Commit 598db831f3

Sreehari Sreedev <sreeharisreedev1@gmail.com>
2021-07-18 11:13:04
FileProtocol: add Reader, Writer, SeekableStream
1 parent c23768a
Changed files (1)
lib
std
os
uefi
lib/std/os/uefi/protocols/file_protocol.zig
@@ -1,4 +1,6 @@
-const uefi = @import("std").os.uefi;
+const std = @import("std");
+const uefi = std.os.uefi;
+const io = std.io;
 const Guid = uefi.Guid;
 const Time = uefi.Time;
 const Status = uefi.Status;
@@ -16,6 +18,27 @@ pub const FileProtocol = extern struct {
     _set_info: fn (*const FileProtocol, *align(8) const Guid, usize, [*]const u8) callconv(.C) Status,
     _flush: fn (*const FileProtocol) callconv(.C) Status,
 
+    pub const SeekError = error{SeekError};
+    pub const GetSeekPosError = error{GetSeekPosError};
+    pub const ReadError = error{ReadError};
+    pub const WriteError = error{WriteError};
+
+    pub const SeekableStream = io.SeekableStream(*const FileProtocol, SeekError, GetSeekPosError, seekTo, seekBy, getPos, getEndPos);
+    pub const Reader = io.Reader(*const FileProtocol, ReadError, readFn);
+    pub const Writer = io.Writer(*const FileProtocol, WriteError, writeFn);
+
+    pub fn seekableStream(self: *FileProtocol) SeekableStream {
+        return .{ .context = self };
+    }
+
+    pub fn reader(self: *FileProtocol) Reader {
+        return .{ .context = self };
+    }
+
+    pub fn writer(self: *FileProtocol) Writer {
+        return .{ .context = self };
+    }
+
     pub fn open(self: *const FileProtocol, new_handle: **const FileProtocol, file_name: [*:0]const u16, open_mode: u64, attributes: u64) Status {
         return self._open(self, new_handle, file_name, open_mode, attributes);
     }
@@ -32,18 +55,66 @@ pub const FileProtocol = extern struct {
         return self._read(self, buffer_size, buffer);
     }
 
+    fn readFn(self: *const FileProtocol, buffer: []u8) ReadError!usize {
+        var size: usize = buffer.len;
+        if (.Success != self.read(&size, buffer.ptr)) return ReadError.ReadError;
+        return size;
+    }
+
     pub fn write(self: *const FileProtocol, buffer_size: *usize, buffer: [*]const u8) Status {
         return self._write(self, buffer_size, buffer);
     }
 
+    fn writeFn(self: *const FileProtocol, bytes: []const u8) WriteError!usize {
+        var size: usize = bytes.len;
+        if (.Success != self.write(&size, bytes.ptr)) return WriteError.WriteError;
+        return size;
+    }
+
     pub fn getPosition(self: *const FileProtocol, position: *u64) Status {
         return self._get_position(self, position);
     }
 
+    fn getPos(self: *const FileProtocol) GetSeekPosError!u64 {
+        var pos: u64 = undefined;
+        if (.Success != self.getPosition(&pos)) return GetSeekPosError.GetSeekPosError;
+        return pos;
+    }
+
+    fn getEndPos(self: *const FileProtocol) GetSeekPosError!u64 {
+        // preserve the old file position
+        var pos: u64 = undefined;
+        if (.Success != self.getPosition(&pos)) return GetSeekPosError.GetSeekPosError;
+        // seek to end of file to get position = file size
+        if (.Success != self.setPosition(efi_file_position_end_of_file)) return GetSeekPosError.GetSeekPosError;
+        // restore the old position
+        if (.Success != self.setPosition(pos)) return GetSeekPosError.GetSeekPosError;
+        // return the file size = position
+        return pos;
+    }
+
     pub fn setPosition(self: *const FileProtocol, position: u64) Status {
         return self._set_position(self, position);
     }
 
+    fn seekTo(self: *const FileProtocol, pos: u64) SeekError!void {
+        if (.Success != self.setPosition(pos)) return SeekError.SeekError;
+    }
+
+    fn seekBy(self: *const FileProtocol, offset: i64) SeekError!void {
+        // save the old position and calculate the delta
+        var pos: u64 = undefined;
+        if (.Success != self.getPosition(&pos)) return SeekError.SeekError;
+        const seek_back = offset < 0;
+        const amt = std.math.absCast(offset);
+        if (seek_back) {
+            pos += amt;
+        } else {
+            pos -= amt;
+        }
+        if (.Success != self.setPosition(pos)) return SeekError.SeekError;
+    }
+
     pub fn getInfo(self: *const FileProtocol, information_type: *align(8) const Guid, buffer_size: *usize, buffer: [*]u8) Status {
         return self._get_info(self, information_type, buffer_size, buffer);
     }