diff --git a/lib/bytesutil/bytebuffer.go b/lib/bytesutil/bytebuffer.go index 4c294d5f29..2748d0a252 100644 --- a/lib/bytesutil/bytebuffer.go +++ b/lib/bytesutil/bytebuffer.go @@ -13,6 +13,7 @@ var ( // Verify ByteBuffer implements the given interfaces. _ io.Writer = &ByteBuffer{} _ fs.ReadAtCloser = &ByteBuffer{} + _ io.ReaderFrom = &ByteBuffer{} // Verify reader implement filestream.ReadCloser interface. _ filestream.ReadCloser = &reader{} @@ -48,6 +49,30 @@ func (bb *ByteBuffer) ReadAt(p []byte, offset int64) { } } +// ReadFrom reads all the data from r to bb until EOF. +func (bb *ByteBuffer) ReadFrom(r io.Reader) (int64, error) { + b := bb.B + bLen := len(b) + b = Resize(b, 4*1024) + b = b[:cap(b)] + offset := bLen + for { + if free := len(b) - offset; free < offset { + n := len(b) + b = append(b, make([]byte, n)...) + } + n, err := r.Read(b[offset:]) + offset += n + if err != nil { + bb.B = b[:offset] + if err == io.EOF { + err = nil + } + return int64(offset - bLen), err + } + } +} + // MustClose closes bb for subsequent re-use. func (bb *ByteBuffer) MustClose() { // Do nothing, since certain code rely on bb reading after MustClose call. diff --git a/lib/bytesutil/bytebuffer_test.go b/lib/bytesutil/bytebuffer_test.go index 64306e1ea3..0db4235486 100644 --- a/lib/bytesutil/bytebuffer_test.go +++ b/lib/bytesutil/bytebuffer_test.go @@ -1,6 +1,7 @@ package bytesutil import ( + "bytes" "fmt" "io" "testing" @@ -66,6 +67,92 @@ func TestByteBuffer(t *testing.T) { } } +func TestByteBufferReadFrom(t *testing.T) { + var bbPool ByteBufferPool + + t.Run("zero_bytes", func(t *testing.T) { + t.Parallel() + bb := bbPool.Get() + defer bbPool.Put(bb) + src := bytes.NewBufferString("") + n, err := bb.ReadFrom(src) + if err != nil { + t.Fatalf("error when reading empty string: %s", err) + } + if n != 0 { + t.Fatalf("unexpected number of bytes read; got %d; want %d", n, 0) + } + if len(bb.B) != 0 { + t.Fatalf("unexpejcted len(bb.B); got %d; want %d", len(bb.B), 0) + } + }) + + t.Run("non_zero_bytes", func(t *testing.T) { + t.Parallel() + bb := bbPool.Get() + defer bbPool.Put(bb) + s := "foobarbaz" + src := bytes.NewBufferString(s) + n, err := bb.ReadFrom(src) + if err != nil { + t.Fatalf("error when reading non-empty string: %s", err) + } + if n != int64(len(s)) { + t.Fatalf("unexpected number of bytes read; got %d; want %d", n, len(s)) + } + if string(bb.B) != s { + t.Fatalf("unexpected value read; got %q; want %q", bb.B, s) + } + }) + + t.Run("big_number_of_bytes", func(t *testing.T) { + t.Parallel() + bb := bbPool.Get() + defer bbPool.Put(bb) + b := make([]byte, 1024*1024+234) + for i := range b { + b[i] = byte(i) + } + src := bytes.NewBuffer(b) + n, err := bb.ReadFrom(src) + if err != nil { + t.Fatalf("cannot read big value: %s", err) + } + if n != int64(len(b)) { + t.Fatalf("unexpected number of bytes read; got %d; want %d", n, len(b)) + } + if string(bb.B) != string(b) { + t.Fatalf("unexpected value read; got %q; want %q", bb.B, b) + } + }) + + t.Run("non_empty_bb", func(t *testing.T) { + t.Parallel() + bb := bbPool.Get() + defer bbPool.Put(bb) + prefix := []byte("prefix") + bb.B = append(bb.B[:0], prefix...) + s := "aosdfdsafdjsf" + src := bytes.NewBufferString(s) + n, err := bb.ReadFrom(src) + if err != nil { + t.Fatalf("cannot read to non-empty bb: %s", err) + } + if n != int64(len(s)) { + t.Fatalf("unexpected number of bytes read; got %d; want %d", n, len(s)) + } + if len(bb.B) != len(prefix)+len(s) { + t.Fatalf("unexpected bb.B len; got %d; want %d", len(bb.B), len(prefix)+len(s)) + } + if string(bb.B[:len(prefix)]) != string(prefix) { + t.Fatalf("unexpected prefix; got %q; want %q", bb.B[:len(prefix)], prefix) + } + if string(bb.B[len(prefix):]) != s { + t.Fatalf("unexpected data read; got %q; want %q", bb.B[len(prefix):], s) + } + }) +} + func TestByteBufferRead(t *testing.T) { var bb ByteBuffer diff --git a/lib/prompb/util.go b/lib/prompb/util.go index 7dbccf9f72..4768bacf08 100644 --- a/lib/prompb/util.go +++ b/lib/prompb/util.go @@ -11,13 +11,9 @@ import ( // ReadSnappy reads r, unpacks it using snappy, appends it to dst // and returns the result. func ReadSnappy(dst []byte, r io.Reader, maxSize int64) ([]byte, error) { - bb := bodyBufferPool.Get() - bb.B = bb.B[:0] - cb := copyBufferPool.Get() - cb.B = bytesutil.Resize(cb.B, 16*1024) lr := io.LimitReader(r, maxSize+1) - reqLen, err := io.CopyBuffer(bb, lr, cb.B) - copyBufferPool.Put(cb) + bb := bodyBufferPool.Get() + reqLen, err := bb.ReadFrom(lr) if err != nil { bodyBufferPool.Put(bb) return dst, fmt.Errorf("cannot read compressed request: %s", err) @@ -45,7 +41,6 @@ func ReadSnappy(dst []byte, r io.Reader, maxSize int64) ([]byte, error) { } var bodyBufferPool bytesutil.ByteBufferPool -var copyBufferPool bytesutil.ByteBufferPool // Reset resets wr. func (wr *WriteRequest) Reset() {