Commit a333bb91ff

Andrew Kelley <andrew@ziglang.org>
2023-03-10 05:17:56
zig objcopy: support the compiler protocol
This commit extracts out server code into src/Server.zig and uses it both in the main CLI as well as `zig objcopy`. std.Build.ObjCopyStep now adds `--listen=-` to the CLI for `zig objcopy` and observes the protocol for progress and other kinds of integrations. This fixes the last two test failures of this branch when I run `zig build test` locally.
1 parent 59f5df3
lib/std/Build/ObjCopyStep.zig
@@ -113,6 +113,8 @@ fn make(step: *Step, prog_node: *std.Progress.Node) !void {
     };
 
     try argv.appendSlice(&.{ full_src_path, full_dest_path });
+
+    try argv.append("--listen=-");
     _ = try step.evalZigProcess(argv.items, prog_node);
 
     self.output_file.path = full_dest_path;
src/main.zig
@@ -26,6 +26,7 @@ const target_util = @import("target.zig");
 const crash_report = @import("crash_report.zig");
 const Module = @import("Module.zig");
 const AstGen = @import("AstGen.zig");
+const Server = @import("Server.zig");
 
 pub const std_options = struct {
     pub const wasiCwd = wasi_cwd;
@@ -3540,11 +3541,14 @@ fn serve(
 ) !void {
     const gpa = comp.gpa;
 
-    try serveStringMessage(out, .zig_version, build_options.version);
+    var server = try Server.init(.{
+        .gpa = gpa,
+        .in = in,
+        .out = out,
+    });
+    defer server.deinit();
 
     var child_pid: ?std.ChildProcess.Id = null;
-    var receive_fifo = std.fifo.LinearFifo(u8, .Dynamic).init(gpa);
-    defer receive_fifo.deinit();
 
     var progress: std.Progress = .{
         .terminal = null,
@@ -3564,7 +3568,7 @@ fn serve(
     main_progress_node.context = &progress;
 
     while (true) {
-        const hdr = try receiveMessage(in, &receive_fifo);
+        const hdr = try server.receiveMessage();
 
         switch (hdr.tag) {
             .exit => {
@@ -3580,7 +3584,7 @@ fn serve(
                     const arena = arena_instance.allocator();
                     var output: TranslateCOutput = undefined;
                     try cmdTranslateC(comp, arena, &output);
-                    try serveEmitBinPath(out, output.path, .{
+                    try server.serveEmitBinPath(output.path, .{
                         .flags = .{ .cache_hit = output.cache_hit },
                     });
                     continue;
@@ -3594,7 +3598,7 @@ fn serve(
                     var reset: std.Thread.ResetEvent = .{};
 
                     var progress_thread = try std.Thread.spawn(.{}, progressThread, .{
-                        &progress, out, &reset,
+                        &progress, &server, &reset,
                     });
                     defer {
                         reset.set();
@@ -3605,7 +3609,7 @@ fn serve(
                 }
 
                 try comp.makeBinFileExecutable();
-                try serveUpdateResults(out, comp);
+                try serveUpdateResults(&server, comp);
             },
             .run => {
                 if (child_pid != null) {
@@ -3632,14 +3636,14 @@ fn serve(
                 assert(main_progress_node.recently_updated_child == null);
                 if (child_pid) |pid| {
                     try comp.hotCodeSwap(main_progress_node, pid);
-                    try serveUpdateResults(out, comp);
+                    try serveUpdateResults(&server, comp);
                 } else {
                     if (comp.bin_file.options.output_mode == .Exe) {
                         try comp.makeBinFileWritable();
                     }
                     try comp.update(main_progress_node);
                     try comp.makeBinFileExecutable();
-                    try serveUpdateResults(out, comp);
+                    try serveUpdateResults(&server, comp);
 
                     child_pid = try runOrTestHotSwap(
                         comp,
@@ -3659,7 +3663,7 @@ fn serve(
     }
 }
 
-fn progressThread(progress: *std.Progress, out: fs.File, reset: *std.Thread.ResetEvent) void {
+fn progressThread(progress: *std.Progress, server: *const Server, reset: *std.Thread.ResetEvent) void {
     while (true) {
         if (reset.timedWait(500 * std.time.ns_per_ms)) |_| {
             // The Compilation update has completed.
@@ -3705,7 +3709,7 @@ fn progressThread(progress: *std.Progress, out: fs.File, reset: *std.Thread.Rese
 
         const progress_string = buf.slice();
 
-        serveMessage(out, .{
+        server.serveMessage(.{
             .tag = .progress,
             .bytes_len = @intCast(u32, progress_string.len),
         }, &.{
@@ -3716,100 +3720,21 @@ fn progressThread(progress: *std.Progress, out: fs.File, reset: *std.Thread.Rese
     }
 }
 
-fn serveMessage(
-    out: fs.File,
-    header: std.zig.Server.Message.Header,
-    bufs: []const []const u8,
-) !void {
-    var iovecs: [10]std.os.iovec_const = undefined;
-    iovecs[0] = .{
-        .iov_base = @ptrCast([*]const u8, &header),
-        .iov_len = @sizeOf(std.zig.Server.Message.Header),
-    };
-    for (bufs, iovecs[1 .. bufs.len + 1]) |buf, *iovec| {
-        iovec.* = .{
-            .iov_base = buf.ptr,
-            .iov_len = buf.len,
-        };
-    }
-    try out.writevAll(iovecs[0 .. bufs.len + 1]);
-}
-
-fn serveErrorBundle(out: fs.File, error_bundle: std.zig.ErrorBundle) !void {
-    const eb_hdr: std.zig.Server.Message.ErrorBundle = .{
-        .extra_len = @intCast(u32, error_bundle.extra.len),
-        .string_bytes_len = @intCast(u32, error_bundle.string_bytes.len),
-    };
-    const bytes_len = @sizeOf(std.zig.Server.Message.ErrorBundle) +
-        4 * error_bundle.extra.len + error_bundle.string_bytes.len;
-    try serveMessage(out, .{
-        .tag = .error_bundle,
-        .bytes_len = @intCast(u32, bytes_len),
-    }, &.{
-        std.mem.asBytes(&eb_hdr),
-        // TODO: implement @ptrCast between slices changing the length
-        std.mem.sliceAsBytes(error_bundle.extra),
-        error_bundle.string_bytes,
-    });
-}
-
-fn serveUpdateResults(out: fs.File, comp: *Compilation) !void {
+fn serveUpdateResults(s: *Server, comp: *Compilation) !void {
     const gpa = comp.gpa;
     var error_bundle = try comp.getAllErrorsAlloc();
     defer error_bundle.deinit(gpa);
     if (error_bundle.errorMessageCount() > 0) {
-        try serveErrorBundle(out, error_bundle);
+        try s.serveErrorBundle(error_bundle);
     } else if (comp.bin_file.options.emit) |emit| {
         const full_path = try emit.directory.join(gpa, &.{emit.sub_path});
         defer gpa.free(full_path);
-        try serveEmitBinPath(out, full_path, .{
+        try s.serveEmitBinPath(full_path, .{
             .flags = .{ .cache_hit = comp.last_update_was_cache_hit },
         });
     }
 }
 
-fn serveEmitBinPath(
-    out: fs.File,
-    fs_path: []const u8,
-    header: std.zig.Server.Message.EmitBinPath,
-) !void {
-    try serveMessage(out, .{
-        .tag = .emit_bin_path,
-        .bytes_len = @intCast(u32, fs_path.len + @sizeOf(std.zig.Server.Message.EmitBinPath)),
-    }, &.{
-        std.mem.asBytes(&header),
-        fs_path,
-    });
-}
-
-fn serveStringMessage(out: fs.File, tag: std.zig.Server.Message.Tag, s: []const u8) !void {
-    try serveMessage(out, .{
-        .tag = tag,
-        .bytes_len = @intCast(u32, s.len),
-    }, &.{s});
-}
-
-fn receiveMessage(in: fs.File, fifo: *std.fifo.LinearFifo(u8, .Dynamic)) !std.zig.Client.Message.Header {
-    const Header = std.zig.Client.Message.Header;
-
-    while (true) {
-        const buf = fifo.readableSlice(0);
-        assert(fifo.readableLength() == buf.len);
-        if (buf.len >= @sizeOf(Header)) {
-            const header = @ptrCast(*align(1) const Header, buf[0..@sizeOf(Header)]);
-            if (header.bytes_len != 0)
-                return error.InvalidClientMessage;
-            const result = header.*;
-            fifo.discard(@sizeOf(Header));
-            return result;
-        }
-
-        const write_buffer = try fifo.writableWithSize(256);
-        const amt = try in.read(write_buffer);
-        fifo.update(amt);
-    }
-}
-
 const ModuleDepIterator = struct {
     split: mem.SplitIterator(u8),
 
src/objcopy.zig
@@ -4,22 +4,25 @@ const fs = std.fs;
 const elf = std.elf;
 const Allocator = std.mem.Allocator;
 const File = std.fs.File;
+const assert = std.debug.assert;
+
 const main = @import("main.zig");
 const fatal = main.fatal;
 const cleanExit = main.cleanExit;
+const Server = @import("Server.zig");
 
 pub fn cmdObjCopy(
     gpa: Allocator,
     arena: Allocator,
     args: []const []const u8,
 ) !void {
-    _ = gpa;
     var i: usize = 0;
     var opt_out_fmt: ?std.Target.ObjectFormat = null;
     var opt_input: ?[]const u8 = null;
     var opt_output: ?[]const u8 = null;
     var only_section: ?[]const u8 = null;
     var pad_to: ?u64 = null;
+    var listen = false;
     while (i < args.len) : (i += 1) {
         const arg = args[i];
         if (!mem.startsWith(u8, arg, "-")) {
@@ -54,6 +57,8 @@ pub fn cmdObjCopy(
             i += 1;
             if (i >= args.len) fatal("expected another argument after '{s}'", .{arg});
             only_section = args[i];
+        } else if (mem.eql(u8, arg, "--listen=-")) {
+            listen = true;
         } else if (mem.startsWith(u8, arg, "--only-section=")) {
             only_section = arg["--output-target=".len..];
         } else if (mem.eql(u8, arg, "--pad-to")) {
@@ -102,10 +107,44 @@ pub fn cmdObjCopy(
                 .only_section = only_section,
                 .pad_to = pad_to,
             });
-            return cleanExit();
         },
         else => fatal("unsupported output object format: {s}", .{@tagName(out_fmt)}),
     }
+
+    if (listen) {
+        var server = try Server.init(.{
+            .gpa = gpa,
+            .in = std.io.getStdIn(),
+            .out = std.io.getStdOut(),
+        });
+        defer server.deinit();
+
+        var seen_update = false;
+        while (true) {
+            const hdr = try server.receiveMessage();
+            switch (hdr.tag) {
+                .exit => {
+                    return cleanExit();
+                },
+                .update => {
+                    if (seen_update) {
+                        std.debug.print("zig objcopy only supports 1 update for now\n", .{});
+                        std.process.exit(1);
+                    }
+                    seen_update = true;
+
+                    try server.serveEmitBinPath(output, .{
+                        .flags = .{ .cache_hit = false },
+                    });
+                },
+                else => {
+                    std.debug.print("unsupported message: {s}", .{@tagName(hdr.tag)});
+                    std.process.exit(1);
+                },
+            }
+        }
+    }
+    return cleanExit();
 }
 
 const usage =
@@ -417,7 +456,7 @@ const HexWriter = struct {
         }
 
         fn Address(address: u32) Record {
-            std.debug.assert(address > 0xFFFF);
+            assert(address > 0xFFFF);
             const segment = @intCast(u16, address / 0x10000);
             if (address > 0xFFFFF) {
                 return Record{
@@ -460,7 +499,7 @@ const HexWriter = struct {
             const BUFSIZE = 1 + (1 + 2 + 1 + MAX_PAYLOAD_LEN + 1) * 2 + linesep.len;
             var outbuf: [BUFSIZE]u8 = undefined;
             const payload_bytes = self.getPayloadBytes();
-            std.debug.assert(payload_bytes.len <= MAX_PAYLOAD_LEN);
+            assert(payload_bytes.len <= MAX_PAYLOAD_LEN);
 
             const line = try std.fmt.bufPrint(&outbuf, ":{0X:0>2}{1X:0>4}{2X:0>2}{3s}{4X:0>2}" ++ linesep, .{
                 @intCast(u8, payload_bytes.len),
src/Server.zig
@@ -0,0 +1,113 @@
+in: std.fs.File,
+out: std.fs.File,
+receive_fifo: std.fifo.LinearFifo(u8, .Dynamic),
+
+pub const Options = struct {
+    gpa: Allocator,
+    in: std.fs.File,
+    out: std.fs.File,
+};
+
+pub fn init(options: Options) !Server {
+    var s: Server = .{
+        .in = options.in,
+        .out = options.out,
+        .receive_fifo = std.fifo.LinearFifo(u8, .Dynamic).init(options.gpa),
+    };
+    try s.serveStringMessage(.zig_version, build_options.version);
+    return s;
+}
+
+pub fn deinit(s: *Server) void {
+    s.receive_fifo.deinit();
+    s.* = undefined;
+}
+
+pub fn receiveMessage(s: *Server) !InMessage.Header {
+    const Header = InMessage.Header;
+    const fifo = &s.receive_fifo;
+
+    while (true) {
+        const buf = fifo.readableSlice(0);
+        assert(fifo.readableLength() == buf.len);
+        if (buf.len >= @sizeOf(Header)) {
+            const header = @ptrCast(*align(1) const Header, buf[0..@sizeOf(Header)]);
+            if (header.bytes_len != 0)
+                return error.InvalidClientMessage;
+            const result = header.*;
+            fifo.discard(@sizeOf(Header));
+            return result;
+        }
+
+        const write_buffer = try fifo.writableWithSize(256);
+        const amt = try s.in.read(write_buffer);
+        fifo.update(amt);
+    }
+}
+
+pub fn serveStringMessage(s: *Server, tag: OutMessage.Tag, msg: []const u8) !void {
+    return s.serveMessage(.{
+        .tag = tag,
+        .bytes_len = @intCast(u32, msg.len),
+    }, &.{msg});
+}
+
+pub fn serveMessage(
+    s: *const Server,
+    header: OutMessage.Header,
+    bufs: []const []const u8,
+) !void {
+    var iovecs: [10]std.os.iovec_const = undefined;
+    iovecs[0] = .{
+        .iov_base = @ptrCast([*]const u8, &header),
+        .iov_len = @sizeOf(OutMessage.Header),
+    };
+    for (bufs, iovecs[1 .. bufs.len + 1]) |buf, *iovec| {
+        iovec.* = .{
+            .iov_base = buf.ptr,
+            .iov_len = buf.len,
+        };
+    }
+    try s.out.writevAll(iovecs[0 .. bufs.len + 1]);
+}
+
+pub fn serveEmitBinPath(
+    s: *Server,
+    fs_path: []const u8,
+    header: std.zig.Server.Message.EmitBinPath,
+) !void {
+    try s.serveMessage(.{
+        .tag = .emit_bin_path,
+        .bytes_len = @intCast(u32, fs_path.len + @sizeOf(std.zig.Server.Message.EmitBinPath)),
+    }, &.{
+        std.mem.asBytes(&header),
+        fs_path,
+    });
+}
+
+pub fn serveErrorBundle(s: *Server, error_bundle: std.zig.ErrorBundle) !void {
+    const eb_hdr: std.zig.Server.Message.ErrorBundle = .{
+        .extra_len = @intCast(u32, error_bundle.extra.len),
+        .string_bytes_len = @intCast(u32, error_bundle.string_bytes.len),
+    };
+    const bytes_len = @sizeOf(std.zig.Server.Message.ErrorBundle) +
+        4 * error_bundle.extra.len + error_bundle.string_bytes.len;
+    try s.serveMessage(.{
+        .tag = .error_bundle,
+        .bytes_len = @intCast(u32, bytes_len),
+    }, &.{
+        std.mem.asBytes(&eb_hdr),
+        // TODO: implement @ptrCast between slices changing the length
+        std.mem.sliceAsBytes(error_bundle.extra),
+        error_bundle.string_bytes,
+    });
+}
+
+const OutMessage = std.zig.Server.Message;
+const InMessage = std.zig.Client.Message;
+
+const Server = @This();
+const std = @import("std");
+const build_options = @import("build_options");
+const Allocator = std.mem.Allocator;
+const assert = std.debug.assert;
CMakeLists.txt
@@ -623,6 +623,7 @@ set(ZIG_STAGE2_SOURCES
     "${CMAKE_SOURCE_DIR}/src/print_targets.zig"
     "${CMAKE_SOURCE_DIR}/src/print_zir.zig"
     "${CMAKE_SOURCE_DIR}/src/register_manager.zig"
+    "${CMAKE_SOURCE_DIR}/src/Server.zig"
     "${CMAKE_SOURCE_DIR}/src/target.zig"
     "${CMAKE_SOURCE_DIR}/src/tracy.zig"
     "${CMAKE_SOURCE_DIR}/src/translate_c.zig"