diff --git a/test/test_pte.cpp b/test/test_pte.cpp index 13656a2f4b..52956a1416 100644 --- a/test/test_pte.cpp +++ b/test/test_pte.cpp @@ -66,8 +66,8 @@ int main(int argc, char *argv[]) { EOSAccessor eos(eos_v); // scratch required for PTE solver - auto nscratch_vars = - std::max(PTESolverRhoTRequiredScratch(NMAT), PTESolverPTRequiredScratch(NMAT)); + auto nscratch_vars_rt = PTESolverRhoTRequiredScratch(NMAT); + auto nscratch_vars_pt = PTESolverPTRequiredScratch(NMAT); // state vars #ifdef PORTABILITY_STRATEGY_KOKKOS @@ -78,7 +78,8 @@ int main(int argc, char *argv[]) { RView sie_v("sie", NPTS); RView temp_v("temp", NPTS); RView press_v("press", NPTS); - RView scratch_v("scratch", NTRIAL * nscratch_vars); + RView scratch_v_rt("scratch", NTRIAL * nscratch_vars_rt); + RView scratch_v_pt("scratch", NTRIAL * nscratch_vars_pt); Kokkos::View hist_d("histogram", HIST_SIZE); auto rho_vh = Kokkos::create_mirror_view(rho_v); auto vfrac_vh = Kokkos::create_mirror_view(vfrac_v); @@ -92,7 +93,8 @@ int main(int argc, char *argv[]) { DataBox sie_d(sie_v.data(), NTRIAL, NMAT); DataBox temp_d(temp_v.data(), NTRIAL, NMAT); DataBox press_d(press_v.data(), NTRIAL, NMAT); - DataBox scratch_d(scratch_v.data(), NTRIAL * nscratch_vars); + DataBox scratch_d_rt(scratch_v.data(), NTRIAL * nscratch_vars_rt); + DataBox scratch_d_pt(scratch_v.data(), NTRIAL * nscratch_vars_pt); DataBox rho_hm(rho_vh.data(), NTRIAL, NMAT); DataBox vfrac_hm(vfrac_vh.data(), NTRIAL, NMAT); DataBox sie_hm(sie_vh.data(), NTRIAL, NMAT); @@ -172,10 +174,13 @@ int main(int argc, char *argv[]) { const Real Tguess = ApproxTemperatureFromRhoMatU(NMAT, eos, rho_tot * sie_tot, rho, vfrac); + Real *scratch_rt = &scratch_d_rt(t * nscratch_vars_rt); + Real *scratch_pt = &scratch_d_pt(t * nscratch_vars_pt); + auto method = PTESolverRhoT, decltype(lambda)>( NMAT, eos, 1.0, sie_tot, rho, vfrac, sie, temp, press, lambda, - &scratch_d(t * nscratch_vars), Tguess); + scratch_rt, Tguess); auto status = PTESolver(method); if (status.converged) { @@ -186,7 +191,7 @@ int main(int argc, char *argv[]) { auto method2 = PTESolverPT, decltype(lambda)>( NMAT, eos, 1.0, sie_tot, rho, vfrac, sie, temp, press, lambda, - &scratch_d(t * nscratch_vars), Tguess); + scratch_pt, Tguess); status = PTESolver(method2); if (status.converged) {