]> Cypherpunks repositories - gostls13.git/commitdiff
crypto/internal/boring/bcache: make Cache type-safe using generics
authorRuss Cox <rsc@golang.org>
Tue, 16 Aug 2022 14:37:48 +0000 (10:37 -0400)
committerGopher Robot <gobot@golang.org>
Thu, 18 Aug 2022 00:30:19 +0000 (00:30 +0000)
Generics lets us write Cache[K, V] instead of using unsafe.Pointer,
which lets us remove all the uses of package unsafe around the
uses of the cache.

I tried to do Cache[*K, *V] instead of Cache[K, V] but that was not possible.

Change-Id: If3b54cf4c8d2a44879a5f343fd91ecff096537e9
Reviewed-on: https://go-review.googlesource.com/c/go/+/423357
TryBot-Result: Gopher Robot <gobot@golang.org>
Run-TryBot: Russ Cox <rsc@golang.org>
Reviewed-by: Cherry Mui <cherryyz@google.com>
Auto-Submit: Russ Cox <rsc@golang.org>

src/crypto/ecdsa/boring.go
src/crypto/internal/boring/bcache/cache.go
src/crypto/internal/boring/bcache/cache_test.go
src/crypto/rsa/boring.go

index 4495730b84ffdc60f1f0af363ab2724319b21706..275c60b4de49eb31091cd652829d3ddf56b602fa 100644 (file)
@@ -11,7 +11,6 @@ import (
        "crypto/internal/boring/bbig"
        "crypto/internal/boring/bcache"
        "math/big"
-       "unsafe"
 )
 
 // Cached conversions from Go PublicKey/PrivateKey to BoringCrypto.
@@ -27,8 +26,8 @@ import (
 // still matches before using the cached key. The theory is that the real
 // operations are significantly more expensive than the comparison.
 
-var pubCache bcache.Cache
-var privCache bcache.Cache
+var pubCache bcache.Cache[PublicKey, boringPub]
+var privCache bcache.Cache[PrivateKey, boringPriv]
 
 func init() {
        pubCache.Register()
@@ -41,7 +40,7 @@ type boringPub struct {
 }
 
 func boringPublicKey(pub *PublicKey) (*boring.PublicKeyECDSA, error) {
-       b := (*boringPub)(pubCache.Get(unsafe.Pointer(pub)))
+       b := pubCache.Get(pub)
        if b != nil && publicKeyEqual(&b.orig, pub) {
                return b.key, nil
        }
@@ -53,7 +52,7 @@ func boringPublicKey(pub *PublicKey) (*boring.PublicKeyECDSA, error) {
                return nil, err
        }
        b.key = key
-       pubCache.Put(unsafe.Pointer(pub), unsafe.Pointer(b))
+       pubCache.Put(pub, b)
        return key, nil
 }
 
@@ -63,7 +62,7 @@ type boringPriv struct {
 }
 
 func boringPrivateKey(priv *PrivateKey) (*boring.PrivateKeyECDSA, error) {
-       b := (*boringPriv)(privCache.Get(unsafe.Pointer(priv)))
+       b := privCache.Get(priv)
        if b != nil && privateKeyEqual(&b.orig, priv) {
                return b.key, nil
        }
@@ -75,7 +74,7 @@ func boringPrivateKey(priv *PrivateKey) (*boring.PrivateKeyECDSA, error) {
                return nil, err
        }
        b.key = key
-       privCache.Put(unsafe.Pointer(priv), unsafe.Pointer(b))
+       privCache.Put(priv, b)
        return key, nil
 }
 
index c0b9d7bf2a49841904a2ed72efd620569d0c255e..7934d03e7badcdbdfb25d9636805b4e1c7f1fc37 100644 (file)
@@ -22,20 +22,18 @@ import (
 // This means that clients need to be able to cope with cache entries
 // disappearing, but it also means that clients don't need to worry about
 // cache entries keeping the keys from being collected.
-//
-// TODO(rsc): Make Cache generic once consumers can handle that.
-type Cache struct {
-       // ptable is an atomic *[cacheSize]unsafe.Pointer,
-       // where each unsafe.Pointer is an atomic *cacheEntry.
+type Cache[K, V any] struct {
        // The runtime atomically stores nil to ptable at the start of each GC.
-       ptable unsafe.Pointer
+       ptable atomic.Pointer[cacheTable[K, V]]
 }
 
+type cacheTable[K, V any] [cacheSize]atomic.Pointer[cacheEntry[K, V]]
+
 // A cacheEntry is a single entry in the linked list for a given hash table entry.
-type cacheEntry struct {
-       k    unsafe.Pointer // immutable once created
-       v    unsafe.Pointer // read and written atomically to allow updates
-       next *cacheEntry    // immutable once linked into table
+type cacheEntry[K, V any] struct {
+       k    *K                // immutable once created
+       v    atomic.Pointer[V] // read and written atomically to allow updates
+       next *cacheEntry[K, V] // immutable once linked into table
 }
 
 func registerCache(unsafe.Pointer) // provided by runtime
@@ -43,7 +41,7 @@ func registerCache(unsafe.Pointer) // provided by runtime
 // Register registers the cache with the runtime,
 // so that c.ptable can be cleared at the start of each GC.
 // Register must be called during package initialization.
-func (c *Cache) Register() {
+func (c *Cache[K, V]) Register() {
        registerCache(unsafe.Pointer(&c.ptable))
 }
 
@@ -54,45 +52,45 @@ const cacheSize = 1021
 
 // table returns a pointer to the current cache hash table,
 // coping with the possibility of the GC clearing it out from under us.
-func (c *Cache) table() *[cacheSize]unsafe.Pointer {
+func (c *Cache[K, V]) table() *cacheTable[K, V] {
        for {
-               p := atomic.LoadPointer(&c.ptable)
+               p := c.ptable.Load()
                if p == nil {
-                       p = unsafe.Pointer(new([cacheSize]unsafe.Pointer))
-                       if !atomic.CompareAndSwapPointer(&c.ptable, nil, p) {
+                       p = new(cacheTable[K, V])
+                       if !c.ptable.CompareAndSwap(nil, p) {
                                continue
                        }
                }
-               return (*[cacheSize]unsafe.Pointer)(p)
+               return p
        }
 }
 
 // Clear clears the cache.
 // The runtime does this automatically at each garbage collection;
 // this method is exposed only for testing.
-func (c *Cache) Clear() {
+func (c *Cache[K, V]) Clear() {
        // The runtime does this at the start of every garbage collection
        // (itself, not by calling this function).
-       atomic.StorePointer(&c.ptable, nil)
+       c.ptable.Store(nil)
 }
 
 // Get returns the cached value associated with v,
 // which is either the value v corresponding to the most recent call to Put(k, v)
 // or nil if that cache entry has been dropped.
-func (c *Cache) Get(k unsafe.Pointer) unsafe.Pointer {
-       head := &c.table()[uintptr(k)%cacheSize]
-       e := (*cacheEntry)(atomic.LoadPointer(head))
+func (c *Cache[K, V]) Get(k *K) *V {
+       head := &c.table()[uintptr(unsafe.Pointer(k))%cacheSize]
+       e := head.Load()
        for ; e != nil; e = e.next {
                if e.k == k {
-                       return atomic.LoadPointer(&e.v)
+                       return e.v.Load()
                }
        }
        return nil
 }
 
 // Put sets the cached value associated with k to v.
-func (c *Cache) Put(k, v unsafe.Pointer) {
-       head := &c.table()[uintptr(k)%cacheSize]
+func (c *Cache[K, V]) Put(k *K, v *V) {
+       head := &c.table()[uintptr(unsafe.Pointer(k))%cacheSize]
 
        // Strategy is to walk the linked list at head,
        // same as in Get, to look for existing entry.
@@ -112,20 +110,21 @@ func (c *Cache) Put(k, v unsafe.Pointer) {
        //
        //  2. We only allocate the entry to be added once,
        //     saving it in add for the next attempt.
-       var add, noK *cacheEntry
+       var add, noK *cacheEntry[K, V]
        n := 0
        for {
-               e := (*cacheEntry)(atomic.LoadPointer(head))
+               e := head.Load()
                start := e
                for ; e != nil && e != noK; e = e.next {
                        if e.k == k {
-                               atomic.StorePointer(&e.v, v)
+                               e.v.Store(v)
                                return
                        }
                        n++
                }
                if add == nil {
-                       add = &cacheEntry{k, v, nil}
+                       add = &cacheEntry[K, V]{k: k}
+                       add.v.Store(v)
                }
                add.next = start
                if n >= 1000 {
@@ -133,7 +132,7 @@ func (c *Cache) Put(k, v unsafe.Pointer) {
                        // throw it away to avoid quadratic lookup behavior.
                        add.next = nil
                }
-               if atomic.CompareAndSwapPointer(head, unsafe.Pointer(start), unsafe.Pointer(add)) {
+               if head.CompareAndSwap(start, add) {
                        return
                }
                noK = start
index 8b2cf3d09415d8e7c7a7d7222596f56acf67da7a..19458a1c245dfcc4432950557d7e5e59d8ca2657 100644 (file)
@@ -10,31 +10,39 @@ import (
        "sync"
        "sync/atomic"
        "testing"
-       "unsafe"
 )
 
-var registeredCache Cache
+var registeredCache Cache[int, int32]
 
 func init() {
        registeredCache.Register()
 }
 
+var seq atomic.Uint32
+
+func next[T int | int32]() *T {
+       x := new(T)
+       *x = T(seq.Add(1))
+       return x
+}
+
+func str[T int | int32](x *T) string {
+       if x == nil {
+               return "nil"
+       }
+       return fmt.Sprint(*x)
+}
+
 func TestCache(t *testing.T) {
        // Use unregistered cache for functionality tests,
        // to keep the runtime from clearing behind our backs.
-       c := new(Cache)
+       c := new(Cache[int, int32])
 
        // Create many entries.
-       seq := uint32(0)
-       next := func() unsafe.Pointer {
-               x := new(int)
-               *x = int(atomic.AddUint32(&seq, 1))
-               return unsafe.Pointer(x)
-       }
-       m := make(map[unsafe.Pointer]unsafe.Pointer)
+       m := make(map[*int]*int32)
        for i := 0; i < 10000; i++ {
-               k := next()
-               v := next()
+               k := next[int]()
+               v := next[int32]()
                m[k] = v
                c.Put(k, v)
        }
@@ -42,7 +50,7 @@ func TestCache(t *testing.T) {
        // Overwrite a random 20% of those.
        n := 0
        for k := range m {
-               v := next()
+               v := next[int32]()
                m[k] = v
                c.Put(k, v)
                if n++; n >= 2000 {
@@ -51,12 +59,6 @@ func TestCache(t *testing.T) {
        }
 
        // Check results.
-       str := func(p unsafe.Pointer) string {
-               if p == nil {
-                       return "nil"
-               }
-               return fmt.Sprint(*(*int)(p))
-       }
        for k, v := range m {
                if cv := c.Get(k); cv != v {
                        t.Fatalf("c.Get(%v) = %v, want %v", str(k), str(cv), str(v))
@@ -86,7 +88,7 @@ func TestCache(t *testing.T) {
        // Lists are discarded if they reach 1000 entries,
        // and there are cacheSize list heads, so we should be
        // able to do 100 * cacheSize entries with no problem at all.
-       c = new(Cache)
+       c = new(Cache[int, int32])
        var barrier, wg sync.WaitGroup
        const N = 100
        barrier.Add(N)
@@ -96,9 +98,9 @@ func TestCache(t *testing.T) {
                go func() {
                        defer wg.Done()
 
-                       m := make(map[unsafe.Pointer]unsafe.Pointer)
+                       m := make(map[*int]*int32)
                        for j := 0; j < cacheSize; j++ {
-                               k, v := next(), next()
+                               k, v := next[int](), next[int32]()
                                m[k] = v
                                c.Put(k, v)
                        }
index 9b1db564c36f3cee2a5cd2edc645da4b082dc6e7..b9f9d3154f2589e04fcf6592511e1d55f371051e 100644 (file)
@@ -11,7 +11,6 @@ import (
        "crypto/internal/boring/bbig"
        "crypto/internal/boring/bcache"
        "math/big"
-       "unsafe"
 )
 
 // Cached conversions from Go PublicKey/PrivateKey to BoringCrypto.
@@ -32,8 +31,8 @@ type boringPub struct {
        orig PublicKey
 }
 
-var pubCache bcache.Cache
-var privCache bcache.Cache
+var pubCache bcache.Cache[PublicKey, boringPub]
+var privCache bcache.Cache[PrivateKey, boringPriv]
 
 func init() {
        pubCache.Register()
@@ -41,7 +40,7 @@ func init() {
 }
 
 func boringPublicKey(pub *PublicKey) (*boring.PublicKeyRSA, error) {
-       b := (*boringPub)(pubCache.Get(unsafe.Pointer(pub)))
+       b := pubCache.Get(pub)
        if b != nil && publicKeyEqual(&b.orig, pub) {
                return b.key, nil
        }
@@ -53,7 +52,7 @@ func boringPublicKey(pub *PublicKey) (*boring.PublicKeyRSA, error) {
                return nil, err
        }
        b.key = key
-       pubCache.Put(unsafe.Pointer(pub), unsafe.Pointer(b))
+       pubCache.Put(pub, b)
        return key, nil
 }
 
@@ -63,7 +62,7 @@ type boringPriv struct {
 }
 
 func boringPrivateKey(priv *PrivateKey) (*boring.PrivateKeyRSA, error) {
-       b := (*boringPriv)(privCache.Get(unsafe.Pointer(priv)))
+       b := privCache.Get(priv)
        if b != nil && privateKeyEqual(&b.orig, priv) {
                return b.key, nil
        }
@@ -87,7 +86,7 @@ func boringPrivateKey(priv *PrivateKey) (*boring.PrivateKeyRSA, error) {
                return nil, err
        }
        b.key = key
-       privCache.Put(unsafe.Pointer(priv), unsafe.Pointer(b))
+       privCache.Put(priv, b)
        return key, nil
 }