diff --git a/src/storage/kv/oracle.rs b/src/storage/kv/oracle.rs index 651cd7a7..3a7d742c 100644 --- a/src/storage/kv/oracle.rs +++ b/src/storage/kv/oracle.rs @@ -178,6 +178,24 @@ impl SnapshotIsolation { } } + // Check write conflicts + for key in txn.write_set.keys() { + if let Some(last_entry) = txn.write_set.get(key).and_then(|entries| entries.last()) { + match current_snapshot.get(&key[..].into()) { + Ok((_, version)) => { + // Detect if another transaction has written to this key + if version > last_entry.version { + return Err(Error::TransactionReadConflict); + } + } + Err(Error::KeyNotFound) => { + continue; + } + Err(e) => return Err(e), + } + } + } + let ts = self.next_tx_id.load(Ordering::SeqCst); self.increment_ts(); Ok(ts) @@ -236,32 +254,44 @@ impl CommitTracker { /// Checks if a transaction has conflicts with committed transactions. /// It acquires a lock on the read set and checks if there are any conflict keys in the read set. fn has_conflict(&self, txn: &Transaction) -> bool { - if txn.read_set.is_empty() { - false - } else { - // For each object in the changeset of the already committed transactions, check if it lies - // within the predicates of the current transaction. - for committed_tx in self.committed_transactions.iter() { - if committed_tx.ts > txn.read_ts { - for conflict_key in committed_tx.conflict_keys.iter() { - for rs_entry in txn.read_key_ranges.iter() { - if key_in_range(conflict_key, &rs_entry.start, &rs_entry.end) { - return true; - } - } - } - } - } - + // For each object in the changeset of the already committed transactions, check if it lies + // within the predicates of the current transaction. + let predicate_conflict = if !txn.read_key_ranges.is_empty() { self.committed_transactions .iter() - .filter(|committed_txn| committed_txn.ts > txn.read_ts) - .any(|committed_txn| { - txn.read_set - .iter() - .any(|read| committed_txn.conflict_keys.contains(&read.key)) + .filter(|committed_tx| committed_tx.ts > txn.read_ts) + .any(|committed_tx| { + committed_tx.conflict_keys.iter().any(|conflict_key| { + txn.read_key_ranges.iter().any(|rs_entry| { + key_in_range(conflict_key, &rs_entry.start, &rs_entry.end) + }) + }) }) - } + } else { + false + }; + + let read_set_conflict = self + .committed_transactions + .iter() + .filter(|committed_tx| committed_tx.ts > txn.read_ts) + .any(|committed_tx| { + txn.read_set + .iter() + .any(|read| committed_tx.conflict_keys.contains(&read.key)) + }); + + let write_set_conflict = self + .committed_transactions + .iter() + .filter(|committed_tx| committed_tx.ts > txn.read_ts) + .any(|committed_tx| { + txn.write_set + .keys() + .any(|write_key| committed_tx.conflict_keys.contains(write_key)) + }); + + predicate_conflict || read_set_conflict || write_set_conflict } } diff --git a/src/storage/kv/transaction.rs b/src/storage/kv/transaction.rs index b0333df6..f4ee4502 100644 --- a/src/storage/kv/transaction.rs +++ b/src/storage/kv/transaction.rs @@ -85,14 +85,16 @@ pub(crate) struct WriteSetEntry { pub(crate) e: Entry, savepoint_no: u32, seqno: u32, + pub(crate) version: u64, } impl WriteSetEntry { - pub(crate) fn new(e: Entry, savepoint_no: u32, seqno: u32) -> Self { + pub(crate) fn new(e: Entry, savepoint_no: u32, seqno: u32, version: u64) -> Self { Self { e, savepoint_no, seqno, + version, } } } @@ -346,7 +348,7 @@ impl Transaction { // Set the transaction's latest savepoint number and add it to the write set. let key = e.key.clone(); let write_seqno = self.next_write_seqno(); - let ws_entry = WriteSetEntry::new(e, self.savepoints, write_seqno); + let ws_entry = WriteSetEntry::new(e, self.savepoints, write_seqno, self.read_ts); match self.write_set.entry(key) { HashEntry::Occupied(mut oe) => { let entries = oe.get_mut(); @@ -1039,7 +1041,7 @@ mod tests { }); } - // blind writes should succeed if key wasn't read first + // blind writes should not succeed { let mut txn1 = store.begin().unwrap(); let mut txn2 = store.begin().unwrap(); @@ -1048,11 +1050,14 @@ mod tests { txn2.set(&key1, &value2).unwrap(); txn1.commit().await.unwrap(); - txn2.commit().await.unwrap(); - - let mut txn3 = store.begin().unwrap(); - let val = txn3.get(&key1).unwrap().unwrap(); - assert_eq!(val, value2.as_ref()); + assert!(match txn2.commit().await { + Err(err) => { + matches!(err, Error::TransactionReadConflict) + } + _ => { + false + } + }); } // read conflict when the read key was updated by another transaction @@ -1254,7 +1259,7 @@ mod tests { } } - #[tokio::test] + #[tokio::test(flavor = "multi_thread")] async fn mvcc_serialized_snapshot_isolation_scan() { mvcc_with_scan_tests(true).await; }