master
1const std = @import("../../std.zig");
2const mem = std.mem;
3const debug = std.debug;
4
5/// A single AES block.
6pub const Block = struct {
7 const Repr = @Vector(2, u64);
8
9 pub const block_length: usize = 16;
10
11 /// Internal representation of a block.
12 repr: Repr,
13
14 /// Convert a byte sequence into an internal representation.
15 pub fn fromBytes(bytes: *const [16]u8) Block {
16 const repr = mem.bytesToValue(Repr, bytes);
17 return Block{ .repr = repr };
18 }
19
20 /// Convert the internal representation of a block into a byte sequence.
21 pub fn toBytes(block: Block) [16]u8 {
22 return mem.toBytes(block.repr);
23 }
24
25 /// XOR the block with a byte sequence.
26 pub fn xorBytes(block: Block, bytes: *const [16]u8) [16]u8 {
27 const x = block.repr ^ fromBytes(bytes).repr;
28 return mem.toBytes(x);
29 }
30
31 const zero = @Vector(2, u64){ 0, 0 };
32
33 /// Encrypt a block with a round key.
34 pub fn encrypt(block: Block, round_key: Block) Block {
35 return Block{
36 .repr = (asm (
37 \\ mov %[out].16b, %[in].16b
38 \\ aese %[out].16b, %[zero].16b
39 \\ aesmc %[out].16b, %[out].16b
40 : [out] "=&x" (-> Repr),
41 : [in] "x" (block.repr),
42 [zero] "x" (zero),
43 )) ^ round_key.repr,
44 };
45 }
46
47 /// Encrypt a block with the last round key.
48 pub fn encryptLast(block: Block, round_key: Block) Block {
49 return Block{
50 .repr = (asm (
51 \\ mov %[out].16b, %[in].16b
52 \\ aese %[out].16b, %[zero].16b
53 : [out] "=&x" (-> Repr),
54 : [in] "x" (block.repr),
55 [zero] "x" (zero),
56 )) ^ round_key.repr,
57 };
58 }
59
60 /// Decrypt a block with a round key.
61 pub fn decrypt(block: Block, inv_round_key: Block) Block {
62 return Block{
63 .repr = (asm (
64 \\ mov %[out].16b, %[in].16b
65 \\ aesd %[out].16b, %[zero].16b
66 \\ aesimc %[out].16b, %[out].16b
67 : [out] "=&x" (-> Repr),
68 : [in] "x" (block.repr),
69 [zero] "x" (zero),
70 )) ^ inv_round_key.repr,
71 };
72 }
73
74 /// Decrypt a block with the last round key.
75 pub fn decryptLast(block: Block, inv_round_key: Block) Block {
76 return Block{
77 .repr = (asm (
78 \\ mov %[out].16b, %[in].16b
79 \\ aesd %[out].16b, %[zero].16b
80 : [out] "=&x" (-> Repr),
81 : [in] "x" (block.repr),
82 [zero] "x" (zero),
83 )) ^ inv_round_key.repr,
84 };
85 }
86
87 /// Apply the bitwise XOR operation to the content of two blocks.
88 pub fn xorBlocks(block1: Block, block2: Block) Block {
89 return Block{ .repr = block1.repr ^ block2.repr };
90 }
91
92 /// Apply the bitwise AND operation to the content of two blocks.
93 pub fn andBlocks(block1: Block, block2: Block) Block {
94 return Block{ .repr = block1.repr & block2.repr };
95 }
96
97 /// Apply the bitwise OR operation to the content of two blocks.
98 pub fn orBlocks(block1: Block, block2: Block) Block {
99 return Block{ .repr = block1.repr | block2.repr };
100 }
101
102 /// Apply the inverse MixColumns operation to a block.
103 pub fn invMixColumns(block: Block) Block {
104 return Block{
105 .repr = asm (
106 \\ aesimc %[out].16b, %[in].16b
107 : [out] "=x" (-> Repr),
108 : [in] "x" (block.repr),
109 ),
110 };
111 }
112
113 /// Perform operations on multiple blocks in parallel.
114 pub const parallel = struct {
115 /// The recommended number of AES encryption/decryption to perform in parallel for the chosen implementation.
116 pub const optimal_parallel_blocks = 6;
117
118 /// Encrypt multiple blocks in parallel, each their own round key.
119 pub fn encryptParallel(comptime count: usize, blocks: [count]Block, round_keys: [count]Block) [count]Block {
120 comptime var i = 0;
121 var out: [count]Block = undefined;
122 inline while (i < count) : (i += 1) {
123 out[i] = blocks[i].encrypt(round_keys[i]);
124 }
125 return out;
126 }
127
128 /// Decrypt multiple blocks in parallel, each their own round key.
129 pub fn decryptParallel(comptime count: usize, blocks: [count]Block, round_keys: [count]Block) [count]Block {
130 comptime var i = 0;
131 var out: [count]Block = undefined;
132 inline while (i < count) : (i += 1) {
133 out[i] = blocks[i].decrypt(round_keys[i]);
134 }
135 return out;
136 }
137
138 /// Encrypt multiple blocks in parallel with the same round key.
139 pub fn encryptWide(comptime count: usize, blocks: [count]Block, round_key: Block) [count]Block {
140 comptime var i = 0;
141 var out: [count]Block = undefined;
142 inline while (i < count) : (i += 1) {
143 out[i] = blocks[i].encrypt(round_key);
144 }
145 return out;
146 }
147
148 /// Decrypt multiple blocks in parallel with the same round key.
149 pub fn decryptWide(comptime count: usize, blocks: [count]Block, round_key: Block) [count]Block {
150 comptime var i = 0;
151 var out: [count]Block = undefined;
152 inline while (i < count) : (i += 1) {
153 out[i] = blocks[i].decrypt(round_key);
154 }
155 return out;
156 }
157
158 /// Encrypt multiple blocks in parallel with the same last round key.
159 pub fn encryptLastWide(comptime count: usize, blocks: [count]Block, round_key: Block) [count]Block {
160 comptime var i = 0;
161 var out: [count]Block = undefined;
162 inline while (i < count) : (i += 1) {
163 out[i] = blocks[i].encryptLast(round_key);
164 }
165 return out;
166 }
167
168 /// Decrypt multiple blocks in parallel with the same last round key.
169 pub fn decryptLastWide(comptime count: usize, blocks: [count]Block, round_key: Block) [count]Block {
170 comptime var i = 0;
171 var out: [count]Block = undefined;
172 inline while (i < count) : (i += 1) {
173 out[i] = blocks[i].decryptLast(round_key);
174 }
175 return out;
176 }
177 };
178};
179
180/// A fixed-size vector of AES blocks.
181/// All operations are performed in parallel, using SIMD instructions when available.
182pub fn BlockVec(comptime blocks_count: comptime_int) type {
183 return struct {
184 const Self = @This();
185
186 /// The number of AES blocks the target architecture can process with a single instruction.
187 pub const native_vector_size = 1;
188
189 /// The size of the AES block vector that the target architecture can process with a single instruction, in bytes.
190 pub const native_word_size = native_vector_size * 16;
191
192 const native_words = blocks_count;
193
194 /// Internal representation of a block vector.
195 repr: [native_words]Block,
196
197 /// Length of the block vector in bytes.
198 pub const block_length: usize = blocks_count * 16;
199
200 /// Convert a byte sequence into an internal representation.
201 pub fn fromBytes(bytes: *const [blocks_count * 16]u8) Self {
202 var out: Self = undefined;
203 inline for (0..native_words) |i| {
204 out.repr[i] = Block.fromBytes(bytes[i * native_word_size ..][0..native_word_size]);
205 }
206 return out;
207 }
208
209 /// Convert the internal representation of a block vector into a byte sequence.
210 pub fn toBytes(block_vec: Self) [blocks_count * 16]u8 {
211 var out: [blocks_count * 16]u8 = undefined;
212 inline for (0..native_words) |i| {
213 out[i * native_word_size ..][0..native_word_size].* = block_vec.repr[i].toBytes();
214 }
215 return out;
216 }
217
218 /// XOR the block vector with a byte sequence.
219 pub fn xorBytes(block_vec: Self, bytes: *const [blocks_count * 16]u8) [32]u8 {
220 var out: Self = undefined;
221 inline for (0..native_words) |i| {
222 out.repr[i] = block_vec.repr[i].xorBytes(bytes[i * native_word_size ..][0..native_word_size]);
223 }
224 return out;
225 }
226
227 /// Apply the forward AES operation to the block vector with a vector of round keys.
228 pub fn encrypt(block_vec: Self, round_key_vec: Self) Self {
229 var out: Self = undefined;
230 inline for (0..native_words) |i| {
231 out.repr[i] = block_vec.repr[i].encrypt(round_key_vec.repr[i]);
232 }
233 return out;
234 }
235
236 /// Apply the forward AES operation to the block vector with a vector of last round keys.
237 pub fn encryptLast(block_vec: Self, round_key_vec: Self) Self {
238 var out: Self = undefined;
239 inline for (0..native_words) |i| {
240 out.repr[i] = block_vec.repr[i].encryptLast(round_key_vec.repr[i]);
241 }
242 return out;
243 }
244
245 /// Apply the inverse AES operation to the block vector with a vector of round keys.
246 pub fn decrypt(block_vec: Self, inv_round_key_vec: Self) Self {
247 var out: Self = undefined;
248 inline for (0..native_words) |i| {
249 out.repr[i] = block_vec.repr[i].decrypt(inv_round_key_vec.repr[i]);
250 }
251 return out;
252 }
253
254 /// Apply the inverse AES operation to the block vector with a vector of last round keys.
255 pub fn decryptLast(block_vec: Self, inv_round_key_vec: Self) Self {
256 var out: Self = undefined;
257 inline for (0..native_words) |i| {
258 out.repr[i] = block_vec.repr[i].decryptLast(inv_round_key_vec.repr[i]);
259 }
260 return out;
261 }
262
263 /// Apply the bitwise XOR operation to the content of two block vectors.
264 pub fn xorBlocks(block_vec1: Self, block_vec2: Self) Self {
265 var out: Self = undefined;
266 inline for (0..native_words) |i| {
267 out.repr[i] = block_vec1.repr[i].xorBlocks(block_vec2.repr[i]);
268 }
269 return out;
270 }
271
272 /// Apply the bitwise AND operation to the content of two block vectors.
273 pub fn andBlocks(block_vec1: Self, block_vec2: Self) Self {
274 var out: Self = undefined;
275 inline for (0..native_words) |i| {
276 out.repr[i] = block_vec1.repr[i].andBlocks(block_vec2.repr[i]);
277 }
278 return out;
279 }
280
281 /// Apply the bitwise OR operation to the content of two block vectors.
282 pub fn orBlocks(block_vec1: Self, block_vec2: Block) Self {
283 var out: Self = undefined;
284 inline for (0..native_words) |i| {
285 out.repr[i] = block_vec1.repr[i].orBlocks(block_vec2.repr[i]);
286 }
287 return out;
288 }
289
290 /// Apply the inverse MixColumns operation to each block in the vector.
291 pub fn invMixColumns(block_vec: Self) Self {
292 var out: Self = undefined;
293 inline for (0..native_words) |i| {
294 out.repr[i] = block_vec.repr[i].invMixColumns();
295 }
296 return out;
297 }
298 };
299}
300
301fn KeySchedule(comptime Aes: type) type {
302 std.debug.assert(Aes.rounds == 10 or Aes.rounds == 14);
303 const rounds = Aes.rounds;
304
305 return struct {
306 const Self = @This();
307
308 const Repr = Aes.block.Repr;
309
310 const zero = @Vector(2, u64){ 0, 0 };
311 const mask1 = @Vector(16, u8){ 13, 14, 15, 12, 13, 14, 15, 12, 13, 14, 15, 12, 13, 14, 15, 12 };
312 const mask2 = @Vector(16, u8){ 12, 13, 14, 15, 12, 13, 14, 15, 12, 13, 14, 15, 12, 13, 14, 15 };
313
314 round_keys: [rounds + 1]Block,
315
316 fn drc128(comptime rc: u8, t: Repr) Repr {
317 var v1: Repr = undefined;
318 var v2: Repr = undefined;
319 var v3: Repr = undefined;
320 var v4: Repr = undefined;
321
322 return asm (
323 \\ movi %[v2].4s, %[rc]
324 \\ tbl %[v4].16b, {%[t].16b}, %[mask].16b
325 \\ ext %[r].16b, %[zero].16b, %[t].16b, #12
326 \\ aese %[v4].16b, %[zero].16b
327 \\ eor %[v2].16b, %[r].16b, %[v2].16b
328 \\ ext %[r].16b, %[zero].16b, %[r].16b, #12
329 \\ eor %[v1].16b, %[v2].16b, %[t].16b
330 \\ ext %[v3].16b, %[zero].16b, %[r].16b, #12
331 \\ eor %[v1].16b, %[v1].16b, %[r].16b
332 \\ eor %[r].16b, %[v1].16b, %[v3].16b
333 \\ eor %[r].16b, %[r].16b, %[v4].16b
334 : [r] "=&x" (-> Repr),
335 [v1] "=&x" (v1),
336 [v2] "=&x" (v2),
337 [v3] "=&x" (v3),
338 [v4] "=&x" (v4),
339 : [rc] "N" (rc),
340 [t] "x" (t),
341 [zero] "x" (zero),
342 [mask] "x" (mask1),
343 );
344 }
345
346 fn drc256(comptime second: bool, comptime rc: u8, t: Repr, tx: Repr) Repr {
347 var v1: Repr = undefined;
348 var v2: Repr = undefined;
349 var v3: Repr = undefined;
350 var v4: Repr = undefined;
351
352 return asm (
353 \\ movi %[v2].4s, %[rc]
354 \\ tbl %[v4].16b, {%[t].16b}, %[mask].16b
355 \\ ext %[r].16b, %[zero].16b, %[tx].16b, #12
356 \\ aese %[v4].16b, %[zero].16b
357 \\ eor %[v1].16b, %[tx].16b, %[r].16b
358 \\ ext %[r].16b, %[zero].16b, %[r].16b, #12
359 \\ eor %[v1].16b, %[v1].16b, %[r].16b
360 \\ ext %[v3].16b, %[zero].16b, %[r].16b, #12
361 \\ eor %[v1].16b, %[v1].16b, %[v2].16b
362 \\ eor %[v1].16b, %[v1].16b, %[v3].16b
363 \\ eor %[r].16b, %[v1].16b, %[v4].16b
364 : [r] "=&x" (-> Repr),
365 [v1] "=&x" (v1),
366 [v2] "=&x" (v2),
367 [v3] "=&x" (v3),
368 [v4] "=&x" (v4),
369 : [rc] "N" (if (second) @as(u8, 0) else rc),
370 [t] "x" (t),
371 [tx] "x" (tx),
372 [zero] "x" (zero),
373 [mask] "x" (if (second) mask2 else mask1),
374 );
375 }
376
377 fn expand128(t1: *Block) Self {
378 var round_keys: [11]Block = undefined;
379 const rcs = [_]u8{ 1, 2, 4, 8, 16, 32, 64, 128, 27, 54 };
380 inline for (rcs, 0..) |rc, round| {
381 round_keys[round] = t1.*;
382 t1.repr = drc128(rc, t1.repr);
383 }
384 round_keys[rcs.len] = t1.*;
385 return Self{ .round_keys = round_keys };
386 }
387
388 fn expand256(t1: *Block, t2: *Block) Self {
389 var round_keys: [15]Block = undefined;
390 const rcs = [_]u8{ 1, 2, 4, 8, 16, 32 };
391 round_keys[0] = t1.*;
392 inline for (rcs, 0..) |rc, round| {
393 round_keys[round * 2 + 1] = t2.*;
394 t1.repr = drc256(false, rc, t2.repr, t1.repr);
395 round_keys[round * 2 + 2] = t1.*;
396 t2.repr = drc256(true, rc, t1.repr, t2.repr);
397 }
398 round_keys[rcs.len * 2 + 1] = t2.*;
399 t1.repr = drc256(false, 64, t2.repr, t1.repr);
400 round_keys[rcs.len * 2 + 2] = t1.*;
401 return Self{ .round_keys = round_keys };
402 }
403
404 /// Invert the key schedule.
405 pub fn invert(key_schedule: Self) Self {
406 const round_keys = &key_schedule.round_keys;
407 var inv_round_keys: [rounds + 1]Block = undefined;
408 inv_round_keys[0] = round_keys[rounds];
409 comptime var i = 1;
410 inline while (i < rounds) : (i += 1) {
411 inv_round_keys[i] = Block{
412 .repr = asm (
413 \\ aesimc %[inv_rk].16b, %[rk].16b
414 : [inv_rk] "=x" (-> Repr),
415 : [rk] "x" (round_keys[rounds - i].repr),
416 ),
417 };
418 }
419 inv_round_keys[rounds] = round_keys[0];
420 return Self{ .round_keys = inv_round_keys };
421 }
422 };
423}
424
425/// A context to perform encryption using the standard AES key schedule.
426pub fn AesEncryptCtx(comptime Aes: type) type {
427 std.debug.assert(Aes.key_bits == 128 or Aes.key_bits == 256);
428 const rounds = Aes.rounds;
429
430 return struct {
431 const Self = @This();
432 pub const block = Aes.block;
433 pub const block_length = block.block_length;
434 key_schedule: KeySchedule(Aes),
435
436 /// Create a new encryption context with the given key.
437 pub fn init(key: [Aes.key_bits / 8]u8) Self {
438 var t1 = Block.fromBytes(key[0..16]);
439 const key_schedule = if (Aes.key_bits == 128) ks: {
440 break :ks KeySchedule(Aes).expand128(&t1);
441 } else ks: {
442 var t2 = Block.fromBytes(key[16..32]);
443 break :ks KeySchedule(Aes).expand256(&t1, &t2);
444 };
445 return Self{
446 .key_schedule = key_schedule,
447 };
448 }
449
450 /// Encrypt a single block.
451 pub fn encrypt(ctx: Self, dst: *[16]u8, src: *const [16]u8) void {
452 const round_keys = ctx.key_schedule.round_keys;
453 var t = Block.fromBytes(src).xorBlocks(round_keys[0]);
454 comptime var i = 1;
455 inline while (i < rounds) : (i += 1) {
456 t = t.encrypt(round_keys[i]);
457 }
458 t = t.encryptLast(round_keys[rounds]);
459 dst.* = t.toBytes();
460 }
461
462 /// Encrypt+XOR a single block.
463 pub fn xor(ctx: Self, dst: *[16]u8, src: *const [16]u8, counter: [16]u8) void {
464 const round_keys = ctx.key_schedule.round_keys;
465 var t = Block.fromBytes(&counter).xorBlocks(round_keys[0]);
466 comptime var i = 1;
467 inline while (i < rounds) : (i += 1) {
468 t = t.encrypt(round_keys[i]);
469 }
470 t = t.encryptLast(round_keys[rounds]);
471 dst.* = t.xorBytes(src);
472 }
473
474 /// Encrypt multiple blocks, possibly leveraging parallelization.
475 pub fn encryptWide(ctx: Self, comptime count: usize, dst: *[16 * count]u8, src: *const [16 * count]u8) void {
476 const round_keys = ctx.key_schedule.round_keys;
477 var ts: [count]Block = undefined;
478 comptime var j = 0;
479 inline while (j < count) : (j += 1) {
480 ts[j] = Block.fromBytes(src[j * 16 .. j * 16 + 16][0..16]).xorBlocks(round_keys[0]);
481 }
482 comptime var i = 1;
483 inline while (i < rounds) : (i += 1) {
484 ts = Block.parallel.encryptWide(count, ts, round_keys[i]);
485 }
486 ts = Block.parallel.encryptLastWide(count, ts, round_keys[i]);
487 j = 0;
488 inline while (j < count) : (j += 1) {
489 dst[16 * j .. 16 * j + 16].* = ts[j].toBytes();
490 }
491 }
492
493 /// Encrypt+XOR multiple blocks, possibly leveraging parallelization.
494 pub fn xorWide(ctx: Self, comptime count: usize, dst: *[16 * count]u8, src: *const [16 * count]u8, counters: [16 * count]u8) void {
495 const round_keys = ctx.key_schedule.round_keys;
496 var ts: [count]Block = undefined;
497 comptime var j = 0;
498 inline while (j < count) : (j += 1) {
499 ts[j] = Block.fromBytes(counters[j * 16 .. j * 16 + 16][0..16]).xorBlocks(round_keys[0]);
500 }
501 comptime var i = 1;
502 inline while (i < rounds) : (i += 1) {
503 ts = Block.parallel.encryptWide(count, ts, round_keys[i]);
504 }
505 ts = Block.parallel.encryptLastWide(count, ts, round_keys[i]);
506 j = 0;
507 inline while (j < count) : (j += 1) {
508 dst[16 * j .. 16 * j + 16].* = ts[j].xorBytes(src[16 * j .. 16 * j + 16]);
509 }
510 }
511 };
512}
513
514/// A context to perform decryption using the standard AES key schedule.
515pub fn AesDecryptCtx(comptime Aes: type) type {
516 std.debug.assert(Aes.key_bits == 128 or Aes.key_bits == 256);
517 const rounds = Aes.rounds;
518
519 return struct {
520 const Self = @This();
521 pub const block = Aes.block;
522 pub const block_length = block.block_length;
523 key_schedule: KeySchedule(Aes),
524
525 /// Create a decryption context from an existing encryption context.
526 pub fn initFromEnc(ctx: AesEncryptCtx(Aes)) Self {
527 return Self{
528 .key_schedule = ctx.key_schedule.invert(),
529 };
530 }
531
532 /// Create a new decryption context with the given key.
533 pub fn init(key: [Aes.key_bits / 8]u8) Self {
534 const enc_ctx = AesEncryptCtx(Aes).init(key);
535 return initFromEnc(enc_ctx);
536 }
537
538 /// Decrypt a single block.
539 pub fn decrypt(ctx: Self, dst: *[16]u8, src: *const [16]u8) void {
540 const inv_round_keys = ctx.key_schedule.round_keys;
541 var t = Block.fromBytes(src).xorBlocks(inv_round_keys[0]);
542 comptime var i = 1;
543 inline while (i < rounds) : (i += 1) {
544 t = t.decrypt(inv_round_keys[i]);
545 }
546 t = t.decryptLast(inv_round_keys[rounds]);
547 dst.* = t.toBytes();
548 }
549
550 /// Decrypt multiple blocks, possibly leveraging parallelization.
551 pub fn decryptWide(ctx: Self, comptime count: usize, dst: *[16 * count]u8, src: *const [16 * count]u8) void {
552 const inv_round_keys = ctx.key_schedule.round_keys;
553 var ts: [count]Block = undefined;
554 comptime var j = 0;
555 inline while (j < count) : (j += 1) {
556 ts[j] = Block.fromBytes(src[j * 16 .. j * 16 + 16][0..16]).xorBlocks(inv_round_keys[0]);
557 }
558 comptime var i = 1;
559 inline while (i < rounds) : (i += 1) {
560 ts = Block.parallel.decryptWide(count, ts, inv_round_keys[i]);
561 }
562 ts = Block.parallel.decryptLastWide(count, ts, inv_round_keys[i]);
563 j = 0;
564 inline while (j < count) : (j += 1) {
565 dst[16 * j .. 16 * j + 16].* = ts[j].toBytes();
566 }
567 }
568 };
569}
570
571/// AES-128 with the standard key schedule.
572pub const Aes128 = struct {
573 pub const key_bits: usize = 128;
574 pub const rounds = ((key_bits - 64) / 32 + 8);
575 pub const block = Block;
576
577 /// Create a new context for encryption.
578 pub fn initEnc(key: [key_bits / 8]u8) AesEncryptCtx(Aes128) {
579 return AesEncryptCtx(Aes128).init(key);
580 }
581
582 /// Create a new context for decryption.
583 pub fn initDec(key: [key_bits / 8]u8) AesDecryptCtx(Aes128) {
584 return AesDecryptCtx(Aes128).init(key);
585 }
586};
587
588/// AES-256 with the standard key schedule.
589pub const Aes256 = struct {
590 pub const key_bits: usize = 256;
591 pub const rounds = ((key_bits - 64) / 32 + 8);
592 pub const block = Block;
593
594 /// Create a new context for encryption.
595 pub fn initEnc(key: [key_bits / 8]u8) AesEncryptCtx(Aes256) {
596 return AesEncryptCtx(Aes256).init(key);
597 }
598
599 /// Create a new context for decryption.
600 pub fn initDec(key: [key_bits / 8]u8) AesDecryptCtx(Aes256) {
601 return AesDecryptCtx(Aes256).init(key);
602 }
603};