Optimized SIMD includes and solved AVX512 bugs

This commit is contained in:
Xavier Arteaga 2017-09-29 16:42:46 +02:00
parent 9e5f999666
commit 94a06867a3
7 changed files with 79 additions and 43 deletions

View File

@ -283,8 +283,8 @@ if(CMAKE_C_COMPILER_ID MATCHES "GNU" OR CMAKE_C_COMPILER_ID MATCHES "Clang")
endif (HAVE_AVX2)
if (HAVE_AVX512)
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -mavx512f -DLV_HAVE_AVX512")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -mavx512f -DLV_HAVE_AVX512")
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -mavx512f -mavx512cd -DLV_HAVE_AVX512")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -mavx512f -mavx512cd -DLV_HAVE_AVX512")
endif(HAVE_AVX512)
if(NOT ${CMAKE_BUILD_TYPE} STREQUAL "Debug")

View File

@ -60,7 +60,6 @@ SRSLTE_API float srslte_mat_2x2_cn(cf_t h00,
#ifdef LV_HAVE_SSE
#include <smmintrin.h>
/* SSE implementation for complex reciprocal */
SRSLTE_API __m128 srslte_mat_cf_recip_sse(__m128 a);
@ -84,8 +83,6 @@ SRSLTE_API void srslte_mat_2x2_mmse_sse(__m128 y0, __m128 y1,
#ifdef LV_HAVE_AVX
#include <immintrin.h>
/* AVX implementation for complex reciprocal */
SRSLTE_API __m256 srslte_mat_cf_recip_avx(__m256 a);

View File

@ -27,7 +27,12 @@
#ifndef SRSLTE_SIMD_H_H
#define SRSLTE_SIMD_H_H
#ifdef LV_HAVE_SSE /* AVX, AVX2, FMA, AVX512 are in this group */
#ifndef __OPTIMIZE__
#define __OPTIMIZE__
#endif
#include <immintrin.h>
#endif /* LV_HAVE_SSE */
/*
* SSE Macros
@ -233,7 +238,7 @@ static inline simd_f_t srslte_simd_f_mul(simd_f_t a, simd_f_t b) {
static inline simd_f_t srslte_simd_f_rcp(simd_f_t a) {
#ifdef LV_HAVE_AVX512
return _mm512_rcp_ps(a);
return _mm512_rcp14_ps(a);
#else /* LV_HAVE_AVX512 */
#ifdef LV_HAVE_AVX2
return _mm256_rcp_ps(a);
@ -372,10 +377,16 @@ typedef struct {
static inline simd_cf_t srslte_simd_cfi_load(cf_t *ptr) {
simd_cf_t ret;
#ifdef LV_HAVE_AVX512
__m512 in1 = _mm512_permute_ps(_mm512_load_ps((float*)(ptr)), 0b11011000);
__m512 in2 = _mm512_permute_ps(_mm512_load_ps((float*)(ptr + 8)), 0b11011000);
ret.re = _mm512_unpacklo_ps(in1, in2);
ret.im = _mm512_unpackhi_ps(in1, in2);
__m512 in1 = _mm512_load_ps((float*)(ptr));
__m512 in2 = _mm512_load_ps((float*)(ptr + SRSLTE_SIMD_CF_SIZE/2));
ret.re = _mm512_permutex2var_ps(in1, _mm512_setr_epi32(0x00, 0x02, 0x04, 0x06,
0x08, 0x0A, 0x0C, 0x0E,
0x10, 0x12, 0x14, 0x16,
0x18, 0x1A, 0x1C, 0x1E), in2);
ret.im = _mm512_permutex2var_ps(in1, _mm512_setr_epi32(0x01, 0x03, 0x05, 0x07,
0x09, 0x0B, 0x0D, 0x0F,
0x11, 0x13, 0x15, 0x17,
0x19, 0x1B, 0x1D, 0x1F), in2);
#else /* LV_HAVE_AVX512 */
#ifdef LV_HAVE_AVX2
__m256 in1 = _mm256_permute_ps(_mm256_load_ps((float*)(ptr)), 0b11011000);
@ -398,10 +409,16 @@ static inline simd_cf_t srslte_simd_cfi_load(cf_t *ptr) {
static inline simd_cf_t srslte_simd_cfi_loadu(cf_t *ptr) {
simd_cf_t ret;
#ifdef LV_HAVE_AVX512
__m512 in1 = _mm512_permute_ps(_mm512_loadu_ps((float*)(ptr)), 0b11011000);
__m512 in2 = _mm512_permute_ps(_mm512_loadu_ps((float*)(ptr + 8)), 0b11011000);
ret.re = _mm512_unpacklo_ps(in1, in2);
ret.im = _mm512_unpackhi_ps(in1, in2);
__m512 in1 = _mm512_loadu_ps((float*)(ptr));
__m512 in2 = _mm512_loadu_ps((float*)(ptr + SRSLTE_SIMD_CF_SIZE/2));
ret.re = _mm512_permutex2var_ps(in1, _mm512_setr_epi32(0x00, 0x02, 0x04, 0x06,
0x08, 0x0A, 0x0C, 0x0E,
0x10, 0x12, 0x14, 0x16,
0x18, 0x1A, 0x1C, 0x1E), in2);
ret.im = _mm512_permutex2var_ps(in1, _mm512_setr_epi32(0x01, 0x03, 0x05, 0x07,
0x09, 0x0B, 0x0D, 0x0F,
0x11, 0x13, 0x15, 0x17,
0x19, 0x1B, 0x1D, 0x1F), in2);
#else /* LV_HAVE_AVX512 */
#ifdef LV_HAVE_AVX2
__m256 in1 = _mm256_permute_ps(_mm256_loadu_ps((float*)(ptr)), 0b11011000);
@ -460,10 +477,16 @@ static inline simd_cf_t srslte_simd_cf_loadu(float *re, float *im) {
static inline void srslte_simd_cfi_store(cf_t *ptr, simd_cf_t simdreg) {
#ifdef LV_HAVE_AVX512
__m512 out1 = _mm512_permute_ps(simdreg.re, 0b11011000);
__m512 out2 = _mm512_permute_ps(simdreg.im, 0b11011000);
_mm512_store_ps((float*)(ptr), _mm512_unpacklo_ps(out1, out2));
_mm512_store_ps((float*)(ptr + 8), _mm512_unpackhi_ps(out1, out2));
__m512 s1 = _mm512_permutex2var_ps(simdreg.re, _mm512_setr_epi32(0x00, 0x10, 0x01, 0x11,
0x02, 0x12, 0x03, 0x13,
0x04, 0x14, 0x05, 0x15,
0x06, 0x16, 0x07, 0x17), simdreg.im);
__m512 s2 = _mm512_permutex2var_ps(simdreg.re, _mm512_setr_epi32(0x08, 0x18, 0x09, 0x19,
0x0A, 0x1A, 0x0B, 0x1B,
0x0C, 0x1C, 0x0D, 0x1D,
0x0E, 0x1E, 0x0F, 0x1F), simdreg.im);
_mm512_store_ps((float*)(ptr), s1);
_mm512_store_ps((float*)(ptr + 8), s2);
#else /* LV_HAVE_AVX512 */
#ifdef LV_HAVE_AVX2
__m256 out1 = _mm256_permute_ps(simdreg.re, 0b11011000);
@ -481,10 +504,16 @@ static inline void srslte_simd_cfi_store(cf_t *ptr, simd_cf_t simdreg) {
static inline void srslte_simd_cfi_storeu(cf_t *ptr, simd_cf_t simdreg) {
#ifdef LV_HAVE_AVX512
__m512 out1 = _mm512_permute_ps(simdreg.re, 0b11011000);
__m512 out2 = _mm512_permute_ps(simdreg.im, 0b11011000);
_mm512_storeu_ps((float*)(ptr), _mm512_unpacklo_ps(out1, out2));
_mm512_storeu_ps((float*)(ptr + 8), _mm512_unpackhi_ps(out1, out2));
__m512 s1 = _mm512_permutex2var_ps(simdreg.re, _mm512_setr_epi32(0x00, 0x10, 0x01, 0x11,
0x02, 0x12, 0x03, 0x13,
0x04, 0x14, 0x05, 0x15,
0x06, 0x16, 0x07, 0x17), simdreg.im);
__m512 s2 = _mm512_permutex2var_ps(simdreg.re, _mm512_setr_epi32(0x08, 0x18, 0x09, 0x19,
0x0A, 0x1A, 0x0B, 0x1B,
0x0C, 0x1C, 0x0D, 0x1D,
0x0E, 0x1E, 0x0F, 0x1F), simdreg.im);
_mm512_storeu_ps((float*)(ptr), s1);
_mm512_storeu_ps((float*)(ptr + 8), s2);
#else /* LV_HAVE_AVX512 */
#ifdef LV_HAVE_AVX2
__m256 out1 = _mm256_permute_ps(simdreg.re, 0b11011000);
@ -625,7 +654,6 @@ static inline simd_cf_t srslte_simd_cf_add (simd_cf_t a, simd_cf_t b) {
static inline simd_cf_t srslte_simd_cf_mul (simd_cf_t a, simd_f_t b) {
simd_cf_t ret;
#ifdef LV_HAVE_AVX512
b = _mm512_permutexvar_ps(b, _mm512_setr_epi32(0,4,1,5,2,6,3,7,8,12,9,13,10,14,11,15));
ret.re = _mm512_mul_ps(a.re, b);
ret.im = _mm512_mul_ps(a.im, b);
#else /* LV_HAVE_AVX512 */
@ -649,7 +677,7 @@ static inline simd_cf_t srslte_simd_cf_rcp (simd_cf_t a) {
simd_f_t a2re = _mm512_mul_ps(a.re, a.re);
simd_f_t a2im = _mm512_mul_ps(a.im, a.im);
simd_f_t mod2 = _mm512_add_ps(a2re, a2im);
simd_f_t rcp = _mm512_rcp_ps(mod2);
simd_f_t rcp = _mm512_rcp14_ps(mod2);
simd_f_t neg_a_im = _mm512_xor_ps(_mm512_set1_ps(-0.0f), a.im);
ret.re = _mm512_mul_ps(a.re, rcp);
ret.im = _mm512_mul_ps(neg_a_im, rcp);
@ -702,12 +730,15 @@ static inline simd_cf_t srslte_simd_cf_zero (void) {
#ifdef LV_HAVE_AVX512
typedef __m512i simd_i_t;
typedef __mmask16 simd_sel_t;
#else /* LV_HAVE_AVX512 */
#ifdef LV_HAVE_AVX2
typedef __m256i simd_i_t;
typedef __m256 simd_sel_t;
#else /* LV_HAVE_AVX2 */
#ifdef LV_HAVE_SSE
typedef __m128i simd_i_t;
typedef __m128i simd_sel_t;
#endif /* LV_HAVE_SSE */
#endif /* LV_HAVE_AVX2 */
#endif /* LV_HAVE_AVX512 */
@ -768,12 +799,12 @@ static inline simd_i_t srslte_simd_i_add(simd_i_t a, simd_i_t b) {
#endif /* LV_HAVE_AVX512 */
}
static inline simd_i_t srslte_simd_f_max(simd_f_t a, simd_f_t b) {
static inline simd_sel_t srslte_simd_f_max(simd_f_t a, simd_f_t b) {
#ifdef LV_HAVE_AVX512
return (simd_i_t) _mm512_cmp_ps_mask(a, b, _CMP_GT_OS);
return _mm512_cmp_ps_mask(a, b, _CMP_GT_OS);
#else /* LV_HAVE_AVX512 */
#ifdef LV_HAVE_AVX2
return (simd_i_t) _mm256_cmp_ps(a, b, _CMP_GT_OS);
return _mm256_cmp_ps(a, b, _CMP_GT_OS);
#else /* LV_HAVE_AVX2 */
#ifdef LV_HAVE_SSE
return (simd_i_t) _mm_cmpgt_ps(a, b);
@ -782,15 +813,15 @@ static inline simd_i_t srslte_simd_f_max(simd_f_t a, simd_f_t b) {
#endif /* LV_HAVE_AVX512 */
}
static inline simd_i_t srslte_simd_i_select(simd_i_t a, simd_i_t b, simd_i_t selector) {
static inline simd_i_t srslte_simd_i_select(simd_i_t a, simd_i_t b, simd_sel_t selector) {
#ifdef LV_HAVE_AVX512
return (__m512i) _mm512_blendv_ps((__m512)a, (__m512) b, (__m512) selector);
return (__m512i) _mm512_mask_blend_ps( selector, (__m512)a, (__m512) b);
#else /* LV_HAVE_AVX512 */
#ifdef LV_HAVE_AVX2
return (__m256i) _mm256_blendv_ps((__m256) a,(__m256) b,(__m256) selector);
return (__m256i) _mm256_blendv_ps((__m256) a,(__m256) b, selector);
#else
#ifdef LV_HAVE_SSE
return (__m128i) _mm_blendv_ps((__m128)a, (__m128)b, (__m128)selector);
return (__m128i) _mm_blendv_ps((__m128)a, (__m128)b, selector);
#endif /* LV_HAVE_SSE */
#endif /* LV_HAVE_AVX2 */
#endif /* LV_HAVE_AVX512 */
@ -1127,6 +1158,19 @@ static inline simd_c16_t srslte_simd_c16_zero (void) {
#if SRSLTE_SIMD_F_SIZE && SRSLTE_SIMD_S_SIZE
static inline simd_s_t srslte_simd_convert_2f_s(simd_f_t a, simd_f_t b) {
#ifdef LV_HAVE_AVX512
__m512 aa = _mm512_permutex2var_ps(a, _mm512_setr_epi32(0x00, 0x01, 0x02, 0x03,
0x08, 0x09, 0x0A, 0x0B,
0x10, 0x11, 0x12, 0x13,
0x18, 0x19, 0x1A, 0x1B), b);
__m512 bb = _mm512_permutex2var_ps(a, _mm512_setr_epi32(0x04, 0x05, 0x06, 0x07,
0x0C, 0x0D, 0x0E, 0x0F,
0x14, 0x15, 0x16, 0x17,
0x1C, 0x1D, 0x1E, 0x1F), b);
__m512i ai = _mm512_cvttps_epi32(aa);
__m512i bi = _mm512_cvttps_epi32(bb);
return _mm512_packs_epi32(ai, bi);
#else /* LV_HAVE_AVX512 */
#ifdef LV_HAVE_AVX2
__m256 aa = _mm256_permute2f128_ps(a, b, 0x20);
__m256 bb = _mm256_permute2f128_ps(a, b, 0x31);
@ -1140,6 +1184,7 @@ static inline simd_s_t srslte_simd_convert_2f_s(simd_f_t a, simd_f_t b) {
return _mm_packs_epi32(ai, bi);
#endif /* LV_HAVE_SSE */
#endif /* LV_HAVE_AVX2 */
#endif /* LV_HAVE_AVX512 */
}
#endif /* SRSLTE_SIMD_F_SIZE && SRSLTE_SIMD_C16_SIZE */

View File

@ -33,20 +33,17 @@
#include "srslte/phy/mimo/precoding.h"
#include "srslte/phy/utils/vector.h"
#include "srslte/phy/utils/debug.h"
#include "srslte/phy/utils/mat.h"
#ifdef LV_HAVE_SSE
#include <immintrin.h>
int srslte_predecoding_single_sse(cf_t *y[SRSLTE_MAX_PORTS], cf_t *h[SRSLTE_MAX_PORTS], cf_t *x, int nof_rxant, int nof_symbols, float noise_estimate);
int srslte_predecoding_diversity2_sse(cf_t *y[SRSLTE_MAX_PORTS], cf_t *h[SRSLTE_MAX_PORTS][SRSLTE_MAX_PORTS], cf_t *x[SRSLTE_MAX_LAYERS], int nof_rxant, int nof_symbols);
#endif
#ifdef LV_HAVE_AVX
#include <immintrin.h>
int srslte_predecoding_single_avx(cf_t *y[SRSLTE_MAX_PORTS], cf_t *h[SRSLTE_MAX_PORTS], cf_t *x, int nof_rxant, int nof_symbols, float noise_estimate);
#endif
#include "srslte/phy/utils/mat.h"
static srslte_mimo_decoder_t mimo_decoder = SRSLTE_MIMO_DECODER_MMSE;
/************************************************

View File

@ -29,7 +29,6 @@
#include <unistd.h>
#include <complex.h>
#include <stdbool.h>
#include <immintrin.h>
#include <sys/time.h>
#include "srslte/phy/utils/mat.h"

View File

@ -29,9 +29,7 @@
#include <unistd.h>
#include <complex.h>
#include <stdbool.h>
#include <immintrin.h>
#include <sys/time.h>
#include <srslte/phy/utils/vector_simd.h>
#include <memory.h>
#include <math.h>

View File

@ -556,7 +556,7 @@ void srslte_vec_prod_cfc_simd(cf_t *x, float *y, cf_t *z, int len) {
for (; i < len - SRSLTE_SIMD_F_SIZE + 1; i += SRSLTE_SIMD_F_SIZE) {
simd_f_t s = srslte_simd_f_loadu(&y[i]);
simd_cf_t a = srslte_simd_cfi_load(&x[i]);
simd_cf_t a = srslte_simd_cfi_loadu(&x[i]);
simd_cf_t r = srslte_simd_cf_mul(a, s);
srslte_simd_cfi_storeu(&z[i], r);
}
@ -1036,7 +1036,7 @@ uint32_t srslte_vec_max_fi_simd(float *x, int len) {
for (; i < len - SRSLTE_SIMD_I_SIZE + 1; i += SRSLTE_SIMD_I_SIZE) {
simd_f_t a = srslte_simd_f_load(&x[i]);
simd_i_t res = srslte_simd_f_max(a, simd_max_values);
simd_sel_t res = srslte_simd_f_max(a, simd_max_values);
simd_max_indexes = srslte_simd_i_select(simd_max_indexes, simd_indexes, res);
simd_max_values = (simd_f_t) srslte_simd_i_select((simd_i_t) simd_max_values, (simd_i_t) a, res);
@ -1046,7 +1046,7 @@ uint32_t srslte_vec_max_fi_simd(float *x, int len) {
for (; i < len - SRSLTE_SIMD_I_SIZE + 1; i += SRSLTE_SIMD_I_SIZE) {
simd_f_t a = srslte_simd_f_loadu(&x[i]);
simd_i_t res = srslte_simd_f_max(a, simd_max_values);
simd_sel_t res = srslte_simd_f_max(a, simd_max_values);
simd_max_indexes = srslte_simd_i_select(simd_max_indexes, simd_indexes, res);
simd_max_values = (simd_f_t) srslte_simd_i_select((simd_i_t) simd_max_values, (simd_i_t) a, res);
@ -1102,7 +1102,7 @@ uint32_t srslte_vec_max_ci_simd(cf_t *x, int len) {
simd_f_t z1 = srslte_simd_f_hadd(mul1, mul2);
simd_i_t res = srslte_simd_f_max(z1, simd_max_values);
simd_sel_t res = srslte_simd_f_max(z1, simd_max_values);
simd_max_indexes = srslte_simd_i_select(simd_max_indexes, simd_indexes, res);
simd_max_values = (simd_f_t) srslte_simd_i_select((simd_i_t) simd_max_values, (simd_i_t) z1, res);
@ -1118,7 +1118,7 @@ uint32_t srslte_vec_max_ci_simd(cf_t *x, int len) {
simd_f_t z1 = srslte_simd_f_hadd(mul1, mul2);
simd_i_t res = srslte_simd_f_max(z1, simd_max_values);
simd_sel_t res = srslte_simd_f_max(z1, simd_max_values);
simd_max_indexes = srslte_simd_i_select(simd_max_indexes, simd_indexes, res);
simd_max_values = (simd_f_t) srslte_simd_i_select((simd_i_t) simd_max_values, (simd_i_t) z1, res);