Skip to content

Commit

Permalink
format
Browse files Browse the repository at this point in the history
  • Loading branch information
stevemessick committed Jan 13, 2025
1 parent 1259588 commit 502b82f
Show file tree
Hide file tree
Showing 2 changed files with 237 additions and 123 deletions.
197 changes: 131 additions & 66 deletions kaggle/api/kaggle_api.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,19 @@
#!/usr/bin/python
#
# Copyright 2025 Kaggle Inc
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

#!/usr/bin/python
#
# Copyright 2023 Kaggle Inc
Expand Down Expand Up @@ -90,7 +106,8 @@ def competitions_list(self,
If the method is called asynchronously,
returns the request thread.
"""
return self.api_client.competitions_list(group, category, sort_by, page, search)
return self.api_client.competitions_list(group, category, sort_by, page,
search)

def competitions_submissions_list(self, id, page=0): # noqa: E501
"""List competition submissions # noqa: E501
Expand All @@ -101,7 +118,8 @@ def competitions_submissions_list(self, id, page=0): # noqa: E501
"""
return self.api_client.competition_submissions(id, page)

def competitions_submissions_submit(self, blob_file_tokens, submission_description, id): # noqa: E501
def competitions_submissions_submit(self, blob_file_tokens,
submission_description, id): # noqa: E501
"""Submit to competition # noqa: E501
:param str blob_file_tokens: Token identifying location of uploaded submission file (required)
Expand All @@ -113,7 +131,8 @@ def competitions_submissions_submit(self, blob_file_tokens, submission_descripti
print("use kaggle_api_extended.KaggleApi.competition_submit() instead")
raise NotImplementedError()

def competitions_submissions_upload(self, file, guid, content_length, last_modified_date_utc): # noqa: E501
def competitions_submissions_upload(self, file, guid, content_length,
last_modified_date_utc): # noqa: E501
"""Upload competition submission file # noqa: E501
:param file file: Competition submission file (required)
Expand All @@ -123,10 +142,13 @@ def competitions_submissions_upload(self, file, guid, content_length, last_modif
:return: Result
"""
print("competitions_submissions_upload() not implemented")
print("use kaggle_api_extended.KaggleApi.competitions_submissions_submit() instead")
print(
"use kaggle_api_extended.KaggleApi.competitions_submissions_submit() instead"
)
raise NotImplementedError()

def competitions_submissions_url(self, id, content_length, last_modified_date_utc): # noqa: E501
def competitions_submissions_url(self, id, content_length,
last_modified_date_utc): # noqa: E501
"""Generate competition submission URL # noqa: E501
:param str id: Competition name, as it appears in the competition's URL (required)
Expand Down Expand Up @@ -159,7 +181,8 @@ def datasets_create_new(self, request): # noqa: E501
print("use kaggle_api_extended.KaggleApi.dataset_create_new() instead")
raise NotImplementedError()

def datasets_create_version_by_id(self, id, dataset_new_version_request, **kwargs): # noqa: E501
def datasets_create_version_by_id(self, id, dataset_new_version_request,
**kwargs): # noqa: E501
"""Create a new dataset version by id # noqa: E501
:param int id: Dataset ID (required)
Expand All @@ -172,7 +195,10 @@ def datasets_create_version_by_id(self, id, dataset_new_version_request, **kwarg
print("use kaggle_api_extended.KaggleApi.dataset_create_version() instead")
raise NotImplementedError()

def datasets_download(self, owner_slug, dataset_slug, dataset_version_number=None): # noqa: E501
def datasets_download(self,
owner_slug,
dataset_slug,
dataset_version_number=None): # noqa: E501
"""Download dataset file # noqa: E501
:param str owner_slug: Dataset owner (required)
Expand All @@ -185,7 +211,11 @@ def datasets_download(self, owner_slug, dataset_slug, dataset_version_number=Non
dataset += f'/{dataset_version_number}'
return self.api_client.dataset_download_files(dataset)

def datasets_download_file(self, owner_slug, dataset_slug, file_name, dataset_version_number=None): # noqa: E501
def datasets_download_file(self,
owner_slug,
dataset_slug,
file_name,
dataset_version_number=None): # noqa: E501
"""Download dataset file # noqa: E501
:param str owner_slug: Dataset owner (required)
Expand Down Expand Up @@ -230,17 +260,18 @@ def datasets_list(self,
If the method is called asynchronously,
returns the request thread.
"""
return self.api_client.dataset_list(sort_by=sort_by,
size=size,
file_type=file_type,
license_name=license_name,
tag_ids=tag_ids,
search=search,
user=user,
mine=mine,
page=page,
max_size=max_size,
min_size=min_size)
return self.api_client.dataset_list(
sort_by=sort_by,
size=size,
file_type=file_type,
license_name=license_name,
tag_ids=tag_ids,
search=search,
user=user,
mine=mine,
page=page,
max_size=max_size,
min_size=min_size)

def datasets_status(self, owner_slug, dataset_slug, **kwargs): # noqa: E501
"""Get dataset creation status # noqa: E501
Expand All @@ -262,7 +293,8 @@ def delete_model(self, owner_slug, model_slug): # noqa: E501
"""
return self.api_client.model_delete(f'{owner_slug}/{model_slug})')

def delete_model_instance(self, owner_slug, model_slug, framework, instance_slug): # noqa: E501
def delete_model_instance(self, owner_slug, model_slug, framework,
instance_slug): # noqa: E501
"""Delete a model instance # noqa: E501
:param str owner_slug: Model owner (required)
Expand All @@ -271,9 +303,12 @@ def delete_model_instance(self, owner_slug, model_slug, framework, instance_slug
:param str instance_slug: Model instance slug (required)
:return: Result
"""
return self.api_client.model_instance_delete(f'{owner_slug}/{model_slug}/{framework}/{instance_slug}', yes=True)
return self.api_client.model_instance_delete(
f'{owner_slug}/{model_slug}/{framework}/{instance_slug}', yes=True)

def delete_model_instance_version(self, owner_slug, model_slug, framework, instance_slug, version_number): # noqa: E501
def delete_model_instance_version(self, owner_slug, model_slug, framework,
instance_slug,
version_number): # noqa: E501
"""Delete a model instance version # noqa: E501
:param str owner_slug: Model owner (required)
Expand All @@ -283,7 +318,9 @@ def delete_model_instance_version(self, owner_slug, model_slug, framework, insta
:param str version_number: Model instance version number (required)
:return: Result
"""
return self.api_client.model_instance_version_delete(f'{owner_slug}/{model_slug}/{framework}/{instance_slug}/{version_number}', yes=True)
return self.api_client.model_instance_version_delete(
f'{owner_slug}/{model_slug}/{framework}/{instance_slug}/{version_number}',
yes=True)

def get_model(self, owner_slug, model_slug): # noqa: E501
"""Get a model # noqa: E501
Expand All @@ -294,7 +331,8 @@ def get_model(self, owner_slug, model_slug): # noqa: E501
"""
return self.api_client.model_get(f'{owner_slug}/{model_slug}')

def get_model_instance(self, owner_slug, model_slug, framework, instance_slug): # noqa: E501
def get_model_instance(self, owner_slug, model_slug, framework,
instance_slug): # noqa: E501
"""Get a model instance # noqa: E501
:param str owner_slug: Model owner (required)
Expand All @@ -303,7 +341,8 @@ def get_model_instance(self, owner_slug, model_slug, framework, instance_slug):
:param str instance_slug: Model instance slug (required)
:return: Result
"""
return self.api_client.model_instance_get(f'{owner_slug}/{model_slug}/{framework}/{instance_slug}')
return self.api_client.model_instance_get(
f'{owner_slug}/{model_slug}/{framework}/{instance_slug}')

def kernel_output(self, user_name, kernel_slug): # noqa: E501
"""Download the latest output from a kernel # noqa: E501
Expand All @@ -312,7 +351,8 @@ def kernel_output(self, user_name, kernel_slug): # noqa: E501
:param str kernel_slug: Kernel name (required)
:return: Result
"""
return self.api_client.kernels_output(f'{user_name}/{kernel_slug}', path=None, force=True)
return self.api_client.kernels_output(
f'{user_name}/{kernel_slug}', path=None, force=True)

def kernel_pull(self, user_name, kernel_slug): # noqa: E501
"""Pull the latest code from a kernel # noqa: E501
Expand All @@ -331,7 +371,7 @@ def kernel_push(self, kernel_push_request): # noqa: E501
"""
with tempfile.TemporaryDirectory() as tmpdir:
meta_file = os.path.join(tmpdir, 'kernel-metadata.json')
(fd,code_file) = tempfile.mkstemp('code','py', tmpdir, text=True)
(fd, code_file) = tempfile.mkstemp('code', 'py', tmpdir, text=True)
fd.write(json.dumps(kernel_push_request.code))
os.close(fd)
with open(meta_file, 'w') as f:
Expand All @@ -342,7 +382,9 @@ def kernel_push(self, kernel_push_request): # noqa: E501
params['kernel_sources'] = params.get('kernel_data_sources')
params['model_sources'] = params.get('model_data_sources')
params['title'] = params.get('new_title')
entries_to_remove = ('competition_data_sources', 'dataset_data_sources', 'kernel_data_sources', 'model_data_sources', 'new_title')
entries_to_remove = ('competition_data_sources', 'dataset_data_sources',
'kernel_data_sources', 'model_data_sources',
'new_title')
for k in entries_to_remove:
params.pop(k, None)
f.write(json.dumps(params))
Expand Down Expand Up @@ -386,18 +428,19 @@ def kernels_list(self,
:param str parent_kernel: Display kernels that have forked the specified kernel
:return: Result
"""
return self.api_client.kernels_list(page=page,
page_size=page_size,
dataset=dataset,
competition=competition,
parent_kernel=parent_kernel,
search=search,
mine=group != 'everyone',
user=user,
language=language,
kernel_type=kernel_type,
output_type=output_type,
sort_by=sort_by)
return self.api_client.kernels_list(
page=page,
page_size=page_size,
dataset=dataset,
competition=competition,
parent_kernel=parent_kernel,
search=search,
mine=group != 'everyone',
user=user,
language=language,
kernel_type=kernel_type,
output_type=output_type,
sort_by=sort_by)

def metadata_get(self, owner_slug, dataset_slug): # noqa: E501
"""Get the metadata for a dataset # noqa: E501
Expand All @@ -406,9 +449,11 @@ def metadata_get(self, owner_slug, dataset_slug): # noqa: E501
:param str dataset_slug: Dataset name (required)
:return: Result
"""
return self.api_client.dataset_metadata(f'{owner_slug}/{dataset_slug}', None)
return self.api_client.dataset_metadata(f'{owner_slug}/{dataset_slug}',
None)

def metadata_post(self, owner_slug, dataset_slug, settings, request): # noqa: E501
def metadata_post(self, owner_slug, dataset_slug, settings,
request): # noqa: E501
"""Update the metadata for a dataset # noqa: E501
:param str owner_slug: Dataset owner (required)
Expand All @@ -423,9 +468,12 @@ def metadata_post(self, owner_slug, dataset_slug, settings, request): # noqa: E
params['isPrivate'] = params.get('is_private')
params.pop('is_private', None)
f.write(json.dumps(params))
return self.api_client.dataset_metadata_update(f'{owner_slug}/{dataset_slug}', meta_file)
return self.api_client.dataset_metadata_update(
f'{owner_slug}/{dataset_slug}', meta_file)

def model_instance_versions_download(self, owner_slug, model_slug, framework, instance_slug, version_number): # noqa: E501
def model_instance_versions_download(self, owner_slug, model_slug, framework,
instance_slug,
version_number): # noqa: E501
"""Download model instance version files # noqa: E501
:param str owner_slug: Model owner (required)
Expand All @@ -438,7 +486,8 @@ def model_instance_versions_download(self, owner_slug, model_slug, framework, in
v = f'{owner_slug}/{model_slug}/{framework}/{instance_slug}/{version_number}'
return self.api_client.model_instance_version_download(v)

def models_create_instance(self, owner_slug, model_slug, model_new_instance_request): # noqa: E501
def models_create_instance(self, owner_slug, model_slug,
model_new_instance_request): # noqa: E501
"""Create a new model instance # noqa: E501
:param str owner_slug: Model owner (required)
Expand All @@ -450,7 +499,9 @@ def models_create_instance(self, owner_slug, model_slug, model_new_instance_requ
print("use kaggle_api_extended.KaggleApi.model_instance_create() instead")
raise NotImplementedError()

def models_create_instance_version(self, owner_slug, model_slug, framework, instance_slug, model_instance_new_version_request): # noqa: E501
def models_create_instance_version(
self, owner_slug, model_slug, framework, instance_slug,
model_instance_new_version_request): # noqa: E501
"""Create a new model instance version # noqa: E501
:param str owner_slug: Model owner (required)
Expand All @@ -461,7 +512,9 @@ def models_create_instance_version(self, owner_slug, model_slug, framework, inst
:return: Result
"""
print("models_create_instance_version() not implemented")
print("use kaggle_api_extended.KaggleApi.model_instance_version_create() instead")
print(
"use kaggle_api_extended.KaggleApi.model_instance_version_create() instead"
)
raise NotImplementedError()

def models_create_new(self, model_new_request): # noqa: E501
Expand All @@ -478,7 +531,8 @@ def models_create_new(self, model_new_request): # noqa: E501
params['isPrivate'] = params.get('is_private')
params['publishTime'] = params.get('publish_time')
params['provenanceSources'] = params.get('provenance_sources')
entries_to_remove = ('owner_slug', 'is_private', 'publish_time', 'provenance_sources')
entries_to_remove = ('owner_slug', 'is_private', 'publish_time',
'provenance_sources')
for k in entries_to_remove:
params.pop(k, None)
f.write(json.dumps(params))
Expand All @@ -490,21 +544,27 @@ def models_list(self,
owner=None,
page_size=20,
page_token=None): # noqa: E501
"""Lists models # noqa: E501
:param str search: Search terms
:param str sort_by: Sort the results
:param str owner: Display models by a specific user or organization
:param int page_size: Page size
:param str page_token: Page token for pagination
:return: Result
"""
return self.api_client.model_list(sort_by=sort_by,
search=search,
owner=owner,
page_size=page_size,
page_token=page_token)

def update_model(self, owner_slug, model_slug, model_update_request,): # noqa: E501
"""Lists models # noqa: E501
:param str search: Search terms
:param str sort_by: Sort the results
:param str owner: Display models by a specific user or organization
:param int page_size: Page size
:param str page_token: Page token for pagination
:return: Result
"""
return self.api_client.model_list(
sort_by=sort_by,
search=search,
owner=owner,
page_size=page_size,
page_token=page_token)

def update_model(
self,
owner_slug,
model_slug,
model_update_request,
): # noqa: E501
"""Update a model # noqa: E501
:param str owner_slug: Model owner (required)
Expand All @@ -521,13 +581,16 @@ def update_model(self, owner_slug, model_slug, model_update_request,): # noqa:
params['publishTime'] = params.get('publish_time')
params['provenanceSources'] = params.get('provenance_sources')
params['updateMask'] = params.get('update_mask')
entries_to_remove = ('owner_slug', 'is_private', 'publish_time', 'provenance_sources', 'update_mask')
entries_to_remove = ('owner_slug', 'is_private', 'publish_time',
'provenance_sources', 'update_mask')
for k in entries_to_remove:
params.pop(k, None)
f.write(json.dumps(params))
return self.api_client.model_update(tmpdir)

def update_model_instance(self, owner_slug, model_slug, framework, instance_slug, model_instance_update_request): # noqa: E501
def update_model_instance(self, owner_slug, model_slug, framework,
instance_slug,
model_instance_update_request): # noqa: E501
"""Update a model # noqa: E501
:param str owner_slug: Model owner (required)
Expand All @@ -548,7 +611,9 @@ def update_model_instance(self, owner_slug, model_slug, framework, instance_slug
params['baseModelInstance'] = params.get('base_model_instance')
params['externalBaseModelUrl'] = params.get('external_base_model_url')
params['updateMask'] = params.get('update_mask')
entries_to_remove = ('license_name', 'fine_tunable', 'training_data', 'model_instance_type', 'base_model_instance', 'external_base_model_url', 'update_mask')
entries_to_remove = ('license_name', 'fine_tunable', 'training_data',
'model_instance_type', 'base_model_instance',
'external_base_model_url', 'update_mask')
for k in entries_to_remove:
params.pop(k, None)
f.write(json.dumps(params))
Expand Down
Loading

0 comments on commit 502b82f

Please sign in to comment.