Commit 2f0e4e9cb2

Andrew Kelley <superjoe30@gmail.com>
2015-12-08 20:25:30
codegen does signed, unsigned, and floating point math
1 parent 3e06ed0
src/codegen.cpp
@@ -210,6 +210,10 @@ static LLVMValueRef gen_arithmetic_bin_op_expr(CodeGen *g, AstNode *node) {
     LLVMValueRef val1 = gen_expr(g, node->data.bin_op_expr.op1);
     LLVMValueRef val2 = gen_expr(g, node->data.bin_op_expr.op2);
 
+    TypeTableEntry *op1_type = get_expr_type(node->data.bin_op_expr.op1);
+    TypeTableEntry *op2_type = get_expr_type(node->data.bin_op_expr.op2);
+    assert(op1_type == op2_type);
+
     switch (node->data.bin_op_expr.bin_op) {
         case BinOpTypeBinOr:
             add_debug_source_node(g, node);
@@ -224,29 +228,51 @@ static LLVMValueRef gen_arithmetic_bin_op_expr(CodeGen *g, AstNode *node) {
             add_debug_source_node(g, node);
             return LLVMBuildShl(g->builder, val1, val2, "");
         case BinOpTypeBitShiftRight:
-            // TODO implement type system so that we know whether to do
-            // logical or arithmetic shifting here.
-            // signed -> arithmetic, unsigned -> logical
             add_debug_source_node(g, node);
-            return LLVMBuildLShr(g->builder, val1, val2, "");
+            if (op1_type->is_signed_int) {
+                return LLVMBuildAShr(g->builder, val1, val2, "");
+            } else {
+                return LLVMBuildLShr(g->builder, val1, val2, "");
+            }
         case BinOpTypeAdd:
             add_debug_source_node(g, node);
-            return LLVMBuildAdd(g->builder, val1, val2, "");
+            if (op1_type->is_float) {
+                return LLVMBuildFAdd(g->builder, val1, val2, "");
+            } else {
+                return LLVMBuildNSWAdd(g->builder, val1, val2, "");
+            }
         case BinOpTypeSub:
             add_debug_source_node(g, node);
-            return LLVMBuildSub(g->builder, val1, val2, "");
+            if (op1_type->is_float) {
+                return LLVMBuildFSub(g->builder, val1, val2, "");
+            } else {
+                return LLVMBuildNSWSub(g->builder, val1, val2, "");
+            }
         case BinOpTypeMult:
-            // TODO types so we know float vs int
             add_debug_source_node(g, node);
-            return LLVMBuildMul(g->builder, val1, val2, "");
+            if (op1_type->is_float) {
+                return LLVMBuildFMul(g->builder, val1, val2, "");
+            } else {
+                return LLVMBuildNSWMul(g->builder, val1, val2, "");
+            }
         case BinOpTypeDiv:
-            // TODO types so we know float vs int and signed vs unsigned
             add_debug_source_node(g, node);
-            return LLVMBuildSDiv(g->builder, val1, val2, "");
+            if (op1_type->is_float) {
+                return LLVMBuildFDiv(g->builder, val1, val2, "");
+            } else if (op1_type->is_signed_int) {
+                return LLVMBuildSDiv(g->builder, val1, val2, "");
+            } else {
+                return LLVMBuildUDiv(g->builder, val1, val2, "");
+            }
         case BinOpTypeMod:
-            // TODO types so we know float vs int and signed vs unsigned
             add_debug_source_node(g, node);
-            return LLVMBuildSRem(g->builder, val1, val2, "");
+            if (op1_type->is_float) {
+                return LLVMBuildFRem(g->builder, val1, val2, "");
+            } else if (op1_type->is_signed_int) {
+                return LLVMBuildSRem(g->builder, val1, val2, "");
+            } else {
+                return LLVMBuildURem(g->builder, val1, val2, "");
+            }
         case BinOpTypeBoolOr:
         case BinOpTypeBoolAnd:
         case BinOpTypeCmpEq:
@@ -281,16 +307,43 @@ static LLVMIntPredicate cmp_op_to_int_predicate(BinOpType cmp_op, bool is_signed
     }
 }
 
+static LLVMRealPredicate cmp_op_to_real_predicate(BinOpType cmp_op) {
+    switch (cmp_op) {
+        case BinOpTypeCmpEq:
+            return LLVMRealOEQ;
+        case BinOpTypeCmpNotEq:
+            return LLVMRealONE;
+        case BinOpTypeCmpLessThan:
+            return LLVMRealOLT;
+        case BinOpTypeCmpGreaterThan:
+            return LLVMRealOGT;
+        case BinOpTypeCmpLessOrEq:
+            return LLVMRealOLE;
+        case BinOpTypeCmpGreaterOrEq:
+            return LLVMRealOGE;
+        default:
+            zig_unreachable();
+    }
+}
+
 static LLVMValueRef gen_cmp_expr(CodeGen *g, AstNode *node) {
     assert(node->type == NodeTypeBinOpExpr);
 
     LLVMValueRef val1 = gen_expr(g, node->data.bin_op_expr.op1);
     LLVMValueRef val2 = gen_expr(g, node->data.bin_op_expr.op2);
 
-    // TODO implement type system so that we know whether to do signed or unsigned comparison here
-    LLVMIntPredicate pred = cmp_op_to_int_predicate(node->data.bin_op_expr.bin_op, true);
+    TypeTableEntry *op1_type = get_expr_type(node->data.bin_op_expr.op1);
+    TypeTableEntry *op2_type = get_expr_type(node->data.bin_op_expr.op2);
+    assert(op1_type == op2_type);
+
     add_debug_source_node(g, node);
-    return LLVMBuildICmp(g->builder, pred, val1, val2, "");
+    if (op1_type->is_float) {
+        LLVMRealPredicate pred = cmp_op_to_real_predicate(node->data.bin_op_expr.bin_op);
+        return LLVMBuildFCmp(g->builder, pred, val1, val2, "");
+    } else {
+        LLVMIntPredicate pred = cmp_op_to_int_predicate(node->data.bin_op_expr.bin_op, op1_type->is_signed_int);
+        return LLVMBuildICmp(g->builder, pred, val1, val2, "");
+    }
 }
 
 static LLVMValueRef gen_bool_and_expr(CodeGen *g, AstNode *node) {
@@ -847,12 +900,26 @@ static void define_primitive_types(CodeGen *g) {
         buf_init_from_str(&entry->name, "i32");
         entry->size_in_bits = 32;
         entry->align_in_bits = 32;
+        entry->is_signed_int = true;
         entry->di_type = LLVMZigCreateDebugBasicType(g->dbuilder, buf_ptr(&entry->name),
                 entry->size_in_bits, entry->align_in_bits,
                 LLVMZigEncoding_DW_ATE_signed());
         g->type_table.put(&entry->name, entry);
         g->builtin_types.entry_i32 = entry;
     }
+    {
+        TypeTableEntry *entry = new_type_table_entry();
+        entry->type_ref = LLVMFloatType();
+        buf_init_from_str(&entry->name, "f32");
+        entry->size_in_bits = 32;
+        entry->align_in_bits = 32;
+        entry->is_float = true;
+        entry->di_type = LLVMZigCreateDebugBasicType(g->dbuilder, buf_ptr(&entry->name),
+                entry->size_in_bits, entry->align_in_bits,
+                LLVMZigEncoding_DW_ATE_float());
+        g->type_table.put(&entry->name, entry);
+        g->builtin_types.entry_f32 = entry;
+    }
     {
         TypeTableEntry *entry = new_type_table_entry();
         entry->type_ref = LLVMVoidType();
@@ -918,6 +985,8 @@ static void init(CodeGen *g, Buf *source_path) {
     g->builder = LLVMCreateBuilder();
     g->dbuilder = LLVMZigCreateDIBuilder(g->module, true);
 
+    LLVMZigSetFastMath(g->builder, true);
+
 
     define_primitive_types(g);
 
@@ -1058,6 +1127,8 @@ static void to_c_type(CodeGen *g, AstNode *type_node, Buf *out_buf) {
     } else if (type_entry == g->builtin_types.entry_i32) {
         g->c_stdint_used = true;
         buf_init_from_str(out_buf, "int32_t");
+    } else if (type_entry == g->builtin_types.entry_f32) {
+        buf_init_from_str(out_buf, "float");
     } else if (type_entry == g->builtin_types.entry_unreachable) {
         buf_init_from_str(out_buf, "__attribute__((__noreturn__)) void");
     } else if (type_entry == g->builtin_types.entry_bool) {
src/semantic_info.hpp
@@ -21,6 +21,8 @@ struct TypeTableEntry {
     LLVMZigDIType *di_type;
     uint64_t size_in_bits;
     uint64_t align_in_bits;
+    bool is_signed_int;
+    bool is_float;
 
     TypeTableEntry *pointer_child;
     bool pointer_is_const;
@@ -82,6 +84,7 @@ struct CodeGen {
         TypeTableEntry *entry_bool;
         TypeTableEntry *entry_u8;
         TypeTableEntry *entry_i32;
+        TypeTableEntry *entry_f32;
         TypeTableEntry *entry_string_literal;
         TypeTableEntry *entry_void;
         TypeTableEntry *entry_unreachable;
src/zig_llvm.cpp
@@ -185,6 +185,10 @@ unsigned LLVMZigEncoding_DW_ATE_signed(void) {
     return dwarf::DW_ATE_signed;
 }
 
+unsigned LLVMZigEncoding_DW_ATE_float(void) {
+    return dwarf::DW_ATE_float;
+}
+
 unsigned LLVMZigLang_DW_LANG_C99(void) {
     return dwarf::DW_LANG_C99;
 }
@@ -322,6 +326,16 @@ LLVMZigDILocation *LLVMZigGetDebugLoc(unsigned line, unsigned col, LLVMZigDIScop
     return reinterpret_cast<LLVMZigDILocation*>(debug_loc.get());
 }
 
+void LLVMZigSetFastMath(LLVMBuilderRef builder_wrapped, bool on_state) {
+    if (on_state) {
+        FastMathFlags fmf;
+        fmf.setUnsafeAlgebra();
+        unwrap(builder_wrapped)->SetFastMathFlags(fmf);
+    } else {
+        unwrap(builder_wrapped)->clearFastMathFlags();
+    }
+}
+
 //------------------------------------
 
 enum FloatAbi {
src/zig_llvm.hpp
@@ -55,6 +55,7 @@ LLVMZigDISubroutineType *LLVMZigCreateSubroutineType(LLVMZigDIBuilder *dibuilder
 
 unsigned LLVMZigEncoding_DW_ATE_unsigned(void);
 unsigned LLVMZigEncoding_DW_ATE_signed(void);
+unsigned LLVMZigEncoding_DW_ATE_float(void);
 unsigned LLVMZigLang_DW_LANG_C99(void);
 unsigned LLVMZigTag_DW_auto_variable(void);
 unsigned LLVMZigTag_DW_arg_variable(void);
@@ -96,6 +97,8 @@ LLVMValueRef LLVMZigInsertDeclare(LLVMZigDIBuilder *dibuilder, LLVMValueRef stor
         LLVMZigDILocalVariable *var_info, LLVMZigDILocation *debug_loc, LLVMValueRef insert_before_instr);
 LLVMZigDILocation *LLVMZigGetDebugLoc(unsigned line, unsigned col, LLVMZigDIScope *scope);
 
+void LLVMZigSetFastMath(LLVMBuilderRef builder_wrapped, bool on_state);
+
 
 /*
  * This stuff is not LLVM API but it depends on the LLVM C++ API so we put it here.