From 6002ef01d1dc9e3b96738835da31789fa40473a7 Mon Sep 17 00:00:00 2001 From: Daniel Reuter Date: Mon, 2 Dec 2024 07:51:32 +0100 Subject: [PATCH] feat: multi region support --- .github/workflows/ci.yml | 1 + Dockerfile | 6 +- README.md | 15 +- config/config.go | 12 + config/specification.go | 1 + e2e/integration_test.go | 2 +- extaz/availability_zone_attack_blackhole.go | 25 ++- ...y_zone_attack_blackhole_localstack_test.go | 5 +- ...availability_zone_attack_blackhole_test.go | 25 ++- extaz/availablity_zone_discovery.go | 10 +- extaz/availablity_zone_discovery_test.go | 8 +- extec2/instance_attack_state.go | 22 +- extec2/instance_attack_state_test.go | 52 ++++- extec2/instance_discovery.go | 6 +- extec2/instance_discovery_test.go | 12 +- extecs/service_attack_scale.go | 16 +- extecs/service_attack_scale_test.go | 14 +- extecs/service_description_poller.go | 129 ++++++----- .../service_description_poller_mock_test.go | 16 +- extecs/service_description_poller_test.go | 39 ++-- extecs/service_discovery.go | 4 +- extecs/service_event_log.go | 9 +- extecs/service_event_log_test.go | 17 +- extecs/service_task_count_check.go | 15 +- extecs/service_task_count_check_test.go | 6 +- extecs/task_attack_ssm.go | 18 +- extecs/task_attack_ssm_test.go | 4 +- extecs/task_attack_stop.go | 12 +- extecs/task_attack_stop_test.go | 9 +- extecs/task_discovery.go | 6 +- extecs/task_discovery_test.go | 6 +- extelasticache/common_elasticache.go | 7 +- extelasticache/node_group_attack_failover.go | 5 +- .../node_group_attack_failover_test.go | 6 +- extelasticache/node_group_discovery.go | 4 +- extelb/alb_attack_static_response.go | 16 +- extelb/alb_attack_static_response_test.go | 16 +- extelb/alb_discovery.go | 6 +- extelb/alb_discovery_test.go | 8 +- extfis/start_experiment.go | 32 +-- extfis/start_experiment_test.go | 14 +- extfis/template_discovery.go | 4 +- extlambda/attack.go | 14 +- extlambda/attack_test.go | 10 +- extlambda/discovery.go | 4 +- extmsk/cluster_discovery.go | 4 +- extmsk/common.go | 7 +- extmsk/reboot_broker_attack.go | 5 +- extmsk/reboot_broker_attack_test.go | 7 +- extrds/cluster_attack_failover.go | 4 +- extrds/cluster_attack_failover_test.go | 8 +- extrds/cluster_discovery.go | 4 +- extrds/common_cluster.go | 8 +- extrds/common_instance.go | 23 +- extrds/common_instance_test.go | 4 +- extrds/instance_attack_reboot.go | 4 +- extrds/instance_attack_reboot_test.go | 48 +--- extrds/instance_attack_stop.go | 4 +- extrds/instance_attack_stop_test.go | 48 +--- extrds/instance_discovery.go | 6 +- extrds/instance_discovery_test.go | 8 +- main.go | 2 +- main_test.go | 9 - utils/aws_access.go | 155 +++++++++++++ utils/aws_access_test.go | 201 +++++++++++++++++ utils/aws_accounts.go | 85 ------- utils/aws_accounts_test.go | 210 ------------------ utils/aws_zones.go | 30 +-- utils/aws_zones_test.go | 16 +- utils/init.go | 116 ---------- utils/sdk_logging.go | 36 +++ 71 files changed, 901 insertions(+), 819 deletions(-) create mode 100644 utils/aws_access.go create mode 100644 utils/aws_access_test.go delete mode 100644 utils/aws_accounts.go delete mode 100644 utils/aws_accounts_test.go delete mode 100644 utils/init.go create mode 100644 utils/sdk_logging.go diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 21bdfb8e..7b999db7 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -21,6 +21,7 @@ jobs: go_version: '1.23' build_linux_packages: true VERSION_BUMPER_APPID: ${{ vars.GH_APP_STEADYBIT_APP_ID }} + force_push_docker_image: true secrets: SONAR_TOKEN: ${{ secrets.SONAR_TOKEN }} PAT_TOKEN_EXTENSION_DEPLOYER: ${{ secrets.PAT_TOKEN_EXTENSION_DEPLOYER }} diff --git a/Dockerfile b/Dockerfile index a6d7eae0..470fbe88 100644 --- a/Dockerfile +++ b/Dockerfile @@ -5,7 +5,8 @@ ## FROM --platform=$BUILDPLATFORM goreleaser/goreleaser:v2.4.8 AS build -ARG TARGETOS TARGETARCH +ARG TARGETOS +ARG TARGETARCH ARG BUILD_WITH_COVERAGE ARG BUILD_SNAPSHOT=true ARG SKIP_LICENSES_REPORT=false @@ -14,6 +15,7 @@ WORKDIR /app COPY . . RUN GOOS=$TARGETOS GOARCH=$TARGETARCH goreleaser build --snapshot="${BUILD_SNAPSHOT}" --single-target -o extension + ## ## Runtime ## @@ -21,6 +23,8 @@ FROM alpine:3.20 LABEL "steadybit.com.discovery-disabled"="true" +RUN apk --no-cache add aws-cli + ARG USERNAME=steadybit ARG USER_UID=10000 diff --git a/README.md b/README.md index 9f4b2dd4..9f4c98a4 100644 --- a/README.md +++ b/README.md @@ -14,6 +14,7 @@ our [Reliability Hub](https://hub.steadybit.com/extension/com.steadybit.extensio |-----------------------------------------------------------------|-------------------------------------------------|---------------------------------------------------------------------------------------------------------------------------------------------------------------|----------|-----------------------------------------------------------------------------------------------------------------------------------------------| | `STEADYBIT_EXTENSION_WORKER_THREADS` | | How many parallel workers should call aws apis (only used if `STEADYBIT_EXTENSION_ASSUME_ROLES` is used) | no | 1 | | `STEADYBIT_EXTENSION_ASSUME_ROLES` | `aws.assumeRoles` | See detailed description below | no | | +| `STEADYBIT_EXTENSION_REGIONS` | | See detailed description below | no | | | `STEADYBIT_EXTENSION_DISCOVERY_DISABLED_EC2` | `aws.discovery.disabled.ec2` | Disable EC2-Discovery and all EC2 related definitions | no | false | | `STEADYBIT_EXTENSION_DISCOVERY_INTERVAL_EC2` | | Discovery-Interval in seconds | no | 30 | | `STEADYBIT_EXTENSION_DISCOVERY_DISABLED_ECS` | `aws.discovery.disabled.ecs` | Disable ECS-Discovery and all ECS related definitions | no | false | @@ -100,7 +101,8 @@ by tweaking the `Resource` clause. "ec2:DescribeTags", "ec2:StopInstances", "ec2:RebootInstances", - "ec2:TerminateInstances" + "ec2:TerminateInstances", + "ec2:StartInstances", ], "Resource": "*" } @@ -518,6 +520,17 @@ steps: } ``` +### Multi Region Support + +By default, the extension will discover targets only in the AWS Region that is provided by the current authentication (environment variable `AWS_REGION`). + +If you want to discover targets in multiple regions, you can set the `STEADYBIT_EXTENSION_REGIONS` environment variable to a comma-separated list of regions. Example: + +```sh +STEADYBIT_EXTENSION_REGIONS='us-east-1,us-west-2' +``` + + ### Agent Lockout - Requirements In order to prevent the agent or the extension of beeing locked out by their own attacks, we implemented some security diff --git a/config/config.go b/config/config.go index 3b682947..07ba93d2 100644 --- a/config/config.go +++ b/config/config.go @@ -6,6 +6,7 @@ package config import ( "github.com/kelseyhightower/envconfig" "github.com/rs/zerolog/log" + "strings" ) var ( @@ -14,7 +15,18 @@ var ( func ParseConfiguration() { err := envconfig.Process("steadybit_extension", &Config) + Config.AssumeRoles = trimSpaces(Config.AssumeRoles) + Config.Regions = trimSpaces(Config.Regions) + Config.EnrichEc2DataForTargetTypes = trimSpaces(Config.EnrichEc2DataForTargetTypes) if err != nil { log.Fatal().Err(err).Msgf("Failed to parse configuration from environment.") } } + +func trimSpaces(orig []string) []string { + var trimmed []string + for _, s := range orig { + trimmed = append(trimmed, strings.TrimSpace(s)) + } + return trimmed +} diff --git a/config/specification.go b/config/specification.go index f66ed4e9..bc5f6592 100644 --- a/config/specification.go +++ b/config/specification.go @@ -5,6 +5,7 @@ package config type Specification struct { AssumeRoles []string `json:"assumeRoles" split_words:"true" required:"false"` + Regions []string `json:"regions" split_words:"true" required:"false"` WorkerThreads int `json:"workerThreads" split_words:"true" required:"false" default:"1"` AwsEndpointOverride string `json:"awsEndpointOverride" split_words:"true" required:"false"` DiscoveryDisabledEc2 bool `json:"discoveryDisabledEc2" split_words:"true" required:"false" default:"false"` diff --git a/e2e/integration_test.go b/e2e/integration_test.go index 84e115ea..37091150 100644 --- a/e2e/integration_test.go +++ b/e2e/integration_test.go @@ -23,7 +23,7 @@ func TestWithMinikube(t *testing.T) { "--set", "logging.level=INFO", "--set", "extraEnv[0].name=STEADYBIT_EXTENSION_AWS_ENDPOINT_OVERRIDE", "--set", "extraEnv[0].value=http://localstack.default.svc.cluster.local:4566", - "--set", "extraEnv[1].name=AWS_DEFAULT_REGION", + "--set", "extraEnv[1].name=AWS_REGION", "--set", "extraEnv[1].value=us-east-1", "--set", "extraEnv[2].name=AWS_ACCESS_KEY_ID", "--set", "extraEnv[2].value=test", diff --git a/extaz/availability_zone_attack_blackhole.go b/extaz/availability_zone_attack_blackhole.go index d2cb7ae2..5870a3d7 100644 --- a/extaz/availability_zone_attack_blackhole.go +++ b/extaz/availability_zone_attack_blackhole.go @@ -23,7 +23,7 @@ import ( ) type azBlackholeAction struct { - clientProvider func(account string) (azBlackholeEC2Api, azBlackholeImdsApi, error) + clientProvider func(account string, region string) (azBlackholeEC2Api, azBlackholeImdsApi, error) extensionRootAccountNumber string } @@ -35,6 +35,7 @@ type BlackholeState struct { AgentAWSAccount string ExtensionAwsAccount string TargetZone string + TargetRegion string NetworkAclIds []string OldNetworkAclIds map[string]string // map[NewAssociationId] = oldNetworkAclId TargetSubnets map[string][]string // map[vpcId] = [subnetIds] @@ -57,7 +58,7 @@ type azBlackholeImdsApi interface { func NewAzBlackholeAction() action_kit_sdk.Action[BlackholeState] { return &azBlackholeAction{ clientProvider: defaultClientProvider, - extensionRootAccountNumber: utils.Accounts.GetRootAccount().AccountNumber, + extensionRootAccountNumber: utils.GetRootAccountNumber(), } } @@ -108,9 +109,10 @@ func (e *azBlackholeAction) Describe() action_kit_api.ActionDescription { func (e *azBlackholeAction) Prepare(ctx context.Context, state *BlackholeState, request action_kit_api.PrepareActionRequestBody) (*action_kit_api.PrepareResult, error) { targetAccount := extutil.MustHaveValue(request.Target.Attributes, "aws.account")[0] targetZone := extutil.MustHaveValue(request.Target.Attributes, "aws.zone")[0] + targetRegion := extutil.MustHaveValue(request.Target.Attributes, "aws.region")[0] // Get AWS Clients - clientEc2, clientImds, err := e.clientProvider(targetAccount) + clientEc2, clientImds, err := e.clientProvider(targetAccount, targetRegion) if err != nil { return nil, extension_kit.ToError(fmt.Sprintf("Failed to initialize AWS clients for AWS targetAccount %s", targetAccount), err) } @@ -145,17 +147,18 @@ func (e *azBlackholeAction) Prepare(ctx context.Context, state *BlackholeState, state.AgentAWSAccount = agentAwsAccountId state.ExtensionAwsAccount = targetAccount state.TargetZone = targetZone + state.TargetRegion = targetRegion state.TargetSubnets = targetSubnets state.AttackExecutionId = request.ExecutionId return nil, nil } func (e *azBlackholeAction) Start(ctx context.Context, state *BlackholeState) (*action_kit_api.StartResult, error) { - clientEc2, _, err := e.clientProvider(state.ExtensionAwsAccount) + clientEc2, _, err := e.clientProvider(state.ExtensionAwsAccount, state.TargetRegion) if err != nil { return nil, extension_kit.ToError(fmt.Sprintf("Failed to initialize EC2 client for AWS account %s", state.ExtensionAwsAccount), err) } - log.Info().Msgf("Starting AZ Blackhole attack against AWS account %s", state.ExtensionAwsAccount) + log.Info().Msgf("Starting AZ Blackhole attack against AWS account %s and region %s", state.ExtensionAwsAccount, state.TargetRegion) log.Debug().Msgf("Attack state: %+v", state) state.OldNetworkAclIds = make(map[string]string) @@ -371,9 +374,9 @@ func getNetworkAclAssociations(ctx context.Context, clientEc2 azBlackholeEC2Api, } func (e *azBlackholeAction) Stop(ctx context.Context, state *BlackholeState) (*action_kit_api.StopResult, error) { - clientEc2, _, err := e.clientProvider(state.ExtensionAwsAccount) + clientEc2, _, err := e.clientProvider(state.ExtensionAwsAccount, state.TargetRegion) if err != nil { - return nil, extension_kit.ToError(fmt.Sprintf("Failed to initialize EC2 client for AWS account %s", state.ExtensionAwsAccount), err) + return nil, extension_kit.ToError(fmt.Sprintf("Failed to initialize EC2 client for AWS account %s and region %s", state.ExtensionAwsAccount, state.TargetRegion), err) } return nil, rollbackBlackholeViaTags(ctx, state, clientEc2) @@ -465,13 +468,13 @@ func getAllNACLsCreatedBySteadybit(clientEc2 azBlackholeEC2Api, ctx context.Cont return &result, nil } -func defaultClientProvider(account string) (azBlackholeEC2Api, azBlackholeImdsApi, error) { - awsAccount, err := utils.Accounts.GetAccount(account) +func defaultClientProvider(account string, region string) (azBlackholeEC2Api, azBlackholeImdsApi, error) { + awsAccess, err := utils.GetAwsAccess(account, region) if err != nil { return nil, nil, err } - clientEc2 := ec2.NewFromConfig(awsAccount.AwsConfig) - clientImds := imds.NewFromConfig(awsAccount.AwsConfig) + clientEc2 := ec2.NewFromConfig(awsAccess.AwsConfig) + clientImds := imds.NewFromConfig(awsAccess.AwsConfig) if err != nil { return nil, nil, err } diff --git a/extaz/availability_zone_attack_blackhole_localstack_test.go b/extaz/availability_zone_attack_blackhole_localstack_test.go index 0647ec4e..6eb87886 100644 --- a/extaz/availability_zone_attack_blackhole_localstack_test.go +++ b/extaz/availability_zone_attack_blackhole_localstack_test.go @@ -55,6 +55,7 @@ func testPrepareAndStartAndStopBlackholeLocalStack(t *testing.T, clientEc2 *ec2. assert.Equal(t, "41", state.AgentAWSAccount) assert.Equal(t, "42", state.ExtensionAwsAccount) assert.Equal(t, "eu-west-1a", state.TargetZone) + assert.Equal(t, "eu-west-1", state.TargetRegion) assert.Len(t, state.TargetSubnets, 2) //default vpc and the one we created assert.Len(t, state.TargetSubnets[defaultVpcId], 1) //default vpc has 1 subnet assert.Len(t, state.TargetSubnets[createdVpcId], 2) //our vpc with 2 subnets @@ -75,6 +76,7 @@ func testPrepareAndStartAndStopBlackholeLocalStack(t *testing.T, clientEc2 *ec2. assert.Equal(t, "41", state.AgentAWSAccount) assert.Equal(t, "42", state.ExtensionAwsAccount) assert.Equal(t, "eu-west-1a", state.TargetZone) + assert.Equal(t, "eu-west-1", state.TargetRegion) assert.Len(t, state.NetworkAclIds, 2) //one per vpc newAssociationIds := reflect.ValueOf(state.OldNetworkAclIds).MapKeys() assert.NotEqual(t, "", state.OldNetworkAclIds[newAssociationIds[0].String()]) @@ -261,7 +263,7 @@ func testApiThrottlingDuringStopWhileDeletingNACLs(t *testing.T, clientEc2 *ec2. func prepareActionCall(clientEc2 *ec2.Client, clientImds *imds.Client) (azBlackholeAction, BlackholeState, action_kit_api.PrepareActionRequestBody) { action := azBlackholeAction{ extensionRootAccountNumber: "41", - clientProvider: func(account string) (azBlackholeEC2Api, azBlackholeImdsApi, error) { + clientProvider: func(account string, region string) (azBlackholeEC2Api, azBlackholeImdsApi, error) { return clientEc2, clientImds, nil }} state := action.NewEmptyState() @@ -273,6 +275,7 @@ func prepareActionCall(clientEc2 *ec2.Client, clientImds *imds.Client) (azBlackh Target: extutil.Ptr(action_kit_api.Target{ Attributes: map[string][]string{ "aws.zone": {"eu-west-1a"}, + "aws.region": {"eu-west-1"}, "aws.account": {"42"}, }, }), diff --git a/extaz/availability_zone_attack_blackhole_test.go b/extaz/availability_zone_attack_blackhole_test.go index 77803af0..a99f9b79 100644 --- a/extaz/availability_zone_attack_blackhole_test.go +++ b/extaz/availability_zone_attack_blackhole_test.go @@ -110,7 +110,7 @@ func TestPrepareBlackhole(t *testing.T) { ctx := context.Background() action := azBlackholeAction{ extensionRootAccountNumber: "", - clientProvider: func(account string) (azBlackholeEC2Api, azBlackholeImdsApi, error) { + clientProvider: func(account string, region string) (azBlackholeEC2Api, azBlackholeImdsApi, error) { return clientEc2, clientImds, nil }} state := action.NewEmptyState() @@ -122,6 +122,7 @@ func TestPrepareBlackhole(t *testing.T) { Target: extutil.Ptr(action_kit_api.Target{ Attributes: map[string][]string{ "aws.zone": {"eu-west-1a"}, + "aws.region": {"eu-west-1"}, "aws.account": {"42"}, }, }), @@ -138,6 +139,7 @@ func TestPrepareBlackhole(t *testing.T) { assert.Equal(t, "41", state.AgentAWSAccount) assert.Equal(t, "42", state.ExtensionAwsAccount) assert.Equal(t, "eu-west-1a", state.TargetZone) + assert.Equal(t, "eu-west-1", state.TargetRegion) assert.Equal(t, []string{"subnet-1", "subnet-2"}, state.TargetSubnets["vpcId-1"]) assert.NotNil(t, state.AttackExecutionId) clientEc2.AssertExpectations(t) @@ -156,7 +158,7 @@ func TestShouldNotAttackWhenExtensionIsInTargetAccountId(t *testing.T) { ctx := context.Background() action := azBlackholeAction{ extensionRootAccountNumber: "42", - clientProvider: func(account string) (azBlackholeEC2Api, azBlackholeImdsApi, error) { + clientProvider: func(account string, region string) (azBlackholeEC2Api, azBlackholeImdsApi, error) { return nil, clientImds, nil }} state := action.NewEmptyState() @@ -168,6 +170,7 @@ func TestShouldNotAttackWhenExtensionIsInTargetAccountId(t *testing.T) { Target: extutil.Ptr(action_kit_api.Target{ Attributes: map[string][]string{ "aws.zone": {"eu-west-1a"}, + "aws.region": {"eu-west-1"}, "aws.account": {"42"}, }, }), @@ -191,7 +194,7 @@ func TestShouldNotAttackWhenExtensionIsInTargetAccountIdViaStsClient(t *testing. ctx := context.Background() action := azBlackholeAction{ extensionRootAccountNumber: "42", - clientProvider: func(account string) (azBlackholeEC2Api, azBlackholeImdsApi, error) { + clientProvider: func(account string, region string) (azBlackholeEC2Api, azBlackholeImdsApi, error) { return nil, clientImds, nil }} state := action.NewEmptyState() @@ -203,6 +206,7 @@ func TestShouldNotAttackWhenExtensionIsInTargetAccountIdViaStsClient(t *testing. Target: extutil.Ptr(action_kit_api.Target{ Attributes: map[string][]string{ "aws.zone": {"eu-west-1a"}, + "aws.region": {"eu-west-1"}, "aws.account": {"42"}, }, }), @@ -226,7 +230,7 @@ func TestShouldNotAttackWhenExtensionAccountIsUnknown(t *testing.T) { ctx := context.Background() action := azBlackholeAction{ extensionRootAccountNumber: "", - clientProvider: func(account string) (azBlackholeEC2Api, azBlackholeImdsApi, error) { + clientProvider: func(account string, region string) (azBlackholeEC2Api, azBlackholeImdsApi, error) { return nil, clientImds, nil }} state := action.NewEmptyState() @@ -238,6 +242,7 @@ func TestShouldNotAttackWhenExtensionAccountIsUnknown(t *testing.T) { Target: extutil.Ptr(action_kit_api.Target{ Attributes: map[string][]string{ "aws.zone": {"eu-west-1a"}, + "aws.region": {"eu-west-1"}, "aws.account": {"42"}, }, }), @@ -261,7 +266,7 @@ func TestShouldNotAttackWhenAgentAccountIsUnknown(t *testing.T) { ctx := context.Background() action := azBlackholeAction{ extensionRootAccountNumber: "", - clientProvider: func(account string) (azBlackholeEC2Api, azBlackholeImdsApi, error) { + clientProvider: func(account string, region string) (azBlackholeEC2Api, azBlackholeImdsApi, error) { return nil, clientImds, nil }} state := action.NewEmptyState() @@ -273,6 +278,7 @@ func TestShouldNotAttackWhenAgentAccountIsUnknown(t *testing.T) { Target: extutil.Ptr(action_kit_api.Target{ Attributes: map[string][]string{ "aws.zone": {"eu-west-1a"}, + "aws.region": {"eu-west-1"}, "aws.account": {"42"}, }, }), @@ -300,7 +306,7 @@ func TestShouldNotAttackWhenAgentIsInTargetAccountId(t *testing.T) { ctx := context.Background() action := azBlackholeAction{ extensionRootAccountNumber: "", - clientProvider: func(account string) (azBlackholeEC2Api, azBlackholeImdsApi, error) { + clientProvider: func(account string, region string) (azBlackholeEC2Api, azBlackholeImdsApi, error) { return nil, clientImds, nil }} state := action.NewEmptyState() @@ -312,6 +318,7 @@ func TestShouldNotAttackWhenAgentIsInTargetAccountId(t *testing.T) { Target: extutil.Ptr(action_kit_api.Target{ Attributes: map[string][]string{ "aws.zone": {"eu-west-1a"}, + "aws.region": {"eu-west-1"}, "aws.account": {"42"}, }, }), @@ -395,7 +402,7 @@ func TestStartBlackhole(t *testing.T) { }), nil) ctx := context.Background() - action := azBlackholeAction{clientProvider: func(account string) (azBlackholeEC2Api, azBlackholeImdsApi, error) { + action := azBlackholeAction{clientProvider: func(account string, region string) (azBlackholeEC2Api, azBlackholeImdsApi, error) { return clientEc2, nil, nil }} @@ -403,6 +410,7 @@ func TestStartBlackhole(t *testing.T) { AgentAWSAccount: "41", ExtensionAwsAccount: "43", TargetZone: "eu-west-1a", + TargetRegion: "eu-west-1", TargetSubnets: map[string][]string{ "vpcId-1": {"subnet-1", "subnet-2"}, }, @@ -482,7 +490,7 @@ func TestStopBlackhole(t *testing.T) { }), mock.Anything).Return(extutil.Ptr(ec2.DeleteNetworkAclOutput{}), nil) ctx := context.Background() - action := azBlackholeAction{clientProvider: func(account string) (azBlackholeEC2Api, azBlackholeImdsApi, error) { + action := azBlackholeAction{clientProvider: func(account string, region string) (azBlackholeEC2Api, azBlackholeImdsApi, error) { return clientEc2, nil, nil }} @@ -490,6 +498,7 @@ func TestStopBlackhole(t *testing.T) { AgentAWSAccount: "41", ExtensionAwsAccount: "43", TargetZone: "eu-west-1a", + TargetRegion: "eu-west-1", TargetSubnets: map[string][]string{ "vpcId-1": {"subnet-1", "subnet-2"}, }, diff --git a/extaz/availablity_zone_discovery.go b/extaz/availablity_zone_discovery.go index 36e968fc..44ccb4a1 100644 --- a/extaz/availablity_zone_discovery.go +++ b/extaz/availablity_zone_discovery.go @@ -66,16 +66,16 @@ func (a *azDiscovery) DescribeTarget() discovery_kit_api.TargetDescription { } func (a *azDiscovery) DiscoverTargets(ctx context.Context) ([]discovery_kit_api.Target, error) { - return utils.ForEveryAccount(utils.Accounts, getTargetsForAccount, ctx, "availability zone") + return utils.ForEveryConfiguredAwsAccess(getTargetsForAccount, ctx, "availability zone") } -func getTargetsForAccount(account *utils.AwsAccount, _ context.Context) ([]discovery_kit_api.Target, error) { - return getAllAvailabilityZones(utils.Zones, account.AccountNumber), nil +func getTargetsForAccount(account *utils.AwsAccess, _ context.Context) ([]discovery_kit_api.Target, error) { + return getAllAvailabilityZones(utils.Zones, account.AccountNumber, account.Region), nil } -func getAllAvailabilityZones(zones utils.GetZonesUtil, awsAccountNumber string) []discovery_kit_api.Target { +func getAllAvailabilityZones(zones utils.GetZonesUtil, awsAccountNumber string, region string) []discovery_kit_api.Target { result := make([]discovery_kit_api.Target, 0, 20) - for _, availabilityZone := range zones.GetZones(awsAccountNumber) { + for _, availabilityZone := range zones.GetZones(awsAccountNumber, region) { result = append(result, toTarget(availabilityZone, awsAccountNumber)) } return discovery_kit_commons.ApplyAttributeExcludes(result, config.Config.DiscoveryAttributesExcludesZone) diff --git a/extaz/availablity_zone_discovery_test.go b/extaz/availablity_zone_discovery_test.go index 832b512b..b4092f6e 100644 --- a/extaz/availablity_zone_discovery_test.go +++ b/extaz/availablity_zone_discovery_test.go @@ -16,8 +16,8 @@ type zoneMock struct { mock.Mock } -func (m *zoneMock) GetZones(awsAccountNumber string) []types.AvailabilityZone { - args := m.Called(awsAccountNumber) +func (m *zoneMock) GetZones(awsAccountNumber string, region string) []types.AvailabilityZone { + args := m.Called(awsAccountNumber, region) return args.Get(0).([]types.AvailabilityZone) } @@ -31,10 +31,10 @@ func TestGetAllAvailabilityZones(t *testing.T) { ZoneId: discovery_kit_api.Ptr("euc1-az3"), }, } - mockedApi.On("GetZones", mock.Anything).Return(mockedReturnValue) + mockedApi.On("GetZones", mock.Anything, mock.Anything).Return(mockedReturnValue) // When - targets := getAllAvailabilityZones(mockedApi, "42") + targets := getAllAvailabilityZones(mockedApi, "42", "eu-central-1") // Then assert.Equal(t, 1, len(targets)) diff --git a/extec2/instance_attack_state.go b/extec2/instance_attack_state.go index b60fe8a7..04288d0f 100644 --- a/extec2/instance_attack_state.go +++ b/extec2/instance_attack_state.go @@ -16,7 +16,7 @@ import ( ) type ec2InstanceStateAction struct { - clientProvider func(account string) (ec2InstanceStateChangeApi, error) + clientProvider func(account string, region string) (ec2InstanceStateChangeApi, error) } // Make sure lambdaAction implements all required interfaces @@ -24,6 +24,7 @@ var _ action_kit_sdk.Action[InstanceStateChangeState] = (*ec2InstanceStateAction type InstanceStateChangeState struct { Account string + Region string InstanceId string Action string } @@ -32,6 +33,7 @@ type ec2InstanceStateChangeApi interface { StopInstances(ctx context.Context, params *ec2.StopInstancesInput, optFns ...func(*ec2.Options)) (*ec2.StopInstancesOutput, error) TerminateInstances(ctx context.Context, params *ec2.TerminateInstancesInput, optFns ...func(*ec2.Options)) (*ec2.TerminateInstancesOutput, error) RebootInstances(ctx context.Context, params *ec2.RebootInstancesInput, optFns ...func(*ec2.Options)) (*ec2.RebootInstancesOutput, error) + StartInstances(ctx context.Context, params *ec2.StartInstancesInput, optFns ...func(*ec2.Options)) (*ec2.StartInstancesOutput, error) } func NewEc2InstanceStateAction() action_kit_sdk.Action[InstanceStateChangeState] { @@ -92,6 +94,10 @@ func (e *ec2InstanceStateAction) Describe() action_kit_api.ActionDescription { Label: "Terminate", Value: "terminate", }, + action_kit_api.ExplicitParameterOption{ + Label: "Start", + Value: "start", + }, }), }, }, @@ -105,13 +111,14 @@ func (e *ec2InstanceStateAction) Prepare(_ context.Context, state *InstanceState } state.Account = extutil.MustHaveValue(request.Target.Attributes, "aws.account")[0] + state.Region = extutil.MustHaveValue(request.Target.Attributes, "aws.region")[0] state.InstanceId = extutil.MustHaveValue(request.Target.Attributes, "aws-ec2.instance.id")[0] state.Action = action.(string) return nil, nil } func (e *ec2InstanceStateAction) Start(ctx context.Context, state *InstanceStateChangeState) (*action_kit_api.StartResult, error) { - client, err := e.clientProvider(state.Account) + client, err := e.clientProvider(state.Account, state.Region) if err != nil { return nil, extension_kit.ToError(fmt.Sprintf("Failed to initialize EC2 client for AWS account %s", state.Account), err) } @@ -140,6 +147,11 @@ func (e *ec2InstanceStateAction) Start(ctx context.Context, state *InstanceState InstanceIds: instanceIds, } _, err = client.TerminateInstances(ctx, &in) + } else if state.Action == "start" { + in := ec2.StartInstancesInput{ + InstanceIds: instanceIds, + } + _, err = client.StartInstances(ctx, &in) } if err != nil { @@ -149,10 +161,10 @@ func (e *ec2InstanceStateAction) Start(ctx context.Context, state *InstanceState return nil, nil } -func defaultClientProvider(account string) (ec2InstanceStateChangeApi, error) { - awsAccount, err := utils.Accounts.GetAccount(account) +func defaultClientProvider(account string, region string) (ec2InstanceStateChangeApi, error) { + awsAccess, err := utils.GetAwsAccess(account, region) if err != nil { return nil, err } - return ec2.NewFromConfig(awsAccount.AwsConfig), nil + return ec2.NewFromConfig(awsAccess.AwsConfig), nil } diff --git a/extec2/instance_attack_state_test.go b/extec2/instance_attack_state_test.go index 0f298473..5d3a9023 100644 --- a/extec2/instance_attack_state_test.go +++ b/extec2/instance_attack_state_test.go @@ -35,6 +35,7 @@ func TestEc2InstanceStateAction_Prepare(t *testing.T) { Attributes: map[string][]string{ "aws-ec2.instance.id": {"my-instance"}, "aws.account": {"42"}, + "aws.region": {"us-west-1"}, }, }), }), @@ -42,6 +43,7 @@ func TestEc2InstanceStateAction_Prepare(t *testing.T) { wantedState: &InstanceStateChangeState{ Account: "42", Action: "stop", + Region: "us-west-1", InstanceId: "my-instance", }, }, @@ -53,6 +55,7 @@ func TestEc2InstanceStateAction_Prepare(t *testing.T) { Attributes: map[string][]string{ "aws-ec2.instance.id": {"my-instance"}, "aws.account": {"42"}, + "aws.region": {"us-west-1"}, }, }), }), @@ -75,6 +78,7 @@ func TestEc2InstanceStateAction_Prepare(t *testing.T) { if tt.wantedState != nil { assert.NoError(t, err) assert.Equal(t, tt.wantedState.Account, state.Account) + assert.Equal(t, tt.wantedState.Region, state.Region) assert.Equal(t, tt.wantedState.InstanceId, state.InstanceId) assert.EqualValues(t, tt.wantedState.Action, state.Action) } @@ -101,7 +105,12 @@ func (m *ec2ClientApiMock) RebootInstances(ctx context.Context, params *ec2.Rebo return nil, args.Error(1) } -func TestEc2InstanceStateAction_Start(t *testing.T) { +func (m *ec2ClientApiMock) StartInstances(ctx context.Context, params *ec2.StartInstancesInput, optFns ...func(*ec2.Options)) (*ec2.StartInstancesOutput, error) { + args := m.Called(ctx, params) + return nil, args.Error(1) +} + +func TestEc2InstanceStateAction_Stop(t *testing.T) { // Given api := new(ec2ClientApiMock) api.On("StopInstances", mock.Anything, mock.MatchedBy(func(params *ec2.StopInstancesInput) bool { @@ -110,13 +119,14 @@ func TestEc2InstanceStateAction_Start(t *testing.T) { return true })).Return(nil, nil) - action := ec2InstanceStateAction{clientProvider: func(account string) (ec2InstanceStateChangeApi, error) { + action := ec2InstanceStateAction{clientProvider: func(account string, region string) (ec2InstanceStateChangeApi, error) { return api, nil }} // When result, err := action.Start(context.Background(), &InstanceStateChangeState{ Account: "42", + Region: "us-west-1", InstanceId: "dev-worker-1", Action: "stop", }) @@ -136,13 +146,14 @@ func TestEc2InstanceStateAction_Hibernate(t *testing.T) { require.Equal(t, true, *params.Hibernate) return true })).Return(nil, nil) - action := ec2InstanceStateAction{clientProvider: func(account string) (ec2InstanceStateChangeApi, error) { + action := ec2InstanceStateAction{clientProvider: func(account string, region string) (ec2InstanceStateChangeApi, error) { return api, nil }} // When result, err := action.Start(context.Background(), &InstanceStateChangeState{ Account: "42", + Region: "us-west-1", InstanceId: "dev-worker-1", Action: "hibernate", }) @@ -161,13 +172,14 @@ func TestEc2InstanceStateAction_Terminate(t *testing.T) { require.Equal(t, "dev-worker-1", params.InstanceIds[0]) return true })).Return(nil, nil) - action := ec2InstanceStateAction{clientProvider: func(account string) (ec2InstanceStateChangeApi, error) { + action := ec2InstanceStateAction{clientProvider: func(account string, region string) (ec2InstanceStateChangeApi, error) { return api, nil }} // When result, err := action.Start(context.Background(), &InstanceStateChangeState{ Account: "42", + Region: "us-west-1", InstanceId: "dev-worker-1", Action: "terminate", }) @@ -186,13 +198,14 @@ func TestEc2InstanceStateAction_Reboot(t *testing.T) { require.Equal(t, "dev-worker-1", params.InstanceIds[0]) return true })).Return(nil, nil) - action := ec2InstanceStateAction{clientProvider: func(account string) (ec2InstanceStateChangeApi, error) { + action := ec2InstanceStateAction{clientProvider: func(account string, region string) (ec2InstanceStateChangeApi, error) { return api, nil }} // When result, err := action.Start(context.Background(), &InstanceStateChangeState{ Account: "42", + Region: "us-west-1", InstanceId: "dev-worker-1", Action: "reboot", }) @@ -204,6 +217,32 @@ func TestEc2InstanceStateAction_Reboot(t *testing.T) { api.AssertExpectations(t) } +func TestEc2InstanceStateAction_Start(t *testing.T) { + // Given + api := new(ec2ClientApiMock) + api.On("StartInstances", mock.Anything, mock.MatchedBy(func(params *ec2.StartInstancesInput) bool { + require.Equal(t, "dev-worker-1", params.InstanceIds[0]) + return true + })).Return(nil, nil) + action := ec2InstanceStateAction{clientProvider: func(account string, region string) (ec2InstanceStateChangeApi, error) { + return api, nil + }} + + // When + result, err := action.Start(context.Background(), &InstanceStateChangeState{ + Account: "42", + Region: "us-west-1", + InstanceId: "dev-worker-1", + Action: "start", + }) + + // Then + assert.NoError(t, err) + assert.Nil(t, result) + + api.AssertExpectations(t) +} + func TestStartInstanceStateChangeForwardsError(t *testing.T) { // Given api := new(ec2ClientApiMock) @@ -211,13 +250,14 @@ func TestStartInstanceStateChangeForwardsError(t *testing.T) { require.Equal(t, "dev-worker-1", params.InstanceIds[0]) return true })).Return(nil, errors.New("expected")) - action := ec2InstanceStateAction{clientProvider: func(account string) (ec2InstanceStateChangeApi, error) { + action := ec2InstanceStateAction{clientProvider: func(account string, region string) (ec2InstanceStateChangeApi, error) { return api, nil }} // When result, err := action.Start(context.Background(), &InstanceStateChangeState{ Account: "42", + Region: "us-west-1", InstanceId: "dev-worker-1", Action: "reboot", }) diff --git a/extec2/instance_discovery.go b/extec2/instance_discovery.go index 70713337..d9a6de0f 100644 --- a/extec2/instance_discovery.go +++ b/extec2/instance_discovery.go @@ -249,10 +249,10 @@ func (e *ec2Discovery) DescribeAttributes() []discovery_kit_api.AttributeDescrip } func (e *ec2Discovery) DiscoverTargets(ctx context.Context) ([]discovery_kit_api.Target, error) { - return utils.ForEveryAccount(utils.Accounts, getTargetsForAccount, ctx, "ec2-instance") + return utils.ForEveryConfiguredAwsAccess(getTargetsForAccount, ctx, "ec2-instance") } -func getTargetsForAccount(account *utils.AwsAccount, ctx context.Context) ([]discovery_kit_api.Target, error) { +func getTargetsForAccount(account *utils.AwsAccess, ctx context.Context) ([]discovery_kit_api.Target, error) { client := ec2.NewFromConfig(account.AwsConfig) result, err := GetAllEc2Instances(ctx, client, utils.Zones, account.AccountNumber, account.AwsConfig.Region) if err != nil { @@ -303,7 +303,7 @@ func toTarget(ec2Instance types.Instance, zoneUtil utils.GetZoneUtil, awsAccount label = label + " / " + *name } availabilityZoneName := aws.ToString(ec2Instance.Placement.AvailabilityZone) - availabilityZoneApi := zoneUtil.GetZone(awsAccountNumber, availabilityZoneName) + availabilityZoneApi := zoneUtil.GetZone(awsAccountNumber, availabilityZoneName, awsRegion) attributes := make(map[string][]string) attributes["aws.account"] = []string{awsAccountNumber} diff --git a/extec2/instance_discovery_test.go b/extec2/instance_discovery_test.go index 6ca3a0eb..9f3b04cf 100644 --- a/extec2/instance_discovery_test.go +++ b/extec2/instance_discovery_test.go @@ -32,8 +32,8 @@ type zoneMock struct { mock.Mock } -func (m *zoneMock) GetZone(awsAccountNumber string, awsZone string) *types.AvailabilityZone { - args := m.Called(awsAccountNumber, awsZone) +func (m *zoneMock) GetZone(awsAccountNumber string, awsZone string, region string) *types.AvailabilityZone { + args := m.Called(awsAccountNumber, awsZone, region) return args.Get(0).(*types.AvailabilityZone) } @@ -76,7 +76,7 @@ func TestGetAllEc2Instances(t *testing.T) { RegionName: discovery_kit_api.Ptr("us-east-1"), ZoneId: discovery_kit_api.Ptr("us-east-1b-id"), } - mockedZoneUtil.On("GetZone", mock.Anything, mock.Anything).Return(&mockedZone) + mockedZoneUtil.On("GetZone", mock.Anything, mock.Anything, mock.Anything).Return(&mockedZone) // When targets, err := GetAllEc2Instances(context.Background(), mockedApi, mockedZoneUtil, "42", "us-east-1") @@ -126,7 +126,7 @@ func TestGetAllEc2InstancesWithFilteredAttributes(t *testing.T) { RegionName: discovery_kit_api.Ptr("us-east-1"), ZoneId: discovery_kit_api.Ptr("us-east-1b-id"), } - mockedZoneUtil.On("GetZone", mock.Anything, mock.Anything).Return(&mockedZone) + mockedZoneUtil.On("GetZone", mock.Anything, mock.Anything, mock.Anything).Return(&mockedZone) // When targets, err := GetAllEc2Instances(context.Background(), mockedApi, mockedZoneUtil, "42", "us-east-1") @@ -179,7 +179,7 @@ func TestNameNotSet(t *testing.T) { RegionName: discovery_kit_api.Ptr("us-east-1"), ZoneId: discovery_kit_api.Ptr("us-east-1b-id"), } - mockedZoneUtil.On("GetZone", mock.Anything, mock.Anything).Return(&mockedZone) + mockedZoneUtil.On("GetZone", mock.Anything, mock.Anything, mock.Anything).Return(&mockedZone) // When targets, err := GetAllEc2Instances(context.Background(), mockedApi, mockedZoneUtil, "42", "us-east-1") @@ -204,7 +204,7 @@ func TestGetAllEc2InstancesError(t *testing.T) { RegionName: discovery_kit_api.Ptr("us-east-1"), ZoneId: discovery_kit_api.Ptr("us-east-1b-id"), } - mockedZoneUtil.On("GetZone", mock.Anything, mock.Anything).Return(&mockedZone) + mockedZoneUtil.On("GetZone", mock.Anything, mock.Anything, mock.Anything).Return(&mockedZone) // When _, err := GetAllEc2Instances(context.Background(), mockedApi, mockedZoneUtil, "42", "us-east-1") diff --git a/extecs/service_attack_scale.go b/extecs/service_attack_scale.go index d8f0cc43..2549bebf 100644 --- a/extecs/service_attack_scale.go +++ b/extecs/service_attack_scale.go @@ -16,7 +16,7 @@ import ( ) type ecsServiceScaleAction struct { - clientProvider func(account string) (ecsServiceScaleApi, error) + clientProvider func(account string, region string) (ecsServiceScaleApi, error) } // Make sure action implements all required interfaces @@ -25,6 +25,7 @@ var _ action_kit_sdk.ActionWithStop[ServiceScaleState] = (*ecsServiceScaleAction type ServiceScaleState struct { Account string + Region string ServiceName string ClusterArn string DesiredCount int32 @@ -88,11 +89,12 @@ func (e *ecsServiceScaleAction) Describe() action_kit_api.ActionDescription { func (e *ecsServiceScaleAction) Prepare(ctx context.Context, state *ServiceScaleState, request action_kit_api.PrepareActionRequestBody) (*action_kit_api.PrepareResult, error) { state.Account = extutil.MustHaveValue(request.Target.Attributes, "aws.account")[0] + state.Region = extutil.MustHaveValue(request.Target.Attributes, "aws.region")[0] state.ClusterArn = extutil.MustHaveValue(request.Target.Attributes, "aws-ecs.cluster.arn")[0] state.ServiceName = extutil.MustHaveValue(request.Target.Attributes, "aws-ecs.service.name")[0] state.DesiredCount = extutil.ToInt32(request.Config["desiredCount"]) - client, err := e.clientProvider(state.Account) + client, err := e.clientProvider(state.Account, state.Region) if err != nil { return nil, extension_kit.ToError(fmt.Sprintf("Failed to initialize ECS client for AWS account %s", state.Account), err) } @@ -109,7 +111,7 @@ func (e *ecsServiceScaleAction) Prepare(ctx context.Context, state *ServiceScale } func (e *ecsServiceScaleAction) Start(ctx context.Context, state *ServiceScaleState) (*action_kit_api.StartResult, error) { - client, err := e.clientProvider(state.Account) + client, err := e.clientProvider(state.Account, state.Region) if err != nil { return nil, extension_kit.ToError(fmt.Sprintf("Failed to initialize ECS client for AWS account %s", state.Account), err) } @@ -127,7 +129,7 @@ func (e *ecsServiceScaleAction) Start(ctx context.Context, state *ServiceScaleSt } func (e *ecsServiceScaleAction) Stop(ctx context.Context, state *ServiceScaleState) (*action_kit_api.StopResult, error) { - client, err := e.clientProvider(state.Account) + client, err := e.clientProvider(state.Account, state.Region) if err != nil { return nil, extension_kit.ToError(fmt.Sprintf("Failed to initialize ECS client for AWS account %s", state.Account), err) } @@ -142,10 +144,10 @@ func (e *ecsServiceScaleAction) Stop(ctx context.Context, state *ServiceScaleSta return nil, nil } -func defaultClientProviderService(account string) (ecsServiceScaleApi, error) { - awsAccount, err := utils.Accounts.GetAccount(account) +func defaultClientProviderService(account string, region string) (ecsServiceScaleApi, error) { + awsAccess, err := utils.GetAwsAccess(account, region) if err != nil { return nil, err } - return ecs.NewFromConfig(awsAccount.AwsConfig), nil + return ecs.NewFromConfig(awsAccess.AwsConfig), nil } diff --git a/extecs/service_attack_scale_test.go b/extecs/service_attack_scale_test.go index 658dfe8f..21b890aa 100644 --- a/extecs/service_attack_scale_test.go +++ b/extecs/service_attack_scale_test.go @@ -23,7 +23,7 @@ func TestEcsServiceScaleAction_Prepare(t *testing.T) { Services: []types.Service{{DesiredCount: 2}}, }, nil) - action := ecsServiceScaleAction{clientProvider: func(account string) (ecsServiceScaleApi, error) { + action := ecsServiceScaleAction{clientProvider: func(account string, region string) (ecsServiceScaleApi, error) { return api, nil }} @@ -46,12 +46,14 @@ func TestEcsServiceScaleAction_Prepare(t *testing.T) { "aws-ecs.service.arn": {"my-service-arn"}, "aws-ecs.service.name": {"my-service-name"}, "aws.account": {"42"}, + "aws.region": {"us-east-1"}, }, }), }), wantedState: &ServiceScaleState{ Account: "42", + Region: "us-east-1", ClusterArn: "my-cluster-arn", ServiceName: "my-service-name", DesiredCount: 5, @@ -75,6 +77,7 @@ func TestEcsServiceScaleAction_Prepare(t *testing.T) { if tt.wantedState != nil { assert.NoError(t, err) assert.Equal(t, tt.wantedState.Account, state.Account) + assert.Equal(t, tt.wantedState.Region, state.Region) assert.Equal(t, tt.wantedState.ClusterArn, state.ClusterArn) assert.EqualValues(t, tt.wantedState.ServiceName, state.ServiceName) assert.EqualValues(t, tt.wantedState.DesiredCount, state.DesiredCount) @@ -107,13 +110,14 @@ func TestEcsServiceScaleAction_Start(t *testing.T) { return true })).Return(nil, nil) - action := ecsServiceScaleAction{clientProvider: func(account string) (ecsServiceScaleApi, error) { + action := ecsServiceScaleAction{clientProvider: func(account string, region string) (ecsServiceScaleApi, error) { return api, nil }} // When state := &ServiceScaleState{ Account: "42", + Region: "us-east-1", ClusterArn: "my-cluster-arn", ServiceName: "my-service-name", InitialDesiredCount: int32(2), @@ -138,13 +142,14 @@ func TestEcsServiceScaleAction_Stop(t *testing.T) { return true })).Return(nil, nil) - action := ecsServiceScaleAction{clientProvider: func(account string) (ecsServiceScaleApi, error) { + action := ecsServiceScaleAction{clientProvider: func(account string, region string) (ecsServiceScaleApi, error) { return api, nil }} // When state := &ServiceScaleState{ Account: "42", + Region: "us-east-1", ClusterArn: "my-cluster-arn", ServiceName: "my-service-name", DesiredCount: int32(5), @@ -169,13 +174,14 @@ func TestEcsServiceScaleActionForwardsError(t *testing.T) { require.Equal(t, int32(5), *params.DesiredCount) return true })).Return(nil, errors.New("expected")) - action := ecsServiceScaleAction{clientProvider: func(account string) (ecsServiceScaleApi, error) { + action := ecsServiceScaleAction{clientProvider: func(account string, region string) (ecsServiceScaleApi, error) { return api, nil }} // When result, err := action.Start(context.Background(), &ServiceScaleState{ Account: "42", + Region: "us-east-1", ClusterArn: "my-cluster-arn", ServiceName: "my-service-name", DesiredCount: int32(5), diff --git a/extecs/service_description_poller.go b/extecs/service_description_poller.go index c5310646..71de716d 100644 --- a/extecs/service_description_poller.go +++ b/extecs/service_description_poller.go @@ -21,10 +21,10 @@ type ecsDescribeServicesApi interface { type ServiceDescriptionPoller interface { Start(ctx context.Context) - Register(account string, cluster string, service string) - Unregister(account string, cluster string, service string) - Latest(account string, cluster string, service string) *PollService - AwaitLatest(account string, cluster string, service string) *PollService + Register(account string, region string, cluster string, service string) + Unregister(account string, region string, cluster string, service string) + Latest(account string, region string, cluster string, service string) *PollService + AwaitLatest(account string, region string, cluster string, service string) *PollService } type PollService struct { @@ -37,10 +37,11 @@ type pollRecord struct { } type pollServices map[string]*pollRecord type pollClusters map[string]pollServices -type pollAccounts map[string]pollClusters +type pollRegions map[string]pollClusters +type pollAccounts map[string]pollRegions type EcsServiceDescriptionPoller struct { - apiClientProvider func(account string) (ecsDescribeServicesApi, error) + apiClientProvider func(account string, region string) (ecsDescribeServicesApi, error) ticker *time.Ticker m *sync.RWMutex c *sync.Cond @@ -73,16 +74,22 @@ func (p EcsServiceDescriptionPoller) Start(ctx context.Context) { }() } -func (p EcsServiceDescriptionPoller) Register(account string, cluster string, service string) { +func (p EcsServiceDescriptionPoller) Register(account string, region string, cluster string, service string) { var ok bool p.m.Lock() defer p.m.Unlock() log.Debug().Msgf("register service %s", service) + var regions pollRegions + if regions, ok = p.polls[account]; !ok { + regions = make(pollRegions) + p.polls[account] = regions + } + var clusters pollClusters - if clusters, ok = p.polls[account]; !ok { + if clusters, ok = regions[region]; !ok { clusters = make(pollClusters) - p.polls[account] = clusters + regions[region] = clusters } var services pollServices @@ -99,48 +106,56 @@ func (p EcsServiceDescriptionPoller) Register(account string, cluster string, se } } -func (p EcsServiceDescriptionPoller) Unregister(account string, cluster string, service string) { +func (p EcsServiceDescriptionPoller) Unregister(account string, region string, cluster string, service string) { p.m.Lock() defer p.m.Unlock() log.Debug().Msgf("unregister service %s", service) - if clusters, ok := p.polls[account]; ok { - if services, ok := clusters[cluster]; ok { - if record, ok := services[service]; ok { - if record.count > 0 { - record.count = record.count - 1 - } else { - delete(services, service) - if len(services) == 0 { - delete(clusters, cluster) + if regions, ok := p.polls[account]; ok { + if clusters, ok := regions[region]; ok { + if services, ok := clusters[cluster]; ok { + if record, ok := services[service]; ok { + if record.count > 0 { + record.count = record.count - 1 + } else { + delete(services, service) + if len(services) == 0 { + delete(clusters, cluster) + } } } } - } - if len(clusters) == 0 { - delete(p.polls, account) + if len(clusters) == 0 { + delete(p.polls, account) + } } } p.c.Broadcast() } -func (p EcsServiceDescriptionPoller) Latest(account string, cluster string, service string) *PollService { +func (p EcsServiceDescriptionPoller) Latest(account string, region string, cluster string, service string) *PollService { p.m.RLock() defer p.m.RUnlock() - if clusters, ok := p.polls[account]; ok { - if services, ok := clusters[cluster]; ok { - if record, ok := services[service]; ok { - return record.value + if regions, ok := p.polls[account]; ok { + if clusters, ok := regions[region]; ok { + if services, ok := clusters[cluster]; ok { + if record, ok := services[service]; ok { + return record.value + } } } } return nil } -func (p EcsServiceDescriptionPoller) AwaitLatest(account string, cluster string, service string) *PollService { +func (p EcsServiceDescriptionPoller) AwaitLatest(account string, region string, cluster string, service string) *PollService { p.m.Lock() defer p.m.Unlock() for { - clusters, ok := p.polls[account] + regions, ok := p.polls[account] + if !ok { + return nil + } + clusters, ok := regions[region] if !ok { return nil } @@ -161,33 +176,35 @@ func (p EcsServiceDescriptionPoller) pollAll(ctx context.Context) { defer p.m.Unlock() startTime := time.Now() - for account, clusters := range p.polls { - client, err := p.apiClientProvider(account) - if err != nil { - log.Warn().TimeDiff("duration", time.Now(), startTime).Err(err).Msg("could not create api client") - continue - } + for account, regions := range p.polls { + for region, clusters := range regions { + client, err := p.apiClientProvider(account, region) + if err != nil { + log.Warn().TimeDiff("duration", time.Now(), startTime).Err(err).Msg("could not create api client") + continue + } - for cluster, services := range clusters { - servicesPages := utils.SplitIntoPages(maps.Keys(services), maxServicePageSize) - for _, servicePage := range servicesPages { - descriptions, err := client.DescribeServices(ctx, &ecs.DescribeServicesInput{ - Services: servicePage, - Cluster: extutil.Ptr(cluster), - }) - if err != nil { - log.Warn().TimeDiff("duration", time.Now(), startTime).Err(err).Msg("api call failed") - continue - } + for cluster, services := range clusters { + servicesPages := utils.SplitIntoPages(maps.Keys(services), maxServicePageSize) + for _, servicePage := range servicesPages { + descriptions, err := client.DescribeServices(ctx, &ecs.DescribeServicesInput{ + Services: servicePage, + Cluster: extutil.Ptr(cluster), + }) + if err != nil { + log.Warn().TimeDiff("duration", time.Now(), startTime).Err(err).Msg("api call failed") + continue + } - for _, service := range descriptions.Services { - services[aws.ToString(service.ServiceArn)].value = &PollService{ - service: &service, + for _, service := range descriptions.Services { + services[aws.ToString(service.ServiceArn)].value = &PollService{ + service: &service, + } } - } - for _, failure := range descriptions.Failures { - services[aws.ToString(failure.Arn)].value = &PollService{ - failure: &failure, + for _, failure := range descriptions.Failures { + services[aws.ToString(failure.Arn)].value = &PollService{ + failure: &failure, + } } } } @@ -196,10 +213,10 @@ func (p EcsServiceDescriptionPoller) pollAll(ctx context.Context) { p.c.Broadcast() } -func defaultDescribeServiceProvider(account string) (ecsDescribeServicesApi, error) { - awsAccount, err := utils.Accounts.GetAccount(account) +func defaultDescribeServiceProvider(account string, region string) (ecsDescribeServicesApi, error) { + awsAccess, err := utils.GetAwsAccess(account, region) if err != nil { return nil, err } - return ecs.NewFromConfig(awsAccount.AwsConfig), nil + return ecs.NewFromConfig(awsAccess.AwsConfig), nil } diff --git a/extecs/service_description_poller_mock_test.go b/extecs/service_description_poller_mock_test.go index c083d13f..af75de49 100644 --- a/extecs/service_description_poller_mock_test.go +++ b/extecs/service_description_poller_mock_test.go @@ -13,24 +13,24 @@ func (m *ServiceDescriptionPollerMock) Start(ctx context.Context) { m.Called(ctx) } -func (m *ServiceDescriptionPollerMock) Register(account string, cluster string, service string) { - m.Called(account, cluster, service) +func (m *ServiceDescriptionPollerMock) Register(account string, region string, cluster string, service string) { + m.Called(account, region, cluster, service) } -func (m *ServiceDescriptionPollerMock) Unregister(account string, cluster string, service string) { - m.Called(account, cluster, service) +func (m *ServiceDescriptionPollerMock) Unregister(account string, region string, cluster string, service string) { + m.Called(account, region, cluster, service) } -func (m *ServiceDescriptionPollerMock) Latest(account string, cluster string, service string) *PollService { - args := m.Called(account, cluster, service) +func (m *ServiceDescriptionPollerMock) Latest(account string, region string, cluster string, service string) *PollService { + args := m.Called(account, region, cluster, service) if args.Get(0) == nil { return nil } return args.Get(0).(*PollService) } -func (m *ServiceDescriptionPollerMock) AwaitLatest(account string, cluster string, service string) *PollService { - args := m.Called(account, cluster, service) +func (m *ServiceDescriptionPollerMock) AwaitLatest(account string, region string, cluster string, service string) *PollService { + args := m.Called(account, region, cluster, service) if args.Get(0) == nil { return nil } diff --git a/extecs/service_description_poller_test.go b/extecs/service_description_poller_test.go index bdf53abe..5dd9cb21 100644 --- a/extecs/service_description_poller_test.go +++ b/extecs/service_description_poller_test.go @@ -25,6 +25,7 @@ func (m *ecsDescribeServicesApiMock) DescribeServices(ctx context.Context, param func TestServiceDescriptionPoller_awaits_first_response(t *testing.T) { account := "awsAccount" + region := "region" cluster := "clusterArn" service := "serviceArn" @@ -33,7 +34,7 @@ func TestServiceDescriptionPoller_awaits_first_response(t *testing.T) { poller := NewServiceDescriptionPoller() poller.ticker = time.NewTicker(1 * time.Millisecond) - poller.apiClientProvider = func(account string) (ecsDescribeServicesApi, error) { + poller.apiClientProvider = func(account string, region string) (ecsDescribeServicesApi, error) { mockedApi := new(ecsDescribeServicesApiMock) mockedApi.On("DescribeServices", mock.Anything, mock.Anything).Return(&ecs.DescribeServicesOutput{ Services: []types.Service{{ @@ -45,12 +46,12 @@ func TestServiceDescriptionPoller_awaits_first_response(t *testing.T) { return mockedApi, nil } - poller.Register(account, cluster, service) - latest := poller.Latest(account, cluster, service) + poller.Register(account, region, cluster, service) + latest := poller.Latest(account, region, cluster, service) assert.Nil(t, latest) poller.Start(ctx) - latest = poller.AwaitLatest(account, cluster, service) + latest = poller.AwaitLatest(account, region, cluster, service) assert.NotNil(t, latest) assert.NotNil(t, latest.service) @@ -59,37 +60,39 @@ func TestServiceDescriptionPoller_awaits_first_response(t *testing.T) { func TestServiceDescriptionPoller_registers_and_unregisters_services(t *testing.T) { p := NewServiceDescriptionPoller() - p.Register("a", "b", "c") - p.Register("a", "b", "e") + p.Register("a", "b", "c", "d") + p.Register("a", "b", "c", "e") assert.Len(t, p.polls, 1) assert.Len(t, p.polls["a"], 1) - assert.Len(t, p.polls["a"]["b"], 2) - assert.NotNil(t, p.polls["a"]["b"]["c"]) - assert.NotNil(t, p.polls["a"]["b"]["e"]) - - p.Unregister("a", "b", "c") assert.Len(t, p.polls["a"]["b"], 1) + assert.Len(t, p.polls["a"]["b"]["c"], 2) + assert.NotNil(t, p.polls["a"]["b"]["c"]["d"]) + assert.NotNil(t, p.polls["a"]["b"]["c"]["e"]) + + p.Unregister("a", "b", "c", "d") + assert.Len(t, p.polls["a"]["b"]["c"], 1) - p.Unregister("a", "b", "e") + p.Unregister("a", "b", "c", "e") assert.Len(t, p.polls, 0) } func TestServiceDescriptionPoller_registers_and_unregisters_service_multiple_times(t *testing.T) { p := NewServiceDescriptionPoller() - p.Register("a", "b", "c") - p.Register("a", "b", "c") + p.Register("a", "b", "c", "d") + p.Register("a", "b", "c", "d") assert.Len(t, p.polls, 1) assert.Len(t, p.polls["a"], 1) assert.Len(t, p.polls["a"]["b"], 1) + assert.Len(t, p.polls["a"]["b"]["c"], 1) record := &pollRecord{ count: 1, } - assert.Equal(t, record, p.polls["a"]["b"]["c"]) + assert.Equal(t, record, p.polls["a"]["b"]["c"]["d"]) - p.Unregister("a", "b", "c") - assert.Len(t, p.polls["a"]["b"], 1) + p.Unregister("a", "b", "c", "d") + assert.Len(t, p.polls["a"]["b"]["c"], 1) - p.Unregister("a", "b", "c") + p.Unregister("a", "b", "c", "d") assert.Len(t, p.polls, 0) } diff --git a/extecs/service_discovery.go b/extecs/service_discovery.go index 098a744a..c7c92928 100644 --- a/extecs/service_discovery.go +++ b/extecs/service_discovery.go @@ -104,10 +104,10 @@ func (e *ecsServiceDiscovery) DescribeAttributes() []discovery_kit_api.Attribute } func (e *ecsServiceDiscovery) DiscoverTargets(ctx context.Context) ([]discovery_kit_api.Target, error) { - return utils.ForEveryAccount(utils.Accounts, getEcsServicesForAccount, ctx, "ecs-service") + return utils.ForEveryConfiguredAwsAccess(getEcsServicesForAccount, ctx, "ecs-service") } -func getEcsServicesForAccount(account *utils.AwsAccount, ctx context.Context) ([]discovery_kit_api.Target, error) { +func getEcsServicesForAccount(account *utils.AwsAccess, ctx context.Context) ([]discovery_kit_api.Target, error) { client := ecs.NewFromConfig(account.AwsConfig) result, err := GetAllEcsServices(account.AwsConfig.Region, account.AccountNumber, client, ctx) if err != nil { diff --git a/extecs/service_event_log.go b/extecs/service_event_log.go index fc183b86..66e1cdad 100644 --- a/extecs/service_event_log.go +++ b/extecs/service_event_log.go @@ -26,6 +26,7 @@ type EcsServiceEventLogState struct { ServiceArn string ClusterArn string AwsAccount string + Region string } func NewEcsServiceEventLogAction(poller ServiceDescriptionPoller) action_kit_sdk.Action[EcsServiceEventLogState] { @@ -93,13 +94,15 @@ func (f EcsServiceEventLogAction) Describe() action_kit_api.ActionDescription { func (f EcsServiceEventLogAction) Prepare(_ context.Context, state *EcsServiceEventLogState, request action_kit_api.PrepareActionRequestBody) (*action_kit_api.PrepareResult, error) { awsAccount := extutil.MustHaveValue(request.Target.Attributes, "aws.account")[0] + region := extutil.MustHaveValue(request.Target.Attributes, "aws.region")[0] clusterArn := extutil.MustHaveValue(request.Target.Attributes, "aws-ecs.cluster.arn")[0] serviceArn := extutil.MustHaveValue(request.Target.Attributes, "aws-ecs.service.arn")[0] - f.poller.Register(awsAccount, clusterArn, serviceArn) + f.poller.Register(awsAccount, region, clusterArn, serviceArn) state.LatestEventTimestamp = time.Now().In(time.UTC) state.AwsAccount = awsAccount + state.Region = region state.ClusterArn = clusterArn state.ServiceArn = serviceArn return nil, nil @@ -118,14 +121,14 @@ func (f EcsServiceEventLogAction) Status(_ context.Context, state *EcsServiceEve } func (f EcsServiceEventLogAction) Stop(_ context.Context, state *EcsServiceEventLogState) (*action_kit_api.StopResult, error) { - defer f.poller.Unregister(state.AwsAccount, state.ClusterArn, state.ServiceArn) + defer f.poller.Unregister(state.AwsAccount, state.Region, state.ClusterArn, state.ServiceArn) return &action_kit_api.StopResult{ Messages: f.newMessages(state), }, nil } func (f EcsServiceEventLogAction) newMessages(state *EcsServiceEventLogState) *action_kit_api.Messages { - latest := f.poller.Latest(state.AwsAccount, state.ClusterArn, state.ServiceArn) + latest := f.poller.Latest(state.AwsAccount, state.Region, state.ClusterArn, state.ServiceArn) newEvents, newLatestEventTimestamp := filterEventsAfter(latest, state.LatestEventTimestamp) state.LatestEventTimestamp = newLatestEventTimestamp if len(newEvents) > 0 { diff --git a/extecs/service_event_log_test.go b/extecs/service_event_log_test.go index 702d28e4..bc2acd42 100644 --- a/extecs/service_event_log_test.go +++ b/extecs/service_event_log_test.go @@ -12,6 +12,7 @@ import ( func TestServiceEventLog_Lifecycle(t *testing.T) { const account = "awsAccount" + const region = "region" const cluster = "cluster" const service = "service" @@ -19,9 +20,9 @@ func TestServiceEventLog_Lifecycle(t *testing.T) { defer cancel() pollerMock := new(ServiceDescriptionPollerMock) - pollerMock.On("Register", account, cluster, service) - pollerMock.On("Unregister", account, cluster, service) - pollerMock.On("Latest", account, cluster, service).Return(nil, nil) + pollerMock.On("Register", account, region, cluster, service) + pollerMock.On("Unregister", account, region, cluster, service) + pollerMock.On("Latest", account, region, cluster, service).Return(nil, nil) action := EcsServiceEventLogAction{ poller: pollerMock, } @@ -30,6 +31,7 @@ func TestServiceEventLog_Lifecycle(t *testing.T) { Target: &action_kit_api.Target{ Attributes: map[string][]string{ "aws.account": {account}, + "aws.region": {region}, "aws-ecs.cluster.arn": {cluster}, "aws-ecs.service.arn": {service}, }, @@ -40,9 +42,10 @@ func TestServiceEventLog_Lifecycle(t *testing.T) { assert.NoError(t, err) assert.Nil(t, prepare) assert.Equal(t, state.AwsAccount, account) + assert.Equal(t, state.Region, region) assert.Equal(t, state.ClusterArn, cluster) assert.Equal(t, state.ServiceArn, service) - pollerMock.AssertCalled(t, "Register", account, cluster, service) + pollerMock.AssertCalled(t, "Register", account, region, cluster, service) start, err := action.Start(ctx, state) assert.NoError(t, err) @@ -51,7 +54,7 @@ func TestServiceEventLog_Lifecycle(t *testing.T) { stop, err := action.Stop(ctx, state) assert.NoError(t, err) assert.NotNil(t, stop) - pollerMock.AssertCalled(t, "Unregister", account, cluster, service) + pollerMock.AssertCalled(t, "Unregister", account, region, cluster, service) } func TestServiceEventLog_Status(t *testing.T) { @@ -236,7 +239,7 @@ func TestServiceEventLog_Status(t *testing.T) { action := EcsServiceEventLogAction{} if len(test.responses) == 0 { pollerMock := new(ServiceDescriptionPollerMock) - pollerMock.On("Latest", test.state.AwsAccount, test.state.ClusterArn, test.state.ServiceArn).Return(nil, nil) + pollerMock.On("Latest", test.state.AwsAccount, test.state.Region, test.state.ClusterArn, test.state.ServiceArn).Return(nil, nil) action.poller = pollerMock runWithPoller(t, action, test.state, test.wanted, 0) } @@ -244,7 +247,7 @@ func TestServiceEventLog_Status(t *testing.T) { // Setting different return values for multiple calls with the same parameters does not seem to work. // This little workaround sets a new poller mock for every response, resulting in the same behavior. pollerMock := new(ServiceDescriptionPollerMock) - pollerMock.On("Latest", test.state.AwsAccount, test.state.ClusterArn, test.state.ServiceArn).Return(test.responses[i], nil) + pollerMock.On("Latest", test.state.AwsAccount, test.state.Region, test.state.ClusterArn, test.state.ServiceArn).Return(test.responses[i], nil) action.poller = pollerMock runWithPoller(t, action, test.state, test.wanted, i) } diff --git a/extecs/service_task_count_check.go b/extecs/service_task_count_check.go index ff69cf9c..cc570784 100644 --- a/extecs/service_task_count_check.go +++ b/extecs/service_task_count_check.go @@ -30,6 +30,7 @@ type EcsServiceTaskCountCheckState struct { ServiceArn string ClusterArn string AwsAccount string + Region string InitialRunningCount int } @@ -140,11 +141,12 @@ func (f EcsServiceTaskCountCheckAction) Prepare(_ context.Context, state *EcsSer } awsAccount := extutil.MustHaveValue(request.Target.Attributes, "aws.account")[0] + region := extutil.MustHaveValue(request.Target.Attributes, "aws.region")[0] clusterArn := extutil.MustHaveValue(request.Target.Attributes, "aws-ecs.cluster.arn")[0] serviceArn := extutil.MustHaveValue(request.Target.Attributes, "aws-ecs.service.arn")[0] - f.poller.Register(awsAccount, clusterArn, serviceArn) - counts, err := f.initialRunningCount(awsAccount, clusterArn, serviceArn) + f.poller.Register(awsAccount, region, clusterArn, serviceArn) + counts, err := f.initialRunningCount(awsAccount, region, clusterArn, serviceArn) if err != nil { return nil, err } @@ -152,6 +154,7 @@ func (f EcsServiceTaskCountCheckAction) Prepare(_ context.Context, state *EcsSer state.Timeout = time.Now().Add(time.Millisecond * time.Duration(config.Duration)) state.RunningCountCheckMode = config.RunningCountCheckMode state.AwsAccount = awsAccount + state.Region = region state.ClusterArn = clusterArn state.ServiceArn = serviceArn state.InitialRunningCount = counts.running @@ -159,8 +162,8 @@ func (f EcsServiceTaskCountCheckAction) Prepare(_ context.Context, state *EcsSer return nil, nil } -func (f EcsServiceTaskCountCheckAction) initialRunningCount(awsAccount string, clusterArn string, serviceArn string) (*escServiceTaskCounts, error) { - latest := f.poller.AwaitLatest(awsAccount, clusterArn, serviceArn) +func (f EcsServiceTaskCountCheckAction) initialRunningCount(awsAccount string, region string, clusterArn string, serviceArn string) (*escServiceTaskCounts, error) { + latest := f.poller.AwaitLatest(awsAccount, region, clusterArn, serviceArn) if latest != nil { if latest.service != nil { return toServiceTaskCounts(latest.service), nil @@ -178,12 +181,12 @@ func (f EcsServiceTaskCountCheckAction) Start(_ context.Context, _ *EcsServiceTa } func (f EcsServiceTaskCountCheckAction) Stop(_ context.Context, state *EcsServiceTaskCountCheckState) (*action_kit_api.StopResult, error) { - f.poller.Unregister(state.AwsAccount, state.ClusterArn, state.ServiceArn) + f.poller.Unregister(state.AwsAccount, state.Region, state.ClusterArn, state.ServiceArn) return nil, nil } func (f EcsServiceTaskCountCheckAction) Status(_ context.Context, state *EcsServiceTaskCountCheckState) (*action_kit_api.StatusResult, error) { - latest := f.poller.Latest(state.AwsAccount, state.ClusterArn, state.ServiceArn) + latest := f.poller.Latest(state.AwsAccount, state.Region, state.ClusterArn, state.ServiceArn) var checkError *action_kit_api.ActionKitError if latest != nil { diff --git a/extecs/service_task_count_check_test.go b/extecs/service_task_count_check_test.go index dec9d6fd..4630272d 100644 --- a/extecs/service_task_count_check_test.go +++ b/extecs/service_task_count_check_test.go @@ -22,7 +22,7 @@ func TestServiceTaskCountCheck_prepare_saves_initial_state(t *testing.T) { poller.ticker = time.NewTicker(1 * time.Millisecond) // Mock the api calls in ServiceDescriptionPoller to check the interactions of ServiceTaskCountCheck with it. - poller.apiClientProvider = func(account string) (ecsDescribeServicesApi, error) { + poller.apiClientProvider = func(account string, region string) (ecsDescribeServicesApi, error) { mockedApi := new(ecsDescribeServicesApiMock) mockedApi.On("DescribeServices", mock.Anything, mock.Anything).Return(&ecs.DescribeServicesOutput{ Services: []types.Service{{ @@ -43,6 +43,7 @@ func TestServiceTaskCountCheck_prepare_saves_initial_state(t *testing.T) { Target: extutil.Ptr(action_kit_api.Target{ Attributes: map[string][]string{ "aws.account": {"42"}, + "aws.region": {"eu-west-1"}, "aws-ecs.service.arn": {"service-arn"}, "aws-ecs.cluster.arn": {"cluster-arn"}, }, @@ -62,6 +63,7 @@ func TestServiceTaskCountCheck_prepare_saves_initial_state(t *testing.T) { assert.NoError(t, err) assert.LessOrEqual(t, state.Timeout, time.Now().Add(time.Second*100)) assert.Equal(t, state.AwsAccount, "42") + assert.Equal(t, state.Region, "eu-west-1") assert.Equal(t, state.ClusterArn, "cluster-arn") assert.Equal(t, state.ServiceArn, "service-arn") assert.Equal(t, state.InitialRunningCount, 2) @@ -211,7 +213,7 @@ func TestServiceTaskCountCheck_status_checks_running_count(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) pollerMock := new(ServiceDescriptionPollerMock) - pollerMock.On("Latest", test.state.AwsAccount, test.state.ClusterArn, test.state.ServiceArn).Return(&test.response, nil) + pollerMock.On("Latest", test.state.AwsAccount, test.state.Region, test.state.ClusterArn, test.state.ServiceArn).Return(&test.response, nil) action := EcsServiceTaskCountCheckAction{ poller: pollerMock, diff --git a/extecs/task_attack_ssm.go b/extecs/task_attack_ssm.go index 04da5aa9..5d641831 100644 --- a/extecs/task_attack_ssm.go +++ b/extecs/task_attack_ssm.go @@ -25,7 +25,7 @@ import ( ) type ecsTaskSsmAction struct { - clientProvider func(account string) (ecsTaskSsmApi, error) + clientProvider func(account string, region string) (ecsTaskSsmApi, error) description action_kit_api.ActionDescription ssmCommandInvocation ssmCommandInvocation } @@ -42,6 +42,7 @@ var ( type TaskSsmActionState struct { Account string + Region string TaskArn string ManagedInstanceId string CommandId string @@ -102,6 +103,7 @@ func (e *ecsTaskSsmAction) Describe() action_kit_api.ActionDescription { func (e *ecsTaskSsmAction) Prepare(ctx context.Context, state *TaskSsmActionState, request action_kit_api.PrepareActionRequestBody) (*action_kit_api.PrepareResult, error) { state.Account = extutil.MustHaveValue(request.Target.Attributes, "aws.account")[0] + state.Region = extutil.MustHaveValue(request.Target.Attributes, "aws.region")[0] state.TaskArn = extutil.MustHaveValue(request.Target.Attributes, "aws-ecs.task.arn")[0] if parameters, err := e.ssmCommandInvocation.getParameters(request); err == nil { @@ -116,7 +118,7 @@ func (e *ecsTaskSsmAction) Prepare(ctx context.Context, state *TaskSsmActionStat state.Comment = "Steadybit Experiment" } - client, err := e.clientProvider(state.Account) + client, err := e.clientProvider(state.Account, state.Region) if err != nil { return nil, err } @@ -140,7 +142,7 @@ func (e *ecsTaskSsmAction) Prepare(ctx context.Context, state *TaskSsmActionStat } func (e *ecsTaskSsmAction) Start(ctx context.Context, state *TaskSsmActionState) (*action_kit_api.StartResult, error) { - client, err := e.clientProvider(state.Account) + client, err := e.clientProvider(state.Account, state.Region) if err != nil { return nil, err } @@ -176,7 +178,7 @@ func shorten(s string, i int) string { } func (e *ecsTaskSsmAction) Status(ctx context.Context, state *TaskSsmActionState) (*action_kit_api.StatusResult, error) { - client, err := e.clientProvider(state.Account) + client, err := e.clientProvider(state.Account, state.Region) if err != nil { return nil, err } @@ -213,7 +215,7 @@ func (e *ecsTaskSsmAction) Stop(ctx context.Context, state *TaskSsmActionState) return nil, nil } - client, err := e.clientProvider(state.Account) + client, err := e.clientProvider(state.Account, state.Region) if err != nil { return nil, err } @@ -317,10 +319,10 @@ func (e *ecsTaskSsmAction) findManagedInstance(ctx context.Context, client ecsTa } } -func defaultTaskSsmClientProvider(account string) (ecsTaskSsmApi, error) { - awsAccount, err := utils.Accounts.GetAccount(account) +func defaultTaskSsmClientProvider(account string, region string) (ecsTaskSsmApi, error) { + awsAccess, err := utils.GetAwsAccess(account, region) if err != nil { return nil, err } - return ssm.NewFromConfig(awsAccount.AwsConfig), nil + return ssm.NewFromConfig(awsAccess.AwsConfig), nil } diff --git a/extecs/task_attack_ssm_test.go b/extecs/task_attack_ssm_test.go index 87846d05..75802da6 100644 --- a/extecs/task_attack_ssm_test.go +++ b/extecs/task_attack_ssm_test.go @@ -22,7 +22,7 @@ var ( mockApi = mockEcsTaskSsmApi{} testSsmAction = &ecsTaskSsmAction{ - clientProvider: func(account string) (ecsTaskSsmApi, error) { + clientProvider: func(account string, region string) (ecsTaskSsmApi, error) { return &mockApi, nil }, ssmCommandInvocation: ssmCommandInvocation{ @@ -36,6 +36,7 @@ var ( testTarget = &action_kit_api.Target{ Attributes: map[string][]string{ "aws.account": {"account"}, + "aws.region": {"region"}, "aws-ecs.task.arn": {"task"}, }, } @@ -72,6 +73,7 @@ func Test_ecsTaskSsmAction_Prepare(t *testing.T) { wantErr: assert.NoError, wantState: TaskSsmActionState{ Account: "account", + Region: "region", TaskArn: "task", ManagedInstanceId: "mi-0", Parameters: map[string][]string{ diff --git a/extecs/task_attack_stop.go b/extecs/task_attack_stop.go index e7f11226..c889b21b 100644 --- a/extecs/task_attack_stop.go +++ b/extecs/task_attack_stop.go @@ -17,7 +17,7 @@ import ( ) type ecsTaskStopAction struct { - clientProvider func(account string) (ecsTaskStopApi, error) + clientProvider func(account string, region string) (ecsTaskStopApi, error) } // Make sure lambdaAction implements all required interfaces @@ -25,6 +25,7 @@ var _ action_kit_sdk.Action[TaskStopState] = (*ecsTaskStopAction)(nil) type TaskStopState struct { Account string + Region string TaskArn string ClusterArn string } @@ -69,13 +70,14 @@ func (e *ecsTaskStopAction) Describe() action_kit_api.ActionDescription { func (e *ecsTaskStopAction) Prepare(_ context.Context, state *TaskStopState, request action_kit_api.PrepareActionRequestBody) (*action_kit_api.PrepareResult, error) { state.Account = extutil.MustHaveValue(request.Target.Attributes, "aws.account")[0] + state.Region = extutil.MustHaveValue(request.Target.Attributes, "aws.region")[0] state.ClusterArn = extutil.MustHaveValue(request.Target.Attributes, "aws-ecs.cluster.arn")[0] state.TaskArn = extutil.MustHaveValue(request.Target.Attributes, "aws-ecs.task.arn")[0] return nil, nil } func (e *ecsTaskStopAction) Start(ctx context.Context, state *TaskStopState) (*action_kit_api.StartResult, error) { - client, err := e.clientProvider(state.Account) + client, err := e.clientProvider(state.Account, state.Region) if err != nil { return nil, extension_kit.ToError(fmt.Sprintf("Failed to initialize ECS client for AWS account %s", state.Account), err) } @@ -109,10 +111,10 @@ func (e *ecsTaskStopAction) Start(ctx context.Context, state *TaskStopState) (*a return nil, nil } -func defaultTaskStopClientProvider(account string) (ecsTaskStopApi, error) { - awsAccount, err := utils.Accounts.GetAccount(account) +func defaultTaskStopClientProvider(account string, region string) (ecsTaskStopApi, error) { + awsAccess, err := utils.GetAwsAccess(account, region) if err != nil { return nil, err } - return ecs.NewFromConfig(awsAccount.AwsConfig), nil + return ecs.NewFromConfig(awsAccess.AwsConfig), nil } diff --git a/extecs/task_attack_stop_test.go b/extecs/task_attack_stop_test.go index 2e3c4a27..c01cfe68 100644 --- a/extecs/task_attack_stop_test.go +++ b/extecs/task_attack_stop_test.go @@ -34,12 +34,14 @@ func TestEcsTaskStopAction_Prepare(t *testing.T) { "aws-ecs.cluster.arn": {"my-cluster-arn"}, "aws-ecs.task.arn": {"my-task-arn"}, "aws.account": {"42"}, + "aws.region": {"us-west-1"}, }, }), }), wantedState: &TaskStopState{ Account: "42", + Region: "us-west-1", ClusterArn: "my-cluster-arn", TaskArn: "my-task-arn", }, @@ -61,6 +63,7 @@ func TestEcsTaskStopAction_Prepare(t *testing.T) { if tt.wantedState != nil { assert.NoError(t, err) assert.Equal(t, tt.wantedState.Account, state.Account) + assert.Equal(t, tt.wantedState.Region, state.Region) assert.Equal(t, tt.wantedState.ClusterArn, state.ClusterArn) assert.EqualValues(t, tt.wantedState.TaskArn, state.TaskArn) } @@ -98,7 +101,7 @@ func TestEcsTaskStopAction_Start(t *testing.T) { }, }) - action := ecsTaskStopAction{clientProvider: func(account string) (ecsTaskStopApi, error) { + action := ecsTaskStopAction{clientProvider: func(account string, region string) (ecsTaskStopApi, error) { return api, nil }} @@ -127,7 +130,7 @@ func TestEcsTaskStopAction_Start_already_stopped_task(t *testing.T) { }, }) - action := ecsTaskStopAction{clientProvider: func(account string) (ecsTaskStopApi, error) { + action := ecsTaskStopAction{clientProvider: func(account string, region string) (ecsTaskStopApi, error) { return api, nil }} @@ -161,7 +164,7 @@ func TestEcsTaskStopActionForwardsError(t *testing.T) { require.Equal(t, "my-cluster-arn", *params.Cluster) return true })).Return(&ecs.StopTaskOutput{}, errors.New("expected")) - action := ecsTaskStopAction{clientProvider: func(account string) (ecsTaskStopApi, error) { + action := ecsTaskStopAction{clientProvider: func(account string, region string) (ecsTaskStopApi, error) { return api, nil }} diff --git a/extecs/task_discovery.go b/extecs/task_discovery.go index 9844e244..18284b30 100644 --- a/extecs/task_discovery.go +++ b/extecs/task_discovery.go @@ -107,10 +107,10 @@ func (e *ecsTaskDiscovery) DescribeAttributes() []discovery_kit_api.AttributeDes } func (e *ecsTaskDiscovery) DiscoverTargets(ctx context.Context) ([]discovery_kit_api.Target, error) { - return utils.ForEveryAccount(utils.Accounts, getTargetsForAccount, ctx, "ecs-task") + return utils.ForEveryConfiguredAwsAccess(getTargetsForAccount, ctx, "ecs-task") } -func getTargetsForAccount(account *utils.AwsAccount, ctx context.Context) ([]discovery_kit_api.Target, error) { +func getTargetsForAccount(account *utils.AwsAccess, ctx context.Context) ([]discovery_kit_api.Target, error) { client := ecs.NewFromConfig(account.AwsConfig) result, err := GetAllEcsTasks(ctx, client, utils.Zones, account.AccountNumber, account.AwsConfig.Region) if err != nil { @@ -183,7 +183,7 @@ func toTarget(task types.Task, zoneUtil utils.GetZoneUtil, awsAccountNumber stri arn := aws.ToString(task.TaskArn) availabilityZoneName := aws.ToString(task.AvailabilityZone) - availabilityZoneApi := zoneUtil.GetZone(awsAccountNumber, availabilityZoneName) + availabilityZoneApi := zoneUtil.GetZone(awsAccountNumber, availabilityZoneName, awsRegion) attributes := make(map[string][]string) attributes["aws.account"] = []string{awsAccountNumber} diff --git a/extecs/task_discovery_test.go b/extecs/task_discovery_test.go index 11256211..ac52897a 100644 --- a/extecs/task_discovery_test.go +++ b/extecs/task_discovery_test.go @@ -47,8 +47,8 @@ type zoneMock struct { mock.Mock } -func (m *zoneMock) GetZone(awsAccountNumber string, awsZone string) *ec2types.AvailabilityZone { - args := m.Called(awsAccountNumber, awsZone) +func (m *zoneMock) GetZone(awsAccountNumber string, awsZone string, region string) *ec2types.AvailabilityZone { + args := m.Called(awsAccountNumber, awsZone, region) return args.Get(0).(*ec2types.AvailabilityZone) } @@ -104,7 +104,7 @@ func TestGetAllEcsTasks(t *testing.T) { RegionName: discovery_kit_api.Ptr("us-east-1"), ZoneId: discovery_kit_api.Ptr("us-east-1b-id"), } - mockedZoneUtil.On("GetZone", mock.Anything, mock.Anything).Return(&mockedZone) + mockedZoneUtil.On("GetZone", mock.Anything, mock.Anything, mock.Anything).Return(&mockedZone) // When targets, err := GetAllEcsTasks(context.Background(), mockedApi, mockedZoneUtil, "42", "us-east-1") diff --git a/extelasticache/common_elasticache.go b/extelasticache/common_elasticache.go index b89514dc..0f8d3854 100644 --- a/extelasticache/common_elasticache.go +++ b/extelasticache/common_elasticache.go @@ -21,6 +21,7 @@ type ElasticacheClusterAttackState struct { ReplicationGroupID string NodeGroupID string Account string + Region string } type ElasticacheApi interface { @@ -28,10 +29,10 @@ type ElasticacheApi interface { DescribeReplicationGroups(ctx context.Context, params *elasticache.DescribeReplicationGroupsInput, optFns ...func(*elasticache.Options)) (*elasticache.DescribeReplicationGroupsOutput, error) } -func defaultElasticacheClientProvider(account string) (ElasticacheApi, error) { - awsAccount, err := utils.Accounts.GetAccount(account) +func defaultElasticacheClientProvider(account string, region string) (ElasticacheApi, error) { + awsAccess, err := utils.GetAwsAccess(account, region) if err != nil { return nil, err } - return elasticache.NewFromConfig(awsAccount.AwsConfig), nil + return elasticache.NewFromConfig(awsAccess.AwsConfig), nil } diff --git a/extelasticache/node_group_attack_failover.go b/extelasticache/node_group_attack_failover.go index 6d5f7c25..af55d4f7 100644 --- a/extelasticache/node_group_attack_failover.go +++ b/extelasticache/node_group_attack_failover.go @@ -15,7 +15,7 @@ import ( ) type elasticacheNodeGroupFailoverAttack struct { - clientProvider func(account string) (ElasticacheApi, error) + clientProvider func(account string, region string) (ElasticacheApi, error) } var _ action_kit_sdk.Action[ElasticacheClusterAttackState] = (*elasticacheNodeGroupFailoverAttack)(nil) @@ -55,13 +55,14 @@ func (f elasticacheNodeGroupFailoverAttack) Describe() action_kit_api.ActionDesc func (f elasticacheNodeGroupFailoverAttack) Prepare(_ context.Context, state *ElasticacheClusterAttackState, request action_kit_api.PrepareActionRequestBody) (*action_kit_api.PrepareResult, error) { state.Account = extutil.MustHaveValue(request.Target.Attributes, "aws.account")[0] + state.Region = extutil.MustHaveValue(request.Target.Attributes, "aws.region")[0] state.ReplicationGroupID = extutil.MustHaveValue(request.Target.Attributes, "aws.elasticache.replication-group.id")[0] state.NodeGroupID = extutil.MustHaveValue(request.Target.Attributes, "aws.elasticache.replication-group.node-group.id")[0] return nil, nil } func (f elasticacheNodeGroupFailoverAttack) Start(ctx context.Context, state *ElasticacheClusterAttackState) (*action_kit_api.StartResult, error) { - client, err := f.clientProvider(state.Account) + client, err := f.clientProvider(state.Account, state.Region) if err != nil { return nil, extension_kit.ToError(fmt.Sprintf("Failed to initialize Elasticache client for AWS account %s", state.Account), err) } diff --git a/extelasticache/node_group_attack_failover_test.go b/extelasticache/node_group_attack_failover_test.go index 91bff04a..fd4df2af 100644 --- a/extelasticache/node_group_attack_failover_test.go +++ b/extelasticache/node_group_attack_failover_test.go @@ -23,6 +23,7 @@ func TestTestFailover(t *testing.T) { "aws.elasticache.replication-group.node-group.id": {"0001"}, "aws.elasticache.replication-group.id": {"redis-steadybit-dev"}, "aws.account": {"42"}, + "aws.region": {"us-west-1"}, }, }), }) @@ -37,6 +38,7 @@ func TestTestFailover(t *testing.T) { assert.NoError(t, err) assert.Equal(t, "0001", state.NodeGroupID) assert.Equal(t, "42", state.Account) + assert.Equal(t, "us-west-1", state.Region) assert.Equal(t, "redis-steadybit-dev", state.ReplicationGroupID) } @@ -53,7 +55,7 @@ func TestStartClusterFailover(t *testing.T) { NodeGroupID: "0001", Account: "42", } - action := elasticacheNodeGroupFailoverAttack{clientProvider: func(account string) (ElasticacheApi, error) { + action := elasticacheNodeGroupFailoverAttack{clientProvider: func(account string, region string) (ElasticacheApi, error) { return api, nil }} @@ -72,7 +74,7 @@ func TestStartClusterFailoverForwardFailoverError(t *testing.T) { state := ElasticacheClusterAttackState{ ReplicationGroupID: "redis-steadybit-dev", } - action := elasticacheNodeGroupFailoverAttack{clientProvider: func(account string) (ElasticacheApi, error) { + action := elasticacheNodeGroupFailoverAttack{clientProvider: func(account string, region string) (ElasticacheApi, error) { return api, nil }} diff --git a/extelasticache/node_group_discovery.go b/extelasticache/node_group_discovery.go index 088505da..853818e3 100644 --- a/extelasticache/node_group_discovery.go +++ b/extelasticache/node_group_discovery.go @@ -126,10 +126,10 @@ func (r *elasticacheReplicationGroupDiscovery) DescribeAttributes() []discovery_ } func (r *elasticacheReplicationGroupDiscovery) DiscoverTargets(ctx context.Context) ([]discovery_kit_api.Target, error) { - return utils.ForEveryAccount(utils.Accounts, getClusterTargetsForAccount, ctx, "replication-group") + return utils.ForEveryConfiguredAwsAccess(getClusterTargetsForAccount, ctx, "replication-group") } -func getClusterTargetsForAccount(account *utils.AwsAccount, ctx context.Context) ([]discovery_kit_api.Target, error) { +func getClusterTargetsForAccount(account *utils.AwsAccess, ctx context.Context) ([]discovery_kit_api.Target, error) { client := elasticache.NewFromConfig(account.AwsConfig) result, err := getAllElasticacheReplicationGroups(ctx, client, account.AccountNumber, account.AwsConfig.Region) if err != nil { diff --git a/extelb/alb_attack_static_response.go b/extelb/alb_attack_static_response.go index 0bcb595e..190145b6 100644 --- a/extelb/alb_attack_static_response.go +++ b/extelb/alb_attack_static_response.go @@ -20,7 +20,7 @@ import ( ) type albStaticResponseAction struct { - clientProvider func(account string) (albStaticResponseApi, error) + clientProvider func(account string, region string) (albStaticResponseApi, error) } // Make sure action implements all required interfaces @@ -29,6 +29,7 @@ var _ action_kit_sdk.ActionWithStop[AlbStaticResponseState] = (*albStaticRespons type AlbStaticResponseState struct { Account string + Region string LoadbalancerArn string ListenerArn string ResponseStatusCode int @@ -206,6 +207,7 @@ func (e *albStaticResponseAction) Describe() action_kit_api.ActionDescription { func (e *albStaticResponseAction) Prepare(ctx context.Context, state *AlbStaticResponseState, request action_kit_api.PrepareActionRequestBody) (*action_kit_api.PrepareResult, error) { state.Account = extutil.MustHaveValue(request.Target.Attributes, "aws.account")[0] + state.Region = extutil.MustHaveValue(request.Target.Attributes, "aws.region")[0] state.LoadbalancerArn = extutil.MustHaveValue(request.Target.Attributes, "aws-elb.alb.arn")[0] state.TargetExecutionId = request.ExecutionId if request.ExecutionContext != nil { @@ -213,7 +215,7 @@ func (e *albStaticResponseAction) Prepare(ctx context.Context, state *AlbStaticR state.ExperimentKey = *request.ExecutionContext.ExperimentKey } - client, err := e.clientProvider(state.Account) + client, err := e.clientProvider(state.Account, state.Region) if err != nil { return nil, extension_kit.ToError(fmt.Sprintf("Failed to initialize elb client for AWS account %s", state.Account), err) } @@ -278,7 +280,7 @@ func (e *albStaticResponseAction) Prepare(ctx context.Context, state *AlbStaticR } func (e *albStaticResponseAction) Start(ctx context.Context, state *AlbStaticResponseState) (*action_kit_api.StartResult, error) { - client, err := e.clientProvider(state.Account) + client, err := e.clientProvider(state.Account, state.Region) if err != nil { return nil, extension_kit.ToError(fmt.Sprintf("Failed to initialize elb client for AWS account %s", state.Account), err) } @@ -467,7 +469,7 @@ func getNewPriorityPairs(rules *elasticloadbalancingv2.DescribeRulesOutput) []ty } func (e *albStaticResponseAction) Stop(ctx context.Context, state *AlbStaticResponseState) (*action_kit_api.StopResult, error) { - client, err := e.clientProvider(state.Account) + client, err := e.clientProvider(state.Account, state.Region) if err != nil { return nil, extension_kit.ToError(fmt.Sprintf("Failed to initialize ECS client for AWS account %s", state.Account), err) } @@ -593,10 +595,10 @@ func restoreOldPriorities(ctx context.Context, client *albStaticResponseApi, sta return nil } -func defaultClientProviderService(account string) (albStaticResponseApi, error) { - awsAccount, err := utils.Accounts.GetAccount(account) +func defaultClientProviderService(account string, region string) (albStaticResponseApi, error) { + awsAccess, err := utils.GetAwsAccess(account, region) if err != nil { return nil, err } - return elasticloadbalancingv2.NewFromConfig(awsAccount.AwsConfig), nil + return elasticloadbalancingv2.NewFromConfig(awsAccess.AwsConfig), nil } diff --git a/extelb/alb_attack_static_response_test.go b/extelb/alb_attack_static_response_test.go index ba51c55d..57b98d32 100644 --- a/extelb/alb_attack_static_response_test.go +++ b/extelb/alb_attack_static_response_test.go @@ -65,7 +65,7 @@ func TestAlbStaticResponseAction_Prepare(t *testing.T) { }}, }, nil) - action := albStaticResponseAction{clientProvider: func(account string) (albStaticResponseApi, error) { + action := albStaticResponseAction{clientProvider: func(account string, region string) (albStaticResponseApi, error) { return api, nil }} @@ -97,6 +97,7 @@ func TestAlbStaticResponseAction_Prepare(t *testing.T) { Attributes: map[string][]string{ "aws-elb.alb.arn": {"my-loadbalancer-arn"}, "aws.account": {"42"}, + "aws.region": {"us-west-1"}, }, }), ExecutionContext: extutil.Ptr(action_kit_api.ExecutionContext{ @@ -108,6 +109,7 @@ func TestAlbStaticResponseAction_Prepare(t *testing.T) { wantedState: &AlbStaticResponseState{ Account: "42", + Region: "us-west-1", ListenerArn: "my-listener-arn", LoadbalancerArn: "my-loadbalancer-arn", ResponseBody: "Steadybit killed your request", @@ -141,6 +143,7 @@ func TestAlbStaticResponseAction_Prepare(t *testing.T) { Attributes: map[string][]string{ "aws-elb.alb.arn": {"my-loadbalancer-arn"}, "aws.account": {"42"}, + "aws.region": {"us-west-1"}, }, }), }), @@ -158,6 +161,7 @@ func TestAlbStaticResponseAction_Prepare(t *testing.T) { Attributes: map[string][]string{ "aws-elb.alb.arn": {"my-loadbalancer-arn"}, "aws.account": {"42"}, + "aws.region": {"us-west-1"}, }, }), }), @@ -175,6 +179,7 @@ func TestAlbStaticResponseAction_Prepare(t *testing.T) { Attributes: map[string][]string{ "aws-elb.alb.arn": {"my-loadbalancer-arn"}, "aws.account": {"42"}, + "aws.region": {"us-west-1"}, }, }), }), @@ -199,6 +204,7 @@ func TestAlbStaticResponseAction_Prepare(t *testing.T) { Attributes: map[string][]string{ "aws-elb.alb.arn": {"my-loadbalancer-arn"}, "aws.account": {"42"}, + "aws.region": {"us-west-1"}, }, }), }), @@ -216,6 +222,7 @@ func TestAlbStaticResponseAction_Prepare(t *testing.T) { Attributes: map[string][]string{ "aws-elb.alb.arn": {"my-loadbalancer-arn"}, "aws.account": {"42"}, + "aws.region": {"us-west-1"}, }, }), }), @@ -236,6 +243,7 @@ func TestAlbStaticResponseAction_Prepare(t *testing.T) { Attributes: map[string][]string{ "aws-elb.alb.arn": {"my-loadbalancer-arn"}, "aws.account": {"42"}, + "aws.region": {"us-west-1"}, }, }), }), @@ -314,13 +322,14 @@ func TestAlbStaticResponseAction_Start(t *testing.T) { }, }, nil) - action := albStaticResponseAction{clientProvider: func(account string) (albStaticResponseApi, error) { + action := albStaticResponseAction{clientProvider: func(account string, region string) (albStaticResponseApi, error) { return api, nil }} // When state := &AlbStaticResponseState{ Account: "42", + Region: "us-west-1", ListenerArn: "my-listener-arn", LoadbalancerArn: "my-loadbalancer-arn", ResponseBody: "Steadybit killed your request", @@ -440,13 +449,14 @@ func TestEcsServiceScaleAction_Stop(t *testing.T) { return true })).Return(&elasticloadbalancingv2.SetRulePrioritiesOutput{}, nil) - action := albStaticResponseAction{clientProvider: func(account string) (albStaticResponseApi, error) { + action := albStaticResponseAction{clientProvider: func(account string, region string) (albStaticResponseApi, error) { return api, nil }} // When state := &AlbStaticResponseState{ Account: "42", + Region: "us-west-1", ListenerArn: "my-listener-arn", LoadbalancerArn: "my-loadbalancer-arn", TargetExecutionId: targetExecutionId, diff --git a/extelb/alb_discovery.go b/extelb/alb_discovery.go index 08289ba4..7ca1a440 100644 --- a/extelb/alb_discovery.go +++ b/extelb/alb_discovery.go @@ -99,10 +99,10 @@ func (e *albDiscovery) DescribeAttributes() []discovery_kit_api.AttributeDescrip } func (e *albDiscovery) DiscoverTargets(ctx context.Context) ([]discovery_kit_api.Target, error) { - return utils.ForEveryAccount(utils.Accounts, getTargetsForAccount, ctx, "ecs-task") + return utils.ForEveryConfiguredAwsAccess(getTargetsForAccount, ctx, "ecs-task") } -func getTargetsForAccount(account *utils.AwsAccount, ctx context.Context) ([]discovery_kit_api.Target, error) { +func getTargetsForAccount(account *utils.AwsAccess, ctx context.Context) ([]discovery_kit_api.Target, error) { client := elasticloadbalancingv2.NewFromConfig(account.AwsConfig) result, err := GetAlbs(ctx, client, utils.Zones, account.AccountNumber, account.AwsConfig.Region) if err != nil { @@ -179,7 +179,7 @@ func toTarget(lb *types.LoadBalancer, tags []types.Tag, listeners []types.Listen zoneIds := make([]string, 0, len(lb.AvailabilityZones)) for _, zone := range lb.AvailabilityZones { zones = append(zones, aws.ToString(zone.ZoneName)) - zoneApi := zoneUtil.GetZone(awsAccountNumber, aws.ToString(zone.ZoneName)) + zoneApi := zoneUtil.GetZone(awsAccountNumber, aws.ToString(zone.ZoneName), awsRegion) if zoneApi != nil { zoneIds = append(zoneIds, *zoneApi.ZoneId) } diff --git a/extelb/alb_discovery_test.go b/extelb/alb_discovery_test.go index ee988bca..0f48cc64 100644 --- a/extelb/alb_discovery_test.go +++ b/extelb/alb_discovery_test.go @@ -48,8 +48,8 @@ type zoneMock struct { mock.Mock } -func (m *zoneMock) GetZone(awsAccountNumber string, awsZone string) *ec2types.AvailabilityZone { - args := m.Called(awsAccountNumber, awsZone) +func (m *zoneMock) GetZone(awsAccountNumber string, awsZone string, region string) *ec2types.AvailabilityZone { + args := m.Called(awsAccountNumber, awsZone, region) return args.Get(0).(*ec2types.AvailabilityZone) } @@ -146,10 +146,10 @@ func TestGetAllAlbTargets(t *testing.T) { } mockedZoneUtil.On("GetZone", mock.Anything, mock.MatchedBy(func(params string) bool { return params == "us-east-1a" - })).Return(&mockedZone1a) + }), mock.Anything).Return(&mockedZone1a) mockedZoneUtil.On("GetZone", mock.Anything, mock.MatchedBy(func(params string) bool { return params == "us-east-1b" - })).Return(&mockedZone1b) + }), mock.Anything).Return(&mockedZone1b) // When targets, err := GetAlbs(context.Background(), mockedApi, mockedZoneUtil, "42", "us-east-1") diff --git a/extfis/start_experiment.go b/extfis/start_experiment.go index 6485dfaa..380f2dd9 100644 --- a/extfis/start_experiment.go +++ b/extfis/start_experiment.go @@ -24,6 +24,7 @@ type FisExperimentAction struct { type FisExperimentState struct { Account string + Region string ExperimentId string TemplateId string LastSummary string @@ -92,17 +93,18 @@ func (f FisExperimentAction) Describe() action_kit_api.ActionDescription { func (f FisExperimentAction) Prepare(_ context.Context, state *FisExperimentState, request action_kit_api.PrepareActionRequestBody) (*action_kit_api.PrepareResult, error) { state.TemplateId = extutil.MustHaveValue(request.Target.Attributes, "aws.fis.experiment.template.id")[0] state.Account = extutil.MustHaveValue(request.Target.Attributes, "aws.account")[0] + state.Region = extutil.MustHaveValue(request.Target.Attributes, "aws.region")[0] state.ExecutionId = request.ExecutionId return nil, nil } func (f FisExperimentAction) Start(ctx context.Context, state *FisExperimentState) (*action_kit_api.StartResult, error) { - return startExperiment(ctx, state, func(account string) (FisStartExperimentClient, error) { - awsAccount, err := utils.Accounts.GetAccount(account) + return startExperiment(ctx, state, func(account string, region string) (FisStartExperimentClient, error) { + awsAccess, err := utils.GetAwsAccess(account, region) if err != nil { return nil, err } - return fis.NewFromConfig(awsAccount.AwsConfig), nil + return fis.NewFromConfig(awsAccess.AwsConfig), nil }) } @@ -110,8 +112,8 @@ type FisStartExperimentClient interface { StartExperiment(ctx context.Context, params *fis.StartExperimentInput, optFns ...func(*fis.Options)) (*fis.StartExperimentOutput, error) } -func startExperiment(ctx context.Context, state *FisExperimentState, clientProvider func(account string) (FisStartExperimentClient, error)) (*action_kit_api.StartResult, error) { - client, err := clientProvider(state.Account) +func startExperiment(ctx context.Context, state *FisExperimentState, clientProvider func(account string, region string) (FisStartExperimentClient, error)) (*action_kit_api.StartResult, error) { + client, err := clientProvider(state.Account, state.Region) if err != nil { return nil, extension_kit.ToError(fmt.Sprintf("Failed to initialize FIS client for AWS account %s", state.Account), err) } @@ -136,12 +138,12 @@ func startExperiment(ctx context.Context, state *FisExperimentState, clientProvi } func (f FisExperimentAction) Status(ctx context.Context, state *FisExperimentState) (*action_kit_api.StatusResult, error) { - return statusExperiment(ctx, state, func(account string) (FisStatusExperimentClient, error) { - awsAccount, err := utils.Accounts.GetAccount(account) + return statusExperiment(ctx, state, func(account string, region string) (FisStatusExperimentClient, error) { + awsAccess, err := utils.GetAwsAccess(account, region) if err != nil { return nil, err } - return fis.NewFromConfig(awsAccount.AwsConfig), nil + return fis.NewFromConfig(awsAccess.AwsConfig), nil }) } @@ -149,8 +151,8 @@ type FisStatusExperimentClient interface { GetExperiment(ctx context.Context, params *fis.GetExperimentInput, optFns ...func(*fis.Options)) (*fis.GetExperimentOutput, error) } -func statusExperiment(ctx context.Context, state *FisExperimentState, clientProvider func(account string) (FisStatusExperimentClient, error)) (*action_kit_api.StatusResult, error) { - client, err := clientProvider(state.Account) +func statusExperiment(ctx context.Context, state *FisExperimentState, clientProvider func(account string, region string) (FisStatusExperimentClient, error)) (*action_kit_api.StatusResult, error) { + client, err := clientProvider(state.Account, state.Region) if err != nil { return nil, extension_kit.ToError("Failed to initialize FIS client for AWS account %s", err) } @@ -215,17 +217,17 @@ type FisStopExperimentClient interface { } func (f FisExperimentAction) Stop(ctx context.Context, state *FisExperimentState) (*action_kit_api.StopResult, error) { - return stopExperiment(ctx, state, func(account string) (FisStopExperimentClient, error) { - awsAccount, err := utils.Accounts.GetAccount(account) + return stopExperiment(ctx, state, func(account string, region string) (FisStopExperimentClient, error) { + awsAccess, err := utils.GetAwsAccess(account, region) if err != nil { return nil, err } - return fis.NewFromConfig(awsAccount.AwsConfig), nil + return fis.NewFromConfig(awsAccess.AwsConfig), nil }) } -func stopExperiment(ctx context.Context, state *FisExperimentState, clientProvider func(account string) (FisStopExperimentClient, error)) (*action_kit_api.StopResult, error) { - client, err := clientProvider(state.Account) +func stopExperiment(ctx context.Context, state *FisExperimentState, clientProvider func(account string, region string) (FisStopExperimentClient, error)) (*action_kit_api.StopResult, error) { + client, err := clientProvider(state.Account, state.Region) if err != nil { return nil, extension_kit.ToError("Failed to initialize FIS client for AWS account %s", err) } diff --git a/extfis/start_experiment_test.go b/extfis/start_experiment_test.go index bdaf6310..f903cd74 100644 --- a/extfis/start_experiment_test.go +++ b/extfis/start_experiment_test.go @@ -24,6 +24,7 @@ func TestPrepareInstanceReboot(t *testing.T) { Attributes: map[string][]string{ "aws.fis.experiment.template.id": {"template-123"}, "aws.account": {"42"}, + "aws.region": {"us-west-1"}, }, }), ExecutionId: executionId, @@ -38,6 +39,7 @@ func TestPrepareInstanceReboot(t *testing.T) { assert.Nil(t, err) assert.Nil(t, result) assert.Equal(t, "42", state.Account) + assert.Equal(t, "us-west-1", state.Region) assert.Equal(t, "template-123", state.TemplateId) assert.Equal(t, executionId, state.ExecutionId) } @@ -78,11 +80,13 @@ func TestStartExperiment(t *testing.T) { state := action.NewEmptyState() state.TemplateId = "template-123" state.Account = "42" + state.Region = "us-west-1" state.ExecutionId = executionId // When - result, err := startExperiment(context.Background(), &state, func(account string) (FisStartExperimentClient, error) { + result, err := startExperiment(context.Background(), &state, func(account string, region string) (FisStartExperimentClient, error) { assert.Equal(t, "42", account) + assert.Equal(t, "us-west-1", region) return mockedApi, nil }) @@ -131,12 +135,14 @@ func TestStatusExperiment(t *testing.T) { state := action.NewEmptyState() state.TemplateId = "template-123" state.Account = "42" + state.Region = "us-west-1" state.ExecutionId = executionId state.ExperimentId = "EXP-123" // When - result, err := statusExperiment(context.Background(), &state, func(account string) (FisStatusExperimentClient, error) { + result, err := statusExperiment(context.Background(), &state, func(account string, region string) (FisStatusExperimentClient, error) { assert.Equal(t, "42", account) + assert.Equal(t, "us-west-1", region) return mockedApi, nil }) @@ -176,12 +182,14 @@ func TestStopExperiment(t *testing.T) { state := action.NewEmptyState() state.TemplateId = "template-123" state.Account = "42" + state.Region = "us-west-1" state.ExecutionId = executionId state.ExperimentId = "EXP-123" // When - _, extKitErr := stopExperiment(context.Background(), &state, func(account string) (FisStopExperimentClient, error) { + _, extKitErr := stopExperiment(context.Background(), &state, func(account string, region string) (FisStopExperimentClient, error) { assert.Equal(t, "42", account) + assert.Equal(t, "us-west-1", region) return mockedApi, nil }) diff --git a/extfis/template_discovery.go b/extfis/template_discovery.go index 5dfee3d8..576d592d 100644 --- a/extfis/template_discovery.go +++ b/extfis/template_discovery.go @@ -97,9 +97,9 @@ func (f *fisTemplateDiscovery) DescribeAttributes() []discovery_kit_api.Attribut } } func (f *fisTemplateDiscovery) DiscoverTargets(ctx context.Context) ([]discovery_kit_api.Target, error) { - return utils.ForEveryAccount(utils.Accounts, getTargetsForAccount, ctx, "fis-template") + return utils.ForEveryConfiguredAwsAccess(getTargetsForAccount, ctx, "fis-template") } -func getTargetsForAccount(account *utils.AwsAccount, ctx context.Context) ([]discovery_kit_api.Target, error) { +func getTargetsForAccount(account *utils.AwsAccess, ctx context.Context) ([]discovery_kit_api.Target, error) { client := fis.NewFromConfig(account.AwsConfig) result, err := GetAllFisTemplates(ctx, client, account.AccountNumber, account.AwsConfig.Region) if err != nil { diff --git a/extlambda/attack.go b/extlambda/attack.go index 7fc2b477..3a5293dd 100644 --- a/extlambda/attack.go +++ b/extlambda/attack.go @@ -21,7 +21,7 @@ import ( type lambdaAction struct { description action_kit_api.ActionDescription configProvider func(request action_kit_api.PrepareActionRequestBody) (*FailureInjectionConfig, error) - clientProvider func(account string) (ssmApi, error) + clientProvider func(account string, region string) (ssmApi, error) } type ssmApi interface { @@ -48,6 +48,7 @@ type FailureInjectionConfig struct { type LambdaActionState struct { Account string `json:"account"` + Region string `json:"region"` Param string `json:"param"` Config *FailureInjectionConfig `json:"config"` ExperimentKey *string `json:"experimentKey"` @@ -74,6 +75,7 @@ func (a *lambdaAction) Prepare(_ context.Context, state *LambdaActionState, requ } state.Account = extutil.MustHaveValue(request.Target.Attributes, "aws.account")[0] + state.Region = extutil.MustHaveValue(request.Target.Attributes, "aws.region")[0] state.Param = failureInjectionParam[0] state.ExperimentKey = request.ExecutionContext.ExperimentKey state.ExecutionId = request.ExecutionContext.ExecutionId @@ -87,7 +89,7 @@ func (a *lambdaAction) Start(ctx context.Context, state *LambdaActionState) (*ac return nil, extension_kit.ToError("Failed to convert ssm parameter", err) } - client, err := a.clientProvider(state.Account) + client, err := a.clientProvider(state.Account, state.Region) if err != nil { return nil, extension_kit.ToError(fmt.Sprintf("Failed to initialize lambda client for AWS account %s", state.Account), err) } @@ -117,7 +119,7 @@ func (a *lambdaAction) Start(ctx context.Context, state *LambdaActionState) (*ac } func (a *lambdaAction) Stop(ctx context.Context, state *LambdaActionState) (*action_kit_api.StopResult, error) { - client, err := a.clientProvider(state.Account) + client, err := a.clientProvider(state.Account, state.Region) if err != nil { return nil, extension_kit.ToError("Failed to create ssm client", err) } @@ -135,11 +137,11 @@ func (a *lambdaAction) Stop(ctx context.Context, state *LambdaActionState) (*act return nil, nil } -func defaultClientProvider(account string) (ssmApi, error) { - awsAccount, err := utils.Accounts.GetAccount(account) +func defaultClientProvider(account string, region string) (ssmApi, error) { + awsAccess, err := utils.GetAwsAccess(account, region) if err != nil { return nil, err } - client := ssm.NewFromConfig(awsAccount.AwsConfig) + client := ssm.NewFromConfig(awsAccess.AwsConfig) return client, nil } diff --git a/extlambda/attack_test.go b/extlambda/attack_test.go index 4e8ae15c..c3bbb646 100644 --- a/extlambda/attack_test.go +++ b/extlambda/attack_test.go @@ -32,10 +32,12 @@ func TestLambdaAction_Prepare(t *testing.T) { name: "Should return config", attributes: map[string][]string{ "aws.account": {"123456789012"}, + "aws.region": {"us-west-1"}, "aws.lambda.failure-injection-param": {"PARAM"}, }, wantedState: &LambdaActionState{ Account: "123456789012", + Region: "us-west-1", Param: "PARAM", Config: &config, }, @@ -44,6 +46,7 @@ func TestLambdaAction_Prepare(t *testing.T) { name: "Should return error if failure-injection-param is missing", attributes: map[string][]string{ "aws.account": {"123456789012"}, + "aws.region": {"us-west-1"}, }, wantedError: extension_kit.ToError("Target is missing the 'aws.lambda.failure-injection-param' attribute. Did you wrap the lambda with https://github.com/steadybit/failure-lambda ?", nil), }, @@ -73,6 +76,7 @@ func TestLambdaAction_Prepare(t *testing.T) { if tt.wantedState != nil { assert.NoError(t, err) assert.Equal(t, tt.wantedState.Account, state.Account) + assert.Equal(t, tt.wantedState.Region, state.Region) assert.Equal(t, tt.wantedState.Param, state.Param) assert.EqualValues(t, tt.wantedState.Config, state.Config) } @@ -97,12 +101,13 @@ func TestLambdaAction_Start(t *testing.T) { }, mock.Anything).Return(&ssm.AddTagsToResourceOutput{}, nil) action := lambdaAction{ - clientProvider: func(account string) (ssmApi, error) { + clientProvider: func(account string, region string) (ssmApi, error) { return api, nil }, } state := action.NewEmptyState() state.Account = "123456789012" + state.Region = "us-west-1" state.Param = "PARAM" state.ExperimentKey = extutil.Ptr("TEST-1") state.ExecutionId = extutil.Ptr(42) @@ -126,12 +131,13 @@ func TestLambdaAction_Stop(t *testing.T) { }, mock.Anything).Return(&ssm.DeleteParameterOutput{}, nil) action := lambdaAction{ - clientProvider: func(account string) (ssmApi, error) { + clientProvider: func(account string, region string) (ssmApi, error) { return api, nil }, } state := action.NewEmptyState() state.Account = "123456789012" + state.Region = "us-west-1" state.Param = "PARAM" result, err := action.Stop(context.Background(), &state) diff --git a/extlambda/discovery.go b/extlambda/discovery.go index 3588d67a..c6feb7ef 100644 --- a/extlambda/discovery.go +++ b/extlambda/discovery.go @@ -173,10 +173,10 @@ func (*lambdaDiscovery) DescribeAttributes() []discovery_kit_api.AttributeDescri } func (l *lambdaDiscovery) DiscoverTargets(ctx context.Context) ([]discovery_kit_api.Target, error) { - return utils.ForEveryAccount(utils.Accounts, getTargetsForAccount, ctx, "lambda") + return utils.ForEveryConfiguredAwsAccess(getTargetsForAccount, ctx, "lambda") } -func getTargetsForAccount(account *utils.AwsAccount, ctx context.Context) ([]discovery_kit_api.Target, error) { +func getTargetsForAccount(account *utils.AwsAccess, ctx context.Context) ([]discovery_kit_api.Target, error) { client := lambda.NewFromConfig(account.AwsConfig) result, err := getAllAwsLambdaFunctions(ctx, client, account.AccountNumber, account.AwsConfig.Region) if err != nil { diff --git a/extmsk/cluster_discovery.go b/extmsk/cluster_discovery.go index 710c745d..8153dd55 100644 --- a/extmsk/cluster_discovery.go +++ b/extmsk/cluster_discovery.go @@ -151,10 +151,10 @@ func (r *mskClusterDiscovery) DescribeAttributes() []discovery_kit_api.Attribute } func (r *mskClusterDiscovery) DiscoverTargets(ctx context.Context) ([]discovery_kit_api.Target, error) { - return utils.ForEveryAccount(utils.Accounts, getClusterTargetsForAccount, ctx, "msk-cluster") + return utils.ForEveryConfiguredAwsAccess(getClusterTargetsForAccount, ctx, "msk-cluster") } -func getClusterTargetsForAccount(account *utils.AwsAccount, ctx context.Context) ([]discovery_kit_api.Target, error) { +func getClusterTargetsForAccount(account *utils.AwsAccess, ctx context.Context) ([]discovery_kit_api.Target, error) { client := kafka.NewFromConfig(account.AwsConfig) result, err := getAllMskClusters(ctx, client, account.AccountNumber, account.AwsConfig.Region) if err != nil { diff --git a/extmsk/common.go b/extmsk/common.go index 2b1ae3cc..f8acc024 100644 --- a/extmsk/common.go +++ b/extmsk/common.go @@ -20,6 +20,7 @@ type KafkaAttackState struct { ClusterARN string ClusterName string Account string + Region string } type MskApi interface { @@ -28,10 +29,10 @@ type MskApi interface { ListNodes(ctx context.Context, params *kafka.ListNodesInput, optFns ...func(*kafka.Options)) (*kafka.ListNodesOutput, error) } -func defaultMskClientProvider(account string) (MskApi, error) { - awsAccount, err := utils.Accounts.GetAccount(account) +func defaultMskClientProvider(account string, region string) (MskApi, error) { + awsAccess, err := utils.GetAwsAccess(account, region) if err != nil { return nil, err } - return kafka.NewFromConfig(awsAccount.AwsConfig), nil + return kafka.NewFromConfig(awsAccess.AwsConfig), nil } diff --git a/extmsk/reboot_broker_attack.go b/extmsk/reboot_broker_attack.go index 67a208b7..da801308 100644 --- a/extmsk/reboot_broker_attack.go +++ b/extmsk/reboot_broker_attack.go @@ -15,7 +15,7 @@ import ( ) type mskRebootBrokerAttack struct { - clientProvider func(account string) (MskApi, error) + clientProvider func(account string, region string) (MskApi, error) } var _ action_kit_sdk.Action[KafkaAttackState] = (*mskRebootBrokerAttack)(nil) @@ -55,6 +55,7 @@ func (f mskRebootBrokerAttack) Describe() action_kit_api.ActionDescription { func (f mskRebootBrokerAttack) Prepare(_ context.Context, state *KafkaAttackState, request action_kit_api.PrepareActionRequestBody) (*action_kit_api.PrepareResult, error) { state.Account = extutil.MustHaveValue(request.Target.Attributes, "aws.account")[0] + state.Region = extutil.MustHaveValue(request.Target.Attributes, "aws.region")[0] state.ClusterARN = extutil.MustHaveValue(request.Target.Attributes, "aws.msk.cluster.arn")[0] state.ClusterName = extutil.MustHaveValue(request.Target.Attributes, "aws.msk.cluster.name")[0] state.BrokerID = extutil.MustHaveValue(request.Target.Attributes, "aws.msk.cluster.broker.id")[0] @@ -62,7 +63,7 @@ func (f mskRebootBrokerAttack) Prepare(_ context.Context, state *KafkaAttackStat } func (f mskRebootBrokerAttack) Start(ctx context.Context, state *KafkaAttackState) (*action_kit_api.StartResult, error) { - client, err := f.clientProvider(state.Account) + client, err := f.clientProvider(state.Account, state.Region) if err != nil { return nil, extension_kit.ToError(fmt.Sprintf("Failed to initialize Msk client for AWS account %s", state.Account), err) } diff --git a/extmsk/reboot_broker_attack_test.go b/extmsk/reboot_broker_attack_test.go index 09daa12b..e54de5c0 100644 --- a/extmsk/reboot_broker_attack_test.go +++ b/extmsk/reboot_broker_attack_test.go @@ -24,6 +24,7 @@ func TestRebootBroker(t *testing.T) { "aws.msk.cluster.broker.id": {"1"}, "aws.msk.cluster.name": {"test"}, "aws.account": {"42"}, + "aws.region": {"us-west-1"}, }, }), }) @@ -39,6 +40,7 @@ func TestRebootBroker(t *testing.T) { assert.Equal(t, "arn", state.ClusterARN) assert.Equal(t, "test", state.ClusterName) assert.Equal(t, "42", state.Account) + assert.Equal(t, "us-west-1", state.Region) assert.Equal(t, "1", state.BrokerID) } @@ -56,8 +58,9 @@ func TestStartRebootBroker(t *testing.T) { BrokerID: "1", BrokerARN: "broker-arn", Account: "42", + Region: "us-west-1", } - action := mskRebootBrokerAttack{clientProvider: func(account string) (MskApi, error) { + action := mskRebootBrokerAttack{clientProvider: func(account string, region string) (MskApi, error) { return api, nil }} @@ -76,7 +79,7 @@ func TestStartClusterFailoverForwardFailoverError(t *testing.T) { state := KafkaAttackState{ ClusterARN: "arn", } - action := mskRebootBrokerAttack{clientProvider: func(account string) (MskApi, error) { + action := mskRebootBrokerAttack{clientProvider: func(account string, region string) (MskApi, error) { return api, nil }} diff --git a/extrds/cluster_attack_failover.go b/extrds/cluster_attack_failover.go index 8e5012a9..fc066650 100644 --- a/extrds/cluster_attack_failover.go +++ b/extrds/cluster_attack_failover.go @@ -15,7 +15,7 @@ import ( ) type rdsClusterFailoverAttack struct { - clientProvider func(account string) (rdsDBClusterApi, error) + clientProvider func(account string, region string) (rdsDBClusterApi, error) } var _ action_kit_sdk.Action[RdsClusterAttackState] = (*rdsClusterFailoverAttack)(nil) @@ -57,7 +57,7 @@ func (f rdsClusterFailoverAttack) Prepare(_ context.Context, state *RdsClusterAt } func (f rdsClusterFailoverAttack) Start(ctx context.Context, state *RdsClusterAttackState) (*action_kit_api.StartResult, error) { - client, err := f.clientProvider(state.Account) + client, err := f.clientProvider(state.Account, state.Region) if err != nil { return nil, extension_kit.ToError(fmt.Sprintf("Failed to initialize RDS client for AWS account %s", state.Account), err) } diff --git a/extrds/cluster_attack_failover_test.go b/extrds/cluster_attack_failover_test.go index 45d320c8..694ab2c7 100644 --- a/extrds/cluster_attack_failover_test.go +++ b/extrds/cluster_attack_failover_test.go @@ -22,6 +22,7 @@ func TestPrepareClusterFailover(t *testing.T) { Attributes: map[string][]string{ "aws.rds.cluster.id": {"my-cluster"}, "aws.account": {"42"}, + "aws.region": {"us-west-1"}, }, }), }) @@ -35,6 +36,8 @@ func TestPrepareClusterFailover(t *testing.T) { // Then assert.NoError(t, err) assert.Equal(t, "my-cluster", state.DBClusterIdentifier) + assert.Equal(t, "42", state.Account) + assert.Equal(t, "us-west-1", state.Region) } func TestStartClusterFailover(t *testing.T) { @@ -47,8 +50,9 @@ func TestStartClusterFailover(t *testing.T) { state := RdsClusterAttackState{ DBClusterIdentifier: "dev-db", Account: "42", + Region: "us-west-1", } - action := rdsClusterFailoverAttack{clientProvider: func(account string) (rdsDBClusterApi, error) { + action := rdsClusterFailoverAttack{clientProvider: func(account string, region string) (rdsDBClusterApi, error) { return api, nil }} @@ -67,7 +71,7 @@ func TestStartClusterFailoverForwardFailoverError(t *testing.T) { state := RdsClusterAttackState{ DBClusterIdentifier: "dev-db", } - action := rdsClusterFailoverAttack{clientProvider: func(account string) (rdsDBClusterApi, error) { + action := rdsClusterFailoverAttack{clientProvider: func(account string, region string) (rdsDBClusterApi, error) { return api, nil }} diff --git a/extrds/cluster_discovery.go b/extrds/cluster_discovery.go index 651b38aa..aa5f7cbf 100644 --- a/extrds/cluster_discovery.go +++ b/extrds/cluster_discovery.go @@ -121,10 +121,10 @@ func (r *rdsClusterDiscovery) DescribeAttributes() []discovery_kit_api.Attribute } func (r *rdsClusterDiscovery) DiscoverTargets(ctx context.Context) ([]discovery_kit_api.Target, error) { - return utils.ForEveryAccount(utils.Accounts, getClusterTargetsForAccount, ctx, "rds-cluster") + return utils.ForEveryConfiguredAwsAccess(getClusterTargetsForAccount, ctx, "rds-cluster") } -func getClusterTargetsForAccount(account *utils.AwsAccount, ctx context.Context) ([]discovery_kit_api.Target, error) { +func getClusterTargetsForAccount(account *utils.AwsAccess, ctx context.Context) ([]discovery_kit_api.Target, error) { client := rds.NewFromConfig(account.AwsConfig) result, err := getAllRdsClusters(ctx, client, account.AccountNumber, account.AwsConfig.Region) if err != nil { diff --git a/extrds/common_cluster.go b/extrds/common_cluster.go index 26e42852..a47ca214 100644 --- a/extrds/common_cluster.go +++ b/extrds/common_cluster.go @@ -18,6 +18,7 @@ const ( type RdsClusterAttackState struct { DBClusterIdentifier string Account string + Region string } type rdsDBClusterApi interface { @@ -27,13 +28,14 @@ type rdsDBClusterApi interface { func convertClusterAttackState(request action_kit_api.PrepareActionRequestBody, state *RdsClusterAttackState) error { state.Account = extutil.MustHaveValue(request.Target.Attributes, "aws.account")[0] + state.Region = extutil.MustHaveValue(request.Target.Attributes, "aws.region")[0] state.DBClusterIdentifier = extutil.MustHaveValue(request.Target.Attributes, "aws.rds.cluster.id")[0] return nil } -func defaultClusterClientProvider(account string) (rdsDBClusterApi, error) { - awsAccount, err := utils.Accounts.GetAccount(account) +func defaultClusterClientProvider(account string, region string) (rdsDBClusterApi, error) { + awsAccess, err := utils.GetAwsAccess(account, region) if err != nil { return nil, err } - return rds.NewFromConfig(awsAccount.AwsConfig), nil + return rds.NewFromConfig(awsAccess.AwsConfig), nil } diff --git a/extrds/common_instance.go b/extrds/common_instance.go index b672c7aa..14a68d66 100644 --- a/extrds/common_instance.go +++ b/extrds/common_instance.go @@ -8,7 +8,6 @@ import ( "github.com/aws/aws-sdk-go-v2/service/rds" "github.com/steadybit/action-kit/go/action_kit_api/v2" "github.com/steadybit/extension-aws/utils" - extension_kit "github.com/steadybit/extension-kit" "github.com/steadybit/extension-kit/extutil" ) @@ -19,6 +18,7 @@ const ( type RdsInstanceAttackState struct { DBInstanceIdentifier string Account string + Region string ForceFailover bool } @@ -30,26 +30,17 @@ type rdsDBInstanceApi interface { } func convertInstanceAttackState(request action_kit_api.PrepareActionRequestBody, state *RdsInstanceAttackState) error { - instanceId := request.Target.Attributes["aws.rds.instance.id"] - if len(instanceId) == 0 { - return extension_kit.ToError("Target is missing the 'aws.rds.instance.id' target attribute.", nil) - } - - account := request.Target.Attributes["aws.account"] - if len(account) == 0 { - return extension_kit.ToError("Target is missing the 'aws.account' target attribute.", nil) - } - - state.Account = account[0] - state.DBInstanceIdentifier = instanceId[0] + state.Account = extutil.MustHaveValue(request.Target.Attributes, "aws.account")[0] + state.Region = extutil.MustHaveValue(request.Target.Attributes, "aws.region")[0] + state.DBInstanceIdentifier = extutil.MustHaveValue(request.Target.Attributes, "aws.rds.instance.id")[0] state.ForceFailover = extutil.ToBool(request.Config["force-failover"]) return nil } -func defaultInstanceClientProvider(account string) (rdsDBInstanceApi, error) { - awsAccount, err := utils.Accounts.GetAccount(account) +func defaultInstanceClientProvider(account string, region string) (rdsDBInstanceApi, error) { + awsAccess, err := utils.GetAwsAccess(account, region) if err != nil { return nil, err } - return rds.NewFromConfig(awsAccount.AwsConfig), nil + return rds.NewFromConfig(awsAccess.AwsConfig), nil } diff --git a/extrds/common_instance_test.go b/extrds/common_instance_test.go index b353c49c..2f8d51ee 100644 --- a/extrds/common_instance_test.go +++ b/extrds/common_instance_test.go @@ -40,7 +40,7 @@ type zoneMock struct { mock.Mock } -func (m *zoneMock) GetZone(awsAccountNumber string, awsZone string) *types.AvailabilityZone { - args := m.Called(awsAccountNumber, awsZone) +func (m *zoneMock) GetZone(awsAccountNumber string, awsZone string, region string) *types.AvailabilityZone { + args := m.Called(awsAccountNumber, awsZone, region) return args.Get(0).(*types.AvailabilityZone) } diff --git a/extrds/instance_attack_reboot.go b/extrds/instance_attack_reboot.go index a08780d0..c76374a8 100644 --- a/extrds/instance_attack_reboot.go +++ b/extrds/instance_attack_reboot.go @@ -15,7 +15,7 @@ import ( ) type rdsInstanceRebootAttack struct { - clientProvider func(account string) (rdsDBInstanceApi, error) + clientProvider func(account string, region string) (rdsDBInstanceApi, error) } var _ action_kit_sdk.Action[RdsInstanceAttackState] = (*rdsInstanceRebootAttack)(nil) @@ -70,7 +70,7 @@ func (f rdsInstanceRebootAttack) Prepare(_ context.Context, state *RdsInstanceAt } func (f rdsInstanceRebootAttack) Start(ctx context.Context, state *RdsInstanceAttackState) (*action_kit_api.StartResult, error) { - client, err := f.clientProvider(state.Account) + client, err := f.clientProvider(state.Account, state.Region) if err != nil { return nil, extension_kit.ToError(fmt.Sprintf("Failed to initialize RDS client for AWS account %s", state.Account), err) } diff --git a/extrds/instance_attack_reboot_test.go b/extrds/instance_attack_reboot_test.go index bf90fb9b..3d1a3391 100644 --- a/extrds/instance_attack_reboot_test.go +++ b/extrds/instance_attack_reboot_test.go @@ -25,6 +25,7 @@ func TestPrepareInstanceReboot(t *testing.T) { Attributes: map[string][]string{ "aws.rds.instance.id": {"my-instance"}, "aws.account": {"42"}, + "aws.region": {"us-west-1"}, }, }), }) @@ -38,49 +39,11 @@ func TestPrepareInstanceReboot(t *testing.T) { // Then assert.NoError(t, err) assert.Equal(t, "my-instance", state.DBInstanceIdentifier) + assert.Equal(t, "42", state.Account) + assert.Equal(t, "us-west-1", state.Region) assert.Equal(t, true, state.ForceFailover) } -func TestPrepareInstanceRebootMustRequireAnInstanceId(t *testing.T) { - // Given - requestBody := extutil.JsonMangle(action_kit_api.PrepareActionRequestBody{ - Target: extutil.Ptr(action_kit_api.Target{ - Attributes: map[string][]string{ - "aws.account": {"42"}, - }, - }), - }) - - attack := rdsInstanceRebootAttack{} - state := attack.NewEmptyState() - - // When - _, err := attack.Prepare(context.Background(), &state, requestBody) - - // Then - assert.ErrorContains(t, err, "aws.rds.instance.id") -} - -func TestPrepareInstanceRebootMustRequireAnAccountId(t *testing.T) { - // Given - requestBody := extutil.JsonMangle(action_kit_api.PrepareActionRequestBody{ - Target: extutil.Ptr(action_kit_api.Target{ - Attributes: map[string][]string{ - "aws.rds.instance.id": {"my-instance"}, - }, - }), - }) - - attack := rdsInstanceRebootAttack{} - state := attack.NewEmptyState() - - // When - _, err := attack.Prepare(context.Background(), &state, requestBody) - - // Then - assert.ErrorContains(t, err, "aws.account") -} - func TestStartInstanceReboot(t *testing.T) { // Given api := new(rdsDBInstanceApiMock) @@ -92,9 +55,10 @@ func TestStartInstanceReboot(t *testing.T) { state := RdsInstanceAttackState{ DBInstanceIdentifier: "dev-db", Account: "42", + Region: "us-west-1", ForceFailover: true, } - action := rdsInstanceRebootAttack{clientProvider: func(account string) (rdsDBInstanceApi, error) { + action := rdsInstanceRebootAttack{clientProvider: func(account string, region string) (rdsDBInstanceApi, error) { return api, nil }} @@ -113,7 +77,7 @@ func TestStartInstanceRebootForwardRebootError(t *testing.T) { state := RdsInstanceAttackState{ DBInstanceIdentifier: "dev-db", } - action := rdsInstanceRebootAttack{clientProvider: func(account string) (rdsDBInstanceApi, error) { + action := rdsInstanceRebootAttack{clientProvider: func(account string, region string) (rdsDBInstanceApi, error) { return api, nil }} diff --git a/extrds/instance_attack_stop.go b/extrds/instance_attack_stop.go index 25a5f929..c07e9bd4 100644 --- a/extrds/instance_attack_stop.go +++ b/extrds/instance_attack_stop.go @@ -15,7 +15,7 @@ import ( ) type rdsInstanceStopAttack struct { - clientProvider func(account string) (rdsDBInstanceApi, error) + clientProvider func(account string, region string) (rdsDBInstanceApi, error) } var ( @@ -63,7 +63,7 @@ func (f rdsInstanceStopAttack) Prepare(_ context.Context, state *RdsInstanceAtta } func (f rdsInstanceStopAttack) Start(ctx context.Context, state *RdsInstanceAttackState) (*action_kit_api.StartResult, error) { - client, err := f.clientProvider(state.Account) + client, err := f.clientProvider(state.Account, state.Region) if err != nil { return nil, extension_kit.ToError(fmt.Sprintf("Failed to initialize RDS client for AWS account %s", state.Account), err) } diff --git a/extrds/instance_attack_stop_test.go b/extrds/instance_attack_stop_test.go index 5f3807ac..489e9885 100644 --- a/extrds/instance_attack_stop_test.go +++ b/extrds/instance_attack_stop_test.go @@ -22,6 +22,7 @@ func TestPrepareInstanceStop(t *testing.T) { Attributes: map[string][]string{ "aws.rds.instance.id": {"my-instance"}, "aws.account": {"42"}, + "aws.region": {"us-west-1"}, }, }), }) @@ -35,46 +36,8 @@ func TestPrepareInstanceStop(t *testing.T) { // Then assert.NoError(t, err) assert.Equal(t, "my-instance", state.DBInstanceIdentifier) -} - -func TestPrepareInstanceStopMustRequireAnInstanceId(t *testing.T) { - // Given - requestBody := extutil.JsonMangle(action_kit_api.PrepareActionRequestBody{ - Target: extutil.Ptr(action_kit_api.Target{ - Attributes: map[string][]string{ - "aws.account": {"42"}, - }, - }), - }) - - attack := rdsInstanceStopAttack{} - state := attack.NewEmptyState() - - // When - _, err := attack.Prepare(context.Background(), &state, requestBody) - - // Then - assert.ErrorContains(t, err, "aws.rds.instance.id") -} - -func TestPrepareInstanceStopMustRequireAnAccountId(t *testing.T) { - // Given - requestBody := extutil.JsonMangle(action_kit_api.PrepareActionRequestBody{ - Target: extutil.Ptr(action_kit_api.Target{ - Attributes: map[string][]string{ - "aws.rds.instance.id": {"my-instance"}, - }, - }), - }) - - attack := rdsInstanceStopAttack{} - state := attack.NewEmptyState() - - // When - _, err := attack.Prepare(context.Background(), &state, requestBody) - - // Then - assert.ErrorContains(t, err, "aws.account") + assert.Equal(t, "42", state.Account) + assert.Equal(t, "us-west-1", state.Region) } func TestStartInstanceStop(t *testing.T) { @@ -87,8 +50,9 @@ func TestStartInstanceStop(t *testing.T) { state := RdsInstanceAttackState{ DBInstanceIdentifier: "dev-db", Account: "42", + Region: "us-west-1", } - action := rdsInstanceStopAttack{clientProvider: func(account string) (rdsDBInstanceApi, error) { + action := rdsInstanceStopAttack{clientProvider: func(account string, region string) (rdsDBInstanceApi, error) { return api, nil }} @@ -107,7 +71,7 @@ func TestStartInstanceStopForwardStopError(t *testing.T) { state := RdsInstanceAttackState{ DBInstanceIdentifier: "dev-db", } - action := rdsInstanceStopAttack{clientProvider: func(account string) (rdsDBInstanceApi, error) { + action := rdsInstanceStopAttack{clientProvider: func(account string, region string) (rdsDBInstanceApi, error) { return api, nil }} diff --git a/extrds/instance_discovery.go b/extrds/instance_discovery.go index ca84744a..de665551 100644 --- a/extrds/instance_discovery.go +++ b/extrds/instance_discovery.go @@ -94,10 +94,10 @@ func (r *rdsInstanceDiscovery) DescribeAttributes() []discovery_kit_api.Attribut } func (r *rdsInstanceDiscovery) DiscoverTargets(ctx context.Context) ([]discovery_kit_api.Target, error) { - return utils.ForEveryAccount(utils.Accounts, getInstanceTargetsForAccount, ctx, "rds-instance") + return utils.ForEveryConfiguredAwsAccess(getInstanceTargetsForAccount, ctx, "rds-instance") } -func getInstanceTargetsForAccount(account *utils.AwsAccount, ctx context.Context) ([]discovery_kit_api.Target, error) { +func getInstanceTargetsForAccount(account *utils.AwsAccess, ctx context.Context) ([]discovery_kit_api.Target, error) { client := rds.NewFromConfig(account.AwsConfig) result, err := getAllRdsInstances(ctx, client, utils.Zones, account.AccountNumber, account.AwsConfig.Region) if err != nil { @@ -133,7 +133,7 @@ func toInstanceTarget(dbInstance types.DBInstance, zoneUtil utils.GetZoneUtil, a arn := aws.ToString(dbInstance.DBInstanceArn) label := aws.ToString(dbInstance.DBInstanceIdentifier) availabilityZoneName := aws.ToString(dbInstance.AvailabilityZone) - availabilityZoneApi := zoneUtil.GetZone(awsAccountNumber, availabilityZoneName) + availabilityZoneApi := zoneUtil.GetZone(awsAccountNumber, availabilityZoneName, awsRegion) attributes := make(map[string][]string) attributes["aws.account"] = []string{awsAccountNumber} diff --git a/extrds/instance_discovery_test.go b/extrds/instance_discovery_test.go index 5661cf86..7f4bc9d6 100644 --- a/extrds/instance_discovery_test.go +++ b/extrds/instance_discovery_test.go @@ -43,7 +43,7 @@ func TestGetAllRdsInstances(t *testing.T) { RegionName: discovery_kit_api.Ptr("us-east-1"), ZoneId: discovery_kit_api.Ptr("us-east-1a-id"), } - mockedZoneUtil.On("GetZone", mock.Anything, mock.Anything).Return(&mockedZone) + mockedZoneUtil.On("GetZone", mock.Anything, mock.Anything, mock.Anything).Return(&mockedZone) // When targets, err := getAllRdsInstances(context.Background(), mockedApi, mockedZoneUtil, "42", "us-east-1") @@ -91,7 +91,7 @@ func TestGetAllRdsInstancesWithoutCluster(t *testing.T) { RegionName: discovery_kit_api.Ptr("us-east-1"), ZoneId: discovery_kit_api.Ptr("us-east-1a-id"), } - mockedZoneUtil.On("GetZone", mock.Anything, mock.Anything).Return(&mockedZone) + mockedZoneUtil.On("GetZone", mock.Anything, mock.Anything, mock.Anything).Return(&mockedZone) // When targets, err := getAllRdsInstances(context.Background(), mockedApi, mockedZoneUtil, "42", "us-east-1") @@ -118,7 +118,7 @@ func TestGetAllRdsInstancesWithPagination(t *testing.T) { RegionName: discovery_kit_api.Ptr("us-east-1"), ZoneId: discovery_kit_api.Ptr("us-east-1a-id"), } - mockedZoneUtil.On("GetZone", mock.Anything, mock.Anything).Return(&mockedZone) + mockedZoneUtil.On("GetZone", mock.Anything, mock.Anything, mock.Anything).Return(&mockedZone) withMarker := mock.MatchedBy(func(arg *rds.DescribeDBInstancesInput) bool { return arg.Marker != nil @@ -171,7 +171,7 @@ func TestGetAllRdsInstancesError(t *testing.T) { RegionName: discovery_kit_api.Ptr("us-east-1"), ZoneId: discovery_kit_api.Ptr("us-east-1a-id"), } - mockedZoneUtil.On("GetZone", mock.Anything, mock.Anything).Return(&mockedZone) + mockedZoneUtil.On("GetZone", mock.Anything, mock.Anything, mock.Anything).Return(&mockedZone) mockedApi.On("DescribeDBInstances", mock.Anything, mock.Anything).Return(nil, errors.New("expected")) diff --git a/main.go b/main.go index b5cf9c00..04650a38 100644 --- a/main.go +++ b/main.go @@ -43,7 +43,7 @@ func main() { exthealth.SetReady(false) config.ParseConfiguration() - utils.InitializeAwsAccountAccess(config.Config) + utils.InitializeAwsAccess(config.Config) utils.InitializeAwsZones() ctx, cancel := SignalCanceledContext() diff --git a/main_test.go b/main_test.go index 72e38a5b..7c205299 100644 --- a/main_test.go +++ b/main_test.go @@ -2,11 +2,9 @@ package main import ( "context" - "github.com/aws/aws-sdk-go-v2/aws" "github.com/steadybit/action-kit/go/action_kit_sdk" "github.com/steadybit/discovery-kit/go/discovery_kit_sdk" "github.com/steadybit/extension-aws/config" - "github.com/steadybit/extension-aws/utils" "github.com/stretchr/testify/assert" "golang.org/x/exp/slices" "net/http" @@ -155,13 +153,6 @@ func Test_getExtensionList(t *testing.T) { }, }, } - utils.Accounts = &utils.AwsAccounts{ - RootAccount: utils.AwsAccount{ - AccountNumber: "123456789012", - AwsConfig: aws.Config{}, - }, - Accounts: make(map[string]utils.AwsAccount), - } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { diff --git a/utils/aws_access.go b/utils/aws_access.go new file mode 100644 index 00000000..88cae4b8 --- /dev/null +++ b/utils/aws_access.go @@ -0,0 +1,155 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: 2023 Steadybit GmbH + +package utils + +import ( + "context" + "fmt" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/config" + "github.com/aws/aws-sdk-go-v2/credentials/stscreds" + "github.com/aws/aws-sdk-go-v2/service/sts" + "github.com/aws/smithy-go/middleware" + "github.com/rs/zerolog/log" + "github.com/steadybit/discovery-kit/go/discovery_kit_api" + extConfig "github.com/steadybit/extension-aws/config" +) + +type AwsAccess struct { + AccountNumber string + Region string + AwsConfig aws.Config +} + +type Regions map[string]AwsAccess + +var ( + rootAccountNumber string + accounts map[string]Regions +) + +func InitializeAwsAccess(specification extConfig.Specification) { + ctx := context.Background() + awsConfigForRootAccount, err := config.LoadDefaultConfig(ctx) + if err != nil { + log.Fatal().Err(err).Msgf("Failed to load AWS configuration") + } + + log.Info().Msgf("Starting in region %s", awsConfigForRootAccount.Region) + awsConfigForRootAccount.Logger = logForwarder{} + awsConfigForRootAccount.ClientLogMode = aws.LogRequest + awsConfigForRootAccount.APIOptions = append(awsConfigForRootAccount.APIOptions, func(stack *middleware.Stack) error { + return stack.Initialize.Add(customLoggerMiddleware, middleware.After) + }) + + if specification.AwsEndpointOverride != "" { + log.Warn().Msgf("Overriding AWS base endpoint with '%s'", specification.AwsEndpointOverride) + awsConfigForRootAccount.BaseEndpoint = &specification.AwsEndpointOverride + } + + stsClientForRootAccount := sts.NewFromConfig(awsConfigForRootAccount) + identityOutputRoot, err := stsClientForRootAccount.GetCallerIdentity(ctx, nil) + if err != nil { + log.Fatal().Err(err).Msgf("Failed to identify AWS account number") + } + + rootAccountNumber = aws.ToString(identityOutputRoot.Account) + accounts = make(map[string]Regions) + + regions := []string{awsConfigForRootAccount.Region} + if len(specification.Regions) > 0 { + regions = specification.Regions + } + + if len(specification.AssumeRoles) > 0 { + log.Debug().Msgf("Executing role assumption in other AWS Accounts.") + for _, roleArn := range specification.AssumeRoles { + awsConfig := awsConfigForRootAccount.Copy() + awsConfig.Credentials = aws.NewCredentialsCache(stscreds.NewAssumeRoleProvider(stsClientForRootAccount, roleArn, func(o *stscreds.AssumeRoleOptions) { + o.RoleSessionName = "steadybit-extension-aws" + })) + + stsClient := sts.NewFromConfig(awsConfig) + identityOutput, err := stsClient.GetCallerIdentity(context.Background(), nil) + if err != nil { + log.Error().Err(err).Msgf("Failed to identify AWS account number for account assumed via role '%s'. The roleArn will be ignored until the next restart of the extension.", roleArn) + continue + } + assumedAccount := aws.ToString(identityOutput.Account) + log.Info().Msgf("Successfully assumed role '%s' in account '%s'", roleArn, assumedAccount) + prepareRegionConfigs(regions, awsConfig, assumedAccount) + } + } else { + prepareRegionConfigs(regions, awsConfigForRootAccount, aws.ToString(identityOutputRoot.Account)) + } +} + +func prepareRegionConfigs(regions []string, awsConfig aws.Config, account string) { + if _, ok := accounts[account]; !ok { + accounts[account] = make(map[string]AwsAccess) + } + for _, region := range regions { + regionalConfig := awsConfig.Copy() + regionalConfig.Region = region + accounts[account][region] = AwsAccess{ + AccountNumber: account, + AwsConfig: regionalConfig, + Region: region, + } + } +} + +func GetRootAccountNumber() string { + return rootAccountNumber +} + +func GetAwsAccess(accountNumber string, region string) (*AwsAccess, error) { + account, ok := accounts[accountNumber] + if ok { + if regionAccount, ok := account[region]; ok { + return ®ionAccount, nil + } + } + return nil, fmt.Errorf("AWS Config for account '%s' and region '%s' not found", accountNumber, region) +} + +func ForEveryConfiguredAwsAccess(supplier func(account *AwsAccess, ctx context.Context) ([]discovery_kit_api.Target, error), ctx context.Context, discovery string) ([]discovery_kit_api.Target, error) { + count := 0 + for _, regions := range accounts { + for range regions { + count++ + } + } + if count > 0 { + accountsChannel := make(chan AwsAccess, count) + resultsChannel := make(chan []discovery_kit_api.Target, count) + for w := 1; w <= extConfig.Config.WorkerThreads; w++ { + go func(w int, accounts <-chan AwsAccess, result <-chan []discovery_kit_api.Target) { + for account := range accounts { + log.Trace().Int("worker", w).Msgf("Collecting %s for account %s in region %s", discovery, account.AccountNumber, account.Region) + eachResult, eachErr := supplier(&account, ctx) + if eachErr != nil { + log.Err(eachErr).Msgf("Failed to collect %s for account %s in region %s", discovery, account.AccountNumber, account.Region) + } + resultsChannel <- eachResult + } + }(w, accountsChannel, resultsChannel) + } + for _, regions := range accounts { + for _, account := range regions { + accountsChannel <- account + } + } + close(accountsChannel) + resultTargets := make([]discovery_kit_api.Target, 0) + for a := 1; a <= count; a++ { + targets := <-resultsChannel + if targets != nil { + resultTargets = append(resultTargets, targets...) + } + } + return resultTargets, nil + } + return []discovery_kit_api.Target{}, nil +} diff --git a/utils/aws_access_test.go b/utils/aws_access_test.go new file mode 100644 index 00000000..85cf833c --- /dev/null +++ b/utils/aws_access_test.go @@ -0,0 +1,201 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: 2023 Steadybit GmbH + +package utils + +import ( + "context" + "errors" + "github.com/steadybit/discovery-kit/go/discovery_kit_api" + "github.com/steadybit/extension-aws/config" + "github.com/steadybit/extension-kit/extutil" + "github.com/stretchr/testify/require" + "sort" + "testing" + "time" +) + +func TestGetAccountSupportsRootAccount(t *testing.T) { + accounts = getTestAccountsWithoutRoleAssumption() + + account, err := GetAwsAccess("root", "us-east-1") + + require.NoError(t, err) + require.Equal(t, "root", account.AccountNumber) + require.Equal(t, "us-east-1", account.Region) +} + +func TestGetAccountSupportsAssumedAccount(t *testing.T) { + accounts = getTestAccountsWithRoleAssumption() + + account, err := GetAwsAccess("assumed2", "eu-central-1") + + require.NoError(t, err) + require.Equal(t, "assumed2", account.AccountNumber) + require.Equal(t, "eu-central-1", account.Region) +} + +func TestGetAccountReportsErrorWhenMissing(t *testing.T) { + accounts = getTestAccountsWithRoleAssumption() + + account, err := GetAwsAccess("unknown-account", "eu-central-1") + + require.ErrorContains(t, err, "unknown-account") + require.Nil(t, account) +} + +func TestForEachAccountWithoutRoleAssumption(t *testing.T) { + config.Config.WorkerThreads = 1 + accounts = getTestAccountsWithoutRoleAssumption() + + result, err := ForEveryConfiguredAwsAccess(getTestFunction(nil, nil), context.Background(), "discovery") + + require.NoError(t, err) + var values []string + for _, target := range result { + values = append(values, target.Attributes["aws.account"][0]+"@"+target.Attributes["aws.region"][0]) + } + require.Equal(t, []string{"root@us-east-1"}, values) +} + +func TestForEachAccountWithRoleAssumptionAndSingleWorker(t *testing.T) { + config.Config.WorkerThreads = 1 + accounts = getTestAccountsWithRoleAssumption() + + result, err := ForEveryConfiguredAwsAccess(getTestFunction(nil, nil), context.Background(), "discovery") + + require.NoError(t, err) + // for stable test execution + var values []string + for _, target := range result { + values = append(values, target.Attributes["aws.account"][0]+"@"+target.Attributes["aws.region"][0]) + } + sort.Strings(values) + require.Equal(t, []string{"assumed1@eu-central-1", "assumed1@us-east-1", "assumed2@eu-central-1", "assumed2@us-east-1", "assumed3@eu-central-1", "assumed3@us-east-1", "assumed4@eu-central-1", "assumed4@us-east-1"}, values) +} + +func TestForEachAccountWithRoleAssumptionAndMultipleWorkers(t *testing.T) { + config.Config.WorkerThreads = 4 + accounts = getTestAccountsWithRoleAssumption() + + result, err := ForEveryConfiguredAwsAccess(getTestFunction(nil, nil), context.Background(), "discovery") + + require.NoError(t, err) + // for stable test execution + var values []string + for _, target := range result { + values = append(values, target.Attributes["aws.account"][0]+"@"+target.Attributes["aws.region"][0]) + } + sort.Strings(values) + require.Equal(t, []string{"assumed1@eu-central-1", "assumed1@us-east-1", "assumed2@eu-central-1", "assumed2@us-east-1", "assumed3@eu-central-1", "assumed3@us-east-1", "assumed4@eu-central-1", "assumed4@us-east-1"}, values) +} + +func TestForEachAccountWithRoleAssumptionAndError(t *testing.T) { + config.Config.WorkerThreads = 4 + accounts = getTestAccountsWithRoleAssumption() + + result, err := ForEveryConfiguredAwsAccess(getTestFunction(extutil.Ptr("assumed2"), nil), context.Background(), "discovery") + + require.NoError(t, err) + // for stable test execution + var values []string + for _, target := range result { + values = append(values, target.Attributes["aws.account"][0]+"@"+target.Attributes["aws.region"][0]) + } + sort.Strings(values) + require.Equal(t, []string{"assumed1@eu-central-1", "assumed1@us-east-1", "assumed3@eu-central-1", "assumed3@us-east-1", "assumed4@eu-central-1", "assumed4@us-east-1"}, values) +} + +func TestForEachAccountWithRoleAssumptionAndEmptyLists(t *testing.T) { + config.Config.WorkerThreads = 4 + accounts = getTestAccountsWithRoleAssumption() + + result, err := ForEveryConfiguredAwsAccess(getTestFunction(nil, extutil.Ptr("assumed2")), context.Background(), "discovery") + + require.NoError(t, err) + // for stable test execution + var values []string + for _, target := range result { + values = append(values, target.Attributes["aws.account"][0]+"@"+target.Attributes["aws.region"][0]) + } + sort.Strings(values) + require.Equal(t, []string{"assumed1@eu-central-1", "assumed1@us-east-1", "assumed3@eu-central-1", "assumed3@us-east-1", "assumed4@eu-central-1", "assumed4@us-east-1"}, values) +} + +func getTestFunction(errorForAccount *string, emptyForAccount *string) func(account *AwsAccess, ctx context.Context) ([]discovery_kit_api.Target, error) { + return func(account *AwsAccess, ctx context.Context) ([]discovery_kit_api.Target, error) { + if (errorForAccount != nil) && (*errorForAccount == account.AccountNumber) { + return nil, errors.New("damn broken discovery") + } + if (emptyForAccount != nil) && (*emptyForAccount == account.AccountNumber) { + return []discovery_kit_api.Target{}, nil + } + var targets []discovery_kit_api.Target + targets = append(targets, discovery_kit_api.Target{ + TargetType: "example", + Label: "label", + Attributes: map[string][]string{ + "aws.account": {account.AccountNumber}, + "aws.region": {account.Region}, + }, + }) + time.Sleep(100 * time.Millisecond) + return targets, nil + } +} + +func getTestAccountsWithRoleAssumption() map[string]Regions { + return map[string]Regions{ + "assumed1": { + "us-east-1": { + AccountNumber: "assumed1", + Region: "us-east-1", + }, + "eu-central-1": { + AccountNumber: "assumed1", + Region: "eu-central-1", + }, + }, + "assumed2": { + "us-east-1": { + AccountNumber: "assumed2", + Region: "us-east-1", + }, + "eu-central-1": { + AccountNumber: "assumed2", + Region: "eu-central-1", + }, + }, + "assumed3": { + "us-east-1": { + AccountNumber: "assumed3", + Region: "us-east-1", + }, + "eu-central-1": { + AccountNumber: "assumed3", + Region: "eu-central-1", + }, + }, + "assumed4": { + "us-east-1": { + AccountNumber: "assumed4", + Region: "us-east-1", + }, + "eu-central-1": { + AccountNumber: "assumed4", + Region: "eu-central-1", + }, + }, + } +} + +func getTestAccountsWithoutRoleAssumption() map[string]Regions { + return map[string]Regions{ + "root": { + "us-east-1": { + AccountNumber: "root", + Region: "us-east-1", + }, + }, + } +} diff --git a/utils/aws_accounts.go b/utils/aws_accounts.go deleted file mode 100644 index d545527a..00000000 --- a/utils/aws_accounts.go +++ /dev/null @@ -1,85 +0,0 @@ -// SPDX-License-Identifier: MIT -// SPDX-FileCopyrightText: 2023 Steadybit GmbH - -package utils - -import ( - "context" - "fmt" - "github.com/aws/aws-sdk-go-v2/aws" - "github.com/rs/zerolog/log" - "github.com/steadybit/discovery-kit/go/discovery_kit_api" - "github.com/steadybit/extension-aws/config" -) - -type AwsAccount struct { - AccountNumber string - AwsConfig aws.Config -} - -type AwsAccounts struct { - RootAccount AwsAccount - - // accounts is a map of AWS account numbers to AwsAccount for which roles are to be assumed. - Accounts map[string]AwsAccount -} - -type GetAccountApi interface { - GetAccount(accountNumber string) (*AwsAccount, error) -} - -func (accounts *AwsAccounts) GetRootAccount() *AwsAccount { - return &accounts.RootAccount -} - -func (accounts *AwsAccounts) GetAccount(accountNumber string) (*AwsAccount, error) { - account, ok := accounts.Accounts[accountNumber] - if ok { - return &account, nil - } - - if accountNumber == accounts.RootAccount.AccountNumber { - return &accounts.RootAccount, nil - } - - return nil, fmt.Errorf("AWS account '%s' not found", accountNumber) -} - -func ForEveryAccount( - accounts *AwsAccounts, - supplier func(account *AwsAccount, ctx context.Context) ([]discovery_kit_api.Target, error), - ctx context.Context, - discovery string, -) ([]discovery_kit_api.Target, error) { - numAccounts := len(accounts.Accounts) - if numAccounts > 0 { - accountsChannel := make(chan AwsAccount, numAccounts) - resultsChannel := make(chan []discovery_kit_api.Target, numAccounts) - for w := 1; w <= config.Config.WorkerThreads; w++ { - go func(w int, accounts <-chan AwsAccount, result <-chan []discovery_kit_api.Target) { - for account := range accounts { - log.Trace().Int("worker", w).Msgf("Collecting %s for account %s", discovery, account.AccountNumber) - eachResult, eachErr := supplier(&account, ctx) - if eachErr != nil { - log.Err(eachErr).Msgf("Failed to collect %s for account %s", discovery, account.AccountNumber) - } - resultsChannel <- eachResult - } - }(w, accountsChannel, resultsChannel) - } - for _, account := range accounts.Accounts { - accountsChannel <- account - } - close(accountsChannel) - var resultTargets []discovery_kit_api.Target - for a := 1; a <= numAccounts; a++ { - targets := <-resultsChannel - if targets != nil { - resultTargets = append(resultTargets, targets...) - } - } - return resultTargets, nil - } else { - return supplier(&accounts.RootAccount, ctx) - } -} diff --git a/utils/aws_accounts_test.go b/utils/aws_accounts_test.go deleted file mode 100644 index 488c9194..00000000 --- a/utils/aws_accounts_test.go +++ /dev/null @@ -1,210 +0,0 @@ -// SPDX-License-Identifier: MIT -// SPDX-FileCopyrightText: 2023 Steadybit GmbH - -package utils - -import ( - "context" - "errors" - "github.com/aws/aws-sdk-go-v2/aws" - "github.com/steadybit/discovery-kit/go/discovery_kit_api" - "github.com/steadybit/extension-aws/config" - "github.com/steadybit/extension-kit/extutil" - "github.com/stretchr/testify/require" - "sort" - "testing" - "time" -) - -func TestGetAccountSupportsRootAccount(t *testing.T) { - accounts := getTestAccountsWithRoleAssumption() - - account, err := accounts.GetAccount("root") - - require.NoError(t, err) - require.Equal(t, "root", account.AccountNumber) -} - -func TestGetAccountSupportsAssumedAccount(t *testing.T) { - accounts := getTestAccountsWithRoleAssumption() - - account, err := accounts.GetAccount("assumed2") - - require.NoError(t, err) - require.Equal(t, "assumed2", account.AccountNumber) -} - -func TestMustPreferAssumedAccount(t *testing.T) { - accounts := AwsAccounts{ - RootAccount: AwsAccount{ - AccountNumber: "root", - }, - Accounts: map[string]AwsAccount{ - "assumed1": { - AccountNumber: "assumed1", - }, - "root": { - AccountNumber: "root", - AwsConfig: aws.Config{}, - }, - }, - } - - account, err := accounts.GetAccount("root") - - require.NoError(t, err) - require.NotNil(t, account.AwsConfig) -} - -func TestGetAccountReportsErrorWhenMissing(t *testing.T) { - accounts := getTestAccountsWithRoleAssumption() - - account, err := accounts.GetAccount("unknown-account") - - require.ErrorContains(t, err, "unknown-account") - require.Nil(t, account) -} - -func TestForEachAccountWithoutRoleAssumption(t *testing.T) { - accounts := getTestAccountsWithoutRoleAssumption() - - result, err := ForEveryAccount(&accounts, getTestFunction(nil, nil), context.Background(), "discovery") - - require.NoError(t, err) - var values []string - for _, target := range result { - values = append(values, target.Attributes["aws.account"][0]) - } - require.Equal(t, []string{"root"}, values) -} - -func TestForEachAccountWithRoleAssumptionAndSingleWorker(t *testing.T) { - config.Config.WorkerThreads = 1 - accounts := getTestAccountsWithRoleAssumption() - - result, err := ForEveryAccount(&accounts, getTestFunction(nil, nil), context.Background(), "discovery") - - require.NoError(t, err) - // for stable test execution - var values []string - for _, target := range result { - values = append(values, target.Attributes["aws.account"][0]) - } - sort.Strings(values) - require.Equal(t, []string{"assumed1", "assumed2", "assumed3", "assumed4", "assumed5", "assumed6", "assumed7", "assumed8", "assumed9"}, values) -} - -func TestForEachAccountWithRoleAssumptionAndMultipleWorkers(t *testing.T) { - config.Config.WorkerThreads = 4 - accounts := getTestAccountsWithRoleAssumption() - - result, err := ForEveryAccount(&accounts, getTestFunction(nil, nil), context.Background(), "discovery") - - require.NoError(t, err) - // for stable test execution - var values []string - for _, target := range result { - values = append(values, target.Attributes["aws.account"][0]) - } - sort.Strings(values) - require.Equal(t, []string{"assumed1", "assumed2", "assumed3", "assumed4", "assumed5", "assumed6", "assumed7", "assumed8", "assumed9"}, values) -} - -func TestForEachAccountWithRoleAssumptionAndError(t *testing.T) { - config.Config.WorkerThreads = 4 - accounts := getTestAccountsWithRoleAssumption() - - result, err := ForEveryAccount(&accounts, getTestFunction(extutil.Ptr("assumed2"), nil), context.Background(), "discovery") - - require.NoError(t, err) - // for stable test execution - var values []string - for _, target := range result { - values = append(values, target.Attributes["aws.account"][0]) - } - sort.Strings(values) - require.Equal(t, []string{"assumed1", "assumed3", "assumed4", "assumed5", "assumed6", "assumed7", "assumed8", "assumed9"}, values) -} - -func TestForEachAccountWithRoleAssumptionAndEmptyLists(t *testing.T) { - config.Config.WorkerThreads = 4 - accounts := getTestAccountsWithRoleAssumption() - - result, err := ForEveryAccount(&accounts, getTestFunction(nil, extutil.Ptr("assumed2")), context.Background(), "discovery") - - require.NoError(t, err) - // for stable test execution - var values []string - for _, target := range result { - values = append(values, target.Attributes["aws.account"][0]) - } - sort.Strings(values) - require.Equal(t, []string{"assumed1", "assumed3", "assumed4", "assumed5", "assumed6", "assumed7", "assumed8", "assumed9"}, values) -} - -func getTestFunction(errorForAccount *string, emptyForAccount *string) func(account *AwsAccount, ctx context.Context) ([]discovery_kit_api.Target, error) { - return func(account *AwsAccount, ctx context.Context) ([]discovery_kit_api.Target, error) { - if (errorForAccount != nil) && (*errorForAccount == account.AccountNumber) { - return nil, errors.New("damn broken discovery") - } - if (emptyForAccount != nil) && (*emptyForAccount == account.AccountNumber) { - return []discovery_kit_api.Target{}, nil - } - var targets []discovery_kit_api.Target - targets = append(targets, discovery_kit_api.Target{ - TargetType: "example", - Label: "label", - Attributes: map[string][]string{ - "aws.account": {account.AccountNumber}, - }, - }) - time.Sleep(100 * time.Millisecond) - return targets, nil - } -} - -func getTestAccountsWithRoleAssumption() AwsAccounts { - return AwsAccounts{ - RootAccount: AwsAccount{ - AccountNumber: "root", - }, - Accounts: map[string]AwsAccount{ - "assumed1": { - AccountNumber: "assumed1", - }, - "assumed2": { - AccountNumber: "assumed2", - }, - "assumed3": { - AccountNumber: "assumed3", - }, - "assumed4": { - AccountNumber: "assumed4", - }, - "assumed5": { - AccountNumber: "assumed5", - }, - "assumed6": { - AccountNumber: "assumed6", - }, - "assumed7": { - AccountNumber: "assumed7", - }, - "assumed8": { - AccountNumber: "assumed8", - }, - "assumed9": { - AccountNumber: "assumed9", - }, - }, - } -} - -func getTestAccountsWithoutRoleAssumption() AwsAccounts { - return AwsAccounts{ - RootAccount: AwsAccount{ - AccountNumber: "root", - }, - Accounts: map[string]AwsAccount{}, - } -} diff --git a/utils/aws_zones.go b/utils/aws_zones.go index c03b8e10..1e826df1 100644 --- a/utils/aws_zones.go +++ b/utils/aws_zones.go @@ -24,16 +24,16 @@ func InitializeAwsZones() { Zones = &AwsZones{ zones: sync.Map{}, } - _, _ = ForEveryAccount(Accounts, initAwsZonesForAccount, context.Background(), "availability zone") + _, _ = ForEveryConfiguredAwsAccess(initAwsZonesForAccount, context.Background(), "availability zone") } -func initAwsZonesForAccount(account *AwsAccount, ctx context.Context) ([]discovery_kit_api.Target, error) { - return initAwsZonesForAccountWithClient(ec2.NewFromConfig(account.AwsConfig), account.AccountNumber, ctx) +func initAwsZonesForAccount(account *AwsAccess, ctx context.Context) ([]discovery_kit_api.Target, error) { + return initAwsZonesForAccountWithClient(ec2.NewFromConfig(account.AwsConfig), account.AccountNumber, account.Region, ctx) } -func initAwsZonesForAccountWithClient(client AZDescribeAvailabilityZonesApi, awsAccountNumber string, ctx context.Context) ([]discovery_kit_api.Target, error) { - result := getAllAvailabilityZones(ctx, client, awsAccountNumber) - Zones.zones.Store(awsAccountNumber, result) +func initAwsZonesForAccountWithClient(client AZDescribeAvailabilityZonesApi, awsAccountNumber string, region string, ctx context.Context) ([]discovery_kit_api.Target, error) { + result := getAllAvailabilityZones(ctx, client, awsAccountNumber, region) + Zones.zones.Store(awsAccountNumber+"-"+region, result) return nil, nil } @@ -41,39 +41,39 @@ type AZDescribeAvailabilityZonesApi interface { DescribeAvailabilityZones(ctx context.Context, params *ec2.DescribeAvailabilityZonesInput, optFns ...func(*ec2.Options)) (*ec2.DescribeAvailabilityZonesOutput, error) } -func getAllAvailabilityZones(ctx context.Context, ec2Api AZDescribeAvailabilityZonesApi, awsAccountNumber string) []types.AvailabilityZone { +func getAllAvailabilityZones(ctx context.Context, ec2Api AZDescribeAvailabilityZonesApi, awsAccountNumber string, region string) []types.AvailabilityZone { output, err := ec2Api.DescribeAvailabilityZones(ctx, &ec2.DescribeAvailabilityZonesInput{ AllAvailabilityZones: aws.Bool(false), }) if err != nil { var re *awshttp.ResponseError if errors.As(err, &re) && re.HTTPStatusCode() == 403 { - log.Error().Msgf("Not Authorized to discover availability zones for account %s. If this is intended, you can disable the discovery by setting STEADYBIT_EXTENSION_DISCOVERY_DISABLED_ZONE=true. Details: %s", awsAccountNumber, re.Error()) + log.Error().Msgf("Not Authorized to discover availability zones for account %s and region %s. If this is intended, you can disable the discovery by setting STEADYBIT_EXTENSION_DISCOVERY_DISABLED_ZONE=true. Details: %s", awsAccountNumber, region, re.Error()) return []types.AvailabilityZone{} } - log.Error().Err(err).Msgf("Failed to load availability zones for account %s.", awsAccountNumber) + log.Error().Err(err).Msgf("Failed to load availability zones for account %s and region %s.", awsAccountNumber, region) return []types.AvailabilityZone{} } return output.AvailabilityZones } type GetZoneUtil interface { - GetZone(awsAccountNumber string, awsZone string) *types.AvailabilityZone + GetZone(awsAccountNumber string, region string, awsZone string) *types.AvailabilityZone } type GetZonesUtil interface { - GetZones(awsAccountNumber string) []types.AvailabilityZone + GetZones(awsAccountNumber string, region string) []types.AvailabilityZone } -func (zones *AwsZones) GetZones(awsAccountNumber string) []types.AvailabilityZone { - value, ok := zones.zones.Load(awsAccountNumber) +func (zones *AwsZones) GetZones(awsAccountNumber string, region string) []types.AvailabilityZone { + value, ok := zones.zones.Load(awsAccountNumber + "-" + region) if !ok { return []types.AvailabilityZone{} } return value.([]types.AvailabilityZone) } -func (zones *AwsZones) GetZone(awsAccountNumber string, awsZone string) *types.AvailabilityZone { - value, ok := zones.zones.Load(awsAccountNumber) +func (zones *AwsZones) GetZone(awsAccountNumber string, region string, awsZone string) *types.AvailabilityZone { + value, ok := zones.zones.Load(awsAccountNumber + "-" + region) if !ok { return nil } diff --git a/utils/aws_zones_test.go b/utils/aws_zones_test.go index dedebb42..7b82d513 100644 --- a/utils/aws_zones_test.go +++ b/utils/aws_zones_test.go @@ -52,19 +52,19 @@ func TestAwsZones(t *testing.T) { mockedApi4711.On("DescribeAvailabilityZones", mock.Anything, mock.Anything, mock.Anything).Return(&mockedReturnValue4711, nil) // When - result, err := initAwsZonesForAccountWithClient(mockedApi42, "42", context.Background()) + result, err := initAwsZonesForAccountWithClient(mockedApi42, "42", "eu-central-1", context.Background()) assert.Nil(t, result) assert.Nil(t, err) - result, err = initAwsZonesForAccountWithClient(mockedApi4711, "4711", context.Background()) + result, err = initAwsZonesForAccountWithClient(mockedApi4711, "4711", "eu-central-1", context.Background()) assert.Nil(t, result) assert.Nil(t, err) // Then - assert.Equal(t, &mockedReturnValue42.AvailabilityZones[0], Zones.GetZone("42", "eu-central-1a")) - assert.Nil(t, Zones.GetZone("42", "eu-central-1c")) - assert.Nil(t, Zones.GetZone("4711", "eu-central-1a")) + assert.Equal(t, &mockedReturnValue42.AvailabilityZones[0], Zones.GetZone("42", "eu-central-1", "eu-central-1a")) + assert.Nil(t, Zones.GetZone("42", "eu-central-1", "eu-central-1c")) + assert.Nil(t, Zones.GetZone("4711", "eu-central-1", "eu-central-1a")) - assert.Equal(t, mockedReturnValue42.AvailabilityZones, Zones.GetZones("42")) - assert.Equal(t, []types.AvailabilityZone{}, Zones.GetZones("4711")) - assert.Equal(t, []types.AvailabilityZone{}, Zones.GetZones("0815")) + assert.Equal(t, mockedReturnValue42.AvailabilityZones, Zones.GetZones("42", "eu-central-1")) + assert.Equal(t, []types.AvailabilityZone{}, Zones.GetZones("4711", "eu-central-1")) + assert.Equal(t, []types.AvailabilityZone{}, Zones.GetZones("0815", "eu-central-1")) } diff --git a/utils/init.go b/utils/init.go deleted file mode 100644 index e564636d..00000000 --- a/utils/init.go +++ /dev/null @@ -1,116 +0,0 @@ -// SPDX-License-Identifier: MIT -// SPDX-FileCopyrightText: 2023 Steadybit GmbH - -package utils - -import ( - "context" - "github.com/aws/aws-sdk-go-v2/aws" - middleware2 "github.com/aws/aws-sdk-go-v2/aws/middleware" - "github.com/aws/aws-sdk-go-v2/config" - "github.com/aws/aws-sdk-go-v2/credentials/stscreds" - "github.com/aws/aws-sdk-go-v2/service/sts" - "github.com/aws/smithy-go/logging" - "github.com/aws/smithy-go/middleware" - "github.com/rs/zerolog/log" - extConfig "github.com/steadybit/extension-aws/config" - "strings" -) - -var ( - Accounts *AwsAccounts -) - -func InitializeAwsAccountAccess(specification extConfig.Specification) { - ctx := context.Background() - awsConfigForRootAccount, err := config.LoadDefaultConfig(ctx) - if err != nil { - log.Fatal().Err(err).Msgf("Failed to load AWS configuration") - } - - log.Info().Msgf("Starting in region %s", awsConfigForRootAccount.Region) - awsConfigForRootAccount.Logger = logForwarder{} - awsConfigForRootAccount.ClientLogMode = aws.LogRequest - awsConfigForRootAccount.APIOptions = append(awsConfigForRootAccount.APIOptions, func(stack *middleware.Stack) error { - return stack.Initialize.Add(customLoggerMiddleware, middleware.After) - }) - - if specification.AwsEndpointOverride != "" { - log.Warn().Msgf("Overriding AWS base endpoint with '%s'", specification.AwsEndpointOverride) - awsConfigForRootAccount.BaseEndpoint = &specification.AwsEndpointOverride - } - - stsClientForRootAccount := sts.NewFromConfig(awsConfigForRootAccount) - identityOutput, err := stsClientForRootAccount.GetCallerIdentity(ctx, nil) - if err != nil { - log.Fatal().Err(err).Msgf("Failed to identify AWS account number") - } - - Accounts = &AwsAccounts{ - RootAccount: AwsAccount{ - AccountNumber: aws.ToString(identityOutput.Account), - AwsConfig: awsConfigForRootAccount, - }, - Accounts: make(map[string]AwsAccount), - } - - if len(specification.AssumeRoles) > 0 { - log.Debug().Msgf("Executing role assumption in other AWS Accounts.") - - for _, roleArn := range specification.AssumeRoles { - assumedAccount := initializeRoleAssumption(stsClientForRootAccount, roleArn, Accounts.RootAccount) - if assumedAccount != nil { - Accounts.Accounts[assumedAccount.AccountNumber] = *assumedAccount - } - } - } -} - -type logForwarder struct { -} - -func (logger logForwarder) Logf(classification logging.Classification, format string, v ...interface{}) { - switch classification { - case logging.Debug: - log.Trace().Msgf(format, v...) - case logging.Warn: - log.Warn().Msgf(format, v...) - } -} - -var customLoggerMiddleware = middleware.InitializeMiddlewareFunc("customLoggerMiddleware", - func(ctx context.Context, in middleware.InitializeInput, next middleware.InitializeHandler) (out middleware.InitializeOutput, metadata middleware.Metadata, err error) { - operationName := middleware2.GetOperationName(ctx) - if strings.HasPrefix(operationName, "List") || - strings.HasPrefix(operationName, "Describe") || - strings.HasPrefix(operationName, "Get") || - strings.HasPrefix(operationName, "Assume") { - log.Trace().Msgf("AWS-Call: %s - %s", middleware2.GetServiceID(ctx), operationName) - } else { - log.Info().Msgf("AWS-Call: %s - %s", middleware2.GetServiceID(ctx), operationName) - } - return next.HandleInitialize(ctx, in) - }) - -func initializeRoleAssumption(stsServiceForRootAccount *sts.Client, roleArn string, rootAccount AwsAccount) *AwsAccount { - awsConfig := rootAccount.AwsConfig.Copy() - awsConfig.Credentials = aws.NewCredentialsCache(stscreds.NewAssumeRoleProvider(stsServiceForRootAccount, roleArn, setSessionName)) - - stsClient := sts.NewFromConfig(awsConfig) - identityOutput, err := stsClient.GetCallerIdentity(context.Background(), nil) - if err != nil { - log.Error().Err(err).Msgf("Failed to identify AWS account number for account assumed via role '%s'. The roleArn will be ignored until the next restart of the extension.", roleArn) - return nil - } - - log.Info().Msgf("Successfully assumed role '%s' in account '%s' (region '%s')", roleArn, aws.ToString(identityOutput.Account), awsConfig.Region) - - return &AwsAccount{ - AccountNumber: aws.ToString(identityOutput.Account), - AwsConfig: awsConfig, - } -} - -func setSessionName(o *stscreds.AssumeRoleOptions) { - o.RoleSessionName = "steadybit-extension-aws" -} diff --git a/utils/sdk_logging.go b/utils/sdk_logging.go new file mode 100644 index 00000000..d3a4e52a --- /dev/null +++ b/utils/sdk_logging.go @@ -0,0 +1,36 @@ +package utils + +import ( + "context" + middleware2 "github.com/aws/aws-sdk-go-v2/aws/middleware" + "github.com/aws/smithy-go/logging" + "github.com/aws/smithy-go/middleware" + "github.com/rs/zerolog/log" + "strings" +) + +type logForwarder struct { +} + +func (logger logForwarder) Logf(classification logging.Classification, format string, v ...interface{}) { + switch classification { + case logging.Debug: + log.Trace().Msgf(format, v...) + case logging.Warn: + log.Warn().Msgf(format, v...) + } +} + +var customLoggerMiddleware = middleware.InitializeMiddlewareFunc("customLoggerMiddleware", + func(ctx context.Context, in middleware.InitializeInput, next middleware.InitializeHandler) (out middleware.InitializeOutput, metadata middleware.Metadata, err error) { + operationName := middleware2.GetOperationName(ctx) + if strings.HasPrefix(operationName, "List") || + strings.HasPrefix(operationName, "Describe") || + strings.HasPrefix(operationName, "Assume") || + strings.HasPrefix(operationName, "Get") { + log.Trace().Msgf("AWS-Call: %s - %s - %s", middleware2.GetRegion(ctx), middleware2.GetServiceID(ctx), operationName) + } else { + log.Info().Msgf("AWS-Call: %s - %s - %s", middleware2.GetRegion(ctx), middleware2.GetServiceID(ctx), operationName) + } + return next.HandleInitialize(ctx, in) + })