Skip to content

Commit

Permalink
save_params() and save_stats() moved from spectra.py to spectrum.py
Browse files Browse the repository at this point in the history
  • Loading branch information
patquem committed Jan 22, 2024
1 parent 129ffc0 commit 74c2db6
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 40 deletions.
39 changes: 5 additions & 34 deletions fitspy/spectra.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,16 @@
"Spectra"
"""
import os
import csv
from copy import deepcopy
from concurrent.futures import ProcessPoolExecutor
import matplotlib.pyplot as plt
from lmfit import Model, fit_report, Parameters
from lmfit import Model, Parameters
from lmfit.model import ModelResult

from fitspy.utils import fileparts, check_or_rename
from fitspy.utils import save_to_json, load_from_json
from fitspy.utils import fileparts, save_to_json, load_from_json
from fitspy.models import gaussian
from fitspy.spectrum import Spectrum
from fitspy import MODELS, PARAMS
from fitspy import MODELS


def fit(params):
Expand Down Expand Up @@ -129,38 +127,11 @@ def save_results(self, dirname_res, fnames=None):
if fnames is None:
fnames = self.fnames

def write_params(fname_params, labels, models):
with open(fname_params, 'w', newline='') as fid:
writer = csv.writer(fid, delimiter=';')
writer.writerow(['label', 'model'] + PARAMS)
for label, model in zip(labels, models):
vals = [label, Spectrum.get_model_name(model)]
for key in PARAMS:
params = model.param_hints
if key in params.keys():
vals.append(params[key]['value'])
else:
vals.append('')
writer.writerow(vals)

for fname in fnames:
spectrum, _ = self.get_objects(fname)
if spectrum.result_fit is not None:
# TODO : use Path, move write method in spectrum
_, name, _ = fileparts(fname)

# results saving
fname_params = os.path.join(dirname_res, name + '.csv')
fname_params = check_or_rename(fname_params)
labels, models = spectrum.models_labels, spectrum.models
if len(models) > 0:
write_params(fname_params, labels, models)

# statistics saving
fname_stats = os.path.join(dirname_res, name + '_stats.txt')
fname_stats = check_or_rename(fname_stats)
with open(fname_stats, 'w') as fid:
fid.write(fit_report(spectrum.result_fit))
spectrum.save_params(dirname_res)
spectrum.save_stats(dirname_res)

def save_figures(self, dirname_fig, fnames=None, bounds=None):
"""
Expand Down
41 changes: 35 additions & 6 deletions fitspy/spectrum.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,23 @@
"""
Class dedicated to spectrum processing
"""
import os
import csv
import itertools
from copy import deepcopy
import numpy as np
import pandas as pd
from scipy.signal import find_peaks
from lmfit import Model, report_fit
from lmfit import Model, report_fit, fit_report
from lmfit.models import ConstantModel, LinearModel, ParabolicModel, \
ExponentialModel, ExpressionModel # pylint:disable=unused-import

from fitspy.utils import closest_index
from fitspy.utils import closest_index, fileparts, check_or_rename
from fitspy.utils import save_to_json, load_from_json
from fitspy.app.utils import convert_dict_from_tk_variables
from fitspy.app.utils import dict_has_tk_variable
from fitspy.baseline import BaseLine
from fitspy import MODELS, BKG_MODELS
from fitspy import MODELS, PARAMS, BKG_MODELS


class Spectrum:
Expand Down Expand Up @@ -121,7 +123,7 @@ def set_attributes(self, dict_attrs, **fit_kwargs):
model.prefix = pfx
else:
model = Model(model, independent_vars=['x'], prefix=pfx)
model.param_hints = param_hints
model.param_hints = deepcopy(param_hints)
self.models.append(model)

if 'bkg_model' in keys and dict_attrs['bkg_model']:
Expand All @@ -135,11 +137,11 @@ def set_attributes(self, dict_attrs, **fit_kwargs):
model = Model(model, independent_vars=['x'])
self.bkg_model = model
self.bkg_model.name2 = model_name
self.bkg_model.param_hints = param_hints
self.bkg_model.param_hints = deepcopy(param_hints)

if 'baseline' in keys:
self.baseline = BaseLine()
for key, val in vars(self.baseline).items():
for key in vars(self.baseline).keys():
if key in dict_attrs['baseline'].keys():
setattr(self.baseline, key, dict_attrs['baseline'][key])

Expand Down Expand Up @@ -529,6 +531,33 @@ def plot_residual(self, ax, factor=1):
ax.plot(x, factor * residual, 'r', label=f"residual (x{factor})")
ax.legend()

def save_params(self, dirname_params):
""" Save fit parameters in a '.csv' file located in 'dirname_params' """
_, name, _ = fileparts(self.fname)
fname_params = os.path.join(dirname_params, name + '.csv')
fname_params = check_or_rename(fname_params)
if len(self.models) > 0:
with open(fname_params, 'w', newline='') as fid:
writer = csv.writer(fid, delimiter=';')
writer.writerow(['label', 'model'] + PARAMS)
for label, model in zip(self.models_labels, self.models):
vals = [label, self.get_model_name(model)]
for key in PARAMS:
params = model.param_hints
if key in params.keys():
vals.append(params[key]['value'])
else:
vals.append('')
writer.writerow(vals)

def save_stats(self, dirname_stats):
""" Save statistics in a '.txt' file located in 'dirname_stats' """
_, name, _ = fileparts(self.fname)
fname_stats = os.path.join(dirname_stats, name + '_stats.txt')
fname_stats = check_or_rename(fname_stats)
with open(fname_stats, 'w') as fid:
fid.write(fit_report(self.result_fit))

def save(self, fname_json=None):
""" Return a 'dict_attrs' dictionary from the spectrum attributes and
Save it if a 'fname_json' is given """
Expand Down

0 comments on commit 74c2db6

Please sign in to comment.