]> Cypherpunks repositories - gostls13.git/commitdiff
http: Transport hook to register non-http(s) protocols
authorBrad Fitzpatrick <bradfitz@golang.org>
Wed, 25 May 2011 19:31:11 +0000 (12:31 -0700)
committerBrad Fitzpatrick <bradfitz@golang.org>
Wed, 25 May 2011 19:31:11 +0000 (12:31 -0700)
This permits external packages implementing e.g.
FTP or gopher to register themselves with the
http.DefaultClient:

package ftp
func init() {
    http.DefaultTransport.RegisterProtocol("ftp", &ftp{})
}

Client code would look like:

import (
    _ "github.com/exampleuser/go/gopher"
    _ "github.com/exampleuser/go/ftp"
)

func main() {
    resp, err := http.Get("ftp://example.com/path.txt")
    ...
}

R=dsymonds, rsc
CC=golang-dev
https://golang.org/cl/4526077

src/pkg/http/transport.go
src/pkg/http/transport_test.go

index fdb1b0829a04616c55cc3bac0d50b31a8445ea0a..2b5e5a4250e0cc4e171f223039474ccd2b52524c 100644 (file)
@@ -36,6 +36,7 @@ const DefaultMaxIdleConnsPerHost = 2
 type Transport struct {
        lk       sync.Mutex
        idleConn map[string][]*persistConn
+       altProto map[string]RoundTripper // nil or map of URI scheme => RoundTripper
 
        // TODO: tunable on global max cached connections
        // TODO: tunable on timeout on cached connections
@@ -97,7 +98,16 @@ func (t *Transport) RoundTrip(req *Request) (resp *Response, err os.Error) {
                }
        }
        if req.URL.Scheme != "http" && req.URL.Scheme != "https" {
-               return nil, &badStringError{"unsupported protocol scheme", req.URL.Scheme}
+               t.lk.Lock()
+               var rt RoundTripper
+               if t.altProto != nil {
+                       rt = t.altProto[req.URL.Scheme]
+               }
+               t.lk.Unlock()
+               if rt == nil {
+                       return nil, &badStringError{"unsupported protocol scheme", req.URL.Scheme}
+               }
+               return rt.RoundTrip(req)
        }
 
        cm, err := t.connectMethodForRequest(req)
@@ -117,6 +127,27 @@ func (t *Transport) RoundTrip(req *Request) (resp *Response, err os.Error) {
        return pconn.roundTrip(req)
 }
 
+// RegisterProtocol registers a new protocol with scheme.
+// The Transport will pass requests using the given scheme to rt.
+// It is rt's responsibility to simulate HTTP request semantics.
+//
+// RegisterProtocol can be used by other packages to provide
+// implementations of protocol schemes like "ftp" or "file".
+func (t *Transport) RegisterProtocol(scheme string, rt RoundTripper) {
+       if scheme == "http" || scheme == "https" {
+               panic("protocol " + scheme + " already registered")
+       }
+       t.lk.Lock()
+       defer t.lk.Unlock()
+       if t.altProto == nil {
+               t.altProto = make(map[string]RoundTripper)
+       }
+       if _, exists := t.altProto[scheme]; exists {
+               panic("protocol " + scheme + " already registered")
+       }
+       t.altProto[scheme] = rt
+}
+
 // CloseIdleConnections closes any connections which were previously
 // connected from previous requests but are now sitting idle in
 // a "keep-alive" state. It does not interrupt any connections currently
index 9cd18ffecf42b91e1a28fb136d80504aa53653b6..76e97640e357881f180e88834438951555196877 100644 (file)
@@ -17,6 +17,7 @@ import (
        "io/ioutil"
        "os"
        "strconv"
+       "strings"
        "testing"
        "time"
 )
@@ -531,6 +532,36 @@ func TestTransportGzipRecursive(t *testing.T) {
        }
 }
 
+type fooProto struct{}
+
+func (fooProto) RoundTrip(req *Request) (*Response, os.Error) {
+       res := &Response{
+               Status:     "200 OK",
+               StatusCode: 200,
+               Header:     make(Header),
+               Body:       ioutil.NopCloser(strings.NewReader("You wanted " + req.URL.String())),
+       }
+       return res, nil
+}
+
+func TestTransportAltProto(t *testing.T) {
+       tr := &Transport{}
+       c := &Client{Transport: tr}
+       tr.RegisterProtocol("foo", fooProto{})
+       res, err := c.Get("foo://bar.com/path")
+       if err != nil {
+               t.Fatal(err)
+       }
+       bodyb, err := ioutil.ReadAll(res.Body)
+       if err != nil {
+               t.Fatal(err)
+       }
+       body := string(bodyb)
+       if e := "You wanted foo://bar.com/path"; body != e {
+               t.Errorf("got response %q, want %q", body, e)
+       }
+}
+
 // rgz is a gzip quine that uncompresses to itself.
 var rgz = []byte{
        0x1f, 0x8b, 0x08, 0x08, 0x00, 0x00, 0x00, 0x00,