Skip to content

Commit

Permalink
split scratch
Browse files Browse the repository at this point in the history
  • Loading branch information
jonahm-LANL committed Jan 13, 2025
1 parent 2d98b03 commit 4d0ef18
Showing 1 changed file with 11 additions and 6 deletions.
17 changes: 11 additions & 6 deletions test/test_pte.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,8 @@ int main(int argc, char *argv[]) {
EOSAccessor eos(eos_v);

// scratch required for PTE solver
auto nscratch_vars =
std::max(PTESolverRhoTRequiredScratch(NMAT), PTESolverPTRequiredScratch(NMAT));
auto nscratch_vars_rt = PTESolverRhoTRequiredScratch(NMAT);
auto nscratch_vars_pt = PTESolverPTRequiredScratch(NMAT);

// state vars
#ifdef PORTABILITY_STRATEGY_KOKKOS
Expand All @@ -78,7 +78,8 @@ int main(int argc, char *argv[]) {
RView sie_v("sie", NPTS);
RView temp_v("temp", NPTS);
RView press_v("press", NPTS);
RView scratch_v("scratch", NTRIAL * nscratch_vars);
RView scratch_v_rt("scratch", NTRIAL * nscratch_vars_rt);
RView scratch_v_pt("scratch", NTRIAL * nscratch_vars_pt);
Kokkos::View<int *, atomic_view> hist_d("histogram", HIST_SIZE);
auto rho_vh = Kokkos::create_mirror_view(rho_v);
auto vfrac_vh = Kokkos::create_mirror_view(vfrac_v);
Expand All @@ -92,7 +93,8 @@ int main(int argc, char *argv[]) {
DataBox sie_d(sie_v.data(), NTRIAL, NMAT);
DataBox temp_d(temp_v.data(), NTRIAL, NMAT);
DataBox press_d(press_v.data(), NTRIAL, NMAT);
DataBox scratch_d(scratch_v.data(), NTRIAL * nscratch_vars);
DataBox scratch_d_rt(scratch_v.data(), NTRIAL * nscratch_vars_rt);
DataBox scratch_d_pt(scratch_v.data(), NTRIAL * nscratch_vars_pt);
DataBox rho_hm(rho_vh.data(), NTRIAL, NMAT);
DataBox vfrac_hm(vfrac_vh.data(), NTRIAL, NMAT);
DataBox sie_hm(sie_vh.data(), NTRIAL, NMAT);
Expand Down Expand Up @@ -172,10 +174,13 @@ int main(int argc, char *argv[]) {
const Real Tguess =
ApproxTemperatureFromRhoMatU(NMAT, eos, rho_tot * sie_tot, rho, vfrac);

Real *scratch_rt = &scratch_d_rt(t * nscratch_vars_rt);
Real *scratch_pt = &scratch_d_pt(t * nscratch_vars_pt);

auto method =
PTESolverRhoT<EOSAccessor, Indexer2D<decltype(rho_d)>, decltype(lambda)>(
NMAT, eos, 1.0, sie_tot, rho, vfrac, sie, temp, press, lambda,
&scratch_d(t * nscratch_vars), Tguess);
scratch_rt, Tguess);
auto status = PTESolver(method);

if (status.converged) {
Expand All @@ -186,7 +191,7 @@ int main(int argc, char *argv[]) {
auto method2 =
PTESolverPT<EOSAccessor, Indexer2D<decltype(rho_d)>, decltype(lambda)>(
NMAT, eos, 1.0, sie_tot, rho, vfrac, sie, temp, press, lambda,
&scratch_d(t * nscratch_vars), Tguess);
scratch_pt, Tguess);
status = PTESolver(method2);

if (status.converged) {
Expand Down

0 comments on commit 4d0ef18

Please sign in to comment.