Commit 433c17aeb1

Andrew Kelley <superjoe30@gmail.com>
2016-12-11 20:27:37
IR: implement divExact builtin
1 parent 8fcb1a1
src/all_types.hpp
@@ -1409,6 +1409,7 @@ enum IrInstructionId {
     IrInstructionIdEmbedFile,
     IrInstructionIdCmpxchg,
     IrInstructionIdFence,
+    IrInstructionIdDivExact,
 };
 
 struct IrInstruction {
@@ -1884,6 +1885,13 @@ struct IrInstructionFence {
     AtomicOrder order;
 };
 
+struct IrInstructionDivExact {
+    IrInstruction base;
+
+    IrInstruction *op1;
+    IrInstruction *op2;
+};
+
 enum LValPurpose {
     LValPurposeNone,
     LValPurposeAssign,
src/codegen.cpp
@@ -1871,6 +1871,14 @@ static LLVMValueRef ir_render_fence(CodeGen *g, IrExecutable *executable, IrInst
     return nullptr;
 }
 
+static LLVMValueRef ir_render_div_exact(CodeGen *g, IrExecutable *executable, IrInstructionDivExact *instruction) {
+    LLVMValueRef op1_val = ir_llvm_value(g, instruction->op1);
+    LLVMValueRef op2_val = ir_llvm_value(g, instruction->op2);
+
+    bool want_debug_safety = ir_want_debug_safety(g, &instruction->base);
+    return gen_div(g, want_debug_safety, op1_val, op2_val, instruction->base.type_entry, true);
+}
+
 static void set_debug_location(CodeGen *g, IrInstruction *instruction) {
     AstNode *source_node = instruction->source_node;
     Scope *scope = instruction->scope;
@@ -1964,6 +1972,8 @@ static LLVMValueRef ir_render_instruction(CodeGen *g, IrExecutable *executable,
             return ir_render_cmpxchg(g, executable, (IrInstructionCmpxchg *)instruction);
         case IrInstructionIdFence:
             return ir_render_fence(g, executable, (IrInstructionFence *)instruction);
+        case IrInstructionIdDivExact:
+            return ir_render_div_exact(g, executable, (IrInstructionDivExact *)instruction);
         case IrInstructionIdSwitchVar:
         case IrInstructionIdContainerInitList:
         case IrInstructionIdStructInit:
src/ir.cpp
@@ -355,6 +355,10 @@ static constexpr IrInstructionId ir_instruction_id(IrInstructionFence *) {
     return IrInstructionIdFence;
 }
 
+static constexpr IrInstructionId ir_instruction_id(IrInstructionDivExact *) {
+    return IrInstructionIdDivExact;
+}
+
 template<typename T>
 static T *ir_create_instruction(IrExecutable *exec, Scope *scope, AstNode *source_node) {
     T *special_instruction = allocate<T>(1);
@@ -1413,6 +1417,23 @@ static IrInstruction *ir_build_fence_from(IrBuilder *irb, IrInstruction *old_ins
     return new_instruction;
 }
 
+static IrInstruction *ir_build_div_exact(IrBuilder *irb, Scope *scope, AstNode *source_node, IrInstruction *op1, IrInstruction *op2) {
+    IrInstructionDivExact *instruction = ir_build_instruction<IrInstructionDivExact>(irb, scope, source_node);
+    instruction->op1 = op1;
+    instruction->op2 = op2;
+
+    ir_ref_instruction(op1);
+    ir_ref_instruction(op2);
+
+    return &instruction->base;
+}
+
+static IrInstruction *ir_build_div_exact_from(IrBuilder *irb, IrInstruction *old_instruction, IrInstruction *op1, IrInstruction *op2) {
+    IrInstruction *new_instruction = ir_build_div_exact(irb, old_instruction->scope, old_instruction->source_node, op1, op2);
+    ir_link_new_instruction(new_instruction, old_instruction);
+    return new_instruction;
+}
+
 static void ir_gen_defers_for_block(IrBuilder *irb, Scope *inner_scope, Scope *outer_scope,
         bool gen_error_defers, bool gen_maybe_defers)
 {
@@ -2192,6 +2213,20 @@ static IrInstruction *ir_gen_builtin_fn_call(IrBuilder *irb, Scope *scope, AstNo
 
                 return ir_build_fence(irb, scope, node, arg0_value, AtomicOrderUnordered);
             }
+        case BuiltinFnIdDivExact:
+            {
+                AstNode *arg0_node = node->data.fn_call_expr.params.at(0);
+                IrInstruction *arg0_value = ir_gen_node(irb, arg0_node, scope);
+                if (arg0_value == irb->codegen->invalid_instruction)
+                    return arg0_value;
+
+                AstNode *arg1_node = node->data.fn_call_expr.params.at(1);
+                IrInstruction *arg1_value = ir_gen_node(irb, arg1_node, scope);
+                if (arg1_value == irb->codegen->invalid_instruction)
+                    return arg1_value;
+
+                return ir_build_div_exact(irb, scope, node, arg0_value, arg1_value);
+            }
         case BuiltinFnIdMemcpy:
         case BuiltinFnIdMemset:
         case BuiltinFnIdAlignof:
@@ -2203,7 +2238,6 @@ static IrInstruction *ir_gen_builtin_fn_call(IrBuilder *irb, Scope *scope, AstNo
         case BuiltinFnIdBreakpoint:
         case BuiltinFnIdReturnAddress:
         case BuiltinFnIdFrameAddress:
-        case BuiltinFnIdDivExact:
         case BuiltinFnIdTruncate:
         case BuiltinFnIdIntType:
             zig_panic("TODO IR gen more builtin functions");
@@ -7426,6 +7460,76 @@ static TypeTableEntry *ir_analyze_instruction_fence(IrAnalyze *ira, IrInstructio
     return ira->codegen->builtin_types.entry_void;
 }
 
+static TypeTableEntry *ir_analyze_instruction_div_exact(IrAnalyze *ira, IrInstructionDivExact *instruction) {
+    IrInstruction *op1 = instruction->op1->other;
+    if (op1->type_entry->id == TypeTableEntryIdInvalid)
+        return ira->codegen->builtin_types.entry_invalid;
+
+    IrInstruction *op2 = instruction->op2->other;
+    if (op2->type_entry->id == TypeTableEntryIdInvalid)
+        return ira->codegen->builtin_types.entry_invalid;
+
+
+    IrInstruction *peer_instructions[] = { op1, op2 };
+    TypeTableEntry *result_type = ir_resolve_peer_types(ira, instruction->base.source_node, peer_instructions, 2);
+
+    if (result_type->id == TypeTableEntryIdInvalid)
+        return ira->codegen->builtin_types.entry_invalid;
+
+    TypeTableEntry *canon_type = get_underlying_type(result_type);
+
+    if (canon_type->id != TypeTableEntryIdInt &&
+        canon_type->id != TypeTableEntryIdNumLitInt)
+    {
+        ir_add_error(ira, &instruction->base,
+                buf_sprintf("expected integer type, found '%s'", buf_ptr(&result_type->name)));
+        // TODO if meta_type is type decl, add note pointing to type decl declaration
+        return ira->codegen->builtin_types.entry_invalid;
+    }
+
+    IrInstruction *casted_op1 = ir_get_casted_value(ira, op1, result_type);
+    if (casted_op1->type_entry->id == TypeTableEntryIdInvalid)
+        return ira->codegen->builtin_types.entry_invalid;
+
+    IrInstruction *casted_op2 = ir_get_casted_value(ira, op2, result_type);
+    if (casted_op2->type_entry->id == TypeTableEntryIdInvalid)
+        return ira->codegen->builtin_types.entry_invalid;
+
+    if (casted_op1->static_value.special == ConstValSpecialStatic &&
+        casted_op2->static_value.special == ConstValSpecialStatic)
+    {
+        ConstExprValue *op1_val = ir_resolve_const(ira, casted_op1);
+        ConstExprValue *op2_val = ir_resolve_const(ira, casted_op2);
+        assert(op1_val);
+        assert(op2_val);
+
+        if (op1_val->data.x_bignum.data.x_uint == 0) {
+            ir_add_error(ira, &instruction->base, buf_sprintf("division by zero"));
+            return ira->codegen->builtin_types.entry_invalid;
+        }
+
+        BigNum remainder;
+        if (bignum_mod(&remainder, &op1_val->data.x_bignum, &op2_val->data.x_bignum)) {
+            ir_add_error(ira, &instruction->base, buf_sprintf("integer overflow"));
+            return ira->codegen->builtin_types.entry_invalid;
+        }
+
+        if (remainder.data.x_uint != 0) {
+            ir_add_error(ira, &instruction->base, buf_sprintf("exact division had a remainder"));
+            return ira->codegen->builtin_types.entry_invalid;
+        }
+
+        bool depends_on_compile_var = casted_op1->static_value.depends_on_compile_var ||
+            casted_op2->static_value.depends_on_compile_var;
+        ConstExprValue *out_val = ir_build_const_from(ira, &instruction->base, depends_on_compile_var);
+        bignum_div(&out_val->data.x_bignum, &op1_val->data.x_bignum, &op2_val->data.x_bignum);
+        return result_type;
+    }
+
+    ir_build_div_exact_from(&ira->new_irb, &instruction->base, casted_op1, casted_op2);
+    return result_type;
+}
+
 static TypeTableEntry *ir_analyze_instruction_nocast(IrAnalyze *ira, IrInstruction *instruction) {
     switch (instruction->id) {
         case IrInstructionIdInvalid:
@@ -7532,6 +7636,8 @@ static TypeTableEntry *ir_analyze_instruction_nocast(IrAnalyze *ira, IrInstructi
             return ir_analyze_instruction_cmpxchg(ira, (IrInstructionCmpxchg *)instruction);
         case IrInstructionIdFence:
             return ir_analyze_instruction_fence(ira, (IrInstructionFence *)instruction);
+        case IrInstructionIdDivExact:
+            return ir_analyze_instruction_div_exact(ira, (IrInstructionDivExact *)instruction);
         case IrInstructionIdCast:
         case IrInstructionIdStructFieldPtr:
         case IrInstructionIdEnumFieldPtr:
@@ -7669,6 +7775,7 @@ bool ir_has_side_effects(IrInstruction *instruction) {
         case IrInstructionIdMaxValue:
         case IrInstructionIdErrName:
         case IrInstructionIdEmbedFile:
+        case IrInstructionIdDivExact:
             return false;
         case IrInstructionIdAsm:
             {
@@ -7682,37 +7789,6 @@ bool ir_has_side_effects(IrInstruction *instruction) {
 // TODO port over all this commented out code into new IR way of doing things
 
 
-//static TypeTableEntry *analyze_div_exact(CodeGen *g, ImportTableEntry *import,
-//        BlockContext *context, AstNode *node)
-//{
-//    assert(node->type == NodeTypeFnCallExpr);
-//
-//    AstNode **op1 = &node->data.fn_call_expr.params.at(0);
-//    AstNode **op2 = &node->data.fn_call_expr.params.at(1);
-//
-//    TypeTableEntry *op1_type = analyze_expression(g, import, context, nullptr, *op1);
-//    TypeTableEntry *op2_type = analyze_expression(g, import, context, nullptr, *op2);
-//
-//    AstNode *op_nodes[] = {*op1, *op2};
-//    TypeTableEntry *op_types[] = {op1_type, op2_type};
-//    TypeTableEntry *result_type = resolve_peer_type_compatibility(g, import, context, node,
-//            op_nodes, op_types, 2);
-//
-//    if (result_type->id == TypeTableEntryIdInvalid) {
-//        return g->builtin_types.entry_invalid;
-//    } else if (result_type->id == TypeTableEntryIdInt) {
-//        return result_type;
-//    } else if (result_type->id == TypeTableEntryIdNumLitInt) {
-//        // check for division by zero
-//        // check for non exact division
-//        zig_panic("TODO");
-//    } else {
-//        add_node_error(g, node,
-//                buf_sprintf("expected integer type, found '%s'", buf_ptr(&result_type->name)));
-//        return g->builtin_types.entry_invalid;
-//    }
-//}
-//
 //static TypeTableEntry *analyze_truncate(CodeGen *g, ImportTableEntry *import,
 //        BlockContext *context, AstNode *node)
 //{
@@ -7920,8 +7996,6 @@ bool ir_has_side_effects(IrInstruction *instruction) {
 //        case BuiltinFnIdFrameAddress:
 //            mark_impure_fn(g, context, node);
 //            return builtin_fn->return_type;
-//        case BuiltinFnIdDivExact:
-//            return analyze_div_exact(g, import, context, node);
 //        case BuiltinFnIdTruncate:
 //            return analyze_truncate(g, import, context, node);
 //        case BuiltinFnIdIntType:
@@ -8243,18 +8317,6 @@ bool ir_has_side_effects(IrInstruction *instruction) {
 //
 
 
-//static LLVMValueRef gen_div_exact(CodeGen *g, AstNode *node) {
-//    assert(node->type == NodeTypeFnCallExpr);
-//
-//    AstNode *op1_node = node->data.fn_call_expr.params.at(0);
-//    AstNode *op2_node = node->data.fn_call_expr.params.at(1);
-//
-//    LLVMValueRef op1_val = gen_expr(g, op1_node);
-//    LLVMValueRef op2_val = gen_expr(g, op2_node);
-//
-//    return gen_div(g, node, op1_val, op2_val, get_expr_type(op1_node), true);
-//}
-//
 //static LLVMValueRef gen_truncate(CodeGen *g, AstNode *node) {
 //    assert(node->type == NodeTypeFnCallExpr);
 //
@@ -8421,8 +8483,6 @@ bool ir_has_side_effects(IrInstruction *instruction) {
 //            return gen_cmp_exchange(g, node);
 //        case BuiltinFnIdFence:
 //            return gen_fence(g, node);
-//        case BuiltinFnIdDivExact:
-//            return gen_div_exact(g, node);
 //        case BuiltinFnIdTruncate:
 //            return gen_truncate(g, node);
 //        case BuiltinFnIdUnreachable:
src/ir_print.cpp
@@ -739,6 +739,14 @@ static void ir_print_fence(IrPrint *irp, IrInstructionFence *instruction) {
     fprintf(irp->f, ")");
 }
 
+static void ir_print_div_exact(IrPrint *irp, IrInstructionDivExact *instruction) {
+    fprintf(irp->f, "@divExact(");
+    ir_print_other_instruction(irp, instruction->op1);
+    fprintf(irp->f, ", ");
+    ir_print_other_instruction(irp, instruction->op2);
+    fprintf(irp->f, ")");
+}
+
 static void ir_print_instruction(IrPrint *irp, IrInstruction *instruction) {
     ir_print_prefix(irp, instruction);
     switch (instruction->id) {
@@ -909,6 +917,9 @@ static void ir_print_instruction(IrPrint *irp, IrInstruction *instruction) {
         case IrInstructionIdFence:
             ir_print_fence(irp, (IrInstructionFence *)instruction);
             break;
+        case IrInstructionIdDivExact:
+            ir_print_div_exact(irp, (IrInstructionDivExact *)instruction);
+            break;
     }
     fprintf(irp->f, "\n");
 }
test/self_hosted2.zig
@@ -285,6 +285,13 @@ fn fence() {
     x = 5678;
 }
 
+fn exactDivision() {
+    assert(divExact(55, 11) == 5);
+}
+fn divExact(a: u32, b: u32) -> u32 {
+    @divExact(a, b)
+}
+
 fn assert(ok: bool) {
     if (!ok)
         @unreachable();
@@ -314,6 +321,7 @@ fn runAllTests() {
     testErrorName();
     cmpxchg();
     fence();
+    exactDivision();
 }
 
 export nakedcc fn _start() -> unreachable {