Skip to content

Commit

Permalink
#285 do not recreate all components every time + optimize document co…
Browse files Browse the repository at this point in the history
…llection filter
  • Loading branch information
HermannKroll committed Nov 22, 2024
1 parent 507dc09 commit 0e9afc9
Show file tree
Hide file tree
Showing 3 changed files with 197 additions and 186 deletions.
11 changes: 7 additions & 4 deletions src/narraint/frontend/ui/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
from narraint.queryengine.result import QueryDocumentResult, QueryDocumentResultList
from narraint.ranking.corpus import DocumentCorpus
from narraint.ranking.indexed_document import IndexedDocument
from narraint.recommender.recommendation import apply_recommendation
from narraint.recommender.recommendation import RecommendationSystem
from narrant.entity.entityresolver import EntityResolver

logging.basicConfig(format='%(asctime)s,%(msecs)d %(levelname)-8s [%(filename)s:%(lineno)d] %(message)s',
Expand Down Expand Up @@ -76,6 +76,7 @@ def __new__(cls):
cls.explainer = EntityExplainer()
cls.keyword2graph = Keyword2GraphTranslation()
cls.corpus = DocumentCorpus()
cls.recommender = RecommendationSystem()
return cls._instance


Expand Down Expand Up @@ -654,7 +655,8 @@ def get_query(request):
results, cache_hit, time_needed = do_query_processing_with_caching(graph_query, document_collection)
result_ids = {r.document_id for r in results}
opt_query = QueryOptimizer.optimize_query(graph_query)
View().query_logger.write_query_log(time_needed, "-".join(sorted(document_collection)), cache_hit, len(result_ids),
View().query_logger.write_query_log(time_needed, "-".join(sorted(document_collection)), cache_hit,
len(result_ids),
query, opt_query)

results = TitleFilter.filter_documents(results, title_filter)
Expand Down Expand Up @@ -1473,13 +1475,15 @@ def get_data_sources(request):
except Exception:
return HttpResponse(status=500)


def get_classifications(request):
try:
available_classifications = ClassificationFilter.get_available_classifications()
return JsonResponse(status=200, data=dict(classifications=available_classifications))
except Exception:
return HttpResponse(status=500)


def get_recommend(request):
results_converted = []
is_aggregate = False
Expand All @@ -1498,7 +1502,6 @@ def get_recommend(request):
if not request.GET.keys():
return HttpResponse(status=500)


document_id = int(request.GET.get("query", ""))
query_collection = request.GET.get("query_col", "")
query_trans_string = document_id
Expand Down Expand Up @@ -1553,7 +1556,7 @@ def get_recommend(request):
valid_query = True
logging.info(f'Requested recommendation for document id: {document_id}')

json_data = apply_recommendation(document_id, query_collection, document_collections, View().corpus)
json_data = View().recommender.apply_recommendation(document_id, query_collection, document_collections)

results = []
graph_data = dict()
Expand Down
12 changes: 8 additions & 4 deletions src/narraint/recommender/first_stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,13 @@ class FirstStage:
Previously called FSConceptFlex, now default first stage implementation.
"""

def __init__(self, extractor: NarrativeCoreExtractor, document_collections: set):
def __init__(self, extractor: NarrativeCoreExtractor):
self.extractor = extractor
self.document_collections = document_collections
self.document_collections = None
self.session = SessionExtended.get()

def retrieve_documents_for(self, document: RecommenderDocument):
def retrieve_documents_for(self, document: RecommenderDocument, document_collections: [str]):
self.document_collections = list(document_collections)
# Compute the cores
core = self.extractor.extract_concept_core(document)

Expand All @@ -38,7 +39,10 @@ def retrieve_documents(self, concept: str, concept_type: str):
# Search for matching nodes but not for predicates (ignore direction)
q = q.filter(TagInvertedIndex.entity_id == concept)
q = q.filter(TagInvertedIndex.entity_type == concept_type)
q = q.filter(TagInvertedIndex.document_collection.in_(self.document_collections))
if len(self.document_collections) == 1:
q = q.filter(TagInvertedIndex.document_collection == self.document_collections[0])
else:
q = q.filter(TagInvertedIndex.document_collection.in_(self.document_collections))
document_ids = set()
for row in q:
document_ids.update(TagInvertedIndex.prepare_document_ids(row.document_ids))
Expand Down
Loading

0 comments on commit 0e9afc9

Please sign in to comment.