diff --git a/mockfirestore/collection.py b/mockfirestore/collection.py index 431c074..a15ba30 100644 --- a/mockfirestore/collection.py +++ b/mockfirestore/collection.py @@ -41,8 +41,8 @@ def add(self, document_data: Dict, document_id: str = None) \ timestamp = Timestamp.from_now() return timestamp, doc_ref - def where(self, field: str, op: str, value: Any) -> Query: - query = Query(self, field_filters=[(field, op, value)]) + def where(self, field: Optional[str] = None, op: Optional[str] = None, value: Any = None, filter=None) -> Query: + query = Query(self, field_filters=[Query.make_field_filter(field, op, value, filter)]) return query def order_by(self, key: str, direction: Optional[str] = None) -> Query: diff --git a/mockfirestore/query.py b/mockfirestore/query.py index 7a4618d..c1cb6bb 100644 --- a/mockfirestore/query.py +++ b/mockfirestore/query.py @@ -61,8 +61,20 @@ def _add_field_filter(self, field: str, op: str, value: Any): compare = self._compare_func(op) self._field_filters.append((field, compare, value)) - def where(self, field: str, op: str, value: Any) -> 'Query': - self._add_field_filter(field, op, value) + @staticmethod + def make_field_filter(field: Optional[str], op: Optional[str], value: Any = None, filter=None): + if bool(filter) and (bool(field) or bool(op)): + raise ValueError("Can't pass in both the positional arguments and 'filter' at the same time") + if filter: + classname = filter.__class__.__name__ + if not classname.endswith('FieldFilter'): + raise NotImplementedError('composite filters not supported by mockfirestore (got %s)' % classname) + return (filter.field_path, filter.op_string, filter.value) + else: + return (field, op, value) + + def where(self, field: Optional[str] = None, op: Optional[str] = None, value: Any = None, filter=None) -> 'Query': + self._add_field_filter(*self.make_field_filter(field, op, value, filter)) return self def order_by(self, key: str, direction: Optional[str] = 'ASCENDING') -> 'Query': diff --git a/tests/test_collection_reference.py b/tests/test_collection_reference.py index 7f1924a..d234d80 100644 --- a/tests/test_collection_reference.py +++ b/tests/test_collection_reference.py @@ -3,6 +3,13 @@ from mockfirestore import MockFirestore, DocumentReference, DocumentSnapshot, AlreadyExists +class MockFieldFilter: + def __init__(self, field_path, op_string, value=None): + self.field_path = field_path + self.op_string = op_string + self.value = value + + class TestCollectionReference(TestCase): def test_collection_get_returnsDocuments(self): fs = MockFirestore() @@ -193,6 +200,29 @@ def test_collection_whereArrayContainsAny(self): self.assertEqual({'field': ['val4']}, contains_any_docs[0].to_dict()) self.assertEqual({'field': ['val3', 'val2', 'val1']}, contains_any_docs[1].to_dict()) + def test_collection_nestedWhereFieldFilter(self): + fs = MockFirestore() + fs._data = {'foo': { + 'first': {'a': 3}, + 'second': {'a': 4}, + 'third': {'a': 5} + }} + filters = [MockFieldFilter('a', '>=', 4), + MockFieldFilter('a', '>=', 5)] + ge_4_dicts = [d.to_dict() + for d in list(fs.collection('foo').where( + filter=filters[0]).stream()) + ] + ge_5_dicts = [d.to_dict() + for d in list(fs.collection('foo').where( + filter=filters[0]).where( + filter=filters[1]).stream()) + ] + self.assertEqual(len(ge_4_dicts), 2) + self.assertEqual(sorted([o['a'] for o in ge_4_dicts]), [4, 5]) + self.assertEqual(len(ge_5_dicts), 1) + self.assertEqual([o['a'] for o in ge_5_dicts], [5]) + def test_collection_orderBy(self): fs = MockFirestore() fs._data = {'foo': {