Skip to content

Commit

Permalink
fix source file name in unit test
Browse files Browse the repository at this point in the history
  • Loading branch information
AntonioMirarchi committed Jul 20, 2024
1 parent c0d39b7 commit 1d04f6e
Showing 1 changed file with 4 additions and 5 deletions.
9 changes: 4 additions & 5 deletions tests/test_mdcath.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

def test_mdcath(tmpdir):
num_atoms_list = np.linspace(50, 1000, 50)
source_file = h5py.File(join(tmpdir, "source.h5"), mode="w")
source_file = h5py.File(join(tmpdir, "mdcath_source.h5"), mode="w")
for num_atoms in num_atoms_list:
z = np.zeros(int(num_atoms))
pos = np.zeros((100, int(num_atoms), 3))
Expand Down Expand Up @@ -68,7 +68,7 @@ def test_mdcath_multiprocessing(tmpdir, num_entries=100, numFrames=10):
pos = np.zeros((numFrames, num_entries, 3))
forces = np.zeros((numFrames, num_entries, 3))

source_file = h5py.File(join(tmpdir, "source.h5"), mode="w")
source_file = h5py.File(join(tmpdir, "mdcath_source.h5"), mode="w")
s_group = source_file.create_group("A00")

s_group.attrs["numChains"] = 1
Expand Down Expand Up @@ -108,7 +108,6 @@ def test_mdcath_multiprocessing(tmpdir, num_entries=100, numFrames=10):

dset = MDCATH(
root=tmpdir,
source_file=join(tmpdir, "source.h5"),
)
assert len(proc.open_files()) == n_open, "creating the dataset object opened a file"

Expand All @@ -125,7 +124,7 @@ def replacer(arr, skipframes):
@mark.parametrize("batch_size", [1, 10])
def test_mdcath_skipframes(tmpdir, skipframes, batch_size):

with h5py.File(join(tmpdir, "source.h5"), mode="w") as source_file:
with h5py.File(join(tmpdir, "mdcath_source.h5"), mode="w") as source_file:
num_frames_list = np.linspace(50, 1000, 50).astype(int)
for num_frame in tqdm(num_frames_list, desc="Creating tmp files"):
z = np.zeros(100)
Expand Down Expand Up @@ -169,7 +168,7 @@ def test_mdcath_skipframes(tmpdir, skipframes, batch_size):
data.close()

dataset = MDCATH(
root=tmpdir, skipFrames=skipframes, source_file=join(tmpdir, "source.h5")
root=tmpdir, skipFrames=skipframes
)
dl = DataLoader(
dataset,
Expand Down

0 comments on commit 1d04f6e

Please sign in to comment.