Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Built-in support for saving/loading model weights in safetensors format #186

Open
minghuaw opened this issue Jan 27, 2025 · 1 comment
Open

Comments

@minghuaw
Copy link
Collaborator

minghuaw commented Jan 27, 2025

We already have support for conversion between mlx_rs::Array and safetensors::tensor::TensorView, so supporting this wouldn't be hard. However, there might be an asymmetry in the saving/loading API due to the lack of public API to create an SafeTensors from Array. More specifically, the only public API to create SafeTensors is SafeTensors::deserialize(buf), where buf is &[u8].

This could end up with some API look like below

fn load_safetensors(model: &mut impl ModuleParameters, safetensors: SafeTensors<'_>) -> Result<()> { }

fn save_safetensors(model: & impl ModuleParameters, path: impl AsRef<Path>) -> Result<()> { }

where we have an asymmetry that we can only save to a file rather than a SafeTensors object.

Or we could do something similar to candle-nn where both loading and saving take a Path to a safetensors file

@minghuaw
Copy link
Collaborator Author

#178 provides a partial fulfillment of this feature, and both saving and loading would be dealing with a impl AsRef<Path>. The parts that is missing in #178 are

  • Performance of loading weights. Loading weights in the mistral example is two times slower than the original python example
  • Support loading/saving from/to multiple files

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant