std::extents: reducing indirection

std::extents: reducing indirection #

Let’s look at std::extents from <mdspan>. This time: extent. It’s original implementation in libstdc++ was:

template<typename _IndexType, array _Extents>
class _ExtentsStorage
{
  constexpr _IndexType
  _M_extent(size_t __r) const noexcept
  {
    auto __se = _Extents[__r];
    if (__se == dynamic_extent)
      return _M_dyn_exts[_S_dynamic_index[__r]];
    else
      return __se;
  }
  // ...
};

Here, _Extents is simply the array of static extents passed to std::extents, e.g. if std::extents<int, 1, dyn, 3, 4> then _Extents == std::array{1, -1, 3, 4}. (Right, dyn is short for std::dynamic_extent.)

The performance question is:

Does the compiler eliminate the branching, if __r is known at compile time?

Now’s unfortunately the time to agree on some notation: E[i] is the i-th extent (static or dynamic), D[k] is the array of dynamic extents; and k[i] is the index of the i-th extent in D (which is meaningful only if the i-th extent is dynamic). Finally, let S[i] denote the i-th static extent (if the i-th extent is dynamic then S[i] == -1).

Okay, and back to more interesting topics. Consider,

int prod(const std::extents<int, 3, dyn>& exts)
{ return exts.extent(0) * exts.extent(1); }

Compile it with -O2, and disassemble:

0000000000000000 <prod>:
   0:  mov    eax,DWORD PTR [rdi]
   2:  lea    eax,[rax+rax*2]
   5:  ret

It seems d + 2*d with d = D[0] is the fast way of implementing: 3*D[0]. So, the answer is: yes. Or “sometimes” if you want to be more cautious. Note, that not only did it eliminate the branching, it also eliminated the indirection D[k[i]].

Let’s try again:

int prod2(const std::extents<int, 3, dyn>& exts,
          const std::array<int, 2>& a)
{ return exts.extent(0) * a[0] + exts.extent(1) * a[1]; }

which results in:

0000000000000010 <prod2>:
  10:  mov    eax,DWORD PTR [rsi+0x4]
  13:  mov    edx,DWORD PTR [rsi]
  15:  imul   eax,DWORD PTR [rdi]
  18:  lea    edx,[rdx+rdx*2]
  1b:  add    eax,edx
  1d:  ret

Again, it eliminated the branching and indirection. So that seems to work, but what about a similar, but harder question:

If S[i] != -1 happens to be true for all i, does the compiler eliminate the branching?

Let’s adjust the test problem a little:

int prod3(const std::extents<int, 3, 5, 7>& exts,
          const std::array<int, 3>& a)
{
  int ret = 0;
  for(size_t i = 0; i < exts.rank(); ++i)
    ret += exts.extent(i) * a[i];
  return ret;
}

The loop has a trip count that’s easily known at compile time. It’s less easy to see at compile-time that S[i] != dyn (hidden inside exts.extent(i)) is always true. Here, the compiler flags matter, but on -O2 the generated code is:

0000000000000020 <prod3>:
  20:  xor    eax,eax
  22:  xor    ecx,ecx
  24:  mov    rdx,QWORD PTR [rax*8+0x0]
  2c:  cmp    rdx,0xffffffffffffffff
  30:  je     36 <prod3+0x16>
  36:  imul   edx,DWORD PTR [rsi+rax*4]
  3a:  add    rax,0x1
  3e:  add    ecx,edx
  40:  cmp    rax,0x3
  44:  jne    24 <prod3+0x4>
  46:  mov    eax,ecx
  48:  ret

Notice the following lines:

  2c:  cmp    rdx,0xffffffffffffffff
  30:  je     36 <prod3+0x16>
  36:  imul   edx,DWORD PTR [rsi+rax*4]

clearly this is the check: S[i] == -1. What’s interesting is that there’s no code to handle the case where they are equal (because it never is). However, it’s not eliminated the check or the jump. More precisely, the branch (je) is always taken and jumps to the very next instruction. The picture changes when passing -O3:

0000000000000020 <prod3>:
  20:  mov    eax,DWORD PTR [rsi]
  22:  mov    edx,DWORD PTR [rsi+0x4]
  25:  lea    eax,[rax+rax*2]
  28:  lea    edx,[rdx+rdx*4]
  2b:  add    eax,edx
  2d:  mov    edx,DWORD PTR [rsi+0x8]
  30:  lea    eax,[rax+rdx*8]
  33:  sub    eax,edx
  35:  ret

The generated code makes sense: no comparison or jump and no loading of static extents. It simply computes 3*a[0] + 5*a[1] + 7*a[2] as follows:

  a0 + 2*a0 + 4*a1 + a1 + 8*a2 - a2

with a0 = a[0], a1 = a[1] and a2 = a[2] (the movs). Naturally, there might be an even faster sequence of instructions, but this doesn’t contain any superfluous instructions.

Let’s see if we can make it easier on the optimizer and get the same behaviour on -O2:

template<typename _IndexType, array _Extents>
class _ExtentsStorage
{
  static constexpr bool
  _S_is_dynamic(size_t __r) noexcept
  {
    if constexpr (__all_static<_Extents>())
      return false;
    else
      return _Extents[__r] == dynamic_extent;
  }

  constexpr _IndexType
  _M_extent(size_t __r) const noexcept
  {
    if (_S_is_dynamic(__r))
      return _M_dyn_exts[_S_dynamic_index(__r)];
    else
      return _S_static_extent(__r);
  }
  // ...
};

The point is to see what happens if the condition is made more obviously always true or false. Let’s recompile again with -O2:

0000000000000020 <prod3>:
  20:  mov    eax,DWORD PTR [rsi]
  22:  mov    edx,DWORD PTR [rsi+0x4]
  25:  lea    eax,[rax+rax*2]
  28:  lea    edx,[rdx+rdx*4]
  2b:  add    eax,edx
  2d:  mov    edx,DWORD PTR [rsi+0x8]
  30:  lea    eax,[rax+rdx*8]
  33:  sub    eax,edx
  35:  ret

Okay, one last time. Let’s look at the following:

int prod4(const std::extents<int, 3, 5, 7, 11>& exts,
          const std::array<int, 4>& a)
{
  int ret = 0;
  for(size_t i = 0; i < exts.rank(); ++i)
    ret += exts.extent(i) * a[i];
  return ret;
}

it different from before in that the array is exactly four elements long; which just happens to be 128 bits. Let’s also compile this example with -O2. First the version without the optimization and disassemble. What we see is essentially unchanged:

0000000000000050 <prod4>:
  50:  xor    eax,eax
  52:  xor    ecx,ecx
  54:  mov    rdx,QWORD PTR [rax*8+0x0]
  5c:  cmp    rdx,0xffffffffffffffff
  60:  je     66 <prod4+0x16>
  66:  imul   edx,DWORD PTR [rsi+rax*4]
  6a:  add    rax,0x1
  6e:  add    ecx,edx
  70:  cmp    rax,0x4
  74:  jne    54 <prod4+0x4>
  76:  mov    eax,ecx
  78:  ret

Let’s compile (-O2) against the optimized version and disassemble:

0000000000000040 <prod4>:
  40:  movdqu xmm1,XMMWORD PTR [rsi]
  44:  movdqa xmm2,XMMWORD PTR [rip+0x0]
  4c:  movdqa xmm0,xmm1
  50:  psrlq  xmm1,0x20
  55:  pmuludq xmm0,xmm2
  59:  psrlq  xmm2,0x20
  5e:  pmuludq xmm1,xmm2
  62:  pshufd xmm0,xmm0,0x8
  67:  pshufd xmm1,xmm1,0x8
  6c:  punpckldq xmm0,xmm1
  70:  movdqa xmm1,xmm0
  74:  psrldq xmm1,0x8
  79:  paddd  xmm0,xmm1
  7d:  movdqa xmm1,xmm0
  81:  psrldq xmm1,0x4
  86:  paddd  xmm0,xmm1
  8a:  movd   eax,xmm0
  8e:  ret

Meaning, it unlocks SIMD vectorization on -O2. Overall, it’s impressive to see how well the compiler figures out rather non-trivial properties like: S[i] == -1 for all i.