From 3a76b7d87a99d688fb14c3f2c93473dfa795c35a Mon Sep 17 00:00:00 2001 From: Marc Suchard Date: Fri, 8 Nov 2024 14:41:35 -0800 Subject: [PATCH 01/20] clean messy dependence between unrelated packages --- .../DirichletProcessOperator.java | 13 +- .../DirichletProcessPriorLogger.java | 5 +- .../LineageSpecificBranchModel.java | 214 +++++++++--------- src/dr/inference/model/BoundedSpace.java | 10 +- ...flectiveHamiltonianMonteCarloOperator.java | 5 +- src/dr/math/matrixAlgebra/WrappedVector.java | 12 + 6 files changed, 136 insertions(+), 123 deletions(-) diff --git a/src/dr/evomodel/branchmodel/lineagespecific/DirichletProcessOperator.java b/src/dr/evomodel/branchmodel/lineagespecific/DirichletProcessOperator.java index 332fd226a1..154477a095 100644 --- a/src/dr/evomodel/branchmodel/lineagespecific/DirichletProcessOperator.java +++ b/src/dr/evomodel/branchmodel/lineagespecific/DirichletProcessOperator.java @@ -27,6 +27,7 @@ package dr.evomodel.branchmodel.lineagespecific; +import dr.math.matrixAlgebra.WrappedVector; import org.apache.commons.math.MathException; import dr.inference.model.CompoundLikelihood; @@ -244,7 +245,7 @@ private void doOperate() throws MathException { if (DEBUG) { System.out.println("N[-index]: "); - dr.app.bss.Utils.printArray(occupancy); + System.out.println(new WrappedVector.Raw(occupancy)); } Likelihood clusterLikelihood = (Likelihood) likelihood.getLikelihood(index); @@ -288,11 +289,11 @@ private void doOperate() throws MathException { clusterProbs[i] = logprob; }// END: i loop - dr.app.bss.Utils.exponentiate(clusterProbs); + exponentiate(clusterProbs); if (DEBUG) { System.out.println("P(z[index] | z[-index]): "); - dr.app.bss.Utils.printArray(clusterProbs); + System.out.println(new WrappedVector.Raw(clusterProbs)); } // sample @@ -308,6 +309,12 @@ private void doOperate() throws MathException { }// END: doOperate + public static void exponentiate(double[] array) { + for (int i = 0; i < array.length; i++) { + array[i] = Math.exp(array[i]); + } + }// END: exponentiate + @Override public String getOperatorName() { return DirichletProcessOperatorParser.DIRICHLET_PROCESS_OPERATOR; diff --git a/src/dr/evomodel/branchmodel/lineagespecific/DirichletProcessPriorLogger.java b/src/dr/evomodel/branchmodel/lineagespecific/DirichletProcessPriorLogger.java index def82debb5..d0c2537c26 100644 --- a/src/dr/evomodel/branchmodel/lineagespecific/DirichletProcessPriorLogger.java +++ b/src/dr/evomodel/branchmodel/lineagespecific/DirichletProcessPriorLogger.java @@ -30,13 +30,12 @@ import java.util.ArrayList; import java.util.List; -import dr.app.bss.Utils; -import dr.inference.distribution.ParametricMultivariateDistributionModel; import dr.inference.loggers.LogColumn; import dr.inference.loggers.Loggable; import dr.inference.loggers.NumberColumn; import dr.inference.model.CompoundParameter; import dr.inference.model.Parameter; +import dr.math.MathUtils; import dr.math.distributions.NormalDistribution; public class DirichletProcessPriorLogger implements Loggable { @@ -112,7 +111,7 @@ private void getNew() { this.categoryProbabilities = getCategoryProbs(); - this.newCategoryIndex = Utils.sample(categoryProbabilities); + this.newCategoryIndex = MathUtils.randomChoicePDF(categoryProbabilities); this.meanForCategory = uniquelyRealizedParameters .getParameterValue(newCategoryIndex); diff --git a/src/dr/evomodel/branchmodel/lineagespecific/LineageSpecificBranchModel.java b/src/dr/evomodel/branchmodel/lineagespecific/LineageSpecificBranchModel.java index 811f68f469..0edca6ea3c 100644 --- a/src/dr/evomodel/branchmodel/lineagespecific/LineageSpecificBranchModel.java +++ b/src/dr/evomodel/branchmodel/lineagespecific/LineageSpecificBranchModel.java @@ -40,8 +40,8 @@ import dr.evomodel.tree.DefaultTreeModel; import dr.evomodel.treelikelihood.BeagleTreeLikelihood; import dr.evomodel.treelikelihood.PartialsRescalingScheme; -import dr.app.beagle.tools.BeagleSequenceSimulator; -import dr.app.beagle.tools.Partition; +//import dr.app.beagle.tools.BeagleSequenceSimulator; +//import dr.app.beagle.tools.Partition; import dr.evolution.alignment.Alignment; import dr.evolution.alignment.ConvertAlignment; import dr.evolution.datatype.Codons; @@ -204,111 +204,111 @@ protected void acceptState() { // }// END: acceptState - public static void main(String[] args) { - - try { - - // the seed of the BEAST - MathUtils.setSeed(666); - - // create tree - NewickImporter importer = new NewickImporter( - "(SimSeq1:73.7468,(SimSeq2:25.256989999999995,SimSeq3:45.256989999999995):18.48981);"); - TreeModel tree = new DefaultTreeModel(importer.importTree(null)); - - // create site model - GammaSiteRateModel siteRateModel = new GammaSiteRateModel( - "siteModel"); - - // create branch rate model - BranchRateModel branchRateModel = new DefaultBranchRateModel(); - - int sequenceLength = 10; - ArrayList partitionsList = new ArrayList(); - - // create Frequency Model - Parameter freqs = new Parameter.Default(new double[]{ - 0.0163936, // - 0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344, // - 0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344, // - 0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344, // - 0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344, // - 0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344 // - }); - FrequencyModel freqModel = new FrequencyModel(Codons.UNIVERSAL, - freqs); - - // create substitution model - Parameter alpha = new Parameter.Default(1, 10); - Parameter beta = new Parameter.Default(1, 5); - MG94HKYCodonModel mg94 = new MG94K80CodonModel(Codons.UNIVERSAL, alpha, beta, freqModel, new CodonOptions()); - - HomogeneousBranchModel substitutionModel = new HomogeneousBranchModel(mg94); - - // create partition - Partition partition1 = new Partition(tree, // - substitutionModel,// - siteRateModel, // - branchRateModel, // - freqModel, // - 0, // from - sequenceLength - 1, // to - 1 // every - ); - - partitionsList.add(partition1); - - // feed to sequence simulator and generate data - BeagleSequenceSimulator simulator = new BeagleSequenceSimulator( - partitionsList); - - Alignment alignment = simulator.simulate(false, false); - - ConvertAlignment convert = new ConvertAlignment(Nucleotides.INSTANCE, - GeneticCode.UNIVERSAL, alignment); - - - List substModels = new ArrayList(); - for (int i = 0; i < 2; i++) { -// alpha = new Parameter.Default(1, 10 ); -// beta = new Parameter.Default(1, 5 ); -// mg94 = new MG94HKYCodonModel(Codons.UNIVERSAL, alpha, beta, -// freqModel); - substModels.add(mg94); - } - - Parameter uCategories = new Parameter.Default(2, 0); -// CountableBranchCategoryProvider provider = new CountableBranchCategoryProvider.IndependentBranchCategoryModel(tree, uCategories); - - LineageSpecificBranchModel branchSpecific = new LineageSpecificBranchModel(tree, freqModel, substModels, //provider, - uCategories); - - BeagleTreeLikelihood like = new BeagleTreeLikelihood(convert, // - tree, // - branchSpecific, // - siteRateModel, // - branchRateModel, // - null, // - false, // - PartialsRescalingScheme.DEFAULT, true); - - BeagleTreeLikelihood gold = new BeagleTreeLikelihood(convert, // - tree, // - substitutionModel, // - siteRateModel, // - branchRateModel, // - null, // - false, // - PartialsRescalingScheme.DEFAULT, true); - - System.out.println("likelihood (gold) = " + gold.getLogLikelihood()); - System.out.println("likelihood = " + like.getLogLikelihood()); - - } catch (Exception e) { - e.printStackTrace(); - } - - }// END: main +// public static void main(String[] args) { +// +// try { +// +// // the seed of the BEAST +// MathUtils.setSeed(666); +// +// // create tree +// NewickImporter importer = new NewickImporter( +// "(SimSeq1:73.7468,(SimSeq2:25.256989999999995,SimSeq3:45.256989999999995):18.48981);"); +// TreeModel tree = new DefaultTreeModel(importer.importTree(null)); +// +// // create site model +// GammaSiteRateModel siteRateModel = new GammaSiteRateModel( +// "siteModel"); +// +// // create branch rate model +// BranchRateModel branchRateModel = new DefaultBranchRateModel(); +// +// int sequenceLength = 10; +// ArrayList partitionsList = new ArrayList(); +// +// // create Frequency Model +// Parameter freqs = new Parameter.Default(new double[]{ +// 0.0163936, // +// 0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344, // +// 0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344, // +// 0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344, // +// 0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344, // +// 0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344 // +// }); +// FrequencyModel freqModel = new FrequencyModel(Codons.UNIVERSAL, +// freqs); +// +// // create substitution model +// Parameter alpha = new Parameter.Default(1, 10); +// Parameter beta = new Parameter.Default(1, 5); +// MG94HKYCodonModel mg94 = new MG94K80CodonModel(Codons.UNIVERSAL, alpha, beta, freqModel, new CodonOptions()); +// +// HomogeneousBranchModel substitutionModel = new HomogeneousBranchModel(mg94); +// +// // create partition +// Partition partition1 = new Partition(tree, // +// substitutionModel,// +// siteRateModel, // +// branchRateModel, // +// freqModel, // +// 0, // from +// sequenceLength - 1, // to +// 1 // every +// ); +// +// partitionsList.add(partition1); +// +// // feed to sequence simulator and generate data +// BeagleSequenceSimulator simulator = new BeagleSequenceSimulator( +// partitionsList); +// +// Alignment alignment = simulator.simulate(false, false); +// +// ConvertAlignment convert = new ConvertAlignment(Nucleotides.INSTANCE, +// GeneticCode.UNIVERSAL, alignment); +// +// +// List substModels = new ArrayList(); +// for (int i = 0; i < 2; i++) { +//// alpha = new Parameter.Default(1, 10 ); +//// beta = new Parameter.Default(1, 5 ); +//// mg94 = new MG94HKYCodonModel(Codons.UNIVERSAL, alpha, beta, +//// freqModel); +// substModels.add(mg94); +// } +// +// Parameter uCategories = new Parameter.Default(2, 0); +//// CountableBranchCategoryProvider provider = new CountableBranchCategoryProvider.IndependentBranchCategoryModel(tree, uCategories); +// +// LineageSpecificBranchModel branchSpecific = new LineageSpecificBranchModel(tree, freqModel, substModels, //provider, +// uCategories); +// +// BeagleTreeLikelihood like = new BeagleTreeLikelihood(convert, // +// tree, // +// branchSpecific, // +// siteRateModel, // +// branchRateModel, // +// null, // +// false, // +// PartialsRescalingScheme.DEFAULT, true); +// +// BeagleTreeLikelihood gold = new BeagleTreeLikelihood(convert, // +// tree, // +// substitutionModel, // +// siteRateModel, // +// branchRateModel, // +// null, // +// false, // +// PartialsRescalingScheme.DEFAULT, true); +// +// System.out.println("likelihood (gold) = " + gold.getLogLikelihood()); +// System.out.println("likelihood = " + like.getLogLikelihood()); +// +// } catch (Exception e) { +// e.printStackTrace(); +// } +// +// }// END: main @Override public Citation.Category getCategory() { diff --git a/src/dr/inference/model/BoundedSpace.java b/src/dr/inference/model/BoundedSpace.java index 458dfd3143..22dcc609ff 100644 --- a/src/dr/inference/model/BoundedSpace.java +++ b/src/dr/inference/model/BoundedSpace.java @@ -27,13 +27,9 @@ package dr.inference.model; -import dr.app.bss.Utils; import dr.inference.operators.hmc.HamiltonianMonteCarloOperator; import dr.math.MathUtils; -import dr.math.matrixAlgebra.EJMLUtils; -import dr.math.matrixAlgebra.IllegalDimension; -import dr.math.matrixAlgebra.Matrix; -import dr.math.matrixAlgebra.SymmetricMatrix; +import dr.math.matrixAlgebra.*; import org.ejml.data.Complex64F; import org.ejml.data.DenseMatrix64F; import org.ejml.factory.DecompositionFactory; @@ -192,7 +188,7 @@ private double[] trajectoryEigenvalues(double[] origin, double[] direction) { System.out.println("Raw matrix to decompose: "); System.out.println(CinvV); System.out.print("Raw eigenvalues: "); - Utils.printArray(values); + System.out.println(new WrappedVector.Raw(values)); } for (int i = 0; i < values.length; i++) { values[i] = 1 / values[i]; @@ -282,7 +278,7 @@ public IntersectionDistances distancesToBoundary(double[] origin, double[] direc SymmetricMatrix Y = compoundCorrelationSymmetricMatrix(origin, dim); SymmetricMatrix X = compoundSymmetricMatrix(0.0, direction, dim); System.out.print("Eigenvalues: "); - Utils.printArray(values); + System.out.println(new WrappedVector.Raw(values)); Matrix S = new SymmetricMatrix(dim, dim); Matrix T = new SymmetricMatrix(dim, dim); diff --git a/src/dr/inference/operators/hmc/ReflectiveHamiltonianMonteCarloOperator.java b/src/dr/inference/operators/hmc/ReflectiveHamiltonianMonteCarloOperator.java index 052d6e3a33..a86a13e9c3 100644 --- a/src/dr/inference/operators/hmc/ReflectiveHamiltonianMonteCarloOperator.java +++ b/src/dr/inference/operators/hmc/ReflectiveHamiltonianMonteCarloOperator.java @@ -28,7 +28,6 @@ package dr.inference.operators.hmc; -import dr.app.bss.Utils; import dr.inference.hmc.GradientWrtParameterProvider; import dr.inference.model.*; import dr.inference.operators.AdaptationMode; @@ -340,7 +339,7 @@ public boolean doReflection(double[] position, WrappedVector momentum) { if (DEBUG) { System.out.println("time: " + eventTime); System.out.print("start: "); - Utils.printArray(position); + System.out.println(new WrappedVector.Raw(position)); System.out.println(momentum); } @@ -348,7 +347,7 @@ public boolean doReflection(double[] position, WrappedVector momentum) { if (DEBUG) { System.out.print("end: "); - Utils.printArray(position); + System.out.println(new WrappedVector.Raw(position)); System.out.println(momentum); } diff --git a/src/dr/math/matrixAlgebra/WrappedVector.java b/src/dr/math/matrixAlgebra/WrappedVector.java index 1a4923ad0f..4655b8233e 100644 --- a/src/dr/math/matrixAlgebra/WrappedVector.java +++ b/src/dr/math/matrixAlgebra/WrappedVector.java @@ -86,6 +86,18 @@ public Raw(double[] buffer) { this(buffer, 0, buffer.length); } + public Raw(int[] in) { + this(convert(in), 0, in.length); + } + + private static double[] convert(int[] in) { + double[] buffer = new double[in.length]; + for (int i = 0; i < in.length; ++i) { + buffer[i] = in[i]; + } + return buffer; + } + @Override final public double get(final int i) { return buffer[offset + i]; From 9ebca49796712f7831b007cfd822f1f8d0e830dd Mon Sep 17 00:00:00 2001 From: Andrew Rambaut Date: Sat, 9 Nov 2024 08:43:36 +0000 Subject: [PATCH 02/20] added feature to compare the target tree with a reference tree (i.e., the true tree or similar) --- .../app/tools/treeannotator/CladeSystem.java | 11 ++ .../tools/treeannotator/TreeAnnotator.java | 143 ++++++++++++------ 2 files changed, 111 insertions(+), 43 deletions(-) diff --git a/src/dr/app/tools/treeannotator/CladeSystem.java b/src/dr/app/tools/treeannotator/CladeSystem.java index 22b8eddd30..dca9b3081d 100644 --- a/src/dr/app/tools/treeannotator/CladeSystem.java +++ b/src/dr/app/tools/treeannotator/CladeSystem.java @@ -323,6 +323,16 @@ public int getCladeCount() { return cladeMap.keySet().size(); } + public int getCommonCladeCount(CladeSystem referenceCladeSystem) { + int count = 0; + for (Object key : cladeMap.keySet()) { + if (referenceCladeSystem.cladeMap.keySet().contains(key)) { + count ++; + } + } + return count; + } + // // Private stuff // @@ -333,4 +343,5 @@ public int getCladeCount() { private final Map cladeMap = new HashMap<>(); Clade rootClade; + } diff --git a/src/dr/app/tools/treeannotator/TreeAnnotator.java b/src/dr/app/tools/treeannotator/TreeAnnotator.java index 9d08573f08..763a8d9cfc 100644 --- a/src/dr/app/tools/treeannotator/TreeAnnotator.java +++ b/src/dr/app/tools/treeannotator/TreeAnnotator.java @@ -121,6 +121,7 @@ public TreeAnnotator(final int burninTrees, boolean computeESS, Target targetOption, String targetTreeFileName, + String referenceTreeFileName, String inputFileName, String outputFileName ) throws IOException { @@ -146,32 +147,32 @@ public TreeAnnotator(final int burninTrees, // read the clades in even if a target tree so it can have its stats reported // if (targetOption != Target.USER_TARGET_TREE) { - // if we are not just annotating a specific target tree - // then we need to read all the trees into a CladeSystem - // to get Clade and SubTree frequencies. - if (COUNT_TREES) { - countTrees(inputFileName); - progressStream.println("Reading trees..."); - } else { - totalTrees = 10000; - progressStream.println("Reading trees (assuming 10,000 trees)..."); - } + // if we are not just annotating a specific target tree + // then we need to read all the trees into a CladeSystem + // to get Clade and SubTree frequencies. + if (COUNT_TREES) { + countTrees(inputFileName); + progressStream.println("Reading trees..."); + } else { + totalTrees = 10000; + progressStream.println("Reading trees (assuming 10,000 trees)..."); + } - burnin = readTrees(inputFileName, burninTrees, burninStates, cladeSystem); + burnin = readTrees(inputFileName, burninTrees, burninStates, cladeSystem); - cladeSystem.calculateCladeCredibilities(totalTreesUsed); + cladeSystem.calculateCladeCredibilities(totalTreesUsed); - progressStream.println("Total trees read: " + totalTrees); - progressStream.println("Size of trees: " + taxa.getTaxonCount() + " tips"); - if (burninTrees > 0) { - progressStream.println("Ignoring first " + burninTrees + " trees" + - (burninStates > 0 ? " (" + burninStates + " states)." : "." )); - } else if (burninStates > 0) { - progressStream.println("Ignoring first " + burninStates + " states (" + burnin + " trees)."); - } + progressStream.println("Total trees read: " + totalTrees); + progressStream.println("Size of trees: " + taxa.getTaxonCount() + " tips"); + if (burninTrees > 0) { + progressStream.println("Ignoring first " + burninTrees + " trees" + + (burninStates > 0 ? " (" + burninStates + " states)." : "." )); + } else if (burninStates > 0) { + progressStream.println("Ignoring first " + burninStates + " states (" + burnin + " trees)."); + } - progressStream.println("Total unique clades: " + cladeSystem.getCladeCount()); - progressStream.println(); + progressStream.println("Total unique clades: " + cladeSystem.getCladeCount()); + progressStream.println(); // } MutableTree targetTree = null; @@ -179,7 +180,7 @@ public TreeAnnotator(final int burninTrees, switch (targetOption) { case USER_TARGET_TREE: { if (targetTreeFileName != null) { - targetTree = readUserTargetTree(targetTreeFileName, targetTree, cladeSystem); + targetTree = readUserTargetTree(targetTreeFileName, cladeSystem); } else { System.err.println("No user target tree specified."); System.exit(1); @@ -199,11 +200,21 @@ public TreeAnnotator(final int burninTrees, default: throw new IllegalArgumentException("Unknown targetOption"); } - // Help garbage collector - cladeSystem = null; - CladeSystem targetCladeSystem = new CladeSystem(targetTree); + + if (referenceTreeFileName != null) { + progressStream.println("Reading reference tree: " + referenceTreeFileName); + + MutableTree referenceTree = readTreeFile(referenceTreeFileName); + CladeSystem referenceCladeSystem = new CladeSystem(referenceTree); + + int commonCladeCount = targetCladeSystem.getCommonCladeCount(referenceCladeSystem); + progressStream.println("Clades in common with reference tree: " + commonCladeCount + + " (out of " + referenceCladeSystem.getCladeCount() + ")"); + progressStream.println(); + } + collectNodeAttributes(targetCladeSystem, inputFileName, burnin); annotateTargetTree(targetCladeSystem, heightsOption, targetTree); @@ -394,33 +405,40 @@ public void setupAttributes(Tree tree) { annotationAction.addAttributeNames(attributeNames); } - private MutableTree readUserTargetTree(String targetTreeFileName, MutableTree targetTree, CladeSystem cladeSystem) throws IOException { + private MutableTree readUserTargetTree(String targetTreeFileName, CladeSystem cladeSystem) throws IOException { progressStream.println("Reading user specified target tree, " + targetTreeFileName + ", ..."); - NexusImporter importer = new NexusImporter(new FileReader(targetTreeFileName)); + MutableTree targetTree = readTreeFile(targetTreeFileName); + + progressStream.println(); + double score = scoreTree(targetTree, cladeSystem); + progressStream.println("Target tree's log clade credibility: " + String.format("%.4f", score)); + reportStatistics(cladeSystem, targetTree); + reportStatisticTables(cladeSystem, targetTree); + + progressStream.println(); + return targetTree; + } + + + private static MutableTree readTreeFile(String treeFileName) throws IOException { + NexusImporter importer = new NexusImporter(new FileReader(treeFileName)); + Tree tree = null; try { - Tree tree = importer.importNextTree(); + tree = importer.importNextTree(); if (tree == null) { - NewickImporter x = new NewickImporter(new FileReader(targetTreeFileName)); + NewickImporter x = new NewickImporter(new FileReader(treeFileName)); tree = x.importNextTree(); } if (tree == null) { - System.err.println("No tree in target nexus or newick file " + targetTreeFileName); + System.err.println("No tree in nexus or newick file " + treeFileName); System.exit(1); } - targetTree = new FlexibleTree(tree); } catch (Importer.ImportException e) { System.err.println("Error Parsing Target Tree: " + e.getMessage()); System.exit(1); } - - progressStream.println(); - double score = scoreTree(targetTree, cladeSystem); - progressStream.println("Target tree's log clade credibility: " + String.format("%.4f", score)); - reportStatistics(cladeSystem, targetTree); - - progressStream.println(); - return targetTree; + return new FlexibleTree(tree); } private Tree getMCCTree(int burnin, CladeSystem cladeSystem, String inputFileName) @@ -472,6 +490,7 @@ private Tree getMCCTree(int burnin, CladeSystem cladeSystem, String inputFileNam progressStream.println("Best tree: " + bestTree.getId() + " (tree number " + bestTreeNumber + ")"); progressStream.println("Best tree's log clade credibility: " + String.format("%.4f", bestScore)); reportStatistics(cladeSystem, bestTree); + reportStatisticTables(cladeSystem, bestTree); progressStream.println(); return bestTree; @@ -490,6 +509,7 @@ private MutableTree getHIPSTRTree(CladeSystem cladeSystem) { progressStream.println(); progressStream.println("HIPSTR tree's log clade credibility: " + String.format("%.4f", score)); reportStatistics(cladeSystem, tree); +// reportStatisticTables(cladeSystem, tree); progressStream.println(); return tree; @@ -505,7 +525,26 @@ private static void reportStatistics(CladeSystem cladeSystem, Tree tree) { progressStream.println("Number of clades with credibility > 0.95: " + cladeSystem.getTopCladeCredibility(tree, 0.95) + " (out of " + cladeSystem.getTopCladeCredibility(0.95) + " in all trees)"); progressStream.println("Number of clades with credibility > 0.5: " + cladeSystem.getTopCladeCredibility(tree, 0.5) + - " (out of " + cladeSystem.getTopCladeCredibility(0.5) + " in all trees)"); + " (out of " + cladeSystem.getTopCladeCredibility(0.5) + " in all trees)"); + } + private static void reportStatisticTables(CladeSystem cladeSystem, Tree tree) { + int count = 100; +// double[] table = new double[count + 1]; +// for (int i = 0; i <= count; i++) { +// double threshold = ((double) (i)) / count; +// table[i] = cladeSystem.getTopCladeCredibility(tree, threshold); +// } + + progressStream.println("threshold, #clades"); + for (int i = 0; i <= count; i++) { + double threshold = ((double) (i)) / count; + progressStream.print(threshold); + progressStream.print(","); + progressStream.print(cladeSystem.getTopCladeCredibility(tree, threshold)); + progressStream.print(","); + progressStream.println(cladeSystem.getTopCladeCredibility(threshold)); + } + } private void annotateTargetTree(CladeSystem cladeSystem, HeightsSummary heightsOption, MutableTree targetTree) { @@ -624,6 +663,7 @@ public static void main(String[] args) throws IOException { Locale.setDefault(Locale.US); String targetTreeFileName = null; + String referenceTreeFileName = null; String inputFileName = null; String outputFileName = null; @@ -704,6 +744,7 @@ public static void main(String[] args) throws IOException { computeESS, targetOption, targetTreeFileName, + referenceTreeFileName, inputFileName, outputFileName); @@ -733,6 +774,7 @@ public static void main(String[] args) throws IOException { new Arguments.IntegerOption("burninTrees", "the number of trees to be considered as 'burn-in'"), new Arguments.RealOption("limit", "the minimum posterior probability for a node to be annotated"), new Arguments.StringOption("target", "target_file_name", "specifies a user target tree to be annotated"), + new Arguments.StringOption("reference", "tree_file_name", "specifies a reference tree for sampled trees to be compared with"), new Arguments.Option("help", "option to print this message"), new Arguments.Option("forceDiscrete", "forces integer traits to be treated as discrete traits."), new Arguments.StringOption("hpd2D", "the HPD interval to be used for the bivariate traits", "specifies a (vector of comma separated) HPD proportion(s)"), @@ -813,6 +855,10 @@ public static void main(String[] args) throws IOException { targetTreeFileName = arguments.getStringOption("target"); } + if (arguments.hasOption("reference")) { + referenceTreeFileName = arguments.getStringOption("reference"); + } + final String[] args2 = arguments.getLeftoverArguments(); switch (args2.length) { @@ -830,7 +876,18 @@ public static void main(String[] args) throws IOException { } } - new TreeAnnotator(burninTrees, burninStates, heights, posteriorLimit, hpd2D, computeESS, target, targetTreeFileName, inputFileName, outputFileName); + new TreeAnnotator( + burninTrees, + burninStates, + heights, + posteriorLimit, + hpd2D, + computeESS, + target, + targetTreeFileName, + referenceTreeFileName, + inputFileName, + outputFileName); if (target == Target.MAX_CLADE_CREDIBILITY) { progressStream.println("Found Maximum Clade Credibility (MCC) tree - citation: " + @@ -838,7 +895,7 @@ public static void main(String[] args) throws IOException { } else if (target == Target.HIPSTR) { progressStream.println("Constructed Highest Independent Posterior Sub-Tree Reconstruction (HIPSTR) tree - citation: In prep."); } else if (target == Target.USER_TARGET_TREE) { - progressStream.println("Loaded user target tree."); +// progressStream.println("Loaded user target tree."); } if (heights == HeightsSummary.CA_HEIGHTS) { From f49f0ca0789538246663793a71f57a9cceee00be Mon Sep 17 00:00:00 2001 From: Andrew Rambaut Date: Sat, 9 Nov 2024 08:45:02 +0000 Subject: [PATCH 03/20] Commenting the table report out for now --- src/dr/app/tools/treeannotator/TreeAnnotator.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/dr/app/tools/treeannotator/TreeAnnotator.java b/src/dr/app/tools/treeannotator/TreeAnnotator.java index 763a8d9cfc..e6c8a45cd5 100644 --- a/src/dr/app/tools/treeannotator/TreeAnnotator.java +++ b/src/dr/app/tools/treeannotator/TreeAnnotator.java @@ -490,7 +490,7 @@ private Tree getMCCTree(int burnin, CladeSystem cladeSystem, String inputFileNam progressStream.println("Best tree: " + bestTree.getId() + " (tree number " + bestTreeNumber + ")"); progressStream.println("Best tree's log clade credibility: " + String.format("%.4f", bestScore)); reportStatistics(cladeSystem, bestTree); - reportStatisticTables(cladeSystem, bestTree); +// reportStatisticTables(cladeSystem, bestTree); progressStream.println(); return bestTree; From 2bcf2bc625ab8e797fdc7efc0f61f97df2519678 Mon Sep 17 00:00:00 2001 From: GuyBaele Date: Mon, 11 Nov 2024 22:16:25 +0100 Subject: [PATCH 04/20] missing case statement --- .../MarginalLikelihoodEstimationGenerator.java | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/dr/app/beauti/components/marginalLikelihoodEstimation/MarginalLikelihoodEstimationGenerator.java b/src/dr/app/beauti/components/marginalLikelihoodEstimation/MarginalLikelihoodEstimationGenerator.java index 8838f9cc06..21f59d727a 100644 --- a/src/dr/app/beauti/components/marginalLikelihoodEstimation/MarginalLikelihoodEstimationGenerator.java +++ b/src/dr/app/beauti/components/marginalLikelihoodEstimation/MarginalLikelihoodEstimationGenerator.java @@ -1025,8 +1025,11 @@ public void writeMLE(XMLWriter writer, MarginalLikelihoodEstimationOptions optio writer.writeIDref(RandomLocalClockModelParser.LOCAL_BRANCH_RATES, model.getPrefix() + BranchRateModel.BRANCH_RATES); break; + case MIXED_EFFECTS_CLOCK: + break; + default: - throw new IllegalArgumentException("Unknown clock model"); + throw new IllegalArgumentException("Unknown clock model: " + model.getClockType()); } } From e893e70a9e7c56902d9a8e6369c608406cc10117 Mon Sep 17 00:00:00 2001 From: GuyBaele Date: Sat, 7 Dec 2024 19:24:00 +0100 Subject: [PATCH 05/20] relative rate bug fix for AA models --- src/dr/app/beauti/generator/SubstitutionModelGenerator.java | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/dr/app/beauti/generator/SubstitutionModelGenerator.java b/src/dr/app/beauti/generator/SubstitutionModelGenerator.java index 368448d91f..b972cec104 100644 --- a/src/dr/app/beauti/generator/SubstitutionModelGenerator.java +++ b/src/dr/app/beauti/generator/SubstitutionModelGenerator.java @@ -753,7 +753,7 @@ private void writeTwoStateSiteModel(XMLWriter writer, PartitionSubstitutionModel if (options.useNuRelativeRates()) { Parameter parameter = model.getParameter("nu"); String prefix1 = options.getPrefix(); - if (!parameter.getSubParameters().isEmpty()) { + if (parameter.getParent() != null && !parameter.getSubParameters().isEmpty()) { writeNuRelativeRateBlock(writer, prefix1, parameter); } } else { @@ -802,7 +802,9 @@ private void writeAASiteModel(XMLWriter writer, PartitionSubstitutionModel model if (options.useNuRelativeRates()) { Parameter parameter = model.getParameter("nu"); - writeNuRelativeRateBlock(writer, prefix, parameter); + if (parameter.getParent() != null && !parameter.getSubParameters().isEmpty()) { + writeNuRelativeRateBlock(writer, prefix, parameter); + } } else { writeParameter(SiteModelParser.RELATIVE_RATE, "mu", model, writer); } From 639edb606ce677345b9be64aabfa7269633ea288 Mon Sep 17 00:00:00 2001 From: GuyBaele Date: Sun, 8 Dec 2024 22:17:16 +0100 Subject: [PATCH 06/20] fixing working prior for mixed effect clock model --- .../MarginalLikelihoodEstimationGenerator.java | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/src/dr/app/beauti/components/marginalLikelihoodEstimation/MarginalLikelihoodEstimationGenerator.java b/src/dr/app/beauti/components/marginalLikelihoodEstimation/MarginalLikelihoodEstimationGenerator.java index 21f59d727a..8b19fdcfde 100644 --- a/src/dr/app/beauti/components/marginalLikelihoodEstimation/MarginalLikelihoodEstimationGenerator.java +++ b/src/dr/app/beauti/components/marginalLikelihoodEstimation/MarginalLikelihoodEstimationGenerator.java @@ -37,6 +37,8 @@ import dr.evolution.util.Taxa; import dr.evolution.util.Units; import dr.evomodel.branchratemodel.BranchRateModel; +import dr.evomodel.branchratemodel.BranchSpecificFixedEffects; +import dr.inference.distribution.DistributionLikelihood; import dr.evomodel.tree.DefaultTreeModel; import dr.evomodelxml.TreeWorkingPriorParsers; import dr.evomodelxml.branchratemodel.*; @@ -1026,6 +1028,20 @@ public void writeMLE(XMLWriter writer, MarginalLikelihoodEstimationOptions optio break; case MIXED_EFFECTS_CLOCK: + + writer.writeIDref(DistributionLikelihood.DISTRIBUTION_LIKELIHOOD, model.getPrefix() + BranchSpecificFixedEffects.RATES_PRIOR); + writer.writeIDref(DistributionLikelihood.DISTRIBUTION_LIKELIHOOD, model.getPrefix() + BranchSpecificFixedEffects.SCALE_PRIOR); + writer.writeIDref(DistributionLikelihood.DISTRIBUTION_LIKELIHOOD, model.getPrefix() + BranchSpecificFixedEffects.INTERCEPT_PRIOR); + + String coeff = BranchSpecificFixedEffectsParser.COEFFICIENT; + int number = 1; + String concat = coeff + number; + while (model.hasParameter(concat)) { + writer.writeIDref(DistributionLikelihood.DISTRIBUTION_LIKELIHOOD, model.getPrefix() + BranchSpecificFixedEffectsParser.FIXED_EFFECTS_LIKELIHOOD + number); + number++; + concat = coeff + number; + } + break; default: From d911c6290131eb32199c4674993505770bff5051 Mon Sep 17 00:00:00 2001 From: GuyBaele Date: Sun, 8 Dec 2024 22:19:52 +0100 Subject: [PATCH 07/20] fixing writeBranchRatesModelRef for mixed effects clock model --- .../app/beauti/generator/ClockModelGenerator.java | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/src/dr/app/beauti/generator/ClockModelGenerator.java b/src/dr/app/beauti/generator/ClockModelGenerator.java index 093411b4e5..af0c9be6b8 100644 --- a/src/dr/app/beauti/generator/ClockModelGenerator.java +++ b/src/dr/app/beauti/generator/ClockModelGenerator.java @@ -30,7 +30,6 @@ import dr.app.beauti.components.ComponentFactory; import dr.app.beauti.options.*; import dr.app.beauti.types.ClockType; -import dr.app.beauti.types.OperatorType; import dr.app.beauti.util.XMLWriter; import dr.evolution.util.Taxa; import dr.evomodel.branchratemodel.ArbitraryBranchRates; @@ -66,8 +65,6 @@ import dr.util.Attribute; import dr.xml.XMLParser; -import java.util.Map; - import static dr.inference.model.ParameterParser.PARAMETER; import static dr.inferencexml.distribution.PriorParsers.*; import static dr.inferencexml.distribution.shrinkage.BayesianBridgeLikelihoodParser.*; @@ -958,18 +955,18 @@ public static void writeBranchRatesModelRef(PartitionClockModel model, XMLWriter case MIXED_EFFECTS_CLOCK: //always write distribution likelihoods for rate, scale and intercept - writer.writeIDref(DistributionLikelihood.DISTRIBUTION_LIKELIHOOD, BranchSpecificFixedEffects.RATES_PRIOR); - writer.writeIDref(DistributionLikelihood.DISTRIBUTION_LIKELIHOOD, BranchSpecificFixedEffects.SCALE_PRIOR); - writer.writeIDref(DistributionLikelihood.DISTRIBUTION_LIKELIHOOD, BranchSpecificFixedEffects.INTERCEPT_PRIOR); + //writer.writeIDref(DistributionLikelihood.DISTRIBUTION_LIKELIHOOD, BranchSpecificFixedEffects.RATES_PRIOR); + //writer.writeIDref(DistributionLikelihood.DISTRIBUTION_LIKELIHOOD, BranchSpecificFixedEffects.SCALE_PRIOR); + //writer.writeIDref(DistributionLikelihood.DISTRIBUTION_LIKELIHOOD, BranchSpecificFixedEffects.INTERCEPT_PRIOR); //check for coefficients - String coeff = BranchSpecificFixedEffectsParser.COEFFICIENT; + /*String coeff = BranchSpecificFixedEffectsParser.COEFFICIENT; int number = 1; String concat = coeff + number; while (model.hasParameter(concat)) { writer.writeIDref(DistributionLikelihood.DISTRIBUTION_LIKELIHOOD, BranchSpecificFixedEffectsParser.FIXED_EFFECTS_LIKELIHOOD + number); number++; concat = coeff + number; - } + }*/ tag = ArbitraryBranchRatesParser.ARBITRARY_BRANCH_RATES; id = model.getPrefix() + ArbitraryBranchRates.BRANCH_RATES; break; From fab2e581428dec5caada6e8297de654194ffe748 Mon Sep 17 00:00:00 2001 From: GuyBaele Date: Sun, 8 Dec 2024 22:56:10 +0100 Subject: [PATCH 08/20] intermediate commit for trying to fix parameter prior generation - but partition clock models are no longer there by the time we reach writeParameterPrior --- .../generator/ParameterPriorGenerator.java | 34 +++++++++++++++---- 1 file changed, 28 insertions(+), 6 deletions(-) diff --git a/src/dr/app/beauti/generator/ParameterPriorGenerator.java b/src/dr/app/beauti/generator/ParameterPriorGenerator.java index 339932202e..501ec7c538 100644 --- a/src/dr/app/beauti/generator/ParameterPriorGenerator.java +++ b/src/dr/app/beauti/generator/ParameterPriorGenerator.java @@ -35,13 +35,13 @@ import dr.evolution.util.Taxa; import dr.evomodel.branchratemodel.BranchSpecificFixedEffects; import dr.evomodel.tree.DefaultTreeModel; +import dr.evomodelxml.branchratemodel.BranchSpecificFixedEffectsParser; import dr.evomodelxml.coalescent.GMRFSkyrideLikelihoodParser; import dr.evomodelxml.tree.CTMCScalePriorParser; import dr.evomodelxml.tree.MonophylyStatisticParser; import dr.inference.distribution.DistributionLikelihood; import dr.inference.model.ParameterParser; import dr.inferencexml.distribution.CachedDistributionLikelihoodParser; -import dr.inferencexml.distribution.DistributionLikelihoodParser; import dr.inferencexml.distribution.PriorParsers; import dr.inferencexml.model.BooleanLikelihoodParser; import dr.inferencexml.model.OneOnXPriorParser; @@ -58,19 +58,41 @@ */ public class ParameterPriorGenerator extends Generator { - //map parameters to prior IDs, for use with HMC - private HashMap mapParameterToPrior; + //map parameters to prior IDs, for use with HMC or other approaches that define their prior befor the XML block + private final HashMap mapParameterToPrior; public ParameterPriorGenerator(BeautiOptions options, ComponentFactory[] components) { super(options, components); //TODO don't like this being here, but will see how things pan out as more HMC approaches are added + int totalModels = options.getPartitionClockModels().size(); + List partitionClockModels = options.getPartitionClockModels(); mapParameterToPrior = new HashMap(); //HMC skygrid mapParameterToPrior.put(GMRFSkyrideLikelihoodParser.SKYGRID_PRECISION, GMRFSkyrideLikelihoodParser.SKYGRID_PRECISION_PRIOR); //HMC relaxed clock - mapParameterToPrior.put(ClockType.HMC_CLOCK_LOCATION, BranchSpecificFixedEffects.LOCATION_PRIOR); - mapParameterToPrior.put(ClockType.HMC_CLOCK_BRANCH_RATES, BranchSpecificFixedEffects.RATES_PRIOR); - mapParameterToPrior.put(ClockType.HMCLN_SCALE, BranchSpecificFixedEffects.SCALE_PRIOR); + for (int i = 0; i < totalModels; i++) { + String prefix = partitionClockModels.get(i).getPrefix(); + mapParameterToPrior.put(ClockType.HMC_CLOCK_LOCATION, prefix + BranchSpecificFixedEffects.LOCATION_PRIOR); + mapParameterToPrior.put(ClockType.HMC_CLOCK_BRANCH_RATES, prefix + BranchSpecificFixedEffects.RATES_PRIOR); + mapParameterToPrior.put(ClockType.HMCLN_SCALE, prefix + BranchSpecificFixedEffects.SCALE_PRIOR); + } + //mixed effects clock + //always write distribution likelihoods for rate, scale and intercept + for (int i = 0; i < totalModels; i++) { + String prefix = partitionClockModels.get(i).getPrefix(); + mapParameterToPrior.put(ClockType.ME_CLOCK_LOCATION, prefix + BranchSpecificFixedEffects.RATES_PRIOR); + mapParameterToPrior.put(ClockType.ME_CLOCK_SCALE, prefix + BranchSpecificFixedEffects.SCALE_PRIOR); + mapParameterToPrior.put(BranchSpecificFixedEffectsParser.INTERCEPT, prefix + BranchSpecificFixedEffects.INTERCEPT_PRIOR); + //check for coefficients + String coeff = BranchSpecificFixedEffectsParser.COEFFICIENT; + int number = 1; + String concat = coeff + number; + while (partitionClockModels.get(i).hasParameter(concat)) { + mapParameterToPrior.put(DistributionLikelihood.DISTRIBUTION_LIKELIHOOD, prefix + BranchSpecificFixedEffectsParser.FIXED_EFFECTS_LIKELIHOOD + number); + number++; + concat = coeff + number; + } + } } /** From 1f36ec6eb12df2ba0dddcebc635509e1b162a0be Mon Sep 17 00:00:00 2001 From: GuyBaele Date: Mon, 9 Dec 2024 11:56:28 +0100 Subject: [PATCH 09/20] fix parameter prior generation for mixed effects clock model --- .../beauti/generator/ParameterPriorGenerator.java | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/src/dr/app/beauti/generator/ParameterPriorGenerator.java b/src/dr/app/beauti/generator/ParameterPriorGenerator.java index 501ec7c538..aa8c0613bb 100644 --- a/src/dr/app/beauti/generator/ParameterPriorGenerator.java +++ b/src/dr/app/beauti/generator/ParameterPriorGenerator.java @@ -64,9 +64,16 @@ public class ParameterPriorGenerator extends Generator { public ParameterPriorGenerator(BeautiOptions options, ComponentFactory[] components) { super(options, components); //TODO don't like this being here, but will see how things pan out as more HMC approaches are added + mapParameterToPrior = new HashMap(); + } + + /** + * Add all possibly previously defined priors to a HashMap + * Cannot be done in constructor as the models have not been defined by the user at that point + */ + public void addParametersToPrior() { int totalModels = options.getPartitionClockModels().size(); List partitionClockModels = options.getPartitionClockModels(); - mapParameterToPrior = new HashMap(); //HMC skygrid mapParameterToPrior.put(GMRFSkyrideLikelihoodParser.SKYGRID_PRECISION, GMRFSkyrideLikelihoodParser.SKYGRID_PRECISION_PRIOR); //HMC relaxed clock @@ -88,7 +95,7 @@ public ParameterPriorGenerator(BeautiOptions options, ComponentFactory[] compone int number = 1; String concat = coeff + number; while (partitionClockModels.get(i).hasParameter(concat)) { - mapParameterToPrior.put(DistributionLikelihood.DISTRIBUTION_LIKELIHOOD, prefix + BranchSpecificFixedEffectsParser.FIXED_EFFECTS_LIKELIHOOD + number); + mapParameterToPrior.put(concat, prefix + BranchSpecificFixedEffectsParser.FIXED_EFFECTS_LIKELIHOOD + number); number++; concat = coeff + number; } @@ -101,6 +108,10 @@ public ParameterPriorGenerator(BeautiOptions options, ComponentFactory[] compone * @param writer the writer */ public void writeParameterPriors(XMLWriter writer) { + + //first make sure that all possibly previously defined priors are part of the HashMap + addParametersToPrior(); + boolean first = true; for (Map.Entry taxaBooleanEntry : options.taxonSetsMono.entrySet()) { From f89cb9421529ddc3b7a36dd0f095697b8408d976 Mon Sep 17 00:00:00 2001 From: GuyBaele Date: Mon, 9 Dec 2024 15:19:28 +0100 Subject: [PATCH 10/20] add HMC relaxed clock descriptions as String constants --- src/dr/app/beauti/types/ClockType.java | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/dr/app/beauti/types/ClockType.java b/src/dr/app/beauti/types/ClockType.java index 09029542d9..e88d521d85 100644 --- a/src/dr/app/beauti/types/ClockType.java +++ b/src/dr/app/beauti/types/ClockType.java @@ -68,4 +68,7 @@ public String toString() { final public static String ACLD_MEAN = "acld.mean"; final public static String ACLD_STDEV = "acld.stdev"; + + final public static String HMC_CLOCK_RATES_DESCRIPTION = "HMC relaxed clock branch rates"; + final public static String HMC_CLOCK_LOCATION_SCALE_DESCRIPTION = "HMC relaxed clock location and scale"; } \ No newline at end of file From 3f46b69c1b5fd9372cc57aba2d204908fdcbff4b Mon Sep 17 00:00:00 2001 From: GuyBaele Date: Mon, 9 Dec 2024 15:19:47 +0100 Subject: [PATCH 11/20] actually make use of said String constants --- src/dr/app/beauti/options/PartitionClockModel.java | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/dr/app/beauti/options/PartitionClockModel.java b/src/dr/app/beauti/options/PartitionClockModel.java index dd2d6452f0..397812f069 100644 --- a/src/dr/app/beauti/options/PartitionClockModel.java +++ b/src/dr/app/beauti/options/PartitionClockModel.java @@ -130,7 +130,7 @@ public void initModelParametersAndOpererators() { .initial(1.0).mean(1.0).offset(0.0).partitionOptions(this).isPriorFixed(true) .isAdaptiveMultivariateCompatible(false).build(parameters); - new Parameter.Builder(ClockType.HMC_CLOCK_BRANCH_RATES, "HMC relaxed clock branch rates") + new Parameter.Builder(ClockType.HMC_CLOCK_BRANCH_RATES, ClockType.HMC_CLOCK_RATES_DESCRIPTION) .prior(PriorType.LOGNORMAL_HPM_PRIOR).initial(0.001).isNonNegative(true) .partitionOptions(this).isPriorFixed(true) .isAdaptiveMultivariateCompatible(false).build(parameters); @@ -216,11 +216,11 @@ public void initModelParametersAndOpererators() { createScaleOperator(ClockType.UCGD_SHAPE, demoTuning, rateWeights); //HMC relaxed clock - createOperator("HMCRCR", "HMC relaxed clock branch rates", + createOperator("HMCRCR", ClockType.HMC_CLOCK_RATES_DESCRIPTION, "Hamiltonian Monte Carlo relaxed clock branch rates operator", null, OperatorType.RELAXED_CLOCK_HMC_RATE_OPERATOR,-1 , 1.0); - createOperator("HMCRCS", "HMC relaxed clock location and scale", + createOperator("HMCRCS", ClockType.HMC_CLOCK_LOCATION_SCALE_DESCRIPTION, "Hamiltonian Monte Carlo relaxed clock scale operator", null, OperatorType.RELAXED_CLOCK_HMC_SCALE_OPERATOR,-1 , 0.5); - //for the time being turn off the HMC relaxed clock scale kernel + //turn off the HMC relaxed clock scale kernel by default getOperator("HMCRCS").setUsed(false); createScaleOperator(ClockType.HMC_CLOCK_LOCATION, demoTuning, rateWeights); createScaleOperator(ClockType.HMCLN_SCALE, demoTuning, rateWeights); From 1c9935d9e34f7903d8ed148c2c399c828a0a00fe Mon Sep 17 00:00:00 2001 From: GuyBaele Date: Mon, 9 Dec 2024 15:20:13 +0100 Subject: [PATCH 12/20] bug fix for generating XML code for HMC transition kernels for the HMC relaxed clock --- src/dr/app/beauti/generator/ClockModelGenerator.java | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/src/dr/app/beauti/generator/ClockModelGenerator.java b/src/dr/app/beauti/generator/ClockModelGenerator.java index af0c9be6b8..c19446180c 100644 --- a/src/dr/app/beauti/generator/ClockModelGenerator.java +++ b/src/dr/app/beauti/generator/ClockModelGenerator.java @@ -298,16 +298,19 @@ public void writeBranchRatesModel(PartitionClockModel clockModel, XMLWriter writ writeCovarianceStatistic(writer, tag, prefix, treePrefix); - //TODO add more String constants for this type of code + boolean generateRatesGradient = false; boolean generateScaleGradient = false; for (Operator operator : options.selectOperators()) { - if (operator.getName().equals("HMC relaxed clock location and scale") && operator.isUsed()) { + if (operator.getName().equals(ClockType.HMC_CLOCK_RATES_DESCRIPTION) && operator.isUsed()) { + generateRatesGradient = true; + } + if (operator.getName().equals(ClockType.HMC_CLOCK_LOCATION_SCALE_DESCRIPTION) && operator.isUsed()) { generateScaleGradient = true; } } - if (generateScaleGradient) { + if (generateRatesGradient) { //scale prior writer.writeOpenTag(DistributionLikelihood.DISTRIBUTION_LIKELIHOOD, @@ -349,6 +352,9 @@ public void writeBranchRatesModel(PartitionClockModel clockModel, XMLWriter writ writer.writeCloseTag(LocationScaleGradientParser.LOCATION); writer.writeCloseTag(LocationScaleGradientParser.NAME); + } + + if (generateScaleGradient){ //scale gradient writer.writeOpenTag(LocationScaleGradientParser.NAME, new Attribute[]{ new Attribute.Default<>(XMLParser.ID, prefix + ScaleGradient.SCALE_GRADIENT), From fe31d4bbe80b1bfee1582ff4d7598df4789fbf46 Mon Sep 17 00:00:00 2001 From: GuyBaele Date: Sat, 21 Dec 2024 11:35:44 +0100 Subject: [PATCH 13/20] GMRF transition kernels are not Gibbs operators --- src/dr/app/beauti/types/OperatorType.java | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/dr/app/beauti/types/OperatorType.java b/src/dr/app/beauti/types/OperatorType.java index 6f0e47e7d2..d89f8729fd 100644 --- a/src/dr/app/beauti/types/OperatorType.java +++ b/src/dr/app/beauti/types/OperatorType.java @@ -28,7 +28,6 @@ package dr.app.beauti.types; import dr.evomodel.operators.BitFlipInSubstitutionModelOperator; -import dr.evomodelxml.operators.TreeNodeSlideParser; import dr.inference.operators.RateBitExchangeOperator; import dr.inferencexml.operators.ScaleOperatorParser; @@ -66,8 +65,8 @@ public enum OperatorType { NARROW_EXCHANGE("narrowExchange"), WIDE_EXCHANGE("wideExchange"), EMPIRICAL_TREE_SWAP("empiricalSwap"), - GMRF_GIBBS_OPERATOR("gmrfGibbsOperator"), - SKY_GRID_GIBBS_OPERATOR("gmrfGibbsOperator"), + GMRF_BLOCKUPDATE_OPERATOR("gmrfBlockUpdateOperator"), + SKY_GRID_BLOCKUPDATE_OPERATOR("gmrfBlockUpdateOperator"), SKY_GRID_HMC_OPERATOR("gmrfHMCOperator"), // PRECISION_GMRF_OPERATOR("precisionGMRFOperator"), WILSON_BALDING("wilsonBalding"), From b510b75a347b0e60c985d06bbe24bd9cc3bd322e Mon Sep 17 00:00:00 2001 From: GuyBaele Date: Sat, 21 Dec 2024 11:36:06 +0100 Subject: [PATCH 14/20] set tuning to 1 for GMRF transition kernels --- src/dr/app/beauti/options/PartitionTreePrior.java | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/dr/app/beauti/options/PartitionTreePrior.java b/src/dr/app/beauti/options/PartitionTreePrior.java index 3da2ee6c1e..3685518ef8 100644 --- a/src/dr/app/beauti/options/PartitionTreePrior.java +++ b/src/dr/app/beauti/options/PartitionTreePrior.java @@ -28,7 +28,6 @@ package dr.app.beauti.options; import dr.app.beauti.types.*; -import dr.evomodel.coalescent.VariableDemographicModel; import dr.evomodel.speciation.CalibrationPoints; import dr.evomodelxml.coalescent.GMRFSkyrideLikelihoodParser; import dr.evomodelxml.speciation.BirthDeathEpidemiologyModelParser; @@ -271,9 +270,9 @@ public void initModelParametersAndOpererators() { // "demographic.indicators", OperatorType.SCALE_WITH_INDICATORS, 0.5, 2 * demoWeights); createOperatorUsing2Parameters("gmrfGibbsOperator", "gmrfGibbsOperator", "Gibbs sampler for GMRF Skyride", "skyride.logPopSize", - "skyride.precision", OperatorType.GMRF_GIBBS_OPERATOR, -1, 2); + "skyride.precision", OperatorType.GMRF_BLOCKUPDATE_OPERATOR, 1, 2); createOperatorUsing2Parameters("gmrfSkyGridGibbsOperator", "skygrid.logPopSize", "Gibbs sampler for Bayesian SkyGrid", "skygrid.logPopSize", - GMRFSkyrideLikelihoodParser.SKYGRID_PRECISION, OperatorType.SKY_GRID_GIBBS_OPERATOR, -1, 2); + GMRFSkyrideLikelihoodParser.SKYGRID_PRECISION, OperatorType.SKY_GRID_BLOCKUPDATE_OPERATOR, 1, 2); createScaleOperator(GMRFSkyrideLikelihoodParser.SKYGRID_PRECISION, "skygrid precision", 0.75, 1.0); createOperatorUsing2Parameters("gmrfSkyGridHMCOperator", "Multiple", "HMC transition kernel for Bayesian SkyGrid", "skygrid.logPopSize", GMRFSkyrideLikelihoodParser.SKYGRID_PRECISION, OperatorType.SKY_GRID_HMC_OPERATOR, -1, 2); @@ -292,7 +291,7 @@ public void initModelParametersAndOpererators() { createOperator(BirthDeathSerialSamplingModelParser.BDSS + "." + BirthDeathSerialSamplingModelParser.RELATIVE_MU, OperatorType.RANDOM_WALK_LOGIT, demoTuning, 1); createScaleOperator(BirthDeathSerialSamplingModelParser.BDSS + "." - + BirthDeathSerialSamplingModelParser.PSI, demoTuning, 1); // todo random worl op ? + + BirthDeathSerialSamplingModelParser.PSI, demoTuning, 1); // todo random walk op ? createScaleOperator(BirthDeathSerialSamplingModelParser.BDSS + "." + BirthDeathSerialSamplingModelParser.ORIGIN, demoTuning, 1); // createScaleOperator(BirthDeathSerialSamplingModelParser.BDSS + "." From fd162f367c7e25a79a8d3a203a1ed0034a9fdb3b Mon Sep 17 00:00:00 2001 From: GuyBaele Date: Sat, 21 Dec 2024 11:36:53 +0100 Subject: [PATCH 15/20] write scaleFactor to XML for GMRF transition kernels --- .../app/beauti/generator/OperatorsGenerator.java | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/src/dr/app/beauti/generator/OperatorsGenerator.java b/src/dr/app/beauti/generator/OperatorsGenerator.java index 4234751be5..ede15bb1b4 100644 --- a/src/dr/app/beauti/generator/OperatorsGenerator.java +++ b/src/dr/app/beauti/generator/OperatorsGenerator.java @@ -231,11 +231,11 @@ private void writeOperator(Operator operator, XMLWriter writer) { case SCALE_WITH_INDICATORS: writeScaleWithIndicatorsOperator(operator, writer); break; - case GMRF_GIBBS_OPERATOR: - writeGMRFGibbsOperator(operator, prefix, writer); + case GMRF_BLOCKUPDATE_OPERATOR: + writeGMRFBlockUpdateOperator(operator, prefix, writer); break; - case SKY_GRID_GIBBS_OPERATOR: - writeSkyGridGibbsOperator(operator, prefix, writer); + case SKY_GRID_BLOCKUPDATE_OPERATOR: + writeSkyGridBlockUpdateOperator(operator, prefix, writer); break; case SKY_GRID_HMC_OPERATOR: writeSkyGridHMCOperator(operator, prefix, writer); @@ -518,12 +518,11 @@ private void writeSampleNonActiveOperator(Operator operator, XMLWriter writer) { writer.writeCloseTag(SampleNonActiveGibbsOperatorParser.SAMPLE_NONACTIVE_GIBBS_OPERATOR); } - private void writeSkyGridGibbsOperator(Operator operator, String treePriorPrefix, XMLWriter writer) { + private void writeSkyGridBlockUpdateOperator(Operator operator, String treePriorPrefix, XMLWriter writer) { writer.writeOpenTag( GMRFSkyrideBlockUpdateOperatorParser.GRID_BLOCK_UPDATE_OPERATOR, new Attribute[] { -// This is a Gibbs operator so shouldn't have a tuning parameter? -// new Attribute.Default(GMRFSkyrideBlockUpdateOperatorParser.SCALE_FACTOR, operator.getTuning()), + new Attribute.Default(GMRFSkyrideBlockUpdateOperatorParser.SCALE_FACTOR, operator.getTuning()), getWeightAttribute(operator.getWeight()) } ); @@ -688,7 +687,7 @@ private void writeShrinkageClockHMCOperator(Operator operator, String prefix, XM writer.writeCloseTag(HamiltonianMonteCarloOperatorParser.HMC_OPERATOR); } - private void writeGMRFGibbsOperator(Operator operator, String treePriorPrefix, XMLWriter writer) { + private void writeGMRFBlockUpdateOperator(Operator operator, String treePriorPrefix, XMLWriter writer) { writer.writeOpenTag( GMRFSkyrideBlockUpdateOperatorParser.BLOCK_UPDATE_OPERATOR, new Attribute[]{ From 41a001c5381ead17c7151a4cdfd0a9dc1709f878 Mon Sep 17 00:00:00 2001 From: GuyBaele Date: Tue, 24 Dec 2024 12:04:01 +0100 Subject: [PATCH 16/20] location scale gradient generator code was in the wrong place --- src/dr/app/beauti/generator/ClockModelGenerator.java | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/dr/app/beauti/generator/ClockModelGenerator.java b/src/dr/app/beauti/generator/ClockModelGenerator.java index c19446180c..2c2dc44e07 100644 --- a/src/dr/app/beauti/generator/ClockModelGenerator.java +++ b/src/dr/app/beauti/generator/ClockModelGenerator.java @@ -340,6 +340,9 @@ public void writeBranchRatesModel(PartitionClockModel clockModel, XMLWriter writ writer.writeIDref(DefaultTreeModel.TREE_MODEL, treePrefix + DefaultTreeModel.TREE_MODEL); writer.writeCloseTag(CTMCScalePriorParser.MODEL_NAME); + } + + if (generateScaleGradient){ //location gradient writer.writeOpenTag(LocationScaleGradientParser.NAME, new Attribute[]{ new Attribute.Default<>(XMLParser.ID, prefix + LocationGradient.LOCATION_GRADIENT), @@ -352,9 +355,6 @@ public void writeBranchRatesModel(PartitionClockModel clockModel, XMLWriter writ writer.writeCloseTag(LocationScaleGradientParser.LOCATION); writer.writeCloseTag(LocationScaleGradientParser.NAME); - } - - if (generateScaleGradient){ //scale gradient writer.writeOpenTag(LocationScaleGradientParser.NAME, new Attribute[]{ new Attribute.Default<>(XMLParser.ID, prefix + ScaleGradient.SCALE_GRADIENT), From fb0d1ded52c150537898045caee6e3e3ca70a829 Mon Sep 17 00:00:00 2001 From: Plemey Date: Sat, 28 Dec 2024 15:15:15 +0100 Subject: [PATCH 17/20] stat for Simon --- .../ContinuousDiffusionStatistic.java | 40 ++++++++++++++++++- 1 file changed, 39 insertions(+), 1 deletion(-) diff --git a/src/dr/evomodel/continuous/ContinuousDiffusionStatistic.java b/src/dr/evomodel/continuous/ContinuousDiffusionStatistic.java index f72aeac688..4fa6325c2f 100644 --- a/src/dr/evomodel/continuous/ContinuousDiffusionStatistic.java +++ b/src/dr/evomodel/continuous/ContinuousDiffusionStatistic.java @@ -74,6 +74,7 @@ public class ContinuousDiffusionStatistic extends Statistic.Abstract { public static final String SPEARMAN = "spearman"; public static final String CORRELATION_COEFFICIENT = "correlationCoefficient"; public static final String DISTANCE_TIME_CORRELATION = "distanceTimeCorrelation"; + public static final String SQUAREDDISTANCE_TIME4_CORRELATION = "squaredDistanceTimeFourCorrelation"; public static final String R_SQUARED = "Rsquared"; public static final String STATISTIC = "statistic"; public static final String TRAIT = "trait"; @@ -458,6 +459,18 @@ public double getStatisticValue(int dim) { Regression r = new Regression(convertDoubles(times),convertDoubles(distances)); return r.getCorrelationCoefficient(); } + } else if (summaryStat == summaryStatistic.SQUAREDDISTANCE_TIME4_CORRELATION) { + List squareddistances = squareElements(distances); + List timesFour = elementsTimesFour(times); + if (summaryMode == Mode.SPEARMAN) { + return getSpearmanRho(convertDoubles(timesFour),convertDoubles(squareddistances)); + } else if (summaryMode == Mode.R_SQUARED) { + Regression r = new Regression(convertDoubles(timesFour), convertDoubles(squareddistances)); + return r.getRSquared(); + } else { + Regression r = new Regression(convertDoubles(timesFour),convertDoubles(squareddistances)); + return r.getCorrelationCoefficient(); + } } else { return treeLength; } @@ -490,6 +503,22 @@ private double[] toArray(List list) { return returnArray; } + private static List squareElements(List inputList) { + List squaredList = new ArrayList<>(); + for (Double number : inputList) { + squaredList.add(number * number); + } + return squaredList; + } + + private static List elementsTimesFour (List inputList) { + List returnList = new ArrayList<>(); + for (Double number : inputList) { + returnList.add(number * 4); + } + return returnList; + } + private double[] imputeValue(double[] nodeValue, double[] parentValue, double time, double nodeHeight, double parentHeight, double[] precisionArray, double rate, boolean trueNoise) { final double scaledTimeChild = (time - nodeHeight) * rate; @@ -932,7 +961,8 @@ enum summaryStatistic { WAVEFRONT_DISTANCE, WAVEFRONT_DISTANCE_PHYLO, WAVEFRONT_RATE, - DISTANCE_TIME_CORRELATION + DISTANCE_TIME_CORRELATION, + SQUAREDDISTANCE_TIME4_CORRELATION } enum BranchSet { @@ -1023,6 +1053,12 @@ public Object parseXMLObject(XMLObject xo) throws XMLParseException { System.err.println(name+": mode = "+mode+" ignored for "+DISTANCE_TIME_CORRELATION+", reverting to correlation coefficient mode"); statMode = Mode.CORRELATION_COEFFICIENT; } + } else if (statistic.equals(SQUAREDDISTANCE_TIME4_CORRELATION)) { + summaryStat = summaryStatistic.SQUAREDDISTANCE_TIME4_CORRELATION; + if (mode.equals(AVERAGE) || mode.equals(WEIGHTED_AVERAGE) || mode.equals(COEFFICIENT_OF_VARIATION) || mode.equals(MEDIAN)){ + System.err.println(name+": mode = "+mode+" ignored for "+SQUAREDDISTANCE_TIME4_CORRELATION+", reverting to correlation coefficient mode"); + statMode = Mode.CORRELATION_COEFFICIENT; + } } else if (statistic.equals(WAVEFRONT_DISTANCE)) { summaryStat = summaryStatistic.WAVEFRONT_DISTANCE; if (!mode.equals(WEIGHTED_AVERAGE)) { @@ -1065,6 +1101,8 @@ public Object parseXMLObject(XMLObject xo) throws XMLParseException { summaryStat = summaryStatistic.DIFFUSION_COEFFICIENT; } else if (statistic.equals(DISTANCE_TIME_CORRELATION)) { summaryStat = summaryStatistic.DISTANCE_TIME_CORRELATION; + } else if (statistic.equals(SQUAREDDISTANCE_TIME4_CORRELATION)) { + summaryStat = summaryStatistic.SQUAREDDISTANCE_TIME4_CORRELATION; } else { System.err.println(name+": unknown statistic: "+statistic+". Reverting to diffusion rate."); summaryStat = summaryStatistic.DIFFUSION_RATE; From f37ecb6a49f0be86da0fcce5cd4a576ea1809624 Mon Sep 17 00:00:00 2001 From: Plemey Date: Mon, 30 Dec 2024 19:11:44 +0100 Subject: [PATCH 18/20] simplifying stat for Simon --- .../continuous/ContinuousDiffusionStatistic.java | 15 +++------------ 1 file changed, 3 insertions(+), 12 deletions(-) diff --git a/src/dr/evomodel/continuous/ContinuousDiffusionStatistic.java b/src/dr/evomodel/continuous/ContinuousDiffusionStatistic.java index 4fa6325c2f..227b7c8aa5 100644 --- a/src/dr/evomodel/continuous/ContinuousDiffusionStatistic.java +++ b/src/dr/evomodel/continuous/ContinuousDiffusionStatistic.java @@ -461,14 +461,13 @@ public double getStatisticValue(int dim) { } } else if (summaryStat == summaryStatistic.SQUAREDDISTANCE_TIME4_CORRELATION) { List squareddistances = squareElements(distances); - List timesFour = elementsTimesFour(times); if (summaryMode == Mode.SPEARMAN) { - return getSpearmanRho(convertDoubles(timesFour),convertDoubles(squareddistances)); + return getSpearmanRho(convertDoubles(times),convertDoubles(squareddistances)); } else if (summaryMode == Mode.R_SQUARED) { - Regression r = new Regression(convertDoubles(timesFour), convertDoubles(squareddistances)); + Regression r = new Regression(convertDoubles(times), convertDoubles(squareddistances)); return r.getRSquared(); } else { - Regression r = new Regression(convertDoubles(timesFour),convertDoubles(squareddistances)); + Regression r = new Regression(convertDoubles(times),convertDoubles(squareddistances)); return r.getCorrelationCoefficient(); } } else { @@ -511,14 +510,6 @@ private static List squareElements(List inputList) { return squaredList; } - private static List elementsTimesFour (List inputList) { - List returnList = new ArrayList<>(); - for (Double number : inputList) { - returnList.add(number * 4); - } - return returnList; - } - private double[] imputeValue(double[] nodeValue, double[] parentValue, double time, double nodeHeight, double parentHeight, double[] precisionArray, double rate, boolean trueNoise) { final double scaledTimeChild = (time - nodeHeight) * rate; From 00877981aa9b6bbf68487b33912b9f7be488ba87 Mon Sep 17 00:00:00 2001 From: Andrew Rambaut Date: Sun, 19 Jan 2025 13:50:26 +0000 Subject: [PATCH 19/20] A few tweaks to help optimisation --- .../tools/treeannotator/TreeAnnotator.java | 69 ++++++++++++++----- 1 file changed, 50 insertions(+), 19 deletions(-) diff --git a/src/dr/app/tools/treeannotator/TreeAnnotator.java b/src/dr/app/tools/treeannotator/TreeAnnotator.java index e6c8a45cd5..4dfdd573bc 100644 --- a/src/dr/app/tools/treeannotator/TreeAnnotator.java +++ b/src/dr/app/tools/treeannotator/TreeAnnotator.java @@ -67,6 +67,7 @@ public class TreeAnnotator extends BaseTreeTool { // Messages to stderr, output to stdout private static PrintStream progressStream = System.err; + private static final boolean extendedMetrics = false; private final CollectionAction collectionAction; private final AnnotationAction annotationAction; @@ -115,15 +116,15 @@ public String toString() { */ public TreeAnnotator(final int burninTrees, final long burninStates, - HeightsSummary heightsOption, - double posteriorLimit, - double[] hpd2D, - boolean computeESS, - Target targetOption, - String targetTreeFileName, - String referenceTreeFileName, - String inputFileName, - String outputFileName + final HeightsSummary heightsOption, + final double posteriorLimit, + final double[] hpd2D, + final boolean computeESS, + final Target targetOption, + final String targetTreeFileName, + final String referenceTreeFileName, + final String inputFileName, + final String outputFileName ) throws IOException { long totalStartTime = System.currentTimeMillis(); @@ -414,7 +415,7 @@ private MutableTree readUserTargetTree(String targetTreeFileName, CladeSystem cl double score = scoreTree(targetTree, cladeSystem); progressStream.println("Target tree's log clade credibility: " + String.format("%.4f", score)); reportStatistics(cladeSystem, targetTree); - reportStatisticTables(cladeSystem, targetTree); +// reportStatisticTables(cladeSystem, targetTree); progressStream.println(); return targetTree; @@ -504,6 +505,13 @@ private MutableTree getHIPSTRTree(CladeSystem cladeSystem) { MutableTree tree = treeBuilder.getHIPSTRTree(cladeSystem, taxa); double score = treeBuilder.getScore(); + // Test whether score returned by HIPSTRTreeBuilder is the same as that calculated de novo + // Generally seems to have very small (precision related) differences +// double score2 = scoreTree(tree, cladeSystem); +// if (score != score2) { +// System.err.println("HIPSTR Score: " + score + " vs recalculation: " + score2); +// } + long timeElapsed = (System.currentTimeMillis() - startTime) / 1000; progressStream.println("[" + timeElapsed + " secs]"); progressStream.println(); @@ -519,14 +527,37 @@ private static void reportStatistics(CladeSystem cladeSystem, Tree tree) { progressStream.println("Lowest individual clade credibility: " + String.format("%.4f", cladeSystem.getMinimumCladeCredibility(tree))); progressStream.println("Mean individual clade credibility: " + String.format("%.4f", cladeSystem.getMeanCladeCredibility(tree))); progressStream.println("Median individual clade credibility: " + String.format("%.4f", cladeSystem.getMedianCladeCredibility(tree))); - progressStream.println("Number of clades with credibility 1.0: " + cladeSystem.getTopCladeCredibility(tree, 1.0)); - progressStream.println("Number of clades with credibility > 0.99: " + cladeSystem.getTopCladeCredibility(tree, 0.99) + - " (out of " + cladeSystem.getTopCladeCredibility(0.99) + " in all trees)"); - progressStream.println("Number of clades with credibility > 0.95: " + cladeSystem.getTopCladeCredibility(tree, 0.95) + - " (out of " + cladeSystem.getTopCladeCredibility(0.95) + " in all trees)"); - progressStream.println("Number of clades with credibility > 0.5: " + cladeSystem.getTopCladeCredibility(tree, 0.5) + - " (out of " + cladeSystem.getTopCladeCredibility(0.5) + " in all trees)"); + progressStream.println("Number of clades with credibility 1.0: " + cladeSystem.getTopCladeCount(tree, 1.0)); + reportCladeCredibilityCount(cladeSystem, tree, 0.99); + reportCladeCredibilityCount(cladeSystem, tree, 0.95); + if (extendedMetrics) { + progressStream.println("Number of clades with credibility > 0.75: " + cladeSystem.getTopCladeCount(tree, 0.75) + + " (out of " + cladeSystem.getTopCladeCount(0.75) + " in all trees)"); + } + reportCladeCredibilityCount(cladeSystem, tree, 0.5); + if (extendedMetrics) { + progressStream.println("Number of clades with credibility > 0.25: " + cladeSystem.getTopCladeCount(tree, 0.25) + + " (out of " + cladeSystem.getTopCladeCount(0.25) + " in all trees)"); + progressStream.println("Number of clades with credibility > 0.10: " + cladeSystem.getTopCladeCount(tree, 0.1) + + " (out of " + cladeSystem.getTopCladeCount(0.1) + " in all trees)"); + progressStream.println("Number of clades with credibility > 0.05: " + cladeSystem.getTopCladeCount(tree, 0.05) + + " (out of " + cladeSystem.getTopCladeCount(0.05) + " in all trees)"); + } } + + private static void reportCladeCredibilityCount(CladeSystem cladeSystem, Tree tree, double threshold) { + int treeCladeCount = cladeSystem.getTopCladeCount(tree, threshold); + int allCladeCount = cladeSystem.getTopCladeCount(threshold); + progressStream.println("Number of clades with credibility > " + threshold + ": " + + treeCladeCount + + " (out of " + allCladeCount + " in all trees)"); + Set treeClades = cladeSystem.getTopClades(tree, threshold); + Set allClades = cladeSystem.getTopClades(threshold); + + Set diff = new HashSet<>(allClades); + diff.removeAll(treeClades); + } + private static void reportStatisticTables(CladeSystem cladeSystem, Tree tree) { int count = 100; // double[] table = new double[count + 1]; @@ -540,9 +571,9 @@ private static void reportStatisticTables(CladeSystem cladeSystem, Tree tree) { double threshold = ((double) (i)) / count; progressStream.print(threshold); progressStream.print(","); - progressStream.print(cladeSystem.getTopCladeCredibility(tree, threshold)); + progressStream.print(cladeSystem.getTopCladeCount(tree, threshold)); progressStream.print(","); - progressStream.println(cladeSystem.getTopCladeCredibility(threshold)); + progressStream.println(cladeSystem.getTopCladeCount(threshold)); } } From 0bb5b0ba495196985547b9d0e67b029d814db963 Mon Sep 17 00:00:00 2001 From: Andrew Rambaut Date: Sun, 19 Jan 2025 13:51:19 +0000 Subject: [PATCH 20/20] Tweaks to help optimisation --- .../app/tools/treeannotator/CladeSystem.java | 45 ++++++++++++++++++- .../treeannotator/HIPSTRTreeBuilder.java | 8 +++- 2 files changed, 50 insertions(+), 3 deletions(-) diff --git a/src/dr/app/tools/treeannotator/CladeSystem.java b/src/dr/app/tools/treeannotator/CladeSystem.java index dca9b3081d..504c58ea7b 100644 --- a/src/dr/app/tools/treeannotator/CladeSystem.java +++ b/src/dr/app/tools/treeannotator/CladeSystem.java @@ -34,7 +34,9 @@ import dr.stats.DiscreteStatistics; import java.util.HashMap; +import java.util.HashSet; import java.util.Map; +import java.util.Set; /** * @author Andrew Rambaut @@ -286,7 +288,7 @@ public boolean expectAllClades() { * @param threshold * @return */ - public int getTopCladeCredibility(Tree tree, double threshold) { + public int getTopCladeCount(Tree tree, double threshold) { final int[] count = {0}; traverseTree(tree, new CladeAction() { @Override @@ -304,12 +306,36 @@ public boolean expectAllClades() { return count[0]; } + /** + * Returns the set of clades in the tree with threshold credibility or higher + * @param tree + * @param threshold + * @return + */ + public Set getTopClades(Tree tree, double threshold) { + Set clades = new HashSet<>(); + traverseTree(tree, new CladeAction() { + @Override + public void actOnClade(Clade clade, Tree tree, NodeRef node) { + if (clade.getTaxon() == null && clade.getCredibility() >= threshold) { + clades.add(clade); + } + } + + @Override + public boolean expectAllClades() { + return true; + } + }); + return clades; + } + /** * Returns the number of clades in the clade system with threshold credibility or higher * @param threshold * @return */ - public int getTopCladeCredibility(double threshold) { + public int getTopCladeCount(double threshold) { int count = 0; for (Clade clade : cladeMap.values()) { if (clade.getCredibility() >= threshold) { @@ -319,6 +345,21 @@ public int getTopCladeCredibility(double threshold) { return count; } + /** + * Returns the set of clades in the clade system with threshold credibility or higher + * @param threshold + * @return + */ + public Set getTopClades(double threshold) { + Set clades = new HashSet<>(); + for (Clade clade : cladeMap.values()) { + if (clade.getCredibility() >= threshold) { + clades.add(clade); + } + } + return clades; + } + public int getCladeCount() { return cladeMap.keySet().size(); } diff --git a/src/dr/app/tools/treeannotator/HIPSTRTreeBuilder.java b/src/dr/app/tools/treeannotator/HIPSTRTreeBuilder.java index cfe6a37b9a..044923f53d 100644 --- a/src/dr/app/tools/treeannotator/HIPSTRTreeBuilder.java +++ b/src/dr/app/tools/treeannotator/HIPSTRTreeBuilder.java @@ -36,7 +36,7 @@ import java.util.Map; public class HIPSTRTreeBuilder { - private Map credibilityCache = new HashMap<>(); + private final Map credibilityCache = new HashMap<>(); public MutableTree getHIPSTRTree(CladeSystem cladeSystem, TaxonList taxonList) { BiClade rootClade = (BiClade)cladeSystem.getRootClade(); @@ -99,6 +99,12 @@ private double findHIPSTRTree(BiClade clade) { clade.bestRight = left; } } +// else if (leftLogCredibility + rightLogCredibility == bestLogCredibility) { +// if ((left.getSize() > 1 && left.getCredibility() >= 0.5) || +// (right.getSize() > 1 && right.getCredibility() >= 0.5)) { +// System.err.println("eek"); +// } +// } } logCredibility += bestLogCredibility; } else {