From e6548788950909309477c2f67d88396bd1ea32d6 Mon Sep 17 00:00:00 2001 From: Nicola Mueller Date: Wed, 20 Mar 2024 15:44:16 -0700 Subject: [PATCH] add dynamics --- .../CoalescentWithReassortment.java | 113 ++- src/coalre/distribution/TipPrior.java | 4 +- src/coalre/dynamics/NeDynamicsFromSpline.java | 156 ++++ src/coalre/dynamics/NotAKnotSpline.java | 187 ++++ .../RecombinationDynamicsFromSpline.java | 147 ++++ src/coalre/dynamics/Spline.java | 328 +++++++ .../SplineTransmissionDifference.java | 95 +++ .../AddRemoveReassortmentCoalescent.java | 8 +- .../GibbsOperatorAboveSegmentRoots.java | 60 +- .../operators/MultiTipDatesRandomWalker.java | 20 +- src/coalre/operators/NetworkOperator.java | 31 +- src/coalre/operators/TipReheight.java | 22 +- src/coalre/simulator/SIRwithReassortment.java | 583 +++++++++++++ .../SuperspreadingSIRwithReassortment.java | 278 ++++++ ...preadingStructuredSIRwithReassortment.java | 798 ++++++++++++++++++ src/coalre/statistics/NetworkStatsLogger.java | 26 + src/main/java/org/example/Main.java | 7 + src/module-info.java | 8 + version.xml | 9 +- 19 files changed, 2801 insertions(+), 79 deletions(-) create mode 100644 src/coalre/dynamics/NeDynamicsFromSpline.java create mode 100644 src/coalre/dynamics/NotAKnotSpline.java create mode 100644 src/coalre/dynamics/RecombinationDynamicsFromSpline.java create mode 100644 src/coalre/dynamics/Spline.java create mode 100644 src/coalre/dynamics/SplineTransmissionDifference.java create mode 100644 src/coalre/simulator/SIRwithReassortment.java create mode 100644 src/coalre/simulator/SuperspreadingSIRwithReassortment.java create mode 100644 src/coalre/simulator/SuperspreadingStructuredSIRwithReassortment.java create mode 100644 src/main/java/org/example/Main.java create mode 100644 src/module-info.java diff --git a/src/coalre/distribution/CoalescentWithReassortment.java b/src/coalre/distribution/CoalescentWithReassortment.java index 2cbeef6..911ea64 100644 --- a/src/coalre/distribution/CoalescentWithReassortment.java +++ b/src/coalre/distribution/CoalescentWithReassortment.java @@ -5,6 +5,7 @@ import beast.base.core.Function; import beast.base.core.Input; import beast.base.evolution.tree.coalescent.PopulationFunction; +import coalre.statistics.NetworkStatsLogger; import java.util.List; @@ -31,8 +32,17 @@ public class CoalescentWithReassortment extends NetworkDistribution { "reassortment rates that vary over time", Input.Validate.XOR, reassortmentRateInput); + public Input maxHeightRatioInput = new Input<>( + "maxHeightRatio", + "if specified, above the ratio, only coalescent events are allowed.", Double.POSITIVE_INFINITY); - public PopulationFunction populationFunction; + public Input redFactorInput = new Input<>( + "redFactor", + "by how much the recombination rate should be reduced after reaching the maxHeightRatio.", 0.1); + + + + public PopulationFunction populationFunction; private Function reassortmentRate; public PopulationFunction timeVaryingReassortmentRates; @@ -40,7 +50,12 @@ public class CoalescentWithReassortment extends NetworkDistribution { private boolean isTimeVarying = false; - @Override + public double redFactor; + + private boolean reduceReassortmentAfterSegTrees = false; + + + @Override public void initAndValidate(){ populationFunction = populationFunctionInput.get(); intervals = networkIntervalsInput.get(); @@ -57,11 +72,15 @@ public double calculateLogP() { // Calculate tree intervals List networkEventList = intervals.getNetworkEventList(); - NetworkEvent prevEvent = null; + // get the mrca of all loci trees + double lociMRCA = maxHeightRatioInput.get() splineInput = new Input<>("spline", + "Spline to use for the population function", Input.Validate.REQUIRED); + + boolean NesKnown = false; + + Spline spline; + + boolean returnNaN = false; + + @Override + public void initAndValidate() { + spline = splineInput.get(); + } + + + @Override + public List getParameterIds() { + return null; + } + + @Override + public double getPopSize(double t) { + if (!spline.update()) + return Double.NaN; + + // check which time t is, if it is larger than the last time, return the last Ne + int interval = spline.gridPoints-1; + for (int i = 0; i < spline.gridPoints; i++){ + if (t < spline.time[i]){ + interval = i-1; + break; + } + } + return spline.I[interval]/(spline.transmissionRate[interval]*2); + } + + public double getIntegral(double from, double to) { + if (!spline.update()) + return Double.NaN; + + // compute the integral of Ne's between from an to + double NeIntegral = 0; + int intervalFrom = spline.gridPoints-1; + for (int i = 0; i < spline.gridPoints; i++) { + // get the first time larger than from + if (from < spline.time[i]) { + intervalFrom = i-1; + break; + } + } + + for (int i = intervalFrom; i < spline.gridPoints; i++) { + // if i==intervalFrom, we have to start compute the diff from there + if (spline.time[i+1] > to) { + if (i == intervalFrom) { + NeIntegral += (spline.transmissionRate[i]/spline.I[i]) * (to - from); + }else{ + NeIntegral += (spline.transmissionRate[i]/spline.I[i]) * (to - spline.time[i]); + } + break; + }else if (i == intervalFrom) { + NeIntegral += (spline.transmissionRate[i]/spline.I[i]) * (spline.time[i+1] - from); + }else{ + NeIntegral += (spline.transmissionRate[i] / spline.I[i]) * (spline.time[i + 1] - spline.time[i]); + } + } + return 2*NeIntegral; + } + + @Override + public double getIntensity(double v) { + return getIntegral(0,v); + } + + @Override + public double getInverseIntensity(double v) { + // divide by 2 to avoid a division in every step + v/=2; + for (int i = 0; i < spline.gridPoints; i++) { + v -= (spline.transmissionRate[i] / spline.I[i]) * (spline.time[i + 1] - spline.time[i]); + if (v<0){ + v += (spline.transmissionRate[i] / spline.I[i]) * (spline.time[i + 1] - spline.time[i]); + // solve for the final time + return spline.time[i] + v/(spline.transmissionRate[i] / spline.I[i]); + } + } + return spline.time[spline.gridPoints-1] + v/(spline.transmissionRate[spline.gridPoints-1] / spline.I[spline.gridPoints-1]); + + } + + @Override + public boolean requiresRecalculation() { + return true; + } + + @Override + public void store() { + super.store(); + } + + @Override + public void restore() { + super.restore(); + } + + + @Override + public void init(PrintStream printStream) { + for (int i = 0; i < spline.gridPoints; i+=20) { + printStream.print("logNe_" + i + "\t"); + } + for (int i = 0; i < spline.gridPoints; i+=20) { + printStream.print("logI_" + i + "\t"); + } + for (int i = 0; i < spline.gridPoints; i+=1) { + printStream.print("transmissionRate" + i + "\t"); + } + + } + + @Override + public void log(long l, PrintStream printStream) { + for (int i = 0; i < spline.gridPoints; i+=20) { + printStream.print(Math.log(spline.I[i]/spline.transmissionRate[i]) + "\t"); + } + for (int i = 0; i < spline.gridPoints; i+=20) { + printStream.print(spline.I[i] + "\t"); + } + for (int i = 0; i < spline.gridPoints; i+=1) { + printStream.print(spline.transmissionRate[i] + "\t"); + } + + } + + @Override + public void close(PrintStream printStream) { + + } +} \ No newline at end of file diff --git a/src/coalre/dynamics/NotAKnotSpline.java b/src/coalre/dynamics/NotAKnotSpline.java new file mode 100644 index 0000000..c7a23dc --- /dev/null +++ b/src/coalre/dynamics/NotAKnotSpline.java @@ -0,0 +1,187 @@ +package coalre.dynamics; + +import beast.base.core.Description; +import beast.base.core.Input; +import beast.base.inference.CalculationNode; +import beast.base.inference.parameter.RealParameter; +import org.apache.commons.math3.linear.*; + + +/** + * @author Nicola F. Mueller + */ +@Description("Populaiton function with values at certain time points that are interpolated in between. Parameter has to be in log space") +public class NotAKnotSpline extends CalculationNode { + + final public Input InfectedInput = new Input<>("logInfected", + "Nes over time in log space", Input.Validate.REQUIRED); + final public Input rateShiftsInput = new Input<>("rateShifts", + "When to switch between elements of Ne", Input.Validate.REQUIRED); + final public Input uninfectiousRateInput = new Input<>("uninfectiousRate", + "Rate at which individuals become uninfectious", Input.Validate.REQUIRED); + final public Input gridPointsInput = new Input<>("gridPoints", + "Number of grid points to use for the spline calculation", 1000); + final public Input infectedIsNeInput = new Input<>("infectedIsNe", + "Whether the infected parameter is actually the number of infected or the logNe", false); + + RealParameter infected; + RealParameter rateShifts; + RealParameter uninfectiousRate; + int gridPoints; + + double[] transmissionRate; + double[] transmissionRateStored; + + double[] I; + double[] I_stored; + + double[][] splineCoeffs; + double[][] splineCoeffs_stored; + + double[] time; + + boolean ratesKnows=false; + boolean isValid = true; + boolean infectedIsNe = false; + + @Override + public void initAndValidate() { + infected = InfectedInput.get(); + rateShifts = rateShiftsInput.get(); + infected.setDimension(rateShifts.getDimension()); + uninfectiousRate = uninfectiousRateInput.get(); + gridPoints = gridPointsInput.get(); + infectedIsNe = infectedIsNeInput.get(); + recalculateRates(); + } + + // computes the Ne's at the break points from the growth rates and the transmission rates + private void recalculateRates() { + notAKnotCubicSpline(); + // make the time grid from 0 to rateShifts.getArrayValue(rateShifts.getDimension()-1) using gridPoints + time = new double[gridPoints+1]; + I = new double[gridPoints+1]; + transmissionRate = new double[gridPoints+1]; + double dt = rateShifts.getArrayValue(rateShifts.getDimension()-1) / (time.length-1); + int j = 0; + int k = j-1; + isValid = true; + for (int i=0; i < gridPoints; i++) { + // update the time for this grid point + time[i] = i*dt; + // find the interval in which this grid point lies + if (time[i] >= rateShifts.getArrayValue(j)) { + j++; + k++; + if (k==rateShifts.getDimension()-1) { + k--; + } + } + // get the time diff from the last point where logI was estimated + double timeDiff = time[i]-rateShifts.getArrayValue(k); + double timeDiff2 = timeDiff*timeDiff; + double timeDiff3 = timeDiff2*timeDiff; + // compute the number of infected individuals at the grid points + I[i] = Math.exp(splineCoeffs[k][0]*timeDiff3 + splineCoeffs[k][1]*timeDiff2 + splineCoeffs[k][2]*timeDiff + splineCoeffs[k][3]); + // compute the transmission rate at the grid points, from dI/dt and the recovery rate, the minus in front + // of I is because the transmission rates are forward in time, but the dI/dt is backward in time + if (infectedIsNe) { + transmissionRate[i] = 1; + }else { + transmissionRate[i] = uninfectiousRate.getValue() - + I[i] * (3 * splineCoeffs[k][0] * timeDiff2 + 2 * splineCoeffs[k][1] * timeDiff + splineCoeffs[k][2]); + } + if (transmissionRate[i] < 0) { + isValid = false; + } + } + // set the last time point to the last time point in the spline + time[gridPoints] = Double.POSITIVE_INFINITY; + // set the last I to the last I in the spline + I[gridPoints] = I[gridPoints-1]; + // set the last transmission rate to the last transmission rate in the spline + transmissionRate[gridPoints] = transmissionRate[gridPoints-1]; + ratesKnows = true; + } + + public boolean update() { + if (!ratesKnows) { + recalculateRates(); + } + return isValid; + } + + @Override + public boolean requiresRecalculation() { + ratesKnows = false; + return true; + } + + @Override + public void store() { + transmissionRateStored = new double[transmissionRate.length]; + System.arraycopy(transmissionRate, 0, transmissionRateStored, 0, transmissionRate.length); + I_stored = new double[I.length]; + System.arraycopy(I, 0, I_stored, 0, I.length); + super.store(); + } + + @Override + public void restore() { + ratesKnows=false; + super.restore(); + } + + + + /** computes the coefficients for the cubic spline interpolation + * de Boor, Carl. A Practical Guide to Splines. Springer-Verlag, New York: 1978 + */ + public void notAKnotCubicSpline() { + int n = rateShifts.getDimension(); + + // Calculate h values (difference between x values) + double[] h = new double[n - 1]; + for (int i = 0; i < n - 1; i++) { + h[i] = rateShifts.getArrayValue(i + 1) - rateShifts.getArrayValue(i); + } + + // Calculate the difference in y values + double[] delta = new double[n - 1]; + for (int i = 0; i < n - 1; i++) { + delta[i] = (infected.getArrayValue(i + 1) - infected.getArrayValue(i)) / h[i]; + } + + // Create the tridiagonal system + RealMatrix A = new Array2DRowRealMatrix(n, n); + RealVector r = new ArrayRealVector(n); + + // Set up the system A*mu = r for the not-a-knot condition + A.setEntry(0, 0, h[1]); + A.setEntry(0, 1, -(h[0] + h[1])); + A.setEntry(0, 2, h[0]); + A.setEntry(n - 1, n - 3, h[n - 2]); + A.setEntry(n - 1, n - 2, -(h[n - 2] + h[n - 3])); + A.setEntry(n - 1, n - 1, h[n - 3]); + + for (int i = 1; i < n - 1; i++) { + A.setEntry(i, i - 1, h[i - 1]); + A.setEntry(i, i, 2 * (h[i - 1] + h[i])); + A.setEntry(i, i + 1, h[i]); + r.setEntry(i, 6 * (delta[i] - delta[i - 1])); + } + + // Solve for mu + DecompositionSolver solver = new LUDecomposition(A).getSolver(); + RealVector mu = solver.solve(r); + + // Store coefficients for each spline segment + splineCoeffs = new double[n - 1][4]; + for (int i = 0; i < n - 1; i++) { + splineCoeffs[i][0] = (mu.getEntry(i + 1) - mu.getEntry(i)) / (6 * h[i]); + splineCoeffs[i][1] = mu.getEntry(i) / 2; + splineCoeffs[i][2] = delta[i] - h[i] * (2 * mu.getEntry(i) + mu.getEntry(i + 1)) / 6; + splineCoeffs[i][3] = infected.getArrayValue(i); + } + } +} \ No newline at end of file diff --git a/src/coalre/dynamics/RecombinationDynamicsFromSpline.java b/src/coalre/dynamics/RecombinationDynamicsFromSpline.java new file mode 100644 index 0000000..f86993a --- /dev/null +++ b/src/coalre/dynamics/RecombinationDynamicsFromSpline.java @@ -0,0 +1,147 @@ +package coalre.dynamics; + +import beast.base.core.Description; +import beast.base.core.Input; +import beast.base.core.Loggable; +import beast.base.evolution.tree.coalescent.PopulationFunction; +import beast.base.inference.parameter.RealParameter; + +import java.io.PrintStream; +import java.util.List; + + +/** + * @author Nicola F. Mueller + */ +@Description("Computes time varying recombination rates as a population"+ + " function from spline interpolation of the number of infected over time") +public class RecombinationDynamicsFromSpline extends PopulationFunction.Abstract implements Loggable { + final public Input InfectedToRhoInput = new Input<>("InfectedToRho", + "the value that maps the number of infected or the Ne to the reassortment rate ", Input.Validate.REQUIRED); + final public Input splineInput = new Input<>("spline", + "Spline to use for the population function", Input.Validate.REQUIRED); + + boolean NesKnown = false; + + Spline spline; + + RealParameter InfectedToRho; + + double rateRatio; + + @Override + public void initAndValidate() { + InfectedToRho = InfectedToRhoInput.get(); + spline = splineInput.get(); + } + + + @Override + public List getParameterIds() { + return null; + } + + @Override + public double getPopSize(double t) { + if (!spline.update()) + return Double.NaN; + + // check which time t is, if it is larger than the last time, return the last Ne + int interval = spline.gridPoints-1; + for (int i = 0; i < spline.gridPoints; i++){ + if (t < spline.time[i]){ + interval = i-1; + break; + } + } + return spline.I[interval]*InfectedToRho.getValue(); + } + + public double getIntegral(double from, double to) { + if (!spline.update()) + return Double.NaN; + + // compute the integral of Ne's between from an to + double NeIntegral = 0; + int intervalFrom = spline.gridPoints-1; + for (int i = 0; i < spline.gridPoints; i++) { + // get the first time larger than from + if (from < spline.time[i]) { + intervalFrom = i-1; + break; + } + } + for (int i = intervalFrom; i < spline.gridPoints; i++) { + double rate = spline.I[i]*InfectedToRho.getValue(); + // if i==intervalFrom, we have to start compute the diff from there + if (spline.time[i+1] > to) { + if (i == intervalFrom) { + NeIntegral += rate * (to - from); + }else{ + NeIntegral += rate * (to - spline.time[i]); + } + break; + }else if (i == intervalFrom) { + NeIntegral += rate * (spline.time[i+1] - from); + }else{ + NeIntegral += rate * (spline.time[i + 1] - spline.time[i]); + } + } + return NeIntegral; + } + + @Override + public double getIntensity(double v){ + return getIntegral(0, v); + } + + @Override + public double getInverseIntensity(double v) { + for (int i = 0; i < spline.gridPoints; i++) { + double rate = spline.I[i]*InfectedToRho.getValue(); + v -= rate * (spline.time[i + 1] - spline.time[i]); + if (v<0){ + v += rate * (spline.time[i + 1] - spline.time[i]); + // solve for the final time + return spline.time[i] + v/rate; + } + } + return spline.time[spline.gridPoints-1] + v/(spline.I[spline.gridPoints-1]*InfectedToRho.getValue()); + } + + @Override + public boolean requiresRecalculation() { + return true; + } + + @Override + public void store() { + super.store(); + } + + @Override + public void restore() { + super.restore(); + } + + @Override + public void init(PrintStream printStream) { + for (int i = 0; i < spline.gridPoints; i+=20) { + printStream.print("reassortment" + i + "\t"); + } + } + + @Override + public void log(long l, PrintStream printStream) { + for (int i = 0; i < spline.gridPoints; i+=20) { + printStream.print(spline.I[i]*InfectedToRho.getValue() + "\t"); + } + } + + @Override + public void close(PrintStream printStream) { + + } + + +} \ No newline at end of file diff --git a/src/coalre/dynamics/Spline.java b/src/coalre/dynamics/Spline.java new file mode 100644 index 0000000..892c469 --- /dev/null +++ b/src/coalre/dynamics/Spline.java @@ -0,0 +1,328 @@ +package coalre.dynamics; + +import beast.base.core.Description; +import beast.base.core.Input; +import beast.base.core.Loggable; +import beast.base.inference.CalculationNode; +import beast.base.inference.parameter.RealParameter; +import org.apache.commons.math3.linear.*; + +import java.io.PrintStream; + + +/** + * @author Nicola F. Mueller + */ +@Description("Populaiton function with values at certain time points that are interpolated in between. Parameter has to be in log space") +public class Spline extends CalculationNode implements Loggable { + final public Input InfectedInput = new Input<>("logInfected", + "Nes over time in log space", Input.Validate.REQUIRED); + final public Input rateShiftsInput = new Input<>("rateShifts", + "When to switch between elements of Ne", Input.Validate.REQUIRED); + final public Input uninfectiousRateInput = new Input<>("uninfectiousRate", + "Rate at which individuals become uninfectious", Input.Validate.REQUIRED); + final public Input gridPointsInput = new Input<>("gridPoints", + "Number of grid points to use for the spline calculation", 1000); + final public Input infectedIsNeInput = new Input<>("infectedIsNe", + "Whether the infected parameter is actually the number of infected or the logNe", false); + + RealParameter infected; + RealParameter rateShifts; + RealParameter uninfectiousRate; + int gridPoints; + + double[] transmissionRate; + double[] transmissionRateStored; + + double[] I; + double[] I_stored; + + double[][] splineCoeffs; + double[][] splineCoeffs_stored; + + double[] time; + + boolean ratesKnows=false; + boolean isValid = true; + boolean infectedIsNe = false; + + @Override + public void initAndValidate() { + infected = InfectedInput.get(); + rateShifts = rateShiftsInput.get(); + infected.setDimension(rateShifts.getDimension()); + uninfectiousRate = uninfectiousRateInput.get(); + gridPoints = gridPointsInput.get(); + infectedIsNe = infectedIsNeInput.get(); + recalculateRates(); + } + + // computes the Ne's at the break points from the growth rates and the transmission rates + private void recalculateRates() { +// notAKnotCubicSpline(); + computeAkimaSplineCoefficients(); + +// clampedCubicSpline(); + // make the time grid from 0 to rateShifts.getArrayValue(rateShifts.getDimension()-1) using gridPoints + time = new double[gridPoints+1]; + I = new double[gridPoints+1]; + transmissionRate = new double[gridPoints+1]; + double dt = rateShifts.getArrayValue(rateShifts.getDimension()-1) / (time.length-1); + int j = 0; + int k = j-1; + isValid = true; + for (int i=0; i < gridPoints; i++) { + // update the time for this grid point + time[i] = i*dt; + // find the interval in which this grid point lies + if (time[i] >= rateShifts.getArrayValue(j)) { + j++; + k++; + if (k==rateShifts.getDimension()-1) { + k--; + } + } + // get the time diff from the last point where logI was estimated + double timeDiff = time[i]-rateShifts.getArrayValue(k); + double timeDiff2 = timeDiff*timeDiff; + double timeDiff3 = timeDiff2*timeDiff; + // compute the number of infected individuals at the grid points + I[i] = Math.exp(splineCoeffs[k][0]*timeDiff3 + splineCoeffs[k][1]*timeDiff2 + splineCoeffs[k][2]*timeDiff + splineCoeffs[k][3]); + // compute the transmission rate at the grid points, from dI/dt and the recovery rate, the minus in front + // of I is because the transmission rates are forward in time, but the dI/dt is backward in time + if (infectedIsNe) { + transmissionRate[i] = 1; + }else { + transmissionRate[i] = uninfectiousRate.getValue() - + (3 * splineCoeffs[k][0] * timeDiff2 + 2 * splineCoeffs[k][1] * timeDiff + splineCoeffs[k][2]); + } + if (transmissionRate[i] < 0) { +// System.out.println(time[i] + " " + I[i] + " " + transmissionRate[i] + " " + splineCoeffs[k][0]*timeDiff3 + " " + splineCoeffs[k][1]*timeDiff2 + " " + splineCoeffs[k][2]*timeDiff + " " + splineCoeffs[k][3]); + isValid = false; + } + } + // set the last time point to the last time point in the spline + time[gridPoints] = Double.POSITIVE_INFINITY; + // set the last I to the last I in the spline + I[gridPoints] = I[gridPoints-1]; + // set the last transmission rate to the last transmission rate in the spline + transmissionRate[gridPoints] = transmissionRate[gridPoints-1]; + ratesKnows = true; + } + + public boolean update() { + if (!ratesKnows) { + recalculateRates(); + } + return isValid; + } + + @Override + public boolean requiresRecalculation() { + ratesKnows = false; + return true; + } + + @Override + public void store() { + transmissionRateStored = new double[transmissionRate.length]; + System.arraycopy(transmissionRate, 0, transmissionRateStored, 0, transmissionRate.length); + I_stored = new double[I.length]; + System.arraycopy(I, 0, I_stored, 0, I.length); + super.store(); + } + + @Override + public void restore() { + recalculateRates(); + super.restore(); + } + + + + /** computes the coefficients for the cubic spline interpolation + * de Boor, Carl. A Practical Guide to Splines. Springer-Verlag, New York: 1978 + */ + public void notAKnotCubicSpline() { + int n = rateShifts.getDimension(); + + // Calculate h values (difference between x values) + double[] h = new double[n - 1]; + for (int i = 0; i < n - 1; i++) { + h[i] = rateShifts.getArrayValue(i + 1) - rateShifts.getArrayValue(i); + } + + // Calculate the difference in y values + double[] delta = new double[n - 1]; + for (int i = 0; i < n - 1; i++) { + delta[i] = (infected.getArrayValue(i + 1) - infected.getArrayValue(i)) / h[i]; + } + + // Create the tridiagonal system + RealMatrix A = new Array2DRowRealMatrix(n, n); + RealVector r = new ArrayRealVector(n); + + // Set up the system A*mu = r for the not-a-knot condition + A.setEntry(0, 0, h[1]); + A.setEntry(0, 1, -(h[0] + h[1])); + A.setEntry(0, 2, h[0]); + A.setEntry(n - 1, n - 3, h[n - 2]); + A.setEntry(n - 1, n - 2, -(h[n - 2] + h[n - 3])); + A.setEntry(n - 1, n - 1, h[n - 3]); + + for (int i = 1; i < n - 1; i++) { + A.setEntry(i, i - 1, h[i - 1]); + A.setEntry(i, i, 2 * (h[i - 1] + h[i])); + A.setEntry(i, i + 1, h[i]); + r.setEntry(i, 6 * (delta[i] - delta[i - 1])); + } + + // Solve for mu + DecompositionSolver solver = new LUDecomposition(A).getSolver(); + RealVector mu = solver.solve(r); + + // Store coefficients for each spline segment + splineCoeffs = new double[n - 1][4]; + for (int i = 0; i < n - 1; i++) { + splineCoeffs[i][0] = (mu.getEntry(i + 1) - mu.getEntry(i)) / (6 * h[i]); + splineCoeffs[i][1] = mu.getEntry(i) / 2; + splineCoeffs[i][2] = delta[i] - h[i] * (2 * mu.getEntry(i) + mu.getEntry(i + 1)) / 6; + splineCoeffs[i][3] = infected.getArrayValue(i); + } + } + + /** + * Clamped cubic spline interpolation with specified first derivatives at the end points. + */ + public void clampedCubicSpline() { + int n = rateShifts.getDimension(); + + // Calculate h values (difference between x values) + double[] h = new double[n - 1]; + for (int i = 0; i < n - 1; i++) { + h[i] = rateShifts.getArrayValue(i + 1) - rateShifts.getArrayValue(i); + } + + // Calculate the difference in y values + double[] delta = new double[n - 1]; + for (int i = 0; i < n - 1; i++) { + delta[i] = (infected.getArrayValue(i + 1) - infected.getArrayValue(i)) / h[i]; + } + + // Create the tridiagonal system + RealMatrix A = new Array2DRowRealMatrix(n, n); + RealVector r = new ArrayRealVector(n); + + // get the derivate of the first intervals + double ddtstart = (infected.getArrayValue(1)-infected.getArrayValue(0))/(rateShifts.getArrayValue(1)-rateShifts.getArrayValue(0)); + // get the derivate of the last intervals + double ddtend = (infected.getArrayValue(n-1)-infected.getArrayValue(n-2))/(rateShifts.getArrayValue(n-1)-rateShifts.getArrayValue(n-2)); + + // Clamped condition at the start + A.setEntry(0, 0, 2 * h[0]); + A.setEntry(0, 1, h[0]); + r.setEntry(0, ((delta[0] - ddtstart) * 3)); + + // Clamped condition at the end + A.setEntry(n - 1, n - 2, h[n - 2]); + A.setEntry(n - 1, n - 1, 2 * h[n - 2]); + r.setEntry(n - 1, ((ddtend - delta[n - 2]) * 3)); + + // Set up the middle equations of the tridiagonal system + for (int i = 1; i < n - 1; i++) { + A.setEntry(i, i - 1, h[i - 1]); + A.setEntry(i, i, 2 * (h[i - 1] + h[i])); + A.setEntry(i, i + 1, h[i]); + r.setEntry(i, 3 * (delta[i] - delta[i - 1])); + } + + // Solve for mu + DecompositionSolver solver = new LUDecomposition(A).getSolver(); + RealVector mu = solver.solve(r); + + // Store coefficients for each spline segment + splineCoeffs = new double[n - 1][4]; + for (int i = 0; i < n - 1; i++) { + splineCoeffs[i][0] = (mu.getEntry(i + 1) - mu.getEntry(i)) / (6 * h[i]); + splineCoeffs[i][1] = mu.getEntry(i) / 2; + splineCoeffs[i][2] = delta[i] - h[i] * (2 * mu.getEntry(i) + mu.getEntry(i + 1)) / 6; + splineCoeffs[i][3] = infected.getArrayValue(i); + } + } + + + private void computeAkimaSplineCoefficients() { + int n = rateShifts.getDimension(); + double[] slopes = new double[n-1]; + + // Compute central differences for slopes + for (int i = 0; i < (n-1); i++) { + slopes[i] = (infected.getArrayValue(i + 1) - infected.getArrayValue(i )) / (rateShifts.getArrayValue(i + 1) - rateShifts.getArrayValue(i)); + } + + // Calculate the Akima weights and the weighted slopes + double[] weightedSlopes = new double[n]; + for (int i = 2; i < (n - 2); i++) { + double weight1 = Math.abs(slopes[i + 1] - slopes[i]); + double weight2 = Math.abs(slopes[i-1] - slopes[i - 2]); + if (weight1 + weight2 == 0) + weightedSlopes[i] = (slopes[i + 1] + slopes[i]) / 2; + else + weightedSlopes[i] = (weight1 * slopes[i-1] + weight2 * slopes[i]) / (weight1 + weight2); + } + + weightedSlopes[0] = slopes[0]; + weightedSlopes[1] = (slopes[0] + slopes[1])/2; + + // Extrapolate the end slopes + weightedSlopes[n-1] = slopes[n-2]; + weightedSlopes[n - 2] = (slopes[n-3]-slopes[n-2])/2; + + // Calculate coefficients for Akima spline segments + splineCoeffs = new double[n - 1][4]; + for (int i = 0; i < n - 1; i++) { + splineCoeffs[i] = computeAkimaCoefficients(rateShifts.getArrayValue(i), rateShifts.getArrayValue(i + 1), + infected.getArrayValue(i), infected.getArrayValue(i + 1), + weightedSlopes[i], weightedSlopes[i + 1]); + } + } + + /** + * Method to compute coefficients for an Akima spline segment. + */ + private double[] computeAkimaCoefficients(double x0, double x1, double y0, double y1, double d0, double d1) { + double dx = x1 - x0; + double[] coeffs = new double[4]; + coeffs[3] = y0; + coeffs[2] = d0; + coeffs[1] = (3 * (y1 - y0) / dx - 2 * d0 - d1) / dx; + coeffs[0] = (d0 + d1 - 2 * (y1 - y0) / dx) / (dx * dx); + return coeffs; + } + + @Override + public void init(PrintStream out) { + for (int i = 0; i < splineCoeffs.length; i++){ + for (int j = 0; j < splineCoeffs[i].length; j++){ + out.print("splineCoeffs_" + i + "_" + j + "\t"); + } + } + } + + @Override + public void log(long l, PrintStream printStream) { + for (int i = 0; i < splineCoeffs.length; i++){ + for (int j = 0; j < splineCoeffs[i].length; j++){ + printStream.print(splineCoeffs[i][j] + "\t"); + } + } + } + + @Override + public void close(PrintStream printStream){ + + } + + + + } \ No newline at end of file diff --git a/src/coalre/dynamics/SplineTransmissionDifference.java b/src/coalre/dynamics/SplineTransmissionDifference.java new file mode 100644 index 0000000..203cb1b --- /dev/null +++ b/src/coalre/dynamics/SplineTransmissionDifference.java @@ -0,0 +1,95 @@ +package coalre.dynamics; + +import beast.base.core.Function; +import beast.base.core.Input; +import beast.base.inference.CalculationNode; + +public class SplineTransmissionDifference extends CalculationNode implements Function { + + public Input splineInput = new Input<>("spline", "spline to use for the population function", Input.Validate.REQUIRED); + + double[] difference; + double[] storedDifference; + boolean needsRecompute = true; + + @Override + public void initAndValidate() { + if (splineInput.get().infectedIsNe) { + difference = new double[splineInput.get().splineCoeffs.length]; + storedDifference = new double[splineInput.get().splineCoeffs.length]; + }else { + difference = new double[splineInput.get().splineCoeffs.length - 1]; + storedDifference = new double[splineInput.get().splineCoeffs.length - 1]; + } + } + + @Override + public int getDimension() { + return difference.length; + } + + @Override + public double getArrayValue() { + if (needsRecompute) { + compute(); + } + return difference[0]; + } + + @Override + public double getArrayValue(int dim) { + if (needsRecompute) { + compute(); + } + return difference[dim]; + } + + void compute() { + + if (splineInput.get().infectedIsNe){ + double[] value = new double[splineInput.get().splineCoeffs.length+1]; + for (int i = 0; i <= splineInput.get().splineCoeffs.length; i++) { + value[i] = splineInput.get().InfectedInput.get().getArrayValue(i); + } + + for (int i = 1; i <= splineInput.get().splineCoeffs.length; i++) { + difference[i - 1] = value[i - 1] - value[i]; + } + + }else { + double[] transmissionRates = new double[splineInput.get().splineCoeffs.length]; + for (int i = 0; i < splineInput.get().splineCoeffs.length; i++) { + transmissionRates[i] = splineInput.get().uninfectiousRate.getValue() - + splineInput.get().splineCoeffs[i][2]; + } + + for (int i = 1; i < splineInput.get().splineCoeffs.length; i++) { + difference[i - 1] = transmissionRates[i - 1] - transmissionRates[i]; + } + } + + needsRecompute = false; + } + + @Override + public void store() { + System.arraycopy(difference, 0, storedDifference, 0, difference.length); + super.store(); + } + + @Override + public void restore() { + double [] tmp = storedDifference; + storedDifference = difference; + difference = tmp; + super.restore(); + } + + @Override + public boolean requiresRecalculation() { + needsRecompute = true; + return true; + } + + +} diff --git a/src/coalre/operators/AddRemoveReassortmentCoalescent.java b/src/coalre/operators/AddRemoveReassortmentCoalescent.java index 519a22d..717f835 100644 --- a/src/coalre/operators/AddRemoveReassortmentCoalescent.java +++ b/src/coalre/operators/AddRemoveReassortmentCoalescent.java @@ -28,6 +28,8 @@ public void initAndValidate() { @Override public double networkProposal() { double logHR; + network.startEditing(this); + if (Randomizer.nextBoolean()) { logHR = addRecombination(); }else { @@ -69,6 +71,11 @@ public double networkProposal() { double transformedTimeToNextCoal = Randomizer.nextExponential(rate); double timeToNextCoal = coalescentDistr.populationFunction.getInverseIntensity( transformedTimeToNextCoal + currentTransformedTime); + +// if (timeToNextCoal binomialProbInput = new Input<>("binomialProb", "Probability of a given segment choosing a particular parent."); + public Input maxHeightRatioInput = new Input<>( + "maxHeightRatio", + "if specified, above the ratio, only coalescent events are allowed.", Double.POSITIVE_INFINITY); + + public Input redFactorInput = new Input<>( + "redFactor", + "by how much the recombination rate should be reduced after reaching the maxHeightRatio.", 0.1); + + private int nSegments; @@ -53,9 +60,8 @@ public double networkProposal() { double resimulate() { network.startEditing(this); - - // get the place where to cut - double maxHeight = getMaxSegmentMRCA(); + // get the maximum height of the segment tree roots + double maxHeight = NetworkStatsLogger.getLociMRCA(network); // get all network edges List networkEdges = new ArrayList<>(network.getEdges()); @@ -73,6 +79,9 @@ public double networkProposal() { // simulate the rest of the network starting from mxHeight double currentTime = maxHeight; double timeUntilNextSample = Double.POSITIVE_INFINITY; + // get the time when the reassortment rates are reduced + double recChangeTime = maxHeight*maxHeightRatioInput.get(); + double redFactor = 1.0; do { // get the current propensities int k = startingEdges.size(); @@ -83,16 +92,22 @@ public double networkProposal() { transformedTimeToNextCoal + currentTransformedTime) - currentTime; - double timeToNextReass = k>=1 ? Randomizer.nextExponential(k*reassortmentRate.getValue()) : Double.POSITIVE_INFINITY; + double timeToNextReass = k>=1 ? Randomizer.nextExponential(k*reassortmentRate.getValue() * redFactor) : Double.POSITIVE_INFINITY; // next event time double timeUntilNextEvent = Math.min(timeToNextCoal, timeToNextReass); - if (timeUntilNextEvent < timeUntilNextSample) { - currentTime += timeUntilNextEvent; - if (timeUntilNextEvent == timeToNextCoal) - coalesce(currentTime, startingEdges); - else - reassort(currentTime, startingEdges); + if ((timeUntilNextEvent+currentTime)>recChangeTime) { + currentTime = recChangeTime; + redFactor *= redFactorInput.get(); + recChangeTime = Double.POSITIVE_INFINITY; + }else { + if (timeUntilNextEvent < timeUntilNextSample) { + currentTime += timeUntilNextEvent; + if (timeUntilNextEvent == timeToNextCoal) + coalesce(currentTime, startingEdges); + else + reassort(currentTime, startingEdges); + } } } @@ -108,17 +123,6 @@ public double networkProposal() { } - double getMaxSegmentMRCA(){ - double maxHeight = 0.0; - for (int i = 0; i < segmentTreesInput.get().size(); i++){ - double height = segmentTreesInput.get().get(i).getRoot().getHeight(); - if (height>maxHeight) - maxHeight=height; - } - - return maxHeight; - } - private void coalesce(double coalescentTime, List extantLineages) { // Sample the pair of lineages that are coalescing: NetworkEdge lineage1 = extantLineages.get(Randomizer.nextInt(extantLineages.size())); diff --git a/src/coalre/operators/MultiTipDatesRandomWalker.java b/src/coalre/operators/MultiTipDatesRandomWalker.java index 82e2408..c5b9fe2 100644 --- a/src/coalre/operators/MultiTipDatesRandomWalker.java +++ b/src/coalre/operators/MultiTipDatesRandomWalker.java @@ -1,9 +1,5 @@ package coalre.operators; -import java.text.DecimalFormat; -import java.util.ArrayList; -import java.util.List; - import beast.base.core.Description; import beast.base.core.Input; import beast.base.core.Input.Validate; @@ -13,6 +9,10 @@ import beast.base.evolution.tree.Tree; import beast.base.util.Randomizer; +import java.text.DecimalFormat; +import java.util.ArrayList; +import java.util.List; + @Description("Randomly moves tip dates on a tree by randomly selecting one from (a subset of) taxa") @@ -62,6 +62,7 @@ public void initAndValidate() { } taxonIndices[k++] = taxonIndex; } + } else { taxonIndices = new int[treeInput.get().getTaxaNames().length]; for (int i = 0; i < taxonIndices.length; i++) { @@ -87,9 +88,9 @@ public double proposal() { double value = node.getHeight(); double newValue = value+difference; - - - if (newValue > node.getParent().getHeight()) { // || newValue < 0.0) { + + + if (newValue > node.getParent().getHeight()) { // || newValue < 0.0) { return Double.NEGATIVE_INFINITY; } @@ -98,8 +99,9 @@ public double proposal() { return Double.NEGATIVE_INFINITY; } node.setHeight(newValue); - - } + + + } return 0.0; } diff --git a/src/coalre/operators/NetworkOperator.java b/src/coalre/operators/NetworkOperator.java index 38a6371..ea0d480 100644 --- a/src/coalre/operators/NetworkOperator.java +++ b/src/coalre/operators/NetworkOperator.java @@ -1,12 +1,10 @@ package coalre.operators; import beast.base.core.Input; -import beast.base.inference.Operator; import beast.base.core.Log; import beast.base.evolution.tree.Tree; -import beast.base.util.Binomial; +import beast.base.inference.Operator; import beast.base.util.Randomizer; -import cern.colt.Arrays; import coalre.network.Network; import coalre.network.NetworkEdge; import coalre.network.NetworkNode; @@ -81,6 +79,16 @@ public double proposal() { // System.out.println(segmentTrees.get(segIdx) +";"); // System.out.println("========="); +// String[] trees = new String[segmentTrees.size()]; +// if (this instanceof TipReheight) { +// if (((TipReheight) this).taxonsetInput.get().getTaxonId(0).contentEquals("RVA/1070/Thailand|2013-01-01")) { +// System.out.println(network); +// for (int i =0; i < segmentTrees.size(); i++) { +// trees[i] = segmentTrees.get(i) + ";"; +// } +// } +// } + double logHR = networkProposal(); @@ -88,6 +96,23 @@ public double proposal() { for (int segIdx=0; segIdx recoveryRateInput = new Input<>("recoveryRate", "recovery rate", Input.Validate.REQUIRED); + + public Input waningImmunityRateInput = new Input<>("waningImmunityRate", "waning immunity rate", Input.Validate.REQUIRED); + + public Input reassortmenProbabilityInput = new Input<>("reassortmenProbability", "reassortment probability", Input.Validate.REQUIRED); + + public Input samplingProbabilityInput = new Input<>("samplingProbability", "sampling probability", Input.Validate.REQUIRED); + + public Input populationSizeInput = new Input<>("populationSize", "population size", Input.Validate.REQUIRED); + + public Input nSegmentsInput = new Input<>("nSegments","Number of segments. Used if no segment trees are supplied."); + + public Input simulationTimeInput = new Input<>("simulationTime", "simulation time", Input.Validate.REQUIRED); + + public Input fileNameInput = new Input<>("fileName", + "Name of file to write simulated network to."); + + public Input enableSegTreeUpdateInput = new Input<>("enableSegmentTreeUpdate", + "If false, segment tree objects won't be updated to agree with simulated " + + "network. (Default true.)", true); + + public Input> segmentTreesInput = new Input<>("segmentTree", + "One or more segment trees to initialize.", new ArrayList<>()); + + public Input minSamplesInput = new Input<>("minSamples", + "Minimum number of samples in the tree, otherwise, simulation is repeated.", 0); + int nSegments; + + int states; + + List lineages; + List sirEvents; + + int S,I,R; + + @Override + public void initAndValidate(){ + + List IDs = new ArrayList<>(); + if (segmentTreesInput.get().isEmpty()) { + nSegments = (int) nSegmentsInput.get(); + }else { + nSegments = segmentTreesInput.get().size(); + segmentNames = new String[nSegments]; + // initialize names of segments + baseName = ""; + for (int segIdx=0; segIdx taxonList = new ArrayList<>(); + for (NetworkNode leaf : getLeafNodes()) { + Taxon taxon = new Taxon(leaf.getTaxonLabel()); + taxonList.add(taxon); + } + // make a TaxonSet from the taxa + TaxonSet taxonset = new TaxonSet(); + taxonset.initByName("taxon", taxonList); + + // create the trait set using the sampling times + TraitSet traitSet = new TraitSet(); + // make a string of values with leaf.getTaxonLabel()+"=" +leaf.getHeight() and individual entries seprated by , + String values = ""; + for (NetworkNode leaf : getLeafNodes()) { + values += leaf.getTaxonLabel()+"=" +leaf.getHeight() + ","; + } + traitSet.initByName("value", values.substring(0, values.length()-1), "taxa", taxonset, "traitname", "date-backward"); + + // Update segment trees: + if (enableSegTreeUpdateInput.get()) { + for (int segIdx = 0; segIdx < nSegments; segIdx++) { + // get the segment tree, set the taxonsets and update the segment tree based on the network + + + + Tree segmentTree = new Tree(); + segmentTree.initByName("taxonset", taxonset); + updateSegmentTree(segmentTree, segIdx); + segmentTree.setID(IDs.get(segIdx)); +// segmentTreesInput.get().add( segmentTree); + segmentTree.setEverythingDirty(false); + segmentTreesInput.get().set(segIdx, segmentTree); + } + } + // Write simulated network to file if requested + super.initAndValidate(); + + for (int segIdx = 0; segIdx < nSegments; segIdx++) { + System.out.println(segmentTreesInput.get().get(segIdx).getID()); + System.out.println(segmentTreesInput.get().get(segIdx).getRoot().toNewick()); + System.out.println(segmentTreesInput.get().get(segIdx).getNodeCount()); + + } + + // Write simulated network to file if requested + if (fileNameInput.get() != null) { + try (PrintStream ps = new PrintStream(fileNameInput.get())) { + + ps.println(toString().replace("0.0)", "0.0000000001)")); + for (int segIdx = 0; segIdx < nSegments; segIdx++) { + ps.println(segmentTreesInput.get().get(segIdx).getRoot().toNewick() +";"); + } + } catch (FileNotFoundException ex) { + throw new RuntimeException("Error writing to output file '" + + fileNameInput.get() + "'."); + } + } + super.initAndValidate(); + } + + protected void simulateNetwork() { + List sampledIndividuals = new ArrayList<>(); + + // set the starting conditions for simulations + do { + sirEvents = new ArrayList<>(); + + double S = populationSizeInput.get() - 1; + double I = 1; + double R = 0; + int individualNr = 0; + // start the simulations while keeping track of the different individuals + List activeIndividuals = new ArrayList<>(); + sampledIndividuals = new ArrayList<>(); + + Individual root = new Individual(individualNr); + individualNr++; + activeIndividuals.add(root); + double time = 0.0; + do { + double transmissionRate = I*(S + I - 1)/(S+I+R) * transmissionRateInput.get(); + double recoveryRate = I * recoveryRateInput.get(); + double waningRate = R * waningImmunityRateInput.get(); + + double totalRate = transmissionRate + recoveryRate + waningRate; + double nextEventTime = Math.log(1.0 / Randomizer.nextDouble()) / totalRate; + + // pick which event happens next based on the rates, by drawing at random + double random = Randomizer.nextDouble() * totalRate; + + if (random < transmissionRate) { + // transmission event + // pick a random individual to transmit to + Individual parent = activeIndividuals.get(Randomizer.nextInt(activeIndividuals.size())); + // set time for parent + parent.setTime(time + nextEventTime); + // create two new individuals + Individual child1 = new Individual(individualNr); + individualNr++; + Individual child2 = new Individual(individualNr); + individualNr++; + // start the first new lineage + parent.addChild(child1); + child1.addParent(parent); + // add the child to the second parent + parent.addChild(child2); + child2.addParent(parent); + + // add the child to the list of active individuals + activeIndividuals.add(child1); + + // pick if the individual to be infected + double probCoInf = S / (double) (S + I - 1); +// System.out.println(probCoInf); + if (Randomizer.nextDouble() > probCoInf) { + // co-infection event + // pick another active individual that is not the parent and not child1 + Individual parent2 = activeIndividuals.get(Randomizer.nextInt(activeIndividuals.size()-1)); + while (parent2 == parent) { + parent2 = activeIndividuals.get(Randomizer.nextInt(activeIndividuals.size()-1)); + } + parent2.setTime(time + nextEventTime); + + // add the child to the second parent + child2.setTime(time + nextEventTime); + + // create a third child + Individual child3 = new Individual(individualNr); + individualNr++; + child3.addParent(child2); + child2.addChild(child3); + + child3.addParent(parent2); + parent2.addChild(child3); + + activeIndividuals.add(child3); + activeIndividuals.remove(parent2); + sirEvents.add(new SIREvents(0, time + nextEventTime)); + } else { + // update the population counts + S--; + I++; + activeIndividuals.add(child2); + sirEvents.add(new SIREvents(1, time + nextEventTime)); + } + // remove the parent + activeIndividuals.remove(parent); + } else if (random < (transmissionRate + recoveryRate)) { + // pick a random individual to recover + Individual individual = activeIndividuals.get(Randomizer.nextInt(activeIndividuals.size())); + individual.setTime(time + nextEventTime); + // choose if that individual will be sampled + if (Randomizer.nextDouble() < samplingProbabilityInput.get()) { + // sample the individual + sampledIndividuals.add(individual); + } + // remove the individual from the list of active individuals + activeIndividuals.remove(individual); + // update the population counts + I--; + R++; + sirEvents.add(new SIREvents(2, time + nextEventTime)); + } else { + System.out.println(transmissionRate + recoveryRate); + System.out.println(transmissionRate + " " + recoveryRate + " " + waningRate); + System.out.println(I); + System.out.println((S + I) ); + System.exit(0); + // waning immunity event + // update the population counts + R--; + S++; + sirEvents.add(new SIREvents(3, time + nextEventTime)); + } + time += nextEventTime; +// System.out.println("S: " + S + " I: " + I + " R: " + R + " " + activeIndividuals.size() + " " + sampledIndividuals.size()); + } while (time < simulationTimeInput.get() && I > 0); + System.out.println("start building network from " + sampledIndividuals.size() + " sampled individuals" + " simulation time was " + time); + + + // build the network from the sampled individuals + }while (sampledIndividuals.size() sampledIndividuals) { + // get the time of the most recent sample + double mrsi_time = 0.0; + for (Individual individual:sampledIndividuals){ + if(individual.getTime()>mrsi_time){ + mrsi_time = individual.getTime(); + } + } + + System.out.println("mrsi time: " + mrsi_time); + System.out.println("================="); + System.out.println("================="); + System.out.println("================="); + System.out.println("================="); + System.out.println("================="); + System.out.println("================="); + lineages = new ArrayList<>(); + + int sampledIndividualNr = 0; + int totalObservedCoal=0; + + List activeIndividuals = new ArrayList<>(); + List activeEdges = new ArrayList<>(); + System.out.print("start"); + while (activeIndividuals.size()>1 || sampledIndividuals.size()>0){ +// System.out.println("get next sampling time"); + // get the next sampling event time, while keeping track of the index of the individual + double nextSamplingTime = Double.NEGATIVE_INFINITY; + int nextSamplingIndex = -1; + for (int i=0;inextSamplingTime){ + nextSamplingTime = sampledIndividuals.get(i).getTime(); + nextSamplingIndex = i; + } + } +// System.out.println("get next active time"); + // get the activeIndividuals time, while keeping track of the index of the individual + double nextActiveTime = Double.NEGATIVE_INFINITY; + int nextActiveIndex = -1; + for (int i=0;i0) { + if (activeIndividuals.get(i).getParents().size() == 2){ + double nextParentTime = Math.max(activeIndividuals.get(i).getParents().get(0).getTime(), + activeIndividuals.get(i).getParents().get(1).getTime()); + if (nextParentTime>=nextActiveTime) { + nextActiveTime = nextParentTime; + nextActiveIndex = i; + } + } else if (activeIndividuals.get(i).getParents().get(0).getTime() > nextActiveTime) { + nextActiveTime = activeIndividuals.get(i).getParents().get(0).getTime(); + nextActiveIndex = i; + } +// else if ((activeIndividuals.get(i).getParents().get(0).getTime() == nextActiveTime)){ +// // if the other individual has two parents, then do nothing +// if (activeIndividuals.get(nextActiveIndex).getParents().size()==2){ +// System.out.println("do nothing");} +// else{ +// System.out.println(activeIndividuals.get(nextActiveIndex).getParents().size()); +// System.out.println(activeIndividuals.get(i).getParents().size()); +// System.out.println("----------------"); +// } +// +// } + } + } +// System.out.println(" done " + nextSamplingTime +" " + nextActiveTime + ""); + + + // depending on which one is next + if (nextSamplingTime>nextActiveTime){ +// System.out.println("sampling event"); + NetworkNode sampledNode = new NetworkNode(); + sampledNode.setHeight(mrsi_time-nextSamplingTime); + sampledNode.setTaxonLabel("sample_no"+sampledIndividualNr); + sampledNode.setTaxonIndex(sampledIndividualNr); + sampledIndividualNr++; + + // set segments for the sampled node, where all bits are true + BitSet segments = new BitSet(nSegments); + // set all bits to true + segments.set(0,nSegments); + // make a new edge + NetworkEdge edge = new NetworkEdge(null, sampledNode, segments); + activeEdges.add(edge); + activeIndividuals.add(sampledIndividuals.get(nextSamplingIndex)); + // remove the individual from the list of sampled individuals + sampledIndividuals.remove(nextSamplingIndex); + + lineages.add(new SIREvents(0, nextSamplingTime)); + }else { + // check if the individual has two parents + if (activeIndividuals.get(nextActiveIndex).getParents().size()==2){ +// System.out.println("co-infection event"); + // get the corresponding edge + NetworkEdge edge = activeEdges.get(nextActiveIndex); + // pick randomly which segments go left and which go right + BitSet segmentsLeft = new BitSet(nSegments); + BitSet segmentsRight = new BitSet(nSegments); + for (int i=0;i 0 + if (segmentsLeft.cardinality()>0 && segmentsRight.cardinality()>0){ + // compute the average observation probability + double prob = 0.0; + for (NetworkEdge e : activeEdges){ + prob += 1-2*Math.pow(0.5, e.hasSegments.cardinality()); + } + prob = prob/activeEdges.size(); + +// System.out.println("observed reassortment event"); + // make a reassortment node + NetworkNode reassortmentNode = new NetworkNode(); + reassortmentNode.setHeight(mrsi_time-nextActiveTime); + // set the child edge + edge.parentNode = reassortmentNode; + reassortmentNode.addChildEdge(edge); + // make two new edges + NetworkEdge edgeLeft = new NetworkEdge(null, reassortmentNode, segmentsLeft); + NetworkEdge edgeRight = new NetworkEdge(null, reassortmentNode, segmentsRight); + // add the edges to the activeEdges list + reassortmentNode.addParentEdge(edgeLeft); + reassortmentNode.addParentEdge(edgeRight); + activeEdges.add(edgeLeft); + activeEdges.add(edgeRight); + + edgeLeft.childNode=reassortmentNode; + edgeRight.childNode=reassortmentNode; + + // add the parent individuals to the active individuals + activeIndividuals.add(activeIndividuals.get(nextActiveIndex).getParents().get(0)); + activeIndividuals.add(activeIndividuals.get(nextActiveIndex).getParents().get(1)); + + // remove the individual from the activeIndividuals list + activeIndividuals.remove(nextActiveIndex); + activeEdges.remove(nextActiveIndex); + lineages.add(new SIREvents(1, nextActiveTime, prob)); + }else{ + // the event is not observed, just replace the individual with its parent that has cardinality>0 + if (segmentsLeft.cardinality()>0) { + activeIndividuals.set(nextActiveIndex, activeIndividuals.get(nextActiveIndex).getParents().get(0)); + }else{ + activeIndividuals.set(nextActiveIndex, activeIndividuals.get(nextActiveIndex).getParents().get(1)); + } + } + }else{ + // get the other child for this coalescent event + Individual otherChild = activeIndividuals.get(nextActiveIndex).getParents().get(0).getChildren().get(0); + if (otherChild.equals(activeIndividuals.get(nextActiveIndex))) { + otherChild = activeIndividuals.get(nextActiveIndex).getParents().get(0).getChildren().get(1); + } + // check if the other child is in the activeIndividuals list + if (activeIndividuals.contains(otherChild)) { + totalObservedCoal++; + // make a coalescent node + NetworkNode coalNode = new NetworkNode(); + coalNode.setHeight(mrsi_time - nextActiveTime); + // get the two child edges + NetworkEdge childEdge1 = activeEdges.get(nextActiveIndex); + NetworkEdge childEdge2 = activeEdges.get(activeIndividuals.indexOf(otherChild)); + // add the parent node to the edges + childEdge1.parentNode = coalNode; + childEdge2.parentNode = coalNode; + // add the edges to the coalNode + coalNode.addChildEdge(childEdge1); + coalNode.addChildEdge(childEdge2); + // make a new parent edge with the segments being the union of the two child edge segments + BitSet segments = (BitSet) childEdge1.hasSegments.clone(); + segments.or(childEdge2.hasSegments); + NetworkEdge parentEdge = new NetworkEdge(null, coalNode, segments); + // add the parent edge to the activeEdges list + activeEdges.add(parentEdge); + // remove the two child edges from the activeEdges list + activeEdges.remove(childEdge1); + activeEdges.remove(childEdge2); + activeIndividuals.add(activeIndividuals.get(nextActiveIndex).getParents().get(0)); + // remove the two children from the activeIndividuals list + activeIndividuals.remove(nextActiveIndex); + activeIndividuals.remove(otherChild); + lineages.add(new SIREvents(2, nextActiveTime)); + } else { + // replace the activeIndividuals with its parent + activeIndividuals.set(nextActiveIndex, activeIndividuals.get(nextActiveIndex).getParents().get(0)); + } + } + } + // check if all the individuals in activeIndividuals have different numbers + for (int i=0; i parent; + List child; + + double time; + int number; + + + int S, I; + + public Individual(int number){ + parent = new ArrayList<>(); + child = new ArrayList<>(); + this.number = number; + } + + public List getParents(){ + return parent; + } + + public List getChildren(){ + return child; + } + + public double getTime(){ + return time; + } + + public void setTime(double time){ + this.time = time; + } + + public void addParent(Individual parent){ + this.parent.add(parent); + } + + public void addChild(Individual child){ + this.child.add(child); + } + + // build a method to check if two individual objects ar the same + public boolean equals(Individual other){ + if (this.number == other.number){ + return true; + }else{ + return false; + } + } + + // define what happens in toString + public String toString(){ + return "" + number ; + } + + public void setS(int S){ + this.S = S; + } + public int getS(){ + return S; + } + + public void setI(int I){ + this.I = I; + } + + public int getI(){ + return I; + } + + } + + protected class SIREvents{ + int eventType; + double time; + + double prob = 0.0; + + public SIREvents(int eventType, double time){ + this.eventType = eventType; + this.time = time; + } + + public SIREvents(int eventType, double time, double prob){ + this.eventType = eventType; + this.time = time; + this.prob = prob; + } + + @Override + public String toString() { + return eventType + ":" + time + ":" + prob; + } + } + + @Override + public void init(PrintStream out) { + out.println("SIR\tlineages\t"); + } + + @Override + public void close(PrintStream out) { + + } + + @Override + public void log(long sample, PrintStream out) { + out.println(sirEvents + "\t" + lineages+"\t"); + } +} + diff --git a/src/coalre/simulator/SuperspreadingSIRwithReassortment.java b/src/coalre/simulator/SuperspreadingSIRwithReassortment.java new file mode 100644 index 0000000..0d2cdd8 --- /dev/null +++ b/src/coalre/simulator/SuperspreadingSIRwithReassortment.java @@ -0,0 +1,278 @@ +package coalre.simulator; + +import beast.base.core.Input; +import beast.base.core.Loggable; +import beast.base.evolution.alignment.Taxon; +import beast.base.evolution.alignment.TaxonSet; +import beast.base.evolution.tree.TraitSet; +import beast.base.evolution.tree.Tree; +import beast.base.util.Randomizer; +import coalre.network.NetworkNode; + +import java.io.FileNotFoundException; +import java.io.PrintStream; +import java.util.ArrayList; +import java.util.List; + +public class SuperspreadingSIRwithReassortment extends SIRwithReassortment implements Loggable { + + public Input kInput = new Input<>("k", + "k-value for the negative binomial distribution", 0.1); + + @Override + public void initAndValidate(){ + + List IDs = new ArrayList<>(); + if (segmentTreesInput.get().isEmpty()) { + nSegments = (int) nSegmentsInput.get(); + }else { + nSegments = segmentTreesInput.get().size(); + segmentNames = new String[nSegments]; + // initialize names of segments + baseName = ""; + for (int segIdx=0; segIdx taxonList = new ArrayList<>(); + for (NetworkNode leaf : getLeafNodes()) { + Taxon taxon = new Taxon(leaf.getTaxonLabel()); + taxonList.add(taxon); + } + // make a TaxonSet from the taxa + TaxonSet taxonset = new TaxonSet(); + taxonset.initByName("taxon", taxonList); + + // create the trait set using the sampling times + TraitSet traitSet = new TraitSet(); + // make a string of values with leaf.getTaxonLabel()+"=" +leaf.getHeight() and individual entries seprated by , + String values = ""; + for (NetworkNode leaf : getLeafNodes()) { + values += leaf.getTaxonLabel()+"=" +leaf.getHeight() + ","; + } + traitSet.initByName("value", values.substring(0, values.length()-1), "taxa", taxonset, "traitname", "date-backward"); + + // Update segment trees: + if (enableSegTreeUpdateInput.get()) { + for (int segIdx = 0; segIdx < nSegments; segIdx++) { + // get the segment tree, set the taxonsets and update the segment tree based on the network + Tree segmentTree = new Tree(); + segmentTree.initByName("taxonset", taxonset); + updateSegmentTree(segmentTree, segIdx); + segmentTree.setID(IDs.get(segIdx)); +// segmentTreesInput.get().add( segmentTree); + segmentTree.setEverythingDirty(false); + segmentTreesInput.get().set(segIdx, segmentTree); + } + } + // Write simulated network to file if requested + super.initAndValidate(); + + for (int segIdx = 0; segIdx < nSegments; segIdx++) { + System.out.println(segmentTreesInput.get().get(segIdx).getID()); + System.out.println(segmentTreesInput.get().get(segIdx).getRoot().toNewick()); + System.out.println(segmentTreesInput.get().get(segIdx).getNodeCount()); + + } + + // Write simulated network to file if requested + if (fileNameInput.get() != null) { + try (PrintStream ps = new PrintStream(fileNameInput.get())) { + + ps.println(toString().replace("0.0)", "0.0000000001)")); + for (int segIdx = 0; segIdx < nSegments; segIdx++) { + ps.println(segmentTreesInput.get().get(segIdx).getRoot().toNewick() +";"); + } + } catch (FileNotFoundException ex) { + throw new RuntimeException("Error writing to output file '" + + fileNameInput.get() + "'."); + } + } + super.initAndValidate(); + } + + @Override + protected void simulateNetwork() { + List sampledIndividuals = new ArrayList<>(); + + // set the starting conditions for simulations + do { + sirEvents = new ArrayList<>(); + + double S = populationSizeInput.get() - 1; + double I = 1; + double R = 0; + int individualNr = 0; + // start the simulations while keeping track of the different individuals + List activeIndividuals = new ArrayList<>(); + sampledIndividuals = new ArrayList<>(); + + Individual root = new Individual(individualNr); + individualNr++; + activeIndividuals.add(root); + double time = 0.0; + do { + double transmissionRate = I * recoveryRateInput.get(); + double recoveryRate = I * recoveryRateInput.get(); + double waningRate = R * waningImmunityRateInput.get(); + + double totalRate = transmissionRate + recoveryRate + waningRate; + double nextEventTime = Math.log(1.0 / Randomizer.nextDouble()) / totalRate; + + // pick which event happens next based on the rates, by drawing at random + double random = Randomizer.nextDouble() * totalRate; + + if (random < transmissionRate) { + // pick the number of offsprings from a negative binomial distribution with R and k + // from a gamma and a poisson distribution + double secondary_infections = transmissionRateInput.get()/recoveryRateInput.get(); + double gamma = Randomizer.nextGamma(kInput.get(), kInput.get() / (double) secondary_infections); + int nOffspring = (int) Randomizer.nextPoisson(gamma); + int isR = 0; + // for each offspring, randomly sample if it hits an R + for (int i = 0; i < nOffspring; i++) { + if (Randomizer.nextDouble() < (R+1) / (double) (S + I + R)) { + // sample the individual + isR++; + } + } + nOffspring -= isR; + + List involvedIndividuals = new ArrayList<>(); + + // don't do anything if there are no offspring + if (nOffspring > 0){ +// System.out.println(nOffspring + " " + activeIndividuals.size()); + // transmission event + // pick a random individual to transmit to + Individual parent = activeIndividuals.get(Randomizer.nextInt(activeIndividuals.size())); + + // build the superspreading event as a series of regular infection events + for (int i = 0; i < nOffspring; i++) { + System.out.println( " time " + time); + // set time for parent + parent.setTime(time + nextEventTime); + // create two new individuals + Individual child1 = new Individual(individualNr); + individualNr++; + Individual child2 = new Individual(individualNr); + individualNr++; + // start the first new lineage + parent.addChild(child1); + child1.addParent(parent); + // add the child to the second parent + parent.addChild(child2); + child2.addParent(parent); + + // add the child to the list of active individuals + activeIndividuals.add(child1); + + involvedIndividuals.add(child1.number); + involvedIndividuals.add(child2.number); + involvedIndividuals.add(parent.number); + + + // pick if the individual to be infected, assuming an individual is co-infected once at most + System.out.println(I + " coinf " + activeIndividuals.size()); + double probCoInf = S / (double) (S + activeIndividuals.size() - 2 - i); + if (Randomizer.nextDouble() > probCoInf) { + // offset for the number of active individuals to draw from + int totOptions = i==0? 1 : 2 + i; + System.out.println(totOptions + " " + activeIndividuals.size()); + + // co-infection event + // pick another active individual that is not the parent, while making sure it is not + // the same as any of the previous individuals in this event + Individual parent2 = activeIndividuals.get(Randomizer.nextInt(activeIndividuals.size()-totOptions)); + while (parent2 == parent) { + parent2 = activeIndividuals.get(Randomizer.nextInt(activeIndividuals.size()-totOptions)); + } + parent2.setTime(time + nextEventTime); + + // add the child to the second parent + child2.setTime(time + nextEventTime); + + // create a third child + Individual child3 = new Individual(individualNr); + +// System.out.println("other lineage " + parent2.number); + +// involvedIndividuals.add(parent2.number); +// involvedIndividuals.add(child3.number); + + individualNr++; + child3.addParent(child2); + child2.addChild(child3); + + child3.addParent(parent2); + parent2.addChild(child3); + + activeIndividuals.add(child3); + activeIndividuals.remove(parent2); + sirEvents.add(new SIREvents(0, time + nextEventTime)); + } else { + // update the population counts + S--; + I++; + activeIndividuals.add(child2); + sirEvents.add(new SIREvents(1, time + nextEventTime)); + } + // remove the parent + activeIndividuals.remove(parent); + parent = child1; + time += 0.0000000001; + } + +// System.out.println(involvedIndividuals); + // check if all the involved individuals are unique +// if (involvedIndividuals.size() != involvedIndividuals.stream().distinct().count()) { +// System.out.println("not unique"); +// System.out.println(involvedIndividuals); +// System.exit(0); +// } + + } + } else if (random < (transmissionRate + recoveryRate)) { + // pick a random individual to recover + Individual individual = activeIndividuals.get(Randomizer.nextInt(activeIndividuals.size())); + individual.setTime(time + nextEventTime); + // choose if that individual will be sampled + if (Randomizer.nextDouble() < samplingProbabilityInput.get()) { + // sample the individual + sampledIndividuals.add(individual); + } + // remove the individual from the list of active individuals + activeIndividuals.remove(individual); + // update the population counts + I--; + R++; + sirEvents.add(new SIREvents(2, time + nextEventTime)); + } else { + System.out.println(transmissionRate + recoveryRate); + System.out.println(transmissionRate + " " + recoveryRate + " " + waningRate); + System.out.println(I); + System.out.println((S + I) ); + System.exit(0); + // waning immunity event + // update the population counts + R--; + S++; + sirEvents.add(new SIREvents(3, time + nextEventTime)); + } + time += nextEventTime; +// System.out.println(time + " S: " + S + " I: " + I + " R: " + R + " " + activeIndividuals.size() + " " + sampledIndividuals.size()); + } while (time < simulationTimeInput.get() && I > 0); + System.out.println("start building network from " + sampledIndividuals.size() + " sampled individuals" + " simulation time was " + time); + // build the network from the sampled individuals + }while (sampledIndividuals.size() transmissionRateInput = new Input<>("transmissionRate", "transmission rate", Input.Validate.REQUIRED); + + public Input recoveryRateInput = new Input<>("recoveryRate", "recovery rate", Input.Validate.REQUIRED); + + public Input waningImmunityRateInput = new Input<>("waningImmunityRate", "waning immunity rate", Input.Validate.REQUIRED); + + public Input reassortmenProbabilityInput = new Input<>("reassortmenProbability", "reassortment probability", Input.Validate.REQUIRED); + + public Input samplingProbabilityInput = new Input<>("samplingProbability", "sampling probability", Input.Validate.REQUIRED); + + public Input populationSizeInput = new Input<>("populationSize", "population size", Input.Validate.REQUIRED); + + public Input nSegmentsInput = new Input<>("nSegments","Number of segments. Used if no segment trees are supplied."); + + public Input simulationTimeInput = new Input<>("simulationTime", "simulation time", Input.Validate.REQUIRED); + + public Input fileNameInput = new Input<>("fileName", + "Name of file to write simulated network to."); + + public Input enableSegTreeUpdateInput = new Input<>("enableSegmentTreeUpdate", + "If false, segment tree objects won't be updated to agree with simulated " + + "network. (Default true.)", true); + + public Input> segmentTreesInput = new Input<>("segmentTree", + "One or more segment trees to initialize.", new ArrayList<>()); + + public Input minSamplesInput = new Input<>("minSamples", + "Minimum number of samples in the tree, otherwise, simulation is repeated.", 0); + public Input kInput = new Input<>("k", + "k-value for the negative binomial distribution", Input.Validate.REQUIRED); + + public Input migrationRatesInput = new Input<>("migrationRates", + "Migration rates between segments", Input.Validate.REQUIRED); + + int nSegments; + int states; + + List lineages; + + List structuredSirEvents; + + int I[]; + int S[]; + int R[]; + + public void init(PrintStream out) { + out.println("SIR\tlineages\t"); + } + + @Override + public void log(long sample, PrintStream out) { + out.println(structuredSirEvents + "\t" + lineages+"\t"); + } + + @Override + public void close(PrintStream out) { + + } + + + + @Override + public void initAndValidate(){ + states = transmissionRateInput.get().getDimension(); + + List IDs = new ArrayList<>(); + if (segmentTreesInput.get().isEmpty()) { + nSegments = (int) nSegmentsInput.get(); + }else { + nSegments = segmentTreesInput.get().size(); + segmentNames = new String[nSegments]; + // initialize names of segments + baseName = ""; + for (int segIdx=0; segIdx taxonList = new ArrayList<>(); + for (NetworkNode leaf : getLeafNodes()) { + Taxon taxon = new Taxon(leaf.getTaxonLabel()); + taxonList.add(taxon); + } + // make a TaxonSet from the taxa + TaxonSet taxonset = new TaxonSet(); + taxonset.initByName("taxon", taxonList); + + // create the trait set using the sampling times + TraitSet traitSet = new TraitSet(); + // make a string of values with leaf.getTaxonLabel()+"=" +leaf.getHeight() and individual entries seprated by , + String values = ""; + for (NetworkNode leaf : getLeafNodes()) { + values += leaf.getTaxonLabel()+"=" +leaf.getHeight() + ","; + } + traitSet.initByName("value", values.substring(0, values.length()-1), "taxa", taxonset, "traitname", "date-backward"); + + // Update segment trees: + if (enableSegTreeUpdateInput.get()) { + for (int segIdx = 0; segIdx < nSegments; segIdx++) { + // get the segment tree, set the taxonsets and update the segment tree based on the network + Tree segmentTree = new Tree(); + segmentTree.initByName("taxonset", taxonset); + updateSegmentTree(segmentTree, segIdx); + segmentTree.setID(IDs.get(segIdx)); +// segmentTreesInput.get().add( segmentTree); + segmentTree.setEverythingDirty(false); + segmentTreesInput.get().set(segIdx, segmentTree); + } + } + // Write simulated network to file if requested + super.initAndValidate(); + + for (int segIdx = 0; segIdx < nSegments; segIdx++) { + System.out.println(segmentTreesInput.get().get(segIdx).getID()); + System.out.println(segmentTreesInput.get().get(segIdx).getRoot().toNewick()); + System.out.println(segmentTreesInput.get().get(segIdx).getNodeCount()); + + } + + // Write simulated network to file if requested + if (fileNameInput.get() != null) { + try (PrintStream ps = new PrintStream(fileNameInput.get())) { + + ps.println(toString().replace("0.0)", "0.0000000001)")); + for (int segIdx = 0; segIdx < nSegments; segIdx++) { + ps.println(segmentTreesInput.get().get(segIdx).getRoot().toNewick() +";"); + } + } catch (FileNotFoundException ex) { + throw new RuntimeException("Error writing to output file '" + + fileNameInput.get() + "'."); + } + } + super.initAndValidate(); + } + + protected void simulateNetwork() { + List sampledIndividuals = new ArrayList<>(); + + // set the starting conditions for simulations + do { + structuredSirEvents = new ArrayList<>(); + // sample initial location + int initLoc = Randomizer.nextInt(states); + S = new int[states]; + I = new int[states]; + R = new int[states]; + for (int i = 0; i < states; i++) { + if (i == initLoc) { + S[i] = populationSizeInput.get().getValue(i) - 1; + I[i] = 1; + R[i] = 0; + } else { + S[i] = populationSizeInput.get().getValue(i); + I[i] = 0; + R[i] = 0; + } + } + int individualNr = 0; + // start the simulations while keeping track of the different individuals + List activeIndividuals = new ArrayList<>(); + sampledIndividuals = new ArrayList<>(); + + StructuredIndividual root = new StructuredIndividual(individualNr, initLoc); + individualNr++; + activeIndividuals.add(root); + double time = 0.0; + System.out.println("start"); + do { + double transmissionRate = 0; + double recoveryRate = 0; + double waningRate = 0; + for (int i = 0; i < states; i++) { + transmissionRate += I[i] * recoveryRateInput.get().getValue(i); + recoveryRate += I[i] * recoveryRateInput.get().getValue(i); + waningRate += R[i] * waningImmunityRateInput.get().getValue(i); + } + double migrationRate = 0; + int c = 0; + for (int i = 0; i < states; i++) { + for (int j = 0; j < states; j++) { + if (i!=j){ + migrationRate += I[i] * migrationRatesInput.get().getValue(c); + c++; + } + } + } + + double totalRate = transmissionRate + recoveryRate + waningRate + migrationRate; + double nextEventTime = Math.log(1.0 / Randomizer.nextDouble()) / totalRate; + + // pick which event happens next based on the rates, by drawing at random + double random = Randomizer.nextDouble() * totalRate; + + if (random < transmissionRate) { +// System.out.println("t"); + ReturnVal rv = transmit(activeIndividuals, transmissionRate, time, nextEventTime, individualNr); + time = rv.time; + individualNr = rv.individualNr; + } else if (random < (transmissionRate + recoveryRate)) { +// System.out.println("r"); + individualNr = recover(activeIndividuals, recoveryRate, time, nextEventTime, individualNr, sampledIndividuals); + }else if (random < (transmissionRate + recoveryRate + migrationRate)){ +// System.out.println("m"); + individualNr = migrate(activeIndividuals, migrationRate, time, nextEventTime, individualNr); + } else { + System.out.println(transmissionRate + recoveryRate); + System.out.println(transmissionRate + " " + recoveryRate + " " + waningRate); + System.out.println(I); +// System.out.println((S + I) ); + System.exit(0); + // waning immunity event + // update the population counts +// R--; +// S++; +// structuredSirEvents.add(new SIREvents(3, time + nextEventTime)); + } + time += nextEventTime; + // log every 1000th iteration + if (individualNr % 1000 == 0) + System.out.println(time + " S: " + getSumI(S) + " I: " + getSumI(I) + " R: " + getSumI(R) + " samples: " + sampledIndividuals.size() + " active: " + activeIndividuals.size()); + int[] activeI = new int[I.length]; + for (StructuredIndividual i : activeIndividuals){ + activeI[i.type]++; + } + for (int i = 0; i < states; i++) { + if (I[i] != activeI[i]) { + System.out.println("error"); + System.out.println(I[i] + " " + activeI[i]); + System.exit(0); + } + } + } while (time < simulationTimeInput.get() && getSumI(I) > 0); + System.out.println("start building network from " + sampledIndividuals.size() + " sampled individuals" + " simulation time was " + time); + // build the network from the sampled individuals + }while (sampledIndividuals.size() activeIndividuals, double transmissionRate, + double time, double nextEventTime, int individualNr){ + // pick the location of the transmission event + double randomT = Randomizer.nextDouble(); + double cummulative = 0; + for (int i = 0; i < states; i++){ + cummulative += I[i] * recoveryRateInput.get().getValue(i)/transmissionRate; + if (randomT<=cummulative){ + // pick the number of offsprings from a negative binomial distribution with R and k + // from a gamma and a poisson distribution + double secondary_infections = transmissionRateInput.get().getArrayValue()/recoveryRateInput.get().getArrayValue(i); + double gamma = Randomizer.nextGamma(kInput.get(), kInput.get() / (double) secondary_infections); + int nOffspring = (int) Randomizer.nextPoisson(gamma); + int isR = 0; + // for each offspring, randomly sample if it hits an R + for (int j = 0; j < nOffspring; j++) { + if (Randomizer.nextDouble() < (R[i]+1) / (double) (S[i] + I[i] + R[i])) { + // sample the individual + isR++; + } + } + nOffspring -= isR; + + // don't do anything if there are no offspring + if (nOffspring > 0){ + // transmission event + // pick a random individual to transmit to +// System.out.println("a " + activeIndividuals + " " + i + " " + Arrays.toString(I) + " " + transmissionRate); + StructuredIndividual parent = activeIndividuals.get(Randomizer.nextInt(activeIndividuals.size())); + while (parent.type!=i){ + parent = activeIndividuals.get(Randomizer.nextInt(activeIndividuals.size())); + } +// System.out.println(i + " "+ activeIndividuals + " "); + + int alreadyInfectedCount = 0; + + // build the superspreading event as a series of regular infection events + for (int j = 0; j < nOffspring; j++) { + // set time for parent + parent.setTime(time + nextEventTime); + // create two new individuals + StructuredIndividual child1 = new StructuredIndividual(individualNr, i); + individualNr++; + StructuredIndividual child2 = new StructuredIndividual(individualNr, i); + individualNr++; + // start the first new lineage + parent.addChild(child1); + child1.addParent(parent); + // add the child to the second parent + parent.addChild(child2); + child2.addParent(parent); + + // add the child to the list of active individuals + activeIndividuals.add(child1); + + // pick if the individual to be infected, assuming an individual is co-infected once at most + double probCoInf = S[i] / (double) (S[i] + I[i]-1-j); + if (Randomizer.nextDouble() > probCoInf) { + // offset for the number of active individuals to draw from + int totOptions = j==0? 1 : 2 + j; + + // co-infection event + // pick another active individual that is not the parent, while making sure it is not + // the same as any of the previous individuals in this event + + StructuredIndividual parent2 = activeIndividuals.get(Randomizer.nextInt(activeIndividuals.size()-totOptions)); + while (parent2 == parent || parent2.type!=i) { + parent2 = activeIndividuals.get(Randomizer.nextInt(activeIndividuals.size()-totOptions)); + } +// System.out.println(i + " " + parent2); + + parent2.setTime(time + nextEventTime); + + // add the child to the second parent + child2.setTime(time + nextEventTime); + + // create a third child + StructuredIndividual child3 = new StructuredIndividual(individualNr, i); + + individualNr++; + child3.addParent(child2); + child2.addChild(child3); + + child3.addParent(parent2); + parent2.addChild(child3); + + activeIndividuals.add(child3); + activeIndividuals.remove(parent2); + structuredSirEvents.add(new StructuredSIREvents(0, time + nextEventTime, i)); + } else { +// System.out.println("..."); + + // update the population counts + S[i]--; + I[i]++; + activeIndividuals.add(child2); + structuredSirEvents.add(new StructuredSIREvents(1, time + nextEventTime, i)); +// System.out.println("..."); + + } + // remove the parent + activeIndividuals.remove(parent); + parent = child1; + time += 0.0000000001; + } + } + return new ReturnVal(time, individualNr); + } + } + return new ReturnVal(time, individualNr); + } + + private int migrate(List activeIndividuals, + double migrationRate, double time, double nextEventTime, int individualNr){ + // pick the route of migration + double randomMig = Randomizer.nextDouble(); + double cummulative = 0; + int c = 0; + for (int i = 0; i activeIndividuals, double recoveryRate, + double time, double nextEventTime, int individualNr, + List sampledIndividuals){ + // pick the state to recover from + double randomRec = Randomizer.nextDouble(); + double cummulative = 0; + for (int i = 0; i sampledIndividuals) { + // get the time of the most recent sample + double mrsi_time = 0.0; + for (StructuredIndividual individual : sampledIndividuals){ + if(individual.getTime()>mrsi_time){ + mrsi_time = individual.getTime(); + } + } + + System.out.println("mrsi time: " + mrsi_time); + System.out.println("================="); + System.out.println("================="); + System.out.println("================="); + System.out.println("================="); + System.out.println("================="); + System.out.println("================="); + lineages = new ArrayList<>(); + + int sampledIndividualNr = 0; + int totalObservedCoal=0; + + List activeIndividuals = new ArrayList<>(); + List activeEdges = new ArrayList<>(); + System.out.println("start"); + while (activeIndividuals.size()>1 || sampledIndividuals.size()>0){ + System.out.println("get next sampling time"); + // get the next sampling event time, while keeping track of the index of the individual + double nextSamplingTime = Double.NEGATIVE_INFINITY; + int nextSamplingIndex = -1; + for (int i=0;inextSamplingTime){ + nextSamplingTime = sampledIndividuals.get(i).getTime(); + nextSamplingIndex = i; + } + } + System.out.println("get next active time"); + // get the activeIndividuals time, while keeping track of the index of the individual + double nextActiveTime = Double.NEGATIVE_INFINITY; + int nextActiveIndex = -1; + for (int i=0;i0) { + if (activeIndividuals.get(i).getParents().size() == 2){ + double nextParentTime = Math.max(activeIndividuals.get(i).getParents().get(0).getTime(), + activeIndividuals.get(i).getParents().get(1).getTime()); + if (nextParentTime>=nextActiveTime) { + nextActiveTime = nextParentTime; + nextActiveIndex = i; + } + } else if (activeIndividuals.get(i).getParents().get(0).getTime() > nextActiveTime) { + nextActiveTime = activeIndividuals.get(i).getParents().get(0).getTime(); + nextActiveIndex = i; + } +// else if ((activeIndividuals.get(i).getParents().get(0).getTime() == nextActiveTime)){ +// // if the other individual has two parents, then do nothing +// if (activeIndividuals.get(nextActiveIndex).getParents().size()==2){ +// System.out.println("do nothing");} +// else{ +// System.out.println(activeIndividuals.get(nextActiveIndex).getParents().size()); +// System.out.println(activeIndividuals.get(i).getParents().size()); +// System.out.println("----------------"); +// } +// +// } + } + } + System.out.println(" done " + nextSamplingTime +" " + nextActiveTime + ""); + + + // depending on which one is next + if (nextSamplingTime>nextActiveTime){ + System.out.println("sampling event"); + NetworkNode sampledNode = new NetworkNode(); + sampledNode.setHeight(mrsi_time-nextSamplingTime); + sampledNode.setTaxonLabel("sample_no"+sampledIndividualNr); + sampledNode.setTaxonIndex(sampledIndividualNr); + sampledNode.setMetaData(",type="+sampledIndividuals.get(nextSamplingIndex).type); + sampledIndividualNr++; + + // set segments for the sampled node, where all bits are true + BitSet segments = new BitSet(nSegments); + // set all bits to true + segments.set(0,nSegments); + // make a new edge + NetworkEdge edge = new NetworkEdge(null, sampledNode, segments); + activeEdges.add(edge); + activeIndividuals.add(sampledIndividuals.get(nextSamplingIndex)); + // remove the individual from the list of sampled individuals + sampledIndividuals.remove(nextSamplingIndex); + + lineages.add(new StructuredSIREvents(0, nextSamplingTime, -1)); + }else { +// check if the individual has two parents + if (activeIndividuals.get(nextActiveIndex).getParents().size()==2){ + System.out.println("co-infection event"); + // get the corresponding edge + NetworkEdge edge = activeEdges.get(nextActiveIndex); + // pick randomly which segments go left and which go right + BitSet segmentsLeft = new BitSet(nSegments); + BitSet segmentsRight = new BitSet(nSegments); + for (int i=0;i 0 + if (segmentsLeft.cardinality()>0 && segmentsRight.cardinality()>0){ + // compute the average observation probability + double prob = 0.0; + for (NetworkEdge e : activeEdges){ + prob += 1-2*Math.pow(0.5, e.hasSegments.cardinality()); + } + prob = prob/activeEdges.size(); + + System.out.println("observed reassortment event"); + // make a reassortment node + NetworkNode reassortmentNode = new NetworkNode(); + reassortmentNode.setMetaData(",type="+activeIndividuals.get(nextActiveIndex).type); + reassortmentNode.setHeight(mrsi_time-nextActiveTime); + // set the child edge + edge.parentNode = reassortmentNode; + reassortmentNode.addChildEdge(edge); + // make two new edges + NetworkEdge edgeLeft = new NetworkEdge(null, reassortmentNode, segmentsLeft); + NetworkEdge edgeRight = new NetworkEdge(null, reassortmentNode, segmentsRight); + // add the edges to the activeEdges list + reassortmentNode.addParentEdge(edgeLeft); + reassortmentNode.addParentEdge(edgeRight); + activeEdges.add(edgeLeft); + activeEdges.add(edgeRight); + + edgeLeft.childNode=reassortmentNode; + edgeRight.childNode=reassortmentNode; + + // add the parent individuals to the active individuals + activeIndividuals.add(activeIndividuals.get(nextActiveIndex).getParents().get(0)); + activeIndividuals.add(activeIndividuals.get(nextActiveIndex).getParents().get(1)); + + // remove the individual from the activeIndividuals list + activeIndividuals.remove(nextActiveIndex); + activeEdges.remove(nextActiveIndex); + lineages.add(new StructuredSIREvents(1, nextActiveTime, prob)); + }else{ + // the event is not observed, just replace the individual with its parent that has cardinality>0 + if (segmentsLeft.cardinality()>0) { + activeIndividuals.set(nextActiveIndex, activeIndividuals.get(nextActiveIndex).getParents().get(0)); + }else{ + activeIndividuals.set(nextActiveIndex, activeIndividuals.get(nextActiveIndex).getParents().get(1)); + } + } + }else{ + System.out.println("coalescent event"); + // get the other child for this coalescent event + StructuredIndividual otherChild = activeIndividuals.get(nextActiveIndex).getParents().get(0).getChildren().get(0); + if (otherChild.equals(activeIndividuals.get(nextActiveIndex))) { + otherChild = activeIndividuals.get(nextActiveIndex).getParents().get(0).getChildren().get(1); + } + // check if the other child is in the activeIndividuals list + if (activeIndividuals.contains(otherChild)) { + totalObservedCoal++; + // make a coalescent node + NetworkNode coalNode = new NetworkNode(); + coalNode.setMetaData(",type="+activeIndividuals.get(nextActiveIndex).type); + + coalNode.setHeight(mrsi_time - nextActiveTime); + // get the two child edges + NetworkEdge childEdge1 = activeEdges.get(nextActiveIndex); + NetworkEdge childEdge2 = activeEdges.get(activeIndividuals.indexOf(otherChild)); + // add the parent node to the edges + childEdge1.parentNode = coalNode; + childEdge2.parentNode = coalNode; + // add the edges to the coalNode + coalNode.addChildEdge(childEdge1); + coalNode.addChildEdge(childEdge2); + // make a new parent edge with the segments being the union of the two child edge segments + BitSet segments = (BitSet) childEdge1.hasSegments.clone(); + segments.or(childEdge2.hasSegments); + NetworkEdge parentEdge = new NetworkEdge(null, coalNode, segments); + // add the parent edge to the activeEdges list + activeEdges.add(parentEdge); + // remove the two child edges from the activeEdges list + activeEdges.remove(childEdge1); + activeEdges.remove(childEdge2); + activeIndividuals.add(activeIndividuals.get(nextActiveIndex).getParents().get(0)); + // remove the two children from the activeIndividuals list + activeIndividuals.remove(nextActiveIndex); + activeIndividuals.remove(otherChild); + lineages.add(new StructuredSIREvents(2, nextActiveTime, -1)); + } else { + // replace the activeIndividuals with its parent + activeIndividuals.set(nextActiveIndex, activeIndividuals.get(nextActiveIndex).getParents().get(0)); + } + } + } + // check if all the individuals in activeIndividuals have different numbers + for (int i=0; i parent; + List child; + + double time; + int number; + + int type; + + + int S, I; + + public StructuredIndividual(int number, int type){ + parent = new ArrayList<>(); + child = new ArrayList<>(); + this.number = number; + this.type = type; + } + + public List getParents(){ + return parent; + } + + public List getChildren(){ + return child; + } + + public double getTime(){ + return time; + } + + public void setTime(double time){ + this.time = time; + } + + public void addParent(StructuredIndividual parent){ + this.parent.add(parent); + } + + public void addChild(StructuredIndividual child){ + this.child.add(child); + } + + // build a method to check if two individual objects ar the same + public boolean equals(StructuredIndividual other){ + if (this.number == other.number){ + return true; + }else{ + return false; + } + } + + // define what happens in toString + public String toString(){ + return "" + number + ":" + type; + } + + public void setS(int S){ + this.S = S; + } + public int getS(){ + return S; + } + + public void setI(int I){ + this.I = I; + } + + public int getI(){ + return I; + } + + } + + protected class StructuredSIREvents{ + int eventType; + double time; + double prob = 0.0; + + int[] fromto; + + int type; + + public StructuredSIREvents(int eventType, double time, int type){ + this.eventType = eventType; + this.time = time; + this.type = type; + } + + public StructuredSIREvents(int eventType, double time, double prob){ + this.eventType = eventType; + this.time = time; + this.prob = prob; + } + + public StructuredSIREvents(int eventType, double time, int[] fromto){ + this.eventType = eventType; + this.time = time; + this.fromto = fromto; + } + + + @Override + public String toString() { + if (fromto!=null){ + return eventType + ":" + time + ":" + prob + ":" + fromto[0] +"_" + fromto[1]; + }else{ + return eventType + ":" + time + ":" + prob + ":" + type; + } + } + + + } + + +} + diff --git a/src/coalre/statistics/NetworkStatsLogger.java b/src/coalre/statistics/NetworkStatsLogger.java index dea3add..70fb2ea 100644 --- a/src/coalre/statistics/NetworkStatsLogger.java +++ b/src/coalre/statistics/NetworkStatsLogger.java @@ -135,4 +135,30 @@ public static double getTotalHeight(Network network, double[] rootHeights) { return maxHeight; } + public static double getLociMRCA(Network network){ + double maxHeight = 0.0; + for (int i = 0; i < network.getSegmentCount(); i++){ + double height = getHeightSegmentsRoot(network.getRootEdge(), i); + if (height > maxHeight) + maxHeight = height; + } + return maxHeight; + } + + private static double getHeightSegmentsRoot(NetworkEdge edge, int segment){ + NetworkNode n = edge.childNode; + if (n.isCoalescence()){ + if (n.getChildEdges().get(0).hasSegments.get(segment) + && n.getChildEdges().get(1).hasSegments.get(segment)) { + return n.getHeight(); + }else if (n.getChildEdges().get(0).hasSegments.get(segment)){ + return getHeightSegmentsRoot(n.getChildEdges().get(0), segment); + }else{ + return getHeightSegmentsRoot(n.getChildEdges().get(1), segment); + } + }else{ + return getHeightSegmentsRoot(n.getChildEdges().get(0), segment); + } + } + } diff --git a/src/main/java/org/example/Main.java b/src/main/java/org/example/Main.java new file mode 100644 index 0000000..407f157 --- /dev/null +++ b/src/main/java/org/example/Main.java @@ -0,0 +1,7 @@ +package org.example; + +public class Main { + public static void main(String[] args) { + System.out.println("Hello world!"); + } +} \ No newline at end of file diff --git a/src/module-info.java b/src/module-info.java new file mode 100644 index 0000000..626eea7 --- /dev/null +++ b/src/module-info.java @@ -0,0 +1,8 @@ +/** + * + */ +/** + * + */ +module CoalRe { +} \ No newline at end of file diff --git a/version.xml b/version.xml index 05082f1..bad836b 100644 --- a/version.xml +++ b/version.xml @@ -1,4 +1,4 @@ - + @@ -26,6 +26,9 @@ + + + @@ -33,6 +36,10 @@ + + + +