From 215fa1bfbccafe96ac87e893cb4da25b4b71030e Mon Sep 17 00:00:00 2001 From: Viren Nadkarni Date: Mon, 11 Dec 2023 17:03:02 +0530 Subject: [PATCH] APIGW: Fix tests to run with non-default account ID (#9690) --- localstack/services/apigateway/helpers.py | 6 +- localstack/services/apigateway/integration.py | 10 +- localstack/utils/testutil.py | 8 +- tests/aws/services/apigateway/conftest.py | 5 + .../apigateway/test_apigateway_basic.py | 157 +++++++++++++----- 5 files changed, 141 insertions(+), 45 deletions(-) diff --git a/localstack/services/apigateway/helpers.py b/localstack/services/apigateway/helpers.py index e9fe7f3eaf4dd..30c100730efd2 100644 --- a/localstack/services/apigateway/helpers.py +++ b/localstack/services/apigateway/helpers.py @@ -585,8 +585,10 @@ def get_stage_variables(context: ApiInvocationContext) -> Optional[Dict[str, str if not context.stage: return {} - _, region_name = get_api_account_id_and_region(context.api_id) - api_gateway_client = connect_to(region_name=region_name).apigateway + account_id, region_name = get_api_account_id_and_region(context.api_id) + api_gateway_client = connect_to( + aws_access_key_id=account_id, region_name=region_name + ).apigateway try: response = api_gateway_client.get_stage(restApiId=context.api_id, stageName=context.stage) return response.get("variables") diff --git a/localstack/services/apigateway/integration.py b/localstack/services/apigateway/integration.py index 34c7833ed3550..91c29e281700c 100644 --- a/localstack/services/apigateway/integration.py +++ b/localstack/services/apigateway/integration.py @@ -168,9 +168,9 @@ def get_source_arn(invocation_context: ApiInvocationContext): def call_lambda( function_arn: str, event: bytes, asynchronous: bool, invocation_context: ApiInvocationContext ) -> str: - region_name = extract_region_from_arn(function_arn) clients = get_service_factory( - region_name=region_name, role_arn=invocation_context.integration.get("credentials") + region_name=extract_region_from_arn(function_arn), + role_arn=invocation_context.integration.get("credentials"), ) inv_result = clients.lambda_.request_metadata( service_principal=ServicePrincipal.apigateway, source_arn=get_source_arn(invocation_context) @@ -754,7 +754,11 @@ def invoke(self, invocation_context: ApiInvocationContext): else: payload = json.loads(invocation_context.data) - client = connect_to().stepfunctions + client = get_service_factory( + region_name=invocation_context.region_name, + role_arn=invocation_context.integration.get("credentials"), + ).stepfunctions + if isinstance(payload.get("input"), dict): payload["input"] = json.dumps(payload["input"]) diff --git a/localstack/utils/testutil.py b/localstack/utils/testutil.py index eca5f9aed8f66..d05826768ace3 100644 --- a/localstack/utils/testutil.py +++ b/localstack/utils/testutil.py @@ -290,6 +290,7 @@ def connect_api_gateway_to_http_with_lambda_proxy( auth_creator_func=None, http_method=None, client=None, + role_arn: str = None, ): if methods is None: methods = [] @@ -304,12 +305,15 @@ def connect_api_gateway_to_http_with_lambda_proxy( for method in methods: int_meth = http_method or method + integration = {"type": "AWS_PROXY", "uri": target_uri, "httpMethod": int_meth} + if role_arn: + integration["credentials"] = role_arn resources[resource_path].append( { "httpMethod": method, "authorizationType": auth_type, "authorizerId": None, - "integrations": [{"type": "AWS_PROXY", "uri": target_uri, "httpMethod": int_meth}], + "integrations": [integration], } ) return resource_utils.create_api_gateway( @@ -332,6 +336,7 @@ def create_lambda_api_gateway_integration( stage_name=None, auth_type=None, auth_creator_func=None, + role_arn: str = None, ): if methods is None: methods = [] @@ -355,6 +360,7 @@ def create_lambda_api_gateway_integration( methods=methods, auth_type=auth_type, auth_creator_func=auth_creator_func, + role_arn=role_arn, ) return result diff --git a/tests/aws/services/apigateway/conftest.py b/tests/aws/services/apigateway/conftest.py index 2e5f4423a85b8..cf42c066b1a7c 100644 --- a/tests/aws/services/apigateway/conftest.py +++ b/tests/aws/services/apigateway/conftest.py @@ -43,6 +43,11 @@ "Statement": [{"Effect": "Allow", "Action": "lambda:*", "Resource": "*"}], } +APIGATEWAY_S3_POLICY = { + "Version": "2012-10-17", + "Statement": [{"Effect": "Allow", "Action": "s3:*", "Resource": "*"}], +} + APIGATEWAY_DYNAMODB_POLICY = { "Version": "2012-10-17", "Statement": [{"Effect": "Allow", "Action": "dynamodb:*", "Resource": "*"}], diff --git a/tests/aws/services/apigateway/test_apigateway_basic.py b/tests/aws/services/apigateway/test_apigateway_basic.py index 5176299e494b3..f235fe441fa18 100644 --- a/tests/aws/services/apigateway/test_apigateway_basic.py +++ b/tests/aws/services/apigateway/test_apigateway_basic.py @@ -59,8 +59,10 @@ ) from tests.aws.services.apigateway.conftest import ( APIGATEWAY_ASSUME_ROLE_POLICY, + APIGATEWAY_DYNAMODB_POLICY, APIGATEWAY_KINESIS_POLICY, APIGATEWAY_LAMBDA_POLICY, + APIGATEWAY_S3_POLICY, APIGATEWAY_STEPFUNCTIONS_POLICY, STEPFUNCTIONS_ASSUME_ROLE_POLICY, ) @@ -115,6 +117,12 @@ ] }""" +API_PATH_LAMBDA_PROXY_BACKEND = "/lambda/foo1" +API_PATH_LAMBDA_PROXY_BACKEND_WITH_PATH_PARAM = "/lambda/{test_param1}" +API_PATH_LAMBDA_PROXY_BACKEND_ANY_METHOD = "/lambda-any-method/foo1" +API_PATH_LAMBDA_PROXY_BACKEND_ANY_METHOD_WITH_PATH_PARAM = "/lambda-any-method/{test_param1}" +API_PATH_LAMBDA_PROXY_BACKEND_WITH_IS_BASE64 = "/lambda-is-base64/foo1" + @pytest.fixture def integration_lambda(create_lambda_function): @@ -125,11 +133,6 @@ def integration_lambda(create_lambda_function): class TestAPIGateway: # endpoint paths - API_PATH_LAMBDA_PROXY_BACKEND = "/lambda/foo1" - API_PATH_LAMBDA_PROXY_BACKEND_WITH_PATH_PARAM = "/lambda/{test_param1}" - API_PATH_LAMBDA_PROXY_BACKEND_ANY_METHOD = "/lambda-any-method/foo1" - API_PATH_LAMBDA_PROXY_BACKEND_ANY_METHOD_WITH_PATH_PARAM = "/lambda-any-method/{test_param1}" - API_PATH_LAMBDA_PROXY_BACKEND_WITH_IS_BASE64 = "/lambda-is-base64/foo1" TEST_API_GATEWAY_AUTHORIZER = { "name": "test", @@ -380,6 +383,7 @@ def test_invoke_endpoint_cors_headers( ] api_id = self.create_api_gateway_and_deploy( aws_client.apigateway, + aws_client.dynamodb, integration_type="MOCK", integration_responses=responses, stage_name=TEST_STAGE_NAME, @@ -401,20 +405,30 @@ def test_invoke_endpoint_cors_headers( assert "http://test.com" in response.headers["Access-Control-Allow-Origin"] @markers.aws.unknown - def test_api_gateway_lambda_proxy_integration(self, integration_lambda): - self._test_api_gateway_lambda_proxy_integration( - integration_lambda, self.API_PATH_LAMBDA_PROXY_BACKEND + @pytest.mark.parametrize( + "api_path", [API_PATH_LAMBDA_PROXY_BACKEND, API_PATH_LAMBDA_PROXY_BACKEND_WITH_PATH_PARAM] + ) + def test_api_gateway_lambda_proxy_integration( + self, api_path, integration_lambda, aws_client, create_iam_role_with_policy + ): + role_arn = create_iam_role_with_policy( + RoleName=f"role-apigw-lambda-{short_uid()}", + PolicyName=f"policy-apigw-lambda-{short_uid()}", + RoleDefinition=APIGATEWAY_ASSUME_ROLE_POLICY, + PolicyDefinition=APIGATEWAY_KINESIS_POLICY, ) - @markers.aws.unknown - def test_api_gateway_lambda_proxy_integration_with_path_param(self, integration_lambda): self._test_api_gateway_lambda_proxy_integration( integration_lambda, - self.API_PATH_LAMBDA_PROXY_BACKEND_WITH_PATH_PARAM, + api_path, + role_arn, + aws_client.apigateway, ) @markers.aws.unknown - def test_api_gateway_lambda_proxy_integration_with_is_base_64_encoded(self, integration_lambda): + def test_api_gateway_lambda_proxy_integration_with_is_base_64_encoded( + self, integration_lambda, aws_client, create_iam_role_with_policy + ): # Test the case where `isBase64Encoded` is enabled. content = b"hello, please base64 encode me" @@ -422,9 +436,18 @@ def _mutate_data(data) -> None: data["return_is_base_64_encoded"] = True data["return_raw_body"] = base64.b64encode(content).decode("utf8") + role_arn = create_iam_role_with_policy( + RoleName=f"role-apigw-lambda-{short_uid()}", + PolicyName=f"policy-apigw-lambda-{short_uid()}", + RoleDefinition=APIGATEWAY_ASSUME_ROLE_POLICY, + PolicyDefinition=APIGATEWAY_LAMBDA_POLICY, + ) + test_result = self._test_api_gateway_lambda_proxy_integration_no_asserts( integration_lambda, - self.API_PATH_LAMBDA_PROXY_BACKEND_WITH_IS_BASE64, + API_PATH_LAMBDA_PROXY_BACKEND_WITH_IS_BASE64, + role_arn, + aws_client.apigateway, data_mutator_fn=_mutate_data, ) @@ -436,6 +459,8 @@ def _test_api_gateway_lambda_proxy_integration_no_asserts( self, fn_name: str, path: str, + role_arn: str, + apigw_client, data_mutator_fn: Optional[Callable] = None, ) -> ApiGatewayLambdaProxyIntegrationTestResult: """ @@ -451,7 +476,12 @@ def _test_api_gateway_lambda_proxy_integration_no_asserts( target_uri = invocation_uri % (TEST_AWS_REGION_NAME, lambda_uri) result = testutil.connect_api_gateway_to_http_with_lambda_proxy( - "test_gateway2", target_uri, path=path, stage_name=TEST_STAGE_NAME + "test_gateway2", + target_uri, + path=path, + stage_name=TEST_STAGE_NAME, + client=apigw_client, + role_arn=role_arn, ) api_id = result["id"] @@ -489,8 +519,12 @@ def _test_api_gateway_lambda_proxy_integration( self, fn_name: str, path: str, + role_arn: str, + apigw_client, ) -> None: - test_result = self._test_api_gateway_lambda_proxy_integration_no_asserts(fn_name, path) + test_result = self._test_api_gateway_lambda_proxy_integration_no_asserts( + fn_name, path, role_arn, apigw_client + ) data, resource, result, url, path_with_replace = test_result assert result.status_code == 203 @@ -550,7 +584,7 @@ def _test_api_gateway_lambda_proxy_integration( @markers.aws.unknown def test_api_gateway_lambda_proxy_integration_any_method(self, integration_lambda): self._test_api_gateway_lambda_proxy_integration_any_method( - integration_lambda, self.API_PATH_LAMBDA_PROXY_BACKEND_ANY_METHOD + integration_lambda, API_PATH_LAMBDA_PROXY_BACKEND_ANY_METHOD ) @markers.aws.unknown @@ -559,7 +593,7 @@ def test_api_gateway_lambda_proxy_integration_any_method_with_path_param( ): self._test_api_gateway_lambda_proxy_integration_any_method( integration_lambda, - self.API_PATH_LAMBDA_PROXY_BACKEND_ANY_METHOD_WITH_PATH_PARAM, + API_PATH_LAMBDA_PROXY_BACKEND_ANY_METHOD_WITH_PATH_PARAM, ) @markers.aws.unknown @@ -875,7 +909,7 @@ def test_api_account(self, create_rest_apigw, aws_client): @markers.aws.unknown def test_put_integration_dynamodb_proxy_validation_without_request_template(self, aws_client): - api_id = self.create_api_gateway_and_deploy(aws_client.apigateway) + api_id = self.create_api_gateway_and_deploy(aws_client.apigateway, aws_client.dynamodb) url = path_based_url(api_id=api_id, stage_name="staging", path="/") response = requests.put( url, @@ -886,11 +920,21 @@ def test_put_integration_dynamodb_proxy_validation_without_request_template(self @markers.aws.unknown def test_put_integration_dynamodb_proxy_validation_with_request_template( - self, aws_client, dynamodb_create_table + self, + aws_client, + dynamodb_create_table, + create_iam_role_with_policy, ): table = dynamodb_create_table() table_name = table["TableDescription"]["TableName"] + role_arn = create_iam_role_with_policy( + RoleName=f"role-apigw-dynamodb-{short_uid()}", + PolicyName=f"policy-apigw-dynamodb-{short_uid()}", + RoleDefinition=APIGATEWAY_ASSUME_ROLE_POLICY, + PolicyDefinition=APIGATEWAY_DYNAMODB_POLICY, + ) + # create API GW with DynamoDB integration request_templates = { "application/json": json.dumps( @@ -904,7 +948,10 @@ def test_put_integration_dynamodb_proxy_validation_with_request_template( ) } api_id = self.create_api_gateway_and_deploy( - aws_client.apigateway, request_templates=request_templates + aws_client.apigateway, + aws_client.dynamodb, + request_templates=request_templates, + role_arn=role_arn, ) url = path_based_url(api_id=api_id, stage_name="staging", path="/") @@ -921,7 +968,7 @@ def test_put_integration_dynamodb_proxy_validation_with_request_template( assert result["Item"]["data"] == {"S": "foobar123"} @markers.aws.unknown - def test_multiple_api_keys_validate(self, aws_client): + def test_multiple_api_keys_validate(self, aws_client, create_iam_role_with_policy): request_templates = { "application/json": json.dumps( { @@ -934,8 +981,19 @@ def test_multiple_api_keys_validate(self, aws_client): ) } + role_arn = create_iam_role_with_policy( + RoleName=f"role-apigw-dynamodb-{short_uid()}", + PolicyName=f"policy-apigw-dynamodb-{short_uid()}", + RoleDefinition=APIGATEWAY_ASSUME_ROLE_POLICY, + PolicyDefinition=APIGATEWAY_DYNAMODB_POLICY, + ) + api_id = self.create_api_gateway_and_deploy( - aws_client.apigateway, request_templates=request_templates, is_api_key_required=True + aws_client.apigateway, + aws_client.dynamodb, + request_templates=request_templates, + is_api_key_required=True, + role_arn=role_arn, ) url = path_based_url(api_id=api_id, stage_name="staging", path="/") @@ -996,6 +1054,7 @@ def test_apigateway_with_step_function_integration( create_rest_apigw, create_iam_role_with_policy, aws_client, + account_id, snapshot, ): snapshot.add_transformer(snapshot.transform.key_value("executionArn", "executionArn")) @@ -1008,7 +1067,6 @@ def test_apigateway_with_step_function_integration( ) region_name = aws_client.apigateway._client_config.region_name - aws_account_id = aws_client.sts.get_caller_identity()["Account"] # create lambda fn_name = f"lambda-sfn-apigw-{short_uid()}" @@ -1019,10 +1077,8 @@ def test_apigateway_with_step_function_integration( )["CreateFunctionResponse"]["FunctionArn"] # create state machine and permissions for step function to invoke lambda - role_name = f"sfn_role-{short_uid()}" - role_arn = arns.iam_role_arn(role_name, account_id=aws_account_id) - create_iam_role_with_policy( - RoleName=role_name, + role_arn = create_iam_role_with_policy( + RoleName=f"sfn_role-{short_uid()}", PolicyName=f"sfn-role-policy-{short_uid()}", RoleDefinition=STEPFUNCTIONS_ASSUME_ROLE_POLICY, PolicyDefinition=APIGATEWAY_LAMBDA_POLICY, @@ -1236,7 +1292,10 @@ def test_api_mock_integration_response_params(self, aws_client): } ] api_id = self.create_api_gateway_and_deploy( - aws_client.apigateway, integration_type="MOCK", integration_responses=resps + aws_client.apigateway, + aws_client.dynamodb, + integration_type="MOCK", + integration_responses=resps, ) url = path_based_url(api_id=api_id, stage_name=TEST_STAGE_NAME, path="/") @@ -1444,14 +1503,12 @@ def test_apigw_stage_variables( ) fn_name = f"test-{short_uid()}" - create_lambda_function( + response = create_lambda_function( func_name=fn_name, handler_file=TEST_LAMBDA_PYTHON_ECHO, runtime=Runtime.python3_9, ) - lambda_arn = aws_client.lambda_.get_function(FunctionName=fn_name)["Configuration"][ - "FunctionArn" - ] + lambda_arn = response["CreateFunctionResponse"]["FunctionArn"] if stage_name == "dev": uri = f"arn:aws:apigateway:{region_name}:lambda:path/2015-03-31/functions/arn:aws:lambda:{region_name}:{aws_account_id}:function:${{stageVariables.lambdaFunction}}/invocations" @@ -1520,12 +1577,14 @@ def test_apigw_stage_variables( @staticmethod def create_api_gateway_and_deploy( apigw_client, + dynamodb_client, request_templates=None, response_templates=None, is_api_key_required=False, integration_type=None, integration_responses=None, stage_name="staging", + role_arn: str = None, ): response_templates = response_templates or {} request_templates = request_templates or {} @@ -1538,10 +1597,15 @@ def create_api_gateway_and_deploy( kwargs = {} if integration_type == "AWS": - resource_util.create_dynamodb_table("MusicCollection", partition_key="id") + resource_util.create_dynamodb_table( + "MusicCollection", partition_key="id", client=dynamodb_client + ) kwargs[ "uri" - ] = "arn:aws:apigateway:us-east-1:dynamodb:action/PutItem&Table=MusicCollection" + ] = f"arn:aws:apigateway:{apigw_client.meta.region_name}:dynamodb:action/PutItem&Table=MusicCollection" + + if role_arn: + kwargs["credentials"] = role_arn if not integration_responses: integration_responses = [{"httpMethod": "PUT", "statusCode": "200"}] @@ -1768,7 +1832,9 @@ def test_rest_api_multi_region( class TestIntegrations: @markers.aws.unknown - def test_api_gateway_s3_get_integration(self, create_rest_apigw, aws_client): + def test_api_gateway_s3_get_integration( + self, create_rest_apigw, aws_client, create_iam_role_with_policy + ): s3_client = aws_client.s3 bucket_name = f"test-bucket-{short_uid()}" @@ -1788,8 +1854,15 @@ def test_api_gateway_s3_get_integration(self, create_rest_apigw, aws_client): ContentType=object_content_type, ) + role_arn = create_iam_role_with_policy( + RoleName=f"role-apigw-s3-{short_uid()}", + PolicyName=f"policy-apigw-s3-{short_uid()}", + RoleDefinition=APIGATEWAY_ASSUME_ROLE_POLICY, + PolicyDefinition=APIGATEWAY_S3_POLICY, + ) + self.connect_api_gateway_to_s3( - aws_client.apigateway, bucket_name, object_name, api_id, "GET" + aws_client.apigateway, bucket_name, object_name, api_id, "GET", role_arn ) aws_client.apigateway.create_deployment(restApiId=api_id, stageName="test") @@ -2023,14 +2096,20 @@ def test_api_gateway_sqs_integration_with_event_source( # TODO: replace with fixtures, to allow passing aws_client and enable snapshot testing # ================== - def connect_api_gateway_to_s3(self, apigw_client, bucket_name, file_name, api_id, method): + def connect_api_gateway_to_s3( + self, + apigw_client, + bucket_name: str, + file_name: str, + api_id: str, + method: str, + role_arn: str, + ): """Connects the root resource of an api gateway to the given object of an s3 bucket.""" s3_uri = "arn:aws:apigateway:{}:s3:path/{}/{{proxy}}".format( TEST_AWS_REGION_NAME, bucket_name ) - test_role = "test-s3-role" - role_arn = arns.iam_role_arn(role_name=test_role, account_id=TEST_AWS_ACCOUNT_ID) resources = apigw_client.get_resources(restApiId=api_id) # using the root resource '/' directly for this test root_resource_id = resources["items"][0]["id"]