Commit 435ccf706d

Andrew Kelley <andrew@ziglang.org>
2025-06-30 21:17:47
std.fs.File.Writer: fix drain implementation
it didn't account for data.len can no longer be zero
1 parent 77e839e
Changed files (1)
lib
std
lib/std/fs/File.zig
@@ -1648,29 +1648,28 @@ pub const Writer = struct {
         };
     }
 
-    pub fn drain(io_writer: *std.io.Writer, data: []const []const u8, splat: usize) std.io.Writer.Error!usize {
-        const w: *Writer = @fieldParentPtr("interface", io_writer);
+    pub fn drain(io_w: *std.io.Writer, data: []const []const u8, splat: usize) std.io.Writer.Error!usize {
+        const w: *Writer = @fieldParentPtr("interface", io_w);
         const handle = w.file.handle;
-        const buffered = io_writer.buffered();
-        var splat_buffer: [256]u8 = undefined;
+        const buffered = io_w.buffered();
         if (is_windows) {
             var i: usize = 0;
             while (i < buffered.len) {
                 const n = windows.WriteFile(handle, buffered[i..], null) catch |err| {
                     w.err = err;
                     w.pos += i;
-                    _ = io_writer.consume(i);
+                    _ = io_w.consume(i);
                     return error.WriteFailed;
                 };
                 i += n;
                 if (data.len > 0 and buffered.len - i < n) {
                     w.pos += i;
-                    return io_writer.consume(i);
+                    return io_w.consume(i);
                 }
             }
             if (i != 0 or data.len == 0 or (data.len == 1 and splat == 0)) {
                 w.pos += i;
-                return io_writer.consume(i);
+                return io_w.consume(i);
             }
             const n = windows.WriteFile(handle, data[0], null) catch |err| {
                 w.err = err;
@@ -1679,19 +1678,6 @@ pub const Writer = struct {
             w.pos += n;
             return n;
         }
-        if (data.len == 0) {
-            var i: usize = 0;
-            while (i < buffered.len) {
-                i += std.posix.write(handle, buffered) catch |err| {
-                    w.err = err;
-                    w.pos += i;
-                    _ = io_writer.consume(i);
-                    return error.WriteFailed;
-                };
-            }
-            w.pos += i;
-            return io_writer.consumeAll();
-        }
         var iovecs: [max_buffers_len]std.posix.iovec_const = undefined;
         var len: usize = 0;
         if (buffered.len > 0) {
@@ -1700,30 +1686,43 @@ pub const Writer = struct {
         }
         for (data) |d| {
             if (d.len == 0) continue;
-            if (iovecs.len - len == 0) break;
             iovecs[len] = .{ .base = d.ptr, .len = d.len };
             len += 1;
+            if (iovecs.len - len == 0) break;
         }
+        if (len == 0) return 0;
+        const pattern = data[data.len - 1];
         switch (splat) {
-            0 => if (data[data.len - 1].len != 0) {
+            0 => if (iovecs[len - 1].base == pattern.ptr) {
                 len -= 1;
             },
             1 => {},
-            else => switch (data[data.len - 1].len) {
+            else => switch (pattern.len) {
                 0 => {},
-                1 => {
+                1 => memset: {
+                    // Replace the 1-byte buffer with a bigger one.
+                    if (iovecs[len - 1].base == pattern.ptr) len -= 1;
+                    if (iovecs.len - len == 0) break :memset;
+                    const splat_buffer_candidate = io_w.buffer[io_w.end..];
+                    var backup_buffer: [64]u8 = undefined;
+                    const splat_buffer = if (splat_buffer_candidate.len >= backup_buffer.len)
+                        splat_buffer_candidate
+                    else
+                        &backup_buffer;
                     const memset_len = @min(splat_buffer.len, splat);
                     const buf = splat_buffer[0..memset_len];
-                    @memset(buf, data[data.len - 1][0]);
-                    iovecs[len - 1] = .{ .base = buf.ptr, .len = buf.len };
+                    @memset(buf, pattern[0]);
+                    iovecs[len] = .{ .base = buf.ptr, .len = buf.len };
+                    len += 1;
                     var remaining_splat = splat - buf.len;
-                    while (remaining_splat > splat_buffer.len and len < iovecs.len) {
-                        iovecs[len] = .{ .base = &splat_buffer, .len = splat_buffer.len };
-                        remaining_splat -= splat_buffer.len;
+                    while (remaining_splat > splat_buffer.len and iovecs.len - len != 0) {
+                        assert(buf.len == splat_buffer.len);
+                        iovecs[len] = .{ .base = splat_buffer.ptr, .len = splat_buffer.len };
                         len += 1;
+                        remaining_splat -= splat_buffer.len;
                     }
-                    if (remaining_splat > 0 and len < iovecs.len) {
-                        iovecs[len] = .{ .base = &splat_buffer, .len = remaining_splat };
+                    if (remaining_splat > 0 and iovecs.len - len != 0) {
+                        iovecs[len] = .{ .base = splat_buffer.ptr, .len = remaining_splat };
                         len += 1;
                     }
                     return std.posix.writev(handle, iovecs[0..len]) catch |err| {
@@ -1733,7 +1732,7 @@ pub const Writer = struct {
                 },
                 else => for (0..splat - 1) |_| {
                     if (iovecs.len - len == 0) break;
-                    iovecs[len] = .{ .base = data[data.len - 1].ptr, .len = data[data.len - 1].len };
+                    iovecs[len] = .{ .base = pattern.ptr, .len = pattern.len };
                     len += 1;
                 },
             },
@@ -1743,15 +1742,15 @@ pub const Writer = struct {
             return error.WriteFailed;
         };
         w.pos += n;
-        return io_writer.consume(n);
+        return io_w.consume(n);
     }
 
     pub fn sendFile(
-        io_writer: *std.io.Writer,
+        io_w: *std.io.Writer,
         file_reader: *Reader,
         limit: std.io.Limit,
     ) std.io.Writer.FileError!usize {
-        const w: *Writer = @fieldParentPtr("interface", io_writer);
+        const w: *Writer = @fieldParentPtr("interface", io_w);
         const out_fd = w.file.handle;
         const in_fd = file_reader.file.handle;
         // TODO try using copy_file_range on FreeBSD
@@ -1762,7 +1761,7 @@ pub const Writer = struct {
             if (w.sendfile_err != null) break :sf;
             // Linux sendfile does not support headers.
             const buffered = limit.slice(file_reader.interface.buffer);
-            if (io_writer.end != 0 or buffered.len != 0) return drain(io_writer, &.{buffered}, 1);
+            if (io_w.end != 0 or buffered.len != 0) return drain(io_w, &.{buffered}, 1);
             const max_count = 0x7ffff000; // Avoid EINVAL.
             var off: std.os.linux.off_t = undefined;
             const off_ptr: ?*std.os.linux.off_t, const count: usize = switch (file_reader.mode) {
@@ -1813,7 +1812,7 @@ pub const Writer = struct {
         if (copy_file_range_fn) |copy_file_range| cfr: {
             if (w.copy_file_range_err != null) break :cfr;
             const buffered = limit.slice(file_reader.interface.buffer);
-            if (io_writer.end != 0 or buffered.len != 0) return drain(io_writer, &.{buffered}, 1);
+            if (io_w.end != 0 or buffered.len != 0) return drain(io_w, &.{buffered}, 1);
             var off_in: i64 = undefined;
             var off_out: i64 = undefined;
             const off_in_ptr: ?*i64 = switch (file_reader.mode) {