Commit e64c0941f9

Andrew Kelley <superjoe30@gmail.com>
2016-01-04 03:38:36
implement #sizeof()
closes #8
1 parent fa6e3ee
doc/langref.md
@@ -60,9 +60,11 @@ ParamDeclList : token(LParen) list(ParamDecl, token(Comma)) token(RParen)
 
 ParamDecl : token(Symbol) token(Colon) Type | token(Ellipsis)
 
-Type : token(Symbol) | token(Unreachable) | token(Void) | PointerType | ArrayType | MaybeType | CompileTimeFnCall
+Type : token(Symbol) | token(Unreachable) | token(Void) | PointerType | ArrayType | MaybeType | CompilerFnExpr
 
-CompileTimeFnCall : token(NumberSign) token(Symbol) token(LParen) Expression token(RParen)
+CompilerFnExpr : token(NumberSign) token(Symbol) token(LParen) Expression token(RParen)
+
+CompilerFnType : token(NumberSign) token(Symbol) token(LParen) Type token(RParen)
 
 PointerType : token(Ampersand) option(token(Const)) Type
 
@@ -152,7 +154,7 @@ ArrayAccessExpression : token(LBracket) Expression token(RBracket)
 
 PrefixOp : token(Not) | token(Dash) | token(Tilde) | (token(Ampersand) option(token(Const)))
 
-PrimaryExpression : token(Number) | token(String) | token(CharLiteral) | KeywordLiteral | GroupedExpression | Goto | token(Break) | token(Continue) | BlockExpression | token(Symbol) | StructValueExpression
+PrimaryExpression : token(Number) | token(String) | token(CharLiteral) | KeywordLiteral | GroupedExpression | Goto | token(Break) | token(Continue) | BlockExpression | token(Symbol) | StructValueExpression | CompilerFnType
 
 StructValueExpression : token(Type) token(LBrace) list(StructValueExpressionField, token(Comma)) token(RBrace)
 
src/analyze.cpp
@@ -59,7 +59,8 @@ static AstNode *first_executing_node(AstNode *node) {
         case NodeTypeStructValueExpr:
         case NodeTypeStructValueField:
         case NodeTypeWhileExpr:
-        case NodeTypeCompilerFnCall:
+        case NodeTypeCompilerFnExpr:
+        case NodeTypeCompilerFnType:
             return node;
     }
     zig_panic("unreachable");
@@ -109,6 +110,20 @@ TypeTableEntry *new_type_table_entry(TypeTableEntryId id) {
     return entry;
 }
 
+static TypeTableEntry *get_number_literal_type_unsigned(CodeGen *g, uint64_t x) {
+    NumLit kind;
+    if (x <= UINT8_MAX) {
+        kind = NumLitU8;
+    } else if (x <= UINT16_MAX) {
+        kind = NumLitU16;
+    } else if (x <= UINT32_MAX) {
+        kind = NumLitU32;
+    } else {
+        kind = NumLitU64;
+    }
+    return g->num_lit_types[kind];
+}
+
 TypeTableEntry *get_pointer_to_type(CodeGen *g, TypeTableEntry *child_type, bool is_const) {
     TypeTableEntry **parent_pointer = is_const ?
         &child_type->pointer_const_parent :
@@ -279,15 +294,16 @@ static TypeTableEntry *resolve_type(CodeGen *g, AstNode *node, ImportTableEntry
         case AstNodeTypeTypeCompilerExpr:
             {
                 AstNode *compiler_expr_node = node->data.type.compiler_expr;
-                Buf *fn_name = &compiler_expr_node->data.compiler_fn_call.name;
+                Buf *fn_name = &compiler_expr_node->data.compiler_fn_expr.name;
                 if (buf_eql_str(fn_name, "typeof")) {
-                    return analyze_expression(g, import, context, nullptr,
-                            compiler_expr_node->data.compiler_fn_call.expr);
+                    type_node->entry = analyze_expression(g, import, context, nullptr,
+                            compiler_expr_node->data.compiler_fn_expr.expr);
                 } else {
                     add_node_error(g, node,
                             buf_sprintf("invalid compiler function: '%s'", buf_ptr(fn_name)));
-                    return g->builtin_types.entry_invalid;
+                    type_node->entry = g->builtin_types.entry_invalid;
                 }
+                return type_node->entry;
             }
     }
     zig_unreachable();
@@ -625,7 +641,8 @@ static void preview_function_declarations(CodeGen *g, ImportTableEntry *import,
         case NodeTypeStructField:
         case NodeTypeStructValueExpr:
         case NodeTypeStructValueField:
-        case NodeTypeCompilerFnCall:
+        case NodeTypeCompilerFnExpr:
+        case NodeTypeCompilerFnType:
             zig_unreachable();
     }
 }
@@ -698,7 +715,8 @@ static void preview_types(CodeGen *g, ImportTableEntry *import, AstNode *node) {
         case NodeTypeStructField:
         case NodeTypeStructValueExpr:
         case NodeTypeStructValueField:
-        case NodeTypeCompilerFnCall:
+        case NodeTypeCompilerFnExpr:
+        case NodeTypeCompilerFnType:
             zig_unreachable();
     }
 }
@@ -1580,6 +1598,30 @@ static TypeTableEntry *analyze_if_var_expr(CodeGen *g, ImportTableEntry *import,
             node->data.if_var_expr.then_block, node->data.if_var_expr.else_node, node);
 }
 
+static TypeTableEntry *analyze_compiler_fn_type(CodeGen *g, ImportTableEntry *import, BlockContext *context,
+        TypeTableEntry *expected_type, AstNode *node)
+{
+    assert(node->type == NodeTypeCompilerFnType);
+
+    Buf *name = &node->data.compiler_fn_type.name;
+    if (buf_eql_str(name, "sizeof")) {
+        TypeTableEntry *type_entry = resolve_type(g, node->data.compiler_fn_type.type, import, context);
+        uint64_t size_in_bytes = type_entry->size_in_bits / 8;
+
+        TypeTableEntry *num_lit_type = get_number_literal_type_unsigned(g, size_in_bytes);
+
+        NumberLiteralNode *codegen_num_lit = &node->codegen_node->data.num_lit_node;
+        assert(!codegen_num_lit->resolved_type);
+        codegen_num_lit->resolved_type = resolve_type_compatibility(g, context, node, expected_type, num_lit_type);
+
+        return num_lit_type;
+    } else {
+        add_node_error(g, node,
+                buf_sprintf("invalid compiler function: '%s'", buf_ptr(name)));
+        return g->builtin_types.entry_invalid;
+    }
+}
+
 static TypeTableEntry * analyze_expression(CodeGen *g, ImportTableEntry *import, BlockContext *context,
         TypeTableEntry *expected_type, AstNode *node)
 {
@@ -1852,6 +1894,9 @@ static TypeTableEntry * analyze_expression(CodeGen *g, ImportTableEntry *import,
         case NodeTypeStructValueExpr:
             return_type = analyze_struct_val_expr(g, import, context, expected_type, node);
             break;
+        case NodeTypeCompilerFnType:
+            return_type = analyze_compiler_fn_type(g, import, context, expected_type, node);
+            break;
         case NodeTypeDirective:
         case NodeTypeFnDecl:
         case NodeTypeFnProto:
@@ -1866,7 +1911,7 @@ static TypeTableEntry * analyze_expression(CodeGen *g, ImportTableEntry *import,
         case NodeTypeStructDecl:
         case NodeTypeStructField:
         case NodeTypeStructValueField:
-        case NodeTypeCompilerFnCall:
+        case NodeTypeCompilerFnExpr:
             zig_unreachable();
     }
     assert(return_type);
@@ -2015,7 +2060,8 @@ static void analyze_top_level_declaration(CodeGen *g, ImportTableEntry *import,
         case NodeTypeStructField:
         case NodeTypeStructValueExpr:
         case NodeTypeStructValueField:
-        case NodeTypeCompilerFnCall:
+        case NodeTypeCompilerFnExpr:
+        case NodeTypeCompilerFnType:
             zig_unreachable();
     }
 }
src/codegen.cpp
@@ -1191,6 +1191,58 @@ static LLVMValueRef gen_var_decl_expr(CodeGen *g, AstNode *node) {
             node->codegen_node->expr_node.block_context, false, &init_val);
 }
 
+static LLVMValueRef gen_number_literal_raw(CodeGen *g, AstNode *source_node,
+        NumberLiteralNode *codegen_num_lit, AstNodeNumberLiteral *num_lit_node)
+{
+    TypeTableEntry *type_entry = codegen_num_lit->resolved_type;
+    assert(type_entry);
+
+    // override the expression type for number literals
+    source_node->codegen_node->expr_node.type_entry = type_entry;
+
+    if (type_entry->id == TypeTableEntryIdInt) {
+        // here the union has int64_t and uint64_t and we purposefully read
+        // the uint64_t value in either case, because we want the twos
+        // complement representation
+
+        return LLVMConstInt(type_entry->type_ref,
+                num_lit_node->data.x_uint,
+                type_entry->data.integral.is_signed);
+    } else if (type_entry->id == TypeTableEntryIdFloat) {
+
+        return LLVMConstReal(type_entry->type_ref,
+                num_lit_node->data.x_float);
+    } else {
+        zig_panic("bad number literal type");
+    }
+}
+
+static LLVMValueRef gen_compiler_fn_type(CodeGen *g, AstNode *node) {
+    assert(node->type == NodeTypeCompilerFnType);
+
+    Buf *name = &node->data.compiler_fn_type.name;
+    if (buf_eql_str(name, "sizeof")) {
+        TypeTableEntry *type_entry = get_type_for_type_node(g, node->data.compiler_fn_type.type);
+        NumberLiteralNode *codegen_num_lit = &node->codegen_node->data.num_lit_node;
+        AstNodeNumberLiteral num_lit_node;
+        num_lit_node.kind = type_entry->data.num_lit.kind;
+        num_lit_node.overflow = false;
+        num_lit_node.data.x_uint = type_entry->size_in_bits / 8;
+        return gen_number_literal_raw(g, node, codegen_num_lit, &num_lit_node);
+    } else {
+        zig_unreachable();
+    }
+}
+
+static LLVMValueRef gen_number_literal(CodeGen *g, AstNode *node) {
+    assert(node->type == NodeTypeNumberLiteral);
+
+    NumberLiteralNode *codegen_num_lit = &node->codegen_node->data.num_lit_node;
+    assert(codegen_num_lit);
+
+    return gen_number_literal_raw(g, node, codegen_num_lit, &node->data.number_literal);
+}
+
 static LLVMValueRef gen_expr_no_cast(CodeGen *g, AstNode *node) {
     switch (node->type) {
         case NodeTypeBinOpExpr:
@@ -1228,31 +1280,7 @@ static LLVMValueRef gen_expr_no_cast(CodeGen *g, AstNode *node) {
         case NodeTypeAsmExpr:
             return gen_asm_expr(g, node);
         case NodeTypeNumberLiteral:
-            {
-                NumberLiteralNode *codegen_num_lit = &node->codegen_node->data.num_lit_node;
-                assert(codegen_num_lit);
-                TypeTableEntry *type_entry = codegen_num_lit->resolved_type;
-                assert(type_entry);
-
-                // override the expression type for number literals
-                node->codegen_node->expr_node.type_entry = type_entry;
-
-                if (type_entry->id == TypeTableEntryIdInt) {
-                    // here the union has int64_t and uint64_t and we purposefully read
-                    // the uint64_t value in either case, because we want the twos
-                    // complement representation
-
-                    return LLVMConstInt(type_entry->type_ref,
-                            node->data.number_literal.data.x_uint,
-                            type_entry->data.integral.is_signed);
-                } else if (type_entry->id == TypeTableEntryIdFloat) {
-
-                    return LLVMConstReal(type_entry->type_ref,
-                            node->data.number_literal.data.x_float);
-                } else {
-                    zig_panic("bad number literal type");
-                }
-            }
+            return gen_number_literal(g, node);
         case NodeTypeStringLiteral:
             {
                 Buf *str = &node->data.string_literal.buf;
@@ -1313,6 +1341,8 @@ static LLVMValueRef gen_expr_no_cast(CodeGen *g, AstNode *node) {
             }
         case NodeTypeStructValueExpr:
             return gen_struct_val_expr(g, node);
+        case NodeTypeCompilerFnType:
+            return gen_compiler_fn_type(g, node);
         case NodeTypeRoot:
         case NodeTypeRootExportDecl:
         case NodeTypeFnProto:
@@ -1326,7 +1356,7 @@ static LLVMValueRef gen_expr_no_cast(CodeGen *g, AstNode *node) {
         case NodeTypeStructDecl:
         case NodeTypeStructField:
         case NodeTypeStructValueField:
-        case NodeTypeCompilerFnCall:
+        case NodeTypeCompilerFnExpr:
             zig_unreachable();
     }
     zig_unreachable();
src/parser.cpp
@@ -142,8 +142,10 @@ const char *node_type_str(NodeType node_type) {
             return "StructValueExpr";
         case NodeTypeStructValueField:
             return "StructValueField";
-        case NodeTypeCompilerFnCall:
-            return "CompilerFnCall";
+        case NodeTypeCompilerFnExpr:
+            return "CompilerFnExpr";
+        case NodeTypeCompilerFnType:
+            return "CompilerFnType";
     }
     zig_unreachable();
 }
@@ -410,7 +412,10 @@ void ast_print(AstNode *node, int indent) {
             fprintf(stderr, "%s '%s'\n", node_type_str(node->type), buf_ptr(&node->data.struct_val_field.name));
             ast_print(node->data.struct_val_field.expr, indent + 2);
             break;
-        case NodeTypeCompilerFnCall:
+        case NodeTypeCompilerFnExpr:
+            fprintf(stderr, "%s\n", node_type_str(node->type));
+            break;
+        case NodeTypeCompilerFnType:
             fprintf(stderr, "%s\n", node_type_str(node->type));
             break;
     }
@@ -996,7 +1001,32 @@ static void ast_parse_type_assume_amp(ParseContext *pc, int *token_index, AstNod
 }
 
 /*
-CompileTimeFnCall : token(NumberSign) token(Symbol) token(LParen) Expression token(RParen)
+CompilerFnType : token(NumberSign) token(Symbol) token(LParen) Expression token(RParen)
+*/
+static AstNode *ast_parse_compiler_fn_type(ParseContext *pc, int *token_index, bool mandatory) {
+    Token *token = &pc->tokens->at(*token_index);
+
+    if (token->id == TokenIdNumberSign) {
+        *token_index += 1;
+    } else if (mandatory) {
+        ast_invalid_token_error(pc, token);
+    } else {
+        return nullptr;
+    }
+
+    Token *name_symbol = ast_eat_token(pc, token_index, TokenIdSymbol);
+    ast_eat_token(pc, token_index, TokenIdLParen);
+
+    AstNode *node = ast_create_node(pc, NodeTypeCompilerFnType, token);
+    ast_buf_from_token(pc, name_symbol, &node->data.compiler_fn_type.name);
+    node->data.compiler_fn_type.type = ast_parse_type(pc, token_index);
+
+    ast_eat_token(pc, token_index, TokenIdRParen);
+    return node;
+}
+
+/*
+CompilerFnExpr : token(NumberSign) token(Symbol) token(LParen) Expression token(RParen)
 */
 static AstNode *ast_parse_compiler_fn_call(ParseContext *pc, int *token_index, bool mandatory) {
     Token *token = &pc->tokens->at(*token_index);
@@ -1012,16 +1042,16 @@ static AstNode *ast_parse_compiler_fn_call(ParseContext *pc, int *token_index, b
     Token *name_symbol = ast_eat_token(pc, token_index, TokenIdSymbol);
     ast_eat_token(pc, token_index, TokenIdLParen);
 
-    AstNode *node = ast_create_node(pc, NodeTypeCompilerFnCall, token);
-    ast_buf_from_token(pc, name_symbol, &node->data.compiler_fn_call.name);
-    node->data.compiler_fn_call.expr = ast_parse_expression(pc, token_index, true);
+    AstNode *node = ast_create_node(pc, NodeTypeCompilerFnExpr, token);
+    ast_buf_from_token(pc, name_symbol, &node->data.compiler_fn_expr.name);
+    node->data.compiler_fn_expr.expr = ast_parse_expression(pc, token_index, true);
 
     ast_eat_token(pc, token_index, TokenIdRParen);
     return node;
 }
 
 /*
-Type : token(Symbol) | token(Unreachable) | token(Void) | PointerType | ArrayType | MaybeType | CompileTimeFnCall
+Type : token(Symbol) | token(Unreachable) | token(Void) | PointerType | ArrayType | MaybeType | CompilerFnExpr
 PointerType : token(Ampersand) option(token(Const)) Type
 ArrayType : token(LBracket) Type token(Semicolon) token(Number) token(RBracket)
 */
@@ -1029,10 +1059,10 @@ static AstNode *ast_parse_type(ParseContext *pc, int *token_index) {
     Token *token = &pc->tokens->at(*token_index);
     AstNode *node = ast_create_node(pc, NodeTypeType, token);
 
-    AstNode *compiler_fn_call = ast_parse_compiler_fn_call(pc, token_index, false);
-    if (compiler_fn_call) {
+    AstNode *compiler_fn_expr = ast_parse_compiler_fn_call(pc, token_index, false);
+    if (compiler_fn_expr) {
         node->data.type.type = AstNodeTypeTypeCompilerExpr;
-        node->data.type.compiler_expr = compiler_fn_call;
+        node->data.type.compiler_expr = compiler_fn_expr;
         return node;
     }
 
@@ -1238,7 +1268,7 @@ static AstNode *ast_parse_struct_val_expr(ParseContext *pc, int *token_index) {
 }
 
 /*
-PrimaryExpression : token(Number) | token(String) | token(CharLiteral) | KeywordLiteral | GroupedExpression | Goto | token(Break) | token(Continue) | BlockExpression | token(Symbol) | StructValueExpression
+PrimaryExpression : token(Number) | token(String) | token(CharLiteral) | KeywordLiteral | GroupedExpression | Goto | token(Break) | token(Continue) | BlockExpression | token(Symbol) | StructValueExpression | CompilerFnType
 */
 static AstNode *ast_parse_primary_expr(ParseContext *pc, int *token_index, bool mandatory) {
     Token *token = &pc->tokens->at(*token_index);
@@ -1317,6 +1347,11 @@ static AstNode *ast_parse_primary_expr(ParseContext *pc, int *token_index, bool
         return block_expr_node;
     }
 
+    AstNode *compiler_fn_type = ast_parse_compiler_fn_type(pc, token_index, false);
+    if (compiler_fn_type) {
+        return compiler_fn_type;
+    }
+
     if (!mandatory)
         return nullptr;
 
src/parser.hpp
@@ -57,7 +57,8 @@ enum NodeType {
     NodeTypeStructField,
     NodeTypeStructValueExpr,
     NodeTypeStructValueField,
-    NodeTypeCompilerFnCall,
+    NodeTypeCompilerFnExpr,
+    NodeTypeCompilerFnType,
 };
 
 struct AstNodeRoot {
@@ -332,11 +333,16 @@ struct AstNodeStructValueExpr {
     ZigList<AstNode *> fields;
 };
 
-struct AstNodeCompilerFnCall {
+struct AstNodeCompilerFnExpr {
     Buf name;
     AstNode *expr;
 };
 
+struct AstNodeCompilerFnType {
+    Buf name;
+    AstNode *type;
+};
+
 struct AstNode {
     enum NodeType type;
     int line;
@@ -376,7 +382,8 @@ struct AstNode {
         AstNodeNumberLiteral number_literal;
         AstNodeStructValueExpr struct_val_expr;
         AstNodeStructValueField struct_val_field;
-        AstNodeCompilerFnCall compiler_fn_call;
+        AstNodeCompilerFnExpr compiler_fn_expr;
+        AstNodeCompilerFnType compiler_fn_type;
         Buf symbol;
         bool bool_literal;
     } data;
test/run_tests.cpp
@@ -698,16 +698,17 @@ fn outer() -> isize {
 }
     )SOURCE", "OK\n");
 
-    add_simple_case("#typeof()", R"SOURCE(
+    add_simple_case("#sizeof() and #typeof()", R"SOURCE(
 use "std.zig";
 const x: u16 = 13;
 const z: #typeof(x) = 19;
 pub fn main(argc : isize, argv : &&u8, env : &&u8) -> i32 {
     const y: #typeof(x) = 120;
-    print_str("OK\n");
+    print_u64(#sizeof(#typeof(y)));
+    print_str("\n");
     return 0;
 }
-    )SOURCE", "OK\n");
+    )SOURCE", "2\n");
 }
 
 ////////////////////////////////////////////////////////////////////////////////////