1//===- BranchProbability.h - Branch Probability Wrapper ---------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Definition of BranchProbability shared by IR and Machine Instructions.
10//
11//===----------------------------------------------------------------------===//
12
13#ifndef LLVM_SUPPORT_BRANCHPROBABILITY_H
14#define LLVM_SUPPORT_BRANCHPROBABILITY_H
15
16#include "llvm/Support/DataTypes.h"
17#include <algorithm>
18#include <cassert>
19#include <climits>
20#include <numeric>
21
22namespace llvm {
23
24class raw_ostream;
25
26// This class represents Branch Probability as a non-negative fraction that is
27// no greater than 1. It uses a fixed-point-like implementation, in which the
28// denominator is always a constant value (here we use 1<<31 for maximum
29// precision).
30class BranchProbability {
31  // Numerator
32  uint32_t N;
33
34  // Denominator, which is a constant value.
35  static constexpr uint32_t D = 1u << 31;
36  static constexpr uint32_t UnknownN = UINT32_MAX;
37
38  // Construct a BranchProbability with only numerator assuming the denominator
39  // is 1<<31. For internal use only.
40  explicit BranchProbability(uint32_t n) : N(n) {}
41
42public:
43  BranchProbability() : N(UnknownN) {}
44  BranchProbability(uint32_t Numerator, uint32_t Denominator);
45
46  bool isZero() const { return N == 0; }
47  bool isUnknown() const { return N == UnknownN; }
48
49  static BranchProbability getZero() { return BranchProbability(0); }
50  static BranchProbability getOne() { return BranchProbability(D); }
51  static BranchProbability getUnknown() { return BranchProbability(UnknownN); }
52  // Create a BranchProbability object with the given numerator and 1<<31
53  // as denominator.
54  static BranchProbability getRaw(uint32_t N) { return BranchProbability(N); }
55  // Create a BranchProbability object from 64-bit integers.
56  static BranchProbability getBranchProbability(uint64_t Numerator,
57                                                uint64_t Denominator);
58
59  // Normalize given probabilties so that the sum of them becomes approximate
60  // one.
61  template <class ProbabilityIter>
62  static void normalizeProbabilities(ProbabilityIter Begin,
63                                     ProbabilityIter End);
64
65  uint32_t getNumerator() const { return N; }
66  static uint32_t getDenominator() { return D; }
67
68  // Return (1 - Probability).
69  BranchProbability getCompl() const { return BranchProbability(D - N); }
70
71  raw_ostream &print(raw_ostream &OS) const;
72
73  void dump() const;
74
75  /// Scale a large integer.
76  ///
77  /// Scales \c Num.  Guarantees full precision.  Returns the floor of the
78  /// result.
79  ///
80  /// \return \c Num times \c this.
81  uint64_t scale(uint64_t Num) const;
82
83  /// Scale a large integer by the inverse.
84  ///
85  /// Scales \c Num by the inverse of \c this.  Guarantees full precision.
86  /// Returns the floor of the result.
87  ///
88  /// \return \c Num divided by \c this.
89  uint64_t scaleByInverse(uint64_t Num) const;
90
91  BranchProbability &operator+=(BranchProbability RHS) {
92    assert(N != UnknownN && RHS.N != UnknownN &&
93           "Unknown probability cannot participate in arithmetics.");
94    // Saturate the result in case of overflow.
95    N = (uint64_t(N) + RHS.N > D) ? D : N + RHS.N;
96    return *this;
97  }
98
99  BranchProbability &operator-=(BranchProbability RHS) {
100    assert(N != UnknownN && RHS.N != UnknownN &&
101           "Unknown probability cannot participate in arithmetics.");
102    // Saturate the result in case of underflow.
103    N = N < RHS.N ? 0 : N - RHS.N;
104    return *this;
105  }
106
107  BranchProbability &operator*=(BranchProbability RHS) {
108    assert(N != UnknownN && RHS.N != UnknownN &&
109           "Unknown probability cannot participate in arithmetics.");
110    N = (static_cast<uint64_t>(N) * RHS.N + D / 2) / D;
111    return *this;
112  }
113
114  BranchProbability &operator*=(uint32_t RHS) {
115    assert(N != UnknownN &&
116           "Unknown probability cannot participate in arithmetics.");
117    N = (uint64_t(N) * RHS > D) ? D : N * RHS;
118    return *this;
119  }
120
121  BranchProbability &operator/=(BranchProbability RHS) {
122    assert(N != UnknownN && RHS.N != UnknownN &&
123           "Unknown probability cannot participate in arithmetics.");
124    N = (static_cast<uint64_t>(N) * D + RHS.N / 2) / RHS.N;
125    return *this;
126  }
127
128  BranchProbability &operator/=(uint32_t RHS) {
129    assert(N != UnknownN &&
130           "Unknown probability cannot participate in arithmetics.");
131    assert(RHS > 0 && "The divider cannot be zero.");
132    N /= RHS;
133    return *this;
134  }
135
136  BranchProbability operator+(BranchProbability RHS) const {
137    BranchProbability Prob(*this);
138    Prob += RHS;
139    return Prob;
140  }
141
142  BranchProbability operator-(BranchProbability RHS) const {
143    BranchProbability Prob(*this);
144    Prob -= RHS;
145    return Prob;
146  }
147
148  BranchProbability operator*(BranchProbability RHS) const {
149    BranchProbability Prob(*this);
150    Prob *= RHS;
151    return Prob;
152  }
153
154  BranchProbability operator*(uint32_t RHS) const {
155    BranchProbability Prob(*this);
156    Prob *= RHS;
157    return Prob;
158  }
159
160  BranchProbability operator/(BranchProbability RHS) const {
161    BranchProbability Prob(*this);
162    Prob /= RHS;
163    return Prob;
164  }
165
166  BranchProbability operator/(uint32_t RHS) const {
167    BranchProbability Prob(*this);
168    Prob /= RHS;
169    return Prob;
170  }
171
172  bool operator==(BranchProbability RHS) const { return N == RHS.N; }
173  bool operator!=(BranchProbability RHS) const { return !(*this == RHS); }
174
175  bool operator<(BranchProbability RHS) const {
176    assert(N != UnknownN && RHS.N != UnknownN &&
177           "Unknown probability cannot participate in comparisons.");
178    return N < RHS.N;
179  }
180
181  bool operator>(BranchProbability RHS) const {
182    assert(N != UnknownN && RHS.N != UnknownN &&
183           "Unknown probability cannot participate in comparisons.");
184    return RHS < *this;
185  }
186
187  bool operator<=(BranchProbability RHS) const {
188    assert(N != UnknownN && RHS.N != UnknownN &&
189           "Unknown probability cannot participate in comparisons.");
190    return !(RHS < *this);
191  }
192
193  bool operator>=(BranchProbability RHS) const {
194    assert(N != UnknownN && RHS.N != UnknownN &&
195           "Unknown probability cannot participate in comparisons.");
196    return !(*this < RHS);
197  }
198};
199
200inline raw_ostream &operator<<(raw_ostream &OS, BranchProbability Prob) {
201  return Prob.print(OS);
202}
203
204template <class ProbabilityIter>
205void BranchProbability::normalizeProbabilities(ProbabilityIter Begin,
206                                               ProbabilityIter End) {
207  if (Begin == End)
208    return;
209
210  unsigned UnknownProbCount = 0;
211  uint64_t Sum = std::accumulate(Begin, End, uint64_t(0),
212                                 [&](uint64_t S, const BranchProbability &BP) {
213                                   if (!BP.isUnknown())
214                                     return S + BP.N;
215                                   UnknownProbCount++;
216                                   return S;
217                                 });
218
219  if (UnknownProbCount > 0) {
220    BranchProbability ProbForUnknown = BranchProbability::getZero();
221    // If the sum of all known probabilities is less than one, evenly distribute
222    // the complement of sum to unknown probabilities. Otherwise, set unknown
223    // probabilities to zeros and continue to normalize known probabilities.
224    if (Sum < BranchProbability::getDenominator())
225      ProbForUnknown = BranchProbability::getRaw(
226          (BranchProbability::getDenominator() - Sum) / UnknownProbCount);
227
228    std::replace_if(Begin, End,
229                    [](const BranchProbability &BP) { return BP.isUnknown(); },
230                    ProbForUnknown);
231
232    if (Sum <= BranchProbability::getDenominator())
233      return;
234  }
235
236  if (Sum == 0) {
237    BranchProbability BP(1, std::distance(Begin, End));
238    std::fill(Begin, End, BP);
239    return;
240  }
241
242  for (auto I = Begin; I != End; ++I)
243    I->N = (I->N * uint64_t(D) + Sum / 2) / Sum;
244}
245
246}
247
248#endif
249