]> Cypherpunks repositories - gostls13.git/commitdiff
cmd/compile: don't pull constant offsets out of pointer arithmetic
authorKeith Randall <khr@golang.org>
Mon, 24 Feb 2025 21:07:29 +0000 (13:07 -0800)
committerKeith Randall <khr@golang.org>
Wed, 26 Feb 2025 17:39:12 +0000 (09:39 -0800)
This could lead to manufacturing a pointer that points outside
its original allocation.

Bug was introduced in CL 629858.

Fixes #71932

Change-Id: Ia86ab0b65ce5f80a8e0f4f4c81babd07c5904f8d
Reviewed-on: https://go-review.googlesource.com/c/go/+/652078
Reviewed-by: Keith Randall <khr@google.com>
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
Reviewed-by: Cherry Mui <cherryyz@google.com>
src/cmd/compile/internal/ssa/_gen/ARM64.rules
src/cmd/compile/internal/ssa/rewriteARM64.go
test/fixedbugs/issue71932.go [new file with mode: 0644]

index 8618c24ebf816ba61299637b710c1de0b221b153..359c1811b0ee9d1bcbcf7132711379ec74ab15a3 100644 (file)
 (SUB a l:(MNEGW x y)) && v.Type.Size() <= 4 && l.Uses==1 && clobber(l) => (MADDW a x y)
 
 // madd/msub can't take constant arguments, so do a bit of reordering if a non-constant is available.
-(ADD a p:(ADDconst [c] m:((MUL|MULW|MNEG|MNEGW) _ _))) && p.Uses==1 && m.Uses==1 => (ADDconst [c] (ADD <v.Type> a m))
-(ADD a p:(SUBconst [c] m:((MUL|MULW|MNEG|MNEGW) _ _))) && p.Uses==1 && m.Uses==1 => (SUBconst [c] (ADD <v.Type> a m))
-(SUB a p:(ADDconst [c] m:((MUL|MULW|MNEG|MNEGW) _ _))) && p.Uses==1 && m.Uses==1 => (SUBconst [c] (SUB <v.Type> a m))
-(SUB a p:(SUBconst [c] m:((MUL|MULW|MNEG|MNEGW) _ _))) && p.Uses==1 && m.Uses==1 => (ADDconst [c] (SUB <v.Type> a m))
+// Note: don't reorder arithmetic concerning pointers, as we must ensure that
+// no intermediate computations are invalid pointers.
+(ADD <t> a p:(ADDconst [c] m:((MUL|MULW|MNEG|MNEGW) _ _))) && p.Uses==1 && m.Uses==1 && !t.IsPtrShaped() => (ADDconst [c] (ADD <v.Type> a m))
+(ADD <t> a p:(SUBconst [c] m:((MUL|MULW|MNEG|MNEGW) _ _))) && p.Uses==1 && m.Uses==1 && !t.IsPtrShaped() => (SUBconst [c] (ADD <v.Type> a m))
+(SUB <t> a p:(ADDconst [c] m:((MUL|MULW|MNEG|MNEGW) _ _))) && p.Uses==1 && m.Uses==1 && !t.IsPtrShaped() => (SUBconst [c] (SUB <v.Type> a m))
+(SUB <t> a p:(SUBconst [c] m:((MUL|MULW|MNEG|MNEGW) _ _))) && p.Uses==1 && m.Uses==1 && !t.IsPtrShaped() => (ADDconst [c] (SUB <v.Type> a m))
 
 // optimize ADCSflags, SBCSflags and friends
 (ADCSflags x y (Select1 <types.TypeFlags> (ADDSconstflags [-1] (ADCzerocarry <typ.UInt64> c)))) => (ADCSflags x y c)
index 909414ee17ee7d7b6cc28edbf891ff032c463e45..d7f99bc46cbede19987ea5dc0a540443ffbcf442 100644 (file)
@@ -1335,10 +1335,11 @@ func rewriteValueARM64_OpARM64ADD(v *Value) bool {
                }
                break
        }
-       // match: (ADD a p:(ADDconst [c] m:(MUL _ _)))
-       // cond: p.Uses==1 && m.Uses==1
+       // match: (ADD <t> a p:(ADDconst [c] m:(MUL _ _)))
+       // cond: p.Uses==1 && m.Uses==1 && !t.IsPtrShaped()
        // result: (ADDconst [c] (ADD <v.Type> a m))
        for {
+               t := v.Type
                for _i0 := 0; _i0 <= 1; _i0, v_0, v_1 = _i0+1, v_1, v_0 {
                        a := v_0
                        p := v_1
@@ -1347,7 +1348,7 @@ func rewriteValueARM64_OpARM64ADD(v *Value) bool {
                        }
                        c := auxIntToInt64(p.AuxInt)
                        m := p.Args[0]
-                       if m.Op != OpARM64MUL || !(p.Uses == 1 && m.Uses == 1) {
+                       if m.Op != OpARM64MUL || !(p.Uses == 1 && m.Uses == 1 && !t.IsPtrShaped()) {
                                continue
                        }
                        v.reset(OpARM64ADDconst)
@@ -1359,10 +1360,11 @@ func rewriteValueARM64_OpARM64ADD(v *Value) bool {
                }
                break
        }
-       // match: (ADD a p:(ADDconst [c] m:(MULW _ _)))
-       // cond: p.Uses==1 && m.Uses==1
+       // match: (ADD <t> a p:(ADDconst [c] m:(MULW _ _)))
+       // cond: p.Uses==1 && m.Uses==1 && !t.IsPtrShaped()
        // result: (ADDconst [c] (ADD <v.Type> a m))
        for {
+               t := v.Type
                for _i0 := 0; _i0 <= 1; _i0, v_0, v_1 = _i0+1, v_1, v_0 {
                        a := v_0
                        p := v_1
@@ -1371,7 +1373,7 @@ func rewriteValueARM64_OpARM64ADD(v *Value) bool {
                        }
                        c := auxIntToInt64(p.AuxInt)
                        m := p.Args[0]
-                       if m.Op != OpARM64MULW || !(p.Uses == 1 && m.Uses == 1) {
+                       if m.Op != OpARM64MULW || !(p.Uses == 1 && m.Uses == 1 && !t.IsPtrShaped()) {
                                continue
                        }
                        v.reset(OpARM64ADDconst)
@@ -1383,10 +1385,11 @@ func rewriteValueARM64_OpARM64ADD(v *Value) bool {
                }
                break
        }
-       // match: (ADD a p:(ADDconst [c] m:(MNEG _ _)))
-       // cond: p.Uses==1 && m.Uses==1
+       // match: (ADD <t> a p:(ADDconst [c] m:(MNEG _ _)))
+       // cond: p.Uses==1 && m.Uses==1 && !t.IsPtrShaped()
        // result: (ADDconst [c] (ADD <v.Type> a m))
        for {
+               t := v.Type
                for _i0 := 0; _i0 <= 1; _i0, v_0, v_1 = _i0+1, v_1, v_0 {
                        a := v_0
                        p := v_1
@@ -1395,7 +1398,7 @@ func rewriteValueARM64_OpARM64ADD(v *Value) bool {
                        }
                        c := auxIntToInt64(p.AuxInt)
                        m := p.Args[0]
-                       if m.Op != OpARM64MNEG || !(p.Uses == 1 && m.Uses == 1) {
+                       if m.Op != OpARM64MNEG || !(p.Uses == 1 && m.Uses == 1 && !t.IsPtrShaped()) {
                                continue
                        }
                        v.reset(OpARM64ADDconst)
@@ -1407,10 +1410,11 @@ func rewriteValueARM64_OpARM64ADD(v *Value) bool {
                }
                break
        }
-       // match: (ADD a p:(ADDconst [c] m:(MNEGW _ _)))
-       // cond: p.Uses==1 && m.Uses==1
+       // match: (ADD <t> a p:(ADDconst [c] m:(MNEGW _ _)))
+       // cond: p.Uses==1 && m.Uses==1 && !t.IsPtrShaped()
        // result: (ADDconst [c] (ADD <v.Type> a m))
        for {
+               t := v.Type
                for _i0 := 0; _i0 <= 1; _i0, v_0, v_1 = _i0+1, v_1, v_0 {
                        a := v_0
                        p := v_1
@@ -1419,7 +1423,7 @@ func rewriteValueARM64_OpARM64ADD(v *Value) bool {
                        }
                        c := auxIntToInt64(p.AuxInt)
                        m := p.Args[0]
-                       if m.Op != OpARM64MNEGW || !(p.Uses == 1 && m.Uses == 1) {
+                       if m.Op != OpARM64MNEGW || !(p.Uses == 1 && m.Uses == 1 && !t.IsPtrShaped()) {
                                continue
                        }
                        v.reset(OpARM64ADDconst)
@@ -1431,10 +1435,11 @@ func rewriteValueARM64_OpARM64ADD(v *Value) bool {
                }
                break
        }
-       // match: (ADD a p:(SUBconst [c] m:(MUL _ _)))
-       // cond: p.Uses==1 && m.Uses==1
+       // match: (ADD <t> a p:(SUBconst [c] m:(MUL _ _)))
+       // cond: p.Uses==1 && m.Uses==1 && !t.IsPtrShaped()
        // result: (SUBconst [c] (ADD <v.Type> a m))
        for {
+               t := v.Type
                for _i0 := 0; _i0 <= 1; _i0, v_0, v_1 = _i0+1, v_1, v_0 {
                        a := v_0
                        p := v_1
@@ -1443,7 +1448,7 @@ func rewriteValueARM64_OpARM64ADD(v *Value) bool {
                        }
                        c := auxIntToInt64(p.AuxInt)
                        m := p.Args[0]
-                       if m.Op != OpARM64MUL || !(p.Uses == 1 && m.Uses == 1) {
+                       if m.Op != OpARM64MUL || !(p.Uses == 1 && m.Uses == 1 && !t.IsPtrShaped()) {
                                continue
                        }
                        v.reset(OpARM64SUBconst)
@@ -1455,10 +1460,11 @@ func rewriteValueARM64_OpARM64ADD(v *Value) bool {
                }
                break
        }
-       // match: (ADD a p:(SUBconst [c] m:(MULW _ _)))
-       // cond: p.Uses==1 && m.Uses==1
+       // match: (ADD <t> a p:(SUBconst [c] m:(MULW _ _)))
+       // cond: p.Uses==1 && m.Uses==1 && !t.IsPtrShaped()
        // result: (SUBconst [c] (ADD <v.Type> a m))
        for {
+               t := v.Type
                for _i0 := 0; _i0 <= 1; _i0, v_0, v_1 = _i0+1, v_1, v_0 {
                        a := v_0
                        p := v_1
@@ -1467,7 +1473,7 @@ func rewriteValueARM64_OpARM64ADD(v *Value) bool {
                        }
                        c := auxIntToInt64(p.AuxInt)
                        m := p.Args[0]
-                       if m.Op != OpARM64MULW || !(p.Uses == 1 && m.Uses == 1) {
+                       if m.Op != OpARM64MULW || !(p.Uses == 1 && m.Uses == 1 && !t.IsPtrShaped()) {
                                continue
                        }
                        v.reset(OpARM64SUBconst)
@@ -1479,10 +1485,11 @@ func rewriteValueARM64_OpARM64ADD(v *Value) bool {
                }
                break
        }
-       // match: (ADD a p:(SUBconst [c] m:(MNEG _ _)))
-       // cond: p.Uses==1 && m.Uses==1
+       // match: (ADD <t> a p:(SUBconst [c] m:(MNEG _ _)))
+       // cond: p.Uses==1 && m.Uses==1 && !t.IsPtrShaped()
        // result: (SUBconst [c] (ADD <v.Type> a m))
        for {
+               t := v.Type
                for _i0 := 0; _i0 <= 1; _i0, v_0, v_1 = _i0+1, v_1, v_0 {
                        a := v_0
                        p := v_1
@@ -1491,7 +1498,7 @@ func rewriteValueARM64_OpARM64ADD(v *Value) bool {
                        }
                        c := auxIntToInt64(p.AuxInt)
                        m := p.Args[0]
-                       if m.Op != OpARM64MNEG || !(p.Uses == 1 && m.Uses == 1) {
+                       if m.Op != OpARM64MNEG || !(p.Uses == 1 && m.Uses == 1 && !t.IsPtrShaped()) {
                                continue
                        }
                        v.reset(OpARM64SUBconst)
@@ -1503,10 +1510,11 @@ func rewriteValueARM64_OpARM64ADD(v *Value) bool {
                }
                break
        }
-       // match: (ADD a p:(SUBconst [c] m:(MNEGW _ _)))
-       // cond: p.Uses==1 && m.Uses==1
+       // match: (ADD <t> a p:(SUBconst [c] m:(MNEGW _ _)))
+       // cond: p.Uses==1 && m.Uses==1 && !t.IsPtrShaped()
        // result: (SUBconst [c] (ADD <v.Type> a m))
        for {
+               t := v.Type
                for _i0 := 0; _i0 <= 1; _i0, v_0, v_1 = _i0+1, v_1, v_0 {
                        a := v_0
                        p := v_1
@@ -1515,7 +1523,7 @@ func rewriteValueARM64_OpARM64ADD(v *Value) bool {
                        }
                        c := auxIntToInt64(p.AuxInt)
                        m := p.Args[0]
-                       if m.Op != OpARM64MNEGW || !(p.Uses == 1 && m.Uses == 1) {
+                       if m.Op != OpARM64MNEGW || !(p.Uses == 1 && m.Uses == 1 && !t.IsPtrShaped()) {
                                continue
                        }
                        v.reset(OpARM64SUBconst)
@@ -16647,10 +16655,11 @@ func rewriteValueARM64_OpARM64SUB(v *Value) bool {
                v.AddArg3(a, x, y)
                return true
        }
-       // match: (SUB a p:(ADDconst [c] m:(MUL _ _)))
-       // cond: p.Uses==1 && m.Uses==1
+       // match: (SUB <t> a p:(ADDconst [c] m:(MUL _ _)))
+       // cond: p.Uses==1 && m.Uses==1 && !t.IsPtrShaped()
        // result: (SUBconst [c] (SUB <v.Type> a m))
        for {
+               t := v.Type
                a := v_0
                p := v_1
                if p.Op != OpARM64ADDconst {
@@ -16658,7 +16667,7 @@ func rewriteValueARM64_OpARM64SUB(v *Value) bool {
                }
                c := auxIntToInt64(p.AuxInt)
                m := p.Args[0]
-               if m.Op != OpARM64MUL || !(p.Uses == 1 && m.Uses == 1) {
+               if m.Op != OpARM64MUL || !(p.Uses == 1 && m.Uses == 1 && !t.IsPtrShaped()) {
                        break
                }
                v.reset(OpARM64SUBconst)
@@ -16668,10 +16677,11 @@ func rewriteValueARM64_OpARM64SUB(v *Value) bool {
                v.AddArg(v0)
                return true
        }
-       // match: (SUB a p:(ADDconst [c] m:(MULW _ _)))
-       // cond: p.Uses==1 && m.Uses==1
+       // match: (SUB <t> a p:(ADDconst [c] m:(MULW _ _)))
+       // cond: p.Uses==1 && m.Uses==1 && !t.IsPtrShaped()
        // result: (SUBconst [c] (SUB <v.Type> a m))
        for {
+               t := v.Type
                a := v_0
                p := v_1
                if p.Op != OpARM64ADDconst {
@@ -16679,7 +16689,7 @@ func rewriteValueARM64_OpARM64SUB(v *Value) bool {
                }
                c := auxIntToInt64(p.AuxInt)
                m := p.Args[0]
-               if m.Op != OpARM64MULW || !(p.Uses == 1 && m.Uses == 1) {
+               if m.Op != OpARM64MULW || !(p.Uses == 1 && m.Uses == 1 && !t.IsPtrShaped()) {
                        break
                }
                v.reset(OpARM64SUBconst)
@@ -16689,10 +16699,11 @@ func rewriteValueARM64_OpARM64SUB(v *Value) bool {
                v.AddArg(v0)
                return true
        }
-       // match: (SUB a p:(ADDconst [c] m:(MNEG _ _)))
-       // cond: p.Uses==1 && m.Uses==1
+       // match: (SUB <t> a p:(ADDconst [c] m:(MNEG _ _)))
+       // cond: p.Uses==1 && m.Uses==1 && !t.IsPtrShaped()
        // result: (SUBconst [c] (SUB <v.Type> a m))
        for {
+               t := v.Type
                a := v_0
                p := v_1
                if p.Op != OpARM64ADDconst {
@@ -16700,7 +16711,7 @@ func rewriteValueARM64_OpARM64SUB(v *Value) bool {
                }
                c := auxIntToInt64(p.AuxInt)
                m := p.Args[0]
-               if m.Op != OpARM64MNEG || !(p.Uses == 1 && m.Uses == 1) {
+               if m.Op != OpARM64MNEG || !(p.Uses == 1 && m.Uses == 1 && !t.IsPtrShaped()) {
                        break
                }
                v.reset(OpARM64SUBconst)
@@ -16710,10 +16721,11 @@ func rewriteValueARM64_OpARM64SUB(v *Value) bool {
                v.AddArg(v0)
                return true
        }
-       // match: (SUB a p:(ADDconst [c] m:(MNEGW _ _)))
-       // cond: p.Uses==1 && m.Uses==1
+       // match: (SUB <t> a p:(ADDconst [c] m:(MNEGW _ _)))
+       // cond: p.Uses==1 && m.Uses==1 && !t.IsPtrShaped()
        // result: (SUBconst [c] (SUB <v.Type> a m))
        for {
+               t := v.Type
                a := v_0
                p := v_1
                if p.Op != OpARM64ADDconst {
@@ -16721,7 +16733,7 @@ func rewriteValueARM64_OpARM64SUB(v *Value) bool {
                }
                c := auxIntToInt64(p.AuxInt)
                m := p.Args[0]
-               if m.Op != OpARM64MNEGW || !(p.Uses == 1 && m.Uses == 1) {
+               if m.Op != OpARM64MNEGW || !(p.Uses == 1 && m.Uses == 1 && !t.IsPtrShaped()) {
                        break
                }
                v.reset(OpARM64SUBconst)
@@ -16731,10 +16743,11 @@ func rewriteValueARM64_OpARM64SUB(v *Value) bool {
                v.AddArg(v0)
                return true
        }
-       // match: (SUB a p:(SUBconst [c] m:(MUL _ _)))
-       // cond: p.Uses==1 && m.Uses==1
+       // match: (SUB <t> a p:(SUBconst [c] m:(MUL _ _)))
+       // cond: p.Uses==1 && m.Uses==1 && !t.IsPtrShaped()
        // result: (ADDconst [c] (SUB <v.Type> a m))
        for {
+               t := v.Type
                a := v_0
                p := v_1
                if p.Op != OpARM64SUBconst {
@@ -16742,7 +16755,7 @@ func rewriteValueARM64_OpARM64SUB(v *Value) bool {
                }
                c := auxIntToInt64(p.AuxInt)
                m := p.Args[0]
-               if m.Op != OpARM64MUL || !(p.Uses == 1 && m.Uses == 1) {
+               if m.Op != OpARM64MUL || !(p.Uses == 1 && m.Uses == 1 && !t.IsPtrShaped()) {
                        break
                }
                v.reset(OpARM64ADDconst)
@@ -16752,10 +16765,11 @@ func rewriteValueARM64_OpARM64SUB(v *Value) bool {
                v.AddArg(v0)
                return true
        }
-       // match: (SUB a p:(SUBconst [c] m:(MULW _ _)))
-       // cond: p.Uses==1 && m.Uses==1
+       // match: (SUB <t> a p:(SUBconst [c] m:(MULW _ _)))
+       // cond: p.Uses==1 && m.Uses==1 && !t.IsPtrShaped()
        // result: (ADDconst [c] (SUB <v.Type> a m))
        for {
+               t := v.Type
                a := v_0
                p := v_1
                if p.Op != OpARM64SUBconst {
@@ -16763,7 +16777,7 @@ func rewriteValueARM64_OpARM64SUB(v *Value) bool {
                }
                c := auxIntToInt64(p.AuxInt)
                m := p.Args[0]
-               if m.Op != OpARM64MULW || !(p.Uses == 1 && m.Uses == 1) {
+               if m.Op != OpARM64MULW || !(p.Uses == 1 && m.Uses == 1 && !t.IsPtrShaped()) {
                        break
                }
                v.reset(OpARM64ADDconst)
@@ -16773,10 +16787,11 @@ func rewriteValueARM64_OpARM64SUB(v *Value) bool {
                v.AddArg(v0)
                return true
        }
-       // match: (SUB a p:(SUBconst [c] m:(MNEG _ _)))
-       // cond: p.Uses==1 && m.Uses==1
+       // match: (SUB <t> a p:(SUBconst [c] m:(MNEG _ _)))
+       // cond: p.Uses==1 && m.Uses==1 && !t.IsPtrShaped()
        // result: (ADDconst [c] (SUB <v.Type> a m))
        for {
+               t := v.Type
                a := v_0
                p := v_1
                if p.Op != OpARM64SUBconst {
@@ -16784,7 +16799,7 @@ func rewriteValueARM64_OpARM64SUB(v *Value) bool {
                }
                c := auxIntToInt64(p.AuxInt)
                m := p.Args[0]
-               if m.Op != OpARM64MNEG || !(p.Uses == 1 && m.Uses == 1) {
+               if m.Op != OpARM64MNEG || !(p.Uses == 1 && m.Uses == 1 && !t.IsPtrShaped()) {
                        break
                }
                v.reset(OpARM64ADDconst)
@@ -16794,10 +16809,11 @@ func rewriteValueARM64_OpARM64SUB(v *Value) bool {
                v.AddArg(v0)
                return true
        }
-       // match: (SUB a p:(SUBconst [c] m:(MNEGW _ _)))
-       // cond: p.Uses==1 && m.Uses==1
+       // match: (SUB <t> a p:(SUBconst [c] m:(MNEGW _ _)))
+       // cond: p.Uses==1 && m.Uses==1 && !t.IsPtrShaped()
        // result: (ADDconst [c] (SUB <v.Type> a m))
        for {
+               t := v.Type
                a := v_0
                p := v_1
                if p.Op != OpARM64SUBconst {
@@ -16805,7 +16821,7 @@ func rewriteValueARM64_OpARM64SUB(v *Value) bool {
                }
                c := auxIntToInt64(p.AuxInt)
                m := p.Args[0]
-               if m.Op != OpARM64MNEGW || !(p.Uses == 1 && m.Uses == 1) {
+               if m.Op != OpARM64MNEGW || !(p.Uses == 1 && m.Uses == 1 && !t.IsPtrShaped()) {
                        break
                }
                v.reset(OpARM64ADDconst)
diff --git a/test/fixedbugs/issue71932.go b/test/fixedbugs/issue71932.go
new file mode 100644 (file)
index 0000000..d69b241
--- /dev/null
@@ -0,0 +1,50 @@
+// run
+
+// Copyright 2025 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package main
+
+import "runtime"
+
+const C = 16
+
+type T [C * C]byte
+
+func main() {
+       var ts []*T
+
+       for i := 0; i < 100; i++ {
+               t := new(T)
+               // Save every even object.
+               if i%2 == 0 {
+                       ts = append(ts, t)
+               }
+       }
+       // Make sure the odd objects are collected.
+       runtime.GC()
+
+       for _, t := range ts {
+               f(t, C, C)
+       }
+}
+
+//go:noinline
+func f(t *T, i, j uint) {
+       if i == 0 || i > C || j == 0 || j > C {
+               return // gets rid of bounds check below (via prove pass)
+       }
+       p := &t[i*j-1]
+       *p = 0
+       runtime.GC()
+       *p = 0
+
+       // This goes badly if compiled to
+       //   q := &t[i*j]
+       //   *(q-1) = 0
+       //   runtime.GC()
+       //   *(q-1) = 0
+       // as at the GC call, q is an invalid pointer
+       // (it points past the end of t's allocation).
+}