Skip to content

Commit

Permalink
Price Alpha Update: Step function flow improvements
Browse files Browse the repository at this point in the history
- Fix batch separation
- Skip last wait in inner loop
  • Loading branch information
nico-corthorn committed Oct 17, 2024
1 parent ce52231 commit d944a1d
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 26 deletions.
2 changes: 1 addition & 1 deletion db/alpha/prices_alpha.sql
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ CREATE TABLE prices_alpha
high numeric(14,2) NOT NULL,
low numeric(14,2) NOT NULL,
close numeric(14,2) NOT NULL,
adjusted_close numeric(14,2) NULL,
adjusted_close numeric(18,2) NULL,
volume bigint NOT NULL,
dividend_amount numeric(14,1) NULL,
split_coefficient numeric(14,1) NULL,
Expand Down
36 changes: 18 additions & 18 deletions esgtools/get_assets.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@ def lambda_handler(event, context):
print("event", event)

# Example
# http://127.0.0.1:3000/get-assets?ref_table=prices_alpha&group=100
# 'queryStringParameters': {'ref_table': 'prices_alpha'}
# {"queryStringParameters": {"ref_table": "prices_alpha"}}

Expand All @@ -52,18 +51,17 @@ def lambda_handler(event, context):
inputs = event["queryStringParameters"]
assert "ref_table" in inputs
ref_table = inputs["ref_table"]
validate = utils.str2bool(inputs["validate"]) \
if "validate" in inputs else False
asset_types = inputs["asset_types"].split(",") \
if "asset_types" in inputs else ["Stock"]
#group = int(inputs["group"]) if "group" in inputs else 10
max_assets_in_batch = int(inputs["max_assets_in_batch"]) if "max_assets_in_batch" in inputs else 75*60
n_lists_in_group = int(inputs["n_lists_in_group"]) if "n_lists_in_group" in inputs else 10
validate = utils.str2bool(inputs.get("validate", "false"))
asset_types = inputs.get("asset_types", "Stock").split(",")
max_assets_in_batch = int(inputs.get("max_assets_in_batch", 75*60))
n_lists_in_batch = int(inputs.get("n_lists_in_batch", 10))
size = inputs.get("size", "full")
print(f"ref_table = {ref_table}")
print(f"validate = {validate}")
print(f"asset_types = {asset_types}")
print(f"max_assets_in_batch = {max_assets_in_batch}")
print(f"n_lists_in_group = {n_lists_in_group}")
print(f"n_lists_in_batch = {n_lists_in_batch}")
print(f"size = {size}")

# Decrypts secret using the associated KMS key.
db_credentials = literal_eval(aws.get_secret("prod/awsportfolio/key"))
Expand Down Expand Up @@ -113,18 +111,20 @@ def lambda_handler(event, context):

assets_sublists = []
if assets.shape[0] > 0:
for i in range(0, len(assets), max_assets_in_group):
symbols_group: pd.Series = assets.loc[i:i+max_assets_in_group-1].symbol
partition_group = np.array_split(symbols_group, n_lists_in_group)
assets_sublists.append([{
"symbols": ",".join(list(sublist)),
"size": size
} for sublist in partition_group])
total_batches = len(range(0, len(assets), max_assets_in_batch))
for i, batch_start in enumerate(range(0, len(assets), max_assets_in_batch)):
symbols_batch: pd.Series = assets.loc[batch_start:batch_start+max_assets_in_batch-1].symbol
partition_batch = np.array_split(symbols_batch, n_lists_in_batch)
is_last_batch = (i == total_batches - 1)
assets_sublists.append({
"batch": [{"symbols": ",".join(list(sublist)), "size": size} for sublist in partition_batch],
"is_last_batch": is_last_batch
})

return {
"statusCode": 200,
"body": json.dumps({
"message": f"Returning {len(assets_sublists)} groups with {n_lists_in_group} sublists each.",
"message": f"Returning {len(assets_sublists)} batches with {n_lists_in_batch} sublists each.",
}),
"assets": assets_sublists,
"assets": assets_sublists
}
25 changes: 20 additions & 5 deletions statemachine/update_all_prices.asl.json
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
{
"Comment": "A state machine that updates prices in hourly batches.",
"Comment": "A state machine that updates prices in hourly batches, skipping the wait for the last batch.",
"StartAt": "Get Assets",
"States": {
"Get Assets": {
Expand All @@ -12,17 +12,17 @@
"ItemsPath": "$.assets",
"MaxConcurrency": 1,
"Iterator": {
"StartAt": "Parallel Process and Wait",
"StartAt": "Parallel Process and Conditional Wait",
"States": {
"Parallel Process and Wait": {
"Parallel Process and Conditional Wait": {
"Type": "Parallel",
"Branches": [
{
"StartAt": "Process Sublists",
"States": {
"Process Sublists": {
"Type": "Map",
"ItemsPath": "$",
"ItemsPath": "$.batch",
"MaxConcurrency": 10,
"Iterator": {
"StartAt": "Update Prices",
Expand Down Expand Up @@ -87,12 +87,27 @@
}
},
{
"StartAt": "Wait One Hour",
"StartAt": "Check If Last Batch",
"States": {
"Check If Last Batch": {
"Type": "Choice",
"Choices": [
{
"Variable": "$.is_last_batch",
"BooleanEquals": true,
"Next": "Skip Wait"
}
],
"Default": "Wait One Hour"
},
"Wait One Hour": {
"Type": "Wait",
"Seconds": 3600,
"End": true
},
"Skip Wait": {
"Type": "Pass",
"End": true
}
}
}
Expand Down
4 changes: 2 additions & 2 deletions tests/unit/test_prices.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import pytest

from esgtools import update_prices
from esgtools import update_prices_alpha


@pytest.fixture()
Expand Down Expand Up @@ -68,7 +68,7 @@ def apigw_event():

def test_prices(apigw_event):

ret = update_prices.lambda_handler(apigw_event, "")
ret = update_prices_alpha.lambda_handler(apigw_event, "")
data = json.loads(ret["body"])

assert ret["statusCode"] == 200
Expand Down

0 comments on commit d944a1d

Please sign in to comment.