]> Cypherpunks repositories - gostls13.git/commitdiff
crypto/internal/hpke: add Recipient role
authorRoland Shoemaker <roland@golang.org>
Sun, 11 Aug 2024 01:19:09 +0000 (18:19 -0700)
committerGopher Robot <gobot@golang.org>
Tue, 19 Nov 2024 16:48:05 +0000 (16:48 +0000)
Adds the Recipient role, alongside the existing Sender role. Also factor
out all of the shared underlying bits and pieces into a shared type that
is embedded in the Sender/Recipient roles.

Change-Id: I7640d8732aa0dd5cc9e38b8c26f0cfa7856170f6
Reviewed-on: https://go-review.googlesource.com/c/go/+/623575
Auto-Submit: Roland Shoemaker <roland@golang.org>
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
Reviewed-by: Michael Knyszek <mknyszek@google.com>
Reviewed-by: Daniel McCarney <daniel@binaryparadox.net>
Reviewed-by: Filippo Valsorda <filippo@golang.org>
src/crypto/internal/hpke/hpke.go
src/crypto/internal/hpke/hpke_test.go

index 022cdd28dfe4e8588c67ddf79820c3104a4b4ad3..978f79cbcf3a8a596f555f466daa66a9b4499eba 100644 (file)
@@ -26,10 +26,10 @@ type hkdfKDF struct {
        hash crypto.Hash
 }
 
-func (kdf *hkdfKDF) LabeledExtract(suiteID []byte, salt []byte, label string, inputKey []byte) []byte {
-       labeledIKM := make([]byte, 0, 7+len(suiteID)+len(label)+len(inputKey))
+func (kdf *hkdfKDF) LabeledExtract(sid []byte, salt []byte, label string, inputKey []byte) []byte {
+       labeledIKM := make([]byte, 0, 7+len(sid)+len(label)+len(inputKey))
        labeledIKM = append(labeledIKM, []byte("HPKE-v1")...)
-       labeledIKM = append(labeledIKM, suiteID...)
+       labeledIKM = append(labeledIKM, sid...)
        labeledIKM = append(labeledIKM, label...)
        labeledIKM = append(labeledIKM, inputKey...)
        return hkdf.Extract(kdf.hash.New, labeledIKM, salt)
@@ -59,13 +59,17 @@ type dhKEM struct {
        nSecret uint16
 }
 
+type KemID uint16
+
+const DHKEM_X25519_HKDF_SHA256 = 0x0020
+
 var SupportedKEMs = map[uint16]struct {
        curve   ecdh.Curve
        hash    crypto.Hash
        nSecret uint16
 }{
        // RFC 9180 Section 7.1
-       0x0020: {ecdh.X25519(), crypto.SHA256, 32},
+       DHKEM_X25519_HKDF_SHA256: {ecdh.X25519(), crypto.SHA256, 32},
 }
 
 func newDHKem(kemID uint16) (*dhKEM, error) {
@@ -108,9 +112,22 @@ func (dh *dhKEM) Encap(pubRecipient *ecdh.PublicKey) (sharedSecret []byte, encap
        return dh.ExtractAndExpand(dhVal, kemContext), encPubEph, nil
 }
 
-type Sender struct {
+func (dh *dhKEM) Decap(encPubEph []byte, secRecipient *ecdh.PrivateKey) ([]byte, error) {
+       pubEph, err := dh.dh.NewPublicKey(encPubEph)
+       if err != nil {
+               return nil, err
+       }
+       dhVal, err := secRecipient.ECDH(pubEph)
+       if err != nil {
+               return nil, err
+       }
+       kemContext := append(encPubEph, secRecipient.PublicKey().Bytes()...)
+
+       return dh.ExtractAndExpand(dhVal, kemContext), nil
+}
+
+type context struct {
        aead cipher.AEAD
-       kem  *dhKEM
 
        sharedSecret []byte
 
@@ -123,6 +140,14 @@ type Sender struct {
        seqNum uint128
 }
 
+type Sender struct {
+       *context
+}
+
+type Receipient struct {
+       *context
+}
+
 var aesGCMNew = func(key []byte) (cipher.AEAD, error) {
        block, err := aes.NewCipher(key)
        if err != nil {
@@ -131,97 +156,143 @@ var aesGCMNew = func(key []byte) (cipher.AEAD, error) {
        return cipher.NewGCM(block)
 }
 
+type AEADID uint16
+
+const (
+       AEAD_AES_128_GCM      = 0x0001
+       AEAD_AES_256_GCM      = 0x0002
+       AEAD_ChaCha20Poly1305 = 0x0003
+)
+
 var SupportedAEADs = map[uint16]struct {
        keySize   int
        nonceSize int
        aead      func([]byte) (cipher.AEAD, error)
 }{
        // RFC 9180, Section 7.3
-       0x0001: {keySize: 16, nonceSize: 12, aead: aesGCMNew},
-       0x0002: {keySize: 32, nonceSize: 12, aead: aesGCMNew},
-       0x0003: {keySize: chacha20poly1305.KeySize, nonceSize: chacha20poly1305.NonceSize, aead: chacha20poly1305.New},
+       AEAD_AES_128_GCM:      {keySize: 16, nonceSize: 12, aead: aesGCMNew},
+       AEAD_AES_256_GCM:      {keySize: 32, nonceSize: 12, aead: aesGCMNew},
+       AEAD_ChaCha20Poly1305: {keySize: chacha20poly1305.KeySize, nonceSize: chacha20poly1305.NonceSize, aead: chacha20poly1305.New},
 }
 
+type KDFID uint16
+
+const KDF_HKDF_SHA256 = 0x0001
+
 var SupportedKDFs = map[uint16]func() *hkdfKDF{
        // RFC 9180, Section 7.2
-       0x0001: func() *hkdfKDF { return &hkdfKDF{crypto.SHA256} },
+       KDF_HKDF_SHA256: func() *hkdfKDF { return &hkdfKDF{crypto.SHA256} },
 }
 
-func SetupSender(kemID, kdfID, aeadID uint16, pub crypto.PublicKey, info []byte) ([]byte, *Sender, error) {
-       suiteID := SuiteID(kemID, kdfID, aeadID)
-
-       kem, err := newDHKem(kemID)
-       if err != nil {
-               return nil, nil, err
-       }
-       pubRecipient, ok := pub.(*ecdh.PublicKey)
-       if !ok {
-               return nil, nil, errors.New("incorrect public key type")
-       }
-       sharedSecret, encapsulatedKey, err := kem.Encap(pubRecipient)
-       if err != nil {
-               return nil, nil, err
-       }
+func newContext(sharedSecret []byte, kemID, kdfID, aeadID uint16, info []byte) (*context, error) {
+       sid := suiteID(kemID, kdfID, aeadID)
 
        kdfInit, ok := SupportedKDFs[kdfID]
        if !ok {
-               return nil, nil, errors.New("unsupported KDF id")
+               return nil, errors.New("unsupported KDF id")
        }
        kdf := kdfInit()
 
        aeadInfo, ok := SupportedAEADs[aeadID]
        if !ok {
-               return nil, nil, errors.New("unsupported AEAD id")
+               return nil, errors.New("unsupported AEAD id")
        }
 
-       pskIDHash := kdf.LabeledExtract(suiteID, nil, "psk_id_hash", nil)
-       infoHash := kdf.LabeledExtract(suiteID, nil, "info_hash", info)
+       pskIDHash := kdf.LabeledExtract(sid, nil, "psk_id_hash", nil)
+       infoHash := kdf.LabeledExtract(sid, nil, "info_hash", info)
        ksContext := append([]byte{0}, pskIDHash...)
        ksContext = append(ksContext, infoHash...)
 
-       secret := kdf.LabeledExtract(suiteID, sharedSecret, "secret", nil)
+       secret := kdf.LabeledExtract(sid, sharedSecret, "secret", nil)
 
-       key := kdf.LabeledExpand(suiteID, secret, "key", ksContext, uint16(aeadInfo.keySize) /* Nk - key size for AEAD */)
-       baseNonce := kdf.LabeledExpand(suiteID, secret, "base_nonce", ksContext, uint16(aeadInfo.nonceSize) /* Nn - nonce size for AEAD */)
-       exporterSecret := kdf.LabeledExpand(suiteID, secret, "exp", ksContext, uint16(kdf.hash.Size()) /* Nh - hash output size of the kdf*/)
+       key := kdf.LabeledExpand(sid, secret, "key", ksContext, uint16(aeadInfo.keySize) /* Nk - key size for AEAD */)
+       baseNonce := kdf.LabeledExpand(sid, secret, "base_nonce", ksContext, uint16(aeadInfo.nonceSize) /* Nn - nonce size for AEAD */)
+       exporterSecret := kdf.LabeledExpand(sid, secret, "exp", ksContext, uint16(kdf.hash.Size()) /* Nh - hash output size of the kdf*/)
 
        aead, err := aeadInfo.aead(key)
        if err != nil {
-               return nil, nil, err
+               return nil, err
        }
 
-       return encapsulatedKey, &Sender{
-               kem:            kem,
+       return &context{
                aead:           aead,
                sharedSecret:   sharedSecret,
-               suiteID:        suiteID,
+               suiteID:        sid,
                key:            key,
                baseNonce:      baseNonce,
                exporterSecret: exporterSecret,
        }, nil
 }
 
-func (s *Sender) nextNonce() []byte {
-       nonce := s.seqNum.bytes()[16-s.aead.NonceSize():]
-       for i := range s.baseNonce {
-               nonce[i] ^= s.baseNonce[i]
+func SetupSender(kemID, kdfID, aeadID uint16, pub *ecdh.PublicKey, info []byte) ([]byte, *Sender, error) {
+       kem, err := newDHKem(kemID)
+       if err != nil {
+               return nil, nil, err
+       }
+       sharedSecret, encapsulatedKey, err := kem.Encap(pub)
+       if err != nil {
+               return nil, nil, err
+       }
+
+       context, err := newContext(sharedSecret, kemID, kdfID, aeadID, info)
+       if err != nil {
+               return nil, nil, err
+       }
+
+       return encapsulatedKey, &Sender{context}, nil
+}
+
+func SetupReceipient(kemID, kdfID, aeadID uint16, priv *ecdh.PrivateKey, info, encPubEph []byte) (*Receipient, error) {
+       kem, err := newDHKem(kemID)
+       if err != nil {
+               return nil, err
+       }
+       sharedSecret, err := kem.Decap(encPubEph, priv)
+       if err != nil {
+               return nil, err
+       }
+
+       context, err := newContext(sharedSecret, kemID, kdfID, aeadID, info)
+       if err != nil {
+               return nil, err
        }
+
+       return &Receipient{context}, nil
+}
+
+func (ctx *context) nextNonce() []byte {
+       nonce := ctx.seqNum.bytes()[16-ctx.aead.NonceSize():]
+       for i := range ctx.baseNonce {
+               nonce[i] ^= ctx.baseNonce[i]
+       }
+       return nonce
+}
+
+func (ctx *context) incrementNonce() {
        // Message limit is, according to the RFC, 2^95+1, which
        // is somewhat confusing, but we do as we're told.
-       if s.seqNum.bitLen() >= (s.aead.NonceSize()*8)-1 {
+       if ctx.seqNum.bitLen() >= (ctx.aead.NonceSize()*8)-1 {
                panic("message limit reached")
        }
-       s.seqNum = s.seqNum.addOne()
-       return nonce
+       ctx.seqNum = ctx.seqNum.addOne()
 }
 
 func (s *Sender) Seal(aad, plaintext []byte) ([]byte, error) {
-
        ciphertext := s.aead.Seal(nil, s.nextNonce(), plaintext, aad)
+       s.incrementNonce()
        return ciphertext, nil
 }
 
-func SuiteID(kemID, kdfID, aeadID uint16) []byte {
+func (r *Receipient) Open(aad, ciphertext []byte) ([]byte, error) {
+       plaintext, err := r.aead.Open(nil, r.nextNonce(), ciphertext, aad)
+       if err != nil {
+               return nil, err
+       }
+       r.incrementNonce()
+       return plaintext, nil
+}
+
+func suiteID(kemID, kdfID, aeadID uint16) []byte {
        suiteID := make([]byte, 0, 4+2+2+2)
        suiteID = append(suiteID, []byte("HPKE")...)
        suiteID = byteorder.BeAppendUint16(suiteID, kemID)
@@ -238,6 +309,14 @@ func ParseHPKEPublicKey(kemID uint16, bytes []byte) (*ecdh.PublicKey, error) {
        return kemInfo.curve.NewPublicKey(bytes)
 }
 
+func ParseHPKEPrivateKey(kemID uint16, bytes []byte) (*ecdh.PrivateKey, error) {
+       kemInfo, ok := SupportedKEMs[kemID]
+       if !ok {
+               return nil, errors.New("unsupported KEM id")
+       }
+       return kemInfo.curve.NewPrivateKey(bytes)
+}
+
 type uint128 struct {
        hi, lo uint64
 }
index dbdfd7a80a327fa64660d3eff0d8dea4fb06de40..51beeed21281e7e0bf710a73af0ffca8e53f4368 100644 (file)
@@ -104,7 +104,7 @@ func TestRFC9180Vectors(t *testing.T) {
                        }
                        t.Cleanup(func() { testingOnlyGenerateKey = nil })
 
-                       encap, context, err := SetupSender(
+                       encap, sender, err := SetupSender(
                                uint16(kemID),
                                uint16(kdfID),
                                uint16(aeadID),
@@ -119,21 +119,42 @@ func TestRFC9180Vectors(t *testing.T) {
                        if !bytes.Equal(encap, expectedEncap) {
                                t.Errorf("unexpected encapsulated key, got: %x, want %x", encap, expectedEncap)
                        }
-                       expectedSharedSecret := mustDecodeHex(t, setup["shared_secret"])
-                       if !bytes.Equal(context.sharedSecret, expectedSharedSecret) {
-                               t.Errorf("unexpected shared secret, got: %x, want %x", context.sharedSecret, expectedSharedSecret)
-                       }
-                       expectedKey := mustDecodeHex(t, setup["key"])
-                       if !bytes.Equal(context.key, expectedKey) {
-                               t.Errorf("unexpected key, got: %x, want %x", context.key, expectedKey)
+
+                       privKeyBytes := mustDecodeHex(t, setup["skRm"])
+                       priv, err := ParseHPKEPrivateKey(uint16(kemID), privKeyBytes)
+                       if err != nil {
+                               t.Fatal(err)
                        }
-                       expectedBaseNonce := mustDecodeHex(t, setup["base_nonce"])
-                       if !bytes.Equal(context.baseNonce, expectedBaseNonce) {
-                               t.Errorf("unexpected base nonce, got: %x, want %x", context.baseNonce, expectedBaseNonce)
+
+                       receipient, err := SetupReceipient(
+                               uint16(kemID),
+                               uint16(kdfID),
+                               uint16(aeadID),
+                               priv,
+                               info,
+                               encap,
+                       )
+                       if err != nil {
+                               t.Fatal(err)
                        }
-                       expectedExporterSecret := mustDecodeHex(t, setup["exporter_secret"])
-                       if !bytes.Equal(context.exporterSecret, expectedExporterSecret) {
-                               t.Errorf("unexpected exporter secret, got: %x, want %x", context.exporterSecret, expectedExporterSecret)
+
+                       for _, ctx := range []*context{sender.context, receipient.context} {
+                               expectedSharedSecret := mustDecodeHex(t, setup["shared_secret"])
+                               if !bytes.Equal(ctx.sharedSecret, expectedSharedSecret) {
+                                       t.Errorf("unexpected shared secret, got: %x, want %x", ctx.sharedSecret, expectedSharedSecret)
+                               }
+                               expectedKey := mustDecodeHex(t, setup["key"])
+                               if !bytes.Equal(ctx.key, expectedKey) {
+                                       t.Errorf("unexpected key, got: %x, want %x", ctx.key, expectedKey)
+                               }
+                               expectedBaseNonce := mustDecodeHex(t, setup["base_nonce"])
+                               if !bytes.Equal(ctx.baseNonce, expectedBaseNonce) {
+                                       t.Errorf("unexpected base nonce, got: %x, want %x", ctx.baseNonce, expectedBaseNonce)
+                               }
+                               expectedExporterSecret := mustDecodeHex(t, setup["exporter_secret"])
+                               if !bytes.Equal(ctx.exporterSecret, expectedExporterSecret) {
+                                       t.Errorf("unexpected exporter secret, got: %x, want %x", ctx.exporterSecret, expectedExporterSecret)
+                               }
                        }
 
                        for _, enc := range parseVectorEncryptions(vector.Encryptions) {
@@ -142,26 +163,31 @@ func TestRFC9180Vectors(t *testing.T) {
                                        if err != nil {
                                                t.Fatal(err)
                                        }
-                                       context.seqNum = uint128{lo: uint64(seqNum)}
+                                       sender.seqNum = uint128{lo: uint64(seqNum)}
+                                       receipient.seqNum = uint128{lo: uint64(seqNum)}
                                        expectedNonce := mustDecodeHex(t, enc["nonce"])
-                                       // We can't call nextNonce, because it increments the sequence number,
-                                       // so just compute it directly.
-                                       computedNonce := context.seqNum.bytes()[16-context.aead.NonceSize():]
-                                       for i := range context.baseNonce {
-                                               computedNonce[i] ^= context.baseNonce[i]
-                                       }
+                                       computedNonce := sender.nextNonce()
                                        if !bytes.Equal(computedNonce, expectedNonce) {
                                                t.Errorf("unexpected nonce: got %x, want %x", computedNonce, expectedNonce)
                                        }
 
                                        expectedCiphertext := mustDecodeHex(t, enc["ct"])
-                                       ciphertext, err := context.Seal(mustDecodeHex(t, enc["aad"]), mustDecodeHex(t, enc["pt"]))
+                                       ciphertext, err := sender.Seal(mustDecodeHex(t, enc["aad"]), mustDecodeHex(t, enc["pt"]))
                                        if err != nil {
                                                t.Fatal(err)
                                        }
                                        if !bytes.Equal(ciphertext, expectedCiphertext) {
                                                t.Errorf("unexpected ciphertext: got %x want %x", ciphertext, expectedCiphertext)
                                        }
+
+                                       expectedPlaintext := mustDecodeHex(t, enc["pt"])
+                                       plaintext, err := receipient.Open(mustDecodeHex(t, enc["aad"]), mustDecodeHex(t, enc["ct"]))
+                                       if err != nil {
+                                               t.Fatal(err)
+                                       }
+                                       if !bytes.Equal(plaintext, expectedPlaintext) {
+                                               t.Errorf("unexpected plaintext: got %x want %x", plaintext, expectedPlaintext)
+                                       }
                                })
                        }
                })