Commit 0e1afb4d98

gwenzek <gwenzek@users.noreply.github.com>
2022-02-05 15:33:00
stage2: add support for Nvptx target
sample command: /home/guw/github/zig/stage2/bin/zig build-obj cuda_kernel.zig -target nvptx64-cuda -O ReleaseSafe this will create a kernel.ptx expose PtxKernel call convention from LLVM kernels are `export fn f() callconv(.PtxKernel)`
1 parent fbc06f9
lib/std/builtin.zig
@@ -147,6 +147,7 @@ pub const CallingConvention = enum {
     AAPCS,
     AAPCSVFP,
     SysV,
+    PtxKernel,
 };
 
 /// This data structure is used by the Zig language code generation and
lib/std/target.zig
@@ -579,6 +579,8 @@ pub const Target = struct {
         raw,
         /// Plan 9 from Bell Labs
         plan9,
+        /// Nvidia PTX format
+        nvptx,
 
         pub fn fileExt(of: ObjectFormat, cpu_arch: Cpu.Arch) [:0]const u8 {
             return switch (of) {
@@ -589,6 +591,7 @@ pub const Target = struct {
                 .hex => ".ihex",
                 .raw => ".bin",
                 .plan9 => plan9Ext(cpu_arch),
+                .nvptx => ".ptx",
             };
         }
     };
@@ -1388,6 +1391,7 @@ pub const Target = struct {
             else => return switch (cpu_arch) {
                 .wasm32, .wasm64 => .wasm,
                 .spirv32, .spirv64 => .spirv,
+                .nvptx, .nvptx64 => .nvptx,
                 else => .elf,
             },
         };
lib/std/zig.zig
@@ -181,6 +181,7 @@ pub fn binNameAlloc(allocator: std.mem.Allocator, options: BinNameOptions) error
             .Obj => return std.fmt.allocPrint(allocator, "{s}{s}", .{ root_name, ofmt.fileExt(target.cpu.arch) }),
             .Lib => return std.fmt.allocPrint(allocator, "{s}{s}.a", .{ target.libPrefix(), root_name }),
         },
+        .nvptx => return std.fmt.allocPrint(allocator, "{s}", .{root_name}),
     }
 }
 
src/codegen/llvm.zig
@@ -378,7 +378,7 @@ pub const Object = struct {
         const mod = comp.bin_file.options.module.?;
         const cache_dir = mod.zig_cache_artifact_directory;
 
-        const emit_bin_path: ?[*:0]const u8 = if (comp.bin_file.options.emit) |emit|
+        var emit_bin_path: ?[*:0]const u8 = if (comp.bin_file.options.emit) |emit|
             try emit.basenamePath(arena, try arena.dupeZ(u8, comp.bin_file.intermediary_basename.?))
         else
             null;
@@ -5078,6 +5078,10 @@ fn toLlvmCallConv(cc: std.builtin.CallingConvention, target: std.Target) llvm.Ca
         },
         .Signal => .AVR_SIGNAL,
         .SysV => .X86_64_SysV,
+        .PtxKernel => return switch (target.cpu.arch) {
+            .nvptx, .nvptx64 => .PTX_Kernel,
+            else => unreachable,
+        },
     };
 }
 
src/link/NvPtx.zig
@@ -0,0 +1,122 @@
+//! NVidia PTX (Paralle Thread Execution)
+//! https://docs.nvidia.com/cuda/parallel-thread-execution/index.html
+//! For this we rely on the nvptx backend of LLVM
+//! Kernel functions need to be marked both as "export" and "callconv(.PtxKernel)"
+
+const NvPtx = @This();
+
+const std = @import("std");
+const builtin = @import("builtin");
+
+const Allocator = std.mem.Allocator;
+const assert = std.debug.assert;
+const log = std.log.scoped(.link);
+
+const Module = @import("../Module.zig");
+const Compilation = @import("../Compilation.zig");
+const link = @import("../link.zig");
+const trace = @import("../tracy.zig").trace;
+const build_options = @import("build_options");
+const Air = @import("../Air.zig");
+const Liveness = @import("../Liveness.zig");
+const LlvmObject = @import("../codegen/llvm.zig").Object;
+
+base: link.File,
+llvm_object: *LlvmObject,
+
+pub fn createEmpty(gpa: Allocator, options: link.Options) !*NvPtx {
+    if (!build_options.have_llvm) return error.TODOArchNotSupported;
+
+    const nvptx = try gpa.create(NvPtx);
+    nvptx.* = .{
+        .base = .{
+            .tag = .nvptx,
+            .options = options,
+            .file = null,
+            .allocator = gpa,
+        },
+        .llvm_object = undefined,
+    };
+
+    switch (options.target.cpu.arch) {
+        .nvptx, .nvptx64 => {},
+        else => return error.TODOArchNotSupported,
+    }
+
+    switch (options.target.os.tag) {
+        // TODO: does it also work with nvcl ?
+        .cuda => {},
+        else => return error.TODOOsNotSupported,
+    }
+
+    return nvptx;
+}
+
+pub fn openPath(allocator: Allocator, sub_path: []const u8, options: link.Options) !*NvPtx {
+    if (!build_options.have_llvm) @panic("nvptx target requires a zig compiler with llvm enabled.");
+    if (!options.use_llvm) return error.TODOArchNotSupported;
+    assert(options.object_format == .nvptx);
+
+    const nvptx = try createEmpty(allocator, options);
+    errdefer nvptx.base.destroy();
+    log.info("Opening .ptx target file {s}", .{sub_path});
+    nvptx.llvm_object = try LlvmObject.create(allocator, options);
+    return nvptx;
+}
+
+pub fn deinit(self: *NvPtx) void {
+    if (!build_options.have_llvm) return;
+    self.llvm_object.destroy(self.base.allocator);
+}
+
+pub fn updateFunc(self: *NvPtx, module: *Module, func: *Module.Fn, air: Air, liveness: Liveness) !void {
+    if (!build_options.have_llvm) return;
+    try self.llvm_object.updateFunc(module, func, air, liveness);
+}
+
+pub fn updateDecl(self: *NvPtx, module: *Module, decl: *Module.Decl) !void {
+    if (!build_options.have_llvm) return;
+    return self.llvm_object.updateDecl(module, decl);
+}
+
+pub fn updateDeclExports(
+    self: *NvPtx,
+    module: *Module,
+    decl: *const Module.Decl,
+    exports: []const *Module.Export,
+) !void {
+    if (!build_options.have_llvm) return;
+    if (build_options.skip_non_native and builtin.object_format != .nvptx) {
+        @panic("Attempted to compile for object format that was disabled by build configuration");
+    }
+    return self.llvm_object.updateDeclExports(module, decl, exports);
+}
+
+pub fn freeDecl(self: *NvPtx, decl: *Module.Decl) void {
+    if (!build_options.have_llvm) return;
+    return self.llvm_object.freeDecl(decl);
+}
+
+pub fn flush(self: *NvPtx, comp: *Compilation) !void {
+    return self.flushModule(comp);
+}
+
+pub fn flushModule(self: *NvPtx, comp: *Compilation) !void {
+    if (!build_options.have_llvm) return;
+    if (build_options.skip_non_native) {
+        @panic("Attempted to compile for architecture that was disabled by build configuration");
+    }
+    const tracy = trace(@src());
+    defer tracy.end();
+
+    var hack_comp = comp;
+    if (comp.bin_file.options.emit) |emit| {
+        hack_comp.emit_asm = .{
+            .directory = emit.directory,
+            .basename = comp.bin_file.intermediary_basename.?,
+        };
+        hack_comp.bin_file.options.emit = null;
+    }
+
+    return try self.llvm_object.flushModule(hack_comp);
+}
src/stage1/all_types.hpp
@@ -83,7 +83,8 @@ enum CallingConvention {
     CallingConventionAPCS,
     CallingConventionAAPCS,
     CallingConventionAAPCSVFP,
-    CallingConventionSysV
+    CallingConventionSysV,
+    CallingConventionPtxKernel
 };
 
 // Stage 1 supports only the generic address space
src/stage1/analyze.cpp
@@ -991,6 +991,7 @@ const char *calling_convention_name(CallingConvention cc) {
         case CallingConventionAAPCSVFP: return "AAPCSVFP";
         case CallingConventionInline: return "Inline";
         case CallingConventionSysV: return "SysV";
+        case CallingConventionPtxKernel: return "PtxKernel";
     }
     zig_unreachable();
 }
@@ -1000,6 +1001,7 @@ bool calling_convention_allows_zig_types(CallingConvention cc) {
         case CallingConventionUnspecified:
         case CallingConventionAsync:
         case CallingConventionInline:
+        case CallingConventionPtxKernel:
             return true;
         case CallingConventionC:
         case CallingConventionNaked:
@@ -2006,6 +2008,15 @@ Error emit_error_unless_callconv_allowed_for_target(CodeGen *g, AstNode *source_
         case CallingConventionSysV:
             if (g->zig_target->arch != ZigLLVM_x86_64)
                 allowed_platforms = "x86_64";
+            break;
+      case CallingConventionPtxKernel:
+            if (g->zig_target->arch != ZigLLVM_nvptx
+                && g->zig_target->arch != ZigLLVM_nvptx64)
+            {
+                allowed_platforms = "nvptx and nvptx64";
+            }
+            break;
+
     }
     if (allowed_platforms != nullptr) {
         add_node_error(g, source_node, buf_sprintf(
@@ -3827,6 +3838,7 @@ static void resolve_decl_fn(CodeGen *g, TldFn *tld_fn) {
                 case CallingConventionAAPCS:
                 case CallingConventionAAPCSVFP:
                 case CallingConventionSysV:
+                case CallingConventionPtxKernel:
                     add_fn_export(g, fn_table_entry, buf_ptr(&fn_table_entry->symbol_name),
                                   GlobalLinkageIdStrong, fn_cc);
                     break;
src/stage1/codegen.cpp
@@ -209,6 +209,11 @@ static ZigLLVM_CallingConv get_llvm_cc(CodeGen *g, CallingConvention cc) {
         case CallingConventionSysV:
             assert(g->zig_target->arch == ZigLLVM_x86_64);
             return ZigLLVM_X86_64_SysV;
+        case CallingConventionPtxKernel:
+            assert(g->zig_target->arch == ZigLLVM_nvptx ||
+                g->zig_target->arch == ZigLLVM_nvptx64);
+                return ZigLLVM_PTX_Kernel;
+
     }
     zig_unreachable();
 }
@@ -354,6 +359,7 @@ static bool cc_want_sret_attr(CallingConvention cc) {
         case CallingConventionAAPCS:
         case CallingConventionAAPCSVFP:
         case CallingConventionSysV:
+        case CallingConventionPtxKernel:
             return true;
         case CallingConventionAsync:
         case CallingConventionUnspecified:
src/stage1/ir.cpp
@@ -11666,6 +11666,7 @@ static Stage1AirInst *ir_analyze_instruction_export(IrAnalyze *ira, Stage1ZirIns
                 case CallingConventionAAPCS:
                 case CallingConventionAAPCSVFP:
                 case CallingConventionSysV:
+                case CallingConventionPtxKernel:
                     add_fn_export(ira->codegen, fn_entry, buf_ptr(symbol_name), global_linkage_id, cc);
                     fn_entry->section_name = section_name;
                     break;
src/link.zig
@@ -215,6 +215,7 @@ pub const File = struct {
         c: void,
         wasm: Wasm.DeclBlock,
         spirv: void,
+        nvptx: void,
     };
 
     pub const LinkFn = union {
@@ -225,6 +226,7 @@ pub const File = struct {
         c: void,
         wasm: Wasm.FnData,
         spirv: SpirV.FnData,
+        nvptx: void,
     };
 
     pub const Export = union {
@@ -235,6 +237,7 @@ pub const File = struct {
         c: void,
         wasm: void,
         spirv: void,
+        nvptx: void,
     };
 
     /// For DWARF .debug_info.
@@ -274,6 +277,7 @@ pub const File = struct {
                 .plan9 => return &(try Plan9.createEmpty(allocator, options)).base,
                 .c => unreachable, // Reported error earlier.
                 .spirv => &(try SpirV.createEmpty(allocator, options)).base,
+                .nvptx => &(try NvPtx.createEmpty(allocator, options)).base,
                 .hex => return error.HexObjectFormatUnimplemented,
                 .raw => return error.RawObjectFormatUnimplemented,
             };
@@ -292,6 +296,7 @@ pub const File = struct {
                     .wasm => &(try Wasm.createEmpty(allocator, options)).base,
                     .c => unreachable, // Reported error earlier.
                     .spirv => &(try SpirV.createEmpty(allocator, options)).base,
+                    .nvptx => &(try NvPtx.createEmpty(allocator, options)).base,
                     .hex => return error.HexObjectFormatUnimplemented,
                     .raw => return error.RawObjectFormatUnimplemented,
                 };
@@ -312,6 +317,7 @@ pub const File = struct {
             .wasm => &(try Wasm.openPath(allocator, sub_path, options)).base,
             .c => &(try C.openPath(allocator, sub_path, options)).base,
             .spirv => &(try SpirV.openPath(allocator, sub_path, options)).base,
+            .nvptx => &(try NvPtx.openPath(allocator, sub_path, options)).base,
             .hex => return error.HexObjectFormatUnimplemented,
             .raw => return error.RawObjectFormatUnimplemented,
         };
@@ -344,7 +350,7 @@ pub const File = struct {
                     .mode = determineMode(base.options),
                 });
             },
-            .c, .wasm, .spirv => {},
+            .c, .wasm, .spirv, .nvptx => {},
         }
     }
 
@@ -389,7 +395,7 @@ pub const File = struct {
                 f.close();
                 base.file = null;
             },
-            .c, .wasm, .spirv => {},
+            .c, .wasm, .spirv, .nvptx => {},
         }
     }
 
@@ -437,6 +443,7 @@ pub const File = struct {
             .wasm  => return @fieldParentPtr(Wasm,  "base", base).updateDecl(module, decl),
             .spirv => return @fieldParentPtr(SpirV, "base", base).updateDecl(module, decl),
             .plan9 => return @fieldParentPtr(Plan9, "base", base).updateDecl(module, decl),
+            .nvptx => return @fieldParentPtr(NvPtx, "base", base).updateDecl(module, decl),
             // zig fmt: on
         }
     }
@@ -456,6 +463,7 @@ pub const File = struct {
             .wasm  => return @fieldParentPtr(Wasm,  "base", base).updateFunc(module, func, air, liveness),
             .spirv => return @fieldParentPtr(SpirV, "base", base).updateFunc(module, func, air, liveness),
             .plan9 => return @fieldParentPtr(Plan9, "base", base).updateFunc(module, func, air, liveness),
+            .nvptx => return @fieldParentPtr(NvPtx, "base", base).updateFunc(module, func, air, liveness),
             // zig fmt: on
         }
     }
@@ -471,7 +479,7 @@ pub const File = struct {
             .macho => return @fieldParentPtr(MachO, "base", base).updateDeclLineNumber(module, decl),
             .c => return @fieldParentPtr(C, "base", base).updateDeclLineNumber(module, decl),
             .plan9 => @panic("TODO: implement updateDeclLineNumber for plan9"),
-            .wasm, .spirv => {},
+            .wasm, .spirv, .nvptx => {},
         }
     }
 
@@ -493,7 +501,7 @@ pub const File = struct {
             },
             .wasm => return @fieldParentPtr(Wasm, "base", base).allocateDeclIndexes(decl),
             .plan9 => return @fieldParentPtr(Plan9, "base", base).allocateDeclIndexes(decl),
-            .c, .spirv => {},
+            .c, .spirv, .nvptx => {},
         }
     }
 
@@ -551,6 +559,11 @@ pub const File = struct {
                 parent.deinit();
                 base.allocator.destroy(parent);
             },
+            .nvptx => {
+                const parent = @fieldParentPtr(NvPtx, "base", base);
+                parent.deinit();
+                base.allocator.destroy(parent);
+            },
         }
     }
 
@@ -584,6 +597,7 @@ pub const File = struct {
             .wasm => return @fieldParentPtr(Wasm, "base", base).flush(comp),
             .spirv => return @fieldParentPtr(SpirV, "base", base).flush(comp),
             .plan9 => return @fieldParentPtr(Plan9, "base", base).flush(comp),
+            .nvptx => return @fieldParentPtr(NvPtx, "base", base).flush(comp),
         }
     }
 
@@ -598,6 +612,7 @@ pub const File = struct {
             .wasm => return @fieldParentPtr(Wasm, "base", base).flushModule(comp),
             .spirv => return @fieldParentPtr(SpirV, "base", base).flushModule(comp),
             .plan9 => return @fieldParentPtr(Plan9, "base", base).flushModule(comp),
+            .nvptx => return @fieldParentPtr(NvPtx, "base", base).flushModule(comp),
         }
     }
 
@@ -612,6 +627,7 @@ pub const File = struct {
             .wasm => @fieldParentPtr(Wasm, "base", base).freeDecl(decl),
             .spirv => @fieldParentPtr(SpirV, "base", base).freeDecl(decl),
             .plan9 => @fieldParentPtr(Plan9, "base", base).freeDecl(decl),
+            .nvptx => @fieldParentPtr(NvPtx, "base", base).freeDecl(decl),
         }
     }
 
@@ -622,7 +638,7 @@ pub const File = struct {
             .macho => return @fieldParentPtr(MachO, "base", base).error_flags,
             .plan9 => return @fieldParentPtr(Plan9, "base", base).error_flags,
             .c => return .{ .no_entry_point_found = false },
-            .wasm, .spirv => return ErrorFlags{},
+            .wasm, .spirv, .nvptx => return ErrorFlags{},
         }
     }
 
@@ -644,6 +660,7 @@ pub const File = struct {
             .wasm => return @fieldParentPtr(Wasm, "base", base).updateDeclExports(module, decl, exports),
             .spirv => return @fieldParentPtr(SpirV, "base", base).updateDeclExports(module, decl, exports),
             .plan9 => return @fieldParentPtr(Plan9, "base", base).updateDeclExports(module, decl, exports),
+            .nvptx => return @fieldParentPtr(NvPtx, "base", base).updateDeclExports(module, decl, exports),
         }
     }
 
@@ -656,6 +673,7 @@ pub const File = struct {
             .c => unreachable,
             .wasm => unreachable,
             .spirv => unreachable,
+            .nvptx => unreachable,
         }
     }
 
@@ -851,6 +869,7 @@ pub const File = struct {
         wasm,
         spirv,
         plan9,
+        nvptx,
     };
 
     pub const ErrorFlags = struct {
@@ -864,6 +883,7 @@ pub const File = struct {
     pub const MachO = @import("link/MachO.zig");
     pub const SpirV = @import("link/SpirV.zig");
     pub const Wasm = @import("link/Wasm.zig");
+    pub const NvPtx = @import("link/NvPtx.zig");
 };
 
 pub fn determineMode(options: Options) fs.File.Mode {
src/Module.zig
@@ -4242,7 +4242,7 @@ fn scanDecl(iter: *ScanDeclIter, decl_sub_index: usize, flags: u4) SemaError!voi
                 // in `Decl` to notice that the line number did not change.
                 mod.comp.work_queue.writeItemAssumeCapacity(.{ .update_line_number = decl });
             },
-            .c, .wasm, .spirv => {},
+            .c, .wasm, .spirv, .nvptx => {},
         }
     }
 }
@@ -4316,6 +4316,7 @@ pub fn clearDecl(
                 .c => .{ .c = {} },
                 .wasm => .{ .wasm = link.File.Wasm.DeclBlock.empty },
                 .spirv => .{ .spirv = {} },
+                .nvptx => .{ .nvptx = {} },
             };
             decl.fn_link = switch (mod.comp.bin_file.tag) {
                 .coff => .{ .coff = {} },
@@ -4325,6 +4326,7 @@ pub fn clearDecl(
                 .c => .{ .c = {} },
                 .wasm => .{ .wasm = link.File.Wasm.FnData.empty },
                 .spirv => .{ .spirv = .{} },
+                .nvptx => .{ .nvptx = .{} },
             };
         }
         if (decl.getInnerNamespace()) |namespace| {
@@ -4652,6 +4654,7 @@ pub fn allocateNewDecl(
             .c => .{ .c = {} },
             .wasm => .{ .wasm = link.File.Wasm.DeclBlock.empty },
             .spirv => .{ .spirv = {} },
+            .nvptx => .{ .nvptx = {} },
         },
         .fn_link = switch (mod.comp.bin_file.tag) {
             .coff => .{ .coff = {} },
@@ -4661,6 +4664,7 @@ pub fn allocateNewDecl(
             .c => .{ .c = {} },
             .wasm => .{ .wasm = link.File.Wasm.FnData.empty },
             .spirv => .{ .spirv = .{} },
+            .nvptx => .{ .nvptx = .{} },
         },
         .generation = 0,
         .is_pub = false,
src/Sema.zig
@@ -3724,6 +3724,7 @@ pub fn analyzeExport(
             .c => .{ .c = {} },
             .wasm => .{ .wasm = {} },
             .spirv => .{ .spirv = {} },
+            .nvptx => .{ .nvptx = {} },
         },
         .owner_decl = owner_decl,
         .src_decl = src_decl,