lib/protocol: Write uncompressible messages uncompressed (#7790)

This commit is contained in:
greatroar 2021-06-27 17:59:30 +02:00 committed by GitHub
parent 7a4c6d262f
commit bd363fe0b7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 102 additions and 110 deletions

View File

@ -523,7 +523,7 @@ func (c *rawConnection) readMessageAfterHeader(hdr Header, fourByteBuf []byte) (
// Nothing
case MessageCompressionLZ4:
decomp, err := c.lz4Decompress(buf)
decomp, err := lz4Decompress(buf)
BufferPool.Put(buf)
if err != nil {
return nil, errors.Wrap(err, "decompressing message")
@ -740,26 +740,56 @@ func (c *rawConnection) writerLoop() {
func (c *rawConnection) writeMessage(msg message) error {
msgContext, _ := messageContext(msg)
l.Debugf("Writing %v", msgContext)
if c.shouldCompressMessage(msg) {
return c.writeCompressedMessage(msg)
}
return c.writeUncompressedMessage(msg)
}
func (c *rawConnection) writeCompressedMessage(msg message) error {
size := msg.ProtoSize()
buf := BufferPool.Get(size)
if _, err := msg.MarshalTo(buf); err != nil {
BufferPool.Put(buf)
hdr := Header{
Type: c.typeOf(msg),
}
hdrSize := hdr.ProtoSize()
if hdrSize > 1<<16-1 {
panic("impossibly large header")
}
overhead := 2 + hdrSize + 4
totSize := overhead + size
buf := BufferPool.Get(totSize)
defer BufferPool.Put(buf)
// Message
if _, err := msg.MarshalTo(buf[2+hdrSize+4:]); err != nil {
return errors.Wrap(err, "marshalling message")
}
compressed, err := c.lz4Compress(buf)
if err != nil {
BufferPool.Put(buf)
return errors.Wrap(err, "compressing message")
if c.shouldCompressMessage(msg) {
ok, err := c.writeCompressedMessage(msg, buf[overhead:], overhead)
if ok {
return err
}
}
// Header length
binary.BigEndian.PutUint16(buf, uint16(hdrSize))
// Header
if _, err := hdr.MarshalTo(buf[2:]); err != nil {
return errors.Wrap(err, "marshalling header")
}
// Message length
binary.BigEndian.PutUint32(buf[2+hdrSize:], uint32(size))
n, err := c.cw.Write(buf)
l.Debugf("wrote %d bytes on the wire (2 bytes length, %d bytes header, 4 bytes message length, %d bytes message), err=%v", n, hdrSize, size, err)
if err != nil {
return errors.Wrap(err, "writing message")
}
return nil
}
// Write msg out compressed, given its uncompressed marshaled payload and overhead.
//
// The first return value indicates whether compression succeeded.
// If not, the caller should retry without compression.
func (c *rawConnection) writeCompressedMessage(msg message, marshaled []byte, overhead int) (ok bool, err error) {
hdr := Header{
Type: c.typeOf(msg),
Compression: MessageCompressionLZ4,
@ -769,71 +799,32 @@ func (c *rawConnection) writeCompressedMessage(msg message) error {
panic("impossibly large header")
}
compressedSize := len(compressed)
totSize := 2 + hdrSize + 4 + compressedSize
buf = BufferPool.Upgrade(buf, totSize)
cOverhead := 2 + hdrSize + 4
maxCompressed := cOverhead + lz4.CompressBound(len(marshaled))
buf := BufferPool.Get(maxCompressed)
defer BufferPool.Put(buf)
compressedSize, err := lz4Compress(marshaled, buf[cOverhead:])
totSize := compressedSize + cOverhead
if err != nil || totSize >= len(marshaled)+overhead {
return false, nil
}
// Header length
binary.BigEndian.PutUint16(buf, uint16(hdrSize))
// Header
if _, err := hdr.MarshalTo(buf[2:]); err != nil {
BufferPool.Put(buf)
BufferPool.Put(compressed)
return errors.Wrap(err, "marshalling header")
return true, errors.Wrap(err, "marshalling header")
}
// Message length
binary.BigEndian.PutUint32(buf[2+hdrSize:], uint32(compressedSize))
// Message
copy(buf[2+hdrSize+4:], compressed)
BufferPool.Put(compressed)
n, err := c.cw.Write(buf)
BufferPool.Put(buf)
l.Debugf("wrote %d bytes on the wire (2 bytes length, %d bytes header, 4 bytes message length, %d bytes message (%d uncompressed)), err=%v", n, hdrSize, compressedSize, size, err)
if err != nil {
return errors.Wrap(err, "writing message")
}
return nil
}
func (c *rawConnection) writeUncompressedMessage(msg message) error {
size := msg.ProtoSize()
hdr := Header{
Type: c.typeOf(msg),
}
hdrSize := hdr.ProtoSize()
if hdrSize > 1<<16-1 {
panic("impossibly large header")
}
totSize := 2 + hdrSize + 4 + size
buf := BufferPool.Get(totSize)
// Header length
binary.BigEndian.PutUint16(buf, uint16(hdrSize))
// Header
if _, err := hdr.MarshalTo(buf[2:]); err != nil {
BufferPool.Put(buf)
return errors.Wrap(err, "marshalling header")
}
// Message length
binary.BigEndian.PutUint32(buf[2+hdrSize:], uint32(size))
// Message
if _, err := msg.MarshalTo(buf[2+hdrSize+4:]); err != nil {
BufferPool.Put(buf)
return errors.Wrap(err, "marshalling message")
}
n, err := c.cw.Write(buf[:totSize])
BufferPool.Put(buf)
l.Debugf("wrote %d bytes on the wire (2 bytes length, %d bytes header, 4 bytes message length, %d bytes message), err=%v", n, hdrSize, size, err)
l.Debugf("wrote %d bytes on the wire (2 bytes length, %d bytes header, 4 bytes message length, %d bytes message (%d uncompressed)), err=%v", n, hdrSize, compressedSize, len(marshaled), err)
if err != nil {
return errors.Wrap(err, "writing message")
return true, errors.Wrap(err, "writing message")
}
return nil
return true, nil
}
func (c *rawConnection) typeOf(msg message) MessageType {
@ -1018,23 +1009,20 @@ func (c *rawConnection) Statistics() Statistics {
}
}
func (c *rawConnection) lz4Compress(src []byte) ([]byte, error) {
var err error
buf := BufferPool.Get(lz4.CompressBound(len(src)))
func lz4Compress(src, buf []byte) (int, error) {
compressed, err := lz4.Encode(buf, src)
if err != nil {
BufferPool.Put(buf)
return nil, err
return -1, err
}
if &compressed[0] != &buf[0] {
panic("bug: lz4.Compress allocated, which it must not (should use buffer pool)")
}
binary.BigEndian.PutUint32(compressed, binary.LittleEndian.Uint32(compressed))
return compressed, nil
return len(compressed), nil
}
func (c *rawConnection) lz4Decompress(src []byte) ([]byte, error) {
func lz4Decompress(src []byte) ([]byte, error) {
size := binary.BigEndian.Uint32(src)
binary.LittleEndian.PutUint32(src, size)
var err error

View File

@ -17,6 +17,7 @@ import (
"testing/quick"
"time"
lz4 "github.com/bkaradzic/go-lz4"
"github.com/syncthing/syncthing/lib/rand"
"github.com/syncthing/syncthing/lib/testutils"
)
@ -439,9 +440,42 @@ func testMarshal(t *testing.T, prefix string, m1, m2 message) bool {
return true
}
func TestLZ4Compression(t *testing.T) {
c := new(rawConnection)
func TestWriteCompressed(t *testing.T) {
for _, random := range []bool{false, true} {
buf := new(bytes.Buffer)
c := &rawConnection{
cr: &countingReader{Reader: buf},
cw: &countingWriter{Writer: buf},
compression: CompressionAlways,
}
msg := &Response{Data: make([]byte, 10240)}
if random {
// This should make the message uncompressible.
rand.Read(msg.Data)
}
if err := c.writeMessage(msg); err != nil {
t.Fatal(err)
}
got, err := c.readMessage(make([]byte, 4))
if err != nil {
t.Fatal(err)
}
if !bytes.Equal(got.(*Response).Data, msg.Data) {
t.Error("received the wrong message")
}
hdr := Header{Type: c.typeOf(msg)}
size := int64(2 + hdr.ProtoSize() + 4 + msg.ProtoSize())
if c.cr.tot > size {
t.Errorf("compression enlarged message from %d to %d",
size, c.cr.tot)
}
}
}
func TestLZ4Compression(t *testing.T) {
for i := 0; i < 10; i++ {
dataLen := 150 + rand.Intn(150)
data := make([]byte, dataLen)
@ -449,13 +483,15 @@ func TestLZ4Compression(t *testing.T) {
if err != nil {
t.Fatal(err)
}
comp, err := c.lz4Compress(data)
comp := make([]byte, lz4.CompressBound(dataLen))
compLen, err := lz4Compress(data, comp)
if err != nil {
t.Errorf("compressing %d bytes: %v", dataLen, err)
continue
}
res, err := c.lz4Decompress(comp)
res, err := lz4Decompress(comp[:compLen])
if err != nil {
t.Errorf("decompressing %d bytes to %d: %v", len(comp), dataLen, err)
continue
@ -470,38 +506,6 @@ func TestLZ4Compression(t *testing.T) {
}
}
func TestStressLZ4CompressGrows(t *testing.T) {
c := new(rawConnection)
success := 0
for i := 0; i < 100; i++ {
// Create a slize that is precisely one min block size, fill it with
// random data. This shouldn't compress at all, so will in fact
// become larger when LZ4 does its thing.
data := make([]byte, MinBlockSize)
if _, err := rand.Reader.Read(data); err != nil {
t.Fatal("randomness failure")
}
comp, err := c.lz4Compress(data)
if err != nil {
t.Fatal("unexpected compression error: ", err)
}
if len(comp) < len(data) {
// data size should grow. We must have been really unlucky in
// the random generation, try again.
continue
}
// Putting it into the buffer pool shouldn't panic because the block
// should come from there to begin with.
BufferPool.Put(comp)
success++
}
if success == 0 {
t.Fatal("unable to find data that grows when compressed")
}
}
func TestCheckFilename(t *testing.T) {
cases := []struct {
name string