Commit d4ec0279d3

Timon Kruiper <timonkruiper@gmail.com>
2021-01-09 16:22:43
stage2: add support for optionals in the LLVM backend
We can now codegen optionals! This includes the following instructions: - is_null - is_null_ptr - is_non_null - is_non_null_ptr - optional_payload - optional_payload_ptr - br_void Also includes a test for optionals.
1 parent 3ad9cb8
Changed files (4)
src
test
stage2
src/codegen/llvm/bindings.zig
@@ -21,9 +21,15 @@ pub const Context = opaque {
     pub const voidType = LLVMVoidTypeInContext;
     extern fn LLVMVoidTypeInContext(C: *const Context) *const Type;
 
+    pub const structType = LLVMStructTypeInContext;
+    extern fn LLVMStructTypeInContext(C: *const Context, ElementTypes: [*]*const Type, ElementCount: c_uint, Packed: LLVMBool) *const Type;
+
     pub const constString = LLVMConstStringInContext;
     extern fn LLVMConstStringInContext(C: *const Context, Str: [*]const u8, Length: c_uint, DontNullTerminate: LLVMBool) *const Value;
 
+    pub const constStruct = LLVMConstStructInContext;
+    extern fn LLVMConstStructInContext(C: *const Context, ConstantVals: [*]*const Value, Count: c_uint, Packed: LLVMBool) *const Value;
+
     pub const createBasicBlock = LLVMCreateBasicBlockInContext;
     extern fn LLVMCreateBasicBlockInContext(C: *const Context, Name: [*:0]const u8) *const BasicBlock;
 
@@ -204,6 +210,9 @@ pub const Builder = opaque {
 
     pub const buildPhi = LLVMBuildPhi;
     extern fn LLVMBuildPhi(*const Builder, Ty: *const Type, Name: [*:0]const u8) *const Value;
+
+    pub const buildExtractValue = LLVMBuildExtractValue;
+    extern fn LLVMBuildExtractValue(*const Builder, AggVal: *const Value, Index: c_uint, Name: [*:0]const u8) *const Value;
 };
 
 pub const IntPredicate = extern enum {
src/codegen/llvm.zig
@@ -397,6 +397,7 @@ pub const LLVMIRModule = struct {
                 .block => try self.genBlock(inst.castTag(.block).?),
                 .br => try self.genBr(inst.castTag(.br).?),
                 .breakpoint => try self.genBreakpoint(inst.castTag(.breakpoint).?),
+                .br_void => try self.genBrVoid(inst.castTag(.br_void).?),
                 .call => try self.genCall(inst.castTag(.call).?),
                 .cmp_eq => try self.genCmp(inst.castTag(.cmp_eq).?, .eq),
                 .cmp_gt => try self.genCmp(inst.castTag(.cmp_gt).?, .gt),
@@ -406,6 +407,10 @@ pub const LLVMIRModule = struct {
                 .cmp_neq => try self.genCmp(inst.castTag(.cmp_neq).?, .neq),
                 .condbr => try self.genCondBr(inst.castTag(.condbr).?),
                 .intcast => try self.genIntCast(inst.castTag(.intcast).?),
+                .is_non_null => try self.genIsNonNull(inst.castTag(.is_non_null).?, false),
+                .is_non_null_ptr => try self.genIsNonNull(inst.castTag(.is_non_null_ptr).?, true),
+                .is_null => try self.genIsNull(inst.castTag(.is_null).?, false),
+                .is_null_ptr => try self.genIsNull(inst.castTag(.is_null_ptr).?, true),
                 .load => try self.genLoad(inst.castTag(.load).?),
                 .loop => try self.genLoop(inst.castTag(.loop).?),
                 .not => try self.genNot(inst.castTag(.not).?),
@@ -414,6 +419,8 @@ pub const LLVMIRModule = struct {
                 .store => try self.genStore(inst.castTag(.store).?),
                 .sub => try self.genSub(inst.castTag(.sub).?),
                 .unreach => self.genUnreach(inst.castTag(.unreach).?),
+                .optional_payload => try self.genOptionalPayload(inst.castTag(.optional_payload).?, false),
+                .optional_payload_ptr => try self.genOptionalPayload(inst.castTag(.optional_payload_ptr).?, true),
                 .dbg_stmt => blk: {
                     // TODO: implement debug info
                     break :blk null;
@@ -534,21 +541,29 @@ pub const LLVMIRModule = struct {
     }
 
     fn genBr(self: *LLVMIRModule, inst: *Inst.Br) !?*const llvm.Value {
-        // Get the block that we want to break to.
         var block = self.blocks.get(inst.block).?;
-        _ = self.builder.buildBr(block.parent_bb);
 
         // If the break doesn't break a value, then we don't have to add
         // the values to the lists.
-        if (!inst.operand.ty.hasCodeGenBits()) return null;
+        if (!inst.operand.ty.hasCodeGenBits()) {
+            // TODO: in astgen these instructions should turn into `br_void` instructions.
+            _ = self.builder.buildBr(block.parent_bb);
+        } else {
+            const val = try self.resolveInst(inst.operand);
 
-        // For the phi node, we need the basic blocks and the values of the
-        // break instructions.
-        try block.break_bbs.append(self.gpa, self.builder.getInsertBlock());
+            // For the phi node, we need the basic blocks and the values of the
+            // break instructions.
+            try block.break_bbs.append(self.gpa, self.builder.getInsertBlock());
+            try block.break_vals.append(self.gpa, val);
 
-        const val = try self.resolveInst(inst.operand);
-        try block.break_vals.append(self.gpa, val);
+            _ = self.builder.buildBr(block.parent_bb);
+        }
+        return null;
+    }
 
+    fn genBrVoid(self: *LLVMIRModule, inst: *Inst.BrVoid) !?*const llvm.Value {
+        var block = self.blocks.get(inst.block).?;
+        _ = self.builder.buildBr(block.parent_bb);
         return null;
     }
 
@@ -591,6 +606,44 @@ pub const LLVMIRModule = struct {
         return null;
     }
 
+    fn genIsNonNull(self: *LLVMIRModule, inst: *Inst.UnOp, operand_is_ptr: bool) !?*const llvm.Value {
+        const operand = try self.resolveInst(inst.operand);
+
+        if (operand_is_ptr) {
+            const index_type = self.context.intType(32);
+
+            var indices: [2]*const llvm.Value = .{
+                index_type.constNull(),
+                index_type.constInt(1, false),
+            };
+
+            return self.builder.buildLoad(self.builder.buildInBoundsGEP(operand, &indices, 2, ""), "");
+        } else {
+            return self.builder.buildExtractValue(operand, 1, "");
+        }
+    }
+
+    fn genIsNull(self: *LLVMIRModule, inst: *Inst.UnOp, operand_is_ptr: bool) !?*const llvm.Value {
+        return self.builder.buildNot((try self.genIsNonNull(inst, operand_is_ptr)).?, "");
+    }
+
+    fn genOptionalPayload(self: *LLVMIRModule, inst: *Inst.UnOp, operand_is_ptr: bool) !?*const llvm.Value {
+        const operand = try self.resolveInst(inst.operand);
+
+        if (operand_is_ptr) {
+            const index_type = self.context.intType(32);
+
+            var indices: [2]*const llvm.Value = .{
+                index_type.constNull(),
+                index_type.constNull(),
+            };
+
+            return self.builder.buildInBoundsGEP(operand, &indices, 2, "");
+        } else {
+            return self.builder.buildExtractValue(operand, 0, "");
+        }
+    }
+
     fn genAdd(self: *LLVMIRModule, inst: *Inst.BinOp) !?*const llvm.Value {
         const lhs = try self.resolveInst(inst.lhs);
         const rhs = try self.resolveInst(inst.rhs);
@@ -751,6 +804,13 @@ pub const LLVMIRModule = struct {
                     // TODO: consider using buildInBoundsGEP2 for opaque pointers
                     return self.builder.buildInBoundsGEP(val, &indices, 2, "");
                 },
+                .ref_val => {
+                    const elem_value = tv.val.castTag(.ref_val).?.data;
+                    const elem_type = tv.ty.castPointer().?.data;
+                    const alloca = self.buildAlloca(try self.getLLVMType(elem_type, src));
+                    _ = self.builder.buildStore(try self.genTypedValue(src, .{ .ty = elem_type, .val = elem_value }), alloca);
+                    return alloca;
+                },
                 else => return self.fail(src, "TODO implement const of pointer type '{}'", .{tv.ty}),
             },
             .Array => {
@@ -765,6 +825,29 @@ pub const LLVMIRModule = struct {
                     return self.fail(src, "TODO handle more array values", .{});
                 }
             },
+            .Optional => {
+                if (!tv.ty.isPtrLikeOptional()) {
+                    var buf: Type.Payload.ElemType = undefined;
+                    const child_type = tv.ty.optionalChild(&buf);
+                    const llvm_child_type = try self.getLLVMType(child_type, src);
+
+                    if (tv.val.tag() == .null_value) {
+                        var optional_values: [2]*const llvm.Value = .{
+                            llvm_child_type.constNull(),
+                            self.context.intType(1).constNull(),
+                        };
+                        return self.context.constStruct(&optional_values, 2, false);
+                    } else {
+                        var optional_values: [2]*const llvm.Value = .{
+                            try self.genTypedValue(src, .{ .ty = child_type, .val = tv.val }),
+                            self.context.intType(1).constAllOnes(),
+                        };
+                        return self.context.constStruct(&optional_values, 2, false);
+                    }
+                } else {
+                    return self.fail(src, "TODO implement const of optional pointer", .{});
+                }
+            },
             else => return self.fail(src, "TODO implement const of type '{}'", .{tv.ty}),
         }
     }
@@ -790,6 +873,20 @@ pub const LLVMIRModule = struct {
                 const elem_type = try self.getLLVMType(t.elemType(), src);
                 return elem_type.arrayType(@intCast(c_uint, t.abiSize(self.module.getTarget())));
             },
+            .Optional => {
+                if (!t.isPtrLikeOptional()) {
+                    var buf: Type.Payload.ElemType = undefined;
+                    const child_type = t.optionalChild(&buf);
+
+                    var optional_types: [2]*const llvm.Type = .{
+                        try self.getLLVMType(child_type, src),
+                        self.context.intType(1),
+                    };
+                    return self.context.structType(&optional_types, 2, false);
+                } else {
+                    return self.fail(src, "TODO implement optional pointers as actual pointers", .{});
+                }
+            },
             else => return self.fail(src, "TODO implement getLLVMType for type '{}'", .{t}),
         }
     }
src/astgen.zig
@@ -453,13 +453,23 @@ pub fn expr(mod: *Module, scope: *Scope, rl: ResultLoc, node: ast.Node.Index) In
             return rvalue(mod, scope, rl, result);
         },
         .unwrap_optional => {
-            const operand = try expr(mod, scope, rl, node_datas[node].lhs);
-            const op: zir.Inst.Tag = switch (rl) {
-                .ref => .optional_payload_safe_ptr,
-                else => .optional_payload_safe,
-            };
             const src = token_starts[main_tokens[node]];
-            return addZIRUnOp(mod, scope, src, op, operand);
+            switch (rl) {
+                .ref => return addZIRUnOp(
+                    mod,
+                    scope,
+                    src,
+                    .optional_payload_safe_ptr,
+                    try expr(mod, scope, .ref, node_datas[node].lhs),
+                ),
+                else => return rvalue(mod, scope, rl, try addZIRUnOp(
+                    mod,
+                    scope,
+                    src,
+                    .optional_payload_safe,
+                    try expr(mod, scope, .none, node_datas[node].lhs),
+                )),
+            }
         },
         .block_two, .block_two_semicolon => {
             const statements = [2]ast.Node.Index{ node_datas[node].lhs, node_datas[node].rhs };
@@ -1701,7 +1711,12 @@ fn orelseCatchExpr(
 
     // This could be a pointer or value depending on the `rl` parameter.
     block_scope.break_count += 1;
-    const operand = try expr(mod, &block_scope.base, block_scope.break_result_loc, lhs);
+    const operand = try expr(
+        mod,
+        &block_scope.base,
+        if (block_scope.break_result_loc == .ref) .ref else .none,
+        lhs,
+    );
     const cond = try addZIRUnOp(mod, &block_scope.base, src, cond_op, operand);
 
     const condbr = try addZIRInstSpecial(mod, &block_scope.base, src, zir.Inst.CondBr, .{
test/stage2/llvm.zig
@@ -132,4 +132,44 @@ pub fn addCases(ctx: *TestContext) !void {
             \\}
         , "");
     }
+
+    {
+        var case = ctx.exeUsingLlvmBackend("optionals", linux_x64);
+
+        case.addCompareOutput(
+            \\fn assert(ok: bool) void {
+            \\    if (!ok) unreachable;
+            \\}
+            \\
+            \\export fn main() c_int {
+            \\    var opt_val: ?i32 = 10;
+            \\    var null_val: ?i32 = null;
+            \\
+            \\    var val1: i32 = opt_val.?;
+            \\    const val1_1: i32 = opt_val.?;
+            \\    var ptr_val1 = &(opt_val.?);
+            \\    const ptr_val1_1 = &(opt_val.?);
+            \\
+            \\    var val2: i32 = null_val orelse 20;
+            \\    const val2_2: i32 = null_val orelse 20;
+            \\
+            \\    var value: i32 = 20;
+            \\    var ptr_val2 = &(null_val orelse value);
+            \\
+            \\    const val3 = opt_val orelse 30;
+            \\
+            \\    assert(val1 == 10);
+            \\    assert(val1_1 == 10);
+            \\    assert(ptr_val1.* == 10);
+            \\    assert(ptr_val1_1.* == 10);
+            \\
+            \\    assert(val2 == 20);
+            \\    assert(val2_2 == 20);
+            \\    assert(ptr_val2.* == 20);
+            \\
+            \\    assert(val3 == 10);
+            \\    return 0;
+            \\}
+        , "");
+    }
 }