From 9f4e83d160e4fb3b908ef345bc0b76eee2088ff0 Mon Sep 17 00:00:00 2001 From: Micheal Gendy Date: Wed, 2 Feb 2022 18:03:01 +0200 Subject: [PATCH] improve readability --- orm/models.py | 44 +++++++++++++++++++------------------------- 1 file changed, 19 insertions(+), 25 deletions(-) diff --git a/orm/models.py b/orm/models.py index ea9264f..1a1cb1c 100644 --- a/orm/models.py +++ b/orm/models.py @@ -22,8 +22,6 @@ "lte": "__le__", } -MODEL = typing.TypeVar("MODEL", bound="Model") - def _update_auto_now_fields(values, fields): for key, value in fields.items(): @@ -468,7 +466,7 @@ async def update(self, **kwargs) -> None: await self.database.execute(expr) async def bulk_update( - self, objs: typing.List[MODEL], fields: typing.List[str] + self, objs: typing.List["Model"], fields: typing.List[str] ) -> None: fields = { key: field.validator @@ -476,30 +474,26 @@ async def bulk_update( if key in fields } validator = typesystem.Schema(fields=fields) + objs = [ + { + key: _convert_value(value) + for key, value in obj.__dict__.items() + if key in fields + } + for obj in objs + ] new_objs = [ - _update_auto_now_fields(validator.validate(value), self.model_cls.fields) - for value in [ - { - key: _convert_value(value) - for key, value in obj.__dict__.items() - if key in fields - } - for obj in objs - ] + _update_auto_now_fields(validator.validate(obj), self.model_cls.fields) + for obj in objs ] - expr = ( - self.table.update() - .where( - getattr(self.table.c, self.pkname) == sqlalchemy.bindparam(self.pkname) - ) - .values( - { - field: sqlalchemy.bindparam(field) - for obj in new_objs - for field in obj.keys() - } - ) - ) + pk_column = getattr(self.table.c, self.pkname) + expr = self.table.update().where(pk_column == sqlalchemy.bindparam(self.pkname)) + kwargs = { + field: sqlalchemy.bindparam(field) + for obj in new_objs + for field in obj.keys() + } + expr = expr.values(kwargs) pk_list = [{self.pkname: getattr(obj, self.pkname)} for obj in objs] joined_list = [{**pk, **value} for pk, value in zip(pk_list, new_objs)] await self.database.execute_many(str(expr), joined_list)