Commit 4a40282391

Vexu <git@vexu.eu>
2020-08-12 21:30:14
stage2: implement unwrap optional
1 parent 5c1fe58
src-self-hosted/astgen.zig
@@ -106,6 +106,7 @@ pub fn expr(mod: *Module, scope: *Scope, rl: ResultLoc, node: *ast.Node) InnerEr
         .BoolLiteral => return rlWrap(mod, scope, rl, try boolLiteral(mod, scope, node.castTag(.BoolLiteral).?)),
         .NullLiteral => return rlWrap(mod, scope, rl, try nullLiteral(mod, scope, node.castTag(.NullLiteral).?)),
         .OptionalType => return rlWrap(mod, scope, rl, try optionalType(mod, scope, node.castTag(.OptionalType).?)),
+        .UnwrapOptional => return unwrapOptional(mod, scope, rl, node.castTag(.UnwrapOptional).?),
         else => return mod.failNode(scope, node, "TODO implement astgen.Expr for {}", .{@tagName(node.tag)}),
     }
 }
@@ -305,6 +306,17 @@ fn optionalType(mod: *Module, scope: *Scope, node: *ast.Node.SimplePrefixOp) Inn
     return addZIRUnOp(mod, scope, src, .optional_type, operand);
 }
 
+fn unwrapOptional(mod: *Module, scope: *Scope, rl: ResultLoc, node: *ast.Node.SimpleSuffixOp) InnerError!*zir.Inst {
+    const tree = scope.tree();
+    const src = tree.token_locs[node.rtoken].start;
+
+    const operand = try expr(mod, scope, .lvalue, node.lhs);
+    const unwrapped_ptr = try addZIRInst(mod, scope, src, zir.Inst.UnwrapOptional, .{ .operand = operand }, .{});
+    if (rl == .lvalue) return unwrapped_ptr;
+
+    return rlWrap(mod, scope, rl, try addZIRUnOp(mod, scope, src, .deref, unwrapped_ptr));
+}
+
 /// Identifier token -> String (allocated in scope.arena())
 pub fn identifierTokenString(mod: *Module, scope: *Scope, token: ast.TokenIndex) InnerError![]const u8 {
     const tree = scope.tree();
src-self-hosted/codegen.zig
@@ -668,6 +668,7 @@ fn Function(comptime arch: std.Target.Cpu.Arch) type {
                 .store => return self.genStore(inst.castTag(.store).?),
                 .sub => return self.genSub(inst.castTag(.sub).?),
                 .unreach => return MCValue{ .unreach = {} },
+                .unwrap_optional => return self.genUnwrapOptional(inst.castTag(.unwrap_optional).?),
             }
         }
 
@@ -817,6 +818,15 @@ fn Function(comptime arch: std.Target.Cpu.Arch) type {
             }
         }
 
+        fn genUnwrapOptional(self: *Self, inst: *ir.Inst.UnwrapOptional) !MCValue {
+            // No side effects, so if it's unreferenced, do nothing.
+            if (inst.base.isUnused())
+                return MCValue.dead;
+            switch (arch) {
+                else => return self.fail(inst.base.src, "TODO implement unwrap optional for {}", .{self.target.cpu.arch}),
+            }
+        }
+
         fn genLoad(self: *Self, inst: *ir.Inst.UnOp) !MCValue {
             const elem_ty = inst.base.ty;
             if (!elem_ty.hasCodeGenBits())
src-self-hosted/ir.zig
@@ -82,6 +82,7 @@ pub const Inst = struct {
         not,
         floatcast,
         intcast,
+        unwrap_optional,
 
         pub fn Type(tag: Tag) type {
             return switch (tag) {
@@ -124,6 +125,7 @@ pub const Inst = struct {
                 .condbr => CondBr,
                 .constant => Constant,
                 .loop => Loop,
+                .unwrap_optional => UnwrapOptional,
             };
         }
 
@@ -420,6 +422,26 @@ pub const Inst = struct {
         }
     };
 
+    pub const UnwrapOptional = struct {
+        pub const base_tag = Tag.unwrap_optional;
+        base: Inst,
+
+        operand: *Inst,
+        safety_check: bool,
+
+        pub fn operandCount(self: *const UnwrapOptional) usize {
+            return 1;
+        }
+        pub fn getOperand(self: *const UnwrapOptional, index: usize) ?*Inst {
+            var i = index;
+
+            if (i < 1)
+                return self.operand;
+            i -= 1;
+
+            return null;
+        }
+    };
 };
 
 pub const Body = struct {
src-self-hosted/Module.zig
@@ -2016,6 +2016,28 @@ pub fn addCall(
     return &inst.base;
 }
 
+pub fn addUnwrapOptional(
+    self: *Module,
+    block: *Scope.Block,
+    src: usize,
+    ty: Type,
+    operand: *Inst,
+    safety_check: bool,
+) !*Inst {
+    const inst = try block.arena.create(Inst.UnwrapOptional);
+    inst.* = .{
+        .base = .{
+            .tag = .unwrap_optional,
+            .ty = ty,
+            .src = src,
+        },
+        .operand = operand,
+        .safety_check = safety_check, 
+    };
+    try block.instructions.append(self.gpa, &inst.base);
+    return &inst.base;
+}
+
 pub fn constInst(self: *Module, scope: *Scope, src: usize, typed_value: TypedValue) !*Inst {
     const const_inst = try scope.arena().create(Inst.Constant);
     const_inst.* = .{
@@ -2488,9 +2510,9 @@ pub fn coerce(self: *Module, scope: *Scope, dest_type: Type, inst: *Inst) !*Inst
             if (child_type.eql(inst.ty)) {
                 return self.constInst(scope, inst.src, .{ .ty = dest_type, .val = val });
             }
-            return self.fail(scope, inst.src, "TODO optional wrap {} to {}", .{ val, inst.ty });
+            return self.fail(scope, inst.src, "TODO optional wrap {} to {}", .{ val, dest_type });
         } else if (child_type.eql(inst.ty)) {
-            return self.fail(scope, inst.src, "TODO optional wrap {}", .{inst.ty});
+            return self.fail(scope, inst.src, "TODO optional wrap {}", .{dest_type});
         }
     }
 
src-self-hosted/zir.zig
@@ -214,6 +214,8 @@ pub const Inst = struct {
         xor,
         /// Create an optional type '?T'
         optional_type,
+        /// Unwraps an optional value 'lhs.?'
+        unwrap_optional,
 
         pub fn Type(tag: Tag) type {
             return switch (tag) {
@@ -301,6 +303,7 @@ pub const Inst = struct {
                 .fntype => FnType,
                 .elemptr => ElemPtr,
                 .condbr => CondBr,
+                .unwrap_optional => UnwrapOptional,
             };
         }
 
@@ -376,6 +379,7 @@ pub const Inst = struct {
                 .typeof,
                 .xor,
                 .optional_type,
+                .unwrap_optional,
                 => false,
 
                 .@"break",
@@ -816,6 +820,18 @@ pub const Inst = struct {
         },
         kw_args: struct {},
     };
+
+    pub const UnwrapOptional = struct {
+        pub const base_tag = Tag.unwrap_optional;
+        base: Inst,
+
+        positionals: struct {
+            operand: *Inst,
+        },
+        kw_args: struct {
+            safety_check: bool = true,
+        },
+    };
 };
 
 pub const ErrorMsg = struct {
@@ -2141,6 +2157,25 @@ const EmitZIR = struct {
                     };
                     break :blk &new_inst.base;
                 },
+
+                .unwrap_optional => blk: {
+                    const old_inst = inst.castTag(.unwrap_optional).?;
+
+                    const new_inst = try self.arena.allocator.create(Inst.UnwrapOptional);
+                    new_inst.* = .{
+                        .base = .{
+                            .src = inst.src,
+                            .tag = Inst.UnwrapOptional.base_tag,
+                        },
+                        .positionals = .{
+                            .operand = try self.resolveInst(new_body, old_inst.operand),
+                        },
+                        .kw_args = .{
+                            .safety_check = old_inst.safety_check,
+                        },
+                    };
+                    break :blk &new_inst.base;
+                },
             };
             try instructions.append(new_inst);
             try inst_table.put(inst, new_inst);
src-self-hosted/zir_sema.zig
@@ -107,6 +107,7 @@ pub fn analyzeInst(mod: *Module, scope: *Scope, old_inst: *zir.Inst) InnerError!
         .boolnot => return analyzeInstBoolNot(mod, scope, old_inst.castTag(.boolnot).?),
         .typeof => return analyzeInstTypeOf(mod, scope, old_inst.castTag(.typeof).?),
         .optional_type => return analyzeInstOptionalType(mod, scope, old_inst.castTag(.optional_type).?),
+        .unwrap_optional => return analyzeInstUnwrapOptional(mod, scope, old_inst.castTag(.unwrap_optional).?),
     }
 }
 
@@ -306,8 +307,19 @@ fn analyzeInstRetPtr(mod: *Module, scope: *Scope, inst: *zir.Inst.NoOp) InnerErr
 
 fn analyzeInstRef(mod: *Module, scope: *Scope, inst: *zir.Inst.UnOp) InnerError!*Inst {
     const operand = try resolveInst(mod, scope, inst.positionals.operand);
-    const b = try mod.requireRuntimeBlock(scope, inst.base.src);
     const ptr_type = try mod.singleConstPtrType(scope, inst.base.src, operand.ty);
+
+    if (operand.value()) |val| {
+        const ref_payload = try scope.arena().create(Value.Payload.RefVal);
+        ref_payload.* = .{ .val = val };
+    
+        return mod.constInst(scope, inst.base.src, .{
+            .ty = ptr_type,
+            .val = Value.initPayload(&ref_payload.base),
+        });
+    }
+
+    const b = try mod.requireRuntimeBlock(scope, inst.base.src);
     return mod.addUnOp(b, inst.base.src, ptr_type, .ref, operand);
 }
 
@@ -649,6 +661,34 @@ fn analyzeInstOptionalType(mod: *Module, scope: *Scope, optional: *zir.Inst.UnOp
     }));
 }
 
+fn analyzeInstUnwrapOptional(mod: *Module, scope: *Scope, unwrap: *zir.Inst.UnwrapOptional) InnerError!*Inst {
+    const operand = try resolveInst(mod, scope, unwrap.positionals.operand);
+    assert(operand.ty.zigTypeTag() == .Pointer);
+
+    if (operand.ty.elemType().zigTypeTag() != .Optional) {
+        return mod.fail(scope, unwrap.base.src, "expected optional type, found {}", .{operand.ty.elemType()});
+    }
+
+    const child_type = operand.ty.elemType().elemType();
+    const child_pointer = if (operand.ty.isConstPtr())
+        try mod.singleConstPtrType(scope, unwrap.base.src, child_type)
+    else
+        try mod.singleMutPtrType(scope, unwrap.base.src, child_type);
+
+    if (operand.value()) |val| {
+        if (val.tag() == .null_value) {
+            return mod.fail(scope, unwrap.base.src, "unable to unwrap null", .{});
+        }
+        return mod.constInst(scope, unwrap.base.src, .{
+            .ty = child_pointer,
+            .val = val,
+        });
+    }
+
+    const b = try mod.requireRuntimeBlock(scope, unwrap.base.src);
+    return mod.addUnwrapOptional(b, unwrap.base.src, child_pointer, operand, unwrap.kw_args.safety_check);
+}
+
 fn analyzeInstFnType(mod: *Module, scope: *Scope, fntype: *zir.Inst.FnType) InnerError!*Inst {
     const return_type = try resolveType(mod, scope, fntype.positionals.return_type);
 
test/stage2/compare_output.zig
@@ -31,6 +31,11 @@ pub fn addCases(ctx: *TestContext) !void {
             \\export fn _start() noreturn {
             \\    print();
             \\
+            \\    const a: u32 = 2;
+            \\    const b: ?u32 = a;
+            \\    const c = b.?;
+            \\    if (c != 2) unreachable;
+            \\
             \\    exit();
             \\}
             \\