Skip to content

Commit

Permalink
WriteTracker and DiskWriteTracker (#42)
Browse files Browse the repository at this point in the history
  • Loading branch information
nedtwigg authored Mar 19, 2024
2 parents 6e1e5fe + 69da5d5 commit 8f7528a
Showing 1 changed file with 60 additions and 2 deletions.
62 changes: 60 additions & 2 deletions python/selfie-lib/selfie_lib/WriteTracker.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
from typing import List, Optional
from typing import List, Optional, Generic, TypeVar, Dict
from selfie_lib.CommentTracker import SnapshotFileLayout
import inspect
from abc import ABC, abstractmethod
import inspect, threading
from functools import total_ordering

T = TypeVar("T")
U = TypeVar("U")


@total_ordering
class CallLocation:
Expand Down Expand Up @@ -56,6 +60,17 @@ def ide_link(self, layout: "SnapshotFileLayout") -> str:
]
return "\n".join(links)

def __eq__(self, other):
if not isinstance(other, CallStack):
return NotImplemented
return (
self.location == other.location
and self.rest_of_stack == other.rest_of_stack
)

def __hash__(self):
return hash((self.location, tuple(self.rest_of_stack)))


def recordCall(callerFileOnly: bool = False) -> CallStack:
stack_frames = inspect.stack()[1:]
Expand All @@ -74,3 +89,46 @@ def recordCall(callerFileOnly: bool = False) -> CallStack:
rest_of_stack = call_locations[1:]

return CallStack(location, rest_of_stack)


class FirstWrite(Generic[U]):
def __init__(self, snapshot: U, call_stack: CallStack):
self.snapshot = snapshot
self.call_stack = call_stack


class WriteTracker(ABC, Generic[T, U]):
def __init__(self):
self.lock = threading.Lock()
self.writes: Dict[T, FirstWrite[U]] = {}

@abstractmethod
def record(self, key: T, snapshot: U, call: CallStack, layout: SnapshotFileLayout):
pass

def recordInternal(
self,
key: T,
snapshot: U,
call: CallStack,
layout: SnapshotFileLayout,
allow_multiple_equivalent_writes: bool = True,
):
with self.lock:
this_write = FirstWrite(snapshot, call)
if key not in self.writes:
self.writes[key] = this_write
return

existing = self.writes[key]
if existing.snapshot != snapshot:
raise ValueError(
f"Snapshot was set to multiple values!\n first time: {existing.call_stack.location.ide_link(layout)}\n this time: {call.location.ide_link(layout)}\n"
)
elif not allow_multiple_equivalent_writes:
raise ValueError("Snapshot was set to the same value multiple times.")


class DiskWriteTracker(WriteTracker[T, U]):
def record(self, key: T, snapshot: U, call: CallStack, layout: SnapshotFileLayout):
super().recordInternal(key, snapshot, call, layout)

0 comments on commit 8f7528a

Please sign in to comment.