Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

sa_search incompatible with int numpy array #50

Open
rangehow opened this issue Dec 11, 2024 · 7 comments
Open

sa_search incompatible with int numpy array #50

rangehow opened this issue Dec 11, 2024 · 7 comments

Comments

@rangehow
Copy link

rangehow commented Dec 11, 2024

Dear project developers,

Hello, I encountered an issue with the sa_search method in combination with numpy. The test code is as follows:

if __name__ == "__main__":
    np.random.seed(0) 
    s = np.random.randint(0, 10, size=1000000)  


    sa = divsufsort(s)
    query_length = 10
    start_idx = np.random.randint(0, len(s) - query_length)
    query = s[start_idx:start_idx + query_length]
    sa_search(s,sa,query)

And get following error:

Traceback (most recent call last):
  File "/mnt/dolphinfs/hdd_pool/docker/user/hadoop-aipnlp/ruanjunhao04/llm-sft-data/ruanjunhao/ndp/ndp/infini_gram/suffix_array_utils.py", line 115, in <module>
    sa_search(s,sa,query)
  File "/mnt/dolphinfs/hdd_pool/docker/user/hadoop-aipnlp/ruanjunhao04/env/rjh/lib/python3.12/site-packages/pydivsufsort/divsufsort.py", line 197, in sa_search
    inp_p = _get_bytes_pointer(inp)
            ^^^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/dolphinfs/hdd_pool/docker/user/hadoop-aipnlp/ruanjunhao04/env/rjh/lib/python3.12/site-packages/pydivsufsort/divsufsort.py", line 58, in _get_bytes_pointer
    assert inp.dtype == np.uint8
           ^^^^^^^^^^^^^^^^^^^^^
AssertionError

It seems that this method does not natively support searching int np.array. Currently, I am using the following alternative solution to achieve this. I hope to get some advice from you.

import numpy as np
import time
from pydivsufsort import divsufsort,sa_search,kasai



def build_suffix_array(s: np.array) -> np.array:
    return divsufsort(s)

def build_lcp_array(s,sa):
    return kasai(s, sa)

def binary_search(sa, lcp, s, query):
    start, end = 0, len(sa)
    query_len = len(query)

    while start < end:
        mid = (start + end) // 2
        suffix_start = sa[mid]
        
        # Utilize LCP information
        if mid > start and lcp[mid] >= query_len:
            # Skip if LCP is greater than the query length
            start = mid + 1
        else:
            mid_slice = s[suffix_start:suffix_start + query_len]
            mid_slice_str = mid_slice.tobytes()
            query_str = query.tobytes()

            if mid_slice_str == query_str:
                cmp_val = 0
            elif mid_slice_str < query_str:
                cmp_val = -1
            else:
                cmp_val = 1

            if cmp_val < 0:
                start = mid + 1
            else:
                end = mid

    # Verify if the found suffix actually matches the query
    if start < len(sa) and np.array_equal(s[sa[start]:sa[start] + query_len], query):
        first_occurrence = start

        # Find the last occurrence using LCP
        end = len(sa)
        while start < end:
            mid = (start + end) // 2
            if mid > first_occurrence and lcp[mid] >= query_len:
                start = mid + 1
            else:
                suffix_start = sa[mid]
                mid_slice = s[suffix_start:suffix_start + query_len]
                mid_slice_str = mid_slice.tobytes()
                if mid_slice_str == query_str or mid_slice_str < query_str:
                    start = mid + 1
                else:
                    end = mid

        last_occurrence = start - 1
        return first_occurrence, last_occurrence
    else:
        return -1, -1

def retrieve_num_substrings(sa, lcp, s, query, extend=0):
    assert extend <= 1

    first, last = binary_search(sa, lcp, s, query)
    if first == -1:
        return 0, (None, None)
    
    return (last - first + 1), (first, last)


def get_retrieved_substrings(first, last, sa, s, query, extend=1):
    assert extend <= 1

    # maybe slow
    matching_substrings = []
    for i in range(first, last + 1):    
        start_index = sa[i]
        matching_substrings.append(s[start_index:start_index + len(query) + extend])

    return matching_substrings


def retrieve_substrings(sa, lcp, s, query, extend=1):
    assert extend <= 1

    num_matches, (first, last) = retrieve_num_substrings(sa, lcp, s, query, extend)
    
    if num_matches == 0:
        return []

    return get_retrieved_substrings(first, last, sa, s, query, extend)

# Testing code
if __name__ == "__main__":
    np.random.seed(0) 
    s = np.random.randint(0, 10, size=1000000)  


    sa = build_suffix_array(s)
    query_length = 10
    start_idx = np.random.randint(0, len(s) - query_length)
    query = s[start_idx:start_idx + query_length]

    lcp = kasai(s, sa)
 
    substrings = retrieve_substrings(sa, lcp, s, query, extend=1)

Lastly, please allow me to express my sincere respect for your valuable time once again.

@louisabraham
Copy link
Owner

Could you try casting your input arrays as int8?

In randint, just add dtype=np.int8

@rangehow
Copy link
Author

Could you try casting your input arrays as int8?

In randint, just add dtype=np.int8

Thank you very much for your prompt reply. However, my use case requires the elements of the NumPy array to have a range of approximately 0 to 128k, which far exceeds the range of uint8. If there is a better way to handle this, please let me know. Thank you again.

@louisabraham
Copy link
Owner

I think this is an edge case I don't cover.
You could convert them to base 256, so each element becomes 3 elements. Then you can cast to uint8 and remove matches that are not 0 modulo 3.

@louisabraham
Copy link
Owner

I can probably fix this in ~12h.

@rangehow
Copy link
Author

rangehow commented Dec 11, 2024

Thank you very much for your help. I am using the repository you developed to process the dataset after tokenization of a large language model. More specifically, it is https://arxiv.org/pdf/2401.17377 (Chapter 3), where it seems to mention a similar approach of converting token IDs from 0-128k into some kind of numeral system. I wonder if it might be of any reference to you.
image

@louisabraham
Copy link
Owner

see the new test:

    a = np.array([0, 256, 0, 256], np.uint16)
    s = WonderString(a)
    pat = np.array([0, 256], np.uint16)
    ans = s.search(pat, return_positions=True)
    assert sorted(ans) == [0, 2]

It uses the undocumented WonderString interface and requires return_positions=True. Unfortunately this function is linear in the number of matches (with return_positions=False the complexity is independent of the number of matches).

I want to leave this issue open as the best solution would be to reimplement libdivsufsort.sa_search, which is a lot more work.

@rangehow
Copy link
Author

Thank you for your time. I will test the performance difference between this interface and my simple Python implementation :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants