In this post, we will implement a static search tree (S+ tree) for
high-throughput searching of sorted data, as introduced on Algorithmica.
We’ll mostly take the code presented there as a starting point, and optimize it
to its limits.
Input. A sorted list of \(n\) 32bit unsigned integers vals: Vec<u32>.
Output. A data structure that supports queries \(q\), returning the smallest
element of vals that is at least \(q\), or u32::MAX if no such element exists.
Optionally, the index of this element may also be returned.
Metric. We optimize throughput. That is, the number of (independent) queries
that can be answered per second. The typical case is where we have a queries: Vec<u32> as input that is sufficiently long, and return a corresponding answers: Vec<u32>.
Benchmarking setup. For now, we will assume that both the input and queries
are simply uniform random sampled 32bit integers.
Code.
In code, this can be modelled by the trait shown in Code Snippet 1.
1
2
3
4
5
6
7
8
9
10
11
12
traitSearchIndex{/// Construct the data structure from a sorted vector.
fnnew(vals: &Vec<u32>)-> Self;/// Two functions with default implementations in terms of each other.
fnquery_one(query: u32)-> u32{Self::query(&vec![query])[0]}fnquery(queries: &Vec<u32>)-> Vec<u32>{queries.iter().map(|&q|Self::query_one(q)).collect()}}
Code Snippet 1:
Trait that our solution should implement.
Related reading
The classical solution to this problem is binary search, which we will briefly
visit in the next section. A great paper on this and other search layouts is
“Array Layouts for Comparison-Based Searching” by Khuong and Morin (2017).
Algorithmica also has a case study based on that paper.
This post will focus on S+ trees, as introduced on Algorithmica in the
followup post, static B-trees.
I also recommend reading my work-in-progress introduction to CPU performance,
which contains some benchmarks pushing the CPU to its limits. We will use the
metrics obtained there as baseline to understand our optimization attemps.
pubstructEytzinger{vals: Vec<u32>,}implEytzinger{/// L: number of levels ahead to prefetch.
pubfnsearch_prefetch<constL: usize>(&self,q: u32)-> u32{letmutidx=1;while(1<<L)*idx<self.vals.len(){idx=2*idx+(q>self.get(idx))asusize;prefetch_index(&self.vals,(1<<L)*idx);}whileidx<self.vals.len(){idx=2*idx+(q>self.get(idx))asusize;}letzeros=idx.trailing_ones()+1;letidx=idx>>zeros;self.get(idx)}}
Code Snippet 2:
Search tree node, aligned to a 64 byte cache line. For now, N is always 16. The values in a node must always be sorted.
NOTE: In code I’ll just write STree, so I’ll also use S-tree instead of S+
tree, since the algorithmica blog already showed that the non-plus S-tree
doesn’t really make sense here anyway.
1
2
3
4
5
6
7
8
9
10
/// N: #elements in a node, always 16.
/// B: branching factor <= N+1. Typically 17.
#[derive(Debug)]pubstructSTree<constB: usize,constN: usize>{/// The list of tree nodes.
tree: Vec<BTreeNode<N>>,/// The root is at index tree[offsets[0]].
/// It's children start at tree[offsets[1]], and so on.
offsets: Vec<usize>,}
Construction TODO: When B<N, ensure that next node value is copied in to
previous node.
TODO: Reverse offsets.
1
2
3
4
5
6
7
8
9
10
11
12
13
fnsearch(&self,q: u32,find: implFn(&BTreeNode<N>,u32)-> usize)-> u32{letmutk=0;foroinself.offsets[1..self.offsets.len()].into_iter().rev(){letjump_to=find(self.node(o+k),q);k=k*(B+1)+jump_to;}leto=self.offsets[0];// node(i) returns tree[i] using unchecked indexing.
letmutidx=find(self.node(o+k),q);// get(i, j) returns tree[i].data[j] using unchecked indexing.
self.get(o+k+idx/N,idx%N)}
vpminud32(%rsi,%r8),%ymm0,%ymm1; take min of data[8..] and query
vpcmpeqd%ymm1,%ymm0,%ymm1; does the min equal query?
vpminud(%rsi,%r8),%ymm0,%ymm2; take min of data[..8] and query
vpcmpeqd%ymm2,%ymm0,%ymm2; does the min equal query?
vpackssdw%ymm1,%ymm2,%ymm1; pack the two results together, interleaved as 16bit words
vextracti128$1,%ymm1,%xmm2; extract half (both halves are equal)
vpacksswb%xmm2,%xmm1,%xmm1; go down to 8bit values, but weirdly shuffled
vpshufd$216,%xmm1,%xmm1; unshuffle
vpmovmskb%xmm1,%r8d; extract the high bit of each 8bit value.
orl$65536,%r8d; set bit 16, to cover the unwrap_or(N)
tzcntl%r8d,%r15d; count trailing zeros
First up: why does simd_le translate into min and =cmpeq?
From checking the Intel Intrinsics Guide, we find out that there are only signed
comparisons, while our data is unsigned. For now, let’s just assume that all
values fit in 31 bits and are at most i32::MAX. Then, we can transmute our input
to Simd<i32, 8>.
Assumption
Both input values and queries are between 0 and i32::MAX.
Eventually we can fix this by either taking i32 input directly, or by shifting
u32 values to fit in the i32 range.
1
2
3
4
5
6
7
8
9
10
11
pub fn find_ctz_signed(&self, q: u32) -> usize
where
LaneCount<N>: SupportedLaneCount,
{
- let data: Simd<u32, N> = Simd::from_slice( &self.data[0..N] );
+ let data: Simd<i32, N> = Simd::from_slice(unsafe { transmute(&self.data[0..N]) });
- let q = Simd::splat(q );
+ let q = Simd::splat(q as i32);
let mask = q.simd_le(data);
mask.first_set().unwrap_or(N)
}
1
2
3
4
5
6
7
8
9
10
vpcmpgtd32(%rsi,%rdi),%ymm1,%ymm2; is query(%ymm1) > data[8..]?
vpcmpgtd(%rsi,%rdi),%ymm1,%ymm1; is query(%ymm1) > data[..8]?
vpackssdw%ymm2,%ymm1,%ymm1; pack results
vpxor%ymm0,%ymm1,%ymm1; negate results (ymm0 is all-ones)
vextracti128$1,%ymm1,%xmm2; extract u16x16
vpacksswb%xmm2,%xmm1,%xmm1; shuffle
vpshufd$216,%xmm1,%xmm1; extract u8x16
vpmovmskb%xmm1,%edi; extract u16 mask
orl$65536,%edi; 16 when none set
tzcntl%edi,%edi; count trailing zeros
Code Snippet 8:
The two vpminud and vpcmpeqd instructions are gone now and merged into vpcmpgtd, but instead we got a vpxor back :/
It turns out there is only a > instruction in SIMD, and not >=, and so there
is no way to avoid inverting the result.
Here, Algorithmica takes the approach of ‘pre-shuffling’ the values in each
node to counter for the unshuffle instruction.
They also suggest using popcount instead, which is indeed what we’ll do next.
Popcount
As we saw, the drawback of the trailing zero count approach is that the order of
the lanes must be preserved. Instead, we’ll now simply count the number of lanes
with a value less than the query, so that the order of lanes doesn’t matter.
1
2
3
4
5
6
7
8
9
10
11
pub fn find_popcnt_portable(&self, q: u32) -> usize
where
LaneCount<N>: SupportedLaneCount,
{
let data: Simd<i32, N> = Simd::from_slice(unsafe { transmute(&self.data[0..N]) });
let q = Simd::splat(q as i32);
- let mask = q.simd_le(data);
+ let mask = q.simd_gt(data);
- mask.first_set().unwrap_or(N)
+ mask.to_bitmask().count_ones() as usize
}
Code Snippet 10:
the xor and or instructions are gone, but we are still stuck with the sequence of 5 instructions to go from the comparison results to an integer bitmask.
Ideally we would like to movmsk directly on the u16x16 output of the first
pack instruction, vpackssdw.
Unfortunately, we are again let down by AVX2: there are movemaskinstructions
for u8, u32, and u16, but not for u16.
Also, the vpshufd instruction is now provably useless, so it’s slightly
disappointing the compiler didn’t elide it. Time to write the SIMD by hand instead.
Manual SIMD
As it turns out, we can get away without most of the packing!
Instead of using vpmovmskb (_mm256_movemask_epi8) on 8bit data, we can
actually just use it directly on the 16bit output of vpackssdw!
Since the comparison sets each lane to all-zeros or all-ones, we can safely read
the most significant and middle bit, and divide the count by two at the end.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
pubfnfind_popcnt(&self,q: u32)-> usize{// We explicitly require that N is 16.
letlow: Simd<u32,8>=Simd::from_slice(&self.data[0..N/2]);lethigh: Simd<u32,8>=Simd::from_slice(&self.data[N/2..N]);letq_simd=Simd::<_,8>::splat(qasi32);unsafe{usestd::mem::transmuteast;// Transmute from u32 to i32.
letmask_low=q_simd.simd_gt(t(low));letmask_high=q_simd.simd_gt(t(high));// Transmute from portable_simd to __m256i intrinsic types.
letmerged=_mm256_packs_epi32(t(mask_low),t(mask_high));// 32 bits is sufficient to hold a count of 2 per lane.
letmask: i32=_mm256_movemask_epi8(t(merged));mask.count_ones()asusize/2}}
Code Snippet 12:
Only 5 instructions total are left now. Note that there is no explicit division by 2, since this is absorbed into the pointer arithmetic in the remainder.
From now on, this is the version we will be using.
pub fn search_l1<const P: usize>(&self, qb: &[u32; P]) -> [u32; P] {
let offsets = self
.offsets
.iter()
.map(|o| unsafe { self.tree.as_ptr().add(*o) })
.collect_vec();
// Initial parts, and prefetch them.
let o0 = offsets[0];
let mut k: [usize; P] = qb.map(|q| {
(q as usize >> self.shift) * 64
});
let q_simd = qb.map(|q| Simd::<u32, 8>::splat(q));
- for [o, o2] in offsets.array_windows() {
+ if let Some([o1, o2]) = offsets.array_windows().next() {
for i in 0..P {
let jump_to = unsafe { *o.byte_add(k[i]) }.find_splat64(q_simd[i]);
- k[i] = k[i] * (B + 1) + jump_to;
+ k[i] = k[i] * self.l1 + jump_to;
prefetch_ptr(unsafe { o2.byte_add(k[i]) });
}
}
- for [o, o2] in offsets .array_windows() {
+ for [o, o2] in offsets[1..].array_windows() {
for i in 0..P {
let jump_to = unsafe { *o.byte_add(k[i]) }.find_splat64(q_simd[i]);
k[i] = k[i] * (B + 1) + jump_to;
prefetch_ptr(unsafe { o2.byte_add(k[i]) });
}
}
let o = offsets.last().unwrap();
from_fn(|i| {
let idx = unsafe { *o.byte_add(k[i]) }.find_splat(q_simd[i]);
unsafe { (o.byte_add(k[i]) as *const u32).add(idx).read() }
})
}
Summary
Future ideas
Overlapping subtrees
Packing data smaller
After doing a 16bit prefix lookup, it suffices to only store 16 bits per element.
References
Khuong, Paul-Virak, and Pat Morin. 2017. “Array Layouts for Comparison-Based Searching.” Acm Journal of Experimental Algorithmics 22 (May): 1–39. https://doi.org/10.1145/3053370.