diff --git a/README.md b/README.md index 6d82541..8cf074b 100644 --- a/README.md +++ b/README.md @@ -35,19 +35,23 @@ Usage: dinglehopper [OPTIONS] GT OCR [REPORT_PREFIX] their text and falls back to plain text if no ALTO or PAGE is detected. The files GT and OCR are usually a ground truth document and the result of - an OCR software, but you may use dinglehopper to compare two OCR results. - In that case, use --no-metrics to disable the then meaningless metrics and - also change the color scheme from green/red to blue. + an OCR software, but you may use dinglehopper to compare two OCR results. In + that case, use --metrics='' to disable the then meaningless metrics and also + change the color scheme from green/red to blue. The comparison report will be written to $REPORT_PREFIX.{html,json}, where - $REPORT_PREFIX defaults to "report". The reports include the character - error rate (CER) and the word error rate (WER). + $REPORT_PREFIX defaults to "report". Depending on your configuration the + reports include the character error rate (CER), the word error rate (WER) + and the flexible character accuracy (FCA). + + The metrics can be chosen via a comma separated combination of their acronyms + like "--metrics=cer,wer,fca". By default, the text of PAGE files is extracted on 'region' level. You may use "--textequiv-level line" to extract from the level of TextLine tags. Options: - --metrics / --no-metrics Enable/disable metrics and green/red + --metrics Enable different metrics like cer, wer and fca. --textequiv-level LEVEL PAGE TextEquiv level to extract text from --progress Show progress bar --help Show this message and exit. @@ -80,12 +84,12 @@ The OCR-D processor has these parameters: | Parameter | Meaning | | ------------------------- | ------------------------------------------------------------------- | -| `-P metrics false` | Disable metrics and the green-red color scheme (default: enabled) | +| `-P metrics cer,wer` | Enable character error rate and word error rate (default) | | `-P textequiv_level line` | (PAGE) Extract text from TextLine level (default: TextRegion level) | For example: ~~~ -ocrd-dinglehopper -I ABBYY-FULLTEXT,OCR-D-OCR-CALAMARI -O OCR-D-OCR-COMPARE-ABBYY-CALAMARI -P metrics false +ocrd-dinglehopper -I ABBYY-FULLTEXT,OCR-D-OCR-CALAMARI -O OCR-D-OCR-COMPARE-ABBYY-CALAMARI -P metrics cer,wer ~~~ Developer information diff --git a/qurator/dinglehopper/__init__.py b/qurator/dinglehopper/__init__.py index 8e58101..dc45a8f 100644 --- a/qurator/dinglehopper/__init__.py +++ b/qurator/dinglehopper/__init__.py @@ -3,3 +3,8 @@ from .character_error_rate import * from .word_error_rate import * from .align import * +from .flexible_character_accuracy import ( + flexible_character_accuracy, + split_matches, + Match, +) diff --git a/qurator/dinglehopper/align.py b/qurator/dinglehopper/align.py index c7e7733..08bb3f5 100644 --- a/qurator/dinglehopper/align.py +++ b/qurator/dinglehopper/align.py @@ -8,11 +8,20 @@ def align(t1, t2): return seq_align(s1, s2) -def seq_align(s1, s2): +def seq_align_linewise(s1, s2, ops): + """Align two lists of lines linewise.""" + assert len(s1) == len(s2) + assert len(s2) == len(ops) + for l1, l2, line_ops in zip(s1, s2, ops): + yield from seq_align(l1, l2, ops=line_ops) + + +def seq_align(s1, s2, ops=None): """Align general sequences.""" s1 = list(s1) s2 = list(s2) - ops = seq_editops(s1, s2) + if not ops: + ops = seq_editops(s1, s2) i = 0 j = 0 diff --git a/qurator/dinglehopper/cli.py b/qurator/dinglehopper/cli.py index 09c26f0..c9b347f 100644 --- a/qurator/dinglehopper/cli.py +++ b/qurator/dinglehopper/cli.py @@ -1,3 +1,4 @@ +import json import os import click @@ -6,14 +7,15 @@ from uniseg.graphemecluster import grapheme_clusters from .character_error_rate import character_error_rate_n +from .flexible_character_accuracy import flexible_character_accuracy, split_matches from .word_error_rate import word_error_rate_n, words_normalized -from .align import seq_align +from .align import seq_align, seq_align_linewise from .extracted_text import ExtractedText from .ocr_files import extract from .config import Config -def gen_diff_report(gt_in, ocr_in, css_prefix, joiner, none): +def gen_diff_report(gt_in, ocr_in, css_prefix, joiner, none, matches=None): gtx = "" ocrx = "" @@ -41,7 +43,32 @@ def format_thing(t, css_classes=None, id_=None): else: return "{html_t}".format(html_t=html_t) - if isinstance(gt_in, ExtractedText): + ops, ocr_ids = None, None + seq_align_fun = seq_align + if matches: + seq_align_fun = seq_align_linewise + gt_things, ocr_things, ops = split_matches(matches) + # we have to reconstruct the order of the ocr because we mixed it for fca + ocr_lines = [match.ocr for match in matches] + ocr_lines_sorted = sorted(ocr_lines, key=lambda x: x.line + x.start / 10000) + + ocr_line_region_id = {} + pos = 0 + for ocr_line in ocr_lines_sorted: + if ocr_line.line not in ocr_line_region_id.keys(): + try: + ocr_line_region_id[ocr_line.line] = ocr_in.segment_id_for_pos(pos) + except AssertionError: + pass + pos += ocr_line.length + + ocr_ids = {None: None} + pos = 0 + for ocr_line in ocr_lines: + for _ in ocr_line.text: + ocr_ids[pos] = ocr_line_region_id[ocr_line.line] + pos += 1 + elif isinstance(gt_in, ExtractedText): if not isinstance(ocr_in, ExtractedText): raise TypeError() # XXX splitting should be done in ExtractedText @@ -53,17 +80,20 @@ def format_thing(t, css_classes=None, id_=None): g_pos = 0 o_pos = 0 - for k, (g, o) in enumerate(seq_align(gt_things, ocr_things)): + for k, (g, o) in enumerate(seq_align_fun(gt_things, ocr_things, ops=ops)): css_classes = None gt_id = None ocr_id = None if g != o: css_classes = "{css_prefix}diff{k} diff".format(css_prefix=css_prefix, k=k) if isinstance(gt_in, ExtractedText): - gt_id = gt_in.segment_id_for_pos(g_pos) if g is not None else None - ocr_id = ocr_in.segment_id_for_pos(o_pos) if o is not None else None # Deletions and inserts only produce one id + None, UI must # support this, i.e. display for the one id produced + gt_id = gt_in.segment_id_for_pos(g_pos) if g else None + if ocr_ids: + ocr_id = ocr_ids.get(o_pos, None) + else: + ocr_id = ocr_in.segment_id_for_pos(o_pos) if o else None gtx += joiner + format_thing(g, css_classes, gt_id) ocrx += joiner + format_thing(o, css_classes, ocr_id) @@ -83,28 +113,37 @@ def format_thing(t, css_classes=None, id_=None): ) -def process(gt, ocr, report_prefix, *, metrics=True, textequiv_level="region"): +def process(gt, ocr, report_prefix, *, metrics="cer,wer", textequiv_level="region"): """Check OCR result against GT. - The @click decorators change the signature of the decorated functions, so we keep this undecorated version and use - Click on a wrapper. + The @click decorators change the signature of the decorated functions, + so we keep this undecorated version and use Click on a wrapper. """ + cer, char_diff_report, n_characters = None, None, None + wer, word_diff_report, n_words = None, None, None + fca, fca_diff_report = None, None gt_text = extract(gt, textequiv_level=textequiv_level) ocr_text = extract(ocr, textequiv_level=textequiv_level) - cer, n_characters = character_error_rate_n(gt_text, ocr_text) - wer, n_words = word_error_rate_n(gt_text, ocr_text) - - char_diff_report = gen_diff_report( - gt_text, ocr_text, css_prefix="c", joiner="", none="·" - ) + if "cer" in metrics or not metrics: + cer, n_characters = character_error_rate_n(gt_text, ocr_text) + char_diff_report = gen_diff_report( + gt_text, ocr_text, css_prefix="c", joiner="", none="·" + ) - gt_words = words_normalized(gt_text) - ocr_words = words_normalized(ocr_text) - word_diff_report = gen_diff_report( - gt_words, ocr_words, css_prefix="w", joiner=" ", none="⋯" - ) + if "wer" in metrics: + gt_words = words_normalized(gt_text) + ocr_words = words_normalized(ocr_text) + wer, n_words = word_error_rate_n(gt_text, ocr_text) + word_diff_report = gen_diff_report( + gt_words, ocr_words, css_prefix="w", joiner=" ", none="⋯" + ) + if "fca" in metrics: + fca, fca_matches = flexible_character_accuracy(gt_text, ocr_text) + fca_diff_report = gen_diff_report( + gt_text, ocr_text, css_prefix="c", joiner="", none="·", matches=fca_matches + ) def json_float(value): """Convert a float value to an JSON float. @@ -124,6 +163,7 @@ def json_float(value): ) ) env.filters["json_float"] = json_float + env.filters["json_dumps"] = json.dumps for report_suffix in (".html", ".json"): template_fn = "report" + report_suffix + ".j2" @@ -137,8 +177,10 @@ def json_float(value): n_characters=n_characters, wer=wer, n_words=n_words, + fca=fca, char_diff_report=char_diff_report, word_diff_report=word_diff_report, + fca_diff_report=fca_diff_report, metrics=metrics, ).dump(out_fn) @@ -148,7 +190,9 @@ def json_float(value): @click.argument("ocr", type=click.Path(exists=True)) @click.argument("report_prefix", type=click.Path(), default="report") @click.option( - "--metrics/--no-metrics", default=True, help="Enable/disable metrics and green/red" + "--metrics", + default="cer,wer", + help="Enable different metrics like cer, wer and fca.", ) @click.option( "--textequiv-level", @@ -166,12 +210,16 @@ def main(gt, ocr, report_prefix, metrics, textequiv_level, progress): The files GT and OCR are usually a ground truth document and the result of an OCR software, but you may use dinglehopper to compare two OCR results. In - that case, use --no-metrics to disable the then meaningless metrics and also + that case, use --metrics='' to disable the then meaningless metrics and also change the color scheme from green/red to blue. The comparison report will be written to $REPORT_PREFIX.{html,json}, where - $REPORT_PREFIX defaults to "report". The reports include the character error - rate (CER) and the word error rate (WER). + $REPORT_PREFIX defaults to "report". Depending on your configuration the + reports include the character error rate (CER), the word error rate (WER) + and the flexible character accuracy (FCA). + + The metrics can be chosen via a comma separated combination of their acronyms + like "--metrics=cer,wer,fca". By default, the text of PAGE files is extracted on 'region' level. You may use "--textequiv-level line" to extract from the level of TextLine tags. diff --git a/qurator/dinglehopper/extracted_text.py b/qurator/dinglehopper/extracted_text.py index 9703b6b..c779836 100644 --- a/qurator/dinglehopper/extracted_text.py +++ b/qurator/dinglehopper/extracted_text.py @@ -1,4 +1,5 @@ import enum +import logging import re import unicodedata from contextlib import suppress @@ -8,7 +9,8 @@ import attr import numpy as np from lxml import etree as ET -from ocrd_utils import getLogger + +LOG = logging.getLogger("processor.OcrdDinglehopperEvaluate") class Normalization(enum.Enum): @@ -239,7 +241,6 @@ def get_textequiv_unicode(text_segment, nsmap) -> str: def get_first_textequiv(textequivs, segment_id): """Get the first TextEquiv based on index or conf order if index is not present.""" - log = getLogger("processor.OcrdDinglehopperEvaluate") if len(textequivs) == 1: return textequivs[0] @@ -248,20 +249,20 @@ def get_first_textequiv(textequivs, segment_id): nan_mask = np.isnan(indices) if np.any(~nan_mask): if np.any(nan_mask): - log.warning("TextEquiv without index in %s.", segment_id) + LOG.warning("TextEquiv without index in %s.", segment_id) index = np.nanargmin(indices) else: # try ordering by conf confidences = np.array([get_attr(te, "conf") for te in textequivs], dtype=float) if np.any(~np.isnan(confidences)): - log.info( + LOG.info( "No index attributes, use 'conf' attribute to sort TextEquiv in %s.", segment_id, ) index = np.nanargmax(confidences) else: # fallback to first entry in case of neither index or conf present - log.warning("No index attributes, use first TextEquiv in %s.", segment_id) + LOG.warning("No index attributes, use first TextEquiv in %s.", segment_id) index = 0 return textequivs[index] diff --git a/qurator/dinglehopper/flexible_character_accuracy.py b/qurator/dinglehopper/flexible_character_accuracy.py new file mode 100644 index 0000000..4ace63c --- /dev/null +++ b/qurator/dinglehopper/flexible_character_accuracy.py @@ -0,0 +1,455 @@ +""" +Implementation of the flexible character accuracy + +Citation: + Flexible character accuracy measure for reading-order-independent evaluation + C. Clausner, S. Pletschacher, A. Antonacopoulos + Pattern Recognition Letters, Volume 131, March 2020, Pages 390-397 +Link: http://www.primaresearch.org/publications/PRL_Clausner_FlexibleCharacterAccuracy +DOI: https://doi.org/10.1016/j.patrec.2020.02.003 + +Note that we deviated from the original algorithm at some places. +""" + +import sys +from collections import Counter +from functools import lru_cache, reduce, partial +from itertools import product, takewhile +from multiprocessing import cpu_count, get_context +from typing import List, Tuple, Optional, Union + +from Levenshtein import editops + +from . import ExtractedText + +if sys.version_info.minor == 5: + from .flexible_character_accuracy_ds_35 import ( + PartVersionSpecific, + Match, + Distance, + Coefficients, + ) +else: + from .flexible_character_accuracy_ds import ( + PartVersionSpecific, + Match, + Distance, + Coefficients, + ) + + +def flexible_character_accuracy( + gt: Union[str, ExtractedText], + ocr: Union[str, ExtractedText], + n_cpu: int = cpu_count(), +) -> Tuple[float, List[Match]]: + """Calculate the flexible character accuracy. + + Reference: contains steps 1-7 of the flexible character accuracy algorithm. + + :param gt: The ground truth text. + :param ocr: The text to compare the ground truth with. + :param n_cpu: numbers of cpus to use for multiprocessing. + :return: Score between 0 and 1 and match objects. + """ + + if isinstance(gt, ExtractedText): + gt = gt.text + if isinstance(ocr, ExtractedText): + ocr = ocr.text + + best_score = -sys.maxsize + best_matches = [] + # TODO: should this be configurable? + coeffs = ( + Coefficients( + edit_dist=edit_dist, length_diff=length_diff, offset=offset, length=length + ) + for edit_dist, length_diff, offset, length in product( + range(15, 31, 5), range(0, 24, 3), range(0, 4, 1), range(0, 6, 1) + ) + ) + with get_context("spawn").Pool(processes=n_cpu) as pool: + # Steps 1 - 6 of the flexible character accuracy algorithm. + # We only use multiprocessing if we have more than 2 cpus available. + # Otherwise the overhead for creating processes and filling caches is too big. + map_fun = partial(pool.imap_unordered, chunksize=10) if n_cpu > 2 else map + for matches in map_fun( + partial(match_with_coefficients, gt=gt, ocr=ocr), coeffs + ): + # Step 7 of the flexible character accuracy algorithm. + score = character_accuracy_for_matches(matches) + if score > best_score: + best_score = score + best_matches = matches + # early breaking: we only need one perfect fit + if best_score >= 1: + break + return best_score, best_matches + + +def match_with_coefficients(coef: Coefficients, gt: str, ocr: str) -> List[Match]: + """Match ground truth with ocr and consider a given set of coefficients. + + Reference: contains steps 1 - 6 of the flexible character accuracy algorithm. + + :return: A list of match objects to score and align the texts. + """ + # Steps 1 and 2 of the flexible character accuracy algorithm. + ocr_lines = initialize_lines(ocr) + gt_lines = initialize_lines(gt) + + matches = [] + + # Step 5 of the flexible character accuracy algorithm. + while len(gt_lines) != 0 and len(ocr_lines) != 0: + # Steps 3 and 4 of the flexible character accuracy algorithm. + match = match_longest_gt_lines(gt_lines, ocr_lines, coef) + if match: + matches.append(match) + + # Step 6 of the flexible character accuracy algorithm. + # remaining lines are considered as deletes and inserts + deletes = [ + distance(line, Part(text="", line=line.line, start=line.start)) + for line in gt_lines + ] + inserts = [ + distance(Part(text="", line=line.line, start=line.start), line) + for line in ocr_lines + ] + + return [*matches, *deletes, *inserts] + + +def match_longest_gt_lines( + gt_lines: List["Part"], ocr_lines: List["Part"], coef: Coefficients +) -> Optional[Match]: + """Find the best match for the longest line(s) in ground truth. + + The longest lines in ground truth are matched against lines in ocr to find the + best matching pair. This pair is then either considered a match on a full line + or the line(s) is splitted and the non matching parts are added back to the list. + + Reference: contains steps 3 and 4 of the flexible character accuracy algorithm. + + :return: Possible match object. + """ + best_score, best_match, best_gt, best_ocr = -sys.maxsize, None, None, None + if not ocr_lines: + return best_match + + # Step 3 of the flexible character accuracy algorithm (variation). + # We do not only take the longest line from ground truth but decide on a length + # threshold and take all lines from ground truth bigger than the threshold. + length_threshold = min(gt_lines[0].length, ocr_lines[0].length) - 1 + for gt_line in takewhile(lambda line: line.length > length_threshold, gt_lines): + match, ocr_line = match_gt_line(gt_line, ocr_lines, coef) + score = -sys.maxsize if not match else character_accuracy(match.dist) + if score > best_score: + best_score, best_match, best_gt, best_ocr = score, match, gt_line, ocr_line + # early breaking: we only need one perfect fit + if best_score >= 1: + break + + # Step 4 of the flexible character accuracy algorithm. + if best_match: + remove_or_split(best_gt, best_match.gt, gt_lines) + remove_or_split(best_ocr, best_match.ocr, ocr_lines) + + return best_match + + +def match_gt_line( + gt_line: "Part", ocr_lines: List["Part"], coef: Coefficients +) -> Tuple[Optional[Match], Optional["Part"]]: + """Match the given ground truth line against all the lines in ocr. + + Reference: contains steps 3 of the flexible character accuracy algorithm. + + TODO: Make penalty function configurable? + + :return: Match object and the matched ocr line. + """ + min_penalty = sys.maxsize + best_match, best_ocr = None, None + gt_line_length = gt_line.length + gt_line_start = gt_line.start + for ocr_line in ocr_lines: + match = match_lines(gt_line, ocr_line) + if match: + penalty = calculate_penalty( + gt_line_length, + ocr_line.length, + gt_line_start, + ocr_line.start, + match.gt.start, + match.ocr.start, + match.dist, + coef, + ) + if penalty < min_penalty: + min_penalty, best_match, best_ocr = penalty, match, ocr_line + return best_match, best_ocr + + +@lru_cache(maxsize=100000) +def match_lines(gt_line: "Part", ocr_line: "Part") -> Optional[Match]: + """Matches two lines searching for a naive local alignment. + + The shorter line is moved along the longer line + until the editing distance is minimized. + + Reference: see figure 2 in the doi:10.1016/j.patrec.2020.02.003. + + TODO: make distance function configurable? + TODO: use @cache annotation in Python 3.9? + + :return: Match object if one is found. + """ + min_length = min(gt_line.length, ocr_line.length) + best_match = None + best_i, best_j = 0, 0 + if min_length == 0: + return best_match + length_diff = gt_line.length - ocr_line.length + min_edit_dist = sys.maxsize + + gt_parts = [ + (i, gt_line.substring(rel_start=i, rel_end=i + min_length)) + for i in range(0, max(1, length_diff + 1)) + ] + ocr_parts = [ + (j, ocr_line.substring(rel_start=j, rel_end=j + min_length)) + for j in range(0, max(1, -1 * length_diff + 1)) + ] + + # add full line + gt_parts = [*gt_parts, (0, gt_line)] + ocr_parts = [*ocr_parts, (0, ocr_line)] + + for i, gt_part in gt_parts: + for j, ocr_part in ocr_parts: + match = distance(gt_part, ocr_part) + edit_dist = score_edit_distance(match.dist) + if edit_dist < min_edit_dist and match.dist.replace < min_length: + min_edit_dist = edit_dist + best_match = match + best_i, best_j = i, j + # elongate at the end for handling deletes + if best_match and (best_match.dist.delete or best_match.dist.replace): + part_length = best_match.gt.length + additional_length = best_match.dist.delete + best_match.dist.replace + for k in range(part_length + 1, part_length + additional_length + 1): + match = distance( + gt_line.substring(rel_start=best_i, rel_end=best_i + k), + ocr_line.substring(rel_start=best_j, rel_end=best_j + k), + ) + edit_dist = score_edit_distance(match.dist) + if edit_dist < min_edit_dist and match.dist.replace < min_length: + min_edit_dist = edit_dist + best_match = match + # is delete a better option? + match = distance(gt_line, Part(text="", line=ocr_line.line, start=ocr_line.start)) + edit_dist = score_edit_distance(match.dist) + if edit_dist < min_edit_dist: + best_match = match + + return best_match + + +@lru_cache(maxsize=100000) +def distance(gt: "Part", ocr: "Part") -> Match: + """Calculate the editing distance between the two lines. + + Using the already available `editops()` function with the Levenshtein distance. + + TODO: use @cache annotation in Python 3.9? + TODO: wait for qurator-spk/dinglehopper#48 for efficient editops. + + :return: Match object containing the lines and the editing operations. + """ + ops = editops(gt.text, ocr.text) + edits = Counter([edit[0] for edit in ops]) + edits["match"] = gt.length - edits["delete"] - edits["replace"] + return Match(gt=gt, ocr=ocr, dist=Distance(**edits), ops=ops) + + +def score_edit_distance(dist: Distance) -> int: + """Calculate edit distance for a match. + + Formula: $deletes + inserts + 2 * replacements$ + + :return: Sum of deletes, inserts and replacements. + """ + return dist.delete + dist.insert + 2 * dist.replace + + +@lru_cache(100000) +def calculate_penalty( + gt_length: int, + ocr_length: int, + gt_start: int, + ocr_start: int, + gt_match_start: int, + ocr_match_start: int, + dist: Distance, + coef: Coefficients, +) -> float: + """Calculate the penalty for a given match. + + For details and discussion see Section 3 in doi:10.1016/j.patrec.2020.02.003. + + :return: Penalty for the given match. + """ + min_edit_dist = score_edit_distance(dist) + length_diff = abs(gt_length - ocr_length) + substring_length = min(gt_length, ocr_length) + offset = 0.0 + if length_diff > 1: + substring_pos = max(gt_match_start - gt_start, ocr_match_start - ocr_start) + offset = length_diff / 2 - abs(substring_pos - length_diff / 2) + return ( + min_edit_dist * coef.edit_dist + + length_diff * coef.length_diff + + offset * coef.offset + - substring_length * coef.length + ) + + +def character_accuracy_for_matches(matches: List[Match]) -> float: + """Character accuracy of a full text represented by a list of matches. + + See other `character_accuracy` for details. + """ + agg = reduce( + lambda acc, match: acc + Counter(match.dist._asdict()), matches, Counter() + ) # type: Counter + + score = character_accuracy(Distance(**agg)) + return score + + +def character_accuracy(edits: Distance) -> float: + """Character accuracy calculated by necessary edit operations. + + Edit operations are needed edits to transform one text into another. + + The character accuracy is given by $1 - errors / characters$. + + Errors are replacements, deletes and inserts. + + Note that it is possible to have more errors than characters in which case the + character accuracy turns negative. + + Comparing two empty strings (having no edits) results in a character accuracy of 1. + """ + errors = edits.replace + edits.delete + edits.insert + chars = edits.match + edits.replace + edits.delete + if not chars and not errors: + # comparison of empty strings is considered a full match + score = 1.0 + elif not chars: + score = -errors + else: + score = 1.0 - errors / chars + return score + + +def initialize_lines(text: str) -> List["Part"]: + """Splits a text into lines and converts them to our line data object. + + The line objects are sorted by their length descending. + + Reference: contains steps 1 and 2 of the flexible character accuracy algorithm. + + :param text: Text to split into lines. + :return: List of sorted line objects. + """ + lines = [ + Part(text=line, line=i, start=0) + for i, line in enumerate(text.splitlines()) + if len(line) > 0 + ] + lines.sort(key=lambda x: x.length, reverse=True) + return lines + + +def remove_or_split(original: "Part", match: "Part", lines: List["Part"]) -> bool: + """Removes the matched line or splits it into parts. + + Reference: contains step 4 of the flexible character accuracy algorithm. + + :return: True if line was splitted. + """ + splitted = False + del lines[lines.index(original)] + if match.length < original.length: + lines.extend(original.split(match)) + # sorting for ocr is not mentioned in the paper, but is used as tie breaking =) + lines.sort(key=lambda x: x.length, reverse=True) + splitted = True + return splitted + + +def split_matches( + matches: List[Match], linesep="\n" +) -> Tuple[List[str], List[str], List[List]]: + """Extracts text segments and editing operations in separate lists. + + :param matches: List of match objects. + :param linesep: Character(s) or line separation. + :return: List of ground truth segments, ocr segments and editing operations. + """ + matches = sorted(matches, key=lambda m: m.gt.line + m.gt.start / 10000) + line = 0 + gt, ocr, ops = [], [], [] + for match in matches: + if match.gt.line > line: + gt.append(linesep) + ocr.append(linesep) + ops.extend([[]] * len(linesep)) + line = match.gt.line + gt.append(match.gt.text) + ocr.append(match.ocr.text) + ops.append(match.ops) + return gt, ocr, ops + + +class Part(PartVersionSpecific): + @property + def end(self) -> int: + return self.start + self.length + + @property + def length(self) -> int: + return len(self.text) + + def split(self, split: "Part") -> List["Part"]: + """Split the line part by another and returns the remaining parts. + + `abc.split("b")` will return ´["a", "c"]`. + + :param split: The line part we want to use to split. + :return: The parts before and after the split. + """ + rest = [] + if self.start < split.start: + rest.append(self.substring(rel_end=split.start - self.start)) + if split.end < self.end: + rest.append(self.substring(rel_start=split.end - self.start)) + return rest + + def substring(self, rel_start: int = 0, rel_end: int = None) -> "Part": + """Get part of the given line. + + Automatically handles the offset of the line. + Therefore `substring(rel_start=2)` will return `Part[start+rel_start:]`. + + :param rel_start: start relative to the part of the line. + :param rel_end: end relative to the part of the line. + :return: Extracted part of the given part of the line. + """ + text = self.text[rel_start:rel_end] + start = self.start + rel_start + return Part(line=self.line, text=text, start=start) diff --git a/qurator/dinglehopper/flexible_character_accuracy_ds.py b/qurator/dinglehopper/flexible_character_accuracy_ds.py new file mode 100644 index 0000000..ac5595c --- /dev/null +++ b/qurator/dinglehopper/flexible_character_accuracy_ds.py @@ -0,0 +1,48 @@ +""" +Datastructures to be used with the Flexible Character Accuracy Algorithm + +Separated because of version compatibility issues with Python 3.5. +""" + +from typing import List, NamedTuple + + +class PartVersionSpecific(NamedTuple): + """Represent a line or part of a line. + + This data object is maintained to be able to reproduce the original text. + """ + + text: str = "" + line: int = 0 + start: int = 0 + + +class Distance(NamedTuple): + """Represent distance between two sequences.""" + + match: int = 0 + replace: int = 0 + delete: int = 0 + insert: int = 0 + + +class Match(NamedTuple): + """Represent a calculated match between ground truth and the ocr result.""" + + gt: "Part" + ocr: "Part" + dist: "Distance" + ops: List + + +class Coefficients(NamedTuple): + """Coefficients to calculate penalty for substrings. + + See Section 3 in doi:10.1016/j.patrec.2020.02.003 + """ + + edit_dist: int = 25 + length_diff: int = 20 + offset: int = 1 + length: int = 4 diff --git a/qurator/dinglehopper/flexible_character_accuracy_ds_35.py b/qurator/dinglehopper/flexible_character_accuracy_ds_35.py new file mode 100644 index 0000000..17384ac --- /dev/null +++ b/qurator/dinglehopper/flexible_character_accuracy_ds_35.py @@ -0,0 +1,83 @@ +""" +Datastructures to be used with the Flexible Character Accuracy Algorithm + +Separated because of version compatibility issues with Python 3.5. +""" + +from collections import namedtuple +from typing import Dict + + +class PartVersionSpecific: + def __init__(self, text: str = "", line: int = 0, start: int = 0): + self.text = text + self.line = line + self.start = start + + def __eq__(self, other): + return ( + self.line == other.line + and self.start == other.start + and self.text == other.text + ) + + def __hash__(self): + return hash(self.text) ^ hash(self.line) ^ hash(self.start) + + def _asdict(self) -> Dict: + return { + "text": self.text, + "line": self.line, + "start": self.start, + } + + +class Distance: + def __init__( + self, match: int = 0, replace: int = 0, delete: int = 0, insert: int = 0 + ): + self.match = match + self.replace = replace + self.delete = delete + self.insert = insert + + def _asdict(self) -> Dict: + return { + "match": self.match, + "replace": self.replace, + "delete": self.delete, + "insert": self.insert, + } + + def __eq__(self, other): + return ( + self.match == other.match + and self.replace == other.replace + and self.delete == other.delete + and self.insert == other.insert + ) + + def __hash__(self): + return ( + hash(self.match) + ^ hash(self.replace) + ^ hash(self.delete) + ^ hash(self.insert) + ) + + +Match = namedtuple("Match", ["gt", "ocr", "dist", "ops"]) + + +class Coefficients: + def __init__( + self, + edit_dist: int = 25, + length_diff: int = 20, + offset: int = 1, + length: int = 4, + ): + self.edit_dist = edit_dist + self.length_diff = length_diff + self.offset = offset + self.length = length diff --git a/qurator/dinglehopper/ocr_files.py b/qurator/dinglehopper/ocr_files.py index 57ebd3f..6f2dd40 100644 --- a/qurator/dinglehopper/ocr_files.py +++ b/qurator/dinglehopper/ocr_files.py @@ -125,7 +125,7 @@ def page_text(tree, *, textequiv_level="region"): def plain_extract(filename): - with open(filename, "r") as f: + with open(filename, "r", encoding="utf8") as f: return ExtractedText( None, [ diff --git a/qurator/dinglehopper/ocrd-tool.json b/qurator/dinglehopper/ocrd-tool.json index 1e2b9b0..f8d480e 100644 --- a/qurator/dinglehopper/ocrd-tool.json +++ b/qurator/dinglehopper/ocrd-tool.json @@ -19,9 +19,10 @@ ], "parameters": { "metrics": { - "type": "boolean", - "default": true, - "description": "Enable/disable metrics and green/red" + "type": "string", + "enum": ["", "cer", "wer", "fca", "cer,wer", "cer,fca", "wer,fca", "cer,wer,fca"], + "default": "cer,wer", + "description": "Enable different metrics like cer, wer and fca." }, "textequiv_level": { "type": "string", diff --git a/qurator/dinglehopper/templates/report.html.j2 b/qurator/dinglehopper/templates/report.html.j2 index 0c2f464..be764db 100644 --- a/qurator/dinglehopper/templates/report.html.j2 +++ b/qurator/dinglehopper/templates/report.html.j2 @@ -40,16 +40,31 @@ {% if metrics %}

Metrics

-

CER: {{ cer|round(4) }}

-

WER: {{ wer|round(4) }}

+ {% if cer is not none %} +

CER: {{ cer|round(4) }}

+ {% endif %} + {% if wer is not none %} +

WER: {{ wer|round(4) }}

+ {% endif %} + {% if fca is not none %} +

FCA: {{ fca|round(4) }}

+ {% endif %} {% endif %} +{% if char_diff_report %}

Character differences

{{ char_diff_report }} +{% endif %} +{% if word_diff_report %}

Word differences

{{ word_diff_report }} +{% endif %} +{% if fca_diff_report %} +

Flexible character accuracy differences

+{{ fca_diff_report }} +{% endif %} diff --git a/qurator/dinglehopper/templates/report.json.j2 b/qurator/dinglehopper/templates/report.json.j2 index 0e8af03..a632590 100644 --- a/qurator/dinglehopper/templates/report.json.j2 +++ b/qurator/dinglehopper/templates/report.json.j2 @@ -1,10 +1,11 @@ { - "gt": "{{ gt }}", - "ocr": "{{ ocr }}", {% if metrics %} - "cer": {{ cer|json_float }}, - "wer": {{ wer|json_float }}, + {% if cer is not none %}"cer": {{ cer|json_float }},{% endif %} + {% if wer is not none %}"wer": {{ wer|json_float }},{% endif %} + {% if fca is not none %}"fca": {{ fca|json_float }},{% endif %} + {% if n_characters is not none %}"n_characters": {{ n_characters }},{% endif %} + {% if n_words is not none %}"n_words": {{ n_words }},{% endif %} {% endif %} - "n_characters": {{ n_characters }}, - "n_words": {{ n_words }} + "gt": {{ gt|json_dumps }}, + "ocr": {{ ocr|json_dumps }} } diff --git a/qurator/dinglehopper/tests/extracted_text_test.py b/qurator/dinglehopper/tests/extracted_text_test.py index 8a81587..c39b3a3 100644 --- a/qurator/dinglehopper/tests/extracted_text_test.py +++ b/qurator/dinglehopper/tests/extracted_text_test.py @@ -4,6 +4,7 @@ import pytest from lxml import etree as ET +from ocrd_utils import getLogger from uniseg.graphemecluster import grapheme_clusters from .. import seq_align, ExtractedText @@ -117,6 +118,7 @@ def test_align(): ) def test_textequiv(attributes, expected_index, expected_log, caplog): """Test that extracting text from a PAGE TextEquiv is working without index attr.""" + getLogger("processor.OcrdDinglehopperEvaluate") caplog.set_level(logging.INFO) xml = '' ns = "http://schema.primaresearch.org/PAGE/gts/pagecontent/2018-07-15" @@ -134,6 +136,7 @@ def test_textequiv(attributes, expected_index, expected_log, caplog): result = ExtractedText.from_text_segment( root, {"page": ns}, textequiv_level="line" ).text + if expected_index is None: assert not result else: diff --git a/qurator/dinglehopper/tests/test_flexible_character_accuracy.py b/qurator/dinglehopper/tests/test_flexible_character_accuracy.py new file mode 100644 index 0000000..6ef316b --- /dev/null +++ b/qurator/dinglehopper/tests/test_flexible_character_accuracy.py @@ -0,0 +1,463 @@ +""" +Tests for the implementation of the flexible character accuracy + +Citation: + Flexible character accuracy measure for reading-order-independent evaluation + C. Clausner, S. Pletschacher, A. Antonacopoulos + Pattern Recognition Letters, Volume 131, March 2020, Pages 390-397 +Link: http://www.primaresearch.org/publications/PRL_Clausner_FlexibleCharacterAccuracy +DOI: 10.1016/j.patrec.2020.02.003 +""" + +import pytest +from lxml import etree as ET + +from ..flexible_character_accuracy import * + +CASE_ARGS = "gt,ocr,first_line_score,all_line_score" + +SIMPLE_CASES = [ + ("a", "", 0, 0), + ("a", "a", 1, 1), + ("a\nb", "a\nb", 1, 1), + ("a\nb", "b\na", 1, 1), + ("aaa\nbbb\nccc", "ccc\naaa\nbbb", 1, 1), + ("aaa\nbbb\nccc", "aaa\nbbb", 1, 1 - 3 / 9), + ("bbb", "aaa\nbbb\nccc", 1, 1 - 6 / 3), + ("a", "a\nbb\nccc", 1, 1 - 5 / 1), + ("bb", "a\nbb\nccc", 1, 1 - 4 / 2), + ("abcd", "ab\ne", 1, 1 - 3 / 4), +] + +COMPLEX_CASES = [ + ("accc", "a\nbb\nccc", 1, 1 - 2 / 4), + ("aaa\nbbb\nccc", "bbb", 1, 1 - 6 / 9), +] + +EXTENDED_CASES = [ + # See figure 4 in 10.1016/j.patrec.2020.02.003 + # A: No errors + ((0, 1, 2, 3, 4, 11, 5, 6, 7, 8, 9), (0, 1, 2, 3, 4, 11, 5, 6, 7, 8, 9), 1, 1), + # B: Different ordering of text blocks + ((0, 1, 2, 3, 4, 11, 5, 6, 7, 8, 9), (5, 6, 7, 8, 9, 11, 0, 1, 2, 3, 4), 1, 1), + # C: Merge across columns + ( + (0, 1, 2, 11, 3, 4, 11, 5, 6, 7, 11, 8, 9), + (0, 1, 2, 5, 6, 7, 11, 3, 4, 8, 9), + 1, + 0.964, + ), + # D: Over-segmentation + ( + (0, 1, 2, 3, 4, 11, 5, 6, 7, 8, 9), + (0, 1, 2, 11, 5, 6, 7, 11, 3, 4, 11, 8, 9), + 1, + 0.966, + ), + # E: Part missing + ((0, 1, 2, 3, 4, 11, 5, 6, 7, 8, 9), (0, 1, 2, 3, 4), 1, 0.50), + # E.2: Part missing + ((0, 1, 2, 3, 4, 11, 5, 6, 7, 8, 9), (5, 6, 7, 8, 9), 1, 0.50), + # F: All missing + ((0, 1, 2, 3, 4, 11, 5, 6, 7, 8, 9), (), 1, 0), + # G: Added parts + ((0, 1, 2, 3, 4), (0, 1, 2, 3, 4, 11, 5, 6), 1, 0.621), +] + +EDIT_ARGS = "gt,ocr,expected_dist" + +SIMPLE_EDITS = [ + (Part(text="a"), Part(text="a"), Distance(match=1)), + (Part(text="aaa"), Part(text="aaa"), Distance(match=3)), + ( + Part(text="abbbbcd"), + Part(text="bbbbede"), + Distance(match=5, replace=1, insert=1, delete=1), + ), +] + + +def extended_case_to_text(gt, ocr): + """Generate sentence from reading order encoding. + + See figure 4 in 10.1016/j.patrec.2020.02.003. + """ + sentence = ( + "Eight", + "happy", + "frogs", + "scuba", + "dived", + "Jenny", + "chick", + "flaps", + "white", + "wings", + "", + "\n", + ) + + gt_sentence = " ".join(sentence[i] for i in gt).replace(" \n ", "\n") + ocr_sentence = " ".join(sentence[i] for i in ocr).replace(" \n ", "\n") + return gt_sentence, ocr_sentence + + +@pytest.mark.parametrize(CASE_ARGS, [*SIMPLE_CASES, *COMPLEX_CASES]) +def test_flexible_character_accuracy_str(gt, ocr, first_line_score, all_line_score): + score, _ = flexible_character_accuracy(gt, ocr, 1) + assert score == pytest.approx(all_line_score) + + +@pytest.mark.parametrize(CASE_ARGS, [*SIMPLE_CASES, *COMPLEX_CASES]) +def test_flexible_character_accuracy_xml(gt, ocr, first_line_score, all_line_score): + def get_extracted_text(text: str): + xml = '' + ns = "http://schema.primaresearch.org/PAGE/gts/pagecontent/2018-07-15" + + textline_tmpl = ( + '{1}' + "" + ) + xml_tmpl = '{0}{2}' + + textlines = [ + textline_tmpl.format(i, line) for i, line in enumerate(text.splitlines()) + ] + xml_text = xml_tmpl.format(xml, ns, "".join(textlines)) + root = ET.fromstring(xml_text) + extracted_text = ExtractedText.from_text_segment( + root, {"page": ns}, textequiv_level="line" + ) + return extracted_text + + gt_text = get_extracted_text(gt) + ocr_text = get_extracted_text(ocr) + score, _ = flexible_character_accuracy(gt_text, ocr_text, 1) + assert score == pytest.approx(all_line_score) + + +@pytest.mark.parametrize( + "config,ocr", + [ + ( + "Config I", + "1 hav\nnospecial\ntalents.\n" + 'I am one\npassionate\ncuriousity."\n' + "Alberto\nEmstein", + ), + ( + "Config II", + "1 hav\nnospecial\ntalents. Alberto\n" + 'I am one Emstein\npassionate\ncuriousity."', + ), + ( + "Config III", + "Alberto\nEmstein\n" + "1 hav\nnospecial\ntalents.\n" + 'I am one\npassionate\ncuriousity."', + ), + ], +) +def test_flexible_character_accuracy(config, ocr): + """Tests from figure 3 in the 10.1016/j.patrec.2020.02.003.""" + gt = ( + '"I have\nno special\ntalent.\n' + 'I am only\npassionately\ncurious."\n' + "Albert\nEinstein" + ) + replacements, inserts, deletes = 3, 5, 7 + chars = len(gt) - gt.count("\n") + assert chars == 68 + + # We consider whitespace as error and in Config II two additional + # whitespaces have been introduced. One will be counted as insert. + # The other whitespace will be counted as replacement, + # additionally reducing the number of deletes. + if config == "Config II": + inserts += 1 + replacements += 1 + deletes -= 1 + + expected_dist = Distance( + match=chars - deletes - replacements, + replace=replacements, + insert=inserts, + delete=deletes, + ) + expected_score = character_accuracy(expected_dist) + + result, matches = flexible_character_accuracy(gt, ocr, 1) + agg = reduce( + lambda acc, match: acc + Counter(match.dist._asdict()), matches, Counter() + ) + dist = Distance(**agg) + assert dist == expected_dist + assert result == pytest.approx(expected_score, abs=0.0005) + + +@pytest.mark.parametrize(CASE_ARGS, EXTENDED_CASES) +def test_flexible_character_accuracy_extended( + gt, ocr, first_line_score, all_line_score +): + """Tests from figure 4 in the 10.1016/j.patrec.2020.02.003.""" + gt_sentence, ocr_sentence = extended_case_to_text(gt, ocr) + result, _ = flexible_character_accuracy(gt_sentence, ocr_sentence, 1) + assert result == pytest.approx(all_line_score, abs=0.001) + + +@pytest.mark.parametrize(CASE_ARGS, [*SIMPLE_CASES, *COMPLEX_CASES, *EXTENDED_CASES]) +def test_match_with_coefficients(gt, ocr, first_line_score, all_line_score): + coef = Coefficients() + if not isinstance(gt, str): + gt, ocr = extended_case_to_text(gt, ocr) + matches = match_with_coefficients(coef, gt, ocr) + score = character_accuracy_for_matches(matches) + assert score == pytest.approx(all_line_score, abs=0.001) + + +@pytest.mark.parametrize(CASE_ARGS, [*SIMPLE_CASES, *COMPLEX_CASES]) +def test_match_longest_gt_lines(gt, ocr, first_line_score, all_line_score): + coef = Coefficients() + gt_lines = initialize_lines(gt) + ocr_lines = initialize_lines(ocr) + match = match_longest_gt_lines(gt_lines, ocr_lines, coef) + score = 0 + if match: + score = character_accuracy(match.dist) + assert score == pytest.approx(first_line_score) + + +@pytest.mark.parametrize( + CASE_ARGS, + [ + *SIMPLE_CASES, + ("accc", "a\nbb\nccc", 1.0, 1.0), + ], +) +def test_match_gt_line(gt, ocr, first_line_score, all_line_score): + coef = Coefficients() + gt_lines = initialize_lines(gt) + ocr_lines = initialize_lines(ocr) + match, _ = match_gt_line(gt_lines[0], ocr_lines, coef) + score = 0 + if match: + score = character_accuracy(match.dist) + assert score == pytest.approx(first_line_score) + + +@pytest.mark.parametrize( + "original,match,expected_lines", + [ + (Part(), Part(), []), + (Part(text="abc"), Part(), [Part(text="abc")]), + (Part(text="abc"), Part("d"), [Part(text="bc", start=1)]), + (Part(text="abc"), Part("a", start=100), [Part(text="abc")]), + (Part(text="abc"), Part("a"), [Part(text="bc", start=1)]), + ( + Part(text="abc"), + Part("b", start=1), + [Part(text="a"), Part(text="c", start=2)], + ), + (Part(text="abc"), Part("c", start=2), [Part(text="ab")]), + ], +) +def test_remove_or_split(original, match, expected_lines): + lines = [original] + splitted = remove_or_split(original, match, lines) + assert splitted == (len(lines) > 0) + assert lines == expected_lines + + +@pytest.mark.parametrize( + EDIT_ARGS, + [ + *SIMPLE_EDITS, + (Part(text="a"), Part(text="b"), Distance(delete=1)), + (Part(text="ab"), Part(text="c"), Distance(delete=2)), + (Part(text="abc"), Part(text="d"), Distance(delete=3)), + (Part(text="aaa"), Part(text="bbb"), Distance(delete=3)), + (Part(text="aaabbbaaa"), Part(text="bbb"), Distance(match=3)), + (Part(text="bbb"), Part(text="aaabbbaaa"), Distance(match=3)), + (Part(text=""), Part(text=""), None), + (Part(text="abcd"), Part(text="acd"), Distance(match=3, delete=1)), + (Part(text="abc"), Part(text="abdc"), Distance(match=3, insert=1)), + ( + Part(text="aaabbbaaaddd"), + Part(text="aaabcbaaa"), + Distance(match=8, replace=1), + ), + ( + Part(text="aaabbbccc"), + Part(text="aaabbbdddccc"), + Distance(match=9, insert=3), + ), + ], +) +def test_match_lines(gt, ocr, expected_dist): + match = match_lines(gt, ocr) + if not expected_dist: + assert match is None + else: + assert match.gt.text in gt.text + assert match.ocr.text in ocr.text + assert match.dist == expected_dist + + +@pytest.mark.parametrize( + EDIT_ARGS, + [ + *SIMPLE_EDITS, + (Part(text="").substring(), Part(text=""), Distance()), + (Part(text="ab").substring(), Part("a"), Distance(match=1, delete=1)), + (Part(text="a").substring(), Part("ab"), Distance(match=1, insert=1)), + (Part(text="a"), Part(text="b"), Distance(replace=1)), + (Part(text="aaa"), Part(text="bbb"), Distance(replace=3)), + ], +) +def test_distance(gt, ocr, expected_dist): + match = distance(gt, ocr) + assert match.gt == gt + assert match.ocr == ocr + assert match.dist == expected_dist + + +@pytest.mark.parametrize( + "matches,expected_dist", + [ + ([], 1), + ([Match(gt=Part(text=""), ocr=Part(text=""), dist=Distance(), ops=[])], 1), + ( + [ + Match( + gt=Part(text="abee"), + ocr=Part("ac"), + dist=Distance(match=1, replace=1, delete=2), + ops=[], + ), + Match( + gt=Part(text="cd"), + ocr=Part("ceff"), + dist=Distance(match=1, replace=1, insert=2), + ops=[], + ), + ], + 1 - 6 / 6, + ), + ], +) +def test_character_accuracy_matches(matches, expected_dist): + assert character_accuracy_for_matches(matches) == pytest.approx(expected_dist) + + +@pytest.mark.parametrize( + "dist,expected_dist", + [ + (Distance(), 1), + (Distance(match=1), 1), + (Distance(replace=1), 0), + (Distance(delete=1), 0), + (Distance(insert=1), -1), + (Distance(match=1, insert=1), 0), + (Distance(match=1, insert=2), 1 - 2 / 1), + (Distance(match=2, insert=1), 0.5), + (Distance(match=1, delete=1), 0.5), + ], +) +def test_character_accuracy_dist(dist, expected_dist): + assert character_accuracy(dist) == pytest.approx(expected_dist) + + +@pytest.mark.parametrize( + "line,subline,expected_rest", + [ + (Part(), Part(), []), + (Part("aaa bbb"), Part("aaa bbb"), []), + (Part("aaa bbb"), Part("aaa"), [Part(" bbb", start=3)]), + (Part("aaa bbb"), Part("bbb", start=4), [Part("aaa ")]), + (Part("aaa bbb", start=3), Part("aaa", start=3), [Part(" bbb", start=6)]), + (Part("aaa bbb", start=3), Part("bbb", start=7), [Part("aaa ", start=3)]), + ( + Part("aaa bbb ccc"), + Part("bbb", start=4), + [Part("aaa "), Part(" ccc", start=7)], + ), + ( + Part("aaa bbb ccc", start=3), + Part("bbb", start=7), + [Part("aaa ", start=3), Part(" ccc", start=10)], + ), + (Part("aaa bbb"), Part(" ", start=3), [Part("aaa"), Part("bbb", start=4)]), + ( + Part("aaa bbb", start=3), + Part(" ", start=6), + [Part("aaa", start=3), Part("bbb", start=7)], + ), + ], +) +def test_split_line(line, subline, expected_rest): + rest = line.split(subline) + assert len(rest) == len(expected_rest) + assert set(rest) == set(expected_rest) + + +def test_initialize_lines(): + lines = initialize_lines("") + assert lines == [] + + lines = initialize_lines("22\n1\n333") + line1 = Part(text="22", line=0, start=0) + line2 = Part("1", line=1, start=0) + line3 = Part("333", line=2, start=0) + assert lines == [line3, line1, line2] + + +@pytest.mark.parametrize( + "matches,expected_gt,expected_ocr,expected_ops", + [ + ([], [], [], []), + ( + [Match(gt=Part(text="aaa"), ocr=Part(text="aaa"), dist=Distance(), ops=[])], + ["aaa"], + ["aaa"], + [[]], + ), + ( + [ + Match( + gt=Part(text="aaa", line=1), + ocr=Part(text="aaa"), + dist=Distance(), + ops=[], + ), + Match( + gt=Part(text="bbb", line=2), + ocr=Part(text="bbc"), + dist=Distance(), + ops=[["replace", 2]], + ), + ], + ["\n", "aaa", "\n", "bbb"], + ["\n", "aaa", "\n", "bbc"], + [[], [], [], [["replace", 2]]], + ), + ], +) +def test_split_matches(matches, expected_gt, expected_ocr, expected_ops): + gt_segments, ocr_segments, ops = split_matches(matches) + assert gt_segments == expected_gt + assert ocr_segments == expected_ocr + assert ops == expected_ops + + +@pytest.mark.parametrize( + "line,start,end,expected", + [ + (Part(text=""), 0, None, Part(text="")), + (Part(text="a"), 0, None, Part(text="a")), + (Part(text="ab"), 0, 1, Part(text="a")), + (Part(text="abc"), 0, -1, Part(text="ab")), + (Part(text="ab"), 1, None, Part(text="b", start=1)), + ], +) +def test_line_substring(line, start, end, expected): + assert line.substring(rel_start=start, rel_end=end) == expected diff --git a/qurator/dinglehopper/tests/test_integ_cli_valid_json.py b/qurator/dinglehopper/tests/test_integ_cli_valid_json.py index 9d52329..1092a92 100644 --- a/qurator/dinglehopper/tests/test_integ_cli_valid_json.py +++ b/qurator/dinglehopper/tests/test_integ_cli_valid_json.py @@ -1,4 +1,5 @@ import json +from itertools import combinations import pytest from .util import working_directory @@ -7,9 +8,19 @@ @pytest.mark.integration -def test_cli_json(tmp_path): +@pytest.mark.parametrize( + "metrics", + [ + *(("",), ("cer",), ("wer",), ("fca",)), + *combinations(("cer", "wer", "fca"), 2), + ("cer", "wer", "fca"), + ], +) +def test_cli_json(metrics, tmp_path): """Test that the cli/process() yields a loadable JSON report""" + expected_values = {"cer": 0.2, "wer": 1.0, "fca": 0.8} + with working_directory(str(tmp_path)): with open("gt.txt", "w") as gtf: gtf.write("AAAAA") @@ -18,12 +29,18 @@ def test_cli_json(tmp_path): with open("gt.txt", "r") as gtf: print(gtf.read()) - process("gt.txt", "ocr.txt", "report") + + process("gt.txt", "ocr.txt", "report", metrics=",".join(metrics)) + with open("report.json", "r") as jsonf: print(jsonf.read()) with open("report.json", "r") as jsonf: j = json.load(jsonf) - assert j["cer"] == pytest.approx(0.2) + for metric, expected_value in expected_values.items(): + if metric in metrics: + assert j[metric] == pytest.approx(expected_values[metric]) + else: + assert metric not in j.keys() @pytest.mark.integration @@ -36,7 +53,23 @@ def test_cli_json_cer_is_infinity(tmp_path): with open("ocr.txt", "w") as ocrf: ocrf.write("Not important") - process("gt.txt", "ocr.txt", "report") + process("gt.txt", "ocr.txt", "report", metrics="cer,wer,fca") with open("report.json", "r") as jsonf: j = json.load(jsonf) assert j["cer"] == pytest.approx(float("inf")) + assert j["fca"] == pytest.approx(-13) + + +def test_cli_json_cer_0_in_report(tmp_path): + """Test that the cli/process() yields a loadable JSON report when CER == 0""" + + with working_directory(str(tmp_path)): + with open("gt.txt", "w") as gtf: + gtf.write("Lorem Ipsum") + + process("gt.txt", "gt.txt", "report", metrics="cer,wer,fca") + with open("report.json", "r") as jsonf: + j = json.load(jsonf) + assert j["cer"] == pytest.approx(0) + assert j["wer"] == pytest.approx(0) + assert j["fca"] == pytest.approx(1) diff --git a/qurator/dinglehopper/tests/test_integ_flexible_character_accuracy.py b/qurator/dinglehopper/tests/test_integ_flexible_character_accuracy.py new file mode 100644 index 0000000..f7299bd --- /dev/null +++ b/qurator/dinglehopper/tests/test_integ_flexible_character_accuracy.py @@ -0,0 +1,69 @@ +import os + +import pytest +from lxml import etree as ET + +from .. import distance, page_text, extract +from .. import flexible_character_accuracy, split_matches + + +@pytest.mark.parametrize("file", ["table-order-0002.xml", "table-no-reading-order.xml"]) +@pytest.mark.integration +def test_fac_ignoring_reading_order(file): + data_dir = os.path.join( + os.path.dirname(os.path.abspath(__file__)), "data", "table-order" + ) + expected = "1\n2\n3\n4\n5\n6\n7\n8\n9" + + gt = page_text(ET.parse(os.path.join(data_dir, "table-order-0001.xml"))) + assert gt == expected + + ocr = page_text(ET.parse(os.path.join(data_dir, file))) + assert distance(gt, ocr) > 0 + + fac, matches = flexible_character_accuracy(gt, ocr) + assert fac == pytest.approx(1.0) + + gt_segments, ocr_segments, ops = split_matches(matches) + assert not any(ops) + assert "".join(gt_segments) == expected + assert "".join(ocr_segments) == expected + + +@pytest.mark.integration +@pytest.mark.parametrize( + "gt,ocr,expected", + [ + ( + "brochrnx_73075507X/00000139.gt.page.xml", + "brochrnx_73075507X/00000139.ocrd-tess.ocr.page.xml", + 0.93, + ), + ( + "actevedef_718448162/OCR-D-GT-PAGE/00000024.page.xml", + "actevedef_718448162/OCR-D-OCR-TESS/OCR-D-OCR-TESS_0001.xml", + 0.96, + ), + ( + "actevedef_718448162/OCR-D-GT-PAGE/00000024.page.xml", + "actevedef_718448162/OCR-D-OCR-CALAMARI/OCR-D-OCR-CALAMARI_0001.xml", + 0.97, + ), + ( + "lorem-ipsum/lorem-ipsum-scan.gt.page.xml", + "lorem-ipsum/lorem-ipsum-scan.ocr.tesseract.alto.xml", + 1.0, + ), + ( + "lorem-ipsum/lorem-ipsum-scan-bad.gt.page.xml", + "lorem-ipsum/lorem-ipsum-scan-bad.ocr.tesseract.alto.xml", + 0.98, + ), + ], +) +def test_ocr_files(gt, ocr, expected): + data_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "data") + gt_et = extract(os.path.join(data_dir, gt)) + ocr_et = extract(os.path.join(data_dir, ocr)) + score, _ = flexible_character_accuracy(gt_et, ocr_et) + assert score == pytest.approx(expected, abs=0.01) diff --git a/requirements.txt b/requirements.txt index 7bb53ac..99172c1 100644 --- a/requirements.txt +++ b/requirements.txt @@ -9,3 +9,4 @@ ocrd >= 2.20.1 attrs multimethod == 1.3 # latest version to officially support Python 3.5 tqdm +python-levenshtein