Skip to content

Commit

Permalink
EXC Only Interface (#121)
Browse files Browse the repository at this point in the history
* Add RKS EXC-only integrator interface + UTs

* Add UKS/GKS EXC-only interfaces

* Make EXC-only path through GPU code work

* Explicitly disable UKS/GKS EXC only

* Update EXC Only with `master`: Squashed commit of the following:

commit 62ac01fab0767f13a7424bcea8a940c92e3aa7e2
Author: David Williams-Young <[email protected]>
Date:   Fri May 31 13:36:46 2024 -0700

    Fix EXC-only path for host integrators, UTs pass

commit 57f305c39ca4c858d5050c5dfc32c3b1ffab2e82
Merge: 1c844b4 92bbbe2
Author: David Williams-Young <[email protected]>
Date:   Fri May 31 12:28:27 2024 -0700

    Merge branch 'master' into merge_master

commit 92bbbe2
Author: David Williams-Young <[email protected]>
Date:   Fri May 31 10:50:04 2024 -0700

    actions/checkout@v3 -> actions/checkout@v4 (#131)

    Bump actions/checkout version to quiet GHA warnings

commit b61ee7a
Author: mikovtun <[email protected]>
Date:   Fri May 31 10:21:21 2024 -0700

    Spellcheck Error Messages (#129)

    * Fixed spelling in error messages and homogenized capitalization

    * Update include/gauxc/load_balancer.hpp

    Co-authored-by: David Williams-Young <[email protected]>

    ---------

    Co-authored-by: David Williams-Young <[email protected]>

commit 905c36a
Author: Ajay Panyala <[email protected]>
Date:   Fri May 31 09:07:08 2024 -0700

    cutlass requires cuda CC >= 8.0 (#130)

commit b9c2161
Author: David Williams-Young <[email protected]>
Date:   Tue May 21 15:13:43 2024 -0700

    Refactor ShellBatched Integrators (#127)

    * Refactor of RKS/UKS/GKS Host drivers to reduce code replication + various QoL

    * ShellBatched Refactor - Host compiles and tests, device untested

    * Fix new ShellBatched for Device, add additional std::future::get to proagate exceptions

    * Update copyright year on old files, shellbatched -> shell_batched

    * document why the extra std::future::get is there

commit 6a8f4bf
Author: David Williams-Young <[email protected]>
Date:   Wed May 8 09:32:51 2024 -0700

    Add Runtime Environment Query Functions (#126)

    * Add runtime environment query functions

    * Update gauxc_config.hpp.in

commit cf6b85c
Author: Ajay Panyala <[email protected]>
Date:   Tue May 7 09:48:17 2024 -0700

    install nccl module file

commit 3e44fcc
Author: David Williams-Young <[email protected]>
Date:   Mon May 6 16:01:48 2024 -0700

    Make CUTLASS a build-only dependency

* Add missing file

* Remove old shellbatched file
  • Loading branch information
wavefunction91 authored May 31, 2024
1 parent 92bbbe2 commit be9ac8d
Show file tree
Hide file tree
Showing 24 changed files with 617 additions and 60 deletions.
7 changes: 7 additions & 0 deletions include/gauxc/xc_integrator.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,13 +54,20 @@ class XCIntegrator {
XCIntegrator( XCIntegrator&& ) noexcept;

value_type integrate_den( const MatrixType& );

value_type eval_exc( const MatrixType&, const IntegratorSettingsXC& = IntegratorSettingsXC{} );
value_type eval_exc( const MatrixType&, const MatrixType&, const IntegratorSettingsXC& = IntegratorSettingsXC{} );
value_type eval_exc( const MatrixType&, const MatrixType&, const MatrixType&, const MatrixType&, const IntegratorSettingsXC& = IntegratorSettingsXC{} );

exc_vxc_type_rks eval_exc_vxc ( const MatrixType&,
const IntegratorSettingsXC& = IntegratorSettingsXC{} );
exc_vxc_type_uks eval_exc_vxc ( const MatrixType&, const MatrixType&,
const IntegratorSettingsXC& = IntegratorSettingsXC{} );
exc_vxc_type_gks eval_exc_vxc ( const MatrixType&, const MatrixType&, const MatrixType&, const MatrixType&,
const IntegratorSettingsXC& = IntegratorSettingsXC{});

exc_grad_type eval_exc_grad( const MatrixType& );

exx_type eval_exx ( const MatrixType&,
const IntegratorSettingsEXX& = IntegratorSettingsEXX{} );

Expand Down
21 changes: 21 additions & 0 deletions include/gauxc/xc_integrator/impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,27 @@ typename XCIntegrator<MatrixType>::value_type
return pimpl_->integrate_den(P);
};

template <typename MatrixType>
typename XCIntegrator<MatrixType>::value_type
XCIntegrator<MatrixType>::eval_exc( const MatrixType& P, const IntegratorSettingsXC& ks_settings ) {
if( not pimpl_ ) GAUXC_PIMPL_NOT_INITIALIZED();
return pimpl_->eval_exc(P, ks_settings);
}

template <typename MatrixType>
typename XCIntegrator<MatrixType>::value_type
XCIntegrator<MatrixType>::eval_exc( const MatrixType& Ps, const MatrixType& Pz, const IntegratorSettingsXC& ks_settings ) {
if( not pimpl_ ) GAUXC_PIMPL_NOT_INITIALIZED();
return pimpl_->eval_exc(Ps, Pz, ks_settings);
}

template <typename MatrixType>
typename XCIntegrator<MatrixType>::value_type
XCIntegrator<MatrixType>::eval_exc( const MatrixType& Ps, const MatrixType& Pz, const MatrixType& Py, const MatrixType& Px, const IntegratorSettingsXC& ks_settings ) {
if( not pimpl_ ) GAUXC_PIMPL_NOT_INITIALIZED();
return pimpl_->eval_exc(Ps, Pz, Py, Px, ks_settings);
}

template <typename MatrixType>
typename XCIntegrator<MatrixType>::exc_vxc_type_rks
XCIntegrator<MatrixType>::eval_exc_vxc( const MatrixType& P, const IntegratorSettingsXC& ks_settings ) {
Expand Down
38 changes: 38 additions & 0 deletions include/gauxc/xc_integrator/replicated/impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,44 @@ typename ReplicatedXCIntegrator<MatrixType>::value_type
return N_EL;
}

template <typename MatrixType>
typename ReplicatedXCIntegrator<MatrixType>::value_type
ReplicatedXCIntegrator<MatrixType>::eval_exc_( const MatrixType& P, const IntegratorSettingsXC& ks_settings ) {

if( not pimpl_ ) GAUXC_PIMPL_NOT_INITIALIZED();
value_type EXC;

pimpl_->eval_exc( P.rows(), P.cols(), P.data(), P.rows(), &EXC, ks_settings );

return EXC;
}

template <typename MatrixType>
typename ReplicatedXCIntegrator<MatrixType>::value_type
ReplicatedXCIntegrator<MatrixType>::eval_exc_( const MatrixType& Ps, const MatrixType& Pz, const IntegratorSettingsXC& ks_settings ) {

if( not pimpl_ ) GAUXC_PIMPL_NOT_INITIALIZED();
value_type EXC;

const size_t n = Ps.rows();
pimpl_->eval_exc( n, n, Ps.data(), n, Pz.data(), n, &EXC, ks_settings );

return EXC;
}

template <typename MatrixType>
typename ReplicatedXCIntegrator<MatrixType>::value_type
ReplicatedXCIntegrator<MatrixType>::eval_exc_( const MatrixType& Ps, const MatrixType& Pz, const MatrixType& Py, const MatrixType& Px, const IntegratorSettingsXC& ks_settings ) {

if( not pimpl_ ) GAUXC_PIMPL_NOT_INITIALIZED();
value_type EXC;

const size_t n = Ps.rows();
pimpl_->eval_exc( n, n, Ps.data(), n, Pz.data(), n, Py.data(), n, Px.data(), n, &EXC, ks_settings );

return EXC;
}

template <typename MatrixType>
typename ReplicatedXCIntegrator<MatrixType>::exc_vxc_type_rks
ReplicatedXCIntegrator<MatrixType>::eval_exc_vxc_( const MatrixType& P, const IntegratorSettingsXC& ks_settings ) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,18 @@ class ReplicatedXCIntegratorImpl {

virtual void integrate_den_( int64_t m, int64_t n, const value_type* P,
int64_t ldp, value_type* N_EL ) = 0;

virtual void eval_exc_( int64_t m, int64_t n, const value_type* P, int64_t ldp,
value_type* EXC, const IntegratorSettingsXC& ks_settings ) = 0;
virtual void eval_exc_( int64_t m, int64_t n, const value_type* Ps, int64_t ldps,
const value_type* Pz, int64_t ldpz,
value_type* EXC, const IntegratorSettingsXC& ks_settings ) = 0;
virtual void eval_exc_( int64_t m, int64_t n, const value_type* Ps, int64_t ldps,
const value_type* Pz, int64_t ldpz,
const value_type* Py, int64_t ldpy,
const value_type* Px, int64_t ldpx,
value_type* EXC, const IntegratorSettingsXC& ks_settings ) = 0;

virtual void eval_exc_vxc_( int64_t m, int64_t n, const value_type* P,
int64_t ldp, value_type* VXC, int64_t ldvxc,
value_type* EXC, const IntegratorSettingsXC& ks_settings ) = 0;
Expand Down Expand Up @@ -81,6 +93,17 @@ class ReplicatedXCIntegratorImpl {
void integrate_den( int64_t m, int64_t n, const value_type* P,
int64_t ldp, value_type* N_EL );

void eval_exc( int64_t m, int64_t n, const value_type* P, int64_t ldp,
value_type* EXC, const IntegratorSettingsXC& ks_settings );
void eval_exc( int64_t m, int64_t n, const value_type* Ps, int64_t ldps,
const value_type* Pz, int64_t ldpz,
value_type* EXC, const IntegratorSettingsXC& ks_settings );
void eval_exc( int64_t m, int64_t n, const value_type* Ps, int64_t ldps,
const value_type* Pz, int64_t ldpz,
const value_type* Py, int64_t ldpy,
const value_type* Px, int64_t ldpx,
value_type* EXC, const IntegratorSettingsXC& ks_settings );

void eval_exc_vxc( int64_t m, int64_t n, const value_type* P,
int64_t ldp, value_type* VXC, int64_t ldvxc,
value_type* EXC, const IntegratorSettingsXC& ks_settings );
Expand Down
3 changes: 3 additions & 0 deletions include/gauxc/xc_integrator/replicated_xc_integrator.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@ class ReplicatedXCIntegrator : public XCIntegratorImpl<MatrixType> {
std::unique_ptr< pimpl_type > pimpl_;

value_type integrate_den_( const MatrixType& ) override;
value_type eval_exc_ ( const MatrixType&, const IntegratorSettingsXC& ) override;
value_type eval_exc_ ( const MatrixType&, const MatrixType&, const IntegratorSettingsXC& ) override;
value_type eval_exc_ ( const MatrixType&, const MatrixType&, const MatrixType&, const MatrixType&, const IntegratorSettingsXC& ) override;
exc_vxc_type_rks eval_exc_vxc_ ( const MatrixType&, const IntegratorSettingsXC& ) override;
exc_vxc_type_uks eval_exc_vxc_ ( const MatrixType&, const MatrixType&, const IntegratorSettingsXC&) override;
exc_vxc_type_gks eval_exc_vxc_ ( const MatrixType&, const MatrixType&, const MatrixType&, const MatrixType&, const IntegratorSettingsXC& ) override;
Expand Down
42 changes: 35 additions & 7 deletions include/gauxc/xc_integrator/xc_integrator_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,11 @@ class XCIntegratorImpl {
protected:

virtual value_type integrate_den_( const MatrixType& P ) = 0;

virtual value_type eval_exc_ ( const MatrixType& P, const IntegratorSettingsXC& ks_settings ) = 0;
virtual value_type eval_exc_ ( const MatrixType& Ps, const MatrixType& Pz, const IntegratorSettingsXC& ks_settings ) = 0;
virtual value_type eval_exc_ ( const MatrixType& Ps, const MatrixType& Pz, const MatrixType& Py, const MatrixType& Px, const IntegratorSettingsXC& ks_settings ) = 0;

virtual exc_vxc_type_rks eval_exc_vxc_ ( const MatrixType& P, const IntegratorSettingsXC& ks_settings ) = 0;
virtual exc_vxc_type_uks eval_exc_vxc_ ( const MatrixType& Ps, const MatrixType& Pz, const IntegratorSettingsXC& ks_settings ) = 0;
virtual exc_vxc_type_gks eval_exc_vxc_ ( const MatrixType& Ps, const MatrixType& Pz, const MatrixType& Py, const MatrixType& Px,
Expand All @@ -54,13 +59,38 @@ class XCIntegratorImpl {
* @param[in] P The density matrix
* @returns Approx Tr[P*S]
*/
value_type integrate_den( const MatrixType& P ) {
return integrate_den_(P);
}
value_type integrate_den( const MatrixType& P ) {
return integrate_den_(P);
}

/** Integrate EXC for RKS
*
* @param[in] P The alpha density matrix
* @returns Integrated EXC
*/
value_type eval_exc( const MatrixType& P, const IntegratorSettingsXC& ks_settings ) {
return eval_exc_(P, ks_settings);
}

/** Integrate EXC for UKS
*
* @param[in] P The alpha density matrix
* @returns Integrated EXC
*/
value_type eval_exc( const MatrixType& Ps, const MatrixType& Pz, const IntegratorSettingsXC& ks_settings ) {
return eval_exc_(Ps, Pz, ks_settings);
}

/** Integrate EXC for GKS
*
* @param[in] P The alpha density matrix
* @returns Integrated EXC
*/
value_type eval_exc( const MatrixType& Ps, const MatrixType& Pz, const MatrixType& Py, const MatrixType& Px, const IntegratorSettingsXC& ks_settings ) {
return eval_exc_(Ps, Pz, Py, Px, ks_settings);
}

/** Integrate EXC / VXC (Mean field terms) for RKS
*
* TODO: add API for UKS/GKS
*
* @param[in] P The alpha density matrix
* @returns EXC / VXC in a combined structure
Expand Down Expand Up @@ -89,8 +119,6 @@ class XCIntegratorImpl {
}

/** Integrate Exact Exchange for RHF
*
* TODO: add API for UHF/GHF
*
* @param[in] P The alpha density matrix
* @returns Excact Exchange Matrix
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
* See LICENSE.txt for details
*/
#include "incore_replicated_xc_device_integrator_integrate_den.hpp"
#include "incore_replicated_xc_device_integrator_exc.hpp"
#include "incore_replicated_xc_device_integrator_exc_vxc.hpp"
#include "incore_replicated_xc_device_integrator_exc_grad.hpp"
#include "incore_replicated_xc_device_integrator_exx.hpp"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,17 @@ class IncoreReplicatedXCDeviceIntegrator :
void integrate_den_( int64_t m, int64_t n, const value_type* P,
int64_t ldp, value_type* N_EL ) override;

void eval_exc_( int64_t m, int64_t n, const value_type* P, int64_t ldp,
value_type* EXC, const IntegratorSettingsXC& ks_settings ) override;
void eval_exc_( int64_t m, int64_t n, const value_type* Ps, int64_t ldps,
const value_type* Pz, int64_t ldpz,
value_type* EXC, const IntegratorSettingsXC& ks_settings ) override;
void eval_exc_( int64_t m, int64_t n, const value_type* Ps, int64_t ldps,
const value_type* Pz, int64_t ldpz,
const value_type* Py, int64_t ldpy,
const value_type* Px, int64_t ldpx,
value_type* EXC, const IntegratorSettingsXC& ks_settings ) override;

void eval_exc_vxc_( int64_t m, int64_t n, const value_type* P,
int64_t ldp, value_type* VXC, int64_t ldvxc,
value_type* EXC, const IntegratorSettingsXC& settings) override;
Expand Down Expand Up @@ -74,7 +85,7 @@ class IncoreReplicatedXCDeviceIntegrator :

void exc_vxc_local_work_( const basis_type& basis, const value_type* P, int64_t ldp,
host_task_iterator task_begin, host_task_iterator task_end,
XCDeviceData& device_data );
XCDeviceData& device_data, bool do_vxc = true );

void exc_vxc_local_work_( const basis_type& basis, const value_type* P, int64_t ldp,
value_type* VXC, int64_t ldvxc, value_type* EXC, value_type *N_EL,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
/**
* GauXC Copyright (c) 2020-2024, The Regents of the University of California,
* through Lawrence Berkeley National Laboratory (subject to receipt of
* any required approvals from the U.S. Dept. of Energy). All rights reserved.
*
* See LICENSE.txt for details
*/
#include "incore_replicated_xc_device_integrator.hpp"
#include "device/local_device_work_driver.hpp"
#include "device/xc_device_aos_data.hpp"
#include <fstream>
#include <gauxc/exceptions.hpp>
#include <gauxc/util/unused.hpp>

namespace GauXC {
namespace detail {

template <typename ValueType>
void IncoreReplicatedXCDeviceIntegrator<ValueType>::
eval_exc_( int64_t m, int64_t n, const value_type* Ps, int64_t ldps,
const value_type* Pz, int64_t ldpz,
const value_type* Py, int64_t ldpy,
const value_type* Px, int64_t ldpx,
value_type* EXC, const IntegratorSettingsXC& settings ) {


if(Pz) GAUXC_GENERIC_EXCEPTION("UKS/GKS + EXC Only Device NYI");
const auto& basis = this->load_balancer_->basis();

// Check that P / VXC are sane
const int64_t nbf = basis.nbf();
if( m != n )
GAUXC_GENERIC_EXCEPTION("P/VXC Must Be Square");
if( m != nbf )
GAUXC_GENERIC_EXCEPTION("P/VXC Must Have Same Dimension as Basis");
if( ldps < nbf )
GAUXC_GENERIC_EXCEPTION("Invalid LDP");


// Get Tasks
auto& tasks = this->load_balancer_->get_tasks();

// Allocate Device memory
auto* lwd = dynamic_cast<LocalDeviceWorkDriver*>(this->local_work_driver_.get() );
auto rt = detail::as_device_runtime(this->load_balancer_->runtime());
auto device_data_ptr = lwd->create_device_data(rt);

GAUXC_MPI_CODE( MPI_Barrier(rt.comm());)

// Temporary electron count to judge integrator accuracy
value_type N_EL;

// Compute local contributions to EXC/VXC and retrieve
// data from device
this->timer_.time_op("XCIntegrator.LocalWork_EXC", [&](){
exc_vxc_local_work_( basis, Ps, ldps, nullptr, 0, EXC,
&N_EL, tasks.begin(), tasks.end(), *device_data_ptr);
});

GAUXC_MPI_CODE(
this->timer_.time_op("XCIntegrator.ImbalanceWait_EXC",[&](){
MPI_Barrier(this->load_balancer_->runtime().comm());
});
)

// Reduce Results in host mem
this->timer_.time_op("XCIntegrator.Allreduce_EXC", [&](){
this->reduction_driver_->allreduce_inplace( EXC, 1 , ReductionOp::Sum );
this->reduction_driver_->allreduce_inplace( &N_EL, 1 , ReductionOp::Sum );
});

}



template <typename ValueType>
void IncoreReplicatedXCDeviceIntegrator<ValueType>::
eval_exc_( int64_t m, int64_t n, const value_type* Ps, int64_t ldps,
const value_type* Pz, int64_t ldpz,
value_type* EXC, const IntegratorSettingsXC& settings ) {

eval_exc_(m, n, Ps, ldps, Pz, ldpz, nullptr, 0, nullptr, 0, EXC, settings);

}

template <typename ValueType>
void IncoreReplicatedXCDeviceIntegrator<ValueType>::
eval_exc_( int64_t m, int64_t n, const value_type* P, int64_t ldp,
value_type* EXC, const IntegratorSettingsXC& settings ) {

eval_exc_(m, n, P, ldp, nullptr, 0, nullptr, 0, nullptr, 0, EXC, settings);

}


}
}

Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ template <typename ValueType>
void IncoreReplicatedXCDeviceIntegrator<ValueType>::
exc_vxc_local_work_( const basis_type& basis, const value_type* P, int64_t ldp,
host_task_iterator task_begin, host_task_iterator task_end,
XCDeviceData& device_data ) {
XCDeviceData& device_data, bool do_vxc ) {


auto* lwd = dynamic_cast<LocalDeviceWorkDriver*>(this->local_work_driver_.get() );
Expand Down Expand Up @@ -195,7 +195,7 @@ void IncoreReplicatedXCDeviceIntegrator<ValueType>::
const auto nbf = basis.nbf();
const auto nshells = basis.nshells();
device_data.reset_allocations();
device_data.allocate_static_data_exc_vxc( nbf, nshells );
device_data.allocate_static_data_exc_vxc( nbf, nshells, do_vxc );
device_data.send_static_data_density_basis( P, ldp, basis );

// Zero integrands
Expand Down Expand Up @@ -257,6 +257,7 @@ void IncoreReplicatedXCDeviceIntegrator<ValueType>::
// Do scalar EXC/N_EL integrations
lwd->inc_exc( &device_data );
lwd->inc_nel( &device_data );
if( not do_vxc ) continue;

// Evaluate Z (+ M) matrix
if( func.is_mgga() ) {
Expand All @@ -272,7 +273,7 @@ void IncoreReplicatedXCDeviceIntegrator<ValueType>::
} // Loop over batches of batches

// Symmetrize VXC in device memory
lwd->symmetrize_vxc( &device_data );
if(do_vxc) lwd->symmetrize_vxc( &device_data );

}

Expand All @@ -287,7 +288,8 @@ void IncoreReplicatedXCDeviceIntegrator<ValueType>::
XCDeviceData& device_data ) {

// Get integrate and keep data on device
exc_vxc_local_work_( basis, P, ldp, task_begin, task_end, device_data );
const bool do_vxc = VXC;
exc_vxc_local_work_( basis, P, ldp, task_begin, task_end, device_data, do_vxc );
auto rt = detail::as_device_runtime(this->load_balancer_->runtime());
rt.device_backend()->master_queue_synchronize();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
*/
#include "shell_batched_replicated_xc_device_integrator.hpp"
#include "shell_batched_replicated_xc_integrator_integrate_den.hpp"
#include "shell_batched_replicated_xc_integrator_exc.hpp"
#include "shell_batched_replicated_xc_integrator_exc_vxc.hpp"
#include "shell_batched_replicated_xc_integrator_exc_grad.hpp"
#include "shell_batched_replicated_xc_integrator_exx.hpp"
Expand Down
Loading

0 comments on commit be9ac8d

Please sign in to comment.