Commit baaec94fe4

Travis Staloch <twostepted@gmail.com>
2021-09-15 03:34:52
sat-arithmetic: create Sema.analyzeSatArithmetic
- similar to Sema.analyzeArithmetic but uses accepts Zir.Inst.Extended.InstData - missing support for Pointer types and comptime arithmetic
1 parent cd8d8ad
Changed files (2)
src/AstGen.zig
@@ -535,7 +535,7 @@ fn expr(gz: *GenZir, scope: *Scope, rl: ResultLoc, node: Ast.Node.Index) InnerEr
             return rvalue(gz, rl, .void_value, node);
         },
         .assign_bit_shift_left_sat => {
-            try assignBinOpExt(gz, scope, node, .shl_with_saturation, Zir.Inst.SaturatingArithmetic);
+            try assignOpExt(gz, scope, node, .shl_with_saturation, Zir.Inst.SaturatingArithmetic);
             return rvalue(gz, rl, .void_value, node);
         },
         .assign_bit_shift_right => {
@@ -568,7 +568,7 @@ fn expr(gz: *GenZir, scope: *Scope, rl: ResultLoc, node: Ast.Node.Index) InnerEr
             return rvalue(gz, rl, .void_value, node);
         },
         .assign_sub_sat => {
-            try assignBinOpExt(gz, scope, node, .sub_with_saturation, Zir.Inst.SaturatingArithmetic);
+            try assignOpExt(gz, scope, node, .sub_with_saturation, Zir.Inst.SaturatingArithmetic);
             return rvalue(gz, rl, .void_value, node);
         },
         .assign_mod => {
@@ -584,7 +584,7 @@ fn expr(gz: *GenZir, scope: *Scope, rl: ResultLoc, node: Ast.Node.Index) InnerEr
             return rvalue(gz, rl, .void_value, node);
         },
         .assign_add_sat => {
-            try assignBinOpExt(gz, scope, node, .add_with_saturation, Zir.Inst.SaturatingArithmetic);
+            try assignOpExt(gz, scope, node, .add_with_saturation, Zir.Inst.SaturatingArithmetic);
             return rvalue(gz, rl, .void_value, node);
         },
         .assign_mul => {
@@ -596,26 +596,28 @@ fn expr(gz: *GenZir, scope: *Scope, rl: ResultLoc, node: Ast.Node.Index) InnerEr
             return rvalue(gz, rl, .void_value, node);
         },
         .assign_mul_sat => {
-            try assignBinOpExt(gz, scope, node, .mul_with_saturation, Zir.Inst.SaturatingArithmetic);
+            try assignOpExt(gz, scope, node, .mul_with_saturation, Zir.Inst.SaturatingArithmetic);
             return rvalue(gz, rl, .void_value, node);
         },
 
         // zig fmt: off
         .bit_shift_left     => return shiftOp(gz, scope, rl, node, node_datas[node].lhs, node_datas[node].rhs, .shl),
-        .bit_shift_left_sat => return binOpExt(gz, scope, rl, node, node_datas[node].lhs, node_datas[node].rhs, .shl_with_saturation, Zir.Inst.SaturatingArithmetic),
         .bit_shift_right    => return shiftOp(gz, scope, rl, node, node_datas[node].lhs, node_datas[node].rhs, .shr),
 
         .add      => return simpleBinOp(gz, scope, rl, node, .add),
         .add_wrap => return simpleBinOp(gz, scope, rl, node, .addwrap),
-        .add_sat  => return binOpExt(gz, scope, rl, node, node_datas[node].lhs, node_datas[node].rhs, .add_with_saturation, Zir.Inst.SaturatingArithmetic),
         .sub      => return simpleBinOp(gz, scope, rl, node, .sub),
         .sub_wrap => return simpleBinOp(gz, scope, rl, node, .subwrap),
-        .sub_sat  => return binOpExt(gz, scope, rl, node, node_datas[node].lhs, node_datas[node].rhs, .sub_with_saturation, Zir.Inst.SaturatingArithmetic),
         .mul      => return simpleBinOp(gz, scope, rl, node, .mul),
         .mul_wrap => return simpleBinOp(gz, scope, rl, node, .mulwrap),
-        .mul_sat  => return binOpExt(gz, scope, rl, node, node_datas[node].lhs, node_datas[node].rhs, .mul_with_saturation, Zir.Inst.SaturatingArithmetic),
         .div      => return simpleBinOp(gz, scope, rl, node, .div),
         .mod      => return simpleBinOp(gz, scope, rl, node, .mod_rem),
+
+        .add_sat            => return simpleBinOpExt(gz, scope, rl, node, node_datas[node].lhs, node_datas[node].rhs, .add_with_saturation, Zir.Inst.SaturatingArithmetic),
+        .sub_sat            => return simpleBinOpExt(gz, scope, rl, node, node_datas[node].lhs, node_datas[node].rhs, .sub_with_saturation, Zir.Inst.SaturatingArithmetic),
+        .mul_sat            => return simpleBinOpExt(gz, scope, rl, node, node_datas[node].lhs, node_datas[node].rhs, .mul_with_saturation, Zir.Inst.SaturatingArithmetic),
+        .bit_shift_left_sat => return simpleBinOpExt(gz, scope, rl, node, node_datas[node].lhs, node_datas[node].rhs, .shl_with_saturation, Zir.Inst.SaturatingArithmetic),
+        
         .bit_and  => {
             const current_ampersand_token = main_tokens[node];
             if (token_tags[current_ampersand_token + 1] == .ampersand) {
@@ -2713,9 +2715,7 @@ fn assignOp(
     _ = try gz.addBin(.store, lhs_ptr, result);
 }
 
-// TODO: is there an existing way to do this?
-// TODO: likely rename this to reflect result_loc == .none or add more params to make it more general
-fn binOpExt(
+fn simpleBinOpExt(
     gz: *GenZir,
     scope: *Scope,
     rl: ResultLoc,
@@ -2735,9 +2735,7 @@ fn binOpExt(
     return rvalue(gz, rl, result, infix_node);
 }
 
-// TODO: is there an existing method to accomplish this?
-// TODO: likely rename this to indicate rhs type coercion or add more params to make it more general
-fn assignBinOpExt(
+fn assignOpExt(
     gz: *GenZir,
     scope: *Scope,
     infix_node: Ast.Node.Index,
src/Sema.zig
@@ -6164,7 +6164,7 @@ fn zirNegate(
     const lhs = sema.resolveInst(.zero);
     const rhs = sema.resolveInst(inst_data.operand);
 
-    return sema.analyzeArithmetic(block, tag_override, lhs, rhs, src, lhs_src, rhs_src, null);
+    return sema.analyzeArithmetic(block, tag_override, lhs, rhs, src, lhs_src, rhs_src);
 }
 
 fn zirArithmetic(
@@ -6184,7 +6184,7 @@ fn zirArithmetic(
     const lhs = sema.resolveInst(extra.lhs);
     const rhs = sema.resolveInst(extra.rhs);
 
-    return sema.analyzeArithmetic(block, zir_tag, lhs, rhs, sema.src, lhs_src, rhs_src, null);
+    return sema.analyzeArithmetic(block, zir_tag, lhs, rhs, sema.src, lhs_src, rhs_src);
 }
 
 fn zirOverflowArithmetic(
@@ -6216,11 +6216,90 @@ fn zirSatArithmetic(
     const lhs = sema.resolveInst(extra.lhs);
     const rhs = sema.resolveInst(extra.rhs);
 
-    return sema.analyzeArithmetic(block, .extended, lhs, rhs, sema.src, lhs_src, rhs_src, extended);
+    return sema.analyzeSatArithmetic(block, lhs, rhs, sema.src, lhs_src, rhs_src, extended);
+}
+
+fn analyzeSatArithmetic(
+    sema: *Sema,
+    block: *Scope.Block,
+    lhs: Air.Inst.Ref,
+    rhs: Air.Inst.Ref,
+    src: LazySrcLoc,
+    lhs_src: LazySrcLoc,
+    rhs_src: LazySrcLoc,
+    extended: Zir.Inst.Extended.InstData,
+) CompileError!Air.Inst.Ref {
+    const lhs_ty = sema.typeOf(lhs);
+    const rhs_ty = sema.typeOf(rhs);
+    const lhs_zig_ty_tag = try lhs_ty.zigTypeTagOrPoison();
+    const rhs_zig_ty_tag = try rhs_ty.zigTypeTagOrPoison();
+    if (lhs_zig_ty_tag == .Vector and rhs_zig_ty_tag == .Vector) {
+        if (lhs_ty.arrayLen() != rhs_ty.arrayLen()) {
+            return sema.mod.fail(&block.base, src, "vector length mismatch: {d} and {d}", .{
+                lhs_ty.arrayLen(), rhs_ty.arrayLen(),
+            });
+        }
+        return sema.mod.fail(&block.base, src, "TODO implement support for vectors in zirBinOp", .{});
+    } else if (lhs_zig_ty_tag == .Vector or rhs_zig_ty_tag == .Vector) {
+        return sema.mod.fail(&block.base, src, "mixed scalar and vector operands to binary expression: '{}' and '{}'", .{
+            lhs_ty, rhs_ty,
+        });
+    }
+
+    if (lhs_zig_ty_tag == .Pointer or rhs_zig_ty_tag == .Pointer)
+        return sema.mod.fail(&block.base, src, "TODO implement support for pointers in zirSatArithmetic", .{});
+
+    const instructions = &[_]Air.Inst.Ref{ lhs, rhs };
+    const resolved_type = try sema.resolvePeerTypes(block, src, instructions, .{ .override = &[_]LazySrcLoc{ lhs_src, rhs_src } });
+    const casted_lhs = try sema.coerce(block, resolved_type, lhs, lhs_src);
+    const casted_rhs = try sema.coerce(block, resolved_type, rhs, rhs_src);
+
+    const scalar_type = if (resolved_type.zigTypeTag() == .Vector)
+        resolved_type.elemType()
+    else
+        resolved_type;
+
+    const scalar_tag = scalar_type.zigTypeTag();
+
+    const is_int = scalar_tag == .Int or scalar_tag == .ComptimeInt;
+
+    if (!is_int)
+        return sema.mod.fail(&block.base, src, "invalid operands to binary expression: '{s}' and '{s}'", .{
+            @tagName(lhs_zig_ty_tag), @tagName(rhs_zig_ty_tag),
+        });
+
+    if (try sema.resolveMaybeUndefVal(block, lhs_src, casted_lhs)) |lhs_val| {
+        if (try sema.resolveMaybeUndefVal(block, rhs_src, casted_rhs)) |rhs_val| {
+            if (lhs_val.isUndef() or rhs_val.isUndef()) {
+                return sema.addConstUndef(resolved_type);
+            }
+            // incase rhs is 0, simply return lhs without doing any calculations
+            if (rhs_val.compareWithZero(.eq)) {
+                switch (extended.opcode) {
+                    .add_with_saturation, .sub_with_saturation => return sema.addConstant(scalar_type, lhs_val),
+                    else => {},
+                }
+            }
+
+            return sema.mod.fail(&block.base, src, "TODO implement comptime saturating arithmetic for operand '{s}'", .{@tagName(extended.opcode)});
+        } else {
+            try sema.requireRuntimeBlock(block, rhs_src);
+        }
+    } else {
+        try sema.requireRuntimeBlock(block, lhs_src);
+    }
+
+    const air_tag: Air.Inst.Tag = switch (extended.opcode) {
+        .add_with_saturation => .addsat,
+        .sub_with_saturation => .subsat,
+        .mul_with_saturation => .mulsat,
+        .shl_with_saturation => .shl_sat,
+        else => return sema.mod.fail(&block.base, src, "TODO implement arithmetic for extended opcode '{s}'", .{@tagName(extended.opcode)}),
+    };
+
+    return block.addBinOp(air_tag, casted_lhs, casted_rhs);
 }
 
-// TODO: audit - not sure if its a good idea to reuse this, adding `opt_extended` param
-// FIXME: somehow, rhs of <<| is required to be Log2T. this should accept T
 fn analyzeArithmetic(
     sema: *Sema,
     block: *Scope.Block,
@@ -6231,7 +6310,6 @@ fn analyzeArithmetic(
     src: LazySrcLoc,
     lhs_src: LazySrcLoc,
     rhs_src: LazySrcLoc,
-    opt_extended: ?Zir.Inst.Extended.InstData,
 ) CompileError!Air.Inst.Ref {
     const lhs_ty = sema.typeOf(lhs);
     const rhs_ty = sema.typeOf(rhs);