Commit 933fd5110d

Ryan Liptak <squeek502@hotmail.com>
2021-10-02 07:01:31
deflate: Better Huffman.construct errors and error handling
This brings construct error handling in line with puff.c
1 parent c4cd592
Changed files (1)
lib
std
compress
lib/std/compress/deflate.zig
@@ -45,7 +45,9 @@ const Huffman = struct {
 
     min_code_len: u16,
 
-    fn construct(self: *Huffman, code_length: []const u16) !void {
+    const ConstructError = error{ Oversubscribed, IncompleteSet };
+
+    fn construct(self: *Huffman, code_length: []const u16) ConstructError!void {
         for (self.count) |*val| {
             val.* = 0;
         }
@@ -70,7 +72,7 @@ const Huffman = struct {
             // Make sure the number of codes with this length isn't too high.
             left -= @as(isize, @bitCast(i16, val));
             if (left < 0)
-                return error.InvalidTree;
+                return error.Oversubscribed;
         }
 
         // Compute the offset of the first symbol represented by a code of a
@@ -125,6 +127,9 @@ const Huffman = struct {
 
         self.last_code = codes[PREFIX_LUT_BITS + 1];
         self.last_index = offset[PREFIX_LUT_BITS + 1] - self.count[PREFIX_LUT_BITS + 1];
+
+        if (left > 0)
+            return error.IncompleteSet;
     }
 };
 
@@ -322,7 +327,13 @@ pub fn InflateStream(comptime ReaderType: type) type {
                 try lencode.construct(len_lengths[0..]);
 
                 const dist_lengths = [_]u16{5} ** MAXDCODES;
-                try distcode.construct(dist_lengths[0..]);
+                distcode.construct(dist_lengths[0..]) catch |err| switch (err) {
+                    // This error is expected because we only compute distance codes
+                    // 0-29, which is fine since "distance codes 30-31 will never actually
+                    // occur in the compressed data" (from section 3.2.6 of RFC1951).
+                    error.IncompleteSet => {},
+                    else => return err,
+                };
             }
 
             self.hlen = &lencode;
@@ -357,7 +368,7 @@ pub fn InflateStream(comptime ReaderType: type) type {
                     lengths[val] = @intCast(u16, try self.readBits(3));
                 }
 
-                try lencode.construct(lengths[0..]);
+                lencode.construct(lengths[0..]) catch return error.InvalidTree;
             }
 
             // Read the length/literal and distance code length tables.
@@ -406,8 +417,24 @@ pub fn InflateStream(comptime ReaderType: type) type {
             if (lengths[256] == 0)
                 return error.MissingEOBCode;
 
-            try self.huffman_tables[0].construct(lengths[0..nlen]);
-            try self.huffman_tables[1].construct(lengths[nlen .. nlen + ndist]);
+            self.huffman_tables[0].construct(lengths[0..nlen]) catch |err| switch (err) {
+                error.Oversubscribed => return error.InvalidTree,
+                error.IncompleteSet => {
+                    // incomplete code ok only for single length 1 code
+                    if (nlen != self.huffman_tables[0].count[0] + self.huffman_tables[0].count[1]) {
+                        return error.InvalidTree;
+                    }
+                },
+            };
+            self.huffman_tables[1].construct(lengths[nlen .. nlen + ndist]) catch |err| switch (err) {
+                error.Oversubscribed => return error.InvalidTree,
+                error.IncompleteSet => {
+                    // incomplete code ok only for single length 1 code
+                    if (ndist != self.huffman_tables[1].count[0] + self.huffman_tables[1].count[1]) {
+                        return error.InvalidTree;
+                    }
+                },
+            };
 
             self.hlen = &self.huffman_tables[0];
             self.hdist = &self.huffman_tables[1];