Commit 0f02e29a2b
Changed files (3)
src/codegen.cpp
@@ -324,30 +324,33 @@ static LLVMValueRef gen_cast_expr(CodeGen *g, AstNode *node) {
zig_unreachable();
}
-static LLVMValueRef gen_arithmetic_bin_op_expr(CodeGen *g, AstNode *node) {
+static LLVMValueRef gen_arithmetic_bin_op(CodeGen *g,
+ LLVMValueRef val1, LLVMValueRef val2,
+ TypeTableEntry *op1_type, TypeTableEntry *op2_type,
+ AstNode *node)
+{
assert(node->type == NodeTypeBinOpExpr);
-
- LLVMValueRef val1 = gen_expr(g, node->data.bin_op_expr.op1);
- LLVMValueRef val2 = gen_expr(g, node->data.bin_op_expr.op2);
-
- TypeTableEntry *op1_type = get_expr_type(node->data.bin_op_expr.op1);
- TypeTableEntry *op2_type = get_expr_type(node->data.bin_op_expr.op2);
assert(op1_type == op2_type);
switch (node->data.bin_op_expr.bin_op) {
case BinOpTypeBinOr:
+ case BinOpTypeAssignBitOr:
add_debug_source_node(g, node);
return LLVMBuildOr(g->builder, val1, val2, "");
case BinOpTypeBinXor:
+ case BinOpTypeAssignBitXor:
add_debug_source_node(g, node);
return LLVMBuildXor(g->builder, val1, val2, "");
case BinOpTypeBinAnd:
+ case BinOpTypeAssignBitAnd:
add_debug_source_node(g, node);
return LLVMBuildAnd(g->builder, val1, val2, "");
case BinOpTypeBitShiftLeft:
+ case BinOpTypeAssignBitShiftLeft:
add_debug_source_node(g, node);
return LLVMBuildShl(g->builder, val1, val2, "");
case BinOpTypeBitShiftRight:
+ case BinOpTypeAssignBitShiftRight:
add_debug_source_node(g, node);
if (op1_type->id == TypeTableEntryIdInt) {
return LLVMBuildAShr(g->builder, val1, val2, "");
@@ -355,6 +358,7 @@ static LLVMValueRef gen_arithmetic_bin_op_expr(CodeGen *g, AstNode *node) {
return LLVMBuildLShr(g->builder, val1, val2, "");
}
case BinOpTypeAdd:
+ case BinOpTypeAssignPlus:
add_debug_source_node(g, node);
if (op1_type->id == TypeTableEntryIdFloat) {
return LLVMBuildFAdd(g->builder, val1, val2, "");
@@ -362,6 +366,7 @@ static LLVMValueRef gen_arithmetic_bin_op_expr(CodeGen *g, AstNode *node) {
return LLVMBuildNSWAdd(g->builder, val1, val2, "");
}
case BinOpTypeSub:
+ case BinOpTypeAssignMinus:
add_debug_source_node(g, node);
if (op1_type->id == TypeTableEntryIdFloat) {
return LLVMBuildFSub(g->builder, val1, val2, "");
@@ -369,6 +374,7 @@ static LLVMValueRef gen_arithmetic_bin_op_expr(CodeGen *g, AstNode *node) {
return LLVMBuildNSWSub(g->builder, val1, val2, "");
}
case BinOpTypeMult:
+ case BinOpTypeAssignTimes:
add_debug_source_node(g, node);
if (op1_type->id == TypeTableEntryIdFloat) {
return LLVMBuildFMul(g->builder, val1, val2, "");
@@ -376,6 +382,7 @@ static LLVMValueRef gen_arithmetic_bin_op_expr(CodeGen *g, AstNode *node) {
return LLVMBuildNSWMul(g->builder, val1, val2, "");
}
case BinOpTypeDiv:
+ case BinOpTypeAssignDiv:
add_debug_source_node(g, node);
if (op1_type->id == TypeTableEntryIdFloat) {
return LLVMBuildFDiv(g->builder, val1, val2, "");
@@ -388,6 +395,7 @@ static LLVMValueRef gen_arithmetic_bin_op_expr(CodeGen *g, AstNode *node) {
}
}
case BinOpTypeMod:
+ case BinOpTypeAssignMod:
add_debug_source_node(g, node);
if (op1_type->id == TypeTableEntryIdFloat) {
return LLVMBuildFRem(g->builder, val1, val2, "");
@@ -409,22 +417,23 @@ static LLVMValueRef gen_arithmetic_bin_op_expr(CodeGen *g, AstNode *node) {
case BinOpTypeCmpGreaterOrEq:
case BinOpTypeInvalid:
case BinOpTypeAssign:
- case BinOpTypeAssignTimes:
- case BinOpTypeAssignDiv:
- case BinOpTypeAssignMod:
- case BinOpTypeAssignPlus:
- case BinOpTypeAssignMinus:
- case BinOpTypeAssignBitShiftLeft:
- case BinOpTypeAssignBitShiftRight:
- case BinOpTypeAssignBitAnd:
- case BinOpTypeAssignBitXor:
- case BinOpTypeAssignBitOr:
case BinOpTypeAssignBoolAnd:
case BinOpTypeAssignBoolOr:
zig_unreachable();
}
zig_unreachable();
}
+static LLVMValueRef gen_arithmetic_bin_op_expr(CodeGen *g, AstNode *node) {
+ assert(node->type == NodeTypeBinOpExpr);
+
+ LLVMValueRef val1 = gen_expr(g, node->data.bin_op_expr.op1);
+ LLVMValueRef val2 = gen_expr(g, node->data.bin_op_expr.op2);
+
+ TypeTableEntry *op1_type = get_expr_type(node->data.bin_op_expr.op1);
+ TypeTableEntry *op2_type = get_expr_type(node->data.bin_op_expr.op2);
+ return gen_arithmetic_bin_op(g, val1, val2, op1_type, op2_type, node);
+
+}
static LLVMIntPredicate cmp_op_to_int_predicate(BinOpType cmp_op, bool is_signed) {
switch (cmp_op) {
@@ -555,11 +564,8 @@ static LLVMValueRef gen_assign_expr(CodeGen *g, AstNode *node) {
AstNode *lhs_node = node->data.bin_op_expr.op1;
- bool is_read_first = node->data.bin_op_expr.bin_op != BinOpTypeAssign;
- if (is_read_first) {
- zig_panic("TODO: implement modify assignment ops");
- }
-
+ LLVMValueRef target_ref;
+ TypeTableEntry *op1_type;
if (lhs_node->type == NodeTypeSymbol) {
LocalVariableTableEntry *var = find_local_variable(node->codegen_node->expr_node.block_context,
&lhs_node->data.symbol);
@@ -567,33 +573,30 @@ static LLVMValueRef gen_assign_expr(CodeGen *g, AstNode *node) {
// semantic checking ensures no variables are constant
assert(!var->is_const);
- LLVMValueRef value = gen_expr(g, node->data.bin_op_expr.op2);
-
- add_debug_source_node(g, node);
- return LLVMBuildStore(g->builder, value, var->value_ref);
+ op1_type = var->type;
+ target_ref = var->value_ref;
} else if (lhs_node->type == NodeTypeArrayAccessExpr) {
- LLVMValueRef ptr = gen_array_ptr(g, lhs_node);
- LLVMValueRef value = gen_expr(g, node->data.bin_op_expr.op2);
- add_debug_source_node(g, node);
- return LLVMBuildStore(g->builder, value, ptr);
- } else if (lhs_node->type == NodeTypeFieldAccessExpr) {
- /*
- LLVMValueRef ptr = gen_field_ptr(g, lhs_node);
- LLVMValueRef value = gen_expr(g, node->data.bin_op_expr.op2);
- add_debug_source_node(g, node);
- return LLVMBuildStore(g->builder, value, ptr);
- */
- LLVMValueRef struct_val = gen_expr(g, lhs_node->data.field_access_expr.struct_expr);
- assert(struct_val);
- FieldAccessNode *codegen_field_access = &lhs_node->codegen_node->data.field_access_node;
- assert(codegen_field_access->field_index >= 0);
-
- LLVMValueRef value = gen_expr(g, node->data.bin_op_expr.op2);
- add_debug_source_node(g, node);
- return LLVMBuildInsertValue(g->builder, struct_val, value, codegen_field_access->field_index, "");
+ TypeTableEntry *array_type = get_expr_type(lhs_node->data.array_access_expr.array_ref_expr);
+ assert(array_type->id == TypeTableEntryIdArray);
+ op1_type = array_type->data.array.child_type;
+ target_ref = gen_array_ptr(g, lhs_node);
} else {
zig_panic("bad assign target");
}
+ LLVMValueRef value = gen_expr(g, node->data.bin_op_expr.op2);
+
+ if (node->data.bin_op_expr.bin_op == BinOpTypeAssign) {
+ // value is ready as is
+ } else {
+ add_debug_source_node(g, node->data.bin_op_expr.op1);
+ LLVMValueRef left_value = LLVMBuildLoad(g->builder, target_ref, "");
+
+ TypeTableEntry *op2_type = get_expr_type(node->data.bin_op_expr.op2);
+ value = gen_arithmetic_bin_op(g, left_value, value, op1_type, op2_type, node);
+ }
+
+ add_debug_source_node(g, node);
+ return LLVMBuildStore(g->builder, value, target_ref);
}
static LLVMValueRef gen_bin_op_expr(CodeGen *g, AstNode *node) {
src/tokenizer.cpp
@@ -402,6 +402,7 @@ void tokenize(Buf *buf, Tokenization *out) {
t.cur_tok->id = TokenIdBitShiftRightEq;
end_token(&t);
t.state = TokenizeStateStart;
+ break;
default:
t.pos -= 1;
end_token(&t);
@@ -415,6 +416,7 @@ void tokenize(Buf *buf, Tokenization *out) {
t.cur_tok->id = TokenIdCmpLessOrEq;
end_token(&t);
t.state = TokenizeStateStart;
+ break;
case '<':
t.cur_tok->id = TokenIdBitShiftLeft;
t.state = TokenizeStateSawLessThanLessThan;
@@ -432,6 +434,7 @@ void tokenize(Buf *buf, Tokenization *out) {
t.cur_tok->id = TokenIdBitShiftLeftEq;
end_token(&t);
t.state = TokenizeStateStart;
+ break;
default:
t.pos -= 1;
end_token(&t);
test/run_tests.cpp
@@ -454,6 +454,29 @@ export fn main(argc : isize, argv : *mut *mut u8, env : *mut *mut u8) -> i32 {
}
)SOURCE", "OK 1\nOK 2\nOK 3\nOK 4\n");
+ add_simple_case("modify operators", R"SOURCE(
+use "std.zig";
+
+export fn main(argc : isize, argv : *mut *mut u8, env : *mut *mut u8) -> i32 {
+ let mut i : i32 = 0;
+ i += 5; if i != 5 { print_str("BAD +=\n" as string); }
+ i -= 2; if i != 3 { print_str("BAD -=\n" as string); }
+ i *= 20; if i != 60 { print_str("BAD *=\n" as string); }
+ i /= 3; if i != 20 { print_str("BAD /=\n" as string); }
+ i %= 11; if i != 9 { print_str("BAD %=\n" as string); }
+ i <<= 1; if i != 18 { print_str("BAD <<=\n" as string); }
+ i >>= 2; if i != 4 { print_str("BAD >>=\n" as string); }
+ i = 6;
+ i &= 5; if i != 4 { print_str("BAD &=\n" as string); }
+ i ^= 6; if i != 2 { print_str("BAD ^=\n" as string); }
+ i = 6;
+ i |= 3; if i != 7 { print_str("BAD |=\n" as string); }
+
+ print_str("OK\n" as string);
+ return 0;
+}
+ )SOURCE", "OK\n");
+
}
static void add_compile_failure_test_cases(void) {