templated interpolate2d (#2221)

* template table2d

* assertions

* oops

* tests

* test bins too

* more assert

* format

* explicitly handle and test NaN

Co-authored-by: Matthew Kennedy <makenne@microsoft.com>
This commit is contained in:
Matthew Kennedy 2021-01-11 05:00:15 -08:00 committed by GitHub
parent 04e791d99c
commit 604b0cd527
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 249 additions and 28 deletions

View File

@ -191,30 +191,6 @@ int findIndex(const float array[], int size, float value) {
return findIndexMsg("", array, size, value);
}
namespace priv
{
/**
* @brief One-dimensional table lookup with linear interpolation
*
* @see setLinearCurve()
*/
float interpolate2d(const char *msg, float value, const float bin[], const float values[], int size) {
if (isnan(value)) {
// this unfortunately sometimes happens during functional tests on real hardware
warning(CUSTOM_INTERPOLATE_NAN, "NaN in interpolate2d %s", msg);
return NAN;
}
int index = findIndexMsg(msg, bin, size, value);
if (index == -1)
return values[0];
if (index == size - 1)
return values[size - 1];
return interpolateMsg(msg, bin[index], values[index], bin[index + 1], values[index + 1], value);
}
}
/**
* Sets specified value for specified key in a correction curve
* see also setLinearCurve()

View File

@ -13,6 +13,8 @@
#include "obd_error_codes.h"
#include "error_handling.h"
#include <type_traits>
#ifndef DEBUG_INTERPOLATION
#define DEBUG_INTERPOLATION FALSE
#endif
@ -27,12 +29,87 @@ float interpolateClamped(float x1, float y1, float x2, float y2, float x);
float interpolateMsg(const char *msg, float x1, float y1, float x2, float y2, float x);
namespace priv {
float interpolate2d(const char *msg, float value, const float bin[], const float values[], int size);
struct BinResult
{
size_t Idx;
float Frac;
};
/**
* @brief Finds the location of a value in the bin array.
*
* @param value The value to find in the bins.
* @return A result containing the index to the left of the value,
* and how far from (idx) to (idx + 1) the value is located.
*/
template<class TBin, int TSize>
BinResult getBin(float value, const TBin (&bins)[TSize]) {
// Enforce numeric only (int, float, uintx_t, etc)
static_assert(std::is_arithmetic_v<TBin>, "Table bins must be an arithmetic type");
// Enforce that there are enough bins to make sense (what does one bin even mean?)
static_assert(TSize >= 2);
// Handle NaN
if (cisnan(value)) {
return { 0, 0.0f };
}
// Handle off-scale low
if (value <= bins[0]) {
return { 0, 0.0f };
}
// Handle off-scale high
if (value >= bins[TSize - 1]) {
return { TSize - 2, 1.0f };
}
size_t idx = 0;
// Find the last index less than the searched value
// Linear search for now, maybe binary search in future
// after collecting real perf data
for (idx = 0; idx < TSize - 1; idx++) {
if (bins[idx + 1] > value) {
break;
}
}
float low = bins[idx];
float high = bins[idx + 1];
// Compute how far along the bin we are
// (0.0f = left side, 1.0f = right side)
float fraction = (value - low) / (high - low);
return { idx, fraction };
}
template <int TSize>
float interpolate2d(const char *msg, const float value, const float (&bin)[TSize], const float (&values)[TSize]) {
return priv::interpolate2d(msg, value, bin, values, TSize);
static float linterp(float low, float high, float frac)
{
return high * frac + low * (1 - frac);
}
} // namespace priv
template <class TBin, class TValue, int TSize>
float interpolate2d(const char *msg, const float value, const TBin (&bin)[TSize], const TValue (&values)[TSize]) {
// Enforce numeric only (int, float, uintx_t, etc)
static_assert(std::is_arithmetic_v<TBin>, "Table values must be an arithmetic type");
auto b = priv::getBin(value, bin);
// Convert to float as we read it out
float low = static_cast<float>(values[b.Idx]);
float high = static_cast<float>(values[b.Idx + 1]);
float frac = b.Frac;
return priv::linterp(low, high, frac);
}
template <class TBin, class TValue, int TSize>
float interpolate2d(const float value, const TBin (&bin)[TSize], const TValue (&values)[TSize]) {
return interpolate2d("", value, bin, values);
}
int needInterpolationLogging(void);

View File

@ -110,3 +110,171 @@ TEST(misc, testSetTableValue) {
ASSERT_FLOAT_EQ(1.4, config.cltFuelCorr[0]);
}
class TestTable2dSmall : public ::testing::Test
{
protected:
float bins[2];
float values[2];
void SetUp() override
{
// This test maps [20,30] -> [100,200]
copyArray(bins, { 20.0f, 30.0f });
copyArray(values, { 100.0f, 200.0f });
}
};
TEST_F(TestTable2dSmall, OffScaleLow)
{
EXPECT_FLOAT_EQ(interpolate2d(10, bins, values), 100);
}
TEST_F(TestTable2dSmall, OffScaleHigh)
{
EXPECT_FLOAT_EQ(interpolate2d(40, bins, values), 200);
}
TEST_F(TestTable2dSmall, EdgeLeft)
{
EXPECT_FLOAT_EQ(interpolate2d(20, bins, values), 100);
}
TEST_F(TestTable2dSmall, EdgeRight)
{
EXPECT_FLOAT_EQ(interpolate2d(30, bins, values), 200);
}
TEST_F(TestTable2dSmall, Middle)
{
EXPECT_FLOAT_EQ(interpolate2d(25, bins, values), 150);
}
TEST_F(TestTable2dSmall, NanInput)
{
EXPECT_FLOAT_EQ(interpolate2d(NAN, bins, values), 100);
}
class Test2dTableMassive : public ::testing::Test
{
static constexpr int Count = 2500;
protected:
float bins[Count];
float values[Count];
void SetUp() override
{
float x = 0;
for (size_t i = 0; i < std::size(bins); i++)
{
x += 0.1f;
bins[i] = x;
values[i] = x * x;
}
}
};
TEST_F(Test2dTableMassive, t)
{
float x = 0;
float maxErr = -1;
for (size_t i = 0; i < 25000; i++)
{
x += 0.01f;
float actual = x * x;
float lookup = interpolate2d(x, bins, values);
float err = std::abs(actual - lookup);
if (err > maxErr)
{
maxErr = err;
}
}
EXPECT_LT(maxErr, 0.01);
}
// Helper for BinResult type
#define EXPECT_BINRESULT(actual, expectedIdx, expectedFrac) \
{ \
auto ___temp___ = actual; \
EXPECT_EQ(___temp___.Idx, expectedIdx); \
EXPECT_NEAR(___temp___.Frac, expectedFrac, expectedFrac / 1e4); \
}
// Test with small bins: only two values
static const float smallBins[] = { 10, 20 };
TEST(TableBinsSmall, OffScaleLeft)
{
EXPECT_BINRESULT(priv::getBin(5, smallBins), 0, 0);
}
TEST(TableBinsSmall, OffScaleRight)
{
EXPECT_BINRESULT(priv::getBin(25, smallBins), 0, 1);
}
TEST(TableBinsSmall, EdgeLeft)
{
EXPECT_BINRESULT(priv::getBin(10, smallBins), 0, 0);
}
TEST(TableBinsSmall, EdgeRight)
{
EXPECT_BINRESULT(priv::getBin(10, smallBins), 0, 0);
}
TEST(TableBinsSmall, Middle)
{
EXPECT_BINRESULT(priv::getBin(15, smallBins), 0, 0.5f);
}
TEST(TableBinsSmall, NanInput)
{
EXPECT_BINRESULT(priv::getBin(NAN, smallBins), 0, 0);
}
// Test with medium bins, 3 items
static const float bigBins[] = { 10, 20, 30 };
TEST(TableBinsBig, OffScaleLow)
{
EXPECT_BINRESULT(priv::getBin(5, bigBins), 0, 0);
}
TEST(TableBinsBig, OffScaleHigh)
{
EXPECT_BINRESULT(priv::getBin(35, bigBins), 1, 1.0f);
}
TEST(TableBinsBig, NearMiddleLow)
{
EXPECT_BINRESULT(priv::getBin(19.99f, bigBins), 0, 0.999f);
}
TEST(TableBinsBig, NearMiddleExact)
{
EXPECT_BINRESULT(priv::getBin(20.0f, bigBins), 1, 0);
}
TEST(TableBinsBig, NearMiddleHigh)
{
EXPECT_BINRESULT(priv::getBin(20.01f, bigBins), 1, 0.001f);
}
TEST(TableBinsBig, LeftMiddle)
{
EXPECT_BINRESULT(priv::getBin(15.0f, bigBins), 0, 0.5f);
}
TEST(TableBinsBig, RightMiddle)
{
EXPECT_BINRESULT(priv::getBin(25.0f, bigBins), 1, 0.5f);
}