From be11422b1ec46fb69b387ef29a521ed42621fe3d Mon Sep 17 00:00:00 2001 From: Jonathan Amsterdam Date: Mon, 18 Sep 2023 15:51:04 -0400 Subject: [PATCH] net/http: index patterns for faster conflict detection Add an index so that pattern registration isn't always quadratic. If there were no index, then every pattern that was registered would have to be compared to every existing pattern for conflicts. This would make registration quadratic in the number of patterns, in every case. The index in this CL should help most of the time. If a pattern has a literal segment, it will weed out all other patterns that have a different literal in that position. The worst case will still be quadratic, but it is unlikely that a set of such patterns would arise naturally. One novel (to me) aspect of the CL is the use of fuzz testing on data that is neither a string nor a byte slice. The test uses fuzzing to generate a byte slice, then decodes the byte slice into a valid pattern (most of the time). This test actually caught a bug: see https://go.dev/cl/529119. Change-Id: Ice0be6547decb5ce75a8062e4e17227815d5d0b0 Reviewed-on: https://go-review.googlesource.com/c/go/+/529121 Run-TryBot: Jonathan Amsterdam TryBot-Result: Gopher Robot Reviewed-by: Damien Neil --- src/net/http/routing_index.go | 124 +++++++++++ src/net/http/routing_index_test.go | 207 ++++++++++++++++++ src/net/http/server.go | 20 +- .../testdata/fuzz/FuzzIndex/48161038f0c8b2da | 2 + .../testdata/fuzz/FuzzIndex/716514f590ce7ab3 | 2 + 5 files changed, 346 insertions(+), 9 deletions(-) create mode 100644 src/net/http/routing_index.go create mode 100644 src/net/http/routing_index_test.go create mode 100644 src/net/http/testdata/fuzz/FuzzIndex/48161038f0c8b2da create mode 100644 src/net/http/testdata/fuzz/FuzzIndex/716514f590ce7ab3 diff --git a/src/net/http/routing_index.go b/src/net/http/routing_index.go new file mode 100644 index 0000000000..9ac42c997d --- /dev/null +++ b/src/net/http/routing_index.go @@ -0,0 +1,124 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package http + +import "math" + +// A routingIndex optimizes conflict detection by indexing patterns. +// +// The basic idea is to rule out patterns that cannot conflict with a given +// pattern because they have a different literal in a corresponding segment. +// See the comments in [routingIndex.possiblyConflictingPatterns] for more details. +type routingIndex struct { + // map from a particular segment position and value to all registered patterns + // with that value in that position. + // For example, the key {1, "b"} would hold the patterns "/a/b" and "/a/b/c" + // but not "/a", "b/a", "/a/c" or "/a/{x}". + segments map[routingIndexKey][]*pattern + // All patterns that end in a multi wildcard (including trailing slash). + // We do not try to be clever about indexing multi patterns, because there + // are unlikely to be many of them. + multis []*pattern +} + +type routingIndexKey struct { + pos int // 0-based segment position + s string // literal, or empty for wildcard +} + +func (idx *routingIndex) addPattern(pat *pattern) { + if pat.lastSegment().multi { + idx.multis = append(idx.multis, pat) + } else { + if idx.segments == nil { + idx.segments = map[routingIndexKey][]*pattern{} + } + for pos, seg := range pat.segments { + key := routingIndexKey{pos: pos, s: ""} + if !seg.wild { + key.s = seg.s + } + idx.segments[key] = append(idx.segments[key], pat) + } + } +} + +// possiblyConflictingPatterns calls f on all patterns that might conflict with +// pat. If f returns a non-nil error, possiblyConflictingPatterns returns immediately +// with that error. +// +// To be correct, possiblyConflictingPatterns must include all patterns that +// might conflict. But it may also include patterns that cannot conflict. +// For instance, an implementation that returns all registered patterns is correct. +// We use this fact throughout, simplifying the implementation by returning more +// patterns that we might need to. +func (idx *routingIndex) possiblyConflictingPatterns(pat *pattern, f func(*pattern) error) (err error) { + // Terminology: + // dollar pattern: one ending in "{$}" + // multi pattern: one ending in a trailing slash or "{x...}" wildcard + // ordinary pattern: neither of the above + + // apply f to all the pats, stopping on error. + apply := func(pats []*pattern) error { + if err != nil { + return err + } + for _, p := range pats { + err = f(p) + if err != nil { + return err + } + } + return nil + } + + // Our simple indexing scheme doesn't try to prune multi patterns; assume + // any of them can match the argument. + if err := apply(idx.multis); err != nil { + return err + } + if pat.lastSegment().s == "/" { + // All paths that a dollar pattern matches end in a slash; no paths that + // an ordinary pattern matches do. So only other dollar or multi + // patterns can conflict with a dollar pattern. Furthermore, conflicting + // dollar patterns must have the {$} in the same position. + return apply(idx.segments[routingIndexKey{s: "/", pos: len(pat.segments) - 1}]) + } + // For ordinary and multi patterns, the only conflicts can be with a multi, + // or a pattern that has the same literal or a wildcard at some literal + // position. + // We could intersect all the possible matches at each position, but we + // do something simpler: we find the position with the fewest patterns. + var lmin, wmin []*pattern + min := math.MaxInt + hasLit := false + for i, seg := range pat.segments { + if seg.multi { + break + } + if !seg.wild { + hasLit = true + lpats := idx.segments[routingIndexKey{s: seg.s, pos: i}] + wpats := idx.segments[routingIndexKey{s: "", pos: i}] + if sum := len(lpats) + len(wpats); sum < min { + lmin = lpats + wmin = wpats + min = sum + } + } + } + if hasLit { + apply(lmin) + apply(wmin) + return err + } + + // This pattern is all wildcards. + // Check it against everything. + for _, pats := range idx.segments { + apply(pats) + } + return err +} diff --git a/src/net/http/routing_index_test.go b/src/net/http/routing_index_test.go new file mode 100644 index 0000000000..7030fc8a67 --- /dev/null +++ b/src/net/http/routing_index_test.go @@ -0,0 +1,207 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package http + +import ( + "bytes" + "fmt" + "slices" + "sort" + "strings" + "testing" +) + +func TestIndex(t *testing.T) { + pats := []string{"HEAD /", "/a"} + + var patterns []*pattern + var idx routingIndex + for _, p := range pats { + pat := mustParsePattern(t, p) + patterns = append(patterns, pat) + idx.addPattern(pat) + } + + compare := func(pat *pattern) { + t.Helper() + got := indexConflicts(pat, &idx) + want := trueConflicts(pat, patterns) + if !slices.Equal(got, want) { + t.Errorf("%q:\ngot %q\nwant %q", pat, got, want) + } + } + + compare(mustParsePattern(t, "GET /foo")) + compare(mustParsePattern(t, "GET /{x}")) +} + +// This test works by comparing possiblyConflictingPatterns with +// an exhaustive loop through all patterns. +func FuzzIndex(f *testing.F) { + inits := []string{"/a", "/a/b", "/{x0}", "/{x0}/b", "/a/{x0}", "/a/{$}", "/a/b/{$}", + "/a/", "/a/b/", "/{x}/b/c/{$}", "GET /{x0}/", "HEAD /a"} + + var patterns []*pattern + var idx routingIndex + + // compare takes a fatalf function because fuzzing doesn't like + // it when the fuzz function calls f.Fatalf. + compare := func(pat *pattern, fatalf func(string, ...any)) { + got := indexConflicts(pat, &idx) + want := trueConflicts(pat, patterns) + if !slices.Equal(got, want) { + fatalf("%q:\ngot %q\nwant %q", pat, got, want) + } + } + + for _, p := range inits { + pat, err := parsePattern(p) + if err != nil { + f.Fatal(err) + } + compare(pat, f.Fatalf) + patterns = append(patterns, pat) + idx.addPattern(pat) + f.Add(bytesFromPattern(pat)) + } + + f.Fuzz(func(t *testing.T, pb []byte) { + pat := bytesToPattern(pb) + if pat == nil { + return + } + compare(pat, t.Fatalf) + }) +} + +func trueConflicts(pat *pattern, pats []*pattern) []string { + var s []string + for _, p := range pats { + if pat.conflictsWith(p) { + s = append(s, p.String()) + } + } + sort.Strings(s) + return s +} + +func indexConflicts(pat *pattern, idx *routingIndex) []string { + var s []string + idx.possiblyConflictingPatterns(pat, func(p *pattern) error { + if pat.conflictsWith(p) { + s = append(s, p.String()) + } + return nil + }) + sort.Strings(s) + return slices.Compact(s) +} + +// TODO: incorporate host and method; make encoding denser. +func bytesToPattern(bs []byte) *pattern { + if len(bs) == 0 { + return nil + } + var sb strings.Builder + wc := 0 + for _, b := range bs[:len(bs)-1] { + sb.WriteByte('/') + switch b & 0x3 { + case 0: + fmt.Fprintf(&sb, "{x%d}", wc) + wc++ + case 1: + sb.WriteString("a") + case 2: + sb.WriteString("b") + case 3: + sb.WriteString("c") + } + } + sb.WriteByte('/') + switch bs[len(bs)-1] & 0x7 { + case 0: + fmt.Fprintf(&sb, "{x%d}", wc) + case 1: + sb.WriteString("a") + case 2: + sb.WriteString("b") + case 3: + sb.WriteString("c") + case 4, 5: + fmt.Fprintf(&sb, "{x%d...}", wc) + default: + sb.WriteString("{$}") + } + pat, err := parsePattern(sb.String()) + if err != nil { + panic(err) + } + return pat +} + +func bytesFromPattern(p *pattern) []byte { + var bs []byte + for _, s := range p.segments { + var b byte + switch { + case s.multi: + b = 4 + case s.wild: + b = 0 + case s.s == "/": + b = 7 + case s.s == "a": + b = 1 + case s.s == "b": + b = 2 + case s.s == "c": + b = 3 + default: + panic("bad pattern") + } + bs = append(bs, b) + } + return bs +} + +func TestBytesPattern(t *testing.T) { + tests := []struct { + bs []byte + pat string + }{ + {[]byte{0, 1, 2, 3}, "/{x0}/a/b/c"}, + {[]byte{16, 17, 18, 19}, "/{x0}/a/b/c"}, + {[]byte{4, 4}, "/{x0}/{x1...}"}, + {[]byte{6, 7}, "/b/{$}"}, + } + t.Run("To", func(t *testing.T) { + for _, test := range tests { + p := bytesToPattern(test.bs) + got := p.String() + if got != test.pat { + t.Errorf("%v: got %q, want %q", test.bs, got, test.pat) + } + } + }) + t.Run("From", func(t *testing.T) { + for _, test := range tests { + p, err := parsePattern(test.pat) + if err != nil { + t.Fatal(err) + } + got := bytesFromPattern(p) + var want []byte + for _, b := range test.bs[:len(test.bs)-1] { + want = append(want, b%4) + + } + want = append(want, test.bs[len(test.bs)-1]%8) + if !bytes.Equal(got, want) { + t.Errorf("%s: got %v, want %v", test.pat, got, want) + } + } + }) +} diff --git a/src/net/http/server.go b/src/net/http/server.go index bc5bcb9a71..629d8d3c62 100644 --- a/src/net/http/server.go +++ b/src/net/http/server.go @@ -2347,7 +2347,8 @@ func RedirectHandler(url string, code int) Handler { type ServeMux struct { mu sync.RWMutex tree routingNode - patterns []*pattern + index routingIndex + patterns []*pattern // TODO(jba): remove if possible } // NewServeMux allocates and returns a new ServeMux. @@ -2624,8 +2625,8 @@ func (mux *ServeMux) register(pattern string, handler Handler) { } } -func (mux *ServeMux) registerErr(pattern string, handler Handler) error { - if pattern == "" { +func (mux *ServeMux) registerErr(patstr string, handler Handler) error { + if patstr == "" { return errors.New("http: invalid pattern") } if handler == nil { @@ -2635,9 +2636,9 @@ func (mux *ServeMux) registerErr(pattern string, handler Handler) error { return errors.New("http: nil handler") } - pat, err := parsePattern(pattern) + pat, err := parsePattern(patstr) if err != nil { - return fmt.Errorf("parsing %q: %w", pattern, err) + return fmt.Errorf("parsing %q: %w", patstr, err) } // Get the caller's location, for better conflict error messages. @@ -2652,16 +2653,17 @@ func (mux *ServeMux) registerErr(pattern string, handler Handler) error { mux.mu.Lock() defer mux.mu.Unlock() // Check for conflict. - // This makes a quadratic number of calls to conflictsWith: we check - // each pattern against every other pattern. - // TODO(jba): add indexing to speed this up. - for _, pat2 := range mux.patterns { + if err := mux.index.possiblyConflictingPatterns(pat, func(pat2 *pattern) error { if pat.conflictsWith(pat2) { return fmt.Errorf("pattern %q (registered at %s) conflicts with pattern %q (registered at %s)", pat, pat.loc, pat2, pat2.loc) } + return nil + }); err != nil { + return err } mux.tree.addPattern(pat, handler) + mux.index.addPattern(pat) mux.patterns = append(mux.patterns, pat) return nil } diff --git a/src/net/http/testdata/fuzz/FuzzIndex/48161038f0c8b2da b/src/net/http/testdata/fuzz/FuzzIndex/48161038f0c8b2da new file mode 100644 index 0000000000..06a7336a8d --- /dev/null +++ b/src/net/http/testdata/fuzz/FuzzIndex/48161038f0c8b2da @@ -0,0 +1,2 @@ +go test fuzz v1 +[]byte("101$") diff --git a/src/net/http/testdata/fuzz/FuzzIndex/716514f590ce7ab3 b/src/net/http/testdata/fuzz/FuzzIndex/716514f590ce7ab3 new file mode 100644 index 0000000000..520bff177b --- /dev/null +++ b/src/net/http/testdata/fuzz/FuzzIndex/716514f590ce7ab3 @@ -0,0 +1,2 @@ +go test fuzz v1 +[]byte("1010") -- 2.50.0