A bug in AArch64 UMOPA/UMOPS (4-way) instruction
Host environment
- Operating system: Ubuntu 22.04
- OS/kernel version: Linux lin005 6.5.0-28-generic #29~22.04.1-Ubuntu SMP PREEMPT_DYNAMIC Thu Apr 4 14:39:20 UTC 2 x86_64 x86_64 x86_64 GNU/Linux
- Architecture: x86-64
- QEMU flavor: qemu-aarch64
- QEMU version: qemu-aarch64 version 9.0.50 (v9.0.0-1123-g74abb45dac)
- QEMU command line:
./qemu-aarch64 -L /usr/aarch64-linux-gnu/ -cpu max,sme256=on ./test.bin
Emulated/Virtualized environment
- Operating system: N/A (qemu-user)
- OS/kernel version: N/A (qemu-user)
- Architecture: aarch64
Description of problem
umopa computes the multiplication of two matrices in the source registers and accumulates the result to the destination register. A source register’s element size is 16 bits, while a destination register’s element size is 64 bits in case of the 4-way variant of this instruction. Before performing matrix multiplication, each element should be zero-extended to a 64-bit element.
However, the current implementation of the helper function fails to convert the element type correctly. Below is the helper function implementation:
// target/arm/tcg/sme_helper.c
#define DEF_IMOP_64(NAME, NTYPE, MTYPE) \
static uint64_t NAME(uint64_t n, uint64_t m, uint64_t a, uint8_t p, bool neg) \
{ \
uint64_t sum = 0; \
/* Apply P to N as a mask, making the inactive elements 0. */ \
n &= expand_pred_h(p); \
sum += (NTYPE)(n >> 0) * (MTYPE)(m >> 0); \
sum += (NTYPE)(n >> 16) * (MTYPE)(m >> 16); \
sum += (NTYPE)(n >> 32) * (MTYPE)(m >> 32); \
sum += (NTYPE)(n >> 48) * (MTYPE)(m >> 48); \
return neg ? a - sum : a + sum; \
}
DEF_IMOP_64(umopa_d, uint16_t, uint16_t)
When the multiplication is performed, each element, such as (NTYPE)(n >> 0)
, is automatically converted to int32_t
, so the computation result has a type int32_t
. The result is then converted to uint64_t
, and it is added to sum
. It seems the elements should be casted to uint64_t
before performing the multiplication.
Steps to reproduce
- Write
test.c
.
#include <stdio.h>
char i_P1[4] = { 0xff, 0xff, 0xff, 0xff };
char i_P5[4] = { 0xff, 0xff, 0xff, 0xff };
char i_Z0[32] = { // Set only the first element as non-zero
0xff, 0xff, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0,
0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0,
0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0,
0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0,
};
char i_Z20[32] = { // Set only the first element as non-zero
0xff, 0xff, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0,
0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0,
0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0,
0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0,
};
char i_ZA2H[128] = { 0x0, };
char o_ZA2H[128];
void __attribute__ ((noinline)) show_state() {
for (int i = 0; i < 8; i++) {
for (int j = 0; j < 16; j++) {
printf("%02x ", o_ZA2H[16*i+j]);
}
printf("\n");
}
}
void __attribute__ ((noinline)) run() {
__asm__ (
".arch armv9.3-a+sme\n"
"smstart\n"
"adrp x29, i_P1\n"
"add x29, x29, :lo12:i_P1\n"
"ldr p1, [x29]\n"
"adrp x29, i_P5\n"
"add x29, x29, :lo12:i_P5\n"
"ldr p5, [x29]\n"
"adrp x29, i_Z0\n"
"add x29, x29, :lo12:i_Z0\n"
"ldr z0, [x29]\n"
"adrp x29, i_Z20\n"
"add x29, x29, :lo12:i_Z20\n"
"ldr z20, [x29]\n"
"adrp x29, i_ZA2H\n"
"add x29, x29, :lo12:i_ZA2H\n"
"mov x15, 0\n"
"ld1d {za2h.d[w15, 0]}, p1, [x29]\n"
"add x29, x29, 32\n"
"ld1d {za2h.d[w15, 1]}, p1, [x29]\n"
"add x29, x29, 32\n"
"mov x15, 2\n"
"ld1d {za2h.d[w15, 0]}, p1, [x29]\n"
"add x29, x29, 32\n"
"ld1d {za2h.d[w15, 1]}, p1, [x29]\n"
".inst 0xa1f43402\n" // umopa za2.d, p5/m, p1/m, z0.h, z20.h
"adrp x29, o_ZA2H\n"
"add x29, x29, :lo12:o_ZA2H\n"
"mov x15, 0\n"
"st1d {za2h.d[w15, 0]}, p1, [x29]\n"
"add x29, x29, 32\n"
"st1d {za2h.d[w15, 1]}, p1, [x29]\n"
"add x29, x29, 32\n"
"mov x15, 2\n"
"st1d {za2h.d[w15, 0]}, p1, [x29]\n"
"add x29, x29, 32\n"
"st1d {za2h.d[w15, 1]}, p1, [x29]\n"
"smstop\n"
".arch armv8-a\n"
);
}
int main(int argc, char **argv) {
run();
show_state();
return 0;
}
- Compile
test.bin
using this command:aarch64-linux-gnu-gcc-12 -O2 -no-pie ./test.c -o ./test.bin
. - Run
QEMU
using this command:qemu-aarch64 -L /usr/aarch64-linux-gnu/ -cpu max,sme256=on ./test.bin
. - The program, runs on top of the buggy QEMU, prints the first 8 bytes of
ZA2H
as01 00 fe ff ff ff ff ff
. It should print01 00 fe ff 00 00 00 00
after the bug is fixed.