Commit bd89a73d52

Andrew Kelley <andrew@ziglang.org>
2022-05-30 23:16:28
Sema: implement functions generic across callconv() or align()
1 parent 7e98b04
Changed files (3)
src
test
behavior
src/Sema.zig
@@ -6547,7 +6547,7 @@ fn zirFunc(
         inst,
         0,
         target_util.defaultAddressSpace(target, .function),
-        null,
+        FuncLinkSection.default,
         cc,
         ret_ty,
         false,
@@ -6660,15 +6660,26 @@ fn handleExternLibName(
     return sema.gpa.dupeZ(u8, lib_name);
 }
 
+const FuncLinkSection = union(enum) {
+    generic,
+    default,
+    explicit: [*:0]const u8,
+};
+
 fn funcCommon(
     sema: *Sema,
     block: *Block,
     src_node_offset: i32,
     func_inst: Zir.Inst.Index,
-    alignment: u32,
-    address_space: std.builtin.AddressSpace,
-    section: ?[*:0]const u8,
-    cc: std.builtin.CallingConvention,
+    /// null means generic poison
+    alignment: ?u32,
+    /// null means generic poison
+    address_space: ?std.builtin.AddressSpace,
+    /// outer null means generic poison; inner null means default link section
+    section: FuncLinkSection,
+    /// null means generic poison
+    cc: ?std.builtin.CallingConvention,
+    /// this might be Type.generic_poison
     bare_return_type: Type,
     var_args: bool,
     inferred_error_set: bool,
@@ -6679,7 +6690,11 @@ fn funcCommon(
 ) CompileError!Air.Inst.Ref {
     const ret_ty_src: LazySrcLoc = .{ .node_offset_fn_type_ret_ty = src_node_offset };
 
-    var is_generic = bare_return_type.tag() == .generic_poison;
+    var is_generic = bare_return_type.tag() == .generic_poison or
+        alignment == null or
+        address_space == null or
+        section == .generic or
+        cc == null;
     // Check for generic params.
     for (block.params.items) |param| {
         if (param.ty.tag() == .generic_poison) is_generic = true;
@@ -6700,25 +6715,28 @@ fn funcCommon(
     errdefer if (maybe_inferred_error_set_node) |node| sema.gpa.destroy(node);
     // Note: no need to errdefer since this will still be in its default state at the end of the function.
 
+    const target = sema.mod.getTarget();
     const fn_ty: Type = fn_ty: {
         // Hot path for some common function types.
         // TODO can we eliminate some of these Type tag values? seems unnecessarily complicated.
-        if (!is_generic and block.params.items.len == 0 and !var_args and
-            alignment == 0 and !inferred_error_set)
+        if (!is_generic and block.params.items.len == 0 and !var_args and !inferred_error_set and
+            alignment.? == 0 and
+            address_space.? == target_util.defaultAddressSpace(target, .function) and
+            section == .default)
         {
-            if (bare_return_type.zigTypeTag() == .NoReturn and cc == .Unspecified) {
+            if (bare_return_type.zigTypeTag() == .NoReturn and cc.? == .Unspecified) {
                 break :fn_ty Type.initTag(.fn_noreturn_no_args);
             }
 
-            if (bare_return_type.zigTypeTag() == .Void and cc == .Unspecified) {
+            if (bare_return_type.zigTypeTag() == .Void and cc.? == .Unspecified) {
                 break :fn_ty Type.initTag(.fn_void_no_args);
             }
 
-            if (bare_return_type.zigTypeTag() == .NoReturn and cc == .Naked) {
+            if (bare_return_type.zigTypeTag() == .NoReturn and cc.? == .Naked) {
                 break :fn_ty Type.initTag(.fn_naked_noreturn_no_args);
             }
 
-            if (bare_return_type.zigTypeTag() == .Void and cc == .C) {
+            if (bare_return_type.zigTypeTag() == .Void and cc.? == .C) {
                 break :fn_ty Type.initTag(.fn_ccc_void_no_args);
             }
         }
@@ -6764,21 +6782,33 @@ fn funcCommon(
             });
         };
 
+        // stage1 bug workaround
+        const cc_workaround = cc orelse undefined;
+        const align_workaround = alignment orelse @as(u32, undefined);
+
         break :fn_ty try Type.Tag.function.create(sema.arena, .{
             .param_types = param_types,
             .comptime_params = comptime_params.ptr,
             .return_type = return_type,
-            .cc = cc,
-            .alignment = alignment,
+            .cc = cc_workaround,
+            .cc_is_generic = cc == null,
+            .alignment = align_workaround,
+            .align_is_generic = alignment == null,
+            .section_is_generic = section == .generic,
+            .addrspace_is_generic = address_space == null,
             .is_var_args = var_args,
             .is_generic = is_generic,
         });
     };
 
     if (sema.owner_decl.owns_tv) {
-        sema.owner_decl.@"linksection" = section;
-        sema.owner_decl.@"align" = alignment;
-        sema.owner_decl.@"addrspace" = address_space;
+        switch (section) {
+            .generic => sema.owner_decl.@"linksection" = undefined,
+            .default => sema.owner_decl.@"linksection" = null,
+            .explicit => |s| sema.owner_decl.@"linksection" = s,
+        }
+        if (alignment) |a| sema.owner_decl.@"align" = a;
+        if (address_space) |a| sema.owner_decl.@"addrspace" = a;
     }
 
     if (is_extern) {
@@ -16780,13 +16810,16 @@ fn zirFuncFancy(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError!A
         break :blk lib_name;
     } else null;
 
-    const @"align": u32 = if (extra.data.bits.has_align_body) blk: {
+    const @"align": ?u32 = if (extra.data.bits.has_align_body) blk: {
         const body_len = sema.code.extra[extra_index];
         extra_index += 1;
         const body = sema.code.extra[extra_index..][0..body_len];
         extra_index += body.len;
 
         const val = try sema.resolveGenericBody(block, align_src, body, inst, Type.u16);
+        if (val.tag() == .generic_poison) {
+            break :blk null;
+        }
         const alignment = @intCast(u32, val.toUnsignedInt(target));
         if (alignment == target_util.defaultFunctionAlignment(target)) {
             break :blk 0;
@@ -16796,7 +16829,12 @@ fn zirFuncFancy(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError!A
     } else if (extra.data.bits.has_align_ref) blk: {
         const align_ref = @intToEnum(Zir.Inst.Ref, sema.code.extra[extra_index]);
         extra_index += 1;
-        const align_tv = try sema.resolveInstConst(block, align_src, align_ref);
+        const align_tv = sema.resolveInstConst(block, align_src, align_ref) catch |err| switch (err) {
+            error.GenericPoison => {
+                break :blk null;
+            },
+            else => |e| return e,
+        };
         const alignment = @intCast(u32, align_tv.val.toUnsignedInt(target));
         if (alignment == target_util.defaultFunctionAlignment(target)) {
             break :blk 0;
@@ -16805,7 +16843,7 @@ fn zirFuncFancy(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError!A
         }
     } else 0;
 
-    const @"addrspace": std.builtin.AddressSpace = if (extra.data.bits.has_addrspace_body) blk: {
+    const @"addrspace": ?std.builtin.AddressSpace = if (extra.data.bits.has_addrspace_body) blk: {
         const body_len = sema.code.extra[extra_index];
         extra_index += 1;
         const body = sema.code.extra[extra_index..][0..body_len];
@@ -16813,32 +16851,48 @@ fn zirFuncFancy(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError!A
 
         const addrspace_ty = try sema.getBuiltinType(block, addrspace_src, "AddressSpace");
         const val = try sema.resolveGenericBody(block, addrspace_src, body, inst, addrspace_ty);
+        if (val.tag() == .generic_poison) {
+            break :blk null;
+        }
         break :blk val.toEnum(std.builtin.AddressSpace);
     } else if (extra.data.bits.has_addrspace_ref) blk: {
         const addrspace_ref = @intToEnum(Zir.Inst.Ref, sema.code.extra[extra_index]);
         extra_index += 1;
-        const addrspace_tv = try sema.resolveInstConst(block, addrspace_src, addrspace_ref);
+        const addrspace_tv = sema.resolveInstConst(block, addrspace_src, addrspace_ref) catch |err| switch (err) {
+            error.GenericPoison => {
+                break :blk null;
+            },
+            else => |e| return e,
+        };
         break :blk addrspace_tv.val.toEnum(std.builtin.AddressSpace);
     } else target_util.defaultAddressSpace(target, .function);
 
-    const @"linksection": ?[*:0]const u8 = if (extra.data.bits.has_section_body) {
+    const @"linksection": FuncLinkSection = if (extra.data.bits.has_section_body) blk: {
         const body_len = sema.code.extra[extra_index];
         extra_index += 1;
         const body = sema.code.extra[extra_index..][0..body_len];
         extra_index += body.len;
 
         const val = try sema.resolveGenericBody(block, section_src, body, inst, Type.initTag(.const_slice_u8));
+        if (val.tag() == .generic_poison) {
+            break :blk FuncLinkSection{ .generic = {} };
+        }
         _ = val;
         return sema.fail(block, section_src, "TODO implement linksection on functions", .{});
-    } else if (extra.data.bits.has_section_ref) {
+    } else if (extra.data.bits.has_section_ref) blk: {
         const section_ref = @intToEnum(Zir.Inst.Ref, sema.code.extra[extra_index]);
         extra_index += 1;
-        const section_tv = try sema.resolveInstConst(block, section_src, section_ref);
+        const section_tv = sema.resolveInstConst(block, section_src, section_ref) catch |err| switch (err) {
+            error.GenericPoison => {
+                break :blk FuncLinkSection{ .generic = {} };
+            },
+            else => |e| return e,
+        };
         _ = section_tv;
         return sema.fail(block, section_src, "TODO implement linksection on functions", .{});
-    } else null;
+    } else FuncLinkSection{ .default = {} };
 
-    const cc: std.builtin.CallingConvention = if (extra.data.bits.has_cc_body) blk: {
+    const cc: ?std.builtin.CallingConvention = if (extra.data.bits.has_cc_body) blk: {
         const body_len = sema.code.extra[extra_index];
         extra_index += 1;
         const body = sema.code.extra[extra_index..][0..body_len];
@@ -16846,13 +16900,21 @@ fn zirFuncFancy(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError!A
 
         const cc_ty = try sema.getBuiltinType(block, addrspace_src, "CallingConvention");
         const val = try sema.resolveGenericBody(block, cc_src, body, inst, cc_ty);
+        if (val.tag() == .generic_poison) {
+            break :blk null;
+        }
         break :blk val.toEnum(std.builtin.CallingConvention);
     } else if (extra.data.bits.has_cc_ref) blk: {
         const cc_ref = @intToEnum(Zir.Inst.Ref, sema.code.extra[extra_index]);
         extra_index += 1;
-        const cc_tv = try sema.resolveInstConst(block, cc_src, cc_ref);
+        const cc_tv = sema.resolveInstConst(block, cc_src, cc_ref) catch |err| switch (err) {
+            error.GenericPoison => {
+                break :blk null;
+            },
+            else => |e| return e,
+        };
         break :blk cc_tv.val.toEnum(std.builtin.CallingConvention);
-    } else .Unspecified;
+    } else std.builtin.CallingConvention.Unspecified;
 
     const ret_ty: Type = if (extra.data.bits.has_ret_ty_body) blk: {
         const body_len = sema.code.extra[extra_index];
src/type.zig
@@ -6120,6 +6120,10 @@ pub const Type = extern union {
                 cc: std.builtin.CallingConvention,
                 is_var_args: bool,
                 is_generic: bool,
+                align_is_generic: bool = false,
+                cc_is_generic: bool = false,
+                section_is_generic: bool = false,
+                addrspace_is_generic: bool = false,
 
                 pub fn paramIsComptime(self: @This(), i: usize) bool {
                     assert(i < self.param_types.len);
test/behavior/align.zig
@@ -334,25 +334,44 @@ fn simple4() align(4) i32 {
     return 0x19;
 }
 
-test "generic function with align param" {
-    if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest;
-    if (builtin.zig_backend == .stage2_llvm) return error.SkipZigTest;
-    if (builtin.zig_backend == .stage2_c) return error.SkipZigTest;
-    if (builtin.zig_backend == .stage2_x86_64) return error.SkipZigTest;
-    if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest;
+test "function align expression depends on generic parameter" {
+    if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
+    if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
 
     // function alignment is a compile error on wasm32/wasm64
     if (native_arch == .wasm32 or native_arch == .wasm64) return error.SkipZigTest;
     if (native_arch == .thumb) return error.SkipZigTest;
 
-    try expect(whyWouldYouEverDoThis(1) == 0x1);
-    try expect(whyWouldYouEverDoThis(4) == 0x1);
-    try expect(whyWouldYouEverDoThis(8) == 0x1);
+    const S = struct {
+        fn doTheTest() !void {
+            try expect(foobar(1) == 2);
+            try expect(foobar(4) == 5);
+            try expect(foobar(8) == 9);
+        }
+
+        fn foobar(comptime align_bytes: u8) align(align_bytes) u8 {
+            return align_bytes + 1;
+        }
+    };
+    try S.doTheTest();
+    comptime try S.doTheTest();
 }
 
-fn whyWouldYouEverDoThis(comptime align_bytes: u8) align(align_bytes) u8 {
-    _ = align_bytes;
-    return 0x1;
+test "function callconv expression depends on generic parameter" {
+    if (builtin.zig_backend == .stage1) return error.SkipZigTest;
+
+    const S = struct {
+        fn doTheTest() !void {
+            try expect(foobar(.C, 1) == 2);
+            try expect(foobar(.Unspecified, 2) == 3);
+        }
+
+        fn foobar(comptime cc: std.builtin.CallingConvention, arg: u8) callconv(cc) u8 {
+            return arg + 1;
+        }
+    };
+    try S.doTheTest();
+    comptime try S.doTheTest();
 }
 
 test "runtime known array index has best alignment possible" {