Commit 38b2d62092

Andrew Kelley <andrew@ziglang.org>
2021-12-08 23:19:13
stage1: saturating shl operates using LHS type
Saturating shift left (`<<|`) previously used the `ir_analyze_bin_op_math` codepath rather than the `ir_analyze_bit_shift` codepath, leading to it doing peer type resolution (incorrect) instead of using the LHS type as the number of bits to do the saturating against. This required implementing SIMD vector support for `@truncate`. Additionall, this commit adds a compile error for saturating shift left on a comptime_int. stage2 does not pass these new behavior tests yet. closes #10298
1 parent 64e2bfa
src/stage1/ir.cpp
@@ -9900,6 +9900,100 @@ static Stage1AirInst *ir_analyze_math_op(IrAnalyze *ira, Scope *scope, AstNode *
     return ir_implicit_cast(ira, result_instruction, type_entry);
 }
 
+static Stage1AirInst *ir_analyze_truncate(IrAnalyze *ira, Scope *scope, AstNode *source_node,
+        ZigType *dest_scalar_type, AstNode *dest_type_node,
+        Stage1AirInst *operand, AstNode *operand_node)
+{
+    if (dest_scalar_type->id != ZigTypeIdInt &&
+        dest_scalar_type->id != ZigTypeIdComptimeInt)
+    {
+        ir_add_error_node(ira, dest_type_node,
+            buf_sprintf("expected integer type, found '%s'", buf_ptr(&dest_scalar_type->name)));
+        return ira->codegen->invalid_inst_gen;
+    }
+
+    ZigType *src_type = operand->value->type;
+    bool is_vector = (src_type->id == ZigTypeIdVector);
+    ZigType *src_scalar_type = is_vector ?
+        src_type->data.vector.elem_type : src_type;
+
+    ZigType *dest_type = is_vector ?
+        get_vector_type(ira->codegen, src_type->data.vector.len, dest_scalar_type) :
+        dest_scalar_type;
+
+    if (src_scalar_type->id != ZigTypeIdInt && src_scalar_type->id != ZigTypeIdComptimeInt) {
+        ir_add_error_node(ira, operand_node,
+            buf_sprintf("expected integer type, found '%s'", buf_ptr(&src_scalar_type->name)));
+        return ira->codegen->invalid_inst_gen;
+    }
+
+    if (dest_scalar_type->id == ZigTypeIdComptimeInt) {
+        return ir_implicit_cast2(ira, scope, operand_node, operand, dest_type);
+    }
+
+    if (src_scalar_type->id != ZigTypeIdComptimeInt) {
+        if (src_scalar_type->data.integral.is_signed != dest_scalar_type->data.integral.is_signed) {
+            const char *sign_str = dest_scalar_type->data.integral.is_signed ? "signed" : "unsigned";
+            ir_add_error_node(ira, operand_node, buf_sprintf("expected %s integer type, found '%s'", sign_str, buf_ptr(&src_scalar_type->name)));
+            return ira->codegen->invalid_inst_gen;
+        } else if (src_scalar_type->data.integral.bit_count > 0 && src_scalar_type->data.integral.bit_count < dest_scalar_type->data.integral.bit_count) {
+            ir_add_error_node(ira, operand_node, buf_sprintf("type '%s' has fewer bits than destination type '%s'",
+                        buf_ptr(&src_scalar_type->name), buf_ptr(&dest_scalar_type->name)));
+            return ira->codegen->invalid_inst_gen;
+        }
+    }
+
+    if (instr_is_comptime(operand)) {
+        ZigValue *val = ir_resolve_const(ira, operand, UndefBad);
+        if (val == nullptr)
+            return ira->codegen->invalid_inst_gen;
+
+        if (!is_vector) {
+            Stage1AirInst *result = ir_const(ira, scope, source_node, dest_type);
+            bigint_truncate(&result->value->data.x_bigint, &val->data.x_bigint,
+                    dest_scalar_type->data.integral.bit_count,
+                    dest_scalar_type->data.integral.is_signed);
+            return result;
+        }
+
+        Stage1AirInst *result_instruction = ir_const(ira, scope, source_node, dest_type);
+        ZigValue *out_val = result_instruction->value;
+        expand_undef_array(ira->codegen, operand->value);
+        out_val->special = ConstValSpecialUndef;
+        expand_undef_array(ira->codegen, out_val);
+        size_t len = dest_type->data.vector.len;
+        for (size_t i = 0; i < len; i += 1) {
+            ZigValue *scalar_operand_val = &operand->value->data.x_array.data.s_none.elements[i];
+            ZigValue *scalar_out_val = &out_val->data.x_array.data.s_none.elements[i];
+            assert(scalar_operand_val->type == dest_scalar_type);
+            assert(scalar_out_val->type == dest_scalar_type);
+
+            bigint_truncate(&scalar_out_val->data.x_bigint,
+                    &scalar_operand_val->data.x_bigint,
+                    dest_scalar_type->data.integral.bit_count,
+                    dest_scalar_type->data.integral.is_signed);
+
+            scalar_out_val->type = dest_scalar_type;
+            scalar_out_val->special = ConstValSpecialStatic;
+        }
+        out_val->type = dest_type;
+        out_val->special = ConstValSpecialStatic;
+        return result_instruction;
+    }
+
+    if (src_scalar_type->data.integral.bit_count == 0 ||
+        dest_scalar_type->data.integral.bit_count == 0)
+    {
+        Stage1AirInst *result = ir_const(ira, scope, source_node, dest_type);
+        if (!is_vector) {
+            bigint_init_unsigned(&result->value->data.x_bigint, 0);
+        }
+        return result;
+    }
+
+    return ir_build_truncate_gen(ira, scope, source_node, dest_type, operand);
+}
+
 static Stage1AirInst *ir_analyze_bit_shift(IrAnalyze *ira, Stage1ZirInstBinOp *bin_op_instruction) {
     Stage1AirInst *op1 = bin_op_instruction->op1->child;
     if (type_is_invalid(op1->value->type))
@@ -9951,6 +10045,12 @@ static Stage1AirInst *ir_analyze_bit_shift(IrAnalyze *ira, Stage1ZirInstBinOp *b
         // comptime_int has no finite bit width
         casted_op2 = op2;
 
+        if (op_id == IrBinOpShlSat) {
+            ir_add_error_node(ira, bin_op_instruction->base.source_node,
+                buf_sprintf("saturating shift on a comptime_int which has unlimited bits"));
+            return ira->codegen->invalid_inst_gen;
+        }
+
         if (op_id == IrBinOpBitShiftLeftLossy) {
             op_id = IrBinOpBitShiftLeftExact;
         }
@@ -9972,6 +10072,13 @@ static Stage1AirInst *ir_analyze_bit_shift(IrAnalyze *ira, Stage1ZirInstBinOp *b
                 buf_sprintf("shift by negative value %s", buf_ptr(val_buf)));
             return ira->codegen->invalid_inst_gen;
         }
+    } else if (op_id == IrBinOpShlSat) {
+        casted_op2 = ir_analyze_truncate(ira,
+                bin_op_instruction->base.scope, bin_op_instruction->base.source_node,
+                op1_scalar_type, bin_op_instruction->op1->source_node,
+                op2, bin_op_instruction->op2->source_node);
+        if (type_is_invalid(casted_op2->value->type))
+            return ira->codegen->invalid_inst_gen;
     } else {
         const unsigned bit_count = op1_scalar_type->data.integral.bit_count;
         ZigType *shift_amt_type = get_smallest_unsigned_int_type(ira->codegen,
@@ -10030,8 +10137,9 @@ static Stage1AirInst *ir_analyze_bit_shift(IrAnalyze *ira, Stage1ZirInstBinOp *b
         return ir_analyze_math_op(ira, bin_op_instruction->base.scope, bin_op_instruction->base.source_node, op1_type, op1_val, op_id, op2_val);
     }
 
-    return ir_build_bin_op_gen(ira, bin_op_instruction->base.scope, bin_op_instruction->base.source_node, op1->value->type,
-            op_id, op1, casted_op2, bin_op_instruction->safety_check_on);
+    return ir_build_bin_op_gen(ira,
+            bin_op_instruction->base.scope, bin_op_instruction->base.source_node,
+            op1->value->type, op_id, op1, casted_op2, bin_op_instruction->safety_check_on);
 }
 
 static bool ok_float_op(IrBinOp op) {
@@ -11035,6 +11143,7 @@ static Stage1AirInst *ir_analyze_instruction_bin_op(IrAnalyze *ira, Stage1ZirIns
         case IrBinOpBitShiftLeftExact:
         case IrBinOpBitShiftRightLossy:
         case IrBinOpBitShiftRightExact:
+        case IrBinOpShlSat:
             return ir_analyze_bit_shift(ira, bin_op_instruction);
         case IrBinOpBinOr:
         case IrBinOpBinXor:
@@ -11057,7 +11166,6 @@ static Stage1AirInst *ir_analyze_instruction_bin_op(IrAnalyze *ira, Stage1ZirIns
         case IrBinOpAddSat:
         case IrBinOpSubSat:
         case IrBinOpMultSat:
-        case IrBinOpShlSat:
             return ir_analyze_bin_op_math(ira, bin_op_instruction);
         case IrBinOpArrayCat:
             return ir_analyze_array_cat(ira, bin_op_instruction);
@@ -20017,59 +20125,13 @@ static Stage1AirInst *ir_analyze_instruction_truncate(IrAnalyze *ira, Stage1ZirI
     if (type_is_invalid(dest_type))
         return ira->codegen->invalid_inst_gen;
 
-    if (dest_type->id != ZigTypeIdInt &&
-        dest_type->id != ZigTypeIdComptimeInt)
-    {
-        ir_add_error(ira, dest_type_value, buf_sprintf("expected integer type, found '%s'", buf_ptr(&dest_type->name)));
-        return ira->codegen->invalid_inst_gen;
-    }
-
-    Stage1AirInst *target = instruction->target->child;
-    ZigType *src_type = target->value->type;
-    if (type_is_invalid(src_type))
-        return ira->codegen->invalid_inst_gen;
-
-    if (src_type->id != ZigTypeIdInt &&
-        src_type->id != ZigTypeIdComptimeInt)
-    {
-        ir_add_error(ira, target, buf_sprintf("expected integer type, found '%s'", buf_ptr(&src_type->name)));
+    Stage1AirInst *operand = instruction->target->child;
+    if (type_is_invalid(operand->value->type))
         return ira->codegen->invalid_inst_gen;
-    }
-
-    if (dest_type->id == ZigTypeIdComptimeInt) {
-        return ir_implicit_cast2(ira, instruction->target->scope, instruction->target->source_node, target, dest_type);
-    }
 
-    if (src_type->id != ZigTypeIdComptimeInt) {
-        if (src_type->data.integral.is_signed != dest_type->data.integral.is_signed) {
-            const char *sign_str = dest_type->data.integral.is_signed ? "signed" : "unsigned";
-            ir_add_error(ira, target, buf_sprintf("expected %s integer type, found '%s'", sign_str, buf_ptr(&src_type->name)));
-            return ira->codegen->invalid_inst_gen;
-        } else if (src_type->data.integral.bit_count > 0 && src_type->data.integral.bit_count < dest_type->data.integral.bit_count) {
-            ir_add_error(ira, target, buf_sprintf("type '%s' has fewer bits than destination type '%s'",
-                        buf_ptr(&src_type->name), buf_ptr(&dest_type->name)));
-            return ira->codegen->invalid_inst_gen;
-        }
-    }
-
-    if (instr_is_comptime(target)) {
-        ZigValue *val = ir_resolve_const(ira, target, UndefBad);
-        if (val == nullptr)
-            return ira->codegen->invalid_inst_gen;
-
-        Stage1AirInst *result = ir_const(ira, instruction->base.scope, instruction->base.source_node, dest_type);
-        bigint_truncate(&result->value->data.x_bigint, &val->data.x_bigint,
-                dest_type->data.integral.bit_count, dest_type->data.integral.is_signed);
-        return result;
-    }
-
-    if (src_type->data.integral.bit_count == 0 || dest_type->data.integral.bit_count == 0) {
-        Stage1AirInst *result = ir_const(ira, instruction->base.scope, instruction->base.source_node, dest_type);
-        bigint_init_unsigned(&result->value->data.x_bigint, 0);
-        return result;
-    }
-
-    return ir_build_truncate_gen(ira, instruction->base.scope, instruction->base.source_node, dest_type, target);
+    return ir_analyze_truncate(ira, instruction->base.scope, instruction->base.source_node,
+            dest_type, instruction->dest_type->source_node,
+            operand, instruction->target->source_node);
 }
 
 static Stage1AirInst *ir_analyze_int_cast(IrAnalyze *ira, Scope *scope, AstNode *source_node,
test/behavior/saturating_arithmetic_stage1.zig
@@ -0,0 +1,22 @@
+const std = @import("std");
+const expect = std.testing.expect;
+
+test "saturating shl uses the LHS type" {
+    const lhs_const: u8 = 1;
+    var lhs_var: u8 = 1;
+
+    const rhs_const: usize = 8;
+    var rhs_var: usize = 8;
+
+    try expect((lhs_const <<| 8) == 255);
+    try expect((lhs_const <<| rhs_const) == 255);
+    try expect((lhs_const <<| rhs_var) == 255);
+
+    try expect((lhs_var <<| 8) == 255);
+    try expect((lhs_var <<| rhs_const) == 255);
+    try expect((lhs_var <<| rhs_var) == 255);
+
+    try expect((@as(u8, 1) <<| 8) == 255);
+    try expect((@as(u8, 1) <<| rhs_const) == 255);
+    try expect((@as(u8, 1) <<| rhs_var) == 255);
+}
test/behavior/truncate_stage1.zig
@@ -0,0 +1,13 @@
+const std = @import("std");
+const expect = std.testing.expect;
+
+test "truncate on vectors" {
+    const S = struct {
+        fn doTheTest() !void {
+            var v1: @Vector(4, u16) = .{ 0xaabb, 0xccdd, 0xeeff, 0x1122 };
+            var v2 = @truncate(u8, v1);
+            try expect(std.mem.eql(u8, &@as([4]u8, v2), &[4]u8{ 0xbb, 0xdd, 0xff, 0x22 }));
+        }
+    };
+    try S.doTheTest();
+}
test/behavior.zig
@@ -171,6 +171,7 @@ test {
                 _ = @import("behavior/popcount_stage1.zig");
                 _ = @import("behavior/ptrcast_stage1.zig");
                 _ = @import("behavior/reflection.zig");
+                _ = @import("behavior/saturating_arithmetic_stage1.zig");
                 _ = @import("behavior/select.zig");
                 _ = @import("behavior/shuffle.zig");
                 _ = @import("behavior/sizeof_and_typeof_stage1.zig");
@@ -181,6 +182,7 @@ test {
                 _ = @import("behavior/switch_prong_err_enum.zig");
                 _ = @import("behavior/switch_prong_implicit_cast.zig");
                 _ = @import("behavior/switch_stage1.zig");
+                _ = @import("behavior/truncate_stage1.zig");
                 _ = @import("behavior/try.zig");
                 _ = @import("behavior/tuple.zig");
                 _ = @import("behavior/type.zig");