Commit c4cc796695

mlugg <mlugg@mlugg.co.uk>
2023-06-14 01:51:31
Sema: consider type bounds when refining result type of `@min`/`@max`
I achieved this through a major refactor of the logic of analyzeMinMax. This change should be compatible with vectors of comptime_int, which Andrew said are supposed to work (but which currently do not).
1 parent 5d9e8f2
Changed files (3)
src/Sema.zig
@@ -22984,104 +22984,127 @@ fn analyzeMinMax(
         else => @compileError("unreachable"),
     };
 
-    // First, find all comptime-known arguments, and get their min/max
+    // The set of runtime-known operands. Set up in the loop below.
     var runtime_known = try std.DynamicBitSet.initFull(sema.arena, operands.len);
+    // The current minmax value - initially this will always be comptime-known, then we'll add
+    // runtime values into the mix later.
     var cur_minmax: ?Air.Inst.Ref = null;
     var cur_minmax_src: LazySrcLoc = undefined; // defined if cur_minmax not null
+    // The current known scalar bounds of the value.
+    var bounds_status: enum {
+        unknown, // We've only seen undef comptime_ints so far, so do not know the bounds.
+        defined, // We've seen only integers, so the bounds are defined.
+        non_integral, // There are floats in the mix, so the bounds aren't defined.
+    } = .unknown;
+    var cur_min_scalar: Value = undefined;
+    var cur_max_scalar: Value = undefined;
+
+    // First, find all comptime-known arguments, and get their min/max
+
     for (operands, operand_srcs, 0..) |operand, operand_src, operand_idx| {
         // Resolve the value now to avoid redundant calls to `checkSimdBinOp` - we'll have to call
         // it in the runtime path anyway since the result type may have been refined
-        const uncasted_operand_val = (try sema.resolveMaybeUndefVal(operand)) orelse continue;
-        if (cur_minmax) |cur| {
-            const simd_op = try sema.checkSimdBinOp(block, src, cur, operand, cur_minmax_src, operand_src);
-            const cur_val = simd_op.lhs_val.?; // cur_minmax is comptime-known
-            const operand_val = simd_op.rhs_val.?; // we checked the operand was resolvable above
-
-            runtime_known.unset(operand_idx);
+        const unresolved_uncoerced_val = try sema.resolveMaybeUndefVal(operand) orelse continue;
+        const uncoerced_val = try sema.resolveLazyValue(unresolved_uncoerced_val);
+
+        runtime_known.unset(operand_idx);
+
+        switch (bounds_status) {
+            .unknown, .defined => refine_bounds: {
+                const ty = sema.typeOf(operand);
+                if (!ty.scalarType(mod).isInt(mod) and !ty.scalarType(mod).eql(Type.comptime_int, mod)) {
+                    bounds_status = .non_integral;
+                    break :refine_bounds;
+                }
+                const scalar_bounds: ?[2]Value = bounds: {
+                    if (!ty.isVector(mod)) break :bounds try uncoerced_val.intValueBounds(mod);
+                    var cur_bounds: [2]Value = try Value.intValueBounds(try uncoerced_val.elemValue(mod, 0), mod) orelse break :bounds null;
+                    const len = try sema.usizeCast(block, src, ty.vectorLen(mod));
+                    for (1..len) |i| {
+                        const elem = try uncoerced_val.elemValue(mod, i);
+                        const elem_bounds = try elem.intValueBounds(mod) orelse break :bounds null;
+                        cur_bounds = .{
+                            Value.numberMin(elem_bounds[0], cur_bounds[0], mod),
+                            Value.numberMax(elem_bounds[1], cur_bounds[1], mod),
+                        };
+                    }
+                    break :bounds cur_bounds;
+                };
+                if (scalar_bounds) |bounds| {
+                    if (bounds_status == .unknown) {
+                        cur_min_scalar = bounds[0];
+                        cur_max_scalar = bounds[1];
+                        bounds_status = .defined;
+                    } else {
+                        cur_min_scalar = opFunc(cur_min_scalar, bounds[0], mod);
+                        cur_max_scalar = opFunc(cur_max_scalar, bounds[1], mod);
+                    }
+                }
+            },
+            .non_integral => {},
+        }
 
-            if (cur_val.isUndef(mod)) continue; // result is also undef
-            if (operand_val.isUndef(mod)) {
-                cur_minmax = try sema.addConstUndef(simd_op.result_ty);
-                continue;
-            }
+        const cur = cur_minmax orelse {
+            cur_minmax = operand;
+            cur_minmax_src = operand_src;
+            continue;
+        };
 
-            const resolved_cur_val = try sema.resolveLazyValue(cur_val);
-            const resolved_operand_val = try sema.resolveLazyValue(operand_val);
+        const simd_op = try sema.checkSimdBinOp(block, src, cur, operand, cur_minmax_src, operand_src);
+        const cur_val = try sema.resolveLazyValue(simd_op.lhs_val.?); // cur_minmax is comptime-known
+        const operand_val = try sema.resolveLazyValue(simd_op.rhs_val.?); // we checked the operand was resolvable above
 
-            const vec_len = simd_op.len orelse {
-                const result_val = opFunc(resolved_cur_val, resolved_operand_val, mod);
-                cur_minmax = try sema.addConstant(simd_op.result_ty, result_val);
-                continue;
-            };
-            const elems = try sema.arena.alloc(InternPool.Index, vec_len);
-            for (elems, 0..) |*elem, i| {
-                const lhs_elem_val = try resolved_cur_val.elemValue(mod, i);
-                const rhs_elem_val = try resolved_operand_val.elemValue(mod, i);
-                elem.* = try opFunc(lhs_elem_val, rhs_elem_val, mod).intern(simd_op.scalar_ty, mod);
-            }
-            cur_minmax = try sema.addConstant(simd_op.result_ty, (try mod.intern(.{ .aggregate = .{
-                .ty = simd_op.result_ty.toIntern(),
-                .storage = .{ .elems = elems },
-            } })).toValue());
-        } else {
-            runtime_known.unset(operand_idx);
-            cur_minmax = try sema.addConstant(sema.typeOf(operand), uncasted_operand_val);
-            cur_minmax_src = operand_src;
+        const vec_len = simd_op.len orelse {
+            const result_val = opFunc(cur_val, operand_val, mod);
+            cur_minmax = try sema.addConstant(simd_op.result_ty, result_val);
+            continue;
+        };
+        const elems = try sema.arena.alloc(InternPool.Index, vec_len);
+        for (elems, 0..) |*elem, i| {
+            const lhs_elem_val = try cur_val.elemValue(mod, i);
+            const rhs_elem_val = try operand_val.elemValue(mod, i);
+            const uncoerced_elem = opFunc(lhs_elem_val, rhs_elem_val, mod);
+            elem.* = (try mod.getCoerced(uncoerced_elem, simd_op.scalar_ty)).toIntern();
         }
+        cur_minmax = try sema.addConstant(simd_op.result_ty, (try mod.intern(.{ .aggregate = .{
+            .ty = simd_op.result_ty.toIntern(),
+            .storage = .{ .elems = elems },
+        } })).toValue());
     }
 
     const opt_runtime_idx = runtime_known.findFirstSet();
 
-    const comptime_refined_ty: ?Type = if (cur_minmax) |ct_minmax_ref| refined: {
-        // Refine the comptime-known result type based on the operation
+    if (cur_minmax) |ct_minmax_ref| refine: {
+        // Refine the comptime-known result type based on the bounds. This isn't strictly necessary
+        // in the runtime case, since we'll refine the type again later, but keeping things as small
+        // as possible will allow us to emit more optimal AIR (if all the runtime operands have
+        // smaller types than the non-refined comptime type).
+
         const val = (try sema.resolveMaybeUndefVal(ct_minmax_ref)).?;
         const orig_ty = sema.typeOf(ct_minmax_ref);
 
-        if (opt_runtime_idx == null and orig_ty.eql(Type.comptime_int, mod)) {
+        if (opt_runtime_idx == null and orig_ty.scalarType(mod).eql(Type.comptime_int, mod)) {
             // If all arguments were `comptime_int`, and there are no runtime args, we'll preserve that type
-            break :refined orig_ty;
+            break :refine;
         }
 
-        const refined_ty = if (orig_ty.zigTypeTag(mod) == .Vector) blk: {
-            const elem_ty = orig_ty.childType(mod);
-            const len = orig_ty.vectorLen(mod);
-
-            if (len == 0) break :blk orig_ty;
-            if (elem_ty.isAnyFloat()) break :blk orig_ty; // can't refine floats
+        // We can't refine float types
+        if (orig_ty.scalarType(mod).isAnyFloat()) break :refine;
 
-            var cur_min: Value = try val.elemValue(mod, 0);
-            var cur_max: Value = cur_min;
-            for (1..len) |idx| {
-                const elem_val = try val.elemValue(mod, idx);
-                if (elem_val.isUndef(mod)) break :blk orig_ty; // can't refine undef
-                if (Value.order(elem_val, cur_min, mod).compare(.lt)) cur_min = elem_val;
-                if (Value.order(elem_val, cur_max, mod).compare(.gt)) cur_max = elem_val;
-            }
+        assert(bounds_status == .defined); // there was a non-comptime-int integral comptime-known arg
 
-            const refined_elem_ty = try mod.intFittingRange(cur_min, cur_max);
-            break :blk try mod.vectorType(.{
-                .len = len,
-                .child = refined_elem_ty.toIntern(),
-            });
-        } else blk: {
-            if (orig_ty.isAnyFloat()) break :blk orig_ty; // can't refine floats
-            if (val.isUndef(mod)) break :blk orig_ty; // can't refine undef
-            break :blk try mod.intFittingRange(val, val);
-        };
+        const refined_scalar_ty = try mod.intFittingRange(cur_min_scalar, cur_max_scalar);
+        const refined_ty = if (orig_ty.isVector(mod)) try mod.vectorType(.{
+            .len = orig_ty.vectorLen(mod),
+            .child = refined_scalar_ty.toIntern(),
+        }) else refined_scalar_ty;
 
-        // Apply the refined type to the current value - this isn't strictly necessary in the
-        // runtime case since we'll refine again afterwards, but keeping things as small as possible
-        // will allow us to emit more optimal AIR (if all the runtime operands have smaller types
-        // than the non-refined comptime type).
-        if (!refined_ty.eql(orig_ty, mod)) {
-            if (std.debug.runtime_safety) {
-                assert(try sema.intFitsInType(val, refined_ty, null));
-            }
-            cur_minmax = try sema.coerceInMemory(val, refined_ty);
+        // Apply the refined type to the current value
+        if (std.debug.runtime_safety) {
+            assert(try sema.intFitsInType(val, refined_ty, null));
         }
-
-        break :refined refined_ty;
-    } else null;
+        cur_minmax = try sema.coerceInMemory(val, refined_ty);
+    }
 
     const runtime_idx = opt_runtime_idx orelse return cur_minmax.?;
     const runtime_src = operand_srcs[runtime_idx];
@@ -23102,6 +23125,11 @@ fn analyzeMinMax(
         cur_minmax = operands[0];
         cur_minmax_src = runtime_src;
         runtime_known.unset(0); // don't look at this operand in the loop below
+        const scalar_ty = sema.typeOf(cur_minmax.?).scalarType(mod);
+        if (scalar_ty.isInt(mod)) {
+            cur_min_scalar = try scalar_ty.minInt(mod, scalar_ty);
+            cur_max_scalar = try scalar_ty.maxInt(mod, scalar_ty);
+        }
     }
 
     var it = runtime_known.iterator(.{});
@@ -23112,49 +23140,49 @@ fn analyzeMinMax(
         const rhs_src = operand_srcs[idx];
         const simd_op = try sema.checkSimdBinOp(block, src, lhs, rhs, lhs_src, rhs_src);
         if (known_undef) {
-            cur_minmax = try sema.addConstant(simd_op.result_ty, Value.undef);
+            cur_minmax = try sema.addConstUndef(simd_op.result_ty);
         } else {
             cur_minmax = try block.addBinOp(air_tag, simd_op.lhs, simd_op.rhs);
         }
+        // Compute the bounds of this type
+        switch (bounds_status) {
+            .unknown, .defined => refine_bounds: {
+                const scalar_ty = sema.typeOf(rhs).scalarType(mod);
+                if (scalar_ty.isAnyFloat()) {
+                    bounds_status = .non_integral;
+                    break :refine_bounds;
+                }
+                const scalar_min = try scalar_ty.minInt(mod, scalar_ty);
+                const scalar_max = try scalar_ty.maxInt(mod, scalar_ty);
+                if (bounds_status == .unknown) {
+                    cur_min_scalar = scalar_min;
+                    cur_max_scalar = scalar_max;
+                    bounds_status = .defined;
+                } else {
+                    cur_min_scalar = opFunc(cur_min_scalar, scalar_min, mod);
+                    cur_max_scalar = opFunc(cur_max_scalar, scalar_max, mod);
+                }
+            },
+            .non_integral => {},
+        }
     }
 
-    if (comptime_refined_ty) |comptime_ty| refine: {
-        // Finally, refine the type based on the comptime-known bound.
-        if (known_undef) break :refine; // can't refine undef
-        const unrefined_ty = sema.typeOf(cur_minmax.?);
-        const is_vector = unrefined_ty.zigTypeTag(mod) == .Vector;
-        const comptime_elem_ty = if (is_vector) comptime_ty.childType(mod) else comptime_ty;
-        const unrefined_elem_ty = if (is_vector) unrefined_ty.childType(mod) else unrefined_ty;
-
-        if (unrefined_elem_ty.isAnyFloat()) break :refine; // we can't refine floats
-
-        // Compute the final bounds based on the runtime type and the comptime-known bound type
-        const min_val = switch (air_tag) {
-            .min => try unrefined_elem_ty.minInt(mod, unrefined_elem_ty),
-            .max => try comptime_elem_ty.minInt(mod, comptime_elem_ty), // @max(ct, rt) >= ct
-            else => unreachable,
-        };
-        const max_val = switch (air_tag) {
-            .min => try comptime_elem_ty.maxInt(mod, comptime_elem_ty), // @min(ct, rt) <= ct
-            .max => try unrefined_elem_ty.maxInt(mod, unrefined_elem_ty),
-            else => unreachable,
-        };
-
-        // Find the smallest type which can contain these bounds
-        const final_elem_ty = try mod.intFittingRange(min_val, max_val);
-
-        const final_ty = if (is_vector)
-            try mod.vectorType(.{
-                .len = unrefined_ty.vectorLen(mod),
-                .child = final_elem_ty.toIntern(),
-            })
-        else
-            final_elem_ty;
+    // Finally, refine the type based on the known bounds.
+    const unrefined_ty = sema.typeOf(cur_minmax.?);
+    if (unrefined_ty.scalarType(mod).isAnyFloat()) {
+        // We can't refine floats, so we're done.
+        return cur_minmax.?;
+    }
+    assert(bounds_status == .defined); // there were integral runtime operands
+    const refined_scalar_ty = try mod.intFittingRange(cur_min_scalar, cur_max_scalar);
+    const refined_ty = if (unrefined_ty.isVector(mod)) try mod.vectorType(.{
+        .len = unrefined_ty.vectorLen(mod),
+        .child = refined_scalar_ty.toIntern(),
+    }) else refined_scalar_ty;
 
-        if (!final_ty.eql(unrefined_ty, mod)) {
-            // We've reduced the type - cast the result down
-            return block.addTyOp(.intcast, final_ty, cur_minmax.?);
-        }
+    if (!refined_ty.eql(unrefined_ty, mod)) {
+        // We've reduced the type - cast the result down
+        return block.addTyOp(.intcast, refined_ty, cur_minmax.?);
     }
 
     return cur_minmax.?;
src/value.zig
@@ -4146,6 +4146,20 @@ pub const Value = struct {
         return val.toIntern() == .generic_poison;
     }
 
+    /// For an integer (comptime or fixed-width) `val`, returns the comptime-known bounds of the value.
+    /// If `val` is not undef, the bounds are both `val`.
+    /// If `val` is undef and has a fixed-width type, the bounds are the bounds of the type.
+    /// If `val` is undef and is a `comptime_int`, returns null.
+    pub fn intValueBounds(val: Value, mod: *Module) !?[2]Value {
+        if (!val.isUndef(mod)) return .{ val, val };
+        const ty = mod.intern_pool.typeOf(val.toIntern());
+        if (ty == .comptime_int_type) return null;
+        return .{
+            try ty.toType().minInt(mod, ty.toType()),
+            try ty.toType().maxInt(mod, ty.toType()),
+        };
+    }
+
     /// This type is not copyable since it may contain pointers to its inner data.
     pub const Payload = struct {
         tag: Tag,
test/behavior/maximum_minimum.zig
@@ -1,6 +1,7 @@
 const std = @import("std");
 const builtin = @import("builtin");
 const mem = std.mem;
+const assert = std.debug.assert;
 const expect = std.testing.expect;
 const expectEqual = std.testing.expectEqual;
 
@@ -210,3 +211,87 @@ test "@min/@max on comptime_int" {
     try expectEqual(-2, min);
     try expectEqual(2, max);
 }
+
+test "@min/@max notices bounds from types" {
+    if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
+    if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
+    if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
+    if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
+
+    var x: u16 = 123;
+    var y: u32 = 456;
+    var z: u8 = 10;
+
+    const min = @min(x, y, z);
+    const max = @max(x, y, z);
+
+    comptime assert(@TypeOf(min) == u8);
+    comptime assert(@TypeOf(max) == u32);
+
+    try expectEqual(z, min);
+    try expectEqual(y, max);
+}
+
+test "@min/@max notices bounds from vector types" {
+    if (builtin.zig_backend == .stage2_wasm) return error.SkipZigTest; // TODO
+    if (builtin.zig_backend == .stage2_x86_64) return error.SkipZigTest; // TODO
+    if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
+    if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
+    if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
+    if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
+
+    var x: @Vector(2, u16) = .{ 30, 67 };
+    var y: @Vector(2, u32) = .{ 20, 500 };
+    var z: @Vector(2, u8) = .{ 60, 15 };
+
+    const min = @min(x, y, z);
+    const max = @max(x, y, z);
+
+    comptime assert(@TypeOf(min) == @Vector(2, u8));
+    comptime assert(@TypeOf(max) == @Vector(2, u32));
+
+    try expectEqual(@Vector(2, u8){ 20, 15 }, min);
+    try expectEqual(@Vector(2, u32){ 60, 500 }, max);
+}
+
+test "@min/@max notices bounds from types when comptime-known value is undef" {
+    if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
+    if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
+    if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
+    if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
+
+    var x: u32 = 1_000_000;
+    const y: u16 = undefined;
+    // y is comptime-known, but is undef, so bounds cannot be refined using its value
+
+    const min = @min(x, y);
+    const max = @max(x, y);
+
+    comptime assert(@TypeOf(min) == u16);
+    comptime assert(@TypeOf(max) == u32);
+
+    // Cannot assert values as one was undefined
+}
+
+test "@min/@max notices bounds from vector types when element of comptime-known vector is undef" {
+    if (builtin.zig_backend == .stage2_wasm) return error.SkipZigTest; // TODO
+    if (builtin.zig_backend == .stage2_x86_64) return error.SkipZigTest; // TODO
+    if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
+    if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
+    if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
+    if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
+
+    var x: @Vector(2, u32) = .{ 1_000_000, 12345 };
+    const y: @Vector(2, u16) = .{ 10, undefined };
+    // y is comptime-known, but an element is undef, so bounds cannot be refined using its value
+
+    const min = @min(x, y);
+    const max = @max(x, y);
+
+    comptime assert(@TypeOf(min) == @Vector(2, u16));
+    comptime assert(@TypeOf(max) == @Vector(2, u32));
+
+    try expectEqual(@as(u16, 10), min[0]);
+    try expectEqual(@as(u32, 1_000_000), max[0]);
+    // Cannot assert values at index 1 as one was undefined
+}