From 29f9f3ec80237ff0b2eb17f042747c08a10e0a14 Mon Sep 17 00:00:00 2001 From: Kelsey Hightower Date: Fri, 29 Aug 2014 22:19:30 -0700 Subject: [PATCH] net/http: add BasicAuth method to *http.Request The net/http package supports setting the HTTP Authorization header using the Basic Authentication Scheme as defined in RFC 2617, but does not provide support for extracting the username and password from an authenticated request using the Basic Authentication Scheme. Add BasicAuth method to *http.Request that returns the username and password from authenticated requests using the Basic Authentication Scheme. Fixes #6779. LGTM=bradfitz R=golang-codereviews, josharian, bradfitz, alberto.garcia.hierro, blakesgentry CC=golang-codereviews https://golang.org/cl/76540043 --- src/pkg/net/http/request.go | 30 ++++++++++++++ src/pkg/net/http/request_test.go | 70 ++++++++++++++++++++++++++++++++ 2 files changed, 100 insertions(+) diff --git a/src/pkg/net/http/request.go b/src/pkg/net/http/request.go index 6372943188..263c26c9bd 100644 --- a/src/pkg/net/http/request.go +++ b/src/pkg/net/http/request.go @@ -10,6 +10,7 @@ import ( "bufio" "bytes" "crypto/tls" + "encoding/base64" "errors" "fmt" "io" @@ -521,6 +522,35 @@ func NewRequest(method, urlStr string, body io.Reader) (*Request, error) { return req, nil } +// BasicAuth returns the username and password provided in the request's +// Authorization header, if the request uses HTTP Basic Authentication. +// See RFC 2617, Section 2. +func (r *Request) BasicAuth() (username, password string, ok bool) { + auth := r.Header.Get("Authorization") + if auth == "" { + return + } + return parseBasicAuth(auth) +} + +// parseBasicAuth parses an HTTP Basic Authentication string. +// "Basic QWxhZGRpbjpvcGVuIHNlc2FtZQ==" returns ("Aladdin", "open sesame", true). +func parseBasicAuth(auth string) (username, password string, ok bool) { + if !strings.HasPrefix(auth, "Basic ") { + return + } + c, err := base64.StdEncoding.DecodeString(strings.TrimPrefix(auth, "Basic ")) + if err != nil { + return + } + cs := string(c) + s := strings.IndexByte(cs, ':') + if s < 0 { + return + } + return cs[:s], cs[s+1:], true +} + // SetBasicAuth sets the request's Authorization header to use HTTP // Basic Authentication with the provided username and password. // diff --git a/src/pkg/net/http/request_test.go b/src/pkg/net/http/request_test.go index b9fa3c2bfc..759ea4e8b5 100644 --- a/src/pkg/net/http/request_test.go +++ b/src/pkg/net/http/request_test.go @@ -7,6 +7,7 @@ package http_test import ( "bufio" "bytes" + "encoding/base64" "fmt" "io" "io/ioutil" @@ -396,6 +397,75 @@ func TestParseHTTPVersion(t *testing.T) { } } +type getBasicAuthTest struct { + username, password string + ok bool +} + +type parseBasicAuthTest getBasicAuthTest + +type basicAuthCredentialsTest struct { + username, password string +} + +var getBasicAuthTests = []struct { + username, password string + ok bool +}{ + {"Aladdin", "open sesame", true}, + {"Aladdin", "open:sesame", true}, + {"", "", true}, +} + +func TestGetBasicAuth(t *testing.T) { + for _, tt := range getBasicAuthTests { + r, _ := NewRequest("GET", "http://example.com/", nil) + r.SetBasicAuth(tt.username, tt.password) + username, password, ok := r.BasicAuth() + if ok != tt.ok || username != tt.username || password != tt.password { + t.Errorf("BasicAuth() = %#v, want %#v", getBasicAuthTest{username, password, ok}, + getBasicAuthTest{tt.username, tt.password, tt.ok}) + } + } + // Unauthenticated request. + r, _ := NewRequest("GET", "http://example.com/", nil) + username, password, ok := r.BasicAuth() + if ok { + t.Errorf("expected false from BasicAuth when the request is unauthenticated") + } + want := basicAuthCredentialsTest{"", ""} + if username != want.username || password != want.password { + t.Errorf("expected credentials: %#v when the request is unauthenticated, got %#v", + want, basicAuthCredentialsTest{username, password}) + } +} + +var parseBasicAuthTests = []struct { + header, username, password string + ok bool +}{ + {"Basic " + base64.StdEncoding.EncodeToString([]byte("Aladdin:open sesame")), "Aladdin", "open sesame", true}, + {"Basic " + base64.StdEncoding.EncodeToString([]byte("Aladdin:open:sesame")), "Aladdin", "open:sesame", true}, + {"Basic " + base64.StdEncoding.EncodeToString([]byte(":")), "", "", true}, + {"Basic" + base64.StdEncoding.EncodeToString([]byte("Aladdin:open sesame")), "", "", false}, + {base64.StdEncoding.EncodeToString([]byte("Aladdin:open sesame")), "", "", false}, + {"Basic ", "", "", false}, + {"Basic Aladdin:open sesame", "", "", false}, + {`Digest username="Aladdin"`, "", "", false}, +} + +func TestParseBasicAuth(t *testing.T) { + for _, tt := range parseBasicAuthTests { + r, _ := NewRequest("GET", "http://example.com/", nil) + r.Header.Set("Authorization", tt.header) + username, password, ok := r.BasicAuth() + if ok != tt.ok || username != tt.username || password != tt.password { + t.Errorf("BasicAuth() = %#v, want %#v", getBasicAuthTest{username, password, ok}, + getBasicAuthTest{tt.username, tt.password, tt.ok}) + } + } +} + type logWrites struct { t *testing.T dst *[]string -- 2.48.1