Commit 27d233cef7

Vexu <git@vexu.eu>
2020-10-12 11:20:32
stage2: basic switch validation
1 parent ad32e46
src/astgen.zig
@@ -1592,7 +1592,7 @@ fn switchExpr(mod: *Module, scope: *Scope, rl: ResultLoc, switch_node: *ast.Node
             kw_args.special_case = .@"else";
             else_src = case_src;
             cases[cases.len - 1] = .{
-                .values = &[_]*zir.Inst{},
+                .items = &[0]*zir.Inst{},
                 .body = undefined, // filled below
             };
             continue;
@@ -1606,7 +1606,7 @@ fn switchExpr(mod: *Module, scope: *Scope, rl: ResultLoc, switch_node: *ast.Node
             kw_args.special_case = .underscore;
             underscore_src = case_src;
             cases[cases.len - 1] = .{
-                .values = &[_]*zir.Inst{},
+                .items = &[0]*zir.Inst{},
                 .body = undefined, // filled below
             };
             continue;
@@ -1620,26 +1620,26 @@ fn switchExpr(mod: *Module, scope: *Scope, rl: ResultLoc, switch_node: *ast.Node
             }
         }
 
-        // Regular case, we need to fill `values`.
-        const values = try block_scope.arena.alloc(*zir.Inst, case.items_len);
+        // Regular case, we need to fill `items`.
+        const items = try block_scope.arena.alloc(*zir.Inst, case.items_len);
         for (case.items()) |item, i| {
             if (item.castTag(.Range)) |range| {
-                values[i] = try switchRange(mod, &block_scope.base, range);
+                items[i] = try switchRange(mod, &block_scope.base, range);
                 if (kw_args.support_range == null)
-                    kw_args.support_range = values[i];
+                    kw_args.support_range = items[i];
             } else {
-                values[i] = try expr(mod, &block_scope.base, .none, item);
+                items[i] = try expr(mod, &block_scope.base, .none, item);
             }
         }
         cases[case_index] = .{
-            .values = values,
+            .items = items,
             .body = undefined, // filled below
         };
         case_index += 1;
     }
 
     // Then we add the switch instruction to finish the block.
-    _ = try addZIRInst(mod, scope, switch_src, zir.Inst.Switch, .{
+    _ = try addZIRInst(mod, &block_scope.base, switch_src, zir.Inst.Switch, .{
         .target_ptr = target_ptr,
         .cases = cases,
     }, kw_args);
src/value.zig
@@ -1242,6 +1242,10 @@ pub const Value = extern union {
         return compare(a, .eq, b);
     }
 
+    pub fn hash(a: Value) u64 {
+        @panic("TODO Value.hash");
+    }
+
     /// Asserts the value is a pointer and dereferences it.
     /// Returns error.AnalysisFail if the pointer points to a Decl that failed semantic analysis.
     pub fn pointerDeref(self: Value, allocator: *Allocator) error{ AnalysisFail, OutOfMemory }!Value {
src/zir.zig
@@ -275,6 +275,8 @@ pub const Inst = struct {
         /// A switch expression.
         @"switch",
         /// A range in a switch case, `lhs...rhs`.
+        /// Only checks that `lhs >= rhs` if they are ints or floats, everything else is
+        /// validated by the .switch instruction.
         switch_range,
 
         pub fn Type(tag: Tag) type {
@@ -1018,7 +1020,7 @@ pub const Inst = struct {
         },
 
         pub const Case = struct {
-            values: []*Inst,
+            items: []*Inst,
             body: Module.Body,
         };
     };
@@ -1284,7 +1286,7 @@ const Writer = struct {
                         try stream.writeAll(",\n");
                     }
                     try stream.writeByteNTimes(' ', self.indent);
-                    try self.writeParamToStream(stream, &case.values);
+                    try self.writeParamToStream(stream, &case.items);
                     try stream.writeAll(" => ");
                     try self.writeParamToStream(stream, &case.body);
                 }
@@ -1714,7 +1716,7 @@ const Parser = struct {
                 while (true) {
                     const cur = try cases.addOne();
                     skipSpace(self);
-                    cur.values = try self.parseParameterGeneric([]*Inst, body_ctx);
+                    cur.items = try self.parseParameterGeneric([]*Inst, body_ctx);
                     skipSpace(self);
                     try requireEatBytes(self, "=>");
                     cur.body = try self.parseBody(body_ctx);
src/zir_sema.zig
@@ -135,7 +135,8 @@ pub fn analyzeInst(mod: *Module, scope: *Scope, old_inst: *zir.Inst) InnerError!
         .slice => return analyzeInstSlice(mod, scope, old_inst.castTag(.slice).?),
         .slice_start => return analyzeInstSliceStart(mod, scope, old_inst.castTag(.slice_start).?),
         .import => return analyzeInstImport(mod, scope, old_inst.castTag(.import).?),
-        .@"switch", .switch_range => @panic("TODO switch sema"),
+        .@"switch" => return analyzeInstSwitch(mod, scope, old_inst.castTag(.@"switch").?),
+        .switch_range => return analyzeInstSwitchRange(mod, scope, old_inst.castTag(.switch_range).?),
     }
 }
 
@@ -1205,6 +1206,126 @@ fn analyzeInstSliceStart(mod: *Module, scope: *Scope, inst: *zir.Inst.BinOp) Inn
     return mod.analyzeSlice(scope, inst.base.src, array_ptr, start, null, null);
 }
 
+fn analyzeInstSwitchRange(mod: *Module, scope: *Scope, inst: *zir.Inst.BinOp) InnerError!*Inst {
+    const start = try resolveInst(mod, scope, inst.positionals.lhs);
+    const end = try resolveInst(mod, scope, inst.positionals.rhs);
+
+    switch (start.ty.zigTypeTag()) {
+        .Int, .ComptimeInt, .Float, .ComptimeFloat => {},
+        else => return mod.constVoid(scope, inst.base.src),
+    }
+    switch (end.ty.zigTypeTag()) {
+        .Int, .ComptimeInt, .Float, .ComptimeFloat => {},
+        else => return mod.constVoid(scope, inst.base.src),
+    }
+    if (start.value()) |start_val| {
+        if (end.value()) |end_val| {
+            if (start_val.compare(.gte, end_val)) {
+                return mod.fail(scope, inst.base.src, "range start value is greater than the end value", .{});
+            }
+        }
+    }
+    return mod.constVoid(scope, inst.base.src);
+}
+
+fn analyzeInstSwitch(mod: *Module, scope: *Scope, inst: *zir.Inst.Switch) InnerError!*Inst {
+    const target_ptr = try resolveInst(mod, scope, inst.positionals.target_ptr);
+    const target = try mod.analyzeDeref(scope, inst.base.src, target_ptr, inst.positionals.target_ptr.src);
+    try validateSwitch(mod, scope, target, inst);
+
+    return mod.fail(scope, inst.base.src, "TODO analyzeInstSwitch", .{});
+}
+
+fn validateSwitch(mod: *Module, scope: *Scope, target: *Inst, inst: *zir.Inst.Switch) InnerError!void {
+    // validate usage of '_' prongs
+    if (inst.kw_args.special_case == .underscore and target.ty.zigTypeTag() != .Enum) {
+        return mod.fail(scope, inst.base.src, "'_' prong only allowed when switching on non-exhaustive enums", .{});
+        // TODO notes "'_' prong here" inst.positionals.cases[last].src
+    }
+
+    // check that target type supports ranges
+    if (inst.kw_args.support_range) |some| {
+        switch (target.ty.zigTypeTag()) {
+            .Int, .ComptimeInt, .Float, .ComptimeFloat => {},
+            else => {
+                return mod.fail(scope, target.src, "ranges not allowed when switching on type {}", .{target.ty});
+                // TODO notes "range used here" some.src
+            },
+        }
+    }
+
+    // validate for duplicate items/missing else prong
+    switch (target.ty.zigTypeTag()) {
+        .Int, .ComptimeInt => return mod.fail(scope, inst.base.src, "TODO validateSwitch .Int, .ComptimeInt", .{}),
+        .Float, .ComptimeFloat => return mod.fail(scope, inst.base.src, "TODO validateSwitch .Float, .ComptimeFloat", .{}),
+        .Enum => return mod.fail(scope, inst.base.src, "TODO validateSwitch .Enum", .{}),
+        .ErrorSet => return mod.fail(scope, inst.base.src, "TODO validateSwitch .ErrorSet", .{}),
+        .Union => return mod.fail(scope, inst.base.src, "TODO validateSwitch .Union", .{}),
+        .Bool => {
+            var true_count: u8 = 0;
+            var false_count: u8 = 0;
+            for (inst.positionals.cases) |case| {
+                for (case.items) |item| {
+                    const resolved = try resolveInst(mod, scope, item);
+                    const casted = try mod.coerce(scope, Type.initTag(.bool), resolved);
+                    if ((try mod.resolveConstValue(scope, casted)).toBool()) {
+                        true_count += 1;
+                    } else {
+                        false_count += 1;
+                    }
+
+                    if (true_count > 1 or false_count > 1) {
+                        return mod.fail(scope, item.src, "duplicate switch value", .{});
+                    }
+                }
+            }
+            if ((true_count == 0 or false_count == 0) and inst.kw_args.special_case != .@"else") {
+                return mod.fail(scope, inst.base.src, "switch must handle all possibilities", .{});
+            }
+            if ((true_count == 1 and false_count == 1) and inst.kw_args.special_case == .@"else") {
+                return mod.fail(scope, inst.base.src, "unreachable else prong, all cases already handled", .{});
+            }
+        },
+        .EnumLiteral, .Void, .Fn, .Pointer, .Type => {
+            if (inst.kw_args.special_case != .@"else") {
+                return mod.fail(scope, inst.base.src, "else prong required when switching on type '{}'", .{target.ty});
+            }
+
+            var seen_values = std.HashMap(Value, usize, Value.hash, Value.eql, std.hash_map.DefaultMaxLoadPercentage).init(mod.gpa);
+            defer seen_values.deinit();
+
+            for (inst.positionals.cases) |case| {
+                for (case.items) |item| {
+                    const resolved = try resolveInst(mod, scope, item);
+                    const casted = try mod.coerce(scope, target.ty, resolved);
+                    const val = try mod.resolveConstValue(scope, casted);
+
+                    if (try seen_values.fetchPut(val, item.src)) |prev| {
+                        return mod.fail(scope, item.src, "duplicate switch value", .{});
+                        // TODO notes "previous value here" prev.value
+                    }
+                }
+            }
+        },
+
+        .ErrorUnion,
+        .NoReturn,
+        .Array,
+        .Struct,
+        .Undefined,
+        .Null,
+        .Optional,
+        .BoundFn,
+        .Opaque,
+        .Vector,
+        .Frame,
+        .AnyFrame,
+        => {
+            return mod.fail(scope, target.src, "invalid switch target type '{}'", .{target.ty});
+        },
+    }
+}
+
 fn analyzeInstImport(mod: *Module, scope: *Scope, inst: *zir.Inst.UnOp) InnerError!*Inst {
     const operand = try resolveConstString(mod, scope, inst.positionals.operand);