]> Cypherpunks repositories - gostls13.git/commitdiff
net/http: add BasicAuth method to *http.Request
authorKelsey Hightower <kelsey.hightower@gmail.com>
Sat, 30 Aug 2014 05:19:30 +0000 (22:19 -0700)
committerBrad Fitzpatrick <bradfitz@golang.org>
Sat, 30 Aug 2014 05:19:30 +0000 (22:19 -0700)
The net/http package supports setting the HTTP Authorization header
using the Basic Authentication Scheme as defined in RFC 2617, but does
not provide support for extracting the username and password from an
authenticated request using the Basic Authentication Scheme.

Add BasicAuth method to *http.Request that returns the username and
password from authenticated requests using the Basic Authentication
Scheme.

Fixes #6779.

LGTM=bradfitz
R=golang-codereviews, josharian, bradfitz, alberto.garcia.hierro, blakesgentry
CC=golang-codereviews
https://golang.org/cl/76540043

src/pkg/net/http/request.go
src/pkg/net/http/request_test.go

index 63729431887a126cfce57e09595b63e8a988e735..263c26c9bdd8343310d6c5a11e9ab71ef3a6a835 100644 (file)
@@ -10,6 +10,7 @@ import (
        "bufio"
        "bytes"
        "crypto/tls"
+       "encoding/base64"
        "errors"
        "fmt"
        "io"
@@ -521,6 +522,35 @@ func NewRequest(method, urlStr string, body io.Reader) (*Request, error) {
        return req, nil
 }
 
+// BasicAuth returns the username and password provided in the request's
+// Authorization header, if the request uses HTTP Basic Authentication.
+// See RFC 2617, Section 2.
+func (r *Request) BasicAuth() (username, password string, ok bool) {
+       auth := r.Header.Get("Authorization")
+       if auth == "" {
+               return
+       }
+       return parseBasicAuth(auth)
+}
+
+// parseBasicAuth parses an HTTP Basic Authentication string.
+// "Basic QWxhZGRpbjpvcGVuIHNlc2FtZQ==" returns ("Aladdin", "open sesame", true).
+func parseBasicAuth(auth string) (username, password string, ok bool) {
+       if !strings.HasPrefix(auth, "Basic ") {
+               return
+       }
+       c, err := base64.StdEncoding.DecodeString(strings.TrimPrefix(auth, "Basic "))
+       if err != nil {
+               return
+       }
+       cs := string(c)
+       s := strings.IndexByte(cs, ':')
+       if s < 0 {
+               return
+       }
+       return cs[:s], cs[s+1:], true
+}
+
 // SetBasicAuth sets the request's Authorization header to use HTTP
 // Basic Authentication with the provided username and password.
 //
index b9fa3c2bfc4f68fae9fb4df302325971dd356355..759ea4e8b5d8ed2daf2d5c5b7a31c1b995b35513 100644 (file)
@@ -7,6 +7,7 @@ package http_test
 import (
        "bufio"
        "bytes"
+       "encoding/base64"
        "fmt"
        "io"
        "io/ioutil"
@@ -396,6 +397,75 @@ func TestParseHTTPVersion(t *testing.T) {
        }
 }
 
+type getBasicAuthTest struct {
+       username, password string
+       ok                 bool
+}
+
+type parseBasicAuthTest getBasicAuthTest
+
+type basicAuthCredentialsTest struct {
+       username, password string
+}
+
+var getBasicAuthTests = []struct {
+       username, password string
+       ok                 bool
+}{
+       {"Aladdin", "open sesame", true},
+       {"Aladdin", "open:sesame", true},
+       {"", "", true},
+}
+
+func TestGetBasicAuth(t *testing.T) {
+       for _, tt := range getBasicAuthTests {
+               r, _ := NewRequest("GET", "http://example.com/", nil)
+               r.SetBasicAuth(tt.username, tt.password)
+               username, password, ok := r.BasicAuth()
+               if ok != tt.ok || username != tt.username || password != tt.password {
+                       t.Errorf("BasicAuth() = %#v, want %#v", getBasicAuthTest{username, password, ok},
+                               getBasicAuthTest{tt.username, tt.password, tt.ok})
+               }
+       }
+       // Unauthenticated request.
+       r, _ := NewRequest("GET", "http://example.com/", nil)
+       username, password, ok := r.BasicAuth()
+       if ok {
+               t.Errorf("expected false from BasicAuth when the request is unauthenticated")
+       }
+       want := basicAuthCredentialsTest{"", ""}
+       if username != want.username || password != want.password {
+               t.Errorf("expected credentials: %#v when the request is unauthenticated, got %#v",
+                       want, basicAuthCredentialsTest{username, password})
+       }
+}
+
+var parseBasicAuthTests = []struct {
+       header, username, password string
+       ok                         bool
+}{
+       {"Basic " + base64.StdEncoding.EncodeToString([]byte("Aladdin:open sesame")), "Aladdin", "open sesame", true},
+       {"Basic " + base64.StdEncoding.EncodeToString([]byte("Aladdin:open:sesame")), "Aladdin", "open:sesame", true},
+       {"Basic " + base64.StdEncoding.EncodeToString([]byte(":")), "", "", true},
+       {"Basic" + base64.StdEncoding.EncodeToString([]byte("Aladdin:open sesame")), "", "", false},
+       {base64.StdEncoding.EncodeToString([]byte("Aladdin:open sesame")), "", "", false},
+       {"Basic ", "", "", false},
+       {"Basic Aladdin:open sesame", "", "", false},
+       {`Digest username="Aladdin"`, "", "", false},
+}
+
+func TestParseBasicAuth(t *testing.T) {
+       for _, tt := range parseBasicAuthTests {
+               r, _ := NewRequest("GET", "http://example.com/", nil)
+               r.Header.Set("Authorization", tt.header)
+               username, password, ok := r.BasicAuth()
+               if ok != tt.ok || username != tt.username || password != tt.password {
+                       t.Errorf("BasicAuth() = %#v, want %#v", getBasicAuthTest{username, password, ok},
+                               getBasicAuthTest{tt.username, tt.password, tt.ok})
+               }
+       }
+}
+
 type logWrites struct {
        t   *testing.T
        dst *[]string