age

Simple, secure encryption with UNIX-style composability.
git clone git://git.sgregoratto.me/age
Log | Files | Refs | README | LICENSE

stream.go (4377B)


      1 // Copyright 2019 Google LLC
      2 //
      3 // Use of this source code is governed by a BSD-style
      4 // license that can be found in the LICENSE file or at
      5 // https://developers.google.com/open-source/licenses/bsd
      6 
      7 package stream
      8 
      9 import (
     10 	"crypto/cipher"
     11 	"errors"
     12 	"io"
     13 
     14 	"golang.org/x/crypto/chacha20poly1305"
     15 	"golang.org/x/crypto/poly1305"
     16 )
     17 
     18 const ChunkSize = 64 * 1024
     19 
     20 type Reader struct {
     21 	a   cipher.AEAD
     22 	src io.Reader
     23 
     24 	unread []byte // decrypted but unread data, backed by buf
     25 	buf    [encChunkSize]byte
     26 
     27 	err   error
     28 	nonce [chacha20poly1305.NonceSize]byte
     29 }
     30 
     31 const (
     32 	encChunkSize  = ChunkSize + poly1305.TagSize
     33 	lastChunkFlag = 0x01
     34 )
     35 
     36 func NewReader(key []byte, src io.Reader) (*Reader, error) {
     37 	aead, err := chacha20poly1305.New(key)
     38 	if err != nil {
     39 		return nil, err
     40 	}
     41 	return &Reader{
     42 		a:   aead,
     43 		src: src,
     44 	}, nil
     45 }
     46 
     47 func (r *Reader) Read(p []byte) (int, error) {
     48 	if len(r.unread) > 0 {
     49 		n := copy(p, r.unread)
     50 		r.unread = r.unread[n:]
     51 		return n, nil
     52 	}
     53 	if r.err != nil {
     54 		return 0, r.err
     55 	}
     56 	if len(p) == 0 {
     57 		return 0, nil
     58 	}
     59 
     60 	last, err := r.readChunk()
     61 	if err != nil {
     62 		r.err = err
     63 		return 0, err
     64 	}
     65 
     66 	n := copy(p, r.unread)
     67 	r.unread = r.unread[n:]
     68 
     69 	if last {
     70 		r.err = io.EOF
     71 	}
     72 
     73 	return n, nil
     74 }
     75 
     76 // readChunk reads the next chunk of ciphertext from r.c and makes in available
     77 // in r.unread. last is true if the chunk was marked as the end of the message.
     78 // readChunk must not be called again after returning a last chunk or an error.
     79 func (r *Reader) readChunk() (last bool, err error) {
     80 	if len(r.unread) != 0 {
     81 		panic("stream: internal error: readChunk called with dirty buffer")
     82 	}
     83 
     84 	in := r.buf[:]
     85 	n, err := io.ReadFull(r.src, in)
     86 	switch {
     87 	case err == io.EOF:
     88 		// A message can't end without a marked chunk. This message is truncated.
     89 		return false, io.ErrUnexpectedEOF
     90 	case err == io.ErrUnexpectedEOF:
     91 		// The last chunk can be short.
     92 		in = in[:n]
     93 		last = true
     94 		setLastChunkFlag(&r.nonce)
     95 	case err != nil:
     96 		return false, err
     97 	}
     98 
     99 	outBuf := make([]byte, 0, ChunkSize)
    100 	out, err := r.a.Open(outBuf, r.nonce[:], in, nil)
    101 	if err != nil && !last {
    102 		// Check if this was a full-length final chunk.
    103 		last = true
    104 		setLastChunkFlag(&r.nonce)
    105 		out, err = r.a.Open(outBuf, r.nonce[:], in, nil)
    106 	}
    107 	if err != nil {
    108 		return false, err
    109 	}
    110 
    111 	incNonce(&r.nonce)
    112 	r.unread = r.buf[:copy(r.buf[:], out)]
    113 	return last, nil
    114 }
    115 
    116 func incNonce(nonce *[chacha20poly1305.NonceSize]byte) {
    117 	for i := len(nonce) - 2; i >= 0; i-- {
    118 		nonce[i]++
    119 		if nonce[i] != 0 {
    120 			break
    121 		} else if i == 0 {
    122 			// The counter is 88 bits, this is unreachable.
    123 			panic("stream: chunk counter wrapped around")
    124 		}
    125 	}
    126 }
    127 
    128 func setLastChunkFlag(nonce *[chacha20poly1305.NonceSize]byte) {
    129 	nonce[len(nonce)-1] = lastChunkFlag
    130 }
    131 
    132 type Writer struct {
    133 	a         cipher.AEAD
    134 	dst       io.Writer
    135 	unwritten []byte // backed by buf
    136 	buf       [encChunkSize]byte
    137 	nonce     [chacha20poly1305.NonceSize]byte
    138 	err       error
    139 }
    140 
    141 func NewWriter(key []byte, dst io.Writer) (*Writer, error) {
    142 	aead, err := chacha20poly1305.New(key)
    143 	if err != nil {
    144 		return nil, err
    145 	}
    146 	w := &Writer{
    147 		a:   aead,
    148 		dst: dst,
    149 	}
    150 	w.unwritten = w.buf[:0]
    151 	return w, nil
    152 }
    153 
    154 func (w *Writer) Write(p []byte) (n int, err error) {
    155 	// TODO: consider refactoring with a bytes.Buffer.
    156 	if w.err != nil {
    157 		return 0, w.err
    158 	}
    159 	if len(p) == 0 {
    160 		return 0, nil
    161 	}
    162 
    163 	total := len(p)
    164 	for len(p) > 0 {
    165 		freeBuf := w.buf[len(w.unwritten):ChunkSize]
    166 		n := copy(freeBuf, p)
    167 		p = p[n:]
    168 		w.unwritten = w.unwritten[:len(w.unwritten)+n]
    169 
    170 		if len(w.unwritten) == ChunkSize && len(p) > 0 {
    171 			if err := w.flushChunk(notLastChunk); err != nil {
    172 				w.err = err
    173 				return 0, err
    174 			}
    175 		}
    176 	}
    177 	return total, nil
    178 }
    179 
    180 func (w *Writer) Close() error {
    181 	// TODO: close w.dst if it can be interface upgraded to io.Closer.
    182 	if w.err != nil {
    183 		return w.err
    184 	}
    185 
    186 	err := w.flushChunk(lastChunk)
    187 	if err != nil {
    188 		w.err = err
    189 	} else {
    190 		w.err = errors.New("stream.Writer is already closed")
    191 	}
    192 	return err
    193 }
    194 
    195 const (
    196 	lastChunk    = true
    197 	notLastChunk = false
    198 )
    199 
    200 func (w *Writer) flushChunk(last bool) error {
    201 	if !last && len(w.unwritten) != ChunkSize {
    202 		panic("stream: internal error: flush called with partial chunk")
    203 	}
    204 
    205 	if last {
    206 		setLastChunkFlag(&w.nonce)
    207 	}
    208 	buf := w.a.Seal(w.buf[:0], w.nonce[:], w.unwritten, nil)
    209 	_, err := w.dst.Write(buf)
    210 	w.unwritten = w.buf[:0]
    211 	incNonce(&w.nonce)
    212 	return err
    213 }