Skip to content

Commit

Permalink
Read all the tensors in a PyTorch pth file. (#1106)
Browse files Browse the repository at this point in the history
  • Loading branch information
LaurentMazare authored Oct 16, 2023
1 parent 588ad48 commit 0106b0b
Showing 1 changed file with 13 additions and 0 deletions.
13 changes: 13 additions & 0 deletions candle-core/src/pickle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -723,3 +723,16 @@ impl PthTensors {
Ok(Some(tensor))
}
}

/// Read all the tensors from a PyTorch pth file.
pub fn read_all<P: AsRef<std::path::Path>>(path: P) -> Result<Vec<(String, Tensor)>> {
let pth = PthTensors::new(path)?;
let tensor_names = pth.tensor_infos.keys();
let mut tensors = Vec::with_capacity(tensor_names.len());
for name in tensor_names {
if let Some(tensor) = pth.get(name)? {
tensors.push((name.to_string(), tensor))
}
}
Ok(tensors)
}

0 comments on commit 0106b0b

Please sign in to comment.