Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement cross account support to Security Lake integration #657

Open
wants to merge 8 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 10 additions & 8 deletions integrations/amazon-security-lake/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -93,14 +93,16 @@ Follow the [official documentation](https://docs.aws.amazon.com/lambda/latest/dg
- Use the [Makefile](./Makefile) to generate the zip package `wazuh_to_amazon_security_lake.zip`, and upload it to the S3 bucket created previously as per [these instructions](https://docs.aws.amazon.com/lambda/latest/dg/gettingstarted-package.html#gettingstarted-package-zip). See [CONTRIBUTING](./CONTRIBUTING.md) for details about the Makefile.
- Configure the Lambda with the at least the required _Environment Variables_ below:

| Environment variable | Required | Value |
| -------------------- | -------- | -------------------------------------------------------------------------------------------------- |
| AWS_BUCKET | True | The name of the Amazon S3 bucket in which Security Lake stores your custom source data |
| SOURCE_LOCATION | True | The _Data source name_ of the _Custom Source_ |
| ACCOUNT_ID | True | Enter the ID that you specified when creating your Amazon Security Lake custom source |
| REGION | True | AWS Region to which the data is written |
| S3_BUCKET_OCSF | False | S3 bucket to which the mapped events are written |
| OCSF_CLASS | False | The OCSF class to map the events into. Can be "SECURITY_FINDING" (default) or "DETECTION_FINDING". |
| Environment variable | Required | Value |
| -------------------- | -------- | --------------------------------------------------------------------------------------------------------- |
| AWS_BUCKET | True | The name of the Amazon S3 bucket in which Security Lake stores your custom source data |
| SOURCE_LOCATION | True | The _Data source name_ of the _Custom Source_ |
| ACCOUNT_ID | True | Enter the ID that you specified when creating your Amazon Security Lake custom source |
| ROLE_ARN | True | The ARN of the role that the Lambda function assumes to write data to the Amazon Security Lake S3 bucket |
| EXTERNAL_ID | True | The External ID that you specified when creating your Amazon Security Lake custom source |
| REGION | True | AWS Region to which the data is written |
| S3_BUCKET_OCSF | False | S3 bucket to which the mapped events are written |
| OCSF_CLASS | False | The OCSF class to map the events into. Can be "SECURITY_FINDING" (default) or "DETECTION_FINDING". |

### Validation

Expand Down
119 changes: 87 additions & 32 deletions integrations/amazon-security-lake/src/lambda_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,31 +12,71 @@
logger = logging.getLogger()
logger.setLevel("INFO")

# Initialize boto3 client outside the handler
if os.environ.get('IS_DEV'):
s3_client = boto3.client(

def get_dev_credentials():
return {
'AccessKeyId': os.environ['AWS_ACCESS_KEY_ID'],
'SecretAccessKey': os.environ['AWS_SECRET_ACCESS_KEY'],
'Region': os.environ['REGION'],
'AWSEndpoint': os.environ['AWS_ENDPOINT']
}


def assume_role(arn: str, external_id: str, session_name: str) -> dict:
"""
Assume a role and return temporary security credentials.
"""
sts_client = boto3.client('sts')
try:
response = sts_client.assume_role(
RoleArn=arn,
RoleSessionName=session_name,
ExternalId=external_id
)
credentials = response['Credentials']
return {
'AccessKeyId': credentials['AccessKeyId'],
'SecretAccessKey': credentials['SecretAccessKey'],
'SessionToken': credentials['SessionToken']
}
except ClientError as e:
logger.error(f"Failed to assume role {arn} with external ID {external_id}: {e}")
return None


def get_s3_client(credentials: dict = None, is_dev=False) -> boto3.client:
"""
Return an S3 client using temporary credentials if provided, otherwise use default credentials.
"""
if not credentials:
return boto3.client('s3')
if is_dev:
return boto3.client(
service_name='s3',
aws_access_key_id=credentials['AccessKeyId'],
aws_secret_access_key=credentials['SecretAccessKey'],
region_name=credentials['Region'],
endpoint_url=credentials['AWSEndpoint'],
)
return boto3.client(
service_name='s3',
aws_access_key_id=os.environ.get('AWS_ACCESS_KEY_ID'),
aws_secret_access_key=os.environ.get('AWS_SECRET_ACCESS_KEY'),
region_name=os.environ.get('REGION'),
endpoint_url=os.environ.get('AWS_ENDPOINT'),
aws_access_key_id=credentials['AccessKeyId'],
aws_secret_access_key=credentials['SecretAccessKey'],
aws_session_token=credentials['SessionToken']
)
else:
s3_client = boto3.client('s3')


def get_events(bucket: str, key: str) -> list:
def get_events(bucket: str, key: str, client: boto3.client) -> list:
"""
Retrieve events from S3 object.
"""
logger.info(f"Reading {key}.")
try:
response = s3_client.get_object(Bucket=bucket, Key=key)
response = client.get_object(Bucket=bucket, Key=key)
data = gzip.decompress(response['Body'].read()).decode('utf-8')
return data.splitlines()
except ClientError as e:
logger.error(
f"Failed to read S3 object {key} from bucket {bucket}: {e}")
logger.error(f"Failed to read S3 object {key} from bucket {bucket}: {e}")
return []


Expand All @@ -48,18 +88,17 @@ def write_parquet_file(ocsf_events: list, filename: str) -> None:
pq.write_table(table, filename, compression='ZSTD')


def upload_to_s3(bucket: str, key: str, filename: str) -> bool:
def upload_to_s3(bucket: str, key: str, filename: str, client: boto3.client) -> bool:
"""
Upload a file to S3 bucket.
"""
logger.info(f"Uploading data to {bucket}.")
try:
with open(filename, 'rb') as data:
s3_client.put_object(Bucket=bucket, Key=key, Body=data)
client.put_object(Bucket=bucket, Key=key, Body=data)
return True
except ClientError as e:
logger.error(
f"Failed to upload file {filename} to bucket {bucket}: {e}")
logger.error(f"Failed to upload file {filename} to bucket {bucket}: {e}")
return False


Expand All @@ -69,7 +108,7 @@ def exit_on_error(error_message):
Args:
error_message (str): Error message to display.
"""
print(f"Error: {error_message}")
logger.error(f"Error: {error_message}")
exit(1)


Expand All @@ -92,21 +131,21 @@ def check_environment_variables(variables):
def get_full_key(src_location: str, account_id: str, region: str, key: str, format: str) -> str:
"""
Constructs a full S3 key path for storing a Parquet file based on event metadata.

Args:
src_location (str): Source location identifier.
account_id (str): AWS account ID associated with the event.
region (str): AWS region where the event occurred.
key (str): Event key containing metadata information.
format (str): File extension.

Returns:
str: Full S3 key path for storing the Parquet file.

Example:
If key is '20240417_ls.s3.0055f22e-200e-4259-b865-8ccea05812be.2024-04-17T15.45.part29.txt',
this function will return:
'ext/src_location/region=region/accountId=account_id/eventDay=20240417/0055f22e200e4259b8658ccea05812be.parquet'
If key is '20240417_ls.s3.0055f22e-200e-4259-b865-8ccea05812be.2024-04-17T15.45.part29.txt',
this function will return:


'ext/src_location/region=region/accountId=account_id/eventDay=20240417/0055f22e200e4259b8658ccea05812be.parquet'
"""
# Extract event day from the key (first 8 characters)
event_day = key[:8]
Expand All @@ -116,9 +155,7 @@ def get_full_key(src_location: str, account_id: str, region: str, key: str, form
filename = ''.join(filename_parts[2].split('-'))

# Construct the full S3 key path for storing the file
key = (
f'ext/{src_location}/region={region}/accountId={account_id}/eventDay={event_day}/{filename}.{format}'
)
key = f'ext/{src_location}/region={region}/accountId={account_id}/eventDay={event_day}/{filename}.{format}'

return key

Expand All @@ -137,19 +174,37 @@ def lambda_handler(event, context):
src_location = os.environ['SOURCE_LOCATION']
account_id = os.environ['ACCOUNT_ID']
region = os.environ['REGION']
role_arn = os.environ.get('ROLE_ARN')
external_id = os.environ.get('EXTERNAL_ID')
ocsf_bucket = os.environ.get('S3_BUCKET_OCSF')
ocsf_class = os.environ.get('OCSF_CLASS', 'SECURITY_FINDING')
is_dev = os.environ.get('IS_DEV', 'false').lower() == 'true'

# Extract bucket and key from S3 event
src_bucket = event['Records'][0]['s3']['bucket']['name']
key = urllib.parse.unquote_plus(
event['Records'][0]['s3']['object']['key'], encoding='utf-8')
logger.info(f"Lambda function invoked due to {key}.")
logger.info(
f"Source bucket name is {src_bucket}. Destination bucket is {dst_bucket}.")
logger.info(f"Source bucket name is {src_bucket}. Destination bucket is {dst_bucket}.")

# Assume role if ARN and External ID are provided
credentials = None
if is_dev:
credentials = get_dev_credentials()
elif role_arn and external_id:
credentials = assume_role(role_arn, external_id, 'lake-session')
if not credentials:
logger.error("Failed to assume role, cannot proceed.")
return
else:
# Log a warning if cross-account credentials are not used
logger.warning("Cross-account access is not used. Lambda running with default credentials.")

# Create the S3 client
client = get_s3_client(credentials, is_dev)

# Read events from source S3 bucket
raw_events = get_events(src_bucket, key)
raw_events = get_events(src_bucket, key, client)
if not raw_events:
return

Expand All @@ -158,20 +213,20 @@ def lambda_handler(event, context):

# Upload event in OCSF format
ocsf_upload_success = False
if ocsf_bucket is not None:
if ocsf_bucket:
tmp_filename = '/tmp/tmp.json'
with open(tmp_filename, "w") as fd:
fd.write(json.dumps(ocsf_events))
ocsf_key = get_full_key(src_location, account_id, region, key, 'json')
ocsf_upload_success = upload_to_s3(ocsf_bucket, ocsf_key, tmp_filename)
ocsf_upload_success = upload_to_s3(ocsf_bucket, ocsf_key, tmp_filename, client)

# Write OCSF events to Parquet file
tmp_filename = '/tmp/tmp.parquet'
write_parquet_file(ocsf_events, tmp_filename)

# Upload Parquet file to destination S3 bucket
parquet_key = get_full_key(src_location, account_id, region, key, 'parquet')
upload_success = upload_to_s3(dst_bucket, parquet_key, tmp_filename)
upload_success = upload_to_s3(dst_bucket, parquet_key, tmp_filename, client)

# Clean up temporary file
os.remove(tmp_filename)
Expand Down
46 changes: 46 additions & 0 deletions integrations/amazon-security-lake/tests/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
## Amazon Security Lake Unit Tests

This directory contains unit tests for the Amazon Security Lake integration.

## How to run

1. Start a virtual environment:

```shell
python3 -m venv venv
source venv/bin/activate
```

2. Install the requirements:

```shell
pip install -r requirements.txt
```

3. Run the tests:

```shell
pytest -v
```

Execution example:

```shell
% pytest -v
================================================================= test session starts ==================================================================
platform darwin -- Python 3.13.0, pytest-8.3.4, pluggy-1.5.0 -- /Users/quebim_wz/IdeaProjects/wazuh-indexer/integrations/amazon-security-lake/venv/bin/python3.13
cachedir: .pytest_cache
rootdir: /Users/quebim_wz/IdeaProjects/wazuh-indexer/integrations/amazon-security-lake/tests
configfile: pytest.ini
collected 7 items

test_lambda_function.py::test_lambda_handler PASSED [ 14%]
test_lambda_function.py::test_assume_role PASSED [ 28%]
test_lambda_function.py::test_get_s3_client PASSED [ 42%]
test_lambda_function.py::test_get_events PASSED [ 57%]
test_lambda_function.py::test_write_parquet_file PASSED [ 71%]
test_lambda_function.py::test_upload_to_s3 PASSED [ 85%]
test_lambda_function.py::test_get_full_key PASSED [100%]

================================================================== 7 passed in 0.59s ===================================================================
```
4 changes: 4 additions & 0 deletions integrations/amazon-security-lake/tests/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
import os
import sys

sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../src')))
3 changes: 3 additions & 0 deletions integrations/amazon-security-lake/tests/pytest.ini
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
[pytest]
filterwarnings =
ignore::DeprecationWarning
7 changes: 7 additions & 0 deletions integrations/amazon-security-lake/tests/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
moto==5.0.27
pytest==8.3.4
requests==2.32.3
pyarrow>=10.0.1
parquet-tools>=0.2.15
pydantic>=2.6.1
boto3==1.34.46
Loading