Commit 87b7c28c9a

Andrew Kelley <superjoe30@gmail.com>
2016-09-27 04:33:33
cstr.len and cstr.cmp can run at compile time
closes #140
1 parent 7ce7e2c
Changed files (3)
src/all_types.hpp
@@ -1111,33 +1111,6 @@ struct FnTableEntry {
     ZigList<AstNode *> goto_list;
 };
 
-struct EvalVar {
-    Buf *name;
-    ConstExprValue value;
-};
-
-struct EvalScope {
-    BlockContext *block_context;
-    ZigList<EvalVar> vars;
-};
-
-struct EvalFnRoot {
-    CodeGen *codegen;
-    FnTableEntry *fn;
-    AstNode *call_node;
-    size_t branch_quota;
-    size_t branches_used;
-    AstNode *exceeded_quota_node;
-    bool abort;
-};
-
-struct EvalFn {
-    EvalFnRoot *root;
-    FnTableEntry *fn;
-    ConstExprValue *return_expr;
-    ZigList<EvalScope*> scope_stack;
-};
-
 enum BuiltinFnId {
     BuiltinFnIdInvalid,
     BuiltinFnIdMemcpy,
src/eval.cpp
@@ -2,6 +2,34 @@
 #include "analyze.hpp"
 #include "error.hpp"
 
+struct EvalVar {
+    Buf *name;
+    ConstExprValue value;
+};
+
+struct EvalScope {
+    BlockContext *block_context;
+    ZigList<EvalVar> vars;
+};
+
+struct EvalFnRoot {
+    CodeGen *codegen;
+    FnTableEntry *fn;
+    AstNode *call_node;
+    size_t branch_quota;
+    size_t branches_used;
+    AstNode *exceeded_quota_node;
+    bool abort;
+};
+
+struct EvalFn {
+    EvalFnRoot *root;
+    FnTableEntry *fn;
+    ConstExprValue *return_expr;
+    ZigList<EvalScope*> scope_stack;
+};
+
+
 static bool eval_fn_args(EvalFnRoot *efr, FnTableEntry *fn, ConstExprValue *args, ConstExprValue *out_val);
 
 bool const_values_equal(ConstExprValue *a, ConstExprValue *b, TypeTableEntry *type_entry) {
@@ -94,9 +122,9 @@ static bool eval_return(EvalFn *ef, AstNode *node, ConstExprValue *out) {
 }
 
 static bool eval_bool_bin_op_bool(bool a, BinOpType bin_op, bool b) {
-    if (bin_op == BinOpTypeBoolOr) {
+    if (bin_op == BinOpTypeBoolOr || bin_op == BinOpTypeAssignBoolOr) {
         return a || b;
-    } else if (bin_op == BinOpTypeBoolAnd) {
+    } else if (bin_op == BinOpTypeBoolAnd || bin_op == BinOpTypeAssignBoolAnd) {
         return a && b;
     } else {
         zig_unreachable();
@@ -180,6 +208,31 @@ static int eval_const_expr_bin_op_bignum(ConstExprValue *op1_val, ConstExprValue
     return 0;
 }
 
+bool eval_const_expr_bin_op_handle_errors(EvalFn *ef, AstNode *node,
+        ConstExprValue *op1_val, TypeTableEntry *op1_type,
+        BinOpType bin_op, ConstExprValue *op2_val, TypeTableEntry *op2_type, ConstExprValue *out_val)
+{
+    int err;
+    if ((err = eval_const_expr_bin_op(op1_val, op1_type, bin_op, op2_val, op2_type, out_val))) {
+        ef->root->abort = true;
+        if (err == ErrorDivByZero) {
+            ErrorMsg *msg = add_node_error(ef->root->codegen, ef->root->fn->fn_def_node,
+                    buf_sprintf("function evaluation caused division by zero"));
+            add_error_note(ef->root->codegen, msg, ef->root->call_node, buf_sprintf("called from here"));
+            add_error_note(ef->root->codegen, msg, node, buf_sprintf("division by zero here"));
+        } else if (err == ErrorOverflow) {
+            ErrorMsg *msg = add_node_error(ef->root->codegen, ef->root->fn->fn_def_node,
+                    buf_sprintf("function evaluation caused overflow"));
+            add_error_note(ef->root->codegen, msg, ef->root->call_node, buf_sprintf("called from here"));
+            add_error_note(ef->root->codegen, msg, node, buf_sprintf("overflow occurred here"));
+        } else {
+            zig_unreachable();
+        }
+        return true;
+    }
+    return false;
+}
+
 int eval_const_expr_bin_op(ConstExprValue *op1_val, TypeTableEntry *op1_type,
         BinOpType bin_op, ConstExprValue *op2_val, TypeTableEntry *op2_type, ConstExprValue *out_val)
 {
@@ -190,25 +243,12 @@ int eval_const_expr_bin_op(ConstExprValue *op1_val, TypeTableEntry *op1_type,
 
     switch (bin_op) {
         case BinOpTypeAssign:
-        case BinOpTypeAssignTimes:
-        case BinOpTypeAssignTimesWrap:
-        case BinOpTypeAssignDiv:
-        case BinOpTypeAssignMod:
-        case BinOpTypeAssignPlus:
-        case BinOpTypeAssignPlusWrap:
-        case BinOpTypeAssignMinus:
-        case BinOpTypeAssignMinusWrap:
-        case BinOpTypeAssignBitShiftLeft:
-        case BinOpTypeAssignBitShiftLeftWrap:
-        case BinOpTypeAssignBitShiftRight:
-        case BinOpTypeAssignBitAnd:
-        case BinOpTypeAssignBitXor:
-        case BinOpTypeAssignBitOr:
-        case BinOpTypeAssignBoolAnd:
-        case BinOpTypeAssignBoolOr:
-            zig_unreachable();
+            *out_val = *op2_val;
+            return 0;
         case BinOpTypeBoolOr:
         case BinOpTypeBoolAnd:
+        case BinOpTypeAssignBoolAnd:
+        case BinOpTypeAssignBoolOr:
             assert(op1_type->id == TypeTableEntryIdBool);
             assert(op2_type->id == TypeTableEntryIdBool);
             out_val->data.x_bool = eval_bool_bin_op_bool(op1_val->data.x_bool, bin_op, op2_val->data.x_bool);
@@ -264,30 +304,43 @@ int eval_const_expr_bin_op(ConstExprValue *op1_val, TypeTableEntry *op1_type,
                 return 0;
             }
         case BinOpTypeAdd:
+        case BinOpTypeAssignPlus:
             return eval_const_expr_bin_op_bignum(op1_val, op2_val, out_val, bignum_add, op1_type, false);
         case BinOpTypeAddWrap:
+        case BinOpTypeAssignPlusWrap:
             return eval_const_expr_bin_op_bignum(op1_val, op2_val, out_val, bignum_add, op1_type, true);
         case BinOpTypeBinOr:
+        case BinOpTypeAssignBitOr:
             return eval_const_expr_bin_op_bignum(op1_val, op2_val, out_val, bignum_or, op1_type, false);
         case BinOpTypeBinXor:
+        case BinOpTypeAssignBitXor:
             return eval_const_expr_bin_op_bignum(op1_val, op2_val, out_val, bignum_xor, op1_type, false);
         case BinOpTypeBinAnd:
+        case BinOpTypeAssignBitAnd:
             return eval_const_expr_bin_op_bignum(op1_val, op2_val, out_val, bignum_and, op1_type, false);
         case BinOpTypeBitShiftLeft:
+        case BinOpTypeAssignBitShiftLeft:
             return eval_const_expr_bin_op_bignum(op1_val, op2_val, out_val, bignum_shl, op1_type, false);
         case BinOpTypeBitShiftLeftWrap:
+        case BinOpTypeAssignBitShiftLeftWrap:
             return eval_const_expr_bin_op_bignum(op1_val, op2_val, out_val, bignum_shl, op1_type, true);
         case BinOpTypeBitShiftRight:
+        case BinOpTypeAssignBitShiftRight:
             return eval_const_expr_bin_op_bignum(op1_val, op2_val, out_val, bignum_shr, op1_type, false);
         case BinOpTypeSub:
+        case BinOpTypeAssignMinus:
             return eval_const_expr_bin_op_bignum(op1_val, op2_val, out_val, bignum_sub, op1_type, false);
         case BinOpTypeSubWrap:
+        case BinOpTypeAssignMinusWrap:
             return eval_const_expr_bin_op_bignum(op1_val, op2_val, out_val, bignum_sub, op1_type, true);
         case BinOpTypeMult:
+        case BinOpTypeAssignTimes:
             return eval_const_expr_bin_op_bignum(op1_val, op2_val, out_val, bignum_mul, op1_type, false);
         case BinOpTypeMultWrap:
+        case BinOpTypeAssignTimesWrap:
             return eval_const_expr_bin_op_bignum(op1_val, op2_val, out_val, bignum_mul, op1_type, true);
         case BinOpTypeDiv:
+        case BinOpTypeAssignDiv:
             {
                 bool is_int = false;
                 bool is_float = false;
@@ -309,6 +362,7 @@ int eval_const_expr_bin_op(ConstExprValue *op1_val, TypeTableEntry *op1_type,
                 }
             }
         case BinOpTypeMod:
+        case BinOpTypeAssignMod:
             return eval_const_expr_bin_op_bignum(op1_val, op2_val, out_val, bignum_mod, op1_type, false);
         case BinOpTypeUnwrapMaybe:
             zig_panic("TODO");
@@ -320,13 +374,66 @@ int eval_const_expr_bin_op(ConstExprValue *op1_val, TypeTableEntry *op1_type,
     zig_unreachable();
 }
 
-static bool eval_bin_op_expr(EvalFn *ef, AstNode *node, ConstExprValue *out_val) {
-    assert(node->type == NodeTypeBinOpExpr);
+static EvalVar *find_var(EvalFn *ef, Buf *name) {
+    size_t scope_index = ef->scope_stack.length - 1;
+    while (scope_index != SIZE_MAX) {
+        EvalScope *scope = ef->scope_stack.at(scope_index);
+        for (size_t var_i = 0; var_i < scope->vars.length; var_i += 1) {
+            EvalVar *var = &scope->vars.at(var_i);
+            if (buf_eql_buf(var->name, name)) {
+                return var;
+            }
+        }
+        scope_index -= 1;
+    }
 
+    return nullptr;
+}
+
+static bool eval_get_lvalue(EvalFn *ef, AstNode *node, ConstExprValue **lvalue) {
+    if (node->type == NodeTypeSymbol) {
+        Buf *name = node->data.symbol_expr.symbol;
+        EvalVar *var = find_var(ef, name);
+        assert(var);
+        *lvalue = &var->value;
+    } else {
+        zig_panic("TODO eval other lvalue types");
+    }
+    return false;
+}
+
+static bool eval_bin_op_assign(EvalFn *ef, AstNode *node, ConstExprValue *out_val) {
     AstNode *op1 = node->data.bin_op_expr.op1;
     AstNode *op2 = node->data.bin_op_expr.op2;
     BinOpType bin_op = node->data.bin_op_expr.bin_op;
 
+    TypeTableEntry *op2_type = get_resolved_expr(op2)->type_entry;
+    assert(op2_type);
+
+    ConstExprValue *assign_result_val;
+    if (eval_get_lvalue(ef, op1, &assign_result_val)) return true;
+
+    ConstExprValue op1_val = *assign_result_val;
+
+    ConstExprValue op2_val = {0};
+    if (eval_expr(ef, op2, &op2_val)) return true;
+
+    if (eval_const_expr_bin_op_handle_errors(ef, node, &op1_val, op2_type, bin_op, &op2_val, op2_type,
+        assign_result_val))
+    {
+        return true;
+    }
+
+    out_val->ok = true;
+    out_val->depends_on_compile_var = false;
+    return false;
+}
+
+static bool eval_bin_op_expr(EvalFn *ef, AstNode *node, ConstExprValue *out_val) {
+    assert(node->type == NodeTypeBinOpExpr);
+
+    BinOpType bin_op = node->data.bin_op_expr.bin_op;
+
     switch (bin_op) {
         case BinOpTypeAssign:
         case BinOpTypeAssignTimes:
@@ -345,7 +452,7 @@ static bool eval_bin_op_expr(EvalFn *ef, AstNode *node, ConstExprValue *out_val)
         case BinOpTypeAssignBitOr:
         case BinOpTypeAssignBoolAnd:
         case BinOpTypeAssignBoolOr:
-            zig_panic("TODO");
+            return eval_bin_op_assign(ef, node, out_val);
         case BinOpTypeBoolOr:
         case BinOpTypeBoolAnd:
         case BinOpTypeCmpEq:
@@ -376,6 +483,10 @@ static bool eval_bin_op_expr(EvalFn *ef, AstNode *node, ConstExprValue *out_val)
             zig_unreachable();
     }
 
+    AstNode *op1 = node->data.bin_op_expr.op1;
+    AstNode *op2 = node->data.bin_op_expr.op2;
+
+
     TypeTableEntry *op1_type = get_resolved_expr(op1)->type_entry;
     TypeTableEntry *op2_type = get_resolved_expr(op2)->type_entry;
 
@@ -388,22 +499,7 @@ static bool eval_bin_op_expr(EvalFn *ef, AstNode *node, ConstExprValue *out_val)
     ConstExprValue op2_val = {0};
     if (eval_expr(ef, op2, &op2_val)) return true;
 
-    int err;
-    if ((err = eval_const_expr_bin_op(&op1_val, op1_type, bin_op, &op2_val, op2_type, out_val))) {
-        ef->root->abort = true;
-        if (err == ErrorDivByZero) {
-            ErrorMsg *msg = add_node_error(ef->root->codegen, ef->root->fn->fn_def_node,
-                    buf_sprintf("function evaluation caused division by zero"));
-            add_error_note(ef->root->codegen, msg, ef->root->call_node, buf_sprintf("called from here"));
-            add_error_note(ef->root->codegen, msg, node, buf_sprintf("division by zero here"));
-        } else if (err == ErrorOverflow) {
-            ErrorMsg *msg = add_node_error(ef->root->codegen, ef->root->fn->fn_def_node,
-                    buf_sprintf("function evaluation caused overflow"));
-            add_error_note(ef->root->codegen, msg, ef->root->call_node, buf_sprintf("called from here"));
-            add_error_note(ef->root->codegen, msg, node, buf_sprintf("overflow occurred here"));
-        } else {
-            zig_unreachable();
-        }
+    if (eval_const_expr_bin_op_handle_errors(ef, node, &op1_val, op1_type, bin_op, &op2_val, op2_type, out_val)) {
         return true;
     }
 
@@ -412,22 +508,6 @@ static bool eval_bin_op_expr(EvalFn *ef, AstNode *node, ConstExprValue *out_val)
     return false;
 }
 
-static EvalVar *find_var(EvalFn *ef, Buf *name) {
-    size_t scope_index = ef->scope_stack.length - 1;
-    while (scope_index != SIZE_MAX) {
-        EvalScope *scope = ef->scope_stack.at(scope_index);
-        for (size_t var_i = 0; var_i < scope->vars.length; var_i += 1) {
-            EvalVar *var = &scope->vars.at(var_i);
-            if (buf_eql_buf(var->name, name)) {
-                return var;
-            }
-        }
-        scope_index -= 1;
-    }
-
-    return nullptr;
-}
-
 static bool eval_symbol_expr(EvalFn *ef, AstNode *node, ConstExprValue *out_val) {
     assert(node->type == NodeTypeSymbol);
 
@@ -456,7 +536,7 @@ static bool eval_container_init_expr(EvalFn *ef, AstNode *node, ConstExprValue *
     ContainerInitKind kind = container_init_expr->kind;
 
     if (container_init_expr->enum_type) {
-        zig_panic("TODO");
+        zig_panic("TODO eval enum init");
     }
 
     TypeTableEntry *container_type = resolve_expr_type(container_init_expr->type);
@@ -514,7 +594,7 @@ static bool eval_container_init_expr(EvalFn *ef, AstNode *node, ConstExprValue *
                 elem_val->depends_on_compile_var;
         }
     } else {
-        zig_panic("TODO");
+        zig_panic("TODO init more container kinds");
     }
 
 
@@ -874,7 +954,7 @@ static bool eval_fn_call_builtin(EvalFn *ef, AstNode *node, ConstExprValue *out_
         case BuiltinFnIdEmbedFile:
         case BuiltinFnIdCmpExchange:
         case BuiltinFnIdTruncate:
-            zig_panic("TODO");
+            zig_panic("TODO builtin function");
         case BuiltinFnIdBreakpoint:
         case BuiltinFnIdInvalid:
         case BuiltinFnIdFrameAddress:
@@ -909,7 +989,7 @@ static bool eval_fn_call_expr(EvalFn *ef, AstNode *node, ConstExprValue *out_val
     if (fn_ref_expr->type == NodeTypeFieldAccessExpr &&
         fn_ref_expr->data.field_access_expr.is_member_fn)
     {
-        zig_panic("TODO");
+        zig_panic("TODO field access member fn");
     }
 
     if (!fn_table_entry) {
@@ -941,7 +1021,7 @@ static bool eval_field_access_expr(EvalFn *ef, AstNode *node, ConstExprValue *ou
     if (struct_type->id == TypeTableEntryIdArray) {
         Buf *name = node->data.field_access_expr.field_name;
         assert(buf_eql_str(name, "len"));
-        zig_panic("TODO");
+        zig_panic("TODO field access array");
     } else if (struct_type->id == TypeTableEntryIdStruct || (struct_type->id == TypeTableEntryIdPointer &&
                struct_type->data.pointer.child_type->id == TypeTableEntryIdStruct))
     {
@@ -954,17 +1034,17 @@ static bool eval_field_access_expr(EvalFn *ef, AstNode *node, ConstExprValue *ou
             *out_val = *field_value;
             assert(out_val->ok);
         } else {
-            zig_panic("TODO");
+            zig_panic("TODO field access struct");
         }
     } else if (struct_type->id == TypeTableEntryIdMetaType) {
         TypeTableEntry *child_type = resolve_expr_type(struct_expr);
         if (child_type->id == TypeTableEntryIdPureError) {
             *out_val = get_resolved_expr(node)->const_val;
         } else {
-            zig_panic("TODO");
+            zig_panic("TODO field access meta type");
         }
     } else if (struct_type->id == TypeTableEntryIdNamespace) {
-        zig_panic("TODO");
+        zig_panic("TODO field access namespace");
     } else {
         zig_unreachable();
     }
@@ -989,7 +1069,7 @@ static bool eval_for_expr(EvalFn *ef, AstNode *node, ConstExprValue *out_val) {
     Buf *elem_var_name = elem_node->data.symbol_expr.symbol;
 
     if (node->data.for_expr.elem_is_ptr) {
-        zig_panic("TODO");
+        zig_panic("TODO for elem is ptr");
     }
 
     Buf *index_var_name = nullptr;
@@ -1062,7 +1142,7 @@ static bool eval_array_access_expr(EvalFn *ef, AstNode *node, ConstExprValue *ou
 
     if (array_type->id == TypeTableEntryIdPointer) {
         if (index_int >= array_val.data.x_ptr.len) {
-            zig_panic("TODO");
+            zig_panic("TODO array access pointer");
         }
         *out_val = *array_val.data.x_ptr.ptr[index_int];
     } else if (array_type->id == TypeTableEntryIdStruct) {
@@ -1071,7 +1151,7 @@ static bool eval_array_access_expr(EvalFn *ef, AstNode *node, ConstExprValue *ou
         ConstExprValue *len_value = array_val.data.x_struct.fields[1];
         uint64_t len_int = len_value->data.x_bignum.data.x_uint;
         if (index_int >= len_int) {
-            zig_panic("TODO");
+            zig_panic("TODO array access slice");
         }
 
         ConstExprValue *ptr_value = array_val.data.x_struct.fields[0];
@@ -1079,7 +1159,7 @@ static bool eval_array_access_expr(EvalFn *ef, AstNode *node, ConstExprValue *ou
     } else if (array_type->id == TypeTableEntryIdArray) {
         uint64_t array_len = array_type->data.array.len;
         if (index_int >= array_len) {
-            zig_panic("TODO");
+            zig_panic("TODO array access array");
         }
         *out_val = *array_val.data.x_array.fields[index_int];
     } else {
@@ -1152,7 +1232,7 @@ static bool eval_prefix_op_expr(EvalFn *ef, AstNode *node, ConstExprValue *out_v
                     return true;
                 }
             } else if (expr_type->id == TypeTableEntryIdFloat) {
-                zig_panic("TODO");
+                zig_panic("TODO prefix op on floats");
             } else {
                 zig_unreachable();
             }
@@ -1162,7 +1242,7 @@ static bool eval_prefix_op_expr(EvalFn *ef, AstNode *node, ConstExprValue *out_v
         case PrefixOpError:
         case PrefixOpUnwrapError:
         case PrefixOpUnwrapMaybe:
-            zig_panic("TODO");
+            zig_panic("TODO more prefix operations");
         case PrefixOpInvalid:
             zig_unreachable();
     }
@@ -1308,7 +1388,7 @@ static bool eval_expr(EvalFn *ef, AstNode *node, ConstExprValue *out) {
         case NodeTypeErrorType:
         case NodeTypeTypeLiteral:
         case NodeTypeVarLiteral:
-            zig_panic("TODO");
+            zig_panic("TODO expr node");
         case NodeTypeRoot:
         case NodeTypeFnProto:
         case NodeTypeFnDef:
std/cstr.zig
@@ -6,22 +6,22 @@ const assert = debug.assert;
 
 const strlen = len;
 
-// TODO fix https://github.com/andrewrk/zig/issues/140
-// and then make this able to run at compile time
-#static_eval_enable(false)
 pub fn len(ptr: &const u8) -> usize {
     var count: usize = 0;
     while (ptr[count] != 0; count += 1) {}
     return count;
 }
 
-// TODO fix https://github.com/andrewrk/zig/issues/140
-// and then make this able to run at compile time
-#static_eval_enable(false)
-pub fn cmp(a: &const u8, b: &const u8) -> i32 {
+pub fn cmp(a: &const u8, b: &const u8) -> i8 {
     var index: usize = 0;
     while (a[index] == b[index] && a[index] != 0; index += 1) {}
-    return a[index] - b[index];
+    return if (a[index] > b[index]) {
+        1
+    } else if (a[index] < b[index]) {
+        -1
+    } else {
+        0
+    };
 }
 
 pub fn toSliceConst(str: &const u8) -> []const u8 {
@@ -145,3 +145,13 @@ fn testSimpleCBuf() {
     %%buf2.resize(4);
     assert(buf.startsWithCBuf(&buf2));
 }
+
+#attribute("test")
+fn testCompileTimeStrCmp() {
+    assert(@constEval(cmp(c"aoeu", c"aoez") == -1));
+}
+
+#attribute("test")
+fn testCompileTimeStrLen() {
+    assert(@constEval(len(c"123456789") == 9));
+}