Commit 671183fa9a

Andrew Kelley <superjoe30@gmail.com>
2017-11-27 02:05:55
translate-c: support pointer casting
also avoid some unnecessary casts
1 parent 93fac5f
Changed files (2)
src/translate_c.cpp
@@ -440,6 +440,10 @@ static AstNode *trans_create_node_apint(Context *c, const llvm::APSInt &aps_int)
 
 }
 
+static const Type *qual_type_canon(QualType qt) {
+    return qt.getCanonicalType().getTypePtr();
+}
+
 static QualType get_expr_qual_type(Context *c, const Expr *expr) {
     // String literals in C are `char *` but they should really be `const char *`.
     if (expr->getStmtClass() == Stmt::ImplicitCastExprClass) {
@@ -462,10 +466,7 @@ static AstNode *get_expr_type(Context *c, const Expr *expr) {
     return trans_qual_type(c, get_expr_qual_type(c, expr), expr->getLocStart());
 }
 
-static bool expr_types_equal(Context *c, const Expr *expr1, const Expr *expr2) {
-    QualType t1 = get_expr_qual_type(c, expr1);
-    QualType t2 = get_expr_qual_type(c, expr2);
-
+static bool qual_types_equal(QualType t1, QualType t2) {
     if (t1.isConstQualified() != t2.isConstQualified()) {
         return false;
     }
@@ -482,26 +483,27 @@ static bool is_c_void_type(AstNode *node) {
     return (node->type == NodeTypeSymbol && buf_eql_str(node->data.symbol_expr.symbol, "c_void"));
 }
 
-static AstNode* trans_c_cast(Context *c, const SourceLocation &source_location, const QualType &qt, AstNode *expr) {
-    // TODO: maybe widen to increase size
-    // TODO: maybe bitcast to change sign
-    // TODO: maybe truncate to reduce size
-    return trans_create_node_fn_call_1(c, trans_qual_type(c, qt, source_location), expr);
+static bool expr_types_equal(Context *c, const Expr *expr1, const Expr *expr2) {
+    QualType t1 = get_expr_qual_type(c, expr1);
+    QualType t2 = get_expr_qual_type(c, expr2);
+
+    return qual_types_equal(t1, t2);
 }
 
-static bool qual_type_is_fn_ptr(Context *c, const QualType &qt) {
-    const Type *ty = qt.getTypePtr();
+static bool qual_type_is_ptr(QualType qt) {
+    const Type *ty = qual_type_canon(qt);
+    return ty->getTypeClass() == Type::Pointer;
+}
+
+static bool qual_type_is_fn_ptr(Context *c, QualType qt) {
+    const Type *ty = qual_type_canon(qt);
     if (ty->getTypeClass() != Type::Pointer) {
         return false;
     }
     const PointerType *pointer_ty = static_cast<const PointerType*>(ty);
     QualType child_qt = pointer_ty->getPointeeType();
     const Type *child_ty = child_qt.getTypePtr();
-    if (child_ty->getTypeClass() != Type::Paren) {
-        return false;
-    }
-    const ParenType *paren_ty = static_cast<const ParenType *>(child_ty);
-    return paren_ty->getInnerType().getTypePtr()->getTypeClass() == Type::FunctionProto;
+    return child_ty->getTypeClass() == Type::FunctionProto;
 }
 
 static uint32_t qual_type_int_bit_width(Context *c, const QualType &qt, const SourceLocation &source_loc) {
@@ -594,17 +596,26 @@ static bool qual_type_child_is_fn_proto(const QualType &qt) {
     return false;
 }
 
-static QualType resolve_any_typedef(Context *c, QualType qt) {
-    const Type * ty = qt.getTypePtr();
-    if (ty->getTypeClass() != Type::Typedef)
-        return qt;
-    const TypedefType *typedef_ty = static_cast<const TypedefType*>(ty);
-    const TypedefNameDecl *typedef_decl = typedef_ty->getDecl();
-    return typedef_decl->getUnderlyingType();
+static AstNode* trans_c_cast(Context *c, const SourceLocation &source_location, QualType dest_type,
+        QualType src_type, AstNode *expr)
+{
+    if (qual_types_equal(dest_type, src_type)) {
+        return expr;
+    }
+    if (qual_type_is_ptr(dest_type) && qual_type_is_ptr(src_type)) {
+        AstNode *ptr_cast_node = trans_create_node_builtin_fn_call_str(c, "ptrCast");
+        ptr_cast_node->data.fn_call_expr.params.append(trans_qual_type(c, dest_type, source_location));
+        ptr_cast_node->data.fn_call_expr.params.append(expr);
+        return ptr_cast_node;
+    }
+    // TODO: maybe widen to increase size
+    // TODO: maybe bitcast to change sign
+    // TODO: maybe truncate to reduce size
+    return trans_create_node_fn_call_1(c, trans_qual_type(c, dest_type, source_location), expr);
 }
 
 static bool c_is_signed_integer(Context *c, QualType qt) {
-    const Type *c_type = resolve_any_typedef(c, qt).getTypePtr();
+    const Type *c_type = qual_type_canon(qt);
     if (c_type->getTypeClass() != Type::Builtin)
         return false;
     const BuiltinType *builtin_ty = static_cast<const BuiltinType*>(c_type);
@@ -623,7 +634,7 @@ static bool c_is_signed_integer(Context *c, QualType qt) {
 }
 
 static bool c_is_unsigned_integer(Context *c, QualType qt) {
-    const Type *c_type = resolve_any_typedef(c, qt).getTypePtr();
+    const Type *c_type = qual_type_canon(qt);
     if (c_type->getTypeClass() != Type::Builtin)
         return false;
     const BuiltinType *builtin_ty = static_cast<const BuiltinType*>(c_type);
@@ -891,6 +902,11 @@ static AstNode *trans_type(Context *c, const Type *ty, const SourceLocation &sou
                         return nullptr;
                     }
                     // convert c_void to actual void (only for return type)
+                    // we do want to look at the AstNode instead of QualType, because
+                    // if they do something like:
+                    //     typedef Foo void;
+                    //     void foo(void) -> Foo;
+                    // we want to keep the return type AST node.
                     if (is_c_void_type(proto_node->data.fn_proto.return_type)) {
                         proto_node->data.fn_proto.return_type = nullptr;
                     }
@@ -1317,19 +1333,28 @@ static AstNode *trans_create_compound_assign_shift(Context *c, ResultUsed result
         if (rhs == nullptr) return nullptr;
         AstNode *coerced_rhs = trans_create_node_fn_call_1(c, rhs_type, rhs);
 
+        // operation_type(*_ref)
+        AstNode *operation_type_cast = trans_c_cast(c, rhs_location,
+            stmt->getComputationLHSType(),
+            stmt->getLHS()->getType(),
+            trans_create_node_prefix_op(c, PrefixOpDereference,
+                trans_create_node_symbol(c, tmp_var_name)));
+
+        // result_type(... >> u5(rhs))
+        AstNode *result_type_cast = trans_c_cast(c, rhs_location,
+            stmt->getComputationResultType(),
+            stmt->getComputationLHSType(),
+            trans_create_node_bin_op(c,
+                operation_type_cast,
+                bin_op,
+                coerced_rhs));
+
+        // *_ref = ...
         AstNode *assign_statement = trans_create_node_bin_op(c,
             trans_create_node_prefix_op(c, PrefixOpDereference,
                 trans_create_node_symbol(c, tmp_var_name)),
-            BinOpTypeAssign,
-            trans_c_cast(c, rhs_location,
-                stmt->getComputationResultType(),
-                trans_create_node_bin_op(c,
-                    trans_c_cast(c, rhs_location,
-                        stmt->getComputationLHSType(),
-                        trans_create_node_prefix_op(c, PrefixOpDereference,
-                            trans_create_node_symbol(c, tmp_var_name))),
-                    bin_op,
-                    coerced_rhs)));
+            BinOpTypeAssign, result_type_cast);
+
         child_scope->node->data.block.statements.append(assign_statement);
 
         if (result_used == ResultUsedYes) {
@@ -1474,7 +1499,8 @@ static AstNode *trans_implicit_cast_expr(Context *c, TransScope *scope, const Im
                 AstNode *target_node = trans_expr(c, ResultUsedYes, scope, stmt->getSubExpr(), TransRValue);
                 if (target_node == nullptr)
                     return nullptr;
-                return trans_c_cast(c, stmt->getExprLoc(), stmt->getType(), target_node);
+                return trans_c_cast(c, stmt->getExprLoc(), stmt->getType(),
+                        stmt->getSubExpr()->getType(), target_node);
             }
         case CK_FunctionToPointerDecay:
         case CK_ArrayToPointerDecay:
@@ -2177,9 +2203,23 @@ static AstNode *trans_call_expr(Context *c, ResultUsed result_used, TransScope *
     if (callee_raw_node == nullptr)
         return nullptr;
 
-    AstNode *callee_node;
+    AstNode *callee_node = nullptr;
     if (qual_type_is_fn_ptr(c, stmt->getCallee()->getType())) {
-        callee_node = trans_create_node_prefix_op(c, PrefixOpUnwrapMaybe, callee_raw_node);
+        if (stmt->getCallee()->getStmtClass() == Stmt::ImplicitCastExprClass) {
+            const ImplicitCastExpr *implicit_cast = static_cast<const ImplicitCastExpr *>(stmt->getCallee());
+            if (implicit_cast->getCastKind() == CK_FunctionToPointerDecay) {
+                if (implicit_cast->getSubExpr()->getStmtClass() == Stmt::DeclRefExprClass) {
+                    const DeclRefExpr *decl_ref = static_cast<const DeclRefExpr *>(implicit_cast->getSubExpr());
+                    const Decl *decl = decl_ref->getFoundDecl();
+                    if (decl->getKind() == Decl::Function) {
+                        callee_node = callee_raw_node;
+                    }
+                }
+            }
+        }
+        if (callee_node == nullptr) {
+            callee_node = trans_create_node_prefix_op(c, PrefixOpUnwrapMaybe, callee_raw_node);
+        }
     } else {
         callee_node = callee_raw_node;
     }
@@ -2237,7 +2277,7 @@ static AstNode *trans_c_style_cast_expr(Context *c, ResultUsed result_used, Tran
     if (sub_expr_node == nullptr)
         return nullptr;
 
-    return trans_c_cast(c, stmt->getLocStart(), stmt->getType(), sub_expr_node);
+    return trans_c_cast(c, stmt->getLocStart(), stmt->getType(), stmt->getSubExpr()->getType(), sub_expr_node);
 }
 
 static AstNode *trans_unary_expr_or_type_trait_expr(Context *c, TransScope *scope,
test/translate_c.zig
@@ -677,12 +677,12 @@ pub fn addCases(cases: &tests.TranslateCContext) {
         \\    };
         \\    a >>= @import("std").math.Log2Int(c_int)({
         \\        const _ref = &a;
-        \\        (*_ref) = c_int(c_int(*_ref) >> @import("std").math.Log2Int(c_int)(1));
+        \\        (*_ref) = ((*_ref) >> @import("std").math.Log2Int(c_int)(1));
         \\        *_ref
         \\    });
         \\    a <<= @import("std").math.Log2Int(c_int)({
         \\        const _ref = &a;
-        \\        (*_ref) = c_int(c_int(*_ref) << @import("std").math.Log2Int(c_int)(1));
+        \\        (*_ref) = ((*_ref) << @import("std").math.Log2Int(c_int)(1));
         \\        *_ref
         \\    });
         \\}
@@ -735,12 +735,12 @@ pub fn addCases(cases: &tests.TranslateCContext) {
         \\    };
         \\    a >>= @import("std").math.Log2Int(c_uint)({
         \\        const _ref = &a;
-        \\        (*_ref) = c_uint(c_uint(*_ref) >> @import("std").math.Log2Int(c_uint)(1));
+        \\        (*_ref) = ((*_ref) >> @import("std").math.Log2Int(c_uint)(1));
         \\        *_ref
         \\    });
         \\    a <<= @import("std").math.Log2Int(c_uint)({
         \\        const _ref = &a;
-        \\        (*_ref) = c_uint(c_uint(*_ref) << @import("std").math.Log2Int(c_uint)(1));
+        \\        (*_ref) = ((*_ref) << @import("std").math.Log2Int(c_uint)(1));
         \\        *_ref
         \\    });
         \\}
@@ -878,17 +878,21 @@ pub fn addCases(cases: &tests.TranslateCContext) {
 
     cases.addC("deref function pointer",
         \\void foo(void) {}
+        \\void baz(void) {}
         \\void bar(void) {
         \\    void(*f)(void) = foo;
         \\    f();
         \\    (*(f))();
+        \\    baz();
         \\}
     ,
         \\export fn foo() {}
+        \\export fn baz() {}
         \\export fn bar() {
         \\    var f: ?extern fn() = foo;
         \\    (??f)();
         \\    (??f)();
+        \\    baz();
         \\}
     );
 
@@ -1100,15 +1104,14 @@ pub fn addCases(cases: &tests.TranslateCContext) {
         \\    return x;
         \\}
     );
-}
-
 
-
-// TODO
-//float *ptrcast(int *a) {
-//    return (float *)a;
-//}
-// should translate to
-// fn ptrcast(a: ?&c_int) -> ?&f32 {
-//     return @ptrCast(?&f32, a);
-// }
+    cases.add("pointer casting",
+        \\float *ptrcast(int *a) {
+        \\    return (float *)a;
+        \\}
+    ,
+        \\fn ptrcast(a: ?&c_int) -> ?&f32 {
+        \\    return @ptrCast(?&f32, a);
+        \\}
+    );
+}