Commit 78192637fb

Andrew Kelley <andrew@ziglang.org>
2024-02-17 02:35:57
std.http: parser fixes
* add API for iterating over custom HTTP headers * remove `trailing` flag from std.http.Client.parse. Instead, simply don't call parse() for trailers. * fix the logic inside that parse() function. it was using wrong std.mem functions, ignoring malformed data, and returned errors on dead branches. * simplify logic inside wait() * fix HeadersParser not dropping the 2 read bytes of \r\n after a chunked transfer * move the trailers test to be a std lib unit test and make it pass
1 parent d574875
Changed files (3)
lib/std/http/Client.zig
@@ -428,12 +428,14 @@ pub const Response = struct {
         CompressionUnsupported,
     };
 
-    pub fn parse(res: *Response, bytes: []const u8, trailing: bool) ParseError!void {
-        var it = mem.tokenizeAny(u8, bytes, "\r\n");
+    pub fn parse(res: *Response, bytes: []const u8) ParseError!void {
+        var it = mem.splitSequence(u8, bytes, "\r\n");
 
-        const first_line = it.next() orelse return error.HttpHeadersInvalid;
-        if (first_line.len < 12)
+        const first_line = it.next().?;
+        if (first_line.len < 12) {
+            std.debug.print("first line: '{s}'\n", .{first_line});
             return error.HttpHeadersInvalid;
+        }
 
         const version: http.Version = switch (int64(first_line[0..8])) {
             int64("HTTP/1.0") => .@"HTTP/1.0",
@@ -449,17 +451,16 @@ pub const Response = struct {
         res.reason = reason;
 
         while (it.next()) |line| {
-            if (line.len == 0) return error.HttpHeadersInvalid;
+            if (line.len == 0) return;
             switch (line[0]) {
                 ' ', '\t' => return error.HttpHeaderContinuationsUnsupported,
                 else => {},
             }
 
-            var line_it = mem.tokenizeAny(u8, line, ": ");
-            const header_name = line_it.next() orelse return error.HttpHeadersInvalid;
+            var line_it = mem.splitSequence(u8, line, ": ");
+            const header_name = line_it.next().?;
             const header_value = line_it.rest();
-
-            if (trailing) continue;
+            if (header_value.len == 0) return error.HttpHeadersInvalid;
 
             if (std.ascii.eqlIgnoreCase(header_name, "connection")) {
                 res.keep_alive = !std.ascii.eqlIgnoreCase(header_value, "close");
@@ -538,6 +539,10 @@ pub const Response = struct {
         try expectEqual(@as(u10, 999), parseInt3("999"));
     }
 
+    pub fn iterateHeaders(r: Response) proto.HeaderIterator {
+        return proto.HeaderIterator.init(r.parser.get());
+    }
+
     version: http.Version,
     status: http.Status,
     reason: []const u8,
@@ -868,7 +873,7 @@ pub const Request = struct {
                 if (req.response.parser.state.isContent()) break;
             }
 
-            try req.response.parse(req.response.parser.get(), false);
+            try req.response.parse(req.response.parser.get());
 
             if (req.response.status == .@"continue") {
                 // We're done parsing the continue response; reset to prepare
@@ -903,21 +908,21 @@ pub const Request = struct {
                 return; // The response is empty; no further setup or redirection is necessary.
             }
 
-            if (req.response.transfer_encoding != .none) {
-                switch (req.response.transfer_encoding) {
-                    .none => unreachable,
-                    .chunked => {
-                        req.response.parser.next_chunk_length = 0;
-                        req.response.parser.state = .chunk_head_size;
-                    },
-                }
-            } else if (req.response.content_length) |cl| {
-                req.response.parser.next_chunk_length = cl;
+            switch (req.response.transfer_encoding) {
+                .none => {
+                    if (req.response.content_length) |cl| {
+                        req.response.parser.next_chunk_length = cl;
 
-                if (cl == 0) req.response.parser.done = true;
-            } else {
-                // read until the connection is closed
-                req.response.parser.next_chunk_length = std.math.maxInt(u64);
+                        if (cl == 0) req.response.parser.done = true;
+                    } else {
+                        // read until the connection is closed
+                        req.response.parser.next_chunk_length = std.math.maxInt(u64);
+                    }
+                },
+                .chunked => {
+                    req.response.parser.next_chunk_length = 0;
+                    req.response.parser.state = .chunk_head_size;
+                },
             }
 
             if (req.response.status.class() == .redirect and req.redirect_behavior != .unhandled) {
@@ -1014,27 +1019,16 @@ pub const Request = struct {
             //.zstd => |*zstd| zstd.read(buffer) catch return error.DecompressionFailure,
             else => try req.transferRead(buffer),
         };
+        if (out_index > 0) return out_index;
 
-        if (out_index == 0) {
-            const has_trail = !req.response.parser.state.isContent();
-
-            while (!req.response.parser.state.isContent()) { // read trailing headers
-                try req.connection.?.fill();
+        while (!req.response.parser.state.isContent()) { // read trailing headers
+            try req.connection.?.fill();
 
-                const nchecked = try req.response.parser.checkCompleteHead(req.connection.?.peek());
-                req.connection.?.drop(@intCast(nchecked));
-            }
-
-            if (has_trail) {
-                // The response headers before the trailers are already
-                // guaranteed to be valid, so they will always be parsed again
-                // and cannot return an error.
-                // This will *only* fail for a malformed trailer.
-                req.response.parse(req.response.parser.get(), true) catch return error.InvalidTrailers;
-            }
+            const nchecked = try req.response.parser.checkCompleteHead(req.connection.?.peek());
+            req.connection.?.drop(@intCast(nchecked));
         }
 
-        return out_index;
+        return 0;
     }
 
     /// Reads data from the response body. Must be called after `wait`.
lib/std/http/protocol.zig
@@ -570,9 +570,10 @@ pub const HeadersParser = struct {
                         .chunk_data => if (r.next_chunk_length == 0) {
                             if (std.mem.eql(u8, conn.peek(), "\r\n")) {
                                 r.state = .finished;
-                                r.done = true;
+                                conn.drop(2);
                             } else {
-                                // The trailer section is formatted identically to the header section.
+                                // The trailer section is formatted identically
+                                // to the header section.
                                 r.state = .seen_rn;
                             }
                             r.done = true;
@@ -613,6 +614,68 @@ pub const HeadersParser = struct {
     }
 };
 
+pub const HeaderIterator = struct {
+    bytes: []const u8,
+    index: usize,
+    is_trailer: bool,
+
+    pub fn init(bytes: []const u8) HeaderIterator {
+        return .{
+            .bytes = bytes,
+            .index = std.mem.indexOfPosLinear(u8, bytes, 0, "\r\n").? + 2,
+            .is_trailer = false,
+        };
+    }
+
+    pub fn next(it: *HeaderIterator) ?std.http.Header {
+        const end = std.mem.indexOfPosLinear(u8, it.bytes, it.index, "\r\n").?;
+        var kv_it = std.mem.splitSequence(u8, it.bytes[it.index..end], ": ");
+        const name = kv_it.next().?;
+        const value = kv_it.rest();
+        if (value.len == 0) {
+            if (it.is_trailer) return null;
+            const next_end = std.mem.indexOfPosLinear(u8, it.bytes, end + 2, "\r\n") orelse
+                return null;
+            it.is_trailer = true;
+            it.index = next_end + 2;
+            kv_it = std.mem.splitSequence(u8, it.bytes[end + 2 .. next_end], ": ");
+            return .{
+                .name = kv_it.next().?,
+                .value = kv_it.rest(),
+            };
+        }
+        it.index = end + 2;
+        return .{
+            .name = name,
+            .value = value,
+        };
+    }
+
+    test next {
+        var it = HeaderIterator.init("200 OK\r\na: b\r\nc: d\r\n\r\ne: f\r\n\r\n");
+        try std.testing.expect(!it.is_trailer);
+        {
+            const header = it.next().?;
+            try std.testing.expect(!it.is_trailer);
+            try std.testing.expectEqualStrings("a", header.name);
+            try std.testing.expectEqualStrings("b", header.value);
+        }
+        {
+            const header = it.next().?;
+            try std.testing.expect(!it.is_trailer);
+            try std.testing.expectEqualStrings("c", header.name);
+            try std.testing.expectEqualStrings("d", header.value);
+        }
+        {
+            const header = it.next().?;
+            try std.testing.expect(it.is_trailer);
+            try std.testing.expectEqualStrings("e", header.name);
+            try std.testing.expectEqualStrings("f", header.value);
+        }
+        try std.testing.expectEqual(null, it.next());
+    }
+};
+
 inline fn int16(array: *const [2]u8) u16 {
     return @as(u16, @bitCast(array.*));
 }
lib/std/http/test.zig
@@ -1,7 +1,8 @@
 const std = @import("std");
+const testing = std.testing;
 
 test "trailers" {
-    const gpa = std.testing.allocator;
+    const gpa = testing.allocator;
 
     var http_server = std.http.Server.init(.{
         .reuse_address = true,
@@ -21,28 +22,49 @@ test "trailers" {
     defer gpa.free(location);
     const uri = try std.Uri.parse(location);
 
-    var server_header_buffer: [1024]u8 = undefined;
-    var req = try client.open(.GET, uri, .{
-        .server_header_buffer = &server_header_buffer,
-    });
-    defer req.deinit();
-
-    try req.send(.{});
-    try req.wait();
-
-    const body = try req.reader().readAllAlloc(gpa, 8192);
-    defer gpa.free(body);
-
-    try std.testing.expectEqualStrings("Hello, World!\n", body);
-    if (true) @panic("TODO implement inspecting custom headers in responses");
-    //try testing.expectEqualStrings("aaaa", req.response.headers.getFirstValue("x-checksum").?);
+    {
+        var server_header_buffer: [1024]u8 = undefined;
+        var req = try client.open(.GET, uri, .{
+            .server_header_buffer = &server_header_buffer,
+        });
+        defer req.deinit();
+
+        try req.send(.{});
+        try req.wait();
+
+        const body = try req.reader().readAllAlloc(gpa, 8192);
+        defer gpa.free(body);
+
+        try testing.expectEqualStrings("Hello, World!\n", body);
+
+        var it = req.response.iterateHeaders();
+        {
+            const header = it.next().?;
+            try testing.expect(!it.is_trailer);
+            try testing.expectEqualStrings("connection", header.name);
+            try testing.expectEqualStrings("keep-alive", header.value);
+        }
+        {
+            const header = it.next().?;
+            try testing.expect(!it.is_trailer);
+            try testing.expectEqualStrings("transfer-encoding", header.name);
+            try testing.expectEqualStrings("chunked", header.value);
+        }
+        {
+            const header = it.next().?;
+            try testing.expect(it.is_trailer);
+            try testing.expectEqualStrings("X-Checksum", header.name);
+            try testing.expectEqualStrings("aaaa", header.value);
+        }
+        try testing.expectEqual(null, it.next());
+    }
 
     // connection has been kept alive
-    try std.testing.expect(client.connection_pool.free_len == 1);
+    try testing.expect(client.connection_pool.free_len == 1);
 }
 
 fn serverThread(http_server: *std.http.Server) anyerror!void {
-    const gpa = std.testing.allocator;
+    const gpa = testing.allocator;
 
     var header_buffer: [1024]u8 = undefined;
     var remaining: usize = 1;
@@ -60,17 +82,16 @@ fn serverThread(http_server: *std.http.Server) anyerror!void {
         };
         try serve(&res);
 
-        try std.testing.expectEqual(.reset, res.reset());
+        try testing.expectEqual(.reset, res.reset());
     }
 }
 
 fn serve(res: *std.http.Server.Response) !void {
-    try std.testing.expectEqualStrings(res.request.target, "/trailer");
+    try testing.expectEqualStrings(res.request.target, "/trailer");
     res.transfer_encoding = .chunked;
 
     try res.send();
     try res.writeAll("Hello, ");
     try res.writeAll("World!\n");
-    // try res.finish();
     try res.connection.writeAll("0\r\nX-Checksum: aaaa\r\n\r\n");
 }