Commit 4ab13a359d

LemonBoy <thatlemon@gmail.com>
2020-03-10 23:04:49
ir: Fix shift code for u0 operands
1 parent 300fcea
Changed files (2)
src
test
stage1
behavior
src/ir.cpp
@@ -16635,34 +16635,47 @@ static IrInstGen *ir_analyze_bit_shift(IrAnalyze *ira, IrInstSrcBinOp *bin_op_in
     IrInstGen *casted_op2;
     IrBinOp op_id = bin_op_instruction->op_id;
     if (op1->value->type->id == ZigTypeIdComptimeInt) {
+        // comptime_int has no finite bit width
         casted_op2 = op2;
 
         if (op_id == IrBinOpBitShiftLeftLossy) {
             op_id = IrBinOpBitShiftLeftExact;
         }
 
-        if (casted_op2->value->data.x_bigint.is_negative) {
+        if (!instr_is_comptime(op2)) {
+            ir_add_error(ira, &bin_op_instruction->base.base,
+                buf_sprintf("LHS of shift must be an integer type, or RHS must be compile-time known"));
+            return ira->codegen->invalid_inst_gen;
+        }
+
+        ZigValue *op2_val = ir_resolve_const(ira, casted_op2, UndefBad);
+        if (op2_val == nullptr)
+            return ira->codegen->invalid_inst_gen;
+
+        if (op2_val->data.x_bigint.is_negative) {
             Buf *val_buf = buf_alloc();
-            bigint_append_buf(val_buf, &casted_op2->value->data.x_bigint, 10);
-            ir_add_error(ira, &casted_op2->base, buf_sprintf("shift by negative value %s", buf_ptr(val_buf)));
+            bigint_append_buf(val_buf, &op2_val->data.x_bigint, 10);
+            ir_add_error(ira, &casted_op2->base,
+                buf_sprintf("shift by negative value %s", buf_ptr(val_buf)));
             return ira->codegen->invalid_inst_gen;
         }
     } else {
-        assert(op1->value->type->data.integral.bit_count > 0);
+        const unsigned bit_count = op1->value->type->data.integral.bit_count;
         ZigType *shift_amt_type = get_smallest_unsigned_int_type(ira->codegen,
-            op1->value->type->data.integral.bit_count - 1);
+            bit_count > 0 ? bit_count - 1 : 0);
 
         casted_op2 = ir_implicit_cast(ira, op2, shift_amt_type);
         if (type_is_invalid(casted_op2->value->type))
             return ira->codegen->invalid_inst_gen;
 
-        if (instr_is_comptime(casted_op2)) {
+        // This check is only valid iff op1 has at least one bit
+        if (bit_count > 0 && instr_is_comptime(casted_op2)) {
             ZigValue *op2_val = ir_resolve_const(ira, casted_op2, UndefBad);
             if (op2_val == nullptr)
                 return ira->codegen->invalid_inst_gen;
 
             BigInt bit_count_value = {0};
-            bigint_init_unsigned(&bit_count_value, op1->value->type->data.integral.bit_count);
+            bigint_init_unsigned(&bit_count_value, bit_count);
 
             if (bigint_cmp(&op2_val->data.x_bigint, &bit_count_value) != CmpLT) {
                 ErrorMsg* msg = ir_add_error(ira,
@@ -16670,14 +16683,23 @@ static IrInstGen *ir_analyze_bit_shift(IrAnalyze *ira, IrInstSrcBinOp *bin_op_in
                     buf_sprintf("RHS of shift is too large for LHS type"));
                 add_error_note(ira->codegen, msg, op1->base.source_node,
                     buf_sprintf("type %s has only %u bits",
-                        buf_ptr(&op1->value->type->name),
-                        op1->value->type->data.integral.bit_count));
+                        buf_ptr(&op1->value->type->name), bit_count));
 
                 return ira->codegen->invalid_inst_gen;
             }
         }
     }
 
+    // Fast path for zero RHS
+    if (instr_is_comptime(casted_op2)) {
+        ZigValue *op2_val = ir_resolve_const(ira, casted_op2, UndefBad);
+        if (op2_val == nullptr)
+            return ira->codegen->invalid_inst_gen;
+
+        if (bigint_cmp_zero(&op2_val->data.x_bigint) == CmpEQ)
+            return ir_analyze_cast(ira, &bin_op_instruction->base.base, op1->value->type, op1);
+    }
+
     if (instr_is_comptime(op1) && instr_is_comptime(casted_op2)) {
         ZigValue *op1_val = ir_resolve_const(ira, op1, UndefBad);
         if (op1_val == nullptr)
@@ -16688,12 +16710,6 @@ static IrInstGen *ir_analyze_bit_shift(IrAnalyze *ira, IrInstSrcBinOp *bin_op_in
             return ira->codegen->invalid_inst_gen;
 
         return ir_analyze_math_op(ira, &bin_op_instruction->base.base, op1->value->type, op1_val, op_id, op2_val);
-    } else if (op1->value->type->id == ZigTypeIdComptimeInt) {
-        ir_add_error(ira, &bin_op_instruction->base.base,
-                buf_sprintf("LHS of shift must be an integer type, or RHS must be compile-time known"));
-        return ira->codegen->invalid_inst_gen;
-    } else if (instr_is_comptime(casted_op2) && bigint_cmp_zero(&casted_op2->value->data.x_bigint) == CmpEQ) {
-        return ir_build_cast(ira, &bin_op_instruction->base.base, op1->value->type, op1, CastOpNoop);
     }
 
     return ir_build_bin_op_gen(ira, &bin_op_instruction->base.base, op1->value->type,
test/stage1/behavior/math.zig
@@ -453,6 +453,25 @@ fn testShrExact(x: u8) void {
     expect(shifted == 0b00101101);
 }
 
+test "shift left/right on u0 operand" {
+    const S = struct {
+        fn doTheTest() void {
+            var x: u0 = 0;
+            var y: u0 = 0;
+            expectEqual(@as(u0, 0), x << 0);
+            expectEqual(@as(u0, 0), x >> 0);
+            expectEqual(@as(u0, 0), x << y);
+            expectEqual(@as(u0, 0), x >> y);
+            expectEqual(@as(u0, 0), @shlExact(x, 0));
+            expectEqual(@as(u0, 0), @shrExact(x, 0));
+            expectEqual(@as(u0, 0), @shlExact(x, y));
+            expectEqual(@as(u0, 0), @shrExact(x, y));
+        }
+    };
+    S.doTheTest();
+    comptime S.doTheTest();
+}
+
 test "comptime_int addition" {
     comptime {
         expect(35361831660712422535336160538497375248 + 101752735581729509668353361206450473702 == 137114567242441932203689521744947848950);