diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 41a7cd6..e82071d 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -6,7 +6,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: [ '3.6', '3.7', '3.8', '3.9', '3.10' ] + python-version: [ '3.7', '3.8', '3.9', '3.10', '3.11', '3.12' ] steps: - uses: actions/checkout@v2 diff --git a/README.md b/README.md index 40a4ba9..7c0ad68 100644 --- a/README.md +++ b/README.md @@ -6,7 +6,7 @@ To install: `pip install mock-firestore` -Python 3.6+ is required for it to work. +Python 3.9+ is required for it to work. ## Usage @@ -105,7 +105,7 @@ transaction.commit() ``` ## Running the tests -* Create and activate a virtualenv with a Python version of at least 3.6 +* Create and activate a virtualenv with a Python version of at least 3.9 * Install dependencies with `pip install -r requirements-dev-minimal.txt` * Run tests with `python -m unittest discover tests -t /` @@ -130,3 +130,4 @@ transaction.commit() * [William Li](https://github.com/wli) * [Ugo Marchand](https://github.com/UgoM) * [Bryce Thornton](https://github.com/brycethornton) +* [Ted Romer](https://github.com/thromer) diff --git a/mockfirestore/client.py b/mockfirestore/client.py index 75943bd..e3522b4 100644 --- a/mockfirestore/client.py +++ b/mockfirestore/client.py @@ -1,4 +1,4 @@ -from typing import Iterable, Sequence +from typing import Iterable, Sequence, Union from mockfirestore.collection import CollectionReference from mockfirestore.document import DocumentReference, DocumentSnapshot from mockfirestore.transaction import Transaction @@ -44,7 +44,7 @@ def collection(self, path: str) -> CollectionReference: self._data[name] = {} return CollectionReference(self._data, [name]) - def collections(self) -> Sequence[CollectionReference]: + def collections(self, timeout: Union[float, None]=None) -> Sequence[CollectionReference]: return [CollectionReference(self._data, [collection_name]) for collection_name in self._data] def reset(self): @@ -52,7 +52,8 @@ def reset(self): def get_all(self, references: Iterable[DocumentReference], field_paths=None, - transaction=None) -> Iterable[DocumentSnapshot]: + transaction=None, + timeout: Union[float, None]=None) -> Iterable[DocumentSnapshot]: for doc_ref in set(references): yield doc_ref.get() diff --git a/mockfirestore/collection.py b/mockfirestore/collection.py index 431c074..8a7d4fc 100644 --- a/mockfirestore/collection.py +++ b/mockfirestore/collection.py @@ -23,12 +23,12 @@ def document(self, document_id: Optional[str] = None) -> DocumentReference: set_by_path(self._data, new_path, {}) return DocumentReference(self._data, new_path, parent=self) - def get(self) -> Iterable[DocumentSnapshot]: + def get(self, timeout: Union[float, None] = None) -> Iterable[DocumentSnapshot]: warnings.warn('Collection.get is deprecated, please use Collection.stream', category=DeprecationWarning) return self.stream() - def add(self, document_data: Dict, document_id: str = None) \ + def add(self, document_data: Dict, document_id: str = None, timeout: Union[float, None] = None) \ -> Tuple[Timestamp, DocumentReference]: if document_id is None: document_id = document_data.get('id', generate_random_string()) @@ -41,8 +41,8 @@ def add(self, document_data: Dict, document_id: str = None) \ timestamp = Timestamp.from_now() return timestamp, doc_ref - def where(self, field: str, op: str, value: Any) -> Query: - query = Query(self, field_filters=[(field, op, value)]) + def where(self, field: Optional[str] = None, op: Optional[str] = None, value: Any = None, filter=None) -> Query: + query = Query(self, field_filters=[Query.make_field_filter(field, op, value, filter)]) return query def order_by(self, key: str, direction: Optional[str] = None) -> Query: @@ -73,13 +73,13 @@ def end_before(self, document_fields_or_snapshot: Union[dict, DocumentSnapshot]) query = Query(self, end_at=(document_fields_or_snapshot, False)) return query - def list_documents(self, page_size: Optional[int] = None) -> Sequence[DocumentReference]: + def list_documents(self, page_size: Optional[int] = None, timeout: Union[float, None] = None) -> Sequence[DocumentReference]: docs = [] for key in get_by_path(self._data, self._path): docs.append(self.document(key)) return docs - def stream(self, transaction=None) -> Iterable[DocumentSnapshot]: + def stream(self, transaction=None, timeout: Union[float, None] = None) -> Iterable[DocumentSnapshot]: for key in sorted(get_by_path(self._data, self._path)): doc_snapshot = self.document(key).get() yield doc_snapshot diff --git a/mockfirestore/document.py b/mockfirestore/document.py index 24aa433..a50dbf1 100644 --- a/mockfirestore/document.py +++ b/mockfirestore/document.py @@ -1,7 +1,7 @@ from copy import deepcopy from functools import reduce import operator -from typing import List, Dict, Any +from typing import List, Dict, Any, Union from mockfirestore import NotFound from mockfirestore._helpers import ( Timestamp, Document, Store, get_by_path, set_by_path, delete_by_path @@ -63,13 +63,13 @@ def __init__(self, data: Store, path: List[str], def id(self): return self._path[-1] - def get(self) -> DocumentSnapshot: + def get(self, timeout: Union[float, None]=None) -> DocumentSnapshot: return DocumentSnapshot(self, get_by_path(self._data, self._path)) - def delete(self): + def delete(self, timeout: Union[float, None]=None): delete_by_path(self._data, self._path) - def set(self, data: Dict, merge=False): + def set(self, data: Dict, merge=False, timeout: Union[float, None]=None): if merge: try: self.update(deepcopy(data)) @@ -78,14 +78,14 @@ def set(self, data: Dict, merge=False): else: set_by_path(self._data, self._path, deepcopy(data)) - def update(self, data: Dict[str, Any]): + def update(self, data: Dict[str, Any], timeout: Union[float, None]=None): document = get_by_path(self._data, self._path) if document == {}: raise NotFound('No document to update: {}'.format(self._path)) apply_transformations(document, deepcopy(data)) - def collection(self, name) -> 'CollectionReference': + def collection(self, name, timeout: Union[float, None]=None) -> 'CollectionReference': from mockfirestore.collection import CollectionReference document = get_by_path(self._data, self._path) new_path = self._path + [name] diff --git a/mockfirestore/query.py b/mockfirestore/query.py index 7a4618d..a5d5bd5 100644 --- a/mockfirestore/query.py +++ b/mockfirestore/query.py @@ -24,7 +24,7 @@ def __init__(self, parent: 'CollectionReference', projection=None, for field_filter in field_filters: self._add_field_filter(*field_filter) - def stream(self, transaction=None) -> Iterator[DocumentSnapshot]: + def stream(self, transaction=None, timeout: Union[float, None]=None) -> Iterator[DocumentSnapshot]: doc_snapshots = self.parent.stream() for field, compare, value in self._field_filters: @@ -52,7 +52,7 @@ def stream(self, transaction=None) -> Iterator[DocumentSnapshot]: return iter(doc_snapshots) - def get(self) -> Iterator[DocumentSnapshot]: + def get(self, timeout: Union[float, None]=None) -> Iterator[DocumentSnapshot]: warnings.warn('Query.get is deprecated, please use Query.stream', category=DeprecationWarning) return self.stream() @@ -61,8 +61,20 @@ def _add_field_filter(self, field: str, op: str, value: Any): compare = self._compare_func(op) self._field_filters.append((field, compare, value)) - def where(self, field: str, op: str, value: Any) -> 'Query': - self._add_field_filter(field, op, value) + @staticmethod + def make_field_filter(field: Optional[str], op: Optional[str], value: Any = None, filter=None): + if bool(filter) and (bool(field) or bool(op)): + raise ValueError("Can't pass in both the positional arguments and 'filter' at the same time") + if filter: + classname = filter.__class__.__name__ + if not classname.endswith('FieldFilter'): + raise NotImplementedError('composite filters not supported by mockfirestore (got %s)' % classname) + return (filter.field_path, filter.op_string, filter.value) + else: + return (field, op, value) + + def where(self, field: Optional[str] = None, op: Optional[str] = None, value: Any = None, filter=None) -> 'Query': + self._add_field_filter(*self.make_field_filter(field, op, value, filter)) return self def order_by(self, key: str, direction: Optional[str] = 'ASCENDING') -> 'Query': diff --git a/mockfirestore/transaction.py b/mockfirestore/transaction.py index 7f06d2d..fac07d2 100644 --- a/mockfirestore/transaction.py +++ b/mockfirestore/transaction.py @@ -5,6 +5,8 @@ from mockfirestore.document import DocumentReference, DocumentSnapshot from mockfirestore.query import Query +from typing import Union + MAX_ATTEMPTS = 5 _MISSING_ID_TEMPLATE = "The transaction has no transaction ID, so it cannot be {}." _CANT_BEGIN = "The transaction has already begun. Current transaction ID: {!r}." @@ -66,10 +68,11 @@ def _commit(self) -> Iterable[WriteResult]: return results def get_all(self, - references: Iterable[DocumentReference]) -> Iterable[DocumentSnapshot]: + references: Iterable[DocumentReference], + timeout: Union[float, None]=None) -> Iterable[DocumentSnapshot]: return self._client.get_all(references) - def get(self, ref_or_query) -> Iterable[DocumentSnapshot]: + def get(self, ref_or_query, timeout: Union[float,None]=None) -> Iterable[DocumentSnapshot]: if isinstance(ref_or_query, DocumentReference): return self._client.get_all([ref_or_query]) elif isinstance(ref_or_query, Query): diff --git a/setup.py b/setup.py index f55cb88..5140d2b 100644 --- a/setup.py +++ b/setup.py @@ -5,20 +5,22 @@ setuptools.setup( name="mock-firestore", - version="0.11.0", + version="0.11.1", author="Matt Dowds", description="In-memory implementation of Google Cloud Firestore for use in tests", long_description=long_description, long_description_content_type="text/markdown", url="https://github.com/mdowds/mock-firestore", - packages=setuptools.find_packages(), + packages=setuptools.find_packages(exclude=["tests"]), + python_requires=">=3.9", test_suite='', classifiers=[ - 'Programming Language :: Python :: 3.6', - 'Programming Language :: Python :: 3.7', - 'Programming Language :: Python :: 3.8', + 'Programming Language :: Python :: 3', 'Programming Language :: Python :: 3.9', 'Programming Language :: Python :: 3.10', + 'Programming Language :: Python :: 3.11', + 'Programming Language :: Python :: 3.12', + 'Programming Language :: Python :: 3.13', "License :: OSI Approved :: MIT License", ], -) \ No newline at end of file +) diff --git a/tests/test_collection_reference.py b/tests/test_collection_reference.py index 7f1924a..d234d80 100644 --- a/tests/test_collection_reference.py +++ b/tests/test_collection_reference.py @@ -3,6 +3,13 @@ from mockfirestore import MockFirestore, DocumentReference, DocumentSnapshot, AlreadyExists +class MockFieldFilter: + def __init__(self, field_path, op_string, value=None): + self.field_path = field_path + self.op_string = op_string + self.value = value + + class TestCollectionReference(TestCase): def test_collection_get_returnsDocuments(self): fs = MockFirestore() @@ -193,6 +200,29 @@ def test_collection_whereArrayContainsAny(self): self.assertEqual({'field': ['val4']}, contains_any_docs[0].to_dict()) self.assertEqual({'field': ['val3', 'val2', 'val1']}, contains_any_docs[1].to_dict()) + def test_collection_nestedWhereFieldFilter(self): + fs = MockFirestore() + fs._data = {'foo': { + 'first': {'a': 3}, + 'second': {'a': 4}, + 'third': {'a': 5} + }} + filters = [MockFieldFilter('a', '>=', 4), + MockFieldFilter('a', '>=', 5)] + ge_4_dicts = [d.to_dict() + for d in list(fs.collection('foo').where( + filter=filters[0]).stream()) + ] + ge_5_dicts = [d.to_dict() + for d in list(fs.collection('foo').where( + filter=filters[0]).where( + filter=filters[1]).stream()) + ] + self.assertEqual(len(ge_4_dicts), 2) + self.assertEqual(sorted([o['a'] for o in ge_4_dicts]), [4, 5]) + self.assertEqual(len(ge_5_dicts), 1) + self.assertEqual([o['a'] for o in ge_5_dicts], [5]) + def test_collection_orderBy(self): fs = MockFirestore() fs._data = {'foo': {