Commit c03a04a589

Andrew Kelley <andrew@ziglang.org>
2021-08-06 04:15:59
stage2: return type expressions of generic functions
* ZIR encoding for function instructions have a body for the return type. This lets Sema for generic functions do the same thing it does for parameters, handling `error.GenericPoison` in the evaluation of the return type by marking the function as generic. * Sema: fix missing block around the new Decl arena finalization. This led to a memory corruption. * Added some floating point support to the LLVM backend but didn't get far enough to pass any new tests.
1 parent e9e3a29
src/codegen/llvm/bindings.zig
@@ -31,6 +31,21 @@ pub const Context = opaque {
     pub const intType = LLVMIntTypeInContext;
     extern fn LLVMIntTypeInContext(C: *const Context, NumBits: c_uint) *const Type;
 
+    pub const halfType = LLVMHalfTypeInContext;
+    extern fn LLVMHalfTypeInContext(C: *const Context) *const Type;
+
+    pub const floatType = LLVMFloatTypeInContext;
+    extern fn LLVMFloatTypeInContext(C: *const Context) *const Type;
+
+    pub const doubleType = LLVMDoubleTypeInContext;
+    extern fn LLVMDoubleTypeInContext(C: *const Context) *const Type;
+
+    pub const x86FP80Type = LLVMX86FP80TypeInContext;
+    extern fn LLVMX86FP80TypeInContext(C: *const Context) *const Type;
+
+    pub const fp128Type = LLVMFP128TypeInContext;
+    extern fn LLVMFP128TypeInContext(C: *const Context) *const Type;
+
     pub const voidType = LLVMVoidTypeInContext;
     extern fn LLVMVoidTypeInContext(C: *const Context) *const Type;
 
@@ -127,6 +142,9 @@ pub const Type = opaque {
     pub const constInt = LLVMConstInt;
     extern fn LLVMConstInt(IntTy: *const Type, N: c_ulonglong, SignExtend: Bool) *const Value;
 
+    pub const constReal = LLVMConstReal;
+    extern fn LLVMConstReal(RealTy: *const Type, N: f64) *const Value;
+
     pub const constArray = LLVMConstArray;
     extern fn LLVMConstArray(ElementTy: *const Type, ConstantVals: [*]*const Value, Length: c_uint) *const Value;
 
src/codegen/llvm.zig
@@ -575,6 +575,14 @@ pub const DeclGen = struct {
                 const info = t.intInfo(self.module.getTarget());
                 return self.context.intType(info.bits);
             },
+            .Float => switch (t.floatBits(self.module.getTarget())) {
+                16 => return self.context.halfType(),
+                32 => return self.context.floatType(),
+                64 => return self.context.doubleType(),
+                80 => return self.context.x86FP80Type(),
+                128 => return self.context.fp128Type(),
+                else => unreachable,
+            },
             .Bool => return self.context.intType(1),
             .Pointer => {
                 if (t.isSlice()) {
@@ -661,7 +669,6 @@ pub const DeclGen = struct {
 
             .BoundFn => @panic("TODO remove BoundFn from the language"),
 
-            .Float,
             .Enum,
             .Union,
             .Opaque,
@@ -699,6 +706,13 @@ pub const DeclGen = struct {
                 }
                 return llvm_int;
             },
+            .Float => {
+                if (tv.ty.floatBits(self.module.getTarget()) <= 64) {
+                    const llvm_ty = try self.llvmType(tv.ty);
+                    return llvm_ty.constReal(tv.val.toFloat(f64));
+                }
+                return self.todo("bitcast to f128 from an integer", .{});
+            },
             .Pointer => switch (tv.val.tag()) {
                 .decl_ref => {
                     if (tv.ty.isSlice()) {
src/AstGen.zig
@@ -1041,6 +1041,7 @@ fn fnProtoExpr(
     fn_proto: ast.full.FnProto,
 ) InnerError!Zir.Inst.Ref {
     const astgen = gz.astgen;
+    const gpa = astgen.gpa;
     const tree = astgen.tree;
     const token_tags = tree.tokens.items(.tag);
 
@@ -1083,7 +1084,6 @@ fn fnProtoExpr(
                     .param_anytype;
                 _ = try gz.addStrTok(tag, param_name, name_token);
             } else {
-                const gpa = astgen.gpa;
                 const param_type_node = param.type_expr;
                 assert(param_type_node != 0);
                 var param_gz = gz.makeSubBlock(scope);
@@ -1113,15 +1113,13 @@ fn fnProtoExpr(
     if (is_inferred_error) {
         return astgen.failTok(maybe_bang, "function prototype may not have inferred error set", .{});
     }
-    const return_type_inst = try AstGen.expr(
-        gz,
-        scope,
-        .{ .ty = .type_type },
-        fn_proto.ast.return_type,
-    );
+    var ret_gz = gz.makeSubBlock(scope);
+    defer ret_gz.instructions.deinit(gpa);
+    const ret_ty = try expr(&ret_gz, scope, coerced_type_rl, fn_proto.ast.return_type);
+    const ret_br = try ret_gz.addBreak(.break_inline, 0, ret_ty);
 
     const cc: Zir.Inst.Ref = if (fn_proto.ast.callconv_expr != 0)
-        try AstGen.expr(
+        try expr(
             gz,
             scope,
             .{ .ty = .calling_convention_type },
@@ -1133,7 +1131,8 @@ fn fnProtoExpr(
     const result = try gz.addFunc(.{
         .src_node = fn_proto.ast.proto_node,
         .param_block = 0,
-        .ret_ty = return_type_inst,
+        .ret_ty = ret_gz.instructions.items,
+        .ret_br = ret_br,
         .body = &[0]Zir.Inst.Index{},
         .cc = cc,
         .align_inst = align_inst,
@@ -3005,12 +3004,10 @@ fn fnDecl(
         break :inst try comptimeExpr(&decl_gz, params_scope, .{ .ty = .const_slice_u8_type }, fn_proto.ast.section_expr);
     };
 
-    const return_type_inst = try AstGen.expr(
-        &decl_gz,
-        params_scope,
-        .{ .ty = .type_type },
-        fn_proto.ast.return_type,
-    );
+    var ret_gz = gz.makeSubBlock(params_scope);
+    defer ret_gz.instructions.deinit(gpa);
+    const ret_ty = try expr(&decl_gz, params_scope, coerced_type_rl, fn_proto.ast.return_type);
+    const ret_br = try ret_gz.addBreak(.break_inline, 0, ret_ty);
 
     const cc: Zir.Inst.Ref = blk: {
         if (fn_proto.ast.callconv_expr != 0) {
@@ -3021,7 +3018,7 @@ fn fnDecl(
                     .{},
                 );
             }
-            break :blk try AstGen.expr(
+            break :blk try expr(
                 &decl_gz,
                 params_scope,
                 .{ .ty = .calling_convention_type },
@@ -3046,7 +3043,8 @@ fn fnDecl(
         }
         break :func try decl_gz.addFunc(.{
             .src_node = decl_node,
-            .ret_ty = return_type_inst,
+            .ret_ty = ret_gz.instructions.items,
+            .ret_br = ret_br,
             .param_block = block_inst,
             .body = &[0]Zir.Inst.Index{},
             .cc = cc,
@@ -3085,7 +3083,8 @@ fn fnDecl(
         break :func try decl_gz.addFunc(.{
             .src_node = decl_node,
             .param_block = block_inst,
-            .ret_ty = return_type_inst,
+            .ret_ty = ret_gz.instructions.items,
+            .ret_br = ret_br,
             .body = fn_gz.instructions.items,
             .cc = cc,
             .align_inst = .none, // passed in the per-decl data
@@ -3430,7 +3429,8 @@ fn testDecl(
     const func_inst = try decl_block.addFunc(.{
         .src_node = node,
         .param_block = block_inst,
-        .ret_ty = .void_type,
+        .ret_ty = &.{},
+        .ret_br = 0,
         .body = fn_block.instructions.items,
         .cc = .none,
         .align_inst = .none,
@@ -9127,7 +9127,8 @@ const GenZir = struct {
         src_node: ast.Node.Index,
         body: []const Zir.Inst.Index,
         param_block: Zir.Inst.Index,
-        ret_ty: Zir.Inst.Ref,
+        ret_ty: []const Zir.Inst.Index,
+        ret_br: Zir.Inst.Index,
         cc: Zir.Inst.Ref,
         align_inst: Zir.Inst.Ref,
         lib_name: u32,
@@ -9137,7 +9138,6 @@ const GenZir = struct {
         is_extern: bool,
     }) !Zir.Inst.Ref {
         assert(args.src_node != 0);
-        assert(args.ret_ty != .none);
         const astgen = gz.astgen;
         const gpa = astgen.gpa;
 
@@ -9179,7 +9179,7 @@ const GenZir = struct {
             try astgen.extra.ensureUnusedCapacity(
                 gpa,
                 @typeInfo(Zir.Inst.ExtendedFunc).Struct.fields.len +
-                    args.body.len + src_locs.len +
+                    args.ret_ty.len + args.body.len + src_locs.len +
                     @boolToInt(args.lib_name != 0) +
                     @boolToInt(args.align_inst != .none) +
                     @boolToInt(args.cc != .none),
@@ -9187,7 +9187,7 @@ const GenZir = struct {
             const payload_index = astgen.addExtraAssumeCapacity(Zir.Inst.ExtendedFunc{
                 .src_node = gz.nodeIndexToRelative(args.src_node),
                 .param_block = args.param_block,
-                .return_type = args.ret_ty,
+                .ret_body_len = @intCast(u32, args.ret_ty.len),
                 .body_len = @intCast(u32, args.body.len),
             });
             if (args.lib_name != 0) {
@@ -9199,10 +9199,14 @@ const GenZir = struct {
             if (args.align_inst != .none) {
                 astgen.extra.appendAssumeCapacity(@enumToInt(args.align_inst));
             }
+            astgen.extra.appendSliceAssumeCapacity(args.ret_ty);
             astgen.extra.appendSliceAssumeCapacity(args.body);
             astgen.extra.appendSliceAssumeCapacity(src_locs);
 
             const new_index = @intCast(Zir.Inst.Index, astgen.instructions.len);
+            if (args.ret_br != 0) {
+                astgen.instructions.items(.data)[args.ret_br].@"break".block_inst = new_index;
+            }
             astgen.instructions.appendAssumeCapacity(.{
                 .tag = .extended,
                 .data = .{ .extended = .{
@@ -9222,23 +9226,27 @@ const GenZir = struct {
             gz.instructions.appendAssumeCapacity(new_index);
             return indexToRef(new_index);
         } else {
-            try gz.astgen.extra.ensureUnusedCapacity(
+            try astgen.extra.ensureUnusedCapacity(
                 gpa,
                 @typeInfo(Zir.Inst.Func).Struct.fields.len +
-                    args.body.len + src_locs.len,
+                    args.ret_ty.len + args.body.len + src_locs.len,
             );
 
-            const payload_index = gz.astgen.addExtraAssumeCapacity(Zir.Inst.Func{
+            const payload_index = astgen.addExtraAssumeCapacity(Zir.Inst.Func{
                 .param_block = args.param_block,
-                .return_type = args.ret_ty,
+                .ret_body_len = @intCast(u32, args.ret_ty.len),
                 .body_len = @intCast(u32, args.body.len),
             });
-            gz.astgen.extra.appendSliceAssumeCapacity(args.body);
-            gz.astgen.extra.appendSliceAssumeCapacity(src_locs);
+            astgen.extra.appendSliceAssumeCapacity(args.ret_ty);
+            astgen.extra.appendSliceAssumeCapacity(args.body);
+            astgen.extra.appendSliceAssumeCapacity(src_locs);
 
             const tag: Zir.Inst.Tag = if (args.is_inferred_error) .func_inferred else .func;
-            const new_index = @intCast(Zir.Inst.Index, gz.astgen.instructions.len);
-            gz.astgen.instructions.appendAssumeCapacity(.{
+            const new_index = @intCast(Zir.Inst.Index, astgen.instructions.len);
+            if (args.ret_br != 0) {
+                astgen.instructions.items(.data)[args.ret_br].@"break".block_inst = new_index;
+            }
+            astgen.instructions.appendAssumeCapacity(.{
                 .tag = tag,
                 .data = .{ .pl_node = .{
                     .src_node = gz.nodeIndexToRelative(args.src_node),
src/Module.zig
@@ -842,6 +842,9 @@ pub const Fn = struct {
 
     pub fn getInferredErrorSet(func: *Fn) ?*std.StringHashMapUnmanaged(void) {
         const ret_ty = func.owner_decl.ty.fnReturnType();
+        if (ret_ty.tag() == .generic_poison) {
+            return null;
+        }
         if (ret_ty.zigTypeTag() == .ErrorUnion) {
             if (ret_ty.errorUnionSet().castTag(.error_set_inferred)) |payload| {
                 return &payload.data.map;
src/print_air.zig
@@ -222,7 +222,7 @@ const Writer = struct {
         const extra = w.air.extraData(Air.Block, ty_pl.payload);
         const body = w.air.extra[extra.end..][0..extra.data.body_len];
 
-        try s.writeAll("{\n");
+        try s.print("{}, {{\n", .{w.air.getRefType(ty_pl.ty)});
         const old_indent = w.indent;
         w.indent += 2;
         try w.writeBody(s, body);
src/Sema.zig
@@ -2618,128 +2618,130 @@ fn analyzeCall(
             break :new_func gop.key_ptr.*;
         };
 
-        try namespace.anon_decls.ensureUnusedCapacity(gpa, 1);
-
-        // Create a Decl for the new function.
-        const new_decl = try mod.allocateNewDecl(namespace, module_fn.owner_decl.src_node);
-        // TODO better names for generic function instantiations
-        const name_index = mod.getNextAnonNameIndex();
-        new_decl.name = try std.fmt.allocPrintZ(gpa, "{s}__anon_{d}", .{
-            module_fn.owner_decl.name, name_index,
-        });
-        new_decl.src_line = module_fn.owner_decl.src_line;
-        new_decl.is_pub = module_fn.owner_decl.is_pub;
-        new_decl.is_exported = module_fn.owner_decl.is_exported;
-        new_decl.has_align = module_fn.owner_decl.has_align;
-        new_decl.has_linksection = module_fn.owner_decl.has_linksection;
-        new_decl.zir_decl_index = module_fn.owner_decl.zir_decl_index;
-        new_decl.alive = true; // This Decl is called at runtime.
-        new_decl.has_tv = true;
-        new_decl.owns_tv = true;
-        new_decl.analysis = .in_progress;
-        new_decl.generation = mod.generation;
-
-        namespace.anon_decls.putAssumeCapacityNoClobber(new_decl, {});
-
-        var new_decl_arena = std.heap.ArenaAllocator.init(sema.gpa);
-        errdefer new_decl_arena.deinit();
-
-        // Re-run the block that creates the function, with the comptime parameters
-        // pre-populated inside `inst_map`. This causes `param_comptime` and
-        // `param_anytype_comptime` ZIR instructions to be ignored, resulting in a
-        // new, monomorphized function, with the comptime parameters elided.
-        var child_sema: Sema = .{
-            .mod = mod,
-            .gpa = gpa,
-            .arena = sema.arena,
-            .code = fn_zir,
-            .owner_decl = new_decl,
-            .namespace = namespace,
-            .func = null,
-            .owner_func = null,
-            .comptime_args = try new_decl_arena.allocator.alloc(TypedValue, uncasted_args.len),
-            .comptime_args_fn_inst = module_fn.zir_body_inst,
-            .preallocated_new_func = new_module_func,
-        };
-        defer child_sema.deinit();
-
-        var child_block: Scope.Block = .{
-            .parent = null,
-            .sema = &child_sema,
-            .src_decl = new_decl,
-            .instructions = .{},
-            .inlining = null,
-            .is_comptime = true,
-        };
-        defer {
-            child_block.instructions.deinit(gpa);
-            child_block.params.deinit(gpa);
-        }
-
-        try child_sema.inst_map.ensureUnusedCapacity(gpa, @intCast(u32, uncasted_args.len));
-        var arg_i: usize = 0;
-        for (fn_info.param_body) |inst| {
-            const is_comptime = switch (zir_tags[inst]) {
-                .param_comptime, .param_anytype_comptime => true,
-                .param, .param_anytype => false,
-                else => continue,
-            } or func_ty_info.paramIsComptime(arg_i);
-            const arg_src = call_src; // TODO: better source location
-            const arg = uncasted_args[arg_i];
-            if (try sema.resolveMaybeUndefVal(block, arg_src, arg)) |arg_val| {
-                const child_arg = try child_sema.addConstant(sema.typeOf(arg), arg_val);
-                child_sema.inst_map.putAssumeCapacityNoClobber(inst, child_arg);
-            } else if (is_comptime) {
-                return sema.failWithNeededComptime(block, arg_src);
+        {
+            try namespace.anon_decls.ensureUnusedCapacity(gpa, 1);
+
+            // Create a Decl for the new function.
+            const new_decl = try mod.allocateNewDecl(namespace, module_fn.owner_decl.src_node);
+            // TODO better names for generic function instantiations
+            const name_index = mod.getNextAnonNameIndex();
+            new_decl.name = try std.fmt.allocPrintZ(gpa, "{s}__anon_{d}", .{
+                module_fn.owner_decl.name, name_index,
+            });
+            new_decl.src_line = module_fn.owner_decl.src_line;
+            new_decl.is_pub = module_fn.owner_decl.is_pub;
+            new_decl.is_exported = module_fn.owner_decl.is_exported;
+            new_decl.has_align = module_fn.owner_decl.has_align;
+            new_decl.has_linksection = module_fn.owner_decl.has_linksection;
+            new_decl.zir_decl_index = module_fn.owner_decl.zir_decl_index;
+            new_decl.alive = true; // This Decl is called at runtime.
+            new_decl.has_tv = true;
+            new_decl.owns_tv = true;
+            new_decl.analysis = .in_progress;
+            new_decl.generation = mod.generation;
+
+            namespace.anon_decls.putAssumeCapacityNoClobber(new_decl, {});
+
+            var new_decl_arena = std.heap.ArenaAllocator.init(sema.gpa);
+            errdefer new_decl_arena.deinit();
+
+            // Re-run the block that creates the function, with the comptime parameters
+            // pre-populated inside `inst_map`. This causes `param_comptime` and
+            // `param_anytype_comptime` ZIR instructions to be ignored, resulting in a
+            // new, monomorphized function, with the comptime parameters elided.
+            var child_sema: Sema = .{
+                .mod = mod,
+                .gpa = gpa,
+                .arena = sema.arena,
+                .code = fn_zir,
+                .owner_decl = new_decl,
+                .namespace = namespace,
+                .func = null,
+                .owner_func = null,
+                .comptime_args = try new_decl_arena.allocator.alloc(TypedValue, uncasted_args.len),
+                .comptime_args_fn_inst = module_fn.zir_body_inst,
+                .preallocated_new_func = new_module_func,
+            };
+            defer child_sema.deinit();
+
+            var child_block: Scope.Block = .{
+                .parent = null,
+                .sema = &child_sema,
+                .src_decl = new_decl,
+                .instructions = .{},
+                .inlining = null,
+                .is_comptime = true,
+            };
+            defer {
+                child_block.instructions.deinit(gpa);
+                child_block.params.deinit(gpa);
             }
-            arg_i += 1;
-        }
-        const new_func_inst = try child_sema.resolveBody(&child_block, fn_info.param_body);
-        const new_func_val = try child_sema.resolveConstValue(&child_block, .unneeded, new_func_inst);
-        const new_func = new_func_val.castTag(.function).?.data;
-        assert(new_func == new_module_func);
 
-        arg_i = 0;
-        for (fn_info.param_body) |inst| {
-            switch (zir_tags[inst]) {
-                .param_comptime, .param_anytype_comptime, .param, .param_anytype => {},
-                else => continue,
+            try child_sema.inst_map.ensureUnusedCapacity(gpa, @intCast(u32, uncasted_args.len));
+            var arg_i: usize = 0;
+            for (fn_info.param_body) |inst| {
+                const is_comptime = switch (zir_tags[inst]) {
+                    .param_comptime, .param_anytype_comptime => true,
+                    .param, .param_anytype => false,
+                    else => continue,
+                } or func_ty_info.paramIsComptime(arg_i);
+                const arg_src = call_src; // TODO: better source location
+                const arg = uncasted_args[arg_i];
+                if (try sema.resolveMaybeUndefVal(block, arg_src, arg)) |arg_val| {
+                    const child_arg = try child_sema.addConstant(sema.typeOf(arg), arg_val);
+                    child_sema.inst_map.putAssumeCapacityNoClobber(inst, child_arg);
+                } else if (is_comptime) {
+                    return sema.failWithNeededComptime(block, arg_src);
+                }
+                arg_i += 1;
             }
-            const arg = child_sema.inst_map.get(inst).?;
-            const arg_val = (child_sema.resolveMaybeUndefValAllowVariables(&child_block, .unneeded, arg) catch unreachable).?;
+            const new_func_inst = try child_sema.resolveBody(&child_block, fn_info.param_body);
+            const new_func_val = try child_sema.resolveConstValue(&child_block, .unneeded, new_func_inst);
+            const new_func = new_func_val.castTag(.function).?.data;
+            assert(new_func == new_module_func);
+
+            arg_i = 0;
+            for (fn_info.param_body) |inst| {
+                switch (zir_tags[inst]) {
+                    .param_comptime, .param_anytype_comptime, .param, .param_anytype => {},
+                    else => continue,
+                }
+                const arg = child_sema.inst_map.get(inst).?;
+                const arg_val = (child_sema.resolveMaybeUndefValAllowVariables(&child_block, .unneeded, arg) catch unreachable).?;
 
-            if (arg_val.tag() == .generic_poison) {
-                child_sema.comptime_args[arg_i] = .{
-                    .ty = Type.initTag(.noreturn),
-                    .val = Value.initTag(.unreachable_value),
-                };
-            } else {
-                child_sema.comptime_args[arg_i] = .{
-                    .ty = try child_sema.typeOf(arg).copy(&new_decl_arena.allocator),
-                    .val = try arg_val.copy(&new_decl_arena.allocator),
-                };
+                if (arg_val.tag() == .generic_poison) {
+                    child_sema.comptime_args[arg_i] = .{
+                        .ty = Type.initTag(.noreturn),
+                        .val = Value.initTag(.unreachable_value),
+                    };
+                } else {
+                    child_sema.comptime_args[arg_i] = .{
+                        .ty = try child_sema.typeOf(arg).copy(&new_decl_arena.allocator),
+                        .val = try arg_val.copy(&new_decl_arena.allocator),
+                    };
+                }
+
+                arg_i += 1;
             }
 
-            arg_i += 1;
-        }
+            // Populate the Decl ty/val with the function and its type.
+            new_decl.ty = try child_sema.typeOf(new_func_inst).copy(&new_decl_arena.allocator);
+            new_decl.val = try Value.Tag.function.create(&new_decl_arena.allocator, new_func);
+            new_decl.analysis = .complete;
 
-        // Populate the Decl ty/val with the function and its type.
-        new_decl.ty = try child_sema.typeOf(new_func_inst).copy(&new_decl_arena.allocator);
-        new_decl.val = try Value.Tag.function.create(&new_decl_arena.allocator, new_func);
-        new_decl.analysis = .complete;
+            // The generic function Decl is guaranteed to be the first dependency
+            // of each of its instantiations.
+            assert(new_decl.dependencies.keys().len == 0);
+            try mod.declareDeclDependency(new_decl, module_fn.owner_decl);
 
-        // Queue up a `codegen_func` work item for the new Fn. The `comptime_args` field
-        // will be populated, ensuring it will have `analyzeBody` called with the ZIR
-        // parameters mapped appropriately.
-        try mod.comp.bin_file.allocateDeclIndexes(new_decl);
-        try mod.comp.work_queue.writeItem(.{ .codegen_func = new_func });
+            // Queue up a `codegen_func` work item for the new Fn. The `comptime_args` field
+            // will be populated, ensuring it will have `analyzeBody` called with the ZIR
+            // parameters mapped appropriately.
+            try mod.comp.bin_file.allocateDeclIndexes(new_decl);
+            try mod.comp.work_queue.writeItem(.{ .codegen_func = new_func });
 
-        try new_decl.finalizeNewArena(&new_decl_arena);
-
-        // The generic function Decl is guaranteed to be the first dependency
-        // of each of its instantiations.
-        assert(new_decl.dependencies.keys().len == 0);
-        try mod.declareDeclDependency(new_decl, module_fn.owner_decl);
+            try new_decl.finalizeNewArena(&new_decl_arena);
+        }
 
         break :res try sema.finishGenericCall(
             block,
@@ -3478,12 +3480,15 @@ fn zirFunc(
 
     const inst_data = sema.code.instructions.items(.data)[inst].pl_node;
     const extra = sema.code.extraData(Zir.Inst.Func, inst_data.payload_index);
+    var extra_index = extra.end;
+    const ret_ty_body = sema.code.extra[extra_index..][0..extra.data.ret_body_len];
+    extra_index += ret_ty_body.len;
 
     var body_inst: Zir.Inst.Index = 0;
     var src_locs: Zir.Inst.Func.SrcLocs = undefined;
     if (extra.data.body_len != 0) {
         body_inst = inst;
-        const extra_index = extra.end + extra.data.body_len;
+        extra_index += extra.data.body_len;
         src_locs = sema.code.extraData(Zir.Inst.Func.SrcLocs, extra_index).data;
     }
 
@@ -3496,7 +3501,7 @@ fn zirFunc(
         block,
         inst_data.src_node,
         body_inst,
-        extra.data.return_type,
+        ret_ty_body,
         cc,
         Value.initTag(.null_value),
         false,
@@ -3512,7 +3517,7 @@ fn funcCommon(
     block: *Scope.Block,
     src_node_offset: i32,
     body_inst: Zir.Inst.Index,
-    zir_return_type: Zir.Inst.Ref,
+    ret_ty_body: []const Zir.Inst.Index,
     cc: std.builtin.CallingConvention,
     align_val: Value,
     var_args: bool,
@@ -3523,7 +3528,37 @@ fn funcCommon(
 ) CompileError!Air.Inst.Ref {
     const src: LazySrcLoc = .{ .node_offset = src_node_offset };
     const ret_ty_src: LazySrcLoc = .{ .node_offset_fn_type_ret_ty = src_node_offset };
-    const bare_return_type = try sema.resolveType(block, ret_ty_src, zir_return_type);
+
+    // The return type body might be a type expression that depends on generic parameters.
+    // In such case we need to use a generic_poison value for the return type and mark
+    // the function as generic.
+    var is_generic = false;
+    const bare_return_type: Type = ret_ty: {
+        if (ret_ty_body.len == 0) break :ret_ty Type.initTag(.void);
+
+        const err = err: {
+            // Make sure any nested param instructions don't clobber our work.
+            const prev_params = block.params;
+            block.params = .{};
+            defer {
+                block.params.deinit(sema.gpa);
+                block.params = prev_params;
+            }
+            if (sema.resolveBody(block, ret_ty_body)) |ret_ty_inst| {
+                if (sema.analyzeAsType(block, ret_ty_src, ret_ty_inst)) |ret_ty| {
+                    break :ret_ty ret_ty;
+                } else |err| break :err err;
+            } else |err| break :err err;
+        };
+        switch (err) {
+            error.GenericPoison => {
+                // The type is not available until the generic instantiation.
+                is_generic = true;
+                break :ret_ty Type.initTag(.generic_poison);
+            },
+            else => |e| return e,
+        }
+    };
 
     const mod = sema.mod;
 
@@ -3540,8 +3575,9 @@ fn funcCommon(
 
     const fn_ty: Type = fn_ty: {
         // Hot path for some common function types.
-        if (block.params.items.len == 0 and !var_args and align_val.tag() == .null_value and
-            !inferred_error_set)
+        // TODO can we eliminate some of these Type tag values? seems unnecessarily complicated.
+        if (!is_generic and block.params.items.len == 0 and !var_args and
+            align_val.tag() == .null_value and !inferred_error_set)
         {
             if (bare_return_type.zigTypeTag() == .NoReturn and cc == .Unspecified) {
                 break :fn_ty Type.initTag(.fn_noreturn_no_args);
@@ -3560,7 +3596,6 @@ fn funcCommon(
             }
         }
 
-        var is_generic = false;
         const param_types = try sema.arena.alloc(Type, block.params.items.len);
         const comptime_params = try sema.arena.alloc(bool, block.params.items.len);
         for (block.params.items) |param, i| {
@@ -3574,7 +3609,9 @@ fn funcCommon(
             return mod.fail(&block.base, src, "TODO implement support for function prototypes to have alignment specified", .{});
         }
 
-        const return_type = if (!inferred_error_set) bare_return_type else blk: {
+        const return_type = if (!inferred_error_set or bare_return_type.tag() == .generic_poison)
+            bare_return_type
+        else blk: {
             const error_set_ty = try Type.Tag.error_set_inferred.create(sema.arena, .{
                 .func = new_func,
                 .map = .{},
@@ -6944,6 +6981,9 @@ fn zirFuncExtended(
         break :blk align_tv.val;
     } else Value.initTag(.null_value);
 
+    const ret_ty_body = sema.code.extra[extra_index..][0..extra.data.ret_body_len];
+    extra_index += ret_ty_body.len;
+
     var body_inst: Zir.Inst.Index = 0;
     var src_locs: Zir.Inst.Func.SrcLocs = undefined;
     if (extra.data.body_len != 0) {
@@ -6960,7 +7000,7 @@ fn zirFuncExtended(
         block,
         extra.data.src_node,
         body_inst,
-        extra.data.return_type,
+        ret_ty_body,
         cc,
         align_val,
         is_var_args,
src/Zir.zig
@@ -2272,11 +2272,13 @@ pub const Inst = struct {
     /// 0. lib_name: u32, // null terminated string index, if has_lib_name is set
     /// 1. cc: Ref, // if has_cc is set
     /// 2. align: Ref, // if has_align is set
-    /// 3. body: Index // for each body_len
-    /// 4. src_locs: Func.SrcLocs // if body_len != 0
+    /// 3. return_type: Index // for each ret_body_len
+    /// 4. body: Index // for each body_len
+    /// 5. src_locs: Func.SrcLocs // if body_len != 0
     pub const ExtendedFunc = struct {
         src_node: i32,
-        return_type: Ref,
+        /// If this is 0 it means a void return type.
+        ret_body_len: u32,
         /// Points to the block that contains the param instructions for this function.
         param_block: Index,
         body_len: u32,
@@ -2312,10 +2314,12 @@ pub const Inst = struct {
     };
 
     /// Trailing:
-    /// 0. body: Index // for each body_len
-    /// 1. src_locs: SrcLocs // if body_len != 0
+    /// 0. return_type: Index // for each ret_body_len
+    /// 1. body: Index // for each body_len
+    /// 2. src_locs: SrcLocs // if body_len != 0
     pub const Func = struct {
-        return_type: Ref,
+        /// If this is 0 it means a void return type.
+        ret_body_len: u32,
         /// Points to the block that contains the param instructions for this function.
         param_block: Index,
         body_len: u32,
@@ -4344,15 +4348,21 @@ const Writer = struct {
         const inst_data = self.code.instructions.items(.data)[inst].pl_node;
         const src = inst_data.src();
         const extra = self.code.extraData(Inst.Func, inst_data.payload_index);
-        const body = self.code.extra[extra.end..][0..extra.data.body_len];
+        var extra_index = extra.end;
+
+        const ret_ty_body = self.code.extra[extra_index..][0..extra.data.ret_body_len];
+        extra_index += ret_ty_body.len;
+
+        const body = self.code.extra[extra_index..][0..extra.data.body_len];
+        extra_index += body.len;
+
         var src_locs: Zir.Inst.Func.SrcLocs = undefined;
         if (body.len != 0) {
-            const extra_index = extra.end + body.len;
             src_locs = self.code.extraData(Zir.Inst.Func.SrcLocs, extra_index).data;
         }
         return self.writeFuncCommon(
             stream,
-            extra.data.return_type,
+            ret_ty_body,
             inferred_error_set,
             false,
             false,
@@ -4387,6 +4397,9 @@ const Writer = struct {
             break :blk align_inst;
         };
 
+        const ret_ty_body = self.code.extra[extra_index..][0..extra.data.ret_body_len];
+        extra_index += ret_ty_body.len;
+
         const body = self.code.extra[extra_index..][0..extra.data.body_len];
         extra_index += body.len;
 
@@ -4396,7 +4409,7 @@ const Writer = struct {
         }
         return self.writeFuncCommon(
             stream,
-            extra.data.return_type,
+            ret_ty_body,
             small.is_inferred_error,
             small.is_var_args,
             small.is_extern,
@@ -4478,7 +4491,7 @@ const Writer = struct {
     fn writeFuncCommon(
         self: *Writer,
         stream: anytype,
-        ret_ty: Inst.Ref,
+        ret_ty_body: []const Inst.Index,
         inferred_error_set: bool,
         var_args: bool,
         is_extern: bool,
@@ -4488,7 +4501,13 @@ const Writer = struct {
         src: LazySrcLoc,
         src_locs: Zir.Inst.Func.SrcLocs,
     ) !void {
-        try self.writeInstRef(stream, ret_ty);
+        try stream.writeAll("ret_ty={\n");
+        self.indent += 2;
+        try self.writeBody(stream, ret_ty_body);
+        self.indent -= 2;
+        try stream.writeByteNTimes(' ', self.indent);
+        try stream.writeAll("}");
+
         try self.writeOptionalInstRef(stream, ", cc=", cc);
         try self.writeOptionalInstRef(stream, ", align=", align_inst);
         try self.writeFlag(stream, ", vargs", var_args);
@@ -4496,9 +4515,9 @@ const Writer = struct {
         try self.writeFlag(stream, ", inferror", inferred_error_set);
 
         if (body.len == 0) {
-            try stream.writeAll(", {}) ");
+            try stream.writeAll(", body={}) ");
         } else {
-            try stream.writeAll(", {\n");
+            try stream.writeAll(", body={\n");
             self.indent += 2;
             try self.writeBody(stream, body);
             self.indent -= 2;
@@ -4932,6 +4951,7 @@ fn findDeclsBody(
 
 pub const FnInfo = struct {
     param_body: []const Inst.Index,
+    ret_ty_body: []const Inst.Index,
     body: []const Inst.Index,
     total_params_len: u32,
 };
@@ -4942,13 +4962,22 @@ pub fn getFnInfo(zir: Zir, fn_inst: Inst.Index) FnInfo {
     const info: struct {
         param_block: Inst.Index,
         body: []const Inst.Index,
+        ret_ty_body: []const Inst.Index,
     } = switch (tags[fn_inst]) {
         .func, .func_inferred => blk: {
             const inst_data = datas[fn_inst].pl_node;
             const extra = zir.extraData(Inst.Func, inst_data.payload_index);
-            const body = zir.extra[extra.end..][0..extra.data.body_len];
+            var extra_index: usize = extra.end;
+
+            const ret_ty_body = zir.extra[extra_index..][0..extra.data.ret_body_len];
+            extra_index += ret_ty_body.len;
+
+            const body = zir.extra[extra_index..][0..extra.data.body_len];
+            extra_index += body.len;
+
             break :blk .{
                 .param_block = extra.data.param_block,
+                .ret_ty_body = ret_ty_body,
                 .body = body,
             };
         },
@@ -4961,9 +4990,13 @@ pub fn getFnInfo(zir: Zir, fn_inst: Inst.Index) FnInfo {
             extra_index += @boolToInt(small.has_lib_name);
             extra_index += @boolToInt(small.has_cc);
             extra_index += @boolToInt(small.has_align);
+            const ret_ty_body = zir.extra[extra_index..][0..extra.data.ret_body_len];
+            extra_index += ret_ty_body.len;
             const body = zir.extra[extra_index..][0..extra.data.body_len];
+            extra_index += body.len;
             break :blk .{
                 .param_block = extra.data.param_block,
+                .ret_ty_body = ret_ty_body,
                 .body = body,
             };
         },
@@ -4983,6 +5016,7 @@ pub fn getFnInfo(zir: Zir, fn_inst: Inst.Index) FnInfo {
     }
     return .{
         .param_body = param_body,
+        .ret_ty_body = info.ret_ty_body,
         .body = info.body,
         .total_params_len = total_params_len,
     };
test/behavior/generics.zig
@@ -1,4 +1,5 @@
 const std = @import("std");
+const builtin = @import("builtin");
 const testing = std.testing;
 const expect = testing.expect;
 const expectEqual = testing.expectEqual;
@@ -14,3 +15,58 @@ test "one param, explicit comptime" {
 fn checkSize(comptime T: type) usize {
     return @sizeOf(T);
 }
+
+test "simple generic fn" {
+    try expect(max(i32, 3, -1) == 3);
+    try expect(max(u8, 1, 100) == 100);
+    if (!builtin.zig_is_stage2) {
+        // TODO: stage2 is incorrectly emitting the following:
+        // error: cast of value 1.23e-01 to type 'f32' loses information
+        try expect(max(f32, 0.123, 0.456) == 0.456);
+    }
+    try expect(add(2, 3) == 5);
+}
+
+fn max(comptime T: type, a: T, b: T) T {
+    if (!builtin.zig_is_stage2) {
+        // TODO: stage2 is incorrectly emitting AIR that allocates a result
+        // value, stores to it, but then returns void instead of the result.
+        return if (a > b) a else b;
+    }
+    if (a > b) {
+        return a;
+    } else {
+        return b;
+    }
+}
+
+fn add(comptime a: i32, b: i32) i32 {
+    return (comptime a) + b;
+}
+
+const the_max = max(u32, 1234, 5678);
+test "compile time generic eval" {
+    try expect(the_max == 5678);
+}
+
+fn gimmeTheBigOne(a: u32, b: u32) u32 {
+    return max(u32, a, b);
+}
+
+fn shouldCallSameInstance(a: u32, b: u32) u32 {
+    return max(u32, a, b);
+}
+
+fn sameButWithFloats(a: f64, b: f64) f64 {
+    return max(f64, a, b);
+}
+
+test "fn with comptime args" {
+    try expect(gimmeTheBigOne(1234, 5678) == 5678);
+    try expect(shouldCallSameInstance(34, 12) == 34);
+    if (!builtin.zig_is_stage2) {
+        // TODO: stage2 llvm backend needs to use fcmp instead of icmp
+        // probably AIR should just have different instructions for floats.
+        try expect(sameButWithFloats(0.43, 0.49) == 0.49);
+    }
+}
test/behavior/generics_stage1.zig
@@ -3,44 +3,7 @@ const testing = std.testing;
 const expect = testing.expect;
 const expectEqual = testing.expectEqual;
 
-test "simple generic fn" {
-    try expect(max(i32, 3, -1) == 3);
-    try expect(max(f32, 0.123, 0.456) == 0.456);
-    try expect(add(2, 3) == 5);
-}
-
-fn max(comptime T: type, a: T, b: T) T {
-    return if (a > b) a else b;
-}
-
-fn add(comptime a: i32, b: i32) i32 {
-    return (comptime a) + b;
-}
-
-const the_max = max(u32, 1234, 5678);
-test "compile time generic eval" {
-    try expect(the_max == 5678);
-}
-
-fn gimmeTheBigOne(a: u32, b: u32) u32 {
-    return max(u32, a, b);
-}
-
-fn shouldCallSameInstance(a: u32, b: u32) u32 {
-    return max(u32, a, b);
-}
-
-fn sameButWithFloats(a: f64, b: f64) f64 {
-    return max(f64, a, b);
-}
-
-test "fn with comptime args" {
-    try expect(gimmeTheBigOne(1234, 5678) == 5678);
-    try expect(shouldCallSameInstance(34, 12) == 34);
-    try expect(sameButWithFloats(0.43, 0.49) == 0.49);
-}
-
-test "var params" {
+test "anytype params" {
     try expect(max_i32(12, 34) == 34);
     try expect(max_f64(1.2, 3.4) == 3.4);
 }