diff --git a/source/module_base/module_container/ATen/core/tensor.cpp b/source/module_base/module_container/ATen/core/tensor.cpp index 8cae92fcac..92babb361c 100644 --- a/source/module_base/module_container/ATen/core/tensor.cpp +++ b/source/module_base/module_container/ATen/core/tensor.cpp @@ -53,7 +53,8 @@ Tensor::Tensor(Tensor&& other) noexcept // However, Our subclass TensorMap, etc., do not own resources. // So, we do not need to declare a virtual destructor here. Tensor::~Tensor() { - if (buffer_) buffer_->unref(); + if (buffer_) { buffer_->unref(); +} } // Get the data type of the tensor. @@ -223,7 +224,8 @@ Tensor& Tensor::operator=(const Tensor& other) { this->device_ = other.device_; this->data_type_ = other.data_type_; this->shape_ = other.shape_; - if (buffer_) buffer_->unref(); + if (buffer_) { buffer_->unref(); +} this->buffer_ = new TensorBuffer(GetAllocator(device_), shape_.NumElements() * SizeOfType(data_type_)); @@ -241,7 +243,8 @@ Tensor& Tensor::operator=(Tensor&& other) noexcept { this->data_type_ = other.data_type_; this->shape_ = other.shape_; - if (buffer_) buffer_->unref(); // Release current resource + if (buffer_) { buffer_->unref(); // Release current resource +} this->buffer_ = other.buffer_; other.buffer_ = nullptr; // Reset the other TensorBuffer. return *this; @@ -284,7 +287,8 @@ bool Tensor::AllocateFrom(const Tensor& other, const TensorShape& shape) { data_type_ = other.data_type_; device_ = other.device_; shape_ = shape; - if (buffer_) buffer_->unref(); + if (buffer_) { buffer_->unref(); +} buffer_ = new TensorBuffer(GetAllocator(device_), shape_.NumElements() * SizeOfType(data_type_)); return true; } @@ -324,6 +328,7 @@ Tensor Tensor::operator[](const int& index) const { // Overloaded operator<< for the Tensor class. std::ostream& operator<<(std::ostream& os, const Tensor& tensor) { std::ios::fmtflags flag(os.flags()); + std::streamsize precision = os.precision(); // save the current precision const int64_t num_elements = tensor.NumElements(); const DataType data_type = tensor.data_type(); const DeviceType device_type = tensor.device_type(); @@ -398,6 +403,7 @@ std::ostream& operator<<(std::ostream& os, const Tensor& tensor) { #endif // restore the os settings os.flags(flag); + os.precision(precision); // restore the precision return os; } diff --git a/source/module_cell/read_stru.cpp b/source/module_cell/read_stru.cpp index 51af080611..a312cba744 100644 --- a/source/module_cell/read_stru.cpp +++ b/source/module_cell/read_stru.cpp @@ -5,8 +5,8 @@ namespace unitcell { bool check_tau(const Atom* atoms, - const int ntype, - const int lat0) + const int& ntype, + const double& lat0) { ModuleBase::TITLE("UnitCell","check_tau"); ModuleBase::timer::tick("UnitCell","check_tau"); diff --git a/source/module_cell/read_stru.h b/source/module_cell/read_stru.h index 3827666966..cff2a0c331 100644 --- a/source/module_cell/read_stru.h +++ b/source/module_cell/read_stru.h @@ -6,8 +6,8 @@ namespace unitcell { bool check_tau(const Atom* atoms, - const int ntype, - const int lat0); + const int& ntype, + const double& lat0); bool read_atom_species(std::ifstream& ifa, std::ofstream& ofs_running, diff --git a/source/module_cell/test/unitcell_test.cpp b/source/module_cell/test/unitcell_test.cpp index 2a783630dc..3e93a7b44f 100644 --- a/source/module_cell/test/unitcell_test.cpp +++ b/source/module_cell/test/unitcell_test.cpp @@ -753,7 +753,7 @@ TEST_F(UcellTest, CheckDTau) } } -TEST_F(UcellTest, CheckTau) +TEST_F(UcellTest, CheckTauFalse) { UcellTestPrepare utp = UcellTestLib["C1H2-CheckTau"]; PARAM.input.relax_new = utp.relax_new; @@ -769,6 +769,33 @@ TEST_F(UcellTest, CheckTau) remove("checktau_warning"); } +TEST_F(UcellTest, CheckTauTrue) +{ + UcellTestPrepare utp = UcellTestLib["C1H2-CheckTau"]; + PARAM.input.relax_new = utp.relax_new; + ucell = utp.SetUcellInfo(); + GlobalV::ofs_warning.open("checktau_warning"); + int atom=0; + //cause the ucell->lat0 is 0.5,if the type of the check_tau has + //an int type,it will set to zero,and it will not pass the unittest + ucell->lat0=0.5; + ucell->nat=3; + for (int it=0;itntype;it++) + { + for(int ia=0; iaatoms[it].na; ++ia) + { + + for (int i=0;i<3;i++) + { + ucell->atoms[it].tau[ia][i]=((atom+i)/(ucell->nat*3.0)); + } + atom+=3; + } + } + EXPECT_EQ(unitcell::check_tau(ucell->atoms ,ucell->ntype, ucell->lat0),true); + GlobalV::ofs_warning.close(); +} + TEST_F(UcellTest, SelectiveDynamics) { UcellTestPrepare utp = UcellTestLib["C1H2-SD"];