From 2518e95fb07c3d1885f27da735ded2227bc904de Mon Sep 17 00:00:00 2001 From: Jack May Date: Wed, 17 Apr 2019 11:28:26 -0700 Subject: [PATCH] Add bench-exchange (#3826) --- Cargo.lock | 58 ++ Cargo.toml | 1 + bench-exchange/Cargo.toml | 38 ++ bench-exchange/src/bench.rs | 1099 ++++++++++++++++++++++++++++++ bench-exchange/src/cli.rs | 173 +++++ bench-exchange/src/main.rs | 73 ++ bench-exchange/src/order_book.rs | 138 ++++ 7 files changed, 1580 insertions(+) create mode 100644 bench-exchange/Cargo.toml create mode 100644 bench-exchange/src/bench.rs create mode 100644 bench-exchange/src/cli.rs create mode 100644 bench-exchange/src/main.rs create mode 100644 bench-exchange/src/order_book.rs diff --git a/Cargo.lock b/Cargo.lock index 6e95ac0ac..255ed363c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1435,6 +1435,16 @@ dependencies = [ "version_check 0.1.5 (registry+https://github.com/rust-lang/crates.io-index)", ] +[[package]] +name = "num-derive" +version = "0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +dependencies = [ + "proc-macro2 0.4.27 (registry+https://github.com/rust-lang/crates.io-index)", + "quote 0.6.11 (registry+https://github.com/rust-lang/crates.io-index)", + "syn 0.15.29 (registry+https://github.com/rust-lang/crates.io-index)", +] + [[package]] name = "num-integer" version = "0.1.39" @@ -2198,6 +2208,35 @@ dependencies = [ "untrusted 0.6.2 (registry+https://github.com/rust-lang/crates.io-index)", ] +[[package]] +name = "solana-bench-exchange" +version = "0.14.0" +dependencies = [ + "bincode 1.1.3 (registry+https://github.com/rust-lang/crates.io-index)", + "bs58 0.2.2 (registry+https://github.com/rust-lang/crates.io-index)", + "clap 2.33.0 (registry+https://github.com/rust-lang/crates.io-index)", + "env_logger 0.6.1 (registry+https://github.com/rust-lang/crates.io-index)", + "itertools 0.8.0 (registry+https://github.com/rust-lang/crates.io-index)", + "log 0.4.6 (registry+https://github.com/rust-lang/crates.io-index)", + "num-derive 0.2.4 (registry+https://github.com/rust-lang/crates.io-index)", + "num-traits 0.2.6 (registry+https://github.com/rust-lang/crates.io-index)", + "rayon 1.0.3 (registry+https://github.com/rust-lang/crates.io-index)", + "serde 1.0.90 (registry+https://github.com/rust-lang/crates.io-index)", + "serde_derive 1.0.90 (registry+https://github.com/rust-lang/crates.io-index)", + "serde_json 1.0.39 (registry+https://github.com/rust-lang/crates.io-index)", + "solana 0.14.0", + "solana-client 0.14.0", + "solana-drone 0.14.0", + "solana-exchange-api 0.14.0", + "solana-exchange-program 0.14.0", + "solana-logger 0.14.0", + "solana-metrics 0.14.0", + "solana-netutil 0.14.0", + "solana-sdk 0.14.0", + "untrusted 0.6.2 (registry+https://github.com/rust-lang/crates.io-index)", + "ws 0.7.9 (registry+https://github.com/rust-lang/crates.io-index)", +] + [[package]] name = "solana-bench-streamer" version = "0.14.0" @@ -3275,6 +3314,23 @@ dependencies = [ "winapi-util 0.1.2 (registry+https://github.com/rust-lang/crates.io-index)", ] +[[package]] +name = "ws" +version = "0.7.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +dependencies = [ + "byteorder 1.3.1 (registry+https://github.com/rust-lang/crates.io-index)", + "bytes 0.4.12 (registry+https://github.com/rust-lang/crates.io-index)", + "httparse 1.3.3 (registry+https://github.com/rust-lang/crates.io-index)", + "log 0.4.6 (registry+https://github.com/rust-lang/crates.io-index)", + "mio 0.6.16 (registry+https://github.com/rust-lang/crates.io-index)", + "mio-extras 2.0.5 (registry+https://github.com/rust-lang/crates.io-index)", + "rand 0.4.6 (registry+https://github.com/rust-lang/crates.io-index)", + "sha1 0.6.0 (registry+https://github.com/rust-lang/crates.io-index)", + "slab 0.4.2 (registry+https://github.com/rust-lang/crates.io-index)", + "url 1.7.2 (registry+https://github.com/rust-lang/crates.io-index)", +] + [[package]] name = "ws" version = "0.8.0" @@ -3477,6 +3533,7 @@ dependencies = [ "checksum nix 0.13.0 (registry+https://github.com/rust-lang/crates.io-index)" = "46f0f3210768d796e8fa79ec70ee6af172dacbe7147f5e69be5240a47778302b" "checksum nodrop 0.1.13 (registry+https://github.com/rust-lang/crates.io-index)" = "2f9667ddcc6cc8a43afc9b7917599d7216aa09c463919ea32c59ed6cac8bc945" "checksum nom 4.2.2 (registry+https://github.com/rust-lang/crates.io-index)" = "22293d25d3f33a8567cc8a1dc20f40c7eeb761ce83d0fcca059858580790cac3" +"checksum num-derive 0.2.4 (registry+https://github.com/rust-lang/crates.io-index)" = "d9fe8fcafd1b86a37ce8a1cfa15ae504817e0c8c2e7ad42767371461ac1d316d" "checksum num-integer 0.1.39 (registry+https://github.com/rust-lang/crates.io-index)" = "e83d528d2677f0518c570baf2b7abdcf0cd2d248860b68507bdcb3e91d4c0cea" "checksum num-traits 0.1.43 (registry+https://github.com/rust-lang/crates.io-index)" = "92e5113e9fd4cc14ded8e499429f396a20f98c772a47cc8622a736e1ec843c31" "checksum num-traits 0.2.6 (registry+https://github.com/rust-lang/crates.io-index)" = "0b3a5d7cc97d6d30d8b9bc8fa19bf45349ffe46241e8816f50f62f6d6aaabee1" @@ -3624,6 +3681,7 @@ dependencies = [ "checksum winapi-util 0.1.2 (registry+https://github.com/rust-lang/crates.io-index)" = "7168bab6e1daee33b4557efd0e95d5ca70a03706d39fa5f3fe7a236f584b03c9" "checksum winapi-x86_64-pc-windows-gnu 0.4.0 (registry+https://github.com/rust-lang/crates.io-index)" = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f" "checksum wincolor 1.0.1 (registry+https://github.com/rust-lang/crates.io-index)" = "561ed901ae465d6185fa7864d63fbd5720d0ef718366c9a4dc83cf6170d7e9ba" +"checksum ws 0.7.9 (registry+https://github.com/rust-lang/crates.io-index)" = "329d3e6dd450a9c5c73024e1047f0be7e24121a68484eb0b5368977bee3cf8c3" "checksum ws 0.8.0 (registry+https://github.com/rust-lang/crates.io-index)" = "3dcacc3ba9c1ee43e3fd0846a25489ff22f8906e90775d51b6edbae4b95d71f4" "checksum ws2_32-sys 0.2.1 (registry+https://github.com/rust-lang/crates.io-index)" = "d59cefebd0c892fa2dd6de581e937301d8552cb44489cdff035c6187cb63fa5e" "checksum xattr 0.2.2 (registry+https://github.com/rust-lang/crates.io-index)" = "244c3741f4240ef46274860397c7c74e50eb23624996930e484c16679633a54c" diff --git a/Cargo.toml b/Cargo.toml index 9abf44a54..a0e961b03 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,7 @@ [workspace] members = [ ".", + "bench-exchange", "bench-streamer", "bench-tps", "drone", diff --git a/bench-exchange/Cargo.toml b/bench-exchange/Cargo.toml new file mode 100644 index 000000000..3e84634bd --- /dev/null +++ b/bench-exchange/Cargo.toml @@ -0,0 +1,38 @@ +[package] +authors = ["Solana Maintainers "] +edition = "2018" +name = "solana-bench-exchange" +version = "0.14.0" +repository = "https://github.com/solana-labs/solana" +license = "Apache-2.0" +homepage = "https://solana.com/" + +[dependencies] +bs58 = "0.2.0" +clap = "2.32.0" +bincode = "1.1.2" +env_logger = "0.6.0" +itertools = "0.8.0" +log = "0.4.6" +num-traits = "0.2" +num-derive = "0.2" +rayon = "1.0.3" +serde = "1.0.87" +serde_derive = "1.0.87" +serde_json = "1.0.38" +# solana-runtime = { path = "../solana/runtime"} +solana = { path = "../core", version = "0.14.0" } +solana-client = { path = "../client", version = "0.14.0" } +solana-drone = { path = "../drone", version = "0.14.0" } +solana-exchange-api = { path = "../instruction-processors/exchange_api", version = "0.14.0" } +solana-exchange-program = { path = "../instruction-processors/exchange_program", version = "0.14.0" } +solana-logger = { path = "../logger", version = "0.14.0" } +solana-metrics = { path = "../metrics", version = "0.14.0" } +solana-netutil = { path = "../netutil", version = "0.14.0" } +solana-sdk = { path = "../sdk", version = "0.14.0" } +ws = "0.7.9" +untrusted = "0.6.2" + +[features] +cuda = ["solana/cuda"] +erasure = [] \ No newline at end of file diff --git a/bench-exchange/src/bench.rs b/bench-exchange/src/bench.rs new file mode 100644 index 000000000..ad58d8d70 --- /dev/null +++ b/bench-exchange/src/bench.rs @@ -0,0 +1,1099 @@ +#![allow(clippy::useless_attribute)] + +use crate::order_book::*; +use itertools::izip; +use log::*; +use rayon::prelude::*; +use solana::gen_keys::GenKeys; +use solana_drone::drone::request_airdrop_transaction; +use solana_exchange_api::exchange_instruction; +use solana_exchange_api::exchange_state::*; +use solana_exchange_api::id; +use solana_sdk::client::Client; +use solana_sdk::client::{AsyncClient, SyncClient}; +use solana_sdk::pubkey::Pubkey; +use solana_sdk::signature::{Keypair, KeypairUtil}; +use solana_sdk::system_instruction; +use solana_sdk::transaction::Transaction; +use std::cmp; +use std::collections::VecDeque; +use std::net::SocketAddr; +use std::process::exit; +use std::sync::atomic::{AtomicBool, AtomicIsize, AtomicUsize, Ordering}; +use std::sync::mpsc::{channel, Receiver, Sender}; +use std::sync::{Arc, RwLock}; +use std::thread::sleep; +use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH}; +use std::{mem, thread}; + +// TODO Chunk length as specified results in a bunch of failures, divide by 10 helps... +// Assume 4MB network buffers, and 512 byte packets +const CHUNK_LEN: usize = 4 * 1024 * 1024 / 512 / 10; + +// Maximum system transfers per transaction +const MAX_TRANSFERS_PER_TX: u64 = 4; + +// Interval between fetching a new blockhash +const BLOCKHASH_RENEW_PERIOD_S: u64 = 30; + +pub type SharedTransactions = Arc>>>; + +pub struct Config { + pub identity: Keypair, + pub threads: usize, + pub duration: Duration, + pub trade_delay: u64, + pub fund_amount: u64, + pub batch_size: usize, + pub account_groups: usize, +} + +impl Default for Config { + fn default() -> Self { + Self { + identity: Keypair::new(), + threads: 4, + duration: Duration::new(u64::max_value(), 0), + trade_delay: 0, + fund_amount: 100_000, + batch_size: 10, + account_groups: 100, + } + } +} + +#[derive(Default)] +pub struct SampleStats { + /// Maximum TPS reported by this node + pub tps: f64, + /// Total time taken for those txs + pub tx_time: Duration, + /// Total transactions reported by this node + pub tx_count: u64, +} + +pub fn do_bench_exchange(client_ctors: Vec, config: Config) +where + F: Fn() -> T, + F: 'static + std::marker::Sync + std::marker::Send, + T: Client, +{ + let Config { + identity, + threads, + duration, + trade_delay, + fund_amount, + batch_size, + account_groups, + } = config; + let accounts_in_groups = batch_size * account_groups; + let exit_signal = Arc::new(AtomicBool::new(false)); + let client_ctors: Vec<_> = client_ctors.into_iter().map(Arc::new).collect(); + let client = client_ctors[0](); + + let total_keys = accounts_in_groups as u64 * 5; + info!("Generating {:?} keys", total_keys); + let mut keypairs = generate_keypairs(total_keys); + let trader_signers: Vec<_> = keypairs + .drain(0..accounts_in_groups) + .map(Arc::new) + .collect(); + let swapper_signers: Vec<_> = keypairs + .drain(0..accounts_in_groups) + .map(Arc::new) + .collect(); + let src_pubkeys: Vec<_> = keypairs + .drain(0..accounts_in_groups) + .map(|keypair| keypair.pubkey()) + .collect(); + let dst_pubkeys: Vec<_> = keypairs + .drain(0..accounts_in_groups) + .map(|keypair| keypair.pubkey()) + .collect(); + let profit_pubkeys: Vec<_> = keypairs + .drain(0..accounts_in_groups) + .map(|keypair| keypair.pubkey()) + .collect(); + + info!("Fund trader accounts"); + fund_keys(&client, &identity, &trader_signers, fund_amount); + info!("Fund swapper accounts"); + fund_keys(&client, &identity, &swapper_signers, fund_amount); + + info!("Create {:?} source token accounts", src_pubkeys.len()); + create_token_accounts(&client, &trader_signers, &src_pubkeys); + info!("Create {:?} destination token accounts", dst_pubkeys.len()); + create_token_accounts(&client, &trader_signers, &dst_pubkeys); + info!("Create {:?} profit token accounts", profit_pubkeys.len()); + create_token_accounts(&client, &swapper_signers, &profit_pubkeys); + + // Collect the max transaction rate and total tx count seen (single node only) + let sample_stats = Arc::new(RwLock::new(Vec::new())); + let sample_period = 1; // in seconds + info!("Sampling clients for tps every {} s", sample_period); + + let sample_threads: Vec<_> = client_ctors + .iter() + .map(|ctor| { + let exit_signal = exit_signal.clone(); + let sample_stats = sample_stats.clone(); + let client_ctor = ctor.clone(); + thread::spawn(move || { + sample_tx_count(&exit_signal, &sample_stats, sample_period, &client_ctor) + }) + }) + .collect(); + + let shared_txs: SharedTransactions = Arc::new(RwLock::new(VecDeque::new())); + let shared_tx_active_thread_count = Arc::new(AtomicIsize::new(0)); + let total_tx_sent_count = Arc::new(AtomicUsize::new(0)); + let s_threads: Vec<_> = (0..threads) + .map(|_| { + let exit_signal = exit_signal.clone(); + let shared_txs = shared_txs.clone(); + let shared_tx_active_thread_count = shared_tx_active_thread_count.clone(); + let total_tx_sent_count = total_tx_sent_count.clone(); + let client_ctor = client_ctors[0].clone(); + thread::spawn(move || { + do_tx_transfers( + &exit_signal, + &shared_txs, + &shared_tx_active_thread_count, + &total_tx_sent_count, + &client_ctor, + ) + }) + }) + .collect(); + + trace!("Start swapper thread"); + let (swapper_sender, swapper_receiver) = channel(); + let swapper_thread = { + let exit_signal = exit_signal.clone(); + let shared_txs = shared_txs.clone(); + let shared_tx_active_thread_count = shared_tx_active_thread_count.clone(); + let client_ctor = client_ctors[0].clone(); + thread::spawn(move || { + swapper( + &exit_signal, + &swapper_receiver, + &shared_txs, + &shared_tx_active_thread_count, + &swapper_signers, + &profit_pubkeys, + batch_size, + account_groups, + &client_ctor, + ) + }) + }; + + trace!("Start trader thread"); + let trader_thread = { + let exit_signal = exit_signal.clone(); + let shared_txs = shared_txs.clone(); + let shared_tx_active_thread_count = shared_tx_active_thread_count.clone(); + let client_ctor = client_ctors[0].clone(); + thread::spawn(move || { + trader( + &exit_signal, + &swapper_sender, + &shared_txs, + &shared_tx_active_thread_count, + &trader_signers, + &src_pubkeys, + &dst_pubkeys, + trade_delay, + batch_size, + account_groups, + &client_ctor, + ) + }) + }; + + info!("Requesting and swapping trades"); + sleep(duration); + + exit_signal.store(true, Ordering::Relaxed); + let _ = trader_thread.join(); + let _ = swapper_thread.join(); + for t in s_threads { + let _ = t.join(); + } + for t in sample_threads { + let _ = t.join(); + } + + compute_and_report_stats(&sample_stats, total_tx_sent_count.load(Ordering::Relaxed)); +} + +fn sample_tx_count( + exit_signal: &Arc, + sample_stats: &Arc>>, + sample_period: u64, + client_ctor: &Arc, +) where + F: Fn() -> T, + T: Client, +{ + let client = client_ctor(); + let mut max_tps = 0.0; + let mut total_tx_time; + let mut total_tx_count; + let mut now = Instant::now(); + let start_time = now; + let mut initial_tx_count = client.get_transaction_count().expect("transaction count"); + let first_tx_count = initial_tx_count; + + loop { + let tx_count = client.get_transaction_count().expect("transaction count"); + let duration = now.elapsed(); + now = Instant::now(); + assert!( + tx_count >= initial_tx_count, + "expected tx_count({}) >= initial_tx_count({})", + tx_count, + initial_tx_count + ); + let sample = tx_count - initial_tx_count; + initial_tx_count = tx_count; + + let ns = duration.as_secs() * 1_000_000_000 + u64::from(duration.subsec_nanos()); + let tps = (sample * 1_000_000_000) as f64 / ns as f64; + if tps > max_tps { + max_tps = tps; + } + total_tx_time = start_time.elapsed(); + total_tx_count = tx_count - first_tx_count; + trace!( + "Sampler {:9.2} TPS, Transactions: {:6}, Total transactions: {} over {} s", + tps, + sample, + total_tx_count, + total_tx_time.as_secs(), + ); + + if exit_signal.load(Ordering::Relaxed) { + let stats = SampleStats { + tps: max_tps, + tx_time: total_tx_time, + tx_count: total_tx_count, + }; + sample_stats.write().unwrap().push(stats); + break; + } + sleep(Duration::from_secs(sample_period)); + } +} + +fn do_tx_transfers( + exit_signal: &Arc, + shared_txs: &SharedTransactions, + shared_tx_thread_count: &Arc, + total_tx_sent_count: &Arc, + client_ctor: &Arc, +) where + F: Fn() -> T, + T: Client, +{ + let client = client_ctor(); + let async_client: &AsyncClient = &client; + let mut stats = Stats::default(); + loop { + let txs; + { + let mut shared_txs_wl = shared_txs.write().unwrap(); + txs = shared_txs_wl.pop_front(); + } + match txs { + Some(txs0) => { + let n = txs0.len(); + + shared_tx_thread_count.fetch_add(1, Ordering::Relaxed); + let now = Instant::now(); + for tx in txs0 { + async_client.async_send_transaction(tx).expect("Transfer"); + } + let duration = now.elapsed(); + shared_tx_thread_count.fetch_add(-1, Ordering::Relaxed); + + total_tx_sent_count.fetch_add(n, Ordering::Relaxed); + stats.total += n as u64; + let sent_ns = + duration.as_secs() * 1_000_000_000 + u64::from(duration.subsec_nanos()); + stats.sent_ns += sent_ns; + let rate = (n as f64 / sent_ns as f64) * 1_000_000_000_f64; + if rate > stats.sent_peak_rate { + stats.sent_peak_rate = rate; + } + trace!(" tx {:?} sent {:.2}/s", n, rate); + } + None => { + if exit_signal.load(Ordering::Relaxed) { + info!( + " Thread Transferred {} Txs, avg {:.2}/s peak {:.2}/s", + stats.total, + (stats.total as f64 / stats.sent_ns as f64) * 1_000_000_000_f64, + stats.sent_peak_rate, + ); + break; + } + } + } + } +} + +#[derive(Default)] +struct Stats { + total: u64, + keygen_ns: u64, + keygen_peak_rate: f64, + sign_ns: u64, + sign_peak_rate: f64, + sent_ns: u64, + sent_peak_rate: f64, +} + +struct TradeInfo { + trade_account: Pubkey, + order_info: TradeOrderInfo, +} +#[allow(clippy::too_many_arguments)] +fn swapper( + exit_signal: &Arc, + receiver: &Receiver>, + shared_txs: &SharedTransactions, + shared_tx_active_thread_count: &Arc, + signers: &[Arc], + profit_pubkeys: &[Pubkey], + batch_size: usize, + account_groups: usize, + client_ctor: &Arc, +) where + F: Fn() -> T, + T: Client, +{ + let client = client_ctor(); + let mut stats = Stats::default(); + let mut order_book = OrderBook::default(); + let mut account_group: usize = 0; + let mut one_more_time = true; + let mut blockhash = client + .get_recent_blockhash() + .expect("Failed to get blockhash"); + let mut blockhash_now = UNIX_EPOCH; + 'outer: loop { + if let Ok(trade_infos) = receiver.try_recv() { + let mut tries = 0; + while client + .get_balance(&trade_infos[0].trade_account) + .unwrap_or(0) + == 0 + { + tries += 1; + if tries > 10 { + debug!("Give up waiting, dump batch"); + continue 'outer; + } + debug!("{} waiting for trades batch to clear", tries); + sleep(Duration::from_millis(100)); + } + + trade_infos.iter().for_each(|info| { + order_book + .push(info.trade_account, info.order_info) + .expect("Failed to push to order_book"); + }); + let mut swaps = Vec::new(); + while let Some((to, from)) = order_book.pop() { + swaps.push((to, from)); + if swaps.len() >= batch_size { + break; + } + } + let swaps_size = swaps.len(); + stats.total += swaps_size as u64; + + let now = Instant::now(); + let swap_keys = generate_keypairs(swaps_size as u64); + + let mut to_swap = vec![]; + let start = account_group * swaps_size as usize; + let end = account_group * swaps_size as usize + batch_size as usize; + for (signer, swap, swap_key, profit) in izip!( + signers[start..end].iter(), + swaps, + swap_keys, + profit_pubkeys[start..end].iter(), + ) { + to_swap.push((signer, swap_key, swap, profit)); + } + account_group = (account_group + 1) % account_groups as usize; + let duration = now.elapsed(); + let keypair_ns = + duration.as_secs() * 1_000_000_000 + u64::from(duration.subsec_nanos()); + let rate = (swaps_size as f64 / keypair_ns as f64) * 1_000_000_000_f64; + stats.keygen_ns += keypair_ns; + if rate > stats.keygen_peak_rate { + stats.keygen_peak_rate = rate; + } + trace!("sw {:?} keypairs {:.2} /s", swaps_size, rate); + + let now = Instant::now(); + + // Don't get a blockhash every time + if SystemTime::now() + .duration_since(blockhash_now) + .unwrap() + .as_secs() + > BLOCKHASH_RENEW_PERIOD_S + { + blockhash = client + .get_recent_blockhash() + .expect("Failed to get blockhash"); + blockhash_now = SystemTime::now(); + } + + let to_swap_txs: Vec<_> = to_swap + .par_iter() + .map(|(signer, swap_key, swap, profit)| { + let s: &Keypair = &signer; + let owner = &signer.pubkey(); + let space = mem::size_of::() as u64; + Transaction::new_signed_instructions( + &[s], + vec![ + system_instruction::create_account( + owner, + &swap_key.pubkey(), + 1, + space, + &id(), + ), + exchange_instruction::swap_request( + owner, + &swap_key.pubkey(), + &swap.0.pubkey, + &swap.1.pubkey, + &swap.0.info.dst_account, + &swap.1.info.dst_account, + &profit, + ), + ], + blockhash, + ) + }) + .collect(); + let duration = now.elapsed(); + let sign_ns = duration.as_secs() * 1_000_000_000 + u64::from(duration.subsec_nanos()); + let n = to_swap_txs.len(); + let rate = (n as f64 / sign_ns as f64) * 1_000_000_000_f64; + stats.sign_ns += sign_ns; + if rate > stats.sign_peak_rate { + stats.sign_peak_rate = rate; + } + trace!(" sw {:?} signed {:.2} /s ", n, rate); + + let chunks: Vec<_> = to_swap_txs.chunks(CHUNK_LEN).collect(); + { + let mut shared_txs_wl = shared_txs.write().unwrap(); + for chunk in chunks { + shared_txs_wl.push_back(chunk.to_vec()); + } + } + } + + while shared_tx_active_thread_count.load(Ordering::Relaxed) > 0 { + sleep(Duration::from_millis(100)); + } + + if exit_signal.load(Ordering::Relaxed) { + if !one_more_time { + info!("{} Swaps with batch size {}", stats.total, batch_size); + info!( + " Keygen avg {:.2}/s peak {:.2}/s", + (stats.total as f64 / stats.keygen_ns as f64) * 1_000_000_000_f64, + stats.keygen_peak_rate + ); + info!( + " Signed avg {:.2}/s peak {:.2}/s", + (stats.total as f64 / stats.sign_ns as f64) * 1_000_000_000_f64, + stats.sign_peak_rate + ); + assert_eq!( + order_book.get_num_outstanding().0 + order_book.get_num_outstanding().1, + 0 + ); + break; + } + // Grab any outstanding trades + sleep(Duration::from_secs(2)); + one_more_time = false; + } + } +} + +#[allow(clippy::too_many_arguments)] +fn trader( + exit_signal: &Arc, + sender: &Sender>, + shared_txs: &SharedTransactions, + shared_tx_active_thread_count: &Arc, + signers: &[Arc], + srcs: &[Pubkey], + dsts: &[Pubkey], + delay: u64, + batch_size: usize, + account_groups: usize, + client_ctor: &Arc, +) where + F: Fn() -> T, + T: Client, +{ + let client = client_ctor(); + let mut stats = Stats::default(); + + // TODO Hard coded for now + let pair = TokenPair::AB; + let tokens = 1; + let price = 1000; + let mut account_group: usize = 0; + let mut blockhash = client + .get_recent_blockhash() + .expect("Failed to get blockhash"); + let mut blockhash_now = UNIX_EPOCH; + + loop { + let now = Instant::now(); + let trade_keys = generate_keypairs(batch_size as u64); + + stats.total += batch_size as u64; + + let mut trades = vec![]; + let mut trade_infos = vec![]; + let start = account_group * batch_size as usize; + let end = account_group * batch_size as usize + batch_size as usize; + let mut direction = Direction::To; + for (signer, trade, src, dst) in izip!( + signers[start..end].iter(), + trade_keys, + srcs[start..end].iter(), + dsts[start..end].iter() + ) { + direction = if direction == Direction::To { + Direction::From + } else { + Direction::To + }; + let order_info = TradeOrderInfo { + /// Owner of the trade order + owner: Pubkey::default(), // don't care + direction, + pair, + tokens, + price, + src_account: Pubkey::default(), // don't care + dst_account: *dst, + }; + trade_infos.push(TradeInfo { + trade_account: trade.pubkey(), + order_info, + }); + trades.push((signer, trade.pubkey(), direction, src, dst)); + } + account_group = (account_group + 1) % account_groups as usize; + let duration = now.elapsed(); + let keypair_ns = duration.as_secs() * 1_000_000_000 + u64::from(duration.subsec_nanos()); + let rate = (batch_size as f64 / keypair_ns as f64) * 1_000_000_000_f64; + stats.keygen_ns += keypair_ns; + if rate > stats.keygen_peak_rate { + stats.keygen_peak_rate = rate; + } + trace!("sw {:?} keypairs {:.2} /s", batch_size, rate); + + trades.chunks(CHUNK_LEN).for_each(|chunk| { + let now = Instant::now(); + + // Don't get a blockhash every time + if SystemTime::now() + .duration_since(blockhash_now) + .unwrap() + .as_secs() + > BLOCKHASH_RENEW_PERIOD_S + { + blockhash = client + .get_recent_blockhash() + .expect("Failed to get blockhash"); + blockhash_now = SystemTime::now(); + } + + let trades_txs: Vec<_> = chunk + .par_iter() + .map(|(signer, trade, direction, src, dst)| { + let s: &Keypair = &signer; + let owner = &signer.pubkey(); + let space = mem::size_of::() as u64; + Transaction::new_signed_instructions( + &[s], + vec![ + system_instruction::create_account(owner, trade, 1, space, &id()), + exchange_instruction::trade_request( + owner, trade, *direction, pair, tokens, price, src, dst, + ), + ], + blockhash, + ) + }) + .collect(); + let duration = now.elapsed(); + let sign_ns = duration.as_secs() * 1_000_000_000 + u64::from(duration.subsec_nanos()); + let n = trades_txs.len(); + let rate = (n as f64 / sign_ns as f64) * 1_000_000_000_f64; + stats.sign_ns += sign_ns; + if rate > stats.sign_peak_rate { + stats.sign_peak_rate = rate; + } + trace!(" sw {:?} signed {:.2} /s ", n, rate); + + let chunks: Vec<_> = trades_txs.chunks(CHUNK_LEN).collect(); + { + let mut shared_txs_wl = shared_txs + .write() + .expect("Failed to send tx to transfer threads"); + for chunk in chunks { + shared_txs_wl.push_back(chunk.to_vec()); + } + } + + if delay > 0 { + sleep(Duration::from_millis(delay)); + } + }); + + sender + .send(trade_infos) + .expect("Failed to send trades to swapper"); + + while shared_tx_active_thread_count.load(Ordering::Relaxed) > 0 { + sleep(Duration::from_millis(100)); + } + + if exit_signal.load(Ordering::Relaxed) { + info!("{} Trades with batch size {}", stats.total, batch_size); + info!( + " Keygen avg {:.2}/s peak {:.2}/s", + (stats.total as f64 / stats.keygen_ns as f64) * 1_000_000_000_f64, + stats.keygen_peak_rate + ); + info!( + " Signed avg {:.2}/s peak {:.2}/s", + (stats.total as f64 / stats.sign_ns as f64) * 1_000_000_000_f64, + stats.sign_peak_rate + ); + break; + } + } +} + +fn verify_transfer(sync_client: &T, tx: &Transaction) -> bool +where + T: SyncClient + ?Sized, +{ + for s in &tx.signatures { + if let Ok(Some(_)) = sync_client.get_signature_status(s) { + return true; + } + } + false +} + +pub fn fund_keys(client: &Client, source: &Keypair, dests: &[Arc], lamports: u64) { + let total = lamports * (dests.len() as u64 + 1); + let mut funded: Vec<(&Keypair, u64)> = vec![(source, total)]; + let mut notfunded: Vec<&Arc> = dests.iter().collect(); + + info!( + " Funding {} keys with {} lamports each", + dests.len(), + lamports + ); + while !notfunded.is_empty() { + if funded.is_empty() { + panic!("No funded accounts left to fund remaining"); + } + let mut new_funded: Vec<(&Keypair, u64)> = vec![]; + let mut to_fund = vec![]; + debug!(" Creating from... {}", funded.len()); + for f in &mut funded { + let max_units = cmp::min( + cmp::min(notfunded.len() as u64, MAX_TRANSFERS_PER_TX), + (f.1 - lamports) / lamports, + ); + if max_units == 0 { + continue; + } + let per_unit = ((f.1 - lamports) / lamports / max_units) * lamports; + f.1 -= per_unit * max_units; + let start = notfunded.len() - max_units as usize; + let moves: Vec<_> = notfunded[start..] + .iter() + .map(|k| (k.pubkey(), per_unit)) + .collect(); + notfunded[start..] + .iter() + .for_each(|k| new_funded.push((k, per_unit))); + notfunded.truncate(start); + if !moves.is_empty() { + to_fund.push((f.0, moves)); + } + } + + to_fund.chunks(CHUNK_LEN).for_each(|chunk| { + #[allow(clippy::clone_double_ref)] // sigh + let mut to_fund_txs: Vec<_> = chunk + .par_iter() + .map(|(k, m)| { + ( + k.clone(), + Transaction::new_unsigned_instructions(system_instruction::transfer_many( + &k.pubkey(), + &m, + )), + ) + }) + .collect(); + + let mut retries = 0; + while !to_fund_txs.is_empty() { + let receivers = to_fund_txs + .iter() + .fold(0, |len, (_, tx)| len + tx.message().instructions.len()); + + debug!( + " {} to {} in {} txs", + if retries == 0 { + " Transferring" + } else { + " Retrying" + }, + receivers, + to_fund_txs.len(), + ); + + let blockhash = client.get_recent_blockhash().expect("blockhash"); + to_fund_txs.par_iter_mut().for_each(|(k, tx)| { + tx.sign(&[*k], blockhash); + }); + to_fund_txs.iter().for_each(|(_, tx)| { + client.async_send_transaction(tx.clone()).expect("transfer"); + }); + + let mut waits = 0; + loop { + sleep(Duration::from_millis(50)); + to_fund_txs.retain(|(_, tx)| !verify_transfer(client, &tx)); + if to_fund_txs.is_empty() { + break; + } + debug!( + " {} transactions outstanding, {:?} waits", + to_fund_txs.len(), + waits + ); + waits += 1; + if waits >= 5 { + break; + } + } + if !to_fund_txs.is_empty() { + retries += 1; + debug!(" Retry {:?}", retries); + if retries >= 10 { + error!(" Too many retries, give up"); + exit(1); + } + } + } + }); + funded.append(&mut new_funded); + funded.retain(|(k, b)| { + client.get_balance(&k.pubkey()).unwrap_or(0) > lamports && *b > lamports + }); + debug!(" Funded: {} left: {}", funded.len(), notfunded.len()); + } +} + +pub fn create_token_accounts(client: &Client, signers: &[Arc], accounts: &[Pubkey]) { + let mut notfunded: Vec<(&Arc, &Pubkey)> = signers.iter().zip(accounts).collect(); + + while !notfunded.is_empty() { + notfunded.chunks(CHUNK_LEN).for_each(|chunk| { + let mut to_create_txs: Vec<_> = chunk + .par_iter() + .map(|(signer, new)| { + let owner_id = &signer.pubkey(); + let space = mem::size_of::() as u64; + let create_ix = + system_instruction::create_account(owner_id, new, 1, space, &id()); + let request_ix = exchange_instruction::account_request(owner_id, new); + ( + signer, + Transaction::new_unsigned_instructions(vec![create_ix, request_ix]), + ) + }) + .collect(); + + let accounts = to_create_txs + .iter() + .fold(0, |len, (_, tx)| len + tx.message().instructions.len() / 2); + + debug!( + " Creating {} accounts in {} txs", + accounts, + to_create_txs.len(), + ); + + let mut retries = 0; + while !to_create_txs.is_empty() { + let blockhash = client + .get_recent_blockhash() + .expect("Failed to get blockhash"); + to_create_txs.par_iter_mut().for_each(|(k, tx)| { + let kp: &Keypair = k; + tx.sign(&[kp], blockhash); + }); + to_create_txs.iter().for_each(|(_, tx)| { + client.async_send_transaction(tx.clone()).expect("transfer"); + }); + + let mut waits = 0; + while !to_create_txs.is_empty() { + sleep(Duration::from_millis(50)); + to_create_txs.retain(|(_, tx)| !verify_transfer(client, &tx)); + if to_create_txs.is_empty() { + break; + } + debug!( + " {} transactions outstanding, waits {:?}", + to_create_txs.len(), + waits + ); + waits += 1; + if waits >= 5 { + break; + } + } + + if !to_create_txs.is_empty() { + retries += 1; + debug!(" Retry {:?}", retries); + if retries >= 10 { + error!(" Too many retries, give up"); + exit(1); + } + } + } + }); + + let mut new_notfunded: Vec<(&Arc, &Pubkey)> = vec![]; + for f in ¬funded { + if client.get_balance(&f.1).unwrap_or(0) == 0 { + new_notfunded.push(*f) + } + } + notfunded = new_notfunded; + debug!(" Left: {}", notfunded.len()); + } +} + +fn compute_and_report_stats(maxes: &Arc>>, total_tx_send_count: usize) { + let mut max_tx_count = 0; + let mut max_tx_time = Duration::new(0, 0); + info!("| Max TPS | Total Transactions"); + info!("+---------------+--------------------"); + + for stats in maxes.read().unwrap().iter() { + let maybe_flag = match stats.tx_count { + 0 => "!!!!!", + _ => "", + }; + + info!("| {:13.2} | {} {}", stats.tps, stats.tx_count, maybe_flag); + + if stats.tx_time > max_tx_time { + max_tx_time = stats.tx_time; + } + if stats.tx_count > max_tx_count { + max_tx_count = stats.tx_count; + } + } + info!("+---------------+--------------------"); + + if max_tx_count > total_tx_send_count as u64 { + error!( + "{} more transactions sampled ({}) then were sent ({})", + max_tx_count - total_tx_send_count as u64, + max_tx_count, + total_tx_send_count + ); + } else { + info!( + "{} txs dropped ({:.2}%)", + total_tx_send_count as u64 - max_tx_count, + (total_tx_send_count as u64 - max_tx_count) as f64 / total_tx_send_count as f64 + * 100_f64 + ); + } + info!( + "\tAverage TPS: {}", + max_tx_count as f32 / max_tx_time.as_secs() as f32 + ); +} + +fn generate_keypairs(num: u64) -> Vec { + let mut seed = [0_u8; 32]; + seed.copy_from_slice(&Keypair::new().pubkey().as_ref()); + let mut rnd = GenKeys::new(seed); + rnd.gen_n_keypairs(num) +} + +pub fn airdrop_lamports(client: &Client, drone_addr: &SocketAddr, id: &Keypair, amount: u64) { + let balance = client.get_balance(&id.pubkey()); + let balance = balance.unwrap_or(0); + if balance > amount { + return; + } + + let amount_to_drop = amount - balance; + + info!( + "Airdropping {:?} lamports from {} for {}", + amount_to_drop, + drone_addr, + id.pubkey(), + ); + + let mut tries = 0; + loop { + let blockhash = client + .get_recent_blockhash() + .expect("Failed to get blockhash"); + match request_airdrop_transaction(&drone_addr, &id.pubkey(), amount_to_drop, blockhash) { + Ok(transaction) => { + let signature = client.async_send_transaction(transaction).unwrap(); + + if let Ok(Some(_)) = client.get_signature_status(&signature) { + break; + } + } + Err(err) => { + panic!( + "Error requesting airdrop: {:?} to addr: {:?} amount: {}", + err, drone_addr, amount + ); + } + }; + debug!(" Retry..."); + tries += 1; + if tries > 50 { + error!("Too many retries, give up"); + exit(1); + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use solana::cluster_info::FULLNODE_PORT_RANGE; + use solana::fullnode::FullnodeConfig; + use solana::gossip_service::discover_nodes; + use solana::local_cluster::{ClusterConfig, LocalCluster}; + use solana_client::thin_client::create_client; + use solana_client::thin_client::ThinClient; + use solana_drone::drone::run_local_drone; + use std::sync::mpsc::channel; + + #[test] + #[ignore] // TODO Issue #3825 + fn test_exchange_local_cluster() { + solana_logger::setup(); + + const NUM_NODES: usize = 1; + let fullnode_config = FullnodeConfig::default(); + + let mut config = Config::default(); + config.identity = Keypair::new(); + config.threads = 4; + config.duration = Duration::from_secs(5); + config.fund_amount = 100_000; + config.trade_delay = 0; + config.batch_size = 10; + config.account_groups = 100; + let Config { + fund_amount, + batch_size, + account_groups, + .. + } = config; + let accounts_in_groups = batch_size * account_groups; + + let cluster = LocalCluster::new(&ClusterConfig { + node_stakes: vec![100_000; NUM_NODES], + cluster_lamports: 100_000_000_000_000, + fullnode_config, + native_instruction_processors: [( + "solana_exchange_program".to_string(), + solana_exchange_api::id(), + )] + .to_vec(), + ..ClusterConfig::default() + }); + + let drone_keypair = Keypair::new(); + cluster.transfer( + &cluster.funding_keypair, + &drone_keypair.pubkey(), + 2_000_000_000_000, + ); + + let (addr_sender, addr_receiver) = channel(); + run_local_drone(drone_keypair, addr_sender, Some(1_000_000_000_000)); + let drone_addr = addr_receiver.recv_timeout(Duration::from_secs(2)).unwrap(); + + info!("Connecting to the cluster"); + let nodes = + discover_nodes(&cluster.entry_point_info.gossip, NUM_NODES).unwrap_or_else(|err| { + error!("Failed to discover {} nodes: {:?}", NUM_NODES, err); + exit(1); + }); + if nodes.len() < NUM_NODES { + error!( + "Error: Insufficient nodes discovered. Expecting {} or more", + NUM_NODES + ); + exit(1); + } + let client_ctors: Vec<_> = nodes + .iter() + .map(|node| { + let cluster_entrypoint = node.clone(); + let cluster_addrs = cluster_entrypoint.client_facing_addr(); + let client_ctor = + move || -> ThinClient { create_client(cluster_addrs, FULLNODE_PORT_RANGE) }; + client_ctor + }) + .collect(); + + let client = client_ctors[0](); + airdrop_lamports( + &client, + &drone_addr, + &config.identity, + fund_amount * (accounts_in_groups + 1) as u64 * 2, + ); + + do_bench_exchange(client_ctors, config); + } +} diff --git a/bench-exchange/src/cli.rs b/bench-exchange/src/cli.rs new file mode 100644 index 000000000..6347492f8 --- /dev/null +++ b/bench-exchange/src/cli.rs @@ -0,0 +1,173 @@ +use std::net::SocketAddr; +use std::time::Duration; + +use clap::{crate_description, crate_name, crate_version, value_t, App, Arg, ArgMatches}; +use solana_drone::drone::DRONE_PORT; +use solana_sdk::signature::{read_keypair, Keypair, KeypairUtil}; +use untrusted::Input; + +pub struct Config { + pub network_addr: SocketAddr, + pub drone_addr: SocketAddr, + pub identity: Keypair, + pub threads: usize, + pub num_nodes: usize, + pub duration: Duration, + pub trade_delay: u64, + pub fund_amount: u64, + pub batch_size: usize, + pub account_groups: usize, +} + +impl Default for Config { + fn default() -> Self { + Self { + network_addr: SocketAddr::from(([127, 0, 0, 1], 8001)), + drone_addr: SocketAddr::from(([127, 0, 0, 1], DRONE_PORT)), + identity: Keypair::new(), + num_nodes: 1, + threads: 4, + duration: Duration::new(u64::max_value(), 0), + trade_delay: 0, + fund_amount: 100_000, + batch_size: 100, + account_groups: 100, + } + } +} + +pub fn build_args<'a, 'b>() -> App<'a, 'b> { + App::new(crate_name!()) + .about(crate_description!()) + .version(crate_version!()) + .arg( + Arg::with_name("network") + .short("n") + .long("network") + .value_name("HOST:PORT") + .takes_value(true) + .required(false) + .default_value("127.0.0.1:8001") + .help("Network's gossip entry point; defaults to 127.0.0.1:8001"), + ) + .arg( + Arg::with_name("drone") + .short("d") + .long("drone") + .value_name("HOST:PORT") + .takes_value(true) + .required(false) + .default_value("127.0.0.1:9900") + .help("Location of the drone; defaults to 127.0.0.1:9900"), + ) + .arg( + Arg::with_name("identity") + .short("i") + .long("identity") + .value_name("PATH") + .takes_value(true) + .help("File containing a client identity (keypair)"), + ) + .arg( + Arg::with_name("threads") + .long("threads") + .value_name("") + .takes_value(true) + .required(false) + .default_value("4") + .help("Number of threads submitting transactions"), + ) + .arg( + Arg::with_name("num-nodes") + .long("num-nodes") + .value_name("NUM") + .takes_value(true) + .required(false) + .default_value("1") + .help("Wait for NUM nodes to converge"), + ) + .arg( + Arg::with_name("duration") + .long("duration") + .value_name("SECS") + .takes_value(true) + .default_value("60") + .help("Seconds to run benchmark, then exit; default is forever"), + ) + .arg( + Arg::with_name("trade-delay") + .long("trade-delay") + .value_name("") + .takes_value(true) + .required(false) + .default_value("0") + .help("Delay between trade requests in milliseconds"), + ) + .arg( + Arg::with_name("fund-amount") + .long("fund-amount") + .value_name("") + .takes_value(true) + .required(false) + .default_value("100000") + .help("Number of lamports to fund to each signer"), + ) + .arg( + Arg::with_name("batch-size") + .long("batch-size") + .value_name("") + .takes_value(true) + .required(false) + .default_value("1000") + .help("Number of bulk trades to submit between trade delays"), + ) + .arg( + Arg::with_name("account-groups") + .long("account-groups") + .value_name("") + .takes_value(true) + .required(false) + .default_value("100") + .help("Number of account groups to cycle for each batch"), + ) +} + +pub fn extract_args<'a>(matches: &ArgMatches<'a>) -> Config { + let mut args = Config::default(); + + args.network_addr = matches + .value_of("network") + .unwrap() + .parse() + .expect("Failed to parse network"); + args.drone_addr = matches + .value_of("drone") + .unwrap() + .parse() + .expect("Failed to parse drone address"); + + if matches.is_present("identity") { + args.identity = read_keypair(matches.value_of("identity").unwrap()) + .expect("can't read client identity"); + } else { + args.identity = { + let seed = [42_u8; 32]; + Keypair::from_seed_unchecked(Input::from(&seed)).unwrap() + }; + } + args.threads = value_t!(matches.value_of("threads"), usize).expect("Failed to parse threads"); + args.num_nodes = + value_t!(matches.value_of("num-nodes"), usize).expect("Failed to parse num-nodes"); + let duration = value_t!(matches.value_of("duration"), u64).expect("Failed to parse duration"); + args.duration = Duration::from_secs(duration); + args.trade_delay = + value_t!(matches.value_of("trade-delay"), u64).expect("Failed to parse trade-delay"); + args.fund_amount = + value_t!(matches.value_of("fund-amount"), u64).expect("Failed to parse fund-amount"); + args.batch_size = + value_t!(matches.value_of("batch-size"), usize).expect("Failed to parse batch-size"); + args.account_groups = value_t!(matches.value_of("account-groups"), usize) + .expect("Failed to parse account-groups"); + + args +} diff --git a/bench-exchange/src/main.rs b/bench-exchange/src/main.rs new file mode 100644 index 000000000..6f314abbe --- /dev/null +++ b/bench-exchange/src/main.rs @@ -0,0 +1,73 @@ +pub mod bench; +mod cli; +pub mod order_book; + +use crate::bench::{airdrop_lamports, do_bench_exchange, Config}; +use log::*; +use solana::cluster_info::FULLNODE_PORT_RANGE; +use solana::gossip_service::discover_nodes; +use solana_client::thin_client::create_client; +use solana_client::thin_client::ThinClient; +use solana_sdk::signature::KeypairUtil; + +fn main() { + solana_logger::setup(); + + let matches = cli::build_args().get_matches(); + let cli_config = cli::extract_args(&matches); + + let cli::Config { + network_addr, + drone_addr, + identity, + threads, + num_nodes, + duration, + trade_delay, + fund_amount, + batch_size, + account_groups, + .. + } = cli_config; + + info!("Connecting to the cluster"); + let nodes = discover_nodes(&network_addr, num_nodes).unwrap_or_else(|_| { + panic!("Failed to discover nodes"); + }); + info!("{} nodes found", nodes.len()); + if nodes.len() < num_nodes { + panic!("Error: Insufficient nodes discovered"); + } + + let client_ctors: Vec<_> = nodes + .iter() + .map(|node| { + let cluster_entrypoint = node.clone(); + let cluster_addrs = cluster_entrypoint.client_facing_addr(); + move || -> ThinClient { create_client(cluster_addrs, FULLNODE_PORT_RANGE) } + }) + .collect(); + + info!("Funding keypair: {}", identity.pubkey()); + + let client = client_ctors[0](); + let accounts_in_groups = batch_size * account_groups; + airdrop_lamports( + &client, + &drone_addr, + &identity, + fund_amount * (accounts_in_groups + 1) as u64 * 2, + ); + + let config = Config { + identity, + threads, + duration, + trade_delay, + fund_amount, + batch_size, + account_groups, + }; + + do_bench_exchange(client_ctors, config); +} diff --git a/bench-exchange/src/order_book.rs b/bench-exchange/src/order_book.rs new file mode 100644 index 000000000..7af79f5f3 --- /dev/null +++ b/bench-exchange/src/order_book.rs @@ -0,0 +1,138 @@ +use itertools::EitherOrBoth::{Both, Left, Right}; +use itertools::Itertools; +use log::*; +use solana_exchange_api::exchange_state::*; +use solana_sdk::pubkey::Pubkey; +use std::cmp::Ordering; +use std::collections::BinaryHeap; +use std::{error, fmt}; + +#[derive(Clone, Debug, Eq, PartialEq)] +pub struct ToOrder { + pub pubkey: Pubkey, + pub info: TradeOrderInfo, +} + +impl Ord for ToOrder { + fn cmp(&self, other: &Self) -> Ordering { + other.info.price.cmp(&self.info.price) + } +} +impl PartialOrd for ToOrder { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} +#[derive(Clone, Debug, Eq, PartialEq)] +pub struct FromOrder { + pub pubkey: Pubkey, + pub info: TradeOrderInfo, +} + +impl Ord for FromOrder { + fn cmp(&self, other: &Self) -> Ordering { + self.info.price.cmp(&other.info.price) + } +} +impl PartialOrd for FromOrder { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +#[derive(Default)] +pub struct OrderBook { + // TODO scale to x token types + to_ab: BinaryHeap, + from_ab: BinaryHeap, +} +impl fmt::Display for OrderBook { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + writeln!( + f, + "+-Order Book--------------------------+-------------------------------------+" + )?; + for (i, it) in self + .to_ab + .iter() + .zip_longest(self.from_ab.iter()) + .enumerate() + { + match it { + Both(to, from) => writeln!( + f, + "| T AB {:8} for {:8}/{:8} | F AB {:8} for {:8}/{:8} |{}", + to.info.tokens, + SCALER, + to.info.price, + from.info.tokens, + SCALER, + from.info.price, + i + )?, + Left(to) => writeln!( + f, + "| T AB {:8} for {:8}/{:8} | |{}", + to.info.tokens, SCALER, to.info.price, i + )?, + Right(from) => writeln!( + f, + "| | F AB {:8} for {:8}/{:8} |{}", + from.info.tokens, SCALER, from.info.price, i + )?, + } + } + write!( + f, + "+-------------------------------------+-------------------------------------+" + )?; + Ok(()) + } +} + +impl OrderBook { + // TODO + // pub fn cancel(&mut self, pubkey: Pubkey) -> Result<(), Box> { + // Ok(()) + // } + pub fn push( + &mut self, + pubkey: Pubkey, + info: TradeOrderInfo, + ) -> Result<(), Box> { + check_trade(info.direction, info.tokens, info.price)?; + match info.direction { + Direction::To => { + self.to_ab.push(ToOrder { pubkey, info }); + } + Direction::From => { + self.from_ab.push(FromOrder { pubkey, info }); + } + } + Ok(()) + } + pub fn pop(&mut self) -> Option<(ToOrder, FromOrder)> { + if let Some(pair) = Self::pop_pair(&mut self.to_ab, &mut self.from_ab) { + return Some(pair); + } + None + } + pub fn get_num_outstanding(&self) -> (usize, usize) { + (self.to_ab.len(), self.from_ab.len()) + } + + fn pop_pair( + to_ab: &mut BinaryHeap, + from_ab: &mut BinaryHeap, + ) -> Option<(ToOrder, FromOrder)> { + let to = to_ab.peek()?; + let from = from_ab.peek()?; + if from.info.price < to.info.price { + debug!("Trade not viable"); + return None; + } + let to = to_ab.pop()?; + let from = from_ab.pop()?; + Some((to, from)) + } +}