Skip to content

Commit

Permalink
Merge branch 'master' into modern_arguments
Browse files Browse the repository at this point in the history
# Conflicts:
#	src/dr/app/tools/treeannotator/TreeAnnotator.java
  • Loading branch information
rambaut committed Jan 25, 2025
2 parents 455b74b + 0bb5b0b commit b9d945e
Show file tree
Hide file tree
Showing 19 changed files with 469 additions and 224 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@
import dr.evolution.util.Taxa;
import dr.evolution.util.Units;
import dr.evomodel.branchratemodel.BranchRateModel;
import dr.evomodel.branchratemodel.BranchSpecificFixedEffects;
import dr.inference.distribution.DistributionLikelihood;
import dr.evomodel.tree.DefaultTreeModel;
import dr.evomodelxml.TreeWorkingPriorParsers;
import dr.evomodelxml.branchratemodel.*;
Expand Down Expand Up @@ -1025,8 +1027,25 @@ public void writeMLE(XMLWriter writer, MarginalLikelihoodEstimationOptions optio
writer.writeIDref(RandomLocalClockModelParser.LOCAL_BRANCH_RATES, model.getPrefix() + BranchRateModel.BRANCH_RATES);
break;

case MIXED_EFFECTS_CLOCK:

writer.writeIDref(DistributionLikelihood.DISTRIBUTION_LIKELIHOOD, model.getPrefix() + BranchSpecificFixedEffects.RATES_PRIOR);
writer.writeIDref(DistributionLikelihood.DISTRIBUTION_LIKELIHOOD, model.getPrefix() + BranchSpecificFixedEffects.SCALE_PRIOR);
writer.writeIDref(DistributionLikelihood.DISTRIBUTION_LIKELIHOOD, model.getPrefix() + BranchSpecificFixedEffects.INTERCEPT_PRIOR);

String coeff = BranchSpecificFixedEffectsParser.COEFFICIENT;
int number = 1;
String concat = coeff + number;
while (model.hasParameter(concat)) {
writer.writeIDref(DistributionLikelihood.DISTRIBUTION_LIKELIHOOD, model.getPrefix() + BranchSpecificFixedEffectsParser.FIXED_EFFECTS_LIKELIHOOD + number);
number++;
concat = coeff + number;
}

break;

default:
throw new IllegalArgumentException("Unknown clock model");
throw new IllegalArgumentException("Unknown clock model: " + model.getClockType());
}
}

Expand Down
25 changes: 14 additions & 11 deletions src/dr/app/beauti/generator/ClockModelGenerator.java
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
import dr.app.beauti.components.ComponentFactory;
import dr.app.beauti.options.*;
import dr.app.beauti.types.ClockType;
import dr.app.beauti.types.OperatorType;
import dr.app.beauti.util.XMLWriter;
import dr.evolution.util.Taxa;
import dr.evomodel.branchratemodel.ArbitraryBranchRates;
Expand Down Expand Up @@ -66,8 +65,6 @@
import dr.util.Attribute;
import dr.xml.XMLParser;

import java.util.Map;

import static dr.inference.model.ParameterParser.PARAMETER;
import static dr.inferencexml.distribution.PriorParsers.*;
import static dr.inferencexml.distribution.shrinkage.BayesianBridgeLikelihoodParser.*;
Expand Down Expand Up @@ -301,16 +298,19 @@ public void writeBranchRatesModel(PartitionClockModel clockModel, XMLWriter writ

writeCovarianceStatistic(writer, tag, prefix, treePrefix);

//TODO add more String constants for this type of code
boolean generateRatesGradient = false;
boolean generateScaleGradient = false;

for (Operator operator : options.selectOperators()) {
if (operator.getName().equals("HMC relaxed clock location and scale") && operator.isUsed()) {
if (operator.getName().equals(ClockType.HMC_CLOCK_RATES_DESCRIPTION) && operator.isUsed()) {
generateRatesGradient = true;
}
if (operator.getName().equals(ClockType.HMC_CLOCK_LOCATION_SCALE_DESCRIPTION) && operator.isUsed()) {
generateScaleGradient = true;
}
}

if (generateScaleGradient) {
if (generateRatesGradient) {

//scale prior
writer.writeOpenTag(DistributionLikelihood.DISTRIBUTION_LIKELIHOOD,
Expand Down Expand Up @@ -340,6 +340,9 @@ public void writeBranchRatesModel(PartitionClockModel clockModel, XMLWriter writ
writer.writeIDref(DefaultTreeModel.TREE_MODEL, treePrefix + DefaultTreeModel.TREE_MODEL);
writer.writeCloseTag(CTMCScalePriorParser.MODEL_NAME);

}

if (generateScaleGradient){
//location gradient
writer.writeOpenTag(LocationScaleGradientParser.NAME, new Attribute[]{
new Attribute.Default<>(XMLParser.ID, prefix + LocationGradient.LOCATION_GRADIENT),
Expand Down Expand Up @@ -958,18 +961,18 @@ public static void writeBranchRatesModelRef(PartitionClockModel model, XMLWriter

case MIXED_EFFECTS_CLOCK:
//always write distribution likelihoods for rate, scale and intercept
writer.writeIDref(DistributionLikelihood.DISTRIBUTION_LIKELIHOOD, BranchSpecificFixedEffects.RATES_PRIOR);
writer.writeIDref(DistributionLikelihood.DISTRIBUTION_LIKELIHOOD, BranchSpecificFixedEffects.SCALE_PRIOR);
writer.writeIDref(DistributionLikelihood.DISTRIBUTION_LIKELIHOOD, BranchSpecificFixedEffects.INTERCEPT_PRIOR);
//writer.writeIDref(DistributionLikelihood.DISTRIBUTION_LIKELIHOOD, BranchSpecificFixedEffects.RATES_PRIOR);
//writer.writeIDref(DistributionLikelihood.DISTRIBUTION_LIKELIHOOD, BranchSpecificFixedEffects.SCALE_PRIOR);
//writer.writeIDref(DistributionLikelihood.DISTRIBUTION_LIKELIHOOD, BranchSpecificFixedEffects.INTERCEPT_PRIOR);
//check for coefficients
String coeff = BranchSpecificFixedEffectsParser.COEFFICIENT;
/*String coeff = BranchSpecificFixedEffectsParser.COEFFICIENT;
int number = 1;
String concat = coeff + number;
while (model.hasParameter(concat)) {
writer.writeIDref(DistributionLikelihood.DISTRIBUTION_LIKELIHOOD, BranchSpecificFixedEffectsParser.FIXED_EFFECTS_LIKELIHOOD + number);
number++;
concat = coeff + number;
}
}*/
tag = ArbitraryBranchRatesParser.ARBITRARY_BRANCH_RATES;
id = model.getPrefix() + ArbitraryBranchRates.BRANCH_RATES;
break;
Expand Down
15 changes: 7 additions & 8 deletions src/dr/app/beauti/generator/OperatorsGenerator.java
Original file line number Diff line number Diff line change
Expand Up @@ -231,11 +231,11 @@ private void writeOperator(Operator operator, XMLWriter writer) {
case SCALE_WITH_INDICATORS:
writeScaleWithIndicatorsOperator(operator, writer);
break;
case GMRF_GIBBS_OPERATOR:
writeGMRFGibbsOperator(operator, prefix, writer);
case GMRF_BLOCKUPDATE_OPERATOR:
writeGMRFBlockUpdateOperator(operator, prefix, writer);
break;
case SKY_GRID_GIBBS_OPERATOR:
writeSkyGridGibbsOperator(operator, prefix, writer);
case SKY_GRID_BLOCKUPDATE_OPERATOR:
writeSkyGridBlockUpdateOperator(operator, prefix, writer);
break;
case SKY_GRID_HMC_OPERATOR:
writeSkyGridHMCOperator(operator, prefix, writer);
Expand Down Expand Up @@ -518,12 +518,11 @@ private void writeSampleNonActiveOperator(Operator operator, XMLWriter writer) {
writer.writeCloseTag(SampleNonActiveGibbsOperatorParser.SAMPLE_NONACTIVE_GIBBS_OPERATOR);
}

private void writeSkyGridGibbsOperator(Operator operator, String treePriorPrefix, XMLWriter writer) {
private void writeSkyGridBlockUpdateOperator(Operator operator, String treePriorPrefix, XMLWriter writer) {
writer.writeOpenTag(
GMRFSkyrideBlockUpdateOperatorParser.GRID_BLOCK_UPDATE_OPERATOR,
new Attribute[] {
// This is a Gibbs operator so shouldn't have a tuning parameter?
// new Attribute.Default<Double>(GMRFSkyrideBlockUpdateOperatorParser.SCALE_FACTOR, operator.getTuning()),
new Attribute.Default<Double>(GMRFSkyrideBlockUpdateOperatorParser.SCALE_FACTOR, operator.getTuning()),
getWeightAttribute(operator.getWeight())
}
);
Expand Down Expand Up @@ -688,7 +687,7 @@ private void writeShrinkageClockHMCOperator(Operator operator, String prefix, XM
writer.writeCloseTag(HamiltonianMonteCarloOperatorParser.HMC_OPERATOR);
}

private void writeGMRFGibbsOperator(Operator operator, String treePriorPrefix, XMLWriter writer) {
private void writeGMRFBlockUpdateOperator(Operator operator, String treePriorPrefix, XMLWriter writer) {
writer.writeOpenTag(
GMRFSkyrideBlockUpdateOperatorParser.BLOCK_UPDATE_OPERATOR,
new Attribute[]{
Expand Down
45 changes: 39 additions & 6 deletions src/dr/app/beauti/generator/ParameterPriorGenerator.java
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,13 @@
import dr.evolution.util.Taxa;
import dr.evomodel.branchratemodel.BranchSpecificFixedEffects;
import dr.evomodel.tree.DefaultTreeModel;
import dr.evomodelxml.branchratemodel.BranchSpecificFixedEffectsParser;
import dr.evomodelxml.coalescent.GMRFSkyrideLikelihoodParser;
import dr.evomodelxml.tree.CTMCScalePriorParser;
import dr.evomodelxml.tree.MonophylyStatisticParser;
import dr.inference.distribution.DistributionLikelihood;
import dr.inference.model.ParameterParser;
import dr.inferencexml.distribution.CachedDistributionLikelihoodParser;
import dr.inferencexml.distribution.DistributionLikelihoodParser;
import dr.inferencexml.distribution.PriorParsers;
import dr.inferencexml.model.BooleanLikelihoodParser;
import dr.inferencexml.model.OneOnXPriorParser;
Expand All @@ -58,19 +58,48 @@
*/
public class ParameterPriorGenerator extends Generator {

//map parameters to prior IDs, for use with HMC
private HashMap<String, String> mapParameterToPrior;
//map parameters to prior IDs, for use with HMC or other approaches that define their prior befor the <mcmc> XML block
private final HashMap<String, String> mapParameterToPrior;

public ParameterPriorGenerator(BeautiOptions options, ComponentFactory[] components) {
super(options, components);
//TODO don't like this being here, but will see how things pan out as more HMC approaches are added
mapParameterToPrior = new HashMap<String, String>();
}

/**
* Add all possibly previously defined priors to a HashMap
* Cannot be done in constructor as the models have not been defined by the user at that point
*/
public void addParametersToPrior() {
int totalModels = options.getPartitionClockModels().size();
List<PartitionClockModel> partitionClockModels = options.getPartitionClockModels();
//HMC skygrid
mapParameterToPrior.put(GMRFSkyrideLikelihoodParser.SKYGRID_PRECISION, GMRFSkyrideLikelihoodParser.SKYGRID_PRECISION_PRIOR);
//HMC relaxed clock
mapParameterToPrior.put(ClockType.HMC_CLOCK_LOCATION, BranchSpecificFixedEffects.LOCATION_PRIOR);
mapParameterToPrior.put(ClockType.HMC_CLOCK_BRANCH_RATES, BranchSpecificFixedEffects.RATES_PRIOR);
mapParameterToPrior.put(ClockType.HMCLN_SCALE, BranchSpecificFixedEffects.SCALE_PRIOR);
for (int i = 0; i < totalModels; i++) {
String prefix = partitionClockModels.get(i).getPrefix();
mapParameterToPrior.put(ClockType.HMC_CLOCK_LOCATION, prefix + BranchSpecificFixedEffects.LOCATION_PRIOR);
mapParameterToPrior.put(ClockType.HMC_CLOCK_BRANCH_RATES, prefix + BranchSpecificFixedEffects.RATES_PRIOR);
mapParameterToPrior.put(ClockType.HMCLN_SCALE, prefix + BranchSpecificFixedEffects.SCALE_PRIOR);
}
//mixed effects clock
//always write distribution likelihoods for rate, scale and intercept
for (int i = 0; i < totalModels; i++) {
String prefix = partitionClockModels.get(i).getPrefix();
mapParameterToPrior.put(ClockType.ME_CLOCK_LOCATION, prefix + BranchSpecificFixedEffects.RATES_PRIOR);
mapParameterToPrior.put(ClockType.ME_CLOCK_SCALE, prefix + BranchSpecificFixedEffects.SCALE_PRIOR);
mapParameterToPrior.put(BranchSpecificFixedEffectsParser.INTERCEPT, prefix + BranchSpecificFixedEffects.INTERCEPT_PRIOR);
//check for coefficients
String coeff = BranchSpecificFixedEffectsParser.COEFFICIENT;
int number = 1;
String concat = coeff + number;
while (partitionClockModels.get(i).hasParameter(concat)) {
mapParameterToPrior.put(concat, prefix + BranchSpecificFixedEffectsParser.FIXED_EFFECTS_LIKELIHOOD + number);
number++;
concat = coeff + number;
}
}
}

/**
Expand All @@ -79,6 +108,10 @@ public ParameterPriorGenerator(BeautiOptions options, ComponentFactory[] compone
* @param writer the writer
*/
public void writeParameterPriors(XMLWriter writer) {

//first make sure that all possibly previously defined priors are part of the HashMap
addParametersToPrior();

boolean first = true;

for (Map.Entry<Taxa, Boolean> taxaBooleanEntry : options.taxonSetsMono.entrySet()) {
Expand Down
6 changes: 4 additions & 2 deletions src/dr/app/beauti/generator/SubstitutionModelGenerator.java
Original file line number Diff line number Diff line change
Expand Up @@ -753,7 +753,7 @@ private void writeTwoStateSiteModel(XMLWriter writer, PartitionSubstitutionModel
if (options.useNuRelativeRates()) {
Parameter parameter = model.getParameter("nu");
String prefix1 = options.getPrefix();
if (!parameter.getSubParameters().isEmpty()) {
if (parameter.getParent() != null && !parameter.getSubParameters().isEmpty()) {
writeNuRelativeRateBlock(writer, prefix1, parameter);
}
} else {
Expand Down Expand Up @@ -802,7 +802,9 @@ private void writeAASiteModel(XMLWriter writer, PartitionSubstitutionModel model

if (options.useNuRelativeRates()) {
Parameter parameter = model.getParameter("nu");
writeNuRelativeRateBlock(writer, prefix, parameter);
if (parameter.getParent() != null && !parameter.getSubParameters().isEmpty()) {
writeNuRelativeRateBlock(writer, prefix, parameter);
}
} else {
writeParameter(SiteModelParser.RELATIVE_RATE, "mu", model, writer);
}
Expand Down
8 changes: 4 additions & 4 deletions src/dr/app/beauti/options/PartitionClockModel.java
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ public void initModelParametersAndOpererators() {
.initial(1.0).mean(1.0).offset(0.0).partitionOptions(this).isPriorFixed(true)
.isAdaptiveMultivariateCompatible(false).build(parameters);

new Parameter.Builder(ClockType.HMC_CLOCK_BRANCH_RATES, "HMC relaxed clock branch rates")
new Parameter.Builder(ClockType.HMC_CLOCK_BRANCH_RATES, ClockType.HMC_CLOCK_RATES_DESCRIPTION)
.prior(PriorType.LOGNORMAL_HPM_PRIOR).initial(0.001).isNonNegative(true)
.partitionOptions(this).isPriorFixed(true)
.isAdaptiveMultivariateCompatible(false).build(parameters);
Expand Down Expand Up @@ -216,11 +216,11 @@ public void initModelParametersAndOpererators() {
createScaleOperator(ClockType.UCGD_SHAPE, demoTuning, rateWeights);

//HMC relaxed clock
createOperator("HMCRCR", "HMC relaxed clock branch rates",
createOperator("HMCRCR", ClockType.HMC_CLOCK_RATES_DESCRIPTION,
"Hamiltonian Monte Carlo relaxed clock branch rates operator", null, OperatorType.RELAXED_CLOCK_HMC_RATE_OPERATOR,-1 , 1.0);
createOperator("HMCRCS", "HMC relaxed clock location and scale",
createOperator("HMCRCS", ClockType.HMC_CLOCK_LOCATION_SCALE_DESCRIPTION,
"Hamiltonian Monte Carlo relaxed clock scale operator", null, OperatorType.RELAXED_CLOCK_HMC_SCALE_OPERATOR,-1 , 0.5);
//for the time being turn off the HMC relaxed clock scale kernel
//turn off the HMC relaxed clock scale kernel by default
getOperator("HMCRCS").setUsed(false);
createScaleOperator(ClockType.HMC_CLOCK_LOCATION, demoTuning, rateWeights);
createScaleOperator(ClockType.HMCLN_SCALE, demoTuning, rateWeights);
Expand Down
7 changes: 3 additions & 4 deletions src/dr/app/beauti/options/PartitionTreePrior.java
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
package dr.app.beauti.options;

import dr.app.beauti.types.*;
import dr.evomodel.coalescent.VariableDemographicModel;
import dr.evomodel.speciation.CalibrationPoints;
import dr.evomodelxml.coalescent.GMRFSkyrideLikelihoodParser;
import dr.evomodelxml.speciation.BirthDeathEpidemiologyModelParser;
Expand Down Expand Up @@ -271,9 +270,9 @@ public void initModelParametersAndOpererators() {
// "demographic.indicators", OperatorType.SCALE_WITH_INDICATORS, 0.5, 2 * demoWeights);

createOperatorUsing2Parameters("gmrfGibbsOperator", "gmrfGibbsOperator", "Gibbs sampler for GMRF Skyride", "skyride.logPopSize",
"skyride.precision", OperatorType.GMRF_GIBBS_OPERATOR, -1, 2);
"skyride.precision", OperatorType.GMRF_BLOCKUPDATE_OPERATOR, 1, 2);
createOperatorUsing2Parameters("gmrfSkyGridGibbsOperator", "skygrid.logPopSize", "Gibbs sampler for Bayesian SkyGrid", "skygrid.logPopSize",
GMRFSkyrideLikelihoodParser.SKYGRID_PRECISION, OperatorType.SKY_GRID_GIBBS_OPERATOR, -1, 2);
GMRFSkyrideLikelihoodParser.SKYGRID_PRECISION, OperatorType.SKY_GRID_BLOCKUPDATE_OPERATOR, 1, 2);
createScaleOperator(GMRFSkyrideLikelihoodParser.SKYGRID_PRECISION, "skygrid precision", 0.75, 1.0);
createOperatorUsing2Parameters("gmrfSkyGridHMCOperator", "Multiple", "HMC transition kernel for Bayesian SkyGrid", "skygrid.logPopSize",
GMRFSkyrideLikelihoodParser.SKYGRID_PRECISION, OperatorType.SKY_GRID_HMC_OPERATOR, -1, 2);
Expand All @@ -292,7 +291,7 @@ public void initModelParametersAndOpererators() {
createOperator(BirthDeathSerialSamplingModelParser.BDSS + "."
+ BirthDeathSerialSamplingModelParser.RELATIVE_MU, OperatorType.RANDOM_WALK_LOGIT, demoTuning, 1);
createScaleOperator(BirthDeathSerialSamplingModelParser.BDSS + "."
+ BirthDeathSerialSamplingModelParser.PSI, demoTuning, 1); // todo random worl op ?
+ BirthDeathSerialSamplingModelParser.PSI, demoTuning, 1); // todo random walk op ?
createScaleOperator(BirthDeathSerialSamplingModelParser.BDSS + "."
+ BirthDeathSerialSamplingModelParser.ORIGIN, demoTuning, 1);
// createScaleOperator(BirthDeathSerialSamplingModelParser.BDSS + "."
Expand Down
3 changes: 3 additions & 0 deletions src/dr/app/beauti/types/ClockType.java
Original file line number Diff line number Diff line change
Expand Up @@ -68,4 +68,7 @@ public String toString() {

final public static String ACLD_MEAN = "acld.mean";
final public static String ACLD_STDEV = "acld.stdev";

final public static String HMC_CLOCK_RATES_DESCRIPTION = "HMC relaxed clock branch rates";
final public static String HMC_CLOCK_LOCATION_SCALE_DESCRIPTION = "HMC relaxed clock location and scale";
}
5 changes: 2 additions & 3 deletions src/dr/app/beauti/types/OperatorType.java
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
package dr.app.beauti.types;

import dr.evomodel.operators.BitFlipInSubstitutionModelOperator;
import dr.evomodelxml.operators.TreeNodeSlideParser;
import dr.inference.operators.RateBitExchangeOperator;
import dr.inferencexml.operators.ScaleOperatorParser;

Expand Down Expand Up @@ -66,8 +65,8 @@ public enum OperatorType {
NARROW_EXCHANGE("narrowExchange"),
WIDE_EXCHANGE("wideExchange"),
EMPIRICAL_TREE_SWAP("empiricalSwap"),
GMRF_GIBBS_OPERATOR("gmrfGibbsOperator"),
SKY_GRID_GIBBS_OPERATOR("gmrfGibbsOperator"),
GMRF_BLOCKUPDATE_OPERATOR("gmrfBlockUpdateOperator"),
SKY_GRID_BLOCKUPDATE_OPERATOR("gmrfBlockUpdateOperator"),
SKY_GRID_HMC_OPERATOR("gmrfHMCOperator"),
// PRECISION_GMRF_OPERATOR("precisionGMRFOperator"),
WILSON_BALDING("wilsonBalding"),
Expand Down
Loading

0 comments on commit b9d945e

Please sign in to comment.