diff --git a/src/data_store.rs b/src/data_store.rs index 70abfcc3f..1c080e6d7 100644 --- a/src/data_store.rs +++ b/src/data_store.rs @@ -5,7 +5,7 @@ // http://opensource.org/licenses/MIT>, at your option. You may not use this file except in // accordance with one or both of these licenses. -use std::collections::{hash_map, HashMap}; +use std::collections::HashMap; use std::ops::Deref; use std::sync::{Arc, Mutex}; @@ -83,34 +83,38 @@ where pub(crate) async fn insert_or_update(&self, object: SO) -> Result { let _guard = self.mutation_lock.lock().await; - let (updated, data_to_persist) = { - let mut locked_objects = self.objects.lock().expect("lock"); - match locked_objects.entry(object.id()) { - hash_map::Entry::Occupied(mut e) => { - let update = object.to_update(); - let updated = e.get_mut().update(update); - let data_to_persist = - if updated { Some(Self::encode_object(e.get())) } else { None }; - (updated, data_to_persist) - }, - hash_map::Entry::Vacant(e) => { - let data_to_persist = Self::encode_object(&object); - e.insert(object); - (true, Some(data_to_persist)) - }, + + let id = object.id(); + let data_to_persist = { + let locked_objects = self.objects.lock().expect("lock"); + if let Some(existing_object) = locked_objects.get(&id) { + let mut updated_object = existing_object.clone(); + let updated = updated_object.update(object.to_update()); + if updated { + Some(updated_object) + } else { + None + } + } else { + Some(object) } }; - if let Some((store_key, data)) = data_to_persist { - self.persist_encoded(store_key, data).await?; + match data_to_persist { + Some(updated_object) => { + self.persist(&updated_object).await?; + let mut locked_objects = self.objects.lock().expect("lock"); + locked_objects.insert(id, updated_object); + Ok(true) + }, + None => Ok(false), } - Ok(updated) } pub(crate) async fn remove(&self, id: &SO::Id) -> Result<(), Error> { let _guard = self.mutation_lock.lock().await; - let removed = { self.objects.lock().expect("lock").remove(id).is_some() }; - if removed { + let should_remove = { self.objects.lock().expect("lock").contains_key(id) }; + if should_remove { let store_key = id.encode_to_hex_str(); KVStore::remove( &*self.kv_store, @@ -131,6 +135,7 @@ where ); Error::PersistenceFailed })?; + self.objects.lock().expect("lock").remove(id); } Ok(()) } @@ -138,38 +143,38 @@ where /// Returns the current in-memory object for `id`. /// /// The async mutation lock serializes writers, but this synchronous reader cannot wait on it. - /// Until store reads are async, callers may temporarily see in-memory state that is either - /// still being persisted or has not yet caught up to a write in progress. + /// Until store reads are async, callers may temporarily see in-memory state that has not yet + /// caught up to a write in progress. pub(crate) fn get(&self, id: &SO::Id) -> Option { self.objects.lock().expect("lock").get(id).cloned() } pub(crate) async fn update(&self, update: SO::Update) -> Result { let _guard = self.mutation_lock.lock().await; - let (res, data_to_persist) = { - let mut locked_objects = self.objects.lock().expect("lock"); - if let Some(object) = locked_objects.get_mut(&update.id()) { - let updated = object.update(update); - if updated { - (DataStoreUpdateResult::Updated, Some(Self::encode_object(object))) - } else { - (DataStoreUpdateResult::Unchanged, None) - } - } else { - (DataStoreUpdateResult::NotFound, None) + let id = update.id(); + let updated_object = { + let locked_objects = self.objects.lock().expect("lock"); + let Some(object) = locked_objects.get(&id) else { + return Ok(DataStoreUpdateResult::NotFound); + }; + let mut updated_object = object.clone(); + if !updated_object.update(update) { + return Ok(DataStoreUpdateResult::Unchanged); } + updated_object }; - if let Some((store_key, data)) = data_to_persist { - self.persist_encoded(store_key, data).await?; - } - Ok(res) + + self.persist(&updated_object).await?; + let mut locked_objects = self.objects.lock().expect("lock"); + locked_objects.insert(id, updated_object); + Ok(DataStoreUpdateResult::Updated) } /// Returns in-memory objects matching `f`. /// /// The async mutation lock serializes writers, but this synchronous reader cannot wait on it. - /// Until store reads are async, callers may temporarily see in-memory state that is either - /// still being persisted or has not yet caught up to a write in progress. + /// Until store reads are async, callers may temporarily see in-memory state that has not yet + /// caught up to a write in progress. pub(crate) fn list_filter bool>(&self, f: F) -> Vec { self.objects.lock().expect("lock").values().filter(f).cloned().collect::>() } @@ -209,8 +214,8 @@ where /// Returns whether the in-memory store contains `id`. /// /// The async mutation lock serializes writers, but this synchronous reader cannot wait on it. - /// Until store reads are async, callers may temporarily see in-memory state that is either - /// still being persisted or has not yet caught up to a write in progress. + /// Until store reads are async, callers may temporarily see in-memory state that has not yet + /// caught up to a write in progress. pub(crate) fn contains_key(&self, id: &SO::Id) -> bool { self.objects.lock().expect("lock").contains_key(id) } @@ -219,6 +224,7 @@ where #[cfg(test)] mod tests { use lightning::impl_writeable_tlv_based; + use lightning::io; use lightning::util::test_utils::TestLogger; use super::*; @@ -281,6 +287,46 @@ mod tests { (2, data, required), }); + struct FailingStore; + + impl KVStore for FailingStore { + fn read( + &self, _primary_namespace: &str, _secondary_namespace: &str, _key: &str, + ) -> impl std::future::Future, io::Error>> + 'static + Send { + async { Err(io::Error::new(io::ErrorKind::Other, "read failed")) } + } + + fn write( + &self, _primary_namespace: &str, _secondary_namespace: &str, _key: &str, _buf: Vec, + ) -> impl std::future::Future> + 'static + Send { + async { Err(io::Error::new(io::ErrorKind::Other, "write failed")) } + } + + fn remove( + &self, _primary_namespace: &str, _secondary_namespace: &str, _key: &str, _lazy: bool, + ) -> impl std::future::Future> + 'static + Send { + async { Err(io::Error::new(io::ErrorKind::Other, "remove failed")) } + } + + fn list( + &self, _primary_namespace: &str, _secondary_namespace: &str, + ) -> impl std::future::Future, io::Error>> + 'static + Send { + async { Err(io::Error::new(io::ErrorKind::Other, "list failed")) } + } + } + + fn new_failing_data_store(objects: Vec) -> DataStore> { + let store: Arc = Arc::new(DynStoreWrapper(FailingStore)); + let logger = Arc::new(TestLogger::new()); + DataStore::new( + objects, + "datastore_test_primary".to_string(), + "datastore_test_secondary".to_string(), + store, + logger, + ) + } + #[tokio::test] async fn data_is_persisted() { let store: Arc = Arc::new(DynStoreWrapper(InMemoryStore::new())); @@ -346,4 +392,54 @@ mod tests { new_iou_object.data[0] += 1; assert_eq!(Ok(true), data_store.insert_or_update(new_iou_object).await); } + + #[tokio::test] + async fn insert_does_not_mutate_memory_if_persist_fails() { + let id = TestObjectId { id: [42u8; 4] }; + let object = TestObject { id, data: [23u8; 3] }; + let data_store = new_failing_data_store(vec![]); + + assert_eq!(Err(Error::PersistenceFailed), data_store.insert(object).await); + assert!(data_store.get(&id).is_none()); + } + + #[tokio::test] + async fn update_does_not_mutate_memory_if_persist_fails() { + let id = TestObjectId { id: [42u8; 4] }; + let object = TestObject { id, data: [23u8; 3] }; + let data_store = new_failing_data_store(vec![object]); + + let update = TestObjectUpdate { id, data: [24u8; 3] }; + assert_eq!(Err(Error::PersistenceFailed), data_store.update(update).await); + assert_eq!(Some(object), data_store.get(&id)); + } + + #[tokio::test] + async fn insert_or_update_does_not_mutate_memory_if_persist_fails() { + let existing_id = TestObjectId { id: [42u8; 4] }; + let existing_object = TestObject { id: existing_id, data: [23u8; 3] }; + let data_store = new_failing_data_store(vec![existing_object]); + + let updated_object = TestObject { id: existing_id, data: [24u8; 3] }; + assert_eq!( + Err(Error::PersistenceFailed), + data_store.insert_or_update(updated_object).await + ); + assert_eq!(Some(existing_object), data_store.get(&existing_id)); + + let new_id = TestObjectId { id: [55u8; 4] }; + let new_object = TestObject { id: new_id, data: [34u8; 3] }; + assert_eq!(Err(Error::PersistenceFailed), data_store.insert_or_update(new_object).await); + assert!(data_store.get(&new_id).is_none()); + } + + #[tokio::test] + async fn remove_does_not_mutate_memory_if_persist_fails() { + let id = TestObjectId { id: [42u8; 4] }; + let object = TestObject { id, data: [23u8; 3] }; + let data_store = new_failing_data_store(vec![object]); + + assert_eq!(Err(Error::PersistenceFailed), data_store.remove(&id).await); + assert_eq!(Some(object), data_store.get(&id)); + } }