Commit 7db17a2d89

Vexu <git@vexu.eu>
2020-10-16 16:01:05
stage2: redesign switchbr
Switchbr now only handles single item prongs. Ranges and multi item prongs are checked with condbrs after the switchbr.
1 parent 95467f3
src/astgen.zig
@@ -1570,16 +1570,33 @@ fn switchExpr(mod: *Module, scope: *Scope, rl: ResultLoc, switch_node: *ast.Node
     };
     defer block_scope.instructions.deinit(mod.gpa);
 
+    var item_scope: Scope.GenZIR = .{
+        .parent = scope,
+        .decl = scope.decl().?,
+        .arena = scope.arena(),
+        .instructions = .{},
+    };
+    defer item_scope.instructions.deinit(mod.gpa);
+
     const tree = scope.tree();
     const switch_src = tree.token_locs[switch_node.switch_token].start;
     const target_ptr = try expr(mod, &block_scope.base, .ref, switch_node.expr);
-    const cases = try scope.arena().alloc(zir.Inst.SwitchBr.Case, switch_node.cases_len);
-    var kw_args: std.meta.fieldInfo(zir.Inst.SwitchBr, "kw_args").field_type = .{};
+    // Add the switch instruction here so that it comes before any range checks.
+    const switch_inst = (try addZIRInst(mod, &block_scope.base, switch_src, zir.Inst.SwitchBr, .{
+        .target_ptr = target_ptr,
+        .cases = undefined, // populated below
+        .items = &[_]*zir.Inst{}, // populated below
+    }, .{})).castTag(.switchbr).?;
+
+    var items = std.ArrayList(*zir.Inst).init(mod.gpa);
+    defer items.deinit();
+    var cases = std.ArrayList(zir.Inst.SwitchBr.Case).init(mod.gpa);
+    defer cases.deinit();
 
     // first we gather all the switch items and check else/'_' prongs
-    var case_index: usize = 0;
     var else_src: ?usize = null;
     var underscore_src: ?usize = null;
+    var range_inst: ?*zir.Inst = null;
     for (switch_node.cases()) |uncasted_case| {
         const case = uncasted_case.castTag(.SwitchCase).?;
         const case_src = tree.token_locs[case.firstToken()].start;
@@ -1593,12 +1610,7 @@ fn switchExpr(mod: *Module, scope: *Scope, rl: ResultLoc, switch_node: *ast.Node
                 return mod.fail(scope, case_src, "multiple else prongs in switch expression", .{});
                 // TODO notes "previous else prong is here"
             }
-            kw_args.special_case = .@"else";
             else_src = case_src;
-            cases[cases.len - 1] = .{
-                .items = &[0]*zir.Inst{},
-                .body = undefined, // filled below
-            };
             continue;
         } else if (case.items_len == 1 and case.items()[0].tag == .Identifier and
             mem.eql(u8, tree.tokenSlice(case.items()[0].firstToken()), "_"))
@@ -1607,48 +1619,44 @@ fn switchExpr(mod: *Module, scope: *Scope, rl: ResultLoc, switch_node: *ast.Node
                 return mod.fail(scope, case_src, "multiple '_' prongs in switch expression", .{});
                 // TODO notes "previous '_' prong is here"
             }
-            kw_args.special_case = .underscore;
             underscore_src = case_src;
-            cases[cases.len - 1] = .{
-                .items = &[0]*zir.Inst{},
-                .body = undefined, // filled below
-            };
             continue;
         }
 
         if (else_src) |some_else| {
             if (underscore_src) |some_underscore| {
-                return mod.fail(scope, case_src, "else and '_' prong in switch expression", .{});
+                return mod.fail(scope, switch_src, "else and '_' prong in switch expression", .{});
                 // TODO notes "else prong is here"
                 // TODO notes "'_' prong is here"
             }
         }
 
-        // Regular case, we need to fill `items`.
-        const items = try block_scope.arena.alloc(*zir.Inst, case.items_len);
-        for (case.items()) |item, i| {
-            if (item.castTag(.Range)) |range| {
-                items[i] = try switchRange(mod, &block_scope.base, range);
-                if (kw_args.support_range == null)
-                    kw_args.support_range = items[i];
-            } else {
-                items[i] = try expr(mod, &block_scope.base, .none, item);
-            }
+        // TODO and not range
+        if (case.items_len == 1) {
+            const item = try expr(mod, &item_scope.base, .none, case.items()[0]);
+            try cases.append(.{
+                .item = item,
+                .body = undefined, // populated below
+            });
+            continue;
         }
-        cases[case_index] = .{
-            .items = items,
-            .body = undefined, // filled below
-        };
-        case_index += 1;
+        return mod.fail(scope, case_src, "TODO switch ranges", .{});
     }
 
-    // Then we add the switch instruction to finish the block.
-    _ = try addZIRInst(mod, &block_scope.base, switch_src, zir.Inst.SwitchBr, .{
-        .target_ptr = target_ptr,
-        .cases = cases,
-    }, kw_args);
+    // Actually populate switch instruction values.
+    if (else_src != null) switch_inst.kw_args.special_prong = .@"else";
+    if (underscore_src != null) switch_inst.kw_args.special_prong = .underscore;
+    switch_inst.positionals.cases = try block_scope.arena.dupe(zir.Inst.SwitchBr.Case, cases.items);
+    switch_inst.positionals.items = try block_scope.arena.dupe(*zir.Inst, items.items);
+    switch_inst.kw_args.range = range_inst;
+
+    // Add comptime block containing all prong items first,
+    _ = try addZIRInstBlock(mod, scope, switch_src, .block_comptime_flat, .{
+        .instructions = try block_scope.arena.dupe(*zir.Inst, item_scope.instructions.items),
+    });
+    // then add block containing the switch.
     const block = try addZIRInstBlock(mod, scope, switch_src, .block, .{
-        .instructions = try block_scope.arena.dupe(*zir.Inst, block_scope.instructions.items),
+        .instructions = undefined, // populated below
     });
 
     // Most result location types can be forwarded directly; however
@@ -1668,39 +1676,64 @@ fn switchExpr(mod: *Module, scope: *Scope, rl: ResultLoc, switch_node: *ast.Node
     defer case_scope.instructions.deinit(mod.gpa);
 
     // And finally we fill generate the bodies of each case.
-    case_index = 0;
+    var case_index: usize = 0;
+    var special_case: ?*ast.Node.SwitchCase = null;
     for (switch_node.cases()) |uncasted_case| {
         const case = uncasted_case.castTag(.SwitchCase).?;
         const case_src = tree.token_locs[case.firstToken()].start;
         // reset without freeing to reduce allocations.
         defer case_scope.instructions.items.len = 0;
 
-        // What index in positionals.cases should this one be placed at.
-        // For special cases it will be at the end.
-        var cur_index = case_index;
         if (case.items_len == 1 and case.items()[0].tag == .SwitchElse) {
-            // validated above
-            cur_index = cases.len - 1;
+            // validated earlier
+            special_case = case;
+            continue;
         } else if (case.items_len == 1 and case.items()[0].tag == .Identifier and
             mem.eql(u8, tree.tokenSlice(case.items()[0].firstToken()), "_"))
         {
-            // validated above
-            cur_index = cases.len - 1;
+            // validated earlier
+            special_case = case;
+            continue;
         }
 
-        // Generate the body of this case.
-        const case_body = try expr(mod, &case_scope.base, case_rl, case.expr);
+        if (case.items_len == 1) {
+            // Generate the body of this case.
+            const case_body = try expr(mod, &case_scope.base, case_rl, case.expr);
+            if (!case_body.tag.isNoReturn()) {
+                _ = try addZIRInst(mod, &case_scope.base, case_src, zir.Inst.Break, .{
+                    .block = block,
+                    .operand = case_body,
+                }, .{});
+            }
+            switch_inst.positionals.cases[case_index].body = .{
+                .instructions = try scope.arena().dupe(*zir.Inst, case_scope.instructions.items),
+            };
+            case_index += 1;
+            continue;
+        }
+        return mod.fail(scope, case_src, "TODO switch ranges", .{});
+    }
+
+    // Generate else block or a break last to finish the block.
+    if (special_case) |case| {
+        const case_src = tree.token_locs[case.firstToken()].start;
+        const case_body = try expr(mod, &block_scope.base, case_rl, case.expr);
         if (!case_body.tag.isNoReturn()) {
-            _ = try addZIRInst(mod, &case_scope.base, case_src, zir.Inst.Break, .{
+            _ = try addZIRInst(mod, &block_scope.base, case_src, zir.Inst.Break, .{
                 .block = block,
                 .operand = case_body,
             }, .{});
         }
-        cases[cur_index].body = .{
-            .instructions = try scope.arena().dupe(*zir.Inst, case_scope.instructions.items),
-        };
+    } else {
+        _ = try addZIRInst(mod, &block_scope.base, switch_src, zir.Inst.BreakVoid, .{
+            .block = block,
+        }, .{});
     }
 
+    // Set block instructions now that it is finished.
+    block.positionals.body = .{
+        .instructions = try block_scope.arena.dupe(*zir.Inst, block_scope.instructions.items),
+    };
     return &block.base;
 }
 
src/ir.zig
@@ -467,15 +467,12 @@ pub const Inst = struct {
         base: Inst,
         target_ptr: *Inst,
         cases: []Case,
-        @"else": ?Body,
         /// Set of instructions whose lifetimes end at the start of one of the cases.
         /// In same order as cases, deaths[0..case_0_count, case_0_count .. case_1_count, ... , case_n_count ... else_count].
         deaths: [*]*Inst = undefined,
-        else_index: u32 = 0,
-        else_deaths: u32 = 0,
 
         pub const Case = struct {
-            items: []Value,
+            item: Value,
             body: Body,
             index: u32 = 0,
             deaths: u32 = 0,
@@ -497,9 +494,6 @@ pub const Inst = struct {
             const case = self.cases[case_index];
             return (self.deaths + case.index)[0..case.deaths];
         }
-        pub fn elseDeaths(self: *const SwitchBr) []*Inst {
-            return (self.deaths + self.else_deaths)[0..self.else_deaths];
-        }
     };
 };
 
src/Module.zig
@@ -2122,18 +2122,16 @@ pub fn addSwitchBr(
     src: usize,
     target_ptr: *Inst,
     cases: []Inst.SwitchBr.Case,
-    else_body: ?Module.Body,
 ) !*Inst {
     const inst = try block.arena.create(Inst.SwitchBr);
     inst.* = .{
         .base = .{
             .tag = .switchbr,
-            .ty = Type.initTag(.noreturn),
+            .ty = Type.initTag(.void),
             .src = src,
         },
         .target_ptr = target_ptr,
         .cases = cases,
-        .@"else" = else_body,
     };
     try block.instructions.append(self.gpa, &inst.base);
     return &inst.base;
src/zir.zig
@@ -501,7 +501,7 @@ pub const Inst = struct {
                 .slice,
                 .slice_start,
                 .import,
-                .switch_range,
+                .switchbr,
                 => false,
 
                 .@"break",
@@ -513,7 +513,7 @@ pub const Inst = struct {
                 .unreach_nocheck,
                 .@"unreachable",
                 .loop,
-                .switchbr,
+                .switch_range,
                 => true,
             };
         }
@@ -1005,22 +1005,21 @@ pub const Inst = struct {
         positionals: struct {
             target_ptr: *Inst,
             cases: []Case,
+            /// List of all individual items and ranges
+            items: []*Inst,
         },
         kw_args: struct {
-            /// if not null target must support ranges, (be int)
-            support_range: ?*Inst = null,
-            special_case: enum {
-                /// all of positionals.cases are regular cases
+            /// Pointer to first range if such exists.
+            range: ?*Inst = null,
+            special_prong: enum {
                 none,
-                /// last case in positionals.cases is an else case
                 @"else",
-                /// last case in positionals.cases is an underscore case
                 underscore,
             } = .none,
         },
 
         pub const Case = struct {
-            items: []*Inst,
+            item: *Inst,
             body: Module.Body,
         };
     };
@@ -1286,7 +1285,7 @@ const Writer = struct {
                     }
                     try stream.writeByteNTimes(' ', self.indent);
                     self.indent += 2;
-                    try self.writeParamToStream(stream, &case.items);
+                    try self.writeParamToStream(stream, &case.item);
                     try stream.writeAll(" => ");
                     try self.writeParamToStream(stream, &case.body);
                     self.indent -= 2;
@@ -1716,7 +1715,7 @@ const Parser = struct {
                 while (true) {
                     const cur = try cases.addOne();
                     skipSpace(self);
-                    cur.items = try self.parseParameterGeneric([]*Inst, body_ctx);
+                    cur.item = try self.parseParameterGeneric(*Inst, body_ctx);
                     skipSpace(self);
                     try requireEatBytes(self, "=>");
                     cur.body = try self.parseBody(body_ctx);
@@ -2549,8 +2548,7 @@ const EmitZIR = struct {
                 },
                 .switchbr => blk: {
                     const old_inst = inst.castTag(.switchbr).?;
-                    const case_count = old_inst.cases.len + @boolToInt(old_inst.@"else" != null);
-                    const cases = try self.arena.allocator.alloc(Inst.SwitchBr.Case, case_count);
+                    const cases = try self.arena.allocator.alloc(Inst.SwitchBr.Case, old_inst.cases.len);
                     const new_inst = try self.arena.allocator.create(Inst.SwitchBr);
                     new_inst.* = .{
                         .base = .{
@@ -2560,11 +2558,9 @@ const EmitZIR = struct {
                         .positionals = .{
                             .target_ptr = try self.resolveInst(new_body, old_inst.target_ptr),
                             .cases = cases,
+                            .items = &[_]*Inst{}, // TODO this should actually be populated
                         },
-                        .kw_args = .{
-                            .special_case = if (old_inst.@"else" != null) .@"else" else .none,
-                            .support_range = null,
-                        },
+                        .kw_args = .{},
                     };
 
                     var body_tmp = std.ArrayList(*Inst).init(self.allocator);
@@ -2574,25 +2570,13 @@ const EmitZIR = struct {
                         body_tmp.items.len = 0;
 
                         try self.emitBody(case.body, inst_table, &body_tmp);
-                        const items = try self.arena.allocator.alloc(*Inst, case.items.len);
-                        for (case.items) |item, j| {
-                            items[j] = (try self.emitTypedValue(inst.src, .{
-                                .ty = old_inst.target_ptr.ty.elemType(),
-                                .val = item,
-                            })).inst;
-                        }
+                        const item = (try self.emitTypedValue(inst.src, .{
+                            .ty = old_inst.target_ptr.ty.elemType(),
+                            .val = case.item,
+                        })).inst;
 
                         cases[i] = .{
-                            .items = items,
-                            .body = .{ .instructions = try self.arena.allocator.dupe(*Inst, body_tmp.items) },
-                        };
-                    }
-                    if (old_inst.@"else") |some| {
-                        body_tmp.items.len = 0;
-
-                        try self.emitBody(some, inst_table, &body_tmp);
-                        cases[cases.len - 1] = .{
-                            .items = &[0]*Inst{},
+                            .item = item,
                             .body = .{ .instructions = try self.arena.allocator.dupe(*Inst, body_tmp.items) },
                         };
                     }
@@ -2846,7 +2830,7 @@ pub fn dumpZir(allocator: *Allocator, kind: []const u8, decl_name: [*:0]const u8
         .block_table = std.AutoHashMap(*Inst.Block, []const u8).init(allocator),
         .loop_table = std.AutoHashMap(*Inst.Loop, []const u8).init(allocator),
         .arena = std.heap.ArenaAllocator.init(allocator),
-        .indent = 2,
+        .indent = 4,
         .next_instr_index = 0,
     };
     defer write.arena.deinit();
src/zir_sema.zig
@@ -553,10 +553,13 @@ fn analyzeInstBlockFlat(mod: *Module, scope: *Scope, inst: *zir.Inst.Block, is_c
 
     try analyzeBody(mod, &child_block.base, inst.positionals.body);
 
-    const copied_instructions = try parent_block.arena.dupe(*Inst, child_block.instructions.items);
-    try parent_block.instructions.appendSlice(mod.gpa, copied_instructions);
+    try parent_block.instructions.appendSlice(mod.gpa, child_block.instructions.items);
 
-    return copied_instructions[copied_instructions.len - 1];
+    // comptime blocks won't generate any runtime values
+    if (child_block.instructions.items.len == 0)
+        return mod.constVoid(scope, inst.base.src);
+
+    return parent_block.instructions.items[parent_block.instructions.items.len - 1];
 }
 
 fn analyzeInstBlock(mod: *Module, scope: *Scope, inst: *zir.Inst.Block, is_comptime: bool) InnerError!*Inst {
@@ -1235,11 +1238,8 @@ fn analyzeInstSwitchBr(mod: *Module, scope: *Scope, inst: *zir.Inst.SwitchBr) In
 
     // TODO comptime execution
 
-    // excludes else and '_' cases
-    const case_count = inst.positionals.cases.len - @boolToInt(inst.kw_args.special_case != .none);
-
     const parent_block = try mod.requireRuntimeBlock(scope, inst.base.src);
-    const cases = try parent_block.arena.alloc(Inst.SwitchBr.Case, case_count);
+    const cases = try parent_block.arena.alloc(Inst.SwitchBr.Case, inst.positionals.cases.len);
 
     var case_block: Scope.Block = .{
         .parent = parent_block,
@@ -1251,58 +1251,39 @@ fn analyzeInstSwitchBr(mod: *Module, scope: *Scope, inst: *zir.Inst.SwitchBr) In
     };
     defer case_block.instructions.deinit(mod.gpa);
 
-    var items_tmp = std.ArrayList(Value).init(mod.gpa);
-    defer items_tmp.deinit();
-
-    for (inst.positionals.cases[0..case_count]) |case, i| {
+    for (inst.positionals.cases[0..inst.positionals.cases.len]) |case, i| {
         // Reset without freeing.
         case_block.instructions.items.len = 0;
-        items_tmp.items.len = 0;
 
-        for (case.items) |item| {
-            if (item.castTag(.switch_range)) |range| {
-                return mod.fail(scope, item.src, "genSwitch expand range", .{});
-            }
-            const resolved = try resolveInst(mod, scope, item);
-            const casted = try mod.coerce(scope, target.ty, resolved);
-            const val = try mod.resolveConstValue(scope, casted);
-            try items_tmp.append(val);
-        }
+        const resolved = try resolveInst(mod, scope, case.item);
+        const casted = try mod.coerce(scope, target.ty, resolved);
+        const item = try mod.resolveConstValue(scope, casted);
 
         try analyzeBody(mod, &case_block.base, case.body);
 
         cases[i] = .{
-            .items = try parent_block.arena.dupe(Value, items_tmp.items),
+            .item = item,
             .body = .{ .instructions = try parent_block.arena.dupe(*Inst, case_block.instructions.items) },
         };
     }
-
-    const else_body = if (inst.kw_args.special_case != .none) blk: {
-        case_block.instructions.items.len = 0;
-
-        try analyzeBody(mod, &case_block.base, inst.positionals.cases[case_count].body);
-        break: blk Body{
-            .instructions = try parent_block.arena.dupe(*Inst, case_block.instructions.items),
-        };
-    } else null;
     
-    return mod.addSwitchBr(parent_block, inst.base.src, target_ptr, cases, else_body);
+    return mod.addSwitchBr(parent_block, inst.base.src, target_ptr, cases);
 }
 
 fn validateSwitch(mod: *Module, scope: *Scope, target: *Inst, inst: *zir.Inst.SwitchBr) InnerError!void {
     // validate usage of '_' prongs
-    if (inst.kw_args.special_case == .underscore and target.ty.zigTypeTag() != .Enum) {
+    if (inst.kw_args.special_prong == .underscore and target.ty.zigTypeTag() != .Enum) {
         return mod.fail(scope, inst.base.src, "'_' prong only allowed when switching on non-exhaustive enums", .{});
         // TODO notes "'_' prong here" inst.positionals.cases[last].src
     }
 
     // check that target type supports ranges
-    if (inst.kw_args.support_range) |some| {
+    if (inst.kw_args.range) |range_inst| {
         switch (target.ty.zigTypeTag()) {
             .Int, .ComptimeInt, .Float, .ComptimeFloat => {},
             else => {
                 return mod.fail(scope, target.src, "ranges not allowed when switching on type {}", .{target.ty});
-                // TODO notes "range used here" some.src
+                // TODO notes "range used here" range_inst.src
             },
         }
     }
@@ -1317,46 +1298,42 @@ fn validateSwitch(mod: *Module, scope: *Scope, target: *Inst, inst: *zir.Inst.Sw
         .Bool => {
             var true_count: u8 = 0;
             var false_count: u8 = 0;
-            for (inst.positionals.cases) |case| {
-                for (case.items) |item| {
-                    const resolved = try resolveInst(mod, scope, item);
-                    const casted = try mod.coerce(scope, Type.initTag(.bool), resolved);
-                    if ((try mod.resolveConstValue(scope, casted)).toBool()) {
-                        true_count += 1;
-                    } else {
-                        false_count += 1;
-                    }
+            for (inst.positionals.items) |item| {
+                const resolved = try resolveInst(mod, scope, item);
+                const casted = try mod.coerce(scope, Type.initTag(.bool), resolved);
+                if ((try mod.resolveConstValue(scope, casted)).toBool()) {
+                    true_count += 1;
+                } else {
+                    false_count += 1;
+                }
 
-                    if (true_count > 1 or false_count > 1) {
-                        return mod.fail(scope, item.src, "duplicate switch value", .{});
-                    }
+                if (true_count > 1 or false_count > 1) {
+                    return mod.fail(scope, item.src, "duplicate switch value", .{});
                 }
             }
-            if ((true_count == 0 or false_count == 0) and inst.kw_args.special_case != .@"else") {
+            if ((true_count == 0 or false_count == 0) and inst.kw_args.special_prong != .@"else") {
                 return mod.fail(scope, inst.base.src, "switch must handle all possibilities", .{});
             }
-            if ((true_count == 1 and false_count == 1) and inst.kw_args.special_case == .@"else") {
+            if ((true_count == 1 and false_count == 1) and inst.kw_args.special_prong == .@"else") {
                 return mod.fail(scope, inst.base.src, "unreachable else prong, all cases already handled", .{});
             }
         },
         .EnumLiteral, .Void, .Fn, .Pointer, .Type => {
-            if (inst.kw_args.special_case != .@"else") {
+            if (inst.kw_args.special_prong != .@"else") {
                 return mod.fail(scope, inst.base.src, "else prong required when switching on type '{}'", .{target.ty});
             }
 
             var seen_values = std.HashMap(Value, usize, Value.hash, Value.eql, std.hash_map.DefaultMaxLoadPercentage).init(mod.gpa);
             defer seen_values.deinit();
 
-            for (inst.positionals.cases) |case| {
-                for (case.items) |item| {
-                    const resolved = try resolveInst(mod, scope, item);
-                    const casted = try mod.coerce(scope, target.ty, resolved);
-                    const val = try mod.resolveConstValue(scope, casted);
+            for (inst.positionals.items) |item| {
+                const resolved = try resolveInst(mod, scope, item);
+                const casted = try mod.coerce(scope, target.ty, resolved);
+                const val = try mod.resolveConstValue(scope, casted);
 
-                    if (try seen_values.fetchPut(val, item.src)) |prev| {
-                        return mod.fail(scope, item.src, "duplicate switch value", .{});
-                        // TODO notes "previous value here" prev.value
-                    }
+                if (try seen_values.fetchPut(val, item.src)) |prev| {
+                    return mod.fail(scope, item.src, "duplicate switch value", .{});
+                    // TODO notes "previous value here" prev.value
                 }
             }
         },