master
  1//! Handles a single connection lifecycle.
  2
  3const std = @import("../std.zig");
  4const http = std.http;
  5const mem = std.mem;
  6const Uri = std.Uri;
  7const assert = std.debug.assert;
  8const testing = std.testing;
  9const Writer = std.Io.Writer;
 10const Reader = std.Io.Reader;
 11
 12const Server = @This();
 13
 14/// Data from the HTTP server to the HTTP client.
 15out: *Writer,
 16reader: http.Reader,
 17
 18/// Initialize an HTTP server that can respond to multiple requests on the same
 19/// connection.
 20///
 21/// The buffer of `in` must be large enough to store the client's entire HTTP
 22/// header, otherwise `receiveHead` returns `error.HttpHeadersOversize`.
 23///
 24/// The returned `Server` is ready for `receiveHead` to be called.
 25pub fn init(in: *Reader, out: *Writer) Server {
 26    return .{
 27        .reader = .{
 28            .in = in,
 29            .state = .ready,
 30            // Populated when `http.Reader.bodyReader` is called.
 31            .interface = undefined,
 32            .max_head_len = in.buffer.len,
 33        },
 34        .out = out,
 35    };
 36}
 37
 38pub const ReceiveHeadError = http.Reader.HeadError || error{
 39    /// Client sent headers that did not conform to the HTTP protocol.
 40    ///
 41    /// To find out more detailed diagnostics, `Request.head_buffer` can be
 42    /// passed directly to `Request.Head.parse`.
 43    HttpHeadersInvalid,
 44};
 45
 46pub fn receiveHead(s: *Server) ReceiveHeadError!Request {
 47    const head_buffer = try s.reader.receiveHead();
 48    return .{
 49        .server = s,
 50        .head_buffer = head_buffer,
 51        // No need to track the returned error here since users can repeat the
 52        // parse with the header buffer to get detailed diagnostics.
 53        .head = Request.Head.parse(head_buffer) catch return error.HttpHeadersInvalid,
 54    };
 55}
 56
 57pub const Request = struct {
 58    server: *Server,
 59    /// Pointers in this struct are invalidated when the request body stream is
 60    /// initialized.
 61    head: Head,
 62    head_buffer: []const u8,
 63    respond_err: ?RespondError = null,
 64
 65    pub const RespondError = error{
 66        /// The request contained an `expect` header with an unrecognized value.
 67        HttpExpectationFailed,
 68    };
 69
 70    pub const Head = struct {
 71        method: http.Method,
 72        target: []const u8,
 73        version: http.Version,
 74        expect: ?[]const u8,
 75        content_type: ?[]const u8,
 76        content_length: ?u64,
 77        transfer_encoding: http.TransferEncoding,
 78        transfer_compression: http.ContentEncoding,
 79        keep_alive: bool,
 80
 81        pub const ParseError = error{
 82            UnknownHttpMethod,
 83            HttpHeadersInvalid,
 84            HttpHeaderContinuationsUnsupported,
 85            HttpTransferEncodingUnsupported,
 86            HttpConnectionHeaderUnsupported,
 87            InvalidContentLength,
 88            CompressionUnsupported,
 89            MissingFinalNewline,
 90        };
 91
 92        pub fn parse(bytes: []const u8) ParseError!Head {
 93            var it = mem.splitSequence(u8, bytes, "\r\n");
 94
 95            const first_line = it.next().?;
 96            if (first_line.len < 10)
 97                return error.HttpHeadersInvalid;
 98
 99            const method_end = mem.indexOfScalar(u8, first_line, ' ') orelse
100                return error.HttpHeadersInvalid;
101
102            const method = std.meta.stringToEnum(http.Method, first_line[0..method_end]) orelse
103                return error.UnknownHttpMethod;
104
105            const version_start = mem.lastIndexOfScalar(u8, first_line, ' ') orelse
106                return error.HttpHeadersInvalid;
107            if (version_start == method_end) return error.HttpHeadersInvalid;
108
109            const version_str = first_line[version_start + 1 ..];
110            if (version_str.len != 8) return error.HttpHeadersInvalid;
111            const version: http.Version = switch (int64(version_str[0..8])) {
112                int64("HTTP/1.0") => .@"HTTP/1.0",
113                int64("HTTP/1.1") => .@"HTTP/1.1",
114                else => return error.HttpHeadersInvalid,
115            };
116
117            const target = first_line[method_end + 1 .. version_start];
118
119            var head: Head = .{
120                .method = method,
121                .target = target,
122                .version = version,
123                .expect = null,
124                .content_type = null,
125                .content_length = null,
126                .transfer_encoding = .none,
127                .transfer_compression = .identity,
128                .keep_alive = switch (version) {
129                    .@"HTTP/1.0" => false,
130                    .@"HTTP/1.1" => true,
131                },
132            };
133
134            while (it.next()) |line| {
135                if (line.len == 0) return head;
136                switch (line[0]) {
137                    ' ', '\t' => return error.HttpHeaderContinuationsUnsupported,
138                    else => {},
139                }
140
141                var line_it = mem.splitScalar(u8, line, ':');
142                const header_name = line_it.next().?;
143                const header_value = mem.trim(u8, line_it.rest(), " \t");
144                if (header_name.len == 0) return error.HttpHeadersInvalid;
145
146                if (std.ascii.eqlIgnoreCase(header_name, "connection")) {
147                    head.keep_alive = !std.ascii.eqlIgnoreCase(header_value, "close");
148                } else if (std.ascii.eqlIgnoreCase(header_name, "expect")) {
149                    head.expect = header_value;
150                } else if (std.ascii.eqlIgnoreCase(header_name, "content-type")) {
151                    head.content_type = header_value;
152                } else if (std.ascii.eqlIgnoreCase(header_name, "content-length")) {
153                    if (head.content_length != null) return error.HttpHeadersInvalid;
154                    head.content_length = std.fmt.parseInt(u64, header_value, 10) catch
155                        return error.InvalidContentLength;
156                } else if (std.ascii.eqlIgnoreCase(header_name, "content-encoding")) {
157                    if (head.transfer_compression != .identity) return error.HttpHeadersInvalid;
158
159                    const trimmed = mem.trim(u8, header_value, " ");
160
161                    if (http.ContentEncoding.fromString(trimmed)) |ce| {
162                        head.transfer_compression = ce;
163                    } else {
164                        return error.HttpTransferEncodingUnsupported;
165                    }
166                } else if (std.ascii.eqlIgnoreCase(header_name, "transfer-encoding")) {
167                    // Transfer-Encoding: second, first
168                    // Transfer-Encoding: deflate, chunked
169                    var iter = mem.splitBackwardsScalar(u8, header_value, ',');
170
171                    const first = iter.first();
172                    const trimmed_first = mem.trim(u8, first, " ");
173
174                    var next: ?[]const u8 = first;
175                    if (std.meta.stringToEnum(http.TransferEncoding, trimmed_first)) |transfer| {
176                        if (head.transfer_encoding != .none)
177                            return error.HttpHeadersInvalid; // we already have a transfer encoding
178                        head.transfer_encoding = transfer;
179
180                        next = iter.next();
181                    }
182
183                    if (next) |second| {
184                        const trimmed_second = mem.trim(u8, second, " ");
185
186                        if (http.ContentEncoding.fromString(trimmed_second)) |transfer| {
187                            if (head.transfer_compression != .identity)
188                                return error.HttpHeadersInvalid; // double compression is not supported
189                            head.transfer_compression = transfer;
190                        } else {
191                            return error.HttpTransferEncodingUnsupported;
192                        }
193                    }
194
195                    if (iter.next()) |_| return error.HttpTransferEncodingUnsupported;
196                }
197            }
198            return error.MissingFinalNewline;
199        }
200
201        test parse {
202            const request_bytes = "GET /hi HTTP/1.0\r\n" ++
203                "content-tYpe: text/plain\r\n" ++
204                "content-Length:10\r\n" ++
205                "expeCt:   100-continue \r\n" ++
206                "TRansfer-encoding:\tdeflate, chunked \r\n" ++
207                "connectioN:\t keep-alive \r\n\r\n";
208
209            const req = try parse(request_bytes);
210
211            try testing.expectEqual(.GET, req.method);
212            try testing.expectEqual(.@"HTTP/1.0", req.version);
213            try testing.expectEqualStrings("/hi", req.target);
214
215            try testing.expectEqualStrings("text/plain", req.content_type.?);
216            try testing.expectEqualStrings("100-continue", req.expect.?);
217
218            try testing.expectEqual(true, req.keep_alive);
219            try testing.expectEqual(10, req.content_length.?);
220            try testing.expectEqual(.chunked, req.transfer_encoding);
221            try testing.expectEqual(.deflate, req.transfer_compression);
222        }
223
224        inline fn int64(array: *const [8]u8) u64 {
225            return @bitCast(array.*);
226        }
227
228        /// Help the programmer avoid bugs by calling this when the string
229        /// memory of `Head` becomes invalidated.
230        fn invalidateStrings(h: *Head) void {
231            h.target = undefined;
232            if (h.expect) |*s| s.* = undefined;
233            if (h.content_type) |*s| s.* = undefined;
234        }
235    };
236
237    pub fn iterateHeaders(r: *const Request) http.HeaderIterator {
238        assert(r.server.reader.state == .received_head);
239        return http.HeaderIterator.init(r.head_buffer);
240    }
241
242    test iterateHeaders {
243        const request_bytes = "GET /hi HTTP/1.0\r\n" ++
244            "content-tYpe: text/plain\r\n" ++
245            "content-Length:10\r\n" ++
246            "expeCt:   100-continue \r\n" ++
247            "TRansfer-encoding:\tdeflate, chunked \r\n" ++
248            "connectioN:\t keep-alive \r\n\r\n";
249
250        var server: Server = .{
251            .reader = .{
252                .in = undefined,
253                .state = .received_head,
254                .interface = undefined,
255                .max_head_len = 4096,
256            },
257            .out = undefined,
258        };
259
260        var request: Request = .{
261            .server = &server,
262            .head = undefined,
263            .head_buffer = @constCast(request_bytes),
264        };
265
266        var it = request.iterateHeaders();
267        {
268            const header = it.next().?;
269            try testing.expectEqualStrings("content-tYpe", header.name);
270            try testing.expectEqualStrings("text/plain", header.value);
271            try testing.expect(!it.is_trailer);
272        }
273        {
274            const header = it.next().?;
275            try testing.expectEqualStrings("content-Length", header.name);
276            try testing.expectEqualStrings("10", header.value);
277            try testing.expect(!it.is_trailer);
278        }
279        {
280            const header = it.next().?;
281            try testing.expectEqualStrings("expeCt", header.name);
282            try testing.expectEqualStrings("100-continue", header.value);
283            try testing.expect(!it.is_trailer);
284        }
285        {
286            const header = it.next().?;
287            try testing.expectEqualStrings("TRansfer-encoding", header.name);
288            try testing.expectEqualStrings("deflate, chunked", header.value);
289            try testing.expect(!it.is_trailer);
290        }
291        {
292            const header = it.next().?;
293            try testing.expectEqualStrings("connectioN", header.name);
294            try testing.expectEqualStrings("keep-alive", header.value);
295            try testing.expect(!it.is_trailer);
296        }
297        try testing.expectEqual(null, it.next());
298    }
299
300    pub const RespondOptions = struct {
301        version: http.Version = .@"HTTP/1.1",
302        status: http.Status = .ok,
303        reason: ?[]const u8 = null,
304        keep_alive: bool = true,
305        extra_headers: []const http.Header = &.{},
306        transfer_encoding: ?http.TransferEncoding = null,
307    };
308
309    /// Send an entire HTTP response to the client, including headers and body.
310    ///
311    /// Automatically handles HEAD requests by omitting the body.
312    ///
313    /// Unless `transfer_encoding` is specified, uses the "content-length"
314    /// header.
315    ///
316    /// If the request contains a body and the connection is to be reused,
317    /// discards the request body, leaving the Server in the `ready` state. If
318    /// this discarding fails, the connection is marked as not to be reused and
319    /// no error is surfaced.
320    ///
321    /// Asserts status is not `continue`.
322    /// Asserts that "\r\n" does not occur in any header name or value.
323    pub fn respond(
324        request: *Request,
325        content: []const u8,
326        options: RespondOptions,
327    ) ExpectContinueError!void {
328        try respondUnflushed(request, content, options);
329        try request.server.out.flush();
330    }
331
332    pub fn respondUnflushed(
333        request: *Request,
334        content: []const u8,
335        options: RespondOptions,
336    ) ExpectContinueError!void {
337        assert(options.status != .@"continue");
338        if (std.debug.runtime_safety) {
339            for (options.extra_headers) |header| {
340                assert(header.name.len != 0);
341                assert(std.mem.indexOfScalar(u8, header.name, ':') == null);
342                assert(std.mem.indexOfPosLinear(u8, header.name, 0, "\r\n") == null);
343                assert(std.mem.indexOfPosLinear(u8, header.value, 0, "\r\n") == null);
344            }
345        }
346        try writeExpectContinue(request);
347
348        const transfer_encoding_none = (options.transfer_encoding orelse .chunked) == .none;
349        const server_keep_alive = !transfer_encoding_none and options.keep_alive;
350        const keep_alive = request.discardBody(server_keep_alive);
351
352        const phrase = options.reason orelse options.status.phrase() orelse "";
353
354        const out = request.server.out;
355        try out.print("{s} {d} {s}\r\n", .{
356            @tagName(options.version), @intFromEnum(options.status), phrase,
357        });
358
359        switch (options.version) {
360            .@"HTTP/1.0" => if (keep_alive) try out.writeAll("connection: keep-alive\r\n"),
361            .@"HTTP/1.1" => if (!keep_alive) try out.writeAll("connection: close\r\n"),
362        }
363
364        if (options.transfer_encoding) |transfer_encoding| switch (transfer_encoding) {
365            .none => {},
366            .chunked => try out.writeAll("transfer-encoding: chunked\r\n"),
367        } else {
368            try out.print("content-length: {d}\r\n", .{content.len});
369        }
370
371        for (options.extra_headers) |header| {
372            var vecs: [4][]const u8 = .{ header.name, ": ", header.value, "\r\n" };
373            try out.writeVecAll(&vecs);
374        }
375
376        try out.writeAll("\r\n");
377
378        if (request.head.method != .HEAD) {
379            const is_chunked = (options.transfer_encoding orelse .none) == .chunked;
380            if (is_chunked) {
381                if (content.len > 0) try out.print("{x}\r\n{s}\r\n", .{ content.len, content });
382                try out.writeAll("0\r\n\r\n");
383            } else if (content.len > 0) {
384                try out.writeAll(content);
385            }
386        }
387    }
388
389    pub const RespondStreamingOptions = struct {
390        /// If provided, the response will use the content-length header;
391        /// otherwise it will use transfer-encoding: chunked.
392        content_length: ?u64 = null,
393        /// Options that are shared with the `respond` method.
394        respond_options: RespondOptions = .{},
395    };
396
397    /// The header is not guaranteed to be sent until `BodyWriter.flush` or
398    /// `BodyWriter.end` is called.
399    ///
400    /// If the request contains a body and the connection is to be reused,
401    /// discards the request body, leaving the Server in the `ready` state. If
402    /// this discarding fails, the connection is marked as not to be reused and
403    /// no error is surfaced.
404    ///
405    /// HEAD requests are handled transparently by setting the
406    /// `BodyWriter.elide` flag on the returned `BodyWriter`, causing
407    /// the response stream to omit the body. However, it may be worth noticing
408    /// that flag and skipping any expensive work that would otherwise need to
409    /// be done to satisfy the request.
410    ///
411    /// Asserts status is not `continue`.
412    pub fn respondStreaming(
413        request: *Request,
414        buffer: []u8,
415        options: RespondStreamingOptions,
416    ) ExpectContinueError!http.BodyWriter {
417        try writeExpectContinue(request);
418        const o = options.respond_options;
419        assert(o.status != .@"continue");
420        const transfer_encoding_none = (o.transfer_encoding orelse .chunked) == .none;
421        const server_keep_alive = !transfer_encoding_none and o.keep_alive;
422        const keep_alive = request.discardBody(server_keep_alive);
423        const phrase = o.reason orelse o.status.phrase() orelse "";
424        const out = request.server.out;
425
426        try out.print("{s} {d} {s}\r\n", .{
427            @tagName(o.version), @intFromEnum(o.status), phrase,
428        });
429
430        switch (o.version) {
431            .@"HTTP/1.0" => if (keep_alive) try out.writeAll("connection: keep-alive\r\n"),
432            .@"HTTP/1.1" => if (!keep_alive) try out.writeAll("connection: close\r\n"),
433        }
434
435        if (o.transfer_encoding) |transfer_encoding| switch (transfer_encoding) {
436            .chunked => try out.writeAll("transfer-encoding: chunked\r\n"),
437            .none => {},
438        } else if (options.content_length) |len| {
439            try out.print("content-length: {d}\r\n", .{len});
440        } else {
441            try out.writeAll("transfer-encoding: chunked\r\n");
442        }
443
444        for (o.extra_headers) |header| {
445            assert(header.name.len != 0);
446            var bufs: [4][]const u8 = .{ header.name, ": ", header.value, "\r\n" };
447            try out.writeVecAll(&bufs);
448        }
449
450        try out.writeAll("\r\n");
451        const elide_body = request.head.method == .HEAD;
452        const state: http.BodyWriter.State = if (o.transfer_encoding) |te| switch (te) {
453            .chunked => .init_chunked,
454            .none => .none,
455        } else if (options.content_length) |len| .{
456            .content_length = len,
457        } else .init_chunked;
458
459        return if (elide_body) .{
460            .http_protocol_output = request.server.out,
461            .state = state,
462            .writer = .{
463                .buffer = buffer,
464                .vtable = &.{
465                    .drain = http.BodyWriter.elidingDrain,
466                    .sendFile = http.BodyWriter.elidingSendFile,
467                },
468            },
469        } else .{
470            .http_protocol_output = request.server.out,
471            .state = state,
472            .writer = .{
473                .buffer = buffer,
474                .vtable = switch (state) {
475                    .none => &.{
476                        .drain = http.BodyWriter.noneDrain,
477                        .sendFile = http.BodyWriter.noneSendFile,
478                    },
479                    .content_length => &.{
480                        .drain = http.BodyWriter.contentLengthDrain,
481                        .sendFile = http.BodyWriter.contentLengthSendFile,
482                    },
483                    .chunk_len => &.{
484                        .drain = http.BodyWriter.chunkedDrain,
485                        .sendFile = http.BodyWriter.chunkedSendFile,
486                    },
487                    .end => unreachable,
488                },
489            },
490        };
491    }
492
493    pub const UpgradeRequest = union(enum) {
494        websocket: ?[]const u8,
495        other: []const u8,
496        none,
497    };
498
499    /// Does not invalidate `request.head`.
500    pub fn upgradeRequested(request: *const Request) UpgradeRequest {
501        switch (request.head.version) {
502            .@"HTTP/1.0" => return .none,
503            .@"HTTP/1.1" => if (request.head.method != .GET) return .none,
504        }
505
506        var sec_websocket_key: ?[]const u8 = null;
507        var upgrade_name: ?[]const u8 = null;
508        var it = request.iterateHeaders();
509        while (it.next()) |header| {
510            if (std.ascii.eqlIgnoreCase(header.name, "sec-websocket-key")) {
511                sec_websocket_key = header.value;
512            } else if (std.ascii.eqlIgnoreCase(header.name, "upgrade")) {
513                upgrade_name = header.value;
514            }
515        }
516
517        const name = upgrade_name orelse return .none;
518        if (std.ascii.eqlIgnoreCase(name, "websocket")) return .{ .websocket = sec_websocket_key };
519        return .{ .other = name };
520    }
521
522    pub const WebSocketOptions = struct {
523        /// The value from `UpgradeRequest.websocket` (sec-websocket-key header value).
524        key: []const u8,
525        reason: ?[]const u8 = null,
526        extra_headers: []const http.Header = &.{},
527    };
528
529    /// The header is not guaranteed to be sent until `WebSocket.flush` is
530    /// called on the returned struct.
531    pub fn respondWebSocket(request: *Request, options: WebSocketOptions) ExpectContinueError!WebSocket {
532        if (request.head.expect != null) return error.HttpExpectationFailed;
533
534        const out = request.server.out;
535        const version: http.Version = .@"HTTP/1.1";
536        const status: http.Status = .switching_protocols;
537        const phrase = options.reason orelse status.phrase() orelse "";
538
539        assert(request.head.version == version);
540        assert(request.head.method == .GET);
541
542        var sha1 = std.crypto.hash.Sha1.init(.{});
543        sha1.update(options.key);
544        sha1.update("258EAFA5-E914-47DA-95CA-C5AB0DC85B11");
545        var digest: [std.crypto.hash.Sha1.digest_length]u8 = undefined;
546        sha1.final(&digest);
547        try out.print("{s} {d} {s}\r\n", .{ @tagName(version), @intFromEnum(status), phrase });
548        try out.writeAll("connection: upgrade\r\nupgrade: websocket\r\nsec-websocket-accept: ");
549        const base64_digest = try out.writableArray(28);
550        assert(std.base64.standard.Encoder.encode(base64_digest, &digest).len == base64_digest.len);
551        try out.writeAll("\r\n");
552
553        for (options.extra_headers) |header| {
554            assert(header.name.len != 0);
555            var bufs: [4][]const u8 = .{ header.name, ": ", header.value, "\r\n" };
556            try out.writeVecAll(&bufs);
557        }
558
559        try out.writeAll("\r\n");
560
561        return .{
562            .input = request.server.reader.in,
563            .output = request.server.out,
564            .key = options.key,
565        };
566    }
567
568    /// In the case that the request contains "expect: 100-continue", this
569    /// function writes the continuation header, which means it can fail with a
570    /// write error. After sending the continuation header, it sets the
571    /// request's expect field to `null`.
572    ///
573    /// Asserts that this function is only called once.
574    ///
575    /// See `readerExpectNone` for an infallible alternative that cannot write
576    /// to the server output stream.
577    pub fn readerExpectContinue(request: *Request, buffer: []u8) ExpectContinueError!*Reader {
578        const flush = request.head.expect != null;
579        try writeExpectContinue(request);
580        if (flush) try request.server.out.flush();
581        return readerExpectNone(request, buffer);
582    }
583
584    /// Asserts the expect header is `null`. The caller must handle the
585    /// expectation manually and then set the value to `null` prior to calling
586    /// this function.
587    ///
588    /// Asserts that this function is only called once.
589    ///
590    /// Invalidates the string memory inside `Head`.
591    pub fn readerExpectNone(request: *Request, buffer: []u8) *Reader {
592        assert(request.server.reader.state == .received_head);
593        assert(request.head.expect == null);
594        request.head.invalidateStrings();
595        if (!request.head.method.requestHasBody()) return .ending;
596        return request.server.reader.bodyReader(buffer, request.head.transfer_encoding, request.head.content_length);
597    }
598
599    pub const ExpectContinueError = error{
600        /// Failed to write "HTTP/1.1 100 Continue\r\n\r\n" to the stream.
601        WriteFailed,
602        /// The client sent an expect HTTP header value other than
603        /// "100-continue".
604        HttpExpectationFailed,
605    };
606
607    pub fn writeExpectContinue(request: *Request) ExpectContinueError!void {
608        const expect = request.head.expect orelse return;
609        if (!mem.eql(u8, expect, "100-continue")) return error.HttpExpectationFailed;
610        try request.server.out.writeAll("HTTP/1.1 100 Continue\r\n\r\n");
611        request.head.expect = null;
612    }
613
614    /// Returns whether the connection should remain persistent.
615    ///
616    /// If it would fail, it instead sets the Server state to receiving body
617    /// and returns false.
618    fn discardBody(request: *Request, keep_alive: bool) bool {
619        // Prepare to receive another request on the same connection.
620        // There are two factors to consider:
621        // * Any body the client sent must be discarded.
622        // * The Server's read_buffer may already have some bytes in it from
623        //   whatever came after the head, which may be the next HTTP request
624        //   or the request body.
625        // If the connection won't be kept alive, then none of this matters
626        // because the connection will be severed after the response is sent.
627        const r = &request.server.reader;
628        if (keep_alive and request.head.keep_alive) switch (r.state) {
629            .received_head => {
630                if (request.head.method.requestHasBody()) {
631                    assert(request.head.transfer_encoding != .none or request.head.content_length != null);
632                    const reader_interface = request.readerExpectContinue(&.{}) catch return false;
633                    _ = reader_interface.discardRemaining() catch return false;
634                    assert(r.state == .ready);
635                } else {
636                    r.state = .ready;
637                }
638                return true;
639            },
640            .body_remaining_content_length, .body_remaining_chunk_len, .body_none, .ready => return true,
641            else => unreachable,
642        };
643
644        // Avoid clobbering the state in case a reading stream already exists.
645        switch (r.state) {
646            .received_head => r.state = .closing,
647            else => {},
648        }
649        return false;
650    }
651};
652
653/// See https://tools.ietf.org/html/rfc6455
654pub const WebSocket = struct {
655    key: []const u8,
656    input: *Reader,
657    output: *Writer,
658
659    pub const Header0 = packed struct(u8) {
660        opcode: Opcode,
661        rsv3: u1 = 0,
662        rsv2: u1 = 0,
663        rsv1: u1 = 0,
664        fin: bool,
665    };
666
667    pub const Header1 = packed struct(u8) {
668        payload_len: enum(u7) {
669            len16 = 126,
670            len64 = 127,
671            _,
672        },
673        mask: bool,
674    };
675
676    pub const Opcode = enum(u4) {
677        continuation = 0,
678        text = 1,
679        binary = 2,
680        connection_close = 8,
681        ping = 9,
682        /// "A Pong frame MAY be sent unsolicited. This serves as a unidirectional
683        /// heartbeat. A response to an unsolicited Pong frame is not expected."
684        pong = 10,
685        _,
686    };
687
688    pub const ReadSmallTextMessageError = error{
689        ConnectionClose,
690        UnexpectedOpCode,
691        MessageOversize,
692        MissingMaskBit,
693        ReadFailed,
694        EndOfStream,
695    };
696
697    pub const SmallMessage = struct {
698        /// Can be text, binary, or ping.
699        opcode: Opcode,
700        data: []u8,
701    };
702
703    /// Reads the next message from the WebSocket stream, failing if the
704    /// message does not fit into the input buffer. The returned memory points
705    /// into the input buffer and is invalidated on the next read.
706    pub fn readSmallMessage(ws: *WebSocket) ReadSmallTextMessageError!SmallMessage {
707        const in = ws.input;
708        while (true) {
709            const header = try in.takeArray(2);
710            const h0: Header0 = @bitCast(header[0]);
711            const h1: Header1 = @bitCast(header[1]);
712
713            switch (h0.opcode) {
714                .text, .binary, .pong, .ping => {},
715                .connection_close => return error.ConnectionClose,
716                .continuation => return error.UnexpectedOpCode,
717                _ => return error.UnexpectedOpCode,
718            }
719
720            if (!h0.fin) return error.MessageOversize;
721            if (!h1.mask) return error.MissingMaskBit;
722
723            const len: usize = switch (h1.payload_len) {
724                .len16 => try in.takeInt(u16, .big),
725                .len64 => std.math.cast(usize, try in.takeInt(u64, .big)) orelse return error.MessageOversize,
726                else => @intFromEnum(h1.payload_len),
727            };
728            if (len > in.buffer.len) return error.MessageOversize;
729            const mask: u32 = @bitCast((try in.takeArray(4)).*);
730            const payload = try in.take(len);
731
732            // Skip pongs.
733            if (h0.opcode == .pong) continue;
734
735            // The last item may contain a partial word of unused data.
736            const floored_len = (payload.len / 4) * 4;
737            const u32_payload: []align(1) u32 = @ptrCast(payload[0..floored_len]);
738            for (u32_payload) |*elem| elem.* ^= mask;
739            const mask_bytes: []const u8 = @ptrCast(&mask);
740            for (payload[floored_len..], mask_bytes[0 .. payload.len - floored_len]) |*leftover, m|
741                leftover.* ^= m;
742
743            return .{
744                .opcode = h0.opcode,
745                .data = payload,
746            };
747        }
748    }
749
750    pub fn writeMessage(ws: *WebSocket, data: []const u8, op: Opcode) Writer.Error!void {
751        var bufs: [1][]const u8 = .{data};
752        try writeMessageVecUnflushed(ws, &bufs, op);
753        try ws.output.flush();
754    }
755
756    pub fn writeMessageUnflushed(ws: *WebSocket, data: []const u8, op: Opcode) Writer.Error!void {
757        var bufs: [1][]const u8 = .{data};
758        try writeMessageVecUnflushed(ws, &bufs, op);
759    }
760
761    pub fn writeMessageVec(ws: *WebSocket, data: [][]const u8, op: Opcode) Writer.Error!void {
762        try writeMessageVecUnflushed(ws, data, op);
763        try ws.output.flush();
764    }
765
766    pub fn writeMessageVecUnflushed(ws: *WebSocket, data: [][]const u8, op: Opcode) Writer.Error!void {
767        const total_len = l: {
768            var total_len: u64 = 0;
769            for (data) |iovec| total_len += iovec.len;
770            break :l total_len;
771        };
772        const out = ws.output;
773        try out.writeByte(@bitCast(@as(Header0, .{
774            .opcode = op,
775            .fin = true,
776        })));
777        switch (total_len) {
778            0...125 => try out.writeByte(@bitCast(@as(Header1, .{
779                .payload_len = @enumFromInt(total_len),
780                .mask = false,
781            }))),
782            126...0xffff => {
783                try out.writeByte(@bitCast(@as(Header1, .{
784                    .payload_len = .len16,
785                    .mask = false,
786                })));
787                try out.writeInt(u16, @intCast(total_len), .big);
788            },
789            else => {
790                try out.writeByte(@bitCast(@as(Header1, .{
791                    .payload_len = .len64,
792                    .mask = false,
793                })));
794                try out.writeInt(u64, total_len, .big);
795            },
796        }
797        try out.writeVecAll(data);
798    }
799
800    pub fn flush(ws: *WebSocket) Writer.Error!void {
801        try ws.output.flush();
802    }
803};