Commit 6b69dcfdd2

Robin Voetter <robin@voetter.nl>
2022-08-06 18:22:56
amdgpu: add AmdgpuKernel calling convention
1 parent 5dd3d17
lib/std/builtin.zig
@@ -157,6 +157,7 @@ pub const CallingConvention = enum {
     SysV,
     Win64,
     PtxKernel,
+    AmdgpuKernel,
 };
 
 /// This data structure is used by the Zig language code generation and
src/codegen/llvm.zig
@@ -4425,7 +4425,7 @@ pub const FuncGen = struct {
                 .cmp_lt  => try self.airCmp(inst, .lt, false),
                 .cmp_lte => try self.airCmp(inst, .lte, false),
                 .cmp_neq => try self.airCmp(inst, .neq, false),
-                
+
                 .cmp_eq_optimized  => try self.airCmp(inst, .eq, true),
                 .cmp_gt_optimized  => try self.airCmp(inst, .gt, true),
                 .cmp_gte_optimized => try self.airCmp(inst, .gte, true),
@@ -9807,6 +9807,10 @@ fn toLlvmCallConv(cc: std.builtin.CallingConvention, target: std.Target) llvm.Ca
             .nvptx, .nvptx64 => .PTX_Kernel,
             else => unreachable,
         },
+        .AmdgpuKernel => return switch (target.cpu.arch) {
+            .amdgcn => .AMDGPU_KERNEL,
+            else => unreachable,
+        },
     };
 }
 
src/stage1/all_types.hpp
@@ -85,7 +85,8 @@ enum CallingConvention {
     CallingConventionAAPCSVFP,
     CallingConventionSysV,
     CallingConventionWin64,
-    CallingConventionPtxKernel
+    CallingConventionPtxKernel,
+    CallingConventionAmdgpuKernel
 };
 
 // Stage 1 supports only the generic address space
src/stage1/analyze.cpp
@@ -993,6 +993,7 @@ const char *calling_convention_name(CallingConvention cc) {
         case CallingConventionSysV: return "SysV";
         case CallingConventionWin64: return "Win64";
         case CallingConventionPtxKernel: return "PtxKernel";
+        case CallingConventionAmdgpuKernel: return "AmdgpuKernel";
     }
     zig_unreachable();
 }
@@ -1017,6 +1018,7 @@ bool calling_convention_allows_zig_types(CallingConvention cc) {
         case CallingConventionAAPCSVFP:
         case CallingConventionSysV:
         case CallingConventionWin64:
+        case CallingConventionAmdgpuKernel:
             return false;
     }
     zig_unreachable();
@@ -2019,6 +2021,9 @@ Error emit_error_unless_callconv_allowed_for_target(CodeGen *g, AstNode *source_
                 allowed_platforms = "nvptx and nvptx64";
             }
             break;
+      case CallingConventionAmdgpuKernel:
+          if (g->zig_target->arch != ZigLLVM_amdgcn)
+              allowed_platforms = "amdgcn and amdpal";
 
     }
     if (allowed_platforms != nullptr) {
@@ -3857,6 +3862,7 @@ static void resolve_decl_fn(CodeGen *g, TldFn *tld_fn) {
                 case CallingConventionSysV:
                 case CallingConventionWin64:
                 case CallingConventionPtxKernel:
+                case CallingConventionAmdgpuKernel:
                     add_fn_export(g, fn_table_entry, buf_ptr(&fn_table_entry->symbol_name),
                                   GlobalLinkageIdStrong, fn_cc);
                     break;
@@ -6012,7 +6018,7 @@ Error type_has_bits2(CodeGen *g, ZigType *type_entry, bool *result) {
 
 bool fn_returns_c_abi_small_struct(FnTypeId *fn_type_id) {
     ZigType *type = fn_type_id->return_type;
-    return !calling_convention_allows_zig_types(fn_type_id->cc) && 
+    return !calling_convention_allows_zig_types(fn_type_id->cc) &&
         type->id == ZigTypeIdStruct && type->abi_size <= 16;
 }
 
@@ -8698,7 +8704,7 @@ static LLVMTypeRef llvm_int_for_size(size_t size) {
 static LLVMTypeRef llvm_sse_for_size(size_t size) {
     if (size > 4)
         return LLVMDoubleType();
-    else 
+    else
         return LLVMFloatType();
 }
 
@@ -8756,7 +8762,7 @@ static Error resolve_llvm_c_abi_type(CodeGen *g, ZigType *ty) {
 
             LLVMTypeRef return_elem_types[] = {
                 LLVMVoidType(),
-                LLVMVoidType(), 
+                LLVMVoidType(),
             };
             for (uint32_t i = 0; i <= eightbyte_index; i += 1) {
                 if (type_classes[i] == X64CABIClass_INTEGER) {
src/stage1/codegen.cpp
@@ -216,6 +216,9 @@ static ZigLLVM_CallingConv get_llvm_cc(CodeGen *g, CallingConvention cc) {
             assert(g->zig_target->arch == ZigLLVM_nvptx ||
                 g->zig_target->arch == ZigLLVM_nvptx64);
                 return ZigLLVM_PTX_Kernel;
+        case CallingConventionAmdgpuKernel:
+            assert(g->zig_target->arch == ZigLLVM_amdgcn);
+            return ZigLLVM_AMDGPU_KERNEL;
 
     }
     zig_unreachable();
@@ -364,6 +367,7 @@ static bool cc_want_sret_attr(CallingConvention cc) {
         case CallingConventionSysV:
         case CallingConventionWin64:
         case CallingConventionPtxKernel:
+        case CallingConventionAmdgpuKernel:
             return true;
         case CallingConventionAsync:
         case CallingConventionUnspecified:
@@ -3515,7 +3519,7 @@ static LLVMValueRef gen_soft_float_to_int_op(CodeGen *g, LLVMValueRef value_ref,
 
     // Handle integers of non-pot bitsize by shortening them on the output
     if (result_type != wider_type) {
-        result = gen_widen_or_shorten(g, false, wider_type, result_type, result); 
+        result = gen_widen_or_shorten(g, false, wider_type, result_type, result);
     }
 
     return result;
src/stage1/ir.cpp
@@ -11753,6 +11753,7 @@ static Stage1AirInst *ir_analyze_instruction_export(IrAnalyze *ira, Stage1ZirIns
                 case CallingConventionSysV:
                 case CallingConventionWin64:
                 case CallingConventionPtxKernel:
+                case CallingConventionAmdgpuKernel:
                     add_fn_export(ira->codegen, fn_entry, buf_ptr(symbol_name), global_linkage_id, cc);
                     fn_entry->section_name = section_name;
                     break;
src/Sema.zig
@@ -8141,6 +8141,10 @@ fn funcCommon(
                 .nvptx, .nvptx64 => null,
                 else => @as([]const u8, "nvptx and nvptx64"),
             },
+            .AmdgpuKernel => switch (arch) {
+                .amdgcn => null,
+                else => @as([]const u8, "amdgcn"),
+            },
         }) |allowed_platform| {
             return sema.fail(block, cc_src, "callconv '{s}' is only available on {s}, not {s}", .{
                 @tagName(cc_workaround),