return nil
}
-// isTestMain tells whether fn is a TestMain(m *testing.M) function.
-func isTestMain(fn *ast.FuncDecl) bool {
- if fn.Name.String() != "TestMain" ||
- fn.Type.Results != nil && len(fn.Type.Results.List) > 0 ||
- fn.Type.Params == nil ||
+// isTestFunc tells whether fn has the type of a testing function. arg
+// specifies the parameter type we look for: B, M or T.
+func isTestFunc(fn *ast.FuncDecl, arg string) bool {
+ if fn.Type.Results != nil && len(fn.Type.Results.List) > 0 ||
+ fn.Type.Params.List == nil ||
len(fn.Type.Params.List) != 1 ||
len(fn.Type.Params.List[0].Names) > 1 {
return false
// We can't easily check that the type is *testing.M
// because we don't know how testing has been imported,
// but at least check that it's *M or *something.M.
- if name, ok := ptr.X.(*ast.Ident); ok && name.Name == "M" {
+ // Same applies for B and T.
+ if name, ok := ptr.X.(*ast.Ident); ok && name.Name == arg {
return true
}
- if sel, ok := ptr.X.(*ast.SelectorExpr); ok && sel.Sel.Name == "M" {
+ if sel, ok := ptr.X.(*ast.SelectorExpr); ok && sel.Sel.Name == arg {
return true
}
return false
}
name := n.Name.String()
switch {
- case isTestMain(n):
+ case name == "TestMain" && isTestFunc(n, "M"):
if t.TestMain != nil {
return errors.New("multiple definitions of TestMain")
}
t.TestMain = &testFunc{pkg, name, ""}
*doImport, *seen = true, true
case isTest(name, "Test"):
+ if !isTestFunc(n, "T") {
+ return fmt.Errorf("wrong type for %s", name)
+ }
t.Tests = append(t.Tests, testFunc{pkg, name, ""})
*doImport, *seen = true, true
case isTest(name, "Benchmark"):
+ if !isTestFunc(n, "B") {
+ return fmt.Errorf("wrong type for %s", name)
+ }
t.Benchmarks = append(t.Benchmarks, testFunc{pkg, name, ""})
*doImport, *seen = true, true
}