Commit 21ab99174e

Andrew Kelley <andrew@ziglang.org>
2022-12-28 07:49:15
std.crypto.tls.Client: use enums more
1 parent 477864d
Changed files (2)
lib
std
lib/std/crypto/tls/Client.zig
@@ -9,7 +9,6 @@ const assert = std.debug.assert;
 const ApplicationCipher = tls.ApplicationCipher;
 const CipherSuite = tls.CipherSuite;
 const ContentType = tls.ContentType;
-const HandshakeType = tls.HandshakeType;
 const HandshakeCipher = tls.HandshakeCipher;
 const max_ciphertext_len = tls.max_ciphertext_len;
 const hkdfExpandLabel = tls.hkdfExpandLabel;
@@ -91,7 +90,7 @@ pub fn init(stream: net.Stream, ca_bundle: Certificate.Bundle, host: []const u8)
         extensions_header;
 
     const out_handshake =
-        [_]u8{@enumToInt(HandshakeType.client_hello)} ++
+        [_]u8{@enumToInt(tls.HandshakeType.client_hello)} ++
         int3(@intCast(u24, client_hello.len + host_len)) ++
         client_hello;
 
@@ -142,7 +141,7 @@ pub fn init(stream: net.Stream, ca_bundle: Certificate.Bundle, host: []const u8)
                 return error.TlsAlert;
             },
             .handshake => {
-                if (frag[0] != @enumToInt(HandshakeType.server_hello)) {
+                if (frag[0] != @enumToInt(tls.HandshakeType.server_hello)) {
                     return error.TlsUnexpectedMessage;
                 }
                 const length = mem.readIntBig(u24, frag[1..4]);
@@ -175,27 +174,27 @@ pub fn init(stream: net.Stream, ca_bundle: Certificate.Bundle, host: []const u8)
                 var shared_key: [32]u8 = undefined;
                 var have_shared_key = false;
                 while (i < frag.len) {
-                    const et = mem.readIntBig(u16, frag[i..][0..2]);
+                    const et = @intToEnum(tls.ExtensionType, mem.readIntBig(u16, frag[i..][0..2]));
                     i += 2;
                     const ext_size = mem.readIntBig(u16, frag[i..][0..2]);
                     i += 2;
                     const next_i = i + ext_size;
                     if (next_i > frag.len) return error.TlsBadLength;
                     switch (et) {
-                        @enumToInt(tls.ExtensionType.supported_versions) => {
+                        .supported_versions => {
                             if (supported_version != 0) return error.TlsIllegalParameter;
                             supported_version = mem.readIntBig(u16, frag[i..][0..2]);
                         },
-                        @enumToInt(tls.ExtensionType.key_share) => {
+                        .key_share => {
                             if (have_shared_key) return error.TlsIllegalParameter;
                             have_shared_key = true;
-                            const named_group = mem.readIntBig(u16, frag[i..][0..2]);
+                            const named_group = @intToEnum(tls.NamedGroup, mem.readIntBig(u16, frag[i..][0..2]));
                             i += 2;
                             const key_size = mem.readIntBig(u16, frag[i..][0..2]);
                             i += 2;
 
                             switch (named_group) {
-                                @enumToInt(tls.NamedGroup.x25519) => {
+                                .x25519 => {
                                     if (key_size != 32) return error.TlsBadLength;
                                     const server_pub_key = frag[i..][0..32];
 
@@ -204,7 +203,7 @@ pub fn init(stream: net.Stream, ca_bundle: Certificate.Bundle, host: []const u8)
                                         server_pub_key.*,
                                     ) catch return error.TlsDecryptFailure;
                                 },
-                                @enumToInt(tls.NamedGroup.secp256r1) => {
+                                .secp256r1 => {
                                     const server_pub_key = frag[i..][0..key_size];
 
                                     const PublicKey = crypto.sign.ecdsa.EcdsaP256Sha256.PublicKey;
@@ -217,7 +216,7 @@ pub fn init(stream: net.Stream, ca_bundle: Certificate.Bundle, host: []const u8)
                                     shared_key = mul.affineCoordinates().x.toBytes(.Big);
                                 },
                                 else => {
-                                    std.debug.print("named group: {x}\n", .{named_group});
+                                    //std.debug.print("named group: {x}\n", .{named_group});
                                     return error.TlsIllegalParameter;
                                 },
                             }
@@ -380,7 +379,7 @@ pub fn init(stream: net.Stream, ca_bundle: Certificate.Bundle, host: []const u8)
                     .handshake => {
                         var ct_i: usize = 0;
                         while (true) {
-                            const handshake_type = cleartext[ct_i];
+                            const handshake_type = @intToEnum(tls.HandshakeType, cleartext[ct_i]);
                             ct_i += 1;
                             const handshake_len = mem.readIntBig(u24, cleartext[ct_i..][0..3]);
                             ct_i += 3;
@@ -390,7 +389,7 @@ pub fn init(stream: net.Stream, ca_bundle: Certificate.Bundle, host: []const u8)
                             const wrapped_handshake = cleartext[ct_i - 4 .. next_handshake_i];
                             const handshake = cleartext[ct_i..next_handshake_i];
                             switch (handshake_type) {
-                                @enumToInt(HandshakeType.encrypted_extensions) => {
+                                .encrypted_extensions => {
                                     if (handshake_state != .encrypted_extensions) return error.TlsUnexpectedMessage;
                                     handshake_state = .certificate;
                                     switch (handshake_cipher) {
@@ -400,13 +399,13 @@ pub fn init(stream: net.Stream, ca_bundle: Certificate.Bundle, host: []const u8)
                                     var hs_i: usize = 2;
                                     const end_ext_i = 2 + total_ext_size;
                                     while (hs_i < end_ext_i) {
-                                        const et = mem.readIntBig(u16, handshake[hs_i..][0..2]);
+                                        const et = @intToEnum(tls.ExtensionType, mem.readIntBig(u16, handshake[hs_i..][0..2]));
                                         hs_i += 2;
                                         const ext_size = mem.readIntBig(u16, handshake[hs_i..][0..2]);
                                         hs_i += 2;
                                         const next_ext_i = hs_i + ext_size;
                                         switch (et) {
-                                            @enumToInt(tls.ExtensionType.server_name) => {},
+                                            .server_name => {},
                                             else => {
                                                 std.debug.print("encrypted extension: {any}\n", .{
                                                     et,
@@ -416,7 +415,7 @@ pub fn init(stream: net.Stream, ca_bundle: Certificate.Bundle, host: []const u8)
                                         hs_i = next_ext_i;
                                     }
                                 },
-                                @enumToInt(HandshakeType.certificate) => cert: {
+                                .certificate => cert: {
                                     switch (handshake_cipher) {
                                         inline else => |*p| p.transcript_hash.update(wrapped_handshake),
                                     }
@@ -488,7 +487,7 @@ pub fn init(stream: net.Stream, ca_bundle: Certificate.Bundle, host: []const u8)
                                         hs_i += total_ext_size;
                                     }
                                 },
-                                @enumToInt(HandshakeType.certificate_verify) => {
+                                .certificate_verify => {
                                     switch (handshake_state) {
                                         .trust_chain_established => handshake_state = .finished,
                                         .certificate => return error.TlsCertificateNotVerified,
@@ -535,7 +534,7 @@ pub fn init(stream: net.Stream, ca_bundle: Certificate.Bundle, host: []const u8)
                                         },
                                     }
                                 },
-                                @enumToInt(HandshakeType.finished) => {
+                                .finished => {
                                     if (handshake_state != .finished) return error.TlsUnexpectedMessage;
                                     // This message is to trick buggy proxies into behaving correctly.
                                     const client_change_cipher_spec_msg = [_]u8{
@@ -555,7 +554,7 @@ pub fn init(stream: net.Stream, ca_bundle: Certificate.Bundle, host: []const u8)
                                             const handshake_hash = p.transcript_hash.finalResult();
                                             const verify_data = tls.hmac(P.Hmac, &handshake_hash, p.client_finished_key);
                                             const out_cleartext = [_]u8{
-                                                @enumToInt(HandshakeType.finished),
+                                                @enumToInt(tls.HandshakeType.finished),
                                                 0, 0, verify_data.len, // length
                                             } ++ verify_data ++ [1]u8{@enumToInt(ContentType.handshake)};
 
@@ -810,7 +809,7 @@ pub fn read(c: *Client, stream: net.Stream, buffer: []u8) !usize {
                     .handshake => {
                         var ct_i: usize = 0;
                         while (true) {
-                            const handshake_type = cleartext[ct_i];
+                            const handshake_type = @intToEnum(tls.HandshakeType, cleartext[ct_i]);
                             ct_i += 1;
                             const handshake_len = mem.readIntBig(u24, cleartext[ct_i..][0..3]);
                             ct_i += 3;
@@ -819,10 +818,10 @@ pub fn read(c: *Client, stream: net.Stream, buffer: []u8) !usize {
                                 return error.TlsBadLength;
                             const handshake = cleartext[ct_i..next_handshake_i];
                             switch (handshake_type) {
-                                @enumToInt(HandshakeType.new_session_ticket) => {
+                                .new_session_ticket => {
                                     std.debug.print("server sent a new session ticket\n", .{});
                                 },
-                                @enumToInt(HandshakeType.key_update) => {
+                                .key_update => {
                                     switch (c.application_cipher) {
                                         inline else => |*p| {
                                             const P = @TypeOf(p.*);
lib/std/crypto/tls.zig
@@ -74,6 +74,7 @@ pub const HandshakeType = enum(u8) {
     finished = 20,
     key_update = 24,
     message_hash = 254,
+    _,
 };
 
 pub const ExtensionType = enum(u16) {
@@ -121,6 +122,8 @@ pub const ExtensionType = enum(u16) {
     signature_algorithms_cert = 50,
     /// RFC 8446
     key_share = 51,
+
+    _,
 };
 
 pub const AlertLevel = enum(u8) {