Commit 5baaf90e3c

Veikka Tuominen <git@vexu.eu>
2022-09-26 14:44:40
Sema: implement non-special inline switch prongs
1 parent cccc4c3
Changed files (4)
src/Module.zig
@@ -2445,8 +2445,8 @@ pub const SrcLoc = struct {
                 const case_nodes = tree.extra_data[extra.start..extra.end];
                 for (case_nodes) |case_node| {
                     const case = switch (node_tags[case_node]) {
-                        .switch_case_one => tree.switchCaseOne(case_node),
-                        .switch_case => tree.switchCase(case_node),
+                        .switch_case_one, .switch_case_inline_one => tree.switchCaseOne(case_node),
+                        .switch_case, .switch_case_inline => tree.switchCase(case_node),
                         else => unreachable,
                     };
                     const is_special = (case.ast.values.len == 0) or
@@ -2469,8 +2469,8 @@ pub const SrcLoc = struct {
                 const case_nodes = tree.extra_data[extra.start..extra.end];
                 for (case_nodes) |case_node| {
                     const case = switch (node_tags[case_node]) {
-                        .switch_case_one => tree.switchCaseOne(case_node),
-                        .switch_case => tree.switchCase(case_node),
+                        .switch_case_one, .switch_case_inline_one => tree.switchCaseOne(case_node),
+                        .switch_case, .switch_case_inline => tree.switchCase(case_node),
                         else => unreachable,
                     };
                     const is_special = (case.ast.values.len == 0) or
@@ -2491,8 +2491,8 @@ pub const SrcLoc = struct {
                 const case_node = src_loc.declRelativeToNodeIndex(node_off);
                 const node_tags = tree.nodes.items(.tag);
                 const case = switch (node_tags[case_node]) {
-                    .switch_case_one => tree.switchCaseOne(case_node),
-                    .switch_case => tree.switchCase(case_node),
+                    .switch_case_one, .switch_case_inline_one => tree.switchCaseOne(case_node),
+                    .switch_case, .switch_case_inline => tree.switchCase(case_node),
                     else => unreachable,
                 };
                 const start_tok = case.payload_token.?;
@@ -5937,8 +5937,8 @@ pub const SwitchProngSrc = union(enum) {
         var scalar_i: u32 = 0;
         for (case_nodes) |case_node| {
             const case = switch (node_tags[case_node]) {
-                .switch_case_one => tree.switchCaseOne(case_node),
-                .switch_case => tree.switchCase(case_node),
+                .switch_case_one, .switch_case_inline_one => tree.switchCaseOne(case_node),
+                .switch_case, .switch_case_inline => tree.switchCase(case_node),
                 else => unreachable,
             };
             if (case.ast.values.len == 0)
src/Sema.zig
@@ -162,6 +162,9 @@ pub const Block = struct {
     /// type of `err` in `else => |err|`
     switch_else_err_ty: ?Type = null,
 
+    /// Value for switch_capture in an inline case
+    inline_case_capture: Air.Inst.Ref = .none,
+
     const Param = struct {
         /// `noreturn` means `anytype`.
         ty: Type,
@@ -9002,6 +9005,30 @@ fn zirSwitchCapture(
     const operand_ptr_ty = sema.typeOf(operand_ptr);
     const operand_ty = if (operand_is_ref) operand_ptr_ty.childType() else operand_ptr_ty;
 
+    if (block.inline_case_capture != .none) {
+        const item_val = sema.resolveConstValue(block, .unneeded, block.inline_case_capture, undefined) catch unreachable;
+        if (operand_ty.zigTypeTag() == .Union) {
+            const field_index = @intCast(u32, operand_ty.unionTagFieldIndex(item_val, sema.mod).?);
+            const union_obj = operand_ty.cast(Type.Payload.Union).?.data;
+            const field_ty = union_obj.fields.values()[field_index].ty;
+            if (is_ref) {
+                const ptr_field_ty = try Type.ptr(sema.arena, sema.mod, .{
+                    .pointee_type = field_ty,
+                    .mutable = operand_ptr_ty.ptrIsMutable(),
+                    .@"volatile" = operand_ptr_ty.isVolatilePtr(),
+                    .@"addrspace" = operand_ptr_ty.ptrAddressSpace(),
+                });
+                return block.addStructFieldPtr(operand_ptr, field_index, ptr_field_ty);
+            } else {
+                return block.addStructFieldVal(operand_ptr, field_index, field_ty);
+            }
+        } else if (is_ref) {
+            return sema.addConstantMaybeRef(block, operand_src, operand_ty, item_val, true);
+        } else {
+            return block.inline_case_capture;
+        }
+    }
+
     const operand = if (operand_is_ref)
         try sema.analyzeLoad(block, operand_src, operand_ptr, operand_src)
     else
@@ -9234,14 +9261,15 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError
     } else 0;
 
     const special_prong = extra.data.bits.specialProng();
-    const special: struct { body: []const Zir.Inst.Index, end: usize } = switch (special_prong) {
-        .none => .{ .body = &.{}, .end = header_extra_index },
+    const special: struct { body: []const Zir.Inst.Index, end: usize, is_inline: bool } = switch (special_prong) {
+        .none => .{ .body = &.{}, .end = header_extra_index, .is_inline = false },
         .under, .@"else" => blk: {
             const body_len = @truncate(u31, sema.code.extra[header_extra_index]);
             const extra_body_start = header_extra_index + 1;
             break :blk .{
                 .body = sema.code.extra[extra_body_start..][0..body_len],
                 .end = extra_body_start + body_len,
+                .is_inline = sema.code.extra[header_extra_index] >> 31 != 0,
             };
         },
     };
@@ -9901,6 +9929,9 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError
         if (special_prong == .none) {
             return sema.fail(block, src, "switch must handle all possibilities", .{});
         }
+        if (special.is_inline) {
+            return sema.fail(block, src, "TODO special.is_inline", .{});
+        }
         if (err_set and try sema.maybeErrorUnwrap(block, special.body, operand)) {
             return Air.Inst.Ref.unreachable_value;
         }
@@ -9927,6 +9958,7 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError
         const item_ref = @intToEnum(Zir.Inst.Ref, sema.code.extra[extra_index]);
         extra_index += 1;
         const body_len = @truncate(u31, sema.code.extra[extra_index]);
+        const is_inline = sema.code.extra[extra_index] >> 31 != 0;
         extra_index += 1;
         const body = sema.code.extra[extra_index..][0..body_len];
         extra_index += body_len;
@@ -9936,8 +9968,10 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError
 
         case_block.instructions.shrinkRetainingCapacity(0);
         case_block.wip_capture_scope = wip_captures.scope;
+        case_block.inline_case_capture = .none;
 
         const item = try sema.resolveInst(item_ref);
+        if (is_inline) case_block.inline_case_capture = item;
         // `item` is already guaranteed to be constant known.
 
         const analyze_body = if (union_originally) blk: {
@@ -9989,12 +10023,118 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError
         const ranges_len = sema.code.extra[extra_index];
         extra_index += 1;
         const body_len = @truncate(u31, sema.code.extra[extra_index]);
+        const is_inline = sema.code.extra[extra_index] >> 31 != 0;
         extra_index += 1;
         const items = sema.code.refSlice(extra_index, items_len);
         extra_index += items_len;
 
         case_block.instructions.shrinkRetainingCapacity(0);
         case_block.wip_capture_scope = child_block.wip_capture_scope;
+        case_block.inline_case_capture = .none;
+
+        // Generate all possible cases as scalar prongs.
+        if (is_inline) {
+            const body_start = extra_index + 2 * ranges_len;
+            const body = sema.code.extra[body_start..][0..body_len];
+            const case_src = src; // TODO better source location
+            var emit_bb = false;
+
+            var range_i: usize = 0;
+            while (range_i < ranges_len) : (range_i += 1) {
+                const first_ref = @intToEnum(Zir.Inst.Ref, sema.code.extra[extra_index]);
+                extra_index += 1;
+                const last_ref = @intToEnum(Zir.Inst.Ref, sema.code.extra[extra_index]);
+                extra_index += 1;
+
+                const item_first_ref = try sema.resolveInst(first_ref);
+                var item_first = sema.resolveConstValue(block, .unneeded, item_first_ref, undefined) catch unreachable;
+                const item_last_ref = try sema.resolveInst(last_ref);
+                const item_last = sema.resolveConstValue(block, .unneeded, item_last_ref, undefined) catch unreachable;
+
+                while (item_first.compare(.lte, item_last, operand_ty, sema.mod)) : ({
+                    item_first = try sema.intAddScalar(block, case_src, item_first, Value.one);
+                }) {
+                    cases_len += 1;
+
+                    const item_ref = try sema.addConstant(operand_ty, item_first);
+                    case_block.inline_case_capture = item_ref;
+
+                    case_block.instructions.shrinkRetainingCapacity(0);
+                    case_block.wip_capture_scope = child_block.wip_capture_scope;
+
+                    if (emit_bb) try sema.emitBackwardBranch(block, case_src);
+                    emit_bb = true;
+
+                    _ = sema.analyzeBodyInner(&case_block, body) catch |err| switch (err) {
+                        error.ComptimeBreak => {
+                            const zir_datas = sema.code.instructions.items(.data);
+                            const break_data = zir_datas[sema.comptime_break_inst].@"break";
+                            try sema.addRuntimeBreak(&case_block, .{
+                                .block_inst = break_data.block_inst,
+                                .operand = break_data.operand,
+                                .inst = sema.comptime_break_inst,
+                            });
+                        },
+                        else => |e| return e,
+                    };
+
+                    // try wip_captures.finalize();
+
+                    try cases_extra.ensureUnusedCapacity(gpa, 3 + case_block.instructions.items.len);
+                    cases_extra.appendAssumeCapacity(1); // items_len
+                    cases_extra.appendAssumeCapacity(@intCast(u32, case_block.instructions.items.len));
+                    cases_extra.appendAssumeCapacity(@enumToInt(item_ref));
+                    cases_extra.appendSliceAssumeCapacity(case_block.instructions.items);
+                }
+            }
+
+            for (items) |item_ref| {
+                cases_len += 1;
+
+                const item = try sema.resolveInst(item_ref);
+                case_block.inline_case_capture = item;
+
+                case_block.instructions.shrinkRetainingCapacity(0);
+                case_block.wip_capture_scope = child_block.wip_capture_scope;
+
+                const analyze_body = if (union_originally) blk: {
+                    const item_val = sema.resolveConstValue(block, .unneeded, item, undefined) catch unreachable;
+                    const field_ty = maybe_union_ty.unionFieldType(item_val, sema.mod);
+                    break :blk field_ty.zigTypeTag() != .NoReturn;
+                } else true;
+
+                if (emit_bb) try sema.emitBackwardBranch(block, case_src);
+                emit_bb = true;
+
+                if (analyze_body) {
+                    _ = sema.analyzeBodyInner(&case_block, body) catch |err| switch (err) {
+                        error.ComptimeBreak => {
+                            const zir_datas = sema.code.instructions.items(.data);
+                            const break_data = zir_datas[sema.comptime_break_inst].@"break";
+                            try sema.addRuntimeBreak(&case_block, .{
+                                .block_inst = break_data.block_inst,
+                                .operand = break_data.operand,
+                                .inst = sema.comptime_break_inst,
+                            });
+                        },
+                        else => |e| return e,
+                    };
+                } else {
+                    _ = try case_block.addNoOp(.unreach);
+                }
+
+                // try wip_captures.finalize();
+
+                try cases_extra.ensureUnusedCapacity(gpa, 3 + case_block.instructions.items.len);
+                cases_extra.appendAssumeCapacity(1); // items_len
+                cases_extra.appendAssumeCapacity(@intCast(u32, case_block.instructions.items.len));
+                cases_extra.appendAssumeCapacity(@enumToInt(item));
+                cases_extra.appendSliceAssumeCapacity(case_block.instructions.items);
+            }
+
+            extra_index += body_len;
+            continue;
+        }
 
         var any_ok: Air.Inst.Ref = .none;
 
@@ -10158,6 +10298,10 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError
 
         case_block.instructions.shrinkRetainingCapacity(0);
         case_block.wip_capture_scope = wip_captures.scope;
+        case_block.inline_case_capture = .none;
+        if (special.is_inline) {
+            return sema.fail(block, src, "TODO special.is_inline", .{});
+        }
 
         const analyze_body = if (union_originally)
             for (seen_union_fields) |seen_field, index| {
test/behavior/inline_switch.zig
@@ -0,0 +1,57 @@
+const std = @import("std");
+const expect = std.testing.expect;
+const builtin = @import("builtin");
+
+test "inline scalar prongs" {
+    if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
+
+    var x: usize = 0;
+    switch (x) {
+        10 => |*item| try expect(@TypeOf(item) == *usize),
+        inline 11 => |*item| {
+            try expect(@TypeOf(item) == *const usize);
+            try expect(item.* == 11);
+        },
+        else => {},
+    }
+}
+
+test "inline prong ranges" {
+    if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
+
+    var x: usize = 0;
+    switch (x) {
+        inline 0...20, 24 => |item| {
+            if (item > 25) @compileError("bad");
+        },
+        else => {},
+    }
+}
+
+const E = enum { a, b, c, d };
+test "inline switch enums" {
+    if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
+
+    var x: E = .a;
+    switch (x) {
+        inline .a, .b => |aorb| if (aorb != .a and aorb != .b) @compileError("bad"),
+        inline .c, .d => |cord| if (cord != .c and cord != .d) @compileError("bad"),
+    }
+}
+
+const U = union(E) { a: void, b: u2, c: u3, d: u4 };
+test "inline switch unions" {
+    if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
+    if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
+    if (builtin.zig_backend == .stage2_x86_64) return error.SkipZigTest; // TODO
+
+    var x: U = .a;
+    switch (x) {
+        inline .a, .b => |aorb| {
+            try expect(@TypeOf(aorb) == void or @TypeOf(aorb) == u2);
+        },
+        inline .c, .d => |cord| {
+            try expect(@TypeOf(cord) == u3 or @TypeOf(cord) == u4);
+        },
+    }
+}
test/behavior.zig
@@ -180,6 +180,7 @@ test {
         _ = @import("behavior/decltest.zig");
         _ = @import("behavior/packed_struct_explicit_backing_int.zig");
         _ = @import("behavior/empty_union.zig");
+        _ = @import("behavior/inline_switch.zig");
     }
 
     if (builtin.os.tag != .wasi) {