commit e9c118cea0a3c7b24c789d4150f2d6f2d09fd525
parent 52dbe9eecf8b2153f77b13157a382d53215b99df
Author: Filippo Valsorda <hi@filippo.io>
Date: Sun, 6 Oct 2019 21:19:04 -0400
internal: implement STREAM, key exchange, encryption and decryption
Developed live over 6 hours of streaming on Twitch.
https://twitter.com/FiloSottile/status/1180875486911766528
Diffstat:
13 files changed, 1097 insertions(+), 4 deletions(-)
diff --git a/go.mod b/go.mod
@@ -1,3 +1,5 @@
module github.com/FiloSottile/age
go 1.12
+
+require golang.org/x/crypto v0.0.0-20191002192127-34f69633bfdc
diff --git a/go.sum b/go.sum
@@ -0,0 +1,8 @@
+golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
+golang.org/x/crypto v0.0.0-20191002192127-34f69633bfdc h1:c0o/qxkaO2LF5t6fQrT4b5hzyggAkLLlCUjqfRxd8Q4=
+golang.org/x/crypto v0.0.0-20191002192127-34f69633bfdc/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
+golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
+golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
+golang.org/x/sys v0.0.0-20190412213103-97732733099d h1:+R4KGOnez64A81RvjARKc4UT5/tI9ujCIVX+P5KiHuI=
+golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
+golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
diff --git a/internal/age/age.go b/internal/age/age.go
@@ -0,0 +1,118 @@
+package age
+
+import (
+ "crypto/hmac"
+ "crypto/rand"
+ "errors"
+ "fmt"
+ "io"
+
+ "github.com/FiloSottile/age/internal/format"
+ "github.com/FiloSottile/age/internal/stream"
+)
+
+type Identity interface {
+ Type() string
+ Unwrap(block *format.Recipient) (fileKey []byte, err error)
+}
+
+type Recipient interface {
+ Type() string
+ Wrap(fileKey []byte) (*format.Recipient, error)
+}
+
+func Encrypt(dst io.Writer, recipients ...Recipient) (io.WriteCloser, error) {
+ if len(recipients) == 0 {
+ return nil, errors.New("no recipients specified")
+ }
+
+ fileKey := make([]byte, 16)
+ if _, err := rand.Read(fileKey); err != nil {
+ return nil, err
+ }
+
+ hdr := &format.Header{}
+ // TODO: remove the AEAD marker from v1.
+ hdr.AEAD = "ChaChaPoly"
+ for i, r := range recipients {
+ if r.Type() == "scrypt" && len(recipients) != 1 {
+ return nil, errors.New("an scrypt recipient must be the only one")
+ }
+
+ block, err := r.Wrap(fileKey)
+ if err != nil {
+ return nil, fmt.Errorf("failed to wrap key for recipient #%d: %v", i, err)
+ }
+ hdr.Recipients = append(hdr.Recipients, block)
+ }
+ if mac, err := headerMAC(fileKey, hdr); err != nil {
+ return nil, fmt.Errorf("failed to compute header MAC: %v", err)
+ } else {
+ hdr.MAC = mac
+ }
+ if err := hdr.Marshal(dst); err != nil {
+ return nil, fmt.Errorf("failed to write header: %v", err)
+ }
+
+ nonce := make([]byte, 16)
+ if _, err := rand.Read(nonce); err != nil {
+ return nil, err
+ }
+ if _, err := dst.Write(nonce); err != nil {
+ return nil, fmt.Errorf("failed to write nonce: %v", err)
+ }
+
+ return stream.NewWriter(streamKey(fileKey, nonce), dst)
+}
+
+func Decrypt(src io.Reader, identities ...Identity) (io.Reader, error) {
+ if len(identities) == 0 {
+ return nil, errors.New("no identities specified")
+ }
+
+ hdr, payload, err := format.Parse(src)
+ if err != nil {
+ return nil, fmt.Errorf("failed to read header: %v", err)
+ }
+ if hdr.AEAD != "ChaChaPoly" {
+ return nil, fmt.Errorf("unsupported AEAD: %v", hdr.AEAD)
+ }
+ if len(hdr.Recipients) > 20 {
+ return nil, errors.New("too many recipients")
+ }
+
+ var fileKey []byte
+RecipientsLoop:
+ for _, r := range hdr.Recipients {
+ if r.Type == "scrypt" && len(hdr.Recipients) != 1 {
+ return nil, errors.New("an scrypt recipient must be the only one")
+ }
+ for _, i := range identities {
+
+ if i.Type() != r.Type {
+ continue
+ }
+
+ fileKey, err = i.Unwrap(r)
+ if err == nil {
+ break RecipientsLoop
+ }
+ }
+ }
+ if fileKey == nil {
+ return nil, errors.New("no identity matched a recipient")
+ }
+
+ if mac, err := headerMAC(fileKey, hdr); err != nil {
+ return nil, fmt.Errorf("failed to compute header MAC: %v", err)
+ } else if !hmac.Equal(mac, hdr.MAC) {
+ return nil, errors.New("bad header MAC")
+ }
+
+ nonce := make([]byte, 16)
+ if _, err := io.ReadFull(payload, nonce); err != nil {
+ return nil, fmt.Errorf("failed to read nonce: %v", err)
+ }
+
+ return stream.NewReader(streamKey(fileKey, nonce), payload)
+}
diff --git a/internal/age/age_test.go b/internal/age/age_test.go
@@ -0,0 +1,103 @@
+package age_test
+
+import (
+ "bytes"
+ "crypto/rand"
+ "io"
+ "io/ioutil"
+ "testing"
+
+ "github.com/FiloSottile/age/internal/age"
+ "golang.org/x/crypto/curve25519"
+)
+
+const helloWorld = "Hello, Twitch!"
+
+func TestEncryptDecryptX25519(t *testing.T) {
+ var secretKeyA, publicKeyA, secretKeyB, publicKeyB [32]byte
+ if _, err := rand.Read(secretKeyA[:]); err != nil {
+ t.Fatal(err)
+ }
+ if _, err := rand.Read(secretKeyB[:]); err != nil {
+ t.Fatal(err)
+ }
+ curve25519.ScalarBaseMult(&publicKeyA, &secretKeyA)
+ curve25519.ScalarBaseMult(&publicKeyB, &secretKeyB)
+
+ rA, err := age.NewX25519Recipient(publicKeyA[:])
+ if err != nil {
+ t.Fatal(err)
+ }
+ rB, err := age.NewX25519Recipient(publicKeyB[:])
+ if err != nil {
+ t.Fatal(err)
+ }
+ buf := &bytes.Buffer{}
+ w, err := age.Encrypt(buf, rA, rB)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if _, err := io.WriteString(w, helloWorld); err != nil {
+ t.Fatal(err)
+ }
+ if err := w.Close(); err != nil {
+ t.Fatal(err)
+ }
+
+ t.Logf("%s", buf.Bytes())
+
+ i, err := age.NewX25519Identity(secretKeyB[:])
+ if err != nil {
+ t.Fatal(err)
+ }
+ out, err := age.Decrypt(buf, i)
+ if err != nil {
+ t.Fatal(err)
+ }
+ outBytes, err := ioutil.ReadAll(out)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if string(outBytes) != helloWorld {
+ t.Errorf("wrong data: %q, excepted %q", outBytes, helloWorld)
+ }
+}
+
+func TestEncryptDecryptScrypt(t *testing.T) {
+ password := "twitch.tv/filosottile"
+
+ r, err := age.NewScryptRecipient(password)
+ if err != nil {
+ t.Fatal(err)
+ }
+ r.SetWorkFactor(1 << 15)
+ buf := &bytes.Buffer{}
+ w, err := age.Encrypt(buf, r)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if _, err := io.WriteString(w, helloWorld); err != nil {
+ t.Fatal(err)
+ }
+ if err := w.Close(); err != nil {
+ t.Fatal(err)
+ }
+
+ t.Logf("%s", buf.Bytes())
+
+ i, err := age.NewScryptIdentity(password)
+ if err != nil {
+ t.Fatal(err)
+ }
+ out, err := age.Decrypt(buf, i)
+ if err != nil {
+ t.Fatal(err)
+ }
+ outBytes, err := ioutil.ReadAll(out)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if string(outBytes) != helloWorld {
+ t.Errorf("wrong data: %q, excepted %q", outBytes, helloWorld)
+ }
+}
diff --git a/internal/age/primitives.go b/internal/age/primitives.go
@@ -0,0 +1,51 @@
+package age
+
+import (
+ "crypto/hmac"
+ "crypto/sha256"
+ "io"
+
+ "github.com/FiloSottile/age/internal/format"
+ "golang.org/x/crypto/chacha20poly1305"
+ "golang.org/x/crypto/hkdf"
+)
+
+func aeadEncrypt(key, plaintext []byte) ([]byte, error) {
+ aead, err := chacha20poly1305.New(key)
+ if err != nil {
+ return nil, err
+ }
+ nonce := make([]byte, chacha20poly1305.NonceSize)
+ return aead.Seal(nil, nonce, plaintext, nil), nil
+}
+
+func aeadDecrypt(key, ciphertext []byte) ([]byte, error) {
+ aead, err := chacha20poly1305.New(key)
+ if err != nil {
+ return nil, err
+ }
+ nonce := make([]byte, chacha20poly1305.NonceSize)
+ return aead.Open(nil, nonce, ciphertext, nil)
+}
+
+func headerMAC(fileKey []byte, hdr *format.Header) ([]byte, error) {
+ h := hkdf.New(sha256.New, fileKey, nil, []byte("header"))
+ hmacKey := make([]byte, 32)
+ if _, err := io.ReadFull(h, hmacKey); err != nil {
+ return nil, err
+ }
+ hh := hmac.New(sha256.New, hmacKey)
+ if err := hdr.MarshalWithoutMAC(hh); err != nil {
+ return nil, err
+ }
+ return hh.Sum(nil), nil
+}
+
+func streamKey(fileKey, nonce []byte) []byte {
+ h := hkdf.New(sha256.New, fileKey, nonce, []byte("payload"))
+ streamKey := make([]byte, chacha20poly1305.KeySize)
+ if _, err := io.ReadFull(h, streamKey); err != nil {
+ panic("age: internal error: failed to read from HKDF: " + err.Error())
+ }
+ return streamKey
+}
diff --git a/internal/age/recipients_test.go b/internal/age/recipients_test.go
@@ -0,0 +1,131 @@
+package age_test
+
+import (
+ "bytes"
+ "crypto/rand"
+ "crypto/rsa"
+ "testing"
+
+ "github.com/FiloSottile/age/internal/age"
+ "golang.org/x/crypto/curve25519"
+ "golang.org/x/crypto/ssh"
+)
+
+func TestX25519RoundTrip(t *testing.T) {
+ var secretKey, publicKey, fileKey [32]byte
+ if _, err := rand.Read(secretKey[:]); err != nil {
+ t.Fatal(err)
+ }
+ if _, err := rand.Read(fileKey[:]); err != nil {
+ t.Fatal(err)
+ }
+ curve25519.ScalarBaseMult(&publicKey, &secretKey)
+
+ r, err := age.NewX25519Recipient(publicKey[:])
+ if err != nil {
+ t.Fatal(err)
+ }
+ i, err := age.NewX25519Identity(secretKey[:])
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ if r.Type() != i.Type() || r.Type() != "X25519" {
+ t.Errorf("invalid Type values: %v, %v", r.Type(), i.Type())
+ }
+
+ block, err := r.Wrap(fileKey[:])
+ if err != nil {
+ t.Fatal(err)
+ }
+ t.Logf("%#v", block)
+
+ out, err := i.Unwrap(block)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ if !bytes.Equal(fileKey[:], out) {
+ t.Errorf("invalid output: %x, expected %x", out, fileKey[:])
+ }
+}
+
+func TestScryptRoundTrip(t *testing.T) {
+ password := "twitch.tv/filosottile"
+
+ r, err := age.NewScryptRecipient(password)
+ if err != nil {
+ t.Fatal(err)
+ }
+ r.SetWorkFactor(1 << 15)
+ i, err := age.NewScryptIdentity(password)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ if r.Type() != i.Type() || r.Type() != "scrypt" {
+ t.Errorf("invalid Type values: %v, %v", r.Type(), i.Type())
+ }
+
+ fileKey := make([]byte, 16)
+ if _, err := rand.Read(fileKey[:]); err != nil {
+ t.Fatal(err)
+ }
+ block, err := r.Wrap(fileKey[:])
+ if err != nil {
+ t.Fatal(err)
+ }
+ t.Logf("%#v", block)
+
+ out, err := i.Unwrap(block)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ if !bytes.Equal(fileKey[:], out) {
+ t.Errorf("invalid output: %x, expected %x", out, fileKey[:])
+ }
+}
+
+func TestSSHRSARoundTrip(t *testing.T) {
+ pk, err := rsa.GenerateKey(rand.Reader, 768)
+ if err != nil {
+ t.Fatal(err)
+ }
+ pub, err := ssh.NewPublicKey(&pk.PublicKey)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ r, err := age.NewSSHRSARecipient(pub)
+ if err != nil {
+ t.Fatal(err)
+ }
+ i, err := age.NewSSHRSAIdentity(pk)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ if r.Type() != i.Type() || r.Type() != "ssh-rsa" {
+ t.Errorf("invalid Type values: %v, %v", r.Type(), i.Type())
+ }
+
+ fileKey := make([]byte, 16)
+ if _, err := rand.Read(fileKey[:]); err != nil {
+ t.Fatal(err)
+ }
+ block, err := r.Wrap(fileKey[:])
+ if err != nil {
+ t.Fatal(err)
+ }
+ t.Logf("%#v", block)
+
+ out, err := i.Unwrap(block)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ if !bytes.Equal(fileKey[:], out) {
+ t.Errorf("invalid output: %x, expected %x", out, fileKey[:])
+ }
+}
diff --git a/internal/age/scrypt.go b/internal/age/scrypt.go
@@ -0,0 +1,125 @@
+package age
+
+import (
+ "crypto/rand"
+ "errors"
+ "fmt"
+ "strconv"
+
+ "github.com/FiloSottile/age/internal/format"
+ "golang.org/x/crypto/chacha20poly1305"
+ "golang.org/x/crypto/scrypt"
+)
+
+type ScryptRecipient struct {
+ password []byte
+ workFactor int
+}
+
+var _ Recipient = &ScryptRecipient{}
+
+func (*ScryptRecipient) Type() string { return "scrypt" }
+
+func NewScryptRecipient(password string) (*ScryptRecipient, error) {
+ if len(password) == 0 {
+ return nil, errors.New("empty scrypt password")
+ }
+ r := &ScryptRecipient{
+ password: []byte(password),
+ workFactor: 1 << 18, // 1s on a modern machine
+ }
+ return r, nil
+}
+
+func (r *ScryptRecipient) SetWorkFactor(N int) {
+ // TODO: automatically scale this to 1s (with a min) in the CLI.
+ r.workFactor = N
+}
+
+func (r *ScryptRecipient) Wrap(fileKey []byte) (*format.Recipient, error) {
+ salt := make([]byte, 16)
+ if _, err := rand.Read(salt[:]); err != nil {
+ return nil, err
+ }
+
+ N := r.workFactor
+ l := &format.Recipient{
+ Type: "scrypt",
+ Args: []string{format.EncodeToString(salt), strconv.Itoa(N)},
+ }
+
+ k, err := scrypt.Key(r.password, salt, N, 8, 1, chacha20poly1305.KeySize)
+ if err != nil {
+ return nil, fmt.Errorf("failed to generate scrypt hash: %v", err)
+ }
+
+ wrappedKey, err := aeadEncrypt(k, fileKey)
+ if err != nil {
+ return nil, err
+ }
+ l.Body = []byte(format.EncodeToString(wrappedKey) + "\n")
+
+ return l, nil
+}
+
+type ScryptIdentity struct {
+ password []byte
+ maxWorkFactor int
+}
+
+var _ Identity = &ScryptIdentity{}
+
+func (*ScryptIdentity) Type() string { return "scrypt" }
+
+func NewScryptIdentity(password string) (*ScryptIdentity, error) {
+ if len(password) == 0 {
+ return nil, errors.New("empty scrypt password")
+ }
+ i := &ScryptIdentity{
+ password: []byte(password),
+ maxWorkFactor: 1 << 22, // 15s on a modern machine
+ }
+ return i, nil
+}
+
+func (i *ScryptIdentity) SetMaxWorkFactor(N int) {
+ i.maxWorkFactor = N
+}
+
+func (i *ScryptIdentity) Unwrap(block *format.Recipient) ([]byte, error) {
+ if block.Type != "scrypt" {
+ return nil, errors.New("wrong recipient block type")
+ }
+ if len(block.Args) != 2 {
+ return nil, errors.New("invalid scrypt recipient block")
+ }
+ salt, err := format.DecodeString(block.Args[0])
+ if err != nil {
+ return nil, fmt.Errorf("failed to parse scrypt salt: %v", err)
+ }
+ if len(salt) != 16 {
+ return nil, errors.New("invalid scrypt recipient block")
+ }
+ N, err := strconv.Atoi(block.Args[1])
+ if err != nil {
+ return nil, fmt.Errorf("failed to parse scrypt work factor: %v", err)
+ }
+ if N > i.maxWorkFactor {
+ return nil, fmt.Errorf("scrypt work factor too large: %v", N)
+ }
+ wrappedKey, err := format.DecodeString(string(block.Body))
+ if err != nil {
+ return nil, fmt.Errorf("failed to parse scrypt recipient: %v", err)
+ }
+
+ k, err := scrypt.Key(i.password, salt, N, 8, 1, chacha20poly1305.KeySize)
+ if err != nil {
+ return nil, fmt.Errorf("failed to generate scrypt hash: %v", err)
+ }
+
+ fileKey, err := aeadDecrypt(k, wrappedKey)
+ if err != nil {
+ return nil, fmt.Errorf("failed to decrypt file key: %v", err)
+ }
+ return fileKey, nil
+}
diff --git a/internal/age/ssh.go b/internal/age/ssh.go
@@ -0,0 +1,118 @@
+package age
+
+import (
+ "bytes"
+ "crypto/rand"
+ "crypto/rsa"
+ "crypto/sha256"
+ "errors"
+ "fmt"
+
+ "github.com/FiloSottile/age/internal/format"
+ "golang.org/x/crypto/ssh"
+)
+
+const oaepLabel = "age-tool.com ssh-rsa"
+
+type SSHRSARecipient struct {
+ sshKey ssh.PublicKey
+ pubKey *rsa.PublicKey
+}
+
+var _ Recipient = &SSHRSARecipient{}
+
+func (*SSHRSARecipient) Type() string { return "ssh-rsa" }
+
+func NewSSHRSARecipient(pk ssh.PublicKey) (*SSHRSARecipient, error) {
+ if pk.Type() != "ssh-rsa" {
+ return nil, errors.New("SSH public key is not an RSA key")
+ }
+ r := &SSHRSARecipient{
+ sshKey: pk,
+ }
+
+ if pk, ok := pk.(ssh.CryptoPublicKey); ok {
+ if pk, ok := pk.CryptoPublicKey().(*rsa.PublicKey); ok {
+ r.pubKey = pk
+ } else {
+ return nil, errors.New("unexpected public key type")
+ }
+ } else {
+ return nil, errors.New("pk does not implement ssh.CryptoPublicKey")
+ }
+ return r, nil
+}
+
+func (r *SSHRSARecipient) Wrap(fileKey []byte) (*format.Recipient, error) {
+ h := sha256.New()
+ h.Write(r.sshKey.Marshal())
+ hh := h.Sum(nil)
+
+ l := &format.Recipient{
+ Type: "ssh-rsa",
+ Args: []string{format.EncodeToString(hh[:4])},
+ }
+
+ wrappedKey, err := rsa.EncryptOAEP(sha256.New(), rand.Reader,
+ r.pubKey, fileKey, []byte(oaepLabel))
+ if err != nil {
+ return nil, err
+ }
+ l.Body = []byte(format.EncodeToString(wrappedKey) + "\n")
+
+ return l, nil
+}
+
+type SSHRSAIdentity struct {
+ k *rsa.PrivateKey
+ sshKey ssh.PublicKey
+}
+
+var _ Identity = &SSHRSAIdentity{}
+
+func (*SSHRSAIdentity) Type() string { return "ssh-rsa" }
+
+func NewSSHRSAIdentity(key *rsa.PrivateKey) (*SSHRSAIdentity, error) {
+ s, err := ssh.NewSignerFromKey(key)
+ if err != nil {
+ return nil, err
+ }
+ i := &SSHRSAIdentity{
+ k: key, sshKey: s.PublicKey(),
+ }
+ return i, nil
+}
+
+func (i *SSHRSAIdentity) Unwrap(block *format.Recipient) ([]byte, error) {
+ if block.Type != "ssh-rsa" {
+ return nil, errors.New("wrong recipient block type")
+ }
+ if len(block.Args) != 1 {
+ return nil, errors.New("invalid ssh-rsa recipient block")
+ }
+ hash, err := format.DecodeString(block.Args[0])
+ if err != nil {
+ return nil, fmt.Errorf("failed to parse ssh-rsa recipient: %v", err)
+ }
+ if len(hash) != 4 {
+ return nil, errors.New("invalid ssh-rsa recipient block")
+ }
+ wrappedKey, err := format.DecodeString(string(block.Body))
+ if err != nil {
+ return nil, fmt.Errorf("failed to parse ssh-rsa recipient: %v", err)
+ }
+
+ h := sha256.New()
+ h.Write(i.sshKey.Marshal())
+ hh := h.Sum(nil)
+ if !bytes.Equal(hh[:4], hash) {
+ return nil, errors.New("wrong ssh-rsa key")
+ }
+
+ fileKey, err := rsa.DecryptOAEP(sha256.New(), rand.Reader, i.k,
+ wrappedKey, []byte(oaepLabel))
+ if err != nil {
+ return nil, fmt.Errorf("failed to decrypt file key: %v", err)
+ }
+ return fileKey, nil
+}
diff --git a/internal/age/x25519.go b/internal/age/x25519.go
@@ -0,0 +1,123 @@
+package age
+
+import (
+ "crypto/rand"
+ "crypto/sha256"
+ "errors"
+ "fmt"
+ "io"
+
+ "github.com/FiloSottile/age/internal/format"
+ "golang.org/x/crypto/chacha20poly1305"
+ "golang.org/x/crypto/curve25519"
+ "golang.org/x/crypto/hkdf"
+)
+
+const x25519Label = "age-tool.com X25519"
+
+type X25519Recipient struct {
+ theirPublicKey [32]byte
+}
+
+var _ Recipient = &X25519Recipient{}
+
+func (*X25519Recipient) Type() string { return "X25519" }
+
+func NewX25519Recipient(publicKey []byte) (*X25519Recipient, error) {
+ if len(publicKey) != 32 {
+ return nil, errors.New("invalid X25519 public key")
+ }
+ r := &X25519Recipient{}
+ copy(r.theirPublicKey[:], publicKey)
+ return r, nil
+}
+
+func (r *X25519Recipient) Wrap(fileKey []byte) (*format.Recipient, error) {
+ var ephemeral, ourPublicKey [32]byte
+ if _, err := rand.Read(ephemeral[:]); err != nil {
+ return nil, err
+ }
+ curve25519.ScalarBaseMult(&ourPublicKey, &ephemeral)
+
+ var sharedSecret [32]byte
+ curve25519.ScalarMult(&sharedSecret, &ephemeral, &r.theirPublicKey)
+
+ l := &format.Recipient{
+ Type: "X25519",
+ Args: []string{format.EncodeToString(ourPublicKey[:])},
+ }
+
+ salt := make([]byte, 0, 32*2)
+ salt = append(salt, ourPublicKey[:]...)
+ salt = append(salt, r.theirPublicKey[:]...)
+ h := hkdf.New(sha256.New, sharedSecret[:], salt, []byte(x25519Label))
+ wrappingKey := make([]byte, chacha20poly1305.KeySize)
+ if _, err := io.ReadFull(h, wrappingKey); err != nil {
+ return nil, err
+ }
+
+ wrappedKey, err := aeadEncrypt(wrappingKey, fileKey)
+ if err != nil {
+ return nil, err
+ }
+ l.Body = []byte(format.EncodeToString(wrappedKey) + "\n")
+
+ return l, nil
+}
+
+type X25519Identity struct {
+ secretKey, ourPublicKey [32]byte
+}
+
+var _ Identity = &X25519Identity{}
+
+func (*X25519Identity) Type() string { return "X25519" }
+
+func NewX25519Identity(secretKey []byte) (*X25519Identity, error) {
+ if len(secretKey) != 32 {
+ return nil, errors.New("invalid X25519 secret key")
+ }
+ i := &X25519Identity{}
+ copy(i.secretKey[:], secretKey)
+ curve25519.ScalarBaseMult(&i.ourPublicKey, &i.secretKey)
+ return i, nil
+}
+
+func (i *X25519Identity) Unwrap(block *format.Recipient) ([]byte, error) {
+ if block.Type != "X25519" {
+ return nil, errors.New("wrong recipient block type")
+ }
+ if len(block.Args) != 1 {
+ return nil, errors.New("invalid X25519 recipient block")
+ }
+ publicKey, err := format.DecodeString(block.Args[0])
+ if err != nil {
+ return nil, fmt.Errorf("failed to parse X25519 recipient: %v", err)
+ }
+ if len(publicKey) != 32 {
+ return nil, errors.New("invalid X25519 recipient block")
+ }
+ wrappedKey, err := format.DecodeString(string(block.Body))
+ if err != nil {
+ return nil, fmt.Errorf("failed to parse X25519 recipient: %v", err)
+ }
+
+ var sharedSecret, theirPublicKey [32]byte
+ copy(theirPublicKey[:], publicKey)
+ curve25519.ScalarMult(&sharedSecret, &i.secretKey, &theirPublicKey)
+
+ salt := make([]byte, 0, 32*2)
+ salt = append(salt, theirPublicKey[:]...)
+ salt = append(salt, i.ourPublicKey[:]...)
+ h := hkdf.New(sha256.New, sharedSecret[:], salt, []byte(x25519Label))
+ wrappingKey := make([]byte, chacha20poly1305.KeySize)
+ if _, err := io.ReadFull(h, wrappingKey); err != nil {
+ return nil, err
+ }
+
+ fileKey, err := aeadDecrypt(wrappingKey, wrappedKey)
+ if err != nil {
+ return nil, fmt.Errorf("failed to decrypt file key: %v", err)
+ }
+ return fileKey, nil
+}
diff --git a/internal/format/format.go b/internal/format/format.go
@@ -24,7 +24,7 @@ type Recipient struct {
var b64 = base64.RawURLEncoding.Strict()
-func decodeString(s string) ([]byte, error) {
+func DecodeString(s string) ([]byte, error) {
// CR and LF are ignored by DecodeString. LF is handled by the parser,
// but CR can introduce malleability.
if strings.Contains(s, "\r") {
@@ -33,12 +33,14 @@ func decodeString(s string) ([]byte, error) {
return b64.DecodeString(s)
}
+var EncodeToString = b64.EncodeToString // TODO: wrap lines
+
const intro = "This is a file encrypted with age-tool.com, version 1\n"
var recipientPrefix = []byte("->")
var footerPrefix = []byte("---")
-func (h *Header) Marshal(w io.Writer) error {
+func (h *Header) MarshalWithoutMAC(w io.Writer) error {
if _, err := io.WriteString(w, intro); err != nil {
return err
}
@@ -54,12 +56,21 @@ func (h *Header) Marshal(w io.Writer) error {
if _, err := io.WriteString(w, "\n"); err != nil {
return err
}
+ // TODO: check that Body ends with a newline.
if _, err := w.Write(r.Body); err != nil {
return err
}
}
+ _, err := fmt.Fprintf(w, "%s %s", footerPrefix, h.AEAD)
+ return err
+}
+
+func (h *Header) Marshal(w io.Writer) error {
+ if err := h.MarshalWithoutMAC(w); err != nil {
+ return err
+ }
mac := b64.EncodeToString(h.MAC)
- _, err := fmt.Fprintf(w, "%s %s %s\n", footerPrefix, h.AEAD, mac)
+ _, err := fmt.Fprintf(w, " %s\n", mac)
return err
}
@@ -100,7 +111,7 @@ func Parse(input io.Reader) (*Header, io.Reader, error) {
return nil, nil, errorf("malformed closing line: %q", line)
}
h.AEAD = args[0]
- h.MAC, err = decodeString(args[1])
+ h.MAC, err = DecodeString(args[1])
if err != nil {
return nil, nil, errorf("malformed closing line %q: %v", line, err)
}
diff --git a/internal/format/format_gofuzz.go b/internal/format/format_gofuzz.go
@@ -1,3 +1,5 @@
+// +build gofuzz
+
package format
import (
diff --git a/internal/stream/stream.go b/internal/stream/stream.go
@@ -0,0 +1,208 @@
+package stream
+
+import (
+ "crypto/cipher"
+ "errors"
+ "io"
+
+ "golang.org/x/crypto/chacha20poly1305"
+ "golang.org/x/crypto/poly1305"
+)
+
+const ChunkSize = 64 * 1024
+
+type Reader struct {
+ a cipher.AEAD
+ src io.Reader
+
+ unread []byte // decrypted but unread data, backed by buf
+ buf [encChunkSize]byte
+
+ err error
+ nonce [chacha20poly1305.NonceSize]byte
+}
+
+const (
+ encChunkSize = ChunkSize + poly1305.TagSize
+ lastChunkFlag = 0x01
+)
+
+func NewReader(key []byte, src io.Reader) (*Reader, error) {
+ aead, err := chacha20poly1305.New(key)
+ if err != nil {
+ return nil, err
+ }
+ return &Reader{
+ a: aead,
+ src: src,
+ }, nil
+}
+
+func (r *Reader) Read(p []byte) (int, error) {
+ if len(r.unread) > 0 {
+ n := copy(p, r.unread)
+ r.unread = r.unread[n:]
+ return n, nil
+ }
+ if r.err != nil {
+ return 0, r.err
+ }
+ if len(p) == 0 {
+ return 0, nil
+ }
+
+ last, err := r.readChunk()
+ if err != nil {
+ r.err = err
+ return 0, err
+ }
+
+ n := copy(p, r.unread)
+ r.unread = r.unread[n:]
+
+ if last {
+ r.err = io.EOF
+ }
+
+ return n, nil
+}
+
+// readChunk reads the next chunk of ciphertext from r.c and makes in available
+// in r.unread. last is true if the chunk was marked as the end of the message.
+// readChunk must not be called again after returning a last chunk or an error.
+func (r *Reader) readChunk() (last bool, err error) {
+ if len(r.unread) != 0 {
+ panic("stream: internal error: readChunk called with dirty buffer")
+ }
+
+ in := r.buf[:]
+ n, err := io.ReadFull(r.src, in)
+ switch {
+ case err == io.EOF:
+ // A message can't end without a marked chunk. This message is truncated.
+ return false, io.ErrUnexpectedEOF
+ case err == io.ErrUnexpectedEOF:
+ // The last chunk can be short.
+ in = in[:n]
+ last = true
+ setLastChunkFlag(&r.nonce)
+ case err != nil:
+ return false, err
+ }
+
+ outBuf := make([]byte, 0, ChunkSize)
+ out, err := r.a.Open(outBuf, r.nonce[:], in, nil)
+ if err != nil && !last {
+ // Check if this was a full-length final chunk.
+ last = true
+ setLastChunkFlag(&r.nonce)
+ out, err = r.a.Open(outBuf, r.nonce[:], in, nil)
+ }
+ if err != nil {
+ return false, err
+ }
+
+ incNonce(&r.nonce)
+ r.unread = r.buf[:copy(r.buf[:], out)]
+ return last, nil
+}
+
+func incNonce(nonce *[chacha20poly1305.NonceSize]byte) {
+ for i := len(nonce) - 2; i >= 0; i-- {
+ nonce[i]++
+ if nonce[i] != 0 {
+ break
+ } else if i == 0 {
+ // The counter is 88 bits, this is unreachable.
+ panic("stream: chunk counter wrapped around")
+ }
+ }
+}
+
+func setLastChunkFlag(nonce *[chacha20poly1305.NonceSize]byte) {
+ nonce[len(nonce)-1] = lastChunkFlag
+}
+
+type Writer struct {
+ a cipher.AEAD
+ dst io.Writer
+ unwritten []byte // backed by buf
+ buf [encChunkSize]byte
+ nonce [chacha20poly1305.NonceSize]byte
+ err error
+}
+
+func NewWriter(key []byte, dst io.Writer) (*Writer, error) {
+ aead, err := chacha20poly1305.New(key)
+ if err != nil {
+ return nil, err
+ }
+ w := &Writer{
+ a: aead,
+ dst: dst,
+ }
+ w.unwritten = w.buf[:0]
+ return w, nil
+}
+
+func (w *Writer) Write(p []byte) (n int, err error) {
+ // TODO: consider refactoring with a bytes.Buffer.
+ if w.err != nil {
+ return 0, w.err
+ }
+ if len(p) == 0 {
+ return 0, nil
+ }
+
+ total := len(p)
+ for len(p) > 0 {
+ free := ChunkSize - len(w.unwritten)
+ freeBuf := w.buf[len(w.unwritten) : len(w.unwritten)+free]
+ n := copy(freeBuf, p)
+ p = p[n:]
+ w.unwritten = w.unwritten[:len(w.unwritten)+n]
+
+ if len(w.unwritten) == ChunkSize && len(p) > 0 {
+ if err := w.flushChunk(notLastChunk); err != nil {
+ w.err = err
+ return 0, err
+ }
+ }
+ }
+ return total, nil
+}
+
+func (w *Writer) Close() error {
+ // TODO: close w.dst if it can be interface upgraded to io.Closer.
+ if w.err != nil {
+ return w.err
+ }
+
+ err := w.flushChunk(lastChunk)
+ if err != nil {
+ w.err = err
+ } else {
+ w.err = errors.New("stream.Writer is already closed")
+ }
+ return err
+}
+
+const (
+ lastChunk = true
+ notLastChunk = false
+)
+
+func (w *Writer) flushChunk(last bool) error {
+ if !last && len(w.unwritten) != ChunkSize {
+ panic("stream: internal error: flush called with partial chunk")
+ }
+
+ if last {
+ setLastChunkFlag(&w.nonce)
+ }
+ buf := w.a.Seal(w.buf[:0], w.nonce[:], w.unwritten, nil)
+ _, err := w.dst.Write(buf)
+ w.unwritten = w.buf[:0]
+ incNonce(&w.nonce)
+ return err
+}
diff --git a/internal/stream/stream_test.go b/internal/stream/stream_test.go
@@ -0,0 +1,93 @@
+package stream_test
+
+import (
+ "bytes"
+ "crypto/rand"
+ "fmt"
+ "testing"
+
+ "github.com/FiloSottile/age/internal/stream"
+ "golang.org/x/crypto/chacha20poly1305"
+)
+
+const cs = stream.ChunkSize
+
+func TestRoundTrip(t *testing.T) {
+ for _, stepSize := range []int{512, 600, 1000, cs} {
+ for _, length := range []int{0, 1000, cs, cs + 100} {
+ t.Run(fmt.Sprintf("len=%d,step=%d", length, stepSize),
+ func(t *testing.T) { testRoundTrip(t, stepSize, length) })
+ }
+ }
+}
+
+func testRoundTrip(t *testing.T, stepSize, length int) {
+ src := make([]byte, length)
+ if _, err := rand.Read(src); err != nil {
+ t.Fatal(err)
+ }
+ buf := &bytes.Buffer{}
+ key := make([]byte, chacha20poly1305.KeySize)
+ if _, err := rand.Read(key); err != nil {
+ t.Fatal(err)
+ }
+
+ w, err := stream.NewWriter(key, buf)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ var n int
+ for n < length {
+ b := length - n
+ if b > stepSize {
+ b = stepSize
+ }
+ nn, err := w.Write(src[n : n+b])
+ if err != nil {
+ t.Fatal(err)
+ }
+ if nn != b {
+ t.Errorf("Write returned %d, expected %d", nn, b)
+ }
+ n += nn
+
+ nn, err = w.Write(src[n:n])
+ if err != nil {
+ t.Fatal(err)
+ }
+ if nn != 0 {
+ t.Errorf("Write returned %d, expected 0", nn)
+ }
+ }
+
+ if err := w.Close(); err != nil {
+ t.Error("Close returned an error:", err)
+ }
+
+ t.Logf("buffer size: %d", buf.Len())
+
+ r, err := stream.NewReader(key, buf)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ n = 0
+ readBuf := make([]byte, stepSize)
+ for n < length {
+ b := length - n
+ if b > stepSize {
+ b = stepSize
+ }
+ nn, err := r.Read(readBuf)
+ if err != nil {
+ t.Fatalf("Read error at index %d: %v", n, err)
+ }
+
+ if !bytes.Equal(readBuf[:nn], src[n:n+nn]) {
+ t.Errorf("wrong data at indexes %d - %d", n, n+nn)
+ }
+
+ n += nn
+ }
+}