Skip to content

Commit

Permalink
Optimize and normalize the code generator (#430)
Browse files Browse the repository at this point in the history
  • Loading branch information
wu-clan authored Sep 24, 2024
1 parent 1d3b0e7 commit ab4495c
Show file tree
Hide file tree
Showing 11 changed files with 231 additions and 194 deletions.
8 changes: 7 additions & 1 deletion backend/app/generator/api/v1/gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ async def get_all_businesses() -> ResponseModel:

@router.get('/businesses/{pk}', summary='获取代码生成业务详情', dependencies=[DependsJwtAuth])
async def get_business(pk: Annotated[int, Path(...)]) -> ResponseModel:
business = await gen_service.get_business_with_model(pk=pk)
business = await gen_business_service.get(pk=pk)
data = GetGenBusinessListDetails(**select_as_dict(business))
return response_base.success(data=data)

Expand Down Expand Up @@ -89,6 +89,12 @@ async def delete_business(pk: Annotated[int, Path(...)]) -> ResponseModel:
return response_base.fail()


@router.get('/models/types', summary='获取代码生成模型列类型', dependencies=[DependsJwtAuth])
async def get_model_types() -> ResponseModel:
model_types = await gen_model_service.get_types()
return response_base.success(data=model_types)


@router.get('/models/{pk}', summary='获取代码生成模型详情', dependencies=[DependsJwtAuth])
async def get_model(pk: Annotated[int, Path(...)]) -> ResponseModel:
model = await gen_model_service.get(pk=pk)
Expand Down
12 changes: 1 addition & 11 deletions backend/app/generator/crud/crud_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,11 @@
# -*- coding: utf-8 -*-
from typing import Sequence

from sqlalchemy import Row, select, text
from sqlalchemy import Row, text
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload

from backend.app.generator.model import GenBusiness


class CRUDGen:
@staticmethod
async def get_business_with_model(db: AsyncSession, business_id: int) -> GenBusiness:
stmt = select(GenBusiness).options(selectinload(GenBusiness.gen_model)).where(GenBusiness.id == business_id)
result = await db.execute(stmt)
data = result.scalars().first()
return data

@staticmethod
async def get_all_tables(db: AsyncSession, table_schema: str) -> Sequence[str]:
stmt = text(
Expand Down
7 changes: 6 additions & 1 deletion backend/app/generator/crud/crud_gen_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# -*- coding: utf-8 -*-
from typing import Sequence

from sqlalchemy import update
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy_crud_plus import CRUDPlus

Expand Down Expand Up @@ -34,6 +35,7 @@ async def create(self, db: AsyncSession, obj_in: CreateGenModelParam, pd_type: s
:param db:
:param obj_in:
:param pd_type:
:return:
"""
await self.create_model(db, obj_in, pd_type=pd_type)
Expand All @@ -45,9 +47,12 @@ async def update(self, db: AsyncSession, pk: int, obj_in: UpdateGenModelParam, p
:param db:
:param pk:
:param obj_in:
:param pd_type:
:return:
"""
return await self.update_model_by_column(db, obj_in, id=pk, pd_type=pd_type)
stmt = update(self.model).where(self.model.id == pk).values(**obj_in.model_dump(), pd_type=pd_type)
result = await db.execute(stmt)
return result.rowcount

async def delete(self, db: AsyncSession, pk: int) -> int:
"""
Expand Down
3 changes: 1 addition & 2 deletions backend/app/generator/schema/gen_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,14 @@
# -*- coding: utf-8 -*-
from pydantic import ConfigDict, Field, field_validator

from backend.common.enums import GenModelColumnType
from backend.common.schema import SchemaBase
from backend.utils.type_conversion import sql_type_to_sqlalchemy


class GenModelSchemaBase(SchemaBase):
name: str
comment: str | None = None
type: GenModelColumnType = Field(GenModelColumnType.VARCHAR)
type: str
default: str | None = None
sort: int
length: int
Expand Down
7 changes: 7 additions & 0 deletions backend/app/generator/service/gen_model_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from backend.app.generator.crud.crud_gen_model import gen_model_dao
from backend.app.generator.model import GenModel
from backend.app.generator.schema.gen_model import CreateGenModelParam, UpdateGenModelParam
from backend.common.enums import GenModelMySQLColumnType
from backend.common.exception import errors
from backend.database.db_mysql import async_db_session
from backend.utils.type_conversion import sql_type_to_pydantic
Expand All @@ -17,6 +18,12 @@ async def get(*, pk: int) -> GenModel:
gen_model = await gen_model_dao.get(db, pk)
return gen_model

@staticmethod
async def get_types() -> list[str]:
types = GenModelMySQLColumnType.get_member_keys()
types.sort()
return types

@staticmethod
async def get_by_business(*, business_id: int) -> Sequence[GenModel]:
async with async_db_session() as db:
Expand Down
15 changes: 4 additions & 11 deletions backend/app/generator/service/gen_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,20 +18,14 @@
from backend.app.generator.schema.gen_business import CreateGenBusinessParam
from backend.app.generator.schema.gen_model import CreateGenModelParam
from backend.app.generator.service.gen_model_service import gen_model_service
from backend.common.enums import GenModelColumnType
from backend.common.exception import errors
from backend.core.path_conf import BasePath
from backend.database.db_mysql import async_db_session
from backend.utils.gen_template import gen_template
from backend.utils.type_conversion import sql_type_to_pydantic


class GenService:
@staticmethod
async def get_business_with_model(*, pk: int) -> GenBusiness:
async with async_db_session() as db:
business = await gen_dao.get_business_with_model(db, pk)
return business

@staticmethod
async def get_tables(*, table_schema: str) -> Sequence[str]:
async with async_db_session() as db:
Expand Down Expand Up @@ -60,19 +54,18 @@ async def import_business_and_model(*, app: str, table_schema: str, table_name:
column_info = await gen_dao.get_all_columns(db, table_schema, table_name)
for column in column_info:
column_type = column[-1].split('(')[0].upper()
pd_type = sql_type_to_pydantic(column_type)
model_data = {
'name': column[0],
'comment': column[-2],
'type': column_type,
'sort': column[-3],
'length': column[-1].split('(')[1][:-1]
if column_type == GenModelColumnType.CHAR or column_type == GenModelColumnType.VARCHAR
else 0,
'length': column[-1].split('(')[1][:-1] if pd_type == 'str' and '(' in column[-1] else 0,
'is_pk': column[1],
'is_nullable': column[2],
'gen_business_id': new_business.id,
}
await gen_model_dao.create(db, CreateGenModelParam(**model_data))
await gen_model_dao.create(db, CreateGenModelParam(**model_data), pd_type=pd_type)

@staticmethod
async def render_tpl_code(*, business: GenBusiness) -> dict:
Expand Down
195 changes: 148 additions & 47 deletions backend/common/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,50 +90,151 @@ class UserSocialType(StrEnum):
linuxdo = 'LinuxDo'


class GenModelColumnType(StrEnum):
"""代码生成模型列类型"""

BIGINT = 'BIGINT'
BINARY = 'BINARY'
BIT = 'BIT'
BLOB = 'BLOB'
BOOL = 'BOOL'
BOOLEAN = 'BOOLEAN'
CHAR = 'CHAR'
DATE = 'DATE'
DATETIME = 'DATETIME'
DECIMAL = 'DECIMAL'
DOUBLE = 'DOUBLE'
DOUBLE_PRECISION = 'DOUBLE PRECISION'
ENUM = 'ENUM'
FLOAT = 'FLOAT'
GEOMETRY = 'GEOMETRY'
GEOMETRYCOLLECTION = 'GEOMETRYCOLLECTION'
INT = 'INT'
INTEGER = 'INTEGER'
JSON = 'JSON'
LINESTRING = 'LINESTRING'
LONGBLOB = 'LONGBLOB'
LONGTEXT = 'LONGTEXT'
MEDIUMBLOB = 'MEDIUMBLOB'
MEDIUMINT = 'MEDIUMINT'
MEDIUMTEXT = 'MEDIUMTEXT'
MULTILINESTRING = 'MULTILINESTRING'
MULTIPOINT = 'MULTIPOINT'
MULTIPOLYGON = 'MULTIPOLYGON'
NUMERIC = 'NUMERIC'
POINT = 'POINT'
POLYGON = 'POLYGON'
REAL = 'REAL'
SERIAL = 'SERIAL'
SET = 'SET'
SMALLINT = 'SMALLINT'
TEXT = 'TEXT'
TIME = 'TIME'
TIMESTAMP = 'TIMESTAMP'
TINYBLOB = 'TINYBLOB'
TINYINT = 'TINYINT'
TINYTEXT = 'TINYTEXT'
VARBINARY = 'VARBINARY'
VARCHAR = 'VARCHAR'
YEAR = 'YEAR'
class GenModelMySQLColumnType(StrEnum):
"""代码生成模型列类型(MySQL)"""

# Python 类型映射
BIGINT = 'int'
BigInteger = 'int' # BIGINT
BINARY = 'bytes'
BLOB = 'bytes'
BOOLEAN = 'bool' # BOOL
Boolean = 'bool' # BOOL
CHAR = 'str'
CLOB = 'str'
DATE = 'date'
Date = 'date' # DATE
DATETIME = 'datetime'
DateTime = 'datetime' # DATETIME
DECIMAL = 'Decimal'
DOUBLE = 'float'
Double = 'float' # DOUBLE
DOUBLE_PRECISION = 'float'
Enum = 'Enum' # Enum()
FLOAT = 'float'
Float = 'float' # FLOAT
INT = 'int' # INTEGER
INTEGER = 'int'
Integer = 'int' # INTEGER
Interval = 'timedelta' # DATETIME
JSON = 'dict'
LargeBinary = 'bytes' # BLOB
NCHAR = 'str'
NUMERIC = 'Decimal'
Numeric = 'Decimal' # NUMERIC
NVARCHAR = 'str' # String
PickleType = 'bytes' # BLOB
REAL = 'float'
SMALLINT = 'int'
SmallInteger = 'int' # SMALLINT
String = 'str' # String
TEXT = 'str'
Text = 'str' # TEXT
TIME = 'time'
Time = 'time' # TIME
TIMESTAMP = 'datetime'
Unicode = 'str' # String
UnicodeText = 'str' # TEXT
UUID = 'str | UUID'
Uuid = 'str' # CHAR(32)
VARBINARY = 'bytes'
VARCHAR = 'str' # String

# sa.dialects.mysql 导入
BIT = 'bool'
ENUM = 'Enum'
LONGBLOB = 'bytes'
LONGTEXT = 'str'
MEDIUMBLOB = 'bytes'
MEDIUMINT = 'int'
MEDIUMTEXT = 'str'
SET = 'list[str]'
TINYBLOB = 'bytes'
TINYINT = 'int'
TINYTEXT = 'str'
YEAR = 'int'


class GenModelPostgreSQLColumnType(StrEnum):
"""代码生成模型列类型(PostgreSQL),仅作为数据保留,并未实施"""

# Python 类型映射
BIGINT = 'int'
BigInteger = 'int' # BIGINT
BINARY = 'bytes'
BLOB = 'bytes'
BOOLEAN = 'bool'
Boolean = 'bool' # BOOLEAN
CHAR = 'str'
CLOB = 'str'
DATE = 'date'
Date = 'date' # DATE
DATETIME = 'datetime'
DateTime = 'datetime' # TIMESTAMP WITHOUT TIME ZONE
DECIMAL = 'Decimal'
DOUBLE = 'float'
Double = 'float' # DOUBLE PRECISION
DOUBLE_PRECISION = 'float' # DOUBLE PRECISION
Enum = 'Enum' # Enum(name='enum')
FLOAT = 'float'
Float = 'float' # FLOAT
INT = 'int' # INTEGER
INTEGER = 'int'
Integer = 'int' # INTEGER
Interval = 'timedelta' # INTERVAL
JSON = 'dict'
LargeBinary = 'bytes' # BYTEA
NCHAR = 'str'
NUMERIC = 'Decimal'
Numeric = 'Decimal' # NUMERIC
NVARCHAR = 'str' # String
PickleType = 'bytes' # BYTEA
REAL = 'float'
SMALLINT = 'int'
SmallInteger = 'int' # SMALLINT
String = 'str' # String
TEXT = 'str'
Text = 'str' # TEXT
TIME = 'time' # TIME WITHOUT TIME ZONE
Time = 'time' # TIME WITHOUT TIME ZONE
TIMESTAMP = 'datetime' # TIMESTAMP WITHOUT TIME ZONE
Unicode = 'str' # String
UnicodeText = 'str' # TEXT
UUID = 'str | UUID'
Uuid = 'str'
VARBINARY = 'bytes'
VARCHAR = 'str' # String

# sa.dialects.postgresql 导入
ARRAY = 'list'
BIT = 'bool'
BYTEA = 'bytes'
CIDR = 'str'
CITEXT = 'str'
DATEMULTIRANGE = 'list[date]'
DATERANGE = 'tuple[date, date]'
DOMAIN = 'str'
ENUM = 'Enum'
HSTORE = 'dict'
INET = 'str'
INT4MULTIRANGE = 'list[int]'
INT4RANGE = 'tuple[int, int]'
INT8MULTIRANGE = 'list[int]'
INT8RANGE = 'tuple[int, int]'
INTERVAL = 'timedelta'
JSONB = 'dict'
JSONPATH = 'str'
MACADDR = 'str'
MACADDR8 = 'str'
MONEY = 'Decimal'
NUMMULTIRANGE = 'list[Decimal]'
NUMRANGE = 'tuple[Decimal, Decimal]'
OID = 'int'
REGCLASS = 'str'
REGCONFIG = 'str'
TSMULTIRANGE = 'list[datetime]'
TSQUERY = 'str'
TSRANGE = 'tuple[datetime, datetime]'
TSTZMULTIRANGE = 'list[datetime]'
TSTZRANGE = 'tuple[datetime, datetime]'
TSVECTOR = 'str'
Loading

0 comments on commit ab4495c

Please sign in to comment.