-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtext_segmentation.py
184 lines (148 loc) · 6.84 KB
/
text_segmentation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
class TextSegmenter:
def __init__(self, delimiter="\n\n", max_length=5000, overlap_length=50,
replace_continuous_spaces=True, remove_urls=False, min_segment_length=100):
"""
Initialize the text segmenter with given parameters.
Args:
delimiter (str): The delimiter to split text into segments
max_length (int): Maximum length of each segment in tokens
overlap_length (int): Number of tokens to overlap between segments
replace_continuous_spaces (bool): Whether to replace continuous spaces/tabs
remove_urls (bool): Whether to remove URLs from text
min_segment_length (int): Minimum length for a segment before merging
"""
if not isinstance(max_length, int) or max_length <= 0:
raise ValueError("max_length must be a positive integer")
if not isinstance(overlap_length, int) or overlap_length < 0:
raise ValueError("overlap_length must be a non-negative integer")
if overlap_length >= max_length:
raise ValueError("overlap_length must be less than max_length")
self.delimiter = delimiter
self.max_length = max_length
self.overlap_length = overlap_length
self.replace_continuous_spaces = replace_continuous_spaces
self.remove_urls = remove_urls
self.min_segment_length = min_segment_length
# 预编译正则表达式
import re
self.space_pattern = re.compile(r'[ \t]+')
self.newline_pattern = re.compile(r'\n+')
self.url_pattern = re.compile(r'http[s]?://(?:[a-zA-Z]|[0-9]|[$-_@.&+]|[!*\\(\\),]|(?:%[0-9a-fA-F][0-9a-fA-F]))+')
self.email_pattern = re.compile(r'[\w\.-]+@[\w\.-]+\.\w+')
self.chinese_pattern = re.compile(r'[。!?;]')
def preprocess_text(self, text):
"""
Preprocess text according to the rules.
Args:
text (str): Input text to preprocess
Returns:
str: Preprocessed text
"""
if not isinstance(text, str):
raise ValueError("Input text must be a string")
if self.replace_continuous_spaces:
# Replace continuous spaces and tabs with single space
text = self.space_pattern.sub(' ', text)
# Replace continuous newlines with single newline
text = self.newline_pattern.sub('\n', text)
if self.remove_urls:
# Remove URLs
text = self.url_pattern.sub('', text)
# Remove email addresses
text = self.email_pattern.sub('', text)
return text.strip()
def merge_short_segments(self, segments):
"""
Merge segments that are too short.
Args:
segments (list): List of text segments
Returns:
list: List of merged text segments
"""
if not segments:
return segments
merged = []
current = segments[0]
for i in range(1, len(segments)):
if len(current) < self.min_segment_length and i < len(segments):
current += self.delimiter + segments[i]
else:
merged.append(current)
current = segments[i]
if current:
merged.append(current)
return merged
def segment_text(self, text):
"""
Segment the input text according to the configured parameters.
Args:
text (str): Input text to segment
Returns:
list: List of text segments
"""
# Preprocess text
text = self.preprocess_text(text)
# Split text by delimiter and Chinese sentence endings
segments = []
initial_segments = text.split(self.delimiter)
for segment in initial_segments:
# Further split by Chinese sentence endings if needed
chinese_sentences = self.chinese_pattern.split(segment)
segments.extend(s.strip() for s in chinese_sentences if s.strip())
# Process segments
result_segments = []
current_segment = ""
for segment in segments:
# If adding this segment would exceed max_length
if len(current_segment + self.delimiter + segment) > self.max_length:
if current_segment:
result_segments.append(current_segment)
# If single segment is longer than max_length, split it
if len(segment) > self.max_length:
words = segment.split()
current_segment = ""
for word in words:
if len(current_segment + " " + word) > self.max_length:
result_segments.append(current_segment)
current_segment = word
else:
current_segment += " " + word if current_segment else word
else:
current_segment = segment
else:
current_segment += self.delimiter + segment if current_segment else segment
# Add the last segment if it exists
if current_segment:
result_segments.append(current_segment)
# Merge short segments
result_segments = self.merge_short_segments(result_segments)
# Handle overlap
if self.overlap_length > 0 and len(result_segments) > 1:
overlapped_segments = []
for i in range(len(result_segments)):
if i == 0:
overlapped_segments.append(result_segments[i])
else:
# Add overlap from previous segment
prev_segment = result_segments[i-1]
overlap = prev_segment[-self.overlap_length:] if len(prev_segment) > self.overlap_length else prev_segment
overlapped_segments.append(overlap + self.delimiter + result_segments[i])
return overlapped_segments
return result_segments
# Example usage
if __name__ == "__main__":
# Example text
example_text = """This is a sample text.
It contains multiple paragraphs and some extra spaces.
This is another paragraph with a URL: https://example.com
And an email: [email protected]
Let's see how it gets segmented."""
# Create segmenter with default settings
segmenter = TextSegmenter()
# Segment the text
segments = segmenter.segment_text(example_text)
# Print results
print("Segmented text:")
for i, segment in enumerate(segments, 1):
print(f"\nSegment {i}:")
print(segment)