diff --git a/types/decimal.go b/types/decimal.go index c9b751d9c..03934e214 100644 --- a/types/decimal.go +++ b/types/decimal.go @@ -318,6 +318,32 @@ func (d Dec) QuoInt64(i int64) Dec { return Dec{mul} } +// ApproxSqrt returns an approximate sqrt estimation using Newton's method to +// compute square roots x=√d for d > 0. The algorithm starts with some guess and +// computes the sequence of improved guesses until an answer converges to an +// approximate answer. It returns -(sqrt(abs(d)) if input is negative. +func (d Dec) ApproxSqrt() Dec { + if d.IsNegative() { + return d.MulInt64(-1).ApproxSqrt().MulInt64(-1) + } + + if d.IsZero() { + return ZeroDec() + } + + z := OneDec() + // first guess + z = z.Sub((z.Mul(z).Sub(d)).Quo(z.MulInt64(2))) + + // iterate until change is very small + for zNew, delta := z, z; delta.GT(SmallestDec()); z = zNew { + zNew = zNew.Sub((zNew.Mul(zNew).Sub(d)).Quo(zNew.MulInt64(2))) + delta = z.Sub(zNew) + } + + return z +} + // is integer, e.g. decimals are zero func (d Dec) IsInteger() bool { return new(big.Int).Rem(d.Int, precisionReuse).Sign() == 0 diff --git a/types/decimal_test.go b/types/decimal_test.go index 99f9a1066..e51c48fc7 100644 --- a/types/decimal_test.go +++ b/types/decimal_test.go @@ -424,3 +424,22 @@ func TestDecCeil(t *testing.T) { require.Equal(t, tc.expected, res, "unexpected result for test case %d, input: %v", i, tc.input) } } + +func TestApproxSqrt(t *testing.T) { + testCases := []struct { + input Dec + expected Dec + }{ + {OneDec(), OneDec()}, // 1.0 => 1.0 + {NewDecWithPrec(25, 2), NewDecWithPrec(5, 1)}, // 0.25 => 0.5 + {NewDecWithPrec(4, 2), NewDecWithPrec(2, 1)}, // 0.09 => 0.3 + {NewDecFromInt(NewInt(9)), NewDecFromInt(NewInt(3))}, // 9 => 3 + {NewDecFromInt(NewInt(-9)), NewDecFromInt(NewInt(-3))}, // -9 => -3 + {NewDecFromInt(NewInt(2)), NewDecWithPrec(1414213562373095049, 18)}, // 2 => 1.414213562373095049 + } + + for i, tc := range testCases { + res := tc.input.ApproxSqrt() + require.Equal(t, tc.expected, res, "unexpected result for test case %d, input: %v", i, tc.input) + } +}