]> Cypherpunks repositories - gostls13.git/commitdiff
slices: handle aliasing cases in Insert/Replace
authorKeith Randall <khr@golang.org>
Sat, 13 May 2023 03:23:57 +0000 (20:23 -0700)
committerKeith Randall <khr@golang.org>
Tue, 16 May 2023 23:34:50 +0000 (23:34 +0000)
Handle cases where the inserted slice is actually part of the slice
that is being inserted into.

Requires a bit more work, but no more allocations. (Compare to #494536.)

Not entirely sure this is worth the complication.

Fixes #60138

Change-Id: Ia72c872b04309b99025e6ca5a4a326ebed2abb69
Reviewed-on: https://go-review.googlesource.com/c/go/+/494817
TryBot-Result: Gopher Robot <gobot@golang.org>
Reviewed-by: Cuong Manh Le <cuong.manhle.vn@gmail.com>
Reviewed-by: Keith Randall <khr@google.com>
Run-TryBot: Keith Randall <khr@golang.org>
Reviewed-by: Bryan Mills <bcmills@google.com>
src/go/build/deps_test.go
src/slices/slices.go
src/slices/slices_test.go

index 324afbfd7c777461e8ae952332d7c8ac1653cb8f..e93422addcace77f1dc5f4de454422998f74cedf 100644 (file)
@@ -46,9 +46,13 @@ var depsRules = `
          internal/goexperiment, internal/goos,
          internal/goversion, internal/nettrace, internal/platform,
          log/internal,
-         maps, slices, unicode/utf8, unicode/utf16, unicode,
+         maps, unicode/utf8, unicode/utf16, unicode,
          unsafe;
 
+       # slices depends on unsafe for overlapping check.
+       unsafe
+       < slices;
+
        # These packages depend only on internal/goarch and unsafe.
        internal/goarch, unsafe
        < internal/abi;
index dd414635ce2330f84e570973ee6391dadf4c8b0d..3c1dfac3dd6379a87301273570d759607ba9bc5d 100644 (file)
@@ -5,6 +5,10 @@
 // Package slices defines various functions useful with slices of any type.
 package slices
 
+import (
+       "unsafe"
+)
+
 // Equal reports whether two slices are equal: the same length and all
 // elements equal. If the lengths are different, Equal returns false.
 // Otherwise, the elements are compared in increasing index order, and the
@@ -81,22 +85,83 @@ func ContainsFunc[E any](s []E, f func(E) bool) bool {
 // Insert panics if i is out of range.
 // This function is O(len(s) + len(v)).
 func Insert[S ~[]E, E any](s S, i int, v ...E) S {
-       tot := len(s) + len(v)
-       if tot <= cap(s) {
-               s2 := s[:tot]
-               copy(s2[i+len(v):], s[i:])
+       m := len(v)
+       if m == 0 {
+               return s
+       }
+       n := len(s)
+       if i == n {
+               return append(s, v...)
+       }
+       if n+m > cap(s) {
+               // Use append rather than make so that we bump the size of
+               // the slice up to the next storage class.
+               // This is what Grow does but we don't call Grow because
+               // that might copy the values twice.
+               s2 := append(S(nil), make(S, n+m)...)
+               copy(s2, s[:i])
                copy(s2[i:], v)
+               copy(s2[i+m:], s[i:])
                return s2
        }
-       // Use append rather than make so that we bump the size of
-       // the slice up to the next storage class.
-       // This is what Grow does but we don't call Grow because
-       // that might copy the values twice.
-       s2 := append(S(nil), make(S, tot)...)
-       copy(s2, s[:i])
-       copy(s2[i:], v)
-       copy(s2[i+len(v):], s[i:])
-       return s2
+       s = s[:n+m]
+
+       // before:
+       // s: aaaaaaaabbbbccccccccdddd
+       //            ^   ^       ^   ^
+       //            i  i+m      n  n+m
+       // after:
+       // s: aaaaaaaavvvvbbbbcccccccc
+       //            ^   ^       ^   ^
+       //            i  i+m      n  n+m
+       //
+       // a are the values that don't move in s.
+       // v are the values copied in from v.
+       // b and c are the values from s that are shifted up in index.
+       // d are the values that get overwritten, never to be seen again.
+
+       if !overlaps(v, s[i+m:]) {
+               // Easy case - v does not overlap either the c or d regions.
+               // (It might be in some of a or b, or elsewhere entirely.)
+               // The data we copy up doesn't write to v at all, so just do it.
+
+               copy(s[i+m:], s[i:])
+
+               // Now we have
+               // s: aaaaaaaabbbbbbbbcccccccc
+               //            ^   ^       ^   ^
+               //            i  i+m      n  n+m
+               // Note the b values are duplicated.
+
+               copy(s[i:], v)
+
+               // Now we have
+               // s: aaaaaaaavvvvbbbbcccccccc
+               //            ^   ^       ^   ^
+               //            i  i+m      n  n+m
+               // That's the result we want.
+               return s
+       }
+
+       // The hard case - v overlaps c or d. We can't just shift up
+       // the data because we'd move or clobber the values we're trying
+       // to insert.
+       // So instead, write v on top of d, then rotate.
+       copy(s[n:], v)
+
+       // Now we have
+       // s: aaaaaaaabbbbccccccccvvvv
+       //            ^   ^       ^   ^
+       //            i  i+m      n  n+m
+
+       rotateRight(s[i:], m)
+
+       // Now we have
+       // s: aaaaaaaavvvvbbbbcccccccc
+       //            ^   ^       ^   ^
+       //            i  i+m      n  n+m
+       // That's the result we want.
+       return s
 }
 
 // Delete removes the elements s[i:j] from s, returning the modified slice.
@@ -143,18 +208,89 @@ func DeleteFunc[S ~[]E, E any](s S, del func(E) bool) S {
 // modified slice. Replace panics if s[i:j] is not a valid slice of s.
 func Replace[S ~[]E, E any](s S, i, j int, v ...E) S {
        _ = s[i:j] // verify that i:j is a valid subslice
+
+       if i == j {
+               return Insert(s, i, v...)
+       }
+       if j == len(s) {
+               return append(s[:i], v...)
+       }
+
        tot := len(s[:i]) + len(v) + len(s[j:])
-       if tot <= cap(s) {
-               s2 := s[:tot]
-               copy(s2[i+len(v):], s[j:])
+       if tot > cap(s) {
+               // Too big to fit, allocate and copy over.
+               s2 := append(S(nil), make(S, tot)...) // See Insert
+               copy(s2, s[:i])
                copy(s2[i:], v)
+               copy(s2[i+len(v):], s[j:])
                return s2
        }
-       s2 := make(S, tot)
-       copy(s2, s[:i])
-       copy(s2[i:], v)
-       copy(s2[i+len(v):], s[j:])
-       return s2
+
+       r := s[:tot]
+
+       if i+len(v) <= j {
+               // Easy, as v fits in the deleted portion.
+               copy(r[i:], v)
+               if i+len(v) != j {
+                       copy(r[i+len(v):], s[j:])
+               }
+               return r
+       }
+
+       // We are expanding (v is bigger than j-i).
+       // The situation is something like this:
+       // (example has i=4,j=8,len(s)=16,len(v)=6)
+       // s: aaaaxxxxbbbbbbbbyy
+       //        ^   ^       ^ ^
+       //        i   j  len(s) tot
+       // a: prefix of s
+       // x: deleted range
+       // b: more of s
+       // y: area to expand into
+
+       if !overlaps(r[i+len(v):], v) {
+               // Easy, as v is not clobbered by the first copy.
+               copy(r[i+len(v):], s[j:])
+               copy(r[i:], v)
+               return r
+       }
+
+       // This is a situation where we don't have a single place to which
+       // we can copy v. Parts of it need to go to two different places.
+       // We want to copy the prefix of v into y and the suffix into x, then
+       // rotate |y| spots to the right.
+       //
+       //        v[2:]      v[:2]
+       //         |           |
+       // s: aaaavvvvbbbbbbbbvv
+       //        ^   ^       ^ ^
+       //        i   j  len(s) tot
+       //
+       // If either of those two destinations don't alias v, then we're good.
+       y := len(v) - (j - i) // length of y portion
+
+       if !overlaps(r[i:j], v) {
+               copy(r[i:j], v[y:])
+               copy(r[len(s):], v[:y])
+               rotateRight(r[i:], y)
+               return r
+       }
+       if !overlaps(r[len(s):], v) {
+               copy(r[len(s):], v[:y])
+               copy(r[i:j], v[y:])
+               rotateRight(r[i:], y)
+               return r
+       }
+
+       // Now we know that v overlaps both x and y.
+       // That means that the entirety of b is *inside* v.
+       // So we don't need to preserve b at all; instead we
+       // can copy v first, then copy the b part of v out of
+       // v to the right destination.
+       k := startIdx(v, s[j:])
+       copy(r[i:], v)
+       copy(r[i+len(v):], r[i+k:])
+       return r
 }
 
 // Clone returns a copy of the slice.
@@ -224,3 +360,90 @@ func Grow[S ~[]E, E any](s S, n int) S {
 func Clip[S ~[]E, E any](s S) S {
        return s[:len(s):len(s)]
 }
+
+// Rotation algorithm explanation:
+//
+// rotate left by 2
+// start with
+//   0123456789
+// split up like this
+//   01 234567 89
+// swap first 2 and last 2
+//   89 234567 01
+// join first parts
+//   89234567 01
+// recursively rotate first left part by 2
+//   23456789 01
+// join at the end
+//   2345678901
+//
+// rotate left by 8
+// start with
+//   0123456789
+// split up like this
+//   01 234567 89
+// swap first 2 and last 2
+//   89 234567 01
+// join last parts
+//   89 23456701
+// recursively rotate second part left by 6
+//   89 01234567
+// join at the end
+//   8901234567
+
+// TODO: There are other rotate algorithms.
+// This algorithm has the desirable property that it moves each element exactly twice.
+// The triple-reverse algorithm is simpler and more cache friendly, but takes more writes.
+// The follow-cycles algorithm can be 1-write but it is not very cache friendly.
+
+// rotateLeft rotates b left by n spaces.
+// s_final[i] = s_orig[i+r], wrapping around.
+func rotateLeft[S ~[]E, E any](s S, r int) {
+       for r != 0 && r != len(s) {
+               if r*2 <= len(s) {
+                       swap(s[:r], s[len(s)-r:])
+                       s = s[:len(s)-r]
+               } else {
+                       swap(s[:len(s)-r], s[r:])
+                       s, r = s[len(s)-r:], r*2-len(s)
+               }
+       }
+}
+func rotateRight[S ~[]E, E any](s S, r int) {
+       rotateLeft(s, len(s)-r)
+}
+
+// swap swaps the contents of x and y. x and y must be equal length and disjoint.
+func swap[S ~[]E, E any](x, y S) {
+       for i := 0; i < len(x); i++ {
+               x[i], y[i] = y[i], x[i]
+       }
+}
+
+// overlaps reports whether the memory ranges a[0:len(a)] and b[0:len(b)] overlap.
+func overlaps[S ~[]E, E any](a, b S) bool {
+       if len(a) == 0 || len(b) == 0 {
+               return false
+       }
+       elemSize := unsafe.Sizeof(a[0])
+       if elemSize == 0 {
+               return false
+       }
+       // TODO: use a runtime/unsafe facility once one becomes available. See issue 12445.
+       // Also see crypto/internal/alias/alias.go:AnyOverlap
+       return uintptr(unsafe.Pointer(&a[0])) <= uintptr(unsafe.Pointer(&b[len(b)-1]))+(elemSize-1) &&
+               uintptr(unsafe.Pointer(&b[0])) <= uintptr(unsafe.Pointer(&a[len(a)-1]))+(elemSize-1)
+}
+
+// startIdx returns the index in haystack where the needle starts.
+// prerequisite: the needle must be aliased entirely inside the haystack.
+func startIdx[S ~[]E, E any](haystack, needle S) int {
+       p := &needle[0]
+       for i := range haystack {
+               if p == &haystack[i] {
+                       return i
+               }
+       }
+       // TODO: what if the overlap is by a non-integral number of Es?
+       panic("needle not found")
+}
index 4d893617f78636895833a06eb4912c4cb2568d6f..c13a67c2d473554ebfeb2f4adf165631b9dc1758 100644 (file)
@@ -302,6 +302,31 @@ func TestInsert(t *testing.T) {
        }
 }
 
+func TestInsertOverlap(t *testing.T) {
+       const N = 10
+       a := make([]int, N)
+       want := make([]int, 2*N)
+       for n := 0; n <= N; n++ { // length
+               for i := 0; i <= n; i++ { // insertion point
+                       for x := 0; x <= N; x++ { // start of inserted data
+                               for y := x; y <= N; y++ { // end of inserted data
+                                       for k := 0; k < N; k++ {
+                                               a[k] = k
+                                       }
+                                       want = want[:0]
+                                       want = append(want, a[:i]...)
+                                       want = append(want, a[x:y]...)
+                                       want = append(want, a[i:n]...)
+                                       got := Insert(a[:n], i, a[x:y]...)
+                                       if !Equal(got, want) {
+                                               t.Errorf("Insert with overlap failed n=%d i=%d x=%d y=%d, got %v want %v", n, i, x, y, got, want)
+                                       }
+                               }
+                       }
+               }
+       }
+}
+
 var deleteTests = []struct {
        s    []int
        i, j int
@@ -662,6 +687,33 @@ func TestReplacePanics(t *testing.T) {
        }
 }
 
+func TestReplaceOverlap(t *testing.T) {
+       const N = 10
+       a := make([]int, N)
+       want := make([]int, 2*N)
+       for n := 0; n <= N; n++ { // length
+               for i := 0; i <= n; i++ { // insertion point 1
+                       for j := i; j <= n; j++ { // insertion point 2
+                               for x := 0; x <= N; x++ { // start of inserted data
+                                       for y := x; y <= N; y++ { // end of inserted data
+                                               for k := 0; k < N; k++ {
+                                                       a[k] = k
+                                               }
+                                               want = want[:0]
+                                               want = append(want, a[:i]...)
+                                               want = append(want, a[x:y]...)
+                                               want = append(want, a[j:n]...)
+                                               got := Replace(a[:n], i, j, a[x:y]...)
+                                               if !Equal(got, want) {
+                                                       t.Errorf("Insert with overlap failed n=%d i=%d j=%d x=%d y=%d, got %v want %v", n, i, j, x, y, got, want)
+                                               }
+                                       }
+                               }
+                       }
+               }
+       }
+}
+
 func BenchmarkReplace(b *testing.B) {
        cases := []struct {
                name string
@@ -710,3 +762,22 @@ func BenchmarkReplace(b *testing.B) {
        }
 
 }
+
+func TestRotate(t *testing.T) {
+       const N = 10
+       s := make([]int, 0, N)
+       for n := 0; n < N; n++ {
+               for r := 0; r < n; r++ {
+                       s = s[:0]
+                       for i := 0; i < n; i++ {
+                               s = append(s, i)
+                       }
+                       rotateLeft(s, r)
+                       for i := 0; i < n; i++ {
+                               if s[i] != (i+r)%n {
+                                       t.Errorf("expected n=%d r=%d i:%d want:%d got:%d", n, r, i, (i+r)%n, s[i])
+                               }
+                       }
+               }
+       }
+}