Commit 6f47513009

Charlie Stanton <charlie@shtanton.com>
2020-06-21 19:24:59
Adds std.meta.cast and uses it to simplify translate-c
1 parent 5622044
Changed files (3)
lib
src-self-hosted
test
lib/std/meta.zig
@@ -693,3 +693,98 @@ pub fn Vector(comptime len: u32, comptime child: type) type {
         },
     });
 }
+
+/// Given a type and value, cast the value to the type as c would
+pub fn cast(comptime DestType: type, target: var) DestType {
+    const TargetType = @TypeOf(target);
+    switch (@typeInfo(DestType)) {
+        .Pointer => |_| {
+            switch (@typeInfo(TargetType)) {
+                .Int => |_| {
+                    return @intToPtr(DestType, target);
+                },
+                .ComptimeInt => |_| {
+                    return @intToPtr(DestType, target);
+                },
+                .Pointer => |ptr| {
+                    return @ptrCast(DestType, @alignCast(ptr.alignment, target));
+                },
+                .Optional => |opt| {
+                    if (@typeInfo(opt.child) == .Pointer) {
+                        return @ptrCast(DestType, @alignCast(@alignOf(opt.child.Child), target));
+                    }
+                },
+                else => {},
+            }
+        },
+        .Optional => |opt| {
+            if (@typeInfo(opt.child) == .Pointer) {
+                switch (@typeInfo(TargetType)) {
+                    .Int => |_| {
+                        return @intToPtr(DestType, target);
+                    },
+                    .ComptimeInt => |_| {
+                        return @intToPtr(DestType, target);
+                    },
+                    .Pointer => |ptr| {
+                        return @ptrCast(DestType, @alignCast(ptr.alignment, target));
+                    },
+                    .Optional => |target_opt| {
+                        if (@typeInfo(target_opt.child) == .Pointer) {
+                            return @ptrCast(DestType, @alignCast(@alignOf(target_opt.child.Child), target));
+                        }
+                    },
+                    else => {},
+                }
+            }
+        },
+        .Enum => |_| {
+            if (@typeInfo(TargetType) == .Int or @typeInfo(TargetType) == .ComptimeInt) {
+                return @intToEnum(DestType, target);
+            }
+        },
+        .EnumLiteral => |_| {
+            if (@typeInfo(TargetType) == .Int or @typeInfo(TargetType) == .ComptimeInt) {
+                return @intToEnum(DestType, target);
+            }
+        },
+        .Int => |_| {
+            switch (@typeInfo(TargetType)) {
+                .Pointer => |_| {
+                    return @as(DestType, @ptrToInt(target));
+                },
+                .Optional => |opt| {
+                    if (@typeInfo(opt.child) == .Pointer) {
+                        return @as(DestType, @ptrToInt(target));
+                    }
+                },
+                .Enum => |_| {
+                    return @as(DestType, @enumToInt(target));
+                },
+                .EnumLiteral => |_| {
+                    return @as(DestType, @enumToInt(target));
+                },
+                else => {},
+            }
+        },
+        else => {},
+    }
+    return @as(DestType, target);
+}
+
+test "std.meta.cast" {
+    const E = enum(u2) {
+        Zero,
+        One,
+        Two,
+    };
+
+    var i = @as(i64, 10);
+
+    testing.expect(cast(?*c_void, 0) == @intToPtr(?*c_void, 0));
+    testing.expect(cast(*u8, 16) == @intToPtr(*u8, 16));
+    testing.expect(cast(u64, @as(u32, 10)) == @as(u64, 10));
+    testing.expect(cast(E, 1) == .One);
+    testing.expect(cast(u8, E.Two) == 2);
+    testing.expect(cast(*u64, &i).* == @as(u64, 10));
+}
src-self-hosted/translate_c.zig
@@ -5668,161 +5668,27 @@ fn parseCPrimaryExpr(c: *Context, it: *CTokenList.Iterator, source: []const u8,
 
             const lparen = try appendToken(c, .LParen, "(");
 
-            if (saw_integer_literal) {
-                //(  if (@typeInfo(dest) == .Pointer))
-                //    @intToPtr(dest, x)
-                //else
-                //    @as(dest, x) )
-                const if_node = try transCreateNodeIf(c);
-                const type_info_node = try c.createBuiltinCall("@typeInfo", 1);
-                type_info_node.params()[0] = inner_node;
-                type_info_node.rparen_token = try appendToken(c, .LParen, ")");
-                const cmp_node = try c.arena.create(ast.Node.InfixOp);
-                cmp_node.* = .{
-                    .op_token = try appendToken(c, .EqualEqual, "=="),
-                    .lhs = &type_info_node.base,
-                    .op = .EqualEqual,
-                    .rhs = try transCreateNodeEnumLiteral(c, "Pointer"),
-                };
-                if_node.condition = &cmp_node.base;
-                _ = try appendToken(c, .RParen, ")");
-
-                const int_to_ptr = try c.createBuiltinCall("@intToPtr", 2);
-                int_to_ptr.params()[0] = inner_node;
-                int_to_ptr.params()[1] = node_to_cast;
-                int_to_ptr.rparen_token = try appendToken(c, .RParen, ")");
-                if_node.body = &int_to_ptr.base;
-
-                const else_node = try transCreateNodeElse(c);
-                if_node.@"else" = else_node;
-
-                const as_node = try c.createBuiltinCall("@as", 2);
-                as_node.params()[0] = inner_node;
-                as_node.params()[1] = node_to_cast;
-                as_node.rparen_token = try appendToken(c, .RParen, ")");
-                else_node.body = &as_node.base;
-
-                const group_node = try c.arena.create(ast.Node.GroupedExpression);
-                group_node.* = .{
-                    .lparen = lparen,
-                    .expr = &if_node.base,
-                    .rparen = try appendToken(c, .RParen, ")"),
-                };
-                return &group_node.base;
-            }
-
-            //(  if (@typeInfo(@TypeOf(x)) == .Pointer)
-            //    @ptrCast(dest, @alignCast(@alignOf(dest.Child), x))
-            //else if (@typeInfo(@TypeOf(x)) == .Int and @typeInfo(dest) == .Pointer))
-            //    @intToPtr(dest, x)
-            //else
-            //    @as(dest, x) )
-
-            const if_1 = try transCreateNodeIf(c);
-            const type_info_1 = try c.createBuiltinCall("@typeInfo", 1);
-            const type_of_1 = try c.createBuiltinCall("@TypeOf", 1);
-            type_info_1.params()[0] = &type_of_1.base;
-            type_of_1.params()[0] = node_to_cast;
-            type_of_1.rparen_token = try appendToken(c, .RParen, ")");
-            type_info_1.rparen_token = try appendToken(c, .RParen, ")");
-
-            const cmp_1 = try c.arena.create(ast.Node.InfixOp);
-            cmp_1.* = .{
-                .op_token = try appendToken(c, .EqualEqual, "=="),
-                .lhs = &type_info_1.base,
-                .op = .EqualEqual,
-                .rhs = try transCreateNodeEnumLiteral(c, "Pointer"),
-            };
-            if_1.condition = &cmp_1.base;
-            _ = try appendToken(c, .RParen, ")");
-
-            const period_tok = try appendToken(c, .Period, ".");
-            const child_ident = try transCreateNodeIdentifier(c, "Child");
-            const inner_node_child = try c.arena.create(ast.Node.InfixOp);
-            inner_node_child.* = .{
-                .op_token = period_tok,
-                .lhs = inner_node,
-                .op = .Period,
-                .rhs = child_ident,
+            //(@import("std").meta.cast(dest, x))
+            const import_fn_call = try c.createBuiltinCall("@import", 1);
+            const std_token = try appendToken(c, .StringLiteral, "\"std\"");
+            const std_node = try c.arena.create(ast.Node.StringLiteral);
+            std_node.* = .{
+                .token = std_token,
             };
+            import_fn_call.params()[0] = &std_node.base;
+            import_fn_call.rparen_token = try appendToken(c, .RParen, ")");
+            const inner_field_access = try transCreateNodeFieldAccess(c, &import_fn_call.base, "meta");
+            const outer_field_access = try transCreateNodeFieldAccess(c, inner_field_access, "cast");
 
-            const align_of = try c.createBuiltinCall("@alignOf", 1);
-            align_of.params()[0] = &inner_node_child.base;
-            align_of.rparen_token = try appendToken(c, .RParen, ")");
-            // hack to get zig fmt to render a comma in builtin calls
-            _ = try appendToken(c, .Comma, ",");
-
-            const align_cast = try c.createBuiltinCall("@alignCast", 2);
-            align_cast.params()[0] = &align_of.base;
-            align_cast.params()[1] = node_to_cast;
-            align_cast.rparen_token = try appendToken(c, .RParen, ")");
-
-            const ptr_cast = try c.createBuiltinCall("@ptrCast", 2);
-            ptr_cast.params()[0] = inner_node;
-            ptr_cast.params()[1] = &align_cast.base;
-            ptr_cast.rparen_token = try appendToken(c, .RParen, ")");
-            if_1.body = &ptr_cast.base;
-
-            const else_1 = try transCreateNodeElse(c);
-            if_1.@"else" = else_1;
-
-            const if_2 = try transCreateNodeIf(c);
-            const type_info_2 = try c.createBuiltinCall("@typeInfo", 1);
-            const type_of_2 = try c.createBuiltinCall("@TypeOf", 1);
-            type_info_2.params()[0] = &type_of_2.base;
-            type_of_2.params()[0] = node_to_cast;
-            type_of_2.rparen_token = try appendToken(c, .RParen, ")");
-            type_info_2.rparen_token = try appendToken(c, .RParen, ")");
-
-            const cmp_2 = try c.arena.create(ast.Node.InfixOp);
-            cmp_2.* = .{
-                .op_token = try appendToken(c, .EqualEqual, "=="),
-                .lhs = &type_info_2.base,
-                .op = .EqualEqual,
-                .rhs = try transCreateNodeEnumLiteral(c, "Int"),
-            };
-            if_2.condition = &cmp_2.base;
-            const cmp_4 = try c.arena.create(ast.Node.InfixOp);
-            cmp_4.* = .{
-                .op_token = try appendToken(c, .Keyword_and, "and"),
-                .lhs = &cmp_2.base,
-                .op = .BoolAnd,
-                .rhs = undefined,
-            };
-            const type_info_3 = try c.createBuiltinCall("@typeInfo", 1);
-            type_info_3.params()[0] = inner_node;
-            type_info_3.rparen_token = try appendToken(c, .LParen, ")");
-            const cmp_3 = try c.arena.create(ast.Node.InfixOp);
-            cmp_3.* = .{
-                .op_token = try appendToken(c, .EqualEqual, "=="),
-                .lhs = &type_info_3.base,
-                .op = .EqualEqual,
-                .rhs = try transCreateNodeEnumLiteral(c, "Pointer"),
-            };
-            cmp_4.rhs = &cmp_3.base;
-            if_2.condition = &cmp_4.base;
-            else_1.body = &if_2.base;
-            _ = try appendToken(c, .RParen, ")");
-
-            const int_to_ptr = try c.createBuiltinCall("@intToPtr", 2);
-            int_to_ptr.params()[0] = inner_node;
-            int_to_ptr.params()[1] = node_to_cast;
-            int_to_ptr.rparen_token = try appendToken(c, .RParen, ")");
-            if_2.body = &int_to_ptr.base;
-
-            const else_2 = try transCreateNodeElse(c);
-            if_2.@"else" = else_2;
-
-            const as = try c.createBuiltinCall("@as", 2);
-            as.params()[0] = inner_node;
-            as.params()[1] = node_to_cast;
-            as.rparen_token = try appendToken(c, .RParen, ")");
-            else_2.body = &as.base;
+            const cast_fn_call = try c.createCall(outer_field_access, 2);
+            cast_fn_call.params()[0] = inner_node;
+            cast_fn_call.params()[1] = node_to_cast;
+            cast_fn_call.rtoken = try appendToken(c, .RParen, ")");
 
             const group_node = try c.arena.create(ast.Node.GroupedExpression);
             group_node.* = .{
                 .lparen = lparen,
-                .expr = &if_1.base,
+                .expr = &cast_fn_call.base,
                 .rparen = try appendToken(c, .RParen, ")"),
             };
             return &group_node.base;
test/translate_c.zig
@@ -1473,7 +1473,7 @@ pub fn addCases(cases: *tests.TranslateCContext) void {
     cases.add("macro pointer cast",
         \\#define NRF_GPIO ((NRF_GPIO_Type *) NRF_GPIO_BASE)
     , &[_][]const u8{
-        \\pub const NRF_GPIO = (if (@typeInfo(@TypeOf(NRF_GPIO_BASE)) == .Pointer) @ptrCast([*c]NRF_GPIO_Type, @alignCast(@alignOf([*c]NRF_GPIO_Type.Child), NRF_GPIO_BASE)) else if (@typeInfo(@TypeOf(NRF_GPIO_BASE)) == .Int and @typeInfo([*c]NRF_GPIO_Type) == .Pointer) @intToPtr([*c]NRF_GPIO_Type, NRF_GPIO_BASE) else @as([*c]NRF_GPIO_Type, NRF_GPIO_BASE));
+        \\pub const NRF_GPIO = (@import("std").meta.cast([*c]NRF_GPIO_Type, NRF_GPIO_BASE));
     });
 
     cases.add("basic macro function",
@@ -2683,11 +2683,11 @@ pub fn addCases(cases: *tests.TranslateCContext) void {
         \\#define FOO(bar) baz((void *)(baz))
         \\#define BAR (void*) a
     , &[_][]const u8{
-        \\pub inline fn FOO(bar: var) @TypeOf(baz((if (@typeInfo(@TypeOf(baz)) == .Pointer) @ptrCast(?*c_void, @alignCast(@alignOf(?*c_void.Child), baz)) else if (@typeInfo(@TypeOf(baz)) == .Int and @typeInfo(?*c_void) == .Pointer) @intToPtr(?*c_void, baz) else @as(?*c_void, baz)))) {
-        \\    return baz((if (@typeInfo(@TypeOf(baz)) == .Pointer) @ptrCast(?*c_void, @alignCast(@alignOf(?*c_void.Child), baz)) else if (@typeInfo(@TypeOf(baz)) == .Int and @typeInfo(?*c_void) == .Pointer) @intToPtr(?*c_void, baz) else @as(?*c_void, baz)));
+        \\pub inline fn FOO(bar: var) @TypeOf(baz((@import("std").meta.cast(?*c_void, baz)))) {
+        \\    return baz((@import("std").meta.cast(?*c_void, baz)));
         \\}
     ,
-        \\pub const BAR = (if (@typeInfo(@TypeOf(a)) == .Pointer) @ptrCast(?*c_void, @alignCast(@alignOf(?*c_void.Child), a)) else if (@typeInfo(@TypeOf(a)) == .Int and @typeInfo(?*c_void) == .Pointer) @intToPtr(?*c_void, a) else @as(?*c_void, a));
+        \\pub const BAR = (@import("std").meta.cast(?*c_void, a));
     });
 
     cases.add("macro conditional operator",
@@ -2905,8 +2905,8 @@ pub fn addCases(cases: *tests.TranslateCContext) void {
         \\#define DefaultScreen(dpy) (((_XPrivDisplay)(dpy))->default_screen)
         \\
     , &[_][]const u8{
-        \\pub inline fn DefaultScreen(dpy: var) @TypeOf((if (@typeInfo(@TypeOf(dpy)) == .Pointer) @ptrCast(_XPrivDisplay, @alignCast(@alignOf(_XPrivDisplay.Child), dpy)) else if (@typeInfo(@TypeOf(dpy)) == .Int and @typeInfo(_XPrivDisplay) == .Pointer) @intToPtr(_XPrivDisplay, dpy) else @as(_XPrivDisplay, dpy)).*.default_screen) {
-        \\    return (if (@typeInfo(@TypeOf(dpy)) == .Pointer) @ptrCast(_XPrivDisplay, @alignCast(@alignOf(_XPrivDisplay.Child), dpy)) else if (@typeInfo(@TypeOf(dpy)) == .Int and @typeInfo(_XPrivDisplay) == .Pointer) @intToPtr(_XPrivDisplay, dpy) else @as(_XPrivDisplay, dpy)).*.default_screen;
+        \\pub inline fn DefaultScreen(dpy: var) @TypeOf((@import("std").meta.cast(_XPrivDisplay, dpy)).*.default_screen) {
+        \\    return (@import("std").meta.cast(_XPrivDisplay, dpy)).*.default_screen;
         \\}
     });
 
@@ -2914,9 +2914,9 @@ pub fn addCases(cases: *tests.TranslateCContext) void {
         \\#define NULL ((void*)0)
         \\#define FOO ((int)0x8000)
     , &[_][]const u8{
-        \\pub const NULL = (if (@typeInfo(?*c_void) == .Pointer) @intToPtr(?*c_void, 0) else @as(?*c_void, 0));
+        \\pub const NULL = (@import("std").meta.cast(?*c_void, 0));
     ,
-        \\pub const FOO = (if (@typeInfo(c_int) == .Pointer) @intToPtr(c_int, 0x8000) else @as(c_int, 0x8000));
+        \\pub const FOO = (@import("std").meta.cast(c_int, 0x8000));
     });
 
     if (std.Target.current.abi == .msvc) {