master
  1//===----------------------------------------------------------------------===//
  2//
  3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
  4// See https://llvm.org/LICENSE.txt for license information.
  5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
  6//
  7//===----------------------------------------------------------------------===//
  8
  9#ifndef _LIBCPP___RANDOM_DISCRETE_DISTRIBUTION_H
 10#define _LIBCPP___RANDOM_DISCRETE_DISTRIBUTION_H
 11
 12#include <__algorithm/upper_bound.h>
 13#include <__config>
 14#include <__random/is_valid.h>
 15#include <__random/uniform_real_distribution.h>
 16#include <__vector/vector.h>
 17#include <initializer_list>
 18#include <iosfwd>
 19#include <numeric>
 20
 21#if !defined(_LIBCPP_HAS_NO_PRAGMA_SYSTEM_HEADER)
 22#  pragma GCC system_header
 23#endif
 24
 25_LIBCPP_PUSH_MACROS
 26#include <__undef_macros>
 27
 28_LIBCPP_BEGIN_NAMESPACE_STD
 29
 30template <class _IntType = int>
 31class discrete_distribution {
 32  static_assert(__libcpp_random_is_valid_inttype<_IntType>::value, "IntType must be a supported integer type");
 33
 34public:
 35  // types
 36  typedef _IntType result_type;
 37
 38  class param_type {
 39    vector<double> __p_;
 40
 41  public:
 42    typedef discrete_distribution distribution_type;
 43
 44    _LIBCPP_HIDE_FROM_ABI param_type() {}
 45    template <class _InputIterator>
 46    _LIBCPP_HIDE_FROM_ABI param_type(_InputIterator __f, _InputIterator __l) : __p_(__f, __l) {
 47      __init();
 48    }
 49#ifndef _LIBCPP_CXX03_LANG
 50    _LIBCPP_HIDE_FROM_ABI param_type(initializer_list<double> __wl) : __p_(__wl.begin(), __wl.end()) { __init(); }
 51#endif // _LIBCPP_CXX03_LANG
 52    template <class _UnaryOperation>
 53    _LIBCPP_HIDE_FROM_ABI param_type(size_t __nw, double __xmin, double __xmax, _UnaryOperation __fw);
 54
 55    _LIBCPP_HIDE_FROM_ABI vector<double> probabilities() const;
 56
 57    friend _LIBCPP_HIDE_FROM_ABI bool operator==(const param_type& __x, const param_type& __y) {
 58      return __x.__p_ == __y.__p_;
 59    }
 60    friend _LIBCPP_HIDE_FROM_ABI bool operator!=(const param_type& __x, const param_type& __y) { return !(__x == __y); }
 61
 62  private:
 63    _LIBCPP_HIDE_FROM_ABI void __init();
 64
 65    friend class discrete_distribution;
 66
 67    template <class _CharT, class _Traits, class _IT>
 68    friend basic_ostream<_CharT, _Traits>&
 69    operator<<(basic_ostream<_CharT, _Traits>& __os, const discrete_distribution<_IT>& __x);
 70
 71    template <class _CharT, class _Traits, class _IT>
 72    friend basic_istream<_CharT, _Traits>&
 73    operator>>(basic_istream<_CharT, _Traits>& __is, discrete_distribution<_IT>& __x);
 74  };
 75
 76private:
 77  param_type __p_;
 78
 79public:
 80  // constructor and reset functions
 81  _LIBCPP_HIDE_FROM_ABI discrete_distribution() {}
 82  template <class _InputIterator>
 83  _LIBCPP_HIDE_FROM_ABI discrete_distribution(_InputIterator __f, _InputIterator __l) : __p_(__f, __l) {}
 84#ifndef _LIBCPP_CXX03_LANG
 85  _LIBCPP_HIDE_FROM_ABI discrete_distribution(initializer_list<double> __wl) : __p_(__wl) {}
 86#endif // _LIBCPP_CXX03_LANG
 87  template <class _UnaryOperation>
 88  _LIBCPP_HIDE_FROM_ABI discrete_distribution(size_t __nw, double __xmin, double __xmax, _UnaryOperation __fw)
 89      : __p_(__nw, __xmin, __xmax, __fw) {}
 90  _LIBCPP_HIDE_FROM_ABI explicit discrete_distribution(const param_type& __p) : __p_(__p) {}
 91  _LIBCPP_HIDE_FROM_ABI void reset() {}
 92
 93  // generating functions
 94  template <class _URNG>
 95  _LIBCPP_HIDE_FROM_ABI result_type operator()(_URNG& __g) {
 96    return (*this)(__g, __p_);
 97  }
 98  template <class _URNG>
 99  _LIBCPP_HIDE_FROM_ABI result_type operator()(_URNG& __g, const param_type& __p);
100
101  // property functions
102  _LIBCPP_HIDE_FROM_ABI vector<double> probabilities() const { return __p_.probabilities(); }
103
104  _LIBCPP_HIDE_FROM_ABI param_type param() const { return __p_; }
105  _LIBCPP_HIDE_FROM_ABI void param(const param_type& __p) { __p_ = __p; }
106
107  _LIBCPP_HIDE_FROM_ABI result_type min() const { return 0; }
108  _LIBCPP_HIDE_FROM_ABI result_type max() const { return __p_.__p_.size(); }
109
110  friend _LIBCPP_HIDE_FROM_ABI bool operator==(const discrete_distribution& __x, const discrete_distribution& __y) {
111    return __x.__p_ == __y.__p_;
112  }
113  friend _LIBCPP_HIDE_FROM_ABI bool operator!=(const discrete_distribution& __x, const discrete_distribution& __y) {
114    return !(__x == __y);
115  }
116
117  template <class _CharT, class _Traits, class _IT>
118  friend basic_ostream<_CharT, _Traits>&
119  operator<<(basic_ostream<_CharT, _Traits>& __os, const discrete_distribution<_IT>& __x);
120
121  template <class _CharT, class _Traits, class _IT>
122  friend basic_istream<_CharT, _Traits>&
123  operator>>(basic_istream<_CharT, _Traits>& __is, discrete_distribution<_IT>& __x);
124};
125
126template <class _IntType>
127template <class _UnaryOperation>
128discrete_distribution<_IntType>::param_type::param_type(
129    size_t __nw, double __xmin, double __xmax, _UnaryOperation __fw) {
130  if (__nw > 1) {
131    __p_.reserve(__nw - 1);
132    double __d  = (__xmax - __xmin) / __nw;
133    double __d2 = __d / 2;
134    for (size_t __k = 0; __k < __nw; ++__k)
135      __p_.push_back(__fw(__xmin + __k * __d + __d2));
136    __init();
137  }
138}
139
140template <class _IntType>
141void discrete_distribution<_IntType>::param_type::__init() {
142  if (!__p_.empty()) {
143    if (__p_.size() > 1) {
144      double __s = std::accumulate(__p_.begin(), __p_.end(), 0.0);
145      for (vector<double>::iterator __i = __p_.begin(), __e = __p_.end(); __i < __e; ++__i)
146        *__i /= __s;
147      vector<double> __t(__p_.size() - 1);
148      std::partial_sum(__p_.begin(), __p_.end() - 1, __t.begin());
149      swap(__p_, __t);
150    } else {
151      __p_.clear();
152      __p_.shrink_to_fit();
153    }
154  }
155}
156
157template <class _IntType>
158vector<double> discrete_distribution<_IntType>::param_type::probabilities() const {
159  size_t __n = __p_.size();
160  vector<double> __p(__n + 1);
161  std::adjacent_difference(__p_.begin(), __p_.end(), __p.begin());
162  if (__n > 0)
163    __p[__n] = 1 - __p_[__n - 1];
164  else
165    __p[0] = 1;
166  return __p;
167}
168
169template <class _IntType>
170template <class _URNG>
171_IntType discrete_distribution<_IntType>::operator()(_URNG& __g, const param_type& __p) {
172  static_assert(__libcpp_random_is_valid_urng<_URNG>::value, "");
173  uniform_real_distribution<double> __gen;
174  return static_cast<_IntType>(std::upper_bound(__p.__p_.begin(), __p.__p_.end(), __gen(__g)) - __p.__p_.begin());
175}
176
177template <class _CharT, class _Traits, class _IT>
178_LIBCPP_HIDE_FROM_ABI basic_ostream<_CharT, _Traits>&
179operator<<(basic_ostream<_CharT, _Traits>& __os, const discrete_distribution<_IT>& __x) {
180  __save_flags<_CharT, _Traits> __lx(__os);
181  typedef basic_ostream<_CharT, _Traits> _OStream;
182  __os.flags(_OStream::dec | _OStream::left | _OStream::fixed | _OStream::scientific);
183  _CharT __sp = __os.widen(' ');
184  __os.fill(__sp);
185  size_t __n = __x.__p_.__p_.size();
186  __os << __n;
187  for (size_t __i = 0; __i < __n; ++__i)
188    __os << __sp << __x.__p_.__p_[__i];
189  return __os;
190}
191
192template <class _CharT, class _Traits, class _IT>
193_LIBCPP_HIDE_FROM_ABI basic_istream<_CharT, _Traits>&
194operator>>(basic_istream<_CharT, _Traits>& __is, discrete_distribution<_IT>& __x) {
195  __save_flags<_CharT, _Traits> __lx(__is);
196  typedef basic_istream<_CharT, _Traits> _Istream;
197  __is.flags(_Istream::dec | _Istream::skipws);
198  size_t __n;
199  __is >> __n;
200  vector<double> __p(__n);
201  for (size_t __i = 0; __i < __n; ++__i)
202    __is >> __p[__i];
203  if (!__is.fail())
204    swap(__x.__p_.__p_, __p);
205  return __is;
206}
207
208_LIBCPP_END_NAMESPACE_STD
209
210_LIBCPP_POP_MACROS
211
212#endif // _LIBCPP___RANDOM_DISCRETE_DISTRIBUTION_H