]> Cypherpunks repositories - gostls13.git/commitdiff
go/types, types2: record correct argument type for cap, len
authorRobert Griesemer <gri@golang.org>
Sat, 5 Feb 2022 00:57:43 +0000 (16:57 -0800)
committerRobert Griesemer <gri@golang.org>
Mon, 7 Feb 2022 20:21:30 +0000 (20:21 +0000)
Record the actual argument type for a cap/len call, not the
underlying type.

Fixes #51055.

Change-Id: Ia0e746a462377f030424ccaec0babf72b78da420
Reviewed-on: https://go-review.googlesource.com/c/go/+/383474
Trust: Robert Griesemer <gri@golang.org>
Run-TryBot: Robert Griesemer <gri@golang.org>
Reviewed-by: Robert Findley <rfindley@google.com>
src/cmd/compile/internal/types2/builtins.go
src/cmd/compile/internal/types2/builtins_test.go
src/go/types/builtins.go
src/go/types/builtins_test.go
src/go/types/example_test.go

index c2f955ce8c70a2409a91c3c59111c9a4838b7c04..f9db07fdea5d9fde84105d1333d758f9a2a58ca1 100644 (file)
@@ -142,9 +142,8 @@ func (check *Checker) builtin(x *operand, call *syntax.CallExpr, id builtinId) (
                // cap(x)
                // len(x)
                mode := invalid
-               var typ Type
                var val constant.Value
-               switch typ = arrayPtrDeref(under(x.typ)); t := typ.(type) {
+               switch t := arrayPtrDeref(under(x.typ)).(type) {
                case *Basic:
                        if isString(t) && id == _Len {
                                if x.mode == constant_ {
@@ -201,17 +200,19 @@ func (check *Checker) builtin(x *operand, call *syntax.CallExpr, id builtinId) (
                        }
                }
 
-               if mode == invalid && typ != Typ[Invalid] {
+               if mode == invalid && under(x.typ) != Typ[Invalid] {
                        check.errorf(x, invalidArg+"%s for %s", x, bin.name)
                        return
                }
 
+               // record the signature before changing x.typ
+               if check.Types != nil && mode != constant_ {
+                       check.recordBuiltinType(call.Fun, makeSig(Typ[Int], x.typ))
+               }
+
                x.mode = mode
                x.typ = Typ[Int]
                x.val = val
-               if check.Types != nil && mode != constant_ {
-                       check.recordBuiltinType(call.Fun, makeSig(x.typ, typ))
-               }
 
        case _Close:
                // close(c)
index be5707cdfeac4dce2820226470cf2b1fc16608a4..e07a7794f6c74521ab954a8e3f498e211b4e700d 100644 (file)
@@ -28,6 +28,8 @@ var builtinCalls = []struct {
        {"cap", `var s [10]int; _ = cap(&s)`, `invalid type`}, // constant
        {"cap", `var s []int64; _ = cap(s)`, `func([]int64) int`},
        {"cap", `var c chan<-bool; _ = cap(c)`, `func(chan<- bool) int`},
+       {"cap", `type S []byte; var s S; _ = cap(s)`, `func(p.S) int`},
+       {"cap", `var s P; _ = cap(s)`, `func(P) int`},
 
        {"len", `_ = len("foo")`, `invalid type`}, // constant
        {"len", `var s string; _ = len(s)`, `func(string) int`},
@@ -36,6 +38,8 @@ var builtinCalls = []struct {
        {"len", `var s []int64; _ = len(s)`, `func([]int64) int`},
        {"len", `var c chan<-bool; _ = len(c)`, `func(chan<- bool) int`},
        {"len", `var m map[string]float32; _ = len(m)`, `func(map[string]float32) int`},
+       {"len", `type S []byte; var s S; _ = len(s)`, `func(p.S) int`},
+       {"len", `var s P; _ = len(s)`, `func(P) int`},
 
        {"close", `var c chan int; close(c)`, `func(chan int)`},
        {"close", `var c chan<- chan string; close(c)`, `func(chan<- chan string)`},
@@ -159,7 +163,7 @@ func parseGenericSrc(path, src string) (*syntax.File, error) {
 }
 
 func testBuiltinSignature(t *testing.T, name, src0, want string) {
-       src := fmt.Sprintf(`package p; import "unsafe"; type _ unsafe.Pointer /* use unsafe */; func _[P any]() { %s }`, src0)
+       src := fmt.Sprintf(`package p; import "unsafe"; type _ unsafe.Pointer /* use unsafe */; func _[P ~[]byte]() { %s }`, src0)
        f, err := parseGenericSrc("", src)
        if err != nil {
                t.Errorf("%s: %s", src0, err)
index f9aece225b2abe3d3b9a4221ccddb683404ba715..8fcfcb935f983cd0d46252582a361069076fee38 100644 (file)
@@ -143,9 +143,8 @@ func (check *Checker) builtin(x *operand, call *ast.CallExpr, id builtinId) (_ b
                // cap(x)
                // len(x)
                mode := invalid
-               var typ Type
                var val constant.Value
-               switch typ = arrayPtrDeref(under(x.typ)); t := typ.(type) {
+               switch t := arrayPtrDeref(under(x.typ)).(type) {
                case *Basic:
                        if isString(t) && id == _Len {
                                if x.mode == constant_ {
@@ -202,7 +201,7 @@ func (check *Checker) builtin(x *operand, call *ast.CallExpr, id builtinId) (_ b
                        }
                }
 
-               if mode == invalid && typ != Typ[Invalid] {
+               if mode == invalid && under(x.typ) != Typ[Invalid] {
                        code := _InvalidCap
                        if id == _Len {
                                code = _InvalidLen
@@ -211,12 +210,14 @@ func (check *Checker) builtin(x *operand, call *ast.CallExpr, id builtinId) (_ b
                        return
                }
 
+               // record the signature before changing x.typ
+               if check.Types != nil && mode != constant_ {
+                       check.recordBuiltinType(call.Fun, makeSig(Typ[Int], x.typ))
+               }
+
                x.mode = mode
                x.typ = Typ[Int]
                x.val = val
-               if check.Types != nil && mode != constant_ {
-                       check.recordBuiltinType(call.Fun, makeSig(x.typ, typ))
-               }
 
        case _Close:
                // close(c)
index edcd7e7724a1aaf12802594ed6a8b29985eef771..7e967a36e10e0d0fdd264032e8842bef1aacff1f 100644 (file)
@@ -29,6 +29,8 @@ var builtinCalls = []struct {
        {"cap", `var s [10]int; _ = cap(&s)`, `invalid type`}, // constant
        {"cap", `var s []int64; _ = cap(s)`, `func([]int64) int`},
        {"cap", `var c chan<-bool; _ = cap(c)`, `func(chan<- bool) int`},
+       {"cap", `type S []byte; var s S; _ = cap(s)`, `func(p.S) int`},
+       {"cap", `var s P; _ = cap(s)`, `func(P) int`},
 
        {"len", `_ = len("foo")`, `invalid type`}, // constant
        {"len", `var s string; _ = len(s)`, `func(string) int`},
@@ -37,6 +39,8 @@ var builtinCalls = []struct {
        {"len", `var s []int64; _ = len(s)`, `func([]int64) int`},
        {"len", `var c chan<-bool; _ = len(c)`, `func(chan<- bool) int`},
        {"len", `var m map[string]float32; _ = len(m)`, `func(map[string]float32) int`},
+       {"len", `type S []byte; var s S; _ = len(s)`, `func(p.S) int`},
+       {"len", `var s P; _ = len(s)`, `func(P) int`},
 
        {"close", `var c chan int; close(c)`, `func(chan int)`},
        {"close", `var c chan<- chan string; close(c)`, `func(chan<- chan string)`},
@@ -157,7 +161,7 @@ func TestBuiltinSignatures(t *testing.T) {
 // parseGenericSrc in types2 is not necessary. We can just parse in testBuiltinSignature below.
 
 func testBuiltinSignature(t *testing.T, name, src0, want string) {
-       src := fmt.Sprintf(`package p; import "unsafe"; type _ unsafe.Pointer /* use unsafe */; func _[P any]() { %s }`, src0)
+       src := fmt.Sprintf(`package p; import "unsafe"; type _ unsafe.Pointer /* use unsafe */; func _[P ~[]byte]() { %s }`, src0)
        f, err := parser.ParseFile(fset, "", src, 0)
        if err != nil {
                t.Errorf("%s: %s", src0, err)
index 270256748645a9d7d6380df4d17cbb2d45f08dc0..3c1bdb58c3e4112fbec0df393f6b5e32ab85c42a 100644 (file)
@@ -279,7 +279,7 @@ func fib(x int) int {
        //
        // Types and Values of each expression:
        //  4: 8 | string              | type    : string
-       //  6:15 | len                 | builtin : func(string) int
+       //  6:15 | len                 | builtin : func(fib.S) int
        //  6:15 | len(b)              | value   : int
        //  6:19 | b                   | var     : fib.S
        //  6:23 | S                   | type    : fib.S