Commit fea8659b82

Andrew Kelley <andrew@ziglang.org>
2021-01-02 03:24:02
stage2: comptime function calls
* Function calls that happen in a comptime scope get called at compile-time. We do this by putting the parameters in place as constant values and then running regular function analysis on the body. * Added `Scope.Block.dump()` for debugging purposes. * Fixed some code to call `identifierTokenString` rather than `tokenSlice`, making it work for `@""` syntax. * Implemented `Value.copy` for big integers. Follow-up issues to tackle: * Adding compile errors to the callsite instead of the callee Decl. * Proper error notes for "called from here". - Related: #7555 * Branch quotas. * ZIR support?
1 parent fb37c1b
src/astgen.zig
@@ -384,7 +384,7 @@ fn breakExpr(mod: *Module, parent_scope: *Scope, node: *ast.Node.ControlFlowExpr
             .local_val => scope = scope.cast(Scope.LocalVal).?.parent,
             .local_ptr => scope = scope.cast(Scope.LocalPtr).?.parent,
             else => if (node.getLabel()) |break_label| {
-                const label_name = try identifierTokenString(mod, parent_scope, break_label);
+                const label_name = try mod.identifierTokenString(parent_scope, break_label);
                 return mod.failTok(parent_scope, break_label, "label not found: '{s}'", .{label_name});
             } else {
                 return mod.failTok(parent_scope, src, "break expression outside loop", .{});
@@ -426,7 +426,7 @@ fn continueExpr(mod: *Module, parent_scope: *Scope, node: *ast.Node.ControlFlowE
             .local_val => scope = scope.cast(Scope.LocalVal).?.parent,
             .local_ptr => scope = scope.cast(Scope.LocalPtr).?.parent,
             else => if (node.getLabel()) |break_label| {
-                const label_name = try identifierTokenString(mod, parent_scope, break_label);
+                const label_name = try mod.identifierTokenString(parent_scope, break_label);
                 return mod.failTok(parent_scope, break_label, "label not found: '{s}'", .{label_name});
             } else {
                 return mod.failTok(parent_scope, src, "continue expression outside loop", .{});
@@ -551,7 +551,7 @@ fn varDecl(
     }
     const tree = scope.tree();
     const name_src = tree.token_locs[node.name_token].start;
-    const ident_name = try identifierTokenString(mod, scope, node.name_token);
+    const ident_name = try mod.identifierTokenString(scope, node.name_token);
 
     // Local variables shadowing detection, including function parameters.
     {
@@ -843,7 +843,7 @@ fn typeInixOp(mod: *Module, scope: *Scope, node: *ast.Node.SimpleInfixOp, op_ins
 fn enumLiteral(mod: *Module, scope: *Scope, node: *ast.Node.EnumLiteral) !*zir.Inst {
     const tree = scope.tree();
     const src = tree.token_locs[node.name].start;
-    const name = try identifierTokenString(mod, scope, node.name);
+    const name = try mod.identifierTokenString(scope, node.name);
 
     return addZIRInst(mod, scope, src, zir.Inst.EnumLiteral, .{ .name = name }, .{});
 }
@@ -864,7 +864,7 @@ fn errorSetDecl(mod: *Module, scope: *Scope, rl: ResultLoc, node: *ast.Node.Erro
 
     for (decls) |decl, i| {
         const tag = decl.castTag(.ErrorTag).?;
-        fields[i] = try identifierTokenString(mod, scope, tag.name_token);
+        fields[i] = try mod.identifierTokenString(scope, tag.name_token);
     }
 
     // analyzing the error set results in a decl ref, so we might need to dereference it
@@ -988,36 +988,16 @@ fn orelseCatchExpr(
 /// Return whether the identifier names of two tokens are equal. Resolves @"" tokens without allocating.
 /// OK in theory it could do it without allocating. This implementation allocates when the @"" form is used.
 fn tokenIdentEql(mod: *Module, scope: *Scope, token1: ast.TokenIndex, token2: ast.TokenIndex) !bool {
-    const ident_name_1 = try identifierTokenString(mod, scope, token1);
-    const ident_name_2 = try identifierTokenString(mod, scope, token2);
+    const ident_name_1 = try mod.identifierTokenString(scope, token1);
+    const ident_name_2 = try mod.identifierTokenString(scope, token2);
     return mem.eql(u8, ident_name_1, ident_name_2);
 }
 
-/// Identifier token -> String (allocated in scope.arena())
-fn identifierTokenString(mod: *Module, scope: *Scope, token: ast.TokenIndex) InnerError![]const u8 {
-    const tree = scope.tree();
-
-    const ident_name = tree.tokenSlice(token);
-    if (mem.startsWith(u8, ident_name, "@")) {
-        const raw_string = ident_name[1..];
-        var bad_index: usize = undefined;
-        return std.zig.parseStringLiteral(scope.arena(), raw_string, &bad_index) catch |err| switch (err) {
-            error.InvalidCharacter => {
-                const bad_byte = raw_string[bad_index];
-                const src = tree.token_locs[token].start;
-                return mod.fail(scope, src + 1 + bad_index, "invalid string literal character: '{c}'\n", .{bad_byte});
-            },
-            else => |e| return e,
-        };
-    }
-    return ident_name;
-}
-
 pub fn identifierStringInst(mod: *Module, scope: *Scope, node: *ast.Node.OneToken) InnerError!*zir.Inst {
     const tree = scope.tree();
     const src = tree.token_locs[node.token].start;
 
-    const ident_name = try identifierTokenString(mod, scope, node.token);
+    const ident_name = try mod.identifierTokenString(scope, node.token);
 
     return addZIRInst(mod, scope, src, zir.Inst.Str, .{ .bytes = ident_name }, .{});
 }
@@ -1936,7 +1916,7 @@ fn identifier(mod: *Module, scope: *Scope, rl: ResultLoc, ident: *ast.Node.OneTo
     defer tracy.end();
 
     const tree = scope.tree();
-    const ident_name = try identifierTokenString(mod, scope, ident.token);
+    const ident_name = try mod.identifierTokenString(scope, ident.token);
     const src = tree.token_locs[ident.token].start;
     if (mem.eql(u8, ident_name, "_")) {
         return mod.failNode(scope, &ident.base, "TODO implement '_' identifier", .{});
src/Module.zig
@@ -268,6 +268,11 @@ pub const Decl = struct {
         }
     }
 
+    /// Asserts that the `Decl` is part of AST and not ZIRModule.
+    pub fn getFileScope(self: *Decl) *Scope.File {
+        return self.scope.cast(Scope.Container).?.file_scope;
+    }
+
     fn removeDependant(self: *Decl, other: *Decl) void {
         self.dependants.removeAssertDiscard(other);
     }
@@ -776,6 +781,11 @@ pub const Scope = struct {
             results: ArrayListUnmanaged(*Inst),
             block_inst: *Inst.Block,
         };
+
+        /// For debugging purposes.
+        pub fn dump(self: *Block, mod: Module) void {
+            zir.dumpBlock(mod, self);
+        }
     };
 
     /// This is a temporary structure, references to it are valid only
@@ -992,11 +1002,11 @@ fn astGenAndAnalyzeDecl(self: *Module, decl: *Decl) !bool {
     defer tracy.end();
 
     const container_scope = decl.scope.cast(Scope.Container).?;
-    const tree = try self.getAstTree(container_scope);
+    const tree = try self.getAstTree(container_scope.file_scope);
     const ast_node = tree.root_node.decls()[decl.src_index];
     switch (ast_node.tag) {
         .FnProto => {
-            const fn_proto = @fieldParentPtr(ast.Node.FnProto, "base", ast_node);
+            const fn_proto = ast_node.castTag(.FnProto).?;
 
             decl.analysis = .in_progress;
 
@@ -1131,7 +1141,7 @@ fn astGenAndAnalyzeDecl(self: *Module, decl: *Decl) !bool {
                 for (fn_proto.params()) |param, i| {
                     const name_token = param.name_token.?;
                     const src = tree.token_locs[name_token].start;
-                    const param_name = tree.tokenSlice(name_token); // TODO: call identifierTokenString
+                    const param_name = try self.identifierTokenString(&gen_scope.base, name_token);
                     const arg = try gen_scope_arena.allocator.create(zir.Inst.Arg);
                     arg.* = .{
                         .base = .{
@@ -1496,12 +1506,10 @@ fn getSrcModule(self: *Module, root_scope: *Scope.ZIRModule) !*zir.Module {
     }
 }
 
-fn getAstTree(self: *Module, container_scope: *Scope.Container) !*ast.Tree {
+pub fn getAstTree(self: *Module, root_scope: *Scope.File) !*ast.Tree {
     const tracy = trace(@src());
     defer tracy.end();
 
-    const root_scope = container_scope.file_scope;
-
     switch (root_scope.status) {
         .never_loaded, .unloaded_success => {
             try self.failed_files.ensureCapacity(self.gpa, self.failed_files.items().len + 1);
@@ -1549,7 +1557,7 @@ pub fn analyzeContainer(self: *Module, container_scope: *Scope.Container) !void
 
     // We may be analyzing it for the first time, or this may be
     // an incremental update. This code handles both cases.
-    const tree = try self.getAstTree(container_scope);
+    const tree = try self.getAstTree(container_scope.file_scope);
     const decls = tree.root_node.decls();
 
     try self.comp.work_queue.ensureUnusedCapacity(decls.len);
@@ -3427,3 +3435,23 @@ pub fn validateVarType(mod: *Module, scope: *Scope, src: usize, ty: Type) !void
         return mod.fail(scope, src, "variable of type '{}' must be const or comptime", .{ty});
     }
 }
+
+/// Identifier token -> String (allocated in scope.arena())
+pub fn identifierTokenString(mod: *Module, scope: *Scope, token: ast.TokenIndex) InnerError![]const u8 {
+    const tree = scope.tree();
+
+    const ident_name = tree.tokenSlice(token);
+    if (mem.startsWith(u8, ident_name, "@")) {
+        const raw_string = ident_name[1..];
+        var bad_index: usize = undefined;
+        return std.zig.parseStringLiteral(scope.arena(), raw_string, &bad_index) catch |err| switch (err) {
+            error.InvalidCharacter => {
+                const bad_byte = raw_string[bad_index];
+                const src = tree.token_locs[token].start;
+                return mod.fail(scope, src + 1 + bad_index, "invalid string literal character: '{c}'\n", .{bad_byte});
+            },
+            else => |e| return e,
+        };
+    }
+    return ident_name;
+}
src/value.zig
@@ -330,11 +330,14 @@ pub const Value = extern union {
             .int_type => return self.copyPayloadShallow(allocator, Payload.IntType),
             .int_u64 => return self.copyPayloadShallow(allocator, Payload.U64),
             .int_i64 => return self.copyPayloadShallow(allocator, Payload.I64),
-            .int_big_positive => {
-                @panic("TODO implement copying of big ints");
-            },
-            .int_big_negative => {
-                @panic("TODO implement copying of big ints");
+            .int_big_positive, .int_big_negative => {
+                const old_payload = self.cast(Payload.BigInt).?;
+                const new_payload = try allocator.create(Payload.BigInt);
+                new_payload.* = .{
+                    .base = .{ .tag = self.ptr_otherwise.tag },
+                    .data = try allocator.dupe(std.math.big.Limb, old_payload.data),
+                };
+                return Value{ .ptr_otherwise = &new_payload.base };
             },
             .function => return self.copyPayloadShallow(allocator, Payload.Function),
             .extern_fn => return self.copyPayloadShallow(allocator, Payload.Decl),
src/zir.zig
@@ -1885,6 +1885,46 @@ pub fn dumpFn(old_module: IrModule, module_fn: *IrModule.Fn) void {
     module.dump();
 }
 
+/// For debugging purposes, prints a function representation to stderr.
+pub fn dumpBlock(old_module: IrModule, module_block: *IrModule.Scope.Block) void {
+    const allocator = old_module.gpa;
+    var ctx: EmitZIR = .{
+        .allocator = allocator,
+        .decls = .{},
+        .arena = std.heap.ArenaAllocator.init(allocator),
+        .old_module = &old_module,
+        .next_auto_name = 0,
+        .names = std.StringArrayHashMap(void).init(allocator),
+        .primitive_table = std.AutoHashMap(Inst.Primitive.Builtin, *Decl).init(allocator),
+        .indent = 0,
+        .block_table = std.AutoHashMap(*ir.Inst.Block, *Inst.Block).init(allocator),
+        .loop_table = std.AutoHashMap(*ir.Inst.Loop, *Inst.Loop).init(allocator),
+        .metadata = std.AutoHashMap(*Inst, Module.MetaData).init(allocator),
+        .body_metadata = std.AutoHashMap(*Module.Body, Module.BodyMetaData).init(allocator),
+    };
+    defer ctx.metadata.deinit();
+    defer ctx.body_metadata.deinit();
+    defer ctx.block_table.deinit();
+    defer ctx.loop_table.deinit();
+    defer ctx.decls.deinit(allocator);
+    defer ctx.names.deinit();
+    defer ctx.primitive_table.deinit();
+    defer ctx.arena.deinit();
+
+    _ = ctx.emitBlock(module_block, 0) catch |err| {
+        std.debug.print("unable to dump function: {}\n", .{err});
+        return;
+    };
+    var module = Module{
+        .decls = ctx.decls.items,
+        .arena = ctx.arena,
+        .metadata = ctx.metadata,
+        .body_metadata = ctx.body_metadata,
+    };
+
+    module.dump();
+}
+
 const EmitZIR = struct {
     allocator: *Allocator,
     arena: std.heap.ArenaAllocator,
@@ -2065,6 +2105,36 @@ const EmitZIR = struct {
         return &declref_inst.base;
     }
 
+    fn emitBlock(self: *EmitZIR, module_block: *IrModule.Scope.Block, src: usize) Allocator.Error!*Decl {
+        var inst_table = std.AutoHashMap(*ir.Inst, *Inst).init(self.allocator);
+        defer inst_table.deinit();
+
+        var instructions = std.ArrayList(*Inst).init(self.allocator);
+        defer instructions.deinit();
+
+        const body: ir.Body = .{ .instructions = module_block.instructions.items };
+        try self.emitBody(body, &inst_table, &instructions);
+
+        const fn_type = try self.emitType(src, Type.initTag(.void));
+
+        const arena_instrs = try self.arena.allocator.alloc(*Inst, instructions.items.len);
+        mem.copy(*Inst, arena_instrs, instructions.items);
+
+        const fn_inst = try self.arena.allocator.create(Inst.Fn);
+        fn_inst.* = .{
+            .base = .{
+                .src = src,
+                .tag = Inst.Fn.base_tag,
+            },
+            .positionals = .{
+                .fn_type = fn_type.inst,
+                .body = .{ .instructions = arena_instrs },
+            },
+            .kw_args = .{},
+        };
+        return self.emitUnnamedDecl(&fn_inst.base);
+    }
+
     fn emitFn(self: *EmitZIR, module_fn: *IrModule.Fn, src: usize, ty: Type) Allocator.Error!*Decl {
         var inst_table = std.AutoHashMap(*ir.Inst, *Inst).init(self.allocator);
         defer inst_table.deinit();
src/zir_sema.zig
@@ -25,6 +25,8 @@ const trace = @import("tracy.zig").trace;
 const Scope = Module.Scope;
 const InnerError = Module.InnerError;
 const Decl = Module.Decl;
+const astgen = @import("astgen.zig");
+const ast = std.zig.ast;
 
 pub fn analyzeInst(mod: *Module, scope: *Scope, old_inst: *zir.Inst) InnerError!*Inst {
     switch (old_inst.tag) {
@@ -826,7 +828,112 @@ fn analyzeInstCall(mod: *Module, scope: *Scope, inst: *zir.Inst.Call) InnerError
 
     const ret_type = func.ty.fnReturnType();
 
-    const b = try mod.requireRuntimeBlock(scope, inst.base.src);
+    const b = try mod.requireFunctionBlock(scope, inst.base.src);
+    if (b.is_comptime) {
+        const fn_val = try mod.resolveConstValue(scope, func);
+        const module_fn = switch (fn_val.tag()) {
+            .function => fn_val.castTag(.function).?.data,
+            .extern_fn => return mod.fail(scope, inst.base.src, "comptime call of extern function", .{}),
+            else => unreachable,
+        };
+        const callee_decl = module_fn.owner_decl;
+        const callee_file_scope = callee_decl.getFileScope();
+        const tree = mod.getAstTree(callee_file_scope) catch |err| switch (err) {
+            error.OutOfMemory => return error.OutOfMemory,
+            error.AnalysisFail => return error.AnalysisFail,
+            // TODO: make sure this gets retried and not cached
+            else => return mod.fail(scope, inst.base.src, "failed to load {s}: {s}", .{
+                callee_file_scope.sub_file_path, @errorName(err),
+            }),
+        };
+        const ast_node = tree.root_node.decls()[callee_decl.src_index];
+        const fn_proto = ast_node.castTag(.FnProto).?;
+
+        var call_arena = std.heap.ArenaAllocator.init(mod.gpa);
+        defer call_arena.deinit();
+
+        var gen_scope: Scope.GenZIR = .{
+            .decl = callee_decl,
+            .arena = &call_arena.allocator,
+            .parent = callee_decl.scope,
+        };
+        defer gen_scope.instructions.deinit(mod.gpa);
+
+        // Add a const instruction for each parameter.
+        var params_scope = &gen_scope.base;
+        for (fn_proto.params()) |param, i| {
+            const name_token = param.name_token.?;
+            const src = tree.token_locs[name_token].start;
+            const param_name = try mod.identifierTokenString(scope, name_token);
+            const arg_val = try mod.resolveConstValue(scope, casted_args[i]);
+            const arg = try astgen.addZIRInstConst(mod, params_scope, src, .{
+                .ty = casted_args[i].ty,
+                .val = arg_val,
+            });
+            const sub_scope = try call_arena.allocator.create(Scope.LocalVal);
+            sub_scope.* = .{
+                .parent = params_scope,
+                .gen_zir = &gen_scope,
+                .name = param_name,
+                .inst = arg,
+            };
+            params_scope = &sub_scope.base;
+        }
+
+        const body_node = fn_proto.getBodyNode().?; // We handle extern functions above.
+        const body_block = body_node.cast(ast.Node.Block).?;
+
+        try astgen.blockExpr(mod, params_scope, body_block);
+
+        if (gen_scope.instructions.items.len == 0 or
+            !gen_scope.instructions.items[gen_scope.instructions.items.len - 1].tag.isNoReturn())
+        {
+            const src = tree.token_locs[body_block.rbrace].start;
+            _ = try astgen.addZIRNoOp(mod, &gen_scope.base, src, .returnvoid);
+        }
+
+        if (mod.comp.verbose_ir) {
+            zir.dumpZir(mod.gpa, "fn_body_callee", callee_decl.name, gen_scope.instructions.items) catch {};
+        }
+
+        // Analyze the ZIR.
+        var inner_block: Scope.Block = .{
+            .parent = null,
+            .func = module_fn,
+            .decl = callee_decl,
+            .instructions = .{},
+            .arena = &call_arena.allocator,
+            .is_comptime = true,
+        };
+        defer inner_block.instructions.deinit(mod.gpa);
+
+        // TODO make sure compile errors that happen from this analyzeBody are reported correctly
+        // and attach to the caller Decl not the callee.
+        try analyzeBody(mod, &inner_block.base, .{
+            .instructions = gen_scope.instructions.items,
+        });
+
+        if (mod.comp.verbose_ir) {
+            inner_block.dump(mod.*);
+        }
+
+        assert(inner_block.instructions.items.len == 1);
+        const only_inst = inner_block.instructions.items[0];
+        switch (only_inst.tag) {
+            .ret => {
+                const ret_inst = only_inst.castTag(.ret).?;
+                const operand = ret_inst.operand;
+                const callee_arena = scope.arena();
+                return mod.constInst(scope, inst.base.src, .{
+                    .ty = try operand.ty.copy(callee_arena),
+                    .val = try operand.value().?.copy(callee_arena),
+                });
+            },
+            .retvoid => return mod.constVoid(scope, inst.base.src),
+            else => unreachable,
+        }
+    }
+
     return mod.addCall(b, inst.base.src, ret_type, func, casted_args);
 }
 
@@ -1509,7 +1616,7 @@ fn analyzeInstImport(mod: *Module, scope: *Scope, inst: *zir.Inst.UnOp) InnerErr
             return mod.fail(scope, inst.base.src, "unable to find '{s}'", .{operand});
         },
         else => {
-            // TODO user friendly error to string
+            // TODO: make sure this gets retried and not cached
             return mod.fail(scope, inst.base.src, "unable to open '{s}': {s}", .{ operand, @errorName(err) });
         },
     };
@@ -1912,12 +2019,12 @@ fn analyzeInstUnreachable(
 
 fn analyzeInstRet(mod: *Module, scope: *Scope, inst: *zir.Inst.UnOp) InnerError!*Inst {
     const operand = try resolveInst(mod, scope, inst.positionals.operand);
-    const b = try mod.requireRuntimeBlock(scope, inst.base.src);
+    const b = try mod.requireFunctionBlock(scope, inst.base.src);
     return mod.addUnOp(b, inst.base.src, Type.initTag(.noreturn), .ret, operand);
 }
 
 fn analyzeInstRetVoid(mod: *Module, scope: *Scope, inst: *zir.Inst.NoOp) InnerError!*Inst {
-    const b = try mod.requireRuntimeBlock(scope, inst.base.src);
+    const b = try mod.requireFunctionBlock(scope, inst.base.src);
     if (b.func) |func| {
         // Need to emit a compile error if returning void is not allowed.
         const void_inst = try mod.constVoid(scope, inst.base.src);
test/stage2/test.zig
@@ -318,7 +318,7 @@ pub fn addCases(ctx: *TestContext) !void {
     }
 
     {
-        var case = ctx.exe("adding numbers at runtime", linux_x64);
+        var case = ctx.exe("adding numbers at runtime and comptime", linux_x64);
         case.addCompareOutput(
             \\export fn _start() noreturn {
             \\    add(3, 4);
@@ -342,6 +342,29 @@ pub fn addCases(ctx: *TestContext) !void {
         ,
             "",
         );
+        case.addCompareOutput(
+            \\export fn _start() noreturn {
+            \\    exit();
+            \\}
+            \\
+            \\fn add(a: u32, b: u32) u32 {
+            \\    return a + b;
+            \\}
+            \\
+            \\const x = add(3, 4);
+            \\
+            \\fn exit() noreturn {
+            \\    asm volatile ("syscall"
+            \\        :
+            \\        : [number] "{rax}" (231),
+            \\          [arg1] "{rdi}" (x - 7)
+            \\        : "rcx", "r11", "memory"
+            \\    );
+            \\    unreachable;
+            \\}
+        ,
+            "",
+        );
     }
 
     {