Merge pull request #328 from nuttycom/note_id_enum

Use an enum to distinguish between sent and received notes in sqlite backend.
This commit is contained in:
str4d 2021-02-03 09:17:16 +13:00 committed by GitHub
commit b5ee057e03
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 46 additions and 52 deletions

View File

@ -137,21 +137,15 @@ pub trait WalletRead {
anchor_height: BlockHeight,
) -> Result<Amount, Self::Error>;
/// Returns the memo for a received note, if it is known and a valid UTF-8 string.
/// Returns the memo for a note, if it is known and a valid UTF-8 string.
///
/// This will return `Ok(None)` if the note identifier does not appear in the
/// database as a known note ID.
fn get_received_memo_as_utf8(
fn get_memo_as_utf8(
&self,
id_note: Self::NoteRef,
) -> Result<Option<String>, Self::Error>;
/// Returns the memo for a sent note, if it is known and a valid UTF-8 string.
///
/// This will return `Ok(None)` if the note identifier does not appear in the
/// database as a known note ID.
fn get_sent_memo_as_utf8(&self, id_note: Self::NoteRef) -> Result<Option<String>, Self::Error>;
/// Returns the note commitment tree at the specified block height.
fn get_commitment_tree(
&self,
@ -459,14 +453,7 @@ pub mod testing {
Ok(Amount::zero())
}
fn get_received_memo_as_utf8(
&self,
_id_note: Self::NoteRef,
) -> Result<Option<String>, Self::Error> {
Ok(None)
}
fn get_sent_memo_as_utf8(
fn get_memo_as_utf8(
&self,
_id_note: Self::NoteRef,
) -> Result<Option<String>, Self::Error> {

View File

@ -347,10 +347,10 @@ where
.note
.nf(&extfvk.fvk.vk, output.witness.position() as u64);
let note_id = up.put_received_note(&output, &Some(nf), tx_row)?;
let received_note_id = up.put_received_note(&output, &Some(nf), tx_row)?;
// Save witness for note.
witnesses.push((note_id, output.witness));
witnesses.push((received_note_id, output.witness));
// Cache nullifier for note (to detect subsequent spends in this scan).
nullifiers.push((output.account, nf));
@ -359,8 +359,8 @@ where
}
// Insert current witnesses into the database.
for (note_id, witness) in witnesses.iter() {
up.insert_witness(*note_id, witness, last_height)?;
for (received_note_id, witness) in witnesses.iter() {
up.insert_witness(*received_note_id, witness, last_height)?;
}
// Prune the stored witnesses (we only expect rollbacks of at most 100 blocks).

View File

@ -16,15 +16,19 @@ pub enum SqliteClientError {
/// The rcm value for a note cannot be decoded to a valid JubJub point.
InvalidNote,
/// The note id associated with a witness being stored corresponds to a
/// sent note, not a received note.
InvalidNoteId,
/// Illegal attempt to reinitialize an already-initialized wallet database.
TableNotEmpty,
/// Bech32 decoding error
Bech32(bech32::Error),
/// Base58 decoding error
Base58(bs58::decode::Error),
/// Illegal attempt to reinitialize an already-initialized wallet database.
TableNotEmpty,
/// Wrapper for rusqlite errors.
DbError(rusqlite::Error),
@ -58,6 +62,7 @@ impl fmt::Display for SqliteClientError {
}
SqliteClientError::IncorrectHRPExtFVK => write!(f, "Incorrect HRP for extfvk"),
SqliteClientError::InvalidNote => write!(f, "Invalid note"),
SqliteClientError::InvalidNoteId => write!(f, "The note ID associated with an inserted witness must correspond to a received note."),
SqliteClientError::Bech32(e) => write!(f, "{}", e),
SqliteClientError::Base58(e) => write!(f, "{}", e),
SqliteClientError::TableNotEmpty => write!(f, "Table is not empty"),

View File

@ -59,11 +59,17 @@ pub mod wallet;
/// A newtype wrapper for sqlite primary key values for the notes
/// table.
#[derive(Debug, Copy, Clone)]
pub struct NoteId(pub i64);
pub enum NoteId {
SentNoteId(i64),
ReceivedNoteId(i64),
}
impl fmt::Display for NoteId {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "Note {}", self.0)
match self {
NoteId::SentNoteId(id) => write!(f, "Sent Note {}", id),
NoteId::ReceivedNoteId(id) => write!(f, "Received Note {}", id),
}
}
}
@ -200,15 +206,11 @@ impl<P: consensus::Parameters> WalletRead for WalletDB<P> {
wallet::get_balance_at(self, account, anchor_height)
}
fn get_received_memo_as_utf8(
&self,
id_note: Self::NoteRef,
) -> Result<Option<String>, Self::Error> {
wallet::get_received_memo_as_utf8(self, id_note)
}
fn get_sent_memo_as_utf8(&self, id_note: Self::NoteRef) -> Result<Option<String>, Self::Error> {
wallet::get_sent_memo_as_utf8(self, id_note)
fn get_memo_as_utf8(&self, id_note: Self::NoteRef) -> Result<Option<String>, Self::Error> {
match id_note {
NoteId::SentNoteId(id_note) => wallet::get_sent_memo_as_utf8(self, id_note),
NoteId::ReceivedNoteId(id_note) => wallet::get_received_memo_as_utf8(self, id_note),
}
}
fn get_commitment_tree(
@ -315,15 +317,11 @@ impl<'a, P: consensus::Parameters> WalletRead for DataConnStmtCache<'a, P> {
self.wallet_db.get_balance_at(account, anchor_height)
}
fn get_received_memo_as_utf8(
fn get_memo_as_utf8(
&self,
id_note: Self::NoteRef,
) -> Result<Option<String>, Self::Error> {
self.wallet_db.get_received_memo_as_utf8(id_note)
}
fn get_sent_memo_as_utf8(&self, id_note: Self::NoteRef) -> Result<Option<String>, Self::Error> {
self.wallet_db.get_sent_memo_as_utf8(id_note)
self.wallet_db.get_memo_as_utf8(id_note)
}
fn get_commitment_tree(
@ -442,7 +440,11 @@ impl<'a, P: consensus::Parameters> WalletWrite for DataConnStmtCache<'a, P> {
witness: &IncrementalWitness<Node>,
height: BlockHeight,
) -> Result<(), Self::Error> {
wallet::insert_witness(self, note_id, witness, height)
if let NoteId::ReceivedNoteId(rnid) = note_id {
wallet::insert_witness(self, rnid, witness, height)
} else {
Err(SqliteClientError::InvalidNoteId)
}
}
fn prune_witnesses(&mut self, below_height: BlockHeight) -> Result<(), Self::Error> {

View File

@ -214,16 +214,16 @@ pub fn get_balance_at<P>(
///
/// let data_file = NamedTempFile::new().unwrap();
/// let db = WalletDB::for_path(data_file, Network::TestNetwork).unwrap();
/// let memo = get_received_memo_as_utf8(&db, NoteId(27));
/// let memo = get_received_memo_as_utf8(&db, 27);
/// ```
pub fn get_received_memo_as_utf8<P>(
wdb: &WalletDB<P>,
id_note: NoteId,
id_note: i64,
) -> Result<Option<String>, SqliteClientError> {
let memo: Vec<_> = wdb.conn.query_row(
"SELECT memo FROM received_notes
WHERE id_note = ?",
&[id_note.0],
&[id_note],
|row| row.get(0),
)?;
@ -255,16 +255,16 @@ pub fn get_received_memo_as_utf8<P>(
///
/// let data_file = NamedTempFile::new().unwrap();
/// let db = WalletDB::for_path(data_file, Network::TestNetwork).unwrap();
/// let memo = get_sent_memo_as_utf8(&db, NoteId(12));
/// let memo = get_sent_memo_as_utf8(&db, 12);
/// ```
pub fn get_sent_memo_as_utf8<P>(
wdb: &WalletDB<P>,
id_note: NoteId,
id_note: i64,
) -> Result<Option<String>, SqliteClientError> {
let memo: Vec<_> = wdb.conn.query_row(
"SELECT memo FROM sent_notes
WHERE id_note = ?",
&[id_note.0],
&[id_note],
|row| row.get(0),
)?;
@ -410,7 +410,7 @@ pub fn get_witnesses<P>(
.prepare("SELECT note, witness FROM sapling_witnesses WHERE block = ?")?;
let witnesses = stmt_fetch_witnesses
.query_map(&[u32::from(block_height)], |row| {
let id_note = NoteId(row.get(0)?);
let id_note = NoteId::ReceivedNoteId(row.get(0)?);
let wdb: Vec<u8> = row.get(1)?;
Ok(IncrementalWitness::read(&wdb[..]).map(|witness| (id_note, witness)))
})
@ -565,13 +565,13 @@ pub fn put_received_note<'a, P, T: ShieldedOutput>(
// It isn't there, so insert our note into the database.
stmts.stmt_insert_received_note.execute_named(&sql_args)?;
Ok(NoteId(stmts.wallet_db.conn.last_insert_rowid()))
Ok(NoteId::ReceivedNoteId(stmts.wallet_db.conn.last_insert_rowid()))
} else {
// It was there, so grab its row number.
stmts
.stmt_select_received_note
.query_row(params![tx_ref, (output.index() as i64)], |row| {
row.get(0).map(NoteId)
row.get(0).map(NoteId::ReceivedNoteId)
})
.map_err(SqliteClientError::from)
}
@ -579,7 +579,7 @@ pub fn put_received_note<'a, P, T: ShieldedOutput>(
pub fn insert_witness<'a, P>(
stmts: &mut DataConnStmtCache<'a, P>,
note_id: NoteId,
note_id: i64,
witness: &IncrementalWitness<Node>,
height: BlockHeight,
) -> Result<(), SqliteClientError> {
@ -588,7 +588,7 @@ pub fn insert_witness<'a, P>(
stmts
.stmt_insert_witness
.execute(params![note_id.0, u32::from(height), encoded])?;
.execute(params![note_id, u32::from(height), encoded])?;
Ok(())
}