Skip to content

Commit

Permalink
Allow optional timeout keyword parameter on methods that accept in re…
Browse files Browse the repository at this point in the history
…al Firestore Python client.
  • Loading branch information
thromer committed Jan 19, 2025
1 parent b8a8a3f commit bd3f781
Show file tree
Hide file tree
Showing 5 changed files with 21 additions and 17 deletions.
7 changes: 4 additions & 3 deletions mockfirestore/client.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Iterable, Sequence
from typing import Iterable, Sequence, Union
from mockfirestore.collection import CollectionReference
from mockfirestore.document import DocumentReference, DocumentSnapshot
from mockfirestore.transaction import Transaction
Expand Down Expand Up @@ -44,15 +44,16 @@ def collection(self, path: str) -> CollectionReference:
self._data[name] = {}
return CollectionReference(self._data, [name])

def collections(self) -> Sequence[CollectionReference]:
def collections(self, timeout: Union[float, None]=None) -> Sequence[CollectionReference]:
return [CollectionReference(self._data, [collection_name]) for collection_name in self._data]

def reset(self):
self._data = {}

def get_all(self, references: Iterable[DocumentReference],
field_paths=None,
transaction=None) -> Iterable[DocumentSnapshot]:
transaction=None,
timeout: Union[float, None]=None) -> Iterable[DocumentSnapshot]:
for doc_ref in set(references):
yield doc_ref.get()

Expand Down
8 changes: 4 additions & 4 deletions mockfirestore/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,12 @@ def document(self, document_id: Optional[str] = None) -> DocumentReference:
set_by_path(self._data, new_path, {})
return DocumentReference(self._data, new_path, parent=self)

def get(self) -> Iterable[DocumentSnapshot]:
def get(self, timeout: Union[float, None] = None) -> Iterable[DocumentSnapshot]:
warnings.warn('Collection.get is deprecated, please use Collection.stream',
category=DeprecationWarning)
return self.stream()

def add(self, document_data: Dict, document_id: str = None) \
def add(self, document_data: Dict, document_id: str = None, timeout: Union[float, None] = None) \
-> Tuple[Timestamp, DocumentReference]:
if document_id is None:
document_id = document_data.get('id', generate_random_string())
Expand Down Expand Up @@ -73,13 +73,13 @@ def end_before(self, document_fields_or_snapshot: Union[dict, DocumentSnapshot])
query = Query(self, end_at=(document_fields_or_snapshot, False))
return query

def list_documents(self, page_size: Optional[int] = None) -> Sequence[DocumentReference]:
def list_documents(self, page_size: Optional[int] = None, timeout: Union[float, None] = None) -> Sequence[DocumentReference]:
docs = []
for key in get_by_path(self._data, self._path):
docs.append(self.document(key))
return docs

def stream(self, transaction=None) -> Iterable[DocumentSnapshot]:
def stream(self, transaction=None, timeout: Union[float, None] = None) -> Iterable[DocumentSnapshot]:
for key in sorted(get_by_path(self._data, self._path)):
doc_snapshot = self.document(key).get()
yield doc_snapshot
12 changes: 6 additions & 6 deletions mockfirestore/document.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from copy import deepcopy
from functools import reduce
import operator
from typing import List, Dict, Any
from typing import List, Dict, Any, Union
from mockfirestore import NotFound
from mockfirestore._helpers import (
Timestamp, Document, Store, get_by_path, set_by_path, delete_by_path
Expand Down Expand Up @@ -63,13 +63,13 @@ def __init__(self, data: Store, path: List[str],
def id(self):
return self._path[-1]

def get(self) -> DocumentSnapshot:
def get(self, timeout: Union[float, None]=None) -> DocumentSnapshot:
return DocumentSnapshot(self, get_by_path(self._data, self._path))

def delete(self):
def delete(self, timeout: Union[float, None]=None):
delete_by_path(self._data, self._path)

def set(self, data: Dict, merge=False):
def set(self, data: Dict, merge=False, timeout: Union[float, None]=None):
if merge:
try:
self.update(deepcopy(data))
Expand All @@ -78,14 +78,14 @@ def set(self, data: Dict, merge=False):
else:
set_by_path(self._data, self._path, deepcopy(data))

def update(self, data: Dict[str, Any]):
def update(self, data: Dict[str, Any], timeout: Union[float, None]=None):
document = get_by_path(self._data, self._path)
if document == {}:
raise NotFound('No document to update: {}'.format(self._path))

apply_transformations(document, deepcopy(data))

def collection(self, name) -> 'CollectionReference':
def collection(self, name, timeout: Union[float, None]=None) -> 'CollectionReference':
from mockfirestore.collection import CollectionReference
document = get_by_path(self._data, self._path)
new_path = self._path + [name]
Expand Down
4 changes: 2 additions & 2 deletions mockfirestore/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def __init__(self, parent: 'CollectionReference', projection=None,
for field_filter in field_filters:
self._add_field_filter(*field_filter)

def stream(self, transaction=None) -> Iterator[DocumentSnapshot]:
def stream(self, transaction=None, timeout: Union[float, None]=None) -> Iterator[DocumentSnapshot]:
doc_snapshots = self.parent.stream()

for field, compare, value in self._field_filters:
Expand Down Expand Up @@ -52,7 +52,7 @@ def stream(self, transaction=None) -> Iterator[DocumentSnapshot]:

return iter(doc_snapshots)

def get(self) -> Iterator[DocumentSnapshot]:
def get(self, timeout: Union[float, None]=None) -> Iterator[DocumentSnapshot]:
warnings.warn('Query.get is deprecated, please use Query.stream',
category=DeprecationWarning)
return self.stream()
Expand Down
7 changes: 5 additions & 2 deletions mockfirestore/transaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from mockfirestore.document import DocumentReference, DocumentSnapshot
from mockfirestore.query import Query

from typing import Union

MAX_ATTEMPTS = 5
_MISSING_ID_TEMPLATE = "The transaction has no transaction ID, so it cannot be {}."
_CANT_BEGIN = "The transaction has already begun. Current transaction ID: {!r}."
Expand Down Expand Up @@ -66,10 +68,11 @@ def _commit(self) -> Iterable[WriteResult]:
return results

def get_all(self,
references: Iterable[DocumentReference]) -> Iterable[DocumentSnapshot]:
references: Iterable[DocumentReference],
timeout: Union[float, None]=None) -> Iterable[DocumentSnapshot]:
return self._client.get_all(references)

def get(self, ref_or_query) -> Iterable[DocumentSnapshot]:
def get(self, ref_or_query, timeout: Union[float,None]=None) -> Iterable[DocumentSnapshot]:
if isinstance(ref_or_query, DocumentReference):
return self._client.get_all([ref_or_query])
elif isinstance(ref_or_query, Query):
Expand Down

0 comments on commit bd3f781

Please sign in to comment.