Commit 1b728e1834

Marc Tiehuis <marc@tiehu.is>
2024-06-14 03:09:55
std.float.parseFloat: fix large hex-float parsing
There were two primary issues at play here: 1. The hex float prefix was not handled correctly when the stream was reset for the fallback parsing path, which occured when the mantissa was longer max mantissa digits. 2. The implied exponent was not adjusted for hex-floats in this branch. Additionally, some of the float parsing routines have been condensed, making use of comptime. closes #20275
1 parent ffb1a6d
Changed files (4)
lib/std/fmt/parse_float/decimal.zig
@@ -241,18 +241,18 @@ pub fn Decimal(comptime T: type) type {
             var d = Self.new();
             var stream = FloatStream.init(s);
 
-            stream.skipChars2('0', '_');
+            stream.skipChars("0_");
             while (stream.scanDigit(10)) |digit| {
                 d.tryAddDigit(digit);
             }
 
-            if (stream.firstIs('.')) {
+            if (stream.firstIs(".")) {
                 stream.advance(1);
                 const marker = stream.offsetTrue();
 
                 // Skip leading zeroes
                 if (d.num_digits == 0) {
-                    stream.skipChars('0');
+                    stream.skipChars("0");
                 }
 
                 while (stream.hasLen(8) and d.num_digits + 8 < max_digits) {
@@ -292,13 +292,13 @@ pub fn Decimal(comptime T: type) type {
                     d.num_digits = max_digits;
                 }
             }
-            if (stream.firstIsLower('e')) {
+            if (stream.firstIsLower("e")) {
                 stream.advance(1);
                 var neg_exp = false;
-                if (stream.firstIs('-')) {
+                if (stream.firstIs("-")) {
                     neg_exp = true;
                     stream.advance(1);
-                } else if (stream.firstIs('+')) {
+                } else if (stream.firstIs("+")) {
                     stream.advance(1);
                 }
                 var exp_num: i32 = 0;
lib/std/fmt/parse_float/FloatStream.zig
@@ -48,30 +48,16 @@ pub fn isEmpty(self: FloatStream) bool {
     return !self.hasLen(1);
 }
 
-pub fn firstIs(self: FloatStream, c: u8) bool {
+pub fn firstIs(self: FloatStream, comptime cs: []const u8) bool {
     if (self.first()) |ok| {
-        return ok == c;
+        inline for (cs) |c| if (ok == c) return true;
     }
     return false;
 }
 
-pub fn firstIsLower(self: FloatStream, c: u8) bool {
+pub fn firstIsLower(self: FloatStream, comptime cs: []const u8) bool {
     if (self.first()) |ok| {
-        return ok | 0x20 == c;
-    }
-    return false;
-}
-
-pub fn firstIs2(self: FloatStream, c1: u8, c2: u8) bool {
-    if (self.first()) |ok| {
-        return ok == c1 or ok == c2;
-    }
-    return false;
-}
-
-pub fn firstIs3(self: FloatStream, c1: u8, c2: u8, c3: u8) bool {
-    if (self.first()) |ok| {
-        return ok == c1 or ok == c2 or ok == c3;
+        inline for (cs) |c| if (ok | 0x20 == c) return true;
     }
     return false;
 }
@@ -89,12 +75,8 @@ pub fn advance(self: *FloatStream, n: usize) void {
     self.offset += n;
 }
 
-pub fn skipChars(self: *FloatStream, c: u8) void {
-    while (self.firstIs(c)) : (self.advance(1)) {}
-}
-
-pub fn skipChars2(self: *FloatStream, c1: u8, c2: u8) void {
-    while (self.firstIs2(c1, c2)) : (self.advance(1)) {}
+pub fn skipChars(self: *FloatStream, comptime cs: []const u8) void {
+    while (self.firstIs(cs)) : (self.advance(1)) {}
 }
 
 pub fn readU64Unchecked(self: FloatStream) u64 {
lib/std/fmt/parse_float/parse.zig
@@ -100,6 +100,7 @@ const ParseInfo = struct {
 };
 
 fn parsePartialNumberBase(comptime T: type, stream: *FloatStream, negative: bool, n: *usize, comptime info: ParseInfo) ?Number(T) {
+    std.debug.assert(info.base == 10 or info.base == 16);
     const MantissaT = common.mantissaType(T);
 
     // parse initial digits before dot
@@ -107,12 +108,10 @@ fn parsePartialNumberBase(comptime T: type, stream: *FloatStream, negative: bool
     tryParseDigits(MantissaT, stream, &mantissa, info.base);
     const int_end = stream.offsetTrue();
     var n_digits = @as(isize, @intCast(stream.offsetTrue()));
-    // the base being 16 implies a 0x prefix, which shouldn't be included in the digit count
-    if (info.base == 16) n_digits -= 2;
 
     // handle dot with the following digits
     var exponent: i64 = 0;
-    if (stream.firstIs('.')) {
+    if (stream.firstIs(".")) {
         stream.advance(1);
         const marker = stream.offsetTrue();
         tryParseDigits(MantissaT, stream, &mantissa, info.base);
@@ -132,14 +131,14 @@ fn parsePartialNumberBase(comptime T: type, stream: *FloatStream, negative: bool
 
     // handle scientific format
     var exp_number: i64 = 0;
-    if (stream.firstIsLower(info.exp_char_lower)) {
+    if (stream.firstIsLower(&.{info.exp_char_lower})) {
         stream.advance(1);
         exp_number = parseScientific(stream) orelse return null;
         exponent += exp_number;
     }
 
     const len = stream.offset; // length must be complete parsed length
-    n.* = len;
+    n.* += len;
 
     if (stream.underscore_count > 0 and !validUnderscores(stream.slice, info.base)) {
         return null;
@@ -159,7 +158,7 @@ fn parsePartialNumberBase(comptime T: type, stream: *FloatStream, negative: bool
     n_digits -= info.max_mantissa_digits;
     var many_digits = false;
     stream.reset(); // re-parse from beginning
-    while (stream.firstIs3('0', '.', '_')) {
+    while (stream.firstIs("0._")) {
         // '0' = '.' + 2
         const next = stream.firstUnchecked();
         if (next != '_') {
@@ -193,6 +192,9 @@ fn parsePartialNumberBase(comptime T: type, stream: *FloatStream, negative: bool
                 break :blk @as(i64, @intCast(marker)) - @as(i64, @intCast(stream.offsetTrue()));
             }
         };
+        if (info.base == 16) {
+            exponent *= 4;
+        }
         // add back the explicit part
         exponent += exp_number;
     }
@@ -212,17 +214,19 @@ fn parsePartialNumberBase(comptime T: type, stream: *FloatStream, negative: bool
 /// significant digits and the decimal exponent.
 fn parsePartialNumber(comptime T: type, s: []const u8, negative: bool, n: *usize) ?Number(T) {
     std.debug.assert(s.len != 0);
-    var stream = FloatStream.init(s);
     const MantissaT = common.mantissaType(T);
+    n.* = 0;
 
-    if (stream.hasLen(2) and stream.atUnchecked(0) == '0' and std.ascii.toLower(stream.atUnchecked(1)) == 'x') {
-        stream.advance(2);
+    if (s.len >= 2 and s[0] == '0' and std.ascii.toLower(s[1]) == 'x') {
+        var stream = FloatStream.init(s[2..]);
+        n.* += 2;
         return parsePartialNumberBase(T, &stream, negative, n, .{
             .base = 16,
             .max_mantissa_digits = if (MantissaT == u64) 16 else 32,
             .exp_char_lower = 'p',
         });
     } else {
+        var stream = FloatStream.init(s);
         return parsePartialNumberBase(T, &stream, negative, n, .{
             .base = 10,
             .max_mantissa_digits = if (MantissaT == u64) 19 else 38,
lib/std/fmt/parse_float.zig
@@ -1,4 +1,4 @@
-const std = @import("../std.zig");
+const std = @import("std");
 const math = std.math;
 const testing = std.testing;
 const expect = testing.expect;
@@ -151,6 +151,12 @@ test "#11169" {
     try expectEqual(try parseFloat(f128, "9007199254740993.0"), 9007199254740993.0);
 }
 
+test "many_digits hex" {
+    const a: f32 = try std.fmt.parseFloat(f32, "0xffffffffffffffff.0p0");
+    const b: f32 = @floatCast(try std.fmt.parseFloat(f128, "0xffffffffffffffff.0p0"));
+    try std.testing.expectEqual(a, b);
+}
+
 test "hex.special" {
     try testing.expect(math.isNan(try parseFloat(f32, "nAn")));
     try testing.expect(math.isPositiveInf(try parseFloat(f32, "iNf")));