40x Faster Binary Search

Ragnar {Groot Koerkamp} Link to heading

  • Did IMO & ICPC; currently head-of-jury for NWERC.
  • Some time at Google.
  • Quit and solved all of projecteuler.net (700+) during Covid.
  • Just finished PhD on high throughput bioinformatics @ ETH Zurich.
    • Lots of sequenced DNA that needs processing.
    • Many static datasets, e.g. a 3GB human genome.
    • Revisiting basic algorithms and optimizing them to the limit.
    • Good in theory \(\neq\) fast in practice.

Problem Statement Link to heading

  • Input: a static sorted list of 32 bit integers.
  • Queries: given \(q\), find the smallest value in the list \(\geq q\).
1
2
3
4
5
6
trait SearchIndex {
    // Initialize the data structure.
    fn new(sorted_data: &[u32]) -> Self;
    // Return the smallest value >=q.
    fn query(&self, q: u32) -> u32;
}

  1. Compare \(q=5\) with the middle element \(x\).
  2. Recurse on left half if \(q\leq x\), right half if \(q>x\).
  3. End when 1 element left after \(\lceil\lg_2(n+1)\rceil\) steps.

Binary Search: latency is more than \(O(\lg n)\) Link to heading

Array Indexing: \(O(n^{0.35})\) latency! Link to heading

Heap Layout: efficient caching + prefetching Link to heading

  • Binary search: top of tree is spread out; each cache line has 1 value.
  • Eytzinger layout: top layers of tree are clustered in cache lines.
    • Also allows prefetching!

Heap Layout: close to array indexing! Link to heading

Static Search Trees / B-trees Link to heading

  • Fully use each cache line.







1
2
3
4
5
6
#[repr(align(64))]       // Each block fills exactly one 512-bit cache line.
struct Block { data: [u32; 16] }
struct Tree {
    tree: Vec<Block>,    // Blocks.
    offsets: Vec<usize>, // Index of first block in each layer.
}

Static Search Trees: Slower than Eytzinger?! Link to heading

Up next: assembly-level optimizations :) Link to heading

Optimizing find: linear scan baseline Link to heading

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
fn find_linear(block: &Block, q: u32) -> usize {
    for i in 0..N {
        if block.data[i] >= q {
            // Early break causes branch mispredictions
            // and prevents auto-vectorization!
            return i;
        }
    }
    return N;
}

Optimizing find: auto-vectorization Link to heading

1
2
3
4
5
6
7
8
9
fn find_count(block: &Block, q: u32) -> usize {
    let mut count = 0;
    for i in 0..N {
        if block.data[i] < q {
            count += 1;
        }
    }
    count
}
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
vmovdqu      (%rax,%rcx), %ymm1     ; load data[..8]
vmovdqu      32(%rax,%rcx), %ymm2   ; load data[8..]
vpbroadcastd %xmm0, %ymm0           ; 'splat' the query value
vpmaxud      %ymm0, %ymm2, %ymm3    ; v
vpcmpeqd     %ymm3, %ymm2, %ymm2    ; v
vpmaxud      %ymm0, %ymm1, %ymm0    ; v
vpcmpeqd     %ymm0, %ymm1, %ymm0    ; 4x compare query with values
vpackssdw    %ymm2, %ymm0, %ymm0    ;
vpcmpeqd     %ymm1, %ymm1, %ymm1    ; v
vpxor        %ymm1, %ymm0, %ymm0    ; 2x negate result
vextracti128 $1, %ymm0, %xmm1       ; v
vpacksswb    %xmm1, %xmm0, %xmm0    ; v
vpshufd      $216, %xmm0, %xmm0     ; v
vpmovmskb    %xmm0, %ecx            ; 4x extract mask
popcntl      %ecx, %ecx

Optimizing find: popcount Link to heading

1
2
3
4
5
6
7
#![feature(portable_simd)]
fn find_popcount(block: &Block, q: i32) -> usize {
    let data: Simd<i32, N> = Simd::from_slice(&block.data[0..N]);
    let q = Simd::splat(q);
    let mask = data.simd_lt(q);      // x[i] < q
    mask.to_bitmask().count_ones()   // count_ones
}
1
2
3
4
5
6
7
8
vpcmpgtd     (%rsi,%rdi), %ymm0, %ymm1
vpcmpgtd     32(%rsi,%rdi), %ymm0, %ymm0
vpackssdw    %ymm1, %ymm0, %ymm0     ; interleave 16bit low halves
vextracti128 $1, %ymm0, %xmm1        ; 1
vpacksswb    %xmm1, %xmm0, %xmm0     ; 2
vpshufd      $216, %xmm0, %xmm0      ; 3 unshuffle interleaving
vpmovmskb    %xmm0, %edi             ; 4 instructions to extract bitmask
popcntl      %edi, %edi

Optimizing find: manual movemask_epi8 Link to heading

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
fn find_manual(block: &Block, q: i32) -> usize { // i32 now!
    let low: Simd<i32, 8> = Simd::from_slice(&self.data[0..N / 2]);
    let high: Simd<i32, 8> = Simd::from_slice(&self.data[N / 2..N]);
    let q = Simd::<i32, 8>::splat(q);
    let mask_low  = q_simd.simd_gt(low ); // compare with low  half
    let mask_high = q_simd.simd_gt(high); // compare with high half
    let merged = _mm256_packs_epi32(mask_low, mask_high); // interleave
    let mask: i32 = _mm256_movemask_epi8(merged);         // movemask_epi8!
    mask.count_ones() as usize / 2        // correct for double-counting
}
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
 vpcmpgtd     (%rsi,%rdi), %ymm0, %ymm1     ; 1 cycle, in parallel
 vpcmpgtd     32(%rsi,%rdi), %ymm0, %ymm0   ; 1 cycle, in parallel
 vpackssdw    %ymm0, %ymm1, %ymm0           ; 1 cycle
-vextracti128 $1, %ymm0, %xmm1     ;
-vpacksswb    %xmm1, %xmm0, %xmm0  ;
-vpshufd      $216, %xmm0, %xmm0   ;
-vpmovmskb    %xmm0, %edi          ; 4 instructions emulating movemask_epi16
+vpmovmskb    %ymm0, %edi                   ; 2 cycles        movemask_epi8
 popcntl      %edi, %edi                    ; 1 cycle
+                                  ; /2 is folded into pointer arithmetic

The results: branchless is great! Link to heading

Throughput, not latency Link to heading

Batching: Many queries in parallel Link to heading

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
-fn query                (&self, q :   u32    ) ->  u32     {
+fn batch<const B: usize>(&self, qs: &[u32; B]) -> [u32; B] {
-    let mut k =  0    ;                                    // current index
+    let mut k = [0; B];                                    // current indices
     for [o, _o2] in self.offsets.array_windows() {         // walk down the tree
+        for i in 0..B {
             let jump_to = self.node(o + k[i]).find(qb[i]); // call `find`
             k[i] = k[i] * (B + 1) + jump_to;               // update index
+        }
     }

     let o = self.offsets.last().unwrap();
+    from_fn(|i| {
         let idx = self.node(o + k[i]).find(qb[i]);
         self.tree[o + k[i] + idx / N].data[idx % N]        // return value
+    })
 }

Batching: up to 2.5x faster! Link to heading

Prefetching Link to heading

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
 fn batch<const B: usize>(&self, qs: &[u32; B]) -> [u32; B] {
     let mut k = [0; B];                                    // current indices
     for [o, o2] in self.offsets.array_windows() {          // walk down the tree
         for i in 0..B {
             let jump_to = self.node(o + k[i]).find(qb[i]); // call `find`
             k[i] = k[i] * (B + 1) + jump_to;               // update index
+            prefetch_index(&self.tree, o2 + k[i]);         // prefetch next layer
         }
     }

     let o = self.offsets.last().unwrap();
     from_fn(|i| {
         let idx = self.node(o + k[i]).find(qb[i]);
         self.tree[o + k[i] + idx / N].data[idx % N]        // return value
     })
 }

Prefetching: 30% faster again! Link to heading

Optimizing pointer arithmetic: more gains Link to heading

  • Convert all pointers to byte units, to avoid conversions.

Interleaving: more pressure on the RAM, -20% Link to heading

  • Interleave multiple batches at different iterations.

Tree layout: internal nodes store minima Link to heading

  • For query 5.5, we walk down the left subtree.
  • Returning 6 reads a new cache line.

Tree layout: internal nodes store maxima Link to heading

  • For query 5.5, we walk down the middle subtree.
  • Returning 6 reads the same cache line.

Tree layout: another 10% gained! Link to heading

Conclusion Link to heading

  • With 4GB input: 40x speedup!

    • binary search: 1150ns
    • static search tree: 27ns
  • Latency of Eytzinger layout search is 3x slower than RAM latency.

  • Throughput of static search tree is 3x slower than RAM throughput.

  • RAM is slow, caches are fast.

  • Grinding through assembly works; be your own compiler!

  • Theoretical big-O complexities are useless here, since caches invalidate the RAM-model assumptions.

P99 CONF Link to heading