Skip to content

Commit

Permalink
Fixed CUTLASS compilation. Error out if CUTLASS + U/GKS is attempted.
Browse files Browse the repository at this point in the history
  • Loading branch information
mikovtun committed Jul 10, 2024
1 parent 74515ec commit 9f3242c
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,10 @@

namespace GauXC {

void AoSScheme1CUTLASSBase::eval_xmat(double fac, XCDeviceData* _data, bool do_grad ){
void AoSScheme1CUTLASSBase::eval_xmat(double fac, XCDeviceData* _data, bool do_grad, density_id den_id ){

if( do_grad ) GAUXC_GENERIC_EXCEPTION("CUTLASS + X Gradient NYI");
if( den_id != DEN_S ) GAUXC_GENERIC_EXCEPTION("CUTLASS + U/GKS NYI");

auto* data = dynamic_cast<Data*>(_data);
if( !data ) GAUXC_BAD_LWD_DATA_CAST();
Expand All @@ -33,7 +34,7 @@ void AoSScheme1CUTLASSBase::eval_xmat(double fac, XCDeviceData* _data, bool do_g
const auto submat_block_size = data->get_submat_chunk_size( nbf, 0 );
auto static_stack = data->static_stack;
auto aos_stack = data->aos_stack;
sym_pack_submat( ntasks, aos_stack.device_tasks, static_stack.dmat_device,
sym_pack_submat( ntasks, aos_stack.device_tasks, static_stack.dmat_s_device,
nbf, submat_block_size, data->device_backend_->queue() );

auto cutlass_stack = data->cutlass_stack;
Expand All @@ -50,14 +51,15 @@ void AoSScheme1CUTLASSBase::eval_xmat(double fac, XCDeviceData* _data, bool do_g
);
}

void AoSScheme1CUTLASSBase::inc_vxc( XCDeviceData* _data, bool do_m){
void AoSScheme1CUTLASSBase::inc_vxc( XCDeviceData* _data, density_id den_id, bool do_m){

auto* data = dynamic_cast<Data*>(_data);
if( !data ) GAUXC_BAD_LWD_DATA_CAST();

if( not data->device_backend_ ) GAUXC_UNINITIALIZED_DEVICE_BACKEND();

if(do_m) GAUXC_GENERIC_EXCEPTION("CUTLASS + MGGA NYI");
if( den_id != DEN_S ) GAUXC_GENERIC_EXCEPTION("CUTLASS + U/GKS NYI");

auto& tasks = data->host_device_tasks;
const auto ntasks = tasks.size();
Expand All @@ -81,7 +83,7 @@ void AoSScheme1CUTLASSBase::inc_vxc( XCDeviceData* _data, bool do_m){
auto static_stack = data->static_stack;
auto aos_stack = data->aos_stack;
sym_task_inc_potential( ntasks, aos_stack.device_tasks,
static_stack.vxc_device, nbf, submat_block_size,
static_stack.vxc_s_device, nbf, submat_block_size,
data->device_backend_->queue() );
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ namespace GauXC {

struct AoSScheme1CUTLASSBase : public AoSScheme1Base {

void eval_xmat(double fac, XCDeviceData*, bool do_grad ) override final;
void inc_vxc( XCDeviceData*, bool ) override final;
void eval_xmat(double fac, XCDeviceData*, bool do_grad, density_id ) override final;
void inc_vxc( XCDeviceData*, density_id, bool ) override final;

struct Data;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ void AoSScheme1CUTLASSBase::Data::pack_and_send(
std::vector<int64_t> ld64_dmat_host( ntask ), ld64_zmat_host( ntask ),
ld64_vmat_host( ntask ), ld64_bf_host( ntask );

double* static_dmat = static_stack.dmat_device;
double* static_dmat = static_stack.dmat_s_device;
const auto nbf = global_dims.nbf;

// host_device_tasks should be populated by parent impl called at top
Expand Down

0 comments on commit 9f3242c

Please sign in to comment.