}
// Bytes returns the decapsulation key as a 64-byte seed in the "d || z" form.
+//
+// The decapsulation key must be kept secret.
func (dk *DecapsulationKey) Bytes() []byte {
var b [SeedSize]byte
copy(b[:], dk.d[:])
// EncapsulationKey returns the public encapsulation key necessary to produce
// ciphertexts.
-func (dk *DecapsulationKey) EncapsulationKey() []byte {
+func (dk *DecapsulationKey) EncapsulationKey() *EncapsulationKey {
+ return &EncapsulationKey{
+ ρ: dk.ρ,
+ h: dk.h,
+ encryptionKey: dk.encryptionKey,
+ }
+}
+
+// An EncapsulationKey is the public key used to produce ciphertexts to be
+// decapsulated by the corresponding [DecapsulationKey].
+type EncapsulationKey struct {
+ ρ [32]byte // sampleNTT seed for A
+ h [32]byte // H(ek)
+ encryptionKey
+}
+
+// Bytes returns the encapsulation key as a byte slice.
+func (ek *EncapsulationKey) Bytes() []byte {
// The actual logic is in a separate function to outline this allocation.
b := make([]byte, 0, EncapsulationKeySize)
- return dk.encapsulationKey(b)
+ return ek.bytes(b)
}
-func (dk *DecapsulationKey) encapsulationKey(b []byte) []byte {
- for i := range dk.t {
- b = polyByteEncode(b, dk.t[i])
+func (ek *EncapsulationKey) bytes(b []byte) []byte {
+ for i := range ek.t {
+ b = polyByteEncode(b, ek.t[i])
}
- b = append(b, dk.ρ[:]...)
+ b = append(b, ek.ρ[:]...)
return b
}
return kemKeyGen(dk, &d, &z)
}
-// NewKeyFromSeed deterministically generates a decapsulation key from a 64-byte
+// NewDecapsulationKey parses a decapsulation key from a 64-byte
// seed in the "d || z" form. The seed must be uniformly random.
-func NewKeyFromSeed(seed []byte) (*DecapsulationKey, error) {
+func NewDecapsulationKey(seed []byte) (*DecapsulationKey, error) {
// The actual logic is in a separate function to outline this allocation.
dk := &DecapsulationKey{}
return newKeyFromSeed(dk, seed)
}
H := sha3.New256()
- ek := dk.EncapsulationKey()
+ ek := dk.EncapsulationKey().Bytes()
H.Write(ek)
H.Sum(dk.h[:0])
// Encapsulate generates a shared key and an associated ciphertext from an
// encapsulation key, drawing random bytes from crypto/rand.
-// If the encapsulation key is not valid, Encapsulate returns an error.
//
// The shared key must be kept secret.
-func Encapsulate(encapsulationKey []byte) (ciphertext, sharedKey []byte, err error) {
+func (ek *EncapsulationKey) Encapsulate() (ciphertext, sharedKey []byte) {
// The actual logic is in a separate function to outline this allocation.
var cc [CiphertextSize]byte
- return encapsulate(&cc, encapsulationKey)
+ return ek.encapsulate(&cc)
}
-func encapsulate(cc *[CiphertextSize]byte, encapsulationKey []byte) (ciphertext, sharedKey []byte, err error) {
- if len(encapsulationKey) != EncapsulationKeySize {
- return nil, nil, errors.New("mlkem768: invalid encapsulation key length")
- }
+func (ek *EncapsulationKey) encapsulate(cc *[CiphertextSize]byte) (ciphertext, sharedKey []byte) {
var m [messageSize]byte
rand.Read(m[:])
// Note that the modulus check (step 2 of the encapsulation key check from
// FIPS 203, Section 7.2) is performed by polyByteDecode in parseEK.
- return kemEncaps(cc, encapsulationKey, &m)
+ return kemEncaps(cc, ek, &m)
}
// kemEncaps generates a shared key and an associated ciphertext.
//
// It implements ML-KEM.Encaps_internal according to FIPS 203, Algorithm 17.
-func kemEncaps(cc *[CiphertextSize]byte, ek []byte, m *[messageSize]byte) (c, K []byte, err error) {
+func kemEncaps(cc *[CiphertextSize]byte, ek *EncapsulationKey, m *[messageSize]byte) (c, K []byte) {
if cc == nil {
cc = &[CiphertextSize]byte{}
}
- H := sha3.Sum256(ek[:])
g := sha3.New512()
g.Write(m[:])
- g.Write(H[:])
+ g.Write(ek.h[:])
G := g.Sum(nil)
K, r := G[:SharedKeySize], G[SharedKeySize:]
- var ex encryptionKey
- if err := parseEK(&ex, ek[:]); err != nil {
- return nil, nil, err
- }
- c = pkeEncrypt(cc, &ex, m, r)
- return c, K, nil
+ c = pkeEncrypt(cc, &ek.encryptionKey, m, r)
+ return c, K
+}
+
+// NewEncapsulationKey parses an encapsulation key from its encoded form.
+// If the encapsulation key is not valid, NewEncapsulationKey returns an error.
+func NewEncapsulationKey(encapsulationKey []byte) (*EncapsulationKey, error) {
+ // The actual logic is in a separate function to outline this allocation.
+ ek := &EncapsulationKey{}
+ return parseEK(ek, encapsulationKey)
}
// parseEK parses an encryption key from its encoded form.
//
// It implements the initial stages of K-PKE.Encrypt according to FIPS 203,
// Algorithm 14.
-func parseEK(ex *encryptionKey, ekPKE []byte) error {
+func parseEK(ek *EncapsulationKey, ekPKE []byte) (*EncapsulationKey, error) {
if len(ekPKE) != encryptionKeySize {
- return errors.New("mlkem768: invalid encryption key length")
+ return nil, errors.New("mlkem768: invalid encapsulation key length")
}
- for i := range ex.t {
+ ek.h = sha3.Sum256(ekPKE[:])
+
+ for i := range ek.t {
var err error
- ex.t[i], err = polyByteDecode[nttElement](ekPKE[:encodingSize12])
+ ek.t[i], err = polyByteDecode[nttElement](ekPKE[:encodingSize12])
if err != nil {
- return err
+ return nil, err
}
ekPKE = ekPKE[encodingSize12:]
}
- ρ := ekPKE
+ copy(ek.ρ[:], ekPKE)
for i := byte(0); i < k; i++ {
for j := byte(0); j < k; j++ {
- ex.a[i*k+j] = sampleNTT(ρ, j, i)
+ ek.a[i*k+j] = sampleNTT(ek.ρ[:], j, i)
}
}
- return nil
+ return ek, nil
}
// pkeEncrypt encrypt a plaintext message.
if err != nil {
t.Fatal(err)
}
- c, Ke, err := Encapsulate(dk.EncapsulationKey())
- if err != nil {
- t.Fatal(err)
- }
+ c, Ke := dk.EncapsulationKey().Encapsulate()
Kd, err := dk.Decapsulate(c)
if err != nil {
t.Fatal(err)
if err != nil {
t.Fatal(err)
}
- if bytes.Equal(dk.EncapsulationKey(), dk1.EncapsulationKey()) {
+ if bytes.Equal(dk.EncapsulationKey().Bytes(), dk1.EncapsulationKey().Bytes()) {
t.Fail()
}
if bytes.Equal(dk.Bytes(), dk1.Bytes()) {
t.Fail()
}
- c1, Ke1, err := Encapsulate(dk.EncapsulationKey())
- if err != nil {
- t.Fatal(err)
- }
+ c1, Ke1 := dk.EncapsulationKey().Encapsulate()
if bytes.Equal(c, c1) {
t.Fail()
}
t.Fatal(err)
}
ek := dk.EncapsulationKey()
+ ekBytes := dk.EncapsulationKey().Bytes()
+ c, _ := ek.Encapsulate()
- for i := 0; i < len(ek)-1; i++ {
- if _, _, err := Encapsulate(ek[:i]); err == nil {
+ for i := 0; i < len(ekBytes)-1; i++ {
+ if _, err := NewEncapsulationKey(ekBytes[:i]); err == nil {
t.Errorf("expected error for ek length %d", i)
}
}
- ekLong := ek
+ ekLong := ekBytes
for i := 0; i < 100; i++ {
ekLong = append(ekLong, 0)
- if _, _, err := Encapsulate(ekLong); err == nil {
+ if _, err := NewEncapsulationKey(ekLong); err == nil {
t.Errorf("expected error for ek length %d", len(ekLong))
}
}
- c, _, err := Encapsulate(ek)
- if err != nil {
- t.Fatal(err)
- }
-
for i := 0; i < len(c)-1; i++ {
if _, err := dk.Decapsulate(c[:i]); err == nil {
t.Errorf("expected error for c length %d", i)
for i := 0; i < n; i++ {
s.Read(seed)
- dk, err := NewKeyFromSeed(seed)
+ dk, err := NewDecapsulationKey(seed)
if err != nil {
t.Fatal(err)
}
ek := dk.EncapsulationKey()
- o.Write(ek)
+ o.Write(ek.Bytes())
s.Read(msg[:])
- ct, k, err := kemEncaps(nil, ek, &msg)
- if err != nil {
- t.Fatal(err)
- }
+ ct, k := kemEncaps(nil, ek, &msg)
o.Write(ct)
o.Write(k)
b.ResetTimer()
for i := 0; i < b.N; i++ {
dk := kemKeyGen(&dk, &d, &z)
- sink ^= dk.EncapsulationKey()[0]
+ sink ^= dk.EncapsulationKey().Bytes()[0]
}
}
rand.Read(seed)
var m [messageSize]byte
rand.Read(m[:])
- dk, err := NewKeyFromSeed(seed)
+ dk, err := NewDecapsulationKey(seed)
if err != nil {
b.Fatal(err)
}
- ek := dk.EncapsulationKey()
+ ekBytes := dk.EncapsulationKey().Bytes()
var c [CiphertextSize]byte
b.ResetTimer()
for i := 0; i < b.N; i++ {
- c, K, err := kemEncaps(&c, ek, &m)
+ ek, err := NewEncapsulationKey(ekBytes)
if err != nil {
b.Fatal(err)
}
+ c, K := kemEncaps(&c, ek, &m)
sink ^= c[0] ^ K[0]
}
}
b.Fatal(err)
}
ek := dk.EncapsulationKey()
- c, _, err := Encapsulate(ek)
- if err != nil {
- b.Fatal(err)
- }
+ c, _ := ek.Encapsulate()
b.ResetTimer()
for i := 0; i < b.N; i++ {
K := kemDecaps(dk, (*[CiphertextSize]byte)(c))
b.Fatal(err)
}
ek := dk.EncapsulationKey()
- c, _, err := Encapsulate(ek)
+ ekBytes := ek.Bytes()
+ c, _ := ek.Encapsulate()
if err != nil {
b.Fatal(err)
}
if err != nil {
b.Fatal(err)
}
- ekS := dkS.EncapsulationKey()
+ ekS := dkS.EncapsulationKey().Bytes()
sink ^= ekS[0]
Ks, err := dk.Decapsulate(c)
})
b.Run("Bob", func(b *testing.B) {
for i := 0; i < b.N; i++ {
- cS, Ks, err := Encapsulate(ek)
+ ek, err := NewEncapsulationKey(ekBytes)
+ if err != nil {
+ b.Fatal(err)
+ }
+ cS, Ks := ek.Encapsulate()
if err != nil {
b.Fatal(err)
}
if _, err := io.ReadFull(config.rand(), seed); err != nil {
return nil, nil, nil, err
}
- keyShareKeys.kyber, err = mlkem768.NewKeyFromSeed(seed)
+ keyShareKeys.kyber, err = mlkem768.NewDecapsulationKey(seed)
if err != nil {
return nil, nil, nil, err
}
// both, as allowed by draft-ietf-tls-hybrid-design-09, Section 3.2.
hello.keyShares = []keyShare{
{group: x25519Kyber768Draft00, data: append(keyShareKeys.ecdhe.PublicKey().Bytes(),
- keyShareKeys.kyber.EncapsulationKey()...)},
+ keyShareKeys.kyber.EncapsulationKey().Bytes()...)},
{group: X25519, data: keyShareKeys.ecdhe.PublicKey().Bytes()},
}
} else {
if err != nil {
return nil, err
}
- return kyberSharedSecret(K, c), nil
+ return kyberSharedSecret(c, K), nil
}
// kyberEncapsulate implements encapsulation according to Kyber Round 3.
func kyberEncapsulate(ek []byte) (c, ss []byte, err error) {
- c, ss, err = mlkem768.Encapsulate(ek)
+ k, err := mlkem768.NewEncapsulationKey(ek)
if err != nil {
return nil, nil, err
}
- return c, kyberSharedSecret(ss, c), nil
+ c, ss = k.Encapsulate()
+ return c, kyberSharedSecret(c, ss), nil
}
-func kyberSharedSecret(K, c []byte) []byte {
+func kyberSharedSecret(c, K []byte) []byte {
// Package mlkem768 implements ML-KEM, which compared to Kyber removed a
// final hashing step. Compute SHAKE-256(K || SHA3-256(c), 32) to match Kyber.
// See https://words.filippo.io/mlkem768/#bonus-track-using-a-ml-kem-implementation-as-kyber-v3.
if err != nil {
t.Fatal(err)
}
- ct, ss, err := kyberEncapsulate(dk.EncapsulationKey())
+ ct, ss, err := kyberEncapsulate(dk.EncapsulationKey().Bytes())
if err != nil {
t.Fatal(err)
}