master
1const std = @import("std");
2const builtin = @import("builtin");
3const crypto = std.crypto;
4const aes = crypto.core.aes;
5const assert = std.debug.assert;
6const math = std.math;
7const mem = std.mem;
8const AuthenticationError = crypto.errors.AuthenticationError;
9
10pub const Aes128Ocb = AesOcb(aes.Aes128);
11pub const Aes256Ocb = AesOcb(aes.Aes256);
12
13const Block = [16]u8;
14
15/// AES-OCB (RFC 7253 - https://competitions.cr.yp.to/round3/ocbv11.pdf)
16fn AesOcb(comptime Aes: anytype) type {
17 const EncryptCtx = aes.AesEncryptCtx(Aes);
18 const DecryptCtx = aes.AesDecryptCtx(Aes);
19
20 return struct {
21 pub const key_length = Aes.key_bits / 8;
22 pub const nonce_length: usize = 12;
23 pub const tag_length: usize = 16;
24
25 const Lx = struct {
26 star: Block align(16),
27 dol: Block align(16),
28 table: [56]Block align(16) = undefined,
29 upto: usize,
30
31 fn double(l: Block) Block {
32 const l_ = mem.readInt(u128, &l, .big);
33 const l_2 = (l_ << 1) ^ (0x87 & -%(l_ >> 127));
34 var l2: Block = undefined;
35 mem.writeInt(u128, &l2, l_2, .big);
36 return l2;
37 }
38
39 fn precomp(lx: *Lx, upto: usize) []const Block {
40 const table = &lx.table;
41 assert(upto < table.len);
42 var i = lx.upto;
43 while (i + 1 <= upto) : (i += 1) {
44 table[i + 1] = double(table[i]);
45 }
46 lx.upto = upto;
47 return lx.table[0 .. upto + 1];
48 }
49
50 fn init(aes_enc_ctx: EncryptCtx) Lx {
51 const zeros = [_]u8{0} ** 16;
52 var star: Block = undefined;
53 aes_enc_ctx.encrypt(&star, &zeros);
54 const dol = double(star);
55 var lx = Lx{ .star = star, .dol = dol, .upto = 0 };
56 lx.table[0] = double(dol);
57 return lx;
58 }
59 };
60
61 fn hash(aes_enc_ctx: EncryptCtx, lx: *Lx, a: []const u8) Block {
62 const full_blocks: usize = a.len / 16;
63 const x_max = if (full_blocks > 0) math.log2_int(usize, full_blocks) else 0;
64 const lt = lx.precomp(x_max);
65 var sum = [_]u8{0} ** 16;
66 var offset = [_]u8{0} ** 16;
67 var i: usize = 0;
68 while (i < full_blocks) : (i += 1) {
69 xorWith(&offset, lt[@ctz(i + 1)]);
70 var e = xorBlocks(offset, a[i * 16 ..][0..16].*);
71 aes_enc_ctx.encrypt(&e, &e);
72 xorWith(&sum, e);
73 }
74 const leftover = a.len % 16;
75 if (leftover > 0) {
76 xorWith(&offset, lx.star);
77 var padded = [_]u8{0} ** 16;
78 @memcpy(padded[0..leftover], a[i * 16 ..][0..leftover]);
79 padded[leftover] = 0x80;
80 var e = xorBlocks(offset, padded);
81 aes_enc_ctx.encrypt(&e, &e);
82 xorWith(&sum, e);
83 }
84 return sum;
85 }
86
87 fn getOffset(aes_enc_ctx: EncryptCtx, npub: [nonce_length]u8) Block {
88 var nx = [_]u8{0} ** 16;
89 nx[0] = @as(u8, @intCast(@as(u7, @truncate(tag_length * 8)) << 1));
90 nx[16 - nonce_length - 1] = 1;
91 nx[nx.len - nonce_length ..].* = npub;
92
93 const bottom: u6 = @truncate(nx[15]);
94 nx[15] &= 0xc0;
95 var ktop_: Block = undefined;
96 aes_enc_ctx.encrypt(&ktop_, &nx);
97 const ktop = mem.readInt(u128, &ktop_, .big);
98 const stretch = (@as(u192, ktop) << 64) | @as(u192, @as(u64, @truncate(ktop >> 64)) ^ @as(u64, @truncate(ktop >> 56)));
99 var offset: Block = undefined;
100 mem.writeInt(u128, &offset, @as(u128, @truncate(stretch >> (64 - @as(u7, bottom)))), .big);
101 return offset;
102 }
103
104 const has_aesni = builtin.cpu.has(.x86, .aes);
105 const has_armaes = builtin.cpu.has(.aarch64, .aes);
106 const wb: usize = if ((builtin.cpu.arch == .x86_64 and has_aesni) or (builtin.cpu.arch == .aarch64 and has_armaes)) 4 else 0;
107
108 /// c: ciphertext: output buffer should be of size m.len
109 /// tag: authentication tag: output MAC
110 /// m: message
111 /// ad: Associated Data
112 /// npub: public nonce
113 /// k: secret key
114 pub fn encrypt(c: []u8, tag: *[tag_length]u8, m: []const u8, ad: []const u8, npub: [nonce_length]u8, key: [key_length]u8) void {
115 assert(c.len == m.len);
116
117 const aes_enc_ctx = Aes.initEnc(key);
118 const full_blocks: usize = m.len / 16;
119 const x_max = if (full_blocks > 0) math.log2_int(usize, full_blocks) else 0;
120 var lx = Lx.init(aes_enc_ctx);
121 const lt = lx.precomp(x_max);
122
123 var offset = getOffset(aes_enc_ctx, npub);
124 var sum = [_]u8{0} ** 16;
125 var i: usize = 0;
126
127 while (wb > 0 and i + wb <= full_blocks) : (i += wb) {
128 var offsets: [wb]Block align(16) = undefined;
129 var es: [16 * wb]u8 align(16) = undefined;
130 var j: usize = 0;
131 while (j < wb) : (j += 1) {
132 xorWith(&offset, lt[@ctz(i + 1 + j)]);
133 offsets[j] = offset;
134 const p = m[(i + j) * 16 ..][0..16].*;
135 es[j * 16 ..][0..16].* = xorBlocks(p, offsets[j]);
136 xorWith(&sum, p);
137 }
138 aes_enc_ctx.encryptWide(wb, &es, &es);
139 j = 0;
140 while (j < wb) : (j += 1) {
141 const e = es[j * 16 ..][0..16].*;
142 c[(i + j) * 16 ..][0..16].* = xorBlocks(e, offsets[j]);
143 }
144 }
145 while (i < full_blocks) : (i += 1) {
146 xorWith(&offset, lt[@ctz(i + 1)]);
147 const p = m[i * 16 ..][0..16].*;
148 var e = xorBlocks(p, offset);
149 aes_enc_ctx.encrypt(&e, &e);
150 c[i * 16 ..][0..16].* = xorBlocks(e, offset);
151 xorWith(&sum, p);
152 }
153 const leftover = m.len % 16;
154 if (leftover > 0) {
155 xorWith(&offset, lx.star);
156 var pad = offset;
157 aes_enc_ctx.encrypt(&pad, &pad);
158 var e = [_]u8{0} ** 16;
159 @memcpy(e[0..leftover], m[i * 16 ..][0..leftover]);
160 e[leftover] = 0x80;
161 for (m[i * 16 ..], 0..) |x, j| {
162 c[i * 16 + j] = pad[j] ^ x;
163 }
164 xorWith(&sum, e);
165 }
166 var e = xorBlocks(xorBlocks(sum, offset), lx.dol);
167 aes_enc_ctx.encrypt(&e, &e);
168 tag.* = xorBlocks(e, hash(aes_enc_ctx, &lx, ad));
169 }
170
171 /// `m`: Message
172 /// `c`: Ciphertext
173 /// `tag`: Authentication tag
174 /// `ad`: Associated data
175 /// `npub`: Public nonce
176 /// `k`: Private key
177 /// Asserts `c.len == m.len`.
178 ///
179 /// Contents of `m` are undefined if an error is returned.
180 pub fn decrypt(m: []u8, c: []const u8, tag: [tag_length]u8, ad: []const u8, npub: [nonce_length]u8, key: [key_length]u8) AuthenticationError!void {
181 assert(c.len == m.len);
182
183 const aes_enc_ctx = Aes.initEnc(key);
184 const aes_dec_ctx = DecryptCtx.initFromEnc(aes_enc_ctx);
185 const full_blocks: usize = m.len / 16;
186 const x_max = if (full_blocks > 0) math.log2_int(usize, full_blocks) else 0;
187 var lx = Lx.init(aes_enc_ctx);
188 const lt = lx.precomp(x_max);
189
190 var offset = getOffset(aes_enc_ctx, npub);
191 var sum = [_]u8{0} ** 16;
192 var i: usize = 0;
193
194 while (wb > 0 and i + wb <= full_blocks) : (i += wb) {
195 var offsets: [wb]Block align(16) = undefined;
196 var es: [16 * wb]u8 align(16) = undefined;
197 var j: usize = 0;
198 while (j < wb) : (j += 1) {
199 xorWith(&offset, lt[@ctz(i + 1 + j)]);
200 offsets[j] = offset;
201 const q = c[(i + j) * 16 ..][0..16].*;
202 es[j * 16 ..][0..16].* = xorBlocks(q, offsets[j]);
203 }
204 aes_dec_ctx.decryptWide(wb, &es, &es);
205 j = 0;
206 while (j < wb) : (j += 1) {
207 const p = xorBlocks(es[j * 16 ..][0..16].*, offsets[j]);
208 m[(i + j) * 16 ..][0..16].* = p;
209 xorWith(&sum, p);
210 }
211 }
212 while (i < full_blocks) : (i += 1) {
213 xorWith(&offset, lt[@ctz(i + 1)]);
214 const q = c[i * 16 ..][0..16].*;
215 var e = xorBlocks(q, offset);
216 aes_dec_ctx.decrypt(&e, &e);
217 const p = xorBlocks(e, offset);
218 m[i * 16 ..][0..16].* = p;
219 xorWith(&sum, p);
220 }
221 const leftover = m.len % 16;
222 if (leftover > 0) {
223 xorWith(&offset, lx.star);
224 var pad = offset;
225 aes_enc_ctx.encrypt(&pad, &pad);
226 for (c[i * 16 ..], 0..) |x, j| {
227 m[i * 16 + j] = pad[j] ^ x;
228 }
229 var e = [_]u8{0} ** 16;
230 @memcpy(e[0..leftover], m[i * 16 ..][0..leftover]);
231 e[leftover] = 0x80;
232 xorWith(&sum, e);
233 }
234 var e = xorBlocks(xorBlocks(sum, offset), lx.dol);
235 aes_enc_ctx.encrypt(&e, &e);
236 var computed_tag = xorBlocks(e, hash(aes_enc_ctx, &lx, ad));
237 const verify = crypto.timing_safe.eql([tag_length]u8, computed_tag, tag);
238 if (!verify) {
239 crypto.secureZero(u8, &computed_tag);
240 @memset(m, undefined);
241 return error.AuthenticationFailed;
242 }
243 }
244 };
245}
246
247fn xorBlocks(x: Block, y: Block) Block {
248 var z: Block = x;
249 for (&z, 0..) |*v, i| {
250 v.* = x[i] ^ y[i];
251 }
252 return z;
253}
254
255fn xorWith(x: *Block, y: Block) void {
256 for (x, 0..) |*v, i| {
257 v.* ^= y[i];
258 }
259}
260
261const hexToBytes = std.fmt.hexToBytes;
262const testing = std.testing;
263
264test "AesOcb test vector 1" {
265 if (builtin.zig_backend == .stage2_c) return error.SkipZigTest;
266
267 var k: [Aes128Ocb.key_length]u8 = undefined;
268 var nonce: [Aes128Ocb.nonce_length]u8 = undefined;
269 var tag: [Aes128Ocb.tag_length]u8 = undefined;
270 _ = try hexToBytes(&k, "000102030405060708090A0B0C0D0E0F");
271 _ = try hexToBytes(&nonce, "BBAA99887766554433221100");
272
273 var c: [0]u8 = undefined;
274 Aes128Ocb.encrypt(&c, &tag, "", "", nonce, k);
275
276 var expected_tag: [tag.len]u8 = undefined;
277 _ = try hexToBytes(&expected_tag, "785407BFFFC8AD9EDCC5520AC9111EE6");
278
279 var m: [0]u8 = undefined;
280 try Aes128Ocb.decrypt(&m, "", tag, "", nonce, k);
281}
282
283test "AesOcb test vector 2" {
284 if (builtin.zig_backend == .stage2_c) return error.SkipZigTest;
285
286 var k: [Aes128Ocb.key_length]u8 = undefined;
287 var nonce: [Aes128Ocb.nonce_length]u8 = undefined;
288 var tag: [Aes128Ocb.tag_length]u8 = undefined;
289 var ad: [40]u8 = undefined;
290 _ = try hexToBytes(&k, "000102030405060708090A0B0C0D0E0F");
291 _ = try hexToBytes(&ad, "000102030405060708090A0B0C0D0E0F101112131415161718191A1B1C1D1E1F2021222324252627");
292 _ = try hexToBytes(&nonce, "BBAA9988776655443322110E");
293
294 var c: [0]u8 = undefined;
295 Aes128Ocb.encrypt(&c, &tag, "", &ad, nonce, k);
296
297 var expected_tag: [tag.len]u8 = undefined;
298 _ = try hexToBytes(&expected_tag, "C5CD9D1850C141E358649994EE701B68");
299
300 try testing.expectEqualSlices(u8, &expected_tag, &tag);
301 var m: [0]u8 = undefined;
302 try Aes128Ocb.decrypt(&m, &c, tag, &ad, nonce, k);
303}
304
305test "AesOcb test vector 3" {
306 if (builtin.zig_backend == .stage2_c) return error.SkipZigTest;
307
308 var k: [Aes128Ocb.key_length]u8 = undefined;
309 var nonce: [Aes128Ocb.nonce_length]u8 = undefined;
310 var tag: [Aes128Ocb.tag_length]u8 = undefined;
311 var m: [40]u8 = undefined;
312 var c: [m.len]u8 = undefined;
313 _ = try hexToBytes(&k, "000102030405060708090A0B0C0D0E0F");
314 _ = try hexToBytes(&m, "000102030405060708090A0B0C0D0E0F101112131415161718191A1B1C1D1E1F2021222324252627");
315 _ = try hexToBytes(&nonce, "BBAA9988776655443322110F");
316
317 Aes128Ocb.encrypt(&c, &tag, &m, "", nonce, k);
318
319 var expected_c: [c.len]u8 = undefined;
320 var expected_tag: [tag.len]u8 = undefined;
321 _ = try hexToBytes(&expected_tag, "479AD363AC366B95A98CA5F3000B1479");
322 _ = try hexToBytes(&expected_c, "4412923493C57D5DE0D700F753CCE0D1D2D95060122E9F15A5DDBFC5787E50B5CC55EE507BCB084E");
323
324 try testing.expectEqualSlices(u8, &expected_tag, &tag);
325 try testing.expectEqualSlices(u8, &expected_c, &c);
326 var m2: [m.len]u8 = undefined;
327 try Aes128Ocb.decrypt(&m2, &c, tag, "", nonce, k);
328 assert(mem.eql(u8, &m, &m2));
329}
330
331test "AesOcb test vector 4" {
332 if (builtin.zig_backend == .stage2_c) return error.SkipZigTest;
333
334 var k: [Aes128Ocb.key_length]u8 = undefined;
335 var nonce: [Aes128Ocb.nonce_length]u8 = undefined;
336 var tag: [Aes128Ocb.tag_length]u8 = undefined;
337 var m: [40]u8 = undefined;
338 var c: [m.len]u8 = undefined;
339 _ = try hexToBytes(&k, "000102030405060708090A0B0C0D0E0F");
340 _ = try hexToBytes(&m, "000102030405060708090A0B0C0D0E0F101112131415161718191A1B1C1D1E1F2021222324252627");
341 _ = try hexToBytes(&nonce, "BBAA9988776655443322110D");
342 const ad = m;
343
344 Aes128Ocb.encrypt(&c, &tag, &m, &ad, nonce, k);
345
346 var expected_c: [c.len]u8 = undefined;
347 var expected_tag: [tag.len]u8 = undefined;
348 _ = try hexToBytes(&expected_tag, "ED07BA06A4A69483A7035490C5769E60");
349 _ = try hexToBytes(&expected_c, "D5CA91748410C1751FF8A2F618255B68A0A12E093FF454606E59F9C1D0DDC54B65E8628E568BAD7A");
350
351 try testing.expectEqualSlices(u8, &expected_tag, &tag);
352 try testing.expectEqualSlices(u8, &expected_c, &c);
353 var m2: [m.len]u8 = undefined;
354 try Aes128Ocb.decrypt(&m2, &c, tag, &ad, nonce, k);
355 assert(mem.eql(u8, &m, &m2));
356}
357
358test "AesOcb in-place encryption-decryption" {
359 if (builtin.zig_backend == .stage2_c) return error.SkipZigTest;
360
361 var k: [Aes128Ocb.key_length]u8 = undefined;
362 var nonce: [Aes128Ocb.nonce_length]u8 = undefined;
363 var tag: [Aes128Ocb.tag_length]u8 = undefined;
364 var m: [40]u8 = undefined;
365 var original_m: [m.len]u8 = undefined;
366 _ = try hexToBytes(&k, "000102030405060708090A0B0C0D0E0F");
367 _ = try hexToBytes(&m, "000102030405060708090A0B0C0D0E0F101112131415161718191A1B1C1D1E1F2021222324252627");
368 _ = try hexToBytes(&nonce, "BBAA9988776655443322110D");
369 const ad = m;
370
371 @memcpy(&original_m, &m);
372
373 Aes128Ocb.encrypt(&m, &tag, &m, &ad, nonce, k);
374
375 var expected_c: [m.len]u8 = undefined;
376 var expected_tag: [tag.len]u8 = undefined;
377 _ = try hexToBytes(&expected_tag, "ED07BA06A4A69483A7035490C5769E60");
378 _ = try hexToBytes(&expected_c, "D5CA91748410C1751FF8A2F618255B68A0A12E093FF454606E59F9C1D0DDC54B65E8628E568BAD7A");
379
380 try testing.expectEqualSlices(u8, &expected_tag, &tag);
381 try testing.expectEqualSlices(u8, &expected_c, &m);
382 try Aes128Ocb.decrypt(&m, &m, tag, &ad, nonce, k);
383
384 try testing.expectEqualSlices(u8, &original_m, &m);
385}