Commit 6129ecd4fe
Changed files (10)
lib
std
src
test
standalone
lib/std/http/Server/Connection.zig
@@ -0,0 +1,132 @@
+stream: std.net.Stream,
+protocol: Protocol,
+
+closing: bool,
+
+read_buf: [buffer_size]u8,
+read_start: u16,
+read_end: u16,
+
+pub const buffer_size = std.crypto.tls.max_ciphertext_record_len;
+pub const Protocol = enum { plain };
+
+pub fn rawReadAtLeast(conn: *Connection, buffer: []u8, len: usize) ReadError!usize {
+ return switch (conn.protocol) {
+ .plain => conn.stream.readAtLeast(buffer, len),
+ // .tls => conn.tls_client.readAtLeast(conn.stream, buffer, len),
+ } catch |err| {
+ switch (err) {
+ error.ConnectionResetByPeer, error.BrokenPipe => return error.ConnectionResetByPeer,
+ else => return error.UnexpectedReadFailure,
+ }
+ };
+}
+
+pub fn fill(conn: *Connection) ReadError!void {
+ if (conn.read_end != conn.read_start) return;
+
+ const nread = try conn.rawReadAtLeast(conn.read_buf[0..], 1);
+ if (nread == 0) return error.EndOfStream;
+ conn.read_start = 0;
+ conn.read_end = @intCast(nread);
+}
+
+pub fn peek(conn: *Connection) []const u8 {
+ return conn.read_buf[conn.read_start..conn.read_end];
+}
+
+pub fn drop(conn: *Connection, num: u16) void {
+ conn.read_start += num;
+}
+
+pub fn readAtLeast(conn: *Connection, buffer: []u8, len: usize) ReadError!usize {
+ assert(len <= buffer.len);
+
+ var out_index: u16 = 0;
+ while (out_index < len) {
+ const available_read = conn.read_end - conn.read_start;
+ const available_buffer = buffer.len - out_index;
+
+ if (available_read > available_buffer) { // partially read buffered data
+ @memcpy(buffer[out_index..], conn.read_buf[conn.read_start..conn.read_end][0..available_buffer]);
+ out_index += @as(u16, @intCast(available_buffer));
+ conn.read_start += @as(u16, @intCast(available_buffer));
+
+ break;
+ } else if (available_read > 0) { // fully read buffered data
+ @memcpy(buffer[out_index..][0..available_read], conn.read_buf[conn.read_start..conn.read_end]);
+ out_index += available_read;
+ conn.read_start += available_read;
+
+ if (out_index >= len) break;
+ }
+
+ const leftover_buffer = available_buffer - available_read;
+ const leftover_len = len - out_index;
+
+ if (leftover_buffer > conn.read_buf.len) {
+ // skip the buffer if the output is large enough
+ return conn.rawReadAtLeast(buffer[out_index..], leftover_len);
+ }
+
+ try conn.fill();
+ }
+
+ return out_index;
+}
+
+pub fn read(conn: *Connection, buffer: []u8) ReadError!usize {
+ return conn.readAtLeast(buffer, 1);
+}
+
+pub const ReadError = error{
+ ConnectionTimedOut,
+ ConnectionResetByPeer,
+ UnexpectedReadFailure,
+ EndOfStream,
+};
+
+pub const Reader = std.io.Reader(*Connection, ReadError, read);
+
+pub fn reader(conn: *Connection) Reader {
+ return .{ .context = conn };
+}
+
+pub fn writeAll(conn: *Connection, buffer: []const u8) WriteError!void {
+ return switch (conn.protocol) {
+ .plain => conn.stream.writeAll(buffer),
+ // .tls => return conn.tls_client.writeAll(conn.stream, buffer),
+ } catch |err| switch (err) {
+ error.BrokenPipe, error.ConnectionResetByPeer => return error.ConnectionResetByPeer,
+ else => return error.UnexpectedWriteFailure,
+ };
+}
+
+pub fn write(conn: *Connection, buffer: []const u8) WriteError!usize {
+ return switch (conn.protocol) {
+ .plain => conn.stream.write(buffer),
+ // .tls => return conn.tls_client.write(conn.stream, buffer),
+ } catch |err| switch (err) {
+ error.BrokenPipe, error.ConnectionResetByPeer => return error.ConnectionResetByPeer,
+ else => return error.UnexpectedWriteFailure,
+ };
+}
+
+pub const WriteError = error{
+ ConnectionResetByPeer,
+ UnexpectedWriteFailure,
+};
+
+pub const Writer = std.io.Writer(*Connection, WriteError, write);
+
+pub fn writer(conn: *Connection) Writer {
+ return .{ .context = conn };
+}
+
+pub fn close(conn: *Connection) void {
+ conn.stream.close();
+}
+
+const Connection = @This();
+const std = @import("../../std.zig");
+const assert = std.debug.assert;
lib/std/http/Server.zig
@@ -1,155 +1,54 @@
-//! HTTP Server implementation.
-//!
-//! This server assumes clients are well behaved and standard compliant; it
-//! deadlocks if a client holds a connection open without sending a request.
+version: http.Version,
+status: http.Status,
+reason: ?[]const u8,
+transfer_encoding: ResponseTransfer,
+keep_alive: bool,
+connection: Connection,
-const builtin = @import("builtin");
-const std = @import("../std.zig");
-const testing = std.testing;
-const http = std.http;
-const mem = std.mem;
-const net = std.net;
-const Uri = std.Uri;
-const Allocator = mem.Allocator;
-const assert = std.debug.assert;
-
-const Server = @This();
-const proto = @import("protocol.zig");
-
-/// The underlying server socket.
-socket: net.StreamServer,
-
-/// An interface to a plain connection.
-pub const Connection = struct {
- stream: net.Stream,
- protocol: Protocol,
+/// Externally-owned; must outlive the Server.
+extra_headers: []const http.Header,
- closing: bool = true,
-
- read_buf: [buffer_size]u8 = undefined,
- read_start: u16 = 0,
- read_end: u16 = 0,
-
- pub const buffer_size = std.crypto.tls.max_ciphertext_record_len;
- pub const Protocol = enum { plain };
-
- pub fn rawReadAtLeast(conn: *Connection, buffer: []u8, len: usize) ReadError!usize {
- return switch (conn.protocol) {
- .plain => conn.stream.readAtLeast(buffer, len),
- // .tls => conn.tls_client.readAtLeast(conn.stream, buffer, len),
- } catch |err| {
- switch (err) {
- error.ConnectionResetByPeer, error.BrokenPipe => return error.ConnectionResetByPeer,
- else => return error.UnexpectedReadFailure,
- }
- };
- }
-
- pub fn fill(conn: *Connection) ReadError!void {
- if (conn.read_end != conn.read_start) return;
-
- const nread = try conn.rawReadAtLeast(conn.read_buf[0..], 1);
- if (nread == 0) return error.EndOfStream;
- conn.read_start = 0;
- conn.read_end = @intCast(nread);
- }
-
- pub fn peek(conn: *Connection) []const u8 {
- return conn.read_buf[conn.read_start..conn.read_end];
- }
-
- pub fn drop(conn: *Connection, num: u16) void {
- conn.read_start += num;
- }
-
- pub fn readAtLeast(conn: *Connection, buffer: []u8, len: usize) ReadError!usize {
- assert(len <= buffer.len);
-
- var out_index: u16 = 0;
- while (out_index < len) {
- const available_read = conn.read_end - conn.read_start;
- const available_buffer = buffer.len - out_index;
-
- if (available_read > available_buffer) { // partially read buffered data
- @memcpy(buffer[out_index..], conn.read_buf[conn.read_start..conn.read_end][0..available_buffer]);
- out_index += @as(u16, @intCast(available_buffer));
- conn.read_start += @as(u16, @intCast(available_buffer));
-
- break;
- } else if (available_read > 0) { // fully read buffered data
- @memcpy(buffer[out_index..][0..available_read], conn.read_buf[conn.read_start..conn.read_end]);
- out_index += available_read;
- conn.read_start += available_read;
-
- if (out_index >= len) break;
- }
-
- const leftover_buffer = available_buffer - available_read;
- const leftover_len = len - out_index;
-
- if (leftover_buffer > conn.read_buf.len) {
- // skip the buffer if the output is large enough
- return conn.rawReadAtLeast(buffer[out_index..], leftover_len);
- }
-
- try conn.fill();
- }
-
- return out_index;
- }
-
- pub fn read(conn: *Connection, buffer: []u8) ReadError!usize {
- return conn.readAtLeast(buffer, 1);
- }
-
- pub const ReadError = error{
- ConnectionTimedOut,
- ConnectionResetByPeer,
- UnexpectedReadFailure,
- EndOfStream,
- };
-
- pub const Reader = std.io.Reader(*Connection, ReadError, read);
-
- pub fn reader(conn: *Connection) Reader {
- return Reader{ .context = conn };
- }
-
- pub fn writeAll(conn: *Connection, buffer: []const u8) WriteError!void {
- return switch (conn.protocol) {
- .plain => conn.stream.writeAll(buffer),
- // .tls => return conn.tls_client.writeAll(conn.stream, buffer),
- } catch |err| switch (err) {
- error.BrokenPipe, error.ConnectionResetByPeer => return error.ConnectionResetByPeer,
- else => return error.UnexpectedWriteFailure,
- };
- }
+/// The HTTP request that this response is responding to.
+///
+/// This field is only valid after calling `wait`.
+request: Request,
- pub fn write(conn: *Connection, buffer: []const u8) WriteError!usize {
- return switch (conn.protocol) {
- .plain => conn.stream.write(buffer),
- // .tls => return conn.tls_client.write(conn.stream, buffer),
- } catch |err| switch (err) {
- error.BrokenPipe, error.ConnectionResetByPeer => return error.ConnectionResetByPeer,
- else => return error.UnexpectedWriteFailure,
- };
- }
+state: State = .first,
- pub const WriteError = error{
- ConnectionResetByPeer,
- UnexpectedWriteFailure,
+/// Initialize an HTTP server that can respond to multiple requests on the same
+/// connection.
+/// The returned `Server` is ready for `reset` or `wait` to be called.
+pub fn init(connection: std.net.Server.Connection, options: Server.Request.InitOptions) Server {
+ return .{
+ .transfer_encoding = .none,
+ .keep_alive = true,
+ .connection = .{
+ .stream = connection.stream,
+ .protocol = .plain,
+ .closing = true,
+ .read_buf = undefined,
+ .read_start = 0,
+ .read_end = 0,
+ },
+ .request = Server.Request.init(options),
+ .version = .@"HTTP/1.1",
+ .status = .ok,
+ .reason = null,
+ .extra_headers = &.{},
};
+}
- pub const Writer = std.io.Writer(*Connection, WriteError, write);
+pub const State = enum {
+ first,
+ start,
+ waited,
+ responded,
+ finished,
+};
- pub fn writer(conn: *Connection) Writer {
- return Writer{ .context = conn };
- }
+pub const ResetState = enum { reset, closing };
- pub fn close(conn: *Connection) void {
- conn.stream.close();
- }
-};
+pub const Connection = @import("Server/Connection.zig");
/// The mode of transport for responses.
pub const ResponseTransfer = union(enum) {
@@ -160,10 +59,10 @@ pub const ResponseTransfer = union(enum) {
/// The decompressor for request messages.
pub const Compression = union(enum) {
- pub const DeflateDecompressor = std.compress.zlib.Decompressor(Response.TransferReader);
- pub const GzipDecompressor = std.compress.gzip.Decompressor(Response.TransferReader);
+ pub const DeflateDecompressor = std.compress.zlib.Decompressor(Server.TransferReader);
+ pub const GzipDecompressor = std.compress.gzip.Decompressor(Server.TransferReader);
// https://github.com/ziglang/zig/issues/18937
- //pub const ZstdDecompressor = std.compress.zstd.DecompressStream(Response.TransferReader, .{});
+ //pub const ZstdDecompressor = std.compress.zstd.DecompressStream(Server.TransferReader, .{});
deflate: DeflateDecompressor,
gzip: GzipDecompressor,
@@ -177,14 +76,37 @@ pub const Request = struct {
method: http.Method,
target: []const u8,
version: http.Version,
- expect: ?[]const u8 = null,
- content_type: ?[]const u8 = null,
- content_length: ?u64 = null,
- transfer_encoding: http.TransferEncoding = .none,
- transfer_compression: http.ContentEncoding = .identity,
- keep_alive: bool = false,
+ expect: ?[]const u8,
+ content_type: ?[]const u8,
+ content_length: ?u64,
+ transfer_encoding: http.TransferEncoding,
+ transfer_compression: http.ContentEncoding,
+ keep_alive: bool,
parser: proto.HeadersParser,
- compression: Compression = .none,
+ compression: Compression,
+
+ pub const InitOptions = struct {
+ /// Externally-owned memory used to store the client's entire HTTP header.
+ /// `error.HttpHeadersOversize` is returned from read() when a
+ /// client sends too many bytes of HTTP headers.
+ client_header_buffer: []u8,
+ };
+
+ pub fn init(options: InitOptions) Request {
+ return .{
+ .method = undefined,
+ .target = undefined,
+ .version = undefined,
+ .expect = null,
+ .content_type = null,
+ .content_length = null,
+ .transfer_encoding = .none,
+ .transfer_compression = .identity,
+ .keep_alive = false,
+ .parser = proto.HeadersParser.init(options.client_header_buffer),
+ .compression = .none,
+ };
+ }
pub const ParseError = Allocator.Error || error{
UnknownHttpMethod,
@@ -300,478 +222,316 @@ pub const Request = struct {
}
};
-/// A HTTP response waiting to be sent.
-///
-/// Order of operations:
-/// ```
-/// [/ <--------------------------------------- \]
-/// accept -> wait -> send [ -> write -> finish][ -> reset /]
-/// \ -> read /
-/// ```
-pub const Response = struct {
- version: http.Version = .@"HTTP/1.1",
- status: http.Status = .ok,
- reason: ?[]const u8 = null,
- transfer_encoding: ResponseTransfer,
- keep_alive: bool,
-
- /// The peer's address
- address: net.Address,
-
- /// The underlying connection for this response.
- connection: Connection,
-
- /// Externally-owned; must outlive the Response.
- extra_headers: []const http.Header = &.{},
-
- /// The HTTP request that this response is responding to.
- ///
- /// This field is only valid after calling `wait`.
- request: Request,
-
- state: State = .first,
-
- pub const State = enum {
- first,
- start,
- waited,
- responded,
- finished,
- };
-
- /// Free all resources associated with this response.
- pub fn deinit(res: *Response) void {
- res.connection.close();
+/// Reset this response to its initial state. This must be called before
+/// handling a second request on the same connection.
+pub fn reset(res: *Server) ResetState {
+ if (res.state == .first) {
+ res.state = .start;
+ return .reset;
}
- pub const ResetState = enum { reset, closing };
-
- /// Reset this response to its initial state. This must be called before
- /// handling a second request on the same connection.
- pub fn reset(res: *Response) ResetState {
- if (res.state == .first) {
- res.state = .start;
- return .reset;
- }
+ if (!res.request.parser.done) {
+ // If the response wasn't fully read, then we need to close the connection.
+ res.connection.closing = true;
+ return .closing;
+ }
- if (!res.request.parser.done) {
- // If the response wasn't fully read, then we need to close the connection.
- res.connection.closing = true;
- return .closing;
- }
+ // A connection is only keep-alive if the Connection header is present
+ // and its value is not "close". The server and client must both agree.
+ //
+ // send() defaults to using keep-alive if the client requests it.
+ res.connection.closing = !res.keep_alive or !res.request.keep_alive;
- // A connection is only keep-alive if the Connection header is present
- // and its value is not "close". The server and client must both agree.
- //
- // send() defaults to using keep-alive if the client requests it.
- res.connection.closing = !res.keep_alive or !res.request.keep_alive;
+ res.state = .start;
+ res.version = .@"HTTP/1.1";
+ res.status = .ok;
+ res.reason = null;
- res.state = .start;
- res.version = .@"HTTP/1.1";
- res.status = .ok;
- res.reason = null;
+ res.transfer_encoding = .none;
- res.transfer_encoding = .none;
+ res.request = Request.init(.{
+ .client_header_buffer = res.request.parser.header_bytes_buffer,
+ });
- res.request.parser.reset();
+ return if (res.connection.closing) .closing else .reset;
+}
- res.request = .{
- .version = undefined,
- .method = undefined,
- .target = undefined,
- .parser = res.request.parser,
- };
+pub const SendError = Connection.WriteError || error{
+ UnsupportedTransferEncoding,
+ InvalidContentLength,
+};
- return if (res.connection.closing) .closing else .reset;
+/// Send the HTTP response headers to the client.
+pub fn send(res: *Server) SendError!void {
+ switch (res.state) {
+ .waited => res.state = .responded,
+ .first, .start, .responded, .finished => unreachable,
}
- pub const SendError = Connection.WriteError || error{
- UnsupportedTransferEncoding,
- InvalidContentLength,
- };
-
- /// Send the HTTP response headers to the client.
- pub fn send(res: *Response) SendError!void {
- switch (res.state) {
- .waited => res.state = .responded,
- .first, .start, .responded, .finished => unreachable,
- }
-
- var buffered = std.io.bufferedWriter(res.connection.writer());
- const w = buffered.writer();
-
- try w.writeAll(@tagName(res.version));
- try w.writeByte(' ');
- try w.print("{d}", .{@intFromEnum(res.status)});
- try w.writeByte(' ');
- if (res.reason) |reason| {
- try w.writeAll(reason);
- } else if (res.status.phrase()) |phrase| {
- try w.writeAll(phrase);
- }
- try w.writeAll("\r\n");
+ var buffered = std.io.bufferedWriter(res.connection.writer());
+ const w = buffered.writer();
+
+ try w.writeAll(@tagName(res.version));
+ try w.writeByte(' ');
+ try w.print("{d}", .{@intFromEnum(res.status)});
+ try w.writeByte(' ');
+ if (res.reason) |reason| {
+ try w.writeAll(reason);
+ } else if (res.status.phrase()) |phrase| {
+ try w.writeAll(phrase);
+ }
+ try w.writeAll("\r\n");
- if (res.status == .@"continue") {
- res.state = .waited; // we still need to send another request after this
+ if (res.status == .@"continue") {
+ res.state = .waited; // we still need to send another request after this
+ } else {
+ if (res.keep_alive and res.request.keep_alive) {
+ try w.writeAll("connection: keep-alive\r\n");
} else {
- if (res.keep_alive and res.request.keep_alive) {
- try w.writeAll("connection: keep-alive\r\n");
- } else {
- try w.writeAll("connection: close\r\n");
- }
-
- switch (res.transfer_encoding) {
- .chunked => try w.writeAll("transfer-encoding: chunked\r\n"),
- .content_length => |content_length| try w.print("content-length: {d}\r\n", .{content_length}),
- .none => {},
- }
-
- for (res.extra_headers) |header| {
- try w.print("{s}: {s}\r\n", .{ header.name, header.value });
- }
+ try w.writeAll("connection: close\r\n");
}
- if (res.request.method == .HEAD) {
- res.transfer_encoding = .none;
+ switch (res.transfer_encoding) {
+ .chunked => try w.writeAll("transfer-encoding: chunked\r\n"),
+ .content_length => |content_length| try w.print("content-length: {d}\r\n", .{content_length}),
+ .none => {},
}
- try w.writeAll("\r\n");
-
- try buffered.flush();
- }
-
- const TransferReadError = Connection.ReadError || proto.HeadersParser.ReadError;
-
- const TransferReader = std.io.Reader(*Response, TransferReadError, transferRead);
-
- fn transferReader(res: *Response) TransferReader {
- return .{ .context = res };
- }
-
- fn transferRead(res: *Response, buf: []u8) TransferReadError!usize {
- if (res.request.parser.done) return 0;
-
- var index: usize = 0;
- while (index == 0) {
- const amt = try res.request.parser.read(&res.connection, buf[index..], false);
- if (amt == 0 and res.request.parser.done) break;
- index += amt;
+ for (res.extra_headers) |header| {
+ try w.print("{s}: {s}\r\n", .{ header.name, header.value });
}
-
- return index;
}
- pub const WaitError = Connection.ReadError ||
- proto.HeadersParser.CheckCompleteHeadError || Request.ParseError ||
- error{CompressionUnsupported};
-
- /// Wait for the client to send a complete request head.
- ///
- /// For correct behavior, the following rules must be followed:
- ///
- /// * If this returns any error in `Connection.ReadError`, you MUST
- /// immediately close the connection by calling `deinit`.
- /// * If this returns `error.HttpHeadersInvalid`, you MAY immediately close
- /// the connection by calling `deinit`.
- /// * If this returns `error.HttpHeadersOversize`, you MUST
- /// respond with a 431 status code and then call `deinit`.
- /// * If this returns any error in `Request.ParseError`, you MUST respond
- /// with a 400 status code and then call `deinit`.
- /// * If this returns any other error, you MUST respond with a 400 status
- /// code and then call `deinit`.
- /// * If the request has an Expect header containing 100-continue, you MUST either:
- /// * Respond with a 100 status code, then call `wait` again.
- /// * Respond with a 417 status code.
- pub fn wait(res: *Response) WaitError!void {
- switch (res.state) {
- .first, .start => res.state = .waited,
- .waited, .responded, .finished => unreachable,
- }
+ if (res.request.method == .HEAD) {
+ res.transfer_encoding = .none;
+ }
- while (true) {
- try res.connection.fill();
+ try w.writeAll("\r\n");
- const nchecked = try res.request.parser.checkCompleteHead(res.connection.peek());
- res.connection.drop(@intCast(nchecked));
+ try buffered.flush();
+}
- if (res.request.parser.state.isContent()) break;
- }
+const TransferReadError = Connection.ReadError || proto.HeadersParser.ReadError;
- try res.request.parse(res.request.parser.get());
+const TransferReader = std.io.Reader(*Server, TransferReadError, transferRead);
- switch (res.request.transfer_encoding) {
- .none => {
- if (res.request.content_length) |len| {
- res.request.parser.next_chunk_length = len;
+fn transferReader(res: *Server) TransferReader {
+ return .{ .context = res };
+}
- if (len == 0) res.request.parser.done = true;
- } else {
- res.request.parser.done = true;
- }
- },
- .chunked => {
- res.request.parser.next_chunk_length = 0;
- res.request.parser.state = .chunk_head_size;
- },
- }
+fn transferRead(res: *Server, buf: []u8) TransferReadError!usize {
+ if (res.request.parser.done) return 0;
- if (!res.request.parser.done) {
- switch (res.request.transfer_compression) {
- .identity => res.request.compression = .none,
- .compress, .@"x-compress" => return error.CompressionUnsupported,
- .deflate => res.request.compression = .{
- .deflate = std.compress.zlib.decompressor(res.transferReader()),
- },
- .gzip, .@"x-gzip" => res.request.compression = .{
- .gzip = std.compress.gzip.decompressor(res.transferReader()),
- },
- .zstd => {
- // https://github.com/ziglang/zig/issues/18937
- return error.CompressionUnsupported;
- },
- }
- }
+ var index: usize = 0;
+ while (index == 0) {
+ const amt = try res.request.parser.read(&res.connection, buf[index..], false);
+ if (amt == 0 and res.request.parser.done) break;
+ index += amt;
}
- pub const ReadError = TransferReadError || proto.HeadersParser.CheckCompleteHeadError || error{ DecompressionFailure, InvalidTrailers };
+ return index;
+}
- pub const Reader = std.io.Reader(*Response, ReadError, read);
+pub const WaitError = Connection.ReadError ||
+ proto.HeadersParser.CheckCompleteHeadError || Request.ParseError ||
+ error{CompressionUnsupported};
- pub fn reader(res: *Response) Reader {
- return .{ .context = res };
+/// Wait for the client to send a complete request head.
+///
+/// For correct behavior, the following rules must be followed:
+///
+/// * If this returns any error in `Connection.ReadError`, you MUST
+/// immediately close the connection by calling `deinit`.
+/// * If this returns `error.HttpHeadersInvalid`, you MAY immediately close
+/// the connection by calling `deinit`.
+/// * If this returns `error.HttpHeadersOversize`, you MUST
+/// respond with a 431 status code and then call `deinit`.
+/// * If this returns any error in `Request.ParseError`, you MUST respond
+/// with a 400 status code and then call `deinit`.
+/// * If this returns any other error, you MUST respond with a 400 status
+/// code and then call `deinit`.
+/// * If the request has an Expect header containing 100-continue, you MUST either:
+/// * Respond with a 100 status code, then call `wait` again.
+/// * Respond with a 417 status code.
+pub fn wait(res: *Server) WaitError!void {
+ switch (res.state) {
+ .first, .start => res.state = .waited,
+ .waited, .responded, .finished => unreachable,
}
- /// Reads data from the response body. Must be called after `wait`.
- pub fn read(res: *Response, buffer: []u8) ReadError!usize {
- switch (res.state) {
- .waited, .responded, .finished => {},
- .first, .start => unreachable,
- }
+ while (true) {
+ try res.connection.fill();
- const out_index = switch (res.request.compression) {
- .deflate => |*deflate| deflate.read(buffer) catch return error.DecompressionFailure,
- .gzip => |*gzip| gzip.read(buffer) catch return error.DecompressionFailure,
- // https://github.com/ziglang/zig/issues/18937
- //.zstd => |*zstd| zstd.read(buffer) catch return error.DecompressionFailure,
- else => try res.transferRead(buffer),
- };
+ const nchecked = try res.request.parser.checkCompleteHead(res.connection.peek());
+ res.connection.drop(@intCast(nchecked));
- if (out_index == 0) {
- const has_trail = !res.request.parser.state.isContent();
+ if (res.request.parser.state.isContent()) break;
+ }
- while (!res.request.parser.state.isContent()) { // read trailing headers
- try res.connection.fill();
+ try res.request.parse(res.request.parser.get());
- const nchecked = try res.request.parser.checkCompleteHead(res.connection.peek());
- res.connection.drop(@intCast(nchecked));
- }
+ switch (res.request.transfer_encoding) {
+ .none => {
+ if (res.request.content_length) |len| {
+ res.request.parser.next_chunk_length = len;
- if (has_trail) {
- // The response headers before the trailers are already
- // guaranteed to be valid, so they will always be parsed again
- // and cannot return an error.
- // This will *only* fail for a malformed trailer.
- res.request.parse(res.request.parser.get()) catch return error.InvalidTrailers;
+ if (len == 0) res.request.parser.done = true;
+ } else {
+ res.request.parser.done = true;
}
- }
-
- return out_index;
+ },
+ .chunked => {
+ res.request.parser.next_chunk_length = 0;
+ res.request.parser.state = .chunk_head_size;
+ },
}
- /// Reads data from the response body. Must be called after `wait`.
- pub fn readAll(res: *Response, buffer: []u8) !usize {
- var index: usize = 0;
- while (index < buffer.len) {
- const amt = try read(res, buffer[index..]);
- if (amt == 0) break;
- index += amt;
+ if (!res.request.parser.done) {
+ switch (res.request.transfer_compression) {
+ .identity => res.request.compression = .none,
+ .compress, .@"x-compress" => return error.CompressionUnsupported,
+ .deflate => res.request.compression = .{
+ .deflate = std.compress.zlib.decompressor(res.transferReader()),
+ },
+ .gzip, .@"x-gzip" => res.request.compression = .{
+ .gzip = std.compress.gzip.decompressor(res.transferReader()),
+ },
+ .zstd => {
+ // https://github.com/ziglang/zig/issues/18937
+ return error.CompressionUnsupported;
+ },
}
- return index;
}
+}
+
+pub const ReadError = TransferReadError || proto.HeadersParser.CheckCompleteHeadError || error{ DecompressionFailure, InvalidTrailers };
- pub const WriteError = Connection.WriteError || error{ NotWriteable, MessageTooLong };
+pub const Reader = std.io.Reader(*Server, ReadError, read);
- pub const Writer = std.io.Writer(*Response, WriteError, write);
+pub fn reader(res: *Server) Reader {
+ return .{ .context = res };
+}
- pub fn writer(res: *Response) Writer {
- return .{ .context = res };
+/// Reads data from the response body. Must be called after `wait`.
+pub fn read(res: *Server, buffer: []u8) ReadError!usize {
+ switch (res.state) {
+ .waited, .responded, .finished => {},
+ .first, .start => unreachable,
}
- /// Write `bytes` to the server. The `transfer_encoding` request header determines how data will be sent.
- /// Must be called after `send` and before `finish`.
- pub fn write(res: *Response, bytes: []const u8) WriteError!usize {
- switch (res.state) {
- .responded => {},
- .first, .waited, .start, .finished => unreachable,
- }
+ const out_index = switch (res.request.compression) {
+ .deflate => |*deflate| deflate.read(buffer) catch return error.DecompressionFailure,
+ .gzip => |*gzip| gzip.read(buffer) catch return error.DecompressionFailure,
+ // https://github.com/ziglang/zig/issues/18937
+ //.zstd => |*zstd| zstd.read(buffer) catch return error.DecompressionFailure,
+ else => try res.transferRead(buffer),
+ };
- switch (res.transfer_encoding) {
- .chunked => {
- if (bytes.len > 0) {
- try res.connection.writer().print("{x}\r\n", .{bytes.len});
- try res.connection.writeAll(bytes);
- try res.connection.writeAll("\r\n");
- }
+ if (out_index == 0) {
+ const has_trail = !res.request.parser.state.isContent();
- return bytes.len;
- },
- .content_length => |*len| {
- if (len.* < bytes.len) return error.MessageTooLong;
+ while (!res.request.parser.state.isContent()) { // read trailing headers
+ try res.connection.fill();
- const amt = try res.connection.write(bytes);
- len.* -= amt;
- return amt;
- },
- .none => return error.NotWriteable,
+ const nchecked = try res.request.parser.checkCompleteHead(res.connection.peek());
+ res.connection.drop(@intCast(nchecked));
}
- }
- /// Write `bytes` to the server. The `transfer_encoding` request header determines how data will be sent.
- /// Must be called after `send` and before `finish`.
- pub fn writeAll(req: *Response, bytes: []const u8) WriteError!void {
- var index: usize = 0;
- while (index < bytes.len) {
- index += try write(req, bytes[index..]);
+ if (has_trail) {
+ // The response headers before the trailers are already
+ // guaranteed to be valid, so they will always be parsed again
+ // and cannot return an error.
+ // This will *only* fail for a malformed trailer.
+ res.request.parse(res.request.parser.get()) catch return error.InvalidTrailers;
}
}
- pub const FinishError = Connection.WriteError || error{MessageNotCompleted};
-
- /// Finish the body of a request. This notifies the server that you have no more data to send.
- /// Must be called after `send`.
- pub fn finish(res: *Response) FinishError!void {
- switch (res.state) {
- .responded => res.state = .finished,
- .first, .waited, .start, .finished => unreachable,
- }
+ return out_index;
+}
- switch (res.transfer_encoding) {
- .chunked => try res.connection.writeAll("0\r\n\r\n"),
- .content_length => |len| if (len != 0) return error.MessageNotCompleted,
- .none => {},
- }
+/// Reads data from the response body. Must be called after `wait`.
+pub fn readAll(res: *Server, buffer: []u8) !usize {
+ var index: usize = 0;
+ while (index < buffer.len) {
+ const amt = try read(res, buffer[index..]);
+ if (amt == 0) break;
+ index += amt;
}
-};
-
-/// Create a new HTTP server.
-pub fn init(options: net.StreamServer.Options) Server {
- return .{
- .socket = net.StreamServer.init(options),
- };
+ return index;
}
-/// Free all resources associated with this server.
-pub fn deinit(server: *Server) void {
- server.socket.deinit();
-}
+pub const WriteError = Connection.WriteError || error{ NotWriteable, MessageTooLong };
-pub const ListenError = std.os.SocketError || std.os.BindError || std.os.ListenError || std.os.SetSockOptError || std.os.GetSockNameError;
+pub const Writer = std.io.Writer(*Server, WriteError, write);
-/// Start the HTTP server listening on the given address.
-pub fn listen(server: *Server, address: net.Address) ListenError!void {
- try server.socket.listen(address);
+pub fn writer(res: *Server) Writer {
+ return .{ .context = res };
}
-pub const AcceptError = net.StreamServer.AcceptError;
-
-pub const AcceptOptions = struct {
- /// Externally-owned memory used to store the client's entire HTTP header.
- /// `error.HttpHeadersOversize` is returned from read() when a
- /// client sends too many bytes of HTTP headers.
- client_header_buffer: []u8,
-};
+/// Write `bytes` to the server. The `transfer_encoding` request header determines how data will be sent.
+/// Must be called after `send` and before `finish`.
+pub fn write(res: *Server, bytes: []const u8) WriteError!usize {
+ switch (res.state) {
+ .responded => {},
+ .first, .waited, .start, .finished => unreachable,
+ }
-pub fn accept(server: *Server, options: AcceptOptions) AcceptError!Response {
- const in = try server.socket.accept();
+ switch (res.transfer_encoding) {
+ .chunked => {
+ if (bytes.len > 0) {
+ try res.connection.writer().print("{x}\r\n", .{bytes.len});
+ try res.connection.writeAll(bytes);
+ try res.connection.writeAll("\r\n");
+ }
- return .{
- .transfer_encoding = .none,
- .keep_alive = true,
- .address = in.address,
- .connection = .{
- .stream = in.stream,
- .protocol = .plain,
+ return bytes.len;
},
- .request = .{
- .version = undefined,
- .method = undefined,
- .target = undefined,
- .parser = proto.HeadersParser.init(options.client_header_buffer),
+ .content_length => |*len| {
+ if (len.* < bytes.len) return error.MessageTooLong;
+
+ const amt = try res.connection.write(bytes);
+ len.* -= amt;
+ return amt;
},
- };
+ .none => return error.NotWriteable,
+ }
}
-test "HTTP server handles a chunked transfer coding request" {
- // This test requires spawning threads.
- if (builtin.single_threaded) {
- return error.SkipZigTest;
+/// Write `bytes` to the server. The `transfer_encoding` request header determines how data will be sent.
+/// Must be called after `send` and before `finish`.
+pub fn writeAll(req: *Server, bytes: []const u8) WriteError!void {
+ var index: usize = 0;
+ while (index < bytes.len) {
+ index += try write(req, bytes[index..]);
}
+}
- const native_endian = comptime builtin.cpu.arch.endian();
- if (builtin.zig_backend == .stage2_llvm and native_endian == .big) {
- // https://github.com/ziglang/zig/issues/13782
- return error.SkipZigTest;
+pub const FinishError = Connection.WriteError || error{MessageNotCompleted};
+
+/// Finish the body of a request. This notifies the server that you have no more data to send.
+/// Must be called after `send`.
+pub fn finish(res: *Server) FinishError!void {
+ switch (res.state) {
+ .responded => res.state = .finished,
+ .first, .waited, .start, .finished => unreachable,
}
- if (builtin.os.tag == .wasi) return error.SkipZigTest;
-
- const allocator = std.testing.allocator;
- const expect = std.testing.expect;
-
- const max_header_size = 8192;
- var server = std.http.Server.init(.{ .reuse_address = true });
- defer server.deinit();
-
- const address = try std.net.Address.parseIp("127.0.0.1", 0);
- try server.listen(address);
- const server_port = server.socket.listen_address.in.getPort();
-
- const server_thread = try std.Thread.spawn(.{}, (struct {
- fn apply(s: *std.http.Server) !void {
- var header_buffer: [max_header_size]u8 = undefined;
- var res = try s.accept(.{
- .allocator = allocator,
- .client_header_buffer = &header_buffer,
- });
- defer res.deinit();
- defer _ = res.reset();
- try res.wait();
-
- try expect(res.request.transfer_encoding == .chunked);
-
- const server_body: []const u8 = "message from server!\n";
- res.transfer_encoding = .{ .content_length = server_body.len };
- res.extra_headers = &.{
- .{ .name = "content-type", .value = "text/plain" },
- };
- res.keep_alive = false;
- try res.send();
-
- var buf: [128]u8 = undefined;
- const n = try res.readAll(&buf);
- try expect(std.mem.eql(u8, buf[0..n], "ABCD"));
- _ = try res.writer().writeAll(server_body);
- try res.finish();
- }
- }).apply, .{&server});
-
- const request_bytes =
- "POST / HTTP/1.1\r\n" ++
- "Content-Type: text/plain\r\n" ++
- "Transfer-Encoding: chunked\r\n" ++
- "\r\n" ++
- "1\r\n" ++
- "A\r\n" ++
- "1\r\n" ++
- "B\r\n" ++
- "2\r\n" ++
- "CD\r\n" ++
- "0\r\n" ++
- "\r\n";
-
- const stream = try std.net.tcpConnectToHost(allocator, "127.0.0.1", server_port);
- defer stream.close();
- _ = try stream.writeAll(request_bytes[0..]);
-
- server_thread.join();
+ switch (res.transfer_encoding) {
+ .chunked => try res.connection.writeAll("0\r\n\r\n"),
+ .content_length => |len| if (len != 0) return error.MessageNotCompleted,
+ .none => {},
+ }
}
+
+const builtin = @import("builtin");
+const std = @import("../std.zig");
+const testing = std.testing;
+const http = std.http;
+const mem = std.mem;
+const net = std.net;
+const Uri = std.Uri;
+const Allocator = mem.Allocator;
+const assert = std.debug.assert;
+
+const Server = @This();
+const proto = @import("protocol.zig");
lib/std/http/test.zig
@@ -8,13 +8,12 @@ test "trailers" {
const gpa = testing.allocator;
- var http_server = std.http.Server.init(.{
+ const address = try std.net.Address.parseIp("127.0.0.1", 0);
+ var http_server = try address.listen(.{
.reuse_address = true,
});
- const address = try std.net.Address.parseIp("127.0.0.1", 0);
- try http_server.listen(address);
- const port = http_server.socket.listen_address.in.getPort();
+ const port = http_server.listen_address.in.getPort();
const server_thread = try std.Thread.spawn(.{}, serverThread, .{&http_server});
defer server_thread.join();
@@ -67,17 +66,14 @@ test "trailers" {
try testing.expect(client.connection_pool.free_len == 1);
}
-fn serverThread(http_server: *std.http.Server) anyerror!void {
- const gpa = testing.allocator;
-
+fn serverThread(http_server: *std.net.Server) anyerror!void {
var header_buffer: [1024]u8 = undefined;
var remaining: usize = 1;
accept: while (remaining != 0) : (remaining -= 1) {
- var res = try http_server.accept(.{
- .allocator = gpa,
- .client_header_buffer = &header_buffer,
- });
- defer res.deinit();
+ const conn = try http_server.accept();
+ defer conn.stream.close();
+
+ var res = std.http.Server.init(conn, .{ .client_header_buffer = &header_buffer });
res.wait() catch |err| switch (err) {
error.HttpHeadersInvalid => continue :accept,
@@ -90,7 +86,7 @@ fn serverThread(http_server: *std.http.Server) anyerror!void {
}
}
-fn serve(res: *std.http.Server.Response) !void {
+fn serve(res: *std.http.Server) !void {
try testing.expectEqualStrings(res.request.target, "/trailer");
res.transfer_encoding = .chunked;
@@ -99,3 +95,73 @@ fn serve(res: *std.http.Server.Response) !void {
try res.writeAll("World!\n");
try res.connection.writeAll("0\r\nX-Checksum: aaaa\r\n\r\n");
}
+
+test "HTTP server handles a chunked transfer coding request" {
+ // This test requires spawning threads.
+ if (builtin.single_threaded) {
+ return error.SkipZigTest;
+ }
+
+ const native_endian = comptime builtin.cpu.arch.endian();
+ if (builtin.zig_backend == .stage2_llvm and native_endian == .big) {
+ // https://github.com/ziglang/zig/issues/13782
+ return error.SkipZigTest;
+ }
+
+ if (builtin.os.tag == .wasi) return error.SkipZigTest;
+
+ const allocator = std.testing.allocator;
+ const expect = std.testing.expect;
+
+ const max_header_size = 8192;
+
+ const address = try std.net.Address.parseIp("127.0.0.1", 0);
+ var server = try address.listen(.{ .reuse_address = true });
+ defer server.deinit();
+ const server_port = server.listen_address.in.getPort();
+
+ const server_thread = try std.Thread.spawn(.{}, (struct {
+ fn apply(s: *std.net.Server) !void {
+ var header_buffer: [max_header_size]u8 = undefined;
+ const conn = try s.accept();
+ defer conn.stream.close();
+ var res = std.http.Server.init(conn, .{ .client_header_buffer = &header_buffer });
+ try res.wait();
+
+ try expect(res.request.transfer_encoding == .chunked);
+ const server_body: []const u8 = "message from server!\n";
+ res.transfer_encoding = .{ .content_length = server_body.len };
+ res.extra_headers = &.{
+ .{ .name = "content-type", .value = "text/plain" },
+ };
+ res.keep_alive = false;
+ try res.send();
+
+ var buf: [128]u8 = undefined;
+ const n = try res.readAll(&buf);
+ try expect(std.mem.eql(u8, buf[0..n], "ABCD"));
+ _ = try res.writer().writeAll(server_body);
+ try res.finish();
+ }
+ }).apply, .{&server});
+
+ const request_bytes =
+ "POST / HTTP/1.1\r\n" ++
+ "Content-Type: text/plain\r\n" ++
+ "Transfer-Encoding: chunked\r\n" ++
+ "\r\n" ++
+ "1\r\n" ++
+ "A\r\n" ++
+ "1\r\n" ++
+ "B\r\n" ++
+ "2\r\n" ++
+ "CD\r\n" ++
+ "0\r\n" ++
+ "\r\n";
+
+ const stream = try std.net.tcpConnectToHost(allocator, "127.0.0.1", server_port);
+ defer stream.close();
+ _ = try stream.writeAll(request_bytes[0..]);
+
+ server_thread.join();
+}
lib/std/net/test.zig
@@ -181,11 +181,9 @@ test "listen on a port, send bytes, receive bytes" {
// configured.
const localhost = try net.Address.parseIp("127.0.0.1", 0);
- var server = net.StreamServer.init(.{});
+ var server = try localhost.listen(.{});
defer server.deinit();
- try server.listen(localhost);
-
const S = struct {
fn clientFn(server_address: net.Address) !void {
const socket = try net.tcpConnectToAddress(server_address);
@@ -215,17 +213,11 @@ test "listen on an in use port" {
const localhost = try net.Address.parseIp("127.0.0.1", 0);
- var server1 = net.StreamServer.init(net.StreamServer.Options{
- .reuse_port = true,
- });
+ var server1 = try localhost.listen(.{ .reuse_port = true });
defer server1.deinit();
- try server1.listen(localhost);
- var server2 = net.StreamServer.init(net.StreamServer.Options{
- .reuse_port = true,
- });
+ var server2 = try server1.listen_address.listen(.{ .reuse_port = true });
defer server2.deinit();
- try server2.listen(server1.listen_address);
}
fn testClientToHost(allocator: mem.Allocator, name: []const u8, port: u16) anyerror!void {
@@ -252,7 +244,7 @@ fn testClient(addr: net.Address) anyerror!void {
try testing.expect(mem.eql(u8, msg, "hello from server\n"));
}
-fn testServer(server: *net.StreamServer) anyerror!void {
+fn testServer(server: *net.Server) anyerror!void {
if (builtin.os.tag == .wasi) return error.SkipZigTest;
var client = try server.accept();
@@ -274,15 +266,14 @@ test "listen on a unix socket, send bytes, receive bytes" {
}
}
- var server = net.StreamServer.init(.{});
- defer server.deinit();
-
const socket_path = try generateFileName("socket.unix");
defer testing.allocator.free(socket_path);
const socket_addr = try net.Address.initUnix(socket_path);
defer std.fs.cwd().deleteFile(socket_path) catch {};
- try server.listen(socket_addr);
+
+ var server = try socket_addr.listen(.{});
+ defer server.deinit();
const S = struct {
fn clientFn(path: []const u8) !void {
@@ -323,9 +314,8 @@ test "non-blocking tcp server" {
}
const localhost = try net.Address.parseIp("127.0.0.1", 0);
- var server = net.StreamServer.init(.{ .force_nonblocking = true });
+ var server = localhost.listen(.{ .force_nonblocking = true });
defer server.deinit();
- try server.listen(localhost);
const accept_err = server.accept();
try testing.expectError(error.WouldBlock, accept_err);
lib/std/os/linux/io_uring.zig
@@ -4,6 +4,7 @@ const assert = std.debug.assert;
const mem = std.mem;
const net = std.net;
const os = std.os;
+const posix = std.posix;
const linux = os.linux;
const testing = std.testing;
@@ -3730,8 +3731,8 @@ const SocketTestHarness = struct {
client: os.socket_t,
fn close(self: SocketTestHarness) void {
- os.closeSocket(self.client);
- os.closeSocket(self.listener);
+ posix.close(self.client);
+ posix.close(self.listener);
}
};
@@ -3739,7 +3740,7 @@ fn createSocketTestHarness(ring: *IO_Uring) !SocketTestHarness {
// Create a TCP server socket
var address = try net.Address.parseIp4("127.0.0.1", 0);
const listener_socket = try createListenerSocket(&address);
- errdefer os.closeSocket(listener_socket);
+ errdefer posix.close(listener_socket);
// Submit 1 accept
var accept_addr: os.sockaddr = undefined;
@@ -3748,7 +3749,7 @@ fn createSocketTestHarness(ring: *IO_Uring) !SocketTestHarness {
// Create a TCP client socket
const client = try os.socket(address.any.family, os.SOCK.STREAM | os.SOCK.CLOEXEC, 0);
- errdefer os.closeSocket(client);
+ errdefer posix.close(client);
_ = try ring.connect(0xcccccccc, client, &address.any, address.getOsSockLen());
try testing.expectEqual(@as(u32, 2), try ring.submit());
@@ -3788,7 +3789,7 @@ fn createSocketTestHarness(ring: *IO_Uring) !SocketTestHarness {
fn createListenerSocket(address: *net.Address) !os.socket_t {
const kernel_backlog = 1;
const listener_socket = try os.socket(address.any.family, os.SOCK.STREAM | os.SOCK.CLOEXEC, 0);
- errdefer os.closeSocket(listener_socket);
+ errdefer posix.close(listener_socket);
try os.setsockopt(listener_socket, os.SOL.SOCKET, os.SO.REUSEADDR, &mem.toBytes(@as(c_int, 1)));
try os.bind(listener_socket, &address.any, address.getOsSockLen());
@@ -3813,7 +3814,7 @@ test "accept multishot" {
var address = try net.Address.parseIp4("127.0.0.1", 0);
const listener_socket = try createListenerSocket(&address);
- defer os.closeSocket(listener_socket);
+ defer posix.close(listener_socket);
// submit multishot accept operation
var addr: os.sockaddr = undefined;
@@ -3826,7 +3827,7 @@ test "accept multishot" {
while (nr > 0) : (nr -= 1) {
// connect client
const client = try os.socket(address.any.family, os.SOCK.STREAM | os.SOCK.CLOEXEC, 0);
- errdefer os.closeSocket(client);
+ errdefer posix.close(client);
try os.connect(client, &address.any, address.getOsSockLen());
// test accept completion
@@ -3836,7 +3837,7 @@ test "accept multishot" {
try testing.expect(cqe.user_data == userdata);
try testing.expect(cqe.flags & linux.IORING_CQE_F_MORE > 0); // more flag is set
- os.closeSocket(client);
+ posix.close(client);
}
}
@@ -3909,7 +3910,7 @@ test "accept_direct" {
try ring.register_files(registered_fds[0..]);
const listener_socket = try createListenerSocket(&address);
- defer os.closeSocket(listener_socket);
+ defer posix.close(listener_socket);
const accept_userdata: u64 = 0xaaaaaaaa;
const read_userdata: u64 = 0xbbbbbbbb;
@@ -3927,7 +3928,7 @@ test "accept_direct" {
// connect
const client = try os.socket(address.any.family, os.SOCK.STREAM | os.SOCK.CLOEXEC, 0);
try os.connect(client, &address.any, address.getOsSockLen());
- defer os.closeSocket(client);
+ defer posix.close(client);
// accept completion
const cqe_accept = try ring.copy_cqe();
@@ -3961,7 +3962,7 @@ test "accept_direct" {
// connect
const client = try os.socket(address.any.family, os.SOCK.STREAM | os.SOCK.CLOEXEC, 0);
try os.connect(client, &address.any, address.getOsSockLen());
- defer os.closeSocket(client);
+ defer posix.close(client);
// completion with error
const cqe_accept = try ring.copy_cqe();
try testing.expect(cqe_accept.user_data == accept_userdata);
@@ -3989,7 +3990,7 @@ test "accept_multishot_direct" {
try ring.register_files(registered_fds[0..]);
const listener_socket = try createListenerSocket(&address);
- defer os.closeSocket(listener_socket);
+ defer posix.close(listener_socket);
const accept_userdata: u64 = 0xaaaaaaaa;
@@ -4003,7 +4004,7 @@ test "accept_multishot_direct" {
// connect
const client = try os.socket(address.any.family, os.SOCK.STREAM | os.SOCK.CLOEXEC, 0);
try os.connect(client, &address.any, address.getOsSockLen());
- defer os.closeSocket(client);
+ defer posix.close(client);
// accept completion
const cqe_accept = try ring.copy_cqe();
@@ -4018,7 +4019,7 @@ test "accept_multishot_direct" {
// connect
const client = try os.socket(address.any.family, os.SOCK.STREAM | os.SOCK.CLOEXEC, 0);
try os.connect(client, &address.any, address.getOsSockLen());
- defer os.closeSocket(client);
+ defer posix.close(client);
// completion with error
const cqe_accept = try ring.copy_cqe();
try testing.expect(cqe_accept.user_data == accept_userdata);
@@ -4092,7 +4093,7 @@ test "socket_direct/socket_direct_alloc/close_direct" {
// use sockets from registered_fds in connect operation
var address = try net.Address.parseIp4("127.0.0.1", 0);
const listener_socket = try createListenerSocket(&address);
- defer os.closeSocket(listener_socket);
+ defer posix.close(listener_socket);
const accept_userdata: u64 = 0xaaaaaaaa;
const connect_userdata: u64 = 0xbbbbbbbb;
const close_userdata: u64 = 0xcccccccc;
lib/std/os/test.zig
@@ -817,7 +817,7 @@ test "shutdown socket" {
error.SocketNotConnected => {},
else => |e| return e,
};
- os.closeSocket(sock);
+ std.net.Stream.close(.{ .handle = sock });
}
test "sigaction" {
lib/std/net.zig
@@ -4,15 +4,17 @@ const assert = std.debug.assert;
const net = @This();
const mem = std.mem;
const os = std.os;
+const posix = std.posix;
const fs = std.fs;
const io = std.io;
const native_endian = builtin.target.cpu.arch.endian();
// Windows 10 added support for unix sockets in build 17063, redstone 4 is the
// first release to support them.
-pub const has_unix_sockets = @hasDecl(os.sockaddr, "un") and
- (builtin.target.os.tag != .windows or
- builtin.os.version_range.windows.isAtLeast(.win10_rs4) orelse false);
+pub const has_unix_sockets = switch (builtin.os.tag) {
+ .windows => builtin.os.version_range.windows.isAtLeast(.win10_rs4) orelse false,
+ else => true,
+};
pub const IPParseError = error{
Overflow,
@@ -206,6 +208,57 @@ pub const Address = extern union {
else => unreachable,
}
}
+
+ pub const ListenError = posix.SocketError || posix.BindError || posix.ListenError ||
+ posix.SetSockOptError || posix.GetSockNameError;
+
+ pub const ListenOptions = struct {
+ /// How many connections the kernel will accept on the application's behalf.
+ /// If more than this many connections pool in the kernel, clients will start
+ /// seeing "Connection refused".
+ kernel_backlog: u31 = 128,
+ reuse_address: bool = false,
+ reuse_port: bool = false,
+ force_nonblocking: bool = false,
+ };
+
+ /// The returned `Server` has an open `stream`.
+ pub fn listen(address: Address, options: ListenOptions) ListenError!Server {
+ const nonblock: u32 = if (options.force_nonblocking) posix.SOCK.NONBLOCK else 0;
+ const sock_flags = posix.SOCK.STREAM | posix.SOCK.CLOEXEC | nonblock;
+ const proto: u32 = if (address.any.family == posix.AF.UNIX) 0 else posix.IPPROTO.TCP;
+
+ const sockfd = try posix.socket(address.any.family, sock_flags, proto);
+ var s: Server = .{
+ .listen_address = undefined,
+ .stream = .{ .handle = sockfd },
+ };
+ errdefer s.stream.close();
+
+ if (options.reuse_address) {
+ try posix.setsockopt(
+ sockfd,
+ posix.SOL.SOCKET,
+ posix.SO.REUSEADDR,
+ &mem.toBytes(@as(c_int, 1)),
+ );
+ }
+
+ if (options.reuse_port) {
+ try posix.setsockopt(
+ sockfd,
+ posix.SOL.SOCKET,
+ posix.SO.REUSEPORT,
+ &mem.toBytes(@as(c_int, 1)),
+ );
+ }
+
+ var socklen = address.getOsSockLen();
+ try posix.bind(sockfd, &address.any, socklen);
+ try posix.listen(sockfd, options.kernel_backlog);
+ try posix.getsockname(sockfd, &s.listen_address.any, &socklen);
+ return s;
+ }
};
pub const Ip4Address = extern struct {
@@ -657,7 +710,7 @@ pub fn connectUnixSocket(path: []const u8) !Stream {
os.SOCK.STREAM | os.SOCK.CLOEXEC | opt_non_block,
0,
);
- errdefer os.closeSocket(sockfd);
+ errdefer Stream.close(.{ .handle = sockfd });
var addr = try std.net.Address.initUnix(path);
try os.connect(sockfd, &addr.any, addr.getOsSockLen());
@@ -669,7 +722,7 @@ fn if_nametoindex(name: []const u8) IPv6InterfaceError!u32 {
if (builtin.target.os.tag == .linux) {
var ifr: os.ifreq = undefined;
const sockfd = try os.socket(os.AF.UNIX, os.SOCK.DGRAM | os.SOCK.CLOEXEC, 0);
- defer os.closeSocket(sockfd);
+ defer Stream.close(.{ .handle = sockfd });
@memcpy(ifr.ifrn.name[0..name.len], name);
ifr.ifrn.name[name.len] = 0;
@@ -738,7 +791,7 @@ pub fn tcpConnectToAddress(address: Address) TcpConnectToAddressError!Stream {
const sock_flags = os.SOCK.STREAM | nonblock |
(if (builtin.target.os.tag == .windows) 0 else os.SOCK.CLOEXEC);
const sockfd = try os.socket(address.any.family, sock_flags, os.IPPROTO.TCP);
- errdefer os.closeSocket(sockfd);
+ errdefer Stream.close(.{ .handle = sockfd });
try os.connect(sockfd, &address.any, address.getOsSockLen());
@@ -1068,7 +1121,7 @@ fn linuxLookupName(
var prefixlen: i32 = 0;
const sock_flags = os.SOCK.DGRAM | os.SOCK.CLOEXEC;
if (os.socket(addr.addr.any.family, sock_flags, os.IPPROTO.UDP)) |fd| syscalls: {
- defer os.closeSocket(fd);
+ defer Stream.close(.{ .handle = fd });
os.connect(fd, da, dalen) catch break :syscalls;
key |= DAS_USABLE;
os.getsockname(fd, sa, &salen) catch break :syscalls;
@@ -1553,7 +1606,7 @@ fn resMSendRc(
},
else => |e| return e,
};
- defer os.closeSocket(fd);
+ defer Stream.close(.{ .handle = fd });
// Past this point, there are no errors. Each individual query will
// yield either no reply (indicated by zero length) or an answer
@@ -1729,13 +1782,15 @@ fn dnsParseCallback(ctx: dpc_ctx, rr: u8, data: []const u8, packet: []const u8)
}
pub const Stream = struct {
- // Underlying socket descriptor.
- // Note that on some platforms this may not be interchangeable with a
- // regular files descriptor.
- handle: os.socket_t,
-
- pub fn close(self: Stream) void {
- os.closeSocket(self.handle);
+ /// Underlying platform-defined type which may or may not be
+ /// interchangeable with a file system file descriptor.
+ handle: posix.socket_t,
+
+ pub fn close(s: Stream) void {
+ switch (builtin.os.tag) {
+ .windows => std.os.windows.closesocket(s.handle) catch unreachable,
+ else => posix.close(s.handle),
+ }
}
pub const ReadError = os.ReadError;
@@ -1839,156 +1894,38 @@ pub const Stream = struct {
}
};
-pub const StreamServer = struct {
- /// Copied from `Options` on `init`.
- kernel_backlog: u31,
- reuse_address: bool,
- reuse_port: bool,
- force_nonblocking: bool,
-
- /// `undefined` until `listen` returns successfully.
+pub const Server = struct {
listen_address: Address,
+ stream: std.net.Stream,
- sockfd: ?os.socket_t,
-
- pub const Options = struct {
- /// How many connections the kernel will accept on the application's behalf.
- /// If more than this many connections pool in the kernel, clients will start
- /// seeing "Connection refused".
- kernel_backlog: u31 = 128,
-
- /// Enable SO.REUSEADDR on the socket.
- reuse_address: bool = false,
-
- /// Enable SO.REUSEPORT on the socket.
- reuse_port: bool = false,
-
- /// Force non-blocking mode.
- force_nonblocking: bool = false,
+ pub const Connection = struct {
+ stream: std.net.Stream,
+ address: Address,
};
- /// After this call succeeds, resources have been acquired and must
- /// be released with `deinit`.
- pub fn init(options: Options) StreamServer {
- return StreamServer{
- .sockfd = null,
- .kernel_backlog = options.kernel_backlog,
- .reuse_address = options.reuse_address,
- .reuse_port = options.reuse_port,
- .force_nonblocking = options.force_nonblocking,
- .listen_address = undefined,
- };
- }
-
- /// Release all resources. The `StreamServer` memory becomes `undefined`.
- pub fn deinit(self: *StreamServer) void {
- self.close();
- self.* = undefined;
- }
-
- pub fn listen(self: *StreamServer, address: Address) !void {
- const nonblock = 0;
- const sock_flags = os.SOCK.STREAM | os.SOCK.CLOEXEC | nonblock;
- var use_sock_flags: u32 = sock_flags;
- if (self.force_nonblocking) use_sock_flags |= os.SOCK.NONBLOCK;
- const proto = if (address.any.family == os.AF.UNIX) @as(u32, 0) else os.IPPROTO.TCP;
-
- const sockfd = try os.socket(address.any.family, use_sock_flags, proto);
- self.sockfd = sockfd;
- errdefer {
- os.closeSocket(sockfd);
- self.sockfd = null;
- }
-
- if (self.reuse_address) {
- try os.setsockopt(
- sockfd,
- os.SOL.SOCKET,
- os.SO.REUSEADDR,
- &mem.toBytes(@as(c_int, 1)),
- );
- }
- if (@hasDecl(os.SO, "REUSEPORT") and self.reuse_port) {
- try os.setsockopt(
- sockfd,
- os.SOL.SOCKET,
- os.SO.REUSEPORT,
- &mem.toBytes(@as(c_int, 1)),
- );
- }
-
- var socklen = address.getOsSockLen();
- try os.bind(sockfd, &address.any, socklen);
- try os.listen(sockfd, self.kernel_backlog);
- try os.getsockname(sockfd, &self.listen_address.any, &socklen);
- }
-
- /// Stop listening. It is still necessary to call `deinit` after stopping listening.
- /// Calling `deinit` will automatically call `close`. It is safe to call `close` when
- /// not listening.
- pub fn close(self: *StreamServer) void {
- if (self.sockfd) |fd| {
- os.closeSocket(fd);
- self.sockfd = null;
- self.listen_address = undefined;
- }
+ pub fn deinit(s: *Server) void {
+ s.stream.close();
+ s.* = undefined;
}
- pub const AcceptError = error{
- ConnectionAborted,
-
- /// The per-process limit on the number of open file descriptors has been reached.
- ProcessFdQuotaExceeded,
-
- /// The system-wide limit on the total number of open files has been reached.
- SystemFdQuotaExceeded,
-
- /// Not enough free memory. This often means that the memory allocation
- /// is limited by the socket buffer limits, not by the system memory.
- SystemResources,
-
- /// Socket is not listening for new connections.
- SocketNotListening,
-
- ProtocolFailure,
-
- /// Socket is in non-blocking mode and there is no connection to accept.
- WouldBlock,
-
- /// Firewall rules forbid connection.
- BlockedByFirewall,
-
- FileDescriptorNotASocket,
-
- ConnectionResetByPeer,
-
- NetworkSubsystemFailed,
+ pub const AcceptError = posix.AcceptError;
- OperationNotSupported,
- } || os.UnexpectedError;
-
- pub const Connection = struct {
- stream: Stream,
- address: Address,
- };
-
- /// If this function succeeds, the returned `Connection` is a caller-managed resource.
- pub fn accept(self: *StreamServer) AcceptError!Connection {
+ /// Blocks until a client connects to the server. The returned `Connection` has
+ /// an open stream.
+ pub fn accept(s: *Server) AcceptError!Connection {
var accepted_addr: Address = undefined;
- var adr_len: os.socklen_t = @sizeOf(Address);
- const accept_result = os.accept(self.sockfd.?, &accepted_addr.any, &adr_len, os.SOCK.CLOEXEC);
-
- if (accept_result) |fd| {
- return Connection{
- .stream = Stream{ .handle = fd },
- .address = accepted_addr,
- };
- } else |err| {
- return err;
- }
+ var addr_len: posix.socklen_t = @sizeOf(Address);
+ const fd = try posix.accept(s.stream.handle, &accepted_addr.any, &addr_len, posix.SOCK.CLOEXEC);
+ return .{
+ .stream = .{ .handle = fd },
+ .address = accepted_addr,
+ };
}
};
test {
_ = @import("net/test.zig");
+ _ = Server;
+ _ = Stream;
+ _ = Address;
}
lib/std/os.zig
@@ -3598,14 +3598,6 @@ pub fn shutdown(sock: socket_t, how: ShutdownHow) ShutdownError!void {
}
}
-pub fn closeSocket(sock: socket_t) void {
- if (builtin.os.tag == .windows) {
- windows.closesocket(sock) catch unreachable;
- } else {
- close(sock);
- }
-}
-
pub const BindError = error{
/// The address is protected, and the user is not the superuser.
/// For UNIX domain sockets: Search permission is denied on a component
src/main.zig
@@ -3322,13 +3322,13 @@ fn buildOutputType(
.ip4 => |ip4_addr| {
if (build_options.only_core_functionality) unreachable;
- var server = std.net.StreamServer.init(.{
+ const addr: std.net.Address = .{ .in = ip4_addr };
+
+ var server = try addr.listen(.{
.reuse_address = true,
});
defer server.deinit();
- try server.listen(.{ .in = ip4_addr });
-
const conn = try server.accept();
defer conn.stream.close();
test/standalone/http.zig
@@ -1,8 +1,6 @@
const std = @import("std");
const http = std.http;
-const Server = http.Server;
-const Client = http.Client;
const mem = std.mem;
const testing = std.testing;
@@ -19,9 +17,7 @@ var gpa_client = std.heap.GeneralPurposeAllocator(.{ .stack_trace_frames = 12 })
const salloc = gpa_server.allocator();
const calloc = gpa_client.allocator();
-var server: Server = undefined;
-
-fn handleRequest(res: *Server.Response) !void {
+fn handleRequest(res: *http.Server, listen_port: u16) !void {
const log = std.log.scoped(.server);
log.info("{} {s} {s}", .{ res.request.method, @tagName(res.request.version), res.request.target });
@@ -125,7 +121,9 @@ fn handleRequest(res: *Server.Response) !void {
} else if (mem.eql(u8, res.request.target, "/redirect/3")) {
res.transfer_encoding = .chunked;
- const location = try std.fmt.allocPrint(salloc, "http://127.0.0.1:{d}/redirect/2", .{server.socket.listen_address.getPort()});
+ const location = try std.fmt.allocPrint(salloc, "http://127.0.0.1:{d}/redirect/2", .{
+ listen_port,
+ });
defer salloc.free(location);
res.status = .found;
@@ -168,14 +166,15 @@ fn handleRequest(res: *Server.Response) !void {
var handle_new_requests = true;
-fn runServer(srv: *Server) !void {
+fn runServer(server: *std.net.Server) !void {
var client_header_buffer: [1024]u8 = undefined;
outer: while (handle_new_requests) {
- var res = try srv.accept(.{
- .allocator = salloc,
+ var connection = try server.accept();
+ defer connection.stream.close();
+
+ var res = http.Server.init(connection, .{
.client_header_buffer = &client_header_buffer,
});
- defer res.deinit();
while (res.reset() != .closing) {
res.wait() catch |err| switch (err) {
@@ -184,16 +183,15 @@ fn runServer(srv: *Server) !void {
else => return err,
};
- try handleRequest(&res);
+ try handleRequest(&res, server.listen_address.getPort());
}
}
}
-fn serverThread(srv: *Server) void {
- defer srv.deinit();
+fn serverThread(server: *std.net.Server) void {
defer _ = gpa_server.deinit();
- runServer(srv) catch |err| {
+ runServer(server) catch |err| {
std.debug.print("server error: {}\n", .{err});
if (@errorReturnTrace()) |trace| {
@@ -205,18 +203,10 @@ fn serverThread(srv: *Server) void {
};
}
-fn killServer(addr: std.net.Address) void {
- handle_new_requests = false;
-
- const conn = std.net.tcpConnectToAddress(addr) catch return;
- conn.close();
-}
-
fn getUnusedTcpPort() !u16 {
const addr = try std.net.Address.parseIp("127.0.0.1", 0);
- var s = std.net.StreamServer.init(.{});
+ var s = try addr.listen(.{});
defer s.deinit();
- try s.listen(addr);
return s.listen_address.in.getPort();
}
@@ -225,16 +215,15 @@ pub fn main() !void {
defer _ = gpa_client.deinit();
- server = Server.init(.{ .reuse_address = true });
-
const addr = std.net.Address.parseIp("127.0.0.1", 0) catch unreachable;
- try server.listen(addr);
+ var server = try addr.listen(.{ .reuse_address = true });
+ defer server.deinit();
- const port = server.socket.listen_address.getPort();
+ const port = server.listen_address.getPort();
const server_thread = try std.Thread.spawn(.{}, serverThread, .{&server});
- var client = Client{ .allocator = calloc };
+ var client: http.Client = .{ .allocator = calloc };
errdefer client.deinit();
// defer client.deinit(); handled below
@@ -691,6 +680,12 @@ pub fn main() !void {
client.deinit();
- killServer(server.socket.listen_address);
+ {
+ handle_new_requests = false;
+
+ const conn = std.net.tcpConnectToAddress(server.listen_address) catch return;
+ conn.close();
+ }
+
server_thread.join();
}