std::extents: unrolling loops

std::extents: unrolling loops #

Let’s look at std::extents from <mdspan>. In particular its operator==. First, std::extents is a variation of

std::array<int, n> exts{1, ..., n};

that handles compile-time constant extents. For example:

std::extents<int, 1, 2, dyn> exts(3);

defines an extents object [1, 2, 3], the first two are known at compile time, called static extents, and the third one is known at runtime, called dynamic extent. Unlike std::array, the std::extents only stores the dynamic extents. The number of extents (n in the above) is the rank.

Let’s fix some notation: E[i] is the i-th extent (static or dynamic), D[k] is the array of dynamic extents; and k[i] is the index of the i-th extent in D (which is meaningful only if the i-th extent is dynamic). Finally, let S[i] denote the i-th static extent (if the i-th extent is dynamic then S[i] == -1).

Any two extents objects are comparable, because if they have the different ranks they they’re always considered different. They are considered equal if and only if E1[i] == E2[i] for all i.

A first implementation in libstdc++ looked like this:

template<typename _OIndexType, size_t... _OExtents>
friend constexpr bool
operator==(const extents& __self,
           const extents<_OIndexType, _OExtents...>& __other) noexcept
{
  if constexpr (!_S_is_compatible_extents<_OExtents...>())
    return false;
  else
  {
    for (size_t __i = 0; __i < __self.rank(); ++__i)
      if (!cmp_equal(__self.extent(__i), __other.extent(__i)))
        return false;
    return true;
  }
}

The upper branch isn’t particularly interesting, it handles cases where the ranks differ, or when the static extents are different.

Note that the trip count is known at compile time and the order in which the extents are checked doesn’t matter. This loop could be unrolled.

Let’s try and get a feeling for what the optimizer will do:

#include <mdspan>
extern "C" {

bool same1(const std::extents<int, 1, 2, 3>& e1,
           const std::extents<int, 1, 2, 3>& e2)
{ return e1 == e2; }

bool same2(const std::extents<int, 0, 2, 3>& e1,
           const std::extents<int, 1, 2, 3>& e2)
{ return e1 == e2; }

bool same3(const std::extents<int, 0, dyn, 3>& e1,
           const std::extents<int, 1, 2, 3>& e2)
{ return e1 == e2; }

bool same4(const std::extents<int, 1, dyn, 3>& e1,
           const std::extents<int, dyn, 2, 0>& e2)
{ return e1 == e2; }

The generated code with -O2, after eliminating filler code for alignment, is:

0000000000000000 <same1>:
   0:  mov    eax,0x1
   5:  ret

0000000000000010 <same2>:
  10:  xor    eax,eax
  12:  ret

0000000000000020 <same3>:
  20:  xor    eax,eax
  22:  ret

0000000000000030 <same4>:
  30:  xor    eax,eax
  32:  ret

Good! Oh, wait, the compiler didn’t need to do much, because same2, …, same4 are all determined to be false by the upper constexpr branch. However, for same1 the optimizer had to work. Note that it eliminated everything and just returns true. It might also be interesting to note that it always inlines operator==.

Okay, once more:

bool same5(const std::extents<int, 1, 2, 3>& e1,
           const std::extents<int, 1, dyn, 3>& e2)
{ return e1 == e2; }

This time we get:

0000000000000040 <same5>:
  40:  xor    eax,eax
  42:  mov    rdx,QWORD PTR [rax*8+0x0]
  4a:  mov    ecx,edx
  4c:  cmp    rdx,0xffffffffffffffff
  50:  jne    5d <same5+0x1d>
  52:  mov    rdx,QWORD PTR [rax*8+0x0]
  5a:  mov    ecx,DWORD PTR [rsi+rdx*4]
  5d:  cmp    ecx,DWORD PTR [rax*8+0x0]
  64:  jne    80 <same5+0x40>
  66:  add    rax,0x1
  6a:  cmp    rax,0x3
  6e:  jne    42 <same5+0x2>
  70:  mov    eax,0x1
  75:  ret
  80:  xor    eax,eax
  82:  ret

(Personally, I find the code used to align instructions distracting and will continue to silently delete it.) Back to the topic at hand. Let’s transcribe it to pseudo code:

for(i = 0; i != 3; ++i)
    if S2[i] == -1
      e2 = D2[k2[i]]
    if S1[i] != e2:
       return false;
return true;

How does one guess? First, there’s the sequence with a backwards jump:

  66:  add    rax,0x1
  6a:  cmp    rax,0x3
  6e:  jne    42 <same5+0x2>

this smells like a loop. Next, we should track the loads:

  42:  mov    rdx,QWORD PTR [rax*8+0x0]

loads 8 bytes from unknown_offset + 0*8. It’s unknown because it’s reading a global variable (and those are only given a location in executables/shared libraries, but not object files), so that’s likely one of the static arrays S1 or S2.

There’s also

  5a:  mov    ecx,DWORD PTR [rsi+rdx*4]

we know that rsi is the second argument passed to the function, i.e. the reference/pointer e2. Therefore, this is D2[k] (where k is currently in the register rdx). We can see that rdx is loaded in the line above from a global, which makes sense if one knows that k[i] is stored in a static array.

Then, there’s

  4c:  cmp    rdx,0xffffffffffffffff
  50:  jne    5d <same5+0x1d>

this checks if rdx, i.e. S?[i], is equal to -1. If not it jumps forwards. That’s likely an if-condition. Now, one can see the rest.

What’s interesting is that:

  • it’s optimized the indirection for E1, because since all its extents are static, there’s no need to emit code that can handle D1[k1[i]].

  • it’s not eliminated the loop,

  • it’s not eliminated the trivial iterations at the beginning and end of the loop.

Considering all we want to do is:

all(E1[i] == E2[i] for i in range(n))

the amount of code seems excessive (usually rank <= 3; almost always <= 8, because m**k just grows too fast for k >= 4).

Out of curiosity, what happens if we write loop-less code. How? Probably, some variant of pack expansion. Maybe something like this:

template<typename _OIndexType, size_t... _OExtents>
friend constexpr bool
operator==(const extents& __self,
           const extents<_OIndexType, _OExtents...>& __other) noexcept
{
  auto __impl = [&__self, &__other]<size_t... _Counts>(
      index_sequence<_Counts...>)
    { return (cmp_equal(__self.extent(_Counts),
                        __other.extent(_Counts)) && ...); };
  return __impl(make_index_sequence<__self.rank()>());
}

It’s a bit clumsy, because everything is stuffed into a lambda for the sole purpose of deducing the loop indices, but otherwise it’s a reasonably flexible pattern to create a compile-time for-loop.

Time to compile all examples again (with -O2):

0000000000000000 <same1>:
   0:  mov    eax,0x1
   5:  ret

0000000000000010 <same2>:
  10:  xor    eax,eax
  12:  ret

0000000000000020 <same3>:
  20:  xor    eax,eax
  22:  ret

0000000000000030 <same4>:
  30:  xor    eax,eax
  32:  ret

0000000000000040 <same5>:
  40:  cmp    DWORD PTR [rsi],0x2
  43:  sete   al
  46:  ret

Alright, that’s shorter (enough so to probably be a sensible statement in its own right). This is interesting because:

  • even though we’ve removed the upper constexpr branch, the optimizer correctly handles cases with mixed dynamic/static extents with one mismatching pair of static extents (doesn’t matter if the mismatch is preceded by dynamic extents or not).

  • then there’s same5 which is amazing. There’s nothing left of the lambda, or any of the other static for-loop boilerplate. It’s simply reduced it down to: D2[1] == 2.

That’s nice enough to warrant one more example:

bool same6(const std::extents<int, dyn,   2, 3, dyn>& e1,
           const std::extents<int, dyn, dyn, 3, 4>& e2)
{ return e1 == e2; }

It’s longer, has no mismatching static extents and the dynamic extents don’t line up nicely. The generated code is:

0000000000000050 <same6>:
  50:  mov    edx,DWORD PTR [rdi]
  52:  xor    eax,eax
  54:  cmp    DWORD PTR [rsi],edx
  56:  je     60 <same6+0x10>
  58:  ret
  60:  cmp    DWORD PTR [rsi+0x4],0x2
  64:  jne    58 <same6+0x8>
  66:  cmp    DWORD PTR [rdi+0x4],0x4
  6a:  sete   al
  6d:  ret

Recall rdi is e1 and rsi is e2. So we see: not only has it eliminated the trivial comparison 3 == 3, it’s also removed the indirection D[k[i]] because i is a compile time constant. In pseudo code:

if (D1[0] != D2[0]) return false
if (    2 != D2[1]) return false
return D1[1] == 4

… and that’s it :-)