Skip to content

Commit

Permalink
Bump sqlmodel, sqlalchemy, pydantic, fastapi
Browse files Browse the repository at this point in the history
We no longer depend on sqlalchemy v1.4 this was.
These all need to be updated together because of cross dependencies.
  • Loading branch information
rroohhh committed Jan 20, 2025
1 parent a0822d2 commit bb83dec
Show file tree
Hide file tree
Showing 32 changed files with 641 additions and 419 deletions.
190 changes: 136 additions & 54 deletions backend/openapi-schema.yml

Large diffs are not rendered by default.

9 changes: 4 additions & 5 deletions backend/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@ authors = [

dependencies = [
"redis~=5.0.1",
"fastapi~=0.92.0",
"fastapi~=0.115",
"uvicorn[standard]~=0.20.0",
"sqlmodel~=0.0.11",
"sqlmodel~=0.0.22",
"alembic~=1.11.1",
"python-multipart~=0.0.6",
"filetype~=1.2.0",
Expand All @@ -23,6 +23,8 @@ dependencies = [
"python-frontmatter~=1.0.0",
"psycopg2~=2.9.9",
"prometheus-fastapi-instrumentator~=6.1.0",
"pydantic~=2.2",
"pydantic-settings>=2.7.1",
]
requires-python = ">=3.11"
readme = "./README.md"
Expand All @@ -49,9 +51,6 @@ transcribee-migrate = "transcribee_backend.db.run_migrations:main"
transcribee-admin = "transcribee_backend.admin_cli:main"

[tool.uv]
override-dependencies = [
"sqlalchemy==1.4.41"
]
config-settings = { editable_mode = "compat" }

[tool.uv.sources]
Expand Down
22 changes: 15 additions & 7 deletions backend/tests/test_doc.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import uuid

import pytest
from fastapi.testclient import TestClient
from sqlmodel import Session
from sqlmodel import Session, col, func, select
from transcribee_backend.auth import generate_share_token
from transcribee_backend.config import settings
from transcribee_backend.models import (
Expand All @@ -23,7 +25,7 @@ def document_id(memory_session: Session, logged_in_client: TestClient):
data={"name": "test document", "model": "tiny", "language": "auto"},
)
assert req.status_code == 200
document_id = req.json()["id"]
document_id = uuid.UUID(req.json()["id"])

memory_session.add(DocumentUpdate(document_id=document_id, change_bytes=b""))
memory_session.commit()
Expand All @@ -50,7 +52,7 @@ def test_doc_delete(
]
counts = {}
for table in checked_tables:
counts[table] = memory_session.query(table).count()
counts[table] = memory_session.exec(select(func.count(col(table.id)))).one()

files = set(str(x) for x in settings.storage_path.glob("*"))

Expand All @@ -60,14 +62,15 @@ def test_doc_delete(
data={"name": "test document", "model": "tiny", "language": "auto"},
)
assert req.status_code == 200
document_id = req.json()["id"]
document_id = uuid.UUID(req.json()["id"])

req = logged_in_client.get(f"/api/v1/documents/{document_id}/tasks/")
task_id = uuid.UUID(req.json()[0]["id"])
assert req.status_code == 200
assert len(req.json()) >= 1

memory_session.add(DocumentUpdate(document_id=document_id, change_bytes=b""))
memory_session.add(TaskAttempt(task_id=req.json()[0]["id"], attempt_number=1))
memory_session.add(TaskAttempt(task_id=task_id, attempt_number=1))
memory_session.add(
generate_share_token(
document_id=document_id, name="Test Token", valid_until=None, can_write=True
Expand All @@ -76,7 +79,9 @@ def test_doc_delete(
memory_session.commit()

for table in checked_tables:
assert counts[table] < memory_session.query(table).count()
assert (
counts[table] < memory_session.exec(select(func.count(col(table.id)))).one()
)

assert files < set(str(x) for x in settings.storage_path.glob("*"))

Expand All @@ -87,7 +92,10 @@ def test_doc_delete(
assert req.status_code == 200

for table in checked_tables:
assert counts[table] == memory_session.query(table).count()
assert (
counts[table]
== memory_session.exec(select(func.count(col(table.id)))).one()
)

assert files == set(str(x) for x in settings.storage_path.glob("*"))

Expand Down
2 changes: 1 addition & 1 deletion backend/transcribee_backend/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def validate_user_authorization(session: Session, authorization: str):
raise HTTPException(status_code=400, detail="Invalid Token")
user_id, provided_token = token_data.split(":", maxsplit=1)
statement = select(UserToken).where(
UserToken.user_id == user_id, UserToken.valid_until >= now_tz_aware()
UserToken.user_id == uuid.UUID(user_id), UserToken.valid_until >= now_tz_aware()
)
results = session.exec(statement)
for token in results:
Expand Down
31 changes: 17 additions & 14 deletions backend/transcribee_backend/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,31 +2,32 @@
from typing import Dict, List, Optional

import frontmatter
from pydantic import BaseModel, BaseSettings, parse_file_as, parse_obj_as
from pydantic import BaseModel, TypeAdapter
from pydantic_settings import BaseSettings

pages = None


class Settings(BaseSettings):
storage_path: Path = Path("storage/")
secret_key = "insecure-secret-key"
worker_timeout = 60 # in seconds
media_signature_max_age = 3600 # in seconds
task_attempt_limit = 5
secret_key: str = "insecure-secret-key"
worker_timeout: int = 60 # in seconds
media_signature_max_age: int = 3600 # in seconds
task_attempt_limit: int = 5

media_url_base = "http://localhost:8000/"
media_url_base: str = "http://localhost:8000/"
logged_out_redirect_url: None | str = None

model_config_path: Path = Path(__file__).parent.resolve() / Path(
"default_models.json"
)
pages_dir: Path = Path("data/pages/")

metrics_username = "transcribee"
metrics_password = "transcribee"
metrics_username: str = "transcribee"
metrics_password: str = "transcribee"

redis_host = "localhost"
redis_port = 6379
redis_host: str = "localhost"
redis_port: int = 6379


class ModelConfig(BaseModel):
Expand All @@ -37,20 +38,22 @@ class ModelConfig(BaseModel):

class PublicConfig(BaseModel):
models: Dict[str, ModelConfig]
logged_out_redirect_url: str | None
logged_out_redirect_url: str | None = None


class ShortPageConfig(BaseModel):
name: str
footer_position: Optional[int]
footer_position: Optional[int] = None


class PageConfig(ShortPageConfig):
text: str


def get_model_config():
return parse_file_as(Dict[str, ModelConfig], settings.model_config_path)
return TypeAdapter(Dict[str, ModelConfig]).validate_json(
Path(settings.model_config_path).read_text()
)


def load_pages_from_disk() -> Dict[str, PageConfig]:
Expand All @@ -75,7 +78,7 @@ def get_page_config():


def get_short_page_config() -> Dict[str, ShortPageConfig]:
return parse_obj_as(Dict[str, ShortPageConfig], get_page_config())
return TypeAdapter(Dict[str, ShortPageConfig]).validate_python(get_page_config())


def get_public_config():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def upgrade() -> None:
op.create_table(
"user",
sa.Column("username", sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.Column("id", sqlmodel.sql.sqltypes.GUID(), nullable=False),
sa.Column("id", sa.Uuid, nullable=False),
sa.Column("password_hash", sa.LargeBinary(), nullable=False),
sa.Column("password_salt", sa.LargeBinary(), nullable=False),
sa.PrimaryKeyConstraint("id"),
Expand All @@ -35,7 +35,7 @@ def upgrade() -> None:
"worker",
sa.Column("last_seen", sa.DateTime(timezone=True), nullable=True),
sa.Column("name", sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.Column("id", sqlmodel.sql.sqltypes.GUID(), nullable=False),
sa.Column("id", sa.Uuid, nullable=False),
sa.Column("token", sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.PrimaryKeyConstraint("id"),
)
Expand All @@ -47,8 +47,8 @@ def upgrade() -> None:
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
sa.Column("changed_at", sa.DateTime(timezone=True), nullable=False),
sa.Column("name", sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.Column("id", sqlmodel.sql.sqltypes.GUID(), nullable=False),
sa.Column("user_id", sqlmodel.sql.sqltypes.GUID(), nullable=False),
sa.Column("id", sa.Uuid, nullable=False),
sa.Column("user_id", sa.Uuid, nullable=False),
sa.ForeignKeyConstraint(
["user_id"],
["user.id"],
Expand All @@ -61,8 +61,8 @@ def upgrade() -> None:
op.create_table(
"usertoken",
sa.Column("valid_until", sa.DateTime(timezone=True), nullable=False),
sa.Column("id", sqlmodel.sql.sqltypes.GUID(), nullable=False),
sa.Column("user_id", sqlmodel.sql.sqltypes.GUID(), nullable=False),
sa.Column("id", sa.Uuid, nullable=False),
sa.Column("user_id", sa.Uuid, nullable=False),
sa.Column("token_hash", sa.LargeBinary(), nullable=False),
sa.Column("token_salt", sa.LargeBinary(), nullable=False),
sa.ForeignKeyConstraint(
Expand All @@ -78,8 +78,8 @@ def upgrade() -> None:
"documentmediafile",
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
sa.Column("changed_at", sa.DateTime(timezone=True), nullable=False),
sa.Column("id", sqlmodel.sql.sqltypes.GUID(), nullable=False),
sa.Column("document_id", sqlmodel.sql.sqltypes.GUID(), nullable=False),
sa.Column("id", sa.Uuid, nullable=False),
sa.Column("document_id", sa.Uuid, nullable=False),
sa.Column("file", sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.Column("content_type", sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.ForeignKeyConstraint(
Expand All @@ -96,8 +96,8 @@ def upgrade() -> None:
op.create_table(
"documentupdate",
sa.Column("change_bytes", sa.LargeBinary(), nullable=False),
sa.Column("id", sqlmodel.sql.sqltypes.GUID(), nullable=False),
sa.Column("document_id", sqlmodel.sql.sqltypes.GUID(), nullable=False),
sa.Column("id", sa.Uuid, nullable=False),
sa.Column("document_id", sa.Uuid, nullable=False),
sa.ForeignKeyConstraint(
["document_id"],
["document.id"],
Expand All @@ -111,10 +111,10 @@ def upgrade() -> None:
"task",
sa.Column("task_parameters", sa.JSON(), nullable=False),
sa.Column("task_type", sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.Column("document_id", sqlmodel.sql.sqltypes.GUID(), nullable=False),
sa.Column("id", sqlmodel.sql.sqltypes.GUID(), nullable=False),
sa.Column("document_id", sa.Uuid, nullable=False),
sa.Column("id", sa.Uuid, nullable=False),
sa.Column("progress", sa.Float(), nullable=True),
sa.Column("assigned_worker_id", sqlmodel.sql.sqltypes.GUID(), nullable=True),
sa.Column("assigned_worker_id", sa.Uuid, nullable=True),
sa.Column("assigned_at", sa.DateTime(), nullable=True),
sa.Column("last_keepalive", sa.DateTime(), nullable=True),
sa.Column("is_completed", sa.Boolean(), nullable=False),
Expand All @@ -134,9 +134,9 @@ def upgrade() -> None:

op.create_table(
"documentmediatag",
sa.Column("id", sqlmodel.sql.sqltypes.GUID(), nullable=False),
sa.Column("id", sa.Uuid, nullable=False),
sa.Column("tag", sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.Column("media_file_id", sqlmodel.sql.sqltypes.GUID(), nullable=False),
sa.Column("media_file_id", sa.Uuid, nullable=False),
sa.ForeignKeyConstraint(
["media_file_id"],
["documentmediafile.id"],
Expand All @@ -150,9 +150,9 @@ def upgrade() -> None:

op.create_table(
"taskdependency",
sa.Column("id", sqlmodel.sql.sqltypes.GUID(), nullable=False),
sa.Column("dependent_task_id", sqlmodel.sql.sqltypes.GUID(), nullable=False),
sa.Column("dependant_on_id", sqlmodel.sql.sqltypes.GUID(), nullable=False),
sa.Column("id", sa.Uuid, nullable=False),
sa.Column("dependent_task_id", sa.Uuid, nullable=False),
sa.Column("dependant_on_id", sa.Uuid, nullable=False),
sa.ForeignKeyConstraint(
["dependant_on_id"],
["task.id"],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,9 @@ def upgrade_with_autocommit() -> None:
TaskAttempt = op.create_table(
"taskattempt",
sa.Column("extra_data", sa.JSON(), nullable=True),
sa.Column("id", sqlmodel.sql.sqltypes.GUID(), nullable=False),
sa.Column("task_id", sqlmodel.sql.sqltypes.GUID(), nullable=False),
sa.Column("assigned_worker_id", sqlmodel.sql.sqltypes.GUID(), nullable=True),
sa.Column("id", sa.Uuid, nullable=False),
sa.Column("task_id", sa.Uuid, nullable=False),
sa.Column("assigned_worker_id", sa.Uuid, nullable=True),
sa.Column("attempt_number", sa.Integer(), nullable=False),
sa.Column("started_at", sa.DateTime(), nullable=True),
sa.Column("last_keepalive", sa.DateTime(), nullable=True),
Expand All @@ -51,9 +51,7 @@ def upgrade_with_autocommit() -> None:
batch_op.create_index(batch_op.f("ix_taskattempt_id"), ["id"], unique=False)

with op.batch_alter_table("task", schema=None) as batch_op:
batch_op.add_column(
sa.Column("current_attempt_id", sqlmodel.sql.sqltypes.GUID(), nullable=True)
)
batch_op.add_column(sa.Column("current_attempt_id", sa.Uuid, nullable=True))
batch_op.add_column(
sa.Column(
"attempt_counter", sa.Integer(), nullable=True, server_default="0"
Expand Down Expand Up @@ -90,20 +88,20 @@ def upgrade_with_autocommit() -> None:

Task = sa.table(
"task",
sa.column("id", sqlmodel.sql.sqltypes.GUID()),
sa.column("id", sa.Uuid),
sa.column("assigned_at", sa.DateTime()),
sa.column("last_keepalive", sa.DateTime()),
sa.column("completed_at", sa.DateTime()),
sa.column("is_completed", sa.Boolean()),
sa.column("completion_data", sa.JSON()),
sa.column("assigned_worker_id", sqlmodel.sql.sqltypes.GUID()),
sa.column("assigned_worker_id", sa.Uuid),
sa.column("state_changed_at", sa.DateTime()),
sa.column(
"state",
sa.Enum("NEW", "ASSIGNED", "COMPLETED", "FAILED", name="taskstate"),
),
sa.column("remaining_attempts", sa.Integer()),
sa.column("current_attempt_id", sqlmodel.sql.sqltypes.GUID()),
sa.column("current_attempt_id", sa.Uuid),
sa.column("progress", sa.Float()),
sa.column("attempt_counter", sa.Integer()),
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.create_table(
"apitoken",
sa.Column("id", sqlmodel.sql.sqltypes.GUID(), nullable=False),
sa.Column("id", sa.Uuid, nullable=False),
sa.Column("name", sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.Column("token", sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.PrimaryKeyConstraint("id"),
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
"""fix accidentially nullable fields
Revision ID: c88376bf4844
Revises: 417eece003cb
Create Date: 2025-01-14 16:23:54.386629
"""
import sqlalchemy as sa
from alembic import op

# revision identifiers, used by Alembic.
revision = "c88376bf4844"
down_revision = "417eece003cb"
branch_labels = None
depends_on = None


def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table("task", schema=None) as batch_op:
batch_op.alter_column("document_id", existing_type=sa.UUID(), nullable=False)

with op.batch_alter_table("taskdependency", schema=None) as batch_op:
batch_op.alter_column(
"dependent_task_id", existing_type=sa.UUID(), nullable=False
)
batch_op.alter_column(
"dependant_on_id", existing_type=sa.UUID(), nullable=False
)

# ### end Alembic commands ###


def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table("taskdependency", schema=None) as batch_op:
batch_op.alter_column("dependant_on_id", existing_type=sa.UUID(), nullable=True)
batch_op.alter_column(
"dependent_task_id", existing_type=sa.UUID(), nullable=True
)

with op.batch_alter_table("task", schema=None) as batch_op:
batch_op.alter_column("document_id", existing_type=sa.UUID(), nullable=True)

# ### end Alembic commands ###
Loading

0 comments on commit bb83dec

Please sign in to comment.