Commit ba569bb8e9

tgschultz <tgschultz@gmail.com>
2024-10-14 03:44:42
Rewrite bit_reader and bit_writer to take advantage of current zig semantics and enhance readability (#21689)
Co-authored-by: Tanner Schultz <tgschultz@tgschultz-dl.tail7ba92.ts.net>
1 parent e2e7996
Changed files (6)
lib/std/compress/zstandard/decode/block.zig
@@ -405,7 +405,7 @@ pub const DecodeState = struct {
     };
     fn readLiteralsBits(
         self: *DecodeState,
-        bit_count_to_read: usize,
+        bit_count_to_read: u16,
     ) LiteralBitsError!u16 {
         return self.literal_stream_reader.readBitsNoEof(u16, bit_count_to_read) catch bits: {
             if (self.literal_streams == .four and self.literal_stream_index < 3) {
lib/std/compress/zstandard/decode/huffman.zig
@@ -63,7 +63,7 @@ fn decodeFseHuffmanTreeSlice(src: []const u8, compressed_size: usize, weights: *
 
 fn assignWeights(
     huff_bits: *readers.ReverseBitReader,
-    accuracy_log: usize,
+    accuracy_log: u16,
     entries: *[1 << 6]Table.Fse,
     weights: *[256]u4,
 ) !usize {
@@ -73,7 +73,7 @@ fn assignWeights(
 
     while (i < 254) {
         const even_data = entries[even_state];
-        var read_bits: usize = 0;
+        var read_bits: u16 = 0;
         const even_bits = huff_bits.readBits(u32, even_data.bits, &read_bits) catch unreachable;
         weights[i] = std.math.cast(u4, even_data.symbol) orelse return error.MalformedHuffmanTree;
         i += 1;
lib/std/compress/zstandard/readers.zig
@@ -42,11 +42,11 @@ pub const ReverseBitReader = struct {
         if (i == 8) return error.BitStreamHasNoStartBit;
     }
 
-    pub fn readBitsNoEof(self: *@This(), comptime U: type, num_bits: usize) error{EndOfStream}!U {
+    pub fn readBitsNoEof(self: *@This(), comptime U: type, num_bits: u16) error{EndOfStream}!U {
         return self.bit_reader.readBitsNoEof(U, num_bits);
     }
 
-    pub fn readBits(self: *@This(), comptime U: type, num_bits: usize, out_bits: *usize) error{}!U {
+    pub fn readBits(self: *@This(), comptime U: type, num_bits: u16, out_bits: *u16) error{}!U {
         return try self.bit_reader.readBits(U, num_bits, out_bits);
     }
 
@@ -55,7 +55,7 @@ pub const ReverseBitReader = struct {
     }
 
     pub fn isEmpty(self: ReverseBitReader) bool {
-        return self.byte_reader.remaining_bytes == 0 and self.bit_reader.bit_count == 0;
+        return self.byte_reader.remaining_bytes == 0 and self.bit_reader.count == 0;
     }
 };
 
@@ -63,11 +63,11 @@ pub fn BitReader(comptime Reader: type) type {
     return struct {
         underlying: std.io.BitReader(.little, Reader),
 
-        pub fn readBitsNoEof(self: *@This(), comptime U: type, num_bits: usize) !U {
+        pub fn readBitsNoEof(self: *@This(), comptime U: type, num_bits: u16) !U {
             return self.underlying.readBitsNoEof(U, num_bits);
         }
 
-        pub fn readBits(self: *@This(), comptime U: type, num_bits: usize, out_bits: *usize) !U {
+        pub fn readBits(self: *@This(), comptime U: type, num_bits: u16, out_bits: *u16) !U {
             return self.underlying.readBits(U, num_bits, out_bits);
         }
 
lib/std/io/bit_reader.zig
@@ -1,176 +1,179 @@
 const std = @import("../std.zig");
-const io = std.io;
-const assert = std.debug.assert;
-const testing = std.testing;
-const meta = std.meta;
-const math = std.math;
-
-/// Creates a stream which allows for reading bit fields from another stream
-pub fn BitReader(comptime endian: std.builtin.Endian, comptime ReaderType: type) type {
+
+//General note on endianess:
+//Big endian is packed starting in the most significant part of the byte and subsequent
+// bytes contain less significant bits. Thus we always take bits from the high
+// end and place them below existing bits in our output.
+//Little endian is packed starting in the least significant part of the byte and
+// subsequent bytes contain more significant bits. Thus we always take bits from
+// the low end and place them above existing bits in our output.
+//Regardless of endianess, within any given byte the bits are always in most
+// to least significant order.
+//Also regardless of endianess, the buffer always aligns bits to the low end
+// of the byte.
+
+/// Creates a bit reader which allows for reading bits from an underlying standard reader
+pub fn BitReader(comptime endian: std.builtin.Endian, comptime Reader: type) type {
     return struct {
-        forward_reader: ReaderType,
-        bit_buffer: u7,
-        bit_count: u3,
-
-        pub const Error = ReaderType.Error;
-        pub const Reader = io.Reader(*Self, Error, read);
-
-        const Self = @This();
-        const u8_bit_count = @bitSizeOf(u8);
-        const u7_bit_count = @bitSizeOf(u7);
-        const u4_bit_count = @bitSizeOf(u4);
-
-        pub fn init(forward_reader: ReaderType) Self {
-            return Self{
-                .forward_reader = forward_reader,
-                .bit_buffer = 0,
-                .bit_count = 0,
+        reader: Reader,
+        bits: u8 = 0,
+        count: u4 = 0,
+
+        const low_bit_mask = [9]u8{
+            0b00000000,
+            0b00000001,
+            0b00000011,
+            0b00000111,
+            0b00001111,
+            0b00011111,
+            0b00111111,
+            0b01111111,
+            0b11111111,
+        };
+
+        fn Bits(comptime T: type) type {
+            return struct {
+                T,
+                u16,
+            };
+        }
+
+        fn initBits(comptime T: type, out: anytype, num: u16) Bits(T) {
+            const UT = std.meta.Int(.unsigned, @bitSizeOf(T));
+            return .{
+                @bitCast(@as(UT, @intCast(out))),
+                num,
             };
         }
 
-        /// Reads `bits` bits from the stream and returns a specified unsigned int type
+        /// Reads `bits` bits from the reader and returns a specified type
         ///  containing them in the least significant end, returning an error if the
         ///  specified number of bits could not be read.
-        pub fn readBitsNoEof(self: *Self, comptime U: type, bits: usize) !U {
-            var n: usize = undefined;
-            const result = try self.readBits(U, bits, &n);
-            if (n < bits) return error.EndOfStream;
-            return result;
+        pub fn readBitsNoEof(self: *@This(), comptime T: type, num: u16) !T {
+            const b, const c = try self.readBitsTuple(T, num);
+            if (c < num) return error.EndOfStream;
+            return b;
         }
 
-        /// Reads `bits` bits from the stream and returns a specified unsigned int type
+        /// Reads `bits` bits from the reader and returns a specified type
         ///  containing them in the least significant end. The number of bits successfully
         ///  read is placed in `out_bits`, as reaching the end of the stream is not an error.
-        pub fn readBits(self: *Self, comptime U: type, bits: usize, out_bits: *usize) Error!U {
-            //by extending the buffer to a minimum of u8 we can cover a number of edge cases
-            // related to shifting and casting.
-            const u_bit_count = @bitSizeOf(U);
-            const buf_bit_count = bc: {
-                assert(u_bit_count >= bits);
-                break :bc if (u_bit_count <= u8_bit_count) u8_bit_count else u_bit_count;
-            };
-            const Buf = std.meta.Int(.unsigned, buf_bit_count);
-            const BufShift = math.Log2Int(Buf);
+        pub fn readBits(self: *@This(), comptime T: type, num: u16, out_bits: *u16) !T {
+            const b, const c = try self.readBitsTuple(T, num);
+            out_bits.* = c;
+            return b;
+        }
 
-            out_bits.* = @as(usize, 0);
-            if (U == u0 or bits == 0) return 0;
-            var out_buffer = @as(Buf, 0);
+        /// Reads `bits` bits from the reader and returns a tuple of the specified type
+        ///  containing them in the least significant end, and the number of bits successfully
+        ///  read. Reaching the end of the stream is not an error.
+        pub fn readBitsTuple(self: *@This(), comptime T: type, num: u16) !Bits(T) {
+            const UT = std.meta.Int(.unsigned, @bitSizeOf(T));
+            const U = if (@bitSizeOf(T) < 8) u8 else UT; //it is a pain to work with <u8
 
-            if (self.bit_count > 0) {
-                const n = if (self.bit_count >= bits) @as(u3, @intCast(bits)) else self.bit_count;
-                const shift = u7_bit_count - n;
-                switch (endian) {
-                    .big => {
-                        out_buffer = @as(Buf, self.bit_buffer >> shift);
-                        if (n >= u7_bit_count)
-                            self.bit_buffer = 0
-                        else
-                            self.bit_buffer <<= n;
-                    },
-                    .little => {
-                        const value = (self.bit_buffer << shift) >> shift;
-                        out_buffer = @as(Buf, value);
-                        if (n >= u7_bit_count)
-                            self.bit_buffer = 0
-                        else
-                            self.bit_buffer >>= n;
-                    },
-                }
-                self.bit_count -= n;
-                out_bits.* = n;
-            }
-            //at this point we know bit_buffer is empty
+            //dump any bits in our buffer first
+            if (num <= self.count) return initBits(T, self.removeBits(@intCast(num)), num);
 
-            //copy bytes until we have enough bits, then leave the rest in bit_buffer
-            while (out_bits.* < bits) {
-                const n = bits - out_bits.*;
-                const next_byte = self.forward_reader.readByte() catch |err| switch (err) {
-                    error.EndOfStream => return @as(U, @intCast(out_buffer)),
+            var out_count: u16 = self.count;
+            var out: U = self.removeBits(self.count);
+
+            //grab all the full bytes we need and put their
+            //bits where they belong
+            const full_bytes_left = (num - out_count) / 8;
+
+            for (0..full_bytes_left) |_| {
+                const byte = self.reader.readByte() catch |err| switch (err) {
+                    error.EndOfStream => return initBits(T, out, out_count),
                     else => |e| return e,
                 };
 
                 switch (endian) {
                     .big => {
-                        if (n >= u8_bit_count) {
-                            out_buffer <<= @as(u3, @intCast(u8_bit_count - 1));
-                            out_buffer <<= 1;
-                            out_buffer |= @as(Buf, next_byte);
-                            out_bits.* += u8_bit_count;
-                            continue;
-                        }
-
-                        const shift = @as(u3, @intCast(u8_bit_count - n));
-                        out_buffer <<= @as(BufShift, @intCast(n));
-                        out_buffer |= @as(Buf, next_byte >> shift);
-                        out_bits.* += n;
-                        self.bit_buffer = @as(u7, @truncate(next_byte << @as(u3, @intCast(n - 1))));
-                        self.bit_count = shift;
+                        if (U == u8) out = 0 else out <<= 8; //shifting u8 by 8 is illegal in Zig
+                        out |= byte;
                     },
                     .little => {
-                        if (n >= u8_bit_count) {
-                            out_buffer |= @as(Buf, next_byte) << @as(BufShift, @intCast(out_bits.*));
-                            out_bits.* += u8_bit_count;
-                            continue;
-                        }
-
-                        const shift = @as(u3, @intCast(u8_bit_count - n));
-                        const value = (next_byte << shift) >> shift;
-                        out_buffer |= @as(Buf, value) << @as(BufShift, @intCast(out_bits.*));
-                        out_bits.* += n;
-                        self.bit_buffer = @as(u7, @truncate(next_byte >> @as(u3, @intCast(n))));
-                        self.bit_count = shift;
+                        const pos = @as(U, byte) << @intCast(out_count);
+                        out |= pos;
                     },
                 }
+                out_count += 8;
             }
 
-            return @as(U, @intCast(out_buffer));
-        }
+            const bits_left = num - out_count;
+            const keep = 8 - bits_left;
+
+            if (bits_left == 0) return initBits(T, out, out_count);
 
-        pub fn alignToByte(self: *Self) void {
-            self.bit_buffer = 0;
-            self.bit_count = 0;
+            const final_byte = self.reader.readByte() catch |err| switch (err) {
+                error.EndOfStream => return initBits(T, out, out_count),
+                else => |e| return e,
+            };
+
+            switch (endian) {
+                .big => {
+                    out <<= @intCast(bits_left);
+                    out |= final_byte >> @intCast(keep);
+                    self.bits = final_byte & low_bit_mask[keep];
+                },
+                .little => {
+                    const pos = @as(U, final_byte & low_bit_mask[bits_left]) << @intCast(out_count);
+                    out |= pos;
+                    self.bits = final_byte >> @intCast(bits_left);
+                },
+            }
+
+            self.count = @intCast(keep);
+            return initBits(T, out, num);
         }
 
-        pub fn read(self: *Self, buffer: []u8) Error!usize {
-            var out_bits: usize = undefined;
-            var out_bits_total = @as(usize, 0);
-            //@NOTE: I'm not sure this is a good idea, maybe alignToByte should be forced
-            if (self.bit_count > 0) {
-                for (buffer) |*b| {
-                    b.* = try self.readBits(u8, u8_bit_count, &out_bits);
-                    out_bits_total += out_bits;
-                }
-                const incomplete_byte = @intFromBool(out_bits_total % u8_bit_count > 0);
-                return (out_bits_total / u8_bit_count) + incomplete_byte;
+        //convenience function for removing bits from
+        //the appropriate part of the buffer based on
+        //endianess.
+        fn removeBits(self: *@This(), num: u4) u8 {
+            if (num == 8) {
+                self.count = 0;
+                return self.bits;
+            }
+
+            const keep = self.count - num;
+            const bits = switch (endian) {
+                .big => self.bits >> @intCast(keep),
+                .little => self.bits & low_bit_mask[num],
+            };
+            switch (endian) {
+                .big => self.bits &= low_bit_mask[keep],
+                .little => self.bits >>= @intCast(num),
             }
 
-            return self.forward_reader.read(buffer);
+            self.count = keep;
+            return bits;
         }
 
-        pub fn reader(self: *Self) Reader {
-            return .{ .context = self };
+        pub fn alignToByte(self: *@This()) void {
+            self.bits = 0;
+            self.count = 0;
         }
     };
 }
 
-pub fn bitReader(
-    comptime endian: std.builtin.Endian,
-    underlying_stream: anytype,
-) BitReader(endian, @TypeOf(underlying_stream)) {
-    return BitReader(endian, @TypeOf(underlying_stream)).init(underlying_stream);
+pub fn bitReader(comptime endian: std.builtin.Endian, reader: anytype) BitReader(endian, @TypeOf(reader)) {
+    return .{ .reader = reader };
 }
 
+///////////////////////////////
+
 test "api coverage" {
     const mem_be = [_]u8{ 0b11001101, 0b00001011 };
     const mem_le = [_]u8{ 0b00011101, 0b10010101 };
 
-    var mem_in_be = io.fixedBufferStream(&mem_be);
+    var mem_in_be = std.io.fixedBufferStream(&mem_be);
     var bit_stream_be = bitReader(.big, mem_in_be.reader());
 
-    var out_bits: usize = undefined;
+    var out_bits: u16 = undefined;
 
-    const expect = testing.expect;
-    const expectError = testing.expectError;
+    const expect = std.testing.expect;
+    const expectError = std.testing.expectError;
 
     try expect(1 == try bit_stream_be.readBits(u2, 1, &out_bits));
     try expect(out_bits == 1);
@@ -186,12 +189,12 @@ test "api coverage" {
     try expect(out_bits == 1);
 
     mem_in_be.pos = 0;
-    bit_stream_be.bit_count = 0;
+    bit_stream_be.count = 0;
     try expect(0b110011010000101 == try bit_stream_be.readBits(u15, 15, &out_bits));
     try expect(out_bits == 15);
 
     mem_in_be.pos = 0;
-    bit_stream_be.bit_count = 0;
+    bit_stream_be.count = 0;
     try expect(0b1100110100001011 == try bit_stream_be.readBits(u16, 16, &out_bits));
     try expect(out_bits == 16);
 
@@ -201,7 +204,7 @@ test "api coverage" {
     try expect(out_bits == 0);
     try expectError(error.EndOfStream, bit_stream_be.readBitsNoEof(u1, 1));
 
-    var mem_in_le = io.fixedBufferStream(&mem_le);
+    var mem_in_le = std.io.fixedBufferStream(&mem_le);
     var bit_stream_le = bitReader(.little, mem_in_le.reader());
 
     try expect(1 == try bit_stream_le.readBits(u2, 1, &out_bits));
@@ -218,12 +221,12 @@ test "api coverage" {
     try expect(out_bits == 1);
 
     mem_in_le.pos = 0;
-    bit_stream_le.bit_count = 0;
+    bit_stream_le.count = 0;
     try expect(0b001010100011101 == try bit_stream_le.readBits(u15, 15, &out_bits));
     try expect(out_bits == 15);
 
     mem_in_le.pos = 0;
-    bit_stream_le.bit_count = 0;
+    bit_stream_le.count = 0;
     try expect(0b1001010100011101 == try bit_stream_le.readBits(u16, 16, &out_bits));
     try expect(out_bits == 16);
 
lib/std/io/bit_writer.zig
@@ -1,153 +1,138 @@
 const std = @import("../std.zig");
-const io = std.io;
-const testing = std.testing;
-const assert = std.debug.assert;
-const math = std.math;
 
-/// Creates a stream which allows for writing bit fields to another stream
-pub fn BitWriter(comptime endian: std.builtin.Endian, comptime WriterType: type) type {
+//General note on endianess:
+//Big endian is packed starting in the most significant part of the byte and subsequent
+// bytes contain less significant bits. Thus we write out bits from the high end
+// of our input first.
+//Little endian is packed starting in the least significant part of the byte and
+// subsequent bytes contain more significant bits. Thus we write out bits from
+// the low end of our input first.
+//Regardless of endianess, within any given byte the bits are always in most
+// to least significant order.
+//Also regardless of endianess, the buffer always aligns bits to the low end
+// of the byte.
+
+/// Creates a bit writer which allows for writing bits to an underlying standard writer
+pub fn BitWriter(comptime endian: std.builtin.Endian, comptime Writer: type) type {
     return struct {
-        forward_writer: WriterType,
-        bit_buffer: u8,
-        bit_count: u4,
-
-        pub const Error = WriterType.Error;
-        pub const Writer = io.Writer(*Self, Error, write);
-
-        const Self = @This();
-        const u8_bit_count = @bitSizeOf(u8);
-        const u4_bit_count = @bitSizeOf(u4);
-
-        pub fn init(forward_writer: WriterType) Self {
-            return Self{
-                .forward_writer = forward_writer,
-                .bit_buffer = 0,
-                .bit_count = 0,
-            };
-        }
-
-        /// Write the specified number of bits to the stream from the least significant bits of
-        ///  the specified unsigned int value. Bits will only be written to the stream when there
+        writer: Writer,
+        bits: u8 = 0,
+        count: u4 = 0,
+
+        const low_bit_mask = [9]u8{
+            0b00000000,
+            0b00000001,
+            0b00000011,
+            0b00000111,
+            0b00001111,
+            0b00011111,
+            0b00111111,
+            0b01111111,
+            0b11111111,
+        };
+
+        /// Write the specified number of bits to the writer from the least significant bits of
+        ///  the specified value. Bits will only be written to the writer when there
         ///  are enough to fill a byte.
-        pub fn writeBits(self: *Self, value: anytype, bits: usize) Error!void {
-            if (bits == 0) return;
-
-            const U = @TypeOf(value);
-            comptime assert(@typeInfo(U).int.signedness == .unsigned);
-
-            //by extending the buffer to a minimum of u8 we can cover a number of edge cases
-            // related to shifting and casting.
-            const u_bit_count = @bitSizeOf(U);
-            const buf_bit_count = bc: {
-                assert(u_bit_count >= bits);
-                break :bc if (u_bit_count <= u8_bit_count) u8_bit_count else u_bit_count;
-            };
-            const Buf = std.meta.Int(.unsigned, buf_bit_count);
-            const BufShift = math.Log2Int(Buf);
-
-            const buf_value = @as(Buf, @intCast(value));
-
-            const high_byte_shift = @as(BufShift, @intCast(buf_bit_count - u8_bit_count));
-            var in_buffer = switch (endian) {
-                .big => buf_value << @as(BufShift, @intCast(buf_bit_count - bits)),
-                .little => buf_value,
-            };
-            var in_bits = bits;
-
-            if (self.bit_count > 0) {
-                const bits_remaining = u8_bit_count - self.bit_count;
-                const n = @as(u3, @intCast(if (bits_remaining > bits) bits else bits_remaining));
+        pub fn writeBits(self: *@This(), value: anytype, num: u16) !void {
+            const T = @TypeOf(value);
+            const UT = std.meta.Int(.unsigned, @bitSizeOf(T));
+            const U = if (@bitSizeOf(T) < 8) u8 else UT; //<u8 is a pain to work with
+
+            var in: U = @as(UT, @bitCast(value));
+            var in_count: u16 = num;
+
+            if (self.count > 0) {
+                //if we can't fill the buffer, add what we have
+                const bits_free = 8 - self.count;
+                if (num < bits_free) {
+                    self.addBits(@truncate(in), @intCast(num));
+                    return;
+                }
+
+                //finish filling the buffer and flush it
+                if (num == bits_free) {
+                    self.addBits(@truncate(in), @intCast(num));
+                    return self.flushBits();
+                }
+
                 switch (endian) {
                     .big => {
-                        const shift = @as(BufShift, @intCast(high_byte_shift + self.bit_count));
-                        const v = @as(u8, @intCast(in_buffer >> shift));
-                        self.bit_buffer |= v;
-                        in_buffer <<= n;
+                        const bits = in >> @intCast(in_count - bits_free);
+                        self.addBits(@truncate(bits), bits_free);
                     },
                     .little => {
-                        const v = @as(u8, @truncate(in_buffer)) << @as(u3, @intCast(self.bit_count));
-                        self.bit_buffer |= v;
-                        in_buffer >>= n;
+                        self.addBits(@truncate(in), bits_free);
+                        in >>= @intCast(bits_free);
                     },
                 }
-                self.bit_count += n;
-                in_bits -= n;
-
-                //if we didn't fill the buffer, it's because bits < bits_remaining;
-                if (self.bit_count != u8_bit_count) return;
-                try self.forward_writer.writeByte(self.bit_buffer);
-                self.bit_buffer = 0;
-                self.bit_count = 0;
+                in_count -= bits_free;
+                try self.flushBits();
             }
-            //at this point we know bit_buffer is empty
 
-            //copy bytes until we can't fill one anymore, then leave the rest in bit_buffer
-            while (in_bits >= u8_bit_count) {
+            //write full bytes while we can
+            const full_bytes_left = in_count / 8;
+            for (0..full_bytes_left) |_| {
                 switch (endian) {
                     .big => {
-                        const v = @as(u8, @intCast(in_buffer >> high_byte_shift));
-                        try self.forward_writer.writeByte(v);
-                        in_buffer <<= @as(u3, @intCast(u8_bit_count - 1));
-                        in_buffer <<= 1;
+                        const bits = in >> @intCast(in_count - 8);
+                        try self.writer.writeByte(@truncate(bits));
                     },
                     .little => {
-                        const v = @as(u8, @truncate(in_buffer));
-                        try self.forward_writer.writeByte(v);
-                        in_buffer >>= @as(u3, @intCast(u8_bit_count - 1));
-                        in_buffer >>= 1;
+                        try self.writer.writeByte(@truncate(in));
+                        if (U == u8) in = 0 else in >>= 8;
                     },
                 }
-                in_bits -= u8_bit_count;
+                in_count -= 8;
             }
 
-            if (in_bits > 0) {
-                self.bit_count = @as(u4, @intCast(in_bits));
-                self.bit_buffer = switch (endian) {
-                    .big => @as(u8, @truncate(in_buffer >> high_byte_shift)),
-                    .little => @as(u8, @truncate(in_buffer)),
-                };
-            }
-        }
-
-        /// Flush any remaining bits to the stream.
-        pub fn flushBits(self: *Self) Error!void {
-            if (self.bit_count == 0) return;
-            try self.forward_writer.writeByte(self.bit_buffer);
-            self.bit_buffer = 0;
-            self.bit_count = 0;
+            //save the remaining bits in the buffer
+            self.addBits(@truncate(in), @intCast(in_count));
         }
 
-        pub fn write(self: *Self, buffer: []const u8) Error!usize {
-            // TODO: I'm not sure this is a good idea, maybe flushBits should be forced
-            if (self.bit_count > 0) {
-                for (buffer) |b|
-                    try self.writeBits(b, u8_bit_count);
-                return buffer.len;
+        //convenience funciton for adding bits to the buffer
+        //in the appropriate position based on endianess
+        fn addBits(self: *@This(), bits: u8, num: u4) void {
+            if (num == 8) self.bits = bits else switch (endian) {
+                .big => {
+                    self.bits <<= @intCast(num);
+                    self.bits |= bits & low_bit_mask[num];
+                },
+                .little => {
+                    const pos = bits << @intCast(self.count);
+                    self.bits |= pos;
+                },
             }
-
-            return self.forward_writer.write(buffer);
+            self.count += num;
         }
 
-        pub fn writer(self: *Self) Writer {
-            return .{ .context = self };
+        /// Flush any remaining bits to the writer, filling
+        /// unused bits with 0s.
+        pub fn flushBits(self: *@This()) !void {
+            if (self.count == 0) return;
+            if (endian == .big) self.bits <<= @intCast(8 - self.count);
+            try self.writer.writeByte(self.bits);
+            self.bits = 0;
+            self.count = 0;
         }
     };
 }
 
-pub fn bitWriter(
-    comptime endian: std.builtin.Endian,
-    underlying_stream: anytype,
-) BitWriter(endian, @TypeOf(underlying_stream)) {
-    return BitWriter(endian, @TypeOf(underlying_stream)).init(underlying_stream);
+pub fn bitWriter(comptime endian: std.builtin.Endian, writer: anytype) BitWriter(endian, @TypeOf(writer)) {
+    return .{ .writer = writer };
 }
 
+///////////////////////////////
+
 test "api coverage" {
     var mem_be = [_]u8{0} ** 2;
     var mem_le = [_]u8{0} ** 2;
 
-    var mem_out_be = io.fixedBufferStream(&mem_be);
+    var mem_out_be = std.io.fixedBufferStream(&mem_be);
     var bit_stream_be = bitWriter(.big, mem_out_be.writer());
 
+    const testing = std.testing;
+
     try bit_stream_be.writeBits(@as(u2, 1), 1);
     try bit_stream_be.writeBits(@as(u5, 2), 2);
     try bit_stream_be.writeBits(@as(u128, 3), 3);
@@ -169,7 +154,7 @@ test "api coverage" {
 
     try bit_stream_be.writeBits(@as(u0, 0), 0);
 
-    var mem_out_le = io.fixedBufferStream(&mem_le);
+    var mem_out_le = std.io.fixedBufferStream(&mem_le);
     var bit_stream_le = bitWriter(.little, mem_out_le.writer());
 
     try bit_stream_le.writeBits(@as(u2, 1), 1);
lib/std/io/test.zig
@@ -82,7 +82,7 @@ test "BitStreams with File Stream" {
 
         var bit_stream = io.bitReader(native_endian, file.reader());
 
-        var out_bits: usize = undefined;
+        var out_bits: u16 = undefined;
 
         try expect(1 == try bit_stream.readBits(u2, 1, &out_bits));
         try expect(out_bits == 1);