Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(pt): support spin virial #4545

Draft
wants to merge 1 commit into
base: devel
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions deepmd/pt/loss/ener_spin.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,22 @@ def forward(self, input_dict, model, label, natoms, learning_rate, mae=False):
rmse_ae.detach(), find_atom_ener
)

if self.has_v and "virial" in model_pred and "virial" in label:
find_virial = label.get("find_virial", 0.0)
pref_v = pref_v * find_virial
diff_v = label["virial"] - model_pred["virial"].reshape(-1, 9)
l2_virial_loss = torch.mean(torch.square(diff_v))
if not self.inference:
more_loss["l2_virial_loss"] = self.display_if_exist(
l2_virial_loss.detach(), find_virial
)
loss += atom_norm * (pref_v * l2_virial_loss)
rmse_v = l2_virial_loss.sqrt() * atom_norm
more_loss["rmse_v"] = self.display_if_exist(rmse_v.detach(), find_virial)
if mae:
mae_v = torch.mean(torch.abs(diff_v)) * atom_norm
more_loss["mae_v"] = self.display_if_exist(mae_v.detach(), find_virial)

if not self.inference:
more_loss["rmse"] = torch.sqrt(loss.detach())
return model_pred, loss, more_loss
Expand Down
17 changes: 17 additions & 0 deletions deepmd/pt/model/model/make_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@ def forward_common(
fparam: Optional[torch.Tensor] = None,
aparam: Optional[torch.Tensor] = None,
do_atomic_virial: bool = False,
coord_corr_for_virial: Optional[torch.Tensor] = None,
) -> dict[str, torch.Tensor]:
"""Return model prediction.

Expand All @@ -153,6 +154,9 @@ def forward_common(
atomic parameter. nf x nloc x nda
do_atomic_virial
If calculate the atomic virial.
coord_corr_for_virial
The coordinates correction of the atoms for virial.
shape: nf x (nloc x 3)

Returns
-------
Expand Down Expand Up @@ -180,6 +184,14 @@ def forward_common(
mixed_types=True,
box=bb,
)
if coord_corr_for_virial is not None:
coord_corr_for_virial = coord_corr_for_virial.to(cc.dtype)
extended_coord_corr = torch.gather(
coord_corr_for_virial, 1, mapping.unsqueeze(-1).expand(-1, -1, 3)
)
else:
extended_coord_corr = None

model_predict_lower = self.forward_common_lower(
extended_coord,
extended_atype,
Expand All @@ -188,6 +200,7 @@ def forward_common(
do_atomic_virial=do_atomic_virial,
fparam=fp,
aparam=ap,
extended_coord_corr=extended_coord_corr,
)
model_predict = communicate_extended_output(
model_predict_lower,
Expand Down Expand Up @@ -242,6 +255,7 @@ def forward_common_lower(
do_atomic_virial: bool = False,
comm_dict: Optional[dict[str, torch.Tensor]] = None,
extra_nlist_sort: bool = False,
extended_coord_corr: Optional[torch.Tensor] = None,
):
"""Return model prediction. Lower interface that takes
extended atomic coordinates and types, nlist, and mapping
Expand All @@ -268,6 +282,8 @@ def forward_common_lower(
The data needed for communication for parallel inference.
extra_nlist_sort
whether to forcibly sort the nlist.
extended_coord_corr
coordinates correction for virial in extended region. nf x (nall x 3)

Returns
-------
Expand Down Expand Up @@ -299,6 +315,7 @@ def forward_common_lower(
cc_ext,
do_atomic_virial=do_atomic_virial,
create_graph=self.training,
extended_coord_corr=extended_coord_corr,
)
model_predict = self.output_type_cast(model_predict, input_prec)
return model_predict
Expand Down
45 changes: 36 additions & 9 deletions deepmd/pt/model/model/spin_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,11 +54,14 @@ def process_spin_input(self, coord, atype, spin):
coord = coord.reshape(nframes, nloc, 3)
spin = spin.reshape(nframes, nloc, 3)
atype_spin = torch.concat([atype, atype + self.ntypes_real], dim=-1)
virtual_coord = coord + spin * (self.virtual_scale_mask.to(atype.device))[
atype
].reshape([nframes, nloc, 1])
spin_dist = spin * (self.virtual_scale_mask.to(atype.device))[atype].reshape(
[nframes, nloc, 1]
)
virtual_coord = coord + spin_dist
coord_spin = torch.concat([coord, virtual_coord], dim=-2)
return coord_spin, atype_spin
# for spin virial corr
coord_corr = torch.concat([torch.zeros_like(coord), -spin_dist], dim=-2)
return coord_spin, atype_spin, coord_corr

def process_spin_input_lower(
self,
Expand All @@ -78,13 +81,18 @@ def process_spin_input_lower(
"""
nframes, nall = extended_coord.shape[:2]
nloc = nlist.shape[1]
virtual_extended_coord = extended_coord + extended_spin * (
extended_spin_dist = extended_spin * (
self.virtual_scale_mask.to(extended_atype.device)
)[extended_atype].reshape([nframes, nall, 1])
virtual_extended_coord = extended_coord + extended_spin_dist
virtual_extended_atype = extended_atype + self.ntypes_real
extended_coord_updated = concat_switch_virtual(
extended_coord, virtual_extended_coord, nloc
)
# for spin virial corr
extended_coord_corr = concat_switch_virtual(
torch.zeros_like(extended_coord), -extended_spin_dist, nloc
)
extended_atype_updated = concat_switch_virtual(
extended_atype, virtual_extended_atype, nloc
)
Expand All @@ -100,6 +108,7 @@ def process_spin_input_lower(
extended_atype_updated,
nlist_updated,
mapping_updated,
extended_coord_corr,
)

def process_spin_output(
Expand Down Expand Up @@ -367,7 +376,7 @@ def spin_sampled_func():
sampled = sampled_func()
spin_sampled = []
for sys in sampled:
coord_updated, atype_updated = self.process_spin_input(
coord_updated, atype_updated, _ = self.process_spin_input(
sys["coord"], sys["atype"], sys["spin"]
)
tmp_dict = {
Expand Down Expand Up @@ -398,7 +407,9 @@ def forward_common(
do_atomic_virial: bool = False,
) -> dict[str, torch.Tensor]:
nframes, nloc = atype.shape
coord_updated, atype_updated = self.process_spin_input(coord, atype, spin)
coord_updated, atype_updated, coord_corr_for_virial = self.process_spin_input(
coord, atype, spin
)
if aparam is not None:
aparam = self.expand_aparam(aparam, nloc * 2)
model_ret = self.backbone_model.forward_common(
Expand All @@ -408,6 +419,7 @@ def forward_common(
fparam=fparam,
aparam=aparam,
do_atomic_virial=do_atomic_virial,
coord_corr_for_virial=coord_corr_for_virial,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Confirm the compatibility of the new argument coord_corr_for_virial

Check if the backbone model's forward_common method is designed to accept coord_corr_for_virial. If not, update the backbone model accordingly or modify the call to prevent runtime errors.

)
model_output_type = self.backbone_model.model_output_type()
if "mask" in model_output_type:
Expand Down Expand Up @@ -454,6 +466,7 @@ def forward_common_lower(
extended_atype_updated,
nlist_updated,
mapping_updated,
extended_coord_corr_for_virial,
) = self.process_spin_input_lower(
extended_coord, extended_atype, extended_spin, nlist, mapping=mapping
)
Expand All @@ -469,6 +482,7 @@ def forward_common_lower(
do_atomic_virial=do_atomic_virial,
comm_dict=comm_dict,
extra_nlist_sort=extra_nlist_sort,
extended_coord_corr=extended_coord_corr_for_virial,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Ensure extended_coord_corr is accepted by forward_common_lower

Similar to the previous comment, verify that self.backbone_model.forward_common_lower accepts extended_coord_corr as an argument. This prevents potential issues during model execution.

)
model_output_type = self.backbone_model.model_output_type()
if "mask" in model_output_type:
Expand Down Expand Up @@ -541,6 +555,11 @@ def translated_output_def(self):
output_def["force"].squeeze(-2)
output_def["force_mag"] = deepcopy(out_def_data["energy_derv_r_mag"])
output_def["force_mag"].squeeze(-2)
if self.do_grad_c("energy"):
output_def["virial"] = deepcopy(out_def_data["energy_derv_c_redu"])
output_def["virial"].squeeze(-2)
output_def["atom_virial"] = deepcopy(out_def_data["energy_derv_c"])
output_def["atom_virial"].squeeze(-3)
Comment on lines +558 to +562
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Assign the result of squeeze operations to reduce tensor dimensions

The squeeze methods in lines 560 and 562 do not modify tensors in place. Assign the results to ensure the dimensions are correctly reduced.

Apply this diff to fix the issue:

-        output_def["virial"].squeeze(-2)
+        output_def["virial"] = output_def["virial"].squeeze(-2)
-        output_def["atom_virial"].squeeze(-3)
+        output_def["atom_virial"] = output_def["atom_virial"].squeeze(-3)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if self.do_grad_c("energy"):
output_def["virial"] = deepcopy(out_def_data["energy_derv_c_redu"])
output_def["virial"].squeeze(-2)
output_def["atom_virial"] = deepcopy(out_def_data["energy_derv_c"])
output_def["atom_virial"].squeeze(-3)
if self.do_grad_c("energy"):
output_def["virial"] = deepcopy(out_def_data["energy_derv_c_redu"])
output_def["virial"] = output_def["virial"].squeeze(-2)
output_def["atom_virial"] = deepcopy(out_def_data["energy_derv_c"])
output_def["atom_virial"] = output_def["atom_virial"].squeeze(-3)

return output_def

def forward(
Expand Down Expand Up @@ -569,7 +588,10 @@ def forward(
if self.backbone_model.do_grad_r("energy"):
model_predict["force"] = model_ret["energy_derv_r"].squeeze(-2)
model_predict["force_mag"] = model_ret["energy_derv_r_mag"].squeeze(-2)
# not support virial by far
if self.backbone_model.do_grad_c("energy"):
model_predict["virial"] = model_ret["energy_derv_c_redu"].squeeze(-2)
if do_atomic_virial:
model_predict["atom_virial"] = model_ret["energy_derv_c"].squeeze(-3)
return model_predict

@torch.jit.export
Expand Down Expand Up @@ -606,5 +628,10 @@ def forward_lower(
model_predict["extended_force_mag"] = model_ret[
"energy_derv_r_mag"
].squeeze(-2)
# not support virial by far
if self.backbone_model.do_grad_c("energy"):
model_predict["virial"] = model_ret["energy_derv_c_redu"].squeeze(-2)
if do_atomic_virial:
model_predict["extended_virial"] = model_ret["energy_derv_c"].squeeze(
-3
)
return model_predict
7 changes: 7 additions & 0 deletions deepmd/pt/model/model/transform_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@ def fit_output_to_model_output(
coord_ext: torch.Tensor,
do_atomic_virial: bool = False,
create_graph: bool = True,
extended_coord_corr: Optional[torch.Tensor] = None,
) -> dict[str, torch.Tensor]:
"""Transform the output of the fitting network to
the model output.
Expand Down Expand Up @@ -187,6 +188,12 @@ def fit_output_to_model_output(
model_ret[kk_derv_r] = dr
if vdef.c_differentiable:
assert dc is not None
if extended_coord_corr is not None:
dc_corr = (
dr.squeeze(-2).unsqueeze(-1)
@ extended_coord_corr.unsqueeze(-2)
).view(list(dc.shape[:-2]) + [1, 9]) # noqa: RUF005
dc = dc + dc_corr
model_ret[kk_derv_c] = dc
model_ret[kk_derv_c + "_redu"] = torch.sum(
model_ret[kk_derv_c].to(redu_prec), dim=1
Expand Down
12 changes: 6 additions & 6 deletions source/api_c/include/deepmd.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2602,9 +2602,9 @@ class DeepSpinModelDevi : public DeepBaseModelDevi {
for (int j = 0; j < natoms * 3; j++) {
force_mag[i][j] = force_mag_flat[i * natoms * 3 + j];
}
// for (int j = 0; j < 9; j++) {
// virial[i][j] = virial_flat[i * 9 + j];
// }
for (int j = 0; j < 9; j++) {
virial[i][j] = virial_flat[i * 9 + j];
}
}
};
/**
Expand Down Expand Up @@ -2705,9 +2705,9 @@ class DeepSpinModelDevi : public DeepBaseModelDevi {
for (int j = 0; j < natoms * 3; j++) {
force_mag[i][j] = force_mag_flat[i * natoms * 3 + j];
}
// for (int j = 0; j < 9; j++) {
// virial[i][j] = virial_flat[i * 9 + j];
// }
for (int j = 0; j < 9; j++) {
virial[i][j] = virial_flat[i * 9 + j];
}
for (int j = 0; j < natoms; j++) {
atom_energy[i][j] = atom_energy_flat[i * natoms + j];
}
Expand Down
10 changes: 5 additions & 5 deletions source/api_c/src/c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -862,11 +862,11 @@ void DP_DeepSpinModelDeviCompute_variant(DP_DeepSpinModelDevi* dp,
flatten_vector(fm_flat, fm);
std::copy(fm_flat.begin(), fm_flat.end(), force_mag);
}
// if (virial) {
// std::vector<VALUETYPE> v_flat;
// flatten_vector(v_flat, v);
// std::copy(v_flat.begin(), v_flat.end(), virial);
// }
if (virial) {
std::vector<VALUETYPE> v_flat;
flatten_vector(v_flat, v);
std::copy(v_flat.begin(), v_flat.end(), virial);
}
if (atomic_energy) {
std::vector<VALUETYPE> ae_flat;
flatten_vector(ae_flat, ae);
Expand Down
25 changes: 11 additions & 14 deletions source/api_cc/src/DeepSpinPT.cc
Original file line number Diff line number Diff line change
Expand Up @@ -251,8 +251,7 @@ void DeepSpinPT::compute(ENERGYVTYPE& ener,
c10::IValue energy_ = outputs.at("energy");
c10::IValue force_ = outputs.at("extended_force");
c10::IValue force_mag_ = outputs.at("extended_force_mag");
// spin model not suported yet
// c10::IValue virial_ = outputs.at("virial");
c10::IValue virial_ = outputs.at("virial");
torch::Tensor flat_energy_ = energy_.toTensor().view({-1});
torch::Tensor cpu_energy_ = flat_energy_.to(torch::kCPU);
ener.assign(cpu_energy_.data_ptr<ENERGYTYPE>(),
Expand All @@ -267,11 +266,11 @@ void DeepSpinPT::compute(ENERGYVTYPE& ener,
dforce_mag.assign(
cpu_force_mag_.data_ptr<VALUETYPE>(),
cpu_force_mag_.data_ptr<VALUETYPE>() + cpu_force_mag_.numel());
// spin model not suported yet
// torch::Tensor flat_virial_ = virial_.toTensor().view({-1}).to(floatType);
// torch::Tensor cpu_virial_ = flat_virial_.to(torch::kCPU);
// virial.assign(cpu_virial_.data_ptr<VALUETYPE>(),
// cpu_virial_.data_ptr<VALUETYPE>() + cpu_virial_.numel());

torch::Tensor flat_virial_ = virial_.toTensor().view({-1}).to(floatType);
torch::Tensor cpu_virial_ = flat_virial_.to(torch::kCPU);
virial.assign(cpu_virial_.data_ptr<VALUETYPE>(),
cpu_virial_.data_ptr<VALUETYPE>() + cpu_virial_.numel());

// bkw map
force.resize(static_cast<size_t>(nframes) * fwd_map.size() * 3);
Expand Down Expand Up @@ -415,8 +414,7 @@ void DeepSpinPT::compute(ENERGYVTYPE& ener,
c10::IValue energy_ = outputs.at("energy");
c10::IValue force_ = outputs.at("force");
c10::IValue force_mag_ = outputs.at("force_mag");
// spin model not suported yet
// c10::IValue virial_ = outputs.at("virial");
c10::IValue virial_ = outputs.at("virial");
torch::Tensor flat_energy_ = energy_.toTensor().view({-1});
torch::Tensor cpu_energy_ = flat_energy_.to(torch::kCPU);
ener.assign(cpu_energy_.data_ptr<ENERGYTYPE>(),
Expand All @@ -431,11 +429,10 @@ void DeepSpinPT::compute(ENERGYVTYPE& ener,
force_mag.assign(
cpu_force_mag_.data_ptr<VALUETYPE>(),
cpu_force_mag_.data_ptr<VALUETYPE>() + cpu_force_mag_.numel());
// spin model not suported yet
// torch::Tensor flat_virial_ = virial_.toTensor().view({-1}).to(floatType);
// torch::Tensor cpu_virial_ = flat_virial_.to(torch::kCPU);
// virial.assign(cpu_virial_.data_ptr<VALUETYPE>(),
// cpu_virial_.data_ptr<VALUETYPE>() + cpu_virial_.numel());
torch::Tensor flat_virial_ = virial_.toTensor().view({-1}).to(floatType);
torch::Tensor cpu_virial_ = flat_virial_.to(torch::kCPU);
virial.assign(cpu_virial_.data_ptr<VALUETYPE>(),
cpu_virial_.data_ptr<VALUETYPE>() + cpu_virial_.numel());
if (atomic) {
// c10::IValue atom_virial_ = outputs.at("atom_virial");
c10::IValue atom_energy_ = outputs.at("atom_energy");
Expand Down
17 changes: 16 additions & 1 deletion source/tests/pt/model/test_autodiff.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,11 +141,17 @@ def test(
cell = (cell) + 5.0 * torch.eye(3, device="cpu")
coord = torch.rand([natoms, 3], dtype=dtype, device="cpu", generator=generator)
coord = torch.matmul(coord, cell)
spin = torch.rand([natoms, 3], dtype=dtype, device="cpu", generator=generator)
atype = torch.IntTensor([0, 0, 0, 1, 1])
# assumes input to be numpy tensor
coord = coord.numpy()
spin = spin.numpy()
cell = cell.numpy()
test_keys = ["energy", "force", "virial"]
test_spin = getattr(self, "test_spin", False)
if not test_spin:
test_keys = ["energy", "force", "virial"]
else:
test_keys = ["energy", "force", "force_mag", "virial"]

def np_infer(
new_cell,
Expand All @@ -157,6 +163,7 @@ def np_infer(
).unsqueeze(0),
torch.tensor(new_cell, device="cpu").unsqueeze(0),
atype,
spins=torch.tensor(spin, device=env.DEVICE).unsqueeze(0),
)
# detach
ret = {key: to_numpy_array(result[key].squeeze(0)) for key in test_keys}
Comment on lines +166 to 169
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Ensure compatibility of tensor devices

When creating tensors within the np_infer function, ensure that all tensors are on the same device to prevent device mismatch errors, especially when env.DEVICE differs from "cpu".

Apply this diff to correct the device assignment:

-                    spins=torch.tensor(spin, device=env.DEVICE).unsqueeze(0),
+                    spins=torch.tensor(spin, device=new_cell.device).unsqueeze(0),

Committable suggestion skipped: line range outside the PR's diff.

Expand Down Expand Up @@ -251,3 +258,11 @@ def setUp(self) -> None:
self.type_split = False
self.test_spin = True
self.model = get_model(model_params).to(env.DEVICE)


class TestEnergyModelSpinSeAVirial(unittest.TestCase, VirialTest):
def setUp(self) -> None:
model_params = copy.deepcopy(model_spin)
self.type_split = False
self.test_spin = True
self.model = get_model(model_params).to(env.DEVICE)
3 changes: 2 additions & 1 deletion source/tests/pt/model/test_ener_spin_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def test_input_output_process(self) -> None:
nframes, nloc = self.coord.shape[:2]
self.real_ntypes = self.model.spin.get_ntypes_real()
# 1. test forward input process
coord_updated, atype_updated = self.model.process_spin_input(
coord_updated, atype_updated, _ = self.model.process_spin_input(
self.coord, self.atype, self.spin
)
# compare atypes of real and virtual atoms
Expand Down Expand Up @@ -174,6 +174,7 @@ def test_input_output_process(self) -> None:
extended_atype_updated,
nlist_updated,
mapping_updated,
_,
) = self.model.process_spin_input_lower(
extended_coord, extended_atype, extended_spin, nlist, mapping=mapping
)
Expand Down
Loading
Loading