Commit b2e970d157

Dmitry Matveyev <greenfork.lists@yandex.com>
2021-08-20 13:52:48
std.json: Add support for recursive objects to std.json.parse (#9307)
* Add support for recursive objects to std.json.parse * Remove previously defined error set * Try with function which returns an error set * Don't analyze already inferred types * Add comptime to inferred_type parameter * Make ParseInternalError to accept only a single argument * Add public `ParseError` for `parse` function * Use error.Foo syntax for errors instead of a named error set * Better formatting * Update to latest code changes
1 parent cfb2827
Changed files (3)
lib/std/fmt/parse_float.zig
@@ -349,7 +349,9 @@ fn caseInEql(a: []const u8, b: []const u8) bool {
     return true;
 }
 
-pub fn parseFloat(comptime T: type, s: []const u8) !T {
+pub const ParseFloatError = error{InvalidCharacter};
+
+pub fn parseFloat(comptime T: type, s: []const u8) ParseFloatError!T {
     if (s.len == 0 or (s.len == 1 and (s[0] == '+' or s[0] == '-'))) {
         return error.InvalidCharacter;
     }
lib/std/fmt.zig
@@ -1757,6 +1757,7 @@ test "parseUnsigned" {
 }
 
 pub const parseFloat = @import("fmt/parse_float.zig").parseFloat;
+pub const ParseFloatError = @import("fmt/parse_float.zig").ParseFloatError;
 pub const parseHexFloat = @import("fmt/parse_hex_float.zig").parseHexFloat;
 
 test {
lib/std/json.zig
@@ -1468,7 +1468,9 @@ pub const ParseOptions = struct {
     allow_trailing_data: bool = false,
 };
 
-fn skipValue(tokens: *TokenStream) !void {
+const SkipValueError = error{UnexpectedJsonDepth} || TokenStream.Error;
+
+fn skipValue(tokens: *TokenStream) SkipValueError!void {
     const original_depth = tokens.stackUsed();
 
     // Return an error if no value is found
@@ -1530,7 +1532,84 @@ test "skipValue" {
     }
 }
 
-fn parseInternal(comptime T: type, token: Token, tokens: *TokenStream, options: ParseOptions) !T {
+fn ParseInternalError(comptime T: type) type {
+    // `inferred_types` is used to avoid infinite recursion for recursive type definitions.
+    const inferred_types = [_]type{};
+    return ParseInternalErrorImpl(T, &inferred_types);
+}
+
+fn ParseInternalErrorImpl(comptime T: type, comptime inferred_types: []const type) type {
+    for (inferred_types) |ty| {
+        if (T == ty) return error{};
+    }
+
+    switch (@typeInfo(T)) {
+        .Bool => return error{UnexpectedToken},
+        .Float, .ComptimeFloat => return error{UnexpectedToken} || std.fmt.ParseFloatError,
+        .Int, .ComptimeInt => {
+            return error{ UnexpectedToken, InvalidNumber, Overflow } ||
+                std.fmt.ParseIntError || std.fmt.ParseFloatError;
+        },
+        .Optional => |optionalInfo| {
+            return ParseInternalErrorImpl(optionalInfo.child, inferred_types ++ [_]type{T});
+        },
+        .Enum => return error{ UnexpectedToken, InvalidEnumTag } || std.fmt.ParseIntError ||
+            std.meta.IntToEnumError || std.meta.IntToEnumError,
+        .Union => |unionInfo| {
+            if (unionInfo.tag_type) |_| {
+                var errors = error{NoUnionMembersMatched};
+                for (unionInfo.fields) |u_field| {
+                    errors = errors || ParseInternalErrorImpl(u_field.field_type, inferred_types ++ [_]type{T});
+                }
+                return errors;
+            } else {
+                @compileError("Unable to parse into untagged union '" ++ @typeName(T) ++ "'");
+            }
+        },
+        .Struct => |structInfo| {
+            var errors = error{
+                DuplicateJSONField,
+                UnexpectedEndOfJson,
+                UnexpectedToken,
+                UnexpectedValue,
+                UnknownField,
+                MissingField,
+            } || SkipValueError || TokenStream.Error;
+            for (structInfo.fields) |field| {
+                errors = errors || ParseInternalErrorImpl(field.field_type, inferred_types ++ [_]type{T});
+            }
+            return errors;
+        },
+        .Array => |arrayInfo| {
+            return error{ UnexpectedEndOfJson, UnexpectedToken } || TokenStream.Error ||
+                UnescapeValidStringError ||
+                ParseInternalErrorImpl(arrayInfo.child, inferred_types ++ [_]type{T});
+        },
+        .Pointer => |ptrInfo| {
+            var errors = error{AllocatorRequired} || std.mem.Allocator.Error;
+            switch (ptrInfo.size) {
+                .One => {
+                    return errors || ParseInternalErrorImpl(ptrInfo.child, inferred_types ++ [_]type{T});
+                },
+                .Slice => {
+                    return errors || error{ UnexpectedEndOfJson, UnexpectedToken } ||
+                        ParseInternalErrorImpl(ptrInfo.child, inferred_types ++ [_]type{T}) ||
+                        UnescapeValidStringError || TokenStream.Error;
+                },
+                else => @compileError("Unable to parse into type '" ++ @typeName(T) ++ "'"),
+            }
+        },
+        else => return error{},
+    }
+    unreachable;
+}
+
+fn parseInternal(
+    comptime T: type,
+    token: Token,
+    tokens: *TokenStream,
+    options: ParseOptions,
+) ParseInternalError(T)!T {
     switch (@typeInfo(T)) {
         .Bool => {
             return switch (token) {
@@ -1794,7 +1873,11 @@ fn parseInternal(comptime T: type, token: Token, tokens: *TokenStream, options:
     unreachable;
 }
 
-pub fn parse(comptime T: type, tokens: *TokenStream, options: ParseOptions) !T {
+pub fn ParseError(comptime T: type) type {
+    return ParseInternalError(T) || error{UnexpectedEndOfJson} || TokenStream.Error;
+}
+
+pub fn parse(comptime T: type, tokens: *TokenStream, options: ParseOptions) ParseError(T)!T {
     const token = (try tokens.next()) orelse return error.UnexpectedEndOfJson;
     const r = try parseInternal(T, token, tokens, options);
     errdefer parseFree(T, r, options);
@@ -2181,6 +2264,45 @@ test "parse into struct ignoring unknown fields" {
     try testing.expectEqualSlices(u8, "zig", r.language);
 }
 
+const ParseIntoRecursiveUnionDefinitionValue = union(enum) {
+    integer: i64,
+    array: []const ParseIntoRecursiveUnionDefinitionValue,
+};
+
+test "parse into recursive union definition" {
+    const T = struct {
+        values: ParseIntoRecursiveUnionDefinitionValue,
+    };
+    const ops = ParseOptions{ .allocator = testing.allocator };
+
+    const r = try parse(T, &std.json.TokenStream.init("{\"values\":[58]}"), ops);
+    defer parseFree(T, r, ops);
+
+    try testing.expectEqual(@as(i64, 58), r.values.array[0].integer);
+}
+
+const ParseIntoDoubleRecursiveUnionValueFirst = union(enum) {
+    integer: i64,
+    array: []const ParseIntoDoubleRecursiveUnionValueSecond,
+};
+
+const ParseIntoDoubleRecursiveUnionValueSecond = union(enum) {
+    boolean: bool,
+    array: []const ParseIntoDoubleRecursiveUnionValueFirst,
+};
+
+test "parse into double recursive union definition" {
+    const T = struct {
+        values: ParseIntoDoubleRecursiveUnionValueFirst,
+    };
+    const ops = ParseOptions{ .allocator = testing.allocator };
+
+    const r = try parse(T, &std.json.TokenStream.init("{\"values\":[[58]]}"), ops);
+    defer parseFree(T, r, ops);
+
+    try testing.expectEqual(@as(i64, 58), r.values.array[0].array[0].integer);
+}
+
 /// A non-stream JSON parser which constructs a tree of Value's.
 pub const Parser = struct {
     allocator: *Allocator,
@@ -2418,10 +2540,12 @@ pub const Parser = struct {
     }
 };
 
+pub const UnescapeValidStringError = error{InvalidUnicodeHexSymbol};
+
 /// Unescape a JSON string
 /// Only to be used on strings already validated by the parser
 /// (note the unreachable statements and lack of bounds checking)
-pub fn unescapeValidString(output: []u8, input: []const u8) !void {
+pub fn unescapeValidString(output: []u8, input: []const u8) UnescapeValidStringError!void {
     var inIndex: usize = 0;
     var outIndex: usize = 0;