std::extents: unrolling loops #
Let’s look at std::extents
from <mdspan>
. In particular its operator==
.
First, std::extents
is a variation of
std::array<int, n> exts{1, ..., n};
that handles compile-time constant extents. For example:
std::extents<int, 1, 2, dyn> exts(3);
defines an extents object [1, 2, 3]
, the first two are known at compile time,
called static extents, and the third one is known at runtime, called dynamic
extent. Unlike std::array
, the std::extents
only stores the dynamic extents.
The number of extents (n
in the above) is the rank.
Let’s fix some notation: E[i]
is the i-th extent (static or dynamic), D[k]
is the array of dynamic extents; and k[i]
is the index of the i-th extent in
D
(which is meaningful only if the i-th extent is dynamic). Finally, let
S[i]
denote the i-th static extent (if the i-th extent is dynamic then S[i] == -1
).
Any two extents objects are comparable, because if they have the different
ranks they they’re always considered different. They are considered equal if
and only if E1[i] == E2[i]
for all i
.
A first implementation in libstdc++ looked like this:
template<typename _OIndexType, size_t... _OExtents>
friend constexpr bool
operator==(const extents& __self,
const extents<_OIndexType, _OExtents...>& __other) noexcept
{
if constexpr (!_S_is_compatible_extents<_OExtents...>())
return false;
else
{
for (size_t __i = 0; __i < __self.rank(); ++__i)
if (!cmp_equal(__self.extent(__i), __other.extent(__i)))
return false;
return true;
}
}
The upper branch isn’t particularly interesting, it handles cases where the ranks differ, or when the static extents are different.
Note that the trip count is known at compile time and the order in which the extents are checked doesn’t matter. This loop could be unrolled.
Let’s try and get a feeling for what the optimizer will do:
#include <mdspan>
extern "C" {
bool same1(const std::extents<int, 1, 2, 3>& e1,
const std::extents<int, 1, 2, 3>& e2)
{ return e1 == e2; }
bool same2(const std::extents<int, 0, 2, 3>& e1,
const std::extents<int, 1, 2, 3>& e2)
{ return e1 == e2; }
bool same3(const std::extents<int, 0, dyn, 3>& e1,
const std::extents<int, 1, 2, 3>& e2)
{ return e1 == e2; }
bool same4(const std::extents<int, 1, dyn, 3>& e1,
const std::extents<int, dyn, 2, 0>& e2)
{ return e1 == e2; }
The generated code with -O2
, after eliminating filler code for
alignment, is:
0000000000000000 <same1>:
0: mov eax,0x1
5: ret
0000000000000010 <same2>:
10: xor eax,eax
12: ret
0000000000000020 <same3>:
20: xor eax,eax
22: ret
0000000000000030 <same4>:
30: xor eax,eax
32: ret
Good! Oh, wait, the compiler didn’t need to do much, because same2
, …,
same4
are all determined to be false by the upper constexpr branch. However,
for same1
the optimizer had to work. Note that it eliminated everything and
just returns true
. It might also be interesting to note that it always
inlines operator==
.
Okay, once more:
bool same5(const std::extents<int, 1, 2, 3>& e1,
const std::extents<int, 1, dyn, 3>& e2)
{ return e1 == e2; }
This time we get:
0000000000000040 <same5>:
40: xor eax,eax
42: mov rdx,QWORD PTR [rax*8+0x0]
4a: mov ecx,edx
4c: cmp rdx,0xffffffffffffffff
50: jne 5d <same5+0x1d>
52: mov rdx,QWORD PTR [rax*8+0x0]
5a: mov ecx,DWORD PTR [rsi+rdx*4]
5d: cmp ecx,DWORD PTR [rax*8+0x0]
64: jne 80 <same5+0x40>
66: add rax,0x1
6a: cmp rax,0x3
6e: jne 42 <same5+0x2>
70: mov eax,0x1
75: ret
80: xor eax,eax
82: ret
(Personally, I find the code used to align instructions distracting and will continue to silently delete it.) Back to the topic at hand. Let’s transcribe it to pseudo code:
for(i = 0; i != 3; ++i)
if S2[i] == -1
e2 = D2[k2[i]]
if S1[i] != e2:
return false;
return true;
How does one guess? First, there’s the sequence with a backwards jump:
66: add rax,0x1
6a: cmp rax,0x3
6e: jne 42 <same5+0x2>
this smells like a loop. Next, we should track the loads:
42: mov rdx,QWORD PTR [rax*8+0x0]
loads 8 bytes from unknown_offset + 0*8
. It’s unknown because it’s reading a
global variable (and those are only given a location in executables/shared
libraries, but not object files), so that’s likely one of the static arrays
S1
or S2
.
There’s also
5a: mov ecx,DWORD PTR [rsi+rdx*4]
we know that rsi
is the second argument passed to the function, i.e. the
reference/pointer e2
. Therefore, this is D2[k]
(where k
is currently in
the register rdx
). We can see that rdx
is loaded in the line above from a
global, which makes sense if one knows that k[i]
is stored in a static array.
Then, there’s
4c: cmp rdx,0xffffffffffffffff
50: jne 5d <same5+0x1d>
this checks if rdx
, i.e. S?[i]
, is equal to -1
. If not it jumps
forwards. That’s likely an if
-condition. Now, one can see the rest.
What’s interesting is that:
-
it’s optimized the indirection for
E1
, because since all its extents are static, there’s no need to emit code that can handleD1[k1[i]]
. -
it’s not eliminated the loop,
-
it’s not eliminated the trivial iterations at the beginning and end of the loop.
Considering all we want to do is:
all(E1[i] == E2[i] for i in range(n))
the amount of code seems excessive (usually rank <= 3; almost always <= 8,
because m**k
just grows too fast for k >= 4
).
Out of curiosity, what happens if we write loop-less code. How? Probably, some variant of pack expansion. Maybe something like this:
template<typename _OIndexType, size_t... _OExtents>
friend constexpr bool
operator==(const extents& __self,
const extents<_OIndexType, _OExtents...>& __other) noexcept
{
auto __impl = [&__self, &__other]<size_t... _Counts>(
index_sequence<_Counts...>)
{ return (cmp_equal(__self.extent(_Counts),
__other.extent(_Counts)) && ...); };
return __impl(make_index_sequence<__self.rank()>());
}
It’s a bit clumsy, because everything is stuffed into a lambda for the sole purpose of deducing the loop indices, but otherwise it’s a reasonably flexible pattern to create a compile-time for-loop.
Time to compile all examples again (with -O2
):
0000000000000000 <same1>:
0: mov eax,0x1
5: ret
0000000000000010 <same2>:
10: xor eax,eax
12: ret
0000000000000020 <same3>:
20: xor eax,eax
22: ret
0000000000000030 <same4>:
30: xor eax,eax
32: ret
0000000000000040 <same5>:
40: cmp DWORD PTR [rsi],0x2
43: sete al
46: ret
Alright, that’s shorter (enough so to probably be a sensible statement in its own right). This is interesting because:
-
even though we’ve removed the upper constexpr branch, the optimizer correctly handles cases with mixed dynamic/static extents with one mismatching pair of static extents (doesn’t matter if the mismatch is preceded by dynamic extents or not).
-
then there’s
same5
which is amazing. There’s nothing left of the lambda, or any of the other static for-loop boilerplate. It’s simply reduced it down to:D2[1] == 2
.
That’s nice enough to warrant one more example:
bool same6(const std::extents<int, dyn, 2, 3, dyn>& e1,
const std::extents<int, dyn, dyn, 3, 4>& e2)
{ return e1 == e2; }
It’s longer, has no mismatching static extents and the dynamic extents don’t line up nicely. The generated code is:
0000000000000050 <same6>:
50: mov edx,DWORD PTR [rdi]
52: xor eax,eax
54: cmp DWORD PTR [rsi],edx
56: je 60 <same6+0x10>
58: ret
60: cmp DWORD PTR [rsi+0x4],0x2
64: jne 58 <same6+0x8>
66: cmp DWORD PTR [rdi+0x4],0x4
6a: sete al
6d: ret
Recall rdi
is e1
and rsi
is e2
. So we see: not only has it eliminated
the trivial comparison 3 == 3
, it’s also removed the indirection D[k[i]]
because i
is a compile time constant. In pseudo code:
if (D1[0] != D2[0]) return false
if ( 2 != D2[1]) return false
return D1[1] == 4
… and that’s it :-)