Code Generation and Merge Sort
source link: https://www.tuicool.com/articles/hit/ERVvya
Go to the source link to view the article. You can view the picture content, updated content and better typesetting reading experience. If the link is broken, please click the button below to view the snapshot at that time.
April 24, 2019
I was reading a few pages of Knuths The Art of Computer Programming , Volume 4A about “branchless computation” (p. 180) in which he demonstrates how to get rid of branches by using conditional instructions. As an instructive example he consideres the inner part of merge sort , in which we are to merge two sorted lists of numbers into one bigger list of the numbers. The description as given by Knuth is as follows:
If $x_i < y_j$ set $z_k \gets x_i$, $i \gets i+1$, and go to x_done if $i = i_{max}$.
Otherwise set $z_k \gets y_i$, $j \gets j+1$, and go to y_done if $j = j_{max}$.
Then set $k \gets k+1$ and go to z_done if $k = k_{max}$.
$x$ and $y$ are the input lists, $z$ is the output merged list. $i$, $j$, and $k$ are loop indices for the three respective lists and the $_{max}$ variants are the lists length.
I got curious and decided to see how a standard optimizing compilier would
handle this case, and whether writing the assmebly yourself would provide any
gain in performance. After all, this is just slightly more complicated than the
trivial examples used to show off good codegen, so it would not be unreasonable
for the compiler to manage to fix a bad implementation of this. In addition, it
would serve as a great excuse to finally learn how to write x86
.
Basics
Here’s the inner loop in C code:
void branching(uint64_t *xs, size_t xmax, uint64_t *ys, size_t ymax, uint64_t *zs, size_t zmax) { size_t i = 0, j = 0, k = 0; while (k < zmax) { if (xs[i] < ys[j]) { zs[k++] = xs[i++]; if (i == xmax) { // x_done memcpy(zs + k, ys + j, 8 * (zmax - k)); return; } } else { zs[k++] = ys[j++]; if (j == ymax) { // y_done memcpy(zs + k, xs + i, 8 * (zmax - k)); return; } } } // z_done }
This seems to be a more or less straight forward textbook implementation of the
procedure, so it will do fine as a benchmark. As a quick check before going any
deeper into this we can use godbolt.org
to see whether
this experiment is even worth doing. Godbolts x86-64 gcc 8.3
with -O3
spits
out this (annotations are by me):
branching(unsigned long*, unsigned long, unsigned long*, unsigned long, unsigned long*, unsigned long): test r9, r9 ; if (r9 == 0) je .L15 ; goto .L15 push r13 ; xor eax, eax ; xor r11d, r11d ; j = 0 xor r10d, r10d ; i = 0 push r12 ; push rbp ; push rbx ; jmp .L2 ; .L17: add r10, 1 ; i++ mov QWORD PTR [r8-8+rax*8], rbp ; zs[k-1] = xi cmp r10, rsi ; if (i == xmax) je .L16 ; goto .L16 .L6: cmp r9, rax ; if (k == zmax) je .L1 ; goto .L1 .L2: lea r12, [rdi+r10*8] ; calculate xs + i lea r13, [rdx+r11*8] ; calculate ys + j add rax, 1 ; k++ mov rbp, QWORD PTR [r12] ; xi = xs[i] mov rbx, QWORD PTR [r13+0] ; yj = ys[j] cmp rbp, rbx ; if (xi < yj) jb .L17 ; goto .L17 add r11, 1 ; j++ mov QWORD PTR [r8-8+rax*8], rbx ; zs[k-1] = yj cmp r11, rcx ; if (j != ymax) jne .L6 ; goto .L6 sub r9, rax ; y_done pop rbx ; mov rsi, r12 ; pop rbp ; lea rdi, [r8+rax*8] ; pop r12 ; lea rdx, [0+r9*8] ; pop r13 ; jmp memcpy ; .L1: pop rbx ; z_done pop rbp ; pop r12 ; pop r13 ; ret ; .L16: sub r9, rax ; x_done pop rbx ; mov rsi, r13 ; pop rbp ; lea rdi, [r8+rax*8] ; pop r12 ; lea rdx, [0+r9*8] ; pop r13 ; jmp memcpy ; .L15: ret
Plenty of branches!
Now, maybe it turns out that it doesn’t matter if we’re branching or not and that the compiler knows best. We could guess that the reason we’re still getting branches is because that’s really the best way to go here. After all “you can’t beat the compiler” seems to be the consensus in many programming circles. Let’s try to write a version in C without exessive use of branching. Then perhaps the compiler will generate different code, and we can see what that difference amounts to in terms of running time. We can adopt Knuth’s branchless version:
void nonbranching_but_branching(uint64_t *xs, size_t xmax, uint64_t *ys, size_t ymax, uint64_t *zs, size_t zmax) { size_t i = 0, j = 0, k = 0; uint64_t xi = xs[i], yj = ys[j]; while ((i < xmax) && (j < ymax) && (k < zmax)) { int64_t t = one_if_lt(xi - yj); yj = min(xi, yj); zs[k] = yj; i += t; xi = xs[i]; t ^= 1; j += t; yj = ys[j]; k += 1; } if (i == xmax) memcpy(zs + k, ys + j, 8 * (zmax - k)); if (j == ymax) memcpy(zs + k, xs + i, 8 * (zmax - k)); }
What is going on, you might ask? The general idea is to first get min(xi,
yj)
, and then have a number t
that’s 1
if xi < yj
and 0
otherwise: we
can add t
to i
, since t=1
if we just wrote xi
to zs[k]
. Then we can xor
it with 1
, effectively flipping 1
to 0
and 0
to 1
, and then add t^1
to j
; this causes either i
or j
to be incremented but not both. We
used two convenience functions here, one_if_lt
and min
, both implemented
straight forward with branching
, hoping that the compiler will figure this
out for us, now that the branches are much smaller.
Next, if we cheat a litte and assume that the highest bit in the numbers are never set we can get rid of those branches:
void nonbranching(uint64_t *xs, size_t xmax, uint64_t *ys, size_t ymax, uint64_t *zs, size_t zmax) { size_t i = 0, j = 0, k = 0; uint64_t xi = xs[i], yj = ys[j]; while ((i < xmax) && (j < ymax) && (k < zmax)) { uint64_t neg = (xi - yj) >> 63; yj = neg * xi + (1 - neg) * yj; zs[k] = yj; i += neg; xi = xs[i]; neg ^= 1; j += neg; yj = ys[j]; k += 1; } if (i == xmax) memcpy(zs + k, ys + j, 8 * (zmax - k)); if (j == ymax) memcpy(zs + k, xs + i, 8 * (zmax - k)); }
What is up with (xi - yj) >> 63
you may ask? This result is negative if xi <
yj
, and so it will overflow and its most significant bit will be set.
Then we shift down logically (since we’re using unsigned integers) so
the bits that are filled in are all zeroes. Since the width is 64, we effectively
move the upper bit to the lowest position while setting all other bits to zero.
Knuth has another quirk, namely that his arrays usually points to the end
of
the array, and his indices are negative, going from -xmax
up to 0
instead
of the more standard going from 0
up to xmax
. One consequence of this is
that the termination check can be done with one comparison instead of three, by and
ing together the three indices: since they are negative they have their
most significant bit set, unless zero. Here’s both of the previous versions
with this reversal trick:
void nonbranching_but_branching_reverse(uint64_t *xs, size_t xmax, uint64_t *ys, size_t ymax, uint64_t *zs, size_t zmax) { uint64_t *xse = xs + xmax; uint64_t *yse = ys + ymax; uint64_t *zse = zs + zmax; ssize_t i = -((ssize_t) xmax); ssize_t j = -((ssize_t) ymax); ssize_t k = -((ssize_t) zmax); uint64_t xi = xse[i], yj = yse[j]; while (i & j & k) { uint64_t t = one_if_lt(xi - yj); yj = min(xi, yj); zse[k] = yj; i += t; xi = xse[i]; t ^= 1; j += t; yj = yse[j]; k += 1; } if (i == 0) memcpy(zse + k, yse + j, -8 * k); if (j == 0) memcpy(zse + k, xse + i, -8 * k); } void nonbranching_reverse(uint64_t *xs, size_t xmax, uint64_t *ys, size_t ymax, uint64_t *zs, size_t zmax) { uint64_t *xse = xs + xmax; uint64_t *yse = ys + ymax; uint64_t *zse = zs + zmax; ssize_t i = -((ssize_t) xmax); ssize_t j = -((ssize_t) ymax); ssize_t k = -((ssize_t) zmax); uint64_t xi = xse[i], yj = yse[j]; while (i & j & k) { uint64_t neg = (xi - yj) >> 63; yj = neg * xi + (1 - neg) * yj; zse[k] = yj; i += neg; xi = xse[i]; neg ^= 1; j += neg; yj = yse[j]; k += 1; } if (i == 0) memcpy(zse + k, yse + j, -8 * k); if (j == 0) memcpy(zse + k, xse + i, -8 * k); }
Technically, I suppose we do assume that the length of the
arrays are not >2**63
, so that they fit in an ssize_t
, but considering
that the address space of x86-64
is not
64 bits, but merely
48 bits,
this is not a problem, even in theory.
Writing the ASM ourselves
Lastly, we can try to write the assembly ourselves. When translating the
branch-free routine by Knuth into x86
there are a number of things to do.
First we need to figure out how to get -1/0/+1
by comparing two variables, as MMIX
s CMP
instruction does. However, instead of trying to translate this
line by line, which would end up with us having more instructions than needed,
we should rather look more closely at what we’re doing, so that we really
understand the minimal amount of work that we have to do.
We only need to do two things: compare $x_i$ and $y_i$ and load the smaller
into a register, and increment either i
or j
. The former can be done using cmovl
, and the latter can be done in a similar fasion as Knuth does it,
which is basically what we’ve been doing up to this point in C.
This is the version I ended up with (here in inline-GCC asm format):
1: mov %[minxy], %[yj] ; cmp %[xi], %[yj] ; minxy = min(xi, yj) cmovl %[minxy], %[xi] ; mov QWORD PTR [%[zse]+8*%[k]], %[minxy] ; zs[k] = minxy mov %[t], 0 ; t = 0 cmovl %[t], %[one] ; if xi < yj: t = 1 add %[i], %[t] ; i += t mov %[xi], QWORD PTR [%[xse]+8*%[i]] ; xi = xs[i] xor %[t], 1 ; t ^= 1 add %[j], %[t] ; j += t mov %[yj], QWORD PTR [%[yse]+8*%[j]] ; yj = ys[j] add %[k], 1 ; k += 1 mov %[u], %[i] ; and %[u], %[j] ; test %[u], %[k] ; if ((i & j & k) != 0) jnz 1b ; goto 1
There’s a few quirks here, like having a couple of mov
instructions in
between the second conditional load and the instruction it conditions on, and
the fact that cmovl
couldn’t take an immediate value, so I had to setup a
register with only the value 1
in it. A sneaky detail to keep in mind is that
when we set t = 0
we cannot use the trick of xor
ing t
with itself,
since this will change the flags, causing the subsequent cmovl
to be wrong.
Now we can take a look at the assembly generated from some of the other
fuctions by using objdump -d
.
Our own programs are compiled with -O3 -march=native
.
Here is the inner loop in nonbranching_reverse
:
<nonbranching_reverse>: 1ef0: mov rax,rdi 1ef3: sub rax,rsi 1ef6: shr rax,0x3f 1efa: mov rdx,r8 1efd: sub rdx,rax 1f00: imul rdx,rsi 1f04: imul rdi,rax 1f08: add rbp,rax 1f0b: xor rax,0x1 1f0f: add rdi,rdx 1f12: mov QWORD PTR [r13+r12*8+0x0],rdi 1f17: add rcx,rax 1f1a: inc r12 1f1d: mov rax,rbp 1f20: and rax,r12 1f23: mov rdi,QWORD PTR [rbx+rbp*8] 1f27: mov rsi,QWORD PTR [r10+rcx*8] 1f2b: test rax,rcx 1f2e: jne 1ef0 <nonbranching_reverse+0x40>
Sure looks a lot better than branching
!
This seems more or less reasonable, but we can see that the multiplication
trickery that we used to avoid the min
branch takes up some space here;
presumably it also takes some time. Maybe one little branch isn’t too bad
though, and perhaps the compiler is more willingly to use conditional
instructions if we use the ternary operator, like this:
void nonbranching_reverse_ternary(uint64_t *xs, size_t xmax, uint64_t *ys, size_t ymax, uint64_t *zs, size_t zmax) { uint64_t *xse = xs + xmax; uint64_t *yse = ys + ymax; uint64_t *zse = zs + zmax; ssize_t i = -((ssize_t) xmax); ssize_t j = -((ssize_t) ymax); ssize_t k = -((ssize_t) zmax); uint64_t xi = xse[i], yj = yse[j]; while (i & j & k) { uint64_t ybig = (xi - yj) >> 63; yj = ybig ? xi : yj; zse[k] = yj; i += ybig; xi = xse[i]; ybig ^= 1; j += ybig; yj = yse[j]; k += 1; } if (i == 0) memcpy(zse + k, yse + j, -8 * k); if (j == 0) memcpy(zse + k, xse + i, -8 * k); }
This time, if we look at the assembly, we can see that the compiler is finally getting it: cmove
!
2080: mov rax,yj ; 2083: sub rax,xi ; 2086: shr rax,0x3f ; t = (yj - xi) >> 63 208a: cmove yj,xi ; yj = t == 0 ? xi : yj 208e: add j,rax ; j += t 2091: mov QWORD PTR [zs+k*8],yj ; z[k] = yj 2096: xor rax,0x1 ; t ^= 1 209a: inc k ; k++ 209d: add i,rax ; i += t 20a0: mov rax,k ; 20a3: and rax,j ; t = k & j 20a6: mov yj,QWORD PTR [ys+j*8] ; yj = ys[j] 20aa: mov xi,QWORD PTR [xs+i*8] ; xi = xs[i] 20ae: test rax,i ; if ((i & j & k) != 0) 20b1: jne 2080 ; goto .2080
So we see it’s really the same! Curiously, the compiler turned our code around
to have t
be 1
if xi
was the bigger, whereas our ybig
was 1
if yj
was the bigger.
Results
And now for the results! We fill two arrays with random elements and run branching
on it, such that we get the merged array back. This is used as the
ground truth which all other variations are checked agaist, in case we have
messed up. Then we use clock_gettime
to measure the wall clock time that we
spend, per method. The following is running time in milliseconds where both
lists are 2**25
elements long, averaged over 100 runs; 10 iterations per seed
and 10 different seeds ( srand(i)
for each iteration).
These are the numbers I got on a Intel [email protected] ( avg +/- var
):
branching: 30.998 +/- 0.001 nonbranching_but_branching: 27.330 +/- 0.002 nonbranching: 24.770 +/- 0.000 nonbranching_but_branching_reverse: 19.387 +/- 0.000 nonbranching_reverse: 20.015 +/- 0.000 nonbranching_reverse_ternary: 19.038 +/- 0.000 asm_nb_rev: 18.987 +/- 0.001
I also ran the suite on another machine with a Intel [email protected], in order to see if there would be any significant difference:
branching: 31.405 +/- 0.034 nonbranching_but_branching: 27.646 +/- 0.097 nonbranching: 27.894 +/- 0.021 nonbranching_but_branching_reverse: 22.760 +/- 0.040 nonbranching_reverse: 21.284 +/- 0.050 nonbranching_reverse_ternary: 19.299 +/- 0.002 asm_nb_rev: 19.793 +/- 0.009
Interestingly, on this CPU our assembly is slightly slower than the ternary
version; I guess this is due to us using a cmovl
where the compiler generated
version used the shifting trick.
Bonus: Sorting
We can’t possibly have done all this merging without making a proper mergesort
in the end! Luckily for us, the merge
part is really the
only difficult part of the routine:
void merge_sort(uint64_t *xs, size_t n, uint64_t *buf) { if (n < 2) return; size_t h = n / 2; merge_sort(xs, h, buf); merge_sort(xs + h, n - h, buf + h); merge(xs, h, xs + h, n - h, buf, n); memcpy(xs, buf, 8 * n); }
Unfortunately we have to merge to a buffer and then memcpy
it back. Perhaps
this is fixable: we can make the sorting routine either put the result in xs
or in buf
, and by having the recursive calls say which we can merge into the
other, assuming both recursive calls agree(!!). That is, if the
recursive calls say that the sorted subarrays are in xs
, we merge into buf
and tell our caller that our
result is in buf
. At the end, we just need to
make sure that the final sorted numbers are in xs
.
void _sort_asm(uint64_t *xs, size_t n, uint64_t *buf, int *into_buf) { if (n < 2) { *into_buf = 0; return; } size_t h = n / 2; int res_in_buf; _sort_asm(xs, h, buf, &res_in_buf); // WARNING: `res_in_buf` for the two calls is needs _sort_asm(xs + h, n - h, buf + h, &res_in_buf); // not be the same in the real world! *into_buf = res_in_buf ^ 1; if (res_in_buf) asm_nb_rev(buf, h, buf + h, n - h, xs, n); else asm_nb_rev(xs, h, xs + h, n - h, buf, n); } void sort_asm(uint64_t *xs, size_t n, uint64_t *buf) { int res_in_buf; _sort_asm(xs, n, buf, &res_in_buf); if (res_in_buf) { memcpy(xs, buf, 8 * n); } }
and similar, for the other variants.
You might see the branch and wonder if we can remove it — I tried, by making
an array {xs, buf}
and index it with res_in_buf
, but it caused a minor
slowdown: maybe some branching is fine after all.
Here are the running times:
i7-7500U i5-8250U sort_branching: 369.479 +/- 0.047 393.762 +/- 0.082 sort_nonbranching_but_branching: 324.337 +/- 0.014 337.120 +/- 0.099 sort_nonbranching: 325.658 +/- 0.028 352.802 +/- 0.120 sort_nonbranching_but_branching_reverse: 279.237 +/- 0.164 287.799 +/- 0.154 sort_nonbranching_reverse: 283.927 +/- 0.033 299.277 +/- 0.929 sort_nonbranching_reverse_ternary: 270.668 +/- 0.009 278.644 +/- 1.677 sort_asm_nb_rev: 270.228 +/- 0.009 281.657 +/- 0.360
If you would like to run the suite yourself, the git repo is avaiable here .
Thanks for reading.
Recommend
About Joyk
Aggregate valuable and interesting links.
Joyk means Joy of geeK