mirror of https://github.com/poanetwork/gecko.git
117 lines
2.9 KiB
Go
117 lines
2.9 KiB
Go
// (c) 2019-2020, Ava Labs, Inc. All rights reserved.
|
|
// See the file LICENSE for licensing terms.
|
|
|
|
package random
|
|
|
|
import (
|
|
"math"
|
|
"math/rand"
|
|
)
|
|
|
|
// Weighted implements the Sampler interface by sampling based on a heap
|
|
// structure.
|
|
//
|
|
// Node weight is defined as the node's given weight along with it's
|
|
// children's recursive weights. Once sampled, a nodes given weight is set to 0.
|
|
//
|
|
// Replacing runs in O(n) time while sampling runs in O(log(n)) time.
|
|
type Weighted struct {
|
|
Weights []uint64
|
|
|
|
// The reason this is separated from Weights, is because it is set to 0
|
|
// after being sampled.
|
|
weights []int64
|
|
cumWeights []int64
|
|
}
|
|
|
|
func (s *Weighted) init() {
|
|
if len(s.Weights) != len(s.weights) {
|
|
s.Replace()
|
|
}
|
|
}
|
|
|
|
// Sample returns a number in [0, len(weights)) with probability proportional to
|
|
// the weight of the item at that index. Assumes Len > 0. Sample takes
|
|
// O(log(len(weights))) time.
|
|
func (s *Weighted) Sample() int {
|
|
i := s.SampleReplace()
|
|
s.changeWeight(i, 0)
|
|
return i
|
|
}
|
|
|
|
// SampleReplace returns a number in [0, len(weights)) with probability
|
|
// proportional to the weight of the item at that index. Assumes CanSample
|
|
// returns true. Sample takes O(log(len(weights))) time. The returned index is
|
|
// not removed.
|
|
func (s *Weighted) SampleReplace() int {
|
|
s.init()
|
|
for w, i := rand.Int63n(s.cumWeights[0]), 0; ; {
|
|
w -= s.weights[i]
|
|
if w < 0 {
|
|
return i
|
|
}
|
|
|
|
i = i*2 + 1 // We shouldn't return the root, so check the left child
|
|
|
|
if lw := s.cumWeights[i]; lw <= w {
|
|
// If the weight is greater than the left weight, you should move to
|
|
// the right child
|
|
w -= lw
|
|
i++
|
|
}
|
|
}
|
|
}
|
|
|
|
// CanSample returns the number of items left that can be sampled
|
|
func (s *Weighted) CanSample() bool {
|
|
s.init()
|
|
return len(s.cumWeights) > 0 && s.cumWeights[0] > 0
|
|
}
|
|
|
|
// Replace all the sampled elements. Takes O(len(weights)) time.
|
|
func (s *Weighted) Replace() {
|
|
// Attempt to malloc as few times as possible
|
|
if s.weights == nil || cap(s.weights) < len(s.Weights) {
|
|
s.weights = make([]int64, len(s.Weights))
|
|
} else {
|
|
s.weights = s.weights[:len(s.Weights)]
|
|
}
|
|
if s.cumWeights == nil || cap(s.cumWeights) < len(s.Weights) {
|
|
s.cumWeights = make([]int64, len(s.Weights))
|
|
} else {
|
|
s.cumWeights = s.cumWeights[:len(s.Weights)]
|
|
}
|
|
|
|
for i, w := range s.Weights {
|
|
if w > math.MaxInt64 {
|
|
panic("Weight too large")
|
|
}
|
|
s.weights[i] = int64(w)
|
|
}
|
|
|
|
copy(s.cumWeights, s.weights)
|
|
|
|
// Initialize the heap
|
|
for i := len(s.cumWeights) - 1; i > 0; i-- {
|
|
parent := (i - 1) / 2
|
|
w := uint64(s.cumWeights[parent]) + uint64(s.cumWeights[i])
|
|
if w > math.MaxInt64 {
|
|
panic("Weight too large")
|
|
}
|
|
s.cumWeights[parent] = int64(w)
|
|
}
|
|
}
|
|
|
|
func (s *Weighted) changeWeight(i int, newWeight int64) {
|
|
change := s.weights[i] - newWeight
|
|
|
|
s.weights[i] = newWeight
|
|
|
|
// Decrease my weight and all my parents weights.
|
|
s.cumWeights[i] -= change
|
|
for i > 0 {
|
|
i = (i - 1) / 2
|
|
s.cumWeights[i] -= change
|
|
}
|
|
}
|