Skip to content

Commit

Permalink
Initial commit
Browse files Browse the repository at this point in the history
  • Loading branch information
prijendev committed Nov 27, 2024
1 parent 9b56a81 commit 283d558
Show file tree
Hide file tree
Showing 15 changed files with 159 additions and 55 deletions.
2 changes: 1 addition & 1 deletion tap_square/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def main():
if args.discover:
write_catalog(catalog)
else:
sync(args.config, args.state, catalog)
sync(args.config, args.config_path, args.state, catalog)

if __name__ == '__main__':
main()
90 changes: 70 additions & 20 deletions tap_square/client.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from datetime import timedelta
import urllib.parse
import os
import requests
import json
from square.client import Client
from singer import utils
import singer
Expand Down Expand Up @@ -39,43 +40,92 @@ def log_backoff(details):
LOGGER.warning('Error receiving data from square. Sleeping %.1f seconds before trying again', details['wait'])


def write_config(config, config_path, data):
"""
Updates the provided filepath with json format of the `data` object
"""
config.update(data)
with open(config_path, "w") as tap_config:
json.dump(config, tap_config, indent=2)
return config


def require_new_access_token(access_token, environment):
"""
Checks if the access token needs to be refreshed
"""
# If there is no access token, we need to generate a new one
if not access_token:
return True

if environment == "sandbox":
url = "https://connect.squareupsandbox.com/v2/locations"
else:
url = "https://connect.squareup.com/v2/locations"

headers = {"Authorization": f"Bearer {access_token}"}
try:
response = requests.request("GET", url, headers=headers)
except Exception as e:
# If there is an error, we should generate a new access token
LOGGER.error(f"Error while validating access token: {e}")
return True
# If the response is a 401, we need to generate a new access token
return response.status_code == 401


class RetryableError(Exception):
pass


class SquareClient():
def __init__(self, config):
def __init__(self, config, config_path):
self._refresh_token = config['refresh_token']
self._client_id = config['client_id']
self._client_secret = config['client_secret']

self._environment = 'sandbox' if config.get('sandbox') == 'true' else 'production'

self._access_token = self._get_access_token()
self._access_token = self._get_access_token(config, config_path)
self._client = Client(access_token=self._access_token, environment=self._environment)

def _get_access_token(self):
if "TAP_SQUARE_ACCESS_TOKEN" in os.environ.keys():
LOGGER.info("Using access token from environment, not creating the new one")
return os.environ["TAP_SQUARE_ACCESS_TOKEN"]
def _get_access_token(self, config, config_path):
"""
Retrieves the access token from the config file. If the access token is expired, it will refresh it.
Otherwise, it will return the cached access token.
"""
access_token = config.get("access_token")

# Check if the access token needs to be refreshed
if require_new_access_token(access_token, self._environment):
LOGGER.info("Refreshing access token...")
body = {
'client_id': self._client_id,
'client_secret': self._client_secret,
'grant_type': 'refresh_token',
'refresh_token': self._refresh_token
}

body = {
'client_id': self._client_id,
'client_secret': self._client_secret,
'grant_type': 'refresh_token',
'refresh_token': self._refresh_token
}
client = Client(environment=self._environment)

client = Client(environment=self._environment)
with singer.http_request_timer('GET access token'):
result = client.o_auth.obtain_token(body)

with singer.http_request_timer('GET access token'):
result = client.o_auth.obtain_token(body)
if result.is_error():
error_message = result.errors if result.errors else result.body
raise RuntimeError(error_message)

if result.is_error():
error_message = result.errors if result.errors else result.body
raise RuntimeError(error_message)
access_token = result.body['access_token']
write_config(
config,
config_path,
{
"access_token": access_token,
"refresh_token": result.body["refresh_token"],
},
)

return result.body['access_token']
return access_token

@staticmethod
@backoff.on_exception(
Expand Down
4 changes: 2 additions & 2 deletions tap_square/sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@

LOGGER = singer.get_logger()

def sync(config, state, catalog): # pylint: disable=too-many-statements
client = SquareClient(config)
def sync(config, config_path, state, catalog): # pylint: disable=too-many-statements
client = SquareClient(config, config_path)

with Transformer() as transformer:
for stream in catalog.get_selected_streams(state):
Expand Down
26 changes: 26 additions & 0 deletions tests/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,9 +88,24 @@ def set_environment(self, env):
Requires re-instatiating TestClient and setting env var.
"""
os.environ['TAP_SQUARE_ENVIRONMENT'] = env
self.set_access_token_in_env()
self.client = TestClient(env=env)
self.SQUARE_ENVIRONMENT = env

def set_access_token_in_env(self):
"""
Fetch the access token from the existing connection and set it in the env.
This is used to avoid rate limiting issues when running tests.
"""
existing_connections = connections.fetch_existing_connections(self)
conn_with_creds = connections.fetch_existing_connection_with_creds(existing_connections[0]['id'])
access_token = conn_with_creds['credentials'].get('access_token')
if not access_token:
LOGGER.info("No access token found in env")
else:
LOGGER.info("Found access token in env")
os.environ['TAP_SQUARE_ACCESS_TOKEN'] = access_token

@staticmethod
def get_environment():
return os.environ['TAP_SQUARE_ENVIRONMENT']
Expand All @@ -115,12 +130,23 @@ def get_credentials():
'refresh_token': os.getenv('TAP_SQUARE_REFRESH_TOKEN') if environment == 'sandbox' else os.getenv('TAP_SQUARE_PROD_REFRESH_TOKEN'),
'client_id': os.getenv('TAP_SQUARE_APPLICATION_ID') if environment == 'sandbox' else os.getenv('TAP_SQUARE_PROD_APPLICATION_ID'),
'client_secret': os.getenv('TAP_SQUARE_APPLICATION_SECRET') if environment == 'sandbox' else os.getenv('TAP_SQUARE_PROD_APPLICATION_SECRET'),
'access_token': os.environ["TAP_SQUARE_ACCESS_TOKEN"]
}
else:
raise Exception("Square Environment: {} is not supported.".format(environment))

return creds

@staticmethod
def preserve_access_token(existing_conns, payload):
"""This method is used get the access token from an existing refresh token"""
if not existing_conns:
return payload

conn_with_creds = connections.fetch_existing_connection_with_creds(existing_conns[0]['id'])
payload['properties']['access_token'] = conn_with_creds['credentials'].get('access_token')
return payload

def expected_check_streams(self):
return set(self.expected_metadata().keys()).difference(set())

Expand Down
2 changes: 1 addition & 1 deletion tests/test_all_fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def all_fields_test(self, environment, data_type):
expected_records = self.create_test_data(self.TESTABLE_STREAMS, self.START_DATE, force_create_records=True)

# instantiate connection
conn_id = connections.ensure_connection(self)
conn_id = connections.ensure_connection(self, payload_hook=self.preserve_access_token)

# run check mode
found_catalogs = self.run_and_verify_check_mode(conn_id)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_automatic_fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def auto_fields_test(self, environment, data_type):
)

# instantiate connection
conn_id = connections.ensure_connection(self)
conn_id = connections.ensure_connection(self, payload_hook=self.preserve_access_token)

# run check mode
found_catalogs = self.run_and_verify_check_mode(conn_id)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_bookmarks.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def bookmarks_test(self, testable_streams):
expected_records_first_sync = self.create_test_data(testable_streams, self.START_DATE, force_create_records=True)

# Instantiate connection with default start
conn_id = connections.ensure_connection(self)
conn_id = connections.ensure_connection(self, payload_hook=self.preserve_access_token)

# run in check mode
check_job_name = runner.run_check_mode(self, conn_id)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_bookmarks_cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def bookmarks_test(self, testable_streams):
]

# Create connection but do not use default start date
conn_id = connections.ensure_connection(self, original_properties=False)
conn_id = connections.ensure_connection(self, original_properties=False, payload_hook=self.preserve_access_token)

# run check mode
found_catalogs = self.run_and_verify_check_mode(conn_id)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_bookmarks_static.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def test_run(self):
expected_records_first_sync = self.create_test_data(self.testable_streams_static(), self.START_DATE)

# Instantiate connection with default start
conn_id = connections.ensure_connection(self)
conn_id = connections.ensure_connection(self, payload_hook=self.preserve_access_token)

# run in check mode
check_job_name = runner.run_check_mode(self, conn_id)
Expand Down
70 changes: 49 additions & 21 deletions tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,29 @@ def get_batch_token_from_headers(headers):
return None


def require_new_access_token(access_token, environment):
"""
Checks if the access token needs to be refreshed
"""
# If there is no access token, we need to generate a new one
if not access_token:
return True

if environment == "sandbox":
url = "https://connect.squareupsandbox.com/v2/locations"
else:
url = "https://connect.squareup.com/v2/locations"

headers = {"Authorization": f"Bearer {access_token}"}
try:
response = requests.request("GET", url, headers=headers)
except Exception as e:
# If there is an error, we should generate a new access token
LOGGER.error(f"Error while validating access token: {e}")
return True
# If the response is a 401, we need to generate a new access token
return response.status_code == 401

def log_backoff(details):
'''
Logs a backoff retry message
Expand Down Expand Up @@ -100,34 +123,39 @@ def __init__(self, env):

self._environment = 'sandbox' if config.get('sandbox') == 'true' else 'production'

self._access_token = self._get_access_token()
self._access_token = self._get_access_token(env)
self._client = Client(access_token=self._access_token, environment=self._environment)

def _get_access_token(self):
if "TAP_SQUARE_ACCESS_TOKEN" in os.environ.keys():
LOGGER.info("Using access token from environment, not creating the new")
return os.environ["TAP_SQUARE_ACCESS_TOKEN"]
def _get_access_token(self, env):
"""
Retrieves the access token from the env. If the access token is expired, it will refresh it.
Otherwise, it will return the cached access token.
"""
access_token = os.environ.get("TAP_SQUARE_ACCESS_TOKEN")

body = {
'client_id': self._client_id,
'client_secret': self._client_secret,
'grant_type': 'refresh_token',
'refresh_token': self._refresh_token
}
# Check if the access token needs to be refreshed
if require_new_access_token(access_token, env):
body = {
'client_id': self._client_id,
'client_secret': self._client_secret,
'grant_type': 'refresh_token',
'refresh_token': self._refresh_token
}

client = Client(environment=self._environment)
client = Client(environment=self._environment)

with singer.http_request_timer('GET access token'):
result = client.o_auth.obtain_token(body)
with singer.http_request_timer('GET access token'):
result = client.o_auth.obtain_token(body)

if result.is_error():
error_message = result.errors if result.errors else result.body
LOGGER.info("error_message :-----------: %s",error_message)
raise RuntimeError(error_message)
if result.is_error():
error_message = result.errors if result.errors else result.body
LOGGER.info("error_message :-----------: %s",error_message)
raise RuntimeError(error_message)

LOGGER.info("Generating new the access token to set in environment....")
os.environ["TAP_SQUARE_ACCESS_TOKEN"] = access_token = result.body['access_token']

LOGGER.info("Setting the access token in environment....")
os.environ["TAP_SQUARE_ACCESS_TOKEN"] = result.body['access_token']
return result.body['access_token']
return access_token

##########################################################################
### V1 INFO
Expand Down
2 changes: 1 addition & 1 deletion tests/test_default_start_date.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def run_standard_sync(self, environment, data_type, select_all_fields=True):
Select all fields or no fields based on the select_all_fields param.
Run a sync.
"""
conn_id = connections.ensure_connection(self, original_properties=False)
conn_id = connections.ensure_connection(self, original_properties=False, payload_hook=self.preserve_access_token)

found_catalogs = self.run_and_verify_check_mode(conn_id)

Expand Down
2 changes: 1 addition & 1 deletion tests/test_discovery.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def discovery_test(self):
• verify that all other fields have inclusion of available (metadata and schema)
"""
conn_id = connections.ensure_connection(self)
check_job_name = runner.run_check_mode(self, conn_id)
check_job_name = runner.run_check_mode(self, conn_id, payload_hook=self.preserve_access_token)

#verify check exit codes
exit_status = menagerie.get_exit_status(conn_id, check_job_name)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_pagination.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def pagination_test(self):
"{} does not have sufficient data in expecatations.\n ".format(stream))

# Create connection but do not use default start date
conn_id = connections.ensure_connection(self, original_properties=False)
conn_id = connections.ensure_connection(self, original_properties=False, payload_hook=self.preserve_access_token)

# run check mode
found_catalogs = self.run_and_verify_check_mode(conn_id)
Expand Down
4 changes: 2 additions & 2 deletions tests/test_start_date.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def start_date_test(self, environment, data_type):
##########################################################################

# instantiate connection
conn_id = connections.ensure_connection(self)
conn_id = connections.ensure_connection(self, payload_hook=self.preserve_access_token)

# run check mode
found_catalogs = self.run_and_verify_check_mode(conn_id)
Expand Down Expand Up @@ -107,7 +107,7 @@ def start_date_test(self, environment, data_type):
##########################################################################

# create a new connection with the new start_date
conn_id = connections.ensure_connection(self, original_properties=False)
conn_id = connections.ensure_connection(self, original_properties=False, payload_hook=self.preserve_access_token)

# run check mode
found_catalogs = self.run_and_verify_check_mode(conn_id)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_sync_canary.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def run_standard_sync(self, environment, data_type, select_all_fields=True):
Select all fields or no fields based on the select_all_fields param.
Run a sync.
"""
conn_id = connections.ensure_connection(self)
conn_id = connections.ensure_connection(self, payload_hook=self.preserve_access_token)

found_catalogs = self.run_and_verify_check_mode(conn_id)

Expand Down

0 comments on commit 283d558

Please sign in to comment.