Commit 907f3ef887

Frank Denis <124872+jedisct1@users.noreply.github.com>
2022-11-06 23:52:41
crypto.salsa20: make the number of rounds a comptime parameter (#13442)
...instead of hard-coding it to 20. - This is consistent with the ChaCha implementation - NaCl and libsodium, that this API is designed to interop with, also support 8 and 12 round variants. The 12 round variant, in particular, provides the same security level as the 20 round variant, but is obviously faster. - scrypt currently uses its own non optimized version of Salsa, just because it use 8 rounds instead of 20. This will help remove code duplication. No behavior nor public API changes. The Salsa20 and XSalsa20 still represent the 20-round variant.
1 parent 4c719ad
Changed files (2)
lib
lib/std/crypto/salsa20.zig
@@ -14,297 +14,293 @@ const AuthenticationError = crypto.errors.AuthenticationError;
 const IdentityElementError = crypto.errors.IdentityElementError;
 const WeakPublicKeyError = crypto.errors.WeakPublicKeyError;
 
-const Salsa20VecImpl = struct {
-    const Lane = @Vector(4, u32);
-    const Half = @Vector(2, u32);
-    const BlockVec = [4]Lane;
-
-    fn initContext(key: [8]u32, d: [4]u32) BlockVec {
-        const c = "expand 32-byte k";
-        const constant_le = comptime [4]u32{
-            mem.readIntLittle(u32, c[0..4]),
-            mem.readIntLittle(u32, c[4..8]),
-            mem.readIntLittle(u32, c[8..12]),
-            mem.readIntLittle(u32, c[12..16]),
-        };
-        return BlockVec{
-            Lane{ key[0], key[1], key[2], key[3] },
-            Lane{ key[4], key[5], key[6], key[7] },
-            Lane{ constant_le[0], constant_le[1], constant_le[2], constant_le[3] },
-            Lane{ d[0], d[1], d[2], d[3] },
-        };
-    }
-
-    inline fn salsa20Core(x: *BlockVec, input: BlockVec, comptime feedback: bool) void {
-        const n1n2n3n0 = Lane{ input[3][1], input[3][2], input[3][3], input[3][0] };
-        const n1n2 = Half{ n1n2n3n0[0], n1n2n3n0[1] };
-        const n3n0 = Half{ n1n2n3n0[2], n1n2n3n0[3] };
-        const k0k1 = Half{ input[0][0], input[0][1] };
-        const k2k3 = Half{ input[0][2], input[0][3] };
-        const k4k5 = Half{ input[1][0], input[1][1] };
-        const k6k7 = Half{ input[1][2], input[1][3] };
-        const n0k0 = Half{ n3n0[1], k0k1[0] };
-        const k0n0 = Half{ n0k0[1], n0k0[0] };
-        const k4k5k0n0 = Lane{ k4k5[0], k4k5[1], k0n0[0], k0n0[1] };
-        const k1k6 = Half{ k0k1[1], k6k7[0] };
-        const k6k1 = Half{ k1k6[1], k1k6[0] };
-        const n1n2k6k1 = Lane{ n1n2[0], n1n2[1], k6k1[0], k6k1[1] };
-        const k7n3 = Half{ k6k7[1], n3n0[0] };
-        const n3k7 = Half{ k7n3[1], k7n3[0] };
-        const k2k3n3k7 = Lane{ k2k3[0], k2k3[1], n3k7[0], n3k7[1] };
-
-        var diag0 = input[2];
-        var diag1 = @shuffle(u32, k4k5k0n0, undefined, [_]i32{ 1, 2, 3, 0 });
-        var diag2 = @shuffle(u32, n1n2k6k1, undefined, [_]i32{ 1, 2, 3, 0 });
-        var diag3 = @shuffle(u32, k2k3n3k7, undefined, [_]i32{ 1, 2, 3, 0 });
-
-        const start0 = diag0;
-        const start1 = diag1;
-        const start2 = diag2;
-        const start3 = diag3;
-
-        var i: usize = 0;
-        while (i < 20) : (i += 2) {
-            var a0 = diag1 +% diag0;
-            diag3 ^= math.rotl(Lane, a0, 7);
-            var a1 = diag0 +% diag3;
-            diag2 ^= math.rotl(Lane, a1, 9);
-            var a2 = diag3 +% diag2;
-            diag1 ^= math.rotl(Lane, a2, 13);
-            var a3 = diag2 +% diag1;
-            diag0 ^= math.rotl(Lane, a3, 18);
-
-            var diag3_shift = @shuffle(u32, diag3, undefined, [_]i32{ 3, 0, 1, 2 });
-            var diag2_shift = @shuffle(u32, diag2, undefined, [_]i32{ 2, 3, 0, 1 });
-            var diag1_shift = @shuffle(u32, diag1, undefined, [_]i32{ 1, 2, 3, 0 });
-            diag3 = diag3_shift;
-            diag2 = diag2_shift;
-            diag1 = diag1_shift;
-
-            a0 = diag3 +% diag0;
-            diag1 ^= math.rotl(Lane, a0, 7);
-            a1 = diag0 +% diag1;
-            diag2 ^= math.rotl(Lane, a1, 9);
-            a2 = diag1 +% diag2;
-            diag3 ^= math.rotl(Lane, a2, 13);
-            a3 = diag2 +% diag3;
-            diag0 ^= math.rotl(Lane, a3, 18);
-
-            diag1_shift = @shuffle(u32, diag1, undefined, [_]i32{ 3, 0, 1, 2 });
-            diag2_shift = @shuffle(u32, diag2, undefined, [_]i32{ 2, 3, 0, 1 });
-            diag3_shift = @shuffle(u32, diag3, undefined, [_]i32{ 1, 2, 3, 0 });
-            diag1 = diag1_shift;
-            diag2 = diag2_shift;
-            diag3 = diag3_shift;
+/// The Salsa cipher with 20 rounds.
+pub const Salsa20 = Salsa(20);
+
+/// The XSalsa cipher with 20 rounds.
+pub const XSalsa20 = XSalsa(20);
+
+fn SalsaVecImpl(comptime rounds: comptime_int) type {
+    return struct {
+        const Lane = @Vector(4, u32);
+        const Half = @Vector(2, u32);
+        const BlockVec = [4]Lane;
+
+        fn initContext(key: [8]u32, d: [4]u32) BlockVec {
+            const c = "expand 32-byte k";
+            const constant_le = comptime [4]u32{
+                mem.readIntLittle(u32, c[0..4]),
+                mem.readIntLittle(u32, c[4..8]),
+                mem.readIntLittle(u32, c[8..12]),
+                mem.readIntLittle(u32, c[12..16]),
+            };
+            return BlockVec{
+                Lane{ key[0], key[1], key[2], key[3] },
+                Lane{ key[4], key[5], key[6], key[7] },
+                Lane{ constant_le[0], constant_le[1], constant_le[2], constant_le[3] },
+                Lane{ d[0], d[1], d[2], d[3] },
+            };
         }
 
-        if (feedback) {
-            diag0 +%= start0;
-            diag1 +%= start1;
-            diag2 +%= start2;
-            diag3 +%= start3;
-        }
+        inline fn salsaCore(x: *BlockVec, input: BlockVec, comptime feedback: bool) void {
+            const n1n2n3n0 = Lane{ input[3][1], input[3][2], input[3][3], input[3][0] };
+            const n1n2 = Half{ n1n2n3n0[0], n1n2n3n0[1] };
+            const n3n0 = Half{ n1n2n3n0[2], n1n2n3n0[3] };
+            const k0k1 = Half{ input[0][0], input[0][1] };
+            const k2k3 = Half{ input[0][2], input[0][3] };
+            const k4k5 = Half{ input[1][0], input[1][1] };
+            const k6k7 = Half{ input[1][2], input[1][3] };
+            const n0k0 = Half{ n3n0[1], k0k1[0] };
+            const k0n0 = Half{ n0k0[1], n0k0[0] };
+            const k4k5k0n0 = Lane{ k4k5[0], k4k5[1], k0n0[0], k0n0[1] };
+            const k1k6 = Half{ k0k1[1], k6k7[0] };
+            const k6k1 = Half{ k1k6[1], k1k6[0] };
+            const n1n2k6k1 = Lane{ n1n2[0], n1n2[1], k6k1[0], k6k1[1] };
+            const k7n3 = Half{ k6k7[1], n3n0[0] };
+            const n3k7 = Half{ k7n3[1], k7n3[0] };
+            const k2k3n3k7 = Lane{ k2k3[0], k2k3[1], n3k7[0], n3k7[1] };
+
+            var diag0 = input[2];
+            var diag1 = @shuffle(u32, k4k5k0n0, undefined, [_]i32{ 1, 2, 3, 0 });
+            var diag2 = @shuffle(u32, n1n2k6k1, undefined, [_]i32{ 1, 2, 3, 0 });
+            var diag3 = @shuffle(u32, k2k3n3k7, undefined, [_]i32{ 1, 2, 3, 0 });
+
+            const start0 = diag0;
+            const start1 = diag1;
+            const start2 = diag2;
+            const start3 = diag3;
+
+            var i: usize = 0;
+            while (i < rounds) : (i += 2) {
+                diag3 ^= math.rotl(Lane, diag1 +% diag0, 7);
+                diag2 ^= math.rotl(Lane, diag0 +% diag3, 9);
+                diag1 ^= math.rotl(Lane, diag3 +% diag2, 13);
+                diag0 ^= math.rotl(Lane, diag2 +% diag1, 18);
+
+                diag3 = @shuffle(u32, diag3, undefined, [_]i32{ 3, 0, 1, 2 });
+                diag2 = @shuffle(u32, diag2, undefined, [_]i32{ 2, 3, 0, 1 });
+                diag1 = @shuffle(u32, diag1, undefined, [_]i32{ 1, 2, 3, 0 });
+
+                diag1 ^= math.rotl(Lane, diag3 +% diag0, 7);
+                diag2 ^= math.rotl(Lane, diag0 +% diag1, 9);
+                diag3 ^= math.rotl(Lane, diag1 +% diag2, 13);
+                diag0 ^= math.rotl(Lane, diag2 +% diag3, 18);
+
+                diag1 = @shuffle(u32, diag1, undefined, [_]i32{ 3, 0, 1, 2 });
+                diag2 = @shuffle(u32, diag2, undefined, [_]i32{ 2, 3, 0, 1 });
+                diag3 = @shuffle(u32, diag3, undefined, [_]i32{ 1, 2, 3, 0 });
+            }
 
-        const x0x1x10x11 = Lane{ diag0[0], diag1[1], diag0[2], diag1[3] };
-        const x12x13x6x7 = Lane{ diag1[0], diag2[1], diag1[2], diag2[3] };
-        const x8x9x2x3 = Lane{ diag2[0], diag3[1], diag2[2], diag3[3] };
-        const x4x5x14x15 = Lane{ diag3[0], diag0[1], diag3[2], diag0[3] };
+            if (feedback) {
+                diag0 +%= start0;
+                diag1 +%= start1;
+                diag2 +%= start2;
+                diag3 +%= start3;
+            }
 
-        x[0] = Lane{ x0x1x10x11[0], x0x1x10x11[1], x8x9x2x3[2], x8x9x2x3[3] };
-        x[1] = Lane{ x4x5x14x15[0], x4x5x14x15[1], x12x13x6x7[2], x12x13x6x7[3] };
-        x[2] = Lane{ x8x9x2x3[0], x8x9x2x3[1], x0x1x10x11[2], x0x1x10x11[3] };
-        x[3] = Lane{ x12x13x6x7[0], x12x13x6x7[1], x4x5x14x15[2], x4x5x14x15[3] };
-    }
+            const x0x1x10x11 = Lane{ diag0[0], diag1[1], diag0[2], diag1[3] };
+            const x12x13x6x7 = Lane{ diag1[0], diag2[1], diag1[2], diag2[3] };
+            const x8x9x2x3 = Lane{ diag2[0], diag3[1], diag2[2], diag3[3] };
+            const x4x5x14x15 = Lane{ diag3[0], diag0[1], diag3[2], diag0[3] };
 
-    fn hashToBytes(out: *[64]u8, x: BlockVec) void {
-        var i: usize = 0;
-        while (i < 4) : (i += 1) {
-            mem.writeIntLittle(u32, out[16 * i + 0 ..][0..4], x[i][0]);
-            mem.writeIntLittle(u32, out[16 * i + 4 ..][0..4], x[i][1]);
-            mem.writeIntLittle(u32, out[16 * i + 8 ..][0..4], x[i][2]);
-            mem.writeIntLittle(u32, out[16 * i + 12 ..][0..4], x[i][3]);
+            x[0] = Lane{ x0x1x10x11[0], x0x1x10x11[1], x8x9x2x3[2], x8x9x2x3[3] };
+            x[1] = Lane{ x4x5x14x15[0], x4x5x14x15[1], x12x13x6x7[2], x12x13x6x7[3] };
+            x[2] = Lane{ x8x9x2x3[0], x8x9x2x3[1], x0x1x10x11[2], x0x1x10x11[3] };
+            x[3] = Lane{ x12x13x6x7[0], x12x13x6x7[1], x4x5x14x15[2], x4x5x14x15[3] };
         }
-    }
 
-    fn salsa20Xor(out: []u8, in: []const u8, key: [8]u32, d: [4]u32) void {
-        var ctx = initContext(key, d);
-        var x: BlockVec = undefined;
-        var buf: [64]u8 = undefined;
-        var i: usize = 0;
-        while (i + 64 <= in.len) : (i += 64) {
-            salsa20Core(x[0..], ctx, true);
-            hashToBytes(buf[0..], x);
-            var xout = out[i..];
-            const xin = in[i..];
-            var j: usize = 0;
-            while (j < 64) : (j += 1) {
-                xout[j] = xin[j];
+        fn hashToBytes(out: *[64]u8, x: BlockVec) void {
+            var i: usize = 0;
+            while (i < 4) : (i += 1) {
+                mem.writeIntLittle(u32, out[16 * i + 0 ..][0..4], x[i][0]);
+                mem.writeIntLittle(u32, out[16 * i + 4 ..][0..4], x[i][1]);
+                mem.writeIntLittle(u32, out[16 * i + 8 ..][0..4], x[i][2]);
+                mem.writeIntLittle(u32, out[16 * i + 12 ..][0..4], x[i][3]);
             }
-            j = 0;
-            while (j < 64) : (j += 1) {
-                xout[j] ^= buf[j];
+        }
+
+        fn salsaXor(out: []u8, in: []const u8, key: [8]u32, d: [4]u32) void {
+            var ctx = initContext(key, d);
+            var x: BlockVec = undefined;
+            var buf: [64]u8 = undefined;
+            var i: usize = 0;
+            while (i + 64 <= in.len) : (i += 64) {
+                salsaCore(x[0..], ctx, true);
+                hashToBytes(buf[0..], x);
+                var xout = out[i..];
+                const xin = in[i..];
+                var j: usize = 0;
+                while (j < 64) : (j += 1) {
+                    xout[j] = xin[j];
+                }
+                j = 0;
+                while (j < 64) : (j += 1) {
+                    xout[j] ^= buf[j];
+                }
+                ctx[3][2] +%= 1;
+                if (ctx[3][2] == 0) {
+                    ctx[3][3] += 1;
+                }
             }
-            ctx[3][2] +%= 1;
-            if (ctx[3][2] == 0) {
-                ctx[3][3] += 1;
+            if (i < in.len) {
+                salsaCore(x[0..], ctx, true);
+                hashToBytes(buf[0..], x);
+
+                var xout = out[i..];
+                const xin = in[i..];
+                var j: usize = 0;
+                while (j < in.len % 64) : (j += 1) {
+                    xout[j] = xin[j] ^ buf[j];
+                }
             }
         }
-        if (i < in.len) {
-            salsa20Core(x[0..], ctx, true);
-            hashToBytes(buf[0..], x);
 
-            var xout = out[i..];
-            const xin = in[i..];
-            var j: usize = 0;
-            while (j < in.len % 64) : (j += 1) {
-                xout[j] = xin[j] ^ buf[j];
+        fn hsalsa(input: [16]u8, key: [32]u8) [32]u8 {
+            var c: [4]u32 = undefined;
+            for (c) |_, i| {
+                c[i] = mem.readIntLittle(u32, input[4 * i ..][0..4]);
             }
+            const ctx = initContext(keyToWords(key), c);
+            var x: BlockVec = undefined;
+            salsaCore(x[0..], ctx, false);
+            var out: [32]u8 = undefined;
+            mem.writeIntLittle(u32, out[0..4], x[0][0]);
+            mem.writeIntLittle(u32, out[4..8], x[1][1]);
+            mem.writeIntLittle(u32, out[8..12], x[2][2]);
+            mem.writeIntLittle(u32, out[12..16], x[3][3]);
+            mem.writeIntLittle(u32, out[16..20], x[1][2]);
+            mem.writeIntLittle(u32, out[20..24], x[1][3]);
+            mem.writeIntLittle(u32, out[24..28], x[2][0]);
+            mem.writeIntLittle(u32, out[28..32], x[2][1]);
+            return out;
         }
-    }
+    };
+}
 
-    fn hsalsa20(input: [16]u8, key: [32]u8) [32]u8 {
-        var c: [4]u32 = undefined;
-        for (c) |_, i| {
-            c[i] = mem.readIntLittle(u32, input[4 * i ..][0..4]);
+fn SalsaNonVecImpl(comptime rounds: comptime_int) type {
+    return struct {
+        const BlockVec = [16]u32;
+
+        fn initContext(key: [8]u32, d: [4]u32) BlockVec {
+            const c = "expand 32-byte k";
+            const constant_le = comptime [4]u32{
+                mem.readIntLittle(u32, c[0..4]),
+                mem.readIntLittle(u32, c[4..8]),
+                mem.readIntLittle(u32, c[8..12]),
+                mem.readIntLittle(u32, c[12..16]),
+            };
+            return BlockVec{
+                constant_le[0], key[0],         key[1],         key[2],
+                key[3],         constant_le[1], d[0],           d[1],
+                d[2],           d[3],           constant_le[2], key[4],
+                key[5],         key[6],         key[7],         constant_le[3],
+            };
         }
-        const ctx = initContext(keyToWords(key), c);
-        var x: BlockVec = undefined;
-        salsa20Core(x[0..], ctx, false);
-        var out: [32]u8 = undefined;
-        mem.writeIntLittle(u32, out[0..4], x[0][0]);
-        mem.writeIntLittle(u32, out[4..8], x[1][1]);
-        mem.writeIntLittle(u32, out[8..12], x[2][2]);
-        mem.writeIntLittle(u32, out[12..16], x[3][3]);
-        mem.writeIntLittle(u32, out[16..20], x[1][2]);
-        mem.writeIntLittle(u32, out[20..24], x[1][3]);
-        mem.writeIntLittle(u32, out[24..28], x[2][0]);
-        mem.writeIntLittle(u32, out[28..32], x[2][1]);
-        return out;
-    }
-};
 
-const Salsa20NonVecImpl = struct {
-    const BlockVec = [16]u32;
-
-    fn initContext(key: [8]u32, d: [4]u32) BlockVec {
-        const c = "expand 32-byte k";
-        const constant_le = comptime [4]u32{
-            mem.readIntLittle(u32, c[0..4]),
-            mem.readIntLittle(u32, c[4..8]),
-            mem.readIntLittle(u32, c[8..12]),
-            mem.readIntLittle(u32, c[12..16]),
-        };
-        return BlockVec{
-            constant_le[0], key[0],         key[1],         key[2],
-            key[3],         constant_le[1], d[0],           d[1],
-            d[2],           d[3],           constant_le[2], key[4],
-            key[5],         key[6],         key[7],         constant_le[3],
+        const QuarterRound = struct {
+            a: usize,
+            b: usize,
+            c: usize,
+            d: u6,
         };
-    }
 
-    const QuarterRound = struct {
-        a: usize,
-        b: usize,
-        c: usize,
-        d: u6,
-    };
-
-    inline fn Rp(a: usize, b: usize, c: usize, d: u6) QuarterRound {
-        return QuarterRound{
-            .a = a,
-            .b = b,
-            .c = c,
-            .d = d,
-        };
-    }
+        inline fn Rp(a: usize, b: usize, c: usize, d: u6) QuarterRound {
+            return QuarterRound{
+                .a = a,
+                .b = b,
+                .c = c,
+                .d = d,
+            };
+        }
 
-    inline fn salsa20Core(x: *BlockVec, input: BlockVec, comptime feedback: bool) void {
-        const arx_steps = comptime [_]QuarterRound{
-            Rp(4, 0, 12, 7),   Rp(8, 4, 0, 9),    Rp(12, 8, 4, 13),   Rp(0, 12, 8, 18),
-            Rp(9, 5, 1, 7),    Rp(13, 9, 5, 9),   Rp(1, 13, 9, 13),   Rp(5, 1, 13, 18),
-            Rp(14, 10, 6, 7),  Rp(2, 14, 10, 9),  Rp(6, 2, 14, 13),   Rp(10, 6, 2, 18),
-            Rp(3, 15, 11, 7),  Rp(7, 3, 15, 9),   Rp(11, 7, 3, 13),   Rp(15, 11, 7, 18),
-            Rp(1, 0, 3, 7),    Rp(2, 1, 0, 9),    Rp(3, 2, 1, 13),    Rp(0, 3, 2, 18),
-            Rp(6, 5, 4, 7),    Rp(7, 6, 5, 9),    Rp(4, 7, 6, 13),    Rp(5, 4, 7, 18),
-            Rp(11, 10, 9, 7),  Rp(8, 11, 10, 9),  Rp(9, 8, 11, 13),   Rp(10, 9, 8, 18),
-            Rp(12, 15, 14, 7), Rp(13, 12, 15, 9), Rp(14, 13, 12, 13), Rp(15, 14, 13, 18),
-        };
-        x.* = input;
-        var j: usize = 0;
-        while (j < 20) : (j += 2) {
-            inline for (arx_steps) |r| {
-                x[r.a] ^= math.rotl(u32, x[r.b] +% x[r.c], r.d);
+        inline fn salsaCore(x: *BlockVec, input: BlockVec, comptime feedback: bool) void {
+            const arx_steps = comptime [_]QuarterRound{
+                Rp(4, 0, 12, 7),   Rp(8, 4, 0, 9),    Rp(12, 8, 4, 13),   Rp(0, 12, 8, 18),
+                Rp(9, 5, 1, 7),    Rp(13, 9, 5, 9),   Rp(1, 13, 9, 13),   Rp(5, 1, 13, 18),
+                Rp(14, 10, 6, 7),  Rp(2, 14, 10, 9),  Rp(6, 2, 14, 13),   Rp(10, 6, 2, 18),
+                Rp(3, 15, 11, 7),  Rp(7, 3, 15, 9),   Rp(11, 7, 3, 13),   Rp(15, 11, 7, 18),
+                Rp(1, 0, 3, 7),    Rp(2, 1, 0, 9),    Rp(3, 2, 1, 13),    Rp(0, 3, 2, 18),
+                Rp(6, 5, 4, 7),    Rp(7, 6, 5, 9),    Rp(4, 7, 6, 13),    Rp(5, 4, 7, 18),
+                Rp(11, 10, 9, 7),  Rp(8, 11, 10, 9),  Rp(9, 8, 11, 13),   Rp(10, 9, 8, 18),
+                Rp(12, 15, 14, 7), Rp(13, 12, 15, 9), Rp(14, 13, 12, 13), Rp(15, 14, 13, 18),
+            };
+            x.* = input;
+            var j: usize = 0;
+            while (j < rounds) : (j += 2) {
+                inline for (arx_steps) |r| {
+                    x[r.a] ^= math.rotl(u32, x[r.b] +% x[r.c], r.d);
+                }
             }
-        }
-        if (feedback) {
-            j = 0;
-            while (j < 16) : (j += 1) {
-                x[j] +%= input[j];
+            if (feedback) {
+                j = 0;
+                while (j < 16) : (j += 1) {
+                    x[j] +%= input[j];
+                }
             }
         }
-    }
 
-    fn hashToBytes(out: *[64]u8, x: BlockVec) void {
-        for (x) |w, i| {
-            mem.writeIntLittle(u32, out[i * 4 ..][0..4], w);
+        fn hashToBytes(out: *[64]u8, x: BlockVec) void {
+            for (x) |w, i| {
+                mem.writeIntLittle(u32, out[i * 4 ..][0..4], w);
+            }
         }
-    }
 
-    fn salsa20Xor(out: []u8, in: []const u8, key: [8]u32, d: [4]u32) void {
-        var ctx = initContext(key, d);
-        var x: BlockVec = undefined;
-        var buf: [64]u8 = undefined;
-        var i: usize = 0;
-        while (i + 64 <= in.len) : (i += 64) {
-            salsa20Core(x[0..], ctx, true);
-            hashToBytes(buf[0..], x);
-            var xout = out[i..];
-            const xin = in[i..];
-            var j: usize = 0;
-            while (j < 64) : (j += 1) {
-                xout[j] = xin[j];
+        fn salsaXor(out: []u8, in: []const u8, key: [8]u32, d: [4]u32) void {
+            var ctx = initContext(key, d);
+            var x: BlockVec = undefined;
+            var buf: [64]u8 = undefined;
+            var i: usize = 0;
+            while (i + 64 <= in.len) : (i += 64) {
+                salsaCore(x[0..], ctx, true);
+                hashToBytes(buf[0..], x);
+                var xout = out[i..];
+                const xin = in[i..];
+                var j: usize = 0;
+                while (j < 64) : (j += 1) {
+                    xout[j] = xin[j];
+                }
+                j = 0;
+                while (j < 64) : (j += 1) {
+                    xout[j] ^= buf[j];
+                }
+                ctx[9] += @boolToInt(@addWithOverflow(u32, ctx[8], 1, &ctx[8]));
             }
-            j = 0;
-            while (j < 64) : (j += 1) {
-                xout[j] ^= buf[j];
+            if (i < in.len) {
+                salsaCore(x[0..], ctx, true);
+                hashToBytes(buf[0..], x);
+
+                var xout = out[i..];
+                const xin = in[i..];
+                var j: usize = 0;
+                while (j < in.len % 64) : (j += 1) {
+                    xout[j] = xin[j] ^ buf[j];
+                }
             }
-            ctx[9] += @boolToInt(@addWithOverflow(u32, ctx[8], 1, &ctx[8]));
         }
-        if (i < in.len) {
-            salsa20Core(x[0..], ctx, true);
-            hashToBytes(buf[0..], x);
 
-            var xout = out[i..];
-            const xin = in[i..];
-            var j: usize = 0;
-            while (j < in.len % 64) : (j += 1) {
-                xout[j] = xin[j] ^ buf[j];
+        fn hsalsa(input: [16]u8, key: [32]u8) [32]u8 {
+            var c: [4]u32 = undefined;
+            for (c) |_, i| {
+                c[i] = mem.readIntLittle(u32, input[4 * i ..][0..4]);
             }
+            const ctx = initContext(keyToWords(key), c);
+            var x: BlockVec = undefined;
+            salsaCore(x[0..], ctx, false);
+            var out: [32]u8 = undefined;
+            mem.writeIntLittle(u32, out[0..4], x[0]);
+            mem.writeIntLittle(u32, out[4..8], x[5]);
+            mem.writeIntLittle(u32, out[8..12], x[10]);
+            mem.writeIntLittle(u32, out[12..16], x[15]);
+            mem.writeIntLittle(u32, out[16..20], x[6]);
+            mem.writeIntLittle(u32, out[20..24], x[7]);
+            mem.writeIntLittle(u32, out[24..28], x[8]);
+            mem.writeIntLittle(u32, out[28..32], x[9]);
+            return out;
         }
-    }
-
-    fn hsalsa20(input: [16]u8, key: [32]u8) [32]u8 {
-        var c: [4]u32 = undefined;
-        for (c) |_, i| {
-            c[i] = mem.readIntLittle(u32, input[4 * i ..][0..4]);
-        }
-        const ctx = initContext(keyToWords(key), c);
-        var x: BlockVec = undefined;
-        salsa20Core(x[0..], ctx, false);
-        var out: [32]u8 = undefined;
-        mem.writeIntLittle(u32, out[0..4], x[0]);
-        mem.writeIntLittle(u32, out[4..8], x[5]);
-        mem.writeIntLittle(u32, out[8..12], x[10]);
-        mem.writeIntLittle(u32, out[12..16], x[15]);
-        mem.writeIntLittle(u32, out[16..20], x[6]);
-        mem.writeIntLittle(u32, out[20..24], x[7]);
-        mem.writeIntLittle(u32, out[24..28], x[8]);
-        mem.writeIntLittle(u32, out[28..32], x[9]);
-        return out;
-    }
-};
+    };
+}
 
-const Salsa20Impl = if (builtin.cpu.arch == .x86_64) Salsa20VecImpl else Salsa20NonVecImpl;
+const SalsaImpl = if (builtin.cpu.arch == .x86_64) SalsaVecImpl else SalsaNonVecImpl;
 
 fn keyToWords(key: [32]u8) [8]u32 {
     var k: [8]u32 = undefined;
@@ -315,52 +311,56 @@ fn keyToWords(key: [32]u8) [8]u32 {
     return k;
 }
 
-fn extend(key: [32]u8, nonce: [24]u8) struct { key: [32]u8, nonce: [8]u8 } {
+fn extend(comptime rounds: comptime_int, key: [32]u8, nonce: [24]u8) struct { key: [32]u8, nonce: [8]u8 } {
     return .{
-        .key = Salsa20Impl.hsalsa20(nonce[0..16].*, key),
+        .key = SalsaImpl(rounds).hsalsa(nonce[0..16].*, key),
         .nonce = nonce[16..24].*,
     };
 }
 
-/// The Salsa20 stream cipher.
-pub const Salsa20 = struct {
-    /// Nonce length in bytes.
-    pub const nonce_length = 8;
-    /// Key length in bytes.
-    pub const key_length = 32;
-
-    /// Add the output of the Salsa20 stream cipher to `in` and stores the result into `out`.
-    /// WARNING: This function doesn't provide authenticated encryption.
-    /// Using the AEAD or one of the `box` versions is usually preferred.
-    pub fn xor(out: []u8, in: []const u8, counter: u64, key: [key_length]u8, nonce: [nonce_length]u8) void {
-        debug.assert(in.len == out.len);
-
-        var d: [4]u32 = undefined;
-        d[0] = mem.readIntLittle(u32, nonce[0..4]);
-        d[1] = mem.readIntLittle(u32, nonce[4..8]);
-        d[2] = @truncate(u32, counter);
-        d[3] = @truncate(u32, counter >> 32);
-        Salsa20Impl.salsa20Xor(out, in, keyToWords(key), d);
-    }
-};
+/// The Salsa stream cipher.
+pub fn Salsa(comptime rounds: comptime_int) type {
+    return struct {
+        /// Nonce length in bytes.
+        pub const nonce_length = 8;
+        /// Key length in bytes.
+        pub const key_length = 32;
+
+        /// Add the output of the Salsa stream cipher to `in` and stores the result into `out`.
+        /// WARNING: This function doesn't provide authenticated encryption.
+        /// Using the AEAD or one of the `box` versions is usually preferred.
+        pub fn xor(out: []u8, in: []const u8, counter: u64, key: [key_length]u8, nonce: [nonce_length]u8) void {
+            debug.assert(in.len == out.len);
+
+            var d: [4]u32 = undefined;
+            d[0] = mem.readIntLittle(u32, nonce[0..4]);
+            d[1] = mem.readIntLittle(u32, nonce[4..8]);
+            d[2] = @truncate(u32, counter);
+            d[3] = @truncate(u32, counter >> 32);
+            SalsaImpl(rounds).salsaXor(out, in, keyToWords(key), d);
+        }
+    };
+}
 
-/// The XSalsa20 stream cipher.
-pub const XSalsa20 = struct {
-    /// Nonce length in bytes.
-    pub const nonce_length = 24;
-    /// Key length in bytes.
-    pub const key_length = 32;
-
-    /// Add the output of the XSalsa20 stream cipher to `in` and stores the result into `out`.
-    /// WARNING: This function doesn't provide authenticated encryption.
-    /// Using the AEAD or one of the `box` versions is usually preferred.
-    pub fn xor(out: []u8, in: []const u8, counter: u64, key: [key_length]u8, nonce: [nonce_length]u8) void {
-        const extended = extend(key, nonce);
-        Salsa20.xor(out, in, counter, extended.key, extended.nonce);
-    }
-};
+/// The XSalsa stream cipher.
+pub fn XSalsa(comptime rounds: comptime_int) type {
+    return struct {
+        /// Nonce length in bytes.
+        pub const nonce_length = 24;
+        /// Key length in bytes.
+        pub const key_length = 32;
+
+        /// Add the output of the XSalsa stream cipher to `in` and stores the result into `out`.
+        /// WARNING: This function doesn't provide authenticated encryption.
+        /// Using the AEAD or one of the `box` versions is usually preferred.
+        pub fn xor(out: []u8, in: []const u8, counter: u64, key: [key_length]u8, nonce: [nonce_length]u8) void {
+            const extended = extend(rounds, key, nonce);
+            Salsa(rounds).xor(out, in, counter, extended.key, extended.nonce);
+        }
+    };
+}
 
-/// The XSalsa20 stream cipher, combined with the Poly1305 MAC
+/// The XSalsa stream cipher, combined with the Poly1305 MAC
 pub const XSalsa20Poly1305 = struct {
     /// Authentication tag length in bytes.
     pub const tag_length = Poly1305.mac_length;
@@ -369,6 +369,8 @@ pub const XSalsa20Poly1305 = struct {
     /// Key length in bytes.
     pub const key_length = XSalsa20.key_length;
 
+    const rounds = 20;
+
     /// c: ciphertext: output buffer should be of size m.len
     /// tag: authentication tag: output MAC
     /// m: message
@@ -377,7 +379,7 @@ pub const XSalsa20Poly1305 = struct {
     /// k: private key
     pub fn encrypt(c: []u8, tag: *[tag_length]u8, m: []const u8, ad: []const u8, npub: [nonce_length]u8, k: [key_length]u8) void {
         debug.assert(c.len == m.len);
-        const extended = extend(k, npub);
+        const extended = extend(rounds, k, npub);
         var block0 = [_]u8{0} ** 64;
         const mlen0 = math.min(32, m.len);
         mem.copy(u8, block0[32..][0..mlen0], m[0..mlen0]);
@@ -398,7 +400,7 @@ pub const XSalsa20Poly1305 = struct {
     /// k: private key
     pub fn decrypt(m: []u8, c: []const u8, tag: [tag_length]u8, ad: []const u8, npub: [nonce_length]u8, k: [key_length]u8) AuthenticationError!void {
         debug.assert(c.len == m.len);
-        const extended = extend(k, npub);
+        const extended = extend(rounds, k, npub);
         var block0 = [_]u8{0} ** 64;
         const mlen0 = math.min(32, c.len);
         mem.copy(u8, block0[32..][0..mlen0], c[0..mlen0]);
@@ -482,7 +484,7 @@ pub const Box = struct {
     pub fn createSharedSecret(public_key: [public_length]u8, secret_key: [secret_length]u8) (IdentityElementError || WeakPublicKeyError)![shared_length]u8 {
         const p = try X25519.scalarmult(secret_key, public_key);
         const zero = [_]u8{0} ** 16;
-        return Salsa20Impl.hsalsa20(zero, p);
+        return SalsaImpl(20).hsalsa(zero, p);
     }
 
     /// Encrypt and authenticate a message using a recipient's public key `public_key` and a sender's `secret_key`.
lib/std/crypto.zig
@@ -147,6 +147,8 @@ pub const stream = struct {
     };
 
     pub const salsa = struct {
+        pub const Salsa = @import("crypto/salsa20.zig").Salsa;
+        pub const XSalsa = @import("crypto/salsa20.zig").XSalsa;
         pub const Salsa20 = @import("crypto/salsa20.zig").Salsa20;
         pub const XSalsa20 = @import("crypto/salsa20.zig").XSalsa20;
     };