Commit 0d6a7088dc

LemonBoy <thatlemon@gmail.com>
2020-11-01 19:51:42
stage1: Implement Add/Mul reduction operators
1 parent 6f3d6c1
Changed files (8)
lib/std/builtin.zig
@@ -106,6 +106,8 @@ pub const ReduceOp = enum {
     Xor,
     Min,
     Max,
+    Add,
+    Mul,
 };
 
 /// This data structure is used by the Zig language code generation and
src/stage1/all_types.hpp
@@ -2447,6 +2447,8 @@ enum ReduceOp {
     ReduceOp_xor,
     ReduceOp_min,
     ReduceOp_max,
+    ReduceOp_add,
+    ReduceOp_mul,
 };
 
 // synchronized with the code in define_builtin_compile_vars
src/stage1/codegen.cpp
@@ -5460,6 +5460,8 @@ static LLVMValueRef ir_render_reduce(CodeGen *g, IrExecutableGen *executable, Ir
     assert(value_type->id == ZigTypeIdVector);
     ZigType *scalar_type = value_type->data.vector.elem_type;
 
+    ZigLLVMSetFastMath(g->builder, ir_want_fast_math(g, &instruction->base));
+
     LLVMValueRef result_val;
     switch (instruction->op) {
         case ReduceOp_and:
@@ -5490,6 +5492,24 @@ static LLVMValueRef ir_render_reduce(CodeGen *g, IrExecutableGen *executable, Ir
                 result_val = ZigLLVMBuildFPMaxReduce(g->builder, value);
             } else zig_unreachable();
         } break;
+        case ReduceOp_add: {
+            if (scalar_type->id == ZigTypeIdInt) {
+                result_val = ZigLLVMBuildAddReduce(g->builder, value);
+            } else if (scalar_type->id == ZigTypeIdFloat) {
+                LLVMValueRef neutral_value = LLVMConstReal(
+                        get_llvm_type(g, scalar_type), -0.0);
+                result_val = ZigLLVMBuildFPAddReduce(g->builder, neutral_value, value);
+            } else zig_unreachable();
+        } break;
+        case ReduceOp_mul: {
+            if (scalar_type->id == ZigTypeIdInt) {
+                result_val = ZigLLVMBuildMulReduce(g->builder, value);
+            } else if (scalar_type->id == ZigTypeIdFloat) {
+                LLVMValueRef neutral_value = LLVMConstReal(
+                        get_llvm_type(g, scalar_type), 1.0);
+                result_val = ZigLLVMBuildFPMulReduce(g->builder, neutral_value, value);
+            } else zig_unreachable();
+        } break;
         default:
             zig_unreachable();
     }
src/stage1/ir.cpp
@@ -27046,7 +27046,8 @@ static ErrorMsg *ir_eval_reduce(IrAnalyze *ira, IrInst *source_instr, ReduceOp o
         return nullptr;
     }
 
-    if (op != ReduceOp_min && op != ReduceOp_max) {
+    // Evaluate and/or/xor.
+    if (op == ReduceOp_and || op == ReduceOp_or || op == ReduceOp_xor) {
         ZigValue *first_elem_val = &value->data.x_array.data.s_none.elements[0];
 
         copy_const_val(ira->codegen, out_value, first_elem_val);
@@ -27071,6 +27072,43 @@ static ErrorMsg *ir_eval_reduce(IrAnalyze *ira, IrInst *source_instr, ReduceOp o
         return nullptr;
     }
 
+    // Evaluate add/sub.
+    // Perform the reduction sequentially, starting from the neutral value.
+    if (op == ReduceOp_add || op == ReduceOp_mul) {
+        if (scalar_type->id == ZigTypeIdInt) {
+            if (op == ReduceOp_add) {
+                bigint_init_unsigned(&out_value->data.x_bigint, 0);
+            } else {
+                bigint_init_unsigned(&out_value->data.x_bigint, 1);
+            }
+        } else {
+            if (op == ReduceOp_add) {
+                float_init_f64(out_value, -0.0);
+            } else {
+                float_init_f64(out_value, 1.0);
+            }
+        }
+
+        for (size_t i = 0; i < len; i++) {
+            ZigValue *elem_val = &value->data.x_array.data.s_none.elements[i];
+
+            IrBinOp bin_op;
+            switch (op) {
+                case ReduceOp_add: bin_op = IrBinOpAdd; break;
+                case ReduceOp_mul: bin_op = IrBinOpMult; break;
+                default: zig_unreachable();
+            }
+
+            ErrorMsg *msg = ir_eval_math_op_scalar(ira, source_instr, scalar_type,
+                    out_value, bin_op, elem_val, out_value);
+            if (msg != nullptr)
+                return msg;
+        }
+
+        return nullptr;
+    }
+
+    // Evaluate min/max.
     ZigValue *candidate_elem_val = &value->data.x_array.data.s_none.elements[0];
 
     ZigValue *dummy_cmp_value = ira->codegen->pass1_arena->create<ZigValue>();
src/stage1/ir_print.cpp
@@ -1611,6 +1611,8 @@ static const char *reduce_op_str(ReduceOp op) {
         case ReduceOp_xor: return "Xor";
         case ReduceOp_min: return "Min";
         case ReduceOp_max: return "Max";
+        case ReduceOp_add: return "Add";
+        case ReduceOp_mul: return "Mul";
     }
     zig_unreachable();
 }
src/zig_llvm.cpp
@@ -1156,6 +1156,22 @@ LLVMValueRef ZigLLVMBuildFPMinReduce(LLVMBuilderRef B, LLVMValueRef Val) {
     return wrap(unwrap(B)->CreateFPMinReduce(unwrap(Val)));
 }
 
+LLVMValueRef ZigLLVMBuildAddReduce(LLVMBuilderRef B, LLVMValueRef Val) {
+    return wrap(unwrap(B)->CreateAddReduce(unwrap(Val)));
+}
+
+LLVMValueRef ZigLLVMBuildMulReduce(LLVMBuilderRef B, LLVMValueRef Val) {
+    return wrap(unwrap(B)->CreateMulReduce(unwrap(Val)));
+}
+
+LLVMValueRef ZigLLVMBuildFPAddReduce(LLVMBuilderRef B, LLVMValueRef Acc, LLVMValueRef Val) {
+    return wrap(unwrap(B)->CreateFAddReduce(unwrap(Acc), unwrap(Val)));
+}
+
+LLVMValueRef ZigLLVMBuildFPMulReduce(LLVMBuilderRef B, LLVMValueRef Acc, LLVMValueRef Val) {
+    return wrap(unwrap(B)->CreateFMulReduce(unwrap(Acc), unwrap(Val)));
+}
+
 static_assert((Triple::ArchType)ZigLLVM_UnknownArch == Triple::UnknownArch, "");
 static_assert((Triple::ArchType)ZigLLVM_arm == Triple::arm, "");
 static_assert((Triple::ArchType)ZigLLVM_armeb == Triple::armeb, "");
src/zig_llvm.h
@@ -462,6 +462,10 @@ LLVMValueRef ZigLLVMBuildIntMaxReduce(LLVMBuilderRef B, LLVMValueRef Val, bool i
 LLVMValueRef ZigLLVMBuildIntMinReduce(LLVMBuilderRef B, LLVMValueRef Val, bool is_signed);
 LLVMValueRef ZigLLVMBuildFPMaxReduce(LLVMBuilderRef B, LLVMValueRef Val);
 LLVMValueRef ZigLLVMBuildFPMinReduce(LLVMBuilderRef B, LLVMValueRef Val);
+LLVMValueRef ZigLLVMBuildAddReduce(LLVMBuilderRef B, LLVMValueRef Val);
+LLVMValueRef ZigLLVMBuildMulReduce(LLVMBuilderRef B, LLVMValueRef Val);
+LLVMValueRef ZigLLVMBuildFPAddReduce(LLVMBuilderRef B, LLVMValueRef Acc, LLVMValueRef Val);
+LLVMValueRef ZigLLVMBuildFPMulReduce(LLVMBuilderRef B, LLVMValueRef Acc, LLVMValueRef Val);
 
 #define ZigLLVM_DIFlags_Zero 0U
 #define ZigLLVM_DIFlags_Private 1U
test/stage1/behavior/vector.zig
@@ -4,6 +4,7 @@ const mem = std.mem;
 const math = std.math;
 const expect = std.testing.expect;
 const expectEqual = std.testing.expectEqual;
+const expectWithinEpsilon = std.testing.expectWithinEpsilon;
 const Vector = std.meta.Vector;
 
 test "implicit cast vector to array - bool" {
@@ -492,7 +493,17 @@ test "vector reduce operation" {
             const TX = @typeInfo(@TypeOf(x)).Array.child;
 
             var r = @reduce(op, @as(Vector(N, TX), x));
-            expectEqual(expected, r);
+            switch (@typeInfo(TX)) {
+                .Int, .Bool => expectEqual(expected, r),
+                .Float => {
+                    if (math.isNan(expected) != math.isNan(r)) {
+                        std.debug.panic("unexpected NaN value!", .{});
+                    } else {
+                        expectWithinEpsilon(expected, r, 0.0001);
+                    }
+                },
+                else => unreachable,
+            }
         }
         fn doTheTest() void {
             doTheTestReduce(.And, [4]bool{ true, false, true, true }, @as(bool, false));
@@ -510,14 +521,49 @@ test "vector reduce operation" {
             doTheTestReduce(.Min, [4]i32{ 1234567, -386, 0, 3 }, @as(i32, -386));
             doTheTestReduce(.Max, [4]i32{ 1234567, -386, 0, 3 }, @as(i32, 1234567));
 
+            doTheTestReduce(.Add, [4]i32{ -9, -99, -999, -9999 }, @as(i32, -11106));
+            doTheTestReduce(.Add, [4]i64{ 9, 99, 999, 9999 }, @as(i64, 11106));
+
             doTheTestReduce(.Min, [4]u32{ 99, 9999, 9, 99999 }, @as(u32, 9));
             doTheTestReduce(.Max, [4]u32{ 99, 9999, 9, 99999 }, @as(u32, 99999));
 
+            doTheTestReduce(.Mul, [4]i32{ -9, -99, -999, 999 }, @as(i32, -889218891));
+            doTheTestReduce(.Mul, [4]i64{ 9, 99, 999, 9999 }, @as(i64, 8900199891));
+
             doTheTestReduce(.Min, [4]f32{ -10.3, 10.0e9, 13.0, -100.0 }, @as(f32, -100.0));
             doTheTestReduce(.Max, [4]f32{ -10.3, 10.0e9, 13.0, -100.0 }, @as(f32, 10.0e9));
 
             doTheTestReduce(.Min, [4]f64{ -10.3, 10.0e9, 13.0, -100.0 }, @as(f64, -100.0));
             doTheTestReduce(.Max, [4]f64{ -10.3, 10.0e9, 13.0, -100.0 }, @as(f64, 10.0e9));
+
+            doTheTestReduce(.Add, [4]f32{ -1.9, 5.1, -60.3, 100.0 }, @as(f32, 42.9));
+            doTheTestReduce(.Add, [4]f64{ -1.9, 5.1, -60.3, 100.0 }, @as(f64, 42.9));
+
+            doTheTestReduce(.Mul, [4]f32{ -1.9, 5.1, -60.3, 100.0 }, @as(f32, 58430.7));
+            doTheTestReduce(.Mul, [4]f64{ -1.9, 5.1, -60.3, 100.0 }, @as(f64, 58430.7));
+
+            // Test the reduction on vectors containing NaNs.
+            const f16_nan = math.nan(f16);
+            const f32_nan = math.nan(f32);
+            const f64_nan = math.nan(f64);
+
+            doTheTestReduce(.Add, [4]f16{ -1.9, 5.1, f16_nan, 100.0 }, f16_nan);
+            doTheTestReduce(.Add, [4]f16{ -1.9, 5.1, f16_nan, 100.0 }, f16_nan);
+
+            doTheTestReduce(.Add, [4]f32{ -1.9, 5.1, f32_nan, 100.0 }, f32_nan);
+            doTheTestReduce(.Add, [4]f32{ -1.9, 5.1, f32_nan, 100.0 }, f32_nan);
+
+            doTheTestReduce(.Add, [4]f64{ -1.9, 5.1, f64_nan, 100.0 }, f64_nan);
+            doTheTestReduce(.Add, [4]f64{ -1.9, 5.1, f64_nan, 100.0 }, f64_nan);
+
+            doTheTestReduce(.Mul, [4]f16{ -1.9, 5.1, f16_nan, 100.0 }, f16_nan);
+            doTheTestReduce(.Mul, [4]f16{ -1.9, 5.1, f16_nan, 100.0 }, f16_nan);
+
+            doTheTestReduce(.Mul, [4]f32{ -1.9, 5.1, f32_nan, 100.0 }, f32_nan);
+            doTheTestReduce(.Mul, [4]f32{ -1.9, 5.1, f32_nan, 100.0 }, f32_nan);
+
+            doTheTestReduce(.Mul, [4]f64{ -1.9, 5.1, f64_nan, 100.0 }, f64_nan);
+            doTheTestReduce(.Mul, [4]f64{ -1.9, 5.1, f64_nan, 100.0 }, f64_nan);
         }
     };