Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support filter keyword in where clauses. (#74) #82

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: [ '3.6', '3.7', '3.8', '3.9', '3.10' ]
python-version: [ '3.7', '3.8', '3.9', '3.10', '3.11', '3.12' ]

steps:
- uses: actions/checkout@v2
Expand Down
5 changes: 3 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ To install:

`pip install mock-firestore`

Python 3.6+ is required for it to work.
Python 3.9+ is required for it to work.

## Usage

Expand Down Expand Up @@ -105,7 +105,7 @@ transaction.commit()
```

## Running the tests
* Create and activate a virtualenv with a Python version of at least 3.6
* Create and activate a virtualenv with a Python version of at least 3.9
* Install dependencies with `pip install -r requirements-dev-minimal.txt`
* Run tests with `python -m unittest discover tests -t /`

Expand All @@ -130,3 +130,4 @@ transaction.commit()
* [William Li](https://github.com/wli)
* [Ugo Marchand](https://github.com/UgoM)
* [Bryce Thornton](https://github.com/brycethornton)
* [Ted Romer](https://github.com/thromer)
7 changes: 4 additions & 3 deletions mockfirestore/client.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Iterable, Sequence
from typing import Iterable, Sequence, Union
from mockfirestore.collection import CollectionReference
from mockfirestore.document import DocumentReference, DocumentSnapshot
from mockfirestore.transaction import Transaction
Expand Down Expand Up @@ -44,15 +44,16 @@ def collection(self, path: str) -> CollectionReference:
self._data[name] = {}
return CollectionReference(self._data, [name])

def collections(self) -> Sequence[CollectionReference]:
def collections(self, timeout: Union[float, None]=None) -> Sequence[CollectionReference]:
return [CollectionReference(self._data, [collection_name]) for collection_name in self._data]

def reset(self):
self._data = {}

def get_all(self, references: Iterable[DocumentReference],
field_paths=None,
transaction=None) -> Iterable[DocumentSnapshot]:
transaction=None,
timeout: Union[float, None]=None) -> Iterable[DocumentSnapshot]:
for doc_ref in set(references):
yield doc_ref.get()

Expand Down
12 changes: 6 additions & 6 deletions mockfirestore/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,12 @@ def document(self, document_id: Optional[str] = None) -> DocumentReference:
set_by_path(self._data, new_path, {})
return DocumentReference(self._data, new_path, parent=self)

def get(self) -> Iterable[DocumentSnapshot]:
def get(self, timeout: Union[float, None] = None) -> Iterable[DocumentSnapshot]:
warnings.warn('Collection.get is deprecated, please use Collection.stream',
category=DeprecationWarning)
return self.stream()

def add(self, document_data: Dict, document_id: str = None) \
def add(self, document_data: Dict, document_id: str = None, timeout: Union[float, None] = None) \
-> Tuple[Timestamp, DocumentReference]:
if document_id is None:
document_id = document_data.get('id', generate_random_string())
Expand All @@ -41,8 +41,8 @@ def add(self, document_data: Dict, document_id: str = None) \
timestamp = Timestamp.from_now()
return timestamp, doc_ref

def where(self, field: str, op: str, value: Any) -> Query:
query = Query(self, field_filters=[(field, op, value)])
def where(self, field: Optional[str] = None, op: Optional[str] = None, value: Any = None, filter=None) -> Query:
query = Query(self, field_filters=[Query.make_field_filter(field, op, value, filter)])
return query

def order_by(self, key: str, direction: Optional[str] = None) -> Query:
Expand Down Expand Up @@ -73,13 +73,13 @@ def end_before(self, document_fields_or_snapshot: Union[dict, DocumentSnapshot])
query = Query(self, end_at=(document_fields_or_snapshot, False))
return query

def list_documents(self, page_size: Optional[int] = None) -> Sequence[DocumentReference]:
def list_documents(self, page_size: Optional[int] = None, timeout: Union[float, None] = None) -> Sequence[DocumentReference]:
docs = []
for key in get_by_path(self._data, self._path):
docs.append(self.document(key))
return docs

def stream(self, transaction=None) -> Iterable[DocumentSnapshot]:
def stream(self, transaction=None, timeout: Union[float, None] = None) -> Iterable[DocumentSnapshot]:
for key in sorted(get_by_path(self._data, self._path)):
doc_snapshot = self.document(key).get()
yield doc_snapshot
12 changes: 6 additions & 6 deletions mockfirestore/document.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from copy import deepcopy
from functools import reduce
import operator
from typing import List, Dict, Any
from typing import List, Dict, Any, Union
from mockfirestore import NotFound
from mockfirestore._helpers import (
Timestamp, Document, Store, get_by_path, set_by_path, delete_by_path
Expand Down Expand Up @@ -63,13 +63,13 @@ def __init__(self, data: Store, path: List[str],
def id(self):
return self._path[-1]

def get(self) -> DocumentSnapshot:
def get(self, timeout: Union[float, None]=None) -> DocumentSnapshot:
return DocumentSnapshot(self, get_by_path(self._data, self._path))

def delete(self):
def delete(self, timeout: Union[float, None]=None):
delete_by_path(self._data, self._path)

def set(self, data: Dict, merge=False):
def set(self, data: Dict, merge=False, timeout: Union[float, None]=None):
if merge:
try:
self.update(deepcopy(data))
Expand All @@ -78,14 +78,14 @@ def set(self, data: Dict, merge=False):
else:
set_by_path(self._data, self._path, deepcopy(data))

def update(self, data: Dict[str, Any]):
def update(self, data: Dict[str, Any], timeout: Union[float, None]=None):
document = get_by_path(self._data, self._path)
if document == {}:
raise NotFound('No document to update: {}'.format(self._path))

apply_transformations(document, deepcopy(data))

def collection(self, name) -> 'CollectionReference':
def collection(self, name, timeout: Union[float, None]=None) -> 'CollectionReference':
from mockfirestore.collection import CollectionReference
document = get_by_path(self._data, self._path)
new_path = self._path + [name]
Expand Down
20 changes: 16 additions & 4 deletions mockfirestore/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def __init__(self, parent: 'CollectionReference', projection=None,
for field_filter in field_filters:
self._add_field_filter(*field_filter)

def stream(self, transaction=None) -> Iterator[DocumentSnapshot]:
def stream(self, transaction=None, timeout: Union[float, None]=None) -> Iterator[DocumentSnapshot]:
doc_snapshots = self.parent.stream()

for field, compare, value in self._field_filters:
Expand Down Expand Up @@ -52,7 +52,7 @@ def stream(self, transaction=None) -> Iterator[DocumentSnapshot]:

return iter(doc_snapshots)

def get(self) -> Iterator[DocumentSnapshot]:
def get(self, timeout: Union[float, None]=None) -> Iterator[DocumentSnapshot]:
warnings.warn('Query.get is deprecated, please use Query.stream',
category=DeprecationWarning)
return self.stream()
Expand All @@ -61,8 +61,20 @@ def _add_field_filter(self, field: str, op: str, value: Any):
compare = self._compare_func(op)
self._field_filters.append((field, compare, value))

def where(self, field: str, op: str, value: Any) -> 'Query':
self._add_field_filter(field, op, value)
@staticmethod
def make_field_filter(field: Optional[str], op: Optional[str], value: Any = None, filter=None):
if bool(filter) and (bool(field) or bool(op)):
raise ValueError("Can't pass in both the positional arguments and 'filter' at the same time")
if filter:
classname = filter.__class__.__name__
if not classname.endswith('FieldFilter'):
raise NotImplementedError('composite filters not supported by mockfirestore (got %s)' % classname)
return (filter.field_path, filter.op_string, filter.value)
else:
return (field, op, value)

def where(self, field: Optional[str] = None, op: Optional[str] = None, value: Any = None, filter=None) -> 'Query':
self._add_field_filter(*self.make_field_filter(field, op, value, filter))
return self

def order_by(self, key: str, direction: Optional[str] = 'ASCENDING') -> 'Query':
Expand Down
7 changes: 5 additions & 2 deletions mockfirestore/transaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from mockfirestore.document import DocumentReference, DocumentSnapshot
from mockfirestore.query import Query

from typing import Union

MAX_ATTEMPTS = 5
_MISSING_ID_TEMPLATE = "The transaction has no transaction ID, so it cannot be {}."
_CANT_BEGIN = "The transaction has already begun. Current transaction ID: {!r}."
Expand Down Expand Up @@ -66,10 +68,11 @@ def _commit(self) -> Iterable[WriteResult]:
return results

def get_all(self,
references: Iterable[DocumentReference]) -> Iterable[DocumentSnapshot]:
references: Iterable[DocumentReference],
timeout: Union[float, None]=None) -> Iterable[DocumentSnapshot]:
return self._client.get_all(references)

def get(self, ref_or_query) -> Iterable[DocumentSnapshot]:
def get(self, ref_or_query, timeout: Union[float,None]=None) -> Iterable[DocumentSnapshot]:
if isinstance(ref_or_query, DocumentReference):
return self._client.get_all([ref_or_query])
elif isinstance(ref_or_query, Query):
Expand Down
14 changes: 8 additions & 6 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,20 +5,22 @@

setuptools.setup(
name="mock-firestore",
version="0.11.0",
version="0.11.1",
author="Matt Dowds",
description="In-memory implementation of Google Cloud Firestore for use in tests",
long_description=long_description,
long_description_content_type="text/markdown",
url="https://github.com/mdowds/mock-firestore",
packages=setuptools.find_packages(),
packages=setuptools.find_packages(exclude=["tests"]),
python_requires=">=3.9",
test_suite='',
classifiers=[
'Programming Language :: Python :: 3.6',
'Programming Language :: Python :: 3.7',
'Programming Language :: Python :: 3.8',
'Programming Language :: Python :: 3',
'Programming Language :: Python :: 3.9',
'Programming Language :: Python :: 3.10',
'Programming Language :: Python :: 3.11',
'Programming Language :: Python :: 3.12',
'Programming Language :: Python :: 3.13',
"License :: OSI Approved :: MIT License",
],
)
)
30 changes: 30 additions & 0 deletions tests/test_collection_reference.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,13 @@
from mockfirestore import MockFirestore, DocumentReference, DocumentSnapshot, AlreadyExists


class MockFieldFilter:
def __init__(self, field_path, op_string, value=None):
self.field_path = field_path
self.op_string = op_string
self.value = value


class TestCollectionReference(TestCase):
def test_collection_get_returnsDocuments(self):
fs = MockFirestore()
Expand Down Expand Up @@ -193,6 +200,29 @@ def test_collection_whereArrayContainsAny(self):
self.assertEqual({'field': ['val4']}, contains_any_docs[0].to_dict())
self.assertEqual({'field': ['val3', 'val2', 'val1']}, contains_any_docs[1].to_dict())

def test_collection_nestedWhereFieldFilter(self):
fs = MockFirestore()
fs._data = {'foo': {
'first': {'a': 3},
'second': {'a': 4},
'third': {'a': 5}
}}
filters = [MockFieldFilter('a', '>=', 4),
MockFieldFilter('a', '>=', 5)]
ge_4_dicts = [d.to_dict()
for d in list(fs.collection('foo').where(
filter=filters[0]).stream())
]
ge_5_dicts = [d.to_dict()
for d in list(fs.collection('foo').where(
filter=filters[0]).where(
filter=filters[1]).stream())
]
self.assertEqual(len(ge_4_dicts), 2)
self.assertEqual(sorted([o['a'] for o in ge_4_dicts]), [4, 5])
self.assertEqual(len(ge_5_dicts), 1)
self.assertEqual([o['a'] for o in ge_5_dicts], [5])

def test_collection_orderBy(self):
fs = MockFirestore()
fs._data = {'foo': {
Expand Down
Loading