Skip to content

Commit

Permalink
revisit fit() using apply_model()
Browse files Browse the repository at this point in the history
  • Loading branch information
patquem committed Jan 24, 2024
1 parent 9dbb3cf commit 97ea49f
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 46 deletions.
48 changes: 8 additions & 40 deletions fitspy/app/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,11 @@
from tkinter import END
from tkinter.messagebox import askyesno, showerror
from tkinter import filedialog as fd
from copy import deepcopy
import glob

from lmfit.models import ExpressionModel
import numpy as np
import matplotlib.pyplot as plt
import dill

from fitspy.spectra import Spectra, fit_mp
from fitspy.spectra import Spectra
from fitspy.spectra_map import SpectraMap
from fitspy.spectrum import Spectrum
from fitspy.utils import closest_index, check_or_rename
Expand Down Expand Up @@ -639,43 +635,15 @@ def fit(self, fnames=None):
fnames = fselector.filenames[0]
fnames = [fnames[i] for i in fselector.lbox[0].curselection()]

models = self.current_spectrum.models
models_labels = self.current_spectrum.models_labels
params = self.fit_settings.params
fit_negative = params['fit_negative_values'].get() == 'On'
max_ite = params['maximum_iterations'].get()
fit_method = params['fit_method'].get()

spectra = []
bkg_model_ref = self.tabview.spectrum.bkg_model
for fname in fnames:
spectrum, _ = self.spectra.get_objects(fname)
# to keep values defined by the user
if bkg_model_ref is not None:
if isinstance(bkg_model_ref, ExpressionModel):
spectrum.bkg_model = dill.copy(bkg_model_ref)
else:
spectrum.bkg_model = deepcopy(bkg_model_ref)
else:
spectrum.set_bkg_model(bkg_name)
spectra.append(spectrum)

ncpus = self.get_ncpus(nfiles=len(spectra))

if ncpus == 1:
for spectrum in spectra:
# ExpressionModel can not be serialized with Pickle,
# which deepcopy uses
if np.any([isinstance(x, ExpressionModel) for x in models]):
spectrum.models = dill.copy(models)
else:
spectrum.models = deepcopy(models)
spectrum.models_labels = models_labels.copy()
spectrum.fit(fit_method, fit_negative, max_ite)
else:
fit_mp(spectra, models,
fit_method, fit_negative, max_ite, ncpus, models_labels)
kwargs = {'fit_negative': params['fit_negative_values'].get() == 'On',
'max_ite': params['maximum_iterations'].get(),
'fit_method': params['fit_method'].get()}

model_dict = self.current_spectrum.save()
ncpus = self.get_ncpus(nfiles=len(fnames))
fit_only = True
self.spectra.apply_model(model_dict, fnames, ncpus, fit_only, **kwargs)
self.colorize_from_fit_status(fnames=fnames)
self.tabview.update()
self.tabview.update_stats()
Expand Down
14 changes: 8 additions & 6 deletions fitspy/spectra.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,19 +203,23 @@ def load_model(fname_json, ind=0):
model = load_from_json(fname_json)[ind]
return model

def apply_model(self, model, fnames=None, ncpus=1, **fit_kwargs):
def apply_model(self, model, fnames=None, ncpus=1,
fit_only=False, **fit_kwargs):
"""
Apply 'model' to all or part of the spectra
Parameters
----------
model: dict
Dictionary issued from a .json model reloading
Dictionary related to the Spectrum object attributes (obtained from
Spectrum.save())
fnames: list of str, optional
List of spectrum filename to handle.
If None, apply the model to all the spectra
ncpus: int, optional
Number of CPU to work with in fitting
fit_only: bool, optional
Activation key to process only fittin
fit_kwargs: dict
Keywords arguments passed to spectrum.fit()
"""
Expand All @@ -230,12 +234,10 @@ def apply_model(self, model, fnames=None, ncpus=1, **fit_kwargs):
spectrum, _ = self.get_objects(fname)
spectrum.set_attributes(model, **fit_kwargs)
spectrum.fname = fname # reassign the correct fname
spectrum.preprocess()
if not fit_only:
spectrum.preprocess()
spectra.append(spectrum)

if len(spectra) == 0:
return

if ncpus == 1:
for spectrum in spectra:
spectrum.fit()
Expand Down

0 comments on commit 97ea49f

Please sign in to comment.