Commit b62e2fd870

Andrew Kelley <superjoe30@gmail.com>
2017-12-01 03:46:02
ability to specify tag type of enums
see #305
1 parent 5786df9
doc/langref.html.in
@@ -137,6 +137,7 @@
             <li><a href="#builtin-divTrunc">@divTrunc</a></li>
             <li><a href="#builtin-embedFile">@embedFile</a></li>
             <li><a href="#builtin-enumTagName">@enumTagName</a></li>
+            <li><a href="#builtin-EnumTagType">@EnumTagType</a></li>
             <li><a href="#builtin-errorName">@errorName</a></li>
             <li><a href="#builtin-fence">@fence</a></li>
             <li><a href="#builtin-fieldParentPtr">@fieldParentPtr</a></li>
@@ -4256,6 +4257,11 @@ test.zig:6:2: error: found compile log statement
       <p>
       Converts an enum tag name to a slice of bytes.
       </p>
+      <h3 id="builtin-EnumTagType">@EnumTagType</h3>
+      <pre><code class="zig">@EnumTagType(T: type) -&gt; type</code></pre>
+      <p>
+      Returns the integer type that is used to store the enumeration value.
+      </p>
       <h3 id="builtin-errorName">@errorName</h3>
       <pre><code class="zig">@errorName(err: error) -&gt; []u8</code></pre>
       <p>
@@ -5837,7 +5843,7 @@ GroupedExpression = "(" Expression ")"
 
 KeywordLiteral = "true" | "false" | "null" | "continue" | "undefined" | "error" | "this" | "unreachable"
 
-ContainerDecl = option("extern" | "packed") ("struct" | "enum" | "union") "{" many(ContainerMember) "}"</code></pre>
+ContainerDecl = option("extern" | "packed") ("struct" | "union" | ("enum" option(GroupedExpression))) "{" many(ContainerMember) "}"</code></pre>
       <h2 id="zen">Zen</h2>
       <ul>
         <li>Communicate intent precisely.</li>
src/all_types.hpp
@@ -1286,6 +1286,7 @@ enum BuiltinFnId {
     BuiltinFnIdIntToPtr,
     BuiltinFnIdPtrToInt,
     BuiltinFnIdEnumTagName,
+    BuiltinFnIdEnumTagType,
     BuiltinFnIdFieldParentPtr,
     BuiltinFnIdOffsetOf,
     BuiltinFnIdInlineCall,
@@ -1911,6 +1912,7 @@ enum IrInstructionId {
     IrInstructionIdDeclRef,
     IrInstructionIdPanic,
     IrInstructionIdEnumTagName,
+    IrInstructionIdEnumTagType,
     IrInstructionIdFieldParentPtr,
     IrInstructionIdOffsetOf,
     IrInstructionIdTypeId,
@@ -2695,6 +2697,12 @@ struct IrInstructionEnumTagName {
     IrInstruction *target;
 };
 
+struct IrInstructionEnumTagType {
+    IrInstruction base;
+
+    IrInstruction *target;
+};
+
 struct IrInstructionFieldParentPtr {
     IrInstruction base;
 
src/analyze.cpp
@@ -1267,7 +1267,7 @@ TypeTableEntry *create_enum_tag_type(CodeGen *g, TypeTableEntry *enum_type, Type
     TypeTableEntry *entry = new_type_table_entry(TypeTableEntryIdEnumTag);
 
     buf_resize(&entry->name, 0);
-    buf_appendf(&entry->name, "@enumTagType(%s)", buf_ptr(&enum_type->name));
+    buf_appendf(&entry->name, "@EnumTagType(%s)", buf_ptr(&enum_type->name));
 
     entry->is_copyable = true;
     entry->data.enum_tag.enum_type = enum_type;
@@ -1391,6 +1391,25 @@ static void resolve_enum_type(CodeGen *g, TypeTableEntry *enum_type) {
     }
 
     TypeTableEntry *tag_int_type = get_smallest_unsigned_int_type(g, field_count - 1);
+    if (decl_node->data.container_decl.init_arg_expr != nullptr) {
+        TypeTableEntry *wanted_tag_int_type = analyze_type_expr(g, scope, decl_node->data.container_decl.init_arg_expr);
+        if (type_is_invalid(wanted_tag_int_type)) {
+            enum_type->data.enumeration.is_invalid = true;
+        } else if (wanted_tag_int_type->id != TypeTableEntryIdInt) {
+            enum_type->data.enumeration.is_invalid = true;
+            add_node_error(g, decl_node->data.container_decl.init_arg_expr,
+                buf_sprintf("expected integer, found '%s'", buf_ptr(&wanted_tag_int_type->name)));
+        } else if (wanted_tag_int_type->data.integral.bit_count < tag_int_type->data.integral.bit_count) {
+            enum_type->data.enumeration.is_invalid = true;
+            add_node_error(g, decl_node->data.container_decl.init_arg_expr,
+                buf_sprintf("'%s' too small to hold all bits; must be at least '%s'",
+                    buf_ptr(&wanted_tag_int_type->name), buf_ptr(&tag_int_type->name)));
+        } else {
+            tag_int_type = wanted_tag_int_type;
+        }
+    }
+
+
     TypeTableEntry *tag_type_entry = create_enum_tag_type(g, enum_type, tag_int_type);
     enum_type->data.enumeration.tag_type = tag_type_entry;
 
src/ast_render.cpp
@@ -660,7 +660,13 @@ static void render_node_extra(AstRender *ar, AstNode *node, bool grouped) {
             {
                 const char *layout_str = layout_string(node->data.container_decl.layout);
                 const char *container_str = container_string(node->data.container_decl.kind);
-                fprintf(ar->f, "%s%s {\n", layout_str, container_str);
+                fprintf(ar->f, "%s%s", layout_str, container_str);
+                if (node->data.container_decl.init_arg_expr != nullptr) {
+                    fprintf(ar->f, "(");
+                    render_node_grouped(ar, node->data.container_decl.init_arg_expr);
+                    fprintf(ar->f, ")");
+                }
+                fprintf(ar->f, " {\n");
                 ar->indent += ar->indent_size;
                 for (size_t field_i = 0; field_i < node->data.container_decl.fields.length; field_i += 1) {
                     AstNode *field_node = node->data.container_decl.fields.at(field_i);
src/codegen.cpp
@@ -3537,6 +3537,7 @@ static LLVMValueRef ir_render_instruction(CodeGen *g, IrExecutable *executable,
         case IrInstructionIdOpaqueType:
         case IrInstructionIdSetAlignStack:
         case IrInstructionIdArgType:
+        case IrInstructionIdEnumTagType:
             zig_unreachable();
         case IrInstructionIdReturn:
             return ir_render_return(g, executable, (IrInstructionReturn *)instruction);
@@ -5049,7 +5050,8 @@ static void define_builtin_fns(CodeGen *g) {
     create_builtin_fn(g, BuiltinFnIdBitCast, "bitCast", 2);
     create_builtin_fn(g, BuiltinFnIdIntToPtr, "intToPtr", 2);
     create_builtin_fn(g, BuiltinFnIdPtrToInt, "ptrToInt", 1);
-    create_builtin_fn(g, BuiltinFnIdEnumTagName, "enumTagName", 1);
+    create_builtin_fn(g, BuiltinFnIdEnumTagName, "enumTagName", 1); // TODO rename to memberName
+    create_builtin_fn(g, BuiltinFnIdEnumTagType, "EnumTagType", 1);
     create_builtin_fn(g, BuiltinFnIdFieldParentPtr, "fieldParentPtr", 3);
     create_builtin_fn(g, BuiltinFnIdOffsetOf, "offsetOf", 2);
     create_builtin_fn(g, BuiltinFnIdDivExact, "divExact", 2);
src/ir.cpp
@@ -551,6 +551,10 @@ static constexpr IrInstructionId ir_instruction_id(IrInstructionEnumTagName *) {
     return IrInstructionIdEnumTagName;
 }
 
+static constexpr IrInstructionId ir_instruction_id(IrInstructionEnumTagType *) {
+    return IrInstructionIdEnumTagType;
+}
+
 static constexpr IrInstructionId ir_instruction_id(IrInstructionFieldParentPtr *) {
     return IrInstructionIdFieldParentPtr;
 }
@@ -2270,6 +2274,17 @@ static IrInstruction *ir_build_enum_tag_name(IrBuilder *irb, Scope *scope, AstNo
     return &instruction->base;
 }
 
+static IrInstruction *ir_build_enum_tag_type(IrBuilder *irb, Scope *scope, AstNode *source_node,
+        IrInstruction *target)
+{
+    IrInstructionEnumTagType *instruction = ir_build_instruction<IrInstructionEnumTagType>(irb, scope, source_node);
+    instruction->target = target;
+
+    ir_ref_instruction(target, irb->current_basic_block);
+
+    return &instruction->base;
+}
+
 static IrInstruction *ir_build_field_parent_ptr(IrBuilder *irb, Scope *scope, AstNode *source_node,
         IrInstruction *type_value, IrInstruction *field_name, IrInstruction *field_ptr, TypeStructField *field)
 {
@@ -3066,6 +3081,13 @@ static IrInstruction *ir_instruction_enumtagname_get_dep(IrInstructionEnumTagNam
     }
 }
 
+static IrInstruction *ir_instruction_enumtagtype_get_dep(IrInstructionEnumTagType *instruction, size_t index) {
+    switch (index) {
+        case 0: return instruction->target;
+        default: return nullptr;
+    }
+}
+
 static IrInstruction *ir_instruction_fieldparentptr_get_dep(IrInstructionFieldParentPtr *instruction, size_t index) {
     switch (index) {
         case 0: return instruction->type_value;
@@ -3326,6 +3348,8 @@ static IrInstruction *ir_instruction_get_dep(IrInstruction *instruction, size_t
             return ir_instruction_panic_get_dep((IrInstructionPanic *) instruction, index);
         case IrInstructionIdEnumTagName:
             return ir_instruction_enumtagname_get_dep((IrInstructionEnumTagName *) instruction, index);
+        case IrInstructionIdEnumTagType:
+            return ir_instruction_enumtagtype_get_dep((IrInstructionEnumTagType *) instruction, index);
         case IrInstructionIdFieldParentPtr:
             return ir_instruction_fieldparentptr_get_dep((IrInstructionFieldParentPtr *) instruction, index);
         case IrInstructionIdOffsetOf:
@@ -4681,6 +4705,15 @@ static IrInstruction *ir_gen_builtin_fn_call(IrBuilder *irb, Scope *scope, AstNo
                 IrInstruction *actual_tag = ir_build_enum_tag(irb, scope, node, arg0_value);
                 return ir_build_enum_tag_name(irb, scope, node, actual_tag);
             }
+        case BuiltinFnIdEnumTagType:
+            {
+                AstNode *arg0_node = node->data.fn_call_expr.params.at(0);
+                IrInstruction *arg0_value = ir_gen_node(irb, arg0_node, scope);
+                if (arg0_value == irb->codegen->invalid_instruction)
+                    return arg0_value;
+
+                return ir_build_enum_tag_type(irb, scope, node, arg0_value);
+            }
         case BuiltinFnIdFieldParentPtr:
             {
                 AstNode *arg0_node = node->data.fn_call_expr.params.at(0);
@@ -15831,6 +15864,27 @@ static TypeTableEntry *ir_analyze_instruction_arg_type(IrAnalyze *ira, IrInstruc
     return ira->codegen->builtin_types.entry_type;
 }
 
+static TypeTableEntry *ir_analyze_instruction_enum_tag_type(IrAnalyze *ira, IrInstructionEnumTagType *instruction) {
+    IrInstruction *target_inst = instruction->target->other;
+    TypeTableEntry *enum_type = ir_resolve_type(ira, target_inst);
+    if (type_is_invalid(enum_type))
+        return ira->codegen->builtin_types.entry_invalid;
+    if (enum_type->id != TypeTableEntryIdEnum) {
+        ir_add_error(ira, target_inst, buf_sprintf("expected enum, found '%s'", buf_ptr(&enum_type->name)));
+        return ira->codegen->builtin_types.entry_invalid;
+    }
+    ensure_complete_type(ira->codegen, enum_type);
+    if (type_is_invalid(enum_type))
+        return ira->codegen->builtin_types.entry_invalid;
+
+    TypeTableEntry *non_int_tag_type = enum_type->data.enumeration.tag_type;
+    assert(non_int_tag_type->id == TypeTableEntryIdEnumTag);
+
+    ConstExprValue *out_val = ir_build_const_from(ira, &instruction->base);
+    out_val->data.x_type = non_int_tag_type->data.enum_tag.int_type;
+    return ira->codegen->builtin_types.entry_type;
+}
+
 static TypeTableEntry *ir_analyze_instruction_nocast(IrAnalyze *ira, IrInstruction *instruction) {
     switch (instruction->id) {
         case IrInstructionIdInvalid:
@@ -16029,6 +16083,8 @@ static TypeTableEntry *ir_analyze_instruction_nocast(IrAnalyze *ira, IrInstructi
             return ir_analyze_instruction_set_align_stack(ira, (IrInstructionSetAlignStack *)instruction);
         case IrInstructionIdArgType:
             return ir_analyze_instruction_arg_type(ira, (IrInstructionArgType *)instruction);
+        case IrInstructionIdEnumTagType:
+            return ir_analyze_instruction_enum_tag_type(ira, (IrInstructionEnumTagType *)instruction);
     }
     zig_unreachable();
 }
@@ -16214,6 +16270,7 @@ bool ir_has_side_effects(IrInstruction *instruction) {
         case IrInstructionIdAlignCast:
         case IrInstructionIdOpaqueType:
         case IrInstructionIdArgType:
+        case IrInstructionIdEnumTagType:
             return false;
         case IrInstructionIdAsm:
             {
src/ir_print.cpp
@@ -994,6 +994,12 @@ static void ir_print_arg_type(IrPrint *irp, IrInstructionArgType *instruction) {
     fprintf(irp->f, ")");
 }
 
+static void ir_print_enum_tag_type(IrPrint *irp, IrInstructionEnumTagType *instruction) {
+    fprintf(irp->f, "@EnumTagType(");
+    ir_print_other_instruction(irp, instruction->target);
+    fprintf(irp->f, ")");
+}
+
 
 static void ir_print_instruction(IrPrint *irp, IrInstruction *instruction) {
     ir_print_prefix(irp, instruction);
@@ -1312,6 +1318,9 @@ static void ir_print_instruction(IrPrint *irp, IrInstruction *instruction) {
         case IrInstructionIdArgType:
             ir_print_arg_type(irp, (IrInstructionArgType *)instruction);
             break;
+        case IrInstructionIdEnumTagType:
+            ir_print_enum_tag_type(irp, (IrInstructionEnumTagType *)instruction);
+            break;
     }
     fprintf(irp->f, "\n");
 }
src/parser.cpp
@@ -2377,7 +2377,7 @@ static AstNode *ast_parse_use(ParseContext *pc, size_t *token_index, VisibMod vi
 }
 
 /*
-ContainerDecl = option("extern" | "packed") ("struct" | "enum" | "union") "{" many(ContainerMember) "}"
+ContainerDecl = option("extern" | "packed") ("struct" | "union" | ("enum" option(GroupedExpression))) "{" many(ContainerMember) "}"
 ContainerMember = (ContainerField | FnDef | GlobalVarDecl)
 ContainerField = Symbol option(":" Expression) ","
 */
@@ -2415,6 +2415,10 @@ static AstNode *ast_parse_container_decl(ParseContext *pc, size_t *token_index,
     node->data.container_decl.layout = layout;
     node->data.container_decl.kind = kind;
 
+    if (kind == ContainerKindEnum || kind == ContainerKindStruct) {
+        node->data.container_decl.init_arg_expr = ast_parse_grouped_expr(pc, token_index, false);
+    }
+
     ast_eat_token(pc, token_index, TokenIdLBrace);
 
     for (;;) {
@@ -2804,6 +2808,7 @@ void ast_visit_node_children(AstNode *node, void (*visit)(AstNode **, void *cont
         case NodeTypeContainerDecl:
             visit_node_list(&node->data.container_decl.fields, visit, context);
             visit_node_list(&node->data.container_decl.decls, visit, context);
+            visit_field(&node->data.container_decl.init_arg_expr, visit, context);
             break;
         case NodeTypeStructField:
             visit_field(&node->data.struct_field.type, visit, context);
src/translate_c.cpp
@@ -651,6 +651,14 @@ static bool c_is_unsigned_integer(Context *c, QualType qt) {
     }
 }
 
+static bool c_is_builtin_type(Context *c, QualType qt, BuiltinType::Kind kind) {
+    const Type *c_type = qual_type_canon(qt);
+    if (c_type->getTypeClass() != Type::Builtin)
+        return false;
+    const BuiltinType *builtin_ty = static_cast<const BuiltinType*>(c_type);
+    return builtin_ty->getKind() == kind;
+}
+
 static bool c_is_float(Context *c, QualType qt) {
     const Type *c_type = qt.getTypePtr();
     if (c_type->getTypeClass() != Type::Builtin)
@@ -3426,7 +3434,9 @@ static AstNode *resolve_enum_decl(Context *c, const EnumDecl *enum_decl) {
         AstNode *enum_node = trans_create_node(c, NodeTypeContainerDecl);
         enum_node->data.container_decl.kind = ContainerKindEnum;
         enum_node->data.container_decl.layout = ContainerLayoutExtern;
-        enum_node->data.container_decl.init_arg_expr = tag_int_type;
+        if (!c_is_builtin_type(c, enum_decl->getIntegerType(), BuiltinType::UInt)) {
+            enum_node->data.container_decl.init_arg_expr = tag_int_type;
+        }
 
         enum_node->data.container_decl.fields.resize(field_count);
         uint32_t i = 0;
test/cases/enum.zig
@@ -190,3 +190,27 @@ test "enum sizes" {
         assert(@sizeOf(ValueCount257) == 2);
     }
 }
+
+const Small2 = enum (u2) {
+    One,
+    Two,
+};
+const Small = enum (u2) {
+    One,
+    Two,
+    Three,
+    Four,
+};
+
+test "set enum tag type" {
+    {
+        var x = Small.One;
+        x = Small.Two;
+        comptime assert(@EnumTagType(Small) == u2);
+    }
+    {
+        var x = Small2.One;
+        x = Small2.Two;
+        comptime assert(@EnumTagType(Small2) == u2);
+    }
+}
test/compile_errors.zig
@@ -2362,4 +2362,32 @@ pub fn addCases(cases: &tests.CompileErrorContext) {
         ".tmp_source.zig:4:25: error: aoeu",
         ".tmp_source.zig:1:36: note: called from here",
         ".tmp_source.zig:12:20: note: referenced here");
+
+    cases.add("specify enum tag type that is too small",
+        \\const Small = enum (u2) {
+        \\    One,
+        \\    Two,
+        \\    Three,
+        \\    Four,
+        \\    Five,
+        \\};
+        \\
+        \\export fn entry() {
+        \\    var x = Small.One;
+        \\}
+    ,
+        ".tmp_source.zig:1:20: error: 'u2' too small to hold all bits; must be at least 'u3'");
+
+    cases.add("specify non-integer enum tag type",
+        \\const Small = enum (f32) {
+        \\    One,
+        \\    Two,
+        \\    Three,
+        \\};
+        \\
+        \\export fn entry() {
+        \\    var x = Small.One;
+        \\}
+    ,
+        ".tmp_source.zig:1:20: error: expected integer, found 'f32'");
 }