From c2d3e1123c2f49aab02260f125a940ed723e42a0 Mon Sep 17 00:00:00 2001 From: Caio Marcelo de Oliveira Filho Date: Sat, 20 Feb 2016 18:36:03 -0200 Subject: [PATCH] cmd/go: better error for test functions with wrong signature Check the function types before compiling the tests. Extend the same approach taken by the type check used for TestMain function. To keep existing behavior, wrong arguments for TestMain are ignored instead of causing an error. Fixes #14226. Change-Id: I488a2555cddb273d35c1a8c4645bb5435c9eb91d Reviewed-on: https://go-review.googlesource.com/19763 Run-TryBot: Brad Fitzpatrick TryBot-Result: Gobot Gobot Reviewed-by: Brad Fitzpatrick --- src/cmd/go/test.go | 23 +++++++++++++++-------- 1 file changed, 15 insertions(+), 8 deletions(-) diff --git a/src/cmd/go/test.go b/src/cmd/go/test.go index 995ba146f5..e23f939255 100644 --- a/src/cmd/go/test.go +++ b/src/cmd/go/test.go @@ -1206,11 +1206,11 @@ func (b *builder) notest(a *action) error { return nil } -// isTestMain tells whether fn is a TestMain(m *testing.M) function. -func isTestMain(fn *ast.FuncDecl) bool { - if fn.Name.String() != "TestMain" || - fn.Type.Results != nil && len(fn.Type.Results.List) > 0 || - fn.Type.Params == nil || +// isTestFunc tells whether fn has the type of a testing function. arg +// specifies the parameter type we look for: B, M or T. +func isTestFunc(fn *ast.FuncDecl, arg string) bool { + if fn.Type.Results != nil && len(fn.Type.Results.List) > 0 || + fn.Type.Params.List == nil || len(fn.Type.Params.List) != 1 || len(fn.Type.Params.List[0].Names) > 1 { return false @@ -1222,10 +1222,11 @@ func isTestMain(fn *ast.FuncDecl) bool { // We can't easily check that the type is *testing.M // because we don't know how testing has been imported, // but at least check that it's *M or *something.M. - if name, ok := ptr.X.(*ast.Ident); ok && name.Name == "M" { + // Same applies for B and T. + if name, ok := ptr.X.(*ast.Ident); ok && name.Name == arg { return true } - if sel, ok := ptr.X.(*ast.SelectorExpr); ok && sel.Sel.Name == "M" { + if sel, ok := ptr.X.(*ast.SelectorExpr); ok && sel.Sel.Name == arg { return true } return false @@ -1344,16 +1345,22 @@ func (t *testFuncs) load(filename, pkg string, doImport, seen *bool) error { } name := n.Name.String() switch { - case isTestMain(n): + case name == "TestMain" && isTestFunc(n, "M"): if t.TestMain != nil { return errors.New("multiple definitions of TestMain") } t.TestMain = &testFunc{pkg, name, ""} *doImport, *seen = true, true case isTest(name, "Test"): + if !isTestFunc(n, "T") { + return fmt.Errorf("wrong type for %s", name) + } t.Tests = append(t.Tests, testFunc{pkg, name, ""}) *doImport, *seen = true, true case isTest(name, "Benchmark"): + if !isTestFunc(n, "B") { + return fmt.Errorf("wrong type for %s", name) + } t.Benchmarks = append(t.Benchmarks, testFunc{pkg, name, ""}) *doImport, *seen = true, true } -- 2.48.1