diff --git a/fitspy/spectra.py b/fitspy/spectra.py index 7404806..0a8fae8 100644 --- a/fitspy/spectra.py +++ b/fitspy/spectra.py @@ -3,18 +3,16 @@ "Spectra" """ import os -import csv from copy import deepcopy from concurrent.futures import ProcessPoolExecutor import matplotlib.pyplot as plt -from lmfit import Model, fit_report, Parameters +from lmfit import Model, Parameters from lmfit.model import ModelResult -from fitspy.utils import fileparts, check_or_rename -from fitspy.utils import save_to_json, load_from_json +from fitspy.utils import fileparts, save_to_json, load_from_json from fitspy.models import gaussian from fitspy.spectrum import Spectrum -from fitspy import MODELS, PARAMS +from fitspy import MODELS def fit(params): @@ -129,38 +127,11 @@ def save_results(self, dirname_res, fnames=None): if fnames is None: fnames = self.fnames - def write_params(fname_params, labels, models): - with open(fname_params, 'w', newline='') as fid: - writer = csv.writer(fid, delimiter=';') - writer.writerow(['label', 'model'] + PARAMS) - for label, model in zip(labels, models): - vals = [label, Spectrum.get_model_name(model)] - for key in PARAMS: - params = model.param_hints - if key in params.keys(): - vals.append(params[key]['value']) - else: - vals.append('') - writer.writerow(vals) - for fname in fnames: spectrum, _ = self.get_objects(fname) if spectrum.result_fit is not None: - # TODO : use Path, move write method in spectrum - _, name, _ = fileparts(fname) - - # results saving - fname_params = os.path.join(dirname_res, name + '.csv') - fname_params = check_or_rename(fname_params) - labels, models = spectrum.models_labels, spectrum.models - if len(models) > 0: - write_params(fname_params, labels, models) - - # statistics saving - fname_stats = os.path.join(dirname_res, name + '_stats.txt') - fname_stats = check_or_rename(fname_stats) - with open(fname_stats, 'w') as fid: - fid.write(fit_report(spectrum.result_fit)) + spectrum.save_params(dirname_res) + spectrum.save_stats(dirname_res) def save_figures(self, dirname_fig, fnames=None, bounds=None): """ diff --git a/fitspy/spectrum.py b/fitspy/spectrum.py index 326db15..3510b75 100644 --- a/fitspy/spectrum.py +++ b/fitspy/spectrum.py @@ -1,21 +1,23 @@ """ Class dedicated to spectrum processing """ +import os +import csv import itertools from copy import deepcopy import numpy as np import pandas as pd from scipy.signal import find_peaks -from lmfit import Model, report_fit +from lmfit import Model, report_fit, fit_report from lmfit.models import ConstantModel, LinearModel, ParabolicModel, \ ExponentialModel, ExpressionModel # pylint:disable=unused-import -from fitspy.utils import closest_index +from fitspy.utils import closest_index, fileparts, check_or_rename from fitspy.utils import save_to_json, load_from_json from fitspy.app.utils import convert_dict_from_tk_variables from fitspy.app.utils import dict_has_tk_variable from fitspy.baseline import BaseLine -from fitspy import MODELS, BKG_MODELS +from fitspy import MODELS, PARAMS, BKG_MODELS class Spectrum: @@ -121,7 +123,7 @@ def set_attributes(self, dict_attrs, **fit_kwargs): model.prefix = pfx else: model = Model(model, independent_vars=['x'], prefix=pfx) - model.param_hints = param_hints + model.param_hints = deepcopy(param_hints) self.models.append(model) if 'bkg_model' in keys and dict_attrs['bkg_model']: @@ -135,11 +137,11 @@ def set_attributes(self, dict_attrs, **fit_kwargs): model = Model(model, independent_vars=['x']) self.bkg_model = model self.bkg_model.name2 = model_name - self.bkg_model.param_hints = param_hints + self.bkg_model.param_hints = deepcopy(param_hints) if 'baseline' in keys: self.baseline = BaseLine() - for key, val in vars(self.baseline).items(): + for key in vars(self.baseline).keys(): if key in dict_attrs['baseline'].keys(): setattr(self.baseline, key, dict_attrs['baseline'][key]) @@ -529,6 +531,33 @@ def plot_residual(self, ax, factor=1): ax.plot(x, factor * residual, 'r', label=f"residual (x{factor})") ax.legend() + def save_params(self, dirname_params): + """ Save fit parameters in a '.csv' file located in 'dirname_params' """ + _, name, _ = fileparts(self.fname) + fname_params = os.path.join(dirname_params, name + '.csv') + fname_params = check_or_rename(fname_params) + if len(self.models) > 0: + with open(fname_params, 'w', newline='') as fid: + writer = csv.writer(fid, delimiter=';') + writer.writerow(['label', 'model'] + PARAMS) + for label, model in zip(self.models_labels, self.models): + vals = [label, self.get_model_name(model)] + for key in PARAMS: + params = model.param_hints + if key in params.keys(): + vals.append(params[key]['value']) + else: + vals.append('') + writer.writerow(vals) + + def save_stats(self, dirname_stats): + """ Save statistics in a '.txt' file located in 'dirname_stats' """ + _, name, _ = fileparts(self.fname) + fname_stats = os.path.join(dirname_stats, name + '_stats.txt') + fname_stats = check_or_rename(fname_stats) + with open(fname_stats, 'w') as fid: + fid.write(fit_report(self.result_fit)) + def save(self, fname_json=None): """ Return a 'dict_attrs' dictionary from the spectrum attributes and Save it if a 'fname_json' is given """