Commit c8ed813097

John Schmidt <john.schmidt.h@gmail.com>
2022-03-15 23:25:38
Implement `@mulAdd` for vectors
1 parent 3125365
Changed files (3)
src
test
behavior
src/codegen/llvm.zig
@@ -5166,7 +5166,13 @@ pub const FuncGen = struct {
             intrinsic,
             libc: [*:0]const u8,
         };
-        const strat: Strat = switch (ty.floatBits(target)) {
+
+        const scalar_ty = if (ty.zigTypeTag() == .Vector)
+            ty.elemType()
+        else
+            ty;
+
+        const strat: Strat = switch (scalar_ty.floatBits(target)) {
             16, 32, 64 => Strat.intrinsic,
             80 => if (CType.longdouble.sizeInBits(target) == 80) Strat{ .intrinsic = {} } else Strat{ .libc = "__fmax" },
             // LLVM always lowers the fma builtin for f128 to fmal, which is for `long double`.
@@ -5175,17 +5181,46 @@ pub const FuncGen = struct {
             else => unreachable,
         };
 
-        const llvm_fn = switch (strat) {
-            .intrinsic => self.getIntrinsic("llvm.fma", &.{llvm_ty}),
-            .libc => |fn_name| self.dg.object.llvm_module.getNamedFunction(fn_name) orelse b: {
-                const param_types = [_]*const llvm.Type{ llvm_ty, llvm_ty, llvm_ty };
-                const fn_type = llvm.functionType(llvm_ty, &param_types, param_types.len, .False);
-                break :b self.dg.object.llvm_module.addFunction(fn_name, fn_type);
+        switch (strat) {
+            .intrinsic => {
+                const llvm_fn = self.getIntrinsic("llvm.fma", &.{llvm_ty});
+                const params = [_]*const llvm.Value{ mulend1, mulend2, addend };
+                return self.builder.buildCall(llvm_fn, &params, params.len, .C, .Auto, "");
             },
-        };
+            .libc => |fn_name| {
+                const scalar_llvm_ty = try self.dg.llvmType(scalar_ty);
+                const llvm_fn = self.dg.object.llvm_module.getNamedFunction(fn_name) orelse b: {
+                    const param_types = [_]*const llvm.Type{ scalar_llvm_ty, scalar_llvm_ty, scalar_llvm_ty };
+                    const fn_type = llvm.functionType(scalar_llvm_ty, &param_types, param_types.len, .False);
+                    break :b self.dg.object.llvm_module.addFunction(fn_name, fn_type);
+                };
+
+                if (ty.zigTypeTag() == .Vector) {
+                    const llvm_i32 = self.context.intType(32);
+                    const vector_llvm_ty = try self.dg.llvmType(ty);
+
+                    var i: usize = 0;
+                    var vector = vector_llvm_ty.getUndef();
+                    while (i < ty.vectorLen()) : (i += 1) {
+                        const index_i32 = llvm_i32.constInt(i, .False);
+
+                        const mulend1_elem = self.builder.buildExtractElement(mulend1, index_i32, "");
+                        const mulend2_elem = self.builder.buildExtractElement(mulend2, index_i32, "");
+                        const addend_elem = self.builder.buildExtractElement(addend, index_i32, "");
 
-        const params = [_]*const llvm.Value{ mulend1, mulend2, addend };
-        return self.builder.buildCall(llvm_fn, &params, params.len, .C, .Auto, "");
+                        const params = [_]*const llvm.Value{ mulend1_elem, mulend2_elem, addend_elem };
+                        const mul_add = self.builder.buildCall(llvm_fn, &params, params.len, .C, .Auto, "");
+
+                        vector = self.builder.buildInsertElement(vector, mul_add, index_i32, "");
+                    }
+
+                    return vector;
+                } else {
+                    const params = [_]*const llvm.Value{ mulend1, mulend2, addend };
+                    return self.builder.buildCall(llvm_fn, &params, params.len, .C, .Auto, "");
+                }
+            },
+        }
     }
 
     fn airShlWithOverflow(self: *FuncGen, inst: Air.Inst.Index) !?*const llvm.Value {
src/Sema.zig
@@ -14499,19 +14499,24 @@ fn zirMulAdd(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError!Air.
 
     const target = sema.mod.getTarget();
 
+    const maybe_mulend1 = try sema.resolveMaybeUndefVal(block, mulend1_src, mulend1);
+    const maybe_mulend2 = try sema.resolveMaybeUndefVal(block, mulend2_src, mulend2);
+    const maybe_addend = try sema.resolveMaybeUndefVal(block, addend_src, addend);
+
     switch (ty.zigTypeTag()) {
-        .ComptimeFloat, .Float => {
-            const maybe_mulend1 = try sema.resolveMaybeUndefVal(block, mulend1_src, mulend1);
-            const maybe_mulend2 = try sema.resolveMaybeUndefVal(block, mulend2_src, mulend2);
-            const maybe_addend = try sema.resolveMaybeUndefVal(block, addend_src, addend);
+        .ComptimeFloat, .Float, .Vector => {},
+        else => return sema.fail(block, src, "expected vector of floats or float type, found '{}'", .{ty}),
+    }
 
-            const runtime_src = if (maybe_mulend1) |mulend1_val| rs: {
-                if (maybe_mulend2) |mulend2_val| {
-                    if (mulend2_val.isUndef()) return sema.addConstUndef(ty);
+    const runtime_src = if (maybe_mulend1) |mulend1_val| rs: {
+        if (maybe_mulend2) |mulend2_val| {
+            if (mulend2_val.isUndef()) return sema.addConstUndef(ty);
 
-                    if (maybe_addend) |addend_val| {
-                        if (addend_val.isUndef()) return sema.addConstUndef(ty);
+            if (maybe_addend) |addend_val| {
+                if (addend_val.isUndef()) return sema.addConstUndef(ty);
 
+                switch (ty.zigTypeTag()) {
+                    .ComptimeFloat, .Float => {
                         const result_val = try Value.mulAdd(
                             ty,
                             mulend1_val,
@@ -14521,47 +14526,70 @@ fn zirMulAdd(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError!Air.
                             target,
                         );
                         return sema.addConstant(ty, result_val);
-                    } else {
-                        break :rs addend_src;
-                    }
-                } else {
-                    if (maybe_addend) |addend_val| {
-                        if (addend_val.isUndef()) return sema.addConstUndef(ty);
-                    }
-                    break :rs mulend2_src;
-                }
-            } else rs: {
-                if (maybe_mulend2) |mulend2_val| {
-                    if (mulend2_val.isUndef()) return sema.addConstUndef(ty);
-                }
-                if (maybe_addend) |addend_val| {
-                    if (addend_val.isUndef()) return sema.addConstUndef(ty);
-                }
-                break :rs mulend1_src;
-            };
+                    },
+                    .Vector => {
+                        const scalar_ty = ty.scalarType();
+                        switch (scalar_ty.zigTypeTag()) {
+                            .ComptimeFloat, .Float => {},
+                            else => return sema.fail(block, src, "expected vector of floats, found vector of '{}'", .{scalar_ty}),
+                        }
 
-            try sema.requireRuntimeBlock(block, runtime_src);
-            return block.addInst(.{
-                .tag = .mul_add,
-                .data = .{ .pl_op = .{
-                    .operand = addend,
-                    .payload = try sema.addExtra(Air.Bin{
-                        .lhs = mulend1,
-                        .rhs = mulend2,
-                    }),
-                } },
-            });
-        },
-        .Vector => {
-            const scalar_ty = ty.scalarType();
-            switch (scalar_ty.zigTypeTag()) {
-                .ComptimeFloat, .Float => {},
-                else => return sema.fail(block, src, "expected vector of floats or float type, found '{}'", .{scalar_ty}),
+                        const vec_len = ty.vectorLen();
+                        const result_ty = try Type.vector(sema.arena, vec_len, scalar_ty);
+                        var mulend1_buf: Value.ElemValueBuffer = undefined;
+                        var mulend2_buf: Value.ElemValueBuffer = undefined;
+                        var addend_buf: Value.ElemValueBuffer = undefined;
+                        const elems = try sema.arena.alloc(Value, vec_len);
+                        for (elems) |*elem, i| {
+                            const mulend1_elem_val = mulend1_val.elemValueBuffer(i, &mulend1_buf);
+                            const mulend2_elem_val = mulend2_val.elemValueBuffer(i, &mulend2_buf);
+                            const addend_elem_val = addend_val.elemValueBuffer(i, &addend_buf);
+                            elem.* = try Value.mulAdd(
+                                scalar_ty,
+                                mulend1_elem_val,
+                                mulend2_elem_val,
+                                addend_elem_val,
+                                sema.arena,
+                                target,
+                            );
+                        }
+                        return sema.addConstant(
+                            result_ty,
+                            try Value.Tag.aggregate.create(sema.arena, elems),
+                        );
+                    },
+                    else => unreachable,
+                }
+            } else {
+                break :rs addend_src;
             }
-            return sema.fail(block, src, "TODO: implement @mulAdd for vectors", .{});
-        },
-        else => return sema.fail(block, src, "expected vector of floats or float type, found '{}'", .{ty}),
-    }
+        } else {
+            if (maybe_addend) |addend_val| {
+                if (addend_val.isUndef()) return sema.addConstUndef(ty);
+            }
+            break :rs mulend2_src;
+        }
+    } else rs: {
+        if (maybe_mulend2) |mulend2_val| {
+            if (mulend2_val.isUndef()) return sema.addConstUndef(ty);
+        }
+        if (maybe_addend) |addend_val| {
+            if (addend_val.isUndef()) return sema.addConstUndef(ty);
+        }
+        break :rs mulend1_src;
+    };
+
+    try sema.requireRuntimeBlock(block, runtime_src);
+    return block.addInst(.{
+        .tag = .mul_add,
+        .data = .{ .pl_op = .{
+            .operand = addend,
+            .payload = try sema.addExtra(Air.Bin{
+                .lhs = mulend1,
+                .rhs = mulend2,
+            }),
+        } },
+    });
 }
 
 fn zirBuiltinCall(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError!Air.Inst.Ref {
test/behavior/muladd.zig
@@ -78,3 +78,136 @@ fn testMulAdd128() !void {
     var c: f128 = 6.25;
     try expect(@mulAdd(f128, a, b, c) == 20);
 }
+
+fn vector16() !void {
+    var a = @Vector(4, f16){ 5.5, 5.5, 5.5, 5.5 };
+    var b = @Vector(4, f16){ 2.5, 2.5, 2.5, 2.5 };
+    var c = @Vector(4, f16){ 6.25, 6.25, 6.25, 6.25 };
+    var x = @mulAdd(@Vector(4, f16), a, b, c);
+
+    // TODO use `expectEqual` instead once stage2 supports it
+    // var expected = @Vector(4, f16){ 20, 20, 20, 20 };
+    // try expectEqual(expected, x);
+
+    try expect(x[0] == 20);
+    try expect(x[1] == 20);
+    try expect(x[2] == 20);
+    try expect(x[3] == 20);
+}
+
+test "vector f16" {
+    if (builtin.zig_backend == .stage1) return error.SkipZigTest;
+    if (builtin.zig_backend == .stage2_c) return error.SkipZigTest; // TODO
+    if (builtin.zig_backend == .stage2_wasm) return error.SkipZigTest; // TODO
+    if (builtin.zig_backend == .stage2_x86_64) return error.SkipZigTest; // TODO
+    if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
+    if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
+
+    comptime try vector16();
+    try vector16();
+}
+
+fn vector32() !void {
+    var a = @Vector(4, f32){ 5.5, 5.5, 5.5, 5.5 };
+    var b = @Vector(4, f32){ 2.5, 2.5, 2.5, 2.5 };
+    var c = @Vector(4, f32){ 6.25, 6.25, 6.25, 6.25 };
+    var x = @mulAdd(@Vector(4, f32), a, b, c);
+
+    // TODO use `expectEqual` instead once stage2 supports it
+    // var expected = @Vector(4, f32){ 20, 20, 20, 20 };
+    // try expectEqual(expected, x);
+
+    try expect(x[0] == 20);
+    try expect(x[1] == 20);
+    try expect(x[2] == 20);
+    try expect(x[3] == 20);
+}
+
+test "vector f32" {
+    if (builtin.zig_backend == .stage1) return error.SkipZigTest;
+    if (builtin.zig_backend == .stage2_c) return error.SkipZigTest; // TODO
+    if (builtin.zig_backend == .stage2_wasm) return error.SkipZigTest; // TODO
+    if (builtin.zig_backend == .stage2_x86_64) return error.SkipZigTest; // TODO
+    if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
+    if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
+
+    comptime try vector32();
+    try vector32();
+}
+
+fn vector64() !void {
+    var a = @Vector(4, f64){ 5.5, 5.5, 5.5, 5.5 };
+    var b = @Vector(4, f64){ 2.5, 2.5, 2.5, 2.5 };
+    var c = @Vector(4, f64){ 6.25, 6.25, 6.25, 6.25 };
+    var x = @mulAdd(@Vector(4, f64), a, b, c);
+
+    // TODO use `expectEqual` instead once stage2 supports it
+    // var expected = @Vector(4, f64){ 20, 20, 20, 20 };
+    // try expectEqual(expected, x);
+
+    try expect(x[0] == 20);
+    try expect(x[1] == 20);
+    try expect(x[2] == 20);
+    try expect(x[3] == 20);
+}
+
+test "vector f64" {
+    if (builtin.zig_backend == .stage1) return error.SkipZigTest;
+    if (builtin.zig_backend == .stage2_c) return error.SkipZigTest; // TODO
+    if (builtin.zig_backend == .stage2_wasm) return error.SkipZigTest; // TODO
+    if (builtin.zig_backend == .stage2_x86_64) return error.SkipZigTest; // TODO
+    if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
+    if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
+
+    comptime try vector64();
+    try vector64();
+}
+
+fn vector80() !void {
+    var a = @Vector(4, f80){ 5.5, 5.5, 5.5, 5.5 };
+    var b = @Vector(4, f80){ 2.5, 2.5, 2.5, 2.5 };
+    var c = @Vector(4, f80){ 6.25, 6.25, 6.25, 6.25 };
+    var x = @mulAdd(@Vector(4, f80), a, b, c);
+    try expect(x[0] == 20);
+    try expect(x[1] == 20);
+    try expect(x[2] == 20);
+    try expect(x[3] == 20);
+}
+
+test "vector f80" {
+    if (true) {
+        // https://github.com/ziglang/zig/issues/11030
+        return error.SkipZigTest;
+    }
+
+    comptime try vector80();
+    try vector80();
+}
+
+fn vector128() !void {
+    var a = @Vector(4, f128){ 5.5, 5.5, 5.5, 5.5 };
+    var b = @Vector(4, f128){ 2.5, 2.5, 2.5, 2.5 };
+    var c = @Vector(4, f128){ 6.25, 6.25, 6.25, 6.25 };
+    var x = @mulAdd(@Vector(4, f128), a, b, c);
+
+    // TODO use `expectEqual` instead once stage2 supports it
+    // var expected = @Vector(4, f128){ 20, 20, 20, 20 };
+    // try expectEqual(expected, x);
+
+    try expect(x[0] == 20);
+    try expect(x[1] == 20);
+    try expect(x[2] == 20);
+    try expect(x[3] == 20);
+}
+
+test "vector f128" {
+    if (builtin.zig_backend == .stage1) return error.SkipZigTest;
+    if (builtin.zig_backend == .stage2_c) return error.SkipZigTest; // TODO
+    if (builtin.zig_backend == .stage2_wasm) return error.SkipZigTest; // TODO
+    if (builtin.zig_backend == .stage2_x86_64) return error.SkipZigTest; // TODO
+    if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
+    if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
+
+    comptime try vector128();
+    try vector128();
+}