Skip to content

Commit

Permalink
chore: analyze language identification models performance on short in…
Browse files Browse the repository at this point in the history
…gredient texts with precision-recall evaluation (#349) (#365)

* analyze language identification models performance on
short ingredient texts with precision-recall evaluation (#349)

* split script 03_calculate_metrics into inference and metrics

* chore: add new label for language identification in labeler.yml

* Update language_identification/scripts/01_extract_data.py

---------

Co-authored-by: Yulia Zhilyaeva <[email protected]>
Co-authored-by: Raphaël Bournhonesque <[email protected]>
  • Loading branch information
3 people authored Dec 12, 2024
1 parent 7bd75a3 commit 405eb1d
Show file tree
Hide file tree
Showing 12 changed files with 3,617 additions and 0 deletions.
2 changes: 2 additions & 0 deletions .github/labeler.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,5 @@ logo-ann:
student projects:
- ai-emlyon/**/*

language identification:
- language_identification/**/*
1,981 changes: 1,981 additions & 0 deletions language_identification/poetry.lock

Large diffs are not rendered by default.

30 changes: 30 additions & 0 deletions language_identification/pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
[tool.poetry]
name = "language_identification"
version = "0.1.0"
description = ""
authors = ["Yulia Zhilyaeva"]
readme = "./scripts/README.md"
package-mode = false

[tool.poetry.dependencies]
python = "^3.12"
pandas = "^2.2.3"
tqdm = "^4.67.0"
datasets = "^3.1.0"
joblib = "^1.4.2"
openfoodfacts = "^2.1.0"
polyglot = "^16.7.4"
numpy = "^2.1.3"
langcodes = "^3.4.1"
scikit-learn = "^1.5.2"
fasttext = "^0.9.3"
lingua-language-detector = "^2.0.2"
huggingface-hub = "^0.26.2"
pyicu = "^2.14"
pycld2 = "^0.41"
morfessor = "^2.0.6"


[build-system]
requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api"
28 changes: 28 additions & 0 deletions language_identification/scripts/01_extract_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import os
from pathlib import Path
import pandas as pd
from datasets import load_dataset


def main():
# extracting all texts with their languages from huggingface dataset

# path where to save selected data
dataset_file = os.path.join(Path(__file__).parent, "texts_with_lang.csv")

hf_dataset = load_dataset("openfoodfacts/product-database", split="food")

data = set()
for entry, main_lang in zip(hf_dataset["ingredients_text"], hf_dataset["lang"]): # iterate over products
for product_in_lang in entry:
if product_in_lang["text"]:
lang = main_lang if product_in_lang["lang"] == "main" else product_in_lang["lang"]
data.add((product_in_lang["text"], lang))

df = pd.DataFrame(data, columns=["ingredients_text", "lang"])
df.dropna(inplace=True)
df.to_csv(dataset_file, index=False)


if __name__ == "__main__":
main()
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
import os
from pathlib import Path
import unicodedata
import pandas as pd
from tqdm import tqdm
from openfoodfacts import API, APIVersion, Environment
from polyglot.text import Text


def is_punctuation(word: str) -> bool:
"""
Check if the string `word` is a punctuation mark.
"""
return all(unicodedata.category(char).startswith("P") for char in word)


def select_texts_by_len(data: pd.DataFrame, min_len: int, max_len: int) -> pd.DataFrame:
"""
Select rows from `data` where the number of words in `ingredients_text`
is between `min_len` and `max_len` inclusively, excluding punctuation.
Args:
data: pandas.DataFrame with columns `ingredients_text`, `lang`
min_len: Minimum number of words.
max_len: Maximum number of words.
Returns:
A pandas DataFrame containing the rows from the `data` that satisfy the word count condition.
"""
selected_rows = []

for _, row in tqdm(data.iterrows(), total=len(data)):
# the object that recognizes individual words in text
text = Text(row.ingredients_text, hint_language_code=row.lang)
# `Text` recognizes punctuation marks as words
words = [word for word in text.words if not is_punctuation(word)]
if min_len <= len(words) <= max_len:
selected_rows.append(row)

selected_df = pd.DataFrame(selected_rows)
return selected_df


def main():
dataset_file = os.path.join(Path(__file__).parent, "texts_with_lang.csv")
all_data = pd.read_csv(dataset_file)

short_texts_df = select_texts_by_len(all_data, min_len=0, max_len=10)

# perform ingredient analysis
api = API(user_agent="langid",
version=APIVersion.v3,
environment=Environment.net)

threshold = 0.8
texts_with_known_ingredients = []

# select ingredients texts with the rate of known ingredients >= `threshold`
for i, row in tqdm(short_texts_df.iterrows(), total=len(short_texts_df)):
try:
ingredient_analysis_results = api.product.parse_ingredients(row.ingredients_text, lang=row.lang)
except RuntimeError:
continue

is_in_taxonomy = sum(dct.get("is_in_taxonomy", 0) for dct in ingredient_analysis_results)
is_in_taxonomy_rate = is_in_taxonomy / len(ingredient_analysis_results) \
if len(ingredient_analysis_results) > 0 else 0.

if is_in_taxonomy_rate >= threshold:
texts_with_known_ingredients.append(row)

texts_with_known_ingredients_df = pd.DataFrame(texts_with_known_ingredients)

# add short texts from manually checked data
manually_checked_data = pd.read_csv(os.path.join(Path(__file__).parent, "manually_checked_data.csv"))
short_texts_manual = select_texts_by_len(manually_checked_data, min_len=0, max_len=10)

# combine data and save
all_texts_under_10_words = pd.concat((short_texts_manual, texts_with_known_ingredients_df), ignore_index=True)
all_texts_under_10_words.drop_duplicates(inplace=True, ignore_index=True)
all_texts_under_10_words.to_csv(os.path.join(Path(__file__).parent, "texts_under_10_words.csv"), index=False)


if __name__ == "__main__":
main()
95 changes: 95 additions & 0 deletions language_identification/scripts/03_calculate_metrics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
import os
from pathlib import Path
import numpy as np
import pandas as pd
from tqdm import trange
from typing import Sequence

from sklearn.metrics import confusion_matrix

from language_identification.scripts.inference import calculate_fasttext_predictions, calculate_lingua_predictions


def replace_lang_code(model_predictions: list[str], mapping: dict[str, str]) -> None:
"""
Replace predicted language codes in the `model_predictions` list
using `mapping`, where:
- Keys represent the original language codes (predicted by the model)
- Values represent the target language codes to replace them with.
The purpose of this function is to standardize language codes
by combining multiple variants of the same language into a unified code
in order to match supported languages.
"""
for i in trange(len(model_predictions)):
if model_predictions[i] in mapping:
model_predictions[i] = mapping[model_predictions[i]]


def calculate_metrics(cm: np.ndarray, labels: Sequence, model_name: str) -> pd.DataFrame:
"""
Calculate precision, recall and f1-score.
Args:
cm: confusion matrix
labels: languages, for which the metrics need to be calculated
model_name: model name (needed for column names in DataFrame)
Returns: pandas.DataFrame with computed metrics for each language.
"""
tp_and_fn = cm.sum(axis=1)
tp_and_fp = cm.sum(axis=0)
tp = cm.diagonal()

precision = np.divide(tp, tp_and_fp, out=np.zeros_like(tp, dtype=float), where=tp_and_fp > 0)
recall = np.divide(tp, tp_and_fn, out=np.zeros_like(tp, dtype=float), where=tp_and_fn > 0)
f1 = np.divide(2 * precision * recall, precision + recall, out=np.zeros_like(precision, dtype=float),
where=(precision + recall) > 0)

df = pd.DataFrame({
"lang": labels,
f"{model_name}_precision": precision,
f"{model_name}_recall": recall,
f"{model_name}_f1": f1,
})

return df


def main():
texts_under_10_words = pd.read_csv(os.path.join(Path(__file__).parent, "texts_under_10_words.csv"))
texts = texts_under_10_words.ingredients_text.tolist()
true_labels = texts_under_10_words.lang.tolist()
possible_class_labels = texts_under_10_words["lang"].value_counts().index.tolist() # use value_counts in order to get sorted by frequency list

fasttext_preds = calculate_fasttext_predictions(texts)
lingua_preds = calculate_lingua_predictions(texts)

mapping = {"yue": "zh"} # yue is a type of Chinese
replace_lang_code(fasttext_preds, mapping)
replace_lang_code(lingua_preds, mapping)

predictions = [fasttext_preds, lingua_preds]
model_names = ["fasttext", "lingua"]
metrics = []
for preds, model_name in zip(predictions, model_names):
cm = confusion_matrix(true_labels, preds, labels=possible_class_labels)
cm_df = pd.DataFrame(cm, index=possible_class_labels, columns=possible_class_labels)
cm_df.to_csv(os.path.join(Path(__file__).parent, f"{model_name}_confusion_matrix.csv"))

metrics_df = calculate_metrics(cm, possible_class_labels, model_name)
metrics_df.set_index("lang", inplace=True)
metrics.append(metrics_df)

# combine results
metrics_df = pd.DataFrame(texts_under_10_words.lang.value_counts())
metrics_df = pd.concat((metrics_df, *metrics), axis=1)

# change columns order
metrics_df = metrics_df[
["count"] + [f"{model}_{metric}" for metric in ["precision", "recall", "f1"] for model in model_names]
]

metrics_df.to_csv(os.path.join(Path(__file__).parent, "10_words_metrics.csv"))

if __name__ == "__main__":
main()
42 changes: 42 additions & 0 deletions language_identification/scripts/10_words_metrics.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
lang,count,fasttext_precision,lingua_precision,fasttext_recall,lingua_recall,fasttext_f1,lingua_f1
en,36319,0.937642,0.987990,0.881476,0.934366,0.908692,0.960430
fr,33759,0.980412,0.989449,0.828173,0.931849,0.897886,0.959786
de,15540,0.971413,0.978929,0.834112,0.954839,0.897542,0.966734
es,6010,0.843908,0.929216,0.855058,0.799721,0.849447,0.859619
it,4098,0.665311,0.826041,0.878067,0.912855,0.757025,0.867281
nl,3877,0.700655,0.905875,0.870591,0.925161,0.776434,0.915417
pl,1350,0.807497,0.911485,0.962879,0.978790,0.878369,0.943939
sv,1022,0.627437,0.833021,0.869000,0.949733,0.728721,0.887556
pt,839,0.543636,0.426197,0.838710,0.905367,0.659680,0.579566
fi,748,0.531029,0.565506,0.927298,0.982709,0.675325,0.717895
bg,712,0.982072,0.986667,0.834179,0.951768,0.902104,0.968903
hr,518,0.459313,0.883212,0.666667,0.889706,0.543897,0.886447
nb,376,0.371212,0.196429,0.710145,0.416667,0.487562,0.266990
ru,362,0.876877,0.974763,0.926984,0.916914,0.901235,0.944954
hu,235,0.402367,0.682119,0.886957,0.911504,0.553596,0.780303
cs,212,0.566038,0.811881,0.568720,0.828283,0.567376,0.820000
ro,163,0.334239,0.275395,0.836735,0.853147,0.477670,0.416382
da,115,0.211628,0.127152,0.834862,0.950495,0.337662,0.224299
ca,90,0.052419,0.116883,0.822785,0.805970,0.098560,0.204159
ja,60,1.000000,1.000000,0.305556,0.633333,0.468085,0.775510
ar,44,0.854167,0.974359,0.953488,0.926829,0.901099,0.950000
lv,34,0.198582,0.337079,0.848485,0.937500,0.321839,0.495868
zh,25,0.063830,0.531915,0.187500,1.000000,0.095238,0.694444
lt,23,0.068182,0.147887,0.954545,0.913043,0.127273,0.254545
el,22,1.000000,1.000000,0.954545,1.000000,0.976744,1.000000
tr,17,0.066667,0.121495,0.750000,0.928571,0.122449,0.214876
sk,16,0.081250,0.130000,0.812500,0.812500,0.147727,0.224138
uk,15,0.379310,0.423077,0.785714,0.846154,0.511628,0.564103
th,14,1.000000,0.928571,0.857143,0.928571,0.923077,0.928571
sl,14,0.062500,0.159420,0.833333,0.916667,0.116279,0.271605
he,13,1.000000,1.000000,1.000000,1.000000,1.000000,1.000000
bn,9,1.000000,1.000000,1.000000,1.000000,1.000000,1.000000
is,8,0.160000,0.112903,1.000000,1.000000,0.275862,0.202899
sr,8,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000
no,7,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000
et,5,0.005822,0.010471,0.800000,0.800000,0.011561,0.020672
id,5,0.015228,0.000000,0.750000,0.000000,0.029851,0.000000
ko,5,0.047059,1.000000,0.800000,1.000000,0.088889,1.000000
sq,3,0.000000,0.015385,0.000000,1.000000,0.000000,0.030303
ka,2,1.000000,1.000000,1.000000,1.000000,1.000000,1.000000
vi,1,0.001050,0.200000,1.000000,1.000000,0.002099,0.333333
16 changes: 16 additions & 0 deletions language_identification/scripts/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# Research of the quality of models on ingredient texts up to 10 words long.

`01_extract_data.py`: extracts all texts with their languages from [huggingface dataset](https://huggingface.co/datasets/openfoodfacts/product-database).

`02_select_short_texts_with_known_ingredients.py`: filters texts with length up to 10 words, performs ingredient analysis by OFF API, selects ingredient texts with at least 80% of known ingredients, adds short texts from manually checked data.

What is manually checked data: \
I created a validation dataset from texts from OFF (42 languages, 15-30 texts per language).
I took 30 random texts in each language, obtained language predictions using the Deepl API and two other models ([language-detection-fine-tuned-on-xlm-roberta-base](https://huggingface.co/ivanlau/language-detection-fine-tuned-on-xlm-roberta-base) and [multilingual-e5-language-detection](https://huggingface.co/Mike0307/multilingual-e5-language-detection)). For languages they don’t support, I used Google Translate and ChatGPT for verification. (As a result, after correcting the labels, some languages have fewer than 30 texts).


`03_calculate_metrics.py`: obtains predictions by [FastText](https://huggingface.co/facebook/fasttext-language-identification) and [lingua language detector](https://github.com/pemistahl/lingua-py) models for texts up to 10 words long, and calculates precision, recall and f1-score.

Results are in files: [metrics](./10_words_metrics.csv), [FastText confusion matrix](./fasttext_confusion_matrix.csv), [lingua confusion matrix](./lingua_confusion_matrix.csv).

It turned out that both models demonstrate low precision and high recall for some languages (indicating that the threshold might be too high and should be adjusted).
42 changes: 42 additions & 0 deletions language_identification/scripts/fasttext_confusion_matrix.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
,en,fr,de,es,it,nl,pl,sv,pt,fi,bg,hr,nb,ru,hu,cs,ro,da,ca,ja,ar,lv,zh,lt,el,tr,sk,uk,th,sl,he,bn,is,sr,no,et,id,ko,sq,ka,vi
en,31050,308,79,414,1092,353,119,161,23,188,0,51,72,0,67,19,116,77,305,0,0,25,25,22,0,47,2,0,0,5,0,0,11,0,0,248,29,45,8,0,264
fr,1406,25627,120,246,361,402,77,80,266,211,0,167,175,0,53,33,82,30,678,0,1,37,13,184,0,48,26,0,0,9,0,0,6,0,0,116,104,19,4,0,363
de,284,40,12505,30,170,575,29,219,46,110,0,52,92,0,127,4,5,153,42,0,0,22,3,14,0,19,9,0,0,48,0,0,8,0,0,252,29,11,4,0,90
es,112,56,15,4336,70,11,28,9,113,27,0,4,7,0,7,11,5,0,79,0,0,8,0,32,0,15,21,0,0,17,0,0,6,0,0,3,6,1,0,0,72
it,71,19,48,42,3435,9,12,10,42,22,0,2,10,0,3,7,30,8,37,0,0,4,0,6,0,0,2,0,0,5,0,0,0,0,0,3,8,3,0,0,74
nl,85,70,74,19,8,3209,16,11,4,19,0,1,23,0,7,7,2,26,11,0,0,6,0,8,0,8,2,0,0,2,0,0,2,0,0,35,3,0,1,0,27
pl,9,0,2,1,0,0,1271,1,1,1,0,3,0,0,12,0,0,0,2,0,0,4,0,0,0,10,1,0,0,0,0,0,0,0,0,0,1,0,0,0,1
sv,17,6,3,0,2,3,3,869,1,9,0,7,11,0,15,0,2,7,1,0,0,3,0,2,0,6,0,0,0,1,0,0,5,0,0,9,0,0,5,0,13
pt,12,6,7,37,4,1,1,0,598,2,0,0,3,0,0,2,3,0,13,0,0,0,0,3,0,3,4,0,0,4,0,0,0,0,0,0,0,0,0,0,10
fi,0,0,1,1,4,4,0,1,1,676,0,7,1,0,3,0,0,5,0,0,0,2,0,0,0,3,1,0,0,1,0,0,0,0,0,14,1,0,0,0,3
bg,15,0,3,0,0,0,0,4,0,0,493,0,0,41,0,2,0,0,0,0,1,0,0,0,0,0,0,10,0,0,0,0,0,19,0,0,0,2,0,0,1
hr,8,0,2,6,5,1,9,6,0,3,0,254,4,0,0,1,0,0,2,0,0,1,0,12,0,0,2,0,0,49,0,0,1,0,0,2,5,0,1,0,7
nb,11,5,3,0,5,7,1,7,2,2,0,3,245,0,3,0,0,29,0,0,0,0,0,1,0,6,0,0,0,1,0,0,3,0,0,1,6,0,0,0,4
ru,2,0,0,0,0,0,0,0,0,0,8,0,1,292,0,0,0,0,0,0,1,0,0,0,0,0,0,8,0,0,0,0,0,2,0,0,0,0,0,0,1
hu,3,1,0,1,0,1,3,0,1,0,0,1,0,0,204,4,0,0,1,0,0,0,0,0,0,0,9,0,0,0,0,0,0,0,0,0,0,0,1,0,0
cs,1,0,3,0,0,2,2,0,1,0,0,1,0,0,4,120,0,1,0,0,0,0,0,1,0,1,68,0,0,5,0,0,0,0,0,0,0,0,0,0,1
ro,3,1,2,0,3,1,0,1,1,1,0,0,1,0,0,0,123,2,4,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,2,0,0,0,2
da,2,0,1,0,1,0,0,2,0,0,0,0,9,0,0,0,0,91,0,0,0,0,0,0,0,2,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1
ca,3,0,0,2,3,0,1,1,0,2,0,0,0,0,0,0,0,0,65,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,2
ja,10,0,4,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,11,1,0,3,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,6
ar,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,41,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1
lv,1,0,0,2,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,28,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1
zh,7,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,3,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,5
lt,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,21,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0
el,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,21,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0
tr,1,0,1,1,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,12,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0
sk,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,1,0,0,13,0,0,0,0,0,0,0,0,0,0,0,0,0,0
uk,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,11,0,0,0,0,0,1,0,0,0,0,0,0,0
th,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,12,0,0,0,0,0,0,0,0,0,0,0,0
sl,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,10,0,0,0,0,0,0,0,0,0,0,0
he,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,13,0,0,0,0,0,0,0,0,0,0
bn,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,6,0,0,0,0,0,0,0,0,0
is,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,8,0,0,0,0,0,0,0,0
sr,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,3,0,0,0,0,0,0,0,0,0,0,1
no,0,0,0,0,0,0,0,0,0,0,0,0,6,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0
et,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,4,0,0,0,0,0
id,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,3,0,0,0,0
ko,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,4,0,0,1
sq,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0
ka,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,2,0
vi,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1
Loading

0 comments on commit 405eb1d

Please sign in to comment.