Commit 082acd7f17

dweiller <4678790+dweiller@users.noreplay.github.com>
2023-01-23 13:46:15
std.compress.zstandard: clean up integer casts
1 parent fc64c27
Changed files (1)
lib
std
compress
zstandard
lib/std/compress/zstandard/decompress.zig
@@ -168,8 +168,11 @@ pub const DecodeState = struct {
                 const data = table[@field(self, @tagName(choice)).state];
                 const T = @TypeOf(@field(self, @tagName(choice))).State;
                 const bits_summand = try bit_reader.readBitsNoEof(T, data.bits);
-                const next_state = data.baseline + bits_summand;
-                @field(self, @tagName(choice)).state = @intCast(@TypeOf(@field(self, @tagName(choice))).State, next_state);
+                const next_state = std.math.cast(
+                    @TypeOf(@field(self, @tagName(choice))).State,
+                    data.baseline + bits_summand,
+                ) orelse return error.MalformedFseBits;
+                @field(self, @tagName(choice)).state = next_state;
             },
         }
     }
@@ -1045,10 +1048,10 @@ fn decodeHuffmanTree(src: []const u8, consumed_count: *usize) !LiteralsSection.H
             const even_data = entries[even_state];
             var read_bits: usize = 0;
             const even_bits = try huff_bits.readBits(u32, even_data.bits, &read_bits);
-            weights[i] = @intCast(u4, even_data.symbol);
+            weights[i] = std.math.cast(u4, even_data.symbol) orelse return error.MalformedHuffmanTree;
             i += 1;
             if (read_bits < even_data.bits) {
-                weights[i] = @intCast(u4, entries[odd_state].symbol);
+                weights[i] = std.math.cast(u4, entries[odd_state].symbol) orelse return error.MalformedHuffmanTree;
                 log.debug("overflow condition: setting weights[{d}] = {d}", .{ i, weights[i] });
                 i += 1;
                 break;
@@ -1058,11 +1061,11 @@ fn decodeHuffmanTree(src: []const u8, consumed_count: *usize) !LiteralsSection.H
             read_bits = 0;
             const odd_data = entries[odd_state];
             const odd_bits = try huff_bits.readBits(u32, odd_data.bits, &read_bits);
-            weights[i] = @intCast(u4, odd_data.symbol);
+            weights[i] = std.math.cast(u4, odd_data.symbol) orelse return error.MalformedHuffmanTree;
             i += 1;
             if (read_bits < odd_data.bits) {
                 if (i == 256) return error.MalformedHuffmanTree;
-                weights[i] = @intCast(u4, entries[even_state].symbol);
+                weights[i] = std.math.cast(u4, entries[even_state].symbol) orelse return error.MalformedHuffmanTree;
                 log.debug("overflow condition: setting weights[{d}] = {d}", .{ i, weights[i] });
                 i += 1;
                 break;
@@ -1100,9 +1103,9 @@ fn decodeHuffmanTree(src: []const u8, consumed_count: *usize) !LiteralsSection.H
     log.debug("weight power sum = {d}", .{weight_power_sum});
 
     // advance to next power of two (even if weight_power_sum is a power of 2)
-    max_number_of_bits = @intCast(u4, std.math.log2_int(u16, weight_power_sum) + 1);
+    max_number_of_bits = std.math.log2_int(u16, weight_power_sum) + 1;
     const next_power_of_two = @as(u16, 1) << max_number_of_bits;
-    weights[symbol_count - 1] = @intCast(u4, std.math.log2_int(u16, next_power_of_two - weight_power_sum) + 1);
+    weights[symbol_count - 1] = std.math.log2_int(u16, next_power_of_two - weight_power_sum) + 1;
     log.debug("weights[{d}] = {d}", .{ symbol_count - 1, weights[symbol_count - 1] });
 
     var weight_sorted_prefixed_symbols: [256]LiteralsSection.HuffmanTree.PrefixedSymbol = undefined;
@@ -1367,7 +1370,7 @@ fn decodeFseTable(
     while (accumulated_probability < total_probability) {
         // WARNING: The RFC in poorly worded, and would suggest std.math.log2_int_ceil is correct here,
         //          but power of two (remaining probabilities + 1) need max bits set to 1 more.
-        const max_bits = @intCast(u4, std.math.log2_int(u16, total_probability - accumulated_probability + 1)) + 1;
+        const max_bits = std.math.log2_int(u16, total_probability - accumulated_probability + 1) + 1;
         const small = try bit_reader.readBitsNoEof(u16, max_bits - 1);
 
         const cutoff = (@as(u16, 1) << max_bits) - 1 - (total_probability - accumulated_probability + 1);