Skip to content

Commit

Permalink
Merge branch 'main' into dev_rolf
Browse files Browse the repository at this point in the history
  • Loading branch information
rolfverberg committed Oct 21, 2024
2 parents ce4f615 + 7d92c3a commit d3a945c
Show file tree
Hide file tree
Showing 5 changed files with 58 additions and 3 deletions.
1 change: 1 addition & 0 deletions CHAP/edd/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,6 @@
SetupNXdataReader,
UpdateNXdataReader,
NXdataSliceReader,
SliceNXdataReader,
)
# from CHAP.edd.writer import
53 changes: 53 additions & 0 deletions CHAP/edd/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -699,6 +699,59 @@ def read(self, filename, dataset_id, detectors=None):
return {'coords': coords, 'signals': signals,
'attrs': attrs, 'data_points': data_points}

class SliceNXdataReader(Reader):
"""A reader class to load and slice an NXdata field from a NeXus
file. This class reads EDD (Energy Dispersive Diffraction) data
from an NXdata group and slices all fields according to the
provided slicing parameters.
"""
def read(self, filename, scan_number, inputdir=None):
"""Reads an NXdata group from a NeXus file and slices the
fields within it based on the provided scan number.
:param filename: The name of the NeXus file to read.
:type filename: str
:param scan_number: The scan number to use for slicing the
data.
:type scan_number: int
:param inputdir: The directory containing the input file,
defaults to None.
:type inputdir: str, optional
:return: The root object of the NeXus file with sliced NXdata
fields.
:rtype: NXroot
:raises ValueError: If no NXdata group is found in the file.
"""
import os
import numpy as np
from nexusformat.nexus import NXentry, NXfield

from CHAP.common import NexusReader
from CHAP.utils.general import nxcopy

reader = NexusReader()
nxroot = nxcopy(reader.read(os.path.join(inputdir, filename)))
nxdata = None
for nxname, nxobject in nxroot.items():
if isinstance(nxobject, NXentry):
nxdata = nxobject.data
if nxdata is None:
msg = 'Could not find NXdata group'
self.logger.error(msg)
raise ValueError(msg)

indices = np.argwhere(nxdata.SCAN_N.nxdata == scan_number).flatten()
for nxname, nxobject in nxdata.items():
if isinstance(nxobject, NXfield):
nxdata[nxname] = NXfield(
value=nxobject.nxdata[indices],
dtype=nxdata[nxname].dtype,
attrs=nxdata[nxname].attrs,
)

return nxroot

class UpdateNXdataReader(Reader):
"""Companion to `edd.SetupNXdataReader` and
Expand Down
2 changes: 2 additions & 0 deletions CHAP/edd/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1398,6 +1398,7 @@ def get_spectra_fits(spectra, energies, peak_locations, detector):
numpy.ndarray, numpy.ndarray, numpy.ndarray, numpy.ndarray,
numpy.ndarray]
"""
from os import getpid
# Third party modules
from nexusformat.nexus import NXdata, NXfield

Expand Down Expand Up @@ -1432,6 +1433,7 @@ def get_spectra_fits(spectra, energies, peak_locations, detector):
# 'method': 'trf',
'method': 'leastsq',
# 'method': 'least_squares',
'memfolder': f'/tmp/{getpid()}_joblib_memmap',
}

# Perform uniform fit
Expand Down
4 changes: 1 addition & 3 deletions CHAP/utils/fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -2084,7 +2084,7 @@ def __init__(self, nxdata, config):
self._best_values = None
self._inv_transpose = None
self._max_nfev = None
self._memfolder = None
self._memfolder = config.memfolder
self._new_parameters = None
self._num_func_eval = None
self._out_of_bounds = None
Expand Down Expand Up @@ -2319,7 +2319,6 @@ def freemem(self):
return
try:
rmtree(self._memfolder)
self._memfolder = None
except:
logger.warning('Could not clean-up automatically.')

Expand Down Expand Up @@ -2514,7 +2513,6 @@ def fit(self, config=None, **kwargs):
np.zeros(self._map_dim, dtype=np.float64)
for _ in range(num_new_parameters)]
else:
self._memfolder = 'joblib_memmap'
try:
mkdir(self._memfolder)
except FileExistsError:
Expand Down
1 change: 1 addition & 0 deletions CHAP/utils/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -568,6 +568,7 @@ class FitConfig(BaseModel):
num_proc: conint(gt=0) = 1
plot: StrictBool = False
print_report: StrictBool = False
memfolder: str = 'joblib_memmap'

@field_validator('method')
@classmethod
Expand Down

0 comments on commit d3a945c

Please sign in to comment.