Skip to content

Commit

Permalink
Merge pull request #116 from PLN-team/zipln
Browse files Browse the repository at this point in the history
Integrating ZIPLN to PLNmodels
  • Loading branch information
mahendra-mariadassou authored Jan 23, 2024
2 parents 3ac3d91 + ceb39af commit d3cd585
Show file tree
Hide file tree
Showing 37 changed files with 3,134 additions and 24 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,9 @@ src/*.dll
# Plot in testthat
tests/testthat/Rplots.pdf

# Raw data files
data-raw/*.RData

# Mac bullsh.
.DS_Store

Expand Down
8 changes: 7 additions & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: PLNmodels
Title: Poisson Lognormal Models
Version: 1.1.0
Version: 1.2.0
Authors@R: c(
person("Julien", "Chiquet", role = c("aut", "cre"), email = "[email protected]",
comment = c(ORCID = "0000-0002-3629-3429")),
Expand Down Expand Up @@ -37,6 +37,7 @@ Imports:
future.apply,
R6,
glassoFast,
pscl,
Matrix,
Rcpp,
nloptr,
Expand Down Expand Up @@ -86,13 +87,18 @@ Collate:
'PLNnetworkfit-S3methods.R'
'PLNnetworkfit-class.R'
'RcppExports.R'
'ZIPLNfit-class.R'
'ZIPLN.R'
'ZIPLNfit-S3methods.R'
'barents.R'
'import_utils.R'
'mollusk.R'
'oaks.R'
'plot_utils.R'
'scRNA.R'
'trichoptera.R'
'utils-pipe.R'
'utils-zipln.R'
'utils.R'
'zzz.R'
Language: en-US
10 changes: 10 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@
S3method(coef,PLNLDAfit)
S3method(coef,PLNfit)
S3method(coef,PLNmixturefit)
S3method(coef,ZIPLNfit)
S3method(fitted,PLNfit)
S3method(fitted,PLNmixturefit)
S3method(fitted,ZIPLNfit)
S3method(getBestModel,PLNPCAfamily)
S3method(getBestModel,PLNmixturefamily)
S3method(getBestModel,PLNnetworkfamily)
Expand All @@ -22,9 +24,11 @@ S3method(plot,PLNnetworkfit)
S3method(predict,PLNLDAfit)
S3method(predict,PLNfit)
S3method(predict,PLNmixturefit)
S3method(predict,ZIPLNfit)
S3method(predict_cond,PLNfit)
S3method(sigma,PLNfit)
S3method(sigma,PLNmixturefit)
S3method(sigma,ZIPLNfit)
S3method(standard_error,PLNPCAfit)
S3method(standard_error,PLNfit)
S3method(standard_error,PLNfit_fixedcov)
Expand All @@ -42,6 +46,8 @@ export(PLNmixture)
export(PLNmixture_param)
export(PLNnetwork)
export(PLNnetwork_param)
export(ZIPLN)
export(ZIPLN_param)
export(coefficient_path)
export(compute_PLN_starting_point)
export(compute_offset)
Expand Down Expand Up @@ -83,6 +89,7 @@ importFrom(igraph,graph_from_adjacency_matrix)
importFrom(igraph,layout_in_circle)
importFrom(igraph,plot.igraph)
importFrom(magrittr,"%>%")
importFrom(pscl,zeroinfl)
importFrom(purrr,map)
importFrom(purrr,map2)
importFrom(purrr,map2_dbl)
Expand All @@ -93,8 +100,11 @@ importFrom(purrr,reduce)
importFrom(rlang,.data)
importFrom(stats,.getXlevels)
importFrom(stats,.lm.fit)
importFrom(stats,as.formula)
importFrom(stats,binomial)
importFrom(stats,coef)
importFrom(stats,coefficients)
importFrom(stats,fitted)
importFrom(stats,glm.control)
importFrom(stats,glm.fit)
importFrom(stats,lm.fit)
Expand Down
6 changes: 5 additions & 1 deletion NEWS.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
# PLNmodels 1.1.0 (2023-08-24)
# Current (2024-01-23)

* Addition of ZIPLN() and ZIPLNfit-class to allow for zero-inflation in the (for now) standard PLN model (merge PR #116)

# PLNmodels 1.1.0 (2024-01-08)

* Update documentation of PLN*_param() functions to include torch optimization parameters
* Add (somehow) explicit error message when torch convergence fails
Expand Down
14 changes: 5 additions & 9 deletions R/PLNfit-class.R
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
#' @rdname PLNfit
#' @include PLNfit-class.R
#' @importFrom R6 R6Class
#' @import torch
#'
#' @examples
#' \dontrun{
Expand Down Expand Up @@ -96,7 +97,6 @@ PLNfit <- R6Class(
Ji
},

#' @import torch
torch_optimize = function(data, params, config) {

#config$device = "mps"
Expand Down Expand Up @@ -184,7 +184,6 @@ PLNfit <- R6Class(
out
},


## %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
## PRIVATE METHODS FOR VARIANCE OF THE ESTIMATORS
## %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
Expand Down Expand Up @@ -345,7 +344,6 @@ PLNfit <- R6Class(
## END OF PRIVATE METHODS
## %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%


),
## %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
## PUBLIC MEMBERS
Expand Down Expand Up @@ -373,10 +371,6 @@ PLNfit <- R6Class(
private$S <- control$inception$var_par$S
} else {
if (control$trace > 1) cat("\n Use LM after log transformation to define the inceptive model")
# fits <- lm.fit(weights * covariates, weights * log((1 + responses)/exp(offsets)))
# private$B <- matrix(fits$coefficients, d, p)
# private$M <- matrix(fits$residuals, n, p)
# private$S <- matrix(.1, n, p)
start_point <- compute_PLN_starting_point(Y = responses, X = covariates, O = offsets, w = weights)
private$B <- start_point$B
private$M <- start_point$M
Expand Down Expand Up @@ -552,7 +546,7 @@ PLNfit <- R6Class(
S <- VE$S
} else {
# otherwise set M = 0 and S = diag(Sigma)
M <- matrix(1, nrow = n_new, ncol = self$p)
M <- matrix(0, nrow = n_new, ncol = self$p)
S <- matrix(diag(private$Sigma), nrow = n_new, ncol = self$p, byrow = TRUE)
}

Expand Down Expand Up @@ -780,6 +774,7 @@ PLNfit_diagonal <- R6Class(
attr(Ji, "weights") <- as.numeric(data$w)
Ji
}

## %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
## END OF TORCH METHODS
## %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
Expand Down Expand Up @@ -860,6 +855,7 @@ PLNfit_spherical <- R6Class(
attr(Ji, "weights") <- as.numeric(data$w)
Ji
}

## %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
## END OF TORCH METHODS FOR OPTIMIZATION
## %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
Expand Down Expand Up @@ -913,7 +909,7 @@ PLNfit_fixedcov <- R6Class(
initialize = function(responses, covariates, offsets, weights, formula, control) {
super$initialize(responses, covariates, offsets, weights, formula, control)
private$optimizer$main <- ifelse(control$backend == "nlopt", nlopt_optimize_fixed, private$torch_optimize)
## ve step is the same as in the fullly parameterized covariance
## ve step is the same as in the fully parameterized covariance
private$Omega <- control$Omega
},
#' @description Call to the NLopt or TORCH optimizer and update of the relevant fields
Expand Down
36 changes: 36 additions & 0 deletions R/RcppExports.R
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,42 @@ nlopt_optimize_vestep_spherical <- function(data, params, B, Omega, config) {
.Call('_PLNmodels_nlopt_optimize_vestep_spherical', PACKAGE = 'PLNmodels', data, params, B, Omega, config)
}

zipln_vloglik <- function(Y, X, O, Pi, Omega, B, R, M, S) {
.Call('_PLNmodels_zipln_vloglik', PACKAGE = 'PLNmodels', Y, X, O, Pi, Omega, B, R, M, S)
}

optim_zipln_Omega_full <- function(M, X, B, S) {
.Call('_PLNmodels_optim_zipln_Omega_full', PACKAGE = 'PLNmodels', M, X, B, S)
}

optim_zipln_Omega_spherical <- function(M, X, B, S) {
.Call('_PLNmodels_optim_zipln_Omega_spherical', PACKAGE = 'PLNmodels', M, X, B, S)
}

optim_zipln_Omega_diagonal <- function(M, X, B, S) {
.Call('_PLNmodels_optim_zipln_Omega_diagonal', PACKAGE = 'PLNmodels', M, X, B, S)
}

optim_zipln_B_dense <- function(M, X) {
.Call('_PLNmodels_optim_zipln_B_dense', PACKAGE = 'PLNmodels', M, X)
}

optim_zipln_zipar_covar <- function(R, init_B0, X0, configuration) {
.Call('_PLNmodels_optim_zipln_zipar_covar', PACKAGE = 'PLNmodels', R, init_B0, X0, configuration)
}

optim_zipln_R <- function(Y, X, O, M, S, Pi) {
.Call('_PLNmodels_optim_zipln_R', PACKAGE = 'PLNmodels', Y, X, O, M, S, Pi)
}

optim_zipln_M <- function(init_M, Y, X, O, R, S, B, Omega, configuration) {
.Call('_PLNmodels_optim_zipln_M', PACKAGE = 'PLNmodels', init_M, Y, X, O, R, S, B, Omega, configuration)
}

optim_zipln_S <- function(init_S, O, M, R, B, diag_Omega, configuration) {
.Call('_PLNmodels_optim_zipln_S', PACKAGE = 'PLNmodels', init_S, O, M, R, B, diag_Omega, configuration)
}

cpp_test_packing <- function() {
.Call('_PLNmodels_cpp_test_packing', PACKAGE = 'PLNmodels')
}
Expand Down
128 changes: 128 additions & 0 deletions R/ZIPLN.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
#' Zero Inflated Poisson lognormal model
#'
#' Fit the multivariate Zero Inflated Poisson lognormal model with a variational algorithm. Use the (g)lm syntax for model specification (covariates, offsets, subset).
#'
#' @inheritParams PLN
#' @param control a list-like structure for controlling the optimization, with default generated by [ZIPLN_param()]. See the associated documentation
#' for details.
#' @param zi a character describing the model used for zero inflation, either of
#' - "single" (default, one parameter shared by all counts)
#' - "col" (one parameter per variable / feature)
#' - "row" (one parameter per sample / individual).
#' If covariates are specified in the formula RHS (see details) this parameter is ignored.
#'
#' @details
#' Covariates for the Zero-Inflation parameter (using a logistic regression model) can be specified in the formula RHS using the pipe
#' (`~ PLN effect | ZI effect`) to separate covariates for the PLN part of the model from those for the Zero-Inflation part.
#' Note that different covariates can be used for each part.
#'
#' @return an R6 object with class [`ZIPLNfit`]
#'
#' @rdname ZIPLN
#' @include ZIPLNfit-class.R
#' @examples
#' data(trichoptera)
#' trichoptera <- prepare_data(trichoptera$Abundance, trichoptera$Covariate)
#' myPLN <- PLN(Abundance ~ 1, data = trichoptera)
#' ## Use different models for zero-inflation...
#' myZIPLN_single <- ZIPLN(Abundance ~ 1, data = trichoptera, zi = "single")
#' myZIPLN_row <- ZIPLN(Abundance ~ 1, data = trichoptera, zi = "row")
#' myZIPLN_col <- ZIPLN(Abundance ~ 1, data = trichoptera, zi = "col")
#' ## ...including logistic regression on covariates
#' myZIPLN_covar <- ZIPLN(Abundance ~ 1 | 1 + Wind, data = trichoptera)
#' dplyr::bind_rows(
#' myPLN$criteria,
#' myZIPLN_single$criteria,
#' myZIPLN_row$criteria,
#' myZIPLN_col$criteria,
#' myZIPLN_covar$criteria
#' )
#' @seealso The class [`ZIPLNfit`]
#' @importFrom stats model.frame model.matrix model.response model.offset terms as.formula
#' @export
ZIPLN <- function(formula, data, subset, zi = c("single", "row", "col"), control = ZIPLN_param()) {

## extract the data matrices and weights
args <- extract_model_zi(match.call(expand.dots = FALSE), parent.frame())
control$ziparam <- ifelse((args$zicovar), "covar", match.arg(zi))

## initialization
if (control$trace > 0) cat("\n Initialization...")
myPLN <- switch(control$covariance,
"diagonal" = ZIPLNfit_diagonal$new(args$Y , list(PLN = args$X, ZI = args$X0), args$O, args$w, args$formula, control),
"spherical" = ZIPLNfit_spherical$new(args$Y, list(PLN = args$X, ZI = args$X0), args$O, args$w, args$formula, control),
"fixed" = ZIPLNfit_fixed$new(args$Y , list(PLN = args$X, ZI = args$X0), args$O, args$w, args$formula, control),
"sparse" = ZIPLNfit_sparse$new(args$Y , list(PLN = args$X, ZI = args$X0), args$O, args$w, args$formula, control),
ZIPLNfit$new(args$Y, list(PLN = args$X, ZI = args$X0), args$O, args$w, args$formula, control)) # default: full covariance

## optimization
if (control$trace > 0) cat("\n Adjusting a ZI-PLN model with",
control$covariance,"covariance model and",
control$ziparam, "specific parameter(s) in Zero inflation component.")
myPLN$optimize(args$Y, list(PLN = args$X, ZI = args$X0), args$O, args$w, control$config_optim)

if (control$trace > 0) cat("\n DONE!\n")
myPLN
}

## -----------------------------------------------------------------
## Series of setter to default parameters for user's main functions

#' Control of a ZIPLN fit
#'
#' Helper to define list of parameters to control the PLN fit. All arguments have defaults.
#'
#' @inheritParams PLN_param
#' @param penalty a user-defined penalty to sparsify the residual covariance. Defaults to 0 (no sparsity).
#' @return list of parameters used during the fit and post-processing steps
#'
#' @inherit PLN_param details
#' @details See [PLN_param()] for a full description of the generic optimization parameters. ZIPLN_param() also has two additional parameters controlling the optimization due
#' the inner-outer loop structure of the optimizer:
#' * "ftol_out" outer solver stops when an optimization step changes the objective function by less than `ftol_out` multiplied by the absolute value of the parameter. Default is 1e-8
#' * "maxit_out" outer solver stops when the number of iteration exceeds `maxit_out`. Default is 100
#'
#' @export
ZIPLN_param <- function(
backend = c("nlopt"),
trace = 1,
covariance = c("full", "diagonal", "spherical", "fixed", "sparse"),
Omega = NULL,
penalty = 0,
config_post = list(),
config_optim = list(),
inception = NULL # pretrained ZIPLNfit used as initialization
) {

covariance <- match.arg(covariance)
if (covariance == "fixed") stopifnot("Omega must be provied for fixed covariance" = inherits(Omega, "matrix") | inherits(Omega, "Matrix")) |> try()
if (inherits(Omega, "matrix") | inherits(Omega, "Matrix")) covariance <- "fixed"
if (covariance == "sparse") stopifnot("You should provide a positive penalty when chosing 'sparse' covariance" = penalty > 0) |> try()
if (penalty > 0) covariance <- "sparse"
if (!is.null(inception)) stopifnot(isZIPLNfit(inception))

## post-treatment config
config_pst <- config_post_default_PLN
config_pst[names(config_post)] <- config_post
config_pst$trace <- trace

## optimization config
stopifnot(backend %in% c("nlopt"))
stopifnot(config_optim$algorithm %in% available_algorithms_nlopt)
config_opt <- config_default_nlopt
config_opt$trace <- trace
config_opt$ftol_out <- 1e-6
config_opt$maxit_out <- 100
config_opt[names(config_optim)] <- config_optim

structure(list(
backend = backend ,
trace = trace ,
covariance = covariance,
Omega = Omega ,
penalty = penalty ,
config_post = config_pst,
config_optim = config_opt,
inception = inception), class = "PLNmodels_param")

}
Loading

0 comments on commit d3cd585

Please sign in to comment.