type FileSet struct {
mutex sync.RWMutex // protects the file set
base int // base offset for the next file
- files []*File // list of files in the order added to the set
+ tree tree // tree of files in ascending base order
last atomic.Pointer[File] // cache of last file looked up
}
}
// add the file to the file set
s.base = base
- s.files = append(s.files, f)
+ s.tree.add(f)
s.last.Store(f)
return f
}
s.mutex.Lock()
defer s.mutex.Unlock()
- // Merge and sort.
- newFiles := append(s.files, files...)
- slices.SortFunc(newFiles, func(x, y *File) int {
- return cmp.Compare(x.Base(), y.Base())
- })
-
- // Reject overlapping files.
- // Discard adjacent identical files.
- out := newFiles[:0]
- for i, file := range newFiles {
- if i > 0 {
- prev := newFiles[i-1]
- if file == prev {
- continue
- }
- if prev.Base()+prev.Size()+1 > file.Base() {
- panic(fmt.Sprintf("file %s (%d-%d) overlaps with file %s (%d-%d)",
- prev.Name(), prev.Base(), prev.Base()+prev.Size(),
- file.Name(), file.Base(), file.Base()+file.Size()))
- }
- }
- out = append(out, file)
- }
- newFiles = out
-
- s.files = newFiles
-
- // Advance base.
- if len(newFiles) > 0 {
- last := newFiles[len(newFiles)-1]
- newBase := last.Base() + last.Size() + 1
- if s.base < newBase {
- s.base = newBase
- }
+ for _, f := range files {
+ s.tree.add(f)
+ s.base = max(s.base, f.Base()+f.Size()+1)
}
}
s.mutex.Lock()
defer s.mutex.Unlock()
- if i := searchFiles(s.files, file.base); i >= 0 && s.files[i] == file {
- last := &s.files[len(s.files)-1]
- s.files = slices.Delete(s.files, i, i+1)
- *last = nil // don't prolong lifetime when popping last element
+ pn, _ := s.tree.locate(file.key())
+ if *pn != nil && (*pn).file == file {
+ s.tree.delete(pn)
}
}
-// Iterate calls f for the files in the file set in the order they were added
-// until f returns false.
-func (s *FileSet) Iterate(f func(*File) bool) {
- for i := 0; ; i++ {
- var file *File
- s.mutex.RLock()
- if i < len(s.files) {
- file = s.files[i]
- }
- s.mutex.RUnlock()
- if file == nil || !f(file) {
- break
- }
- }
-}
+// Iterate calls yield for the files in the file set in ascending Base
+// order until yield returns false.
+func (s *FileSet) Iterate(yield func(*File) bool) {
+ s.mutex.RLock()
+ defer s.mutex.RUnlock()
-func searchFiles(a []*File, x int) int {
- i, found := slices.BinarySearchFunc(a, x, func(a *File, x int) int {
- return cmp.Compare(a.base, x)
+ // Unlock around user code.
+ // The iterator is robust to modification by yield.
+ // Avoid range here, so we can use defer.
+ s.tree.all()(func(f *File) bool {
+ s.mutex.RUnlock()
+ defer s.mutex.RLock()
+ return yield(f)
})
- if !found {
- // We want the File containing x, but if we didn't
- // find x then i is the next one.
- i--
- }
- return i
}
func (s *FileSet) file(p Pos) *File {
s.mutex.RLock()
defer s.mutex.RUnlock()
- // p is not in last file - search all files
- if i := searchFiles(s.files, int(p)); i >= 0 {
- f := s.files[i]
- // f.base <= int(p) by definition of searchFiles
- if int(p) <= f.base+f.size {
- // Update cache of last file. A race is ok,
- // but an exclusive lock causes heavy contention.
- s.last.Store(f)
- return f
- }
+ pn, _ := s.tree.locate(key{int(p), int(p)})
+ if n := *pn; n != nil {
+ // Update cache of last file. A race is ok,
+ // but an exclusive lock causes heavy contention.
+ s.last.Store(n.file)
+ return n.file
}
return nil
}
--- /dev/null
+// Copyright 2025 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package token
+
+// tree is a self-balancing AVL tree; see
+// Lewis & Denenberg, Data Structures and Their Algorithms.
+//
+// An AVL tree is a binary tree in which the difference between the
+// heights of a node's two subtrees--the node's "balance factor"--is
+// at most one. It is more strictly balanced than a red/black tree,
+// and thus favors lookups at the expense of updates, which is the
+// appropriate trade-off for FileSet.
+//
+// Insertion at a node may cause its ancestors' balance factors to
+// temporarily reach ±2, requiring rebalancing of each such ancestor
+// by a rotation.
+//
+// Each key is the pos-end range of a single File.
+// All Files in the tree must have disjoint ranges.
+//
+// The implementation is simplified from Russ Cox's github.com/rsc/omap.
+
+import (
+ "fmt"
+ "iter"
+)
+
+// A tree is a tree-based ordered map:
+// each value is a *File, keyed by its Pos range.
+// All map entries cover disjoint ranges.
+//
+// The zero value of tree is an empty map ready to use.
+type tree struct {
+ root *node
+}
+
+type node struct {
+ // We use the notation (parent left right) in many comments.
+ parent *node
+ left *node
+ right *node
+ file *File
+ key key // = file.key(), but improves locality (25% faster)
+ balance int32 // at most ±2
+ height int32
+}
+
+// A key represents the Pos range of a File.
+type key struct{ start, end int }
+
+func (f *File) key() key {
+ return key{f.base, f.base + f.size}
+}
+
+// compareKey reports whether x is before y (-1),
+// after y (+1), or overlapping y (0).
+// This is a total order so long as all
+// files in the tree have disjoint ranges.
+//
+// All files are separated by at least one unit.
+// This allows us to use strict < comparisons.
+// Use key{p, p} to search for a zero-width position
+// even at the start or end of a file.
+func compareKey(x, y key) int {
+ switch {
+ case x.end < y.start:
+ return -1
+ case y.end < x.start:
+ return +1
+ }
+ return 0
+}
+
+// check asserts that each node's height, subtree, and parent link is
+// correct.
+func (n *node) check(parent *node) {
+ const debugging = false
+ if debugging {
+ if n == nil {
+ return
+ }
+ if n.parent != parent {
+ panic("bad parent")
+ }
+ n.left.check(n)
+ n.right.check(n)
+ n.checkBalance()
+ }
+}
+
+func (n *node) checkBalance() {
+ lheight, rheight := n.left.safeHeight(), n.right.safeHeight()
+ balance := rheight - lheight
+ if balance != n.balance {
+ panic("bad node.balance")
+ }
+ if !(-2 <= balance && balance <= +2) {
+ panic(fmt.Sprintf("node.balance out of range: %d", balance))
+ }
+ h := 1 + max(lheight, rheight)
+ if h != n.height {
+ panic("bad node.height")
+ }
+}
+
+// locate returns a pointer to the variable that holds the node
+// identified by k, along with its parent, if any. If the key is not
+// present, it returns a pointer to the node where the key should be
+// inserted by a subsequent call to [tree.set].
+func (t *tree) locate(k key) (pos **node, parent *node) {
+ pos, x := &t.root, t.root
+ for x != nil {
+ sign := compareKey(k, x.key)
+ if sign < 0 {
+ pos, x, parent = &x.left, x.left, x
+ } else if sign > 0 {
+ pos, x, parent = &x.right, x.right, x
+ } else {
+ break
+ }
+ }
+ return pos, parent
+}
+
+// all returns an iterator over the tree t.
+// If t is modified during the iteration,
+// some files may not be visited.
+// No file will be visited multiple times.
+func (t *tree) all() iter.Seq[*File] {
+ return func(yield func(*File) bool) {
+ if t == nil {
+ return
+ }
+ x := t.root
+ if x != nil {
+ for x.left != nil {
+ x = x.left
+ }
+ }
+ for x != nil && yield(x.file) {
+ if x.height >= 0 {
+ // still in tree
+ x = x.next()
+ } else {
+ // deleted
+ x = t.nextAfter(t.locate(x.key))
+ }
+ }
+ }
+}
+
+// nextAfter returns the node in the key sequence following
+// (pos, parent), a result pair from [tree.locate].
+func (t *tree) nextAfter(pos **node, parent *node) *node {
+ switch {
+ case *pos != nil:
+ return (*pos).next()
+ case parent == nil:
+ return nil
+ case pos == &parent.left:
+ return parent
+ default:
+ return parent.next()
+ }
+}
+
+func (x *node) next() *node {
+ if x.right == nil {
+ for x.parent != nil && x.parent.right == x {
+ x = x.parent
+ }
+ return x.parent
+ }
+ x = x.right
+ for x.left != nil {
+ x = x.left
+ }
+ return x
+}
+
+func (t *tree) setRoot(x *node) {
+ t.root = x
+ if x != nil {
+ x.parent = nil
+ }
+}
+
+func (x *node) setLeft(y *node) {
+ x.left = y
+ if y != nil {
+ y.parent = x
+ }
+}
+
+func (x *node) setRight(y *node) {
+ x.right = y
+ if y != nil {
+ y.parent = x
+ }
+}
+
+func (n *node) safeHeight() int32 {
+ if n == nil {
+ return -1
+ }
+ return n.height
+}
+
+func (n *node) update() {
+ lheight, rheight := n.left.safeHeight(), n.right.safeHeight()
+ n.height = max(lheight, rheight) + 1
+ n.balance = rheight - lheight
+}
+
+func (t *tree) replaceChild(parent, old, new *node) {
+ switch {
+ case parent == nil:
+ if t.root != old {
+ panic("corrupt tree")
+ }
+ t.setRoot(new)
+ case parent.left == old:
+ parent.setLeft(new)
+ case parent.right == old:
+ parent.setRight(new)
+ default:
+ panic("corrupt tree")
+ }
+}
+
+// rebalanceUp visits each excessively unbalanced ancestor
+// of x, restoring balance by rotating it.
+//
+// x is a node that has just been mutated, and so the height and
+// balance of x and its ancestors may be stale, but the children of x
+// must be in a valid state.
+func (t *tree) rebalanceUp(x *node) {
+ for x != nil {
+ h := x.height
+ x.update()
+ switch x.balance {
+ case -2:
+ if x.left.balance == 1 {
+ t.rotateLeft(x.left)
+ }
+ x = t.rotateRight(x)
+
+ case +2:
+ if x.right.balance == -1 {
+ t.rotateRight(x.right)
+ }
+ x = t.rotateLeft(x)
+ }
+ if x.height == h {
+ // x's height has not changed, so the height
+ // and balance of its ancestors have not changed;
+ // no further rebalancing is required.
+ return
+ }
+ x = x.parent
+ }
+}
+
+// rotateRight rotates the subtree rooted at node y.
+// turning (y (x a b) c) into (x a (y b c)).
+func (t *tree) rotateRight(y *node) *node {
+ // p -> (y (x a b) c)
+ p := y.parent
+ x := y.left
+ b := x.right
+
+ x.checkBalance()
+ y.checkBalance()
+
+ x.setRight(y)
+ y.setLeft(b)
+ t.replaceChild(p, y, x)
+
+ y.update()
+ x.update()
+ return x
+}
+
+// rotateLeft rotates the subtree rooted at node x.
+// turning (x a (y b c)) into (y (x a b) c).
+func (t *tree) rotateLeft(x *node) *node {
+ // p -> (x a (y b c))
+ p := x.parent
+ y := x.right
+ b := y.left
+
+ x.checkBalance()
+ y.checkBalance()
+
+ y.setLeft(x)
+ x.setRight(b)
+ t.replaceChild(p, x, y)
+
+ x.update()
+ y.update()
+ return y
+}
+
+// add inserts file into the tree, if not present.
+// It panics if file overlaps with another.
+func (t *tree) add(file *File) {
+ pos, parent := t.locate(file.key())
+ if *pos == nil {
+ t.set(file, pos, parent) // missing; insert
+ return
+ }
+ if prev := (*pos).file; prev != file {
+ panic(fmt.Sprintf("file %s (%d-%d) overlaps with file %s (%d-%d)",
+ prev.Name(), prev.Base(), prev.Base()+prev.Size(),
+ file.Name(), file.Base(), file.Base()+file.Size()))
+ }
+}
+
+// set updates the existing node at (pos, parent) if present, or
+// inserts a new node if not, so that it refers to file.
+func (t *tree) set(file *File, pos **node, parent *node) {
+ if x := *pos; x != nil {
+ // This code path isn't currently needed
+ // because FileSet never updates an existing entry.
+ // Remove this assertion if things change.
+ panic("unreachable according to current FileSet requirements")
+ x.file = file
+ return
+ }
+ x := &node{file: file, key: file.key(), parent: parent, height: -1}
+ *pos = x
+ t.rebalanceUp(x)
+}
+
+// delete deletes the node at pos.
+func (t *tree) delete(pos **node) {
+ t.root.check(nil)
+
+ x := *pos
+ switch {
+ case x == nil:
+ // This code path isn't currently needed because FileSet
+ // only calls delete after a positive locate.
+ // Remove this assertion if things change.
+ panic("unreachable according to current FileSet requirements")
+ return
+
+ case x.left == nil:
+ if *pos = x.right; *pos != nil {
+ (*pos).parent = x.parent
+ }
+ t.rebalanceUp(x.parent)
+
+ case x.right == nil:
+ *pos = x.left
+ x.left.parent = x.parent
+ t.rebalanceUp(x.parent)
+
+ default:
+ t.deleteSwap(pos)
+ }
+
+ x.balance = -100
+ x.parent = nil
+ x.left = nil
+ x.right = nil
+ x.height = -1
+ t.root.check(nil)
+}
+
+// deleteSwap deletes a node that has two children by replacing
+// it by its in-order successor, then triggers a rebalance.
+func (t *tree) deleteSwap(pos **node) {
+ x := *pos
+ z := t.deleteMin(&x.right)
+
+ *pos = z
+ unbalanced := z.parent // lowest potentially unbalanced node
+ if unbalanced == x {
+ unbalanced = z // (x a (z nil b)) -> (z a b)
+ }
+ z.parent = x.parent
+ z.height = x.height
+ z.balance = x.balance
+ z.setLeft(x.left)
+ z.setRight(x.right)
+
+ t.rebalanceUp(unbalanced)
+}
+
+// deleteMin updates *zpos to the minimum (leftmost) element
+// in that subtree.
+func (t *tree) deleteMin(zpos **node) (z *node) {
+ for (*zpos).left != nil {
+ zpos = &(*zpos).left
+ }
+ z = *zpos
+ *zpos = z.right
+ if *zpos != nil {
+ (*zpos).parent = z.parent
+ }
+ return z
+}
--- /dev/null
+// Copyright 2025 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package token
+
+import (
+ "math/rand/v2"
+ "slices"
+ "testing"
+)
+
+// TestTree provides basic coverage of the AVL tree operations.
+func TestTree(t *testing.T) {
+ // Use a reproducible PRNG.
+ seed1, seed2 := rand.Uint64(), rand.Uint64()
+ t.Logf("random seeds: %d, %d", seed1, seed2)
+ rng := rand.New(rand.NewPCG(seed1, seed2))
+
+ // Create a number of Files of arbitrary size.
+ files := make([]*File, 500)
+ var base int
+ for i := range files {
+ base++
+ size := 1000
+ files[i] = &File{base: base, size: size}
+ base += size
+ }
+
+ // Add them all to the tree in random order.
+ var tr tree
+ {
+ files2 := slices.Clone(files)
+ Shuffle(rng, files2)
+ for _, f := range files2 {
+ tr.add(f)
+ }
+ }
+
+ // Randomly delete a subset of them.
+ for range 100 {
+ i := rng.IntN(len(files))
+ file := files[i]
+ if file == nil {
+ continue // already deleted
+ }
+ files[i] = nil
+
+ pn, _ := tr.locate(file.key())
+ if (*pn).file != file {
+ t.Fatalf("locate returned wrong file")
+ }
+ tr.delete(pn)
+ }
+
+ // Check some position lookups within each file.
+ for _, file := range files {
+ if file == nil {
+ continue // deleted
+ }
+ for _, pos := range []int{
+ file.base, // start
+ file.base + file.size/2, // midpoint
+ file.base + file.size, // end
+ } {
+ pn, _ := tr.locate(key{pos, pos})
+ if (*pn).file != file {
+ t.Fatalf("lookup %s@%d returned wrong file %s",
+ file.name, pos,
+ (*pn).file.name)
+ }
+ }
+ }
+
+ // Check that the sequence is the same.
+ files = slices.DeleteFunc(files, func(f *File) bool { return f == nil })
+ if !slices.Equal(slices.Collect(tr.all()), files) {
+ t.Fatalf("incorrect tree.all sequence")
+ }
+}
+
+func Shuffle[T any](rng *rand.Rand, slice []*T) {
+ rng.Shuffle(len(slice), func(i, j int) {
+ slice[i], slice[j] = slice[j], slice[i]
+ })
+}