Commit 23ccff9cce

Nameless <truemedian@gmail.com>
2023-05-28 09:50:51
std.http.Server: collapse BufferedConnection into Connection
1 parent 0e5e6cb
Changed files (3)
lib
test
standalone
lib/std/http/Client.zig
@@ -184,7 +184,7 @@ pub const Connection = struct {
     pub fn fill(conn: *Connection) ReadError!void {
         if (conn.read_end != conn.read_start) return;
 
-        const nread = try conn.read(conn.read_buf[0..]);
+        const nread = try conn.rawReadAtLeast(conn.read_buf[0..], 1);
         if (nread == 0) return error.EndOfStream;
         conn.read_start = 0;
         conn.read_end = @intCast(u16, nread);
@@ -207,13 +207,13 @@ pub const Connection = struct {
             const available_buffer = buffer.len - out_index;
 
             if (available_read > available_buffer) { // partially read buffered data
-                @memcpy(buffer[out_index..], conn.read_buf[conn.read_start..][0..available_buffer]);
+                @memcpy(buffer[out_index..], conn.read_buf[conn.read_start..conn.read_end][0..available_buffer]);
                 out_index += @intCast(u16, available_buffer);
                 conn.read_start += @intCast(u16, available_buffer);
 
                 break;
             } else if (available_read > 0) { // fully read buffered data
-                @memcpy(buffer[out_index..][0..available_read], conn.read_buf[conn.read_start..]);
+                @memcpy(buffer[out_index..][0..available_read], conn.read_buf[conn.read_start..conn.read_end]);
                 out_index += available_read;
                 conn.read_start += available_read;
 
@@ -608,6 +608,8 @@ pub const Request = struct {
         try w.print("{}", .{req.headers});
 
         try w.writeAll("\r\n");
+
+        try buffered.flush();
     }
 
     pub const TransferReadError = Connection.ReadError || proto.HeadersParser.ReadError;
lib/std/http/Server.zig
@@ -16,39 +16,92 @@ socket: net.StreamServer,
 
 /// An interface to either a plain or TLS connection.
 pub const Connection = struct {
+    pub const buffer_size = std.crypto.tls.max_ciphertext_record_len;
+    pub const Protocol = enum { plain };
+
     stream: net.Stream,
     protocol: Protocol,
 
     closing: bool = true,
 
-    pub const Protocol = enum { plain };
+    read_buf: [buffer_size]u8 = undefined,
+    read_start: u16 = 0,
+    read_end: u16 = 0,
 
-    pub fn read(conn: *Connection, buffer: []u8) ReadError!usize {
+    pub fn rawReadAtLeast(conn: *Connection, buffer: []u8, len: usize) ReadError!usize {
         return switch (conn.protocol) {
-            .plain => conn.stream.read(buffer),
-            // .tls => return conn.tls_client.read(conn.stream, buffer),
-        } catch |err| switch (err) {
-            error.ConnectionTimedOut => return error.ConnectionTimedOut,
-            error.ConnectionResetByPeer, error.BrokenPipe => return error.ConnectionResetByPeer,
-            else => return error.UnexpectedReadFailure,
+            .plain => conn.stream.readAtLeast(buffer, len),
+            // .tls => conn.tls_client.readAtLeast(conn.stream, buffer, len),
+        } catch |err| {
+            switch (err) {
+                error.ConnectionResetByPeer, error.BrokenPipe => return error.ConnectionResetByPeer,
+                else => return error.UnexpectedReadFailure,
+            }
         };
     }
 
+    pub fn fill(conn: *Connection) ReadError!void {
+        if (conn.read_end != conn.read_start) return;
+
+        const nread = try conn.rawReadAtLeast(conn.read_buf[0..], 1);
+        if (nread == 0) return error.EndOfStream;
+        conn.read_start = 0;
+        conn.read_end = @intCast(u16, nread);
+    }
+
+    pub fn peek(conn: *Connection) []const u8 {
+        return conn.read_buf[conn.read_start..conn.read_end];
+    }
+
+    pub fn drop(conn: *Connection, num: u16) void {
+        conn.read_start += num;
+    }
+
     pub fn readAtLeast(conn: *Connection, buffer: []u8, len: usize) ReadError!usize {
-        return switch (conn.protocol) {
-            .plain => conn.stream.readAtLeast(buffer, len),
-            // .tls => return conn.tls_client.readAtLeast(conn.stream, buffer, len),
-        } catch |err| switch (err) {
-            error.ConnectionTimedOut => return error.ConnectionTimedOut,
-            error.ConnectionResetByPeer, error.BrokenPipe => return error.ConnectionResetByPeer,
-            else => return error.UnexpectedReadFailure,
-        };
+        assert(len <= buffer.len);
+
+        var out_index: u16 = 0;
+        while (out_index < len) {
+            const available_read = conn.read_end - conn.read_start;
+            const available_buffer = buffer.len - out_index;
+
+            if (available_read > available_buffer) { // partially read buffered data
+                @memcpy(buffer[out_index..], conn.read_buf[conn.read_start..conn.read_end][0..available_buffer]);
+                out_index += @intCast(u16, available_buffer);
+                conn.read_start += @intCast(u16, available_buffer);
+
+                break;
+            } else if (available_read > 0) { // fully read buffered data
+                @memcpy(buffer[out_index..][0..available_read], conn.read_buf[conn.read_start..conn.read_end]);
+                out_index += available_read;
+                conn.read_start += available_read;
+
+                if (out_index >= len) break;
+            }
+
+            const leftover_buffer = available_buffer - available_read;
+            const leftover_len = len - out_index;
+
+            if (leftover_buffer > conn.read_buf.len) {
+                // skip the buffer if the output is large enough
+                return conn.rawReadAtLeast(buffer[out_index..], leftover_len);
+            }
+
+            try conn.fill();
+        }
+
+        return out_index;
+    }
+
+    pub fn read(conn: *Connection, buffer: []u8) ReadError!usize {
+        return conn.readAtLeast(buffer, 1);
     }
 
     pub const ReadError = error{
         ConnectionTimedOut,
         ConnectionResetByPeer,
         UnexpectedReadFailure,
+        EndOfStream,
     };
 
     pub const Reader = std.io.Reader(*Connection, ReadError, read);
@@ -93,112 +146,6 @@ pub const Connection = struct {
     }
 };
 
-/// A buffered (and peekable) Connection.
-pub const BufferedConnection = struct {
-    pub const buffer_size = std.crypto.tls.max_ciphertext_record_len;
-
-    conn: Connection,
-    read_buf: [buffer_size]u8 = undefined,
-    read_start: u16 = 0,
-    read_end: u16 = 0,
-
-    write_buf: [buffer_size]u8 = undefined,
-    write_end: u16 = 0,
-
-    pub fn fill(bconn: *BufferedConnection) ReadError!void {
-        if (bconn.read_end != bconn.read_start) return;
-
-        const nread = try bconn.conn.read(bconn.read_buf[0..]);
-        if (nread == 0) return error.EndOfStream;
-        bconn.read_start = 0;
-        bconn.read_end = @intCast(u16, nread);
-    }
-
-    pub fn peek(bconn: *BufferedConnection) []const u8 {
-        return bconn.read_buf[bconn.read_start..bconn.read_end];
-    }
-
-    pub fn drop(bconn: *BufferedConnection, num: u16) void {
-        bconn.read_start += num;
-    }
-
-    pub fn readAtLeast(bconn: *BufferedConnection, buffer: []u8, len: usize) ReadError!usize {
-        var out_index: u16 = 0;
-        while (out_index < len) {
-            const available = bconn.read_end - bconn.read_start;
-            const left = buffer.len - out_index;
-
-            if (available > 0) {
-                const can_read = @intCast(u16, @min(available, left));
-
-                @memcpy(buffer[out_index..][0..can_read], bconn.read_buf[bconn.read_start..][0..can_read]);
-                out_index += can_read;
-                bconn.read_start += can_read;
-
-                continue;
-            }
-
-            if (left > bconn.read_buf.len) {
-                // skip the buffer if the output is large enough
-                return bconn.conn.read(buffer[out_index..]);
-            }
-
-            try bconn.fill();
-        }
-
-        return out_index;
-    }
-
-    pub fn read(bconn: *BufferedConnection, buffer: []u8) ReadError!usize {
-        return bconn.readAtLeast(buffer, 1);
-    }
-
-    pub const ReadError = Connection.ReadError || error{EndOfStream};
-    pub const Reader = std.io.Reader(*BufferedConnection, ReadError, read);
-
-    pub fn reader(bconn: *BufferedConnection) Reader {
-        return Reader{ .context = bconn };
-    }
-
-    pub fn writeAll(bconn: *BufferedConnection, buffer: []const u8) WriteError!void {
-        if (bconn.write_buf.len - bconn.write_end >= buffer.len) {
-            @memcpy(bconn.write_buf[bconn.write_end..][0..buffer.len], buffer);
-            bconn.write_end += @intCast(u16, buffer.len);
-        } else {
-            try bconn.flush();
-            try bconn.conn.writeAll(buffer);
-        }
-    }
-
-    pub fn write(bconn: *BufferedConnection, buffer: []const u8) WriteError!usize {
-        if (bconn.write_buf.len - bconn.write_end >= buffer.len) {
-            @memcpy(bconn.write_buf[bconn.write_end..][0..buffer.len], buffer);
-            bconn.write_end += @intCast(u16, buffer.len);
-
-            return buffer.len;
-        } else {
-            try bconn.flush();
-            return try bconn.conn.write(buffer);
-        }
-    }
-
-    pub fn flush(bconn: *BufferedConnection) WriteError!void {
-        defer bconn.write_end = 0;
-        return bconn.conn.writeAll(bconn.write_buf[0..bconn.write_end]);
-    }
-
-    pub const WriteError = Connection.WriteError;
-    pub const Writer = std.io.Writer(*BufferedConnection, WriteError, write);
-
-    pub fn writer(bconn: *BufferedConnection) Writer {
-        return Writer{ .context = bconn };
-    }
-
-    pub fn close(bconn: *BufferedConnection) void {
-        bconn.conn.close();
-    }
-};
-
 /// The mode of transport for responses.
 pub const ResponseTransfer = union(enum) {
     content_length: u64,
@@ -351,7 +298,7 @@ pub const Response = struct {
 
     allocator: Allocator,
     address: net.Address,
-    connection: BufferedConnection,
+    connection: Connection,
 
     headers: http.Headers,
     request: Request,
@@ -388,7 +335,7 @@ pub const Response = struct {
 
         if (!res.request.parser.done) {
             // If the response wasn't fully read, then we need to close the connection.
-            res.connection.conn.closing = true;
+            res.connection.closing = true;
             return .closing;
         }
 
@@ -402,9 +349,9 @@ pub const Response = struct {
         const req_connection = res.request.headers.getFirstValue("connection");
         const req_keepalive = req_connection != null and !std.ascii.eqlIgnoreCase("close", req_connection.?);
         if (req_keepalive and (res_keepalive or res_connection == null)) {
-            res.connection.conn.closing = false;
+            res.connection.closing = false;
         } else {
-            res.connection.conn.closing = true;
+            res.connection.closing = true;
         }
 
         switch (res.request.compression) {
@@ -434,14 +381,14 @@ pub const Response = struct {
             .parser = res.request.parser,
         };
 
-        if (res.connection.conn.closing) {
+        if (res.connection.closing) {
             return .closing;
         } else {
             return .reset;
         }
     }
 
-    pub const DoError = BufferedConnection.WriteError || error{ UnsupportedTransferEncoding, InvalidContentLength };
+    pub const DoError = Connection.WriteError || error{ UnsupportedTransferEncoding, InvalidContentLength };
 
     /// Send the response headers.
     pub fn do(res: *Response) !void {
@@ -450,7 +397,8 @@ pub const Response = struct {
             .first, .start, .responded, .finished => unreachable,
         }
 
-        const w = res.connection.writer();
+        var buffered = std.io.bufferedWriter(res.connection.writer());
+        const w = buffered.writer();
 
         try w.writeAll(@tagName(res.version));
         try w.writeByte(' ');
@@ -508,10 +456,10 @@ pub const Response = struct {
 
         try w.writeAll("\r\n");
 
-        try res.connection.flush();
+        try buffered.flush();
     }
 
-    pub const TransferReadError = BufferedConnection.ReadError || proto.HeadersParser.ReadError;
+    pub const TransferReadError = Connection.ReadError || proto.HeadersParser.ReadError;
 
     pub const TransferReader = std.io.Reader(*Response, TransferReadError, transferRead);
 
@@ -532,7 +480,7 @@ pub const Response = struct {
         return index;
     }
 
-    pub const WaitError = BufferedConnection.ReadError || proto.HeadersParser.CheckCompleteHeadError || Request.ParseError || error{ CompressionInitializationFailed, CompressionNotSupported };
+    pub const WaitError = Connection.ReadError || proto.HeadersParser.CheckCompleteHeadError || Request.ParseError || error{ CompressionInitializationFailed, CompressionNotSupported };
 
     /// Wait for the client to send a complete request head.
     pub fn wait(res: *Response) WaitError!void {
@@ -637,7 +585,7 @@ pub const Response = struct {
         return index;
     }
 
-    pub const WriteError = BufferedConnection.WriteError || error{ NotWriteable, MessageTooLong };
+    pub const WriteError = Connection.WriteError || error{ NotWriteable, MessageTooLong };
 
     pub const Writer = std.io.Writer(*Response, WriteError, write);
 
@@ -692,8 +640,6 @@ pub const Response = struct {
             .content_length => |len| if (len != 0) return error.MessageNotCompleted,
             .none => {},
         }
-
-        try res.connection.flush();
     }
 };
 
@@ -742,10 +688,10 @@ pub fn accept(server: *Server, options: AcceptOptions) AcceptError!Response {
     return Response{
         .allocator = options.allocator,
         .address = in.address,
-        .connection = .{ .conn = .{
+        .connection = .{
             .stream = in.stream,
             .protocol = .plain,
-        } },
+        },
         .headers = .{ .allocator = options.allocator },
         .request = .{
             .version = undefined,
test/standalone/http.zig
@@ -86,7 +86,6 @@ fn handleRequest(res: *Server.Response) !void {
         try res.writeAll("World!\n");
         // try res.finish();
         try res.connection.writeAll("0\r\nX-Checksum: aaaa\r\n\r\n");
-        try res.connection.flush();
     } else if (mem.eql(u8, res.request.target, "/redirect/1")) {
         res.transfer_encoding = .chunked;