Skip to content

Commit

Permalink
some advances in cpp version onf sandwich
Browse files Browse the repository at this point in the history
  • Loading branch information
jchiquet committed Nov 21, 2024
1 parent 46fd6c3 commit e7f78d4
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 5 deletions.
4 changes: 4 additions & 0 deletions R/RcppExports.R
Original file line number Diff line number Diff line change
Expand Up @@ -89,3 +89,7 @@ cpp_test_packing <- function() {
.Call('_PLNmodels_cpp_test_packing', PACKAGE = 'PLNmodels')
}

get_sandwich_variance_B <- function(Y, X, A, S, Sigma, Diag_Omega) {
.Call('_PLNmodels_get_sandwich_variance_B', PACKAGE = 'PLNmodels', Y, X, A, S, Sigma, Diag_Omega)
}

17 changes: 17 additions & 0 deletions src/RcppExports.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,22 @@ BEGIN_RCPP
return rcpp_result_gen;
END_RCPP
}
// get_sandwich_variance_B
arma::mat get_sandwich_variance_B(const arma::mat& Y, const arma::mat& X, const arma::mat& A, const arma::mat& S, const arma::mat& Sigma, const arma::vec& Diag_Omega);
RcppExport SEXP _PLNmodels_get_sandwich_variance_B(SEXP YSEXP, SEXP XSEXP, SEXP ASEXP, SEXP SSEXP, SEXP SigmaSEXP, SEXP Diag_OmegaSEXP) {
BEGIN_RCPP
Rcpp::RObject rcpp_result_gen;
Rcpp::RNGScope rcpp_rngScope_gen;
Rcpp::traits::input_parameter< const arma::mat& >::type Y(YSEXP);
Rcpp::traits::input_parameter< const arma::mat& >::type X(XSEXP);
Rcpp::traits::input_parameter< const arma::mat& >::type A(ASEXP);
Rcpp::traits::input_parameter< const arma::mat& >::type S(SSEXP);
Rcpp::traits::input_parameter< const arma::mat& >::type Sigma(SigmaSEXP);
Rcpp::traits::input_parameter< const arma::vec& >::type Diag_Omega(Diag_OmegaSEXP);
rcpp_result_gen = Rcpp::wrap(get_sandwich_variance_B(Y, X, A, S, Sigma, Diag_Omega));
return rcpp_result_gen;
END_RCPP
}

static const R_CallMethodDef CallEntries[] = {
{"_PLNmodels_cpp_test_nlopt", (DL_FUNC) &_PLNmodels_cpp_test_nlopt, 0},
Expand All @@ -354,6 +370,7 @@ static const R_CallMethodDef CallEntries[] = {
{"_PLNmodels_optim_zipln_M", (DL_FUNC) &_PLNmodels_optim_zipln_M, 9},
{"_PLNmodels_optim_zipln_S", (DL_FUNC) &_PLNmodels_optim_zipln_S, 7},
{"_PLNmodels_cpp_test_packing", (DL_FUNC) &_PLNmodels_cpp_test_packing, 0},
{"_PLNmodels_get_sandwich_variance_B", (DL_FUNC) &_PLNmodels_get_sandwich_variance_B, 6},
{NULL, NULL, 0}
};

Expand Down
24 changes: 19 additions & 5 deletions src/utils-R.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,31 @@ arma::mat get_sandwich_variance_B(
const arma::vec & Diag_Omega
) {

arma::uword n = Y.n_rows ;
arma::uword p = Y.n_cols ;
arma::uword d = X.n_cols ;

arma::mat get_iCnB = [&A, &S, &D_omega, &Sigma](
) {
auto get_iCnB = [&A, &S, &Diag_Omega, &Sigma](arma::uword i) {
arma::vec a = A.row(i) ;
arma::vec s = S.row(i) ;
arma::mat D = diagmat(pow(a, -1) + pow(s, 4) / (1 + pow(s,2) * (a + Diag_Omega))) ;

return arma::inv_sympd(Sigma + D);
};

return ;
arma::mat YmA = Y - A ;
arma::mat Cn = arma::zeros(d*p, d*p) ;
arma::mat Dn = arma::zeros(d*p, d*p) ;
for (int i=0; i<n; i++) {
arma::mat xxt_i = X.col(i) * X.col(i).t() ;
arma::mat yyt_i = YmA.col(i) * YmA.col(i).t() ;
Cn = Cn - arma::kron(get_iCnB(i), xxt_i) / n ;
Dn = Dn + arma::kron(yyt_i, xxt_i) / n ;
}

arma::mat Cn_inv = arma::inv_sympd(Cn) ;


return ;
return (Cn_inv * Dn * Cn_inv) / n ;
}

// vcov_sandwich_B = function(Y, X) {
Expand Down

0 comments on commit e7f78d4

Please sign in to comment.