Commit 7f9a4625fd

Nameless <truemedian@gmail.com>
2023-04-08 16:58:48
std.http: reenable protocol read tests, add missing branch in findHeaders end
1 parent ef6d58e
Changed files (1)
lib
std
lib/std/http/protocol.zig
@@ -82,7 +82,7 @@ pub const HeadersParser = struct {
     /// If the amount returned is less than `bytes.len`, you may assume that the parser is in a content state and the
     /// first byte of content is located at `bytes[result]`.
     pub fn findHeadersEnd(r: *HeadersParser, bytes: []const u8) u32 {
-        const vector_len: comptime_int = comptime std.simd.suggestVectorSize(u8) orelse 8;
+        const vector_len: comptime_int = comptime std.math.max(std.simd.suggestVectorSize(u8) orelse 1, 8);
         const len = @intCast(u32, bytes.len);
         var index: u32 = 0;
 
@@ -232,7 +232,7 @@ pub const HeadersParser = struct {
                                     else => {},
                                 }
                             },
-                            4...vector_len - 1 => {
+                            4...vector_len => {
                                 inline for (0..vector_len - 3) |i_usize| {
                                     const i = @truncate(u32, i_usize);
 
@@ -311,6 +311,7 @@ pub const HeadersParser = struct {
 
                         switch (b16) {
                             int16("\r\n") => r.state = .seen_rn,
+                            int16("\n\r") => r.state = .seen_rnr,
                             int16("\n\n") => r.state = .finished,
                             else => {},
                         }
@@ -614,6 +615,86 @@ inline fn intShift(comptime T: type, x: anytype) T {
     }
 }
 
+/// A buffered (and peekable) Connection.
+const MockBufferedConnection = struct {
+    pub const buffer_size = 0x2000;
+
+    conn: std.io.FixedBufferStream([]const u8),
+    buf: [buffer_size]u8 = undefined,
+    start: u16 = 0,
+    end: u16 = 0,
+
+    pub fn fill(bconn: *MockBufferedConnection) ReadError!void {
+        if (bconn.end != bconn.start) return;
+
+        const nread = try bconn.conn.read(bconn.buf[0..]);
+        if (nread == 0) return error.EndOfStream;
+        bconn.start = 0;
+        bconn.end = @truncate(u16, nread);
+    }
+
+    pub fn peek(bconn: *MockBufferedConnection) []const u8 {
+        return bconn.buf[bconn.start..bconn.end];
+    }
+
+    pub fn clear(bconn: *MockBufferedConnection, num: u16) void {
+        bconn.start += num;
+    }
+
+    pub fn readAtLeast(bconn: *MockBufferedConnection, buffer: []u8, len: usize) ReadError!usize {
+        var out_index: u16 = 0;
+        while (out_index < len) {
+            const available = bconn.end - bconn.start;
+            const left = buffer.len - out_index;
+
+            if (available > 0) {
+                const can_read = @truncate(u16, @min(available, left));
+
+                std.mem.copy(u8, buffer[out_index..], bconn.buf[bconn.start..][0..can_read]);
+                out_index += can_read;
+                bconn.start += can_read;
+
+                continue;
+            }
+
+            if (left > bconn.buf.len) {
+                // skip the buffer if the output is large enough
+                return bconn.conn.read(buffer[out_index..]);
+            }
+
+            try bconn.fill();
+        }
+
+        return out_index;
+    }
+
+    pub fn read(bconn: *MockBufferedConnection, buffer: []u8) ReadError!usize {
+        return bconn.readAtLeast(buffer, 1);
+    }
+
+    pub const ReadError = std.io.FixedBufferStream([]const u8).ReadError || error{EndOfStream};
+    pub const Reader = std.io.Reader(*MockBufferedConnection, ReadError, read);
+
+    pub fn reader(bconn: *MockBufferedConnection) Reader {
+        return Reader{ .context = bconn };
+    }
+
+    pub fn writeAll(bconn: *MockBufferedConnection, buffer: []const u8) WriteError!void {
+        return bconn.conn.writeAll(buffer);
+    }
+
+    pub fn write(bconn: *MockBufferedConnection, buffer: []const u8) WriteError!usize {
+        return bconn.conn.write(buffer);
+    }
+
+    pub const WriteError = std.io.FixedBufferStream([]const u8).WriteError;
+    pub const Writer = std.io.Writer(*MockBufferedConnection, WriteError, write);
+
+    pub fn writer(bconn: *MockBufferedConnection) Writer {
+        return Writer{ .context = bconn };
+    }
+};
+
 test "HeadersParser.findHeadersEnd" {
     var r: HeadersParser = undefined;
     const data = "GET / HTTP/1.1\r\nHost: localhost\r\n\r\nHello";
@@ -662,18 +743,29 @@ test "HeadersParser.findChunkedLen" {
 
 test "HeadersParser.read length" {
     // mock BufferedConnection for read
-    if (true) return error.SkipZigTest;
 
     var r = HeadersParser.initDynamic(256);
     defer r.header_bytes.deinit(std.testing.allocator);
     const data = "GET / HTTP/1.1\r\nHost: localhost\r\nContent-Length: 5\r\n\r\nHello";
     var fbs = std.io.fixedBufferStream(data);
 
-    try r.waitForCompleteHead(fbs.reader(), std.testing.allocator);
+    var bconn = MockBufferedConnection{
+        .conn = fbs,
+    };
+
+    while (true) { // read headers
+        try bconn.fill();
+
+        const nchecked = try r.checkCompleteHead(std.testing.allocator, bconn.peek());
+        bconn.clear(@intCast(u16, nchecked));
+
+        if (r.state.isContent()) break;
+    }
+
     var buf: [8]u8 = undefined;
 
     r.next_chunk_length = 5;
-    const len = try r.read(fbs.reader(), &buf, false);
+    const len = try r.read(&bconn, &buf, false);
     try std.testing.expectEqual(@as(usize, 5), len);
     try std.testing.expectEqualStrings("Hello", buf[0..len]);
 
@@ -682,18 +774,28 @@ test "HeadersParser.read length" {
 
 test "HeadersParser.read chunked" {
     // mock BufferedConnection for read
-    if (true) return error.SkipZigTest;
 
     var r = HeadersParser.initDynamic(256);
     defer r.header_bytes.deinit(std.testing.allocator);
     const data = "GET / HTTP/1.1\r\nHost: localhost\r\n\r\n2\r\nHe\r\n2\r\nll\r\n1\r\no\r\n0\r\n\r\n";
     var fbs = std.io.fixedBufferStream(data);
 
-    try r.waitForCompleteHead(fbs.reader(), std.testing.allocator);
+    var bconn = MockBufferedConnection{
+        .conn = fbs,
+    };
+
+    while (true) { // read headers
+        try bconn.fill();
+
+        const nchecked = try r.checkCompleteHead(std.testing.allocator, bconn.peek());
+        bconn.clear(@intCast(u16, nchecked));
+
+        if (r.state.isContent()) break;
+    }
     var buf: [8]u8 = undefined;
 
     r.state = .chunk_head_size;
-    const len = try r.read(fbs.reader(), &buf, false);
+    const len = try r.read(&bconn, &buf, false);
     try std.testing.expectEqual(@as(usize, 5), len);
     try std.testing.expectEqualStrings("Hello", buf[0..len]);
 
@@ -702,22 +804,39 @@ test "HeadersParser.read chunked" {
 
 test "HeadersParser.read chunked trailer" {
     // mock BufferedConnection for read
-    if (true) return error.SkipZigTest;
 
     var r = HeadersParser.initDynamic(256);
     defer r.header_bytes.deinit(std.testing.allocator);
     const data = "GET / HTTP/1.1\r\nHost: localhost\r\n\r\n2\r\nHe\r\n2\r\nll\r\n1\r\no\r\n0\r\nContent-Type: text/plain\r\n\r\n";
     var fbs = std.io.fixedBufferStream(data);
 
-    try r.waitForCompleteHead(fbs.reader(), std.testing.allocator);
+    var bconn = MockBufferedConnection{
+        .conn = fbs,
+    };
+
+    while (true) { // read headers
+        try bconn.fill();
+
+        const nchecked = try r.checkCompleteHead(std.testing.allocator, bconn.peek());
+        bconn.clear(@intCast(u16, nchecked));
+
+        if (r.state.isContent()) break;
+    }
     var buf: [8]u8 = undefined;
 
     r.state = .chunk_head_size;
-    const len = try r.read(fbs.reader(), &buf, false);
+    const len = try r.read(&bconn, &buf, false);
     try std.testing.expectEqual(@as(usize, 5), len);
     try std.testing.expectEqualStrings("Hello", buf[0..len]);
 
-    try r.waitForCompleteHead(fbs.reader(), std.testing.allocator);
+    while (true) { // read headers
+        try bconn.fill();
+
+        const nchecked = try r.checkCompleteHead(std.testing.allocator, bconn.peek());
+        bconn.clear(@intCast(u16, nchecked));
+
+        if (r.state.isContent()) break;
+    }
 
     try std.testing.expectEqualStrings("GET / HTTP/1.1\r\nHost: localhost\r\n\r\nContent-Type: text/plain\r\n\r\n", r.header_bytes.items);
 }