diff --git a/README.md b/README.md index a714d43..ccc25e9 100644 --- a/README.md +++ b/README.md @@ -262,6 +262,7 @@ You can use below variables in jinja2 template - `operation.response` response object - `operation.function_name` function name is created `operationId` or `METHOD` + `Path` - `operation.snake_case_arguments` Snake-cased function arguments + - `operation.security` [Security](https://swagger.io/docs/specification/authentication/) ### default template `main.jinja2` diff --git a/fastapi_code_generator/parser.py b/fastapi_code_generator/parser.py index 681db9d..5208fa8 100644 --- a/fastapi_code_generator/parser.py +++ b/fastapi_code_generator/parser.py @@ -89,6 +89,7 @@ class Operation(CachedPropertyModel): responses: Dict[UsefulStr, Any] = {} requestBody: Dict[str, Any] = {} imports: List[Import] = [] + security: Optional[List[Dict[str, List[str]]]] = None @cached_property def root_path(self) -> UsefulStr: @@ -299,6 +300,7 @@ class Operations(BaseModel): options: Optional[Operation] = None trace: Optional[Operation] = None path: UsefulStr + security: Optional[List[Dict[str, List[str]]]] = [] @root_validator(pre=True) def inject_path_and_type_to_operation(cls, values: Dict[str, Any]) -> Any: @@ -311,20 +313,26 @@ def inject_path_and_type_to_operation(cls, values: Dict[str, Any]) -> Any: }, path=path, parameters=values.get('parameters', []), + security=values.get('security'), ) @root_validator - def inject_parameters_to_operation(cls, values: Dict[str, Any]) -> Any: - if parameters := values.get('parameters'): - for operation_name in OPERATION_NAMES: - if operation := values.get(operation_name): + def inject_parameters_and_security_to_operation(cls, values: Dict[str, Any]) -> Any: + security = values.get('security') + for operation_name in OPERATION_NAMES: + if operation := values.get(operation_name): + if parameters := values.get('parameters'): operation.parameters.extend(parameters) + if security is not None and operation.security is None: + operation.security = security + return values class Path(CachedPropertyModel): path: UsefulStr operations: Optional[Operations] = None + security: Optional[List[Dict[str, List[str]]]] = [] @root_validator(pre=True) def validate_root(cls, values: Dict[str, Any]) -> Any: @@ -332,9 +340,13 @@ def validate_root(cls, values: Dict[str, Any]) -> Any: if isinstance(path, str): if operations := values.get('operations'): if isinstance(operations, dict): + security = values.get('security', []) return { 'path': path, - 'operations': dict(**operations, path=path), + 'operations': dict( + **operations, path=path, security=security + ), + 'security': security, } return values @@ -379,15 +391,21 @@ def __init__( def parse(self) -> ParsedObject: openapi = load_json_or_yaml(self.input_text) - return self.parse_paths(openapi["paths"]) + return self.parse_paths(openapi) + + def parse_security( + self, openapi: Dict[str, Any] + ) -> Optional[List[Dict[str, List[str]]]]: + return openapi.get('security') - def parse_paths(self, paths: Dict[str, Any]) -> ParsedObject: + def parse_paths(self, openapi: Dict[str, Any]) -> ParsedObject: + security = self.parse_security(openapi) return ParsedObject( [ operation - for path_name, operations in paths.items() + for path_name, operations in openapi['paths'].items() for operation in Path( - path=UsefulStr(path_name), operations=operations + path=UsefulStr(path_name), operations=operations, security=security ).exists_operations ] ) diff --git a/pyproject.toml b/pyproject.toml index ebb14db..4fbccca 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -53,6 +53,7 @@ freezegun = "^0.3.15" line-length = 88 skip-string-normalization = true target-version = ['py38'] +exclude = '(tests/data|\.eggs|\.git|\.hg|\.mypy_cache|\.nox|\.tox|\.venv|_build|buck-out|build|dist|.*\/models\.py.*|.*\/models\/.*)' [tool.isort] multi_line_output = 3 @@ -60,6 +61,7 @@ include_trailing_comma = true force_grid_wrap = 0 use_parentheses = true line_length = 88 +skip = "tests/data" [tool.pydantic-pycharm-plugin.parsable-types] # str field may parse int and float diff --git a/tests/data/custom_template/security/main.jinja2 b/tests/data/custom_template/security/main.jinja2 new file mode 100644 index 0000000..ad228d2 --- /dev/null +++ b/tests/data/custom_template/security/main.jinja2 @@ -0,0 +1,47 @@ + +from __future__ import annotations + +from typing import List, Optional + +from fastapi import Depends, FastAPI, HTTPException, Query +from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials +from pydantic import BaseModel +from starlette import status + +{{ imports }} + +app = FastAPI() + + +DUMMY_CREDENTIALS = 'abcdefg' + + +class User(BaseModel): + username: str + email: str + + +def get_dummy_user(token: str) -> User: + return User(username=token, email='abc@example.com') + + +async def valid_token(auth: HTTPAuthorizationCredentials = Depends(HTTPBearer())) -> str: + if auth.credentials == DUMMY_CREDENTIALS: + return 'dummy' + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Invalid authentication credentials", + headers={"WWW-Authenticate": "Bearer"}, + ) + + +async def valid_current_user(token: str = Depends(valid_token)) -> User: + return get_dummy_user(token) + + + +{% for operation in operations %} +@app.{{operation.type}}('{{operation.snake_case_path}}', response_model={{operation.response}}) +def {{operation.function_name}}({{operation.snake_case_arguments}}{%- if operation.security -%}{%- if operation.snake_case_arguments -%}, {%- endif -%}user: User = Depends(valid_current_user){%- endif -%}) -> {{operation.response}}: + pass +{% endfor %} \ No newline at end of file diff --git a/tests/data/expected/openapi/custom_template_security/custom_security/main.py b/tests/data/expected/openapi/custom_template_security/custom_security/main.py new file mode 100644 index 0000000..259c504 --- /dev/null +++ b/tests/data/expected/openapi/custom_template_security/custom_security/main.py @@ -0,0 +1,80 @@ +# generated by fastapi-codegen: +# filename: custom_security.yaml +# timestamp: 2020-06-19T00:00:00+00:00 + +from __future__ import annotations + +from typing import List, Optional + +from pydantic import BaseModel + +from fastapi import Depends, FastAPI, HTTPException, Query +from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer +from starlette import status + +from .models import Pet, PetForm + +app = FastAPI() + + +DUMMY_CREDENTIALS = 'abcdefg' + + +class User(BaseModel): + username: str + email: str + + +def get_dummy_user(token: str) -> User: + return User(username=token, email='abc@example.com') + + +async def valid_token( + auth: HTTPAuthorizationCredentials = Depends(HTTPBearer()), +) -> str: + if auth.credentials == DUMMY_CREDENTIALS: + return 'dummy' + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Invalid authentication credentials", + headers={"WWW-Authenticate": "Bearer"}, + ) + + +async def valid_current_user(token: str = Depends(valid_token)) -> User: + return get_dummy_user(token) + + +@app.get('/food/{food_id}', response_model=None) +def show_food_by_id(food_id: str, user: User = Depends(valid_current_user)) -> None: + pass + + +@app.get('/pets', response_model=List[Pet]) +def list_pets( + limit: Optional[int] = 0, + home_address: Optional[str] = Query('Unknown', alias='HomeAddress'), + kind: Optional[str] = 'dog', +) -> List[Pet]: + pass + + +@app.post('/pets', response_model=None) +def post_pets(body: PetForm, user: User = Depends(valid_current_user)) -> None: + pass + + +@app.get('/pets/{pet_id}', response_model=Pet) +def show_pet_by_id( + pet_id: str = Query(..., alias='petId'), user: User = Depends(valid_current_user) +) -> Pet: + pass + + +@app.put('/pets/{pet_id}', response_model=None) +def put_pets_pet_id( + pet_id: str = Query(..., alias='petId'), + body: PetForm = None, + user: User = Depends(valid_current_user), +) -> None: + pass diff --git a/tests/data/expected/openapi/custom_template_security/custom_security/models.py b/tests/data/expected/openapi/custom_template_security/custom_security/models.py new file mode 100644 index 0000000..3edc4d2 --- /dev/null +++ b/tests/data/expected/openapi/custom_template_security/custom_security/models.py @@ -0,0 +1,23 @@ +# generated by datamodel-codegen: +# filename: custom_security.yaml +# timestamp: 2020-06-19T00:00:00+00:00 + +from typing import Optional + +from pydantic import BaseModel + + +class Pet(BaseModel): + id: int + name: str + tag: Optional[str] = None + + +class Error(BaseModel): + code: int + message: str + + +class PetForm(BaseModel): + name: Optional[str] = None + age: Optional[int] = None diff --git a/tests/data/expected/openapi/body_and_parameters/main.py b/tests/data/expected/openapi/default_template/body_and_parameters/main.py similarity index 100% rename from tests/data/expected/openapi/body_and_parameters/main.py rename to tests/data/expected/openapi/default_template/body_and_parameters/main.py diff --git a/tests/data/expected/openapi/body_and_parameters/models.py b/tests/data/expected/openapi/default_template/body_and_parameters/models.py similarity index 100% rename from tests/data/expected/openapi/body_and_parameters/models.py rename to tests/data/expected/openapi/default_template/body_and_parameters/models.py diff --git a/tests/data/expected/openapi/simple/main.py b/tests/data/expected/openapi/default_template/simple/main.py similarity index 100% rename from tests/data/expected/openapi/simple/main.py rename to tests/data/expected/openapi/default_template/simple/main.py diff --git a/tests/data/expected/openapi/simple/models.py b/tests/data/expected/openapi/default_template/simple/models.py similarity index 100% rename from tests/data/expected/openapi/simple/models.py rename to tests/data/expected/openapi/default_template/simple/models.py diff --git a/tests/data/openapi/custom_template_security/custom_security.yaml b/tests/data/openapi/custom_template_security/custom_security.yaml new file mode 100644 index 0000000..f2674c2 --- /dev/null +++ b/tests/data/openapi/custom_template_security/custom_security.yaml @@ -0,0 +1,193 @@ +openapi: "3.0.0" +info: + version: 1.0.0 + title: Swagger Petstore + license: + name: MIT +servers: + - url: http://petstore.swagger.io/v1 +security: + - BearerAuth: [] +paths: + /pets: + get: + summary: List all pets + operationId: listPets + tags: + - pets + security: [] + parameters: + - name: limit + in: query + description: How many items to return at one time (max 100) + required: false + schema: + default: 0 + type: integer + format: int32 + - name: HomeAddress + in: query + required: false + schema: + default: 'Unknown' + type: string + - name: kind + in: query + required: false + schema: + default: dog + type: string + responses: + '200': + description: A paged array of pets + headers: + x-next: + description: A link to the next page of responses + schema: + type: string + content: + application/json: + schema: + type: array + items: + - $ref: "#/components/schemas/Pet" + default: + description: unexpected error + content: + application/json: + schema: + $ref: "#/components/schemas/Error" + post: + summary: Create a pet + tags: + - pets + requestBody: + required: true + content: + application/json: + schema: + $ref: '#/components/schemas/PetForm' + responses: + '201': + description: Null response + default: + description: unexpected error + content: + application/json: + schema: + $ref: "#/components/schemas/Error" + /pets/{petId}: + get: + summary: Info for a specific pet + operationId: showPetById + tags: + - pets + parameters: + - name: petId + in: path + required: true + description: The id of the pet to retrieve + schema: + type: string + responses: + '200': + description: Expected response to a valid request + content: + application/json: + schema: + $ref: "#/components/schemas/Pet" + default: + description: unexpected error + content: + application/json: + schema: + $ref: "#/components/schemas/Error" + put: + parameters: + - name: petId + in: path + required: true + description: The id of the pet to retrieve + schema: + type: string + summary: update a pet + tags: + - pets + requestBody: + required: false + content: + application/json: + schema: + $ref: '#/components/schemas/PetForm' + responses: + '201': + description: Null response + default: + description: unexpected error + content: + application/json: + schema: + $ref: "#/components/schemas/Error" + x-amazon-apigateway-integration: + uri: + Fn::Sub: arn:aws:apigateway:${AWS::Region}:lambda:path/2015-03-31/functions/${PythonVersionFunction.Arn}/invocations + passthroughBehavior: when_no_templates + httpMethod: POST + type: aws_proxy + /food/{food_id}: + get: + summary: Info for a specific pet + operationId: showFoodById + tags: + - foods + parameters: + - name: food_id + in: path + description: The id of the food to retrieve + schema: + type: string + responses: + '200': + description: Expected response to a valid request + x-amazon-apigateway-integration: + uri: + Fn::Sub: arn:aws:apigateway:${AWS::Region}:lambda:path/2015-03-31/functions/${PythonVersionFunction.Arn}/invocations + passthroughBehavior: when_no_templates + httpMethod: POST + type: aws_proxy +components: + securitySchemes: + BearerAuth: + type: http + scheme: bearer + schemas: + Pet: + required: + - id + - name + properties: + id: + type: integer + format: int64 + name: + type: string + tag: + type: string + Error: + required: + - code + - message + properties: + code: + type: integer + format: int32 + message: + type: string + PetForm: + title: PetForm + type: object + properties: + name: + type: string + age: + type: integer \ No newline at end of file diff --git a/tests/data/openapi/body_and_parameters.yaml b/tests/data/openapi/default_template/body_and_parameters.yaml similarity index 97% rename from tests/data/openapi/body_and_parameters.yaml rename to tests/data/openapi/default_template/body_and_parameters.yaml index c3106b0..f2674c2 100644 --- a/tests/data/openapi/body_and_parameters.yaml +++ b/tests/data/openapi/default_template/body_and_parameters.yaml @@ -6,6 +6,8 @@ info: name: MIT servers: - url: http://petstore.swagger.io/v1 +security: + - BearerAuth: [] paths: /pets: get: @@ -13,6 +15,7 @@ paths: operationId: listPets tags: - pets + security: [] parameters: - name: limit in: query @@ -153,6 +156,10 @@ paths: httpMethod: POST type: aws_proxy components: + securitySchemes: + BearerAuth: + type: http + scheme: bearer schemas: Pet: required: diff --git a/tests/data/openapi/simple.yaml b/tests/data/openapi/default_template/simple.yaml similarity index 100% rename from tests/data/openapi/simple.yaml rename to tests/data/openapi/default_template/simple.yaml diff --git a/tests/test_generate.py b/tests/test_generate.py index 43fc1d5..d8c4ac9 100644 --- a/tests/test_generate.py +++ b/tests/test_generate.py @@ -6,18 +6,19 @@ from fastapi_code_generator.__main__ import generate_code -OPEN_API_DIR_NAME = 'openapi' +OPEN_API_DEFAULT_TEMPLATE_DIR_NAME = Path('openapi') / 'default_template' +OPEN_API_SECURITY_TEMPLATE_DIR_NAME = Path('openapi') / 'custom_template_security' DATA_DIR = Path(__file__).parent / 'data' -OPEN_API_DIR = DATA_DIR / OPEN_API_DIR_NAME - EXPECTED_DIR = DATA_DIR / 'expected' -@pytest.mark.parametrize("oas_file", OPEN_API_DIR.glob("*.yaml")) +@pytest.mark.parametrize( + "oas_file", (DATA_DIR / OPEN_API_DEFAULT_TEMPLATE_DIR_NAME).glob("*.yaml") +) @freeze_time("2020-06-19") -def test_generate_simple(oas_file): +def test_generate_default_template(oas_file): with TemporaryDirectory() as tmp_dir: output_dir = Path(tmp_dir) / oas_file.stem generate_code( @@ -26,7 +27,30 @@ def test_generate_simple(oas_file): output_dir=output_dir, template_dir=None, ) - expected_dir = EXPECTED_DIR / OPEN_API_DIR_NAME / oas_file.stem + expected_dir = EXPECTED_DIR / OPEN_API_DEFAULT_TEMPLATE_DIR_NAME / oas_file.stem + output_files = sorted(list(output_dir.glob('*'))) + expected_files = sorted(list(expected_dir.glob('*'))) + assert [f.name for f in output_files] == [f.name for f in expected_files] + for output_file, expected_file in zip(output_files, expected_files): + assert output_file.read_text() == expected_file.read_text() + + +@pytest.mark.parametrize( + "oas_file", (DATA_DIR / OPEN_API_SECURITY_TEMPLATE_DIR_NAME).glob("*.yaml") +) +@freeze_time("2020-06-19") +def test_generate_custom_security_template(oas_file): + with TemporaryDirectory() as tmp_dir: + output_dir = Path(tmp_dir) / oas_file.stem + generate_code( + input_name=oas_file.name, + input_text=oas_file.read_text(), + output_dir=output_dir, + template_dir=DATA_DIR / 'custom_template' / 'security', + ) + expected_dir = ( + EXPECTED_DIR / OPEN_API_SECURITY_TEMPLATE_DIR_NAME / oas_file.stem + ) output_files = sorted(list(output_dir.glob('*'))) expected_files = sorted(list(expected_dir.glob('*'))) assert [f.name for f in output_files] == [f.name for f in expected_files]