From b30ba3df9ff8969f934bec5016cfce4b91f6ea5b Mon Sep 17 00:00:00 2001 From: Russ Cox Date: Tue, 16 Aug 2022 10:37:48 -0400 Subject: [PATCH] crypto/internal/boring/bcache: make Cache type-safe using generics Generics lets us write Cache[K, V] instead of using unsafe.Pointer, which lets us remove all the uses of package unsafe around the uses of the cache. I tried to do Cache[*K, *V] instead of Cache[K, V] but that was not possible. Change-Id: If3b54cf4c8d2a44879a5f343fd91ecff096537e9 Reviewed-on: https://go-review.googlesource.com/c/go/+/423357 TryBot-Result: Gopher Robot Run-TryBot: Russ Cox Reviewed-by: Cherry Mui Auto-Submit: Russ Cox --- src/crypto/ecdsa/boring.go | 13 ++--- src/crypto/internal/boring/bcache/cache.go | 57 +++++++++---------- .../internal/boring/bcache/cache_test.go | 46 ++++++++------- src/crypto/rsa/boring.go | 13 ++--- 4 files changed, 64 insertions(+), 65 deletions(-) diff --git a/src/crypto/ecdsa/boring.go b/src/crypto/ecdsa/boring.go index 4495730b84..275c60b4de 100644 --- a/src/crypto/ecdsa/boring.go +++ b/src/crypto/ecdsa/boring.go @@ -11,7 +11,6 @@ import ( "crypto/internal/boring/bbig" "crypto/internal/boring/bcache" "math/big" - "unsafe" ) // Cached conversions from Go PublicKey/PrivateKey to BoringCrypto. @@ -27,8 +26,8 @@ import ( // still matches before using the cached key. The theory is that the real // operations are significantly more expensive than the comparison. -var pubCache bcache.Cache -var privCache bcache.Cache +var pubCache bcache.Cache[PublicKey, boringPub] +var privCache bcache.Cache[PrivateKey, boringPriv] func init() { pubCache.Register() @@ -41,7 +40,7 @@ type boringPub struct { } func boringPublicKey(pub *PublicKey) (*boring.PublicKeyECDSA, error) { - b := (*boringPub)(pubCache.Get(unsafe.Pointer(pub))) + b := pubCache.Get(pub) if b != nil && publicKeyEqual(&b.orig, pub) { return b.key, nil } @@ -53,7 +52,7 @@ func boringPublicKey(pub *PublicKey) (*boring.PublicKeyECDSA, error) { return nil, err } b.key = key - pubCache.Put(unsafe.Pointer(pub), unsafe.Pointer(b)) + pubCache.Put(pub, b) return key, nil } @@ -63,7 +62,7 @@ type boringPriv struct { } func boringPrivateKey(priv *PrivateKey) (*boring.PrivateKeyECDSA, error) { - b := (*boringPriv)(privCache.Get(unsafe.Pointer(priv))) + b := privCache.Get(priv) if b != nil && privateKeyEqual(&b.orig, priv) { return b.key, nil } @@ -75,7 +74,7 @@ func boringPrivateKey(priv *PrivateKey) (*boring.PrivateKeyECDSA, error) { return nil, err } b.key = key - privCache.Put(unsafe.Pointer(priv), unsafe.Pointer(b)) + privCache.Put(priv, b) return key, nil } diff --git a/src/crypto/internal/boring/bcache/cache.go b/src/crypto/internal/boring/bcache/cache.go index c0b9d7bf2a..7934d03e7b 100644 --- a/src/crypto/internal/boring/bcache/cache.go +++ b/src/crypto/internal/boring/bcache/cache.go @@ -22,20 +22,18 @@ import ( // This means that clients need to be able to cope with cache entries // disappearing, but it also means that clients don't need to worry about // cache entries keeping the keys from being collected. -// -// TODO(rsc): Make Cache generic once consumers can handle that. -type Cache struct { - // ptable is an atomic *[cacheSize]unsafe.Pointer, - // where each unsafe.Pointer is an atomic *cacheEntry. +type Cache[K, V any] struct { // The runtime atomically stores nil to ptable at the start of each GC. - ptable unsafe.Pointer + ptable atomic.Pointer[cacheTable[K, V]] } +type cacheTable[K, V any] [cacheSize]atomic.Pointer[cacheEntry[K, V]] + // A cacheEntry is a single entry in the linked list for a given hash table entry. -type cacheEntry struct { - k unsafe.Pointer // immutable once created - v unsafe.Pointer // read and written atomically to allow updates - next *cacheEntry // immutable once linked into table +type cacheEntry[K, V any] struct { + k *K // immutable once created + v atomic.Pointer[V] // read and written atomically to allow updates + next *cacheEntry[K, V] // immutable once linked into table } func registerCache(unsafe.Pointer) // provided by runtime @@ -43,7 +41,7 @@ func registerCache(unsafe.Pointer) // provided by runtime // Register registers the cache with the runtime, // so that c.ptable can be cleared at the start of each GC. // Register must be called during package initialization. -func (c *Cache) Register() { +func (c *Cache[K, V]) Register() { registerCache(unsafe.Pointer(&c.ptable)) } @@ -54,45 +52,45 @@ const cacheSize = 1021 // table returns a pointer to the current cache hash table, // coping with the possibility of the GC clearing it out from under us. -func (c *Cache) table() *[cacheSize]unsafe.Pointer { +func (c *Cache[K, V]) table() *cacheTable[K, V] { for { - p := atomic.LoadPointer(&c.ptable) + p := c.ptable.Load() if p == nil { - p = unsafe.Pointer(new([cacheSize]unsafe.Pointer)) - if !atomic.CompareAndSwapPointer(&c.ptable, nil, p) { + p = new(cacheTable[K, V]) + if !c.ptable.CompareAndSwap(nil, p) { continue } } - return (*[cacheSize]unsafe.Pointer)(p) + return p } } // Clear clears the cache. // The runtime does this automatically at each garbage collection; // this method is exposed only for testing. -func (c *Cache) Clear() { +func (c *Cache[K, V]) Clear() { // The runtime does this at the start of every garbage collection // (itself, not by calling this function). - atomic.StorePointer(&c.ptable, nil) + c.ptable.Store(nil) } // Get returns the cached value associated with v, // which is either the value v corresponding to the most recent call to Put(k, v) // or nil if that cache entry has been dropped. -func (c *Cache) Get(k unsafe.Pointer) unsafe.Pointer { - head := &c.table()[uintptr(k)%cacheSize] - e := (*cacheEntry)(atomic.LoadPointer(head)) +func (c *Cache[K, V]) Get(k *K) *V { + head := &c.table()[uintptr(unsafe.Pointer(k))%cacheSize] + e := head.Load() for ; e != nil; e = e.next { if e.k == k { - return atomic.LoadPointer(&e.v) + return e.v.Load() } } return nil } // Put sets the cached value associated with k to v. -func (c *Cache) Put(k, v unsafe.Pointer) { - head := &c.table()[uintptr(k)%cacheSize] +func (c *Cache[K, V]) Put(k *K, v *V) { + head := &c.table()[uintptr(unsafe.Pointer(k))%cacheSize] // Strategy is to walk the linked list at head, // same as in Get, to look for existing entry. @@ -112,20 +110,21 @@ func (c *Cache) Put(k, v unsafe.Pointer) { // // 2. We only allocate the entry to be added once, // saving it in add for the next attempt. - var add, noK *cacheEntry + var add, noK *cacheEntry[K, V] n := 0 for { - e := (*cacheEntry)(atomic.LoadPointer(head)) + e := head.Load() start := e for ; e != nil && e != noK; e = e.next { if e.k == k { - atomic.StorePointer(&e.v, v) + e.v.Store(v) return } n++ } if add == nil { - add = &cacheEntry{k, v, nil} + add = &cacheEntry[K, V]{k: k} + add.v.Store(v) } add.next = start if n >= 1000 { @@ -133,7 +132,7 @@ func (c *Cache) Put(k, v unsafe.Pointer) { // throw it away to avoid quadratic lookup behavior. add.next = nil } - if atomic.CompareAndSwapPointer(head, unsafe.Pointer(start), unsafe.Pointer(add)) { + if head.CompareAndSwap(start, add) { return } noK = start diff --git a/src/crypto/internal/boring/bcache/cache_test.go b/src/crypto/internal/boring/bcache/cache_test.go index 8b2cf3d094..19458a1c24 100644 --- a/src/crypto/internal/boring/bcache/cache_test.go +++ b/src/crypto/internal/boring/bcache/cache_test.go @@ -10,31 +10,39 @@ import ( "sync" "sync/atomic" "testing" - "unsafe" ) -var registeredCache Cache +var registeredCache Cache[int, int32] func init() { registeredCache.Register() } +var seq atomic.Uint32 + +func next[T int | int32]() *T { + x := new(T) + *x = T(seq.Add(1)) + return x +} + +func str[T int | int32](x *T) string { + if x == nil { + return "nil" + } + return fmt.Sprint(*x) +} + func TestCache(t *testing.T) { // Use unregistered cache for functionality tests, // to keep the runtime from clearing behind our backs. - c := new(Cache) + c := new(Cache[int, int32]) // Create many entries. - seq := uint32(0) - next := func() unsafe.Pointer { - x := new(int) - *x = int(atomic.AddUint32(&seq, 1)) - return unsafe.Pointer(x) - } - m := make(map[unsafe.Pointer]unsafe.Pointer) + m := make(map[*int]*int32) for i := 0; i < 10000; i++ { - k := next() - v := next() + k := next[int]() + v := next[int32]() m[k] = v c.Put(k, v) } @@ -42,7 +50,7 @@ func TestCache(t *testing.T) { // Overwrite a random 20% of those. n := 0 for k := range m { - v := next() + v := next[int32]() m[k] = v c.Put(k, v) if n++; n >= 2000 { @@ -51,12 +59,6 @@ func TestCache(t *testing.T) { } // Check results. - str := func(p unsafe.Pointer) string { - if p == nil { - return "nil" - } - return fmt.Sprint(*(*int)(p)) - } for k, v := range m { if cv := c.Get(k); cv != v { t.Fatalf("c.Get(%v) = %v, want %v", str(k), str(cv), str(v)) @@ -86,7 +88,7 @@ func TestCache(t *testing.T) { // Lists are discarded if they reach 1000 entries, // and there are cacheSize list heads, so we should be // able to do 100 * cacheSize entries with no problem at all. - c = new(Cache) + c = new(Cache[int, int32]) var barrier, wg sync.WaitGroup const N = 100 barrier.Add(N) @@ -96,9 +98,9 @@ func TestCache(t *testing.T) { go func() { defer wg.Done() - m := make(map[unsafe.Pointer]unsafe.Pointer) + m := make(map[*int]*int32) for j := 0; j < cacheSize; j++ { - k, v := next(), next() + k, v := next[int](), next[int32]() m[k] = v c.Put(k, v) } diff --git a/src/crypto/rsa/boring.go b/src/crypto/rsa/boring.go index 9b1db564c3..b9f9d3154f 100644 --- a/src/crypto/rsa/boring.go +++ b/src/crypto/rsa/boring.go @@ -11,7 +11,6 @@ import ( "crypto/internal/boring/bbig" "crypto/internal/boring/bcache" "math/big" - "unsafe" ) // Cached conversions from Go PublicKey/PrivateKey to BoringCrypto. @@ -32,8 +31,8 @@ type boringPub struct { orig PublicKey } -var pubCache bcache.Cache -var privCache bcache.Cache +var pubCache bcache.Cache[PublicKey, boringPub] +var privCache bcache.Cache[PrivateKey, boringPriv] func init() { pubCache.Register() @@ -41,7 +40,7 @@ func init() { } func boringPublicKey(pub *PublicKey) (*boring.PublicKeyRSA, error) { - b := (*boringPub)(pubCache.Get(unsafe.Pointer(pub))) + b := pubCache.Get(pub) if b != nil && publicKeyEqual(&b.orig, pub) { return b.key, nil } @@ -53,7 +52,7 @@ func boringPublicKey(pub *PublicKey) (*boring.PublicKeyRSA, error) { return nil, err } b.key = key - pubCache.Put(unsafe.Pointer(pub), unsafe.Pointer(b)) + pubCache.Put(pub, b) return key, nil } @@ -63,7 +62,7 @@ type boringPriv struct { } func boringPrivateKey(priv *PrivateKey) (*boring.PrivateKeyRSA, error) { - b := (*boringPriv)(privCache.Get(unsafe.Pointer(priv))) + b := privCache.Get(priv) if b != nil && privateKeyEqual(&b.orig, priv) { return b.key, nil } @@ -87,7 +86,7 @@ func boringPrivateKey(priv *PrivateKey) (*boring.PrivateKeyRSA, error) { return nil, err } b.key = key - privCache.Put(unsafe.Pointer(priv), unsafe.Pointer(b)) + privCache.Put(priv, b) return key, nil } -- 2.50.0