From b1b3243a1b98121911c351bc16cb95489024dc1d Mon Sep 17 00:00:00 2001 From: Robert Griesemer Date: Thu, 6 Aug 2015 13:55:19 -0700 Subject: [PATCH] go/types: check for duplicate values in expression switches Fixes #11578. Change-Id: I29a542be247127f470ba6c39aac0d0f6a18de553 Reviewed-on: https://go-review.googlesource.com/13285 Reviewed-by: Alan Donovan --- src/go/types/stmt.go | 91 +++++++++++++++++++++++++++++---- src/go/types/testdata/stmt0.src | 79 ++++++++++++++++++++++++++-- 2 files changed, 155 insertions(+), 15 deletions(-) diff --git a/src/go/types/stmt.go b/src/go/types/stmt.go index 0ab2492d09..062b767c1a 100644 --- a/src/go/types/stmt.go +++ b/src/go/types/stmt.go @@ -155,20 +155,84 @@ func (check *Checker) suspendedCall(keyword string, call *ast.CallExpr) { check.errorf(x.pos(), "%s %s %s", keyword, msg, &x) } -func (check *Checker) caseValues(x *operand, values []ast.Expr) { - // No duplicate checking for now. See issue 4524. +// goVal returns the Go value for val, or nil. +func goVal(val constant.Value) interface{} { + // val should exist, but be conservative and check + if val == nil { + return nil + } + // Match implementation restriction of other compilers. + // gc only checks duplicates for integer, floating-point + // and string values, so only create Go values for these + // types. + switch val.Kind() { + case constant.Int: + if x, ok := constant.Int64Val(val); ok { + return x + } + if x, ok := constant.Uint64Val(val); ok { + return x + } + case constant.Float: + if x, ok := constant.Float64Val(val); ok { + return x + } + case constant.String: + return constant.StringVal(val) + } + return nil +} + +// A valueMap maps a case value (of a basic Go type) to a list of positions +// where the same case value appeared, together with the corresponding case +// types. +// Since two case values may have the same "underlying" value but different +// types we need to also check the value's types (e.g., byte(1) vs myByte(1)) +// when the switch expression is of interface type. +type ( + valueMap map[interface{}][]valueType // underlying Go value -> valueType + valueType struct { + pos token.Pos + typ Type + } +) + +func (check *Checker) caseValues(x *operand, values []ast.Expr, seen valueMap) { +L: for _, e := range values { var v operand check.expr(&v, e) if x.mode == invalid || v.mode == invalid { - continue + continue L } check.convertUntyped(&v, x.typ) if v.mode == invalid { - continue + continue L } // Order matters: By comparing v against x, error positions are at the case values. - check.comparison(&v, x, token.EQL) + res := v // keep original v unchanged + check.comparison(&res, x, token.EQL) + if res.mode == invalid { + continue L + } + if v.mode != constant_ { + continue L // we're done + } + // look for duplicate values + if val := goVal(v.val); val != nil { + if list := seen[val]; list != nil { + // look for duplicate types for a given value + // (quadratic algorithm, but these lists tend to be very short) + for _, vt := range list { + if Identical(v.typ, vt.typ) { + check.errorf(v.pos(), "duplicate case %s in expression switch", &v) + check.error(vt.pos, "\tprevious case") // secondary error, \t indented + continue L + } + } + } + seen[val] = append(seen[val], valueType{v.pos(), v.typ}) + } } } @@ -177,15 +241,19 @@ L: for _, e := range types { T = check.typOrNil(e) if T == Typ[Invalid] { - continue + continue L } - // complain about duplicate types - // TODO(gri) use a type hash to avoid quadratic algorithm + // look for duplicate types + // (quadratic algorithm, but type switches tend to be reasonably small) for t, pos := range seen { if T == nil && t == nil || T != nil && t != nil && Identical(T, t) { // talk about "case" rather than "type" because of nil case - check.error(e.Pos(), "duplicate case in type switch") - check.errorf(pos, "\tprevious case %s", T) // secondary error, \t indented + Ts := "nil" + if T != nil { + Ts = T.String() + } + check.errorf(e.Pos(), "duplicate case %s in type switch", Ts) + check.error(pos, "\tprevious case") // secondary error, \t indented continue L } } @@ -408,13 +476,14 @@ func (check *Checker) stmt(ctxt stmtContext, s ast.Stmt) { check.multipleDefaults(s.Body.List) + seen := make(valueMap) // map of seen case values to positions and types for i, c := range s.Body.List { clause, _ := c.(*ast.CaseClause) if clause == nil { check.invalidAST(c.Pos(), "incorrect expression switch case") continue } - check.caseValues(&x, clause.List) + check.caseValues(&x, clause.List, seen) check.openScope(clause, "case") inner := inner if i+1 < len(s.Body.List) { diff --git a/src/go/types/testdata/stmt0.src b/src/go/types/testdata/stmt0.src index 7e28c23fb0..e946066c49 100644 --- a/src/go/types/testdata/stmt0.src +++ b/src/go/types/testdata/stmt0.src @@ -438,14 +438,85 @@ func switches0() { switch x { case 1: - case 1 /* DISABLED "duplicate case" */ : + case 1 /* ERROR "duplicate case" */ : + case ( /* ERROR "duplicate case" */ 1): case 2, 3, 4: - case 1 /* DISABLED "duplicate case" */ : + case 5, 1 /* ERROR "duplicate case" */ : } switch uint64(x) { - case 1 /* DISABLED duplicate case */ <<64-1: - case 1 /* DISABLED duplicate case */ <<64-1: + case 1<<64 - 1: + case 1 /* ERROR duplicate case */ <<64 - 1: + case 2, 3, 4: + case 5, 1 /* ERROR duplicate case */ <<64 - 1: + } + + var y32 float32 + switch y32 { + case 1.1: + case 11/10: // integer division! + case 11. /* ERROR duplicate case */ /10: + case 2, 3.0, 4.1: + case 5.2, 1.10 /* ERROR duplicate case */ : + } + + var y64 float64 + switch y64 { + case 1.1: + case 11/10: // integer division! + case 11. /* ERROR duplicate case */ /10: + case 2, 3.0, 4.1: + case 5.2, 1.10 /* ERROR duplicate case */ : + } + + var s string + switch s { + case "foo": + case "foo" /* ERROR duplicate case */ : + case "f" /* ERROR duplicate case */ + "oo": + case "abc", "def", "ghi": + case "jkl", "foo" /* ERROR duplicate case */ : + } + + type T int + type F float64 + type S string + type B bool + var i interface{} + switch i { + case nil: + case nil: // no duplicate detection + case (*int)(nil): + case (*int)(nil): // do duplicate detection + case 1: + case byte(1): + case int /* ERROR duplicate case */ (1): + case T(1): + case 1.0: + case F(1.0): + case F /* ERROR duplicate case */ (1.0): + case "hello": + case S("hello"): + case S /* ERROR duplicate case */ ("hello"): + case 1==1, B(false): + case false, B(2==2): + } + + // switch on array + var a [3]int + switch a { + case [3]int{1, 2, 3}: + case [3]int{1, 2, 3}: // no duplicate detection + case [ /* ERROR "mismatched types */ 4]int{4, 5, 6}: + } + + // switch on channel + var c1, c2 chan int + switch c1 { + case nil: + case c1: + case c2: + case c1, c2: // no duplicate detection } } -- 2.48.1