Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

linalg internal uses NULL pointers for undefined functions #487

Open
manuschneider opened this issue Oct 2, 2024 · 3 comments
Open

linalg internal uses NULL pointers for undefined functions #487

manuschneider opened this issue Oct 2, 2024 · 3 comments
Labels
Low priority An issue which is not urgent suggestion Suggestion for current codebase

Comments

@manuschneider
Copy link
Collaborator

Here is an example of the current behavior of many backend functions for different dtyes:

in src/backend/linalg_internal_interface.hpp, a vector containing function pointers is created:

std::vector<Qrfunc_oii> QR_ii;

Then, in src/backend/linalg_internal_interface.cpp, this vector is filled with the implementations for the different dtypes:

QR_ii = vector<Qrfunc_oii>(5);

QR_ii[Type.ComplexDouble] = QR_internal_cd;
QR_ii[Type.ComplexFloat] = QR_internal_cf;
QR_ii[Type.Double] = QR_internal_d;
QR_ii[Type.Float] = QR_internal_f;

I see two problems with this: first, the vector in this case only contains 5 entries, so accessing QR_ii[Type.Int] would lead to an error that does not tell the user much.
Second, QR_ii contains 5 elements, but only 4 of them are initialized. The last one defaults to a NULL pointer, and when the function is called this leads to a segfault or kernel crash. This makes debugging very hard.

Here is a suggestion to add fallback implementations of all internal functions in order to make things more solid:
In src/backend/linalg_internal_interface.cpp:

QR_ii = vector<Qrfunc_oii>(N_Type, QR_internal_fallback);

QR_ii[Type.ComplexDouble] = QR_internal_cd;
QR_ii[Type.ComplexFloat] = QR_internal_cf;
QR_ii[Type.Double] = QR_internal_d;
QR_ii[Type.Float] = QR_internal_f;

In src/backend/linalg_internal_cpu/QR_internal.hpp:

void QR_internal_fallback(const boost::intrusive_ptr<Storage_base> &in,
                        boost::intrusive_ptr<Storage_base> &Q,
                        boost::intrusive_ptr<Storage_base> &R,
                        boost::intrusive_ptr<Storage_base> &D,
                        boost::intrusive_ptr<Storage_base> &tau, const cytnx_int64 &M,
                        const cytnx_int64 &N, const bool &is_d);

In src/backend/linalg_internal_cpu/QR_internal.cpp:

void QR_internal_fallback(const boost::intrusive_ptr<Storage_base> &in,
                        boost::intrusive_ptr<Storage_base> &Q,
                        boost::intrusive_ptr<Storage_base> &R,
                        boost::intrusive_ptr<Storage_base> &D,
                        boost::intrusive_ptr<Storage_base> &tau, const cytnx_int64 &M,
                        const cytnx_int64 &N, const bool &is_d) {
  cytnx_error_msg(true, "[ERROR][linalg_internal] QR_internal not implemented for this data type", "\n");
}

However, this would have to be done for many internal functions in the backend.

@manuschneider manuschneider added suggestion Suggestion for current codebase Low priority An issue which is not urgent labels Oct 2, 2024
@manuschneider manuschneider changed the title linalg internal uses zero pointers for undefined functions linalg internal uses NULL pointers for undefined functions Oct 2, 2024
@IvanaGyro
Copy link
Collaborator

Many xx_cd, xx_cf, ... functions, do similar things. The other way to ease the pain is letting the callers take the responsibility of type checking. And then templatize xx_cd, xx_cf, ... functions or use function overload: xx(int value) {}, xx(double value). By doing so, some error will be caught at the compile time.

@manuschneider
Copy link
Collaborator Author

Indeed, it would be possible to not provide these function pointer vectors and let the calling function handle things. But then the calling function needs to have some switch - case statement everywhere to call the correct function. I think the current implementation is more convenient, one can just call QR_ii[dtype] for a generic dtype. But the current implementation is not very robust because of the NULL pointers and vectors that are shorter than N_Type.

Function overloading would be a clean way to implement things, but does not work here unfortunately. The function arguments are tensors, and the implementation depends on their dtype. So the function arguments always have the same types.

The only alternative I can think of is to provide only one function that checks the dtype of the input tensors and changes the behavior accordingly/calls the correct function (maybe a bit slower though).

@IvanaGyro
Copy link
Collaborator

IvanaGyro commented Oct 7, 2024

Instead of creating an invalid function for each function manually, we can use a template. This solution needs C++20.

#include <iostream>
#include <algorithm>

template<size_t N>
struct StringLiteral {
    constexpr StringLiteral(const char (&str)[N]) {
        std::copy_n(str, N, value);
    }
    
    char value[N];
};


template<StringLiteral Name, typename Return, typename... Args>
Return NullFunction(Args... args) {
    std::cout << "Calling a function pointing to nullptr:" << Name.value << std::endl;
    return Return();
}

Here is a demo.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Low priority An issue which is not urgent suggestion Suggestion for current codebase
Projects
None yet
Development

No branches or pull requests

2 participants