Commit 52bb71867d

Andrew Kelley <andrew@ziglang.org>
2019-02-22 19:28:57
implement vector negation
also fix vector behavior tests, they weren't actually testing runtime vectors, but now they are. See #903
1 parent 2fe8a08
Changed files (5)
src/codegen.cpp
@@ -3229,7 +3229,8 @@ static LLVMValueRef ir_render_br(CodeGen *g, IrExecutable *executable, IrInstruc
 static LLVMValueRef ir_render_un_op(CodeGen *g, IrExecutable *executable, IrInstructionUnOp *un_op_instruction) {
     IrUnOp op_id = un_op_instruction->op_id;
     LLVMValueRef expr = ir_llvm_value(g, un_op_instruction->value);
-    ZigType *expr_type = un_op_instruction->value->value.type;
+    ZigType *operand_type = un_op_instruction->value->value.type;
+    ZigType *scalar_type = (operand_type->id == ZigTypeIdVector) ? operand_type->data.vector.elem_type : operand_type;
 
     switch (op_id) {
         case IrUnOpInvalid:
@@ -3239,16 +3240,16 @@ static LLVMValueRef ir_render_un_op(CodeGen *g, IrExecutable *executable, IrInst
         case IrUnOpNegation:
         case IrUnOpNegationWrap:
             {
-                if (expr_type->id == ZigTypeIdFloat) {
+                if (scalar_type->id == ZigTypeIdFloat) {
                     ZigLLVMSetFastMath(g->builder, ir_want_fast_math(g, &un_op_instruction->base));
                     return LLVMBuildFNeg(g->builder, expr, "");
-                } else if (expr_type->id == ZigTypeIdInt) {
+                } else if (scalar_type->id == ZigTypeIdInt) {
                     if (op_id == IrUnOpNegationWrap) {
                         return LLVMBuildNeg(g->builder, expr, "");
                     } else if (ir_want_runtime_safety(g, &un_op_instruction->base)) {
                         LLVMValueRef zero = LLVMConstNull(LLVMTypeOf(expr));
-                        return gen_overflow_op(g, expr_type, AddSubMulSub, zero, expr);
-                    } else if (expr_type->data.integral.is_signed) {
+                        return gen_overflow_op(g, operand_type, AddSubMulSub, zero, expr);
+                    } else if (scalar_type->data.integral.is_signed) {
                         return LLVMBuildNSWNeg(g->builder, expr, "");
                     } else {
                         return LLVMBuildNUWNeg(g->builder, expr, "");
src/ir.cpp
@@ -14620,6 +14620,41 @@ static IrInstruction *ir_analyze_maybe(IrAnalyze *ira, IrInstructionUnOp *un_op_
     zig_unreachable();
 }
 
+static ErrorMsg *ir_eval_negation_scalar(IrAnalyze *ira, IrInstruction *source_instr, ZigType *scalar_type,
+        ConstExprValue *operand_val, ConstExprValue *scalar_out_val, bool is_wrap_op)
+{
+    bool is_float = (scalar_type->id == ZigTypeIdFloat || scalar_type->id == ZigTypeIdComptimeFloat);
+
+    bool ok_type = ((scalar_type->id == ZigTypeIdInt && scalar_type->data.integral.is_signed) ||
+        scalar_type->id == ZigTypeIdComptimeInt || (is_float && !is_wrap_op));
+
+    if (!ok_type) {
+        const char *fmt = is_wrap_op ? "invalid wrapping negation type: '%s'" : "invalid negation type: '%s'";
+        return ir_add_error(ira, source_instr, buf_sprintf(fmt, buf_ptr(&scalar_type->name)));
+    }
+
+    if (is_float) {
+        float_negate(scalar_out_val, operand_val);
+    } else if (is_wrap_op) {
+        bigint_negate_wrap(&scalar_out_val->data.x_bigint, &operand_val->data.x_bigint,
+                scalar_type->data.integral.bit_count);
+    } else {
+        bigint_negate(&scalar_out_val->data.x_bigint, &operand_val->data.x_bigint);
+    }
+
+    scalar_out_val->type = scalar_type;
+    scalar_out_val->special = ConstValSpecialStatic;
+
+    if (is_wrap_op || is_float || scalar_type->id == ZigTypeIdComptimeInt) {
+        return nullptr;
+    }
+
+    if (!bigint_fits_in_bits(&scalar_out_val->data.x_bigint, scalar_type->data.integral.bit_count, true)) {
+        return ir_add_error(ira, source_instr, buf_sprintf("negation caused overflow"));
+    }
+    return nullptr;
+}
+
 static IrInstruction *ir_analyze_negation(IrAnalyze *ira, IrInstructionUnOp *instruction) {
     IrInstruction *value = instruction->value->child;
     ZigType *expr_type = value->value.type;
@@ -14628,47 +14663,50 @@ static IrInstruction *ir_analyze_negation(IrAnalyze *ira, IrInstructionUnOp *ins
 
     bool is_wrap_op = (instruction->op_id == IrUnOpNegationWrap);
 
-    bool is_float = (expr_type->id == ZigTypeIdFloat || expr_type->id == ZigTypeIdComptimeFloat);
+    ZigType *scalar_type = (expr_type->id == ZigTypeIdVector) ? expr_type->data.vector.elem_type : expr_type;
 
-    if ((expr_type->id == ZigTypeIdInt && expr_type->data.integral.is_signed) ||
-        expr_type->id == ZigTypeIdComptimeInt || (is_float && !is_wrap_op))
-    {
-        if (instr_is_comptime(value)) {
-            ConstExprValue *target_const_val = ir_resolve_const(ira, value, UndefBad);
-            if (!target_const_val)
-                return ira->codegen->invalid_instruction;
+    if (instr_is_comptime(value)) {
+        ConstExprValue *operand_val = ir_resolve_const(ira, value, UndefBad);
+        if (!operand_val)
+            return ira->codegen->invalid_instruction;
 
-            IrInstruction *result = ir_const(ira, &instruction->base, expr_type);
-            ConstExprValue *out_val = &result->value;
-            if (is_float) {
-                float_negate(out_val, target_const_val);
-            } else if (is_wrap_op) {
-                bigint_negate_wrap(&out_val->data.x_bigint, &target_const_val->data.x_bigint,
-                        expr_type->data.integral.bit_count);
-            } else {
-                bigint_negate(&out_val->data.x_bigint, &target_const_val->data.x_bigint);
-            }
-            if (is_wrap_op || is_float || expr_type->id == ZigTypeIdComptimeInt) {
-                return result;
+        IrInstruction *result_instruction = ir_const(ira, &instruction->base, expr_type);
+        ConstExprValue *out_val = &result_instruction->value;
+        if (expr_type->id == ZigTypeIdVector) {
+            expand_undef_array(ira->codegen, operand_val);
+            out_val->special = ConstValSpecialUndef;
+            expand_undef_array(ira->codegen, out_val);
+            size_t len = expr_type->data.vector.len;
+            for (size_t i = 0; i < len; i += 1) {
+                ConstExprValue *scalar_operand_val = &operand_val->data.x_array.data.s_none.elements[i];
+                ConstExprValue *scalar_out_val = &out_val->data.x_array.data.s_none.elements[i];
+                assert(scalar_operand_val->type == scalar_type);
+                assert(scalar_out_val->type == scalar_type);
+                ErrorMsg *msg = ir_eval_negation_scalar(ira, &instruction->base, scalar_type,
+                        scalar_operand_val, scalar_out_val, is_wrap_op);
+                if (msg != nullptr) {
+                    add_error_note(ira->codegen, msg, instruction->base.source_node,
+                        buf_sprintf("when computing vector element at index %" ZIG_PRI_usize, i));
+                    return ira->codegen->invalid_instruction;
+                }
             }
-
-            if (!bigint_fits_in_bits(&out_val->data.x_bigint, expr_type->data.integral.bit_count, true)) {
-                ir_add_error(ira, &instruction->base, buf_sprintf("negation caused overflow"));
+            out_val->type = expr_type;
+            out_val->special = ConstValSpecialStatic;
+        } else {
+            if (ir_eval_negation_scalar(ira, &instruction->base, scalar_type, operand_val, out_val,
+                        is_wrap_op) != nullptr)
+            {
                 return ira->codegen->invalid_instruction;
             }
-            return result;
         }
-
-        IrInstruction *result = ir_build_un_op(&ira->new_irb,
-                instruction->base.scope, instruction->base.source_node,
-                instruction->op_id, value);
-        result->value.type = expr_type;
-        return result;
+        return result_instruction;
     }
 
-    const char *fmt = is_wrap_op ? "invalid wrapping negation type: '%s'" : "invalid negation type: '%s'";
-    ir_add_error(ira, &instruction->base, buf_sprintf(fmt, buf_ptr(&expr_type->name)));
-    return ira->codegen->invalid_instruction;
+    IrInstruction *result = ir_build_un_op(&ira->new_irb,
+            instruction->base.scope, instruction->base.source_node,
+            instruction->op_id, value);
+    result->value.type = expr_type;
+    return result;
 }
 
 static IrInstruction *ir_analyze_bin_not(IrAnalyze *ira, IrInstructionUnOp *instruction) {
test/stage1/behavior/vector.zig
@@ -5,11 +5,28 @@ const expect = std.testing.expect;
 test "vector wrap operators" {
     const S = struct {
         fn doTheTest() void {
-            const v: @Vector(4, i32) = [4]i32{ 10, 20, 30, 40 };
-            const x: @Vector(4, i32) = [4]i32{ 1, 2, 3, 4 };
-            expect(mem.eql(i32, ([4]i32)(v +% x), [4]i32{ 11, 22, 33, 44 }));
-            expect(mem.eql(i32, ([4]i32)(v -% x), [4]i32{ 9, 18, 27, 36 }));
-            expect(mem.eql(i32, ([4]i32)(v *% x), [4]i32{ 10, 40, 90, 160 }));
+            var v: @Vector(4, i32) = [4]i32{ 2147483647, -2, 30, 40 };
+            var x: @Vector(4, i32) = [4]i32{ 1, 2147483647, 3, 4 };
+            expect(mem.eql(i32, ([4]i32)(v +% x), [4]i32{ -2147483648, 2147483645, 33, 44 }));
+            expect(mem.eql(i32, ([4]i32)(v -% x), [4]i32{ 2147483646, 2147483647, 27, 36 }));
+            expect(mem.eql(i32, ([4]i32)(v *% x), [4]i32{ 2147483647, 2, 90, 160 }));
+            var z: @Vector(4, i32) = [4]i32{ 1, 2, 3, -2147483648 };
+            expect(mem.eql(i32, ([4]i32)(-%z), [4]i32{ -1, -2, -3, -2147483648 }));
+        }
+    };
+    S.doTheTest();
+    comptime S.doTheTest();
+}
+
+test "vector int operators" {
+    const S = struct {
+        fn doTheTest() void {
+            var v: @Vector(4, i32) = [4]i32{ 10, 20, 30, 40 };
+            var x: @Vector(4, i32) = [4]i32{ 1, 2, 3, 4 };
+            expect(mem.eql(i32, ([4]i32)(v + x), [4]i32{ 11, 22, 33, 44 }));
+            expect(mem.eql(i32, ([4]i32)(v - x), [4]i32{ 9, 18, 27, 36 }));
+            expect(mem.eql(i32, ([4]i32)(v * x), [4]i32{ 10, 40, 90, 160 }));
+            expect(mem.eql(i32, ([4]i32)(-v), [4]i32{ -10, -20, -30, -40 }));
         }
     };
     S.doTheTest();
@@ -19,11 +36,12 @@ test "vector wrap operators" {
 test "vector float operators" {
     const S = struct {
         fn doTheTest() void {
-            const v: @Vector(4, f32) = [4]f32{ 10, 20, 30, 40 };
-            const x: @Vector(4, f32) = [4]f32{ 1, 2, 3, 4 };
+            var v: @Vector(4, f32) = [4]f32{ 10, 20, 30, 40 };
+            var x: @Vector(4, f32) = [4]f32{ 1, 2, 3, 4 };
             expect(mem.eql(f32, ([4]f32)(v + x), [4]f32{ 11, 22, 33, 44 }));
             expect(mem.eql(f32, ([4]f32)(v - x), [4]f32{ 9, 18, 27, 36 }));
             expect(mem.eql(f32, ([4]f32)(v * x), [4]f32{ 10, 40, 90, 160 }));
+            expect(mem.eql(f32, ([4]f32)(-x), [4]f32{ -1, -2, -3, -4 }));
         }
     };
     S.doTheTest();
@@ -33,8 +51,8 @@ test "vector float operators" {
 test "vector bit operators" {
     const S = struct {
         fn doTheTest() void {
-            const v: @Vector(4, u8) = [4]u8{ 0b10101010, 0b10101010, 0b10101010, 0b10101010 };
-            const x: @Vector(4, u8) = [4]u8{ 0b11110000, 0b00001111, 0b10101010, 0b01010101 };
+            var v: @Vector(4, u8) = [4]u8{ 0b10101010, 0b10101010, 0b10101010, 0b10101010 };
+            var x: @Vector(4, u8) = [4]u8{ 0b11110000, 0b00001111, 0b10101010, 0b01010101 };
             expect(mem.eql(u8, ([4]u8)(v ^ x), [4]u8{ 0b01011010, 0b10100101, 0b00000000, 0b11111111 }));
             expect(mem.eql(u8, ([4]u8)(v | x), [4]u8{ 0b11111010, 0b10101111, 0b10101010, 0b11111111 }));
             expect(mem.eql(u8, ([4]u8)(v & x), [4]u8{ 0b10100000, 0b00001010, 0b10101010, 0b00000000 }));
test/compile_errors.zig
@@ -1,6 +1,18 @@
 const tests = @import("tests.zig");
 
 pub fn addCases(cases: *tests.CompileErrorContext) void {
+    cases.addTest(
+        "comptime vector overflow shows the index",
+        \\comptime {
+        \\    var a: @Vector(4, u8) = []u8{ 1, 2, 255, 4 };
+        \\    var b: @Vector(4, u8) = []u8{ 5, 6, 1, 8 };
+        \\    var x = a + b;
+        \\}
+    ,
+        ".tmp_source.zig:4:15: error: operation caused overflow",
+        ".tmp_source.zig:4:15: note: when computing vector element at index 2",
+    );
+
     cases.addTest(
         "packed struct with fields of not allowed types",
         \\const A = packed struct {
test/runtime_safety.zig
@@ -118,6 +118,47 @@ pub fn addCases(cases: *tests.CompareOutputContext) void {
         \\}
     );
 
+    cases.addRuntimeSafety("vector integer subtraction overflow",
+        \\pub fn panic(message: []const u8, stack_trace: ?*@import("builtin").StackTrace) noreturn {
+        \\    @import("std").os.exit(126);
+        \\}
+        \\pub fn main() void {
+        \\    var a: @Vector(4, u32) = []u32{ 1, 2, 8, 4 };
+        \\    var b: @Vector(4, u32) = []u32{ 5, 6, 7, 8 };
+        \\    const x = sub(b, a);
+        \\}
+        \\fn sub(a: @Vector(4, u32), b: @Vector(4, u32)) @Vector(4, u32) {
+        \\    return a - b;
+        \\}
+    );
+
+    cases.addRuntimeSafety("vector integer multiplication overflow",
+        \\pub fn panic(message: []const u8, stack_trace: ?*@import("builtin").StackTrace) noreturn {
+        \\    @import("std").os.exit(126);
+        \\}
+        \\pub fn main() void {
+        \\    var a: @Vector(4, u8) = []u8{ 1, 2, 200, 4 };
+        \\    var b: @Vector(4, u8) = []u8{ 5, 6, 2, 8 };
+        \\    const x = mul(b, a);
+        \\}
+        \\fn mul(a: @Vector(4, u8), b: @Vector(4, u8)) @Vector(4, u8) {
+        \\    return a * b;
+        \\}
+    );
+
+    cases.addRuntimeSafety("vector integer negation overflow",
+        \\pub fn panic(message: []const u8, stack_trace: ?*@import("builtin").StackTrace) noreturn {
+        \\    @import("std").os.exit(126);
+        \\}
+        \\pub fn main() void {
+        \\    var a: @Vector(4, i16) = []i16{ 1, -32768, 200, 4 };
+        \\    const x = neg(a);
+        \\}
+        \\fn neg(a: @Vector(4, i16)) @Vector(4, i16) {
+        \\    return -a;
+        \\}
+    );
+
     cases.addRuntimeSafety("integer subtraction overflow",
         \\pub fn panic(message: []const u8, stack_trace: ?*@import("builtin").StackTrace) noreturn {
         \\    @import("std").os.exit(126);