Commit a5d4ad17b7

Frank Denis <124872+jedisct1@users.noreply.github.com>
2024-11-20 11:16:09
crypto.keccak.State: add checks to prevent insecure transitions (#22020)
* crypto.keccak.State: don't unconditionally permute after a squeeze() Now, squeeze() behaves like absorb() Namely, squeeze(x[0..t]); squeeze(x[t..n)); with t <= n becomes equivalent to squeeze(x[0..n]). * keccak: in debug mode, track transitions to prevent insecure ones. Fixes #22019
1 parent dafe1a9
Changed files (1)
lib
std
lib/std/crypto/keccak_p.zig
@@ -4,6 +4,7 @@ const assert = std.debug.assert;
 const math = std.math;
 const mem = std.mem;
 const native_endian = builtin.cpu.arch.endian();
+const mode = @import("builtin").mode;
 
 /// The Keccak-f permutation.
 pub fn KeccakF(comptime f: u11) type {
@@ -199,6 +200,46 @@ pub fn State(comptime f: u11, comptime capacity: u11, comptime rounds: u5) type
     comptime assert(f >= 200 and f <= 1600 and f % 200 == 0); // invalid state size
     comptime assert(capacity < f and capacity % 8 == 0); // invalid capacity size
 
+    // In debug mode, track transitions to prevent insecure ones.
+    const Op = enum { uninitialized, initialized, updated, absorb, squeeze };
+    const TransitionTracker = if (mode == .Debug) struct {
+        op: Op = .uninitialized,
+
+        fn to(tracker: *@This(), next_op: Op) void {
+            switch (next_op) {
+                .updated => {
+                    switch (tracker.op) {
+                        .uninitialized => @panic("cannot permute before initializing"),
+                        else => {},
+                    }
+                },
+                .absorb => {
+                    switch (tracker.op) {
+                        .squeeze => @panic("cannot absorb right after squeezing"),
+                        else => {},
+                    }
+                },
+                .squeeze => {
+                    switch (tracker.op) {
+                        .uninitialized => @panic("cannot squeeze before initializing"),
+                        .initialized => @panic("cannot squeeze right after initializing"),
+                        .absorb => @panic("cannot squeeze right after absorbing"),
+                        else => {},
+                    }
+                },
+                .uninitialized => @panic("cannot transition to uninitialized"),
+                .initialized => {},
+            }
+            tracker.op = next_op;
+        }
+    } else struct {
+        // No-op in non-debug modes.
+        inline fn to(tracker: *@This(), next_op: Op) void {
+            _ = tracker; // no-op
+            _ = next_op; // no-op
+        }
+    };
+
     return struct {
         const Self = @This();
 
@@ -215,67 +256,108 @@ pub fn State(comptime f: u11, comptime capacity: u11, comptime rounds: u5) type
 
         st: KeccakF(f) = .{},
 
+        transition: TransitionTracker = .{},
+
         /// Absorb a slice of bytes into the sponge.
-        pub fn absorb(self: *Self, bytes_: []const u8) void {
-            var bytes = bytes_;
+        pub fn absorb(self: *Self, bytes: []const u8) void {
+            self.transition.to(.absorb);
+            var i: usize = 0;
             if (self.offset > 0) {
                 const left = @min(rate - self.offset, bytes.len);
                 @memcpy(self.buf[self.offset..][0..left], bytes[0..left]);
                 self.offset += left;
+                if (left == bytes.len) return;
                 if (self.offset == rate) {
-                    self.offset = 0;
                     self.st.addBytes(self.buf[0..]);
                     self.st.permuteR(rounds);
+                    self.offset = 0;
                 }
-                if (left == bytes.len) return;
-                bytes = bytes[left..];
+                i = left;
             }
-            while (bytes.len >= rate) {
-                self.st.addBytes(bytes[0..rate]);
+            while (i + rate < bytes.len) : (i += rate) {
+                self.st.addBytes(bytes[i..][0..rate]);
                 self.st.permuteR(rounds);
-                bytes = bytes[rate..];
             }
-            if (bytes.len > 0) {
-                @memcpy(self.buf[0..bytes.len], bytes);
-                self.offset = bytes.len;
+            const left = bytes.len - i;
+            if (left > 0) {
+                @memcpy(self.buf[0..left], bytes[i..][0..left]);
             }
+            self.offset = left;
         }
 
         /// Initialize the state from a slice of bytes.
-        pub fn init(bytes: [f / 8]u8) Self {
-            return .{ .st = KeccakF(f).init(bytes) };
+        pub fn init(bytes: [f / 8]u8, delim: u8) Self {
+            var st = Self{ .st = KeccakF(f).init(bytes), .delim = delim };
+            st.transition.to(.initialized);
+            return st;
         }
 
         /// Permute the state
         pub fn permute(self: *Self) void {
+            if (mode == .Debug) {
+                if (self.transition.op == .absorb and self.offset > 0) {
+                    @panic("cannot permute with pending input - call fillBlock() or pad() instead");
+                }
+            }
+            self.transition.to(.updated);
             self.st.permuteR(rounds);
             self.offset = 0;
         }
 
-        /// Align the input to the rate boundary.
+        /// Align the input to the rate boundary and permute.
         pub fn fillBlock(self: *Self) void {
+            self.transition.to(.absorb);
             self.st.addBytes(self.buf[0..self.offset]);
             self.st.permuteR(rounds);
             self.offset = 0;
+            self.transition.to(.updated);
         }
 
         /// Mark the end of the input.
         pub fn pad(self: *Self) void {
+            self.transition.to(.absorb);
             self.st.addBytes(self.buf[0..self.offset]);
+            if (self.offset == rate) {
+                self.st.permuteR(rounds);
+                self.offset = 0;
+            }
             self.st.addByte(self.delim, self.offset);
             self.st.addByte(0x80, rate - 1);
             self.st.permuteR(rounds);
             self.offset = 0;
+            self.transition.to(.updated);
         }
 
         /// Squeeze a slice of bytes from the sponge.
+        /// The function can be called multiple times.
         pub fn squeeze(self: *Self, out: []u8) void {
+            self.transition.to(.squeeze);
             var i: usize = 0;
-            while (i < out.len) : (i += rate) {
-                const left = @min(rate, out.len - i);
-                self.st.extractBytes(out[i..][0..left]);
+            if (self.offset == rate) {
+                self.st.permuteR(rounds);
+            } else if (self.offset > 0) {
+                @branchHint(.unlikely);
+                var buf: [rate]u8 = undefined;
+                self.st.extractBytes(buf[0..]);
+                const left = @min(rate - self.offset, out.len);
+                @memcpy(out[0..left], buf[self.offset..][0..left]);
+                self.offset += left;
+                if (left == out.len) return;
+                if (self.offset == rate) {
+                    self.offset = 0;
+                    self.st.permuteR(rounds);
+                }
+                i = left;
+            }
+            while (i + rate < out.len) : (i += rate) {
+                self.st.extractBytes(out[i..][0..rate]);
                 self.st.permuteR(rounds);
             }
+            const left = out.len - i;
+            if (left > 0) {
+                self.st.extractBytes(out[i..][0..left]);
+            }
+            self.offset = left;
         }
     };
 }
@@ -298,3 +380,26 @@ test "Keccak-f800" {
     };
     try std.testing.expectEqualSlices(u32, &st.st, &expected);
 }
+
+test "squeeze" {
+    var st = State(800, 256, 22).init([_]u8{0x80} ** 100, 0x01);
+
+    var out0: [15]u8 = undefined;
+    var out1: [out0.len]u8 = undefined;
+    st.permute();
+    var st0 = st;
+    st0.squeeze(out0[0..]);
+    var st1 = st;
+    st1.squeeze(out1[0 .. out1.len / 2]);
+    st1.squeeze(out1[out1.len / 2 ..]);
+    try std.testing.expectEqualSlices(u8, &out0, &out1);
+
+    var out2: [100]u8 = undefined;
+    var out3: [out2.len]u8 = undefined;
+    var st2 = st;
+    st2.squeeze(out2[0..]);
+    var st3 = st;
+    st3.squeeze(out3[0 .. out2.len / 2]);
+    st3.squeeze(out3[out2.len / 2 ..]);
+    try std.testing.expectEqualSlices(u8, &out2, &out3);
+}