From 0ddab2189501d4a3cbdd33435ea03fb4c8a6d0b7 Mon Sep 17 00:00:00 2001 From: "Artur A. Galstyan" Date: Sun, 3 Mar 2024 18:53:10 +0100 Subject: [PATCH] export alias --- jaxonloader/__init__.py | 3 +++ jaxonloader/dataloader.py | 8 ++------ pyproject.toml | 2 +- 3 files changed, 6 insertions(+), 7 deletions(-) diff --git a/jaxonloader/__init__.py b/jaxonloader/__init__.py index 71272b6..4d08a94 100644 --- a/jaxonloader/__init__.py +++ b/jaxonloader/__init__.py @@ -1,5 +1,8 @@ from jaxonloader._datasets import * # noqa from jaxonloader.dataloader import JaxonDataLoader, make # noqa import equinox as eqx +from jaxtyping import Array +from collections.abc import Callable Index = eqx.nn.State +JITJaxonDataLoader = Callable[[eqx.nn.State], tuple[Array, eqx.nn.State, bool]] diff --git a/jaxonloader/dataloader.py b/jaxonloader/dataloader.py index 32e8626..5b63390 100644 --- a/jaxonloader/dataloader.py +++ b/jaxonloader/dataloader.py @@ -1,10 +1,9 @@ -from collections.abc import Callable - import equinox as eqx import jax import jax.numpy as jnp from jaxtyping import Array, PRNGKeyArray +from jaxonloader import JITJaxonDataLoader from jaxonloader.dataset import JaxonDataset @@ -78,10 +77,7 @@ def make( drop_last: bool = False, key: PRNGKeyArray | None = None, jit: bool = True, -) -> ( - tuple[Callable[[eqx.nn.State], tuple[Array, eqx.nn.State, bool]], eqx.nn.State] - | tuple[JaxonDataLoader, eqx.nn.State] -): +) -> tuple[JITJaxonDataLoader, eqx.nn.State] | tuple[JaxonDataLoader, eqx.nn.State]: dataloader, index = eqx.nn.make_with_state(JaxonDataLoader)( dataset, batch_size, shuffle, drop_last, key=key ) diff --git a/pyproject.toml b/pyproject.toml index 6a05f29..5032dfe 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "jaxonloader" -version = "0.2.1" +version = "0.2.2" description = "A dataloader, but for JAX" readme = "README.md" requires-python ="~=3.10"