Commit ea6706b6f4

Andrew Kelley <andrew@ziglang.org>
2021-09-29 23:04:52
stage2: LLVM backend: implement struct type fwd decls
Makes struct types able to refer to themselves.
1 parent 1d1f6a0
Changed files (6)
src/codegen/llvm/bindings.zig
@@ -181,6 +181,13 @@ pub const Type = opaque {
     pub const constArray = LLVMConstArray;
     extern fn LLVMConstArray(ElementTy: *const Type, ConstantVals: [*]*const Value, Length: c_uint) *const Value;
 
+    pub const constNamedStruct = LLVMConstNamedStruct;
+    extern fn LLVMConstNamedStruct(
+        StructTy: *const Type,
+        ConstantVals: [*]const *const Value,
+        Count: c_uint,
+    ) *const Value;
+
     pub const getUndef = LLVMGetUndef;
     extern fn LLVMGetUndef(Ty: *const Type) *const Value;
 
@@ -666,23 +673,25 @@ pub const Builder = opaque {
         Name: [*:0]const u8,
     ) *const Value;
 
-    pub const buildMemSet = LLVMBuildMemSet;
-    extern fn LLVMBuildMemSet(
+    pub const buildMemSet = ZigLLVMBuildMemSet;
+    extern fn ZigLLVMBuildMemSet(
         B: *const Builder,
         Ptr: *const Value,
         Val: *const Value,
         Len: *const Value,
         Align: c_uint,
+        is_volatile: bool,
     ) *const Value;
 
-    pub const buildMemCpy = LLVMBuildMemCpy;
-    extern fn LLVMBuildMemCpy(
+    pub const buildMemCpy = ZigLLVMBuildMemCpy;
+    extern fn ZigLLVMBuildMemCpy(
         B: *const Builder,
         Dst: *const Value,
         DstAlign: c_uint,
         Src: *const Value,
         SrcAlign: c_uint,
         Size: *const Value,
+        is_volatile: bool,
     ) *const Value;
 };
 
src/codegen/llvm.zig
@@ -164,9 +164,25 @@ pub const Object = struct {
     /// * it works for functions not all globals.
     /// Therefore, this table keeps track of the mapping.
     decl_map: std.AutoHashMapUnmanaged(*const Module.Decl, *const llvm.Value),
+    /// Maps Zig types to LLVM types. The table memory itself is backed by the GPA of
+    /// the compiler, but the Type/Value memory here is backed by `type_map_arena`.
+    /// TODO we need to remove entries from this map in response to incremental compilation
+    /// but I think the frontend won't tell us about types that get deleted because
+    /// hasCodeGenBits() is false for types.
+    type_map: TypeMap,
+    /// The backing memory for `type_map`. Periodically garbage collected after flush().
+    /// The code for doing the periodical GC is not yet implemented.
+    type_map_arena: std.heap.ArenaAllocator,
     /// Where to put the output object file, relative to bin_file.options.emit directory.
     sub_path: []const u8,
 
+    pub const TypeMap = std.HashMapUnmanaged(
+        Type,
+        *const llvm.Type,
+        Type.HashContext64,
+        std.hash_map.default_max_load_percentage,
+    );
+
     pub fn create(gpa: *Allocator, sub_path: []const u8, options: link.Options) !*Object {
         const obj = try gpa.create(Object);
         errdefer gpa.destroy(obj);
@@ -253,6 +269,8 @@ pub const Object = struct {
             .context = context,
             .target_machine = target_machine,
             .decl_map = .{},
+            .type_map = .{},
+            .type_map_arena = std.heap.ArenaAllocator.init(gpa),
             .sub_path = sub_path,
         };
     }
@@ -262,6 +280,8 @@ pub const Object = struct {
         self.llvm_module.dispose();
         self.context.dispose();
         self.decl_map.deinit(gpa);
+        self.type_map.deinit(gpa);
+        self.type_map_arena.deinit();
         self.* = undefined;
     }
 
@@ -725,10 +745,10 @@ pub const DeclGen = struct {
     }
 
     fn llvmType(self: *DeclGen, t: Type) error{ OutOfMemory, CodegenFail }!*const llvm.Type {
+        const gpa = self.gpa;
         log.debug("llvmType for {}", .{t});
         switch (t.zigTypeTag()) {
-            .Void => return self.context.voidType(),
-            .NoReturn => return self.context.voidType(),
+            .Void, .NoReturn => return self.context.voidType(),
             .Int => {
                 const info = t.intInfo(self.module.getTarget());
                 return self.context.intType(info.bits);
@@ -799,18 +819,38 @@ pub const DeclGen = struct {
                 return self.context.intType(16);
             },
             .Struct => {
+                const gop = try self.object.type_map.getOrPut(gpa, t);
+                if (gop.found_existing) return gop.value_ptr.*;
+
+                // The Type memory is ephemeral; since we want to store a longer-lived
+                // reference, we need to copy it here.
+                gop.key_ptr.* = try t.copy(&self.object.type_map_arena.allocator);
+
                 const struct_obj = t.castTag(.@"struct").?.data;
                 assert(struct_obj.haveFieldTypes());
-                const llvm_fields = try self.gpa.alloc(*const llvm.Type, struct_obj.fields.count());
-                defer self.gpa.free(llvm_fields);
-                for (struct_obj.fields.values()) |field, i| {
-                    llvm_fields[i] = try self.llvmType(field.ty);
+
+                const name = try struct_obj.getFullyQualifiedName(gpa);
+                defer gpa.free(name);
+
+                const llvm_struct_ty = self.context.structCreateNamed(name);
+                gop.value_ptr.* = llvm_struct_ty; // must be done before any recursive calls
+
+                var llvm_field_types: std.ArrayListUnmanaged(*const llvm.Type) = .{};
+                try llvm_field_types.ensureTotalCapacity(gpa, struct_obj.fields.count());
+                defer llvm_field_types.deinit(gpa);
+
+                for (struct_obj.fields.values()) |field| {
+                    if (!field.ty.hasCodeGenBits()) continue;
+                    llvm_field_types.appendAssumeCapacity(try self.llvmType(field.ty));
                 }
-                return self.context.structType(
-                    llvm_fields.ptr,
-                    @intCast(c_uint, llvm_fields.len),
-                    .False,
+
+                llvm_struct_ty.structSetBody(
+                    llvm_field_types.items.ptr,
+                    @intCast(c_uint, llvm_field_types.items.len),
+                    llvm.Bool.fromBool(struct_obj.layout == .Packed),
                 );
+
+                return llvm_struct_ty;
             },
             .Union => {
                 const union_obj = t.castTag(.@"union").?.data;
@@ -838,8 +878,8 @@ pub const DeclGen = struct {
             .Fn => {
                 const ret_ty = try self.llvmType(t.fnReturnType());
                 const params_len = t.fnParamLen();
-                const llvm_params = try self.gpa.alloc(*const llvm.Type, params_len);
-                defer self.gpa.free(llvm_params);
+                const llvm_params = try gpa.alloc(*const llvm.Type, params_len);
+                defer gpa.free(llvm_params);
                 for (llvm_params) |*llvm_param, i| {
                     llvm_param.* = try self.llvmType(t.fnParamType(i));
                 }
@@ -1073,21 +1113,26 @@ pub const DeclGen = struct {
                 return self.context.constStruct(&fields, fields.len, .False);
             },
             .Struct => {
-                const fields_len = tv.ty.structFieldCount();
+                const llvm_struct_ty = try self.llvmType(tv.ty);
                 const field_vals = tv.val.castTag(.@"struct").?.data;
                 const gpa = self.gpa;
-                const llvm_fields = try gpa.alloc(*const llvm.Value, fields_len);
-                defer gpa.free(llvm_fields);
-                for (llvm_fields) |*llvm_field, i| {
-                    llvm_field.* = try self.genTypedValue(.{
-                        .ty = tv.ty.structFieldType(i),
-                        .val = field_vals[i],
-                    });
+
+                var llvm_fields: std.ArrayListUnmanaged(*const llvm.Value) = .{};
+                try llvm_fields.ensureTotalCapacity(gpa, field_vals.len);
+                defer llvm_fields.deinit(gpa);
+
+                for (field_vals) |field_val, i| {
+                    const field_ty = tv.ty.structFieldType(i);
+                    if (!field_ty.hasCodeGenBits()) continue;
+
+                    llvm_fields.appendAssumeCapacity(try self.genTypedValue(.{
+                        .ty = field_ty,
+                        .val = field_val,
+                    }));
                 }
-                return self.context.constStruct(
-                    llvm_fields.ptr,
-                    @intCast(c_uint, llvm_fields.len),
-                    .False,
+                return llvm_struct_ty.constNamedStruct(
+                    llvm_fields.items.ptr,
+                    @intCast(c_uint, llvm_fields.items.len),
                 );
             },
             .ComptimeInt => unreachable,
@@ -1692,13 +1737,15 @@ pub const FuncGen = struct {
         const struct_field = self.air.extraData(Air.StructField, ty_pl.payload).data;
         const struct_ptr = try self.resolveInst(struct_field.struct_operand);
         const struct_ptr_ty = self.air.typeOf(struct_field.struct_operand);
-        const field_index = @intCast(c_uint, struct_field.field_index);
-        return self.fieldPtr(inst, struct_ptr, struct_ptr_ty, field_index);
+        return self.fieldPtr(inst, struct_ptr, struct_ptr_ty, struct_field.field_index);
     }
 
-    fn airStructFieldPtrIndex(self: *FuncGen, inst: Air.Inst.Index, field_index: c_uint) !?*const llvm.Value {
-        if (self.liveness.isUnused(inst))
-            return null;
+    fn airStructFieldPtrIndex(
+        self: *FuncGen,
+        inst: Air.Inst.Index,
+        field_index: u32,
+    ) !?*const llvm.Value {
+        if (self.liveness.isUnused(inst)) return null;
 
         const ty_op = self.air.instructions.items(.data)[inst].ty_op;
         const struct_ptr = try self.resolveInst(ty_op.operand);
@@ -1707,13 +1754,13 @@ pub const FuncGen = struct {
     }
 
     fn airStructFieldVal(self: *FuncGen, inst: Air.Inst.Index) !?*const llvm.Value {
-        if (self.liveness.isUnused(inst))
-            return null;
+        if (self.liveness.isUnused(inst)) return null;
 
         const ty_pl = self.air.instructions.items(.data)[inst].ty_pl;
         const struct_field = self.air.extraData(Air.StructField, ty_pl.payload).data;
+        const struct_ty = self.air.typeOf(struct_field.struct_operand);
         const struct_byval = try self.resolveInst(struct_field.struct_operand);
-        const field_index = @intCast(c_uint, struct_field.field_index);
+        const field_index = llvmFieldIndex(struct_ty, struct_field.field_index);
         return self.builder.buildExtractValue(struct_byval, field_index, "");
     }
 
@@ -2643,8 +2690,7 @@ pub const FuncGen = struct {
         const fill_char = if (val_is_undef) u8_llvm_ty.constInt(0xaa, .False) else value;
         const target = self.dg.module.getTarget();
         const dest_ptr_align = ptr_ty.ptrAlignment(target);
-        const memset = self.builder.buildMemSet(dest_ptr_u8, fill_char, len, dest_ptr_align);
-        memset.setVolatile(llvm.Bool.fromBool(ptr_ty.isVolatilePtr()));
+        _ = self.builder.buildMemSet(dest_ptr_u8, fill_char, len, dest_ptr_align, ptr_ty.isVolatilePtr());
 
         if (val_is_undef and self.dg.module.comp.bin_file.options.valgrind) {
             // TODO generate valgrind client request to mark byte range as undefined
@@ -2667,14 +2713,14 @@ pub const FuncGen = struct {
         const src_ptr_u8 = self.builder.buildBitCast(src_ptr, ptr_u8_llvm_ty, "");
         const is_volatile = src_ptr_ty.isVolatilePtr() or dest_ptr_ty.isVolatilePtr();
         const target = self.dg.module.getTarget();
-        const memcpy = self.builder.buildMemCpy(
+        _ = self.builder.buildMemCpy(
             dest_ptr_u8,
             dest_ptr_ty.ptrAlignment(target),
             src_ptr_u8,
             src_ptr_ty.ptrAlignment(target),
             len,
+            is_volatile,
         );
-        memcpy.setVolatile(llvm.Bool.fromBool(is_volatile));
         return null;
     }
 
@@ -2741,11 +2787,14 @@ pub const FuncGen = struct {
         inst: Air.Inst.Index,
         struct_ptr: *const llvm.Value,
         struct_ptr_ty: Type,
-        field_index: c_uint,
+        field_index: u32,
     ) !?*const llvm.Value {
         const struct_ty = struct_ptr_ty.childType();
         switch (struct_ty.zigTypeTag()) {
-            .Struct => return self.builder.buildStructGEP(struct_ptr, field_index, ""),
+            .Struct => {
+                const llvm_field_index = llvmFieldIndex(struct_ty, field_index);
+                return self.builder.buildStructGEP(struct_ptr, llvm_field_index, "");
+            },
             .Union => return self.unionFieldPtr(inst, struct_ptr, struct_ty, field_index),
             else => unreachable,
         }
@@ -2968,3 +3017,15 @@ fn toLlvmAtomicRmwBinOp(
         .Min => if (is_signed) llvm.AtomicRMWBinOp.Min else return .UMin,
     };
 }
+
+/// Take into account 0 bit fields.
+fn llvmFieldIndex(ty: Type, index: u32) c_uint {
+    const struct_obj = ty.castTag(.@"struct").?.data;
+    var result: c_uint = 0;
+    for (struct_obj.fields.values()[0..index]) |field| {
+        if (field.ty.hasCodeGenBits()) {
+            result += 1;
+        }
+    }
+    return result;
+}
src/Module.zig
@@ -813,7 +813,7 @@ pub const Struct = struct {
         is_comptime: bool,
     };
 
-    pub fn getFullyQualifiedName(s: *Struct, gpa: *Allocator) ![]u8 {
+    pub fn getFullyQualifiedName(s: *Struct, gpa: *Allocator) ![:0]u8 {
         return s.owner_decl.getFullyQualifiedName(gpa);
     }
 
src/type.zig
@@ -1785,6 +1785,8 @@ pub const Type = extern union {
                 if (is_packed) @panic("TODO packed structs");
                 var size: u64 = 0;
                 for (s.fields.values()) |field| {
+                    if (!field.ty.hasCodeGenBits()) continue;
+
                     const field_align = a: {
                         if (field.abi_align.tag() == .abi_align_default) {
                             break :a field.ty.abiAlignment(target);
test/behavior/struct.zig
@@ -90,3 +90,42 @@ test "call member function directly" {
     const result = MemberFnTestFoo.member(instance);
     try expect(result == 1234);
 }
+
+test "struct point to self" {
+    var root: Node = undefined;
+    root.val.x = 1;
+
+    var node: Node = undefined;
+    node.next = &root;
+    node.val.x = 2;
+
+    root.next = &node;
+
+    try expect(node.next.next.next.val.x == 1);
+}
+
+test "void struct fields" {
+    const foo = VoidStructFieldsFoo{
+        .a = void{},
+        .b = 1,
+        .c = void{},
+    };
+    try expect(foo.b == 1);
+    try expect(@sizeOf(VoidStructFieldsFoo) == 4);
+}
+const VoidStructFieldsFoo = struct {
+    a: void,
+    b: i32,
+    c: void,
+};
+
+test "member functions" {
+    const r = MemberFnRand{ .seed = 1234 };
+    try expect(r.getSeed() == 1234);
+}
+const MemberFnRand = struct {
+    seed: u32,
+    pub fn getSeed(r: *const MemberFnRand) u32 {
+        return r.seed;
+    }
+};
test/behavior/struct_stage1.zig
@@ -16,21 +16,6 @@ test "top level fields" {
     try expectEqual(@as(i32, 1235), instance.top_level_field);
 }
 
-test "void struct fields" {
-    const foo = VoidStructFieldsFoo{
-        .a = void{},
-        .b = 1,
-        .c = void{},
-    };
-    try expect(foo.b == 1);
-    try expect(@sizeOf(VoidStructFieldsFoo) == 4);
-}
-const VoidStructFieldsFoo = struct {
-    a: void,
-    b: i32,
-    c: void,
-};
-
 const StructFoo = struct {
     a: i32,
     b: bool,
@@ -46,19 +31,6 @@ const Val = struct {
     x: i32,
 };
 
-test "struct point to self" {
-    var root: Node = undefined;
-    root.val.x = 1;
-
-    var node: Node = undefined;
-    node.next = &root;
-    node.val.x = 2;
-
-    root.next = &node;
-
-    try expect(node.next.next.next.val.x == 1);
-}
-
 test "fn call of struct field" {
     const Foo = struct {
         ptr: fn () i32,
@@ -89,17 +61,6 @@ test "store member function in variable" {
     try expect(result == 1234);
 }
 
-test "member functions" {
-    const r = MemberFnRand{ .seed = 1234 };
-    try expect(r.getSeed() == 1234);
-}
-const MemberFnRand = struct {
-    seed: u32,
-    pub fn getSeed(r: *const MemberFnRand) u32 {
-        return r.seed;
-    }
-};
-
 test "return struct byval from function" {
     const bar = makeBar2(1234, 5678);
     try expect(bar.y == 5678);