Commit 19037014e5

Andrew Kelley <superjoe30@gmail.com>
2016-11-19 05:52:42
IR: more maybe type support
1 parent 31565ef
src/ast_render.cpp
@@ -342,13 +342,23 @@ static void print_symbol(AstRender *ar, Buf *symbol) {
     fprintf(ar->f, "@\"%s\"", buf_ptr(&escaped));
 }
 
-static void render_node(AstRender *ar, AstNode *node) {
+static void render_node_extra(AstRender *ar, AstNode *node, bool grouped);
+
+static void render_node_grouped(AstRender *ar, AstNode *node) {
+    return render_node_extra(ar, node, true);
+}
+
+static void render_node_ungrouped(AstRender *ar, AstNode *node) {
+    return render_node_extra(ar, node, false);
+}
+
+static void render_node_extra(AstRender *ar, AstNode *node, bool grouped) {
     switch (node->type) {
         case NodeTypeRoot:
             for (size_t i = 0; i < node->data.root.top_level_decls.length; i += 1) {
                 AstNode *child = node->data.root.top_level_decls.at(i);
                 print_indent(ar);
-                render_node(ar, child);
+                render_node_grouped(ar, child);
 
                 if (child->type == NodeTypeUse ||
                     child->type == NodeTypeVariableDeclaration ||
@@ -381,7 +391,7 @@ static void render_node(AstRender *ar, AstNode *node) {
                         print_symbol(ar, param_decl->data.param_decl.name);
                         fprintf(ar->f, ": ");
                     }
-                    render_node(ar, param_decl->data.param_decl.type);
+                    render_node_grouped(ar, param_decl->data.param_decl.type);
 
                     if (arg_i + 1 < arg_count || is_var_args) {
                         fprintf(ar->f, ", ");
@@ -394,14 +404,14 @@ static void render_node(AstRender *ar, AstNode *node) {
 
                 AstNode *return_type_node = node->data.fn_proto.return_type;
                 fprintf(ar->f, " -> ");
-                render_node(ar, return_type_node);
+                render_node_grouped(ar, return_type_node);
                 break;
             }
         case NodeTypeFnDef:
             {
-                render_node(ar, node->data.fn_def.fn_proto);
+                render_node_grouped(ar, node->data.fn_def.fn_proto);
                 fprintf(ar->f, " ");
-                render_node(ar, node->data.fn_def.body);
+                render_node_grouped(ar, node->data.fn_def.body);
                 break;
             }
         case NodeTypeBlock:
@@ -414,7 +424,7 @@ static void render_node(AstRender *ar, AstNode *node) {
             for (size_t i = 0; i < node->data.block.statements.length; i += 1) {
                 AstNode *statement = node->data.block.statements.at(i);
                 print_indent(ar);
-                render_node(ar, statement);
+                render_node_grouped(ar, statement);
                 if (i != node->data.block.statements.length - 1)
                     fprintf(ar->f, ";");
                 fprintf(ar->f, "\n");
@@ -427,14 +437,14 @@ static void render_node(AstRender *ar, AstNode *node) {
             {
                 const char *return_str = return_string(node->data.return_expr.kind);
                 fprintf(ar->f, "%s ", return_str);
-                render_node(ar, node->data.return_expr.expr);
+                render_node_grouped(ar, node->data.return_expr.expr);
                 break;
             }
         case NodeTypeDefer:
             {
                 const char *defer_str = defer_string(node->data.defer.kind);
                 fprintf(ar->f, "%s ", defer_str);
-                render_node(ar, node->data.return_expr.expr);
+                render_node_grouped(ar, node->data.return_expr.expr);
                 break;
             }
         case NodeTypeVariableDeclaration:
@@ -447,11 +457,11 @@ static void render_node(AstRender *ar, AstNode *node) {
 
                 if (node->data.variable_declaration.type) {
                     fprintf(ar->f, ": ");
-                    render_node(ar, node->data.variable_declaration.type);
+                    render_node_grouped(ar, node->data.variable_declaration.type);
                 }
                 if (node->data.variable_declaration.expr) {
                     fprintf(ar->f, " = ");
-                    render_node(ar, node->data.variable_declaration.expr);
+                    render_node_grouped(ar, node->data.variable_declaration.expr);
                 }
                 break;
             }
@@ -460,15 +470,15 @@ static void render_node(AstRender *ar, AstNode *node) {
                 const char *pub_str = visib_mod_string(node->data.type_decl.top_level_decl.visib_mod);
                 const char *var_name = buf_ptr(node->data.type_decl.symbol);
                 fprintf(ar->f, "%stype %s = ", pub_str, var_name);
-                render_node(ar, node->data.type_decl.child_type);
+                render_node_grouped(ar, node->data.type_decl.child_type);
                 break;
             }
         case NodeTypeBinOpExpr:
-            fprintf(ar->f, "(");
-            render_node(ar, node->data.bin_op_expr.op1);
+            if (!grouped) fprintf(ar->f, "(");
+            render_node_ungrouped(ar, node->data.bin_op_expr.op1);
             fprintf(ar->f, " %s ", bin_op_str(node->data.bin_op_expr.bin_op));
-            render_node(ar, node->data.bin_op_expr.op2);
-            fprintf(ar->f, ")");
+            render_node_ungrouped(ar, node->data.bin_op_expr.op2);
+            if (!grouped) fprintf(ar->f, ")");
             break;
         case NodeTypeNumberLiteral:
             switch (node->data.number_literal.bignum->kind) {
@@ -511,7 +521,7 @@ static void render_node(AstRender *ar, AstNode *node) {
                 PrefixOp op = node->data.prefix_op_expr.prefix_op;
                 fprintf(ar->f, "%s", prefix_op_str(op));
 
-                render_node(ar, node->data.prefix_op_expr.primary_expr);
+                render_node_ungrouped(ar, node->data.prefix_op_expr.primary_expr);
                 break;
             }
         case NodeTypeFnCallExpr:
@@ -520,7 +530,7 @@ static void render_node(AstRender *ar, AstNode *node) {
             } else {
                 fprintf(ar->f, "(");
             }
-            render_node(ar, node->data.fn_call_expr.fn_ref_expr);
+            render_node_ungrouped(ar, node->data.fn_call_expr.fn_ref_expr);
             if (!node->data.fn_call_expr.is_builtin) {
                 fprintf(ar->f, ")");
             }
@@ -530,21 +540,21 @@ static void render_node(AstRender *ar, AstNode *node) {
                 if (i != 0) {
                     fprintf(ar->f, ", ");
                 }
-                render_node(ar, param);
+                render_node_grouped(ar, param);
             }
             fprintf(ar->f, ")");
             break;
         case NodeTypeArrayAccessExpr:
-            render_node(ar, node->data.array_access_expr.array_ref_expr);
+            render_node_ungrouped(ar, node->data.array_access_expr.array_ref_expr);
             fprintf(ar->f, "[");
-            render_node(ar, node->data.array_access_expr.subscript);
+            render_node_grouped(ar, node->data.array_access_expr.subscript);
             fprintf(ar->f, "]");
             break;
         case NodeTypeFieldAccessExpr:
             {
                 AstNode *lhs = node->data.field_access_expr.struct_expr;
                 Buf *rhs = node->data.field_access_expr.field_name;
-                render_node(ar, lhs);
+                render_node_ungrouped(ar, lhs);
                 fprintf(ar->f, ".");
                 print_symbol(ar, rhs);
                 break;
@@ -565,7 +575,7 @@ static void render_node(AstRender *ar, AstNode *node) {
                     print_indent(ar);
                     print_symbol(ar, field_node->data.struct_field.name);
                     fprintf(ar->f, ": ");
-                    render_node(ar, field_node->data.struct_field.type);
+                    render_node_grouped(ar, field_node->data.struct_field.type);
                     fprintf(ar->f, ",\n");
                 }
 
@@ -574,9 +584,8 @@ static void render_node(AstRender *ar, AstNode *node) {
                 break;
             }
         case NodeTypeContainerInitExpr:
-            fprintf(ar->f, "(");
-            render_node(ar, node->data.container_init_expr.type);
-            fprintf(ar->f, "){");
+            render_node_ungrouped(ar, node->data.container_init_expr.type);
+            fprintf(ar->f, "{");
             assert(node->data.container_init_expr.entries.length == 0);
             fprintf(ar->f, "}");
             break;
@@ -584,13 +593,13 @@ static void render_node(AstRender *ar, AstNode *node) {
             {
                 fprintf(ar->f, "[");
                 if (node->data.array_type.size) {
-                    render_node(ar, node->data.array_type.size);
+                    render_node_grouped(ar, node->data.array_type.size);
                 }
                 fprintf(ar->f, "]");
                 if (node->data.array_type.is_const) {
                     fprintf(ar->f, "const ");
                 }
-                render_node(ar, node->data.array_type.child_type);
+                render_node_ungrouped(ar, node->data.array_type.child_type);
                 break;
             }
         case NodeTypeErrorType:
@@ -622,7 +631,7 @@ static void render_node(AstRender *ar, AstNode *node) {
                             buf_ptr(asm_output->constraint));
                     if (asm_output->return_type) {
                         fprintf(ar->f, "-> ");
-                        render_node(ar, asm_output->return_type);
+                        render_node_grouped(ar, asm_output->return_type);
                     } else {
                         fprintf(ar->f, "%s", buf_ptr(asm_output->variable_name));
                     }
@@ -642,7 +651,7 @@ static void render_node(AstRender *ar, AstNode *node) {
                     fprintf(ar->f, "[%s] \"%s\" (",
                             buf_ptr(asm_input->asm_symbolic_name),
                             buf_ptr(asm_input->constraint));
-                    render_node(ar, asm_input->expr);
+                    render_node_grouped(ar, asm_input->expr);
                     fprintf(ar->f, ")");
                 }
                 fprintf(ar->f, "\n");
@@ -660,13 +669,13 @@ static void render_node(AstRender *ar, AstNode *node) {
             {
                 const char *inline_str = node->data.while_expr.is_inline ? "inline " : "";
                 fprintf(ar->f, "%swhile (", inline_str);
-                render_node(ar, node->data.while_expr.condition);
+                render_node_grouped(ar, node->data.while_expr.condition);
                 if (node->data.while_expr.continue_expr) {
                     fprintf(ar->f, "; ");
-                    render_node(ar, node->data.while_expr.continue_expr);
+                    render_node_grouped(ar, node->data.while_expr.continue_expr);
                 }
                 fprintf(ar->f, ") ");
-                render_node(ar, node->data.while_expr.body);
+                render_node_grouped(ar, node->data.while_expr.body);
                 break;
             }
         case NodeTypeThisLiteral:
@@ -680,6 +689,18 @@ static void render_node(AstRender *ar, AstNode *node) {
                 fprintf(ar->f, "%s", bool_str);
                 break;
             }
+        case NodeTypeIfBoolExpr:
+            {
+                fprintf(ar->f, "if (");
+                render_node_grouped(ar, node->data.if_bool_expr.condition);
+                fprintf(ar->f, ") ");
+                render_node_grouped(ar, node->data.if_bool_expr.then_block);
+                if (node->data.if_bool_expr.else_node) {
+                    fprintf(ar->f, "else ");
+                    render_node_grouped(ar, node->data.if_bool_expr.else_node);
+                }
+                break;
+            }
         case NodeTypeFnDecl:
         case NodeTypeParamDecl:
         case NodeTypeErrorValueDecl:
@@ -690,7 +711,6 @@ static void render_node(AstRender *ar, AstNode *node) {
         case NodeTypeUse:
         case NodeTypeNullLiteral:
         case NodeTypeZeroesLiteral:
-        case NodeTypeIfBoolExpr:
         case NodeTypeIfVarExpr:
         case NodeTypeForExpr:
         case NodeTypeSwitchExpr:
@@ -711,5 +731,5 @@ void ast_render(FILE *f, AstNode *node, int indent_size) {
     ar.indent_size = indent_size;
     ar.indent = 0;
 
-    render_node(&ar, node);
+    render_node_grouped(&ar, node);
 }
src/codegen.cpp
@@ -3242,9 +3242,19 @@ static void get_c_type(CodeGen *g, TypeTableEntry *type_entry, Buf *out_buf) {
                 buf_appendf(out_buf, "%s%s *", const_str, buf_ptr(&child_buf));
                 break;
             }
+        case TypeTableEntryIdMaybe:
+            {
+                TypeTableEntry *child_type = type_entry->data.maybe.child_type;
+                if (child_type->id == TypeTableEntryIdPointer ||
+                    child_type->id == TypeTableEntryIdFn)
+                {
+                    return get_c_type(g, child_type, out_buf);
+                } else {
+                    zig_unreachable();
+                }
+            }
         case TypeTableEntryIdArray:
         case TypeTableEntryIdStruct:
-        case TypeTableEntryIdMaybe:
         case TypeTableEntryIdErrorUnion:
         case TypeTableEntryIdPureError:
         case TypeTableEntryIdEnum:
src/ir.cpp
@@ -3113,8 +3113,11 @@ static TypeTableEntry *ir_analyze_instruction_decl_var(IrAnalyze *ira, IrInstruc
     var->type = result_type;
     assert(var->type);
 
-    ConstExprValue *mem_slot = &ira->exec_context.mem_slot_list[var->mem_slot_index];
-    *mem_slot = casted_init_value->static_value;
+    if (var->mem_slot_index != SIZE_MAX) {
+        assert(var->mem_slot_index < ira->exec_context.mem_slot_count);
+        ConstExprValue *mem_slot = &ira->exec_context.mem_slot_list[var->mem_slot_index];
+        *mem_slot = casted_init_value->static_value;
+    }
 
     ir_build_var_decl_from(&ira->new_irb, &decl_var_instruction->base, var, var_type, casted_init_value);
 
@@ -3341,6 +3344,70 @@ static TypeTableEntry *ir_analyze_dereference(IrAnalyze *ira, IrInstructionUnOp
     return child_type;
 }
 
+static TypeTableEntry *ir_analyze_maybe(IrAnalyze *ira, IrInstructionUnOp *un_op_instruction) {
+    IrInstruction *value = un_op_instruction->value->other;
+    TypeTableEntry *type_entry = ir_resolve_type(ira, value);
+    TypeTableEntry *canon_type = get_underlying_type(type_entry);
+    switch (canon_type->id) {
+        case TypeTableEntryIdInvalid:
+            return ira->codegen->builtin_types.entry_invalid;
+        case TypeTableEntryIdVar:
+        case TypeTableEntryIdTypeDecl:
+            zig_unreachable();
+        case TypeTableEntryIdMetaType:
+        case TypeTableEntryIdVoid:
+        case TypeTableEntryIdBool:
+        case TypeTableEntryIdInt:
+        case TypeTableEntryIdFloat:
+        case TypeTableEntryIdPointer:
+        case TypeTableEntryIdArray:
+        case TypeTableEntryIdStruct:
+        case TypeTableEntryIdNumLitFloat:
+        case TypeTableEntryIdNumLitInt:
+        case TypeTableEntryIdUndefLit:
+        case TypeTableEntryIdNullLit:
+        case TypeTableEntryIdMaybe:
+        case TypeTableEntryIdErrorUnion:
+        case TypeTableEntryIdPureError:
+        case TypeTableEntryIdEnum:
+        case TypeTableEntryIdUnion:
+        case TypeTableEntryIdFn:
+        case TypeTableEntryIdNamespace:
+        case TypeTableEntryIdBlock:
+        case TypeTableEntryIdGenericFn:
+            {
+                ConstExprValue *out_val = ir_build_const_from(ira, &un_op_instruction->base,
+                        value->static_value.depends_on_compile_var);
+                out_val->data.x_type = get_maybe_type(ira->codegen, type_entry);
+                return ira->codegen->builtin_types.entry_type;
+            }
+        case TypeTableEntryIdUnreachable:
+            add_node_error(ira->codegen, un_op_instruction->base.source_node,
+                    buf_sprintf("type '%s' not nullable", buf_ptr(&type_entry->name)));
+            // TODO if it's a type decl, put an error note here pointing to the decl
+            return ira->codegen->builtin_types.entry_invalid;
+    }
+    zig_unreachable();
+}
+
+static TypeTableEntry *ir_analyze_unwrap_maybe(IrAnalyze *ira, IrInstructionUnOp *un_op_instruction) {
+    IrInstruction *value = un_op_instruction->value->other;
+    TypeTableEntry *type_entry = value->type_entry;
+    if (type_entry->id == TypeTableEntryIdInvalid) {
+        return type_entry;
+    } else if (type_entry->id == TypeTableEntryIdMaybe) {
+        if (value->static_value.special != ConstValSpecialRuntime) {
+            zig_panic("TODO compile time eval unwrap maybe");
+        }
+        ir_build_un_op_from(&ira->new_irb, &un_op_instruction->base, IrUnOpUnwrapMaybe, value);
+        return type_entry->data.maybe.child_type;
+    } else {
+        add_node_error(ira->codegen, un_op_instruction->base.source_node,
+            buf_sprintf("expected maybe type, found '%s'", buf_ptr(&type_entry->name)));
+        return ira->codegen->builtin_types.entry_invalid;
+    }
+}
+
 static TypeTableEntry *ir_analyze_instruction_un_op(IrAnalyze *ira, IrInstructionUnOp *un_op_instruction) {
     IrUnOp op_id = un_op_instruction->op_id;
     switch (op_id) {
@@ -3417,34 +3484,7 @@ static TypeTableEntry *ir_analyze_instruction_un_op(IrAnalyze *ira, IrInstructio
         case IrUnOpDereference:
             return ir_analyze_dereference(ira, un_op_instruction);
         case IrUnOpMaybe:
-            zig_panic("TODO analyze PrefixOpMaybe");
-            //{
-            //    TypeTableEntry *type_entry = analyze_expression(g, import, context, nullptr, *expr_node);
-
-            //    if (type_entry->id == TypeTableEntryIdInvalid) {
-            //        return type_entry;
-            //    } else if (type_entry->id == TypeTableEntryIdMetaType) {
-            //        TypeTableEntry *meta_type = resolve_type(g, *expr_node);
-            //        if (meta_type->id == TypeTableEntryIdInvalid) {
-            //            return g->builtin_types.entry_invalid;
-            //        } else if (meta_type->id == TypeTableEntryIdUnreachable) {
-            //            add_node_error(g, node, buf_create_from_str("unable to wrap unreachable in maybe type"));
-            //            return g->builtin_types.entry_invalid;
-            //        } else {
-            //            return resolve_expr_const_val_as_type(g, node, get_maybe_type(g, meta_type), false);
-            //        }
-            //    } else if (type_entry->id == TypeTableEntryIdUnreachable) {
-            //        add_node_error(g, *expr_node, buf_sprintf("unable to wrap unreachable in maybe type"));
-            //        return g->builtin_types.entry_invalid;
-            //    } else {
-            //        ConstExprValue *target_const_val = &get_resolved_expr(*expr_node)->const_val;
-            //        TypeTableEntry *maybe_type = get_maybe_type(g, type_entry);
-            //        if (!target_const_val->ok) {
-            //            return maybe_type;
-            //        }
-            //        return resolve_expr_const_val_as_non_null(g, node, maybe_type, target_const_val);
-            //    }
-            //}
+            return ir_analyze_maybe(ira, un_op_instruction);
         case IrUnOpError:
             return ir_analyze_unary_prefix_op_err(ira, un_op_instruction);
         case IrUnOpUnwrapError:
@@ -3463,20 +3503,7 @@ static TypeTableEntry *ir_analyze_instruction_un_op(IrAnalyze *ira, IrInstructio
             //    }
             //}
         case IrUnOpUnwrapMaybe:
-            zig_panic("TODO analyze PrefixOpUnwrapMaybe");
-            //{
-            //    TypeTableEntry *type_entry = analyze_expression(g, import, context, nullptr, *expr_node);
-
-            //    if (type_entry->id == TypeTableEntryIdInvalid) {
-            //        return type_entry;
-            //    } else if (type_entry->id == TypeTableEntryIdMaybe) {
-            //        return type_entry->data.maybe.child_type;
-            //    } else {
-            //        add_node_error(g, *expr_node,
-            //            buf_sprintf("expected maybe type, got '%s'", buf_ptr(&type_entry->name)));
-            //        return g->builtin_types.entry_invalid;
-            //    }
-            //}
+            return ir_analyze_unwrap_maybe(ira, un_op_instruction);
         case IrUnOpErrorReturn:
             zig_panic("TODO analyze IrUnOpErrorReturn");
         case IrUnOpMaybeReturn:
std/builtin.zig
@@ -1,24 +1,31 @@
 // These functions are provided when not linking against libc because LLVM
 // sometimes generates code that calls them.
 
-// TODO dest should be nullable and return value should be nullable
-export fn memset(dest: &u8, c: u8, n: usize) -> &u8 {
+export fn memset(dest: ?&u8, c: u8, n: usize) -> ?&u8 {
     @setDebugSafety(this, false);
 
+    if (n == 0)
+        return dest;
+
+    const d = ??dest;
     var index: usize = 0;
     while (index != n; index += 1)
-        dest[index] = c;
+        d[index] = c;
 
     return dest;
 }
 
-// TODO dest, source, and return value should be nullable
-export fn memcpy(noalias dest: &u8, noalias src: &const u8, n: usize) -> &u8 {
+export fn memcpy(noalias dest: ?&u8, noalias src: ?&const u8, n: usize) -> ?&u8 {
     @setDebugSafety(this, false);
 
+    if (n == 0)
+        return dest;
+
+    const d = ??dest;
+    const s = ??src;
     var index: usize = 0;
     while (index != n; index += 1)
-        dest[index] = src[index];
+        d[index] = s[index];
 
     return dest;
 }