Commit c92cc5798f

Robin Voetter <robin@voetter.nl>
2023-05-20 18:02:30
spirv: make constant handle float, errorset, errorunion
This is in preparation of removing indirect lowering again. Also modifies constant() to accept a repr so that both direct as well as indirect representations can be generated. Indirect is not yet used, but will be used for globals.
1 parent 65157d3
Changed files (2)
src
codegen
src/codegen/spirv/Module.zig
@@ -774,6 +774,16 @@ pub fn changePtrStorageClass(self: *Module, ptr_ty_ref: Type.Ref, new_storage_cl
     return try self.resolveType(Type.initPayload(&payload.base));
 }
 
+pub fn constComposite(self: *Module, ty_ref: Type.Ref, members: []const IdRef) !IdRef {
+    const result_id = self.allocId();
+    try self.sections.types_globals_constants.emit(self.gpa, .OpSpecConstantComposite, .{
+        .id_result_type = self.typeId(ty_ref),
+        .id_result = result_id,
+        .constituents = members,
+    });
+    return result_id;
+}
+
 pub fn emitConstant(
     self: *Module,
     ty_id: IdRef,
src/codegen/spirv.zig
@@ -242,7 +242,7 @@ pub const DeclGen = struct {
                 return self.spv.declPtr(spv_decl_index).result_id;
             }
 
-            return try self.constant(ty, val);
+            return try self.constant(ty, val, .direct);
         }
         const index = Air.refToIndex(inst).?;
         return self.inst_results.get(index).?; // Assertion means instruction does not dominate usage.
@@ -1021,14 +1021,16 @@ pub const DeclGen = struct {
     /// the constant is more complicated however, it needs to be lowered to an indirect constant, which
     /// is then loaded using OpLoad. Such values are loaded into the UniformConstant storage class by default.
     /// This function should only be called during function code generation.
-    fn constant(self: *DeclGen, ty: Type, val: Value) !IdRef {
+    fn constant(self: *DeclGen, ty: Type, val: Value, repr: Repr) !IdRef {
         const target = self.getTarget();
         const section = &self.spv.sections.types_globals_constants;
-        const result_ty_ref = try self.resolveType(ty, .direct);
+        const result_ty_ref = try self.resolveType(ty, repr);
         const result_ty_id = self.typeId(result_ty_ref);
-        const result_id = self.spv.allocId();
+
+        log.debug("constant: ty = {}, val = {}", .{ ty.fmt(self.module), val.fmtValue(ty, self.module) });
 
         if (val.isUndef()) {
+            const result_id = self.spv.allocId();
             try section.emit(self.spv.gpa, .OpUndef, .{
                 .id_result_type = result_ty_id,
                 .id_result = result_id,
@@ -1039,24 +1041,76 @@ pub const DeclGen = struct {
         switch (ty.zigTypeTag()) {
             .Int => {
                 if (ty.isSignedInt()) {
-                    try self.genConstInt(result_ty_ref, result_id, val.toSignedInt(target));
+                    return try self.constInt(result_ty_ref, val.toSignedInt(target));
                 } else {
-                    try self.genConstInt(result_ty_ref, result_id, val.toUnsignedInt(target));
+                    return try self.constInt(result_ty_ref, val.toUnsignedInt(target));
+                }
+            },
+            .Bool => switch (repr) {
+                .direct => {
+                    const result_id = self.spv.allocId();
+                    const operands = .{ .id_result_type = result_ty_id, .id_result = result_id };
+                    if (val.toBool()) {
+                        try section.emit(self.spv.gpa, .OpConstantTrue, operands);
+                    } else {
+                        try section.emit(self.spv.gpa, .OpConstantFalse, operands);
+                    }
+                    return result_id;
+                },
+                .indirect => return try self.constInt(result_ty_ref, @boolToInt(val.toBool())),
+            },
+            .Float => {
+                const result_id = self.spv.allocId();
+                switch (ty.floatBits(target)) {
+                    16 => try self.spv.emitConstant(result_ty_id, result_id, .{ .float32 = val.toFloat(f16) }),
+                    32 => try self.spv.emitConstant(result_ty_id, result_id, .{ .float32 = val.toFloat(f32) }),
+                    64 => try self.spv.emitConstant(result_ty_id, result_id, .{ .float64 = val.toFloat(f64) }),
+                    80, 128 => unreachable, // TODO
+                    else => unreachable,
                 }
+                return result_id;
             },
-            .Bool => {
-                const operands = .{ .id_result_type = result_ty_id, .id_result = result_id };
-                if (val.toBool()) {
-                    try section.emit(self.spv.gpa, .OpConstantTrue, operands);
+            .ErrorSet => {
+                const value = switch (val.tag()) {
+                    .@"error" => blk: {
+                        const err_name = val.castTag(.@"error").?.data.name;
+                        const kv = try self.module.getErrorValue(err_name);
+                        break :blk @intCast(u16, kv.value);
+                    },
+                    .zero => 0,
+                    else => unreachable,
+                };
+
+                return try self.constInt(result_ty_ref, value);
+            },
+            .ErrorUnion => {
+                const payload_ty = ty.errorUnionPayload();
+                const is_pl = val.errorUnionIsPayload();
+                const error_val = if (!is_pl) val else Value.initTag(.zero);
+
+                const eu_layout = self.errorUnionLayout(payload_ty);
+                if (!eu_layout.payload_has_bits) {
+                    return try self.constant(Type.anyerror, error_val, repr);
+                }
+
+                const payload_val = if (val.castTag(.eu_payload)) |pl| pl.data else Value.initTag(.undef);
+
+                var members: [2]IdRef = undefined;
+                if (eu_layout.error_first) {
+                    members[0] = try self.constant(Type.anyerror, error_val, .indirect);
+                    members[1] = try self.constant(payload_ty, payload_val, .indirect);
                 } else {
-                    try section.emit(self.spv.gpa, .OpConstantFalse, operands);
+                    members[0] = try self.constant(payload_ty, payload_val, .indirect);
+                    members[1] = try self.constant(Type.anyerror, error_val, .indirect);
                 }
+                return try self.spv.constComposite(result_ty_ref, &members);
             },
             // TODO: We can handle most pointers here (decl refs etc), because now they emit an extra
             // OpVariable that is not really required.
             else => {
                 // The value cannot be generated directly, so generate it as an indirect constant,
                 // and then perform an OpLoad.
+                const result_id = self.spv.allocId();
                 const alignment = ty.abiAlignment(target);
                 const spv_decl_index = try self.spv.allocDecl(.global);
 
@@ -1078,10 +1132,9 @@ pub const DeclGen = struct {
                 });
                 // TODO: Convert bools? This logic should hook into `load`. It should be a dead
                 // path though considering .Bool is handled above.
+                return result_id;
             },
         }
-
-        return result_id;
     }
 
     /// Turn a Zig type into a SPIR-V Type, and return its type result-id.