Commit 12e4c648cc

Vexu <git@vexu.eu>
2020-10-17 00:09:42
stage2: implement switch validation for integers
1 parent 4155d2a
src/RangeSet.zig
@@ -0,0 +1,76 @@
+const std = @import("std");
+const Order = std.math.Order;
+const Value = @import("value.zig").Value;
+const RangeSet = @This();
+
+ranges: std.ArrayList(Range),
+
+pub const Range = struct {
+    start: Value,
+    end: Value,
+    src: usize,
+};
+
+pub fn init(allocator: *std.mem.Allocator) RangeSet {
+    return .{
+        .ranges = std.ArrayList(Range).init(allocator),
+    };
+}
+
+pub fn deinit(self: *RangeSet) void {
+    self.ranges.deinit();
+}
+
+pub fn add(self: *RangeSet, start: Value, end: Value, src: usize) !?usize {
+    for (self.ranges.items) |range| {
+        if ((start.compare(.gte, range.start) and start.compare(.lte, range.end)) or
+            (end.compare(.gte, range.start) and end.compare(.lte, range.end)))
+        {
+            // ranges overlap
+            return range.src;
+        }
+    }
+    try self.ranges.append(.{
+        .start = start,
+        .end = end,
+        .src = src,
+    });
+    return null;
+}
+
+/// Assumes a and b do not overlap
+fn lessThan(_: void, a: Range, b: Range) bool {
+    return a.start.compare(.lt, b.start);
+}
+
+pub fn spans(self: *RangeSet, start: Value, end: Value) !bool {
+    std.sort.sort(Range, self.ranges.items, {}, lessThan);
+
+    if (!self.ranges.items[0].start.eql(start) or
+        !self.ranges.items[self.ranges.items.len - 1].end.eql(end))
+    {
+        return false;
+    }
+
+    var space: Value.BigIntSpace = undefined;
+
+    var counter = try std.math.big.int.Managed.init(self.ranges.allocator);
+    defer counter.deinit();
+
+    // look for gaps
+    for (self.ranges.items[1..]) |cur, i| {
+        // i starts counting from the second item.
+        const prev = self.ranges.items[i];
+
+        // prev.end + 1 == cur.start
+        try counter.copy(prev.end.toBigInt(&space));
+        try counter.addScalar(counter.toConst(), 1);
+
+        const cur_start_int = cur.start.toBigInt(&space);
+        if (!cur_start_int.eq(counter.toConst())) {
+            return false;
+        }
+    }
+
+    return true;
+}
src/type.zig
@@ -2863,6 +2863,78 @@ pub const Type = extern union {
         };
     }
 
+    /// Asserts that self.zigTypeTag() == .Int.
+    pub fn minInt(self: Type, arena: *std.heap.ArenaAllocator, target: Target) !Value {
+        assert(self.zigTypeTag() == .Int);
+        const info = self.intInfo(target);
+
+        if (!info.signed) {
+            return Value.initTag(.zero);
+        }
+
+        if ((info.bits - 1) <= std.math.maxInt(u6)) {
+            const payload = try arena.allocator.create(Value.Payload.Int_i64);
+            payload.* = .{
+                .int = -(@as(i64, 1) << @truncate(u6, info.bits - 1)),
+            };
+            return Value.initPayload(&payload.base);
+        }
+
+        var res = try std.math.big.int.Managed.initSet(&arena.allocator, 1);
+        try res.shiftLeft(res, info.bits - 1);
+        res.negate();
+
+        const res_const = res.toConst();
+        if (res_const.positive) {
+            const val_payload = try arena.allocator.create(Value.Payload.IntBigPositive);
+            val_payload.* = .{ .limbs = res_const.limbs };
+            return Value.initPayload(&val_payload.base);
+        } else {
+            const val_payload = try arena.allocator.create(Value.Payload.IntBigNegative);
+            val_payload.* = .{ .limbs = res_const.limbs };
+            return Value.initPayload(&val_payload.base);
+        }
+    }
+
+    /// Asserts that self.zigTypeTag() == .Int.
+    pub fn maxInt(self: Type, arena: *std.heap.ArenaAllocator, target: Target) !Value {
+        assert(self.zigTypeTag() == .Int);
+        const info = self.intInfo(target);
+
+        if (info.signed and (info.bits - 1) <= std.math.maxInt(u6)) {
+            const payload = try arena.allocator.create(Value.Payload.Int_i64);
+            payload.* = .{
+                .int = (@as(i64, 1) << @truncate(u6, info.bits - 1)) - 1,
+            };
+            return Value.initPayload(&payload.base);
+        } else if (!info.signed and info.bits <= std.math.maxInt(u6)) {
+            const payload = try arena.allocator.create(Value.Payload.Int_u64);
+            payload.* = .{
+                .int = (@as(u64, 1) << @truncate(u6, info.bits)) - 1,
+            };
+            return Value.initPayload(&payload.base);
+        }
+
+        var res = try std.math.big.int.Managed.initSet(&arena.allocator, 1);
+        try res.shiftLeft(res, info.bits - @boolToInt(info.signed));
+        const one = std.math.big.int.Const{
+            .limbs = &[_]std.math.big.Limb{1},
+            .positive = true,
+        };
+        res.sub(res.toConst(), one) catch unreachable;
+
+        const res_const = res.toConst();
+        if (res_const.positive) {
+            const val_payload = try arena.allocator.create(Value.Payload.IntBigPositive);
+            val_payload.* = .{ .limbs = res_const.limbs };
+            return Value.initPayload(&val_payload.base);
+        } else {
+            const val_payload = try arena.allocator.create(Value.Payload.IntBigNegative);
+            val_payload.* = .{ .limbs = res_const.limbs };
+            return Value.initPayload(&val_payload.base);
+        }
+    }
+
     /// This enum does not directly correspond to `std.builtin.TypeId` because
     /// it has extra enum tags in it, as a way of using less memory. For example,
     /// even though Zig recognizes `*align(10) i32` and `*i32` both as Pointer types
src/zir_sema.zig
@@ -1268,7 +1268,7 @@ fn analyzeInstSwitchBr(mod: *Module, scope: *Scope, inst: *zir.Inst.SwitchBr) In
             .body = .{ .instructions = try parent_block.arena.dupe(*Inst, case_block.instructions.items) },
         };
     }
-    
+
     return mod.addSwitchBr(parent_block, inst.base.src, target_ptr, cases);
 }
 
@@ -1292,10 +1292,56 @@ fn validateSwitch(mod: *Module, scope: *Scope, target: *Inst, inst: *zir.Inst.Sw
 
     // validate for duplicate items/missing else prong
     switch (target.ty.zigTypeTag()) {
-        .Int, .ComptimeInt => return mod.fail(scope, inst.base.src, "TODO validateSwitch .Int, .ComptimeInt", .{}),
         .Enum => return mod.fail(scope, inst.base.src, "TODO validateSwitch .Enum", .{}),
         .ErrorSet => return mod.fail(scope, inst.base.src, "TODO validateSwitch .ErrorSet", .{}),
         .Union => return mod.fail(scope, inst.base.src, "TODO validateSwitch .Union", .{}),
+        .Int, .ComptimeInt => {
+            var range_set = @import("RangeSet.zig").init(mod.gpa);
+            defer range_set.deinit();
+
+            for (inst.positionals.items) |item| {
+                const maybe_src = if (item.castTag(.switch_range)) |range| blk: {
+                    const start_resolved = try resolveInst(mod, scope, range.positionals.lhs);
+                    const start_casted = try mod.coerce(scope, target.ty, start_resolved);
+                    const end_resolved = try resolveInst(mod, scope, range.positionals.rhs);
+                    const end_casted = try mod.coerce(scope, target.ty, end_resolved);
+
+                    break :blk try range_set.add(
+                        try mod.resolveConstValue(scope, start_casted),
+                        try mod.resolveConstValue(scope, end_casted),
+                        item.src,
+                    );
+                } else blk: {
+                    const resolved = try resolveInst(mod, scope, item);
+                    const casted = try mod.coerce(scope, target.ty, resolved);
+                    const value = try mod.resolveConstValue(scope, casted);
+                    break :blk try range_set.add(value, value, item.src);
+                };
+
+                if (maybe_src) |previous_src| {
+                    return mod.fail(scope, item.src, "duplicate switch value", .{});
+                    // TODO notes "previous value is here" previous_src
+                }
+            }
+
+            if (target.ty.zigTypeTag() == .Int) {
+                var arena = std.heap.ArenaAllocator.init(mod.gpa);
+                defer arena.deinit();
+
+                const start = try target.ty.minInt(&arena, mod.getTarget());
+                const end = try target.ty.maxInt(&arena, mod.getTarget());
+                if (try range_set.spans(start, end)) {
+                    if (inst.kw_args.special_prong == .@"else") {
+                        return mod.fail(scope, inst.base.src, "unreachable else prong, all cases already handled", .{});
+                    }
+                    return;
+                }
+            }
+
+            if (inst.kw_args.special_prong != .@"else") {
+                return mod.fail(scope, inst.base.src, "switch must handle all possibilities", .{});
+            }
+        },
         .Bool => {
             var true_count: u8 = 0;
             var false_count: u8 = 0;