Commit 9716a1c3ab

Evan Haas <evan@lagerdata.com>
2022-02-21 22:18:18
translate-c: Add support for cast-to-union
Fixes #10955
1 parent 4a0b037
lib/std/zig/c_translation.zig
@@ -30,6 +30,12 @@ pub fn cast(comptime DestType: type, target: anytype) DestType {
                 else => {},
             }
         },
+        .Union => |info| {
+            inline for (info.fields) |field| {
+                if (field.field_type == SourceType) return @unionInit(DestType, field.name, target);
+            }
+            @compileError("cast to union type '" ++ @typeName(DestType) ++ "' from type '" ++ @typeName(SourceType) ++ "' which is not present in union");
+        },
         else => {},
     }
     return @as(DestType, target);
src/clang.zig
@@ -258,6 +258,14 @@ pub const CaseStmt = opaque {
     extern fn ZigClangCaseStmt_getSubStmt(*const CaseStmt) *const Stmt;
 };
 
+pub const CastExpr = opaque {
+    pub const getCastKind = ZigClangCastExpr_getCastKind;
+    extern fn ZigClangCastExpr_getCastKind(*const CastExpr) CK;
+
+    pub const getTargetFieldForToUnionCast = ZigClangCastExpr_getTargetFieldForToUnionCast;
+    extern fn ZigClangCastExpr_getTargetFieldForToUnionCast(*const CastExpr, QualType, QualType) ?*const FieldDecl;
+};
+
 pub const CharacterLiteral = opaque {
     pub const getBeginLoc = ZigClangCharacterLiteral_getBeginLoc;
     extern fn ZigClangCharacterLiteral_getBeginLoc(*const CharacterLiteral) SourceLocation;
src/translate_c.zig
@@ -1791,14 +1791,31 @@ fn transCStyleCastExprClass(
     stmt: *const clang.CStyleCastExpr,
     result_used: ResultUsed,
 ) TransError!Node {
+    const cast_expr = @ptrCast(*const clang.CastExpr, stmt);
     const sub_expr = stmt.getSubExpr();
-    const cast_node = (try transCCast(
+    const dst_type = stmt.getType();
+    const src_type = sub_expr.getType();
+    const sub_expr_node = try transExpr(c, scope, sub_expr, .used);
+    const loc = stmt.getBeginLoc();
+
+    const cast_node = if (cast_expr.getCastKind() == .ToUnion) blk: {
+        const field_decl = cast_expr.getTargetFieldForToUnionCast(dst_type, src_type).?; // C syntax error if target field is null
+        const field_name = try c.str(@ptrCast(*const clang.NamedDecl, field_decl).getName_bytes_begin());
+
+        const union_ty = try transQualType(c, scope, dst_type, loc);
+
+        const inits = [1]ast.Payload.ContainerInit.Initializer{.{ .name = field_name, .value = sub_expr_node }};
+        break :blk try Tag.container_init.create(c.arena, .{
+            .lhs = union_ty,
+            .inits = try c.arena.dupe(ast.Payload.ContainerInit.Initializer, &inits),
+        });
+    } else (try transCCast(
         c,
         scope,
-        stmt.getBeginLoc(),
-        stmt.getType(),
-        sub_expr.getType(),
-        try transExpr(c, scope, sub_expr, .used),
+        loc,
+        dst_type,
+        src_type,
+        sub_expr_node,
     ));
     return maybeSuppressResult(c, scope, result_used, cast_node);
 }
@@ -2370,7 +2387,7 @@ fn cIntTypeForEnum(enum_qt: clang.QualType) clang.QualType {
     return enum_decl.getIntegerType();
 }
 
-// when modifying this function, make sure to also update std.meta.cast
+// when modifying this function, make sure to also update std.zig.c_translation.cast
 fn transCCast(
     c: *Context,
     scope: *Scope,
src/zig_clang.cpp
@@ -2986,6 +2986,18 @@ const struct ZigClangCompoundStmt *ZigClangStmtExpr_getSubStmt(const struct ZigC
     return reinterpret_cast<const ZigClangCompoundStmt *>(casted->getSubStmt());
 }
 
+enum ZigClangCK ZigClangCastExpr_getCastKind(const struct ZigClangCastExpr *self) {
+    auto casted = reinterpret_cast<const clang::CastExpr *>(self);
+    return (ZigClangCK)casted->getCastKind();
+}
+
+const struct ZigClangFieldDecl *ZigClangCastExpr_getTargetFieldForToUnionCast(const struct ZigClangCastExpr *self, ZigClangQualType union_type, ZigClangQualType op_type) {
+    clang::QualType union_qt = bitcast(union_type);
+    clang::QualType op_qt = bitcast(op_type);
+    auto casted = reinterpret_cast<const clang::CastExpr *>(self);
+    return reinterpret_cast<const ZigClangFieldDecl *>(casted->getTargetFieldForToUnionCast(union_qt, op_qt));
+}
+
 struct ZigClangSourceLocation ZigClangCharacterLiteral_getBeginLoc(const struct ZigClangCharacterLiteral *self) {
     auto casted = reinterpret_cast<const clang::CharacterLiteral *>(self);
     return bitcast(casted->getBeginLoc());
src/zig_clang.h
@@ -103,6 +103,7 @@ struct ZigClangBuiltinType;
 struct ZigClangCStyleCastExpr;
 struct ZigClangCallExpr;
 struct ZigClangCaseStmt;
+struct ZigClangCastExpr;
 struct ZigClangCharacterLiteral;
 struct ZigClangChooseExpr;
 struct ZigClangCompoundAssignOperator;
@@ -1317,6 +1318,9 @@ ZIG_EXTERN_C struct ZigClangQualType ZigClangDecayedType_getDecayedType(const st
 
 ZIG_EXTERN_C const struct ZigClangCompoundStmt *ZigClangStmtExpr_getSubStmt(const struct ZigClangStmtExpr *);
 
+ZIG_EXTERN_C enum ZigClangCK ZigClangCastExpr_getCastKind(const struct ZigClangCastExpr *);
+ZIG_EXTERN_C const struct ZigClangFieldDecl *ZigClangCastExpr_getTargetFieldForToUnionCast(const struct ZigClangCastExpr *, struct ZigClangQualType, struct ZigClangQualType);
+
 ZIG_EXTERN_C struct ZigClangSourceLocation ZigClangCharacterLiteral_getBeginLoc(const struct ZigClangCharacterLiteral *);
 ZIG_EXTERN_C enum ZigClangCharacterLiteral_CharacterKind ZigClangCharacterLiteral_getKind(const struct ZigClangCharacterLiteral *);
 ZIG_EXTERN_C unsigned ZigClangCharacterLiteral_getValue(const struct ZigClangCharacterLiteral *);
test/behavior/translate_c_macros.h
@@ -15,6 +15,11 @@ struct Foo {
     int a;
 };
 
+union U {
+    long l;
+    double d;
+};
+
 #define SIZE_OF_FOO sizeof(struct Foo)
 
 #define MAP_FAILED	((void *) -1)
@@ -30,3 +35,5 @@ struct Foo {
 #define IGNORE_ME_8(x) (volatile void)(x)
 #define IGNORE_ME_9(x) (const volatile void)(x)
 #define IGNORE_ME_10(x) (volatile const void)(x)
+
+#define UNION_CAST(X) (union U)(X)
test/behavior/translate_c_macros.zig
@@ -47,3 +47,16 @@ test "cast negative integer to pointer" {
 
     try expectEqual(@intToPtr(?*anyopaque, @bitCast(usize, @as(isize, -1))), h.MAP_FAILED);
 }
+
+test "casting to union with a macro" {
+    if (builtin.zig_backend != .stage1) return error.SkipZigTest; // TODO Sema.zirUnionInitPtr
+
+    const l: c_long = 42;
+    const d: f64 = 2.0;
+
+    var casted = h.UNION_CAST(l);
+    try expectEqual(l, casted.l);
+
+    casted = h.UNION_CAST(d);
+    try expectEqual(d, casted.d);
+}
test/run_translated_c.zig
@@ -1829,4 +1829,26 @@ pub fn addCases(cases: *tests.RunTranslatedCContext) void {
         \\    return 0;
         \\}
     , "");
+
+    cases.add("Cast-to-union. Issue #10955",
+        \\#include <stdlib.h>
+        \\struct S { int x; };
+        \\union U {
+        \\    long l;
+        \\    double d;
+        \\    struct S s;
+        \\};
+        \\union U bar(union U u) { return u; }
+        \\int main(void) {
+        \\    union U u = (union U) 42L;
+        \\    if (u.l != 42L) abort();
+        \\    u = (union U) 2.0;
+        \\    if (u.d != 2.0) abort();
+        \\    u = bar((union U)4.0);
+        \\    if (u.d != 4.0) abort();
+        \\    u = (union U)(struct S){ .x = 5 };
+        \\    if (u.s.x != 5) abort();
+        \\    return 0;
+        \\}
+    , "");
 }