]> Cypherpunks repositories - gostls13.git/commitdiff
http: let FileServer work when path doesn't begin with a slash
authorBrad Fitzpatrick <bradfitz@golang.org>
Mon, 18 Jul 2011 16:04:48 +0000 (09:04 -0700)
committerBrad Fitzpatrick <bradfitz@golang.org>
Mon, 18 Jul 2011 16:04:48 +0000 (09:04 -0700)
... as when it's over-stripped with StripPrefix.

R=golang-dev, andybalholm, rsc
CC=golang-dev
https://golang.org/cl/4759052

src/pkg/http/fs.go
src/pkg/http/fs_test.go

index 34fe77d6bd677c1ea021f7d283f44bfcf7e2fac4..4a514beb0b8a0f26ac10f38e8204f29ecd527828 100644 (file)
@@ -242,7 +242,12 @@ func FileServer(root FileSystem) Handler {
 }
 
 func (f *fileHandler) ServeHTTP(w ResponseWriter, r *Request) {
-       serveFile(w, r, f.root, path.Clean(r.URL.Path), true)
+       upath := r.URL.Path
+       if !strings.HasPrefix(upath, "/") {
+               upath = "/" + upath
+               r.URL.Path = upath
+       }
+       serveFile(w, r, f.root, path.Clean(upath), true)
 }
 
 // httpRange specifies the byte range to be sent to the client.
index e278e25399d34febb6d5e36b07d37fe7c72eecb1..0101ad88ca9b96478a42ce8b2e9c08f151eb0673 100644 (file)
@@ -10,6 +10,8 @@ import (
        "http/httptest"
        "io/ioutil"
        "os"
+       "path/filepath"
+       "strings"
        "testing"
 )
 
@@ -117,6 +119,36 @@ func TestFileServerCleans(t *testing.T) {
        }
 }
 
+func TestFileServerImplicitLeadingSlash(t *testing.T) {
+       tempDir, err := ioutil.TempDir("", "")
+       if err != nil {
+               t.Fatalf("TempDir: %v", err)
+       }
+       defer os.RemoveAll(tempDir)
+       if err := ioutil.WriteFile(filepath.Join(tempDir, "foo.txt"), []byte("Hello world"), 0644); err != nil {
+               t.Fatalf("WriteFile: %v", err)
+       }
+       ts := httptest.NewServer(StripPrefix("/bar/", FileServer(Dir(tempDir))))
+       defer ts.Close()
+       get := func(suffix string) string {
+               res, err := Get(ts.URL + suffix)
+               if err != nil {
+                       t.Fatalf("Get %s: %v", suffix, err)
+               }
+               b, err := ioutil.ReadAll(res.Body)
+               if err != nil {
+                       t.Fatalf("ReadAll %s: %v", suffix, err)
+               }
+               return string(b)
+       }
+       if s := get("/bar/"); !strings.Contains(s, ">foo.txt<") {
+               t.Logf("expected a directory listing with foo.txt, got %q", s)
+       }
+       if s := get("/bar/foo.txt"); s != "Hello world" {
+               t.Logf("expected %q, got %q", "Hello world", s)
+       }
+}
+
 func TestDirJoin(t *testing.T) {
        wfi, err := os.Stat("/etc/hosts")
        if err != nil {