How do I put a constraint on SciPy curve fit?

Axon picture Axon · May 14, 2013 · Viewed 15.5k times · Source

I'm trying to fit the distribution of some experimental values with a custom probability density function. Obviously, the integral of the resulting function should always be equal to 1, but the results of simple scipy.optimize.curve_fit(function, dataBincenters, dataCounts) never satisfy this condition. What is the best way to solve this problem?

Answer

Saullo G. P. Castro picture Saullo G. P. Castro · May 19, 2013

You can define your own residuals function, including a penalization parameter, like detailed in the code below, where it is known beforehand that the integral along the interval must be 2.. If you test without the penalization you will see that what your are getting is the conventional curve_fit:

enter image description here

import matplotlib.pyplot as plt
import scipy
from scipy.optimize import curve_fit, minimize, leastsq
from scipy.integrate import quad
from scipy import pi, sin
x = scipy.linspace(0, pi, 100)
y = scipy.sin(x) + (0. + scipy.rand(len(x))*0.4)
def func1(x, a0, a1, a2, a3):
    return a0 + a1*x + a2*x**2 + a3*x**3

# here you include the penalization factor
def residuals(p,x,y):
    integral = quad( func1, 0, pi, args=(p[0],p[1],p[2],p[3]))[0]
    penalization = abs(2.-integral)*10000
    return y - func1(x, p[0],p[1],p[2],p[3]) - penalization

popt1, pcov1 = curve_fit( func1, x, y )
popt2, pcov2 = leastsq(func=residuals, x0=(1.,1.,1.,1.), args=(x,y))
y_fit1 = func1(x, *popt1)
y_fit2 = func1(x, *popt2)
plt.scatter(x,y, marker='.')
plt.plot(x,y_fit1, color='g', label='curve_fit')
plt.plot(x,y_fit2, color='y', label='constrained')
plt.legend(); plt.xlim(-0.1,3.5); plt.ylim(0,1.4)
print 'Exact   integral:',quad(sin ,0,pi)[0]
print 'Approx integral1:',quad(func1,0,pi,args=(popt1[0],popt1[1],
                                                popt1[2],popt1[3]))[0]
print 'Approx integral2:',quad(func1,0,pi,args=(popt2[0],popt2[1],
                                                popt2[2],popt2[3]))[0]
plt.show()

#Exact   integral: 2.0
#Approx integral1: 2.60068579748
#Approx integral2: 2.00001911981

Other related questions: