From 69da5d59ed6e73f7c9d111ff33208f9964ef2f68 Mon Sep 17 00:00:00 2001 From: Harvir Sahota Date: Sat, 16 Mar 2024 11:34:30 -0700 Subject: [PATCH] Implement DiskWriteTracker --- python/selfie-lib/selfie_lib/WriteTracker.py | 51 +++++++++++++++++++- 1 file changed, 49 insertions(+), 2 deletions(-) diff --git a/python/selfie-lib/selfie_lib/WriteTracker.py b/python/selfie-lib/selfie_lib/WriteTracker.py index 1f899b47..37e2cd9b 100644 --- a/python/selfie-lib/selfie_lib/WriteTracker.py +++ b/python/selfie-lib/selfie_lib/WriteTracker.py @@ -1,8 +1,12 @@ -from typing import List, Optional +from typing import List, Optional, Generic, TypeVar, Dict from selfie_lib.CommentTracker import SnapshotFileLayout -import inspect +from abc import ABC, abstractmethod +import inspect, threading from functools import total_ordering +T = TypeVar("T") +U = TypeVar("U") + @total_ordering class CallLocation: @@ -85,3 +89,46 @@ def recordCall(callerFileOnly: bool = False) -> CallStack: rest_of_stack = call_locations[1:] return CallStack(location, rest_of_stack) + + +class FirstWrite(Generic[U]): + def __init__(self, snapshot: U, call_stack: CallStack): + self.snapshot = snapshot + self.call_stack = call_stack + + +class WriteTracker(ABC, Generic[T, U]): + def __init__(self): + self.lock = threading.Lock() + self.writes: Dict[T, FirstWrite[U]] = {} + + @abstractmethod + def record(self, key: T, snapshot: U, call: CallStack, layout: SnapshotFileLayout): + pass + + def recordInternal( + self, + key: T, + snapshot: U, + call: CallStack, + layout: SnapshotFileLayout, + allow_multiple_equivalent_writes: bool = True, + ): + with self.lock: + this_write = FirstWrite(snapshot, call) + if key not in self.writes: + self.writes[key] = this_write + return + + existing = self.writes[key] + if existing.snapshot != snapshot: + raise ValueError( + f"Snapshot was set to multiple values!\n first time: {existing.call_stack.location.ide_link(layout)}\n this time: {call.location.ide_link(layout)}\n" + ) + elif not allow_multiple_equivalent_writes: + raise ValueError("Snapshot was set to the same value multiple times.") + + +class DiskWriteTracker(WriteTracker[T, U]): + def record(self, key: T, snapshot: U, call: CallStack, layout: SnapshotFileLayout): + super().recordInternal(key, snapshot, call, layout)