From 8148a14e725c0f09aae3eb74769b14d9ffca198b Mon Sep 17 00:00:00 2001 From: Artur Galstyan Date: Tue, 9 Apr 2024 21:03:28 +0200 Subject: [PATCH] return paths --- jaxonloader/datasets/_datasets.py | 20 +++++--------------- jaxonloader/datasets/download.py | 26 ++++++++++++++------------ pyproject.toml | 2 +- 3 files changed, 20 insertions(+), 28 deletions(-) diff --git a/jaxonloader/datasets/_datasets.py b/jaxonloader/datasets/_datasets.py index 32ffac9..3f16693 100644 --- a/jaxonloader/datasets/_datasets.py +++ b/jaxonloader/datasets/_datasets.py @@ -1,4 +1,3 @@ -import pathlib import pickle from collections.abc import Callable from typing import Optional @@ -15,17 +14,12 @@ download_tinyshakespeare, download_titanic, ) -from jaxonloader.utils import ( - get_data_path, - JAXONLOADER_PATH, -) def get_mnist( *, target_path: Optional[str] = None ) -> tuple[JaxonDataset, JaxonDataset]: - download_mnist(target_path=target_path) - data_path = get_data_path("mnist", target_path) + data_path = download_mnist(target_path=target_path) train_df = pl.read_csv(data_path / "mnist_train.csv") test_df = pl.read_csv(data_path / "mnist_test.csv") @@ -38,8 +32,7 @@ def get_mnist( def get_cifar10(target_path: Optional[str] = None) -> tuple[JaxonDataset, JaxonDataset]: - download_cifar10(target_path=target_path) - data_path = pathlib.Path(JAXONLOADER_PATH) / "cifar10" + data_path = download_cifar10(target_path=target_path) n_batches = 5 train_data = [] train_labels = [] @@ -65,8 +58,7 @@ def get_cifar10(target_path: Optional[str] = None) -> tuple[JaxonDataset, JaxonD def get_cifar100( target_path: Optional[str] = None, ) -> tuple[JaxonDataset, JaxonDataset]: - download_cifar100(target_path=target_path) - data_path = get_data_path("cifar100", target_path) + data_path = download_cifar100(target_path=target_path) with open(data_path / "cifar-100-python/train", "rb") as f: train_data = pickle.load(f, encoding="bytes") @@ -128,8 +120,7 @@ def get_tiny_shakespeare( train_dataset, test_dataset, vocab_size, encoder, decoder = get_tiny_shakespeare() ``` """ - download_tinyshakespeare(target_path=target_path) - data_path = get_data_path("tinyshakespeare", target_path) + data_path = download_tinyshakespeare(target_path=target_path) def get_text(): with open(data_path / "input.txt", "r") as f: @@ -172,8 +163,7 @@ def decode(latent: NDArray) -> str: def get_titanic(target_path: Optional[str] = None) -> JaxonDataset: - download_titanic(target_path=target_path) - data_path = pathlib.Path(JAXONLOADER_PATH) / "titanic" + data_path = download_titanic(target_path=target_path) train_df = pl.read_csv(data_path / "train.csv") def _gender_to_int(df: pl.DataFrame) -> pl.DataFrame: diff --git a/jaxonloader/datasets/download.py b/jaxonloader/datasets/download.py index 19fa2d0..b5f67e1 100644 --- a/jaxonloader/datasets/download.py +++ b/jaxonloader/datasets/download.py @@ -1,35 +1,37 @@ +from pathlib import Path from typing import Optional from jaxonloader.utils import ( + get_data_path, jaxonloader_cache, ) @jaxonloader_cache(dataset_name="mnist") -def download_mnist(*, target_path: Optional[str]) -> None: - return +def download_mnist(*, target_path: Optional[str]) -> Path: + return get_data_path("mnist", target_path) @jaxonloader_cache(dataset_name="titanic") -def download_titanic(*, target_path: Optional[str]) -> None: - pass +def download_titanic(*, target_path: Optional[str]) -> Path: + return get_data_path("titanic", target_path) @jaxonloader_cache(dataset_name="hms") -def download_hms(*, target_path: Optional[str]) -> None: - pass +def download_hms(*, target_path: Optional[str]) -> Path: + return get_data_path("hms", target_path) @jaxonloader_cache(dataset_name="cifar10") -def download_cifar10(*, target_path: Optional[str]) -> None: - pass +def download_cifar10(*, target_path: Optional[str]) -> Path: + return get_data_path("cifar10", target_path) @jaxonloader_cache(dataset_name="cifar100") -def download_cifar100(*, target_path: Optional[str]) -> None: - pass +def download_cifar100(*, target_path: Optional[str]) -> Path: + return get_data_path("cifar100", target_path) @jaxonloader_cache(dataset_name="tinyshakespeare") -def download_tinyshakespeare(*, target_path: Optional[str]) -> None: - pass +def download_tinyshakespeare(*, target_path: Optional[str]) -> Path: + return get_data_path("tinyshakespeare", target_path) diff --git a/pyproject.toml b/pyproject.toml index 67ac2d0..bd2791d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "jaxonloader" -version = "0.3.5" +version = "0.3.6" description = "A dataloader, but for JAX" readme = "README.md" requires-python ="~=3.10"