Commit fb7060d3c2

Andrew Kelley <andrew@ziglang.org>
2022-01-31 00:23:31
stage2: implement shl_exact and shr_exact
These produce an undefined value when one bits are shifted out. New AIR instruction: shr_exact.
1 parent 0c30799
Changed files (13)
src/arch/aarch64/CodeGen.zig
@@ -535,12 +535,12 @@ fn genBody(self: *Self, body: []const Air.Inst.Index) InnerError!void {
             .cmp_gt  => try self.airCmp(inst, .gt),
             .cmp_neq => try self.airCmp(inst, .neq),
 
-            .bool_and => try self.airBoolOp(inst),
-            .bool_or  => try self.airBoolOp(inst),
-            .bit_and  => try self.airBitAnd(inst),
-            .bit_or   => try self.airBitOr(inst),
-            .xor      => try self.airXor(inst),
-            .shr      => try self.airShr(inst),
+            .bool_and        => try self.airBoolOp(inst),
+            .bool_or         => try self.airBoolOp(inst),
+            .bit_and         => try self.airBitAnd(inst),
+            .bit_or          => try self.airBitOr(inst),
+            .xor             => try self.airXor(inst),
+            .shr, .shr_exact => try self.airShr(inst),
 
             .alloc           => try self.airAlloc(inst),
             .ret_ptr         => try self.airRetPtr(inst),
src/arch/arm/CodeGen.zig
@@ -527,12 +527,12 @@ fn genBody(self: *Self, body: []const Air.Inst.Index) InnerError!void {
             .cmp_gt  => try self.airCmp(inst, .gt),
             .cmp_neq => try self.airCmp(inst, .neq),
 
-            .bool_and => try self.airBoolOp(inst),
-            .bool_or  => try self.airBoolOp(inst),
-            .bit_and  => try self.airBitAnd(inst),
-            .bit_or   => try self.airBitOr(inst),
-            .xor      => try self.airXor(inst),
-            .shr      => try self.airShr(inst),
+            .bool_and        => try self.airBoolOp(inst),
+            .bool_or         => try self.airBoolOp(inst),
+            .bit_and         => try self.airBitAnd(inst),
+            .bit_or          => try self.airBitOr(inst),
+            .xor             => try self.airXor(inst),
+            .shr, .shr_exact => try self.airShr(inst),
 
             .alloc           => try self.airAlloc(inst),
             .ret_ptr         => try self.airRetPtr(inst),
src/arch/riscv64/CodeGen.zig
@@ -514,12 +514,12 @@ fn genBody(self: *Self, body: []const Air.Inst.Index) InnerError!void {
             .cmp_gt  => try self.airCmp(inst, .gt),
             .cmp_neq => try self.airCmp(inst, .neq),
 
-            .bool_and => try self.airBoolOp(inst),
-            .bool_or  => try self.airBoolOp(inst),
-            .bit_and  => try self.airBitAnd(inst),
-            .bit_or   => try self.airBitOr(inst),
-            .xor      => try self.airXor(inst),
-            .shr      => try self.airShr(inst),
+            .bool_and        => try self.airBoolOp(inst),
+            .bool_or         => try self.airBoolOp(inst),
+            .bit_and         => try self.airBitAnd(inst),
+            .bit_or          => try self.airBitOr(inst),
+            .xor             => try self.airXor(inst),
+            .shr, .shr_exact => try self.airShr(inst),
 
             .alloc           => try self.airAlloc(inst),
             .ret_ptr         => try self.airRetPtr(inst),
src/arch/wasm/CodeGen.zig
@@ -1442,8 +1442,8 @@ fn genInst(self: *Self, inst: Air.Inst.Index) !WValue {
         .bool_and => self.airBinOp(inst, .@"and"),
         .bool_or => self.airBinOp(inst, .@"or"),
         .rem => self.airBinOp(inst, .rem),
-        .shl => self.airBinOp(inst, .shl),
-        .shr => self.airBinOp(inst, .shr),
+        .shl, .shl_exact => self.airBinOp(inst, .shl),
+        .shr, .shr_exact => self.airBinOp(inst, .shr),
         .xor => self.airBinOp(inst, .xor),
 
         .cmp_eq => self.airCmp(inst, .eq),
@@ -1531,7 +1531,6 @@ fn genInst(self: *Self, inst: Air.Inst.Index) !WValue {
         .max,
         .min,
         .assembly,
-        .shl_exact,
         .shl_sat,
         .ret_addr,
         .clz,
src/arch/x86_64/CodeGen.zig
@@ -578,12 +578,12 @@ fn genBody(self: *Self, body: []const Air.Inst.Index) InnerError!void {
             .cmp_gt  => try self.airCmp(inst, .gt),
             .cmp_neq => try self.airCmp(inst, .neq),
 
-            .bool_and => try self.airBoolOp(inst),
-            .bool_or  => try self.airBoolOp(inst),
-            .bit_and  => try self.airBitAnd(inst),
-            .bit_or   => try self.airBitOr(inst),
-            .xor      => try self.airXor(inst),
-            .shr      => try self.airShr(inst),
+            .bool_and        => try self.airBoolOp(inst),
+            .bool_or         => try self.airBoolOp(inst),
+            .bit_and         => try self.airBitAnd(inst),
+            .bit_or          => try self.airBitOr(inst),
+            .xor             => try self.airXor(inst),
+            .shr, .shr_exact => try self.airShr(inst),
 
             .alloc           => try self.airAlloc(inst),
             .ret_ptr         => try self.airRetPtr(inst),
src/codegen/llvm/bindings.zig
@@ -548,6 +548,12 @@ pub const Builder = opaque {
     pub const buildAShr = LLVMBuildAShr;
     extern fn LLVMBuildAShr(*const Builder, LHS: *const Value, RHS: *const Value, Name: [*:0]const u8) *const Value;
 
+    pub const buildLShrExact = ZigLLVMBuildLShrExact;
+    extern fn ZigLLVMBuildLShrExact(*const Builder, LHS: *const Value, RHS: *const Value, Name: [*:0]const u8) *const Value;
+
+    pub const buildAShrExact = ZigLLVMBuildAShrExact;
+    extern fn ZigLLVMBuildAShrExact(*const Builder, LHS: *const Value, RHS: *const Value, Name: [*:0]const u8) *const Value;
+
     pub const buildShl = LLVMBuildShl;
     extern fn LLVMBuildShl(*const Builder, LHS: *const Value, RHS: *const Value, Name: [*:0]const u8) *const Value;
 
src/codegen/c.zig
@@ -1470,7 +1470,7 @@ fn genBody(f: *Function, body: []const Air.Inst.Index) error{ AnalysisFail, OutO
             .bit_and         => try airBinOp(f, inst, " & "),
             .bit_or          => try airBinOp(f, inst, " | "),
             .xor             => try airBinOp(f, inst, " ^ "),
-            .shr             => try airBinOp(f, inst, " >> "),
+            .shr, .shr_exact => try airBinOp(f, inst, " >> "),
             .shl, .shl_exact => try airBinOp(f, inst, " << "),
             .not             => try airNot  (f, inst),
 
src/codegen/llvm.zig
@@ -2047,7 +2047,8 @@ pub const FuncGen = struct {
                 .bit_and, .bool_and => try self.airAnd(inst),
                 .bit_or, .bool_or   => try self.airOr(inst),
                 .xor                => try self.airXor(inst),
-                .shr                => try self.airShr(inst),
+                .shr                => try self.airShr(inst, false),
+                .shr_exact          => try self.airShr(inst, true),
 
                 .cmp_eq  => try self.airCmp(inst, .eq),
                 .cmp_gt  => try self.airCmp(inst, .gt),
@@ -3633,7 +3634,7 @@ pub const FuncGen = struct {
         return self.builder.buildUShlSat(lhs, casted_rhs, "");
     }
 
-    fn airShr(self: *FuncGen, inst: Air.Inst.Index) !?*const llvm.Value {
+    fn airShr(self: *FuncGen, inst: Air.Inst.Index, is_exact: bool) !?*const llvm.Value {
         if (self.liveness.isUnused(inst))
             return null;
         const bin_op = self.air.instructions.items(.data)[inst].bin_op;
@@ -3645,11 +3646,20 @@ pub const FuncGen = struct {
             self.builder.buildZExt(rhs, try self.dg.llvmType(lhs_type), "")
         else
             rhs;
+        const is_signed_int = self.air.typeOfIndex(inst).isSignedInt();
 
-        if (self.air.typeOfIndex(inst).isSignedInt()) {
-            return self.builder.buildAShr(lhs, casted_rhs, "");
+        if (is_exact) {
+            if (is_signed_int) {
+                return self.builder.buildAShrExact(lhs, casted_rhs, "");
+            } else {
+                return self.builder.buildLShrExact(lhs, casted_rhs, "");
+            }
         } else {
-            return self.builder.buildLShr(lhs, casted_rhs, "");
+            if (is_signed_int) {
+                return self.builder.buildAShr(lhs, casted_rhs, "");
+            } else {
+                return self.builder.buildLShr(lhs, casted_rhs, "");
+            }
         }
     }
 
src/Air.zig
@@ -179,6 +179,9 @@ pub const Inst = struct {
         /// Shift right. `>>`
         /// Uses the `bin_op` field.
         shr,
+        /// Shift right. The shift produces a poison value if it shifts out any non-zero bits.
+        /// Uses the `bin_op` field.
+        shr_exact,
         /// Shift left. `<<`
         /// Uses the `bin_op` field.
         shl,
@@ -738,6 +741,7 @@ pub fn typeOfIndex(air: Air, inst: Air.Inst.Index) Type {
         .ptr_add,
         .ptr_sub,
         .shr,
+        .shr_exact,
         .shl,
         .shl_exact,
         .shl_sat,
src/Liveness.zig
@@ -261,6 +261,7 @@ fn analyzeInst(
         .shl_exact,
         .shl_sat,
         .shr,
+        .shr_exact,
         .atomic_store_unordered,
         .atomic_store_monotonic,
         .atomic_store_release,
src/print_air.zig
@@ -138,6 +138,7 @@ const Writer = struct {
             .shl_exact,
             .shl_sat,
             .shr,
+            .shr_exact,
             .set_union_tag,
             .min,
             .max,
src/Sema.zig
@@ -666,7 +666,8 @@ fn analyzeBodyInner(
             .ptr_type_simple              => try sema.zirPtrTypeSimple(block, inst),
             .ref                          => try sema.zirRef(block, inst),
             .ret_err_value_code           => try sema.zirRetErrValueCode(block, inst),
-            .shr                          => try sema.zirShr(block, inst),
+            .shr                          => try sema.zirShr(block, inst, .shr),
+            .shr_exact                    => try sema.zirShr(block, inst, .shr_exact),
             .slice_end                    => try sema.zirSliceEnd(block, inst),
             .slice_sentinel               => try sema.zirSliceSentinel(block, inst),
             .slice_start                  => try sema.zirSliceStart(block, inst),
@@ -721,7 +722,6 @@ fn analyzeBodyInner(
             .pop_count                    => try sema.zirPopCount(block, inst),
             .byte_swap                    => try sema.zirByteSwap(block, inst),
             .bit_reverse                  => try sema.zirBitReverse(block, inst),
-            .shr_exact                    => try sema.zirShrExact(block, inst),
             .bit_offset_of                => try sema.zirBitOffsetOf(block, inst),
             .offset_of                    => try sema.zirOffsetOf(block, inst),
             .cmpxchg_strong               => try sema.zirCmpxchg(block, inst, .cmpxchg_strong),
@@ -7472,18 +7472,30 @@ fn zirShl(
         if (rhs_val.compareWithZero(.eq)) {
             return sema.addConstant(lhs_ty, lhs_val);
         }
+        const target = sema.mod.getTarget();
         const val = switch (air_tag) {
-            .shl_exact => return sema.fail(block, lhs_src, "TODO implement Sema for comptime shl_exact", .{}),
+            .shl_exact => val: {
+                const shifted = try lhs_val.shl(rhs_val, sema.arena);
+                if (lhs_ty.zigTypeTag() == .ComptimeInt) {
+                    break :val shifted;
+                }
+                const int_info = lhs_ty.intInfo(target);
+                const truncated = try shifted.intTrunc(sema.arena, int_info.signedness, int_info.bits);
+                if (truncated.compareHetero(.eq, shifted)) {
+                    break :val shifted;
+                }
+                return sema.addConstUndef(lhs_ty);
+            },
 
             .shl_sat => if (lhs_ty.zigTypeTag() == .ComptimeInt)
                 try lhs_val.shl(rhs_val, sema.arena)
             else
-                try lhs_val.shlSat(rhs_val, lhs_ty, sema.arena, sema.mod.getTarget()),
+                try lhs_val.shlSat(rhs_val, lhs_ty, sema.arena, target),
 
             .shl => if (lhs_ty.zigTypeTag() == .ComptimeInt)
                 try lhs_val.shl(rhs_val, sema.arena)
             else
-                try lhs_val.shlTrunc(rhs_val, lhs_ty, sema.arena, sema.mod.getTarget()),
+                try lhs_val.shlTrunc(rhs_val, lhs_ty, sema.arena, target),
 
             else => unreachable,
         };
@@ -7502,19 +7514,23 @@ fn zirShl(
     return block.addBinOp(air_tag, lhs, rhs);
 }
 
-fn zirShr(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError!Air.Inst.Ref {
+fn zirShr(
+    sema: *Sema,
+    block: *Block,
+    inst: Zir.Inst.Index,
+    air_tag: Air.Inst.Tag,
+) CompileError!Air.Inst.Ref {
     const tracy = trace(@src());
     defer tracy.end();
 
     const inst_data = sema.code.instructions.items(.data)[inst].pl_node;
-    const src: LazySrcLoc = .{ .node_offset_bin_op = inst_data.src_node };
     const lhs_src: LazySrcLoc = .{ .node_offset_bin_lhs = inst_data.src_node };
     const rhs_src: LazySrcLoc = .{ .node_offset_bin_rhs = inst_data.src_node };
     const extra = sema.code.extraData(Zir.Inst.Bin, inst_data.payload_index).data;
     const lhs = sema.resolveInst(extra.lhs);
     const rhs = sema.resolveInst(extra.rhs);
 
-    if (try sema.resolveMaybeUndefVal(block, rhs_src, rhs)) |rhs_val| {
+    const runtime_src = if (try sema.resolveMaybeUndefVal(block, rhs_src, rhs)) |rhs_val| rs: {
         if (try sema.resolveMaybeUndefVal(block, lhs_src, lhs)) |lhs_val| {
             const lhs_ty = sema.typeOf(lhs);
             if (lhs_val.isUndef() or rhs_val.isUndef()) {
@@ -7524,19 +7540,29 @@ fn zirShr(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError!Air.Ins
             if (rhs_val.compareWithZero(.eq)) {
                 return sema.addConstant(lhs_ty, lhs_val);
             }
+            if (air_tag == .shr_exact) {
+                // Detect if any ones would be shifted out.
+                const bits = @intCast(u16, rhs_val.toUnsignedInt());
+                const truncated = try lhs_val.intTrunc(sema.arena, .unsigned, bits);
+                if (!truncated.compareWithZero(.eq)) {
+                    return sema.addConstUndef(lhs_ty);
+                }
+            }
             const val = try lhs_val.shr(rhs_val, sema.arena);
             return sema.addConstant(lhs_ty, val);
+        } else {
+            // Even if lhs is not comptime known, we can still deduce certain things based
+            // on rhs.
+            // If rhs is 0, return lhs without doing any calculations.
+            if (rhs_val.compareWithZero(.eq)) {
+                return lhs;
+            }
+            break :rs lhs_src;
         }
-        // Even if lhs is not comptime known, we can still deduce certain things based
-        // on rhs.
-        // If rhs is 0, return lhs without doing any calculations.
-        else if (rhs_val.compareWithZero(.eq)) {
-            return lhs;
-        }
-    }
+    } else rhs_src;
 
-    try sema.requireRuntimeBlock(block, src);
-    return block.addBinOp(.shr, lhs, rhs);
+    try sema.requireRuntimeBlock(block, runtime_src);
+    return block.addBinOp(air_tag, lhs, rhs);
 }
 
 fn zirBitwise(
@@ -11448,12 +11474,6 @@ fn zirBitReverse(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError!
     return sema.fail(block, src, "TODO: Sema.zirBitReverse", .{});
 }
 
-fn zirShrExact(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError!Air.Inst.Ref {
-    const inst_data = sema.code.instructions.items(.data)[inst].pl_node;
-    const src = inst_data.src();
-    return sema.fail(block, src, "TODO: Sema.zirShrExact", .{});
-}
-
 fn zirBitOffsetOf(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError!Air.Inst.Ref {
     const offset = try bitOffsetOf(sema, block, inst);
     return sema.addIntUnsigned(Type.comptime_int, offset);
test/behavior/math.zig
@@ -736,8 +736,6 @@ fn testShlTrunc(x: u16) !void {
 }
 
 test "exact shift left" {
-    if (builtin.zig_backend != .stage1) return error.SkipZigTest; // TODO
-
     try testShlExact(0b00110101);
     comptime try testShlExact(0b00110101);
 }
@@ -747,8 +745,6 @@ fn testShlExact(x: u8) !void {
 }
 
 test "exact shift right" {
-    if (builtin.zig_backend != .stage1) return error.SkipZigTest; // TODO
-
     try testShrExact(0b10110100);
     comptime try testShrExact(0b10110100);
 }
@@ -758,8 +754,6 @@ fn testShrExact(x: u8) !void {
 }
 
 test "shift left/right on u0 operand" {
-    if (builtin.zig_backend != .stage1) return error.SkipZigTest; // TODO
-
     const S = struct {
         fn doTheTest() !void {
             var x: u0 = 0;