Commit 20fba0933f

LemonBoy <thatlemon@gmail.com>
2020-10-29 17:16:03
std/deflate: Avoid reading past end of stream
Use a conservative (and slower) approach in the Huffman decoder fast path. Closes #6847
1 parent 88eb3ae
Changed files (3)
lib/std/compress/deflate.zig
@@ -27,6 +27,8 @@ const FIXLCODES = 288;
 const PREFIX_LUT_BITS = 9;
 
 const Huffman = struct {
+    const LUTEntry = packed struct { symbol: u16 align(4), len: u16 };
+
     // Number of codes for each possible length
     count: [MAXBITS + 1]u16,
     // Mapping between codes and symbols
@@ -40,19 +42,23 @@ const Huffman = struct {
     // canonical Huffman code and we have to decode it using a slower method.
     //
     // [1] https://github.com/madler/zlib/blob/v1.2.11/doc/algorithm.txt#L58
-    prefix_lut: [1 << PREFIX_LUT_BITS]u16,
-    prefix_lut_len: [1 << PREFIX_LUT_BITS]u16,
+    prefix_lut: [1 << PREFIX_LUT_BITS]LUTEntry,
     // The following info refer to the codes of length PREFIX_LUT_BITS+1 and are
     // used to bootstrap the bit-by-bit reading method if the fast-path fails.
     last_code: u16,
     last_index: u16,
 
+    min_code_len: u16,
+
     fn construct(self: *Huffman, code_length: []const u16) !void {
         for (self.count) |*val| {
             val.* = 0;
         }
 
+        self.min_code_len = math.maxInt(u16);
         for (code_length) |len| {
+            if (len != 0 and len < self.min_code_len)
+                self.min_code_len = len;
             self.count[len] += 1;
         }
 
@@ -85,39 +91,38 @@ const Huffman = struct {
             }
         }
 
-        self.prefix_lut_len = mem.zeroes(@TypeOf(self.prefix_lut_len));
+        self.prefix_lut = mem.zeroes(@TypeOf(self.prefix_lut));
 
         for (code_length) |len, symbol| {
             if (len != 0) {
                 // Fill the symbol table.
                 // The symbols are assigned sequentially for each length.
                 self.symbol[offset[len]] = @truncate(u16, symbol);
-                // Track the last assigned offset
+                // Track the last assigned offset.
                 offset[len] += 1;
             }
 
             if (len == 0 or len > PREFIX_LUT_BITS)
                 continue;
 
-            // Given a Huffman code of length N we have to massage it so
-            // that it becomes an index in the lookup table.
-            // The bit order is reversed as the fast path reads the bit
-            // sequence MSB to LSB using an &, the order is flipped wrt the
-            // one obtained by reading bit-by-bit.
-            // The codes are prefix-free, if the prefix matches we can
-            // safely ignore the trail bits. We do so by replicating the
-            // symbol info for each combination of the trailing bits.
+            // Given a Huffman code of length N we transform it into an index
+            // into the lookup table by reversing its bits and filling the
+            // remaining bits (PREFIX_LUT_BITS - N) with every possible
+            // combination of bits to act as a wildcard.
             const bits_to_fill = @intCast(u5, PREFIX_LUT_BITS - len);
-            const rev_code = bitReverse(codes[len], len);
-            // Track the last used code, but only for lengths < PREFIX_LUT_BITS
+            const rev_code = bitReverse(u16, codes[len], len);
+
+            // Track the last used code, but only for lengths < PREFIX_LUT_BITS.
             codes[len] += 1;
 
             var j: usize = 0;
             while (j < @as(usize, 1) << bits_to_fill) : (j += 1) {
                 const index = rev_code | (j << @intCast(u5, len));
-                assert(self.prefix_lut_len[index] == 0);
-                self.prefix_lut[index] = @truncate(u16, symbol);
-                self.prefix_lut_len[index] = @truncate(u16, len);
+                assert(self.prefix_lut[index].len == 0);
+                self.prefix_lut[index] = .{
+                    .symbol = @truncate(u16, symbol),
+                    .len = @truncate(u16, len),
+                };
             }
         }
 
@@ -126,14 +131,10 @@ const Huffman = struct {
     }
 };
 
-// Reverse bit-by-bit a N-bit value
-fn bitReverse(x: usize, N: usize) usize {
-    var tmp: usize = 0;
-    var i: usize = 0;
-    while (i < N) : (i += 1) {
-        tmp |= ((x >> @intCast(u5, i)) & 1) << @intCast(u5, N - i - 1);
-    }
-    return tmp;
+// Reverse bit-by-bit a N-bit code.
+fn bitReverse(comptime T: type, value: T, N: usize) T {
+    const r = @bitReverse(T, value);
+    return r >> @intCast(math.Log2Int(T), @typeInfo(T).Int.bits - N);
 }
 
 pub fn InflateStream(comptime ReaderType: type) type {
@@ -269,8 +270,8 @@ pub fn InflateStream(comptime ReaderType: type) type {
         hdist: *Huffman,
         hlen: *Huffman,
 
-        // Temporary buffer for the bitstream, only bits 0..`bits_left` are
-        // considered valid.
+        // Temporary buffer for the bitstream.
+        // Bits 0..`bits_left` are filled with data, the remaining ones are zeros.
         bits: u32,
         bits_left: usize,
 
@@ -280,7 +281,8 @@ pub fn InflateStream(comptime ReaderType: type) type {
                 self.bits |= @as(u32, byte) << @intCast(u5, self.bits_left);
                 self.bits_left += 8;
             }
-            return self.bits & ((@as(u32, 1) << @intCast(u5, bits)) - 1);
+            const mask = (@as(u32, 1) << @intCast(u5, bits)) - 1;
+            return self.bits & mask;
         }
         fn readBits(self: *Self, bits: usize) !u32 {
             const val = self.peekBits(bits);
@@ -293,8 +295,8 @@ pub fn InflateStream(comptime ReaderType: type) type {
         }
 
         fn stored(self: *Self) !void {
-            // Discard the remaining bits, the lenght field is always
-            // byte-aligned (and so is the data)
+            // Discard the remaining bits, the length field is always
+            // byte-aligned (and so is the data).
             self.discardBits(self.bits_left);
 
             const length = try self.inner_reader.readIntLittle(u16);
@@ -481,32 +483,52 @@ pub fn InflateStream(comptime ReaderType: type) type {
         }
 
         fn decode(self: *Self, h: *Huffman) !u16 {
-            // Fast path, read some bits and hope they're prefixes of some code
-            const prefix = try self.peekBits(PREFIX_LUT_BITS);
-            if (h.prefix_lut_len[prefix] != 0) {
-                self.discardBits(h.prefix_lut_len[prefix]);
-                return h.prefix_lut[prefix];
+            // Using u32 instead of u16 to reduce the number of casts needed.
+            var prefix: u32 = 0;
+
+            // Fast path, read some bits and hope they're the prefix of some code.
+            // We can't read PREFIX_LUT_BITS as we don't want to read past the
+            // deflate stream end, use an incremental approach instead.
+            var code_len = h.min_code_len;
+            while (true) {
+                _ = try self.peekBits(code_len);
+                // Small optimization win, use as many bits as possible in the
+                // table lookup.
+                prefix = self.bits & ((1 << PREFIX_LUT_BITS) - 1);
+
+                const lut_entry = &h.prefix_lut[prefix];
+                // The code is longer than PREFIX_LUT_BITS!
+                if (lut_entry.len == 0)
+                    break;
+                // If the code lenght doesn't increase we found a match.
+                if (lut_entry.len <= code_len) {
+                    self.discardBits(code_len);
+                    return lut_entry.symbol;
+                }
+
+                code_len = lut_entry.len;
             }
 
             // The sequence we've read is not a prefix of any code of length <=
-            // PREFIX_LUT_BITS, keep decoding it using a slower method
-            self.discardBits(PREFIX_LUT_BITS);
+            // PREFIX_LUT_BITS, keep decoding it using a slower method.
+            prefix = try self.readBits(PREFIX_LUT_BITS);
 
             // Speed up the decoding by starting from the first code length
-            // that's not covered by the table
+            // that's not covered by the table.
             var len: usize = PREFIX_LUT_BITS + 1;
             var first: usize = h.last_code;
             var index: usize = h.last_index;
 
             // Reverse the prefix so that the LSB becomes the MSB and make space
-            // for the next bit
-            var code = bitReverse(prefix, PREFIX_LUT_BITS + 1);
+            // for the next bit.
+            var code = bitReverse(u32, prefix, PREFIX_LUT_BITS + 1);
 
             while (len <= MAXBITS) : (len += 1) {
                 code |= try self.readBits(1);
                 const count = h.count[len];
-                if (code < first + count)
+                if (code < first + count) {
                     return h.symbol[index + (code - first)];
+                }
                 index += count;
                 first += count;
                 first <<= 1;
@@ -520,7 +542,7 @@ pub fn InflateStream(comptime ReaderType: type) type {
             while (true) {
                 switch (self.state) {
                     .DecodeBlockHeader => {
-                        // The compressed stream is done
+                        // The compressed stream is done.
                         if (self.seen_eos) return;
 
                         const last = @intCast(u1, try self.readBits(1));
@@ -528,7 +550,7 @@ pub fn InflateStream(comptime ReaderType: type) type {
 
                         self.seen_eos = last != 0;
 
-                        // The next state depends on the block type
+                        // The next state depends on the block type.
                         switch (kind) {
                             0 => try self.stored(),
                             1 => try self.fixed(),
@@ -553,7 +575,7 @@ pub fn InflateStream(comptime ReaderType: type) type {
                             var tmp: [1]u8 = undefined;
                             if ((try self.inner_reader.read(&tmp)) != 1) {
                                 // Unexpected end of stream, keep this error
-                                // consistent with the use of readBitsNoEof
+                                // consistent with the use of readBitsNoEof.
                                 return error.EndOfStream;
                             }
                             self.window.appendUnsafe(tmp[0]);
lib/std/compress/zlib.zig
@@ -144,6 +144,19 @@ test "compressed data" {
     );
 }
 
+test "don't read past deflate stream's end" {
+    try testReader(
+        &[_]u8{
+            0x08, 0xd7, 0x63, 0xf8, 0xcf, 0xc0, 0xc0, 0x00, 0xc1, 0xff,
+            0xff, 0x43, 0x30, 0x03, 0x03, 0xc3, 0xff, 0xff, 0xff, 0x01,
+            0x83, 0x95, 0x0b, 0xf5,
+        },
+        // SHA256 of
+        // 00ff 0000 00ff 0000 00ff 00ff ffff 00ff ffff 0000 0000 ffff ff
+        "3bbba1cc65408445c81abb61f3d2b86b1b60ee0d70b4c05b96d1499091a08c93",
+    );
+}
+
 test "sanity checks" {
     // Truncated header
     testing.expectError(
lib/std/math.zig
@@ -1141,4 +1141,3 @@ test "math.comptime" {
     comptime const v = sin(@as(f32, 1)) + ln(@as(f32, 5));
     testing.expect(v == sin(@as(f32, 1)) + ln(@as(f32, 5)));
 }
-