token-swap: Add pool token conversion interface (#934)

* Add pool token conversion interface for all curves

This reverts commit 8400bc7bfe4fcc18580d8f81cbb19a4ef5a437ff.

* Improve tests

* Run cargo fmt
This commit is contained in:
Jon Cinque 2020-12-09 11:30:22 +01:00 committed by GitHub
parent c0f5ff182c
commit 3dcb1c5665
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 111 additions and 81 deletions

10
Cargo.lock generated
View File

@ -709,13 +709,12 @@ dependencies = [
[[package]]
name = "curve25519-dalek"
version = "2.1.0"
source = "git+https://github.com/garious/curve25519-dalek?rev=60efef3553d6bf3d7f3b09b5f97acd54d72529ff#60efef3553d6bf3d7f3b09b5f97acd54d72529ff"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5d85653f070353a16313d0046f173f70d1aadd5b42600a14de626f0dfb3473a5"
dependencies = [
"borsh",
"byteorder",
"digest 0.8.1",
"rand_core",
"serde",
"subtle 2.2.3",
"zeroize",
]
@ -723,12 +722,13 @@ dependencies = [
[[package]]
name = "curve25519-dalek"
version = "2.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5d85653f070353a16313d0046f173f70d1aadd5b42600a14de626f0dfb3473a5"
source = "git+https://github.com/garious/curve25519-dalek?rev=60efef3553d6bf3d7f3b09b5f97acd54d72529ff#60efef3553d6bf3d7f3b09b5f97acd54d72529ff"
dependencies = [
"borsh",
"byteorder",
"digest 0.8.1",
"rand_core",
"serde",
"subtle 2.2.3",
"zeroize",
]

View File

@ -101,8 +101,10 @@ impl SwapCurve {
pub fn trading_tokens_to_pool_tokens(
&self,
source_amount: u128,
swap_source_amount: u128,
swap_token_a_amount: u128,
swap_token_b_amount: u128,
pool_supply: u128,
trade_direction: TradeDirection,
fees: &Fees,
) -> Option<u128> {
// Get the trading fee incurred if the owner fee is swapped for the other side
@ -110,8 +112,10 @@ impl SwapCurve {
let source_amount = source_amount.checked_sub(trade_fee)?;
self.calculator.trading_tokens_to_pool_tokens(
source_amount,
swap_source_amount,
swap_token_a_amount,
swap_token_b_amount,
pool_supply,
trade_direction,
)
}
}

View File

@ -85,12 +85,19 @@ pub trait CurveCalculator: Debug + DynPack {
&self,
pool_tokens: u128,
pool_token_supply: u128,
swap_token_amount: u128,
) -> Option<u128> {
pool_tokens
.checked_mul(swap_token_amount)?
.checked_div(pool_token_supply)
.and_then(map_zero_to_none)
swap_token_a_amount: u128,
swap_token_b_amount: u128,
) -> Option<TradingTokenResult> {
let token_a_amount = pool_tokens
.checked_mul(swap_token_a_amount)?
.checked_div(pool_token_supply)?;
let token_b_amount = pool_tokens
.checked_mul(swap_token_b_amount)?
.checked_div(pool_token_supply)?;
Some(TradingTokenResult {
token_a_amount,
token_b_amount,
})
}
/// Get the amount of pool tokens for the given amount of token A or B
@ -99,9 +106,15 @@ pub trait CurveCalculator: Debug + DynPack {
fn trading_tokens_to_pool_tokens(
&self,
source_amount: u128,
swap_source_amount: u128,
swap_token_a_amount: u128,
swap_token_b_amount: u128,
pool_supply: u128,
trade_direction: TradeDirection,
) -> Option<u128> {
let swap_source_amount = match trade_direction {
TradeDirection::AtoB => swap_token_a_amount,
TradeDirection::BtoA => swap_token_b_amount,
};
pool_supply
.checked_mul(source_amount)?
.checked_div(swap_source_amount)?

View File

@ -101,26 +101,35 @@ mod tests {
assert_eq!(calculator.new_pool_supply(), INITIAL_SWAP_POOL_AMOUNT);
}
fn check_pool_token_rate(token_a: u128, deposit: u128, supply: u128, expected_a: u128) {
fn check_pool_token_rate(
token_a: u128,
token_b: u128,
deposit: u128,
supply: u128,
expected_a: u128,
expected_b: u128,
) {
let calculator = ConstantProductCurve {};
let results = calculator
.pool_tokens_to_trading_tokens(deposit, supply, token_a)
.pool_tokens_to_trading_tokens(deposit, supply, token_a, token_b)
.unwrap();
assert_eq!(results, expected_a);
assert_eq!(results.token_a_amount, expected_a);
assert_eq!(results.token_b_amount, expected_b);
}
#[test]
fn trading_token_conversion() {
check_pool_token_rate(2, 5, 10, 1);
check_pool_token_rate(10, 5, 10, 5);
check_pool_token_rate(5, 5, 10, 2);
check_pool_token_rate(5, 5, 10, 2);
check_pool_token_rate(2, 49, 5, 10, 1, 24);
check_pool_token_rate(100, 202, 5, 101, 4, 10);
check_pool_token_rate(5, 501, 2, 10, 1, 100);
}
#[test]
fn fail_trading_token_conversion() {
let calculator = ConstantProductCurve {};
let results = calculator.pool_tokens_to_trading_tokens(5, 10, u128::MAX);
let results = calculator.pool_tokens_to_trading_tokens(5, 10, u128::MAX, 0);
assert!(results.is_none());
let results = calculator.pool_tokens_to_trading_tokens(5, 10, 0, u128::MAX);
assert!(results.is_none());
}

View File

@ -183,28 +183,37 @@ mod tests {
assert_eq!(calculator.new_pool_supply(), INITIAL_SWAP_POOL_AMOUNT);
}
fn check_pool_token_rate(token_a: u128, deposit: u128, supply: u128, expected_a: u128) {
fn check_pool_token_rate(
token_a: u128,
token_b: u128,
deposit: u128,
supply: u128,
expected_a: u128,
expected_b: u128,
) {
let amp = 1;
let calculator = StableCurve { amp };
let results = calculator
.pool_tokens_to_trading_tokens(deposit, supply, token_a)
.pool_tokens_to_trading_tokens(deposit, supply, token_a, token_b)
.unwrap();
assert_eq!(results, expected_a);
assert_eq!(results.token_a_amount, expected_a);
assert_eq!(results.token_b_amount, expected_b);
}
#[test]
fn trading_token_conversion() {
check_pool_token_rate(2, 5, 10, 1);
check_pool_token_rate(10, 5, 10, 5);
check_pool_token_rate(5, 5, 10, 2);
check_pool_token_rate(5, 5, 10, 2);
check_pool_token_rate(2, 49, 5, 10, 1, 24);
check_pool_token_rate(100, 202, 5, 101, 4, 10);
check_pool_token_rate(5, 501, 2, 10, 1, 100);
}
#[test]
fn fail_trading_token_conversion() {
let amp = 1;
let calculator = StableCurve { amp };
let results = calculator.pool_tokens_to_trading_tokens(5, 10, u128::MAX);
let results = calculator.pool_tokens_to_trading_tokens(5, 10, u128::MAX, 0);
assert!(results.is_none());
let results = calculator.pool_tokens_to_trading_tokens(5, 10, 0, u128::MAX);
assert!(results.is_none());
}

View File

@ -365,12 +365,20 @@ impl Processor {
// mint pool tokens equivalent to the owner fee
let source_account =
Self::unpack_token_account(swap_source_info, &token_swap.token_program_id)?;
let destination_account =
Self::unpack_token_account(swap_destination_info, &token_swap.token_program_id)?;
let (swap_token_a_amount, swap_token_b_amount) = match trade_direction {
TradeDirection::AtoB => (source_account.amount, destination_account.amount),
TradeDirection::BtoA => (destination_account.amount, source_account.amount),
};
let mut pool_token_amount = token_swap
.swap_curve
.trading_tokens_to_pool_tokens(
result.owner_fee,
to_u128(source_account.amount)?,
to_u128(swap_token_a_amount)?,
to_u128(swap_token_b_amount)?,
to_u128(pool_mint.supply)?,
trade_direction,
&token_swap.fees,
)
.ok_or(SwapError::FeeCalculationFailure)?;
@ -470,28 +478,28 @@ impl Processor {
let calculator = token_swap.swap_curve.calculator;
let token_a_amount = calculator
let results = calculator
.pool_tokens_to_trading_tokens(
pool_token_amount,
pool_mint_supply,
to_u128(token_a.amount)?,
)
.ok_or(SwapError::ZeroTradingTokens)?;
let token_a_amount = to_u64(token_a_amount)?;
if token_a_amount > maximum_token_a_amount {
return Err(SwapError::ExceededSlippage.into());
}
let token_b_amount = calculator
.pool_tokens_to_trading_tokens(
pool_token_amount,
pool_mint_supply,
to_u128(token_b.amount)?,
)
.ok_or(SwapError::ZeroTradingTokens)?;
let token_b_amount = to_u64(token_b_amount)?;
let token_a_amount = to_u64(results.token_a_amount)?;
if token_a_amount > maximum_token_a_amount {
return Err(SwapError::ExceededSlippage.into());
}
if token_a_amount == 0 {
return Err(SwapError::ZeroTradingTokens.into());
}
let token_b_amount = to_u64(results.token_b_amount)?;
if token_b_amount > maximum_token_b_amount {
return Err(SwapError::ExceededSlippage.into());
}
if token_b_amount == 0 {
return Err(SwapError::ZeroTradingTokens.into());
}
Self::token_transfer(
swap_info.key,
@ -591,30 +599,29 @@ impl Processor {
let pool_token_amount = to_u128(pool_token_amount)?
.checked_sub(withdraw_fee)
.ok_or(SwapError::CalculationFailure)?;
let pool_mint_supply = to_u128(pool_mint.supply)?;
let token_a_amount = calculator
let results = calculator
.pool_tokens_to_trading_tokens(
pool_token_amount,
pool_mint_supply,
to_u128(pool_mint.supply)?,
to_u128(token_a.amount)?,
)
.ok_or(SwapError::ZeroTradingTokens)?;
let token_a_amount = to_u64(token_a_amount)?;
if token_a_amount < minimum_token_a_amount {
return Err(SwapError::ExceededSlippage.into());
}
let token_b_amount = calculator
.pool_tokens_to_trading_tokens(
pool_token_amount,
pool_mint_supply,
to_u128(token_b.amount)?,
)
.ok_or(SwapError::ZeroTradingTokens)?;
let token_b_amount = to_u64(token_b_amount)?;
let token_a_amount = to_u64(results.token_a_amount)?;
if token_a_amount < minimum_token_a_amount {
return Err(SwapError::ExceededSlippage.into());
}
if token_a_amount == 0 {
return Err(SwapError::ZeroTradingTokens.into());
}
let token_b_amount = to_u64(results.token_b_amount)?;
if token_b_amount < minimum_token_b_amount {
return Err(SwapError::ExceededSlippage.into());
}
if token_b_amount == 0 {
return Err(SwapError::ZeroTradingTokens.into());
}
Self::token_transfer(
swap_info.key,
@ -3552,41 +3559,33 @@ mod tests {
let pool_mint =
spl_token::state::Mint::unpack(&accounts.pool_mint_account.data).unwrap();
let withdraw_fee = accounts.fees.owner_withdraw_fee(withdraw_amount).unwrap();
let withdraw_token_a_amount = accounts
let results = accounts
.swap_curve
.calculator
.pool_tokens_to_trading_tokens(
withdraw_amount - withdraw_fee,
pool_mint.supply.try_into().unwrap(),
swap_token_a.amount.try_into().unwrap(),
)
.unwrap();
let withdraw_token_b_amount = accounts
.swap_curve
.calculator
.pool_tokens_to_trading_tokens(
withdraw_amount - withdraw_fee,
pool_mint.supply.try_into().unwrap(),
swap_token_b.amount.try_into().unwrap(),
)
.unwrap();
assert_eq!(
swap_token_a.amount,
token_a_amount - to_u64(withdraw_token_a_amount).unwrap()
token_a_amount - to_u64(results.token_a_amount).unwrap()
);
assert_eq!(
swap_token_b.amount,
token_b_amount - to_u64(withdraw_token_b_amount).unwrap()
token_b_amount - to_u64(results.token_b_amount).unwrap()
);
let token_a = spl_token::state::Account::unpack(&token_a_account.data).unwrap();
assert_eq!(
token_a.amount,
initial_a + to_u64(withdraw_token_a_amount).unwrap()
initial_a + to_u64(results.token_a_amount).unwrap()
);
let token_b = spl_token::state::Account::unpack(&token_b_account.data).unwrap();
assert_eq!(
token_b.amount,
initial_b + to_u64(withdraw_token_b_amount).unwrap()
initial_b + to_u64(results.token_b_amount).unwrap()
);
let pool_account = spl_token::state::Account::unpack(&pool_account.data).unwrap();
assert_eq!(
@ -3638,33 +3637,25 @@ mod tests {
spl_token::state::Account::unpack(&accounts.token_b_account.data).unwrap();
let pool_mint =
spl_token::state::Mint::unpack(&accounts.pool_mint_account.data).unwrap();
let token_a_amount = accounts
let results = accounts
.swap_curve
.calculator
.pool_tokens_to_trading_tokens(
pool_fee_amount.try_into().unwrap(),
pool_mint.supply.try_into().unwrap(),
swap_token_a.amount.try_into().unwrap(),
)
.unwrap();
let token_b_amount = accounts
.swap_curve
.calculator
.pool_tokens_to_trading_tokens(
pool_fee_amount.try_into().unwrap(),
pool_mint.supply.try_into().unwrap(),
swap_token_b.amount.try_into().unwrap(),
)
.unwrap();
let token_a = spl_token::state::Account::unpack(&token_a_account.data).unwrap();
assert_eq!(
token_a.amount,
TryInto::<u64>::try_into(token_a_amount).unwrap()
TryInto::<u64>::try_into(results.token_a_amount).unwrap()
);
let token_b = spl_token::state::Account::unpack(&token_b_account.data).unwrap();
assert_eq!(
token_b.amount,
TryInto::<u64>::try_into(token_b_amount).unwrap()
TryInto::<u64>::try_into(results.token_b_amount).unwrap()
);
}
}
@ -3762,7 +3753,9 @@ mod tests {
.trading_tokens_to_pool_tokens(
results.owner_fee,
token_a_amount.try_into().unwrap(),
token_b_amount.try_into().unwrap(),
initial_supply.try_into().unwrap(),
TradeDirection::AtoB,
&fees,
)
.unwrap();
@ -3835,8 +3828,10 @@ mod tests {
let second_fee = swap_curve
.trading_tokens_to_pool_tokens(
results.owner_fee,
token_a_amount.try_into().unwrap(),
token_b_amount.try_into().unwrap(),
initial_supply.try_into().unwrap(),
TradeDirection::BtoA,
&fees,
)
.unwrap();