Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implementation of 6K model #8

Open
wants to merge 14 commits into
base: master
Choose a base branch
from
255 changes: 253 additions & 2 deletions pyrho/fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -969,6 +969,249 @@ def solveGo(tlag, Gd, Go0=1000, tol=1e-9):



def fit6Kstates(fluxSet, quickSet, run, vInd, params, method=defMethod): # , verbose=config.verbose):
"""
fluxSet := ProtocolData set (of Photocurrent objects) to fit
quickSet:= ProtocolData set (of Photocurrent objects) with short pulses to fit opsin activation rates
run := Index for the run within the ProtocolData set
vInd := Index for Voltage clamp value within the ProtocolData set
params := Parameters object of model parameters with initial values [and bounds, expressions]
method := Fitting algorithm for the optimiser to use
"""
# verbose := Text output (verbosity) level


plotResult = bool(config.verbose > 1)

nStates = '6K'

### Prepare the data
nRuns = fluxSet.nRuns
nPhis = fluxSet.nPhis
nVs = fluxSet.nVs

assert(0 < nPhis)
assert(0 <= run < nRuns)
assert(0 <= vInd < nVs)

Ions = [None for phiInd in range(nPhis)]
Ioffs = [None for phiInd in range(nPhis)]
tons = [None for phiInd in range(nPhis)]
toffs = [None for phiInd in range(nPhis)]
phis = []
Is = []
ts = []
Vs = []

Icycles = []
nfs = [] # Normalisation factors: e.g. /Ions[trial][-1] or /min(Ions[trial])


# Trim off phase data
#frac = 1
#chop = int(round(len(Ioffs[0])*frac))

for phiInd in range(nPhis):
targetPC = fluxSet.trials[run][phiInd][vInd]
#targetPC.alignToTime()
I = targetPC.I
t = targetPC.t
onInd = targetPC._idx_pulses_[0,0] ### Consider multiple pulse scenarios
offInd = targetPC._idx_pulses_[0,1]
Ions[phiInd] = I[onInd:offInd+1]
Ioffs[phiInd] = I[offInd:] #[I[offInd:] for I in Is]
tons[phiInd] = t[onInd:offInd+1]-t[onInd]
toffs[phiInd] = t[offInd:]-t[offInd] #[t[offInd:]-t[offInd] for t in ts]
#args=(Ioffs[phiInd][:chop+1],toffs[phiInd][:chop+1])
phi = targetPC.phi
phis.append(phi)

Is.append(I)
ts.append(t)
V = targetPC.V
Vs.append(V)

Icycles.append(I[onInd:])
nfs.append(I[offInd])
#nfs.append(targetPC.I_peak_)


### OFF PHASE
### 3a. OFF CURVE: Fit biexponential to off curve to find lambdas

OffKeys = ['Gd1', 'Gd2', 'Gf0', 'Ga3']

iOffPs = Parameters() # Create parameter dictionary
for k in OffKeys:
copyParam(k, params, iOffPs)

### Trim the first 10% of the off curve to allow I1 and I2 to empty?


### This is an approximation based on the 4-state model which ignores the effects of Go1 and Go2 after light off.

# lam1 + lam2 == Gd1 + Gd2 + Gf0 + Gb0
# lam1 * lam2 == Gd1*Gd2 + Gd1*Gb0 + Gd2*Gf0

#, Gd2, Gf0, Gb0: (Gd1 + Gd2 + Gf0 + Gb0)/2
#calcC = lambda b, Gd1, Gd2, Gf0, Gb0: np.sqrt(b**2 - (Gd1*Gd2 + Gd1*Gb0 + Gd2*Gf0))

def lams(p):

Gd1 = p['Gd1'].value
Gd2 = p['Gd2'].value
Ga3 = p['Ga3'].value

lam1 = Gd1
lam2 = (Gd2 + Ga3)
return lam1, lam2

# Create dummy parameters for each phi
for phiInd in range(nPhis):
Iss = Ioffs[phiInd][0]
if Iss < 0:
iOffPs.add('Islow_'+str(phiInd), value=0.2*Iss, vary=True, max=0)
iOffPs.add('Ifast_'+str(phiInd), value=0.8*Iss, vary=True, max=0, expr='{} - {}'.format(Iss, 'Islow_'+str(phiInd)))
else:
iOffPs.add('Islow_'+str(phiInd), value=0.2*Iss, vary=True, min=0)
iOffPs.add('Ifast_'+str(phiInd), value=0.8*Iss, vary=True, min=0, expr='{} - {}'.format(Iss, 'Islow_'+str(phiInd)))

def fit6Koff(p,t,trial):
Islow = p['Islow_'+str(trial)].value
Ifast = p['Ifast_'+str(trial)].value
lam1, lam2 = lams(p)
return Islow*np.exp(-lam1*t) + Ifast*np.exp(-lam2*t)

def err6Koff(p,Ioffs,toffs):
"""Normalise by the first element of the off-curve""" # [-1]
return np.r_[ [(Ioffs[i] - fit6Koff(p,toffs[i],i))/Ioffs[i][0] for i in range(len(Ioffs))] ]

#fitfunc = lambda p, t: -(p['a0'].value + p['a1'].value*np.exp(-lams(p)[0]*t) + p['a2'].value*np.exp(-lams(p)[1]*t))
##fitfunc = lambda p, t: -(p['a0'].value + p['a1'].value*np.exp(-p['lam1'].value*t) + p['a2'].value*np.exp(-p['lam2'].value*t))
#errfunc = lambda p, Ioff, toff: Ioff - fitfunc(p,toff)

offPmin = minimize(err6Koff, iOffPs, args=(Ioffs,toffs), method=method)#, fit_kws={'maxfun':100000})
pOffs = offPmin.params

reportFit(offPmin, "Off-phase fit report for the 6K-state model", method)
if config.verbose > 0:
print('Gd1 = {}; Gd2 = {}; Gf0 = {}'.format(pOffs['Gd1'].value, pOffs['Gd2'].value,
pOffs['Gf0'].value))

if plotResult:
lam1, lam2 = lams(pOffs)
plotOffPhaseFits(toffs, Ioffs, pOffs, phis, nStates, fit6Koff, lam1, lam2, Gd=None)


# Fix off-curve parameters
for k in OffKeys:
pOffs[k].vary = False


### Calculate Go (1/tau_opsin)
print('\nCalculating opsin activation rate')
# Assume that Gd1 > Gd2
# Assume that Gd = Gd1 for short pulses

def solveGo(tlag, Gd, Go0=1000, tol=1e-9):
Go, Go_m1 = Go0, 0
while abs(Go_m1 - Go) > tol:
Go_m1 = Go
Go = ((tlag*Gd) - np.log(Gd/Go_m1))/tlag
#Go_m1, Go = Go, ((tlag*Gd) - np.log(Gd/Go_m1))/tlag
return Go

#if 'shortPulse' in dataSet: # Fit Go
if quickSet.nRuns > 1:
#from scipy.optimize import curve_fit
# Fit tpeak = tpulse + tmaxatp0 * np.exp(-k*tpulse)
#dataSet['shortPulse'].getProtPeaks()
#tpeaks = dataSet['shortPulse'].IrunPeaks

#PD = dataSet['shortPulse']
PCs = [quickSet.trials[p][0][0] for p in range(quickSet.nRuns)] # Aligned to the pulse i.e. t_on = 0
#[pc.alignToTime() for pc in PCs]

#tpeaks = np.asarray([PD.trials[p][0][0].tpeak for p in range(PD.nRuns)]) # - PD.trials[p][0][0].t[0]
#tpulses = np.asarray([PD.trials[p][0][0].Dt_ons[0] for p in range(PD.nRuns)])
tpeaks = np.asarray([pc.t_peak_ for pc in PCs])
tpulses = np.asarray([pc.Dt_ons_[0] for pc in PCs])

devFunc = lambda tpulses, t0, k: tpulses + t0 * np.exp(-k*tpulses)
p0 = (0, 1)
popt, pcov = curve_fit(devFunc, tpulses, tpeaks, p0=p0)
if plotResult:
fig = plt.figure()
ax = fig.add_subplot(111, aspect='equal')
nPoints = 10*int(round(max(tpulses))+1) # 101
tsmooth = np.linspace(0, max(tpulses), nPoints)
ax.plot(tpulses, tpeaks, 'x')
ax.plot(tsmooth, devFunc(tsmooth, *popt))
ax.plot(tsmooth, tsmooth, '--')
ax.set_ylim([0, max(tpulses)]) #+5
ax.set_xlim([0, max(tpulses)]) #+5
#plt.tight_layout()
#plt.axis('equal')
plt.show()

# Solve iteratively Go = ((tlag*Gd) - np.log(Gd/Go))/tlag
Gd1 = pOffs['Gd1'].value
Go = solveGo(tlag=popt[0], Gd=Gd1, Go0=1000, tol=1e-9)
print('t_lag = {:.3g}; Gd = {:.3g} --> Go = {:.3g}'.format(popt[0], Gd1, Go))

elif quickSet.nRuns == 1: #'delta' in dataSet:
#PD = dataSet['delta']
#PCs = [PD.trials[p][0][0] for p in range(PD.nRuns)]
PC = quickSet.trials[0][0][0]
tlag = PC.Dt_lag_ # := Dt_lags_[0] ############################### Add to Photocurrent...
Go = solveGo(tlag=tlag, Gd=Gd1, Go0=1000, tol=1e-9)
print('t_lag = {:.3g}; Gd = {:.3g} --> Go = {:.3g}'.format(tlag, Gd1, Go))

else:
Go = 1 # Default
print('No data found to estimate Go: defaulting to Go = {}'.format(Go))


### ON PHASE

iOnPs = Parameters() # deepcopy(params)

# Set parameters from Off-curve optimisation
for k in OffKeys:
copyParam(k, pOffs, iOnPs)

# Set parameters from general rhodopsin analysis routines
for k in ['Go1', 'Go2', 'k1', 'k2', 'k3', 'k_f', 'k_b', 'gam', 'p', 'q', 'phi_m', 'g0', 'Gb', 'E', 'v0', 'v1']: #.extend(OffKeys):
copyParam(k, params, iOnPs)

# Set parameters from short pulse calculations
iOnPs['Go1'].value = Go; iOnPs['Go1'].vary = False
iOnPs['Go2'].value = Go; iOnPs['Go2'].vary = False

RhO = models['6K']()

### Trim down ton? Take 10% of data or one point every ms? ==> [0::5]

if config.verbose > 2:
print('Optimising ',end='')

onPmin = minimize(errOnPhase, iOnPs, args=(Ions,tons,RhO,Vs,phis), method=method)
pOns = onPmin.params

reportFit(onPmin, "On-phase fit report for the 6K-state model", method)

if config.verbose > 0:
print('k1 = {}; k2 = {}; k_f = {}; k_b = {}'.format(pOns['k1'].value, pOns['k2'].value,
pOns['k_f'].value, pOns['k_b'].value))
print('gam = {}; phi_m = {}; p = {}; q = {}'.format(pOns['gam'].value, pOns['phi_m'].value,
pOns['p'].value, pOns['q'].value))

fitParams = pOns

return fitParams, onPmin




#TODO: Tidy up and refactor getRecoveryPeaks and fitRecovery
def getRecoveryPeaks(recData, phiInd=None, vInd=None, usePeakTime=False):
Expand Down Expand Up @@ -1769,11 +2012,14 @@ def fitModel(dataSet, nStates='3', params=None, postFitOpt=True, relaxFact=2, me
"""Fit a model (with initial parameters) to a dataset of optogenetic photocurrents."""

### Define non-optimised parameters to exclude in post-fit optimisation
nonOptParams = ['Gr0', 'E', 'v0', 'v1']

if not isinstance(nStates, str):
nStates = str(nStates) # .lower()

if nStates == '3' or nStates == '4' or nStates == '6':
nonOptParams = ['Gr0', 'E', 'v0', 'v1']
elif nStates == '6K':
nonOptParams = ['E', 'v0', 'v1', 'Ga3']

if nStates not in modelParams:
print(f"Error in selecting model {nStates} - please choose from {list(modelParams)} states")
raise NotImplementedError(nStates)
Expand Down Expand Up @@ -2000,6 +2246,11 @@ def fitModel(dataSet, nStates='3', params=None, postFitOpt=True, relaxFact=2, me
constrainedParams = ['Gd1', 'Gd2', 'Gf0', 'Gb0', 'Go1', 'Go2']
#constrainedParams = ['Go1', 'Go2', 'Gf0', 'Gb0']
#nonOptParams.append(['Gd1', 'Gd2'])
elif nStates == '6K':
fittedParams, miniObj = fit6Kstates(setPC, quickSet, runInd, vIndm70, fitParams, method) # , verbose)
constrainedParams = ['Gd1', 'Gd2', 'Gf0', 'Go1', 'Go2', 'Gb']
#constrainedParams = ['Go1', 'Go2', 'Gf0', 'Gb0']
#nonOptParams.append(['Gd1', 'Gd2'])
else:
raise Exception(f'Invalid choice for nStates: {nStates}!')

Expand Down
Loading