Commit 5a8822c714

Andrew Kelley <superjoe30@gmail.com>
2015-12-16 03:17:39
fix assignment operators for struct fields
1 parent 28c5a8f
doc/langref.md
@@ -142,7 +142,7 @@ FnCallExpression : token(LParen) list(Expression, token(Comma)) token(RParen)
 
 ArrayAccessExpression : token(LBracket) Expression token(RBracket)
 
-PrefixOp : token(Not) | token(Dash) | token(Tilde)
+PrefixOp : token(Not) | token(Dash) | token(Tilde) | (token(Ampersand) option(token(Const)))
 
 PrimaryExpression : token(Number) | token(String) | KeywordLiteral | GroupedExpression | token(Symbol) | Goto | BlockExpression
 
@@ -157,8 +157,7 @@ KeywordLiteral : token(Unreachable) | token(Void) | token(True) | token(False)
 
 ```
 x() x[] x.y
-&x
-!x -x ~x
+!x -x ~x &x &const x
 as
 * / %
 + -
src/analyze.cpp
@@ -987,6 +987,50 @@ static TypeTableEntry *analyze_cast_expr(CodeGen *g, ImportTableEntry *import, B
     }
 }
 
+enum LValPurpose {
+    LValPurposeAssign,
+    LValPurposeAddressOf,
+};
+
+static TypeTableEntry *analyze_lvalue(CodeGen *g, ImportTableEntry *import, BlockContext *block_context,
+        AstNode *lhs_node, LValPurpose purpose, bool is_ptr_const)
+{
+    TypeTableEntry *expected_rhs_type = nullptr;
+    if (lhs_node->type == NodeTypeSymbol) {
+        Buf *name = &lhs_node->data.symbol;
+        VariableTableEntry *var = find_variable(block_context, name);
+        if (var) {
+            if (purpose == LValPurposeAssign && var->is_const) {
+                add_node_error(g, lhs_node,
+                    buf_sprintf("cannot assign to constant"));
+            } else if (purpose == LValPurposeAddressOf && var->is_const && !is_ptr_const) {
+                add_node_error(g, lhs_node,
+                    buf_sprintf("must use &const to get address of constant"));
+            } else {
+                expected_rhs_type = var->type;
+            }
+        } else {
+            add_node_error(g, lhs_node,
+                    buf_sprintf("use of undeclared identifier '%s'", buf_ptr(name)));
+        }
+    } else if (lhs_node->type == NodeTypeArrayAccessExpr) {
+        expected_rhs_type = analyze_array_access_expr(g, import, block_context, lhs_node);
+    } else if (lhs_node->type == NodeTypeFieldAccessExpr) {
+        alloc_codegen_node(lhs_node);
+        expected_rhs_type = analyze_field_access_expr(g, import, block_context, lhs_node);
+    } else {
+        if (purpose == LValPurposeAssign) {
+            add_node_error(g, lhs_node,
+                    buf_sprintf("assignment target must be variable, field, or array element"));
+        } else if (purpose == LValPurposeAddressOf) {
+            add_node_error(g, lhs_node,
+                    buf_sprintf("addressof target must be variable, field, or array element"));
+        }
+        expected_rhs_type = g->builtin_types.entry_invalid;
+    }
+    return expected_rhs_type;
+}
+
 static TypeTableEntry *analyze_bin_op_expr(CodeGen *g, ImportTableEntry *import, BlockContext *context,
         TypeTableEntry *expected_type, AstNode *node)
 {
@@ -1006,38 +1050,17 @@ static TypeTableEntry *analyze_bin_op_expr(CodeGen *g, ImportTableEntry *import,
         case BinOpTypeAssignBoolOr:
             {
                 AstNode *lhs_node = node->data.bin_op_expr.op1;
-                TypeTableEntry *expected_rhs_type = nullptr;
-                if (lhs_node->type == NodeTypeSymbol) {
-                    Buf *name = &lhs_node->data.symbol;
-                    VariableTableEntry *var = find_variable(context, name);
-                    if (var) {
-                        if (var->is_const) {
-                            add_node_error(g, lhs_node,
-                                buf_sprintf("cannot assign to constant variable"));
-                        } else {
-                            if (!is_op_allowed(var->type, node->data.bin_op_expr.bin_op)) {
-                                if (var->type->id != TypeTableEntryIdInvalid) {
-                                    add_node_error(g, lhs_node,
-                                        buf_sprintf("operator not allowed for type '%s'",
-                                            buf_ptr(&var->type->name)));
-                                }
-                            } else {
-                                expected_rhs_type = var->type;
-                            }
-                        }
-                    } else {
+
+                TypeTableEntry *expected_rhs_type = analyze_lvalue(g, import, context, lhs_node,
+                        LValPurposeAssign, false);
+                if (!is_op_allowed(expected_rhs_type, node->data.bin_op_expr.bin_op)) {
+                    if (expected_rhs_type->id != TypeTableEntryIdInvalid) {
                         add_node_error(g, lhs_node,
-                                buf_sprintf("use of undeclared identifier '%s'", buf_ptr(name)));
+                            buf_sprintf("operator not allowed for type '%s'",
+                                buf_ptr(&expected_rhs_type->name)));
                     }
-                } else if (lhs_node->type == NodeTypeArrayAccessExpr) {
-                    expected_rhs_type = analyze_array_access_expr(g, import, context, lhs_node);
-                } else if (lhs_node->type == NodeTypeFieldAccessExpr) {
-                    alloc_codegen_node(lhs_node);
-                    expected_rhs_type = analyze_field_access_expr(g, import, context, lhs_node);
-                } else {
-                    add_node_error(g, lhs_node,
-                            buf_sprintf("assignment target must be variable, field, or array element"));
                 }
+
                 analyze_expression(g, import, context, expected_rhs_type, node->data.bin_op_expr.op2);
                 return g->builtin_types.entry_void;
             }
@@ -1388,6 +1411,8 @@ static TypeTableEntry * analyze_expression(CodeGen *g, ImportTableEntry *import,
             break;
         case NodeTypePrefixOpExpr:
             switch (node->data.prefix_op_expr.prefix_op) {
+                case PrefixOpInvalid:
+                    zig_unreachable();
                 case PrefixOpBoolNot:
                     analyze_expression(g, import, context, g->builtin_types.entry_bool,
                             node->data.prefix_op_expr.primary_expr);
@@ -1407,8 +1432,22 @@ static TypeTableEntry * analyze_expression(CodeGen *g, ImportTableEntry *import,
                         return_type = g->builtin_types.entry_i32;
                         break;
                     }
-                case PrefixOpInvalid:
-                    zig_unreachable();
+                case PrefixOpAddressOf:
+                case PrefixOpConstAddressOf:
+                    {
+                        bool is_const = (node->data.prefix_op_expr.prefix_op == PrefixOpConstAddressOf);
+
+                        TypeTableEntry *child_type = analyze_lvalue(g, import, context,
+                                node->data.prefix_op_expr.primary_expr, LValPurposeAddressOf, is_const);
+
+                        if (child_type->id == TypeTableEntryIdInvalid) {
+                            return_type = g->builtin_types.entry_invalid;
+                            break;
+                        }
+
+                        return_type = get_pointer_to_type(g, child_type, is_const);
+                        break;
+                    }
             }
             break;
         case NodeTypeIfExpr:
src/codegen.cpp
@@ -206,7 +206,7 @@ static LLVMValueRef gen_array_ptr(CodeGen *g, AstNode *node) {
     return LLVMBuildInBoundsGEP(g->builder, array_ref_value, indices, 2, "");
 }
 
-static LLVMValueRef gen_field_ptr(CodeGen *g, AstNode *node) {
+static LLVMValueRef gen_field_ptr(CodeGen *g, AstNode *node, TypeTableEntry **out_type_entry) {
     assert(node->type == NodeTypeFieldAccessExpr);
 
     LLVMValueRef struct_ptr = gen_expr(g, node->data.field_access_expr.struct_expr);
@@ -217,6 +217,8 @@ static LLVMValueRef gen_field_ptr(CodeGen *g, AstNode *node) {
 
     assert(codegen_field_access->field_index >= 0);
 
+    *out_type_entry = codegen_field_access->type_struct_field->type_entry;
+
     add_debug_source_node(g, node);
     return LLVMBuildStructGEP(g->builder, struct_ptr, codegen_field_access->field_index, "");
 }
@@ -243,34 +245,78 @@ static LLVMValueRef gen_field_access_expr(CodeGen *g, AstNode *node) {
             zig_panic("gen_field_access_expr bad array field");
         }
     } else if (struct_type->id == TypeTableEntryIdStruct) {
-        LLVMValueRef ptr = gen_field_ptr(g, node);
+        TypeTableEntry *type_entry;
+        LLVMValueRef ptr = gen_field_ptr(g, node, &type_entry);
         return LLVMBuildLoad(g->builder, ptr, "");
     } else {
         zig_panic("gen_field_access_expr bad struct type");
     }
 }
 
+static LLVMValueRef gen_lvalue(CodeGen *g, AstNode *parent_node, AstNode *node,
+        TypeTableEntry **out_type_entry)
+{
+    LLVMValueRef target_ref;
+
+    if (node->type == NodeTypeSymbol) {
+        VariableTableEntry *var = find_variable(parent_node->codegen_node->expr_node.block_context,
+                &node->data.symbol);
+
+        // semantic checking ensures no variables are constant
+        assert(!var->is_const);
+
+        *out_type_entry = var->type;
+        target_ref = var->value_ref;
+    } else if (node->type == NodeTypeArrayAccessExpr) {
+        TypeTableEntry *array_type = get_expr_type(node->data.array_access_expr.array_ref_expr);
+        assert(array_type->id == TypeTableEntryIdArray);
+        *out_type_entry = array_type->data.array.child_type;
+        target_ref = gen_array_ptr(g, node);
+    } else if (node->type == NodeTypeFieldAccessExpr) {
+        target_ref = gen_field_ptr(g, node, out_type_entry);
+    } else {
+        zig_panic("bad assign target");
+    }
+
+    return target_ref;
+}
+
 static LLVMValueRef gen_prefix_op_expr(CodeGen *g, AstNode *node) {
     assert(node->type == NodeTypePrefixOpExpr);
     assert(node->data.prefix_op_expr.primary_expr);
 
-    LLVMValueRef expr = gen_expr(g, node->data.prefix_op_expr.primary_expr);
+    AstNode *expr_node = node->data.prefix_op_expr.primary_expr;
 
     switch (node->data.prefix_op_expr.prefix_op) {
+        case PrefixOpInvalid:
+            zig_unreachable();
         case PrefixOpNegation:
-            add_debug_source_node(g, node);
-            return LLVMBuildNeg(g->builder, expr, "");
+            {
+                LLVMValueRef expr = gen_expr(g, expr_node);
+                add_debug_source_node(g, node);
+                return LLVMBuildNeg(g->builder, expr, "");
+            }
         case PrefixOpBoolNot:
             {
+                LLVMValueRef expr = gen_expr(g, expr_node);
                 LLVMValueRef zero = LLVMConstNull(LLVMTypeOf(expr));
                 add_debug_source_node(g, node);
                 return LLVMBuildICmp(g->builder, LLVMIntEQ, expr, zero, "");
             }
         case PrefixOpBinNot:
-            add_debug_source_node(g, node);
-            return LLVMBuildNot(g->builder, expr, "");
-        case PrefixOpInvalid:
-            zig_unreachable();
+            {
+                LLVMValueRef expr = gen_expr(g, expr_node);
+                add_debug_source_node(g, node);
+                return LLVMBuildNot(g->builder, expr, "");
+            }
+        case PrefixOpAddressOf:
+        case PrefixOpConstAddressOf:
+            {
+                add_debug_source_node(g, node);
+                TypeTableEntry *lvalue_type;
+                return gen_lvalue(g, node, expr_node, &lvalue_type);
+            }
+
     }
     zig_unreachable();
 }
@@ -571,33 +617,14 @@ static LLVMValueRef gen_bool_or_expr(CodeGen *g, AstNode *expr_node) {
     return phi;
 }
 
-
 static LLVMValueRef gen_assign_expr(CodeGen *g, AstNode *node) {
     assert(node->type == NodeTypeBinOpExpr);
 
     AstNode *lhs_node = node->data.bin_op_expr.op1;
 
-    LLVMValueRef target_ref;
     TypeTableEntry *op1_type;
-    if (lhs_node->type == NodeTypeSymbol) {
-        VariableTableEntry *var = find_variable(node->codegen_node->expr_node.block_context,
-                &lhs_node->data.symbol);
-
-        // semantic checking ensures no variables are constant
-        assert(!var->is_const);
+    LLVMValueRef target_ref = gen_lvalue(g, node, lhs_node, &op1_type);
 
-        op1_type = var->type;
-        target_ref = var->value_ref;
-    } else if (lhs_node->type == NodeTypeArrayAccessExpr) {
-        TypeTableEntry *array_type = get_expr_type(lhs_node->data.array_access_expr.array_ref_expr);
-        assert(array_type->id == TypeTableEntryIdArray);
-        op1_type = array_type->data.array.child_type;
-        target_ref = gen_array_ptr(g, lhs_node);
-    } else if (lhs_node->type == NodeTypeFieldAccessExpr) {
-        target_ref = gen_field_ptr(g, lhs_node);
-    } else {
-        zig_panic("bad assign target");
-    }
     LLVMValueRef value = gen_expr(g, node->data.bin_op_expr.op2);
 
     if (node->data.bin_op_expr.bin_op == BinOpTypeAssign) {
src/parser.cpp
@@ -58,6 +58,8 @@ static const char *prefix_op_str(PrefixOp prefix_op) {
         case PrefixOpNegation: return "-";
         case PrefixOpBoolNot: return "!";
         case PrefixOpBinNot: return "~";
+        case PrefixOpAddressOf: return "&";
+        case PrefixOpConstAddressOf: return "&const";
     }
     zig_unreachable();
 }
src/parser.hpp
@@ -198,6 +198,8 @@ enum PrefixOp {
     PrefixOpBoolNot,
     PrefixOpBinNot,
     PrefixOpNegation,
+    PrefixOpAddressOf,
+    PrefixOpConstAddressOf,
 };
 
 struct AstNodePrefixOpExpr {
test/run_tests.cpp
@@ -566,7 +566,7 @@ use "std.zig";
 
 export fn main(argc : isize, argv : &&u8, env : &&u8) -> i32 {
     var foo : Foo;
-    foo.a = foo.a + 1;
+    foo.a += 1;
     foo.b = foo.a == 1;
     test_foo(foo);
     return 0;
@@ -749,7 +749,7 @@ fn f() {
     const a = 3;
     a = 4;
 }
-    )SOURCE", 1, ".tmp_source.zig:4:5: error: cannot assign to constant variable");
+    )SOURCE", 1, ".tmp_source.zig:4:5: error: cannot assign to constant");
 
     add_compile_fail_case("use of undeclared identifier", R"SOURCE(
 fn f() {
@@ -787,7 +787,7 @@ const x : i32 = 99;
 fn f() {
     x = 1;
 }
-    )SOURCE", 1, ".tmp_source.zig:4:5: error: cannot assign to constant variable");
+    )SOURCE", 1, ".tmp_source.zig:4:5: error: cannot assign to constant");
 
 
     add_compile_fail_case("missing else clause", R"SOURCE(