Commit 3a059ebe4c

Luuk de Gram <luuk@degram.dev>
2022-05-23 22:06:27
wasm: Fixes for error union semantics
1 parent c90a97f
Changed files (3)
src
test
behavior
src/arch/wasm/CodeGen.zig
@@ -636,7 +636,7 @@ fn resolveInst(self: *Self, ref: Air.Inst.Ref) InnerError!WValue {
     // means we must generate it from a constant.
     const val = self.air.value(ref).?;
     const ty = self.air.typeOf(ref);
-    if (!ty.hasRuntimeBitsIgnoreComptime() and !ty.isInt()) {
+    if (!ty.hasRuntimeBitsIgnoreComptime() and !ty.isInt() and !ty.isError()) {
         gop.value_ptr.* = WValue{ .none = {} };
         return gop.value_ptr.*;
     }
@@ -804,6 +804,8 @@ fn genFunctype(gpa: Allocator, fn_info: Type.Payload.Function.Data, target: std.
         } else {
             try returns.append(typeToValtype(fn_info.return_type, target));
         }
+    } else if (fn_info.return_type.isError()) {
+        try returns.append(.i32);
     }
 
     // param types
@@ -1373,10 +1375,15 @@ fn isByRef(ty: Type, target: std.Target) bool {
         .Int => return ty.intInfo(target).bits > 64,
         .Float => return ty.floatBits(target) > 64,
         .ErrorUnion => {
-            const has_tag = ty.errorUnionSet().hasRuntimeBitsIgnoreComptime();
-            const has_pl = ty.errorUnionPayload().hasRuntimeBitsIgnoreComptime();
-            if (!has_tag or !has_pl) return false;
-            return ty.hasRuntimeBitsIgnoreComptime();
+            const err_ty = ty.errorUnionSet();
+            const pl_ty = ty.errorUnionPayload();
+            if (err_ty.errorSetCardinality() == .zero) {
+                return isByRef(pl_ty, target);
+            }
+            if (!pl_ty.hasRuntimeBitsIgnoreComptime()) {
+                return false;
+            }
+            return true;
         },
         .Optional => {
             if (ty.isPtrLikeOptional()) return false;
@@ -1624,13 +1631,14 @@ fn genBody(self: *Self, body: []const Air.Inst.Index) InnerError!void {
 fn airRet(self: *Self, inst: Air.Inst.Index) InnerError!WValue {
     const un_op = self.air.instructions.items(.data)[inst].un_op;
     const operand = try self.resolveInst(un_op);
-    const ret_ty = self.decl.ty.fnReturnType();
+    const fn_info = self.decl.ty.fnInfo();
+    const ret_ty = fn_info.return_type;
 
     // result must be stored in the stack and we return a pointer
     // to the stack instead
     if (self.return_value != .none) {
-        try self.store(self.return_value, operand, self.decl.ty.fnReturnType(), 0);
-    } else if (self.decl.ty.fnInfo().cc == .C and ret_ty.hasRuntimeBitsIgnoreComptime()) {
+        try self.store(self.return_value, operand, ret_ty, 0);
+    } else if (fn_info.cc == .C and ret_ty.hasRuntimeBitsIgnoreComptime()) {
         switch (ret_ty.zigTypeTag()) {
             // Aggregate types can be lowered as a singular value
             .Struct, .Union => {
@@ -1650,7 +1658,11 @@ fn airRet(self: *Self, inst: Air.Inst.Index) InnerError!WValue {
             else => try self.emitWValue(operand),
         }
     } else {
-        try self.emitWValue(operand);
+        if (!ret_ty.hasRuntimeBitsIgnoreComptime() and ret_ty.isError()) {
+            try self.addImm32(0);
+        } else {
+            try self.emitWValue(operand);
+        }
     }
     try self.restoreStackPointer();
     try self.addTag(.@"return");
@@ -1675,7 +1687,13 @@ fn airRetLoad(self: *Self, inst: Air.Inst.Index) InnerError!WValue {
     const un_op = self.air.instructions.items(.data)[inst].un_op;
     const operand = try self.resolveInst(un_op);
     const ret_ty = self.air.typeOf(un_op).childType();
-    if (!ret_ty.hasRuntimeBitsIgnoreComptime()) return WValue.none;
+    if (!ret_ty.hasRuntimeBitsIgnoreComptime()) {
+        if (ret_ty.isError()) {
+            try self.addImm32(0);
+        } else {
+            return WValue.none;
+        }
+    }
 
     if (!firstParamSRet(self.decl.ty.fnInfo(), self.target)) {
         const result = try self.load(operand, ret_ty, 0);
@@ -1723,8 +1741,7 @@ fn airCall(self: *Self, inst: Air.Inst.Index, modifier: std.builtin.CallOptions.
 
     const sret = if (first_param_sret) blk: {
         const sret_local = try self.allocStack(ret_ty);
-        const ptr_offset = try self.buildPointerOffset(sret_local, 0, .new);
-        try self.emitWValue(ptr_offset);
+        try self.lowerToStack(sret_local);
         break :blk sret_local;
     } else WValue{ .none = {} };
 
@@ -1754,7 +1771,7 @@ fn airCall(self: *Self, inst: Air.Inst.Index, modifier: std.builtin.CallOptions.
         try self.addLabel(.call_indirect, fn_type_index);
     }
 
-    if (self.liveness.isUnused(inst) or !ret_ty.hasRuntimeBitsIgnoreComptime()) {
+    if (self.liveness.isUnused(inst) or (!ret_ty.hasRuntimeBitsIgnoreComptime() and !ret_ty.isError())) {
         return WValue.none;
     } else if (ret_ty.isNoReturn()) {
         try self.addTag(.@"unreachable");
@@ -1796,8 +1813,11 @@ fn store(self: *Self, lhs: WValue, rhs: WValue, ty: Type, offset: u32) InnerErro
         .ErrorUnion => {
             const err_ty = ty.errorUnionSet();
             const pl_ty = ty.errorUnionPayload();
+            if (err_ty.errorSetCardinality() == .zero) {
+                return self.store(lhs, rhs, pl_ty, 0);
+            }
             if (!pl_ty.hasRuntimeBitsIgnoreComptime()) {
-                return self.store(lhs, rhs, err_ty, 0);
+                return self.store(lhs, rhs, Type.anyerror, 0);
             }
 
             const len = @intCast(u32, ty.abiSize(self.target));
@@ -2256,6 +2276,7 @@ fn lowerConstant(self: *Self, val: Value, ty: Type) InnerError!WValue {
     const target = self.target;
 
     switch (ty.zigTypeTag()) {
+        .Void => return WValue{ .none = {} },
         .Int => {
             const int_info = ty.intInfo(self.target);
             switch (int_info.signedness) {
@@ -2324,6 +2345,10 @@ fn lowerConstant(self: *Self, val: Value, ty: Type) InnerError!WValue {
         },
         .ErrorUnion => {
             const error_type = ty.errorUnionSet();
+            if (error_type.errorSetCardinality() == .zero) {
+                const pl_val = if (val.castTag(.eu_payload)) |pl| pl.data else Value.initTag(.undef);
+                return self.lowerConstant(pl_val, ty.errorUnionPayload());
+            }
             const is_pl = val.errorUnionIsPayload();
             const err_val = if (!is_pl) val else Value.initTag(.zero);
             return self.lowerConstant(err_val, error_type);
@@ -2892,12 +2917,19 @@ fn airIsErr(self: *Self, inst: Air.Inst.Index, opcode: wasm.Opcode) InnerError!W
     const err_ty = self.air.typeOf(un_op);
     const pl_ty = err_ty.errorUnionPayload();
 
-    // load the error tag value
+    if (err_ty.errorUnionSet().errorSetCardinality() == .zero) {
+        switch (opcode) {
+            .i32_ne => return WValue{ .imm32 = 0 },
+            .i32_eq => return WValue{ .imm32 = 1 },
+            else => unreachable,
+        }
+    }
+
     try self.emitWValue(operand);
     if (pl_ty.hasRuntimeBitsIgnoreComptime()) {
         try self.addMemArg(.i32_load16_u, .{
-            .offset = operand.offset(),
-            .alignment = err_ty.errorUnionSet().abiAlignment(self.target),
+            .offset = operand.offset() + errUnionErrorOffset(pl_ty, self.target),
+            .alignment = Type.anyerror.abiAlignment(self.target),
         });
     }
 
@@ -2905,7 +2937,7 @@ fn airIsErr(self: *Self, inst: Air.Inst.Index, opcode: wasm.Opcode) InnerError!W
     try self.addImm32(0);
     try self.addTag(Mir.Inst.Tag.fromOpcode(opcode));
 
-    const is_err_tmp = try self.allocLocal(Type.initTag(.i32)); // result is always an i32
+    const is_err_tmp = try self.allocLocal(Type.i32);
     try self.addLabel(.local_set, is_err_tmp.local);
     return is_err_tmp;
 }
@@ -2917,14 +2949,18 @@ fn airUnwrapErrUnionPayload(self: *Self, inst: Air.Inst.Index, op_is_ptr: bool)
     const op_ty = self.air.typeOf(ty_op.operand);
     const err_ty = if (op_is_ptr) op_ty.childType() else op_ty;
     const payload_ty = err_ty.errorUnionPayload();
+
+    if (err_ty.errorUnionSet().errorSetCardinality() == .zero) {
+        return operand;
+    }
+
     if (!payload_ty.hasRuntimeBitsIgnoreComptime()) return WValue{ .none = {} };
-    const err_align = err_ty.abiAlignment(self.target);
-    const set_size = err_ty.errorUnionSet().abiSize(self.target);
-    const offset = mem.alignForwardGeneric(u64, set_size, err_align);
+
+    const pl_offset = errUnionPayloadOffset(payload_ty, self.target);
     if (op_is_ptr or isByRef(payload_ty, self.target)) {
-        return self.buildPointerOffset(operand, offset, .new);
+        return self.buildPointerOffset(operand, pl_offset, .new);
     }
-    return self.load(operand, payload_ty, @intCast(u32, offset));
+    return self.load(operand, payload_ty, pl_offset);
 }
 
 fn airUnwrapErrUnionError(self: *Self, inst: Air.Inst.Index, op_is_ptr: bool) InnerError!WValue {
@@ -2935,11 +2971,16 @@ fn airUnwrapErrUnionError(self: *Self, inst: Air.Inst.Index, op_is_ptr: bool) In
     const op_ty = self.air.typeOf(ty_op.operand);
     const err_ty = if (op_is_ptr) op_ty.childType() else op_ty;
     const payload_ty = err_ty.errorUnionPayload();
+
+    if (err_ty.errorUnionSet().errorSetCardinality() == .zero) {
+        return WValue{ .imm32 = 0 };
+    }
+
     if (op_is_ptr or !payload_ty.hasRuntimeBitsIgnoreComptime()) {
         return operand;
     }
 
-    return self.load(operand, err_ty.errorUnionSet(), 0);
+    return self.load(operand, Type.anyerror, errUnionErrorOffset(payload_ty, self.target));
 }
 
 fn airWrapErrUnionPayload(self: *Self, inst: Air.Inst.Index) InnerError!WValue {
@@ -2947,22 +2988,26 @@ fn airWrapErrUnionPayload(self: *Self, inst: Air.Inst.Index) InnerError!WValue {
 
     const ty_op = self.air.instructions.items(.data)[inst].ty_op;
     const operand = try self.resolveInst(ty_op.operand);
+    const err_ty = self.air.typeOfIndex(inst);
 
-    const op_ty = self.air.typeOf(ty_op.operand);
-    if (!op_ty.hasRuntimeBitsIgnoreComptime()) return operand;
-    const err_union_ty = self.air.getRefType(ty_op.ty);
-    const err_align = err_union_ty.abiAlignment(self.target);
-    const set_size = err_union_ty.errorUnionSet().abiSize(self.target);
-    const offset = mem.alignForwardGeneric(u64, set_size, err_align);
+    if (err_ty.errorUnionSet().errorSetCardinality() == .zero) {
+        return operand;
+    }
+
+    const pl_ty = self.air.typeOf(ty_op.operand);
+    if (!pl_ty.hasRuntimeBitsIgnoreComptime()) {
+        return operand;
+    }
 
-    const err_union = try self.allocStack(err_union_ty);
-    const payload_ptr = try self.buildPointerOffset(err_union, offset, .new);
-    try self.store(payload_ptr, operand, op_ty, 0);
+    const err_union = try self.allocStack(err_ty);
+    const payload_ptr = try self.buildPointerOffset(err_union, errUnionPayloadOffset(pl_ty, self.target), .new);
+    try self.store(payload_ptr, operand, pl_ty, 0);
 
     // ensure we also write '0' to the error part, so any present stack value gets overwritten by it.
     try self.emitWValue(err_union);
     try self.addImm32(0);
-    try self.addMemArg(.i32_store16, .{ .offset = err_union.offset(), .alignment = 2 });
+    const err_val_offset = errUnionErrorOffset(pl_ty, self.target);
+    try self.addMemArg(.i32_store16, .{ .offset = err_union.offset() + err_val_offset, .alignment = 2 });
 
     return err_union;
 }
@@ -2973,17 +3018,18 @@ fn airWrapErrUnionErr(self: *Self, inst: Air.Inst.Index) InnerError!WValue {
     const ty_op = self.air.instructions.items(.data)[inst].ty_op;
     const operand = try self.resolveInst(ty_op.operand);
     const err_ty = self.air.getRefType(ty_op.ty);
+    const pl_ty = err_ty.errorUnionPayload();
 
-    if (!err_ty.errorUnionPayload().hasRuntimeBitsIgnoreComptime()) return operand;
+    if (!pl_ty.hasRuntimeBitsIgnoreComptime()) {
+        return operand;
+    }
 
     const err_union = try self.allocStack(err_ty);
-    try self.store(err_union, operand, err_ty.errorUnionSet(), 0);
+    // store error value
+    try self.store(err_union, operand, Type.anyerror, errUnionErrorOffset(pl_ty, self.target));
 
     // write 'undefined' to the payload
-    const err_align = err_ty.abiAlignment(self.target);
-    const set_size = err_ty.errorUnionSet().abiSize(self.target);
-    const offset = mem.alignForwardGeneric(u64, set_size, err_align);
-    const payload_ptr = try self.buildPointerOffset(err_union, offset, .new);
+    const payload_ptr = try self.buildPointerOffset(err_union, errUnionPayloadOffset(pl_ty, self.target), .new);
     const len = @intCast(u32, err_ty.errorUnionPayload().abiSize(self.target));
     try self.memset(payload_ptr, .{ .imm32 = len }, .{ .imm32 = 0xaaaaaaaa });
 
@@ -3927,12 +3973,16 @@ fn airFptrunc(self: *Self, inst: Air.Inst.Index) InnerError!WValue {
 fn airErrUnionPayloadPtrSet(self: *Self, inst: Air.Inst.Index) InnerError!WValue {
     const ty_op = self.air.instructions.items(.data)[inst].ty_op;
     const err_set_ty = self.air.typeOf(ty_op.operand).childType();
-    const err_ty = err_set_ty.errorUnionSet();
     const payload_ty = err_set_ty.errorUnionPayload();
     const operand = try self.resolveInst(ty_op.operand);
 
     // set error-tag to '0' to annotate error union is non-error
-    try self.store(operand, .{ .imm32 = 0 }, err_ty, 0);
+    try self.store(
+        operand,
+        .{ .imm32 = 0 },
+        Type.anyerror,
+        errUnionErrorOffset(payload_ty, self.target),
+    );
 
     if (self.liveness.isUnused(inst)) return WValue{ .none = {} };
 
@@ -3940,11 +3990,7 @@ fn airErrUnionPayloadPtrSet(self: *Self, inst: Air.Inst.Index) InnerError!WValue
         return operand;
     }
 
-    const err_align = err_set_ty.abiAlignment(self.target);
-    const set_size = err_ty.abiSize(self.target);
-    const offset = mem.alignForwardGeneric(u64, set_size, err_align);
-
-    return self.buildPointerOffset(operand, @intCast(u32, offset), .new);
+    return self.buildPointerOffset(operand, errUnionPayloadOffset(payload_ty, self.target), .new);
 }
 
 fn airFieldParentPtr(self: *Self, inst: Air.Inst.Index) InnerError!WValue {
@@ -4572,3 +4618,17 @@ fn airDbgStmt(self: *Self, inst: Air.Inst.Index) !WValue {
     } });
     return WValue{ .none = {} };
 }
+
+fn errUnionPayloadOffset(payload_ty: Type, target: std.Target) u32 {
+    if (Type.anyerror.abiAlignment(target) > payload_ty.abiAlignment(target)) {
+        return @intCast(u32, Type.anyerror.abiSize(target));
+    }
+    return 0;
+}
+
+fn errUnionErrorOffset(payload_ty: Type, target: std.Target) u32 {
+    if (Type.anyerror.abiAlignment(target) > payload_ty.abiAlignment(target)) {
+        return 0;
+    }
+    return @intCast(u32, payload_ty.abiSize(target));
+}
src/codegen.zig
@@ -714,7 +714,7 @@ pub fn generateSymbol(
             const is_payload = typed_value.val.errorUnionIsPayload();
 
             if (!payload_ty.hasRuntimeBitsIgnoreComptime()) {
-                const err_val = if (!is_payload) typed_value.val else Value.initTag(.zero);
+                const err_val = if (is_payload) Value.initTag(.zero) else typed_value.val;
                 return generateSymbol(bin_file, src_loc, .{
                     .ty = error_ty,
                     .val = err_val,
@@ -763,7 +763,7 @@ pub fn generateSymbol(
             }
 
             // Payload size is larger than error set, so emit our error set last
-            if (error_align < payload_align) {
+            if (error_align <= payload_align) {
                 const begin = code.items.len;
                 switch (try generateSymbol(bin_file, src_loc, .{
                     .ty = error_ty,
@@ -794,7 +794,7 @@ pub fn generateSymbol(
                     try code.writer().writeInt(u32, kv.value, endian);
                 },
                 else => {
-                    try code.writer().writeByteNTimes(0, @intCast(usize, typed_value.ty.abiSize(target)));
+                    try code.writer().writeByteNTimes(0, @intCast(usize, Type.anyerror.abiSize(target)));
                 },
             }
             return Result{ .appended = {} };
test/behavior/error.zig
@@ -260,7 +260,6 @@ fn testComptimeTestErrorEmptySet(x: EmptyErrorSet!i32) !void {
 }
 
 test "comptime err to int of error set with only 1 possible value" {
-    if (builtin.zig_backend == .stage2_wasm) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_c) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_x86_64) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO