master
  1//===-- High Precision Decimal ----------------------------------*- C++ -*-===//
  2//
  3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
  4// See httpss//llvm.org/LICENSE.txt for license information.
  5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
  6//
  7//===----------------------------------------------------------------------===//
  8
  9// -----------------------------------------------------------------------------
 10//                               **** WARNING ****
 11// This file is shared with libc++. You should also be careful when adding
 12// dependencies to this file, since it needs to build for all libc++ targets.
 13// -----------------------------------------------------------------------------
 14
 15#ifndef LLVM_LIBC_SRC___SUPPORT_HIGH_PRECISION_DECIMAL_H
 16#define LLVM_LIBC_SRC___SUPPORT_HIGH_PRECISION_DECIMAL_H
 17
 18#include "src/__support/CPP/limits.h"
 19#include "src/__support/ctype_utils.h"
 20#include "src/__support/macros/config.h"
 21#include "src/__support/str_to_integer.h"
 22#include <stdint.h>
 23
 24namespace LIBC_NAMESPACE_DECL {
 25namespace internal {
 26
 27struct LShiftTableEntry {
 28  uint32_t new_digits;
 29  char const *power_of_five;
 30};
 31
 32// -----------------------------------------------------------------------------
 33//                               **** WARNING ****
 34// This interface is shared with libc++, if you change this interface you need
 35// to update it in both libc and libc++.
 36// -----------------------------------------------------------------------------
 37// This is used in both this file and in the main str_to_float.h.
 38// TODO: Figure out where to put this.
 39enum class RoundDirection { Up, Down, Nearest };
 40
 41// This is based on the HPD data structure described as part of the Simple
 42// Decimal Conversion algorithm by Nigel Tao, described at this link:
 43// https://nigeltao.github.io/blog/2020/parse-number-f64-simple.html
 44class HighPrecisionDecimal {
 45
 46  // This precomputed table speeds up left shifts by having the number of new
 47  // digits that will be added by multiplying 5^i by 2^i. If the number is less
 48  // than 5^i then it will add one fewer digit. There are only 60 entries since
 49  // that's the max shift amount.
 50  // This table was generated by the script at
 51  // libc/utils/mathtools/GenerateHPDConstants.py
 52  static constexpr LShiftTableEntry LEFT_SHIFT_DIGIT_TABLE[] = {
 53      {0, ""},
 54      {1, "5"},
 55      {1, "25"},
 56      {1, "125"},
 57      {2, "625"},
 58      {2, "3125"},
 59      {2, "15625"},
 60      {3, "78125"},
 61      {3, "390625"},
 62      {3, "1953125"},
 63      {4, "9765625"},
 64      {4, "48828125"},
 65      {4, "244140625"},
 66      {4, "1220703125"},
 67      {5, "6103515625"},
 68      {5, "30517578125"},
 69      {5, "152587890625"},
 70      {6, "762939453125"},
 71      {6, "3814697265625"},
 72      {6, "19073486328125"},
 73      {7, "95367431640625"},
 74      {7, "476837158203125"},
 75      {7, "2384185791015625"},
 76      {7, "11920928955078125"},
 77      {8, "59604644775390625"},
 78      {8, "298023223876953125"},
 79      {8, "1490116119384765625"},
 80      {9, "7450580596923828125"},
 81      {9, "37252902984619140625"},
 82      {9, "186264514923095703125"},
 83      {10, "931322574615478515625"},
 84      {10, "4656612873077392578125"},
 85      {10, "23283064365386962890625"},
 86      {10, "116415321826934814453125"},
 87      {11, "582076609134674072265625"},
 88      {11, "2910383045673370361328125"},
 89      {11, "14551915228366851806640625"},
 90      {12, "72759576141834259033203125"},
 91      {12, "363797880709171295166015625"},
 92      {12, "1818989403545856475830078125"},
 93      {13, "9094947017729282379150390625"},
 94      {13, "45474735088646411895751953125"},
 95      {13, "227373675443232059478759765625"},
 96      {13, "1136868377216160297393798828125"},
 97      {14, "5684341886080801486968994140625"},
 98      {14, "28421709430404007434844970703125"},
 99      {14, "142108547152020037174224853515625"},
100      {15, "710542735760100185871124267578125"},
101      {15, "3552713678800500929355621337890625"},
102      {15, "17763568394002504646778106689453125"},
103      {16, "88817841970012523233890533447265625"},
104      {16, "444089209850062616169452667236328125"},
105      {16, "2220446049250313080847263336181640625"},
106      {16, "11102230246251565404236316680908203125"},
107      {17, "55511151231257827021181583404541015625"},
108      {17, "277555756156289135105907917022705078125"},
109      {17, "1387778780781445675529539585113525390625"},
110      {18, "6938893903907228377647697925567626953125"},
111      {18, "34694469519536141888238489627838134765625"},
112      {18, "173472347597680709441192448139190673828125"},
113      {19, "867361737988403547205962240695953369140625"},
114  };
115
116  // The maximum amount we can shift is the number of bits used in the
117  // accumulator, minus the number of bits needed to represent the base (in this
118  // case 4).
119  static constexpr uint32_t MAX_SHIFT_AMOUNT = sizeof(uint64_t) - 4;
120
121  // 800 is an arbitrary number of digits, but should be
122  // large enough for any practical number.
123  static constexpr uint32_t MAX_NUM_DIGITS = 800;
124
125  uint32_t num_digits = 0;
126  int32_t decimal_point = 0;
127  bool truncated = false;
128  uint8_t digits[MAX_NUM_DIGITS];
129
130private:
131  LIBC_INLINE bool should_round_up(int32_t round_to_digit,
132                                   RoundDirection round) {
133    if (round_to_digit < 0 ||
134        static_cast<uint32_t>(round_to_digit) >= this->num_digits) {
135      return false;
136    }
137
138    // The above condition handles all cases where all of the trailing digits
139    // are zero. In that case, if the rounding mode is up, then this number
140    // should be rounded up. Similarly, if the rounding mode is down, then it
141    // should always round down.
142    if (round == RoundDirection::Up) {
143      return true;
144    } else if (round == RoundDirection::Down) {
145      return false;
146    }
147    // Else round to nearest.
148
149    // If we're right in the middle and there are no extra digits
150    if (this->digits[round_to_digit] == 5 &&
151        static_cast<uint32_t>(round_to_digit + 1) == this->num_digits) {
152
153      // Round up if we've truncated (since that means the result is slightly
154      // higher than what's represented.)
155      if (this->truncated) {
156        return true;
157      }
158
159      // If this exactly halfway, round to even.
160      if (round_to_digit == 0)
161        // When the input is ".5".
162        return false;
163      return this->digits[round_to_digit - 1] % 2 != 0;
164    }
165    // If there are digits after round_to_digit, they must be non-zero since we
166    // trim trailing zeroes after all operations that change digits.
167    return this->digits[round_to_digit] >= 5;
168  }
169
170  // Takes an amount to left shift and returns the number of new digits needed
171  // to store the result based on LEFT_SHIFT_DIGIT_TABLE.
172  LIBC_INLINE uint32_t get_num_new_digits(uint32_t lshift_amount) {
173    const char *power_of_five =
174        LEFT_SHIFT_DIGIT_TABLE[lshift_amount].power_of_five;
175    uint32_t new_digits = LEFT_SHIFT_DIGIT_TABLE[lshift_amount].new_digits;
176    uint32_t digit_index = 0;
177    while (power_of_five[digit_index] != 0) {
178      if (digit_index >= this->num_digits) {
179        return new_digits - 1;
180      }
181      if (this->digits[digit_index] !=
182          internal::b36_char_to_int(power_of_five[digit_index])) {
183        return new_digits -
184               ((this->digits[digit_index] <
185                 internal::b36_char_to_int(power_of_five[digit_index]))
186                    ? 1
187                    : 0);
188      }
189      ++digit_index;
190    }
191    return new_digits;
192  }
193
194  // Trim all trailing 0s
195  LIBC_INLINE void trim_trailing_zeroes() {
196    while (this->num_digits > 0 && this->digits[this->num_digits - 1] == 0) {
197      --this->num_digits;
198    }
199    if (this->num_digits == 0) {
200      this->decimal_point = 0;
201    }
202  }
203
204  // Perform a digitwise binary non-rounding right shift on this value by
205  // shift_amount. The shift_amount can't be more than MAX_SHIFT_AMOUNT to
206  // prevent overflow.
207  LIBC_INLINE void right_shift(uint32_t shift_amount) {
208    uint32_t read_index = 0;
209    uint32_t write_index = 0;
210
211    uint64_t accumulator = 0;
212
213    const uint64_t shift_mask = (uint64_t(1) << shift_amount) - 1;
214
215    // Warm Up phase: we don't have enough digits to start writing, so just
216    // read them into the accumulator.
217    while (accumulator >> shift_amount == 0) {
218      uint64_t read_digit = 0;
219      // If there are still digits to read, read the next one, else the digit is
220      // assumed to be 0.
221      if (read_index < this->num_digits) {
222        read_digit = this->digits[read_index];
223      }
224      accumulator = accumulator * 10 + read_digit;
225      ++read_index;
226    }
227
228    // Shift the decimal point by the number of digits it took to fill the
229    // accumulator.
230    this->decimal_point -= read_index - 1;
231
232    // Middle phase: we have enough digits to write, as well as more digits to
233    // read. Keep reading until we run out of digits.
234    while (read_index < this->num_digits) {
235      uint64_t read_digit = this->digits[read_index];
236      uint64_t write_digit = accumulator >> shift_amount;
237      accumulator &= shift_mask;
238      this->digits[write_index] = static_cast<uint8_t>(write_digit);
239      accumulator = accumulator * 10 + read_digit;
240      ++read_index;
241      ++write_index;
242    }
243
244    // Cool Down phase: All of the readable digits have been read, so just write
245    // the remainder, while treating any more digits as 0.
246    while (accumulator > 0) {
247      uint64_t write_digit = accumulator >> shift_amount;
248      accumulator &= shift_mask;
249      if (write_index < MAX_NUM_DIGITS) {
250        this->digits[write_index] = static_cast<uint8_t>(write_digit);
251        ++write_index;
252      } else if (write_digit > 0) {
253        this->truncated = true;
254      }
255      accumulator = accumulator * 10;
256    }
257    this->num_digits = write_index;
258    this->trim_trailing_zeroes();
259  }
260
261  // Perform a digitwise binary non-rounding left shift on this value by
262  // shift_amount. The shift_amount can't be more than MAX_SHIFT_AMOUNT to
263  // prevent overflow.
264  LIBC_INLINE void left_shift(uint32_t shift_amount) {
265    uint32_t new_digits = this->get_num_new_digits(shift_amount);
266
267    int32_t read_index = static_cast<int32_t>(this->num_digits - 1);
268    uint32_t write_index = this->num_digits + new_digits;
269
270    uint64_t accumulator = 0;
271
272    // No Warm Up phase. Since we're putting digits in at the top and taking
273    // digits from the bottom we don't have to wait for the accumulator to fill.
274
275    // Middle phase: while we have more digits to read, keep reading as well as
276    // writing.
277    while (read_index >= 0) {
278      accumulator += static_cast<uint64_t>(this->digits[read_index])
279                     << shift_amount;
280      uint64_t next_accumulator = accumulator / 10;
281      uint64_t write_digit = accumulator - (10 * next_accumulator);
282      --write_index;
283      if (write_index < MAX_NUM_DIGITS) {
284        this->digits[write_index] = static_cast<uint8_t>(write_digit);
285      } else if (write_digit != 0) {
286        this->truncated = true;
287      }
288      accumulator = next_accumulator;
289      --read_index;
290    }
291
292    // Cool Down phase: there are no more digits to read, so just write the
293    // remaining digits in the accumulator.
294    while (accumulator > 0) {
295      uint64_t next_accumulator = accumulator / 10;
296      uint64_t write_digit = accumulator - (10 * next_accumulator);
297      --write_index;
298      if (write_index < MAX_NUM_DIGITS) {
299        this->digits[write_index] = static_cast<uint8_t>(write_digit);
300      } else if (write_digit != 0) {
301        this->truncated = true;
302      }
303      accumulator = next_accumulator;
304    }
305
306    this->num_digits += new_digits;
307    if (this->num_digits > MAX_NUM_DIGITS) {
308      this->num_digits = MAX_NUM_DIGITS;
309    }
310    this->decimal_point += new_digits;
311    this->trim_trailing_zeroes();
312  }
313
314public:
315  // num_string is assumed to be a string of numeric characters. It doesn't
316  // handle leading spaces.
317  LIBC_INLINE
318  HighPrecisionDecimal(
319      const char *__restrict num_string,
320      const size_t num_len = cpp::numeric_limits<size_t>::max()) {
321    bool saw_dot = false;
322    size_t num_cur = 0;
323    // This counts the digits in the number, even if there isn't space to store
324    // them all.
325    uint32_t total_digits = 0;
326    while (num_cur < num_len &&
327           (isdigit(num_string[num_cur]) || num_string[num_cur] == '.')) {
328      if (num_string[num_cur] == '.') {
329        if (saw_dot) {
330          break;
331        }
332        this->decimal_point = static_cast<int32_t>(total_digits);
333        saw_dot = true;
334      } else {
335        if (num_string[num_cur] == '0' && this->num_digits == 0) {
336          --this->decimal_point;
337          ++num_cur;
338          continue;
339        }
340        ++total_digits;
341        if (this->num_digits < MAX_NUM_DIGITS) {
342          this->digits[this->num_digits] = static_cast<uint8_t>(
343              internal::b36_char_to_int(num_string[num_cur]));
344          ++this->num_digits;
345        } else if (num_string[num_cur] != '0') {
346          this->truncated = true;
347        }
348      }
349      ++num_cur;
350    }
351
352    if (!saw_dot)
353      this->decimal_point = static_cast<int32_t>(total_digits);
354
355    if (num_cur < num_len &&
356        (num_string[num_cur] == 'e' || num_string[num_cur] == 'E')) {
357      ++num_cur;
358      if (isdigit(num_string[num_cur]) || num_string[num_cur] == '+' ||
359          num_string[num_cur] == '-') {
360        auto result =
361            strtointeger<int32_t>(num_string + num_cur, 10, num_len - num_cur);
362        if (result.has_error()) {
363          // TODO: handle error
364        }
365        int32_t add_to_exponent = result.value;
366
367        // Here we do this operation as int64 to avoid overflow.
368        int64_t temp_exponent = static_cast<int64_t>(this->decimal_point) +
369                                static_cast<int64_t>(add_to_exponent);
370
371        // Theoretically these numbers should be MAX_BIASED_EXPONENT for long
372        // double, but that should be ~16,000 which is much less than 1 << 30.
373        if (temp_exponent > (1 << 30)) {
374          temp_exponent = (1 << 30);
375        } else if (temp_exponent < -(1 << 30)) {
376          temp_exponent = -(1 << 30);
377        }
378        this->decimal_point = static_cast<int32_t>(temp_exponent);
379      }
380    }
381
382    this->trim_trailing_zeroes();
383  }
384
385  // Binary shift left (shift_amount > 0) or right (shift_amount < 0)
386  LIBC_INLINE void shift(int shift_amount) {
387    if (shift_amount == 0) {
388      return;
389    }
390    // Left
391    else if (shift_amount > 0) {
392      while (static_cast<uint32_t>(shift_amount) > MAX_SHIFT_AMOUNT) {
393        this->left_shift(MAX_SHIFT_AMOUNT);
394        shift_amount -= MAX_SHIFT_AMOUNT;
395      }
396      this->left_shift(static_cast<uint32_t>(shift_amount));
397    }
398    // Right
399    else {
400      while (static_cast<uint32_t>(shift_amount) < -MAX_SHIFT_AMOUNT) {
401        this->right_shift(MAX_SHIFT_AMOUNT);
402        shift_amount += MAX_SHIFT_AMOUNT;
403      }
404      this->right_shift(static_cast<uint32_t>(-shift_amount));
405    }
406  }
407
408  // Round the number represented to the closest value of unsigned int type T.
409  // This is done ignoring overflow.
410  template <class T>
411  LIBC_INLINE T
412  round_to_integer_type(RoundDirection round = RoundDirection::Nearest) {
413    T result = 0;
414    uint32_t cur_digit = 0;
415
416    while (static_cast<int32_t>(cur_digit) < this->decimal_point &&
417           cur_digit < this->num_digits) {
418      result = result * 10 + (this->digits[cur_digit]);
419      ++cur_digit;
420    }
421
422    // If there are implicit 0s at the end of the number, include those.
423    while (static_cast<int32_t>(cur_digit) < this->decimal_point) {
424      result *= 10;
425      ++cur_digit;
426    }
427    return result +
428           static_cast<T>(this->should_round_up(this->decimal_point, round));
429  }
430
431  // Extra functions for testing.
432
433  LIBC_INLINE uint8_t *get_digits() { return this->digits; }
434  LIBC_INLINE uint32_t get_num_digits() { return this->num_digits; }
435  LIBC_INLINE int32_t get_decimal_point() { return this->decimal_point; }
436  LIBC_INLINE void set_truncated(bool trunc) { this->truncated = trunc; }
437};
438
439} // namespace internal
440} // namespace LIBC_NAMESPACE_DECL
441
442#endif // LLVM_LIBC_SRC___SUPPORT_HIGH_PRECISION_DECIMAL_H