Commit 653d4158cd

Andrew Kelley <andrew@ziglang.org>
2024-02-23 10:58:02
std.http.Server: expose arbitrary HTTP headers
Ultimate flexibility, just be sure to destroy the correct amount of information when looking at them.
1 parent 5b34a1b
lib/std/http/Client.zig
@@ -568,8 +568,8 @@ pub const Response = struct {
         try expectEqual(@as(u10, 999), parseInt3("999"));
     }
 
-    pub fn iterateHeaders(r: Response) proto.HeaderIterator {
-        return proto.HeaderIterator.init(r.parser.get());
+    pub fn iterateHeaders(r: Response) http.HeaderIterator {
+        return http.HeaderIterator.init(r.parser.get());
     }
 };
 
lib/std/http/HeaderIterator.zig
@@ -0,0 +1,62 @@
+bytes: []const u8,
+index: usize,
+is_trailer: bool,
+
+pub fn init(bytes: []const u8) HeaderIterator {
+    return .{
+        .bytes = bytes,
+        .index = std.mem.indexOfPosLinear(u8, bytes, 0, "\r\n").? + 2,
+        .is_trailer = false,
+    };
+}
+
+pub fn next(it: *HeaderIterator) ?std.http.Header {
+    const end = std.mem.indexOfPosLinear(u8, it.bytes, it.index, "\r\n").?;
+    var kv_it = std.mem.splitSequence(u8, it.bytes[it.index..end], ": ");
+    const name = kv_it.next().?;
+    const value = kv_it.rest();
+    if (value.len == 0) {
+        if (it.is_trailer) return null;
+        const next_end = std.mem.indexOfPosLinear(u8, it.bytes, end + 2, "\r\n") orelse
+            return null;
+        it.is_trailer = true;
+        it.index = next_end + 2;
+        kv_it = std.mem.splitSequence(u8, it.bytes[end + 2 .. next_end], ": ");
+        return .{
+            .name = kv_it.next().?,
+            .value = kv_it.rest(),
+        };
+    }
+    it.index = end + 2;
+    return .{
+        .name = name,
+        .value = value,
+    };
+}
+
+test next {
+    var it = HeaderIterator.init("200 OK\r\na: b\r\nc: d\r\n\r\ne: f\r\n\r\n");
+    try std.testing.expect(!it.is_trailer);
+    {
+        const header = it.next().?;
+        try std.testing.expect(!it.is_trailer);
+        try std.testing.expectEqualStrings("a", header.name);
+        try std.testing.expectEqualStrings("b", header.value);
+    }
+    {
+        const header = it.next().?;
+        try std.testing.expect(!it.is_trailer);
+        try std.testing.expectEqualStrings("c", header.name);
+        try std.testing.expectEqualStrings("d", header.value);
+    }
+    {
+        const header = it.next().?;
+        try std.testing.expect(it.is_trailer);
+        try std.testing.expectEqualStrings("e", header.name);
+        try std.testing.expectEqualStrings("f", header.value);
+    }
+    try std.testing.expectEqual(null, it.next());
+}
+
+const HeaderIterator = @This();
+const std = @import("../std.zig");
lib/std/http/protocol.zig
@@ -250,68 +250,6 @@ pub const HeadersParser = struct {
     }
 };
 
-pub const HeaderIterator = struct {
-    bytes: []const u8,
-    index: usize,
-    is_trailer: bool,
-
-    pub fn init(bytes: []const u8) HeaderIterator {
-        return .{
-            .bytes = bytes,
-            .index = std.mem.indexOfPosLinear(u8, bytes, 0, "\r\n").? + 2,
-            .is_trailer = false,
-        };
-    }
-
-    pub fn next(it: *HeaderIterator) ?std.http.Header {
-        const end = std.mem.indexOfPosLinear(u8, it.bytes, it.index, "\r\n").?;
-        var kv_it = std.mem.splitSequence(u8, it.bytes[it.index..end], ": ");
-        const name = kv_it.next().?;
-        const value = kv_it.rest();
-        if (value.len == 0) {
-            if (it.is_trailer) return null;
-            const next_end = std.mem.indexOfPosLinear(u8, it.bytes, end + 2, "\r\n") orelse
-                return null;
-            it.is_trailer = true;
-            it.index = next_end + 2;
-            kv_it = std.mem.splitSequence(u8, it.bytes[end + 2 .. next_end], ": ");
-            return .{
-                .name = kv_it.next().?,
-                .value = kv_it.rest(),
-            };
-        }
-        it.index = end + 2;
-        return .{
-            .name = name,
-            .value = value,
-        };
-    }
-
-    test next {
-        var it = HeaderIterator.init("200 OK\r\na: b\r\nc: d\r\n\r\ne: f\r\n\r\n");
-        try std.testing.expect(!it.is_trailer);
-        {
-            const header = it.next().?;
-            try std.testing.expect(!it.is_trailer);
-            try std.testing.expectEqualStrings("a", header.name);
-            try std.testing.expectEqualStrings("b", header.value);
-        }
-        {
-            const header = it.next().?;
-            try std.testing.expect(!it.is_trailer);
-            try std.testing.expectEqualStrings("c", header.name);
-            try std.testing.expectEqualStrings("d", header.value);
-        }
-        {
-            const header = it.next().?;
-            try std.testing.expect(it.is_trailer);
-            try std.testing.expectEqualStrings("e", header.name);
-            try std.testing.expectEqualStrings("f", header.value);
-        }
-        try std.testing.expectEqual(null, it.next());
-    }
-};
-
 inline fn int16(array: *const [2]u8) u16 {
     return @as(u16, @bitCast(array.*));
 }
lib/std/http/Server.zig
@@ -273,6 +273,10 @@ pub const Request = struct {
         }
     };
 
+    pub fn iterateHeaders(r: *Request) http.HeaderIterator {
+        return http.HeaderIterator.init(r.server.read_buffer[0..r.head_end]);
+    }
+
     pub const RespondOptions = struct {
         version: http.Version = .@"HTTP/1.1",
         status: http.Status = .ok,
lib/std/http/test.zig
@@ -290,6 +290,58 @@ test "Server.Request.respondStreaming non-chunked, unknown content-length" {
     try expectEqualStrings(expected_response.items, response);
 }
 
+test "receiving arbitrary http headers from the client" {
+    const test_server = try createTestServer(struct {
+        fn run(net_server: *std.net.Server) anyerror!void {
+            var read_buffer: [666]u8 = undefined;
+            var remaining: usize = 1;
+            while (remaining != 0) : (remaining -= 1) {
+                const conn = try net_server.accept();
+                defer conn.stream.close();
+
+                var server = http.Server.init(conn, &read_buffer);
+                try expectEqual(.ready, server.state);
+                var request = try server.receiveHead();
+                try expectEqualStrings("/bar", request.head.target);
+                var it = request.iterateHeaders();
+                {
+                    const header = it.next().?;
+                    try expectEqualStrings("CoNneCtIoN", header.name);
+                    try expectEqualStrings("close", header.value);
+                    try expect(!it.is_trailer);
+                }
+                {
+                    const header = it.next().?;
+                    try expectEqualStrings("aoeu", header.name);
+                    try expectEqualStrings("asdf", header.value);
+                    try expect(!it.is_trailer);
+                }
+                try request.respond("", .{});
+            }
+        }
+    });
+    defer test_server.destroy();
+
+    const request_bytes = "GET /bar HTTP/1.1\r\n" ++
+        "CoNneCtIoN: close\r\n" ++
+        "aoeu: asdf\r\n" ++
+        "\r\n";
+    const gpa = std.testing.allocator;
+    const stream = try std.net.tcpConnectToHost(gpa, "127.0.0.1", test_server.port());
+    defer stream.close();
+    try stream.writeAll(request_bytes);
+
+    const response = try stream.reader().readAllAlloc(gpa, 8192);
+    defer gpa.free(response);
+
+    var expected_response = std.ArrayList(u8).init(gpa);
+    defer expected_response.deinit();
+
+    try expected_response.appendSlice("HTTP/1.1 200 OK\r\n");
+    try expected_response.appendSlice("content-length: 0\r\n\r\n");
+    try expectEqualStrings(expected_response.items, response);
+}
+
 test "general client/server API coverage" {
     if (builtin.os.tag == .windows) {
         // This test was never passing on Windows.
lib/std/http.zig
@@ -3,6 +3,7 @@ pub const Server = @import("http/Server.zig");
 pub const protocol = @import("http/protocol.zig");
 pub const HeadParser = @import("http/HeadParser.zig");
 pub const ChunkParser = @import("http/ChunkParser.zig");
+pub const HeaderIterator = @import("http/HeaderIterator.zig");
 
 pub const Version = enum {
     @"HTTP/1.0",