fix(hermes): add lock for the entire message state

Before this change, there was a lock for each message and it could
cause the updateData for multiple ids have 2 updates (because of the
race with the thread updating the states). This change adds a RwLock
which makes sure that when the entire message state is updating,
no one can read from it while allowing concurrent reads in other
occasions.
This commit is contained in:
Ali Behjati 2023-10-11 12:38:13 +02:00
parent a224199c8f
commit bb922c3a17
3 changed files with 34 additions and 30 deletions

2
hermes/Cargo.lock generated
View File

@ -1898,7 +1898,7 @@ checksum = "95505c38b4572b2d910cecb0281560f54b440a19336cbbcb27bf6ce6adc6f5a8"
[[package]]
name = "hermes"
version = "0.3.1"
version = "0.3.2"
dependencies = [
"anyhow",
"async-trait",

View File

@ -1,6 +1,6 @@
[package]
name = "hermes"
version = "0.3.1"
version = "0.3.2"
description = "Hermes is an agent that provides Verified Prices from the Pythnet Pyth Oracle."
edition = "2021"

View File

@ -12,14 +12,17 @@ use {
anyhow,
Result,
},
dashmap::DashMap,
futures::future::join_all,
pythnet_sdk::messages::{
FeedId,
Message,
MessageType,
},
std::{
collections::BTreeMap,
collections::{
BTreeMap,
HashMap,
},
ops::Bound,
sync::Arc,
},
@ -103,16 +106,16 @@ pub struct Cache {
/// We do not write to this cache much, so we can use a simple RwLock instead of a DashMap.
wormhole_merkle_state_cache: Arc<RwLock<BTreeMap<Slot, WormholeMerkleState>>>,
message_cache: Arc<DashMap<MessageStateKey, BTreeMap<MessageStateTime, MessageState>>>,
message_cache: Arc<RwLock<HashMap<MessageStateKey, BTreeMap<MessageStateTime, MessageState>>>>,
cache_size: u64,
}
fn retrieve_message_state(
async fn retrieve_message_state(
cache: &Cache,
key: MessageStateKey,
request_time: RequestTime,
) -> Option<MessageState> {
match cache.message_cache.get(&key) {
match cache.message_cache.read().await.get(&key) {
Some(key_cache) => {
match request_time {
RequestTime::Latest => key_cache.last_key_value().map(|(_, v)| v).cloned(),
@ -154,7 +157,7 @@ fn retrieve_message_state(
impl Cache {
pub fn new(cache_size: u64) -> Self {
Self {
message_cache: Arc::new(DashMap::new()),
message_cache: Arc::new(RwLock::new(HashMap::new())),
accumulator_messages_cache: Arc::new(RwLock::new(BTreeMap::new())),
wormhole_merkle_state_cache: Arc::new(RwLock::new(BTreeMap::new())),
cache_size,
@ -189,20 +192,20 @@ impl AggregateCache for crate::state::State {
async fn message_state_keys(&self) -> Vec<MessageStateKey> {
self.cache
.message_cache
.read()
.await
.iter()
.map(|entry| entry.key().clone())
.map(|entry| entry.0.clone())
.collect::<Vec<_>>()
}
async fn store_message_states(&self, message_states: Vec<MessageState>) -> Result<()> {
let mut message_cache = self.cache.message_cache.write().await;
for message_state in message_states {
let key = message_state.key();
let time = message_state.time();
let mut cache = self
.cache
.message_cache
.entry(key)
.or_insert_with(BTreeMap::new);
let cache = message_cache.entry(key).or_insert_with(BTreeMap::new);
cache.insert(time, message_state);
@ -220,24 +223,25 @@ impl AggregateCache for crate::state::State {
request_time: RequestTime,
filter: MessageStateFilter,
) -> Result<Vec<MessageState>> {
ids.into_iter()
.flat_map(|id| {
let request_time = request_time.clone();
let message_types: Vec<MessageType> = match filter {
MessageStateFilter::All => MessageType::iter().collect(),
MessageStateFilter::Only(t) => vec![t],
};
join_all(ids.into_iter().flat_map(|id| {
let request_time = request_time.clone();
let message_types: Vec<MessageType> = match filter {
MessageStateFilter::All => MessageType::iter().collect(),
MessageStateFilter::Only(t) => vec![t],
};
message_types.into_iter().map(move |message_type| {
let key = MessageStateKey {
feed_id: id,
type_: message_type,
};
retrieve_message_state(&self.cache, key, request_time.clone())
.ok_or(anyhow!("Message not found"))
})
message_types.into_iter().map(move |message_type| {
let key = MessageStateKey {
feed_id: id,
type_: message_type,
};
retrieve_message_state(&self.cache, key, request_time.clone())
})
.collect()
}))
.await
.into_iter()
.collect::<Option<Vec<_>>>()
.ok_or(anyhow!("Message not found"))
}
async fn store_accumulator_messages(