diff --git a/lib/protocol/protocol.go b/lib/protocol/protocol.go index 28244f092..410c8bd72 100644 --- a/lib/protocol/protocol.go +++ b/lib/protocol/protocol.go @@ -523,7 +523,7 @@ func (c *rawConnection) readMessageAfterHeader(hdr Header, fourByteBuf []byte) ( // Nothing case MessageCompressionLZ4: - decomp, err := c.lz4Decompress(buf) + decomp, err := lz4Decompress(buf) BufferPool.Put(buf) if err != nil { return nil, errors.Wrap(err, "decompressing message") @@ -740,26 +740,56 @@ func (c *rawConnection) writerLoop() { func (c *rawConnection) writeMessage(msg message) error { msgContext, _ := messageContext(msg) l.Debugf("Writing %v", msgContext) - if c.shouldCompressMessage(msg) { - return c.writeCompressedMessage(msg) - } - return c.writeUncompressedMessage(msg) -} -func (c *rawConnection) writeCompressedMessage(msg message) error { size := msg.ProtoSize() - buf := BufferPool.Get(size) - if _, err := msg.MarshalTo(buf); err != nil { - BufferPool.Put(buf) + hdr := Header{ + Type: c.typeOf(msg), + } + hdrSize := hdr.ProtoSize() + if hdrSize > 1<<16-1 { + panic("impossibly large header") + } + + overhead := 2 + hdrSize + 4 + totSize := overhead + size + buf := BufferPool.Get(totSize) + defer BufferPool.Put(buf) + + // Message + if _, err := msg.MarshalTo(buf[2+hdrSize+4:]); err != nil { return errors.Wrap(err, "marshalling message") } - compressed, err := c.lz4Compress(buf) - if err != nil { - BufferPool.Put(buf) - return errors.Wrap(err, "compressing message") + if c.shouldCompressMessage(msg) { + ok, err := c.writeCompressedMessage(msg, buf[overhead:], overhead) + if ok { + return err + } } + // Header length + binary.BigEndian.PutUint16(buf, uint16(hdrSize)) + // Header + if _, err := hdr.MarshalTo(buf[2:]); err != nil { + return errors.Wrap(err, "marshalling header") + } + // Message length + binary.BigEndian.PutUint32(buf[2+hdrSize:], uint32(size)) + + n, err := c.cw.Write(buf) + + l.Debugf("wrote %d bytes on the wire (2 bytes length, %d bytes header, 4 bytes message length, %d bytes message), err=%v", n, hdrSize, size, err) + if err != nil { + return errors.Wrap(err, "writing message") + } + return nil +} + +// Write msg out compressed, given its uncompressed marshaled payload and overhead. +// +// The first return value indicates whether compression succeeded. +// If not, the caller should retry without compression. +func (c *rawConnection) writeCompressedMessage(msg message, marshaled []byte, overhead int) (ok bool, err error) { hdr := Header{ Type: c.typeOf(msg), Compression: MessageCompressionLZ4, @@ -769,71 +799,32 @@ func (c *rawConnection) writeCompressedMessage(msg message) error { panic("impossibly large header") } - compressedSize := len(compressed) - totSize := 2 + hdrSize + 4 + compressedSize - buf = BufferPool.Upgrade(buf, totSize) + cOverhead := 2 + hdrSize + 4 + maxCompressed := cOverhead + lz4.CompressBound(len(marshaled)) + buf := BufferPool.Get(maxCompressed) + defer BufferPool.Put(buf) + + compressedSize, err := lz4Compress(marshaled, buf[cOverhead:]) + totSize := compressedSize + cOverhead + if err != nil || totSize >= len(marshaled)+overhead { + return false, nil + } // Header length binary.BigEndian.PutUint16(buf, uint16(hdrSize)) // Header if _, err := hdr.MarshalTo(buf[2:]); err != nil { - BufferPool.Put(buf) - BufferPool.Put(compressed) - return errors.Wrap(err, "marshalling header") + return true, errors.Wrap(err, "marshalling header") } // Message length binary.BigEndian.PutUint32(buf[2+hdrSize:], uint32(compressedSize)) - // Message - copy(buf[2+hdrSize+4:], compressed) - BufferPool.Put(compressed) - - n, err := c.cw.Write(buf) - BufferPool.Put(buf) - - l.Debugf("wrote %d bytes on the wire (2 bytes length, %d bytes header, 4 bytes message length, %d bytes message (%d uncompressed)), err=%v", n, hdrSize, compressedSize, size, err) - if err != nil { - return errors.Wrap(err, "writing message") - } - return nil -} - -func (c *rawConnection) writeUncompressedMessage(msg message) error { - size := msg.ProtoSize() - - hdr := Header{ - Type: c.typeOf(msg), - } - hdrSize := hdr.ProtoSize() - if hdrSize > 1<<16-1 { - panic("impossibly large header") - } - - totSize := 2 + hdrSize + 4 + size - buf := BufferPool.Get(totSize) - - // Header length - binary.BigEndian.PutUint16(buf, uint16(hdrSize)) - // Header - if _, err := hdr.MarshalTo(buf[2:]); err != nil { - BufferPool.Put(buf) - return errors.Wrap(err, "marshalling header") - } - // Message length - binary.BigEndian.PutUint32(buf[2+hdrSize:], uint32(size)) - // Message - if _, err := msg.MarshalTo(buf[2+hdrSize+4:]); err != nil { - BufferPool.Put(buf) - return errors.Wrap(err, "marshalling message") - } n, err := c.cw.Write(buf[:totSize]) - BufferPool.Put(buf) - - l.Debugf("wrote %d bytes on the wire (2 bytes length, %d bytes header, 4 bytes message length, %d bytes message), err=%v", n, hdrSize, size, err) + l.Debugf("wrote %d bytes on the wire (2 bytes length, %d bytes header, 4 bytes message length, %d bytes message (%d uncompressed)), err=%v", n, hdrSize, compressedSize, len(marshaled), err) if err != nil { - return errors.Wrap(err, "writing message") + return true, errors.Wrap(err, "writing message") } - return nil + return true, nil } func (c *rawConnection) typeOf(msg message) MessageType { @@ -1018,23 +1009,20 @@ func (c *rawConnection) Statistics() Statistics { } } -func (c *rawConnection) lz4Compress(src []byte) ([]byte, error) { - var err error - buf := BufferPool.Get(lz4.CompressBound(len(src))) +func lz4Compress(src, buf []byte) (int, error) { compressed, err := lz4.Encode(buf, src) if err != nil { - BufferPool.Put(buf) - return nil, err + return -1, err } if &compressed[0] != &buf[0] { panic("bug: lz4.Compress allocated, which it must not (should use buffer pool)") } binary.BigEndian.PutUint32(compressed, binary.LittleEndian.Uint32(compressed)) - return compressed, nil + return len(compressed), nil } -func (c *rawConnection) lz4Decompress(src []byte) ([]byte, error) { +func lz4Decompress(src []byte) ([]byte, error) { size := binary.BigEndian.Uint32(src) binary.LittleEndian.PutUint32(src, size) var err error diff --git a/lib/protocol/protocol_test.go b/lib/protocol/protocol_test.go index a82cbb961..1c1d617f7 100644 --- a/lib/protocol/protocol_test.go +++ b/lib/protocol/protocol_test.go @@ -17,6 +17,7 @@ import ( "testing/quick" "time" + lz4 "github.com/bkaradzic/go-lz4" "github.com/syncthing/syncthing/lib/rand" "github.com/syncthing/syncthing/lib/testutils" ) @@ -439,9 +440,42 @@ func testMarshal(t *testing.T, prefix string, m1, m2 message) bool { return true } -func TestLZ4Compression(t *testing.T) { - c := new(rawConnection) +func TestWriteCompressed(t *testing.T) { + for _, random := range []bool{false, true} { + buf := new(bytes.Buffer) + c := &rawConnection{ + cr: &countingReader{Reader: buf}, + cw: &countingWriter{Writer: buf}, + compression: CompressionAlways, + } + msg := &Response{Data: make([]byte, 10240)} + if random { + // This should make the message uncompressible. + rand.Read(msg.Data) + } + + if err := c.writeMessage(msg); err != nil { + t.Fatal(err) + } + got, err := c.readMessage(make([]byte, 4)) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(got.(*Response).Data, msg.Data) { + t.Error("received the wrong message") + } + + hdr := Header{Type: c.typeOf(msg)} + size := int64(2 + hdr.ProtoSize() + 4 + msg.ProtoSize()) + if c.cr.tot > size { + t.Errorf("compression enlarged message from %d to %d", + size, c.cr.tot) + } + } +} + +func TestLZ4Compression(t *testing.T) { for i := 0; i < 10; i++ { dataLen := 150 + rand.Intn(150) data := make([]byte, dataLen) @@ -449,13 +483,15 @@ func TestLZ4Compression(t *testing.T) { if err != nil { t.Fatal(err) } - comp, err := c.lz4Compress(data) + + comp := make([]byte, lz4.CompressBound(dataLen)) + compLen, err := lz4Compress(data, comp) if err != nil { t.Errorf("compressing %d bytes: %v", dataLen, err) continue } - res, err := c.lz4Decompress(comp) + res, err := lz4Decompress(comp[:compLen]) if err != nil { t.Errorf("decompressing %d bytes to %d: %v", len(comp), dataLen, err) continue @@ -470,38 +506,6 @@ func TestLZ4Compression(t *testing.T) { } } -func TestStressLZ4CompressGrows(t *testing.T) { - c := new(rawConnection) - success := 0 - for i := 0; i < 100; i++ { - // Create a slize that is precisely one min block size, fill it with - // random data. This shouldn't compress at all, so will in fact - // become larger when LZ4 does its thing. - data := make([]byte, MinBlockSize) - if _, err := rand.Reader.Read(data); err != nil { - t.Fatal("randomness failure") - } - - comp, err := c.lz4Compress(data) - if err != nil { - t.Fatal("unexpected compression error: ", err) - } - if len(comp) < len(data) { - // data size should grow. We must have been really unlucky in - // the random generation, try again. - continue - } - - // Putting it into the buffer pool shouldn't panic because the block - // should come from there to begin with. - BufferPool.Put(comp) - success++ - } - if success == 0 { - t.Fatal("unable to find data that grows when compressed") - } -} - func TestCheckFilename(t *testing.T) { cases := []struct { name string