master
  1const std = @import("../../std.zig");
  2const math = std.math;
  3const mem = std.mem;
  4
  5const side_channels_mitigations = std.options.side_channels_mitigations;
  6
  7/// A single AES block.
  8pub const Block = struct {
  9    const Repr = [4]u32;
 10
 11    pub const block_length: usize = 16;
 12
 13    /// Internal representation of a block.
 14    repr: Repr align(16),
 15
 16    /// Convert a byte sequence into an internal representation.
 17    pub fn fromBytes(bytes: *const [16]u8) Block {
 18        const s0 = mem.readInt(u32, bytes[0..4], .little);
 19        const s1 = mem.readInt(u32, bytes[4..8], .little);
 20        const s2 = mem.readInt(u32, bytes[8..12], .little);
 21        const s3 = mem.readInt(u32, bytes[12..16], .little);
 22        return Block{ .repr = Repr{ s0, s1, s2, s3 } };
 23    }
 24
 25    /// Convert the internal representation of a block into a byte sequence.
 26    pub fn toBytes(block: Block) [16]u8 {
 27        var bytes: [16]u8 = undefined;
 28        mem.writeInt(u32, bytes[0..4], block.repr[0], .little);
 29        mem.writeInt(u32, bytes[4..8], block.repr[1], .little);
 30        mem.writeInt(u32, bytes[8..12], block.repr[2], .little);
 31        mem.writeInt(u32, bytes[12..16], block.repr[3], .little);
 32        return bytes;
 33    }
 34
 35    /// XOR the block with a byte sequence.
 36    pub fn xorBytes(block: Block, bytes: *const [16]u8) [16]u8 {
 37        const block_bytes = block.toBytes();
 38        var x: [16]u8 = undefined;
 39        comptime var i: usize = 0;
 40        inline while (i < 16) : (i += 1) {
 41            x[i] = block_bytes[i] ^ bytes[i];
 42        }
 43        return x;
 44    }
 45
 46    /// Encrypt a block with a round key.
 47    pub fn encrypt(block: Block, round_key: Block) Block {
 48        const s0 = block.repr[0];
 49        const s1 = block.repr[1];
 50        const s2 = block.repr[2];
 51        const s3 = block.repr[3];
 52
 53        var x: [4]u32 = undefined;
 54        x = table_lookup(&table_encrypt, @as(u8, @truncate(s0)), @as(u8, @truncate(s1 >> 8)), @as(u8, @truncate(s2 >> 16)), @as(u8, @truncate(s3 >> 24)));
 55        var t0 = x[0] ^ x[1] ^ x[2] ^ x[3];
 56        x = table_lookup(&table_encrypt, @as(u8, @truncate(s1)), @as(u8, @truncate(s2 >> 8)), @as(u8, @truncate(s3 >> 16)), @as(u8, @truncate(s0 >> 24)));
 57        var t1 = x[0] ^ x[1] ^ x[2] ^ x[3];
 58        x = table_lookup(&table_encrypt, @as(u8, @truncate(s2)), @as(u8, @truncate(s3 >> 8)), @as(u8, @truncate(s0 >> 16)), @as(u8, @truncate(s1 >> 24)));
 59        var t2 = x[0] ^ x[1] ^ x[2] ^ x[3];
 60        x = table_lookup(&table_encrypt, @as(u8, @truncate(s3)), @as(u8, @truncate(s0 >> 8)), @as(u8, @truncate(s1 >> 16)), @as(u8, @truncate(s2 >> 24)));
 61        var t3 = x[0] ^ x[1] ^ x[2] ^ x[3];
 62
 63        t0 ^= round_key.repr[0];
 64        t1 ^= round_key.repr[1];
 65        t2 ^= round_key.repr[2];
 66        t3 ^= round_key.repr[3];
 67
 68        return Block{ .repr = Repr{ t0, t1, t2, t3 } };
 69    }
 70
 71    /// Encrypt a block with a round key *WITHOUT ANY PROTECTION AGAINST SIDE CHANNELS*
 72    pub fn encryptUnprotected(block: Block, round_key: Block) Block {
 73        const s0 = block.repr[0];
 74        const s1 = block.repr[1];
 75        const s2 = block.repr[2];
 76        const s3 = block.repr[3];
 77
 78        var x: [4]u32 = undefined;
 79        x = .{
 80            table_encrypt[0][@as(u8, @truncate(s0))],
 81            table_encrypt[1][@as(u8, @truncate(s1 >> 8))],
 82            table_encrypt[2][@as(u8, @truncate(s2 >> 16))],
 83            table_encrypt[3][@as(u8, @truncate(s3 >> 24))],
 84        };
 85        var t0 = x[0] ^ x[1] ^ x[2] ^ x[3];
 86        x = .{
 87            table_encrypt[0][@as(u8, @truncate(s1))],
 88            table_encrypt[1][@as(u8, @truncate(s2 >> 8))],
 89            table_encrypt[2][@as(u8, @truncate(s3 >> 16))],
 90            table_encrypt[3][@as(u8, @truncate(s0 >> 24))],
 91        };
 92        var t1 = x[0] ^ x[1] ^ x[2] ^ x[3];
 93        x = .{
 94            table_encrypt[0][@as(u8, @truncate(s2))],
 95            table_encrypt[1][@as(u8, @truncate(s3 >> 8))],
 96            table_encrypt[2][@as(u8, @truncate(s0 >> 16))],
 97            table_encrypt[3][@as(u8, @truncate(s1 >> 24))],
 98        };
 99        var t2 = x[0] ^ x[1] ^ x[2] ^ x[3];
100        x = .{
101            table_encrypt[0][@as(u8, @truncate(s3))],
102            table_encrypt[1][@as(u8, @truncate(s0 >> 8))],
103            table_encrypt[2][@as(u8, @truncate(s1 >> 16))],
104            table_encrypt[3][@as(u8, @truncate(s2 >> 24))],
105        };
106        var t3 = x[0] ^ x[1] ^ x[2] ^ x[3];
107
108        t0 ^= round_key.repr[0];
109        t1 ^= round_key.repr[1];
110        t2 ^= round_key.repr[2];
111        t3 ^= round_key.repr[3];
112
113        return Block{ .repr = Repr{ t0, t1, t2, t3 } };
114    }
115
116    /// Encrypt a block with the last round key.
117    pub fn encryptLast(block: Block, round_key: Block) Block {
118        const s0 = block.repr[0];
119        const s1 = block.repr[1];
120        const s2 = block.repr[2];
121        const s3 = block.repr[3];
122
123        // Last round uses s-box directly and XORs to produce output.
124        var x: [4]u8 = undefined;
125        x = sbox_lookup(&sbox_encrypt, @as(u8, @truncate(s0)), @as(u8, @truncate(s1 >> 8)), @as(u8, @truncate(s2 >> 16)), @as(u8, @truncate(s3 >> 24)));
126        var t0 = mem.readInt(u32, &x, .little);
127        x = sbox_lookup(&sbox_encrypt, @as(u8, @truncate(s1)), @as(u8, @truncate(s2 >> 8)), @as(u8, @truncate(s3 >> 16)), @as(u8, @truncate(s0 >> 24)));
128        var t1 = mem.readInt(u32, &x, .little);
129        x = sbox_lookup(&sbox_encrypt, @as(u8, @truncate(s2)), @as(u8, @truncate(s3 >> 8)), @as(u8, @truncate(s0 >> 16)), @as(u8, @truncate(s1 >> 24)));
130        var t2 = mem.readInt(u32, &x, .little);
131        x = sbox_lookup(&sbox_encrypt, @as(u8, @truncate(s3)), @as(u8, @truncate(s0 >> 8)), @as(u8, @truncate(s1 >> 16)), @as(u8, @truncate(s2 >> 24)));
132        var t3 = mem.readInt(u32, &x, .little);
133
134        t0 ^= round_key.repr[0];
135        t1 ^= round_key.repr[1];
136        t2 ^= round_key.repr[2];
137        t3 ^= round_key.repr[3];
138
139        return Block{ .repr = Repr{ t0, t1, t2, t3 } };
140    }
141
142    /// Decrypt a block with a round key.
143    pub fn decrypt(block: Block, round_key: Block) Block {
144        const s0 = block.repr[0];
145        const s1 = block.repr[1];
146        const s2 = block.repr[2];
147        const s3 = block.repr[3];
148
149        var x: [4]u32 = undefined;
150        x = table_lookup(&table_decrypt, @as(u8, @truncate(s0)), @as(u8, @truncate(s3 >> 8)), @as(u8, @truncate(s2 >> 16)), @as(u8, @truncate(s1 >> 24)));
151        var t0 = x[0] ^ x[1] ^ x[2] ^ x[3];
152        x = table_lookup(&table_decrypt, @as(u8, @truncate(s1)), @as(u8, @truncate(s0 >> 8)), @as(u8, @truncate(s3 >> 16)), @as(u8, @truncate(s2 >> 24)));
153        var t1 = x[0] ^ x[1] ^ x[2] ^ x[3];
154        x = table_lookup(&table_decrypt, @as(u8, @truncate(s2)), @as(u8, @truncate(s1 >> 8)), @as(u8, @truncate(s0 >> 16)), @as(u8, @truncate(s3 >> 24)));
155        var t2 = x[0] ^ x[1] ^ x[2] ^ x[3];
156        x = table_lookup(&table_decrypt, @as(u8, @truncate(s3)), @as(u8, @truncate(s2 >> 8)), @as(u8, @truncate(s1 >> 16)), @as(u8, @truncate(s0 >> 24)));
157        var t3 = x[0] ^ x[1] ^ x[2] ^ x[3];
158
159        t0 ^= round_key.repr[0];
160        t1 ^= round_key.repr[1];
161        t2 ^= round_key.repr[2];
162        t3 ^= round_key.repr[3];
163
164        return Block{ .repr = Repr{ t0, t1, t2, t3 } };
165    }
166
167    /// Decrypt a block with a round key *WITHOUT ANY PROTECTION AGAINST SIDE CHANNELS*
168    pub fn decryptUnprotected(block: Block, round_key: Block) Block {
169        const s0 = block.repr[0];
170        const s1 = block.repr[1];
171        const s2 = block.repr[2];
172        const s3 = block.repr[3];
173
174        var x: [4]u32 = undefined;
175        x = .{
176            table_decrypt[0][@as(u8, @truncate(s0))],
177            table_decrypt[1][@as(u8, @truncate(s3 >> 8))],
178            table_decrypt[2][@as(u8, @truncate(s2 >> 16))],
179            table_decrypt[3][@as(u8, @truncate(s1 >> 24))],
180        };
181        var t0 = x[0] ^ x[1] ^ x[2] ^ x[3];
182        x = .{
183            table_decrypt[0][@as(u8, @truncate(s1))],
184            table_decrypt[1][@as(u8, @truncate(s0 >> 8))],
185            table_decrypt[2][@as(u8, @truncate(s3 >> 16))],
186            table_decrypt[3][@as(u8, @truncate(s2 >> 24))],
187        };
188        var t1 = x[0] ^ x[1] ^ x[2] ^ x[3];
189        x = .{
190            table_decrypt[0][@as(u8, @truncate(s2))],
191            table_decrypt[1][@as(u8, @truncate(s1 >> 8))],
192            table_decrypt[2][@as(u8, @truncate(s0 >> 16))],
193            table_decrypt[3][@as(u8, @truncate(s3 >> 24))],
194        };
195        var t2 = x[0] ^ x[1] ^ x[2] ^ x[3];
196        x = .{
197            table_decrypt[0][@as(u8, @truncate(s3))],
198            table_decrypt[1][@as(u8, @truncate(s2 >> 8))],
199            table_decrypt[2][@as(u8, @truncate(s1 >> 16))],
200            table_decrypt[3][@as(u8, @truncate(s0 >> 24))],
201        };
202        var t3 = x[0] ^ x[1] ^ x[2] ^ x[3];
203
204        t0 ^= round_key.repr[0];
205        t1 ^= round_key.repr[1];
206        t2 ^= round_key.repr[2];
207        t3 ^= round_key.repr[3];
208
209        return Block{ .repr = Repr{ t0, t1, t2, t3 } };
210    }
211
212    /// Decrypt a block with the last round key.
213    pub fn decryptLast(block: Block, round_key: Block) Block {
214        const s0 = block.repr[0];
215        const s1 = block.repr[1];
216        const s2 = block.repr[2];
217        const s3 = block.repr[3];
218
219        // Last round uses s-box directly and XORs to produce output.
220        var x: [4]u8 = undefined;
221        x = sbox_lookup(&sbox_decrypt, @as(u8, @truncate(s0)), @as(u8, @truncate(s3 >> 8)), @as(u8, @truncate(s2 >> 16)), @as(u8, @truncate(s1 >> 24)));
222        var t0 = mem.readInt(u32, &x, .little);
223        x = sbox_lookup(&sbox_decrypt, @as(u8, @truncate(s1)), @as(u8, @truncate(s0 >> 8)), @as(u8, @truncate(s3 >> 16)), @as(u8, @truncate(s2 >> 24)));
224        var t1 = mem.readInt(u32, &x, .little);
225        x = sbox_lookup(&sbox_decrypt, @as(u8, @truncate(s2)), @as(u8, @truncate(s1 >> 8)), @as(u8, @truncate(s0 >> 16)), @as(u8, @truncate(s3 >> 24)));
226        var t2 = mem.readInt(u32, &x, .little);
227        x = sbox_lookup(&sbox_decrypt, @as(u8, @truncate(s3)), @as(u8, @truncate(s2 >> 8)), @as(u8, @truncate(s1 >> 16)), @as(u8, @truncate(s0 >> 24)));
228        var t3 = mem.readInt(u32, &x, .little);
229
230        t0 ^= round_key.repr[0];
231        t1 ^= round_key.repr[1];
232        t2 ^= round_key.repr[2];
233        t3 ^= round_key.repr[3];
234
235        return Block{ .repr = Repr{ t0, t1, t2, t3 } };
236    }
237
238    /// Apply the bitwise XOR operation to the content of two blocks.
239    pub fn xorBlocks(block1: Block, block2: Block) Block {
240        var x: Repr = undefined;
241        comptime var i = 0;
242        inline while (i < 4) : (i += 1) {
243            x[i] = block1.repr[i] ^ block2.repr[i];
244        }
245        return Block{ .repr = x };
246    }
247
248    /// Apply the bitwise AND operation to the content of two blocks.
249    pub fn andBlocks(block1: Block, block2: Block) Block {
250        var x: Repr = undefined;
251        comptime var i = 0;
252        inline while (i < 4) : (i += 1) {
253            x[i] = block1.repr[i] & block2.repr[i];
254        }
255        return Block{ .repr = x };
256    }
257
258    /// Apply the bitwise OR operation to the content of two blocks.
259    pub fn orBlocks(block1: Block, block2: Block) Block {
260        var x: Repr = undefined;
261        comptime var i = 0;
262        inline while (i < 4) : (i += 1) {
263            x[i] = block1.repr[i] | block2.repr[i];
264        }
265        return Block{ .repr = x };
266    }
267
268    /// Apply the inverse MixColumns operation to a block.
269    pub fn invMixColumns(block: Block) Block {
270        var out: Repr = undefined;
271        inline for (0..4) |i| {
272            const col = block.repr[i];
273            const b0: u8 = @truncate(col);
274            const b1: u8 = @truncate(col >> 8);
275            const b2: u8 = @truncate(col >> 16);
276            const b3: u8 = @truncate(col >> 24);
277
278            const r0 = mul(0x0e, b0) ^ mul(0x0b, b1) ^ mul(0x0d, b2) ^ mul(0x09, b3);
279            const r1 = mul(0x09, b0) ^ mul(0x0e, b1) ^ mul(0x0b, b2) ^ mul(0x0d, b3);
280            const r2 = mul(0x0d, b0) ^ mul(0x09, b1) ^ mul(0x0e, b2) ^ mul(0x0b, b3);
281            const r3 = mul(0x0b, b0) ^ mul(0x0d, b1) ^ mul(0x09, b2) ^ mul(0x0e, b3);
282
283            out[i] = @as(u32, r0) | (@as(u32, r1) << 8) | (@as(u32, r2) << 16) | (@as(u32, r3) << 24);
284        }
285        return Block{ .repr = out };
286    }
287
288    /// Perform operations on multiple blocks in parallel.
289    pub const parallel = struct {
290        /// The recommended number of AES encryption/decryption to perform in parallel for the chosen implementation.
291        pub const optimal_parallel_blocks = 1;
292
293        /// Encrypt multiple blocks in parallel, each their own round key.
294        pub fn encryptParallel(comptime count: usize, blocks: [count]Block, round_keys: [count]Block) [count]Block {
295            var i = 0;
296            var out: [count]Block = undefined;
297            while (i < count) : (i += 1) {
298                out[i] = blocks[i].encrypt(round_keys[i]);
299            }
300            return out;
301        }
302
303        /// Decrypt multiple blocks in parallel, each their own round key.
304        pub fn decryptParallel(comptime count: usize, blocks: [count]Block, round_keys: [count]Block) [count]Block {
305            var i = 0;
306            var out: [count]Block = undefined;
307            while (i < count) : (i += 1) {
308                out[i] = blocks[i].decrypt(round_keys[i]);
309            }
310            return out;
311        }
312
313        /// Encrypt multiple blocks in parallel with the same round key.
314        pub fn encryptWide(comptime count: usize, blocks: [count]Block, round_key: Block) [count]Block {
315            var i = 0;
316            var out: [count]Block = undefined;
317            while (i < count) : (i += 1) {
318                out[i] = blocks[i].encrypt(round_key);
319            }
320            return out;
321        }
322
323        /// Decrypt multiple blocks in parallel with the same round key.
324        pub fn decryptWide(comptime count: usize, blocks: [count]Block, round_key: Block) [count]Block {
325            var i = 0;
326            var out: [count]Block = undefined;
327            while (i < count) : (i += 1) {
328                out[i] = blocks[i].decrypt(round_key);
329            }
330            return out;
331        }
332
333        /// Encrypt multiple blocks in parallel with the same last round key.
334        pub fn encryptLastWide(comptime count: usize, blocks: [count]Block, round_key: Block) [count]Block {
335            var i = 0;
336            var out: [count]Block = undefined;
337            while (i < count) : (i += 1) {
338                out[i] = blocks[i].encryptLast(round_key);
339            }
340            return out;
341        }
342
343        /// Decrypt multiple blocks in parallel with the same last round key.
344        pub fn decryptLastWide(comptime count: usize, blocks: [count]Block, round_key: Block) [count]Block {
345            var i = 0;
346            var out: [count]Block = undefined;
347            while (i < count) : (i += 1) {
348                out[i] = blocks[i].decryptLast(round_key);
349            }
350            return out;
351        }
352    };
353};
354
355/// A fixed-size vector of AES blocks.
356/// All operations are performed in parallel, using SIMD instructions when available.
357pub fn BlockVec(comptime blocks_count: comptime_int) type {
358    return struct {
359        const Self = @This();
360
361        /// The number of AES blocks the target architecture can process with a single instruction.
362        pub const native_vector_size = 1;
363
364        /// The size of the AES block vector that the target architecture can process with a single instruction, in bytes.
365        pub const native_word_size = native_vector_size * 16;
366
367        const native_words = blocks_count;
368
369        /// Internal representation of a block vector.
370        repr: [native_words]Block,
371
372        /// Length of the block vector in bytes.
373        pub const block_length: usize = blocks_count * 16;
374
375        /// Convert a byte sequence into an internal representation.
376        pub fn fromBytes(bytes: *const [blocks_count * 16]u8) Self {
377            var out: Self = undefined;
378            for (0..native_words) |i| {
379                out.repr[i] = Block.fromBytes(bytes[i * native_word_size ..][0..native_word_size]);
380            }
381            return out;
382        }
383
384        /// Convert the internal representation of a block vector into a byte sequence.
385        pub fn toBytes(block_vec: Self) [blocks_count * 16]u8 {
386            var out: [blocks_count * 16]u8 = undefined;
387            for (0..native_words) |i| {
388                out[i * native_word_size ..][0..native_word_size].* = block_vec.repr[i].toBytes();
389            }
390            return out;
391        }
392
393        /// XOR the block vector with a byte sequence.
394        pub fn xorBytes(block_vec: Self, bytes: *const [blocks_count * 16]u8) [32]u8 {
395            var out: Self = undefined;
396            for (0..native_words) |i| {
397                out.repr[i] = block_vec.repr[i].xorBytes(bytes[i * native_word_size ..][0..native_word_size]);
398            }
399            return out;
400        }
401
402        /// Apply the forward AES operation to the block vector with a vector of round keys.
403        pub fn encrypt(block_vec: Self, round_key_vec: Self) Self {
404            var out: Self = undefined;
405            for (0..native_words) |i| {
406                out.repr[i] = block_vec.repr[i].encrypt(round_key_vec.repr[i]);
407            }
408            return out;
409        }
410
411        /// Apply the forward AES operation to the block vector with a vector of last round keys.
412        pub fn encryptLast(block_vec: Self, round_key_vec: Self) Self {
413            var out: Self = undefined;
414            for (0..native_words) |i| {
415                out.repr[i] = block_vec.repr[i].encryptLast(round_key_vec.repr[i]);
416            }
417            return out;
418        }
419
420        /// Apply the inverse AES operation to the block vector with a vector of round keys.
421        pub fn decrypt(block_vec: Self, inv_round_key_vec: Self) Self {
422            var out: Self = undefined;
423            for (0..native_words) |i| {
424                out.repr[i] = block_vec.repr[i].decrypt(inv_round_key_vec.repr[i]);
425            }
426            return out;
427        }
428
429        /// Apply the inverse AES operation to the block vector with a vector of last round keys.
430        pub fn decryptLast(block_vec: Self, inv_round_key_vec: Self) Self {
431            var out: Self = undefined;
432            for (0..native_words) |i| {
433                out.repr[i] = block_vec.repr[i].decryptLast(inv_round_key_vec.repr[i]);
434            }
435            return out;
436        }
437
438        /// Apply the bitwise XOR operation to the content of two block vectors.
439        pub fn xorBlocks(block_vec1: Self, block_vec2: Self) Self {
440            var out: Self = undefined;
441            for (0..native_words) |i| {
442                out.repr[i] = block_vec1.repr[i].xorBlocks(block_vec2.repr[i]);
443            }
444            return out;
445        }
446
447        /// Apply the bitwise AND operation to the content of two block vectors.
448        pub fn andBlocks(block_vec1: Self, block_vec2: Self) Self {
449            var out: Self = undefined;
450            for (0..native_words) |i| {
451                out.repr[i] = block_vec1.repr[i].andBlocks(block_vec2.repr[i]);
452            }
453            return out;
454        }
455
456        /// Apply the bitwise OR operation to the content of two block vectors.
457        pub fn orBlocks(block_vec1: Self, block_vec2: Block) Self {
458            var out: Self = undefined;
459            for (0..native_words) |i| {
460                out.repr[i] = block_vec1.repr[i].orBlocks(block_vec2.repr[i]);
461            }
462            return out;
463        }
464
465        /// Apply the inverse MixColumns operation to each block in the vector.
466        pub fn invMixColumns(block_vec: Self) Self {
467            var out: Self = undefined;
468            for (0..native_words) |i| {
469                out.repr[i] = block_vec.repr[i].invMixColumns();
470            }
471            return out;
472        }
473    };
474}
475
476fn KeySchedule(comptime Aes: type) type {
477    std.debug.assert(Aes.rounds == 10 or Aes.rounds == 14);
478    const key_length = Aes.key_bits / 8;
479    const rounds = Aes.rounds;
480
481    return struct {
482        const Self = @This();
483        const words_in_key = key_length / 4;
484
485        round_keys: [rounds + 1]Block,
486
487        // Key expansion algorithm. See FIPS-197, Figure 11.
488        fn expandKey(key: [key_length]u8) Self {
489            const subw = struct {
490                // Apply sbox_encrypt to each byte in w.
491                fn func(w: u32) u32 {
492                    const x = sbox_lookup(&sbox_key_schedule, @as(u8, @truncate(w)), @as(u8, @truncate(w >> 8)), @as(u8, @truncate(w >> 16)), @as(u8, @truncate(w >> 24)));
493                    return mem.readInt(u32, &x, .little);
494                }
495            }.func;
496
497            var round_keys: [rounds + 1]Block = undefined;
498            comptime var i: usize = 0;
499            inline while (i < words_in_key) : (i += 1) {
500                round_keys[i / 4].repr[i % 4] = mem.readInt(u32, key[4 * i ..][0..4], .big);
501            }
502            inline while (i < round_keys.len * 4) : (i += 1) {
503                var t = round_keys[(i - 1) / 4].repr[(i - 1) % 4];
504                if (i % words_in_key == 0) {
505                    t = subw(std.math.rotl(u32, t, 8)) ^ (@as(u32, powx[i / words_in_key - 1]) << 24);
506                } else if (words_in_key > 6 and i % words_in_key == 4) {
507                    t = subw(t);
508                }
509                round_keys[i / 4].repr[i % 4] = round_keys[(i - words_in_key) / 4].repr[(i - words_in_key) % 4] ^ t;
510            }
511            i = 0;
512            inline while (i < round_keys.len * 4) : (i += 1) {
513                round_keys[i / 4].repr[i % 4] = @byteSwap(round_keys[i / 4].repr[i % 4]);
514            }
515            return Self{ .round_keys = round_keys };
516        }
517
518        /// Invert the key schedule.
519        pub fn invert(key_schedule: Self) Self {
520            const round_keys = &key_schedule.round_keys;
521            var inv_round_keys: [rounds + 1]Block = undefined;
522            const total_words = 4 * round_keys.len;
523            var i: usize = 0;
524            while (i < total_words) : (i += 4) {
525                const ei = total_words - i - 4;
526                comptime var j: usize = 0;
527                inline while (j < 4) : (j += 1) {
528                    var rk = round_keys[(ei + j) / 4].repr[(ei + j) % 4];
529                    if (i > 0 and i + 4 < total_words) {
530                        const x = sbox_lookup(&sbox_key_schedule, @as(u8, @truncate(rk >> 24)), @as(u8, @truncate(rk >> 16)), @as(u8, @truncate(rk >> 8)), @as(u8, @truncate(rk)));
531                        const y = table_lookup(&table_decrypt, x[3], x[2], x[1], x[0]);
532                        rk = y[0] ^ y[1] ^ y[2] ^ y[3];
533                    }
534                    inv_round_keys[(i + j) / 4].repr[(i + j) % 4] = rk;
535                }
536            }
537            return Self{ .round_keys = inv_round_keys };
538        }
539    };
540}
541
542/// A context to perform encryption using the standard AES key schedule.
543pub fn AesEncryptCtx(comptime Aes: type) type {
544    std.debug.assert(Aes.key_bits == 128 or Aes.key_bits == 256);
545    const rounds = Aes.rounds;
546
547    return struct {
548        const Self = @This();
549        pub const block = Aes.block;
550        pub const block_length = block.block_length;
551        key_schedule: KeySchedule(Aes),
552
553        /// Create a new encryption context with the given key.
554        pub fn init(key: [Aes.key_bits / 8]u8) Self {
555            const key_schedule = KeySchedule(Aes).expandKey(key);
556            return Self{
557                .key_schedule = key_schedule,
558            };
559        }
560
561        /// Encrypt a single block.
562        pub fn encrypt(ctx: Self, dst: *[16]u8, src: *const [16]u8) void {
563            const round_keys = ctx.key_schedule.round_keys;
564            var t = Block.fromBytes(src).xorBlocks(round_keys[0]);
565            comptime var i = 1;
566            if (side_channels_mitigations == .full) {
567                inline while (i < rounds) : (i += 1) {
568                    t = t.encrypt(round_keys[i]);
569                }
570            } else {
571                inline while (i < 5) : (i += 1) {
572                    t = t.encrypt(round_keys[i]);
573                }
574                inline while (i < rounds - 1) : (i += 1) {
575                    t = t.encryptUnprotected(round_keys[i]);
576                }
577                t = t.encrypt(round_keys[i]);
578            }
579            t = t.encryptLast(round_keys[rounds]);
580            dst.* = t.toBytes();
581        }
582
583        /// Encrypt+XOR a single block.
584        pub fn xor(ctx: Self, dst: *[16]u8, src: *const [16]u8, counter: [16]u8) void {
585            const round_keys = ctx.key_schedule.round_keys;
586            var t = Block.fromBytes(&counter).xorBlocks(round_keys[0]);
587            comptime var i = 1;
588            if (side_channels_mitigations == .full) {
589                inline while (i < rounds) : (i += 1) {
590                    t = t.encrypt(round_keys[i]);
591                }
592            } else {
593                inline while (i < 5) : (i += 1) {
594                    t = t.encrypt(round_keys[i]);
595                }
596                inline while (i < rounds - 1) : (i += 1) {
597                    t = t.encryptUnprotected(round_keys[i]);
598                }
599                t = t.encrypt(round_keys[i]);
600            }
601            t = t.encryptLast(round_keys[rounds]);
602            dst.* = t.xorBytes(src);
603        }
604
605        /// Encrypt multiple blocks, possibly leveraging parallelization.
606        pub fn encryptWide(ctx: Self, comptime count: usize, dst: *[16 * count]u8, src: *const [16 * count]u8) void {
607            var i: usize = 0;
608            while (i < count) : (i += 1) {
609                ctx.encrypt(dst[16 * i .. 16 * i + 16][0..16], src[16 * i .. 16 * i + 16][0..16]);
610            }
611        }
612
613        /// Encrypt+XOR multiple blocks, possibly leveraging parallelization.
614        pub fn xorWide(ctx: Self, comptime count: usize, dst: *[16 * count]u8, src: *const [16 * count]u8, counters: [16 * count]u8) void {
615            var i: usize = 0;
616            while (i < count) : (i += 1) {
617                ctx.xor(dst[16 * i .. 16 * i + 16][0..16], src[16 * i .. 16 * i + 16][0..16], counters[16 * i .. 16 * i + 16][0..16].*);
618            }
619        }
620    };
621}
622
623/// A context to perform decryption using the standard AES key schedule.
624pub fn AesDecryptCtx(comptime Aes: type) type {
625    std.debug.assert(Aes.key_bits == 128 or Aes.key_bits == 256);
626    const rounds = Aes.rounds;
627
628    return struct {
629        const Self = @This();
630        pub const block = Aes.block;
631        pub const block_length = block.block_length;
632        key_schedule: KeySchedule(Aes),
633
634        /// Create a decryption context from an existing encryption context.
635        pub fn initFromEnc(ctx: AesEncryptCtx(Aes)) Self {
636            return Self{
637                .key_schedule = ctx.key_schedule.invert(),
638            };
639        }
640
641        /// Create a new decryption context with the given key.
642        pub fn init(key: [Aes.key_bits / 8]u8) Self {
643            const enc_ctx = AesEncryptCtx(Aes).init(key);
644            return initFromEnc(enc_ctx);
645        }
646
647        /// Decrypt a single block.
648        pub fn decrypt(ctx: Self, dst: *[16]u8, src: *const [16]u8) void {
649            const inv_round_keys = ctx.key_schedule.round_keys;
650            var t = Block.fromBytes(src).xorBlocks(inv_round_keys[0]);
651            comptime var i = 1;
652            if (side_channels_mitigations == .full) {
653                inline while (i < rounds) : (i += 1) {
654                    t = t.decrypt(inv_round_keys[i]);
655                }
656            } else {
657                inline while (i < 5) : (i += 1) {
658                    t = t.decrypt(inv_round_keys[i]);
659                }
660                inline while (i < rounds - 1) : (i += 1) {
661                    t = t.decryptUnprotected(inv_round_keys[i]);
662                }
663                t = t.decrypt(inv_round_keys[i]);
664            }
665            t = t.decryptLast(inv_round_keys[rounds]);
666            dst.* = t.toBytes();
667        }
668
669        /// Decrypt multiple blocks, possibly leveraging parallelization.
670        pub fn decryptWide(ctx: Self, comptime count: usize, dst: *[16 * count]u8, src: *const [16 * count]u8) void {
671            var i: usize = 0;
672            while (i < count) : (i += 1) {
673                ctx.decrypt(dst[16 * i .. 16 * i + 16][0..16], src[16 * i .. 16 * i + 16][0..16]);
674            }
675        }
676    };
677}
678
679/// AES-128 with the standard key schedule.
680pub const Aes128 = struct {
681    pub const key_bits: usize = 128;
682    pub const rounds = ((key_bits - 64) / 32 + 8);
683    pub const block = Block;
684
685    /// Create a new context for encryption.
686    pub fn initEnc(key: [key_bits / 8]u8) AesEncryptCtx(Aes128) {
687        return AesEncryptCtx(Aes128).init(key);
688    }
689
690    /// Create a new context for decryption.
691    pub fn initDec(key: [key_bits / 8]u8) AesDecryptCtx(Aes128) {
692        return AesDecryptCtx(Aes128).init(key);
693    }
694};
695
696/// AES-256 with the standard key schedule.
697pub const Aes256 = struct {
698    pub const key_bits: usize = 256;
699    pub const rounds = ((key_bits - 64) / 32 + 8);
700    pub const block = Block;
701
702    /// Create a new context for encryption.
703    pub fn initEnc(key: [key_bits / 8]u8) AesEncryptCtx(Aes256) {
704        return AesEncryptCtx(Aes256).init(key);
705    }
706
707    /// Create a new context for decryption.
708    pub fn initDec(key: [key_bits / 8]u8) AesDecryptCtx(Aes256) {
709        return AesDecryptCtx(Aes256).init(key);
710    }
711};
712
713// constants
714
715// Rijndael's irreducible polynomial.
716const poly: u9 = 1 << 8 | 1 << 4 | 1 << 3 | 1 << 1 | 1 << 0; // x⁸ + x⁴ + x³ + x + 1
717
718// Powers of x mod poly in GF(2).
719const powx = init: {
720    var array: [16]u8 = undefined;
721
722    var value = 1;
723    for (&array) |*power| {
724        power.* = value;
725        value = mul(value, 2);
726    }
727
728    break :init array;
729};
730
731const sbox_encrypt align(64) = generateSbox(false); // S-box for encryption
732const sbox_key_schedule align(64) = generateSbox(false); // S-box only for key schedule, so that it uses distinct L1 cache entries than the S-box used for encryption
733const sbox_decrypt align(64) = generateSbox(true); // S-box for decryption
734const table_encrypt align(64) = generateTable(false); // 4-byte LUTs for encryption
735const table_decrypt align(64) = generateTable(true); // 4-byte LUTs for decryption
736
737// Generate S-box substitution values.
738fn generateSbox(invert: bool) [256]u8 {
739    @setEvalBranchQuota(10000);
740
741    var sbox: [256]u8 = undefined;
742
743    var p: u8 = 1;
744    var q: u8 = 1;
745    for (sbox) |_| {
746        p = mul(p, 3);
747        q = mul(q, 0xf6); // divide by 3
748
749        var value: u8 = q ^ 0x63;
750        value ^= math.rotl(u8, q, 1);
751        value ^= math.rotl(u8, q, 2);
752        value ^= math.rotl(u8, q, 3);
753        value ^= math.rotl(u8, q, 4);
754
755        if (invert) {
756            sbox[value] = p;
757        } else {
758            sbox[p] = value;
759        }
760    }
761
762    if (invert) {
763        sbox[0x63] = 0x00;
764    } else {
765        sbox[0x00] = 0x63;
766    }
767
768    return sbox;
769}
770
771// Generate lookup tables.
772fn generateTable(invert: bool) [4][256]u32 {
773    @setEvalBranchQuota(50000);
774
775    var table: [4][256]u32 = undefined;
776
777    for (generateSbox(invert), 0..) |value, index| {
778        table[0][index] = math.shl(u32, mul(value, if (invert) 0xb else 0x3), 24);
779        table[0][index] |= math.shl(u32, mul(value, if (invert) 0xd else 0x1), 16);
780        table[0][index] |= math.shl(u32, mul(value, if (invert) 0x9 else 0x1), 8);
781        table[0][index] |= mul(value, if (invert) 0xe else 0x2);
782
783        table[1][index] = math.rotl(u32, table[0][index], 8);
784        table[2][index] = math.rotl(u32, table[0][index], 16);
785        table[3][index] = math.rotl(u32, table[0][index], 24);
786    }
787
788    return table;
789}
790
791// Multiply a and b as GF(2) polynomials modulo poly.
792fn mul(a: u8, b: u8) u8 {
793    @setEvalBranchQuota(30000);
794
795    var i: u8 = a;
796    var j: u9 = b;
797    var s: u9 = 0;
798
799    while (i > 0) : (i >>= 1) {
800        if (i & 1 != 0) {
801            s ^= j;
802        }
803
804        j *= 2;
805        if (j & 0x100 != 0) {
806            j ^= poly;
807        }
808    }
809
810    return @as(u8, @truncate(s));
811}
812
813const cache_line_bytes = std.atomic.cache_line;
814
815fn sbox_lookup(sbox: *align(64) const [256]u8, idx0: u8, idx1: u8, idx2: u8, idx3: u8) [4]u8 {
816    if (side_channels_mitigations == .none) {
817        return [4]u8{
818            sbox[idx0],
819            sbox[idx1],
820            sbox[idx2],
821            sbox[idx3],
822        };
823    } else {
824        const stride = switch (side_channels_mitigations) {
825            .none => unreachable,
826            .basic => sbox.len / 4,
827            .medium => @min(sbox.len, 2 * cache_line_bytes),
828            .full => @min(sbox.len, cache_line_bytes),
829        };
830        const of0 = idx0 % stride;
831        const of1 = idx1 % stride;
832        const of2 = idx2 % stride;
833        const of3 = idx3 % stride;
834        var t: [4][sbox.len / stride]u8 align(64) = undefined;
835        var i: usize = 0;
836        while (i < t[0].len) : (i += 1) {
837            const tx = sbox[i * stride ..];
838            t[0][i] = tx[of0];
839            t[1][i] = tx[of1];
840            t[2][i] = tx[of2];
841            t[3][i] = tx[of3];
842        }
843        std.mem.doNotOptimizeAway(t);
844        return [4]u8{
845            t[0][idx0 / stride],
846            t[1][idx1 / stride],
847            t[2][idx2 / stride],
848            t[3][idx3 / stride],
849        };
850    }
851}
852
853fn table_lookup(table: *align(64) const [4][256]u32, idx0: u8, idx1: u8, idx2: u8, idx3: u8) [4]u32 {
854    if (side_channels_mitigations == .none) {
855        return [4]u32{
856            table[0][idx0],
857            table[1][idx1],
858            table[2][idx2],
859            table[3][idx3],
860        };
861    } else {
862        const table_len: usize = 256;
863        const stride = switch (side_channels_mitigations) {
864            .none => unreachable,
865            .basic => table_len / 4,
866            .medium => @max(1, @min(table_len, 2 * cache_line_bytes / 4)),
867            .full => @max(1, @min(table_len, cache_line_bytes / 4)),
868        };
869        const of0 = idx0 % stride;
870        const of1 = idx1 % stride;
871        const of2 = idx2 % stride;
872        const of3 = idx3 % stride;
873        var t: [4][table_len / stride]u32 align(64) = undefined;
874        var i: usize = 0;
875        while (i < t[0].len) : (i += 1) {
876            const tx = table[0][i * stride ..];
877            t[0][i] = tx[of0];
878            t[1][i] = tx[of1];
879            t[2][i] = tx[of2];
880            t[3][i] = tx[of3];
881        }
882        std.mem.doNotOptimizeAway(t);
883        return [4]u32{
884            t[0][idx0 / stride],
885            math.rotl(u32, (&t[1])[idx1 / stride], 8),
886            math.rotl(u32, (&t[2])[idx2 / stride], 16),
887            math.rotl(u32, (&t[3])[idx3 / stride], 24),
888        };
889    }
890}