Commit 4155d2ae24

Vexu <git@vexu.eu>
2020-10-16 22:11:35
stage2: switch ranges and multi item prongs
1 parent 3c96d79
src/astgen.zig
@@ -1561,6 +1561,17 @@ fn forExpr(mod: *Module, scope: *Scope, rl: ResultLoc, for_node: *ast.Node.For)
     return &for_block.base;
 }
 
+fn getRangeNode(node: *ast.Node) ?*ast.Node.SimpleInfixOp {
+    var cur = node;
+    while (true) {
+        switch (cur.tag) {
+            .Range => return @fieldParentPtr(ast.Node.SimpleInfixOp, "base", cur),
+            .GroupedExpression => cur = @fieldParentPtr(ast.Node.GroupedExpression, "base", cur).expr,
+            else => return null,
+        }
+    }
+}
+
 fn switchExpr(mod: *Module, scope: *Scope, rl: ResultLoc, switch_node: *ast.Node.Switch) InnerError!*zir.Inst {
     var block_scope: Scope.GenZIR = .{
         .parent = scope,
@@ -1581,6 +1592,7 @@ fn switchExpr(mod: *Module, scope: *Scope, rl: ResultLoc, switch_node: *ast.Node
     const tree = scope.tree();
     const switch_src = tree.token_locs[switch_node.switch_token].start;
     const target_ptr = try expr(mod, &block_scope.base, .ref, switch_node.expr);
+    const target = try addZIRUnOp(mod, &block_scope.base, target_ptr.src, .deref, target_ptr);
     // Add the switch instruction here so that it comes before any range checks.
     const switch_inst = (try addZIRInst(mod, &block_scope.base, switch_src, zir.Inst.SwitchBr, .{
         .target_ptr = target_ptr,
@@ -1593,24 +1605,51 @@ fn switchExpr(mod: *Module, scope: *Scope, rl: ResultLoc, switch_node: *ast.Node
     var cases = std.ArrayList(zir.Inst.SwitchBr.Case).init(mod.gpa);
     defer cases.deinit();
 
+    // Add comptime block containing all prong items first,
+    const item_block = try addZIRInstBlock(mod, scope, switch_src, .block_comptime_flat, .{
+        .instructions = undefined, // populated below
+    });
+    // then add block containing the switch.
+    const block = try addZIRInstBlock(mod, scope, switch_src, .block, .{
+        .instructions = undefined, // populated below
+    });
+
+    // Most result location types can be forwarded directly; however
+    // if we need to write to a pointer which has an inferred type,
+    // proper type inference requires peer type resolution on the switch case.
+    const case_rl: ResultLoc = switch (rl) {
+        .discard, .none, .ty, .ptr, .ref => rl,
+        .inferred_ptr, .bitcasted_ptr, .block_ptr => .{ .block_ptr = block },
+    };
+
+    var case_scope: Scope.GenZIR = .{
+        .parent = scope,
+        .decl = block_scope.decl,
+        .arena = block_scope.arena,
+        .instructions = .{},
+    };
+    defer case_scope.instructions.deinit(mod.gpa);
+
     // first we gather all the switch items and check else/'_' prongs
     var else_src: ?usize = null;
     var underscore_src: ?usize = null;
-    var range_inst: ?*zir.Inst = null;
+    var first_range: ?*zir.Inst = null;
+    var special_case: ?*ast.Node.SwitchCase = null;
     for (switch_node.cases()) |uncasted_case| {
         const case = uncasted_case.castTag(.SwitchCase).?;
         const case_src = tree.token_locs[case.firstToken()].start;
+        // reset without freeing to reduce allocations.
+        case_scope.instructions.items.len = 0;
+        assert(case.items_len != 0);
 
-        if (case.payload != null) {
-            return mod.fail(scope, case_src, "TODO switch case payload capture", .{});
-        }
-
+        // Check for else/_ prong, those are handled last.
         if (case.items_len == 1 and case.items()[0].tag == .SwitchElse) {
             if (else_src) |src| {
                 return mod.fail(scope, case_src, "multiple else prongs in switch expression", .{});
                 // TODO notes "previous else prong is here"
             }
             else_src = case_src;
+            special_case = case;
             continue;
         } else if (case.items_len == 1 and case.items()[0].tag == .Identifier and
             mem.eql(u8, tree.tokenSlice(case.items()[0].firstToken()), "_"))
@@ -1620,6 +1659,7 @@ fn switchExpr(mod: *Module, scope: *Scope, rl: ResultLoc, switch_node: *ast.Node
                 // TODO notes "previous '_' prong is here"
             }
             underscore_src = case_src;
+            special_case = case;
             continue;
         }
 
@@ -1631,103 +1671,107 @@ fn switchExpr(mod: *Module, scope: *Scope, rl: ResultLoc, switch_node: *ast.Node
             }
         }
 
-        // TODO and not range
-        if (case.items_len == 1) {
+        // If this is a simple one item prong then it is handled by the switchbr.
+        if (case.items_len == 1 and getRangeNode(case.items()[0]) == null) {
             const item = try expr(mod, &item_scope.base, .none, case.items()[0]);
+            try items.append(item);
+            try switchCaseExpr(mod, &case_scope.base, case_rl, block, case);
+
             try cases.append(.{
                 .item = item,
-                .body = undefined, // populated below
+                .body = .{ .instructions = try scope.arena().dupe(*zir.Inst, case_scope.instructions.items) },
             });
             continue;
         }
-        return mod.fail(scope, case_src, "TODO switch ranges", .{});
-    }
 
-    // Actually populate switch instruction values.
-    if (else_src != null) switch_inst.kw_args.special_prong = .@"else";
-    if (underscore_src != null) switch_inst.kw_args.special_prong = .underscore;
-    switch_inst.positionals.cases = try block_scope.arena.dupe(zir.Inst.SwitchBr.Case, cases.items);
-    switch_inst.positionals.items = try block_scope.arena.dupe(*zir.Inst, items.items);
-    switch_inst.kw_args.range = range_inst;
+        // TODO if the case has few items and no ranges it might be better
+        // to just handle them as switch prongs.
+
+        // Check if the target matches any of the items.
+        // 1, 2, 3..6 will result in
+        // target == 1 or target == 2 or (target >= 3 and target <= 6)
+        var any_ok: ?*zir.Inst = null;
+        for (case.items()) |item| {
+            if (getRangeNode(item)) |range| {
+                const start = try expr(mod, &item_scope.base, .none, range.lhs);
+                const end = try expr(mod, &item_scope.base, .none, range.rhs);
+                const range_src = tree.token_locs[range.op_token].start;
+                const range_inst = try addZIRBinOp(mod, &item_scope.base, range_src, .switch_range, start, end);
+                try items.append(range_inst);
+                if (first_range == null) first_range = range_inst;
+
+                // target >= start and target <= end
+                const range_start_ok = try addZIRBinOp(mod, &block_scope.base, range_src, .cmp_gte, target, start);
+                const range_end_ok = try addZIRBinOp(mod, &block_scope.base, range_src, .cmp_lte, target, end);
+                const range_ok = try addZIRBinOp(mod, &block_scope.base, range_src, .booland, range_start_ok, range_end_ok);
+
+                if (any_ok) |some| {
+                    any_ok = try addZIRBinOp(mod, &block_scope.base, range_src, .boolor, some, range_ok);
+                } else {
+                    any_ok = range_ok;
+                }
+                continue;
+            }
 
-    // Add comptime block containing all prong items first,
-    _ = try addZIRInstBlock(mod, scope, switch_src, .block_comptime_flat, .{
-        .instructions = try block_scope.arena.dupe(*zir.Inst, item_scope.instructions.items),
-    });
-    // then add block containing the switch.
-    const block = try addZIRInstBlock(mod, scope, switch_src, .block, .{
-        .instructions = undefined, // populated below
-    });
+            const item_inst = try expr(mod, &item_scope.base, .none, item);
+            try items.append(item_inst);
+            const cpm_ok = try addZIRBinOp(mod, &block_scope.base, item_inst.src, .cmp_eq, target, item_inst);
 
-    // Most result location types can be forwarded directly; however
-    // if we need to write to a pointer which has an inferred type,
-    // proper type inference requires peer type resolution on the switch case.
-    const case_rl: ResultLoc = switch (rl) {
-        .discard, .none, .ty, .ptr, .ref => rl,
-        .inferred_ptr, .bitcasted_ptr, .block_ptr => .{ .block_ptr = block },
-    };
+            if (any_ok) |some| {
+                any_ok = try addZIRBinOp(mod, &block_scope.base, item_inst.src, .boolor, some, cpm_ok);
+            } else {
+                any_ok = cpm_ok;
+            }
+        }
 
-    var case_scope: Scope.GenZIR = .{
-        .parent = scope,
-        .decl = block_scope.decl,
-        .arena = block_scope.arena,
-        .instructions = .{},
-    };
-    defer case_scope.instructions.deinit(mod.gpa);
+        const condbr = try addZIRInstSpecial(mod, &block_scope.base, case_src, zir.Inst.CondBr, .{
+            .condition = any_ok.?,
+            .then_body = undefined, // populated below
+            .else_body = undefined, // populated below
+        }, .{});
 
-    // And finally we fill generate the bodies of each case.
-    var case_index: usize = 0;
-    var special_case: ?*ast.Node.SwitchCase = null;
-    for (switch_node.cases()) |uncasted_case| {
-        const case = uncasted_case.castTag(.SwitchCase).?;
-        const case_src = tree.token_locs[case.firstToken()].start;
-        // reset without freeing to reduce allocations.
-        defer case_scope.instructions.items.len = 0;
+        try switchCaseExpr(mod, &case_scope.base, case_rl, block, case);
+        condbr.positionals.then_body = .{
+            .instructions = try scope.arena().dupe(*zir.Inst, case_scope.instructions.items),
+        };
 
-        if (case.items_len == 1 and case.items()[0].tag == .SwitchElse) {
-            // validated earlier
-            special_case = case;
-            continue;
-        } else if (case.items_len == 1 and case.items()[0].tag == .Identifier and
-            mem.eql(u8, tree.tokenSlice(case.items()[0].firstToken()), "_"))
-        {
-            // validated earlier
-            special_case = case;
-            continue;
-        }
+        // reset to add the empty block
+        case_scope.instructions.items.len = 0;
+        const empty_block = try addZIRInstBlock(mod, &case_scope.base, case_src, .block, .{
+            .instructions = undefined, // populated below
+        });
+        condbr.positionals.else_body = .{
+            .instructions = try scope.arena().dupe(*zir.Inst, case_scope.instructions.items),
+        };
 
-        if (case.items_len == 1) {
-            // Generate the body of this case.
-            const case_body = try expr(mod, &case_scope.base, case_rl, case.expr);
-            if (!case_body.tag.isNoReturn()) {
-                _ = try addZIRInst(mod, &case_scope.base, case_src, zir.Inst.Break, .{
-                    .block = block,
-                    .operand = case_body,
-                }, .{});
-            }
-            switch_inst.positionals.cases[case_index].body = .{
-                .instructions = try scope.arena().dupe(*zir.Inst, case_scope.instructions.items),
-            };
-            case_index += 1;
-            continue;
-        }
-        return mod.fail(scope, case_src, "TODO switch ranges", .{});
+        // reset to add a break to the empty block
+        case_scope.instructions.items.len = 0;
+        _ = try addZIRInst(mod, &case_scope.base, case_src, zir.Inst.BreakVoid, .{
+            .block = empty_block,
+        }, .{});
+        empty_block.positionals.body = .{
+            .instructions = try scope.arena().dupe(*zir.Inst, case_scope.instructions.items),
+        };
     }
 
+    // All items have been generated, add the instructions to the comptime block.
+    item_block.positionals.body = .{
+        .instructions = try block_scope.arena.dupe(*zir.Inst, item_scope.instructions.items),
+    };
+
+    // Actually populate switch instruction values.
+    if (else_src != null) switch_inst.kw_args.special_prong = .@"else";
+    if (underscore_src != null) switch_inst.kw_args.special_prong = .underscore;
+    switch_inst.positionals.cases = try block_scope.arena.dupe(zir.Inst.SwitchBr.Case, cases.items);
+    switch_inst.positionals.items = try block_scope.arena.dupe(*zir.Inst, items.items);
+    switch_inst.kw_args.range = first_range;
+
     // Generate else block or a break last to finish the block.
     if (special_case) |case| {
-        const case_src = tree.token_locs[case.firstToken()].start;
-        const case_body = try expr(mod, &block_scope.base, case_rl, case.expr);
-        if (!case_body.tag.isNoReturn()) {
-            _ = try addZIRInst(mod, &block_scope.base, case_src, zir.Inst.Break, .{
-                .block = block,
-                .operand = case_body,
-            }, .{});
-        }
+        try switchCaseExpr(mod, &block_scope.base, case_rl, block, case);
     } else {
-        _ = try addZIRInst(mod, &block_scope.base, switch_src, zir.Inst.BreakVoid, .{
-            .block = block,
-        }, .{});
+        // Not handling all possible cases is a compile error.
+        _ = try addZIRNoOp(mod, &block_scope.base, switch_src, .unreach_nocheck);
     }
 
     // Set block instructions now that it is finished.
@@ -1737,15 +1781,20 @@ fn switchExpr(mod: *Module, scope: *Scope, rl: ResultLoc, switch_node: *ast.Node
     return &block.base;
 }
 
-/// Only used for `a...b` in switches.
-fn switchRange(mod: *Module, scope: *Scope, node: *ast.Node.SimpleInfixOp) InnerError!*zir.Inst {
+fn switchCaseExpr(mod: *Module, scope: *Scope, rl: ResultLoc, block: *zir.Inst.Block, case: *ast.Node.SwitchCase) !void {
     const tree = scope.tree();
-    const src = tree.token_locs[node.op_token].start;
-
-    const start = try expr(mod, scope, .none, node.lhs);
-    const end = try expr(mod, scope, .none, node.rhs);
+    const case_src = tree.token_locs[case.firstToken()].start;
+    if (case.payload != null) {
+        return mod.fail(scope, case_src, "TODO switch case payload capture", .{});
+    }
 
-    return try addZIRBinOp(mod, scope, src, .switch_range, start, end);
+    const case_body = try expr(mod, scope, rl, case.expr);
+    if (!case_body.tag.isNoReturn()) {
+        _ = try addZIRInst(mod, scope, case_src, zir.Inst.Break, .{
+            .block = block,
+            .operand = case_body,
+        }, .{});
+    }
 }
 
 fn ret(mod: *Module, scope: *Scope, cfe: *ast.Node.ControlFlowExpression) InnerError!*zir.Inst {
src/codegen.zig
@@ -758,6 +758,8 @@ fn Function(comptime arch: std.Target.Cpu.Arch) type {
                 .br => return self.genBr(inst.castTag(.br).?),
                 .breakpoint => return self.genBreakpoint(inst.src),
                 .brvoid => return self.genBrVoid(inst.castTag(.brvoid).?),
+                .booland => return self.genBoolOp(inst.castTag(.booland).?),
+                .boolor => return self.genBoolOp(inst.castTag(.boolor).?),
                 .call => return self.genCall(inst.castTag(.call).?),
                 .cmp_lt => return self.genCmp(inst.castTag(.cmp_lt).?, .lt),
                 .cmp_lte => return self.genCmp(inst.castTag(.cmp_lte).?, .lte),
@@ -782,11 +784,11 @@ fn Function(comptime arch: std.Target.Cpu.Arch) type {
                 .retvoid => return self.genRetVoid(inst.castTag(.retvoid).?),
                 .store => return self.genStore(inst.castTag(.store).?),
                 .sub => return self.genSub(inst.castTag(.sub).?),
+                .switchbr => return self.genSwitch(inst.castTag(.switchbr).?),
                 .unreach => return MCValue{ .unreach = {} },
                 .unwrap_optional => return self.genUnwrapOptional(inst.castTag(.unwrap_optional).?),
                 .wrap_optional => return self.genWrapOptional(inst.castTag(.wrap_optional).?),
                 .varptr => return self.genVarPtr(inst.castTag(.varptr).?),
-                .switchbr => return self.genSwitch(inst.castTag(.switchbr).?),
             }
         }
 
@@ -2030,6 +2032,12 @@ fn Function(comptime arch: std.Target.Cpu.Arch) type {
             return self.brVoid(inst.base.src, inst.block);
         }
 
+        fn genBoolOp(self: *Self, inst: *ir.Inst.BinOp) !MCValue {
+            switch (arch) {
+                else => return self.fail(inst.base.src, "TODO genBoolOp for {}", .{self.target.cpu.arch}),
+            }
+        }
+
         fn brVoid(self: *Self, src: usize, block: *ir.Inst.Block) !MCValue {
             // Emit a jump with a relocation. It will be patched up after the block ends.
             try block.codegen.relocs.ensureCapacity(self.gpa, block.codegen.relocs.items.len + 1);
src/ir.zig
@@ -74,6 +74,8 @@ pub const Inst = struct {
         isnonnull,
         isnull,
         iserr,
+        booland,
+        boolor,
         /// Read a value from a pointer.
         load,
         loop,
@@ -126,6 +128,8 @@ pub const Inst = struct {
                 .cmp_gt,
                 .cmp_neq,
                 .store,
+                .booland,
+                .boolor,
                 => BinOp,
 
                 .arg => Arg,
src/zir.zig
@@ -85,8 +85,12 @@ pub const Inst = struct {
         block_comptime,
         /// Same as `block_flat` but additionally makes the inner instructions execute at comptime.
         block_comptime_flat,
+        /// Boolean AND. See also `bitand`.
+        booland,
         /// Boolean NOT. See also `bitnot`.
         boolnot,
+        /// Boolean OR. See also `bitor`.
+        boolor,
         /// Return a value from a `Block`.
         @"break",
         breakpoint,
@@ -333,6 +337,8 @@ pub const Inst = struct {
                 .array_type,
                 .bitand,
                 .bitor,
+                .booland,
+                .boolor,
                 .div,
                 .mod_rem,
                 .mul,
@@ -425,6 +431,8 @@ pub const Inst = struct {
                 .block_comptime,
                 .block_comptime_flat,
                 .boolnot,
+                .booland,
+                .boolor,
                 .breakpoint,
                 .call,
                 .cmp_lt,
@@ -502,6 +510,7 @@ pub const Inst = struct {
                 .slice_start,
                 .import,
                 .switchbr,
+                .switch_range,
                 => false,
 
                 .@"break",
@@ -513,7 +522,6 @@ pub const Inst = struct {
                 .unreach_nocheck,
                 .@"unreachable",
                 .loop,
-                .switch_range,
                 => true,
             };
         }
@@ -2320,6 +2328,8 @@ const EmitZIR = struct {
                 .cmp_gte => try self.emitBinOp(inst.src, new_body, inst.castTag(.cmp_gte).?, .cmp_gte),
                 .cmp_gt => try self.emitBinOp(inst.src, new_body, inst.castTag(.cmp_gt).?, .cmp_gt),
                 .cmp_neq => try self.emitBinOp(inst.src, new_body, inst.castTag(.cmp_neq).?, .cmp_neq),
+                .booland => try self.emitBinOp(inst.src, new_body, inst.castTag(.booland).?, .booland),
+                .boolor => try self.emitBinOp(inst.src, new_body, inst.castTag(.boolor).?, .boolor),
 
                 .bitcast => try self.emitCast(inst.src, new_body, inst.castTag(.bitcast).?, .bitcast),
                 .intcast => try self.emitCast(inst.src, new_body, inst.castTag(.intcast).?, .intcast),
src/zir_sema.zig
@@ -137,6 +137,8 @@ pub fn analyzeInst(mod: *Module, scope: *Scope, old_inst: *zir.Inst) InnerError!
         .import => return analyzeInstImport(mod, scope, old_inst.castTag(.import).?),
         .switchbr => return analyzeInstSwitchBr(mod, scope, old_inst.castTag(.switchbr).?),
         .switch_range => return analyzeInstSwitchRange(mod, scope, old_inst.castTag(.switch_range).?),
+        .booland => return analyzeInstBoolOp(mod, scope, old_inst.castTag(.booland).?),
+        .boolor => return analyzeInstBoolOp(mod, scope, old_inst.castTag(.boolor).?),
     }
 }
 
@@ -1224,7 +1226,7 @@ fn analyzeInstSwitchRange(mod: *Module, scope: *Scope, inst: *zir.Inst.BinOp) In
     if (start.value()) |start_val| {
         if (end.value()) |end_val| {
             if (start_val.compare(.gte, end_val)) {
-                return mod.fail(scope, inst.base.src, "range start value is greater than the end value", .{});
+                return mod.fail(scope, inst.base.src, "range start value must be smaller than the end value", .{});
             }
         }
     }
@@ -1609,6 +1611,28 @@ fn analyzeInstBoolNot(mod: *Module, scope: *Scope, inst: *zir.Inst.UnOp) InnerEr
     return mod.addUnOp(b, inst.base.src, bool_type, .not, operand);
 }
 
+fn analyzeInstBoolOp(mod: *Module, scope: *Scope, inst: *zir.Inst.BinOp) InnerError!*Inst {
+    const bool_type = Type.initTag(.bool);
+    const uncasted_lhs = try resolveInst(mod, scope, inst.positionals.lhs);
+    const lhs = try mod.coerce(scope, bool_type, uncasted_lhs);
+    const uncasted_rhs = try resolveInst(mod, scope, inst.positionals.rhs);
+    const rhs = try mod.coerce(scope, bool_type, uncasted_rhs);
+
+    const is_bool_or = inst.base.tag == .boolor;
+
+    if (lhs.value()) |lhs_val| {
+        if (rhs.value()) |rhs_val| {
+            if (is_bool_or) {
+                return mod.constBool(scope, inst.base.src, lhs_val.toBool() or rhs_val.toBool());
+            } else {
+                return mod.constBool(scope, inst.base.src, lhs_val.toBool() and rhs_val.toBool());
+            }
+        }
+    }
+    const b = try mod.requireRuntimeBlock(scope, inst.base.src);
+    return mod.addBinOp(b, inst.base.src, bool_type, if (is_bool_or) .boolor else .booland, lhs, rhs);
+}
+
 fn analyzeInstIsNonNull(mod: *Module, scope: *Scope, inst: *zir.Inst.UnOp, invert_logic: bool) InnerError!*Inst {
     const operand = try resolveInst(mod, scope, inst.positionals.operand);
     return mod.analyzeIsNull(scope, inst.base.src, operand, invert_logic);