]> Cypherpunks repositories - gostls13.git/commitdiff
cmd/compile: recognize (*[Big]T)(ptr)[:n:m] pattern for -d=checkptr
authorMatthew Dempsky <mdempsky@google.com>
Thu, 17 Oct 2019 21:29:16 +0000 (14:29 -0700)
committerMatthew Dempsky <mdempsky@google.com>
Mon, 21 Oct 2019 23:16:27 +0000 (23:16 +0000)
A common idiom for turning an unsafe.Pointer into a slice is to write:

    s := (*[Big]T)(ptr)[:n:m]

This technically violates Go's unsafe pointer rules (rule #1 says T2
can't be bigger than T1), but it's fairly common and not too difficult
to recognize, so might as well allow it for now so we can make
progress on #34972.

This should be revisited if #19367 is accepted.

Updates #22218.
Updates #34972.

Change-Id: Id824e2461904e770910b6e728b4234041d2cc8bc
Reviewed-on: https://go-review.googlesource.com/c/go/+/201839
Run-TryBot: Matthew Dempsky <mdempsky@google.com>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
src/cmd/compile/internal/gc/builtin.go
src/cmd/compile/internal/gc/builtin/runtime.go
src/cmd/compile/internal/gc/walk.go
src/runtime/checkptr.go

index ab65696a09194fc3db89ee4eebeea0e3b84ffdf2..17c45cab15ba11275ae44209b4db236013b27f9b 100644 (file)
@@ -312,7 +312,7 @@ func runtimeTypes() []*types.Type {
        typs[117] = functype(nil, []*Node{anonfield(typs[23]), anonfield(typs[23])}, []*Node{anonfield(typs[23])})
        typs[118] = functype(nil, []*Node{anonfield(typs[50])}, nil)
        typs[119] = functype(nil, []*Node{anonfield(typs[50]), anonfield(typs[50])}, nil)
-       typs[120] = functype(nil, []*Node{anonfield(typs[56]), anonfield(typs[1])}, nil)
+       typs[120] = functype(nil, []*Node{anonfield(typs[56]), anonfield(typs[1]), anonfield(typs[50])}, nil)
        typs[121] = types.NewSlice(typs[56])
        typs[122] = functype(nil, []*Node{anonfield(typs[56]), anonfield(typs[121])}, nil)
        return typs[:]
index 10a2241597fdf162e317a0a9c38bf60f469b1f19..3fc82c26812e59d5c1c627946b73fa06c1553007 100644 (file)
@@ -235,7 +235,7 @@ func racewriterange(addr, size uintptr)
 func msanread(addr, size uintptr)
 func msanwrite(addr, size uintptr)
 
-func checkptrAlignment(unsafe.Pointer, *byte)
+func checkptrAlignment(unsafe.Pointer, *byte, uintptr)
 func checkptrArithmetic(unsafe.Pointer, []unsafe.Pointer)
 
 // architecture variants
index 8f6da254715fd7eb9de8a82d908e5c18011ed2cc..78de8114d0a2278c627907485329a0da5a52a6f0 100644 (file)
@@ -953,7 +953,7 @@ opswitch:
                n.Left = walkexpr(n.Left, init)
                if n.Op == OCONVNOP && checkPtr(Curfn, 1) {
                        if n.Type.IsPtr() && n.Left.Type.Etype == TUNSAFEPTR { // unsafe.Pointer to *T
-                               n = walkCheckPtrAlignment(n, init)
+                               n = walkCheckPtrAlignment(n, init, nil)
                                break
                        }
                        if n.Type.Etype == TUNSAFEPTR && n.Left.Type.Etype == TUINTPTR { // uintptr to unsafe.Pointer
@@ -1120,7 +1120,12 @@ opswitch:
                n.List.SetSecond(walkexpr(n.List.Second(), init))
 
        case OSLICE, OSLICEARR, OSLICESTR, OSLICE3, OSLICE3ARR:
-               n.Left = walkexpr(n.Left, init)
+               checkSlice := checkPtr(Curfn, 1) && n.Op == OSLICE3ARR && n.Left.Op == OCONVNOP && n.Left.Left.Type.Etype == TUNSAFEPTR
+               if checkSlice {
+                       n.Left.Left = walkexpr(n.Left.Left, init)
+               } else {
+                       n.Left = walkexpr(n.Left, init)
+               }
                low, high, max := n.SliceBounds()
                low = walkexpr(low, init)
                if low != nil && isZero(low) {
@@ -1130,6 +1135,9 @@ opswitch:
                high = walkexpr(high, init)
                max = walkexpr(max, init)
                n.SetSliceBounds(low, high, max)
+               if checkSlice {
+                       n.Left = walkCheckPtrAlignment(n.Left, init, max)
+               }
                if n.Op.IsSlice3() {
                        if max != nil && max.Op == OCAP && samesafeexpr(n.Left, max.Left) {
                                // Reduce x[i:j:cap(x)] to x[i:j].
@@ -3912,13 +3920,29 @@ func isRuneCount(n *Node) bool {
        return Debug['N'] == 0 && !instrumenting && n.Op == OLEN && n.Left.Op == OSTR2RUNES
 }
 
-func walkCheckPtrAlignment(n *Node, init *Nodes) *Node {
-       if n.Type.Elem().Alignment() == 1 && n.Type.Elem().Size() == 1 {
+func walkCheckPtrAlignment(n *Node, init *Nodes, count *Node) *Node {
+       if !n.Type.IsPtr() {
+               Fatalf("expected pointer type: %v", n.Type)
+       }
+       elem := n.Type.Elem()
+       if count != nil {
+               if !elem.IsArray() {
+                       Fatalf("expected array type: %v", elem)
+               }
+               elem = elem.Elem()
+       }
+
+       size := elem.Size()
+       if elem.Alignment() == 1 && (size == 0 || size == 1 && count == nil) {
                return n
        }
 
+       if count == nil {
+               count = nodintconst(1)
+       }
+
        n.Left = cheapexpr(n.Left, init)
-       init.Append(mkcall("checkptrAlignment", nil, init, convnop(n.Left, types.Types[TUNSAFEPTR]), typename(n.Type.Elem())))
+       init.Append(mkcall("checkptrAlignment", nil, init, convnop(n.Left, types.Types[TUNSAFEPTR]), typename(elem), conv(count, types.Types[TUINTPTR])))
        return n
 }
 
index a6d33c5af1024bc4c93369fb2e241731b754764c..d1fc651509d2158ded14dad93cb97d41c659b122 100644 (file)
@@ -9,18 +9,19 @@ import "unsafe"
 type ptrAlign struct {
        ptr  unsafe.Pointer
        elem *_type
+       n    uintptr
 }
 
-func checkptrAlignment(p unsafe.Pointer, elem *_type) {
-       // Check that (*T)(p) is appropriately aligned.
+func checkptrAlignment(p unsafe.Pointer, elem *_type, n uintptr) {
+       // Check that (*[n]elem)(p) is appropriately aligned.
        // TODO(mdempsky): What about fieldAlign?
        if uintptr(p)&(uintptr(elem.align)-1) != 0 {
-               panic(ptrAlign{p, elem})
+               panic(ptrAlign{p, elem, n})
        }
 
-       // Check that (*T)(p) doesn't straddle multiple heap objects.
-       if elem.size != 1 && checkptrBase(p) != checkptrBase(add(p, elem.size-1)) {
-               panic(ptrAlign{p, elem})
+       // Check that (*[n]elem)(p) doesn't straddle multiple heap objects.
+       if size := n * elem.size; size > 1 && checkptrBase(p) != checkptrBase(add(p, size-1)) {
+               panic(ptrAlign{p, elem, n})
        }
 }
 
@@ -34,6 +35,9 @@ func checkptrArithmetic(p unsafe.Pointer, originals []unsafe.Pointer) {
                panic(ptrArith{p, originals})
        }
 
+       // Check that if the computed pointer p points into a heap
+       // object, then one of the original pointers must have pointed
+       // into the same object.
        base := checkptrBase(p)
        if base == 0 {
                return