gecko/utils/random/weighted.go

117 lines
2.9 KiB
Go

// (c) 2019-2020, Ava Labs, Inc. All rights reserved.
// See the file LICENSE for licensing terms.
package random
import (
"math"
"math/rand"
)
// Weighted implements the Sampler interface by sampling based on a heap
// structure.
//
// Node weight is defined as the node's given weight along with it's
// children's recursive weights. Once sampled, a nodes given weight is set to 0.
//
// Replacing runs in O(n) time while sampling runs in O(log(n)) time.
type Weighted struct {
Weights []uint64
// The reason this is separated from Weights, is because it is set to 0
// after being sampled.
weights []int64
cumWeights []int64
}
func (s *Weighted) init() {
if len(s.Weights) != len(s.weights) {
s.Replace()
}
}
// Sample returns a number in [0, len(weights)) with probability proportional to
// the weight of the item at that index. Assumes Len > 0. Sample takes
// O(log(len(weights))) time.
func (s *Weighted) Sample() int {
i := s.SampleReplace()
s.changeWeight(i, 0)
return i
}
// SampleReplace returns a number in [0, len(weights)) with probability
// proportional to the weight of the item at that index. Assumes CanSample
// returns true. Sample takes O(log(len(weights))) time. The returned index is
// not removed.
func (s *Weighted) SampleReplace() int {
s.init()
for w, i := rand.Int63n(s.cumWeights[0]), 0; ; {
w -= s.weights[i]
if w < 0 {
return i
}
i = i*2 + 1 // We shouldn't return the root, so check the left child
if lw := s.cumWeights[i]; lw <= w {
// If the weight is greater than the left weight, you should move to
// the right child
w -= lw
i++
}
}
}
// CanSample returns the number of items left that can be sampled
func (s *Weighted) CanSample() bool {
s.init()
return len(s.cumWeights) > 0 && s.cumWeights[0] > 0
}
// Replace all the sampled elements. Takes O(len(weights)) time.
func (s *Weighted) Replace() {
// Attempt to malloc as few times as possible
if s.weights == nil || cap(s.weights) < len(s.Weights) {
s.weights = make([]int64, len(s.Weights))
} else {
s.weights = s.weights[:len(s.Weights)]
}
if s.cumWeights == nil || cap(s.cumWeights) < len(s.Weights) {
s.cumWeights = make([]int64, len(s.Weights))
} else {
s.cumWeights = s.cumWeights[:len(s.Weights)]
}
for i, w := range s.Weights {
if w > math.MaxInt64 {
panic("Weight too large")
}
s.weights[i] = int64(w)
}
copy(s.cumWeights, s.weights)
// Initialize the heap
for i := len(s.cumWeights) - 1; i > 0; i-- {
parent := (i - 1) / 2
w := uint64(s.cumWeights[parent]) + uint64(s.cumWeights[i])
if w > math.MaxInt64 {
panic("Weight too large")
}
s.cumWeights[parent] = int64(w)
}
}
func (s *Weighted) changeWeight(i int, newWeight int64) {
change := s.weights[i] - newWeight
s.weights[i] = newWeight
// Decrease my weight and all my parents weights.
s.cumWeights[i] -= change
for i > 0 {
i = (i - 1) / 2
s.cumWeights[i] -= change
}
}