Commit ad32e46bce

Vexu <git@vexu.eu>
2020-10-11 22:52:08
stage2: switch astgen
1 parent a1d7f00
Changed files (3)
src/astgen.zig
@@ -183,6 +183,7 @@ pub fn expr(mod: *Module, scope: *Scope, rl: ResultLoc, node: *ast.Node) InnerEr
         .VarDecl => unreachable, // Handled in `blockExpr`.
         .SwitchCase => unreachable, // Handled in `switchExpr`.
         .SwitchElse => unreachable, // Handled in `switchExpr`.
+        .Range => unreachable, // Handled in `switchExpr`.
         .Else => unreachable, // Handled explicitly the control flow expression functions.
         .Payload => unreachable, // Handled explicitly.
         .PointerPayload => unreachable, // Handled explicitly.
@@ -279,9 +280,9 @@ pub fn expr(mod: *Module, scope: *Scope, rl: ResultLoc, node: *ast.Node) InnerEr
         .Catch => return catchExpr(mod, scope, rl, node.castTag(.Catch).?),
         .Comptime => return comptimeKeyword(mod, scope, rl, node.castTag(.Comptime).?),
         .OrElse => return orelseExpr(mod, scope, rl, node.castTag(.OrElse).?),
+        .Switch => return switchExpr(mod, scope, rl, node.castTag(.Switch).?),
 
         .Defer => return mod.failNode(scope, node, "TODO implement astgen.expr for .Defer", .{}),
-        .Range => return mod.failNode(scope, node, "TODO implement astgen.expr for .Range", .{}),
         .Await => return mod.failNode(scope, node, "TODO implement astgen.expr for .Await", .{}),
         .Resume => return mod.failNode(scope, node, "TODO implement astgen.expr for .Resume", .{}),
         .Try => return mod.failNode(scope, node, "TODO implement astgen.expr for .Try", .{}),
@@ -289,7 +290,6 @@ pub fn expr(mod: *Module, scope: *Scope, rl: ResultLoc, node: *ast.Node) InnerEr
         .ArrayInitializerDot => return mod.failNode(scope, node, "TODO implement astgen.expr for .ArrayInitializerDot", .{}),
         .StructInitializer => return mod.failNode(scope, node, "TODO implement astgen.expr for .StructInitializer", .{}),
         .StructInitializerDot => return mod.failNode(scope, node, "TODO implement astgen.expr for .StructInitializerDot", .{}),
-        .Switch => return mod.failNode(scope, node, "TODO implement astgen.expr for .Switch", .{}),
         .Suspend => return mod.failNode(scope, node, "TODO implement astgen.expr for .Suspend", .{}),
         .Continue => return mod.failNode(scope, node, "TODO implement astgen.expr for .Continue", .{}),
         .AnyType => return mod.failNode(scope, node, "TODO implement astgen.expr for .AnyType", .{}),
@@ -1561,6 +1561,156 @@ fn forExpr(mod: *Module, scope: *Scope, rl: ResultLoc, for_node: *ast.Node.For)
     return &for_block.base;
 }
 
+fn switchExpr(mod: *Module, scope: *Scope, rl: ResultLoc, switch_node: *ast.Node.Switch) InnerError!*zir.Inst {
+    var block_scope: Scope.GenZIR = .{
+        .parent = scope,
+        .decl = scope.decl().?,
+        .arena = scope.arena(),
+        .instructions = .{},
+    };
+    defer block_scope.instructions.deinit(mod.gpa);
+
+    const tree = scope.tree();
+    const switch_src = tree.token_locs[switch_node.switch_token].start;
+    const target_ptr = try expr(mod, &block_scope.base, .ref, switch_node.expr);
+    const cases = try scope.arena().alloc(zir.Inst.Switch.Case, switch_node.cases_len);
+    var kw_args: std.meta.fieldInfo(zir.Inst.Switch, "kw_args").field_type = .{};
+
+    // first we gather all the switch items and check else/'_' prongs
+    var case_index: usize = 0;
+    var else_src: ?usize = null;
+    var underscore_src: ?usize = null;
+    for (switch_node.cases()) |uncasted_case| {
+        const case = uncasted_case.castTag(.SwitchCase).?;
+        const case_src = tree.token_locs[case.firstToken()].start;
+
+        if (case.items_len == 1 and case.items()[0].tag == .SwitchElse) {
+            if (else_src) |src| {
+                return mod.fail(scope, case_src, "multiple else prongs in switch expression", .{});
+                // TODO notes "previous else prong is here"
+            }
+            kw_args.special_case = .@"else";
+            else_src = case_src;
+            cases[cases.len - 1] = .{
+                .values = &[_]*zir.Inst{},
+                .body = undefined, // filled below
+            };
+            continue;
+        } else if (case.items_len == 1 and case.items()[0].tag == .Identifier and
+            mem.eql(u8, tree.tokenSlice(case.items()[0].firstToken()), "_"))
+        {
+            if (underscore_src) |src| {
+                return mod.fail(scope, case_src, "multiple '_' prongs in switch expression", .{});
+                // TODO notes "previous '_' prong is here"
+            }
+            kw_args.special_case = .underscore;
+            underscore_src = case_src;
+            cases[cases.len - 1] = .{
+                .values = &[_]*zir.Inst{},
+                .body = undefined, // filled below
+            };
+            continue;
+        }
+
+        if (else_src) |some_else| {
+            if (underscore_src) |some_underscore| {
+                return mod.fail(scope, case_src, "else and '_' prong in switch expression", .{});
+                // TODO notes "else prong is here"
+                // TODO notes "'_' prong is here"
+            }
+        }
+
+        // Regular case, we need to fill `values`.
+        const values = try block_scope.arena.alloc(*zir.Inst, case.items_len);
+        for (case.items()) |item, i| {
+            if (item.castTag(.Range)) |range| {
+                values[i] = try switchRange(mod, &block_scope.base, range);
+                if (kw_args.support_range == null)
+                    kw_args.support_range = values[i];
+            } else {
+                values[i] = try expr(mod, &block_scope.base, .none, item);
+            }
+        }
+        cases[case_index] = .{
+            .values = values,
+            .body = undefined, // filled below
+        };
+        case_index += 1;
+    }
+
+    // Then we add the switch instruction to finish the block.
+    _ = try addZIRInst(mod, scope, switch_src, zir.Inst.Switch, .{
+        .target_ptr = target_ptr,
+        .cases = cases,
+    }, kw_args);
+    const block = try addZIRInstBlock(mod, scope, switch_src, .block, .{
+        .instructions = try block_scope.arena.dupe(*zir.Inst, block_scope.instructions.items),
+    });
+
+    // Most result location types can be forwarded directly; however
+    // if we need to write to a pointer which has an inferred type,
+    // proper type inference requires peer type resolution on the switch case.
+    const case_rl: ResultLoc = switch (rl) {
+        .discard, .none, .ty, .ptr, .ref => rl,
+        .inferred_ptr, .bitcasted_ptr, .block_ptr => .{ .block_ptr = block },
+    };
+
+    var case_scope: Scope.GenZIR = .{
+        .parent = scope,
+        .decl = block_scope.decl,
+        .arena = block_scope.arena,
+        .instructions = .{},
+    };
+    defer case_scope.instructions.deinit(mod.gpa);
+
+    // And finally we fill generate the bodies of each case.
+    case_index = 0;
+    for (switch_node.cases()) |uncasted_case| {
+        const case = uncasted_case.castTag(.SwitchCase).?;
+        const case_src = tree.token_locs[case.firstToken()].start;
+        // reset without freeing to reduce allocations.
+        defer case_scope.instructions.items.len = 0;
+
+        // What index in positionals.cases should this one be placed at.
+        // For special cases it will be at the end.
+        var cur_index = case_index;
+        if (case.items_len == 1 and case.items()[0].tag == .SwitchElse) {
+            // validated above
+            cur_index = cases.len - 1;
+        } else if (case.items_len == 1 and case.items()[0].tag == .Identifier and
+            mem.eql(u8, tree.tokenSlice(case.items()[0].firstToken()), "_"))
+        {
+            // validated above
+            cur_index = cases.len - 1;
+        }
+
+        // Generate the body of this case.
+        const case_body = try expr(mod, &case_scope.base, case_rl, case.expr);
+        if (!case_body.tag.isNoReturn()) {
+            _ = try addZIRInst(mod, &case_scope.base, case_src, zir.Inst.Break, .{
+                .block = block,
+                .operand = case_body,
+            }, .{});
+        }
+        cases[cur_index].body = .{
+            .instructions = try scope.arena().dupe(*zir.Inst, case_scope.instructions.items),
+        };
+    }
+
+    return &block.base;
+}
+
+/// Only used for `a...b` in switches.
+fn switchRange(mod: *Module, scope: *Scope, node: *ast.Node.SimpleInfixOp) InnerError!*zir.Inst {
+    const tree = scope.tree();
+    const src = tree.token_locs[node.op_token].start;
+
+    const start = try expr(mod, scope, .none, node.lhs);
+    const end = try expr(mod, scope, .none, node.rhs);
+
+    return try addZIRBinOp(mod, scope, src, .switch_range, start, end);
+}
+
 fn ret(mod: *Module, scope: *Scope, cfe: *ast.Node.ControlFlowExpression) InnerError!*zir.Inst {
     const tree = scope.tree();
     const src = tree.token_locs[cfe.ltoken].start;
src/zir.zig
@@ -272,6 +272,10 @@ pub const Inst = struct {
         ensure_err_payload_void,
         /// Enum literal
         enum_literal,
+        /// A switch expression.
+        @"switch",
+        /// A range in a switch case, `lhs...rhs`.
+        switch_range,
 
         pub fn Type(tag: Tag) type {
             return switch (tag) {
@@ -351,6 +355,7 @@ pub const Inst = struct {
                 .error_union_type,
                 .merge_error_sets,
                 .slice_start,
+                .switch_range,
                 => BinOp,
 
                 .block,
@@ -389,6 +394,7 @@ pub const Inst = struct {
                 .enum_literal => EnumLiteral,
                 .error_set => ErrorSet,
                 .slice => Slice,
+                .@"switch" => Switch,
             };
         }
 
@@ -493,6 +499,7 @@ pub const Inst = struct {
                 .slice,
                 .slice_start,
                 .import,
+                .switch_range,
                 => false,
 
                 .@"break",
@@ -504,6 +511,7 @@ pub const Inst = struct {
                 .unreach_nocheck,
                 .@"unreachable",
                 .loop,
+                .@"switch",
                 => true,
             };
         }
@@ -987,6 +995,33 @@ pub const Inst = struct {
             sentinel: ?*Inst = null,
         },
     };
+
+    pub const Switch = struct {
+        pub const base_tag = Tag.@"switch";
+        base: Inst,
+
+        positionals: struct {
+            target_ptr: *Inst,
+            cases: []Case,
+        },
+        kw_args: struct {
+            /// if not null target must support ranges, (be int)
+            support_range: ?*Inst = null,
+            special_case: enum {
+                /// all of positionals.cases are regular cases
+                none,
+                /// last case in positionals.cases is an else case
+                @"else",
+                /// last case in positionals.cases is an underscore case
+                underscore,
+            } = .none,
+        },
+
+        pub const Case = struct {
+            values: []*Inst,
+            body: Module.Body,
+        };
+    };
 };
 
 pub const ErrorMsg = struct {
@@ -1238,6 +1273,26 @@ const Writer = struct {
                 }
                 try stream.writeByte(']');
             },
+            []Inst.Switch.Case => {
+                if (param.len == 0) {
+                    return stream.writeAll("{}");
+                }
+                try stream.writeAll("{\n");
+                self.indent += 2;
+                for (param) |*case, i| {
+                    if (i != 0) {
+                        try stream.writeAll(",\n");
+                    }
+                    try stream.writeByteNTimes(' ', self.indent);
+                    try self.writeParamToStream(stream, &case.values);
+                    try stream.writeAll(" => ");
+                    try self.writeParamToStream(stream, &case.body);
+                }
+                try stream.writeByte('\n');
+                self.indent -= 2;
+                try stream.writeByteNTimes(' ', self.indent);
+                try stream.writeByte('}');
+            },
             else => |T| @compileError("unimplemented: rendering parameter of type " ++ @typeName(T)),
         }
     }
@@ -1650,6 +1705,26 @@ const Parser = struct {
                 try requireEatBytes(self, "]");
                 return strings.toOwnedSlice();
             },
+            []Inst.Switch.Case => {
+                try requireEatBytes(self, "{");
+                skipSpace(self);
+                if (eatByte(self, '}')) return &[0]Inst.Switch.Case{};
+
+                var cases = std.ArrayList(Inst.Switch.Case).init(&self.arena.allocator);
+                while (true) {
+                    const cur = try cases.addOne();
+                    skipSpace(self);
+                    cur.values = try self.parseParameterGeneric([]*Inst, body_ctx);
+                    skipSpace(self);
+                    try requireEatBytes(self, "=>");
+                    cur.body = try self.parseBody(body_ctx);
+                    skipSpace(self);
+                    if (!eatByte(self, ',')) break;
+                }
+                skipSpace(self);
+                try requireEatBytes(self, "}");
+                return cases.toOwnedSlice();
+            },
             else => @compileError("Unimplemented: ir parseParameterGeneric for type " ++ @typeName(T)),
         }
         return self.fail("TODO parse parameter {}", .{@typeName(T)});
src/zir_sema.zig
@@ -135,6 +135,7 @@ pub fn analyzeInst(mod: *Module, scope: *Scope, old_inst: *zir.Inst) InnerError!
         .slice => return analyzeInstSlice(mod, scope, old_inst.castTag(.slice).?),
         .slice_start => return analyzeInstSliceStart(mod, scope, old_inst.castTag(.slice_start).?),
         .import => return analyzeInstImport(mod, scope, old_inst.castTag(.import).?),
+        .@"switch", .switch_range => @panic("TODO switch sema"),
     }
 }