diff --git a/firmware/util/containers/table_helper.h b/firmware/util/containers/table_helper.h index 8a5cf3c534..786d9a1ecd 100644 --- a/firmware/util/containers/table_helper.h +++ b/firmware/util/containers/table_helper.h @@ -29,7 +29,8 @@ template - void init(TValueInit table[TRowNum][TColNum], const TRowInit rowBins[TRowNum], const TColumnInit columnBins[TColNum]) { + void init(TValueInit (&table)[TRowNum][TColNum], + const TRowInit (&rowBins)[TRowNum], const TColumnInit (&columnBins)[TColNum]) { // This splits out here so that we don't need one overload of init per possible combination of table/rows/columns types/dimensions // Overload resolution figures out the correct versions of the functions below to call, some of which have assertions about what's allowed initValues(table); @@ -37,73 +38,60 @@ public: initCols(columnBins); } - float getValue(float xColumn, float yRow) const override { + float getValue(float xColumn, float yRow) const final { if (!m_values) { // not initialized, return 0 return 0; } - - auto row = priv::getBinPtr(yRow * m_rowMult, m_rowBins); - auto col = priv::getBinPtr(xColumn * m_colMult, m_columnBins); - // Orient the table such that (0, 0) is the bottom left corner, - // then the following variable names will make sense - float lowerLeft = getValueAtPosition(row.Idx, col.Idx); - float upperLeft = getValueAtPosition(row.Idx + 1, col.Idx); - float lowerRight = getValueAtPosition(row.Idx, col.Idx + 1); - float upperRight = getValueAtPosition(row.Idx + 1, col.Idx + 1); - - // Interpolate each side by itself - float left = priv::linterp(lowerLeft, upperLeft, row.Frac); - float right = priv::linterp(lowerRight, upperRight, row.Frac); - - // Then interpolate between those - float tableValue = priv::linterp(left, right, col.Frac); - - // Correct by the ratio of table units to "world" units - return tableValue * TValueMultiplier::asFloat(); + return interpolate3d(*m_values, + *m_rowBins, yRow * m_rowMult, + *m_columnBins, xColumn * m_colMult) * + TValueMultiplier::asFloat(); } void setAll(TValue value) { efiAssertVoid(CUSTOM_ERR_6573, m_values, "map not initialized"); - for (size_t i = 0; i < TRowNum * TColNum; i++) { - m_values[i] = value / TValueMultiplier::asFloat(); + for (size_t r = 0; r < TRowNum; r++) { + for (size_t c = 0; c < TColNum; c++) { + (*m_values)[r][c] = value / TValueMultiplier::asFloat(); + } } } private: template - void initValues(scaled_channel table[TRowNum][TColNum]) { + void initValues(scaled_channel (&table)[TRowNum][TColNum]) { static_assert(TValueMultiplier::den == TMult); static_assert(TValueMultiplier::num == 1); - m_values = reinterpret_cast(&table[0][0]); + m_values = reinterpret_cast(&table); } - void initValues(TValue table[TRowNum][TColNum]) { - m_values = &table[0][0]; + void initValues(TValue (&table)[TRowNum][TColNum]) { + m_values = &table; } template - void initRows(const scaled_channel rowBins[TRowNum]) { - m_rowBins = reinterpret_cast(&rowBins[0]); + void initRows(const scaled_channel (&rowBins)[TRowNum]) { + m_rowBins = reinterpret_cast(&rowBins); m_rowMult = TRowMult; } - void initRows(const TRow rowBins[TRowNum]) { - m_rowBins = &rowBins[0]; + void initRows(const TRow (&rowBins)[TRowNum]) { + m_rowBins = &rowBins; m_rowMult = 1; } template - void initCols(const scaled_channel columnBins[TColNum]) { - m_columnBins = reinterpret_cast(&columnBins[0]); + void initCols(const scaled_channel (&columnBins)[TColNum]) { + m_columnBins = reinterpret_cast(&columnBins); m_colMult = TColMult; } - void initCols(const TColumn columnBins[TColNum]) { - m_columnBins = &columnBins[0]; + void initCols(const TColumn (&columnBins)[TColNum]) { + m_columnBins = &columnBins; m_colMult = 1; } @@ -120,10 +108,10 @@ private: } // TODO: should be const - /*const*/ TValue* m_values = nullptr; + /*const*/ TValue (*m_values)[TRowNum][TColNum] = nullptr; - const TRow *m_rowBins = nullptr; - const TColumn *m_columnBins = nullptr; + const TRow (*m_rowBins)[TRowNum] = nullptr; + const TColumn (*m_columnBins)[TColNum] = nullptr; float m_rowMult = 1; float m_colMult = 1; diff --git a/firmware/util/math/interpolation.h b/firmware/util/math/interpolation.h index a08ae811bb..751fcb5aad 100644 --- a/firmware/util/math/interpolation.h +++ b/firmware/util/math/interpolation.h @@ -45,13 +45,13 @@ struct BinResult /** * @brief Finds the location of a value in the bin array. - * + * * @param value The value to find in the bins. * @return A result containing the index to the left of the value, * and how far from (idx) to (idx + 1) the value is located. */ template -BinResult getBinPtr(float value, const TBin* bins) { +BinResult getBin(float value, const TBin (&bins)[TSize]) { // Enforce numeric only (int, float, uintx_t, etc) static_assert(std::is_arithmetic_v, "Table bins must be an arithmetic type"); @@ -90,25 +90,13 @@ BinResult getBinPtr(float value, const TBin* bins) { // Compute how far along the bin we are // (0.0f = left side, 1.0f = right side) float fraction = (value - low) / (high - low); - + return { idx, fraction }; } -template -BinResult getBinPtr(float value, const scaled_channel* bins) { - // Strip off the scaled_channel, and perform the scaling before searching the array - auto binPtrRaw = reinterpret_cast(bins); - return getBinPtr(value * TMult, binPtrRaw); -} - -template -BinResult getBin(float value, const TBin (&bins)[TSize]) { - return getBinPtr(value, &bins[0]); -} - template BinResult getBin(float value, const scaled_channel (&bins)[TSize]) { - return getBinPtr(value, &bins[0]); + return getBin(value * TMult, *reinterpret_cast(&bins)); } static float linterp(float low, float high, float frac) @@ -132,6 +120,29 @@ float interpolate2d(const float value, const TBin (&bin)[TSize], const TValue (& return priv::linterp(low, high, frac); } +template +float interpolate3d(const VType (&table)[RNum][CNum], + const RType (&rowBins)[RNum], float rowValue, + const CType (&colBins)[CNum], float colValue) +{ + auto row = priv::getBin(rowValue, rowBins); + auto col = priv::getBin(colValue, colBins); + + // Orient the table such that (0, 0) is the bottom left corner, + // then the following variable names will make sense + float lowerLeft = table[row.Idx ][col.Idx ]; + float upperLeft = table[row.Idx + 1][col.Idx ]; + float lowerRight = table[row.Idx ][col.Idx + 1]; + float upperRight = table[row.Idx + 1][col.Idx + 1]; + + // Interpolate each side by itself + float left = priv::linterp(lowerLeft, upperLeft, row.Frac); + float right = priv::linterp(lowerRight, upperRight, row.Frac); + + // Then interpolate between those + return priv::linterp(left, right, col.Frac); +} + /** @brief Binary search * @returns the highest index within sorted array such that array[i] is greater than or equal to the parameter * @note If the parameter is smaller than the first element of the array, -1 is returned.