Commit e761aa2d2f

Steve Perkins <steve@octopart.com>
2016-11-02 23:52:00
sortCmp allows for a custom cmp function
1 parent c5b2bda
Changed files (1)
std/sort.zig
@@ -1,5 +1,8 @@
 const assert = @import("debug.zig").assert;
 const str = @import("str.zig");
+const math = @import("math.zig");
+
+pub const Cmp = math.Cmp;
 
 pub fn sort(inline T: type, array: []T) {
     if (array.len > 0) {
@@ -32,6 +35,43 @@ fn quicksort(inline T: type, array: []T, left: usize, right: usize) {
     if (i < right) quicksort(T, array, i, right);
 }
 
+// ---------------------------------------
+// sortCmp
+
+pub fn sortCmp(inline T: type, array: []T, inline cmp: fn(a: T, b: T)->Cmp) {
+    if (array.len > 0) {
+        quicksortCmp(T, array, 0, array.len - 1, cmp);
+    }
+}
+
+fn quicksortCmp(inline T: type, array: []T, left: usize, right: usize, inline cmp: fn(a: T, b: T)->Cmp) {
+    var i = left;
+    var j = right;
+    var p = (i + j) / 2;
+
+    while (i <= j) {
+        while (cmp(array[i], array[p]) == Cmp.Less) {
+            i += 1;
+        }
+        while (cmp(array[j], array[p]) == Cmp.Greater) {
+            j -= 1;
+        }
+        if (i <= j) {
+            const tmp = array[i];
+            array[i] = array[j];
+            array[j] = tmp;
+            i += 1;
+            if (j > 0) j -= 1;
+        }
+    }
+
+    if (left < j) quicksortCmp(T, array, left, j, cmp);
+    if (i < right) quicksortCmp(T, array, i, right, cmp);
+}
+
+// ---------------------------------------
+// tests
+
 fn testSort() {
     @setFnTest(this, true);
 
@@ -63,3 +103,44 @@ fn testSort() {
         assert(str.sliceEql(i32, case[0], case[1]));
     }
 }
+
+fn testSortCmp() {
+    @setFnTest(this, true);
+
+    const i32cases = [][][]i32 {
+        [][]i32{[]i32{}, []i32{}},
+        [][]i32{[]i32{1}, []i32{1}},
+        [][]i32{[]i32{0, 1}, []i32{0, 1}},
+        [][]i32{[]i32{1, 0}, []i32{0, 1}},
+        [][]i32{[]i32{1, -1, 0}, []i32{-1, 0, 1}},
+        [][]i32{[]i32{2, 1, 3}, []i32{1, 2, 3}},
+    };
+
+    for (i32cases) |case| {
+        sortCmp(i32, case[0], normalCmp);
+        assert(str.sliceEql(i32, case[0], case[1]));
+    }
+
+    const revCases = [][][]i32 {
+        [][]i32{[]i32{}, []i32{}},
+        [][]i32{[]i32{1}, []i32{1}},
+        [][]i32{[]i32{0, 1}, []i32{1, 0}},
+        [][]i32{[]i32{1, 0}, []i32{1, 0}},
+        [][]i32{[]i32{1, -1, 0}, []i32{1, 0, -1}},
+        [][]i32{[]i32{2, 1, 3}, []i32{3, 2, 1}},
+    };
+
+    for (revCases) |case| {
+        sortCmp(i32, case[0], revCmp);
+        assert(str.sliceEql(i32, case[0], case[1]));
+    }
+
+}
+
+fn normalCmp(a: i32, b: i32) -> Cmp {
+    return if (a > b) Cmp.Greater else if (a < b) Cmp.Less else Cmp.Equal;
+}
+
+fn revCmp(a: i32, b: i32) -> Cmp {
+    return if (a < b) Cmp.Greater else if (a > b) Cmp.Less else Cmp.Equal;
+}