Vectorize tensor.isnan() by using typed predicates.
Getting this to run fast with AVX512 required vectorizing casting between Packet16f and Packet16b. Casting could be improved for other backends later.
Benchmark measurements:
AVX512F:
name old cpu/op new cpu/op delta
BM_isNaN_1T/3 [using 1 threads] 10.9ns ± 4% 10.9ns ± 5% ~ (p=0.083 n=57+54)
BM_isNaN_1T/4 [using 1 threads] 8.07ns ±12% 7.69ns ± 4% -4.81% (p=0.000 n=55+54)
BM_isNaN_1T/7 [using 1 threads] 12.0ns ± 6% 9.9ns ± 5% -17.60% (p=0.000 n=57+44)
BM_isNaN_1T/8 [using 1 threads] 12.8ns ± 3% 8.7ns ± 6% -31.61% (p=0.000 n=53+54)
BM_isNaN_1T/10 [using 1 threads] 19.2ns ± 6% 13.2ns ± 6% -31.20% (p=0.000 n=54+47)
BM_isNaN_1T/15 [using 1 threads] 30.9ns ± 8% 15.3ns ± 7% -50.51% (p=0.000 n=50+55)
BM_isNaN_1T/16 [using 1 threads] 31.8ns ± 4% 15.4ns ± 6% -51.42% (p=0.000 n=50+53)
BM_isNaN_1T/31 [using 1 threads] 104ns ± 5% 53ns ± 6% -49.41% (p=0.000 n=55+59)
BM_isNaN_1T/32 [using 1 threads] 109ns ± 3% 56ns ± 6% -48.48% (p=0.000 n=54+59)
BM_isNaN_1T/64 [using 1 threads] 420ns ± 4% 221ns ± 4% -47.45% (p=0.000 n=57+54)
BM_isNaN_1T/128 [using 1 threads] 1.70µs ± 5% 0.90µs ± 4% -47.18% (p=0.000 n=59+55)
BM_isNaN_1T/256 [using 1 threads] 6.79µs ± 6% 3.57µs ± 4% -47.45% (p=0.000 n=60+49)
BM_isNaN_1T/512 [using 1 threads] 40.7µs ± 4% 33.1µs ± 6% -18.71% (p=0.000 n=48+50)
BM_isNaN_1T/1k [using 1 threads] 192µs ± 4% 198µs ± 3% +3.18% (p=0.000 n=55+54)
BM_isNaN_1T/2k [using 1 threads] 887µs ±24% 912µs ±24% ~ (p=0.054 n=43+45)
BM_isNaN_1T/4k [using 1 threads] 7.37ms ±11% 6.47ms ± 5% -12.26% (p=0.000 n=33+32)
BM_isNaN_1T/10k [using 1 threads] 46.3ms ± 7% 40.6ms ± 3% -12.19% (p=0.000 n=15+11)
AVX2:
name old cpu/op new cpu/op delta
BM_isNaN_1T/3 [using 1 threads] 10.9ns ± 4% 17.3ns ± 7% +58.34% (p=0.000 n=58+52)
BM_isNaN_1T/4 [using 1 threads] 8.51ns ± 3% 14.58ns ± 4% +71.25% (p=0.000 n=49+53)
BM_isNaN_1T/7 [using 1 threads] 15.4ns ± 5% 19.9ns ± 5% +29.65% (p=0.000 n=58+52)
BM_isNaN_1T/8 [using 1 threads] 18.1ns ± 8% 19.8ns ± 4% +9.67% (p=0.000 n=54+57)
BM_isNaN_1T/10 [using 1 threads] 27.7ns ± 4% 26.0ns ± 5% -6.28% (p=0.000 n=51+51)
BM_isNaN_1T/15 [using 1 threads] 50.2ns ± 5% 37.8ns ± 6% -24.69% (p=0.000 n=60+40)
BM_isNaN_1T/16 [using 1 threads] 55.6ns ± 4% 39.5ns ±10% -28.87% (p=0.000 n=59+47)
BM_isNaN_1T/31 [using 1 threads] 196ns ± 3% 121ns ± 4% -38.27% (p=0.000 n=56+54)
BM_isNaN_1T/32 [using 1 threads] 208ns ± 4% 128ns ± 5% -38.61% (p=0.000 n=55+59)
BM_isNaN_1T/64 [using 1 threads] 822ns ± 4% 464ns ± 5% -43.61% (p=0.000 n=57+60)
BM_isNaN_1T/128 [using 1 threads] 3.27µs ± 4% 2.09µs ± 6% -36.14% (p=0.000 n=50+58)
BM_isNaN_1T/256 [using 1 threads] 13.0µs ± 4% 8.3µs ± 4% -36.45% (p=0.000 n=55+57)
BM_isNaN_1T/512 [using 1 threads] 54.4µs ± 6% 43.6µs ± 7% -19.89% (p=0.000 n=60+58)
BM_isNaN_1T/1k [using 1 threads] 226µs ± 5% 198µs ± 4% -12.26% (p=0.000 n=52+52)
BM_isNaN_1T/2k [using 1 threads] 1.09ms ±33% 0.97ms ±26% -10.63% (p=0.000 n=41+47)
BM_isNaN_1T/4k [using 1 threads] 8.22ms ± 7% 7.50ms ±14% -8.79% (p=0.000 n=39+36)
BM_isNaN_1T/10k [using 1 threads] 50.9ms ± 8% 47.3ms ± 6% -7.16% (p=0.000 n=15+14)
SSE:
name old cpu/op new cpu/op delta
BM_isNaN_1T/3 [using 1 threads] 10.6ns ± 4% 17.0ns ± 6% +60.36% (p=0.000 n=47+58)
BM_isNaN_1T/4 [using 1 threads] 8.65ns ± 7% 12.44ns ± 5% +43.82% (p=0.000 n=54+54)
BM_isNaN_1T/7 [using 1 threads] 16.4ns ± 5% 16.8ns ± 6% +2.44% (p=0.000 n=57+53)
BM_isNaN_1T/8 [using 1 threads] 17.5ns ± 7% 17.5ns ± 3% ~ (p=0.551 n=55+53)
BM_isNaN_1T/10 [using 1 threads] 23.5ns ± 5% 23.3ns ± 6% ~ (p=0.080 n=50+48)
BM_isNaN_1T/15 [using 1 threads] 39.9ns ± 4% 34.9ns ± 9% -12.56% (p=0.000 n=45+46)
BM_isNaN_1T/16 [using 1 threads] 42.8ns ± 4% 36.5ns ± 8% -14.56% (p=0.000 n=54+47)
BM_isNaN_1T/31 [using 1 threads] 142ns ± 3% 120ns ± 4% -15.39% (p=0.000 n=59+46)
BM_isNaN_1T/32 [using 1 threads] 149ns ± 4% 126ns ± 6% -15.50% (p=0.000 n=60+54)
BM_isNaN_1T/64 [using 1 threads] 558ns ± 4% 457ns ± 8% -18.18% (p=0.000 n=60+59)
BM_isNaN_1T/128 [using 1 threads] 2.47µs ± 5% 1.89µs ± 5% -23.34% (p=0.000 n=54+52)
BM_isNaN_1T/256 [using 1 threads] 9.82µs ± 4% 7.47µs ± 4% -23.93% (p=0.000 n=60+59)
BM_isNaN_1T/512 [using 1 threads] 46.8µs ± 7% 42.2µs ± 7% -9.68% (p=0.000 n=60+56)
BM_isNaN_1T/1k [using 1 threads] 203µs ± 6% 195µs ± 6% -3.66% (p=0.000 n=53+54)
BM_isNaN_1T/2k [using 1 threads] 1.01ms ±38% 1.01ms ±43% ~ (p=0.804 n=49+46)
BM_isNaN_1T/4k [using 1 threads] 7.55ms ±10% 7.21ms ± 9% -4.43% (p=0.001 n=39+29)
BM_isNaN_1T/10k [using 1 threads] 46.6ms ± 5% 44.6ms ± 6% -4.31% (p=0.002 n=14+13)
Edited by Rasmus Munk Larsen