]> Cypherpunks repositories - gostls13.git/commitdiff
[dev.typeparams] cmd/compile: allow generic funcs to call other generic funcs for...
authorDan Scales <danscales@google.com>
Mon, 8 Feb 2021 18:23:05 +0000 (10:23 -0800)
committerDan Scales <danscales@google.com>
Mon, 8 Feb 2021 19:32:55 +0000 (19:32 +0000)
 - Handle generic function calling itself or another generic function in
   stenciling. This is easy - after it is created, just scan an
   instantiated generic function for function instantiations (that may
   needed to be stenciled), just like non-generic functions. The types
   in the function instantiation will already have been set by the
   stenciling.

 - Handle OTYPE nodes in subster.node() (allows for generic type
   conversions).

 - Eliminated some duplicated work in subster.typ().

 - Added new test case fact.go that tests a generic function calling
   itself, and simple generic type conversions.

 - Cause an error if a generic function is to be exported (which we
   don't handle yet).

 - Fixed some suggested changes in the add.go test case that I missed in
   the last review.

Change-Id: I5d61704254c27962f358d5a3d2e0c62a5099f148
Reviewed-on: https://go-review.googlesource.com/c/go/+/290469
Trust: Dan Scales <danscales@google.com>
Trust: Robert Griesemer <gri@golang.org>
Reviewed-by: Robert Griesemer <gri@golang.org>
src/cmd/compile/internal/noder/object.go
src/cmd/compile/internal/noder/stencil.go
test/typeparam/fact.go [new file with mode: 0644]
test/typeparam/sum.go [moved from test/typeparam/add.go with 61% similarity]

index c740285ca2ea3bca7dac8621c58270e2fada613d..b4e5c022dbfe8d80400da3025d6b7c19fc8e5022 100644 (file)
@@ -155,6 +155,9 @@ func (g *irgen) objFinish(name *ir.Name, class ir.Class, typ *types.Type) {
                        break // methods are exported with their receiver type
                }
                if types.IsExported(sym.Name) {
+                       if name.Class == ir.PFUNC && name.Type().NumTParams() > 0 {
+                               base.FatalfAt(name.Pos(), "Cannot export a generic function (yet): %v", name)
+                       }
                        typecheck.Export(name)
                }
                if base.Flag.AsmHdr != "" && !name.Sym().Asm() {
index 3c6c7f4a8c15aabeb0328f606a78872e882a8540..0c4eadcf44854956d31ab990c43c2d95d2a88a28 100644 (file)
@@ -20,7 +20,11 @@ import (
 // creates the required stencils for simple generic functions.
 func (g *irgen) stencil() {
        g.target.Stencils = make(map[*types.Sym]*ir.Func)
-       for _, decl := range g.target.Decls {
+       // Don't use range(g.target.Decls) - we also want to process any new instantiated
+       // functions that are created during this loop, in order to handle generic
+       // functions calling other generic functions.
+       for i := 0; i < len(g.target.Decls); i++ {
+               decl := g.target.Decls[i]
                if decl.Op() != ir.ODCLFUNC || decl.Type().NumTParams() > 0 {
                        // Skip any non-function declarations and skip generic functions
                        continue
@@ -142,6 +146,9 @@ func (subst *subster) node(n ir.Node) ir.Node {
        var edit func(ir.Node) ir.Node
        edit = func(x ir.Node) ir.Node {
                switch x.Op() {
+               case ir.OTYPE:
+                       return ir.TypeNode(subst.typ(x.Type()))
+
                case ir.ONAME:
                        name := x.(*ir.Name)
                        if v := subst.vars[name]; v != nil {
@@ -211,21 +218,21 @@ func (subst *subster) typ(t *types.Type) *types.Type {
        case types.TARRAY:
                elem := t.Elem()
                newelem := subst.typ(elem)
-               if subst.typ(elem) != elem {
+               if newelem != elem {
                        return types.NewArray(newelem, t.NumElem())
                }
 
        case types.TPTR:
                elem := t.Elem()
                newelem := subst.typ(elem)
-               if subst.typ(elem) != elem {
+               if newelem != elem {
                        return types.NewPtr(newelem)
                }
 
        case types.TSLICE:
                elem := t.Elem()
                newelem := subst.typ(elem)
-               if subst.typ(elem) != elem {
+               if newelem != elem {
                        return types.NewSlice(newelem)
                }
 
diff --git a/test/typeparam/fact.go b/test/typeparam/fact.go
new file mode 100644 (file)
index 0000000..e5e0ad4
--- /dev/null
@@ -0,0 +1,35 @@
+// run -gcflags=-G=3
+
+// Copyright 2021 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package main
+
+import (
+       "fmt"
+)
+
+
+func fact[T interface { type float64 }](n T) T {
+       if n == T(1) {
+               return T(1)
+       }
+       return n * fact(n - T(1))
+}
+
+func main() {
+       got := fact(4.0)
+       want := 24.0
+       if got != want {
+               panic(fmt.Sprintf("Got %f, want %f", got, want))
+       }
+
+       // Re-enable when types2 bug is fixed (can't do T(1) with more than one
+       // type in the type list).
+       //got = fact(5)
+       //want = 120
+       //if want != got {
+       //      panic(fmt.Sprintf("Want %d, got %d", want, got))
+       //}
+}
similarity index 61%
rename from test/typeparam/add.go
rename to test/typeparam/sum.go
index b0cf76d3ee313bbafc5396d0867e0068033b4a6b..72511c2fe5b4ad9e1a86381dd4f6539017835004 100644 (file)
@@ -10,7 +10,7 @@ import (
        "fmt"
 )
 
-func add[T interface{ type int, float64 }](vec []T) T {
+func sum[T interface{ type int, float64 }](vec []T) T {
        var sum T
        for _, elt := range vec {
                sum = sum + elt
@@ -28,23 +28,23 @@ func abs(f float64) float64 {
 func main() {
        vec1 := []int{3, 4}
        vec2 := []float64{5.8, 9.6}
+       got := sum[int](vec1)
        want := vec1[0] + vec1[1]
-       got := add[int](vec1)
-       if want != got {
-               panic(fmt.Sprintf("Want %d, got %d", want, got))
+       if got != want {
+               panic(fmt.Sprintf("Got %d, want %d", got, want))
        }
-       got = add(vec1)
+       got = sum(vec1)
        if want != got {
-               panic(fmt.Sprintf("Want %d, got %d", want, got))
+               panic(fmt.Sprintf("Got %d, want %d", got, want))
        }
 
        fwant := vec2[0] + vec2[1]
-       fgot := add[float64](vec2)
+       fgot := sum[float64](vec2)
        if abs(fgot - fwant) > 1e-10 {
-               panic(fmt.Sprintf("Want %f, got %f", fwant, fgot))
+               panic(fmt.Sprintf("Got %f, want %f", fgot, fwant))
        }
-       fgot = add(vec2)
+       fgot = sum(vec2)
        if abs(fgot - fwant) > 1e-10 {
-               panic(fmt.Sprintf("Want %f, got %f", fwant, fgot))
+               panic(fmt.Sprintf("Got %f, want %f", fgot, fwant))
        }
 }