Commit 9135115573

e4m2 <git@e4m2.com>
2023-08-14 21:39:51
std.crypto.aead: Consistent decryption tail and doc fixes (#16781)
* Consistent decryption tail for all AEADs * Remove outdated note This was previously copied here from another function. There used to be another comment on the tag verification linking to issue #1776, but that one was not copied over. As it stands, this note seems fairly misleading/irrelevant. * Prettier docs * Add note about plaintext contents to docs * Capitalization * Fixup missing XChaChaPoly docs
1 parent 8f3ccbb
lib/std/crypto/aegis.zig
@@ -17,10 +17,11 @@
 //! https://datatracker.ietf.org/doc/draft-irtf-cfrg-aegis-aead/
 
 const std = @import("std");
+const crypto = std.crypto;
 const mem = std.mem;
 const assert = std.debug.assert;
-const AesBlock = std.crypto.core.aes.Block;
-const AuthenticationError = std.crypto.errors.AuthenticationError;
+const AesBlock = crypto.core.aes.Block;
+const AuthenticationError = crypto.errors.AuthenticationError;
 
 /// AEGIS-128L with a 128-bit authentication tag.
 pub const Aegis128L = Aegis128LGeneric(128);
@@ -169,12 +170,15 @@ fn Aegis128LGeneric(comptime tag_bits: u9) type {
             tag.* = state.mac(tag_bits, ad.len, m.len);
         }
 
-        /// m: message: output buffer should be of size c.len
-        /// c: ciphertext
-        /// tag: authentication tag
-        /// ad: Associated Data
-        /// npub: public nonce
-        /// k: private key
+        /// `m`: Message
+        /// `c`: Ciphertext
+        /// `tag`: Authentication tag
+        /// `ad`: Associated data
+        /// `npub`: Public nonce
+        /// `k`: Private key
+        /// Asserts `c.len == m.len`.
+        ///
+        /// Contents of `m` are undefined if an error is returned.
         pub fn decrypt(m: []u8, c: []const u8, tag: [tag_length]u8, ad: []const u8, npub: [nonce_length]u8, key: [key_length]u8) AuthenticationError!void {
             assert(c.len == m.len);
             var state = State128L.init(key, npub);
@@ -203,12 +207,10 @@ fn Aegis128LGeneric(comptime tag_bits: u9) type {
                 blocks[0] = blocks[0].xorBlocks(AesBlock.fromBytes(dst[0..16]));
                 blocks[4] = blocks[4].xorBlocks(AesBlock.fromBytes(dst[16..32]));
             }
-            const computed_tag = state.mac(tag_bits, ad.len, m.len);
-            var acc: u8 = 0;
-            for (computed_tag, 0..) |_, j| {
-                acc |= (computed_tag[j] ^ tag[j]);
-            }
-            if (acc != 0) {
+            var computed_tag = state.mac(tag_bits, ad.len, m.len);
+            const verify = crypto.utils.timingSafeEql([tag_length]u8, computed_tag, tag);
+            if (!verify) {
+                crypto.utils.secureZero(u8, &computed_tag);
                 @memset(m, undefined);
                 return error.AuthenticationFailed;
             }
@@ -351,12 +353,15 @@ fn Aegis256Generic(comptime tag_bits: u9) type {
             tag.* = state.mac(tag_bits, ad.len, m.len);
         }
 
-        /// m: message: output buffer should be of size c.len
-        /// c: ciphertext
-        /// tag: authentication tag
-        /// ad: Associated Data
-        /// npub: public nonce
-        /// k: private key
+        /// `m`: Message
+        /// `c`: Ciphertext
+        /// `tag`: Authentication tag
+        /// `ad`: Associated data
+        /// `npub`: Public nonce
+        /// `k`: Private key
+        /// Asserts `c.len == m.len`.
+        ///
+        /// Contents of `m` are undefined if an error is returned.
         pub fn decrypt(m: []u8, c: []const u8, tag: [tag_length]u8, ad: []const u8, npub: [nonce_length]u8, key: [key_length]u8) AuthenticationError!void {
             assert(c.len == m.len);
             var state = State256.init(key, npub);
@@ -384,12 +389,10 @@ fn Aegis256Generic(comptime tag_bits: u9) type {
                 const blocks = &state.blocks;
                 blocks[0] = blocks[0].xorBlocks(AesBlock.fromBytes(&dst));
             }
-            const computed_tag = state.mac(tag_bits, ad.len, m.len);
-            var acc: u8 = 0;
-            for (computed_tag, 0..) |_, j| {
-                acc |= (computed_tag[j] ^ tag[j]);
-            }
-            if (acc != 0) {
+            var computed_tag = state.mac(tag_bits, ad.len, m.len);
+            const verify = crypto.utils.timingSafeEql([tag_length]u8, computed_tag, tag);
+            if (!verify) {
+                crypto.utils.secureZero(u8, &computed_tag);
                 @memset(m, undefined);
                 return error.AuthenticationFailed;
             }
lib/std/crypto/aes_gcm.zig
@@ -55,6 +55,15 @@ fn AesGcm(comptime Aes: anytype) type {
             }
         }
 
+        /// `m`: Message
+        /// `c`: Ciphertext
+        /// `tag`: Authentication tag
+        /// `ad`: Associated data
+        /// `npub`: Public nonce
+        /// `k`: Private key
+        /// Asserts `c.len == m.len`.
+        ///
+        /// Contents of `m` are undefined if an error is returned.
         pub fn decrypt(m: []u8, c: []const u8, tag: [tag_length]u8, ad: []const u8, npub: [nonce_length]u8, key: [key_length]u8) AuthenticationError!void {
             assert(c.len == m.len);
 
@@ -86,11 +95,9 @@ fn AesGcm(comptime Aes: anytype) type {
                 computed_tag[i] ^= x;
             }
 
-            var acc: u8 = 0;
-            for (computed_tag, 0..) |_, p| {
-                acc |= (computed_tag[p] ^ tag[p]);
-            }
-            if (acc != 0) {
+            const verify = crypto.utils.timingSafeEql([tag_length]u8, computed_tag, tag);
+            if (!verify) {
+                crypto.utils.secureZero(u8, &computed_tag);
                 @memset(m, undefined);
                 return error.AuthenticationFailed;
             }
lib/std/crypto/aes_ocb.zig
@@ -168,12 +168,15 @@ fn AesOcb(comptime Aes: anytype) type {
             tag.* = xorBlocks(e, hash(aes_enc_ctx, &lx, ad));
         }
 
-        /// m: message: output buffer should be of size c.len
-        /// c: ciphertext
-        /// tag: authentication tag
-        /// ad: Associated Data
-        /// npub: public nonce
-        /// k: secret key
+        /// `m`: Message
+        /// `c`: Ciphertext
+        /// `tag`: Authentication tag
+        /// `ad`: Associated data
+        /// `npub`: Public nonce
+        /// `k`: Private key
+        /// Asserts `c.len == m.len`.
+        ///
+        /// Contents of `m` are undefined if an error is returned.
         pub fn decrypt(m: []u8, c: []const u8, tag: [tag_length]u8, ad: []const u8, npub: [nonce_length]u8, key: [key_length]u8) AuthenticationError!void {
             assert(c.len == m.len);
 
@@ -232,8 +235,9 @@ fn AesOcb(comptime Aes: anytype) type {
             aes_enc_ctx.encrypt(&e, &e);
             var computed_tag = xorBlocks(e, hash(aes_enc_ctx, &lx, ad));
             const verify = crypto.utils.timingSafeEql([tag_length]u8, computed_tag, tag);
-            crypto.utils.secureZero(u8, &computed_tag);
             if (!verify) {
+                crypto.utils.secureZero(u8, &computed_tag);
+                @memset(m, undefined);
                 return error.AuthenticationFailed;
             }
         }
lib/std/crypto/chacha20.zig
@@ -2,13 +2,14 @@
 
 const std = @import("../std.zig");
 const builtin = @import("builtin");
+const crypto = std.crypto;
 const math = std.math;
 const mem = std.mem;
 const assert = std.debug.assert;
 const testing = std.testing;
 const maxInt = math.maxInt;
-const Poly1305 = std.crypto.onetimeauth.Poly1305;
-const AuthenticationError = std.crypto.errors.AuthenticationError;
+const Poly1305 = crypto.onetimeauth.Poly1305;
+const AuthenticationError = crypto.errors.AuthenticationError;
 
 /// IETF-variant of the ChaCha20 stream cipher, as designed for TLS.
 pub const ChaCha20IETF = ChaChaIETF(20);
@@ -675,13 +676,15 @@ fn ChaChaPoly1305(comptime rounds_nb: usize) type {
             mac.final(tag);
         }
 
-        /// m: message: output buffer should be of size c.len
-        /// c: ciphertext
-        /// tag: authentication tag
-        /// ad: Associated Data
-        /// npub: public nonce
-        /// k: private key
-        /// NOTE: the check of the authentication tag is currently not done in constant time
+        /// `m`: Message
+        /// `c`: Ciphertext
+        /// `tag`: Authentication tag
+        /// `ad`: Associated data
+        /// `npub`: Public nonce
+        /// `k`: Private key
+        /// Asserts `c.len == m.len`.
+        ///
+        /// Contents of `m` are undefined if an error is returned.
         pub fn decrypt(m: []u8, c: []const u8, tag: [tag_length]u8, ad: []const u8, npub: [nonce_length]u8, k: [key_length]u8) AuthenticationError!void {
             assert(c.len == m.len);
 
@@ -706,14 +709,13 @@ fn ChaChaPoly1305(comptime rounds_nb: usize) type {
             mem.writeIntLittle(u64, lens[0..8], ad.len);
             mem.writeIntLittle(u64, lens[8..16], c.len);
             mac.update(lens[0..]);
-            var computedTag: [16]u8 = undefined;
-            mac.final(computedTag[0..]);
+            var computed_tag: [16]u8 = undefined;
+            mac.final(computed_tag[0..]);
 
-            var acc: u8 = 0;
-            for (computedTag, 0..) |_, i| {
-                acc |= computedTag[i] ^ tag[i];
-            }
-            if (acc != 0) {
+            const verify = crypto.utils.timingSafeEql([tag_length]u8, computed_tag, tag);
+            if (!verify) {
+                crypto.utils.secureZero(u8, &computed_tag);
+                @memset(m, undefined);
                 return error.AuthenticationFailed;
             }
             ChaChaIETF(rounds_nb).xor(m[0..c.len], c, 1, k, npub);
@@ -738,12 +740,15 @@ fn XChaChaPoly1305(comptime rounds_nb: usize) type {
             return ChaChaPoly1305(rounds_nb).encrypt(c, tag, m, ad, extended.nonce, extended.key);
         }
 
-        /// m: message: output buffer should be of size c.len
-        /// c: ciphertext
-        /// tag: authentication tag
-        /// ad: Associated Data
-        /// npub: public nonce
-        /// k: private key
+        /// `m`: Message
+        /// `c`: Ciphertext
+        /// `tag`: Authentication tag
+        /// `ad`: Associated data
+        /// `npub`: Public nonce
+        /// `k`: Private key
+        /// Asserts `c.len == m.len`.
+        ///
+        /// Contents of `m` are undefined if an error is returned.
         pub fn decrypt(m: []u8, c: []const u8, tag: [tag_length]u8, ad: []const u8, npub: [nonce_length]u8, k: [key_length]u8) AuthenticationError!void {
             const extended = extend(k, npub, rounds_nb);
             return ChaChaPoly1305(rounds_nb).decrypt(m, c, tag, ad, extended.nonce, extended.key);
lib/std/crypto/isap.zig
@@ -147,11 +147,21 @@ pub const IsapA128A = struct {
         tag.* = mac(c, ad, npub, key);
     }
 
+    /// `m`: Message
+    /// `c`: Ciphertext
+    /// `tag`: Authentication tag
+    /// `ad`: Associated data
+    /// `npub`: Public nonce
+    /// `k`: Private key
+    /// Asserts `c.len == m.len`.
+    ///
+    /// Contents of `m` are undefined if an error is returned.
     pub fn decrypt(m: []u8, c: []const u8, tag: [tag_length]u8, ad: []const u8, npub: [nonce_length]u8, key: [key_length]u8) AuthenticationError!void {
         var computed_tag = mac(c, ad, npub, key);
-        const res = crypto.utils.timingSafeEql([tag_length]u8, computed_tag, tag);
-        crypto.utils.secureZero(u8, &computed_tag);
-        if (!res) {
+        const verify = crypto.utils.timingSafeEql([tag_length]u8, computed_tag, tag);
+        if (!verify) {
+            crypto.utils.secureZero(u8, &computed_tag);
+            @memset(m, undefined);
             return error.AuthenticationFailed;
         }
         xor(m, c, npub, key);
lib/std/crypto/salsa20.zig
@@ -394,12 +394,15 @@ pub const XSalsa20Poly1305 = struct {
         mac.final(tag);
     }
 
-    /// m: message: output buffer should be of size c.len
-    /// c: ciphertext
-    /// tag: authentication tag
-    /// ad: Associated Data
-    /// npub: public nonce
-    /// k: private key
+    /// `m`: Message
+    /// `c`: Ciphertext
+    /// `tag`: Authentication tag
+    /// `ad`: Associated data
+    /// `npub`: Public nonce
+    /// `k`: Private key
+    /// Asserts `c.len == m.len`.
+    ///
+    /// Contents of `m` are undefined if an error is returned.
     pub fn decrypt(m: []u8, c: []const u8, tag: [tag_length]u8, ad: []const u8, npub: [nonce_length]u8, k: [key_length]u8) AuthenticationError!void {
         debug.assert(c.len == m.len);
         const extended = extend(rounds, k, npub);
@@ -410,14 +413,13 @@ pub const XSalsa20Poly1305 = struct {
         var mac = Poly1305.init(block0[0..32]);
         mac.update(ad);
         mac.update(c);
-        var computedTag: [tag_length]u8 = undefined;
-        mac.final(&computedTag);
-        var acc: u8 = 0;
-        for (computedTag, 0..) |_, i| {
-            acc |= computedTag[i] ^ tag[i];
-        }
-        if (acc != 0) {
-            utils.secureZero(u8, &computedTag);
+        var computed_tag: [tag_length]u8 = undefined;
+        mac.final(&computed_tag);
+
+        const verify = utils.timingSafeEql([tag_length]u8, computed_tag, tag);
+        if (!verify) {
+            utils.secureZero(u8, &computed_tag);
+            @memset(m, undefined);
             return error.AuthenticationFailed;
         }
         @memcpy(m[0..mlen0], block0[32..][0..mlen0]);