Commit 0e5e6cb10c

Nameless <truemedian@gmail.com>
2023-05-28 09:37:56
std.http: add TlsAlert descriptions so that they can at least be viewed in err return traces
1 parent 8136123
Changed files (5)
lib/std/crypto/tls/Client.zig
@@ -89,12 +89,11 @@ pub const StreamInterface = struct {
 };
 
 pub fn InitError(comptime Stream: type) type {
-    return std.mem.Allocator.Error || Stream.WriteError || Stream.ReadError || error{
+    return std.mem.Allocator.Error || Stream.WriteError || Stream.ReadError || tls.AlertDescription.Error || error{
         InsufficientEntropy,
         DiskQuota,
         LockViolation,
         NotOpenForWriting,
-        TlsAlert,
         TlsUnexpectedMessage,
         TlsIllegalParameter,
         TlsDecryptFailure,
@@ -251,8 +250,11 @@ pub fn init(stream: anytype, ca_bundle: Certificate.Bundle, host: []const u8) In
                 const level = ptd.decode(tls.AlertLevel);
                 const desc = ptd.decode(tls.AlertDescription);
                 _ = level;
-                _ = desc;
-                return error.TlsAlert;
+
+                // if this isn't a error alert, then it's a closure alert, which makes no sense in a handshake
+                try desc.toError();
+                // TODO: handle server-side closures
+                return error.TlsUnexpectedMessage;
             },
             .handshake => {
                 try ptd.ensure(4);
@@ -1071,8 +1073,10 @@ pub fn readvAdvanced(c: *Client, stream: anytype, iovecs: []const std.os.iovec)
                 const level = @intToEnum(tls.AlertLevel, frag[in]);
                 const desc = @intToEnum(tls.AlertDescription, frag[in + 1]);
                 _ = level;
-                _ = desc;
-                return error.TlsAlert;
+
+                try desc.toError();
+                // TODO: handle server-side closures
+                return error.TlsUnexpectedMessage;
             },
             .application_data => {
                 const cleartext = switch (c.application_cipher) {
@@ -1112,7 +1116,10 @@ pub fn readvAdvanced(c: *Client, stream: anytype, iovecs: []const std.os.iovec)
                             return vp.total;
                         }
                         _ = level;
-                        return error.TlsAlert;
+
+                        try desc.toError();
+                        // TODO: handle server-side closures
+                        return error.TlsUnexpectedMessage;
                     },
                     .handshake => {
                         var ct_i: usize = 0;
lib/std/crypto/tls.zig
@@ -138,6 +138,35 @@ pub const AlertLevel = enum(u8) {
 };
 
 pub const AlertDescription = enum(u8) {
+    pub const Error = error{
+        TlsAlertUnexpectedMessage,
+        TlsAlertBadRecordMac,
+        TlsAlertRecordOverflow,
+        TlsAlertHandshakeFailure,
+        TlsAlertBadCertificate,
+        TlsAlertUnsupportedCertificate,
+        TlsAlertCertificateRevoked,
+        TlsAlertCertificateExpired,
+        TlsAlertCertificateUnknown,
+        TlsAlertIllegalParameter,
+        TlsAlertUnknownCa,
+        TlsAlertAccessDenied,
+        TlsAlertDecodeError,
+        TlsAlertDecryptError,
+        TlsAlertProtocolVersion,
+        TlsAlertInsufficientSecurity,
+        TlsAlertInternalError,
+        TlsAlertInappropriateFallback,
+        TlsAlertMissingExtension,
+        TlsAlertUnsupportedExtension,
+        TlsAlertUnrecognizedName,
+        TlsAlertBadCertificateStatusResponse,
+        TlsAlertUnknownPskIdentity,
+        TlsAlertCertificateRequired,
+        TlsAlertNoApplicationProtocol,
+        TlsAlertUnknown,
+    };
+
     close_notify = 0,
     unexpected_message = 10,
     bad_record_mac = 20,
@@ -166,6 +195,39 @@ pub const AlertDescription = enum(u8) {
     certificate_required = 116,
     no_application_protocol = 120,
     _,
+
+    pub fn toError(alert: AlertDescription) Error!void {
+        return switch (alert) {
+            .close_notify => {}, // not an error
+            .unexpected_message => error.TlsAlertUnexpectedMessage,
+            .bad_record_mac => error.TlsAlertBadRecordMac,
+            .record_overflow => error.TlsAlertRecordOverflow,
+            .handshake_failure => error.TlsAlertHandshakeFailure,
+            .bad_certificate => error.TlsAlertBadCertificate,
+            .unsupported_certificate => error.TlsAlertUnsupportedCertificate,
+            .certificate_revoked => error.TlsAlertCertificateRevoked,
+            .certificate_expired => error.TlsAlertCertificateExpired,
+            .certificate_unknown => error.TlsAlertCertificateUnknown,
+            .illegal_parameter => error.TlsAlertIllegalParameter,
+            .unknown_ca => error.TlsAlertUnknownCa,
+            .access_denied => error.TlsAlertAccessDenied,
+            .decode_error => error.TlsAlertDecodeError,
+            .decrypt_error => error.TlsAlertDecryptError,
+            .protocol_version => error.TlsAlertProtocolVersion,
+            .insufficient_security => error.TlsAlertInsufficientSecurity,
+            .internal_error => error.TlsAlertInternalError,
+            .inappropriate_fallback => error.TlsAlertInappropriateFallback,
+            .user_canceled => {}, // not an error
+            .missing_extension => error.TlsAlertMissingExtension,
+            .unsupported_extension => error.TlsAlertUnsupportedExtension,
+            .unrecognized_name => error.TlsAlertUnrecognizedName,
+            .bad_certificate_status_response => error.TlsAlertBadCertificateStatusResponse,
+            .unknown_psk_identity => error.TlsAlertUnknownPskIdentity,
+            .certificate_required => error.TlsAlertCertificateRequired,
+            .no_application_protocol => error.TlsAlertNoApplicationProtocol,
+            _ => error.TlsAlertUnknown,
+        };
+    }
 };
 
 pub const SignatureScheme = enum(u16) {
lib/std/http/Client.zig
@@ -168,19 +168,23 @@ pub const Connection = struct {
         return switch (conn.protocol) {
             .plain => conn.stream.readAtLeast(buffer, len),
             .tls => conn.tls_client.readAtLeast(conn.stream, buffer, len),
-        } catch |err| switch (err) {
-            error.TlsConnectionTruncated, error.TlsRecordOverflow, error.TlsDecodeError, error.TlsBadRecordMac, error.TlsBadLength, error.TlsIllegalParameter, error.TlsUnexpectedMessage => return error.TlsFailure,
-            error.TlsAlert => return error.TlsAlert,
-            error.ConnectionTimedOut => return error.ConnectionTimedOut,
-            error.ConnectionResetByPeer, error.BrokenPipe => return error.ConnectionResetByPeer,
-            else => return error.UnexpectedReadFailure,
+        } catch |err| {
+            // TODO: https://github.com/ziglang/zig/issues/2473
+            if (mem.startsWith(u8, @errorName(err), "TlsAlert")) return error.TlsAlert;
+
+            switch (err) {
+                error.TlsConnectionTruncated, error.TlsRecordOverflow, error.TlsDecodeError, error.TlsBadRecordMac, error.TlsBadLength, error.TlsIllegalParameter, error.TlsUnexpectedMessage => return error.TlsFailure,
+                error.ConnectionTimedOut => return error.ConnectionTimedOut,
+                error.ConnectionResetByPeer, error.BrokenPipe => return error.ConnectionResetByPeer,
+                else => return error.UnexpectedReadFailure,
+            }
         };
     }
 
     pub fn fill(conn: *Connection) ReadError!void {
         if (conn.read_end != conn.read_start) return;
 
-        const nread = try conn.conn.read(conn.read_buf[0..]);
+        const nread = try conn.read(conn.read_buf[0..]);
         if (nread == 0) return error.EndOfStream;
         conn.read_start = 0;
         conn.read_end = @intCast(u16, nread);
@@ -204,8 +208,8 @@ pub const Connection = struct {
 
             if (available_read > available_buffer) { // partially read buffered data
                 @memcpy(buffer[out_index..], conn.read_buf[conn.read_start..][0..available_buffer]);
-                out_index += available_buffer;
-                conn.read_start += available_buffer;
+                out_index += @intCast(u16, available_buffer);
+                conn.read_start += @intCast(u16, available_buffer);
 
                 break;
             } else if (available_read > 0) { // fully read buffered data
@@ -759,7 +763,7 @@ pub const Request = struct {
                 try req.connection.data.fill();
 
                 const nchecked = try req.response.parser.checkCompleteHead(req.client.allocator, req.connection.data.peek());
-                req.connection.data.clear(@intCast(u16, nchecked));
+                req.connection.data.drop(@intCast(u16, nchecked));
             }
 
             if (has_trail) {
lib/std/http/protocol.zig
@@ -513,8 +513,8 @@ pub const HeadersParser = struct {
     ///
     /// If `skip` is true, the buffer will be unused and the body will be skipped.
     ///
-    /// See `std.http.Client.BufferedConnection for an example of `bconn`.
-    pub fn read(r: *HeadersParser, bconn: anytype, buffer: []u8, skip: bool) !usize {
+    /// See `std.http.Client.BufferedConnection for an example of `conn`.
+    pub fn read(r: *HeadersParser, conn: anytype, buffer: []u8, skip: bool) !usize {
         assert(r.state.isContent());
         if (r.done) return 0;
 
@@ -526,10 +526,10 @@ pub const HeadersParser = struct {
                     const data_avail = r.next_chunk_length;
 
                     if (skip) {
-                        try bconn.fill();
+                        try conn.fill();
 
-                        const nread = @min(bconn.peek().len, data_avail);
-                        bconn.clear(@intCast(u16, nread));
+                        const nread = @min(conn.peek().len, data_avail);
+                        conn.drop(@intCast(u16, nread));
                         r.next_chunk_length -= nread;
 
                         if (r.next_chunk_length == 0) r.done = true;
@@ -539,7 +539,7 @@ pub const HeadersParser = struct {
                         const out_avail = buffer.len;
 
                         const can_read = @intCast(usize, @min(data_avail, out_avail));
-                        const nread = try bconn.read(buffer[0..can_read]);
+                        const nread = try conn.read(buffer[0..can_read]);
                         r.next_chunk_length -= nread;
 
                         if (r.next_chunk_length == 0) r.done = true;
@@ -548,15 +548,15 @@ pub const HeadersParser = struct {
                     }
                 },
                 .chunk_data_suffix, .chunk_data_suffix_r, .chunk_head_size, .chunk_head_ext, .chunk_head_r => {
-                    try bconn.fill();
+                    try conn.fill();
 
-                    const i = r.findChunkedLen(bconn.peek());
-                    bconn.clear(@intCast(u16, i));
+                    const i = r.findChunkedLen(conn.peek());
+                    conn.drop(@intCast(u16, i));
 
                     switch (r.state) {
                         .invalid => return error.HttpChunkInvalid,
                         .chunk_data => if (r.next_chunk_length == 0) {
-                            if (std.mem.eql(u8, bconn.peek(), "\r\n")) {
+                            if (std.mem.eql(u8, conn.peek(), "\r\n")) {
                                 r.state = .finished;
                             } else {
                                 // The trailer section is formatted identically to the header section.
@@ -576,14 +576,14 @@ pub const HeadersParser = struct {
                     const out_avail = buffer.len - out_index;
 
                     if (skip) {
-                        try bconn.fill();
+                        try conn.fill();
 
-                        const nread = @min(bconn.peek().len, data_avail);
-                        bconn.clear(@intCast(u16, nread));
+                        const nread = @min(conn.peek().len, data_avail);
+                        conn.drop(@intCast(u16, nread));
                         r.next_chunk_length -= nread;
                     } else {
                         const can_read = @intCast(usize, @min(data_avail, out_avail));
-                        const nread = try bconn.read(buffer[out_index..][0..can_read]);
+                        const nread = try conn.read(buffer[out_index..][0..can_read]);
                         r.next_chunk_length -= nread;
                         out_index += nread;
                     }
@@ -628,74 +628,74 @@ const MockBufferedConnection = struct {
     start: u16 = 0,
     end: u16 = 0,
 
-    pub fn fill(bconn: *MockBufferedConnection) ReadError!void {
-        if (bconn.end != bconn.start) return;
+    pub fn fill(conn: *MockBufferedConnection) ReadError!void {
+        if (conn.end != conn.start) return;
 
-        const nread = try bconn.conn.read(bconn.buf[0..]);
+        const nread = try conn.conn.read(conn.buf[0..]);
         if (nread == 0) return error.EndOfStream;
-        bconn.start = 0;
-        bconn.end = @truncate(u16, nread);
+        conn.start = 0;
+        conn.end = @truncate(u16, nread);
     }
 
-    pub fn peek(bconn: *MockBufferedConnection) []const u8 {
-        return bconn.buf[bconn.start..bconn.end];
+    pub fn peek(conn: *MockBufferedConnection) []const u8 {
+        return conn.buf[conn.start..conn.end];
     }
 
     pub fn drop(conn: *MockBufferedConnection, num: u16) void {
         conn.start += num;
     }
 
-    pub fn readAtLeast(bconn: *MockBufferedConnection, buffer: []u8, len: usize) ReadError!usize {
+    pub fn readAtLeast(conn: *MockBufferedConnection, buffer: []u8, len: usize) ReadError!usize {
         var out_index: u16 = 0;
         while (out_index < len) {
-            const available = bconn.end - bconn.start;
+            const available = conn.end - conn.start;
             const left = buffer.len - out_index;
 
             if (available > 0) {
                 const can_read = @truncate(u16, @min(available, left));
 
-                @memcpy(buffer[out_index..][0..can_read], bconn.buf[bconn.start..][0..can_read]);
+                @memcpy(buffer[out_index..][0..can_read], conn.buf[conn.start..][0..can_read]);
                 out_index += can_read;
-                bconn.start += can_read;
+                conn.start += can_read;
 
                 continue;
             }
 
-            if (left > bconn.buf.len) {
+            if (left > conn.buf.len) {
                 // skip the buffer if the output is large enough
-                return bconn.conn.read(buffer[out_index..]);
+                return conn.conn.read(buffer[out_index..]);
             }
 
-            try bconn.fill();
+            try conn.fill();
         }
 
         return out_index;
     }
 
-    pub fn read(bconn: *MockBufferedConnection, buffer: []u8) ReadError!usize {
-        return bconn.readAtLeast(buffer, 1);
+    pub fn read(conn: *MockBufferedConnection, buffer: []u8) ReadError!usize {
+        return conn.readAtLeast(buffer, 1);
     }
 
     pub const ReadError = std.io.FixedBufferStream([]const u8).ReadError || error{EndOfStream};
     pub const Reader = std.io.Reader(*MockBufferedConnection, ReadError, read);
 
-    pub fn reader(bconn: *MockBufferedConnection) Reader {
-        return Reader{ .context = bconn };
+    pub fn reader(conn: *MockBufferedConnection) Reader {
+        return Reader{ .context = conn };
     }
 
-    pub fn writeAll(bconn: *MockBufferedConnection, buffer: []const u8) WriteError!void {
-        return bconn.conn.writeAll(buffer);
+    pub fn writeAll(conn: *MockBufferedConnection, buffer: []const u8) WriteError!void {
+        return conn.conn.writeAll(buffer);
     }
 
-    pub fn write(bconn: *MockBufferedConnection, buffer: []const u8) WriteError!usize {
-        return bconn.conn.write(buffer);
+    pub fn write(conn: *MockBufferedConnection, buffer: []const u8) WriteError!usize {
+        return conn.conn.write(buffer);
     }
 
     pub const WriteError = std.io.FixedBufferStream([]const u8).WriteError;
     pub const Writer = std.io.Writer(*MockBufferedConnection, WriteError, write);
 
-    pub fn writer(bconn: *MockBufferedConnection) Writer {
-        return Writer{ .context = bconn };
+    pub fn writer(conn: *MockBufferedConnection) Writer {
+        return Writer{ .context = conn };
     }
 };
 
@@ -753,12 +753,12 @@ test "HeadersParser.read length" {
     const data = "GET / HTTP/1.1\r\nHost: localhost\r\nContent-Length: 5\r\n\r\nHello";
     var fbs = std.io.fixedBufferStream(data);
 
-    var bconn = MockBufferedConnection{
+    var conn = MockBufferedConnection{
         .conn = fbs,
     };
 
     while (true) { // read headers
-        try bconn.fill();
+        try conn.fill();
 
         const nchecked = try r.checkCompleteHead(std.testing.allocator, conn.peek());
         conn.drop(@intCast(u16, nchecked));
@@ -769,7 +769,7 @@ test "HeadersParser.read length" {
     var buf: [8]u8 = undefined;
 
     r.next_chunk_length = 5;
-    const len = try r.read(&bconn, &buf, false);
+    const len = try r.read(&conn, &buf, false);
     try std.testing.expectEqual(@as(usize, 5), len);
     try std.testing.expectEqualStrings("Hello", buf[0..len]);
 
@@ -784,12 +784,12 @@ test "HeadersParser.read chunked" {
     const data = "GET / HTTP/1.1\r\nHost: localhost\r\n\r\n2\r\nHe\r\n2\r\nll\r\n1\r\no\r\n0\r\n\r\n";
     var fbs = std.io.fixedBufferStream(data);
 
-    var bconn = MockBufferedConnection{
+    var conn = MockBufferedConnection{
         .conn = fbs,
     };
 
     while (true) { // read headers
-        try bconn.fill();
+        try conn.fill();
 
         const nchecked = try r.checkCompleteHead(std.testing.allocator, conn.peek());
         conn.drop(@intCast(u16, nchecked));
@@ -799,7 +799,7 @@ test "HeadersParser.read chunked" {
     var buf: [8]u8 = undefined;
 
     r.state = .chunk_head_size;
-    const len = try r.read(&bconn, &buf, false);
+    const len = try r.read(&conn, &buf, false);
     try std.testing.expectEqual(@as(usize, 5), len);
     try std.testing.expectEqualStrings("Hello", buf[0..len]);
 
@@ -814,12 +814,12 @@ test "HeadersParser.read chunked trailer" {
     const data = "GET / HTTP/1.1\r\nHost: localhost\r\n\r\n2\r\nHe\r\n2\r\nll\r\n1\r\no\r\n0\r\nContent-Type: text/plain\r\n\r\n";
     var fbs = std.io.fixedBufferStream(data);
 
-    var bconn = MockBufferedConnection{
+    var conn = MockBufferedConnection{
         .conn = fbs,
     };
 
     while (true) { // read headers
-        try bconn.fill();
+        try conn.fill();
 
         const nchecked = try r.checkCompleteHead(std.testing.allocator, conn.peek());
         conn.drop(@intCast(u16, nchecked));
@@ -829,12 +829,12 @@ test "HeadersParser.read chunked trailer" {
     var buf: [8]u8 = undefined;
 
     r.state = .chunk_head_size;
-    const len = try r.read(&bconn, &buf, false);
+    const len = try r.read(&conn, &buf, false);
     try std.testing.expectEqual(@as(usize, 5), len);
     try std.testing.expectEqualStrings("Hello", buf[0..len]);
 
     while (true) { // read headers
-        try bconn.fill();
+        try conn.fill();
 
         const nchecked = try r.checkCompleteHead(std.testing.allocator, conn.peek());
         conn.drop(@intCast(u16, nchecked));
lib/std/http/Server.zig
@@ -118,7 +118,7 @@ pub const BufferedConnection = struct {
         return bconn.read_buf[bconn.read_start..bconn.read_end];
     }
 
-    pub fn clear(bconn: *BufferedConnection, num: u16) void {
+    pub fn drop(bconn: *BufferedConnection, num: u16) void {
         bconn.read_start += num;
     }
 
@@ -545,7 +545,7 @@ pub const Response = struct {
             try res.connection.fill();
 
             const nchecked = try res.request.parser.checkCompleteHead(res.allocator, res.connection.peek());
-            res.connection.clear(@intCast(u16, nchecked));
+            res.connection.drop(@intCast(u16, nchecked));
 
             if (res.request.parser.state.isContent()) break;
         }
@@ -612,7 +612,7 @@ pub const Response = struct {
                 try res.connection.fill();
 
                 const nchecked = try res.request.parser.checkCompleteHead(res.allocator, res.connection.peek());
-                res.connection.clear(@intCast(u16, nchecked));
+                res.connection.drop(@intCast(u16, nchecked));
             }
 
             if (has_trail) {