Commit 7ba1f9bfb5

Evan Haas <evan@lagerdata.com>
2022-07-25 18:53:40
translate-c: take address of functions before passing them to @ptrToInt
Fixes #12194
1 parent c8c7986
lib/std/zig/c_translation.zig
@@ -36,6 +36,9 @@ pub fn cast(comptime DestType: type, target: anytype) DestType {
                 .Int => {
                     return castInt(DestType, target);
                 },
+                .Fn => {
+                    return castInt(DestType, @ptrToInt(&target));
+                },
                 else => {},
             }
         },
@@ -45,6 +48,7 @@ pub fn cast(comptime DestType: type, target: anytype) DestType {
             }
             @compileError("cast to union type '" ++ @typeName(DestType) ++ "' from type '" ++ @typeName(SourceType) ++ "' which is not present in union");
         },
+        .Bool => return cast(usize, target) != 0,
         else => {},
     }
     return @as(DestType, target);
src/translate_c/ast.zig
@@ -36,6 +36,7 @@ pub const Node = extern union {
         /// "string"[0..end]
         string_slice,
         identifier,
+        fn_identifier,
         @"if",
         /// if (!operand) break;
         if_not_break,
@@ -335,6 +336,7 @@ pub const Node = extern union {
                 .char_literal,
                 .enum_literal,
                 .identifier,
+                .fn_identifier,
                 .warning,
                 .type,
                 .helpers_macro,
@@ -1058,6 +1060,14 @@ fn renderNode(c: *Context, node: Node) Allocator.Error!NodeIndex {
                 .data = undefined,
             });
         },
+        .fn_identifier => {
+            const payload = node.castTag(.fn_identifier).?.data;
+            return c.addNode(.{
+                .tag = .identifier,
+                .main_token = try c.addIdentifier(payload),
+                .data = undefined,
+            });
+        },
         .float_literal => {
             const payload = node.castTag(.float_literal).?.data;
             return c.addNode(.{
@@ -2234,6 +2244,7 @@ fn renderNodeGrouped(c: *Context, node: Node) !NodeIndex {
         .char_literal,
         .enum_literal,
         .identifier,
+        .fn_identifier,
         .field_access,
         .ptr_cast,
         .type,
src/translate_c.zig
@@ -1950,7 +1950,10 @@ fn transDeclRefExpr(
     const value_decl = expr.getDecl();
     const name = try c.str(@ptrCast(*const clang.NamedDecl, value_decl).getName_bytes_begin());
     const mangled_name = scope.getAlias(name);
-    var ref_expr = try Tag.identifier.create(c.arena, mangled_name);
+    var ref_expr = if (cIsFunctionDeclRef(@ptrCast(*const clang.Expr, expr)))
+        try Tag.fn_identifier.create(c.arena, mangled_name)
+    else
+        try Tag.identifier.create(c.arena, mangled_name);
 
     if (@ptrCast(*const clang.Decl, value_decl).getKind() == .Var) {
         const var_decl = @ptrCast(*const clang.VarDecl, value_decl);
@@ -1999,7 +2002,11 @@ fn transImplicitCastExpr(
         },
         .PointerToBoolean => {
             // @ptrToInt(val) != 0
-            const ptr_to_int = try Tag.ptr_to_int.create(c.arena, try transExpr(c, scope, sub_expr, .used));
+            var ptr_node = try transExpr(c, scope, sub_expr, .used);
+            if (ptr_node.tag() == .fn_identifier) {
+                ptr_node = try Tag.address_of.create(c.arena, ptr_node);
+            }
+            const ptr_to_int = try Tag.ptr_to_int.create(c.arena, ptr_node);
 
             const ne = try Tag.not_equal.create(c.arena, .{ .lhs = ptr_to_int, .rhs = Tag.zero_literal.init() });
             return maybeSuppressResult(c, scope, result_used, ne);
@@ -2042,7 +2049,7 @@ fn isBuiltinDefined(name: []const u8) bool {
 
 fn transBuiltinFnExpr(c: *Context, scope: *Scope, expr: *const clang.Expr, used: ResultUsed) TransError!Node {
     const node = try transExpr(c, scope, expr, used);
-    if (node.castTag(.identifier)) |ident| {
+    if (node.castTag(.fn_identifier)) |ident| {
         const name = ident.data;
         if (!isBuiltinDefined(name)) return fail(c, error.UnsupportedTranslation, expr.getBeginLoc(), "TODO implement function '{s}' in std.zig.c_builtins", .{name});
     }
@@ -2447,7 +2454,10 @@ fn transCCast(
     }
     if (cIsInteger(dst_type) and qualTypeIsPtr(src_type)) {
         // @intCast(dest_type, @ptrToInt(val))
-        const ptr_to_int = try Tag.ptr_to_int.create(c.arena, expr);
+        const ptr_to_int = if (expr.tag() == .fn_identifier)
+            try Tag.ptr_to_int.create(c.arena, try Tag.address_of.create(c.arena, expr))
+        else
+            try Tag.ptr_to_int.create(c.arena, expr);
         return Tag.int_cast.create(c.arena, .{ .lhs = dst_node, .rhs = ptr_to_int });
     }
     if (cIsInteger(src_type) and qualTypeIsPtr(dst_type)) {
test/behavior/translate_c_macros.h
@@ -40,3 +40,11 @@ union U {
 #define CAST_OR_CALL_WITH_PARENS(type_or_fn, val) ((type_or_fn)(val))
 
 #define NESTED_COMMA_OPERATOR (1, (2, 3))
+
+#include <stdint.h>
+#if !defined(__UINTPTR_MAX__)
+typedef _Bool uintptr_t;
+#endif
+
+#define CAST_TO_BOOL(X) (_Bool)(X)
+#define CAST_TO_UINTPTR(X) (uintptr_t)(X)
test/behavior/translate_c_macros.zig
@@ -99,3 +99,17 @@ test "nested comma operator" {
 
     try expectEqual(@as(c_int, 3), h.NESTED_COMMA_OPERATOR);
 }
+
+test "cast functions" {
+    if (builtin.zig_backend == .stage2_wasm) return error.SkipZigTest; // TODO
+    if (builtin.zig_backend == .stage2_x86_64) return error.SkipZigTest; // TODO
+    if (builtin.zig_backend == .stage2_c) return error.SkipZigTest; // TODO
+    if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
+    if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
+
+    const S = struct {
+        fn foo() void {}
+    };
+    try expectEqual(true, h.CAST_TO_BOOL(S.foo));
+    try expect(h.CAST_TO_UINTPTR(S.foo) != 0);
+}
test/run_translated_c.zig
@@ -1861,4 +1861,18 @@ pub fn addCases(cases: *tests.RunTranslatedCContext) void {
         \\    return 0;
         \\}
     , "");
+
+    // The C standard does not require function pointers to be convertible to any integer type.
+    // However, POSIX requires that function pointers have the same representation as `void *`
+    // so that dlsym() can work
+    cases.add("Function to integral",
+        \\#include <stdint.h>
+        \\int main(void) {
+        \\#if defined(__UINTPTR_MAX__) && __has_include(<unistd.h>)
+        \\    uintptr_t x = main;
+        \\    x = (uintptr_t)main;
+        \\#endif
+        \\    return 0;
+        \\}
+    , "");
 }
test/translate_c.zig
@@ -3435,7 +3435,7 @@ pub fn addCases(cases: *tests.TranslateCContext) void {
         \\    var x = arg_x;
         \\    var a: bool = @as(c_int, @boolToInt(x)) != @as(c_int, 1);
         \\    var b: bool = @as(c_int, @boolToInt(a)) != @as(c_int, 0);
-        \\    var c: bool = @ptrToInt(foo) != 0;
+        \\    var c: bool = @ptrToInt(&foo) != 0;
         \\    return foo(@as(c_int, @boolToInt(c)) != @as(c_int, @boolToInt(b)));
         \\}
     });