Commit fd7feed04b

Andrew Kelley <andrew@ziglang.org>
2025-07-07 01:58:48
std.fs.File.Writer: implement positional writing
1 parent 7e2a26c
Changed files (1)
lib
std
lib/std/fs/File.zig
@@ -1240,6 +1240,8 @@ pub const Reader = struct {
     file: File,
     err: ?ReadError = null,
     mode: Reader.Mode = .positional,
+    /// Tracks the true seek position in the file. To obtain the logical
+    /// position, subtract the buffer size from this value.
     pos: u64 = 0,
     size: ?u64 = null,
     size_err: ?GetEndPosError = null,
@@ -1407,10 +1409,14 @@ pub const Reader = struct {
                 const n = posix.preadv(r.file.handle, dest, r.pos) catch |err| switch (err) {
                     error.Unseekable => {
                         r.mode = r.mode.toStreaming();
-                        if (r.pos != 0) r.seekBy(@intCast(r.pos)) catch {
-                            r.mode = .failure;
-                            return error.ReadFailed;
-                        };
+                        const pos = r.pos;
+                        if (pos != 0) {
+                            r.pos = 0;
+                            r.seekBy(@intCast(pos)) catch {
+                                r.mode = .failure;
+                                return error.ReadFailed;
+                            };
+                        }
                         return 0;
                     },
                     else => |e| {
@@ -1535,10 +1541,14 @@ pub const Reader = struct {
         const n = r.file.pread(dest, r.pos) catch |err| switch (err) {
             error.Unseekable => {
                 r.mode = r.mode.toStreaming();
-                if (r.pos != 0) r.seekBy(@intCast(r.pos)) catch {
-                    r.mode = .failure;
-                    return error.ReadFailed;
-                };
+                const pos = r.pos;
+                if (pos != 0) {
+                    r.pos = 0;
+                    r.seekBy(@intCast(pos)) catch {
+                        r.mode = .failure;
+                        return error.ReadFailed;
+                    };
+                }
                 return 0;
             },
             else => |e| {
@@ -1586,6 +1596,8 @@ pub const Writer = struct {
     file: File,
     err: ?WriteError = null,
     mode: Writer.Mode = .positional,
+    /// Tracks the true seek position in the file. To obtain the logical
+    /// position, add the buffer size to this value.
     pos: u64 = 0,
     sendfile_err: ?SendfileError = null,
     copy_file_range_err: ?CopyFileRangeError = null,
@@ -1652,32 +1664,36 @@ pub const Writer = struct {
         const w: *Writer = @fieldParentPtr("interface", io_w);
         const handle = w.file.handle;
         const buffered = io_w.buffered();
-        if (is_windows) {
-            var i: usize = 0;
-            while (i < buffered.len) {
-                const n = windows.WriteFile(handle, buffered[i..], null) catch |err| {
-                    w.err = err;
-                    w.pos += i;
-                    _ = io_w.consume(i);
-                    return error.WriteFailed;
-                };
-                i += n;
-                if (data.len > 0 and buffered.len - i < n) {
+        if (is_windows) switch (w.mode) {
+            .positional, .positional_reading => @panic("TODO"),
+            .streaming, .streaming_reading => {
+                var i: usize = 0;
+                while (i < buffered.len) {
+                    const n = windows.WriteFile(handle, buffered[i..], null) catch |err| {
+                        w.err = err;
+                        w.pos += i;
+                        _ = io_w.consume(i);
+                        return error.WriteFailed;
+                    };
+                    i += n;
+                    if (data.len > 0 and buffered.len - i < n) {
+                        w.pos += i;
+                        return io_w.consume(i);
+                    }
+                }
+                if (i != 0 or data.len == 0 or (data.len == 1 and splat == 0)) {
                     w.pos += i;
                     return io_w.consume(i);
                 }
-            }
-            if (i != 0 or data.len == 0 or (data.len == 1 and splat == 0)) {
-                w.pos += i;
-                return io_w.consume(i);
-            }
-            const n = windows.WriteFile(handle, data[0], null) catch |err| {
-                w.err = err;
-                return 0;
-            };
-            w.pos += n;
-            return n;
-        }
+                const n = windows.WriteFile(handle, data[0], null) catch |err| {
+                    w.err = err;
+                    return 0;
+                };
+                w.pos += n;
+                return n;
+            },
+            .failure => return error.WriteFailed,
+        };
         var iovecs: [max_buffers_len]std.posix.iovec_const = undefined;
         var len: usize = 0;
         if (buffered.len > 0) {
@@ -1733,12 +1749,39 @@ pub const Writer = struct {
                 },
             },
         }
-        const n = std.posix.writev(handle, iovecs[0..len]) catch |err| {
-            w.err = err;
-            return error.WriteFailed;
-        };
-        w.pos += n;
-        return io_w.consume(n);
+        switch (w.mode) {
+            .positional, .positional_reading => {
+                const n = std.posix.pwritev(handle, iovecs[0..len], w.pos) catch |err| switch (err) {
+                    error.Unseekable => {
+                        w.mode = w.mode.toStreaming();
+                        const pos = w.pos;
+                        if (pos != 0) {
+                            w.pos = 0;
+                            w.seekTo(@intCast(pos)) catch {
+                                w.mode = .failure;
+                                return error.WriteFailed;
+                            };
+                        }
+                        return 0;
+                    },
+                    else => |e| {
+                        w.err = e;
+                        return error.WriteFailed;
+                    },
+                };
+                w.pos += n;
+                return io_w.consume(n);
+            },
+            .streaming, .streaming_reading => {
+                const n = std.posix.writev(handle, iovecs[0..len]) catch |err| {
+                    w.err = err;
+                    return error.WriteFailed;
+                };
+                w.pos += n;
+                return io_w.consume(n);
+            },
+            .failure => return error.WriteFailed,
+        }
     }
 
     pub fn sendFile(
@@ -1781,10 +1824,14 @@ pub const Writer = struct {
             const n = std.os.linux.wrapped.sendfile(out_fd, in_fd, off_ptr, count) catch |err| switch (err) {
                 error.Unseekable => {
                     file_reader.mode = file_reader.mode.toStreaming();
-                    if (file_reader.pos != 0) file_reader.seekBy(@intCast(file_reader.pos)) catch {
-                        file_reader.mode = .failure;
-                        return error.ReadFailed;
-                    };
+                    const pos = file_reader.pos;
+                    if (pos != 0) {
+                        file_reader.pos = 0;
+                        file_reader.seekBy(@intCast(pos)) catch {
+                            file_reader.mode = .failure;
+                            return error.ReadFailed;
+                        };
+                    }
                     return 0;
                 },
                 else => |e| {
@@ -1877,17 +1924,19 @@ pub const Writer = struct {
     }
 
     pub fn seekTo(w: *Writer, offset: u64) SeekError!void {
-        if (w.seek_err) |err| return err;
         switch (w.mode) {
             .positional, .positional_reading => {
                 w.pos = offset;
             },
             .streaming, .streaming_reading => {
+                if (w.seek_err) |err| return err;
                 posix.lseek_SET(w.file.handle, offset) catch |err| {
                     w.seek_err = err;
                     return err;
                 };
+                w.pos = offset;
             },
+            .failure => return w.seek_err.?,
         }
     }