Commit 618a435ad4

Andrew Kelley <andrew@ziglang.org>
2025-08-07 03:37:48
std.http.Server: add safety for invalidated Head strings
and fix bad unit test API usage that it finds
1 parent 858716a
Changed files (3)
lib/std/http/Client.zig
@@ -444,8 +444,8 @@ pub const Connection = struct {
 
 pub const Response = struct {
     request: *Request,
-    /// Pointers in this struct are invalidated with the next call to
-    /// `receiveHead`.
+    /// Pointers in this struct are invalidated when the response body stream
+    /// is initialized.
     head: Head,
 
     pub const Head = struct {
@@ -671,6 +671,16 @@ pub const Response = struct {
             try expectEqual(@as(u10, 418), parseInt3("418"));
             try expectEqual(@as(u10, 999), parseInt3("999"));
         }
+
+        /// Help the programmer avoid bugs by calling this when the string
+        /// memory of `Head` becomes invalidated.
+        fn invalidateStrings(h: *Head) void {
+            h.bytes = undefined;
+            h.reason = undefined;
+            if (h.location) |*s| s.* = undefined;
+            if (h.content_type) |*s| s.* = undefined;
+            if (h.content_disposition) |*s| s.* = undefined;
+        }
     };
 
     /// If compressed body has been negotiated this will return compressed bytes.
@@ -682,7 +692,8 @@ pub const Response = struct {
     ///
     /// See also:
     /// * `readerDecompressing`
-    pub fn reader(response: *const Response, buffer: []u8) *Reader {
+    pub fn reader(response: *Response, buffer: []u8) *Reader {
+        response.head.invalidateStrings();
         const req = response.request;
         if (!req.method.responseHasBody()) return .ending;
         const head = &response.head;
@@ -703,6 +714,7 @@ pub const Response = struct {
         decompressor: *http.Decompressor,
         decompression_buffer: []u8,
     ) *Reader {
+        response.head.invalidateStrings();
         const head = &response.head;
         return response.request.reader.bodyReaderDecompressing(
             head.transfer_encoding,
lib/std/http/Server.zig
@@ -55,8 +55,8 @@ pub fn receiveHead(s: *Server) ReceiveHeadError!Request {
 
 pub const Request = struct {
     server: *Server,
-    /// Pointers in this struct are invalidated with the next call to
-    /// `receiveHead`.
+    /// Pointers in this struct are invalidated when the request body stream is
+    /// initialized.
     head: Head,
     head_buffer: []const u8,
     respond_err: ?RespondError = null,
@@ -224,6 +224,14 @@ pub const Request = struct {
         inline fn int64(array: *const [8]u8) u64 {
             return @bitCast(array.*);
         }
+
+        /// Help the programmer avoid bugs by calling this when the string
+        /// memory of `Head` becomes invalidated.
+        fn invalidateStrings(h: *Head) void {
+            h.target = undefined;
+            if (h.expect) |*s| s.* = undefined;
+            if (h.content_type) |*s| s.* = undefined;
+        }
     };
 
     pub fn iterateHeaders(r: *const Request) http.HeaderIterator {
@@ -578,9 +586,12 @@ pub const Request = struct {
     /// this function.
     ///
     /// Asserts that this function is only called once.
+    ///
+    /// Invalidates the string memory inside `Head`.
     pub fn readerExpectNone(request: *Request, buffer: []u8) *Reader {
         assert(request.server.reader.state == .received_head);
         assert(request.head.expect == null);
+        request.head.invalidateStrings();
         if (!request.head.method.requestHasBody()) return .ending;
         return request.server.reader.bodyReader(buffer, request.head.transfer_encoding, request.head.content_length);
     }
lib/std/http/test.zig
@@ -65,23 +65,22 @@ test "trailers" {
         try req.sendBodiless();
         var response = try req.receiveHead(&.{});
 
-        const body = try response.reader(&.{}).allocRemaining(gpa, .unlimited);
-        defer gpa.free(body);
-
-        try expectEqualStrings("Hello, World!\n", body);
-
         {
             var it = response.head.iterateHeaders();
             const header = it.next().?;
-            try expect(!it.is_trailer);
             try expectEqualStrings("transfer-encoding", header.name);
             try expectEqualStrings("chunked", header.value);
             try expectEqual(null, it.next());
         }
+
+        const body = try response.reader(&.{}).allocRemaining(gpa, .unlimited);
+        defer gpa.free(body);
+
+        try expectEqualStrings("Hello, World!\n", body);
+
         {
             var it = response.iterateTrailers();
             const header = it.next().?;
-            try expect(it.is_trailer);
             try expectEqualStrings("X-Checksum", header.name);
             try expectEqualStrings("aaaa", header.value);
             try expectEqual(null, it.next());
@@ -208,12 +207,14 @@ test "echo content server" {
             //    request.head.target,
             //});
 
+            try expect(mem.startsWith(u8, request.head.target, "/echo-content"));
+            try expectEqualStrings("text/plain", request.head.content_type.?);
+
+            // head strings expire here
             const body = try (try request.readerExpectContinue(&.{})).allocRemaining(std.testing.allocator, .unlimited);
             defer std.testing.allocator.free(body);
 
-            try expect(mem.startsWith(u8, request.head.target, "/echo-content"));
             try expectEqualStrings("Hello, World!\n", body);
-            try expectEqualStrings("text/plain", request.head.content_type.?);
 
             var response = try request.respondStreaming(&.{}, .{
                 .content_length = switch (request.head.transfer_encoding) {
@@ -410,17 +411,19 @@ test "general client/server API coverage" {
 
         fn handleRequest(request: *http.Server.Request, listen_port: u16) !void {
             const log = std.log.scoped(.server);
+            const gpa = std.testing.allocator;
 
             log.info("{f} {t} {s}", .{ request.head.method, request.head.version, request.head.target });
+            const target = try gpa.dupe(u8, request.head.target);
+            defer gpa.free(target);
 
-            const gpa = std.testing.allocator;
             const reader = (try request.readerExpectContinue(&.{}));
             const body = try reader.allocRemaining(gpa, .unlimited);
             defer gpa.free(body);
 
-            if (mem.startsWith(u8, request.head.target, "/get")) {
+            if (mem.startsWith(u8, target, "/get")) {
                 var response = try request.respondStreaming(&.{}, .{
-                    .content_length = if (mem.indexOf(u8, request.head.target, "?chunked") == null)
+                    .content_length = if (mem.indexOf(u8, target, "?chunked") == null)
                         14
                     else
                         null,
@@ -435,7 +438,7 @@ test "general client/server API coverage" {
                 try w.writeAll("World!\n");
                 try response.end();
                 // Writing again would cause an assertion failure.
-            } else if (mem.startsWith(u8, request.head.target, "/large")) {
+            } else if (mem.startsWith(u8, target, "/large")) {
                 var response = try request.respondStreaming(&.{}, .{
                     .content_length = 14 * 1024 + 14 * 10,
                 });
@@ -458,7 +461,7 @@ test "general client/server API coverage" {
                 }
 
                 try response.end();
-            } else if (mem.eql(u8, request.head.target, "/redirect/1")) {
+            } else if (mem.eql(u8, target, "/redirect/1")) {
                 var response = try request.respondStreaming(&.{}, .{
                     .respond_options = .{
                         .status = .found,
@@ -472,14 +475,14 @@ test "general client/server API coverage" {
                 try w.writeAll("Hello, ");
                 try w.writeAll("Redirected!\n");
                 try response.end();
-            } else if (mem.eql(u8, request.head.target, "/redirect/2")) {
+            } else if (mem.eql(u8, target, "/redirect/2")) {
                 try request.respond("Hello, Redirected!\n", .{
                     .status = .found,
                     .extra_headers = &.{
                         .{ .name = "location", .value = "/redirect/1" },
                     },
                 });
-            } else if (mem.eql(u8, request.head.target, "/redirect/3")) {
+            } else if (mem.eql(u8, target, "/redirect/3")) {
                 const location = try std.fmt.allocPrint(gpa, "http://127.0.0.1:{d}/redirect/2", .{
                     listen_port,
                 });
@@ -491,23 +494,23 @@ test "general client/server API coverage" {
                         .{ .name = "location", .value = location },
                     },
                 });
-            } else if (mem.eql(u8, request.head.target, "/redirect/4")) {
+            } else if (mem.eql(u8, target, "/redirect/4")) {
                 try request.respond("Hello, Redirected!\n", .{
                     .status = .found,
                     .extra_headers = &.{
                         .{ .name = "location", .value = "/redirect/3" },
                     },
                 });
-            } else if (mem.eql(u8, request.head.target, "/redirect/5")) {
+            } else if (mem.eql(u8, target, "/redirect/5")) {
                 try request.respond("Hello, Redirected!\n", .{
                     .status = .found,
                     .extra_headers = &.{
                         .{ .name = "location", .value = "/%2525" },
                     },
                 });
-            } else if (mem.eql(u8, request.head.target, "/%2525")) {
+            } else if (mem.eql(u8, target, "/%2525")) {
                 try request.respond("Encoded redirect successful!\n", .{});
-            } else if (mem.eql(u8, request.head.target, "/redirect/invalid")) {
+            } else if (mem.eql(u8, target, "/redirect/invalid")) {
                 const invalid_port = try getUnusedTcpPort();
                 const location = try std.fmt.allocPrint(gpa, "http://127.0.0.1:{d}", .{invalid_port});
                 defer gpa.free(location);
@@ -518,7 +521,7 @@ test "general client/server API coverage" {
                         .{ .name = "location", .value = location },
                     },
                 });
-            } else if (mem.eql(u8, request.head.target, "/empty")) {
+            } else if (mem.eql(u8, target, "/empty")) {
                 try request.respond("", .{
                     .extra_headers = &.{
                         .{ .name = "empty", .value = "" },
@@ -559,11 +562,12 @@ test "general client/server API coverage" {
         try req.sendBodiless();
         var response = try req.receiveHead(&redirect_buffer);
 
+        try expectEqualStrings("text/plain", response.head.content_type.?);
+
         const body = try response.reader(&.{}).allocRemaining(gpa, .unlimited);
         defer gpa.free(body);
 
         try expectEqualStrings("Hello, World!\n", body);
-        try expectEqualStrings("text/plain", response.head.content_type.?);
     }
 
     // connection has been kept alive
@@ -604,12 +608,13 @@ test "general client/server API coverage" {
         try req.sendBodiless();
         var response = try req.receiveHead(&redirect_buffer);
 
+        try expectEqualStrings("text/plain", response.head.content_type.?);
+        try expectEqual(14, response.head.content_length.?);
+
         const body = try response.reader(&.{}).allocRemaining(gpa, .unlimited);
         defer gpa.free(body);
 
         try expectEqualStrings("", body);
-        try expectEqualStrings("text/plain", response.head.content_type.?);
-        try expectEqual(14, response.head.content_length.?);
     }
 
     // connection has been kept alive
@@ -628,11 +633,12 @@ test "general client/server API coverage" {
         try req.sendBodiless();
         var response = try req.receiveHead(&redirect_buffer);
 
+        try expectEqualStrings("text/plain", response.head.content_type.?);
+
         const body = try response.reader(&.{}).allocRemaining(gpa, .unlimited);
         defer gpa.free(body);
 
         try expectEqualStrings("Hello, World!\n", body);
-        try expectEqualStrings("text/plain", response.head.content_type.?);
     }
 
     // connection has been kept alive
@@ -651,12 +657,13 @@ test "general client/server API coverage" {
         try req.sendBodiless();
         var response = try req.receiveHead(&redirect_buffer);
 
+        try expectEqualStrings("text/plain", response.head.content_type.?);
+        try expect(response.head.transfer_encoding == .chunked);
+
         const body = try response.reader(&.{}).allocRemaining(gpa, .unlimited);
         defer gpa.free(body);
 
         try expectEqualStrings("", body);
-        try expectEqualStrings("text/plain", response.head.content_type.?);
-        try expect(response.head.transfer_encoding == .chunked);
     }
 
     // connection has been kept alive
@@ -677,11 +684,12 @@ test "general client/server API coverage" {
         try req.sendBodiless();
         var response = try req.receiveHead(&redirect_buffer);
 
+        try expectEqualStrings("text/plain", response.head.content_type.?);
+
         const body = try response.reader(&.{}).allocRemaining(gpa, .unlimited);
         defer gpa.free(body);
 
         try expectEqualStrings("Hello, World!\n", body);
-        try expectEqualStrings("text/plain", response.head.content_type.?);
     }
 
     // connection has been closed
@@ -706,11 +714,6 @@ test "general client/server API coverage" {
 
         try std.testing.expectEqual(.ok, response.head.status);
 
-        const body = try response.reader(&.{}).allocRemaining(gpa, .unlimited);
-        defer gpa.free(body);
-
-        try expectEqualStrings("", body);
-
         var it = response.head.iterateHeaders();
         {
             const header = it.next().?;
@@ -718,6 +721,12 @@ test "general client/server API coverage" {
             try expectEqualStrings("content-length", header.name);
             try expectEqualStrings("0", header.value);
         }
+
+        const body = try response.reader(&.{}).allocRemaining(gpa, .unlimited);
+        defer gpa.free(body);
+
+        try expectEqualStrings("", body);
+
         {
             const header = it.next().?;
             try expect(!it.is_trailer);