Commit b9a95f2dd9

Travis Staloch <twostepted@gmail.com>
2021-09-09 00:19:03
sat-arithmetic: add c backend support
- modify AstGen binOpExt()/assignBinOpExt() to accept generic extended payload T - rework Sema zirSatArithmetic() to use existing sema.analyzeArithmetic() by adding an `opt_extended` parameter. - add airSatOp() to codegen/c.zig - add saturating functions to src/link/C/zig.h
1 parent 29f4189
Changed files (5)
src/codegen/c.zig
@@ -885,17 +885,17 @@ fn genBody(f: *Function, body: []const Air.Inst.Index) error{ AnalysisFail, OutO
             // that wrapping is UB.
             .add, .ptr_add => try airBinOp( f, inst, " + "),
             .addwrap       => try airWrapOp(f, inst, " + ", "addw_"),
-            .addsat        => return o.dg.fail("TODO: C backend: implement codegen for addsat", .{}),
+            .addsat        => return f.fail("TODO: C backend: implement codegen for addsat", .{}),
             // TODO use a different strategy for sub that communicates to the optimizer
             // that wrapping is UB.
             .sub, .ptr_sub => try airBinOp( f, inst, " - "),
             .subwrap       => try airWrapOp(f, inst, " - ", "subw_"),
-            .subsat        => return o.dg.fail("TODO: C backend: implement codegen for subsat", .{}),
+            .subsat        => return f.fail("TODO: C backend: implement codegen for subsat", .{}),
             // TODO use a different strategy for mul that communicates to the optimizer
             // that wrapping is UB.
             .mul           => try airBinOp( f, inst, " * "),
             .mulwrap       => try airWrapOp(f, inst, " * ", "mulw_"),
-            .mulsat        => return o.dg.fail("TODO: C backend: implement codegen for mulsat", .{}),
+            .mulsat        => return f.fail("TODO: C backend: implement codegen for mulsat", .{}),
             // TODO use a different strategy for div that communicates to the optimizer
             // that wrapping is UB.
             .div           => try airBinOp( f, inst, " / "),
@@ -919,6 +919,8 @@ fn genBody(f: *Function, body: []const Air.Inst.Index) error{ AnalysisFail, OutO
 
             .shr        => try airBinOp(f, inst, " >> "),
             .shl        => try airBinOp(f, inst, " << "),
+            .shl_sat    => return f.fail("TODO: C backend: implement codegen for mulsat", .{}),
+            
 
             .not        => try airNot(  f, inst),
 
@@ -1312,6 +1314,118 @@ fn airWrapOp(
     return ret;
 }
 
+fn airSatOp(
+    o: *Object,
+    inst: Air.Inst.Index,
+    str_op: [*:0]const u8,
+    fn_op: [*:0]const u8,
+) !CValue {
+    if (o.liveness.isUnused(inst))
+        return CValue.none;
+
+    const bin_op = o.air.instructions.items(.data)[inst].bin_op;
+    const inst_ty = o.air.typeOfIndex(inst);
+    const int_info = inst_ty.intInfo(o.dg.module.getTarget());
+    const bits = int_info.bits;
+
+    // if it's an unsigned int with non-arbitrary bit size then we can just add
+    const ok_bits = switch (bits) {
+        8, 16, 32, 64, 128 => true,
+        else => false,
+    };
+
+    if (bits > 64) {
+        return f.fail("TODO: C backend: airSatOp for large integers", .{});
+    }
+
+    var min_buf: [80]u8 = undefined;
+    const min = switch (int_info.signedness) {
+        .unsigned => "0",
+        else => switch (inst_ty.tag()) {
+            .c_short => "SHRT_MIN",
+            .c_int => "INT_MIN",
+            .c_long => "LONG_MIN",
+            .c_longlong => "LLONG_MIN",
+            .isize => "INTPTR_MIN",
+            else => blk: {
+                const val = -1 * std.math.pow(i65, 2, @intCast(i65, bits - 1));
+                break :blk std.fmt.bufPrint(&min_buf, "{d}", .{val}) catch |err| switch (err) {
+                    error.NoSpaceLeft => unreachable,
+                    else => |e| return e,
+                };
+            },
+        },
+    };
+
+    var max_buf: [80]u8 = undefined;
+    const max = switch (inst_ty.tag()) {
+        .c_short => "SHRT_MAX",
+        .c_ushort => "USHRT_MAX",
+        .c_int => "INT_MAX",
+        .c_uint => "UINT_MAX",
+        .c_long => "LONG_MAX",
+        .c_ulong => "ULONG_MAX",
+        .c_longlong => "LLONG_MAX",
+        .c_ulonglong => "ULLONG_MAX",
+        .isize => "INTPTR_MAX",
+        .usize => "UINTPTR_MAX",
+        else => blk: {
+            const pow_bits = switch (int_info.signedness) {
+                .signed => bits - 1,
+                .unsigned => bits,
+            };
+            const val = std.math.pow(u65, 2, pow_bits) - 1;
+            break :blk std.fmt.bufPrint(&max_buf, "{}", .{val}) catch |err| switch (err) {
+                error.NoSpaceLeft => unreachable,
+                else => |e| return e,
+            };
+        },
+    };
+
+    const lhs = try o.resolveInst(bin_op.lhs);
+    const rhs = try o.resolveInst(bin_op.rhs);
+    const w = o.writer();
+
+    const ret = try o.allocLocal(inst_ty, .Mut);
+    try w.print(" = zig_{s}", .{fn_op});
+
+    switch (inst_ty.tag()) {
+        .isize => try w.writeAll("isize"),
+        .c_short => try w.writeAll("short"),
+        .c_int => try w.writeAll("int"),
+        .c_long => try w.writeAll("long"),
+        .c_longlong => try w.writeAll("longlong"),
+        else => {
+            const prefix_byte: u8 = switch (int_info.signedness) {
+                .signed => 'i',
+                .unsigned => 'u',
+            };
+            for ([_]u8{ 8, 16, 32, 64 }) |nbits| {
+                if (bits <= nbits) {
+                    try w.print("{c}{d}", .{ prefix_byte, nbits });
+                    break;
+                }
+            } else {
+                unreachable;
+            }
+        },
+    }
+
+    try w.writeByte('(');
+    try o.writeCValue(w, lhs);
+    try w.writeAll(", ");
+    try o.writeCValue(w, rhs);
+
+    if (int_info.signedness == .signed) {
+        try w.print(", {s}", .{min});
+    }
+
+    try w.print(", {s});", .{max});
+    try o.indent_writer.insertNewline();
+
+    return ret;
+}
+
 fn airNot(f: *Function, inst: Air.Inst.Index) !CValue {
     if (f.liveness.isUnused(inst))
         return CValue.none;
src/codegen/llvm.zig
@@ -2038,7 +2038,7 @@ pub const FuncGen = struct {
         const rhs = try self.resolveInst(bin_op.rhs);
         const inst_ty = self.air.typeOfIndex(inst);
 
-        if (inst_ty.isFloat()) return self.builder.buildFAdd(lhs, rhs, "");
+        if (inst_ty.isAnyFloat()) return self.builder.buildFAdd(lhs, rhs, "");
         if (ty == .wrapping)
             return self.builder.buildAdd(lhs, rhs, "")
         else if (ty == .saturated) {
@@ -2060,7 +2060,7 @@ pub const FuncGen = struct {
         const rhs = try self.resolveInst(bin_op.rhs);
         const inst_ty = self.air.typeOfIndex(inst);
 
-        if (inst_ty.isFloat()) return self.builder.buildFSub(lhs, rhs, "");
+        if (inst_ty.isAnyFloat()) return self.builder.buildFSub(lhs, rhs, "");
         if (ty == .wrapping)
             return self.builder.buildSub(lhs, rhs, "")
         else if (ty == .saturated) {
@@ -2082,7 +2082,7 @@ pub const FuncGen = struct {
         const rhs = try self.resolveInst(bin_op.rhs);
         const inst_ty = self.air.typeOfIndex(inst);
 
-        if (inst_ty.isFloat()) return self.builder.buildFMul(lhs, rhs, "");
+        if (inst_ty.isAnyFloat()) return self.builder.buildFMul(lhs, rhs, "");
         if (ty == .wrapping)
             return self.builder.buildMul(lhs, rhs, "")
         else if (ty == .saturated) {
src/link/C/zig.h
@@ -356,3 +356,96 @@ static inline long long zig_subw_longlong(long long lhs, long long rhs, long lon
     return (long long)(((unsigned long long)lhs) - ((unsigned long long)rhs));
 }
 
+/*
+ * Saturating aritmetic operations: add, sub, mul, shl
+ */
+#define zig_add_sat_u(ZT, T) static inline T zig_adds_##ZT(T x, T y, T max) { \
+    return (x > max - y) ? max : x + y; \
+}
+
+#define zig_add_sat_s(ZT, T, T2) static inline T zig_adds_##ZT(T2 x, T2 y, T2 min, T2 max) { \
+    T2 res = x + y; \
+    return (res < min) ? min : (res > max) ? max : res; \
+}
+
+zig_add_sat_u( u8,    uint8_t)
+zig_add_sat_s( i8,     int8_t,  int16_t)
+zig_add_sat_u(u16,   uint16_t)
+zig_add_sat_s(i16,    int16_t,  int32_t)
+zig_add_sat_u(u32,   uint32_t)
+zig_add_sat_s(i32,    int32_t,  int64_t)
+zig_add_sat_u(u64,   uint64_t)
+zig_add_sat_s(i64,    int64_t, int128_t)
+zig_add_sat_s(isize, intptr_t, int128_t)
+zig_add_sat_s(short,    short, int)
+zig_add_sat_s(int,        int, long)
+zig_add_sat_s(long,      long, long long)
+
+#define zig_sub_sat_u(ZT, T) static inline T zig_subs_##ZT(T x, T y, T max) { \
+    return (x > max + y) ? max : x - y; \
+}
+
+#define zig_sub_sat_s(ZT, T, T2) static inline T zig_subs_##ZT(T2 x, T2 y, T2 min, T2 max) { \
+    T2 res = x - y; \
+    return (res < min) ? min : (res > max) ? max : res; \
+}
+
+zig_sub_sat_u( u8,    uint8_t)
+zig_sub_sat_s( i8,     int8_t,  int16_t)
+zig_sub_sat_u(u16,   uint16_t)
+zig_sub_sat_s(i16,    int16_t,  int32_t)
+zig_sub_sat_u(u32,   uint32_t)
+zig_sub_sat_s(i32,    int32_t,  int64_t)
+zig_sub_sat_u(u64,   uint64_t)
+zig_sub_sat_s(i64,    int64_t, int128_t)
+zig_sub_sat_s(isize, intptr_t, int128_t)
+zig_sub_sat_s(short,    short, int)
+zig_sub_sat_s(int,        int, long)
+zig_sub_sat_s(long,      long, long long)
+
+
+#define zig_mul_sat_u(ZT, T, T2) static inline T zig_muls_##ZT(T2 x, T2 y, T2 max) { \
+    T2 res = x * y; \
+    return (res > max) ? max : res; \
+}
+
+#define zig_mul_sat_s(ZT, T, T2) static inline T zig_muls_##ZT(T2 x, T2 y, T2 min, T2 max) { \
+    T2 res = x * y; \
+    return (res < min) ? min : (res > max) ? max : res; \
+}
+
+zig_mul_sat_u(u8,    uint8_t,   uint16_t)
+zig_mul_sat_s(i8,     int8_t,    int16_t)
+zig_mul_sat_u(u16,   uint16_t,  uint32_t)
+zig_mul_sat_s(i16,    int16_t,   int32_t)
+zig_mul_sat_u(u32,   uint32_t,  uint64_t)
+zig_mul_sat_s(i32,    int32_t,   int64_t)
+zig_mul_sat_u(u64,   uint64_t, uint128_t)
+zig_mul_sat_s(i64,    int64_t,  int128_t)
+zig_mul_sat_s(isize, intptr_t, int128_t)
+zig_mul_sat_s(short,    short, int)
+zig_mul_sat_s(int,        int, long)
+zig_mul_sat_s(long,      long, long long)
+
+#define zig_shl_sat_u(ZT, T, bits) static inline T zig_shls_##ZT(T x, T y, T max) { \
+    T leading_zeros = __builtin_clz(x); \
+    return (leading_zeros + y > bits) ? max : x << y; \
+}
+
+#define zig_shl_sat_s(ZT, T, bits) static inline T zig_shls_##ZT(T x, T y, T min, T max) { \
+    T leading_zeros = __builtin_clz(x & ~max); \
+    return (leading_zeros + y > bits) ? max : x << y; \
+}
+
+zig_shl_sat_u(u8,    uint8_t,   8)
+zig_shl_sat_s(i8,     int8_t,   7)
+zig_shl_sat_u(u16,   uint16_t, 16)
+zig_shl_sat_s(i16,    int16_t, 15)
+zig_shl_sat_u(u32,   uint32_t, 32)
+zig_shl_sat_s(i32,    int32_t, 31)
+zig_shl_sat_u(u64,   uint64_t, 64)
+zig_shl_sat_s(i64,    int64_t, 63)
+zig_shl_sat_s(isize, intptr_t, 63)
+zig_shl_sat_s(short,    short, 15)
+zig_shl_sat_s(int,        int, 31)
+zig_shl_sat_s(long,      long, 63)
\ No newline at end of file
src/AstGen.zig
@@ -535,7 +535,7 @@ fn expr(gz: *GenZir, scope: *Scope, rl: ResultLoc, node: Ast.Node.Index) InnerEr
             return rvalue(gz, rl, .void_value, node);
         },
         .assign_bit_shift_left_sat => {
-            try assignBinOpExt(gz, scope, node, .shl_with_saturation);
+            try assignBinOpExt(gz, scope, node, .shl_with_saturation, Zir.Inst.SaturatingArithmetic);
             return rvalue(gz, rl, .void_value, node);
         },
         .assign_bit_shift_right => {
@@ -568,7 +568,7 @@ fn expr(gz: *GenZir, scope: *Scope, rl: ResultLoc, node: Ast.Node.Index) InnerEr
             return rvalue(gz, rl, .void_value, node);
         },
         .assign_sub_sat => {
-            try assignBinOpExt(gz, scope, node, .sub_with_saturation);
+            try assignBinOpExt(gz, scope, node, .sub_with_saturation, Zir.Inst.SaturatingArithmetic);
             return rvalue(gz, rl, .void_value, node);
         },
         .assign_mod => {
@@ -584,7 +584,7 @@ fn expr(gz: *GenZir, scope: *Scope, rl: ResultLoc, node: Ast.Node.Index) InnerEr
             return rvalue(gz, rl, .void_value, node);
         },
         .assign_add_sat => {
-            try assignBinOpExt(gz, scope, node, .add_with_saturation);
+            try assignBinOpExt(gz, scope, node, .add_with_saturation, Zir.Inst.SaturatingArithmetic);
             return rvalue(gz, rl, .void_value, node);
         },
         .assign_mul => {
@@ -596,24 +596,24 @@ fn expr(gz: *GenZir, scope: *Scope, rl: ResultLoc, node: Ast.Node.Index) InnerEr
             return rvalue(gz, rl, .void_value, node);
         },
         .assign_mul_sat => {
-            try assignBinOpExt(gz, scope, node, .mul_with_saturation);
+            try assignBinOpExt(gz, scope, node, .mul_with_saturation, Zir.Inst.SaturatingArithmetic);
             return rvalue(gz, rl, .void_value, node);
         },
 
         // zig fmt: off
         .bit_shift_left     => return shiftOp(gz, scope, rl, node, node_datas[node].lhs, node_datas[node].rhs, .shl),
-        .bit_shift_left_sat => return binOpExt(gz, scope, rl, node, node_datas[node].lhs, node_datas[node].rhs, .shl_with_saturation),
+        .bit_shift_left_sat => return binOpExt(gz, scope, rl, node, node_datas[node].lhs, node_datas[node].rhs, .shl_with_saturation, Zir.Inst.SaturatingArithmetic),
         .bit_shift_right    => return shiftOp(gz, scope, rl, node, node_datas[node].lhs, node_datas[node].rhs, .shr),
 
         .add      => return simpleBinOp(gz, scope, rl, node, .add),
         .add_wrap => return simpleBinOp(gz, scope, rl, node, .addwrap),
-        .add_sat  => return binOpExt(gz, scope, rl, node, node_datas[node].lhs, node_datas[node].rhs, .add_with_saturation),
+        .add_sat  => return binOpExt(gz, scope, rl, node, node_datas[node].lhs, node_datas[node].rhs, .add_with_saturation, Zir.Inst.SaturatingArithmetic),
         .sub      => return simpleBinOp(gz, scope, rl, node, .sub),
         .sub_wrap => return simpleBinOp(gz, scope, rl, node, .subwrap),
-        .sub_sat  => return binOpExt(gz, scope, rl, node, node_datas[node].lhs, node_datas[node].rhs, .sub_with_saturation),
+        .sub_sat  => return binOpExt(gz, scope, rl, node, node_datas[node].lhs, node_datas[node].rhs, .sub_with_saturation, Zir.Inst.SaturatingArithmetic),
         .mul      => return simpleBinOp(gz, scope, rl, node, .mul),
         .mul_wrap => return simpleBinOp(gz, scope, rl, node, .mulwrap),
-        .mul_sat  => return binOpExt(gz, scope, rl, node, node_datas[node].lhs, node_datas[node].rhs, .mul_with_saturation),
+        .mul_sat  => return binOpExt(gz, scope, rl, node, node_datas[node].lhs, node_datas[node].rhs, .mul_with_saturation, Zir.Inst.SaturatingArithmetic),
         .div      => return simpleBinOp(gz, scope, rl, node, .div),
         .mod      => return simpleBinOp(gz, scope, rl, node, .mod_rem),
         .bit_and  => {
@@ -2713,6 +2713,28 @@ fn assignOp(
     _ = try gz.addBin(.store, lhs_ptr, result);
 }
 
+// TODO: is there an existing way to do this?
+// TODO: likely rename this to reflect result_loc == .none or add more params to make it more general
+fn binOpExt(
+    gz: *GenZir,
+    scope: *Scope,
+    rl: ResultLoc,
+    infix_node: Ast.Node.Index,
+    lhs_node: Ast.Node.Index,
+    rhs_node: Ast.Node.Index,
+    tag: Zir.Inst.Extended,
+    comptime T: type,
+) InnerError!Zir.Inst.Ref {
+    const lhs = try expr(gz, scope, .none, lhs_node);
+    const rhs = try expr(gz, scope, .none, rhs_node);
+    const result = try gz.addExtendedPayload(tag, T{
+        .node = gz.nodeIndexToRelative(infix_node),
+        .lhs = lhs,
+        .rhs = rhs,
+    });
+    return rvalue(gz, rl, result, infix_node);
+}
+
 // TODO: is there an existing method to accomplish this?
 // TODO: likely rename this to indicate rhs type coercion or add more params to make it more general
 fn assignBinOpExt(
@@ -2720,8 +2742,8 @@ fn assignBinOpExt(
     scope: *Scope,
     infix_node: Ast.Node.Index,
     op_inst_tag: Zir.Inst.Extended,
+    comptime T: type,
 ) InnerError!void {
-    try emitDbgNode(gz, infix_node);
     const astgen = gz.astgen;
     const tree = astgen.tree;
     const node_datas = tree.nodes.items(.data);
@@ -2730,7 +2752,7 @@ fn assignBinOpExt(
     const lhs = try gz.addUnNode(.load, lhs_ptr, infix_node);
     const lhs_type = try gz.addUnNode(.typeof, lhs, infix_node);
     const rhs = try expr(gz, scope, .{ .coerced_ty = lhs_type }, node_datas[infix_node].rhs);
-    const result = try gz.addExtendedPayload(op_inst_tag, Zir.Inst.BinNode{
+    const result = try gz.addExtendedPayload(op_inst_tag, T{
         .node = gz.nodeIndexToRelative(infix_node),
         .lhs = lhs,
         .rhs = rhs,
@@ -7903,26 +7925,6 @@ fn shiftOp(
     return rvalue(gz, rl, result, node);
 }
 
-// TODO: is there an existing way to do this?
-// TODO: likely rename this to reflect result_loc == .none or add more params to make it more general
-fn binOpExt(
-    gz: *GenZir,
-    scope: *Scope,
-    rl: ResultLoc,
-    node: Ast.Node.Index,
-    lhs_node: Ast.Node.Index,
-    rhs_node: Ast.Node.Index,
-    tag: Zir.Inst.Extended,
-) InnerError!Zir.Inst.Ref {
-    const lhs = try expr(gz, scope, .none, lhs_node);
-    const rhs = try expr(gz, scope, .none, rhs_node);
-    const result = try gz.addExtendedPayload(tag, Zir.Inst.Bin{
-        .lhs = lhs,
-        .rhs = rhs,
-    });
-    return rvalue(gz, rl, result, node);
-}
-
 fn cImport(
     gz: *GenZir,
     scope: *Scope,
src/Sema.zig
@@ -694,10 +694,11 @@ fn zirExtended(sema: *Sema, block: *Scope.Block, inst: Zir.Inst.Index) CompileEr
         .c_define           => return sema.zirCDefine(           block, extended),
         .wasm_memory_size   => return sema.zirWasmMemorySize(    block, extended),
         .wasm_memory_grow   => return sema.zirWasmMemoryGrow(    block, extended),
-        .add_with_saturation=> return sema.zirSatArithmetic(     block, extended),
-        .sub_with_saturation=> return sema.zirSatArithmetic(     block, extended),
-        .mul_with_saturation=> return sema.zirSatArithmetic(     block, extended),
-        .shl_with_saturation=> return sema.zirSatArithmetic(     block, extended),
+        .add_with_saturation, 
+        .sub_with_saturation, 
+        .mul_with_saturation, 
+        .shl_with_saturation, 
+                            => return sema.zirSatArithmetic(     block, extended),
         // zig fmt: on
     }
 }
@@ -6163,7 +6164,7 @@ fn zirNegate(
     const lhs = sema.resolveInst(.zero);
     const rhs = sema.resolveInst(inst_data.operand);
 
-    return sema.analyzeArithmetic(block, tag_override, lhs, rhs, src, lhs_src, rhs_src);
+    return sema.analyzeArithmetic(block, tag_override, lhs, rhs, src, lhs_src, rhs_src, null);
 }
 
 fn zirArithmetic(
@@ -6183,7 +6184,7 @@ fn zirArithmetic(
     const lhs = sema.resolveInst(extra.lhs);
     const rhs = sema.resolveInst(extra.rhs);
 
-    return sema.analyzeArithmetic(block, zir_tag, lhs, rhs, sema.src, lhs_src, rhs_src);
+    return sema.analyzeArithmetic(block, zir_tag, lhs, rhs, sema.src, lhs_src, rhs_src, null);
 }
 
 fn zirOverflowArithmetic(
@@ -6209,10 +6210,17 @@ fn zirSatArithmetic(
     defer tracy.end();
 
     const extra = sema.code.extraData(Zir.Inst.SaturatingArithmetic, extended.operand).data;
-    const src: LazySrcLoc = .{ .node_offset = extra.node };
-    return sema.mod.fail(&block.base, src, "TODO implement Sema.zirSatArithmetic", .{});
+    sema.src = .{ .node_offset_bin_op = extra.node };
+    const lhs_src: LazySrcLoc = .{ .node_offset_bin_lhs = extra.node };
+    const rhs_src: LazySrcLoc = .{ .node_offset_bin_rhs = extra.node };
+    const lhs = sema.resolveInst(extra.lhs);
+    const rhs = sema.resolveInst(extra.rhs);
+
+    return sema.analyzeArithmetic(block, .extended, lhs, rhs, sema.src, lhs_src, rhs_src, extended);
 }
 
+// TODO: audit - not sure if its a good idea to reuse this, adding `opt_extended` param
+// FIXME: somehow, rhs of <<| is required to be Log2T. this should accept T
 fn analyzeArithmetic(
     sema: *Sema,
     block: *Scope.Block,
@@ -6223,6 +6231,7 @@ fn analyzeArithmetic(
     src: LazySrcLoc,
     lhs_src: LazySrcLoc,
     rhs_src: LazySrcLoc,
+    opt_extended: ?Zir.Inst.Extended.InstData,
 ) CompileError!Air.Inst.Ref {
     const lhs_ty = sema.typeOf(lhs);
     const rhs_ty = sema.typeOf(rhs);