Commit 131bfe2f74

Garrett <138411610+garrettlennoxbeck@users.noreply.github.com>
2023-07-09 04:49:31
std.json: expose innerParse and add .allocate option (#16312)
1 parent 89396ff
Changed files (3)
lib/std/json/static.zig
@@ -34,6 +34,14 @@ pub const ParseOptions = struct {
     /// The default for `parseFromTokenSource` with a `*std.json.Reader` is `std.json.default_max_value_len`.
     /// Ignored for `parseFromValue` and `parseFromValueLeaky`.
     max_value_len: ?usize = null,
+
+    /// This determines whether strings should always be copied,
+    /// or if a reference to the given buffer should be preferred if possible.
+    /// The default for `parseFromSlice` or `parseFromTokenSource` with a `*std.json.Scanner` input
+    /// is `.alloc_if_needed`.
+    /// The default with a `*std.json.Reader` input is `.alloc_always`.
+    /// Ignored for `parseFromValue` and `parseFromValueLeaky`.
+    allocate: ?AllocWhen = null,
 };
 
 pub fn Parsed(comptime T: type) type {
@@ -113,7 +121,6 @@ pub fn parseFromTokenSourceLeaky(
     if (@TypeOf(scanner_or_reader.*) == Scanner) {
         assert(scanner_or_reader.is_end_of_input);
     }
-
     var resolved_options = options;
     if (resolved_options.max_value_len == null) {
         if (@TypeOf(scanner_or_reader.*) == Scanner) {
@@ -122,8 +129,15 @@ pub fn parseFromTokenSourceLeaky(
             resolved_options.max_value_len = default_max_value_len;
         }
     }
+    if (resolved_options.allocate == null) {
+        if (@TypeOf(scanner_or_reader.*) == Scanner) {
+            resolved_options.allocate = .alloc_if_needed;
+        } else {
+            resolved_options.allocate = .alloc_always;
+        }
+    }
 
-    const value = try internalParse(T, allocator, scanner_or_reader, resolved_options);
+    const value = try innerParse(T, allocator, scanner_or_reader, resolved_options);
 
     assert(.end_of_document == try scanner_or_reader.next());
 
@@ -181,7 +195,14 @@ pub const ParseFromValueError = std.fmt.ParseIntError || std.fmt.ParseFloatError
     LengthMismatch,
 };
 
-fn internalParse(
+/// This is an internal function called recursively
+/// during the implementation of `parseFromTokenSourceLeaky` and similar.
+/// It is exposed primarily to enable custom `jsonParse()` methods to call back into the `parseFrom*` system,
+/// such as if you're implementing a custom container of type `T`;
+/// you can call `internalParse(T, ...)` for each of the container's items.
+/// Note that `null` fields are not allowed on the `options` when calling this function.
+/// (The `options` you get in your `jsonParse` method has no `null` fields.)
+pub fn innerParse(
     comptime T: type,
     allocator: Allocator,
     source: anytype,
@@ -220,7 +241,7 @@ fn internalParse(
                     return null;
                 },
                 else => {
-                    return try internalParse(optionalInfo.child, allocator, source, options);
+                    return try innerParse(optionalInfo.child, allocator, source, options);
                 },
             }
         },
@@ -250,16 +271,17 @@ fn internalParse(
             var name_token: ?Token = try source.nextAllocMax(allocator, .alloc_if_needed, options.max_value_len.?);
             const field_name = switch (name_token.?) {
                 inline .string, .allocated_string => |slice| slice,
-                else => return error.UnexpectedToken,
+                else => {
+                    return error.UnexpectedToken;
+                },
             };
 
             inline for (unionInfo.fields) |u_field| {
                 if (std.mem.eql(u8, u_field.name, field_name)) {
                     // Free the name token now in case we're using an allocator that optimizes freeing the last allocated object.
-                    // (Recursing into internalParse() might trigger more allocations.)
+                    // (Recursing into innerParse() might trigger more allocations.)
                     freeAllocated(allocator, name_token.?);
                     name_token = null;
-
                     if (u_field.type == void) {
                         // void isn't really a json type, but we can support void payload union tags with {} as a value.
                         if (.object_begin != try source.next()) return error.UnexpectedToken;
@@ -267,7 +289,7 @@ fn internalParse(
                         result = @unionInit(T, u_field.name, {});
                     } else {
                         // Recurse.
-                        result = @unionInit(T, u_field.name, try internalParse(u_field.type, allocator, source, options));
+                        result = @unionInit(T, u_field.name, try innerParse(u_field.type, allocator, source, options));
                     }
                     break;
                 }
@@ -287,7 +309,7 @@ fn internalParse(
 
                 var r: T = undefined;
                 inline for (0..structInfo.fields.len) |i| {
-                    r[i] = try internalParse(structInfo.fields[i].type, allocator, source, options);
+                    r[i] = try innerParse(structInfo.fields[i].type, allocator, source, options);
                 }
 
                 if (.array_end != try source.next()) return error.UnexpectedToken;
@@ -307,32 +329,35 @@ fn internalParse(
             while (true) {
                 var name_token: ?Token = try source.nextAllocMax(allocator, .alloc_if_needed, options.max_value_len.?);
                 const field_name = switch (name_token.?) {
-                    .object_end => break, // No more fields.
                     inline .string, .allocated_string => |slice| slice,
-                    else => return error.UnexpectedToken,
+                    .object_end => { // No more fields.
+                        break;
+                    },
+                    else => {
+                        return error.UnexpectedToken;
+                    },
                 };
 
                 inline for (structInfo.fields, 0..) |field, i| {
                     if (field.is_comptime) @compileError("comptime fields are not supported: " ++ @typeName(T) ++ "." ++ field.name);
                     if (std.mem.eql(u8, field.name, field_name)) {
                         // Free the name token now in case we're using an allocator that optimizes freeing the last allocated object.
-                        // (Recursing into internalParse() might trigger more allocations.)
+                        // (Recursing into innerParse() might trigger more allocations.)
                         freeAllocated(allocator, name_token.?);
                         name_token = null;
-
                         if (fields_seen[i]) {
                             switch (options.duplicate_field_behavior) {
                                 .use_first => {
                                     // Parse and ignore the redundant value.
                                     // We don't want to skip the value, because we want type checking.
-                                    _ = try internalParse(field.type, allocator, source, options);
+                                    _ = try innerParse(field.type, allocator, source, options);
                                     break;
                                 },
                                 .@"error" => return error.DuplicateField,
                                 .use_last => {},
                             }
                         }
-                        @field(r, field.name) = try internalParse(field.type, allocator, source, options);
+                        @field(r, field.name) = try innerParse(field.type, allocator, source, options);
                         fields_seen[i] = true;
                         break;
                     }
@@ -418,7 +443,7 @@ fn internalParse(
             switch (ptrInfo.size) {
                 .One => {
                     const r: *ptrInfo.child = try allocator.create(ptrInfo.child);
-                    r.* = try internalParse(ptrInfo.child, allocator, source, options);
+                    r.* = try innerParse(ptrInfo.child, allocator, source, options);
                     return r;
                 },
                 .Slice => {
@@ -438,7 +463,7 @@ fn internalParse(
                                 }
 
                                 try arraylist.ensureUnusedCapacity(1);
-                                arraylist.appendAssumeCapacity(try internalParse(ptrInfo.child, allocator, source, options));
+                                arraylist.appendAssumeCapacity(try innerParse(ptrInfo.child, allocator, source, options));
                             }
 
                             if (ptrInfo.sentinel) |some| {
@@ -459,7 +484,7 @@ fn internalParse(
                                 return try value_list.toOwnedSliceSentinel(@as(*const u8, @ptrCast(sentinel_ptr)).*);
                             }
                             if (ptrInfo.is_const) {
-                                switch (try source.nextAllocMax(allocator, .alloc_if_needed, options.max_value_len.?)) {
+                                switch (try source.nextAllocMax(allocator, options.allocate.?, options.max_value_len.?)) {
                                     inline .string, .allocated_string => |slice| return slice,
                                     else => unreachable,
                                 }
@@ -495,7 +520,7 @@ fn internalParseArray(
     var r: T = undefined;
     var i: usize = 0;
     while (i < len) : (i += 1) {
-        r[i] = try internalParse(Child, allocator, source, options);
+        r[i] = try innerParse(Child, allocator, source, options);
     }
 
     if (.array_end != try source.next()) return error.UnexpectedToken;
lib/std/json/static_test.zig
@@ -7,6 +7,7 @@ const parseFromSlice = @import("./static.zig").parseFromSlice;
 const parseFromSliceLeaky = @import("./static.zig").parseFromSliceLeaky;
 const parseFromTokenSource = @import("./static.zig").parseFromTokenSource;
 const parseFromTokenSourceLeaky = @import("./static.zig").parseFromTokenSourceLeaky;
+const innerParse = @import("./static.zig").innerParse;
 const parseFromValue = @import("./static.zig").parseFromValue;
 const parseFromValueLeaky = @import("./static.zig").parseFromValueLeaky;
 const ParseOptions = @import("./static.zig").ParseOptions;
@@ -801,3 +802,102 @@ test "parse into vector" {
     try testing.expectApproxEqAbs(@as(f32, 2.5), parsed.value.vec_f32[1], 0.0000001);
     try testing.expectEqual(@Vector(4, i32){ 4, 5, 6, 7 }, parsed.value.vec_i32);
 }
+
+fn assertKey(
+    allocator: Allocator,
+    test_string: []const u8,
+    scanner: anytype,
+) !void {
+    const token_outer = try scanner.nextAlloc(allocator, .alloc_always);
+    switch (token_outer) {
+        .allocated_string => |string| {
+            try testing.expectEqualSlices(u8, string, test_string);
+            allocator.free(string);
+        },
+        else => return error.UnexpectedToken,
+    }
+}
+test "json parse partial" {
+    const Inner = struct {
+        num: u32,
+        yes: bool,
+    };
+    var str =
+        \\{
+        \\  "outer": {
+        \\    "key1": {
+        \\      "num": 75,
+        \\      "yes": true
+        \\    },
+        \\    "key2": {
+        \\      "num": 95,
+        \\      "yes": false
+        \\    }
+        \\  }
+        \\}
+    ;
+    var allocator = testing.allocator;
+    var scanner = JsonScanner.initCompleteInput(allocator, str);
+    defer scanner.deinit();
+
+    var arena = ArenaAllocator.init(allocator);
+    defer arena.deinit();
+
+    // Peel off the outer object
+    try testing.expectEqual(try scanner.next(), .object_begin);
+    try assertKey(allocator, "outer", &scanner);
+    try testing.expectEqual(try scanner.next(), .object_begin);
+    try assertKey(allocator, "key1", &scanner);
+
+    // Parse the inner object to an Inner struct
+    const inner_token = try innerParse(
+        Inner,
+        arena.allocator(),
+        &scanner,
+        .{ .max_value_len = scanner.input.len },
+    );
+    try testing.expectEqual(inner_token.num, 75);
+    try testing.expectEqual(inner_token.yes, true);
+
+    // Get they next key
+    try assertKey(allocator, "key2", &scanner);
+    const inner_token_2 = try innerParse(
+        Inner,
+        arena.allocator(),
+        &scanner,
+        .{ .max_value_len = scanner.input.len },
+    );
+    try testing.expectEqual(inner_token_2.num, 95);
+    try testing.expectEqual(inner_token_2.yes, false);
+    try testing.expectEqual(try scanner.next(), .object_end);
+}
+
+test "json parse allocate when streaming" {
+    const T = struct {
+        not_const: []u8,
+        is_const: []const u8,
+    };
+    var str =
+        \\{
+        \\  "not_const": "non const string",
+        \\  "is_const": "const string"
+        \\}
+    ;
+    var allocator = testing.allocator;
+    var arena = ArenaAllocator.init(allocator);
+    defer arena.deinit();
+
+    var stream = std.io.fixedBufferStream(str);
+    var json_reader = jsonReader(std.testing.allocator, stream.reader());
+
+    const parsed = parseFromTokenSourceLeaky(T, arena.allocator(), &json_reader, .{}) catch |err| {
+        json_reader.deinit();
+        return err;
+    };
+    // Deinit our reader to invalidate its buffer
+    json_reader.deinit();
+
+    // If either of these was invalidated, it would be full of '0xAA'
+    try testing.expectEqualSlices(u8, parsed.not_const, "non const string");
+    try testing.expectEqualSlices(u8, parsed.is_const, "const string");
+}
lib/std/json.zig
@@ -88,6 +88,7 @@ pub const parseFromSlice = @import("json/static.zig").parseFromSlice;
 pub const parseFromSliceLeaky = @import("json/static.zig").parseFromSliceLeaky;
 pub const parseFromTokenSource = @import("json/static.zig").parseFromTokenSource;
 pub const parseFromTokenSourceLeaky = @import("json/static.zig").parseFromTokenSourceLeaky;
+pub const innerParse = @import("json/static.zig").innerParse;
 pub const parseFromValue = @import("json/static.zig").parseFromValue;
 pub const parseFromValueLeaky = @import("json/static.zig").parseFromValueLeaky;
 pub const ParseError = @import("json/static.zig").ParseError;