Commit 628e9e6d04

gwenzek <gwenzek@users.noreply.github.com>
2022-02-21 20:05:27
enable Gpu address spaces (#10884)
1 parent d8da9a0
Changed files (5)
lib
src
test
lib/std/builtin.zig
@@ -157,6 +157,12 @@ pub const AddressSpace = enum {
     gs,
     fs,
     ss,
+    // GPU address spaces
+    global,
+    constant,
+    param,
+    shared,
+    local,
 };
 
 /// This data structure is used by the Zig language code generation and
src/codegen/llvm.zig
@@ -801,6 +801,16 @@ pub const DeclGen = struct {
                 .gs => llvm.address_space.x86.gs,
                 .fs => llvm.address_space.x86.fs,
                 .ss => llvm.address_space.x86.ss,
+                else => unreachable,
+            },
+            .nvptx, .nvptx64 => switch (address_space) {
+                .generic => llvm.address_space.default,
+                .global => llvm.address_space.nvptx.global,
+                .constant => llvm.address_space.nvptx.constant,
+                .param => llvm.address_space.nvptx.param,
+                .shared => llvm.address_space.nvptx.shared,
+                .local => llvm.address_space.nvptx.local,
+                else => unreachable,
             },
             else => switch (address_space) {
                 .generic => llvm.address_space.default,
src/Sema.zig
@@ -18006,10 +18006,14 @@ pub fn analyzeAddrspace(
     const address_space = addrspace_tv.val.toEnum(std.builtin.AddressSpace);
     const target = sema.mod.getTarget();
     const arch = target.cpu.arch;
+    const is_gpu = arch == .nvptx or arch == .nvptx64;
 
     const supported = switch (address_space) {
         .generic => true,
         .gs, .fs, .ss => (arch == .i386 or arch == .x86_64) and ctx == .pointer,
+        // TODO: check that .shared and .local are left uninitialized
+        .global, .param, .shared, .local => is_gpu,
+        .constant => is_gpu and (ctx == .constant),
     };
 
     if (!supported) {
@@ -18020,7 +18024,6 @@ pub fn analyzeAddrspace(
             .constant => "constant values",
             .pointer => "pointers",
         };
-
         return sema.fail(
             block,
             src,
test/stage2/nvptx.zig
@@ -0,0 +1,57 @@
+const std = @import("std");
+const TestContext = @import("../../src/test.zig").TestContext;
+
+const nvptx = std.zig.CrossTarget{
+    .cpu_arch = .nvptx64,
+    .os_tag = .cuda,
+};
+
+pub fn addCases(ctx: *TestContext) !void {
+    {
+        var case = ctx.exeUsingLlvmBackend("simple addition and subtraction", nvptx);
+
+        case.compiles(
+            \\fn add(a: i32, b: i32) i32 {
+            \\    return a + b;
+            \\}
+            \\
+            \\pub export fn main(a: i32, out: *i32) callconv(.PtxKernel) void {
+            \\    const x = add(a, 7);
+            \\    var y = add(2, 0);
+            \\    y -= x;
+            \\    out.* = y;
+            \\}
+        );
+    }
+
+    {
+        var case = ctx.exeUsingLlvmBackend("read special registers", nvptx);
+
+        case.compiles(
+            \\fn tid() usize {
+            \\     var tid = asm volatile ("mov.u32 \t$0, %tid.x;"
+            \\         : [ret] "=r" (-> u32),
+            \\     );
+            \\     return @as(usize, tid);
+            \\}
+            \\
+            \\pub export fn main(a: []const i32, out: []i32) callconv(.PtxKernel) void {
+            \\    const i = tid();
+            \\    out[i] = a[i] + 7;
+            \\}
+        );
+    }
+
+    {
+        var case = ctx.exeUsingLlvmBackend("address spaces", nvptx);
+
+        case.compiles(
+            \\var x: u32 addrspace(.global) = 0;
+            \\
+            \\pub export fn increment(out: *i32) callconv(.PtxKernel) void {
+            \\    x += 1;
+            \\    out.* = x;
+            \\}
+        );
+    }
+}
test/cases.zig
@@ -16,4 +16,5 @@ pub fn addCases(ctx: *TestContext) !void {
     try @import("stage2/riscv64.zig").addCases(ctx);
     try @import("stage2/plan9.zig").addCases(ctx);
     try @import("stage2/x86_64.zig").addCases(ctx);
+    try @import("stage2/nvptx.zig").addCases(ctx);
 }