External libraries are allowed at first for convenience.
I’m assuming that the set of cities is fixed.
Or at least, that each row is sampled randomly from the set of available cities.
Or really, I’m just scanning a prefix of 100k rows, assuming all cities
appear there.
Along the way, I use that the longest line in the input file has 33
characters, and that cities are uniquely determined by their first and last 8 characters.
This post is a log of ideas, timings, and general progress.
My processor is a i7-10750H CPU @ 2.6GHz
Clock speed is capped at 3.60GHz. (It can run
at 4.6GHz, but 3.6GHzz tends to be more stable for benchmarking.)
\(64GB\) ram.
It has 6 cores.
Hyperthreading is disabled.
The file should be cached in memory by the OS.
Table 1:
Results after improving ILP, @4.6GHz, measured inside the program, excluding munmap at exit.
Input: a 13GB file containing \(10^9\) lines of the form Amsterdam;10.5.
First a city name. There are \(413\) cities, and each row has name chosen at
random from the set. Their lengths are from \(3\) to \(26\) bytes.
Then a temperature, formatted as -?\d?\d,\d, i.e. a possibly negative number
with one or two integral digits and exactly one decimal.
Each temperature is drawn from a normal distribution for each city.
Output: a sorted list of cities of the form <city>: <min>/<avg>/<max>,
each formatted with one decimal place.
Let’s see what’s slow here: cargo flamegraph --open (or just flamegraph).
Takeaways:
35% of time is next_match, i.e. searching for \n and/or ;.
14% of time is parsing the f32.
35% of time is accessing the hashmap.
Not sure what exactly is the remainder. We’ll figure that out once it becomes relevant.
Bytes instead of strings: 72s
Strings in rust are checked to be valid UTF8. Using byte slices (&[u8]) is
usually faster. We have to do some slightly ugly conversions from byteslice back
to strings for parsing floats and printing, but it’s worth it. This basically
removes next_match from the flamegraph.
Commit here. (It’s neither pretty nor interesting.)
This already saves 21 seconds, from 105 to 84. Pretty great!
Manual parsing: 61s
Instead of parsing the input as f32 float, we can parse manually to a
fixed-precision i32 signed integer
typeV=i32;fnparse(muts: &[u8])-> V{letneg=ifs[0]==b'-'{s=&s[1..];true}else{false};// s = abc.d
let(a,b,c,d)=matchs{[c,b'.',d]=>(0,0,c-b'0',d-b'0'),[b,c,b'.',d]=>(0,b-b'0',c-b'0',d-b'0'),[a,b,c,b'.',d]=>(a-b'0',b-b'0',c-b'0',d-b'0'),[c]=>(0,0,0,c-b'0'),[b,c]=>(0,b-b'0',c-b'0',0),[a,b,c]=>(a-b'0',b-b'0',c-b'0',0),_=>panic!("Unknown patters {:?}",std::str::from_utf8(s).unwrap()),};letv=aasV*1000+basV*100+casV*10+dasV;ifneg{-v}else{v}}
Code Snippet 2:
A custom parsing function using matching on the pattern. commit.
Inline hash keys: 50s
Currently the hashmap is from &str to Record, where all &str are slices of
the input string. All this indirection is probably slow.
So we instead would like to store keys inline as [u8; 8] (basically a u64).
It turns out that the first 8 characters of each city name are almost enough for
uniqueness. Only Alexandra and Alexandria coincide, so we’ll xor in the
length of the string to make them unique.
One drawback is that the hashmap must now store the full name corresponding to
the key as well.
The default hash table in rust uses a pretty slow hash function. Let’s instead
use fxhash::FxHashMap. For u64 keys, the hash function is simply
multiplication by a constant. This gives another 10 seconds speedup.
1
2
- let mut h = HashMap::new();
+ let mut h = FxHashMap::default();
Now that we’ve addressed the obvious hot parts, let’s make a new graph.
Yeah well great… I suppose everything is inlined or so. But actually the
debuginfo should still be there. idk…
Perf it is
cargo flamegraph uses perf record under the hood. So we can just perf report and see what’s there.
Some snippets. Numbers on the left are percentage of samples on that line.
1
2
3
4
5
6
3.85│2d0:┌─→movzbl0x0(%rbp,%rbx,1),%r15d// read a byte
1.24││cmp$0xa,%r15b// compare to \n
0.69││↓je300// handle the line if \n
2.07││inc%rbx// increment position
│├──cmp%rbx,%rcx// compare to end of data
5.43│└──jne2d0// next iteration
Code Snippet 5:
The column on the left indicates that in total 13% of time is spent looking for newlines.
1
2
3
4
5
6
6.25│330:┌─→cmpb$0x3b,0x0(%rbp,%r13,1)// read a byte
3.40││↓je350// handle if found
3.28││inc%r13// increment position
│├──cmp%r13,%rbx// compare to length of the line
2.53│└──jne330// next iteration
│↓jmpc0e// fall through to panic handler
Code Snippet 6:
15% of time is spent looking for semicolons.
Code Snippet 7:
Converting from [u8; 8] to u64, i.e. an unaligned read, is surprisingly slow?
Then there are quite some instructions for indexing the hash table, adding to
around 20% in total.
Parsing takes around 5%.
Something simple: allocating the right size: 41s
We can stat the input file for its size and allocate exactly the right amount of space.
This saves around half a second.
1
2
3
4
5
let mut data = vec![];
+ let stat = std::fs::metadata(filename).unwrap();
+ data.reserve(stat.len() as usize + 1);
let mut file = std::fs::File::open(filename).unwrap();
file.read_to_end(&mut data).unwrap();
Code Snippet 8:
reserving space
memchr for scanning: 47s
memchr(byte, text) is a libc function that returns the first index of the
byte in the text.
But well.. it turns out this is a non-inlined function call after all and things
slow down. But anyway, here’s the diff:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
let mut h = FxHashMap::default();
- for line in data.split(|&c| c == b'\n') {
- let (name, value) = line.split_once(|&c| c == b';').unwrap();
+ let mut data = &data[..];
+ loop {
+ let Some(separator) = memchr(b';', data) else {
+ break;
+ };
+ let end = memchr(b'\n', &data[separator..]).unwrap();
+ let name = &data[..separator];
+ let value = &data[separator + 1..separator + end];
h.entry(to_key(name))
.or_insert((Record::default(), name))
.0
.add(parse(value));
+ data = &data[separator + end + 1..];
}
One ‘problem’ with memchr is that it is made for scanning long ranges, and is
not super flexible. So let’s roll our own.
We make sure that data is aligned to SIMD boundaries and iterate over it \(32\)
characters at a time. We check for all of them at once whether they equal each
of them, and convert these results to a bitmask. The number of trailing zeros
indicates the position of the match. If the bitmask is \(0\), there are no matches
and we try the next \(32\) characters.
This turns out to be slightly slower. I’m not exactly sure why, but we can
profile and iterate from here.
/// Number of SIMD lanes. AVX2 has 256 bits, so 32 lanes.
constL: usize=32;/// The Simd type.
typeS=Simd<u8,L>;/// Find the regions between \n and ; (names) and between ; and \n (values),
/// and calls `callback` for each line.
#[inline(always)]fniter_lines<'a>(data: &'a[u8],mutcallback: implFnMut(&'a[u8],&'a[u8])){unsafe{// TODO: Handle the tail.
letsimd_data: &[S]=data.align_to::<S>().1;letsep=S::splat(b';');letend=S::splat(b'\n');letmutstart_pos=0;letmuti=0;letmuteq_sep=sep.simd_eq(simd_data[i]).to_bitmask();letmuteq_end=end.simd_eq(simd_data[i]).to_bitmask();// TODO: Handle the tail.
whilei<simd_data.len()-2{// find ; separator
// TODO if?
whileeq_sep==0{i+=1;eq_sep=sep.simd_eq(simd_data[i]).to_bitmask();eq_end=end.simd_eq(simd_data[i]).to_bitmask();}letoffset=eq_sep.trailing_zeros();eq_sep^=1<<offset;letsep_pos=L*i+offsetasusize;// find \n newline
// TODO if?
whileeq_end==0{i+=1;eq_sep=sep.simd_eq(simd_data[i]).to_bitmask();eq_end=end.simd_eq(simd_data[i]).to_bitmask();}letoffset=eq_end.trailing_zeros();eq_end^=1<<offset;letend_pos=L*i+offsetasusize;callback(data.get_unchecked(start_pos..sep_pos),data.get_unchecked(sep_pos+1..end_pos),);start_pos=end_pos+1;}}}
Code Snippet 10:
Simd code to search for semicolon and newline characters. commit.
Profiling
Running perf stat -d cargo run -r gives:
1
2
3
4
5
6
7
8
9
10
11
12
28,367.09msectask-clock:u# 1.020 CPUs utilized
0context-switches:u# 0.000 /sec
0cpu-migrations:u# 0.000 /sec
31,249page-faults:u# 1.102 K/sec
92,838,268,117cycles:u# 3.273 GHz
153,099,184,152instructions:u# 1.65 insn per cycle
19,317,651,322branches:u# 680.988 M/sec
1,712,837,337branch-misses:u# 8.87% of all branches
27,760,594,151L1-dcache-loads:u# 978.620 M/sec
339,143,832L1-dcache-load-misses:u# 1.22% of all L1-dcache accesses
25,000,151LLC-loads:u# 881.308 K/sec
4,546,946LLC-load-misses:u# 18.19% of all L1-icache accesses #+end_src
Code Snippet 11:
Output of perf stat profiling.
Observe:
Actual cycles is only 3.3GHz, whereas it should be 3.6GHz. Not sure why;
might be waiting for IO.
1.65 instructions per cycle is quite low. It can be up to 4 and is often at
least 2.5.
8.87% of branch misses is also quite high. Usually this is at most 1% and
typically lower. Each branch mispredict causes a stall of 5ns or so, which
is over 1 second total, but I suspect the impact is larger.
18.19% of last-level-cache load misses. Also quite high, but I’m not sure if
this is a problem, since the total number of LLC loads is relatively low.
Revisiting the key function: 23s
Looking at perf report we see that the hottest instruction is a call to
memcpy to read up to name.len() bytes from the &[u8] name to a u64.
Code Snippet 12:
12% of time is spent on casting the name into a u64.
We can avoid this memcpy call entirely by just doing a (possibly out of
bounds) u64 read of the name, and then shifting away bits corresponding to the
out-of-bounds part. We’ll also improve the hash to add the first and last (up
to) 8 characters.
1
2
3
4
5
6
7
8
9
fnto_key(name: &[u8])-> u64{// Hash the first and last 8 bytes.
lethead: [u8;8]=unsafe{*name.get_unchecked(..8).split_array_ref().0};lettail: [u8;8]=unsafe{*name.get_unchecked(name.len()-8..).split_array_ref().0};letshift=64usize.saturating_sub(8*name.len());letkhead=u64::from_ne_bytes(head)<<shift;letktail=u64::from_ne_bytes(tail)>>shift;khead+ktail}
│hashbrown::raw::RawTable<T,A>::find:0.27│mov(%rsp),%rcx0.16│mov0x8(%rsp),%rax│hashbrown::raw::h2:0.41│mov%rbp,%rdx0.56│shr$0x39,%rdx1.19│mov%rdx,0x158(%rsp)0.13│vmovd%edx,%xmm00.89│vpbroadcastb%xmm0,%xmm00.20│lea-0x28(%rcx),%rdx0.16│xor%esi,%esi0.16│mov%rbp,%r11│hashbrown::raw::RawTableInner::find_inner:1.41│586:and%rax,%r11│core::intrinsics::copy_nonoverlapping:3.29│vmovdqu(%rcx,%r11,1),%xmm1│core::core_arch::x86::sse2::_mm_movemask_epi8:5.60│vpcmpeqb%xmm0,%xmm1,%xmm2; compare key to stores keys
0.02│vpmovmskb%xmm2,%r8d│hashbrown::raw::bitmask::BitMask::lowest_set_bit:0.31│nop0.97│5a0:┌─→test%r8w,%r8w││<hashbrown::raw::bitmask::BitMaskIterascore::iter::traits::iterator::Iterator>::next:0.80││↓je5d0││hashbrown::raw::bitmask::BitMask::lowest_set_bit:5.59││tzcnt%r8d,%r9d; find position of match in bitmask
││hashbrown::raw::bitmask::BitMask::remove_lowest_bit:0.03││blsr%r8d,%r8d││hashbrown::raw::RawTableInner::find_inner:0.61││add%r11,%r90.53││and%rax,%r9││core::ptr::mut_ptr::<impl*mutT>::sub:1.93││neg%r9││core::ptr::mut_ptr::<impl*mutT>::offset:0.57││lea(%r9,%r9,4),%r9││core::cmp::impls::<implcore::cmp::PartialEqforu64>::eq:8.40│├──cmp%r14,(%rdx,%r9,8); check equal
││hashbrown::raw::RawTableInner::find_inner:0.69│└──jne5a00.11│↓jmp600│core::core_arch::x86::sse2::_mm_movemask_epi8:│data16csnopw0x0(%rax,%rax,1)7.55│5d0:vpcmpeqb-0x47c8(%rip),%xmm1,%xmm1; more equality checking
0.00│vpmovmskb%xmm1,%r8d│hashbrown::raw::bitmask::BitMask::any_bit_set:│┌──test%r8d,%r8d││hashbrown::raw::RawTableInner::find_inner:│├──jne6f6
Code Snippet 14:
The hasmap takes a lot of time. There are four instructions taking over 5% here, for a total of around 35% of runtime.
Observe:
There is a loop for linear probing.
There are a lot of equality checks to test if a slot corresponds to the
requested key.
Generally, this code is long, complex, and branchy.
It would be much better to use a perfect hash function that we build once. Then
none of these equality checks are needed.
For this, I will use PtrHash, a (minimal) perfect hash function I developed based on PtHash
(PtHash paper). I still have to write a paper on PtrHash, but I do have a long
roundabout blogpost.
Find all city names the first 100k rows. Since each row has a random city,
all names will occur here.
Build a perfect hash function. For the given dataset, PtrHash outputs a
metadata pilot array of \(63\) bytes.
On each lookup, the u64 hash is mapped to one of the \(63\) buckets. Then
the hash is xored by C * pilots[b] where \(C\) is a random mixing constant.
This is then reduced to an integer less than \(512\), which is the index in the array
of Records we are looking for.
The pilots are constructed such that each hash results in a different index.
The full code is here.
The diff in the hot loop is this.
1
2
3
4
5
6
7
8
let callback = |name, value| {
let key = to_key(name);
- let entry = h.entry(key).or_insert((Record::default(), name)).0;
+ let index = phf.index(&key);
+ let entry = unsafe { records.get_unchecked_mut(index) };
entry.add(parse(value));
};
iter_lines(data, callback);
Code Snippet 15:
Using a perfect hash function for lookups. Before, h was a HashMap<u64, (Record, &str)>. After, records is simply a [Record; 512], and phf.index(key) is the perfect hash function.
In assembly code, it looks like this:
1
2
3
4
5
6
7
8
9
10
11
12
0.24│movabs$0x517cc1b727220a95,%rsi// Load the multiplication constant C
2.22│imul%rsi,%rdx// Hash the key by multiplying by C
0.53│mov0xf8(%rsp),%rax// Some instructions to compute bucket b < 63
3.16│mulx%rax,%rax,%rax0.55│mov0x10(%rsp),%r85.67│movzbl(%r8,%rax,1),%eax// Read the pilot for this bucket. This is slow.
0.03│mov0x110(%rsp),%r80.57│mulx%r8,%r12,%r127.09│imul%rsi,%rax// Some instructions to get the slot < 512.
0.81│xor%rdx,%rax0.05│mov%rax,%rdx3.87│mulx%rsi,%rdx,%rdx
Code Snippet 16:
Assembly code for the perfect hash function lookup. Just note how short it is compared to the hash table. It's still 20% of the total time though.
The new running time is now 17s!
Larger masks: 15s
Currently we store u32 masks on which we do .trailing_zeros() to find
character offsets. We can also check two 32 simd lanes in parallel and combine them into
a single u64 mask. This gives a small speedup, I think mostly because there
are now slightly fewer branch-misses (593M now vs 675M before): commit.
Reduce pattern matching: 14s
I modified the generator I’m using to always print exactly one decimal. This
saves some branches.
1
2
3
4
5
6
7
8
9
10
11
12
13
// s = abc.d
let (a, b, c, d) = match s {
[c, b'.', d] => (0, 0, c - b'0', d - b'0'),
[b, c, b'.', d] => (0, b - b'0', c - b'0', d - b'0'),
[a, b, c, b'.', d] => (a - b'0', b - b'0', c - b'0', d - b'0'),
- [c] => (0, 0, 0, c - b'0'),
- [b, c] => (0, b - b'0', c - b'0', 0),
- [a, b, c] => (a - b'0', b - b'0', c - b'0', 0),
+ // [c] => (0, 0, 0, c - b'0'),
+ // [b, c] => (0, b - b'0', c - b'0', 0),
+ // [a, b, c] => (a - b'0', b - b'0', c - b'0', 0),
_ => panic!("Unknown pattern {:?}", to_str(s)),
};
Instead of first reading the file into memory and then processing that, we can
memory map it and transparently read parts as needed. This saves the 2 seconds
spent reading the file at the start.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
let filename = &args().nth(1).unwrap_or("measurements.txt".to_string());
- let mut data = vec![];
+ let mut mmap: Mmap;
+ let mut data: &[u8];
{
let mut file = std::fs::File::open(filename).unwrap();
let start = std::time::Instant::now();
- let stat = std::fs::metadata(filename).unwrap();
- data.reserve(stat.len() as usize + 1);
- file.read_to_end(&mut data).unwrap();
+ mmap = unsafe { Mmap::map(&file).unwrap() };
+ data = &*mmap;
eprintln!("{}", format!("{:>5.1?}", start.elapsed()).bold().green());
}
Code Snippet 18:
memory mapping using memmap2 crate.
Parallelization: 2.0s
Parallelizing code is fairly straightforward.
First we split the data into one chunk per thread. Then we fire a thread for
each chunk, each with its own vector to accumulate results. Then at the end each
thread merges its results into the global accumulator.
This gives pretty much exactly \(6\times\) speedup on my 6-core machine, since
accumulating is only a small fraction of the total time.
fnrun_parallel(data: &[u8],phf: &PtrHash,num_slots: usize)-> Vec<Record>{letmutslots=std::sync::Mutex::new(vec![Record::default();num_slots]);// Spawn one thread per core.
letnum_threads=std::thread::available_parallelism().unwrap();std::thread::scope(|s|{letchunks=data.chunks(data.len()/num_threads+1);forchunkinchunks{s.spawn(||{// Each thread has its own accumulator.
letthread_slots=run(chunk,phf,num_slots);// Merge results.
letmutslots=slots.lock().unwrap();for(thread_slot,slot)inthread_slots.into_iter().zip(slots.iter_mut()){slot.merge(&thread_slot);}});}});slots.into_inner().unwrap()}
Code Snippet 19:
Code to process data in parallel.
Branchless parsing: 1.7s
The match statement on the number of digits in the temperature generated quite
a lot of branches and perf stat cargo run -r was showing 440M branch-misses,
i.e. almost one every other line. That’s about as bad as it can be with half the
numbers having a single integer digit and half the numbers having two integer digits.
I was able to pinpoint it to the branching by running perf record -b -g cargo run -r followed by perf report.
Changing this to a branch-less version is quite a bit faster, and now only
140M branch-misses remain.
1
2
3
4
5
6
7
8
9
// s = abc.d
leta=unsafe{*s.get_unchecked(s.len()-5)};letb=unsafe{*s.get_unchecked(s.len()-4)};letc=unsafe{*s.get_unchecked(s.len()-3)};letd=unsafe{*s.get_unchecked(s.len()-1)};letv=aasV*1000*(s.len()>=5)asV+basV*100*(s.len()>=4)asV+casV*10+dasV;
Code Snippet 20:
Branchless float parsing.
Purging all branches: 1.67s
The remaining branch misses are in the while eq_sep == 0 in the scanning for
; and \n characters (Manual SIMD: 29s).
Since cities and temperatures have variable
lengths, iterating over the input will always have to do some branching to
move to the next bit of input or not.
We can work around this by doing an independent scan for the next occurrence of
; and \n in each iteration. It turns out the longest line in the input
contains 33 characters including newline. This means that a single 32-character
SIMD comparison is exactly sufficient to determine the next occurrence of each character.
#[inline(always)]fniter_lines<'a>(mutdata: &'a[u8],mutcallback: implFnMut(&'a[u8],&'a[u8])){letsep=S::splat(b';');letend=S::splat(b'\n');// Find the next occurence of the given separator character.
letmutfind=|mutlast: usize,sep: S|{letsimd=S::from_array(unsafe{*data.get_unchecked(last..).as_ptr().cast()});leteq=sep.simd_eq(simd).to_bitmask();letoffset=eq.trailing_zeros()asusize;last+offset};// Pointers to the last match of ; or \n.
letmutsep_pos=0;letmutstart_pos=0;whilestart_pos<data.len()-32{// Both start searching from the last semicolon, so that the unaligned SIMD read can be reused.
sep_pos=find(sep_pos+1,sep);letend_pos=find(sep_pos+1,end);unsafe{letname=data.get_unchecked(start_pos+1..sep_pos);letvalue=data.get_unchecked(sep_pos+1..end_pos);callback(name,value);}start_pos=end_pos;}}
It turns out this does not actually give a speedup, but we will use this as a
starting point for further improvements. Note also that perf stat changes
considerably:
1
2
3
4
5
6
7
8
9
10
11
BEFORE
35,409,579,588 cycles:u # 3.383 GHz
96,408,277,646 instructions:u # 2.72 insn per cycle
4,463,603,931 branches:u # 426.463 M/sec
148,274,976 branch-misses:u # 3.32% of all branches
AFTER
35,217,349,810 cycles:u # 3.383 GHz
87,571,263,997 instructions:u # 2.49 insn per cycle
1,102,455,316 branches:u # 105.904 M/sec
4,148,835 branch-misses:u # 0.38% of all branches
Code Snippet 22:
Selection of perf stat before and after
Note:
The total CPU cycles is the same.
The number of instructions has gone down 10%.
The number of branches went from 4.4G (4 per line) to 1.1G (1 per line).
The number of branch-misses went from 150M (once every 7 lines) to 4M (once
every 250 lines).
To illustrate, at this point the main loop looks like this. Note that it is
indeed branchless, and only 87 instructions long.
Code Snippet 23:
Main loop of the program. The first column shows the percentage of time in each line.
Some more attempts
Possible improvements at this point are increasing parallelism to get more than
2.49 instructions per cycle, and increasing parallelism by using SIMD to process
multiple lines at a time.
I quickly hacked something that splits the data: &[u8] for each thread into
two to four chunks that are processed at the same time, hoping multiple
independent code paths would improve parallelism, but that didn’t work out
immediately. Probably I need to interleave all the instructions everywhere, and
manually use SIMD where possible, which is slightly annoying and for a later time.
I also know that the PtrHash perfect hash function contains a few redundant
instructions that are needed in the general case but not here. Removing those
would be nice.
Faster perfect hashing: 1.55s
Turns out I added a function to PtrHash for lookups on small tables, but
wasn’t actually using it. Saves some cycles again :)
Bug time: Back up to 1.71s
I accidentally dropped the - b'0' part when making the floating point parsing branch free.
Adding them back in bumps the times quite a bit, given that it’s only 4
instructions extra.
1
2
3
4
5
6
7
8
- let a = unsafe { *s.get_unchecked(s.len().wrapping_sub(5)) };
- let b = unsafe { *s.get_unchecked(s.len().wrapping_sub(4)) };
- let c = unsafe { *s.get_unchecked(s.len().wrapping_sub(3)) };
- let d = unsafe { *s.get_unchecked(s.len().wrapping_sub(1)) };
+ let a = unsafe { *s.get_unchecked(s.len().wrapping_sub(5)) - b'0' };
+ let b = unsafe { *s.get_unchecked(s.len().wrapping_sub(4)) - b'0' };
+ let c = unsafe { *s.get_unchecked(s.len().wrapping_sub(3)) - b'0' };
+ let d = unsafe { *s.get_unchecked(s.len().wrapping_sub(1)) - b'0' };
Code Snippet 24:
Bugfix
Temperatures less than 100: 1.62s
Assuming that temperatures are less than 100 helps quite a bit.
1
2
3
4
5
6
7
8
9
10
11
- // s = abc.d
+ // s = bc.d
- let a = unsafe { *s.get_unchecked(s.len().wrapping_sub(5)) - b'0' };
let b = unsafe { *s.get_unchecked(s.len().wrapping_sub(4)) - b'0' };
let c = unsafe { *s.get_unchecked(s.len().wrapping_sub(3)) - b'0' };
let d = unsafe { *s.get_unchecked(s.len().wrapping_sub(1)) - b'0' };
- let v = a as V * 1000 * (s.len() >= 5) as V
+ let v = 0
+ b as V * 100 * (s.len() >= 4) as V
+ c as V * 10
+ d as V;
Computing min as a max: 1.50
Instead of record.min = min(value, record.min) we can do record.min = max(-value, record.min) and negate the value at the end. This turns out to
generate slightly faster code, because the two max calls can now be done using SIMD.
Intermezzo: Hyperthreading: 1.34s
Turns out enabling hyperthreading speeds up the parallel run by around 10%!
Surprisingly, the single-threaded version becomes a bit slower, from 7s down to 9s.
Here’s a perf stat on 12 threads with hyperthreading:
1
2
3
4
5
6
7
8
9
15,888.56 msec task-clock:u # 9.665 CPUs utilized
54,310,677,591 cycles:u # 3.418 GHz
72,478,697,756 instructions:u # 1.33 insn per cycle
1,095,632,420 branches:u # 68.957 M/sec
1,658,837 branch-misses:u # 0.15% of all branches
1.643962284 seconds time elapsed
15.107088000 seconds user
0.767016000 seconds sys
Instructions per cycle is very low, probably since the hyperthreads are competing for
cycles.
And here’s a perf stat for 6 threads with hyperthreading disabled:
1
2
3
4
5
6
7
8
9
9,059.31 msec task-clock:u # 5.220 CPUs utilized
30,369,716,994 cycles:u # 3.352 GHz
72,476,925,632 instructions:u # 2.39 insn per cycle
1,095,268,277 branches:u # 120.900 M/sec
1,641,589 branch-misses:u # 0.15% of all branches
1.735524152 seconds time elapsed
8.495397000 seconds user
0.556457000 seconds sys
Notice how elapsed time is a bit higher, but instructions/cycle, task-clock, and user time are lower.
In fact, the number of instructions, branches, and branch-misses is pretty much
the same. The hyperthreaded variant just has more contention for the available cycles.
(I’ll now disable hyperthreading again to make numbers easier to interpret.)
Not parsing negative numbers: 1.48s
Having to deal with positive and negative numbers at the same time is kinda
annoying for further parsing optimizations. To fix this, we will create separate
hash map entries for positive and negative numbers. In particular, for cities with a
negative value I will act as if the ; separator was located at the position of
the minus.
That way, the value is always positive, and the city name gets a ; appended
for negative cases.
This now saves some instructions in the parsing where we can assume the number
is positive.
Code Snippet 25:
adjusting the name and value &[u8] slice for negative numbers.
More efficient parsing: 1.44s
It turns out subtracting b'0' from each character is quite slow: since each
u8 subtraction could overflow, they all have to be done independently, as we
can see in the generated assembly:
To fix this, we can do all subtractions on i32. That way, the compiler merges
them into a single subtraction of 111 * b'0'.
1
2
3
4
5
6
7
8
// s = bc.d
- let b = unsafe { *s.get_unchecked(s.len().wrapping_sub(4)) - b'0' };
- let c = unsafe { *s.get_unchecked(s.len().wrapping_sub(3)) - b'0' };
- let d = unsafe { *s.get_unchecked(s.len().wrapping_sub(1)) - b'0' };
+ let b = unsafe { *s.get_unchecked(s.len().wrapping_sub(4)) as V - b'0' as V };
+ let c = unsafe { *s.get_unchecked(s.len().wrapping_sub(3)) as V - b'0' as V };
+ let d = unsafe { *s.get_unchecked(s.len().wrapping_sub(1)) as V - b'0' as V };
b as V * 100 * (s.len() >= 4) as V + c as V * 10 + d as V
New assembly:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
│letb=unsafe{*s.get_unchecked(s.len().wrapping_sub(4))asV-b'0'asV};
1.16│movzbl-0x4(%rbp,%rsi,1),%edx│letc=unsafe{*s.get_unchecked(s.len().wrapping_sub(3))asV-b'0'asV};
0.08│movzbl-0x3(%rbp,%rsi,1),%r11d│letd=unsafe{*s.get_unchecked(s.len().wrapping_sub(1))asV-b'0'asV};
1.16│movzbl-0x1(%rbp,%rsi,1),%esi│basV*100*(s.len()>=4)asV+casV*10+dasV1.18│imul$0x64,%edx,%edx1.08│lea(%r11,%r11,4),%r11d0.57│lea(%rdx,%r11,2),%r11d│one_billion_row_challenge::Record::add:│self.min=self.min.max(-value);
0.98│mov$0x14d0,%ebp// the constant is 111*48 = 111*b'0'
1.05│sub%r11d,%ebp1.26│sub%esi,%ebp│one_billion_row_challenge::parse:│basV*100*(s.len()>=4)asV+casV*10+dasV1.60│lea(%rsi,%r11,1),%edx0.23│add$0xffffeb30,%edx// the constant is -111*48
I’m still confused, because the 111 * 48 constant appears twice, which seems
unnecessary, but the code is quite a bit shorter for sure.
I’m also not quite sure exactly how the * (s.len() >= 4) ends up in here. It
seems that both values are computed and the right one is automatically picked.
But then I would expect to see 11 * 48 as a constant also, but that doesn’t appear.
Fixing undefined behaviour: back to 1.56s
So, it turns out that doing the unsafe equivalent of s[s.len()-4] on a slice
of length 3 is no good. On stable rust it happens to work, but on nightly it
does not. Apparently get_unchecked with out-of-bounds indexes is actually
undefined behaviour.
Rewriting everything to use indices into the larger data: &[u8] slice instead of
small slices ends up giving quite a bit of slowdown. Maybe enough to reconsider
some earlier choices…
Lazily subtracting b'0': 1.52s
So actually we don’t even have to do the - b'0' subtraction in the hot loop at
all for c.d!
Since we count how many entries there are, we can just keep them, and compensate
in the end by subtracting count * 11 * b'0' there.
1
2
3
4
5
6
let b = unsafe { *data.get_unchecked(end - 4) as V - b'0' as V };
- let c = unsafe { *data.get_unchecked(end - 3) as V - b'0' as V };
+ let c = unsafe { *data.get_unchecked(end - 3) as V };
- let d = unsafe { *data.get_unchecked(end - 1) as V - b'0' as V };
+ let d = unsafe { *data.get_unchecked(end - 1) as V };
b as V * 100 * (end - sep >= 5) as V + c as V * 10 + d as V
Instead of doing min and max operations on parsed integers, we can do them
on the raw string bytes directly. We only have to be careful to mask out the ;
in ;2.3, since the ASCII value of ; is larger than the digits.
The commit is kinda ugly. Highlight is this function that reads the bytes in
big-endian order, and then shifts away out-of-range bytes.
This is only a bit slower, and I suspect this will allow speedups later if we
can drop the other parse function completely.
Parsing using a single multiplication: doesn’t work
The previous parse_to_raw reads bytes in reverse order (big endian, while my
system is low endian) and returns something like 0x3b_3c_??_3d (where ?? is
the value of ., and 3b is the ASCII value of digit b).
Most of this is constant and can be subtracted at the end in bulk.
The interesting bits are 0x0b_0c_00_0d. From this we could almost get
the target value by multiplying with 100 + 10 << 8 + 1 << 24 and then
shifting right by 24:
1
2
3
4
5
6
0b 0c 00 0d
----------------------------------
100b 100c 0 100d | * 100
10b 10c 0 10d | * 10 << 8
b 0 c d | * 1 << 24
b 0 10b+c T 100c 10d 100d | sum
Code Snippet 27:
Long multiplication using base 256.
The problem is that the 100c term could get larger than 256 and interfere with
the target value T. Also, T itself could be up to 1000 and run in the
10b+c stored one position higher.
Parsing using a single multiplication does work after all! 1.48s
Alternatively, we could read the bytes in low endian order into 0x0d_00_0c_0b and multiply by
1 + 10<<16 + 100<<24, again followed by a right shift of 24:
1
2
3
4
5
6
0d 00 0c 0b
---------------------------------
d 0 c b | * 1
10d 0 10c 10b | * 10 << 16
100d 0 100c 100b | * 100 << 24
100d 10d 100c T 10b c b | sum
Code Snippet 28:
Long multiplication using base 256, attempt 2.
In this case, we see that the 10b term in the lower significant position can
never exceed 90 < 256, so there are no problems there.
On the other side, we still have the issue that \(T = 100b + 10c + d\) can be up
to \(999\) and more than \(256\). But it turns out we’re lucky! (Very lucky?!) The
100c term does occupy the next byte, but 100 is divisible by 4, and hence the
low 2 bits of that byte are always 0. And those 2 bits are exactly what we need!
Taking them together with the 8 bits of the target byte, we have a 10 bit value,
which can store anything up to 1024, larger than the max value of T!
OMG It’s hard to overstate my happiness here! Just think for a moment about
what just happened. I wonder if there is some ‘deeper’ reason this works. We’re
very much using that 2 divides 10, and that the position of the . is such that
only the 100c term is close by.
The new parse function looks like this:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
/// data[start..end] is the input slice.
fnparse(data: &[u8],start: usize,end: usize)-> V{// Start with a slice b"bc.d" of b"c.d"
// Read it as low-endian value 0x3d..3c3b or 0x??3d..3c (?? is what comes after d)
letraw=u32::from_le_bytes(unsafe{*data.get_unchecked(start..).as_ptr().cast()});// Shift out the ??, so we now have 0x3d..3c3b or 0x3d..3c00.
letraw=raw>>(8*(4-(end-start)));// Extract only the relevant bits corresponding to digits.
// Digit a looks of 0x3a in ASCII, so we simply mask out the high half of each byte.
letraw=raw&0x0f000f0f;// Long multiplication.
constC: u64=1+(10<<16)+(100<<24);// Shift right by 24 and take the low 10 bits.
letv=((rawasu64*C)>>24)&((1<<10)-1);// let v = unsafe { core::arch::x86_64::_bextr_u64(v, 24, 10) };
vas_}
Note: I just learned of the bextr instruction that could replace this last
shift and mask sequence, but it doesn’t seem to be faster than the native implementation.
A side note: ASCII
One thing that really annoys me a lot is that the bit representation of ; is
very close to the digits 0..9 (ASCII table). This means that we can’t nicely
mask it away, and are somewhat forced to do the shift in algorithm above to get
0 bits.
Skip parsing using PDEP: 1.42s
The new parsing function is pretty sweet, but there is another idea worth exploring.
Using PDEP and PEXT instructions (parallel bit deposit/extract), we can shuffle bits around in 3 clock
cycles (same as a multiplication).
So we could do:
1
2
3
4
5
6
0b0011bbbb0011cccc........0011dddd // Input string.
0b 1111 1111 1111 // Extract these bits (PEXT)
0b bbbbccccdddd // Result of extraction
0b 1111 1111 1111 // Deposit bits here in u64 (PDEP)
0b bbbb cccc dddd // Result of deposit
^42 ^21 ^0
Code Snippet 29:
playing with PDEP and PEXT.
Now we have a u64 integer containing three 21-bit integers inside it. Since
each digit has a value up to \(10\), we can accumulate 2^{21}/10 = 200k
temperatures in this before overflow happens. So we could process say 10M rows
(where each city should only appear 25k times) and then accumulate things into
a main buffer.
Improved
We can even get rid of the PEXT instruction and directly PDEP the bits
where we want them, since all bits only move left. Basically we deposit some
bloat bits (the 3 of the ASCII digits, and the representation of the .), but
then we simply drop them using an &.
1
2
3
4
5
0b0011bbbb0011cccc........0011dddd // Input string.
xxxx yyyyyyyyyyyy // name the trash bits
0b bbbb xxxxcccc yyyyyyyyyyyydddd // Deposit here
0b 1111 1111 1111 // Mask out trash using &
^42 ^21 ^0
Code Snippet 30:
More efficient bit deposit.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
fnparse_pdep(data: &[u8],start: usize,end: usize)-> u64{// Start with a slice b"bc.d" of b"c.d"
// Read it as big-endian value 0xbbcc..dd or 0xcc..dd?? (?? is what comes after d)
letraw=u32::from_be_bytes(unsafe{*data.get_unchecked(start..).as_ptr().cast()});// Shift out the ??, so we now have 0xbbcc..dd or 0x00cc..dd.
letraw=raw>>(8*(4-(end-start)));// 0b bbbb xxxxcccc yyyyyyyyyyyydddd // Deposit here
// 0b 1111 1111 1111 // Mask out trash using &
letpdep=0b0000000000000000001111000000000000011111111000001111111111111111u64;letmask=0b0000000000000000001111000000000000000001111000000000000000001111u64;letv=unsafe{core::arch::x86_64::_pdep_u64(rawasu64,pdep)};v&mask}
Code Snippet 31:
parsing using PDEP
Note that we have to do a bit of extra logic around the accumulation, but that’s
fairly straightforward.
The length of the assembly code is now down to 62 instructions!
Both the pdep shuffle and mask require a dedicated mov. The
and-mask can also be done on the u32 value before the pdep, which merges the
mov and and to and $0xf0f000f,%edx.
A further note
Not all systems support PDEP. Also, AMD Zen 1 and Zen 2 processor do
support it, but implement it using microcode taking 18 cycles (wikipedia). In
our case we can emulate the PDEP relatively easily using some mask and shift
operations:
Indeed, branch-mispredicts are slow, but it turns out these branches are almost
never taken! In fact, since the data is random, they will be true only \(\ln(n)\)
time for each city, where \(n\) is the number of times the city appears. Each city
appears 10G/400 = 2.5M times, and so for each city we only expect around 15
updates, for 400*15=6000 branch misses in total (per thread),
which is very small compared to the 1G writes and =cmov=s we save.
I tried adding an unlikely(..) hint to this, but that didn’t change the
generated assembly code.
No counting: 1.34s
Ok this is where you start saying I’m cheating, but it turns out we don’t need
to count how often each city appears in the input. Since each row has a random
city, the number of times a city appears is \(Bin(R, 1/C)\), where \(r=10^9\) and
\(C=413\). This has expectation \(R/C\sim 15\cdot 10^6\) and standard deviation
\(\approx \sqrt(R/C) \sim 4000\). So the variance around the expected value is
only \(1/4000\). That means that if in the end we simply divide by the expected
number of occurrences instead of the actual number of occurrences, we’re only
off by \(1/4000\) relatively. There are \(400\) cities, so the worst deviation will
be slightly more, maybe \(4/4000=1/1000\) or so. And since there are \(400<1000\)
cities, the probability that one of them has average value within \(1/1000\) from
a rounding-boundary is less than a half.
Indeed, this ends up giving the correct output for the evaluation set.
Edit: There was a bug in my verification script. As it turns out, after this
optimization the temperature for Baghdad is incorrectly reported as 22.7
instead of 22.8.
Arbitrary long city names: 1.34
Turns out supporting longer city names doesn’t even come at a cost since the
branch is never taken!
1
2
3
4
5
6
7
8
9
10
11
12
13
let find = |mut last: usize, sep: S| {
let simd = S::from_array(unsafe { *data.get_unchecked(last..).as_ptr().cast() });
let mut eq = sep.simd_eq(simd).to_bitmask() as u32;
+ if eq == 0 {
+ while eq == 0 {
+ last += 32;
+ let simd = S::from_array(unsafe { *data.get_unchecked(last..).as_ptr().cast() });
+ eq = sep.simd_eq(simd).to_bitmask() as u32;
+ }
+ }
let offset = eq.trailing_zeros() as usize;
last + offset
};
Note that just doing while eq == 0 {..} slows down the code, but wrapping it
in an always-false if statement makes it not slow down the code.
4 entries in parallel: 1.23s
Currently we only process one line at a time. This limits how much we can
exploit ILP (instruction level parallelism). So instead we will process 4 lines
at a time.
We don’t want to process consecutive lines since they depend on each other.
Instead, we chunk the data into 4 parts and process them independently of each
other in parallel.
To improve the generated code, I created a simple macro that interleaves the
instructions. This gets the IPC (instructions per cycle) up to 2.66 (up from 2.20).
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
letmutstate0=init_state(0);letmutstate1=init_state(data.len()/4);letmutstate2=init_state(2*data.len()/4);letmutstate3=init_state(3*data.len()/4);// Duplicate each line for each input state.
macro_rules!step{[$($s:expr),*]=>{$($s.sep_pos=find_long($s.sep_pos,sep)+1;)*$($s.end_pos=find($s.sep_pos,end)+1;)*$(callback(data,$s.start_pos,$s.sep_pos-1,$s.end_pos-1);)*$($s.start_pos=$s.end_pos;)*}}whilestate3.start_pos<data.len(){step!(state0,state1,state2,state3);}
Mmap per thread
Currently the main thread spends 180ms on unmapping the file at the end.
Instead I tried making a separate Mmap instance per thread, so the unmapping
can happen in parallel. Sadly this doesn’t make a difference: now the 180ms is
simply divided over all threads, and since munmap has a program wide lock, the
drop(mmap) calls in each thread are just waiting for each other.
Some workarounds:
Make each thread do its own memory (un)mapping. This does not yet speed things
up since they all unmap at the same time.
Make two or three times the number of threads, so that other work can be done
instead of waiting for the kernel lock. This ends up having around 100ms
overhead, which is better than the 180ms we started with but still significant.
For now both these options are slightly too messy and I will just keep things as
they are. I might revisit this later.
Reordering some operations: 1.19s
Turns out doing the increments for the start/end of each slice in a slightly
different order gives quite a bit better performance. :shrug:
The above still processes the callback sequentially for the 4 parallel states.
In particular, each callback spends relatively a lot of time waiting for the
right record to be loaded from L1 cache into registers.
Let’s go deeper and call the callback with 4 states. Ugly commit, around 0.01s slowdown for some reason.
Now we can interleave the record loading instructions:
// If value is negative, extend name by one character.
s0.sep += (data.get_unchecked(s0.sep + 1) == &b'-') as usize;
let name0 = data.get_unchecked(s0.start..s0.sep);
let key0 = to_key(name0);
let index0 = phf.index_single_part(&key0);
-let entry0 = slots.get_unchecked_mut(index0);
let raw0 = parse_to_raw(data, s0.sep + 1, s0.end);
-entry0.add(raw0, raw_to_pdep(raw0));
s1.sep += (data.get_unchecked(s1.sep + 1) == &b'-') as usize;
let name1 = data.get_unchecked(s1.start..s1.sep);
let key1 = to_key(name1);
let index1 = phf.index_single_part(&key1);
-let entry1 = slots.get_unchecked_mut(index1);
let raw1 = parse_to_raw(data, s1.sep + 1, s1.end);
-entry1.add(raw1, raw_to_pdep(raw1));
s2.sep += (data.get_unchecked(s2.sep + 1) == &b'-') as usize;
let name2 = data.get_unchecked(s2.start..s2.sep);
let key2 = to_key(name2);
let index2 = phf.index_single_part(&key2);
-let entry2 = slots.get_unchecked_mut(index2);
let raw2 = parse_to_raw(data, s2.sep + 1, s2.end);
-entry2.add(raw2, raw_to_pdep(raw2));
s3.sep += (data.get_unchecked(s3.sep + 1) == &b'-') as usize;
let name3 = data.get_unchecked(s3.start..s3.sep);
let key3 = to_key(name3);
let index3 = phf.index_single_part(&key3);
-let entry3 = slots.get_unchecked_mut(index3);
let raw3 = parse_to_raw(data, s3.sep + 1, s3.end);
+
+let entry0 = slots.get_unchecked_mut(index0);
+entry0.add(raw0, raw_to_pdep(raw0));
+let entry1 = slots.get_unchecked_mut(index1);
+entry1.add(raw1, raw_to_pdep(raw1));
+let entry2 = slots.get_unchecked_mut(index2);
+entry2.add(raw2, raw_to_pdep(raw2));
+let entry3 = slots.get_unchecked_mut(index3);
entry3.add(raw3, raw_to_pdep(raw3));
This brings us down to 1.11s, quite a big win really!
Comparing perf stat before and after, the number of instructions went from
62.7G to 63.0G (0.5% more; I’m not sure why), but instructions per cycle went up
from 2.55 to 2.80, almost 10% better!
Even more ILP: 1.05
Of course, we can do better and interleave all the instructions fully. I played
around a bit and the following gives the best results:
For some reason keeping the sep += ... and name = ... instructions as-is is
around 0.01s better. Still the overall win is very nice! The number of
instructions remains the same, but instructions per cycle is up to 2.95! For
1.05s in total.
At this point it turns out that hyperthreading actually slows down the code!
Table 2:
results at 4.6GHz.
threads
HT off
HT on
1
4.93
5.05
6
0.90
0.97
12
0.90
0.92
Compliance 1, OK I’ll count: 1.06
Ok I’m adding back the record counting. Being non-compliant this way is
annoying, and it’s really quite brittle to rely on the randomness here.
Surprisingly, performance is much (~0.05s) better when I make count: u64 instead of
count: u32. I didn’t bother to figure out why.
TODO
Use AES instructions for hashing
Verify hash on lookup
Rebuild PHF on hash collision / new key.
Simplify perfect hash function
SIMD
Munmap per chunk/thread
Postscript
As things go, my attention shifted to other things before wrapping everything
up, and by now I have forgotten the details of where I left things :")
Anyway, here’s a blogpost comparing non-java implementations, and here is a long
discussion of possible optimizations.