Commit 51b2f1b80b

Jimmi Holst Christensen <jhc@liab.dk>
2018-03-08 10:29:29
Translate C can now translate switch statements again
1 parent bb80daf
Changed files (2)
src/translate_c.cpp
@@ -104,6 +104,7 @@ static TransScopeRoot *trans_scope_root_create(Context *c);
 static TransScopeWhile *trans_scope_while_create(Context *c, TransScope *parent_scope);
 static TransScopeBlock *trans_scope_block_create(Context *c, TransScope *parent_scope);
 static TransScopeVar *trans_scope_var_create(Context *c, TransScope *parent_scope, Buf *wanted_name);
+static TransScopeSwitch *trans_scope_switch_create(Context *c, TransScope *parent_scope);
 
 static TransScopeBlock *trans_scope_block_find(TransScope *scope);
 
@@ -2527,6 +2528,155 @@ static AstNode *trans_for_loop(Context *c, TransScope *parent_scope, const ForSt
     return loop_block_node;
 }
 
+static AstNode *trans_switch_stmt(Context *c, TransScope *parent_scope, const SwitchStmt *stmt) {
+    TransScopeBlock *block_scope = trans_scope_block_create(c, parent_scope);
+
+    TransScopeSwitch *switch_scope;
+
+    const DeclStmt *var_decl_stmt = stmt->getConditionVariableDeclStmt();
+    if (var_decl_stmt == nullptr) {
+        switch_scope = trans_scope_switch_create(c, &block_scope->base);
+    } else {
+        AstNode *vars_node;
+        TransScope *var_scope = trans_stmt(c, &block_scope->base, var_decl_stmt, &vars_node);
+        if (var_scope == nullptr)
+            return nullptr;
+        if (vars_node != nullptr)
+            block_scope->node->data.block.statements.append(vars_node);
+        switch_scope = trans_scope_switch_create(c, var_scope);
+    }
+    block_scope->node->data.block.statements.append(switch_scope->switch_node);
+
+    // TODO avoid name collisions
+    Buf *end_label_name = buf_create_from_str("__switch");
+    switch_scope->end_label_name = end_label_name;
+    block_scope->node->data.block.name = end_label_name;
+
+    const Expr *cond_expr = stmt->getCond();
+    assert(cond_expr != nullptr);
+
+    AstNode *expr_node = trans_expr(c, ResultUsedYes, &block_scope->base, cond_expr, TransRValue);
+    if (expr_node == nullptr)
+        return nullptr;
+    switch_scope->switch_node->data.switch_expr.expr = expr_node;
+
+    AstNode *body_node;
+    const Stmt *body_stmt = stmt->getBody();
+    if (body_stmt->getStmtClass() == Stmt::CompoundStmtClass) {
+        if (trans_compound_stmt_inline(c, &switch_scope->base, (const CompoundStmt *)body_stmt,
+                                       block_scope->node, nullptr))
+        {
+            return nullptr;
+        }
+    } else {
+        TransScope *body_scope = trans_stmt(c, &switch_scope->base, body_stmt, &body_node);
+        if (body_scope == nullptr)
+            return nullptr;
+        if (body_node != nullptr)
+            block_scope->node->data.block.statements.append(body_node);
+    }
+
+    if (!switch_scope->found_default && !stmt->isAllEnumCasesCovered()) {
+        AstNode *prong_node = trans_create_node(c, NodeTypeSwitchProng);
+        prong_node->data.switch_prong.expr = trans_create_node_break(c, end_label_name, nullptr);
+        switch_scope->switch_node->data.switch_expr.prongs.append(prong_node);
+    }
+
+    return block_scope->node;
+}
+
+static TransScopeSwitch *trans_scope_switch_find(TransScope *scope) {
+    while (scope != nullptr) {
+        if (scope->id == TransScopeIdSwitch) {
+            return (TransScopeSwitch *)scope;
+        }
+        scope = scope->parent;
+    }
+    return nullptr;
+}
+
+static int trans_switch_case(Context *c, TransScope *parent_scope, const CaseStmt *stmt, AstNode **out_node,
+                             TransScope **out_scope) {
+    *out_node = nullptr;
+
+    if (stmt->getRHS() != nullptr) {
+        emit_warning(c, stmt->getLocStart(), "TODO support GNU switch case a ... b extension");
+        return ErrorUnexpected;
+    }
+
+    TransScopeSwitch *switch_scope = trans_scope_switch_find(parent_scope);
+    assert(switch_scope != nullptr);
+
+    Buf *label_name = buf_sprintf("__case_%" PRIu32, switch_scope->case_index);
+    switch_scope->case_index += 1;
+
+    {
+        // Add the prong
+        AstNode *prong_node = trans_create_node(c, NodeTypeSwitchProng);
+        AstNode *item_node = trans_expr(c, ResultUsedYes, &switch_scope->base, stmt->getLHS(), TransRValue);
+        if (item_node == nullptr)
+            return ErrorUnexpected;
+        prong_node->data.switch_prong.items.append(item_node);
+        prong_node->data.switch_prong.expr = trans_create_node_break(c, label_name, nullptr);
+        switch_scope->switch_node->data.switch_expr.prongs.append(prong_node);
+    }
+
+    TransScopeBlock *scope_block = trans_scope_block_find(parent_scope);
+
+    AstNode *case_block = trans_create_node(c, NodeTypeBlock);
+    case_block->data.block.name = label_name;
+    case_block->data.block.statements = scope_block->node->data.block.statements;
+    scope_block->node->data.block.statements = {0};
+    scope_block->node->data.block.statements.append(case_block);
+
+    AstNode *sub_stmt_node;
+    TransScope *new_scope = trans_stmt(c, parent_scope, stmt->getSubStmt(), &sub_stmt_node);
+    if (new_scope == nullptr)
+        return ErrorUnexpected;
+    if (sub_stmt_node != nullptr)
+        scope_block->node->data.block.statements.append(sub_stmt_node);
+
+    *out_scope = new_scope;
+    return ErrorNone;
+}
+
+static int trans_switch_default(Context *c, TransScope *parent_scope, const DefaultStmt *stmt, AstNode **out_node,
+                                TransScope **out_scope)
+{
+    *out_node = nullptr;
+
+    TransScopeSwitch *switch_scope = trans_scope_switch_find(parent_scope);
+    assert(switch_scope != nullptr);
+
+    Buf *label_name = buf_sprintf("__default");
+
+    {
+        // Add the prong
+        AstNode *prong_node = trans_create_node(c, NodeTypeSwitchProng);
+        prong_node->data.switch_prong.expr = trans_create_node_break(c, label_name, nullptr);
+        switch_scope->switch_node->data.switch_expr.prongs.append(prong_node);
+        switch_scope->found_default = true;
+    }
+
+    TransScopeBlock *scope_block = trans_scope_block_find(parent_scope);
+
+    AstNode *case_block = trans_create_node(c, NodeTypeBlock);
+    case_block->data.block.name = label_name;
+    case_block->data.block.statements = scope_block->node->data.block.statements;
+    scope_block->node->data.block.statements = {0};
+    scope_block->node->data.block.statements.append(case_block);
+
+    AstNode *sub_stmt_node;
+    TransScope *new_scope = trans_stmt(c, parent_scope, stmt->getSubStmt(), &sub_stmt_node);
+    if (new_scope == nullptr)
+        return ErrorUnexpected;
+    if (sub_stmt_node != nullptr)
+        scope_block->node->data.block.statements.append(sub_stmt_node);
+
+    *out_scope = new_scope;
+    return ErrorNone;
+}
+
 static AstNode *trans_string_literal(Context *c, TransScope *scope, const StringLiteral *stmt) {
     switch (stmt->getKind()) {
         case StringLiteral::Ascii:
@@ -2551,7 +2701,8 @@ static AstNode *trans_break_stmt(Context *c, TransScope *scope, const BreakStmt
         if (cur_scope->id == TransScopeIdWhile) {
             return trans_create_node(c, NodeTypeBreak);
         } else if (cur_scope->id == TransScopeIdSwitch) {
-            zig_panic("TODO");
+            TransScopeSwitch *switch_scope = (TransScopeSwitch *)cur_scope;
+            return trans_create_node_break(c, switch_scope->end_label_name, nullptr);
         }
         cur_scope = cur_scope->parent;
     }
@@ -2651,14 +2802,12 @@ static int trans_stmt_extra(Context *c, TransScope *scope, const Stmt *stmt,
             return wrap_stmt(out_node, out_child_scope, scope,
                     trans_expr(c, result_used, scope, ((const ParenExpr*)stmt)->getSubExpr(), lrvalue));
         case Stmt::SwitchStmtClass:
-            emit_warning(c, stmt->getLocStart(), "TODO handle C SwitchStmtClass");
-            return ErrorUnexpected;
+            return wrap_stmt(out_node, out_child_scope, scope,
+                             trans_switch_stmt(c, scope, (const SwitchStmt *)stmt));
         case Stmt::CaseStmtClass:
-            emit_warning(c, stmt->getLocStart(), "TODO handle C CaseStmtClass");
-            return ErrorUnexpected;
+            return trans_switch_case(c, scope, (const CaseStmt *)stmt, out_node, out_child_scope);
         case Stmt::DefaultStmtClass:
-            emit_warning(c, stmt->getLocStart(), "TODO handle C DefaultStmtClass");
-            return ErrorUnexpected;
+            return trans_switch_default(c, scope, (const DefaultStmt *)stmt, out_node, out_child_scope);
         case Stmt::NoStmtClass:
             emit_warning(c, stmt->getLocStart(), "TODO handle C NoStmtClass");
             return ErrorUnexpected;
@@ -3828,6 +3977,14 @@ static TransScopeVar *trans_scope_var_create(Context *c, TransScope *parent_scop
     return result;
 }
 
+static TransScopeSwitch *trans_scope_switch_create(Context *c, TransScope *parent_scope) {
+    TransScopeSwitch *result = allocate<TransScopeSwitch>(1);
+    result->base.id = TransScopeIdSwitch;
+    result->base.parent = parent_scope;
+    result->switch_node = trans_create_node(c, NodeTypeSwitchExpr);
+    return result;
+}
+
 static TransScopeBlock *trans_scope_block_find(TransScope *scope) {
     while (scope != nullptr) {
         if (scope->id == TransScopeIdBlock) {
test/translate_c.zig
@@ -1197,4 +1197,46 @@ pub fn addCases(cases: &tests.TranslateCContext) void {
        \\    }
        \\}
     );
+
+    cases.add("for on int",
+        \\int switch_fn(int i) {
+        \\    int res = 0;
+        \\    switch (i) {
+        \\        case 0:
+        \\            res = 1;
+        \\        case 1:
+        \\            res = 2;
+        \\        default:
+        \\            res = 3 * i;
+        \\            break;
+        \\        case 2:
+        \\            res = 5;
+        \\    }
+        \\}
+    ,
+       \\pub fn switch_fn(i: c_int) c_int {
+       \\    var res: c_int = 0;
+       \\    __switch: {
+       \\        __case_2: {
+       \\            __default: {
+       \\                __case_1: {
+       \\                    __case_0: {
+       \\                        switch (i) {
+       \\                            0 => break :__case_0,
+       \\                            1 => break :__case_1,
+       \\                            else => break :__default,
+       \\                            2 => break :__case_2,
+       \\                        }
+       \\                    }
+       \\                    res = 1;
+       \\                }
+       \\                res = 2;
+       \\            }
+       \\            res = (3 * i);
+       \\            break :__switch;
+       \\        }
+       \\        res = 5;
+       \\    }
+       \\}
+    );
 }