Commit 38b2d62092
Changed files (4)
src
stage1
src/stage1/ir.cpp
@@ -9900,6 +9900,100 @@ static Stage1AirInst *ir_analyze_math_op(IrAnalyze *ira, Scope *scope, AstNode *
return ir_implicit_cast(ira, result_instruction, type_entry);
}
+static Stage1AirInst *ir_analyze_truncate(IrAnalyze *ira, Scope *scope, AstNode *source_node,
+ ZigType *dest_scalar_type, AstNode *dest_type_node,
+ Stage1AirInst *operand, AstNode *operand_node)
+{
+ if (dest_scalar_type->id != ZigTypeIdInt &&
+ dest_scalar_type->id != ZigTypeIdComptimeInt)
+ {
+ ir_add_error_node(ira, dest_type_node,
+ buf_sprintf("expected integer type, found '%s'", buf_ptr(&dest_scalar_type->name)));
+ return ira->codegen->invalid_inst_gen;
+ }
+
+ ZigType *src_type = operand->value->type;
+ bool is_vector = (src_type->id == ZigTypeIdVector);
+ ZigType *src_scalar_type = is_vector ?
+ src_type->data.vector.elem_type : src_type;
+
+ ZigType *dest_type = is_vector ?
+ get_vector_type(ira->codegen, src_type->data.vector.len, dest_scalar_type) :
+ dest_scalar_type;
+
+ if (src_scalar_type->id != ZigTypeIdInt && src_scalar_type->id != ZigTypeIdComptimeInt) {
+ ir_add_error_node(ira, operand_node,
+ buf_sprintf("expected integer type, found '%s'", buf_ptr(&src_scalar_type->name)));
+ return ira->codegen->invalid_inst_gen;
+ }
+
+ if (dest_scalar_type->id == ZigTypeIdComptimeInt) {
+ return ir_implicit_cast2(ira, scope, operand_node, operand, dest_type);
+ }
+
+ if (src_scalar_type->id != ZigTypeIdComptimeInt) {
+ if (src_scalar_type->data.integral.is_signed != dest_scalar_type->data.integral.is_signed) {
+ const char *sign_str = dest_scalar_type->data.integral.is_signed ? "signed" : "unsigned";
+ ir_add_error_node(ira, operand_node, buf_sprintf("expected %s integer type, found '%s'", sign_str, buf_ptr(&src_scalar_type->name)));
+ return ira->codegen->invalid_inst_gen;
+ } else if (src_scalar_type->data.integral.bit_count > 0 && src_scalar_type->data.integral.bit_count < dest_scalar_type->data.integral.bit_count) {
+ ir_add_error_node(ira, operand_node, buf_sprintf("type '%s' has fewer bits than destination type '%s'",
+ buf_ptr(&src_scalar_type->name), buf_ptr(&dest_scalar_type->name)));
+ return ira->codegen->invalid_inst_gen;
+ }
+ }
+
+ if (instr_is_comptime(operand)) {
+ ZigValue *val = ir_resolve_const(ira, operand, UndefBad);
+ if (val == nullptr)
+ return ira->codegen->invalid_inst_gen;
+
+ if (!is_vector) {
+ Stage1AirInst *result = ir_const(ira, scope, source_node, dest_type);
+ bigint_truncate(&result->value->data.x_bigint, &val->data.x_bigint,
+ dest_scalar_type->data.integral.bit_count,
+ dest_scalar_type->data.integral.is_signed);
+ return result;
+ }
+
+ Stage1AirInst *result_instruction = ir_const(ira, scope, source_node, dest_type);
+ ZigValue *out_val = result_instruction->value;
+ expand_undef_array(ira->codegen, operand->value);
+ out_val->special = ConstValSpecialUndef;
+ expand_undef_array(ira->codegen, out_val);
+ size_t len = dest_type->data.vector.len;
+ for (size_t i = 0; i < len; i += 1) {
+ ZigValue *scalar_operand_val = &operand->value->data.x_array.data.s_none.elements[i];
+ ZigValue *scalar_out_val = &out_val->data.x_array.data.s_none.elements[i];
+ assert(scalar_operand_val->type == dest_scalar_type);
+ assert(scalar_out_val->type == dest_scalar_type);
+
+ bigint_truncate(&scalar_out_val->data.x_bigint,
+ &scalar_operand_val->data.x_bigint,
+ dest_scalar_type->data.integral.bit_count,
+ dest_scalar_type->data.integral.is_signed);
+
+ scalar_out_val->type = dest_scalar_type;
+ scalar_out_val->special = ConstValSpecialStatic;
+ }
+ out_val->type = dest_type;
+ out_val->special = ConstValSpecialStatic;
+ return result_instruction;
+ }
+
+ if (src_scalar_type->data.integral.bit_count == 0 ||
+ dest_scalar_type->data.integral.bit_count == 0)
+ {
+ Stage1AirInst *result = ir_const(ira, scope, source_node, dest_type);
+ if (!is_vector) {
+ bigint_init_unsigned(&result->value->data.x_bigint, 0);
+ }
+ return result;
+ }
+
+ return ir_build_truncate_gen(ira, scope, source_node, dest_type, operand);
+}
+
static Stage1AirInst *ir_analyze_bit_shift(IrAnalyze *ira, Stage1ZirInstBinOp *bin_op_instruction) {
Stage1AirInst *op1 = bin_op_instruction->op1->child;
if (type_is_invalid(op1->value->type))
@@ -9951,6 +10045,12 @@ static Stage1AirInst *ir_analyze_bit_shift(IrAnalyze *ira, Stage1ZirInstBinOp *b
// comptime_int has no finite bit width
casted_op2 = op2;
+ if (op_id == IrBinOpShlSat) {
+ ir_add_error_node(ira, bin_op_instruction->base.source_node,
+ buf_sprintf("saturating shift on a comptime_int which has unlimited bits"));
+ return ira->codegen->invalid_inst_gen;
+ }
+
if (op_id == IrBinOpBitShiftLeftLossy) {
op_id = IrBinOpBitShiftLeftExact;
}
@@ -9972,6 +10072,13 @@ static Stage1AirInst *ir_analyze_bit_shift(IrAnalyze *ira, Stage1ZirInstBinOp *b
buf_sprintf("shift by negative value %s", buf_ptr(val_buf)));
return ira->codegen->invalid_inst_gen;
}
+ } else if (op_id == IrBinOpShlSat) {
+ casted_op2 = ir_analyze_truncate(ira,
+ bin_op_instruction->base.scope, bin_op_instruction->base.source_node,
+ op1_scalar_type, bin_op_instruction->op1->source_node,
+ op2, bin_op_instruction->op2->source_node);
+ if (type_is_invalid(casted_op2->value->type))
+ return ira->codegen->invalid_inst_gen;
} else {
const unsigned bit_count = op1_scalar_type->data.integral.bit_count;
ZigType *shift_amt_type = get_smallest_unsigned_int_type(ira->codegen,
@@ -10030,8 +10137,9 @@ static Stage1AirInst *ir_analyze_bit_shift(IrAnalyze *ira, Stage1ZirInstBinOp *b
return ir_analyze_math_op(ira, bin_op_instruction->base.scope, bin_op_instruction->base.source_node, op1_type, op1_val, op_id, op2_val);
}
- return ir_build_bin_op_gen(ira, bin_op_instruction->base.scope, bin_op_instruction->base.source_node, op1->value->type,
- op_id, op1, casted_op2, bin_op_instruction->safety_check_on);
+ return ir_build_bin_op_gen(ira,
+ bin_op_instruction->base.scope, bin_op_instruction->base.source_node,
+ op1->value->type, op_id, op1, casted_op2, bin_op_instruction->safety_check_on);
}
static bool ok_float_op(IrBinOp op) {
@@ -11035,6 +11143,7 @@ static Stage1AirInst *ir_analyze_instruction_bin_op(IrAnalyze *ira, Stage1ZirIns
case IrBinOpBitShiftLeftExact:
case IrBinOpBitShiftRightLossy:
case IrBinOpBitShiftRightExact:
+ case IrBinOpShlSat:
return ir_analyze_bit_shift(ira, bin_op_instruction);
case IrBinOpBinOr:
case IrBinOpBinXor:
@@ -11057,7 +11166,6 @@ static Stage1AirInst *ir_analyze_instruction_bin_op(IrAnalyze *ira, Stage1ZirIns
case IrBinOpAddSat:
case IrBinOpSubSat:
case IrBinOpMultSat:
- case IrBinOpShlSat:
return ir_analyze_bin_op_math(ira, bin_op_instruction);
case IrBinOpArrayCat:
return ir_analyze_array_cat(ira, bin_op_instruction);
@@ -20017,59 +20125,13 @@ static Stage1AirInst *ir_analyze_instruction_truncate(IrAnalyze *ira, Stage1ZirI
if (type_is_invalid(dest_type))
return ira->codegen->invalid_inst_gen;
- if (dest_type->id != ZigTypeIdInt &&
- dest_type->id != ZigTypeIdComptimeInt)
- {
- ir_add_error(ira, dest_type_value, buf_sprintf("expected integer type, found '%s'", buf_ptr(&dest_type->name)));
- return ira->codegen->invalid_inst_gen;
- }
-
- Stage1AirInst *target = instruction->target->child;
- ZigType *src_type = target->value->type;
- if (type_is_invalid(src_type))
- return ira->codegen->invalid_inst_gen;
-
- if (src_type->id != ZigTypeIdInt &&
- src_type->id != ZigTypeIdComptimeInt)
- {
- ir_add_error(ira, target, buf_sprintf("expected integer type, found '%s'", buf_ptr(&src_type->name)));
+ Stage1AirInst *operand = instruction->target->child;
+ if (type_is_invalid(operand->value->type))
return ira->codegen->invalid_inst_gen;
- }
-
- if (dest_type->id == ZigTypeIdComptimeInt) {
- return ir_implicit_cast2(ira, instruction->target->scope, instruction->target->source_node, target, dest_type);
- }
- if (src_type->id != ZigTypeIdComptimeInt) {
- if (src_type->data.integral.is_signed != dest_type->data.integral.is_signed) {
- const char *sign_str = dest_type->data.integral.is_signed ? "signed" : "unsigned";
- ir_add_error(ira, target, buf_sprintf("expected %s integer type, found '%s'", sign_str, buf_ptr(&src_type->name)));
- return ira->codegen->invalid_inst_gen;
- } else if (src_type->data.integral.bit_count > 0 && src_type->data.integral.bit_count < dest_type->data.integral.bit_count) {
- ir_add_error(ira, target, buf_sprintf("type '%s' has fewer bits than destination type '%s'",
- buf_ptr(&src_type->name), buf_ptr(&dest_type->name)));
- return ira->codegen->invalid_inst_gen;
- }
- }
-
- if (instr_is_comptime(target)) {
- ZigValue *val = ir_resolve_const(ira, target, UndefBad);
- if (val == nullptr)
- return ira->codegen->invalid_inst_gen;
-
- Stage1AirInst *result = ir_const(ira, instruction->base.scope, instruction->base.source_node, dest_type);
- bigint_truncate(&result->value->data.x_bigint, &val->data.x_bigint,
- dest_type->data.integral.bit_count, dest_type->data.integral.is_signed);
- return result;
- }
-
- if (src_type->data.integral.bit_count == 0 || dest_type->data.integral.bit_count == 0) {
- Stage1AirInst *result = ir_const(ira, instruction->base.scope, instruction->base.source_node, dest_type);
- bigint_init_unsigned(&result->value->data.x_bigint, 0);
- return result;
- }
-
- return ir_build_truncate_gen(ira, instruction->base.scope, instruction->base.source_node, dest_type, target);
+ return ir_analyze_truncate(ira, instruction->base.scope, instruction->base.source_node,
+ dest_type, instruction->dest_type->source_node,
+ operand, instruction->target->source_node);
}
static Stage1AirInst *ir_analyze_int_cast(IrAnalyze *ira, Scope *scope, AstNode *source_node,
test/behavior/saturating_arithmetic_stage1.zig
@@ -0,0 +1,22 @@
+const std = @import("std");
+const expect = std.testing.expect;
+
+test "saturating shl uses the LHS type" {
+ const lhs_const: u8 = 1;
+ var lhs_var: u8 = 1;
+
+ const rhs_const: usize = 8;
+ var rhs_var: usize = 8;
+
+ try expect((lhs_const <<| 8) == 255);
+ try expect((lhs_const <<| rhs_const) == 255);
+ try expect((lhs_const <<| rhs_var) == 255);
+
+ try expect((lhs_var <<| 8) == 255);
+ try expect((lhs_var <<| rhs_const) == 255);
+ try expect((lhs_var <<| rhs_var) == 255);
+
+ try expect((@as(u8, 1) <<| 8) == 255);
+ try expect((@as(u8, 1) <<| rhs_const) == 255);
+ try expect((@as(u8, 1) <<| rhs_var) == 255);
+}
test/behavior/truncate_stage1.zig
@@ -0,0 +1,13 @@
+const std = @import("std");
+const expect = std.testing.expect;
+
+test "truncate on vectors" {
+ const S = struct {
+ fn doTheTest() !void {
+ var v1: @Vector(4, u16) = .{ 0xaabb, 0xccdd, 0xeeff, 0x1122 };
+ var v2 = @truncate(u8, v1);
+ try expect(std.mem.eql(u8, &@as([4]u8, v2), &[4]u8{ 0xbb, 0xdd, 0xff, 0x22 }));
+ }
+ };
+ try S.doTheTest();
+}
test/behavior.zig
@@ -171,6 +171,7 @@ test {
_ = @import("behavior/popcount_stage1.zig");
_ = @import("behavior/ptrcast_stage1.zig");
_ = @import("behavior/reflection.zig");
+ _ = @import("behavior/saturating_arithmetic_stage1.zig");
_ = @import("behavior/select.zig");
_ = @import("behavior/shuffle.zig");
_ = @import("behavior/sizeof_and_typeof_stage1.zig");
@@ -181,6 +182,7 @@ test {
_ = @import("behavior/switch_prong_err_enum.zig");
_ = @import("behavior/switch_prong_implicit_cast.zig");
_ = @import("behavior/switch_stage1.zig");
+ _ = @import("behavior/truncate_stage1.zig");
_ = @import("behavior/try.zig");
_ = @import("behavior/tuple.zig");
_ = @import("behavior/type.zig");