master
  1const std = @import("../../std.zig");
  2const builtin = @import("builtin");
  3const mem = std.mem;
  4const debug = std.debug;
  5
  6const has_vaes = builtin.cpu.arch == .x86_64 and builtin.cpu.has(.x86, .vaes);
  7const has_avx512f = builtin.cpu.arch == .x86_64 and builtin.zig_backend != .stage2_x86_64 and builtin.cpu.has(.x86, .avx512f);
  8
  9/// A single AES block.
 10pub const Block = struct {
 11    const Repr = @Vector(2, u64);
 12
 13    /// The length of an AES block in bytes.
 14    pub const block_length: usize = 16;
 15
 16    /// Internal representation of a block.
 17    repr: Repr,
 18
 19    /// Convert a byte sequence into an internal representation.
 20    pub fn fromBytes(bytes: *const [16]u8) Block {
 21        const repr = mem.bytesToValue(Repr, bytes);
 22        return Block{ .repr = repr };
 23    }
 24
 25    /// Convert the internal representation of a block into a byte sequence.
 26    pub fn toBytes(block: Block) [16]u8 {
 27        return mem.toBytes(block.repr);
 28    }
 29
 30    /// XOR the block with a byte sequence.
 31    pub fn xorBytes(block: Block, bytes: *const [16]u8) [16]u8 {
 32        const x = block.repr ^ fromBytes(bytes).repr;
 33        return mem.toBytes(x);
 34    }
 35
 36    /// Encrypt a block with a round key.
 37    pub fn encrypt(block: Block, round_key: Block) Block {
 38        return Block{
 39            .repr = asm (
 40                \\ vaesenc %[rk], %[in], %[out]
 41                : [out] "=x" (-> Repr),
 42                : [in] "x" (block.repr),
 43                  [rk] "x" (round_key.repr),
 44            ),
 45        };
 46    }
 47
 48    /// Encrypt a block with the last round key.
 49    pub fn encryptLast(block: Block, round_key: Block) Block {
 50        return Block{
 51            .repr = asm (
 52                \\ vaesenclast %[rk], %[in], %[out]
 53                : [out] "=x" (-> Repr),
 54                : [in] "x" (block.repr),
 55                  [rk] "x" (round_key.repr),
 56            ),
 57        };
 58    }
 59
 60    /// Decrypt a block with a round key.
 61    pub fn decrypt(block: Block, inv_round_key: Block) Block {
 62        return Block{
 63            .repr = asm (
 64                \\ vaesdec %[rk], %[in], %[out]
 65                : [out] "=x" (-> Repr),
 66                : [in] "x" (block.repr),
 67                  [rk] "x" (inv_round_key.repr),
 68            ),
 69        };
 70    }
 71
 72    /// Decrypt a block with the last round key.
 73    pub fn decryptLast(block: Block, inv_round_key: Block) Block {
 74        return Block{
 75            .repr = asm (
 76                \\ vaesdeclast %[rk], %[in], %[out]
 77                : [out] "=x" (-> Repr),
 78                : [in] "x" (block.repr),
 79                  [rk] "x" (inv_round_key.repr),
 80            ),
 81        };
 82    }
 83
 84    /// Apply the bitwise XOR operation to the content of two blocks.
 85    pub fn xorBlocks(block1: Block, block2: Block) Block {
 86        return Block{ .repr = block1.repr ^ block2.repr };
 87    }
 88
 89    /// Apply the bitwise AND operation to the content of two blocks.
 90    pub fn andBlocks(block1: Block, block2: Block) Block {
 91        return Block{ .repr = block1.repr & block2.repr };
 92    }
 93
 94    /// Apply the bitwise OR operation to the content of two blocks.
 95    pub fn orBlocks(block1: Block, block2: Block) Block {
 96        return Block{ .repr = block1.repr | block2.repr };
 97    }
 98
 99    /// Apply the inverse MixColumns operation to a block.
100    pub fn invMixColumns(block: Block) Block {
101        return Block{
102            .repr = asm (
103                \\ vaesimc %[in], %[out]
104                : [out] "=x" (-> Repr),
105                : [in] "x" (block.repr),
106            ),
107        };
108    }
109
110    /// Perform operations on multiple blocks in parallel.
111    pub const parallel = struct {
112        const cpu = std.Target.x86.cpu;
113
114        /// The recommended number of AES encryption/decryption to perform in parallel for the chosen implementation.
115        pub const optimal_parallel_blocks = switch (builtin.cpu.model) {
116            &cpu.westmere, &cpu.goldmont => 3,
117            &cpu.cannonlake, &cpu.skylake, &cpu.skylake_avx512, &cpu.tremont, &cpu.goldmont_plus, &cpu.cascadelake => 4,
118            &cpu.icelake_client, &cpu.icelake_server, &cpu.tigerlake, &cpu.rocketlake, &cpu.alderlake => 6,
119            &cpu.haswell, &cpu.broadwell => 7,
120            &cpu.sandybridge, &cpu.ivybridge => 8,
121            &cpu.znver1, &cpu.znver2, &cpu.znver3, &cpu.znver4 => 8,
122            else => 8,
123        };
124
125        /// Encrypt multiple blocks in parallel, each their own round key.
126        pub fn encryptParallel(comptime count: usize, blocks: [count]Block, round_keys: [count]Block) [count]Block {
127            comptime var i = 0;
128            var out: [count]Block = undefined;
129            inline while (i < count) : (i += 1) {
130                out[i] = blocks[i].encrypt(round_keys[i]);
131            }
132            return out;
133        }
134
135        /// Decrypt multiple blocks in parallel, each their own round key.
136        pub fn decryptParallel(comptime count: usize, blocks: [count]Block, round_keys: [count]Block) [count]Block {
137            comptime var i = 0;
138            var out: [count]Block = undefined;
139            inline while (i < count) : (i += 1) {
140                out[i] = blocks[i].decrypt(round_keys[i]);
141            }
142            return out;
143        }
144
145        /// Encrypt multiple blocks in parallel with the same round key.
146        pub fn encryptWide(comptime count: usize, blocks: [count]Block, round_key: Block) [count]Block {
147            comptime var i = 0;
148            var out: [count]Block = undefined;
149            inline while (i < count) : (i += 1) {
150                out[i] = blocks[i].encrypt(round_key);
151            }
152            return out;
153        }
154
155        /// Decrypt multiple blocks in parallel with the same round key.
156        pub fn decryptWide(comptime count: usize, blocks: [count]Block, round_key: Block) [count]Block {
157            comptime var i = 0;
158            var out: [count]Block = undefined;
159            inline while (i < count) : (i += 1) {
160                out[i] = blocks[i].decrypt(round_key);
161            }
162            return out;
163        }
164
165        /// Encrypt multiple blocks in parallel with the same last round key.
166        pub fn encryptLastWide(comptime count: usize, blocks: [count]Block, round_key: Block) [count]Block {
167            comptime var i = 0;
168            var out: [count]Block = undefined;
169            inline while (i < count) : (i += 1) {
170                out[i] = blocks[i].encryptLast(round_key);
171            }
172            return out;
173        }
174
175        /// Decrypt multiple blocks in parallel with the same last round key.
176        pub fn decryptLastWide(comptime count: usize, blocks: [count]Block, round_key: Block) [count]Block {
177            comptime var i = 0;
178            var out: [count]Block = undefined;
179            inline while (i < count) : (i += 1) {
180                out[i] = blocks[i].decryptLast(round_key);
181            }
182            return out;
183        }
184    };
185};
186
187/// A fixed-size vector of AES blocks.
188/// All operations are performed in parallel, using SIMD instructions when available.
189pub fn BlockVec(comptime blocks_count: comptime_int) type {
190    return struct {
191        const Self = @This();
192
193        /// The number of AES blocks the target architecture can process with a single instruction.
194        pub const native_vector_size = w: {
195            if (has_avx512f and blocks_count % 4 == 0) break :w 4;
196            if (has_vaes and blocks_count % 2 == 0) break :w 2;
197            break :w 1;
198        };
199
200        /// The size of the AES block vector that the target architecture can process with a single instruction, in bytes.
201        pub const native_word_size = native_vector_size * 16;
202
203        const native_words = blocks_count / native_vector_size;
204
205        const Repr = @Vector(native_vector_size * 2, u64);
206
207        /// Internal representation of a block vector.
208        repr: [native_words]Repr,
209
210        /// Length of the block vector in bytes.
211        pub const block_length: usize = blocks_count * 16;
212
213        /// Convert a byte sequence into an internal representation.
214        pub fn fromBytes(bytes: *const [blocks_count * 16]u8) Self {
215            var out: Self = undefined;
216            inline for (0..native_words) |i| {
217                out.repr[i] = mem.bytesToValue(Repr, bytes[i * native_word_size ..][0..native_word_size]);
218            }
219            return out;
220        }
221
222        /// Convert the internal representation of a block vector into a byte sequence.
223        pub fn toBytes(block_vec: Self) [blocks_count * 16]u8 {
224            var out: [blocks_count * 16]u8 = undefined;
225            inline for (0..native_words) |i| {
226                out[i * native_word_size ..][0..native_word_size].* = mem.toBytes(block_vec.repr[i]);
227            }
228            return out;
229        }
230
231        /// XOR the block vector with a byte sequence.
232        pub fn xorBytes(block_vec: Self, bytes: *const [blocks_count * 16]u8) [blocks_count * 16]u8 {
233            var x: Self = undefined;
234            inline for (0..native_words) |i| {
235                x.repr[i] = block_vec.repr[i] ^ mem.bytesToValue(Repr, bytes[i * native_word_size ..][0..native_word_size]);
236            }
237            return x.toBytes();
238        }
239
240        /// Apply the forward AES operation to the block vector with a vector of round keys.
241        pub fn encrypt(block_vec: Self, round_key_vec: Self) Self {
242            var out: Self = undefined;
243            inline for (0..native_words) |i| {
244                out.repr[i] = asm (
245                    \\ vaesenc %[rk], %[in], %[out]
246                    : [out] "=x" (-> Repr),
247                    : [in] "x" (block_vec.repr[i]),
248                      [rk] "x" (round_key_vec.repr[i]),
249                );
250            }
251            return out;
252        }
253
254        /// Apply the forward AES operation to the block vector with a vector of last round keys.
255        pub fn encryptLast(block_vec: Self, round_key_vec: Self) Self {
256            var out: Self = undefined;
257            inline for (0..native_words) |i| {
258                out.repr[i] = asm (
259                    \\ vaesenclast %[rk], %[in], %[out]
260                    : [out] "=x" (-> Repr),
261                    : [in] "x" (block_vec.repr[i]),
262                      [rk] "x" (round_key_vec.repr[i]),
263                );
264            }
265            return out;
266        }
267
268        /// Apply the inverse AES operation to the block vector with a vector of round keys.
269        pub fn decrypt(block_vec: Self, inv_round_key_vec: Self) Self {
270            var out: Self = undefined;
271            inline for (0..native_words) |i| {
272                out.repr[i] = asm (
273                    \\ vaesdec %[rk], %[in], %[out]
274                    : [out] "=x" (-> Repr),
275                    : [in] "x" (block_vec.repr[i]),
276                      [rk] "x" (inv_round_key_vec.repr[i]),
277                );
278            }
279            return out;
280        }
281
282        /// Apply the inverse AES operation to the block vector with a vector of last round keys.
283        pub fn decryptLast(block_vec: Self, inv_round_key_vec: Self) Self {
284            var out: Self = undefined;
285            inline for (0..native_words) |i| {
286                out.repr[i] = asm (
287                    \\ vaesdeclast %[rk], %[in], %[out]
288                    : [out] "=x" (-> Repr),
289                    : [in] "x" (block_vec.repr[i]),
290                      [rk] "x" (inv_round_key_vec.repr[i]),
291                );
292            }
293            return out;
294        }
295
296        /// Apply the bitwise XOR operation to the content of two block vectors.
297        pub fn xorBlocks(block_vec1: Self, block_vec2: Self) Self {
298            var out: Self = undefined;
299            inline for (0..native_words) |i| {
300                out.repr[i] = block_vec1.repr[i] ^ block_vec2.repr[i];
301            }
302            return out;
303        }
304
305        /// Apply the bitwise AND operation to the content of two block vectors.
306        pub fn andBlocks(block_vec1: Self, block_vec2: Self) Self {
307            var out: Self = undefined;
308            inline for (0..native_words) |i| {
309                out.repr[i] = block_vec1.repr[i] & block_vec2.repr[i];
310            }
311            return out;
312        }
313
314        /// Apply the bitwise OR operation to the content of two block vectors.
315        pub fn orBlocks(block_vec1: Self, block_vec2: Block) Self {
316            var out: Self = undefined;
317            inline for (0..native_words) |i| {
318                out.repr[i] = block_vec1.repr[i] | block_vec2.repr[i];
319            }
320            return out;
321        }
322
323        /// Apply the inverse MixColumns operation to each block in the vector.
324        pub fn invMixColumns(block_vec: Self) Self {
325            var out_bytes: [blocks_count * 16]u8 = undefined;
326            const in_bytes = block_vec.toBytes();
327            inline for (0..blocks_count) |i| {
328                const block = Block.fromBytes(in_bytes[i * 16 ..][0..16]);
329                out_bytes[i * 16 ..][0..16].* = block.invMixColumns().toBytes();
330            }
331            return fromBytes(&out_bytes);
332        }
333    };
334}
335
336fn KeySchedule(comptime Aes: type) type {
337    std.debug.assert(Aes.rounds == 10 or Aes.rounds == 14);
338    const rounds = Aes.rounds;
339
340    return struct {
341        const Self = @This();
342
343        const Repr = Aes.block.Repr;
344
345        round_keys: [rounds + 1]Block,
346
347        fn drc(comptime second: bool, comptime rc: u8, t: Repr, tx: Repr) Repr {
348            var s: Repr = undefined;
349            var ts: Repr = undefined;
350            return asm (
351                \\ vaeskeygenassist %[rc], %[t], %[s]
352                \\ vpslldq $4, %[tx], %[ts]
353                \\ vpxor   %[ts], %[tx], %[r]
354                \\ vpslldq $8, %[r], %[ts]
355                \\ vpxor   %[ts], %[r], %[r]
356                \\ vpshufd %[mask], %[s], %[ts]
357                \\ vpxor   %[ts], %[r], %[r]
358                : [r] "=&x" (-> Repr),
359                  [s] "=&x" (s),
360                  [ts] "=&x" (ts),
361                : [rc] "n" (rc),
362                  [t] "x" (t),
363                  [tx] "x" (tx),
364                  [mask] "n" (@as(u8, if (second) 0xaa else 0xff)),
365            );
366        }
367
368        fn expand128(t1: *Block) Self {
369            var round_keys: [11]Block = undefined;
370            const rcs = [_]u8{ 1, 2, 4, 8, 16, 32, 64, 128, 27, 54 };
371            inline for (rcs, 0..) |rc, round| {
372                round_keys[round] = t1.*;
373                t1.repr = drc(false, rc, t1.repr, t1.repr);
374            }
375            round_keys[rcs.len] = t1.*;
376            return Self{ .round_keys = round_keys };
377        }
378
379        fn expand256(t1: *Block, t2: *Block) Self {
380            var round_keys: [15]Block = undefined;
381            const rcs = [_]u8{ 1, 2, 4, 8, 16, 32 };
382            round_keys[0] = t1.*;
383            inline for (rcs, 0..) |rc, round| {
384                round_keys[round * 2 + 1] = t2.*;
385                t1.repr = drc(false, rc, t2.repr, t1.repr);
386                round_keys[round * 2 + 2] = t1.*;
387                t2.repr = drc(true, rc, t1.repr, t2.repr);
388            }
389            round_keys[rcs.len * 2 + 1] = t2.*;
390            t1.repr = drc(false, 64, t2.repr, t1.repr);
391            round_keys[rcs.len * 2 + 2] = t1.*;
392            return Self{ .round_keys = round_keys };
393        }
394
395        /// Invert the key schedule.
396        pub fn invert(key_schedule: Self) Self {
397            const round_keys = &key_schedule.round_keys;
398            var inv_round_keys: [rounds + 1]Block = undefined;
399            inv_round_keys[0] = round_keys[rounds];
400            comptime var i = 1;
401            inline while (i < rounds) : (i += 1) {
402                inv_round_keys[i] = Block{
403                    .repr = asm (
404                        \\ vaesimc %[rk], %[inv_rk]
405                        : [inv_rk] "=x" (-> Repr),
406                        : [rk] "x" (round_keys[rounds - i].repr),
407                    ),
408                };
409            }
410            inv_round_keys[rounds] = round_keys[0];
411            return Self{ .round_keys = inv_round_keys };
412        }
413    };
414}
415
416/// A context to perform encryption using the standard AES key schedule.
417pub fn AesEncryptCtx(comptime Aes: type) type {
418    std.debug.assert(Aes.key_bits == 128 or Aes.key_bits == 256);
419    const rounds = Aes.rounds;
420
421    return struct {
422        const Self = @This();
423        pub const block = Aes.block;
424        pub const block_length = block.block_length;
425        key_schedule: KeySchedule(Aes),
426
427        /// Create a new encryption context with the given key.
428        pub fn init(key: [Aes.key_bits / 8]u8) Self {
429            var t1 = Block.fromBytes(key[0..16]);
430            const key_schedule = if (Aes.key_bits == 128) ks: {
431                break :ks KeySchedule(Aes).expand128(&t1);
432            } else ks: {
433                var t2 = Block.fromBytes(key[16..32]);
434                break :ks KeySchedule(Aes).expand256(&t1, &t2);
435            };
436            return Self{
437                .key_schedule = key_schedule,
438            };
439        }
440
441        /// Encrypt a single block.
442        pub fn encrypt(ctx: Self, dst: *[16]u8, src: *const [16]u8) void {
443            const round_keys = ctx.key_schedule.round_keys;
444            var t = Block.fromBytes(src).xorBlocks(round_keys[0]);
445            comptime var i = 1;
446            inline while (i < rounds) : (i += 1) {
447                t = t.encrypt(round_keys[i]);
448            }
449            t = t.encryptLast(round_keys[rounds]);
450            dst.* = t.toBytes();
451        }
452
453        /// Encrypt+XOR a single block.
454        pub fn xor(ctx: Self, dst: *[16]u8, src: *const [16]u8, counter: [16]u8) void {
455            const round_keys = ctx.key_schedule.round_keys;
456            var t = Block.fromBytes(&counter).xorBlocks(round_keys[0]);
457            comptime var i = 1;
458            inline while (i < rounds) : (i += 1) {
459                t = t.encrypt(round_keys[i]);
460            }
461            t = t.encryptLast(round_keys[rounds]);
462            dst.* = t.xorBytes(src);
463        }
464
465        /// Encrypt multiple blocks, possibly leveraging parallelization.
466        pub fn encryptWide(ctx: Self, comptime count: usize, dst: *[16 * count]u8, src: *const [16 * count]u8) void {
467            const round_keys = ctx.key_schedule.round_keys;
468            var ts: [count]Block = undefined;
469            comptime var j = 0;
470            inline while (j < count) : (j += 1) {
471                ts[j] = Block.fromBytes(src[j * 16 .. j * 16 + 16][0..16]).xorBlocks(round_keys[0]);
472            }
473            comptime var i = 1;
474            inline while (i < rounds) : (i += 1) {
475                ts = Block.parallel.encryptWide(count, ts, round_keys[i]);
476            }
477            ts = Block.parallel.encryptLastWide(count, ts, round_keys[i]);
478            j = 0;
479            inline while (j < count) : (j += 1) {
480                dst[16 * j .. 16 * j + 16].* = ts[j].toBytes();
481            }
482        }
483
484        /// Encrypt+XOR multiple blocks, possibly leveraging parallelization.
485        pub fn xorWide(ctx: Self, comptime count: usize, dst: *[16 * count]u8, src: *const [16 * count]u8, counters: [16 * count]u8) void {
486            const round_keys = ctx.key_schedule.round_keys;
487            var ts: [count]Block = undefined;
488            comptime var j = 0;
489            inline while (j < count) : (j += 1) {
490                ts[j] = Block.fromBytes(counters[j * 16 .. j * 16 + 16][0..16]).xorBlocks(round_keys[0]);
491            }
492            comptime var i = 1;
493            inline while (i < rounds) : (i += 1) {
494                ts = Block.parallel.encryptWide(count, ts, round_keys[i]);
495            }
496            ts = Block.parallel.encryptLastWide(count, ts, round_keys[i]);
497            j = 0;
498            inline while (j < count) : (j += 1) {
499                dst[16 * j .. 16 * j + 16].* = ts[j].xorBytes(src[16 * j .. 16 * j + 16]);
500            }
501        }
502    };
503}
504
505/// A context to perform decryption using the standard AES key schedule.
506pub fn AesDecryptCtx(comptime Aes: type) type {
507    std.debug.assert(Aes.key_bits == 128 or Aes.key_bits == 256);
508    const rounds = Aes.rounds;
509
510    return struct {
511        const Self = @This();
512        pub const block = Aes.block;
513        pub const block_length = block.block_length;
514        key_schedule: KeySchedule(Aes),
515
516        /// Create a decryption context from an existing encryption context.
517        pub fn initFromEnc(ctx: AesEncryptCtx(Aes)) Self {
518            return Self{
519                .key_schedule = ctx.key_schedule.invert(),
520            };
521        }
522
523        /// Create a new decryption context with the given key.
524        pub fn init(key: [Aes.key_bits / 8]u8) Self {
525            const enc_ctx = AesEncryptCtx(Aes).init(key);
526            return initFromEnc(enc_ctx);
527        }
528
529        /// Decrypt a single block.
530        pub fn decrypt(ctx: Self, dst: *[16]u8, src: *const [16]u8) void {
531            const inv_round_keys = ctx.key_schedule.round_keys;
532            var t = Block.fromBytes(src).xorBlocks(inv_round_keys[0]);
533            comptime var i = 1;
534            inline while (i < rounds) : (i += 1) {
535                t = t.decrypt(inv_round_keys[i]);
536            }
537            t = t.decryptLast(inv_round_keys[rounds]);
538            dst.* = t.toBytes();
539        }
540
541        /// Decrypt multiple blocks, possibly leveraging parallelization.
542        pub fn decryptWide(ctx: Self, comptime count: usize, dst: *[16 * count]u8, src: *const [16 * count]u8) void {
543            const inv_round_keys = ctx.key_schedule.round_keys;
544            var ts: [count]Block = undefined;
545            comptime var j = 0;
546            inline while (j < count) : (j += 1) {
547                ts[j] = Block.fromBytes(src[j * 16 .. j * 16 + 16][0..16]).xorBlocks(inv_round_keys[0]);
548            }
549            comptime var i = 1;
550            inline while (i < rounds) : (i += 1) {
551                ts = Block.parallel.decryptWide(count, ts, inv_round_keys[i]);
552            }
553            ts = Block.parallel.decryptLastWide(count, ts, inv_round_keys[i]);
554            j = 0;
555            inline while (j < count) : (j += 1) {
556                dst[16 * j .. 16 * j + 16].* = ts[j].toBytes();
557            }
558        }
559    };
560}
561
562/// AES-128 with the standard key schedule.
563pub const Aes128 = struct {
564    pub const key_bits: usize = 128;
565    pub const rounds = ((key_bits - 64) / 32 + 8);
566    pub const block = Block;
567
568    /// Create a new context for encryption.
569    pub fn initEnc(key: [key_bits / 8]u8) AesEncryptCtx(Aes128) {
570        return AesEncryptCtx(Aes128).init(key);
571    }
572
573    /// Create a new context for decryption.
574    pub fn initDec(key: [key_bits / 8]u8) AesDecryptCtx(Aes128) {
575        return AesDecryptCtx(Aes128).init(key);
576    }
577};
578
579/// AES-256 with the standard key schedule.
580pub const Aes256 = struct {
581    pub const key_bits: usize = 256;
582    pub const rounds = ((key_bits - 64) / 32 + 8);
583    pub const block = Block;
584
585    /// Create a new context for encryption.
586    pub fn initEnc(key: [key_bits / 8]u8) AesEncryptCtx(Aes256) {
587        return AesEncryptCtx(Aes256).init(key);
588    }
589
590    /// Create a new context for decryption.
591    pub fn initDec(key: [key_bits / 8]u8) AesDecryptCtx(Aes256) {
592        return AesDecryptCtx(Aes256).init(key);
593    }
594};