From 240aa9c1b260190e1513e988e845c981cad12b0e Mon Sep 17 00:00:00 2001 From: BrianLusina <12752833+BrianLusina@users.noreply.github.com> Date: Tue, 10 Dec 2024 20:45:56 +0300 Subject: [PATCH] feat(sql): new mixin and updates to base repository Adds a new mixin plus additional updates to the repository. This additionally adds new types with Pydantic and utilities to handle those updates. --- .tool-versions | 2 +- pyproject.toml | 4 + sanctumlabs_dbkit/exceptions.py | 5 + sanctumlabs_dbkit/sql/alembic.py | 36 +++ sanctumlabs_dbkit/sql/mixins.py | 18 +- sanctumlabs_dbkit/sql/models.py | 40 +++ sanctumlabs_dbkit/sql/repository.py | 28 +- sanctumlabs_dbkit/sql/session.py | 6 +- sanctumlabs_dbkit/sql/types.py | 323 ++++++++++++++++++- sanctumlabs_dbkit/sql/utils.py | 4 +- tests/sql/test_types.py | 475 ++++++++++++++++++++++++++++ 11 files changed, 930 insertions(+), 11 deletions(-) create mode 100644 sanctumlabs_dbkit/sql/alembic.py create mode 100644 tests/sql/test_types.py diff --git a/.tool-versions b/.tool-versions index 9850327..ed8f0e9 100644 --- a/.tool-versions +++ b/.tool-versions @@ -1,2 +1,2 @@ -python 3.12.0 +python 3.12.3 pre-commit 3.4.0 diff --git a/pyproject.toml b/pyproject.toml index 8b69c88..5064398 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -12,6 +12,10 @@ python = "^3.12.0" sqlalchemy = "^2.0.29" sqlalchemy-utils = "^0.41.1" inflection = "^0.5.1" +wrapt = "^1.16.0" +orjson = "^3.10.3" +pydantic = "^2.0" +uuid = "^1.30" [tool.poetry.group.dev.dependencies] pylint = "^3.1.0" diff --git a/sanctumlabs_dbkit/exceptions.py b/sanctumlabs_dbkit/exceptions.py index 6b36a68..15fc524 100644 --- a/sanctumlabs_dbkit/exceptions.py +++ b/sanctumlabs_dbkit/exceptions.py @@ -5,3 +5,8 @@ class ModelNotFoundError(Exception): """Error indicating a missing model""" + + +class UnsupportedModelOperationError(Exception): + """Error indicating an operation on a model is unsupported""" + pass diff --git a/sanctumlabs_dbkit/sql/alembic.py b/sanctumlabs_dbkit/sql/alembic.py new file mode 100644 index 0000000..3bca637 --- /dev/null +++ b/sanctumlabs_dbkit/sql/alembic.py @@ -0,0 +1,36 @@ +from typing import Any, Literal, Union + +from sanctumlabs_dbkit.sql.types import PydanticModel, PydanticModelList + + +def render_item( + type_: str, obj: Any, autogen_context: Any +) -> Union[str, Literal[False]]: + """ + A custom renderer for the `alembic` migration framework which caters for our custom SQLAlchemy pydantic types. + + These types allow for pydantic models to be serialized and deserialized to/from json. Alembic doesn't generate the + correct migrations for these cases so we need to do some hackery and override here. + + To leverage the custom renderer, you need to configure it on your migration context in your alembic `env.py` + + ``python + with connectable.connect() as connection: + context.configure( + connection=connection, + target_metadata=target_metadata, + compare_type=False, + render_item=render_item, + ) + ``` + + See https://alembic.sqlalchemy.org/en/latest/autogenerate.html#affecting-the-rendering-of-types-themselves + See https://gist.github.com/imankulov/4051b7805ad737ace7d8de3d3f934d6b + """ + + if type_ == "type" and ( + isinstance(obj, PydanticModelList) or isinstance(obj, PydanticModel) + ): + return "sa.JSON()" + + return False diff --git a/sanctumlabs_dbkit/sql/mixins.py b/sanctumlabs_dbkit/sql/mixins.py index cdfca2e..cad6ed1 100644 --- a/sanctumlabs_dbkit/sql/mixins.py +++ b/sanctumlabs_dbkit/sql/mixins.py @@ -5,7 +5,7 @@ from typing import Optional from datetime import datetime, timezone from uuid import UUID, uuid4 -from sqlalchemy import DateTime, func +from sqlalchemy import DateTime, func, BIGINT, Identity from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, declared_attr from sqlalchemy.dialects.postgresql import UUID as UUIDType import inflection @@ -85,3 +85,19 @@ class TableNameMixin: def __tablename__(self) -> str: """Table names are snake case plural, for example shipping_records""" return inflection.pluralize(inflection.underscore(self.__name__)) # type: ignore[attr-defined] + + +class BigIntIdentityMixin: + """ + A mixin to provide an auto-incrementing bigint primary key column. + + NOTE: usage of this mixin for primary key column purposes is discouraged and should only be used for special + cases (e.g. outbox spooler). The UUIDPrimaryKeyMixin is what should typically be used instead (via the + BaseModel class) + """ + + id: Mapped[Optional[int]] = mapped_column( + Identity(start=1, cycle=False), primary_key=True, nullable=False, type_=BIGINT + ) + + pk: str = "id" diff --git a/sanctumlabs_dbkit/sql/models.py b/sanctumlabs_dbkit/sql/models.py index cdc2a04..79fe2b2 100644 --- a/sanctumlabs_dbkit/sql/models.py +++ b/sanctumlabs_dbkit/sql/models.py @@ -1,6 +1,12 @@ """ Contains base database models that can be subclassed to add functionality & attributes for database models in an app """ +from datetime import datetime, timezone +from typing import Optional +from uuid import UUID, uuid4 + +from sqlalchemy import LargeBinary +from sqlalchemy.orm import Mapped, declared_attr, mapped_column, synonym from sanctumlabs_dbkit.sql.mixins import ( AuditedMixin, @@ -9,6 +15,7 @@ TableNameMixin, TimestampColumnsMixin, UUIDPrimaryKeyMixin, + BaseIdentityMixin ) @@ -35,6 +42,38 @@ class BaseModel(UUIDPrimaryKeyMixin, AbstractBaseModel): __abstract__ = True +class BaseOutboxEvent(Base, BaseIdentityMixin, TableNameMixin): + """ + Base model for outbox events. Projects can choose to add additional table args (e.g. custom index) if + needed: + + __table_args__ = ( + Index( + ... + ), + ) + """ + + __abstract__ = True + + uuid: Mapped[UUID] = mapped_column(unique=True, default=uuid4) + + created: Mapped[datetime] = mapped_column( + default=lambda: datetime.now(timezone.utc) + ) + destination: Mapped[str] + event_type: Mapped[str] + correlation_id: Mapped[str] + partition_key: Mapped[str] + payload: Mapped[bytes] = mapped_column(type_=LargeBinary) + sent_time: Mapped[Optional[datetime]] + error_message: Mapped[Optional[str]] + + # mimic AbstractBaseModel to play nicely in the base DAO class + @declared_attr + def created_at(cls) -> Mapped[datetime]: # noqa: N805 + return synonym("created") + __all__ = [ "AbstractBaseModel", @@ -45,4 +84,5 @@ class BaseModel(UUIDPrimaryKeyMixin, AbstractBaseModel): "TableNameMixin", "TimestampColumnsMixin", "UUIDPrimaryKeyMixin", + "BaseOutboxEvent", ] diff --git a/sanctumlabs_dbkit/sql/repository.py b/sanctumlabs_dbkit/sql/repository.py index 9a0e3b7..22f7607 100644 --- a/sanctumlabs_dbkit/sql/repository.py +++ b/sanctumlabs_dbkit/sql/repository.py @@ -3,7 +3,7 @@ """ from datetime import datetime, UTC -from typing import Generic, Any, Optional, Sequence, Type, TypeVar, cast +from typing import Generic, Any, Optional, Sequence, Type, TypeVar, cast, TypeGuard from sqlalchemy import ColumnElement, Select, select from sanctumlabs_dbkit.exceptions import ModelNotFoundError @@ -28,11 +28,29 @@ def __init__(self, model: Type[T], session: Session) -> None: self.model = model self.session = session + @staticmethod + def _supports_soft_deletion(model: Type[T]) -> TypeGuard[Type[AbstractBaseModel]]: + """ + Indicates if the provided model supports soft deletion (has a 'deleted_at' column). This function + takes in an argument due to mypy typeguarding requirements, and is thus static. + """ + return issubclass(model, AbstractBaseModel) + + def create(self, refresh: bool = False, **kwargs: Any) -> T: + model_instance = self.model(**kwargs) + self.session.add(model_instance) + + if refresh: + self.session.flush() + self.session.refresh(model_instance) + + return cast(T, model_instance) + def query(self, include_deleted: bool = False) -> Select: """Returns a select query with the model including deleted records if the include_deleted is set to True""" selectable = select(self.model) - if not include_deleted: + if not include_deleted and self._supports_soft_deletion(self.model): selectable = selectable.where( self.model.deleted_at == self.model.not_deleted_value() ) @@ -67,7 +85,11 @@ def all(self, include_deleted: bool = False) -> Sequence[T]: def delete(self, pk: Any) -> None: """Deletes a given record with the given primary key""" entity = self.find(pk) - + + # Cast here as mypy type narrowing doesn't infer the type of entity + # correctly + entity = cast(AbstractBaseModel, self.find(pk)) + if entity: entity.deleted_at = datetime.now(UTC) diff --git a/sanctumlabs_dbkit/sql/session.py b/sanctumlabs_dbkit/sql/session.py index b75b9b5..d7d46de 100644 --- a/sanctumlabs_dbkit/sql/session.py +++ b/sanctumlabs_dbkit/sql/session.py @@ -28,7 +28,7 @@ def transaction(self, func: FuncT) -> FuncT: Example: ```python - from sanctumlabs_dbkit import SessionLocal + from sanctumlabs_dbkit.sql import SessionLocal session = SessionLocal() @@ -59,8 +59,8 @@ def transaction(func: FuncT) -> FuncT: Example: ```python - from sanctumlabs_dbkit import SessionLocal - from sanctumlabs_dbkit.session import transaction + from sanctumlabs_dbkit.sql import SessionLocal + from sanctumlabs_dbkit.sql.session import transaction class UserService(): def __init__(session: Session): diff --git a/sanctumlabs_dbkit/sql/types.py b/sanctumlabs_dbkit/sql/types.py index f63126b..7841f3b 100644 --- a/sanctumlabs_dbkit/sql/types.py +++ b/sanctumlabs_dbkit/sql/types.py @@ -1,8 +1,329 @@ """ Database Kit Types """ +from __future__ import annotations -from typing import Callable +from typing import ( + Any, + Callable, + Dict, + Iterable, + List, + Optional, + SupportsIndex, + Type, + TypeVar, + Union, + cast, +) from sanctumlabs_dbkit.sql.session import Session +import sqlalchemy as sa +from pydantic import BaseModel +from sqlalchemy import Dialect +from sqlalchemy.dialects.postgresql import JSONB +from sqlalchemy.ext.mutable import Mutable, MutableList +from sqlalchemy.sql.type_api import TypeEngine +from wrapt import ObjectProxy + CommitCallback = Callable[[Session], None] + +_T = TypeVar("_T", bound=BaseModel) + + +class SerializationOptions(BaseModel): + """ + `SerializationOptions` are used when serialization a pydantic model to JSON. + + Example: + ```python + data: Mapped[SettingsData] = mapped_column( + MutablePydanticModel.as_mutable( + SettingsData, + serialization_options=SerializationOptions(exclude_defaults=True), + default=lambda: SettingsData(), + ) + ) + ``` + + Given the above example, when the settings data is serialized and persisted in the database, it will not dump + the default model values. + """ + + exclude_defaults: bool = False + + +class ColumnUsesPydanticModelsMixin(sa.types.TypeDecorator, TypeEngine[_T]): + """ + ColumnUsesPydanticModelsMixin is a mixin class for serializing and deserializing pydantic models to/from + SQLAlchemy columns. + + See https://docs.sqlalchemy.org/en/20/core/custom_types.html#marshal-json-strings + """ + + impl = sa.types.JSON + + def __init__( + self, model: _T, serialization_options: Optional[SerializationOptions] = None + ): + self.model = model + self.serialization_options = serialization_options or SerializationOptions() + + super().__init__() + + def load_dialect_impl(self, dialect: Dialect) -> TypeEngine[Any]: + # Use JSONB for PostgreSQL and JSON for other databases. + if dialect.name == "postgresql": + return dialect.type_descriptor(JSONB(none_as_null=True)) # type: ignore + else: + return dialect.type_descriptor(sa.JSON(none_as_null=True)) + + def _model_to_dict(self, value: _T) -> Dict[str, Any]: + return value.model_dump( + exclude_defaults=self.serialization_options.exclude_defaults + ) + + +class PydanticModel(ColumnUsesPydanticModelsMixin): + """ + A custom SQLAlchemy column type for declaring a field as a pydantic model. + + It's important to note that changes on the pydantic model won't trigger an attribute change on the SQLAlchemy + model. If you want mutations to be recognised, you'll need to use the `MutablePydanticModel`. + + `cart_summary: Mapped[CartSummary] = mapped_column(PydanticModel(CartSummary), default=lambda: CartSummary())` + """ + + def process_bind_param( + self, value: Optional[BaseModel], dialect: Dialect + ) -> Optional[Dict[str, Any]]: + return self._model_to_dict(value) if value else None + + def process_result_value( + self, value: Optional[Any], dialect: Dialect + ) -> Optional[BaseModel]: + return ( + cast(BaseModel, self.model.model_validate(value)) + if value is not None + else None + ) + + def __repr__(self) -> str: + return f"PydanticModel{self.model.__name__}" + + +class PydanticModelProxy(ObjectProxy): + """ + A proxy class wrapping a pydantic model. + + The proxy class is used by the `MutablePydanticModel` and `MutablePydanticModelList` to detect changes in + the underlying model(s). We hijack all attribute setting and trigger a change on the mutable class. As of writing, + there is no other way to detect a change in the model. + """ + + def __init__(self, wrapped: BaseModel, mutable: Mutable): + super().__init__(wrapped) + + self._self_mutable = mutable + + def __setattr__(self, name: str, value: Any) -> None: + super().__setattr__(name, value) + + if name in self.__wrapped__.model_fields: + self._self_mutable.changed() + + +class MutablePydanticModel(Mutable): + """ + A custom SQLAlchemy column type for declaring a field as a pydantic model. + + Any change to the underlying pydantic model will trigger a change to the SQLAlchemy model triggering an update. + + See https://docs.sqlalchemy.org/en/14/orm/extensions/mutable.html for further docs on mutation tracking. + + `cart_summary: Mapped[CartSummary] = mapped_column(MutablePydanticModel.as_mutable(CartSummary), default=lambda: CartSummary())` + """ + + def __init__(self, model_instance: BaseModel): + self.proxied_model_instance = _create_proxied_pydantic_model( + model_instance, self + ) + + @classmethod + def as_mutable( + cls, sqltype: Union[TypeEngine[_T], _T], **kwargs: Any + ) -> TypeEngine[_T]: + serialization_options = kwargs.get("serialization_options") + + return super().as_mutable( + PydanticModel(sqltype, serialization_options=serialization_options) + ) + + @classmethod + def coerce(cls, key: str, value: Any) -> Optional[Any]: + if not isinstance(value, cls): + if isinstance(value, BaseModel): + return cls(value) + + return Mutable.coerce(key, value) + else: + return value + + def __setattr__(self, name: str, value: Any) -> None: + if name == "proxied_model_instance": + super().__setattr__(name, value) + else: + setattr(self.proxied_model_instance, name, value) + + def __getattr__(self, name: str) -> Any: + if name == "proxied_model_instance": + return super().__getattr__(name) # type: ignore + else: + return getattr(self.proxied_model_instance, name) + + +class PydanticModelList(ColumnUsesPydanticModelsMixin): + """ + A custom SQLAlchemy column type for declaring a field as a list of pydantic models. + + `products: Mapped[List[Product]] = mapped_column(PydanticModelList(Product))` + """ + + def process_bind_param( + self, value: Optional[List[BaseModel]], dialect: Dialect + ) -> Optional[List[Dict[str, Any]]]: + if value: + return [self._model_to_dict(model_instance) for model_instance in value] + else: + return None + + def process_result_value( + self, value: Optional[Any], dialect: Dialect + ) -> Optional[List[BaseModel]]: + if value: + return [self.model.model_validate(model_data) for model_data in value] + else: + return None + + +class MutablePydanticModelList(MutableList): + """ + A custom SQLAlchemy column type for declaring a field as a list of pydantic models. + + Any change to the underlying pydantic models will trigger a change to the SQLAlchemy model triggering an update. + + See https://docs.sqlalchemy.org/en/14/orm/extensions/mutable.html for further docs on mutation tracking. + + `products: Mapped[List[Product]] = mapped_column(MutablePydanticModelList.as_mutable(Product))` + """ + + def __init__(self, values: Optional[Iterable[_T]] = None): + values = values or [] + + proxied_model_instances = [ + _create_proxied_pydantic_model(model_instance, self) + for model_instance in values + ] + + super().__init__(proxied_model_instances) + + @classmethod + def as_mutable( + cls, sqltype: Union[TypeEngine[_T], _T], **kwargs: Any + ) -> TypeEngine[_T]: + serialization_options = kwargs.get("serialization_options") + + return super().as_mutable( + PydanticModelList(sqltype, serialization_options=serialization_options) + ) + + def __setitem__( + self, index: SupportsIndex | slice, value: _T | Iterable[_T] + ) -> None: + if isinstance(value, Iterable): + proxied_model_instances = [ + _create_proxied_pydantic_model(v, self) for v in value + ] + + super().__setitem__(index, proxied_model_instances) + else: + super().__setitem__(index, _create_proxied_pydantic_model(value, self)) + + def append(self, x: _T) -> None: + super().append(_create_proxied_pydantic_model(x, self)) + + def extend(self, x: Iterable[_T]) -> None: + proxied_model_instances = [_create_proxied_pydantic_model(v, self) for v in x] + + super().extend(proxied_model_instances) + + def insert(self, i: SupportsIndex, x: _T) -> None: + super().insert(i, _create_proxied_pydantic_model(x, self)) + + +def _create_proxied_pydantic_model( + model_instance: _T, mutable: Mutable +) -> PydanticModelProxy: + if isinstance(model_instance, PydanticModelProxy): + return model_instance + elif isinstance(model_instance, BaseModel): + return PydanticModelProxy(model_instance, mutable) + else: + raise Exception("The model instance must be a pydantic model") + + +def normalise_mutable_pydantic_model(v: Any) -> Any: + """ + A hook to automatically unwrap a `MutablePydanticModel` to its respective pydantic model. + + Example usage: + + ```python + + + BusinessSocialAccountsType = Annotated[ + BusinessSocialAccounts, + BeforeValidator(normalise_mutable_pydantic_model), + ] + + class BusinessEntity(BaseModel): + social_accounts: BusinessSocialAccountsType = Field( + default_factory=BusinessSocialAccounts + ) + ``` + + In this example, if we were to store the `BusinessSocialAccounts` in a database column, it would be wrapped + within mutable and proxy objects. If we were to later try and construct a `BusinessEntity` using `model_validate()`, it would complain that we're + trying to set a field expecting type `BusinessSocialAccounts` to something which is a `MutablePydanticModel`. + We therefore use the hook to normalise the proxy model to the actual pydantic model. + """ + + if isinstance(v, MutablePydanticModel): + return v.proxied_model_instance.__wrapped__ + + return v + + +def default_to_pydantic_model(model: Type[BaseModel]) -> Callable[[Any], Any]: + """ + A hook to automatically set a field to an instance of a pydantic model if constructing without a value. + + Example usage: + + ```python + BusinessSocialAccountsType = Annotated[ + BusinessSocialAccounts, + BeforeValidator(default_to_pydantic_model(BusinessSocialAccounts)), + ] + + class BusinessEntity(BaseModel): + social_accounts: BusinessSocialAccountsType = Field( + default_factory=BusinessSocialAccounts + ) + ``` + """ + + def _validator(v: Any) -> Any: + return model() if not v else v + + return _validator diff --git a/sanctumlabs_dbkit/sql/utils.py b/sanctumlabs_dbkit/sql/utils.py index 2eec36b..e6c562b 100644 --- a/sanctumlabs_dbkit/sql/utils.py +++ b/sanctumlabs_dbkit/sql/utils.py @@ -21,10 +21,10 @@ def get_changes(entity: AbstractBaseModel) -> Dict[str, Tuple[Any, Any]]: Example: user = get_user_by_id(420) >>> '' - get_model_changes(user) + get_changes(user) >>> {} user.email = 'new_email@who-dis.biz' - get_model_changes(user) + get_changes(user) >>> {'email': ['business_email@gmail.com', 'new_email@who-dis.biz']} """ state: InstanceState = inspect(entity) diff --git a/tests/sql/test_types.py b/tests/sql/test_types.py new file mode 100644 index 0000000..b5121ae --- /dev/null +++ b/tests/sql/test_types.py @@ -0,0 +1,475 @@ +from datetime import date +from decimal import Decimal +from typing import Any, Dict, List, Optional, cast + +from pydantic import BaseModel as PydanticBaseModel +from sqlalchemy import JSON, text +from sqlalchemy.orm import Mapped, mapped_column + +from sanctumlabs_dbkit.sql.repository import DAO +from sanctumlabs_dbkit.sql.models import BaseModel +from sanctumlabs_dbkit.sql.session import Session +from sanctumlabs_dbkit.sql.types import ( + MutablePydanticModel, + MutablePydanticModelList, + PydanticModel, + SerializationOptions, +) + + +class Child(PydanticBaseModel): + first_name: str + last_name: str + date_of_birth: date + favourite_food: List[str] + attributes: Dict[str, Any] + + +class Person(BaseModel): + name: Mapped[str] + children: Mapped[JSON] = mapped_column(type_=JSON) + + +def test_sqlalchemy_uses_pydantic_json_serializer_to_serialize_json( + database_session: Session, +) -> None: + with database_session.begin(): + person = Person( + name="Pam Goslett", + children=[ + Child( + first_name="Matthew", + last_name="Goslett", + date_of_birth=date(1987, 7, 31), + favourite_food=["ribs", "dim sum", "sushi"], + attributes={ + "eye_colour": "hazel", + "hair_colour": "brown", + "star_sign": "leo", + }, + ) + ], + ) + + database_session.add(person) + + people_dao = DAO(Person, database_session) + + people = people_dao.all() + + assert len(people) == 1 + + pam_goslett: Person = people[0] + + assert pam_goslett.name == "Pam Goslett" + + children = cast(List[Dict[str, Any]], pam_goslett.children) + + assert children == [ + { + "first_name": "Matthew", + "last_name": "Goslett", + "date_of_birth": "1987-07-31", + "favourite_food": ["ribs", "dim sum", "sushi"], + "attributes": { + "eye_colour": "hazel", + "hair_colour": "brown", + "star_sign": "leo", + }, + } + ] + + +class Product(PydanticBaseModel): + name: str + brand: str + price: Decimal + is_for_sale: bool = True + + +class CartSummary(PydanticBaseModel): + gross_amount: Decimal = Decimal("0") + discount: Decimal = Decimal("0") + net_amount: Decimal = Decimal("0") + + +class AttributionSource(PydanticBaseModel): + utm_source: Optional[str] + utm_medium: Optional[str] + utm_campaign: Optional[str] + + +class Catalogue(BaseModel): + products: Mapped[List[Product]] = mapped_column( + MutablePydanticModelList.as_mutable( + Product, serialization_options=SerializationOptions(exclude_defaults=True) + ), + nullable=True, + ) + + +class Cart(BaseModel): + shipping_first_name: Mapped[str] + products: Mapped[List[Product]] = mapped_column( + MutablePydanticModelList.as_mutable(Product), nullable=True + ) + cart_summary: Mapped[CartSummary] = mapped_column( + MutablePydanticModel.as_mutable(CartSummary), default=lambda: CartSummary() + ) + attribution_source: Mapped[Optional[AttributionSource]] = mapped_column( + PydanticModel(AttributionSource) + ) + + +class SettingsData(PydanticBaseModel): + invoice_email_address: str = "foo@bar.com" + show_address_on_invoice: bool = True + invoice_footer_text: Optional[str] = None + + +class Settings(BaseModel): + data: Mapped[SettingsData] = mapped_column( + MutablePydanticModel.as_mutable( + SettingsData, + serialization_options=SerializationOptions(exclude_defaults=True), + default=lambda: SettingsData(), + ) + ) + + +def test_sqlalchemy_serializes_value_in_pydantic_model_column( + database_session: Session, +) -> None: + with database_session.begin(): + cart = Cart( + shipping_first_name="Matthew", + products=[], + cart_summary=CartSummary( + gross_amount=Decimal("100"), + discount=Decimal("45"), + net_amount=Decimal("55"), + ), + attribution_source=AttributionSource( + utm_source="facebook", + utm_medium="social", + utm_campaign="yoco_2022_march_madness", + ), + ) + + database_session.add(cart) + + assert cart.shipping_first_name == "Matthew" + assert cart.cart_summary.gross_amount == Decimal("100") + assert cart.cart_summary.discount == Decimal("45") + assert cart.cart_summary.net_amount == Decimal("55") + + assert isinstance(cart.attribution_source, AttributionSource) + assert cart.attribution_source.utm_source == "facebook" + assert cart.attribution_source.utm_medium == "social" + assert cart.attribution_source.utm_campaign == "yoco_2022_march_madness" + + +def test_sqlalchemy_serializes_value_in_pydantic_model_column_and_excludes_defaults_when_specified( + database_session: Session, +) -> None: + with database_session.begin(): + settings = Settings(data=SettingsData(show_address_on_invoice=False)) + + database_session.add(settings) + + assert settings.data.invoice_email_address == "foo@bar.com" + assert not settings.data.show_address_on_invoice + assert settings.data.invoice_footer_text is None + + row = database_session.execute(text("SELECT * FROM settings")).first() + + assert row.data == {"show_address_on_invoice": False} # type: ignore + + +def test_sqlalchemy_serializes_value_in_pydantic_model_column_and_excludes_defaults_when_specified_and_correctly_handles_case_where_value_is_default( + database_session: Session, +) -> None: + with database_session.begin(): + settings = Settings(data=SettingsData()) + + database_session.add(settings) + + assert settings.data.model_dump() == SettingsData().model_dump() + + row = database_session.execute(text("SELECT * FROM settings")).first() + + assert row.data == {} # type: ignore + + +def test_sqlalchemy_serializes_value_in_pydantic_model_column_when_value_is_null( + database_session: Session, +) -> None: + with database_session.begin(): + cart = Cart( + shipping_first_name="Matthew", + cart_summary=CartSummary( + gross_amount=Decimal("100"), + discount=Decimal("45"), + net_amount=Decimal("55"), + ), + ) + + database_session.add(cart) + + assert cart.shipping_first_name == "Matthew" + assert cart.cart_summary.gross_amount == Decimal("100") + assert cart.cart_summary.discount == Decimal("45") + assert cart.cart_summary.net_amount == Decimal("55") + assert cart.attribution_source is None + + +def test_sqlalchemy_serializes_value_in_pydantic_model_column_and_sets_default_value_when_not_present( + database_session: Session, +) -> None: + with database_session.begin(): + cart = Cart(shipping_first_name="Matthew") + + database_session.add(cart) + + assert cart.cart_summary.gross_amount == Decimal("0") + assert cart.cart_summary.discount == Decimal("0") + assert cart.cart_summary.net_amount == Decimal("0") + + +def test_sqlalchemy_updates_value_in_pydantic_model_column( + database_session: Session, +) -> None: + with database_session.begin(): + cart = Cart( + shipping_first_name="Matthew", + cart_summary=CartSummary( + gross_amount=Decimal("100"), + discount=Decimal("45"), + net_amount=Decimal("55"), + ), + attribution_source=AttributionSource( + utm_source="facebook", + utm_medium="social", + utm_campaign="yoco_2022_march_madness", + ), + ) + database_session.add(cart) + + assert cart.shipping_first_name == "Matthew" + assert cart.cart_summary.gross_amount == Decimal("100") + assert cart.cart_summary.discount == Decimal("45") + assert cart.cart_summary.net_amount == Decimal("55") + + assert isinstance(cart.attribution_source, AttributionSource) + assert cart.attribution_source.utm_source == "facebook" + assert cart.attribution_source.utm_medium == "social" + assert cart.attribution_source.utm_campaign == "yoco_2022_march_madness" + + cart.shipping_first_name = "Bob" + cart.cart_summary.gross_amount = Decimal("45") + cart.cart_summary.discount = Decimal("0") + cart.cart_summary.net_amount = Decimal("45") + cart.attribution_source.utm_source = "google" + + database_session.commit() + + assert cart.shipping_first_name == "Bob" + assert cart.cart_summary.gross_amount == Decimal("45") + assert cart.cart_summary.discount == Decimal("0") + assert cart.cart_summary.net_amount == Decimal("45") + # cart.attribution_source.utm_source should not have changed since the field isn't mutable or monitored + assert cart.attribution_source.utm_source == "facebook" + assert cart.attribution_source.utm_medium == "social" + assert cart.attribution_source.utm_campaign == "yoco_2022_march_madness" + + +def test_sqlalchemy_serializes_value_in_pydantic_model_list_column( + database_session: Session, +) -> None: + with database_session.begin(): + cart = Cart( + shipping_first_name="Matthew", + products=[ + Product(name="Kit Kat Chunky", brand="Kit Kat", price=Decimal("2.45")), + Product(name="Nike Pegasus 39", brand="Nike", price=Decimal("2499")), + ], + ) + + database_session.add(cart) + + assert len(cart.products) == 2 + + assert cart.products[0].name == "Kit Kat Chunky" + assert cart.products[0].brand == "Kit Kat" + assert cart.products[0].price == Decimal("2.45") + assert cart.products[0].is_for_sale + + assert cart.products[1].name == "Nike Pegasus 39" + assert cart.products[1].brand == "Nike" + assert cart.products[1].price == Decimal("2499") + assert cart.products[1].is_for_sale + + +def test_sqlalchemy_serializes_value_in_pydantic_model_list_column_and_excludes_defaults_when_specified( + database_session: Session, +) -> None: + with database_session.begin(): + catalogue = Catalogue( + products=[ + Product(name="Kit Kat Chunky", brand="Kit Kat", price=Decimal("2.45")), + Product( + name="Nike Pegasus 39", + brand="Nike", + price=Decimal("2499"), + is_for_sale=False, + ), + ], + ) + + database_session.add(catalogue) + + assert len(catalogue.products) == 2 + + assert catalogue.products[0].name == "Kit Kat Chunky" + assert catalogue.products[0].brand == "Kit Kat" + assert catalogue.products[0].price == Decimal("2.45") + assert catalogue.products[0].is_for_sale + + assert catalogue.products[1].name == "Nike Pegasus 39" + assert catalogue.products[1].brand == "Nike" + assert catalogue.products[1].price == Decimal("2499") + assert not catalogue.products[1].is_for_sale + + row = database_session.execute(text("SELECT * FROM catalogues")).first() + + assert row.products == [ # type: ignore + {"name": "Kit Kat Chunky", "brand": "Kit Kat", "price": 2.45}, + { + "name": "Nike Pegasus 39", + "brand": "Nike", + "price": 2499, + "is_for_sale": False, + }, + ] + + +def test_sqlalchemy_serializes_value_in_pydantic_model_list_column_when_value_is_null( + database_session: Session, +) -> None: + with database_session.begin(): + cart = Cart(shipping_first_name="Matthew") + + database_session.add(cart) + + assert cart.products is None + + +def test_sqlalchemy_updates_value_in_pydantic_model_list_column_when_model_updates( + database_session: Session, +) -> None: + with database_session.begin(): + cart = Cart( + shipping_first_name="Matthew", + products=[ + Product(name="Kit Kat Chunky", brand="Kit Kat", price=Decimal("2.45")), + Product(name="Nike Pegasus 39", brand="Nike", price=Decimal("2499")), + ], + ) + + database_session.add(cart) + + assert len(cart.products) == 2 + + assert cart.products[0].name == "Kit Kat Chunky" + assert cart.products[0].brand == "Kit Kat" + assert cart.products[0].price == Decimal("2.45") + assert cart.products[0].is_for_sale + + assert cart.products[1].name == "Nike Pegasus 39" + assert cart.products[1].brand == "Nike" + assert cart.products[1].price == Decimal("2499") + assert cart.products[1].is_for_sale + + with database_session.begin(): + cart.products[0].price = Decimal("4.78") + + assert cart.products[0].price == Decimal("4.78") + + +def test_sqlalchemy_updates_value_in_pydantic_model_list_column_when_model_is_added( + database_session: Session, +) -> None: + with database_session.begin(): + cart = Cart( + shipping_first_name="Matthew", + products=[ + Product(name="Kit Kat Chunky", brand="Kit Kat", price=Decimal("2.45")), + ], + ) + + database_session.add(cart) + + assert len(cart.products) == 1 + + assert cart.products[0].name == "Kit Kat Chunky" + assert cart.products[0].brand == "Kit Kat" + assert cart.products[0].price == Decimal("2.45") + assert cart.products[0].is_for_sale + + with database_session.begin(): + cart.products.append( + Product( + name="G-Star Skinny Jeans - Size 32", + brand="G-Star", + price=Decimal("1699"), + ) + ) + + assert len(cart.products) == 2 + + assert cart.products[0].name == "Kit Kat Chunky" + assert cart.products[0].brand == "Kit Kat" + assert cart.products[0].price == Decimal("2.45") + assert cart.products[0].is_for_sale + + assert cart.products[1].name == "G-Star Skinny Jeans - Size 32" + assert cart.products[1].brand == "G-Star" + assert cart.products[1].price == Decimal("1699") + assert cart.products[1].is_for_sale + + +def test_sqlalchemy_updates_value_in_pydantic_model_list_column_when_model_is_removed( + database_session: Session, +) -> None: + with database_session.begin(): + cart = Cart( + shipping_first_name="Matthew", + products=[ + Product(name="Kit Kat Chunky", brand="Kit Kat", price=Decimal("2.45")), + Product(name="Nike Pegasus 39", brand="Nike", price=Decimal("2499")), + ], + ) + + database_session.add(cart) + + assert len(cart.products) == 2 + + assert cart.products[0].name == "Kit Kat Chunky" + assert cart.products[0].brand == "Kit Kat" + assert cart.products[0].price == Decimal("2.45") + assert cart.products[0].is_for_sale + + assert cart.products[1].name == "Nike Pegasus 39" + assert cart.products[1].brand == "Nike" + assert cart.products[1].price == Decimal("2499") + assert cart.products[1].is_for_sale + + with database_session.begin(): + del cart.products[1] + + assert len(cart.products) == 1 + + assert cart.products[0].name == "Kit Kat Chunky" + assert cart.products[0].brand == "Kit Kat" + assert cart.products[0].price == Decimal("2.45") + assert cart.products[0].is_for_sale