Commit fa6e3eec46

Andrew Kelley <superjoe30@gmail.com>
2016-01-04 02:17:50
add #typeof() compiler function
1 parent b453345
doc/langref.md
@@ -60,7 +60,9 @@ ParamDeclList : token(LParen) list(ParamDecl, token(Comma)) token(RParen)
 
 ParamDecl : token(Symbol) token(Colon) Type | token(Ellipsis)
 
-Type : token(Symbol) | token(Unreachable) | token(Void) | PointerType | ArrayType | MaybeType
+Type : token(Symbol) | token(Unreachable) | token(Void) | PointerType | ArrayType | MaybeType | CompileTimeFnCall
+
+CompileTimeFnCall : token(NumberSign) token(Symbol) token(LParen) Expression token(RParen)
 
 PointerType : token(Ampersand) option(token(Const)) Type
 
example/rand/main.zig
@@ -2,7 +2,7 @@
 const ARRAY_SIZE : u16 = 624;
 
 /// Use `rand_init` to initialize this state.
-pub struct Rand {
+struct Rand {
     array: [u32; ARRAY_SIZE],
     index: #typeof(ARRAY_SIZE),
 
src/analyze.cpp
@@ -59,6 +59,7 @@ static AstNode *first_executing_node(AstNode *node) {
         case NodeTypeStructValueExpr:
         case NodeTypeStructValueField:
         case NodeTypeWhileExpr:
+        case NodeTypeCompilerFnCall:
             return node;
     }
     zig_panic("unreachable");
@@ -208,7 +209,7 @@ static TypeTableEntry *get_array_type(CodeGen *g, TypeTableEntry *child_type, ui
     }
 }
 
-static TypeTableEntry *resolve_type(CodeGen *g, AstNode *node) {
+static TypeTableEntry *resolve_type(CodeGen *g, AstNode *node, ImportTableEntry *import, BlockContext *context) {
     assert(node->type == NodeTypeType);
     alloc_codegen_node(node);
     TypeNode *type_node = &node->codegen_node->data.type_node;
@@ -228,7 +229,7 @@ static TypeTableEntry *resolve_type(CodeGen *g, AstNode *node) {
             }
         case AstNodeTypeTypePointer:
             {
-                resolve_type(g, node->data.type.child_type);
+                resolve_type(g, node->data.type.child_type, import, context);
                 TypeTableEntry *child_type = node->data.type.child_type->codegen_node->data.type_node.entry;
                 assert(child_type);
                 if (child_type->id == TypeTableEntryIdUnreachable) {
@@ -242,7 +243,7 @@ static TypeTableEntry *resolve_type(CodeGen *g, AstNode *node) {
             }
         case AstNodeTypeTypeArray:
             {
-                resolve_type(g, node->data.type.child_type);
+                resolve_type(g, node->data.type.child_type, import, context);
                 TypeTableEntry *child_type = node->data.type.child_type->codegen_node->data.type_node.entry;
                 if (child_type->id == TypeTableEntryIdUnreachable) {
                     add_node_error(g, node,
@@ -263,7 +264,7 @@ static TypeTableEntry *resolve_type(CodeGen *g, AstNode *node) {
             }
         case AstNodeTypeTypeMaybe:
             {
-                resolve_type(g, node->data.type.child_type);
+                resolve_type(g, node->data.type.child_type, import, context);
                 TypeTableEntry *child_type = node->data.type.child_type->codegen_node->data.type_node.entry;
                 assert(child_type);
                 if (child_type->id == TypeTableEntryIdUnreachable) {
@@ -275,11 +276,26 @@ static TypeTableEntry *resolve_type(CodeGen *g, AstNode *node) {
                 type_node->entry = get_maybe_type(g, child_type);
                 return type_node->entry;
             }
+        case AstNodeTypeTypeCompilerExpr:
+            {
+                AstNode *compiler_expr_node = node->data.type.compiler_expr;
+                Buf *fn_name = &compiler_expr_node->data.compiler_fn_call.name;
+                if (buf_eql_str(fn_name, "typeof")) {
+                    return analyze_expression(g, import, context, nullptr,
+                            compiler_expr_node->data.compiler_fn_call.expr);
+                } else {
+                    add_node_error(g, node,
+                            buf_sprintf("invalid compiler function: '%s'", buf_ptr(fn_name)));
+                    return g->builtin_types.entry_invalid;
+                }
+            }
     }
     zig_unreachable();
 }
 
-static void resolve_function_proto(CodeGen *g, AstNode *node, FnTableEntry *fn_table_entry) {
+static void resolve_function_proto(CodeGen *g, AstNode *node, FnTableEntry *fn_table_entry,
+        ImportTableEntry *import)
+{
     assert(node->type == NodeTypeFnProto);
 
     for (int i = 0; i < node->data.fn_proto.directives->length; i += 1) {
@@ -310,7 +326,7 @@ static void resolve_function_proto(CodeGen *g, AstNode *node, FnTableEntry *fn_t
     for (int i = 0; i < node->data.fn_proto.params.length; i += 1) {
         AstNode *child = node->data.fn_proto.params.at(i);
         assert(child->type == NodeTypeParamDecl);
-        TypeTableEntry *type_entry = resolve_type(g, child->data.param_decl.type);
+        TypeTableEntry *type_entry = resolve_type(g, child->data.param_decl.type, import, import->block_context);
         if (type_entry->id == TypeTableEntryIdUnreachable) {
             add_node_error(g, child->data.param_decl.type,
                 buf_sprintf("parameter of type 'unreachable' not allowed"));
@@ -322,7 +338,7 @@ static void resolve_function_proto(CodeGen *g, AstNode *node, FnTableEntry *fn_t
         }
     }
 
-    resolve_type(g, node->data.fn_proto.return_type);
+    resolve_type(g, node->data.fn_proto.return_type, import, import->block_context);
 }
 
 static void preview_function_labels(CodeGen *g, AstNode *node, FnTableEntry *fn_table_entry) {
@@ -383,7 +399,7 @@ static void resolve_struct_type(CodeGen *g, ImportTableEntry *import, TypeTableE
         AstNode *field_node = decl_node->data.struct_decl.fields.at(i);
         TypeStructField *type_struct_field = &struct_type->data.structure.fields[i];
         type_struct_field->name = &field_node->data.struct_field.name;
-        type_struct_field->type_entry = resolve_type(g, field_node->data.struct_field.type);
+        type_struct_field->type_entry = resolve_type(g, field_node->data.struct_field.type, import, import->block_context);
 
         if (type_struct_field->type_entry->id == TypeTableEntryIdStruct) {
             resolve_struct_type(g, import, type_struct_field->type_entry);
@@ -453,7 +469,7 @@ static void preview_function_declarations(CodeGen *g, ImportTableEntry *import,
                 fn_table_entry->import_entry = import;
                 fn_table_entry->label_table.init(8);
 
-                resolve_function_proto(g, fn_proto, fn_table_entry);
+                resolve_function_proto(g, fn_proto, fn_table_entry, import);
 
                 Buf *name = &fn_proto->data.fn_proto.name;
                 g->fn_protos.append(fn_table_entry);
@@ -512,7 +528,7 @@ static void preview_function_declarations(CodeGen *g, ImportTableEntry *import,
                         g->fn_table.put(proto_name, fn_table_entry);
                     }
 
-                    resolve_function_proto(g, proto_node, fn_table_entry);
+                    resolve_function_proto(g, proto_node, fn_table_entry, import);
 
 
                     alloc_codegen_node(proto_node);
@@ -609,6 +625,7 @@ static void preview_function_declarations(CodeGen *g, ImportTableEntry *import,
         case NodeTypeStructField:
         case NodeTypeStructValueExpr:
         case NodeTypeStructValueField:
+        case NodeTypeCompilerFnCall:
             zig_unreachable();
     }
 }
@@ -681,6 +698,7 @@ static void preview_types(CodeGen *g, ImportTableEntry *import, AstNode *node) {
         case NodeTypeStructField:
         case NodeTypeStructValueExpr:
         case NodeTypeStructValueField:
+        case NodeTypeCompilerFnCall:
             zig_unreachable();
     }
 }
@@ -1132,7 +1150,7 @@ static bool is_op_allowed(TypeTableEntry *type, BinOpType op) {
 static TypeTableEntry *analyze_cast_expr(CodeGen *g, ImportTableEntry *import, BlockContext *context,
         TypeTableEntry *expected_type, AstNode *node)
 {
-    TypeTableEntry *wanted_type = resolve_type(g, node->data.cast_expr.type);
+    TypeTableEntry *wanted_type = resolve_type(g, node->data.cast_expr.type, import, context);
     TypeTableEntry *actual_type = analyze_expression(g, import, context, nullptr, node->data.cast_expr.expr);
 
     if (wanted_type->id == TypeTableEntryIdInvalid ||
@@ -1328,7 +1346,7 @@ static VariableTableEntry *analyze_variable_declaration_raw(CodeGen *g, ImportTa
 {
     TypeTableEntry *explicit_type = nullptr;
     if (variable_declaration->type != nullptr) {
-        explicit_type = resolve_type(g, variable_declaration->type);
+        explicit_type = resolve_type(g, variable_declaration->type, import, context);
         if (explicit_type->id == TypeTableEntryIdUnreachable) {
             add_node_error(g, variable_declaration->type,
                 buf_sprintf("variable of type 'unreachable' not allowed"));
@@ -1428,7 +1446,7 @@ static TypeTableEntry *analyze_struct_val_expr(CodeGen *g, ImportTableEntry *imp
 
     AstNodeStructValueExpr *struct_val_expr = &node->data.struct_val_expr;
 
-    TypeTableEntry *type_entry = resolve_type(g, struct_val_expr->type);
+    TypeTableEntry *type_entry = resolve_type(g, struct_val_expr->type, import, context);
 
     if (type_entry->id == TypeTableEntryIdInvalid) {
         return g->builtin_types.entry_invalid;
@@ -1655,7 +1673,7 @@ static TypeTableEntry * analyze_expression(CodeGen *g, ImportTableEntry *import,
                     AsmOutput *asm_output = node->data.asm_expr.output_list.at(i);
                     if (asm_output->return_type) {
                         node->data.asm_expr.return_count += 1;
-                        return_type = resolve_type(g, asm_output->return_type);
+                        return_type = resolve_type(g, asm_output->return_type, import, context);
                         if (node->data.asm_expr.return_count > 1) {
                             add_node_error(g, node,
                                 buf_sprintf("inline assembly allows up to one output value"));
@@ -1848,6 +1866,7 @@ static TypeTableEntry * analyze_expression(CodeGen *g, ImportTableEntry *import,
         case NodeTypeStructDecl:
         case NodeTypeStructField:
         case NodeTypeStructValueField:
+        case NodeTypeCompilerFnCall:
             zig_unreachable();
     }
     assert(return_type);
@@ -1996,6 +2015,7 @@ static void analyze_top_level_declaration(CodeGen *g, ImportTableEntry *import,
         case NodeTypeStructField:
         case NodeTypeStructValueExpr:
         case NodeTypeStructValueField:
+        case NodeTypeCompilerFnCall:
             zig_unreachable();
     }
 }
src/analyze.hpp
@@ -159,9 +159,11 @@ struct CodeGen {
     struct {
         TypeTableEntry *entry_bool;
         TypeTableEntry *entry_u8;
+        TypeTableEntry *entry_u16;
         TypeTableEntry *entry_u32;
         TypeTableEntry *entry_u64;
         TypeTableEntry *entry_i8;
+        TypeTableEntry *entry_i16;
         TypeTableEntry *entry_i32;
         TypeTableEntry *entry_i64;
         TypeTableEntry *entry_isize;
src/codegen.cpp
@@ -1326,6 +1326,7 @@ static LLVMValueRef gen_expr_no_cast(CodeGen *g, AstNode *node) {
         case NodeTypeStructDecl:
         case NodeTypeStructField:
         case NodeTypeStructValueField:
+        case NodeTypeCompilerFnCall:
             zig_unreachable();
     }
     zig_unreachable();
@@ -1685,6 +1686,19 @@ static void define_builtin_types(CodeGen *g) {
         g->type_table.put(&entry->name, entry);
         g->builtin_types.entry_u8 = entry;
     }
+    {
+        TypeTableEntry *entry = new_type_table_entry(TypeTableEntryIdInt);
+        entry->type_ref = LLVMInt16Type();
+        buf_init_from_str(&entry->name, "u16");
+        entry->size_in_bits = 16;
+        entry->align_in_bits = 16;
+        entry->data.integral.is_signed = false;
+        entry->di_type = LLVMZigCreateDebugBasicType(g->dbuilder, buf_ptr(&entry->name),
+                entry->size_in_bits, entry->align_in_bits,
+                LLVMZigEncoding_DW_ATE_unsigned());
+        g->type_table.put(&entry->name, entry);
+        g->builtin_types.entry_u16 = entry;
+    }
     {
         TypeTableEntry *entry = new_type_table_entry(TypeTableEntryIdInt);
         entry->type_ref = LLVMInt32Type();
@@ -1725,6 +1739,19 @@ static void define_builtin_types(CodeGen *g) {
         g->type_table.put(&entry->name, entry);
         g->builtin_types.entry_i8 = entry;
     }
+    {
+        TypeTableEntry *entry = new_type_table_entry(TypeTableEntryIdInt);
+        entry->type_ref = LLVMInt16Type();
+        buf_init_from_str(&entry->name, "i16");
+        entry->size_in_bits = 16;
+        entry->align_in_bits = 16;
+        entry->data.integral.is_signed = true;
+        entry->di_type = LLVMZigCreateDebugBasicType(g->dbuilder, buf_ptr(&entry->name),
+                entry->size_in_bits, entry->align_in_bits,
+                LLVMZigEncoding_DW_ATE_signed());
+        g->type_table.put(&entry->name, entry);
+        g->builtin_types.entry_i16 = entry;
+    }
     {
         TypeTableEntry *entry = new_type_table_entry(TypeTableEntryIdInt);
         entry->type_ref = LLVMInt32Type();
src/parser.cpp
@@ -142,6 +142,8 @@ const char *node_type_str(NodeType node_type) {
             return "StructValueExpr";
         case NodeTypeStructValueField:
             return "StructValueField";
+        case NodeTypeCompilerFnCall:
+            return "CompilerFnCall";
     }
     zig_unreachable();
 }
@@ -233,6 +235,12 @@ void ast_print(AstNode *node, int indent) {
                         ast_print(node->data.type.child_type, indent + 2);
                         break;
                     }
+                case AstNodeTypeTypeCompilerExpr:
+                    {
+                        fprintf(stderr, "CompilerExprType\n");
+                        ast_print(node->data.type.compiler_expr, indent + 2);
+                        break;
+                    }
             }
             break;
         case NodeTypeReturnExpr:
@@ -402,6 +410,9 @@ void ast_print(AstNode *node, int indent) {
             fprintf(stderr, "%s '%s'\n", node_type_str(node->type), buf_ptr(&node->data.struct_val_field.name));
             ast_print(node->data.struct_val_field.expr, indent + 2);
             break;
+        case NodeTypeCompilerFnCall:
+            fprintf(stderr, "%s\n", node_type_str(node->type));
+            break;
     }
 }
 
@@ -985,16 +996,48 @@ static void ast_parse_type_assume_amp(ParseContext *pc, int *token_index, AstNod
 }
 
 /*
-Type : token(Symbol) | token(Unreachable) | token(Void) | PointerType | ArrayType | MaybeType
+CompileTimeFnCall : token(NumberSign) token(Symbol) token(LParen) Expression token(RParen)
+*/
+static AstNode *ast_parse_compiler_fn_call(ParseContext *pc, int *token_index, bool mandatory) {
+    Token *token = &pc->tokens->at(*token_index);
+
+    if (token->id == TokenIdNumberSign) {
+        *token_index += 1;
+    } else if (mandatory) {
+        ast_invalid_token_error(pc, token);
+    } else {
+        return nullptr;
+    }
+
+    Token *name_symbol = ast_eat_token(pc, token_index, TokenIdSymbol);
+    ast_eat_token(pc, token_index, TokenIdLParen);
+
+    AstNode *node = ast_create_node(pc, NodeTypeCompilerFnCall, token);
+    ast_buf_from_token(pc, name_symbol, &node->data.compiler_fn_call.name);
+    node->data.compiler_fn_call.expr = ast_parse_expression(pc, token_index, true);
+
+    ast_eat_token(pc, token_index, TokenIdRParen);
+    return node;
+}
+
+/*
+Type : token(Symbol) | token(Unreachable) | token(Void) | PointerType | ArrayType | MaybeType | CompileTimeFnCall
 PointerType : token(Ampersand) option(token(Const)) Type
 ArrayType : token(LBracket) Type token(Semicolon) token(Number) token(RBracket)
 */
 static AstNode *ast_parse_type(ParseContext *pc, int *token_index) {
     Token *token = &pc->tokens->at(*token_index);
-    *token_index += 1;
-
     AstNode *node = ast_create_node(pc, NodeTypeType, token);
 
+    AstNode *compiler_fn_call = ast_parse_compiler_fn_call(pc, token_index, false);
+    if (compiler_fn_call) {
+        node->data.type.type = AstNodeTypeTypeCompilerExpr;
+        node->data.type.compiler_expr = compiler_fn_call;
+        return node;
+    }
+
+    *token_index += 1;
+
     if (token->id == TokenIdKeywordUnreachable) {
         node->data.type.type = AstNodeTypeTypePrimitive;
         buf_init_from_str(&node->data.type.primitive_name, "unreachable");
src/parser.hpp
@@ -57,6 +57,7 @@ enum NodeType {
     NodeTypeStructField,
     NodeTypeStructValueExpr,
     NodeTypeStructValueField,
+    NodeTypeCompilerFnCall,
 };
 
 struct AstNodeRoot {
@@ -97,6 +98,7 @@ enum AstNodeTypeType {
     AstNodeTypeTypePointer,
     AstNodeTypeTypeArray,
     AstNodeTypeTypeMaybe,
+    AstNodeTypeTypeCompilerExpr,
 };
 
 struct AstNodeType {
@@ -105,6 +107,7 @@ struct AstNodeType {
     AstNode *child_type;
     AstNode *array_size;
     bool is_const;
+    AstNode *compiler_expr;
 };
 
 struct AstNodeBlock {
@@ -329,6 +332,11 @@ struct AstNodeStructValueExpr {
     ZigList<AstNode *> fields;
 };
 
+struct AstNodeCompilerFnCall {
+    Buf name;
+    AstNode *expr;
+};
+
 struct AstNode {
     enum NodeType type;
     int line;
@@ -368,6 +376,7 @@ struct AstNode {
         AstNodeNumberLiteral number_literal;
         AstNodeStructValueExpr struct_val_expr;
         AstNodeStructValueField struct_val_field;
+        AstNodeCompilerFnCall compiler_fn_call;
         Buf symbol;
         bool bool_literal;
     } data;
test/run_tests.cpp
@@ -347,7 +347,7 @@ pub fn main(argc: isize, argv: &&u8, env: &&u8) -> i32 {
     add_simple_case("hello world without libc", R"SOURCE(
 use "std.zig";
 
-export fn main(argc : isize, argv : &&u8, env : &&u8) -> i32 {
+pub fn main(argc : isize, argv : &&u8, env : &&u8) -> i32 {
     print_str("Hello, world!\n");
     return 0;
 }
@@ -357,7 +357,7 @@ export fn main(argc : isize, argv : &&u8, env : &&u8) -> i32 {
     add_simple_case("a + b + c", R"SOURCE(
 use "std.zig";
 
-export fn main(argc : isize, argv : &&u8, env : &&u8) -> i32 {
+pub fn main(argc : isize, argv : &&u8, env : &&u8) -> i32 {
     if (false || false || false) { print_str("BAD 1\n"); }
     if (true && true && false)   { print_str("BAD 2\n"); }
     if (1 | 2 | 4 != 7)          { print_str("BAD 3\n"); }
@@ -379,7 +379,7 @@ export fn main(argc : isize, argv : &&u8, env : &&u8) -> i32 {
     add_simple_case("short circuit", R"SOURCE(
 use "std.zig";
 
-export fn main(argc : isize, argv : &&u8, env : &&u8) -> i32 {
+pub fn main(argc : isize, argv : &&u8, env : &&u8) -> i32 {
     if (true || { print_str("BAD 1\n"); false }) {
       print_str("OK 1\n");
     }
@@ -402,7 +402,7 @@ export fn main(argc : isize, argv : &&u8, env : &&u8) -> i32 {
     add_simple_case("modify operators", R"SOURCE(
 use "std.zig";
 
-export fn main(argc : isize, argv : &&u8, env : &&u8) -> i32 {
+pub fn main(argc : isize, argv : &&u8, env : &&u8) -> i32 {
     var i : i32 = 0;
     i += 5;  if (i != 5)  { print_str("BAD +=\n"); }
     i -= 2;  if (i != 3)  { print_str("BAD -=\n"); }
@@ -554,7 +554,7 @@ export fn main(argc : isize, argv : &&u8, env : &&u8) -> i32 {
     add_simple_case("structs", R"SOURCE(
 use "std.zig";
 
-export fn main(argc : isize, argv : &&u8, env : &&u8) -> i32 {
+pub fn main(argc : isize, argv : &&u8, env : &&u8) -> i32 {
     var foo : Foo;
     foo.a += 1;
     foo.b = foo.a == 1;
@@ -628,7 +628,7 @@ use "std.zig";
 const g1 : i32 = 1233 + 1;
 var g2 : i32;
 
-export fn main(argc : isize, argv : &&u8, env : &&u8) -> i32 {
+pub fn main(argc : isize, argv : &&u8, env : &&u8) -> i32 {
     if (g2 != 0) { print_str("BAD\n"); }
     g2 = g1;
     if (g2 != 1234) { print_str("BAD\n"); }
@@ -639,7 +639,7 @@ export fn main(argc : isize, argv : &&u8, env : &&u8) -> i32 {
 
     add_simple_case("while loop", R"SOURCE(
 use "std.zig";
-export fn main(argc : isize, argv : &&u8, env : &&u8) -> i32 {
+pub fn main(argc : isize, argv : &&u8, env : &&u8) -> i32 {
     var i : i32 = 0;
     while (i < 4) {
         print_str("loop\n");
@@ -651,7 +651,7 @@ export fn main(argc : isize, argv : &&u8, env : &&u8) -> i32 {
 
     add_simple_case("continue and break", R"SOURCE(
 use "std.zig";
-export fn main(argc : isize, argv : &&u8, env : &&u8) -> i32 {
+pub fn main(argc : isize, argv : &&u8, env : &&u8) -> i32 {
     var i : i32 = 0;
     while (true) {
         print_str("loop\n");
@@ -667,7 +667,7 @@ export fn main(argc : isize, argv : &&u8, env : &&u8) -> i32 {
 
     add_simple_case("maybe type", R"SOURCE(
 use "std.zig";
-export fn main(argc : isize, argv : &&u8, env : &&u8) -> i32 {
+pub fn main(argc : isize, argv : &&u8, env : &&u8) -> i32 {
     const x : ?bool = true;
 
     if (const y ?= x) {
@@ -685,7 +685,7 @@ export fn main(argc : isize, argv : &&u8, env : &&u8) -> i32 {
 
     add_simple_case("implicit cast after unreachable", R"SOURCE(
 use "std.zig";
-export fn main(argc : isize, argv : &&u8, env : &&u8) -> i32 {
+pub fn main(argc : isize, argv : &&u8, env : &&u8) -> i32 {
     const x = outer();
     if (x == 1234) {
         print_str("OK\n");
@@ -695,6 +695,17 @@ export fn main(argc : isize, argv : &&u8, env : &&u8) -> i32 {
 fn inner() -> i32 { 1234 }
 fn outer() -> isize {
     return inner();
+}
+    )SOURCE", "OK\n");
+
+    add_simple_case("#typeof()", R"SOURCE(
+use "std.zig";
+const x: u16 = 13;
+const z: #typeof(x) = 19;
+pub fn main(argc : isize, argv : &&u8, env : &&u8) -> i32 {
+    const y: #typeof(x) = 120;
+    print_str("OK\n");
+    return 0;
 }
     )SOURCE", "OK\n");
 }
@@ -999,6 +1010,11 @@ fn f() -> i32 {
     (return 1) as i32
 }
     )SOURCE", 1, ".tmp_source.zig:3:16: error: invalid cast from type 'unreachable' to 'i32'");
+
+    add_compile_fail_case("invalid compiler fn", R"SOURCE(
+fn f() -> #bogus(foo) {
+}
+    )SOURCE", 1, ".tmp_source.zig:2:11: error: invalid compiler function: 'bogus'");
 }
 
 static void print_compiler_invocation(TestCase *test_case) {