master
1const std = @import("../../std.zig");
2const math = std.math;
3const mem = std.mem;
4
5const side_channels_mitigations = std.options.side_channels_mitigations;
6
7/// A single AES block.
8pub const Block = struct {
9 const Repr = [4]u32;
10
11 pub const block_length: usize = 16;
12
13 /// Internal representation of a block.
14 repr: Repr align(16),
15
16 /// Convert a byte sequence into an internal representation.
17 pub fn fromBytes(bytes: *const [16]u8) Block {
18 const s0 = mem.readInt(u32, bytes[0..4], .little);
19 const s1 = mem.readInt(u32, bytes[4..8], .little);
20 const s2 = mem.readInt(u32, bytes[8..12], .little);
21 const s3 = mem.readInt(u32, bytes[12..16], .little);
22 return Block{ .repr = Repr{ s0, s1, s2, s3 } };
23 }
24
25 /// Convert the internal representation of a block into a byte sequence.
26 pub fn toBytes(block: Block) [16]u8 {
27 var bytes: [16]u8 = undefined;
28 mem.writeInt(u32, bytes[0..4], block.repr[0], .little);
29 mem.writeInt(u32, bytes[4..8], block.repr[1], .little);
30 mem.writeInt(u32, bytes[8..12], block.repr[2], .little);
31 mem.writeInt(u32, bytes[12..16], block.repr[3], .little);
32 return bytes;
33 }
34
35 /// XOR the block with a byte sequence.
36 pub fn xorBytes(block: Block, bytes: *const [16]u8) [16]u8 {
37 const block_bytes = block.toBytes();
38 var x: [16]u8 = undefined;
39 comptime var i: usize = 0;
40 inline while (i < 16) : (i += 1) {
41 x[i] = block_bytes[i] ^ bytes[i];
42 }
43 return x;
44 }
45
46 /// Encrypt a block with a round key.
47 pub fn encrypt(block: Block, round_key: Block) Block {
48 const s0 = block.repr[0];
49 const s1 = block.repr[1];
50 const s2 = block.repr[2];
51 const s3 = block.repr[3];
52
53 var x: [4]u32 = undefined;
54 x = table_lookup(&table_encrypt, @as(u8, @truncate(s0)), @as(u8, @truncate(s1 >> 8)), @as(u8, @truncate(s2 >> 16)), @as(u8, @truncate(s3 >> 24)));
55 var t0 = x[0] ^ x[1] ^ x[2] ^ x[3];
56 x = table_lookup(&table_encrypt, @as(u8, @truncate(s1)), @as(u8, @truncate(s2 >> 8)), @as(u8, @truncate(s3 >> 16)), @as(u8, @truncate(s0 >> 24)));
57 var t1 = x[0] ^ x[1] ^ x[2] ^ x[3];
58 x = table_lookup(&table_encrypt, @as(u8, @truncate(s2)), @as(u8, @truncate(s3 >> 8)), @as(u8, @truncate(s0 >> 16)), @as(u8, @truncate(s1 >> 24)));
59 var t2 = x[0] ^ x[1] ^ x[2] ^ x[3];
60 x = table_lookup(&table_encrypt, @as(u8, @truncate(s3)), @as(u8, @truncate(s0 >> 8)), @as(u8, @truncate(s1 >> 16)), @as(u8, @truncate(s2 >> 24)));
61 var t3 = x[0] ^ x[1] ^ x[2] ^ x[3];
62
63 t0 ^= round_key.repr[0];
64 t1 ^= round_key.repr[1];
65 t2 ^= round_key.repr[2];
66 t3 ^= round_key.repr[3];
67
68 return Block{ .repr = Repr{ t0, t1, t2, t3 } };
69 }
70
71 /// Encrypt a block with a round key *WITHOUT ANY PROTECTION AGAINST SIDE CHANNELS*
72 pub fn encryptUnprotected(block: Block, round_key: Block) Block {
73 const s0 = block.repr[0];
74 const s1 = block.repr[1];
75 const s2 = block.repr[2];
76 const s3 = block.repr[3];
77
78 var x: [4]u32 = undefined;
79 x = .{
80 table_encrypt[0][@as(u8, @truncate(s0))],
81 table_encrypt[1][@as(u8, @truncate(s1 >> 8))],
82 table_encrypt[2][@as(u8, @truncate(s2 >> 16))],
83 table_encrypt[3][@as(u8, @truncate(s3 >> 24))],
84 };
85 var t0 = x[0] ^ x[1] ^ x[2] ^ x[3];
86 x = .{
87 table_encrypt[0][@as(u8, @truncate(s1))],
88 table_encrypt[1][@as(u8, @truncate(s2 >> 8))],
89 table_encrypt[2][@as(u8, @truncate(s3 >> 16))],
90 table_encrypt[3][@as(u8, @truncate(s0 >> 24))],
91 };
92 var t1 = x[0] ^ x[1] ^ x[2] ^ x[3];
93 x = .{
94 table_encrypt[0][@as(u8, @truncate(s2))],
95 table_encrypt[1][@as(u8, @truncate(s3 >> 8))],
96 table_encrypt[2][@as(u8, @truncate(s0 >> 16))],
97 table_encrypt[3][@as(u8, @truncate(s1 >> 24))],
98 };
99 var t2 = x[0] ^ x[1] ^ x[2] ^ x[3];
100 x = .{
101 table_encrypt[0][@as(u8, @truncate(s3))],
102 table_encrypt[1][@as(u8, @truncate(s0 >> 8))],
103 table_encrypt[2][@as(u8, @truncate(s1 >> 16))],
104 table_encrypt[3][@as(u8, @truncate(s2 >> 24))],
105 };
106 var t3 = x[0] ^ x[1] ^ x[2] ^ x[3];
107
108 t0 ^= round_key.repr[0];
109 t1 ^= round_key.repr[1];
110 t2 ^= round_key.repr[2];
111 t3 ^= round_key.repr[3];
112
113 return Block{ .repr = Repr{ t0, t1, t2, t3 } };
114 }
115
116 /// Encrypt a block with the last round key.
117 pub fn encryptLast(block: Block, round_key: Block) Block {
118 const s0 = block.repr[0];
119 const s1 = block.repr[1];
120 const s2 = block.repr[2];
121 const s3 = block.repr[3];
122
123 // Last round uses s-box directly and XORs to produce output.
124 var x: [4]u8 = undefined;
125 x = sbox_lookup(&sbox_encrypt, @as(u8, @truncate(s0)), @as(u8, @truncate(s1 >> 8)), @as(u8, @truncate(s2 >> 16)), @as(u8, @truncate(s3 >> 24)));
126 var t0 = mem.readInt(u32, &x, .little);
127 x = sbox_lookup(&sbox_encrypt, @as(u8, @truncate(s1)), @as(u8, @truncate(s2 >> 8)), @as(u8, @truncate(s3 >> 16)), @as(u8, @truncate(s0 >> 24)));
128 var t1 = mem.readInt(u32, &x, .little);
129 x = sbox_lookup(&sbox_encrypt, @as(u8, @truncate(s2)), @as(u8, @truncate(s3 >> 8)), @as(u8, @truncate(s0 >> 16)), @as(u8, @truncate(s1 >> 24)));
130 var t2 = mem.readInt(u32, &x, .little);
131 x = sbox_lookup(&sbox_encrypt, @as(u8, @truncate(s3)), @as(u8, @truncate(s0 >> 8)), @as(u8, @truncate(s1 >> 16)), @as(u8, @truncate(s2 >> 24)));
132 var t3 = mem.readInt(u32, &x, .little);
133
134 t0 ^= round_key.repr[0];
135 t1 ^= round_key.repr[1];
136 t2 ^= round_key.repr[2];
137 t3 ^= round_key.repr[3];
138
139 return Block{ .repr = Repr{ t0, t1, t2, t3 } };
140 }
141
142 /// Decrypt a block with a round key.
143 pub fn decrypt(block: Block, round_key: Block) Block {
144 const s0 = block.repr[0];
145 const s1 = block.repr[1];
146 const s2 = block.repr[2];
147 const s3 = block.repr[3];
148
149 var x: [4]u32 = undefined;
150 x = table_lookup(&table_decrypt, @as(u8, @truncate(s0)), @as(u8, @truncate(s3 >> 8)), @as(u8, @truncate(s2 >> 16)), @as(u8, @truncate(s1 >> 24)));
151 var t0 = x[0] ^ x[1] ^ x[2] ^ x[3];
152 x = table_lookup(&table_decrypt, @as(u8, @truncate(s1)), @as(u8, @truncate(s0 >> 8)), @as(u8, @truncate(s3 >> 16)), @as(u8, @truncate(s2 >> 24)));
153 var t1 = x[0] ^ x[1] ^ x[2] ^ x[3];
154 x = table_lookup(&table_decrypt, @as(u8, @truncate(s2)), @as(u8, @truncate(s1 >> 8)), @as(u8, @truncate(s0 >> 16)), @as(u8, @truncate(s3 >> 24)));
155 var t2 = x[0] ^ x[1] ^ x[2] ^ x[3];
156 x = table_lookup(&table_decrypt, @as(u8, @truncate(s3)), @as(u8, @truncate(s2 >> 8)), @as(u8, @truncate(s1 >> 16)), @as(u8, @truncate(s0 >> 24)));
157 var t3 = x[0] ^ x[1] ^ x[2] ^ x[3];
158
159 t0 ^= round_key.repr[0];
160 t1 ^= round_key.repr[1];
161 t2 ^= round_key.repr[2];
162 t3 ^= round_key.repr[3];
163
164 return Block{ .repr = Repr{ t0, t1, t2, t3 } };
165 }
166
167 /// Decrypt a block with a round key *WITHOUT ANY PROTECTION AGAINST SIDE CHANNELS*
168 pub fn decryptUnprotected(block: Block, round_key: Block) Block {
169 const s0 = block.repr[0];
170 const s1 = block.repr[1];
171 const s2 = block.repr[2];
172 const s3 = block.repr[3];
173
174 var x: [4]u32 = undefined;
175 x = .{
176 table_decrypt[0][@as(u8, @truncate(s0))],
177 table_decrypt[1][@as(u8, @truncate(s3 >> 8))],
178 table_decrypt[2][@as(u8, @truncate(s2 >> 16))],
179 table_decrypt[3][@as(u8, @truncate(s1 >> 24))],
180 };
181 var t0 = x[0] ^ x[1] ^ x[2] ^ x[3];
182 x = .{
183 table_decrypt[0][@as(u8, @truncate(s1))],
184 table_decrypt[1][@as(u8, @truncate(s0 >> 8))],
185 table_decrypt[2][@as(u8, @truncate(s3 >> 16))],
186 table_decrypt[3][@as(u8, @truncate(s2 >> 24))],
187 };
188 var t1 = x[0] ^ x[1] ^ x[2] ^ x[3];
189 x = .{
190 table_decrypt[0][@as(u8, @truncate(s2))],
191 table_decrypt[1][@as(u8, @truncate(s1 >> 8))],
192 table_decrypt[2][@as(u8, @truncate(s0 >> 16))],
193 table_decrypt[3][@as(u8, @truncate(s3 >> 24))],
194 };
195 var t2 = x[0] ^ x[1] ^ x[2] ^ x[3];
196 x = .{
197 table_decrypt[0][@as(u8, @truncate(s3))],
198 table_decrypt[1][@as(u8, @truncate(s2 >> 8))],
199 table_decrypt[2][@as(u8, @truncate(s1 >> 16))],
200 table_decrypt[3][@as(u8, @truncate(s0 >> 24))],
201 };
202 var t3 = x[0] ^ x[1] ^ x[2] ^ x[3];
203
204 t0 ^= round_key.repr[0];
205 t1 ^= round_key.repr[1];
206 t2 ^= round_key.repr[2];
207 t3 ^= round_key.repr[3];
208
209 return Block{ .repr = Repr{ t0, t1, t2, t3 } };
210 }
211
212 /// Decrypt a block with the last round key.
213 pub fn decryptLast(block: Block, round_key: Block) Block {
214 const s0 = block.repr[0];
215 const s1 = block.repr[1];
216 const s2 = block.repr[2];
217 const s3 = block.repr[3];
218
219 // Last round uses s-box directly and XORs to produce output.
220 var x: [4]u8 = undefined;
221 x = sbox_lookup(&sbox_decrypt, @as(u8, @truncate(s0)), @as(u8, @truncate(s3 >> 8)), @as(u8, @truncate(s2 >> 16)), @as(u8, @truncate(s1 >> 24)));
222 var t0 = mem.readInt(u32, &x, .little);
223 x = sbox_lookup(&sbox_decrypt, @as(u8, @truncate(s1)), @as(u8, @truncate(s0 >> 8)), @as(u8, @truncate(s3 >> 16)), @as(u8, @truncate(s2 >> 24)));
224 var t1 = mem.readInt(u32, &x, .little);
225 x = sbox_lookup(&sbox_decrypt, @as(u8, @truncate(s2)), @as(u8, @truncate(s1 >> 8)), @as(u8, @truncate(s0 >> 16)), @as(u8, @truncate(s3 >> 24)));
226 var t2 = mem.readInt(u32, &x, .little);
227 x = sbox_lookup(&sbox_decrypt, @as(u8, @truncate(s3)), @as(u8, @truncate(s2 >> 8)), @as(u8, @truncate(s1 >> 16)), @as(u8, @truncate(s0 >> 24)));
228 var t3 = mem.readInt(u32, &x, .little);
229
230 t0 ^= round_key.repr[0];
231 t1 ^= round_key.repr[1];
232 t2 ^= round_key.repr[2];
233 t3 ^= round_key.repr[3];
234
235 return Block{ .repr = Repr{ t0, t1, t2, t3 } };
236 }
237
238 /// Apply the bitwise XOR operation to the content of two blocks.
239 pub fn xorBlocks(block1: Block, block2: Block) Block {
240 var x: Repr = undefined;
241 comptime var i = 0;
242 inline while (i < 4) : (i += 1) {
243 x[i] = block1.repr[i] ^ block2.repr[i];
244 }
245 return Block{ .repr = x };
246 }
247
248 /// Apply the bitwise AND operation to the content of two blocks.
249 pub fn andBlocks(block1: Block, block2: Block) Block {
250 var x: Repr = undefined;
251 comptime var i = 0;
252 inline while (i < 4) : (i += 1) {
253 x[i] = block1.repr[i] & block2.repr[i];
254 }
255 return Block{ .repr = x };
256 }
257
258 /// Apply the bitwise OR operation to the content of two blocks.
259 pub fn orBlocks(block1: Block, block2: Block) Block {
260 var x: Repr = undefined;
261 comptime var i = 0;
262 inline while (i < 4) : (i += 1) {
263 x[i] = block1.repr[i] | block2.repr[i];
264 }
265 return Block{ .repr = x };
266 }
267
268 /// Apply the inverse MixColumns operation to a block.
269 pub fn invMixColumns(block: Block) Block {
270 var out: Repr = undefined;
271 inline for (0..4) |i| {
272 const col = block.repr[i];
273 const b0: u8 = @truncate(col);
274 const b1: u8 = @truncate(col >> 8);
275 const b2: u8 = @truncate(col >> 16);
276 const b3: u8 = @truncate(col >> 24);
277
278 const r0 = mul(0x0e, b0) ^ mul(0x0b, b1) ^ mul(0x0d, b2) ^ mul(0x09, b3);
279 const r1 = mul(0x09, b0) ^ mul(0x0e, b1) ^ mul(0x0b, b2) ^ mul(0x0d, b3);
280 const r2 = mul(0x0d, b0) ^ mul(0x09, b1) ^ mul(0x0e, b2) ^ mul(0x0b, b3);
281 const r3 = mul(0x0b, b0) ^ mul(0x0d, b1) ^ mul(0x09, b2) ^ mul(0x0e, b3);
282
283 out[i] = @as(u32, r0) | (@as(u32, r1) << 8) | (@as(u32, r2) << 16) | (@as(u32, r3) << 24);
284 }
285 return Block{ .repr = out };
286 }
287
288 /// Perform operations on multiple blocks in parallel.
289 pub const parallel = struct {
290 /// The recommended number of AES encryption/decryption to perform in parallel for the chosen implementation.
291 pub const optimal_parallel_blocks = 1;
292
293 /// Encrypt multiple blocks in parallel, each their own round key.
294 pub fn encryptParallel(comptime count: usize, blocks: [count]Block, round_keys: [count]Block) [count]Block {
295 var i = 0;
296 var out: [count]Block = undefined;
297 while (i < count) : (i += 1) {
298 out[i] = blocks[i].encrypt(round_keys[i]);
299 }
300 return out;
301 }
302
303 /// Decrypt multiple blocks in parallel, each their own round key.
304 pub fn decryptParallel(comptime count: usize, blocks: [count]Block, round_keys: [count]Block) [count]Block {
305 var i = 0;
306 var out: [count]Block = undefined;
307 while (i < count) : (i += 1) {
308 out[i] = blocks[i].decrypt(round_keys[i]);
309 }
310 return out;
311 }
312
313 /// Encrypt multiple blocks in parallel with the same round key.
314 pub fn encryptWide(comptime count: usize, blocks: [count]Block, round_key: Block) [count]Block {
315 var i = 0;
316 var out: [count]Block = undefined;
317 while (i < count) : (i += 1) {
318 out[i] = blocks[i].encrypt(round_key);
319 }
320 return out;
321 }
322
323 /// Decrypt multiple blocks in parallel with the same round key.
324 pub fn decryptWide(comptime count: usize, blocks: [count]Block, round_key: Block) [count]Block {
325 var i = 0;
326 var out: [count]Block = undefined;
327 while (i < count) : (i += 1) {
328 out[i] = blocks[i].decrypt(round_key);
329 }
330 return out;
331 }
332
333 /// Encrypt multiple blocks in parallel with the same last round key.
334 pub fn encryptLastWide(comptime count: usize, blocks: [count]Block, round_key: Block) [count]Block {
335 var i = 0;
336 var out: [count]Block = undefined;
337 while (i < count) : (i += 1) {
338 out[i] = blocks[i].encryptLast(round_key);
339 }
340 return out;
341 }
342
343 /// Decrypt multiple blocks in parallel with the same last round key.
344 pub fn decryptLastWide(comptime count: usize, blocks: [count]Block, round_key: Block) [count]Block {
345 var i = 0;
346 var out: [count]Block = undefined;
347 while (i < count) : (i += 1) {
348 out[i] = blocks[i].decryptLast(round_key);
349 }
350 return out;
351 }
352 };
353};
354
355/// A fixed-size vector of AES blocks.
356/// All operations are performed in parallel, using SIMD instructions when available.
357pub fn BlockVec(comptime blocks_count: comptime_int) type {
358 return struct {
359 const Self = @This();
360
361 /// The number of AES blocks the target architecture can process with a single instruction.
362 pub const native_vector_size = 1;
363
364 /// The size of the AES block vector that the target architecture can process with a single instruction, in bytes.
365 pub const native_word_size = native_vector_size * 16;
366
367 const native_words = blocks_count;
368
369 /// Internal representation of a block vector.
370 repr: [native_words]Block,
371
372 /// Length of the block vector in bytes.
373 pub const block_length: usize = blocks_count * 16;
374
375 /// Convert a byte sequence into an internal representation.
376 pub fn fromBytes(bytes: *const [blocks_count * 16]u8) Self {
377 var out: Self = undefined;
378 for (0..native_words) |i| {
379 out.repr[i] = Block.fromBytes(bytes[i * native_word_size ..][0..native_word_size]);
380 }
381 return out;
382 }
383
384 /// Convert the internal representation of a block vector into a byte sequence.
385 pub fn toBytes(block_vec: Self) [blocks_count * 16]u8 {
386 var out: [blocks_count * 16]u8 = undefined;
387 for (0..native_words) |i| {
388 out[i * native_word_size ..][0..native_word_size].* = block_vec.repr[i].toBytes();
389 }
390 return out;
391 }
392
393 /// XOR the block vector with a byte sequence.
394 pub fn xorBytes(block_vec: Self, bytes: *const [blocks_count * 16]u8) [32]u8 {
395 var out: Self = undefined;
396 for (0..native_words) |i| {
397 out.repr[i] = block_vec.repr[i].xorBytes(bytes[i * native_word_size ..][0..native_word_size]);
398 }
399 return out;
400 }
401
402 /// Apply the forward AES operation to the block vector with a vector of round keys.
403 pub fn encrypt(block_vec: Self, round_key_vec: Self) Self {
404 var out: Self = undefined;
405 for (0..native_words) |i| {
406 out.repr[i] = block_vec.repr[i].encrypt(round_key_vec.repr[i]);
407 }
408 return out;
409 }
410
411 /// Apply the forward AES operation to the block vector with a vector of last round keys.
412 pub fn encryptLast(block_vec: Self, round_key_vec: Self) Self {
413 var out: Self = undefined;
414 for (0..native_words) |i| {
415 out.repr[i] = block_vec.repr[i].encryptLast(round_key_vec.repr[i]);
416 }
417 return out;
418 }
419
420 /// Apply the inverse AES operation to the block vector with a vector of round keys.
421 pub fn decrypt(block_vec: Self, inv_round_key_vec: Self) Self {
422 var out: Self = undefined;
423 for (0..native_words) |i| {
424 out.repr[i] = block_vec.repr[i].decrypt(inv_round_key_vec.repr[i]);
425 }
426 return out;
427 }
428
429 /// Apply the inverse AES operation to the block vector with a vector of last round keys.
430 pub fn decryptLast(block_vec: Self, inv_round_key_vec: Self) Self {
431 var out: Self = undefined;
432 for (0..native_words) |i| {
433 out.repr[i] = block_vec.repr[i].decryptLast(inv_round_key_vec.repr[i]);
434 }
435 return out;
436 }
437
438 /// Apply the bitwise XOR operation to the content of two block vectors.
439 pub fn xorBlocks(block_vec1: Self, block_vec2: Self) Self {
440 var out: Self = undefined;
441 for (0..native_words) |i| {
442 out.repr[i] = block_vec1.repr[i].xorBlocks(block_vec2.repr[i]);
443 }
444 return out;
445 }
446
447 /// Apply the bitwise AND operation to the content of two block vectors.
448 pub fn andBlocks(block_vec1: Self, block_vec2: Self) Self {
449 var out: Self = undefined;
450 for (0..native_words) |i| {
451 out.repr[i] = block_vec1.repr[i].andBlocks(block_vec2.repr[i]);
452 }
453 return out;
454 }
455
456 /// Apply the bitwise OR operation to the content of two block vectors.
457 pub fn orBlocks(block_vec1: Self, block_vec2: Block) Self {
458 var out: Self = undefined;
459 for (0..native_words) |i| {
460 out.repr[i] = block_vec1.repr[i].orBlocks(block_vec2.repr[i]);
461 }
462 return out;
463 }
464
465 /// Apply the inverse MixColumns operation to each block in the vector.
466 pub fn invMixColumns(block_vec: Self) Self {
467 var out: Self = undefined;
468 for (0..native_words) |i| {
469 out.repr[i] = block_vec.repr[i].invMixColumns();
470 }
471 return out;
472 }
473 };
474}
475
476fn KeySchedule(comptime Aes: type) type {
477 std.debug.assert(Aes.rounds == 10 or Aes.rounds == 14);
478 const key_length = Aes.key_bits / 8;
479 const rounds = Aes.rounds;
480
481 return struct {
482 const Self = @This();
483 const words_in_key = key_length / 4;
484
485 round_keys: [rounds + 1]Block,
486
487 // Key expansion algorithm. See FIPS-197, Figure 11.
488 fn expandKey(key: [key_length]u8) Self {
489 const subw = struct {
490 // Apply sbox_encrypt to each byte in w.
491 fn func(w: u32) u32 {
492 const x = sbox_lookup(&sbox_key_schedule, @as(u8, @truncate(w)), @as(u8, @truncate(w >> 8)), @as(u8, @truncate(w >> 16)), @as(u8, @truncate(w >> 24)));
493 return mem.readInt(u32, &x, .little);
494 }
495 }.func;
496
497 var round_keys: [rounds + 1]Block = undefined;
498 comptime var i: usize = 0;
499 inline while (i < words_in_key) : (i += 1) {
500 round_keys[i / 4].repr[i % 4] = mem.readInt(u32, key[4 * i ..][0..4], .big);
501 }
502 inline while (i < round_keys.len * 4) : (i += 1) {
503 var t = round_keys[(i - 1) / 4].repr[(i - 1) % 4];
504 if (i % words_in_key == 0) {
505 t = subw(std.math.rotl(u32, t, 8)) ^ (@as(u32, powx[i / words_in_key - 1]) << 24);
506 } else if (words_in_key > 6 and i % words_in_key == 4) {
507 t = subw(t);
508 }
509 round_keys[i / 4].repr[i % 4] = round_keys[(i - words_in_key) / 4].repr[(i - words_in_key) % 4] ^ t;
510 }
511 i = 0;
512 inline while (i < round_keys.len * 4) : (i += 1) {
513 round_keys[i / 4].repr[i % 4] = @byteSwap(round_keys[i / 4].repr[i % 4]);
514 }
515 return Self{ .round_keys = round_keys };
516 }
517
518 /// Invert the key schedule.
519 pub fn invert(key_schedule: Self) Self {
520 const round_keys = &key_schedule.round_keys;
521 var inv_round_keys: [rounds + 1]Block = undefined;
522 const total_words = 4 * round_keys.len;
523 var i: usize = 0;
524 while (i < total_words) : (i += 4) {
525 const ei = total_words - i - 4;
526 comptime var j: usize = 0;
527 inline while (j < 4) : (j += 1) {
528 var rk = round_keys[(ei + j) / 4].repr[(ei + j) % 4];
529 if (i > 0 and i + 4 < total_words) {
530 const x = sbox_lookup(&sbox_key_schedule, @as(u8, @truncate(rk >> 24)), @as(u8, @truncate(rk >> 16)), @as(u8, @truncate(rk >> 8)), @as(u8, @truncate(rk)));
531 const y = table_lookup(&table_decrypt, x[3], x[2], x[1], x[0]);
532 rk = y[0] ^ y[1] ^ y[2] ^ y[3];
533 }
534 inv_round_keys[(i + j) / 4].repr[(i + j) % 4] = rk;
535 }
536 }
537 return Self{ .round_keys = inv_round_keys };
538 }
539 };
540}
541
542/// A context to perform encryption using the standard AES key schedule.
543pub fn AesEncryptCtx(comptime Aes: type) type {
544 std.debug.assert(Aes.key_bits == 128 or Aes.key_bits == 256);
545 const rounds = Aes.rounds;
546
547 return struct {
548 const Self = @This();
549 pub const block = Aes.block;
550 pub const block_length = block.block_length;
551 key_schedule: KeySchedule(Aes),
552
553 /// Create a new encryption context with the given key.
554 pub fn init(key: [Aes.key_bits / 8]u8) Self {
555 const key_schedule = KeySchedule(Aes).expandKey(key);
556 return Self{
557 .key_schedule = key_schedule,
558 };
559 }
560
561 /// Encrypt a single block.
562 pub fn encrypt(ctx: Self, dst: *[16]u8, src: *const [16]u8) void {
563 const round_keys = ctx.key_schedule.round_keys;
564 var t = Block.fromBytes(src).xorBlocks(round_keys[0]);
565 comptime var i = 1;
566 if (side_channels_mitigations == .full) {
567 inline while (i < rounds) : (i += 1) {
568 t = t.encrypt(round_keys[i]);
569 }
570 } else {
571 inline while (i < 5) : (i += 1) {
572 t = t.encrypt(round_keys[i]);
573 }
574 inline while (i < rounds - 1) : (i += 1) {
575 t = t.encryptUnprotected(round_keys[i]);
576 }
577 t = t.encrypt(round_keys[i]);
578 }
579 t = t.encryptLast(round_keys[rounds]);
580 dst.* = t.toBytes();
581 }
582
583 /// Encrypt+XOR a single block.
584 pub fn xor(ctx: Self, dst: *[16]u8, src: *const [16]u8, counter: [16]u8) void {
585 const round_keys = ctx.key_schedule.round_keys;
586 var t = Block.fromBytes(&counter).xorBlocks(round_keys[0]);
587 comptime var i = 1;
588 if (side_channels_mitigations == .full) {
589 inline while (i < rounds) : (i += 1) {
590 t = t.encrypt(round_keys[i]);
591 }
592 } else {
593 inline while (i < 5) : (i += 1) {
594 t = t.encrypt(round_keys[i]);
595 }
596 inline while (i < rounds - 1) : (i += 1) {
597 t = t.encryptUnprotected(round_keys[i]);
598 }
599 t = t.encrypt(round_keys[i]);
600 }
601 t = t.encryptLast(round_keys[rounds]);
602 dst.* = t.xorBytes(src);
603 }
604
605 /// Encrypt multiple blocks, possibly leveraging parallelization.
606 pub fn encryptWide(ctx: Self, comptime count: usize, dst: *[16 * count]u8, src: *const [16 * count]u8) void {
607 var i: usize = 0;
608 while (i < count) : (i += 1) {
609 ctx.encrypt(dst[16 * i .. 16 * i + 16][0..16], src[16 * i .. 16 * i + 16][0..16]);
610 }
611 }
612
613 /// Encrypt+XOR multiple blocks, possibly leveraging parallelization.
614 pub fn xorWide(ctx: Self, comptime count: usize, dst: *[16 * count]u8, src: *const [16 * count]u8, counters: [16 * count]u8) void {
615 var i: usize = 0;
616 while (i < count) : (i += 1) {
617 ctx.xor(dst[16 * i .. 16 * i + 16][0..16], src[16 * i .. 16 * i + 16][0..16], counters[16 * i .. 16 * i + 16][0..16].*);
618 }
619 }
620 };
621}
622
623/// A context to perform decryption using the standard AES key schedule.
624pub fn AesDecryptCtx(comptime Aes: type) type {
625 std.debug.assert(Aes.key_bits == 128 or Aes.key_bits == 256);
626 const rounds = Aes.rounds;
627
628 return struct {
629 const Self = @This();
630 pub const block = Aes.block;
631 pub const block_length = block.block_length;
632 key_schedule: KeySchedule(Aes),
633
634 /// Create a decryption context from an existing encryption context.
635 pub fn initFromEnc(ctx: AesEncryptCtx(Aes)) Self {
636 return Self{
637 .key_schedule = ctx.key_schedule.invert(),
638 };
639 }
640
641 /// Create a new decryption context with the given key.
642 pub fn init(key: [Aes.key_bits / 8]u8) Self {
643 const enc_ctx = AesEncryptCtx(Aes).init(key);
644 return initFromEnc(enc_ctx);
645 }
646
647 /// Decrypt a single block.
648 pub fn decrypt(ctx: Self, dst: *[16]u8, src: *const [16]u8) void {
649 const inv_round_keys = ctx.key_schedule.round_keys;
650 var t = Block.fromBytes(src).xorBlocks(inv_round_keys[0]);
651 comptime var i = 1;
652 if (side_channels_mitigations == .full) {
653 inline while (i < rounds) : (i += 1) {
654 t = t.decrypt(inv_round_keys[i]);
655 }
656 } else {
657 inline while (i < 5) : (i += 1) {
658 t = t.decrypt(inv_round_keys[i]);
659 }
660 inline while (i < rounds - 1) : (i += 1) {
661 t = t.decryptUnprotected(inv_round_keys[i]);
662 }
663 t = t.decrypt(inv_round_keys[i]);
664 }
665 t = t.decryptLast(inv_round_keys[rounds]);
666 dst.* = t.toBytes();
667 }
668
669 /// Decrypt multiple blocks, possibly leveraging parallelization.
670 pub fn decryptWide(ctx: Self, comptime count: usize, dst: *[16 * count]u8, src: *const [16 * count]u8) void {
671 var i: usize = 0;
672 while (i < count) : (i += 1) {
673 ctx.decrypt(dst[16 * i .. 16 * i + 16][0..16], src[16 * i .. 16 * i + 16][0..16]);
674 }
675 }
676 };
677}
678
679/// AES-128 with the standard key schedule.
680pub const Aes128 = struct {
681 pub const key_bits: usize = 128;
682 pub const rounds = ((key_bits - 64) / 32 + 8);
683 pub const block = Block;
684
685 /// Create a new context for encryption.
686 pub fn initEnc(key: [key_bits / 8]u8) AesEncryptCtx(Aes128) {
687 return AesEncryptCtx(Aes128).init(key);
688 }
689
690 /// Create a new context for decryption.
691 pub fn initDec(key: [key_bits / 8]u8) AesDecryptCtx(Aes128) {
692 return AesDecryptCtx(Aes128).init(key);
693 }
694};
695
696/// AES-256 with the standard key schedule.
697pub const Aes256 = struct {
698 pub const key_bits: usize = 256;
699 pub const rounds = ((key_bits - 64) / 32 + 8);
700 pub const block = Block;
701
702 /// Create a new context for encryption.
703 pub fn initEnc(key: [key_bits / 8]u8) AesEncryptCtx(Aes256) {
704 return AesEncryptCtx(Aes256).init(key);
705 }
706
707 /// Create a new context for decryption.
708 pub fn initDec(key: [key_bits / 8]u8) AesDecryptCtx(Aes256) {
709 return AesDecryptCtx(Aes256).init(key);
710 }
711};
712
713// constants
714
715// Rijndael's irreducible polynomial.
716const poly: u9 = 1 << 8 | 1 << 4 | 1 << 3 | 1 << 1 | 1 << 0; // x⁸ + x⁴ + x³ + x + 1
717
718// Powers of x mod poly in GF(2).
719const powx = init: {
720 var array: [16]u8 = undefined;
721
722 var value = 1;
723 for (&array) |*power| {
724 power.* = value;
725 value = mul(value, 2);
726 }
727
728 break :init array;
729};
730
731const sbox_encrypt align(64) = generateSbox(false); // S-box for encryption
732const sbox_key_schedule align(64) = generateSbox(false); // S-box only for key schedule, so that it uses distinct L1 cache entries than the S-box used for encryption
733const sbox_decrypt align(64) = generateSbox(true); // S-box for decryption
734const table_encrypt align(64) = generateTable(false); // 4-byte LUTs for encryption
735const table_decrypt align(64) = generateTable(true); // 4-byte LUTs for decryption
736
737// Generate S-box substitution values.
738fn generateSbox(invert: bool) [256]u8 {
739 @setEvalBranchQuota(10000);
740
741 var sbox: [256]u8 = undefined;
742
743 var p: u8 = 1;
744 var q: u8 = 1;
745 for (sbox) |_| {
746 p = mul(p, 3);
747 q = mul(q, 0xf6); // divide by 3
748
749 var value: u8 = q ^ 0x63;
750 value ^= math.rotl(u8, q, 1);
751 value ^= math.rotl(u8, q, 2);
752 value ^= math.rotl(u8, q, 3);
753 value ^= math.rotl(u8, q, 4);
754
755 if (invert) {
756 sbox[value] = p;
757 } else {
758 sbox[p] = value;
759 }
760 }
761
762 if (invert) {
763 sbox[0x63] = 0x00;
764 } else {
765 sbox[0x00] = 0x63;
766 }
767
768 return sbox;
769}
770
771// Generate lookup tables.
772fn generateTable(invert: bool) [4][256]u32 {
773 @setEvalBranchQuota(50000);
774
775 var table: [4][256]u32 = undefined;
776
777 for (generateSbox(invert), 0..) |value, index| {
778 table[0][index] = math.shl(u32, mul(value, if (invert) 0xb else 0x3), 24);
779 table[0][index] |= math.shl(u32, mul(value, if (invert) 0xd else 0x1), 16);
780 table[0][index] |= math.shl(u32, mul(value, if (invert) 0x9 else 0x1), 8);
781 table[0][index] |= mul(value, if (invert) 0xe else 0x2);
782
783 table[1][index] = math.rotl(u32, table[0][index], 8);
784 table[2][index] = math.rotl(u32, table[0][index], 16);
785 table[3][index] = math.rotl(u32, table[0][index], 24);
786 }
787
788 return table;
789}
790
791// Multiply a and b as GF(2) polynomials modulo poly.
792fn mul(a: u8, b: u8) u8 {
793 @setEvalBranchQuota(30000);
794
795 var i: u8 = a;
796 var j: u9 = b;
797 var s: u9 = 0;
798
799 while (i > 0) : (i >>= 1) {
800 if (i & 1 != 0) {
801 s ^= j;
802 }
803
804 j *= 2;
805 if (j & 0x100 != 0) {
806 j ^= poly;
807 }
808 }
809
810 return @as(u8, @truncate(s));
811}
812
813const cache_line_bytes = std.atomic.cache_line;
814
815fn sbox_lookup(sbox: *align(64) const [256]u8, idx0: u8, idx1: u8, idx2: u8, idx3: u8) [4]u8 {
816 if (side_channels_mitigations == .none) {
817 return [4]u8{
818 sbox[idx0],
819 sbox[idx1],
820 sbox[idx2],
821 sbox[idx3],
822 };
823 } else {
824 const stride = switch (side_channels_mitigations) {
825 .none => unreachable,
826 .basic => sbox.len / 4,
827 .medium => @min(sbox.len, 2 * cache_line_bytes),
828 .full => @min(sbox.len, cache_line_bytes),
829 };
830 const of0 = idx0 % stride;
831 const of1 = idx1 % stride;
832 const of2 = idx2 % stride;
833 const of3 = idx3 % stride;
834 var t: [4][sbox.len / stride]u8 align(64) = undefined;
835 var i: usize = 0;
836 while (i < t[0].len) : (i += 1) {
837 const tx = sbox[i * stride ..];
838 t[0][i] = tx[of0];
839 t[1][i] = tx[of1];
840 t[2][i] = tx[of2];
841 t[3][i] = tx[of3];
842 }
843 std.mem.doNotOptimizeAway(t);
844 return [4]u8{
845 t[0][idx0 / stride],
846 t[1][idx1 / stride],
847 t[2][idx2 / stride],
848 t[3][idx3 / stride],
849 };
850 }
851}
852
853fn table_lookup(table: *align(64) const [4][256]u32, idx0: u8, idx1: u8, idx2: u8, idx3: u8) [4]u32 {
854 if (side_channels_mitigations == .none) {
855 return [4]u32{
856 table[0][idx0],
857 table[1][idx1],
858 table[2][idx2],
859 table[3][idx3],
860 };
861 } else {
862 const table_len: usize = 256;
863 const stride = switch (side_channels_mitigations) {
864 .none => unreachable,
865 .basic => table_len / 4,
866 .medium => @max(1, @min(table_len, 2 * cache_line_bytes / 4)),
867 .full => @max(1, @min(table_len, cache_line_bytes / 4)),
868 };
869 const of0 = idx0 % stride;
870 const of1 = idx1 % stride;
871 const of2 = idx2 % stride;
872 const of3 = idx3 % stride;
873 var t: [4][table_len / stride]u32 align(64) = undefined;
874 var i: usize = 0;
875 while (i < t[0].len) : (i += 1) {
876 const tx = table[0][i * stride ..];
877 t[0][i] = tx[of0];
878 t[1][i] = tx[of1];
879 t[2][i] = tx[of2];
880 t[3][i] = tx[of3];
881 }
882 std.mem.doNotOptimizeAway(t);
883 return [4]u32{
884 t[0][idx0 / stride],
885 math.rotl(u32, (&t[1])[idx1 / stride], 8),
886 math.rotl(u32, (&t[2])[idx2 / stride], 16),
887 math.rotl(u32, (&t[3])[idx3 / stride], 24),
888 };
889 }
890}