Commit 7c9979a02e

Veikka Tuominen <git@vexu.eu>
2022-08-11 21:45:15
stage2: generate a switch for `@errSetCast` safety
1 parent fa50e17
src/arch/aarch64/CodeGen.zig
@@ -778,6 +778,7 @@ fn genBody(self: *Self, body: []const Air.Inst.Index) InnerError!void {
             => return self.fail("TODO implement optimized float mode", .{}),
 
             .is_named_enum_value => return self.fail("TODO implement is_named_enum_value", .{}),
+            .error_set_has_value => return self.fail("TODO implement error_set_has_value", .{}),
 
             .wasm_memory_size => unreachable,
             .wasm_memory_grow => unreachable,
src/arch/arm/CodeGen.zig
@@ -769,6 +769,7 @@ fn genBody(self: *Self, body: []const Air.Inst.Index) InnerError!void {
             => return self.fail("TODO implement optimized float mode", .{}),
 
             .is_named_enum_value => return self.fail("TODO implement is_named_enum_value", .{}),
+            .error_set_has_value => return self.fail("TODO implement error_set_has_value", .{}),
 
             .wasm_memory_size => unreachable,
             .wasm_memory_grow => unreachable,
src/arch/riscv64/CodeGen.zig
@@ -694,6 +694,7 @@ fn genBody(self: *Self, body: []const Air.Inst.Index) InnerError!void {
             => return self.fail("TODO implement optimized float mode", .{}),
 
             .is_named_enum_value => return self.fail("TODO implement is_named_enum_value", .{}),
+            .error_set_has_value => return self.fail("TODO implement error_set_has_value", .{}),
 
             .wasm_memory_size => unreachable,
             .wasm_memory_grow => unreachable,
src/arch/sparc64/CodeGen.zig
@@ -706,6 +706,7 @@ fn genBody(self: *Self, body: []const Air.Inst.Index) InnerError!void {
             => @panic("TODO implement optimized float mode"),
 
             .is_named_enum_value => @panic("TODO implement is_named_enum_value"),
+            .error_set_has_value => @panic("TODO implement error_set_has_value"),
 
             .wasm_memory_size => unreachable,
             .wasm_memory_grow => unreachable,
src/arch/wasm/CodeGen.zig
@@ -1694,6 +1694,7 @@ fn genInst(self: *Self, inst: Air.Inst.Index) !WValue {
         .err_return_trace,
         .set_err_return_trace,
         .is_named_enum_value,
+        .error_set_has_value,
         => |tag| return self.fail("TODO: Implement wasm inst: {s}", .{@tagName(tag)}),
 
         .add_optimized,
src/arch/x86_64/CodeGen.zig
@@ -776,6 +776,7 @@ fn genBody(self: *Self, body: []const Air.Inst.Index) InnerError!void {
             => return self.fail("TODO implement optimized float mode", .{}),
 
             .is_named_enum_value => return self.fail("TODO implement is_named_enum_value", .{}),
+            .error_set_has_value => return self.fail("TODO implement error_set_has_value", .{}),
 
             .wasm_memory_size => unreachable,
             .wasm_memory_grow => unreachable,
src/codegen/c.zig
@@ -1954,6 +1954,7 @@ fn genBody(f: *Function, body: []const Air.Inst.Index) error{ AnalysisFail, OutO
             => return f.fail("TODO implement optimized float mode", .{}),
 
             .is_named_enum_value => return f.fail("TODO: C backend: implement is_named_enum_value", .{}),
+            .error_set_has_value => return f.fail("TODO: C backend: implement error_set_has_value", .{}),
             // zig fmt: on
         };
         switch (result_value) {
src/codegen/llvm.zig
@@ -4247,6 +4247,7 @@ pub const FuncGen = struct {
                 .prefetch       => try self.airPrefetch(inst),
 
                 .is_named_enum_value => try self.airIsNamedEnumValue(inst),
+                .error_set_has_value => try self.airErrorSetHasValue(inst),
 
                 .reduce           => try self.airReduce(inst, false),
                 .reduce_optimized => try self.airReduce(inst, true),
@@ -7983,6 +7984,53 @@ pub const FuncGen = struct {
         }
     }
 
+    fn airErrorSetHasValue(self: *FuncGen, inst: Air.Inst.Index) !?*const llvm.Value {
+        if (self.liveness.isUnused(inst)) return null;
+
+        const ty_op = self.air.instructions.items(.data)[inst].ty_op;
+        const operand = try self.resolveInst(ty_op.operand);
+        const error_set_ty = self.air.getRefType(ty_op.ty);
+
+        const names = error_set_ty.errorSetNames();
+        const valid_block = self.dg.context.appendBasicBlock(self.llvm_func, "Valid");
+        const invalid_block = self.dg.context.appendBasicBlock(self.llvm_func, "Invalid");
+        const end_block = self.context.appendBasicBlock(self.llvm_func, "End");
+        const switch_instr = self.builder.buildSwitch(operand, invalid_block, @intCast(c_uint, names.len));
+
+        for (names) |name| {
+            const err_int = self.dg.module.global_error_set.get(name).?;
+            const this_tag_int_value = int: {
+                var tag_val_payload: Value.Payload.U64 = .{
+                    .base = .{ .tag = .int_u64 },
+                    .data = err_int,
+                };
+                break :int try self.dg.lowerValue(.{
+                    .ty = Type.u16,
+                    .val = Value.initPayload(&tag_val_payload.base),
+                });
+            };
+            switch_instr.addCase(this_tag_int_value, valid_block);
+        }
+        self.builder.positionBuilderAtEnd(valid_block);
+        _ = self.builder.buildBr(end_block);
+
+        self.builder.positionBuilderAtEnd(invalid_block);
+        _ = self.builder.buildBr(end_block);
+
+        self.builder.positionBuilderAtEnd(end_block);
+
+        const llvm_type = self.dg.context.intType(1);
+        const incoming_values: [2]*const llvm.Value = .{
+            llvm_type.constInt(1, .False), llvm_type.constInt(0, .False),
+        };
+        const incoming_blocks: [2]*const llvm.BasicBlock = .{
+            valid_block, invalid_block,
+        };
+        const phi_node = self.builder.buildPhi(llvm_type, "");
+        phi_node.addIncoming(&incoming_values, &incoming_blocks, 2);
+        return phi_node;
+    }
+
     fn airIsNamedEnumValue(self: *FuncGen, inst: Air.Inst.Index) !?*const llvm.Value {
         if (self.liveness.isUnused(inst)) return null;
 
src/Air.zig
@@ -673,6 +673,10 @@ pub const Inst = struct {
         /// Uses the `un_op` field.
         error_name,
 
+        /// Returns true if error set has error with value.
+        /// Uses the `ty_op` field.
+        error_set_has_value,
+
         /// Constructs a vector, tuple, struct, or array value out of runtime-known elements.
         /// Some of the elements may be comptime-known.
         /// Uses the `ty_pl` field, payload is index of an array of elements, each of which
@@ -1062,6 +1066,7 @@ pub fn typeOfIndex(air: Air, inst: Air.Inst.Index) Type {
         .is_err_ptr,
         .is_non_err_ptr,
         .is_named_enum_value,
+        .error_set_has_value,
         => return Type.bool,
 
         .const_ty => return Type.type,
src/Liveness.zig
@@ -267,6 +267,7 @@ pub fn categorizeOperand(
         .byte_swap,
         .bit_reverse,
         .splat,
+        .error_set_has_value,
         => {
             const o = air_datas[inst].ty_op;
             if (o.operand == operand_ref) return matchOperandSmallIndex(l, inst, 0, .none);
@@ -842,6 +843,7 @@ fn analyzeInst(
         .byte_swap,
         .bit_reverse,
         .splat,
+        .error_set_has_value,
         => {
             const o = inst_datas[inst].ty_op;
             return trackOperands(a, new_set, inst, main_tomb, .{ o.operand, .none, .none });
src/print_air.zig
@@ -243,6 +243,7 @@ const Writer = struct {
             .popcount,
             .byte_swap,
             .bit_reverse,
+            .error_set_has_value,
             => try w.writeTyOp(s, inst),
 
             .block,
src/Sema.zig
@@ -17359,17 +17359,10 @@ fn zirErrSetCast(sema: *Sema, block: *Block, extended: Zir.Inst.Extended.InstDat
     }
 
     try sema.requireRuntimeBlock(block, src, operand_src);
-    if (block.wantSafety() and !dest_ty.isAnyError()) {
+    if (block.wantSafety() and !dest_ty.isAnyError() and sema.mod.comp.bin_file.options.use_llvm) {
         const err_int_inst = try block.addBitCast(Type.u16, operand);
-        // TODO: Output a switch instead of chained OR's.
-        var found_match: Air.Inst.Ref = undefined;
-        for (dest_ty.errorSetNames()) |dest_err_name, i| {
-            const dest_err_int = (try sema.mod.getErrorValue(dest_err_name)).value;
-            const dest_err_int_inst = try sema.addIntUnsigned(Type.u16, dest_err_int);
-            const next_match = try block.addBinOp(.cmp_eq, dest_err_int_inst, err_int_inst);
-            found_match = if (i == 0) next_match else try block.addBinOp(.bool_or, found_match, next_match);
-        }
-        try sema.addSafetyCheck(block, found_match, .invalid_error_code);
+        const ok = try block.addTyOp(.error_set_has_value, dest_ty, err_int_inst);
+        try sema.addSafetyCheck(block, ok, .invalid_error_code);
     }
     return block.addBitCast(dest_ty, operand);
 }