Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Loading a saved SplineModel fit result using load_modelresult does not work for more than 6 knots #985

Closed
B-Hartmann opened this issue Jan 16, 2025 · 6 comments · Fixed by #989

Comments

@B-Hartmann
Copy link

lmfit 1.3.2
scipy 1.12.0
numpy 1.26.4
asteval 1.0.5
Windows 10
Python 3.10.4
Code run in a jupyter notebook using DataSpell

I would like to save the results of a fit using SplineModel for later use. I used the functions save_modelresult and load_modelresult:

import numpy as np
import matplotlib.pyplot as plt
from lmfit.model import save_modelresult, load_modelresult
from lmfit.models import SplineModel

x = np.linspace(-10, 10, 100)
y = 0.6*np.exp(-(x**2)/(1.3**2))

spl_model = SplineModel(xknots=np.linspace(-10, 10, 10))
params = spl_model.guess(y, x)
result = spl_model.fit(y, params, x=x)

save_modelresult(result, 'spline_modelresult.sav')

result_new = load_modelresult('spline_modelresult.sav')
print(result_new.fit_report())

plt.plot(x, y, 'o')
plt.plot(x, result_new.best_fit, '-')
plt.show()

I get the following error:

Traceback (most recent call last)
Cell In[26], line 15
     11 result = spl_model.fit(y, params, x=x)
     13 save_modelresult(result, 'spline_modelresult.sav')
---> 15 result_new = load_modelresult('spline_modelresult.sav')
     16 print(result_new.fit_report())
     18 plt.plot(x, y, 'o')

File ~\.environments\tracking\lib\site-packages\lmfit\model.py:1468, in load_modelresult(fname, funcdefs)
   1466 modres = ModelResult(Model(lambda x: x, None), params)
   1467 with open(fname) as fh:
-> 1468     mresult = modres.load(fh, funcdefs=funcdefs)
   1469 return mresult

File ~\.environments\tracking\lib\site-packages\lmfit\model.py:2079, in ModelResult.load(self, fp, funcdefs, **kws)
   2057 def load(self, fp, funcdefs=None, **kws):
   2058     """Load JSON representation of ModelResult from a file-like object.
   2059 
   2060     Parameters
   (...)
   2077 
   2078     """
-> 2079     return self.loads(fp.read(), funcdefs=funcdefs, **kws)

File ~\.environments\tracking\lib\site-packages\lmfit\model.py:1998, in ModelResult.loads(self, s, funcdefs, **kws)
   1995     raise AttributeError('ModelResult.loads() needs valid ModelResult')
   1997 # model
-> 1998 self.model = _buildmodel(decode4js(modres['model']), funcdefs=funcdefs)
   2000 if funcdefs:
   2001     # Remove model function so as not pass it into the _asteval.symtable
   2002     funcdefs.pop(self.model.func.__name__, None)

File ~\.environments\tracking\lib\site-packages\lmfit\model.py:1421, in _buildmodel(state, funcdefs)
   1415     model = ExpressionModel(func, name=name,
   1416                             independent_vars=ivars,
   1417                             param_names=pnames,
   1418                             nan_policy=nan_policy, **opts)
   1420 else:
-> 1421     model = Model(func, name=name, prefix=prefix,
   1422                   independent_vars=ivars, param_names=pnames,
   1423                   nan_policy=nan_policy, **opts)
   1425 for name, hint in phints.items():
   1426     model.set_param_hint(name, **hint)

File ~\.environments\tracking\lib\site-packages\lmfit\model.py:305, in Model.__init__(self, func, independent_vars, param_names, nan_policy, prefix, name, **kws)
    303 self.param_hints = {}
    304 self._param_names = []
--> 305 self._parse_params()
    306 if self.independent_vars is None:
    307     self.independent_vars = []

File ~\.environments\tracking\lib\site-packages\lmfit\model.py:612, in Model._parse_params(self)
    609 for arg in names:
    610     if (self._strip_prefix(arg) not in self._func_allargs or
    611             arg in self._forbidden_args):
--> 612         raise ValueError(self._invalid_par % (arg, fname))
    613 # the following as been changed from OrderedSet for the time being.
    614 self._param_names = names[:]

ValueError: Invalid parameter name ('s6') for function spline_model

I believe the following line is the problem, allowing only 6 arguments to be passed to the function, but I have 10 knots:
https://github.com/lmfit/lmfit-py/blob/1.3.2/lmfit/models.py#L394

Could this be solved by adding a *s or **s to the function arguments or something?

@newville
Copy link
Member

@B-Hartmann Hmm, yeah that's a problem ;).

I think we might need to special case SplineModel when saving/loading models... that may need some work. There might be other options....

@paulmueller
Copy link
Contributor

paulmueller commented Jan 17, 2025

I believe a good approach would be to

  1. refactor/generalize _parse_params a little (accepting keyword arguments for whether to perform certain checks)

    https://github.com/lmfit/lmfit-py/blob/1.3.2/lmfit/model.py#L506

  2. In the spline subclass reimplement _parse_params and super()-call the parent class implementation, skipping the problematic check and implementing a separate check instead.

It might also help to refactor a _check_params method out of _parse_params.

If this sounds good, I would be happy to submit a PR 👍

@B-Hartmann
Copy link
Author

B-Hartmann commented Jan 17, 2025

A workaround that does what I need is to save the fit result using pickle and extract the parameter values from the pickled object when needed.
I can then create a new spline model based on the code from this discussion like so (only requires that the same spline degree and the same knots array are used as inputs for splrep):

import numpy as np
import matplotlib.pyplot as plt
from lmfit.models import SplineModel
from scipy.interpolate import splev, splrep
import pickle

x = np.linspace(-10, 10, 100)
y = 0.6*np.exp(-(x**2)/(2*2.3**2))

xknots = np.linspace(-8, 8, 12)
gmodel = SplineModel(xknots=xknots)
params = gmodel.guess(y, x)
result = gmodel.fit(y, params, x=x)

# save summary of fit report
with open('spline_model.pkl', 'wb') as f:
    pickle.dump(result.summary(), f)

# use the pickled data instead and manually extract the knots
with open('spline_model.pkl', 'rb') as f:
    loaded_result = pickle.load(f)

fit_param_new = []
for para_set in loaded_result["params"]:
    fit_param_new.append(para_set[1])
fit_param = []
for p in result.params:
    fit_param.append(result.params[p].value)

x_new = np.linspace(-10, 10, 200)
# create new spline model
knots, _c, _k = splrep(xknots, np.ones(len(xknots)), k=3)
coefs = fit_param_new
coefs.extend([coefs[-1]]*4)
y_pred =  splev(x_new, [knots, np.array(coefs), 3])

plt.figure()
plt.plot(x, y, 'o', label="data")
plt.plot(x, result.best_fit, '-', label="original spline")
plt.plot(x_new, y_pred, "--k", linewidth=2, label="new spline from parameters")
plt.legend()
plt.ylim([-0.2, 0.8])
plt.show()

@newville
Copy link
Member

@B-Hartmann Yes, that would be a decent workaround, but we should also try to fix this....

@paulmueller I think that "refactor/generalize _parse_params" for this (essentially, 'allow varargs') is not that easy. SplineModel is a special case as the only model that supports a variable number of parameters. We don't otherwise allow model functions to use varargs.

Another possible workaround would be to "simply" increase the number of parameters defined in the signature at
https://github.com/lmfit/lmfit-py/blob/1.3.2/lmfit/models.py#L394, say up to 's99', and set the maximum number of spline knots to 100. I think that would work, and using more than 100 knots seems highly unlikely.

@paulmueller
Copy link
Contributor

paulmueller commented Jan 21, 2025

@newville I would be motivated to work on both approaches (refactoring _parse_params or the simple approach). For the simple approach, there should probably a warning when saving the model parameters with more than 99 knots noting that loading the parameters will not work, pointing to this issue with the workaround.

If you have a preference (since you will have to review this), please let me know.

@newville
Copy link
Member

@paulmueller I recommend the "easy route" -- and setting the max number of splines there to 100 should be good.

To be clear, lmfit Model can use vark ws (**kwargs), and save/load model works with that. So, I the parsing works okay and the problem is only with SplineModel.

Another thing to try would be to use **kwargs for SplineModel.

See:

import numpy as np
import time
from lmfit import Model, Parameters
from lmfit.model import save_modelresult, load_modelresult
import wxmplot.interactive as wi

def my_poly(x, **params):
    val= 0.0
    parnames = sorted(params.keys())
    for i, pname in enumerate(parnames):
        val += params[pname]*x**i
    return val

my_model = Model(my_poly)

# Parameter names and starting values
params = Parameters()
params.add('C00', value=-10)
params.add('C01', value=  5)
params.add('C02', value=  1)
params.add('C03', value=  0)
params.add('C04', value=  0)

x = np.linspace(-20, 20, 101)
y = -30.4 + 7.8*x - 0.5*x*x + 0.03 * x**3 + 0.009*x**4
y = y + np.random.normal(size=len(y), scale=3)


out = my_model.fit(y, params, x=x)
print(out.fit_report())
wi.plot(x, y, label='data', show_legend=True)
wi.plot(x, out.best_fit, label='fit')

print("Saving")
save_modelresult(out, 'test_modeelresult.sav')

print("Waitng")
time.sleep(3)
new_result = load_modelresult('test_modeelresult.sav')
print("Loaded model result")
print(new_result.fit_report())
print(new_result.eval(x=3))

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants