Skip to content

Latest commit

 

History

History
42 lines (28 loc) · 1.11 KB

README.md

File metadata and controls

42 lines (28 loc) · 1.11 KB

PyTorch-VeLO

VeLO optimizer usable from PyTorch.

The wrapping is very basic, we try to let JAX do everything so we do not have to re-implement the optimizer in PyTorch.

XLA_PYTHON_CLIENT_PREALLOCATE=false is automatically set so JAX does not consume all GPU memory.

Installation

python3 -m pip install git+https://github.com/janEbert/PyTorch-VeLO.git

Usage

from pytorch_velo import VeLO

# [...]

train_steps = epochs * len(dataset)  # Assuming `dataset` is already batched.
opt = VeLO(params, num_training_steps=train_steps, weight_decay=0.0)

# Use like any other PyTorch optimizer.

Caution

Alpha-level software. Not well tested, probably highly imperformant.

Only parameters with trivial strides are supported; this will have to be implemented on the JAX side (see jax-ml/jax#8082).

With jax==0.3.21 (automatically installed via learned_optimization as of writing), the jax.default_device context manager does not work. To force JAX to use the CPU for its optimizer, set the environment variable JAX_PLATFORMS=cpu.