Commit 4689d93cb2

Nameless <truemedian@gmail.com>
2023-08-27 23:36:24
std.http: allow for arbitrary http methods
1 parent ddef683
Changed files (4)
lib
test
standalone
lib/std/http/Client.zig
@@ -545,7 +545,7 @@ pub const Request = struct {
         var buffered = std.io.bufferedWriter(req.connection.?.data.writer());
         const w = buffered.writer();
 
-        try w.writeAll(@tagName(req.method));
+        try req.method.write(w);
         try w.writeByte(' ');
 
         if (req.method == .CONNECT) {
@@ -627,15 +627,15 @@ pub const Request = struct {
         try buffered.flush();
     }
 
-    pub const TransferReadError = Connection.ReadError || proto.HeadersParser.ReadError;
+    const TransferReadError = Connection.ReadError || proto.HeadersParser.ReadError;
 
-    pub const TransferReader = std.io.Reader(*Request, TransferReadError, transferRead);
+    const TransferReader = std.io.Reader(*Request, TransferReadError, transferRead);
 
-    pub fn transferReader(req: *Request) TransferReader {
+    fn transferReader(req: *Request) TransferReader {
         return .{ .context = req };
     }
 
-    pub fn transferRead(req: *Request, buf: []u8) TransferReadError!usize {
+    fn transferRead(req: *Request, buf: []u8) TransferReadError!usize {
         if (req.response.parser.done) return 0;
 
         var index: usize = 0;
lib/std/http/Server.zig
@@ -185,8 +185,10 @@ pub const Request = struct {
             return error.HttpHeadersInvalid;
 
         const method_end = mem.indexOfScalar(u8, first_line, ' ') orelse return error.HttpHeadersInvalid;
+        if (method_end > 24) return error.HttpHeadersInvalid;
+
         const method_str = first_line[0..method_end];
-        const method = std.meta.stringToEnum(http.Method, method_str) orelse return error.UnknownHttpMethod;
+        const method: http.Method = @enumFromInt(http.Method.parse(method_str));
 
         const version_start = mem.lastIndexOfScalar(u8, first_line, ' ') orelse return error.HttpHeadersInvalid;
         if (version_start == method_end) return error.HttpHeadersInvalid;
@@ -467,11 +469,11 @@ pub const Response = struct {
         try buffered.flush();
     }
 
-    pub const TransferReadError = Connection.ReadError || proto.HeadersParser.ReadError;
+    const TransferReadError = Connection.ReadError || proto.HeadersParser.ReadError;
 
-    pub const TransferReader = std.io.Reader(*Response, TransferReadError, transferRead);
+    const TransferReader = std.io.Reader(*Response, TransferReadError, transferRead);
 
-    pub fn transferReader(res: *Response) TransferReader {
+    fn transferReader(res: *Response) TransferReader {
         return .{ .context = res };
     }
 
lib/std/http.zig
@@ -1,3 +1,5 @@
+const std = @import("std.zig");
+
 pub const Client = @import("http/Client.zig");
 pub const Server = @import("http/Server.zig");
 pub const protocol = @import("http/protocol.zig");
@@ -14,16 +16,36 @@ pub const Version = enum {
 /// https://developer.mozilla.org/en-US/docs/Web/HTTP/Methods
 /// https://datatracker.ietf.org/doc/html/rfc7231#section-4 Initial definition
 /// https://datatracker.ietf.org/doc/html/rfc5789#section-2 PATCH
-pub const Method = enum {
-    GET,
-    HEAD,
-    POST,
-    PUT,
-    DELETE,
-    CONNECT,
-    OPTIONS,
-    TRACE,
-    PATCH,
+pub const Method = enum(u64) { // TODO: should be u192 or u256, but neither is supported by the C backend, and therefore cannot pass CI
+    GET = parse("GET"),
+    HEAD = parse("HEAD"),
+    POST = parse("POST"),
+    PUT = parse("PUT"),
+    DELETE = parse("DELETE"),
+    CONNECT = parse("CONNECT"),
+    OPTIONS = parse("OPTIONS"),
+    TRACE = parse("TRACE"),
+    PATCH = parse("PATCH"),
+
+    _,
+
+    /// Converts `s` into a type that may be used as a `Method` field.
+    /// Asserts that `s` is 24 or fewer bytes.
+    pub fn parse(s: []const u8) u64 {
+        var x: u64 = 0;
+        @memcpy(std.mem.asBytes(&x)[0..s.len], s);
+        return x;
+    }
+
+    pub fn write(self: Method, w: anytype) !void {
+        const bytes = std.mem.asBytes(&@intFromEnum(self));
+        const str = std.mem.sliceTo(bytes, 0);
+        try w.writeAll(str);
+    }
+
+    pub fn format(value: Method, comptime _: []const u8, _: std.fmt.FormatOptions, writer: anytype) @TypeOf(writer).Error!void {
+        return try value.write(writer);
+    }
 
     /// Returns true if a request of this method is allowed to have a body
     /// Actual behavior from servers may vary and should still be checked
@@ -31,6 +53,7 @@ pub const Method = enum {
         return switch (self) {
             .POST, .PUT, .PATCH => true,
             .GET, .HEAD, .DELETE, .CONNECT, .OPTIONS, .TRACE => false,
+            else => true,
         };
     }
 
@@ -40,6 +63,7 @@ pub const Method = enum {
         return switch (self) {
             .GET, .POST, .DELETE, .CONNECT, .OPTIONS, .PATCH => true,
             .HEAD, .PUT, .TRACE => false,
+            else => true,
         };
     }
 
@@ -50,6 +74,7 @@ pub const Method = enum {
         return switch (self) {
             .GET, .HEAD, .OPTIONS, .TRACE => true,
             .POST, .PUT, .DELETE, .CONNECT, .PATCH => false,
+            else => false,
         };
     }
 
@@ -60,6 +85,7 @@ pub const Method = enum {
         return switch (self) {
             .GET, .HEAD, .PUT, .DELETE, .OPTIONS, .TRACE => true,
             .CONNECT, .POST, .PATCH => false,
+            else => false,
         };
     }
 
@@ -70,6 +96,7 @@ pub const Method = enum {
         return switch (self) {
             .GET, .HEAD => true,
             .POST, .PUT, .DELETE, .CONNECT, .OPTIONS, .TRACE, .PATCH => false,
+            else => false,
         };
     }
 };
@@ -269,8 +296,6 @@ pub const Connection = enum {
     close,
 };
 
-const std = @import("std.zig");
-
 test {
     _ = Client;
     _ = Method;
test/standalone/http.zig
@@ -20,7 +20,7 @@ var server: Server = undefined;
 fn handleRequest(res: *Server.Response) !void {
     const log = std.log.scoped(.server);
 
-    log.info("{s} {s} {s}", .{ @tagName(res.request.method), @tagName(res.request.version), res.request.target });
+    log.info("{} {s} {s}", .{ res.request.method, @tagName(res.request.version), res.request.target });
 
     if (res.request.headers.contains("expect")) {
         if (mem.eql(u8, res.request.headers.getFirstValue("expect").?, "100-continue")) {