master
  1// Based on Go stdlib implementation
  2
  3const std = @import("../std.zig");
  4const mem = std.mem;
  5const debug = std.debug;
  6
  7/// Counter mode.
  8///
  9/// This mode creates a key stream by encrypting an incrementing counter using a block cipher, and adding it to the source material.
 10///
 11/// Important: the counter mode doesn't provide authenticated encryption: the ciphertext can be trivially modified without this being detected.
 12/// As a result, applications should generally never use it directly, but only in a construction that includes a MAC.
 13pub fn ctr(comptime BlockCipher: anytype, block_cipher: BlockCipher, dst: []u8, src: []const u8, iv: [BlockCipher.block_length]u8, endian: std.builtin.Endian) void {
 14    ctrSlice(BlockCipher, block_cipher, dst, src, iv, endian, 0, BlockCipher.block_length);
 15}
 16
 17/// Counter mode with configurable counter position and size.
 18///
 19/// This extended version allows specifying where the counter is located within the IV block
 20/// and how many bytes it occupies. This is useful for modes like AES-GCM-SIV which use a
 21/// 32-bit counter at the beginning of the block.
 22///
 23/// @param counter_offset: Byte offset where the counter starts
 24/// @param counter_size: Size of the counter in bytes
 25pub fn ctrSlice(
 26    comptime BlockCipher: anytype,
 27    block_cipher: BlockCipher,
 28    dst: []u8,
 29    src: []const u8,
 30    iv: [BlockCipher.block_length]u8,
 31    endian: std.builtin.Endian,
 32    comptime counter_offset: usize,
 33    comptime counter_size: usize,
 34) void {
 35    debug.assert(dst.len >= src.len);
 36    const block_length = BlockCipher.block_length;
 37    debug.assert(counter_offset + counter_size <= block_length);
 38    debug.assert(counter_size > 0 and counter_size <= block_length);
 39
 40    var counterBlock = iv;
 41    var i: usize = 0;
 42
 43    const CounterInt = std.meta.Int(.unsigned, counter_size * 8);
 44
 45    const parallel_count = BlockCipher.block.parallel.optimal_parallel_blocks;
 46    const wide_block_length = parallel_count * block_length;
 47    var cnt_val = mem.readInt(CounterInt, counterBlock[counter_offset..][0..counter_size], endian);
 48    if (src.len >= wide_block_length) {
 49        var counters: [parallel_count * block_length]u8 = undefined;
 50        inline for (0..parallel_count) |j| {
 51            counters[j * block_length ..][0..block_length].* = iv;
 52        }
 53        while (i + wide_block_length <= src.len) : (i += wide_block_length) {
 54            comptime var j = 0;
 55            inline while (j < parallel_count) : (j += 1) {
 56                mem.writeInt(CounterInt, counters[j * block_length + counter_offset ..][0..counter_size], cnt_val +% j, endian);
 57            }
 58            cnt_val += parallel_count;
 59            block_cipher.xorWide(parallel_count, dst[i .. i + wide_block_length][0..wide_block_length], src[i .. i + wide_block_length][0..wide_block_length], counters);
 60        }
 61        mem.writeInt(CounterInt, counterBlock[counter_offset..][0..counter_size], cnt_val, endian);
 62    }
 63    while (i + block_length <= src.len) : (i += block_length) {
 64        block_cipher.xor(dst[i .. i + block_length][0..block_length], src[i .. i + block_length][0..block_length], counterBlock);
 65        cnt_val +%= 1;
 66        mem.writeInt(CounterInt, counterBlock[counter_offset..][0..counter_size], cnt_val, endian);
 67    }
 68    if (i < src.len) {
 69        var pad: [block_length]u8 = @splat(0);
 70        const src_slice = src[i..];
 71        @memcpy(pad[0..src_slice.len], src_slice);
 72        block_cipher.xor(&pad, &pad, counterBlock);
 73        const pad_slice = pad[0 .. src.len - i];
 74        @memcpy(dst[i..][0..pad_slice.len], pad_slice);
 75    }
 76}
 77
 78test "ctr mode" {
 79    const testing = std.testing;
 80    const aes = std.crypto.core.aes;
 81
 82    // Test key and IV from NIST SP 800-38A
 83    const key = [_]u8{ 0x2b, 0x7e, 0x15, 0x16, 0x28, 0xae, 0xd2, 0xa6, 0xab, 0xf7, 0x15, 0x88, 0x09, 0xcf, 0x4f, 0x3c };
 84    const iv = [_]u8{ 0xf0, 0xf1, 0xf2, 0xf3, 0xf4, 0xf5, 0xf6, 0xf7, 0xf8, 0xf9, 0xfa, 0xfb, 0xfc, 0xfd, 0xfe, 0xff };
 85    const ctx = aes.Aes128.initEnc(key);
 86
 87    // Test 1: Empty input
 88    {
 89        const in = [_]u8{};
 90        const expected = [_]u8{};
 91        var out: [0]u8 = undefined;
 92        ctr(aes.AesEncryptCtx(aes.Aes128), ctx, out[0..], in[0..], iv, std.builtin.Endian.big);
 93        try testing.expectEqualSlices(u8, expected[0..], out[0..]);
 94    }
 95
 96    // Test 2: Single byte
 97    {
 98        const in = [_]u8{0x6b};
 99        const expected = [_]u8{0x87};
100        var out: [1]u8 = undefined;
101        ctr(aes.AesEncryptCtx(aes.Aes128), ctx, out[0..], in[0..], iv, std.builtin.Endian.big);
102        try testing.expectEqualSlices(u8, expected[0..], out[0..]);
103    }
104
105    // Test 3: Less than one block (15 bytes)
106    {
107        const in = [_]u8{ 0x6b, 0xc1, 0xbe, 0xe2, 0x2e, 0x40, 0x9f, 0x96, 0xe9, 0x3d, 0x7e, 0x11, 0x73, 0x93, 0x17 };
108        const expected = [_]u8{ 0x87, 0x4d, 0x61, 0x91, 0xb6, 0x20, 0xe3, 0x26, 0x1b, 0xef, 0x68, 0x64, 0x99, 0x0d, 0xb6 };
109        var out: [15]u8 = undefined;
110        ctr(aes.AesEncryptCtx(aes.Aes128), ctx, out[0..], in[0..], iv, std.builtin.Endian.big);
111        try testing.expectEqualSlices(u8, expected[0..], out[0..]);
112    }
113
114    // Test 4: Exactly one block (16 bytes)
115    {
116        const in = [_]u8{ 0x6b, 0xc1, 0xbe, 0xe2, 0x2e, 0x40, 0x9f, 0x96, 0xe9, 0x3d, 0x7e, 0x11, 0x73, 0x93, 0x17, 0x2a };
117        const expected = [_]u8{ 0x87, 0x4d, 0x61, 0x91, 0xb6, 0x20, 0xe3, 0x26, 0x1b, 0xef, 0x68, 0x64, 0x99, 0x0d, 0xb6, 0xce };
118        var out: [16]u8 = undefined;
119        ctr(aes.AesEncryptCtx(aes.Aes128), ctx, out[0..], in[0..], iv, std.builtin.Endian.big);
120        try testing.expectEqualSlices(u8, expected[0..], out[0..]);
121    }
122
123    // Test 5: One block plus one byte (17 bytes)
124    {
125        const in = [_]u8{ 0x6b, 0xc1, 0xbe, 0xe2, 0x2e, 0x40, 0x9f, 0x96, 0xe9, 0x3d, 0x7e, 0x11, 0x73, 0x93, 0x17, 0x2a, 0xae };
126        const expected = [_]u8{ 0x87, 0x4d, 0x61, 0x91, 0xb6, 0x20, 0xe3, 0x26, 0x1b, 0xef, 0x68, 0x64, 0x99, 0x0d, 0xb6, 0xce, 0x98 };
127        var out: [17]u8 = undefined;
128        ctr(aes.AesEncryptCtx(aes.Aes128), ctx, out[0..], in[0..], iv, std.builtin.Endian.big);
129        try testing.expectEqualSlices(u8, expected[0..], out[0..]);
130    }
131
132    // Test 6: Exactly two blocks (32 bytes)
133    {
134        const in = [_]u8{
135            0x6b, 0xc1, 0xbe, 0xe2, 0x2e, 0x40, 0x9f, 0x96, 0xe9, 0x3d, 0x7e, 0x11, 0x73, 0x93, 0x17, 0x2a,
136            0xae, 0x2d, 0x8a, 0x57, 0x1e, 0x03, 0xac, 0x9c, 0x9e, 0xb7, 0x6f, 0xac, 0x45, 0xaf, 0x8e, 0x51,
137        };
138        const expected = [_]u8{
139            0x87, 0x4d, 0x61, 0x91, 0xb6, 0x20, 0xe3, 0x26, 0x1b, 0xef, 0x68, 0x64, 0x99, 0x0d, 0xb6, 0xce,
140            0x98, 0x06, 0xf6, 0x6b, 0x79, 0x70, 0xfd, 0xff, 0x86, 0x17, 0x18, 0x7b, 0xb9, 0xff, 0xfd, 0xff,
141        };
142        var out: [32]u8 = undefined;
143        ctr(aes.AesEncryptCtx(aes.Aes128), ctx, out[0..], in[0..], iv, std.builtin.Endian.big);
144        try testing.expectEqualSlices(u8, expected[0..], out[0..]);
145    }
146
147    // Test 7: Two blocks plus 5 bytes (37 bytes)
148    {
149        const in = [_]u8{
150            0x6b, 0xc1, 0xbe, 0xe2, 0x2e, 0x40, 0x9f, 0x96, 0xe9, 0x3d, 0x7e, 0x11, 0x73, 0x93, 0x17, 0x2a,
151            0xae, 0x2d, 0x8a, 0x57, 0x1e, 0x03, 0xac, 0x9c, 0x9e, 0xb7, 0x6f, 0xac, 0x45, 0xaf, 0x8e, 0x51,
152            0x30, 0xc8, 0x1c, 0x46, 0xa3,
153        };
154        const expected = [_]u8{
155            0x87, 0x4d, 0x61, 0x91, 0xb6, 0x20, 0xe3, 0x26, 0x1b, 0xef, 0x68, 0x64, 0x99, 0x0d, 0xb6, 0xce,
156            0x98, 0x06, 0xf6, 0x6b, 0x79, 0x70, 0xfd, 0xff, 0x86, 0x17, 0x18, 0x7b, 0xb9, 0xff, 0xfd, 0xff,
157            0x5a, 0xe4, 0xdf, 0x3e, 0xdb,
158        };
159        var out: [37]u8 = undefined;
160        ctr(aes.AesEncryptCtx(aes.Aes128), ctx, out[0..], in[0..], iv, std.builtin.Endian.big);
161        try testing.expectEqualSlices(u8, expected[0..], out[0..]);
162    }
163
164    // Test 8: Four blocks (64 bytes) - NIST test vector
165    {
166        const in = [_]u8{
167            0x6b, 0xc1, 0xbe, 0xe2, 0x2e, 0x40, 0x9f, 0x96, 0xe9, 0x3d, 0x7e, 0x11, 0x73, 0x93, 0x17, 0x2a,
168            0xae, 0x2d, 0x8a, 0x57, 0x1e, 0x03, 0xac, 0x9c, 0x9e, 0xb7, 0x6f, 0xac, 0x45, 0xaf, 0x8e, 0x51,
169            0x30, 0xc8, 0x1c, 0x46, 0xa3, 0x5c, 0xe4, 0x11, 0xe5, 0xfb, 0xc1, 0x19, 0x1a, 0x0a, 0x52, 0xef,
170            0xf6, 0x9f, 0x24, 0x45, 0xdf, 0x4f, 0x9b, 0x17, 0xad, 0x2b, 0x41, 0x7b, 0xe6, 0x6c, 0x37, 0x10,
171        };
172        const expected = [_]u8{
173            0x87, 0x4d, 0x61, 0x91, 0xb6, 0x20, 0xe3, 0x26, 0x1b, 0xef, 0x68, 0x64, 0x99, 0x0d, 0xb6, 0xce,
174            0x98, 0x06, 0xf6, 0x6b, 0x79, 0x70, 0xfd, 0xff, 0x86, 0x17, 0x18, 0x7b, 0xb9, 0xff, 0xfd, 0xff,
175            0x5a, 0xe4, 0xdf, 0x3e, 0xdb, 0xd5, 0xd3, 0x5e, 0x5b, 0x4f, 0x09, 0x02, 0x0d, 0xb0, 0x3e, 0xab,
176            0x1e, 0x03, 0x1d, 0xda, 0x2f, 0xbe, 0x03, 0xd1, 0x79, 0x21, 0x70, 0xa0, 0xf3, 0x00, 0x9c, 0xee,
177        };
178        var out: [64]u8 = undefined;
179        ctr(aes.AesEncryptCtx(aes.Aes128), ctx, out[0..], in[0..], iv, std.builtin.Endian.big);
180        try testing.expectEqualSlices(u8, expected[0..], out[0..]);
181    }
182
183    // Test 9: Large input (> 2*block_length, 100 bytes)
184    {
185        // Create a 100-byte input by extending with zeros
186        var in: [100]u8 = [_]u8{0} ** 100;
187        @memcpy(in[0..64], &[_]u8{
188            0x6b, 0xc1, 0xbe, 0xe2, 0x2e, 0x40, 0x9f, 0x96, 0xe9, 0x3d, 0x7e, 0x11, 0x73, 0x93, 0x17, 0x2a,
189            0xae, 0x2d, 0x8a, 0x57, 0x1e, 0x03, 0xac, 0x9c, 0x9e, 0xb7, 0x6f, 0xac, 0x45, 0xaf, 0x8e, 0x51,
190            0x30, 0xc8, 0x1c, 0x46, 0xa3, 0x5c, 0xe4, 0x11, 0xe5, 0xfb, 0xc1, 0x19, 0x1a, 0x0a, 0x52, 0xef,
191            0xf6, 0x9f, 0x24, 0x45, 0xdf, 0x4f, 0x9b, 0x17, 0xad, 0x2b, 0x41, 0x7b, 0xe6, 0x6c, 0x37, 0x10,
192        });
193
194        // Expected output: first 64 bytes from NIST, then CTR continues with zeros
195        var expected: [100]u8 = undefined;
196        @memcpy(expected[0..64], &[_]u8{
197            0x87, 0x4d, 0x61, 0x91, 0xb6, 0x20, 0xe3, 0x26, 0x1b, 0xef, 0x68, 0x64, 0x99, 0x0d, 0xb6, 0xce,
198            0x98, 0x06, 0xf6, 0x6b, 0x79, 0x70, 0xfd, 0xff, 0x86, 0x17, 0x18, 0x7b, 0xb9, 0xff, 0xfd, 0xff,
199            0x5a, 0xe4, 0xdf, 0x3e, 0xdb, 0xd5, 0xd3, 0x5e, 0x5b, 0x4f, 0x09, 0x02, 0x0d, 0xb0, 0x3e, 0xab,
200            0x1e, 0x03, 0x1d, 0xda, 0x2f, 0xbe, 0x03, 0xd1, 0x79, 0x21, 0x70, 0xa0, 0xf3, 0x00, 0x9c, 0xee,
201        });
202        // Compute the rest with zeros XORed with keystream
203        @memcpy(expected[64..], &[_]u8{
204            0xb0, 0x0d, 0x47, 0xf8, 0x14, 0x8a, 0x91, 0x0e, 0xf0, 0x68, 0x30, 0x97, 0x90, 0x4b, 0xa5, 0x02,
205            0x58, 0x99, 0x44, 0x5a, 0x4d, 0xe1, 0x01, 0xf5, 0x13, 0xca, 0xd1, 0x98, 0x7d, 0x89, 0xe9, 0x1b,
206            0x3b, 0xd9, 0xac, 0x79,
207        });
208
209        var out: [100]u8 = undefined;
210        ctr(aes.AesEncryptCtx(aes.Aes128), ctx, out[0..], in[0..], iv, std.builtin.Endian.big);
211        try testing.expectEqualSlices(u8, expected[0..], out[0..]);
212    }
213
214    // Test 10: Test with different endianness (little-endian counter)
215    {
216        const le_iv = [_]u8{ 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00 };
217        const in = [_]u8{ 0x00, 0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88, 0x99, 0xaa, 0xbb, 0xcc, 0xdd, 0xee, 0xff };
218
219        // We'll compute the expected value from the actual encryption
220        var out: [16]u8 = undefined;
221        ctr(aes.AesEncryptCtx(aes.Aes128), ctx, out[0..], in[0..], le_iv, std.builtin.Endian.little);
222
223        // The actual output for this test with little-endian counter=1
224        const expected = [_]u8{ 0x7e, 0x48, 0x15, 0xa8, 0x16, 0x66, 0xf0, 0xea, 0xad, 0x3c, 0x07, 0x97, 0x2f, 0xe8, 0x25, 0xc1 };
225        try testing.expectEqualSlices(u8, expected[0..], out[0..]);
226    }
227}