Commit 221f1d898c

Evan Haas <evan@lagerdata.com>
2021-02-06 00:37:18
translate-c: Improve function pointer handling
Omit address-of operator if operand is a function. Improve handling of function-call translation when using function pointers Fixes #4124
1 parent 1adac0a
src/clang.zig
@@ -848,7 +848,10 @@ pub const UnaryOperator = opaque {
     extern fn ZigClangUnaryOperator_getBeginLoc(*const UnaryOperator) SourceLocation;
 };
 
-pub const ValueDecl = opaque {};
+pub const ValueDecl = opaque {
+    pub const getType = ZigClangValueDecl_getType;
+    extern fn ZigClangValueDecl_getType(*const ValueDecl) QualType;
+};
 
 pub const VarDecl = opaque {
     pub const getLocation = ZigClangVarDecl_getLocation;
src/translate_c.zig
@@ -3208,6 +3208,38 @@ fn transArrayAccess(rp: RestorePoint, scope: *Scope, stmt: *const clang.ArraySub
     return maybeSuppressResult(rp, scope, result_used, &node.base);
 }
 
+/// Check if an expression is ultimately a reference to a function declaration
+/// (which means it should not be unwrapped with `.?` in translated code)
+fn cIsFunctionDeclRef(expr: *const clang.Expr) bool {
+    switch (expr.getStmtClass()) {
+        .ParenExprClass => {
+            const op_expr = @ptrCast(*const clang.ParenExpr, expr).getSubExpr();
+            return cIsFunctionDeclRef(op_expr);
+        },
+        .DeclRefExprClass => {
+            const decl_ref = @ptrCast(*const clang.DeclRefExpr, expr);
+            const value_decl = decl_ref.getDecl();
+            const qt = value_decl.getType();
+            return qualTypeChildIsFnProto(qt);
+        },
+        .ImplicitCastExprClass => {
+            const implicit_cast = @ptrCast(*const clang.ImplicitCastExpr, expr);
+            const cast_kind = implicit_cast.getCastKind();
+            if (cast_kind == .BuiltinFnToFnPtr) return true;
+            if (cast_kind == .FunctionToPointerDecay) {
+                return cIsFunctionDeclRef(implicit_cast.getSubExpr());
+            }
+            return false;
+        },
+        .UnaryOperatorClass => {
+            const un_op = @ptrCast(*const clang.UnaryOperator, expr);
+            const opcode = un_op.getOpcode();
+            return (opcode == .AddrOf or opcode == .Deref) and cIsFunctionDeclRef(un_op.getSubExpr());
+        },
+        else => return false,
+    }
+}
+
 fn transCallExpr(rp: RestorePoint, scope: *Scope, stmt: *const clang.CallExpr, result_used: ResultUsed) TransError!*ast.Node {
     const callee = stmt.getCallee();
     var raw_fn_expr = try transExpr(rp, scope, callee, .used, .r_value);
@@ -3215,24 +3247,9 @@ fn transCallExpr(rp: RestorePoint, scope: *Scope, stmt: *const clang.CallExpr, r
     var is_ptr = false;
     const fn_ty = qualTypeGetFnProto(callee.getType(), &is_ptr);
 
-    const fn_expr = if (is_ptr and fn_ty != null) blk: {
-        if (callee.getStmtClass() == .ImplicitCastExprClass) {
-            const implicit_cast = @ptrCast(*const clang.ImplicitCastExpr, callee);
-            const cast_kind = implicit_cast.getCastKind();
-            if (cast_kind == .BuiltinFnToFnPtr) break :blk raw_fn_expr;
-            if (cast_kind == .FunctionToPointerDecay) {
-                const subexpr = implicit_cast.getSubExpr();
-                if (subexpr.getStmtClass() == .DeclRefExprClass) {
-                    const decl_ref = @ptrCast(*const clang.DeclRefExpr, subexpr);
-                    const named_decl = decl_ref.getFoundDecl();
-                    if (@ptrCast(*const clang.Decl, named_decl).getKind() == .Function) {
-                        break :blk raw_fn_expr;
-                    }
-                }
-            }
-        }
-        break :blk try transCreateNodeUnwrapNull(rp.c, raw_fn_expr);
-    } else
+    const fn_expr = if (is_ptr and fn_ty != null and !cIsFunctionDeclRef(callee))
+        try transCreateNodeUnwrapNull(rp.c, raw_fn_expr)
+    else
         raw_fn_expr;
 
     const num_args = stmt.getNumArgs();
@@ -3379,6 +3396,9 @@ fn transUnaryOperator(rp: RestorePoint, scope: *Scope, stmt: *const clang.UnaryO
         else
             return transCreatePreCrement(rp, scope, stmt, .AssignSub, .MinusEqual, "-=", used),
         .AddrOf => {
+            if (cIsFunctionDeclRef(op_expr)) {
+                return transExpr(rp, scope, op_expr, used, .r_value);
+            }
             const op_node = try transCreateNodeSimplePrefixOp(rp.c, .AddressOf, .Ampersand, "&");
             op_node.rhs = try transExpr(rp, scope, op_expr, used, .r_value);
             return &op_node.base;
src/zig_clang.cpp
@@ -2773,6 +2773,11 @@ struct ZigClangSourceLocation ZigClangUnaryOperator_getBeginLoc(const struct Zig
     return bitcast(casted->getBeginLoc());
 }
 
+struct ZigClangQualType ZigClangValueDecl_getType(const struct ZigClangValueDecl *self) {
+    auto casted = reinterpret_cast<const clang::ValueDecl *>(self);
+    return bitcast(casted->getType());
+}
+
 const struct ZigClangExpr *ZigClangWhileStmt_getCond(const struct ZigClangWhileStmt *self) {
     auto casted = reinterpret_cast<const clang::WhileStmt *>(self);
     return reinterpret_cast<const struct ZigClangExpr *>(casted->getCond());
src/zig_clang.h
@@ -1200,6 +1200,8 @@ ZIG_EXTERN_C struct ZigClangQualType ZigClangUnaryOperator_getType(const struct
 ZIG_EXTERN_C const struct ZigClangExpr *ZigClangUnaryOperator_getSubExpr(const struct ZigClangUnaryOperator *);
 ZIG_EXTERN_C struct ZigClangSourceLocation ZigClangUnaryOperator_getBeginLoc(const struct ZigClangUnaryOperator *);
 
+ZIG_EXTERN_C struct ZigClangQualType ZigClangValueDecl_getType(const struct ZigClangValueDecl *);
+
 ZIG_EXTERN_C const struct ZigClangExpr *ZigClangWhileStmt_getCond(const struct ZigClangWhileStmt *);
 ZIG_EXTERN_C const struct ZigClangStmt *ZigClangWhileStmt_getBody(const struct ZigClangWhileStmt *);
 
test/run_translated_c.zig
@@ -818,4 +818,60 @@ pub fn addCases(cases: *tests.RunTranslatedCContext) void {
         \\    return 0;
         \\}
     , "");
+
+    cases.add("Address of function is no-op",
+        \\#include <stdlib.h>
+        \\#include <stdbool.h>
+        \\typedef int (*myfunc)(int);
+        \\int a(int arg) { return arg + 1;}
+        \\int b(int arg) { return arg + 2;}
+        \\int caller(myfunc fn, int arg) {
+        \\    return fn(arg);
+        \\}
+        \\int main() {
+        \\    myfunc arr[3] = {&a, &b, a};
+        \\    myfunc foo = a;
+        \\    myfunc bar = &(a);
+        \\    if (foo != bar) abort();
+        \\    if (arr[0] == arr[1]) abort();
+        \\    if (arr[0] != arr[2]) abort();
+        \\    if (caller(b, 40) != 42) abort();
+        \\    if (caller(&b, 40) != 42) abort();
+        \\    return 0;
+        \\}
+    , "");
+
+    cases.add("Obscure ways of calling functions; issue #4124",
+        \\#include <stdlib.h>
+        \\static int add(int a, int b) {
+        \\    return a + b;
+        \\}
+        \\typedef int (*adder)(int, int);
+        \\typedef void (*funcptr)(void);
+        \\int main() {
+        \\    if ((add)(1, 2) != 3) abort();
+        \\    if ((&add)(1, 2) != 3) abort();
+        \\    if (add(3, 1) != 4) abort();
+        \\    if ((*add)(2, 3) != 5) abort();
+        \\    if ((**add)(7, -1) != 6) abort();
+        \\    if ((***add)(-2, 9) != 7) abort();
+        \\
+        \\    int (*ptr)(int a, int b);
+        \\    ptr = add;
+        \\
+        \\    if (ptr(1, 2) != 3) abort();
+        \\    if ((*ptr)(3, 1) != 4) abort();
+        \\    if ((**ptr)(2, 3) != 5) abort();
+        \\    if ((***ptr)(7, -1) != 6) abort();
+        \\    if ((****ptr)(-2, 9) != 7) abort();
+        \\
+        \\    funcptr addr1 = (funcptr)(add);
+        \\    funcptr addr2 = (funcptr)(&add);
+        \\
+        \\    if (addr1 != addr2) abort();
+        \\    if (((int(*)(int, int))addr1)(1, 2) != 3) abort();
+        \\    if (((adder)addr2)(1, 2) != 3) abort();
+        \\    return 0;
+        \\}
+    , "");
 }
test/translate_c.zig
@@ -2802,8 +2802,8 @@ pub fn addCases(cases: *tests.TranslateCContext) void {
         \\    fn_f64(3);
         \\    fn_bool(@as(c_int, 123) != 0);
         \\    fn_bool(@as(c_int, 0) != 0);
-        \\    fn_bool(@ptrToInt(&fn_int) != 0);
-        \\    fn_int(@intCast(c_int, @ptrToInt(&fn_int)));
+        \\    fn_bool(@ptrToInt(fn_int) != 0);
+        \\    fn_int(@intCast(c_int, @ptrToInt(fn_int)));
         \\    fn_ptr(@intToPtr(?*c_void, @as(c_int, 42)));
         \\}
     });