Commit 1d20ada366

Andrew Kelley <andrew@ziglang.org>
2022-12-29 03:54:17
std.crypto.tls.Client: refactor to reduce namespace bloat
1 parent 16af628
Changed files (1)
lib
std
crypto
lib/std/crypto/tls/Client.zig
@@ -5,18 +5,14 @@ const net = std.net;
 const mem = std.mem;
 const crypto = std.crypto;
 const assert = std.debug.assert;
+const Certificate = std.crypto.Certificate;
 
-const ApplicationCipher = tls.ApplicationCipher;
-const CipherSuite = tls.CipherSuite;
-const ContentType = tls.ContentType;
-const HandshakeCipher = tls.HandshakeCipher;
 const max_ciphertext_len = tls.max_ciphertext_len;
 const hkdfExpandLabel = tls.hkdfExpandLabel;
 const int2 = tls.int2;
 const int3 = tls.int3;
 const array = tls.array;
 const enum_array = tls.enum_array;
-const Certificate = crypto.Certificate;
 
 read_seq: u64,
 write_seq: u64,
@@ -27,7 +23,7 @@ partially_read_len: u15,
 /// re-decrypt bytes from `partially_read_buffer` when the buffer supplied by
 /// the read() API user is not large enough.
 partial_cleartext_index: u15,
-application_cipher: ApplicationCipher,
+application_cipher: tls.ApplicationCipher,
 eof: bool,
 /// The size is enough to contain exactly one TLSCiphertext record.
 /// Contains encrypted bytes.
@@ -101,7 +97,7 @@ pub fn init(stream: net.Stream, ca_bundle: Certificate.Bundle, host: []const u8)
         client_hello;
 
     const plaintext_header = [_]u8{
-        @enumToInt(ContentType.handshake),
+        @enumToInt(tls.ContentType.handshake),
         0x03, 0x01, // legacy_record_version
     } ++ int2(@intCast(u16, out_handshake.len + host_len)) ++ out_handshake;
 
@@ -121,7 +117,7 @@ pub fn init(stream: net.Stream, ca_bundle: Certificate.Bundle, host: []const u8)
 
     const client_hello_bytes1 = plaintext_header[5..];
 
-    var handshake_cipher: HandshakeCipher = undefined;
+    var handshake_cipher: tls.HandshakeCipher = undefined;
 
     var handshake_buf: [8000]u8 = undefined;
     var len: usize = 0;
@@ -129,7 +125,7 @@ pub fn init(stream: net.Stream, ca_bundle: Certificate.Bundle, host: []const u8)
         const plaintext = handshake_buf[0..5];
         len = try stream.readAtLeast(&handshake_buf, plaintext.len);
         if (len < plaintext.len) return error.EndOfStream;
-        const ct = @intToEnum(ContentType, plaintext[0]);
+        const ct = @intToEnum(tls.ContentType, plaintext[0]);
         const frag_len = mem.readIntBig(u16, plaintext[3..][0..2]);
         const end = plaintext.len + frag_len;
         if (end > handshake_buf.len) return error.TlsRecordOverflow;
@@ -169,7 +165,7 @@ pub fn init(stream: net.Stream, ca_bundle: Certificate.Bundle, host: []const u8)
                 i += 32;
                 const cipher_suite_int = mem.readIntBig(u16, frag[i..][0..2]);
                 i += 2;
-                const cipher_suite_tag = @intToEnum(CipherSuite, cipher_suite_int);
+                const cipher_suite_tag = @intToEnum(tls.CipherSuite, cipher_suite_int);
                 const legacy_compression_method = frag[i];
                 i += 1;
                 _ = legacy_compression_method;
@@ -247,8 +243,8 @@ pub fn init(stream: net.Stream, ca_bundle: Certificate.Bundle, host: []const u8)
                     .AEGIS_256_SHA384,
                     .AEGIS_128L_SHA256,
                     => |tag| {
-                        const P = std.meta.TagPayloadByName(HandshakeCipher, @tagName(tag));
-                        handshake_cipher = @unionInit(HandshakeCipher, @tagName(tag), .{
+                        const P = std.meta.TagPayloadByName(tls.HandshakeCipher, @tagName(tag));
+                        handshake_cipher = @unionInit(tls.HandshakeCipher, @tagName(tag), .{
                             .handshake_secret = undefined,
                             .master_secret = undefined,
                             .client_handshake_key = undefined,
@@ -338,7 +334,7 @@ pub fn init(stream: net.Stream, ca_bundle: Certificate.Bundle, host: []const u8)
             len += try stream.readAtLeast(handshake_buf[len..], end_hdr - len);
             if (end_hdr > len) return error.EndOfStream;
         }
-        const ct = @intToEnum(ContentType, handshake_buf[i]);
+        const ct = @intToEnum(tls.ContentType, handshake_buf[i]);
         i += 1;
         const legacy_version = mem.readIntBig(u16, handshake_buf[i..][0..2]);
         i += 2;
@@ -380,7 +376,7 @@ pub fn init(stream: net.Stream, ca_bundle: Certificate.Bundle, host: []const u8)
                     },
                 };
 
-                const inner_ct = @intToEnum(ContentType, cleartext[cleartext.len - 1]);
+                const inner_ct = @intToEnum(tls.ContentType, cleartext[cleartext.len - 1]);
                 switch (inner_ct) {
                     .handshake => {
                         var ct_i: usize = 0;
@@ -546,7 +542,7 @@ pub fn init(stream: net.Stream, ca_bundle: Certificate.Bundle, host: []const u8)
                                     if (handshake_state != .finished) return error.TlsUnexpectedMessage;
                                     // This message is to trick buggy proxies into behaving correctly.
                                     const client_change_cipher_spec_msg = [_]u8{
-                                        @enumToInt(ContentType.change_cipher_spec),
+                                        @enumToInt(tls.ContentType.change_cipher_spec),
                                         0x03, 0x03, // legacy protocol version
                                         0x00, 0x01, // length
                                         0x01,
@@ -564,12 +560,12 @@ pub fn init(stream: net.Stream, ca_bundle: Certificate.Bundle, host: []const u8)
                                             const out_cleartext = [_]u8{
                                                 @enumToInt(tls.HandshakeType.finished),
                                                 0, 0, verify_data.len, // length
-                                            } ++ verify_data ++ [1]u8{@enumToInt(ContentType.handshake)};
+                                            } ++ verify_data ++ [1]u8{@enumToInt(tls.ContentType.handshake)};
 
                                             const wrapped_len = out_cleartext.len + P.AEAD.tag_length;
 
                                             var finished_msg = [_]u8{
-                                                @enumToInt(ContentType.application_data),
+                                                @enumToInt(tls.ContentType.application_data),
                                                 0x03, 0x03, // legacy protocol version
                                                 0, wrapped_len, // byte length of encrypted record
                                             } ++ @as([wrapped_len]u8, undefined);
@@ -590,7 +586,7 @@ pub fn init(stream: net.Stream, ca_bundle: Certificate.Bundle, host: []const u8)
                                             //    std.fmt.fmtSliceHexLower(&client_secret),
                                             //    std.fmt.fmtSliceHexLower(&server_secret),
                                             //});
-                                            break :c @unionInit(ApplicationCipher, @tagName(tag), .{
+                                            break :c @unionInit(tls.ApplicationCipher, @tagName(tag), .{
                                                 .client_secret = client_secret,
                                                 .server_secret = server_secret,
                                                 .client_key = hkdfExpandLabel(P.Hkdf, client_secret, "key", "", P.AEAD.key_length),
@@ -661,7 +657,7 @@ pub fn write(c: *Client, stream: net.Stream, bytes: []const u8) !usize {
                 if (encrypted_content_len == 0) break :l overhead_len;
 
                 mem.copy(u8, &cleartext_buf, bytes[bytes_i..][0..encrypted_content_len]);
-                cleartext_buf[encrypted_content_len] = @enumToInt(ContentType.application_data);
+                cleartext_buf[encrypted_content_len] = @enumToInt(tls.ContentType.application_data);
                 bytes_i += encrypted_content_len;
                 const ciphertext_len = encrypted_content_len + 1;
                 const cleartext = cleartext_buf[0..ciphertext_len];
@@ -669,7 +665,7 @@ pub fn write(c: *Client, stream: net.Stream, bytes: []const u8) !usize {
                 const record_start = ciphertext_end;
                 const ad = ciphertext_buf[ciphertext_end..][0..5];
                 ad.* =
-                    [_]u8{@enumToInt(ContentType.application_data)} ++
+                    [_]u8{@enumToInt(tls.ContentType.application_data)} ++
                     int2(@enumToInt(tls.ProtocolVersion.tls_1_2)) ++
                     int2(ciphertext_len + P.AEAD.tag_length);
                 ciphertext_end += ad.len;
@@ -818,7 +814,7 @@ pub fn readAdvanced(c: *Client, stream: net.Stream, buffer: []u8) !usize {
             return finishRead(c, frag, in, out);
         }
         const record_start = in;
-        const ct = @intToEnum(ContentType, frag[in]);
+        const ct = @intToEnum(tls.ContentType, frag[in]);
         in += 1;
         const legacy_version = mem.readIntBig(u16, frag[in..][0..2]);
         in += 2;
@@ -861,7 +857,7 @@ pub fn readAdvanced(c: *Client, stream: net.Stream, buffer: []u8) !usize {
                     },
                 };
 
-                const inner_ct = @intToEnum(ContentType, cleartext[cleartext.len - 1]);
+                const inner_ct = @intToEnum(tls.ContentType, cleartext[cleartext.len - 1]);
                 switch (inner_ct) {
                     .alert => {
                         c.read_seq += 1;