Skip to content

Commit

Permalink
define spectrum.result_fit as an object
Browse files Browse the repository at this point in the history
  • Loading branch information
patquem committed Jan 24, 2024
1 parent 97ea49f commit 51459ec
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 36 deletions.
16 changes: 8 additions & 8 deletions fitspy/app/toplevels.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import matplotlib.pyplot as plt
from matplotlib.colors import rgb2hex
from lmfit import fit_report
from lmfit.model import ModelResult

from fitspy.app.utils import add, add_entry
from fitspy.app.callbacks import FIT_METHODS
Expand Down Expand Up @@ -127,7 +128,7 @@ def param_has_changed(self, i, key, arg):
value = max(min(param['max'], value), param['min'])
self.params[i][key][arg].set(f'{value:.4g}') # bound the value
param[arg] = value
self.spectrum.result_fit = None
self.spectrum.result_fit = lambda: None
self.plot() # pylint:disable=not-callable

def model_has_changed(self, i):
Expand All @@ -140,14 +141,14 @@ def model_has_changed(self, i):
x0 = spectrum.models[i].param_hints['x0']['value']
spectrum.models[i] = spectrum.create_model(i + 1, new_model_name,
x0=x0, ampli=ampli)
self.spectrum.result_fit = None
self.spectrum.result_fit = lambda: None
self.plot() # pylint:disable=not-callable
self.update()

def bkg_model_has_changed(self, _):
""" Update the 'bkg_model' """
self.spectrum.set_bkg_model(self.bkg_name.get())
self.spectrum.result_fit = None
self.spectrum.result_fit = lambda: None
self.plot() # pylint:disable=not-callable
self.update()

Expand All @@ -166,7 +167,7 @@ def delete_models(self):
for i, val in enumerate(reversed(self.models_delete)):
if val.get():
self.spectrum.del_model(nb_models - i - 1)
self.spectrum.result_fit = None
self.spectrum.result_fit = lambda: None
self.plot() # pylint:disable=not-callable
self.update()

Expand Down Expand Up @@ -253,10 +254,9 @@ def add_model(self, model, i, row, keys):
def update_stats(self):
""" Update the statistics """
self.text.delete(1.0, END)
if self.spectrum.result_fit is not None:
if hasattr(self.spectrum.result_fit, 'success'):
self.text.insert(END, fit_report(self.spectrum.result_fit))
self.text.pack()
if isinstance(self.spectrum.result_fit, ModelResult):
self.text.insert(END, fit_report(self.spectrum.result_fit))
self.text.pack()

def delete(self):
""" Delete all the values contained in frames """
Expand Down
4 changes: 2 additions & 2 deletions fitspy/spectra.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def picklable_model(model):
funcdefs[val.__name__] = val

for (values, success), spectrum in zip(results, spectra):
spectrum.result_fit = success
spectrum.result_fit.success = success
for model in spectrum.models:
for key in model.param_names:
model.set_param_hint(key[4:], value=values[key])
Expand Down Expand Up @@ -149,7 +149,7 @@ def save_results(self, dirname_res, fnames=None):

for fname in fnames:
spectrum, _ = self.get_objects(fname)
if spectrum.result_fit is not None:
if hasattr(spectrum.result_fit, "success"):
spectrum.save_params(dirname_res)
spectrum.save_stats(dirname_res)

Expand Down
57 changes: 31 additions & 26 deletions fitspy/spectrum.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import pandas as pd
from scipy.signal import find_peaks
from lmfit import Model, report_fit, fit_report
from lmfit.model import ModelResult
from lmfit.models import ConstantModel, LinearModel, ParabolicModel, \
ExponentialModel, ExpressionModel # pylint:disable=unused-import

Expand Down Expand Up @@ -89,7 +90,8 @@ class Spectrum:
An iteration consists in evaluating all the 'free' parameters once.
Default is 200.
result_fit: lmfit.ModelResult
Object resulting from lmfit fitting
Object resulting from lmfit fitting. Default value is an 'empty'
function that enables to address a 'result_fit.success' status
"""

def __init__(self):
Expand All @@ -114,7 +116,7 @@ def __init__(self):
self.fit_method = 'leastsq'
self.fit_negative = False
self.max_ite = 200
self.result_fit = None
self.result_fit = lambda: None

def set_attributes(self, dict_attrs, **fit_kwargs):
"""Set attributes from a dictionary (obtained from a .json reloading)"""
Expand Down Expand Up @@ -154,7 +156,7 @@ def set_attributes(self, dict_attrs, **fit_kwargs):
setattr(self.baseline, key, dict_attrs['baseline'][key])

if 'result_fit_success' in keys:
self.result_fit = dict_attrs['result_fit_success']
self.result_fit.success = dict_attrs['result_fit_success']

# COMPATIBILITY with 'old' models
#################################
Expand Down Expand Up @@ -354,7 +356,8 @@ def remove_models(self):
self.models = []
self.models_labels = []
self.models_index = itertools.count(start=1)
self.result_fit = None
self.bkg_model = None
self.result_fit = lambda: None

def set_bkg_model(self, bkg_name):
""" Set the 'bkg_model' attribute from 'bkg_name' """
Expand Down Expand Up @@ -477,7 +480,9 @@ def plot(self, ax, show_peaks=True, show_negative_values=True,
""" Plot the spectrum with the fitted models and Return the profiles """
lines = []
x, y = self.x, self.y
linewidth = 0.5 if self.result_fit is None else 1
linewidth = 0.5
if hasattr(self.result_fit, 'success') and self.result_fit.success:
linewidth = 1

ax.plot(x, y, 'ko-', lw=linewidth, ms=1)

Expand Down Expand Up @@ -516,27 +521,26 @@ def plot(self, ax, show_peaks=True, show_negative_values=True,
line, = ax.plot(x, y, lw=linewidth)
lines.append(line)

if self.result_fit is not None:
if linewidth == 1: # self.result_fit.success
ax.plot(x, y_bkg + y_models, 'b', label="Fitted profile")

return lines

def plot_residual(self, ax, factor=1):
""" Plot the residual x factor obtained after fitting """
if self.result_fit is not None:
x, y = self.x, self.y
if hasattr(self.result_fit, 'residual'):
residual = self.result_fit.residual
else:
y_fit = np.zeros_like(x)
for model in self.models:
y_fit += model.eval(model.make_params(), x=x)
if self.bkg_model is not None:
bkg_model = self.bkg_model
y_fit += bkg_model.eval(bkg_model.make_params(), x=x)
residual = y - y_fit
ax.plot(x, factor * residual, 'r', label=f"residual (x{factor})")
ax.legend()
x, y = self.x, self.y
if hasattr(self.result_fit, 'residual'):
residual = self.result_fit.residual
else:
y_fit = np.zeros_like(x)
for model in self.models:
y_fit += model.eval(model.make_params(), x=x)
if self.bkg_model is not None:
bkg_model = self.bkg_model
y_fit += bkg_model.eval(bkg_model.make_params(), x=x)
residual = y - y_fit
ax.plot(x, factor * residual, 'r', label=f"residual (x{factor})")
ax.legend()

def save_params(self, dirname_params):
""" Save fit parameters in a '.csv' file located in 'dirname_params' """
Expand All @@ -559,11 +563,12 @@ def save_params(self, dirname_params):

def save_stats(self, dirname_stats):
""" Save statistics in a '.txt' file located in 'dirname_stats' """
_, name, _ = fileparts(self.fname)
fname_stats = os.path.join(dirname_stats, name + '_stats.txt')
fname_stats = check_or_rename(fname_stats)
with open(fname_stats, 'w') as fid:
fid.write(fit_report(self.result_fit))
if isinstance(self.result_fit, ModelResult):
_, name, _ = fileparts(self.fname)
fname_stats = os.path.join(dirname_stats, name + '_stats.txt')
fname_stats = check_or_rename(fname_stats)
with open(fname_stats, 'w') as fid:
fid.write(fit_report(self.result_fit))

def save(self, fname_json=None):
""" Return a 'dict_attrs' dictionary from the spectrum attributes and
Expand Down Expand Up @@ -593,7 +598,7 @@ def save(self, fname_json=None):
models[i][model_name] = model.param_hints
dict_attrs['models'] = models

if self.result_fit is not None:
if hasattr(self.result_fit, 'success'):
dict_attrs['result_fit_success'] = self.result_fit.success

if fname_json is not None:
Expand Down

0 comments on commit 51459ec

Please sign in to comment.