Commit de62dc884e

Parker Liu <flyfish30@users.noreply.github.com>
2025-04-02 22:07:41
translate-c: fix function prototype decalared inside a function
* If a function prototype is declarated inside a function, do not translate it to a top-level extern function declaration. Similar to extern local variable, just wrapped it into a block-local struct. * Add a new extern_local_fn tag of aro_translate_c node for present extern local function declaration. * When a function body has a C function prototype declaration, it adds an extern local function declaration. Subsequent function references will look for this function declaration.
1 parent 9dfdf35
Changed files (4)
lib
compiler
src
test
lib/compiler/aro_translate_c/ast.zig
@@ -57,6 +57,8 @@ pub const Node = extern union {
         static_local_var,
         /// const ExternLocal_name = struct { init }
         extern_local_var,
+        /// const ExternLocal_name = struct { init }
+        extern_local_fn,
         /// var name = init.*
         mut_str,
         func,
@@ -367,7 +369,13 @@ pub const Node = extern union {
                 .c_pointer, .single_pointer => Payload.Pointer,
                 .array_type, .null_sentinel_array_type => Payload.Array,
                 .arg_redecl, .alias, .fail_decl => Payload.ArgRedecl,
-                .var_simple, .pub_var_simple, .static_local_var, .extern_local_var, .mut_str => Payload.SimpleVarDecl,
+                .var_simple,
+                .pub_var_simple,
+                .static_local_var,
+                .extern_local_var,
+                .extern_local_fn,
+                .mut_str,
+                => Payload.SimpleVarDecl,
                 .enum_constant => Payload.EnumConstant,
                 .array_filler => Payload.ArrayFiller,
                 .pub_inline_fn => Payload.PubInlineFn,
@@ -1265,8 +1273,11 @@ fn renderNode(c: *Context, node: Node) Allocator.Error!NodeIndex {
                 } },
             });
         },
-        .extern_local_var => {
-            const payload = node.castTag(.extern_local_var).?.data;
+        .extern_local_var, .extern_local_fn => {
+            const payload = if (node.tag() == .extern_local_var)
+                node.castTag(.extern_local_var).?.data
+            else
+                node.castTag(.extern_local_fn).?.data;
 
             const const_tok = try c.addToken(.keyword_const, "const");
             _ = try c.addIdentifier(payload.name);
@@ -2293,7 +2304,7 @@ fn renderNullSentinelArrayType(c: *Context, len: usize, elem_type: Node) !NodeIn
 fn addSemicolonIfNeeded(c: *Context, node: Node) !void {
     switch (node.tag()) {
         .warning => unreachable,
-        .var_decl, .var_simple, .arg_redecl, .alias, .block, .empty_block, .block_single, .@"switch", .static_local_var, .extern_local_var, .mut_str => {},
+        .var_decl, .var_simple, .arg_redecl, .alias, .block, .empty_block, .block_single, .@"switch", .static_local_var, .extern_local_var, .extern_local_fn, .mut_str => {},
         .while_true => {
             const payload = node.castTag(.while_true).?.data;
             return addSemicolonIfNotBlock(c, payload);
@@ -2390,6 +2401,7 @@ fn renderNodeGrouped(c: *Context, node: Node) !NodeIndex {
         .builtin_extern,
         .static_local_var,
         .extern_local_var,
+        .extern_local_fn,
         .mut_str,
         .macro_arithmetic,
         => {
lib/compiler/aro_translate_c.zig
@@ -1502,19 +1502,29 @@ pub fn ScopeExtra(comptime ScopeExtraContext: type, comptime ScopeExtraType: typ
                 return scope.base.parent.?.getAlias(name);
             }
 
-            /// Finds the (potentially) mangled struct name for a locally scoped extern variable given the original declaration name.
+            /// Finds the (potentially) mangled struct name for a locally scoped extern variable or function given the original declaration name.
             ///
             /// Block scoped extern declarations translate to:
             ///     const MangledStructName = struct {extern [qualifiers] original_extern_variable_name: [type]};
             /// This finds MangledStructName given original_extern_variable_name for referencing correctly in transDeclRefExpr()
             pub fn getLocalExternAlias(scope: *Block, name: []const u8) ?[]const u8 {
                 for (scope.statements.items) |node| {
-                    if (node.tag() == .extern_local_var) {
-                        const parent_node = node.castTag(.extern_local_var).?;
-                        const init_node = parent_node.data.init.castTag(.var_decl).?;
-                        if (std.mem.eql(u8, init_node.data.name, name)) {
-                            return parent_node.data.name;
-                        }
+                    switch (node.tag()) {
+                        .extern_local_var => {
+                            const parent_node = node.castTag(.extern_local_var).?;
+                            const init_node = parent_node.data.init.castTag(.var_decl).?;
+                            if (std.mem.eql(u8, init_node.data.name, name)) {
+                                return parent_node.data.name;
+                            }
+                        },
+                        .extern_local_fn => {
+                            const parent_node = node.castTag(.extern_local_fn).?;
+                            const init_node = parent_node.data.init.castTag(.func).?;
+                            if (std.mem.eql(u8, init_node.data.name.?, name)) {
+                                return parent_node.data.name;
+                            }
+                        },
+                        else => {},
                     }
                 }
                 return null;
src/translate_c.zig
@@ -325,7 +325,7 @@ fn declVisitorNamesOnly(c: *Context, decl: *const clang.Decl) Error!void {
 fn declVisitor(c: *Context, decl: *const clang.Decl) Error!void {
     switch (decl.getKind()) {
         .Function => {
-            return visitFnDecl(c, @as(*const clang.FunctionDecl, @ptrCast(decl)));
+            return transFnDecl(c, &c.global_scope.base, @as(*const clang.FunctionDecl, @ptrCast(decl)));
         },
         .Typedef => {
             try transTypeDef(c, &c.global_scope.base, @as(*const clang.TypedefNameDecl, @ptrCast(decl)));
@@ -367,7 +367,7 @@ fn transFileScopeAsm(c: *Context, scope: *Scope, file_scope_asm: *const clang.Fi
     try scope.appendNode(comptime_node);
 }
 
-fn visitFnDecl(c: *Context, fn_decl: *const clang.FunctionDecl) Error!void {
+fn transFnDecl(c: *Context, scope: *Scope, fn_decl: *const clang.FunctionDecl) Error!void {
     const fn_name = try c.str(@as(*const clang.NamedDecl, @ptrCast(fn_decl)).getName_bytes_begin());
     if (c.global_scope.sym_table.contains(fn_name))
         return; // Avoid processing this decl twice
@@ -375,7 +375,7 @@ fn visitFnDecl(c: *Context, fn_decl: *const clang.FunctionDecl) Error!void {
     // Skip this declaration if a proper definition exists
     if (!fn_decl.isThisDeclarationADefinition()) {
         if (fn_decl.getDefinition()) |def|
-            return visitFnDecl(c, def);
+            return transFnDecl(c, scope, def);
     }
 
     const fn_decl_loc = fn_decl.getLocation();
@@ -446,6 +446,9 @@ fn visitFnDecl(c: *Context, fn_decl: *const clang.FunctionDecl) Error!void {
     };
 
     if (!decl_ctx.has_body) {
+        if (scope.id != .root) {
+            return addLocalExternFnDecl(c, scope, fn_name, Node.initPayload(&proto_node.base));
+        }
         return addTopLevelDecl(c, fn_name, Node.initPayload(&proto_node.base));
     }
 
@@ -455,7 +458,7 @@ fn visitFnDecl(c: *Context, fn_decl: *const clang.FunctionDecl) Error!void {
     block_scope.return_type = return_qt;
     defer block_scope.deinit();
 
-    const scope = &block_scope.base;
+    const top_scope = &block_scope.base;
 
     var param_id: c_uint = 0;
     for (proto_node.data.params) |*param| {
@@ -508,7 +511,7 @@ fn visitFnDecl(c: *Context, fn_decl: *const clang.FunctionDecl) Error!void {
             break :blk;
         }
 
-        const rhs = transZeroInitExpr(c, scope, fn_decl_loc, return_qt.getTypePtr()) catch |err| switch (err) {
+        const rhs = transZeroInitExpr(c, top_scope, fn_decl_loc, return_qt.getTypePtr()) catch |err| switch (err) {
             error.OutOfMemory => |e| return e,
             error.UnsupportedTranslation,
             error.UnsupportedType,
@@ -1874,7 +1877,7 @@ fn transDeclStmtOne(
             try transEnumDecl(c, scope, @as(*const clang.EnumDecl, @ptrCast(decl)));
         },
         .Function => {
-            try visitFnDecl(c, @as(*const clang.FunctionDecl, @ptrCast(decl)));
+            try transFnDecl(c, scope, @as(*const clang.FunctionDecl, @ptrCast(decl)));
         },
         else => {
             const decl_name = try c.str(decl.getDeclKindName());
@@ -1903,11 +1906,19 @@ fn transDeclRefExpr(
     const name = try c.str(@as(*const clang.NamedDecl, @ptrCast(value_decl)).getName_bytes_begin());
     const mangled_name = scope.getAlias(name);
     const decl_is_var = @as(*const clang.Decl, @ptrCast(value_decl)).getKind() == .Var;
-    const potential_local_extern = if (decl_is_var) ((@as(*const clang.VarDecl, @ptrCast(value_decl)).getStorageClass() == .Extern) and (scope.id != .root)) else false;
+    const storage_class = @as(*const clang.VarDecl, @ptrCast(value_decl)).getStorageClass();
+    const potential_local_extern = if (decl_is_var) ((storage_class == .Extern) and (scope.id != .root)) else false;
 
     var confirmed_local_extern = false;
+    var confirmed_local_extern_fn = false;
     var ref_expr = val: {
         if (cIsFunctionDeclRef(@as(*const clang.Expr, @ptrCast(expr)))) {
+            if (scope.id != .root) {
+                if (scope.getLocalExternAlias(name)) |v| {
+                    confirmed_local_extern_fn = true;
+                    break :val try Tag.identifier.create(c.arena, v);
+                }
+            }
             break :val try Tag.fn_identifier.create(c.arena, mangled_name);
         } else if (potential_local_extern) {
             if (scope.getLocalExternAlias(name)) |v| {
@@ -1934,6 +1945,11 @@ fn transDeclRefExpr(
                 .field_name = name, // by necessity, name will always == mangled_name
             });
         }
+    } else if (confirmed_local_extern_fn) {
+        ref_expr = try Tag.field_access.create(c.arena, .{
+            .lhs = ref_expr,
+            .field_name = name, // by necessity, name will always == mangled_name
+        });
     }
     scope.skipVariableDiscard(mangled_name);
     return ref_expr;
@@ -4213,6 +4229,23 @@ fn addTopLevelDecl(c: *Context, name: []const u8, decl_node: Node) !void {
     }
 }
 
+/// Add an "extern" function prototype declaration that's been declared within a scoped block.
+/// Similar to static local variables, this will be wrapped in a struct to work with Zig's syntax requirements.
+///
+fn addLocalExternFnDecl(c: *Context, scope: *Scope, name: []const u8, decl_node: Node) !void {
+    const bs: *Scope.Block = try scope.findBlockScope(c);
+
+    // Special naming convention for local extern function wrapper struct,
+    // this named "ExternLocal_[name]".
+    const struct_name = try std.fmt.allocPrint(c.arena, "{s}_{s}", .{ Scope.Block.extern_inner_prepend, name });
+
+    // Outer Node for the wrapper struct
+    const node = try Tag.extern_local_fn.create(c.arena, .{ .name = struct_name, .init = decl_node });
+
+    try bs.statements.append(node);
+    try bs.discardVariable(c, struct_name);
+}
+
 fn transQualTypeInitializedStringLiteral(c: *Context, elem_ty: Node, string_lit: *const clang.StringLiteral) TypeError!Node {
     const string_lit_size = string_lit.getLength();
     const array_size = @as(usize, @intCast(string_lit_size));
test/translate_c.zig
@@ -3537,9 +3537,12 @@ pub fn addCases(cases: *tests.TranslateCContext) void {
         \\    return bar(1, 2);
         \\}
     , &[_][]const u8{
-        \\pub extern fn bar(c_int, c_int) c_int;
         \\pub export fn foo() c_int {
-        \\    return bar(@as(c_int, 1), @as(c_int, 2));
+        \\    const ExternLocal_bar = struct {
+        \\        pub extern fn bar(c_int, c_int) c_int;
+        \\    };
+        \\    _ = &ExternLocal_bar;
+        \\    return ExternLocal_bar.bar(@as(c_int, 1), @as(c_int, 2));
         \\}
     });