hash crypto.Hash
}
-func (kdf *hkdfKDF) LabeledExtract(suiteID []byte, salt []byte, label string, inputKey []byte) []byte {
- labeledIKM := make([]byte, 0, 7+len(suiteID)+len(label)+len(inputKey))
+func (kdf *hkdfKDF) LabeledExtract(sid []byte, salt []byte, label string, inputKey []byte) []byte {
+ labeledIKM := make([]byte, 0, 7+len(sid)+len(label)+len(inputKey))
labeledIKM = append(labeledIKM, []byte("HPKE-v1")...)
- labeledIKM = append(labeledIKM, suiteID...)
+ labeledIKM = append(labeledIKM, sid...)
labeledIKM = append(labeledIKM, label...)
labeledIKM = append(labeledIKM, inputKey...)
return hkdf.Extract(kdf.hash.New, labeledIKM, salt)
nSecret uint16
}
+type KemID uint16
+
+const DHKEM_X25519_HKDF_SHA256 = 0x0020
+
var SupportedKEMs = map[uint16]struct {
curve ecdh.Curve
hash crypto.Hash
nSecret uint16
}{
// RFC 9180 Section 7.1
- 0x0020: {ecdh.X25519(), crypto.SHA256, 32},
+ DHKEM_X25519_HKDF_SHA256: {ecdh.X25519(), crypto.SHA256, 32},
}
func newDHKem(kemID uint16) (*dhKEM, error) {
return dh.ExtractAndExpand(dhVal, kemContext), encPubEph, nil
}
-type Sender struct {
+func (dh *dhKEM) Decap(encPubEph []byte, secRecipient *ecdh.PrivateKey) ([]byte, error) {
+ pubEph, err := dh.dh.NewPublicKey(encPubEph)
+ if err != nil {
+ return nil, err
+ }
+ dhVal, err := secRecipient.ECDH(pubEph)
+ if err != nil {
+ return nil, err
+ }
+ kemContext := append(encPubEph, secRecipient.PublicKey().Bytes()...)
+
+ return dh.ExtractAndExpand(dhVal, kemContext), nil
+}
+
+type context struct {
aead cipher.AEAD
- kem *dhKEM
sharedSecret []byte
seqNum uint128
}
+type Sender struct {
+ *context
+}
+
+type Receipient struct {
+ *context
+}
+
var aesGCMNew = func(key []byte) (cipher.AEAD, error) {
block, err := aes.NewCipher(key)
if err != nil {
return cipher.NewGCM(block)
}
+type AEADID uint16
+
+const (
+ AEAD_AES_128_GCM = 0x0001
+ AEAD_AES_256_GCM = 0x0002
+ AEAD_ChaCha20Poly1305 = 0x0003
+)
+
var SupportedAEADs = map[uint16]struct {
keySize int
nonceSize int
aead func([]byte) (cipher.AEAD, error)
}{
// RFC 9180, Section 7.3
- 0x0001: {keySize: 16, nonceSize: 12, aead: aesGCMNew},
- 0x0002: {keySize: 32, nonceSize: 12, aead: aesGCMNew},
- 0x0003: {keySize: chacha20poly1305.KeySize, nonceSize: chacha20poly1305.NonceSize, aead: chacha20poly1305.New},
+ AEAD_AES_128_GCM: {keySize: 16, nonceSize: 12, aead: aesGCMNew},
+ AEAD_AES_256_GCM: {keySize: 32, nonceSize: 12, aead: aesGCMNew},
+ AEAD_ChaCha20Poly1305: {keySize: chacha20poly1305.KeySize, nonceSize: chacha20poly1305.NonceSize, aead: chacha20poly1305.New},
}
+type KDFID uint16
+
+const KDF_HKDF_SHA256 = 0x0001
+
var SupportedKDFs = map[uint16]func() *hkdfKDF{
// RFC 9180, Section 7.2
- 0x0001: func() *hkdfKDF { return &hkdfKDF{crypto.SHA256} },
+ KDF_HKDF_SHA256: func() *hkdfKDF { return &hkdfKDF{crypto.SHA256} },
}
-func SetupSender(kemID, kdfID, aeadID uint16, pub crypto.PublicKey, info []byte) ([]byte, *Sender, error) {
- suiteID := SuiteID(kemID, kdfID, aeadID)
-
- kem, err := newDHKem(kemID)
- if err != nil {
- return nil, nil, err
- }
- pubRecipient, ok := pub.(*ecdh.PublicKey)
- if !ok {
- return nil, nil, errors.New("incorrect public key type")
- }
- sharedSecret, encapsulatedKey, err := kem.Encap(pubRecipient)
- if err != nil {
- return nil, nil, err
- }
+func newContext(sharedSecret []byte, kemID, kdfID, aeadID uint16, info []byte) (*context, error) {
+ sid := suiteID(kemID, kdfID, aeadID)
kdfInit, ok := SupportedKDFs[kdfID]
if !ok {
- return nil, nil, errors.New("unsupported KDF id")
+ return nil, errors.New("unsupported KDF id")
}
kdf := kdfInit()
aeadInfo, ok := SupportedAEADs[aeadID]
if !ok {
- return nil, nil, errors.New("unsupported AEAD id")
+ return nil, errors.New("unsupported AEAD id")
}
- pskIDHash := kdf.LabeledExtract(suiteID, nil, "psk_id_hash", nil)
- infoHash := kdf.LabeledExtract(suiteID, nil, "info_hash", info)
+ pskIDHash := kdf.LabeledExtract(sid, nil, "psk_id_hash", nil)
+ infoHash := kdf.LabeledExtract(sid, nil, "info_hash", info)
ksContext := append([]byte{0}, pskIDHash...)
ksContext = append(ksContext, infoHash...)
- secret := kdf.LabeledExtract(suiteID, sharedSecret, "secret", nil)
+ secret := kdf.LabeledExtract(sid, sharedSecret, "secret", nil)
- key := kdf.LabeledExpand(suiteID, secret, "key", ksContext, uint16(aeadInfo.keySize) /* Nk - key size for AEAD */)
- baseNonce := kdf.LabeledExpand(suiteID, secret, "base_nonce", ksContext, uint16(aeadInfo.nonceSize) /* Nn - nonce size for AEAD */)
- exporterSecret := kdf.LabeledExpand(suiteID, secret, "exp", ksContext, uint16(kdf.hash.Size()) /* Nh - hash output size of the kdf*/)
+ key := kdf.LabeledExpand(sid, secret, "key", ksContext, uint16(aeadInfo.keySize) /* Nk - key size for AEAD */)
+ baseNonce := kdf.LabeledExpand(sid, secret, "base_nonce", ksContext, uint16(aeadInfo.nonceSize) /* Nn - nonce size for AEAD */)
+ exporterSecret := kdf.LabeledExpand(sid, secret, "exp", ksContext, uint16(kdf.hash.Size()) /* Nh - hash output size of the kdf*/)
aead, err := aeadInfo.aead(key)
if err != nil {
- return nil, nil, err
+ return nil, err
}
- return encapsulatedKey, &Sender{
- kem: kem,
+ return &context{
aead: aead,
sharedSecret: sharedSecret,
- suiteID: suiteID,
+ suiteID: sid,
key: key,
baseNonce: baseNonce,
exporterSecret: exporterSecret,
}, nil
}
-func (s *Sender) nextNonce() []byte {
- nonce := s.seqNum.bytes()[16-s.aead.NonceSize():]
- for i := range s.baseNonce {
- nonce[i] ^= s.baseNonce[i]
+func SetupSender(kemID, kdfID, aeadID uint16, pub *ecdh.PublicKey, info []byte) ([]byte, *Sender, error) {
+ kem, err := newDHKem(kemID)
+ if err != nil {
+ return nil, nil, err
+ }
+ sharedSecret, encapsulatedKey, err := kem.Encap(pub)
+ if err != nil {
+ return nil, nil, err
+ }
+
+ context, err := newContext(sharedSecret, kemID, kdfID, aeadID, info)
+ if err != nil {
+ return nil, nil, err
+ }
+
+ return encapsulatedKey, &Sender{context}, nil
+}
+
+func SetupReceipient(kemID, kdfID, aeadID uint16, priv *ecdh.PrivateKey, info, encPubEph []byte) (*Receipient, error) {
+ kem, err := newDHKem(kemID)
+ if err != nil {
+ return nil, err
+ }
+ sharedSecret, err := kem.Decap(encPubEph, priv)
+ if err != nil {
+ return nil, err
+ }
+
+ context, err := newContext(sharedSecret, kemID, kdfID, aeadID, info)
+ if err != nil {
+ return nil, err
}
+
+ return &Receipient{context}, nil
+}
+
+func (ctx *context) nextNonce() []byte {
+ nonce := ctx.seqNum.bytes()[16-ctx.aead.NonceSize():]
+ for i := range ctx.baseNonce {
+ nonce[i] ^= ctx.baseNonce[i]
+ }
+ return nonce
+}
+
+func (ctx *context) incrementNonce() {
// Message limit is, according to the RFC, 2^95+1, which
// is somewhat confusing, but we do as we're told.
- if s.seqNum.bitLen() >= (s.aead.NonceSize()*8)-1 {
+ if ctx.seqNum.bitLen() >= (ctx.aead.NonceSize()*8)-1 {
panic("message limit reached")
}
- s.seqNum = s.seqNum.addOne()
- return nonce
+ ctx.seqNum = ctx.seqNum.addOne()
}
func (s *Sender) Seal(aad, plaintext []byte) ([]byte, error) {
-
ciphertext := s.aead.Seal(nil, s.nextNonce(), plaintext, aad)
+ s.incrementNonce()
return ciphertext, nil
}
-func SuiteID(kemID, kdfID, aeadID uint16) []byte {
+func (r *Receipient) Open(aad, ciphertext []byte) ([]byte, error) {
+ plaintext, err := r.aead.Open(nil, r.nextNonce(), ciphertext, aad)
+ if err != nil {
+ return nil, err
+ }
+ r.incrementNonce()
+ return plaintext, nil
+}
+
+func suiteID(kemID, kdfID, aeadID uint16) []byte {
suiteID := make([]byte, 0, 4+2+2+2)
suiteID = append(suiteID, []byte("HPKE")...)
suiteID = byteorder.BeAppendUint16(suiteID, kemID)
return kemInfo.curve.NewPublicKey(bytes)
}
+func ParseHPKEPrivateKey(kemID uint16, bytes []byte) (*ecdh.PrivateKey, error) {
+ kemInfo, ok := SupportedKEMs[kemID]
+ if !ok {
+ return nil, errors.New("unsupported KEM id")
+ }
+ return kemInfo.curve.NewPrivateKey(bytes)
+}
+
type uint128 struct {
hi, lo uint64
}
}
t.Cleanup(func() { testingOnlyGenerateKey = nil })
- encap, context, err := SetupSender(
+ encap, sender, err := SetupSender(
uint16(kemID),
uint16(kdfID),
uint16(aeadID),
if !bytes.Equal(encap, expectedEncap) {
t.Errorf("unexpected encapsulated key, got: %x, want %x", encap, expectedEncap)
}
- expectedSharedSecret := mustDecodeHex(t, setup["shared_secret"])
- if !bytes.Equal(context.sharedSecret, expectedSharedSecret) {
- t.Errorf("unexpected shared secret, got: %x, want %x", context.sharedSecret, expectedSharedSecret)
- }
- expectedKey := mustDecodeHex(t, setup["key"])
- if !bytes.Equal(context.key, expectedKey) {
- t.Errorf("unexpected key, got: %x, want %x", context.key, expectedKey)
+
+ privKeyBytes := mustDecodeHex(t, setup["skRm"])
+ priv, err := ParseHPKEPrivateKey(uint16(kemID), privKeyBytes)
+ if err != nil {
+ t.Fatal(err)
}
- expectedBaseNonce := mustDecodeHex(t, setup["base_nonce"])
- if !bytes.Equal(context.baseNonce, expectedBaseNonce) {
- t.Errorf("unexpected base nonce, got: %x, want %x", context.baseNonce, expectedBaseNonce)
+
+ receipient, err := SetupReceipient(
+ uint16(kemID),
+ uint16(kdfID),
+ uint16(aeadID),
+ priv,
+ info,
+ encap,
+ )
+ if err != nil {
+ t.Fatal(err)
}
- expectedExporterSecret := mustDecodeHex(t, setup["exporter_secret"])
- if !bytes.Equal(context.exporterSecret, expectedExporterSecret) {
- t.Errorf("unexpected exporter secret, got: %x, want %x", context.exporterSecret, expectedExporterSecret)
+
+ for _, ctx := range []*context{sender.context, receipient.context} {
+ expectedSharedSecret := mustDecodeHex(t, setup["shared_secret"])
+ if !bytes.Equal(ctx.sharedSecret, expectedSharedSecret) {
+ t.Errorf("unexpected shared secret, got: %x, want %x", ctx.sharedSecret, expectedSharedSecret)
+ }
+ expectedKey := mustDecodeHex(t, setup["key"])
+ if !bytes.Equal(ctx.key, expectedKey) {
+ t.Errorf("unexpected key, got: %x, want %x", ctx.key, expectedKey)
+ }
+ expectedBaseNonce := mustDecodeHex(t, setup["base_nonce"])
+ if !bytes.Equal(ctx.baseNonce, expectedBaseNonce) {
+ t.Errorf("unexpected base nonce, got: %x, want %x", ctx.baseNonce, expectedBaseNonce)
+ }
+ expectedExporterSecret := mustDecodeHex(t, setup["exporter_secret"])
+ if !bytes.Equal(ctx.exporterSecret, expectedExporterSecret) {
+ t.Errorf("unexpected exporter secret, got: %x, want %x", ctx.exporterSecret, expectedExporterSecret)
+ }
}
for _, enc := range parseVectorEncryptions(vector.Encryptions) {
if err != nil {
t.Fatal(err)
}
- context.seqNum = uint128{lo: uint64(seqNum)}
+ sender.seqNum = uint128{lo: uint64(seqNum)}
+ receipient.seqNum = uint128{lo: uint64(seqNum)}
expectedNonce := mustDecodeHex(t, enc["nonce"])
- // We can't call nextNonce, because it increments the sequence number,
- // so just compute it directly.
- computedNonce := context.seqNum.bytes()[16-context.aead.NonceSize():]
- for i := range context.baseNonce {
- computedNonce[i] ^= context.baseNonce[i]
- }
+ computedNonce := sender.nextNonce()
if !bytes.Equal(computedNonce, expectedNonce) {
t.Errorf("unexpected nonce: got %x, want %x", computedNonce, expectedNonce)
}
expectedCiphertext := mustDecodeHex(t, enc["ct"])
- ciphertext, err := context.Seal(mustDecodeHex(t, enc["aad"]), mustDecodeHex(t, enc["pt"]))
+ ciphertext, err := sender.Seal(mustDecodeHex(t, enc["aad"]), mustDecodeHex(t, enc["pt"]))
if err != nil {
t.Fatal(err)
}
if !bytes.Equal(ciphertext, expectedCiphertext) {
t.Errorf("unexpected ciphertext: got %x want %x", ciphertext, expectedCiphertext)
}
+
+ expectedPlaintext := mustDecodeHex(t, enc["pt"])
+ plaintext, err := receipient.Open(mustDecodeHex(t, enc["aad"]), mustDecodeHex(t, enc["ct"]))
+ if err != nil {
+ t.Fatal(err)
+ }
+ if !bytes.Equal(plaintext, expectedPlaintext) {
+ t.Errorf("unexpected plaintext: got %x want %x", plaintext, expectedPlaintext)
+ }
})
}
})