diff --git a/firmware/util/math/interpolation.cpp b/firmware/util/math/interpolation.cpp index cb0f95fc48..6f08144fb1 100644 --- a/firmware/util/math/interpolation.cpp +++ b/firmware/util/math/interpolation.cpp @@ -191,30 +191,6 @@ int findIndex(const float array[], int size, float value) { return findIndexMsg("", array, size, value); } -namespace priv -{ -/** - * @brief One-dimensional table lookup with linear interpolation - * - * @see setLinearCurve() - */ -float interpolate2d(const char *msg, float value, const float bin[], const float values[], int size) { - if (isnan(value)) { - // this unfortunately sometimes happens during functional tests on real hardware - warning(CUSTOM_INTERPOLATE_NAN, "NaN in interpolate2d %s", msg); - return NAN; - } - int index = findIndexMsg(msg, bin, size, value); - - if (index == -1) - return values[0]; - if (index == size - 1) - return values[size - 1]; - - return interpolateMsg(msg, bin[index], values[index], bin[index + 1], values[index + 1], value); -} -} - /** * Sets specified value for specified key in a correction curve * see also setLinearCurve() diff --git a/firmware/util/math/interpolation.h b/firmware/util/math/interpolation.h index 8e1267d9e0..b94ddbc756 100644 --- a/firmware/util/math/interpolation.h +++ b/firmware/util/math/interpolation.h @@ -13,6 +13,8 @@ #include "obd_error_codes.h" #include "error_handling.h" +#include + #ifndef DEBUG_INTERPOLATION #define DEBUG_INTERPOLATION FALSE #endif @@ -27,12 +29,87 @@ float interpolateClamped(float x1, float y1, float x2, float y2, float x); float interpolateMsg(const char *msg, float x1, float y1, float x2, float y2, float x); namespace priv { -float interpolate2d(const char *msg, float value, const float bin[], const float values[], int size); +struct BinResult +{ + size_t Idx; + float Frac; +}; + +/** + * @brief Finds the location of a value in the bin array. + * + * @param value The value to find in the bins. + * @return A result containing the index to the left of the value, + * and how far from (idx) to (idx + 1) the value is located. + */ +template +BinResult getBin(float value, const TBin (&bins)[TSize]) { + // Enforce numeric only (int, float, uintx_t, etc) + static_assert(std::is_arithmetic_v, "Table bins must be an arithmetic type"); + + // Enforce that there are enough bins to make sense (what does one bin even mean?) + static_assert(TSize >= 2); + + // Handle NaN + if (cisnan(value)) { + return { 0, 0.0f }; + } + + // Handle off-scale low + if (value <= bins[0]) { + return { 0, 0.0f }; + } + + // Handle off-scale high + if (value >= bins[TSize - 1]) { + return { TSize - 2, 1.0f }; + } + + size_t idx = 0; + + // Find the last index less than the searched value + // Linear search for now, maybe binary search in future + // after collecting real perf data + for (idx = 0; idx < TSize - 1; idx++) { + if (bins[idx + 1] > value) { + break; + } + } + + float low = bins[idx]; + float high = bins[idx + 1]; + + // Compute how far along the bin we are + // (0.0f = left side, 1.0f = right side) + float fraction = (value - low) / (high - low); + + return { idx, fraction }; } -template -float interpolate2d(const char *msg, const float value, const float (&bin)[TSize], const float (&values)[TSize]) { - return priv::interpolate2d(msg, value, bin, values, TSize); +static float linterp(float low, float high, float frac) +{ + return high * frac + low * (1 - frac); +} +} // namespace priv + +template +float interpolate2d(const char *msg, const float value, const TBin (&bin)[TSize], const TValue (&values)[TSize]) { + // Enforce numeric only (int, float, uintx_t, etc) + static_assert(std::is_arithmetic_v, "Table values must be an arithmetic type"); + + auto b = priv::getBin(value, bin); + + // Convert to float as we read it out + float low = static_cast(values[b.Idx]); + float high = static_cast(values[b.Idx + 1]); + float frac = b.Frac; + + return priv::linterp(low, high, frac); +} + +template +float interpolate2d(const float value, const TBin (&bin)[TSize], const TValue (&values)[TSize]) { + return interpolate2d("", value, bin, values); } int needInterpolationLogging(void); diff --git a/unit_tests/test_basic_math/test_find_index.cpp b/unit_tests/test_basic_math/test_find_index.cpp index 6101ee9e8d..25d543008e 100644 --- a/unit_tests/test_basic_math/test_find_index.cpp +++ b/unit_tests/test_basic_math/test_find_index.cpp @@ -110,3 +110,171 @@ TEST(misc, testSetTableValue) { ASSERT_FLOAT_EQ(1.4, config.cltFuelCorr[0]); } + +class TestTable2dSmall : public ::testing::Test +{ +protected: + float bins[2]; + float values[2]; + + void SetUp() override + { + // This test maps [20,30] -> [100,200] + copyArray(bins, { 20.0f, 30.0f }); + copyArray(values, { 100.0f, 200.0f }); + } +}; + +TEST_F(TestTable2dSmall, OffScaleLow) +{ + EXPECT_FLOAT_EQ(interpolate2d(10, bins, values), 100); +} + +TEST_F(TestTable2dSmall, OffScaleHigh) +{ + EXPECT_FLOAT_EQ(interpolate2d(40, bins, values), 200); +} + +TEST_F(TestTable2dSmall, EdgeLeft) +{ + EXPECT_FLOAT_EQ(interpolate2d(20, bins, values), 100); +} + +TEST_F(TestTable2dSmall, EdgeRight) +{ + EXPECT_FLOAT_EQ(interpolate2d(30, bins, values), 200); +} + +TEST_F(TestTable2dSmall, Middle) +{ + EXPECT_FLOAT_EQ(interpolate2d(25, bins, values), 150); +} + +TEST_F(TestTable2dSmall, NanInput) +{ + EXPECT_FLOAT_EQ(interpolate2d(NAN, bins, values), 100); +} + +class Test2dTableMassive : public ::testing::Test +{ + static constexpr int Count = 2500; + +protected: + float bins[Count]; + float values[Count]; + + void SetUp() override + { + float x = 0; + + for (size_t i = 0; i < std::size(bins); i++) + { + x += 0.1f; + bins[i] = x; + values[i] = x * x; + } + } +}; + +TEST_F(Test2dTableMassive, t) +{ + float x = 0; + float maxErr = -1; + + for (size_t i = 0; i < 25000; i++) + { + x += 0.01f; + + float actual = x * x; + float lookup = interpolate2d(x, bins, values); + + float err = std::abs(actual - lookup); + + if (err > maxErr) + { + maxErr = err; + } + } + + EXPECT_LT(maxErr, 0.01); +} + +// Helper for BinResult type +#define EXPECT_BINRESULT(actual, expectedIdx, expectedFrac) \ + { \ + auto ___temp___ = actual; \ + EXPECT_EQ(___temp___.Idx, expectedIdx); \ + EXPECT_NEAR(___temp___.Frac, expectedFrac, expectedFrac / 1e4); \ + } + +// Test with small bins: only two values +static const float smallBins[] = { 10, 20 }; + +TEST(TableBinsSmall, OffScaleLeft) +{ + EXPECT_BINRESULT(priv::getBin(5, smallBins), 0, 0); +} + +TEST(TableBinsSmall, OffScaleRight) +{ + EXPECT_BINRESULT(priv::getBin(25, smallBins), 0, 1); +} + +TEST(TableBinsSmall, EdgeLeft) +{ + EXPECT_BINRESULT(priv::getBin(10, smallBins), 0, 0); +} + +TEST(TableBinsSmall, EdgeRight) +{ + EXPECT_BINRESULT(priv::getBin(10, smallBins), 0, 0); +} + +TEST(TableBinsSmall, Middle) +{ + EXPECT_BINRESULT(priv::getBin(15, smallBins), 0, 0.5f); +} + +TEST(TableBinsSmall, NanInput) +{ + EXPECT_BINRESULT(priv::getBin(NAN, smallBins), 0, 0); +} + +// Test with medium bins, 3 items +static const float bigBins[] = { 10, 20, 30 }; + +TEST(TableBinsBig, OffScaleLow) +{ + EXPECT_BINRESULT(priv::getBin(5, bigBins), 0, 0); +} + +TEST(TableBinsBig, OffScaleHigh) +{ + EXPECT_BINRESULT(priv::getBin(35, bigBins), 1, 1.0f); +} + + +TEST(TableBinsBig, NearMiddleLow) +{ + EXPECT_BINRESULT(priv::getBin(19.99f, bigBins), 0, 0.999f); +} + +TEST(TableBinsBig, NearMiddleExact) +{ + EXPECT_BINRESULT(priv::getBin(20.0f, bigBins), 1, 0); +} + +TEST(TableBinsBig, NearMiddleHigh) +{ + EXPECT_BINRESULT(priv::getBin(20.01f, bigBins), 1, 0.001f); +} + +TEST(TableBinsBig, LeftMiddle) +{ + EXPECT_BINRESULT(priv::getBin(15.0f, bigBins), 0, 0.5f); +} + +TEST(TableBinsBig, RightMiddle) +{ + EXPECT_BINRESULT(priv::getBin(25.0f, bigBins), 1, 0.5f); +}