Commit b4b9f6aa4a

Andrew Kelley <andrew@ziglang.org>
2024-02-21 08:16:03
std.http.Server: reimplement chunked uploading
* Uncouple std.http.ChunkParser from protocol.zig * Fix receiveHead not passing leftover buffer through the header parser. * Fix content-length read streaming This implementation handles the final chunk length correctly rather than "hoping" that the buffer already contains \r\n.
1 parent a8958c9
lib/std/http/ChunkParser.zig
@@ -0,0 +1,131 @@
+//! Parser for transfer-encoding: chunked.
+
+state: State,
+chunk_len: u64,
+
+pub const init: ChunkParser = .{
+    .state = .head_size,
+    .chunk_len = 0,
+};
+
+pub const State = enum {
+    head_size,
+    head_ext,
+    head_r,
+    data,
+    data_suffix,
+    data_suffix_r,
+    invalid,
+};
+
+/// Returns the number of bytes consumed by the chunk size. This is always
+/// less than or equal to `bytes.len`.
+///
+/// After this function returns, `chunk_len` will contain the parsed chunk size
+/// in bytes when `state` is `data`. Alternately, `state` may become `invalid`,
+/// indicating a syntax error in the input stream.
+///
+/// If the amount returned is less than `bytes.len`, the parser is in the
+/// `chunk_data` state and the first byte of the chunk is at `bytes[result]`.
+///
+/// Asserts `state` is neither `data` nor `invalid`.
+pub fn feed(p: *ChunkParser, bytes: []const u8) usize {
+    for (bytes, 0..) |c, i| switch (p.state) {
+        .data_suffix => switch (c) {
+            '\r' => p.state = .data_suffix_r,
+            '\n' => p.state = .head_size,
+            else => {
+                p.state = .invalid;
+                return i;
+            },
+        },
+        .data_suffix_r => switch (c) {
+            '\n' => p.state = .head_size,
+            else => {
+                p.state = .invalid;
+                return i;
+            },
+        },
+        .head_size => {
+            const digit = switch (c) {
+                '0'...'9' => |b| b - '0',
+                'A'...'Z' => |b| b - 'A' + 10,
+                'a'...'z' => |b| b - 'a' + 10,
+                '\r' => {
+                    p.state = .head_r;
+                    continue;
+                },
+                '\n' => {
+                    p.state = .data;
+                    return i + 1;
+                },
+                else => {
+                    p.state = .head_ext;
+                    continue;
+                },
+            };
+
+            const new_len = p.chunk_len *% 16 +% digit;
+            if (new_len <= p.chunk_len and p.chunk_len != 0) {
+                p.state = .invalid;
+                return i;
+            }
+
+            p.chunk_len = new_len;
+        },
+        .head_ext => switch (c) {
+            '\r' => p.state = .head_r,
+            '\n' => {
+                p.state = .data;
+                return i + 1;
+            },
+            else => continue,
+        },
+        .head_r => switch (c) {
+            '\n' => {
+                p.state = .data;
+                return i + 1;
+            },
+            else => {
+                p.state = .invalid;
+                return i;
+            },
+        },
+        .data => unreachable,
+        .invalid => unreachable,
+    };
+    return bytes.len;
+}
+
+const ChunkParser = @This();
+const std = @import("std");
+
+test feed {
+    const testing = std.testing;
+
+    const data = "Ff\r\nf0f000 ; ext\n0\r\nffffffffffffffffffffffffffffffffffffffff\r\n";
+
+    var p = init;
+    const first = p.feed(data[0..]);
+    try testing.expectEqual(@as(u32, 4), first);
+    try testing.expectEqual(@as(u64, 0xff), p.chunk_len);
+    try testing.expectEqual(.data, p.state);
+
+    p = init;
+    const second = p.feed(data[first..]);
+    try testing.expectEqual(@as(u32, 13), second);
+    try testing.expectEqual(@as(u64, 0xf0f000), p.chunk_len);
+    try testing.expectEqual(.data, p.state);
+
+    p = init;
+    const third = p.feed(data[first + second ..]);
+    try testing.expectEqual(@as(u32, 3), third);
+    try testing.expectEqual(@as(u64, 0), p.chunk_len);
+    try testing.expectEqual(.data, p.state);
+
+    p = init;
+    const fourth = p.feed(data[first + second + third ..]);
+    try testing.expectEqual(@as(u32, 16), fourth);
+    try testing.expectEqual(@as(u64, 0xffffffffffffffff), p.chunk_len);
+    try testing.expectEqual(.invalid, p.state);
+}
lib/std/http/HeadParser.zig
@@ -1,3 +1,5 @@
+//! Finds the end of an HTTP head in a stream.
+
 state: State = .start,
 
 pub const State = enum {
@@ -17,13 +19,12 @@ pub const State = enum {
 /// `bytes[result]`.
 pub fn feed(p: *HeadParser, bytes: []const u8) usize {
     const vector_len: comptime_int = @max(std.simd.suggestVectorLength(u8) orelse 1, 8);
-    const len: u32 = @intCast(bytes.len);
-    var index: u32 = 0;
+    var index: usize = 0;
 
     while (true) {
         switch (p.state) {
             .finished => return index,
-            .start => switch (len - index) {
+            .start => switch (bytes.len - index) {
                 0 => return index,
                 1 => {
                     switch (bytes[index]) {
@@ -218,7 +219,7 @@ pub fn feed(p: *HeadParser, bytes: []const u8) usize {
                     continue;
                 },
             },
-            .seen_n => switch (len - index) {
+            .seen_n => switch (bytes.len - index) {
                 0 => return index,
                 else => {
                     switch (bytes[index]) {
@@ -230,7 +231,7 @@ pub fn feed(p: *HeadParser, bytes: []const u8) usize {
                     continue;
                 },
             },
-            .seen_r => switch (len - index) {
+            .seen_r => switch (bytes.len - index) {
                 0 => return index,
                 1 => {
                     switch (bytes[index]) {
@@ -286,7 +287,7 @@ pub fn feed(p: *HeadParser, bytes: []const u8) usize {
                     continue;
                 },
             },
-            .seen_rn => switch (len - index) {
+            .seen_rn => switch (bytes.len - index) {
                 0 => return index,
                 1 => {
                     switch (bytes[index]) {
@@ -317,7 +318,7 @@ pub fn feed(p: *HeadParser, bytes: []const u8) usize {
                     continue;
                 },
             },
-            .seen_rnr => switch (len - index) {
+            .seen_rnr => switch (bytes.len - index) {
                 0 => return index,
                 else => {
                     switch (bytes[index]) {
lib/std/http/protocol.zig
@@ -97,85 +97,32 @@ pub const HeadersParser = struct {
         return @intCast(result);
     }
 
-    /// Returns the number of bytes consumed by the chunk size. This is always
-    /// less than or equal to `bytes.len`.
-    /// You should check `r.state == .chunk_data` after this to check if the
-    /// chunk size has been fully parsed.
-    ///
-    /// If the amount returned is less than `bytes.len`, you may assume that
-    /// the parser is in the `chunk_data` state and that the first byte of the
-    /// chunk is at `bytes[result]`.
     pub fn findChunkedLen(r: *HeadersParser, bytes: []const u8) u32 {
-        const len = @as(u32, @intCast(bytes.len));
-
-        for (bytes[0..], 0..) |c, i| {
-            const index = @as(u32, @intCast(i));
-            switch (r.state) {
-                .chunk_data_suffix => switch (c) {
-                    '\r' => r.state = .chunk_data_suffix_r,
-                    '\n' => r.state = .chunk_head_size,
-                    else => {
-                        r.state = .invalid;
-                        return index;
-                    },
-                },
-                .chunk_data_suffix_r => switch (c) {
-                    '\n' => r.state = .chunk_head_size,
-                    else => {
-                        r.state = .invalid;
-                        return index;
-                    },
-                },
-                .chunk_head_size => {
-                    const digit = switch (c) {
-                        '0'...'9' => |b| b - '0',
-                        'A'...'Z' => |b| b - 'A' + 10,
-                        'a'...'z' => |b| b - 'a' + 10,
-                        '\r' => {
-                            r.state = .chunk_head_r;
-                            continue;
-                        },
-                        '\n' => {
-                            r.state = .chunk_data;
-                            return index + 1;
-                        },
-                        else => {
-                            r.state = .chunk_head_ext;
-                            continue;
-                        },
-                    };
-
-                    const new_len = r.next_chunk_length *% 16 +% digit;
-                    if (new_len <= r.next_chunk_length and r.next_chunk_length != 0) {
-                        r.state = .invalid;
-                        return index;
-                    }
-
-                    r.next_chunk_length = new_len;
-                },
-                .chunk_head_ext => switch (c) {
-                    '\r' => r.state = .chunk_head_r,
-                    '\n' => {
-                        r.state = .chunk_data;
-                        return index + 1;
-                    },
-                    else => continue,
-                },
-                .chunk_head_r => switch (c) {
-                    '\n' => {
-                        r.state = .chunk_data;
-                        return index + 1;
-                    },
-                    else => {
-                        r.state = .invalid;
-                        return index;
-                    },
-                },
+        var cp: std.http.ChunkParser = .{
+            .state = switch (r.state) {
+                .chunk_head_size => .head_size,
+                .chunk_head_ext => .head_ext,
+                .chunk_head_r => .head_r,
+                .chunk_data => .data,
+                .chunk_data_suffix => .data_suffix,
+                .chunk_data_suffix_r => .data_suffix_r,
+                .invalid => .invalid,
                 else => unreachable,
-            }
-        }
-
-        return len;
+            },
+            .chunk_len = r.next_chunk_length,
+        };
+        const result = cp.feed(bytes);
+        r.state = switch (cp.state) {
+            .head_size => .chunk_head_size,
+            .head_ext => .chunk_head_ext,
+            .head_r => .chunk_head_r,
+            .data => .chunk_data,
+            .data_suffix => .chunk_data_suffix,
+            .data_suffix_r => .chunk_data_suffix_r,
+            .invalid => .invalid,
+        };
+        r.next_chunk_length = cp.chunk_len;
+        return @intCast(result);
     }
 
     /// Returns whether or not the parser has finished parsing a complete
@@ -464,41 +411,6 @@ const MockBufferedConnection = struct {
     }
 };
 
-test "HeadersParser.findChunkedLen" {
-    var r: HeadersParser = undefined;
-    const data = "Ff\r\nf0f000 ; ext\n0\r\nffffffffffffffffffffffffffffffffffffffff\r\n";
-
-    r = HeadersParser.init(&.{});
-    r.state = .chunk_head_size;
-    r.next_chunk_length = 0;
-
-    const first = r.findChunkedLen(data[0..]);
-    try testing.expectEqual(@as(u32, 4), first);
-    try testing.expectEqual(@as(u64, 0xff), r.next_chunk_length);
-    try testing.expectEqual(State.chunk_data, r.state);
-    r.state = .chunk_head_size;
-    r.next_chunk_length = 0;
-
-    const second = r.findChunkedLen(data[first..]);
-    try testing.expectEqual(@as(u32, 13), second);
-    try testing.expectEqual(@as(u64, 0xf0f000), r.next_chunk_length);
-    try testing.expectEqual(State.chunk_data, r.state);
-    r.state = .chunk_head_size;
-    r.next_chunk_length = 0;
-
-    const third = r.findChunkedLen(data[first + second ..]);
-    try testing.expectEqual(@as(u32, 3), third);
-    try testing.expectEqual(@as(u64, 0), r.next_chunk_length);
-    try testing.expectEqual(State.chunk_data, r.state);
-    r.state = .chunk_head_size;
-    r.next_chunk_length = 0;
-
-    const fourth = r.findChunkedLen(data[first + second + third ..]);
-    try testing.expectEqual(@as(u32, 16), fourth);
-    try testing.expectEqual(@as(u64, 0xffffffffffffffff), r.next_chunk_length);
-    try testing.expectEqual(State.invalid, r.state);
-}
-
 test "HeadersParser.read length" {
     // mock BufferedConnection for read
     var headers_buf: [256]u8 = undefined;
lib/std/http/Server.zig
@@ -1,4 +1,5 @@
 //! Blocking HTTP server implementation.
+//! Handles a single connection's lifecycle.
 
 connection: net.Server.Connection,
 /// Keeps track of whether the Server is ready to accept a new request on the
@@ -62,20 +63,19 @@ pub fn receiveHead(s: *Server) ReceiveHeadError!Request {
     // In case of a reused connection, move the next request's bytes to the
     // beginning of the buffer.
     if (s.next_request_start > 0) {
-        if (s.read_buffer_len > s.next_request_start) {
-            const leftover = s.read_buffer[s.next_request_start..s.read_buffer_len];
-            const dest = s.read_buffer[0..leftover.len];
-            if (leftover.len <= s.next_request_start) {
-                @memcpy(dest, leftover);
-            } else {
-                mem.copyBackwards(u8, dest, leftover);
-            }
-            s.read_buffer_len = leftover.len;
-        }
+        if (s.read_buffer_len > s.next_request_start) rebase(s, 0);
         s.next_request_start = 0;
     }
 
     var hp: http.HeadParser = .{};
+
+    if (s.read_buffer_len > 0) {
+        const bytes = s.read_buffer[0..s.read_buffer_len];
+        const end = hp.feed(bytes);
+        if (hp.state == .finished)
+            return finishReceivingHead(s, end);
+    }
+
     while (true) {
         const buf = s.read_buffer[s.read_buffer_len..];
         if (buf.len == 0)
@@ -85,16 +85,21 @@ pub fn receiveHead(s: *Server) ReceiveHeadError!Request {
         s.read_buffer_len += read_n;
         const bytes = buf[0..read_n];
         const end = hp.feed(bytes);
-        if (hp.state == .finished) return .{
-            .server = s,
-            .head_end = end,
-            .head = Request.Head.parse(s.read_buffer[0..end]) catch
-                return error.HttpHeadersInvalid,
-            .reader_state = undefined,
-        };
+        if (hp.state == .finished)
+            return finishReceivingHead(s, s.read_buffer_len - bytes.len + end);
     }
 }
 
+fn finishReceivingHead(s: *Server, head_end: usize) ReceiveHeadError!Request {
+    return .{
+        .server = s,
+        .head_end = head_end,
+        .head = Request.Head.parse(s.read_buffer[0..head_end]) catch
+            return error.HttpHeadersInvalid,
+        .reader_state = undefined,
+    };
+}
+
 pub const Request = struct {
     server: *Server,
     /// Index into Server's read_buffer.
@@ -102,6 +107,7 @@ pub const Request = struct {
     head: Head,
     reader_state: union {
         remaining_content_length: u64,
+        chunk_parser: http.ChunkParser,
     },
 
     pub const Compression = union(enum) {
@@ -416,51 +422,130 @@ pub const Request = struct {
         };
     }
 
-    pub const ReadError = net.Stream.ReadError;
+    pub const ReadError = net.Stream.ReadError || error{ HttpChunkInvalid, HttpHeadersOversize };
 
     fn read_cl(context: *const anyopaque, buffer: []u8) ReadError!usize {
         const request: *Request = @constCast(@alignCast(@ptrCast(context)));
         const s = request.server;
         assert(s.state == .receiving_body);
-
         const remaining_content_length = &request.reader_state.remaining_content_length;
-
         if (remaining_content_length.* == 0) {
             s.state = .ready;
             return 0;
         }
-
-        const available_bytes = s.read_buffer_len - request.head_end;
-        if (available_bytes == 0)
-            s.read_buffer_len += try s.connection.stream.read(s.read_buffer[request.head_end..]);
-
-        const available_buf = s.read_buffer[request.head_end..s.read_buffer_len];
-        const len = @min(remaining_content_length.*, available_buf.len, buffer.len);
-        @memcpy(buffer[0..len], available_buf[0..len]);
+        const available = try fill(s, request.head_end);
+        const len = @min(remaining_content_length.*, available.len, buffer.len);
+        @memcpy(buffer[0..len], available[0..len]);
         remaining_content_length.* -= len;
+        s.next_request_start += len;
         if (remaining_content_length.* == 0)
             s.state = .ready;
         return len;
     }
 
+    fn fill(s: *Server, head_end: usize) ReadError![]u8 {
+        const available = s.read_buffer[s.next_request_start..s.read_buffer_len];
+        if (available.len > 0) return available;
+        s.next_request_start = head_end;
+        s.read_buffer_len = head_end + try s.connection.stream.read(s.read_buffer[head_end..]);
+        return s.read_buffer[head_end..s.read_buffer_len];
+    }
+
     fn read_chunked(context: *const anyopaque, buffer: []u8) ReadError!usize {
         const request: *Request = @constCast(@alignCast(@ptrCast(context)));
         const s = request.server;
         assert(s.state == .receiving_body);
-        _ = buffer;
-        @panic("TODO");
-    }
 
-    pub const ReadAllError = ReadError || error{HttpBodyOversize};
+        const cp = &request.reader_state.chunk_parser;
+        const head_end = request.head_end;
+
+        // Protect against returning 0 before the end of stream.
+        var out_end: usize = 0;
+        while (out_end == 0) {
+            switch (cp.state) {
+                .invalid => return 0,
+                .data => {
+                    const available = try fill(s, head_end);
+                    const len = @min(cp.chunk_len, available.len, buffer.len);
+                    @memcpy(buffer[0..len], available[0..len]);
+                    cp.chunk_len -= len;
+                    if (cp.chunk_len == 0)
+                        cp.state = .data_suffix;
+                    out_end += len;
+                    s.next_request_start += len;
+                    continue;
+                },
+                else => {
+                    const available = try fill(s, head_end);
+                    const n = cp.feed(available);
+                    switch (cp.state) {
+                        .invalid => return error.HttpChunkInvalid,
+                        .data => {
+                            if (cp.chunk_len == 0) {
+                                // The next bytes in the stream are trailers,
+                                // or \r\n to indicate end of chunked body.
+                                //
+                                // This function must append the trailers at
+                                // head_end so that headers and trailers are
+                                // together.
+                                //
+                                // Since returning 0 would indicate end of
+                                // stream, this function must read all the
+                                // trailers before returning.
+                                if (s.next_request_start > head_end) rebase(s, head_end);
+                                var hp: http.HeadParser = .{};
+                                {
+                                    const bytes = s.read_buffer[head_end..s.read_buffer_len];
+                                    const end = hp.feed(bytes);
+                                    if (hp.state == .finished) {
+                                        s.next_request_start = s.read_buffer_len - bytes.len + end;
+                                        return out_end;
+                                    }
+                                }
+                                while (true) {
+                                    const buf = s.read_buffer[s.read_buffer_len..];
+                                    if (buf.len == 0)
+                                        return error.HttpHeadersOversize;
+                                    const read_n = try s.connection.stream.read(buf);
+                                    s.read_buffer_len += read_n;
+                                    const bytes = buf[0..read_n];
+                                    const end = hp.feed(bytes);
+                                    if (hp.state == .finished) {
+                                        s.next_request_start = s.read_buffer_len - bytes.len + end;
+                                        return out_end;
+                                    }
+                                }
+                            }
+                            const data = available[n..];
+                            const len = @min(cp.chunk_len, data.len, buffer.len);
+                            @memcpy(buffer[0..len], data[0..len]);
+                            cp.chunk_len -= len;
+                            if (cp.chunk_len == 0)
+                                cp.state = .data_suffix;
+                            out_end += len;
+                            s.next_request_start += n + len;
+                            continue;
+                        },
+                        else => continue,
+                    }
+                },
+            }
+        }
+        return out_end;
+    }
 
     pub fn reader(request: *Request) std.io.AnyReader {
         const s = request.server;
         assert(s.state == .received_head);
         s.state = .receiving_body;
+        s.next_request_start = request.head_end;
         switch (request.head.transfer_encoding) {
-            .chunked => return .{
-                .readFn = read_chunked,
-                .context = request,
+            .chunked => {
+                request.reader_state = .{ .chunk_parser = http.ChunkParser.init };
+                return .{
+                    .readFn = read_chunked,
+                    .context = request,
+                };
             },
             .none => {
                 request.reader_state = .{
@@ -489,31 +574,8 @@ pub const Request = struct {
         const s = request.server;
         if (keep_alive and request.head.keep_alive) switch (s.state) {
             .received_head => {
-                s.state = .receiving_body;
-                switch (request.head.transfer_encoding) {
-                    .none => t: {
-                        const len = request.head.content_length orelse break :t;
-                        const head_end = request.head_end;
-                        var total_body_discarded: usize = 0;
-                        while (true) {
-                            const available_bytes = s.read_buffer_len - head_end;
-                            const remaining_len = len - total_body_discarded;
-                            if (available_bytes >= remaining_len) {
-                                s.next_request_start = head_end + remaining_len;
-                                break :t;
-                            }
-                            total_body_discarded += available_bytes;
-                            // Preserve request header memory until receiveHead is called.
-                            const buf = s.read_buffer[head_end..];
-                            const read_n = s.connection.stream.read(buf) catch return false;
-                            s.read_buffer_len = head_end + read_n;
-                        }
-                    },
-                    .chunked => {
-                        @panic("TODO");
-                    },
-                }
-                s.state = .ready;
+                _ = request.reader().discard() catch return false;
+                assert(s.state == .ready);
                 return true;
             },
             .receiving_body, .ready => return true,
@@ -799,6 +861,17 @@ pub const Response = struct {
     }
 };
 
+fn rebase(s: *Server, index: usize) void {
+    const leftover = s.read_buffer[s.next_request_start..s.read_buffer_len];
+    const dest = s.read_buffer[index..][0..leftover.len];
+    if (leftover.len <= s.next_request_start - index) {
+        @memcpy(dest, leftover);
+    } else {
+        mem.copyBackwards(u8, dest, leftover);
+    }
+    s.read_buffer_len = index + leftover.len;
+}
+
 const std = @import("../std.zig");
 const http = std.http;
 const mem = std.mem;
lib/std/http/test.zig
@@ -164,7 +164,7 @@ test "HTTP server handles a chunked transfer coding request" {
 
     const stream = try std.net.tcpConnectToHost(allocator, "127.0.0.1", server_port);
     defer stream.close();
-    _ = try stream.writeAll(request_bytes[0..]);
+    try stream.writeAll(request_bytes);
 
     server_thread.join();
 }
lib/std/http.zig
@@ -4,6 +4,7 @@ pub const Client = @import("http/Client.zig");
 pub const Server = @import("http/Server.zig");
 pub const protocol = @import("http/protocol.zig");
 pub const HeadParser = @import("http/HeadParser.zig");
+pub const ChunkParser = @import("http/ChunkParser.zig");
 
 pub const Version = enum {
     @"HTTP/1.0",
@@ -313,5 +314,6 @@ test {
     _ = Server;
     _ = Status;
     _ = HeadParser;
+    _ = ChunkParser;
     _ = @import("http/test.zig");
 }