Commit 057d30bacc

Frank Denis <124872+jedisct1@users.noreply.github.com>
2023-05-23 16:36:44
std.crypto.chacha: remove the hack for ChaCha with a 64-bit counter (#15818)
Support for 64-bit counters was a hack built upon the version with a 32-bit counter, that emulated a larger counter by splitting the input into large blocks. This is fragile, particularily if the initial counter is set to a non-default value and if we have parallelism. Simply add a comptime parameter to check if we have a 32 bit or a 64 bit counter instead. Also convert a couple while() loops to for(), and change @panic() to @compileError().
1 parent 7cb3a67
Changed files (1)
lib
std
lib/std/crypto/chacha20.zig
@@ -109,14 +109,18 @@ fn ChaChaVecImpl(comptime rounds_nb: usize, comptime degree: comptime_int) type
                         mem.readIntLittle(u32, c[8..12]),
                         mem.readIntLittle(u32, c[12..16]),
                     };
+                    const n1 = @addWithOverflow(d[0], 1);
                     return BlockVec{
                         constant_le,
                         Lane{ key[0], key[1], key[2], key[3], key[0], key[1], key[2], key[3] },
                         Lane{ key[4], key[5], key[6], key[7], key[4], key[5], key[6], key[7] },
-                        Lane{ d[0], d[1], d[2], d[3], d[0] +% 1, d[1], d[2], d[3] },
+                        Lane{ d[0], d[1], d[2], d[3], n1[0], d[1] +% n1[1], d[2], d[3] },
                     };
                 },
                 4 => {
+                    const n1 = @addWithOverflow(d[0], 1);
+                    const n2 = @addWithOverflow(d[0], 2);
+                    const n3 = @addWithOverflow(d[0], 3);
                     const constant_le = Lane{
                         mem.readIntLittle(u32, c[0..4]),
                         mem.readIntLittle(u32, c[4..8]),
@@ -139,10 +143,10 @@ fn ChaChaVecImpl(comptime rounds_nb: usize, comptime degree: comptime_int) type
                         constant_le,
                         Lane{ key[0], key[1], key[2], key[3], key[0], key[1], key[2], key[3], key[0], key[1], key[2], key[3], key[0], key[1], key[2], key[3] },
                         Lane{ key[4], key[5], key[6], key[7], key[4], key[5], key[6], key[7], key[4], key[5], key[6], key[7], key[4], key[5], key[6], key[7] },
-                        Lane{ d[0], d[1], d[2], d[3], d[0] +% 1, d[1], d[2], d[3], d[0] +% 2, d[1], d[2], d[3], d[0] +% 3, d[1], d[2], d[3] },
+                        Lane{ d[0], d[1], d[2], d[3], n1[0], d[1] +% n1[1], d[2], d[3], n2[0], d[1] +% n2[1], d[2], d[3], n3[0], d[1] +% n3[1], d[2], d[3] },
                     };
                 },
-                else => @panic("invalid degree"),
+                else => @compileError("invalid degree"),
             }
         }
 
@@ -153,19 +157,19 @@ fn ChaChaVecImpl(comptime rounds_nb: usize, comptime degree: comptime_int) type
                 1 => [_]i32{ 3, 0, 1, 2 },
                 2 => [_]i32{ 3, 0, 1, 2 } ++ [_]i32{ 7, 4, 5, 6 },
                 4 => [_]i32{ 3, 0, 1, 2 } ++ [_]i32{ 7, 4, 5, 6 } ++ [_]i32{ 11, 8, 9, 10 } ++ [_]i32{ 15, 12, 13, 14 },
-                else => @panic("invalid degree"),
+                else => @compileError("invalid degree"),
             };
             const m1 = switch (degree) {
                 1 => [_]i32{ 2, 3, 0, 1 },
                 2 => [_]i32{ 2, 3, 0, 1 } ++ [_]i32{ 6, 7, 4, 5 },
                 4 => [_]i32{ 2, 3, 0, 1 } ++ [_]i32{ 6, 7, 4, 5 } ++ [_]i32{ 10, 11, 8, 9 } ++ [_]i32{ 14, 15, 12, 13 },
-                else => @panic("invalid degree"),
+                else => @compileError("invalid degree"),
             };
             const m2 = switch (degree) {
                 1 => [_]i32{ 1, 2, 3, 0 },
                 2 => [_]i32{ 1, 2, 3, 0 } ++ [_]i32{ 5, 6, 7, 4 },
                 4 => [_]i32{ 1, 2, 3, 0 } ++ [_]i32{ 5, 6, 7, 4 } ++ [_]i32{ 9, 10, 11, 8 } ++ [_]i32{ 13, 14, 15, 12 },
-                else => @panic("invalid degree"),
+                else => @compileError("invalid degree"),
             };
 
             var r: usize = 0;
@@ -212,8 +216,7 @@ fn ChaChaVecImpl(comptime rounds_nb: usize, comptime degree: comptime_int) type
 
         inline fn hashToBytes(comptime dm: usize, out: *[64 * dm]u8, x: BlockVec) void {
             for (0..dm) |d| {
-                var i: usize = 0;
-                while (i < 4) : (i += 1) {
+                for (0..4) |i| {
                     mem.writeIntLittle(u32, out[64 * d + 16 * i + 0 ..][0..4], x[i][0 + 4 * d]);
                     mem.writeIntLittle(u32, out[64 * d + 16 * i + 4 ..][0..4], x[i][1 + 4 * d]);
                     mem.writeIntLittle(u32, out[64 * d + 16 * i + 8 ..][0..4], x[i][2 + 4 * d]);
@@ -229,8 +232,8 @@ fn ChaChaVecImpl(comptime rounds_nb: usize, comptime degree: comptime_int) type
             x[3] +%= ctx[3];
         }
 
-        fn chacha20Xor(out: []u8, in: []const u8, key: [8]u32, counter: [4]u32) void {
-            var ctx = initContext(key, counter);
+        fn chacha20Xor(out: []u8, in: []const u8, key: [8]u32, nonce_and_counter: [4]u32, comptime count64: bool) void {
+            var ctx = initContext(key, nonce_and_counter);
             var x: BlockVec = undefined;
             var buf: [64 * degree]u8 = undefined;
             var i: usize = 0;
@@ -242,16 +245,20 @@ fn ChaChaVecImpl(comptime rounds_nb: usize, comptime degree: comptime_int) type
 
                     var xout = out[i..];
                     const xin = in[i..];
-                    var j: usize = 0;
-                    while (j < 64 * d) : (j += 1) {
+                    for (0..64 * d) |j| {
                         xout[j] = xin[j];
                     }
-                    j = 0;
-                    while (j < 64 * d) : (j += 1) {
+                    for (0..64 * d) |j| {
                         xout[j] ^= buf[j];
                     }
                     inline for (0..d) |d_| {
-                        ctx[3][4 * d_] += @intCast(u32, d);
+                        if (count64) {
+                            const next = @addWithOverflow(ctx[3][4 * d_], d);
+                            ctx[3][4 * d_] = next[0];
+                            ctx[3][4 * d_ + 1] +%= next[1];
+                        } else {
+                            ctx[3][4 * d_] +%= d;
+                        }
                     }
                 }
             }
@@ -262,15 +269,14 @@ fn ChaChaVecImpl(comptime rounds_nb: usize, comptime degree: comptime_int) type
 
                 var xout = out[i..];
                 const xin = in[i..];
-                var j: usize = 0;
-                while (j < in.len % 64) : (j += 1) {
+                for (0..in.len % 64) |j| {
                     xout[j] = xin[j] ^ buf[j];
                 }
             }
         }
 
-        fn chacha20Stream(out: []u8, key: [8]u32, counter: [4]u32) void {
-            var ctx = initContext(key, counter);
+        fn chacha20Stream(out: []u8, key: [8]u32, nonce_and_counter: [4]u32, comptime count64: bool) void {
+            var ctx = initContext(key, nonce_and_counter);
             var x: BlockVec = undefined;
             var i: usize = 0;
             inline for ([_]comptime_int{ 4, 2, 1 }) |d| {
@@ -279,7 +285,13 @@ fn ChaChaVecImpl(comptime rounds_nb: usize, comptime degree: comptime_int) type
                     contextFeedback(&x, ctx);
                     hashToBytes(d, out[i..][0 .. 64 * d], x);
                     inline for (0..d) |d_| {
-                        ctx[3][4 * d_] += @intCast(u32, d);
+                        if (count64) {
+                            const next = @addWithOverflow(ctx[3][4 * d_], d);
+                            ctx[3][4 * d_] = next[0];
+                            ctx[3][4 * d_ + 1] +%= next[1];
+                        } else {
+                            ctx[3][4 * d_] +%= d;
+                        }
                     }
                 }
             }
@@ -382,8 +394,7 @@ fn ChaChaNonVecImpl(comptime rounds_nb: usize) type {
         }
 
         inline fn hashToBytes(out: *[64]u8, x: BlockVec) void {
-            var i: usize = 0;
-            while (i < 4) : (i += 1) {
+            for (0..4) |i| {
                 mem.writeIntLittle(u32, out[16 * i + 0 ..][0..4], x[i * 4 + 0]);
                 mem.writeIntLittle(u32, out[16 * i + 4 ..][0..4], x[i * 4 + 1]);
                 mem.writeIntLittle(u32, out[16 * i + 8 ..][0..4], x[i * 4 + 2]);
@@ -392,14 +403,13 @@ fn ChaChaNonVecImpl(comptime rounds_nb: usize) type {
         }
 
         inline fn contextFeedback(x: *BlockVec, ctx: BlockVec) void {
-            var i: usize = 0;
-            while (i < 16) : (i += 1) {
+            for (0..16) |i| {
                 x[i] +%= ctx[i];
             }
         }
 
-        fn chacha20Xor(out: []u8, in: []const u8, key: [8]u32, counter: [4]u32) void {
-            var ctx = initContext(key, counter);
+        fn chacha20Xor(out: []u8, in: []const u8, key: [8]u32, nonce_and_counter: [4]u32, comptime count64: bool) void {
+            var ctx = initContext(key, nonce_and_counter);
             var x: BlockVec = undefined;
             var buf: [64]u8 = undefined;
             var i: usize = 0;
@@ -410,15 +420,19 @@ fn ChaChaNonVecImpl(comptime rounds_nb: usize) type {
 
                 var xout = out[i..];
                 const xin = in[i..];
-                var j: usize = 0;
-                while (j < 64) : (j += 1) {
+                for (0..64) |j| {
                     xout[j] = xin[j];
                 }
-                j = 0;
-                while (j < 64) : (j += 1) {
+                for (0..64) |j| {
                     xout[j] ^= buf[j];
                 }
-                ctx[12] += 1;
+                if (count64) {
+                    const next = @addWithOverflow(ctx[12], 1);
+                    ctx[12] = next[0];
+                    ctx[13] +%= next[1];
+                } else {
+                    ctx[12] +%= 1;
+                }
             }
             if (i < in.len) {
                 chacha20Core(x[0..], ctx);
@@ -427,22 +441,27 @@ fn ChaChaNonVecImpl(comptime rounds_nb: usize) type {
 
                 var xout = out[i..];
                 const xin = in[i..];
-                var j: usize = 0;
-                while (j < in.len % 64) : (j += 1) {
+                for (0..in.len % 64) |j| {
                     xout[j] = xin[j] ^ buf[j];
                 }
             }
         }
 
-        fn chacha20Stream(out: []u8, key: [8]u32, counter: [4]u32) void {
-            var ctx = initContext(key, counter);
+        fn chacha20Stream(out: []u8, key: [8]u32, nonce_and_counter: [4]u32, comptime count64: bool) void {
+            var ctx = initContext(key, nonce_and_counter);
             var x: BlockVec = undefined;
             var i: usize = 0;
             while (i + 64 <= out.len) : (i += 64) {
                 chacha20Core(x[0..], ctx);
                 contextFeedback(&x, ctx);
                 hashToBytes(out[i..][0..64], x);
-                ctx[12] += 1;
+                if (count64) {
+                    const next = @addWithOverflow(ctx[12], 1);
+                    ctx[12] = next[0];
+                    ctx[13] +%= next[1];
+                } else {
+                    ctx[12] +%= 1;
+                }
             }
             if (i < out.len) {
                 chacha20Core(x[0..], ctx);
@@ -496,8 +515,7 @@ fn ChaChaImpl(comptime rounds_nb: usize) type {
 
 fn keyToWords(key: [32]u8) [8]u32 {
     var k: [8]u32 = undefined;
-    var i: usize = 0;
-    while (i < 8) : (i += 1) {
+    for (0..8) |i| {
         k[i] = mem.readIntLittle(u32, key[i * 4 ..][0..4]);
     }
     return k;
@@ -527,26 +545,26 @@ fn ChaChaIETF(comptime rounds_nb: usize) type {
         /// Using the AEAD or one of the `box` versions is usually preferred.
         pub fn xor(out: []u8, in: []const u8, counter: u32, key: [key_length]u8, nonce: [nonce_length]u8) void {
             assert(in.len == out.len);
-            assert(in.len / 64 <= (1 << 32 - 1) - counter);
+            assert(in.len <= 64 * (@as(u39, 1 << 32) - counter));
 
             var d: [4]u32 = undefined;
             d[0] = counter;
             d[1] = mem.readIntLittle(u32, nonce[0..4]);
             d[2] = mem.readIntLittle(u32, nonce[4..8]);
             d[3] = mem.readIntLittle(u32, nonce[8..12]);
-            ChaChaImpl(rounds_nb).chacha20Xor(out, in, keyToWords(key), d);
+            ChaChaImpl(rounds_nb).chacha20Xor(out, in, keyToWords(key), d, false);
         }
 
         /// Write the output of the ChaCha20 stream cipher into `out`.
         pub fn stream(out: []u8, counter: u32, key: [key_length]u8, nonce: [nonce_length]u8) void {
-            assert(out.len / 64 <= (1 << 32 - 1) - counter);
+            assert(out.len <= 64 * (@as(u39, 1 << 32) - counter));
 
             var d: [4]u32 = undefined;
             d[0] = counter;
             d[1] = mem.readIntLittle(u32, nonce[0..4]);
             d[2] = mem.readIntLittle(u32, nonce[4..8]);
             d[3] = mem.readIntLittle(u32, nonce[8..12]);
-            ChaChaImpl(rounds_nb).chacha20Stream(out, keyToWords(key), d);
+            ChaChaImpl(rounds_nb).chacha20Stream(out, keyToWords(key), d, false);
         }
     };
 }
@@ -565,47 +583,28 @@ fn ChaChaWith64BitNonce(comptime rounds_nb: usize) type {
         /// Using the AEAD or one of the `box` versions is usually preferred.
         pub fn xor(out: []u8, in: []const u8, counter: u64, key: [key_length]u8, nonce: [nonce_length]u8) void {
             assert(in.len == out.len);
-            assert(in.len / 64 <= (1 << 64 - 1) - counter);
+            assert(in.len <= 64 * (@as(u71, 1 << 64) - counter));
 
-            var cursor: usize = 0;
             const k = keyToWords(key);
             var c: [4]u32 = undefined;
             c[0] = @truncate(u32, counter);
             c[1] = @truncate(u32, counter >> 32);
             c[2] = mem.readIntLittle(u32, nonce[0..4]);
             c[3] = mem.readIntLittle(u32, nonce[4..8]);
-
-            // The full block size is greater than the address space on a 32bit machine
-            const big_block = if (@sizeOf(usize) > 4) (block_length << 32) else maxInt(usize);
-
-            // first partial big block
-            if (((@intCast(u64, maxInt(u32) - @truncate(u32, counter)) + 1) << 6) < in.len) {
-                ChaChaImpl(rounds_nb).chacha20Xor(out[cursor..big_block], in[cursor..big_block], k, c);
-                cursor = big_block - cursor;
-                c[1] += 1;
-                if (comptime @sizeOf(usize) > 4) {
-                    // A big block is giant: 256 GiB, but we can avoid this limitation
-                    var remaining_blocks: u32 = @intCast(u32, (in.len / big_block));
-                    while (remaining_blocks > 0) : (remaining_blocks -= 1) {
-                        ChaChaImpl(rounds_nb).chacha20Xor(out[cursor .. cursor + big_block], in[cursor .. cursor + big_block], k, c);
-                        c[1] += 1; // upper 32-bit of counter, generic chacha20Xor() doesn't know about this.
-                        cursor += big_block;
-                    }
-                }
-            }
-            ChaChaImpl(rounds_nb).chacha20Xor(out[cursor..], in[cursor..], k, c);
+            ChaChaImpl(rounds_nb).chacha20Xor(out, in, k, c, true);
         }
 
         /// Write the output of the ChaCha20 stream cipher into `out`.
         pub fn stream(out: []u8, counter: u32, key: [key_length]u8, nonce: [nonce_length]u8) void {
-            assert(out.len / 64 <= (1 << 32 - 1) - counter);
+            assert(out.len <= 64 * (@as(u71, 1 << 64) - counter));
+
             const k = keyToWords(key);
             var c: [4]u32 = undefined;
             c[0] = @truncate(u32, counter);
             c[1] = @truncate(u32, counter >> 32);
             c[2] = mem.readIntLittle(u32, nonce[0..4]);
             c[3] = mem.readIntLittle(u32, nonce[4..8]);
-            ChaChaImpl(rounds_nb).chacha20Stream(out, k, c);
+            ChaChaImpl(rounds_nb).chacha20Stream(out, k, c, true);
         }
     };
 }
@@ -649,6 +648,7 @@ fn ChaChaPoly1305(comptime rounds_nb: usize) type {
         /// k: private key
         pub fn encrypt(c: []u8, tag: *[tag_length]u8, m: []const u8, ad: []const u8, npub: [nonce_length]u8, k: [key_length]u8) void {
             assert(c.len == m.len);
+            assert(m.len <= 64 * (@as(u39, 1 << 32) - 1));
 
             var polyKey = [_]u8{0} ** 32;
             ChaChaIETF(rounds_nb).xor(polyKey[0..], polyKey[0..], 0, k, npub);
@@ -766,7 +766,7 @@ test "chacha20 AEAD API" {
         aead.encrypt(c[0..], tag[0..], m, ad, nonce, key);
         try aead.decrypt(out[0..], c[0..], tag, ad[0..], nonce, key);
         try testing.expectEqualSlices(u8, out[0..], m);
-        c[0] += 1;
+        c[0] +%= 1;
         try testing.expectError(error.AuthenticationFailed, aead.decrypt(out[0..], c[0..], tag, ad[0..], nonce, key));
     }
 }
@@ -1154,7 +1154,7 @@ test "crypto.xchacha20" {
         var buf: [2 * c.len]u8 = undefined;
         try testing.expectEqualStrings(try std.fmt.bufPrint(&buf, "{s}", .{std.fmt.fmtSliceHexUpper(&c)}), "994D2DD32333F48E53650C02C7A2ABB8E018B0836D7175AEC779F52E961780768F815C58F1AA52D211498DB89B9216763F569C9433A6BBFCEFB4D4A49387A4C5207FBB3B5A92B5941294DF30588C6740D39DC16FA1F0E634F7246CF7CDCB978E44347D89381B7A74EB7084F754B90BDE9AAF5A94B8F2A85EFD0B50692AE2D425E234");
         try testing.expectEqualSlices(u8, out[0..], m);
-        c[0] += 1;
+        c[0] +%= 1;
         try testing.expectError(error.AuthenticationFailed, XChaCha20Poly1305.decrypt(out[0..], c[0..m.len], c[m.len..].*, ad, nonce, key));
     }
 }