diff --git a/AutoPM3_main.py b/AutoPM3_main.py new file mode 100644 index 0000000..d293efd --- /dev/null +++ b/AutoPM3_main.py @@ -0,0 +1,432 @@ +from langchain_community.llms import Ollama +from langchain.text_splitter import RecursiveCharacterTextSplitter,CharacterTextSplitter +from langchain_community.vectorstores import FAISS +from langchain.chains import RetrievalQA +from langchain import PromptTemplate +from langchain.globals import set_verbose, set_debug +import requests + +from bioc import biocxml +# Import the following stuff for implementing custom retrievers +from typing import List, Dict +from langchain_core.documents import Document +from langchain_core.retrievers import BaseRetriever + +from table_functions import table_extraction_n_sqlQA +from utils import extractTablesFromXML + +set_debug(False) + +# the second item is the number of returned chunks +# the third item is the abbreviated protein change notation +retriever_OK = [ False, 0, None ] + + +import textwrap +import os +import time +from argparse import ArgumentParser +import sys +import glob +import json +import re +import tempfile + + +os.environ['CURL_CA_BUNDLE'] = '' # Fix SSL error for Mutalyzer3 + +PROTEIN_MAPPING_FILE = './protein.txt' + +# enum types +VARIANT_QUERY = 0 +INTRANS_QUERY = 1 +C_VARIANT = 0 # c.123A>G +P_VARIANT = 1 # protein change + + +from func_timeout import func_set_timeout +import func_timeout + + +class VariantSpecificRetriever(BaseRetriever): + documents: List[Document] + k: int + protein_map: Dict[str, str] + + # Assumes "query" to be the target variant (in HGVS notation) + def _get_relevant_documents(self, query): + variant = query + # Remove the contig name (NM_xxxxxx) + target_var = variant.split(":")[-1] + # Remove c.() from the variant notation (by default is c.(123A>G) or c.123A>G) + var_dna = target_var.replace('c.', '').replace('(', '').replace(')', '') + # Translate the mutation to protein change using Mutalyzer + var_protein = None + var_protein_short = None + try: + r = requests.get(f'https://mutalyzer.nl/api/normalize/{variant}?only_variants=false') + j = r.json() + returned_prot = j['protein']['description'].split(':')[-1] + # Remove the p.() + m = re.match(r'p.\((.*)\)', returned_prot) + prot = m.group(1) + if len(prot) < 5: # Too short (sometimes Mutalyzer returns something like p.(=) ) + raise Exception(f'Protein change too short: {returned_prot}') + # Sometimes the protein mutation is like Cys1447Glnfs29 but some papers write as Cys1447fs, + # so we remove the whole Glnfs part + var_protein = re.sub(r'[A-Za-z]{3}fs.*', '', prot) + # Convert the protein to short form ( -> ) + var_protein_short = var_protein + for (k,v) in self.protein_map.items(): + var_protein_short = var_protein_short.replace(k, v) + # Remove X and * (meaning Terminal) from the protein notation, since we don't know the paper is using which one + var_protein = var_protein.replace('X', '').replace('*', '') + var_protein_short = var_protein_short.replace('X', '').replace('*', '') + #print(f'Protein : {var_protein} ({var_protein_short})') + except KeyError as e: + #print('Protein: [ERROR] Not found by Mutalyzer') + pass + except Exception as e: + #print(f'Protein : [ERROR] {e}') + pass + + # Done with conversion. Now do the retrieval (= regex matching) + + retrieved_chunks = [] + # Construct the regex pattern for DNA: + # 1. 123A>G becomes \s*123\s*A>G (allow spaces around numbers) + # 2. Further becomes \s*123\s*123A\s*>\s*G (allow spaces around > ) + dna_pattern = re.sub('([0-9]+)', '\\\\s*\\1\\\\s*', re.escape(var_dna)) + dna_pattern = re.sub('(>)', '\\\\s*\\1\\\\s*', dna_pattern) + dna_matcher = re.compile(dna_pattern, re.IGNORECASE) + for chunk in self.documents: + # Re-encode the text to get rid of those annoying Unicode \x80\x89 (whitespaces) + text = chunk.page_content.encode('utf-8').decode('unicode_escape').encode('latin-1').decode('utf-8') + if dna_matcher.search(text) or \ + var_protein and chunk.page_content.find(var_protein) >= 0 or \ + var_protein_short and chunk.page_content.find(var_protein_short) >= 0: + retrieved_chunks += [ chunk ] + # If neither DNA nor protein change could retrieve anything, + # we resort to matching by positions only... + if not retrieved_chunks: + dig_dna = re.findall(r'\d+', var_dna) + dig_protein = re.findall(r'\d+', var_protein) if var_protein else None + dig_dna_matcher = re.compile('\D' + str(dig_dna[0]) + '\D') if dig_dna else None + dig_protein_matcher = re.compile('\D' + str(dig_protein[0]) + '\D') if dig_protein else None + for chunk in self.documents: + if dig_dna_matcher and dig_dna_matcher.search(chunk.page_content) or \ + dig_protein_matcher and dig_protein_matcher.search(chunk.page_content): + retrieved_chunks += [ chunk ] + if len(retrieved_chunks) > 0: + retriever_OK[0] = True + retriever_OK[1] = len(retrieved_chunks) + retriever_OK[2] = var_protein_short + return retrieved_chunks[:self.k] + + +# Load the protein abbreviatioon map from a file +def load_protein_map(filename): + with open(filename, 'r') as f: + lines = f.read().splitlines() + m = { x.split()[0]: x.split()[1] for x in lines } + return m + + +# Load paper from XML file +def load_xml_paper(filename, filter_tables=False): + out_doc = '' + with open(filename, 'r', encoding='utf8') as fp: # better use utf8 + collection = biocxml.load(fp) + document = collection.documents[0] + for passage in document.passages: + section_type = passage.infons.get('section_type', '').upper() + if filter_tables and section_type in [ 'TABLE', 'REF', 'COMP_INT', 'AUTH_CONT', 'SUPPL' ]: + pass # filter away this section + else: + out_doc += passage.text + '\n' + return out_doc + + + +template_PM3_answer_chain_llama3 = """\ +<|begin_of_text|><|start_header_id|>system<|end_header_id|> +You are a specialist in biogenetics, answer only based on user's input!<|eot_id|> +<|start_header_id|>user<|end_header_id|> +The variant in HGVS format is {question}, don't include this in your answer if condisering compound het variants. +Given the context: '{context}' and target variant {c_variant}. Answer the question: {proposedQuestion}<|eot_id|>. +<|start_header_id|>assistant<|end_header_id|>\n +""" + + + +def split_docs(documents,chunk_size=1500,chunk_overlap=100): +# Responsible for splitting the documents into several chunks + + # Initializing the RecursiveCharacterTextSplitter with + # chunk_size and chunk_overlap + text_splitter = RecursiveCharacterTextSplitter( + chunk_size=chunk_size, + chunk_overlap=chunk_overlap + ) + + # Splitting the documents into chunks + chunks = text_splitter.split_documents(documents=documents) + + # returning the document chunks + return chunks + + + + + +# Creating the chain for Question Answering +def load_qa_chain(retriever, llm, prompt): + + return RetrievalQA.from_chain_type( + llm=llm, + retriever=retriever, # here we are using the vectorstore as a retriever + chain_type="stuff", + return_source_documents=True, # including source documents in output + chain_type_kwargs={'prompt': prompt, "verbose": False} # customizing the prompt + ) + + + + +# Prettifying the response +@func_set_timeout(300) +def get_answers_PM3(query, chain): + + # Getting response from chain + input_dict = {'query': query} + + response = chain(input_dict) + + return response + +def loadTextModel(model_name): + print("Loading model",model_name) + if "llama3" in model_name: + llm_a = Ollama(model=model_name,temperature=0.0, top_p = 0.9, stop=["<|start_header_id|>", "<|end_header_id|>", "<|eot_id|>", "<|reserved_special_token"]) + elif model_name == "phi3": + + llm_a = Ollama(model=model_name,temperature=0.0, top_p = 0.9, stop=["<|user|>","<|assistant|>","<|system|>","<|end|>","<|endoftext|>", "<|reserved_special_token"]) + else: + llm_a = Ollama(model=model_name,temperature=0.0, top_p = 0.9) + print("Loading model DONE") + return llm_a + +def main(): + parser = ArgumentParser(description='AutoPM3') + parser.add_argument( + '--model_name_text', + help="llm used for answering generated questions", + required=False, + default='llama3_loraFT-8b-f16', + ) + parser.add_argument( + '--model_name_table', + help="llm used for table queries", + required=False, + default='sqlcoder-7b-Mistral-7B-Instruct-v0.2-slerp.Q8_0', + ) + parser.add_argument( + '--query_variant', + help="query variant in HGVS format", + required=True, + ) + parser.add_argument( + '--paper_path', + help="paper_path of the query literature", + required=True, + ) + + + # print help message if no argument input + if len(sys.argv) <= 1 or sys.argv[1] == "-h" or sys.argv[1] == "--help": + parser.print_help(sys.stderr) + sys.exit(0) + + args = parser.parse_args() + results = query_variant_in_paper_xml(args.query_variant,args.paper_path,args.model_name_table,args.model_name_text) + print(results) + + +def query_variant_in_paper_xml(query_variant, xml_path, model_name_table, model_name_text): + + llm_a = loadTextModel(model_name_text) + llm_table = [Ollama(model=model_name_table, temperature=0.0, top_p=0.9) ] + + # Read protein abbreviation table + protein_map = load_protein_map(PROTEIN_MAPPING_FILE) + + # check if the query variant is in the correct format (TODO) + c_variant = query_variant.split(":")[-1] + + # Check if the paper (XML, PDF) exists + xml_fn =xml_path + + if not os.path.exists(xml_fn): + print('XML paper not found. Abort.') + sys.exit(-1) + + # Load the XML paper and filter away tables and useless sections, + # then split into chunks + doc_filtered = load_xml_paper(xml_fn, filter_tables=True) + doc_wrapper = [ Document(page_content = doc_filtered, metadata = {'source': 'local'}) ] + doc_chunks = split_docs(doc_wrapper) + # Try our custom retriever + variant_retriever = VariantSpecificRetriever(documents=doc_chunks, k=5, protein_map=protein_map) + variant_hgvs = query_variant + + try: + r = requests.get(f'https://mutalyzer.nl/api/normalize/{query_variant}?only_variants=false') + j = r.json() + protein = j['protein']['description'].split(':')[-1] + if protein == 'p.(=)': + raise Exception('invalid notation') + + c_protein_id = re.findall(r"\d+",protein) + except Exception as e: + + protein = None + protein_short = None + + + ########################## + # Do table queries first # + ########################## + table_src_contains_variant = False + table_query_results = [] + + # Find all the table CSV files for this PMID + #relevant_tables = [ f for f in table_csv_files if str(c_pmid) in f ] + relevant_tables = extractTablesFromXML(xml_fn) + c_variant_id = None + c_max = 0 + c_tmp_digit = re.findall(r"\d+",c_variant) + for c_digit in c_tmp_digit: + if len(c_digit) > c_max: # 3 + c_variant_id = c_digit + c_max = len(c_digit) + + if relevant_tables: + variant_alias = [c_variant_id, c_protein_id[0]] if protein is not None and len(c_protein_id) > 0 else [c_variant_id] + + csv_files = [] + csv_filenames = [] + # Write the extracted tables to temporary CSV files + for table in relevant_tables: + + tmpfile = tempfile.NamedTemporaryFile(mode='w', suffix='.csv', delete=True) + csv_files.append(tmpfile) + csv_filenames.append(tmpfile.name) + #print("temp tables",tmpfile.name) + table.to_csv(tmpfile.name,index=False) + table_query_return = table_extraction_n_sqlQA(csv_filenames, model_name_table, + query_variant_list=variant_alias, llm=llm_table, llm_qa=llm_table, show_errors=False) + + # Close and delete the temp files + for tmpfile in csv_files: + tmpfile.close() + + if table_query_return is None: + + pass + #print('[ERROR] Something went wrong with the tables.') + else: + ( table_query_results, table_src_contains_variant ) = table_query_return[:2] + # Collect the answer candidates from the returned results + table_results_plaintext = [] + for c_cmd in table_query_results: + for c_answer in c_cmd[1]: + if not isinstance(c_answer, tuple): + try: + table_results_plaintext.append(c_answer['plainText']) + except Exception as e: + pass + else: + table_results_plaintext = ["No table found"] + ########################### + # Now do the text queries # + ########################### + + # We use "protein" instead of "c_protein_id[0]" for text. + # - "protein": The protein change returned by Mutalzyer + # - "c_protein_id[0]": Only the digits in "protein" + + variant_alias = [c_variant, protein] if protein is not None and len(c_protein_id) > 0 else [c_variant] + + text_variant_hit = False + text_intrans_list = [] + text_src_contains_variant = False + text_variant_answer = "" + + MAX_RETRIES = 1 # number of retries before giving up + + PM3_answer_prompt = PromptTemplate.from_template(template_PM3_answer_chain_llama3) + for c_index, current_variant in enumerate(variant_alias): + + for query_type in range(2): # variant query, in-trans query + retriever_OK[0] = False + if query_type == VARIANT_QUERY: + my_predefined_query = f"Does the paper mention the queried variant ({current_variant}) and what is the surrounding context?" + f"""if such variant is existed, say *YES* at first otherwise say *None* (focus on variant: {current_variant})""" + elif query_type == INTRANS_QUERY: + my_predefined_query = f"If {current_variant} is compound heterozygous with another variant, name it; if {current_variant} is homozygous, say homozygous; if no related variant is found, say *None*. List all results seperated by comma and wrap the answers by *.""" + + num_retries = 0 + query_success = False + while not query_success and num_retries <= MAX_RETRIES: + try: + PM3_answer_chain = load_qa_chain(variant_retriever, llm_a, PM3_answer_prompt.partial(proposedQuestion=my_predefined_query, c_variant=current_variant)) + cur_answers_all = get_answers_PM3(variant_hgvs, PM3_answer_chain) + if not retriever_OK[0]: + break + else: + text_src_contains_variant = True + protein_short = retriever_OK[2] + cur_answers = cur_answers_all['result'] + query_success = True + except func_timeout.exceptions.FunctionTimedOut: + del llm_a; + llm_a = loadTextModel(model_name_text) + num_retries += 1 + + if not query_success: + continue + + # Wrapping the text for better output in Jupyter Notebook + wrapped_text = textwrap.fill(cur_answers_all['result'], width=100) + + if query_type == VARIANT_QUERY: + source_doc = cur_answers_all['source_documents'] + c_variant_inRetrieved = False + for page in source_doc: + c_rsids = re.findall(current_variant if c_index == C_VARIANT else c_protein_id[0], page.page_content) + if len(c_rsids) > 0: + c_variant_inRetrieved = True + if 'yes' in cur_answers.lower(): + text_variant_hit = True + if c_index == C_VARIANT: + text_variant_answer = "\n- **[DNA match result]**:"+cur_answers + else: + text_variant_answer += "\n- **[Protein match result]**:"+cur_answers + elif text_variant_answer == "": + text_variant_answer = "\n- **Variant not found in text part!**" + + elif query_type == INTRANS_QUERY: + if "none" not in cur_answers.lower() or "contain" in cur_answers.lower(): + text_intrans_list.append(cur_answers) + #text_intrans_list.append(cur_answers) + + + table_results_plaintext_output = [str(xx).strip("\n") + "\n\n" for xx in table_results_plaintext] + #print(f'# Output Summary: \n \n## **Query Variant and Relative Intrans-variant/Genotype Found in PaperTables**: \n{"".join([str(xx) for xx in table_results_plaintext_output])} \n \n## **Query Variant Found in PaperText**: \n- {text_variant_answer if text_variant_hit != "" else "Variant not found in text part!"} \n \n## **Query Variant\'s Intrans-variant Found in PaperText**: \n{text_intrans_list if text_variant_answer != "Variant not found in text part!" and text_intrans_list else "None!"}') + results = f'# Output Summary: \n \n## **Query Variant and Relative Intrans-variant/Genotype Found in PaperTables**: \n{"".join([str(xx) for xx in table_results_plaintext_output])} \n \n## **Query Variant Found in PaperText**: \n- {text_variant_answer if text_variant_hit != "" else "Variant not found in text part!"} \n \n## **Query Variant\'s Intrans-variant Found in PaperText**: \n{text_intrans_list if text_variant_answer != "Variant not found in text part!" and text_intrans_list else "None!"}' + return results + + +if __name__ == "__main__": + + main() + diff --git a/PM3-Bench/README.md b/PM3-Bench/README.md new file mode 100644 index 0000000..e69de29 diff --git a/README.md b/README.md index 98bbb5a..cb9140c 100644 --- a/README.md +++ b/README.md @@ -1 +1,112 @@ -# AutoPM3 \ No newline at end of file +# AutoPM3: Enhancing Variant Interpretation via LLM-driven PM3 Evidence Extraction from Scientific Literature + +[![License](https://img.shields.io/badge/license-MIT-blue)](https://opensource.org/license/mit/) + + +Contact: Ruibang Luo, Shumin Li + +Email: rbluo@cs.hku.hk, lishumin@connect.hku.hk + + +## Introduction +We introduce AutoPM3, a method for automating the extraction of ACMG/AMP PM3 evidence from scientific literature using open-source LLMs. It combines an optimized RAG system for text comprehension and a TableLLM equipped with Text2SQL for data extraction. We evaluated AutoPM3 using our collected PM3-Bench, a dataset from ClinGen with 1,027 variant-publication pairs. AutoPM3 significantly outperformed other methods in variant hit and in trans variant identification, thanks to the four key modules. Additionally, we wrapped AutoPM3 with a user-friendly interface to enhances its accessibility. This study presents a powerful tool to improve rare disease diagnosis workflows by facilitating PM3-relevant evidence extraction from scientific literature. + +![](./images/img1.png) +--- + +## Contents + +- [Latest Updates](#latest-updates) +- [Installations](#installation) + - [Dependency Installation](#dependency-installation) + - [Ollama Setup](#using-ollama-to-host-llms) +- [Usage](#usage) + - [Quick Start](#quick-start) + - [Advanced Usage](#advanced-usage-of-the-python-script) +- [TODO](#todo) +--- + +## Latest Updates +* v0.1 (Oct, 2024): Initial release. +--- + +## Installation +### Dependency Installation +```bash +conda create -n AutoPM3 python=3.10 +conda activate AutoPM3 +pip3 install -r requirements.txt +``` + +### Using Ollama to host LLMs +1. Download Ollama [Guidance](https://github.com/ollama/ollama) +2. Change the directory of Ollama models: +```bash +# please change the target folder as you prefer +mkdir ollama_models +export OLLAMA_MODELS=./ollama_models +``` + +3. Launch Ollama server: + +```bash + +ollama serve + +``` + +3. Download sqlcoder-mistral-7B model and fine-tuned Llama3: +```bash +cd $OLLAMA_MODELS +wget https://huggingface.co/MaziyarPanahi/sqlcoder-7b-Mistral-7B-Instruct-v0.2-slerp-GGUF/resolve/main/sqlcoder-7b-Mistral-7B-Instruct-v0.2-slerp.Q8_0.gguf?download=true +mv 'sqlcoder-7b-Mistral-7B-Instruct-v0.2-slerp.Q8_0.gguf?download=true' 'sqlcoder-7b-Mistral-7B-Instruct-v0.2-slerp.Q8_0.gguf' +echo "FROM ./sqlcoder-7b-Mistral-7B-Instruct-v0.2-slerp.Q8_0.gguf" >Modelfile1 +ollama create sqlcoder-7b-Mistral-7B-Instruct-v0.2-slerp.Q8_0 -f Modelfile1 + +wget http://bio8.cs.hku.hk/AutoPM3/llama3_loraFT-8b-f16.gguf +echo "FROM ./llama3_loraFT-8b-f16.gguf" >Modelfile2 +ollama create llama3_loraFT-8b-f16 -f Modelfile2 + +``` + +5. Check the created models: + +```bash + +ollama list + +``` + +6. (Optional) Download other models as the backend of the RAG system: +``` +# e.g. download Llama3:70B +ollama pull llama3:70B + +``` + +## Usage + +### Quick start + +* Step 1. Launch the local web-server: +```bash +streamlit run lit.py +``` +* Step 2. Copy the following `http://localhost:8501` to the brower and start to use. + +### Advanced usage of the python script + +* Check the help of AutoPM3_main.py +```bash +python AutoPM3_main.py -h +``` +* The example of running python scripts: +```bash +python AutoPM3_main.py +--query_variant "NM_004004.5:c.516G>C" ## HVGS format query variant +--paper_path ./xml_papers/20201936.xml ## paper path. +--model_name_text llama3_loraFT-8b-f16 ## change to llama3:70b or other hosted models as the backend of RAG as you prefer, noted that you need pull the model in Ollama in advance. +``` + +## TODO +* A fast set up for AutoPM3. \ No newline at end of file diff --git a/images/img1.png b/images/img1.png new file mode 100644 index 0000000..4bbce00 Binary files /dev/null and b/images/img1.png differ diff --git a/lit.py b/lit.py new file mode 100644 index 0000000..2d3b014 --- /dev/null +++ b/lit.py @@ -0,0 +1,66 @@ +import streamlit as st +import os +import requests + +from AutoPM3_main import query_variant_in_paper_xml + + +# Function to load a XML file from a URL +def load_xml(url,pmid): + + temp_paper_file_root = "./xml_papers" + if(not os.path.exists(temp_paper_file_root)): + os.mkdir(temp_paper_file_root) + fn = str(pmid)+".xml" + xml_path = os.path.join(temp_paper_file_root,fn) + if(os.path.exists(xml_path)): + return xml_path + + headers = { + "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/58.0.3029.110 Safari/537.3" + } + response = requests.get(url, headers=headers) + if response.status_code == 200 and response.headers['Content-type'] == 'text/xml': + xml_data = response.content + # save it to the temp dir + with open(xml_path, 'wb') as f: + f.write(response.content) + return xml_path + else: + raise Exception('Invalid PMID. Make sure the publication has OpenAccess.') + return None + + +# Function to display model results +def display_results(model, data): + # Assuming 'model' is your trained model and 'data' is the input to the model + results = model.predict(data) + st.write(results) + +# Main +st.title('AutoPM3') + +variant_name = st.text_input('Step 1. Enter the variant (HGVS notation)') + +# Get the URL of the XML from the user +paper_url = '' +pmid = st.text_input('Step 2. Enter the PMID of the paper') +if pmid: + try: + pmid = int(pmid) + paper_url = f'https://www.ncbi.nlm.nih.gov/research/bionlp/RESTful/pmcoa.cgi/BioC_xml/{pmid}/unicode' + except ValueError: + st.write('Invalid PMID.') + +if st.button('Run', type='primary'): + summarized_results = "" + if paper_url and variant_name: + try: + # Load and display the XML + xml_path = load_xml(paper_url,pmid) + summarized_results = query_variant_in_paper_xml(variant_name, xml_path, 'sqlcoder-7b-Mistral-7B-Instruct-v0.2-slerp.Q8_0', 'llama3_loraFT-8b-f16') + # Display the summarized results + st.write(summarized_results) + except Exception as e: + st.write('An error has occurred.') + st.write(str(e)) diff --git a/protein.txt b/protein.txt new file mode 100644 index 0000000..0ac041c --- /dev/null +++ b/protein.txt @@ -0,0 +1,21 @@ +Ala A +Arg R +Asn N +Asp D +Cys C +Gln Q +Glu E +Gly G +His H +Ile I +Leu L +Lys K +Met M +Phe F +Pro P +Ser S +Ter X +Thr T +Trp W +Tyr Y +Val V diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..7a432d3 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,85 @@ +aiohappyeyeballs==2.4.3 +aiohttp==3.10.10 +aiosignal==1.3.1 +altair==5.4.1 +annotated-types==0.7.0 +anyio==4.6.2.post1 +async-timeout==4.0.3 +attrs==24.2.0 +beautifulsoup4==4.12.2 +bioc==2.1 +blinker==1.8.2 +cachetools==5.5.0 +certifi==2024.8.30 +charset-normalizer==3.4.0 +click==8.1.7 +dataclasses-json==0.6.7 +docopt==0.6.2 +exceptiongroup==1.2.2 +frozenlist==1.4.1 +func_timeout==4.3.5 +gitdb==4.0.11 +GitPython==3.1.43 +greenlet==3.1.1 +h11==0.14.0 +httpcore==1.0.6 +httpx==0.27.2 +idna==3.10 +intervaltree==3.1.0 +Jinja2==3.1.4 +jsonlines==4.0.0 +jsonpatch==1.33 +jsonpointer==3.0.0 +jsonschema==4.23.0 +jsonschema-specifications==2024.10.1 +langchain==0.2.16 +langchain-community==0.2.6 +langchain-core==0.2.41 +langchain-experimental==0.0.63 +langchain-text-splitters==0.2.4 +langsmith==0.1.136 +lxml==5.2.2 +markdown-it-py==3.0.0 +MarkupSafe==3.0.2 +marshmallow==3.23.0 +mdurl==0.1.2 +multidict==6.1.0 +mypy-extensions==1.0.0 +narwhals==1.9.4 +numpy==1.26.4 +orjson==3.10.7 +packaging==24.1 +pandas==2.2.2 +pillow==10.4.0 +propcache==0.2.0 +protobuf==5.28.2 +pyarrow==17.0.0 +pydantic==2.9.2 +pydantic_core==2.23.4 +pydeck==0.9.1 +Pygments==2.18.0 +python-dateutil==2.9.0.post0 +pytz==2024.2 +PyYAML==6.0.2 +referencing==0.35.1 +requests==2.32.3 +requests-toolbelt==1.0.0 +rich==13.9.2 +rpds-py==0.20.0 +six==1.16.0 +smmap==5.0.1 +sniffio==1.3.1 +sortedcontainers==2.4.0 +soupsieve==2.6 +SQLAlchemy==2.0.36 +streamlit==1.39.0 +tenacity==8.5.0 +toml==0.10.2 +tornado==6.4.1 +tqdm==4.66.5 +typing-inspect==0.9.0 +typing_extensions==4.12.2 +tzdata==2024.2 +urllib3==2.2.3 +watchdog==5.0.3 +yarl==1.15.5 diff --git a/table_functions.py b/table_functions.py new file mode 100644 index 0000000..c4a76aa --- /dev/null +++ b/table_functions.py @@ -0,0 +1,499 @@ +from langchain_community.llms import Ollama +from langchain.chains import RetrievalQA,RetrievalQAWithSourcesChain +from langchain import PromptTemplate +from langchain_core.output_parsers import StrOutputParser +import langchain +from langchain import SQLDatabase +from langchain_experimental.sql import SQLDatabaseChain +from langchain.chains.sql_database.prompt import PROMPT, SQL_PROMPTS +from langchain import LLMChain + +import pandas as pd +import ast +import textwrap +import os +import time +from argparse import ArgumentParser +import sys +import glob +import json +import re + + +import os +import time +from argparse import ArgumentParser +import sys +import json +import math +import requests +from func_timeout import func_set_timeout +import func_timeout + + +import sqlite3 + +langchain.verbose = False + +TABLE_PARAMETER = "{TABLE_PARAMETER}" +c_tr_index = "{c_tr_index}" +DROP_TABLE_SQL = f"DROP TABLE {TABLE_PARAMETER};" +GET_TABLES_SQL = "SELECT name FROM sqlite_schema WHERE type='table';" +GET_ROW_SQL = f"""SELECT * FROM {TABLE_PARAMETER} WHERE "index" = {c_tr_index};""" +def delete_all_tables(con): + tables = get_tables(con) + delete_tables(con, tables) + +def get_row(con, c_table, c_index): + cur = con.cursor() + sql = GET_ROW_SQL.replace(TABLE_PARAMETER, c_table); sql = sql.replace(c_tr_index, str(c_index)) + cur.execute(sql) + rows = cur.fetchall() + cur.close() + return rows + + +def get_tables(con): + cur = con.cursor() + cur.execute(GET_TABLES_SQL) + tables = cur.fetchall() + cur.close() + return tables + + +def delete_tables(con, tables): + cur = con.cursor() + for table, in tables: + sql = DROP_TABLE_SQL.replace(TABLE_PARAMETER, table) + cur.execute(sql) + cur.close() + +@func_set_timeout(40) +def table2text(llm, tableRow, question): + llm_chain = LLMChain( + llm=llm, + prompt=PromptTemplate.from_template(template_PM3_table2text) + ) + + result = llm_chain.generate([{"tableData":tableRow, "question":question}]) + return result.generations[0][0].text + + +@func_set_timeout(40) +def tableNtext_qa(llm, tableRow, pt, question): + llm_chain = LLMChain( + llm=llm, + prompt=PromptTemplate.from_template(template_PM3_tableNtext_qa) + ) + + result = llm_chain.generate([{"tableData":tableRow, "pt":pt, "question":question}]) + return result.generations[0][0].text + +@func_set_timeout(40) +def wrapper(func, query): + return(func(query)) + + + +def is_number(s): + try: + float(s) + return True + except ValueError: + pass + + return False + + +# for benchmarking only, generate half-sturctured data in plain text from single table_row +template_PM3_table2text = """ +### System: +You are reading the structured data given in the Context and try to rephrase it in plain text. In each line, the attribute name(header) is on the left of *:*, then corresponding attribute value is on the right. + +### Context: +{tableData} + +### User: +Each variant/mutation must contain alphabet letters with several digits, don't make up non-existed variants/mutations. +Limit your answer under 25 words. +Stop the answer by the word *END*. +Please read the above provided structured data in context and just answer the given question in short plain text. Question: {question}'\ +### Response: + +""" + + + + +template_PM3_tableNtext_qa = """ +### System: +You are reading the structured data and it's corresponding plain text description given in the Context, try to answer user's question based on these. For structured data, in each line, the attribute name(header) is on the left of *:*, then corresponding attribute value is on the right. + +### Context: +structured data {tableData} + +plain text description {pt} + +### User: +Limit your answer under 100 words and don't repeat the context or any info you are given. Please read the above provided structured data and it's corresponding plain text description in context and just answer the given question. Question: {question}'\ +### Response: + +""" + + + + +# for q4,8 +sqlquery_template = 'Given an input question, first create a syntactically correct {dialect} query to run, then look at the results \ +of the query and return the answer. Strict your query to a short one and dont give a long answer. *Never* use limitation to limit your query like: LIMIT {top_k} except user asked for certain row.\nWhen no specific column names are given, you can check for the answer in all columns using "OR" operator.\n\n\ +Unless exactly match is required by user, use LIKE other than = in the query\n\ +Never sort the results. If user asks for certain row, use LIMIT operator!\n\ +Never give a sql that will return all content in the table if not explicitly asked\n\ +Only give one query ended with \';\' eachtime!\n\ +Carefully check the statement after WHERE clause, don\'t mix up column_name with user\'s query string, and keep the string integral for matching!\n\ +When using LIKE operator, note to put column names on the left and query string on the right, don\'t reverse it\n\ +Don\'t forget to append ; at the end of query and no order is needed!\n\nPay attention to use only the column \ +names that you can see in the schema description. Be careful to not query for columns that do not exist. Also, pay \ +attention to which column is in which table.\n\nUse the following format:\n\nQuestion: Question here\nSQLQuery: SQL \ +Query to run\nSQLResult: Result of the SQLQuery\nAnswer: Final answer here\n\nOnly use the \ +following tables:\n{table_info}\n\nQuestion: {input}' + + +""" +# for q2 +sqlquery_template = 'Given an input question, first create a syntactically correct {dialect} query to run, then look at the results \ +of the query and return the answer. Only generate a single short query statement to run. *Never* use limitation to limit your query like: LIMIT {top_k}.\nThe query you generate should check all columns using OR operator.\n\n\ +The OR operator should involve all columns you can see, don\'t only choose several columns by yourself\n\ +Unless exactly match is required by user, use LIKE other than = in the query\n\ +Don\'t forget to append ; at the end of query\n\nPay attention to use only the column \ +names that you can see in the schema description. Be careful to not query for columns that do not exist. Also, pay \ +attention to which column is in which table.\n\nUse the following format:\n\nQuestion: Question here\nSQLQuery: SQL \ +Query to run\nSQLResult: Result of the SQLQuery\nAnswer: Final answer here\n\nOnly use the \ +following tables:\n{table_info}\n\nQuestion: {input}' +""" + +def table_extraction_n_sqlQA(current_paper_tables, model_name, query_variant_list, additional_question=None, llm=None, llm_qa=None, show_errors=True): + + df_list = [] + for c_table in current_paper_tables: + try: + df_list.append(pd.read_csv(c_table, header=None)) + except Exception as e: + print(f"current table is invalid: {c_table}") + valid_df_tables = [] + variant_insource = False + + for df in df_list: + c_len = 0 + for idx,row in df.iterrows(): + for row_i in range(len(row)): + c_len += len(str(row[row_i])) + + + c_num_col = len(df.axes[1]) + c_len /= len(list(df.iterrows())) + + #print(f"num_col: {c_num_col}; avg col_len= {c_len/c_num_col}") + + #print("-------------------------------") + if True:#c_len <= 80 and c_len >= 15: # filter extracted tables by avg row_len (large: plain text block, tiny: nonsense piece) + valid_df_tables.append(df) + for c_variant in query_variant_list: + for col in df.columns: + insource_result = df[df[col].astype(str).str.contains(c_variant, regex= True, na=False)] + if len(insource_result) > 0: + variant_insource = True + #print(insource_result) + break + if variant_insource: + break + #print(df) + table_chunks = math.ceil(len(valid_df_tables) / 5) + + basic_query_answers_list = [] + + for c_chunk in range(table_chunks): + shift = c_chunk*5 + conn = sqlite3.connect('PDFpaper_Table_extractions.db') + + delete_all_tables(conn) + + for df_idnex, df in enumerate(valid_df_tables[shift:shift+5 if c_chunkG" #"p.Asn346His" # c.104del # c.516G>C + + sqlite_db_path = "./PDFpaper_Table_extractions.db" + db = SQLDatabase.from_uri(f"sqlite:///{sqlite_db_path}") + + prompt = sqlquery_template.replace("{dialect}", db.dialect) + prompt = PromptTemplate.from_template(prompt) + + # return_direct: return the sql result directly + #@func_set_timeout(180)#设定函数超执行时间_ + db_chain = SQLDatabaseChain.from_llm(llm[0], db, prompt=prompt, use_query_checker=False, verbose=False, return_intermediate_steps=True, return_direct=True) + + c_unique_rows = [] + for c_vindx, query_variant in enumerate(query_variant_list): + for idx in range(current_end_index+1): + current_table = "current_paper_table_" + str(idx) + current_table_source_index = idx + shift + + ## basic QA + q0 = f"get the first row in table {current_table}? (take the result given by SQLResult:)" + try: + + result0 = wrapper(db_chain,q0) + + + except Exception as e: + if show_errors: + print(f"[ERROR] error when running question-0 on table {current_table}: {e}") + result0 = None + except func_timeout.exceptions.FunctionTimedOut: + result0 = None + del db_chain; llm.clear(); + llm.append(Ollama(model=model_name, temperature=0.0, top_p = 0.9)) + db_chain = SQLDatabaseChain.from_llm(llm[0], db, prompt=prompt, use_query_checker=False, verbose=False, return_intermediate_steps=True, return_direct=True) + + + ## basic QA + q1 = f"how many rows in table {current_table}? (take the result given by SQLResult:)" + try: + result1 = wrapper(db_chain,q1) + + except Exception as e: + if show_errors: + print(f"[ERROR] error when running question-1 on table {current_table}: {e}") + result1 = None + except func_timeout.exceptions.FunctionTimedOut: + result1 = None + del db_chain; llm.clear(); + llm.append(Ollama(model=model_name, temperature=0.0, top_p = 0.9)) + db_chain = SQLDatabaseChain.from_llm(llm[0], db, prompt=prompt, use_query_checker=False, verbose=False, return_intermediate_steps=True, return_direct=True) + + q2 = f"search for the string: '{query_variant}' through every column in table {current_table} using OR? (find all, no limit, column names should be like 0,1,2 as u can see in the schema)" + try: + result2 = wrapper(db_chain,q2) + + except Exception as e: + if show_errors: + print(f"[ERROR] error when running question-2 on table {current_table}: {e}") + result2 = None + except func_timeout.exceptions.FunctionTimedOut: + result2 = None + del db_chain; llm.clear(); + llm.append(Ollama(model=model_name, temperature=0.0, top_p = 0.9)) + db_chain = SQLDatabaseChain.from_llm(llm[0], db, prompt=prompt, use_query_checker=False, verbose=False, return_intermediate_steps=True, return_direct=True) + + q3 = f"Question: find all rows that contain the string '{query_variant}' in any column (don\'t only consider one column) (check all columns in table {current_table}) (find all, no limit)" + try: + result3 = wrapper(db_chain,q3) + + except Exception as e: + if show_errors: + print(f"[ERROR] error when running question-3 on table {current_table}: {e}") + result3 = None + except func_timeout.exceptions.FunctionTimedOut: + result3 = None + del db_chain; llm.clear(); + llm.append(Ollama(model=model_name, temperature=0.0, top_p = 0.9)) + db_chain = SQLDatabaseChain.from_llm(llm[0], db, prompt=prompt, use_query_checker=False, verbose=False, return_intermediate_steps=True, return_direct=True) + + q4 = f"Question: find all the rows that contain {query_variant} (query all columns in table {current_table} using OR) (find all, no limit)" + try: + result4 = wrapper(db_chain,q4) + + except Exception as e: + if show_errors: + print(f"[ERROR] error when running question-4 on table {current_table}") + result4 = None + except func_timeout.exceptions.FunctionTimedOut: + result4 = None + del db_chain; llm.clear(); + llm.append(Ollama(model=model_name, temperature=0.0, top_p = 0.9)) + db_chain = SQLDatabaseChain.from_llm(llm[0], db, prompt=prompt, use_query_checker=False, verbose=False, return_intermediate_steps=True, return_direct=True) + + basic_query_answers[current_table] = [(q1,result1["intermediate_steps"][-1] if result1 is not None else "sql error"),(q2,result2["intermediate_steps"][-1] if result2 is not None else "sql error"),(q3,result3["intermediate_steps"][-1] if result3 is not None else "sql error"), (q4,result4["intermediate_steps"][-1] if result4 is not None else "sql error")] + all_results = [result2, result3, result4] + + c_chunk_useful_answers = [] + + max_row_count = 0 + + for c_result in all_results: + if c_result is not None and len(c_result["result"]) > 0: + c_json_string = c_result["result"] + + c_list_results = ast.literal_eval(c_json_string.strip()) + + len_cl = len(c_list_results) + + for i in range(len_cl-1, -1, -1): + c_in = False + c_rowstr = json.dumps(c_list_results[i]) + for c_content in c_list_results[i]: + if query_variant in str(c_content): + if abs(len(query_variant) - len(str(c_content).strip())) >= 2: + if not is_number(str(c_content).strip()): + c_have_digits_0 = re.findall(rf"\d+{query_variant}\d+",str(c_content)) + c_have_digits_1 = re.findall(rf"{query_variant}\d+",str(c_content)) + c_have_digits_2 = re.findall(rf"\d+{query_variant}",str(c_content)) + c_single = re.findall(rf" {query_variant} ",str(c_content)) + if len(c_single) == len(c_have_digits_0) == len(c_have_digits_1) == len(c_have_digits_2) == 0: + if c_rowstr not in c_unique_rows and c_vindx == 0: # remove duplicate result from different querying variant_format + c_unique_rows.append(c_rowstr) + elif c_rowstr in c_unique_rows and c_vindx == 1: # remove duplicate result from different querying variant_format + break + c_in = True + break + if not c_in: + del c_list_results[i] + + c_rows_count = len(c_list_results) # choose max result of all similar questions + if not max_row_count < c_rows_count: + continue + + if result0 is not None: + r0_list = ast.literal_eval(result0["result"].strip())[0] + try: + r0_list = list(map(lambda x: x[1] + str(r0_list[:x[0]].count(x[1]) + 1) if r0_list.count(x[1]) > 1 else x[1], enumerate(r0_list))) + except Exception as e: + r0_list = None + + for c_index in range(len(c_list_results)): + if r0_list is None: + break + if len(r0_list) == len(c_list_results[c_index]): + + c_list_results[c_index] = dict(zip(r0_list, c_list_results[c_index])) + sub_dict = dict([(key, c_list_results[c_index][key]) for key in list(c_list_results[c_index].keys())[1:]]) + json_tablerow = json.dumps(sub_dict, indent=2) + c_question_summary = "rephrase and describe it in plain text" + c_question = "only list the existed variants/mutations in context in the following format *PatientID:... Variant:...*\nif no patient is explicitly mentioned put *PatientID:None* and don't mix up with variants/mutations. If no variants/mutations is explicitly mentioned put *Variant:None*." + + i_pre = int(list(c_list_results[c_index].values())[0]) - 1; + i_next = int(list(c_list_results[c_index].values())[0]) + 1; + json_list_answers = [] + if i_pre > 0: + pre_row = get_row(conn, current_table, i_pre)[0] + if str(pre_row[1]) == str(list(c_list_results[c_index].values())[1]): + json_list_answers.insert(0,dict(zip(r0_list[1:], list(pre_row)[1:]))) + #print(f"pre_row: {pre_row}") + + next_row = get_row(conn, current_table, i_next) + if len(next_row) > 0: + #print(f"next_row: {next_row[0]}") + next_row = next_row[0] + if str(next_row[1]) == str(list(c_list_results[c_index].values())[1]): + json_list_answers.append(dict(zip(r0_list[1:], list(next_row)[1:]))) + + #json_tablerow = json.dumps(json_list_answers, indent=2) + try: + c_text = table2text(llm[0], str(json_tablerow), c_question) + c_text_sum = table2text(llm[0], str(json_tablerow), c_question_summary) + except Exception as e: + + c_text = None + c_text_sum = None + llm.clear(); + llm.append(Ollama(model=model_name, temperature=0.0, top_p = 0.9)) + except func_timeout.exceptions.FunctionTimedOut: + + #print("func_timeout.exceptions.FunctionTimedOut") + #print("str(json_tablerow)",str(json_tablerow),"c_question",c_question) + c_text = None + c_text_sum = None + llm.clear(); + print("re-loading ollama") + llm.append(Ollama(model=model_name, temperature=0.0, top_p = 0.9)) + + # try again + try: + c_text = table2text(llm[0], str(json_tablerow), c_question) + c_text_sum = table2text(llm[0], str(json_tablerow), c_question_summary) + except func_timeout.exceptions.FunctionTimedOut: + c_text='LLM running failed' + c_text_sum = "LLM running failed" + llm.clear(); + print("re-loading ollama") + llm.append(Ollama(model=model_name, temperature=0.0, top_p = 0.9)) + if c_text is not None: + c_list_results[c_index]["plainText"] = f"## TableLLM Identified Record \n**Source**: - Table {current_table_source_index+1} - Row {c_list_results[c_index][0]} \n- **LLM extracted Variant/Genotypes with PatientID**: " + c_text +f" \n- **LLM Translated Row Summary**: {c_text_sum} " + f" \n- **Source Row Details**: {str(json_tablerow)} " + if c_text is not None: + + for c_tableRow in json_list_answers: + c_tableRow = json.dumps(c_tableRow, indent=2) + + try: + c_text = table2text(llm[0], str(c_tableRow), c_question) + c_text_sum = table2text(llm[0], str(c_tableRow), c_question_summary) + except Exception as e: + + c_text = None + c_text_sum = None + llm.clear(); + llm.append(Ollama(model=model_name, temperature=0.0, top_p = 0.9)) + except func_timeout.exceptions.FunctionTimedOut: + c_text = None + c_text_sum = None + llm.clear(); + llm.append(Ollama(model=model_name, temperature=0.0, top_p = 0.9)) + if c_text is not None: + c_list_results[c_index]["plainText"] += f"\n### Adjacent rows potentially contain intrans variant: \n- **LLM extracted Variant/Genotypes with PatientID**: " + c_text +f" \n- **LLM Translated Row Summary**: {c_text_sum} " + f" \n- **Source Row Details**: {str(c_tableRow)} " + + + + + else: c_list_results[c_index]["plainText"] = "GG" + + + #print(max_row_count, c_rows_count) + if max_row_count < c_rows_count: + max_row_count = c_rows_count + #c_chunk_useful_answers.append((c_result["intermediate_steps"][-2],c_json_string.strip())) + c_chunk_useful_answers = (c_result["intermediate_steps"][-2],c_list_results) + + + + ## given QA + if type(additional_question) is list: + for c_question in additional_question: + c_question = c_question + f" (only in table {current_table})" + try: + result = db_chain(c_question) + basic_query_answers[current_table].append((c_question,result["intermediate_steps"][-1])) + except Exception as e: + if show_errors: + print(f"error when running given question on table {current_table}: {e}") + basic_query_answers[current_table].append((c_question,["invalid_question!"])) + + elif type(additional_question) is str: + c_question = additional_question + f" (check all columns in table {current_table})" + try: + result = db_chain(c_question) + basic_query_answers[current_table].append((c_question,result["intermediate_steps"][-1])) + except Exception as e: + if show_errors: + print(f"error when running given question on table {current_table}: {e}") + basic_query_answers[current_table].append((c_question,["invalid_question!"])) + + + if len(c_chunk_useful_answers) > 0: + basic_query_answers_list.append(c_chunk_useful_answers) + + return [basic_query_answers_list, variant_insource] + diff --git a/utils.py b/utils.py new file mode 100644 index 0000000..6d8ab41 --- /dev/null +++ b/utils.py @@ -0,0 +1,178 @@ +import pandas as pd +from bioc import biocxml +from bs4 import BeautifulSoup +import re +from itertools import product + +import requests + + + + +def translateProtein2SingleChar(proteinNotation): + print("start translateProtein2SingleChar") + print("proteinNotation",proteinNotation) + try: + r = requests.get(f'https://mutalyzer.nl/api/normalize/{proteinNotation}?only_variants=false') + + j = r.json() + return j['equivalent_descriptions']['p'][0]['description'].split(":")[-1] + except Exception as e: + return None + + + +def sortVariantAlias(variantAlias,documents): + + for page in source_doc: + c_rsids = re.findall(current_variant,page.page_content) + + +def table_to_2d(table_tag): + rowspans = [] # track pending rowspans + rows = table_tag.find_all(['tr']) + # first scan, see how many columns we need + colcount = 0 + for r, row in enumerate(rows): + cells = row.find_all(['td', 'th'], recursive=False) + + ''' + colcount = max( + colcount, + sum(int(c.get('colspan', 1)) or 1 for c in cells[:-1]) + len(cells[-1:]) + len(rowspans)) + ''' + + colcount = max(colcount,sum(int(c.get('colspan', 1)) for c in cells) + len(rowspans)) + # update rowspan bookkeeping; 0 is a span to the bottom. + rowspans += [int(c.get('rowspan', 1)) or 1 or len(rows) - r for c in cells] + rowspans = [s - 1 for s in rowspans if s > 1] + + # it doesn't matter if there are still rowspan numbers 'active'; no extra + # rows to show in the table means the larger than 1 rowspan numbers in the + # last table row are ignored. + + # build an empty matrix for all possible cells + table = [[None] * colcount for row in rows] + + # fill matrix from row data + rowspans = {} # track pending rowspans, column number mapping to count + for row, row_elem in enumerate(rows): + span_offset = 0 # how many columns are skipped due to row and colspans + for col, cell in enumerate(row_elem.find_all(['td', 'th'], recursive=False)): + # adjust for preceding row and colspans + col += span_offset + while rowspans.get(col, 0): + span_offset += 1 + col += 1 + + # fill table data + rowspan = rowspans[col] = int(cell.get('rowspan', 1)) or len(rows) - row + colspan = int(cell.get('colspan', 1)) or colcount - col + # next column is offset by the colspan + span_offset += colspan - 1 + value = cell.get_text() + for drow, dcol in product(range(rowspan), range(colspan)): + try: + table[row + drow][col + dcol] = value + rowspans[col + dcol] = rowspan + except IndexError: + # rowspan or colspan outside the confines of the table + pass + + # update rowspan bookkeeping + rowspans = {c: s - 1 for c, s in rowspans.items() if s > 1} + + return table + + +def convert2DF(xml_data): + ''' + input: xml_data: the xml component of a table + + return: the converetd dataframe, collapse the table into 2D structure + + ''' + # parse XML string with BeautifulSoup + xml_data = re.sub(r'\\x..', '', xml_data) + + #print(xml_data) + soup = BeautifulSoup(xml_data, 'lxml') + + # find the table in the soup + table = soup.find('table') + + # if thead and tbody + heads = table.find('thead') + if(not heads): + # if no tags of thead in the table, lets make the head as the first row + new_body = table_to_2d(table) + df = pd.DataFrame(new_body[1:], columns=new_body[0]) + + return df + new_heads = table_to_2d(heads) + + # merge multiple row of heads by concatenate unique values + new_heads = [['' if value is None else value for value in row] for row in new_heads] + unique_columns = [list(set(column)) for column in zip(*new_heads)] + my_merged_header = [' '.join(column) for column in unique_columns] + + + + bodies = table.find('tbody') + + data = [] + new_body_data = table_to_2d(bodies) + for index,cur_body_row in enumerate(new_body_data): + + tmp_row = ['']*len(cur_body_row) + cur_body_row = ['' if value is None else value for value in cur_body_row] + + for k, cur_col in enumerate(cur_body_row): + tmp_row[k] = cur_col.replace("\\n",'') + + data.append(tmp_row) + df = pd.DataFrame(data, columns=my_merged_header) + return df +def extractTablesFromXML(XML_path): + + ''' + + XML_path: the path of the XML paper file + + return a list of dataframes, each df represents one table + + ''' + all_tables = [] + with biocxml.iterparse(XML_path) as reader: + for document in reader: + for i in range(len(document.passages)): + if(document.passages[i].infons['type']!='table'): + continue + cur_table_xml = document.passages[i].infons['xml'] + table_name = document.passages[i].infons['id'] + + df = convert2DF(cur_table_xml) + + all_tables.append(df) + + + + return all_tables + pass + +def reduceIntransDuplicates(text_in_trans_list): + + formated_in_trans_list = [] + for cur_answers in text_in_trans_list: + if('none' in cur_answers.lower() and 'contain' not in cur_answers.lower()): + continue + else: + formated_in_trans_list.extend(cur_answers.split(',')) + formated_in_trans_list = list(set(formated_in_trans_list)) + return formated_in_trans_list + pass +if __name__ == "__main__": + print("test extractTablesFromXML") + test_path = "/autofs/bal36md0/smli/smli/LLM-genome-curation/literatures/subset_report_XML/36546626.xml" + table_list = extractTablesFromXML(test_path) + print(len(table_list)) \ No newline at end of file