Skip to content

Commit

Permalink
Enhance textractor to better support RAG use cases, closes #603
Browse files Browse the repository at this point in the history
  • Loading branch information
davidmezzetti committed Nov 28, 2023
1 parent 4eeb0b9 commit 591e730
Show file tree
Hide file tree
Showing 6 changed files with 229 additions and 31 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ endif
# Download test data
data:
mkdir -p /tmp/txtai
wget -N https://github.com/neuml/txtai/releases/download/v3.5.0/tests.tar.gz -P /tmp
wget -N https://github.com/neuml/txtai/releases/download/v6.2.0/tests.tar.gz -P /tmp
tar -xvzf /tmp/tests.tar.gz -C /tmp

# Unit tests
Expand Down
27 changes: 16 additions & 11 deletions src/python/txtai/pipeline/data/segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ class Segmentation(Pipeline):
Segments text into logical units.
"""

def __init__(self, sentences=False, lines=False, paragraphs=False, minlength=None, join=False):
def __init__(self, sentences=False, lines=False, paragraphs=False, minlength=None, join=False, sections=False):
"""
Creates a new Segmentation pipeline.
Expand All @@ -30,6 +30,7 @@ def __init__(self, sentences=False, lines=False, paragraphs=False, minlength=Non
paragraphs: tokenizes text into paragraphs if True, defaults to False
minlength: require at least minlength characters per text element, defaults to None
join: joins tokenized sections back together if True, defaults to False
sections: tokenizes text into sections if True, defaults to False. Splits using section or page breaks, depending on what's available
"""

if not NLTK:
Expand All @@ -38,6 +39,7 @@ def __init__(self, sentences=False, lines=False, paragraphs=False, minlength=Non
self.sentences = sentences
self.lines = lines
self.paragraphs = paragraphs
self.sections = sections
self.minlength = minlength
self.join = join

Expand Down Expand Up @@ -96,19 +98,23 @@ def parse(self, text):
if self.sentences:
content = [self.clean(x) for x in sent_tokenize(text)]
elif self.lines:
content = [self.clean(x) for x in text.split("\n")]
content = [self.clean(x) for x in re.split(r"\n{1,}", text)]
elif self.paragraphs:
content = [self.clean(x) for x in text.split("\n\n")]
content = [self.clean(x) for x in re.split(r"\n{2,}", text)]
elif self.sections:
split = r"\f" if "\f" in text else r"\n{3,}"
content = [self.clean(x) for x in re.split(split, text)]
else:
content = [self.clean(text)]
content = self.clean(text)

# Remove empty strings
content = [x for x in content if x]

if self.sentences or self.lines or self.paragraphs:
# Text tokenization enabled
if isinstance(content, list):
# Remove empty strings
content = [x for x in content if x]
return " ".join(content) if self.join else content

return content[0] if content else content
# Default method that returns clean text
return content

def clean(self, text):
"""
Expand All @@ -121,8 +127,7 @@ def clean(self, text):
clean text
"""

text = text.replace("\n", " ")
text = re.sub(r"\s+", " ", text)
text = re.sub(r" +", " ", text)
text = text.strip()

return text if not self.minlength or len(text) >= self.minlength else None
192 changes: 179 additions & 13 deletions src/python/txtai/pipeline/data/textractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

# Conditional import
try:
from bs4 import BeautifulSoup
from bs4 import BeautifulSoup, NavigableString
from tika import parser

TIKA = True
Expand All @@ -25,33 +25,36 @@ class Textractor(Segmentation):
Extracts text from files.
"""

def __init__(self, sentences=False, lines=False, paragraphs=False, minlength=None, join=False, tika=True):
def __init__(self, sentences=False, lines=False, paragraphs=False, minlength=None, join=False, tika=True, sections=False):
if not TIKA:
raise ImportError('Textractor pipeline is not available - install "pipeline" extra to enable')

super().__init__(sentences, lines, paragraphs, minlength, join)
super().__init__(sentences, lines, paragraphs, minlength, join, sections)

# Determine if Tika (default if Java is available) or Beautiful Soup should be used
# Beautiful Soup only supports HTML, Tika supports a wide variety of file formats, including HTML.
self.tika = self.checkjava() if tika else False

# HTML to Text extractor
self.extract = Extract()

def text(self, text):
# Use Tika if available
if self.tika:
# Format file urls as local file paths
text = text.replace("file://", "")

# text is a path to a file
parsed = parser.from_file(text)
return parsed["content"]

# Fallback to Beautiful Soup
text = f"file://{text}" if os.path.exists(text) else text
with contextlib.closing(urlopen(text)) as connection:
text = connection.read()
# Parse content to XHTML
parsed = parser.from_file(text, xmlContent=True)
text = parsed["content"]
else:
# Fallback to XHTML-only support, read data from url/path
text = f"file://{text}" if os.path.exists(text) else text
with contextlib.closing(urlopen(text)) as connection:
text = connection.read()

soup = BeautifulSoup(text, features="html.parser")
return soup.get_text()
# Extract text from HTML
return self.extract(text)

def checkjava(self, path=None):
"""
Expand All @@ -76,3 +79,166 @@ def checkjava(self, path=None):
return False

return True


class Extract:
"""
HTML to Text extractor. A limited set of Markdown is applied for organizing container elements such as tables and lists.
Visual formatting is not included (bold, italic, styling etc).
"""

def __call__(self, html):
"""
Transforms input HTML into formatted text.
Args:
html: input html
Returns:
formatted text
"""

# HTML Parser
soup = BeautifulSoup(html, features="html.parser")

# Extract text from each body element
nodes = []
for body in soup.find_all("body"):
nodes.append(self.process(body))

# Return extracted text, fallback to default text extraction if no nodes found
return "\n".join(nodes) if nodes else soup.get_text()

def process(self, node):
"""
Extracts text from a node. This method applies transforms for containers, tables, lists and text.
Page breaks are detected and reflected in the output text as a page break character.
Args:
node: input node
Returns:
node text
"""

if node.name == "table":
return self.table(node)
if node.name in ("ul", "ol"):
return self.items(node)

# Get page break symbol, if available
page = node.name and node.get("class") and "page" in node.get("class")

# Get node children
children = self.children(node)

# Join elements into text
text = "\n".join(self.process(node) for node in children) if self.iscontainer(node, children) else self.text(node)

# Detect page breaks. Otherwise add node text.
return f"{text}\f" if page else text

def text(self, node):
"""
Text handler. This method flattens a node and it's children to text.
Args:
node: input node
Returns:
node text
"""

# Get node children if available, otherwise use node as item
items = self.children(node)
items = items if items else [node]

# Join text elements
text = "".join(x.text for x in items)

# Return text, strip leading/trailing whitespace if this is a string only node
return text if node.name else text.strip()

def table(self, node):
"""
Table handler. This method transforms a HTML table into a Markdown formatted table.
Args:
node: input node
Returns:
table as markdown
"""

elements, header = [], False

# Process all rows
rows = node.find_all("tr")
for row in rows:
# Get list of columns for row
columns = row.find_all(lambda tag: tag.name in ("th", "td"))

# Add columns with separator
elements.append(f"|{'|'.join(self.process(column) for column in columns)}|")

# If there are multiple rows, add header format row
if not header and len(rows) > 1:
elements.append(f"{'|---' * len(columns)}|")
header = True

# Join elements together as string
return "\n".join(elements)

def items(self, node):
"""
List handler. This method transforms a HTML ordered/unordered list into a Markdown formatted list.
Args:
node: input node
Returns:
list as markdown
"""

elements = []
for x, element in enumerate(node.find_all("li")):
# Unordered lists use dashes. Ordered lists use numbers.
prefix = "-" if node.name == "ul" else f"{x + 1}."

# Add list element
elements.append(f" {prefix} {self.process(element)}")

# Join elements together as string
return "\n".join(elements)

def iscontainer(self, node, children):
"""
Analyzes a node and it's children to determine if this is a container element. A container
element is defined as being a div, body or not having any string elements as children.
Args:
node: input node
nodes: input node's children
Returns:
True if this is a container element, False otherwise
"""

return node.name in ("div", "body") or (children and not any(isinstance(x, NavigableString) for x in children))

def children(self, node):
"""
Gets the node children, if available.
Args:
node: input node
Returns:
node children or None if not available
"""

if node.name and node.contents:
# Iterate over children and remove whitespace-only string nodes
return [node for node in node.contents if node.name or node.text.strip()]

return None
4 changes: 2 additions & 2 deletions test/python/testapi/testpipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,7 @@ def testTextractor(self):
text = self.client.get(f"textract?file={Utils.PATH}/article.pdf").json()

# Check length of text is as expected
self.assertEqual(len(text), 2301)
self.assertEqual(len(text), 2334)

def testTextractorBatch(self):
"""
Expand All @@ -287,7 +287,7 @@ def testTextractorBatch(self):
path = Utils.PATH + "/article.pdf"

texts = self.client.post("batchtextract", json=[path, path]).json()
self.assertEqual((len(texts[0]), len(texts[1])), (2301, 2301))
self.assertEqual((len(texts[0]), len(texts[1])), (2334, 2334))

def testTranscribe(self):
"""
Expand Down
33 changes: 30 additions & 3 deletions test/python/testpipeline/testtextractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,24 @@ def testParagraphs(self):

textractor = Textractor(paragraphs=True)

# Extract text as sentences
# Extract text as paragraphs
paragraphs = textractor(Utils.PATH + "/article.pdf")

# Check number of paragraphs is as expected
self.assertEqual(len(paragraphs), 13)
self.assertEqual(len(paragraphs), 11)

def testSections(self):
"""
Test extraction to sections
"""

textractor = Textractor(sections=True)

# Extract as sections
paragraphs = textractor(Utils.PATH + "/document.pdf")

# Check number of sections is as expected
self.assertEqual(len(paragraphs), 3)

def testSentences(self):
"""
Expand All @@ -82,4 +95,18 @@ def testSingle(self):
text = textractor(Utils.PATH + "/article.pdf")

# Check length of text is as expected
self.assertEqual(len(text), 2301)
self.assertEqual(len(text), 2334)

def testTable(self):
"""
Test table extraction
"""

textractor = Textractor()

# Extract text as a single block
for name in ["document.docx", "spreadsheet.xlsx"]:
text = textractor(f"{Utils.PATH}/{name}")

# Check for table header
self.assertTrue("|---|" in text)
2 changes: 1 addition & 1 deletion test/python/testworkflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,7 +392,7 @@ def testStorageWorkflow(self):

results = list(workflow(["local://" + Utils.PATH, "test string"]))

self.assertEqual(len(results), 19)
self.assertEqual(len(results), 22)

def testTemplateInput(self):
"""
Expand Down

0 comments on commit 591e730

Please sign in to comment.