Added algebra SSE, AVX (+FMA) implementatiokns with test

This commit is contained in:
Xavier Arteaga 2017-08-17 10:14:37 +02:00
parent 2d1166e3fa
commit d1709e06af
4 changed files with 737 additions and 40 deletions

View File

@ -29,42 +29,127 @@
#include "srslte/config.h"
#ifdef LV_HAVE_SSE
/*
* Generic Macros
*/
#define RANDOM_CF() (((float)rand())/((float)RAND_MAX) + _Complex_I*((float)rand())/((float)RAND_MAX))
#define _MM_MULJ_PS(X) _mm_permute_ps(_MM_CONJ_PS(X), 0b10110001)
#define _MM_CONJ_PS(X) (_mm_xor_ps(X, (__m128){0.0f, -0.0f, 0.0f, -0.0f}))
/*
* SSE Macros
*/
#ifdef LV_HAVE_SSE
#define _MM_SWAP(X) ((__m128)_mm_shuffle_ps(X, X, _MM_SHUFFLE(2,3,0,1)))
#define _MM_PERM(X) ((__m128)_mm_shuffle_ps(X, X, _MM_SHUFFLE(2,1,3,0)))
#define _MM_MULJ_PS(X) _MM_SWAP(_MM_CONJ_PS(X))
#define _MM_CONJ_PS(X) (_mm_xor_ps(X, _mm_set_ps(-0.0f, 0.0f, -0.0f, 0.0f)))
#define _MM_SQMOD_PS(X) _MM_PERM(_mm_hadd_ps(_mm_mul_ps(X,X), _mm_set_ps(0.0f, 0.0f, 0.0f, 0.0f)))
#define _MM_PROD_PS(a, b) _mm_addsub_ps(_mm_mul_ps(a,_mm_moveldup_ps(b)),_mm_mul_ps(\
_mm_shuffle_ps(a,a,0xB1),_mm_movehdup_ps(b)))
SRSLTE_API void srslte_algebra_2x2_zf_sse(__m128 y0,
__m128 y1,
__m128 h00,
__m128 h01,
__m128 h10,
__m128 h11,
__m128 *x0,
__m128 *x1,
#endif /* LV_HAVE_SSE */
/*
* AVX Macros
*/
#ifdef LV_HAVE_AVX
#define _MM256_MULJ_PS(X) _mm256_permute_ps(_MM256_CONJ_PS(X), 0b10110001)
#define _MM256_CONJ_PS(X) (_mm256_xor_ps(X, _mm256_set_ps(-0.0f, 0.0f, -0.0f, 0.0f, -0.0f, 0.0f, -0.0f, 0.0f)))
#ifdef LV_HAVE_FMA
#define _MM256_SQMOD_PS(A, B) _mm256_permute_ps(_mm256_hadd_ps(_mm256_fmadd_ps(A, A, _mm256_mul_ps(B,B)), \
_mm256_set_ps(0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f)), 0b11011100)
#define _MM256_PROD_PS(a, b) _mm256_fmaddsub_ps(a,_mm256_moveldup_ps(b),\
_mm256_mul_ps(_mm256_shuffle_ps(a,a,0xB1),_mm256_movehdup_ps(b)))
#else
#define _MM256_SQMOD_PS(A, B) _mm256_permute_ps(_mm256_hadd_ps(_mm256_add_ps(_mm256_mul_ps(A,A), _mm256_mul_ps(B,B)), \
_mm256_set_ps(0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f)), 0b11011100)
#define _MM256_PROD_PS(a, b) _mm256_addsub_ps(_mm256_mul_ps(a,_mm256_moveldup_ps(b)),\
_mm256_mul_ps(_mm256_shuffle_ps(a,a,0xB1),_mm256_movehdup_ps(b)))
#endif /* LV_HAVE_FMA */
#endif /* LV_HAVE_AVX */
/*
* AVX extension with FMA Macros
*/
#ifdef LV_HAVE_FMA
#define _MM256_SQMOD_ADD_PS(A, B, C) _mm256_permute_ps(_mm256_hadd_ps(_mm256_fmadd_ps(A, A, _mm256_fmadd_ps(B, B, C)),\
_mm256_set_ps(0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f)), 0b11011100)
#define _MM256_PROD_ADD_PS(A, B, C) _mm256_fmaddsub_ps(A,_mm256_moveldup_ps(B),\
_mm256_fmaddsub_ps(_mm256_shuffle_ps(A,A,0xB1),_mm256_movehdup_ps(B), C))
#define _MM256_PROD_SUB_PS(A, B, C) _mm256_fmaddsub_ps(A,_mm256_moveldup_ps(B),\
_mm256_fmsubadd_ps(_mm256_shuffle_ps(A,A,0xB1),_mm256_movehdup_ps(B), C))
#endif /* LV_HAVE_FMA */
/* Generic implementation for complex reciprocal */
SRSLTE_API cf_t srslte_algebra_cf_recip_gen(cf_t a);
/* Generic implementation for 2x2 determinant */
SRSLTE_API cf_t srslte_algebra_2x2_det_gen(cf_t a00, cf_t a01, cf_t a10, cf_t a11);
/* Generic implementation for 2x2 Matrix Inversion */
SRSLTE_API void srslte_algebra_2x2_inv_gen(cf_t a00, cf_t a01, cf_t a10, cf_t a11,
cf_t *r00, cf_t *r01, cf_t *r10, cf_t *r11);
/* Generic implementation for Zero Forcing (ZF) solver */
SRSLTE_API void srslte_algebra_2x2_zf_gen(cf_t y0, cf_t y1,
cf_t h00, cf_t h01, cf_t h10, cf_t h11,
cf_t *x0, cf_t *x1,
float norm);
/* Generic implementation for Minimum Mean Squared Error (MMSE) solver */
SRSLTE_API void srslte_algebra_2x2_mmse_gen(cf_t y0, cf_t y1,
cf_t h00, cf_t h01, cf_t h10, cf_t h11,
cf_t *x0, cf_t *x1,
float noise_estimate,
float norm);
#ifdef LV_HAVE_SSE
/* SSE implementation for complex reciprocal */
SRSLTE_API __m128 srslte_algebra_cf_recip_sse(__m128 a);
/* SSE implementation for 2x2 determinant */
SRSLTE_API __m128 srslte_algebra_2x2_det_sse(__m128 a00, __m128 a01, __m128 a10, __m128 a11);
/* SSE implementation for Zero Forcing (ZF) solver */
SRSLTE_API void srslte_algebra_2x2_zf_sse(__m128 y0, __m128 y1,
__m128 h00, __m128 h01, __m128 h10, __m128 h11,
__m128 *x0, __m128 *x1,
float norm);
/* SSE implementation for Minimum Mean Squared Error (MMSE) solver */
SRSLTE_API void srslte_algebra_2x2_mmse_sse(__m128 y0, __m128 y1,
__m128 h00, __m128 h01, __m128 h10, __m128 h11,
__m128 *x0, __m128 *x1,
float noise_estimate, float norm);
#endif /* LV_HAVE_SSE */
#ifdef LV_HAVE_AVX
#define _MM256_MULJ_PS(X) _mm256_permute_ps(_MM256_CONJ_PS(X), 0b10110001)
#define _MM256_CONJ_PS(X) (_mm256_xor_ps(X, (__m256){0.0f, -0.0f, 0.0f, -0.0f, 0.0f, -0.0f, 0.0f, -0.0f}))
#define _MM256_PROD_PS(a, b) _mm256_addsub_ps(_mm256_mul_ps(a,_mm256_moveldup_ps(b)),\
_mm256_mul_ps(_mm256_shuffle_ps(a,a,0xB1),_mm256_movehdup_ps(b)))
/* AVX implementation for complex reciprocal */
SRSLTE_API __m256 srslte_algebra_cf_recip_avx(__m256 a);
SRSLTE_API void srslte_algebra_2x2_zf_avx(__m256 y0,
__m256 y1,
__m256 h00,
__m256 h01,
__m256 h10,
__m256 h11,
__m256 *x0,
__m256 *x1,
/* AVX implementation for 2x2 determinant */
SRSLTE_API __m256 srslte_algebra_2x2_det_avx(__m256 a00, __m256 a01, __m256 a10, __m256 a11);
/* AVX implementation for Zero Forcing (ZF) solver */
SRSLTE_API void srslte_algebra_2x2_zf_avx(__m256 y0, __m256 y1,
__m256 h00, __m256 h01, __m256 h10, __m256 h11,
__m256 *x0, __m256 *x1,
float norm);
/* AVX implementation for Minimum Mean Squared Error (MMSE) solver */
SRSLTE_API void srslte_algebra_2x2_mmse_avx(__m256 y0, __m256 y1,
__m256 h00, __m256 h01, __m256 h10, __m256 h11,
__m256 *x0, __m256 *x1,
float noise_estimate, float norm);
#endif /* LV_HAVE_AVX */
#endif //SRSLTE_ALGEBRA_H

View File

@ -24,49 +24,237 @@
*
*/
#include <complex.h>
#include <immintrin.h>
#include "srslte/phy/utils/algebra.h"
/* Generic implementation for complex reciprocal */
inline cf_t srslte_algebra_cf_recip_gen(cf_t a) {
return conjf(a) / (crealf(a) * crealf(a) + cimagf(a) * cimagf(a));
}
/* Generic implementation for 2x2 determinant */
inline cf_t srslte_algebra_2x2_det_gen(cf_t a00, cf_t a01, cf_t a10, cf_t a11) {
return a00 * a11 - a01 * a10;
}
/* 2x2 Matrix inversion, generic implementation */
inline void srslte_algebra_2x2_inv_gen(cf_t a00, cf_t a01, cf_t a10, cf_t a11,
cf_t *r00, cf_t *r01, cf_t *r10, cf_t *r11) {
cf_t div = srslte_algebra_cf_recip_gen(srslte_algebra_2x2_det_gen(a00, a01, a10, a11));
*r00 = a11 * div;
*r01 = -a01 * div;
*r10 = -a10 * div;
*r11 = a00 * div;
}
/* Generic implementation for Zero Forcing (ZF) solver */
inline void srslte_algebra_2x2_zf_gen(cf_t y0, cf_t y1, cf_t h00, cf_t h01, cf_t h10, cf_t h11,
cf_t *x0, cf_t *x1, float norm) {
cf_t _norm = srslte_algebra_cf_recip_gen(srslte_algebra_2x2_det_gen(h00, h01, h10, h11)) * norm;
*x0 = (y0 * h11 - h01 * y1) * _norm;
*x1 = (y1 * h00 - h10 * y0) * _norm;
}
/* Generic implementation for Minimum Mean Squared Error (MMSE) solver */
inline void srslte_algebra_2x2_mmse_gen(cf_t y0, cf_t y1, cf_t h00, cf_t h01, cf_t h10, cf_t h11,
cf_t *x0, cf_t *x1, float noise_estimate, float norm) {
/* Create conjugated matrix */
cf_t _h00 = conjf(h00);
cf_t _h01 = conjf(h01);
cf_t _h10 = conjf(h10);
cf_t _h11 = conjf(h11);
/* 1. A = H' x H + No*/
cf_t a00 = _h00 * h00 + _h10 * h10 + noise_estimate;
cf_t a01 = _h00 * h01 + _h10 * h11;
cf_t a10 = _h01 * h00 + _h11 * h10;
cf_t a11 = _h01 * h01 + _h11 * h11 + noise_estimate;
/* 2. B = inv(H' x H + No) = inv(A) */
cf_t b00 = a11;
cf_t b01 = -a01;
cf_t b10 = -a10;
cf_t b11 = a00;
cf_t _norm = norm * srslte_algebra_cf_recip_gen(srslte_algebra_2x2_det_gen(a00, a01, a10, a11));
/* 3. W = inv(H' x H + No) x H' = B x H' */
cf_t w00 = b00 * _h00 + b01 * _h01;
cf_t w01 = b00 * _h10 + b01 * _h11;
cf_t w10 = b10 * _h00 + b11 * _h01;
cf_t w11 = b10 * _h10 + b11 * _h11;
/* 4. X = W x Y */
*x0 = (y0 * w00 + y1 * w01) * _norm;
*x1 = (y0 * w10 + y1 * w11) * _norm;
}
#ifdef LV_HAVE_SSE
/* SSE implementation for complex reciprocal */
inline __m128 srslte_algebra_cf_recip_sse(__m128 a) {
__m128 conj = _MM_CONJ_PS(a);
__m128 sqabs = _mm_mul_ps(a, a);
sqabs = _mm_add_ps(_mm_movehdup_ps(sqabs), _mm_moveldup_ps(sqabs));
__m128 recp = _mm_rcp_ps(sqabs);
return _mm_mul_ps(recp, conj);
}
/* SSE implementation for 2x2 determinant */
inline __m128 srslte_algebra_2x2_det_sse(__m128 a00, __m128 a01, __m128 a10, __m128 a11) {
return _mm_sub_ps(_MM_PROD_PS(a00, a11), _MM_PROD_PS(a01, a10));
}
/* SSE implementation for Zero Forcing (ZF) solver */
inline void srslte_algebra_2x2_zf_sse(__m128 y0, __m128 y1, __m128 h00, __m128 h01, __m128 h10, __m128 h11,
__m128 *x0, __m128 *x1, float norm) {
__m128 detmult1 = _MM_PROD_PS(h00, h11);
__m128 detmult2 = _MM_PROD_PS(h01, h10);
__m128 det = _mm_sub_ps(detmult1, detmult2);
__m128 detconj = _MM_CONJ_PS(det);
__m128 detabs2 = _MM_PROD_PS(det, detconj);
__m128 detabs2rec = _mm_rcp_ps(detabs2);
detabs2rec = _mm_moveldup_ps(detabs2rec);
__m128 detrec = _mm_mul_ps(_mm_mul_ps(detconj, detabs2rec),
(__m128) {norm, norm, norm, norm});
__m128 detrec = _mm_mul_ps(srslte_algebra_cf_recip_sse(det), _mm_set1_ps(norm));
*x0 = _MM_PROD_PS(_mm_sub_ps(_MM_PROD_PS(h11, y0), _MM_PROD_PS(h01, y1)), detrec);
*x1 = _MM_PROD_PS(_mm_sub_ps(_MM_PROD_PS(h00, y1), _MM_PROD_PS(h10, y0)), detrec);
}
/* SSE implementation for Minimum Mean Squared Error (MMSE) solver */
inline void srslte_algebra_2x2_mmse_sse(__m128 y0, __m128 y1, __m128 h00, __m128 h01, __m128 h10, __m128 h11,
__m128 *x0, __m128 *x1, float noise_estimate, float norm) {
__m128 _noise_estimate = _mm_set_ps(0.0f, noise_estimate, 0.0f, noise_estimate);
__m128 _norm = _mm_set1_ps(norm);
/* Create conjugated matrix */
__m128 _h00 = _MM_CONJ_PS(h00);
__m128 _h01 = _MM_CONJ_PS(h01);
__m128 _h10 = _MM_CONJ_PS(h10);
__m128 _h11 = _MM_CONJ_PS(h11);
/* 1. A = H' x H + No*/
__m128 a00 = _mm_add_ps(_mm_add_ps(_MM_SQMOD_PS(h00), _MM_SQMOD_PS(h10)), _noise_estimate);
__m128 a01 = _mm_add_ps(_MM_PROD_PS(_h00, h01), _MM_PROD_PS(_h10, h11));
__m128 a10 = _mm_add_ps(_MM_PROD_PS(_h01, h00), _MM_PROD_PS(_h11, h10));
__m128 a11 = _mm_add_ps(_mm_add_ps(_MM_SQMOD_PS(h01), _MM_SQMOD_PS(h11)), _noise_estimate);
/* 2. B = inv(H' x H + No) = inv(A) */
__m128 b00 = a11;
__m128 b01 = _mm_xor_ps(a01, _mm_set1_ps(-0.0f));
__m128 b10 = _mm_xor_ps(a10, _mm_set1_ps(-0.0f));
__m128 b11 = a00;
_norm = _mm_mul_ps(_norm, srslte_algebra_cf_recip_sse(srslte_algebra_2x2_det_sse(a00, a01, a10, a11)));
/* 3. W = inv(H' x H + No) x H' = B x H' */
__m128 w00 = _mm_add_ps(_MM_PROD_PS(b00, _h00), _MM_PROD_PS(b01, _h01));
__m128 w01 = _mm_add_ps(_MM_PROD_PS(b00, _h10), _MM_PROD_PS(b01, _h11));
__m128 w10 = _mm_add_ps(_MM_PROD_PS(b10, _h00), _MM_PROD_PS(b11, _h01));
__m128 w11 = _mm_add_ps(_MM_PROD_PS(b10, _h10), _MM_PROD_PS(b11, _h11));
/* 4. X = W x Y */
*x0 = _MM_PROD_PS(_mm_add_ps(_MM_PROD_PS(y0, w00), _MM_PROD_PS(y1, w01)), _norm);
*x1 = _MM_PROD_PS(_mm_add_ps(_MM_PROD_PS(y0, w10), _MM_PROD_PS(y1, w11)), _norm);
}
#endif /* LV_HAVE_SSE */
#ifdef LV_HAVE_AVX
/* AVX implementation for complex reciprocal */
inline __m256 srslte_algebra_cf_recip_avx(__m256 a) {
__m256 conj = _MM256_CONJ_PS(a);
__m256 sqabs = _mm256_mul_ps(a, a);
sqabs = _mm256_add_ps(_mm256_movehdup_ps(sqabs), _mm256_moveldup_ps(sqabs));
__m256 recp = _mm256_rcp_ps(sqabs);
return _mm256_mul_ps(recp, conj);
}
/* AVX implementation for 2x2 determinant */
inline __m256 srslte_algebra_2x2_det_avx(__m256 a00, __m256 a01, __m256 a10, __m256 a11) {
#ifdef LV_HAVE_FMA
return _MM256_PROD_SUB_PS(a00, a11, _MM256_PROD_PS(a01, a10));
#else
return _mm256_sub_ps(_MM256_PROD_PS(a00, a11), _MM256_PROD_PS(a01, a10));
#endif /* LV_HAVE_FMA */
}
/* AVX implementation for Zero Forcing (ZF) solver */
inline void srslte_algebra_2x2_zf_avx(__m256 y0, __m256 y1, __m256 h00, __m256 h01, __m256 h10, __m256 h11,
__m256 *x0, __m256 *x1, float norm) {
__m256 detmult1 = _MM256_PROD_PS(h00, h11);
__m256 detmult2 = _MM256_PROD_PS(h01, h10);
__m256 *x0, __m256 *x1, float norm) {
__m256 det = _mm256_sub_ps(detmult1, detmult2);
__m256 detconj = _MM256_CONJ_PS(det);
__m256 sqdet = _mm256_mul_ps(det, det);
__m256 detabs2 = _mm256_add_ps(_mm256_movehdup_ps(sqdet), _mm256_moveldup_ps(sqdet));
__m256 detabs2rec = _mm256_rcp_ps(detabs2);
__m256 detrec = _mm256_mul_ps(_mm256_mul_ps(detconj, detabs2rec),
(__m256) {norm, norm, norm, norm, norm, norm, norm, norm});
__m256 det = srslte_algebra_2x2_det_avx(h00, h01, h10, h11);
__m256 detrec = _mm256_mul_ps(srslte_algebra_cf_recip_avx(det), _mm256_set1_ps(norm));
#ifdef LV_HAVE_FMA
*x0 = _MM256_PROD_PS(_MM256_PROD_SUB_PS(h11, y0, _MM256_PROD_PS(h01, y1)), detrec);
*x1 = _MM256_PROD_PS(_MM256_PROD_SUB_PS(h00, y1, _MM256_PROD_PS(h10, y0)), detrec);
#else
*x0 = _MM256_PROD_PS(_mm256_sub_ps(_MM256_PROD_PS(h11, y0), _MM256_PROD_PS(h01, y1)), detrec);
*x1 = _MM256_PROD_PS(_mm256_sub_ps(_MM256_PROD_PS(h00, y1), _MM256_PROD_PS(h10, y0)), detrec);
#endif /* LV_HAVE_FMA */
}
/* AVX implementation for Minimum Mean Squared Error (MMSE) solver */
inline void srslte_algebra_2x2_mmse_avx(__m256 y0, __m256 y1, __m256 h00, __m256 h01, __m256 h10, __m256 h11,
__m256 *x0, __m256 *x1, float noise_estimate, float norm) {
__m256 _noise_estimate = _mm256_set_ps(0.0f, noise_estimate, 0.0f, noise_estimate,
0.0f, noise_estimate, 0.0f, noise_estimate);
__m256 _norm = _mm256_set1_ps(norm);
/* Create conjugated matrix */
__m256 _h00 = _MM256_CONJ_PS(h00);
__m256 _h01 = _MM256_CONJ_PS(h01);
__m256 _h10 = _MM256_CONJ_PS(h10);
__m256 _h11 = _MM256_CONJ_PS(h11);
/* 1. A = H' x H + No*/
#ifdef LV_HAVE_FMA
__m256 a00 = _MM256_SQMOD_ADD_PS(h00, h10, _noise_estimate);
__m256 a01 = _MM256_PROD_ADD_PS(_h00, h01, _MM256_PROD_PS(_h10, h11));
__m256 a10 = _MM256_PROD_ADD_PS(_h01, h00, _MM256_PROD_PS(_h11, h10));
__m256 a11 = _MM256_SQMOD_ADD_PS(h01, h11, _noise_estimate);
#else
__m256 a00 = _mm256_add_ps(_MM256_SQMOD_PS(h00, h10), _noise_estimate);
__m256 a01 = _mm256_add_ps(_MM256_PROD_PS(_h00, h01), _MM256_PROD_PS(_h10, h11));
__m256 a10 = _mm256_add_ps(_MM256_PROD_PS(_h01, h00), _MM256_PROD_PS(_h11, h10));
__m256 a11 = _mm256_add_ps(_MM256_SQMOD_PS(h01, h11), _noise_estimate);
#endif /* LV_HAVE_FMA */
/* 2. B = inv(H' x H + No) = inv(A) */
__m256 b00 = a11;
__m256 b01 = _mm256_xor_ps(a01, _mm256_set1_ps(-0.0f));
__m256 b10 = _mm256_xor_ps(a10, _mm256_set1_ps(-0.0f));
__m256 b11 = a00;
_norm = _mm256_mul_ps(_norm, srslte_algebra_cf_recip_avx(srslte_algebra_2x2_det_avx(a00, a01, a10, a11)));
/* 3. W = inv(H' x H + No) x H' = B x H' */
#ifdef LV_HAVE_FMA
__m256 w00 = _MM256_PROD_ADD_PS(b00, _h00, _MM256_PROD_PS(b01, _h01));
__m256 w01 = _MM256_PROD_ADD_PS(b00, _h10, _MM256_PROD_PS(b01, _h11));
__m256 w10 = _MM256_PROD_ADD_PS(b10, _h00, _MM256_PROD_PS(b11, _h01));
__m256 w11 = _MM256_PROD_ADD_PS(b10, _h10, _MM256_PROD_PS(b11, _h11));
#else
__m256 w00 = _mm256_add_ps(_MM256_PROD_PS(b00, _h00), _MM256_PROD_PS(b01, _h01));
__m256 w01 = _mm256_add_ps(_MM256_PROD_PS(b00, _h10), _MM256_PROD_PS(b01, _h11));
__m256 w10 = _mm256_add_ps(_MM256_PROD_PS(b10, _h00), _MM256_PROD_PS(b11, _h01));
__m256 w11 = _mm256_add_ps(_MM256_PROD_PS(b10, _h10), _MM256_PROD_PS(b11, _h11));
#endif /* LV_HAVE_FMA */
/* 4. X = W x Y */
#ifdef LV_HAVE_FMA
*x0 = _MM256_PROD_PS(_MM256_PROD_ADD_PS(y0, w00, _MM256_PROD_PS(y1, w01)), _norm);
*x1 = _MM256_PROD_PS(_MM256_PROD_ADD_PS(y0, w10, _MM256_PROD_PS(y1, w11)), _norm);
#else
*x0 = _MM256_PROD_PS(_mm256_add_ps(_MM256_PROD_PS(y0, w00), _MM256_PROD_PS(y1, w01)), _norm);
*x1 = _MM256_PROD_PS(_mm256_add_ps(_MM256_PROD_PS(y0, w10), _MM256_PROD_PS(y1, w11)), _norm);
#endif /* LV_HAVE_FMA */
}
#endif /* LV_HAVE_AVX */

View File

@ -33,3 +33,12 @@ add_test(dft_dc dft_test -b -d) # Backwards first & handle dc internally
add_test(dft_odd dft_test -N 255) # Odd-length
add_test(dft_odd_dc dft_test -N 255 -b -d) # Odd-length, backwards first, handle dc
########################################################################
# Algebra TEST
########################################################################
add_executable(algebra_test algebra_test.c)
target_link_libraries(algebra_test srslte_phy)
add_test(algebra_2x2_zf_solver_test algebra_test -z)
add_test(algebra_2x2_mmse_solver_test algebra_test -m)

View File

@ -0,0 +1,415 @@
/**
*
* \section COPYRIGHT
*
* Copyright 2013-2015 Software Radio Systems Limited
*
* \section LICENSE
*
* This file is part of the srsLTE library.
*
* srsLTE is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as
* published by the Free Software Foundation, either version 3 of
* the License, or (at your option) any later version.
*
* srsLTE is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* A copy of the GNU Affero General Public License can be found in
* the LICENSE file in the top-level directory of this distribution
* and at http://www.gnu.org/licenses/.
*
*/
#include <stdio.h>
#include <stdlib.h>
#include <unistd.h>
#include <complex.h>
#include <stdbool.h>
#include <immintrin.h>
#include <sys/time.h>
#include "srslte/phy/utils/algebra.h"
bool zf_solver = false;
bool mmse_solver = false;
bool verbose = false;
double elapsed_us(struct timeval *ts_start, struct timeval *ts_end) {
if (ts_end->tv_usec > ts_start->tv_usec) {
return ((double) ts_end->tv_sec - (double) ts_start->tv_sec) * 1000000 +
(double) ts_end->tv_usec - (double) ts_start->tv_usec;
} else {
return ((double) ts_end->tv_sec - (double) ts_start->tv_sec - 1) * 1000000 +
((double) ts_end->tv_usec + 1000000) - (double) ts_start->tv_usec;
}
}
#define NOF_REPETITIONS 1000
#define RUN_TEST(FUNCTION) /*TYPE NAME (void)*/ { \
int i;\
struct timeval start, end;\
gettimeofday(&start, NULL); \
bool ret = true; \
for (i = 0; i < NOF_REPETITIONS; i++) {ret &= FUNCTION ();}\
gettimeofday(&end, NULL);\
if (verbose) printf("%32s: %s ... %6.2f us/call\n", #FUNCTION, (ret)?"Pass":"Fail", \
elapsed_us(&start, &end)/NOF_REPETITIONS);\
passed &= ret;\
}
void usage(char *prog) {
printf("Usage: %s [mzvh]\n", prog);
printf("\t-m Test Minimum Mean Squared Error (MMSE) solver\n");
printf("\t-z Test Zero Forcing (ZF) solver\n");
printf("\t-v Verbose\n");
printf("\t-h Show this message\n");
}
void parse_args(int argc, char **argv) {
int opt;
while ((opt = getopt(argc, argv, "mzvh")) != -1) {
switch (opt) {
case 'm':
mmse_solver = true;
break;
case 'z':
zf_solver = true;
break;
case 'v':
verbose = true;
break;
case 'h':
default:
usage(argv[0]);
exit(-1);
}
}
}
bool test_zf_solver_gen(void) {
cf_t x0, x1, cf_error0, cf_error1;
float error;
cf_t x0_gold = RANDOM_CF();
cf_t x1_gold = RANDOM_CF();
cf_t h00 = RANDOM_CF();
cf_t h01 = RANDOM_CF();
cf_t h10 = RANDOM_CF();
cf_t h11 = (1 - h01 * h10) / h00;
cf_t y0 = x0_gold * h00 + x1_gold * h01;
cf_t y1 = x0_gold * h10 + x1_gold * h11;
srslte_algebra_2x2_zf_gen(y0, y1, h00, h01, h10, h11, &x0, &x1, 1.0f);
cf_error0 = x0 - x0_gold;
cf_error1 = x1 - x1_gold;
error = crealf(cf_error0) * crealf(cf_error0) + cimagf(cf_error0) * cimagf(cf_error0) +
crealf(cf_error1) * crealf(cf_error1) + cimagf(cf_error1) * cimagf(cf_error1);
return (error < 1e-6);
}
bool test_mmse_solver_gen(void) {
cf_t x0, x1, cf_error0, cf_error1;
float error;
cf_t x0_gold = RANDOM_CF();
cf_t x1_gold = RANDOM_CF();
cf_t h00 = RANDOM_CF();
cf_t h01 = RANDOM_CF();
cf_t h10 = RANDOM_CF();
cf_t h11 = (1 - h01 * h10) / h00;
cf_t y0 = x0_gold * h00 + x1_gold * h01;
cf_t y1 = x0_gold * h10 + x1_gold * h11;
srslte_algebra_2x2_mmse_gen(y0, y1, h00, h01, h10, h11, &x0, &x1, 0.0f, 1.0f);
cf_error0 = x0 - x0_gold;
cf_error1 = x1 - x1_gold;
error = crealf(cf_error0) * crealf(cf_error0) + cimagf(cf_error0) * cimagf(cf_error0) +
crealf(cf_error1) * crealf(cf_error1) + cimagf(cf_error1) * cimagf(cf_error1);
return (error < 1e-6);
}
#ifdef LV_HAVE_SSE
bool test_zf_solver_sse(void) {
cf_t cf_error0, cf_error1;
float error = 0.0f;
cf_t x0_gold_1 = RANDOM_CF();
cf_t x1_gold_1 = RANDOM_CF();
cf_t h00_1 = RANDOM_CF();
cf_t h01_1 = RANDOM_CF();
cf_t h10_1 = RANDOM_CF();
cf_t h11_1 = (1 - h01_1 * h10_1) / h00_1;
cf_t y0_1 = x0_gold_1 * h00_1 + x1_gold_1 * h01_1;
cf_t y1_1 = x0_gold_1 * h10_1 + x1_gold_1 * h11_1;
cf_t x0_gold_2 = RANDOM_CF();
cf_t x1_gold_2 = RANDOM_CF();
cf_t h00_2 = RANDOM_CF();
cf_t h01_2 = RANDOM_CF();
cf_t h10_2 = RANDOM_CF();
cf_t h11_2 = (1 - h01_2 * h10_2) / h00_2;
cf_t y0_2 = x0_gold_2 * h00_2 + x1_gold_2 * h01_2;
cf_t y1_2 = x0_gold_2 * h10_2 + x1_gold_2 * h11_2;
__m128 _y0 = _mm_set_ps(cimagf(y0_1), crealf(y0_1), cimagf(y0_2), crealf(y0_2));
__m128 _y1 = _mm_set_ps(cimagf(y1_1), crealf(y1_1), cimagf(y1_2), crealf(y1_2));
__m128 _h00 = _mm_set_ps(cimagf(h00_1), crealf(h00_1), cimagf(h00_2), crealf(h00_2));
__m128 _h01 = _mm_set_ps(cimagf(h01_1), crealf(h01_1), cimagf(h01_2), crealf(h01_2));
__m128 _h10 = _mm_set_ps(cimagf(h10_1), crealf(h10_1), cimagf(h10_2), crealf(h10_2));
__m128 _h11 = _mm_set_ps(cimagf(h11_1), crealf(h11_1), cimagf(h11_2), crealf(h11_2));
__m128 _x0, _x1;
srslte_algebra_2x2_zf_sse(_y0, _y1, _h00, _h01, _h10, _h11, &_x0, &_x1, 1.0f);
__attribute__((aligned(128))) cf_t x0[2];
__attribute__((aligned(128))) cf_t x1[2];
_mm_store_ps((float *) x0, _x0);
_mm_store_ps((float *) x1, _x1);
cf_error0 = x0[1] - x0_gold_1;
cf_error1 = x1[1] - x1_gold_1;
error += crealf(cf_error0) * crealf(cf_error0) + cimagf(cf_error0) * cimagf(cf_error0) +
crealf(cf_error1) * crealf(cf_error1) + cimagf(cf_error1) * cimagf(cf_error1);
cf_error0 = x0[0] - x0_gold_2;
cf_error1 = x1[0] - x1_gold_2;
error += crealf(cf_error0) * crealf(cf_error0) + cimagf(cf_error0) * cimagf(cf_error0) +
crealf(cf_error1) * crealf(cf_error1) + cimagf(cf_error1) * cimagf(cf_error1);
return (error < 1e-3);
}
bool test_mmse_solver_sse(void) {
cf_t cf_error0, cf_error1;
float error = 0.0f;
cf_t x0_gold_1 = RANDOM_CF();
cf_t x1_gold_1 = RANDOM_CF();
cf_t h00_1 = RANDOM_CF();
cf_t h01_1 = RANDOM_CF();
cf_t h10_1 = RANDOM_CF();
cf_t h11_1 = (1 - h01_1 * h10_1) / h00_1;
cf_t y0_1 = x0_gold_1 * h00_1 + x1_gold_1 * h01_1;
cf_t y1_1 = x0_gold_1 * h10_1 + x1_gold_1 * h11_1;
cf_t x0_gold_2 = RANDOM_CF();
cf_t x1_gold_2 = RANDOM_CF();
cf_t h00_2 = RANDOM_CF();
cf_t h01_2 = RANDOM_CF();
cf_t h10_2 = RANDOM_CF();
cf_t h11_2 = (1 - h01_2 * h10_2) / h00_2;
cf_t y0_2 = x0_gold_2 * h00_2 + x1_gold_2 * h01_2;
cf_t y1_2 = x0_gold_2 * h10_2 + x1_gold_2 * h11_2;
__m128 _y0 = _mm_set_ps(cimagf(y0_1), crealf(y0_1), cimagf(y0_2), crealf(y0_2));
__m128 _y1 = _mm_set_ps(cimagf(y1_1), crealf(y1_1), cimagf(y1_2), crealf(y1_2));
__m128 _h00 = _mm_set_ps(cimagf(h00_1), crealf(h00_1), cimagf(h00_2), crealf(h00_2));
__m128 _h01 = _mm_set_ps(cimagf(h01_1), crealf(h01_1), cimagf(h01_2), crealf(h01_2));
__m128 _h10 = _mm_set_ps(cimagf(h10_1), crealf(h10_1), cimagf(h10_2), crealf(h10_2));
__m128 _h11 = _mm_set_ps(cimagf(h11_1), crealf(h11_1), cimagf(h11_2), crealf(h11_2));
__m128 _x0, _x1;
srslte_algebra_2x2_mmse_sse(_y0, _y1, _h00, _h01, _h10, _h11, &_x0, &_x1, 0.0f, 1.0f);
__attribute__((aligned(128))) cf_t x0[2];
__attribute__((aligned(128))) cf_t x1[2];
_mm_store_ps((float *) x0, _x0);
_mm_store_ps((float *) x1, _x1);
cf_error0 = x0[1] - x0_gold_1;
cf_error1 = x1[1] - x1_gold_1;
error += crealf(cf_error0) * crealf(cf_error0) + cimagf(cf_error0) * cimagf(cf_error0) +
crealf(cf_error1) * crealf(cf_error1) + cimagf(cf_error1) * cimagf(cf_error1);
cf_error0 = x0[0] - x0_gold_2;
cf_error1 = x1[0] - x1_gold_2;
error += crealf(cf_error0) * crealf(cf_error0) + cimagf(cf_error0) * cimagf(cf_error0) +
crealf(cf_error1) * crealf(cf_error1) + cimagf(cf_error1) * cimagf(cf_error1);
return (error < 1e-3);
}
#endif /* LV_HAVE_SSE */
#ifdef LV_HAVE_AVX
bool test_zf_solver_avx(void) {
cf_t cf_error0, cf_error1;
float error = 0.0f;
cf_t x0_gold_1 = RANDOM_CF();
cf_t x1_gold_1 = RANDOM_CF();
cf_t h00_1 = RANDOM_CF();
cf_t h01_1 = RANDOM_CF();
cf_t h10_1 = RANDOM_CF();
cf_t h11_1 = (1 - h01_1 * h10_1) / h00_1;
cf_t y0_1 = x0_gold_1 * h00_1 + x1_gold_1 * h01_1;
cf_t y1_1 = x0_gold_1 * h10_1 + x1_gold_1 * h11_1;
cf_t x0_gold_2 = RANDOM_CF();
cf_t x1_gold_2 = RANDOM_CF();
cf_t h00_2 = RANDOM_CF();
cf_t h01_2 = RANDOM_CF();
cf_t h10_2 = RANDOM_CF();
cf_t h11_2 = (1 - h01_2 * h10_2) / h00_2;
cf_t y0_2 = x0_gold_2 * h00_2 + x1_gold_2 * h01_2;
cf_t y1_2 = x0_gold_2 * h10_2 + x1_gold_2 * h11_2;
__m256 _y0 = _mm256_set_ps(cimagf(y0_1), crealf(y0_1), cimagf(y0_2), crealf(y0_2),
cimagf(y0_1), crealf(y0_1), cimagf(y0_2), crealf(y0_2));
__m256 _y1 = _mm256_set_ps(cimagf(y1_1), crealf(y1_1), cimagf(y1_2), crealf(y1_2),
cimagf(y1_1), crealf(y1_1), cimagf(y1_2), crealf(y1_2));
__m256 _h00 = _mm256_set_ps(cimagf(h00_1), crealf(h00_1), cimagf(h00_2), crealf(h00_2),
cimagf(h00_1), crealf(h00_1), cimagf(h00_2), crealf(h00_2));
__m256 _h01 = _mm256_set_ps(cimagf(h01_1), crealf(h01_1), cimagf(h01_2), crealf(h01_2),
cimagf(h01_1), crealf(h01_1), cimagf(h01_2), crealf(h01_2));
__m256 _h10 = _mm256_set_ps(cimagf(h10_1), crealf(h10_1), cimagf(h10_2), crealf(h10_2),
cimagf(h10_1), crealf(h10_1), cimagf(h10_2), crealf(h10_2));
__m256 _h11 = _mm256_set_ps(cimagf(h11_1), crealf(h11_1), cimagf(h11_2), crealf(h11_2),
cimagf(h11_1), crealf(h11_1), cimagf(h11_2), crealf(h11_2));
__m256 _x0, _x1;
srslte_algebra_2x2_zf_avx(_y0, _y1, _h00, _h01, _h10, _h11, &_x0, &_x1, 1.0f);
__attribute__((aligned(256))) cf_t x0[4];
__attribute__((aligned(256))) cf_t x1[4];
_mm256_store_ps((float *) x0, _x0);
_mm256_store_ps((float *) x1, _x1);
cf_error0 = x0[1] - x0_gold_1;
cf_error1 = x1[1] - x1_gold_1;
error += crealf(cf_error0) * crealf(cf_error0) + cimagf(cf_error0) * cimagf(cf_error0) +
crealf(cf_error1) * crealf(cf_error1) + cimagf(cf_error1) * cimagf(cf_error1);
cf_error0 = x0[0] - x0_gold_2;
cf_error1 = x1[0] - x1_gold_2;
error += crealf(cf_error0) * crealf(cf_error0) + cimagf(cf_error0) * cimagf(cf_error0) +
crealf(cf_error1) * crealf(cf_error1) + cimagf(cf_error1) * cimagf(cf_error1);
return (error < 1e-3);
}
bool test_mmse_solver_avx(void) {
cf_t cf_error0, cf_error1;
float error = 0.0f;
cf_t x0_gold_1 = RANDOM_CF();
cf_t x1_gold_1 = RANDOM_CF();
cf_t h00_1 = RANDOM_CF();
cf_t h01_1 = RANDOM_CF();
cf_t h10_1 = RANDOM_CF();
cf_t h11_1 = (1 - h01_1 * h10_1) / h00_1;
cf_t y0_1 = x0_gold_1 * h00_1 + x1_gold_1 * h01_1;
cf_t y1_1 = x0_gold_1 * h10_1 + x1_gold_1 * h11_1;
cf_t x0_gold_2 = RANDOM_CF();
cf_t x1_gold_2 = RANDOM_CF();
cf_t h00_2 = RANDOM_CF();
cf_t h01_2 = RANDOM_CF();
cf_t h10_2 = RANDOM_CF();
cf_t h11_2 = (1 - h01_2 * h10_2) / h00_2;
cf_t y0_2 = x0_gold_2 * h00_2 + x1_gold_2 * h01_2;
cf_t y1_2 = x0_gold_2 * h10_2 + x1_gold_2 * h11_2;
__m256 _y0 = _mm256_set_ps(cimagf(y0_1), crealf(y0_1), cimagf(y0_2), crealf(y0_2),
cimagf(y0_1), crealf(y0_1), cimagf(y0_2), crealf(y0_2));
__m256 _y1 = _mm256_set_ps(cimagf(y1_1), crealf(y1_1), cimagf(y1_2), crealf(y1_2),
cimagf(y1_1), crealf(y1_1), cimagf(y1_2), crealf(y1_2));
__m256 _h00 = _mm256_set_ps(cimagf(h00_1), crealf(h00_1), cimagf(h00_2), crealf(h00_2),
cimagf(h00_1), crealf(h00_1), cimagf(h00_2), crealf(h00_2));
__m256 _h01 = _mm256_set_ps(cimagf(h01_1), crealf(h01_1), cimagf(h01_2), crealf(h01_2),
cimagf(h01_1), crealf(h01_1), cimagf(h01_2), crealf(h01_2));
__m256 _h10 = _mm256_set_ps(cimagf(h10_1), crealf(h10_1), cimagf(h10_2), crealf(h10_2),
cimagf(h10_1), crealf(h10_1), cimagf(h10_2), crealf(h10_2));
__m256 _h11 = _mm256_set_ps(cimagf(h11_1), crealf(h11_1), cimagf(h11_2), crealf(h11_2),
cimagf(h11_1), crealf(h11_1), cimagf(h11_2), crealf(h11_2));
__m256 _x0, _x1;
srslte_algebra_2x2_mmse_avx(_y0, _y1, _h00, _h01, _h10, _h11, &_x0, &_x1, 0.0f, 1.0f);
__attribute__((aligned(256))) cf_t x0[4];
__attribute__((aligned(256))) cf_t x1[4];
_mm256_store_ps((float *) x0, _x0);
_mm256_store_ps((float *) x1, _x1);
cf_error0 = x0[1] - x0_gold_1;
cf_error1 = x1[1] - x1_gold_1;
error += crealf(cf_error0) * crealf(cf_error0) + cimagf(cf_error0) * cimagf(cf_error0) +
crealf(cf_error1) * crealf(cf_error1) + cimagf(cf_error1) * cimagf(cf_error1);
cf_error0 = x0[0] - x0_gold_2;
cf_error1 = x1[0] - x1_gold_2;
error += crealf(cf_error0) * crealf(cf_error0) + cimagf(cf_error0) * cimagf(cf_error0) +
crealf(cf_error1) * crealf(cf_error1) + cimagf(cf_error1) * cimagf(cf_error1);
return (error < 1e-3);
}
#endif /* LV_HAVE_AVX */
int main(int argc, char **argv) {
bool passed = true;
int ret = 0;
parse_args(argc, argv);
if (zf_solver) {
RUN_TEST(test_zf_solver_gen);
#ifdef LV_HAVE_SSE
RUN_TEST(test_zf_solver_sse);
#endif /* LV_HAVE_SSE */
#ifdef LV_HAVE_AVX
RUN_TEST(test_zf_solver_avx);
#endif /* LV_HAVE_AVX */
}
if (mmse_solver) {
RUN_TEST(test_mmse_solver_gen);
#ifdef LV_HAVE_SSE
RUN_TEST(test_mmse_solver_sse);
#endif /* LV_HAVE_SSE */
#ifdef LV_HAVE_AVX
RUN_TEST(test_mmse_solver_avx);
#endif /* LV_HAVE_AVX */
}
printf("%s!\n", (passed) ? "Ok" : "Failed");
if (!passed) {
exit(ret);
}
exit(ret);
}