Commit b9fd0eeca6

Andrew Kelley <andrew@ziglang.org>
2024-08-05 00:25:42
add std.http.WebSocket
1 parent d36c182
Changed files (2)
lib/std/http/WebSocket.zig
@@ -0,0 +1,243 @@
+//! See https://tools.ietf.org/html/rfc6455
+
+const builtin = @import("builtin");
+const std = @import("std");
+const WebSocket = @This();
+const assert = std.debug.assert;
+const native_endian = builtin.cpu.arch.endian();
+
+key: []const u8,
+request: *std.http.Server.Request,
+recv_fifo: std.fifo.LinearFifo(u8, .Slice),
+reader: std.io.AnyReader,
+response: std.http.Server.Response,
+/// Number of bytes that have been peeked but not discarded yet.
+outstanding_len: usize,
+
+pub const InitError = error{WebSocketUpgradeMissingKey} ||
+    std.http.Server.Request.ReaderError;
+
+pub fn init(
+    ws: *WebSocket,
+    request: *std.http.Server.Request,
+    send_buffer: []u8,
+    recv_buffer: []align(4) u8,
+) InitError!bool {
+    var sec_websocket_key: ?[]const u8 = null;
+    var upgrade_websocket: bool = false;
+    var it = request.iterateHeaders();
+    while (it.next()) |header| {
+        if (std.ascii.eqlIgnoreCase(header.name, "sec-websocket-key")) {
+            sec_websocket_key = header.value;
+        } else if (std.ascii.eqlIgnoreCase(header.name, "upgrade")) {
+            if (!std.mem.eql(u8, header.value, "websocket"))
+                return false;
+            upgrade_websocket = true;
+        }
+    }
+    if (!upgrade_websocket)
+        return false;
+
+    const key = sec_websocket_key orelse return error.WebSocketUpgradeMissingKey;
+
+    var sha1 = std.crypto.hash.Sha1.init(.{});
+    sha1.update(key);
+    sha1.update("258EAFA5-E914-47DA-95CA-C5AB0DC85B11");
+    var digest: [std.crypto.hash.Sha1.digest_length]u8 = undefined;
+    sha1.final(&digest);
+    var base64_digest: [28]u8 = undefined;
+    assert(std.base64.standard.Encoder.encode(&base64_digest, &digest).len == base64_digest.len);
+
+    request.head.content_length = std.math.maxInt(u64);
+
+    ws.* = .{
+        .key = key,
+        .recv_fifo = std.fifo.LinearFifo(u8, .Slice).init(recv_buffer),
+        .reader = try request.reader(),
+        .response = request.respondStreaming(.{
+            .send_buffer = send_buffer,
+            .respond_options = .{
+                .status = .switching_protocols,
+                .extra_headers = &.{
+                    .{ .name = "upgrade", .value = "websocket" },
+                    .{ .name = "connection", .value = "upgrade" },
+                    .{ .name = "sec-websocket-accept", .value = &base64_digest },
+                },
+                .transfer_encoding = .none,
+            },
+        }),
+        .request = request,
+        .outstanding_len = 0,
+    };
+    return true;
+}
+
+pub const Header0 = packed struct(u8) {
+    opcode: Opcode,
+    rsv3: u1 = 0,
+    rsv2: u1 = 0,
+    rsv1: u1 = 0,
+    fin: bool,
+};
+
+pub const Header1 = packed struct(u8) {
+    payload_len: enum(u7) {
+        len16 = 126,
+        len64 = 127,
+        _,
+    },
+    mask: bool,
+};
+
+pub const Opcode = enum(u4) {
+    continuation = 0,
+    text = 1,
+    binary = 2,
+    connection_close = 8,
+    ping = 9,
+    /// "A Pong frame MAY be sent unsolicited. This serves as a unidirectional
+    /// heartbeat. A response to an unsolicited Pong frame is not expected."
+    pong = 10,
+    _,
+};
+
+pub const ReadSmallTextMessageError = error{
+    ConnectionClose,
+    UnexpectedOpCode,
+    MessageTooBig,
+    MissingMaskBit,
+} || RecvError;
+
+pub const SmallMessage = struct {
+    /// Can be text, binary, or ping.
+    opcode: Opcode,
+    data: []u8,
+};
+
+/// Reads the next message from the WebSocket stream, failing if the message does not fit
+/// into `recv_buffer`.
+pub fn readSmallMessage(ws: *WebSocket) ReadSmallTextMessageError!SmallMessage {
+    while (true) {
+        const header_bytes = (try recv(ws, 2))[0..2];
+        const h0: Header0 = @bitCast(header_bytes[0]);
+        const h1: Header1 = @bitCast(header_bytes[1]);
+
+        switch (h0.opcode) {
+            .text, .binary, .pong, .ping => {},
+            .connection_close => return error.ConnectionClose,
+            .continuation => return error.UnexpectedOpCode,
+            _ => return error.UnexpectedOpCode,
+        }
+
+        if (!h0.fin) return error.MessageTooBig;
+        if (!h1.mask) return error.MissingMaskBit;
+
+        const len: usize = switch (h1.payload_len) {
+            .len16 => try recvReadInt(ws, u16),
+            .len64 => std.math.cast(usize, try recvReadInt(ws, u64)) orelse return error.MessageTooBig,
+            else => @intFromEnum(h1.payload_len),
+        };
+        if (len > ws.recv_fifo.buf.len) return error.MessageTooBig;
+
+        const mask: u32 = @bitCast((try recv(ws, 4))[0..4].*);
+        const payload = try recv(ws, len);
+
+        // Skip pongs.
+        if (h0.opcode == .pong) continue;
+
+        // The last item may contain a partial word of unused data.
+        const floored_len = (payload.len / 4) * 4;
+        const u32_payload: []align(1) u32 = @alignCast(std.mem.bytesAsSlice(u32, payload[0..floored_len]));
+        for (u32_payload) |*elem| elem.* ^= mask;
+        const mask_bytes = std.mem.asBytes(&mask)[0 .. payload.len - floored_len];
+        for (payload[floored_len..], mask_bytes) |*leftover, m| leftover.* ^= m;
+
+        return .{
+            .opcode = h0.opcode,
+            .data = payload,
+        };
+    }
+}
+
+const RecvError = std.http.Server.Request.ReadError || error{EndOfStream};
+
+fn recv(ws: *WebSocket, len: usize) RecvError![]u8 {
+    ws.recv_fifo.discard(ws.outstanding_len);
+    assert(len <= ws.recv_fifo.buf.len);
+    if (len > ws.recv_fifo.count) {
+        const small_buf = ws.recv_fifo.writableSlice(0);
+        const needed = len - ws.recv_fifo.count;
+        const buf = if (small_buf.len >= needed) small_buf else b: {
+            ws.recv_fifo.realign();
+            break :b ws.recv_fifo.writableSlice(0);
+        };
+        const n = try @as(RecvError!usize, @errorCast(ws.reader.readAtLeast(buf, needed)));
+        if (n < needed) return error.EndOfStream;
+        ws.recv_fifo.update(n);
+    }
+    ws.outstanding_len = len;
+    // TODO: improve the std lib API so this cast isn't necessary.
+    return @constCast(ws.recv_fifo.readableSliceOfLen(len));
+}
+
+fn recvReadInt(ws: *WebSocket, comptime I: type) !I {
+    const unswapped: I = @bitCast((try recv(ws, @sizeOf(I)))[0..@sizeOf(I)].*);
+    return switch (native_endian) {
+        .little => @byteSwap(unswapped),
+        .big => unswapped,
+    };
+}
+
+pub const WriteError = std.http.Server.Response.WriteError;
+
+pub fn writeMessage(ws: *WebSocket, message: []const u8, opcode: Opcode) WriteError!void {
+    const iovecs: [1]std.posix.iovec_const = .{
+        .{ .base = message.ptr, .len = message.len },
+    };
+    return writeMessagev(ws, &iovecs, opcode);
+}
+
+pub fn writeMessagev(ws: *WebSocket, message: []const std.posix.iovec_const, opcode: Opcode) WriteError!void {
+    const total_len = l: {
+        var total_len: u64 = 0;
+        for (message) |iovec| total_len += iovec.len;
+        break :l total_len;
+    };
+
+    var header_buf: [2 + 8]u8 = undefined;
+    header_buf[0] = @bitCast(@as(Header0, .{
+        .opcode = opcode,
+        .fin = true,
+    }));
+    const header = switch (total_len) {
+        0...125 => blk: {
+            header_buf[1] = @bitCast(@as(Header1, .{
+                .payload_len = @enumFromInt(total_len),
+                .mask = false,
+            }));
+            break :blk header_buf[0..2];
+        },
+        126...0xffff => blk: {
+            header_buf[1] = @bitCast(@as(Header1, .{
+                .payload_len = .len16,
+                .mask = false,
+            }));
+            std.mem.writeInt(u16, header_buf[2..4], @intCast(total_len), .big);
+            break :blk header_buf[0..4];
+        },
+        else => blk: {
+            header_buf[1] = @bitCast(@as(Header1, .{
+                .payload_len = .len64,
+                .mask = false,
+            }));
+            std.mem.writeInt(u64, header_buf[2..10], total_len, .big);
+            break :blk header_buf[0..10];
+        },
+    };
+
+    const response = &ws.response;
+    try response.writeAll(header);
+    for (message) |iovec|
+        try response.writeAll(iovec.base[0..iovec.len]);
+    try response.flush();
+}
lib/std/http.zig
@@ -4,6 +4,7 @@ pub const protocol = @import("http/protocol.zig");
 pub const HeadParser = @import("http/HeadParser.zig");
 pub const ChunkParser = @import("http/ChunkParser.zig");
 pub const HeaderIterator = @import("http/HeaderIterator.zig");
+pub const WebSocket = @import("http/WebSocket.zig");
 
 pub const Version = enum {
     @"HTTP/1.0",
@@ -318,6 +319,7 @@ test {
         _ = Status;
         _ = HeadParser;
         _ = ChunkParser;
+        _ = WebSocket;
         _ = @import("http/test.zig");
     }
 }