Skip to content

Commit

Permalink
Refactor basic pattern recognizer
Browse files Browse the repository at this point in the history
  • Loading branch information
rolczynski committed Oct 29, 2020
1 parent ad089fb commit 59f05b7
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 15 deletions.
26 changes: 13 additions & 13 deletions aspect_based_sentiment_analysis/aux_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from typing import Set
from typing import Tuple
from dataclasses import dataclass
from dataclasses import field

import numpy as np
import tensorflow as tf
Expand Down Expand Up @@ -52,27 +51,28 @@ class BasicReferenceRecognizer(ReferenceRecognizer, PretrainedConfig):
"""
The Basic Reference Recognizer predicts whether a text relates to an
aspect or not. Briefly, it represents a text and an aspect as two
vectors, and predicts that a text relates to an aspect if the cosine
similarity is bigger than a threshold. It calculates text and aspect
vectors, measure cosine similarity between them, and then use the simple
logistic regression to make a prediction. It calculates text and aspect
representations by summing their subtoken vectors, context-independent
embeddings that come from the embedding first layer.
This model has only one parameter, nonetheless, we show how to take a use
of the methods `save_pretrained` and `load_pretrained`. They are useful
especially for more complex models.
embeddings that come from the embedding first layer. This model has two
parameter (β_0, β_1). Benefit from two useful methods `save_pretrained`
and `load_pretrained` (to persist the model for future use).
"""
threshold: float
model_type: str = field(default='reference_recognizer')
weights: Tuple[float, float]
model_type: str = 'reference_recognizer'

def __call__(
self,
example: TokenizedExample,
output: Output
) -> bool:
β_0, β_1 = self.weights
n = len(example.subtokens)
hidden_states = output.hidden_states[:, :n, :] # Trim padded tokens.
text_mask, aspect_mask = self.text_aspect_subtoken_masks(example)
similarity = self.transform(output.hidden_states, text_mask, aspect_mask)
is_reference = bool(similarity > self.threshold)
return is_reference
similarity = self.transform(hidden_states, text_mask, aspect_mask)
is_reference = β_0 + β_1 * similarity > 0
return bool(is_reference) # Do not use the numpy bool object.

@staticmethod
def transform(
Expand Down
7 changes: 5 additions & 2 deletions tests/absa/test_aux_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
def test_basic_reference_recognizer():
text = 'the automobile is so cool and the service is prompt and curious.'
examples = [Example(text, 'breakfast'), Example(text, 'service'), Example(text, 'car')]
recognizer = BasicReferenceRecognizer(threshold=0.08)
recognizer = BasicReferenceRecognizer(weights=(-0.025, 44))
nlp = absa.load('absa/classifier-rest-0.2', reference_recognizer=recognizer)
predictions = nlp.transform(examples)
prediction_1, prediction_2, prediction_3 = predictions
Expand All @@ -28,7 +28,10 @@ def test_basic_reference_recognizer():
def test_basic_reference_recognizer_from_pretrained():
name = 'absa/basic_reference_recognizer-rest-0.1'
recognizer = BasicReferenceRecognizer.from_pretrained(name)
assert recognizer.threshold == 0.08
assert np.allclose(recognizer.weights, [-0.024, 44.443], atol=0.001)
name = 'absa/basic_reference_recognizer-lapt-0.1'
recognizer = BasicReferenceRecognizer.from_pretrained(name)
assert np.allclose(recognizer.weights, [-0.175, 40.165], atol=0.001)


def test_basic_reference_recognizer_transform():
Expand Down

0 comments on commit 59f05b7

Please sign in to comment.