Commit 737e7be46c

Andrew Kelley <andrew@ziglang.org>
2024-02-23 01:54:46
std.http: refactor unit tests
avoid a little bit of boilerplate
1 parent abde76a
Changed files (1)
lib
std
lib/std/http/test.zig
@@ -1,28 +1,57 @@
 const builtin = @import("builtin");
 const std = @import("std");
-const testing = std.testing;
 const native_endian = builtin.cpu.arch.endian();
+const expect = std.testing.expect;
+const expectEqual = std.testing.expectEqual;
+const expectEqualStrings = std.testing.expectEqualStrings;
+const expectError = std.testing.expectError;
 
 test "trailers" {
-    if (builtin.single_threaded) return error.SkipZigTest;
-    if (builtin.os.tag == .wasi) return error.SkipZigTest;
+    const test_server = try createTestServer(struct {
+        fn run(net_server: *std.net.Server) anyerror!void {
+            var header_buffer: [1024]u8 = undefined;
+            var remaining: usize = 1;
+            while (remaining != 0) : (remaining -= 1) {
+                const conn = try net_server.accept();
+                defer conn.stream.close();
 
-    const gpa = testing.allocator;
+                var server = std.http.Server.init(conn, &header_buffer);
 
-    const address = try std.net.Address.parseIp("127.0.0.1", 0);
-    var http_server = try address.listen(.{
-        .reuse_address = true,
-    });
+                try expectEqual(.ready, server.state);
+                var request = try server.receiveHead();
+                try serve(&request);
+                try expectEqual(.ready, server.state);
+            }
+        }
 
-    const port = http_server.listen_address.in.getPort();
+        fn serve(request: *std.http.Server.Request) !void {
+            try expectEqualStrings(request.head.target, "/trailer");
 
-    const server_thread = try std.Thread.spawn(.{}, serverThread, .{&http_server});
-    defer server_thread.join();
+            var send_buffer: [1024]u8 = undefined;
+            var response = request.respondStreaming(.{
+                .send_buffer = &send_buffer,
+            });
+            try response.writeAll("Hello, ");
+            try response.flush();
+            try response.writeAll("World!\n");
+            try response.flush();
+            try response.endChunked(.{
+                .trailers = &.{
+                    .{ .name = "X-Checksum", .value = "aaaa" },
+                },
+            });
+        }
+    });
+    defer test_server.destroy();
+
+    const gpa = std.testing.allocator;
 
     var client: std.http.Client = .{ .allocator = gpa };
     defer client.deinit();
 
-    const location = try std.fmt.allocPrint(gpa, "http://127.0.0.1:{d}/trailer", .{port});
+    const location = try std.fmt.allocPrint(gpa, "http://127.0.0.1:{d}/trailer", .{
+        test_server.port(),
+    });
     defer gpa.free(location);
     const uri = try std.Uri.parse(location);
 
@@ -39,94 +68,38 @@ test "trailers" {
         const body = try req.reader().readAllAlloc(gpa, 8192);
         defer gpa.free(body);
 
-        try testing.expectEqualStrings("Hello, World!\n", body);
+        try expectEqualStrings("Hello, World!\n", body);
 
         var it = req.response.iterateHeaders();
         {
             const header = it.next().?;
-            try testing.expect(!it.is_trailer);
-            try testing.expectEqualStrings("connection", header.name);
-            try testing.expectEqualStrings("keep-alive", header.value);
+            try expect(!it.is_trailer);
+            try expectEqualStrings("connection", header.name);
+            try expectEqualStrings("keep-alive", header.value);
         }
         {
             const header = it.next().?;
-            try testing.expect(!it.is_trailer);
-            try testing.expectEqualStrings("transfer-encoding", header.name);
-            try testing.expectEqualStrings("chunked", header.value);
+            try expect(!it.is_trailer);
+            try expectEqualStrings("transfer-encoding", header.name);
+            try expectEqualStrings("chunked", header.value);
         }
         {
             const header = it.next().?;
-            try testing.expect(it.is_trailer);
-            try testing.expectEqualStrings("X-Checksum", header.name);
-            try testing.expectEqualStrings("aaaa", header.value);
+            try expect(it.is_trailer);
+            try expectEqualStrings("X-Checksum", header.name);
+            try expectEqualStrings("aaaa", header.value);
         }
-        try testing.expectEqual(null, it.next());
+        try expectEqual(null, it.next());
     }
 
     // connection has been kept alive
-    try testing.expect(client.connection_pool.free_len == 1);
-}
-
-fn serverThread(http_server: *std.net.Server) anyerror!void {
-    var header_buffer: [1024]u8 = undefined;
-    var remaining: usize = 1;
-    while (remaining != 0) : (remaining -= 1) {
-        const conn = try http_server.accept();
-        defer conn.stream.close();
-
-        var server = std.http.Server.init(conn, &header_buffer);
-
-        try testing.expectEqual(.ready, server.state);
-        var request = try server.receiveHead();
-        try serve(&request);
-        try testing.expectEqual(.ready, server.state);
-    }
-}
-
-fn serve(request: *std.http.Server.Request) !void {
-    try testing.expectEqualStrings(request.head.target, "/trailer");
-
-    var send_buffer: [1024]u8 = undefined;
-    var response = request.respondStreaming(.{
-        .send_buffer = &send_buffer,
-    });
-    try response.writeAll("Hello, ");
-    try response.flush();
-    try response.writeAll("World!\n");
-    try response.flush();
-    try response.endChunked(.{
-        .trailers = &.{
-            .{ .name = "X-Checksum", .value = "aaaa" },
-        },
-    });
+    try expect(client.connection_pool.free_len == 1);
 }
 
 test "HTTP server handles a chunked transfer coding request" {
-    // This test requires spawning threads.
-    if (builtin.single_threaded) {
-        return error.SkipZigTest;
-    }
-
-    if (builtin.zig_backend == .stage2_llvm and native_endian == .big) {
-        // https://github.com/ziglang/zig/issues/13782
-        return error.SkipZigTest;
-    }
-
-    if (builtin.os.tag == .wasi) return error.SkipZigTest;
-
-    const allocator = std.testing.allocator;
-    const expect = std.testing.expect;
-
-    const max_header_size = 8192;
-
-    const address = try std.net.Address.parseIp("127.0.0.1", 0);
-    var socket_server = try address.listen(.{ .reuse_address = true });
-    defer socket_server.deinit();
-    const server_port = socket_server.listen_address.in.getPort();
-
-    const server_thread = try std.Thread.spawn(.{}, (struct {
-        fn apply(net_server: *std.net.Server) !void {
-            var header_buffer: [max_header_size]u8 = undefined;
+    const test_server = try createTestServer(struct {
+        fn run(net_server: *std.net.Server) !void {
+            var header_buffer: [8192]u8 = undefined;
             const conn = try net_server.accept();
             defer conn.stream.close();
 
@@ -146,7 +119,8 @@ test "HTTP server handles a chunked transfer coding request" {
                 .keep_alive = false,
             });
         }
-    }).apply, .{&socket_server});
+    });
+    defer test_server.destroy();
 
     const request_bytes =
         "POST / HTTP/1.1\r\n" ++
@@ -162,30 +136,48 @@ test "HTTP server handles a chunked transfer coding request" {
         "0\r\n" ++
         "\r\n";
 
-    const stream = try std.net.tcpConnectToHost(allocator, "127.0.0.1", server_port);
+    const gpa = std.testing.allocator;
+    const stream = try std.net.tcpConnectToHost(gpa, "127.0.0.1", test_server.port());
     defer stream.close();
     try stream.writeAll(request_bytes);
-
-    server_thread.join();
 }
 
 test "echo content server" {
-    if (builtin.single_threaded) return error.SkipZigTest;
-    if (builtin.os.tag == .wasi) return error.SkipZigTest;
+    const test_server = try createTestServer(struct {
+        fn run(net_server: *std.net.Server) anyerror!void {
+            var read_buffer: [1024]u8 = undefined;
 
-    if (builtin.zig_backend == .stage2_llvm and native_endian == .big) {
-        // https://github.com/ziglang/zig/issues/13782
-        return error.SkipZigTest;
-    }
+            accept: while (true) {
+                const conn = try net_server.accept();
+                defer conn.stream.close();
 
-    const gpa = std.testing.allocator;
+                var http_server = std.http.Server.init(conn, &read_buffer);
 
-    const address = try std.net.Address.parseIp("127.0.0.1", 0);
-    var socket_server = try address.listen(.{ .reuse_address = true });
-    defer socket_server.deinit();
-    const port = socket_server.listen_address.in.getPort();
+                while (http_server.state == .ready) {
+                    var request = http_server.receiveHead() catch |err| switch (err) {
+                        error.HttpConnectionClosing => continue :accept,
+                        else => |e| return e,
+                    };
+                    if (std.mem.eql(u8, request.head.target, "/end")) {
+                        return request.respond("", .{ .keep_alive = false });
+                    }
+                    if (request.head.expect) |expect_header_value| {
+                        if (std.mem.eql(u8, expect_header_value, "garbage")) {
+                            try expectError(error.HttpExpectationFailed, request.reader());
+                            try request.respond("", .{ .keep_alive = false });
+                            continue;
+                        }
+                    }
+                    handleRequest(&request) catch |err| {
+                        // This message helps the person troubleshooting determine whether
+                        // output comes from the server thread or the client thread.
+                        std.debug.print("handleRequest failed with '{s}'\n", .{@errorName(err)});
+                        return err;
+                    };
+                }
+            }
+        }
 
-    const server_thread = try std.Thread.spawn(.{}, (struct {
         fn handleRequest(request: *std.http.Server.Request) !void {
             //std.debug.print("server received {s} {s} {s}\n", .{
             //    @tagName(request.head.method),
@@ -196,9 +188,9 @@ test "echo content server" {
             const body = try (try request.reader()).readAllAlloc(std.testing.allocator, 8192);
             defer std.testing.allocator.free(body);
 
-            try testing.expect(std.mem.startsWith(u8, request.head.target, "/echo-content"));
-            try testing.expectEqualStrings("Hello, World!\n", body);
-            try testing.expectEqualStrings("text/plain", request.head.content_type.?);
+            try expect(std.mem.startsWith(u8, request.head.target, "/echo-content"));
+            try expectEqualStrings("Hello, World!\n", body);
+            try expectEqualStrings("text/plain", request.head.content_type.?);
 
             var send_buffer: [100]u8 = undefined;
             var response = request.respondStreaming(.{
@@ -206,7 +198,7 @@ test "echo content server" {
                 .content_length = switch (request.head.transfer_encoding) {
                     .chunked => null,
                     .none => len: {
-                        try testing.expectEqual(14, request.head.content_length.?);
+                        try expectEqual(14, request.head.content_length.?);
                         break :len 14;
                     },
                 },
@@ -219,54 +211,19 @@ test "echo content server" {
             try response.end();
             //std.debug.print("  server finished responding\n", .{});
         }
-
-        fn run(net_server: *std.net.Server) anyerror!void {
-            var read_buffer: [1024]u8 = undefined;
-
-            accept: while (true) {
-                const conn = try net_server.accept();
-                defer conn.stream.close();
-
-                var http_server = std.http.Server.init(conn, &read_buffer);
-
-                while (http_server.state == .ready) {
-                    var request = http_server.receiveHead() catch |err| switch (err) {
-                        error.HttpConnectionClosing => continue :accept,
-                        else => |e| return e,
-                    };
-                    if (std.mem.eql(u8, request.head.target, "/end")) {
-                        return request.respond("", .{ .keep_alive = false });
-                    }
-                    if (request.head.expect) |expect| {
-                        if (std.mem.eql(u8, expect, "garbage")) {
-                            try testing.expectError(error.HttpExpectationFailed, request.reader());
-                            try request.respond("", .{ .keep_alive = false });
-                            continue;
-                        }
-                    }
-                    handleRequest(&request) catch |err| {
-                        // This message helps the person troubleshooting determine whether
-                        // output comes from the server thread or the client thread.
-                        std.debug.print("handleRequest failed with '{s}'\n", .{@errorName(err)});
-                        return err;
-                    };
-                }
-            }
-        }
-    }).run, .{&socket_server});
-
-    defer server_thread.join();
+    });
+    defer test_server.destroy();
 
     {
-        var client: std.http.Client = .{ .allocator = gpa };
+        var client: std.http.Client = .{ .allocator = std.testing.allocator };
         defer client.deinit();
 
-        try echoTests(&client, port);
+        try echoTests(&client, test_server.port());
     }
 }
 
 fn echoTests(client: *std.http.Client, port: u16) !void {
-    const gpa = testing.allocator;
+    const gpa = std.testing.allocator;
     var location_buffer: [100]u8 = undefined;
 
     { // send content-length request
@@ -295,11 +252,11 @@ fn echoTests(client: *std.http.Client, port: u16) !void {
         const body = try req.reader().readAllAlloc(gpa, 8192);
         defer gpa.free(body);
 
-        try testing.expectEqualStrings("Hello, World!\n", body);
+        try expectEqualStrings("Hello, World!\n", body);
     }
 
     // connection has been kept alive
-    try testing.expect(client.http_proxy != null or client.connection_pool.free_len == 1);
+    try expect(client.http_proxy != null or client.connection_pool.free_len == 1);
 
     { // send chunked request
         const uri = try std.Uri.parse(try std.fmt.bufPrint(
@@ -329,11 +286,11 @@ fn echoTests(client: *std.http.Client, port: u16) !void {
         const body = try req.reader().readAllAlloc(gpa, 8192);
         defer gpa.free(body);
 
-        try testing.expectEqualStrings("Hello, World!\n", body);
+        try expectEqualStrings("Hello, World!\n", body);
     }
 
     // connection has been kept alive
-    try testing.expect(client.http_proxy != null or client.connection_pool.free_len == 1);
+    try expect(client.http_proxy != null or client.connection_pool.free_len == 1);
 
     { // Client.fetch()
 
@@ -352,8 +309,8 @@ fn echoTests(client: *std.http.Client, port: u16) !void {
             },
             .response_storage = .{ .dynamic = &body },
         });
-        try testing.expectEqual(.ok, res.status);
-        try testing.expectEqualStrings("Hello, World!\n", body.items);
+        try expectEqual(.ok, res.status);
+        try expectEqualStrings("Hello, World!\n", body.items);
     }
 
     { // expect: 100-continue
@@ -379,12 +336,12 @@ fn echoTests(client: *std.http.Client, port: u16) !void {
         try req.finish();
 
         try req.wait();
-        try testing.expectEqual(.ok, req.response.status);
+        try expectEqual(.ok, req.response.status);
 
         const body = try req.reader().readAllAlloc(gpa, 8192);
         defer gpa.free(body);
 
-        try testing.expectEqualStrings("Hello, World!\n", body);
+        try expectEqualStrings("Hello, World!\n", body);
     }
 
     { // expect: garbage
@@ -406,7 +363,7 @@ fn echoTests(client: *std.http.Client, port: u16) !void {
 
         try req.send(.{});
         try req.wait();
-        try testing.expectEqual(.expectation_failed, req.response.status);
+        try expectEqual(.expectation_failed, req.response.status);
     }
 
     _ = try client.fetch(.{
@@ -415,3 +372,32 @@ fn echoTests(client: *std.http.Client, port: u16) !void {
         },
     });
 }
+
+const TestServer = struct {
+    server_thread: std.Thread,
+    net_server: std.net.Server,
+
+    fn destroy(self: *@This()) void {
+        self.server_thread.join();
+        std.testing.allocator.destroy(self);
+    }
+
+    fn port(self: @This()) u16 {
+        return self.net_server.listen_address.in.getPort();
+    }
+};
+
+fn createTestServer(S: type) !*TestServer {
+    if (builtin.single_threaded) return error.SkipZigTest;
+    if (builtin.os.tag == .wasi) return error.SkipZigTest;
+    if (builtin.zig_backend == .stage2_llvm and native_endian == .big) {
+        // https://github.com/ziglang/zig/issues/13782
+        return error.SkipZigTest;
+    }
+
+    const address = try std.net.Address.parseIp("127.0.0.1", 0);
+    const test_server = try std.testing.allocator.create(TestServer);
+    test_server.net_server = try address.listen(.{ .reuse_address = true });
+    test_server.server_thread = try std.Thread.spawn(.{}, S.run, .{&test_server.net_server});
+    return test_server;
+}