Adaptive tree sort for integers in Java

Introduction

In this post, I will discuss a tree sort algorithm for sorting (primitive) integer arrays a little bit more efficiently than java.util.Arrays.sort. A (general (and stable)) tree sort scans the requested array range, putting each scanned element into a (hopefully) balanced binary search tree. Whenever we scan an element that is already in the tree, there is two cases to consider:

  • if we are sorting an array of primitive component type (int, float, etc.) we could just hold a counter variable in each tree node, counting the number of occurrences of the tree node key so far in the range being scanned, or
  • if we are sorting an array of non-primitive component type (Objects in general), I see no other option like converting the counter variable mentioned in the above list item to an ordered list.

What comes to the case of sorting objects, after we have built the tree mapping each distinct object to the list of all objects considered to be “equal”, we traverse the tree keys in order, and at each key, we dump the list to the input array range. Same goes for the case of sorting primitive type arrays: take the least key (say, x), copy it to the beginning of the input range k times, where k is the number of times x occurred in the input range; move to the second least element, dump it, move, and so on.

The algorithm

Clearly, under assumption that the underlying tree is balanced, the worst case running time is \Theta(n \log n). We can, however, improve this to \Theta(n + k \log k), where k is the number of distinct integers in the requested array range. In the sort, beyond the tree data structure mapping each distinct key to the number of its occurrences so far, we maintain a hash table mapping individual integers to their respective tree nodes; given an integer x, we first ask the hash table whether x is already mapped to a tree node. If so, we just increment the counter of the tree node. Obviously, adding an integer we have already added takes constant time. If, however, x is not in the tree (and the table), we add it to both of them in total time \mathcal{O}(\log k).

The last important point is that, on sufficiently small k — but on inherently random data — the tree sort under discussion improves java.util.Arrays by a factor of four. However, measurements revealed that, on fully random data with large k, the hidden constants are huge: the sort becomes unpractical. In order to circumvent the issue, our implementation switches to java.util.Arrays.sort as soon as the number of distinct integers encountered exceeds a particular threshold.

net.coderodde.util.sorting.IntTreeSort.java
package net.coderodde.util.sorting;

import java.util.Arrays;

/**
 * This class implements a funky tree sort algorithm for sorting integers.
 * 
 * @author Rodion "rodde" Efremov
 * @version 1.6 (Feb 21, 2016)
 */
public class IntTreeSort {

    public static void sort(int[] array) {
        sort(array, 0, array.length);
    }

    public static void sort(int[] array, int fromIndex, int toIndex) {
        if (toIndex - fromIndex < 2) {
            return;
        }

        new IntTreeSort(array, fromIndex, toIndex).sort();
    }

    private final int[] array;
    private final int fromIndex;
    private final int toIndex;
    private final HashTableEntry[] table;
    private final int mask;
    private TreeNode root;

    private IntTreeSort(int[] array, int fromIndex, int toIndex) {
        this.array     = array;
        this.fromIndex = fromIndex;
        this.toIndex   = toIndex;

        int capacity   = computeCapacity(toIndex - fromIndex);

        this.table     = new HashTableEntry[capacity];
        this.mask      = capacity - 1;
    }

    private static int computeCapacity(int length) {
        int ret = 1;

        while (ret < length) {
            ret <<= 1;
        }

        return ret;
    }

    private static final class TreeNode {
        int key;
        int count;
        int height;
        TreeNode left;
        TreeNode right;
        TreeNode parent;

        TreeNode(int key) {
            this.key = key;
            this.count = 1;
        }
    }

    private static final class HashTableEntry {
        int key;
        TreeNode treeNode;
        HashTableEntry nextEntry;

        HashTableEntry(int key, TreeNode treeNode, HashTableEntry nextEntry) {
            this.key = key;
            this.treeNode = treeNode;
            this.nextEntry = nextEntry;
        }
    }

    private static int height(TreeNode node) {
        return node == null ? -1 : node.height;
    }

    private int index(int element) {
        return element & mask;
    }

    private TreeNode findTreeNode(int element, int elementHash) {
        HashTableEntry entry = table[elementHash];

        while (entry != null && entry.treeNode.key != element) {
            entry = entry.nextEntry;
        }

        return entry == null ? null : entry.treeNode;
    }

    private void sort() {
        int maximumAllowedNodes = (toIndex - fromIndex) / 240;
        int treeNodeCount = 0;
        int initialKey = array[fromIndex];
        root = new TreeNode(initialKey);
        table[index(initialKey)] = new HashTableEntry(initialKey,
                                                      root, 
                                                      null);

        for (int i = fromIndex + 1; i < toIndex; ++i) {
            int currentElement = array[i];
            int currentElementHash = index(currentElement);

            TreeNode treeNode = findTreeNode(currentElement, 
                                             currentElementHash);

            if (treeNode != null) {
                treeNode.count++;
            } else {
                ++treeNodeCount;
                
                if (treeNodeCount > maximumAllowedNodes) {
                    Arrays.sort(array, fromIndex, toIndex);
                    return;
                }
                
                TreeNode newnode = add(currentElement);
                HashTableEntry newentry =
                        new HashTableEntry(currentElement,
                                           newnode,
                                           table[currentElementHash]);
                table[currentElementHash] = newentry;
            }
        }

        TreeNode node = minimum(root);
        int index = fromIndex;

        while (node != null) {
            int key = node.key;
            int count = node.count;

            for (int i = 0; i < count; ++i) {
                array[index++] = key;
            }

            node = successor(node);
        }
    }

    private TreeNode minimum(TreeNode node) {
        while (node.left != null) {
            node = node.left;
        }

        return node;
    }

    private TreeNode successor(TreeNode node) {
        if (node.right != null) {
            return minimum(node.right);
        }

        TreeNode parent = node.parent;

        while (parent != null && parent.right == node) {
            node = parent;
            parent = parent.parent;
        }

        return parent;
    }

    private TreeNode add(int key) {
        TreeNode parent = null;
        TreeNode node = root;

        while (node != null) {
            if (key < node.key) {
                parent = node;
                node = node.left;
            } else if (key > node.key) {
                parent = node;
                node = node.right;
            } else {
                break;
            }
        }

        TreeNode newnode = new TreeNode(key);

        if (key < parent.key) {
            parent.left = newnode;
        } else {
            parent.right = newnode;
        }

        newnode.parent = parent;
        fixAfterInsertion(parent);
        return newnode;
    }

    private void fixAfterInsertion(TreeNode node) {
        TreeNode parent = node.parent;
        TreeNode grandParent;
        TreeNode subTree;

        while (parent != null) {
            if (height(parent.left) == height(parent.right) + 2) {
                grandParent = parent.parent;

                if (height(parent.left.left) >= height(parent.left.right)) {
                    subTree = rightRotate(parent);
                } else {
                    subTree = leftRightRotate(parent);
                }

                if (grandParent == null) {
                    root = subTree;
                } else if (grandParent.left == parent) {
                    grandParent.left = subTree;
                } else {
                    grandParent.right = subTree;
                }

                if (grandParent != null) {
                    grandParent.height = Math.max(
                            height(grandParent.left),
                            height(grandParent.right)) + 1;
                }

                return;
            } else if (height(parent.right) == height(parent.left) + 2) {
                grandParent = parent.parent;

                if (height(parent.right.right) >= height(parent.right.left)) {
                    subTree = leftRotate(parent);
                } else {
                    subTree = rightLeftRotate(parent);
                }

                if (grandParent == null) {
                    root = subTree;
                } else if (grandParent.left == parent) {
                    grandParent.left = subTree;
                } else {
                    grandParent.right = subTree;
                }

                if (grandParent != null) {
                    grandParent.height =
                            Math.max(height(grandParent.left),
                                     height(grandParent.right)) + 1;
                }

                return;
            }

            parent.height = Math.max(height(parent.left), 
                                     height(parent.right)) + 1;
            parent = parent.parent;
        }
    }

    private TreeNode leftRotate(TreeNode node1) {
        TreeNode node2 = node1.right;
        node2.parent = node1.parent;
        node1.parent = node2;
        node1.right = node2.left;
        node2.left = node1;

        if (node1.right != null) {
            node1.right.parent = node1;
        }

        node1.height = Math.max(height(node1.left), height(node1.right)) + 1;
        node2.height = Math.max(height(node2.left), height(node2.right)) + 1;
        return node2;
    }

    private TreeNode rightRotate(TreeNode node1) {
        TreeNode node2 = node1.left;
        node2.parent = node1.parent;
        node1.parent = node2;
        node1.left = node2.right;
        node2.right = node1;

        if (node1.left != null) {
            node1.left.parent = node1;
        }

        node1.height = Math.max(height(node1.left), height(node1.right)) + 1;
        node2.height = Math.max(height(node2.left), height(node2.right)) + 1;
        return node2;
    }

    private TreeNode rightLeftRotate(TreeNode node1) {
        TreeNode node2 = node1.right;
        node1.right = rightRotate(node2);
        return leftRotate(node1);
    }

    private TreeNode leftRightRotate(TreeNode node1) {
        TreeNode node2 = node1.left;
        node1.left = leftRotate(node2);
        return rightRotate(node1);
    }
}
Demo.java
import java.util.Arrays;
import java.util.Random;
import net.coderodde.util.sorting.IntTreeSort;

public class Demo {

    private static final int CONSOLE_WIDTH = 80;
    private static final int DISTINCT_INTS = 1000;
    private static final int LENGTH = 10_000_000;

    public static void main(String[] args) {
        System.out.println(title("Small number of distinct integers"));
        int[] array = new int[LENGTH];
        long seed = System.nanoTime();
        Random random = new Random(seed);

        for (int i = 0; i < array.length; ++i) {
            array[i] = 3 * random.nextInt(DISTINCT_INTS);
        }

        System.out.println("Seed = " + seed);

        profile(array);
    }

    private static void profile(int[] array) {
        int[] array2 = array.clone();
        int[] array3 = array.clone();

        long startTime = System.nanoTime();
        Arrays.sort(array);
        long endTime = System.nanoTime();

        System.out.printf("Arrays.sort in %.2f milliseconds.\n",
                          (endTime - startTime) / 1E6);

        startTime = System.nanoTime();
        IntTreeSort.sort(array2);
        endTime = System.nanoTime();

        System.out.printf("IntTreeSort.sort in %.2f milliseconds.\n",
                          (endTime - startTime) / 1E6);

        System.out.println("Equals: " + Arrays.equals(array, array2));
    }

    public static String title(String text) {
        return title(text, '=');
    }

    private static String title(String text, char c) {
        StringBuilder sb = new StringBuilder();

        int left = (CONSOLE_WIDTH - 2 - text.length()) / 2;
        int right = CONSOLE_WIDTH - 2 - text.length() - left;

        for (int i = 0; i < left; ++i) {
            sb.append(c);
        }

        sb.append(' ').append(text).append(' ');

        for (int i = 0; i < right; ++i) {
            sb.append(c);
        }

        return sb.toString();
    }
}

Leave a comment