Commit 65531c73a9

Vexu <git@vexu.eu>
2019-12-17 16:19:28
translate-c-2 switch
1 parent 0283ab8
Changed files (3)
src-self-hosted/clang.zig
@@ -77,6 +77,7 @@ pub const struct_ZigClangPredefinedExpr = @OpaqueType();
 pub const struct_ZigClangInitListExpr = @OpaqueType();
 pub const ZigClangPreprocessingRecord = @OpaqueType();
 pub const ZigClangFloatingLiteral = @OpaqueType();
+pub const ZigClangConstantExpr = @OpaqueType();
 
 pub const ZigClangBO = extern enum {
     PtrMemD,
@@ -1059,3 +1060,17 @@ pub extern fn ZigClangAPFloat_getValueAsApproximateDouble(*const ZigClangFloatin
 pub extern fn ZigClangConditionalOperator_getCond(*const ZigClangConditionalOperator) *const ZigClangExpr;
 pub extern fn ZigClangConditionalOperator_getTrueExpr(*const ZigClangConditionalOperator) *const ZigClangExpr;
 pub extern fn ZigClangConditionalOperator_getFalseExpr(*const ZigClangConditionalOperator) *const ZigClangExpr;
+
+pub extern fn ZigClangSwitchStmt_getConditionVariableDeclStmt(*const ZigClangSwitchStmt) ?*const ZigClangDeclStmt;
+pub extern fn ZigClangSwitchStmt_getCond(*const ZigClangSwitchStmt) *const ZigClangExpr;
+pub extern fn ZigClangSwitchStmt_getBody(*const ZigClangSwitchStmt) *const ZigClangStmt;
+pub extern fn ZigClangSwitchStmt_isAllEnumCasesCovered(*const ZigClangSwitchStmt) bool;
+
+pub extern fn ZigClangCaseStmt_getLHS(*const ZigClangCaseStmt) *const ZigClangExpr;
+pub extern fn ZigClangCaseStmt_getRHS(*const ZigClangCaseStmt) ?*const ZigClangExpr;
+pub extern fn ZigClangCaseStmt_getBeginLoc(*const ZigClangCaseStmt) ZigClangSourceLocation;
+pub extern fn ZigClangCaseStmt_getSubStmt(*const ZigClangCaseStmt) *const ZigClangStmt;
+
+pub extern fn ZigClangDefaultStmt_getSubStmt(*const ZigClangDefaultStmt) *const ZigClangStmt;
+
+pub extern fn ZigClangExpr_EvaluateAsConstantExpr(*const ZigClangExpr, *ZigClangExprEvalResult, ZigClangExpr_ConstExprUsage, *const ZigClangASTContext) bool;
src-self-hosted/translate_c.zig
@@ -54,7 +54,9 @@ const Scope = struct {
 
     const Switch = struct {
         base: Scope,
-        label: []const u8,
+        pending_block: *ast.Node.Block,
+        cases: *ast.Node.Switch.CaseList,
+        has_default: bool = false,
     };
 
     /// used when getting a member `a.b`
@@ -189,8 +191,8 @@ const Scope = struct {
             .Ref => null,
             .FnDef => @fieldParentPtr(FnDef, "base", scope).getAlias(name),
             .Block => @fieldParentPtr(Block, "base", scope).getAlias(name),
+            .Switch,
             .Condition => scope.parent.?.getAlias(name),
-            else => @panic("TODO Scope.getAlias"),
         };
     }
 
@@ -200,15 +202,31 @@ const Scope = struct {
             .Root => @fieldParentPtr(Root, "base", scope).contains(name),
             .FnDef => @fieldParentPtr(FnDef, "base", scope).contains(name),
             .Block => @fieldParentPtr(Block, "base", scope).contains(name),
+            .Switch,
             .Condition => scope.parent.?.contains(name),
-            else => @panic("TODO Scope.contains"),
         };
     }
 
     fn getBreakableScope(inner: *Scope) *Scope {
         var scope = inner;
-        while (scope.id != .Switch and scope.id != .Root) : (scope = scope.parent.?) {}
-        return scope;
+        while (true) {
+            switch (scope.id) {
+                .FnDef => unreachable,
+                .Switch => return scope,
+                else => scope = scope.parent.?,
+            }
+        }
+    }
+
+    fn getSwitch(inner: *Scope) *Scope.Switch {
+        var scope = inner;
+        while (true) {
+            switch (scope.id) {
+                .FnDef => unreachable,
+                .Switch => return @fieldParentPtr(Switch, "base", scope),
+                else => scope = scope.parent.?,
+            }
+        }
     }
 };
 
@@ -634,6 +652,10 @@ fn transStmt(
         .ForStmtClass => return transForLoop(rp, scope, @ptrCast(*const ZigClangForStmt, stmt)),
         .FloatingLiteralClass => return transFloatingLiteral(rp, scope, @ptrCast(*const ZigClangFloatingLiteral, stmt), result_used),
         .ConditionalOperatorClass => return transConditionalOperator(rp, scope, @ptrCast(*const ZigClangConditionalOperator, stmt), result_used),
+        .SwitchStmtClass => return transSwitch(rp, scope, @ptrCast(*const ZigClangSwitchStmt, stmt)),
+        .CaseStmtClass => return transCase(rp, scope, @ptrCast(*const ZigClangCaseStmt, stmt)),
+        .DefaultStmtClass => return transDefault(rp, scope, @ptrCast(*const ZigClangDefaultStmt, stmt)),
+        .ConstantExprClass => return transConstantExpr(rp, scope, @ptrCast(*const ZigClangExpr, stmt), result_used),
         else => {
             return revertAndWarn(
                 rp,
@@ -1374,6 +1396,154 @@ fn transForLoop(
         return &while_node.base;
 }
 
+fn transSwitch(
+    rp: RestorePoint,
+    scope: *Scope,
+    stmt: *const ZigClangSwitchStmt,
+) TransError!*ast.Node {
+    const switch_node = try transCreateNodeSwitch(rp.c);
+    var switch_scope = Scope.Switch{
+        .base = .{
+            .id = .Switch,
+            .parent = scope,
+        },
+        .cases = &switch_node.cases,
+        .pending_block = undefined,
+    };
+
+    var cond_scope = Scope.Condition{
+        .base = .{
+            .parent = scope,
+            .id = .Condition,
+        },
+    };
+    switch_node.expr = try transExpr(rp, &cond_scope.base, ZigClangSwitchStmt_getCond(stmt), .used, .r_value);
+    _ = try appendToken(rp.c, .RParen, ")");
+    _ = try appendToken(rp.c, .LBrace, "{");
+    switch_node.rbrace = try appendToken(rp.c, .RBrace, "}");
+
+    const block_scope = try Scope.Block.init(rp.c, &switch_scope.base, null);
+    // tmp block that all statements will go before being picked up by a case or default
+    const block = try transCreateNodeBlock(rp.c, null);
+    block_scope.block_node = block;
+
+    const switch_block = try transCreateNodeBlock(rp.c, null);
+    try switch_block.statements.push(&switch_node.base);
+    switch_scope.pending_block = switch_block;
+    
+
+    const last = try transStmt(rp, &block_scope.base, ZigClangSwitchStmt_getBody(stmt), .unused, .r_value);
+    _ = try appendToken(rp.c, .Semicolon, ";");
+
+    // take all pending statements
+    var it = last.cast(ast.Node.Block).?.statements.iterator(0);
+    while (it.next()) |n| {
+        try switch_scope.pending_block.statements.push(n.*);
+    }
+
+    switch_scope.pending_block.label = try appendIdentifier(rp.c, "__switch");
+    _ = try appendToken(rp.c, .Colon, ":");
+    if (!switch_scope.has_default) {
+        const else_prong = try transCreateNodeSwitchCase(rp.c, try transCreateNodeSwitchElse(rp.c));
+        else_prong.expr = &(try transCreateNodeBreak(rp.c, "__switch")).base;
+        _ = try appendToken(rp.c, .Comma, ",");
+        try switch_node.cases.push(&else_prong.base);
+    }
+    switch_scope.pending_block.rbrace = try appendToken(rp.c, .RBrace, "}");
+    return &switch_scope.pending_block.base;
+}
+    
+
+fn transCase(
+    rp: RestorePoint,
+    scope: *Scope,
+    stmt: *const ZigClangCaseStmt,
+) TransError!*ast.Node {
+    const block_scope = scope.findBlockScope(rp.c) catch unreachable;
+    const switch_scope = scope.getSwitch();
+    const label = try std.fmt.allocPrint(rp.c.a(), "__case_{}", .{switch_scope.cases.len - @boolToInt(switch_scope.has_default)});
+    _ = try appendToken(rp.c, .Semicolon, ";");
+
+    const expr = if (ZigClangCaseStmt_getRHS(stmt)) |rhs| blk: {
+        const lhs_node = try transExpr(rp, scope, ZigClangCaseStmt_getLHS(stmt), .used, .r_value);
+        const ellips = try appendToken(rp.c, .Ellipsis3, "...");
+        const rhs_node = try transExpr(rp, scope, ZigClangCaseStmt_getLHS(stmt), .used, .r_value);
+
+        const node = try rp.c.a().create(ast.Node.InfixOp);
+        node.* = .{
+            .op_token = ellips,
+            .lhs = lhs_node,
+            .op = .Range,
+            .rhs = rhs_node,
+        };
+        break :blk &node.base;
+    } else
+        try transExpr(rp, scope, ZigClangCaseStmt_getLHS(stmt), .used, .r_value);
+
+
+    const switch_prong = try transCreateNodeSwitchCase(rp.c, expr);
+    switch_prong.expr = &(try transCreateNodeBreak(rp.c, label)).base;
+    _ = try appendToken(rp.c, .Comma, ",");
+    try switch_scope.cases.push(&switch_prong.base);
+
+    const block = try transCreateNodeBlock(rp.c, null);
+    switch_scope.pending_block.label = try appendIdentifier(rp.c, label);
+    _ = try appendToken(rp.c, .Colon, ":");
+    switch_scope.pending_block.rbrace = try appendToken(rp.c, .RBrace, "}");
+    try block.statements.push(&switch_scope.pending_block.base);
+
+    // take all pending statements
+    var it = block_scope.block_node.statements.iterator(0);
+    while (it.next()) |n| {
+        try switch_scope.pending_block.statements.push(n.*);
+    }
+    block_scope.block_node.statements.shrink(0);
+
+    switch_scope.pending_block = block;
+
+    return transStmt(rp, scope, ZigClangCaseStmt_getSubStmt(stmt), .unused, .r_value);
+}
+
+fn transDefault(
+    rp: RestorePoint,
+    scope: *Scope,
+    stmt: *const ZigClangDefaultStmt,
+) TransError!*ast.Node {
+    const block_scope = scope.findBlockScope(rp.c) catch unreachable;
+    const switch_scope = scope.getSwitch();
+    const label = "__default";
+    switch_scope.has_default = true;
+    _ = try appendToken(rp.c, .Semicolon, ";");
+
+    const else_prong = try transCreateNodeSwitchCase(rp.c, try transCreateNodeSwitchElse(rp.c));
+    else_prong.expr = &(try transCreateNodeBreak(rp.c, label)).base;
+    _ = try appendToken(rp.c, .Comma, ",");
+    try switch_scope.cases.push(&else_prong.base);
+
+    const block = try transCreateNodeBlock(rp.c, null);
+    switch_scope.pending_block.label = try appendIdentifier(rp.c, label);
+    _ = try appendToken(rp.c, .Colon, ":");
+    switch_scope.pending_block.rbrace = try appendToken(rp.c, .RBrace, "}");
+    try block.statements.push(&switch_scope.pending_block.base);
+
+    // take all pending statements
+    var it = block_scope.block_node.statements.iterator(0);
+    while (it.next()) |n| {
+        try switch_scope.pending_block.statements.push(n.*);
+    }
+    block_scope.block_node.statements.shrink(0);
+
+    switch_scope.pending_block = block;
+    return transStmt(rp, scope, ZigClangDefaultStmt_getSubStmt(stmt), .unused, .r_value);
+}
+
+fn transConstantExpr(rp: RestorePoint, scope: *Scope, expr: *const ZigClangExpr, used: ResultUsed) TransError!*ast.Node {
+    var result: ZigClangExprEvalResult = undefined;
+    if (!ZigClangExpr_EvaluateAsConstantExpr(expr, &result, .EvaluateForCodeGen, rp.c.clang_context))
+        return revertAndWarn(rp, error.UnsupportedTranslation, ZigClangExpr_getBeginLoc(expr), "invalid constant expression", .{});
+    return maybeSuppressResult(rp, scope, used, try transCreateNodeAPInt(rp.c, ZigClangAPValue_getInt(&result.Val)));
+}
+
 fn transCPtrCast(
     rp: RestorePoint,
     loc: ZigClangSourceLocation,
@@ -1414,7 +1584,7 @@ fn transCPtrCast(
 fn transBreak(rp: RestorePoint, scope: *Scope) TransError!*ast.Node {
     const break_scope = scope.getBreakableScope();
     const br = try transCreateNodeBreak(rp.c, if (break_scope.id == .Switch)
-        @fieldParentPtr(Scope.Switch, "base", break_scope).label
+        "__switch"
     else
         null);
     return &br.base;
@@ -2339,6 +2509,42 @@ fn transCreateNodeContinue(c: *Context) !*ast.Node {
     return &node.base;
 }
 
+fn transCreateNodeSwitch(c: *Context) !*ast.Node.Switch {
+    const switch_tok = try appendToken(c, .Keyword_switch, "switch");
+    _ = try appendToken(c, .LParen, "(");
+
+    const node = try c.a().create(ast.Node.Switch);
+    node.* = .{
+        .switch_token = switch_tok,
+        .expr = undefined,
+        .cases = ast.Node.Switch.CaseList.init(c.a()),
+        .rbrace = undefined,
+    };
+    return node;
+}
+
+fn transCreateNodeSwitchCase(c: *Context, lhs: *ast.Node) !*ast.Node.SwitchCase {
+    const arrow_tok = try appendToken(c, .EqualAngleBracketRight, "=>");
+
+    const node = try c.a().create(ast.Node.SwitchCase);
+    node.* = .{
+        .items = ast.Node.SwitchCase.ItemList.init(c.a()),
+        .arrow_token = arrow_tok,
+        .payload = null,
+        .expr = undefined,
+    };
+    try node.items.push(lhs);
+    return node;
+}
+
+fn transCreateNodeSwitchElse(c: *Context) !*ast.Node {
+    const node = try c.a().create(ast.Node.SwitchElse);
+    node.* = .{
+        .token = try appendToken(c, .Keyword_else, "else"),
+    };
+    return &node.base;
+}
+
 const RestorePoint = struct {
     c: *Context,
     token_index: ast.TokenIndex,
@@ -2818,7 +3024,7 @@ fn appendTokenFmt(c: *Context, token_id: Token.Id, comptime format: []const u8,
 
 // TODO hook up with codegen
 fn isZigPrimitiveType(name: []const u8) bool {
-    if (name.len > 1 and std.mem.startsWith(u8, name, "u") or std.mem.startsWith(u8, name, "u")) {
+    if (name.len > 1 and (name[0] == 'u' or name[0] == 'i')) {
         for (name[1..]) |c| {
             switch (c) {
                 '0'...'9' => {},
@@ -2840,7 +3046,15 @@ fn isZigPrimitiveType(name: []const u8) bool {
     std.mem.eql(u8, name, "c_longdouble") or
     std.mem.eql(u8, name, "noreturn") or
     std.mem.eql(u8, name, "type") or
-    std.mem.eql(u8, name, "anyerror");
+    std.mem.eql(u8, name, "anyerror") or
+    std.mem.eql(u8, name, "c_short") or
+    std.mem.eql(u8, name, "c_ushort") or
+    std.mem.eql(u8, name, "c_int") or
+    std.mem.eql(u8, name, "c_uint") or
+    std.mem.eql(u8, name, "c_long") or
+    std.mem.eql(u8, name, "c_ulong") or
+    std.mem.eql(u8, name, "c_longlong") or
+    std.mem.eql(u8, name, "c_ulonglong");
 }
 
 fn isValidZigIdentifier(name: []const u8) bool {
test/translate_c.zig
@@ -849,6 +849,48 @@ pub fn addCases(cases: *tests.TranslateCContext) void {
         \\}
     });
 
+    cases.add_2("switch on int",
+        \\int switch_fn(int i) {
+        \\    int res = 0;
+        \\    switch (i) {
+        \\        case 0:
+        \\            res = 1;
+        \\        case 1:
+        \\            res = 2;
+        \\        default:
+        \\            res = 3 * i;
+        \\            break;
+        \\        case 2:
+        \\            res = 5;
+        \\    }
+        \\}
+    , &[_][]const u8{
+        \\pub export fn switch_fn(i: c_int) c_int {
+        \\    var res: c_int = 0;
+        \\    __switch: {
+        \\        __case_2: {
+        \\            __default: {
+        \\                __case_1: {
+        \\                    __case_0: {
+        \\                        switch (i) {
+        \\                            0 => break :__case_0,
+        \\                            1 => break :__case_1,
+        \\                            else => break :__default,
+        \\                            2 => break :__case_2,
+        \\                        }
+        \\                    }
+        \\                    res = 1;
+        \\                }
+        \\                res = 2;
+        \\            }
+        \\            res = (3 * i);
+        \\            break :__switch;
+        \\        }
+        \\        res = 5;
+        \\    }
+        \\}
+    });
+
     /////////////// Cases for only stage1 which are TODO items for stage2 ////////////////
 
     if (builtin.os != builtin.Os.windows) {
@@ -1938,48 +1980,6 @@ pub fn addCases(cases: *tests.TranslateCContext) void {
         \\}
     });
 
-    cases.add("switch on int",
-        \\int switch_fn(int i) {
-        \\    int res = 0;
-        \\    switch (i) {
-        \\        case 0:
-        \\            res = 1;
-        \\        case 1:
-        \\            res = 2;
-        \\        default:
-        \\            res = 3 * i;
-        \\            break;
-        \\        case 2:
-        \\            res = 5;
-        \\    }
-        \\}
-    , &[_][]const u8{
-        \\pub fn switch_fn(i: c_int) c_int {
-        \\    var res: c_int = 0;
-        \\    __switch: {
-        \\        __case_2: {
-        \\            __default: {
-        \\                __case_1: {
-        \\                    __case_0: {
-        \\                        switch (i) {
-        \\                            0 => break :__case_0,
-        \\                            1 => break :__case_1,
-        \\                            else => break :__default,
-        \\                            2 => break :__case_2,
-        \\                        }
-        \\                    }
-        \\                    res = 1;
-        \\                }
-        \\                res = 2;
-        \\            }
-        \\            res = (3 * i);
-        \\            break :__switch;
-        \\        }
-        \\        res = 5;
-        \\    }
-        \\}
-    });
-
     cases.addC("implicit casts",
         \\#include <stdbool.h>
         \\
@@ -2237,4 +2237,46 @@ pub fn addCases(cases: *tests.TranslateCContext) void {
     , &[_][]const u8{
         \\pub const NRF_GPIO = if (@typeId(@TypeOf(NRF_GPIO_BASE)) == @import("builtin").TypeId.Pointer) @ptrCast([*c]NRF_GPIO_Type, NRF_GPIO_BASE) else if (@typeId(@TypeOf(NRF_GPIO_BASE)) == @import("builtin").TypeId.Int) @intToPtr([*c]NRF_GPIO_Type, NRF_GPIO_BASE) else @as([*c]NRF_GPIO_Type, NRF_GPIO_BASE);
     });
+
+    cases.add("switch on int",
+        \\int switch_fn(int i) {
+        \\    int res = 0;
+        \\    switch (i) {
+        \\        case 0:
+        \\            res = 1;
+        \\        case 1:
+        \\            res = 2;
+        \\        default:
+        \\            res = 3 * i;
+        \\            break;
+        \\        case 2:
+        \\            res = 5;
+        \\    }
+        \\}
+    , &[_][]const u8{
+        \\pub fn switch_fn(i: c_int) c_int {
+        \\    var res: c_int = 0;
+        \\    __switch: {
+        \\        __case_2: {
+        \\            __default: {
+        \\                __case_1: {
+        \\                    __case_0: {
+        \\                        switch (i) {
+        \\                            0 => break :__case_0,
+        \\                            1 => break :__case_1,
+        \\                            else => break :__default,
+        \\                            2 => break :__case_2,
+        \\                        }
+        \\                    }
+        \\                    res = 1;
+        \\                }
+        \\                res = 2;
+        \\            }
+        \\            res = (3 * i);
+        \\            break :__switch;
+        \\        }
+        \\        res = 5;
+        \\    }
+        \\}
+    });
 }