"io"
"math"
"mime/multipart"
+ "net/http"
. "net/http"
"net/http/httptest"
"net/url"
})
server := httptest.NewServer(mux)
defer server.Close()
- _, err := Get(server.URL + test.url)
+ res, err := Get(server.URL + test.url)
if err != nil {
t.Fatal(err)
}
+ res.Body.Close()
}
}
})
server := httptest.NewServer(mux)
defer server.Close()
- _, err := Get(server.URL + "/a/b/c/d/e")
+ res, err := Get(server.URL + "/a/b/c/d/e")
if err != nil {
t.Fatal(err)
}
+ res.Body.Close()
+}
+
+func TestStatus(t *testing.T) {
+ // The main purpose of this test is to check 405 responses and the Allow header.
+ h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
+ mux := NewServeMux()
+ mux.Handle("GET /g", h)
+ mux.Handle("POST /p", h)
+ mux.Handle("PATCH /p", h)
+ mux.Handle("PUT /r", h)
+ mux.Handle("GET /r/", h)
+ server := httptest.NewServer(mux)
+ defer server.Close()
+
+ for _, test := range []struct {
+ method, path string
+ wantStatus int
+ wantAllow string
+ }{
+ {"GET", "/g", 200, ""},
+ {"HEAD", "/g", 200, ""},
+ {"POST", "/g", 405, "GET, HEAD"},
+ {"GET", "/x", 404, ""},
+ {"GET", "/p", 405, "PATCH, POST"},
+ {"GET", "/./p", 405, "PATCH, POST"},
+ {"GET", "/r/", 200, ""},
+ {"GET", "/r", 200, ""}, // redirected
+ {"HEAD", "/r/", 200, ""},
+ {"HEAD", "/r", 200, ""}, // redirected
+ {"PUT", "/r/", 405, "GET, HEAD"},
+ {"PUT", "/r", 200, ""},
+ } {
+ req, err := http.NewRequest(test.method, server.URL+test.path, nil)
+ if err != nil {
+ t.Fatal(err)
+ }
+ res, err := http.DefaultClient.Do(req)
+ if err != nil {
+ t.Fatal(err)
+ }
+ res.Body.Close()
+ if g, w := res.StatusCode, test.wantStatus; g != w {
+ t.Errorf("%s %s: got %d, want %d", test.method, test.path, g, w)
+ }
+ if g, w := res.Header.Get("Allow"), test.wantAllow; g != w {
+ t.Errorf("%s %s, Allow: got %q, want %q", test.method, test.path, g, w)
+ }
+ }
}
}
return path[:i], path[i:]
}
+
+// matchingMethods adds to methodSet all the methods that would result in a
+// match if passed to routingNode.match with the given host and path.
+func (root *routingNode) matchingMethods(host, path string, methodSet map[string]bool) {
+ if host != "" {
+ root.findChild(host).matchingMethodsPath(path, methodSet)
+ }
+ root.emptyChild.matchingMethodsPath(path, methodSet)
+ if methodSet["GET"] {
+ methodSet["HEAD"] = true
+ }
+}
+
+func (n *routingNode) matchingMethodsPath(path string, set map[string]bool) {
+ if n == nil {
+ return
+ }
+ n.children.eachPair(func(method string, c *routingNode) bool {
+ if p, _ := c.matchPath(path, nil); p != nil {
+ set[method] = true
+ }
+ return true
+ })
+ // Don't look at the empty child. If there were an empty
+ // child, it would match on any method, but we only
+ // call this when we fail to match on a method.
+}
})
}
+func TestMatchingMethods(t *testing.T) {
+ hostTree := buildTree("GET a.com/", "PUT b.com/", "POST /foo/{x}")
+ for _, test := range []struct {
+ name string
+ tree *routingNode
+ host, path string
+ want string
+ }{
+ {
+ "post",
+ buildTree("POST /"), "", "/foo",
+ "POST",
+ },
+ {
+ "get",
+ buildTree("GET /"), "", "/foo",
+ "GET,HEAD",
+ },
+ {
+ "host",
+ hostTree, "", "/foo",
+ "",
+ },
+ {
+ "host",
+ hostTree, "", "/foo/bar",
+ "POST",
+ },
+ {
+ "host2",
+ hostTree, "a.com", "/foo/bar",
+ "GET,HEAD,POST",
+ },
+ {
+ "host3",
+ hostTree, "b.com", "/bar",
+ "PUT",
+ },
+ {
+ // This case shouldn't come up because we only call matchingMethods
+ // when there was no match, but we include it for completeness.
+ "empty",
+ buildTree("/"), "", "/",
+ "",
+ },
+ } {
+ t.Run(test.name, func(t *testing.T) {
+ ms := map[string]bool{}
+ test.tree.matchingMethods(test.host, test.path, ms)
+ keys := mapKeys(ms)
+ sort.Strings(keys)
+ got := strings.Join(keys, ",")
+ if got != test.want {
+ t.Errorf("got %s, want %s", got, test.want)
+ }
+ })
+ }
+}
+
func (n *routingNode) print(w io.Writer, level int) {
indent := strings.Repeat(" ", level)
if n.pattern != nil {
urlpkg "net/url"
"path"
"runtime"
+ "sort"
"strconv"
"strings"
"sync"
// TODO(jba): use escaped path. This is an independent change that is also part
// of proposal https://go.dev/issue/61410.
path := r.URL.Path
-
+ host := r.URL.Host
// CONNECT requests are not canonicalized.
if r.Method == "CONNECT" {
// If r.URL.Path is /tree and its handler is not registered,
// the /tree -> /tree/ redirect applies to CONNECT requests
// but the path canonicalization does not.
- _, _, u := mux.matchOrRedirect(r.URL.Host, r.Method, path, r.URL)
+ _, _, u := mux.matchOrRedirect(host, r.Method, path, r.URL)
if u != nil {
return RedirectHandler(u.String(), StatusMovedPermanently), u.Path, nil, nil
}
} else {
// All other requests have any port stripped and path cleaned
// before passing to mux.handler.
- host := stripHostPort(r.Host)
+ host = stripHostPort(r.Host)
path = cleanPath(path)
// If the given path is /tree and its handler is not registered,
}
}
if n == nil {
- // TODO(jba): support 405 (MethodNotAllowed) by checking for patterns with different methods.
+ // We didn't find a match with the request method. To distinguish between
+ // Not Found and Method Not Allowed, see if there is another pattern that
+ // matches except for the method.
+ allowedMethods := mux.matchingMethods(host, path)
+ if len(allowedMethods) > 0 {
+ return HandlerFunc(func(w ResponseWriter, r *Request) {
+ w.Header().Set("Allow", strings.Join(allowedMethods, ", "))
+ Error(w, StatusText(StatusMethodNotAllowed), StatusMethodNotAllowed)
+ }), "", nil, nil
+ }
return NotFoundHandler(), "", nil, nil
}
return n.handler, n.pattern.String(), n.pattern, matches
return len(n.pattern.segments) == strings.Count(path, "/")
}
+// matchingMethods return a sorted list of all methods that would match with the given host and path.
+func (mux *ServeMux) matchingMethods(host, path string) []string {
+ // Hold the read lock for the entire method so that the two matches are done
+ // on the same set of registered patterns.
+ mux.mu.RLock()
+ defer mux.mu.RUnlock()
+ ms := map[string]bool{}
+ mux.tree.matchingMethods(host, path, ms)
+ // matchOrRedirect will try appending a trailing slash if there is no match.
+ mux.tree.matchingMethods(host, path+"/", ms)
+ methods := mapKeys(ms)
+ sort.Strings(methods)
+ return methods
+}
+
+// TODO: replace with maps.Keys when it is defined.
+func mapKeys[K comparable, V any](m map[K]V) []K {
+ var ks []K
+ for k := range m {
+ ks = append(ks, k)
+ }
+ return ks
+}
+
// ServeHTTP dispatches the request to the handler whose
// pattern most closely matches the request URL.
func (mux *ServeMux) ServeHTTP(w ResponseWriter, r *Request) {