Skip to content

Vectorize tensor.isnan() by using typed predicates.

Getting this to run fast with AVX512 required vectorizing casting between Packet16f and Packet16b. Casting could be improved for other backends later.

Benchmark measurements:

AVX512F:
name                               old cpu/op  new cpu/op  delta
BM_isNaN_1T/3   [using 1 threads]  10.9ns ± 4%  10.9ns ± 5%     ~     (p=0.083 n=57+54)
BM_isNaN_1T/4   [using 1 threads]  8.07ns ±12%  7.69ns ± 4%   -4.81%  (p=0.000 n=55+54)
BM_isNaN_1T/7   [using 1 threads]  12.0ns ± 6%   9.9ns ± 5%  -17.60%  (p=0.000 n=57+44)
BM_isNaN_1T/8   [using 1 threads]  12.8ns ± 3%   8.7ns ± 6%  -31.61%  (p=0.000 n=53+54)
BM_isNaN_1T/10  [using 1 threads]  19.2ns ± 6%  13.2ns ± 6%  -31.20%  (p=0.000 n=54+47)
BM_isNaN_1T/15  [using 1 threads]  30.9ns ± 8%  15.3ns ± 7%  -50.51%  (p=0.000 n=50+55)
BM_isNaN_1T/16  [using 1 threads]  31.8ns ± 4%  15.4ns ± 6%  -51.42%  (p=0.000 n=50+53)
BM_isNaN_1T/31  [using 1 threads]   104ns ± 5%    53ns ± 6%  -49.41%  (p=0.000 n=55+59)
BM_isNaN_1T/32  [using 1 threads]   109ns ± 3%    56ns ± 6%  -48.48%  (p=0.000 n=54+59)
BM_isNaN_1T/64  [using 1 threads]   420ns ± 4%   221ns ± 4%  -47.45%  (p=0.000 n=57+54)
BM_isNaN_1T/128 [using 1 threads]  1.70µs ± 5%  0.90µs ± 4%  -47.18%  (p=0.000 n=59+55)
BM_isNaN_1T/256 [using 1 threads]  6.79µs ± 6%  3.57µs ± 4%  -47.45%  (p=0.000 n=60+49)
BM_isNaN_1T/512 [using 1 threads]  40.7µs ± 4%  33.1µs ± 6%  -18.71%  (p=0.000 n=48+50)
BM_isNaN_1T/1k  [using 1 threads]   192µs ± 4%   198µs ± 3%   +3.18%  (p=0.000 n=55+54)
BM_isNaN_1T/2k  [using 1 threads]   887µs ±24%   912µs ±24%     ~     (p=0.054 n=43+45)
BM_isNaN_1T/4k  [using 1 threads]  7.37ms ±11%  6.47ms ± 5%  -12.26%  (p=0.000 n=33+32)
BM_isNaN_1T/10k [using 1 threads]  46.3ms ± 7%  40.6ms ± 3%  -12.19%  (p=0.000 n=15+11)

AVX2:
name                               old cpu/op  new cpu/op   delta
BM_isNaN_1T/3   [using 1 threads]  10.9ns ± 4%   17.3ns ± 7%  +58.34%  (p=0.000 n=58+52)
BM_isNaN_1T/4   [using 1 threads]  8.51ns ± 3%  14.58ns ± 4%  +71.25%  (p=0.000 n=49+53)
BM_isNaN_1T/7   [using 1 threads]  15.4ns ± 5%   19.9ns ± 5%  +29.65%  (p=0.000 n=58+52)
BM_isNaN_1T/8   [using 1 threads]  18.1ns ± 8%   19.8ns ± 4%   +9.67%  (p=0.000 n=54+57)
BM_isNaN_1T/10  [using 1 threads]  27.7ns ± 4%   26.0ns ± 5%   -6.28%  (p=0.000 n=51+51)
BM_isNaN_1T/15  [using 1 threads]  50.2ns ± 5%   37.8ns ± 6%  -24.69%  (p=0.000 n=60+40)
BM_isNaN_1T/16  [using 1 threads]  55.6ns ± 4%   39.5ns ±10%  -28.87%  (p=0.000 n=59+47)
BM_isNaN_1T/31  [using 1 threads]   196ns ± 3%    121ns ± 4%  -38.27%  (p=0.000 n=56+54)
BM_isNaN_1T/32  [using 1 threads]   208ns ± 4%    128ns ± 5%  -38.61%  (p=0.000 n=55+59)
BM_isNaN_1T/64  [using 1 threads]   822ns ± 4%    464ns ± 5%  -43.61%  (p=0.000 n=57+60)
BM_isNaN_1T/128 [using 1 threads]  3.27µs ± 4%   2.09µs ± 6%  -36.14%  (p=0.000 n=50+58)
BM_isNaN_1T/256 [using 1 threads]  13.0µs ± 4%    8.3µs ± 4%  -36.45%  (p=0.000 n=55+57)
BM_isNaN_1T/512 [using 1 threads]  54.4µs ± 6%   43.6µs ± 7%  -19.89%  (p=0.000 n=60+58)
BM_isNaN_1T/1k  [using 1 threads]   226µs ± 5%    198µs ± 4%  -12.26%  (p=0.000 n=52+52)
BM_isNaN_1T/2k  [using 1 threads]  1.09ms ±33%   0.97ms ±26%  -10.63%  (p=0.000 n=41+47)
BM_isNaN_1T/4k  [using 1 threads]  8.22ms ± 7%   7.50ms ±14%   -8.79%  (p=0.000 n=39+36)
BM_isNaN_1T/10k [using 1 threads]  50.9ms ± 8%   47.3ms ± 6%   -7.16%  (p=0.000 n=15+14)


SSE:
name                               old cpu/op  new cpu/op   delta
BM_isNaN_1T/3   [using 1 threads]  10.6ns ± 4%   17.0ns ± 6%  +60.36%  (p=0.000 n=47+58)
BM_isNaN_1T/4   [using 1 threads]  8.65ns ± 7%  12.44ns ± 5%  +43.82%  (p=0.000 n=54+54)
BM_isNaN_1T/7   [using 1 threads]  16.4ns ± 5%   16.8ns ± 6%   +2.44%  (p=0.000 n=57+53)
BM_isNaN_1T/8   [using 1 threads]  17.5ns ± 7%   17.5ns ± 3%     ~     (p=0.551 n=55+53)
BM_isNaN_1T/10  [using 1 threads]  23.5ns ± 5%   23.3ns ± 6%     ~     (p=0.080 n=50+48)
BM_isNaN_1T/15  [using 1 threads]  39.9ns ± 4%   34.9ns ± 9%  -12.56%  (p=0.000 n=45+46)
BM_isNaN_1T/16  [using 1 threads]  42.8ns ± 4%   36.5ns ± 8%  -14.56%  (p=0.000 n=54+47)
BM_isNaN_1T/31  [using 1 threads]   142ns ± 3%    120ns ± 4%  -15.39%  (p=0.000 n=59+46)
BM_isNaN_1T/32  [using 1 threads]   149ns ± 4%    126ns ± 6%  -15.50%  (p=0.000 n=60+54)
BM_isNaN_1T/64  [using 1 threads]   558ns ± 4%    457ns ± 8%  -18.18%  (p=0.000 n=60+59)
BM_isNaN_1T/128 [using 1 threads]  2.47µs ± 5%   1.89µs ± 5%  -23.34%  (p=0.000 n=54+52)
BM_isNaN_1T/256 [using 1 threads]  9.82µs ± 4%   7.47µs ± 4%  -23.93%  (p=0.000 n=60+59)
BM_isNaN_1T/512 [using 1 threads]  46.8µs ± 7%   42.2µs ± 7%   -9.68%  (p=0.000 n=60+56)
BM_isNaN_1T/1k  [using 1 threads]   203µs ± 6%    195µs ± 6%   -3.66%  (p=0.000 n=53+54)
BM_isNaN_1T/2k  [using 1 threads]  1.01ms ±38%   1.01ms ±43%     ~     (p=0.804 n=49+46)
BM_isNaN_1T/4k  [using 1 threads]  7.55ms ±10%   7.21ms ± 9%   -4.43%  (p=0.001 n=39+29)
BM_isNaN_1T/10k [using 1 threads]  46.6ms ± 5%   44.6ms ± 6%   -4.31%  (p=0.002 n=14+13)
Edited by Rasmus Munk Larsen

Merge request reports

Loading