diff --git a/kaggle/api/kaggle_api_extended.py b/kaggle/api/kaggle_api_extended.py index deb19e8..9990d1c 100644 --- a/kaggle/api/kaggle_api_extended.py +++ b/kaggle/api/kaggle_api_extended.py @@ -1,19 +1,19 @@ -#!/usr/bin/python -# -# Copyright 2024 Kaggle Inc -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - +#!/usr/bin/python +# +# Copyright 2024 Kaggle Inc +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + #!/usr/bin/python # # Copyright 2019 Kaggle Inc @@ -46,7 +46,9 @@ import zipfile import tempfile -from kagglesdk.datasets.types.dataset_api_service import ApiListDatasetsRequest +from kagglesdk.datasets.types.dataset_api_service import ApiListDatasetsRequest, ApiListDatasetFilesRequest, \ + ApiGetDatasetStatusRequest, ApiDownloadDatasetRequest, ApiCreateDatasetRequest, ApiCreateDatasetVersionRequestBody, \ + ApiCreateDatasetVersionByIdRequest, ApiCreateDatasetVersionRequest, ApiDatasetNewFile from kagglesdk.datasets.types.dataset_enums import DatasetSelectionGroup, DatasetSortBy from ..api_client import ApiClient from kaggle.configuration import Configuration @@ -371,6 +373,7 @@ class KaggleApi(KaggleApi): 'ref', 'title', 'size', 'lastUpdated', 'downloadCount', 'voteCount', 'usabilityRating' ] + dataset_file_fields = ['name', 'size', 'creationDate'] # Hack for https://github.com/Kaggle/kaggle-api/issues/22 / b/78194015 if six.PY2: @@ -1383,26 +1386,28 @@ def dataset_metadata_cli(self, dataset, path, update, dataset_opt=None): def dataset_list_files(self, dataset, page_token=None, page_size=20): """ List files for a dataset. - Parameters - ========== - dataset: the string identified of the dataset - should be in format [owner]/[dataset-name] - page_token: the page token for pagination - page_size: the number of items per page - """ + + Parameters + ========== + dataset: the string identified of the dataset + should be in format [owner]/[dataset-name] + page_token: the page token for pagination + page_size: the number of items per page + """ if dataset is None: raise ValueError('A dataset must be specified') owner_slug, dataset_slug, dataset_version_number = self.split_dataset_string( - dataset) + dataset) - dataset_list_files_result = self.process_response( - self.datasets_list_files_with_http_info( - owner_slug=owner_slug, - dataset_slug=dataset_slug, - dataset_version_number=dataset_version_number, - page_token=page_token, - page_size=page_size)) - return ListFilesResult(dataset_list_files_result) + with self.build_kaggle_client() as kaggle: + request = ApiListDatasetFilesRequest() + request.owner_slug = owner_slug + request.dataset_slug = dataset_slug + request.dataset_version_number = dataset_version_number + request.page_token = page_token + request.page_size = page_size + response = kaggle.datasets.dataset_api_client.list_dataset_files(request) + return ListFilesResult(response) def dataset_list_files_cli(self, dataset, @@ -1441,11 +1446,11 @@ def dataset_list_files_cli(self, def dataset_status(self, dataset): """ Call to get the status of a dataset from the API. - Parameters - ========== - dataset: the string identifier of the dataset - should be in format [owner]/[dataset-name] - """ + Parameters + ========== + dataset: the string identifier of the dataset + should be in format [owner]/[dataset-name] + """ if dataset is None: raise ValueError('A dataset must be specified') if '/' in dataset: @@ -1456,10 +1461,13 @@ def dataset_status(self, dataset): else: owner_slug = self.get_config_value(self.CONFIG_NAME_USER) dataset_slug = dataset - dataset_status_result = self.process_response( - self.datasets_status_with_http_info( - owner_slug=owner_slug, dataset_slug=dataset_slug)) - return dataset_status_result + + with self.build_kaggle_client() as kaggle: + request = ApiGetDatasetStatusRequest() + request.owner_slug = owner_slug + request.dataset_slug = dataset_slug + response = kaggle.datasets.dataset_api_client.get_dataset_status(request) + return response.status.name.lower() def dataset_status_cli(self, dataset, dataset_opt=None): """ A wrapper for client for dataset_status, with additional @@ -1480,43 +1488,44 @@ def dataset_download_file(self, licenses=[]): """ Download a single file for a dataset. - Parameters - ========== - dataset: the string identified of the dataset - should be in format [owner]/[dataset-name] - file_name: the dataset configuration file - path: if defined, download to this location - force: force the download if the file already exists (default False) - quiet: suppress verbose output (default is True) - licenses: a list of license names, e.g. ['CC0-1.0'] - """ + Parameters + ========== + dataset: the string identified of the dataset + should be in format [owner]/[dataset-name] + file_name: the dataset configuration file + path: if defined, download to this location + force: force the download if the file already exists (default False) + quiet: suppress verbose output (default is True) + licenses: a list of license names, e.g. ['CC0-1.0'] + """ if '/' in dataset: self.validate_dataset_string(dataset) owner_slug, dataset_slug, dataset_version_number = self.split_dataset_string( - dataset) + dataset) else: owner_slug = self.get_config_value(self.CONFIG_NAME_USER) dataset_slug = dataset dataset_version_number = None if path is None: - effective_path = self.get_default_download_dir('datasets', owner_slug, - dataset_slug) + effective_path = self.get_default_download_dir( + 'datasets', owner_slug, dataset_slug) else: effective_path = path self._print_dataset_url_and_license(owner_slug, dataset_slug, dataset_version_number, licenses) - response = self.process_response( - self.datasets_download_file_with_http_info( - owner_slug=owner_slug, - dataset_slug=dataset_slug, - dataset_version_number=dataset_version_number, - file_name=file_name, - _preload_content=False)) - url = response.retries.history[0].redirect_location.split('?')[0] - outfile = os.path.join(effective_path, url.split('/')[-1]) + with self.build_kaggle_client() as kaggle: + request = ApiDownloadDatasetRequest() + request.owner_slug = owner_slug + request.dataset_slug = dataset_slug + request.dataset_version_number = dataset_version_number + request.file_name = file_name + response = kaggle.datasets.dataset_api_client.download_dataset(request) + url = response.history[0].url + outfile = os.path.join(effective_path, url.split('?')[0].split('/')[-1]) + if force or self.download_needed(response, outfile, quiet): self.download_file(response, outfile, quiet, not force) return True @@ -1532,35 +1541,35 @@ def dataset_download_files(self, licenses=[]): """ Download all files for a dataset. - Parameters - ========== - dataset: the string identified of the dataset - should be in format [owner]/[dataset-name] - path: the path to download the dataset to - force: force the download if the file already exists (default False) - quiet: suppress verbose output (default is True) - unzip: if True, unzip files upon download (default is False) - licenses: a list of license names, e.g. ['CC0-1.0'] - """ + Parameters + ========== + dataset: the string identified of the dataset + should be in format [owner]/[dataset-name] + path: the path to download the dataset to + force: force the download if the file already exists (default False) + quiet: suppress verbose output (default is True) + unzip: if True, unzip files upon download (default is False) + licenses: a list of license names, e.g. ['CC0-1.0'] + """ if dataset is None: raise ValueError('A dataset must be specified') owner_slug, dataset_slug, dataset_version_number = self.split_dataset_string( - dataset) + dataset) if path is None: - effective_path = self.get_default_download_dir('datasets', owner_slug, - dataset_slug) + effective_path = self.get_default_download_dir( + 'datasets', owner_slug, dataset_slug) else: effective_path = path self._print_dataset_url_and_license(owner_slug, dataset_slug, dataset_version_number, licenses) - response = self.process_response( - self.datasets_download_with_http_info( - owner_slug=owner_slug, - dataset_slug=dataset_slug, - dataset_version_number=dataset_version_number, - _preload_content=False)) + with self.build_kaggle_client() as kaggle: + request = ApiDownloadDatasetRequest() + request.owner_slug = owner_slug + request.dataset_slug = dataset_slug + request.dataset_version_number = dataset_version_number + response = kaggle.datasets.dataset_api_client.download_dataset(request) outfile = os.path.join(effective_path, dataset_slug + '.zip') if force or self.download_needed(response, outfile, quiet): @@ -1577,18 +1586,18 @@ def dataset_download_files(self, z.extractall(effective_path) except zipfile.BadZipFile as e: raise ValueError( - f"The file {outfile} is corrupted or not a valid zip file. " - "Please report this issue at https://www.github.com/kaggle/kaggle-api" + f"The file {outfile} is corrupted or not a valid zip file. " + "Please report this issue at https://www.github.com/kaggle/kaggle-api" ) except FileNotFoundError: raise FileNotFoundError( - f"The file {outfile} was not found. " - "Please report this issue at https://www.github.com/kaggle/kaggle-api" + f"The file {outfile} was not found. " + "Please report this issue at https://www.github.com/kaggle/kaggle-api" ) except Exception as e: raise RuntimeError( - f"An unexpected error occurred: {e}. " - "Please report this issue at https://www.github.com/kaggle/kaggle-api" + f"An unexpected error occurred: {e}. " + "Please report this issue at https://www.github.com/kaggle/kaggle-api" ) try: @@ -1616,9 +1625,9 @@ def dataset_download_cli(self, unzip=False, force=False, quiet=False): - """ client wrapper for dataset_download_files and download dataset file, + """ Client wrapper for dataset_download_files and download dataset file, either for a specific file (when file_name is provided), - or all files for a dataset (plural) + or all files for a dataset (plural). Parameters ========== @@ -1668,7 +1677,7 @@ def dataset_download_cli(self, licenses=licenses) def _upload_blob(self, path, quiet, blob_type, upload_context): - """ upload a file + """ Upload a file. Parameters ========== @@ -1722,7 +1731,7 @@ def dataset_create_version(self, convert_to_csv=True, delete_old_versions=False, dir_mode='skip'): - """ create a version of a dataset + """ Create a version of a dataset. Parameters ========== @@ -1745,6 +1754,8 @@ def dataset_create_version(self, id_no = self.get_or_default(meta_data, 'id_no', None) if not ref and not id_no: raise ValueError('ID or slug must be specified in the metadata') + elif ref and ref == self.config_values[self.CONFIG_NAME_USER] + '/INSERT_SLUG_HERE': + raise ValueError('Default slug detected, please change values before uploading') subtitle = meta_data.get('subtitle') if subtitle and (len(subtitle) < 20 or len(subtitle) > 80): @@ -1756,14 +1767,35 @@ def dataset_create_version(self, description = meta_data.get('description') keywords = self.get_or_default(meta_data, 'keywords', []) - request = DatasetNewVersionRequest( - version_notes=version_notes, - subtitle=subtitle, - description=description, - files=[], - convert_to_csv=convert_to_csv, - category_ids=keywords, - delete_old_versions=delete_old_versions) + body = ApiCreateDatasetVersionRequestBody() + body.version_notes=version_notes + body.subtitle=subtitle + body.description=description + body.files=[] + body.category_ids=keywords + body.delete_old_versions=delete_old_versions + + with self.build_kaggle_client() as kaggle: + if id_no: + request = ApiCreateDatasetVersionByIdRequest() + request.id = id_no + message = kaggle.datasets.dataset_api_client.create_dataset_version_by_id + else: + self.validate_dataset_string(ref) + ref_list = ref.split('/') + owner_slug = ref_list[0] + dataset_slug = ref_list[1] + request = ApiCreateDatasetVersionRequest() + request.owner_slug = owner_slug + request.dataset_slug = dataset_slug + message = kaggle.datasets.dataset_api_client.create_dataset_version + request.body = body + with ResumableUploadContext() as upload_context: + self.upload_files(body, resources, folder, ApiBlobType.DATASET, + upload_context, quiet, dir_mode) + request.body.files = [self._api_dataset_new_file(file) for file in request.body.files] + response = self.with_retry(message)(request) + return response with ResumableUploadContext() as upload_context: self.upload_files(request, resources, folder, ApiBlobType.DATASET, @@ -1776,10 +1808,8 @@ def dataset_create_version(self, self.datasets_create_version_by_id_with_http_info)( id_no, request))) else: - if ref == self.config_values[ - self.CONFIG_NAME_USER] + '/INSERT_SLUG_HERE': - raise ValueError('Default slug detected, please change values before ' - 'uploading') + if ref == self.config_values[self.CONFIG_NAME_USER] + '/INSERT_SLUG_HERE': + raise ValueError('Default slug detected, please change values before uploading') self.validate_dataset_string(ref) ref_list = ref.split('/') owner_slug = ref_list[0] @@ -1791,6 +1821,12 @@ def dataset_create_version(self, return result + def _api_dataset_new_file(self, file): + # TODO Eliminate the need for this conversion + f = ApiDatasetNewFile() + f.token = file.token + return f + def dataset_create_version_cli(self, folder, version_notes, @@ -1861,16 +1897,17 @@ def dataset_create_new(self, quiet=False, convert_to_csv=True, dir_mode='skip'): - """ create a new dataset, meaning the same as creating a version but - with extra metadata like license and user/owner. - Parameters - ========== - folder: the folder to get the metadata file from - public: should the dataset be public? - quiet: suppress verbose output (default is False) - convert_to_csv: if True, convert data to comma separated value - dir_mode: What to do with directories: "skip" - ignore; "zip" - compress and upload - """ + """ Create a new dataset, meaning the same as creating a version but + with extra metadata like license and user/owner. + + Parameters + ========== + folder: the folder to get the metadata file from + public: should the dataset be public? + quiet: suppress verbose output (default is False) + convert_to_csv: if True, convert data to comma separated value + dir_mode: What to do with directories: "skip" - ignore; "zip" - compress and upload + """ if not os.path.isdir(folder): raise ValueError('Invalid folder: ' + folder) @@ -1887,18 +1924,22 @@ def dataset_create_new(self, dataset_slug = ref_list[1] # validations - if ref == self.config_values[self.CONFIG_NAME_USER] + '/INSERT_SLUG_HERE': + if ref == self.config_values[ + self.CONFIG_NAME_USER] + '/INSERT_SLUG_HERE': raise ValueError( - 'Default slug detected, please change values before uploading') + 'Default slug detected, please change values before uploading') if title == 'INSERT_TITLE_HERE': raise ValueError( - 'Default title detected, please change values before uploading') + 'Default title detected, please change values before uploading' + ) if len(licenses) != 1: raise ValueError('Please specify exactly one license') if len(dataset_slug) < 6 or len(dataset_slug) > 50: - raise ValueError('The dataset slug must be between 6 and 50 characters') + raise ValueError( + 'The dataset slug must be between 6 and 50 characters') if len(title) < 6 or len(title) > 50: - raise ValueError('The dataset title must be between 6 and 50 characters') + raise ValueError( + 'The dataset title must be between 6 and 50 characters') resources = meta_data.get('resources') if resources: self.validate_resources(folder, resources) @@ -1909,27 +1950,44 @@ def dataset_create_new(self, subtitle = meta_data.get('subtitle') if subtitle and (len(subtitle) < 20 or len(subtitle) > 80): - raise ValueError('Subtitle length must be between 20 and 80 characters') - - request = DatasetNewRequest( - title=title, - slug=dataset_slug, - owner_slug=owner_slug, - license_name=license_name, - subtitle=subtitle, - description=description, - files=[], - is_private=not public, - convert_to_csv=convert_to_csv, - category_ids=keywords) + raise ValueError( + 'Subtitle length must be between 20 and 80 characters') + + request = DatasetNewRequest(title=title, + slug=dataset_slug, + owner_slug=owner_slug, + license_name=license_name, + subtitle=subtitle, + description=description, + files=[], + is_private=not public, + convert_to_csv=convert_to_csv, + category_ids=keywords) with ResumableUploadContext() as upload_context: + # TODO Change upload_files() to use ApiCreateDatasetRequest self.upload_files(request, resources, folder, ApiBlobType.DATASET, upload_context, quiet, dir_mode) + + with self.build_kaggle_client() as kaggle: + retry_request = ApiCreateDatasetRequest() + retry_request.title=title + retry_request.slug=dataset_slug + retry_request.owner_slug=owner_slug + retry_request.license_name=license_name + retry_request.subtitle=subtitle + retry_request.description=description + retry_request.files=[] + retry_request.is_private=not public + retry_request.category_ids=keywords + response = self.with_retry( + kaggle.datasets.dataset_api_client.create_dataset)(retry_request) + return response + result = DatasetNewResponse( - self.process_response( - self.with_retry( - self.datasets_create_new_with_http_info)(request))) + self.process_response( + self.with_retry( + self.datasets_create_new_with_http_info)(request))) return result diff --git a/kaggle/models/kaggle_models_extended.py b/kaggle/models/kaggle_models_extended.py index 8a39271..19ae9e3 100644 --- a/kaggle/models/kaggle_models_extended.py +++ b/kaggle/models/kaggle_models_extended.py @@ -1,19 +1,19 @@ -#!/usr/bin/python -# -# Copyright 2024 Kaggle Inc -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - +#!/usr/bin/python +# +# Copyright 2024 Kaggle Inc +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + #!/usr/bin/python # # Copyright 2019 Kaggle Inc @@ -131,9 +131,14 @@ def __repr__(self): class File(object): def __init__(self, init_dict): - parsed_dict = {k: parse(v) for k, v in init_dict.items()} - self.__dict__.update(parsed_dict) - self.size = File.get_size(self.totalBytes) + try: # TODO Remove try-block + parsed_dict = {k: parse(v) for k, v in init_dict.items()} + self.__dict__.update(parsed_dict) + self.size = File.get_size(self.totalBytes) + except AttributeError: + self.name = init_dict.name + self.creation_date = init_dict.creation_date + self.size = File.get_size(init_dict.total_bytes) def __repr__(self): return self.name @@ -181,13 +186,18 @@ def __repr__(self): class ListFilesResult(object): def __init__(self, init_dict): - self.error_message = init_dict['errorMessage'] - files = init_dict['datasetFiles'] + try: # TODO Remove try-block + self.error_message = init_dict['errorMessage'] + files = init_dict['datasetFiles'] + token = init_dict['nextPageToken'] + except TypeError: + self.error_message = init_dict.error_message + files = init_dict.dataset_files + token = init_dict.next_page_token if files: self.files = [File(f) for f in files] else: self.files = {} - token = init_dict['nextPageToken'] if token: self.nextPageToken = token else: diff --git a/kagglesdk/datasets/types/dataset_api_service.py b/kagglesdk/datasets/types/dataset_api_service.py index b190b52..c284871 100644 --- a/kagglesdk/datasets/types/dataset_api_service.py +++ b/kagglesdk/datasets/types/dataset_api_service.py @@ -176,6 +176,10 @@ def endpoint(self): def method(): return 'POST' + @staticmethod + def body_fields(): + return '*' + class ApiCreateDatasetResponse(KaggleObject): r""" Attributes: @@ -310,6 +314,10 @@ def endpoint(self): def method(): return 'POST' + @staticmethod + def body_fields(): + return 'body' + class ApiCreateDatasetVersionRequest(KaggleObject): r""" Attributes: @@ -373,6 +381,10 @@ def endpoint(self): def method(): return 'POST' + @staticmethod + def body_fields(): + return 'body' + class ApiCreateDatasetVersionRequestBody(KaggleObject): r""" Attributes: @@ -2080,6 +2092,10 @@ def endpoint(self): def method(): return 'POST' + @staticmethod + def body_fields(): + return 'settings' + class ApiUpdateDatasetMetadataResponse(KaggleObject): r""" Attributes: diff --git a/kagglesdk/kaggle_http_client.py b/kagglesdk/kaggle_http_client.py index e2fcd22..b4a7069 100644 --- a/kagglesdk/kaggle_http_client.py +++ b/kagglesdk/kaggle_http_client.py @@ -34,6 +34,16 @@ def _get_apikey_creds(): api_key = api_key_data['key'] return username, api_key +def clean_data(data): + if isinstance(data, dict): + return {k: clean_data(v) for k, v in data.items() if v is not None} + if isinstance(data, list): + return [clean_data(v) for v in data if v is not None] + if data is True: + return 'true' + if data is False: + return 'false' + return data class KaggleHttpClient(object): _xsrf_cookie_name = 'XSRF-TOKEN' @@ -75,6 +85,19 @@ def _prepare_request(self, service_name: str, request_name: str, request: Kaggle 'Accept': 'application/json', 'Content-Type': 'text/plain', }) + elif method == 'POST': + self._session.headers.update({ + 'Accept': 'application/json', + 'Content-Type': 'application/json', + }) + if isinstance(data, dict): + fields = request.body_fields() + if fields is not None: + if fields != '*': + data = data[fields] + data = clean_data(data) + data = data.__str__().replace("'", '"') + # TODO Remove quotes from numbers. http_request = requests.Request( method=method, url=request_url, diff --git a/kagglesdk/kaggle_object.py b/kagglesdk/kaggle_object.py index 04f61b5..ac3235e 100644 --- a/kagglesdk/kaggle_object.py +++ b/kagglesdk/kaggle_object.py @@ -205,6 +205,10 @@ class KaggleObject(object): def endpoint(self): raise 'Error: endpoint must be defined by the request object' + @staticmethod + def body_fields(): + return None + @classmethod def prepare_from(cls, http_response): return cls.from_json(http_response.text) @@ -229,7 +233,7 @@ def to_dict(self, ignore_defaults=True): @staticmethod def to_field_map(self, ignore_defaults=True): kv_pairs = [(field.field_name, field.get_as_dict_item(self, ignore_defaults)) for field in self._fields] - return {k: v for (k, v) in kv_pairs if not ignore_defaults or v is not None} + return {k: str(v) for (k, v) in kv_pairs if not ignore_defaults or v is not None} @classmethod def from_dict(cls, json_dict): diff --git a/kagglesdk/kernels/types/kernels_api_service.py b/kagglesdk/kernels/types/kernels_api_service.py index e0854a9..224a17e 100644 --- a/kagglesdk/kernels/types/kernels_api_service.py +++ b/kagglesdk/kernels/types/kernels_api_service.py @@ -1365,6 +1365,10 @@ def endpoint(self): def method(): return 'POST' + @staticmethod + def body_fields(): + return '*' + class ApiSaveKernelResponse(KaggleObject): r""" Attributes: diff --git a/kagglesdk/models/types/model_api_service.py b/kagglesdk/models/types/model_api_service.py index ce9aa1d..d1b39bc 100644 --- a/kagglesdk/models/types/model_api_service.py +++ b/kagglesdk/models/types/model_api_service.py @@ -3,7 +3,7 @@ from kagglesdk.datasets.types.dataset_api_service import ApiCategory, ApiDatasetNewFile, ApiUploadDirectoryInfo from kagglesdk.kaggle_object import * from kagglesdk.models.types.model_enums import ListModelsOrderBy, ModelFramework, ModelInstanceType -from kagglesdk.models.types.model_types import BaseModelInstanceInformation +from kagglesdk.models.types.model_types import BaseModelInstanceInformation, ModelLink from typing import Optional, List class ApiCreateModelInstanceRequest(KaggleObject): @@ -69,6 +69,10 @@ def endpoint(self): def method(): return 'POST' + @staticmethod + def body_fields(): + return 'body' + class ApiCreateModelInstanceRequestBody(KaggleObject): r""" Attributes: @@ -357,6 +361,10 @@ def endpoint(self): def method(): return 'POST' + @staticmethod + def body_fields(): + return 'body' + class ApiCreateModelInstanceVersionRequestBody(KaggleObject): r""" Attributes: @@ -553,6 +561,10 @@ def endpoint(self): def method(): return 'POST' + @staticmethod + def body_fields(): + return '*' + class ApiCreateModelResponse(KaggleObject): r""" Attributes: @@ -1280,6 +1292,8 @@ class ApiListModelsRequest(KaggleObject): Page size. page_token (str) Page token used for pagination. + only_vertex_models (bool) + Only list models that have Vertex URLs """ def __init__(self): @@ -1288,6 +1302,7 @@ def __init__(self): self._owner = None self._page_size = None self._page_token = None + self._only_vertex_models = None self._freeze() @property @@ -1363,6 +1378,20 @@ def page_token(self, page_token: str): raise TypeError('page_token must be of type str') self._page_token = page_token + @property + def only_vertex_models(self) -> bool: + """Only list models that have Vertex URLs""" + return self._only_vertex_models or False + + @only_vertex_models.setter + def only_vertex_models(self, only_vertex_models: bool): + if only_vertex_models is None: + del self.only_vertex_models + return + if not isinstance(only_vertex_models, bool): + raise TypeError('only_vertex_models must be of type bool') + self._only_vertex_models = only_vertex_models + def endpoint(self): path = '/api/v1/models/list' @@ -1441,6 +1470,7 @@ class ApiModel(KaggleObject): publish_time (datetime) provenance_sources (str) url (str) + model_version_links (ModelLink) """ def __init__(self): @@ -1457,6 +1487,7 @@ def __init__(self): self._publish_time = None self._provenance_sources = "" self._url = "" + self._model_version_links = [] self._freeze() @property @@ -1633,6 +1664,21 @@ def url(self, url: str): raise TypeError('url must be of type str') self._url = url + @property + def model_version_links(self) -> Optional[List[Optional['ModelLink']]]: + return self._model_version_links + + @model_version_links.setter + def model_version_links(self, model_version_links: Optional[List[Optional['ModelLink']]]): + if model_version_links is None: + del self.model_version_links + return + if not isinstance(model_version_links, list): + raise TypeError('model_version_links must be of type list') + if not all([isinstance(t, ModelLink) for t in model_version_links]): + raise TypeError('model_version_links must contain only items of type ModelLink') + self._model_version_links = model_version_links + class ApiModelFile(KaggleObject): r""" @@ -2139,6 +2185,10 @@ def endpoint(self): def method(): return 'POST' + @staticmethod + def body_fields(): + return '*' + class ApiUpdateModelRequest(KaggleObject): r""" Attributes: @@ -2292,6 +2342,10 @@ def endpoint(self): def method(): return 'POST' + @staticmethod + def body_fields(): + return '*' + class ApiUpdateModelResponse(KaggleObject): r""" Attributes: @@ -2587,6 +2641,7 @@ def create_url(self, create_url: str): FieldMetadata("owner", "owner", "_owner", str, None, PredefinedSerializer(), optional=True), FieldMetadata("pageSize", "page_size", "_page_size", int, None, PredefinedSerializer(), optional=True), FieldMetadata("pageToken", "page_token", "_page_token", str, None, PredefinedSerializer(), optional=True), + FieldMetadata("onlyVertexModels", "only_vertex_models", "_only_vertex_models", bool, None, PredefinedSerializer(), optional=True), ] ApiListModelsResponse._fields = [ @@ -2609,6 +2664,7 @@ def create_url(self, create_url: str): FieldMetadata("publishTime", "publish_time", "_publish_time", datetime, None, DateTimeSerializer()), FieldMetadata("provenanceSources", "provenance_sources", "_provenance_sources", str, "", PredefinedSerializer()), FieldMetadata("url", "url", "_url", str, "", PredefinedSerializer()), + FieldMetadata("modelVersionLinks", "model_version_links", "_model_version_links", ModelLink, [], ListSerializer(KaggleObjectSerializer())), ] ApiModelFile._fields = [ diff --git a/kagglesdk/models/types/model_enums.py b/kagglesdk/models/types/model_enums.py index a7f488e..6f0f2c2 100644 --- a/kagglesdk/models/types/model_enums.py +++ b/kagglesdk/models/types/model_enums.py @@ -42,3 +42,8 @@ class ModelInstanceType(enum.Enum): MODEL_INSTANCE_TYPE_KAGGLE_VARIANT = 2 MODEL_INSTANCE_TYPE_EXTERNAL_VARIANT = 3 +class ModelVersionLinkType(enum.Enum): + MODEL_VERSION_LINK_TYPE_UNSPECIFIED = 0 + MODEL_VERSION_LINK_TYPE_VERTEX_OPEN = 1 + MODEL_VERSION_LINK_TYPE_VERTEX_DEPLOY = 2 + diff --git a/kagglesdk/models/types/model_types.py b/kagglesdk/models/types/model_types.py index 96a4931..bdfeedc 100644 --- a/kagglesdk/models/types/model_types.py +++ b/kagglesdk/models/types/model_types.py @@ -1,5 +1,5 @@ from kagglesdk.kaggle_object import * -from kagglesdk.models.types.model_enums import ModelFramework +from kagglesdk.models.types.model_enums import ModelFramework, ModelVersionLinkType from kagglesdk.users.types.users_enums import UserAchievementTier from typing import Optional @@ -87,6 +87,45 @@ def framework(self, framework: 'ModelFramework'): self._framework = framework +class ModelLink(KaggleObject): + r""" + Attributes: + type (ModelVersionLinkType) + url (str) + """ + + def __init__(self): + self._type = ModelVersionLinkType.MODEL_VERSION_LINK_TYPE_UNSPECIFIED + self._url = "" + self._freeze() + + @property + def type(self) -> 'ModelVersionLinkType': + return self._type + + @type.setter + def type(self, type: 'ModelVersionLinkType'): + if type is None: + del self.type + return + if not isinstance(type, ModelVersionLinkType): + raise TypeError('type must be of type ModelVersionLinkType') + self._type = type + + @property + def url(self) -> str: + return self._url + + @url.setter + def url(self, url: str): + if url is None: + del self.url + return + if not isinstance(url, str): + raise TypeError('url must be of type str') + self._url = url + + class Owner(KaggleObject): r""" Based off Datasets OwnerDto as the permission model is the same @@ -214,6 +253,11 @@ def user_tier(self, user_tier: 'UserAchievementTier'): FieldMetadata("framework", "framework", "_framework", ModelFramework, ModelFramework.MODEL_FRAMEWORK_UNSPECIFIED, EnumSerializer()), ] +ModelLink._fields = [ + FieldMetadata("type", "type", "_type", ModelVersionLinkType, ModelVersionLinkType.MODEL_VERSION_LINK_TYPE_UNSPECIFIED, EnumSerializer()), + FieldMetadata("url", "url", "_url", str, "", PredefinedSerializer()), +] + Owner._fields = [ FieldMetadata("id", "id", "_id", int, 0, PredefinedSerializer()), FieldMetadata("imageUrl", "image_url", "_image_url", str, None, PredefinedSerializer(), optional=True), diff --git a/src/kaggle/api/kaggle_api_extended.py b/src/kaggle/api/kaggle_api_extended.py index 9d60b23..698af11 100644 --- a/src/kaggle/api/kaggle_api_extended.py +++ b/src/kaggle/api/kaggle_api_extended.py @@ -30,7 +30,9 @@ import zipfile import tempfile -from kagglesdk.datasets.types.dataset_api_service import ApiListDatasetsRequest +from kagglesdk.datasets.types.dataset_api_service import ApiListDatasetsRequest, ApiListDatasetFilesRequest, \ + ApiGetDatasetStatusRequest, ApiDownloadDatasetRequest, ApiCreateDatasetRequest, ApiCreateDatasetVersionRequestBody, \ + ApiCreateDatasetVersionByIdRequest, ApiCreateDatasetVersionRequest, ApiDatasetNewFile from kagglesdk.datasets.types.dataset_enums import DatasetSelectionGroup, DatasetSortBy from ..api_client import ApiClient from kaggle.configuration import Configuration @@ -100,7 +102,7 @@ def __enter__(self): self._temp_dir = tempfile.mkdtemp() _, dir_name = os.path.split(self._fullpath) self.path = shutil.make_archive( - os.path.join(self._temp_dir, dir_name), self._format, self._fullpath) + os.path.join(self._temp_dir, dir_name), self._format, self._fullpath) _, self.name = os.path.split(self.path) return self @@ -133,8 +135,8 @@ def __exit__(self, exc_type, exc_value, exc_traceback): def get_upload_info_file_path(self, path): return os.path.join( - self._temp_dir, - '%s.json' % path.replace(os.path.sep, '_').replace(':', '_')) + self._temp_dir, + '%s.json' % path.replace(os.path.sep, '_').replace(':', '_')) def new_resumable_file_upload(self, path, start_blob_upload_request): file_upload = ResumableFileUpload(path, start_blob_upload_request, self) @@ -192,8 +194,8 @@ def _load_previous_if_any(self): def _is_previous_valid(self, previous): return previous.path == self.path and \ - previous.start_blob_upload_request == self.start_blob_upload_request and \ - previous.timestamp > time.time() - ResumableFileUpload.RESUMABLE_UPLOAD_EXPIRY_SECONDS + previous.start_blob_upload_request == self.start_blob_upload_request and \ + previous.timestamp > time.time() - ResumableFileUpload.RESUMABLE_UPLOAD_EXPIRY_SECONDS def upload_initiated(self, start_blob_upload_response): if self.context.no_resume: @@ -225,28 +227,28 @@ def cleanup(self): def to_dict(self): return { - 'path': - self.path, - 'start_blob_upload_request': - self.start_blob_upload_request.to_dict(), - 'timestamp': - self.timestamp, - 'start_blob_upload_response': - self.start_blob_upload_response.to_dict() - if self.start_blob_upload_response is not None else None, - 'upload_complete': - self.upload_complete, + 'path': + self.path, + 'start_blob_upload_request': + self.start_blob_upload_request.to_dict(), + 'timestamp': + self.timestamp, + 'start_blob_upload_response': + self.start_blob_upload_response.to_dict() + if self.start_blob_upload_response is not None else None, + 'upload_complete': + self.upload_complete, } def from_dict(other, context): new = ResumableFileUpload( - other['path'], - StartBlobUploadRequest(**other['start_blob_upload_request']), context) + other['path'], + StartBlobUploadRequest(**other['start_blob_upload_request']), context) new.timestamp = other.get('timestamp') start_blob_upload_response = other.get('start_blob_upload_response') if start_blob_upload_response is not None: new.start_blob_upload_response = StartBlobUploadResponse( - **start_blob_upload_response) + **start_blob_upload_response) new.upload_complete = other.get('upload_complete') or False return new @@ -304,31 +306,31 @@ class KaggleApi(KaggleApi): valid_list_kernel_types = ['all', 'script', 'notebook'] valid_list_output_types = ['all', 'visualization', 'data'] valid_list_sort_by = [ - 'hotness', 'commentCount', 'dateCreated', 'dateRun', 'relevance', - 'scoreAscending', 'scoreDescending', 'viewCount', 'voteCount' + 'hotness', 'commentCount', 'dateCreated', 'dateRun', 'relevance', + 'scoreAscending', 'scoreDescending', 'viewCount', 'voteCount' ] # Competitions valid types valid_competition_groups = ['general', 'entered', 'inClass'] valid_competition_categories = [ - 'all', 'featured', 'research', 'recruitment', 'gettingStarted', 'masters', - 'playground' + 'all', 'featured', 'research', 'recruitment', 'gettingStarted', 'masters', + 'playground' ] valid_competition_sort_by = [ - 'grouped', 'prize', 'earliestDeadline', 'latestDeadline', 'numberOfTeams', - 'recentlyCreated' + 'grouped', 'prize', 'earliestDeadline', 'latestDeadline', 'numberOfTeams', + 'recentlyCreated' ] # Datasets valid types valid_dataset_file_types = ['all', 'csv', 'sqlite', 'json', 'bigQuery'] valid_dataset_license_names = ['all', 'cc', 'gpl', 'odb', 'other'] valid_dataset_sort_bys = [ - 'hottest', 'votes', 'updated', 'active', 'published' + 'hottest', 'votes', 'updated', 'active', 'published' ] # Models valid types valid_model_sort_bys = [ - 'hotness', 'downloadCount', 'voteCount', 'notebookCount', 'createTime' + 'hotness', 'downloadCount', 'voteCount', 'notebookCount', 'createTime' ] # Command prefixes that are valid without authentication. @@ -337,24 +339,25 @@ class KaggleApi(KaggleApi): # Attributes competition_fields = [ - 'ref', 'deadline', 'category', 'reward', 'teamCount', 'userHasEntered' + 'ref', 'deadline', 'category', 'reward', 'teamCount', 'userHasEntered' ] submission_fields = [ - 'fileName', 'date', 'description', 'status', 'publicScore', 'privateScore' + 'fileName', 'date', 'description', 'status', 'publicScore', 'privateScore' ] competition_file_fields = ['name', 'totalBytes', 'creationDate'] competition_file_labels = ['name', 'size', 'creationDate'] competition_leaderboard_fields = [ - 'teamId', 'teamName', 'submissionDate', 'score' + 'teamId', 'teamName', 'submissionDate', 'score' ] dataset_fields = [ - 'ref', 'title', 'totalBytes', 'lastUpdated', 'downloadCount', 'voteCount', - 'usabilityRating' + 'ref', 'title', 'totalBytes', 'lastUpdated', 'downloadCount', 'voteCount', + 'usabilityRating' ] dataset_labels = [ - 'ref', 'title', 'size', 'lastUpdated', 'downloadCount', 'voteCount', - 'usabilityRating' + 'ref', 'title', 'size', 'lastUpdated', 'downloadCount', 'voteCount', + 'usabilityRating' ] + dataset_file_fields = ['name', 'size', 'creationDate'] # Hack for https://github.com/Kaggle/kaggle-api/issues/22 / b/78194015 if six.PY2: @@ -363,11 +366,11 @@ class KaggleApi(KaggleApi): def _is_retriable(self, e): return issubclass(type(e), ConnectionError) or \ - issubclass(type(e), urllib3_exceptions.ConnectionError) or \ - issubclass(type(e), urllib3_exceptions.ConnectTimeoutError) or \ - issubclass(type(e), urllib3_exceptions.ProtocolError) or \ - issubclass(type(e), requests.exceptions.ConnectionError) or \ - issubclass(type(e), requests.exceptions.ConnectTimeout) + issubclass(type(e), urllib3_exceptions.ConnectionError) or \ + issubclass(type(e), urllib3_exceptions.ConnectTimeoutError) or \ + issubclass(type(e), urllib3_exceptions.ProtocolError) or \ + issubclass(type(e), requests.exceptions.ConnectionError) or \ + issubclass(type(e), requests.exceptions.ConnectTimeout) def _calculate_backoff_delay(self, attempt, initial_delay_millis, retry_multiplier, randomness_factor): @@ -418,12 +421,12 @@ def authenticate(self): # Step 2: if credentials were not in env read in configuration file if self.CONFIG_NAME_USER not in config_data \ - or self.CONFIG_NAME_KEY not in config_data: + or self.CONFIG_NAME_KEY not in config_data: if os.path.exists(self.config): config_data = self.read_config_file(config_data) elif self._is_help_or_version_command(api_command) or (len( sys.argv) > 2 and api_command.startswith( - self.command_prefixes_allowing_anonymous_access)): + self.command_prefixes_allowing_anonymous_access)): # Some API commands should be allowed without authentication. return else: @@ -431,7 +434,7 @@ def authenticate(self): ' {}. Or use the environment method. See setup' ' instructions at' ' https://github.com/Kaggle/kaggle-api/'.format( - self.config_file, self.config_dir)) + self.config_file, self.config_dir)) # Step 3: load into configuration! self._load_config(config_data) @@ -519,10 +522,10 @@ def _load_config(self, config_data): ' is not valid, please check your proxy settings') else: raise ValueError( - 'Unauthorized: you must download an API key or export ' - 'credentials to the environment. Please see\n ' + - 'https://github.com/Kaggle/kaggle-api#api-credentials ' + - 'for instructions.') + 'Unauthorized: you must download an API key or export ' + 'credentials to the environment. Please see\n ' + + 'https://github.com/Kaggle/kaggle-api#api-credentials ' + + 'for instructions.') def read_config_file(self, config_data=None, quiet=False): """read_config_file is the first effort to get a username @@ -690,16 +693,16 @@ def print_config_values(self, prefix='- '): def build_kaggle_client(self): env = KaggleEnv.STAGING if '--staging' in self.args \ - else KaggleEnv.ADMIN if '--admin' in self.args \ - else KaggleEnv.LOCAL if '--local' in self.args \ - else KaggleEnv.PROD + else KaggleEnv.ADMIN if '--admin' in self.args \ + else KaggleEnv.LOCAL if '--local' in self.args \ + else KaggleEnv.PROD verbose = '--verbose' in self.args or '-v' in self.args config = self.api_client.configuration return KaggleClient( - env=env, - verbose=verbose, - username=config.username, - password=config.password) + env=env, + verbose=verbose, + username=config.username, + password=config.password) def camel_to_snake(self, name): """ @@ -749,7 +752,7 @@ def competitions_list(self, request.search = search request.sort_by = sort_by response = kaggle.competitions.competition_api_client.list_competitions( - request) + request) return response.competitions def competitions_list_cli(self, @@ -771,11 +774,11 @@ def competitions_list_cli(self, csv_display: if True, print comma separated values """ competitions = self.competitions_list( - group=group, - category=category, - sort_by=sort_by, - page=page, - search=search) + group=group, + category=category, + sort_by=sort_by, + page=page, + search=search) if competitions: if csv_display: self.print_csv(competitions, self.competition_fields) @@ -809,7 +812,7 @@ def competition_submit(self, file_name, message, competition, quiet=False): request.content_length = os.path.getsize(file_name) request.last_modified_epoch_seconds = int(os.path.getmtime(file_name)) response = kaggle.competitions.competition_api_client.start_submission_upload( - request) + request) upload_status = self.upload_complete(file_name, response.create_url, quiet) if upload_status != ResumableUploadResult.COMPLETE: @@ -822,7 +825,7 @@ def competition_submit(self, file_name, message, competition, quiet=False): submit_request.blob_file_tokens = response.token submit_request.submission_description = message submit_response = kaggle.competitions.competition_api_client.create_submission( - submit_request) + submit_request) return submit_response def competition_submit_cli(self, @@ -880,7 +883,7 @@ def competition_submissions( request.group = group request.sort_by = sort response = kaggle.competitions.competition_api_client.list_submissions( - request) + request) return response.submissions def competition_submissions_cli(self, @@ -913,7 +916,7 @@ def competition_submissions_cli(self, raise ValueError('No competition specified') else: submissions = self.competition_submissions( - competition, page_token=page_token, page_size=page_size) + competition, page_token=page_token, page_size=page_size) if submissions: if csv_display: self.print_csv(submissions, submission_fields) @@ -936,7 +939,7 @@ def competition_list_files(self, competition, page_token=None, page_size=20): request.page_token = page_token request.page_size = page_size response = kaggle.competitions.competition_api_client.list_data_files( - request) + request) return response def competition_list_files_cli(self, @@ -1008,7 +1011,7 @@ def competition_download_file(self, request.competition_name = competition request.file_name = file_name response = kaggle.competitions.competition_api_client.download_data_file( - request) + request) url = response.history[0].url outfile = os.path.join(effective_path, url.split('?')[0].split('/')[-1]) @@ -1039,7 +1042,7 @@ def competition_download_files(self, request = ApiDownloadDataFilesRequest() request.competition_name = competition response = kaggle.competitions.competition_api_client.download_data_files( - request) + request) url = response.url.split('?')[0] outfile = os.path.join(effective_path, competition + '.' + url.split('.')[-1]) @@ -1095,7 +1098,7 @@ def competition_leaderboard_download(self, competition, path, quiet=True): request = ApiDownloadLeaderboardRequest() request.competition_name = competition response = kaggle.competitions.competition_api_client.download_leaderboard( - request) + request) if path is None: effective_path = self.get_default_download_dir('competitions', competition) @@ -1117,7 +1120,7 @@ def competition_leaderboard_view(self, competition): request = ApiGetLeaderboardRequest() request.competition_name = competition response = kaggle.competitions.competition_api_client.get_leaderboard( - request) + request) return response.submissions def competition_leaderboard_cli(self, @@ -1203,8 +1206,8 @@ def dataset_list(self, if size: raise ValueError( - 'The --size parameter has been deprecated. ' + - 'Please use --max-size and --min-size to filter dataset sizes.') + 'The --size parameter has been deprecated. ' + + 'Please use --max-size and --min-size to filter dataset sizes.') if file_type and file_type not in self.valid_dataset_file_types: raise ValueError('Invalid file type specified. Valid options are ' + @@ -1317,20 +1320,20 @@ def dataset_metadata_update(self, dataset, path): with open(meta_file, 'r') as f: metadata = json.load(f) updateSettingsRequest = DatasetUpdateSettingsRequest( - title=metadata['title'], - subtitle=metadata['subtitle'], - description=metadata['description'], - is_private=metadata['isPrivate'], - licenses=[License(name=l['name']) for l in metadata['licenses']], - keywords=metadata['keywords'], - collaborators=[ - Collaborator(username=c['username'], role=c['role']) - for c in metadata['collaborators'] - ], - data=metadata['data']) + title=metadata['title'], + subtitle=metadata['subtitle'], + description=metadata['description'], + is_private=metadata['isPrivate'], + licenses=[License(name=l['name']) for l in metadata['licenses']], + keywords=metadata['keywords'], + collaborators=[ + Collaborator(username=c['username'], role=c['role']) + for c in metadata['collaborators'] + ], + data=metadata['data']) result = self.process_response( - self.metadata_post_with_http_info(owner_slug, dataset_slug, - updateSettingsRequest)) + self.metadata_post_with_http_info(owner_slug, dataset_slug, + updateSettingsRequest)) if (len(result['errors']) > 0): [print(e['message']) for e in result['errors']] exit(1) @@ -1343,7 +1346,7 @@ def dataset_metadata(self, dataset, path): os.makedirs(effective_path) result = self.process_response( - self.metadata_get_with_http_info(owner_slug, dataset_slug)) + self.metadata_get_with_http_info(owner_slug, dataset_slug)) if (result['errorMessage']): raise Exception(result['errorMessage']) @@ -1367,26 +1370,28 @@ def dataset_metadata_cli(self, dataset, path, update, dataset_opt=None): def dataset_list_files(self, dataset, page_token=None, page_size=20): """ List files for a dataset. - Parameters - ========== - dataset: the string identified of the dataset - should be in format [owner]/[dataset-name] - page_token: the page token for pagination - page_size: the number of items per page - """ + + Parameters + ========== + dataset: the string identified of the dataset + should be in format [owner]/[dataset-name] + page_token: the page token for pagination + page_size: the number of items per page + """ if dataset is None: raise ValueError('A dataset must be specified') owner_slug, dataset_slug, dataset_version_number = self.split_dataset_string( - dataset) + dataset) - dataset_list_files_result = self.process_response( - self.datasets_list_files_with_http_info( - owner_slug=owner_slug, - dataset_slug=dataset_slug, - dataset_version_number=dataset_version_number, - page_token=page_token, - page_size=page_size)) - return ListFilesResult(dataset_list_files_result) + with self.build_kaggle_client() as kaggle: + request = ApiListDatasetFilesRequest() + request.owner_slug = owner_slug + request.dataset_slug = dataset_slug + request.dataset_version_number = dataset_version_number + request.page_token = page_token + request.page_size = page_size + response = kaggle.datasets.dataset_api_client.list_dataset_files(request) + return ListFilesResult(response) def dataset_list_files_cli(self, dataset, @@ -1425,11 +1430,11 @@ def dataset_list_files_cli(self, def dataset_status(self, dataset): """ Call to get the status of a dataset from the API. - Parameters - ========== - dataset: the string identifier of the dataset - should be in format [owner]/[dataset-name] - """ + Parameters + ========== + dataset: the string identifier of the dataset + should be in format [owner]/[dataset-name] + """ if dataset is None: raise ValueError('A dataset must be specified') if '/' in dataset: @@ -1440,10 +1445,13 @@ def dataset_status(self, dataset): else: owner_slug = self.get_config_value(self.CONFIG_NAME_USER) dataset_slug = dataset - dataset_status_result = self.process_response( - self.datasets_status_with_http_info( - owner_slug=owner_slug, dataset_slug=dataset_slug)) - return dataset_status_result + + with self.build_kaggle_client() as kaggle: + request = ApiGetDatasetStatusRequest() + request.owner_slug = owner_slug + request.dataset_slug = dataset_slug + response = kaggle.datasets.dataset_api_client.get_dataset_status(request) + return response.status.name.lower() def dataset_status_cli(self, dataset, dataset_opt=None): """ A wrapper for client for dataset_status, with additional @@ -1464,43 +1472,44 @@ def dataset_download_file(self, licenses=[]): """ Download a single file for a dataset. - Parameters - ========== - dataset: the string identified of the dataset - should be in format [owner]/[dataset-name] - file_name: the dataset configuration file - path: if defined, download to this location - force: force the download if the file already exists (default False) - quiet: suppress verbose output (default is True) - licenses: a list of license names, e.g. ['CC0-1.0'] - """ + Parameters + ========== + dataset: the string identified of the dataset + should be in format [owner]/[dataset-name] + file_name: the dataset configuration file + path: if defined, download to this location + force: force the download if the file already exists (default False) + quiet: suppress verbose output (default is True) + licenses: a list of license names, e.g. ['CC0-1.0'] + """ if '/' in dataset: self.validate_dataset_string(dataset) owner_slug, dataset_slug, dataset_version_number = self.split_dataset_string( - dataset) + dataset) else: owner_slug = self.get_config_value(self.CONFIG_NAME_USER) dataset_slug = dataset dataset_version_number = None if path is None: - effective_path = self.get_default_download_dir('datasets', owner_slug, - dataset_slug) + effective_path = self.get_default_download_dir( + 'datasets', owner_slug, dataset_slug) else: effective_path = path self._print_dataset_url_and_license(owner_slug, dataset_slug, dataset_version_number, licenses) - response = self.process_response( - self.datasets_download_file_with_http_info( - owner_slug=owner_slug, - dataset_slug=dataset_slug, - dataset_version_number=dataset_version_number, - file_name=file_name, - _preload_content=False)) - url = response.retries.history[0].redirect_location.split('?')[0] - outfile = os.path.join(effective_path, url.split('/')[-1]) + with self.build_kaggle_client() as kaggle: + request = ApiDownloadDatasetRequest() + request.owner_slug = owner_slug + request.dataset_slug = dataset_slug + request.dataset_version_number = dataset_version_number + request.file_name = file_name + response = kaggle.datasets.dataset_api_client.download_dataset(request) + url = response.history[0].url + outfile = os.path.join(effective_path, url.split('?')[0].split('/')[-1]) + if force or self.download_needed(response, outfile, quiet): self.download_file(response, outfile, quiet, not force) return True @@ -1516,35 +1525,35 @@ def dataset_download_files(self, licenses=[]): """ Download all files for a dataset. - Parameters - ========== - dataset: the string identified of the dataset - should be in format [owner]/[dataset-name] - path: the path to download the dataset to - force: force the download if the file already exists (default False) - quiet: suppress verbose output (default is True) - unzip: if True, unzip files upon download (default is False) - licenses: a list of license names, e.g. ['CC0-1.0'] - """ + Parameters + ========== + dataset: the string identified of the dataset + should be in format [owner]/[dataset-name] + path: the path to download the dataset to + force: force the download if the file already exists (default False) + quiet: suppress verbose output (default is True) + unzip: if True, unzip files upon download (default is False) + licenses: a list of license names, e.g. ['CC0-1.0'] + """ if dataset is None: raise ValueError('A dataset must be specified') owner_slug, dataset_slug, dataset_version_number = self.split_dataset_string( - dataset) + dataset) if path is None: - effective_path = self.get_default_download_dir('datasets', owner_slug, - dataset_slug) + effective_path = self.get_default_download_dir( + 'datasets', owner_slug, dataset_slug) else: effective_path = path self._print_dataset_url_and_license(owner_slug, dataset_slug, dataset_version_number, licenses) - response = self.process_response( - self.datasets_download_with_http_info( - owner_slug=owner_slug, - dataset_slug=dataset_slug, - dataset_version_number=dataset_version_number, - _preload_content=False)) + with self.build_kaggle_client() as kaggle: + request = ApiDownloadDatasetRequest() + request.owner_slug = owner_slug + request.dataset_slug = dataset_slug + request.dataset_version_number = dataset_version_number + response = kaggle.datasets.dataset_api_client.download_dataset(request) outfile = os.path.join(effective_path, dataset_slug + '.zip') if force or self.download_needed(response, outfile, quiet): @@ -1561,18 +1570,18 @@ def dataset_download_files(self, z.extractall(effective_path) except zipfile.BadZipFile as e: raise ValueError( - f"The file {outfile} is corrupted or not a valid zip file. " - "Please report this issue at https://www.github.com/kaggle/kaggle-api" + f"The file {outfile} is corrupted or not a valid zip file. " + "Please report this issue at https://www.github.com/kaggle/kaggle-api" ) except FileNotFoundError: raise FileNotFoundError( - f"The file {outfile} was not found. " - "Please report this issue at https://www.github.com/kaggle/kaggle-api" + f"The file {outfile} was not found. " + "Please report this issue at https://www.github.com/kaggle/kaggle-api" ) except Exception as e: raise RuntimeError( - f"An unexpected error occurred: {e}. " - "Please report this issue at https://www.github.com/kaggle/kaggle-api" + f"An unexpected error occurred: {e}. " + "Please report this issue at https://www.github.com/kaggle/kaggle-api" ) try: @@ -1600,9 +1609,9 @@ def dataset_download_cli(self, unzip=False, force=False, quiet=False): - """ client wrapper for dataset_download_files and download dataset file, + """ Client wrapper for dataset_download_files and download dataset file, either for a specific file (when file_name is provided), - or all files for a dataset (plural) + or all files for a dataset (plural). Parameters ========== @@ -1619,40 +1628,40 @@ def dataset_download_cli(self, owner_slug, dataset_slug, _ = self.split_dataset_string(dataset) metadata = self.process_response( - self.metadata_get_with_http_info(owner_slug, dataset_slug)) + self.metadata_get_with_http_info(owner_slug, dataset_slug)) if 'info' in metadata and 'licenses' in metadata['info']: # license_objs format is like: [{ 'name': 'CC0-1.0' }] license_objs = metadata['info']['licenses'] licenses = [ - license_obj['name'] - for license_obj in license_objs - if 'name' in license_obj + license_obj['name'] + for license_obj in license_objs + if 'name' in license_obj ] else: licenses = [ - 'Error retrieving license. Please visit the Dataset URL to view license information.' + 'Error retrieving license. Please visit the Dataset URL to view license information.' ] if file_name is None: self.dataset_download_files( - dataset, - path=path, - unzip=unzip, - force=force, - quiet=quiet, - licenses=licenses) + dataset, + path=path, + unzip=unzip, + force=force, + quiet=quiet, + licenses=licenses) else: self.dataset_download_file( - dataset, - file_name, - path=path, - force=force, - quiet=quiet, - licenses=licenses) + dataset, + file_name, + path=path, + force=force, + quiet=quiet, + licenses=licenses) def _upload_blob(self, path, quiet, blob_type, upload_context): - """ upload a file + """ Upload a file. Parameters ========== @@ -1666,13 +1675,13 @@ def _upload_blob(self, path, quiet, blob_type, upload_context): last_modified_epoch_seconds = int(os.path.getmtime(path)) start_blob_upload_request = StartBlobUploadRequest( - blob_type, - file_name, - content_length, - last_modified_epoch_seconds=last_modified_epoch_seconds) + blob_type, + file_name, + content_length, + last_modified_epoch_seconds=last_modified_epoch_seconds) file_upload = upload_context.new_resumable_file_upload( - path, start_blob_upload_request) + path, start_blob_upload_request) for i in range(0, self.MAX_UPLOAD_RESUME_ATTEMPTS): if file_upload.upload_complete: @@ -1681,15 +1690,15 @@ def _upload_blob(self, path, quiet, blob_type, upload_context): if not file_upload.can_resume: # Initiate upload on Kaggle backend to get the url and token. start_blob_upload_response = self.process_response( - self.with_retry(self.upload_file_with_http_info)( - file_upload.start_blob_upload_request)) + self.with_retry(self.upload_file_with_http_info)( + file_upload.start_blob_upload_request)) file_upload.upload_initiated(start_blob_upload_response) upload_result = self.upload_complete( - path, - file_upload.start_blob_upload_response.create_url, - quiet, - resume=file_upload.can_resume) + path, + file_upload.start_blob_upload_response.create_url, + quiet, + resume=file_upload.can_resume) if upload_result == ResumableUploadResult.INCOMPLETE: continue # Continue (i.e., retry/resume) only if the upload is incomplete. @@ -1706,7 +1715,7 @@ def dataset_create_version(self, convert_to_csv=True, delete_old_versions=False, dir_mode='skip'): - """ create a version of a dataset + """ Create a version of a dataset. Parameters ========== @@ -1729,6 +1738,8 @@ def dataset_create_version(self, id_no = self.get_or_default(meta_data, 'id_no', None) if not ref and not id_no: raise ValueError('ID or slug must be specified in the metadata') + elif ref and ref == self.config_values[self.CONFIG_NAME_USER] + '/INSERT_SLUG_HERE': + raise ValueError('Default slug detected, please change values before uploading') subtitle = meta_data.get('subtitle') if subtitle and (len(subtitle) < 20 or len(subtitle) > 80): @@ -1740,14 +1751,35 @@ def dataset_create_version(self, description = meta_data.get('description') keywords = self.get_or_default(meta_data, 'keywords', []) - request = DatasetNewVersionRequest( - version_notes=version_notes, - subtitle=subtitle, - description=description, - files=[], - convert_to_csv=convert_to_csv, - category_ids=keywords, - delete_old_versions=delete_old_versions) + body = ApiCreateDatasetVersionRequestBody() + body.version_notes=version_notes + body.subtitle=subtitle + body.description=description + body.files=[] + body.category_ids=keywords + body.delete_old_versions=delete_old_versions + + with self.build_kaggle_client() as kaggle: + if id_no: + request = ApiCreateDatasetVersionByIdRequest() + request.id = id_no + message = kaggle.datasets.dataset_api_client.create_dataset_version_by_id + else: + self.validate_dataset_string(ref) + ref_list = ref.split('/') + owner_slug = ref_list[0] + dataset_slug = ref_list[1] + request = ApiCreateDatasetVersionRequest() + request.owner_slug = owner_slug + request.dataset_slug = dataset_slug + message = kaggle.datasets.dataset_api_client.create_dataset_version + request.body = body + with ResumableUploadContext() as upload_context: + self.upload_files(body, resources, folder, ApiBlobType.DATASET, + upload_context, quiet, dir_mode) + request.body.files = [self._api_dataset_new_file(file) for file in request.body.files] + response = self.with_retry(message)(request) + return response with ResumableUploadContext() as upload_context: self.upload_files(request, resources, folder, ApiBlobType.DATASET, @@ -1755,26 +1787,30 @@ def dataset_create_version(self, if id_no: result = DatasetNewVersionResponse( - self.process_response( - self.with_retry( - self.datasets_create_version_by_id_with_http_info)( - id_no, request))) + self.process_response( + self.with_retry( + self.datasets_create_version_by_id_with_http_info)( + id_no, request))) else: - if ref == self.config_values[ - self.CONFIG_NAME_USER] + '/INSERT_SLUG_HERE': - raise ValueError('Default slug detected, please change values before ' - 'uploading') + if ref == self.config_values[self.CONFIG_NAME_USER] + '/INSERT_SLUG_HERE': + raise ValueError('Default slug detected, please change values before uploading') self.validate_dataset_string(ref) ref_list = ref.split('/') owner_slug = ref_list[0] dataset_slug = ref_list[1] result = DatasetNewVersionResponse( - self.process_response( - self.datasets_create_version_with_http_info( - owner_slug, dataset_slug, request))) + self.process_response( + self.datasets_create_version_with_http_info( + owner_slug, dataset_slug, request))) return result + def _api_dataset_new_file(self, file): + # TODO Eliminate the need for this conversion + f = ApiDatasetNewFile() + f.token = file.token + return f + def dataset_create_version_cli(self, folder, version_notes, @@ -1794,12 +1830,12 @@ def dataset_create_version_cli(self, """ folder = folder or os.getcwd() result = self.dataset_create_version( - folder, - version_notes, - quiet=quiet, - convert_to_csv=convert_to_csv, - delete_old_versions=delete_old_versions, - dir_mode=dir_mode) + folder, + version_notes, + quiet=quiet, + convert_to_csv=convert_to_csv, + delete_old_versions=delete_old_versions, + dir_mode=dir_mode) if result is None: print('Dataset version creation error: See previous output') @@ -1845,16 +1881,17 @@ def dataset_create_new(self, quiet=False, convert_to_csv=True, dir_mode='skip'): - """ create a new dataset, meaning the same as creating a version but - with extra metadata like license and user/owner. - Parameters - ========== - folder: the folder to get the metadata file from - public: should the dataset be public? - quiet: suppress verbose output (default is False) - convert_to_csv: if True, convert data to comma separated value - dir_mode: What to do with directories: "skip" - ignore; "zip" - compress and upload - """ + """ Create a new dataset, meaning the same as creating a version but + with extra metadata like license and user/owner. + + Parameters + ========== + folder: the folder to get the metadata file from + public: should the dataset be public? + quiet: suppress verbose output (default is False) + convert_to_csv: if True, convert data to comma separated value + dir_mode: What to do with directories: "skip" - ignore; "zip" - compress and upload + """ if not os.path.isdir(folder): raise ValueError('Invalid folder: ' + folder) @@ -1871,18 +1908,22 @@ def dataset_create_new(self, dataset_slug = ref_list[1] # validations - if ref == self.config_values[self.CONFIG_NAME_USER] + '/INSERT_SLUG_HERE': + if ref == self.config_values[ + self.CONFIG_NAME_USER] + '/INSERT_SLUG_HERE': raise ValueError( - 'Default slug detected, please change values before uploading') + 'Default slug detected, please change values before uploading') if title == 'INSERT_TITLE_HERE': raise ValueError( - 'Default title detected, please change values before uploading') + 'Default title detected, please change values before uploading' + ) if len(licenses) != 1: raise ValueError('Please specify exactly one license') if len(dataset_slug) < 6 or len(dataset_slug) > 50: - raise ValueError('The dataset slug must be between 6 and 50 characters') + raise ValueError( + 'The dataset slug must be between 6 and 50 characters') if len(title) < 6 or len(title) > 50: - raise ValueError('The dataset title must be between 6 and 50 characters') + raise ValueError( + 'The dataset title must be between 6 and 50 characters') resources = meta_data.get('resources') if resources: self.validate_resources(folder, resources) @@ -1893,27 +1934,44 @@ def dataset_create_new(self, subtitle = meta_data.get('subtitle') if subtitle and (len(subtitle) < 20 or len(subtitle) > 80): - raise ValueError('Subtitle length must be between 20 and 80 characters') - - request = DatasetNewRequest( - title=title, - slug=dataset_slug, - owner_slug=owner_slug, - license_name=license_name, - subtitle=subtitle, - description=description, - files=[], - is_private=not public, - convert_to_csv=convert_to_csv, - category_ids=keywords) + raise ValueError( + 'Subtitle length must be between 20 and 80 characters') + + request = DatasetNewRequest(title=title, + slug=dataset_slug, + owner_slug=owner_slug, + license_name=license_name, + subtitle=subtitle, + description=description, + files=[], + is_private=not public, + convert_to_csv=convert_to_csv, + category_ids=keywords) with ResumableUploadContext() as upload_context: + # TODO Change upload_files() to use ApiCreateDatasetRequest self.upload_files(request, resources, folder, ApiBlobType.DATASET, upload_context, quiet, dir_mode) + + with self.build_kaggle_client() as kaggle: + retry_request = ApiCreateDatasetRequest() + retry_request.title=title + retry_request.slug=dataset_slug + retry_request.owner_slug=owner_slug + retry_request.license_name=license_name + retry_request.subtitle=subtitle + retry_request.description=description + retry_request.files=[] + retry_request.is_private=not public + retry_request.category_ids=keywords + response = self.with_retry( + kaggle.datasets.dataset_api_client.create_dataset)(retry_request) + return response + result = DatasetNewResponse( - self.process_response( - self.with_retry( - self.datasets_create_new_with_http_info)(request))) + self.process_response( + self.with_retry( + self.datasets_create_new_with_http_info)(request))) return result @@ -1984,7 +2042,7 @@ def download_file(self, file_exists = os.path.isfile(outfile) resumable = 'Accept-Ranges' in response.headers and response.headers[ - 'Accept-Ranges'] == 'bytes' + 'Accept-Ranges'] == 'bytes' if resume and resumable and file_exists: size_read = os.path.getsize(outfile) @@ -1992,16 +2050,16 @@ def download_file(self, if not quiet: print("... resuming from %d bytes (%d bytes left) ..." % ( - size_read, - size - size_read, + size_read, + size - size_read, )) request_history = response.retries.history[0] response = self.api_client.request( - request_history.method, - request_history.redirect_location, - headers={'Range': 'bytes=%d-' % (size_read,)}, - _preload_content=False) + request_history.method, + request_history.redirect_location, + headers={'Range': 'bytes=%d-' % (size_read,)}, + _preload_content=False) with tqdm( total=size, @@ -2019,8 +2077,8 @@ def download_file(self, break out.write(data) os.utime( - outfile, - times=(remote_date_timestamp - 1, remote_date_timestamp - 1)) + outfile, + times=(remote_date_timestamp - 1, remote_date_timestamp - 1)) size_read = min(size, size_read + chunk_size) pbar.update(len(data)) else: @@ -2029,8 +2087,8 @@ def download_file(self, break out.write(data) os.utime( - outfile, - times=(remote_date_timestamp - 1, remote_date_timestamp - 1)) + outfile, + times=(remote_date_timestamp - 1, remote_date_timestamp - 1)) size_read = min(size, size_read + chunk_size) pbar.update(len(data)) if not quiet: @@ -2104,19 +2162,19 @@ def kernels_list(self, group = 'profile' kernels_list_result = self.process_response( - self.kernels_list_with_http_info( - page=page, - page_size=page_size, - group=group, - user=user or '', - language=language or 'all', - kernel_type=kernel_type or 'all', - output_type=output_type or 'all', - sort_by=sort_by or 'hotness', - dataset=dataset or '', - competition=competition or '', - parent_kernel=parent_kernel or '', - search=search or '')) + self.kernels_list_with_http_info( + page=page, + page_size=page_size, + group=group, + user=user or '', + language=language or 'all', + kernel_type=kernel_type or 'all', + output_type=output_type or 'all', + sort_by=sort_by or 'hotness', + dataset=dataset or '', + competition=competition or '', + parent_kernel=parent_kernel or '', + search=search or '')) return [Kernel(k) for k in kernels_list_result] def kernels_list_cli(self, @@ -2140,18 +2198,18 @@ def kernels_list_cli(self, csv_display: if True, print comma separated values instead of table """ kernels = self.kernels_list( - page=page, - page_size=page_size, - search=search, - mine=mine, - dataset=dataset, - competition=competition, - parent_kernel=parent, - user=user, - language=language, - kernel_type=kernel_type, - output_type=output_type, - sort_by=sort_by) + page=page, + page_size=page_size, + search=search, + mine=mine, + dataset=dataset, + competition=competition, + parent_kernel=parent, + user=user, + language=language, + kernel_type=kernel_type, + output_type=output_type, + sort_by=sort_by) fields = ['ref', 'title', 'author', 'lastRunTime', 'totalVotes'] if kernels: if csv_display: @@ -2173,14 +2231,14 @@ def kernels_list_files(self, kernel, page_token=None, page_size=20): if kernel is None: raise ValueError('A kernel must be specified') user_name, kernel_slug, kernel_version_number = self.split_dataset_string( - kernel) + kernel) kernels_list_files_result = self.process_response( - self.kernels_list_files_with_http_info( - kernel_slug=kernel_slug, - user_name=user_name, - page_token=page_token, - page_size=page_size)) + self.kernels_list_files_with_http_info( + kernel_slug=kernel_slug, + user_name=user_name, + page_token=page_token, + page_size=page_size)) return FileList(kernels_list_files_result) def kernels_list_files_cli(self, @@ -2236,30 +2294,30 @@ def kernels_initialize(self, folder): username = self.get_config_value(self.CONFIG_NAME_USER) meta_data = { - 'id': - username + '/INSERT_KERNEL_SLUG_HERE', - 'title': - 'INSERT_TITLE_HERE', - 'code_file': - 'INSERT_CODE_FILE_PATH_HERE', - 'language': - 'Pick one of: {' + - ','.join(x for x in self.valid_push_language_types) + '}', - 'kernel_type': - 'Pick one of: {' + - ','.join(x for x in self.valid_push_kernel_types) + '}', - 'is_private': - 'true', - 'enable_gpu': - 'false', - 'enable_tpu': - 'false', - 'enable_internet': - 'true', - 'dataset_sources': [], - 'competition_sources': [], - 'kernel_sources': [], - 'model_sources': [], + 'id': + username + '/INSERT_KERNEL_SLUG_HERE', + 'title': + 'INSERT_TITLE_HERE', + 'code_file': + 'INSERT_CODE_FILE_PATH_HERE', + 'language': + 'Pick one of: {' + + ','.join(x for x in self.valid_push_language_types) + '}', + 'kernel_type': + 'Pick one of: {' + + ','.join(x for x in self.valid_push_kernel_types) + '}', + 'is_private': + 'true', + 'enable_gpu': + 'false', + 'enable_tpu': + 'false', + 'enable_internet': + 'true', + 'dataset_sources': [], + 'competition_sources': [], + 'kernel_sources': [], + 'model_sources': [], } meta_file = os.path.join(folder, self.KERNEL_METADATA_FILE) with open(meta_file, 'w') as f: @@ -2331,14 +2389,14 @@ def kernels_push(self, folder): language = self.get_or_default(meta_data, 'language', '') if language not in self.valid_push_language_types: raise ValueError( - 'A valid language must be specified in the metadata. Valid ' - 'options are ' + str(self.valid_push_language_types)) + 'A valid language must be specified in the metadata. Valid ' + 'options are ' + str(self.valid_push_language_types)) kernel_type = self.get_or_default(meta_data, 'kernel_type', '') if kernel_type not in self.valid_push_kernel_types: raise ValueError( - 'A valid kernel type must be specified in the metadata. Valid ' - 'options are ' + str(self.valid_push_kernel_types)) + 'A valid kernel type must be specified in the metadata. Valid ' + 'options are ' + str(self.valid_push_kernel_types)) if kernel_type == 'notebook' and language == 'rmarkdown': language = 'r' @@ -2378,28 +2436,28 @@ def kernels_push(self, folder): script_body = json.dumps(json_body) kernel_push_request = KernelPushRequest( - id=id_no, - slug=slug, - new_title=self.get_or_default(meta_data, 'title', None), - text=script_body, - language=language, - kernel_type=kernel_type, - is_private=self.get_or_default(meta_data, 'is_private', None), - enable_gpu=self.get_or_default(meta_data, 'enable_gpu', None), - enable_tpu=self.get_or_default(meta_data, 'enable_tpu', None), - enable_internet=self.get_or_default(meta_data, 'enable_internet', None), - dataset_data_sources=dataset_sources, - competition_data_sources=self.get_or_default(meta_data, - 'competition_sources', []), - kernel_data_sources=kernel_sources, - model_data_sources=model_sources, - category_ids=self.get_or_default(meta_data, 'keywords', []), - docker_image_pinning_type=docker_pinning_type) + id=id_no, + slug=slug, + new_title=self.get_or_default(meta_data, 'title', None), + text=script_body, + language=language, + kernel_type=kernel_type, + is_private=self.get_or_default(meta_data, 'is_private', None), + enable_gpu=self.get_or_default(meta_data, 'enable_gpu', None), + enable_tpu=self.get_or_default(meta_data, 'enable_tpu', None), + enable_internet=self.get_or_default(meta_data, 'enable_internet', None), + dataset_data_sources=dataset_sources, + competition_data_sources=self.get_or_default(meta_data, + 'competition_sources', []), + kernel_data_sources=kernel_sources, + model_data_sources=model_sources, + category_ids=self.get_or_default(meta_data, 'keywords', []), + docker_image_pinning_type=docker_pinning_type) result = KernelPushResponse( - self.process_response( - self.kernel_push_with_http_info( - kernel_push_request=kernel_push_request))) + self.process_response( + self.kernel_push_with_http_info( + kernel_push_request=kernel_push_request))) return result def kernels_push_cli(self, folder): @@ -2480,7 +2538,7 @@ def kernels_pull(self, kernel, path, metadata=False, quiet=True): os.makedirs(effective_path) response = self.process_response( - self.kernel_pull_with_http_info(owner_slug, kernel_slug)) + self.kernel_pull_with_http_info(owner_slug, kernel_slug)) blob = response['blob'] if os.path.isfile(effective_path): @@ -2576,7 +2634,7 @@ def kernels_pull_cli(self, """ kernel = kernel or kernel_opt effective_path = self.kernels_pull( - kernel, path=path, metadata=metadata, quiet=False) + kernel, path=path, metadata=metadata, quiet=False) if metadata: print('Source code and metadata downloaded to ' + effective_path) else: @@ -2615,7 +2673,7 @@ def kernels_output(self, kernel, path, force=False, quiet=True): raise ValueError('You must specify a directory for the kernels output') response = self.process_response( - self.kernel_output_with_http_info(owner_slug, kernel_slug)) + self.kernel_output_with_http_info(owner_slug, kernel_slug)) outfiles = [] for item in response['files']: outfile = os.path.join(target_dir, item['fileName']) @@ -2671,7 +2729,7 @@ def kernels_status(self, kernel): owner_slug = self.get_config_value(self.CONFIG_NAME_USER) kernel_slug = kernel response = self.process_response( - self.kernel_status_with_http_info(owner_slug, kernel_slug)) + self.kernel_status_with_http_info(owner_slug, kernel_slug)) return response def kernels_status_cli(self, kernel, kernel_opt=None): @@ -2700,7 +2758,7 @@ def model_get(self, model): owner_slug, model_slug = self.split_model_string(model) model_get_result = self.process_response( - self.get_model_with_http_info(owner_slug, model_slug)) + self.get_model_with_http_info(owner_slug, model_slug)) return model_get_result def model_get_cli(self, model, folder=None): @@ -2757,12 +2815,12 @@ def model_list(self, raise ValueError('Page size must be >= 1') models_list_result = self.process_response( - self.models_list_with_http_info( - sort_by=sort_by or 'hotness', - search=search or '', - owner=owner or '', - page_size=page_size, - page_token=page_token)) + self.models_list_with_http_info( + sort_by=sort_by or 'hotness', + search=search or '', + owner=owner or '', + page_size=page_size, + page_token=page_token)) next_page_token = models_list_result['nextPageToken'] if next_page_token: @@ -2809,18 +2867,18 @@ def model_initialize(self, folder): raise ValueError('Invalid folder: ' + folder) meta_data = { - 'ownerSlug': - 'INSERT_OWNER_SLUG_HERE', - 'title': - 'INSERT_TITLE_HERE', - 'slug': - 'INSERT_SLUG_HERE', - 'subtitle': - '', - 'isPrivate': - True, - 'description': - '''# Model Summary + 'ownerSlug': + 'INSERT_OWNER_SLUG_HERE', + 'title': + 'INSERT_TITLE_HERE', + 'slug': + 'INSERT_SLUG_HERE', + 'subtitle': + '', + 'isPrivate': + True, + 'description': + '''# Model Summary # Model Characteristics @@ -2828,10 +2886,10 @@ def model_initialize(self, folder): # Evaluation Results ''', - 'publishTime': - '', - 'provenanceSources': - '' + 'publishTime': + '', + 'provenanceSources': + '' } meta_file = os.path.join(folder, self.MODEL_METADATA_FILE) with open(meta_file, 'w') as f: @@ -2864,36 +2922,36 @@ def model_create_new(self, folder): subtitle = meta_data.get('subtitle') is_private = self.get_or_fail(meta_data, 'isPrivate') description = self.sanitize_markdown( - self.get_or_fail(meta_data, 'description')) + self.get_or_fail(meta_data, 'description')) publish_time = meta_data.get('publishTime') provenance_sources = meta_data.get('provenanceSources') # validations if owner_slug == 'INSERT_OWNER_SLUG_HERE': raise ValueError( - 'Default ownerSlug detected, please change values before uploading') + 'Default ownerSlug detected, please change values before uploading') if title == 'INSERT_TITLE_HERE': raise ValueError( - 'Default title detected, please change values before uploading') + 'Default title detected, please change values before uploading') if slug == 'INSERT_SLUG_HERE': raise ValueError( - 'Default slug detected, please change values before uploading') + 'Default slug detected, please change values before uploading') if not isinstance(is_private, bool): raise ValueError('model.isPrivate must be a boolean') if publish_time: self.validate_date(publish_time) request = ModelNewRequest( - owner_slug=owner_slug, - slug=slug, - title=title, - subtitle=subtitle, - is_private=is_private, - description=description, - publish_time=publish_time, - provenance_sources=provenance_sources) + owner_slug=owner_slug, + slug=slug, + title=title, + subtitle=subtitle, + is_private=is_private, + description=description, + publish_time=publish_time, + provenance_sources=provenance_sources) result = ModelNewResponse( - self.process_response(self.models_create_new_with_http_info(request))) + self.process_response(self.models_create_new_with_http_info(request))) return result @@ -2908,7 +2966,7 @@ def model_create_new_cli(self, folder=None): if result.hasId: print('Your model was created. Id={}. Url={}'.format( - result.id, result.url)) + result.id, result.url)) else: print('Model creation error: ' + result.error) @@ -2928,8 +2986,8 @@ def model_delete(self, model, yes): exit(0) res = ModelDeleteResponse( - self.process_response( - self.delete_model_with_http_info(owner_slug, model_slug))) + self.process_response( + self.delete_model_with_http_info(owner_slug, model_slug))) return res def model_delete_cli(self, model, yes): @@ -2974,10 +3032,10 @@ def model_update(self, folder): # validations if owner_slug == 'INSERT_OWNER_SLUG_HERE': raise ValueError( - 'Default ownerSlug detected, please change values before uploading') + 'Default ownerSlug detected, please change values before uploading') if slug == 'INSERT_SLUG_HERE': raise ValueError( - 'Default slug detected, please change values before uploading') + 'Default slug detected, please change values before uploading') if is_private != None and not isinstance(is_private, bool): raise ValueError('model.isPrivate must be a boolean') if publish_time: @@ -3002,16 +3060,16 @@ def model_update(self, folder): update_mask['paths'].append('provenance_sources') request = ModelUpdateRequest( - title=title, - subtitle=subtitle, - is_private=is_private, - description=description, - publish_time=publish_time, - provenance_sources=provenance_sources, - update_mask=update_mask) + title=title, + subtitle=subtitle, + is_private=is_private, + description=description, + publish_time=publish_time, + provenance_sources=provenance_sources, + update_mask=update_mask) result = ModelNewResponse( - self.process_response( - self.update_model_with_http_info(owner_slug, slug, request))) + self.process_response( + self.update_model_with_http_info(owner_slug, slug, request))) return result @@ -3026,7 +3084,7 @@ def model_update_cli(self, folder=None): if result.hasId: print('Your model was updated. Id={}. Url={}'.format( - result.id, result.url)) + result.id, result.url)) else: print('Model update error: ' + result.error) @@ -3040,11 +3098,11 @@ def model_instance_get(self, model_instance): if model_instance is None: raise ValueError('A model instance must be specified') owner_slug, model_slug, framework, instance_slug = self.split_model_instance_string( - model_instance) + model_instance) mi = self.process_response( - self.get_model_instance_with_http_info(owner_slug, model_slug, - framework, instance_slug)) + self.get_model_instance_with_http_info(owner_slug, model_slug, + framework, instance_slug)) return mi def model_instance_get_cli(self, model_instance, folder=None): @@ -3062,7 +3120,7 @@ def model_instance_get_cli(self, model_instance, folder=None): meta_file = os.path.join(folder, self.MODEL_INSTANCE_METADATA_FILE) owner_slug, model_slug, framework, instance_slug = self.split_model_instance_string( - model_instance) + model_instance) data = {} data['id'] = mi['id'] @@ -3080,10 +3138,10 @@ def model_instance_get_cli(self, model_instance, folder=None): data['modelInstanceType'] = mi['modelInstanceType'] if mi['baseModelInstanceInformation'] is not None: data['baseModelInstance'] = '{}/{}/{}/{}'.format( - mi['baseModelInstanceInformation']['owner']['slug'], - mi['baseModelInstanceInformation']['modelSlug'], - mi['baseModelInstanceInformation']['framework'], - mi['baseModelInstanceInformation']['instanceSlug']) + mi['baseModelInstanceInformation']['owner']['slug'], + mi['baseModelInstanceInformation']['modelSlug'], + mi['baseModelInstanceInformation']['framework'], + mi['baseModelInstanceInformation']['instanceSlug']) data['externalBaseModelUrl'] = mi['externalBaseModelUrl'] with open(meta_file, 'w') as f: @@ -3100,18 +3158,18 @@ def model_instance_initialize(self, folder): raise ValueError('Invalid folder: ' + folder) meta_data = { - 'ownerSlug': - 'INSERT_OWNER_SLUG_HERE', - 'modelSlug': - 'INSERT_EXISTING_MODEL_SLUG_HERE', - 'instanceSlug': - 'INSERT_INSTANCE_SLUG_HERE', - 'framework': - 'INSERT_FRAMEWORK_HERE', - 'overview': - '', - 'usage': - '''# Model Format + 'ownerSlug': + 'INSERT_OWNER_SLUG_HERE', + 'modelSlug': + 'INSERT_EXISTING_MODEL_SLUG_HERE', + 'instanceSlug': + 'INSERT_INSTANCE_SLUG_HERE', + 'framework': + 'INSERT_FRAMEWORK_HERE', + 'overview': + '', + 'usage': + '''# Model Format # Training Data @@ -3125,17 +3183,17 @@ def model_instance_initialize(self, folder): # Changelog ''', - 'licenseName': - 'Apache 2.0', - 'fineTunable': - False, - 'trainingData': [], - 'modelInstanceType': - 'Unspecified', - 'baseModelInstanceId': - 0, - 'externalBaseModelUrl': - '' + 'licenseName': + 'Apache 2.0', + 'fineTunable': + False, + 'trainingData': [], + 'modelInstanceType': + 'Unspecified', + 'baseModelInstanceId': + 0, + 'externalBaseModelUrl': + '' } meta_file = os.path.join(folder, self.MODEL_INSTANCE_METADATA_FILE) with open(meta_file, 'w') as f: @@ -3169,7 +3227,7 @@ def model_instance_create(self, folder, quiet=False, dir_mode='skip'): instance_slug = self.get_or_fail(meta_data, 'instanceSlug') framework = self.get_or_fail(meta_data, 'framework') overview = self.sanitize_markdown( - self.get_or_default(meta_data, 'overview', '')) + self.get_or_default(meta_data, 'overview', '')) usage = self.sanitize_markdown(self.get_or_default(meta_data, 'usage', '')) license_name = self.get_or_fail(meta_data, 'licenseName') fine_tunable = self.get_or_default(meta_data, 'fineTunable', False) @@ -3184,17 +3242,17 @@ def model_instance_create(self, folder, quiet=False, dir_mode='skip'): # validations if owner_slug == 'INSERT_OWNER_SLUG_HERE': raise ValueError( - 'Default ownerSlug detected, please change values before uploading') + 'Default ownerSlug detected, please change values before uploading') if model_slug == 'INSERT_EXISTING_MODEL_SLUG_HERE': raise ValueError( - 'Default modelSlug detected, please change values before uploading') + 'Default modelSlug detected, please change values before uploading') if instance_slug == 'INSERT_INSTANCE_SLUG_HERE': raise ValueError( - 'Default instanceSlug detected, please change values before uploading' + 'Default instanceSlug detected, please change values before uploading' ) if framework == 'INSERT_FRAMEWORK_HERE': raise ValueError( - 'Default framework detected, please change values before uploading') + 'Default framework detected, please change values before uploading') if license_name == '': raise ValueError('Please specify a license') if not isinstance(fine_tunable, bool): @@ -3203,25 +3261,25 @@ def model_instance_create(self, folder, quiet=False, dir_mode='skip'): raise ValueError('modelInstance.trainingData must be a list') request = ModelNewInstanceRequest( - instance_slug=instance_slug, - framework=framework, - overview=overview, - usage=usage, - license_name=license_name, - fine_tunable=fine_tunable, - training_data=training_data, - model_instance_type=model_instance_type, - base_model_instance=base_model_instance, - external_base_model_url=external_base_model_url, - files=[]) + instance_slug=instance_slug, + framework=framework, + overview=overview, + usage=usage, + license_name=license_name, + fine_tunable=fine_tunable, + training_data=training_data, + model_instance_type=model_instance_type, + base_model_instance=base_model_instance, + external_base_model_url=external_base_model_url, + files=[]) with ResumableUploadContext() as upload_context: self.upload_files(request, None, folder, ApiBlobType.MODEL, upload_context, quiet, dir_mode) result = ModelNewResponse( - self.process_response( - self.with_retry(self.models_create_instance_with_http_info)( - owner_slug, model_slug, request))) + self.process_response( + self.with_retry(self.models_create_instance_with_http_info)( + owner_slug, model_slug, request))) return result @@ -3238,7 +3296,7 @@ def model_instance_create_cli(self, folder, quiet=False, dir_mode='skip'): if result.hasId: print('Your model instance was created. Id={}. Url={}'.format( - result.id, result.url)) + result.id, result.url)) else: print('Model instance creation error: ' + result.error) @@ -3253,7 +3311,7 @@ def model_instance_delete(self, model_instance, yes): if model_instance is None: raise ValueError('A model instance must be specified') owner_slug, model_slug, framework, instance_slug = self.split_model_instance_string( - model_instance) + model_instance) if not yes: if not self.confirmation(): @@ -3261,10 +3319,10 @@ def model_instance_delete(self, model_instance, yes): exit(0) res = ModelDeleteResponse( - self.process_response( - self.delete_model_instance_with_http_info(owner_slug, model_slug, - framework, - instance_slug))) + self.process_response( + self.delete_model_instance_with_http_info(owner_slug, model_slug, + framework, + instance_slug))) return res def model_instance_delete_cli(self, model_instance, yes): @@ -3305,14 +3363,14 @@ def model_instance_files(self, [owner_slug, model_slug, framework, instance_slug] = urls response = self.process_response( - self.model_instance_files_with_http_info( - owner_slug=owner_slug, - model_slug=model_slug, - framework=framework, - instance_slug=instance_slug, - page_size=page_size, - page_token=page_token, - _preload_content=True)) + self.model_instance_files_with_http_info( + owner_slug=owner_slug, + model_slug=model_slug, + framework=framework, + instance_slug=instance_slug, + page_size=page_size, + page_token=page_token, + _preload_content=True)) if response: next_page_token = response['nextPageToken'] @@ -3338,10 +3396,10 @@ def model_instance_files_cli(self, csv_display: if True, print comma separated values instead of table """ result = self.model_instance_files( - model_instance, - page_token=page_token, - page_size=page_size, - csv_display=csv_display) + model_instance, + page_token=page_token, + page_size=page_size, + csv_display=csv_display) if result and result.files is not None: fields = ['name', 'size', 'creationDate'] if csv_display: @@ -3382,17 +3440,17 @@ def model_instance_update(self, folder): # validations if owner_slug == 'INSERT_OWNER_SLUG_HERE': raise ValueError( - 'Default ownerSlug detected, please change values before uploading') + 'Default ownerSlug detected, please change values before uploading') if model_slug == 'INSERT_SLUG_HERE': raise ValueError( - 'Default model slug detected, please change values before uploading') + 'Default model slug detected, please change values before uploading') if instance_slug == 'INSERT_INSTANCE_SLUG_HERE': raise ValueError( - 'Default instance slug detected, please change values before uploading' + 'Default instance slug detected, please change values before uploading' ) if framework == 'INSERT_FRAMEWORK_HERE': raise ValueError( - 'Default framework detected, please change values before uploading') + 'Default framework detected, please change values before uploading') if fine_tunable != None and not isinstance(fine_tunable, bool): raise ValueError('modelInstance.fineTunable must be a boolean') if training_data != None and not isinstance(training_data, list): @@ -3422,20 +3480,20 @@ def model_instance_update(self, folder): update_mask['paths'].append('external_base_model_url') request = ModelInstanceUpdateRequest( - overview=overview, - usage=usage, - license_name=license_name, - fine_tunable=fine_tunable, - training_data=training_data, - model_instance_type=model_instance_type, - base_model_instance=base_model_instance, - external_base_model_url=external_base_model_url, - update_mask=update_mask) + overview=overview, + usage=usage, + license_name=license_name, + fine_tunable=fine_tunable, + training_data=training_data, + model_instance_type=model_instance_type, + base_model_instance=base_model_instance, + external_base_model_url=external_base_model_url, + update_mask=update_mask) result = ModelNewResponse( - self.process_response( - self.update_model_instance_with_http_info(owner_slug, model_slug, - framework, instance_slug, - request))) + self.process_response( + self.update_model_instance_with_http_info(owner_slug, model_slug, + framework, instance_slug, + request))) return result @@ -3450,7 +3508,7 @@ def model_instance_update_cli(self, folder=None): if result.hasId: print('Your model instance was updated. Id={}. Url={}'.format( - result.id, result.url)) + result.id, result.url)) else: print('Model update error: ' + result.error) @@ -3471,20 +3529,20 @@ def model_instance_version_create(self, dir_mode: what to do with directories: "skip" - ignore; "zip" - compress and upload """ owner_slug, model_slug, framework, instance_slug = self.split_model_instance_string( - model_instance) + model_instance) request = ModelInstanceNewVersionRequest( - version_notes=version_notes, files=[]) + version_notes=version_notes, files=[]) with ResumableUploadContext() as upload_context: self.upload_files(request, None, folder, ApiBlobType.MODEL, upload_context, quiet, dir_mode) result = ModelNewResponse( - self.process_response( - self.with_retry( - self.models_create_instance_version_with_http_info)( - owner_slug, model_slug, framework, instance_slug, - request))) + self.process_response( + self.with_retry( + self.models_create_instance_version_with_http_info)( + owner_slug, model_slug, framework, instance_slug, + request))) return result @@ -3509,7 +3567,7 @@ def model_instance_version_create_cli(self, if result.hasId: print('Your model instance version was created. Url={}'.format( - result.url)) + result.url)) else: print('Model instance version creation error: ' + result.error) @@ -3550,13 +3608,13 @@ def model_instance_version_download(self, effective_path = path response = self.process_response( - self.model_instance_versions_download_with_http_info( - owner_slug=owner_slug, - model_slug=model_slug, - framework=framework, - instance_slug=instance_slug, - version_number=version_number, - _preload_content=False)) + self.model_instance_versions_download_with_http_info( + owner_slug=owner_slug, + model_slug=model_slug, + framework=framework, + instance_slug=instance_slug, + version_number=version_number, + _preload_content=False)) outfile = os.path.join(effective_path, model_slug + '.tar.gz') if force or self.download_needed(response, outfile, quiet): @@ -3572,8 +3630,8 @@ def model_instance_version_download(self, t.extractall(effective_path) except Exception as e: raise ValueError( - 'Error extracting the tar.gz file, please report on ' - 'www.github.com/kaggle/kaggle-api', e) + 'Error extracting the tar.gz file, please report on ' + 'www.github.com/kaggle/kaggle-api', e) try: os.remove(outfile) @@ -3599,11 +3657,11 @@ def model_instance_version_download_cli(self, untar: if True, untar files upon download (default is False) """ return self.model_instance_version_download( - model_instance_version, - path=path, - untar=untar, - force=force, - quiet=quiet) + model_instance_version, + path=path, + untar=untar, + force=force, + quiet=quiet) def model_instance_version_files(self, model_instance_version, @@ -3628,15 +3686,15 @@ def model_instance_version_files(self, [owner_slug, model_slug, framework, instance_slug, version_number] = urls response = self.process_response( - self.model_instance_version_files_with_http_info( - owner_slug=owner_slug, - model_slug=model_slug, - framework=framework, - instance_slug=instance_slug, - version_number=version_number, - page_size=page_size, - page_token=page_token, - _preload_content=True)) + self.model_instance_version_files_with_http_info( + owner_slug=owner_slug, + model_slug=model_slug, + framework=framework, + instance_slug=instance_slug, + version_number=version_number, + page_size=page_size, + page_token=page_token, + _preload_content=True)) if response: next_page_token = response['nextPageToken'] @@ -3662,10 +3720,10 @@ def model_instance_version_files_cli(self, csv_display: if True, print comma separated values instead of table """ result = self.model_instance_version_files( - model_instance_version, - page_token=page_token, - page_size=page_size, - csv_display=csv_display) + model_instance_version, + page_token=page_token, + page_size=page_size, + csv_display=csv_display) if result and result.files is not None: fields = ['name', 'size', 'creationDate'] if csv_display: @@ -3698,10 +3756,10 @@ def model_instance_version_delete(self, model_instance_version, yes): exit(0) res = ModelDeleteResponse( - self.process_response( - self.delete_model_instance_version_with_http_info( - owner_slug, model_slug, framework, instance_slug, - version_number))) + self.process_response( + self.delete_model_instance_version_with_http_info( + owner_slug, model_slug, framework, instance_slug, + version_number))) return res def model_instance_version_delete_cli(self, model_instance_version, yes): @@ -3735,12 +3793,12 @@ def files_upload_cli(self, local_paths, inbox_path, no_resume, no_compress): continue create_inbox_file_request = CreateInboxFileRequest( - virtual_directory=inbox_path, blob_file_token=upload_file.token) + virtual_directory=inbox_path, blob_file_token=upload_file.token) files_to_create.append((create_inbox_file_request, file_name)) for (create_inbox_file_request, file_name) in files_to_create: self.process_response( - self.with_retry(self.create_inbox_file)(create_inbox_file_request)) + self.with_retry(self.create_inbox_file)(create_inbox_file_request)) print('Inbox file created:', file_name) def file_upload_cli(self, local_path, inbox_path, no_compress, @@ -3785,9 +3843,9 @@ def download_needed(self, response, outfile, quiet=True): if remote_date <= local_date: if not quiet: print( - os.path.basename(outfile) + - ': Skipping, found more recently modified local ' - 'copy (use --force to force download)') + os.path.basename(outfile) + + ': Skipping, found more recently modified local ' + 'copy (use --force to force download)') return False except: pass @@ -3810,14 +3868,14 @@ def print_table(self, items, fields, labels=None): return for f in fields: length = max( - len(f), - max([ - len(self.string(getattr(i, self.camel_to_snake(f)))) - for i in items - ])) + len(f), + max([ + len(self.string(getattr(i, self.camel_to_snake(f)))) + for i in items + ])) justify = '>' if isinstance( - getattr(items[0], self.camel_to_snake(f)), - int) or f == 'size' or f == 'reward' else '<' + getattr(items[0], self.camel_to_snake(f)), + int) or f == 'size' or f == 'reward' else '<' formats.append('{:' + justify + self.string(length + 2) + '}') borders.append('-' * length + ' ') row_format = u''.join(formats) @@ -3826,7 +3884,7 @@ def print_table(self, items, fields, labels=None): print(row_format.format(*borders)) for i in items: i_fields = [ - self.string(getattr(i, self.camel_to_snake(f))) + ' ' for f in fields + self.string(getattr(i, self.camel_to_snake(f))) + ' ' for f in fields ] try: print(row_format.format(*i_fields)) @@ -3848,7 +3906,7 @@ def print_csv(self, items, fields, labels=None): writer.writerow(labels) for i in items: i_fields = [ - self.string(getattr(i, self.camel_to_snake(f))) for f in fields + self.string(getattr(i, self.camel_to_snake(f))) for f in fields ] writer.writerow(i_fields) @@ -3967,9 +4025,9 @@ def upload_files(self, """ for file_name in os.listdir(folder): if (file_name in [ - self.DATASET_METADATA_FILE, self.OLD_DATASET_METADATA_FILE, - self.KERNEL_METADATA_FILE, self.MODEL_METADATA_FILE, - self.MODEL_INSTANCE_METADATA_FILE + self.DATASET_METADATA_FILE, self.OLD_DATASET_METADATA_FILE, + self.KERNEL_METADATA_FILE, self.MODEL_METADATA_FILE, + self.MODEL_INSTANCE_METADATA_FILE ]): continue upload_file = self._upload_file_or_folder(folder, file_name, blob_type, @@ -4055,8 +4113,8 @@ def process_column(self, column): column: a list of values in a column to be processed """ processed_column = DatasetColumn( - name=self.get_or_fail(column, 'name'), - description=self.get_or_default(column, 'description', '')) + name=self.get_or_fail(column, 'name'), + description=self.get_or_default(column, 'description', '')) if 'type' in column: original_type = column['type'].lower() processed_column.original_type = original_type @@ -4109,10 +4167,10 @@ def upload_complete(self, path, url, quiet, resume=False): if start_at > 0: fp.seek(start_at) session.headers.update({ - 'Content-Length': - '%d' % upload_size, - 'Content-Range': - 'bytes %d-%d/%d' % (start_at, file_size - 1, file_size) + 'Content-Length': + '%d' % upload_size, + 'Content-Range': + 'bytes %d-%d/%d' % (start_at, file_size - 1, file_size) }) reader = TqdmBufferedReader(fp, progress_bar) retries = Retry(total=10, backoff_factor=0.5) @@ -4136,8 +4194,8 @@ def _resume_upload(self, url, content_length, quiet): # Documentation: https://developers.google.com/drive/api/guides/manage-uploads#resume-upload session = requests.Session() session.headers.update({ - 'Content-Length': '0', - 'Content-Range': 'bytes */%d' % content_length, + 'Content-Length': '0', + 'Content-Range': 'bytes */%d' % content_length, }) response = session.put(url) @@ -4284,13 +4342,13 @@ def validate_model_instance_version_string(self, model_instance_version): if model_instance_version: if model_instance_version.count('/') != 4: raise ValueError( - 'Model instance version must be specified in the form of ' - '\'{owner}/{model-slug}/{framework}/{instance-slug}/{version-number}\'' + 'Model instance version must be specified in the form of ' + '\'{owner}/{model-slug}/{framework}/{instance-slug}/{version-number}\'' ) split = model_instance_version.split('/') if not split[0] or not split[1] or not split[2] or not split[ - 3] or not split[4]: + 3] or not split[4]: raise ValueError('Invalid model instance version specification ' + model_instance_version) @@ -4298,7 +4356,7 @@ def validate_model_instance_version_string(self, model_instance_version): version_number = int(split[4]) except: raise ValueError( - 'Model instance version\'s version-number must be an integer') + 'Model instance version\'s version-number must be an integer') def validate_kernel_string(self, kernel): """ determine if a kernel string is valid, meaning it is in the format @@ -4330,8 +4388,8 @@ def validate_model_string(self, model): if model: if '/' not in model: raise ValueError( - 'Model must be specified in the form of ' - '\'{username}/{model-slug}/{framework}/{variation-slug}/{version-number}\'' + 'Model must be specified in the form of ' + '\'{username}/{model-slug}/{framework}/{variation-slug}/{version-number}\'' ) split = model.split('/') @@ -4377,7 +4435,7 @@ def validate_no_duplicate_paths(self, resources): file_name = item.get('path') if file_name in paths: raise ValueError( - '%s path was specified more than once in the metadata' % file_name) + '%s path was specified more than once in the metadata' % file_name) paths.add(file_name) def convert_to_dataset_file_metadata(self, file_data, path): @@ -4389,17 +4447,17 @@ def convert_to_dataset_file_metadata(self, file_data, path): path: the path to write the metadata to """ as_metadata = { - 'path': os.path.join(path, file_data['name']), - 'description': file_data['description'] + 'path': os.path.join(path, file_data['name']), + 'description': file_data['description'] } schema = {} fields = [] for column in file_data['columns']: field = { - 'name': column['name'], - 'title': column['description'], - 'type': column['type'] + 'name': column['name'], + 'title': column['description'], + 'type': column['type'] } fields.append(field) schema['fields'] = fields diff --git a/src/kaggle/models/kaggle_models_extended.py b/src/kaggle/models/kaggle_models_extended.py index 19fc3a0..d4edae7 100644 --- a/src/kaggle/models/kaggle_models_extended.py +++ b/src/kaggle/models/kaggle_models_extended.py @@ -115,9 +115,14 @@ def __repr__(self): class File(object): def __init__(self, init_dict): - parsed_dict = {k: parse(v) for k, v in init_dict.items()} - self.__dict__.update(parsed_dict) - self.size = File.get_size(self.totalBytes) + try: # TODO Remove try-block + parsed_dict = {k: parse(v) for k, v in init_dict.items()} + self.__dict__.update(parsed_dict) + self.size = File.get_size(self.totalBytes) + except AttributeError: + self.name = init_dict.name + self.creation_date = init_dict.creation_date + self.size = File.get_size(init_dict.total_bytes) def __repr__(self): return self.name @@ -165,13 +170,18 @@ def __repr__(self): class ListFilesResult(object): def __init__(self, init_dict): - self.error_message = init_dict['errorMessage'] - files = init_dict['datasetFiles'] + try: # TODO Remove try-block + self.error_message = init_dict['errorMessage'] + files = init_dict['datasetFiles'] + token = init_dict['nextPageToken'] + except TypeError: + self.error_message = init_dict.error_message + files = init_dict.dataset_files + token = init_dict.next_page_token if files: self.files = [File(f) for f in files] else: self.files = {} - token = init_dict['nextPageToken'] if token: self.nextPageToken = token else: @@ -223,8 +233,8 @@ def __repr__(self): def parse(string): time_formats = [ - '%Y-%m-%dT%H:%M:%S', '%Y-%m-%dT%H:%M:%SZ', '%Y-%m-%dT%H:%M:%S.%f', - '%Y-%m-%dT%H:%M:%S.%fZ' + '%Y-%m-%dT%H:%M:%S', '%Y-%m-%dT%H:%M:%SZ', '%Y-%m-%dT%H:%M:%S.%f', + '%Y-%m-%dT%H:%M:%S.%fZ' ] for t in time_formats: try: diff --git a/tests/unit_tests.py b/tests/unit_tests.py index 9fbc3e9..ff8314f 100644 --- a/tests/unit_tests.py +++ b/tests/unit_tests.py @@ -8,14 +8,14 @@ from requests import HTTPError from kaggle.rest import ApiException +from kagglesdk.datasets.types.dataset_api_service import ApiDownloadDatasetRequest -sys.path.insert(0,'..') +sys.path.insert(0, '..') -sys.path.insert(0,'..') +sys.path.insert(0, '..') from kaggle import api - # Unit test names include a letter to sort them in run order. # That seemed easier and more obvious than defining a test suite. @@ -53,609 +53,657 @@ # Max retries to get kernel status max_status_tries = 10 + def tearDownModule(): - file = os.path.join(dataset_directory, api.DATASET_METADATA_FILE) - if os.path.exists(file): - os.remove(file) - file = os.path.join(kernel_directory, api.KERNEL_METADATA_FILE) - if os.path.exists(file): - os.remove(file) - file = os.path.join(model_directory, api.MODEL_METADATA_FILE) - if os.path.exists(file): - os.remove(file) - file = os.path.join(model_inst_directory, api.MODEL_INSTANCE_METADATA_FILE) - if os.path.exists(file): - os.remove(file) + file = os.path.join(dataset_directory, api.DATASET_METADATA_FILE) + if os.path.exists(file): + os.remove(file) + file = os.path.join(kernel_directory, api.KERNEL_METADATA_FILE) + if os.path.exists(file): + os.remove(file) + file = os.path.join(model_directory, api.MODEL_METADATA_FILE) + if os.path.exists(file): + os.remove(file) + file = os.path.join(model_inst_directory, api.MODEL_INSTANCE_METADATA_FILE) + if os.path.exists(file): + os.remove(file) def update_kernel_metadata_file(metadata_file, k_name): - with open(metadata_file) as f: - meta_data = json.load(f) - meta_id = meta_data['id'] - if 'INSERT_KERNEL_SLUG_HERE' in meta_id: - meta_id = meta_id.replace('INSERT_KERNEL_SLUG_HERE', k_name) - meta_title = meta_data['title'] - if 'INSERT_TITLE_HERE' == meta_title: - meta_title = k_name - meta_path = meta_data['code_file'] - if 'INSERT_CODE_FILE_PATH_HERE' == meta_path: - meta_path = f'{k_name}.ipynb' - meta_data['id'] = meta_id - meta_data['title'] = meta_title - meta_data['code_file'] = meta_path - meta_data['language'] = 'python' - meta_data['kernel_type'] ='notebook' - with open(metadata_file, 'w') as f: - json.dump(meta_data, f, indent=2) - return meta_data + with open(metadata_file) as f: + meta_data = json.load(f) + meta_id = meta_data['id'] + if 'INSERT_KERNEL_SLUG_HERE' in meta_id: + meta_id = meta_id.replace('INSERT_KERNEL_SLUG_HERE', k_name) + meta_title = meta_data['title'] + if 'INSERT_TITLE_HERE' == meta_title: + meta_title = k_name + meta_path = meta_data['code_file'] + if 'INSERT_CODE_FILE_PATH_HERE' == meta_path: + meta_path = f'{k_name}.ipynb' + meta_data['id'] = meta_id + meta_data['title'] = meta_title + meta_data['code_file'] = meta_path + meta_data['language'] = 'python' + meta_data['kernel_type'] = 'notebook' + with open(metadata_file, 'w') as f: + json.dump(meta_data, f, indent=2) + return meta_data + def initialize_dataset_metadata_file(dataset_dir): - metadata_file = os.path.join(dataset_dir, api.DATASET_METADATA_FILE) - try: - with open(metadata_file) as f: - original = json.load(f) - if 'versionNumber' in original: - version_num = int(original['versionNumber']) - else: - version_num = 0 - except FileNotFoundError: + metadata_file = os.path.join(dataset_dir, api.DATASET_METADATA_FILE) + try: + with open(metadata_file) as f: + original = json.load(f) + if 'versionNumber' in original: + version_num = int(original['versionNumber']) + else: version_num = 0 - version_num += 1 - return version_num, metadata_file + except FileNotFoundError: + version_num = 0 + version_num += 1 + return version_num, metadata_file + def update_dataset_metadata_file(metadata_file, data_name, version_num): - with open(metadata_file) as f: - meta_data = json.load(f) - meta_id = meta_data['id'] - if 'INSERT_SLUG_HERE' in meta_id: - meta_id = meta_id.replace('INSERT_SLUG_HERE', data_name) - meta_title = meta_data['title'] - if 'INSERT_TITLE_HERE' == meta_title: - meta_title = data_name - meta_data['id'] = meta_id - meta_data['title'] = meta_title - meta_data['versionNumber'] = version_num - if not 'resources' in meta_data: - resource_list = [ - { - "path": "data.csv", - "description": "Description", - "schema": { - "fields": [ - { - "name": "NumberField", - "description": "id", - "type": "number" - }, - { - "name": "StringField", - "description": "label", - "type": "string" - } - ] - } - } - ] - meta_data.update({'resources': resource_list}) - with open(metadata_file, 'w') as f: - json.dump(meta_data, f) + with open(metadata_file) as f: + meta_data = json.load(f) + meta_id = meta_data['id'] + if 'INSERT_SLUG_HERE' in meta_id: + meta_id = meta_id.replace('INSERT_SLUG_HERE', data_name) + meta_title = meta_data['title'] + if 'INSERT_TITLE_HERE' == meta_title: + meta_title = data_name + meta_data['id'] = meta_id + meta_data['title'] = meta_title + meta_data['versionNumber'] = version_num + if not 'resources' in meta_data: + resource_list = [{ + "path": "data.csv", + "description": "Description", + "schema": { + "fields": [{ + "name": "NumberField", + "description": "id", + "type": "number" + }, { + "name": "StringField", + "description": "label", + "type": "string" + }] + } + }] + meta_data.update({'resources': resource_list}) + with open(metadata_file, 'w') as f: + json.dump(meta_data, f) + def update_model_metadata(metadata_file, owner, title, slug): - with open(metadata_file) as f: - meta_data = json.load(f) - meta_id = meta_data['ownerSlug'] - if 'INSERT_OWNER_SLUG_HERE' == meta_id: - meta_id = owner - meta_title = meta_data['title'] - if 'INSERT_TITLE_HERE' == meta_title: - meta_title = title - meta_path = meta_data['slug'] - if 'INSERT_SLUG_HERE' == meta_path: - meta_path = slug - meta_data['ownerSlug'] = meta_id - meta_data['title'] = meta_title - meta_data['slug'] = meta_path - with open(metadata_file, 'w') as f: - json.dump(meta_data, f, indent=2) - return meta_data - -def update_model_instance_metadata(metadata_file, owner, model_slug, instance_slug, framework): - with open(metadata_file) as f: - meta_data = json.load(f) - meta_owner = meta_data['ownerSlug'] - if 'INSERT_OWNER_SLUG_HERE' == meta_owner: - meta_owner = owner - meta_framework = meta_data['framework'] - if 'INSERT_FRAMEWORK_HERE' == meta_framework: - meta_framework = framework - meta_instance = meta_data['instanceSlug'] - if 'INSERT_INSTANCE_SLUG_HERE' == meta_instance: - meta_instance = instance_slug - meta_model = meta_data['modelSlug'] - if 'INSERT_EXISTING_MODEL_SLUG_HERE' == meta_model: - meta_model = model_slug - meta_data['ownerSlug'] = meta_owner - meta_data['modelSlug'] = meta_model - meta_data['framework'] = meta_framework - meta_data['instanceSlug'] = meta_instance - with open(metadata_file, 'w') as f: - json.dump(meta_data, f, indent=2) - return meta_data - -def print_fields(instance,fields): # For debugging. - for f in fields: - if not hasattr(instance, api.camel_to_snake(f)): - print(f"Missing field: {f} named: {api.camel_to_snake(f)}") + with open(metadata_file) as f: + meta_data = json.load(f) + meta_id = meta_data['ownerSlug'] + if 'INSERT_OWNER_SLUG_HERE' == meta_id: + meta_id = owner + meta_title = meta_data['title'] + if 'INSERT_TITLE_HERE' == meta_title: + meta_title = title + meta_path = meta_data['slug'] + if 'INSERT_SLUG_HERE' == meta_path: + meta_path = slug + meta_data['ownerSlug'] = meta_id + meta_data['title'] = meta_title + meta_data['slug'] = meta_path + with open(metadata_file, 'w') as f: + json.dump(meta_data, f, indent=2) + return meta_data + + +def update_model_instance_metadata(metadata_file, owner, model_slug, + instance_slug, framework): + with open(metadata_file) as f: + meta_data = json.load(f) + meta_owner = meta_data['ownerSlug'] + if 'INSERT_OWNER_SLUG_HERE' == meta_owner: + meta_owner = owner + meta_framework = meta_data['framework'] + if 'INSERT_FRAMEWORK_HERE' == meta_framework: + meta_framework = framework + meta_instance = meta_data['instanceSlug'] + if 'INSERT_INSTANCE_SLUG_HERE' == meta_instance: + meta_instance = instance_slug + meta_model = meta_data['modelSlug'] + if 'INSERT_EXISTING_MODEL_SLUG_HERE' == meta_model: + meta_model = model_slug + meta_data['ownerSlug'] = meta_owner + meta_data['modelSlug'] = meta_model + meta_data['framework'] = meta_framework + meta_data['instanceSlug'] = meta_instance + with open(metadata_file, 'w') as f: + json.dump(meta_data, f, indent=2) + return meta_data + + +def print_fields(instance, fields): # For debugging. + for f in fields: + if not hasattr(instance, api.camel_to_snake(f)): + print(f"Missing field: {f} named: {api.camel_to_snake(f)}") + class TestKaggleApi(unittest.TestCase): - version_number, meta_file = initialize_dataset_metadata_file(dataset_directory) - - # Initialized from Response objects. - competition_file = None - kernel_slug = '' - kernel_metadata_path = '' - dataset = '' - dataset_file = None - model_instance = '' - model_meta_data = None - model_metadata_file = '' - instance_metadata_file = '' - - # Kernels - - def test_kernels_a_list(self): - try: - kernels = api.kernels_list() - self.assertGreater(len(kernels), 0) # Assuming there should be some kernels - except ApiException as e: - self.fail(f"kernels_list failed: {e}") - - def test_kernels_b_initialize(self): - try: - self.kernel_metadata_path = api.kernels_initialize(kernel_directory) - self.assertTrue(os.path.exists(self.kernel_metadata_path)) - except ApiException as e: - self.fail(f"kernels_initialize failed: {e}") - - def test_kernels_c_push(self): - if self.kernel_metadata_path == '': - self.test_kernels_b_initialize() - try: - md = update_kernel_metadata_file(self.kernel_metadata_path, kernel_name) - push_result = api.kernels_push(kernel_directory) - self.assertIsNotNone(push_result.ref) - self.assertIsNotNone(push_result.versionNumber) - self.kernel_slug = md['id'] - except ApiException as e: - self.fail(f"kernels_push failed: {e}") - - def test_kernels_d_status(self): - if self.kernel_slug == '': - self.test_kernels_c_push() - try: - status_result = api.kernels_status(self.kernel_slug) - start_time = time.time() - # If this loop is stuck because the kernel stays queued, go to the Kaggle website - # on localhost and cancel the active event. That will exit the loop, but you may - # need to clean up other active kernels to get it to run again. - count = 0 - while status_result['status'] == 'running' or status_result['status'] == 'queued' or count >= max_status_tries: - time.sleep(5) - status_result = api.kernels_status(self.kernel_slug) - print(status_result['status']) - end_time = time.time() - print(f'kernels_status ready in {end_time-start_time}s') - except ApiException as e: - self.fail(f"kernels_status failed: {e}") - - def test_kernels_e_list_files(self): - if self.kernel_slug == '': - self.test_kernels_c_push() - try: - fs = api.kernels_list_files(self.kernel_slug) - self.assertGreaterEqual(len(fs.files), 0) # Adjust expectation if needed - except ApiException as e: - self.fail(f"kernels_list_files failed: {e}") - - def test_kernels_f_output(self): - fs = [] - if self.kernel_slug == '': - self.test_kernels_c_push() - try: - fs = api.kernels_output(self.kernel_slug, 'kernel/tmp') - self.assertIsInstance(fs, list) # Assuming it returns a list of files, but may be empty - except ApiException as e: - self.fail(f"kernels_output failed: {e}") - finally: - for file in fs: - if os.path.exists(file): - os.remove(file) - if os.path.exists('kernel/tmp'): - os.rmdir('kernel/tmp') - - def test_kernels_g_pull(self): - if self.kernel_metadata_path == '': - self.test_kernels_b_initialize() - fs = '' - try: - fs = api.kernels_pull(f'{test_user}/testing', 'kernel/tmp', metadata=True) - self.assertTrue(os.path.exists(fs)) - except ApiException as e: - self.fail(f"kernels_pull failed: {e}") - finally: - for file in [f'{fs}/{self.kernel_metadata_path.split("/")[1]}', f'{fs}/{kernel_name}.ipynb']: - if os.path.exists(file): - os.remove(file) - if os.path.exists(fs): - os.rmdir(fs) - - # Competitions - - def test_competition_a_list(self): - try: - competitions = api.competitions_list() - self.assertGreater(len(competitions), 0) # Assuming there should be some competitions - [self.assertTrue(hasattr(competitions[0], api.camel_to_snake(f))) for f in api.competition_fields] - except ApiException as e: - self.fail(f"competitions_list failed: {e}") - - def test_competition_b_submit(self): - try: - api.competition_submit(up_file, description, competition) - except HTTPError: - # Handle submission limit reached gracefully (potentially skip the test) - print('Competition submission limit reached for the day') - pass - except ApiException as e: - self.fail(f"competition_submit failed: {e}") - - def test_competition_c_submissions(self): - try: - submissions = api.competition_submissions(competition) - self.assertIsInstance(submissions, list) # Assuming it returns a list of submissions - self.assertGreater(len(submissions), 0) - [self.assertTrue(hasattr(submissions[0], api.camel_to_snake(f))) for f in api.submission_fields] - except ApiException as e: - self.fail(f"competition_submissions failed: {e}") - - def test_competition_d_list_files(self): - try: - competition_files = api.competition_list_files(competition).files - self.assertIsInstance(competition_files, list) - self.assertGreater(len(competition_files), 0) - self.competition_file = competition_files[0] - [self.assertTrue(hasattr(competition_files[0], api.camel_to_snake(f))) for f in api.competition_file_fields] - except ApiException as e: - self.fail(f"competition_list_files failed: {e}") - - def test_competition_e_download_file(self): - if self.competition_file is None: - self.test_competition_d_list_files() - try: - api.competition_download_file(competition, self.competition_file.ref, force=True) - self.assertTrue(os.path.exists(self.competition_file.ref)) - except ApiException as e: - self.fail(f"competition_download_file failed: {e}") - finally: - if os.path.exists(self.competition_file.ref): - os.remove(self.competition_file.ref) - - def test_competition_f_download_files(self): - try: - api.competition_download_files(competition) - self.assertTrue(os.path.exists(f'{competition}.zip')) - self.assertTrue(os.path.getsize(f'{competition}.zip') > 0) - except ApiException as e: - self.fail(f"competition_download_files failed: {e}") - finally: - if os.path.exists(f'{competition}.zip'): - os.remove(f'{competition}.zip') - - def test_competition_g_leaderboard_view(self): - try: - result = api.competition_leaderboard_view(competition) - self.assertIsInstance(result, list) - self.assertGreater(len(result), 0) - [self.assertTrue(hasattr(result[0], api.camel_to_snake(f))) for f in api.competition_leaderboard_fields] - except ApiException as e: - self.fail(f"competition_leaderboard_view failed: {e}") - - def test_competition_h_leaderboard_download(self): - try: - api.competition_leaderboard_download(competition, 'tmp') - self.assertTrue(os.path.exists(f'tmp/{competition}.zip')) - except ApiException as e: - self.fail(f"competition_leaderboard_download failed: {e}") - finally: - if os.path.exists(f'tmp/{competition}.zip'): - os.remove(f'tmp/{competition}.zip') - if os.path.exists('tmp'): - os.rmdir('tmp') - - # Datasets - - def test_dataset_a_list(self): - try: - datasets = api.dataset_list(sort_by='votes') - self.assertGreater(len(datasets), 0) # Assuming there should be some datasets - self.dataset = str(datasets[0].ref) - [self.assertTrue(hasattr(datasets[0], api.camel_to_snake(f))) for f in api.dataset_fields] - except ApiException as e: - self.fail(f"dataset_list failed: {e}") - - def test_dataset_b_metadata(self): - if self.dataset == '': - self.test_dataset_a_list() - m = '' - try: - m = api.dataset_metadata(self.dataset, dataset_directory) - self.assertTrue(os.path.exists(m)) - except ApiException as e: - self.fail(f"dataset_metadata failed: {e}") - - def test_dataset_c_metadata_update(self): - if self.dataset == '': - self.test_dataset_a_list() - if not os.path.exists(os.path.join(dataset_directory, api.DATASET_METADATA_FILE)): - self.test_dataset_b_metadata() - try: - api.dataset_metadata_update(self.dataset, dataset_directory) - # TODO Make the API method return something, and not exit when it fails. - except ApiException as e: - self.fail(f"dataset_metadata_update failed: {e}") - - def test_dataset_d_list_files(self): - if self.dataset == '': - self.test_dataset_a_list() - try: - dataset_files = api.dataset_list_files(self.dataset) - self.assertIsInstance(dataset_files.files, list) - self.assertGreater(len(dataset_files.files), 0) - self.dataset_file = dataset_files.files[0] - except ApiException as e: - self.fail(f"dataset_list_files failed: {e}") - - def test_dataset_e_status(self): - if self.dataset == '': - self.test_dataset_a_list() - try: - status = api.dataset_status(self.dataset) - self.assertIn(status, ['ready', 'pending', 'error']) - except ApiException as e: - self.fail(f"dataset_status failed: {e}") - - def test_dataset_f_download_file(self): - if self.dataset_file is None: - self.test_dataset_d_list_files() - try: - api.dataset_download_file(self.dataset, self.dataset_file.name, 'tmp') - self.assertTrue(os.path.exists(f'tmp/{self.dataset_file.name}')) - except ApiException as e: - self.fail(f"dataset_download_file failed: {e}") - finally: - if os.path.exists(f'tmp/{self.dataset_file.name}'): - os.remove(f'tmp/{self.dataset_file.name}') - if os.path.exists('tmp'): - os.rmdir('tmp') - - def test_dataset_g_download_files(self): - if self.dataset == '': - self.test_dataset_a_list() - ds = ['a', 'b'] - try: - api.dataset_download_files(self.dataset) - ds = self.dataset.split('/') - self.assertTrue(os.path.exists(f'{ds[1]}.zip')) - except ApiException as e: - self.fail(f"dataset_download_files failed: {e}") - finally: - if os.path.exists(f'{ds[1]}.zip'): - os.remove(f'{ds[1]}.zip') - - def test_dataset_h_initialize(self): - try: - api.dataset_initialize('dataset') - self.assertTrue(os.path.exists(os.path.join(dataset_directory, api.DATASET_METADATA_FILE))) - except ApiException as e: - self.fail(f"dataset_initialize failed: {e}") - - def test_dataset_i_create_new(self): - if not os.path.exists(os.path.join(dataset_directory, api.DATASET_METADATA_FILE)): - self.test_dataset_h_initialize() - try: - update_dataset_metadata_file(self.meta_file, dataset_name, self.version_number) - new_dataset = api.dataset_create_new(dataset_directory) - self.assertIsNotNone(new_dataset) - if new_dataset.hasError: - print(new_dataset.error) # This is likely to happen, and that's OK. - except ApiException as e: - self.fail(f"dataset_create_new failed: {e}") - - def test_dataset_j_create_version(self): - try: - new_version = api.dataset_create_version(dataset_directory, "Notes") - self.assertIsNotNone(new_version) - self.assertFalse(new_version.hasError) - self.assertTrue(new_version.hasRef) - except ApiException as e: - self.fail(f"dataset_create_version failed: {e}") - - # Models - - def test_model_a_list(self): - try: - ms = api.model_list() - self.assertIsInstance(ms, list) - self.assertGreater(len(ms), 0) - except ApiException as e: - self.fail(f"models_list failed: {e}") - - def test_model_b_initialize(self): - try: - self.model_metadata_file = api.model_initialize(model_directory) - self.assertTrue(os.path.exists(self.model_metadata_file)) - self.model_meta_data = update_model_metadata(self.model_metadata_file, test_user, model_title, model_title) - self.model_instance = f'{test_user}/{self.model_meta_data["slug"]}/{framework_name}/{instance_name}' - except ApiException as e: - self.fail(f"model_initialize failed: {e}") - - def test_model_c_create_new(self): - if self.model_metadata_file == '': - self.test_model_b_initialize() - try: - model = api.model_create_new(model_directory) - if model.hasError: - self.fail(model.error) - else: - self.assertIsNotNone(model.ref) - self.assertGreater(len(model.ref), 0) - except ApiException as e: - self.fail(f"model_create_new failed: {e}") - - def test_model_d_get(self): - try: - model_data = api.model_get(f'{test_user}/{model_title}') - self.assertIsNotNone(model_data['ref']) - self.assertGreater(len(model_data['ref']), 0) - self.assertEquals(model_data['title'], model_title) - except ApiException as e: - self.fail(f"model_get failed: {e}") - - def test_model_e_update(self): - try: - update_response = api.model_update(model_directory) - self.assertIsNotNone(update_response.ref) - self.assertGreater(len(update_response.ref), 0) - except ApiException as e: - self.fail(f"model_update failed: {e}") - - # Model instances - - def test_model_instance_a_initialize(self): - try: - self.instance_metadata_file = api.model_instance_initialize(model_inst_directory) - self.assertTrue(os.path.exists(self.instance_metadata_file)) - except ApiException as e: - self.fail(f"model_instance_initialize failed: {e}") - - def test_model_instance_b_create(self): - if self.model_meta_data is None: - self.test_model_b_initialize() - if self.instance_metadata_file == '': - self.test_model_instance_a_initialize() - try: - update_model_instance_metadata( - self.instance_metadata_file, test_user, self.model_meta_data['slug'], instance_name, framework_name) - inst_create_resp = api.model_instance_create(model_inst_directory) - self.assertIsNotNone(inst_create_resp.ref) - self.assertGreater(len(inst_create_resp.ref), 0) - except ApiException as e: - self.fail(f"model_instance_create failed: {e}") - - def test_model_instance_b_wait_after_create(self): - # When running all tests sequentially, give the new model some time to stabilize. - time.sleep(10) # TODO: Find a better way to detect model stability. - - def test_model_instance_c_get(self): - if self.model_instance == '': - self.test_model_b_initialize() - try: - inst_get_resp = api.model_instance_get(self.model_instance) - self.assertIsNotNone(inst_get_resp['url']) - self.assertGreater(len(inst_get_resp['url']), 0) - except ApiException as e: - self.fail(f"model_instance_get failed: {e}") - - def test_model_instance_d_files(self): - if self.model_instance == '': - self.test_model_b_initialize() - try: - inst_files_resp = api.model_instance_files(self.model_instance) - self.assertIsInstance(inst_files_resp.files, list) - self.assertGreater(len(inst_files_resp.files), 0) - except ApiException as e: - self.fail(f"model_instance_files failed: {e}") - - def test_model_instance_e_update(self): - if self.model_instance == '': - self.test_model_b_initialize() - try: - inst_update_resp = api.model_instance_update(model_inst_directory) - self.assertIsNotNone(inst_update_resp) - self.assertIsNotNone(inst_update_resp.ref) - self.assertGreater(len(inst_update_resp.ref), 0) - except ApiException as e: - self.fail(f"model_instance_update failed: {e}") - - # Model instance versions - - def test_model_instance_version_a_create(self): - if self.model_instance == '': - self.test_model_b_initialize() - try: - version_metadata_resp = api.model_instance_version_create(self.model_instance, model_inst_vers_directory) - self.assertIsNotNone(version_metadata_resp.ref) - except ApiException as e: - self.fail(f"model_instance_version_create failed: {e}") - - def test_model_instance_version_b_files(self): - if self.model_instance == '': - self.test_model_b_initialize() - try: - r = api.model_instance_version_files(f'{self.model_instance}/1') - self.assertIsInstance(r.files, list) - self.assertGreater(len(r.files), 0) - except ApiException as e: - self.fail(f"model_instance_version_files failed: {e}") - - def test_model_instance_version_c_download(self): - if self.model_instance == '': - self.test_model_b_initialize() - version_file = '' - try: - version_file = api.model_instance_version_download(f'{self.model_instance}/1', 'tmp') - self.assertTrue(os.path.exists(version_file)) - except KeyError: - pass # TODO Create a version that has content. - except ApiException as e: - self.fail(f"model_instance_version_download failed: {e}") - finally: - if os.path.exists(version_file): - os.remove(version_file) - if os.path.exists('tmp'): - os.rmdir('tmp') - - # Model deletion - - def test_model_instance_version_d_delete(self): - if self.model_instance == '': - self.test_model_b_initialize() - try: - version_delete_resp = api.model_instance_version_delete(f'{self.model_instance}/1', True) - self.assertFalse(version_delete_resp.hasError) - except ApiException as e: - self.fail(f"model_instance_version_delete failed: {e}") - - def test_model_instance_x_delete(self): - if self.model_instance == '': - self.test_model_b_initialize() - try: - inst_update_resp = api.model_instance_delete(self.model_instance, True) - self.assertIsNotNone(inst_update_resp) - except ApiException as e: - self.fail(f"model_instance_delete failed: {e}") - - def test_model_z_delete(self): - try: - delete_response = api.model_delete(f'{test_user}/{model_title}', True) - if delete_response.hasError: - self.fail(delete_response.error) - else: - pass - except ApiException as e: - self.fail(f"model_delete failed: {e}") + version_number, meta_file = initialize_dataset_metadata_file( + dataset_directory) + + # Initialized from Response objects. + competition_file = None + kernel_slug = '' + kernel_metadata_path = '' + dataset = '' + dataset_file = None + model_instance = '' + model_meta_data = None + model_metadata_file = '' + instance_metadata_file = '' + + # Kernels + + def test_kernels_a_list(self): + try: + kernels = api.kernels_list() + self.assertGreater(len(kernels), + 0) # Assuming there should be some kernels + except ApiException as e: + self.fail(f"kernels_list failed: {e}") + + def test_kernels_b_initialize(self): + try: + self.kernel_metadata_path = api.kernels_initialize(kernel_directory) + self.assertTrue(os.path.exists(self.kernel_metadata_path)) + except ApiException as e: + self.fail(f"kernels_initialize failed: {e}") + + def test_kernels_c_push(self): + if self.kernel_metadata_path == '': + self.test_kernels_b_initialize() + try: + md = update_kernel_metadata_file(self.kernel_metadata_path, kernel_name) + push_result = api.kernels_push(kernel_directory) + self.assertIsNotNone(push_result.ref) + self.assertIsNotNone(push_result.versionNumber) + self.kernel_slug = md['id'] + except ApiException as e: + self.fail(f"kernels_push failed: {e}") + + def test_kernels_d_status(self): + if self.kernel_slug == '': + self.test_kernels_c_push() + try: + status_result = api.kernels_status(self.kernel_slug) + start_time = time.time() + # If this loop is stuck because the kernel stays queued, go to the Kaggle website + # on localhost and cancel the active event. That will exit the loop, but you may + # need to clean up other active kernels to get it to run again. + count = 0 + while status_result['status'] == 'running' or status_result[ + 'status'] == 'queued' or count >= max_status_tries: + time.sleep(5) + status_result = api.kernels_status(self.kernel_slug) + print(status_result['status']) + end_time = time.time() + print(f'kernels_status ready in {end_time-start_time}s') + except ApiException as e: + self.fail(f"kernels_status failed: {e}") + + def test_kernels_e_list_files(self): + if self.kernel_slug == '': + self.test_kernels_c_push() + try: + fs = api.kernels_list_files(self.kernel_slug) + self.assertGreaterEqual(len(fs.files), 0) # Adjust expectation if needed + except ApiException as e: + self.fail(f"kernels_list_files failed: {e}") + + def test_kernels_f_output(self): + fs = [] + if self.kernel_slug == '': + self.test_kernels_c_push() + try: + fs = api.kernels_output(self.kernel_slug, 'kernel/tmp') + self.assertIsInstance( + fs, list) # Assuming it returns a list of files, but may be empty + except ApiException as e: + self.fail(f"kernels_output failed: {e}") + finally: + for file in fs: + if os.path.exists(file): + os.remove(file) + if os.path.exists('kernel/tmp'): + os.rmdir('kernel/tmp') + + def test_kernels_g_pull(self): + if self.kernel_metadata_path == '': + self.test_kernels_b_initialize() + fs = '' + try: + fs = api.kernels_pull(f'{test_user}/testing', 'kernel/tmp', metadata=True) + self.assertTrue(os.path.exists(fs)) + except ApiException as e: + self.fail(f"kernels_pull failed: {e}") + finally: + for file in [ + f'{fs}/{self.kernel_metadata_path.split("/")[1]}', + f'{fs}/{kernel_name}.ipynb' + ]: + if os.path.exists(file): + os.remove(file) + if os.path.exists(fs): + os.rmdir(fs) + + # Competitions + + def test_competition_a_list(self): + try: + competitions = api.competitions_list() + self.assertGreater(len(competitions), + 0) # Assuming there should be some competitions + [ + self.assertTrue(hasattr(competitions[0], api.camel_to_snake(f))) + for f in api.competition_fields + ] + except ApiException as e: + self.fail(f"competitions_list failed: {e}") + + def test_competition_b_submit(self): + try: + api.competition_submit(up_file, description, competition) + except HTTPError: + # Handle submission limit reached gracefully (potentially skip the test) + print('Competition submission limit reached for the day') + pass + except ApiException as e: + self.fail(f"competition_submit failed: {e}") + + def test_competition_c_submissions(self): + try: + submissions = api.competition_submissions(competition) + self.assertIsInstance(submissions, + list) # Assuming it returns a list of submissions + self.assertGreater(len(submissions), 0) + [ + self.assertTrue(hasattr(submissions[0], api.camel_to_snake(f))) + for f in api.submission_fields + ] + except ApiException as e: + self.fail(f"competition_submissions failed: {e}") + + def test_competition_d_list_files(self): + try: + competition_files = api.competition_list_files(competition).files + self.assertIsInstance(competition_files, list) + self.assertGreater(len(competition_files), 0) + self.competition_file = competition_files[0] + [ + self.assertTrue(hasattr(competition_files[0], api.camel_to_snake(f))) + for f in api.competition_file_fields + ] + except ApiException as e: + self.fail(f"competition_list_files failed: {e}") + + def test_competition_e_download_file(self): + if self.competition_file is None: + self.test_competition_d_list_files() + try: + api.competition_download_file( + competition, self.competition_file.ref, force=True) + self.assertTrue(os.path.exists(self.competition_file.ref)) + except ApiException as e: + self.fail(f"competition_download_file failed: {e}") + finally: + if os.path.exists(self.competition_file.ref): + os.remove(self.competition_file.ref) + + def test_competition_f_download_files(self): + try: + api.competition_download_files(competition) + self.assertTrue(os.path.exists(f'{competition}.zip')) + self.assertTrue(os.path.getsize(f'{competition}.zip') > 0) + except ApiException as e: + self.fail(f"competition_download_files failed: {e}") + finally: + if os.path.exists(f'{competition}.zip'): + os.remove(f'{competition}.zip') + + def test_competition_g_leaderboard_view(self): + try: + result = api.competition_leaderboard_view(competition) + self.assertIsInstance(result, list) + self.assertGreater(len(result), 0) + [ + self.assertTrue(hasattr(result[0], api.camel_to_snake(f))) + for f in api.competition_leaderboard_fields + ] + except ApiException as e: + self.fail(f"competition_leaderboard_view failed: {e}") + + def test_competition_h_leaderboard_download(self): + try: + api.competition_leaderboard_download(competition, 'tmp') + self.assertTrue(os.path.exists(f'tmp/{competition}.zip')) + except ApiException as e: + self.fail(f"competition_leaderboard_download failed: {e}") + finally: + if os.path.exists(f'tmp/{competition}.zip'): + os.remove(f'tmp/{competition}.zip') + if os.path.exists('tmp'): + os.rmdir('tmp') + + # Datasets + + def test_dataset_a_list(self): + try: + datasets = api.dataset_list(sort_by='votes') + self.assertGreater(len(datasets), + 0) # Assuming there should be some datasets + self.dataset = str(datasets[0].ref) + [ + self.assertTrue(hasattr(datasets[0], api.camel_to_snake(f))) + for f in api.dataset_fields + ] + except ApiException as e: + self.fail(f"dataset_list failed: {e}") + + def test_dataset_b_metadata(self): + if self.dataset == '': + self.test_dataset_a_list() + m = '' + try: + m = api.dataset_metadata(self.dataset, dataset_directory) + self.assertTrue(os.path.exists(m)) + except ApiException as e: + self.fail(f"dataset_metadata failed: {e}") + + def test_dataset_c_metadata_update(self): + if self.dataset == '': + self.test_dataset_a_list() + if not os.path.exists( + os.path.join(dataset_directory, api.DATASET_METADATA_FILE)): + self.test_dataset_b_metadata() + try: + api.dataset_metadata_update(self.dataset, dataset_directory) + # TODO Make the API method return something, and not exit when it fails. + except ApiException as e: + self.fail(f"dataset_metadata_update failed: {e}") + + def test_dataset_d_list_files(self): + if self.dataset == '': + self.test_dataset_a_list() + try: + dataset_files = api.dataset_list_files(self.dataset) + self.assertIsInstance(dataset_files.files, list) + self.assertGreater(len(dataset_files.files), 0) + self.dataset_file = dataset_files.files[0] + [ + self.assertTrue(hasattr(self.dataset_file, api.camel_to_snake(f))) + for f in api.dataset_file_fields + ] + except ApiException as e: + self.fail(f"dataset_list_files failed: {e}") + + def test_dataset_e_status(self): + if self.dataset == '': + self.test_dataset_a_list() + try: + status = api.dataset_status(self.dataset) + self.assertIn(status, ['ready', 'pending', 'error']) + except ApiException as e: + self.fail(f"dataset_status failed: {e}") + + def test_dataset_f_download_file(self): + if self.dataset_file is None: + self.test_dataset_d_list_files() + try: + api.dataset_download_file(self.dataset, self.dataset_file.name, 'tmp') + self.assertTrue(os.path.exists(f'tmp/{self.dataset_file.name}')) + except ApiException as e: + self.fail(f"dataset_download_file failed: {e}") + finally: + if os.path.exists(f'tmp/{self.dataset_file.name}'): + os.remove(f'tmp/{self.dataset_file.name}') + if os.path.exists('tmp'): + os.rmdir('tmp') + + def test_dataset_g_download_files(self): + if self.dataset == '': + self.test_dataset_a_list() + ds = ['a', 'b'] + try: + api.dataset_download_files(self.dataset) + ds = self.dataset.split('/') + self.assertTrue(os.path.exists(f'{ds[1]}.zip')) + except ApiException as e: + self.fail(f"dataset_download_files failed: {e}") + finally: + if os.path.exists(f'{ds[1]}.zip'): + os.remove(f'{ds[1]}.zip') + + def test_dataset_h_initialize(self): + try: + api.dataset_initialize('dataset') + self.assertTrue( + os.path.exists( + os.path.join(dataset_directory, api.DATASET_METADATA_FILE))) + except ApiException as e: + self.fail(f"dataset_initialize failed: {e}") + + def test_dataset_i_create_new(self): + if not os.path.exists( + os.path.join(dataset_directory, api.DATASET_METADATA_FILE)): + self.test_dataset_h_initialize() + try: + update_dataset_metadata_file(self.meta_file, dataset_name, + self.version_number) + new_dataset = api.dataset_create_new(dataset_directory) + self.assertIsNotNone(new_dataset) + if new_dataset.error is not None: + print(new_dataset.error) # This is likely to happen, and that's OK. + except ApiException as e: + self.fail(f"dataset_create_new failed: {e}") + + def test_dataset_j_create_version(self): + if not os.path.exists( + os.path.join(dataset_directory, api.DATASET_METADATA_FILE)): + self.test_dataset_i_create_new() + try: + new_version = api.dataset_create_version(dataset_directory, "Notes") + self.assertIsNotNone(new_version) + self.assertTrue(new_version.error == '') + self.assertFalse(new_version.ref == '') + except ApiException as e: + self.fail(f"dataset_create_version failed: {e}") + + # Models + + def test_model_a_list(self): + try: + ms = api.model_list() + self.assertIsInstance(ms, list) + self.assertGreater(len(ms), 0) + except ApiException as e: + self.fail(f"models_list failed: {e}") + + def test_model_b_initialize(self): + try: + self.model_metadata_file = api.model_initialize(model_directory) + self.assertTrue(os.path.exists(self.model_metadata_file)) + self.model_meta_data = update_model_metadata(self.model_metadata_file, + test_user, model_title, + model_title) + self.model_instance = f'{test_user}/{self.model_meta_data["slug"]}/{framework_name}/{instance_name}' + except ApiException as e: + self.fail(f"model_initialize failed: {e}") + + def test_model_c_create_new(self): + if self.model_metadata_file == '': + self.test_model_b_initialize() + try: + model = api.model_create_new(model_directory) + if model.hasError: + self.fail(model.error) + else: + self.assertIsNotNone(model.ref) + self.assertGreater(len(model.ref), 0) + except ApiException as e: + self.fail(f"model_create_new failed: {e}") + + def test_model_d_get(self): + try: + model_data = api.model_get(f'{test_user}/{model_title}') + self.assertIsNotNone(model_data['ref']) + self.assertGreater(len(model_data['ref']), 0) + self.assertEquals(model_data['title'], model_title) + except ApiException as e: + self.fail(f"model_get failed: {e}") + + def test_model_e_update(self): + try: + update_response = api.model_update(model_directory) + self.assertIsNotNone(update_response.ref) + self.assertGreater(len(update_response.ref), 0) + except ApiException as e: + self.fail(f"model_update failed: {e}") + + # Model instances + + def test_model_instance_a_initialize(self): + try: + self.instance_metadata_file = api.model_instance_initialize( + model_inst_directory) + self.assertTrue(os.path.exists(self.instance_metadata_file)) + except ApiException as e: + self.fail(f"model_instance_initialize failed: {e}") + + def test_model_instance_b_create(self): + if self.model_meta_data is None: + self.test_model_b_initialize() + if self.instance_metadata_file == '': + self.test_model_instance_a_initialize() + try: + update_model_instance_metadata(self.instance_metadata_file, test_user, + self.model_meta_data['slug'], + instance_name, framework_name) + inst_create_resp = api.model_instance_create(model_inst_directory) + self.assertIsNotNone(inst_create_resp.ref) + self.assertGreater(len(inst_create_resp.ref), 0) + except ApiException as e: + self.fail(f"model_instance_create failed: {e}") + + def test_model_instance_b_wait_after_create(self): + # When running all tests sequentially, give the new model some time to stabilize. + time.sleep(10) # TODO: Find a better way to detect model stability. + + def test_model_instance_c_get(self): + if self.model_instance == '': + self.test_model_b_initialize() + try: + inst_get_resp = api.model_instance_get(self.model_instance) + self.assertIsNotNone(inst_get_resp['url']) + self.assertGreater(len(inst_get_resp['url']), 0) + except ApiException as e: + self.fail(f"model_instance_get failed: {e}") + + def test_model_instance_d_files(self): + if self.model_instance == '': + self.test_model_b_initialize() + try: + inst_files_resp = api.model_instance_files(self.model_instance) + self.assertIsInstance(inst_files_resp.files, list) + self.assertGreater(len(inst_files_resp.files), 0) + except ApiException as e: + self.fail(f"model_instance_files failed: {e}") + + def test_model_instance_e_update(self): + if self.model_instance == '': + self.test_model_b_initialize() + try: + inst_update_resp = api.model_instance_update(model_inst_directory) + self.assertIsNotNone(inst_update_resp) + self.assertIsNotNone(inst_update_resp.ref) + self.assertGreater(len(inst_update_resp.ref), 0) + except ApiException as e: + self.fail(f"model_instance_update failed: {e}") + + # Model instance versions + + def test_model_instance_version_a_create(self): + if self.model_instance == '': + self.test_model_b_initialize() + try: + version_metadata_resp = api.model_instance_version_create( + self.model_instance, model_inst_vers_directory) + self.assertIsNotNone(version_metadata_resp.ref) + except ApiException as e: + self.fail(f"model_instance_version_create failed: {e}") + + def test_model_instance_version_b_files(self): + if self.model_instance == '': + self.test_model_b_initialize() + try: + r = api.model_instance_version_files(f'{self.model_instance}/1') + self.assertIsInstance(r.files, list) + self.assertGreater(len(r.files), 0) + except ApiException as e: + self.fail(f"model_instance_version_files failed: {e}") + + def test_model_instance_version_c_download(self): + if self.model_instance == '': + self.test_model_b_initialize() + version_file = '' + try: + version_file = api.model_instance_version_download( + f'{self.model_instance}/1', 'tmp') + self.assertTrue(os.path.exists(version_file)) + except KeyError: + pass # TODO Create a version that has content. + except ApiException as e: + self.fail(f"model_instance_version_download failed: {e}") + finally: + if os.path.exists(version_file): + os.remove(version_file) + if os.path.exists('tmp'): + os.rmdir('tmp') + + # Model deletion + + def test_model_instance_version_d_delete(self): + if self.model_instance == '': + self.test_model_b_initialize() + try: + version_delete_resp = api.model_instance_version_delete( + f'{self.model_instance}/1', True) + self.assertFalse(version_delete_resp.hasError) + except ApiException as e: + self.fail(f"model_instance_version_delete failed: {e}") + + def test_model_instance_x_delete(self): + if self.model_instance == '': + self.test_model_b_initialize() + try: + inst_update_resp = api.model_instance_delete(self.model_instance, True) + self.assertIsNotNone(inst_update_resp) + except ApiException as e: + self.fail(f"model_instance_delete failed: {e}") + + def test_model_z_delete(self): + try: + delete_response = api.model_delete(f'{test_user}/{model_title}', True) + if delete_response.hasError: + self.fail(delete_response.error) + else: + pass + except ApiException as e: + self.fail(f"model_delete failed: {e}") if __name__ == '__main__': - unittest.main() \ No newline at end of file + unittest.main()