diff --git a/examples/crud.py b/examples/crud.py index 811db02..62408bf 100644 --- a/examples/crud.py +++ b/examples/crud.py @@ -144,7 +144,7 @@ async def get_users_count( if first_name__contains: select_statement = select_statement.where(User.first_name.contains(first_name__contains)) - return await user_crud.count(db, select_statement=select_statement) + return await user_crud.count(db, select_statement=lambda _: select_statement) @app.get("/users/one") diff --git a/src/fastapi_batteries/crud/__init__.py b/src/fastapi_batteries/crud/__init__.py index 0fa57b0..e7849ff 100644 --- a/src/fastapi_batteries/crud/__init__.py +++ b/src/fastapi_batteries/crud/__init__.py @@ -215,7 +215,7 @@ async def get_multi( # --- Return records if pagination: - total = await self.count(db, select_statement=_select_statement) + total = await self.count(db, select_statement=lambda _: _select_statement) return records, total return records @@ -315,8 +315,12 @@ async def delete(self, db: AsyncSession, item_id: int, *, commit: bool = True) - return result.rowcount - # TODO: Use callable for select_statement like other methods - async def count(self, db: AsyncSession, *, select_statement: Select[tuple[ModelType]] | None = None) -> int: + async def count( + self, + db: AsyncSession, + *, + select_statement: Callable[[Select[tuple[ModelType]]], Select[tuple[ModelType]]] = lambda s: s, + ) -> int: """Count the number of records for given select statement. TIP: If you just want to know if n records exist, use `exist_n` method. @@ -330,7 +334,7 @@ async def count(self, db: AsyncSession, *, select_statement: Select[tuple[ModelT Number of records """ - count_select_from = select_statement.subquery() if select_statement is not None else self.model + count_select_from = select_statement(select(self.model)).subquery() count_statement = select(func.count()).select_from(count_select_from) result = await db.scalars(count_statement)