From: Kelsey Hightower Date: Sat, 30 Aug 2014 05:19:30 +0000 (-0700) Subject: net/http: add BasicAuth method to *http.Request X-Git-Tag: go1.4beta1~611 X-Git-Url: http://www.git.cypherpunks.su/?a=commitdiff_plain;h=29f9f3ec80237ff0b2eb17f042747c08a10e0a14;p=gostls13.git net/http: add BasicAuth method to *http.Request The net/http package supports setting the HTTP Authorization header using the Basic Authentication Scheme as defined in RFC 2617, but does not provide support for extracting the username and password from an authenticated request using the Basic Authentication Scheme. Add BasicAuth method to *http.Request that returns the username and password from authenticated requests using the Basic Authentication Scheme. Fixes #6779. LGTM=bradfitz R=golang-codereviews, josharian, bradfitz, alberto.garcia.hierro, blakesgentry CC=golang-codereviews https://golang.org/cl/76540043 --- diff --git a/src/pkg/net/http/request.go b/src/pkg/net/http/request.go index 6372943188..263c26c9bd 100644 --- a/src/pkg/net/http/request.go +++ b/src/pkg/net/http/request.go @@ -10,6 +10,7 @@ import ( "bufio" "bytes" "crypto/tls" + "encoding/base64" "errors" "fmt" "io" @@ -521,6 +522,35 @@ func NewRequest(method, urlStr string, body io.Reader) (*Request, error) { return req, nil } +// BasicAuth returns the username and password provided in the request's +// Authorization header, if the request uses HTTP Basic Authentication. +// See RFC 2617, Section 2. +func (r *Request) BasicAuth() (username, password string, ok bool) { + auth := r.Header.Get("Authorization") + if auth == "" { + return + } + return parseBasicAuth(auth) +} + +// parseBasicAuth parses an HTTP Basic Authentication string. +// "Basic QWxhZGRpbjpvcGVuIHNlc2FtZQ==" returns ("Aladdin", "open sesame", true). +func parseBasicAuth(auth string) (username, password string, ok bool) { + if !strings.HasPrefix(auth, "Basic ") { + return + } + c, err := base64.StdEncoding.DecodeString(strings.TrimPrefix(auth, "Basic ")) + if err != nil { + return + } + cs := string(c) + s := strings.IndexByte(cs, ':') + if s < 0 { + return + } + return cs[:s], cs[s+1:], true +} + // SetBasicAuth sets the request's Authorization header to use HTTP // Basic Authentication with the provided username and password. // diff --git a/src/pkg/net/http/request_test.go b/src/pkg/net/http/request_test.go index b9fa3c2bfc..759ea4e8b5 100644 --- a/src/pkg/net/http/request_test.go +++ b/src/pkg/net/http/request_test.go @@ -7,6 +7,7 @@ package http_test import ( "bufio" "bytes" + "encoding/base64" "fmt" "io" "io/ioutil" @@ -396,6 +397,75 @@ func TestParseHTTPVersion(t *testing.T) { } } +type getBasicAuthTest struct { + username, password string + ok bool +} + +type parseBasicAuthTest getBasicAuthTest + +type basicAuthCredentialsTest struct { + username, password string +} + +var getBasicAuthTests = []struct { + username, password string + ok bool +}{ + {"Aladdin", "open sesame", true}, + {"Aladdin", "open:sesame", true}, + {"", "", true}, +} + +func TestGetBasicAuth(t *testing.T) { + for _, tt := range getBasicAuthTests { + r, _ := NewRequest("GET", "http://example.com/", nil) + r.SetBasicAuth(tt.username, tt.password) + username, password, ok := r.BasicAuth() + if ok != tt.ok || username != tt.username || password != tt.password { + t.Errorf("BasicAuth() = %#v, want %#v", getBasicAuthTest{username, password, ok}, + getBasicAuthTest{tt.username, tt.password, tt.ok}) + } + } + // Unauthenticated request. + r, _ := NewRequest("GET", "http://example.com/", nil) + username, password, ok := r.BasicAuth() + if ok { + t.Errorf("expected false from BasicAuth when the request is unauthenticated") + } + want := basicAuthCredentialsTest{"", ""} + if username != want.username || password != want.password { + t.Errorf("expected credentials: %#v when the request is unauthenticated, got %#v", + want, basicAuthCredentialsTest{username, password}) + } +} + +var parseBasicAuthTests = []struct { + header, username, password string + ok bool +}{ + {"Basic " + base64.StdEncoding.EncodeToString([]byte("Aladdin:open sesame")), "Aladdin", "open sesame", true}, + {"Basic " + base64.StdEncoding.EncodeToString([]byte("Aladdin:open:sesame")), "Aladdin", "open:sesame", true}, + {"Basic " + base64.StdEncoding.EncodeToString([]byte(":")), "", "", true}, + {"Basic" + base64.StdEncoding.EncodeToString([]byte("Aladdin:open sesame")), "", "", false}, + {base64.StdEncoding.EncodeToString([]byte("Aladdin:open sesame")), "", "", false}, + {"Basic ", "", "", false}, + {"Basic Aladdin:open sesame", "", "", false}, + {`Digest username="Aladdin"`, "", "", false}, +} + +func TestParseBasicAuth(t *testing.T) { + for _, tt := range parseBasicAuthTests { + r, _ := NewRequest("GET", "http://example.com/", nil) + r.Header.Set("Authorization", tt.header) + username, password, ok := r.BasicAuth() + if ok != tt.ok || username != tt.username || password != tt.password { + t.Errorf("BasicAuth() = %#v, want %#v", getBasicAuthTest{username, password, ok}, + getBasicAuthTest{tt.username, tt.password, tt.ok}) + } + } +} + type logWrites struct { t *testing.T dst *[]string