stream.go (4377B)
1 // Copyright 2019 Google LLC 2 // 3 // Use of this source code is governed by a BSD-style 4 // license that can be found in the LICENSE file or at 5 // https://developers.google.com/open-source/licenses/bsd 6 7 package stream 8 9 import ( 10 "crypto/cipher" 11 "errors" 12 "io" 13 14 "golang.org/x/crypto/chacha20poly1305" 15 "golang.org/x/crypto/poly1305" 16 ) 17 18 const ChunkSize = 64 * 1024 19 20 type Reader struct { 21 a cipher.AEAD 22 src io.Reader 23 24 unread []byte // decrypted but unread data, backed by buf 25 buf [encChunkSize]byte 26 27 err error 28 nonce [chacha20poly1305.NonceSize]byte 29 } 30 31 const ( 32 encChunkSize = ChunkSize + poly1305.TagSize 33 lastChunkFlag = 0x01 34 ) 35 36 func NewReader(key []byte, src io.Reader) (*Reader, error) { 37 aead, err := chacha20poly1305.New(key) 38 if err != nil { 39 return nil, err 40 } 41 return &Reader{ 42 a: aead, 43 src: src, 44 }, nil 45 } 46 47 func (r *Reader) Read(p []byte) (int, error) { 48 if len(r.unread) > 0 { 49 n := copy(p, r.unread) 50 r.unread = r.unread[n:] 51 return n, nil 52 } 53 if r.err != nil { 54 return 0, r.err 55 } 56 if len(p) == 0 { 57 return 0, nil 58 } 59 60 last, err := r.readChunk() 61 if err != nil { 62 r.err = err 63 return 0, err 64 } 65 66 n := copy(p, r.unread) 67 r.unread = r.unread[n:] 68 69 if last { 70 r.err = io.EOF 71 } 72 73 return n, nil 74 } 75 76 // readChunk reads the next chunk of ciphertext from r.c and makes in available 77 // in r.unread. last is true if the chunk was marked as the end of the message. 78 // readChunk must not be called again after returning a last chunk or an error. 79 func (r *Reader) readChunk() (last bool, err error) { 80 if len(r.unread) != 0 { 81 panic("stream: internal error: readChunk called with dirty buffer") 82 } 83 84 in := r.buf[:] 85 n, err := io.ReadFull(r.src, in) 86 switch { 87 case err == io.EOF: 88 // A message can't end without a marked chunk. This message is truncated. 89 return false, io.ErrUnexpectedEOF 90 case err == io.ErrUnexpectedEOF: 91 // The last chunk can be short. 92 in = in[:n] 93 last = true 94 setLastChunkFlag(&r.nonce) 95 case err != nil: 96 return false, err 97 } 98 99 outBuf := make([]byte, 0, ChunkSize) 100 out, err := r.a.Open(outBuf, r.nonce[:], in, nil) 101 if err != nil && !last { 102 // Check if this was a full-length final chunk. 103 last = true 104 setLastChunkFlag(&r.nonce) 105 out, err = r.a.Open(outBuf, r.nonce[:], in, nil) 106 } 107 if err != nil { 108 return false, err 109 } 110 111 incNonce(&r.nonce) 112 r.unread = r.buf[:copy(r.buf[:], out)] 113 return last, nil 114 } 115 116 func incNonce(nonce *[chacha20poly1305.NonceSize]byte) { 117 for i := len(nonce) - 2; i >= 0; i-- { 118 nonce[i]++ 119 if nonce[i] != 0 { 120 break 121 } else if i == 0 { 122 // The counter is 88 bits, this is unreachable. 123 panic("stream: chunk counter wrapped around") 124 } 125 } 126 } 127 128 func setLastChunkFlag(nonce *[chacha20poly1305.NonceSize]byte) { 129 nonce[len(nonce)-1] = lastChunkFlag 130 } 131 132 type Writer struct { 133 a cipher.AEAD 134 dst io.Writer 135 unwritten []byte // backed by buf 136 buf [encChunkSize]byte 137 nonce [chacha20poly1305.NonceSize]byte 138 err error 139 } 140 141 func NewWriter(key []byte, dst io.Writer) (*Writer, error) { 142 aead, err := chacha20poly1305.New(key) 143 if err != nil { 144 return nil, err 145 } 146 w := &Writer{ 147 a: aead, 148 dst: dst, 149 } 150 w.unwritten = w.buf[:0] 151 return w, nil 152 } 153 154 func (w *Writer) Write(p []byte) (n int, err error) { 155 // TODO: consider refactoring with a bytes.Buffer. 156 if w.err != nil { 157 return 0, w.err 158 } 159 if len(p) == 0 { 160 return 0, nil 161 } 162 163 total := len(p) 164 for len(p) > 0 { 165 freeBuf := w.buf[len(w.unwritten):ChunkSize] 166 n := copy(freeBuf, p) 167 p = p[n:] 168 w.unwritten = w.unwritten[:len(w.unwritten)+n] 169 170 if len(w.unwritten) == ChunkSize && len(p) > 0 { 171 if err := w.flushChunk(notLastChunk); err != nil { 172 w.err = err 173 return 0, err 174 } 175 } 176 } 177 return total, nil 178 } 179 180 func (w *Writer) Close() error { 181 // TODO: close w.dst if it can be interface upgraded to io.Closer. 182 if w.err != nil { 183 return w.err 184 } 185 186 err := w.flushChunk(lastChunk) 187 if err != nil { 188 w.err = err 189 } else { 190 w.err = errors.New("stream.Writer is already closed") 191 } 192 return err 193 } 194 195 const ( 196 lastChunk = true 197 notLastChunk = false 198 ) 199 200 func (w *Writer) flushChunk(last bool) error { 201 if !last && len(w.unwritten) != ChunkSize { 202 panic("stream: internal error: flush called with partial chunk") 203 } 204 205 if last { 206 setLastChunkFlag(&w.nonce) 207 } 208 buf := w.a.Seal(w.buf[:0], w.nonce[:], w.unwritten, nil) 209 _, err := w.dst.Write(buf) 210 w.unwritten = w.buf[:0] 211 incNonce(&w.nonce) 212 return err 213 }