Skip to content

Add support for Power10 (AltiVec) MMA instructions for bfloat16.

Hello everyone, long time no see! I hope to find everyone safe and sound 😸

What?

This merge request add MMA support for bfloat16 on Power 10 machines. As Power 10 has bfloat16 support this is way faster comparing to what we had before using only VSX instructions that falls back for float32 to any computation.

How

Briefly, Power10 MMA instructions for 16 bits types has a Rank-2 operation xvbf16ger2pp that is able to do two rank-1 updates simultaneously using 2 rows/columns. It takes a 4x2 against a 2x4 matrix block and do a rank 2 update. Below there's a scheme of how my MMA register needs to be and the result. One thing worth mentioning is, result is a 4x4 float32 matrix, there's a "type upgrade" on this operation.

A B C D E F G H
I J K L M N O P
A*I + B*J A*K + B*L A*M + B*N A*O + B*P
C*I + D*J C*K + D*L C*M + D*N C*O + D*P
E*I + F*J E*K + F*L E*M + F*N E*O + F*P
G*I + H*J G*K + H*L G*M + H*N G*O + H*P

In short, what gemmMMAbfloat16 is doing, it's loading 4x2 and 2x4 blocks from LHS/RHS, organizing them at the registers and running rank-2 update. As standard packing wasn't created with this situation in mind, it can be a little confusing how I'm acessing memory. If you think a detailed explanation is necessary I don't mind drawing something to make myself clear as possible 😄

Out of curiosity, I did try to change packing to make code more friendly but I couldn't make my custom packing work for triangular so I went back.

Code

Temporary float32 result

Talking further about the result being a float32 matrix to avoid converting back and forth from float32 <-> bfloat16 on GEMM I created a temporary float32 matrix to hold result.

  float** result = new float*[cols];
  for(int i = 0; i < cols; i++) result[i] = new float[rows];

I didn't see any code using new so if that's a problem I'm open to suggestions. 😉

Long and ugly switch statement

pgerMMAbfloat16 is basically running rank-2 update instructions. There's a mask feature for this set of instructions that I'm able to ignore some parts of the result matrix. This is useful when we are running that last section of a matrix that is unable to fit whole 4 elements and/or don't have two rows/columns. Using masks I'm able to ignore result for those non-existent values.

Now comes the ugly part, I don't know masks at compile time and so I can't write something like: __builtin_mma_pmxvbf16ger2pp(acc, reinterpret_cast<Packet16uc>(a.m_val), reinterpret_cast<Packet16uc>(b.m_val), maskX, maskY, 0b11); I don't have exact compiler error at this moment but it was something that mentions literals. I bet it's because, after compilation, these masked rank updates are different instructions (instead of a instruction with masks as arguments).

Testing

For testing I've changed files below (not submitted):

  • test/product_syrk.cpp
  • test/product_large.cpp
  • test/product_symm.cpp
  • test/triangular.cpp

I had a problem creating those tests because bfloat16 doesn't support int scaling (i.e. k * A) so there's some tests that had like 2*m1 that doesn't work for bfloat16. To work I've changed:

This:

VERIFY_IS_APPROX(res, 2*(square + m1 * m2.transpose()));

To: VERIFY_IS_APPROX(res, (square + m1 * m2.transpose()) + ( square + m1 * m2.transpose()) );

There's also this code section that don't work for bfloat16 and I don't have any idea why:

  if(!MatrixType::IsRowMajor)
  {
    typedef Matrix<Scalar,Dynamic,Dynamic> MatrixX;
    MatrixX buffer(2*rows,2*rows);
    Map<RowSquareMatrixType,0,Stride<Dynamic,2> > map1(buffer.data(),rows,rows,Stride<Dynamic,2>(2*rows,2));
    buffer.setZero();
    VERIFY_IS_APPROX(map1 = m1 * m2.transpose(), (m1 * m2.transpose()).eval());
    buffer.setZero();
    VERIFY_IS_APPROX(map1.noalias() = m1 * m2.transpose(), (m1 * m2.transpose()).eval());
    buffer.setZero();
    VERIFY_IS_APPROX(map1.noalias() += m1 * m2.transpose(), (m1 * m2.transpose()).eval());
  }

If you people think it's important to update tests files to have my bfloat16 tests I don't mind doing it. Honestly I didn't give much tought about this matter but maybe I can specialize product function on product.h for a Matrix of bfloat16. Suggestions are also appreciated here 😄

Last considerations

As this is a lot of changes I can imagine this will go back and forth. Any suggestion/consideration will be much appreciated! Thanks a lot for your time.

PS: This is a collaboration between me, Chip Kerchner and Rafael Souza. (co-authors on commit message)

Merge request reports

Loading