Commit d3361c41db

Frank Denis <github@pureftpd.org>
2021-04-26 22:32:22
Change timingSafeCompare() to accept slices
1 parent 0747591
Changed files (1)
lib
std
crypto
lib/std/crypto/utils.zig
@@ -1,4 +1,5 @@
 const std = @import("../std.zig");
+const debug = std.debug;
 const mem = std.mem;
 const testing = std.testing;
 
@@ -43,44 +44,37 @@ pub fn timingSafeEql(comptime T: type, a: T, b: T) bool {
 
 /// Compare two integers serialized as arrays of the same size, in constant time.
 /// Returns .lt if a<b, .gt if a>b and .eq if a=b
-pub fn timingSafeCompare(comptime T: type, a: T, b: T, endian: Endian) Order {
-    switch (@typeInfo(T)) {
-        .Array => |info| {
-            const C = info.child;
-            const bits = switch (@typeInfo(C)) {
-                .Int => |cinfo| if (cinfo.signedness != .unsigned) @compileError("Elements to be compared must be unsigned") else cinfo.bits,
-                else => @compileError("Elements to be compared must be integers"),
-            };
-            comptime const Cext = std.meta.Int(.unsigned, bits + 1);
-            var gt: C = 0;
-            var eq: C = 1;
-            if (endian == .Little) {
-                var i = a.len;
-                while (i != 0) {
-                    i -= 1;
-                    const x1 = a[i];
-                    const x2 = b[i];
-                    gt |= @truncate(C, (@as(Cext, x2) -% @as(Cext, x1)) >> bits) & eq;
-                    eq &= @truncate(C, (@as(Cext, (x2 ^ x1)) -% 1) >> bits);
-                }
-            } else {
-                for (a) |x1, i| {
-                    const x2 = b[i];
-                    gt |= @truncate(C, (@as(Cext, x2) -% @as(Cext, x1)) >> bits) & eq;
-                    eq &= @truncate(C, (@as(Cext, (x2 ^ x1)) -% 1) >> bits);
-                }
-            }
-            if (gt != 0) {
-                return Order.gt;
-            } else if (eq != 0) {
-                return Order.eq;
-            }
-            return Order.lt;
-        },
-        else => {
-            @compileError("Only arrays can be compared");
-        },
+pub fn timingSafeCompare(comptime T: type, a: []const T, b: []const T, endian: Endian) Order {
+    debug.assert(a.len == b.len);
+    const bits = switch (@typeInfo(T)) {
+        .Int => |cinfo| if (cinfo.signedness != .unsigned) @compileError("Elements to be compared must be unsigned") else cinfo.bits,
+        else => @compileError("Elements to be compared must be integers"),
+    };
+    comptime const Cext = std.meta.Int(.unsigned, bits + 1);
+    var gt: T = 0;
+    var eq: T = 1;
+    if (endian == .Little) {
+        var i = a.len;
+        while (i != 0) {
+            i -= 1;
+            const x1 = a[i];
+            const x2 = b[i];
+            gt |= @truncate(T, (@as(Cext, x2) -% @as(Cext, x1)) >> bits) & eq;
+            eq &= @truncate(T, (@as(Cext, (x2 ^ x1)) -% 1) >> bits);
+        }
+    } else {
+        for (a) |x1, i| {
+            const x2 = b[i];
+            gt |= @truncate(T, (@as(Cext, x2) -% @as(Cext, x1)) >> bits) & eq;
+            eq &= @truncate(T, (@as(Cext, (x2 ^ x1)) -% 1) >> bits);
+        }
+    }
+    if (gt != 0) {
+        return Order.gt;
+    } else if (eq != 0) {
+        return Order.eq;
     }
+    return Order.lt;
 }
 
 /// Sets a slice to zeroes.
@@ -118,14 +112,14 @@ test "crypto.utils.timingSafeEql (vectors)" {
 test "crypto.utils.timingSafeCompare" {
     var a = [_]u8{10} ** 32;
     var b = [_]u8{10} ** 32;
-    testing.expectEqual(timingSafeCompare([32]u8, a, b, .Big), .eq);
-    testing.expectEqual(timingSafeCompare([32]u8, a, b, .Little), .eq);
+    testing.expectEqual(timingSafeCompare(u8, &a, &b, .Big), .eq);
+    testing.expectEqual(timingSafeCompare(u8, &a, &b, .Little), .eq);
     a[31] = 1;
-    testing.expectEqual(timingSafeCompare([32]u8, a, b, .Big), .lt);
-    testing.expectEqual(timingSafeCompare([32]u8, a, b, .Little), .lt);
+    testing.expectEqual(timingSafeCompare(u8, &a, &b, .Big), .lt);
+    testing.expectEqual(timingSafeCompare(u8, &a, &b, .Little), .lt);
     a[0] = 20;
-    testing.expectEqual(timingSafeCompare([32]u8, a, b, .Big), .gt);
-    testing.expectEqual(timingSafeCompare([32]u8, a, b, .Little), .lt);
+    testing.expectEqual(timingSafeCompare(u8, &a, &b, .Big), .gt);
+    testing.expectEqual(timingSafeCompare(u8, &a, &b, .Little), .lt);
 }
 
 test "crypto.utils.secureZero" {