Commit 3f92eaceb6

Robin Voetter <robin@voetter.nl>
2022-11-26 16:51:53
spirv: array, structs, bitcast, call
Implements type lowering for arrays and structs, and implements instruction lowering for bitcast and call. Bitcast currently naively maps to the OpBitcast instruction - this is only valid for some primitive types, and should be improved to work with composites.
1 parent dae8b4c
Changed files (4)
src/codegen/spirv/Module.zig
@@ -222,7 +222,7 @@ pub fn resolveType(self: *Module, ty: Type) !Type.Ref {
     return @intToEnum(Type.Ref, result.index);
 }
 
-pub fn resolveTypeId(self: *Module, ty: Type) !IdRef {
+pub fn resolveTypeId(self: *Module, ty: Type) !IdResultType {
     const type_ref = try self.resolveType(ty);
     return self.typeResultId(type_ref);
 }
@@ -243,7 +243,7 @@ pub fn typeRefId(self: Module, type_ref: Type.Ref) IdRef {
 /// Note: This function does not attempt to perform any validation on the type.
 /// The type is emitted in a shallow fashion; any child types should already
 /// be emitted at this point.
-pub fn emitType(self: *Module, ty: Type) !IdResultType {
+pub fn emitType(self: *Module, ty: Type) error{OutOfMemory}!IdResultType {
     const result_id = self.allocId();
     const ref_id = result_id.toRef();
     const types = &self.sections.types_globals_constants;
@@ -347,10 +347,21 @@ pub fn emitType(self: *Module, ty: Type) !IdResultType {
         .array => {
             const info = ty.payload(.array);
             assert(info.length != 0);
+
+            const size_type = Type.initTag(.u32);
+            const size_type_id = try self.resolveTypeId(size_type);
+
+            const length_id = self.allocId();
+            try types.emit(self.gpa, .OpConstant, .{
+                .id_result_type = size_type_id,
+                .id_result = length_id,
+                .value = .{ .uint32 = info.length },
+            });
+
             try types.emit(self.gpa, .OpTypeArray, .{
                 .id_result = result_id,
                 .element_type = self.typeResultId(ty.childType()).toRef(),
-                .length = .{ .id = 0 }, // TODO: info.length must be emitted as constant!
+                .length = length_id.toRef(),
             });
             if (info.array_stride != 0) {
                 try annotations.decorate(self.gpa, ref_id, .{ .ArrayStride = .{ .array_stride = info.array_stride } });
src/codegen/spirv/type.zig
@@ -421,7 +421,7 @@ pub const Type = extern union {
             length: u32,
             /// Type has the 'ArrayStride' decoration.
             /// If zero, no stride is present.
-            array_stride: u32,
+            array_stride: u32 = 0,
         };
 
         pub const RuntimeArray = struct {
@@ -434,6 +434,7 @@ pub const Type = extern union {
 
         pub const Struct = struct {
             base: Payload = .{ .tag = .@"struct" },
+            // TODO: name
             members: []Member,
             decorations: StructDecorations,
 
@@ -444,20 +445,21 @@ pub const Type = extern union {
             pub const Member = struct {
                 ty: Ref,
                 offset: u32,
+                // TODO: name
                 decorations: MemberDecorations,
             };
 
             pub const StructDecorations = packed struct {
                 /// Type has the 'Block' decoration.
-                block: bool,
+                block: bool = false,
                 /// Type has the 'BufferBlock' decoration.
-                buffer_block: bool,
+                buffer_block: bool = false,
                 /// Type has the 'GLSLShared' decoration.
-                glsl_shared: bool,
+                glsl_shared: bool = false,
                 /// Type has the 'GLSLPacked' decoration.
-                glsl_packed: bool,
+                glsl_packed: bool = false,
                 /// Type has the 'CPacked' decoration.
-                c_packed: bool,
+                c_packed: bool = false,
             };
 
             pub const MemberDecorations = packed struct {
@@ -473,31 +475,31 @@ pub const Type = extern union {
                     col_major,
                     /// Member is not a matrix or array of matrices.
                     none,
-                },
+                } = .none,
 
                 // Regular decorations, these do not imply extra fields.
 
                 /// Member has the 'NoPerspective' decoration.
-                no_perspective: bool,
+                no_perspective: bool = false,
                 /// Member has the 'Flat' decoration.
-                flat: bool,
+                flat: bool = false,
                 /// Member has the 'Patch' decoration.
-                patch: bool,
+                patch: bool = false,
                 /// Member has the 'Centroid' decoration.
-                centroid: bool,
+                centroid: bool = false,
                 /// Member has the 'Sample' decoration.
-                sample: bool,
+                sample: bool = false,
                 /// Member has the 'Invariant' decoration.
                 /// Note: requires parent struct to have 'Block'.
-                invariant: bool,
+                invariant: bool = false,
                 /// Member has the 'Volatile' decoration.
-                @"volatile": bool,
+                @"volatile": bool = false,
                 /// Member has the 'Coherent' decoration.
-                coherent: bool,
+                coherent: bool = false,
                 /// Member has the 'NonWritable' decoration.
-                non_writable: bool,
+                non_writable: bool = false,
                 /// Member has the 'NonReadable' decoration.
-                non_readable: bool,
+                non_readable: bool = false,
 
                 // The following decorations all imply extra field(s).
 
@@ -506,27 +508,27 @@ pub const Type = extern union {
                 /// Note: If any member of a struct has the BuiltIn decoration, all members must have one.
                 /// Note: Each builtin may only be reachable once for a particular entry point.
                 /// Note: The member type may be constrained by a particular built-in, defined in the client API specification.
-                builtin: bool,
+                builtin: bool = false,
                 /// Member has the 'Stream' decoration.
                 /// This member has an extra field of type `u32`.
-                stream: bool,
+                stream: bool = false,
                 /// Member has the 'Location' decoration.
                 /// This member has an extra field of type `u32`.
-                location: bool,
+                location: bool = false,
                 /// Member has the 'Component' decoration.
                 /// This member has an extra field of type `u32`.
-                component: bool,
+                component: bool = false,
                 /// Member has the 'XfbBuffer' decoration.
                 /// This member has an extra field of type `u32`.
-                xfb_buffer: bool,
+                xfb_buffer: bool = false,
                 /// Member has the 'XfbStride' decoration.
                 /// This member has an extra field of type `u32`.
-                xfb_stride: bool,
+                xfb_stride: bool = false,
                 /// Member has the 'UserSemantic' decoration.
                 /// This member has an extra field of type `[]u8`, which is encoded
                 /// by an `u32` containing the number of chars exactly, and then the string padded to
                 /// a multiple of 4 bytes with zeroes.
-                user_semantic: bool,
+                user_semantic: bool = false,
             };
         };
 
src/codegen/spirv.zig
@@ -492,6 +492,21 @@ pub const DeclGen = struct {
 
                 return try self.spv.resolveType(SpvType.float(bits));
             },
+            .Array => {
+                const elem_ty = ty.childType();
+                const total_len_u64 = ty.arrayLen() + @boolToInt(ty.sentinel() != null);
+                const total_len = std.math.cast(u32, total_len_u64) orelse {
+                    return self.fail("array type of {} elements is too large", .{total_len_u64});
+                };
+
+                const payload = try self.spv.arena.create(SpvType.Payload.Array);
+                payload.* = .{
+                    .element_type = try self.resolveType(elem_ty),
+                    .length = total_len,
+                    .array_stride = @intCast(u32, ty.abiSize(target)),
+                };
+                return try self.spv.resolveType(SpvType.initPayload(&payload.base));
+            },
             .Fn => {
                 // TODO: Put this somewhere in Sema.zig
                 if (ty.fnIsVarArgs())
@@ -537,7 +552,37 @@ pub const DeclGen = struct {
                 };
                 return try self.spv.resolveType(SpvType.initPayload(&payload.base));
             },
+            .Struct => {
+                if (ty.isSimpleTupleOrAnonStruct()) {
+                    return self.todo("implement tuple struct type", .{});
+                }
+
+                const struct_ty = ty.castTag(.@"struct").?.data;
+
+                if (struct_ty.layout == .Packed) {
+                    return try self.resolveType(struct_ty.backing_int_ty);
+                }
+
+                const members = try self.spv.arena.alloc(SpvType.Payload.Struct.Member, struct_ty.fields.count());
+                var member_index: usize = 0;
+                for (struct_ty.fields.values()) |field| {
+                    if (field.is_comptime or !field.ty.hasRuntimeBits()) continue;
+
+                    members[member_index] = .{
+                        .ty = try self.resolveType(field.ty),
+                        .offset = field.offset,
+                        .decorations = .{},
+                    };
+                }
 
+                const payload = try self.spv.arena.create(SpvType.Payload.Struct);
+                payload.* = .{
+                    .members = members[0..member_index],
+                    .decorations = .{},
+                    .member_decoration_extra = &.{},
+                };
+                return try self.spv.resolveType(SpvType.initPayload(&payload.base));
+            },
             .Null,
             .Undefined,
             .EnumLiteral,
@@ -632,7 +677,7 @@ pub const DeclGen = struct {
             .bool_and => try self.airBinOpSimple(inst, .OpLogicalAnd),
             .bool_or  => try self.airBinOpSimple(inst, .OpLogicalOr),
 
-            .not => try self.airNot(inst),
+            .not     => try self.airNot(inst),
 
             .cmp_eq  => try self.airCmp(inst, .OpFOrdEqual,            .OpLogicalEqual,      .OpIEqual),
             .cmp_neq => try self.airCmp(inst, .OpFOrdNotEqual,         .OpLogicalNotEqual,   .OpINotEqual),
@@ -646,6 +691,7 @@ pub const DeclGen = struct {
             .block => (try self.airBlock(inst)) orelse return,
             .load  => try self.airLoad(inst),
 
+            .bitcast    => try self.airBitcast(inst),
             .br         => return self.airBr(inst),
             .breakpoint => return,
             .cond_br    => return self.airCondBr(inst),
@@ -657,6 +703,11 @@ pub const DeclGen = struct {
             .unreach    => return self.airUnreach(),
             .assembly   => (try self.airAssembly(inst)) orelse return,
 
+            .call              => (try self.airCall(inst, .auto)) orelse return,
+            .call_always_tail  => (try self.airCall(inst, .always_tail)) orelse return,
+            .call_never_tail   => (try self.airCall(inst, .never_tail)) orelse return,
+            .call_never_inline => (try self.airCall(inst, .never_inline)) orelse return,
+
             .dbg_var_ptr => return,
             .dbg_var_val => return,
             .dbg_block_begin => return,
@@ -911,6 +962,19 @@ pub const DeclGen = struct {
         return result_id.toRef();
     }
 
+    fn airBitcast(self: *DeclGen, inst: Air.Inst.Index) !IdRef {
+        const ty_op = self.air.instructions.items(.data)[inst].ty_op;
+        const operand_id = try self.resolve(ty_op.operand);
+        const result_id = self.spv.allocId();
+        const result_type_id = try self.resolveTypeId(Type.initTag(.bool));
+        try self.func.body.emit(self.spv.gpa, .OpBitcast, .{
+            .id_result_type = result_type_id,
+            .id_result = result_id,
+            .operand = operand_id,
+        });
+        return result_id.toRef();
+    }
+
     fn airBr(self: *DeclGen, inst: Air.Inst.Index) !void {
         const br = self.air.instructions.items(.data)[inst].br;
         const block = self.blocks.get(br.block_inst).?;
@@ -1158,4 +1222,43 @@ pub const DeclGen = struct {
 
         return null;
     }
+
+    fn airCall(self: *DeclGen, inst: Air.Inst.Index, modifier: std.builtin.CallOptions.Modifier) !?IdRef {
+        _ = modifier;
+
+        const pl_op = self.air.instructions.items(.data)[inst].pl_op;
+        const extra = self.air.extraData(Air.Call, pl_op.payload);
+        const args = @ptrCast([]const Air.Inst.Ref, self.air.extra[extra.end..][0..extra.data.args_len]);
+        const callee_ty = self.air.typeOf(pl_op.operand);
+        const zig_fn_ty = switch (callee_ty.zigTypeTag()) {
+            .Fn => callee_ty,
+            .Pointer => return self.fail("cannot call function pointers", .{}),
+            else => unreachable,
+        };
+        const fn_info = zig_fn_ty.fnInfo();
+        const return_type = fn_info.return_type;
+
+        const result_type_id = try self.resolveTypeId(return_type);
+        const result_id = self.spv.allocId();
+        const callee_id = try self.resolve(pl_op.operand);
+
+        try self.func.body.emitRaw(self.spv.gpa, .OpFunctionCall, 3 + args.len);
+        self.func.body.writeOperand(spec.IdResultType, result_type_id);
+        self.func.body.writeOperand(spec.IdResult, result_id);
+        self.func.body.writeOperand(spec.IdRef, callee_id);
+
+        for (args) |arg| {
+            const arg_id = try self.resolve(arg);
+            const arg_ty = self.air.typeOf(arg);
+            if (!arg_ty.hasRuntimeBitsIgnoreComptime()) continue;
+
+            self.func.body.writeOperand(spec.IdRef, arg_id);
+        }
+
+        if (return_type.isNoReturn()) {
+            try self.func.body.emit(self.spv.gpa, .OpUnreachable, {});
+        }
+
+        return result_id.toRef();
+    }
 };
src/link/SpirV.zig
@@ -226,6 +226,8 @@ pub fn flushModule(self: *SpirV, comp: *Compilation, prog_node: *std.Progress.No
         const air = entry.value_ptr.air;
         const liveness = entry.value_ptr.liveness;
 
+        log.debug("generating code for {s}", .{decl.name});
+
         // Note, if `decl` is not a function, air/liveness may be undefined.
         if (try decl_gen.gen(decl_index, air, liveness)) |msg| {
             try module.failed_decls.put(module.gpa, decl_index, msg);