From a5ca63528793a8e0f248d9a95a07845328c3d800 Mon Sep 17 00:00:00 2001 From: Robert Griesemer Date: Wed, 13 Apr 2011 09:38:13 -0700 Subject: [PATCH] gofmt: avoid endless loops With the (partial) resolution of identifiers done by the go/parser, ast.Objects point may introduce cycles in the AST. Don't follow *ast.Objects, and replace them with nil instead (they are likely incorrect after a rewrite anyway). - minor manual cleanups after reflect change automatic rewrite - includes fix by rsc related to reflect change Fixes #1667. R=rsc CC=golang-dev https://golang.org/cl/4387044 --- src/cmd/gofmt/rewrite.go | 48 ++++++++++++++++++++++++++++++---------- 1 file changed, 36 insertions(+), 12 deletions(-) diff --git a/src/cmd/gofmt/rewrite.go b/src/cmd/gofmt/rewrite.go index 4590ccb58b..93643dced2 100644 --- a/src/cmd/gofmt/rewrite.go +++ b/src/cmd/gofmt/rewrite.go @@ -46,6 +46,16 @@ func parseExpr(s string, what string) ast.Expr { } +// Keep this function for debugging. +/* +func dump(msg string, val reflect.Value) { + fmt.Printf("%s:\n", msg) + ast.Print(fset, val.Interface()) + fmt.Println() +} +*/ + + // rewriteFile applies the rewrite rule 'pattern -> replace' to an entire file. func rewriteFile(pattern, replace ast.Expr, p *ast.File) *ast.File { m := make(map[string]reflect.Value) @@ -82,12 +92,29 @@ func setValue(x, y reflect.Value) { } +// Values/types for special cases. +var ( + objectPtrNil = reflect.NewValue((*ast.Object)(nil)) + + identType = reflect.Typeof((*ast.Ident)(nil)) + objectPtrType = reflect.Typeof((*ast.Object)(nil)) + positionType = reflect.Typeof(token.NoPos) +) + + // apply replaces each AST field x in val with f(x), returning val. // To avoid extra conversions, f operates on the reflect.Value form. func apply(f func(reflect.Value) reflect.Value, val reflect.Value) reflect.Value { if !val.IsValid() { return reflect.Value{} } + + // *ast.Objects introduce cycles and are likely incorrect after + // rewrite; don't follow them but replace with nil instead + if val.Type() == objectPtrType { + return objectPtrNil + } + switch v := reflect.Indirect(val); v.Kind() { case reflect.Slice: for i := 0; i < v.Len(); i++ { @@ -107,10 +134,6 @@ func apply(f func(reflect.Value) reflect.Value, val reflect.Value) reflect.Value } -var positionType = reflect.Typeof(token.NoPos) -var identType = reflect.Typeof((*ast.Ident)(nil)) - - func isWildcard(s string) bool { rune, size := utf8.DecodeRuneInString(s) return size == len(s) && unicode.IsLower(rune) @@ -148,9 +171,6 @@ func match(m map[string]reflect.Value, pattern, val reflect.Value) bool { // Special cases. switch pattern.Type() { - case positionType: - // token positions don't need to match - return true case identType: // For identifiers, only the names need to match // (and none of the other *ast.Object information). @@ -159,6 +179,9 @@ func match(m map[string]reflect.Value, pattern, val reflect.Value) bool { p := pattern.Interface().(*ast.Ident) v := val.Interface().(*ast.Ident) return p == nil && v == nil || p != nil && v != nil && p.Name == v.Name + case objectPtrType, positionType: + // object pointers and token positions don't need to match + return true } p := reflect.Indirect(pattern) @@ -169,7 +192,6 @@ func match(m map[string]reflect.Value, pattern, val reflect.Value) bool { switch p.Kind() { case reflect.Slice: - v := v if p.Len() != v.Len() { return false } @@ -181,7 +203,6 @@ func match(m map[string]reflect.Value, pattern, val reflect.Value) bool { return true case reflect.Struct: - v := v if p.NumField() != v.NumField() { return false } @@ -193,7 +214,6 @@ func match(m map[string]reflect.Value, pattern, val reflect.Value) bool { return true case reflect.Interface: - v := v return match(m, p.Elem(), v.Elem()) } @@ -247,12 +267,16 @@ func subst(m map[string]reflect.Value, pattern reflect.Value, pos reflect.Value) case reflect.Ptr: v := reflect.Zero(p.Type()) - v.Set(subst(m, p.Elem(), pos).Addr()) + if elem := p.Elem(); elem.IsValid() { + v.Set(subst(m, elem, pos).Addr()) + } return v case reflect.Interface: v := reflect.Zero(p.Type()) - v.Set(subst(m, p.Elem(), pos)) + if elem := p.Elem(); elem.IsValid() { + v.Set(subst(m, elem, pos)) + } return v } -- 2.48.1