Skip to content

Commit

Permalink
Convert datasets to use kagglesdk (#636)
Browse files Browse the repository at this point in the history
I have some `TODO`s that won't be done until Swagger is removed.

Ignore everything in `kagglesdk`.

The unit tests were not reformatted earlier.

```bash
$ yapf --version
yapf 0.40.2
```
The version is the same as specified in #634, so I don't know why some
indentation changed.
  • Loading branch information
stevemessick authored Sep 25, 2024
1 parent ded7a52 commit 8e32390
Show file tree
Hide file tree
Showing 12 changed files with 1,649 additions and 1,313 deletions.
330 changes: 194 additions & 136 deletions kaggle/api/kaggle_api_extended.py

Large diffs are not rendered by default.

54 changes: 32 additions & 22 deletions kaggle/models/kaggle_models_extended.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,19 @@
#!/usr/bin/python
#
# Copyright 2024 Kaggle Inc
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

#!/usr/bin/python
#
# Copyright 2024 Kaggle Inc
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

#!/usr/bin/python
#
# Copyright 2019 Kaggle Inc
Expand Down Expand Up @@ -131,9 +131,14 @@ def __repr__(self):
class File(object):

def __init__(self, init_dict):
parsed_dict = {k: parse(v) for k, v in init_dict.items()}
self.__dict__.update(parsed_dict)
self.size = File.get_size(self.totalBytes)
try: # TODO Remove try-block
parsed_dict = {k: parse(v) for k, v in init_dict.items()}
self.__dict__.update(parsed_dict)
self.size = File.get_size(self.totalBytes)
except AttributeError:
self.name = init_dict.name
self.creation_date = init_dict.creation_date
self.size = File.get_size(init_dict.total_bytes)

def __repr__(self):
return self.name
Expand Down Expand Up @@ -181,13 +186,18 @@ def __repr__(self):
class ListFilesResult(object):

def __init__(self, init_dict):
self.error_message = init_dict['errorMessage']
files = init_dict['datasetFiles']
try: # TODO Remove try-block
self.error_message = init_dict['errorMessage']
files = init_dict['datasetFiles']
token = init_dict['nextPageToken']
except TypeError:
self.error_message = init_dict.error_message
files = init_dict.dataset_files
token = init_dict.next_page_token
if files:
self.files = [File(f) for f in files]
else:
self.files = {}
token = init_dict['nextPageToken']
if token:
self.nextPageToken = token
else:
Expand Down
16 changes: 16 additions & 0 deletions kagglesdk/datasets/types/dataset_api_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,10 @@ def endpoint(self):
def method():
return 'POST'

@staticmethod
def body_fields():
return '*'

class ApiCreateDatasetResponse(KaggleObject):
r"""
Attributes:
Expand Down Expand Up @@ -310,6 +314,10 @@ def endpoint(self):
def method():
return 'POST'

@staticmethod
def body_fields():
return 'body'

class ApiCreateDatasetVersionRequest(KaggleObject):
r"""
Attributes:
Expand Down Expand Up @@ -373,6 +381,10 @@ def endpoint(self):
def method():
return 'POST'

@staticmethod
def body_fields():
return 'body'

class ApiCreateDatasetVersionRequestBody(KaggleObject):
r"""
Attributes:
Expand Down Expand Up @@ -2080,6 +2092,10 @@ def endpoint(self):
def method():
return 'POST'

@staticmethod
def body_fields():
return 'settings'

class ApiUpdateDatasetMetadataResponse(KaggleObject):
r"""
Attributes:
Expand Down
23 changes: 23 additions & 0 deletions kagglesdk/kaggle_http_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,16 @@ def _get_apikey_creds():
api_key = api_key_data['key']
return username, api_key

def clean_data(data):
if isinstance(data, dict):
return {k: clean_data(v) for k, v in data.items() if v is not None}
if isinstance(data, list):
return [clean_data(v) for v in data if v is not None]
if data is True:
return 'true'
if data is False:
return 'false'
return data

class KaggleHttpClient(object):
_xsrf_cookie_name = 'XSRF-TOKEN'
Expand Down Expand Up @@ -75,6 +85,19 @@ def _prepare_request(self, service_name: str, request_name: str, request: Kaggle
'Accept': 'application/json',
'Content-Type': 'text/plain',
})
elif method == 'POST':
self._session.headers.update({
'Accept': 'application/json',
'Content-Type': 'application/json',
})
if isinstance(data, dict):
fields = request.body_fields()
if fields is not None:
if fields != '*':
data = data[fields]
data = clean_data(data)
data = data.__str__().replace("'", '"')
# TODO Remove quotes from numbers.
http_request = requests.Request(
method=method,
url=request_url,
Expand Down
6 changes: 5 additions & 1 deletion kagglesdk/kaggle_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,10 @@ class KaggleObject(object):
def endpoint(self):
raise 'Error: endpoint must be defined by the request object'

@staticmethod
def body_fields():
return None

@classmethod
def prepare_from(cls, http_response):
return cls.from_json(http_response.text)
Expand All @@ -229,7 +233,7 @@ def to_dict(self, ignore_defaults=True):
@staticmethod
def to_field_map(self, ignore_defaults=True):
kv_pairs = [(field.field_name, field.get_as_dict_item(self, ignore_defaults)) for field in self._fields]
return {k: v for (k, v) in kv_pairs if not ignore_defaults or v is not None}
return {k: str(v) for (k, v) in kv_pairs if not ignore_defaults or v is not None}

@classmethod
def from_dict(cls, json_dict):
Expand Down
4 changes: 4 additions & 0 deletions kagglesdk/kernels/types/kernels_api_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -1365,6 +1365,10 @@ def endpoint(self):
def method():
return 'POST'

@staticmethod
def body_fields():
return '*'

class ApiSaveKernelResponse(KaggleObject):
r"""
Attributes:
Expand Down
58 changes: 57 additions & 1 deletion kagglesdk/models/types/model_api_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from kagglesdk.datasets.types.dataset_api_service import ApiCategory, ApiDatasetNewFile, ApiUploadDirectoryInfo
from kagglesdk.kaggle_object import *
from kagglesdk.models.types.model_enums import ListModelsOrderBy, ModelFramework, ModelInstanceType
from kagglesdk.models.types.model_types import BaseModelInstanceInformation
from kagglesdk.models.types.model_types import BaseModelInstanceInformation, ModelLink
from typing import Optional, List

class ApiCreateModelInstanceRequest(KaggleObject):
Expand Down Expand Up @@ -69,6 +69,10 @@ def endpoint(self):
def method():
return 'POST'

@staticmethod
def body_fields():
return 'body'

class ApiCreateModelInstanceRequestBody(KaggleObject):
r"""
Attributes:
Expand Down Expand Up @@ -357,6 +361,10 @@ def endpoint(self):
def method():
return 'POST'

@staticmethod
def body_fields():
return 'body'

class ApiCreateModelInstanceVersionRequestBody(KaggleObject):
r"""
Attributes:
Expand Down Expand Up @@ -553,6 +561,10 @@ def endpoint(self):
def method():
return 'POST'

@staticmethod
def body_fields():
return '*'

class ApiCreateModelResponse(KaggleObject):
r"""
Attributes:
Expand Down Expand Up @@ -1280,6 +1292,8 @@ class ApiListModelsRequest(KaggleObject):
Page size.
page_token (str)
Page token used for pagination.
only_vertex_models (bool)
Only list models that have Vertex URLs
"""

def __init__(self):
Expand All @@ -1288,6 +1302,7 @@ def __init__(self):
self._owner = None
self._page_size = None
self._page_token = None
self._only_vertex_models = None
self._freeze()

@property
Expand Down Expand Up @@ -1363,6 +1378,20 @@ def page_token(self, page_token: str):
raise TypeError('page_token must be of type str')
self._page_token = page_token

@property
def only_vertex_models(self) -> bool:
"""Only list models that have Vertex URLs"""
return self._only_vertex_models or False

@only_vertex_models.setter
def only_vertex_models(self, only_vertex_models: bool):
if only_vertex_models is None:
del self.only_vertex_models
return
if not isinstance(only_vertex_models, bool):
raise TypeError('only_vertex_models must be of type bool')
self._only_vertex_models = only_vertex_models


def endpoint(self):
path = '/api/v1/models/list'
Expand Down Expand Up @@ -1441,6 +1470,7 @@ class ApiModel(KaggleObject):
publish_time (datetime)
provenance_sources (str)
url (str)
model_version_links (ModelLink)
"""

def __init__(self):
Expand All @@ -1457,6 +1487,7 @@ def __init__(self):
self._publish_time = None
self._provenance_sources = ""
self._url = ""
self._model_version_links = []
self._freeze()

@property
Expand Down Expand Up @@ -1633,6 +1664,21 @@ def url(self, url: str):
raise TypeError('url must be of type str')
self._url = url

@property
def model_version_links(self) -> Optional[List[Optional['ModelLink']]]:
return self._model_version_links

@model_version_links.setter
def model_version_links(self, model_version_links: Optional[List[Optional['ModelLink']]]):
if model_version_links is None:
del self.model_version_links
return
if not isinstance(model_version_links, list):
raise TypeError('model_version_links must be of type list')
if not all([isinstance(t, ModelLink) for t in model_version_links]):
raise TypeError('model_version_links must contain only items of type ModelLink')
self._model_version_links = model_version_links


class ApiModelFile(KaggleObject):
r"""
Expand Down Expand Up @@ -2139,6 +2185,10 @@ def endpoint(self):
def method():
return 'POST'

@staticmethod
def body_fields():
return '*'

class ApiUpdateModelRequest(KaggleObject):
r"""
Attributes:
Expand Down Expand Up @@ -2292,6 +2342,10 @@ def endpoint(self):
def method():
return 'POST'

@staticmethod
def body_fields():
return '*'

class ApiUpdateModelResponse(KaggleObject):
r"""
Attributes:
Expand Down Expand Up @@ -2587,6 +2641,7 @@ def create_url(self, create_url: str):
FieldMetadata("owner", "owner", "_owner", str, None, PredefinedSerializer(), optional=True),
FieldMetadata("pageSize", "page_size", "_page_size", int, None, PredefinedSerializer(), optional=True),
FieldMetadata("pageToken", "page_token", "_page_token", str, None, PredefinedSerializer(), optional=True),
FieldMetadata("onlyVertexModels", "only_vertex_models", "_only_vertex_models", bool, None, PredefinedSerializer(), optional=True),
]

ApiListModelsResponse._fields = [
Expand All @@ -2609,6 +2664,7 @@ def create_url(self, create_url: str):
FieldMetadata("publishTime", "publish_time", "_publish_time", datetime, None, DateTimeSerializer()),
FieldMetadata("provenanceSources", "provenance_sources", "_provenance_sources", str, "", PredefinedSerializer()),
FieldMetadata("url", "url", "_url", str, "", PredefinedSerializer()),
FieldMetadata("modelVersionLinks", "model_version_links", "_model_version_links", ModelLink, [], ListSerializer(KaggleObjectSerializer())),
]

ApiModelFile._fields = [
Expand Down
5 changes: 5 additions & 0 deletions kagglesdk/models/types/model_enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,3 +42,8 @@ class ModelInstanceType(enum.Enum):
MODEL_INSTANCE_TYPE_KAGGLE_VARIANT = 2
MODEL_INSTANCE_TYPE_EXTERNAL_VARIANT = 3

class ModelVersionLinkType(enum.Enum):
MODEL_VERSION_LINK_TYPE_UNSPECIFIED = 0
MODEL_VERSION_LINK_TYPE_VERTEX_OPEN = 1
MODEL_VERSION_LINK_TYPE_VERTEX_DEPLOY = 2

Loading

0 comments on commit 8e32390

Please sign in to comment.