Skip to content

Commit

Permalink
refactor(crud): align count with other method to accept `select_state…
Browse files Browse the repository at this point in the history
…ment` as callable
  • Loading branch information
jd-solanki committed Dec 17, 2024
1 parent 11d8e67 commit 1eeb632
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 5 deletions.
2 changes: 1 addition & 1 deletion examples/crud.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ async def get_users_count(
if first_name__contains:
select_statement = select_statement.where(User.first_name.contains(first_name__contains))

return await user_crud.count(db, select_statement=select_statement)
return await user_crud.count(db, select_statement=lambda _: select_statement)


@app.get("/users/one")
Expand Down
12 changes: 8 additions & 4 deletions src/fastapi_batteries/crud/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ async def get_multi(

# --- Return records
if pagination:
total = await self.count(db, select_statement=_select_statement)
total = await self.count(db, select_statement=lambda _: _select_statement)
return records, total
return records

Expand Down Expand Up @@ -315,8 +315,12 @@ async def delete(self, db: AsyncSession, item_id: int, *, commit: bool = True) -

return result.rowcount

# TODO: Use callable for select_statement like other methods
async def count(self, db: AsyncSession, *, select_statement: Select[tuple[ModelType]] | None = None) -> int:
async def count(
self,
db: AsyncSession,
*,
select_statement: Callable[[Select[tuple[ModelType]]], Select[tuple[ModelType]]] = lambda s: s,
) -> int:
"""Count the number of records for given select statement.
TIP: If you just want to know if n records exist, use `exist_n` method.
Expand All @@ -330,7 +334,7 @@ async def count(self, db: AsyncSession, *, select_statement: Select[tuple[ModelT
Number of records
"""
count_select_from = select_statement.subquery() if select_statement is not None else self.model
count_select_from = select_statement(select(self.model)).subquery()
count_statement = select(func.count()).select_from(count_select_from)

result = await db.scalars(count_statement)
Expand Down

0 comments on commit 1eeb632

Please sign in to comment.