Commit 0f02e29a2b

Josh Wolfe <thejoshwolfe@gmail.com>
2015-12-13 03:47:37
codegen and tests for modify operators. closes #16
1 parent 5cb5f5d
Changed files (3)
src/codegen.cpp
@@ -324,30 +324,33 @@ static LLVMValueRef gen_cast_expr(CodeGen *g, AstNode *node) {
     zig_unreachable();
 }
 
-static LLVMValueRef gen_arithmetic_bin_op_expr(CodeGen *g, AstNode *node) {
+static LLVMValueRef gen_arithmetic_bin_op(CodeGen *g,
+    LLVMValueRef val1, LLVMValueRef val2,
+    TypeTableEntry *op1_type, TypeTableEntry *op2_type,
+    AstNode *node)
+{
     assert(node->type == NodeTypeBinOpExpr);
-
-    LLVMValueRef val1 = gen_expr(g, node->data.bin_op_expr.op1);
-    LLVMValueRef val2 = gen_expr(g, node->data.bin_op_expr.op2);
-
-    TypeTableEntry *op1_type = get_expr_type(node->data.bin_op_expr.op1);
-    TypeTableEntry *op2_type = get_expr_type(node->data.bin_op_expr.op2);
     assert(op1_type == op2_type);
 
     switch (node->data.bin_op_expr.bin_op) {
         case BinOpTypeBinOr:
+        case BinOpTypeAssignBitOr:
             add_debug_source_node(g, node);
             return LLVMBuildOr(g->builder, val1, val2, "");
         case BinOpTypeBinXor:
+        case BinOpTypeAssignBitXor:
             add_debug_source_node(g, node);
             return LLVMBuildXor(g->builder, val1, val2, "");
         case BinOpTypeBinAnd:
+        case BinOpTypeAssignBitAnd:
             add_debug_source_node(g, node);
             return LLVMBuildAnd(g->builder, val1, val2, "");
         case BinOpTypeBitShiftLeft:
+        case BinOpTypeAssignBitShiftLeft:
             add_debug_source_node(g, node);
             return LLVMBuildShl(g->builder, val1, val2, "");
         case BinOpTypeBitShiftRight:
+        case BinOpTypeAssignBitShiftRight:
             add_debug_source_node(g, node);
             if (op1_type->id == TypeTableEntryIdInt) {
                 return LLVMBuildAShr(g->builder, val1, val2, "");
@@ -355,6 +358,7 @@ static LLVMValueRef gen_arithmetic_bin_op_expr(CodeGen *g, AstNode *node) {
                 return LLVMBuildLShr(g->builder, val1, val2, "");
             }
         case BinOpTypeAdd:
+        case BinOpTypeAssignPlus:
             add_debug_source_node(g, node);
             if (op1_type->id == TypeTableEntryIdFloat) {
                 return LLVMBuildFAdd(g->builder, val1, val2, "");
@@ -362,6 +366,7 @@ static LLVMValueRef gen_arithmetic_bin_op_expr(CodeGen *g, AstNode *node) {
                 return LLVMBuildNSWAdd(g->builder, val1, val2, "");
             }
         case BinOpTypeSub:
+        case BinOpTypeAssignMinus:
             add_debug_source_node(g, node);
             if (op1_type->id == TypeTableEntryIdFloat) {
                 return LLVMBuildFSub(g->builder, val1, val2, "");
@@ -369,6 +374,7 @@ static LLVMValueRef gen_arithmetic_bin_op_expr(CodeGen *g, AstNode *node) {
                 return LLVMBuildNSWSub(g->builder, val1, val2, "");
             }
         case BinOpTypeMult:
+        case BinOpTypeAssignTimes:
             add_debug_source_node(g, node);
             if (op1_type->id == TypeTableEntryIdFloat) {
                 return LLVMBuildFMul(g->builder, val1, val2, "");
@@ -376,6 +382,7 @@ static LLVMValueRef gen_arithmetic_bin_op_expr(CodeGen *g, AstNode *node) {
                 return LLVMBuildNSWMul(g->builder, val1, val2, "");
             }
         case BinOpTypeDiv:
+        case BinOpTypeAssignDiv:
             add_debug_source_node(g, node);
             if (op1_type->id == TypeTableEntryIdFloat) {
                 return LLVMBuildFDiv(g->builder, val1, val2, "");
@@ -388,6 +395,7 @@ static LLVMValueRef gen_arithmetic_bin_op_expr(CodeGen *g, AstNode *node) {
                 }
             }
         case BinOpTypeMod:
+        case BinOpTypeAssignMod:
             add_debug_source_node(g, node);
             if (op1_type->id == TypeTableEntryIdFloat) {
                 return LLVMBuildFRem(g->builder, val1, val2, "");
@@ -409,22 +417,23 @@ static LLVMValueRef gen_arithmetic_bin_op_expr(CodeGen *g, AstNode *node) {
         case BinOpTypeCmpGreaterOrEq:
         case BinOpTypeInvalid:
         case BinOpTypeAssign:
-        case BinOpTypeAssignTimes:
-        case BinOpTypeAssignDiv:
-        case BinOpTypeAssignMod:
-        case BinOpTypeAssignPlus:
-        case BinOpTypeAssignMinus:
-        case BinOpTypeAssignBitShiftLeft:
-        case BinOpTypeAssignBitShiftRight:
-        case BinOpTypeAssignBitAnd:
-        case BinOpTypeAssignBitXor:
-        case BinOpTypeAssignBitOr:
         case BinOpTypeAssignBoolAnd:
         case BinOpTypeAssignBoolOr:
             zig_unreachable();
     }
     zig_unreachable();
 }
+static LLVMValueRef gen_arithmetic_bin_op_expr(CodeGen *g, AstNode *node) {
+    assert(node->type == NodeTypeBinOpExpr);
+
+    LLVMValueRef val1 = gen_expr(g, node->data.bin_op_expr.op1);
+    LLVMValueRef val2 = gen_expr(g, node->data.bin_op_expr.op2);
+
+    TypeTableEntry *op1_type = get_expr_type(node->data.bin_op_expr.op1);
+    TypeTableEntry *op2_type = get_expr_type(node->data.bin_op_expr.op2);
+    return gen_arithmetic_bin_op(g, val1, val2, op1_type, op2_type, node);
+
+}
 
 static LLVMIntPredicate cmp_op_to_int_predicate(BinOpType cmp_op, bool is_signed) {
     switch (cmp_op) {
@@ -555,11 +564,8 @@ static LLVMValueRef gen_assign_expr(CodeGen *g, AstNode *node) {
 
     AstNode *lhs_node = node->data.bin_op_expr.op1;
 
-    bool is_read_first = node->data.bin_op_expr.bin_op != BinOpTypeAssign;
-    if (is_read_first) {
-        zig_panic("TODO: implement modify assignment ops");
-    }
-
+    LLVMValueRef target_ref;
+    TypeTableEntry *op1_type;
     if (lhs_node->type == NodeTypeSymbol) {
         LocalVariableTableEntry *var = find_local_variable(node->codegen_node->expr_node.block_context,
                 &lhs_node->data.symbol);
@@ -567,33 +573,30 @@ static LLVMValueRef gen_assign_expr(CodeGen *g, AstNode *node) {
         // semantic checking ensures no variables are constant
         assert(!var->is_const);
 
-        LLVMValueRef value = gen_expr(g, node->data.bin_op_expr.op2);
-
-        add_debug_source_node(g, node);
-        return LLVMBuildStore(g->builder, value, var->value_ref);
+        op1_type = var->type;
+        target_ref = var->value_ref;
     } else if (lhs_node->type == NodeTypeArrayAccessExpr) {
-        LLVMValueRef ptr = gen_array_ptr(g, lhs_node);
-        LLVMValueRef value = gen_expr(g, node->data.bin_op_expr.op2);
-        add_debug_source_node(g, node);
-        return LLVMBuildStore(g->builder, value, ptr);
-    } else if (lhs_node->type == NodeTypeFieldAccessExpr) {
-        /*
-        LLVMValueRef ptr = gen_field_ptr(g, lhs_node);
-        LLVMValueRef value = gen_expr(g, node->data.bin_op_expr.op2);
-        add_debug_source_node(g, node);
-        return LLVMBuildStore(g->builder, value, ptr);
-        */
-        LLVMValueRef struct_val = gen_expr(g, lhs_node->data.field_access_expr.struct_expr);
-        assert(struct_val);
-        FieldAccessNode *codegen_field_access = &lhs_node->codegen_node->data.field_access_node;
-        assert(codegen_field_access->field_index >= 0);
-
-        LLVMValueRef value = gen_expr(g, node->data.bin_op_expr.op2);
-        add_debug_source_node(g, node);
-        return LLVMBuildInsertValue(g->builder, struct_val, value, codegen_field_access->field_index, "");
+        TypeTableEntry *array_type = get_expr_type(lhs_node->data.array_access_expr.array_ref_expr);
+        assert(array_type->id == TypeTableEntryIdArray);
+        op1_type = array_type->data.array.child_type;
+        target_ref = gen_array_ptr(g, lhs_node);
     } else {
         zig_panic("bad assign target");
     }
+    LLVMValueRef value = gen_expr(g, node->data.bin_op_expr.op2);
+
+    if (node->data.bin_op_expr.bin_op == BinOpTypeAssign) {
+        // value is ready as is
+    } else {
+        add_debug_source_node(g, node->data.bin_op_expr.op1);
+        LLVMValueRef left_value = LLVMBuildLoad(g->builder, target_ref, "");
+
+        TypeTableEntry *op2_type = get_expr_type(node->data.bin_op_expr.op2);
+        value = gen_arithmetic_bin_op(g, left_value, value, op1_type, op2_type, node);
+    }
+
+    add_debug_source_node(g, node);
+    return LLVMBuildStore(g->builder, value, target_ref);
 }
 
 static LLVMValueRef gen_bin_op_expr(CodeGen *g, AstNode *node) {
src/tokenizer.cpp
@@ -402,6 +402,7 @@ void tokenize(Buf *buf, Tokenization *out) {
                         t.cur_tok->id = TokenIdBitShiftRightEq;
                         end_token(&t);
                         t.state = TokenizeStateStart;
+                        break;
                     default:
                         t.pos -= 1;
                         end_token(&t);
@@ -415,6 +416,7 @@ void tokenize(Buf *buf, Tokenization *out) {
                         t.cur_tok->id = TokenIdCmpLessOrEq;
                         end_token(&t);
                         t.state = TokenizeStateStart;
+                        break;
                     case '<':
                         t.cur_tok->id = TokenIdBitShiftLeft;
                         t.state = TokenizeStateSawLessThanLessThan;
@@ -432,6 +434,7 @@ void tokenize(Buf *buf, Tokenization *out) {
                         t.cur_tok->id = TokenIdBitShiftLeftEq;
                         end_token(&t);
                         t.state = TokenizeStateStart;
+                        break;
                     default:
                         t.pos -= 1;
                         end_token(&t);
test/run_tests.cpp
@@ -454,6 +454,29 @@ export fn main(argc : isize, argv : *mut *mut u8, env : *mut *mut u8) -> i32 {
 }
     )SOURCE", "OK 1\nOK 2\nOK 3\nOK 4\n");
 
+    add_simple_case("modify operators", R"SOURCE(
+use "std.zig";
+
+export fn main(argc : isize, argv : *mut *mut u8, env : *mut *mut u8) -> i32 {
+    let mut i : i32 = 0;
+    i += 5;  if i != 5  { print_str("BAD +=\n" as string); }
+    i -= 2;  if i != 3  { print_str("BAD -=\n" as string); }
+    i *= 20; if i != 60 { print_str("BAD *=\n" as string); }
+    i /= 3;  if i != 20 { print_str("BAD /=\n" as string); }
+    i %= 11; if i != 9  { print_str("BAD %=\n" as string); }
+    i <<= 1; if i != 18 { print_str("BAD <<=\n" as string); }
+    i >>= 2; if i != 4  { print_str("BAD >>=\n" as string); }
+    i = 6;
+    i &= 5;  if i != 4  { print_str("BAD &=\n" as string); }
+    i ^= 6;  if i != 2  { print_str("BAD ^=\n" as string); }
+    i = 6;
+    i |= 3;  if i != 7  { print_str("BAD |=\n" as string); }
+
+    print_str("OK\n" as string);
+    return 0;
+}
+    )SOURCE", "OK\n");
+
 }
 
 static void add_compile_failure_test_cases(void) {