Commit 8eee392862

Robin Voetter <robin@voetter.nl>
2023-07-01 14:27:12
spirv: fix up todos & errors from intern pool changes
This replaces the implementation of constant() which one that is directly based on the intern pool rather than the Zig type tag too.
1 parent 0a6cd25
Changed files (1)
src
codegen
src/codegen/spirv.zig
@@ -537,6 +537,12 @@ pub const DeclGen = struct {
 
         fn addInt(self: *@This(), ty: Type, val: Value) !void {
             const mod = self.dg.module;
+            const len = ty.abiSize(mod);
+            if (val.isUndef(mod)) {
+                try self.addUndef(len);
+                return;
+            }
+
             const int_info = ty.intInfo(mod);
             const int_bits = switch (int_info.signedness) {
                 .signed => @as(u64, @bitCast(val.toSignedInt(mod))),
@@ -544,7 +550,6 @@ pub const DeclGen = struct {
             };
 
             // TODO: Swap endianess if the compiler is big endian.
-            const len = ty.abiSize(mod);
             try self.addBytes(std.mem.asBytes(&int_bits)[0..@as(usize, @intCast(len))]);
         }
 
@@ -667,31 +672,41 @@ pub const DeclGen = struct {
                     try self.addConstInt(u16, @as(u16, @intCast(int)));
                 },
                 .error_union => |error_union| {
+                    const err_ty = switch (error_union.val) {
+                        .err_name => ty.errorUnionSet(mod),
+                        .payload => Type.err_int,
+                    };
+                    const err_val = switch (error_union.val) {
+                        .err_name => |err_name| (try mod.intern(.{ .err = .{
+                            .ty = ty.errorUnionSet(mod).toIntern(),
+                            .name = err_name,
+                        } })).toValue(),
+                        .payload => try mod.intValue(Type.err_int, 0),
+                    };
                     const payload_ty = ty.errorUnionPayload(mod);
-                    const is_pl = val.errorUnionIsPayload(mod);
-                    const error_val = if (!is_pl) val else try mod.intValue(Type.anyerror, 0);
-
                     const eu_layout = dg.errorUnionLayout(payload_ty);
                     if (!eu_layout.payload_has_bits) {
-                        return try self.lower(Type.anyerror, error_val);
+                        // We use the error type directly as the type.
+                        try self.lower(err_ty, err_val);
+                        return;
                     }
 
                     const payload_size = payload_ty.abiSize(mod);
-                    const error_size = Type.anyerror.abiAlignment(mod);
+                    const error_size = err_ty.abiSize(mod);
                     const ty_size = ty.abiSize(mod);
                     const padding = ty_size - payload_size - error_size;
 
                     const payload_val = switch (error_union.val) {
-                        .err_name => try mod.intern(.{ .undef = payload_ty.ip_index }),
+                        .err_name => try mod.intern(.{ .undef = payload_ty.toIntern() }),
                         .payload => |payload| payload,
                     }.toValue();
 
                     if (eu_layout.error_first) {
-                        try self.lower(Type.anyerror, error_val);
+                        try self.lower(err_ty, err_val);
                         try self.lower(payload_ty, payload_val);
                     } else {
                         try self.lower(payload_ty, payload_val);
-                        try self.lower(Type.anyerror, error_val);
+                        try self.lower(err_ty, err_val);
                     }
 
                     try self.addUndef(padding);
@@ -705,9 +720,14 @@ pub const DeclGen = struct {
                 },
                 .float => try self.addFloat(ty, val),
                 .ptr => |ptr| {
+                    const ptr_ty = switch (ptr.len) {
+                        .none => ty,
+                        else => ty.slicePtrFieldType(mod),
+                    };
                     switch (ptr.addr) {
-                        .decl => |decl| try self.addDeclRef(ty, decl),
-                        .mut_decl => |mut_decl| try self.addDeclRef(ty, mut_decl.decl),
+                        .decl => |decl| try self.addDeclRef(ptr_ty, decl),
+                        .mut_decl => |mut_decl| try self.addDeclRef(ptr_ty, mut_decl.decl),
+                        .int => |int| try self.addInt(Type.usize, int.toValue()),
                         else => |tag| return dg.todo("pointer value of type {s}", .{@tagName(tag)}),
                     }
                     if (ptr.len != .none) {
@@ -979,38 +999,84 @@ pub const DeclGen = struct {
     /// the constant is more complicated however, it needs to be lowered to an indirect constant, which
     /// is then loaded using OpLoad. Such values are loaded into the UniformConstant storage class by default.
     /// This function should only be called during function code generation.
-    fn constant(self: *DeclGen, ty: Type, val: Value, repr: Repr) !IdRef {
+    fn constant(self: *DeclGen, ty: Type, arg_val: Value, repr: Repr) !IdRef {
         const mod = self.module;
         const target = self.getTarget();
         const result_ty_ref = try self.resolveType(ty, repr);
 
-        log.debug("constant: ty = {}, val = {}", .{ ty.fmt(self.module), val.fmtValue(ty, self.module) });
+        var val = arg_val;
+        switch (mod.intern_pool.indexToKey(val.toIntern())) {
+            .runtime_value => |rt| val = rt.val.toValue(),
+            else => {},
+        }
 
+        log.debug("constant: ty = {}, val = {}", .{ ty.fmt(self.module), val.fmtValue(ty, self.module) });
         if (val.isUndef(mod)) {
             return self.spv.constUndef(result_ty_ref);
         }
 
-        switch (ty.zigTypeTag(mod)) {
-            .Int => {
+        switch (mod.intern_pool.indexToKey(val.toIntern())) {
+            .int_type,
+            .ptr_type,
+            .array_type,
+            .vector_type,
+            .opt_type,
+            .anyframe_type,
+            .error_union_type,
+            .simple_type,
+            .struct_type,
+            .anon_struct_type,
+            .union_type,
+            .opaque_type,
+            .enum_type,
+            .func_type,
+            .error_set_type,
+            .inferred_error_set_type,
+            => unreachable, // types, not values
+
+            .undef => unreachable, // handled above
+            .runtime_value => unreachable, // ???
+
+            .variable,
+            .extern_func,
+            .func,
+            .enum_literal,
+            .empty_enum_value,
+            => unreachable, // non-runtime values
+
+            .simple_value => |simple_value| switch (simple_value) {
+                .undefined,
+                .void,
+                .null,
+                .empty_struct,
+                .@"unreachable",
+                .generic_poison,
+                => unreachable, // non-runtime values
+
+                .false, .true => switch (repr) {
+                    .direct => return try self.spv.constBool(result_ty_ref, val.toBool()),
+                    .indirect => return try self.spv.constInt(result_ty_ref, @intFromBool(val.toBool())),
+                },
+            },
+
+            .int => {
                 if (ty.isSignedInt(mod)) {
                     return try self.spv.constInt(result_ty_ref, val.toSignedInt(mod));
                 } else {
                     return try self.spv.constInt(result_ty_ref, val.toUnsignedInt(mod));
                 }
             },
-            .Bool => switch (repr) {
-                .direct => return try self.spv.constBool(result_ty_ref, val.toBool()),
-                .indirect => return try self.spv.constInt(result_ty_ref, @intFromBool(val.toBool())),
-            },
-            .Float => return switch (ty.floatBits(target)) {
+            .float => return switch (ty.floatBits(target)) {
                 16 => try self.spv.resolveId(.{ .float = .{ .ty = result_ty_ref, .value = .{ .float16 = val.toFloat(f16, mod) } } }),
                 32 => try self.spv.resolveId(.{ .float = .{ .ty = result_ty_ref, .value = .{ .float32 = val.toFloat(f32, mod) } } }),
                 64 => try self.spv.resolveId(.{ .float = .{ .ty = result_ty_ref, .value = .{ .float64 = val.toFloat(f64, mod) } } }),
                 80, 128 => unreachable, // TODO
                 else => unreachable,
             },
-            .ErrorSet => @panic("TODO"),
-            .ErrorUnion => @panic("TODO"),
+            .err => |err| {
+                const value = try mod.getErrorValue(err.name);
+                return try self.spv.constInt(result_ty_ref, value);
+            },
             // TODO: We can handle most pointers here (decl refs etc), because now they emit an extra
             // OpVariable that is not really required.
             else => {
@@ -1263,51 +1329,53 @@ pub const DeclGen = struct {
                 } });
             },
             .Struct => {
-                const struct_ty = mod.typeToStruct(ty).?;
-                const fields = struct_ty.fields.values();
-
-                if (ty.isSimpleTupleOrAnonStruct(mod)) {
-                    const member_types = try self.gpa.alloc(CacheRef, fields.len);
-                    defer self.gpa.free(member_types);
+                const struct_ty = switch (mod.intern_pool.indexToKey(ty.toIntern())) {
+                    .anon_struct_type => |tuple| {
+                        const member_types = try self.gpa.alloc(CacheRef, tuple.values.len);
+                        defer self.gpa.free(member_types);
 
-                    var member_index: usize = 0;
-                    for (fields) |field| {
-                        if (field.ty.ip_index != .unreachable_value or !field.ty.hasRuntimeBits(mod)) continue;
+                        var member_index: usize = 0;
+                        for (tuple.types, tuple.values) |field_ty, field_val| {
+                            if (field_val != .none or !field_ty.toType().hasRuntimeBits(mod)) continue;
 
-                        member_types[member_index] = try self.resolveType(field.ty, .indirect);
-                        member_index += 1;
-                    }
+                            member_types[member_index] = try self.resolveType(field_ty.toType(), .indirect);
+                            member_index += 1;
+                        }
 
-                    return try self.spv.resolve(.{ .struct_type = .{
-                        .member_types = member_types[0..member_index],
-                    } });
-                }
+                        return try self.spv.resolve(.{ .struct_type = .{
+                            .member_types = member_types[0..member_index],
+                        } });
+                    },
+                    .struct_type => |struct_ty| struct_ty,
+                    else => unreachable,
+                };
 
-                if (struct_ty.layout == .Packed) {
-                    return try self.resolveType(struct_ty.backing_int_ty, .direct);
+                const struct_obj = mod.structPtrUnwrap(struct_ty.index).?;
+                if (struct_obj.layout == .Packed) {
+                    return try self.resolveType(struct_obj.backing_int_ty, .direct);
                 }
 
-                const member_types = try self.gpa.alloc(CacheRef, fields.len);
-                defer self.gpa.free(member_types);
-
-                const member_names = try self.gpa.alloc(CacheString, fields.len);
-                defer self.gpa.free(member_names);
+                var member_types = std.ArrayList(CacheRef).init(self.gpa);
+                defer member_types.deinit();
 
-                var member_index: usize = 0;
-                for (fields, 0..) |field, i| {
-                    if (field.is_comptime or !field.ty.hasRuntimeBits(mod)) continue;
+                var member_names = std.ArrayList(CacheString).init(self.gpa);
+                defer member_names.deinit();
 
-                    member_types[member_index] = try self.resolveType(field.ty, .indirect);
-                    member_names[member_index] = try self.spv.resolveString(mod.intern_pool.stringToSlice(struct_ty.fields.keys()[i]));
-                    member_index += 1;
+                var it = struct_obj.runtimeFieldIterator(mod);
+                while (it.next()) |field_and_index| {
+                    const field = field_and_index.field;
+                    const index = field_and_index.index;
+                    const field_name = mod.intern_pool.stringToSlice(struct_obj.fields.keys()[index]);
+                    try member_types.append(try self.resolveType(field.ty, .indirect));
+                    try member_names.append(try self.spv.resolveString(field_name));
                 }
 
-                const name = mod.intern_pool.stringToSlice(try struct_ty.getFullyQualifiedName(self.module));
+                const name = mod.intern_pool.stringToSlice(try struct_obj.getFullyQualifiedName(self.module));
 
                 return try self.spv.resolve(.{ .struct_type = .{
                     .name = try self.spv.resolveString(name),
-                    .member_types = member_types[0..member_index],
-                    .member_names = member_names[0..member_index],
+                    .member_types = member_types.items,
+                    .member_names = member_names.items,
                 } });
             },
             .Optional => {
@@ -2512,9 +2580,9 @@ pub const DeclGen = struct {
         // just an element.
         var elem_ptr_info = ptr_ty.ptrInfo(mod);
         elem_ptr_info.flags.size = .One;
-        const elem_ptr_ty = elem_ptr_info.child.toType();
+        const elem_ptr_ty = try mod.intern_pool.get(mod.gpa, .{ .ptr_type = elem_ptr_info });
 
-        return try self.load(elem_ptr_ty, elem_ptr_id);
+        return try self.load(elem_ptr_ty.toType(), elem_ptr_id);
     }
 
     fn airGetUnionTag(self: *DeclGen, inst: Air.Inst.Index) !?IdRef {