Skip to content

Commit

Permalink
add dynamics
Browse files Browse the repository at this point in the history
  • Loading branch information
Nicola Mueller committed Mar 20, 2024
1 parent a96a23f commit e654878
Show file tree
Hide file tree
Showing 19 changed files with 2,801 additions and 79 deletions.
113 changes: 87 additions & 26 deletions src/coalre/distribution/CoalescentWithReassortment.java
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import beast.base.core.Function;
import beast.base.core.Input;
import beast.base.evolution.tree.coalescent.PopulationFunction;
import coalre.statistics.NetworkStatsLogger;

import java.util.List;

Expand All @@ -31,16 +32,30 @@ public class CoalescentWithReassortment extends NetworkDistribution {
"reassortment rates that vary over time",
Input.Validate.XOR, reassortmentRateInput);

public Input<Double> maxHeightRatioInput = new Input<>(
"maxHeightRatio",
"if specified, above the ratio, only coalescent events are allowed.", Double.POSITIVE_INFINITY);

public PopulationFunction populationFunction;
public Input<Double> redFactorInput = new Input<>(
"redFactor",
"by how much the recombination rate should be reduced after reaching the maxHeightRatio.", 0.1);



public PopulationFunction populationFunction;
private Function reassortmentRate;
public PopulationFunction timeVaryingReassortmentRates;

public NetworkIntervals intervals;

private boolean isTimeVarying = false;

@Override
public double redFactor;

private boolean reduceReassortmentAfterSegTrees = false;


@Override
public void initAndValidate(){
populationFunction = populationFunctionInput.get();
intervals = networkIntervalsInput.get();
Expand All @@ -57,11 +72,15 @@ public double calculateLogP() {
// Calculate tree intervals
List<NetworkEvent> networkEventList = intervals.getNetworkEventList();

NetworkEvent prevEvent = null;
// get the mrca of all loci trees
double lociMRCA = maxHeightRatioInput.get()<Double.POSITIVE_INFINITY ?
NetworkStatsLogger.getLociMRCA(networkIntervalsInput.get().networkInput.get()) : Double.POSITIVE_INFINITY;

NetworkEvent prevEvent = null;

for (NetworkEvent event : networkEventList) {
if (prevEvent != null)
logP += intervalContribution(prevEvent, event);
logP += intervalContribution(prevEvent, event, lociMRCA);

switch (event.type) {
case COALESCENCE:
Expand All @@ -72,19 +91,21 @@ public double calculateLogP() {
break;

case REASSORTMENT:
logP += reassortment(event);
logP += reassortment(event, lociMRCA);
break;
}

if (logP==Double.NEGATIVE_INFINITY)
break;

prevEvent = event;
}
}
// System.out.println(networkIntervalsInput.get().networkInput.get());
// System.exit(0);
return logP;
}

private double reassortment(NetworkEvent event) {
private double reassortment(NetworkEvent event, double lociMRCA) {
// lp+=Math.log(reassortmentRate.getArrayValue())
// + event.segsSortedLeft * Math.log(intervals.getBinomialProb())
// + (event.segsToSort-event.segsSortedLeft)*Math.log(1-intervals.getBinomialProb())
Expand All @@ -93,14 +114,23 @@ private double reassortment(NetworkEvent event) {
double binomval = Math.pow(intervals.getBinomialProb(), event.segsSortedLeft)
* Math.pow(1-intervals.getBinomialProb(), event.segsToSort-event.segsSortedLeft)
+ Math.pow(intervals.getBinomialProb(), event.segsToSort-event.segsSortedLeft)
* Math.pow(1-intervals.getBinomialProb(), event.segsSortedLeft);

if (isTimeVarying)
return Math.log(timeVaryingReassortmentRates.getPopSize(event.time))
+ Math.log(binomval);
else
return Math.log(reassortmentRate.getArrayValue())
+ Math.log(binomval);
* Math.pow(1-intervals.getBinomialProb(), event.segsSortedLeft);

if (event.time<=(lociMRCA*maxHeightRatioInput.get())) {
if (isTimeVarying)
return Math.log(timeVaryingReassortmentRates.getPopSize(event.time))
+ Math.log(binomval);
else
return Math.log(reassortmentRate.getArrayValue())
+ Math.log(binomval);
}else{
if (isTimeVarying)
return Math.log(redFactor*timeVaryingReassortmentRates.getPopSize(event.time))
+ Math.log(binomval);
else
return Math.log(redFactor*reassortmentRate.getArrayValue())
+ Math.log(binomval);
}



Expand All @@ -115,19 +145,48 @@ private double coalesce(NetworkEvent event) {
return Math.log(1.0/populationFunction.getPopSize(event.time));
}

private double intervalContribution(NetworkEvent prevEvent, NetworkEvent nextEvent) {
private double intervalContribution(NetworkEvent prevEvent, NetworkEvent nextEvent, double lociMRCA) {

double result = 0.0;

// System.out.println(timeVaryingReassortmentRates.getIntegral(prevEvent.time, nextEvent.time));
// System.out.println(nextEvent.time - prevEvent.time);
if (isTimeVarying)
result += -prevEvent.totalReassortmentObsProb
* timeVaryingReassortmentRates.getIntegral(prevEvent.time, nextEvent.time);
else
result += -reassortmentRate.getArrayValue() * prevEvent.totalReassortmentObsProb
* (nextEvent.time - prevEvent.time);

// System.out.println(prevEvent.time + " " + nextEvent.time);

// if (nextEvent.time<3.3) {
// System.out.println(prevEvent.time + " " + nextEvent.time);
// System.out.println(timeVaryingReassortmentRates.getIntegral(prevEvent.time, nextEvent.time));
// }

if (nextEvent.time<(lociMRCA*maxHeightRatioInput.get())) {
if (isTimeVarying)
result += -prevEvent.totalReassortmentObsProb
* timeVaryingReassortmentRates.getIntegral(prevEvent.time, nextEvent.time);
else
result += -reassortmentRate.getArrayValue() * prevEvent.totalReassortmentObsProb
* (nextEvent.time - prevEvent.time);
}else if (prevEvent.time<(lociMRCA*maxHeightRatioInput.get())) {
if (isTimeVarying)
result += -prevEvent.totalReassortmentObsProb
* timeVaryingReassortmentRates.getIntegral(prevEvent.time, nextEvent.time-lociMRCA*maxHeightRatioInput.get());
else
result += -reassortmentRate.getArrayValue() * prevEvent.totalReassortmentObsProb
* (lociMRCA*maxHeightRatioInput.get() - prevEvent.time);

if (isTimeVarying)
result += -redFactor*prevEvent.totalReassortmentObsProb
* timeVaryingReassortmentRates.getIntegral(lociMRCA*maxHeightRatioInput.get(), nextEvent.time);
else
result += -redFactor*reassortmentRate.getArrayValue() * prevEvent.totalReassortmentObsProb
* (nextEvent.time - prevEvent.time-lociMRCA*maxHeightRatioInput.get());



}else{
if (isTimeVarying)
result += -prevEvent.totalReassortmentObsProb
* redFactor* timeVaryingReassortmentRates.getIntegral(prevEvent.time, nextEvent.time);
else
result += -redFactor * reassortmentRate.getArrayValue() * prevEvent.totalReassortmentObsProb
* (nextEvent.time - prevEvent.time);
}
result += -0.5*prevEvent.lineages*(prevEvent.lineages-1)
* populationFunction.getIntegral(prevEvent.time, nextEvent.time);

Expand All @@ -148,6 +207,8 @@ protected boolean requiresRecalculation() {

return super.requiresRecalculation();
}




}
4 changes: 2 additions & 2 deletions src/coalre/distribution/TipPrior.java
Original file line number Diff line number Diff line change
Expand Up @@ -97,12 +97,12 @@ public double calculateLogP() {
if (taxon.getTaxonLabel().equals(taxonsetInput.get().getTaxonId(0))){
operatingNode = taxon;
MRCATime = dateOffsetInput.get().getValue() - operatingNode.getHeight();

logP += dist.logDensity(MRCATime);
break;
}
}

// System.out.println(logP);

return logP;
Expand Down
156 changes: 156 additions & 0 deletions src/coalre/dynamics/NeDynamicsFromSpline.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
package coalre.dynamics;

import beast.base.core.Description;
import beast.base.core.Input;
import beast.base.core.Loggable;
import beast.base.evolution.tree.coalescent.PopulationFunction;

import java.io.PrintStream;
import java.util.List;


/**
* @author Nicola F. Mueller
*/
@Description("Populaiton function with values at certain time points that are interpolated in between. Parameter has to be in log space")
public class NeDynamicsFromSpline extends PopulationFunction.Abstract implements Loggable {

final public Input<Spline> splineInput = new Input<>("spline",
"Spline to use for the population function", Input.Validate.REQUIRED);

boolean NesKnown = false;

Spline spline;

boolean returnNaN = false;

@Override
public void initAndValidate() {
spline = splineInput.get();
}


@Override
public List<String> getParameterIds() {
return null;
}

@Override
public double getPopSize(double t) {
if (!spline.update())
return Double.NaN;

// check which time t is, if it is larger than the last time, return the last Ne
int interval = spline.gridPoints-1;
for (int i = 0; i < spline.gridPoints; i++){
if (t < spline.time[i]){
interval = i-1;
break;
}
}
return spline.I[interval]/(spline.transmissionRate[interval]*2);
}

public double getIntegral(double from, double to) {
if (!spline.update())
return Double.NaN;

// compute the integral of Ne's between from an to
double NeIntegral = 0;
int intervalFrom = spline.gridPoints-1;
for (int i = 0; i < spline.gridPoints; i++) {
// get the first time larger than from
if (from < spline.time[i]) {
intervalFrom = i-1;
break;
}
}

for (int i = intervalFrom; i < spline.gridPoints; i++) {
// if i==intervalFrom, we have to start compute the diff from there
if (spline.time[i+1] > to) {
if (i == intervalFrom) {
NeIntegral += (spline.transmissionRate[i]/spline.I[i]) * (to - from);
}else{
NeIntegral += (spline.transmissionRate[i]/spline.I[i]) * (to - spline.time[i]);
}
break;
}else if (i == intervalFrom) {
NeIntegral += (spline.transmissionRate[i]/spline.I[i]) * (spline.time[i+1] - from);
}else{
NeIntegral += (spline.transmissionRate[i] / spline.I[i]) * (spline.time[i + 1] - spline.time[i]);
}
}
return 2*NeIntegral;
}

@Override
public double getIntensity(double v) {
return getIntegral(0,v);
}

@Override
public double getInverseIntensity(double v) {
// divide by 2 to avoid a division in every step
v/=2;
for (int i = 0; i < spline.gridPoints; i++) {
v -= (spline.transmissionRate[i] / spline.I[i]) * (spline.time[i + 1] - spline.time[i]);
if (v<0){
v += (spline.transmissionRate[i] / spline.I[i]) * (spline.time[i + 1] - spline.time[i]);
// solve for the final time
return spline.time[i] + v/(spline.transmissionRate[i] / spline.I[i]);
}
}
return spline.time[spline.gridPoints-1] + v/(spline.transmissionRate[spline.gridPoints-1] / spline.I[spline.gridPoints-1]);

}

@Override
public boolean requiresRecalculation() {
return true;
}

@Override
public void store() {
super.store();
}

@Override
public void restore() {
super.restore();
}


@Override
public void init(PrintStream printStream) {
for (int i = 0; i < spline.gridPoints; i+=20) {
printStream.print("logNe_" + i + "\t");
}
for (int i = 0; i < spline.gridPoints; i+=20) {
printStream.print("logI_" + i + "\t");
}
for (int i = 0; i < spline.gridPoints; i+=1) {
printStream.print("transmissionRate" + i + "\t");
}

}

@Override
public void log(long l, PrintStream printStream) {
for (int i = 0; i < spline.gridPoints; i+=20) {
printStream.print(Math.log(spline.I[i]/spline.transmissionRate[i]) + "\t");
}
for (int i = 0; i < spline.gridPoints; i+=20) {
printStream.print(spline.I[i] + "\t");
}
for (int i = 0; i < spline.gridPoints; i+=1) {
printStream.print(spline.transmissionRate[i] + "\t");
}

}

@Override
public void close(PrintStream printStream) {

}
}
Loading

0 comments on commit e654878

Please sign in to comment.