Skip to content

Commit

Permalink
Merge branch 'develop' into remove_ctx
Browse files Browse the repository at this point in the history
  • Loading branch information
Critsium-xy authored Jan 17, 2025
2 parents 7565f9a + ce7acd0 commit ba5d2a5
Show file tree
Hide file tree
Showing 4 changed files with 42 additions and 9 deletions.
14 changes: 10 additions & 4 deletions source/module_base/module_container/ATen/core/tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,8 @@ Tensor::Tensor(Tensor&& other) noexcept
// However, Our subclass TensorMap, etc., do not own resources.
// So, we do not need to declare a virtual destructor here.
Tensor::~Tensor() {
if (buffer_) buffer_->unref();
if (buffer_) { buffer_->unref();
}
}

// Get the data type of the tensor.
Expand Down Expand Up @@ -223,7 +224,8 @@ Tensor& Tensor::operator=(const Tensor& other) {
this->device_ = other.device_;
this->data_type_ = other.data_type_;
this->shape_ = other.shape_;
if (buffer_) buffer_->unref();
if (buffer_) { buffer_->unref();
}

this->buffer_ = new TensorBuffer(GetAllocator(device_), shape_.NumElements() * SizeOfType(data_type_));

Expand All @@ -241,7 +243,8 @@ Tensor& Tensor::operator=(Tensor&& other) noexcept {
this->data_type_ = other.data_type_;
this->shape_ = other.shape_;

if (buffer_) buffer_->unref(); // Release current resource
if (buffer_) { buffer_->unref(); // Release current resource
}
this->buffer_ = other.buffer_;
other.buffer_ = nullptr; // Reset the other TensorBuffer.
return *this;
Expand Down Expand Up @@ -284,7 +287,8 @@ bool Tensor::AllocateFrom(const Tensor& other, const TensorShape& shape) {
data_type_ = other.data_type_;
device_ = other.device_;
shape_ = shape;
if (buffer_) buffer_->unref();
if (buffer_) { buffer_->unref();
}
buffer_ = new TensorBuffer(GetAllocator(device_), shape_.NumElements() * SizeOfType(data_type_));
return true;
}
Expand Down Expand Up @@ -324,6 +328,7 @@ Tensor Tensor::operator[](const int& index) const {
// Overloaded operator<< for the Tensor class.
std::ostream& operator<<(std::ostream& os, const Tensor& tensor) {
std::ios::fmtflags flag(os.flags());
std::streamsize precision = os.precision(); // save the current precision
const int64_t num_elements = tensor.NumElements();
const DataType data_type = tensor.data_type();
const DeviceType device_type = tensor.device_type();
Expand Down Expand Up @@ -398,6 +403,7 @@ std::ostream& operator<<(std::ostream& os, const Tensor& tensor) {
#endif
// restore the os settings
os.flags(flag);
os.precision(precision); // restore the precision
return os;
}

Expand Down
4 changes: 2 additions & 2 deletions source/module_cell/read_stru.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
namespace unitcell
{
bool check_tau(const Atom* atoms,
const int ntype,
const int lat0)
const int& ntype,
const double& lat0)
{
ModuleBase::TITLE("UnitCell","check_tau");
ModuleBase::timer::tick("UnitCell","check_tau");
Expand Down
4 changes: 2 additions & 2 deletions source/module_cell/read_stru.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
namespace unitcell
{
bool check_tau(const Atom* atoms,
const int ntype,
const int lat0);
const int& ntype,
const double& lat0);

bool read_atom_species(std::ifstream& ifa,
std::ofstream& ofs_running,
Expand Down
29 changes: 28 additions & 1 deletion source/module_cell/test/unitcell_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -753,7 +753,7 @@ TEST_F(UcellTest, CheckDTau)
}
}

TEST_F(UcellTest, CheckTau)
TEST_F(UcellTest, CheckTauFalse)
{
UcellTestPrepare utp = UcellTestLib["C1H2-CheckTau"];
PARAM.input.relax_new = utp.relax_new;
Expand All @@ -769,6 +769,33 @@ TEST_F(UcellTest, CheckTau)
remove("checktau_warning");
}

TEST_F(UcellTest, CheckTauTrue)
{
UcellTestPrepare utp = UcellTestLib["C1H2-CheckTau"];
PARAM.input.relax_new = utp.relax_new;
ucell = utp.SetUcellInfo();
GlobalV::ofs_warning.open("checktau_warning");
int atom=0;
//cause the ucell->lat0 is 0.5,if the type of the check_tau has
//an int type,it will set to zero,and it will not pass the unittest
ucell->lat0=0.5;
ucell->nat=3;
for (int it=0;it<ucell->ntype;it++)
{
for(int ia=0; ia<ucell->atoms[it].na; ++ia)
{

for (int i=0;i<3;i++)
{
ucell->atoms[it].tau[ia][i]=((atom+i)/(ucell->nat*3.0));
}
atom+=3;
}
}
EXPECT_EQ(unitcell::check_tau(ucell->atoms ,ucell->ntype, ucell->lat0),true);
GlobalV::ofs_warning.close();
}

TEST_F(UcellTest, SelectiveDynamics)
{
UcellTestPrepare utp = UcellTestLib["C1H2-SD"];
Expand Down

0 comments on commit ba5d2a5

Please sign in to comment.