age

Simple, secure encryption with UNIX-style composability.
git clone git://git.sgregoratto.me/age
Log | Files | Refs | README | LICENSE

commit e9c118cea0a3c7b24c789d4150f2d6f2d09fd525
parent 52dbe9eecf8b2153f77b13157a382d53215b99df
Author: Filippo Valsorda <hi@filippo.io>
Date:   Sun,  6 Oct 2019 21:19:04 -0400

internal: implement STREAM, key exchange, encryption and decryption

Developed live over 6 hours of streaming on Twitch.

https://twitter.com/FiloSottile/status/1180875486911766528

Diffstat:
Mgo.mod | 2++
Ago.sum | 8++++++++
Ainternal/age/age.go | 118+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
Ainternal/age/age_test.go | 103+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
Ainternal/age/primitives.go | 51+++++++++++++++++++++++++++++++++++++++++++++++++++
Ainternal/age/recipients_test.go | 131+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
Ainternal/age/scrypt.go | 125+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
Ainternal/age/ssh.go | 118+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
Ainternal/age/x25519.go | 123+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
Minternal/format/format.go | 19+++++++++++++++----
Minternal/format/format_gofuzz.go | 2++
Ainternal/stream/stream.go | 208+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
Ainternal/stream/stream_test.go | 93+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
13 files changed, 1097 insertions(+), 4 deletions(-)

diff --git a/go.mod b/go.mod @@ -1,3 +1,5 @@ module github.com/FiloSottile/age go 1.12 + +require golang.org/x/crypto v0.0.0-20191002192127-34f69633bfdc diff --git a/go.sum b/go.sum @@ -0,0 +1,8 @@ +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20191002192127-34f69633bfdc h1:c0o/qxkaO2LF5t6fQrT4b5hzyggAkLLlCUjqfRxd8Q4= +golang.org/x/crypto v0.0.0-20191002192127-34f69633bfdc/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= +golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190412213103-97732733099d h1:+R4KGOnez64A81RvjARKc4UT5/tI9ujCIVX+P5KiHuI= +golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= diff --git a/internal/age/age.go b/internal/age/age.go @@ -0,0 +1,118 @@ +package age + +import ( + "crypto/hmac" + "crypto/rand" + "errors" + "fmt" + "io" + + "github.com/FiloSottile/age/internal/format" + "github.com/FiloSottile/age/internal/stream" +) + +type Identity interface { + Type() string + Unwrap(block *format.Recipient) (fileKey []byte, err error) +} + +type Recipient interface { + Type() string + Wrap(fileKey []byte) (*format.Recipient, error) +} + +func Encrypt(dst io.Writer, recipients ...Recipient) (io.WriteCloser, error) { + if len(recipients) == 0 { + return nil, errors.New("no recipients specified") + } + + fileKey := make([]byte, 16) + if _, err := rand.Read(fileKey); err != nil { + return nil, err + } + + hdr := &format.Header{} + // TODO: remove the AEAD marker from v1. + hdr.AEAD = "ChaChaPoly" + for i, r := range recipients { + if r.Type() == "scrypt" && len(recipients) != 1 { + return nil, errors.New("an scrypt recipient must be the only one") + } + + block, err := r.Wrap(fileKey) + if err != nil { + return nil, fmt.Errorf("failed to wrap key for recipient #%d: %v", i, err) + } + hdr.Recipients = append(hdr.Recipients, block) + } + if mac, err := headerMAC(fileKey, hdr); err != nil { + return nil, fmt.Errorf("failed to compute header MAC: %v", err) + } else { + hdr.MAC = mac + } + if err := hdr.Marshal(dst); err != nil { + return nil, fmt.Errorf("failed to write header: %v", err) + } + + nonce := make([]byte, 16) + if _, err := rand.Read(nonce); err != nil { + return nil, err + } + if _, err := dst.Write(nonce); err != nil { + return nil, fmt.Errorf("failed to write nonce: %v", err) + } + + return stream.NewWriter(streamKey(fileKey, nonce), dst) +} + +func Decrypt(src io.Reader, identities ...Identity) (io.Reader, error) { + if len(identities) == 0 { + return nil, errors.New("no identities specified") + } + + hdr, payload, err := format.Parse(src) + if err != nil { + return nil, fmt.Errorf("failed to read header: %v", err) + } + if hdr.AEAD != "ChaChaPoly" { + return nil, fmt.Errorf("unsupported AEAD: %v", hdr.AEAD) + } + if len(hdr.Recipients) > 20 { + return nil, errors.New("too many recipients") + } + + var fileKey []byte +RecipientsLoop: + for _, r := range hdr.Recipients { + if r.Type == "scrypt" && len(hdr.Recipients) != 1 { + return nil, errors.New("an scrypt recipient must be the only one") + } + for _, i := range identities { + + if i.Type() != r.Type { + continue + } + + fileKey, err = i.Unwrap(r) + if err == nil { + break RecipientsLoop + } + } + } + if fileKey == nil { + return nil, errors.New("no identity matched a recipient") + } + + if mac, err := headerMAC(fileKey, hdr); err != nil { + return nil, fmt.Errorf("failed to compute header MAC: %v", err) + } else if !hmac.Equal(mac, hdr.MAC) { + return nil, errors.New("bad header MAC") + } + + nonce := make([]byte, 16) + if _, err := io.ReadFull(payload, nonce); err != nil { + return nil, fmt.Errorf("failed to read nonce: %v", err) + } + + return stream.NewReader(streamKey(fileKey, nonce), payload) +} diff --git a/internal/age/age_test.go b/internal/age/age_test.go @@ -0,0 +1,103 @@ +package age_test + +import ( + "bytes" + "crypto/rand" + "io" + "io/ioutil" + "testing" + + "github.com/FiloSottile/age/internal/age" + "golang.org/x/crypto/curve25519" +) + +const helloWorld = "Hello, Twitch!" + +func TestEncryptDecryptX25519(t *testing.T) { + var secretKeyA, publicKeyA, secretKeyB, publicKeyB [32]byte + if _, err := rand.Read(secretKeyA[:]); err != nil { + t.Fatal(err) + } + if _, err := rand.Read(secretKeyB[:]); err != nil { + t.Fatal(err) + } + curve25519.ScalarBaseMult(&publicKeyA, &secretKeyA) + curve25519.ScalarBaseMult(&publicKeyB, &secretKeyB) + + rA, err := age.NewX25519Recipient(publicKeyA[:]) + if err != nil { + t.Fatal(err) + } + rB, err := age.NewX25519Recipient(publicKeyB[:]) + if err != nil { + t.Fatal(err) + } + buf := &bytes.Buffer{} + w, err := age.Encrypt(buf, rA, rB) + if err != nil { + t.Fatal(err) + } + if _, err := io.WriteString(w, helloWorld); err != nil { + t.Fatal(err) + } + if err := w.Close(); err != nil { + t.Fatal(err) + } + + t.Logf("%s", buf.Bytes()) + + i, err := age.NewX25519Identity(secretKeyB[:]) + if err != nil { + t.Fatal(err) + } + out, err := age.Decrypt(buf, i) + if err != nil { + t.Fatal(err) + } + outBytes, err := ioutil.ReadAll(out) + if err != nil { + t.Fatal(err) + } + if string(outBytes) != helloWorld { + t.Errorf("wrong data: %q, excepted %q", outBytes, helloWorld) + } +} + +func TestEncryptDecryptScrypt(t *testing.T) { + password := "twitch.tv/filosottile" + + r, err := age.NewScryptRecipient(password) + if err != nil { + t.Fatal(err) + } + r.SetWorkFactor(1 << 15) + buf := &bytes.Buffer{} + w, err := age.Encrypt(buf, r) + if err != nil { + t.Fatal(err) + } + if _, err := io.WriteString(w, helloWorld); err != nil { + t.Fatal(err) + } + if err := w.Close(); err != nil { + t.Fatal(err) + } + + t.Logf("%s", buf.Bytes()) + + i, err := age.NewScryptIdentity(password) + if err != nil { + t.Fatal(err) + } + out, err := age.Decrypt(buf, i) + if err != nil { + t.Fatal(err) + } + outBytes, err := ioutil.ReadAll(out) + if err != nil { + t.Fatal(err) + } + if string(outBytes) != helloWorld { + t.Errorf("wrong data: %q, excepted %q", outBytes, helloWorld) + } +} diff --git a/internal/age/primitives.go b/internal/age/primitives.go @@ -0,0 +1,51 @@ +package age + +import ( + "crypto/hmac" + "crypto/sha256" + "io" + + "github.com/FiloSottile/age/internal/format" + "golang.org/x/crypto/chacha20poly1305" + "golang.org/x/crypto/hkdf" +) + +func aeadEncrypt(key, plaintext []byte) ([]byte, error) { + aead, err := chacha20poly1305.New(key) + if err != nil { + return nil, err + } + nonce := make([]byte, chacha20poly1305.NonceSize) + return aead.Seal(nil, nonce, plaintext, nil), nil +} + +func aeadDecrypt(key, ciphertext []byte) ([]byte, error) { + aead, err := chacha20poly1305.New(key) + if err != nil { + return nil, err + } + nonce := make([]byte, chacha20poly1305.NonceSize) + return aead.Open(nil, nonce, ciphertext, nil) +} + +func headerMAC(fileKey []byte, hdr *format.Header) ([]byte, error) { + h := hkdf.New(sha256.New, fileKey, nil, []byte("header")) + hmacKey := make([]byte, 32) + if _, err := io.ReadFull(h, hmacKey); err != nil { + return nil, err + } + hh := hmac.New(sha256.New, hmacKey) + if err := hdr.MarshalWithoutMAC(hh); err != nil { + return nil, err + } + return hh.Sum(nil), nil +} + +func streamKey(fileKey, nonce []byte) []byte { + h := hkdf.New(sha256.New, fileKey, nonce, []byte("payload")) + streamKey := make([]byte, chacha20poly1305.KeySize) + if _, err := io.ReadFull(h, streamKey); err != nil { + panic("age: internal error: failed to read from HKDF: " + err.Error()) + } + return streamKey +} diff --git a/internal/age/recipients_test.go b/internal/age/recipients_test.go @@ -0,0 +1,131 @@ +package age_test + +import ( + "bytes" + "crypto/rand" + "crypto/rsa" + "testing" + + "github.com/FiloSottile/age/internal/age" + "golang.org/x/crypto/curve25519" + "golang.org/x/crypto/ssh" +) + +func TestX25519RoundTrip(t *testing.T) { + var secretKey, publicKey, fileKey [32]byte + if _, err := rand.Read(secretKey[:]); err != nil { + t.Fatal(err) + } + if _, err := rand.Read(fileKey[:]); err != nil { + t.Fatal(err) + } + curve25519.ScalarBaseMult(&publicKey, &secretKey) + + r, err := age.NewX25519Recipient(publicKey[:]) + if err != nil { + t.Fatal(err) + } + i, err := age.NewX25519Identity(secretKey[:]) + if err != nil { + t.Fatal(err) + } + + if r.Type() != i.Type() || r.Type() != "X25519" { + t.Errorf("invalid Type values: %v, %v", r.Type(), i.Type()) + } + + block, err := r.Wrap(fileKey[:]) + if err != nil { + t.Fatal(err) + } + t.Logf("%#v", block) + + out, err := i.Unwrap(block) + if err != nil { + t.Fatal(err) + } + + if !bytes.Equal(fileKey[:], out) { + t.Errorf("invalid output: %x, expected %x", out, fileKey[:]) + } +} + +func TestScryptRoundTrip(t *testing.T) { + password := "twitch.tv/filosottile" + + r, err := age.NewScryptRecipient(password) + if err != nil { + t.Fatal(err) + } + r.SetWorkFactor(1 << 15) + i, err := age.NewScryptIdentity(password) + if err != nil { + t.Fatal(err) + } + + if r.Type() != i.Type() || r.Type() != "scrypt" { + t.Errorf("invalid Type values: %v, %v", r.Type(), i.Type()) + } + + fileKey := make([]byte, 16) + if _, err := rand.Read(fileKey[:]); err != nil { + t.Fatal(err) + } + block, err := r.Wrap(fileKey[:]) + if err != nil { + t.Fatal(err) + } + t.Logf("%#v", block) + + out, err := i.Unwrap(block) + if err != nil { + t.Fatal(err) + } + + if !bytes.Equal(fileKey[:], out) { + t.Errorf("invalid output: %x, expected %x", out, fileKey[:]) + } +} + +func TestSSHRSARoundTrip(t *testing.T) { + pk, err := rsa.GenerateKey(rand.Reader, 768) + if err != nil { + t.Fatal(err) + } + pub, err := ssh.NewPublicKey(&pk.PublicKey) + if err != nil { + t.Fatal(err) + } + + r, err := age.NewSSHRSARecipient(pub) + if err != nil { + t.Fatal(err) + } + i, err := age.NewSSHRSAIdentity(pk) + if err != nil { + t.Fatal(err) + } + + if r.Type() != i.Type() || r.Type() != "ssh-rsa" { + t.Errorf("invalid Type values: %v, %v", r.Type(), i.Type()) + } + + fileKey := make([]byte, 16) + if _, err := rand.Read(fileKey[:]); err != nil { + t.Fatal(err) + } + block, err := r.Wrap(fileKey[:]) + if err != nil { + t.Fatal(err) + } + t.Logf("%#v", block) + + out, err := i.Unwrap(block) + if err != nil { + t.Fatal(err) + } + + if !bytes.Equal(fileKey[:], out) { + t.Errorf("invalid output: %x, expected %x", out, fileKey[:]) + } +} diff --git a/internal/age/scrypt.go b/internal/age/scrypt.go @@ -0,0 +1,125 @@ +package age + +import ( + "crypto/rand" + "errors" + "fmt" + "strconv" + + "github.com/FiloSottile/age/internal/format" + "golang.org/x/crypto/chacha20poly1305" + "golang.org/x/crypto/scrypt" +) + +type ScryptRecipient struct { + password []byte + workFactor int +} + +var _ Recipient = &ScryptRecipient{} + +func (*ScryptRecipient) Type() string { return "scrypt" } + +func NewScryptRecipient(password string) (*ScryptRecipient, error) { + if len(password) == 0 { + return nil, errors.New("empty scrypt password") + } + r := &ScryptRecipient{ + password: []byte(password), + workFactor: 1 << 18, // 1s on a modern machine + } + return r, nil +} + +func (r *ScryptRecipient) SetWorkFactor(N int) { + // TODO: automatically scale this to 1s (with a min) in the CLI. + r.workFactor = N +} + +func (r *ScryptRecipient) Wrap(fileKey []byte) (*format.Recipient, error) { + salt := make([]byte, 16) + if _, err := rand.Read(salt[:]); err != nil { + return nil, err + } + + N := r.workFactor + l := &format.Recipient{ + Type: "scrypt", + Args: []string{format.EncodeToString(salt), strconv.Itoa(N)}, + } + + k, err := scrypt.Key(r.password, salt, N, 8, 1, chacha20poly1305.KeySize) + if err != nil { + return nil, fmt.Errorf("failed to generate scrypt hash: %v", err) + } + + wrappedKey, err := aeadEncrypt(k, fileKey) + if err != nil { + return nil, err + } + l.Body = []byte(format.EncodeToString(wrappedKey) + "\n") + + return l, nil +} + +type ScryptIdentity struct { + password []byte + maxWorkFactor int +} + +var _ Identity = &ScryptIdentity{} + +func (*ScryptIdentity) Type() string { return "scrypt" } + +func NewScryptIdentity(password string) (*ScryptIdentity, error) { + if len(password) == 0 { + return nil, errors.New("empty scrypt password") + } + i := &ScryptIdentity{ + password: []byte(password), + maxWorkFactor: 1 << 22, // 15s on a modern machine + } + return i, nil +} + +func (i *ScryptIdentity) SetMaxWorkFactor(N int) { + i.maxWorkFactor = N +} + +func (i *ScryptIdentity) Unwrap(block *format.Recipient) ([]byte, error) { + if block.Type != "scrypt" { + return nil, errors.New("wrong recipient block type") + } + if len(block.Args) != 2 { + return nil, errors.New("invalid scrypt recipient block") + } + salt, err := format.DecodeString(block.Args[0]) + if err != nil { + return nil, fmt.Errorf("failed to parse scrypt salt: %v", err) + } + if len(salt) != 16 { + return nil, errors.New("invalid scrypt recipient block") + } + N, err := strconv.Atoi(block.Args[1]) + if err != nil { + return nil, fmt.Errorf("failed to parse scrypt work factor: %v", err) + } + if N > i.maxWorkFactor { + return nil, fmt.Errorf("scrypt work factor too large: %v", N) + } + wrappedKey, err := format.DecodeString(string(block.Body)) + if err != nil { + return nil, fmt.Errorf("failed to parse scrypt recipient: %v", err) + } + + k, err := scrypt.Key(i.password, salt, N, 8, 1, chacha20poly1305.KeySize) + if err != nil { + return nil, fmt.Errorf("failed to generate scrypt hash: %v", err) + } + + fileKey, err := aeadDecrypt(k, wrappedKey) + if err != nil { + return nil, fmt.Errorf("failed to decrypt file key: %v", err) + } + return fileKey, nil +} diff --git a/internal/age/ssh.go b/internal/age/ssh.go @@ -0,0 +1,118 @@ +package age + +import ( + "bytes" + "crypto/rand" + "crypto/rsa" + "crypto/sha256" + "errors" + "fmt" + + "github.com/FiloSottile/age/internal/format" + "golang.org/x/crypto/ssh" +) + +const oaepLabel = "age-tool.com ssh-rsa" + +type SSHRSARecipient struct { + sshKey ssh.PublicKey + pubKey *rsa.PublicKey +} + +var _ Recipient = &SSHRSARecipient{} + +func (*SSHRSARecipient) Type() string { return "ssh-rsa" } + +func NewSSHRSARecipient(pk ssh.PublicKey) (*SSHRSARecipient, error) { + if pk.Type() != "ssh-rsa" { + return nil, errors.New("SSH public key is not an RSA key") + } + r := &SSHRSARecipient{ + sshKey: pk, + } + + if pk, ok := pk.(ssh.CryptoPublicKey); ok { + if pk, ok := pk.CryptoPublicKey().(*rsa.PublicKey); ok { + r.pubKey = pk + } else { + return nil, errors.New("unexpected public key type") + } + } else { + return nil, errors.New("pk does not implement ssh.CryptoPublicKey") + } + return r, nil +} + +func (r *SSHRSARecipient) Wrap(fileKey []byte) (*format.Recipient, error) { + h := sha256.New() + h.Write(r.sshKey.Marshal()) + hh := h.Sum(nil) + + l := &format.Recipient{ + Type: "ssh-rsa", + Args: []string{format.EncodeToString(hh[:4])}, + } + + wrappedKey, err := rsa.EncryptOAEP(sha256.New(), rand.Reader, + r.pubKey, fileKey, []byte(oaepLabel)) + if err != nil { + return nil, err + } + l.Body = []byte(format.EncodeToString(wrappedKey) + "\n") + + return l, nil +} + +type SSHRSAIdentity struct { + k *rsa.PrivateKey + sshKey ssh.PublicKey +} + +var _ Identity = &SSHRSAIdentity{} + +func (*SSHRSAIdentity) Type() string { return "ssh-rsa" } + +func NewSSHRSAIdentity(key *rsa.PrivateKey) (*SSHRSAIdentity, error) { + s, err := ssh.NewSignerFromKey(key) + if err != nil { + return nil, err + } + i := &SSHRSAIdentity{ + k: key, sshKey: s.PublicKey(), + } + return i, nil +} + +func (i *SSHRSAIdentity) Unwrap(block *format.Recipient) ([]byte, error) { + if block.Type != "ssh-rsa" { + return nil, errors.New("wrong recipient block type") + } + if len(block.Args) != 1 { + return nil, errors.New("invalid ssh-rsa recipient block") + } + hash, err := format.DecodeString(block.Args[0]) + if err != nil { + return nil, fmt.Errorf("failed to parse ssh-rsa recipient: %v", err) + } + if len(hash) != 4 { + return nil, errors.New("invalid ssh-rsa recipient block") + } + wrappedKey, err := format.DecodeString(string(block.Body)) + if err != nil { + return nil, fmt.Errorf("failed to parse ssh-rsa recipient: %v", err) + } + + h := sha256.New() + h.Write(i.sshKey.Marshal()) + hh := h.Sum(nil) + if !bytes.Equal(hh[:4], hash) { + return nil, errors.New("wrong ssh-rsa key") + } + + fileKey, err := rsa.DecryptOAEP(sha256.New(), rand.Reader, i.k, + wrappedKey, []byte(oaepLabel)) + if err != nil { + return nil, fmt.Errorf("failed to decrypt file key: %v", err) + } + return fileKey, nil +} diff --git a/internal/age/x25519.go b/internal/age/x25519.go @@ -0,0 +1,123 @@ +package age + +import ( + "crypto/rand" + "crypto/sha256" + "errors" + "fmt" + "io" + + "github.com/FiloSottile/age/internal/format" + "golang.org/x/crypto/chacha20poly1305" + "golang.org/x/crypto/curve25519" + "golang.org/x/crypto/hkdf" +) + +const x25519Label = "age-tool.com X25519" + +type X25519Recipient struct { + theirPublicKey [32]byte +} + +var _ Recipient = &X25519Recipient{} + +func (*X25519Recipient) Type() string { return "X25519" } + +func NewX25519Recipient(publicKey []byte) (*X25519Recipient, error) { + if len(publicKey) != 32 { + return nil, errors.New("invalid X25519 public key") + } + r := &X25519Recipient{} + copy(r.theirPublicKey[:], publicKey) + return r, nil +} + +func (r *X25519Recipient) Wrap(fileKey []byte) (*format.Recipient, error) { + var ephemeral, ourPublicKey [32]byte + if _, err := rand.Read(ephemeral[:]); err != nil { + return nil, err + } + curve25519.ScalarBaseMult(&ourPublicKey, &ephemeral) + + var sharedSecret [32]byte + curve25519.ScalarMult(&sharedSecret, &ephemeral, &r.theirPublicKey) + + l := &format.Recipient{ + Type: "X25519", + Args: []string{format.EncodeToString(ourPublicKey[:])}, + } + + salt := make([]byte, 0, 32*2) + salt = append(salt, ourPublicKey[:]...) + salt = append(salt, r.theirPublicKey[:]...) + h := hkdf.New(sha256.New, sharedSecret[:], salt, []byte(x25519Label)) + wrappingKey := make([]byte, chacha20poly1305.KeySize) + if _, err := io.ReadFull(h, wrappingKey); err != nil { + return nil, err + } + + wrappedKey, err := aeadEncrypt(wrappingKey, fileKey) + if err != nil { + return nil, err + } + l.Body = []byte(format.EncodeToString(wrappedKey) + "\n") + + return l, nil +} + +type X25519Identity struct { + secretKey, ourPublicKey [32]byte +} + +var _ Identity = &X25519Identity{} + +func (*X25519Identity) Type() string { return "X25519" } + +func NewX25519Identity(secretKey []byte) (*X25519Identity, error) { + if len(secretKey) != 32 { + return nil, errors.New("invalid X25519 secret key") + } + i := &X25519Identity{} + copy(i.secretKey[:], secretKey) + curve25519.ScalarBaseMult(&i.ourPublicKey, &i.secretKey) + return i, nil +} + +func (i *X25519Identity) Unwrap(block *format.Recipient) ([]byte, error) { + if block.Type != "X25519" { + return nil, errors.New("wrong recipient block type") + } + if len(block.Args) != 1 { + return nil, errors.New("invalid X25519 recipient block") + } + publicKey, err := format.DecodeString(block.Args[0]) + if err != nil { + return nil, fmt.Errorf("failed to parse X25519 recipient: %v", err) + } + if len(publicKey) != 32 { + return nil, errors.New("invalid X25519 recipient block") + } + wrappedKey, err := format.DecodeString(string(block.Body)) + if err != nil { + return nil, fmt.Errorf("failed to parse X25519 recipient: %v", err) + } + + var sharedSecret, theirPublicKey [32]byte + copy(theirPublicKey[:], publicKey) + curve25519.ScalarMult(&sharedSecret, &i.secretKey, &theirPublicKey) + + salt := make([]byte, 0, 32*2) + salt = append(salt, theirPublicKey[:]...) + salt = append(salt, i.ourPublicKey[:]...) + h := hkdf.New(sha256.New, sharedSecret[:], salt, []byte(x25519Label)) + wrappingKey := make([]byte, chacha20poly1305.KeySize) + if _, err := io.ReadFull(h, wrappingKey); err != nil { + return nil, err + } + + fileKey, err := aeadDecrypt(wrappingKey, wrappedKey) + if err != nil { + return nil, fmt.Errorf("failed to decrypt file key: %v", err) + } + return fileKey, nil +} diff --git a/internal/format/format.go b/internal/format/format.go @@ -24,7 +24,7 @@ type Recipient struct { var b64 = base64.RawURLEncoding.Strict() -func decodeString(s string) ([]byte, error) { +func DecodeString(s string) ([]byte, error) { // CR and LF are ignored by DecodeString. LF is handled by the parser, // but CR can introduce malleability. if strings.Contains(s, "\r") { @@ -33,12 +33,14 @@ func decodeString(s string) ([]byte, error) { return b64.DecodeString(s) } +var EncodeToString = b64.EncodeToString // TODO: wrap lines + const intro = "This is a file encrypted with age-tool.com, version 1\n" var recipientPrefix = []byte("->") var footerPrefix = []byte("---") -func (h *Header) Marshal(w io.Writer) error { +func (h *Header) MarshalWithoutMAC(w io.Writer) error { if _, err := io.WriteString(w, intro); err != nil { return err } @@ -54,12 +56,21 @@ func (h *Header) Marshal(w io.Writer) error { if _, err := io.WriteString(w, "\n"); err != nil { return err } + // TODO: check that Body ends with a newline. if _, err := w.Write(r.Body); err != nil { return err } } + _, err := fmt.Fprintf(w, "%s %s", footerPrefix, h.AEAD) + return err +} + +func (h *Header) Marshal(w io.Writer) error { + if err := h.MarshalWithoutMAC(w); err != nil { + return err + } mac := b64.EncodeToString(h.MAC) - _, err := fmt.Fprintf(w, "%s %s %s\n", footerPrefix, h.AEAD, mac) + _, err := fmt.Fprintf(w, " %s\n", mac) return err } @@ -100,7 +111,7 @@ func Parse(input io.Reader) (*Header, io.Reader, error) { return nil, nil, errorf("malformed closing line: %q", line) } h.AEAD = args[0] - h.MAC, err = decodeString(args[1]) + h.MAC, err = DecodeString(args[1]) if err != nil { return nil, nil, errorf("malformed closing line %q: %v", line, err) } diff --git a/internal/format/format_gofuzz.go b/internal/format/format_gofuzz.go @@ -1,3 +1,5 @@ +// +build gofuzz + package format import ( diff --git a/internal/stream/stream.go b/internal/stream/stream.go @@ -0,0 +1,208 @@ +package stream + +import ( + "crypto/cipher" + "errors" + "io" + + "golang.org/x/crypto/chacha20poly1305" + "golang.org/x/crypto/poly1305" +) + +const ChunkSize = 64 * 1024 + +type Reader struct { + a cipher.AEAD + src io.Reader + + unread []byte // decrypted but unread data, backed by buf + buf [encChunkSize]byte + + err error + nonce [chacha20poly1305.NonceSize]byte +} + +const ( + encChunkSize = ChunkSize + poly1305.TagSize + lastChunkFlag = 0x01 +) + +func NewReader(key []byte, src io.Reader) (*Reader, error) { + aead, err := chacha20poly1305.New(key) + if err != nil { + return nil, err + } + return &Reader{ + a: aead, + src: src, + }, nil +} + +func (r *Reader) Read(p []byte) (int, error) { + if len(r.unread) > 0 { + n := copy(p, r.unread) + r.unread = r.unread[n:] + return n, nil + } + if r.err != nil { + return 0, r.err + } + if len(p) == 0 { + return 0, nil + } + + last, err := r.readChunk() + if err != nil { + r.err = err + return 0, err + } + + n := copy(p, r.unread) + r.unread = r.unread[n:] + + if last { + r.err = io.EOF + } + + return n, nil +} + +// readChunk reads the next chunk of ciphertext from r.c and makes in available +// in r.unread. last is true if the chunk was marked as the end of the message. +// readChunk must not be called again after returning a last chunk or an error. +func (r *Reader) readChunk() (last bool, err error) { + if len(r.unread) != 0 { + panic("stream: internal error: readChunk called with dirty buffer") + } + + in := r.buf[:] + n, err := io.ReadFull(r.src, in) + switch { + case err == io.EOF: + // A message can't end without a marked chunk. This message is truncated. + return false, io.ErrUnexpectedEOF + case err == io.ErrUnexpectedEOF: + // The last chunk can be short. + in = in[:n] + last = true + setLastChunkFlag(&r.nonce) + case err != nil: + return false, err + } + + outBuf := make([]byte, 0, ChunkSize) + out, err := r.a.Open(outBuf, r.nonce[:], in, nil) + if err != nil && !last { + // Check if this was a full-length final chunk. + last = true + setLastChunkFlag(&r.nonce) + out, err = r.a.Open(outBuf, r.nonce[:], in, nil) + } + if err != nil { + return false, err + } + + incNonce(&r.nonce) + r.unread = r.buf[:copy(r.buf[:], out)] + return last, nil +} + +func incNonce(nonce *[chacha20poly1305.NonceSize]byte) { + for i := len(nonce) - 2; i >= 0; i-- { + nonce[i]++ + if nonce[i] != 0 { + break + } else if i == 0 { + // The counter is 88 bits, this is unreachable. + panic("stream: chunk counter wrapped around") + } + } +} + +func setLastChunkFlag(nonce *[chacha20poly1305.NonceSize]byte) { + nonce[len(nonce)-1] = lastChunkFlag +} + +type Writer struct { + a cipher.AEAD + dst io.Writer + unwritten []byte // backed by buf + buf [encChunkSize]byte + nonce [chacha20poly1305.NonceSize]byte + err error +} + +func NewWriter(key []byte, dst io.Writer) (*Writer, error) { + aead, err := chacha20poly1305.New(key) + if err != nil { + return nil, err + } + w := &Writer{ + a: aead, + dst: dst, + } + w.unwritten = w.buf[:0] + return w, nil +} + +func (w *Writer) Write(p []byte) (n int, err error) { + // TODO: consider refactoring with a bytes.Buffer. + if w.err != nil { + return 0, w.err + } + if len(p) == 0 { + return 0, nil + } + + total := len(p) + for len(p) > 0 { + free := ChunkSize - len(w.unwritten) + freeBuf := w.buf[len(w.unwritten) : len(w.unwritten)+free] + n := copy(freeBuf, p) + p = p[n:] + w.unwritten = w.unwritten[:len(w.unwritten)+n] + + if len(w.unwritten) == ChunkSize && len(p) > 0 { + if err := w.flushChunk(notLastChunk); err != nil { + w.err = err + return 0, err + } + } + } + return total, nil +} + +func (w *Writer) Close() error { + // TODO: close w.dst if it can be interface upgraded to io.Closer. + if w.err != nil { + return w.err + } + + err := w.flushChunk(lastChunk) + if err != nil { + w.err = err + } else { + w.err = errors.New("stream.Writer is already closed") + } + return err +} + +const ( + lastChunk = true + notLastChunk = false +) + +func (w *Writer) flushChunk(last bool) error { + if !last && len(w.unwritten) != ChunkSize { + panic("stream: internal error: flush called with partial chunk") + } + + if last { + setLastChunkFlag(&w.nonce) + } + buf := w.a.Seal(w.buf[:0], w.nonce[:], w.unwritten, nil) + _, err := w.dst.Write(buf) + w.unwritten = w.buf[:0] + incNonce(&w.nonce) + return err +} diff --git a/internal/stream/stream_test.go b/internal/stream/stream_test.go @@ -0,0 +1,93 @@ +package stream_test + +import ( + "bytes" + "crypto/rand" + "fmt" + "testing" + + "github.com/FiloSottile/age/internal/stream" + "golang.org/x/crypto/chacha20poly1305" +) + +const cs = stream.ChunkSize + +func TestRoundTrip(t *testing.T) { + for _, stepSize := range []int{512, 600, 1000, cs} { + for _, length := range []int{0, 1000, cs, cs + 100} { + t.Run(fmt.Sprintf("len=%d,step=%d", length, stepSize), + func(t *testing.T) { testRoundTrip(t, stepSize, length) }) + } + } +} + +func testRoundTrip(t *testing.T, stepSize, length int) { + src := make([]byte, length) + if _, err := rand.Read(src); err != nil { + t.Fatal(err) + } + buf := &bytes.Buffer{} + key := make([]byte, chacha20poly1305.KeySize) + if _, err := rand.Read(key); err != nil { + t.Fatal(err) + } + + w, err := stream.NewWriter(key, buf) + if err != nil { + t.Fatal(err) + } + + var n int + for n < length { + b := length - n + if b > stepSize { + b = stepSize + } + nn, err := w.Write(src[n : n+b]) + if err != nil { + t.Fatal(err) + } + if nn != b { + t.Errorf("Write returned %d, expected %d", nn, b) + } + n += nn + + nn, err = w.Write(src[n:n]) + if err != nil { + t.Fatal(err) + } + if nn != 0 { + t.Errorf("Write returned %d, expected 0", nn) + } + } + + if err := w.Close(); err != nil { + t.Error("Close returned an error:", err) + } + + t.Logf("buffer size: %d", buf.Len()) + + r, err := stream.NewReader(key, buf) + if err != nil { + t.Fatal(err) + } + + n = 0 + readBuf := make([]byte, stepSize) + for n < length { + b := length - n + if b > stepSize { + b = stepSize + } + nn, err := r.Read(readBuf) + if err != nil { + t.Fatalf("Read error at index %d: %v", n, err) + } + + if !bytes.Equal(readBuf[:nn], src[n:n+nn]) { + t.Errorf("wrong data at indexes %d - %d", n, n+nn) + } + + n += nn + } +}