]> Cypherpunks repositories - gostls13.git/commitdiff
cmd/cgo: recognize unsafe.{StringData,SliceData}
authorIan Lance Taylor <iant@golang.org>
Mon, 8 May 2023 19:45:42 +0000 (12:45 -0700)
committerGopher Robot <gobot@golang.org>
Mon, 22 May 2023 18:34:47 +0000 (18:34 +0000)
A simple call to unsafe.StringData can't contain any pointers.

When looking for field references, a call to unsafe.StringData or
unsafe.SliceData can be treated as a type conversion.

In order to make unsafe.SliceData useful, recognize slice expressions
when calling C functions.

Fixes #59954

Change-Id: I08a3ace7882073284c1d46a5210582a2521b0b4e
Reviewed-on: https://go-review.googlesource.com/c/go/+/493556
Run-TryBot: Ian Lance Taylor <iant@google.com>
Auto-Submit: Ian Lance Taylor <iant@google.com>
Reviewed-by: Michael Knyszek <mknyszek@google.com>
Reviewed-by: Ian Lance Taylor <iant@google.com>
TryBot-Result: Gopher Robot <gobot@golang.org>
Reviewed-by: David Chase <drchase@google.com>
src/cmd/cgo/gcc.go
src/cmd/cgo/internal/testerrors/ptr_test.go

index 5df4a8c4ad53c86f0073a08a740f0e9325bc77c1..5f7c6fbbe630c95e5c9c721692f31d2105bd87b7 100644 (file)
@@ -938,7 +938,7 @@ func (p *Package) rewriteCall(f *File, call *Call) (string, bool) {
                // constants to the parameter type, to avoid a type mismatch.
                ptype := p.rewriteUnsafe(param.Go)
 
-               if !p.needsPointerCheck(f, param.Go, args[i]) || param.BadPointer {
+               if !p.needsPointerCheck(f, param.Go, args[i]) || param.BadPointer || p.checkUnsafeStringData(args[i]) {
                        if ptype != param.Go {
                                needsUnsafe = true
                        }
@@ -957,6 +957,11 @@ func (p *Package) rewriteCall(f *File, call *Call) (string, bool) {
                        continue
                }
 
+               // Check for a[:].
+               if p.checkSlice(&sb, &sbCheck, arg, i) {
+                       continue
+               }
+
                fmt.Fprintf(&sb, "_cgo%d := %s; ", i, gofmtPos(arg, origArg.Pos()))
                fmt.Fprintf(&sbCheck, "_cgoCheckPointer(_cgo%d, nil); ", i)
        }
@@ -1178,7 +1183,10 @@ func (p *Package) checkIndex(sb, sbCheck *bytes.Buffer, arg ast.Expr, i int) boo
        x := arg
        for {
                c, ok := x.(*ast.CallExpr)
-               if !ok || len(c.Args) != 1 || !p.isType(c.Fun) {
+               if !ok || len(c.Args) != 1 {
+                       break
+               }
+               if !p.isType(c.Fun) && !p.isUnsafeData(c.Fun, false) {
                        break
                }
                x = c.Args[0]
@@ -1232,7 +1240,10 @@ func (p *Package) checkAddr(sb, sbCheck *bytes.Buffer, arg ast.Expr, i int) bool
        px := &arg
        for {
                c, ok := (*px).(*ast.CallExpr)
-               if !ok || len(c.Args) != 1 || !p.isType(c.Fun) {
+               if !ok || len(c.Args) != 1 {
+                       break
+               }
+               if !p.isType(c.Fun) && !p.isUnsafeData(c.Fun, false) {
                        break
                }
                px = &c.Args[0]
@@ -1255,6 +1266,71 @@ func (p *Package) checkAddr(sb, sbCheck *bytes.Buffer, arg ast.Expr, i int) bool
        return true
 }
 
+// checkSlice checks whether arg has the form x[i:j], possibly inside
+// type conversions. If so, it writes
+//
+//     _cgoSliceNN := x[i:j]
+//     _cgoNN := _cgoSliceNN // with type conversions, if any
+//
+// to sb, and writes
+//
+//     _cgoCheckPointer(_cgoSliceNN, true)
+//
+// to sbCheck, and returns true. This tells _cgoCheckPointer to check
+// just the contents of the slice being passed, not any other part
+// of the memory allocation.
+func (p *Package) checkSlice(sb, sbCheck *bytes.Buffer, arg ast.Expr, i int) bool {
+       // Strip type conversions.
+       px := &arg
+       for {
+               c, ok := (*px).(*ast.CallExpr)
+               if !ok || len(c.Args) != 1 {
+                       break
+               }
+               if !p.isType(c.Fun) && !p.isUnsafeData(c.Fun, false) {
+                       break
+               }
+               px = &c.Args[0]
+       }
+       if _, ok := (*px).(*ast.SliceExpr); !ok {
+               return false
+       }
+
+       fmt.Fprintf(sb, "_cgoSlice%d := %s; ", i, gofmtPos(*px, (*px).Pos()))
+
+       origX := *px
+       *px = ast.NewIdent(fmt.Sprintf("_cgoSlice%d", i))
+       fmt.Fprintf(sb, "_cgo%d := %s; ", i, gofmtPos(arg, arg.Pos()))
+       *px = origX
+
+       // Use 0 == 0 to do the right thing in the unlikely event
+       // that "true" is shadowed.
+       fmt.Fprintf(sbCheck, "_cgoCheckPointer(_cgoSlice%d, 0 == 0); ", i)
+
+       return true
+}
+
+// checkUnsafeStringData checks for a call to unsafe.StringData.
+// The result of that call can't contain a pointer so there is
+// no need to call _cgoCheckPointer.
+func (p *Package) checkUnsafeStringData(arg ast.Expr) bool {
+       x := arg
+       for {
+               c, ok := x.(*ast.CallExpr)
+               if !ok || len(c.Args) != 1 {
+                       break
+               }
+               if p.isUnsafeData(c.Fun, true) {
+                       return true
+               }
+               if !p.isType(c.Fun) {
+                       break
+               }
+               x = c.Args[0]
+       }
+       return false
+}
+
 // isType reports whether the expression is definitely a type.
 // This is conservative--it returns false for an unknown identifier.
 func (p *Package) isType(t ast.Expr) bool {
@@ -1299,6 +1375,28 @@ func (p *Package) isType(t ast.Expr) bool {
        return false
 }
 
+// isUnsafeData reports whether the expression is unsafe.StringData
+// or unsafe.SliceData. We can ignore these when checking for pointers
+// because they don't change whether or not their argument contains
+// any Go pointers. If onlyStringData is true we only check for StringData.
+func (p *Package) isUnsafeData(x ast.Expr, onlyStringData bool) bool {
+       st, ok := x.(*ast.SelectorExpr)
+       if !ok {
+               return false
+       }
+       id, ok := st.X.(*ast.Ident)
+       if !ok {
+               return false
+       }
+       if id.Name != "unsafe" {
+               return false
+       }
+       if !onlyStringData && st.Sel.Name == "SliceData" {
+               return true
+       }
+       return st.Sel.Name == "StringData"
+}
+
 // isVariable reports whether x is a variable, possibly with field references.
 func (p *Package) isVariable(x ast.Expr) bool {
        switch x := x.(type) {
index eb923eaa5b5942244dc274a61d4a22b60c6cb5c5..33126f40ae382dae7c4b07920714ba1a8044eb16 100644 (file)
@@ -444,6 +444,28 @@ var ptrTests = []ptrTest{
                body:    `s := &S40{p: new(int)}; C.f40((*C.struct_S40i)(&s.a))`,
                fail:    false,
        },
+       {
+               // Test that we handle unsafe.StringData.
+               name:    "stringdata",
+               c:       `void f41(void* p) {}`,
+               imports: []string{"unsafe"},
+               body:    `s := struct { a [4]byte; p *int }{p: new(int)}; str := unsafe.String(&s.a[0], 4); C.f41(unsafe.Pointer(unsafe.StringData(str)))`,
+               fail:    false,
+       },
+       {
+               name:    "slicedata",
+               c:       `void f42(void* p) {}`,
+               imports: []string{"unsafe"},
+               body:    `s := []*byte{nil, new(byte)}; C.f42(unsafe.Pointer(unsafe.SliceData(s)))`,
+               fail:    true,
+       },
+       {
+               name:    "slicedata2",
+               c:       `void f43(void* p) {}`,
+               imports: []string{"unsafe"},
+               body:    `s := struct { a [4]byte; p *int }{p: new(int)}; C.f43(unsafe.Pointer(unsafe.SliceData(s.a[:])))`,
+               fail:    false,
+       },
 }
 
 func TestPointerChecks(t *testing.T) {
@@ -497,7 +519,7 @@ func buildPtrTests(t *testing.T, gopath string, cgocheck2 bool) (exe string) {
        if err := os.MkdirAll(src, 0777); err != nil {
                t.Fatal(err)
        }
-       if err := os.WriteFile(filepath.Join(src, "go.mod"), []byte("module ptrtest"), 0666); err != nil {
+       if err := os.WriteFile(filepath.Join(src, "go.mod"), []byte("module ptrtest\ngo 1.20"), 0666); err != nil {
                t.Fatal(err)
        }