Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add support for bulk_update #148

Open
wants to merge 7 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 15 additions & 4 deletions docs/making_queries.md
Original file line number Diff line number Diff line change
Expand Up @@ -146,10 +146,10 @@ await Note.objects.create(text="Send invoices.", completed=True)
You need to pass a list of dictionaries of required fields to create multiple objects:

```python
await Product.objects.bulk_create(
await Note.objects.bulk_create(
[
{"data": {"foo": 123}, "value": 123.456, "status": StatusEnum.RELEASED},
{"data": {"foo": 456}, "value": 456.789, "status": StatusEnum.DRAFT},
{"text": "Buy the groceries", "completed": False},
{"text": "Call Mum.", "completed": True},

]
)
Expand Down Expand Up @@ -233,6 +233,18 @@ note = await Note.objects.first()
await note.update(completed=True)
```

### .bulk_update()

You can also bulk update multiple objects at once by passing a list of objects and a list of fields to update.

```python
notes = await Note.objects.all()
for note in notes :
note.completed = True

await Note.objects.bulk_update(notes, fields=["completed"])
```

## Convenience Methods

### .get_or_create()
Expand All @@ -252,7 +264,6 @@ if it doesn't exist, it will use `defaults` argument to create the new instance.
!!! note
Since `get_or_create()` is doing a [get()](#get), it can raise `MultipleMatches` exception.


### .update_or_create()

To update an existing instance matching the query, or create a new one.
Expand Down
44 changes: 44 additions & 0 deletions orm/models.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import enum
import json
import typing

import databases
Expand Down Expand Up @@ -28,6 +30,15 @@ def _update_auto_now_fields(values, fields):
return values


def _convert_value(value):
if isinstance(value, dict):
return json.dumps(value)
elif isinstance(value, enum.Enum):
return value.name
else:
return value


class ModelRegistry:
def __init__(self, database: databases.Database) -> None:
self.database = database
Expand Down Expand Up @@ -454,6 +465,39 @@ async def update(self, **kwargs) -> None:

await self.database.execute(expr)

async def bulk_update(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry I should've noticed this earlier, apologies for that.
But maybe a general refactor would be useful here?
There's a lot of nested code here and it's not very readable. What do you think?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I agree with you it needs to be more readable

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@aminalaee Any updates ?

self, objs: typing.List["Model"], fields: typing.List[str]
) -> None:
fields = {
key: field.validator
for key, field in self.model_cls.fields.items()
if key in fields
}
validator = typesystem.Schema(fields=fields)
new_objs = [
{
key: _convert_value(value)
for key, value in obj.__dict__.items()
if key in fields
}
for obj in objs
]
new_objs = [
_update_auto_now_fields(validator.validate(obj), self.model_cls.fields)
for obj in new_objs
]
pk_column = getattr(self.table.c, self.pkname)
expr = self.table.update().where(pk_column == sqlalchemy.bindparam(self.pkname))
kwargs = {
field: sqlalchemy.bindparam(field)
for obj in new_objs
for field in obj.keys()
}
expr = expr.values(kwargs)
pk_list = [{self.pkname: getattr(obj, self.pkname)} for obj in objs]
joined_list = [{**pk, **value} for pk, value in zip(pk_list, new_objs)]
await self.database.execute_many(str(expr), joined_list)

async def get_or_create(
self, defaults: typing.Dict[str, typing.Any], **kwargs
) -> typing.Tuple[typing.Any, bool]:
Expand Down
40 changes: 40 additions & 0 deletions tests/test_columns.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,3 +159,43 @@ async def test_bulk_create():
assert products[1].data == {"foo": 456}
assert products[1].value == 456.789
assert products[1].status == StatusEnum.DRAFT


async def test_bulk_update():
await Product.objects.bulk_create(
[
{
"created_day": datetime.date.today(),
"data": {"foo": 123},
"value": 123.456,
"status": StatusEnum.RELEASED,
},
{
"created_day": datetime.date.today(),
"data": {"foo": 456},
"value": 456.789,
"status": StatusEnum.DRAFT,
},
]
)
products = await Product.objects.all()
products[0].created_day = datetime.date.today() - datetime.timedelta(days=1)
products[1].created_day = datetime.date.today() - datetime.timedelta(days=1)
products[0].status = StatusEnum.DRAFT
products[1].status = StatusEnum.RELEASED
products[0].data = {"foo": 1234}
products[1].data = {"foo": 5678}
products[0].value = 345.5
products[1].value = 789.8
await Product.objects.bulk_update(
products, fields=["created_day", "status", "data", "value"]
)
products = await Product.objects.all()
assert products[0].created_day == datetime.date.today() - datetime.timedelta(days=1)
assert products[1].created_day == datetime.date.today() - datetime.timedelta(days=1)
assert products[0].status == StatusEnum.DRAFT
assert products[1].status == StatusEnum.RELEASED
assert products[0].data == {"foo": 1234}
assert products[1].data == {"foo": 5678}
assert products[0].value == 345.5
assert products[1].value == 789.8
19 changes: 19 additions & 0 deletions tests/test_foreignkey.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,3 +278,22 @@ async def test_nullable_foreign_key():

assert member.email == "[email protected]"
assert member.team.pk is None


async def test_bulk_update_with_relation():
album = await Album.objects.create(name="foo")
album2 = await Album.objects.create(name="bar")

await Track.objects.bulk_create(
[
{"name": "foo", "album": album, "position": 1, "title": "foo"},
{"name": "bar", "album": album, "position": 2, "title": "bar"},
]
)
tracks = await Track.objects.all()
for track in tracks:
track.album = album2
await Track.objects.bulk_update(tracks, fields=["album"])
tracks = await Track.objects.all()
assert tracks[0].album.pk == album2.pk
assert tracks[1].album.pk == album2.pk