Commit f46d7304b1

Veikka Tuominen <git@vexu.eu>
2022-08-05 17:15:31
stage2: add runtime safety for invalid enum values
1 parent 19d5ffc
src/arch/aarch64/CodeGen.zig
@@ -753,6 +753,8 @@ fn genBody(self: *Self, body: []const Air.Inst.Index) InnerError!void {
             .float_to_int_optimized,
             => return self.fail("TODO implement optimized float mode", .{}),
 
+            .is_named_enum_value => return self.fail("TODO implement is_named_enum_value", .{}),
+
             .wasm_memory_size => unreachable,
             .wasm_memory_grow => unreachable,
             // zig fmt: on
src/arch/arm/CodeGen.zig
@@ -768,6 +768,8 @@ fn genBody(self: *Self, body: []const Air.Inst.Index) InnerError!void {
             .float_to_int_optimized,
             => return self.fail("TODO implement optimized float mode", .{}),
 
+            .is_named_enum_value => return self.fail("TODO implement is_named_enum_value", .{}),
+
             .wasm_memory_size => unreachable,
             .wasm_memory_grow => unreachable,
             // zig fmt: on
src/arch/riscv64/CodeGen.zig
@@ -693,6 +693,8 @@ fn genBody(self: *Self, body: []const Air.Inst.Index) InnerError!void {
             .float_to_int_optimized,
             => return self.fail("TODO implement optimized float mode", .{}),
 
+            .is_named_enum_value => return self.fail("TODO implement is_named_enum_value", .{}),
+
             .wasm_memory_size => unreachable,
             .wasm_memory_grow => unreachable,
             // zig fmt: on
src/arch/sparc64/CodeGen.zig
@@ -705,6 +705,8 @@ fn genBody(self: *Self, body: []const Air.Inst.Index) InnerError!void {
             .float_to_int_optimized,
             => @panic("TODO implement optimized float mode"),
 
+            .is_named_enum_value => @panic("TODO implement is_named_enum_value"),
+
             .wasm_memory_size => unreachable,
             .wasm_memory_grow => unreachable,
             // zig fmt: on
src/arch/wasm/CodeGen.zig
@@ -1621,6 +1621,7 @@ fn genInst(self: *Self, inst: Air.Inst.Index) !WValue {
         .tag_name,
         .err_return_trace,
         .set_err_return_trace,
+        .is_named_enum_value,
         => |tag| return self.fail("TODO: Implement wasm inst: {s}", .{@tagName(tag)}),
 
         .add_optimized,
src/arch/x86_64/CodeGen.zig
@@ -775,6 +775,8 @@ fn genBody(self: *Self, body: []const Air.Inst.Index) InnerError!void {
             .float_to_int_optimized,
             => return self.fail("TODO implement optimized float mode", .{}),
 
+            .is_named_enum_value => return self.fail("TODO implement is_named_enum_value", .{}),
+
             .wasm_memory_size => unreachable,
             .wasm_memory_grow => unreachable,
             // zig fmt: on
src/codegen/c.zig
@@ -1952,6 +1952,8 @@ fn genBody(f: *Function, body: []const Air.Inst.Index) error{ AnalysisFail, OutO
             .reduce_optimized,
             .float_to_int_optimized,
             => return f.fail("TODO implement optimized float mode", .{}),
+
+            .is_named_enum_value => return f.fail("TODO: C backend: implement is_named_enum_value", .{}),
             // zig fmt: on
         };
         switch (result_value) {
src/codegen/llvm.zig
@@ -201,6 +201,8 @@ pub const Object = struct {
     /// * it works for functions not all globals.
     /// Therefore, this table keeps track of the mapping.
     decl_map: std.AutoHashMapUnmanaged(Module.Decl.Index, *const llvm.Value),
+    /// Serves the same purpose as `decl_map` but only used for the `is_named_enum_value` instruction.
+    named_enum_map: std.AutoHashMapUnmanaged(Module.Decl.Index, *const llvm.Value),
     /// Maps Zig types to LLVM types. The table memory itself is backed by the GPA of
     /// the compiler, but the Type/Value memory here is backed by `type_map_arena`.
     /// TODO we need to remove entries from this map in response to incremental compilation
@@ -377,6 +379,7 @@ pub const Object = struct {
             .target_data = target_data,
             .target = options.target,
             .decl_map = .{},
+            .named_enum_map = .{},
             .type_map = .{},
             .type_map_arena = std.heap.ArenaAllocator.init(gpa),
             .di_type_map = .{},
@@ -396,6 +399,7 @@ pub const Object = struct {
         self.llvm_module.dispose();
         self.context.dispose();
         self.decl_map.deinit(gpa);
+        self.named_enum_map.deinit(gpa);
         self.type_map.deinit(gpa);
         self.type_map_arena.deinit();
         self.extern_collisions.deinit(gpa);
@@ -4180,6 +4184,8 @@ pub const FuncGen = struct {
                 .union_init     => try self.airUnionInit(inst),
                 .prefetch       => try self.airPrefetch(inst),
 
+                .is_named_enum_value => try self.airIsNamedEnumValue(inst),
+
                 .reduce           => try self.airReduce(inst, false),
                 .reduce_optimized => try self.airReduce(inst, true),
 
@@ -7882,6 +7888,87 @@ pub const FuncGen = struct {
         }
     }
 
+    fn airIsNamedEnumValue(self: *FuncGen, inst: Air.Inst.Index) !?*const llvm.Value {
+        if (self.liveness.isUnused(inst)) return null;
+
+        const un_op = self.air.instructions.items(.data)[inst].un_op;
+        const operand = try self.resolveInst(un_op);
+        const enum_ty = self.air.typeOf(un_op);
+
+        const llvm_fn = try self.getIsNamedEnumValueFunction(enum_ty);
+        const params = [_]*const llvm.Value{operand};
+        return self.builder.buildCall(llvm_fn, &params, params.len, .Fast, .Auto, "");
+    }
+
+    fn getIsNamedEnumValueFunction(self: *FuncGen, enum_ty: Type) !*const llvm.Value {
+        const enum_decl = enum_ty.getOwnerDecl();
+
+        // TODO: detect when the type changes and re-emit this function.
+        const gop = try self.dg.object.named_enum_map.getOrPut(self.dg.gpa, enum_decl);
+        if (gop.found_existing) return gop.value_ptr.*;
+        errdefer assert(self.dg.object.named_enum_map.remove(enum_decl));
+
+        var arena_allocator = std.heap.ArenaAllocator.init(self.gpa);
+        defer arena_allocator.deinit();
+        const arena = arena_allocator.allocator();
+
+        const mod = self.dg.module;
+        const llvm_fn_name = try std.fmt.allocPrintZ(arena, "__zig_is_named_enum_value_{s}", .{
+            try mod.declPtr(enum_decl).getFullyQualifiedName(mod),
+        });
+
+        var int_tag_type_buffer: Type.Payload.Bits = undefined;
+        const int_tag_ty = enum_ty.intTagType(&int_tag_type_buffer);
+        const param_types = [_]*const llvm.Type{try self.dg.lowerType(int_tag_ty)};
+
+        const llvm_ret_ty = try self.dg.lowerType(Type.bool);
+        const fn_type = llvm.functionType(llvm_ret_ty, &param_types, param_types.len, .False);
+        const fn_val = self.dg.object.llvm_module.addFunction(llvm_fn_name, fn_type);
+        fn_val.setLinkage(.Internal);
+        fn_val.setFunctionCallConv(.Fast);
+        self.dg.addCommonFnAttributes(fn_val);
+        gop.value_ptr.* = fn_val;
+
+        const prev_block = self.builder.getInsertBlock();
+        const prev_debug_location = self.builder.getCurrentDebugLocation2();
+        defer {
+            self.builder.positionBuilderAtEnd(prev_block);
+            if (self.di_scope != null) {
+                self.builder.setCurrentDebugLocation2(prev_debug_location);
+            }
+        }
+
+        const entry_block = self.dg.context.appendBasicBlock(fn_val, "Entry");
+        self.builder.positionBuilderAtEnd(entry_block);
+        self.builder.clearCurrentDebugLocation();
+
+        const fields = enum_ty.enumFields();
+        const named_block = self.dg.context.appendBasicBlock(fn_val, "Named");
+        const unnamed_block = self.dg.context.appendBasicBlock(fn_val, "Unnamed");
+        const tag_int_value = fn_val.getParam(0);
+        const switch_instr = self.builder.buildSwitch(tag_int_value, unnamed_block, @intCast(c_uint, fields.count()));
+
+        for (fields.keys()) |_, field_index| {
+            const this_tag_int_value = int: {
+                var tag_val_payload: Value.Payload.U32 = .{
+                    .base = .{ .tag = .enum_field_index },
+                    .data = @intCast(u32, field_index),
+                };
+                break :int try self.dg.lowerValue(.{
+                    .ty = enum_ty,
+                    .val = Value.initPayload(&tag_val_payload.base),
+                });
+            };
+            switch_instr.addCase(this_tag_int_value, named_block);
+        }
+        self.builder.positionBuilderAtEnd(named_block);
+        _ = self.builder.buildRet(self.dg.context.intType(1).constInt(1, .False));
+
+        self.builder.positionBuilderAtEnd(unnamed_block);
+        _ = self.builder.buildRet(self.dg.context.intType(1).constInt(0, .False));
+        return fn_val;
+    }
+
     fn airTagName(self: *FuncGen, inst: Air.Inst.Index) !?*const llvm.Value {
         if (self.liveness.isUnused(inst)) return null;
 
src/Air.zig
@@ -660,6 +660,10 @@ pub const Inst = struct {
         /// Uses the `pl_op` field with payload `AtomicRmw`. Operand is `ptr`.
         atomic_rmw,
 
+        /// Returns true if enum tag value has a name.
+        /// Uses the `un_op` field.
+        is_named_enum_value,
+
         /// Given an enum tag value, returns the tag name. The enum type may be non-exhaustive.
         /// Result type is always `[:0]const u8`.
         /// Uses the `un_op` field.
@@ -1057,6 +1061,7 @@ pub fn typeOfIndex(air: Air, inst: Air.Inst.Index) Type {
         .is_non_err,
         .is_err_ptr,
         .is_non_err_ptr,
+        .is_named_enum_value,
         => return Type.bool,
 
         .const_ty => return Type.type,
src/Liveness.zig
@@ -291,6 +291,7 @@ pub fn categorizeOperand(
         .is_non_err_ptr,
         .ptrtoint,
         .bool_to_int,
+        .is_named_enum_value,
         .tag_name,
         .error_name,
         .sqrt,
@@ -858,6 +859,7 @@ fn analyzeInst(
         .bool_to_int,
         .ret,
         .ret_load,
+        .is_named_enum_value,
         .tag_name,
         .error_name,
         .sqrt,
src/print_air.zig
@@ -170,6 +170,7 @@ const Writer = struct {
             .bool_to_int,
             .ret,
             .ret_load,
+            .is_named_enum_value,
             .tag_name,
             .error_name,
             .sqrt,
src/Sema.zig
@@ -6933,8 +6933,12 @@ fn zirIntToEnum(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError!A
     }
 
     try sema.requireRuntimeBlock(block, src, operand_src);
-    // TODO insert safety check to make sure the value matches an enum value
-    return block.addTyOp(.intcast, dest_ty, operand);
+    const result = try block.addTyOp(.intcast, dest_ty, operand);
+    if (block.wantSafety() and !dest_ty.isNonexhaustiveEnum() and sema.mod.comp.bin_file.options.use_llvm) {
+        const ok = try block.addUnOp(.is_named_enum_value, result);
+        try sema.addSafetyCheck(block, ok, .invalid_enum_value);
+    }
+    return result;
 }
 
 /// Pointer in, pointer out.
@@ -15887,6 +15891,11 @@ fn zirTagName(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError!Air
         const field_name = enum_ty.enumFieldName(field_index);
         return sema.addStrLit(block, field_name);
     }
+    try sema.requireRuntimeBlock(block, src, operand_src);
+    if (block.wantSafety() and sema.mod.comp.bin_file.options.use_llvm) {
+        const ok = try block.addUnOp(.is_named_enum_value, casted_operand);
+        try sema.addSafetyCheck(block, ok, .invalid_enum_value);
+    }
     // In case the value is runtime-known, we have an AIR instruction for this instead
     // of trying to lower it in Sema because an optimization pass may result in the operand
     // being comptime-known, which would let us elide the `tag_name` AIR instruction.
@@ -20019,6 +20028,7 @@ pub const PanicId = enum {
     integer_part_out_of_bounds,
     corrupt_switch,
     shift_rhs_too_big,
+    invalid_enum_value,
 };
 
 fn addSafetyCheck(
@@ -20316,6 +20326,7 @@ fn safetyPanic(
         .integer_part_out_of_bounds => "integer part of floating point value out of bounds",
         .corrupt_switch => "switch on corrupt value",
         .shift_rhs_too_big => "shift amount is greater than the type size",
+        .invalid_enum_value => "invalid enum value",
     };
 
     const msg_inst = msg_inst: {
test/cases/safety/@intToEnum - no matching tag value.zig
@@ -1,9 +1,11 @@
 const std = @import("std");
 
 pub fn panic(message: []const u8, stack_trace: ?*std.builtin.StackTrace) noreturn {
-    _ = message;
     _ = stack_trace;
-    std.process.exit(0);
+    if (std.mem.eql(u8, message, "invalid enum value")) {
+        std.process.exit(0);
+    }
+    std.process.exit(1);
 }
 const Foo = enum {
     A,
@@ -18,6 +20,7 @@ fn bar(a: u2) Foo {
     return @intToEnum(Foo, a);
 }
 fn baz(_: Foo) void {}
+
 // run
-// backend=stage1
+// backend=llvm
 // target=native
test/cases/safety/@tagName on corrupted enum value.zig
@@ -10,6 +10,7 @@ pub fn panic(message: []const u8, stack_trace: ?*std.builtin.StackTrace) noretur
 
 const E = enum(u32) {
     X = 1,
+    Y = 2,
 };
 
 pub fn main() !void {
@@ -21,5 +22,5 @@ pub fn main() !void {
 }
 
 // run
-// backend=stage1
+// backend=llvm
 // target=native
test/cases/safety/@tagName on corrupted union value.zig
@@ -10,6 +10,7 @@ pub fn panic(message: []const u8, stack_trace: ?*std.builtin.StackTrace) noretur
 
 const U = union(enum(u32)) {
     X: u8,
+    Y: i8,
 };
 
 pub fn main() !void {
@@ -22,5 +23,5 @@ pub fn main() !void {
 }
 
 // run
-// backend=stage1
+// backend=llvm
 // target=native