# An efficient probability distribution data structure

A probability distribution is a list of nodes, each holding an object that may be sampled along its positive weight. Whenever asking for an object in the distribution (sampling the distribution), the data structure returns an object according to its weight and the weights of all the other objects. For example, if the list is $\langle (A, 1.0), (B, 1.0), (C, 3.0) \rangle$, whenever sampling the distribution, both $A$ and $B$ may be returned with 20% chance each, and $C$ may be returned with probability of 60%.

One obvious data structure is an array-based list (such as ArrayList in Java, or vector in C++). Both the list data structures are implemented in such a way, that adding a new object/weight pair is done in amortized constant time, yet sampling and removing an element will run in $\Theta(n)$ worst case time.

In this post I will explain how to implement a data structure that supports all three fundamental operations (insertion, sampling, removal) in logarithmic time in all cases. However, I will supply three other data structures for the sake of comparison.

#### The application programming interface

Clearly, in our probability distribution data structures, we need at least three methods:

• sampleElement,
• removeElement.

Now, let us proceed to defining an abstract base class:

net.coderodde.stat.AbstractProbabilityDistribution.java
package net.coderodde.stat;

import java.util.Objects;
import java.util.Random;

public abstract class AbstractProbabilityDistribution<E> {

protected double totalWeight;
protected final Random random;

protected AbstractProbabilityDistribution() {
this(new Random());
}

protected AbstractProbabilityDistribution(Random random) {
this.random =
Objects.requireNonNull(random,
"The random number generator is null.");
}

public abstract boolean isEmpty();
public abstract int size();
public abstract boolean addElement(E element, double weight);
public abstract E sampleElement();
public abstract boolean contains(E element);
public abstract boolean removeElement(E element);
public abstract void clear();

protected void checkWeight(double weight) {
if (Double.isNaN(weight)) {
throw new IllegalArgumentException("The element weight is NaN.");
}

if (weight <= 0.0) {
throw new IllegalArgumentException(
"The element weight must be positive. Received " + weight);
}

if (Double.isInfinite(weight)) {
// Once here, 'weight' is positive infinity.
throw new IllegalArgumentException(
"The element weight is infinite.");
}
}

protected void checkNotEmpty(int size) {
if (size == 0) {
throw new IllegalStateException(
"This probability distribution is empty.");
}
}
}


#### Trivial implementation: ArrayProbabilityDistribution.java

First, we start with a naïve implementation: the distribution is implemented by a vector of entries, each holding the actual element and its (positive) weight. When adding a new entry, it is appended to the end of the underlying vector. Whenever removing an entry, we remove it from the vector and shift the right part of the vector one position to the left. Clearly, addition runs in amortised constant time whenever using list data structures such as ArrayList or vector. What comes to the removal, unfortunately it is worst case and average case $\Theta(n)$. Also, in order to sample an element, we may need to visit as many entries as there is in the data structure, which yields worst and average case linear time for sampling an element.

All in all, the trivial implementation looks like this:

net.coderodde.stat.AbstractProbabilityDistribution.java
package net.coderodde.stat.support;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Random;
import net.coderodde.stat.AbstractProbabilityDistribution;

public class ArrayProbabilityDistribution<E>
extends AbstractProbabilityDistribution<E> {

private static final class Entry<E> {

private final E element;
private double weight;

Entry(E element, double weight) {
this.element = element;
this.weight = weight;
}

E getElement() {
return element;
}

double getWeight() {
return weight;
}

void setWeight(double weight) {
this.weight = weight;
}
}

private final List<Entry<E>> storage = new ArrayList<>();
private final Map<E, Entry<E>> map = new HashMap<>();

public ArrayProbabilityDistribution() {
this(new Random());
}

public ArrayProbabilityDistribution(Random random) {
super(random);
}

@Override
public boolean addElement(E element, double weight) {
checkWeightIsPositiveAndNonNanN(weight);
Entry<E> entry = map.get(element);

if (entry != null) {
entry.setWeight(entry.getWeight() + weight);
} else {
entry = new Entry<>(element, weight);
map.put(element, entry);
}

totalWeight += weight;
return true;
}

@Override
public E sampleElement() {
checkNotEmpty();
double value = random.nextDouble() * totalWeight;
int distributionSize = storage.size();

for (int i = 0; i < distributionSize; ++i) {
Entry<E> entry = storage.get(i);
double currentWeight = entry.getWeight();

if (value < currentWeight) {
return entry.getElement();
}

value -= currentWeight;
}

throw new IllegalStateException("Should not get here.");
}

@Override
public boolean removeElement(E element) {
Entry<E> entry = map.remove(element);

if (entry == null) {
return false;
}

totalWeight -= entry.getWeight();
storage.remove(entry);
return true;
}

@Override
public void clear() {
map.clear();
storage.clear();
totalWeight = 0.0;
}

@Override
public boolean isEmpty() {
return storage.isEmpty();
}

@Override
public int size() {
return storage.size();
}

@Override
public boolean contains(E element) {
return map.containsKey(element);
}

protected void checkNotEmpty() {
checkNotEmpty(storage.size());
}

public static void main(String[] args) {
AbstractProbabilityDistribution<Integer> pd =
new ArrayProbabilityDistribution<>();

int[] counts = new int[4];

for (int i = 0; i < 1000; ++i) {
Integer myint = pd.sampleElement();
counts[myint]++;
}

System.out.println(Arrays.toString(counts));
}
}


#### Improving the trivial implementation: BinarySearchProbabilityDistribution.java

The next probability distribution extends the one above by adding accumulated weights to each entry. Since the accumulated weights form a non-decreasing sequence as we march from the beginning of the storage list towards its end, and the storage list supports access in $\Theta(1)$, we can apply binary search to the list in order to sample an element.

Basically, binary search version has the same running times for insertion and removal operations as the very first one (ArrayProbabilityDistribution), yet it improves both average and worst case running times for sampling to $\Theta(\log n)$. The code follows:

package net.coderodde.stat.support;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Random;
import net.coderodde.stat.AbstractProbabilityDistribution;

/**
* This class implements a probability distribution data structure that
* maintains an accumulated sum of weights and thus allows sampling the elements
* in worst-case logarithmic time.
*
* @author Rodion "rodde" Efremov
* @version 1.61 (Sep 30, 2016)
*/
public class BinarySearchProbabilityDistribution<E>
extends AbstractProbabilityDistribution<E> {

/**
* This class implements the actual entry in the distribution.
*
* @param <E> the actual element type.
*/
private static final class Entry<E> {

private final E element;

private double weight;

private double accumulatedWeight;

Entry(E element, double weight, double accumulatedWeight) {
this.element = element;
this.weight = weight;
this.accumulatedWeight = accumulatedWeight;
}

E getElement() {
return element;
}

double getWeight() {
return weight;
}

void setWeight(double weight) {
this.weight = weight;
}

double getAccumulatedWeight() {
return accumulatedWeight;
}

accumulatedWeight += delta;
}
}

/**
* This map maps each element stored in this probability distribution to its
* respective entry.
*/
private final Map<E, Entry<E>> map = new HashMap<>();

/**
* Holds the actual distribution entries.
*/
private final List<Entry<E>> storage = new ArrayList<>();

/**
* Constructs this probability distribution with default random number
* generator.
*/
public BinarySearchProbabilityDistribution() {
this(new Random());
}

/**
* Constructs this probability distribution with given random number
* generator.
*
* @param random the random number generator.
*/
public BinarySearchProbabilityDistribution(Random random) {
super(random);
}

/**
* {@inheritDoc }
*/
@Override
public boolean addElement(E element, double weight) {
checkWeightIsPositiveAndNonNaN(weight);
Entry<E> entry = map.get(element);

if (entry == null) {
entry = new Entry<>(element, weight, totalWeight);
map.put(element, entry);
} else {
for (int i = storage.indexOf(entry); i < storage.size(); ++i) {
}
}

totalWeight += weight;
return true;
}

/**
* {@inheritDoc }
*/
@Override
public E sampleElement() {
checkNotEmpty();
double value = totalWeight * random.nextDouble();

int left = 0;
int right = storage.size() - 1;

while (left < right) {             int middle = left + ((right - left) >> 1);
Entry<E> middleEntry = storage.get(middle);
double lowerBound = middleEntry.getAccumulatedWeight();
double upperBound = lowerBound + middleEntry.getWeight();

if (lowerBound <= value && value < upperBound) {
return middleEntry.getElement();
}

if (value < lowerBound) {
right = middle - 1;
} else {
left = middle + 1;
}
}

return storage.get(left).getElement();
}

/**
* {@inheritDoc }
*/
@Override
public boolean contains(E element) {
return map.containsKey(element);
}

/**
* {@inheritDoc }
*/
@Override
public boolean removeElement(E element) {
Entry<E> entry = map.remove(element);

if (entry == null) {
return false;
}

int index = storage.indexOf(entry);
storage.remove(index);

for (int i = index; i < storage.size(); ++i) {
}

totalWeight -= entry.getWeight();
return true;
}

/**
* {@inheritDoc }
*/
@Override
public void clear() {
map.clear();
storage.clear();
totalWeight = 0.0;
}

/**
* {@inheritDoc }
*/
@Override
public boolean isEmpty() {
return storage.isEmpty();
}

/**
* {@inheritDoc }
*/
@Override
public int size() {
return storage.size();
}

private void checkNotEmpty() {
checkNotEmpty(storage.size());
}

public static void main(String[] args) {
BinarySearchProbabilityDistribution<Integer> d = new BinarySearchProbabilityDistribution<>();

d.removeElement(3);

System.out.println("");

binarySearchProbabilityDistributionDemo();
}

private static void binarySearchProbabilityDistributionDemo() {
BinarySearchProbabilityDistribution<Integer> pd =
new BinarySearchProbabilityDistribution<>();

int[] counts = new int[4];

for (int i = 0; i < 100; ++i) {
Integer myint = pd.sampleElement();
counts[myint]++;
System.out.println(myint);
}

System.out.println(Arrays.toString(counts));
}
}


#### Efficient removal of elements: LinkedListProbabilityDistribution.java

The third data structure improves the removal operation to constant time in all cases, yet runs the sampling operation in linear time in worst and average cases. The idea is to use a doubly-linked list for storing the entries. Also, we maintain a hashtable-based map mapping each element in the distribution to its node in the linked list. Clearly, in order to add a new element, we create a new linked list node for the element and map it in the map; both actions run in constant time with possible exception of the situation where the hash map has to be expanded, which is $\Theta(n)$. However, adding to a hash map runs in constant amortized time. What comes to the removal operation, we access the linked list node through the map, and unlink it; both actions run in constant time as well. However, in order to sample an element, in both average and worst cases we need to march through at least half the list, which is clearly linear time. The implementation follows:

package net.coderodde.stat.support;

import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
import java.util.Random;
import net.coderodde.stat.AbstractProbabilityDistribution;

/**
* This class implements a probability distribution relying on a linked list.
* The running times of the main methods are as follows:
*
*
<table>
*
<tr>
<td>Method</td>
<td>Complexity</td>
</tr>
*
<tr>
*
<td><tt>amortized constant time</tt>,</td>
</tr>
*
<tr>
<td><tt>sampleElement</tt></td>
<td><tt>O(n)</tt>,</td>
</tr>
*
<tr>
<td><tt>removeElement</tt></td>
<td><tt>O(1)</tt>.</td>
</tr>
*</table>
*
* This probability distribution class is best used whenever it is modified
* frequently compared to the number of queries made.
*
* @param <E> the actual type of the elements stored in this probability
*            distribution.
*
* @author Rodion "rodde" Efremov
* @version 1.61 (Sep 30, 2016)
*/
extends AbstractProbabilityDistribution<E> {

private static final class LinkedListNode<E> {

private final E element;
private double weight;

this.element = element;
this.weight  = weight;
}

E getElement() {
return element;
}

double getWeight() {
return weight;
}

void setWeight(double weight) {
this.weight = weight;
}

return prev;
}

return next;
}

prev = node;
}

next = node;
}
}

/**
* This map maps the elements to their respective linked list nodes.
*/
private final Map<E, LinkedListNode<E>> map = new HashMap<>();

/**
* Stores the very first linked list node in this probability distribution.
*/

/**
* Stores the very last linked list node in this probability distribution.
*/

/**
* Construct a new probability distribution.
*/
super();
}

/**
* Constructs a new probability distribution using the input random number
* generator.
*
* @param random the random number generator to use.
*/
super(random);
}

/**
* {@inheritDoc }
*/
@Override
public boolean addElement(E element, double weight) {
checkWeightIsPositiveAndNonNaN(weight);

if (node == null) {

} else {
}

map.put(element, node);
} else {
node.setWeight(node.getWeight() + weight);
}

totalWeight += weight;
return true;
}

/**
* {@inheritDoc }
*/
@Override
public E sampleElement() {
checkNotEmpty(map.size());
double value = random.nextDouble() * totalWeight;

node != null;
if (value < node.getWeight()) {
return node.getElement();
}

value -= node.getWeight();
}

throw new IllegalStateException("Should not get here.");
}

/**
* {@inheritDoc }
*/
@Override
public boolean contains(E element) {
return map.containsKey(element);
}

/**
* {@inheritDoc }
*/
@Override
public boolean removeElement(E element) {

if (node == null) {
return false;
}

totalWeight -= node.getWeight();
return true;
}

/**
* {@inheritDoc }
*/
@Override
public void clear() {
totalWeight = 0.0;
map.clear();
}

/**
* {@inheritDoc }
*/
@Override
public boolean isEmpty() {
return map.isEmpty();
}

/**
* {@inheritDoc }
*/
@Override
public int size() {
return map.size();
}

if (left != null) {
} else {
}

if (right != null) {
} else {
}
}

public static void main(String[] args) {

int[] counts = new int[4];

for (int i = 0; i < 100; ++i) {
Integer myint = pd.sampleElement();
counts[myint]++;
System.out.println(myint);
}

System.out.println(Arrays.toString(counts));
}
}


#### Efficient binary tree based implementation: BinaryTreeProbabilityDistribution.java

Finally, we can discuss the tree based probability distribution. The main idea is to organize the element entries according to a balanced binary tree. For that we define to different categories of nodes:

• leaf nodes,
• relay nodes.

Only leaf nodes store the actual elements. Each relay or leaf node $n$ is augmented with the following fields:

• element: stores the actual element in a leaf node; unused in the relay nodes,
• weight: if the node is a leaf node, stores the weight of element. Otherwise (if the node is a relay node), stores the sum of all weights in all the leaves of the subtree rooted at $n$,
• isRelayNode: set to true only in relay nodes,
• leftChild: the left subtree of $n$, or null if this node is a leaf node,
• rightChild: the right subtree of $n$, or null if this node is a leaf node,
• parent: the parent relay node of this node,
• numberOfLeafNodes: if $n$ is a leaf node, the field is set to one, otherwise it stores the number of leaf nodes in the subtree rooted at $n$.

Next, the invariant is that relay nodes have exactly two child nodes. Any of the two children of a relay node can be relay or leaf nodes.

###### Inserting an element

In order to insert an element in the data structure, we start from the root node. If the root is null (the distribution is empty), we create a leaf and set it as the root. Otherwise, if the root is a leaf $n$ (only one element in the structure), we create a new relay node $r$, append $n$ and the new leaf as its children, and set $r$ as a root. In case the root node is a relay node, we consult its two children, and descend to the one with smaller count field value. That way we keep the entire tree balanced. Note that this is possible due to the fact that we do not have to maintain any order in the tree: upon insertion, we just put a new leaf node to a least high tree. See Figure 1 for a demonstration of inserting some elements in the tree.

Figure 1: Inserting elements

In order to delete an element from the binary tree probability distribution, we have the following: let the tree node containing the element requested for removal be $n$. Now, if $n$ is the root, just set nil to the root. Otherwise, the parent of $n$ is a relay node $r$ by the data structure invariant. What we do is extracting the other child of $r$ (call it $n'$) and replace $r$ with $n'$. Note that even if we keep removing, say, rightmost leaf node, the entire tree remains balanced (try to read the Firgure 1 “backwards” for illustration).

Finally, what comes to the sampling, we toss a coin $c \in [0, \sum w_i)$, and start descending the tree downwards from the root until we reach a leaf node whose “bucket” contains $late x$. Since the depth of the tree is logarithmic in its size, we have the running time of $\Theta(\log n)$ for all cases. The Java implementation follows:

package net.coderodde.stat.support;

import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
import java.util.Random;
import net.coderodde.stat.AbstractProbabilityDistribution;

public class BinaryTreeProbabilityDistribution<E>
extends AbstractProbabilityDistribution<E> {

private static final class Node<E> {

private final E element;
private double weight;
private final boolean isRelayNode;
private Node<E> leftChild;
private Node<E> rightChild;
private Node<E> parent;
private int numberOfLeafNodes;

Node(E element, double weight) {
this.element           = element;
this.weight            = weight;
this.numberOfLeafNodes = 1;
this.isRelayNode       = false;
}

Node(double weight) {
this.element           = null;
this.weight            = weight;
this.numberOfLeafNodes = 1;
this.isRelayNode       = true;
}

@Override
public String toString() {
if (isRelayNode) {
return "[" + String.format("%.3f", getWeight()) + " : "
+ numberOfLeafNodes + "]";
}

return "(" + String.format("%.3f", getWeight()) + " : "
+ element + ")";
}

E getElement() {
return element;
}

double getWeight() {
return weight;
}

void setWeight(double weight) {
this.weight = weight;
}

int getNumberOfLeaves() {
return numberOfLeafNodes;
}

void setNumberOfLeaves(int numberOfLeaves) {
this.numberOfLeafNodes = numberOfLeaves;
}

Node<E> getLeftChild() {
return leftChild;
}

void setLeftChild(Node<E> block) {
this.leftChild = block;
}

Node<E> getRightChild() {
return rightChild;
}

void setRightChild(Node<E> block) {
this.rightChild = block;
}

Node<E> getParent() {
return parent;
}

void setParent(Node<E> block) {
this.parent = block;
}

boolean isRelayNode() {
return isRelayNode;
}
}

private final Map<E, Node<E>> map = new HashMap<>();
private Node<E> root;

public BinaryTreeProbabilityDistribution() {
this(new Random());
}

public BinaryTreeProbabilityDistribution(Random random) {
super(random);
}

@Override
public boolean addElement(E element, double weight) {
checkWeightIsPositiveAndNonNanN(weight);
Node<E> node = map.get(element);

if (node == null) {
node = new Node<>(element, weight);
insert(node);
map.put(element, node);
} else {
node.setWeight(node.getWeight() + weight);
}

totalWeight += weight;
return true;
}

@Override
public boolean contains(E element) {
return map.containsKey(element);
}

@Override
public E sampleElement() {
checkNotEmpty(map.size());
double value = totalWeight * random.nextDouble();
Node<E> node = root;

while (node.isRelayNode()) {
if (value < node.getLeftChild().getWeight()) {
node = node.getLeftChild();
} else {
value -= node.getLeftChild().getWeight();
node = node.getRightChild();
}
}

return node.getElement();
}

@Override
public boolean removeElement(E element) {
Node<E> node = map.remove(element);

if (node == null) {
return false;
}

delete(node);
totalWeight -= node.getWeight();
return true;
}

@Override
public void clear() {
root = null;
map.clear();
totalWeight = 0.0;
}

@Override
public boolean isEmpty() {
return map.isEmpty();
}

@Override
public int size() {
return map.size();
}

private void bypassLeafNode(Node<E> leafNodeToBypass,
Node<E> newNode) {
Node<E> relayNode = new Node<>(leafNodeToBypass.getWeight());
Node<E> parentOfCurrentNode = leafNodeToBypass.getParent();

relayNode.setLeftChild(leafNodeToBypass);
relayNode.setRightChild(newNode);

leafNodeToBypass.setParent(relayNode);
newNode.setParent(relayNode);

if (parentOfCurrentNode == null) {
root = relayNode;
} else if (parentOfCurrentNode.getLeftChild() == leafNodeToBypass) {
relayNode.setParent(parentOfCurrentNode);
parentOfCurrentNode.setLeftChild(relayNode);
} else {
relayNode.setParent(parentOfCurrentNode);
parentOfCurrentNode.setRightChild(relayNode);
}

}

private void insert(Node<E> node) {
if (root == null) {
root = node;
return;
}

Node<E> currentNode = root;

while (currentNode.isRelayNode()) {
if (currentNode.getLeftChild().getNumberOfLeaves() <
currentNode.getRightChild().getNumberOfLeaves()) {
currentNode = currentNode.getLeftChild();
} else {
currentNode = currentNode.getRightChild();
}
}

bypassLeafNode(currentNode, node);
}

private void delete(Node<E> leafToDelete) {
Node<E> relayNode = leafToDelete.getParent();

if (relayNode == null) {
root = null;
return;
}

Node<E> parentOfRelayNode = relayNode.getParent();
Node<E> siblingLeaf = relayNode.getLeftChild() == leafToDelete ?
relayNode.getRightChild() :
relayNode.getLeftChild();

if (parentOfRelayNode == null) {
root = siblingLeaf;
siblingLeaf.setParent(null);
return;
}

if (parentOfRelayNode.getLeftChild() == relayNode) {
parentOfRelayNode.setLeftChild(siblingLeaf);
} else {
parentOfRelayNode.setRightChild(siblingLeaf);
}

siblingLeaf.setParent(parentOfRelayNode);
}

double weightDelta,
int nodeDelta) {
while (node != null) {
node.setNumberOfLeaves(node.getNumberOfLeaves() + nodeDelta);
node.setWeight(node.getWeight() + weightDelta);
node = node.getParent();
}
}

public static void main(String[] args) {
AbstractProbabilityDistribution<Integer> pd =
new BinaryTreeProbabilityDistribution<>();

int[] counts = new int[4];

for (int i = 0; i < 1000; ++i) {
Integer myint = pd.sampleElement();
counts[myint]++;
}

System.out.println(Arrays.toString(counts));
}
}

ArrayProbabilityDistribution
Operation Worst case Average case Best case
addElement $\Theta(n)$ $\Theta(1)$ $\Theta(1)$
sampleElement $\Theta(n)$ $\Theta(n)$ $\Theta(1)$
removeElement $\Theta(n)$ $\Theta(n)$ $\Theta(1)$
BinarySearchProbabilityDistribution
Operation Worst case Average case Best case
addElement $\Theta(n)$ $\Theta(1)$ $\Theta(1)$
sampleElement $\Theta(\log n)$ $\Theta(\log n)$ $\Theta(1)$
removeElement $\Theta(n)$ $\Theta(n)$ $\Theta(1)$
addElement $\Theta(n)$ $\Theta(1)$ $\Theta(1)$
sampleElement $\Theta(n)$ $\Theta(n)$ $\Theta(1)$
removeElement $\Theta(1)$ $\Theta(1)$ $\Theta(1)$
addElement $\Theta(\log n)$ $\Theta(\log n)$ $\Theta(\log n)$
sampleElement $\Theta(\log n)$ $\Theta(\log n)$ $\Theta(\log n)$
removeElement $\Theta(\log n)$ $\Theta(\log n)$ $\Theta(\log n)$