Commit 51b2f1b80b
Changed files (2)
test
src/translate_c.cpp
@@ -104,6 +104,7 @@ static TransScopeRoot *trans_scope_root_create(Context *c);
static TransScopeWhile *trans_scope_while_create(Context *c, TransScope *parent_scope);
static TransScopeBlock *trans_scope_block_create(Context *c, TransScope *parent_scope);
static TransScopeVar *trans_scope_var_create(Context *c, TransScope *parent_scope, Buf *wanted_name);
+static TransScopeSwitch *trans_scope_switch_create(Context *c, TransScope *parent_scope);
static TransScopeBlock *trans_scope_block_find(TransScope *scope);
@@ -2527,6 +2528,155 @@ static AstNode *trans_for_loop(Context *c, TransScope *parent_scope, const ForSt
return loop_block_node;
}
+static AstNode *trans_switch_stmt(Context *c, TransScope *parent_scope, const SwitchStmt *stmt) {
+ TransScopeBlock *block_scope = trans_scope_block_create(c, parent_scope);
+
+ TransScopeSwitch *switch_scope;
+
+ const DeclStmt *var_decl_stmt = stmt->getConditionVariableDeclStmt();
+ if (var_decl_stmt == nullptr) {
+ switch_scope = trans_scope_switch_create(c, &block_scope->base);
+ } else {
+ AstNode *vars_node;
+ TransScope *var_scope = trans_stmt(c, &block_scope->base, var_decl_stmt, &vars_node);
+ if (var_scope == nullptr)
+ return nullptr;
+ if (vars_node != nullptr)
+ block_scope->node->data.block.statements.append(vars_node);
+ switch_scope = trans_scope_switch_create(c, var_scope);
+ }
+ block_scope->node->data.block.statements.append(switch_scope->switch_node);
+
+ // TODO avoid name collisions
+ Buf *end_label_name = buf_create_from_str("__switch");
+ switch_scope->end_label_name = end_label_name;
+ block_scope->node->data.block.name = end_label_name;
+
+ const Expr *cond_expr = stmt->getCond();
+ assert(cond_expr != nullptr);
+
+ AstNode *expr_node = trans_expr(c, ResultUsedYes, &block_scope->base, cond_expr, TransRValue);
+ if (expr_node == nullptr)
+ return nullptr;
+ switch_scope->switch_node->data.switch_expr.expr = expr_node;
+
+ AstNode *body_node;
+ const Stmt *body_stmt = stmt->getBody();
+ if (body_stmt->getStmtClass() == Stmt::CompoundStmtClass) {
+ if (trans_compound_stmt_inline(c, &switch_scope->base, (const CompoundStmt *)body_stmt,
+ block_scope->node, nullptr))
+ {
+ return nullptr;
+ }
+ } else {
+ TransScope *body_scope = trans_stmt(c, &switch_scope->base, body_stmt, &body_node);
+ if (body_scope == nullptr)
+ return nullptr;
+ if (body_node != nullptr)
+ block_scope->node->data.block.statements.append(body_node);
+ }
+
+ if (!switch_scope->found_default && !stmt->isAllEnumCasesCovered()) {
+ AstNode *prong_node = trans_create_node(c, NodeTypeSwitchProng);
+ prong_node->data.switch_prong.expr = trans_create_node_break(c, end_label_name, nullptr);
+ switch_scope->switch_node->data.switch_expr.prongs.append(prong_node);
+ }
+
+ return block_scope->node;
+}
+
+static TransScopeSwitch *trans_scope_switch_find(TransScope *scope) {
+ while (scope != nullptr) {
+ if (scope->id == TransScopeIdSwitch) {
+ return (TransScopeSwitch *)scope;
+ }
+ scope = scope->parent;
+ }
+ return nullptr;
+}
+
+static int trans_switch_case(Context *c, TransScope *parent_scope, const CaseStmt *stmt, AstNode **out_node,
+ TransScope **out_scope) {
+ *out_node = nullptr;
+
+ if (stmt->getRHS() != nullptr) {
+ emit_warning(c, stmt->getLocStart(), "TODO support GNU switch case a ... b extension");
+ return ErrorUnexpected;
+ }
+
+ TransScopeSwitch *switch_scope = trans_scope_switch_find(parent_scope);
+ assert(switch_scope != nullptr);
+
+ Buf *label_name = buf_sprintf("__case_%" PRIu32, switch_scope->case_index);
+ switch_scope->case_index += 1;
+
+ {
+ // Add the prong
+ AstNode *prong_node = trans_create_node(c, NodeTypeSwitchProng);
+ AstNode *item_node = trans_expr(c, ResultUsedYes, &switch_scope->base, stmt->getLHS(), TransRValue);
+ if (item_node == nullptr)
+ return ErrorUnexpected;
+ prong_node->data.switch_prong.items.append(item_node);
+ prong_node->data.switch_prong.expr = trans_create_node_break(c, label_name, nullptr);
+ switch_scope->switch_node->data.switch_expr.prongs.append(prong_node);
+ }
+
+ TransScopeBlock *scope_block = trans_scope_block_find(parent_scope);
+
+ AstNode *case_block = trans_create_node(c, NodeTypeBlock);
+ case_block->data.block.name = label_name;
+ case_block->data.block.statements = scope_block->node->data.block.statements;
+ scope_block->node->data.block.statements = {0};
+ scope_block->node->data.block.statements.append(case_block);
+
+ AstNode *sub_stmt_node;
+ TransScope *new_scope = trans_stmt(c, parent_scope, stmt->getSubStmt(), &sub_stmt_node);
+ if (new_scope == nullptr)
+ return ErrorUnexpected;
+ if (sub_stmt_node != nullptr)
+ scope_block->node->data.block.statements.append(sub_stmt_node);
+
+ *out_scope = new_scope;
+ return ErrorNone;
+}
+
+static int trans_switch_default(Context *c, TransScope *parent_scope, const DefaultStmt *stmt, AstNode **out_node,
+ TransScope **out_scope)
+{
+ *out_node = nullptr;
+
+ TransScopeSwitch *switch_scope = trans_scope_switch_find(parent_scope);
+ assert(switch_scope != nullptr);
+
+ Buf *label_name = buf_sprintf("__default");
+
+ {
+ // Add the prong
+ AstNode *prong_node = trans_create_node(c, NodeTypeSwitchProng);
+ prong_node->data.switch_prong.expr = trans_create_node_break(c, label_name, nullptr);
+ switch_scope->switch_node->data.switch_expr.prongs.append(prong_node);
+ switch_scope->found_default = true;
+ }
+
+ TransScopeBlock *scope_block = trans_scope_block_find(parent_scope);
+
+ AstNode *case_block = trans_create_node(c, NodeTypeBlock);
+ case_block->data.block.name = label_name;
+ case_block->data.block.statements = scope_block->node->data.block.statements;
+ scope_block->node->data.block.statements = {0};
+ scope_block->node->data.block.statements.append(case_block);
+
+ AstNode *sub_stmt_node;
+ TransScope *new_scope = trans_stmt(c, parent_scope, stmt->getSubStmt(), &sub_stmt_node);
+ if (new_scope == nullptr)
+ return ErrorUnexpected;
+ if (sub_stmt_node != nullptr)
+ scope_block->node->data.block.statements.append(sub_stmt_node);
+
+ *out_scope = new_scope;
+ return ErrorNone;
+}
+
static AstNode *trans_string_literal(Context *c, TransScope *scope, const StringLiteral *stmt) {
switch (stmt->getKind()) {
case StringLiteral::Ascii:
@@ -2551,7 +2701,8 @@ static AstNode *trans_break_stmt(Context *c, TransScope *scope, const BreakStmt
if (cur_scope->id == TransScopeIdWhile) {
return trans_create_node(c, NodeTypeBreak);
} else if (cur_scope->id == TransScopeIdSwitch) {
- zig_panic("TODO");
+ TransScopeSwitch *switch_scope = (TransScopeSwitch *)cur_scope;
+ return trans_create_node_break(c, switch_scope->end_label_name, nullptr);
}
cur_scope = cur_scope->parent;
}
@@ -2651,14 +2802,12 @@ static int trans_stmt_extra(Context *c, TransScope *scope, const Stmt *stmt,
return wrap_stmt(out_node, out_child_scope, scope,
trans_expr(c, result_used, scope, ((const ParenExpr*)stmt)->getSubExpr(), lrvalue));
case Stmt::SwitchStmtClass:
- emit_warning(c, stmt->getLocStart(), "TODO handle C SwitchStmtClass");
- return ErrorUnexpected;
+ return wrap_stmt(out_node, out_child_scope, scope,
+ trans_switch_stmt(c, scope, (const SwitchStmt *)stmt));
case Stmt::CaseStmtClass:
- emit_warning(c, stmt->getLocStart(), "TODO handle C CaseStmtClass");
- return ErrorUnexpected;
+ return trans_switch_case(c, scope, (const CaseStmt *)stmt, out_node, out_child_scope);
case Stmt::DefaultStmtClass:
- emit_warning(c, stmt->getLocStart(), "TODO handle C DefaultStmtClass");
- return ErrorUnexpected;
+ return trans_switch_default(c, scope, (const DefaultStmt *)stmt, out_node, out_child_scope);
case Stmt::NoStmtClass:
emit_warning(c, stmt->getLocStart(), "TODO handle C NoStmtClass");
return ErrorUnexpected;
@@ -3828,6 +3977,14 @@ static TransScopeVar *trans_scope_var_create(Context *c, TransScope *parent_scop
return result;
}
+static TransScopeSwitch *trans_scope_switch_create(Context *c, TransScope *parent_scope) {
+ TransScopeSwitch *result = allocate<TransScopeSwitch>(1);
+ result->base.id = TransScopeIdSwitch;
+ result->base.parent = parent_scope;
+ result->switch_node = trans_create_node(c, NodeTypeSwitchExpr);
+ return result;
+}
+
static TransScopeBlock *trans_scope_block_find(TransScope *scope) {
while (scope != nullptr) {
if (scope->id == TransScopeIdBlock) {
test/translate_c.zig
@@ -1197,4 +1197,46 @@ pub fn addCases(cases: &tests.TranslateCContext) void {
\\ }
\\}
);
+
+ cases.add("for on int",
+ \\int switch_fn(int i) {
+ \\ int res = 0;
+ \\ switch (i) {
+ \\ case 0:
+ \\ res = 1;
+ \\ case 1:
+ \\ res = 2;
+ \\ default:
+ \\ res = 3 * i;
+ \\ break;
+ \\ case 2:
+ \\ res = 5;
+ \\ }
+ \\}
+ ,
+ \\pub fn switch_fn(i: c_int) c_int {
+ \\ var res: c_int = 0;
+ \\ __switch: {
+ \\ __case_2: {
+ \\ __default: {
+ \\ __case_1: {
+ \\ __case_0: {
+ \\ switch (i) {
+ \\ 0 => break :__case_0,
+ \\ 1 => break :__case_1,
+ \\ else => break :__default,
+ \\ 2 => break :__case_2,
+ \\ }
+ \\ }
+ \\ res = 1;
+ \\ }
+ \\ res = 2;
+ \\ }
+ \\ res = (3 * i);
+ \\ break :__switch;
+ \\ }
+ \\ res = 5;
+ \\ }
+ \\}
+ );
}