Skip to content

Commit

Permalink
return paths
Browse files Browse the repository at this point in the history
  • Loading branch information
Artur-Galstyan committed Apr 9, 2024
1 parent e15a80b commit 8148a14
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 28 deletions.
20 changes: 5 additions & 15 deletions jaxonloader/datasets/_datasets.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import pathlib
import pickle
from collections.abc import Callable
from typing import Optional
Expand All @@ -15,17 +14,12 @@
download_tinyshakespeare,
download_titanic,
)
from jaxonloader.utils import (
get_data_path,
JAXONLOADER_PATH,
)


def get_mnist(
*, target_path: Optional[str] = None
) -> tuple[JaxonDataset, JaxonDataset]:
download_mnist(target_path=target_path)
data_path = get_data_path("mnist", target_path)
data_path = download_mnist(target_path=target_path)
train_df = pl.read_csv(data_path / "mnist_train.csv")
test_df = pl.read_csv(data_path / "mnist_test.csv")

Expand All @@ -38,8 +32,7 @@ def get_mnist(


def get_cifar10(target_path: Optional[str] = None) -> tuple[JaxonDataset, JaxonDataset]:
download_cifar10(target_path=target_path)
data_path = pathlib.Path(JAXONLOADER_PATH) / "cifar10"
data_path = download_cifar10(target_path=target_path)
n_batches = 5
train_data = []
train_labels = []
Expand All @@ -65,8 +58,7 @@ def get_cifar10(target_path: Optional[str] = None) -> tuple[JaxonDataset, JaxonD
def get_cifar100(
target_path: Optional[str] = None,
) -> tuple[JaxonDataset, JaxonDataset]:
download_cifar100(target_path=target_path)
data_path = get_data_path("cifar100", target_path)
data_path = download_cifar100(target_path=target_path)

with open(data_path / "cifar-100-python/train", "rb") as f:
train_data = pickle.load(f, encoding="bytes")
Expand Down Expand Up @@ -128,8 +120,7 @@ def get_tiny_shakespeare(
train_dataset, test_dataset, vocab_size, encoder, decoder = get_tiny_shakespeare()
```
"""
download_tinyshakespeare(target_path=target_path)
data_path = get_data_path("tinyshakespeare", target_path)
data_path = download_tinyshakespeare(target_path=target_path)

def get_text():
with open(data_path / "input.txt", "r") as f:
Expand Down Expand Up @@ -172,8 +163,7 @@ def decode(latent: NDArray) -> str:


def get_titanic(target_path: Optional[str] = None) -> JaxonDataset:
download_titanic(target_path=target_path)
data_path = pathlib.Path(JAXONLOADER_PATH) / "titanic"
data_path = download_titanic(target_path=target_path)
train_df = pl.read_csv(data_path / "train.csv")

def _gender_to_int(df: pl.DataFrame) -> pl.DataFrame:
Expand Down
26 changes: 14 additions & 12 deletions jaxonloader/datasets/download.py
Original file line number Diff line number Diff line change
@@ -1,35 +1,37 @@
from pathlib import Path
from typing import Optional

from jaxonloader.utils import (
get_data_path,
jaxonloader_cache,
)


@jaxonloader_cache(dataset_name="mnist")
def download_mnist(*, target_path: Optional[str]) -> None:
return
def download_mnist(*, target_path: Optional[str]) -> Path:
return get_data_path("mnist", target_path)


@jaxonloader_cache(dataset_name="titanic")
def download_titanic(*, target_path: Optional[str]) -> None:
pass
def download_titanic(*, target_path: Optional[str]) -> Path:
return get_data_path("titanic", target_path)


@jaxonloader_cache(dataset_name="hms")
def download_hms(*, target_path: Optional[str]) -> None:
pass
def download_hms(*, target_path: Optional[str]) -> Path:
return get_data_path("hms", target_path)


@jaxonloader_cache(dataset_name="cifar10")
def download_cifar10(*, target_path: Optional[str]) -> None:
pass
def download_cifar10(*, target_path: Optional[str]) -> Path:
return get_data_path("cifar10", target_path)


@jaxonloader_cache(dataset_name="cifar100")
def download_cifar100(*, target_path: Optional[str]) -> None:
pass
def download_cifar100(*, target_path: Optional[str]) -> Path:
return get_data_path("cifar100", target_path)


@jaxonloader_cache(dataset_name="tinyshakespeare")
def download_tinyshakespeare(*, target_path: Optional[str]) -> None:
pass
def download_tinyshakespeare(*, target_path: Optional[str]) -> Path:
return get_data_path("tinyshakespeare", target_path)
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "jaxonloader"
version = "0.3.5"
version = "0.3.6"
description = "A dataloader, but for JAX"
readme = "README.md"
requires-python ="~=3.10"
Expand Down

0 comments on commit 8148a14

Please sign in to comment.