Commit 6bd5479306

Ali Chraghi <alichraghi@proton.me>
2023-06-24 16:01:50
std.sort.block: add safety check for lessThan return value
1 parent 88284c1
lib/std/c/tokenizer.zig
@@ -1,5 +1,4 @@
 const std = @import("std");
-const mem = std.mem;
 
 pub const Token = struct {
     id: Id,
lib/std/sort/block.zig
@@ -1,3 +1,4 @@
+const builtin = @import("builtin");
 const std = @import("../std.zig");
 const sort = std.sort;
 const math = std.math;
@@ -100,8 +101,16 @@ pub fn block(
     comptime T: type,
     items: []T,
     context: anytype,
-    comptime lessThan: fn (@TypeOf(context), lhs: T, rhs: T) bool,
+    comptime lessThanFn: fn (@TypeOf(context), lhs: T, rhs: T) bool,
 ) void {
+    const lessThan = if (builtin.mode == .Debug) struct {
+        fn lessThan(ctx: @TypeOf(context), lhs: T, rhs: T) bool {
+            const lt = lessThanFn(ctx, lhs, rhs);
+            const gt = lessThanFn(ctx, rhs, lhs);
+            std.debug.assert(!(lt and gt));
+            return lt;
+        }
+    }.lessThan else lessThanFn;
 
     // Implementation ported from https://github.com/BonzaiThePenguin/WikiSort/blob/master/WikiSort.c
     var cache: [512]T = undefined;
lib/std/comptime_string_map.zig
@@ -9,18 +9,12 @@ const mem = std.mem;
 /// You can pass `struct { []const u8 }` (only keys) tuples if `V` is `void`.
 pub fn ComptimeStringMap(comptime V: type, comptime kvs_list: anytype) type {
     const precomputed = comptime blk: {
-        @setEvalBranchQuota(2000);
+        @setEvalBranchQuota(1500);
         const KV = struct {
             key: []const u8,
             value: V,
         };
         var sorted_kvs: [kvs_list.len]KV = undefined;
-        const lenAsc = (struct {
-            fn lenAsc(context: void, a: KV, b: KV) bool {
-                _ = context;
-                return a.key.len < b.key.len;
-            }
-        }).lenAsc;
         for (kvs_list, 0..) |kv, i| {
             if (V != void) {
                 sorted_kvs[i] = .{ .key = kv.@"0", .value = kv.@"1" };
@@ -28,7 +22,20 @@ pub fn ComptimeStringMap(comptime V: type, comptime kvs_list: anytype) type {
                 sorted_kvs[i] = .{ .key = kv.@"0", .value = {} };
             }
         }
-        mem.sort(KV, &sorted_kvs, {}, lenAsc);
+
+        const SortContext = struct {
+            kvs: []KV,
+
+            pub fn lessThan(ctx: @This(), a: usize, b: usize) bool {
+                return ctx.kvs[a].key.len < ctx.kvs[b].key.len;
+            }
+
+            pub fn swap(ctx: @This(), a: usize, b: usize) void {
+                return std.mem.swap(KV, &ctx.kvs[a], &ctx.kvs[b]);
+            }
+        };
+        mem.sortUnstableContext(0, sorted_kvs.len, SortContext{ .kvs = &sorted_kvs });
+
         const min_len = sorted_kvs[0].key.len;
         const max_len = sorted_kvs[sorted_kvs.len - 1].key.len;
         var len_indexes: [max_len + 1]usize = undefined;
lib/std/enums.zig
@@ -1289,10 +1289,6 @@ test "std.enums.ensureIndexer" {
     });
 }
 
-fn ascByValue(ctx: void, comptime a: EnumField, comptime b: EnumField) bool {
-    _ = ctx;
-    return a.value < b.value;
-}
 pub fn EnumIndexer(comptime E: type) type {
     if (!@typeInfo(E).Enum.is_exhaustive) {
         @compileError("Cannot create an enum indexer for a non-exhaustive enum.");
@@ -1300,7 +1296,10 @@ pub fn EnumIndexer(comptime E: type) type {
 
     const const_fields = std.meta.fields(E);
     var fields = const_fields[0..const_fields.len].*;
-    if (fields.len == 0) {
+    const min = fields[0].value;
+    const max = fields[fields.len - 1].value;
+    const fields_len = fields.len;
+    if (fields_len == 0) {
         return struct {
             pub const Key = E;
             pub const count: usize = 0;
@@ -1314,10 +1313,20 @@ pub fn EnumIndexer(comptime E: type) type {
             }
         };
     }
-    std.mem.sort(EnumField, &fields, {}, ascByValue);
-    const min = fields[0].value;
-    const max = fields[fields.len - 1].value;
-    const fields_len = fields.len;
+
+    const SortContext = struct {
+        fields: []EnumField,
+
+        pub fn lessThan(comptime ctx: @This(), comptime a: usize, comptime b: usize) bool {
+            return ctx.fields[a].value < ctx.fields[b].value;
+        }
+
+        pub fn swap(comptime ctx: @This(), comptime a: usize, comptime b: usize) void {
+            return std.mem.swap(EnumField, &ctx.fields[a], &ctx.fields[b]);
+        }
+    };
+    std.sort.insertionContext(0, fields_len, SortContext{ .fields = &fields });
+
     if (max - min == fields.len - 1) {
         return struct {
             pub const Key = E;
lib/std/sort.zig
@@ -366,7 +366,7 @@ test "sort with context in the middle of a slice" {
                 const slice = buf[0..case[0].len];
                 @memcpy(slice, case[0]);
                 sortFn(range.start, range.end, Context{ .items = slice });
-                try testing.expectEqualSlices(i32, slice[range.start..range.end], case[1][range.start..range.end]);
+                try testing.expectEqualSlices(i32, case[1][range.start..range.end], slice[range.start..range.end]);
             }
         }
     }
src/arch/x86_64/Encoding.zig
@@ -767,7 +767,7 @@ fn estimateInstructionLength(prefix: Prefix, encoding: Encoding, ops: []const Op
 }
 
 const mnemonic_to_encodings_map = init: {
-    @setEvalBranchQuota(30_000);
+    @setEvalBranchQuota(50_000);
     const encodings = @import("encodings.zig");
     var entries = encodings.table;
     std.mem.sort(encodings.Entry, &entries, {}, struct {