Commit cb41f0e58d

Luuk de Gram <luuk@degram.dev>
2021-07-23 16:41:04
switchbr: When prongs are sparse values, use if/else-chain
1 parent ad38fc1
Changed files (1)
src
codegen
src/codegen/wasm.zig
@@ -1275,12 +1275,17 @@ pub const Context = struct {
         const blocktype = wasm.block_empty;
         const pl_op = self.air.instructions.items(.data)[inst].pl_op;
         const target = self.resolveInst(pl_op.operand);
+        const target_ty = self.air.typeOf(pl_op.operand);
         const switch_br = self.air.extraData(Air.SwitchBr, pl_op.payload);
         var extra_index: usize = switch_br.end;
         var case_i: u32 = 0;
 
         // a map that maps each value with its index and body
-        var map = std.AutoArrayHashMap(u32, struct { index: u32, body: []const Air.Inst.Index }).init(self.gpa);
+        var map = std.AutoArrayHashMap(u32, struct {
+            index: u32,
+            body: []const Air.Inst.Index,
+            value: Value,
+        }).init(self.gpa);
         defer map.deinit();
 
         var lowest: u32 = 0;
@@ -1292,51 +1297,87 @@ pub const Context = struct {
             extra_index = case.end + items.len + case_body.len;
 
             for (items) |ref| {
-                const item_val = @intCast(u32, self.air.value(ref).?.toUnsignedInt());
-                if (item_val < lowest) {
-                    lowest = item_val;
+                const item_val = self.air.value(ref).?;
+                // safe to truncate the values as we only use them when
+                // the target's bits is 32 or lower.
+                const int_val = @truncate(u32, item_val.toUnsignedInt());
+                if (int_val < lowest) {
+                    lowest = int_val;
                 }
-                if (item_val > highest) {
-                    highest = item_val;
+                if (int_val > highest) {
+                    highest = int_val;
                 }
-                try map.put(item_val, .{ .index = case_i, .body = case_body });
+                try map.put(int_val, .{ .index = case_i, .body = case_body, .value = item_val });
             }
 
             try self.startBlock(.block, blocktype, null);
         }
 
+        // When the highest and lowest values are seperated by '50',
+        // we define it as sparse and use an if/else-chain, rather than a jump table.
+        // When the target is an integer size larger than u32, we have no way to use the value
+        // as an index, therefore we also use an if/else-chain for those cases.
+        // TODO: Benchmark this to find a proper value, LLVM seems to draw the line at '40~45'.
+        const is_sparse = target_ty.intInfo(self.target).bits > 32 or highest - lowest > 50;
+
         const else_body = self.air.extra[extra_index..][0..switch_br.data.else_body_len];
-        if (else_body.len != 0) {
+        const has_else_body = else_body.len != 0;
+        if (has_else_body) {
             try self.startBlock(.block, blocktype, null);
         }
 
-        // Generate the jump table 'br_table'.
-        // The value 'target' represents the index into the table.
-        // Each index in the table represents a label to the branch
-        // to jump to.
-        try self.startBlock(.block, blocktype, null);
-        try self.emitWValue(target);
-        try self.code.append(wasm.opcode(.br_table));
-        try leb.writeULEB128(self.code.writer(), highest - lowest + 1);
-        while (lowest <= highest) : (lowest += 1) {
-            const idx = if (map.get(lowest)) |value| blk: {
-                break :blk value.index + 1;
-            } else 0;
-            try leb.writeULEB128(self.code.writer(), idx);
-        } else if (else_body.len != 0) {
-            try leb.writeULEB128(self.code.writer(), @as(u32, 0)); // default branch
-        }
-        try self.endBlock();
-
-        if (else_body.len != 0) {
-            try self.genBody(else_body);
+        if (!is_sparse) {
+            // Generate the jump table 'br_table' when the prongs are not sparse.
+            // The value 'target' represents the index into the table.
+            // Each index in the table represents a label to the branch
+            // to jump to.
+            try self.startBlock(.block, blocktype, null);
+            try self.emitWValue(target);
+            try self.code.append(wasm.opcode(.br_table));
+            const depth = highest - lowest + @boolToInt(has_else_body);
+            try leb.writeULEB128(self.code.writer(), depth);
+            while (lowest <= highest) : (lowest += 1) {
+                const idx = if (map.get(lowest)) |value| blk: {
+                    break :blk value.index;
+                } else if (has_else_body) case_i else unreachable;
+                try leb.writeULEB128(self.code.writer(), idx);
+            } else if (has_else_body) {
+                try leb.writeULEB128(self.code.writer(), @as(u32, case_i)); // default branch
+            }
             try self.endBlock();
         }
 
+        const signedness: std.builtin.Signedness = blk: {
+            // by default we tell the operand type is unsigned (i.e. bools and enum values)
+            if (target_ty.zigTypeTag() != .Int) break :blk .unsigned;
+
+            // incase of an actual integer, we emit the correct signedness
+            break :blk target_ty.intInfo(self.target).signedness;
+        };
+
         for (map.values()) |val| {
+            // when sparse, we use if/else-chain, so emit conditional checks
+            if (is_sparse) {
+                try self.emitWValue(target);
+                try self.emitConstant(val.value, target_ty);
+                const opcode = buildOpcode(.{
+                    .valtype1 = try self.typeToValtype(target_ty),
+                    .op = .ne, // not equal, because we want to jump out of this block if it does not match the condition.
+                    .signedness = signedness,
+                });
+                try self.code.append(wasm.opcode(opcode));
+                try self.code.append(wasm.opcode(.br_if));
+                try leb.writeULEB128(self.code.writer(), @as(u32, 0));
+            }
             try self.genBody(val.body);
             try self.endBlock();
         }
+
+        if (has_else_body) {
+            try self.genBody(else_body);
+            try self.endBlock();
+        }
+
         return .none;
     }