diff --git a/tap_square/__init__.py b/tap_square/__init__.py index 4a8311c..e721373 100644 --- a/tap_square/__init__.py +++ b/tap_square/__init__.py @@ -20,7 +20,7 @@ def main(): if args.discover: write_catalog(catalog) else: - sync(args.config, args.state, catalog) + sync(args.config, args.config_path, args.state, catalog) if __name__ == '__main__': main() diff --git a/tap_square/client.py b/tap_square/client.py index 5fd5499..4fbbacc 100644 --- a/tap_square/client.py +++ b/tap_square/client.py @@ -1,6 +1,7 @@ from datetime import timedelta import urllib.parse -import os +import requests +import json from square.client import Client from singer import utils import singer @@ -39,43 +40,92 @@ def log_backoff(details): LOGGER.warning('Error receiving data from square. Sleeping %.1f seconds before trying again', details['wait']) +def write_config(config, config_path, data): + """ + Updates the provided filepath with json format of the `data` object + """ + config.update(data) + with open(config_path, "w") as tap_config: + json.dump(config, tap_config, indent=2) + return config + + +def require_new_access_token(access_token, environment): + """ + Checks if the access token needs to be refreshed + """ + # If there is no access token, we need to generate a new one + if not access_token: + return True + + if environment == "sandbox": + url = "https://connect.squareupsandbox.com/v2/locations" + else: + url = "https://connect.squareup.com/v2/locations" + + headers = {"Authorization": f"Bearer {access_token}"} + try: + response = requests.request("GET", url, headers=headers) + except Exception as e: + # If there is an error, we should generate a new access token + LOGGER.error(f"Error while validating access token: {e}") + return True + # If the response is a 401, we need to generate a new access token + return response.status_code == 401 + + class RetryableError(Exception): pass class SquareClient(): - def __init__(self, config): + def __init__(self, config, config_path): self._refresh_token = config['refresh_token'] self._client_id = config['client_id'] self._client_secret = config['client_secret'] self._environment = 'sandbox' if config.get('sandbox') == 'true' else 'production' - self._access_token = self._get_access_token() + self._access_token = self._get_access_token(config, config_path) self._client = Client(access_token=self._access_token, environment=self._environment) - def _get_access_token(self): - if "TAP_SQUARE_ACCESS_TOKEN" in os.environ.keys(): - LOGGER.info("Using access token from environment, not creating the new one") - return os.environ["TAP_SQUARE_ACCESS_TOKEN"] + def _get_access_token(self, config, config_path): + """ + Retrieves the access token from the config file. If the access token is expired, it will refresh it. + Otherwise, it will return the cached access token. + """ + access_token = config.get("access_token") + + # Check if the access token needs to be refreshed + if require_new_access_token(access_token, self._environment): + LOGGER.info("Refreshing access token...") + body = { + 'client_id': self._client_id, + 'client_secret': self._client_secret, + 'grant_type': 'refresh_token', + 'refresh_token': self._refresh_token + } - body = { - 'client_id': self._client_id, - 'client_secret': self._client_secret, - 'grant_type': 'refresh_token', - 'refresh_token': self._refresh_token - } + client = Client(environment=self._environment) - client = Client(environment=self._environment) + with singer.http_request_timer('GET access token'): + result = client.o_auth.obtain_token(body) - with singer.http_request_timer('GET access token'): - result = client.o_auth.obtain_token(body) + if result.is_error(): + error_message = result.errors if result.errors else result.body + raise RuntimeError(error_message) - if result.is_error(): - error_message = result.errors if result.errors else result.body - raise RuntimeError(error_message) + access_token = result.body['access_token'] + write_config( + config, + config_path, + { + "access_token": access_token, + "refresh_token": result.body["refresh_token"], + }, + ) - return result.body['access_token'] + return access_token @staticmethod @backoff.on_exception( diff --git a/tap_square/sync.py b/tap_square/sync.py index b69a1a3..b2480e7 100644 --- a/tap_square/sync.py +++ b/tap_square/sync.py @@ -6,8 +6,8 @@ LOGGER = singer.get_logger() -def sync(config, state, catalog): # pylint: disable=too-many-statements - client = SquareClient(config) +def sync(config, config_path, state, catalog): # pylint: disable=too-many-statements + client = SquareClient(config, config_path) with Transformer() as transformer: for stream in catalog.get_selected_streams(state): diff --git a/tests/base.py b/tests/base.py index 1d6d33f..0804d38 100644 --- a/tests/base.py +++ b/tests/base.py @@ -88,9 +88,24 @@ def set_environment(self, env): Requires re-instatiating TestClient and setting env var. """ os.environ['TAP_SQUARE_ENVIRONMENT'] = env + self.set_access_token_in_env() self.client = TestClient(env=env) self.SQUARE_ENVIRONMENT = env + def set_access_token_in_env(self): + """ + Fetch the access token from the existing connection and set it in the env. + This is used to avoid rate limiting issues when running tests. + """ + existing_connections = connections.fetch_existing_connections(self) + conn_with_creds = connections.fetch_existing_connection_with_creds(existing_connections[0]['id']) + access_token = conn_with_creds['credentials'].get('access_token') + if not access_token: + LOGGER.info("No access token found in env") + else: + LOGGER.info("Found access token in env") + os.environ['TAP_SQUARE_ACCESS_TOKEN'] = access_token + @staticmethod def get_environment(): return os.environ['TAP_SQUARE_ENVIRONMENT'] @@ -115,12 +130,23 @@ def get_credentials(): 'refresh_token': os.getenv('TAP_SQUARE_REFRESH_TOKEN') if environment == 'sandbox' else os.getenv('TAP_SQUARE_PROD_REFRESH_TOKEN'), 'client_id': os.getenv('TAP_SQUARE_APPLICATION_ID') if environment == 'sandbox' else os.getenv('TAP_SQUARE_PROD_APPLICATION_ID'), 'client_secret': os.getenv('TAP_SQUARE_APPLICATION_SECRET') if environment == 'sandbox' else os.getenv('TAP_SQUARE_PROD_APPLICATION_SECRET'), + 'access_token': os.environ["TAP_SQUARE_ACCESS_TOKEN"] } else: raise Exception("Square Environment: {} is not supported.".format(environment)) return creds + @staticmethod + def preserve_access_token(existing_conns, payload): + """This method is used get the access token from an existing refresh token""" + if not existing_conns: + return payload + + conn_with_creds = connections.fetch_existing_connection_with_creds(existing_conns[0]['id']) + payload['properties']['access_token'] = conn_with_creds['credentials'].get('access_token') + return payload + def expected_check_streams(self): return set(self.expected_metadata().keys()).difference(set()) diff --git a/tests/test_all_fields.py b/tests/test_all_fields.py index 0d11756..fb90f6d 100644 --- a/tests/test_all_fields.py +++ b/tests/test_all_fields.py @@ -131,7 +131,7 @@ def all_fields_test(self, environment, data_type): expected_records = self.create_test_data(self.TESTABLE_STREAMS, self.START_DATE, force_create_records=True) # instantiate connection - conn_id = connections.ensure_connection(self) + conn_id = connections.ensure_connection(self, payload_hook=self.preserve_access_token) # run check mode found_catalogs = self.run_and_verify_check_mode(conn_id) diff --git a/tests/test_automatic_fields.py b/tests/test_automatic_fields.py index 62150f2..e96800f 100644 --- a/tests/test_automatic_fields.py +++ b/tests/test_automatic_fields.py @@ -62,7 +62,7 @@ def auto_fields_test(self, environment, data_type): ) # instantiate connection - conn_id = connections.ensure_connection(self) + conn_id = connections.ensure_connection(self, payload_hook=self.preserve_access_token) # run check mode found_catalogs = self.run_and_verify_check_mode(conn_id) diff --git a/tests/test_bookmarks.py b/tests/test_bookmarks.py index c08c46f..b84c30b 100644 --- a/tests/test_bookmarks.py +++ b/tests/test_bookmarks.py @@ -90,7 +90,7 @@ def bookmarks_test(self, testable_streams): expected_records_first_sync = self.create_test_data(testable_streams, self.START_DATE, force_create_records=True) # Instantiate connection with default start - conn_id = connections.ensure_connection(self) + conn_id = connections.ensure_connection(self, payload_hook=self.preserve_access_token) # run in check mode check_job_name = runner.run_check_mode(self, conn_id) diff --git a/tests/test_bookmarks_cursor.py b/tests/test_bookmarks_cursor.py index 842570c..ff591a9 100644 --- a/tests/test_bookmarks_cursor.py +++ b/tests/test_bookmarks_cursor.py @@ -84,7 +84,7 @@ def bookmarks_test(self, testable_streams): ] # Create connection but do not use default start date - conn_id = connections.ensure_connection(self, original_properties=False) + conn_id = connections.ensure_connection(self, original_properties=False, payload_hook=self.preserve_access_token) # run check mode found_catalogs = self.run_and_verify_check_mode(conn_id) diff --git a/tests/test_bookmarks_static.py b/tests/test_bookmarks_static.py index 3603994..e911865 100644 --- a/tests/test_bookmarks_static.py +++ b/tests/test_bookmarks_static.py @@ -68,7 +68,7 @@ def test_run(self): expected_records_first_sync = self.create_test_data(self.testable_streams_static(), self.START_DATE) # Instantiate connection with default start - conn_id = connections.ensure_connection(self) + conn_id = connections.ensure_connection(self, payload_hook=self.preserve_access_token) # run in check mode check_job_name = runner.run_check_mode(self, conn_id) diff --git a/tests/test_client.py b/tests/test_client.py index 9aed368..6e639a4 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -38,6 +38,29 @@ def get_batch_token_from_headers(headers): return None +def require_new_access_token(access_token, environment): + """ + Checks if the access token needs to be refreshed + """ + # If there is no access token, we need to generate a new one + if not access_token: + return True + + if environment == "sandbox": + url = "https://connect.squareupsandbox.com/v2/locations" + else: + url = "https://connect.squareup.com/v2/locations" + + headers = {"Authorization": f"Bearer {access_token}"} + try: + response = requests.request("GET", url, headers=headers) + except Exception as e: + # If there is an error, we should generate a new access token + LOGGER.error(f"Error while validating access token: {e}") + return True + # If the response is a 401, we need to generate a new access token + return response.status_code == 401 + def log_backoff(details): ''' Logs a backoff retry message @@ -100,34 +123,39 @@ def __init__(self, env): self._environment = 'sandbox' if config.get('sandbox') == 'true' else 'production' - self._access_token = self._get_access_token() + self._access_token = self._get_access_token(env) self._client = Client(access_token=self._access_token, environment=self._environment) - def _get_access_token(self): - if "TAP_SQUARE_ACCESS_TOKEN" in os.environ.keys(): - LOGGER.info("Using access token from environment, not creating the new") - return os.environ["TAP_SQUARE_ACCESS_TOKEN"] + def _get_access_token(self, env): + """ + Retrieves the access token from the env. If the access token is expired, it will refresh it. + Otherwise, it will return the cached access token. + """ + access_token = os.environ.get("TAP_SQUARE_ACCESS_TOKEN") - body = { - 'client_id': self._client_id, - 'client_secret': self._client_secret, - 'grant_type': 'refresh_token', - 'refresh_token': self._refresh_token - } + # Check if the access token needs to be refreshed + if require_new_access_token(access_token, env): + body = { + 'client_id': self._client_id, + 'client_secret': self._client_secret, + 'grant_type': 'refresh_token', + 'refresh_token': self._refresh_token + } - client = Client(environment=self._environment) + client = Client(environment=self._environment) - with singer.http_request_timer('GET access token'): - result = client.o_auth.obtain_token(body) + with singer.http_request_timer('GET access token'): + result = client.o_auth.obtain_token(body) - if result.is_error(): - error_message = result.errors if result.errors else result.body - LOGGER.info("error_message :-----------: %s",error_message) - raise RuntimeError(error_message) + if result.is_error(): + error_message = result.errors if result.errors else result.body + LOGGER.info("error_message :-----------: %s",error_message) + raise RuntimeError(error_message) + + LOGGER.info("Generating new the access token to set in environment....") + os.environ["TAP_SQUARE_ACCESS_TOKEN"] = access_token = result.body['access_token'] - LOGGER.info("Setting the access token in environment....") - os.environ["TAP_SQUARE_ACCESS_TOKEN"] = result.body['access_token'] - return result.body['access_token'] + return access_token ########################################################################## ### V1 INFO diff --git a/tests/test_default_start_date.py b/tests/test_default_start_date.py index 929be3f..8cb5112 100644 --- a/tests/test_default_start_date.py +++ b/tests/test_default_start_date.py @@ -29,7 +29,7 @@ def run_standard_sync(self, environment, data_type, select_all_fields=True): Select all fields or no fields based on the select_all_fields param. Run a sync. """ - conn_id = connections.ensure_connection(self, original_properties=False) + conn_id = connections.ensure_connection(self, original_properties=False, payload_hook=self.preserve_access_token) found_catalogs = self.run_and_verify_check_mode(conn_id) diff --git a/tests/test_discovery.py b/tests/test_discovery.py index abee04b..200e308 100644 --- a/tests/test_discovery.py +++ b/tests/test_discovery.py @@ -57,7 +57,7 @@ def discovery_test(self): • verify that all other fields have inclusion of available (metadata and schema) """ conn_id = connections.ensure_connection(self) - check_job_name = runner.run_check_mode(self, conn_id) + check_job_name = runner.run_check_mode(self, conn_id, payload_hook=self.preserve_access_token) #verify check exit codes exit_status = menagerie.get_exit_status(conn_id, check_job_name) diff --git a/tests/test_pagination.py b/tests/test_pagination.py index 578ee33..95f62c7 100644 --- a/tests/test_pagination.py +++ b/tests/test_pagination.py @@ -92,7 +92,7 @@ def pagination_test(self): "{} does not have sufficient data in expecatations.\n ".format(stream)) # Create connection but do not use default start date - conn_id = connections.ensure_connection(self, original_properties=False) + conn_id = connections.ensure_connection(self, original_properties=False, payload_hook=self.preserve_access_token) # run check mode found_catalogs = self.run_and_verify_check_mode(conn_id) diff --git a/tests/test_start_date.py b/tests/test_start_date.py index 99f0efa..61d0fe6 100644 --- a/tests/test_start_date.py +++ b/tests/test_start_date.py @@ -76,7 +76,7 @@ def start_date_test(self, environment, data_type): ########################################################################## # instantiate connection - conn_id = connections.ensure_connection(self) + conn_id = connections.ensure_connection(self, payload_hook=self.preserve_access_token) # run check mode found_catalogs = self.run_and_verify_check_mode(conn_id) @@ -107,7 +107,7 @@ def start_date_test(self, environment, data_type): ########################################################################## # create a new connection with the new start_date - conn_id = connections.ensure_connection(self, original_properties=False) + conn_id = connections.ensure_connection(self, original_properties=False, payload_hook=self.preserve_access_token) # run check mode found_catalogs = self.run_and_verify_check_mode(conn_id) diff --git a/tests/test_sync_canary.py b/tests/test_sync_canary.py index 36e63ac..a87f8e4 100644 --- a/tests/test_sync_canary.py +++ b/tests/test_sync_canary.py @@ -26,7 +26,7 @@ def run_standard_sync(self, environment, data_type, select_all_fields=True): Select all fields or no fields based on the select_all_fields param. Run a sync. """ - conn_id = connections.ensure_connection(self) + conn_id = connections.ensure_connection(self, payload_hook=self.preserve_access_token) found_catalogs = self.run_and_verify_check_mode(conn_id)