Commit 56c059077c

Timon Kruiper <timonkruiper@gmail.com>
2021-01-08 19:28:34
stage2: add initial impl of control flow in LLVM backend
The following TZIR instrutions have been implemented in the backend: - all cmp operators (lt, lte, gt, gte, eq, neq) - block - br - condbr The following LLVMIR is generated for a simple assert function: ``` define void @assert(i1 %0) { Entry: %1 = alloca i1, align 1 store i1 %0, i1* %1, align 1 %2 = load i1, i1* %1, align 1 %3 = xor i1 %2, true br i1 %3, label %Then, label %Else Then: ; preds = %Entry call void @llvm.debugtrap() unreachable Else: ; preds = %Entry br label %Block Block: ; preds = %Else ret void } ``` See tests for more examples.
1 parent 3715ed7
Changed files (3)
src
test
stage2
src/codegen/llvm/bindings.zig
@@ -24,6 +24,9 @@ pub const Context = opaque {
     pub const constString = LLVMConstStringInContext;
     extern fn LLVMConstStringInContext(C: *const Context, Str: [*]const u8, Length: c_uint, DontNullTerminate: LLVMBool) *const Value;
 
+    pub const createBasicBlock = LLVMCreateBasicBlockInContext;
+    extern fn LLVMCreateBasicBlockInContext(C: *const Context, Name: [*:0]const u8) *const BasicBlock;
+
     pub const appendBasicBlock = LLVMAppendBasicBlockInContext;
     extern fn LLVMAppendBasicBlockInContext(C: *const Context, Fn: *const Value, Name: [*:0]const u8) *const BasicBlock;
 
@@ -38,6 +41,12 @@ pub const Value = opaque {
     pub const getFirstBasicBlock = LLVMGetFirstBasicBlock;
     extern fn LLVMGetFirstBasicBlock(Fn: *const Value) ?*const BasicBlock;
 
+    pub const appendExistingBasicBlock = LLVMAppendExistingBasicBlock;
+    extern fn LLVMAppendExistingBasicBlock(Fn: *const Value, BB: *const BasicBlock) void;
+
+    pub const addIncoming = LLVMAddIncoming;
+    extern fn LLVMAddIncoming(PhiNode: *const Value, IncomingValues: [*]*const Value, IncomingBlocks: [*]*const BasicBlock, Count: c_uint) void;
+
     pub const getNextInstruction = LLVMGetNextInstruction;
     extern fn LLVMGetNextInstruction(Inst: *const Value) ?*const Value;
 };
@@ -183,6 +192,31 @@ pub const Builder = opaque {
 
     pub const buildInBoundsGEP = LLVMBuildInBoundsGEP;
     extern fn LLVMBuildInBoundsGEP(B: *const Builder, Pointer: *const Value, Indices: [*]*const Value, NumIndices: c_uint, Name: [*:0]const u8) *const Value;
+
+    pub const buildICmp = LLVMBuildICmp;
+    extern fn LLVMBuildICmp(*const Builder, Op: IntPredicate, LHS: *const Value, RHS: *const Value, Name: [*:0]const u8) *const Value;
+
+    pub const buildBr = LLVMBuildBr;
+    extern fn LLVMBuildBr(*const Builder, Dest: *const BasicBlock) *const Value;
+
+    pub const buildCondBr = LLVMBuildCondBr;
+    extern fn LLVMBuildCondBr(*const Builder, If: *const Value, Then: *const BasicBlock, Else: *const BasicBlock) *const Value;
+
+    pub const buildPhi = LLVMBuildPhi;
+    extern fn LLVMBuildPhi(*const Builder, Ty: *const Type, Name: [*:0]const u8) *const Value;
+};
+
+pub const IntPredicate = extern enum {
+    EQ = 32,
+    NE = 33,
+    UGT = 34,
+    UGE = 35,
+    ULT = 36,
+    ULE = 37,
+    SGT = 38,
+    SGE = 39,
+    SLT = 40,
+    SLE = 41,
 };
 
 pub const BasicBlock = opaque {
src/codegen/llvm.zig
@@ -5,6 +5,7 @@ const Compilation = @import("../Compilation.zig");
 const llvm = @import("llvm/bindings.zig");
 const link = @import("../link.zig");
 const log = std.log.scoped(.codegen);
+const math = std.math;
 
 const Module = @import("../Module.zig");
 const TypedValue = @import("../TypedValue.zig");
@@ -154,6 +155,8 @@ pub const LLVMIRModule = struct {
 
     /// This stores the LLVM values used in a function, such that they can be
     /// referred to in other instructions. This table is cleared before every function is generated.
+    /// TODO: Change this to a stack of Branch. Currently we store all the values from all the blocks
+    /// in here, however if a block ends, the instructions can be thrown away.
     func_inst_table: std.AutoHashMapUnmanaged(*Inst, *const llvm.Value) = .{},
 
     /// These fields are used to refer to the LLVM value of the function paramaters in an Arg instruction.
@@ -165,6 +168,18 @@ pub const LLVMIRModule = struct {
     /// to the top of the function.
     latest_alloca_inst: ?*const llvm.Value = null,
 
+    llvm_func: *const llvm.Value = undefined,
+
+    /// This data structure is used to implement breaking to blocks.
+    blocks: std.AutoHashMapUnmanaged(*Inst.Block, struct {
+        parent_bb: *const llvm.BasicBlock,
+        break_bbs: *BreakBasicBlocks,
+        break_vals: *BreakValues,
+    }) = .{},
+
+    const BreakBasicBlocks = std.ArrayListUnmanaged(*const llvm.BasicBlock);
+    const BreakValues = std.ArrayListUnmanaged(*const llvm.Value);
+
     pub fn create(allocator: *Allocator, sub_path: []const u8, options: link.Options) !*LLVMIRModule {
         const self = try allocator.create(LLVMIRModule);
         errdefer allocator.destroy(self);
@@ -252,6 +267,8 @@ pub const LLVMIRModule = struct {
         self.func_inst_table.deinit(self.gpa);
         self.gpa.free(self.object_path);
 
+        self.blocks.deinit(self.gpa);
+
         allocator.destroy(self);
     }
 
@@ -349,32 +366,9 @@ pub const LLVMIRModule = struct {
             self.entry_block = self.context.appendBasicBlock(llvm_func, "Entry");
             self.builder.positionBuilderAtEnd(self.entry_block);
             self.latest_alloca_inst = null;
+            self.llvm_func = llvm_func;
 
-            const instructions = func.body.instructions;
-            for (instructions) |inst| {
-                const opt_llvm_val: ?*const llvm.Value = switch (inst.tag) {
-                    .add => try self.genAdd(inst.castTag(.add).?),
-                    .alloc => try self.genAlloc(inst.castTag(.alloc).?),
-                    .arg => try self.genArg(inst.castTag(.arg).?),
-                    .bitcast => try self.genBitCast(inst.castTag(.bitcast).?),
-                    .breakpoint => try self.genBreakpoint(inst.castTag(.breakpoint).?),
-                    .call => try self.genCall(inst.castTag(.call).?),
-                    .intcast => try self.genIntCast(inst.castTag(.intcast).?),
-                    .load => try self.genLoad(inst.castTag(.load).?),
-                    .not => try self.genNot(inst.castTag(.not).?),
-                    .ret => try self.genRet(inst.castTag(.ret).?),
-                    .retvoid => self.genRetVoid(inst.castTag(.retvoid).?),
-                    .store => try self.genStore(inst.castTag(.store).?),
-                    .sub => try self.genSub(inst.castTag(.sub).?),
-                    .unreach => self.genUnreach(inst.castTag(.unreach).?),
-                    .dbg_stmt => blk: {
-                        // TODO: implement debug info
-                        break :blk null;
-                    },
-                    else => |tag| return self.fail(src, "TODO implement LLVM codegen for Zir instruction: {}", .{tag}),
-                };
-                if (opt_llvm_val) |llvm_val| try self.func_inst_table.putNoClobber(self.gpa, inst, llvm_val);
-            }
+            try self.genBody(func.body);
         } else if (typed_value.val.castTag(.extern_fn)) |extern_fn| {
             _ = try self.resolveLLVMFunction(extern_fn.data, src);
         } else {
@@ -382,6 +376,42 @@ pub const LLVMIRModule = struct {
         }
     }
 
+    fn genBody(self: *LLVMIRModule, body: ir.Body) error{ OutOfMemory, CodegenFail }!void {
+        for (body.instructions) |inst| {
+            const opt_value = switch (inst.tag) {
+                .add => try self.genAdd(inst.castTag(.add).?),
+                .alloc => try self.genAlloc(inst.castTag(.alloc).?),
+                .arg => try self.genArg(inst.castTag(.arg).?),
+                .bitcast => try self.genBitCast(inst.castTag(.bitcast).?),
+                .block => try self.genBlock(inst.castTag(.block).?),
+                .br => try self.genBr(inst.castTag(.br).?),
+                .breakpoint => try self.genBreakpoint(inst.castTag(.breakpoint).?),
+                .call => try self.genCall(inst.castTag(.call).?),
+                .cmp_eq => try self.genCmp(inst.castTag(.cmp_eq).?, .eq),
+                .cmp_gt => try self.genCmp(inst.castTag(.cmp_gt).?, .gt),
+                .cmp_gte => try self.genCmp(inst.castTag(.cmp_gte).?, .gte),
+                .cmp_lt => try self.genCmp(inst.castTag(.cmp_lt).?, .lt),
+                .cmp_lte => try self.genCmp(inst.castTag(.cmp_lte).?, .lte),
+                .cmp_neq => try self.genCmp(inst.castTag(.cmp_neq).?, .neq),
+                .condbr => try self.genCondBr(inst.castTag(.condbr).?),
+                .intcast => try self.genIntCast(inst.castTag(.intcast).?),
+                .load => try self.genLoad(inst.castTag(.load).?),
+                .not => try self.genNot(inst.castTag(.not).?),
+                .ret => try self.genRet(inst.castTag(.ret).?),
+                .retvoid => self.genRetVoid(inst.castTag(.retvoid).?),
+                .store => try self.genStore(inst.castTag(.store).?),
+                .sub => try self.genSub(inst.castTag(.sub).?),
+                .unreach => self.genUnreach(inst.castTag(.unreach).?),
+                .dbg_stmt => blk: {
+                    // TODO: implement debug info
+                    break :blk null;
+                },
+                else => |tag| return self.fail(inst.src, "TODO implement LLVM codegen for Zir instruction: {}", .{tag}),
+            };
+            if (opt_value) |val| try self.func_inst_table.putNoClobber(self.gpa, inst, val);
+        }
+    }
+
     fn genCall(self: *LLVMIRModule, inst: *Inst.Call) !?*const llvm.Value {
         if (inst.func.value()) |func_value| {
             const fn_decl = if (func_value.castTag(.extern_fn)) |extern_fn|
@@ -436,6 +466,99 @@ pub const LLVMIRModule = struct {
         return null;
     }
 
+    fn genCmp(self: *LLVMIRModule, inst: *Inst.BinOp, op: math.CompareOperator) !?*const llvm.Value {
+        const lhs = try self.resolveInst(inst.lhs);
+        const rhs = try self.resolveInst(inst.rhs);
+
+        if (!inst.base.ty.isInt())
+            if (inst.base.ty.tag() != .bool)
+                return self.fail(inst.base.src, "TODO implement 'genCmp' for type {}", .{inst.base.ty});
+
+        const is_signed = inst.base.ty.isSignedInt();
+        const operation = switch (op) {
+            .eq => .EQ,
+            .neq => .NE,
+            .lt => @as(llvm.IntPredicate, if (is_signed) .SLT else .ULT),
+            .lte => @as(llvm.IntPredicate, if (is_signed) .SLE else .ULE),
+            .gt => @as(llvm.IntPredicate, if (is_signed) .SGT else .UGT),
+            .gte => @as(llvm.IntPredicate, if (is_signed) .SGE else .UGE),
+        };
+
+        return self.builder.buildICmp(operation, lhs, rhs, "");
+    }
+
+    fn genBlock(self: *LLVMIRModule, inst: *Inst.Block) !?*const llvm.Value {
+        const parent_bb = self.context.createBasicBlock("Block");
+
+        // 5 breaks to a block seems like a reasonable default.
+        var break_bbs = try BreakBasicBlocks.initCapacity(self.gpa, 5);
+        var break_vals = try BreakValues.initCapacity(self.gpa, 5);
+        try self.blocks.putNoClobber(self.gpa, inst, .{
+            .parent_bb = parent_bb,
+            .break_bbs = &break_bbs,
+            .break_vals = &break_vals,
+        });
+        defer {
+            self.blocks.removeAssertDiscard(inst);
+            break_bbs.deinit(self.gpa);
+            break_vals.deinit(self.gpa);
+        }
+
+        try self.genBody(inst.body);
+
+        self.llvm_func.appendExistingBasicBlock(parent_bb);
+        self.builder.positionBuilderAtEnd(parent_bb);
+
+        // If the block does not return a value, we dont have to create a phi node.
+        if (!inst.base.ty.hasCodeGenBits()) return null;
+
+        const phi_node = self.builder.buildPhi(try self.getLLVMType(inst.base.ty, inst.base.src), "");
+        phi_node.addIncoming(
+            break_vals.items.ptr,
+            break_bbs.items.ptr,
+            @intCast(c_uint, break_vals.items.len),
+        );
+        return phi_node;
+    }
+
+    fn genBr(self: *LLVMIRModule, inst: *Inst.Br) !?*const llvm.Value {
+        // Get the block that we want to break to.
+        var block = self.blocks.get(inst.block).?;
+        _ = self.builder.buildBr(block.parent_bb);
+
+        // If the break doesn't break a value, then we don't have to add
+        // the values to the lists.
+        if (!inst.operand.ty.hasCodeGenBits()) return null;
+
+        // For the phi node, we need the basic blocks and the values of the
+        // break instructions.
+        try block.break_bbs.append(self.gpa, self.builder.getInsertBlock());
+
+        const val = try self.resolveInst(inst.operand);
+        try block.break_vals.append(self.gpa, val);
+
+        return null;
+    }
+
+    fn genCondBr(self: *LLVMIRModule, inst: *Inst.CondBr) !?*const llvm.Value {
+        const condition_value = try self.resolveInst(inst.condition);
+
+        const then_block = self.context.appendBasicBlock(self.llvm_func, "Then");
+        const else_block = self.context.appendBasicBlock(self.llvm_func, "Else");
+        {
+            const prev_block = self.builder.getInsertBlock();
+            defer self.builder.positionBuilderAtEnd(prev_block);
+
+            self.builder.positionBuilderAtEnd(then_block);
+            try self.genBody(inst.then_body);
+
+            self.builder.positionBuilderAtEnd(else_block);
+            try self.genBody(inst.else_body);
+        }
+        _ = self.builder.buildCondBr(condition_value, then_block, else_block);
+        return null;
+    }
+
     fn genNot(self: *LLVMIRModule, inst: *Inst.UnOp) !?*const llvm.Value {
         return self.builder.buildNot(try self.resolveInst(inst.operand), "");
     }
@@ -509,6 +632,9 @@ pub const LLVMIRModule = struct {
     /// Use this instead of builder.buildAlloca, because this function makes sure to
     /// put the alloca instruction at the top of the function!
     fn buildAlloca(self: *LLVMIRModule, t: *const llvm.Type) *const llvm.Value {
+        const prev_block = self.builder.getInsertBlock();
+        defer self.builder.positionBuilderAtEnd(prev_block);
+
         if (self.latest_alloca_inst) |latest_alloc| {
             // builder.positionBuilder adds it before the instruction,
             // but we want to put it after the last alloca instruction.
@@ -521,7 +647,6 @@ pub const LLVMIRModule = struct {
                 self.builder.positionBuilder(self.entry_block, first_inst);
             }
         }
-        defer self.builder.positionBuilderAtEnd(self.entry_block);
 
         const val = self.builder.buildAlloca(t, "");
         self.latest_alloca_inst = val;
test/stage2/llvm.zig
@@ -40,4 +40,75 @@ pub fn addCases(ctx: *TestContext) !void {
             \\}
         , "hello world!" ++ std.cstr.line_sep);
     }
+
+    {
+        var case = ctx.exeUsingLlvmBackend("simple if statement", linux_x64);
+
+        case.addCompareOutput(
+            \\fn add(a: i32, b: i32) i32 {
+            \\    return a + b;
+            \\}
+            \\
+            \\fn assert(ok: bool) void {
+            \\    if (!ok) unreachable;
+            \\}
+            \\
+            \\export fn main() c_int {
+            \\    assert(add(1,2) == 3);
+            \\    return 0;
+            \\}
+        , "");
+    }
+
+    {
+        var case = ctx.exeUsingLlvmBackend("blocks", linux_x64);
+
+        case.addCompareOutput(
+            \\fn assert(ok: bool) void {
+            \\    if (!ok) unreachable;
+            \\}
+            \\
+            \\fn foo(ok: bool) i32 {
+            \\    const val: i32 = blk: {
+            \\        var x: i32 = 1;
+            \\        if (!ok) break :blk x + 9;
+            \\        break :blk x + 19;
+            \\    };
+            \\    return val + 10;
+            \\}
+            \\
+            \\export fn main() c_int {
+            \\    assert(foo(false) == 20);
+            \\    assert(foo(true) == 30);
+            \\    return 0;
+            \\}
+        , "");
+    }
+
+    {
+        var case = ctx.exeUsingLlvmBackend("nested blocks", linux_x64);
+
+        case.addCompareOutput(
+            \\fn assert(ok: bool) void {
+            \\    if (!ok) unreachable;
+            \\}
+            \\
+            \\fn foo(ok: bool) i32 {
+            \\    var val: i32 = blk: {
+            \\        const val2: i32 = another: {
+            \\            if (!ok) break :blk 10;
+            \\            break :another 10;
+            \\        };
+            \\        break :blk val2 + 10;
+            \\    };
+            \\    return val;
+            \\}
+            \\
+            \\export fn main() c_int {
+            \\    assert(foo(false) == 10);
+            \\    assert(foo(true) == 20);
+            \\    return 0;
+            \\}
+        , "");
+    }
 }