Skip to content

Commit

Permalink
Support filter keyword in where clauses. (#74)
Browse files Browse the repository at this point in the history
Doesn't support composite filters.
# Please enter the commit message for your changes. Lines starting
# with '#' will be ignored, and an empty message aborts the commit.
#
# On branch master
# Your branch is up to date with 'origin/master'.
#
# Changes to be committed:
#	modified:   mockfirestore/collection.py
#	modified:   mockfirestore/query.py
#	modified:   tests/test_collection_reference.py
#
  • Loading branch information
thromer committed Apr 13, 2024
1 parent 0de34b1 commit c20149d
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 4 deletions.
4 changes: 2 additions & 2 deletions mockfirestore/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ def add(self, document_data: Dict, document_id: str = None) \
timestamp = Timestamp.from_now()
return timestamp, doc_ref

def where(self, field: str, op: str, value: Any) -> Query:
query = Query(self, field_filters=[(field, op, value)])
def where(self, field: Optional[str] = None, op: Optional[str] = None, value: Any = None, filter=None) -> Query:
query = Query(self, field_filters=[Query.make_field_filter(field, op, value, filter)])
return query

def order_by(self, key: str, direction: Optional[str] = None) -> Query:
Expand Down
16 changes: 14 additions & 2 deletions mockfirestore/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,20 @@ def _add_field_filter(self, field: str, op: str, value: Any):
compare = self._compare_func(op)
self._field_filters.append((field, compare, value))

def where(self, field: str, op: str, value: Any) -> 'Query':
self._add_field_filter(field, op, value)
@staticmethod
def make_field_filter(field: Optional[str], op: Optional[str], value: Any = None, filter=None):
if bool(filter) and (bool(field) or bool(op)):
raise ValueError("Can't pass in both the positional arguments and 'filter' at the same time")
if filter:
classname = filter.__class__.__name__
if not classname.endswith('FieldFilter'):
raise NotImplementedError('composite filters not supported by mockfirestore (got %s)' % classname)
return (filter.field_path, filter.op_string, filter.value)
else:
return (field, op, value)

def where(self, field: Optional[str] = None, op: Optional[str] = None, value: Any = None, filter=None) -> 'Query':
self._add_field_filter(*self.make_field_filter(field, op, value, filter))
return self

def order_by(self, key: str, direction: Optional[str] = 'ASCENDING') -> 'Query':
Expand Down
30 changes: 30 additions & 0 deletions tests/test_collection_reference.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,13 @@
from mockfirestore import MockFirestore, DocumentReference, DocumentSnapshot, AlreadyExists


class MockFieldFilter:
def __init__(self, field_path, op_string, value=None):
self.field_path = field_path
self.op_string = op_string
self.value = value


class TestCollectionReference(TestCase):
def test_collection_get_returnsDocuments(self):
fs = MockFirestore()
Expand Down Expand Up @@ -193,6 +200,29 @@ def test_collection_whereArrayContainsAny(self):
self.assertEqual({'field': ['val4']}, contains_any_docs[0].to_dict())
self.assertEqual({'field': ['val3', 'val2', 'val1']}, contains_any_docs[1].to_dict())

def test_collection_nestedWhereFieldFilter(self):
fs = MockFirestore()
fs._data = {'foo': {
'first': {'a': 3},
'second': {'a': 4},
'third': {'a': 5}
}}
filters = [MockFieldFilter('a', '>=', 4),
MockFieldFilter('a', '>=', 5)]
ge_4_dicts = [d.to_dict()
for d in list(fs.collection('foo').where(
filter=filters[0]).stream())
]
ge_5_dicts = [d.to_dict()
for d in list(fs.collection('foo').where(
filter=filters[0]).where(
filter=filters[1]).stream())
]
self.assertEqual(len(ge_4_dicts), 2)
self.assertEqual(sorted([o['a'] for o in ge_4_dicts]), [4, 5])
self.assertEqual(len(ge_5_dicts), 1)
self.assertEqual([o['a'] for o in ge_5_dicts], [5])

def test_collection_orderBy(self):
fs = MockFirestore()
fs._data = {'foo': {
Expand Down

0 comments on commit c20149d

Please sign in to comment.