The Intel Advanced Vector Extensions (AVX) offers no dot product in the 256-bit version (YMM register) for double precision floating point variables. The "Why?" question have been very briefly treated in another forum (here) and on Stack Overflow (here). But the question I am facing is how to replace this missing instruction with other AVX instructions in an efficient way?
The dot product in 256-bit version exists for single precision floating point variables (reference here):
__m256 _mm256_dp_ps(__m256 m1, __m256 m2, const int mask);
The idea is to find an efficient equivalent for this missing instruction:
__m256d _mm256_dp_pd(__m256d m1, __m256d m2, const int mask);
To be more specific, the code I would like to transform from __m128
(four floats) to __m256d
(4 doubles) use the following instructions:
__m128 val0 = ...; // Four float values
__m128 val1 = ...; //
__m128 val2 = ...; //
__m128 val3 = ...; //
__m128 val4 = ...; //
__m128 res = _mm_or_ps( _mm_dp_ps(val1, val0, 0xF1),
_mm_or_ps( _mm_dp_ps(val2, val0, 0xF2),
_mm_or_ps( _mm_dp_ps(val3, val0, 0xF4),
_mm_dp_ps(val4, val0, 0xF8) )));
The result of this code is a _m128
vector of four floats containing the results of the dot products between val1
and val0
, val2
and val0
, val3
and val0
, val4
and val0
.
Maybe this can give hints for the suggestions?
I would use a 4*double multiplication, then a hadd
(which unfortunately adds only 2*2 floats in the upper and lower half), extract the upper half (a shuffle should work equally, maybe faster) and add it to the lower half.
The result is in the low 64 bit of dotproduct
.
__m256d xy = _mm256_mul_pd( x, y );
__m256d temp = _mm256_hadd_pd( xy, xy );
__m128d hi128 = _mm256_extractf128_pd( temp, 1 );
__m128d dotproduct = _mm_add_pd( (__m128d)temp, hi128 );
Edit:
After an idea of Norbert P. I extended this version to do 4 dot products at one time.
__m256d xy0 = _mm256_mul_pd( x[0], y[0] );
__m256d xy1 = _mm256_mul_pd( x[1], y[1] );
__m256d xy2 = _mm256_mul_pd( x[2], y[2] );
__m256d xy3 = _mm256_mul_pd( x[3], y[3] );
// low to high: xy00+xy01 xy10+xy11 xy02+xy03 xy12+xy13
__m256d temp01 = _mm256_hadd_pd( xy0, xy1 );
// low to high: xy20+xy21 xy30+xy31 xy22+xy23 xy32+xy33
__m256d temp23 = _mm256_hadd_pd( xy2, xy3 );
// low to high: xy02+xy03 xy12+xy13 xy20+xy21 xy30+xy31
__m256d swapped = _mm256_permute2f128_pd( temp01, temp23, 0x21 );
// low to high: xy00+xy01 xy10+xy11 xy22+xy23 xy32+xy33
__m256d blended = _mm256_blend_pd(temp01, temp23, 0b1100);
__m256d dotproduct = _mm256_add_pd( swapped, blended );