]> Cypherpunks repositories - gostls13.git/commitdiff
http: add file protocol transport
authorBrad Fitzpatrick <bradfitz@golang.org>
Wed, 31 Aug 2011 04:47:41 +0000 (21:47 -0700)
committerBrad Fitzpatrick <bradfitz@golang.org>
Wed, 31 Aug 2011 04:47:41 +0000 (21:47 -0700)
Off by default (security risk), but users can
wire it up if desired.

Fixes #2113

R=rsc, bradfitz
CC=golang-dev
https://golang.org/cl/4959049

src/pkg/http/Makefile
src/pkg/http/filetransport.go [new file with mode: 0644]
src/pkg/http/filetransport_test.go [new file with mode: 0644]

index df4ab95101b39adce1bd32dfc4fc7880f0ed24f4..af4fbc12e024ba36a415373afb8d6210e967af02 100644 (file)
@@ -10,6 +10,7 @@ GOFILES=\
        client.go\
        cookie.go\
        dump.go\
+       filetransport.go\
        fs.go\
        header.go\
        lex.go\
diff --git a/src/pkg/http/filetransport.go b/src/pkg/http/filetransport.go
new file mode 100644 (file)
index 0000000..78f3aa2
--- /dev/null
@@ -0,0 +1,124 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package http
+
+import (
+       "fmt"
+       "io"
+       "os"
+)
+
+// fileTransport implements RoundTripper for the 'file' protocol.
+type fileTransport struct {
+       fh fileHandler
+}
+
+// NewFileTransport returns a new RoundTripper, serving the provided
+// FileSystem. The returned RoundTripper ignores the URL host in its
+// incoming requests, as well as most other properties of the
+// request.
+//
+// The typical use case for NewFileTransport is to register the "file"
+// protocol with a Transport, as in:
+//
+//   t := &http.Transport{}
+//   t.RegisterProtocol("file", http.NewFileTransport(http.Dir("/")))
+//   c := &http.Client{Transport: t}
+//   res, err := c.Get("file:///etc/passwd")
+//   ...
+func NewFileTransport(fs FileSystem) RoundTripper {
+       return fileTransport{fileHandler{fs}}
+}
+
+func (t fileTransport) RoundTrip(req *Request) (resp *Response, err os.Error) {
+       // We start ServeHTTP in a goroutine, which may take a long
+       // time if the file is large.  The newPopulateResponseWriter
+       // call returns a channel which either ServeHTTP or finish()
+       // sends our *Response on, once the *Response itself has been
+       // populated (even if the body itself is still being
+       // written to the res.Body, a pipe)
+       rw, resc := newPopulateResponseWriter()
+       go func() {
+               t.fh.ServeHTTP(rw, req)
+               rw.finish()
+       }()
+       return <-resc, nil
+}
+
+func newPopulateResponseWriter() (*populateResponse, <-chan *Response) {
+       pr, pw := io.Pipe()
+       rw := &populateResponse{
+               ch: make(chan *Response),
+               pw: pw,
+               res: &Response{
+                       Proto:      "HTTP/1.0",
+                       ProtoMajor: 1,
+                       Header:     make(Header),
+                       Close:      true,
+                       Body:       pr,
+               },
+       }
+       return rw, rw.ch
+}
+
+// populateResponse is a ResponseWriter that populates the *Response
+// in res, and writes its body to a pipe connected to the response
+// body. Once writes begin or finish() is called, the response is sent
+// on ch.
+type populateResponse struct {
+       res          *Response
+       ch           chan *Response
+       wroteHeader  bool
+       hasContent   bool
+       sentResponse bool
+       pw           *io.PipeWriter
+}
+
+func (pr *populateResponse) finish() {
+       if !pr.wroteHeader {
+               pr.WriteHeader(500)
+       }
+       if !pr.sentResponse {
+               pr.sendResponse()
+       }
+       pr.pw.Close()
+}
+
+func (pr *populateResponse) sendResponse() {
+       if pr.sentResponse {
+               return
+       }
+       pr.sentResponse = true
+
+       if pr.hasContent {
+               pr.res.ContentLength = -1
+       }
+       pr.ch <- pr.res
+}
+
+func (pr *populateResponse) Header() Header {
+       return pr.res.Header
+}
+
+func (pr *populateResponse) WriteHeader(code int) {
+       if pr.wroteHeader {
+               return
+       }
+       pr.wroteHeader = true
+
+       pr.res.StatusCode = code
+       pr.res.Status = fmt.Sprintf("%d %s", code, StatusText(code))
+}
+
+func (pr *populateResponse) Write(p []byte) (n int, err os.Error) {
+       if !pr.wroteHeader {
+               pr.WriteHeader(StatusOK)
+       }
+       pr.hasContent = true
+       if !pr.sentResponse {
+               pr.sendResponse()
+       }
+       return pr.pw.Write(p)
+}
diff --git a/src/pkg/http/filetransport_test.go b/src/pkg/http/filetransport_test.go
new file mode 100644 (file)
index 0000000..2634243
--- /dev/null
@@ -0,0 +1,63 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package http_test
+
+import (
+       "http"
+       "io/ioutil"
+       "path/filepath"
+       "os"
+       "testing"
+)
+
+func checker(t *testing.T) func(string, os.Error) {
+       return func(call string, err os.Error) {
+               if err == nil {
+                       return
+               }
+               t.Fatalf("%s: %v", call, err)
+       }
+}
+
+func TestFileTransport(t *testing.T) {
+       check := checker(t)
+
+       dname, err := ioutil.TempDir("", "")
+       check("TempDir", err)
+       fname := filepath.Join(dname, "foo.txt")
+       err = ioutil.WriteFile(fname, []byte("Bar"), 0644)
+       check("WriteFile", err)
+
+       tr := &http.Transport{}
+       tr.RegisterProtocol("file", http.NewFileTransport(http.Dir(dname)))
+       c := &http.Client{Transport: tr}
+
+       fooURLs := []string{"file:///foo.txt", "file://../foo.txt"}
+       for _, urlstr := range fooURLs {
+               res, err := c.Get(urlstr)
+               check("Get "+urlstr, err)
+               if res.StatusCode != 200 {
+                       t.Errorf("for %s, StatusCode = %d, want 200", urlstr, res.StatusCode)
+               }
+               if res.ContentLength != -1 {
+                       t.Errorf("for %s, ContentLength = %d, want -1", urlstr, res.ContentLength)
+               }
+               if res.Body == nil {
+                       t.Fatalf("for %s, nil Body", urlstr)
+               }
+               slurp, err := ioutil.ReadAll(res.Body)
+               check("ReadAll "+urlstr, err)
+               if string(slurp) != "Bar" {
+                       t.Errorf("for %s, got content %q, want %q", urlstr, string(slurp), "Bar")
+               }
+       }
+
+       const badURL = "file://../no-exist.txt"
+       res, err := c.Get(badURL)
+       check("Get "+badURL, err)
+       if res.StatusCode != 404 {
+               t.Errorf("for %s, StatusCode = %d, want 404", badURL, res.StatusCode)
+       }
+}