Commit 2020ca640e

Vexu <git@vexu.eu>
2020-10-13 17:08:15
stage2: switch emit zir
1 parent 11998d2
src/astgen.zig
@@ -1573,8 +1573,8 @@ fn switchExpr(mod: *Module, scope: *Scope, rl: ResultLoc, switch_node: *ast.Node
     const tree = scope.tree();
     const switch_src = tree.token_locs[switch_node.switch_token].start;
     const target_ptr = try expr(mod, &block_scope.base, .ref, switch_node.expr);
-    const cases = try scope.arena().alloc(zir.Inst.Switch.Case, switch_node.cases_len);
-    var kw_args: std.meta.fieldInfo(zir.Inst.Switch, "kw_args").field_type = .{};
+    const cases = try scope.arena().alloc(zir.Inst.SwitchBr.Case, switch_node.cases_len);
+    var kw_args: std.meta.fieldInfo(zir.Inst.SwitchBr, "kw_args").field_type = .{};
 
     // first we gather all the switch items and check else/'_' prongs
     var case_index: usize = 0;
@@ -1643,7 +1643,7 @@ fn switchExpr(mod: *Module, scope: *Scope, rl: ResultLoc, switch_node: *ast.Node
     }
 
     // Then we add the switch instruction to finish the block.
-    _ = try addZIRInst(mod, &block_scope.base, switch_src, zir.Inst.Switch, .{
+    _ = try addZIRInst(mod, &block_scope.base, switch_src, zir.Inst.SwitchBr, .{
         .target_ptr = target_ptr,
         .cases = cases,
     }, kw_args);
src/codegen.zig
@@ -786,7 +786,7 @@ fn Function(comptime arch: std.Target.Cpu.Arch) type {
                 .unwrap_optional => return self.genUnwrapOptional(inst.castTag(.unwrap_optional).?),
                 .wrap_optional => return self.genWrapOptional(inst.castTag(.wrap_optional).?),
                 .varptr => return self.genVarPtr(inst.castTag(.varptr).?),
-                .@"switch" => return self.genSwitch(inst.castTag(.@"switch").?),
+                .switchbr => return self.genSwitch(inst.castTag(.switchbr).?),
             }
         }
 
@@ -1990,7 +1990,7 @@ fn Function(comptime arch: std.Target.Cpu.Arch) type {
             return @bitCast(MCValue, inst.codegen.mcv);
         }
 
-        fn genSwitch(self: *Self, inst: *ir.Inst.Switch) !MCValue {
+        fn genSwitch(self: *Self, inst: *ir.Inst.SwitchBr) !MCValue {
             switch (arch) {
                 else => return self.fail(inst.base.src, "TODO genSwitch for {}", .{self.target.cpu.arch}),
             }
src/ir.zig
@@ -91,7 +91,7 @@ pub const Inst = struct {
         intcast,
         unwrap_optional,
         wrap_optional,
-        @"switch",
+        switchbr,
 
         pub fn Type(tag: Tag) type {
             return switch (tag) {
@@ -138,7 +138,7 @@ pub const Inst = struct {
                 .constant => Constant,
                 .loop => Loop,
                 .varptr => VarPtr,
-                .@"switch" => Switch,
+                .switchbr => SwitchBr,
             };
         }
 
@@ -461,26 +461,45 @@ pub const Inst = struct {
         }
     };
 
-    pub const Switch = struct {
-        pub const base_tag = Tag.@"switch";
+    pub const SwitchBr = struct {
+        pub const base_tag = Tag.switchbr;
 
         base: Inst,
         target_ptr: *Inst,
         cases: []Case,
         @"else": ?Body,
+        /// Set of instructions whose lifetimes end at the start of one of the cases.
+        /// In same order as cases, deaths[0..case_0_count, case_0_count .. case_1_count, ... , case_n_count ... else_count].
+        deaths: [*]*Inst = undefined,
+        else_index: u32 = 0,
+        else_deaths: u32 = 0,
 
         pub const Case = struct {
             items: []Value,
             body: Body,
+            index: u32 = 0,
+            deaths: u32 = 0,
         };
 
-        pub fn operandCount(self: *const Switch) usize {
+        pub fn operandCount(self: *const SwitchBr) usize {
             return 1;
         }
-        pub fn getOperand(self: *const Switch, index: usize) ?*Inst {
-            return self.target_ptr;
+        pub fn getOperand(self: *const SwitchBr, index: usize) ?*Inst {
+            var i = index;
+
+            if (i < 1)
+                return self.target_ptr;
+            i -= 1;
+
+            return null;
+        }
+        pub fn caseDeaths(self: *const SwitchBr, case_index: usize) []*Inst {
+            const case = self.cases[case_index];
+            return (self.deaths + case.index)[0..case.deaths];
+        }
+        pub fn elseDeaths(self: *const SwitchBr) []*Inst {
+            return (self.deaths + self.else_deaths)[0..self.else_deaths];
         }
-        // TODO case body deaths
     };
 };
 
src/Module.zig
@@ -2098,6 +2098,29 @@ pub fn addCall(
     return &inst.base;
 }
 
+pub fn addSwitchBr(
+    self: *Module,
+    block: *Scope.Block,
+    src: usize,
+    target_ptr: *Inst,
+    cases: []Inst.SwitchBr.Case,
+    else_body: ?Module.Body,
+) !*Inst {
+    const inst = try block.arena.create(Inst.SwitchBr);
+    inst.* = .{
+        .base = .{
+            .tag = .switchbr,
+            .ty = Type.initTag(.noreturn),
+            .src = src,
+        },
+        .target_ptr = target_ptr,
+        .cases = cases,
+        .@"else" = else_body,
+    };
+    try block.instructions.append(self.gpa, &inst.base);
+    return &inst.base;
+}
+
 pub fn constInst(self: *Module, scope: *Scope, src: usize, typed_value: TypedValue) !*Inst {
     const const_inst = try scope.arena().create(Inst.Constant);
     const_inst.* = .{
src/zir.zig
@@ -273,7 +273,7 @@ pub const Inst = struct {
         /// Enum literal
         enum_literal,
         /// A switch expression.
-        @"switch",
+        switchbr,
         /// A range in a switch case, `lhs...rhs`.
         /// Only checks that `lhs >= rhs` if they are ints or floats, everything else is
         /// validated by the .switch instruction.
@@ -396,7 +396,7 @@ pub const Inst = struct {
                 .enum_literal => EnumLiteral,
                 .error_set => ErrorSet,
                 .slice => Slice,
-                .@"switch" => Switch,
+                .switchbr => SwitchBr,
             };
         }
 
@@ -513,7 +513,7 @@ pub const Inst = struct {
                 .unreach_nocheck,
                 .@"unreachable",
                 .loop,
-                .@"switch",
+                .switchbr,
                 => true,
             };
         }
@@ -998,8 +998,8 @@ pub const Inst = struct {
         },
     };
 
-    pub const Switch = struct {
-        pub const base_tag = Tag.@"switch";
+    pub const SwitchBr = struct {
+        pub const base_tag = Tag.switchbr;
         base: Inst,
 
         positionals: struct {
@@ -1275,24 +1275,24 @@ const Writer = struct {
                 }
                 try stream.writeByte(']');
             },
-            []Inst.Switch.Case => {
+            []Inst.SwitchBr.Case => {
                 if (param.len == 0) {
                     return stream.writeAll("{}");
                 }
                 try stream.writeAll("{\n");
-                self.indent += 2;
                 for (param) |*case, i| {
                     if (i != 0) {
                         try stream.writeAll(",\n");
                     }
                     try stream.writeByteNTimes(' ', self.indent);
+                    self.indent += 2;
                     try self.writeParamToStream(stream, &case.items);
                     try stream.writeAll(" => ");
                     try self.writeParamToStream(stream, &case.body);
+                    self.indent -= 2;
                 }
                 try stream.writeByte('\n');
-                self.indent -= 2;
-                try stream.writeByteNTimes(' ', self.indent);
+                try stream.writeByteNTimes(' ', self.indent - 2);
                 try stream.writeByte('}');
             },
             else => |T| @compileError("unimplemented: rendering parameter of type " ++ @typeName(T)),
@@ -1707,12 +1707,12 @@ const Parser = struct {
                 try requireEatBytes(self, "]");
                 return strings.toOwnedSlice();
             },
-            []Inst.Switch.Case => {
+            []Inst.SwitchBr.Case => {
                 try requireEatBytes(self, "{");
                 skipSpace(self);
-                if (eatByte(self, '}')) return &[0]Inst.Switch.Case{};
+                if (eatByte(self, '}')) return &[0]Inst.SwitchBr.Case{};
 
-                var cases = std.ArrayList(Inst.Switch.Case).init(&self.arena.allocator);
+                var cases = std.ArrayList(Inst.SwitchBr.Case).init(&self.arena.allocator);
                 while (true) {
                     const cur = try cases.addOne();
                     skipSpace(self);
@@ -1824,7 +1824,7 @@ pub fn dumpFn(old_module: IrModule, module_fn: *IrModule.Fn) void {
         .arena = std.heap.ArenaAllocator.init(allocator),
         .old_module = &old_module,
         .next_auto_name = 0,
-        .names = std.StringHashMap(void).init(allocator),
+        .names = std.StringArrayHashMap(void).init(allocator),
         .primitive_table = std.AutoHashMap(Inst.Primitive.Builtin, *Decl).init(allocator),
         .indent = 0,
         .block_table = std.AutoHashMap(*ir.Inst.Block, *Inst.Block).init(allocator),
@@ -2547,11 +2547,58 @@ const EmitZIR = struct {
                     };
                     break :blk &new_inst.base;
                 },
+                .switchbr => blk: {
+                    const old_inst = inst.castTag(.switchbr).?;
+                    const case_count = old_inst.cases.len + @boolToInt(old_inst.@"else" != null);
+                    const cases = try self.arena.allocator.alloc(Inst.SwitchBr.Case, case_count);
+                    const new_inst = try self.arena.allocator.create(Inst.SwitchBr);
+                    new_inst.* = .{
+                        .base = .{
+                            .src = inst.src,
+                            .tag = Inst.SwitchBr.base_tag,
+                        },
+                        .positionals = .{
+                            .target_ptr = try self.resolveInst(new_body, old_inst.target_ptr),
+                            .cases = cases,
+                        },
+                        .kw_args = .{
+                            .special_case = if (old_inst.@"else" != null) .@"else" else .none,
+                            .support_range = null,
+                        },
+                    };
 
-                .varptr => @panic("TODO"),
-                .@"switch" => {
-                    @panic("TODO");
+                    var body_tmp = std.ArrayList(*Inst).init(self.allocator);
+                    defer body_tmp.deinit();
+
+                    for (old_inst.cases) |case, i| {
+                        body_tmp.items.len = 0;
+
+                        try self.emitBody(case.body, inst_table, &body_tmp);
+                        const items = try self.arena.allocator.alloc(*Inst, case.items.len);
+                        for (case.items) |item, j| {
+                            items[j] = (try self.emitTypedValue(inst.src, .{
+                                .ty = old_inst.target_ptr.ty.elemType(),
+                                .val = item,
+                            })).inst;
+                        }
+
+                        cases[i] = .{
+                            .items = items,
+                            .body = .{ .instructions = try self.arena.allocator.dupe(*Inst, body_tmp.items) },
+                        };
+                    }
+                    if (old_inst.@"else") |some| {
+                        body_tmp.items.len = 0;
+
+                        try self.emitBody(some, inst_table, &body_tmp);
+                        cases[cases.len - 1] = .{
+                            .items = &[0]*Inst{},
+                            .body = .{ .instructions = try self.arena.allocator.dupe(*Inst, body_tmp.items) },
+                        };
+                    }
+                    break :blk &new_inst.base;
                 },
+                .varptr => @panic("TODO"),
             };
             try self.metadata.put(new_inst, .{
                 .deaths = inst.deaths,
src/zir_sema.zig
@@ -135,7 +135,7 @@ pub fn analyzeInst(mod: *Module, scope: *Scope, old_inst: *zir.Inst) InnerError!
         .slice => return analyzeInstSlice(mod, scope, old_inst.castTag(.slice).?),
         .slice_start => return analyzeInstSliceStart(mod, scope, old_inst.castTag(.slice_start).?),
         .import => return analyzeInstImport(mod, scope, old_inst.castTag(.import).?),
-        .@"switch" => return analyzeInstSwitch(mod, scope, old_inst.castTag(.@"switch").?),
+        .switchbr => return analyzeInstSwitchBr(mod, scope, old_inst.castTag(.switchbr).?),
         .switch_range => return analyzeInstSwitchRange(mod, scope, old_inst.castTag(.switch_range).?),
     }
 }
@@ -1228,7 +1228,7 @@ fn analyzeInstSwitchRange(mod: *Module, scope: *Scope, inst: *zir.Inst.BinOp) In
     return mod.constVoid(scope, inst.base.src);
 }
 
-fn analyzeInstSwitch(mod: *Module, scope: *Scope, inst: *zir.Inst.Switch) InnerError!*Inst {
+fn analyzeInstSwitchBr(mod: *Module, scope: *Scope, inst: *zir.Inst.SwitchBr) InnerError!*Inst {
     const target_ptr = try resolveInst(mod, scope, inst.positionals.target_ptr);
     const target = try mod.analyzeDeref(scope, inst.base.src, target_ptr, inst.positionals.target_ptr.src);
     try validateSwitch(mod, scope, target, inst);
@@ -1239,17 +1239,7 @@ fn analyzeInstSwitch(mod: *Module, scope: *Scope, inst: *zir.Inst.Switch) InnerE
     const case_count = inst.positionals.cases.len - @boolToInt(inst.kw_args.special_case != .none);
 
     const parent_block = try mod.requireRuntimeBlock(scope, inst.base.src);
-    const switch_inst = try parent_block.arena.create(Inst.Switch);
-    switch_inst.* = .{
-        .base = .{
-            .tag = Inst.Switch.base_tag,
-            .ty = Type.initTag(.noreturn),
-            .src = inst.base.src,
-        },
-        .target_ptr = target_ptr,
-        .@"else" = null,
-        .cases = try parent_block.arena.alloc(Inst.Switch.Case, case_count),
-    };
+    const cases = try parent_block.arena.alloc(Inst.SwitchBr.Case, case_count);
 
     var case_block: Scope.Block = .{
         .parent = parent_block,
@@ -1281,25 +1271,25 @@ fn analyzeInstSwitch(mod: *Module, scope: *Scope, inst: *zir.Inst.Switch) InnerE
 
         try analyzeBody(mod, &case_block.base, case.body);
 
-        switch_inst.cases[i] = .{
+        cases[i] = .{
             .items = try parent_block.arena.dupe(Value, items_tmp.items),
             .body = .{ .instructions = try parent_block.arena.dupe(*Inst, case_block.instructions.items) },
         };
     }
 
-    if (inst.kw_args.special_case != .none) {
+    const else_body = if (inst.kw_args.special_case != .none) blk: {
         case_block.instructions.items.len = 0;
 
         try analyzeBody(mod, &case_block.base, inst.positionals.cases[case_count].body);
-        switch_inst.@"else" = .{
+        break: blk Body{
             .instructions = try parent_block.arena.dupe(*Inst, case_block.instructions.items),
         };
-    }
+    } else null;
     
-    return &switch_inst.base;
+    return mod.addSwitchBr(parent_block, inst.base.src, target_ptr, cases, else_body);
 }
 
-fn validateSwitch(mod: *Module, scope: *Scope, target: *Inst, inst: *zir.Inst.Switch) InnerError!void {
+fn validateSwitch(mod: *Module, scope: *Scope, target: *Inst, inst: *zir.Inst.SwitchBr) InnerError!void {
     // validate usage of '_' prongs
     if (inst.kw_args.special_case == .underscore and target.ty.zigTypeTag() != .Enum) {
         return mod.fail(scope, inst.base.src, "'_' prong only allowed when switching on non-exhaustive enums", .{});