From 6b34eba007052b5985abc0a3ff1e90316ec28d91 Mon Sep 17 00:00:00 2001 From: Robert Griesemer Date: Mon, 4 Mar 2013 14:40:12 -0800 Subject: [PATCH] go/types: "missing return" check Implementation closely based on Russ' CL 7440047. Future work: The error messages could be better (e.g., instead of "missing return" it might say "missing return (no default in switch)", etc.). R=adonovan, rsc CC=golang-dev https://golang.org/cl/7437049 --- src/pkg/go/types/check.go | 3 + src/pkg/go/types/check_test.go | 1 + src/pkg/go/types/return.go | 186 +++++++++++++++++++++++++++ src/pkg/go/types/testdata/decls1.src | 8 +- src/pkg/go/types/testdata/stmt1.src | 164 +++++++++++++++++++++++ 5 files changed, 358 insertions(+), 4 deletions(-) create mode 100644 src/pkg/go/types/return.go create mode 100644 src/pkg/go/types/testdata/stmt1.src diff --git a/src/pkg/go/types/check.go b/src/pkg/go/types/check.go index f7b87e30c6..19f4c34d11 100644 --- a/src/pkg/go/types/check.go +++ b/src/pkg/go/types/check.go @@ -481,6 +481,9 @@ func check(ctxt *Context, fset *token.FileSet, files []*ast.File) (pkg *Package, } check.funcsig = f.sig check.stmtList(f.body.List) + if len(f.sig.Results) > 0 && f.body != nil && !check.isTerminating(f.body, "") { + check.errorf(f.body.Rbrace, "missing return") + } } // remaining untyped expressions must indeed be untyped diff --git a/src/pkg/go/types/check_test.go b/src/pkg/go/types/check_test.go index 470f3a1a93..28308a579a 100644 --- a/src/pkg/go/types/check_test.go +++ b/src/pkg/go/types/check_test.go @@ -57,6 +57,7 @@ var tests = []struct { {"builtins", []string{"testdata/builtins.src"}}, {"conversions", []string{"testdata/conversions.src"}}, {"stmt0", []string{"testdata/stmt0.src"}}, + {"stmt1", []string{"testdata/stmt1.src"}}, } var fset = token.NewFileSet() diff --git a/src/pkg/go/types/return.go b/src/pkg/go/types/return.go new file mode 100644 index 0000000000..5806fb25da --- /dev/null +++ b/src/pkg/go/types/return.go @@ -0,0 +1,186 @@ +// Copyright 2013 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// This file implements isTerminating. + +package types + +import ( + "go/ast" + "go/token" +) + +// isTerminating reports if s is a terminating statement. +// If s is labeled, label is the label name; otherwise s +// is "". +func (check *checker) isTerminating(s ast.Stmt, label string) bool { + switch s := s.(type) { + default: + unreachable() + + case *ast.BadStmt, *ast.DeclStmt, *ast.EmptyStmt, *ast.SendStmt, + *ast.IncDecStmt, *ast.AssignStmt, *ast.GoStmt, *ast.DeferStmt, + *ast.RangeStmt: + // no chance + + case *ast.LabeledStmt: + return check.isTerminating(s.Stmt, s.Label.Name) + + case *ast.ExprStmt: + // the predeclared panic() function is terminating + if call, _ := s.X.(*ast.CallExpr); call != nil { + if id, _ := call.Fun.(*ast.Ident); id != nil { + if obj := check.lookup(id); obj != nil { + // TODO(gri) Predeclared functions should be modelled as objects + // rather then ordinary functions that have a predeclared + // function type. This would simplify code here and else- + // where. + if f, _ := obj.(*Func); f != nil && f.Type == predeclaredFunctions[_Panic] { + return true + } + } + } + } + + case *ast.ReturnStmt: + return true + + case *ast.BranchStmt: + if s.Tok == token.GOTO || s.Tok == token.FALLTHROUGH { + return true + } + + case *ast.BlockStmt: + return check.isTerminatingList(s.List, "") + + case *ast.IfStmt: + if s.Else != nil && + check.isTerminating(s.Body, "") && + check.isTerminating(s.Else, "") { + return true + } + + case *ast.SwitchStmt: + return check.isTerminatingSwitch(s.Body, label) + + case *ast.TypeSwitchStmt: + return check.isTerminatingSwitch(s.Body, label) + + case *ast.SelectStmt: + for _, s := range s.Body.List { + cc := s.(*ast.CommClause) + if !check.isTerminatingList(cc.Body, "") || hasBreakList(cc.Body, label, true) { + return false + } + + } + return true + + case *ast.ForStmt: + if s.Cond == nil && !hasBreak(s.Body, label, true) { + return true + } + + } + + return false +} + +func (check *checker) isTerminatingList(list []ast.Stmt, label string) bool { + n := len(list) + return n > 0 && check.isTerminating(list[n-1], label) +} + +func (check *checker) isTerminatingSwitch(body *ast.BlockStmt, label string) bool { + hasDefault := false + for _, s := range body.List { + cc := s.(*ast.CaseClause) + if cc.List == nil { + hasDefault = true + } + if !check.isTerminatingList(cc.Body, "") || hasBreakList(cc.Body, label, true) { + return false + } + } + return hasDefault +} + +// hasBreak reports if s is or contains a break statement +// referring to the label-ed statement or implicit-ly the +// closest outer breakable statement. +func hasBreak(s ast.Stmt, label string, implicit bool) bool { + switch s := s.(type) { + default: + unreachable() + + case *ast.BadStmt, *ast.DeclStmt, *ast.EmptyStmt, *ast.ExprStmt, + *ast.SendStmt, *ast.IncDecStmt, *ast.AssignStmt, *ast.GoStmt, + *ast.DeferStmt, *ast.ReturnStmt: + // no chance + + case *ast.LabeledStmt: + return hasBreak(s.Stmt, label, implicit) + + case *ast.BranchStmt: + if s.Tok == token.BREAK { + if s.Label == nil { + return implicit + } + if s.Label.Name == label { + return true + } + } + + case *ast.BlockStmt: + return hasBreakList(s.List, label, implicit) + + case *ast.IfStmt: + if hasBreak(s.Body, label, implicit) || + s.Else != nil && hasBreak(s.Else, label, implicit) { + return true + } + + case *ast.CaseClause: + return hasBreakList(s.Body, label, implicit) + + case *ast.SwitchStmt: + if label != "" && hasBreak(s.Body, label, false) { + return true + } + + case *ast.TypeSwitchStmt: + if label != "" && hasBreak(s.Body, label, false) { + return true + } + + case *ast.CommClause: + return hasBreakList(s.Body, label, implicit) + + case *ast.SelectStmt: + if label != "" && hasBreak(s.Body, label, false) { + return true + } + + case *ast.ForStmt: + if label != "" && hasBreak(s.Body, label, false) { + return true + } + + case *ast.RangeStmt: + if label != "" && hasBreak(s.Body, label, false) { + return true + } + } + + return false +} + +func hasBreakList(list []ast.Stmt, label string, implicit bool) bool { + for _, s := range list { + if hasBreak(s, label, implicit) { + return true + } + } + return false +} diff --git a/src/pkg/go/types/testdata/decls1.src b/src/pkg/go/types/testdata/decls1.src index 2251f457f3..d74b7d9bed 100644 --- a/src/pkg/go/types/testdata/decls1.src +++ b/src/pkg/go/types/testdata/decls1.src @@ -109,9 +109,9 @@ func f0() {} func f1(a /* ERROR "not a type" */) {} func f2(a, b, c d /* ERROR "not a type" */) {} -func f3() int {} -func f4() a /* ERROR "not a type" */ {} -func f5() (a, b, c d /* ERROR "not a type" */) {} +func f3() int { return 0 } +func f4() a /* ERROR "not a type" */ { return 0 /* ERROR "cannot convert" */ } +func f5() (a, b, c d /* ERROR "not a type" */) { return } func f6(a, b, c int) complex128 { return 0 } @@ -128,5 +128,5 @@ func (x *T) m3() {} func init() {} func /* ERROR "no arguments and no return values" */ init(int) {} func /* ERROR "no arguments and no return values" */ init() int { return 0 } -func /* ERROR "no arguments and no return values" */ init(int) int {} +func /* ERROR "no arguments and no return values" */ init(int) int { return 0 } func (T) init(int) int { return 0 } diff --git a/src/pkg/go/types/testdata/stmt1.src b/src/pkg/go/types/testdata/stmt1.src new file mode 100644 index 0000000000..537c3f4a34 --- /dev/null +++ b/src/pkg/go/types/testdata/stmt1.src @@ -0,0 +1,164 @@ +// Copyright 2013 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// terminating statements + +package stmt1 + +func _() {} + +func _() int {} /* ERROR "missing return" */ + +func _() int { panic(0) } + +// block statements +func _(x, y int) (z int) { + { + return + } +} + +func _(x, y int) (z int) { + { + } +} /* ERROR "missing return" */ + +// if statements +func _(x, y int) (z int) { + if x < y { return } + return 1 +} + +func _(x, y int) (z int) { + if x < y { return } +} /* ERROR "missing return" */ + +func _(x, y int) (z int) { + if x < y { + } else { return 1 + } +} /* ERROR "missing return" */ + +func _(x, y int) (z int) { + if x < y { return + } else { return + } +} + +// for statements +func _(x, y int) (z int) { + for x < y { + return + } +} /* ERROR "missing return" */ + +func _(x, y int) (z int) { + for { + return + } +} + +func _(x, y int) (z int) { + for { + return + break + } +} /* ERROR "missing return" */ + +func _(x, y int) (z int) { + for { + for { break } + return + } +} + +func _(x, y int) (z int) { +L: for { + for { break L } + return + } +} /* ERROR "missing return" */ + +// switch statements +func _(x, y int) (z int) { + switch x { + case 0: return + default: return + } +} + +func _(x, y int) (z int) { + switch x { + case 0: return + } +} /* ERROR "missing return" */ + +func _(x, y int) (z int) { + switch x { + case 0: return + case 1: break + } +} /* ERROR "missing return" */ + +func _(x, y int) (z int) { + switch x { + case 0: return + default: + switch y { + case 0: break + } + panic(0) + } +} + +func _(x, y int) (z int) { +L: switch x { + case 0: return + default: + switch y { + case 0: break L + } + panic(0) + } +} /* ERROR "missing return" */ + +// select statements +func _(ch chan int) (z int) { + select {} +} // nice! + +func _(ch chan int) (z int) { + select { + default: break + } +} /* ERROR "missing return" */ + +func _(ch chan int) (z int) { + select { + case <-ch: return + default: break + } +} /* ERROR "missing return" */ + +func _(ch chan int) (z int) { + select { + case <-ch: return + default: + for i := 0; i < 10; i++ { + break + } + return + } +} + +func _(ch chan int) (z int) { +L: select { + case <-ch: return + default: + for i := 0; i < 10; i++ { + break L + } + return + } +} /* ERROR "missing return" */ -- 2.48.1