Commit 9a014b52cc

Josh Wolfe <thejoshwolfe@gmail.com>
2015-11-29 22:46:05
flatten expression ast to hide operator precedence
1 parent 4466a45
Changed files (3)
src/codegen.cpp
@@ -305,17 +305,9 @@ static void find_declarations(CodeGen *g, AstNode *node) {
         case NodeTypeReturnExpr:
         case NodeTypeRoot:
         case NodeTypeBlock:
-        case NodeTypeBoolOrExpr:
+        case NodeTypeBinOpExpr:
         case NodeTypeFnCall:
         case NodeTypeRootExportDecl:
-        case NodeTypeBoolAndExpr:
-        case NodeTypeComparisonExpr:
-        case NodeTypeBinOrExpr:
-        case NodeTypeBinXorExpr:
-        case NodeTypeBinAndExpr:
-        case NodeTypeBitShiftExpr:
-        case NodeTypeAddExpr:
-        case NodeTypeMultExpr:
         case NodeTypeCastExpr:
         case NodeTypePrimaryExpr:
         case NodeTypeGroupedExpr:
@@ -481,10 +473,9 @@ static void analyze_node(CodeGen *g, AstNode *node) {
                 analyze_node(g, node->data.return_expr.expr);
             }
             break;
-        case NodeTypeBoolOrExpr:
-            analyze_node(g, node->data.bool_or_expr.op1);
-            if (node->data.bool_or_expr.op2)
-                analyze_node(g, node->data.bool_or_expr.op2);
+        case NodeTypeBinOpExpr:
+            analyze_node(g, node->data.bin_op_expr.op1);
+            analyze_node(g, node->data.bin_op_expr.op2);
             break;
         case NodeTypeFnCall:
             {
@@ -515,30 +506,6 @@ static void analyze_node(CodeGen *g, AstNode *node) {
         case NodeTypeDirective:
             // we looked at directives in the parent node
             break;
-        case NodeTypeBoolAndExpr:
-            zig_panic("TODO");
-            break;
-        case NodeTypeComparisonExpr:
-            zig_panic("TODO");
-            break;
-        case NodeTypeBinOrExpr:
-            zig_panic("TODO");
-            break;
-        case NodeTypeBinXorExpr:
-            zig_panic("TODO");
-            break;
-        case NodeTypeBinAndExpr:
-            zig_panic("TODO");
-            break;
-        case NodeTypeBitShiftExpr:
-            zig_panic("TODO");
-            break;
-        case NodeTypeAddExpr:
-            zig_panic("TODO");
-            break;
-        case NodeTypeMultExpr:
-            zig_panic("TODO");
-            break;
         case NodeTypeCastExpr:
             zig_panic("TODO");
             break;
@@ -752,168 +719,138 @@ static LLVMValueRef gen_cast_expr(CodeGen *g, AstNode *node) {
 }
 
 static LLVMValueRef gen_mult_expr(CodeGen *g, AstNode *node) {
-    assert(node->type == NodeTypeMultExpr);
-
-    LLVMValueRef val1 = gen_cast_expr(g, node->data.mult_expr.op1);
-
-    if (!node->data.mult_expr.op2)
-        return val1;
+    assert(node->type == NodeTypeBinOpExpr);
 
-    LLVMValueRef val2 = gen_cast_expr(g, node->data.mult_expr.op2);
+    LLVMValueRef val1 = gen_expr(g, node->data.bin_op_expr.op1);
+    LLVMValueRef val2 = gen_expr(g, node->data.bin_op_expr.op2);
 
-    switch (node->data.mult_expr.mult_op) {
-        case MultOpMult:
+    switch (node->data.bin_op_expr.bin_op) {
+        case BinOpTypeMult:
             // TODO types so we know float vs int
             add_debug_source_node(g, node);
             return LLVMBuildMul(g->builder, val1, val2, "");
-        case MultOpDiv:
+        case BinOpTypeDiv:
             // TODO types so we know float vs int and signed vs unsigned
             add_debug_source_node(g, node);
             return LLVMBuildSDiv(g->builder, val1, val2, "");
-        case MultOpMod:
+        case BinOpTypeMod:
             // TODO types so we know float vs int and signed vs unsigned
             add_debug_source_node(g, node);
             return LLVMBuildSRem(g->builder, val1, val2, "");
-        case MultOpInvalid:
+        default:
             zig_unreachable();
     }
     zig_unreachable();
 }
 
 static LLVMValueRef gen_add_expr(CodeGen *g, AstNode *node) {
-    assert(node->type == NodeTypeAddExpr);
-
-    LLVMValueRef val1 = gen_mult_expr(g, node->data.add_expr.op1);
-
-    if (!node->data.add_expr.op2)
-        return val1;
+    assert(node->type == NodeTypeBinOpExpr);
 
-    LLVMValueRef val2 = gen_mult_expr(g, node->data.add_expr.op2);
+    LLVMValueRef val1 = gen_expr(g, node->data.bin_op_expr.op1);
+    LLVMValueRef val2 = gen_expr(g, node->data.bin_op_expr.op2);
 
-    switch (node->data.add_expr.add_op) {
-        case AddOpAdd:
+    switch (node->data.bin_op_expr.bin_op) {
+        case BinOpTypeAdd:
             add_debug_source_node(g, node);
             return LLVMBuildAdd(g->builder, val1, val2, "");
-        case AddOpSub:
+        case BinOpTypeSub:
             add_debug_source_node(g, node);
             return LLVMBuildSub(g->builder, val1, val2, "");
-        case AddOpInvalid:
+        default:
             zig_unreachable();
     }
     zig_unreachable();
 }
 
 static LLVMValueRef gen_bit_shift_expr(CodeGen *g, AstNode *node) {
-    assert(node->type == NodeTypeBitShiftExpr);
-
-    LLVMValueRef val1 = gen_add_expr(g, node->data.bit_shift_expr.op1);
-
-    if (!node->data.bit_shift_expr.op2)
-        return val1;
+    assert(node->type == NodeTypeBinOpExpr);
 
-    LLVMValueRef val2 = gen_add_expr(g, node->data.bit_shift_expr.op2);
+    LLVMValueRef val1 = gen_expr(g, node->data.bin_op_expr.op1);
+    LLVMValueRef val2 = gen_expr(g, node->data.bin_op_expr.op2);
 
-    switch (node->data.bit_shift_expr.bit_shift_op) {
-        case BitShiftOpLeft:
+    switch (node->data.bin_op_expr.bin_op) {
+        case BinOpTypeBitShiftLeft:
             add_debug_source_node(g, node);
             return LLVMBuildShl(g->builder, val1, val2, "");
-        case BitShiftOpRight:
+        case BinOpTypeBitShiftRight:
             // TODO implement type system so that we know whether to do
             // logical or arithmetic shifting here.
             // signed -> arithmetic, unsigned -> logical
             add_debug_source_node(g, node);
             return LLVMBuildLShr(g->builder, val1, val2, "");
-        case BitShiftOpInvalid:
+        default:
             zig_unreachable();
     }
     zig_unreachable();
 }
 
 static LLVMValueRef gen_bin_and_expr(CodeGen *g, AstNode *node) {
-    assert(node->type == NodeTypeBinAndExpr);
-
-    LLVMValueRef val1 = gen_bit_shift_expr(g, node->data.bin_and_expr.op1);
-
-    if (!node->data.bin_and_expr.op2)
-        return val1;
+    assert(node->type == NodeTypeBinOpExpr);
 
-    LLVMValueRef val2 = gen_bit_shift_expr(g, node->data.bin_and_expr.op2);
+    LLVMValueRef val1 = gen_expr(g, node->data.bin_op_expr.op1);
+    LLVMValueRef val2 = gen_expr(g, node->data.bin_op_expr.op2);
 
     add_debug_source_node(g, node);
     return LLVMBuildAnd(g->builder, val1, val2, "");
 }
 
 static LLVMValueRef gen_bin_xor_expr(CodeGen *g, AstNode *node) {
-    assert(node->type == NodeTypeBinXorExpr);
+    assert(node->type == NodeTypeBinOpExpr);
 
-    LLVMValueRef val1 = gen_bin_and_expr(g, node->data.bin_xor_expr.op1);
-
-    if (!node->data.bin_xor_expr.op2)
-        return val1;
-
-    LLVMValueRef val2 = gen_bin_and_expr(g, node->data.bin_xor_expr.op2);
+    LLVMValueRef val1 = gen_expr(g, node->data.bin_op_expr.op1);
+    LLVMValueRef val2 = gen_expr(g, node->data.bin_op_expr.op2);
 
     add_debug_source_node(g, node);
     return LLVMBuildXor(g->builder, val1, val2, "");
 }
 
 static LLVMValueRef gen_bin_or_expr(CodeGen *g, AstNode *node) {
-    assert(node->type == NodeTypeBinOrExpr);
-
-    LLVMValueRef val1 = gen_bin_xor_expr(g, node->data.bin_or_expr.op1);
+    assert(node->type == NodeTypeBinOpExpr);
 
-    if (!node->data.bin_or_expr.op2)
-        return val1;
-
-    LLVMValueRef val2 = gen_bin_xor_expr(g, node->data.bin_or_expr.op2);
+    LLVMValueRef val1 = gen_expr(g, node->data.bin_op_expr.op1);
+    LLVMValueRef val2 = gen_expr(g, node->data.bin_op_expr.op2);
 
     add_debug_source_node(g, node);
     return LLVMBuildOr(g->builder, val1, val2, "");
 }
 
-static LLVMIntPredicate cmp_op_to_int_predicate(CmpOp cmp_op, bool is_signed) {
+static LLVMIntPredicate cmp_op_to_int_predicate(BinOpType cmp_op, bool is_signed) {
     switch (cmp_op) {
-        case CmpOpInvalid:
+        case BinOpTypeInvalid:
             zig_unreachable();
-        case CmpOpEq:
+        case BinOpTypeCmpEq:
             return LLVMIntEQ;
-        case CmpOpNotEq:
+        case BinOpTypeCmpNotEq:
             return LLVMIntNE;
-        case CmpOpLessThan:
+        case BinOpTypeCmpLessThan:
             return is_signed ? LLVMIntSLT : LLVMIntULT;
-        case CmpOpGreaterThan:
+        case BinOpTypeCmpGreaterThan:
             return is_signed ? LLVMIntSGT : LLVMIntUGT;
-        case CmpOpLessOrEq:
+        case BinOpTypeCmpLessOrEq:
             return is_signed ? LLVMIntSLE : LLVMIntULE;
-        case CmpOpGreaterOrEq:
+        case BinOpTypeCmpGreaterOrEq:
             return is_signed ? LLVMIntSGE : LLVMIntUGE;
+        default:
+            zig_unreachable();
     }
-    zig_unreachable();
 }
 
 static LLVMValueRef gen_cmp_expr(CodeGen *g, AstNode *node) {
-    assert(node->type == NodeTypeComparisonExpr);
-
-    LLVMValueRef val1 = gen_bin_or_expr(g, node->data.comparison_expr.op1);
-
-    if (!node->data.comparison_expr.op2)
-        return val1;
+    assert(node->type == NodeTypeBinOpExpr);
 
-    LLVMValueRef val2 = gen_bin_or_expr(g, node->data.comparison_expr.op2);
+    LLVMValueRef val1 = gen_expr(g, node->data.bin_op_expr.op1);
+    LLVMValueRef val2 = gen_expr(g, node->data.bin_op_expr.op2);
 
     // TODO implement type system so that we know whether to do signed or unsigned comparison here
-    LLVMIntPredicate pred = cmp_op_to_int_predicate(node->data.comparison_expr.cmp_op, true);
+    LLVMIntPredicate pred = cmp_op_to_int_predicate(node->data.bin_op_expr.bin_op, true);
     add_debug_source_node(g, node);
     return LLVMBuildICmp(g->builder, pred, val1, val2, "");
 }
 
 static LLVMValueRef gen_bool_and_expr(CodeGen *g, AstNode *node) {
-    assert(node->type == NodeTypeBoolAndExpr);
+    assert(node->type == NodeTypeBinOpExpr);
 
-    LLVMValueRef val1 = gen_cmp_expr(g, node->data.bool_and_expr.op1);
-
-    if (!node->data.bool_and_expr.op2)
-        return val1;
+    LLVMValueRef val1 = gen_expr(g, node->data.bin_op_expr.op1);
 
     // block for when val1 == true
     LLVMBasicBlockRef true_block = LLVMAppendBasicBlock(g->cur_fn, "BoolAndTrue");
@@ -926,7 +863,7 @@ static LLVMValueRef gen_bool_and_expr(CodeGen *g, AstNode *node) {
     LLVMBuildCondBr(g->builder, val1_i1, false_block, true_block);
 
     LLVMPositionBuilderAtEnd(g->builder, true_block);
-    LLVMValueRef val2 = gen_cmp_expr(g, node->data.bool_and_expr.op2);
+    LLVMValueRef val2 = gen_expr(g, node->data.bin_op_expr.op2);
     add_debug_source_node(g, node);
     LLVMValueRef val2_i1 = LLVMBuildICmp(g->builder, LLVMIntEQ, val2, zero, "");
 
@@ -942,12 +879,9 @@ static LLVMValueRef gen_bool_and_expr(CodeGen *g, AstNode *node) {
 }
 
 static LLVMValueRef gen_bool_or_expr(CodeGen *g, AstNode *expr_node) {
-    assert(expr_node->type == NodeTypeBoolOrExpr);
-
-    LLVMValueRef val1 = gen_bool_and_expr(g, expr_node->data.bool_or_expr.op1);
+    assert(expr_node->type == NodeTypeBinOpExpr);
 
-    if (!expr_node->data.bool_or_expr.op2)
-        return val1;
+    LLVMValueRef val1 = gen_expr(g, expr_node->data.bin_op_expr.op1);
 
     // block for when val1 == false
     LLVMBasicBlockRef false_block = LLVMAppendBasicBlock(g->cur_fn, "BoolOrFalse");
@@ -960,7 +894,7 @@ static LLVMValueRef gen_bool_or_expr(CodeGen *g, AstNode *expr_node) {
     LLVMBuildCondBr(g->builder, val1_i1, false_block, true_block);
 
     LLVMPositionBuilderAtEnd(g->builder, false_block);
-    LLVMValueRef val2 = gen_bool_and_expr(g, expr_node->data.bool_or_expr.op2);
+    LLVMValueRef val2 = gen_expr(g, expr_node->data.bin_op_expr.op2);
     add_debug_source_node(g, expr_node);
     LLVMValueRef val2_i1 = LLVMBuildICmp(g->builder, LLVMIntEQ, val2, zero, "");
 
@@ -975,6 +909,41 @@ static LLVMValueRef gen_bool_or_expr(CodeGen *g, AstNode *expr_node) {
     return phi;
 }
 
+static LLVMValueRef gen_bin_op_expr(CodeGen *g, AstNode *node) {
+    switch (node->data.bin_op_expr.bin_op) {
+        case BinOpTypeInvalid:
+            zig_unreachable();
+        case BinOpTypeBoolOr:
+            return gen_bool_or_expr(g, node);
+        case BinOpTypeBoolAnd:
+            return gen_bool_and_expr(g, node);
+        case BinOpTypeCmpEq:
+        case BinOpTypeCmpNotEq:
+        case BinOpTypeCmpLessThan:
+        case BinOpTypeCmpGreaterThan:
+        case BinOpTypeCmpLessOrEq:
+        case BinOpTypeCmpGreaterOrEq:
+            return gen_cmp_expr(g, node);
+        case BinOpTypeBinOr:
+            return gen_bin_or_expr(g, node);
+        case BinOpTypeBinXor:
+            return gen_bin_xor_expr(g, node);
+        case BinOpTypeBinAnd:
+            return gen_bin_and_expr(g, node);
+        case BinOpTypeBitShiftLeft:
+        case BinOpTypeBitShiftRight:
+            return gen_bit_shift_expr(g, node);
+        case BinOpTypeAdd:
+        case BinOpTypeSub:
+            return gen_add_expr(g, node);
+        case BinOpTypeMult:
+        case BinOpTypeDiv:
+        case BinOpTypeMod:
+            return gen_mult_expr(g, node);
+    }
+    zig_unreachable();
+}
+
 static LLVMValueRef gen_return_expr(CodeGen *g, AstNode *node) {
     assert(node->type == NodeTypeReturnExpr);
     AstNode *param_node = node->data.return_expr.expr;
@@ -993,10 +962,12 @@ Expression : BoolOrExpression | ReturnExpression
 */
 static LLVMValueRef gen_expr(CodeGen *g, AstNode *node) {
     switch (node->type) {
-        case NodeTypeBoolOrExpr:
-            return gen_bool_or_expr(g, node);
+        case NodeTypeBinOpExpr:
+            return gen_bin_op_expr(g, node);
         case NodeTypeReturnExpr:
             return gen_return_expr(g, node);
+        case NodeTypeCastExpr:
+            return gen_cast_expr(g, node);
         case NodeTypeRoot:
         case NodeTypeRootExportDecl:
         case NodeTypeFnProto:
@@ -1008,15 +979,6 @@ static LLVMValueRef gen_expr(CodeGen *g, AstNode *node) {
         case NodeTypeFnCall:
         case NodeTypeExternBlock:
         case NodeTypeDirective:
-        case NodeTypeBoolAndExpr:
-        case NodeTypeComparisonExpr:
-        case NodeTypeBinOrExpr:
-        case NodeTypeBinXorExpr:
-        case NodeTypeBinAndExpr:
-        case NodeTypeBitShiftExpr:
-        case NodeTypeAddExpr:
-        case NodeTypeMultExpr:
-        case NodeTypeCastExpr:
         case NodeTypePrimaryExpr:
             return gen_primary_expr(g, node);
         case NodeTypeGroupedExpr:
src/parser.cpp
@@ -10,43 +10,27 @@
 #include <stdarg.h>
 #include <stdio.h>
 
-static const char *mult_op_str(MultOp mult_op) {
-    switch (mult_op) {
-        case MultOpInvalid: return "(invalid)";
-        case MultOpMult: return "*";
-        case MultOpDiv: return "/";
-        case MultOpMod: return "%";
-    }
-    zig_unreachable();
-}
-
-static const char *add_op_str(AddOp add_op) {
-    switch (add_op) {
-        case AddOpInvalid: return "(invalid)";
-        case AddOpAdd: return "+";
-        case AddOpSub: return "-";
-    }
-    zig_unreachable();
-}
-
-static const char *bit_shift_op_str(BitShiftOp bit_shift_op) {
-    switch (bit_shift_op) {
-        case BitShiftOpInvalid: return "(invalid)";
-        case BitShiftOpLeft: return "<<";
-        case BitShiftOpRight: return ">>";
-    }
-    zig_unreachable();
-}
-
-static const char *cmp_op_str(CmpOp cmp_op) {
-    switch (cmp_op) {
-        case CmpOpInvalid: return "(invalid)";
-        case CmpOpEq: return "=";
-        case CmpOpNotEq: return "!=";
-        case CmpOpLessThan: return "<";
-        case CmpOpGreaterThan: return ">";
-        case CmpOpLessOrEq: return "<=";
-        case CmpOpGreaterOrEq: return ">=";
+static const char *bin_op_str(BinOpType bin_op) {
+    switch (bin_op) {
+        case BinOpTypeInvalid:        return "(invalid)";
+        case BinOpTypeBoolOr:         return "||";
+        case BinOpTypeBoolAnd:        return "&&";
+        case BinOpTypeCmpEq:          return "==";
+        case BinOpTypeCmpNotEq:       return "!=";
+        case BinOpTypeCmpLessThan:    return "<";
+        case BinOpTypeCmpGreaterThan: return ">";
+        case BinOpTypeCmpLessOrEq:    return "<=";
+        case BinOpTypeCmpGreaterOrEq: return ">=";
+        case BinOpTypeBinOr:          return "|";
+        case BinOpTypeBinXor:         return "^";
+        case BinOpTypeBinAnd:         return "&";
+        case BinOpTypeBitShiftLeft:   return "<<";
+        case BinOpTypeBitShiftRight:  return ">>";
+        case BinOpTypeAdd:            return "+";
+        case BinOpTypeSub:            return "-";
+        case BinOpTypeMult:           return "*";
+        case BinOpTypeDiv:            return "/";
+        case BinOpTypeMod:            return "%";
     }
     zig_unreachable();
 }
@@ -84,8 +68,8 @@ const char *node_type_str(NodeType node_type) {
             return "Type";
         case NodeTypeBlock:
             return "Block";
-        case NodeTypeBoolOrExpr:
-            return "BoolOrExpr";
+        case NodeTypeBinOpExpr:
+            return "BinOpExpr";
         case NodeTypeFnCall:
             return "FnCall";
         case NodeTypeExternBlock:
@@ -94,22 +78,6 @@ const char *node_type_str(NodeType node_type) {
             return "Directive";
         case NodeTypeReturnExpr:
             return "ReturnExpr";
-        case NodeTypeBoolAndExpr:
-            return "BoolAndExpr";
-        case NodeTypeComparisonExpr:
-            return "ComparisonExpr";
-        case NodeTypeBinOrExpr:
-            return "BinOrExpr";
-        case NodeTypeBinXorExpr:
-            return "BinXorExpr";
-        case NodeTypeBinAndExpr:
-            return "BinAndExpr";
-        case NodeTypeBitShiftExpr:
-            return "BitShiftExpr";
-        case NodeTypeAddExpr:
-            return "AddExpr";
-        case NodeTypeMultExpr:
-            return "MultExpr";
         case NodeTypeCastExpr:
             return "CastExpr";
         case NodeTypePrimaryExpr:
@@ -214,11 +182,11 @@ void ast_print(AstNode *node, int indent) {
             fprintf(stderr, "%s\n", node_type_str(node->type));
             ast_print(node->data.fn_decl.fn_proto, indent + 2);
             break;
-        case NodeTypeBoolOrExpr:
-            fprintf(stderr, "%s\n", node_type_str(node->type));
-            ast_print(node->data.bool_or_expr.op1, indent + 2);
-            if (node->data.bool_or_expr.op2)
-                ast_print(node->data.bool_or_expr.op2, indent + 2);
+        case NodeTypeBinOpExpr:
+            fprintf(stderr, "%s %s\n", node_type_str(node->type),
+                    bin_op_str(node->data.bin_op_expr.bin_op));
+            ast_print(node->data.bin_op_expr.op1, indent + 2);
+            ast_print(node->data.bin_op_expr.op2, indent + 2);
             break;
         case NodeTypeFnCall:
             fprintf(stderr, "%s '%s'\n", node_type_str(node->type), buf_ptr(&node->data.fn_call.name));
@@ -230,58 +198,6 @@ void ast_print(AstNode *node, int indent) {
         case NodeTypeDirective:
             fprintf(stderr, "%s\n", node_type_str(node->type));
             break;
-        case NodeTypeBoolAndExpr:
-            fprintf(stderr, "%s\n", node_type_str(node->type));
-            ast_print(node->data.bool_and_expr.op1, indent + 2);
-            if (node->data.bool_and_expr.op2)
-                ast_print(node->data.bool_and_expr.op2, indent + 2);
-            break;
-        case NodeTypeComparisonExpr:
-            fprintf(stderr, "%s %s\n", node_type_str(node->type),
-                    cmp_op_str(node->data.comparison_expr.cmp_op));
-            ast_print(node->data.comparison_expr.op1, indent + 2);
-            if (node->data.comparison_expr.op2)
-                ast_print(node->data.comparison_expr.op2, indent + 2);
-            break;
-        case NodeTypeBinOrExpr:
-            fprintf(stderr, "%s\n", node_type_str(node->type));
-            ast_print(node->data.bin_or_expr.op1, indent + 2);
-            if (node->data.bin_or_expr.op2)
-                ast_print(node->data.bin_or_expr.op2, indent + 2);
-            break;
-        case NodeTypeBinXorExpr:
-            fprintf(stderr, "%s\n", node_type_str(node->type));
-            ast_print(node->data.bin_xor_expr.op1, indent + 2);
-            if (node->data.bin_xor_expr.op2)
-                ast_print(node->data.bin_xor_expr.op2, indent + 2);
-            break;
-        case NodeTypeBinAndExpr:
-            fprintf(stderr, "%s\n", node_type_str(node->type));
-            ast_print(node->data.bin_and_expr.op1, indent + 2);
-            if (node->data.bin_and_expr.op2)
-                ast_print(node->data.bin_and_expr.op2, indent + 2);
-            break;
-        case NodeTypeBitShiftExpr:
-            fprintf(stderr, "%s %s\n", node_type_str(node->type),
-                    bit_shift_op_str(node->data.bit_shift_expr.bit_shift_op));
-            ast_print(node->data.bit_shift_expr.op1, indent + 2);
-            if (node->data.bit_shift_expr.op2)
-                ast_print(node->data.bit_shift_expr.op2, indent + 2);
-            break;
-        case NodeTypeAddExpr:
-            fprintf(stderr, "%s %s\n", node_type_str(node->type),
-                    add_op_str(node->data.add_expr.add_op));
-            ast_print(node->data.add_expr.op1, indent + 2);
-            if (node->data.add_expr.op2)
-                ast_print(node->data.add_expr.op2, indent + 2);
-            break;
-        case NodeTypeMultExpr:
-            fprintf(stderr, "%s %s\n", node_type_str(node->type),
-                    mult_op_str(node->data.mult_expr.mult_op));
-            ast_print(node->data.mult_expr.op1, indent + 2);
-            if (node->data.mult_expr.op2)
-                ast_print(node->data.mult_expr.op2, indent + 2);
-            break;
         case NodeTypeCastExpr:
             fprintf(stderr, "%s\n", node_type_str(node->type));
             ast_print(node->data.cast_expr.primary_expr, indent + 2);
@@ -709,26 +625,26 @@ static AstNode *ast_parse_cast_expression(ParseContext *pc, int *token_index, bo
     return node;
 }
 
-static MultOp tok_to_mult_op(Token *token) {
+static BinOpType tok_to_mult_op(Token *token) {
     switch (token->id) {
-        case TokenIdStar: return MultOpMult;
-        case TokenIdSlash: return MultOpDiv;
-        case TokenIdPercent: return MultOpMod;
-        default: return MultOpInvalid;
+        case TokenIdStar: return BinOpTypeMult;
+        case TokenIdSlash: return BinOpTypeDiv;
+        case TokenIdPercent: return BinOpTypeMod;
+        default: return BinOpTypeInvalid;
     }
 }
 
 /*
 MultiplyOperator : token(Star) | token(Slash) | token(Percent)
 */
-static MultOp ast_parse_mult_op(ParseContext *pc, int *token_index, bool mandatory) {
+static BinOpType ast_parse_mult_op(ParseContext *pc, int *token_index, bool mandatory) {
     Token *token = &pc->tokens->at(*token_index);
-    MultOp result = tok_to_mult_op(token);
-    if (result == MultOpInvalid) {
+    BinOpType result = tok_to_mult_op(token);
+    if (result == BinOpTypeInvalid) {
         if (mandatory) {
             ast_invalid_token_error(pc, token);
         } else {
-            return MultOpInvalid;
+            return BinOpTypeInvalid;
         }
     }
     *token_index += 1;
@@ -744,39 +660,39 @@ static AstNode *ast_parse_mult_expr(ParseContext *pc, int *token_index, bool man
         return nullptr;
 
     Token *token = &pc->tokens->at(*token_index);
-    MultOp mult_op = ast_parse_mult_op(pc, token_index, false);
-    if (mult_op == MultOpInvalid)
+    BinOpType mult_op = ast_parse_mult_op(pc, token_index, false);
+    if (mult_op == BinOpTypeInvalid)
         return operand_1;
 
     AstNode *operand_2 = ast_parse_cast_expression(pc, token_index, true);
 
-    AstNode *node = ast_create_node(NodeTypeMultExpr, token);
-    node->data.mult_expr.op1 = operand_1;
-    node->data.mult_expr.mult_op = mult_op;
-    node->data.mult_expr.op2 = operand_2;
+    AstNode *node = ast_create_node(NodeTypeBinOpExpr, token);
+    node->data.bin_op_expr.op1 = operand_1;
+    node->data.bin_op_expr.bin_op = mult_op;
+    node->data.bin_op_expr.op2 = operand_2;
 
     return node;
 }
 
-static AddOp tok_to_add_op(Token *token) {
+static BinOpType tok_to_add_op(Token *token) {
     switch (token->id) {
-        case TokenIdPlus: return AddOpAdd;
-        case TokenIdDash: return AddOpSub;
-        default: return AddOpInvalid;
+        case TokenIdPlus: return BinOpTypeAdd;
+        case TokenIdDash: return BinOpTypeSub;
+        default: return BinOpTypeInvalid;
     }
 }
 
 /*
 AdditionOperator : token(Plus) | token(Minus)
 */
-static AddOp ast_parse_add_op(ParseContext *pc, int *token_index, bool mandatory) {
+static BinOpType ast_parse_add_op(ParseContext *pc, int *token_index, bool mandatory) {
     Token *token = &pc->tokens->at(*token_index);
-    AddOp result = tok_to_add_op(token);
-    if (result == AddOpInvalid) {
+    BinOpType result = tok_to_add_op(token);
+    if (result == BinOpTypeInvalid) {
         if (mandatory) {
             ast_invalid_token_error(pc, token);
         } else {
-            return AddOpInvalid;
+            return BinOpTypeInvalid;
         }
     }
     *token_index += 1;
@@ -792,39 +708,39 @@ static AstNode *ast_parse_add_expr(ParseContext *pc, int *token_index, bool mand
         return nullptr;
 
     Token *token = &pc->tokens->at(*token_index);
-    AddOp add_op = ast_parse_add_op(pc, token_index, false);
-    if (add_op == AddOpInvalid)
+    BinOpType add_op = ast_parse_add_op(pc, token_index, false);
+    if (add_op == BinOpTypeInvalid)
         return operand_1;
 
     AstNode *operand_2 = ast_parse_mult_expr(pc, token_index, true);
 
-    AstNode *node = ast_create_node(NodeTypeAddExpr, token);
-    node->data.add_expr.op1 = operand_1;
-    node->data.add_expr.add_op = add_op;
-    node->data.add_expr.op2 = operand_2;
+    AstNode *node = ast_create_node(NodeTypeBinOpExpr, token);
+    node->data.bin_op_expr.op1 = operand_1;
+    node->data.bin_op_expr.bin_op = add_op;
+    node->data.bin_op_expr.op2 = operand_2;
 
     return node;
 }
 
-static BitShiftOp tok_to_bit_shift_op(Token *token) {
+static BinOpType tok_to_bit_shift_op(Token *token) {
     switch (token->id) {
-        case TokenIdBitShiftLeft: return BitShiftOpLeft;
-        case TokenIdBitShiftRight: return BitShiftOpRight;
-        default: return BitShiftOpInvalid;
+        case TokenIdBitShiftLeft: return BinOpTypeBitShiftLeft;
+        case TokenIdBitShiftRight: return BinOpTypeBitShiftRight;
+        default: return BinOpTypeInvalid;
     }
 }
 
 /*
 BitShiftOperator : token(BitShiftLeft | token(BitShiftRight)
 */
-static BitShiftOp ast_parse_bit_shift_op(ParseContext *pc, int *token_index, bool mandatory) {
+static BinOpType ast_parse_bit_shift_op(ParseContext *pc, int *token_index, bool mandatory) {
     Token *token = &pc->tokens->at(*token_index);
-    BitShiftOp result = tok_to_bit_shift_op(token);
-    if (result == BitShiftOpInvalid) {
+    BinOpType result = tok_to_bit_shift_op(token);
+    if (result == BinOpTypeInvalid) {
         if (mandatory) {
             ast_invalid_token_error(pc, token);
         } else {
-            return BitShiftOpInvalid;
+            return BinOpTypeInvalid;
         }
     }
     *token_index += 1;
@@ -840,16 +756,16 @@ static AstNode *ast_parse_bit_shift_expr(ParseContext *pc, int *token_index, boo
         return nullptr;
 
     Token *token = &pc->tokens->at(*token_index);
-    BitShiftOp bit_shift_op = ast_parse_bit_shift_op(pc, token_index, false);
-    if (bit_shift_op == BitShiftOpInvalid)
+    BinOpType bit_shift_op = ast_parse_bit_shift_op(pc, token_index, false);
+    if (bit_shift_op == BinOpTypeInvalid)
         return operand_1;
 
     AstNode *operand_2 = ast_parse_add_expr(pc, token_index, true);
 
-    AstNode *node = ast_create_node(NodeTypeBitShiftExpr, token);
-    node->data.bit_shift_expr.op1 = operand_1;
-    node->data.bit_shift_expr.bit_shift_op = bit_shift_op;
-    node->data.bit_shift_expr.op2 = operand_2;
+    AstNode *node = ast_create_node(NodeTypeBinOpExpr, token);
+    node->data.bin_op_expr.op1 = operand_1;
+    node->data.bin_op_expr.bin_op = bit_shift_op;
+    node->data.bin_op_expr.op2 = operand_2;
 
     return node;
 }
@@ -870,9 +786,10 @@ static AstNode *ast_parse_bin_and_expr(ParseContext *pc, int *token_index, bool
 
     AstNode *operand_2 = ast_parse_bit_shift_expr(pc, token_index, true);
 
-    AstNode *node = ast_create_node(NodeTypeBinAndExpr, token);
-    node->data.bin_and_expr.op1 = operand_1;
-    node->data.bin_and_expr.op2 = operand_2;
+    AstNode *node = ast_create_node(NodeTypeBinOpExpr, token);
+    node->data.bin_op_expr.op1 = operand_1;
+    node->data.bin_op_expr.bin_op = BinOpTypeBinAnd;
+    node->data.bin_op_expr.op2 = operand_2;
 
     return node;
 }
@@ -892,9 +809,10 @@ static AstNode *ast_parse_bin_xor_expr(ParseContext *pc, int *token_index, bool
 
     AstNode *operand_2 = ast_parse_bin_and_expr(pc, token_index, true);
 
-    AstNode *node = ast_create_node(NodeTypeBinXorExpr, token);
-    node->data.bin_xor_expr.op1 = operand_1;
-    node->data.bin_xor_expr.op2 = operand_2;
+    AstNode *node = ast_create_node(NodeTypeBinOpExpr, token);
+    node->data.bin_op_expr.op1 = operand_1;
+    node->data.bin_op_expr.bin_op = BinOpTypeBinXor;
+    node->data.bin_op_expr.op2 = operand_2;
 
     return node;
 }
@@ -914,33 +832,34 @@ static AstNode *ast_parse_bin_or_expr(ParseContext *pc, int *token_index, bool m
 
     AstNode *operand_2 = ast_parse_bin_xor_expr(pc, token_index, true);
 
-    AstNode *node = ast_create_node(NodeTypeBinOrExpr, token);
-    node->data.bin_or_expr.op1 = operand_1;
-    node->data.bin_or_expr.op2 = operand_2;
+    AstNode *node = ast_create_node(NodeTypeBinOpExpr, token);
+    node->data.bin_op_expr.op1 = operand_1;
+    node->data.bin_op_expr.bin_op = BinOpTypeBinOr;
+    node->data.bin_op_expr.op2 = operand_2;
 
     return node;
 }
 
-static CmpOp tok_to_cmp_op(Token *token) {
+static BinOpType tok_to_cmp_op(Token *token) {
     switch (token->id) {
-        case TokenIdCmpEq: return CmpOpEq;
-        case TokenIdCmpNotEq: return CmpOpNotEq;
-        case TokenIdCmpLessThan: return CmpOpLessThan;
-        case TokenIdCmpGreaterThan: return CmpOpGreaterThan;
-        case TokenIdCmpLessOrEq: return CmpOpLessOrEq;
-        case TokenIdCmpGreaterOrEq: return CmpOpGreaterOrEq;
-        default: return CmpOpInvalid;
+        case TokenIdCmpEq: return BinOpTypeCmpEq;
+        case TokenIdCmpNotEq: return BinOpTypeCmpNotEq;
+        case TokenIdCmpLessThan: return BinOpTypeCmpLessThan;
+        case TokenIdCmpGreaterThan: return BinOpTypeCmpGreaterThan;
+        case TokenIdCmpLessOrEq: return BinOpTypeCmpLessOrEq;
+        case TokenIdCmpGreaterOrEq: return BinOpTypeCmpGreaterOrEq;
+        default: return BinOpTypeInvalid;
     }
 }
 
-static CmpOp ast_parse_comparison_operator(ParseContext *pc, int *token_index, bool mandatory) {
+static BinOpType ast_parse_comparison_operator(ParseContext *pc, int *token_index, bool mandatory) {
     Token *token = &pc->tokens->at(*token_index);
-    CmpOp result = tok_to_cmp_op(token);
-    if (result == CmpOpInvalid) {
+    BinOpType result = tok_to_cmp_op(token);
+    if (result == BinOpTypeInvalid) {
         if (mandatory) {
             ast_invalid_token_error(pc, token);
         } else {
-            return CmpOpInvalid;
+            return BinOpTypeInvalid;
         }
     }
     *token_index += 1;
@@ -956,16 +875,16 @@ static AstNode *ast_parse_comparison_expr(ParseContext *pc, int *token_index, bo
         return nullptr;
 
     Token *token = &pc->tokens->at(*token_index);
-    CmpOp cmp_op = ast_parse_comparison_operator(pc, token_index, false);
-    if (cmp_op == CmpOpInvalid)
+    BinOpType cmp_op = ast_parse_comparison_operator(pc, token_index, false);
+    if (cmp_op == BinOpTypeInvalid)
         return operand_1;
 
     AstNode *operand_2 = ast_parse_bin_or_expr(pc, token_index, true);
 
-    AstNode *node = ast_create_node(NodeTypeComparisonExpr, token);
-    node->data.comparison_expr.op1 = operand_1;
-    node->data.comparison_expr.cmp_op = cmp_op;
-    node->data.comparison_expr.op2 = operand_2;
+    AstNode *node = ast_create_node(NodeTypeBinOpExpr, token);
+    node->data.bin_op_expr.op1 = operand_1;
+    node->data.bin_op_expr.bin_op = cmp_op;
+    node->data.bin_op_expr.op2 = operand_2;
 
     return node;
 }
@@ -985,9 +904,10 @@ static AstNode *ast_parse_bool_and_expr(ParseContext *pc, int *token_index, bool
 
     AstNode *operand_2 = ast_parse_comparison_expr(pc, token_index, true);
 
-    AstNode *node = ast_create_node(NodeTypeBoolAndExpr, token);
-    node->data.bool_and_expr.op1 = operand_1;
-    node->data.bool_and_expr.op2 = operand_2;
+    AstNode *node = ast_create_node(NodeTypeBinOpExpr, token);
+    node->data.bin_op_expr.op1 = operand_1;
+    node->data.bin_op_expr.bin_op = BinOpTypeBoolAnd;
+    node->data.bin_op_expr.op2 = operand_2;
 
     return node;
 }
@@ -1024,9 +944,10 @@ static AstNode *ast_parse_bool_or_expr(ParseContext *pc, int *token_index, bool
 
     AstNode *operand_2 = ast_parse_bool_and_expr(pc, token_index, true);
 
-    AstNode *node = ast_create_node(NodeTypeBoolOrExpr, token);
-    node->data.bool_or_expr.op1 = operand_1;
-    node->data.bool_or_expr.op2 = operand_2;
+    AstNode *node = ast_create_node(NodeTypeBinOpExpr, token);
+    node->data.bin_op_expr.op1 = operand_1;
+    node->data.bin_op_expr.bin_op = BinOpTypeBoolOr;
+    node->data.bin_op_expr.op2 = operand_2;
 
     return node;
 }
src/parser.hpp
@@ -28,15 +28,7 @@ enum NodeType {
     NodeTypeExternBlock,
     NodeTypeDirective,
     NodeTypeReturnExpr,
-    NodeTypeBoolOrExpr,
-    NodeTypeBoolAndExpr,
-    NodeTypeComparisonExpr,
-    NodeTypeBinOrExpr,
-    NodeTypeBinXorExpr,
-    NodeTypeBinAndExpr,
-    NodeTypeBitShiftExpr,
-    NodeTypeAddExpr,
-    NodeTypeMultExpr,
+    NodeTypeBinOpExpr,
     NodeTypeCastExpr,
     NodeTypePrimaryExpr,
     NodeTypeGroupedExpr,
@@ -96,9 +88,32 @@ struct AstNodeReturnExpr {
     AstNode *expr;
 };
 
-struct AstNodeBoolOrExpr {
+enum BinOpType {
+    BinOpTypeInvalid,
+    // TODO: include assignment?
+    BinOpTypeBoolOr,
+    BinOpTypeBoolAnd,
+    BinOpTypeCmpEq,
+    BinOpTypeCmpNotEq,
+    BinOpTypeCmpLessThan,
+    BinOpTypeCmpGreaterThan,
+    BinOpTypeCmpLessOrEq,
+    BinOpTypeCmpGreaterOrEq,
+    BinOpTypeBinOr,
+    BinOpTypeBinXor,
+    BinOpTypeBinAnd,
+    BinOpTypeBitShiftLeft,
+    BinOpTypeBitShiftRight,
+    BinOpTypeAdd,
+    BinOpTypeSub,
+    BinOpTypeMult,
+    BinOpTypeDiv,
+    BinOpTypeMod,
+};
+
+struct AstNodeBinOpExpr {
     AstNode *op1;
-    // if op2 is non-null, do boolean or, otherwise nothing
+    BinOpType bin_op;
     AstNode *op2;
 };
 
@@ -122,87 +137,6 @@ struct AstNodeRootExportDecl {
     Buf name;
 };
 
-struct AstNodeBoolAndExpr {
-    AstNode *op1;
-    // if op2 is non-null, do boolean and, otherwise nothing
-    AstNode *op2;
-};
-
-enum CmpOp {
-    CmpOpInvalid,
-    CmpOpEq,
-    CmpOpNotEq,
-    CmpOpLessThan,
-    CmpOpGreaterThan,
-    CmpOpLessOrEq,
-    CmpOpGreaterOrEq,
-};
-
-struct AstNodeComparisonExpr {
-    AstNode *op1;
-    CmpOp cmp_op;
-    // if op2 is non-null, do cmp_op, otherwise nothing
-    AstNode *op2;
-};
-
-struct AstNodeBinOrExpr {
-    AstNode *op1;
-    // if op2 is non-null, do binary or, otherwise nothing
-    AstNode *op2;
-};
-
-struct AstNodeBinXorExpr {
-    AstNode *op1;
-    // if op2 is non-null, do binary xor, otherwise nothing
-    AstNode *op2;
-};
-
-struct AstNodeBinAndExpr {
-    AstNode *op1;
-    // if op2 is non-null, do binary and, otherwise nothing
-    AstNode *op2;
-};
-
-enum BitShiftOp {
-    BitShiftOpInvalid,
-    BitShiftOpLeft,
-    BitShiftOpRight,
-};
-
-struct AstNodeBitShiftExpr {
-    AstNode *op1;
-    BitShiftOp bit_shift_op;
-    // if op2 is non-null, do bit_shift_op, otherwise nothing
-    AstNode *op2;
-};
-
-enum AddOp {
-    AddOpInvalid,
-    AddOpAdd,
-    AddOpSub,
-};
-
-struct AstNodeAddExpr {
-    AstNode *op1;
-    AddOp add_op;
-    // if op2 is non-null, do add_op, otherwise nothing
-    AstNode *op2;
-};
-
-enum MultOp {
-    MultOpInvalid,
-    MultOpMult,
-    MultOpDiv,
-    MultOpMod,
-};
-
-struct AstNodeMultExpr {
-    AstNode *op1;
-    MultOp mult_op;
-    // if op2 is non-null, do mult_op, otherwise nothing
-    AstNode *op2;
-};
-
 struct AstNodeCastExpr {
     AstNode *primary_expr;
     // if type is non-null, do cast, otherwise nothing
@@ -249,18 +183,10 @@ struct AstNode {
         AstNodeParamDecl param_decl;
         AstNodeBlock block;
         AstNodeReturnExpr return_expr;
-        AstNodeBoolOrExpr bool_or_expr;
+        AstNodeBinOpExpr bin_op_expr;
         AstNodeFnCall fn_call;
         AstNodeExternBlock extern_block;
         AstNodeDirective directive;
-        AstNodeBoolAndExpr bool_and_expr;
-        AstNodeComparisonExpr comparison_expr;
-        AstNodeBinOrExpr bin_or_expr;
-        AstNodeBinXorExpr bin_xor_expr;
-        AstNodeBinAndExpr bin_and_expr;
-        AstNodeBitShiftExpr bit_shift_expr;
-        AstNodeAddExpr add_expr;
-        AstNodeMultExpr mult_expr;
         AstNodeCastExpr cast_expr;
         AstNodePrimaryExpr primary_expr;
         AstNodeGroupedExpr grouped_expr;