Skip to content

Commit

Permalink
Merge pull request #1993 from vqd8a/supernodal-sptrsv-copyD2H
Browse files Browse the repository at this point in the history
Only deep_copy from device to host if supernodal sptrsv algorithms are used
  • Loading branch information
lucbv authored Oct 10, 2023
2 parents 66cc282 + c1e6c76 commit 0aac17f
Showing 1 changed file with 31 additions and 15 deletions.
46 changes: 31 additions & 15 deletions sparse/impl/KokkosSparse_sptrsv_solve_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2908,10 +2908,9 @@ void lower_tri_solve(TriSolveHandle &thandle, const RowMapType row_map,
// Keep this a host View, create device version and copy to back to host
// during scheduling This requires making sure the host view in the handle is
// properly updated after the symbolic phase
auto nodes_per_level = thandle.get_nodes_per_level();
auto hnodes_per_level = thandle.get_host_nodes_per_level();
auto nodes_grouped_by_level = thandle.get_nodes_grouped_by_level();
auto nodes_grouped_by_level_host = thandle.get_host_nodes_grouped_by_level();
auto nodes_per_level = thandle.get_nodes_per_level();
auto hnodes_per_level = thandle.get_host_nodes_per_level();
auto nodes_grouped_by_level = thandle.get_nodes_grouped_by_level();

#if defined(KOKKOSKERNELS_ENABLE_SUPERNODAL_SPTRSV)
using namespace KokkosSparse::Experimental;
Expand All @@ -2920,15 +2919,25 @@ void lower_tri_solve(TriSolveHandle &thandle, const RowMapType row_map,
using integer_view_host_t = typename TriSolveHandle::integer_view_host_t;
using scalar_t = typename ValuesType::non_const_value_type;
using range_type = Kokkos::pair<int, int>;
using row_map_host_view_t = Kokkos::View<size_type *, Kokkos::HostSpace>;

row_map_host_view_t row_map_host;

const scalar_t zero(0.0);
const scalar_t one(1.0);
Kokkos::deep_copy(nodes_grouped_by_level_host, nodes_grouped_by_level);

Kokkos::View<size_type *, Kokkos::HostSpace> row_map_host(
Kokkos::view_alloc(Kokkos::WithoutInitializing, "host rowmap"),
row_map.extent(0));
Kokkos::deep_copy(row_map_host, row_map);
auto nodes_grouped_by_level_host = thandle.get_host_nodes_grouped_by_level();

if (thandle.get_algorithm() == SPTRSVAlgorithm::SUPERNODAL_NAIVE ||
thandle.get_algorithm() == SPTRSVAlgorithm::SUPERNODAL_ETREE ||
thandle.get_algorithm() == SPTRSVAlgorithm::SUPERNODAL_DAG) {
Kokkos::deep_copy(nodes_grouped_by_level_host, nodes_grouped_by_level);

row_map_host = row_map_host_view_t(
Kokkos::view_alloc(Kokkos::WithoutInitializing, "host rowmap"),
row_map.extent(0));
Kokkos::deep_copy(row_map_host, row_map);
}

// inversion options
const bool invert_diagonal = thandle.get_invert_diagonal();
Expand Down Expand Up @@ -3293,19 +3302,26 @@ void upper_tri_solve(TriSolveHandle &thandle, const RowMapType row_map,
using integer_view_t = typename TriSolveHandle::integer_view_t;
using integer_view_host_t = typename TriSolveHandle::integer_view_host_t;
using scalar_t = typename ValuesType::non_const_value_type;
using range_type = Kokkos::pair<int, int>;
using row_map_host_view_t = Kokkos::View<size_type *, Kokkos::HostSpace>;

using range_type = Kokkos::pair<int, int>;
row_map_host_view_t row_map_host;

const scalar_t zero(0.0);
const scalar_t one(1.0);

auto nodes_grouped_by_level_host = thandle.get_host_nodes_grouped_by_level();
Kokkos::deep_copy(nodes_grouped_by_level_host, nodes_grouped_by_level);

Kokkos::View<size_type *, Kokkos::HostSpace> row_map_host(
Kokkos::view_alloc(Kokkos::WithoutInitializing, "host rowmap"),
row_map.extent(0));
Kokkos::deep_copy(row_map_host, row_map);
if (thandle.get_algorithm() == SPTRSVAlgorithm::SUPERNODAL_NAIVE ||
thandle.get_algorithm() == SPTRSVAlgorithm::SUPERNODAL_ETREE ||
thandle.get_algorithm() == SPTRSVAlgorithm::SUPERNODAL_DAG) {
Kokkos::deep_copy(nodes_grouped_by_level_host, nodes_grouped_by_level);

row_map_host = row_map_host_view_t(
Kokkos::view_alloc(Kokkos::WithoutInitializing, "host rowmap"),
row_map.extent(0));
Kokkos::deep_copy(row_map_host, row_map);
}

// supernode sizes
const int *supercols = thandle.get_supercols();
Expand Down

0 comments on commit 0aac17f

Please sign in to comment.