master
1const std = @import("../../std.zig");
2const builtin = @import("builtin");
3const mem = std.mem;
4const debug = std.debug;
5
6const has_vaes = builtin.cpu.arch == .x86_64 and builtin.cpu.has(.x86, .vaes);
7const has_avx512f = builtin.cpu.arch == .x86_64 and builtin.zig_backend != .stage2_x86_64 and builtin.cpu.has(.x86, .avx512f);
8
9/// A single AES block.
10pub const Block = struct {
11 const Repr = @Vector(2, u64);
12
13 /// The length of an AES block in bytes.
14 pub const block_length: usize = 16;
15
16 /// Internal representation of a block.
17 repr: Repr,
18
19 /// Convert a byte sequence into an internal representation.
20 pub fn fromBytes(bytes: *const [16]u8) Block {
21 const repr = mem.bytesToValue(Repr, bytes);
22 return Block{ .repr = repr };
23 }
24
25 /// Convert the internal representation of a block into a byte sequence.
26 pub fn toBytes(block: Block) [16]u8 {
27 return mem.toBytes(block.repr);
28 }
29
30 /// XOR the block with a byte sequence.
31 pub fn xorBytes(block: Block, bytes: *const [16]u8) [16]u8 {
32 const x = block.repr ^ fromBytes(bytes).repr;
33 return mem.toBytes(x);
34 }
35
36 /// Encrypt a block with a round key.
37 pub fn encrypt(block: Block, round_key: Block) Block {
38 return Block{
39 .repr = asm (
40 \\ vaesenc %[rk], %[in], %[out]
41 : [out] "=x" (-> Repr),
42 : [in] "x" (block.repr),
43 [rk] "x" (round_key.repr),
44 ),
45 };
46 }
47
48 /// Encrypt a block with the last round key.
49 pub fn encryptLast(block: Block, round_key: Block) Block {
50 return Block{
51 .repr = asm (
52 \\ vaesenclast %[rk], %[in], %[out]
53 : [out] "=x" (-> Repr),
54 : [in] "x" (block.repr),
55 [rk] "x" (round_key.repr),
56 ),
57 };
58 }
59
60 /// Decrypt a block with a round key.
61 pub fn decrypt(block: Block, inv_round_key: Block) Block {
62 return Block{
63 .repr = asm (
64 \\ vaesdec %[rk], %[in], %[out]
65 : [out] "=x" (-> Repr),
66 : [in] "x" (block.repr),
67 [rk] "x" (inv_round_key.repr),
68 ),
69 };
70 }
71
72 /// Decrypt a block with the last round key.
73 pub fn decryptLast(block: Block, inv_round_key: Block) Block {
74 return Block{
75 .repr = asm (
76 \\ vaesdeclast %[rk], %[in], %[out]
77 : [out] "=x" (-> Repr),
78 : [in] "x" (block.repr),
79 [rk] "x" (inv_round_key.repr),
80 ),
81 };
82 }
83
84 /// Apply the bitwise XOR operation to the content of two blocks.
85 pub fn xorBlocks(block1: Block, block2: Block) Block {
86 return Block{ .repr = block1.repr ^ block2.repr };
87 }
88
89 /// Apply the bitwise AND operation to the content of two blocks.
90 pub fn andBlocks(block1: Block, block2: Block) Block {
91 return Block{ .repr = block1.repr & block2.repr };
92 }
93
94 /// Apply the bitwise OR operation to the content of two blocks.
95 pub fn orBlocks(block1: Block, block2: Block) Block {
96 return Block{ .repr = block1.repr | block2.repr };
97 }
98
99 /// Apply the inverse MixColumns operation to a block.
100 pub fn invMixColumns(block: Block) Block {
101 return Block{
102 .repr = asm (
103 \\ vaesimc %[in], %[out]
104 : [out] "=x" (-> Repr),
105 : [in] "x" (block.repr),
106 ),
107 };
108 }
109
110 /// Perform operations on multiple blocks in parallel.
111 pub const parallel = struct {
112 const cpu = std.Target.x86.cpu;
113
114 /// The recommended number of AES encryption/decryption to perform in parallel for the chosen implementation.
115 pub const optimal_parallel_blocks = switch (builtin.cpu.model) {
116 &cpu.westmere, &cpu.goldmont => 3,
117 &cpu.cannonlake, &cpu.skylake, &cpu.skylake_avx512, &cpu.tremont, &cpu.goldmont_plus, &cpu.cascadelake => 4,
118 &cpu.icelake_client, &cpu.icelake_server, &cpu.tigerlake, &cpu.rocketlake, &cpu.alderlake => 6,
119 &cpu.haswell, &cpu.broadwell => 7,
120 &cpu.sandybridge, &cpu.ivybridge => 8,
121 &cpu.znver1, &cpu.znver2, &cpu.znver3, &cpu.znver4 => 8,
122 else => 8,
123 };
124
125 /// Encrypt multiple blocks in parallel, each their own round key.
126 pub fn encryptParallel(comptime count: usize, blocks: [count]Block, round_keys: [count]Block) [count]Block {
127 comptime var i = 0;
128 var out: [count]Block = undefined;
129 inline while (i < count) : (i += 1) {
130 out[i] = blocks[i].encrypt(round_keys[i]);
131 }
132 return out;
133 }
134
135 /// Decrypt multiple blocks in parallel, each their own round key.
136 pub fn decryptParallel(comptime count: usize, blocks: [count]Block, round_keys: [count]Block) [count]Block {
137 comptime var i = 0;
138 var out: [count]Block = undefined;
139 inline while (i < count) : (i += 1) {
140 out[i] = blocks[i].decrypt(round_keys[i]);
141 }
142 return out;
143 }
144
145 /// Encrypt multiple blocks in parallel with the same round key.
146 pub fn encryptWide(comptime count: usize, blocks: [count]Block, round_key: Block) [count]Block {
147 comptime var i = 0;
148 var out: [count]Block = undefined;
149 inline while (i < count) : (i += 1) {
150 out[i] = blocks[i].encrypt(round_key);
151 }
152 return out;
153 }
154
155 /// Decrypt multiple blocks in parallel with the same round key.
156 pub fn decryptWide(comptime count: usize, blocks: [count]Block, round_key: Block) [count]Block {
157 comptime var i = 0;
158 var out: [count]Block = undefined;
159 inline while (i < count) : (i += 1) {
160 out[i] = blocks[i].decrypt(round_key);
161 }
162 return out;
163 }
164
165 /// Encrypt multiple blocks in parallel with the same last round key.
166 pub fn encryptLastWide(comptime count: usize, blocks: [count]Block, round_key: Block) [count]Block {
167 comptime var i = 0;
168 var out: [count]Block = undefined;
169 inline while (i < count) : (i += 1) {
170 out[i] = blocks[i].encryptLast(round_key);
171 }
172 return out;
173 }
174
175 /// Decrypt multiple blocks in parallel with the same last round key.
176 pub fn decryptLastWide(comptime count: usize, blocks: [count]Block, round_key: Block) [count]Block {
177 comptime var i = 0;
178 var out: [count]Block = undefined;
179 inline while (i < count) : (i += 1) {
180 out[i] = blocks[i].decryptLast(round_key);
181 }
182 return out;
183 }
184 };
185};
186
187/// A fixed-size vector of AES blocks.
188/// All operations are performed in parallel, using SIMD instructions when available.
189pub fn BlockVec(comptime blocks_count: comptime_int) type {
190 return struct {
191 const Self = @This();
192
193 /// The number of AES blocks the target architecture can process with a single instruction.
194 pub const native_vector_size = w: {
195 if (has_avx512f and blocks_count % 4 == 0) break :w 4;
196 if (has_vaes and blocks_count % 2 == 0) break :w 2;
197 break :w 1;
198 };
199
200 /// The size of the AES block vector that the target architecture can process with a single instruction, in bytes.
201 pub const native_word_size = native_vector_size * 16;
202
203 const native_words = blocks_count / native_vector_size;
204
205 const Repr = @Vector(native_vector_size * 2, u64);
206
207 /// Internal representation of a block vector.
208 repr: [native_words]Repr,
209
210 /// Length of the block vector in bytes.
211 pub const block_length: usize = blocks_count * 16;
212
213 /// Convert a byte sequence into an internal representation.
214 pub fn fromBytes(bytes: *const [blocks_count * 16]u8) Self {
215 var out: Self = undefined;
216 inline for (0..native_words) |i| {
217 out.repr[i] = mem.bytesToValue(Repr, bytes[i * native_word_size ..][0..native_word_size]);
218 }
219 return out;
220 }
221
222 /// Convert the internal representation of a block vector into a byte sequence.
223 pub fn toBytes(block_vec: Self) [blocks_count * 16]u8 {
224 var out: [blocks_count * 16]u8 = undefined;
225 inline for (0..native_words) |i| {
226 out[i * native_word_size ..][0..native_word_size].* = mem.toBytes(block_vec.repr[i]);
227 }
228 return out;
229 }
230
231 /// XOR the block vector with a byte sequence.
232 pub fn xorBytes(block_vec: Self, bytes: *const [blocks_count * 16]u8) [blocks_count * 16]u8 {
233 var x: Self = undefined;
234 inline for (0..native_words) |i| {
235 x.repr[i] = block_vec.repr[i] ^ mem.bytesToValue(Repr, bytes[i * native_word_size ..][0..native_word_size]);
236 }
237 return x.toBytes();
238 }
239
240 /// Apply the forward AES operation to the block vector with a vector of round keys.
241 pub fn encrypt(block_vec: Self, round_key_vec: Self) Self {
242 var out: Self = undefined;
243 inline for (0..native_words) |i| {
244 out.repr[i] = asm (
245 \\ vaesenc %[rk], %[in], %[out]
246 : [out] "=x" (-> Repr),
247 : [in] "x" (block_vec.repr[i]),
248 [rk] "x" (round_key_vec.repr[i]),
249 );
250 }
251 return out;
252 }
253
254 /// Apply the forward AES operation to the block vector with a vector of last round keys.
255 pub fn encryptLast(block_vec: Self, round_key_vec: Self) Self {
256 var out: Self = undefined;
257 inline for (0..native_words) |i| {
258 out.repr[i] = asm (
259 \\ vaesenclast %[rk], %[in], %[out]
260 : [out] "=x" (-> Repr),
261 : [in] "x" (block_vec.repr[i]),
262 [rk] "x" (round_key_vec.repr[i]),
263 );
264 }
265 return out;
266 }
267
268 /// Apply the inverse AES operation to the block vector with a vector of round keys.
269 pub fn decrypt(block_vec: Self, inv_round_key_vec: Self) Self {
270 var out: Self = undefined;
271 inline for (0..native_words) |i| {
272 out.repr[i] = asm (
273 \\ vaesdec %[rk], %[in], %[out]
274 : [out] "=x" (-> Repr),
275 : [in] "x" (block_vec.repr[i]),
276 [rk] "x" (inv_round_key_vec.repr[i]),
277 );
278 }
279 return out;
280 }
281
282 /// Apply the inverse AES operation to the block vector with a vector of last round keys.
283 pub fn decryptLast(block_vec: Self, inv_round_key_vec: Self) Self {
284 var out: Self = undefined;
285 inline for (0..native_words) |i| {
286 out.repr[i] = asm (
287 \\ vaesdeclast %[rk], %[in], %[out]
288 : [out] "=x" (-> Repr),
289 : [in] "x" (block_vec.repr[i]),
290 [rk] "x" (inv_round_key_vec.repr[i]),
291 );
292 }
293 return out;
294 }
295
296 /// Apply the bitwise XOR operation to the content of two block vectors.
297 pub fn xorBlocks(block_vec1: Self, block_vec2: Self) Self {
298 var out: Self = undefined;
299 inline for (0..native_words) |i| {
300 out.repr[i] = block_vec1.repr[i] ^ block_vec2.repr[i];
301 }
302 return out;
303 }
304
305 /// Apply the bitwise AND operation to the content of two block vectors.
306 pub fn andBlocks(block_vec1: Self, block_vec2: Self) Self {
307 var out: Self = undefined;
308 inline for (0..native_words) |i| {
309 out.repr[i] = block_vec1.repr[i] & block_vec2.repr[i];
310 }
311 return out;
312 }
313
314 /// Apply the bitwise OR operation to the content of two block vectors.
315 pub fn orBlocks(block_vec1: Self, block_vec2: Block) Self {
316 var out: Self = undefined;
317 inline for (0..native_words) |i| {
318 out.repr[i] = block_vec1.repr[i] | block_vec2.repr[i];
319 }
320 return out;
321 }
322
323 /// Apply the inverse MixColumns operation to each block in the vector.
324 pub fn invMixColumns(block_vec: Self) Self {
325 var out_bytes: [blocks_count * 16]u8 = undefined;
326 const in_bytes = block_vec.toBytes();
327 inline for (0..blocks_count) |i| {
328 const block = Block.fromBytes(in_bytes[i * 16 ..][0..16]);
329 out_bytes[i * 16 ..][0..16].* = block.invMixColumns().toBytes();
330 }
331 return fromBytes(&out_bytes);
332 }
333 };
334}
335
336fn KeySchedule(comptime Aes: type) type {
337 std.debug.assert(Aes.rounds == 10 or Aes.rounds == 14);
338 const rounds = Aes.rounds;
339
340 return struct {
341 const Self = @This();
342
343 const Repr = Aes.block.Repr;
344
345 round_keys: [rounds + 1]Block,
346
347 fn drc(comptime second: bool, comptime rc: u8, t: Repr, tx: Repr) Repr {
348 var s: Repr = undefined;
349 var ts: Repr = undefined;
350 return asm (
351 \\ vaeskeygenassist %[rc], %[t], %[s]
352 \\ vpslldq $4, %[tx], %[ts]
353 \\ vpxor %[ts], %[tx], %[r]
354 \\ vpslldq $8, %[r], %[ts]
355 \\ vpxor %[ts], %[r], %[r]
356 \\ vpshufd %[mask], %[s], %[ts]
357 \\ vpxor %[ts], %[r], %[r]
358 : [r] "=&x" (-> Repr),
359 [s] "=&x" (s),
360 [ts] "=&x" (ts),
361 : [rc] "n" (rc),
362 [t] "x" (t),
363 [tx] "x" (tx),
364 [mask] "n" (@as(u8, if (second) 0xaa else 0xff)),
365 );
366 }
367
368 fn expand128(t1: *Block) Self {
369 var round_keys: [11]Block = undefined;
370 const rcs = [_]u8{ 1, 2, 4, 8, 16, 32, 64, 128, 27, 54 };
371 inline for (rcs, 0..) |rc, round| {
372 round_keys[round] = t1.*;
373 t1.repr = drc(false, rc, t1.repr, t1.repr);
374 }
375 round_keys[rcs.len] = t1.*;
376 return Self{ .round_keys = round_keys };
377 }
378
379 fn expand256(t1: *Block, t2: *Block) Self {
380 var round_keys: [15]Block = undefined;
381 const rcs = [_]u8{ 1, 2, 4, 8, 16, 32 };
382 round_keys[0] = t1.*;
383 inline for (rcs, 0..) |rc, round| {
384 round_keys[round * 2 + 1] = t2.*;
385 t1.repr = drc(false, rc, t2.repr, t1.repr);
386 round_keys[round * 2 + 2] = t1.*;
387 t2.repr = drc(true, rc, t1.repr, t2.repr);
388 }
389 round_keys[rcs.len * 2 + 1] = t2.*;
390 t1.repr = drc(false, 64, t2.repr, t1.repr);
391 round_keys[rcs.len * 2 + 2] = t1.*;
392 return Self{ .round_keys = round_keys };
393 }
394
395 /// Invert the key schedule.
396 pub fn invert(key_schedule: Self) Self {
397 const round_keys = &key_schedule.round_keys;
398 var inv_round_keys: [rounds + 1]Block = undefined;
399 inv_round_keys[0] = round_keys[rounds];
400 comptime var i = 1;
401 inline while (i < rounds) : (i += 1) {
402 inv_round_keys[i] = Block{
403 .repr = asm (
404 \\ vaesimc %[rk], %[inv_rk]
405 : [inv_rk] "=x" (-> Repr),
406 : [rk] "x" (round_keys[rounds - i].repr),
407 ),
408 };
409 }
410 inv_round_keys[rounds] = round_keys[0];
411 return Self{ .round_keys = inv_round_keys };
412 }
413 };
414}
415
416/// A context to perform encryption using the standard AES key schedule.
417pub fn AesEncryptCtx(comptime Aes: type) type {
418 std.debug.assert(Aes.key_bits == 128 or Aes.key_bits == 256);
419 const rounds = Aes.rounds;
420
421 return struct {
422 const Self = @This();
423 pub const block = Aes.block;
424 pub const block_length = block.block_length;
425 key_schedule: KeySchedule(Aes),
426
427 /// Create a new encryption context with the given key.
428 pub fn init(key: [Aes.key_bits / 8]u8) Self {
429 var t1 = Block.fromBytes(key[0..16]);
430 const key_schedule = if (Aes.key_bits == 128) ks: {
431 break :ks KeySchedule(Aes).expand128(&t1);
432 } else ks: {
433 var t2 = Block.fromBytes(key[16..32]);
434 break :ks KeySchedule(Aes).expand256(&t1, &t2);
435 };
436 return Self{
437 .key_schedule = key_schedule,
438 };
439 }
440
441 /// Encrypt a single block.
442 pub fn encrypt(ctx: Self, dst: *[16]u8, src: *const [16]u8) void {
443 const round_keys = ctx.key_schedule.round_keys;
444 var t = Block.fromBytes(src).xorBlocks(round_keys[0]);
445 comptime var i = 1;
446 inline while (i < rounds) : (i += 1) {
447 t = t.encrypt(round_keys[i]);
448 }
449 t = t.encryptLast(round_keys[rounds]);
450 dst.* = t.toBytes();
451 }
452
453 /// Encrypt+XOR a single block.
454 pub fn xor(ctx: Self, dst: *[16]u8, src: *const [16]u8, counter: [16]u8) void {
455 const round_keys = ctx.key_schedule.round_keys;
456 var t = Block.fromBytes(&counter).xorBlocks(round_keys[0]);
457 comptime var i = 1;
458 inline while (i < rounds) : (i += 1) {
459 t = t.encrypt(round_keys[i]);
460 }
461 t = t.encryptLast(round_keys[rounds]);
462 dst.* = t.xorBytes(src);
463 }
464
465 /// Encrypt multiple blocks, possibly leveraging parallelization.
466 pub fn encryptWide(ctx: Self, comptime count: usize, dst: *[16 * count]u8, src: *const [16 * count]u8) void {
467 const round_keys = ctx.key_schedule.round_keys;
468 var ts: [count]Block = undefined;
469 comptime var j = 0;
470 inline while (j < count) : (j += 1) {
471 ts[j] = Block.fromBytes(src[j * 16 .. j * 16 + 16][0..16]).xorBlocks(round_keys[0]);
472 }
473 comptime var i = 1;
474 inline while (i < rounds) : (i += 1) {
475 ts = Block.parallel.encryptWide(count, ts, round_keys[i]);
476 }
477 ts = Block.parallel.encryptLastWide(count, ts, round_keys[i]);
478 j = 0;
479 inline while (j < count) : (j += 1) {
480 dst[16 * j .. 16 * j + 16].* = ts[j].toBytes();
481 }
482 }
483
484 /// Encrypt+XOR multiple blocks, possibly leveraging parallelization.
485 pub fn xorWide(ctx: Self, comptime count: usize, dst: *[16 * count]u8, src: *const [16 * count]u8, counters: [16 * count]u8) void {
486 const round_keys = ctx.key_schedule.round_keys;
487 var ts: [count]Block = undefined;
488 comptime var j = 0;
489 inline while (j < count) : (j += 1) {
490 ts[j] = Block.fromBytes(counters[j * 16 .. j * 16 + 16][0..16]).xorBlocks(round_keys[0]);
491 }
492 comptime var i = 1;
493 inline while (i < rounds) : (i += 1) {
494 ts = Block.parallel.encryptWide(count, ts, round_keys[i]);
495 }
496 ts = Block.parallel.encryptLastWide(count, ts, round_keys[i]);
497 j = 0;
498 inline while (j < count) : (j += 1) {
499 dst[16 * j .. 16 * j + 16].* = ts[j].xorBytes(src[16 * j .. 16 * j + 16]);
500 }
501 }
502 };
503}
504
505/// A context to perform decryption using the standard AES key schedule.
506pub fn AesDecryptCtx(comptime Aes: type) type {
507 std.debug.assert(Aes.key_bits == 128 or Aes.key_bits == 256);
508 const rounds = Aes.rounds;
509
510 return struct {
511 const Self = @This();
512 pub const block = Aes.block;
513 pub const block_length = block.block_length;
514 key_schedule: KeySchedule(Aes),
515
516 /// Create a decryption context from an existing encryption context.
517 pub fn initFromEnc(ctx: AesEncryptCtx(Aes)) Self {
518 return Self{
519 .key_schedule = ctx.key_schedule.invert(),
520 };
521 }
522
523 /// Create a new decryption context with the given key.
524 pub fn init(key: [Aes.key_bits / 8]u8) Self {
525 const enc_ctx = AesEncryptCtx(Aes).init(key);
526 return initFromEnc(enc_ctx);
527 }
528
529 /// Decrypt a single block.
530 pub fn decrypt(ctx: Self, dst: *[16]u8, src: *const [16]u8) void {
531 const inv_round_keys = ctx.key_schedule.round_keys;
532 var t = Block.fromBytes(src).xorBlocks(inv_round_keys[0]);
533 comptime var i = 1;
534 inline while (i < rounds) : (i += 1) {
535 t = t.decrypt(inv_round_keys[i]);
536 }
537 t = t.decryptLast(inv_round_keys[rounds]);
538 dst.* = t.toBytes();
539 }
540
541 /// Decrypt multiple blocks, possibly leveraging parallelization.
542 pub fn decryptWide(ctx: Self, comptime count: usize, dst: *[16 * count]u8, src: *const [16 * count]u8) void {
543 const inv_round_keys = ctx.key_schedule.round_keys;
544 var ts: [count]Block = undefined;
545 comptime var j = 0;
546 inline while (j < count) : (j += 1) {
547 ts[j] = Block.fromBytes(src[j * 16 .. j * 16 + 16][0..16]).xorBlocks(inv_round_keys[0]);
548 }
549 comptime var i = 1;
550 inline while (i < rounds) : (i += 1) {
551 ts = Block.parallel.decryptWide(count, ts, inv_round_keys[i]);
552 }
553 ts = Block.parallel.decryptLastWide(count, ts, inv_round_keys[i]);
554 j = 0;
555 inline while (j < count) : (j += 1) {
556 dst[16 * j .. 16 * j + 16].* = ts[j].toBytes();
557 }
558 }
559 };
560}
561
562/// AES-128 with the standard key schedule.
563pub const Aes128 = struct {
564 pub const key_bits: usize = 128;
565 pub const rounds = ((key_bits - 64) / 32 + 8);
566 pub const block = Block;
567
568 /// Create a new context for encryption.
569 pub fn initEnc(key: [key_bits / 8]u8) AesEncryptCtx(Aes128) {
570 return AesEncryptCtx(Aes128).init(key);
571 }
572
573 /// Create a new context for decryption.
574 pub fn initDec(key: [key_bits / 8]u8) AesDecryptCtx(Aes128) {
575 return AesDecryptCtx(Aes128).init(key);
576 }
577};
578
579/// AES-256 with the standard key schedule.
580pub const Aes256 = struct {
581 pub const key_bits: usize = 256;
582 pub const rounds = ((key_bits - 64) / 32 + 8);
583 pub const block = Block;
584
585 /// Create a new context for encryption.
586 pub fn initEnc(key: [key_bits / 8]u8) AesEncryptCtx(Aes256) {
587 return AesEncryptCtx(Aes256).init(key);
588 }
589
590 /// Create a new context for decryption.
591 pub fn initDec(key: [key_bits / 8]u8) AesDecryptCtx(Aes256) {
592 return AesDecryptCtx(Aes256).init(key);
593 }
594};