Skip to content

Commit

Permalink
Create a new curand instead of reseeding. (#1089)
Browse files Browse the repository at this point in the history
  • Loading branch information
LaurentMazare authored Oct 14, 2023
1 parent a193bf5 commit 9309cfc
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 1 deletion.
4 changes: 3 additions & 1 deletion candle-core/src/cuda_backend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -224,8 +224,10 @@ impl BackendDevice for CudaDevice {
}

fn set_seed(&self, seed: u64) -> Result<()> {
// We do not call set_seed but instead create a new curand object. This ensures that the
// state will be identical and the same random numbers will be generated.
let mut curand = self.curand.lock().unwrap();
curand.0.set_seed(seed).w()?;
curand.0 = cudarc::curand::CudaRng::new(seed, self.device.clone()).w()?;
Ok(())
}

Expand Down
7 changes: 7 additions & 0 deletions candle-core/src/device.rs
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,13 @@ impl Device {
Ok(Self::Cuda(crate::CudaDevice::new(ordinal)?))
}

pub fn set_seed(&self, seed: u64) -> Result<()> {
match self {
Self::Cpu => crate::cpu_backend::CpuDevice.set_seed(seed),
Self::Cuda(c) => c.set_seed(seed),
}
}

pub fn same_device(&self, rhs: &Self) -> bool {
match (self, rhs) {
(Self::Cpu, Self::Cpu) => true,
Expand Down

0 comments on commit 9309cfc

Please sign in to comment.