From bd3f7818a6c1eee43e0626546fa0ae8bd52394a0 Mon Sep 17 00:00:00 2001 From: Ted Romer Date: Sat, 18 Jan 2025 18:12:12 -0800 Subject: [PATCH] Allow optional timeout keyword parameter on methods that accept in real Firestore Python client. --- mockfirestore/client.py | 7 ++++--- mockfirestore/collection.py | 8 ++++---- mockfirestore/document.py | 12 ++++++------ mockfirestore/query.py | 4 ++-- mockfirestore/transaction.py | 7 +++++-- 5 files changed, 21 insertions(+), 17 deletions(-) diff --git a/mockfirestore/client.py b/mockfirestore/client.py index 75943bd..e3522b4 100644 --- a/mockfirestore/client.py +++ b/mockfirestore/client.py @@ -1,4 +1,4 @@ -from typing import Iterable, Sequence +from typing import Iterable, Sequence, Union from mockfirestore.collection import CollectionReference from mockfirestore.document import DocumentReference, DocumentSnapshot from mockfirestore.transaction import Transaction @@ -44,7 +44,7 @@ def collection(self, path: str) -> CollectionReference: self._data[name] = {} return CollectionReference(self._data, [name]) - def collections(self) -> Sequence[CollectionReference]: + def collections(self, timeout: Union[float, None]=None) -> Sequence[CollectionReference]: return [CollectionReference(self._data, [collection_name]) for collection_name in self._data] def reset(self): @@ -52,7 +52,8 @@ def reset(self): def get_all(self, references: Iterable[DocumentReference], field_paths=None, - transaction=None) -> Iterable[DocumentSnapshot]: + transaction=None, + timeout: Union[float, None]=None) -> Iterable[DocumentSnapshot]: for doc_ref in set(references): yield doc_ref.get() diff --git a/mockfirestore/collection.py b/mockfirestore/collection.py index a15ba30..8a7d4fc 100644 --- a/mockfirestore/collection.py +++ b/mockfirestore/collection.py @@ -23,12 +23,12 @@ def document(self, document_id: Optional[str] = None) -> DocumentReference: set_by_path(self._data, new_path, {}) return DocumentReference(self._data, new_path, parent=self) - def get(self) -> Iterable[DocumentSnapshot]: + def get(self, timeout: Union[float, None] = None) -> Iterable[DocumentSnapshot]: warnings.warn('Collection.get is deprecated, please use Collection.stream', category=DeprecationWarning) return self.stream() - def add(self, document_data: Dict, document_id: str = None) \ + def add(self, document_data: Dict, document_id: str = None, timeout: Union[float, None] = None) \ -> Tuple[Timestamp, DocumentReference]: if document_id is None: document_id = document_data.get('id', generate_random_string()) @@ -73,13 +73,13 @@ def end_before(self, document_fields_or_snapshot: Union[dict, DocumentSnapshot]) query = Query(self, end_at=(document_fields_or_snapshot, False)) return query - def list_documents(self, page_size: Optional[int] = None) -> Sequence[DocumentReference]: + def list_documents(self, page_size: Optional[int] = None, timeout: Union[float, None] = None) -> Sequence[DocumentReference]: docs = [] for key in get_by_path(self._data, self._path): docs.append(self.document(key)) return docs - def stream(self, transaction=None) -> Iterable[DocumentSnapshot]: + def stream(self, transaction=None, timeout: Union[float, None] = None) -> Iterable[DocumentSnapshot]: for key in sorted(get_by_path(self._data, self._path)): doc_snapshot = self.document(key).get() yield doc_snapshot diff --git a/mockfirestore/document.py b/mockfirestore/document.py index 24aa433..a50dbf1 100644 --- a/mockfirestore/document.py +++ b/mockfirestore/document.py @@ -1,7 +1,7 @@ from copy import deepcopy from functools import reduce import operator -from typing import List, Dict, Any +from typing import List, Dict, Any, Union from mockfirestore import NotFound from mockfirestore._helpers import ( Timestamp, Document, Store, get_by_path, set_by_path, delete_by_path @@ -63,13 +63,13 @@ def __init__(self, data: Store, path: List[str], def id(self): return self._path[-1] - def get(self) -> DocumentSnapshot: + def get(self, timeout: Union[float, None]=None) -> DocumentSnapshot: return DocumentSnapshot(self, get_by_path(self._data, self._path)) - def delete(self): + def delete(self, timeout: Union[float, None]=None): delete_by_path(self._data, self._path) - def set(self, data: Dict, merge=False): + def set(self, data: Dict, merge=False, timeout: Union[float, None]=None): if merge: try: self.update(deepcopy(data)) @@ -78,14 +78,14 @@ def set(self, data: Dict, merge=False): else: set_by_path(self._data, self._path, deepcopy(data)) - def update(self, data: Dict[str, Any]): + def update(self, data: Dict[str, Any], timeout: Union[float, None]=None): document = get_by_path(self._data, self._path) if document == {}: raise NotFound('No document to update: {}'.format(self._path)) apply_transformations(document, deepcopy(data)) - def collection(self, name) -> 'CollectionReference': + def collection(self, name, timeout: Union[float, None]=None) -> 'CollectionReference': from mockfirestore.collection import CollectionReference document = get_by_path(self._data, self._path) new_path = self._path + [name] diff --git a/mockfirestore/query.py b/mockfirestore/query.py index c1cb6bb..a5d5bd5 100644 --- a/mockfirestore/query.py +++ b/mockfirestore/query.py @@ -24,7 +24,7 @@ def __init__(self, parent: 'CollectionReference', projection=None, for field_filter in field_filters: self._add_field_filter(*field_filter) - def stream(self, transaction=None) -> Iterator[DocumentSnapshot]: + def stream(self, transaction=None, timeout: Union[float, None]=None) -> Iterator[DocumentSnapshot]: doc_snapshots = self.parent.stream() for field, compare, value in self._field_filters: @@ -52,7 +52,7 @@ def stream(self, transaction=None) -> Iterator[DocumentSnapshot]: return iter(doc_snapshots) - def get(self) -> Iterator[DocumentSnapshot]: + def get(self, timeout: Union[float, None]=None) -> Iterator[DocumentSnapshot]: warnings.warn('Query.get is deprecated, please use Query.stream', category=DeprecationWarning) return self.stream() diff --git a/mockfirestore/transaction.py b/mockfirestore/transaction.py index 7f06d2d..fac07d2 100644 --- a/mockfirestore/transaction.py +++ b/mockfirestore/transaction.py @@ -5,6 +5,8 @@ from mockfirestore.document import DocumentReference, DocumentSnapshot from mockfirestore.query import Query +from typing import Union + MAX_ATTEMPTS = 5 _MISSING_ID_TEMPLATE = "The transaction has no transaction ID, so it cannot be {}." _CANT_BEGIN = "The transaction has already begun. Current transaction ID: {!r}." @@ -66,10 +68,11 @@ def _commit(self) -> Iterable[WriteResult]: return results def get_all(self, - references: Iterable[DocumentReference]) -> Iterable[DocumentSnapshot]: + references: Iterable[DocumentReference], + timeout: Union[float, None]=None) -> Iterable[DocumentSnapshot]: return self._client.get_all(references) - def get(self, ref_or_query) -> Iterable[DocumentSnapshot]: + def get(self, ref_or_query, timeout: Union[float,None]=None) -> Iterable[DocumentSnapshot]: if isinstance(ref_or_query, DocumentReference): return self._client.get_all([ref_or_query]) elif isinstance(ref_or_query, Query):