Skip to content

Commit

Permalink
Made sure code is robust to having different precisions for ParticleR…
Browse files Browse the repository at this point in the history
…eal and Real.
  • Loading branch information
stevenhofmeyr committed Nov 7, 2024
1 parent 02b0290 commit 3fb168f
Show file tree
Hide file tree
Showing 13 changed files with 95 additions and 70 deletions.
2 changes: 1 addition & 1 deletion UrbanPop-scripts/extract_urbanpop_feather.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def print_header(df):
#include <sstream>
using std::string;
using float32_t = amrex::Real;
using float32_t = amrex::ParticleReal;
namespace UrbanPop {{
Expand Down
54 changes: 32 additions & 22 deletions src/AgentContainer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -177,13 +177,13 @@ void AgentContainer::moveAgentsToWork ()
if (!inHospital(ip, ptd)) {
ParticleType& p = pstruct[ip];
if (is_census) { // using census data
p.pos(0) = (work_i_ptr[ip] + 0.5_prt)*dx[0];
p.pos(1) = (work_j_ptr[ip] + 0.5_prt)*dx[1];
p.pos(0) = static_cast<ParticleReal>((work_i_ptr[ip] + 0.5_rt) * dx[0]);
p.pos(1) = static_cast<ParticleReal>((work_j_ptr[ip] + 0.5_rt) * dx[1]);
} else {
Real lng, lat;
(*grid_to_lnglat_ptr)(work_i_ptr[ip], work_j_ptr[ip], lng, lat);
p.pos(0) = lng;
p.pos(1) = lat;
p.pos(0) = static_cast<ParticleReal>(lng);
p.pos(1) = static_cast<ParticleReal>(lat);
}
}
});
Expand Down Expand Up @@ -234,13 +234,13 @@ void AgentContainer::moveAgentsToHome ()
if (!inHospital(ip, ptd)) {
ParticleType& p = pstruct[ip];
if (is_census) { // using census data
p.pos(0) = (home_i_ptr[ip] + 0.5_prt) * dx[0];
p.pos(1) = (home_j_ptr[ip] + 0.5_prt) * dx[1];
p.pos(0) = static_cast<ParticleReal>((home_i_ptr[ip] + 0.5_rt) * dx[0]);
p.pos(1) = static_cast<ParticleReal>((home_j_ptr[ip] + 0.5_rt) * dx[1]);
} else {
Real lng, lat;
(*grid_to_lnglat_ptr)(home_i_ptr[ip], home_j_ptr[ip], lng, lat);
p.pos(0) = lng;
p.pos(1) = lat;
p.pos(0) = static_cast<ParticleReal>(lng);
p.pos(1) = static_cast<ParticleReal>(lat);
}
}
});
Expand Down Expand Up @@ -395,8 +395,8 @@ void AgentContainer::setAirTravel (const iMultiFab& unit_mf, AirTravelFlow& air,
int unit = unit_arr(home_i_ptr[i], home_j_ptr[i], 0);
int orgAirport= assigned_airport_ptr[unit];
int destAirport=-1;
float lowProb=0.0;
float random= amrex::Random(engine);
Real lowProb = 0.0_rt;
Real random = amrex::Random(engine);
//choose a destination airport for the agent (number of airports is often small, so let's visit in sequential order)
for(int idx= dest_airports_offset_ptr[orgAirport]; idx<dest_airports_offset_ptr[orgAirport+1]; idx++){
float hiProb= dest_airports_prob_ptr[idx];
Expand All @@ -408,7 +408,7 @@ void AgentContainer::setAirTravel (const iMultiFab& unit_mf, AirTravelFlow& air,
}
if(destAirport >=0){
int destUnit=-1;
float random1= amrex::Random(engine);
Real random1= amrex::Random(engine);
int low=arrivalUnits_offset_ptr[destAirport], high=arrivalUnits_offset_ptr[destAirport+1];
if(high-low<=16){
//this sequential algo. is very slow when we have to go through hundreds or thoudsands of units to select a destination
Expand Down Expand Up @@ -480,13 +480,13 @@ void AgentContainer::returnRandomTravel ()
ParticleType& p = pstruct[i];
random_travel_ptr[i] = -1;
if (is_census) {
p.pos(0) = (home_i_ptr[i] + 0.5_prt) * dx[0];
p.pos(1) = (home_j_ptr[i] + 0.5_prt) * dx[1];
p.pos(0) = static_cast<ParticleReal>((home_i_ptr[i] + 0.5_rt) * dx[0]);
p.pos(1) = static_cast<ParticleReal>((home_j_ptr[i] + 0.5_rt) * dx[1]);
} else {
Real lng, lat;
(*grid_to_lnglat_ptr)(home_i_ptr[i], home_j_ptr[i], lng, lat);
p.pos(0) = lng;
p.pos(1) = lat;
p.pos(0) = static_cast<ParticleReal>(lng);
p.pos(1) = static_cast<ParticleReal>(lat);
}
}
});
Expand All @@ -508,6 +508,9 @@ void AgentContainer::returnAirTravel ()
auto& plev = GetParticles(lev);
const auto dx = Geom(lev).CellSizeArray();

bool is_census = (ic_type == ExaEpi::ICType::Census);
auto grid_to_lnglat_ptr = &grid_to_lnglat;

#ifdef AMREX_USE_OMP
#pragma omp parallel if (Gpu::notInLaunchRegion())
#endif
Expand All @@ -527,8 +530,15 @@ void AgentContainer::returnAirTravel ()
if (air_travel_ptr[i] >= 0) {
ParticleType& p = pstruct[i];
air_travel_ptr[i] = -1;
p.pos(0) = (home_i_ptr[i] + 0.5_prt) * dx[0];
p.pos(1) = (home_j_ptr[i] + 0.5_prt) * dx[1];
if (is_census) { // using census data
p.pos(0) = static_cast<ParticleReal>((home_i_ptr[i] + 0.5_rt) * dx[0]);
p.pos(1) = static_cast<ParticleReal>((home_j_ptr[i] + 0.5_rt) * dx[1]);
} else {
Real lng, lat;
(*grid_to_lnglat_ptr)(home_i_ptr[i], home_j_ptr[i], lng, lat);
p.pos(0) = static_cast<ParticleReal>(lng);
p.pos(1) = static_cast<ParticleReal>(lat);
}
}
});
}
Expand Down Expand Up @@ -577,13 +587,13 @@ void AgentContainer::updateStatus ( MFPtrVec& a_disease_stats /*!< Community-wis
if (inHospital(ip, ptd)) {
ParticleType& p = pstruct[ip];
if (is_census) {
p.pos(0) = (hosp_i_ptr[ip] + 0.5_prt) * dx[0];
p.pos(1) = (hosp_j_ptr[ip] + 0.5_prt) * dx[1];
p.pos(0) = static_cast<ParticleReal>((hosp_i_ptr[ip] + 0.5_prt) * dx[0]);
p.pos(1) = static_cast<ParticleReal>((hosp_j_ptr[ip] + 0.5_prt) * dx[1]);
} else {
Real lng, lat;
(*grid_to_lnglat_ptr)(hosp_i_ptr[ip], hosp_j_ptr[ip], lng, lat);
p.pos(0) = lng;
p.pos(1) = lat;
p.pos(0) = static_cast<ParticleReal>(lng);
p.pos(1) = static_cast<ParticleReal>(lat);
}
}
});
Expand Down Expand Up @@ -700,7 +710,7 @@ void AgentContainer::infectAgents ()
amrex::ParallelForRNG( np,
[=] AMREX_GPU_DEVICE (int i, amrex::RandomEngine const& engine) noexcept
{
prob_ptr[i] = 1.0_rt - prob_ptr[i];
prob_ptr[i] = 1.0_prt - prob_ptr[i];
if ( status_ptr[i] == Status::never ||
status_ptr[i] == Status::susceptible ) {
if (amrex::Random(engine) < prob_ptr[i]) {
Expand Down
8 changes: 4 additions & 4 deletions src/CensusData.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -425,15 +425,15 @@ void CensusData::initAgents (AgentContainer& pc, /*!< Agents */
}
}

agent.pos(0) = (i + 0.5_rt)*dx[0];
agent.pos(1) = (j + 0.5_rt)*dx[1];
agent.pos(0) = static_cast<ParticleReal>((i + 0.5_rt) * dx[0]);
agent.pos(1) = static_cast<ParticleReal>((j + 0.5_rt) * dx[1]);
agent.id() = pid+ip;
agent.cpu() = my_proc;

for (int d = 0; d < n_disease; d++) {
status_ptrs[d][ip] = 0;
counter_ptrs[d][ip] = 0.0_rt;
timer_ptrs[d][ip] = 0.0_rt;
counter_ptrs[d][ip] = 0.0_prt;
timer_ptrs[d][ip] = 0.0_prt;
}
age_group_ptr[ip] = age_group;
family_ptr[ip] = family_id_start + (ii / family_size);
Expand Down
20 changes: 12 additions & 8 deletions src/DiseaseParm.H
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include "AgentDefinitions.H"

using amrex::Real;
using amrex::ParticleReal;
using amrex::Vector;

struct CaseTypes
Expand Down Expand Up @@ -160,18 +161,21 @@ struct DiseaseParm
/*! \brief Set this agent to infected status, and initialize disease periods. */
AMREX_GPU_DEVICE AMREX_FORCE_INLINE
void setInfected ( int* status,
amrex::Real* counter,
amrex::Real* latent_period,
amrex::Real* infectious_period,
amrex::Real* incubation_period,
amrex::ParticleReal* counter,
amrex::ParticleReal* latent_period,
amrex::ParticleReal* infectious_period,
amrex::ParticleReal* incubation_period,
amrex::RandomEngine const& engine,
const DiseaseParm* lparm)
{
*status = Status::infected;
*counter = amrex::Real(0);
*latent_period = amrex::RandomNormal(lparm->latent_length_mean, lparm->latent_length_std, engine);
*infectious_period = amrex::RandomNormal(lparm->infectious_length_mean, lparm->infectious_length_std, engine);
*incubation_period = amrex::RandomNormal(lparm->incubation_length_mean, lparm->incubation_length_std, engine);
*counter = ParticleReal(0);
*latent_period =
static_cast<ParticleReal>(amrex::RandomNormal(lparm->latent_length_mean, lparm->latent_length_std, engine));
*infectious_period =
static_cast<ParticleReal>(amrex::RandomNormal(lparm->infectious_length_mean, lparm->infectious_length_std, engine));
*incubation_period =
static_cast<ParticleReal>(amrex::RandomNormal(lparm->incubation_length_mean, lparm->incubation_length_std, engine));
if (*latent_period < 0) { *latent_period = amrex::Real(0);}
if (*infectious_period < 0) { *infectious_period = amrex::Real(0);}
if (*incubation_period < 0) { *incubation_period = amrex::Real(0);}
Expand Down
11 changes: 6 additions & 5 deletions src/DiseaseStatus.H
Original file line number Diff line number Diff line change
Expand Up @@ -146,10 +146,10 @@ void DiseaseStatus<AC,ACT,ACTD,A>::updateAgents(AC& a_agents, /*!< Agent contain
return;
}
else if (status_ptr[i] == Status::immune) {
counter_ptr[i] -= 1.0_rt;
if (counter_ptr[i] < 0.0_rt) {
counter_ptr[i] = 0.0_rt;
timer_ptr[i] = 0.0_rt;
counter_ptr[i] -= 1.0_prt;
if (counter_ptr[i] < 0.0_prt) {
counter_ptr[i] = 0.0_prt;
timer_ptr[i] = 0.0_prt;
status_ptr[i] = Status::susceptible;
return;
}
Expand Down Expand Up @@ -188,7 +188,8 @@ void DiseaseStatus<AC,ACT,ACTD,A>::updateAgents(AC& a_agents, /*!< Agent contain
else if (!inHospital(i,ptd)) {
if (counter_ptr[i] >= (latent_period_ptr[i] + infectious_period_ptr[i])) {
status_ptr[i] = Status::immune;
counter_ptr[i] = RandomNormal(immune_length_mean, immune_length_std, engine);
counter_ptr[i] =
static_cast<ParticleReal>(RandomNormal(immune_length_mean, immune_length_std, engine));
symptomatic_ptr[i] = SymptomStatus::presymptomatic;
withdrawn_ptr[i] = 0;
}
Expand Down
15 changes: 8 additions & 7 deletions src/HospitalModel.H
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ void HospitalModel<PCType, PTDType, PType>::treatAgents(PCType& a_agents, /*!< A

AMREX_ALWAYS_ASSERT(status_ptrs[d][i] == Status::infected);
// decrement days in hospital
timer_ptrs[d][i] -= 1.0_rt;
timer_ptrs[d][i] -= 1.0_prt;
if (timer_ptrs[d][i] == 0) {
// finished hospitalization period
flag_status_ptr[i] = DiseaseStats::hospitalization + 1;
Expand All @@ -169,10 +169,11 @@ void HospitalModel<PCType, PTDType, PType>::treatAgents(PCType& a_agents, /*!< A
} else {
// If alive, hospitalized patient recovers
status_ptrs[d][i] = Status::immune;
counter_ptrs[d][i] = RandomNormal(immune_length_mean, immune_length_std, engine);
counter_ptrs[d][i] =
static_cast<ParticleReal>(RandomNormal(immune_length_mean, immune_length_std, engine));
symptomatic_ptrs[d][i] = SymptomStatus::presymptomatic;
withdrawn_ptr[i] = 0;
timer_ptrs[d][i] = 0.0_rt;
timer_ptrs[d][i] = 0.0_prt;
}
}
});
Expand Down Expand Up @@ -207,13 +208,13 @@ void HospitalModel<PCType, PTDType, PType>::treatAgents(PCType& a_agents, /*!< A
withdrawn_ptr[i] = 0;
PType& p = pstruct[i];
if (is_census) {
p.pos(0) = (home_i_ptr[i] + 0.5_prt) * dx[0];
p.pos(1) = (home_j_ptr[i] + 0.5_prt) * dx[1];
p.pos(0) = static_cast<ParticleReal>((home_i_ptr[i] + 0.5_rt) * dx[0]);
p.pos(1) = static_cast<ParticleReal>((home_j_ptr[i] + 0.5_rt) * dx[1]);
} else {
Real lng, lat;
(*grid_to_lnglat_ptr)(home_i_ptr[i], home_j_ptr[i], lng, lat);
p.pos(0) = lng;
p.pos(1) = lat;
p.pos(0) = static_cast<ParticleReal>(lng);
p.pos(1) = static_cast<ParticleReal>(lat);
}
}
}
Expand Down
10 changes: 5 additions & 5 deletions src/InteractionModHome.H
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ void InteractionModHome<PCType, PTDType, PType>::fastInteractHome (PCType& agent
auto prob_ptr = soa.GetRealData(RealIdx::nattribs + r0(d) + RealIdxDisease::prob).data();
auto lparm = agents.getDiseaseParameters_d(d);
auto lparm_h = agents.getDiseaseParameters_h(d);
Real scale = 1.0_prt; // TODO this should vary based on cell
Real scale = 1.0_rt; // TODO this should vary based on cell
Real infect = (1.0_rt - lparm_h->vac_eff);
// loop to count infectious agents in each group
ParallelFor(np, [=] AMREX_GPU_DEVICE (int i) noexcept {
Expand All @@ -179,8 +179,8 @@ void InteractionModHome<PCType, PTDType, PType>::fastInteractHome (PCType& agent
// and in neighborhood cluster, adjust the infected counts to avoid double-counting the overlap.
ParallelFor(np, [=] AMREX_GPU_DEVICE (int i) noexcept {
if (isSusceptible(i, ptd, d) && isHomeCandidate(i, ptd)) {
ParticleReal xmit_family_prob = 0;
ParticleReal xmit_nc_prob = 0;
Real xmit_family_prob = 0;
Real xmit_nc_prob = 0;
if (adults) {
xmit_family_prob = lparm->xmit_hh_adult[ptd.m_idata[IntIdx::age_group][i]];
xmit_nc_prob = lparm->xmit_nc_adult[ptd.m_idata[IntIdx::age_group][i]];
Expand All @@ -192,7 +192,7 @@ void InteractionModHome<PCType, PTDType, PType>::fastInteractHome (PCType& agent
AMREX_ALWAYS_ASSERT(community <= max_communities);
int family_i = community * max_family + family_ptr[i];
int num_infected_family = infected_family_d_ptr[family_i];
ParticleReal family_prob = 1.0_prt - infect * xmit_family_prob * scale;
Real family_prob = 1.0_rt - infect * xmit_family_prob * scale;
prob_ptr[i] *= static_cast<ParticleReal>(std::pow(family_prob, num_infected_family));
if (!ptd.m_idata[IntIdx::withdrawn][i]) {
int num_infected_family_not_withdrawn = infected_family_not_withdrawn_d_ptr[family_i];
Expand All @@ -201,7 +201,7 @@ void InteractionModHome<PCType, PTDType, PType>::fastInteractHome (PCType& agent
int nc = (community * max_nborhood + nborhood_ptr[i]) * num_ncs + cluster;
int num_infected_nc = infected_nc_d_ptr[nc] - num_infected_family_not_withdrawn;
AMREX_ALWAYS_ASSERT(num_infected_nc >= 0);
ParticleReal nc_prob = 1.0_prt - infect * xmit_nc_prob * scale;
Real nc_prob = 1.0_rt - infect * xmit_nc_prob * scale;
prob_ptr[i] *= static_cast<ParticleReal>(std::pow(nc_prob, num_infected_nc));
}
}
Expand Down
8 changes: 4 additions & 4 deletions src/InteractionModHomeNborhood.H
Original file line number Diff line number Diff line change
Expand Up @@ -132,8 +132,8 @@ void InteractionModHomeNborhood<PCType, PTDType, PType>::fastInteractHomeNborhoo
auto prob_ptr = soa.GetRealData(RealIdx::nattribs + r0(d) + RealIdxDisease::prob).data();
auto lparm = agents.getDiseaseParameters_d(d);
auto lparm_h = agents.getDiseaseParameters_h(d);
Real scale = 1.0_prt; // TODO this should vary based on cell
Real infect = (1.0_rt - lparm_h->vac_eff);
Real scale = 1.0_rt; // TODO this should vary based on cell
Real infect = 1.0_rt - lparm_h->vac_eff;

ParallelFor(np, [=] AMREX_GPU_DEVICE (int i) noexcept {
if (isInfectious(i, ptd, d) && isCandidate(i, ptd)) {
Expand All @@ -152,9 +152,9 @@ void InteractionModHomeNborhood<PCType, PTDType, PType>::fastInteractHomeNborhoo
int num_infected_nborhood = infected_nborhood_d_ptr[community * max_nborhood + nborhood];
int num_infected_community = infected_community_d_ptr[community];
AMREX_ALWAYS_ASSERT(num_infected_community >= num_infected_nborhood);
ParticleReal comm_prob = 1.0_prt - infect * lparm->xmit_comm[ptd.m_idata[IntIdx::age_group][i]] * scale;
Real comm_prob = 1.0_prt - infect * lparm->xmit_comm[ptd.m_idata[IntIdx::age_group][i]] * scale;
prob_ptr[i] *= static_cast<ParticleReal>(std::pow(comm_prob, num_infected_community - num_infected_nborhood));
ParticleReal nborhood_prob = 1.0_prt - infect * lparm->xmit_hood[ptd.m_idata[IntIdx::age_group][i]] * scale;
Real nborhood_prob = 1.0_prt - infect * lparm->xmit_hood[ptd.m_idata[IntIdx::age_group][i]] * scale;
prob_ptr[i] *= static_cast<ParticleReal>(std::pow(nborhood_prob, num_infected_nborhood));
}
});
Expand Down
Loading

0 comments on commit 3fb168f

Please sign in to comment.