[WIP] Static search trees

In this post, we will implement a static search tree (S+ tree) for high-throughput searching of sorted data, as introduced on Algorithmica. We’ll mostly take the code presented there as a starting point, and optimize it to its limits.

Code: github:RagnarGrootKoerkamp/suffix-array-searching.

Introduction

Problem statement

Input. A sorted list of \(n\) 32bit unsigned integers vals: Vec<u32>.

Output. A data structure that supports queries \(q\), returning the smallest element of vals that is at least \(q\), or u32::MAX if no such element exists. Optionally, the index of this element may also be returned.

Metric. We optimize throughput. That is, the number of (independent) queries that can be answered per second. The typical case is where we have a queries: Vec<u32> as input that is sufficiently long, and return a corresponding answers: Vec<u32>.

Benchmarking setup. For now, we will assume that both the input and queries are simply uniform random sampled 32bit integers.

Code. In code, this can be modelled by the trait shown in Code Snippet 1.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
trait SearchIndex {
    /// Construct the data structure from a sorted vector.
    fn new(vals: &Vec<u32>) -> Self;

    /// Two functions with default implementations in terms of each other.
    fn query_one(query: u32) -> u32 {
        Self::query(&vec![query])[0]
    }
    fn query(queries: &Vec<u32>) -> Vec<u32> {
        queries.iter().map(|&q| Self::query_one(q)).collect()
    }
}
Code Snippet 1: Trait that our solution should implement.

The classical solution to this problem is binary search, which we will briefly visit in the next section. A great paper on this and other search layouts is “Array Layouts for Comparison-Based Searching” by Khuong and Morin (2017). Algorithmica also has a case study based on that paper.

This post will focus on S+ trees, as introduced on Algorithmica in the followup post, static B-trees.

I also recommend reading my work-in-progress introduction to CPU performance, which contains some benchmarks pushing the CPU to its limits. We will use the metrics obtained there as baseline to understand our optimization attemps.

Also helpful is the Intel Intrinsics Guide when looking into SIMD instructions.

Binary search and Eytzinger layout

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
pub struct SortedVec {
    vals: Vec<u32>,
}

impl SortedVec {
    pub fn binary_search_std(&self, q: u32) -> u32 {
        let idx = self.vals.binary_search(&q).unwrap_or_else(|i| i);
        self.vals[idx]
    }
}

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
pub struct Eytzinger {
    vals: Vec<u32>,
}

impl Eytzinger {
    /// L: number of levels ahead to prefetch.
    pub fn search_prefetch<const L: usize>(&self, q: u32) -> u32 {
        let mut idx = 1;
        while (1 << L) * idx < self.vals.len() {
            idx = 2 * idx + (q > self.get(idx)) as usize;
            prefetch_index(&self.vals, (1 << L) * idx);
        }
        while idx < self.vals.len() {
            idx = 2 * idx + (q > self.get(idx)) as usize;
        }
        let zeros = idx.trailing_ones() + 1;
        let idx = idx >> zeros;
        self.get(idx)
    }
}

B-trees, B+ trees, and S-trees

1
2
3
4
#[repr(align(64))]
pub struct BTreeNode<const N: usize> {
    data: [u32; N],
}
Code Snippet 2: Search tree node, aligned to a 64 byte cache line. For now, N is always 16. The values in a node must always be sorted.

NOTE: In code I’ll just write STree, so I’ll also use S-tree instead of S+ tree, since the algorithmica blog already showed that the non-plus S-tree doesn’t really make sense here anyway.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
/// N: #elements in a node, always 16.
/// B: branching factor <= N+1. Typically 17.
#[derive(Debug)]
pub struct STree<const B: usize, const N: usize> {
    /// The list of tree nodes.
    tree: Vec<BTreeNode<N>>,
    /// The root is at index tree[offsets[0]].
    /// It's children start at tree[offsets[1]], and so on.
    offsets: Vec<usize>,
}

Construction TODO: When B<N, ensure that next node value is copied in to previous node.

TODO: Reverse offsets.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
fn search(&self, q: u32, find: impl Fn(&BTreeNode<N>, u32) -> usize) -> u32 {
    let mut k = 0;
    for o in self.offsets[1..self.offsets.len()].into_iter().rev() {
        let jump_to = find(self.node(o + k), q);
        k = k * (B + 1) + jump_to;
    }

    let o = self.offsets[0];
    // node(i) returns tree[i] using unchecked indexing.
    let mut idx = find(self.node(o + k), q);
    // get(i, j) returns tree[i].data[j] using unchecked indexing.
    self.get(o + k + idx / N, idx % N)
}
Code Snippet 4: Initial code, directly adapted from https://en.algorithmica.org/hpc/data-structures/s-tree/#searching, with a customizable find function.

Optimizing find

Linear

1
2
3
4
5
6
7
8
pub fn find_linear(&self, q: u32) -> usize {
    for i in 0..N {
        if self.data[i] >= q {
            return i;
        }
    }
    N
}

Hugepages

  • 2MB vs 32MB

Autovectorization

1
2
3
4
5
6
7
8
9
pub fn find_linear_count(&self, q: u32) -> usize {
    let mut count = 0;
    for i in 0..N {
        if self.data[i] < q {
            count += 1;
        }
    }
    count
}

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
vmovdqu      (%rax,%rcx), %ymm1
vmovdqu      32(%rax,%rcx), %ymm2
vpbroadcastd %xmm0, %ymm0
vpmaxud      %ymm0, %ymm2, %ymm3
vpcmpeqd     %ymm3, %ymm2, %ymm2
vpmaxud      %ymm0, %ymm1, %ymm0
vpcmpeqd     %ymm0, %ymm1, %ymm0
vpackssdw    %ymm2, %ymm0, %ymm0
vpcmpeqd     %ymm1, %ymm1, %ymm1
vpxor        %ymm1, %ymm0, %ymm0
vextracti128 $1, %ymm0, %xmm1
vpacksswb    %xmm1, %xmm0, %xmm0
vpshufd      $216, %xmm0, %xmm0
vpmovmskb    %xmm0, %ecx
popcntl      %ecx, %ecx
Code Snippet 7: linear-count is auto-vectorized :)

Trailing zeros

1
2
3
4
5
6
pub fn find_ctz(&self, q: u32) -> usize {
    let data: Simd<u32, N> = Simd::from_slice(&self.data[0..N]);
    let q = Simd::splat(q);
    let mask = q.simd_le(data);
    mask.first_set().unwrap_or(N)
}
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
vpminud      32(%rsi,%r8), %ymm0, %ymm1  ; take min of data[8..] and query
vpcmpeqd     %ymm1, %ymm0, %ymm1         ; does the min equal query?
vpminud      (%rsi,%r8), %ymm0, %ymm2    ; take min of data[..8] and query
vpcmpeqd     %ymm2, %ymm0, %ymm2         ; does the min equal query?
vpackssdw    %ymm1, %ymm2, %ymm1         ; pack the two results together, interleaved as 16bit words
vextracti128 $1, %ymm1, %xmm2            ; extract half (both halves are equal)
vpacksswb    %xmm2, %xmm1, %xmm1         ; go down to 8bit values, but weirdly shuffled
vpshufd      $216, %xmm1, %xmm1          ; unshuffle
vpmovmskb    %xmm1, %r8d                 ; extract the high bit of each 8bit value.
orl          $65536,%r8d                 ; set bit 16, to cover the unwrap_or(N)
tzcntl       %r8d,%r15d                  ; count trailing zeros

First up: why does simd_le translate into min and =cmpeq?

From checking the Intel Intrinsics Guide, we find out that there are only signed comparisons, while our data is unsigned. For now, let’s just assume that all values fit in 31 bits and are at most i32::MAX. Then, we can transmute our input to Simd<i32, 8>.

Assumption

Both input values and queries are between 0 and i32::MAX.

Eventually we can fix this by either taking i32 input directly, or by shifting u32 values to fit in the i32 range.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
 pub fn find_ctz_signed(&self, q: u32) -> usize
 where
     LaneCount<N>: SupportedLaneCount,
 {
-    let data: Simd<u32, N> = Simd::from_slice(                   &self.data[0..N]   );
+    let data: Simd<i32, N> = Simd::from_slice(unsafe { transmute(&self.data[0..N]) });
-    let q = Simd::splat(q       );
+    let q = Simd::splat(q as i32);
     let mask = q.simd_le(data);
     mask.first_set().unwrap_or(N)
 }

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
vpcmpgtd     32(%rsi,%rdi), %ymm1, %ymm2 ; is query(%ymm1) > data[8..]?
vpcmpgtd     (%rsi,%rdi), %ymm1, %ymm1   ; is query(%ymm1) > data[..8]?
vpackssdw    %ymm2, %ymm1, %ymm1         ; pack results
vpxor        %ymm0, %ymm1, %ymm1         ; negate results (ymm0 is all-ones)
vextracti128 $1, %ymm1, %xmm2            ; extract u16x16
vpacksswb    %xmm2, %xmm1, %xmm1         ; shuffle
vpshufd      $216, %xmm1, %xmm1          ; extract u8x16
vpmovmskb    %xmm1, %edi                 ; extract u16 mask
orl          $65536,%edi                 ; 16 when none set
tzcntl       %edi,%edi                   ; count trailing zeros
Code Snippet 8: The two vpminud and vpcmpeqd instructions are gone now and merged into vpcmpgtd, but instead we got a vpxor back :/

It turns out there is only a > instruction in SIMD, and not >=, and so there is no way to avoid inverting the result.

Here, Algorithmica takes the approach of ‘pre-shuffling’ the values in each node to counter for the unshuffle instruction. They also suggest using popcount instead, which is indeed what we’ll do next.

Popcount

As we saw, the drawback of the trailing zero count approach is that the order of the lanes must be preserved. Instead, we’ll now simply count the number of lanes with a value less than the query, so that the order of lanes doesn’t matter.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
 pub fn find_popcnt_portable(&self, q: u32) -> usize
 where
     LaneCount<N>: SupportedLaneCount,
 {
     let data: Simd<i32, N> = Simd::from_slice(unsafe { transmute(&self.data[0..N]) });
     let q = Simd::splat(q as i32);
-    let mask = q.simd_le(data);
+    let mask = q.simd_gt(data);
-    mask.first_set().unwrap_or(N)
+    mask.to_bitmask().count_ones() as usize
 }

1
2
3
4
5
6
7
8
vpcmpgtd     32(%rsi,%rdi), %ymm0, %ymm1
vpcmpgtd     (%rsi,%rdi), %ymm0, %ymm0
vpackssdw    %ymm1, %ymm0, %ymm0     ; 1
vextracti128 $1, %ymm0, %xmm1        ; 2
vpacksswb    %xmm1, %xmm0, %xmm0     ; 3
vpshufd      $216, %xmm0, %xmm0      ; 4
vpmovmskb    %xmm0, %edi             ; 5
popcntl      %edi, %edi
Code Snippet 10: the xor and or instructions are gone, but we are still stuck with the sequence of 5 instructions to go from the comparison results to an integer bitmask.

Ideally we would like to movmsk directly on the u16x16 output of the first pack instruction, vpackssdw. Unfortunately, we are again let down by AVX2: there are movemask instructions for u8, u32, and u16, but not for u16.

Also, the vpshufd instruction is now provably useless, so it’s slightly disappointing the compiler didn’t elide it. Time to write the SIMD by hand instead.

Manual SIMD

As it turns out, we can get away without most of the packing! Instead of using vpmovmskb (_mm256_movemask_epi8) on 8bit data, we can actually just use it directly on the 16bit output of vpackssdw! Since the comparison sets each lane to all-zeros or all-ones, we can safely read the most significant and middle bit, and divide the count by two at the end.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
pub fn find_popcnt(&self, q: u32) -> usize {
    // We explicitly require that N is 16.
    let low: Simd<u32, 8> = Simd::from_slice(&self.data[0..N / 2]);
    let high: Simd<u32, 8> = Simd::from_slice(&self.data[N / 2..N]);
    let q_simd = Simd::<_, 8>::splat(q as i32);
    unsafe {
        use std::mem::transmute as t;
        // Transmute from u32 to i32.
        let mask_low = q_simd.simd_gt(t(low));
        let mask_high = q_simd.simd_gt(t(high));
        // Transmute from portable_simd to __m256i intrinsic types.
        let merged = _mm256_packs_epi32(t(mask_low), t(mask_high));
        // 32 bits is sufficient to hold a count of 2 per lane.
        let mask: i32 = _mm256_movemask_epi8(t(merged));
        mask.count_ones() as usize / 2
    }
}
Code Snippet 11: New version

1
2
3
4
5
vpcmpgtd     (%rsi,%rdi), %ymm0, %ymm1
vpcmpgtd     32(%rsi,%rdi), %ymm0, %ymm0
vpackssdw    %ymm0, %ymm1, %ymm0
vpmovmskb    %ymm0, %edi
popcntl      %edi, %edi
Code Snippet 12: Only 5 instructions total are left now. Note that there is no explicit division by 2, since this is absorbed into the pointer arithmetic in the remainder.

From now on, this is the version we will be using.

Optimizing the search

Batching

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
fn batch<const P: usize>(&self, qb: &[u32; P]) -> [u32; P] {
    let mut k = [0; P];
    for [o, o2] in self.offsets.array_windows() {
        for i in 0..P {
            let jump_to = self.node(o + k[i]).find(qb[i]);
            k[i] = k[i] * (B + 1) + jump_to;
        }
    }

    let o = self.offsets.last().unwrap();
    from_fn(|i| {
        let idx = self.node(o + k[i]).find(qb[i]);
        self.get(o + k[i] + idx / N, idx % N)
    })
}

Prefetching

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
 fn batch<const P: usize>(&self, qb: &[u32; P]) -> [u32; P] {
     let mut k = [0; P];
     for [o, o2] in self.offsets.array_windows() {
         for i in 0..P {
             let jump_to = self.node(o + k[i]).find(qb[i]);
             k[i] = k[i] * (B + 1) + jump_to;
+            prefetch_index(&self.tree, o2 + k[i]);
         }
     }

     let o = self.offsets.last().unwrap();
     from_fn(|i| {
         let idx = self.node(o + k[i]).find(qb[i]);
         self.get(o + k[i] + idx / N, idx % N)
     })
 }

Pointer arithmetic

Up-front splat

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
 pub fn batch_splat<const P: usize>(&self, qb: &[u32; P]) -> [u32; P] {
     let mut k = [0; P];
+    let q_simd = qb.map(|q| Simd::<u32, 8>::splat(q));

     for [o, o2] in self.offsets.array_windows() {
         for i in 0..P {
-            let jump_to = self.node(o + k[i]).find      (qb[i]    );
+            let jump_to = self.node(o + k[i]).find_splat(q_simd[i]);
             k[i] = k[i] * (B + 1) + jump_to;
             prefetch_index(&self.tree, o2 + k[i]);
         }
     }

     let o = self.offsets.last().unwrap();
     from_fn(|i| {
-        let idx = self.node(o + k[i]).find      (qb[i]    );
+        let idx = self.node(o + k[i]).find_splat(q_simd[i]);
         self.get(o + k[i] + idx / N, idx % N)
     })
 }

Pointer indexing

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
 pub fn batch_ptr<const P: usize>(&self, qb: &[u32; P]) -> [u32; P] {
     let mut k = [0; P];
     let q_simd = qb.map(|q| Simd::<u32, 8>::splat(q));

+    // offsets[l] is a pointer to self.tree[self.offsets[l]]
+    let offsets = self.offsets.iter()
+        .map(|o| unsafe { self.tree.as_ptr().add(*o) })
+        .collect_vec();

     for [o, o2] in offsets.array_windows() {
         for i in 0..P {
-            let jump_to = self.node(o  +  k[i])  .find_splat(q_simd[i]);
+            let jump_to = unsafe { *o.add(k[i]) }.find_splat(q_simd[i]);
             k[i] = k[i] * (B + 1) + jump_to;
-            prefetch_index(&self.tree, o2 + k[i]);
+            prefetch_ptr(unsafe { o2.add(k[i]) });
         }
     }

     let o = offsets.last().unwrap();
     from_fn(|i| {
-        let idx = self.node(o  +  k[i])  .find_splat(q_simd[i]);
+        let idx = unsafe { *o.add(k[i]) }.find_splat(q_simd[i]);
-        self.get(o + k[i] + idx / N, idx % N)
+        unsafe { *(*o.add(k[i] + idx / N)).data.get_unchecked(idx % N) }
     })
 }

Byte-based pointers

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
 pub fn batch_byte_ptr<const P: usize>(&self, qb: &[u32; P]) -> [u32; P] {
     let mut k = [0; P];
     let q_simd = qb.map(|q| Simd::<u32, 8>::splat(q));

     let offsets = self
         .offsets
         .iter()
         .map(|o| unsafe { self.tree.as_ptr().add(*o) })
         .collect_vec();

     for [o, o2] in offsets.array_windows() {
         for i in 0..P {
-            let jump_to = unsafe { *o.     add(k[i]) }.find_splat(q_simd[i]);
+            let jump_to = unsafe { *o.byte_add(k[i]) }.find_splat(q_simd[i]);
-            k[i] = k[i] * (B + 1) + jump_to     ;
+            k[i] = k[i] * (B + 1) + jump_to * 64;
-            prefetch_ptr(unsafe { o2.     add(k[i]) });
+            prefetch_ptr(unsafe { o2.byte_add(k[i]) });
         }
     }

     let o = offsets.last().unwrap();
     from_fn(|i| {
-        let idx = self.node(o     +    k[i])  .find_splat(q_simd[i]);
+        let idx = unsafe { *o.byte_add(k[i]) }.find_splat(q_simd[i]);
         unsafe { (o.byte_add(k[i]) as *const u32).add(idx).read() }
     })
 }

More efficient multiplication

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
 pub fn find_splat64(&self, q_simd: Simd<u32, 8>) -> usize {
     let low: Simd<u32, 8> = Simd::from_slice(&self.data[0..N / 2]);
     let high: Simd<u32, 8> = Simd::from_slice(&self.data[N / 2..N]);
     unsafe {
         let q_simd: Simd<i32, 8> = t(q_simd);
         let mask_low = q_simd.simd_gt(t(low));
         let mask_high = q_simd.simd_gt(t(high));
         use std::mem::transmute as t;
         let merged = _mm256_packs_epi32(t(mask_low), t(mask_high));
         let mask = _mm256_movemask_epi8(merged);
-        mask.count_ones() as usize / 2
+        mask.count_ones() as usize * 32
     }
 }

 pub fn batch_byte_ptr2<const P: usize>(&self, qb: &[u32; P]) -> [u32; P] {
     let mut k = [0; P];
     let q_simd = qb.map(|q| Simd::<u32, 8>::splat(q));

     let offsets = self
         .offsets
         .iter()
         .map(|o| unsafe { self.tree.as_ptr().add(*o) })
         .collect_vec();

     for [o, o2] in offsets.array_windows() {
         for i in 0..P {
-            let jump_to = unsafe { *o.byte_add(k[i]) }.find_splat  (q_simd[i]);
+            let jump_to = unsafe { *o.byte_add(k[i]) }.find_splat64(q_simd[i]);
-            k[i] = k[i] * (B + 1) + jump_to * 64;
+            k[i] = k[i] * (B + 1) + jump_to     ;
             prefetch_ptr(unsafe { o2.byte_add(k[i]) });
         }
     }

     let o = offsets.last().unwrap();
     from_fn(|i| {
         let idx = unsafe { *o.byte_add(k[i]) }.find_splat(q_simd[i]);
         unsafe { (o.byte_add(k[i]) as *const u32).add(idx).read() }
     })
 }

Skip prefetch

Interleave

Optimizing the tree layout

‘Flipped’ tree

Reverse offsets

  • forward vs original reverse direction

Full table

  • wasted memory
  • faster pointer arithmetic?
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
 pub fn batch_ptr3_full<const P: usize>(&self, qb: &[u32; P]) -> [u32; P] {
     let mut k = [0; P];
     let q_simd = qb.map(|q| Simd::<u32, 8>::splat(q));

+    let o = self.tree.as_ptr();

-    for [o, o2] in offsets.array_windows() {
+    for _h      in 0..self.offsets.len() - 1 {
         for i in 0..P {
             let jump_to = unsafe { *o.byte_add(k[i]) }.find_splat64(q_simd[i]);
-            k[i] = k[i] * (B + 1) + jump_to     ;
+            k[i] = k[i] * (B + 1) + jump_to + 64;
             prefetch_ptr(unsafe { o.byte_add(k[i]) });
         }
     }

     from_fn(|i| {
         let idx = unsafe { *o.byte_add(k[i]) }.find_splat(q_simd[i]);
         unsafe { (o.byte_add(k[i]) as *const u32).add(idx).read() }
     })
 }

Node size \(B=15\).

Summary

What we have now.

Beyond search trees

Prefix lookup

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
 pub fn search_prefix<const P: usize>(&self, qb: &[u32; P]) -> [u32; P] {
     let offsets = self
         .offsets
         .iter()
         .map(|o| unsafe { self.tree.as_ptr().add(*o) })
         .collect_vec();

     // Initial parts, and prefetch them.
     let o0 = offsets[0];
-    let mut k = [0; P];
+    let mut k = qb.map(|q| {
+        (q as usize >> self.shift) * 64
+    });
     let q_simd = qb.map(|q| Simd::<u32, 8>::splat(q));

     for [o, o2] in offsets.array_windows() {
         for i in 0..P {
             let jump_to = unsafe { *o.byte_add(k[i]) }.find_splat64(q_simd[i]);
             k[i] = k[i] * (B + 1) + jump_to;
             prefetch_ptr(unsafe { o2.byte_add(k[i]) });
         }
     }

     let o = offsets.last().unwrap();
     from_fn(|i| {
         let idx = unsafe { *o.byte_add(k[i]) }.find_splat(q_simd[i]);
         unsafe { (o.byte_add(k[i]) as *const u32).add(idx).read() }
     })
 }

A more compact layout

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
 pub fn search<const P: usize>(&self, qb: &[u32; P]) -> [u32; P] {
     let offsets = self
         .offsets
         .iter()
         .map(|o| unsafe { self.tree.as_ptr().add(*o) })
         .collect_vec();

     // Initial parts, and prefetch them.
     let o0 = offsets[0];
+    let mut k: [usize; P] = [0; P];
+    let parts: [usize; P] = qb.map(|q| {
+        // byte offset of the part.
+        (q as usize >> self.shift) * self.bpp * 64
+    });
     let q_simd = qb.map(|q| Simd::<u32, 8>::splat(q));

     for [o, o2] in offsets.array_windows() {
         for i in 0..P {
-            let jump_to = unsafe { *o.byte_add(           k[i]) }.find_splat64(q_simd[i]);
+            let jump_to = unsafe { *o.byte_add(parts[i] + k[i]) }.find_splat64(q_simd[i]);
             k[i] = k[i] * (B + 1) + jump_to;
-            prefetch_ptr(unsafe { o2.byte_add(           k[i]) });
+            prefetch_ptr(unsafe { o2.byte_add(parts[i] + k[i]) });
         }
     }

     let o = offsets.last().unwrap();
     from_fn(|i| {
-        let idx = unsafe { *o.byte_add(           k[i]) }.find_splat(q_simd[i]);
+        let idx = unsafe { *o.byte_add(parts[i] + k[i]) }.find_splat(q_simd[i]);
-        unsafe { (o.byte_add(           k[i]) as *const u32).add(idx).read() }
+        unsafe { (o.byte_add(parts[i] + k[i]) as *const u32).add(idx).read() }
     })
 }

The best of both

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
 pub fn search_l1<const P: usize>(&self, qb: &[u32; P]) -> [u32; P] {
     let offsets = self
         .offsets
         .iter()
         .map(|o| unsafe { self.tree.as_ptr().add(*o) })
         .collect_vec();

     // Initial parts, and prefetch them.
     let o0 = offsets[0];
     let mut k: [usize; P] = qb.map(|q| {
          (q as usize >> self.shift) * 64
     });
     let q_simd = qb.map(|q| Simd::<u32, 8>::splat(q));

-    for         [o, o2]  in offsets.array_windows()        {
+    if let Some([o1, o2]) = offsets.array_windows().next() {
         for i in 0..P {
             let jump_to = unsafe { *o.byte_add(k[i]) }.find_splat64(q_simd[i]);
-            k[i] = k[i] * (B + 1) + jump_to;
+            k[i] = k[i] * self.l1 + jump_to;
             prefetch_ptr(unsafe { o2.byte_add(k[i]) });
         }
     }

-    for [o, o2] in offsets     .array_windows() {
+    for [o, o2] in offsets[1..].array_windows() {
         for i in 0..P {
             let jump_to = unsafe { *o.byte_add(k[i]) }.find_splat64(q_simd[i]);
             k[i] = k[i] * (B + 1) + jump_to;
             prefetch_ptr(unsafe { o2.byte_add(k[i]) });
         }
     }

     let o = offsets.last().unwrap();
     from_fn(|i| {
         let idx = unsafe { *o.byte_add(k[i]) }.find_splat(q_simd[i]);

         unsafe { (o.byte_add(k[i]) as *const u32).add(idx).read() }
     })
 }

Summary

Future ideas

Overlapping subtrees

Packing data smaller

  • After doing a 16bit prefix lookup, it suffices to only store 16 bits per element.

References

Khuong, Paul-Virak, and Pat Morin. 2017. “Array Layouts for Comparison-Based Searching.” Acm Journal of Experimental Algorithmics 22 (May): 1–39. https://doi.org/10.1145/3053370.