// Copyright 2019+ Klaus Post. All rights reserved. // License information can be found in the LICENSE file. // Based on work by Yann Collet, released under BSD License. package zstd import ( "bytes" "encoding/hex" "errors" "hash" "io" "sync" "github.com/klauspost/compress/zstd/internal/xxhash" ) type frameDec struct { o decoderOptions crc hash.Hash64 frameDone sync.WaitGroup offset int64 WindowSize uint64 DictionaryID uint32 FrameContentSize uint64 HasCheckSum bool SingleSegment bool // maxWindowSize is the maximum windows size to support. // should never be bigger than max-int. maxWindowSize uint64 // In order queue of blocks being decoded. decoding chan *blockDec // Frame history passed between blocks history history rawInput byteBuffer // Byte buffer that can be reused for small input blocks. bBuf byteBuf // asyncRunning indicates whether the async routine processes input on 'decoding'. asyncRunning bool asyncRunningMu sync.Mutex } const ( // The minimum Window_Size is 1 KB. minWindowSize = 1 << 10 ) var ( frameMagic = []byte{0x28, 0xb5, 0x2f, 0xfd} skippableFrameMagic = []byte{0x2a, 0x4d, 0x18} ) func newFrameDec(o decoderOptions) *frameDec { d := frameDec{ o: o, maxWindowSize: 1 << 30, } return &d } // reset will read the frame header and prepare for block decoding. // If nothing can be read from the input, io.EOF will be returned. // Any other error indicated that the stream contained data, but // there was a problem. func (d *frameDec) reset(br byteBuffer) error { d.HasCheckSum = false d.WindowSize = 0 var b []byte for { b = br.readSmall(4) if b == nil { return io.EOF } if !bytes.Equal(b[1:4], skippableFrameMagic) || b[0]&0xf0 != 0x50 { if debug { println("Not skippable", hex.EncodeToString(b), hex.EncodeToString(skippableFrameMagic)) } // Break if not skippable frame. break } // Read size to skip b = br.readSmall(4) if b == nil { println("Reading Frame Size EOF") return io.ErrUnexpectedEOF } n := uint32(b[0]) | (uint32(b[1]) << 8) | (uint32(b[2]) << 16) | (uint32(b[3]) << 24) println("Skipping frame with", n, "bytes.") err := br.skipN(int(n)) if err != nil { if debug { println("Reading discarded frame", err) } return err } } if !bytes.Equal(b, frameMagic) { println("Got magic numbers: ", b, "want:", frameMagic) return ErrMagicMismatch } // Read Frame_Header_Descriptor fhd, err := br.readByte() if err != nil { println("Reading Frame_Header_Descriptor", err) return err } d.SingleSegment = fhd&(1<<5) != 0 if fhd&(1<<3) != 0 { return errors.New("Reserved bit set on frame header") } // Read Window_Descriptor // https://github.com/facebook/zstd/blob/dev/doc/zstd_compression_format.md#window_descriptor d.WindowSize = 0 if !d.SingleSegment { wd, err := br.readByte() if err != nil { println("Reading Window_Descriptor", err) return err } printf("raw: %x, mantissa: %d, exponent: %d\n", wd, wd&7, wd>>3) windowLog := 10 + (wd >> 3) windowBase := uint64(1) << windowLog windowAdd := (windowBase / 8) * uint64(wd&0x7) d.WindowSize = windowBase + windowAdd } // Read Dictionary_ID // https://github.com/facebook/zstd/blob/dev/doc/zstd_compression_format.md#dictionary_id d.DictionaryID = 0 if size := fhd & 3; size != 0 { if size == 3 { size = 4 } b = br.readSmall(int(size)) if b == nil { if debug { println("Reading Dictionary_ID", io.ErrUnexpectedEOF) } return io.ErrUnexpectedEOF } switch size { case 1: d.DictionaryID = uint32(b[0]) case 2: d.DictionaryID = uint32(b[0]) | (uint32(b[1]) << 8) case 4: d.DictionaryID = uint32(b[0]) | (uint32(b[1]) << 8) | (uint32(b[2]) << 16) | (uint32(b[3]) << 24) } if debug { println("Dict size", size, "ID:", d.DictionaryID) } if d.DictionaryID != 0 { return ErrUnknownDictionary } } // Read Frame_Content_Size // https://github.com/facebook/zstd/blob/dev/doc/zstd_compression_format.md#frame_content_size var fcsSize int v := fhd >> 6 switch v { case 0: if d.SingleSegment { fcsSize = 1 } default: fcsSize = 1 << v } d.FrameContentSize = 0 if fcsSize > 0 { b := br.readSmall(fcsSize) if b == nil { println("Reading Frame content", io.ErrUnexpectedEOF) return io.ErrUnexpectedEOF } switch fcsSize { case 1: d.FrameContentSize = uint64(b[0]) case 2: // When FCS_Field_Size is 2, the offset of 256 is added. d.FrameContentSize = uint64(b[0]) | (uint64(b[1]) << 8) + 256 case 4: d.FrameContentSize = uint64(b[0]) | (uint64(b[1]) << 8) | (uint64(b[2]) << 16) | (uint64(b[3] << 24)) case 8: d1 := uint32(b[0]) | (uint32(b[1]) << 8) | (uint32(b[2]) << 16) | (uint32(b[3]) << 24) d2 := uint32(b[4]) | (uint32(b[5]) << 8) | (uint32(b[6]) << 16) | (uint32(b[7]) << 24) d.FrameContentSize = uint64(d1) | (uint64(d2) << 32) } if debug { println("field size bits:", v, "fcsSize:", fcsSize, "FrameContentSize:", d.FrameContentSize, hex.EncodeToString(b[:fcsSize])) } } // Move this to shared. d.HasCheckSum = fhd&(1<<2) != 0 if d.HasCheckSum { if d.crc == nil { d.crc = xxhash.New() } d.crc.Reset() } if d.WindowSize == 0 && d.SingleSegment { // We may not need window in this case. d.WindowSize = d.FrameContentSize if d.WindowSize < minWindowSize { d.WindowSize = minWindowSize } } if d.WindowSize > d.maxWindowSize { printf("window size %d > max %d\n", d.WindowSize, d.maxWindowSize) return ErrWindowSizeExceeded } // The minimum Window_Size is 1 KB. if d.WindowSize < minWindowSize { println("got window size: ", d.WindowSize) return ErrWindowSizeTooSmall } d.history.windowSize = int(d.WindowSize) d.history.maxSize = d.history.windowSize + maxBlockSize // history contains input - maybe we do something d.rawInput = br return nil } // next will start decoding the next block from stream. func (d *frameDec) next(block *blockDec) error { if debug { printf("decoding new block %p:%p", block, block.data) } err := block.reset(d.rawInput, d.WindowSize) if err != nil { println("block error:", err) // Signal the frame decoder we have a problem. d.sendErr(block, err) return err } block.input <- struct{}{} if debug { println("next block:", block) } d.asyncRunningMu.Lock() defer d.asyncRunningMu.Unlock() if !d.asyncRunning { return nil } if block.Last { // We indicate the frame is done by sending io.EOF d.decoding <- block return io.EOF } d.decoding <- block return nil } // sendEOF will queue an error block on the frame. // This will cause the frame decoder to return when it encounters the block. // Returns true if the decoder was added. func (d *frameDec) sendErr(block *blockDec, err error) bool { d.asyncRunningMu.Lock() defer d.asyncRunningMu.Unlock() if !d.asyncRunning { return false } println("sending error", err.Error()) block.sendErr(err) d.decoding <- block return true } // checkCRC will check the checksum if the frame has one. // Will return ErrCRCMismatch if crc check failed, otherwise nil. func (d *frameDec) checkCRC() error { if !d.HasCheckSum { return nil } var tmp [4]byte got := d.crc.Sum64() // Flip to match file order. tmp[0] = byte(got >> 0) tmp[1] = byte(got >> 8) tmp[2] = byte(got >> 16) tmp[3] = byte(got >> 24) // We can overwrite upper tmp now want := d.rawInput.readSmall(4) if want == nil { println("CRC missing?") return io.ErrUnexpectedEOF } if !bytes.Equal(tmp[:], want) { if debug { println("CRC Check Failed:", tmp[:], "!=", want) } return ErrCRCMismatch } println("CRC ok") return nil } func (d *frameDec) initAsync() { if !d.o.lowMem && !d.SingleSegment { // set max extra size history to 20MB. d.history.maxSize = d.history.windowSize + maxBlockSize*10 } // re-alloc if more than one extra block size. if d.o.lowMem && cap(d.history.b) > d.history.maxSize+maxBlockSize { d.history.b = make([]byte, 0, d.history.maxSize) } if cap(d.history.b) < d.history.maxSize { d.history.b = make([]byte, 0, d.history.maxSize) } if cap(d.decoding) < d.o.concurrent { d.decoding = make(chan *blockDec, d.o.concurrent) } if debug { h := d.history printf("history init. len: %d, cap: %d", len(h.b), cap(h.b)) } d.asyncRunningMu.Lock() d.asyncRunning = true d.asyncRunningMu.Unlock() } // startDecoder will start decoding blocks and write them to the writer. // The decoder will stop as soon as an error occurs or at end of frame. // When the frame has finished decoding the *bufio.Reader // containing the remaining input will be sent on frameDec.frameDone. func (d *frameDec) startDecoder(output chan decodeOutput) { // TODO: Init to dictionary d.history.reset() written := int64(0) defer func() { d.asyncRunningMu.Lock() d.asyncRunning = false d.asyncRunningMu.Unlock() // Drain the currently decoding. d.history.error = true flushdone: for { select { case b := <-d.decoding: b.history <- &d.history output <- <-b.result default: break flushdone } } println("frame decoder done, signalling done") d.frameDone.Done() }() // Get decoder for first block. block := <-d.decoding block.history <- &d.history for { var next *blockDec // Get result r := <-block.result if r.err != nil { println("Result contained error", r.err) output <- r return } if debug { println("got result, from ", d.offset, "to", d.offset+int64(len(r.b))) d.offset += int64(len(r.b)) } if !block.Last { // Send history to next block select { case next = <-d.decoding: if debug { println("Sending ", len(d.history.b), "bytes as history") } next.history <- &d.history default: // Wait until we have sent the block, so // other decoders can potentially get the decoder. next = nil } } // Add checksum, async to decoding. if d.HasCheckSum { n, err := d.crc.Write(r.b) if err != nil { r.err = err if n != len(r.b) { r.err = io.ErrShortWrite } output <- r return } } written += int64(len(r.b)) if d.SingleSegment && uint64(written) > d.FrameContentSize { r.err = ErrFrameSizeExceeded output <- r return } if block.Last { r.err = d.checkCRC() output <- r return } output <- r if next == nil { // There was no decoder available, we wait for one now that we have sent to the writer. if debug { println("Sending ", len(d.history.b), " bytes as history") } next = <-d.decoding next.history <- &d.history } block = next } } // runDecoder will create a sync decoder that will decode a block of data. func (d *frameDec) runDecoder(dst []byte, dec *blockDec) ([]byte, error) { // TODO: Init to dictionary d.history.reset() saved := d.history.b // We use the history for output to avoid copying it. d.history.b = dst // Store input length, so we only check new data. crcStart := len(dst) var err error for { err = dec.reset(d.rawInput, d.WindowSize) if err != nil { break } if debug { println("next block:", dec) } err = dec.decodeBuf(&d.history) if err != nil || dec.Last { break } if uint64(len(d.history.b)) > d.o.maxDecodedSize { err = ErrDecoderSizeExceeded break } if d.SingleSegment && uint64(len(d.history.b)) > d.o.maxDecodedSize { err = ErrFrameSizeExceeded break } } dst = d.history.b if err == nil { if d.HasCheckSum { var n int n, err = d.crc.Write(dst[crcStart:]) if err == nil { if n != len(dst)-crcStart { err = io.ErrShortWrite } } err = d.checkCRC() } } d.history.b = saved return dst, err }