Commit a7a933d7ee

Andrew Kelley <andrew@ziglang.org>
2023-01-06 03:27:53
std.http.Client: support transfer-encoding: chunked
closes #14204 In order to add tests for this I need to implement an HTTP server in the standard library (#910) so that's probably the next thing I'll do.
1 parent d711f45
Changed files (2)
lib
lib/std/http/Client.zig
@@ -77,12 +77,14 @@ pub const Request = struct {
         /// could be our own array list.
         header_bytes: std.ArrayListUnmanaged(u8),
         max_header_bytes: usize,
+        next_chunk_length: u64,
 
         pub const Headers = struct {
-            location: ?[]const u8 = null,
             status: http.Status,
             version: http.Version,
+            location: ?[]const u8 = null,
             content_length: ?u64 = null,
+            transfer_encoding: ?http.TransferEncoding = null,
 
             pub fn parse(bytes: []const u8) !Response.Headers {
                 var it = mem.split(u8, bytes[0 .. bytes.len - 4], "\r\n");
@@ -119,6 +121,10 @@ pub const Request = struct {
                     } else if (std.ascii.eqlIgnoreCase(header_name, "content-length")) {
                         if (headers.content_length != null) return error.HttpHeadersInvalid;
                         headers.content_length = try std.fmt.parseInt(u64, header_value, 10);
+                    } else if (std.ascii.eqlIgnoreCase(header_name, "transfer-encoding")) {
+                        if (headers.transfer_encoding != null) return error.HttpHeadersInvalid;
+                        headers.transfer_encoding = std.meta.stringToEnum(http.TransferEncoding, header_value) orelse
+                            return error.HttpTransferEncodingUnsupported;
                     }
                 }
 
@@ -164,12 +170,24 @@ pub const Request = struct {
         };
 
         pub const State = enum {
+            /// Begin header parsing states.
             invalid,
-            finished,
             start,
             seen_r,
             seen_rn,
             seen_rnr,
+            finished,
+            /// Begin transfer-encoding: chunked parsing states.
+            chunk_size,
+            chunk_r,
+            chunk_data,
+
+            pub fn zeroMeansEnd(state: State) bool {
+                return switch (state) {
+                    .finished, .chunk_data => true,
+                    else => false,
+                };
+            }
         };
 
         pub fn initDynamic(max: usize) Response {
@@ -179,6 +197,7 @@ pub const Request = struct {
                 .header_bytes = .{},
                 .max_header_bytes = max,
                 .header_bytes_owned = true,
+                .next_chunk_length = undefined,
             };
         }
 
@@ -189,6 +208,7 @@ pub const Request = struct {
                 .header_bytes = .{ .items = buf[0..0], .capacity = buf.len },
                 .max_header_bytes = buf.len,
                 .header_bytes_owned = false,
+                .next_chunk_length = undefined,
             };
         }
 
@@ -362,12 +382,60 @@ pub const Request = struct {
                             continue :state;
                         },
                     },
+                    .chunk_size => unreachable,
+                    .chunk_r => unreachable,
+                    .chunk_data => unreachable,
                 }
 
                 return index;
             }
         }
 
+        pub fn findChunkedLen(r: *Response, bytes: []const u8) usize {
+            var i: usize = 0;
+            if (r.state == .chunk_size) {
+                while (i < bytes.len) : (i += 1) {
+                    const digit = switch (bytes[i]) {
+                        '0'...'9' => |b| b - '0',
+                        'A'...'Z' => |b| b - 'A' + 10,
+                        'a'...'z' => |b| b - 'a' + 10,
+                        '\r' => {
+                            r.state = .chunk_r;
+                            i += 1;
+                            break;
+                        },
+                        else => {
+                            r.state = .invalid;
+                            return i;
+                        },
+                    };
+                    const mul = @mulWithOverflow(r.next_chunk_length, 16);
+                    if (mul[1] != 0) {
+                        r.state = .invalid;
+                        return i;
+                    }
+                    const add = @addWithOverflow(mul[0], digit);
+                    if (add[1] != 0) {
+                        r.state = .invalid;
+                        return i;
+                    }
+                    r.next_chunk_length = add[0];
+                } else {
+                    return i;
+                }
+            }
+            assert(r.state == .chunk_r);
+            if (i == bytes.len) return i;
+
+            if (bytes[i] == '\n') {
+                r.state = .chunk_data;
+                return i + 1;
+            } else {
+                r.state = .invalid;
+                return i;
+            }
+        }
+
         fn parseInt3(nnn: @Vector(3, u8)) u10 {
             const zero: @Vector(3, u8) = .{ '0', '0', '0' };
             const mmm: @Vector(3, u10) = .{ 100, 10, 1 };
@@ -415,6 +483,7 @@ pub const Request = struct {
     };
 
     pub const Headers = struct {
+        version: http.Version = .@"HTTP/1.1",
         method: http.Method = .GET,
     };
 
@@ -456,9 +525,9 @@ pub const Request = struct {
         assert(len <= buffer.len);
         var index: usize = 0;
         while (index < len) {
-            const headers_finished = req.response.state == .finished;
+            const zero_means_end = req.response.state.zeroMeansEnd();
             const amt = try readAdvanced(req, buffer[index..]);
-            if (amt == 0 and headers_finished) break;
+            if (amt == 0 and zero_means_end) break;
             index += amt;
         }
         return index;
@@ -467,47 +536,101 @@ pub const Request = struct {
     /// This one can return 0 without meaning EOF.
     /// TODO change to readvAdvanced
     pub fn readAdvanced(req: *Request, buffer: []u8) !usize {
-        if (req.response.state == .finished) return req.connection.read(buffer);
-
         const amt = try req.connection.read(buffer);
-        const data = buffer[0..amt];
-        const i = req.response.findHeadersEnd(data);
-        if (req.response.state == .invalid) return error.HttpHeadersInvalid;
+        var in = buffer[0..amt];
+        var out_index: usize = 0;
+        while (true) {
+            switch (req.response.state) {
+                .invalid => unreachable,
+                .start, .seen_r, .seen_rn, .seen_rnr => {
+                    const i = req.response.findHeadersEnd(in);
+                    if (req.response.state == .invalid) return error.HttpHeadersInvalid;
+
+                    const headers_data = in[0..i];
+                    if (req.response.header_bytes.items.len + headers_data.len > req.response.max_header_bytes) {
+                        return error.HttpHeadersExceededSizeLimit;
+                    }
+                    try req.response.header_bytes.appendSlice(req.client.allocator, headers_data);
+
+                    if (req.response.state == .finished) {
+                        req.response.headers = try Response.Headers.parse(req.response.header_bytes.items);
+
+                        if (req.response.headers.status.class() == .redirect) {
+                            if (req.redirects_left == 0) return error.TooManyHttpRedirects;
+                            const location = req.response.headers.location orelse
+                                return error.HttpRedirectMissingLocation;
+                            const new_url = try std.Url.parse(location);
+                            const new_req = try req.client.request(new_url, req.headers, .{
+                                .max_redirects = req.redirects_left - 1,
+                                .header_strategy = if (req.response.header_bytes_owned) .{
+                                    .dynamic = req.response.max_header_bytes,
+                                } else .{
+                                    .static = req.response.header_bytes.unusedCapacitySlice(),
+                                },
+                            });
+                            req.deinit();
+                            req.* = new_req;
+                            assert(out_index == 0);
+                            return readAdvanced(req, buffer);
+                        }
 
-        const headers_data = data[0..i];
-        if (req.response.header_bytes.items.len + headers_data.len > req.response.max_header_bytes) {
-            return error.HttpHeadersExceededSizeLimit;
-        }
-        try req.response.header_bytes.appendSlice(req.client.allocator, headers_data);
+                        if (req.response.headers.transfer_encoding) |transfer_encoding| {
+                            switch (transfer_encoding) {
+                                .chunked => {
+                                    req.response.next_chunk_length = 0;
+                                    req.response.state = .chunk_size;
+                                },
+                                .compress => return error.HttpTransferEncodingUnsupported,
+                                .deflate => return error.HttpTransferEncodingUnsupported,
+                                .gzip => return error.HttpTransferEncodingUnsupported,
+                            }
+                        } else if (req.response.headers.content_length) |content_length| {
+                            req.response.next_chunk_length = content_length;
+                        } else {
+                            return error.HttpContentLengthUnknown;
+                        }
 
-        if (req.response.state == .finished) {
-            req.response.headers = try Response.Headers.parse(req.response.header_bytes.items);
-        }
+                        in = in[i..];
+                        continue;
+                    }
 
-        if (req.response.headers.status.class() == .redirect) {
-            if (req.redirects_left == 0) return error.TooManyHttpRedirects;
-            const location = req.response.headers.location orelse
-                return error.HttpRedirectMissingLocation;
-            const new_url = try std.Url.parse(location);
-            const new_req = try req.client.request(new_url, req.headers, .{
-                .max_redirects = req.redirects_left - 1,
-                .header_strategy = if (req.response.header_bytes_owned) .{
-                    .dynamic = req.response.max_header_bytes,
-                } else .{
-                    .static = req.response.header_bytes.unusedCapacitySlice(),
+                    assert(out_index == 0);
+                    return 0;
                 },
-            });
-            req.deinit();
-            req.* = new_req;
-            return readAdvanced(req, buffer);
-        }
-
-        const body_data = data[i..];
-        if (body_data.len > 0) {
-            mem.copy(u8, buffer, body_data);
-            return body_data.len;
+                .finished => {
+                    mem.copy(u8, buffer[out_index..], in);
+                    return out_index + in.len;
+                },
+                .chunk_size, .chunk_r => {
+                    const i = req.response.findChunkedLen(in);
+                    switch (req.response.state) {
+                        .invalid => return error.HttpHeadersInvalid,
+                        .chunk_data => {
+                            if (req.response.next_chunk_length == 0) {
+                                req.response.state = .start;
+                                return out_index;
+                            }
+                            in = in[i..];
+                            continue;
+                        },
+                        .chunk_size => return out_index,
+                        else => unreachable,
+                    }
+                },
+                .chunk_data => {
+                    const sub_amt = @min(req.response.next_chunk_length, in.len);
+                    mem.copy(u8, buffer[out_index..], in[0..sub_amt]);
+                    out_index += sub_amt;
+                    req.response.next_chunk_length -= sub_amt;
+                    if (req.response.next_chunk_length == 0) {
+                        req.response.state = .chunk_size;
+                        in = in[sub_amt..];
+                        continue;
+                    }
+                    return out_index;
+                },
+            }
         }
-        return 0;
     }
 
     test {
@@ -569,7 +692,9 @@ pub fn request(client: *Client, url: Url, headers: Request.Headers, options: Req
         try h.appendSlice(@tagName(headers.method));
         try h.appendSlice(" ");
         try h.appendSlice(url.path);
-        try h.appendSlice(" HTTP/1.1\r\nHost: ");
+        try h.appendSlice(" ");
+        try h.appendSlice(@tagName(headers.version));
+        try h.appendSlice("\r\nHost: ");
         try h.appendSlice(url.host);
         try h.appendSlice("\r\nConnection: close\r\n\r\n");
 
lib/std/http.zig
@@ -246,6 +246,13 @@ pub const Status = enum(u10) {
     }
 };
 
+pub const TransferEncoding = enum {
+    chunked,
+    compress,
+    deflate,
+    gzip,
+};
+
 const std = @import("std.zig");
 
 test {