diff --git a/CMakeLists.txt b/CMakeLists.txt index dcfea2f1..f7f4bcd7 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -100,6 +100,37 @@ add_library(cytnx STATIC) set_property(TARGET cytnx PROPERTY C_VISIBILITY_PRESET hidden) set_property(TARGET cytnx PROPERTY VISIBILITY_INLINES_HIDDEN ON) +if(BACKEND_TORCH) + message(STATUS "backend = pytorch") + target_compile_definitions(cytnx PUBLIC BACKEND_TORCH) + + # let torch python expose where pytorch.cmake is installed + execute_process( + COMMAND bash -c "python -c 'import torch;print(torch.utils.cmake_prefix_path)'" + OUTPUT_VARIABLE TORCH_CMAKE_PATH_C + ) + string(REGEX REPLACE "\n$" "" TORCH_CMAKE_PATH_C "${TORCH_CMAKE_PATH_C}") + set(CMAKE_PREFIX_PATH ${CMAKE_PREFIX_PATH} ${TORCH_CMAKE_PATH_C}) + message(STATUS ${CMAKE_PREFIX_PATH}) + + find_package(Torch REQUIRED) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${TORCH_CXX_FLAGS}") + message(STATUS "pytorch: ${TORCH_INSTALL_PREFIX}") + message(STATUS "pytorch libs: ${TORCH_LIBRARIES}") + + if(${CMAKE_SYSTEM_NAME} MATCHES "Darwin") + target_link_libraries(cytnx PUBLIC ${TORCH_LIBRARIES} ${TORCH_INSTALL_PREFIX}/lib/libtorch_python.dylib) + else() + target_link_libraries(cytnx PUBLIC ${TORCH_LIBRARIES} ${TORCH_INSTALL_PREFIX}/lib/libtorch_python.so) + target_link_libraries(cytnx PRIVATE ${TORCH_LIBRARIES} ${TORCH_INSTALL_PREFIX}/lib/libtorch_python.so) + target_link_libraries(cytnx INTERFACE ${TORCH_LIBRARIES} ${TORCH_INSTALL_PREFIX}/lib/libtorch_python.so) + endif() + +else() + message(STATUS "backend = cytnx") + include(CytnxBKNDCMakeLists.cmake) +endif() # Backend torch + target_include_directories(cytnx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/src @@ -160,38 +191,6 @@ target_include_directories(cytnx SYSTEM target_link_libraries(cytnx PUBLIC Boost::boost ${LAPACK_LIBRARIES}) FILE(APPEND "${CMAKE_BINARY_DIR}/cxxflags.tmp" "-I${Boost_INCLUDE_DIRS}\n" "") -# ### -if(BACKEND_TORCH) - message(STATUS "backend = pytorch") - target_compile_definitions(cytnx PUBLIC BACKEND_TORCH) - - # let torch python expose where pytorch.cmake is installed - execute_process( - COMMAND bash -c "python -c 'import torch;print(torch.utils.cmake_prefix_path)'" - OUTPUT_VARIABLE TORCH_CMAKE_PATH_C - ) - string(REGEX REPLACE "\n$" "" TORCH_CMAKE_PATH_C "${TORCH_CMAKE_PATH_C}") - set(CMAKE_PREFIX_PATH ${CMAKE_PREFIX_PATH} ${TORCH_CMAKE_PATH_C}) - message(STATUS ${CMAKE_PREFIX_PATH}) - - find_package(Torch REQUIRED) - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${TORCH_CXX_FLAGS}") - message(STATUS "pytorch: ${TORCH_INSTALL_PREFIX}") - message(STATUS "pytorch libs: ${TORCH_LIBRARIES}") - - if(${CMAKE_SYSTEM_NAME} MATCHES "Darwin") - target_link_libraries(cytnx PUBLIC ${TORCH_LIBRARIES} ${TORCH_INSTALL_PREFIX}/lib/libtorch_python.dylib) - else() - target_link_libraries(cytnx PUBLIC ${TORCH_LIBRARIES} ${TORCH_INSTALL_PREFIX}/lib/libtorch_python.so) - target_link_libraries(cytnx PRIVATE ${TORCH_LIBRARIES} ${TORCH_INSTALL_PREFIX}/lib/libtorch_python.so) - target_link_libraries(cytnx INTERFACE ${TORCH_LIBRARIES} ${TORCH_INSTALL_PREFIX}/lib/libtorch_python.so) - endif() - -else() - message(STATUS "backend = cytnx") - include(CytnxBKNDCMakeLists.cmake) -endif() # Backend torch - # ##################################################################### # ## Get Gtest & benchmark # ##################################################################### diff --git a/Install.sh b/Install.sh index ce73d624..4e7ab103 100644 --- a/Install.sh +++ b/Install.sh @@ -183,7 +183,7 @@ echo ${FLAG} mkdir build cd build cmake ../ ${FLAG} -DDEV_MODE=on -# cmake ../ ${FLAG} -make -j`nproc` -# make install -# ctest -j`nproc` +make -j4 +make install +ctest +gcovr -r ../ . --html-details cov.html diff --git a/include/backend_torch/Scalar.hpp b/include/backend_torch/Scalar.hpp index 63a7fe12..ae8aea7f 100644 --- a/include/backend_torch/Scalar.hpp +++ b/include/backend_torch/Scalar.hpp @@ -1,10 +1,1847 @@ #ifndef _H_Scalar_ #define _H_Scalar_ +#include "cytnx_error.hpp" +#include "Type.hpp" #ifdef BACKEND_TORCH #include - namespace cytnx { - class Scalar : public torch::Scalar {}; -} // namespace cytnx + class Scalar : public c10::Scalar { + public: + int _dtype = Type.Void; + ///@cond + struct Sproxy { + c10::intrusive_ptr _insimpl; + cytnx_uint64 _loc; + int _dtype = Type.Void; + Sproxy() {} + Sproxy(c10::intrusive_ptr _ptr, const cytnx_uint64 &idx) + : _insimpl(_ptr), _loc(idx), _dtype(Type.Void) {} + + Sproxy(const Sproxy &rhs) { + this->_insimpl = rhs._insimpl; + this->_loc = rhs._loc; + this->_dtype = rhs._dtype; + } + + // When used to set elems: + Sproxy &operator=(const Scalar &rc); + Sproxy &operator=(const cytnx_complex128 &rc); + Sproxy &operator=(const cytnx_complex64 &rc); + Sproxy &operator=(const cytnx_double &rc); + Sproxy &operator=(const cytnx_float &rc); + Sproxy &operator=(const cytnx_uint64 &rc); + Sproxy &operator=(const cytnx_int64 &rc); + Sproxy &operator=(const cytnx_uint32 &rc); + Sproxy &operator=(const cytnx_int32 &rc); + Sproxy &operator=(const cytnx_uint16 &rc); + Sproxy &operator=(const cytnx_int16 &rc); + Sproxy &operator=(const cytnx_bool &rc); + + Sproxy &operator=(const Sproxy &rc); + + Sproxy copy() const { + Sproxy out = *this; + return out; + } + + Scalar real(); + Scalar imag(); + bool exists() const; + + // When used to get elements: + // operator Scalar() const; + }; + + /// @brief default constructor + Scalar(){}; + + // init!! + /// @brief init a Scalar with a cytnx::cytnx_complex128 + Scalar(const cytnx_complex128 &in) { this->Init_by_number(in); } + + /// @brief init a Scalar with a cytnx::cytnx_complex64 + Scalar(const cytnx_complex64 &in) { this->Init_by_number(in); } + + /// @brief init a Scalar with a cytnx::cytnx_double + Scalar(const cytnx_double &in) { this->Init_by_number(in); } + + /// @brief init a Scalar with a cytnx::cytnx_float + Scalar(const cytnx_float &in) { this->Init_by_number(in); } + + /// @brief init a Scalar with a cytnx::cytnx_uint64 + Scalar(const cytnx_uint64 &in) { + cytnx_error_msg(true, "[ERROR] no support for unsigned dtype for torch backend %s", "\n"); + } + + /// @brief init a Scalar with a cytnx::cytnx_int64 + Scalar(const cytnx_int64 &in) { this->Init_by_number(in); } + + /// @brief init a Scalar with a cytnx::cytnx_uint32 + Scalar(const cytnx_uint32 &in) { + cytnx_error_msg(true, "[ERROR] no support for unsigned dtype for torch backend %s", "\n"); + } + + /// @brief init a Scalar with a cytnx::cytnx_int32 + Scalar(const cytnx_int32 &in) { this->Init_by_number(in); } + + /// @brief init a Scalar with a cytnx::cytnx_uint16 + Scalar(const cytnx_uint16 &in) { + cytnx_error_msg(true, "[ERROR] no support for unsigned dtype for torch backend %s", "\n"); + } + + /// @brief init a Scalar with a cytnx::cytnx_int16 + Scalar(const cytnx_int16 &in) { this->Init_by_number(in); } + + /// @brief init a Scalar with a cytnx::cytnx_bool + Scalar(const cytnx_bool &in) { this->Init_by_number(in); } + + Scalar(c10::Scalar &rhs) : c10::Scalar(rhs) { + if (rhs.isComplex()) { + _dtype = Type.ComplexDouble; + } else if (rhs.isFloatingPoint()) { + _dtype = Type.Double; + } else if (rhs.isIntegral(/*includeBool=*/false)) { + _dtype = Type.Int64; + } else if (rhs.isBoolean()) { + _dtype = Type.Bool; + } else { + cytnx_error_msg(true, "[ERROR] invalid dtype for torch backend %s", "\n"); + } + } + Scalar(c10::Scalar &&rhs) : c10::Scalar(rhs) { + if (rhs.isComplex()) { + _dtype = Type.ComplexDouble; + } else if (rhs.isFloatingPoint()) { + _dtype = Type.Double; + } else if (rhs.isIntegral(/*includeBool=*/false)) { + _dtype = Type.Int64; + } else if (rhs.isBoolean()) { + _dtype = Type.Bool; + } else { + cytnx_error_msg(true, "[ERROR] invalid dtype for torch backend %s", "\n"); + } + } + + /** + * @brief Get the max value of the Scalar with the given \p dtype. + * @details This function is used to get the max value of the Scalar with the given \p dtype. + * That is, for example, if you want to get the max value of a Scalar with + * \p dtype = cytnx::Type.Int16, then you will get the max value of a 16-bit integer 32767. + * @param[in] dtype The data type of the Scalar. + * @return The max value of the Scalar with the given \p dtype. + */ + static Scalar maxval(const unsigned int &dtype) { + Scalar out(0, dtype); + switch (dtype) { + case Type.ComplexDouble: + cytnx_error_msg(true, "[ERROR] maxval not supported for complex type%s", "\n"); + break; + case Type.ComplexFloat: + cytnx_error_msg(true, "[ERROR] maxval not supported for complex type%s", "\n"); + break; + case Type.Double: + return Scalar(std::numeric_limits::max()); + break; + case Type.Float: + return Scalar(std::numeric_limits::max()); + break; + case Type.Int64: + return Scalar(std::numeric_limits::max()); + break; + case Type.Uint64: + cytnx_error_msg(true, "[ERROR] no support for unsigned dtype for torch backend %s", "\n"); + break; + case Type.Int32: + return Scalar(std::numeric_limits::max()); + break; + case Type.Uint32: + cytnx_error_msg(true, "[ERROR] no support for unsigned dtype for torch backend %s", "\n"); + break; + case Type.Int16: + return Scalar(std::numeric_limits::max()); + break; + case Type.Uint16: + cytnx_error_msg(true, "[ERROR] no support for unsigned dtype for torch backend %s", "\n"); + break; + case Type.Bool: + return Scalar(true); + break; + default: + cytnx_error_msg(true, "[ERROR] invalid dtype for torch backend %s", "\n"); + break; + } + } + + /** + * @brief Get the min value of the Scalar with the given \p dtype. + * @details This function is used to get the min value of the Scalar with the given \p dtype. + * That is, for example, if you want to get the min value of a Scalar with + * \p dtype = cytnx::Type.Int16, then you will get the min value of a 16-bit integer -32768. + * @param[in] dtype The data type of the Scalar. + * @return The min value of the Scalar with the given \p dtype. + */ + static Scalar minval(const unsigned int &dtype) { + Scalar out(0, dtype); + switch (dtype) { + case Type.ComplexDouble: + cytnx_error_msg(true, "[ERROR] maxval not supported for complex type%s", "\n"); + break; + case Type.ComplexFloat: + cytnx_error_msg(true, "[ERROR] maxval not supported for complex type%s", "\n"); + break; + case Type.Double: + return Scalar(std::numeric_limits::min()); + break; + case Type.Float: + return Scalar(std::numeric_limits::min()); + break; + case Type.Int64: + return Scalar(std::numeric_limits::min()); + break; + case Type.Uint64: + cytnx_error_msg(true, "[ERROR] no support for unsigned dtype for torch backend %s", "\n"); + break; + case Type.Int32: + return Scalar(std::numeric_limits::min()); + break; + case Type.Uint32: + cytnx_error_msg(true, "[ERROR] no support for unsigned dtype for torch backend %s", "\n"); + break; + case Type.Int16: + return Scalar(std::numeric_limits::min()); + break; + case Type.Uint16: + cytnx_error_msg(true, "[ERROR] no support for unsigned dtype for torch backend %s", "\n"); + break; + case Type.Bool: + return Scalar(false); + break; + default: + cytnx_error_msg(true, "[ERROR] invalid dtype for torch backend %s", "\n"); + break; + } + } + + /** + * @brief The constructor of the Scalar class. + * @details This constructor is used to init a Scalar with a given template value + * \p in and \p dtype (see cytnx::Type). + * @param[in] in The value of the Scalar. + * @param[in] dtype The data type of the Scalar. + * @return A Scalar object. + * @note The \p dtype can be any of the cytnx::Type. + */ + template + Scalar(const T &in, const unsigned int &dtype) { + destroy(); + *this = Scalar(in); + switch (dtype) { + case Type.ComplexDouble: + *this = Scalar(toComplexDouble()); + break; + case Type.ComplexFloat: + *this = Scalar(toComplexFloat()); + break; + case Type.Double: + *this = Scalar(toDouble()); + break; + case Type.Float: + *this = Scalar(toFloat()); + break; + case Type.Int64: + *this = Scalar(toLong()); + break; + case Type.Uint64: + cytnx_error_msg(true, "[ERROR] no support for unsigned dtype for torch backend %s", "\n"); + break; + case Type.Int32: + *this = Scalar(toInt()); + break; + case Type.Uint32: + cytnx_error_msg(true, "[ERROR] no support for unsigned dtype for torch backend %s", "\n"); + break; + case Type.Int16: + *this = Scalar(toShort()); + break; + case Type.Uint16: + cytnx_error_msg(true, "[ERROR] no support for unsigned dtype for torch backend %s", "\n"); + break; + case Type.Bool: + *this = Scalar(toBool()); + break; + default: + cytnx_error_msg(true, "[ERROR] invalid dtype for torch backend %s", "\n"); + break; + } + }; + + /// @cond + // move sproxy when use to get elements here. + Scalar(const Sproxy &prox); + + // specialization of init: + ///@cond + void Init_by_number(const cytnx_complex128 &in) { + destroy(); + *this = c10::Scalar(c10::complex(in)); + this->_dtype = Type.ComplexDouble; + }; + void Init_by_number(const cytnx_complex64 &in) { + destroy(); + *this = c10::Scalar(c10::complex(in)); + this->_dtype = Type.ComplexFloat; + }; + void Init_by_number(const cytnx_double &in) { + destroy(); + *this = c10::Scalar(in); + this->_dtype = Type.Double; + } + void Init_by_number(const cytnx_float &in) { + destroy(); + *this = c10::Scalar(in); + this->_dtype = Type.Float; + } + void Init_by_number(const cytnx_int64 &in) { + destroy(); + *this = c10::Scalar(in); + this->_dtype = Type.Int64; + } + void Init_by_number(const cytnx_uint64 &in) { + cytnx_error_msg(true, "[ERROR] no support for unsigned dtype for torch backend %s", "\n"); + } + void Init_by_number(const cytnx_int32 &in) { + destroy(); + *this = c10::Scalar(in); + this->_dtype = Type.Int32; + } + void Init_by_number(const cytnx_uint32 &in) { + cytnx_error_msg(true, "[ERROR] no support for unsigned dtype for torch backend %s", "\n"); + } + void Init_by_number(const cytnx_int16 &in) { + destroy(); + *this = c10::Scalar(in); + this->_dtype = Type.Int16; + } + void Init_by_number(const cytnx_uint16 &in) { + cytnx_error_msg(true, "[ERROR] no support for unsigned dtype for torch backend %s", "\n"); + } + void Init_by_number(const cytnx_bool &in) { + destroy(); + *this = c10::Scalar(in); + this->_dtype = Type.Bool; + } + + // The copy constructor + Scalar(const Scalar &rhs) { *this = c10::Scalar(rhs); } + /// @endcond + + /// @brief The copy assignment of the Scalar class. + Scalar &operator=(const Scalar &rhs) { + *this = c10::Scalar(rhs); + return *this; + }; + + // copy assignment [Number]: + /** @brief The copy assignment operator of the Scalar class with a given number + * cytnx::cytnx_complex128 \p rhs. + */ + Scalar &operator=(const cytnx_complex128 &rhs) { + this->Init_by_number(rhs); + return *this; + } + + /** @brief The copy assignment operator of the Scalar class with a given number + * cytnx::cytnx_complex64 \p rhs. + */ + Scalar &operator=(const cytnx_complex64 &rhs) { + this->Init_by_number(rhs); + return *this; + } + + /** @brief The copy assignment operator of the Scalar class with a given number + * cytnx::cytnx_double \p rhs. + */ + Scalar &operator=(const cytnx_double &rhs) { + this->Init_by_number(rhs); + return *this; + } + + /** @brief The copy assignment operator of the Scalar class with a given number + * cytnx::cytnx_float \p rhs. + */ + Scalar &operator=(const cytnx_float &rhs) { + this->Init_by_number(rhs); + return *this; + } + + /** @brief The copy assignment operator of the Scalar class with a given number + * cytnx::cytnx_uint64 \p rhs. + */ + Scalar &operator=(const cytnx_uint64 &rhs) { + this->Init_by_number(rhs); + return *this; + } + + /** @brief The copy assignment operator of the Scalar class with a given number + * cytnx::cytnx_int64 \p rhs. + */ + Scalar &operator=(const cytnx_int64 &rhs) { + this->Init_by_number(rhs); + return *this; + } + + /** @brief The copy assignment operator of the Scalar class with a given number + * cytnx::cytnx_uint32 \p rhs. + */ + Scalar &operator=(const cytnx_uint32 &rhs) { + this->Init_by_number(rhs); + return *this; + } + + /** @brief The copy assignment operator of the Scalar class with a given number + * cytnx::cytnx_int32 \p rhs. + */ + Scalar &operator=(const cytnx_int32 &rhs) { + this->Init_by_number(rhs); + return *this; + } + + /** @brief The copy assignment operator of the Scalar class with a given number + * cytnx::cytnx_uint16 \p rhs. + */ + Scalar &operator=(const cytnx_uint16 &rhs) { + this->Init_by_number(rhs); + return *this; + } + + /** @brief The copy assignment operator of the Scalar class with a given number + * cytnx::cytnx_int16 \p rhs. + */ + Scalar &operator=(const cytnx_int16 &rhs) { + this->Init_by_number(rhs); + return *this; + } + + /** @brief The copy assignment operator of the Scalar class with a given number + * cytnx::cytnx_bool \p rhs. + */ + Scalar &operator=(const cytnx_bool &rhs) { + this->Init_by_number(rhs); + return *this; + } + + /** + * @brief Type conversion function. + * @param[in] dtype The type of the output Scalar (see cytnx::Type for more details). + * @return The converted Scalar. + * @attention The function cannot convert from complex to real, please use + * cytnx::Scalar::real() or cytnx::Scalar::imag() to get the real or imaginary + * part of the Scalar instead. + */ + Scalar astype(const unsigned int &dtype) const { + Scalar out = *this; + switch (dtype) { + case Type.ComplexDouble: + out = Scalar(toComplexDouble()); + break; + case Type.ComplexFloat: + out = Scalar(toComplexFloat()); + break; + case Type.Double: + out = Scalar(toDouble()); + break; + case Type.Float: + out = Scalar(toFloat()); + break; + case Type.Int64: + out = Scalar(toLong()); + break; + case Type.Uint64: + cytnx_error_msg(true, "[ERROR] no support for unsigned dtype for torch backend %s", "\n"); + break; + case Type.Int32: + out = Scalar(toInt()); + break; + case Type.Uint32: + cytnx_error_msg(true, "[ERROR] no support for unsigned dtype for torch backend %s", "\n"); + break; + case Type.Int16: + out = Scalar(toShort()); + break; + case Type.Uint16: + cytnx_error_msg(true, "[ERROR] no support for unsigned dtype for torch backend %s", "\n"); + break; + case Type.Bool: + out = Scalar(toBool()); + break; + default: + cytnx_error_msg(true, "[ERROR] invalid dtype for torch backend %s", "\n"); + break; + } + } + + /** + * @brief Get the conjugate of the Scalar. That means return \f$ c^* \f$ if + * the Scalar is \f$ c \f$. + * @return The conjugate of the Scalar. + */ + Scalar conj() const { + Scalar out = *this; + out = out.c10::Scalar::conj(); + return out; + } + + /** + * @brief Get the imaginary part of the Scalar. That means return \f$ \Im(c) \f$ if + * the Scalar is \f$ c \f$. + * @return The imaginary part of the Scalar. + */ + Scalar imag() const { + if (!isComplex()) { + cytnx_error_msg(true, "[ERROR] cannot get imaginary part of a real Scalar.%s", "\n"); + } else { + if (this->dtype() == Type.ComplexDouble) { + return Scalar(toComplexDouble().imag()); + } else if (this->dtype() == Type.ComplexFloat) { + return Scalar(toComplexFloat().imag()); + } else { + cytnx_error_msg(true, "[ERROR] This should not happen %s", "\n"); + } + } + } + + /** + * @brief Get the real part of the Scalar. That means return \f$ \Re(c) \f$ if + * the Scalar is \f$ c \f$. + * @return The real part of the Scalar. + */ + Scalar real() const { + if (!isComplex()) { + return *this; + } else { + if (this->dtype() == Type.ComplexDouble) { + return c10::Scalar(toComplexDouble().real()); + } else if (this->dtype() == Type.ComplexFloat) { + return c10::Scalar(toComplexFloat().real()); + } else { + cytnx_error_msg(true, "[ERROR] This should not happen %s", "\n"); + } + } + } + // Scalar& set_imag(const Scalar &in){ return *this;} + // Scalar& set_real(const Scalar &in){ return *this;} + + /** + * @brief Get the dtype of the Scalar (see cytnx::Type for more details). + */ + int dtype() const { return this->_dtype; } + + /// @cond + /* + * @brief On the pytorch-side only. Not documented. + */ + void print_elem(std::ostream &os) const { + switch (_dtype) { + case Type.ComplexDouble: + os << "< " << toComplexDouble() << " >"; + break; + case Type.ComplexFloat: + os << "< " << toComplexFloat() << " >"; + break; + case Type.Double: + os << "< " << toDouble() << " >"; + break; + case Type.Float: + os << "< " << toFloat() << " >"; + break; + case Type.Int64: + os << "< " << toLong() << " >"; + break; + case Type.Uint64: + cytnx_error_msg(true, "[ERROR] no support for unsigned dtype for torch backend %s", "\n"); + break; + case Type.Int32: + os << "< " << toInt() << " >"; + break; + case Type.Uint32: + cytnx_error_msg(true, "[ERROR] no support for unsigned dtype for torch backend %s", "\n"); + break; + case Type.Int16: + os << "< " << toShort() << " >"; + break; + case Type.Uint16: + cytnx_error_msg(true, "[ERROR] no support for unsigned dtype for torch backend %s", "\n"); + break; + case Type.Bool: + os << "< " << toBool() << " >"; + break; + default: + cytnx_error_msg(true, "[ERROR] invalid dtype for torch backend %s", "\n"); + break; + } + } + /// @endcond + + /** + * @brief Print the Scalar to the standard output. + */ + void print() const { + print_elem(std::cout); + std::cout << std::string(" Scalar dtype: [") << Type.getname(this->_dtype) << std::string("]") + << std::endl; + } + + // casting + /// @brief The explicit casting operator of the Scalar class to cytnx::cytnx_double. + explicit operator cytnx_double() const { return toDouble(); } + + /// @brief The explicit casting operator of the Scalar class to cytnx::cytnx_float + explicit operator cytnx_float() const { return toFloat(); } + + /// @brief The explicit casting operator of the Scalar class to cytnx::cytnx_uint64. + explicit operator cytnx_uint64() const { + cytnx_error_msg(true, "[ERROR] no support for unsigned dtype for torch backend %s", "\n"); + } + + /// @brief The explicit casting operator of the Scalar class to cytnx::cytnx_int64. + explicit operator cytnx_int64() const { return toLong(); } + + /// @brief The explicit casting operator of the Scalar class to cytnx::cytnx_uint32. + explicit operator cytnx_uint32() const { + cytnx_error_msg(true, "[ERROR] no support for unsigned dtype for torch backend %s", "\n"); + } + + /// @brief The explicit casting operator of the Scalar class to cytnx::cytnx_int32. + explicit operator cytnx_int32() const { return toInt(); } + + /// @brief The explicit casting operator of the Scalar class to cytnx::cytnx_uint16. + explicit operator cytnx_uint16() const { + cytnx_error_msg(true, "[ERROR] no support for unsigned dtype for torch backend %s", "\n"); + } + + /// @brief The explicit casting operator of the Scalar class to cytnx::cytnx_int16. + explicit operator cytnx_int16() const { return toShort(); } + + /// @brief The explicit casting operator of the Scalar class to cytnx::cytnx_bool. + explicit operator cytnx_bool() const { return toBool(); } + + /// @cond + // destructor + ~Scalar(){}; + /// @endcond + + // arithmetic: + ///@brief The addition assignment operator of the Scalar class with a given number (template). + template + void operator+=(const T &rc) { + switch (_dtype) { + case Type.ComplexDouble: + *this = Scalar(toComplexDouble() + rc); + break; + case Type.ComplexFloat: + *this = Scalar(toComplexFloat() + rc); + break; + case Type.Double: + *this = Scalar(toDouble() + rc); + break; + case Type.Float: + *this = Scalar(toFloat() + rc); + break; + case Type.Int64: + *this = Scalar(toLong() + rc); + break; + case Type.Uint64: + cytnx_error_msg(true, "[ERROR] no support for unsigned dtype for torch backend %s", "\n"); + break; + case Type.Int32: + *this = Scalar(toInt() + rc); + break; + case Type.Uint32: + cytnx_error_msg(true, "[ERROR] no support for unsigned dtype for torch backend %s", "\n"); + break; + case Type.Int16: + *this = Scalar(toShort() + rc); + break; + case Type.Uint16: + cytnx_error_msg(true, "[ERROR] no support for unsigned dtype for torch backend %s", "\n"); + break; + case Type.Bool: + *this = Scalar(toBool() + rc); + break; + default: + cytnx_error_msg(true, "[ERROR] invalid dtype for torch backend %s", "\n"); + break; + } + } + + ///@brief The addition assignment operator of the Scalar class with a given Scalar. + void operator+=(const Scalar &rhs) { + switch (_dtype) { + case Type.ComplexDouble: + *this = Scalar(toComplexDouble() + rhs.toComplexDouble()); + break; + case Type.ComplexFloat: + *this = Scalar(toComplexFloat() + rhs.toComplexFloat()); + break; + case Type.Double: + *this = Scalar(toDouble() + rhs.toDouble()); + break; + case Type.Float: + *this = Scalar(toFloat() + rhs.toFloat()); + break; + case Type.Int64: + *this = Scalar(toLong() + rhs.toLong()); + break; + case Type.Uint64: + cytnx_error_msg(true, "[ERROR] no support for unsigned dtype for torch backend %s", "\n"); + break; + case Type.Int32: + *this = Scalar(toInt() + rhs.toInt()); + break; + case Type.Uint32: + cytnx_error_msg(true, "[ERROR] no support for unsigned dtype for torch backend %s", "\n"); + break; + case Type.Int16: + *this = Scalar(toShort() + rhs.toShort()); + break; + case Type.Uint16: + cytnx_error_msg(true, "[ERROR] no support for unsigned dtype for torch backend %s", "\n"); + break; + case Type.Bool: + *this = Scalar(toBool() + rhs.toBool()); + break; + default: + cytnx_error_msg(true, "[ERROR] invalid dtype for torch backend %s", "\n"); + break; + } + } + + ///@brief The subtraction assignment operator of the Scalar class with a given number + ///(template). + template + void operator-=(const T &rc) { + switch (_dtype) { + case Type.ComplexDouble: + *this = Scalar(toComplexDouble() - rc); + break; + case Type.ComplexFloat: + *this = Scalar(toComplexFloat() - rc); + break; + case Type.Double: + *this = Scalar(toDouble() - rc); + break; + case Type.Float: + *this = Scalar(toFloat() - rc); + break; + case Type.Int64: + *this = Scalar(toLong() - rc); + break; + case Type.Uint64: + cytnx_error_msg(true, "[ERROR] no support for unsigned dtype for torch backend %s", "\n"); + break; + case Type.Int32: + *this = Scalar(toInt() - rc); + break; + case Type.Uint32: + cytnx_error_msg(true, "[ERROR] no support for unsigned dtype for torch backend %s", "\n"); + break; + case Type.Int16: + *this = Scalar(toShort() - rc); + break; + case Type.Uint16: + cytnx_error_msg(true, "[ERROR] no support for unsigned dtype for torch backend %s", "\n"); + break; + case Type.Bool: + *this = Scalar(toBool() - rc); + break; + default: + cytnx_error_msg(true, "[ERROR] invalid dtype for torch backend %s", "\n"); + break; + } + } + + ///@brief The subtraction assignment operator of the Scalar class with a given Scalar. + void operator-=(const Scalar &rhs) { + switch (_dtype) { + case Type.ComplexDouble: + *this = Scalar(toComplexDouble() - rhs.toComplexDouble()); + break; + case Type.ComplexFloat: + *this = Scalar(toComplexFloat() - rhs.toComplexFloat()); + break; + case Type.Double: + *this = Scalar(toDouble() - rhs.toDouble()); + break; + case Type.Float: + *this = Scalar(toFloat() - rhs.toFloat()); + break; + case Type.Int64: + *this = Scalar(toLong() - rhs.toLong()); + break; + case Type.Uint64: + cytnx_error_msg(true, "[ERROR] no support for unsigned dtype for torch backend %s", "\n"); + break; + case Type.Int32: + *this = Scalar(toInt() - rhs.toInt()); + break; + case Type.Uint32: + cytnx_error_msg(true, "[ERROR] no support for unsigned dtype for torch backend %s", "\n"); + break; + case Type.Int16: + *this = Scalar(toShort() - rhs.toShort()); + break; + case Type.Uint16: + cytnx_error_msg(true, "[ERROR] no support for unsigned dtype for torch backend %s", "\n"); + break; + case Type.Bool: + *this = Scalar(toBool() - rhs.toBool()); + break; + default: + cytnx_error_msg(true, "[ERROR] invalid dtype for torch backend %s", "\n"); + break; + } + } + template + + ///@brief The multiplication assignment operator of the Scalar class with a given number + ///(template). + void operator*=(const T &rc) { + switch (_dtype) { + case Type.ComplexDouble: + *this = Scalar(toComplexDouble() * rc); + break; + case Type.ComplexFloat: + *this = Scalar(toComplexFloat() * rc); + break; + case Type.Double: + *this = Scalar(toDouble() * rc); + break; + case Type.Float: + *this = Scalar(toFloat() * rc); + break; + case Type.Int64: + *this = Scalar(toLong() * rc); + break; + case Type.Uint64: + cytnx_error_msg(true, "[ERROR] no support for unsigned dtype for torch backend %s", "\n"); + break; + case Type.Int32: + *this = Scalar(toInt() * rc); + break; + case Type.Uint32: + cytnx_error_msg(true, "[ERROR] no support for unsigned dtype for torch backend %s", "\n"); + break; + case Type.Int16: + *this = Scalar(toShort() * rc); + break; + case Type.Uint16: + cytnx_error_msg(true, "[ERROR] no support for unsigned dtype for torch backend %s", "\n"); + break; + case Type.Bool: + *this = Scalar(toBool() * rc); + break; + default: + cytnx_error_msg(true, "[ERROR] invalid dtype for torch backend %s", "\n"); + break; + } + } + + ///@brief The multiplication assignment operator of the Scalar class with a given Scalar. + void operator*=(const Scalar &rhs) { + switch (_dtype) { + case Type.ComplexDouble: + *this = Scalar(toComplexDouble() * rhs.toComplexDouble()); + break; + case Type.ComplexFloat: + *this = Scalar(toComplexFloat() * rhs.toComplexFloat()); + break; + case Type.Double: + *this = Scalar(toDouble() * rhs.toDouble()); + break; + case Type.Float: + *this = Scalar(toFloat() * rhs.toFloat()); + break; + case Type.Int64: + *this = Scalar(toLong() * rhs.toLong()); + break; + case Type.Uint64: + cytnx_error_msg(true, "[ERROR] no support for unsigned dtype for torch backend %s", "\n"); + break; + case Type.Int32: + *this = Scalar(toInt() * rhs.toInt()); + break; + case Type.Uint32: + cytnx_error_msg(true, "[ERROR] no support for unsigned dtype for torch backend %s", "\n"); + break; + case Type.Int16: + *this = Scalar(toShort() * rhs.toShort()); + break; + case Type.Uint16: + cytnx_error_msg(true, "[ERROR] no support for unsigned dtype for torch backend %s", "\n"); + break; + case Type.Bool: + *this = Scalar(toBool() * rhs.toBool()); + break; + default: + cytnx_error_msg(true, "[ERROR] invalid dtype for torch backend %s", "\n"); + break; + } + } + template + + /** + * @brief The division assignment operator of the Scalar class with a given number (template). + */ + void operator/=(const T &rc) { + switch (_dtype) { + case Type.ComplexDouble: + *this = Scalar(toComplexDouble() / rc); + break; + case Type.ComplexFloat: + *this = Scalar(toComplexFloat() / rc); + break; + case Type.Double: + *this = Scalar(toDouble() / rc); + break; + case Type.Float: + *this = Scalar(toFloat() / rc); + break; + case Type.Int64: + *this = Scalar(toLong() / rc); + break; + case Type.Uint64: + cytnx_error_msg(true, "[ERROR] no support for unsigned dtype for torch backend %s", "\n"); + break; + case Type.Int32: + *this = Scalar(toInt() / rc); + break; + case Type.Uint32: + cytnx_error_msg(true, "[ERROR] no support for unsigned dtype for torch backend %s", "\n"); + break; + case Type.Int16: + *this = Scalar(toShort() / rc); + break; + case Type.Uint16: + cytnx_error_msg(true, "[ERROR] no support for unsigned dtype for torch backend %s", "\n"); + break; + case Type.Bool: + *this = Scalar(toBool() / rc); + break; + default: + cytnx_error_msg(true, "[ERROR] invalid dtype for torch backend %s", "\n"); + break; + } + } + + /** + * @brief The division assignment operator of the Scalar class with a given Scalar. + */ + void operator/=(const Scalar &rhs) { + switch (_dtype) { + case Type.ComplexDouble: + *this = Scalar(toComplexDouble() / rhs.toComplexDouble()); + break; + case Type.ComplexFloat: + *this = Scalar(toComplexFloat() / rhs.toComplexFloat()); + break; + case Type.Double: + *this = Scalar(toDouble() / rhs.toDouble()); + break; + case Type.Float: + *this = Scalar(toFloat() / rhs.toFloat()); + break; + case Type.Int64: + *this = Scalar(toLong() / rhs.toLong()); + break; + case Type.Uint64: + cytnx_error_msg(true, "[ERROR] no support for unsigned dtype for torch backend %s", "\n"); + break; + case Type.Int32: + *this = Scalar(toInt() / rhs.toInt()); + break; + case Type.Uint32: + cytnx_error_msg(true, "[ERROR] no support for unsigned dtype for torch backend %s", "\n"); + break; + case Type.Int16: + *this = Scalar(toShort() / rhs.toShort()); + break; + case Type.Uint16: + cytnx_error_msg(true, "[ERROR] no support for unsigned dtype for torch backend %s", "\n"); + break; + case Type.Bool: + *this = Scalar(toBool() / rhs.toBool()); + break; + default: + cytnx_error_msg(true, "[ERROR] invalid dtype for torch backend %s", "\n"); + break; + } + } + + /// @brief Set the Scalar to absolute value. (inplace) + void iabs() { *this = this->abs(); } + + /// @brief Set the Scalar to square root. (inplace) + void isqrt() { *this = this->sqrt(); } + + /** + * @brief The member function to get the absolute value of the Scalar. + * @note Compare to the iabs() function, this function will return a new Scalar object. + * @return The absolute value of the Scalar. + * @see iabs() + */ + Scalar abs() const { + Scalar out = *this; + switch (_dtype) { + case Type.ComplexDouble: + out = Scalar(std::abs(toComplexDouble())); + break; + case Type.ComplexFloat: + out = Scalar(std::abs(toComplexFloat())); + break; + case Type.Double: + out = Scalar(std::abs(toDouble())); + break; + case Type.Float: + out = Scalar(std::abs(toFloat())); + break; + case Type.Int64: + out = Scalar(std::abs(toLong())); + break; + case Type.Uint64: + cytnx_error_msg(true, "[ERROR] no support for unsigned dtype for torch backend %s", "\n"); + break; + case Type.Int32: + out = Scalar(std::abs(toInt())); + break; + case Type.Uint32: + cytnx_error_msg(true, "[ERROR] no support for unsigned dtype for torch backend %s", "\n"); + break; + case Type.Int16: + out = Scalar(std::abs(toShort())); + break; + case Type.Uint16: + cytnx_error_msg(true, "[ERROR] no support for unsigned dtype for torch backend %s", "\n"); + break; + case Type.Bool: + out = Scalar(std::abs(toBool())); + break; + default: + cytnx_error_msg(true, "[ERROR] invalid dtype for torch backend %s", "\n"); + break; + } + return out; + } + + /** + * @brief The member function to get the square root of the Scalar. + * @note Compare to the isqrt() function, this function will return a new Scalar object. + * @return The square root of the Scalar. + * @see isqrt() + */ + Scalar sqrt() const { + Scalar out = *this; + switch (_dtype) { + case Type.ComplexDouble: + out = Scalar(std::sqrt(toComplexDouble())); + break; + case Type.ComplexFloat: + out = Scalar(std::sqrt(toComplexFloat())); + break; + case Type.Double: + out = Scalar(std::sqrt(toDouble())); + break; + case Type.Float: + out = Scalar(std::sqrt(toFloat())); + break; + case Type.Int64: + out = Scalar(std::sqrt(toLong())); + break; + case Type.Uint64: + cytnx_error_msg(true, "[ERROR] no support for unsigned dtype for torch backend %s", "\n"); + break; + case Type.Int32: + out = Scalar(std::sqrt(toInt())); + break; + case Type.Uint32: + cytnx_error_msg(true, "[ERROR] no support for unsigned dtype for torch backend %s", "\n"); + break; + case Type.Int16: + out = Scalar(std::sqrt(toShort())); + break; + case Type.Uint16: + cytnx_error_msg(true, "[ERROR] no support for unsigned dtype for torch backend %s", "\n"); + break; + case Type.Bool: + out = Scalar(std::sqrt(toBool())); + break; + default: + cytnx_error_msg(true, "[ERROR] invalid dtype for torch backend %s", "\n"); + break; + } + return out; + } + + // comparison < + /** + * @brief Return whether the current Scalar is less than a given template number \p rc. + * @details That is, whether \f$ s < r \f$, where \f$ s \f$ is the current Scalar + * itself and \f$ r \f$ is the given number \p rc. + * @see operator<(const Scalar &lhs, const Scalar &rhs) + */ + template + bool less(const T &rc) const { + Scalar tmp; + int rid = Type.cy_typeid(rc); + if (rid < this->dtype()) { + tmp = this->astype(rid); + } else { + tmp = *this; + } + switch (tmp.dtype()) { + case Type.ComplexDouble: + cytnx_error_msg(true, "[ERROR] comparison not supported for complex type%s", "\n"); + break; + case Type.ComplexFloat: + cytnx_error_msg(true, "[ERROR] comparison not supported for complex type%s", "\n"); + break; + case Type.Double: + return tmp.toDouble() < rc; + break; + case Type.Float: + return tmp.toFloat() < rc; + break; + case Type.Int64: + return tmp.toLong() < rc; + break; + case Type.Uint64: + cytnx_error_msg(true, "[ERROR] no support for unsigned dtype for torch backend %s", "\n"); + break; + case Type.Int32: + return tmp.toInt() < rc; + break; + case Type.Uint32: + cytnx_error_msg(true, "[ERROR] no support for unsigned dtype for torch backend %s", "\n"); + break; + case Type.Int16: + return tmp.toShort() < rc; + break; + case Type.Uint16: + cytnx_error_msg(true, "[ERROR] no support for unsigned dtype for torch backend %s", "\n"); + break; + case Type.Bool: + return tmp.toBool() < rc; + break; + default: + cytnx_error_msg(true, "[ERROR] invalid dtype for torch backend %s", "\n"); + break; + } + } + + /** + * @brief Return whether the current Scalar is less than a given Scalar \p rhs. + * @details That is, whether \f$ s < r \f$, where \f$ s \f$ is the current Scalar + * itself and \f$ r \f$ is the given Scalar \p rhs. + * @see operator<(const Scalar &lhs, const Scalar &rhs) + */ + bool less(const Scalar &rhs) const { + Scalar tmp; + if (rhs.dtype() < this->dtype()) { + tmp = this->astype(rhs.dtype()); + } else { + tmp = *this; + } + switch (tmp.dtype()) { + case Type.ComplexDouble: + cytnx_error_msg(true, "[ERROR] comparison not supported for complex type%s", "\n"); + break; + case Type.ComplexFloat: + cytnx_error_msg(true, "[ERROR] comparison not supported for complex type%s", "\n"); + break; + case Type.Double: + return tmp.toDouble() < rhs.toDouble(); + break; + case Type.Float: + return tmp.toFloat() < rhs.toFloat(); + break; + case Type.Int64: + return tmp.toLong() < rhs.toLong(); + break; + case Type.Uint64: + cytnx_error_msg(true, "[ERROR] no support for unsigned dtype for torch backend %s", "\n"); + break; + case Type.Int32: + return tmp.toInt() < rhs.toInt(); + break; + case Type.Uint32: + cytnx_error_msg(true, "[ERROR] no support for unsigned dtype for torch backend %s", "\n"); + break; + case Type.Int16: + return tmp.toShort() < rhs.toShort(); + break; + case Type.Uint16: + cytnx_error_msg(true, "[ERROR] no support for unsigned dtype for torch backend %s", "\n"); + break; + case Type.Bool: + return tmp.toBool() < rhs.toBool(); + break; + default: + cytnx_error_msg(true, "[ERROR] invalid dtype for torch backend %s", "\n"); + break; + } + } + + // comparison <= + + /** + * @brief Return whether the current Scalar is less than or equal to a given template number + \p + * rc. + * @details That is, whether \f$ s \leq r \f$, where \f$ s \f$ is the current Scalar + * itself and \f$ r \f$ is the given number \p rc. + * @see operator<=(const Scalar &lhs, const Scalar &rhs) + */ + template + bool leq(const T &rc) const { + Scalar tmp; + int rid = Type.cy_typeid(rc); + if (rid < this->dtype()) { + tmp = this->astype(rid); + } else { + tmp = *this; + } + switch (tmp.dtype()) { + case Type.ComplexDouble: + cytnx_error_msg(true, "[ERROR] comparison not supported for complex type%s", "\n"); + break; + case Type.ComplexFloat: + cytnx_error_msg(true, "[ERROR] comparison not supported for complex type%s", "\n"); + break; + case Type.Double: + return tmp.toDouble() <= rc; + break; + case Type.Float: + return tmp.toFloat() <= rc; + break; + case Type.Int64: + return tmp.toLong() <= rc; + break; + case Type.Uint64: + cytnx_error_msg(true, "[ERROR] no support for unsigned dtype for torch backend %s", "\n"); + break; + case Type.Int32: + return tmp.toInt() <= rc; + break; + case Type.Uint32: + cytnx_error_msg(true, "[ERROR] no support for unsigned dtype for torch backend %s", "\n"); + break; + case Type.Int16: + return tmp.toShort() <= rc; + break; + case Type.Uint16: + cytnx_error_msg(true, "[ERROR] no support for unsigned dtype for torch backend %s", "\n"); + break; + case Type.Bool: + return tmp.toBool() <= rc; + break; + default: + cytnx_error_msg(true, "[ERROR] invalid dtype for torch backend %s", "\n"); + break; + } + } + + /** + * @brief Return whether the current Scalar is less than or equal to a given Scalar \p rhs. + * @details That is, whether \f$ s \leq r \f$, where \f$ s \f$ is the current Scalar + * itself and \f$ r \f$ is the given Scalar \p rhs. + * @see operator<=(const Scalar &lhs, const Scalar &rhs) + */ + bool leq(const Scalar &rhs) const { + Scalar tmp; + if (rhs.dtype() < this->dtype()) { + tmp = this->astype(rhs.dtype()); + } else { + tmp = *this; + } + switch (tmp.dtype()) { + case Type.ComplexDouble: + cytnx_error_msg(true, "[ERROR] comparison not supported for complex type%s", "\n"); + break; + case Type.ComplexFloat: + cytnx_error_msg(true, "[ERROR] comparison not supported for complex type%s", "\n"); + break; + case Type.Double: + return tmp.toDouble() <= rhs.toDouble(); + break; + case Type.Float: + return tmp.toFloat() <= rhs.toFloat(); + break; + case Type.Int64: + return tmp.toLong() <= rhs.toLong(); + break; + case Type.Uint64: + cytnx_error_msg(true, "[ERROR] no support for unsigned dtype for torch backend %s", "\n"); + break; + case Type.Int32: + return tmp.toInt() <= rhs.toInt(); + break; + case Type.Uint32: + cytnx_error_msg(true, "[ERROR] no support for unsigned dtype for torch backend %s", "\n"); + break; + case Type.Int16: + return tmp.toShort() <= rhs.toShort(); + break; + case Type.Uint16: + cytnx_error_msg(true, "[ERROR] no support for unsigned dtype for torch backend %s", "\n"); + break; + case Type.Bool: + return tmp.toBool() <= rhs.toBool(); + break; + default: + cytnx_error_msg(true, "[ERROR] invalid dtype for torch backend %s", "\n"); + break; + } + } + + // comparison > + /** + * @brief Return whether the current Scalar is greater than a given template number \p rc. + * @details That is, whether \f$ s > r \f$, where \f$ s \f$ is the current Scalar + * itself and \f$ r \f$ is the given number \p rc. + * @see operator>(const Scalar &lhs, const Scalar &rhs) + */ + template + bool greater(const T &rc) const { + Scalar tmp; + int rid = Type.cy_typeid(rc); + if (rid < this->dtype()) { + tmp = this->astype(rid); + } else { + tmp = *this; + } + switch (tmp.dtype()) { + case Type.ComplexDouble: + cytnx_error_msg(true, "[ERROR] comparison not supported for complex type%s", "\n"); + break; + case Type.ComplexFloat: + cytnx_error_msg(true, "[ERROR] comparison not supported for complex type%s", "\n"); + break; + case Type.Double: + return tmp.toDouble() > rc; + break; + case Type.Float: + return tmp.toFloat() > rc; + break; + case Type.Int64: + return tmp.toLong() > rc; + break; + case Type.Uint64: + cytnx_error_msg(true, "[ERROR] no support for unsigned dtype for torch backend %s", "\n"); + break; + case Type.Int32: + return tmp.toInt() > rc; + break; + case Type.Uint32: + cytnx_error_msg(true, "[ERROR] no support for unsigned dtype for torch backend %s", "\n"); + break; + case Type.Int16: + return tmp.toShort() > rc; + break; + case Type.Uint16: + cytnx_error_msg(true, "[ERROR] no support for unsigned dtype for torch backend %s", "\n"); + break; + case Type.Bool: + return tmp.toBool() > rc; + break; + default: + cytnx_error_msg(true, "[ERROR] invalid dtype for torch backend %s", "\n"); + break; + } + } + + /** + * @brief Return whether the current Scalar is greater than a given Scalar \p rhs. + * @details That is, whether \f$ s > r \f$, where \f$ s \f$ is the current Scalar + * itself and \f$ r \f$ is the given Scalar \p rhs. + * @see operator>(const Scalar &lhs, const Scalar &rhs) + */ + bool greater(const Scalar &rhs) const { + Scalar tmp; + if (rhs.dtype() < this->dtype()) { + tmp = this->astype(rhs.dtype()); + } else { + tmp = *this; + } + switch (tmp.dtype()) { + case Type.ComplexDouble: + cytnx_error_msg(true, "[ERROR] comparison not supported for complex type%s", "\n"); + break; + case Type.ComplexFloat: + cytnx_error_msg(true, "[ERROR] comparison not supported for complex type%s", "\n"); + break; + case Type.Double: + return tmp.toDouble() > rhs.toDouble(); + break; + case Type.Float: + return tmp.toFloat() > rhs.toFloat(); + break; + case Type.Int64: + return tmp.toLong() > rhs.toLong(); + break; + case Type.Uint64: + cytnx_error_msg(true, "[ERROR] no support for unsigned dtype for torch backend %s", "\n"); + break; + case Type.Int32: + return tmp.toInt() > rhs.toInt(); + break; + case Type.Uint32: + cytnx_error_msg(true, "[ERROR] no support for unsigned dtype for torch backend %s", "\n"); + break; + case Type.Int16: + return tmp.toShort() > rhs.toShort(); + break; + case Type.Uint16: + cytnx_error_msg(true, "[ERROR] no support for unsigned dtype for torch backend %s", "\n"); + break; + case Type.Bool: + return tmp.toBool() > rhs.toBool(); + break; + default: + cytnx_error_msg(true, "[ERROR] invalid dtype for torch backend %s", "\n"); + break; + } + } + + // comparison >= + + /** + * @brief Return whether the current Scalar is greater than or equal to a given template + number + * \p rc. + * @details That is, whether \f$ s \geq r \f$, where \f$ s \f$ is the current Scalar + * itself and \f$ r \f$ is the given number \p rc. + * @see operator>=(const Scalar &lhs, const Scalar &rhs) + */ + template + bool geq(const T &rc) const { + Scalar tmp; + int rid = Type.cy_typeid(rc); + if (rid < this->dtype()) { + tmp = this->astype(rid); + } else { + tmp = *this; + } + switch (tmp.dtype()) { + case Type.ComplexDouble: + cytnx_error_msg(true, "[ERROR] comparison not supported for complex type%s", "\n"); + break; + case Type.ComplexFloat: + cytnx_error_msg(true, "[ERROR] comparison not supported for complex type%s", "\n"); + break; + case Type.Double: + return tmp.toDouble() >= rc; + break; + case Type.Float: + return tmp.toFloat() >= rc; + break; + case Type.Int64: + return tmp.toLong() >= rc; + break; + case Type.Uint64: + cytnx_error_msg(true, "[ERROR] no support for unsigned dtype for torch backend %s", "\n"); + break; + case Type.Int32: + return tmp.toInt() >= rc; + break; + case Type.Uint32: + cytnx_error_msg(true, "[ERROR] no support for unsigned dtype for torch backend %s", "\n"); + break; + case Type.Int16: + return tmp.toShort() >= rc; + break; + case Type.Uint16: + cytnx_error_msg(true, "[ERROR] no support for unsigned dtype for torch backend %s", "\n"); + break; + case Type.Bool: + return tmp.toBool() >= rc; + break; + default: + cytnx_error_msg(true, "[ERROR] invalid dtype for torch backend %s", "\n"); + break; + } + } + + /** + * @brief Return whether the current Scalar is greater than or equal to a given Scalar \p + rhs. + * @details That is, whether \f$ s \geq r \f$, where \f$ s \f$ is the current Scalar + * itself and \f$ r \f$ is the given Scalar \p rhs. + * @see operator>=(const Scalar &lhs, const Scalar &rhs) + */ + bool geq(const Scalar &rhs) const { + Scalar tmp; + if (rhs.dtype() < this->dtype()) { + tmp = this->astype(rhs.dtype()); + } else { + tmp = *this; + } + switch (tmp.dtype()) { + case Type.ComplexDouble: + cytnx_error_msg(true, "[ERROR] comparison not supported for complex type%s", "\n"); + break; + case Type.ComplexFloat: + cytnx_error_msg(true, "[ERROR] comparison not supported for complex type%s", "\n"); + break; + case Type.Double: + return tmp.toDouble() >= rhs.toDouble(); + break; + case Type.Float: + return tmp.toFloat() >= rhs.toFloat(); + break; + case Type.Int64: + return tmp.toLong() >= rhs.toLong(); + break; + case Type.Uint64: + cytnx_error_msg(true, "[ERROR] no support for unsigned dtype for torch backend %s", "\n"); + break; + case Type.Int32: + return tmp.toInt() >= rhs.toInt(); + break; + case Type.Uint32: + cytnx_error_msg(true, "[ERROR] no support for unsigned dtype for torch backend %s", "\n"); + break; + case Type.Int16: + return tmp.toShort() >= rhs.toShort(); + break; + case Type.Uint16: + cytnx_error_msg(true, "[ERROR] no support for unsigned dtype for torch backend %s", "\n"); + break; + case Type.Bool: + return tmp.toBool() >= rhs.toBool(); + break; + default: + cytnx_error_msg(true, "[ERROR] invalid dtype for torch backend %s", "\n"); + break; + } + } + + // comparison == + + /** + * @brief Return whether the current Scalar is equal to a given template number \p rc. + * @details That is, whether \f$ s = r \f$, where \f$ s \f$ is the current Scalar + * itself and \f$ r \f$ is the given number \p rc. + * @see operator==(const Scalar &lhs, const Scalar &rhs) + */ + template + bool eq(const T &rc) const { + Scalar tmp; + int rid = Type.cy_typeid(rc); + if (rid < this->dtype()) { + tmp = this->astype(rid); + } else { + tmp = *this; + } + switch (tmp.dtype()) { + case Type.ComplexDouble: + cytnx_error_msg(true, "[ERROR] comparison not supported for complex type%s", "\n"); + break; + case Type.ComplexFloat: + cytnx_error_msg(true, "[ERROR] comparison not supported for complex type%s", "\n"); + break; + case Type.Double: + return tmp.toDouble() == rc; + break; + case Type.Float: + return tmp.toFloat() == rc; + break; + case Type.Int64: + return tmp.toLong() == rc; + break; + case Type.Uint64: + cytnx_error_msg(true, "[ERROR] no support for unsigned dtype for torch backend %s", "\n"); + break; + case Type.Int32: + return tmp.toInt() == rc; + break; + case Type.Uint32: + cytnx_error_msg(true, "[ERROR] no support for unsigned dtype for torch backend %s", "\n"); + break; + case Type.Int16: + return tmp.toShort() == rc; + break; + case Type.Uint16: + cytnx_error_msg(true, "[ERROR] no support for unsigned dtype for torch backend %s", "\n"); + break; + case Type.Bool: + return tmp.toBool() == rc; + break; + default: + cytnx_error_msg(true, "[ERROR] invalid dtype for torch backend %s", "\n"); + break; + } + } + // /** + // * @brief Return whether the current Scalar is approximately equal to a given template number + // \p + // * rc. + // * @details That is, whether \f$ abs(s-r) + // bool approx_eq(const T &rc, cytnx_double tol = 1e-8) const { + // Scalar tmp; + // int rid = Type.cy_typeid(rc); + // if (rid < this->dtype()) { + // tmp = this->astype(rid); + // return tmp._impl->approx_eq(rc, tol); + // } else { + // return this->_impl->approx_eq(rc, tol); + // } + // } + + /** + * @brief Return whether the current Scalar is equal to a given Scalar \p rhs. + * @details That is, whether \f$ s = r \f$, where \f$ s \f$ is the current Scalar + * itself and \f$ r \f$ is the given Scalar \p rhs. + * @see operator==(const Scalar &lhs, const Scalar &rhs) + */ + bool eq(const Scalar &rhs) const { + Scalar tmp; + if (rhs.dtype() < this->dtype()) { + tmp = this->astype(rhs.dtype()); + } else { + tmp = *this; + } + switch (tmp.dtype()) { + case Type.ComplexDouble: + cytnx_error_msg(true, "[ERROR] comparison not supported for complex type%s", "\n"); + break; + case Type.ComplexFloat: + cytnx_error_msg(true, "[ERROR] comparison not supported for complex type%s", "\n"); + break; + case Type.Double: + return tmp.toDouble() == rhs.toDouble(); + break; + case Type.Float: + return tmp.toFloat() == rhs.toFloat(); + break; + case Type.Int64: + return tmp.toLong() == rhs.toLong(); + break; + case Type.Uint64: + cytnx_error_msg(true, "[ERROR] no support for unsigned dtype for torch backend %s", "\n"); + break; + case Type.Int32: + return tmp.toInt() == rhs.toInt(); + break; + case Type.Uint32: + cytnx_error_msg(true, "[ERROR] no support for unsigned dtype for torch backend %s", "\n"); + break; + case Type.Int16: + return tmp.toShort() == rhs.toShort(); + break; + case Type.Uint16: + cytnx_error_msg(true, "[ERROR] no support for unsigned dtype for torch backend %s", "\n"); + break; + case Type.Bool: + return tmp.toBool() == rhs.toBool(); + break; + default: + cytnx_error_msg(true, "[ERROR] invalid dtype for torch backend %s", "\n"); + break; + } + } + // /** + // * @brief Return whether the current Scalar is approximately equal to a given Scalar \p rhs. + // * @details That is, whether \f$ abs(s-r)dtype()) { + // tmp = this->astype(rhs.dtype()); + // return tmp._impl->approx_eq(rhs._impl, tol); + // } else { + // return this->_impl->approx_eq(rhs._impl, tol); + // } + // } + + // radd: Scalar + c + + /** + * @brief Return the addition of the current Scalar and a given template number \p rc. + * @see operator+(const Scalar &lhs, const Scalar &rhs) + */ + template + Scalar radd(const T &rc) const { + Scalar out; + int rid = Type.cy_typeid(rc); + if (this->dtype() < rid) { + out = *this; + } else { + out = this->astype(rid); + } + out += rc; + return out; + } + + /** + * @brief Return the addition of the current Scalar and a given Scalar \p rhs. + * @see operator+(const Scalar &lhs, const Scalar &rhs) + */ + Scalar radd(const Scalar &rhs) const { + Scalar out; + if (this->dtype() < rhs.dtype()) { + out = *this; + } else { + out = this->astype(rhs.dtype()); + } + out += rhs; + return out; + } + + // rmul: Scalar * c + + /** + * @brief Return the multiplication of the current Scalar and a given template number \p rc. + * @see operator*(const Scalar &lhs, const Scalar &rhs) + */ + template + Scalar rmul(const T &rc) const { + Scalar out; + int rid = Type.cy_typeid(rc); + if (this->dtype() < rid) { + out = *this; + } else { + out = this->astype(rid); + } + out *= rc; + return out; + } + + /** + * @brief Return the multiplication of the current Scalar and a given Scalar \p rhs. + * @see operator*(const Scalar &lhs, const Scalar &rhs) + */ + Scalar rmul(const Scalar &rhs) const { + Scalar out; + if (this->dtype() < rhs.dtype()) { + out = *this; + } else { + out = this->astype(rhs.dtype()); + } + out *= rhs; + return out; + } + + // rsub: Scalar - c + + /** + * @brief Return the subtraction of the current Scalar and a given template number \p rc. + * @see operator-(const Scalar &lhs, const Scalar &rhs) + */ + template + Scalar rsub(const T &rc) const { + Scalar out; + int rid = Type.cy_typeid(rc); + if (this->dtype() < rid) { + out = *this; + } else { + out = this->astype(rid); + } + out -= rc; + return out; + } + + /** + * @brief Return the subtraction of the current Scalar and a given Scalar \p rhs. + * @see operator-(const Scalar &lhs, const Scalar &rhs) + */ + Scalar rsub(const Scalar &rhs) const { + Scalar out; + if (this->dtype() < rhs.dtype()) { + out = *this; + } else { + out = this->astype(rhs.dtype()); + } + out -= rhs; + return out; + } + + // rdiv: Scalar / c + + /** + * @brief Return the division of the current Scalar and a given template number \p rc. + * @see operator/(const Scalar &lhs, const Scalar &rhs) + */ + template + Scalar rdiv(const T &rc) const { + Scalar out; + int rid = Type.cy_typeid(rc); + if (this->dtype() < rid) { + out = *this; + } else { + out = this->astype(rid); + } + out /= rc; + return out; + } + + /** + * @brief Return the division of the current Scalar and a given Scalar \p rhs. + * @see operator/(const Scalar &lhs, const Scalar &rhs) + */ + Scalar rdiv(const Scalar &rhs) const { + Scalar out; + if (this->dtype() < rhs.dtype()) { + out = *this; + } else { + out = this->astype(rhs.dtype()); + } + out /= rhs; + return out; + } + + /* + //operator: + template + Scalar operator+(const T &rc){ + return this->radd(rc); + } + template + Scalar operator*(const T &rc){ + return this->rmul(rc); + } + template + Scalar operator-(const T &rc){ + return this->rsub(rc); + } + template + Scalar operator/(const T &rc){ + return this->rdiv(rc); + } + + template + bool operator<(const T &rc){ + return this->less(rc); + } + + template + bool operator>(const T &rc){ + return this->greater(rc); + } + + template + bool operator<=(const T &rc){ + return this->leq(rc); + } + + template + bool operator>=(const T &rc){ + return this->geq(rc); + } + */ + }; +}; // namespace cytnx #endif #endif diff --git a/src/backend_torch/CMakeLists.txt b/src/backend_torch/CMakeLists.txt index bf8f2765..07411eac 100644 --- a/src/backend_torch/CMakeLists.txt +++ b/src/backend_torch/CMakeLists.txt @@ -5,4 +5,5 @@ target_sources_local(cytnx PRIVATE Type_convert.cpp + Scalar.cpp ) diff --git a/src/backend_torch/Scalar.cpp b/src/backend_torch/Scalar.cpp new file mode 100644 index 00000000..a34de6a3 --- /dev/null +++ b/src/backend_torch/Scalar.cpp @@ -0,0 +1,300 @@ +#include "backend_torch/Scalar.hpp" + +using namespace std; +namespace cytnx { + cytnx_complex128 complex128(const Scalar& in) { return complex(in.toComplexDouble()); } + + cytnx_complex64 complex64(const Scalar& in) { return complex(in.toComplexFloat()); } + + std::ostream& operator<<(std::ostream& os, const Scalar& in) { + in.print_elem(os); + os << std::string(" dtype: [") << Type.getname(in._dtype) << std::string("]"); + return os; + } + + // ladd: c + Scalar: + Scalar operator+(const Scalar& lc, const Scalar& rs) { return rs.radd(lc); }; + + // lmul c * Scalar; + Scalar operator*(const Scalar& lc, const Scalar& rs) { return rs.rmul(lc); }; + + // lsub c - Scalar; + Scalar operator-(const Scalar& lc, const Scalar& rs) { return lc.rsub(rs); }; + + // ldiv c / Scalar; + Scalar operator/(const Scalar& lc, const Scalar& rs) { return lc.rdiv(rs); }; + + // lless c < Scalar; + bool operator<(const Scalar& lc, const Scalar& rs) { return lc.less(rs); }; + + // lless c > Scalar; + bool operator>(const Scalar& lc, const Scalar& rs) { return lc.greater(rs); }; + + // lless c <= Scalar; + bool operator<=(const Scalar& lc, const Scalar& rs) { return lc.leq(rs); }; + + // lless c >= Scalar; + bool operator>=(const Scalar& lc, const Scalar& rs) { return lc.geq(rs); }; + + // eq c == Scalar; + bool operator==(const Scalar& lc, const Scalar& rs) { + // if (lc.dtype() < rs.dtype()) + // return lc.geq(rs); + // else + // return rs.geq(lc); + return lc.eq(rs); + }; + + Scalar abs(const Scalar& c) { return c.abs(); }; + + Scalar sqrt(const Scalar& c) { return c.sqrt(); }; + + at::Tensor StorageImpl2Tensor(const c10::intrusive_ptr& impl, int dtype) { + c10::Storage sto = c10::Storage(impl); + + if (dtype == Type.Bool) { + at::TensorImpl tnimpl = + at::TensorImpl(std::move(sto), c10::DispatchKeySet::FULL, caffe2::TypeMeta::Make()); + return at::Tensor( + c10::intrusive_ptr::reclaim(&tnimpl)); + } else if (dtype == Type.ComplexDouble) { + at::TensorImpl tnimpl = at::TensorImpl(std::move(sto), c10::DispatchKeySet::FULL, + caffe2::TypeMeta::Make>()); + return at::Tensor( + c10::intrusive_ptr::reclaim(&tnimpl)); + } else if (dtype == Type.ComplexFloat) { + at::TensorImpl tnimpl = at::TensorImpl(std::move(sto), c10::DispatchKeySet::FULL, + caffe2::TypeMeta::Make>()); + return at::Tensor( + c10::intrusive_ptr::reclaim(&tnimpl)); + } else if (dtype == Type.Int64) { + at::TensorImpl tnimpl = at::TensorImpl(std::move(sto), c10::DispatchKeySet::FULL, + caffe2::TypeMeta::Make()); + return at::Tensor( + c10::intrusive_ptr::reclaim(&tnimpl)); + } else if (dtype == Type.Int32) { + at::TensorImpl tnimpl = at::TensorImpl(std::move(sto), c10::DispatchKeySet::FULL, + caffe2::TypeMeta::Make()); + return at::Tensor( + c10::intrusive_ptr::reclaim(&tnimpl)); + } else if (dtype == Type.Int16) { + at::TensorImpl tnimpl = at::TensorImpl(std::move(sto), c10::DispatchKeySet::FULL, + caffe2::TypeMeta::Make()); + return at::Tensor( + c10::intrusive_ptr::reclaim(&tnimpl)); + } else if (dtype == Type.Double) { + at::TensorImpl tnimpl = + at::TensorImpl(std::move(sto), c10::DispatchKeySet::FULL, caffe2::TypeMeta::Make()); + return at::Tensor( + c10::intrusive_ptr::reclaim(&tnimpl)); + } else if (dtype == Type.Float) { + at::TensorImpl tnimpl = + at::TensorImpl(std::move(sto), c10::DispatchKeySet::FULL, caffe2::TypeMeta::Make()); + return at::Tensor( + c10::intrusive_ptr::reclaim(&tnimpl)); + } else if (dtype == Type.Float) { + at::TensorImpl tnimpl = + at::TensorImpl(std::move(sto), c10::DispatchKeySet::FULL, caffe2::TypeMeta::Make()); + return at::Tensor( + c10::intrusive_ptr::reclaim(&tnimpl)); + } + cytnx_error_msg(true, "[ERROR] StorageImpl2Tensor: unsupported dtype%s", "\n"); + return at::Tensor(); + } + + // Scalar proxy: + // Sproxy + Scalar::Sproxy& Scalar::Sproxy::operator=(const Scalar::Sproxy& rc) { + if (this->_insimpl.get() == 0) { + this->_insimpl = rc._insimpl; + this->_loc = rc._loc; + return *this; + } else { + if ((rc._insimpl == this->_insimpl) && (rc._loc == this->_loc)) { + return *this; + } else { + at::Tensor tnthis = StorageImpl2Tensor(this->_insimpl, this->_dtype); + at::Tensor tnrc = StorageImpl2Tensor(rc._insimpl, rc._dtype); + tnthis[this->_loc] = tnrc[rc._loc]; + return *this; + } + } + } + Scalar::Sproxy& Scalar::Sproxy::operator=(const Scalar& rc) { + at::Tensor tnthis = StorageImpl2Tensor(this->_insimpl, this->_dtype); + tnthis[this->_loc] = rc; + return *this; + } + Scalar::Sproxy& Scalar::Sproxy::operator=(const cytnx_complex128& rc) { + at::Tensor tnthis = StorageImpl2Tensor(this->_insimpl, this->_dtype); + tnthis[this->_loc] = c10::complex(rc); + return *this; + } + Scalar::Sproxy& Scalar::Sproxy::operator=(const cytnx_complex64& rc) { + at::Tensor tnthis = StorageImpl2Tensor(this->_insimpl, this->_dtype); + tnthis[this->_loc] = c10::complex(rc); + return *this; + } + Scalar::Sproxy& Scalar::Sproxy::operator=(const cytnx_double& rc) { + at::Tensor tnthis = StorageImpl2Tensor(this->_insimpl, this->_dtype); + tnthis[this->_loc] = rc; + return *this; + } + Scalar::Sproxy& Scalar::Sproxy::operator=(const cytnx_float& rc) { + at::Tensor tnthis = StorageImpl2Tensor(this->_insimpl, this->_dtype); + tnthis[this->_loc] = rc; + return *this; + } + Scalar::Sproxy& Scalar::Sproxy::operator=(const cytnx_uint64& rc) { + cytnx_error_msg(true, + "[ERROR] invalid dtype for scalar operator=, pytorch backend doesn't support " + "unsigned type %s", + "\n"); + return *this; + } + Scalar::Sproxy& Scalar::Sproxy::operator=(const cytnx_int64& rc) { + at::Tensor tnthis = StorageImpl2Tensor(this->_insimpl, this->_dtype); + tnthis[this->_loc] = rc; + return *this; + } + Scalar::Sproxy& Scalar::Sproxy::operator=(const cytnx_uint32& rc) { + cytnx_error_msg(true, + "[ERROR] invalid dtype for scalar operator=, pytorch backend doesn't support " + "unsigned type %s", + "\n"); + return *this; + } + Scalar::Sproxy& Scalar::Sproxy::operator=(const cytnx_int32& rc) { + at::Tensor tnthis = StorageImpl2Tensor(this->_insimpl, this->_dtype); + tnthis[this->_loc] = rc; + return *this; + } + Scalar::Sproxy& Scalar::Sproxy::operator=(const cytnx_uint16& rc) { + cytnx_error_msg(true, + "[ERROR] invalid dtype for scalar operator=, pytorch backend doesn't support " + "unsigned type %s", + "\n"); + return *this; + } + Scalar::Sproxy& Scalar::Sproxy::operator=(const cytnx_int16& rc) { + at::Tensor tnthis = StorageImpl2Tensor(this->_insimpl, this->_dtype); + tnthis[this->_loc] = rc; + return *this; + } + Scalar::Sproxy& Scalar::Sproxy::operator=(const cytnx_bool& rc) { + at::Tensor tnthis = StorageImpl2Tensor(this->_insimpl, this->_dtype); + tnthis[this->_loc] = rc; + return *this; + } + + bool Scalar::Sproxy::exists() const { return this->_dtype != Type.Void; }; + + Scalar Scalar::Sproxy::real() { return Scalar(*this).real(); } + Scalar Scalar::Sproxy::imag() { return Scalar(*this).imag(); } + + Scalar::Scalar(const Sproxy& prox) { + switch (prox._dtype) { + case Type.ComplexDouble: + *this = StorageImpl2Tensor(prox._insimpl, prox._dtype)[prox._loc].item(); + break; + case Type.ComplexFloat: + *this = StorageImpl2Tensor(prox._insimpl, prox._dtype)[prox._loc].item(); + break; + case Type.Double: + *this = StorageImpl2Tensor(prox._insimpl, prox._dtype)[prox._loc].item(); + break; + case Type.Float: + *this = StorageImpl2Tensor(prox._insimpl, prox._dtype)[prox._loc].item(); + break; + case Type.Int64: + *this = StorageImpl2Tensor(prox._insimpl, prox._dtype)[prox._loc].item(); + break; + case Type.Uint64: + cytnx_error_msg(true, "[ERROR] no support for unsigned dtype for torch backend %s", "\n"); + break; + case Type.Int32: + *this = StorageImpl2Tensor(prox._insimpl, prox._dtype)[prox._loc].item(); + break; + case Type.Uint32: + cytnx_error_msg(true, "[ERROR] no support for unsigned dtype for torch backend %s", "\n"); + break; + case Type.Int16: + *this = StorageImpl2Tensor(prox._insimpl, prox._dtype)[prox._loc].item(); + break; + case Type.Uint16: + *this = StorageImpl2Tensor(prox._insimpl, prox._dtype)[prox._loc].item(); + break; + case Type.Bool: + *this = StorageImpl2Tensor(prox._insimpl, prox._dtype)[prox._loc].item(); + break; + default: + cytnx_error_msg(true, "[ERROR] invalid dtype for torch backend %s", "\n"); + break; + } + } + + // // Storage Init interface. + // //============================= + // inline Scalar_base* ScIInit_cd() { + // Scalar_base* out = new ComplexDoubleScalar(); + // return out; + // } + // inline Scalar_base* ScIInit_cf() { + // Scalar_base* out = new ComplexFloatScalar(); + // return out; + // } + // inline Scalar_base* ScIInit_d() { + // Scalar_base* out = new DoubleScalar(); + // return out; + // } + // inline Scalar_base* ScIInit_f() { + // Scalar_base* out = new FloatScalar(); + // return out; + // } + // inline Scalar_base* ScIInit_u64() { + // Scalar_base* out = new Uint64Scalar(); + // return out; + // } + // inline Scalar_base* ScIInit_i64() { + // Scalar_base* out = new Int64Scalar(); + // return out; + // } + // inline Scalar_base* ScIInit_u32() { + // Scalar_base* out = new Uint32Scalar(); + // return out; + // } + // inline Scalar_base* ScIInit_i32() { + // Scalar_base* out = new Int32Scalar(); + // return out; + // } + // inline Scalar_base* ScIInit_u16() { + // Scalar_base* out = new Uint16Scalar(); + // return out; + // } + // inline Scalar_base* ScIInit_i16() { + // Scalar_base* out = new Int16Scalar(); + // return out; + // } + // inline Scalar_base* ScIInit_b() { + // Scalar_base* out = new BoolScalar(); + // return out; + // } + // Scalar_init_interface::Scalar_init_interface() { + // if (!inited) { + // UScIInit[this->Double] = ScIInit_d; + // UScIInit[this->Float] = ScIInit_f; + // UScIInit[this->ComplexDouble] = ScIInit_cd; + // UScIInit[this->ComplexFloat] = ScIInit_cf; + // UScIInit[this->Uint64] = ScIInit_u64; + // UScIInit[this->Int64] = ScIInit_i64; + // UScIInit[this->Uint32] = ScIInit_u32; + // UScIInit[this->Int32] = ScIInit_i32; + // UScIInit[this->Uint16] = ScIInit_u16; + // UScIInit[this->Int16] = ScIInit_i16; + // UScIInit[this->Bool] = ScIInit_b; + // inited = true; + // } + // } + + // Scalar_init_interface __ScII; +} // namespace cytnx