-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathModelETL.py
executable file
·157 lines (132 loc) · 7.55 KB
/
ModelETL.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
#!/usr/bin/env python
# coding: utf-8
# ./ModelETL.py -o redditAggregatedData-20230502.parquet -c 1000 <- replace the args
import modelUtils as mu
import pyspark.sql.functions as F
from pyspark.sql import SparkSession
import boto3
from pyspark.sql.types import IntegerType
from pyspark.sql.functions import udf
import argparse
import sys
import os
THIS_DIR = os.path.dirname(os.path.abspath(__file__))
sys.path.append(os.path.join(THIS_DIR, '../'))
import configUtils as cu
# Forcing Timezone keeps things consistent with running on aws and without it timestamps get additional
# timezone conversions when writing to parquet. Setting spark timezone was not enough to fix this
os.environ['TZ'] = 'UTC'
cfg_file = cu.findConfig()
cfg = cu.parseConfig(cfg_file)
spark = (
SparkSession
.builder
.appName('redditData')
.config('spark.driver.extraJavaOptions', '-Duser.timezone=GMT')
.config('spark.executor.extraJavaOptions', '-Duser.timezone=GMT')
.config('spark.sql.session.timeZone', 'UTC')
.config("fs.s3a.access.key", cfg['S3_access']['ACCESSKEY'])
.config("fs.s3a.secret.key", cfg['S3_access']['SECRETKEY'])
.getOrCreate()
)
class Pipeline:
def __init__(self, spark=spark):
self.spark = spark
self.dynamodb_resource = boto3.resource('dynamodb', region_name='us-east-2') # higher level abstractions, recommended to use, fewer methods but creating table returns a table object that you can run operations on, can also grab a Table with Table('name')
# initializations - passed between functions
self.postIdData = None
self.uniqueHotPostIds = None
def extract(self, chunkSize=1000):
###################
# Get Rising Data #
###################
print("Gathering Rising Data...")
risingTable = self.dynamodb_resource.Table('rising')
datesToQuery = mu.daysUntilNow()
print("Dates to query:", datesToQuery)
postIdQueryResult = mu.queryByRangeOfDates(risingTable, datesToQuery) # [{'postId': XXXXXX}, {'postId': YYYYYY}...]
postsOfInterest = {res['postId'] for res in postIdQueryResult}
print("Number of posts found:", len(postsOfInterest))
# this can take a while due to read constraints placed on dynamo db, consider increasing RCU on database
# it can also be slow because converts each dynamodb partition to a spark dataframe,
# this was done so that it would scale better on a distributed system
# over keeping all the data in python in one node and trying to then move it to spark
self.postIdData = mu.getPostIdSparkDataFrame(self.spark, risingTable, postsOfInterest, chunkSize=chunkSize)
pandasTestDf = self.postIdData.limit(5).toPandas()
print(pandasTestDf.to_string())
print("Finished gathering Rising Data.")
###############
# Get Targets #
###############
print("Gathering Hot Data...")
hotTable = self.dynamodb_resource.Table('hot')
hotPosts = mu.queryByRangeOfDates(hotTable, datesToQuery)
self.uniqueHotPostIds = set([p['postId'] for p in hotPosts])
# the hot posts are usually not a very long list and we really only need this for the purpose of creating targets
print("unique hot postIds:", self.uniqueHotPostIds)
print("Finished gathering Hot Data.")
def transform(self):
##################################
# Apply all data transformations #
##################################
# if you don't initialize this, you get an error when you try to broadcast the UDF
postIdData = self.postIdData
uniqueHotPostIds = self.uniqueHotPostIds
print("Applying transformations to Rising Data...")
aggData = mu.applyDataTransformations(postIdData)
print("Creating Targets for Rising Data from Hot Data")
getTargetUDF = udf(lambda x: mu.getTarget(x, uniqueHotPostIds), returnType=IntegerType())
aggData = aggData.withColumn('target', getTargetUDF(F.col('postId')))
print("Finished gathering data targets.")
return aggData
# by aggregating the data, there should be an at most 60x reduction in the data (since data can be collected once every minute)
# this can be slow because it has to go through all of the transformations and does not scale well
# aggDataPd = aggData.toPandas()
# print(len(aggDataPd))
# aggDataPd.head()
# aggDataPd[aggDataPd['target']==1]
# for pId in aggDataPd[aggDataPd['target']==1]['postId']:
# print('https://reddit.com/'+pId)
# At the time of writing, I've only collected about 1.5 days of data, with 7 viral posts (not a very large amount although that was to be expected). Interestingly, I've noticed that of the viral posts I have,
#
# - the one that had the most upvotes after an hour was actually the least viral,
# - while the one with the least upvotes was actually the most viral.
# - but that post with the least upvotes had the second most comments of the viral posts, 24 comments, so maybe it would be captured by the model
#
# I'm considering extending the time out to 90-120 minutes for data collection. However, the point was to get to a post early when there were relatively few comments. That most viral post had 24 comments after an hour and even that is a lot and any new replies are likely to be buried.
###############################
# Write Spark DataFrame to S3 #
###############################
#
# This is basically our model data and what we will use to train a model. I used spark to write to s3 to future proof this if the data was too large to fit in pandas on driver.
#
# If you get an error here then you probably need to download hadoop-aws-*.jar (ex: [3.2.0](https://mvnrepository.com/artifact/org.apache.hadoop/hadoop-aws/3.2.0)) and aws-java-sdk-bundle-*.jar (ex: [1.11.375](https://mvnrepository.com/artifact/com.amazonaws/aws-java-sdk-bundle/1.11.375))
#
# - for hadoop-aws-*.jar this should match the version of other hadoop jars in $SPARK_HOME/jars/
# - for aws-java-sdk-bundle-*.jar you will need to check the version dependency of hadoop-aws-*.jar on the maven website. Do NOT use the upgraded version, use the version that hadoop-aws was created with.
#
# You may not need to add these dependencies to the configs, but you may need to restart the kernel and rerun things.
#
# These links may help:
#
# - [SO link 1](https://stackoverflow.com/questions/58415928/spark-s3-error-java-lang-classnotfoundexception-class-org-apache-hadoop-f?answertab=scoredesc#tab-top)
# - [SO link 2](https://stackoverflow.com/questions/44411493/java-lang-noclassdeffounderror-org-apache-hadoop-fs-storagestatistics/44500698#44500698)
# - [SO link 3](https://stackoverflow.com/questions/64547468/pyspark-s3-error-java-lang-noclassdeffounderror-com-amazonaws-amazonserviceex)
# - [Tutorial](https://notadatascientist.com/running-apache-spark-and-s3-locally/)
# - [Hadoop Troubleshooting](https://hadoop.apache.org/docs/stable/hadoop-aws/tools/hadoop-aws/troubleshooting_s3a.html)
def load(self, data, location):
print("Writing to S3")
data.write.parquet(location, mode="overwrite")
print("Finished writing to S3")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("-o", "--output", help="Output file name to save results to", required=True)
parser.add_argument("-c", "--chunkSize", help="Number of postIds to read in before converting to spark df", default=1000, required=False, type=int)
args = parser.parse_args()
o = args.output # ie redditAggregatedData.parquet
chunkSize = args.chunkSize # ie 1000
outputFilename = f"s3a://data-kennethmyers/model_data/{o}"
pipeline = Pipeline()
pipeline.extract(chunkSize=chunkSize)
data = pipeline.transform()
pipeline.load(data, outputFilename)