Commit fb16633ecb

Matt Knight <mattnite@protonmail.com>
2021-04-10 23:21:59
C backend: add/sub/mul wrapping for the C backend
1 parent 62d27fc
Changed files (3)
src
codegen
link
test
stage2
src/codegen/c.zig
@@ -846,18 +846,15 @@ pub fn genBody(o: *Object, body: ir.Body) error{ AnalysisFail, OutOfMemory }!voi
             // TODO use a different strategy for add that communicates to the optimizer
             // that wrapping is UB.
             .add => try genBinOp(o, inst.castTag(.add).?, " + "),
-            // TODO make this do wrapping arithmetic for signed ints
-            .addwrap => try genBinOp(o, inst.castTag(.add).?, " + "),
+            .addwrap => try genWrapOp(o, .add, inst.castTag(.addwrap).?),
             // TODO use a different strategy for sub that communicates to the optimizer
             // that wrapping is UB.
             .sub => try genBinOp(o, inst.castTag(.sub).?, " - "),
-            // TODO make this do wrapping arithmetic for signed ints
-            .subwrap => try genBinOp(o, inst.castTag(.sub).?, " - "),
+            .subwrap => try genWrapOp(o, .sub, inst.castTag(.subwrap).?),
             // TODO use a different strategy for mul that communicates to the optimizer
             // that wrapping is UB.
             .mul => try genBinOp(o, inst.castTag(.sub).?, " * "),
-            // TODO make this do wrapping multiplication for signed ints
-            .mulwrap => try genBinOp(o, inst.castTag(.sub).?, " * "),
+            .mulwrap => try genWrapOp(o, .mul, inst.castTag(.mulwrap).?),
             // TODO use a different strategy for div that communicates to the optimizer
             // that wrapping is UB.
             .div => try genBinOp(o, inst.castTag(.div).?, " / "),
@@ -1042,6 +1039,131 @@ fn genStore(o: *Object, inst: *Inst.BinOp) !CValue {
     return CValue.none;
 }
 
+const WrappingOp = enum {
+    add,
+    sub,
+    mul,
+};
+
+fn genWrapOp(o: *Object, op: WrappingOp, inst: *Inst.BinOp) !CValue {
+    if (inst.base.isUnused())
+        return CValue.none;
+
+    const is_signed = inst.base.ty.isSignedInt();
+
+    // if it's an unsigned int with non-arbitrary bit size then we can just add
+    if (!is_signed and inst.base.ty.tag() != .int_unsigned) {
+        return try genBinOp(o, inst, switch (op) {
+            .add => " + ",
+            .sub => " - ",
+            .mul => " * ",
+        });
+    }
+
+    var min_buf: [80]u8 = undefined;
+    const min = if (!is_signed)
+        "0"
+    else switch (inst.base.ty.tag()) {
+        .c_short => "SHRT_MIN",
+        .c_int => "INT_MIN",
+        .c_long => "LONG_MIN",
+        .c_longlong => "LLONG_MIN",
+        .isize => "INTPTR_MIN",
+        else => blk: {
+            // should be able to use undefined here since all the target specifics are handled
+            const bits = inst.base.ty.intInfo(@as(std.Target, undefined)).bits;
+            assert(bits <= 64); // TODO: large integers
+            const val = -1 * std.math.pow(i64, 2, @intCast(i64, bits - 1));
+            break :blk std.fmt.bufPrint(&min_buf, "{}", .{val}) catch |e|
+            // doesn't fit in some upwards error set, but should never happen
+                return if (e == error.NoSpaceLeft) unreachable else e;
+        },
+    };
+
+    var max_buf: [80]u8 = undefined;
+    const max = switch (inst.base.ty.tag()) {
+        .c_short => "SHRT_MAX",
+        .c_ushort => "USHRT_MAX",
+        .c_int => "INT_MAX",
+        .c_uint => "UINT_MAX",
+        .c_long => "LONG_MAX",
+        .c_ulong => "ULONG_MAX",
+        .c_longlong => "LLONG_MAX",
+        .c_ulonglong => "ULLONG_MAX",
+        .isize => "INTPTR_MAX",
+        .usize => "UINTPTR_MAX",
+        else => blk: {
+            // should be able to use undefined here since all the target specifics are handled
+            const bits = inst.base.ty.intInfo(@as(std.Target, undefined)).bits;
+            assert(bits <= 64); // TODO: large integers
+            const val = std.math.pow(u64, 2, if (is_signed) (bits - 1) else bits) - 1;
+            break :blk std.fmt.bufPrint(&max_buf, "{}", .{val}) catch |e|
+            // doesn't fit in some upwards error set, but should never happen
+                return if (e == error.NoSpaceLeft) unreachable else e;
+        },
+    };
+
+    const lhs = try o.resolveInst(inst.lhs);
+    const rhs = try o.resolveInst(inst.rhs);
+    const w = o.writer();
+
+    const ret = try o.allocLocal(inst.base.ty, .Mut);
+    try w.writeAll(" = zig_");
+    try w.writeAll(switch (op) {
+        .add => "addw_",
+        .sub => "subw_",
+        .mul => return o.dg.fail(.{ .node_offset = 0 }, "TODO: C backend: implement wrapping multiplication operator", .{}),
+    });
+
+    switch (inst.base.ty.tag()) {
+        .u8 => try w.writeAll("u8"),
+        .i8 => try w.writeAll("i8"),
+        .u16 => try w.writeAll("u16"),
+        .i16 => try w.writeAll("i16"),
+        .u32 => try w.writeAll("u32"),
+        .i32 => try w.writeAll("i32"),
+        .u64 => try w.writeAll("u64"),
+        .i64 => try w.writeAll("i64"),
+        .isize => try w.writeAll("isize"),
+        .c_short => try w.writeAll("short"),
+        .c_int => try w.writeAll("int"),
+        .c_long => try w.writeAll("long"),
+        .c_longlong => try w.writeAll("longlong"),
+        .int_signed, .int_unsigned => {
+            if (is_signed) {
+                try w.writeByte('i');
+            } else {
+                try w.writeByte('u');
+            }
+
+            const info_bits = inst.base.ty.intInfo(@as(std.Target, undefined)).bits;
+            inline for (.{ 8, 16, 32, 64 }) |nbits| {
+                if (info_bits <= nbits) {
+                    try w.print("{d}", .{nbits});
+                    break;
+                }
+            } else {
+                return o.dg.fail(.{ .node_offset = 0 }, "TODO: C backend: implement integer types larger than 64 bits", .{});
+            }
+        },
+        else => unreachable,
+    }
+
+    try w.writeByte('(');
+    try o.writeCValue(w, lhs);
+    try w.writeAll(", ");
+    try o.writeCValue(w, rhs);
+
+    if (is_signed) {
+        try w.print(", {s}", .{min});
+    }
+
+    try w.print(", {s});", .{max});
+    try o.indent_writer.insertNewline();
+
+    return ret;
+}
+
 fn genBinOp(o: *Object, inst: *Inst.BinOp, operator: []const u8) !CValue {
     if (inst.base.isUnused())
         return CValue.none;
src/link/C/zig.h
@@ -60,9 +60,161 @@
 #define zig_breakpoint() raise(SIGTRAP)
 #endif
 
+
+#define ZIG_UADDW(Type, lhs, rhs, max)                      \
+    Type thresh = max - rhs;                                \
+    if (lhs > thresh) {                                     \
+        return lhs - thresh - 1;                            \
+    } else {                                                \
+        return lhs + rhs;                                   \
+    }
+
+#define ZIG_SADDW(Type, lhs, rhs, min, max)                 \
+    if ((lhs > 0) && (rhs > 0)) {                           \
+        Type thresh = max - rhs;                            \
+        if (lhs > thresh) {                                 \
+            return min + lhs - thresh - 1;                  \
+        }                                                   \
+    } else if ((lhs < 0) && (rhs < 0)) {                    \
+        Type thresh = min - rhs;                            \
+        if (lhs < thresh) {                                 \
+            return max + lhs - thresh + 1;                  \
+        }                                                   \
+    }                                                       \
+                                                            \
+    return lhs + rhs;
+
+#define ZIG_USUBW(lhs, rhs, max)                            \
+    if (lhs < rhs) {                                        \
+        return max - rhs - lhs + 1;                         \
+    } else {                                                \
+        return lhs - rhs;                                   \
+    }
+
+#define ZIG_SSUBW(Type, lhs, rhs, min, max)                 \
+    if ((lhs > 0) && (rhs < 0)) {                           \
+        Type thresh = lhs - max;                            \
+        if (rhs < thresh) {                                 \
+            return min + (thresh - rhs - 1);                \
+        }                                                   \
+    } else if ((lhs < 0) && (rhs > 0)) {                    \
+        Type thresh = lhs - min;                            \
+        if (rhs > thresh) {                                 \
+            return max - (rhs - thresh - 1);                \
+        }                                                   \
+    }                                                       \
+    return lhs - rhs;
+
 #include <stdint.h>
 #include <stddef.h>
+#include <limits.h>
 #define int128_t __int128
 #define uint128_t unsigned __int128
 ZIG_EXTERN_C void *memcpy (void *ZIG_RESTRICT, const void *ZIG_RESTRICT, size_t);
 
+/* Wrapping addition operators */
+static inline uint8_t zig_addw_u8(uint8_t lhs, uint8_t rhs, uint8_t max) {
+    ZIG_UADDW(uint8_t, lhs, rhs, max);
+}
+
+static inline int8_t zig_addw_i8(int8_t lhs, int8_t rhs, int8_t min, int8_t max) {
+    ZIG_SADDW(int8_t, lhs, rhs, min, max);
+}
+
+static inline uint16_t zig_addw_u16(uint16_t lhs, uint16_t rhs, uint16_t max) {
+    ZIG_UADDW(uint16_t, lhs, rhs, max);
+}
+
+static inline int16_t zig_addw_i16(int16_t lhs, int16_t rhs, int16_t min, int16_t max) {
+    ZIG_SADDW(int16_t, lhs, rhs, min, max);
+}
+
+static inline uint32_t zig_addw_u32(uint32_t lhs, uint32_t rhs, uint32_t max) {
+    ZIG_UADDW(uint32_t, lhs, rhs, max);
+}
+
+static inline int32_t zig_addw_i32(int32_t lhs, int32_t rhs, int32_t min, int32_t max) {
+    ZIG_SADDW(int32_t, lhs, rhs, min, max);
+}
+
+static inline uint64_t zig_addw_u64(uint64_t lhs, uint64_t rhs, uint64_t max) {
+    ZIG_UADDW(uint64_t, lhs, rhs, max);
+}
+
+static inline int64_t zig_addw_i64(int64_t lhs, int64_t rhs, int64_t min, int64_t max) {
+    ZIG_SADDW(int64_t, lhs, rhs, min, max);
+}
+
+static inline intptr_t zig_addw_isize(intptr_t lhs, intptr_t rhs, intptr_t min, intptr_t max) {
+    return (intptr_t)(((uintptr_t)lhs) + ((uintptr_t)rhs));
+}
+
+static inline short zig_addw_short(short lhs, short rhs, short min, short max) {
+    return (short)(((unsigned short)lhs) + ((unsigned short)rhs));
+}
+
+static inline int zig_addw_int(int lhs, int rhs, int min, int max) {
+    return (int)(((unsigned)lhs) + ((unsigned)rhs));
+}
+
+static inline long zig_addw_long(long lhs, long rhs, long min, long max) {
+    return (long)(((unsigned long)lhs) + ((unsigned long)rhs));
+}
+
+static inline long long zig_addw_longlong(long long lhs, long long rhs, long long min, long long max) {
+    return (long long)(((unsigned long long)lhs) + ((unsigned long long)rhs));
+}
+
+/* Wrapping subtraction operators */
+static inline uint8_t zig_subw_u8(uint8_t lhs, uint8_t rhs, uint8_t max) {
+    ZIG_USUBW(lhs, rhs, max);
+}
+
+static inline int8_t zig_subw_i8(int8_t lhs, int8_t rhs, int8_t min, int8_t max) {
+    ZIG_SSUBW(int8_t, lhs, rhs, min, max);
+}
+
+static inline uint16_t zig_subw_u16(uint16_t lhs, uint16_t rhs, uint16_t max) {
+    ZIG_USUBW(lhs, rhs, max);
+}
+
+static inline int16_t zig_subw_i16(int16_t lhs, int16_t rhs, int16_t min, int16_t max) {
+    ZIG_SSUBW(int16_t, lhs, rhs, min, max);
+}
+
+static inline uint32_t zig_subw_u32(uint32_t lhs, uint32_t rhs, uint32_t max) {
+    ZIG_USUBW(lhs, rhs, max);
+}
+
+static inline int32_t zig_subw_i32(int32_t lhs, int32_t rhs, int32_t min, int32_t max) {
+    ZIG_SSUBW(int32_t, lhs, rhs, min, max);
+}
+
+static inline uint64_t zig_subw_u64(uint64_t lhs, uint64_t rhs, uint64_t max) {
+    ZIG_USUBW(lhs, rhs, max);
+}
+
+static inline int64_t zig_subw_i64(int64_t lhs, int64_t rhs, int64_t min, int64_t max) {
+    ZIG_SSUBW(int64_t, lhs, rhs, min, max);
+}
+
+static inline intptr_t zig_subw_isize(intptr_t lhs, intptr_t rhs, intptr_t min, intptr_t max) {
+    return (intptr_t)(((uintptr_t)lhs) - ((uintptr_t)rhs));
+}
+
+static inline short zig_subw_short(short lhs, short rhs, short min, short max) {
+    return (short)(((unsigned short)lhs) - ((unsigned short)rhs));
+}
+
+static inline int zig_subw_int(int lhs, int rhs, int min, int max) {
+    return (int)(((unsigned)lhs) - ((unsigned)rhs));
+}
+
+static inline long zig_subw_long(long lhs, long rhs, long min, long max) {
+    return (long)(((unsigned long)lhs) - ((unsigned long)rhs));
+}
+
+static inline long long zig_subw_longlong(long long lhs, long long rhs, long long min, long long max) {
+    return (long long)(((unsigned long long)lhs) - ((unsigned long long)rhs));
+}
+
test/stage2/cbe.zig
@@ -824,6 +824,55 @@ pub fn addCases(ctx: *TestContext) !void {
         , "");
     }
 
+    {
+        // TODO: move these cases into the programs themselves once stage 2 has array literals
+        // TODO: add u64 tests, ran into issues with the literal generated for std.math.maxInt(u64)
+        var case = ctx.exeFromCompiledC("Wrapping operations", .{});
+        const programs = comptime blk: {
+            const cases = .{
+                // Addition
+                .{ u3, "+%", 1, 1, 2 },
+                .{ u3, "+%", 7, 1, 0 },
+                .{ i3, "+%", 1, 1, 2 },
+                .{ i3, "+%", 3, 2, -3 },
+                .{ i3, "+%", -3, -2, 3 },
+                .{ c_int, "+%", 1, 1, 2 },
+                .{ c_int, "+%", std.math.maxInt(c_int), 2, std.math.minInt(c_int) + 1 },
+                .{ c_int, "+%", std.math.minInt(c_int) + 1, -2, std.math.maxInt(c_int) },
+
+                // Subtraction
+                .{ u3, "-%", 2, 1, 1 },
+                .{ u3, "-%", 0, 1, 7 },
+                .{ i3, "-%", 2, 1, 1 },
+                .{ i3, "-%", 3, -2, -3 },
+                .{ i3, "-%", -3, 2, 3 },
+                .{ c_int, "-%", 2, 1, 1 },
+                .{ c_int, "-%", std.math.maxInt(c_int), -2, std.math.minInt(c_int) + 1 },
+                .{ c_int, "-%", std.math.minInt(c_int) + 1, 2, std.math.maxInt(c_int) },
+            };
+
+            var ret: [cases.len][:0]const u8 = undefined;
+            for (cases) |c, i| ret[i] = std.fmt.comptimePrint(
+                \\export fn main() i32 {{
+                \\    var lhs: {0} = {2};
+                \\    var rhs: {0} = {3};
+                \\    var expected: {0} = {4};
+                \\
+                \\    if (expected != lhs {1s} rhs) {{
+                \\        return 1;
+                \\    }} else {{
+                \\        return 0;
+                \\    }}
+                \\}}
+                \\
+            , c);
+
+            break :blk ret;
+        };
+
+        inline for (programs) |prog| case.addCompareOutput(prog, "");
+    }
+
     ctx.h("simple header", linux_x64,
         \\export fn start() void{}
     ,