Commit 50357dad45

Andrew Kelley <superjoe30@gmail.com>
2015-12-24 08:00:23
add struct value expression
1 parent 9ce36ba
doc/langref.md
@@ -144,7 +144,11 @@ ArrayAccessExpression : token(LBracket) Expression token(RBracket)
 
 PrefixOp : token(Not) | token(Dash) | token(Tilde) | (token(Ampersand) option(token(Const)))
 
-PrimaryExpression : token(Number) | token(String) | KeywordLiteral | GroupedExpression | token(Symbol) | Goto | BlockExpression
+PrimaryExpression : token(Number) | token(String) | KeywordLiteral | GroupedExpression | Goto | BlockExpression | token(Symbol) | StructValueExpression
+
+StructValueExpression : token(Type) token(LBrace) list(StructValueExpressionField, token(Comma)) token(RBrace)
+
+StructValueExpressionField : token(Dot) token(Symbol) token(Eq) Expression
 
 Goto: token(Goto) token(Symbol)
 
example/structs/structs.zig
@@ -21,6 +21,8 @@ pub fn main(argc : isize, argv : &&u8, env : &&u8) -> i32 {
 
     test_byval_assign();
 
+    test_initializer();
+
     print_str("OK\n");
     return 0;
 }
@@ -78,3 +80,8 @@ fn test_byval_assign() {
     if foo2.a != 1234 { print_str("BAD - byval assignment failed\n"); }
 
 }
+
+fn test_initializer() {
+    const val = Val { .x = 42 };
+    if val.x != 42 { print_str("BAD\n"); }
+}
src/analyze.cpp
@@ -52,6 +52,8 @@ static AstNode *first_executing_node(AstNode *node) {
         case NodeTypeFieldAccessExpr:
         case NodeTypeStructDecl:
         case NodeTypeStructField:
+        case NodeTypeStructValueExpr:
+        case NodeTypeStructValueField:
             return node;
     }
     zig_panic("unreachable");
@@ -529,6 +531,8 @@ static void preview_function_declarations(CodeGen *g, ImportTableEntry *import,
         case NodeTypeAsmExpr:
         case NodeTypeFieldAccessExpr:
         case NodeTypeStructField:
+        case NodeTypeStructValueExpr:
+        case NodeTypeStructValueField:
             zig_unreachable();
     }
 }
@@ -594,6 +598,8 @@ static void preview_types(CodeGen *g, ImportTableEntry *import, AstNode *node) {
         case NodeTypeAsmExpr:
         case NodeTypeFieldAccessExpr:
         case NodeTypeStructField:
+        case NodeTypeStructValueExpr:
+        case NodeTypeStructValueField:
             zig_unreachable();
     }
 }
@@ -1060,6 +1066,7 @@ static TypeTableEntry *analyze_cast_expr(CodeGen *g, ImportTableEntry *import, B
 enum LValPurpose {
     LValPurposeAssign,
     LValPurposeAddressOf,
+    LValPurposeNotLVal,
 };
 
 static TypeTableEntry *analyze_lvalue(CodeGen *g, ImportTableEntry *import, BlockContext *block_context,
@@ -1269,6 +1276,62 @@ static TypeTableEntry *analyze_number_literal_expr(CodeGen *g, ImportTableEntry
     }
 }
 
+static TypeStructField *find_struct_type_field(TypeTableEntry *type_entry, Buf *name, int *index) {
+    assert(type_entry->id == TypeTableEntryIdStruct);
+    for (int i = 0; i < type_entry->data.structure.field_count; i += 1) {
+        TypeStructField *field = &type_entry->data.structure.fields[i];
+        if (buf_eql_buf(field->name, name)) {
+            *index = i;
+            return field;
+        }
+    }
+    return nullptr;
+}
+
+static TypeTableEntry *analyze_struct_val_expr(CodeGen *g, ImportTableEntry *import, BlockContext *context,
+        TypeTableEntry *expected_type, AstNode *node)
+{
+    assert(node->type == NodeTypeStructValueExpr);
+
+    AstNodeStructValueExpr *struct_val_expr = &node->data.struct_val_expr;
+
+    TypeTableEntry *type_entry = resolve_type(g, struct_val_expr->type);
+
+    if (type_entry->id == TypeTableEntryIdInvalid) {
+        return g->builtin_types.entry_invalid;
+    } else if (type_entry->id != TypeTableEntryIdStruct) {
+        add_node_error(g, node,
+            buf_sprintf("type '%s' is not a struct", buf_ptr(&type_entry->name)));
+        return g->builtin_types.entry_invalid;
+    }
+
+    assert(node->codegen_node);
+    node->codegen_node->data.struct_val_expr_node.type_entry = type_entry;
+    node->codegen_node->data.struct_val_expr_node.source_node = node;
+    context->struct_val_expr_alloca_list.append(&node->codegen_node->data.struct_val_expr_node);
+
+    for (int i = 0; i < struct_val_expr->fields.length; i += 1) {
+        AstNode *val_field_node = struct_val_expr->fields.at(i);
+        int field_index;
+        TypeStructField *type_field = find_struct_type_field(type_entry,
+                &val_field_node->data.struct_val_field.name, &field_index);
+
+        if (!type_field) {
+            add_node_error(g, val_field_node,
+                buf_sprintf("type '%s' is not a struct", buf_ptr(&type_entry->name)));
+            continue;
+        }
+
+        alloc_codegen_node(val_field_node);
+        val_field_node->codegen_node->data.struct_val_field_node.index = field_index;
+
+        analyze_expression(g, import, context, type_field->type_entry,
+                val_field_node->data.struct_val_field.expr);
+    }
+
+    return type_entry;
+}
+
 static TypeTableEntry * analyze_expression(CodeGen *g, ImportTableEntry *import, BlockContext *context,
         TypeTableEntry *expected_type, AstNode *node)
 {
@@ -1545,6 +1608,9 @@ static TypeTableEntry * analyze_expression(CodeGen *g, ImportTableEntry *import,
                 }
                 break;
             }
+        case NodeTypeStructValueExpr:
+            return_type = analyze_struct_val_expr(g, import, context, expected_type, node);
+            break;
         case NodeTypeDirective:
         case NodeTypeFnDecl:
         case NodeTypeFnProto:
@@ -1558,6 +1624,7 @@ static TypeTableEntry * analyze_expression(CodeGen *g, ImportTableEntry *import,
         case NodeTypeLabel:
         case NodeTypeStructDecl:
         case NodeTypeStructField:
+        case NodeTypeStructValueField:
             zig_unreachable();
     }
     assert(return_type);
@@ -1690,6 +1757,8 @@ static void analyze_top_level_declaration(CodeGen *g, ImportTableEntry *import,
         case NodeTypeAsmExpr:
         case NodeTypeFieldAccessExpr:
         case NodeTypeStructField:
+        case NodeTypeStructValueExpr:
+        case NodeTypeStructValueField:
             zig_unreachable();
     }
 }
src/analyze.hpp
@@ -18,6 +18,7 @@ struct BlockContext;
 struct TypeTableEntry;
 struct VariableTableEntry;
 struct CastNode;
+struct StructValExprNode;
 
 struct TypeTableEntryPointer {
     TypeTableEntry *child_type;
@@ -223,6 +224,7 @@ struct BlockContext {
     BlockContext *parent; // null when this is the root
     HashMap<Buf *, VariableTableEntry *, buf_hash, buf_eql_buf> variable_table;
     ZigList<CastNode *> cast_expr_alloca_list;
+    ZigList<StructValExprNode *> struct_val_expr_alloca_list;
     LLVMZigDIScope *di_scope;
 };
 
@@ -292,6 +294,16 @@ struct VarDeclNode {
     TypeTableEntry *type;
 };
 
+struct StructValFieldNode {
+    int index;
+};
+
+struct StructValExprNode {
+    TypeTableEntry *type_entry;
+    LLVMValueRef ptr;
+    AstNode *source_node;
+};
+
 struct CodeGenNode {
     union {
         TypeNode type_node; // for NodeTypeType
@@ -305,6 +317,8 @@ struct CodeGenNode {
         CastNode cast_node; // for NodeTypeCastExpr
         NumberLiteralNode num_lit_node; // for NodeTypeNumberLiteral
         VarDeclNode var_decl_node; // for NodeTypeVariableDeclaration
+        StructValFieldNode struct_val_field_node; // for NodeTypeStructValueField
+        StructValExprNode struct_val_expr_node; // for NodeTypeStructValueExpr
     } data;
     ExprNode expr_node; // for all the expression nodes
 };
src/codegen.cpp
@@ -657,6 +657,28 @@ static LLVMValueRef gen_bool_or_expr(CodeGen *g, AstNode *expr_node) {
     return phi;
 }
 
+static LLVMValueRef gen_struct_memcpy(CodeGen *g, AstNode *source_node, LLVMValueRef src, LLVMValueRef dest,
+        TypeTableEntry *type_entry)
+{
+    assert(type_entry->id == TypeTableEntryIdStruct);
+
+    LLVMTypeRef ptr_u8 = LLVMPointerType(LLVMInt8Type(), 0);
+
+    add_debug_source_node(g, source_node);
+    LLVMValueRef src_ptr = LLVMBuildBitCast(g->builder, src, ptr_u8, "");
+    LLVMValueRef dest_ptr = LLVMBuildBitCast(g->builder, dest, ptr_u8, "");
+
+    LLVMValueRef params[] = {
+        dest_ptr, // dest pointer
+        src_ptr, // source pointer
+        LLVMConstInt(LLVMIntType(g->pointer_size_bytes * 8), type_entry->size_in_bits / 8, false), // byte count
+        LLVMConstInt(LLVMInt32Type(), type_entry->align_in_bits / 8, false), // align in bytes
+        LLVMConstNull(LLVMInt1Type()), // is volatile
+    };
+
+    return LLVMBuildCall(g->builder, g->memcpy_fn_val, params, 5, "");
+}
+
 static LLVMValueRef gen_assign_expr(CodeGen *g, AstNode *node) {
     assert(node->type == NodeTypeBinOpExpr);
 
@@ -675,21 +697,7 @@ static LLVMValueRef gen_assign_expr(CodeGen *g, AstNode *node) {
         assert(op1_type == op2_type);
         assert(node->data.bin_op_expr.bin_op == BinOpTypeAssign);
 
-        LLVMTypeRef ptr_u8 = LLVMPointerType(LLVMInt8Type(), 0);
-
-        add_debug_source_node(g, node);
-        LLVMValueRef src_ptr = LLVMBuildBitCast(g->builder, value, ptr_u8, "");
-        LLVMValueRef dest_ptr = LLVMBuildBitCast(g->builder, target_ref, ptr_u8, "");
-
-        LLVMValueRef params[] = {
-            dest_ptr, // dest pointer
-            src_ptr, // source pointer
-            LLVMConstInt(LLVMIntType(g->pointer_size_bytes * 8), op1_type->size_in_bits / 8, false), // byte count
-            LLVMConstInt(LLVMInt32Type(), op1_type->align_in_bits / 8, false), // align in bits
-            LLVMConstNull(LLVMInt1Type()), // is volatile
-        };
-
-        return LLVMBuildCall(g->builder, g->memcpy_fn_val, params, 5, "");
+        return gen_struct_memcpy(g, node, value, target_ref, op1_type);
     }
 
     if (node->data.bin_op_expr.bin_op != BinOpTypeAssign) {
@@ -970,6 +978,34 @@ static LLVMValueRef gen_asm_expr(CodeGen *g, AstNode *node) {
     return LLVMBuildCall(g->builder, asm_fn, param_values, input_and_output_count, "");
 }
 
+static LLVMValueRef gen_struct_val_expr(CodeGen *g, AstNode *node) {
+    assert(node->type == NodeTypeStructValueExpr);
+
+    TypeTableEntry *type_entry = get_expr_type(node);
+
+    assert(type_entry->id == TypeTableEntryIdStruct);
+
+    int field_count = type_entry->data.structure.field_count;
+    assert(field_count == node->data.struct_val_expr.fields.length);
+
+    StructValExprNode *struct_val_expr_node = &node->codegen_node->data.struct_val_expr_node;
+    LLVMValueRef tmp_struct_ptr = struct_val_expr_node->ptr;
+
+    for (int i = 0; i < field_count; i += 1) {
+        AstNode *field_node = node->data.struct_val_expr.fields.at(i);
+        int index = field_node->codegen_node->data.struct_val_field_node.index;
+        TypeStructField *type_struct_field = &type_entry->data.structure.fields[index];
+        assert(buf_eql_buf(type_struct_field->name, &field_node->data.struct_val_field.name));
+
+        add_debug_source_node(g, field_node);
+        LLVMValueRef field_ptr = LLVMBuildStructGEP(g->builder, tmp_struct_ptr, index, "");
+        LLVMValueRef value = gen_expr(g, field_node->data.struct_val_field.expr);
+        LLVMBuildStore(g->builder, value, field_ptr);
+    }
+
+    return tmp_struct_ptr;
+}
+
 static LLVMValueRef gen_expr_no_cast(CodeGen *g, AstNode *node) {
     switch (node->type) {
         case NodeTypeBinOpExpr:
@@ -994,8 +1030,13 @@ static LLVMValueRef gen_expr_no_cast(CodeGen *g, AstNode *node) {
                 if (variable->type->id == TypeTableEntryIdVoid) {
                     return nullptr;
                 } else {
-                    add_debug_source_node(g, node);
-                    LLVMValueRef store_instr = LLVMBuildStore(g->builder, value, variable->value_ref);
+                    LLVMValueRef store_instr;
+                    if (variable->type->id == TypeTableEntryIdStruct && node->data.variable_declaration.expr) {
+                        store_instr = gen_struct_memcpy(g, node, value, variable->value_ref, variable->type);
+                    } else {
+                        add_debug_source_node(g, node);
+                        store_instr = LLVMBuildStore(g->builder, value, variable->value_ref);
+                    }
 
                     LLVMZigDILocation *debug_loc = LLVMZigGetDebugLoc(node->line + 1, node->column + 1,
                             g->cur_block_context->di_scope);
@@ -1035,7 +1076,7 @@ static LLVMValueRef gen_expr_no_cast(CodeGen *g, AstNode *node) {
                 TypeTableEntry *type_entry = codegen_num_lit->resolved_type;
                 assert(type_entry);
 
-                // TODO this is kinda iffy. make sure josh is on board with this
+                // override the expression type for number literals
                 node->codegen_node->expr_node.type_entry = type_entry;
 
                 if (type_entry->id == TypeTableEntryIdInt) {
@@ -1104,6 +1145,8 @@ static LLVMValueRef gen_expr_no_cast(CodeGen *g, AstNode *node) {
                 LLVMPositionBuilderAtEnd(g->builder, basic_block);
                 return nullptr;
             }
+        case NodeTypeStructValueExpr:
+            return gen_struct_val_expr(g, node);
         case NodeTypeRoot:
         case NodeTypeRootExportDecl:
         case NodeTypeFnProto:
@@ -1116,6 +1159,7 @@ static LLVMValueRef gen_expr_no_cast(CodeGen *g, AstNode *node) {
         case NodeTypeUse:
         case NodeTypeStructDecl:
         case NodeTypeStructField:
+        case NodeTypeStructValueField:
             zig_unreachable();
     }
     zig_unreachable();
@@ -1358,6 +1402,14 @@ static void do_code_gen(CodeGen *g) {
                 add_debug_source_node(g, cast_node->source_node);
                 cast_node->ptr = LLVMBuildAlloca(g->builder, cast_node->type->type_ref, "");
             }
+
+            // allocate structs which are struct value expressions
+            for (int alloca_i = 0; alloca_i < block_context->struct_val_expr_alloca_list.length; alloca_i += 1) {
+                StructValExprNode *struct_val_expr_node = block_context->struct_val_expr_alloca_list.at(alloca_i);
+                add_debug_source_node(g, struct_val_expr_node->source_node);
+                struct_val_expr_node->ptr = LLVMBuildAlloca(g->builder,
+                        struct_val_expr_node->type_entry->type_ref, "");
+            }
         }
 
         TypeTableEntry *implicit_return_type = codegen_fn_def->implicit_return_type;
src/parser.cpp
@@ -128,6 +128,10 @@ const char *node_type_str(NodeType node_type) {
             return "StructDecl";
         case NodeTypeStructField:
             return "StructField";
+        case NodeTypeStructValueExpr:
+            return "StructValueExpr";
+        case NodeTypeStructValueField:
+            return "StructValueField";
     }
     zig_unreachable();
 }
@@ -341,6 +345,18 @@ void ast_print(AstNode *node, int indent) {
             fprintf(stderr, "%s '%s'\n", node_type_str(node->type), buf_ptr(&node->data.struct_field.name));
             ast_print(node->data.struct_field.type, indent + 2);
             break;
+        case NodeTypeStructValueExpr:
+            fprintf(stderr, "%s\n", node_type_str(node->type));
+            ast_print(node->data.struct_val_expr.type, indent + 2);
+            for (int i = 0; i < node->data.struct_val_expr.fields.length; i += 1) {
+                AstNode *child = node->data.struct_val_expr.fields.at(i);
+                ast_print(child, indent + 2);
+            }
+            break;
+        case NodeTypeStructValueField:
+            fprintf(stderr, "%s '%s'\n", node_type_str(node->type), buf_ptr(&node->data.struct_val_field.name));
+            ast_print(node->data.struct_val_field.expr, indent + 2);
+            break;
     }
 }
 
@@ -1035,7 +1051,51 @@ static AstNode *ast_parse_grouped_expr(ParseContext *pc, int *token_index, bool
 }
 
 /*
-PrimaryExpression : token(Number) | token(String) | KeywordLiteral | GroupedExpression | token(Symbol) | Goto | BlockExpression
+StructValueExpression : token(Symbol) token(LBrace) list(StructValueExpressionField, token(Comma)) token(RBrace)
+StructValueExpressionField : token(Dot) token(Symbol) token(Eq) Expression
+*/
+static AstNode *ast_parse_struct_val_expr(ParseContext *pc, int *token_index) {
+    Token *first_token = &pc->tokens->at(*token_index);
+    AstNode *node = ast_create_node(pc, NodeTypeStructValueExpr, first_token);
+
+    node->data.struct_val_expr.type = ast_parse_type(pc, token_index);
+
+    ast_eat_token(pc, token_index, TokenIdLBrace);
+
+    for (;;) {
+        Token *token = &pc->tokens->at(*token_index);
+        *token_index += 1;
+
+        if (token->id == TokenIdRBrace) {
+            return node;
+        } else if (token->id == TokenIdDot) {
+            Token *field_name_tok = ast_eat_token(pc, token_index, TokenIdSymbol);
+            ast_eat_token(pc, token_index, TokenIdEq);
+
+            AstNode *field_node = ast_create_node(pc, NodeTypeStructValueField, token);
+
+            ast_buf_from_token(pc, field_name_tok, &field_node->data.struct_val_field.name);
+            field_node->data.struct_val_field.expr = ast_parse_expression(pc, token_index, true);
+
+            node->data.struct_val_expr.fields.append(field_node);
+
+            Token *comma_tok = &pc->tokens->at(*token_index);
+            if (comma_tok->id == TokenIdComma) {
+                *token_index += 1;
+            } else if (comma_tok->id != TokenIdRBrace) {
+                ast_invalid_token_error(pc, comma_tok);
+            } else {
+                *token_index += 1;
+                return node;
+            }
+        } else {
+            ast_invalid_token_error(pc, token);
+        }
+    }
+}
+
+/*
+PrimaryExpression : token(Number) | token(String) | KeywordLiteral | GroupedExpression | Goto | BlockExpression | token(Symbol) | StructValueExpression
 */
 static AstNode *ast_parse_primary_expr(ParseContext *pc, int *token_index, bool mandatory) {
     Token *token = &pc->tokens->at(*token_index);
@@ -1069,10 +1129,16 @@ static AstNode *ast_parse_primary_expr(ParseContext *pc, int *token_index, bool
         *token_index += 1;
         return node;
     } else if (token->id == TokenIdSymbol) {
-        AstNode *node = ast_create_node(pc, NodeTypeSymbol, token);
-        ast_buf_from_token(pc, token, &node->data.symbol);
-        *token_index += 1;
-        return node;
+        Token *next_token = &pc->tokens->at(*token_index + 1);
+
+        if (next_token->id == TokenIdLBrace) {
+            return ast_parse_struct_val_expr(pc, token_index);
+        } else {
+            *token_index += 1;
+            AstNode *node = ast_create_node(pc, NodeTypeSymbol, token);
+            ast_buf_from_token(pc, token, &node->data.symbol);
+            return node;
+        }
     } else if (token->id == TokenIdKeywordGoto) {
         AstNode *node = ast_create_node(pc, NodeTypeGoto, token);
         *token_index += 1;
src/parser.hpp
@@ -50,6 +50,8 @@ enum NodeType {
     NodeTypeAsmExpr,
     NodeTypeStructDecl,
     NodeTypeStructField,
+    NodeTypeStructValueExpr,
+    NodeTypeStructValueField,
 };
 
 struct AstNodeRoot {
@@ -296,6 +298,16 @@ struct AstNodeNumberLiteral {
     } data;
 };
 
+struct AstNodeStructValueField {
+    Buf name;
+    AstNode *expr;
+};
+
+struct AstNodeStructValueExpr {
+    AstNode *type;
+    ZigList<AstNode *> fields;
+};
+
 struct AstNode {
     enum NodeType type;
     int line;
@@ -330,6 +342,8 @@ struct AstNode {
         AstNodeStructField struct_field;
         AstNodeStringLiteral string_literal;
         AstNodeNumberLiteral number_literal;
+        AstNodeStructValueExpr struct_val_expr;
+        AstNodeStructValueField struct_val_field;
         Buf symbol;
         bool bool_literal;
     } data;
test/run_tests.cpp
@@ -575,6 +575,7 @@ export fn main(argc : isize, argv : &&u8, env : &&u8) -> i32 {
     }
     test_point_to_self();
     test_byval_assign();
+    test_initializer();
     print_str("OK\n");
     return 0;
 }
@@ -624,6 +625,10 @@ fn test_byval_assign() {
     foo2 = foo1;
 
     if foo2.a != 1234 { print_str("BAD - byval assignment failed\n"); }
+}
+fn test_initializer() {
+    const val = Val { .x = 42 };
+    if val.x != 42 { print_str("BAD\n"); }
 }
     )SOURCE", "OK\n");