Commit 1676729c66

Jimmi Holst Christensen <jhc@dismail.de>
2022-01-01 17:36:53
fmt: Refactor parsing of placeholders into its own function
This saves on comptime format string parsing, as the compiler caches comptime calls. The catch here, is that parsePlaceHolder cannot take the placeholder string as a slice. It must take it as an array by value for the caching to occure. There is also some logic in here that ensures that the specifier_arg is always them same slice when the items they contain are the same. This makes the compiler stamp out less copies of formatType.
1 parent eee3952
Changed files (2)
lib
src
lib/std/fmt.zig
@@ -75,111 +75,20 @@ pub fn format(
     comptime fmt: []const u8,
     args: anytype,
 ) !void {
-    const ArgSetType = u32;
-
     const ArgsType = @TypeOf(args);
+    const args_type_info = @typeInfo(ArgsType);
     // XXX: meta.trait.is(.Struct)(ArgsType) doesn't seem to work...
-    if (@typeInfo(ArgsType) != .Struct) {
+    if (args_type_info != .Struct) {
         @compileError("Expected tuple or struct argument, found " ++ @typeName(ArgsType));
     }
 
-    const fields_info = meta.fields(ArgsType);
-    if (fields_info.len > @typeInfo(ArgSetType).Int.bits) {
+    const fields_info = args_type_info.Struct.fields;
+    if (fields_info.len > max_format_args) {
         @compileError("32 arguments max are supported per format call");
     }
 
-    comptime var arg_state: struct {
-        next_arg: usize = 0,
-        used_args: usize = 0,
-        args_len: usize = fields_info.len,
-
-        fn hasUnusedArgs(comptime self: *@This()) bool {
-            return @popCount(ArgSetType, self.used_args) != self.args_len;
-        }
-
-        fn nextArg(comptime self: *@This(), comptime arg_index: ?usize) comptime_int {
-            const next_index = arg_index orelse init: {
-                const arg = self.next_arg;
-                self.next_arg += 1;
-                break :init arg;
-            };
-
-            if (next_index >= self.args_len) {
-                @compileError("Too few arguments");
-            }
-
-            // Mark this argument as used
-            self.used_args |= 1 << next_index;
-
-            return next_index;
-        }
-    } = .{};
-
-    comptime var parser: struct {
-        buf: []const u8 = undefined,
-        pos: comptime_int = 0,
-
-        // Returns a decimal number or null if the current character is not a
-        // digit
-        fn number(comptime self: *@This()) ?usize {
-            var r: ?usize = null;
-
-            while (self.pos < self.buf.len) : (self.pos += 1) {
-                switch (self.buf[self.pos]) {
-                    '0'...'9' => {
-                        if (r == null) r = 0;
-                        r.? *= 10;
-                        r.? += self.buf[self.pos] - '0';
-                    },
-                    else => break,
-                }
-            }
-
-            return r;
-        }
-
-        // Returns a substring of the input starting from the current position
-        // and ending where `ch` is found or until the end if not found
-        fn until(comptime self: *@This(), comptime ch: u8) []const u8 {
-            const start = self.pos;
-
-            if (start >= self.buf.len)
-                return &[_]u8{};
-
-            while (self.pos < self.buf.len) : (self.pos += 1) {
-                if (self.buf[self.pos] == ch) break;
-            }
-            return self.buf[start..self.pos];
-        }
-
-        // Returns one character, if available
-        fn char(comptime self: *@This()) ?u8 {
-            if (self.pos < self.buf.len) {
-                const ch = self.buf[self.pos];
-                self.pos += 1;
-                return ch;
-            }
-            return null;
-        }
-
-        fn maybe(comptime self: *@This(), comptime val: u8) bool {
-            if (self.pos < self.buf.len and self.buf[self.pos] == val) {
-                self.pos += 1;
-                return true;
-            }
-            return false;
-        }
-
-        // Returns the n-th next character or null if that's past the end
-        fn peek(comptime self: *@This(), comptime n: usize) ?u8 {
-            return if (self.pos + n < self.buf.len) self.buf[self.pos + n] else null;
-        }
-    } = .{};
-
-    var options: FormatOptions = .{};
-
     @setEvalBranchQuota(2000000);
-
+    comptime var arg_state: ArgState = .{ .args_len = fields_info.len };
     comptime var i = 0;
     inline while (i < fmt.len) {
         const start_index = i;
@@ -234,134 +143,258 @@ pub fn format(
         comptime assert(fmt[i] == '}');
         i += 1;
 
-        options = .{};
-
-        // Parse the format fragment between braces
-        parser.buf = fmt[fmt_begin..fmt_end];
-        parser.pos = 0;
-
-        // Parse the positional argument number
-        const opt_pos_arg = comptime init: {
-            if (parser.maybe('[')) {
-                const arg_name = parser.until(']');
+        const placeholder = comptime parsePlaceholder(fmt[fmt_begin..fmt_end].*);
+        const arg_pos = comptime switch (placeholder.arg) {
+            .none => null,
+            .number => |pos| pos,
+            .named => |arg_name| meta.fieldIndex(ArgsType, arg_name) orelse
+                @compileError("No argument with name '" ++ arg_name ++ "'"),
+        };
 
-                if (!parser.maybe(']')) {
-                    @compileError("Expected closing ]");
-                }
+        const width = switch (placeholder.width) {
+            .none => null,
+            .number => |v| v,
+            .named => |arg_name| blk: {
+                const arg_i = comptime meta.fieldIndex(ArgsType, arg_name) orelse
+                    @compileError("No argument with name '" ++ arg_name ++ "'");
+                _ = comptime arg_state.nextArg(arg_i) orelse @compileError("Too few arguments");
+                break :blk @field(args, arg_name);
+            },
+        };
 
-                break :init meta.fieldIndex(ArgsType, arg_name) orelse
+        const precision = switch (placeholder.precision) {
+            .none => null,
+            .number => |v| v,
+            .named => |arg_name| blk: {
+                const arg_i = comptime meta.fieldIndex(ArgsType, arg_name) orelse
                     @compileError("No argument with name '" ++ arg_name ++ "'");
-            } else {
-                break :init parser.number();
-            }
+                _ = comptime arg_state.nextArg(arg_i) orelse @compileError("Too few arguments");
+                break :blk @field(args, arg_name);
+            },
         };
 
-        // Parse the format specifier
-        const specifier_arg = comptime parser.until(':');
+        const arg_to_print = comptime arg_state.nextArg(arg_pos) orelse
+            @compileError("Too few arguments");
 
-        // Skip the colon, if present
-        if (comptime parser.char()) |ch| {
-            if (ch != ':') {
-                @compileError("Expected : or }, found '" ++ [1]u8{ch} ++ "'");
-            }
+        try formatType(
+            @field(args, fields_info[arg_to_print].name),
+            placeholder.specifier_arg,
+            FormatOptions{
+                .fill = placeholder.fill,
+                .alignment = placeholder.alignment,
+                .width = width,
+                .precision = precision,
+            },
+            writer,
+            default_max_depth,
+        );
+    }
+
+    if (comptime arg_state.hasUnusedArgs()) {
+        const missing_count = arg_state.args_len - @popCount(ArgSetType, arg_state.used_args);
+        switch (missing_count) {
+            0 => unreachable,
+            1 => @compileError("Unused argument in '" ++ fmt ++ "'"),
+            else => @compileError((comptime comptimePrint("{d}", .{missing_count})) ++ " unused arguments in '" ++ fmt ++ "'"),
         }
+    }
+}
 
-        // Parse the fill character
-        // The fill parameter requires the alignment parameter to be specified
-        // too
-        if (comptime parser.peek(1)) |ch| {
-            if (comptime mem.indexOfScalar(u8, "<^>", ch) != null) {
-                options.fill = comptime parser.char().?;
-            }
+fn parsePlaceholder(comptime str: anytype) Placeholder {
+    comptime var parser = Parser{ .buf = &str };
+
+    // Parse the positional argument number
+    const arg = comptime parser.specifier() catch |err|
+        @compileError(@errorName(err));
+
+    // Parse the format specifier
+    const specifier_arg = comptime parser.until(':');
+
+    // Skip the colon, if present
+    if (comptime parser.char()) |ch| {
+        if (ch != ':') {
+            @compileError("Expected : or }, found '" ++ [1]u8{ch} ++ "'");
         }
+    }
 
-        // Parse the alignment parameter
-        if (comptime parser.peek(0)) |ch| {
-            switch (ch) {
-                '<' => {
-                    options.alignment = .Left;
-                    _ = comptime parser.char();
-                },
-                '^' => {
-                    options.alignment = .Center;
-                    _ = comptime parser.char();
-                },
-                '>' => {
-                    options.alignment = .Right;
-                    _ = comptime parser.char();
-                },
-                else => {},
-            }
+    // Parse the fill character
+    // The fill parameter requires the alignment parameter to be specified
+    // too
+    const fill = comptime if (parser.peek(1)) |ch|
+        switch (ch) {
+            '<', '^', '>' => parser.char().?,
+            else => ' ',
+        }
+    else
+        ' ';
+
+    // Parse the alignment parameter
+    const alignment: Alignment = comptime if (parser.peek(0)) |ch| init: {
+        switch (ch) {
+            '<', '^', '>' => _ = parser.char(),
+            else => {},
         }
+        break :init switch (ch) {
+            '<' => .Left,
+            '^' => .Center,
+            else => .Right,
+        };
+    } else .Right;
 
-        // Parse the width parameter
-        options.width = comptime init: {
-            if (parser.maybe('[')) {
-                const arg_name = parser.until(']');
+    // Parse the width parameter
+    const width = comptime parser.specifier() catch |err|
+        @compileError(@errorName(err));
 
-                if (!parser.maybe(']')) {
-                    @compileError("Expected closing ]");
-                }
+    // Skip the dot, if present
+    if (comptime parser.char()) |ch| {
+        if (ch != '.') {
+            @compileError("Expected . or }, found '" ++ [1]u8{ch} ++ "'");
+        }
+    }
 
-                const index = meta.fieldIndex(ArgsType, arg_name) orelse
-                    @compileError("No argument with name '" ++ arg_name ++ "'");
-                const arg_index = arg_state.nextArg(index);
+    // Parse the precision parameter
+    const precision = comptime parser.specifier() catch |err|
+        @compileError(@errorName(err));
 
-                break :init @field(args, fields_info[arg_index].name);
-            } else {
-                break :init parser.number();
-            }
-        };
+    if (comptime parser.char()) |ch| {
+        @compileError("Extraneous trailing character '" ++ [1]u8{ch} ++ "'");
+    }
+
+    return Placeholder{
+        .specifier_arg = cacheString(specifier_arg[0..specifier_arg.len].*),
+        .fill = fill,
+        .alignment = alignment,
+        .arg = arg,
+        .width = width,
+        .precision = precision,
+    };
+}
+
+fn cacheString(str: anytype) []const u8 {
+    return &str;
+}
+
+const Placeholder = struct {
+    specifier_arg: []const u8,
+    fill: u8,
+    alignment: Alignment,
+    arg: Specifier,
+    width: Specifier,
+    precision: Specifier,
+};
+
+const Specifier = union(enum) {
+    none,
+    number: usize,
+    named: []const u8,
+};
 
-        // Skip the dot, if present
-        if (comptime parser.char()) |ch| {
-            if (ch != '.') {
-                @compileError("Expected . or }, found '" ++ [1]u8{ch} ++ "'");
+const Parser = struct {
+    buf: []const u8,
+    pos: usize = 0,
+
+    // Returns a decimal number or null if the current character is not a
+    // digit
+    fn number(self: *@This()) ?usize {
+        var r: ?usize = null;
+
+        while (self.pos < self.buf.len) : (self.pos += 1) {
+            switch (self.buf[self.pos]) {
+                '0'...'9' => {
+                    if (r == null) r = 0;
+                    r.? *= 10;
+                    r.? += self.buf[self.pos] - '0';
+                },
+                else => break,
             }
         }
 
-        // Parse the precision parameter
-        options.precision = comptime init: {
-            if (parser.maybe('[')) {
-                const arg_name = parser.until(']');
+        return r;
+    }
 
-                if (!parser.maybe(']')) {
-                    @compileError("Expected closing ]");
-                }
+    // Returns a substring of the input starting from the current position
+    // and ending where `ch` is found or until the end if not found
+    fn until(self: *@This(), ch: u8) []const u8 {
+        const start = self.pos;
 
-                const arg_i = meta.fieldIndex(ArgsType, arg_name) orelse
-                    @compileError("No argument with name '" ++ arg_name ++ "'");
-                const arg_to_use = arg_state.nextArg(arg_i);
+        if (start >= self.buf.len)
+            return &[_]u8{};
 
-                break :init @field(args, fields_info[arg_to_use].name);
-            } else {
-                break :init parser.number();
-            }
-        };
+        while (self.pos < self.buf.len) : (self.pos += 1) {
+            if (self.buf[self.pos] == ch) break;
+        }
+        return self.buf[start..self.pos];
+    }
 
-        if (comptime parser.char()) |ch| {
-            @compileError("Extraneous trailing character '" ++ [1]u8{ch} ++ "'");
+    // Returns one character, if available
+    fn char(self: *@This()) ?u8 {
+        if (self.pos < self.buf.len) {
+            const ch = self.buf[self.pos];
+            self.pos += 1;
+            return ch;
         }
+        return null;
+    }
 
-        const arg_to_print = comptime arg_state.nextArg(opt_pos_arg);
-        try formatType(
-            @field(args, fields_info[arg_to_print].name),
-            specifier_arg,
-            options,
-            writer,
-            default_max_depth,
-        );
+    fn maybe(self: *@This(), val: u8) bool {
+        if (self.pos < self.buf.len and self.buf[self.pos] == val) {
+            self.pos += 1;
+            return true;
+        }
+        return false;
     }
 
-    if (comptime arg_state.hasUnusedArgs()) {
-        const missing_count = arg_state.args_len - @popCount(ArgSetType, arg_state.used_args);
-        switch (missing_count) {
-            0 => unreachable,
-            1 => @compileError("Unused argument in '" ++ fmt ++ "'"),
-            else => @compileError((comptime comptimePrint("{d}", .{missing_count})) ++ " unused arguments in '" ++ fmt ++ "'"),
+    // Returns a decimal number or null if the current character is not a
+    // digit
+    fn specifier(self: *@This()) !Specifier {
+        if (self.maybe('[')) {
+            const arg_name = self.until(']');
+
+            if (!self.maybe(']'))
+                return @field(anyerror, "Expected closing ]");
+
+            return Specifier{ .named = arg_name };
         }
+        if (self.number()) |i|
+            return Specifier{ .number = i };
+
+        return Specifier{ .none = {} };
+    }
+
+    // Returns the n-th next character or null if that's past the end
+    fn peek(self: *@This(), n: usize) ?u8 {
+        return if (self.pos + n < self.buf.len) self.buf[self.pos + n] else null;
+    }
+};
+
+const ArgSetType = u32;
+const max_format_args = @typeInfo(ArgSetType).Int.bits;
+
+const ArgState = struct {
+    next_arg: usize = 0,
+    used_args: ArgSetType = 0,
+    args_len: usize,
+
+    fn hasUnusedArgs(self: *@This()) bool {
+        return @popCount(ArgSetType, self.used_args) != self.args_len;
     }
-}
+
+    fn nextArg(self: *@This(), arg_index: ?usize) ?usize {
+        const next_index = arg_index orelse init: {
+            const arg = self.next_arg;
+            self.next_arg += 1;
+            break :init arg;
+        };
+
+        if (next_index >= self.args_len) {
+            return null;
+        }
+
+        // Mark this argument as used
+        self.used_args |= @as(ArgSetType, 1) << @intCast(u5, next_index);
+        return next_index;
+    }
+};
 
 pub fn formatAddress(value: anytype, options: FormatOptions, writer: anytype) @TypeOf(writer).Error!void {
     _ = options;
@@ -535,14 +568,19 @@ pub fn formatType(
                     if (actual_fmt.len == 0)
                         @compileError("cannot format array ref without a specifier (i.e. {s} or {*})");
                     if (info.child == u8) {
-                        if (comptime mem.indexOfScalar(u8, "sxXeE", actual_fmt[0]) != null) {
-                            return formatText(value, actual_fmt, options, writer);
+                        switch (actual_fmt[0]) {
+                            's', 'x', 'X', 'e', 'E' => {
+                                comptime checkTextFmt(actual_fmt);
+                                return formatBuf(value, options, writer);
+                            },
+                            else => {},
                         }
                     }
                     if (comptime std.meta.trait.isZigString(info.child)) {
                         for (value) |item, i| {
-                            if (i != 0) try formatText(", ", actual_fmt, options, writer);
-                            try formatText(item, actual_fmt, options, writer);
+                            comptime checkTextFmt(actual_fmt);
+                            if (i != 0) try formatBuf(", ", options, writer);
+                            try formatBuf(item, options, writer);
                         }
                         return;
                     }
@@ -560,8 +598,12 @@ pub fn formatType(
                     return formatType(mem.span(value), actual_fmt, options, writer, max_depth);
                 }
                 if (ptr_info.child == u8) {
-                    if (comptime mem.indexOfScalar(u8, "sxXeE", actual_fmt[0]) != null) {
-                        return formatText(mem.span(value), actual_fmt, options, writer);
+                    switch (actual_fmt[0]) {
+                        's', 'x', 'X', 'e', 'E' => {
+                            comptime checkTextFmt(actual_fmt);
+                            return formatBuf(mem.span(value), options, writer);
+                        },
+                        else => {},
                     }
                 }
                 @compileError("Unknown format string: '" ++ actual_fmt ++ "' for type '" ++ @typeName(T) ++ "'");
@@ -573,8 +615,12 @@ pub fn formatType(
                     return writer.writeAll("{ ... }");
                 }
                 if (ptr_info.child == u8) {
-                    if (comptime mem.indexOfScalar(u8, "sxXeE", actual_fmt[0]) != null) {
-                        return formatText(value, actual_fmt, options, writer);
+                    switch (actual_fmt[0]) {
+                        's', 'x', 'X', 'e', 'E' => {
+                            comptime checkTextFmt(actual_fmt);
+                            return formatBuf(value, options, writer);
+                        },
+                        else => {},
                     }
                 }
                 try writer.writeAll("{ ");
@@ -594,8 +640,12 @@ pub fn formatType(
                 return writer.writeAll("{ ... }");
             }
             if (info.child == u8) {
-                if (comptime mem.indexOfScalar(u8, "sxXeE", actual_fmt[0]) != null) {
-                    return formatText(&value, actual_fmt, options, writer);
+                switch (actual_fmt[0]) {
+                    's', 'x', 'X', 'e', 'E' => {
+                        comptime checkTextFmt(actual_fmt);
+                        return formatBuf(&value, options, writer);
+                    },
+                    else => {},
                 }
             }
             try writer.writeAll("{ ");
@@ -881,29 +931,28 @@ pub fn fmtIntSizeBin(value: u64) std.fmt.Formatter(formatSizeBin) {
     return .{ .data = value };
 }
 
+fn checkTextFmt(comptime fmt: []const u8) void {
+    if (fmt.len != 1)
+        @compileError("Unsupported format string '" ++ fmt ++ "' when formatting text");
+    switch (fmt[0]) {
+        'x' => @compileError("specifier 'x' has been deprecated, wrap your argument in std.fmt.fmtSliceHexLower instead"),
+        'X' => @compileError("specifier 'X' has been deprecated, wrap your argument in std.fmt.fmtSliceHexUpper instead"),
+        'e' => @compileError("specifier 'e' has been deprecated, wrap your argument in std.fmt.fmtSliceEscapeLower instead"),
+        'E' => @compileError("specifier 'E' has been deprecated, wrap your argument in std.fmt.fmtSliceEscapeUpper instead"),
+        'z' => @compileError("specifier 'z' has been deprecated, wrap your argument in std.zig.fmtId instead"),
+        'Z' => @compileError("specifier 'Z' has been deprecated, wrap your argument in std.zig.fmtEscapes instead"),
+        else => {},
+    }
+}
+
 pub fn formatText(
     bytes: []const u8,
     comptime fmt: []const u8,
     options: FormatOptions,
     writer: anytype,
 ) !void {
-    if (comptime std.mem.eql(u8, fmt, "s")) {
-        return formatBuf(bytes, options, writer);
-    } else if (comptime (std.mem.eql(u8, fmt, "x"))) {
-        @compileError("specifier 'x' has been deprecated, wrap your argument in std.fmt.fmtSliceHexLower instead");
-    } else if (comptime (std.mem.eql(u8, fmt, "X"))) {
-        @compileError("specifier 'X' has been deprecated, wrap your argument in std.fmt.fmtSliceHexUpper instead");
-    } else if (comptime (std.mem.eql(u8, fmt, "e"))) {
-        @compileError("specifier 'e' has been deprecated, wrap your argument in std.fmt.fmtSliceEscapeLower instead");
-    } else if (comptime (std.mem.eql(u8, fmt, "E"))) {
-        @compileError("specifier 'E' has been deprecated, wrap your argument in std.fmt.fmtSliceEscapeUpper instead");
-    } else if (comptime std.mem.eql(u8, fmt, "z")) {
-        @compileError("specifier 'z' has been deprecated, wrap your argument in std.zig.fmtId instead");
-    } else if (comptime std.mem.eql(u8, fmt, "Z")) {
-        @compileError("specifier 'Z' has been deprecated, wrap your argument in std.zig.fmtEscapes instead");
-    } else {
-        @compileError("Unsupported format string '" ++ fmt ++ "' when formatting text");
-    }
+    comptime checkTextFmt(fmt);
+    return formatBuf(bytes, options, writer);
 }
 
 pub fn formatAsciiChar(
src/stage1/analyze.cpp
@@ -5876,49 +5876,7 @@ static bool can_mutate_comptime_var_state(ZigValue *value) {
     zig_unreachable();
 }
 
-static bool return_type_is_cacheable(ZigType *return_type) {
-    switch (return_type->id) {
-        case ZigTypeIdInvalid:
-            zig_unreachable();
-        case ZigTypeIdMetaType:
-        case ZigTypeIdVoid:
-        case ZigTypeIdBool:
-        case ZigTypeIdUnreachable:
-        case ZigTypeIdInt:
-        case ZigTypeIdFloat:
-        case ZigTypeIdComptimeFloat:
-        case ZigTypeIdComptimeInt:
-        case ZigTypeIdEnumLiteral:
-        case ZigTypeIdUndefined:
-        case ZigTypeIdNull:
-        case ZigTypeIdBoundFn:
-        case ZigTypeIdFn:
-        case ZigTypeIdOpaque:
-        case ZigTypeIdErrorSet:
-        case ZigTypeIdEnum:
-        case ZigTypeIdPointer:
-        case ZigTypeIdVector:
-        case ZigTypeIdFnFrame:
-        case ZigTypeIdAnyFrame:
-            return true;
-
-        case ZigTypeIdArray:
-        case ZigTypeIdStruct:
-        case ZigTypeIdUnion:
-            return false;
-
-        case ZigTypeIdOptional:
-            return return_type_is_cacheable(return_type->data.maybe.child_type);
-
-        case ZigTypeIdErrorUnion:
-            return return_type_is_cacheable(return_type->data.error_union.payload_type);
-    }
-    zig_unreachable();
-}
-
 bool fn_eval_cacheable(Scope *scope, ZigType *return_type) {
-    if (!return_type_is_cacheable(return_type))
-        return false;
     while (scope) {
         if (scope->id == ScopeIdVarDecl) {
             ScopeVarDecl *var_scope = (ScopeVarDecl *)scope;