Commit 9a8851515b

Andrew Kelley <superjoe30@gmail.com>
2016-01-02 08:06:06
basic maybe type working
1 parent b3ac5c1
Changed files (6)
doc
vim
syntax
example
maybe_type
src
test
doc/vim/syntax/zig.vim
@@ -19,7 +19,7 @@ syn keyword zigType bool i8 u8 i16 u16 i32 u32 i64 u64 isize usize f32 f64 f128
 
 syn keyword zigBoolean true false
 
-syn match zigOperator display "\%(+\|-\|/\|*\|=\|\^\|&\||\|!\|>\|<\|%\)=\?"
+syn match zigOperator display "\%(+\|-\|/\|*\|=\|\^\|&\|?\||\|!\|>\|<\|%\)=\?"
 syn match zigOperator display "&&\|||"
 syn match zigArrowCharacter display "->"
 
example/maybe_type/main.zig
@@ -2,7 +2,7 @@ export executable "maybe_type";
 
 use "std.zig";
 
-fn main(argc: isize, argv: &&u8, env: &&u8) -> i32 {
+pub fn main(argc: isize, argv: &&u8, env: &&u8) -> i32 {
     const x : ?bool = true;
 
     if (const y ?= x) {
src/analyze.cpp
@@ -137,13 +137,44 @@ static TypeTableEntry *get_maybe_type(CodeGen *g, TypeTableEntry *child_type) {
         return child_type->maybe_parent;
     } else {
         TypeTableEntry *entry = new_type_table_entry(TypeTableEntryIdMaybe);
-        // TODO entry->type_ref
+        // create a struct with a boolean whether this is the null value
+        assert(child_type->type_ref);
+        LLVMTypeRef elem_types[] = {
+            child_type->type_ref,
+            LLVMInt1Type(),
+        };
+        entry->type_ref = LLVMStructType(elem_types, 2, false);
         buf_resize(&entry->name, 0);
         buf_appendf(&entry->name, "?%s", buf_ptr(&child_type->name));
-        // TODO entry->size_in_bits
-        // TODO entry->align_in_bits
+        entry->size_in_bits = child_type->size_in_bits + 8;
+        entry->align_in_bits = child_type->align_in_bits;
         assert(child_type->di_type);
-        // TODO entry->di_type
+
+
+        LLVMZigDIScope *compile_unit_scope = LLVMZigCompileUnitToScope(g->compile_unit);
+        LLVMZigDIFile *di_file = nullptr;
+        unsigned line = 0;
+        entry->di_type = LLVMZigCreateReplaceableCompositeType(g->dbuilder,
+            LLVMZigTag_DW_structure_type(), buf_ptr(&entry->name),
+            compile_unit_scope, di_file, line);
+
+        LLVMZigDIType *di_element_types[] = {
+            LLVMZigCreateDebugMemberType(g->dbuilder, LLVMZigTypeToScope(entry->di_type),
+                    "val", di_file, line, child_type->size_in_bits, child_type->align_in_bits, 0, 0,
+                    child_type->di_type),
+            LLVMZigCreateDebugMemberType(g->dbuilder, LLVMZigTypeToScope(entry->di_type),
+                    "maybe", di_file, line, 8, 8, 8, 0,
+                    child_type->di_type),
+        };
+        LLVMZigDIType *replacement_di_type = LLVMZigCreateDebugStructType(g->dbuilder,
+                compile_unit_scope,
+                buf_ptr(&entry->name),
+                di_file, line, entry->size_in_bits, entry->align_in_bits, 0,
+                nullptr, di_element_types, 2, 0, nullptr, "");
+
+        LLVMZigReplaceTemporary(g->dbuilder, entry->di_type, replacement_di_type);
+        entry->di_type = replacement_di_type;
+
         entry->data.maybe.child_type = child_type;
 
         g->type_table.put(&entry->name, entry);
@@ -814,13 +845,35 @@ static TypeTableEntry *resolve_type_compatibility(CodeGen *g, BlockContext *cont
         return expected_type;
     }
 
+    if (expected_type->id == TypeTableEntryIdMaybe &&
+        actual_type->id == TypeTableEntryIdMaybe)
+    {
+        TypeTableEntry *expected_child = expected_type->data.maybe.child_type;
+        TypeTableEntry *actual_child = actual_type->data.maybe.child_type;
+        return resolve_type_compatibility(g, context, node, expected_child, actual_child);
+    }
+
+    // implicit conversion from non maybe type to maybe type
+    if (expected_type->id == TypeTableEntryIdMaybe) {
+        TypeTableEntry *resolved_type = resolve_type_compatibility(g, context, node,
+                expected_type->data.maybe.child_type, actual_type);
+        if (resolved_type->id == TypeTableEntryIdInvalid) {
+            return resolved_type;
+        }
+        node->codegen_node->expr_node.implicit_maybe_cast.op = CastOpMaybeWrap;
+        node->codegen_node->expr_node.implicit_maybe_cast.after_type = expected_type;
+        node->codegen_node->expr_node.implicit_maybe_cast.source_node = node;
+        context->cast_expr_alloca_list.append(&node->codegen_node->expr_node.implicit_maybe_cast);
+        return expected_type;
+    }
+
     // implicit widening conversion
     if (expected_type->id == TypeTableEntryIdInt &&
         actual_type->id == TypeTableEntryIdInt &&
         expected_type->data.integral.is_signed == actual_type->data.integral.is_signed &&
         expected_type->size_in_bits > actual_type->size_in_bits)
     {
-        node->codegen_node->expr_node.implicit_cast.type = expected_type;
+        node->codegen_node->expr_node.implicit_cast.after_type = expected_type;
         node->codegen_node->expr_node.implicit_cast.op = CastOpIntWidenOrShorten;
         node->codegen_node->expr_node.implicit_cast.source_node = node;
         return expected_type;
@@ -831,7 +884,7 @@ static TypeTableEntry *resolve_type_compatibility(CodeGen *g, BlockContext *cont
         actual_type->id == TypeTableEntryIdArray &&
         actual_type->data.array.child_type == g->builtin_types.entry_u8)
     {
-        node->codegen_node->expr_node.implicit_cast.type = expected_type;
+        node->codegen_node->expr_node.implicit_cast.after_type = expected_type;
         node->codegen_node->expr_node.implicit_cast.op = CastOpArrayToString;
         node->codegen_node->expr_node.implicit_cast.source_node = node;
         context->cast_expr_alloca_list.append(&node->codegen_node->expr_node.implicit_cast);
@@ -1077,7 +1130,7 @@ static TypeTableEntry *analyze_cast_expr(CodeGen *g, ImportTableEntry *import, B
 
     CastNode *cast_node = &node->codegen_node->data.cast_node;
     cast_node->source_node = node;
-    cast_node->type = wanted_type;
+    cast_node->after_type = wanted_type;
 
     // special casing this for now, TODO think about casting and do a general solution
     if (wanted_type == g->builtin_types.entry_isize &&
@@ -1489,6 +1542,7 @@ static TypeTableEntry *analyze_if_var_expr(CodeGen *g, ImportTableEntry *import,
     assert(node->type == NodeTypeIfVarExpr);
 
     BlockContext *child_context = new_block_context(node, context);
+    node->codegen_node->data.if_var_node.block_context = child_context;
 
     analyze_variable_declaration_raw(g, import, child_context, node, &node->data.if_var_expr.var_decl, true);
 
src/analyze.hpp
@@ -272,10 +272,11 @@ struct FieldAccessNode {
 };
 
 enum CastOp {
+    CastOpNothing,
     CastOpPtrToInt,
     CastOpIntWidenOrShorten,
     CastOpArrayToString,
-    CastOpNothing,
+    CastOpMaybeWrap,
 };
 
 struct CastNode {
@@ -283,7 +284,7 @@ struct CastNode {
     // if op is CastOpArrayToString, this will be a pointer to
     // the string struct on the stack
     LLVMValueRef ptr;
-    TypeTableEntry *type;
+    TypeTableEntry *after_type;
     AstNode *source_node;
 };
 
@@ -294,7 +295,8 @@ struct ExprNode {
     BlockContext *block_context;
 
     // may be null for no cast
-    CastNode implicit_cast;
+    CastNode implicit_cast; // happens first
+    CastNode implicit_maybe_cast; // happens second
 };
 
 struct NumberLiteralNode {
@@ -315,6 +317,10 @@ struct StructValExprNode {
     AstNode *source_node;
 };
 
+struct IfVarNode {
+    BlockContext *block_context;
+};
+
 struct CodeGenNode {
     union {
         TypeNode type_node; // for NodeTypeType
@@ -330,6 +336,7 @@ struct CodeGenNode {
         VarDeclNode var_decl_node; // for NodeTypeVariableDeclaration
         StructValFieldNode struct_val_field_node; // for NodeTypeStructValueField
         StructValExprNode struct_val_expr_node; // for NodeTypeStructValueExpr
+        IfVarNode if_var_node; // for NodeTypeStructValueExpr
     } data;
     ExprNode expr_node; // for all the expression nodes
 };
src/codegen.cpp
@@ -65,6 +65,11 @@ void codegen_set_libc_path(CodeGen *g, Buf *libc_path) {
 static LLVMValueRef gen_expr(CodeGen *g, AstNode *expr_node);
 static LLVMValueRef gen_lvalue(CodeGen *g, AstNode *expr_node, AstNode *node, TypeTableEntry **out_type_entry);
 static LLVMValueRef gen_field_access_expr(CodeGen *g, AstNode *node, bool is_lvalue);
+static LLVMValueRef gen_var_decl_raw(CodeGen *g, AstNode *source_node, AstNodeVariableDeclaration *var_decl,
+        BlockContext *block_context, bool unwrap_maybe, LLVMValueRef *init_val);
+static LLVMValueRef gen_assign_raw(CodeGen *g, AstNode *source_node, BinOpType bin_op,
+        LLVMValueRef target_ref, LLVMValueRef value,
+        TypeTableEntry *op1_type, TypeTableEntry *op2_type);
     
 
 static TypeTableEntry *get_type_for_type_node(CodeGen *g, AstNode *type_node) {
@@ -132,7 +137,7 @@ static LLVMValueRef find_or_create_string(CodeGen *g, Buf *str, bool c) {
 }
 
 static TypeTableEntry *get_expr_type(AstNode *node) {
-    TypeTableEntry *cast_type = node->codegen_node->expr_node.implicit_cast.type;
+    TypeTableEntry *cast_type = node->codegen_node->expr_node.implicit_cast.after_type;
     return cast_type ? cast_type : node->codegen_node->expr_node.type_entry;
 }
 
@@ -367,6 +372,22 @@ static LLVMValueRef gen_bare_cast(CodeGen *g, AstNode *node, LLVMValueRef expr_v
     switch (cast_node->op) {
         case CastOpNothing:
             return expr_val;
+        case CastOpMaybeWrap:
+            {
+                assert(cast_node->ptr);
+                assert(wanted_type->id == TypeTableEntryIdMaybe);
+
+                add_debug_source_node(g, node);
+                LLVMValueRef val_ptr = LLVMBuildStructGEP(g->builder, cast_node->ptr, 0, "");
+                gen_assign_raw(g, node, BinOpTypeAssign,
+                        val_ptr, expr_val, wanted_type->data.maybe.child_type, actual_type);
+
+                add_debug_source_node(g, node);
+                LLVMValueRef maybe_ptr = LLVMBuildStructGEP(g->builder, cast_node->ptr, 1, "");
+                LLVMBuildStore(g->builder, LLVMConstAllOnes(LLVMInt1Type()), maybe_ptr);
+
+                return cast_node->ptr;
+            }
         case CastOpPtrToInt:
             return LLVMBuildPtrToInt(g->builder, expr_val, wanted_type->type_ref, "");
         case CastOpIntWidenOrShorten:
@@ -423,34 +444,33 @@ static LLVMValueRef gen_cast_expr(CodeGen *g, AstNode *node) {
 
 }
 
-static LLVMValueRef gen_arithmetic_bin_op(CodeGen *g,
+static LLVMValueRef gen_arithmetic_bin_op(CodeGen *g, AstNode *source_node,
     LLVMValueRef val1, LLVMValueRef val2,
     TypeTableEntry *op1_type, TypeTableEntry *op2_type,
-    AstNode *node)
+    BinOpType bin_op)
 {
-    assert(node->type == NodeTypeBinOpExpr);
     assert(op1_type == op2_type);
 
-    switch (node->data.bin_op_expr.bin_op) {
+    switch (bin_op) {
         case BinOpTypeBinOr:
         case BinOpTypeAssignBitOr:
-            add_debug_source_node(g, node);
+            add_debug_source_node(g, source_node);
             return LLVMBuildOr(g->builder, val1, val2, "");
         case BinOpTypeBinXor:
         case BinOpTypeAssignBitXor:
-            add_debug_source_node(g, node);
+            add_debug_source_node(g, source_node);
             return LLVMBuildXor(g->builder, val1, val2, "");
         case BinOpTypeBinAnd:
         case BinOpTypeAssignBitAnd:
-            add_debug_source_node(g, node);
+            add_debug_source_node(g, source_node);
             return LLVMBuildAnd(g->builder, val1, val2, "");
         case BinOpTypeBitShiftLeft:
         case BinOpTypeAssignBitShiftLeft:
-            add_debug_source_node(g, node);
+            add_debug_source_node(g, source_node);
             return LLVMBuildShl(g->builder, val1, val2, "");
         case BinOpTypeBitShiftRight:
         case BinOpTypeAssignBitShiftRight:
-            add_debug_source_node(g, node);
+            add_debug_source_node(g, source_node);
             if (op1_type->id == TypeTableEntryIdInt) {
                 return LLVMBuildAShr(g->builder, val1, val2, "");
             } else {
@@ -458,7 +478,7 @@ static LLVMValueRef gen_arithmetic_bin_op(CodeGen *g,
             }
         case BinOpTypeAdd:
         case BinOpTypeAssignPlus:
-            add_debug_source_node(g, node);
+            add_debug_source_node(g, source_node);
             if (op1_type->id == TypeTableEntryIdFloat) {
                 return LLVMBuildFAdd(g->builder, val1, val2, "");
             } else {
@@ -466,7 +486,7 @@ static LLVMValueRef gen_arithmetic_bin_op(CodeGen *g,
             }
         case BinOpTypeSub:
         case BinOpTypeAssignMinus:
-            add_debug_source_node(g, node);
+            add_debug_source_node(g, source_node);
             if (op1_type->id == TypeTableEntryIdFloat) {
                 return LLVMBuildFSub(g->builder, val1, val2, "");
             } else {
@@ -474,7 +494,7 @@ static LLVMValueRef gen_arithmetic_bin_op(CodeGen *g,
             }
         case BinOpTypeMult:
         case BinOpTypeAssignTimes:
-            add_debug_source_node(g, node);
+            add_debug_source_node(g, source_node);
             if (op1_type->id == TypeTableEntryIdFloat) {
                 return LLVMBuildFMul(g->builder, val1, val2, "");
             } else {
@@ -482,7 +502,7 @@ static LLVMValueRef gen_arithmetic_bin_op(CodeGen *g,
             }
         case BinOpTypeDiv:
         case BinOpTypeAssignDiv:
-            add_debug_source_node(g, node);
+            add_debug_source_node(g, source_node);
             if (op1_type->id == TypeTableEntryIdFloat) {
                 return LLVMBuildFDiv(g->builder, val1, val2, "");
             } else {
@@ -495,7 +515,7 @@ static LLVMValueRef gen_arithmetic_bin_op(CodeGen *g,
             }
         case BinOpTypeMod:
         case BinOpTypeAssignMod:
-            add_debug_source_node(g, node);
+            add_debug_source_node(g, source_node);
             if (op1_type->id == TypeTableEntryIdFloat) {
                 return LLVMBuildFRem(g->builder, val1, val2, "");
             } else {
@@ -530,7 +550,7 @@ static LLVMValueRef gen_arithmetic_bin_op_expr(CodeGen *g, AstNode *node) {
 
     TypeTableEntry *op1_type = get_expr_type(node->data.bin_op_expr.op1);
     TypeTableEntry *op2_type = get_expr_type(node->data.bin_op_expr.op2);
-    return gen_arithmetic_bin_op(g, val1, val2, op1_type, op2_type, node);
+    return gen_arithmetic_bin_op(g, node, val1, val2, op1_type, op2_type, node->data.bin_op_expr.bin_op);
 
 }
 
@@ -660,7 +680,7 @@ static LLVMValueRef gen_bool_or_expr(CodeGen *g, AstNode *expr_node) {
 static LLVMValueRef gen_struct_memcpy(CodeGen *g, AstNode *source_node, LLVMValueRef src, LLVMValueRef dest,
         TypeTableEntry *type_entry)
 {
-    assert(type_entry->id == TypeTableEntryIdStruct);
+    assert(type_entry->id == TypeTableEntryIdStruct || type_entry->id == TypeTableEntryIdMaybe);
 
     LLVMTypeRef ptr_u8 = LLVMPointerType(LLVMInt8Type(), 0);
 
@@ -679,6 +699,30 @@ static LLVMValueRef gen_struct_memcpy(CodeGen *g, AstNode *source_node, LLVMValu
     return LLVMBuildCall(g->builder, g->memcpy_fn_val, params, 5, "");
 }
 
+static LLVMValueRef gen_assign_raw(CodeGen *g, AstNode *source_node, BinOpType bin_op,
+        LLVMValueRef target_ref, LLVMValueRef value,
+        TypeTableEntry *op1_type, TypeTableEntry *op2_type)
+{
+    if (op1_type->id == TypeTableEntryIdStruct) {
+        assert(op2_type->id == TypeTableEntryIdStruct);
+        assert(op1_type == op2_type);
+        assert(bin_op == BinOpTypeAssign);
+
+        return gen_struct_memcpy(g, source_node, value, target_ref, op1_type);
+    }
+
+    if (bin_op != BinOpTypeAssign) {
+        assert(source_node->type == NodeTypeBinOpExpr);
+        add_debug_source_node(g, source_node->data.bin_op_expr.op1);
+        LLVMValueRef left_value = LLVMBuildLoad(g->builder, target_ref, "");
+
+        value = gen_arithmetic_bin_op(g, source_node, left_value, value, op1_type, op2_type, bin_op);
+    }
+
+    add_debug_source_node(g, source_node);
+    return LLVMBuildStore(g->builder, value, target_ref);
+}
+
 static LLVMValueRef gen_assign_expr(CodeGen *g, AstNode *node) {
     assert(node->type == NodeTypeBinOpExpr);
 
@@ -692,23 +736,7 @@ static LLVMValueRef gen_assign_expr(CodeGen *g, AstNode *node) {
 
     LLVMValueRef value = gen_expr(g, node->data.bin_op_expr.op2);
 
-    if (op1_type->id == TypeTableEntryIdStruct) {
-        assert(op2_type->id == TypeTableEntryIdStruct);
-        assert(op1_type == op2_type);
-        assert(node->data.bin_op_expr.bin_op == BinOpTypeAssign);
-
-        return gen_struct_memcpy(g, node, value, target_ref, op1_type);
-    }
-
-    if (node->data.bin_op_expr.bin_op != BinOpTypeAssign) {
-        add_debug_source_node(g, node->data.bin_op_expr.op1);
-        LLVMValueRef left_value = LLVMBuildLoad(g->builder, target_ref, "");
-
-        value = gen_arithmetic_bin_op(g, left_value, value, op1_type, op2_type, node);
-    }
-
-    add_debug_source_node(g, node);
-    return LLVMBuildStore(g->builder, value, target_ref);
+    return gen_assign_raw(g, node, node->data.bin_op_expr.bin_op, target_ref, value, op1_type, op2_type);
 }
 
 static LLVMValueRef gen_bin_op_expr(CodeGen *g, AstNode *node) {
@@ -769,18 +797,14 @@ static LLVMValueRef gen_return_expr(CodeGen *g, AstNode *node) {
     }
 }
 
-static LLVMValueRef gen_if_bool_expr(CodeGen *g, AstNode *node) {
-    assert(node->type == NodeTypeIfBoolExpr);
-    assert(node->data.if_bool_expr.condition);
-    assert(node->data.if_bool_expr.then_block);
-
-    LLVMValueRef cond_value = gen_expr(g, node->data.if_bool_expr.condition);
-
-    TypeTableEntry *then_type = get_expr_type(node->data.if_bool_expr.then_block);
+static LLVMValueRef gen_if_bool_expr_raw(CodeGen *g, AstNode *source_node, LLVMValueRef cond_value,
+        AstNode *then_node, AstNode *else_node)
+{
+    TypeTableEntry *then_type = get_expr_type(then_node);
     bool use_expr_value = (then_type->id != TypeTableEntryIdUnreachable &&
                            then_type->id != TypeTableEntryIdVoid);
 
-    if (node->data.if_bool_expr.else_node) {
+    if (else_node) {
         LLVMBasicBlockRef then_block = LLVMAppendBasicBlock(g->cur_fn->fn_value, "Then");
         LLVMBasicBlockRef else_block = LLVMAppendBasicBlock(g->cur_fn->fn_value, "Else");
         LLVMBasicBlockRef endif_block = LLVMAppendBasicBlock(g->cur_fn->fn_value, "EndIf");
@@ -788,13 +812,13 @@ static LLVMValueRef gen_if_bool_expr(CodeGen *g, AstNode *node) {
         LLVMBuildCondBr(g->builder, cond_value, then_block, else_block);
 
         LLVMPositionBuilderAtEnd(g->builder, then_block);
-        LLVMValueRef then_expr_result = gen_expr(g, node->data.if_bool_expr.then_block);
-        if (get_expr_type(node->data.if_bool_expr.then_block)->id != TypeTableEntryIdUnreachable)
+        LLVMValueRef then_expr_result = gen_expr(g, then_node);
+        if (get_expr_type(then_node)->id != TypeTableEntryIdUnreachable)
             LLVMBuildBr(g->builder, endif_block);
 
         LLVMPositionBuilderAtEnd(g->builder, else_block);
-        LLVMValueRef else_expr_result = gen_expr(g, node->data.if_bool_expr.else_node);
-        if (get_expr_type(node->data.if_bool_expr.else_node)->id != TypeTableEntryIdUnreachable)
+        LLVMValueRef else_expr_result = gen_expr(g, else_node);
+        if (get_expr_type(else_node)->id != TypeTableEntryIdUnreachable)
             LLVMBuildBr(g->builder, endif_block);
 
         LLVMPositionBuilderAtEnd(g->builder, endif_block);
@@ -818,17 +842,49 @@ static LLVMValueRef gen_if_bool_expr(CodeGen *g, AstNode *node) {
     LLVMBuildCondBr(g->builder, cond_value, then_block, endif_block);
 
     LLVMPositionBuilderAtEnd(g->builder, then_block);
-    gen_expr(g, node->data.if_bool_expr.then_block);
-    if (get_expr_type(node->data.if_bool_expr.then_block)->id != TypeTableEntryIdUnreachable)
+    gen_expr(g, then_node);
+    if (get_expr_type(then_node)->id != TypeTableEntryIdUnreachable)
         LLVMBuildBr(g->builder, endif_block);
 
     LLVMPositionBuilderAtEnd(g->builder, endif_block);
     return nullptr;
 }
 
+static LLVMValueRef gen_if_bool_expr(CodeGen *g, AstNode *node) {
+    assert(node->type == NodeTypeIfBoolExpr);
+    assert(node->data.if_bool_expr.condition);
+    assert(node->data.if_bool_expr.then_block);
+
+    LLVMValueRef cond_value = gen_expr(g, node->data.if_bool_expr.condition);
+
+    return gen_if_bool_expr_raw(g, node, cond_value,
+            node->data.if_bool_expr.then_block,
+            node->data.if_bool_expr.else_node);
+}
+
 static LLVMValueRef gen_if_var_expr(CodeGen *g, AstNode *node) {
     assert(node->type == NodeTypeIfVarExpr);
-    zig_panic("TODO gen_if_var_expr");
+    assert(node->data.if_var_expr.var_decl.expr);
+
+    BlockContext *old_block_context = g->cur_block_context;
+    BlockContext *new_block_context = node->codegen_node->data.if_var_node.block_context;
+
+    LLVMValueRef init_val;
+    gen_var_decl_raw(g, node, &node->data.if_var_expr.var_decl, new_block_context, true, &init_val);
+
+    // test if value is the maybe state
+    add_debug_source_node(g, node);
+    LLVMValueRef maybe_field_ptr = LLVMBuildStructGEP(g->builder, init_val, 1, "");
+    LLVMValueRef cond_value = LLVMBuildLoad(g->builder, maybe_field_ptr, "");
+
+    g->cur_block_context = new_block_context;
+
+    LLVMValueRef return_value = gen_if_bool_expr_raw(g, node, cond_value,
+            node->data.if_var_expr.then_block,
+            node->data.if_var_expr.else_node);
+
+    g->cur_block_context = old_block_context;
+    return return_value;
 }
 
 static LLVMValueRef gen_block(CodeGen *g, AstNode *block_node, TypeTableEntry *implicit_return_type) {
@@ -1058,6 +1114,55 @@ static LLVMValueRef gen_continue(CodeGen *g, AstNode *node) {
     return LLVMBuildBr(g->builder, dest_block);
 }
 
+static LLVMValueRef gen_var_decl_raw(CodeGen *g, AstNode *source_node, AstNodeVariableDeclaration *var_decl,
+        BlockContext *block_context, bool unwrap_maybe, LLVMValueRef *init_value)
+{
+    VariableTableEntry *variable = find_variable(block_context, &var_decl->symbol);
+
+    assert(variable);
+    assert(variable->is_ptr);
+
+    if (var_decl->expr) {
+        *init_value = gen_expr(g, var_decl->expr);
+    } else {
+        *init_value = LLVMConstNull(variable->type->type_ref);
+    }
+    if (variable->type->id == TypeTableEntryIdVoid) {
+        return nullptr;
+    } else {
+        LLVMValueRef store_instr;
+        LLVMValueRef value;
+        if (unwrap_maybe) {
+            assert(var_decl->expr);
+            add_debug_source_node(g, source_node);
+            LLVMValueRef maybe_field_ptr = LLVMBuildStructGEP(g->builder, *init_value, 0, "");
+            // TODO if it's a struct we might not want to load the pointer
+            value = LLVMBuildLoad(g->builder, maybe_field_ptr, "");
+        } else {
+            value = *init_value;
+        }
+        if ((variable->type->id == TypeTableEntryIdStruct || variable->type->id == TypeTableEntryIdMaybe) &&
+            var_decl->expr)
+        {
+            store_instr = gen_struct_memcpy(g, source_node, value, variable->value_ref, variable->type);
+        } else {
+            add_debug_source_node(g, source_node);
+            store_instr = LLVMBuildStore(g->builder, value, variable->value_ref);
+        }
+
+        LLVMZigDILocation *debug_loc = LLVMZigGetDebugLoc(source_node->line + 1, source_node->column + 1,
+                g->cur_block_context->di_scope);
+        LLVMZigInsertDeclare(g->dbuilder, variable->value_ref, variable->di_loc_var, debug_loc, store_instr);
+        return nullptr;
+    }
+}
+
+static LLVMValueRef gen_var_decl_expr(CodeGen *g, AstNode *node) {
+    LLVMValueRef init_val;
+    return gen_var_decl_raw(g, node, &node->data.variable_declaration,
+            node->codegen_node->expr_node.block_context, false, &init_val);
+}
+
 static LLVMValueRef gen_expr_no_cast(CodeGen *g, AstNode *node) {
     switch (node->type) {
         case NodeTypeBinOpExpr:
@@ -1065,38 +1170,7 @@ static LLVMValueRef gen_expr_no_cast(CodeGen *g, AstNode *node) {
         case NodeTypeReturnExpr:
             return gen_return_expr(g, node);
         case NodeTypeVariableDeclaration:
-            {
-                VariableTableEntry *variable = find_variable(
-                        node->codegen_node->expr_node.block_context,
-                        &node->data.variable_declaration.symbol);
-
-                assert(variable);
-                assert(variable->is_ptr);
-
-                LLVMValueRef value;
-                if (node->data.variable_declaration.expr) {
-                    value = gen_expr(g, node->data.variable_declaration.expr);
-                } else {
-                    value = LLVMConstNull(variable->type->type_ref);
-                }
-                if (variable->type->id == TypeTableEntryIdVoid) {
-                    return nullptr;
-                } else {
-                    LLVMValueRef store_instr;
-                    if (variable->type->id == TypeTableEntryIdStruct && node->data.variable_declaration.expr) {
-                        store_instr = gen_struct_memcpy(g, node, value, variable->value_ref, variable->type);
-                    } else {
-                        add_debug_source_node(g, node);
-                        store_instr = LLVMBuildStore(g->builder, value, variable->value_ref);
-                    }
-
-                    LLVMZigDILocation *debug_loc = LLVMZigGetDebugLoc(node->line + 1, node->column + 1,
-                            g->cur_block_context->di_scope);
-                    LLVMZigInsertDeclare(g->dbuilder, variable->value_ref, variable->di_loc_var,
-                            debug_loc, store_instr);
-                    return nullptr;
-                }
-            }
+            return gen_var_decl_expr(g, node);
         case NodeTypeCastExpr:
             return gen_cast_expr(g, node);
         case NodeTypePrefixOpExpr:
@@ -1174,7 +1248,9 @@ static LLVMValueRef gen_expr_no_cast(CodeGen *g, AstNode *node) {
                     assert(variable->value_ref);
                     if (variable->type->id == TypeTableEntryIdArray) {
                         return variable->value_ref;
-                    } else if (variable->type->id == TypeTableEntryIdStruct) {
+                    } else if (variable->type->id == TypeTableEntryIdStruct ||
+                               variable->type->id == TypeTableEntryIdMaybe)
+                    {
                         return variable->value_ref;
                     } else {
                         add_debug_source_node(g, node);
@@ -1225,6 +1301,12 @@ static LLVMValueRef gen_expr_no_cast(CodeGen *g, AstNode *node) {
     zig_unreachable();
 }
 
+static LLVMValueRef gen_cast_node(CodeGen *g, AstNode *node, LLVMValueRef val, TypeTableEntry *before_type,
+        CastNode *cast_node)
+{
+    return cast_node->after_type ? gen_bare_cast(g, node, val, before_type, cast_node->after_type, cast_node) : val;
+}
+
 static LLVMValueRef gen_expr(CodeGen *g, AstNode *node) {
     LLVMValueRef val = gen_expr_no_cast(g, node);
 
@@ -1234,11 +1316,17 @@ static LLVMValueRef gen_expr(CodeGen *g, AstNode *node) {
 
     assert(node->codegen_node);
 
-    TypeTableEntry *actual_type = node->codegen_node->expr_node.type_entry;
-    TypeTableEntry *cast_type = node->codegen_node->expr_node.implicit_cast.type;
+    {
+        TypeTableEntry *before_type = node->codegen_node->expr_node.type_entry;
+        val = gen_cast_node(g, node, val, before_type, &node->codegen_node->expr_node.implicit_cast);
+    }
+
+    {
+        TypeTableEntry *before_type = node->codegen_node->expr_node.implicit_cast.after_type;
+        val = gen_cast_node(g, node, val, before_type, &node->codegen_node->expr_node.implicit_maybe_cast);
+    }
 
-    return cast_type ? gen_bare_cast(g, node, val, actual_type, cast_type,
-            &node->codegen_node->expr_node.implicit_cast) : val;
+    return val;
 }
 
 static void build_label_blocks(CodeGen *g, AstNode *block_node) {
@@ -1460,7 +1548,7 @@ static void do_code_gen(CodeGen *g) {
             for (int cea_i = 0; cea_i < block_context->cast_expr_alloca_list.length; cea_i += 1) {
                 CastNode *cast_node = block_context->cast_expr_alloca_list.at(cea_i);
                 add_debug_source_node(g, cast_node->source_node);
-                cast_node->ptr = LLVMBuildAlloca(g->builder, cast_node->type->type_ref, "");
+                cast_node->ptr = LLVMBuildAlloca(g->builder, cast_node->after_type->type_ref, "");
             }
 
             // allocate structs which are struct value expressions
test/run_tests.cpp
@@ -674,6 +674,24 @@ export fn main(argc : isize, argv : &&u8, env : &&u8) -> i32 {
     return 0;
 }
     )SOURCE", "loop\nloop\nloop\nloop\n");
+
+    add_simple_case("maybe type", R"SOURCE(
+use "std.zig";
+export fn main(argc : isize, argv : &&u8, env : &&u8) -> i32 {
+    const x : ?bool = true;
+
+    if (const y ?= x) {
+        if (y) {
+            print_str("x is true\n");
+        } else {
+            print_str("x is false\n");
+        }
+    } else {
+        print_str("x is none\n");
+    }
+    return 0;
+}
+    )SOURCE", "x is true\n");
 }
 
 ////////////////////////////////////////////////////////////////////////////////////