Commit 8ee629aa4c

Andrew Kelley <andrew@ziglang.org>
2020-07-21 21:13:15
stage2: ability for ZIR to map multiple tags to the same type
1 parent 7a1a924
Changed files (4)
src-self-hosted
test
stage2
src-self-hosted/astgen.zig
@@ -16,6 +16,15 @@ pub fn expr(mod: *Module, scope: *Scope, node: *ast.Node) InnerError!*zir.Inst {
     switch (node.tag) {
         .VarDecl => unreachable, // Handled in `blockExpr`.
 
+        .Add => return simpleInfixOp(mod, scope, node.castTag(.Add).?, .add),
+        .Sub => return simpleInfixOp(mod, scope, node.castTag(.Sub).?, .sub),
+        .BangEqual => return simpleInfixOp(mod, scope, node.castTag(.BangEqual).?, .cmp_neq),
+        .EqualEqual => return simpleInfixOp(mod, scope, node.castTag(.EqualEqual).?, .cmp_eq),
+        .GreaterThan => return simpleInfixOp(mod, scope, node.castTag(.GreaterThan).?, .cmp_gt),
+        .GreaterOrEqual => return simpleInfixOp(mod, scope, node.castTag(.GreaterOrEqual).?, .cmp_gte),
+        .LessThan => return simpleInfixOp(mod, scope, node.castTag(.LessThan).?, .cmp_lt),
+        .LessOrEqual => return simpleInfixOp(mod, scope, node.castTag(.LessOrEqual).?, .cmp_lte),
+
         .Identifier => return identifier(mod, scope, node.castTag(.Identifier).?),
         .Asm => return assembly(mod, scope, node.castTag(.Asm).?),
         .StringLiteral => return stringLiteral(mod, scope, node.castTag(.StringLiteral).?),
@@ -26,13 +35,6 @@ pub fn expr(mod: *Module, scope: *Scope, node: *ast.Node) InnerError!*zir.Inst {
         .ControlFlowExpression => return controlFlowExpr(mod, scope, node.castTag(.ControlFlowExpression).?),
         .If => return ifExpr(mod, scope, node.castTag(.If).?),
         .Assign => return assign(mod, scope, node.castTag(.Assign).?),
-        .Add => return add(mod, scope, node.castTag(.Add).?),
-        .BangEqual => return cmp(mod, scope, node.castTag(.BangEqual).?, .neq),
-        .EqualEqual => return cmp(mod, scope, node.castTag(.EqualEqual).?, .eq),
-        .GreaterThan => return cmp(mod, scope, node.castTag(.GreaterThan).?, .gt),
-        .GreaterOrEqual => return cmp(mod, scope, node.castTag(.GreaterOrEqual).?, .gte),
-        .LessThan => return cmp(mod, scope, node.castTag(.LessThan).?, .lt),
-        .LessOrEqual => return cmp(mod, scope, node.castTag(.LessOrEqual).?, .lte),
         .Period => return field(mod, scope, node.castTag(.Period).?),
         .Deref => return deref(mod, scope, node.castTag(.Deref).?),
         .BoolNot => return boolNot(mod, scope, node.castTag(.BoolNot).?),
@@ -97,14 +99,6 @@ fn varDecl(mod: *Module, scope: *Scope, node: *ast.Node.VarDecl) InnerError!Scop
         },
         .Keyword_var => {
             return mod.failNode(scope, &node.base, "TODO implement local vars", .{});
-            //const src = tree.token_locs[node.name_token].start;
-            //const alloc = mod.addZIRInst(scope, src, zir.Inst.Alloc, .{}, .{});
-            //if (node.getTrailer("type_node")) |type_node| {
-            //    const type_inst = try expr(mod, scope, type_node);
-            //    return mod.failNode(scope, type_node, "TODO implement typed var locals", .{});
-            //} else {
-            //    return mod.failTok(scope, node.mut_token, "TODO implement mutable type-inferred locals", .{});
-            //}
         },
         else => unreachable,
     }
@@ -114,7 +108,7 @@ fn boolNot(mod: *Module, scope: *Scope, node: *ast.Node.SimplePrefixOp) InnerErr
     const operand = try expr(mod, scope, node.rhs);
     const tree = scope.tree();
     const src = tree.token_locs[node.op_token].start;
-    return mod.addZIRInst(scope, src, zir.Inst.BoolNot, .{ .operand = operand }, .{});
+    return mod.addZIRUnOp(scope, src, .boolnot, operand);
 }
 
 fn assign(mod: *Module, scope: *Scope, infix_node: *ast.Node.SimpleInfixOp) InnerError!*zir.Inst {
@@ -169,33 +163,21 @@ fn field(mod: *Module, scope: *Scope, node: *ast.Node.SimpleInfixOp) InnerError!
     const field_name = try identifierStringInst(mod, scope, node.rhs.castTag(.Identifier).?);
 
     const pointer = try mod.addZIRInst(scope, src, zir.Inst.FieldPtr, .{ .object_ptr = lhs, .field_name = field_name }, .{});
-    return mod.addZIRInst(scope, src, zir.Inst.Deref, .{ .ptr = pointer }, .{});
+    return mod.addZIRUnOp(scope, src, .deref, pointer);
 }
 
 fn deref(mod: *Module, scope: *Scope, node: *ast.Node.SimpleSuffixOp) InnerError!*zir.Inst {
     const tree = scope.tree();
     const src = tree.token_locs[node.rtoken].start;
-
     const lhs = try expr(mod, scope, node.lhs);
-
-    return mod.addZIRInst(scope, src, zir.Inst.Deref, .{ .ptr = lhs }, .{});
-}
-
-fn add(mod: *Module, scope: *Scope, infix_node: *ast.Node.SimpleInfixOp) InnerError!*zir.Inst {
-    const lhs = try expr(mod, scope, infix_node.lhs);
-    const rhs = try expr(mod, scope, infix_node.rhs);
-
-    const tree = scope.tree();
-    const src = tree.token_locs[infix_node.op_token].start;
-
-    return mod.addZIRInst(scope, src, zir.Inst.Add, .{ .lhs = lhs, .rhs = rhs }, .{});
+    return mod.addZIRUnOp(scope, src, .deref, lhs);
 }
 
-fn cmp(
+fn simpleInfixOp(
     mod: *Module,
     scope: *Scope,
     infix_node: *ast.Node.SimpleInfixOp,
-    op: std.math.CompareOperator,
+    op_inst_tag: zir.Inst.Tag,
 ) InnerError!*zir.Inst {
     const lhs = try expr(mod, scope, infix_node.lhs);
     const rhs = try expr(mod, scope, infix_node.rhs);
@@ -203,11 +185,7 @@ fn cmp(
     const tree = scope.tree();
     const src = tree.token_locs[infix_node.op_token].start;
 
-    return mod.addZIRInst(scope, src, zir.Inst.Cmp, .{
-        .lhs = lhs,
-        .op = op,
-        .rhs = rhs,
-    }, .{});
+    return mod.addZIRBinOp(scope, src, op_inst_tag, lhs, rhs);
 }
 
 fn ifExpr(mod: *Module, scope: *Scope, if_node: *ast.Node.If) InnerError!*zir.Inst {
@@ -306,9 +284,9 @@ fn controlFlowExpr(
     const src = tree.token_locs[cfe.ltoken].start;
     if (cfe.rhs) |rhs_node| {
         const operand = try expr(mod, scope, rhs_node);
-        return mod.addZIRInst(scope, src, zir.Inst.Return, .{ .operand = operand }, .{});
+        return mod.addZIRUnOp(scope, src, .@"return", operand);
     } else {
-        return mod.addZIRInst(scope, src, zir.Inst.ReturnVoid, .{}, .{});
+        return mod.addZIRNoOp(scope, src, .returnvoid);
     }
 }
 
@@ -519,7 +497,7 @@ fn callExpr(mod: *Module, scope: *Scope, node: *ast.Node.Call) InnerError!*zir.I
 fn unreach(mod: *Module, scope: *Scope, unreach_node: *ast.Node.Unreachable) InnerError!*zir.Inst {
     const tree = scope.tree();
     const src = tree.token_locs[unreach_node.token].start;
-    return mod.addZIRInst(scope, src, zir.Inst.Unreachable, .{}, .{});
+    return mod.addZIRNoOp(scope, src, .@"unreachable");
 }
 
 fn getSimplePrimitiveValue(name: []const u8) ?TypedValue {
src-self-hosted/Module.zig
@@ -1215,7 +1215,7 @@ fn astGenAndAnalyzeDecl(self: *Module, decl: *Decl) !bool {
                 .return_type = return_type_inst,
                 .param_types = param_types,
             }, .{});
-            _ = try self.addZIRInst(&fn_type_scope.base, fn_src, zir.Inst.Return, .{ .operand = fn_type_inst }, .{});
+            _ = try self.addZIRUnOp(&fn_type_scope.base, fn_src, .@"return", fn_type_inst);
 
             // We need the memory for the Type to go into the arena for the Decl
             var decl_arena = std.heap.ArenaAllocator.init(self.gpa);
@@ -1256,7 +1256,15 @@ fn astGenAndAnalyzeDecl(self: *Module, decl: *Decl) !bool {
                     const name_token = param.name_token.?;
                     const src = tree.token_locs[name_token].start;
                     const param_name = tree.tokenSlice(name_token);
-                    const arg = try newZIRInst(&gen_scope_arena.allocator, src, zir.Inst.Arg, .{}, .{});
+                    const arg = try gen_scope_arena.allocator.create(zir.Inst.NoOp);
+                    arg.* = .{
+                        .base = .{
+                            .tag = .arg,
+                            .src = src,
+                        },
+                        .positionals = .{},
+                        .kw_args = .{},
+                    };
                     gen_scope.instructions.items[i] = &arg.base;
                     const sub_scope = try gen_scope_arena.allocator.create(Scope.LocalVar);
                     sub_scope.* = .{
@@ -1276,7 +1284,7 @@ fn astGenAndAnalyzeDecl(self: *Module, decl: *Decl) !bool {
                     !gen_scope.instructions.items[gen_scope.instructions.items.len - 1].tag.isNoReturn()))
                 {
                     const src = tree.token_locs[body_block.rbrace].start;
-                    _ = try self.addZIRInst(&gen_scope.base, src, zir.Inst.ReturnVoid, .{}, .{});
+                    _ = try self.addZIRNoOp(&gen_scope.base, src, .returnvoid);
                 }
 
                 const fn_zir = try gen_scope_arena.allocator.create(Fn.ZIR);
@@ -2067,14 +2075,17 @@ fn addCall(
     return &inst.base;
 }
 
-fn newZIRInst(
-    gpa: *Allocator,
+pub fn addZIRInstSpecial(
+    self: *Module,
+    scope: *Scope,
     src: usize,
     comptime T: type,
     positionals: std.meta.fieldInfo(T, "positionals").field_type,
     kw_args: std.meta.fieldInfo(T, "kw_args").field_type,
 ) !*T {
-    const inst = try gpa.create(T);
+    const gen_zir = scope.getGenZIR();
+    try gen_zir.instructions.ensureCapacity(self.gpa, gen_zir.instructions.items.len + 1);
+    const inst = try gen_zir.arena.create(T);
     inst.* = .{
         .base = .{
             .tag = T.base_tag,
@@ -2083,22 +2094,79 @@ fn newZIRInst(
         .positionals = positionals,
         .kw_args = kw_args,
     };
+    gen_zir.instructions.appendAssumeCapacity(&inst.base);
     return inst;
 }
 
-pub fn addZIRInstSpecial(
+pub fn addZIRNoOp(
     self: *Module,
     scope: *Scope,
     src: usize,
-    comptime T: type,
-    positionals: std.meta.fieldInfo(T, "positionals").field_type,
-    kw_args: std.meta.fieldInfo(T, "kw_args").field_type,
-) !*T {
+    tag: zir.Inst.Tag,
+) !*zir.Inst {
     const gen_zir = scope.getGenZIR();
     try gen_zir.instructions.ensureCapacity(self.gpa, gen_zir.instructions.items.len + 1);
-    const inst = try newZIRInst(gen_zir.arena, src, T, positionals, kw_args);
+    const inst = try gen_zir.arena.create(zir.Inst.NoOp);
+    inst.* = .{
+        .base = .{
+            .tag = tag,
+            .src = src,
+        },
+        .positionals = .{},
+        .kw_args = .{},
+    };
     gen_zir.instructions.appendAssumeCapacity(&inst.base);
-    return inst;
+    return &inst.base;
+}
+
+pub fn addZIRUnOp(
+    self: *Module,
+    scope: *Scope,
+    src: usize,
+    tag: zir.Inst.Tag,
+    operand: *zir.Inst,
+) !*zir.Inst {
+    const gen_zir = scope.getGenZIR();
+    try gen_zir.instructions.ensureCapacity(self.gpa, gen_zir.instructions.items.len + 1);
+    const inst = try gen_zir.arena.create(zir.Inst.UnOp);
+    inst.* = .{
+        .base = .{
+            .tag = tag,
+            .src = src,
+        },
+        .positionals = .{
+            .operand = operand,
+        },
+        .kw_args = .{},
+    };
+    gen_zir.instructions.appendAssumeCapacity(&inst.base);
+    return &inst.base;
+}
+
+pub fn addZIRBinOp(
+    self: *Module,
+    scope: *Scope,
+    src: usize,
+    tag: zir.Inst.Tag,
+    lhs: *zir.Inst,
+    rhs: *zir.Inst,
+) !*zir.Inst {
+    const gen_zir = scope.getGenZIR();
+    try gen_zir.instructions.ensureCapacity(self.gpa, gen_zir.instructions.items.len + 1);
+    const inst = try gen_zir.arena.create(zir.Inst.BinOp);
+    inst.* = .{
+        .base = .{
+            .tag = tag,
+            .src = src,
+        },
+        .positionals = .{
+            .lhs = lhs,
+            .rhs = rhs,
+        },
+        .kw_args = .{},
+    };
+    gen_zir.instructions.appendAssumeCapacity(&inst.base);
+    return &inst.base;
 }
 
 pub fn addZIRInst(
@@ -2252,46 +2320,51 @@ fn analyzeInstConst(self: *Module, scope: *Scope, const_inst: *zir.Inst.Const) I
 
 fn analyzeInst(self: *Module, scope: *Scope, old_inst: *zir.Inst) InnerError!*Inst {
     switch (old_inst.tag) {
-        .arg => return self.analyzeInstArg(scope, old_inst.cast(zir.Inst.Arg).?),
-        .block => return self.analyzeInstBlock(scope, old_inst.cast(zir.Inst.Block).?),
-        .@"break" => return self.analyzeInstBreak(scope, old_inst.cast(zir.Inst.Break).?),
-        .breakpoint => return self.analyzeInstBreakpoint(scope, old_inst.cast(zir.Inst.Breakpoint).?),
-        .breakvoid => return self.analyzeInstBreakVoid(scope, old_inst.cast(zir.Inst.BreakVoid).?),
-        .call => return self.analyzeInstCall(scope, old_inst.cast(zir.Inst.Call).?),
-        .compileerror => return self.analyzeInstCompileError(scope, old_inst.cast(zir.Inst.CompileError).?),
-        .@"const" => return self.analyzeInstConst(scope, old_inst.cast(zir.Inst.Const).?),
-        .declref => return self.analyzeInstDeclRef(scope, old_inst.cast(zir.Inst.DeclRef).?),
-        .declref_str => return self.analyzeInstDeclRefStr(scope, old_inst.cast(zir.Inst.DeclRefStr).?),
-        .declval => return self.analyzeInstDeclVal(scope, old_inst.cast(zir.Inst.DeclVal).?),
-        .declval_in_module => return self.analyzeInstDeclValInModule(scope, old_inst.cast(zir.Inst.DeclValInModule).?),
-        .str => return self.analyzeInstStr(scope, old_inst.cast(zir.Inst.Str).?),
+        .arg => return self.analyzeInstArg(scope, old_inst.castTag(.arg).?),
+        .block => return self.analyzeInstBlock(scope, old_inst.castTag(.block).?),
+        .@"break" => return self.analyzeInstBreak(scope, old_inst.castTag(.@"break").?),
+        .breakpoint => return self.analyzeInstBreakpoint(scope, old_inst.castTag(.breakpoint).?),
+        .breakvoid => return self.analyzeInstBreakVoid(scope, old_inst.castTag(.breakvoid).?),
+        .call => return self.analyzeInstCall(scope, old_inst.castTag(.call).?),
+        .compileerror => return self.analyzeInstCompileError(scope, old_inst.castTag(.compileerror).?),
+        .@"const" => return self.analyzeInstConst(scope, old_inst.castTag(.@"const").?),
+        .declref => return self.analyzeInstDeclRef(scope, old_inst.castTag(.declref).?),
+        .declref_str => return self.analyzeInstDeclRefStr(scope, old_inst.castTag(.declref_str).?),
+        .declval => return self.analyzeInstDeclVal(scope, old_inst.castTag(.declval).?),
+        .declval_in_module => return self.analyzeInstDeclValInModule(scope, old_inst.castTag(.declval_in_module).?),
+        .str => return self.analyzeInstStr(scope, old_inst.castTag(.str).?),
         .int => {
-            const big_int = old_inst.cast(zir.Inst.Int).?.positionals.int;
+            const big_int = old_inst.castTag(.int).?.positionals.int;
             return self.constIntBig(scope, old_inst.src, Type.initTag(.comptime_int), big_int);
         },
-        .inttype => return self.analyzeInstIntType(scope, old_inst.cast(zir.Inst.IntType).?),
-        .ptrtoint => return self.analyzeInstPtrToInt(scope, old_inst.cast(zir.Inst.PtrToInt).?),
-        .fieldptr => return self.analyzeInstFieldPtr(scope, old_inst.cast(zir.Inst.FieldPtr).?),
-        .deref => return self.analyzeInstDeref(scope, old_inst.cast(zir.Inst.Deref).?),
-        .as => return self.analyzeInstAs(scope, old_inst.cast(zir.Inst.As).?),
-        .@"asm" => return self.analyzeInstAsm(scope, old_inst.cast(zir.Inst.Asm).?),
-        .@"unreachable" => return self.analyzeInstUnreachable(scope, old_inst.cast(zir.Inst.Unreachable).?),
-        .@"return" => return self.analyzeInstRet(scope, old_inst.cast(zir.Inst.Return).?),
-        .returnvoid => return self.analyzeInstRetVoid(scope, old_inst.cast(zir.Inst.ReturnVoid).?),
-        .@"fn" => return self.analyzeInstFn(scope, old_inst.cast(zir.Inst.Fn).?),
-        .@"export" => return self.analyzeInstExport(scope, old_inst.cast(zir.Inst.Export).?),
-        .primitive => return self.analyzeInstPrimitive(scope, old_inst.cast(zir.Inst.Primitive).?),
-        .fntype => return self.analyzeInstFnType(scope, old_inst.cast(zir.Inst.FnType).?),
-        .intcast => return self.analyzeInstIntCast(scope, old_inst.cast(zir.Inst.IntCast).?),
-        .bitcast => return self.analyzeInstBitCast(scope, old_inst.cast(zir.Inst.BitCast).?),
-        .elemptr => return self.analyzeInstElemPtr(scope, old_inst.cast(zir.Inst.ElemPtr).?),
-        .add => return self.analyzeInstAdd(scope, old_inst.cast(zir.Inst.Add).?),
-        .sub => return self.analyzeInstSub(scope, old_inst.cast(zir.Inst.Sub).?),
-        .cmp => return self.analyzeInstCmp(scope, old_inst.cast(zir.Inst.Cmp).?),
-        .condbr => return self.analyzeInstCondBr(scope, old_inst.cast(zir.Inst.CondBr).?),
-        .isnull => return self.analyzeInstIsNull(scope, old_inst.cast(zir.Inst.IsNull).?),
-        .isnonnull => return self.analyzeInstIsNonNull(scope, old_inst.cast(zir.Inst.IsNonNull).?),
-        .boolnot => return self.analyzeInstBoolNot(scope, old_inst.cast(zir.Inst.BoolNot).?),
+        .inttype => return self.analyzeInstIntType(scope, old_inst.castTag(.inttype).?),
+        .ptrtoint => return self.analyzeInstPtrToInt(scope, old_inst.castTag(.ptrtoint).?),
+        .fieldptr => return self.analyzeInstFieldPtr(scope, old_inst.castTag(.fieldptr).?),
+        .deref => return self.analyzeInstDeref(scope, old_inst.castTag(.deref).?),
+        .as => return self.analyzeInstAs(scope, old_inst.castTag(.as).?),
+        .@"asm" => return self.analyzeInstAsm(scope, old_inst.castTag(.@"asm").?),
+        .@"unreachable" => return self.analyzeInstUnreachable(scope, old_inst.castTag(.@"unreachable").?),
+        .@"return" => return self.analyzeInstRet(scope, old_inst.castTag(.@"return").?),
+        .returnvoid => return self.analyzeInstRetVoid(scope, old_inst.castTag(.returnvoid).?),
+        .@"fn" => return self.analyzeInstFn(scope, old_inst.castTag(.@"fn").?),
+        .@"export" => return self.analyzeInstExport(scope, old_inst.castTag(.@"export").?),
+        .primitive => return self.analyzeInstPrimitive(scope, old_inst.castTag(.primitive).?),
+        .fntype => return self.analyzeInstFnType(scope, old_inst.castTag(.fntype).?),
+        .intcast => return self.analyzeInstIntCast(scope, old_inst.castTag(.intcast).?),
+        .bitcast => return self.analyzeInstBitCast(scope, old_inst.castTag(.bitcast).?),
+        .elemptr => return self.analyzeInstElemPtr(scope, old_inst.castTag(.elemptr).?),
+        .add => return self.analyzeInstAdd(scope, old_inst.castTag(.add).?),
+        .sub => return self.analyzeInstSub(scope, old_inst.castTag(.sub).?),
+        .cmp_lt => return self.analyzeInstCmp(scope, old_inst.castTag(.cmp_lt).?, .lt),
+        .cmp_lte => return self.analyzeInstCmp(scope, old_inst.castTag(.cmp_lte).?, .lte),
+        .cmp_eq => return self.analyzeInstCmp(scope, old_inst.castTag(.cmp_eq).?, .eq),
+        .cmp_gte => return self.analyzeInstCmp(scope, old_inst.castTag(.cmp_gte).?, .gte),
+        .cmp_gt => return self.analyzeInstCmp(scope, old_inst.castTag(.cmp_gt).?, .gt),
+        .cmp_neq => return self.analyzeInstCmp(scope, old_inst.castTag(.cmp_neq).?, .neq),
+        .condbr => return self.analyzeInstCondBr(scope, old_inst.castTag(.condbr).?),
+        .isnull => return self.analyzeInstIsNonNull(scope, old_inst.castTag(.isnull).?, true),
+        .isnonnull => return self.analyzeInstIsNonNull(scope, old_inst.castTag(.isnonnull).?, false),
+        .boolnot => return self.analyzeInstBoolNot(scope, old_inst.castTag(.boolnot).?),
     }
 }
 
@@ -2372,7 +2445,7 @@ fn analyzeInstCompileError(self: *Module, scope: *Scope, inst: *zir.Inst.Compile
     return self.fail(scope, inst.base.src, "{}", .{inst.positionals.msg});
 }
 
-fn analyzeInstArg(self: *Module, scope: *Scope, inst: *zir.Inst.Arg) InnerError!*Inst {
+fn analyzeInstArg(self: *Module, scope: *Scope, inst: *zir.Inst.NoOp) InnerError!*Inst {
     const b = try self.requireRuntimeBlock(scope, inst.base.src);
     const fn_ty = b.func.?.owner_decl.typed_value.most_recent.typed_value.ty;
     const param_index = b.instructions.items.len;
@@ -2435,7 +2508,7 @@ fn analyzeInstBlock(self: *Module, scope: *Scope, inst: *zir.Inst.Block) InnerEr
     return &block_inst.base;
 }
 
-fn analyzeInstBreakpoint(self: *Module, scope: *Scope, inst: *zir.Inst.Breakpoint) InnerError!*Inst {
+fn analyzeInstBreakpoint(self: *Module, scope: *Scope, inst: *zir.Inst.NoOp) InnerError!*Inst {
     const b = try self.requireRuntimeBlock(scope, inst.base.src);
     return self.addNoOp(b, inst.base.src, Type.initTag(.void), .breakpoint);
 }
@@ -2791,11 +2864,11 @@ fn analyzeInstElemPtr(self: *Module, scope: *Scope, inst: *zir.Inst.ElemPtr) Inn
     return self.fail(scope, inst.base.src, "TODO implement more analyze elemptr", .{});
 }
 
-fn analyzeInstSub(self: *Module, scope: *Scope, inst: *zir.Inst.Sub) InnerError!*Inst {
+fn analyzeInstSub(self: *Module, scope: *Scope, inst: *zir.Inst.BinOp) InnerError!*Inst {
     return self.fail(scope, inst.base.src, "TODO implement analysis of sub", .{});
 }
 
-fn analyzeInstAdd(self: *Module, scope: *Scope, inst: *zir.Inst.Add) InnerError!*Inst {
+fn analyzeInstAdd(self: *Module, scope: *Scope, inst: *zir.Inst.BinOp) InnerError!*Inst {
     const tracy = trace(@src());
     defer tracy.end();
 
@@ -2848,9 +2921,9 @@ fn analyzeInstAdd(self: *Module, scope: *Scope, inst: *zir.Inst.Add) InnerError!
     return self.fail(scope, inst.base.src, "TODO analyze add for {} + {}", .{ lhs.ty.zigTypeTag(), rhs.ty.zigTypeTag() });
 }
 
-fn analyzeInstDeref(self: *Module, scope: *Scope, deref: *zir.Inst.Deref) InnerError!*Inst {
-    const ptr = try self.resolveInst(scope, deref.positionals.ptr);
-    return self.analyzeDeref(scope, deref.base.src, ptr, deref.positionals.ptr.src);
+fn analyzeInstDeref(self: *Module, scope: *Scope, deref: *zir.Inst.UnOp) InnerError!*Inst {
+    const ptr = try self.resolveInst(scope, deref.positionals.operand);
+    return self.analyzeDeref(scope, deref.base.src, ptr, deref.positionals.operand.src);
 }
 
 fn analyzeDeref(self: *Module, scope: *Scope, src: usize, ptr: *Inst, ptr_src: usize) InnerError!*Inst {
@@ -2907,10 +2980,14 @@ fn analyzeInstAsm(self: *Module, scope: *Scope, assembly: *zir.Inst.Asm) InnerEr
     return &inst.base;
 }
 
-fn analyzeInstCmp(self: *Module, scope: *Scope, inst: *zir.Inst.Cmp) InnerError!*Inst {
+fn analyzeInstCmp(
+    self: *Module,
+    scope: *Scope,
+    inst: *zir.Inst.BinOp,
+    op: std.math.CompareOperator,
+) InnerError!*Inst {
     const lhs = try self.resolveInst(scope, inst.positionals.lhs);
     const rhs = try self.resolveInst(scope, inst.positionals.rhs);
-    const op = inst.positionals.op;
 
     const is_equality_cmp = switch (op) {
         .eq, .neq => true,
@@ -2964,7 +3041,7 @@ fn analyzeInstCmp(self: *Module, scope: *Scope, inst: *zir.Inst.Cmp) InnerError!
     return self.fail(scope, inst.base.src, "TODO implement more cmp analysis", .{});
 }
 
-fn analyzeInstBoolNot(self: *Module, scope: *Scope, inst: *zir.Inst.BoolNot) InnerError!*Inst {
+fn analyzeInstBoolNot(self: *Module, scope: *Scope, inst: *zir.Inst.UnOp) InnerError!*Inst {
     const uncasted_operand = try self.resolveInst(scope, inst.positionals.operand);
     const bool_type = Type.initTag(.bool);
     const operand = try self.coerce(scope, bool_type, uncasted_operand);
@@ -2975,14 +3052,9 @@ fn analyzeInstBoolNot(self: *Module, scope: *Scope, inst: *zir.Inst.BoolNot) Inn
     return self.addUnOp(b, inst.base.src, bool_type, .not, operand);
 }
 
-fn analyzeInstIsNull(self: *Module, scope: *Scope, inst: *zir.Inst.IsNull) InnerError!*Inst {
-    const operand = try self.resolveInst(scope, inst.positionals.operand);
-    return self.analyzeIsNull(scope, inst.base.src, operand, true);
-}
-
-fn analyzeInstIsNonNull(self: *Module, scope: *Scope, inst: *zir.Inst.IsNonNull) InnerError!*Inst {
+fn analyzeInstIsNonNull(self: *Module, scope: *Scope, inst: *zir.Inst.UnOp, invert_logic: bool) InnerError!*Inst {
     const operand = try self.resolveInst(scope, inst.positionals.operand);
-    return self.analyzeIsNull(scope, inst.base.src, operand, false);
+    return self.analyzeIsNull(scope, inst.base.src, operand, invert_logic);
 }
 
 fn analyzeInstCondBr(self: *Module, scope: *Scope, inst: *zir.Inst.CondBr) InnerError!*Inst {
@@ -3031,7 +3103,7 @@ fn wantSafety(self: *Module, scope: *Scope) bool {
     };
 }
 
-fn analyzeInstUnreachable(self: *Module, scope: *Scope, unreach: *zir.Inst.Unreachable) InnerError!*Inst {
+fn analyzeInstUnreachable(self: *Module, scope: *Scope, unreach: *zir.Inst.NoOp) InnerError!*Inst {
     const b = try self.requireRuntimeBlock(scope, unreach.base.src);
     if (self.wantSafety(scope)) {
         // TODO Once we have a panic function to call, call it here instead of this.
@@ -3040,13 +3112,13 @@ fn analyzeInstUnreachable(self: *Module, scope: *Scope, unreach: *zir.Inst.Unrea
     return self.addNoOp(b, unreach.base.src, Type.initTag(.noreturn), .unreach);
 }
 
-fn analyzeInstRet(self: *Module, scope: *Scope, inst: *zir.Inst.Return) InnerError!*Inst {
+fn analyzeInstRet(self: *Module, scope: *Scope, inst: *zir.Inst.UnOp) InnerError!*Inst {
     const operand = try self.resolveInst(scope, inst.positionals.operand);
     const b = try self.requireRuntimeBlock(scope, inst.base.src);
     return self.addUnOp(b, inst.base.src, Type.initTag(.noreturn), .ret, operand);
 }
 
-fn analyzeInstRetVoid(self: *Module, scope: *Scope, inst: *zir.Inst.ReturnVoid) InnerError!*Inst {
+fn analyzeInstRetVoid(self: *Module, scope: *Scope, inst: *zir.Inst.NoOp) InnerError!*Inst {
     const b = try self.requireRuntimeBlock(scope, inst.base.src);
     return self.addNoOp(b, inst.base.src, Type.initTag(.noreturn), .retvoid);
 }
src-self-hosted/zir.zig
@@ -79,11 +79,69 @@ pub const Inst = struct {
         elemptr,
         add,
         sub,
-        cmp,
+        cmp_lt,
+        cmp_lte,
+        cmp_eq,
+        cmp_gte,
+        cmp_gt,
+        cmp_neq,
         condbr,
         isnull,
         isnonnull,
 
+        pub fn Type(tag: Tag) type {
+            return switch (tag) {
+                .arg,
+                .breakpoint,
+                .@"unreachable",
+                .returnvoid,
+                => NoOp,
+
+                .boolnot,
+                .deref,
+                .@"return",
+                .isnull,
+                .isnonnull,
+                => UnOp,
+
+                .add,
+                .sub,
+                .cmp_lt,
+                .cmp_lte,
+                .cmp_eq,
+                .cmp_gte,
+                .cmp_gt,
+                .cmp_neq,
+                => BinOp,
+
+                .block => Block,
+                .@"break" => Break,
+                .breakvoid => BreakVoid,
+                .call => Call,
+                .declref => DeclRef,
+                .declref_str => DeclRefStr,
+                .declval => DeclVal,
+                .declval_in_module => DeclValInModule,
+                .compileerror => CompileError,
+                .@"const" => Const,
+                .str => Str,
+                .int => Int,
+                .inttype => IntType,
+                .ptrtoint => PtrToInt,
+                .fieldptr => FieldPtr,
+                .as => As,
+                .@"asm" => Asm,
+                .@"fn" => Fn,
+                .@"export" => Export,
+                .primitive => Primitive,
+                .fntype => FnType,
+                .intcast => IntCast,
+                .bitcast => BitCast,
+                .elemptr => ElemPtr,
+                .condbr => CondBr,
+            };
+        }
+
         /// Returns whether the instruction is one of the control flow "noreturn" types.
         /// Function calls do not count.
         pub fn isNoReturn(tag: Tag) bool {
@@ -114,7 +172,12 @@ pub const Inst = struct {
                 .elemptr,
                 .add,
                 .sub,
-                .cmp,
+                .cmp_lt,
+                .cmp_lte,
+                .cmp_eq,
+                .cmp_gte,
+                .cmp_gt,
+                .cmp_neq,
                 .isnull,
                 .isnonnull,
                 .boolnot,
@@ -132,63 +195,56 @@ pub const Inst = struct {
         }
     };
 
-    pub fn TagToType(tag: Tag) type {
-        return switch (tag) {
-            .arg => Arg,
-            .block => Block,
-            .@"break" => Break,
-            .breakpoint => Breakpoint,
-            .breakvoid => BreakVoid,
-            .call => Call,
-            .declref => DeclRef,
-            .declref_str => DeclRefStr,
-            .declval => DeclVal,
-            .declval_in_module => DeclValInModule,
-            .compileerror => CompileError,
-            .@"const" => Const,
-            .boolnot => BoolNot,
-            .str => Str,
-            .int => Int,
-            .inttype => IntType,
-            .ptrtoint => PtrToInt,
-            .fieldptr => FieldPtr,
-            .deref => Deref,
-            .as => As,
-            .@"asm" => Asm,
-            .@"unreachable" => Unreachable,
-            .@"return" => Return,
-            .returnvoid => ReturnVoid,
-            .@"fn" => Fn,
-            .@"export" => Export,
-            .primitive => Primitive,
-            .fntype => FnType,
-            .intcast => IntCast,
-            .bitcast => BitCast,
-            .elemptr => ElemPtr,
-            .add => Add,
-            .sub => Sub,
-            .cmp => Cmp,
-            .condbr => CondBr,
-            .isnull => IsNull,
-            .isnonnull => IsNonNull,
-        };
-    }
-
+    /// Prefer `castTag` to this.
     pub fn cast(base: *Inst, comptime T: type) ?*T {
-        if (base.tag != T.base_tag)
-            return null;
+        if (@hasField(T, "base_tag")) {
+            return base.castTag(T.base_tag);
+        }
+        inline for (@typeInfo(Tag).Enum.fields) |field| {
+            const tag = @intToEnum(Tag, field.value);
+            if (base.tag == tag) {
+                if (T == tag.Type()) {
+                    return @fieldParentPtr(T, "base", base);
+                }
+                return null;
+            }
+        }
+        unreachable;
+    }
 
-        return @fieldParentPtr(T, "base", base);
+    pub fn castTag(base: *Inst, comptime tag: Tag) ?*tag.Type() {
+        if (base.tag == tag) {
+            return @fieldParentPtr(tag.Type(), "base", base);
+        }
+        return null;
     }
 
-    pub const Arg = struct {
-        pub const base_tag = Tag.arg;
+    pub const NoOp = struct {
         base: Inst,
 
         positionals: struct {},
         kw_args: struct {},
     };
 
+    pub const UnOp = struct {
+        base: Inst,
+
+        positionals: struct {
+            operand: *Inst,
+        },
+        kw_args: struct {},
+    };
+
+    pub const BinOp = struct {
+        base: Inst,
+
+        positionals: struct {
+            lhs: *Inst,
+            rhs: *Inst,
+        },
+        kw_args: struct {},
+    };
+
     pub const Block = struct {
         pub const base_tag = Tag.block;
         base: Inst,
@@ -210,14 +266,6 @@ pub const Inst = struct {
         kw_args: struct {},
     };
 
-    pub const Breakpoint = struct {
-        pub const base_tag = Tag.breakpoint;
-        base: Inst,
-
-        positionals: struct {},
-        kw_args: struct {},
-    };
-
     pub const BreakVoid = struct {
         pub const base_tag = Tag.breakvoid;
         base: Inst,
@@ -301,16 +349,6 @@ pub const Inst = struct {
         kw_args: struct {},
     };
 
-    pub const BoolNot = struct {
-        pub const base_tag = Tag.boolnot;
-        base: Inst,
-
-        positionals: struct {
-            operand: *Inst,
-        },
-        kw_args: struct {},
-    };
-
     pub const Str = struct {
         pub const base_tag = Tag.str;
         base: Inst,
@@ -353,16 +391,6 @@ pub const Inst = struct {
         kw_args: struct {},
     };
 
-    pub const Deref = struct {
-        pub const base_tag = Tag.deref;
-        base: Inst,
-
-        positionals: struct {
-            ptr: *Inst,
-        },
-        kw_args: struct {},
-    };
-
     pub const As = struct {
         pub const base_tag = Tag.as;
         pub const builtin_name = "@as";
@@ -392,32 +420,6 @@ pub const Inst = struct {
         },
     };
 
-    pub const Unreachable = struct {
-        pub const base_tag = Tag.@"unreachable";
-        base: Inst,
-
-        positionals: struct {},
-        kw_args: struct {},
-    };
-
-    pub const Return = struct {
-        pub const base_tag = Tag.@"return";
-        base: Inst,
-
-        positionals: struct {
-            operand: *Inst,
-        },
-        kw_args: struct {},
-    };
-
-    pub const ReturnVoid = struct {
-        pub const base_tag = Tag.returnvoid;
-        base: Inst,
-
-        positionals: struct {},
-        kw_args: struct {},
-    };
-
     pub const Fn = struct {
         pub const base_tag = Tag.@"fn";
         base: Inst,
@@ -587,42 +589,6 @@ pub const Inst = struct {
         kw_args: struct {},
     };
 
-    pub const Add = struct {
-        pub const base_tag = Tag.add;
-        base: Inst,
-
-        positionals: struct {
-            lhs: *Inst,
-            rhs: *Inst,
-        },
-        kw_args: struct {},
-    };
-
-    pub const Sub = struct {
-        pub const base_tag = Tag.sub;
-        base: Inst,
-
-        positionals: struct {
-            lhs: *Inst,
-            rhs: *Inst,
-        },
-        kw_args: struct {},
-    };
-
-    /// TODO get rid of the op positional arg and make that data part of
-    /// the base Inst tag.
-    pub const Cmp = struct {
-        pub const base_tag = Tag.cmp;
-        base: Inst,
-
-        positionals: struct {
-            lhs: *Inst,
-            op: std.math.CompareOperator,
-            rhs: *Inst,
-        },
-        kw_args: struct {},
-    };
-
     pub const CondBr = struct {
         pub const base_tag = Tag.condbr;
         base: Inst,
@@ -634,26 +600,6 @@ pub const Inst = struct {
         },
         kw_args: struct {},
     };
-
-    pub const IsNull = struct {
-        pub const base_tag = Tag.isnull;
-        base: Inst,
-
-        positionals: struct {
-            operand: *Inst,
-        },
-        kw_args: struct {},
-    };
-
-    pub const IsNonNull = struct {
-        pub const base_tag = Tag.isnonnull;
-        base: Inst,
-
-        positionals: struct {
-            operand: *Inst,
-        },
-        kw_args: struct {},
-    };
 };
 
 pub const ErrorMsg = struct {
@@ -775,7 +721,7 @@ const Writer = struct {
         comptime inst_tag: Inst.Tag,
         base: *Inst,
     ) (@TypeOf(stream).Error || error{OutOfMemory})!void {
-        const SpecificInst = Inst.TagToType(inst_tag);
+        const SpecificInst = inst_tag.Type();
         const inst = @fieldParentPtr(SpecificInst, "base", base);
         const Positionals = @TypeOf(inst.positionals);
         try stream.writeAll("= " ++ @tagName(inst_tag) ++ "(");
@@ -1102,7 +1048,7 @@ const Parser = struct {
         inline for (@typeInfo(Inst.Tag).Enum.fields) |field| {
             if (mem.eql(u8, field.name, fn_name)) {
                 const tag = @field(Inst.Tag, field.name);
-                return parseInstructionGeneric(self, field.name, Inst.TagToType(tag), body_ctx, name, contents_start);
+                return parseInstructionGeneric(self, field.name, tag.Type(), tag, body_ctx, name, contents_start);
             }
         }
         return self.fail("unknown instruction '{}'", .{fn_name});
@@ -1112,6 +1058,7 @@ const Parser = struct {
         self: *Parser,
         comptime fn_name: []const u8,
         comptime InstType: type,
+        tag: Inst.Tag,
         body_ctx: ?*Body,
         inst_name: []const u8,
         contents_start: usize,
@@ -1119,7 +1066,7 @@ const Parser = struct {
         const inst_specific = try self.arena.allocator.create(InstType);
         inst_specific.base = .{
             .src = self.i,
-            .tag = InstType.base_tag,
+            .tag = tag,
         };
 
         if (InstType == Inst.Block) {
@@ -1615,12 +1562,12 @@ const EmitZIR = struct {
         }
     }
 
-    fn emitNoOp(self: *EmitZIR, src: usize, comptime T: type) Allocator.Error!*Inst {
-        const new_inst = try self.arena.allocator.create(T);
+    fn emitNoOp(self: *EmitZIR, src: usize, tag: Inst.Tag) Allocator.Error!*Inst {
+        const new_inst = try self.arena.allocator.create(Inst.NoOp);
         new_inst.* = .{
             .base = .{
                 .src = src,
-                .tag = T.base_tag,
+                .tag = tag,
             },
             .positionals = .{},
             .kw_args = .{},
@@ -1628,41 +1575,18 @@ const EmitZIR = struct {
         return &new_inst.base;
     }
 
-    fn emitCmp(
-        self: *EmitZIR,
-        src: usize,
-        new_body: ZirBody,
-        old_inst: *ir.Inst.BinOp,
-        op: std.math.CompareOperator,
-    ) Allocator.Error!*Inst {
-        const new_inst = try self.arena.allocator.create(Inst.Cmp);
-        new_inst.* = .{
-            .base = .{
-                .src = src,
-                .tag = Inst.Cmp.base_tag,
-            },
-            .positionals = .{
-                .lhs = try self.resolveInst(new_body, old_inst.lhs),
-                .rhs = try self.resolveInst(new_body, old_inst.rhs),
-                .op = op,
-            },
-            .kw_args = .{},
-        };
-        return &new_inst.base;
-    }
-
     fn emitUnOp(
         self: *EmitZIR,
         src: usize,
         new_body: ZirBody,
         old_inst: *ir.Inst.UnOp,
-        comptime I: type,
+        tag: Inst.Tag,
     ) Allocator.Error!*Inst {
-        const new_inst = try self.arena.allocator.create(I);
+        const new_inst = try self.arena.allocator.create(Inst.UnOp);
         new_inst.* = .{
             .base = .{
                 .src = src,
-                .tag = I.base_tag,
+                .tag = tag,
             },
             .positionals = .{
                 .operand = try self.resolveInst(new_body, old_inst.operand),
@@ -1677,13 +1601,13 @@ const EmitZIR = struct {
         src: usize,
         new_body: ZirBody,
         old_inst: *ir.Inst.BinOp,
-        comptime I: type,
+        tag: Inst.Tag,
     ) Allocator.Error!*Inst {
-        const new_inst = try self.arena.allocator.create(I);
+        const new_inst = try self.arena.allocator.create(Inst.BinOp);
         new_inst.* = .{
             .base = .{
                 .src = src,
-                .tag = I.base_tag,
+                .tag = tag,
             },
             .positionals = .{
                 .lhs = try self.resolveInst(new_body, old_inst.lhs),
@@ -1708,26 +1632,25 @@ const EmitZIR = struct {
             const new_inst = switch (inst.tag) {
                 .constant => unreachable, // excluded from function bodies
 
-                .arg => try self.emitNoOp(inst.src, Inst.Arg),
-                .breakpoint => try self.emitNoOp(inst.src, Inst.Breakpoint),
-                .unreach => try self.emitNoOp(inst.src, Inst.Unreachable),
-                .retvoid => try self.emitNoOp(inst.src, Inst.ReturnVoid),
-
-                .not => try self.emitUnOp(inst.src, new_body, inst.castTag(.not).?, Inst.BoolNot),
-                .ret => try self.emitUnOp(inst.src, new_body, inst.castTag(.ret).?, Inst.Return),
-                .ptrtoint => try self.emitUnOp(inst.src, new_body, inst.castTag(.ptrtoint).?, Inst.PtrToInt),
-                .isnull => try self.emitUnOp(inst.src, new_body, inst.castTag(.isnull).?, Inst.IsNull),
-                .isnonnull => try self.emitUnOp(inst.src, new_body, inst.castTag(.isnonnull).?, Inst.IsNonNull),
-
-                .add => try self.emitBinOp(inst.src, new_body, inst.castTag(.add).?, Inst.Add),
-                .sub => try self.emitBinOp(inst.src, new_body, inst.castTag(.sub).?, Inst.Sub),
-
-                .cmp_lt => try self.emitCmp(inst.src, new_body, inst.castTag(.cmp_lt).?, .lt),
-                .cmp_lte => try self.emitCmp(inst.src, new_body, inst.castTag(.cmp_lte).?, .lte),
-                .cmp_eq => try self.emitCmp(inst.src, new_body, inst.castTag(.cmp_eq).?, .eq),
-                .cmp_gte => try self.emitCmp(inst.src, new_body, inst.castTag(.cmp_gte).?, .gte),
-                .cmp_gt => try self.emitCmp(inst.src, new_body, inst.castTag(.cmp_gt).?, .gt),
-                .cmp_neq => try self.emitCmp(inst.src, new_body, inst.castTag(.cmp_neq).?, .neq),
+                .arg => try self.emitNoOp(inst.src, .arg),
+                .breakpoint => try self.emitNoOp(inst.src, .breakpoint),
+                .unreach => try self.emitNoOp(inst.src, .@"unreachable"),
+                .retvoid => try self.emitNoOp(inst.src, .returnvoid),
+
+                .not => try self.emitUnOp(inst.src, new_body, inst.castTag(.not).?, .boolnot),
+                .ret => try self.emitUnOp(inst.src, new_body, inst.castTag(.ret).?, .@"return"),
+                .ptrtoint => try self.emitUnOp(inst.src, new_body, inst.castTag(.ptrtoint).?, .ptrtoint),
+                .isnull => try self.emitUnOp(inst.src, new_body, inst.castTag(.isnull).?, .isnull),
+                .isnonnull => try self.emitUnOp(inst.src, new_body, inst.castTag(.isnonnull).?, .isnonnull),
+
+                .add => try self.emitBinOp(inst.src, new_body, inst.castTag(.add).?, .add),
+                .sub => try self.emitBinOp(inst.src, new_body, inst.castTag(.sub).?, .sub),
+                .cmp_lt => try self.emitBinOp(inst.src, new_body, inst.castTag(.cmp_lt).?, .cmp_lt),
+                .cmp_lte => try self.emitBinOp(inst.src, new_body, inst.castTag(.cmp_lte).?, .cmp_lte),
+                .cmp_eq => try self.emitBinOp(inst.src, new_body, inst.castTag(.cmp_eq).?, .cmp_eq),
+                .cmp_gte => try self.emitBinOp(inst.src, new_body, inst.castTag(.cmp_gte).?, .cmp_gte),
+                .cmp_gt => try self.emitBinOp(inst.src, new_body, inst.castTag(.cmp_gt).?, .cmp_gt),
+                .cmp_neq => try self.emitBinOp(inst.src, new_body, inst.castTag(.cmp_neq).?, .cmp_neq),
 
                 .bitcast => blk: {
                     const old_inst = inst.castTag(.bitcast).?;
test/stage2/zir.zig
@@ -56,7 +56,7 @@ pub fn addCases(ctx: *TestContext) !void {
         \\  %result = add(%x0, %x1)
         \\
         \\  %expected = int(69)
-        \\  %ok = cmp(%result, eq, %expected)
+        \\  %ok = cmp_eq(%result, %expected)
         \\  %10 = condbr(%ok, {
         \\    %11 = returnvoid()
         \\  }, {