Commit 8fc52a94f4

tgschultz <tgschultz@gmail.com>
2018-05-30 17:18:11
Added custom formatter support, refactored fmt.format
1 parent 8174f97
Changed files (1)
std
std/fmt/index.zig
@@ -16,27 +16,12 @@ pub fn format(context: var, comptime Errors: type, output: fn(@typeOf(context),
         Start,
         OpenBrace,
         CloseBrace,
-        Integer,
-        IntegerWidth,
-        Float,
-        FloatWidth,
-        FloatScientific,
-        FloatScientificWidth,
-        Character,
-        Buf,
-        BufWidth,
-        Bytes,
-        BytesBase,
-        BytesWidth,
+        FormatString,
     };
 
     comptime var start_index = 0;
     comptime var state = State.Start;
     comptime var next_arg = 0;
-    comptime var radix = 0;
-    comptime var uppercase = false;
-    comptime var width = 0;
-    comptime var width_start = 0;
 
     inline for (fmt) |c, i| {
         switch (state) {
@@ -45,8 +30,10 @@ pub fn format(context: var, comptime Errors: type, output: fn(@typeOf(context),
                     if (start_index < i) {
                         try output(context, fmt[start_index..i]);
                     }
+                    start_index = i;
                     state = State.OpenBrace;
                 },
+                
                 '}' => {
                     if (start_index < i) {
                         try output(context, fmt[start_index..i]);
@@ -61,57 +48,14 @@ pub fn format(context: var, comptime Errors: type, output: fn(@typeOf(context),
                     start_index = i;
                 },
                 '}' => {
-                    try formatValue(args[next_arg], context, Errors, output);
+                    try formatType(args[next_arg], fmt[0..0], context, Errors, output);
                     next_arg += 1;
                     state = State.Start;
                     start_index = i + 1;
                 },
-                'd' => {
-                    radix = 10;
-                    uppercase = false;
-                    width = 0;
-                    state = State.Integer;
-                },
-                'x' => {
-                    radix = 16;
-                    uppercase = false;
-                    width = 0;
-                    state = State.Integer;
-                },
-                'X' => {
-                    radix = 16;
-                    uppercase = true;
-                    width = 0;
-                    state = State.Integer;
-                },
-                'c' => {
-                    state = State.Character;
-                },
-                's' => {
-                    state = State.Buf;
-                },
-                'e' => {
-                    state = State.FloatScientific;
+                else => {
+                    state = State.FormatString;
                 },
-                '.' => {
-                    state = State.Float;
-                },
-                'B' => {
-                    width = 0;
-                    radix = 1000;
-                    state = State.Bytes;
-                },
-                else => @compileError("Unknown format character: " ++ []u8{c}),
-            },
-            State.Buf => switch (c) {
-                '}' => {
-                    return output(context, args[next_arg]);
-                },
-                '0'...'9' => {
-                    width_start = i;
-                    state = State.BufWidth;
-                },
-                else => @compileError("Unexpected character in format string: " ++ []u8{c}),
             },
             State.CloseBrace => switch (c) {
                 '}' => {
@@ -120,139 +64,16 @@ pub fn format(context: var, comptime Errors: type, output: fn(@typeOf(context),
                 },
                 else => @compileError("Single '}' encountered in format string"),
             },
-            State.Integer => switch (c) {
+            State.FormatString => switch(c) {
                 '}' => {
-                    try formatInt(args[next_arg], radix, uppercase, width, context, Errors, output);
+                    const s = start_index + 1;
+                    try formatType(args[next_arg], fmt[s..i], context, Errors, output);
                     next_arg += 1;
                     state = State.Start;
                     start_index = i + 1;
                 },
-                '0'...'9' => {
-                    width_start = i;
-                    state = State.IntegerWidth;
-                },
-                else => @compileError("Unexpected character in format string: " ++ []u8{c}),
-            },
-            State.IntegerWidth => switch (c) {
-                '}' => {
-                    width = comptime (parseUnsigned(usize, fmt[width_start..i], 10) catch unreachable);
-                    try formatInt(args[next_arg], radix, uppercase, width, context, Errors, output);
-                    next_arg += 1;
-                    state = State.Start;
-                    start_index = i + 1;
-                },
-                '0'...'9' => {},
-                else => @compileError("Unexpected character in format string: " ++ []u8{c}),
-            },
-            State.FloatScientific => switch (c) {
-                '}' => {
-                    try formatFloatScientific(args[next_arg], null, context, Errors, output);
-                    next_arg += 1;
-                    state = State.Start;
-                    start_index = i + 1;
-                },
-                '0'...'9' => {
-                    width_start = i;
-                    state = State.FloatScientificWidth;
-                },
-                else => @compileError("Unexpected character in format string: " ++ []u8{c}),
-            },
-            State.FloatScientificWidth => switch (c) {
-                '}' => {
-                    width = comptime (parseUnsigned(usize, fmt[width_start..i], 10) catch unreachable);
-                    try formatFloatScientific(args[next_arg], width, context, Errors, output);
-                    next_arg += 1;
-                    state = State.Start;
-                    start_index = i + 1;
-                },
-                '0'...'9' => {},
-                else => @compileError("Unexpected character in format string: " ++ []u8{c}),
-            },
-            State.Float => switch (c) {
-                '}' => {
-                    try formatFloatDecimal(args[next_arg], null, context, Errors, output);
-                    next_arg += 1;
-                    state = State.Start;
-                    start_index = i + 1;
-                },
-                '0'...'9' => {
-                    width_start = i;
-                    state = State.FloatWidth;
-                },
-                else => @compileError("Unexpected character in format string: " ++ []u8{c}),
-            },
-            State.FloatWidth => switch (c) {
-                '}' => {
-                    width = comptime (parseUnsigned(usize, fmt[width_start..i], 10) catch unreachable);
-                    try formatFloatDecimal(args[next_arg], width, context, Errors, output);
-                    next_arg += 1;
-                    state = State.Start;
-                    start_index = i + 1;
-                },
-                '0'...'9' => {},
-                else => @compileError("Unexpected character in format string: " ++ []u8{c}),
-            },
-            State.BufWidth => switch (c) {
-                '}' => {
-                    width = comptime (parseUnsigned(usize, fmt[width_start..i], 10) catch unreachable);
-                    try formatBuf(args[next_arg], width, context, Errors, output);
-                    next_arg += 1;
-                    state = State.Start;
-                    start_index = i + 1;
-                },
-                '0'...'9' => {},
-                else => @compileError("Unexpected character in format string: " ++ []u8{c}),
-            },
-            State.Character => switch (c) {
-                '}' => {
-                    try formatAsciiChar(args[next_arg], context, Errors, output);
-                    next_arg += 1;
-                    state = State.Start;
-                    start_index = i + 1;
-                },
-                else => @compileError("Unexpected character in format string: " ++ []u8{c}),
-            },
-            State.Bytes => switch (c) {
-                '}' => {
-                    try formatBytes(args[next_arg], 0, radix, context, Errors, output);
-                    next_arg += 1;
-                    state = State.Start;
-                    start_index = i + 1;
-                },
-                'i' => {
-                    radix = 1024;
-                    state = State.BytesBase;
-                },
-                '0'...'9' => {
-                    width_start = i;
-                    state = State.BytesWidth;
-                },
-                else => @compileError("Unexpected character in format string: " ++ []u8{c}),
-            },
-            State.BytesBase => switch (c) {
-                '}' => {
-                    try formatBytes(args[next_arg], 0, radix, context, Errors, output);
-                    next_arg += 1;
-                    state = State.Start;
-                    start_index = i + 1;
-                },
-                '0'...'9' => {
-                    width_start = i;
-                    state = State.BytesWidth;
-                },
-                else => @compileError("Unexpected character in format string: " ++ []u8{c}),
-            },
-            State.BytesWidth => switch (c) {
-                '}' => {
-                    width = comptime (parseUnsigned(usize, fmt[width_start..i], 10) catch unreachable);
-                    try formatBytes(args[next_arg], width, radix, context, Errors, output);
-                    next_arg += 1;
-                    state = State.Start;
-                    start_index = i + 1;
-                },
-                '0'...'9' => {},
-                else => @compileError("Unexpected character in format string: " ++ []u8{c}),
-            },
+                else => {},
+            }
         }
     }
     comptime {
@@ -268,14 +89,14 @@ pub fn format(context: var, comptime Errors: type, output: fn(@typeOf(context),
     }
 }
 
-pub fn formatValue(value: var, context: var, comptime Errors: type, output: fn(@typeOf(context), []const u8) Errors!void) Errors!void {
+pub fn formatType(value: var, comptime fmt: []const u8, context: var, comptime Errors: type,
+    output: fn(@typeOf(context), []const u8)Errors!void) Errors!void
+{
     const T = @typeOf(value);
     switch (@typeId(T)) {
-        builtin.TypeId.Int => {
-            return formatInt(value, 10, false, 0, context, Errors, output);
-        },
+        builtin.TypeId.Int,
         builtin.TypeId.Float => {
-            return formatFloatScientific(value, null, context, Errors, output);
+            return formatValue(value, fmt, context, Errors, output);
         },
         builtin.TypeId.Void => {
             return output(context, "void");
@@ -285,16 +106,16 @@ pub fn formatValue(value: var, context: var, comptime Errors: type, output: fn(@
         },
         builtin.TypeId.Nullable => {
             if (value) |payload| {
-                return formatValue(payload, context, Errors, output);
+                return formatType(payload, fmt, context, Errors, output);
             } else {
                 return output(context, "null");
             }
         },
         builtin.TypeId.ErrorUnion => {
             if (value) |payload| {
-                return formatValue(payload, context, Errors, output);
+                return formatType(payload, fmt, context, Errors, output);
             } else |err| {
-                return formatValue(err, context, Errors, output);
+                return formatType(err, fmt, context, Errors, output);
             }
         },
         builtin.TypeId.ErrorSet => {
@@ -302,10 +123,60 @@ pub fn formatValue(value: var, context: var, comptime Errors: type, output: fn(@
             return output(context, @errorName(value));
         },
         builtin.TypeId.Pointer => {
-            if (@typeId(T.Child) == builtin.TypeId.Array and T.Child.Child == u8) {
-                return output(context, (value.*)[0..]);
-            } else {
-                return format(context, Errors, output, "{}@{x}", @typeName(T.Child), @ptrToInt(value));
+            switch(@typeId(T.Child)) {
+                builtin.TypeId.Array => {
+                    if(T.Child.Child == u8) {
+                        return formatText(value, fmt, context, Errors, output);
+                    }
+                },
+                builtin.TypeId.Enum,
+                builtin.TypeId.Union,
+                builtin.TypeId.Struct => {
+                    const has_cust_fmt = comptime cf: {
+                        const info = @typeInfo(T.Child);
+                        const defs = switch (info) {
+                            builtin.TypeId.Struct => |s| s.defs,
+                            builtin.TypeId.Union => |u| u.defs,
+                            builtin.TypeId.Enum => |e| e.defs,
+                            else => unreachable,
+                        };
+                        
+                        for (defs) |def| {
+                            if (mem.eql(u8, def.name, "format") and def.is_pub) {
+                                const data = def.data;
+                                switch (data) {
+                                    builtin.TypeInfo.Definition.Data.Type, 
+                                    builtin.TypeInfo.Definition.Data.Var => continue,
+                                    builtin.TypeInfo.Definition.Data.Fn => |*fn_def| {
+                                        //const FmtType = fn(@typeOf(context), []const u8)Errors!void;
+                                        //// for some reason, fn_type sees the arg `comptime []const u8` as `var`
+                                        //const TargetType = fn(T, var, var, type, FmtType) Errors!void;
+                                        
+                                        // This hack is because fn_def.fn_type != TargetType
+                                        //   for reasons I have yet to determine.
+
+                                        const fn_type_name = @typeName(@typeOf(value.format));
+                                        const value_type_name = @typeName(@typeOf(value));
+                                        const target_type_name = "(bound fn("
+                                            ++ value_type_name ++ ",var,var,var,var)var)";
+                                        if (mem.eql(u8, fn_type_name, target_type_name))
+                                        {
+                                            break :cf true;
+                                        }
+                                        
+                                    },
+                                }
+                            }
+                        }
+                        break :cf false;
+                    };
+                    
+                    if (has_cust_fmt) return value.format(fmt, context, Errors, output);
+                    return format(context, Errors, output, "{}@{x}", @typeName(T.Child),
+                            @ptrToInt(value));
+                },
+                else => return format(context, Errors, output, "{}@{x}", @typeName(T.Child),
+                            @ptrToInt(value)),
             }
         },
         else => if (@canImplicitCast([]const u8, value)) {
@@ -317,11 +188,106 @@ pub fn formatValue(value: var, context: var, comptime Errors: type, output: fn(@
     }
 }
 
+fn formatValue(value: var, comptime fmt: []const u8, context: var, comptime Errors: type,
+    output: fn(@typeOf(context), []const u8)Errors!void) Errors!void
+{
+    if (fmt.len > 0) {
+        if (fmt[0] == 'B') {
+            comptime var width: ?usize = null;
+            if (fmt.len > 1) {
+                if (fmt[1] == 'i') {
+                    if (fmt.len > 2) width = comptime (parseUnsigned(usize, fmt[2..], 10) catch unreachable);
+                    return formatBytes(value, width, 1024, context, Errors, output);
+                }
+                width = comptime (parseUnsigned(usize, fmt[1..], 10) catch unreachable);
+            }
+            return formatBytes(value, width, 1000, context, Errors, output);
+        }
+    }
+    
+    comptime var T = @typeOf(value);
+    switch (@typeId(T)) {
+        builtin.TypeId.Float => return formatFloatValue(value, fmt, context, Errors, output),
+        builtin.TypeId.Int => return formatIntValue(value, fmt, context, Errors, output),
+        else => unreachable,
+    }
+}
+
+pub fn formatIntValue(value: var, comptime fmt: []const u8, context: var, comptime Errors: type,
+    output: fn(@typeOf(context), []const u8)Errors!void) Errors!void
+{
+    comptime var radix = 10;
+    comptime var uppercase = false;
+    comptime var width = 0;
+    if (fmt.len > 0) {
+        switch (fmt[0]) {
+            'c' => {
+                if(@typeOf(value) == u8) {
+                    if(fmt.len > 1) @compileError("Unknown format character: " ++ []u8{fmt[1]});
+                    return formatAsciiChar(fmt[0], context, Errors, output);
+                }
+            },
+            'd' => {
+                radix = 10;
+                uppercase = false;
+                width = 0;
+            },
+            'x' => {
+                radix = 16;
+                uppercase = false;
+                width = 0;
+            },
+            'X' => {
+                radix = 16;
+                uppercase = true;
+                width = 0;
+            },
+            else => @compileError("Unknown format character: " ++ []u8{fmt[0]}),
+        }
+        if (fmt.len > 1) width = comptime (parseUnsigned(usize, fmt[1..], 10) catch unreachable);
+    }
+    return formatInt(value, radix, uppercase, width, context, Errors, output);
+}
+
+fn formatFloatValue(value: var, comptime fmt: []const u8, context: var, comptime Errors: type,
+    output: fn(@typeOf(context), []const u8)Errors!void) Errors!void
+{
+    comptime var width: ?usize = null;
+    comptime var float_fmt = 'e';
+    if (fmt.len > 0) {
+        float_fmt = fmt[0];
+        if(fmt.len > 1) width = comptime (parseUnsigned(usize, fmt[1..], 10) catch unreachable);
+    }
+    
+    switch (float_fmt) {
+        'e' => try formatFloatScientific(value, width, context, Errors, output),
+        '.' => try formatFloatDecimal(value, width, context, Errors, output),
+        else => @compileError("Unknown format character: " ++ []u8{float_fmt}),
+    }
+    
+}
+
+pub fn formatText(bytes: []const u8, comptime fmt: []const u8, context: var, 
+    comptime Errors: type, output: fn(@typeOf(context), []const u8)Errors!void) Errors!void
+{
+    if (fmt.len > 0) {
+        if (fmt[0] == 's') {
+            comptime var width = 0;
+            if(fmt.len > 1) width = comptime (parseUnsigned(usize, fmt[1..], 10) catch unreachable);
+            return formatBuf(bytes, width, context, Errors, output);
+        }
+        else @compileError("Unknown format character: " ++ []u8{fmt[0]});
+    }
+    return output(context, bytes);
+}
+
 pub fn formatAsciiChar(c: u8, context: var, comptime Errors: type, output: fn(@typeOf(context), []const u8) Errors!void) Errors!void {
     return output(context, (&c)[0..1]);
 }
 
-pub fn formatBuf(buf: []const u8, width: usize, context: var, comptime Errors: type, output: fn(@typeOf(context), []const u8) Errors!void) Errors!void {
+pub fn formatBuf(buf: []const u8, width: usize, context: var,
+    comptime Errors: type, output: fn(@typeOf(context), []const u8) Errors!void) Errors!void
+{
     try output(context, buf);
 
     var leftover_padding = if (width > buf.len) (width - buf.len) else return;
@@ -1048,6 +1014,38 @@ test "fmt.format" {
         const result = try bufPrint(buf1[0..], "f64: {.5}\n", value);
         assert(mem.eql(u8, result, "f64: 18014400656965630.00000\n"));
     }
+    //custom type format
+    {
+        const Vec2 = struct {
+            const SelfType = this;
+            x: f32,
+            y: f32,
+        
+            pub fn format(self: &SelfType, comptime fmt: []const u8, context: var, 
+                comptime Errors: type, output: fn(@typeOf(context), []const u8)Errors!void) 
+                Errors!void 
+            {
+                if (fmt.len > 0) {
+                    if (fmt.len > 1) unreachable;
+                    switch (fmt[0]) {
+                        //point format
+                        'p' => return std.fmt.format(context, Errors, output, "({.3},{.3})", self.x, self.y),
+                        //dimension format
+                        'd' => return std.fmt.format(context, Errors, output, "{.3}x{.3}", self.x, self.y),
+                        else => unreachable,
+                    }
+                }
+                return std.fmt.format(context, Errors, output, "({.3},{.3})", self.x, self.y);
+            }
+        };
+        
+        var buf1: [32]u8 = undefined;
+        var value = Vec2{.x = 10.2, .y = 2.22,};
+        const point_result = try bufPrint(buf1[0..], "point: {}\n", &value);
+        assert(mem.eql(u8, point_result, "point: (10.200,2.220)\n"));
+        const dim_result = try bufPrint(buf1[0..], "dim: {d}\n", &value);
+        assert(mem.eql(u8, dim_result, "dim: 10.200x2.220\n"));
+    }
 }
 
 fn testFmt(expected: []const u8, comptime template: []const u8, args: ...) !void {