Skip to content

Even better CompareByte for x64.

Rika requested to merge runewalsh/source:comparebyte-x64 into main

I’m terribly sorry about rewriting this function just after !369 (merged), but I accidentally found a better idea about CompareByte for x64 on StackOverflow. TL;DR: even with unaligned accesses, you can often over-read safely after checking explicitly if all bytes reside on the same memory page.

Here, if both tails can be over-read, I read them with MOVDQUs and throw away garbage bytes, otherwise compare them as before: uint32 by uint32, then byte by byte.

Assuming addresses are distributed randomly, chances that one tail can be over-read are 1 − 16 / 4096 ≈ 99.6%, and that two tails can be over-read are (1 − 16 / 4096)² ≈ 99.2%.

I tried to remove uint32 branch, this makes the function simpler than before but slows down bad cases by always handling them bytewise (on my computer, 3→9 ns for 8-byte, 6→16 ns for 15-byte). While the probability of bad cases is around 1%, something important can reside there, so I left it.

Benchmark: CompareByteBenchmark.pas.

My results:

Different byte #0 of 1
System.CompareByte:           2.3 ns/call
CompareByteAsmV2:             2.5 ns/call

Different byte #7 of 8
System.CompareByte:           3.4 ns/call
CompareByteAsmV2:             2.5 ns/call

Different byte #14 of 15
System.CompareByte:           6.3 ns/call
CompareByteAsmV2:             2.5 ns/call

Different byte #15 of 16
System.CompareByte:           2.0 ns/call
CompareByteAsmV2:             2.0 ns/call

Different byte #23 of 24
System.CompareByte:           4.6 ns/call
CompareByteAsmV2:             2.9 ns/call

Different byte #24 of 25
System.CompareByte:           5.4 ns/call
CompareByteAsmV2:             2.9 ns/call

Different byte #46 of 47
System.CompareByte:           8.1 ns/call
CompareByteAsmV2:             3.5 ns/call

Different byte #1 of 100
System.CompareByte:           2.0 ns/call
CompareByteAsmV2:             2.0 ns/call

Different byte #99 of 100
System.CompareByte:           7.9 ns/call
CompareByteAsmV2:             6.1 ns/call

Different byte #999 of 1000
System.CompareByte:           37 ns/call
CompareByteAsmV2:             39 ns/call

Looks that small but non-trivial cases that aren’t multiples of 16 can be 2× faster thanks to forcing them into a SIMD unit.

C++ source:

#include <immintrin.h>
#include <cstdint>

typedef unsigned char uchar;
const size_t PAGE_SIZE = 4096;
const uintptr_t XMM_MASK = 16 - 1;

__attribute__((__ms_abi__))
// __attribute__((__sysv_abi__))
ptrdiff_t compare_byte(uchar* a, uchar* b, ptrdiff_t n)
{
	uchar *ae = a + n, *aepart;
	if (n >= 4)
	{
		aepart = a + (n & ~XMM_MASK);
		while (a != aepart)
		{
			__m128i sample_a = _mm_loadu_si128((__m128i*)a);
			__m128i sample_b = _mm_loadu_si128((__m128i*)b);
			int cmp_mask = ~_mm_movemask_epi8(_mm_cmpeq_epi8(sample_a, sample_b)) & 65535;
			if (cmp_mask != 0)
			{
				size_t idiff = __builtin_ctz(cmp_mask);
				return (ptrdiff_t)a[idiff] - (ptrdiff_t)b[idiff];
			}
			a += 1 + XMM_MASK, b += 1 + XMM_MASK;
		}
		if (a == ae) return 0;

		// If both tails can be over-read to XMMs, compare them as XMMs.
		// (a ^ (a + b)) >= PageSize iff 'a + b' resides on different page than 'a'.
		if
		(
			((uint32_t)(uintptr_t)a ^ ((uint32_t)(uintptr_t)a + XMM_MASK)) < PAGE_SIZE &&
			((uint32_t)(uintptr_t)b ^ ((uint32_t)(uintptr_t)b + XMM_MASK)) < PAGE_SIZE
		)
		{
			__m128i sample_a = _mm_loadu_si128((__m128i*)a);
			__m128i sample_b = _mm_loadu_si128((__m128i*)b);
			int cmp_mask = ~_mm_movemask_epi8(_mm_cmpeq_epi8(sample_a, sample_b)) & 65535;
			if (cmp_mask == 0) return 0;
			size_t idiff = __builtin_ctz(cmp_mask);
			if (a + idiff >= ae) return 0;
			return (ptrdiff_t)a[idiff] - (ptrdiff_t)b[idiff];
		}

		// Compare uint32 by uint32, can be replaced with 'goto bytewise_body'.
		aepart = a + ((ae - a) & ~uintptr_t{sizeof(uint32_t) - 1});
		while (a != aepart)
		{
			if (*(uint32_t*)a != *(uint32_t*)b)
				return (ptrdiff_t)__builtin_bswap32(*(uint32_t*)a) - (ptrdiff_t)__builtin_bswap32(*(uint32_t*)b);
			a += sizeof(uint32_t); b += sizeof(uint32_t);
		}
	}
	while (a != ae)
	{ // bytewise_body:
		ptrdiff_t diff = (ptrdiff_t)*a - (ptrdiff_t)*b;
		if (diff) return diff;
		a++, b++;
	}
	return 0;
}

Merge request reports