master
  1const std = @import("std");
  2const builtin = @import("builtin");
  3const crypto = std.crypto;
  4const aes = crypto.core.aes;
  5const assert = std.debug.assert;
  6const math = std.math;
  7const mem = std.mem;
  8const AuthenticationError = crypto.errors.AuthenticationError;
  9
 10pub const Aes128Ocb = AesOcb(aes.Aes128);
 11pub const Aes256Ocb = AesOcb(aes.Aes256);
 12
 13const Block = [16]u8;
 14
 15/// AES-OCB (RFC 7253 - https://competitions.cr.yp.to/round3/ocbv11.pdf)
 16fn AesOcb(comptime Aes: anytype) type {
 17    const EncryptCtx = aes.AesEncryptCtx(Aes);
 18    const DecryptCtx = aes.AesDecryptCtx(Aes);
 19
 20    return struct {
 21        pub const key_length = Aes.key_bits / 8;
 22        pub const nonce_length: usize = 12;
 23        pub const tag_length: usize = 16;
 24
 25        const Lx = struct {
 26            star: Block align(16),
 27            dol: Block align(16),
 28            table: [56]Block align(16) = undefined,
 29            upto: usize,
 30
 31            fn double(l: Block) Block {
 32                const l_ = mem.readInt(u128, &l, .big);
 33                const l_2 = (l_ << 1) ^ (0x87 & -%(l_ >> 127));
 34                var l2: Block = undefined;
 35                mem.writeInt(u128, &l2, l_2, .big);
 36                return l2;
 37            }
 38
 39            fn precomp(lx: *Lx, upto: usize) []const Block {
 40                const table = &lx.table;
 41                assert(upto < table.len);
 42                var i = lx.upto;
 43                while (i + 1 <= upto) : (i += 1) {
 44                    table[i + 1] = double(table[i]);
 45                }
 46                lx.upto = upto;
 47                return lx.table[0 .. upto + 1];
 48            }
 49
 50            fn init(aes_enc_ctx: EncryptCtx) Lx {
 51                const zeros = [_]u8{0} ** 16;
 52                var star: Block = undefined;
 53                aes_enc_ctx.encrypt(&star, &zeros);
 54                const dol = double(star);
 55                var lx = Lx{ .star = star, .dol = dol, .upto = 0 };
 56                lx.table[0] = double(dol);
 57                return lx;
 58            }
 59        };
 60
 61        fn hash(aes_enc_ctx: EncryptCtx, lx: *Lx, a: []const u8) Block {
 62            const full_blocks: usize = a.len / 16;
 63            const x_max = if (full_blocks > 0) math.log2_int(usize, full_blocks) else 0;
 64            const lt = lx.precomp(x_max);
 65            var sum = [_]u8{0} ** 16;
 66            var offset = [_]u8{0} ** 16;
 67            var i: usize = 0;
 68            while (i < full_blocks) : (i += 1) {
 69                xorWith(&offset, lt[@ctz(i + 1)]);
 70                var e = xorBlocks(offset, a[i * 16 ..][0..16].*);
 71                aes_enc_ctx.encrypt(&e, &e);
 72                xorWith(&sum, e);
 73            }
 74            const leftover = a.len % 16;
 75            if (leftover > 0) {
 76                xorWith(&offset, lx.star);
 77                var padded = [_]u8{0} ** 16;
 78                @memcpy(padded[0..leftover], a[i * 16 ..][0..leftover]);
 79                padded[leftover] = 0x80;
 80                var e = xorBlocks(offset, padded);
 81                aes_enc_ctx.encrypt(&e, &e);
 82                xorWith(&sum, e);
 83            }
 84            return sum;
 85        }
 86
 87        fn getOffset(aes_enc_ctx: EncryptCtx, npub: [nonce_length]u8) Block {
 88            var nx = [_]u8{0} ** 16;
 89            nx[0] = @as(u8, @intCast(@as(u7, @truncate(tag_length * 8)) << 1));
 90            nx[16 - nonce_length - 1] = 1;
 91            nx[nx.len - nonce_length ..].* = npub;
 92
 93            const bottom: u6 = @truncate(nx[15]);
 94            nx[15] &= 0xc0;
 95            var ktop_: Block = undefined;
 96            aes_enc_ctx.encrypt(&ktop_, &nx);
 97            const ktop = mem.readInt(u128, &ktop_, .big);
 98            const stretch = (@as(u192, ktop) << 64) | @as(u192, @as(u64, @truncate(ktop >> 64)) ^ @as(u64, @truncate(ktop >> 56)));
 99            var offset: Block = undefined;
100            mem.writeInt(u128, &offset, @as(u128, @truncate(stretch >> (64 - @as(u7, bottom)))), .big);
101            return offset;
102        }
103
104        const has_aesni = builtin.cpu.has(.x86, .aes);
105        const has_armaes = builtin.cpu.has(.aarch64, .aes);
106        const wb: usize = if ((builtin.cpu.arch == .x86_64 and has_aesni) or (builtin.cpu.arch == .aarch64 and has_armaes)) 4 else 0;
107
108        /// c: ciphertext: output buffer should be of size m.len
109        /// tag: authentication tag: output MAC
110        /// m: message
111        /// ad: Associated Data
112        /// npub: public nonce
113        /// k: secret key
114        pub fn encrypt(c: []u8, tag: *[tag_length]u8, m: []const u8, ad: []const u8, npub: [nonce_length]u8, key: [key_length]u8) void {
115            assert(c.len == m.len);
116
117            const aes_enc_ctx = Aes.initEnc(key);
118            const full_blocks: usize = m.len / 16;
119            const x_max = if (full_blocks > 0) math.log2_int(usize, full_blocks) else 0;
120            var lx = Lx.init(aes_enc_ctx);
121            const lt = lx.precomp(x_max);
122
123            var offset = getOffset(aes_enc_ctx, npub);
124            var sum = [_]u8{0} ** 16;
125            var i: usize = 0;
126
127            while (wb > 0 and i + wb <= full_blocks) : (i += wb) {
128                var offsets: [wb]Block align(16) = undefined;
129                var es: [16 * wb]u8 align(16) = undefined;
130                var j: usize = 0;
131                while (j < wb) : (j += 1) {
132                    xorWith(&offset, lt[@ctz(i + 1 + j)]);
133                    offsets[j] = offset;
134                    const p = m[(i + j) * 16 ..][0..16].*;
135                    es[j * 16 ..][0..16].* = xorBlocks(p, offsets[j]);
136                    xorWith(&sum, p);
137                }
138                aes_enc_ctx.encryptWide(wb, &es, &es);
139                j = 0;
140                while (j < wb) : (j += 1) {
141                    const e = es[j * 16 ..][0..16].*;
142                    c[(i + j) * 16 ..][0..16].* = xorBlocks(e, offsets[j]);
143                }
144            }
145            while (i < full_blocks) : (i += 1) {
146                xorWith(&offset, lt[@ctz(i + 1)]);
147                const p = m[i * 16 ..][0..16].*;
148                var e = xorBlocks(p, offset);
149                aes_enc_ctx.encrypt(&e, &e);
150                c[i * 16 ..][0..16].* = xorBlocks(e, offset);
151                xorWith(&sum, p);
152            }
153            const leftover = m.len % 16;
154            if (leftover > 0) {
155                xorWith(&offset, lx.star);
156                var pad = offset;
157                aes_enc_ctx.encrypt(&pad, &pad);
158                var e = [_]u8{0} ** 16;
159                @memcpy(e[0..leftover], m[i * 16 ..][0..leftover]);
160                e[leftover] = 0x80;
161                for (m[i * 16 ..], 0..) |x, j| {
162                    c[i * 16 + j] = pad[j] ^ x;
163                }
164                xorWith(&sum, e);
165            }
166            var e = xorBlocks(xorBlocks(sum, offset), lx.dol);
167            aes_enc_ctx.encrypt(&e, &e);
168            tag.* = xorBlocks(e, hash(aes_enc_ctx, &lx, ad));
169        }
170
171        /// `m`: Message
172        /// `c`: Ciphertext
173        /// `tag`: Authentication tag
174        /// `ad`: Associated data
175        /// `npub`: Public nonce
176        /// `k`: Private key
177        /// Asserts `c.len == m.len`.
178        ///
179        /// Contents of `m` are undefined if an error is returned.
180        pub fn decrypt(m: []u8, c: []const u8, tag: [tag_length]u8, ad: []const u8, npub: [nonce_length]u8, key: [key_length]u8) AuthenticationError!void {
181            assert(c.len == m.len);
182
183            const aes_enc_ctx = Aes.initEnc(key);
184            const aes_dec_ctx = DecryptCtx.initFromEnc(aes_enc_ctx);
185            const full_blocks: usize = m.len / 16;
186            const x_max = if (full_blocks > 0) math.log2_int(usize, full_blocks) else 0;
187            var lx = Lx.init(aes_enc_ctx);
188            const lt = lx.precomp(x_max);
189
190            var offset = getOffset(aes_enc_ctx, npub);
191            var sum = [_]u8{0} ** 16;
192            var i: usize = 0;
193
194            while (wb > 0 and i + wb <= full_blocks) : (i += wb) {
195                var offsets: [wb]Block align(16) = undefined;
196                var es: [16 * wb]u8 align(16) = undefined;
197                var j: usize = 0;
198                while (j < wb) : (j += 1) {
199                    xorWith(&offset, lt[@ctz(i + 1 + j)]);
200                    offsets[j] = offset;
201                    const q = c[(i + j) * 16 ..][0..16].*;
202                    es[j * 16 ..][0..16].* = xorBlocks(q, offsets[j]);
203                }
204                aes_dec_ctx.decryptWide(wb, &es, &es);
205                j = 0;
206                while (j < wb) : (j += 1) {
207                    const p = xorBlocks(es[j * 16 ..][0..16].*, offsets[j]);
208                    m[(i + j) * 16 ..][0..16].* = p;
209                    xorWith(&sum, p);
210                }
211            }
212            while (i < full_blocks) : (i += 1) {
213                xorWith(&offset, lt[@ctz(i + 1)]);
214                const q = c[i * 16 ..][0..16].*;
215                var e = xorBlocks(q, offset);
216                aes_dec_ctx.decrypt(&e, &e);
217                const p = xorBlocks(e, offset);
218                m[i * 16 ..][0..16].* = p;
219                xorWith(&sum, p);
220            }
221            const leftover = m.len % 16;
222            if (leftover > 0) {
223                xorWith(&offset, lx.star);
224                var pad = offset;
225                aes_enc_ctx.encrypt(&pad, &pad);
226                for (c[i * 16 ..], 0..) |x, j| {
227                    m[i * 16 + j] = pad[j] ^ x;
228                }
229                var e = [_]u8{0} ** 16;
230                @memcpy(e[0..leftover], m[i * 16 ..][0..leftover]);
231                e[leftover] = 0x80;
232                xorWith(&sum, e);
233            }
234            var e = xorBlocks(xorBlocks(sum, offset), lx.dol);
235            aes_enc_ctx.encrypt(&e, &e);
236            var computed_tag = xorBlocks(e, hash(aes_enc_ctx, &lx, ad));
237            const verify = crypto.timing_safe.eql([tag_length]u8, computed_tag, tag);
238            if (!verify) {
239                crypto.secureZero(u8, &computed_tag);
240                @memset(m, undefined);
241                return error.AuthenticationFailed;
242            }
243        }
244    };
245}
246
247fn xorBlocks(x: Block, y: Block) Block {
248    var z: Block = x;
249    for (&z, 0..) |*v, i| {
250        v.* = x[i] ^ y[i];
251    }
252    return z;
253}
254
255fn xorWith(x: *Block, y: Block) void {
256    for (x, 0..) |*v, i| {
257        v.* ^= y[i];
258    }
259}
260
261const hexToBytes = std.fmt.hexToBytes;
262const testing = std.testing;
263
264test "AesOcb test vector 1" {
265    if (builtin.zig_backend == .stage2_c) return error.SkipZigTest;
266
267    var k: [Aes128Ocb.key_length]u8 = undefined;
268    var nonce: [Aes128Ocb.nonce_length]u8 = undefined;
269    var tag: [Aes128Ocb.tag_length]u8 = undefined;
270    _ = try hexToBytes(&k, "000102030405060708090A0B0C0D0E0F");
271    _ = try hexToBytes(&nonce, "BBAA99887766554433221100");
272
273    var c: [0]u8 = undefined;
274    Aes128Ocb.encrypt(&c, &tag, "", "", nonce, k);
275
276    var expected_tag: [tag.len]u8 = undefined;
277    _ = try hexToBytes(&expected_tag, "785407BFFFC8AD9EDCC5520AC9111EE6");
278
279    var m: [0]u8 = undefined;
280    try Aes128Ocb.decrypt(&m, "", tag, "", nonce, k);
281}
282
283test "AesOcb test vector 2" {
284    if (builtin.zig_backend == .stage2_c) return error.SkipZigTest;
285
286    var k: [Aes128Ocb.key_length]u8 = undefined;
287    var nonce: [Aes128Ocb.nonce_length]u8 = undefined;
288    var tag: [Aes128Ocb.tag_length]u8 = undefined;
289    var ad: [40]u8 = undefined;
290    _ = try hexToBytes(&k, "000102030405060708090A0B0C0D0E0F");
291    _ = try hexToBytes(&ad, "000102030405060708090A0B0C0D0E0F101112131415161718191A1B1C1D1E1F2021222324252627");
292    _ = try hexToBytes(&nonce, "BBAA9988776655443322110E");
293
294    var c: [0]u8 = undefined;
295    Aes128Ocb.encrypt(&c, &tag, "", &ad, nonce, k);
296
297    var expected_tag: [tag.len]u8 = undefined;
298    _ = try hexToBytes(&expected_tag, "C5CD9D1850C141E358649994EE701B68");
299
300    try testing.expectEqualSlices(u8, &expected_tag, &tag);
301    var m: [0]u8 = undefined;
302    try Aes128Ocb.decrypt(&m, &c, tag, &ad, nonce, k);
303}
304
305test "AesOcb test vector 3" {
306    if (builtin.zig_backend == .stage2_c) return error.SkipZigTest;
307
308    var k: [Aes128Ocb.key_length]u8 = undefined;
309    var nonce: [Aes128Ocb.nonce_length]u8 = undefined;
310    var tag: [Aes128Ocb.tag_length]u8 = undefined;
311    var m: [40]u8 = undefined;
312    var c: [m.len]u8 = undefined;
313    _ = try hexToBytes(&k, "000102030405060708090A0B0C0D0E0F");
314    _ = try hexToBytes(&m, "000102030405060708090A0B0C0D0E0F101112131415161718191A1B1C1D1E1F2021222324252627");
315    _ = try hexToBytes(&nonce, "BBAA9988776655443322110F");
316
317    Aes128Ocb.encrypt(&c, &tag, &m, "", nonce, k);
318
319    var expected_c: [c.len]u8 = undefined;
320    var expected_tag: [tag.len]u8 = undefined;
321    _ = try hexToBytes(&expected_tag, "479AD363AC366B95A98CA5F3000B1479");
322    _ = try hexToBytes(&expected_c, "4412923493C57D5DE0D700F753CCE0D1D2D95060122E9F15A5DDBFC5787E50B5CC55EE507BCB084E");
323
324    try testing.expectEqualSlices(u8, &expected_tag, &tag);
325    try testing.expectEqualSlices(u8, &expected_c, &c);
326    var m2: [m.len]u8 = undefined;
327    try Aes128Ocb.decrypt(&m2, &c, tag, "", nonce, k);
328    assert(mem.eql(u8, &m, &m2));
329}
330
331test "AesOcb test vector 4" {
332    if (builtin.zig_backend == .stage2_c) return error.SkipZigTest;
333
334    var k: [Aes128Ocb.key_length]u8 = undefined;
335    var nonce: [Aes128Ocb.nonce_length]u8 = undefined;
336    var tag: [Aes128Ocb.tag_length]u8 = undefined;
337    var m: [40]u8 = undefined;
338    var c: [m.len]u8 = undefined;
339    _ = try hexToBytes(&k, "000102030405060708090A0B0C0D0E0F");
340    _ = try hexToBytes(&m, "000102030405060708090A0B0C0D0E0F101112131415161718191A1B1C1D1E1F2021222324252627");
341    _ = try hexToBytes(&nonce, "BBAA9988776655443322110D");
342    const ad = m;
343
344    Aes128Ocb.encrypt(&c, &tag, &m, &ad, nonce, k);
345
346    var expected_c: [c.len]u8 = undefined;
347    var expected_tag: [tag.len]u8 = undefined;
348    _ = try hexToBytes(&expected_tag, "ED07BA06A4A69483A7035490C5769E60");
349    _ = try hexToBytes(&expected_c, "D5CA91748410C1751FF8A2F618255B68A0A12E093FF454606E59F9C1D0DDC54B65E8628E568BAD7A");
350
351    try testing.expectEqualSlices(u8, &expected_tag, &tag);
352    try testing.expectEqualSlices(u8, &expected_c, &c);
353    var m2: [m.len]u8 = undefined;
354    try Aes128Ocb.decrypt(&m2, &c, tag, &ad, nonce, k);
355    assert(mem.eql(u8, &m, &m2));
356}
357
358test "AesOcb in-place encryption-decryption" {
359    if (builtin.zig_backend == .stage2_c) return error.SkipZigTest;
360
361    var k: [Aes128Ocb.key_length]u8 = undefined;
362    var nonce: [Aes128Ocb.nonce_length]u8 = undefined;
363    var tag: [Aes128Ocb.tag_length]u8 = undefined;
364    var m: [40]u8 = undefined;
365    var original_m: [m.len]u8 = undefined;
366    _ = try hexToBytes(&k, "000102030405060708090A0B0C0D0E0F");
367    _ = try hexToBytes(&m, "000102030405060708090A0B0C0D0E0F101112131415161718191A1B1C1D1E1F2021222324252627");
368    _ = try hexToBytes(&nonce, "BBAA9988776655443322110D");
369    const ad = m;
370
371    @memcpy(&original_m, &m);
372
373    Aes128Ocb.encrypt(&m, &tag, &m, &ad, nonce, k);
374
375    var expected_c: [m.len]u8 = undefined;
376    var expected_tag: [tag.len]u8 = undefined;
377    _ = try hexToBytes(&expected_tag, "ED07BA06A4A69483A7035490C5769E60");
378    _ = try hexToBytes(&expected_c, "D5CA91748410C1751FF8A2F618255B68A0A12E093FF454606E59F9C1D0DDC54B65E8628E568BAD7A");
379
380    try testing.expectEqualSlices(u8, &expected_tag, &tag);
381    try testing.expectEqualSlices(u8, &expected_c, &m);
382    try Aes128Ocb.decrypt(&m, &m, tag, &ad, nonce, k);
383
384    try testing.expectEqualSlices(u8, &original_m, &m);
385}