diff --git a/webapp/api/api/admin/__init__.py b/webapp/api/api/admin/__init__.py index c7036e94..ba61bf3d 100644 --- a/webapp/api/api/admin/__init__.py +++ b/webapp/api/api/admin/__init__.py @@ -5,13 +5,13 @@ admin.site.register(Entity) admin.site.register(MetaTaskValue) admin.site.register(MetaTask) -admin.site.register(MetaAnnotation) admin.site.register(Vocabulary) admin.site.register(Relation) admin.site.register(EntityRelation) admin.site.register(ProjectGroup, ProjectGroupAdmin) admin.site.register(ProjectAnnotateEntities, ProjectAnnotateEntitiesAdmin) admin.site.register(AnnotatedEntity, AnnotatedEntityAdmin) +admin.site.register(MetaAnnotation, MetaAnnotationAdmin) admin.site.register(ConceptDB, ConceptDBAdmin) admin.site.register(Document, DocumentAdmin) admin.site.register(ExportedProject, ExportedProjectAdmin) diff --git a/webapp/api/api/admin/actions.py b/webapp/api/api/admin/actions.py index feb754c4..57b5a16b 100644 --- a/webapp/api/api/admin/actions.py +++ b/webapp/api/api/admin/actions.py @@ -26,8 +26,10 @@ def reset_project(modeladmin, request, queryset): # Remove all annotations and cascade to meta anns AnnotatedEntity.objects.filter(project=project).delete() - # Set all validated documents to none + # Clear validated_docuents and prepared_documents project.validated_documents.clear() + project.prepared_documents.clear() + def download_without_text(modeladmin, request, queryset): diff --git a/webapp/api/api/admin/models.py b/webapp/api/api/admin/models.py index ed09f378..3768a4a5 100644 --- a/webapp/api/api/admin/models.py +++ b/webapp/api/api/admin/models.py @@ -8,7 +8,7 @@ from ..models import * _PROJECT_ANNO_ENTS_SETTINGS_FIELD_ORDER = ( - 'concept_db', 'vocab', 'cdb_search_filter', 'require_entity_validation', 'train_model_on_submit', + 'concept_db', 'vocab', 'model_pack', 'cdb_search_filter', 'require_entity_validation', 'train_model_on_submit', 'add_new_entities', 'restrict_concept_lookup', 'terminate_available', 'irrelevant_available', 'enable_entity_annotation_comments', 'tasks', 'relations' ) @@ -43,7 +43,8 @@ class ProjectAnnotateEntitiesAdmin(admin.ModelAdmin): actions = [download, download_without_text, download_without_text_with_doc_names, reset_project, clone_projects] list_filter = ('members', 'project_status', 'project_locked', 'annotation_classification') list_display = ['name'] - fields = (('group', 'name', 'description', 'annotation_guideline_link', 'members', 'dataset', 'validated_documents') + + fields = (('group', 'name', 'description', 'annotation_guideline_link', 'members', + 'dataset', 'validated_documents', 'prepared_documents') + _PROJECT_FIELDS_ORDER + _PROJECT_ANNO_ENTS_SETTINGS_FIELD_ORDER) @@ -55,7 +56,7 @@ def formfield_for_foreignkey(self, db_field, request, **kwargs): def formfield_for_manytomany(self, db_field, request, **kwargs): if db_field.name == 'cdb_search_filter': kwargs['queryset'] = ConceptDB.objects.all() - if db_field.name == 'validated_documents': + if db_field.name in ('validated_documents', 'prepared_documents'): project_id = request.path.replace('/admin/api/projectannotateentities/', '').split('/')[0] try: proj = ProjectAnnotateEntities.objects.get(id=int(project_id)) @@ -165,7 +166,13 @@ def metacats(self, obj): class MetaCATModelAdmin(admin.ModelAdmin): model = MetaCATModel list_display = ('name', 'meta_cat_dir') - list_filter = ['meta_task'] + + +class MetaAnnotationAdmin(admin.ModelAdmin): + model = MetaAnnotation + list_display = ('annotated_entity', 'meta_task', 'meta_task_value', 'acc', + 'predicted_meta_task_value', 'validated', 'last_modified') + list_filter = ('meta_task', 'meta_task_value', 'predicted_meta_task_value', 'validated') class DocumentAdmin(admin.ModelAdmin): diff --git a/webapp/api/api/metrics.py b/webapp/api/api/metrics.py index 3ffdd842..d9d2fe61 100644 --- a/webapp/api/api/metrics.py +++ b/webapp/api/api/metrics.py @@ -36,15 +36,21 @@ def calculate_metrics(project_ids: List[int], report_name: str): """ Computes metrics in a background task - :param projects: list of projects to compute metrics for. Uses the 'first' for the CDB, but - should be the same CDB, but will still try and compute metrics regardless + :param projects: list of projects to compute metrics for. + Uses the 'first' for the CDB / vocab or ModelPack, + but should be the same CDB, but will still try and compute metrics regardless. :return: computed metrics results """ logger.info('Calculating metrics for report: %s', report_name) projects = [ProjectAnnotateEntities.objects.filter(id=p_id).first() for p_id in project_ids] - cdb = CDB.load(projects[0].concept_db.cdb_file.path) - vocab = Vocab.load(projects[0].vocab.vocab_file.path) - cat = CAT(cdb, vocab, config=cdb.config) + if projects[0].cdb is None: + # assume the model pack is set. + cat = CAT.load_model_pack(projects[0].model_pack.model_pack.path) + else: + # assume the cdb / vocab is set in these projects + cdb = CDB.load(projects[0].concept_db.cdb_file.path) + vocab = Vocab.load(projects[0].vocab.vocab_file.path) + cat = CAT(cdb, vocab, config=cdb.config) project_data = retrieve_project_data(projects) metrics = ProjectMetrics(project_data, cat) report = metrics.generate_report() diff --git a/webapp/api/api/migrations/0082_remove_metacatmodel_meta_task_and_more.py b/webapp/api/api/migrations/0082_remove_metacatmodel_meta_task_and_more.py new file mode 100644 index 00000000..df09c193 --- /dev/null +++ b/webapp/api/api/migrations/0082_remove_metacatmodel_meta_task_and_more.py @@ -0,0 +1,98 @@ +# Generated by Django 5.0.6 on 2024-08-28 10:56 + +import django.db.models.deletion +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('api', '0081_alter_metatask_name'), + ] + + operations = [ + migrations.RemoveField( + model_name='metacatmodel', + name='meta_task', + ), + migrations.AddField( + model_name='metaannotation', + name='predicted_meta_task_value', + field=models.ForeignKey(blank=True, help_text='meta annotation predicted by a MetaAnnotationModel', null=True, on_delete=django.db.models.deletion.CASCADE, related_name='predicted_value', to='api.metataskvalue'), + ), + migrations.AddField( + model_name='metatask', + name='prediction_model', + field=models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.SET_NULL, to='api.metacatmodel'), + ), + migrations.AddField( + model_name='project', + name='meta_cat_predictions', + field=models.BooleanField(default=False, help_text='If MetaTasks are setup on the project and there are associated MetaCATModel instances, display these predictions in the interface to be validated / corrected'), + ), + migrations.AddField( + model_name='projectannotateentities', + name='model_pack', + field=models.ForeignKey(blank=True, default=None, help_text='A MedCAT model pack. This will raise an exception if both the CDB and Vocab and ModelPack fields are set', null=True, on_delete=django.db.models.deletion.SET_NULL, to='api.modelpack'), + ), + migrations.AddField( + model_name='projectgroup', + name='meta_cat_predictions', + field=models.BooleanField(default=False, help_text='If MetaTasks are setup on the project and there are associated MetaCATModel instances, display these predictions in the interface to be validated / corrected'), + ), + migrations.AddField( + model_name='projectgroup', + name='model_pack', + field=models.ForeignKey(blank=True, default=None, help_text='A MedCAT model pack. This will raise an exception if both the CDB and Vocab and ModelPack fields are set', null=True, on_delete=django.db.models.deletion.SET_NULL, to='api.modelpack'), + ), + migrations.AlterField( + model_name='metaannotation', + name='validated', + field=models.BooleanField(default=False, help_text='If an annotation is not '), + ), + migrations.AlterField( + model_name='metacatmodel', + name='meta_cat_dir', + field=models.FilePathField(allow_folders=True, editable=False, help_text='The zip or dir for a MetaCAT model, not editable, is set via a model pack .zip upload'), + ), + migrations.AlterField( + model_name='metacatmodel', + name='name', + field=models.CharField(help_text='The task name followed by the underlying model impl', max_length=100), + ), + migrations.AlterField( + model_name='projectannotateentities', + name='concept_db', + field=models.ForeignKey(blank=True, help_text='The MedCAT CDB used to annotate / validate', null=True, on_delete=django.db.models.deletion.SET_NULL, to='api.conceptdb'), + ), + migrations.AlterField( + model_name='projectannotateentities', + name='tasks', + field=models.ManyToManyField(blank=True, default=None, help_text='The set of MetaAnnotation tasks configured for this project, this will default to the set of Tasks configured in a ModelPack if a model pack is used for the project', to='api.metatask'), + ), + migrations.AlterField( + model_name='projectannotateentities', + name='vocab', + field=models.ForeignKey(blank=True, help_text='The MedCAT Vocab used to annotate / validate', null=True, on_delete=django.db.models.deletion.SET_NULL, to='api.vocabulary'), + ), + migrations.AlterField( + model_name='projectgroup', + name='cdb_search_filter', + field=models.ManyToManyField(blank=True, help_text='The CDB that will be used for concept lookup. This specific CDB should have been "imported" via the CDB admin screen', related_name='project_group_concept_source', to='api.conceptdb'), + ), + migrations.AlterField( + model_name='projectgroup', + name='concept_db', + field=models.ForeignKey(blank=True, help_text='The MedCAT CDB used to annotate / validate', null=True, on_delete=django.db.models.deletion.SET_NULL, to='api.conceptdb'), + ), + migrations.AlterField( + model_name='projectgroup', + name='tasks', + field=models.ManyToManyField(blank=True, default=None, help_text='The set of MetaAnnotation tasks configured for this project, this will default to the set of Tasks configured in a ModelPack if a model pack is used for the project', to='api.metatask'), + ), + migrations.AlterField( + model_name='projectgroup', + name='vocab', + field=models.ForeignKey(blank=True, help_text='The MedCAT Vocab used to annotate / validate', null=True, on_delete=django.db.models.deletion.SET_NULL, to='api.vocabulary'), + ), + ] diff --git a/webapp/api/api/migrations/0083_project_prepared_documents_and_more.py b/webapp/api/api/migrations/0083_project_prepared_documents_and_more.py new file mode 100644 index 00000000..289715e0 --- /dev/null +++ b/webapp/api/api/migrations/0083_project_prepared_documents_and_more.py @@ -0,0 +1,23 @@ +# Generated by Django 5.0.6 on 2024-08-29 11:25 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('api', '0082_remove_metacatmodel_meta_task_and_more'), + ] + + operations = [ + migrations.AddField( + model_name='project', + name='prepared_documents', + field=models.ManyToManyField(blank=True, default=None, help_text='Set automatically on each prep of a document', related_name='prepared_documents', to='api.document'), + ), + migrations.AlterField( + model_name='project', + name='validated_documents', + field=models.ManyToManyField(blank=True, default=None, help_text='Set automatically on each doc submission', to='api.document'), + ), + ] diff --git a/webapp/api/api/model_cache.py b/webapp/api/api/model_cache.py new file mode 100644 index 00000000..304ccf7c --- /dev/null +++ b/webapp/api/api/model_cache.py @@ -0,0 +1,127 @@ +import logging +import os +from typing import Dict + +import pkg_resources +from medcat.cat import CAT +from medcat.cdb import CDB +from medcat.vocab import Vocab + +from api.models import ConceptDB + +""" +Module level caches for CDBs, Vocabs and CAT instances. +""" +# Maps between IDs and objects +CDB_MAP = {} +VOCAB_MAP = {} +CAT_MAP = {} + + +logger = logging.getLogger(__name__) + + +def get_medcat_from_cdb_vocab(project, + cdb_map: Dict[str, CDB]=CDB_MAP, + vocab_map: Dict[str, Vocab]=VOCAB_MAP, + cat_map: Dict[str, CAT]=CAT_MAP) -> CAT: + cdb_id = project.concept_db.id + vocab_id = project.vocab.id + cat_id = str(cdb_id) + "-" + str(vocab_id) + if cat_id in cat_map: + cat = cat_map[cat_id] + else: + if cdb_id in cdb_map: + cdb = cdb_map[cdb_id] + else: + cdb_path = project.concept_db.cdb_file.path + try: + cdb = CDB.load(cdb_path) + except KeyError as ke: + mc_v = pkg_resources.get_distribution('medcat').version + if int(mc_v.split('.')[0]) > 0: + logger.error('Attempted to load MedCAT v0.x model with MCTrainer v1.x') + raise Exception('Attempted to load MedCAT v0.x model with MCTrainer v1.x', + 'Please re-configure this project to use a MedCAT v1.x CDB or consult the ' + 'MedCATTrainer Dev team if you believe this should work') from ke + raise + + custom_config = os.getenv("MEDCAT_CONFIG_FILE") + if custom_config is not None and os.path.exists(custom_config): + cdb.config.parse_config_file(path=custom_config) + else: + logger.info("No MEDCAT_CONFIG_FILE env var set to valid path, using default config available on CDB") + cdb_map[cdb_id] = cdb + + if vocab_id in vocab_map: + vocab = vocab_map[vocab_id] + else: + vocab_path = project.vocab.vocab_file.path + vocab = Vocab.load(vocab_path) + vocab_map[vocab_id] = vocab + cat = CAT(cdb=cdb, config=cdb.config, vocab=vocab) + cat_map[cat_id] = cat + return cat + + +def get_medcat_from_model_pack(project, cat_map: Dict[str, CAT]=CAT_MAP) -> CAT: + model_pack_obj = project.model_pack + cat_id = 'mp' + str(model_pack_obj.id) + logger.info('Loading model pack from:%s', model_pack_obj.model_pack.path) + cat = CAT.load_model_pack(model_pack_obj.model_pack.path) + cat_map[cat_id] = cat + return cat + + +def get_medcat(project, + cdb_map: Dict[str, CDB]=CDB_MAP, + vocab_map: Dict[str, Vocab]=VOCAB_MAP, + cat_map: Dict[str, CAT]=CAT_MAP): + try: + if project.model_pack is None: + cat = get_medcat_from_cdb_vocab(project, cdb_map, vocab_map, cat_map) + else: + cat = get_medcat_from_model_pack(project, cat_map) + return cat + except AttributeError: + raise Exception('Failure loading Project ConceptDB, Vocab or Model Pack. Are these set correctly?') + + +def get_cached_medcat(project, cat_map: Dict[str, CAT]=CAT_MAP): + if project.concept_db is None or project.vocab is None: + return None + cdb_id = project.concept_db.id + vocab_id = project.vocab.id + cat_id = str(cdb_id) + "-" + str(vocab_id) + return cat_map.get(cat_id) + + +def clear_cached_medcat(project, cat_map: Dict[str, CAT]=CAT_MAP): + cdb_id = project.concept_db.id + vocab_id = project.vocab.id + cat_id = str(cdb_id) + "-" + str(vocab_id) + if cat_id in cat_map: + del cat_map[cat_id] + + +def get_cached_cdb(cdb_id: str, cdb_map: Dict[str, CDB]=CDB_MAP) -> CDB: + if cdb_id not in cdb_map: + cdb_obj = ConceptDB.objects.get(id=cdb_id) + cdb = CDB.load(cdb_obj.cdb_file.path) + cdb_map[cdb_id] = cdb + return cdb_map[cdb_id] + + +def clear_cached_cdb(cdb_id, cdb_map: Dict[str, CDB]=CDB_MAP): + if cdb_id in cdb_map: + del cdb_map[cdb_id] + + +def is_model_loaded(project, + cdb_map: Dict[str, CDB]=CDB_MAP, + cat_map: Dict[str, CAT]=CAT_MAP): + if project.concept_db is None: + # model pack is used. + return False if not project.model_pack else f'mp{project.model_pack.id}' in cat_map + else: + return False if not project.concept_db else project.concept_db.id in cdb_map diff --git a/webapp/api/api/models.py b/webapp/api/api/models.py index e8ed5e1c..0bfccd8c 100644 --- a/webapp/api/api/models.py +++ b/webapp/api/api/models.py @@ -77,6 +77,7 @@ def save(self, *args, **kwargs): # load MetaCATs try: + metaCATmodels = [] # should raise an error if there already is a MetaCAT model with this definition for meta_cat_dir, meta_cat in CAT.load_meta_cats(unpacked_model_pack_path): mc_model = MetaCATModel() @@ -84,6 +85,8 @@ def save(self, *args, **kwargs): mc_model.name = f'{meta_cat.config.general.category_name} - {meta_cat.config.model.model_name}' mc_model.save(unpack_load_meta_cat_dir=False) mc_model.get_or_create_meta_tasks_and_values(meta_cat) + metaCATmodels.append(mc_model) + self.meta_cats.add(*metaCATmodels) except Exception as exc: raise MedCATLoadException(f'Failure loading MetaCAT models - {unpacked_model_pack_path}') from exc super().save(*args, **kwargs) @@ -142,9 +145,10 @@ def __str__(self): class MetaCATModel(models.Model): - name = models.CharField(max_length=100) - meta_cat_dir = models.FilePathField(help_text='The zip or dir for a MetaCAT model', allow_folders=True) - meta_task = models.ForeignKey('MetaTask', on_delete=SET_NULL, blank=True, null=True) + name = models.CharField(max_length=100, help_text="The task name followed by the underlying model impl") + meta_cat_dir = models.FilePathField(help_text='The zip or dir for a MetaCAT model, not editable, ' + 'is set via a model pack .zip upload', + allow_folders=True, editable=False) def get_or_create_meta_tasks_and_values(self, meta_cat: MetaCAT): task = meta_cat.config.general.category_name @@ -152,8 +156,11 @@ def get_or_create_meta_tasks_and_values(self, meta_cat: MetaCAT): if not mt: mt = MetaTask() mt.name = task + mt.prediction_model = self + mt.save() + else: + mt.prediction_model = self mt.save() - self.meta_task = mt mt_vs = [] for meta_task_value in meta_cat.config.general.category_value2id.keys(): @@ -163,7 +170,8 @@ def get_or_create_meta_tasks_and_values(self, meta_cat: MetaCAT): mt_v.name = meta_task_value mt_v.save() mt_vs.append(mt_v) - self.meta_task.values.set(mt_vs) + mt.values.set(mt_vs) + self.save() def save(self, *args, unpack_load_meta_cat_dir=False, **kwargs): if unpack_load_meta_cat_dir: @@ -245,6 +253,10 @@ class Meta: annotation_classification = models.BooleanField(default=False, help_text="If these annotations are suitable " "for training a general purpose model. If" " in doubt uncheck this.") + meta_cat_predictions = models.BooleanField(default=False, help_text="If MetaTasks are setup on the project and " + "there are associated MetaCATModel instances, " + "display these predictions in the interface to " + "be validated / corrected") project_locked = models.BooleanField(default=False, help_text="Locked indicates annotation collection is complete and this dataset should " "not be touched any further.") project_status = models.CharField(max_length=1, choices=PROJECT_STATUSES, default="A", @@ -256,7 +268,11 @@ class Project(PolymorphicModel, ProjectFields): help_text='The list users that have access to this annotation project') group = models.ForeignKey('ProjectGroup', on_delete=models.SET_NULL, blank=True, null=True, help_text='The annotation project group that this project is part of') - validated_documents = models.ManyToManyField(Document, default=None, blank=True) + validated_documents = models.ManyToManyField(Document, default=None, blank=True, + help_text='Set automatically on each doc submission') + prepared_documents = models.ManyToManyField(Document, default=None, blank=True, + help_text='Set automatically on each prep of a document', + related_name='prepared_documents') def __str__(self): return str(self.name) @@ -353,6 +369,7 @@ class MetaTask(models.Model): description = models.TextField(default="", blank=True) ordering = models.PositiveSmallIntegerField(help_text="the order in which the meta task will appear in " "the Trainer Annotation project screen", default=0) + prediction_model = models.ForeignKey('MetaCATModel', null=True, blank=True, on_delete=models.SET_NULL) class Meta: ordering = ['ordering', 'name'] @@ -368,10 +385,13 @@ class ProjectAnnotateEntitiesFields(models.Model): class Meta: abstract = True - concept_db = models.ForeignKey('ConceptDB', on_delete=models.SET_NULL, blank=False, null=True, + concept_db = models.ForeignKey('ConceptDB', on_delete=models.SET_NULL, blank=True, null=True, help_text='The MedCAT CDB used to annotate / validate') - vocab = models.ForeignKey('Vocabulary', on_delete=models.SET_NULL, null=True, + vocab = models.ForeignKey('Vocabulary', on_delete=models.SET_NULL, blank=True, null=True, help_text='The MedCAT Vocab used to annotate / validate') + model_pack = models.ForeignKey('ModelPack', on_delete=models.SET_NULL, help_text="A MedCAT model pack. This will raise an exception if " + "both the CDB and Vocab and ModelPack fields are set", + default=None, null=True, blank=True) cdb_search_filter = models.ManyToManyField('ConceptDB', blank=True, default=None, help_text='The CDB that will be used for concept lookup. ' 'This specific CDB should have been "imported" ' @@ -399,10 +419,19 @@ class Meta: help_text="Enable to allow annotators to leave comments" " for each annotation") tasks = models.ManyToManyField('MetaTask', blank=True, default=None, - help_text='The set of MetaAnnotation tasks configured for this project') + help_text='The set of MetaAnnotation tasks configured for this project, ' + 'this will default to the set of Tasks configured in a ModelPack ' + 'if a model pack is used for the project') relations = models.ManyToManyField('Relation', blank=True, default=None, help_text='Relations that will be available for this project') + def save(self, *args, **kwargs): + if self.model_pack is None and (self.concept_db is None or self.vocab is None): + raise ValidationError('Must set at least the ModelPack or a Concept Database and Vocab Pair') + if self.model_pack and (self.concept_db is not None or self.vocab is not None): + raise ValidationError('Cannot set model pack and ConceptDB or a Vocab. You must use one or the other.') + super().save(*args, **kwargs) + class ProjectAnnotateEntities(Project, ProjectAnnotateEntitiesFields): """ @@ -419,7 +448,7 @@ class ProjectGroup(ProjectFields, ProjectAnnotateEntitiesFields): annotators = models.ManyToManyField(settings.AUTH_USER_MODEL, help_text="The set of users that will each be provided an annotation project", related_name='annotators') - cdb_search_filter = models.ManyToManyField('ConceptDB', blank=True, default=None, + cdb_search_filter = models.ManyToManyField('ConceptDB', blank=True, help_text='The CDB that will be used for concept lookup. ' 'This specific CDB should have been "imported" ' 'via the CDB admin screen', @@ -441,7 +470,10 @@ class MetaAnnotation(models.Model): meta_task = models.ForeignKey('MetaTask', on_delete=models.CASCADE) meta_task_value = models.ForeignKey('MetaTaskValue', on_delete=models.CASCADE) acc = models.FloatField(default=1) - validated = models.BooleanField(default=False) + predicted_meta_task_value = models.ForeignKey('MetaTaskValue', on_delete=models.CASCADE, + help_text='meta annotation predicted by a MetaAnnotationModel', + null=True, blank=True, related_name="predicted_value") + validated = models.BooleanField(help_text='If an annotation is not ', default=False) last_modified = models.DateTimeField(auto_now=True) def save(self, *args, **kwargs): diff --git a/webapp/api/api/signals.py b/webapp/api/api/signals.py index 6ae3b278..112b1fa5 100644 --- a/webapp/api/api/signals.py +++ b/webapp/api/api/signals.py @@ -3,13 +3,14 @@ import os import shutil -from django.core.exceptions import ObjectDoesNotExist +from django.core.exceptions import ObjectDoesNotExist, ValidationError from django.db.models.fields.files import FileField -from django.db.models.signals import post_save, post_delete, pre_save +from django.db.models.signals import post_save, post_delete, pre_save, m2m_changed, pre_delete from django.dispatch import receiver from api.data_utils import dataset_from_file, delete_orphan_docs, upload_projects_export -from api.models import Dataset, ExportedProject, ModelPack +from api.models import Dataset, ExportedProject, ModelPack, ProjectFields, ProjectAnnotateEntitiesFields, MetaTask, \ + ProjectAnnotateEntities from core.settings import MEDIA_ROOT @@ -45,6 +46,13 @@ def save_exported_projects(sender, instance, **kwargs): upload_projects_export(json.load(open(instance.trainer_export_file.path))) +@receiver(pre_delete, sender=ModelPack) +def remove_model_pack_meta_cat_models(sender, instance, **kwargs): + if len(instance.meta_cats.all()) > 0: + for m_c in instance.meta_cats.all(): + m_c.delete(using=None, keep_parents=False) + + @receiver(post_delete, sender=ModelPack) def remove_model_pack_assets(sender, instance, **kwargs): try: @@ -55,14 +63,23 @@ def remove_model_pack_assets(sender, instance, **kwargs): try: if instance.vocab: instance.vocab.delete(using=None, keep_parents=False) - if len(instance.meta_cats.all()) > 0: - for m_c in instance.meta_cats.all(): - m_c.delete(using=None, keep_parents=False) except ObjectDoesNotExist: pass # if a vocab of a model pack is removed, this will cascade ModelPack removal. + try: # rm the model pack unzipped dir & model pack zip shutil.rmtree(instance.model_pack.path.replace(".zip", "")) os.remove(instance.model_pack.path) except FileNotFoundError: logger.warning("Failure removing Model pack dir or zip. Not found. Likely already deleted") + + +def project_tasks_changed(sender, instance, action, **kwargs): + # post_remove or post_add actions, overwrite to model_pack supplied MetaCAT tasks. + if (action.startswith('post') and type(instance) is ProjectAnnotateEntitiesFields and + instance.model_pack is not None): + instance.tasks.set([MetaTask.objects.filter(prediction_model_id=meta_cat.id).first() for meta_cat in + instance.model_pack.meta_cats.all()]) + + +m2m_changed.connect(project_tasks_changed, sender=ProjectAnnotateEntitiesFields.tasks.through) diff --git a/webapp/api/api/utils.py b/webapp/api/api/utils.py index efc0d75e..41039dfe 100644 --- a/webapp/api/api/utils.py +++ b/webapp/api/api/utils.py @@ -1,22 +1,21 @@ import json import logging import os -from typing import Union, Dict, List, Type +from typing import List -import pkg_resources +from background_task import background from django.contrib.auth.models import User from django.db.models.signals import post_save from django.dispatch import receiver from medcat.cat import CAT -from medcat.cdb import CDB from medcat.utils.filters import check_filters from medcat.utils.helpers import tkns_from_doc -from medcat.vocab import Vocab +from .model_cache import get_medcat from .models import Entity, AnnotatedEntity, ProjectAnnotateEntities, \ - ConceptDB + MetaAnnotation, MetaTask, Document -log = logging.getLogger('trainer') +logger = logging.getLogger('trainer') def remove_annotations(document, project, partial=False): @@ -26,13 +25,13 @@ def remove_annotations(document, project, partial=False): AnnotatedEntity.objects.filter(project=project, document=document, validated=False).delete() - log.debug(f"Unvalidated Annotations removed for:{document.id}") + logger.debug(f"Unvalidated Annotations removed for:{document.id}") else: # Removes everything AnnotatedEntity.objects.filter(project=project, document=document).delete() - log.debug(f"All Annotations removed for:{document.id}") + logger.debug(f"All Annotations removed for:{document.id}") except Exception as e: - log.debug(f"Something went wrong: {e}") + logger.debug(f"Something went wrong: {e}") def add_annotations(spacy_doc, user, project, document, existing_annotations, cat): @@ -41,6 +40,19 @@ def add_annotations(spacy_doc, user, project, document, existing_annotations, ca tkns_in = [] ents = [] existing_annos_intervals = [(ann.start_ind, ann.end_ind) for ann in existing_annotations] + # all MetaTasks and associated values + # that can be produced are expected to have available models + try: + metatask2obj = {task_name: MetaTask.objects.get(name=task_name) + for task_name in spacy_doc._.ents[0]._.meta_anns.keys()} + metataskvals2obj = {task_name: {v.name: v for v in MetaTask.objects.get(name=task_name).values.all()} + for task_name in spacy_doc._.ents[0]._.meta_anns.keys()} + except AttributeError: + # ignore meta_anns that are not present - i.e. non model pack preds, + # or model pack preds with no meta_anns + metatask2obj = {} + metataskvals2obj = {} + pass def check_ents(ent): return any((ea[0] < ent.start_char < ea[1]) or @@ -57,7 +69,9 @@ def check_ents(ent): tkns_in.append(tkn) ents.append(ent) + logger.debug('Found %s annotations to store', len(ents)) for ent in ents: + logger.debug('Processing annotation ent %s of %s', ents.index(ent), len(ents)) label = ent._.cui if not Entity.objects.filter(label=label).exists(): @@ -68,10 +82,11 @@ def check_ents(ent): else: entity = Entity.objects.get(label=label) - if AnnotatedEntity.objects.filter(project=project, - document=document, - start_ind=ent.start_char, - end_ind=ent.end_char).count() == 0: + ann_ent = AnnotatedEntity.objects.filter(project=project, + document=document, + start_ind=ent.start_char, + end_ind=ent.end_char).first() + if ann_ent is None: # If this entity doesn't exist already ann_ent = AnnotatedEntity() ann_ent.user = user @@ -90,6 +105,20 @@ def check_ents(ent): ann_ent.save() + # check the ent._.meta_anns if it exists + if hasattr(ent._, 'meta_anns') and len(metatask2obj) > 0 and len(metataskvals2obj) > 0: + logger.debug('Found %s meta annos on ent', len(ent._.meta_anns.items())) + for meta_ann_task, pred in ent._.meta_anns.items(): + meta_anno_obj = MetaAnnotation() + meta_anno_obj.predicted_meta_task_value = metataskvals2obj[meta_ann_task][pred['value']] + meta_anno_obj.meta_task = metatask2obj[meta_ann_task] + meta_anno_obj.annotated_entity = ann_ent + meta_anno_obj.meta_task_value = metataskvals2obj[meta_ann_task][pred['value']] + meta_anno_obj.acc = pred['confidence'] + meta_anno_obj.save() + logger.debug('Successfully saved %s', meta_anno_obj) + + def get_create_cdb_infos(cdb, concept, cui, cui_info_prop, code_prop, desc_prop, model_clazz): codes = [c[code_prop] for c in cdb.cui2info.get(cui, {}).get(cui_info_prop, []) if code_prop in c] @@ -113,7 +142,7 @@ def _remove_overlap(project, document, start, end): for ann in anns: if (start <= ann.start_ind <= end) or (start <= ann.end_ind <= end): - log.debug("Removed %s ", str(ann)) + logger.debug("Removed %s ", str(ann)) ann.delete() @@ -205,78 +234,32 @@ def train_medcat(cat, project, document): cat.config.linking['filters'].get('cuis_exclude').update([cui]) -def get_cached_medcat(CAT_MAP, project): - if project.concept_db is None or project.vocab is None: - return None - cdb_id = project.concept_db.id - vocab_id = project.vocab.id - cat_id = str(cdb_id) + "-" + str(vocab_id) - return CAT_MAP.get(cat_id) - - -def clear_cached_medcat(CAT_MAP, project): - cdb_id = project.concept_db.id - vocab_id = project.vocab.id - cat_id = str(cdb_id) + "-" + str(vocab_id) - if cat_id in CAT_MAP: - del CAT_MAP[cat_id] +@background(schedule=1, queue='doc_prep') +def prep_docs(project_id: List[int], doc_ids: List[int], user_id: int): + user = User.objects.get(id=user_id) + project = ProjectAnnotateEntities.objects.get(id=project_id) + docs = Document.objects.filter(id__in=doc_ids) + logger.info('Loading CAT object in bg process') + cat = get_medcat(project=project) -def get_medcat(CDB_MAP, VOCAB_MAP, CAT_MAP, project): - try: - cdb_id = project.concept_db.id - vocab_id = project.vocab.id - except AttributeError: - raise Exception('Failure loading Project Concept Database or Vocabulary. Are these set correctly?') - cat_id = str(cdb_id) + "-" + str(vocab_id) - - if cat_id in CAT_MAP: - cat = CAT_MAP[cat_id] - else: - if cdb_id in CDB_MAP: - cdb = CDB_MAP[cdb_id] - else: - cdb_path = project.concept_db.cdb_file.path - try: - cdb = CDB.load(cdb_path) - except KeyError as ke: - mc_v = pkg_resources.get_distribution('medcat').version - if int(mc_v.split('.')[0]) > 0: - log.error('Attempted to load MedCAT v0.x model with MCTrainer v1.x') - raise Exception('Attempted to load MedCAT v0.x model with MCTrainer v1.x', - 'Please re-configure this project to use a MedCAT v1.x CDB or consult the ' - 'MedCATTrainer Dev team if you believe this should work') from ke - raise - - custom_config = os.getenv("MEDCAT_CONFIG_FILE") - if custom_config is not None and os.path.exists(custom_config): - cdb.config.parse_config_file(path=custom_config) - else: - log.info("No MEDCAT_CONFIG_FILE env var set to valid path, using default config available on CDB") - CDB_MAP[cdb_id] = cdb - - if vocab_id in VOCAB_MAP: - vocab = VOCAB_MAP[vocab_id] - else: - vocab_path = project.vocab.vocab_file.path - vocab = Vocab.load(vocab_path) - VOCAB_MAP[vocab_id] = vocab - - # integrated model-pack spacy model not used. - # This assumes specified spacy model is installed... - # Next change will create conditional params to load CDB / Vocab, or - # model-packs directly for a project. - cat = CAT(cdb=cdb, config=cdb.config, vocab=vocab) - CAT_MAP[cat_id] = cat - return cat + # Set CAT filters + cat.config.linking['filters']['cuis'] = project.cuis + for doc in docs: + logger.info(f'Running MedCAT model over doc: {doc.id}') + spacy_doc = cat(doc.text) + anns = AnnotatedEntity.objects.filter(document=doc).filter(project=project) -def get_cached_cdb(cdb_id: str, CDB_MAP: Dict[str, CDB]) -> CDB: - if cdb_id not in CDB_MAP: - cdb_obj = ConceptDB.objects.get(id=cdb_id) - cdb = CDB.load(cdb_obj.cdb_file.path) - CDB_MAP[cdb_id] = cdb - return CDB_MAP[cdb_id] + add_annotations(spacy_doc=spacy_doc, + user=user, + project=project, + document=doc, + cat=cat, + existing_annotations=anns) + # add doc to prepared_documents + project.prepared_documents.add(doc) + project.save() @receiver(post_save, sender=ProjectAnnotateEntities) @@ -290,6 +273,7 @@ def save_project_anno(sender, instance, **kwargs): post_save.connect(save_project_anno, sender=ProjectAnnotateEntities) + def env_str_to_bool(var: str, default: bool): val = os.environ.get(var, default) if isinstance(val, str): diff --git a/webapp/api/api/views.py b/webapp/api/api/views.py index 5f2b0682..19102ad5 100644 --- a/webapp/api/api/views.py +++ b/webapp/api/api/views.py @@ -1,33 +1,28 @@ -import logging -import pickle import traceback -from datetime import datetime +from smtplib import SMTPException from tempfile import NamedTemporaryFile from background_task.models import Task, CompletedTask +from django.contrib.auth.views import PasswordResetView from django.http import HttpResponseBadRequest, HttpResponseServerError, HttpResponse from django.shortcuts import render from django.utils import timezone from django_filters import rest_framework as drf -from django.contrib.auth.views import PasswordResetView -from medcat.cdb import CDB from medcat.utils.helpers import tkns_from_doc from rest_framework import viewsets from rest_framework.decorators import api_view from rest_framework.response import Response -from smtplib import SMTPException -from core.settings import MEDIA_ROOT from .admin import download_projects_with_text, download_projects_without_text, \ import_concepts_from_cdb from .data_utils import upload_projects_export from .medcat_utils import ch2pt_from_pt2ch, get_all_ch, dedupe_preserve_order, snomed_ct_concept_path from .metrics import calculate_metrics +from .model_cache import get_medcat, get_cached_cdb, VOCAB_MAP, clear_cached_cdb, CAT_MAP, CDB_MAP, is_model_loaded from .permissions import * from .serializers import * from .solr_utils import collections_available, search_collection, ensure_concept_searchable -from .utils import get_cached_medcat, clear_cached_medcat, get_medcat, get_cached_cdb, \ - add_annotations, remove_annotations, train_medcat, create_annotation +from .utils import add_annotations, remove_annotations, train_medcat, create_annotation, prep_docs # For local testing, put envs """ @@ -39,10 +34,6 @@ logger = logging.getLogger(__name__) -# Maps between IDs and objects -CDB_MAP = {} -VOCAB_MAP = {} -CAT_MAP = {} # Get the basic version of MedCAT cat = None @@ -221,7 +212,7 @@ def post(self, request, *args, **kwargs): @api_view(http_method_names=['GET']) def get_anno_tool_conf(_): - return Response({k: v for k,v in os.environ.items()}) + return Response({k: v for k, v in os.environ.items()}) @api_view(http_method_names=['POST']) @@ -253,6 +244,12 @@ def prepare_documents(request): 'description': 'Missing CUI filter file, %s, cannot be found on the filesystem, ' 'but is still set on the project. To fix remove and reset the ' 'cui filter file' % project.cuis_file}, status=500) + + if request.data.get('bg_task'): + # execute model infer in bg + job = prep_docs(p_id, d_ids, user.id) + return Response({'bg_job_id': job.id}) + try: for d_id in d_ids: document = Document.objects.get(id=d_id) @@ -271,8 +268,8 @@ def prepare_documents(request): # If the document is not already annotated, annotate it if (len(anns) == 0 and not is_validated) or update: # Based on the project id get the right medcat - cat = get_medcat(CDB_MAP=CDB_MAP, VOCAB_MAP=VOCAB_MAP, - CAT_MAP=CAT_MAP, project=project) + cat = get_medcat(project=project) + logger.info('loaded medcat model for project: %s', project.id) # Set CAT filters cat.config.linking['filters']['cuis'] = cuis @@ -285,6 +282,10 @@ def prepare_documents(request): cat=cat, existing_annotations=anns) + # add doc to prepared_documents + project.prepared_documents.add(document) + project.save() + except Exception as e: stack = traceback.format_exc() return Response({'message': e.args[0] if len(e.args) > 0 else 'Internal Server Error', @@ -293,6 +294,24 @@ def prepare_documents(request): return Response({'message': 'Documents prepared successfully'}) +@api_view(http_method_names=['GET']) +def prepare_docs_bg_tasks(request): + proj_id = int(request.GET['project']) + running_doc_prep_tasks = Task.objects.filter(queue='doc_prep') + completed_doc_prep_tasks = CompletedTask.objects.filter(queue='doc_prep') + + def transform_task_params(task_params_str): + task_params = json.loads(task_params_str)[0] + return { + 'document': task_params[1][0], + 'user_id': task_params[2] + } + running_tasks = [transform_task_params(task.task_params) for task in running_doc_prep_tasks + if json.loads(task.task_params)[0][0] == proj_id] + complete_tasks = [transform_task_params(task.task_params) for task in completed_doc_prep_tasks + if json.loads(task.task_params)[0][0] == proj_id] + return Response({'running_tasks': running_tasks, 'comp_tasks': complete_tasks}) + @api_view(http_method_names=['POST']) def add_annotation(request): # Get project id @@ -310,8 +329,7 @@ def add_annotation(request): project = ProjectAnnotateEntities.objects.get(id=p_id) document = Document.objects.get(id=d_id) - cat = get_medcat(CDB_MAP=CDB_MAP, VOCAB_MAP=VOCAB_MAP, - CAT_MAP=CAT_MAP, project=project) + cat = get_medcat(project=project) id = create_annotation(source_val=source_val, selection_occurrence_index=sel_occur_idx, cui=cui, @@ -342,8 +360,7 @@ def add_concept(request): user = request.user project = ProjectAnnotateEntities.objects.get(id=p_id) document = Document.objects.get(id=d_id) - cat = get_medcat(CDB_MAP=CDB_MAP, VOCAB_MAP=VOCAB_MAP, - CAT_MAP=CAT_MAP, project=project) + cat = get_medcat(project=project) if cui in cat.cdb.cui2names: err_msg = f'Cannot add a concept "{name}" with cui:{cui}. CUI already linked to {cat.cdb.cui2names[cui]}' @@ -393,14 +410,13 @@ def import_cdb_concepts(request): def _submit_document(project: ProjectAnnotateEntities, document: Document): if project.train_model_on_submit: try: - cat = get_medcat(CDB_MAP=CDB_MAP, VOCAB_MAP=VOCAB_MAP, - CAT_MAP=CAT_MAP, project=project) + cat = get_medcat(project=project) train_medcat(cat, project, document) except Exception as e: if project.vocab.id: if len(VOCAB_MAP[project.vocab.id].unigram_table) == 0: return Exception('Vocab is missing the unigram table. On the vocab instance ' - 'use vocab.make_unigram_table() to build') + 'use vocab.make_unigram_table() to build') else: raise e @@ -445,8 +461,7 @@ def save_models(request): # Get project id p_id = request.data['project_id'] project = ProjectAnnotateEntities.objects.get(id=p_id) - cat = get_medcat(CDB_MAP=CDB_MAP, VOCAB_MAP=VOCAB_MAP, - CAT_MAP=CAT_MAP, project=project) + cat = get_medcat(project=project) cat.cdb.save(project.concept_db.cdb_file.path) @@ -546,8 +561,7 @@ def annotate_text(request): project = ProjectAnnotateEntities.objects.get(id=p_id) - cat = get_medcat(CDB_MAP=CDB_MAP, VOCAB_MAP=VOCAB_MAP, - CAT_MAP=CAT_MAP, project=project) + cat = get_medcat(project=project) cat.config.linking['filters']['cuis'] = set(cuis) spacy_doc = cat(message) @@ -627,9 +641,9 @@ def upload_deployment(request): @api_view(http_method_names=['GET', 'DELETE']) def cache_model(request, cdb_id): if request.method == 'GET': - get_cached_cdb(cdb_id, CDB_MAP) - elif request.method == 'DELETE' and cdb_id in CDB_MAP: - del CDB_MAP[cdb_id] + get_cached_cdb(cdb_id) + elif request.method == 'DELETE': + clear_cached_cdb(cdb_id) else: return Response(f'Invalid method or cdb_id:{cdb_id} is invalid / not loaded', 400) return Response('success', 200) @@ -637,8 +651,11 @@ def cache_model(request, cdb_id): @api_view(http_method_names=['GET']) def model_loaded(_): - return Response({p.id: False if not p.concept_db else p.concept_db.id in CDB_MAP - for p in ProjectAnnotateEntities.objects.all()}) + models_loaded = {} + for p in ProjectAnnotateEntities.objects.all(): + models_loaded[p.id] = is_model_loaded(p) + + return Response(models_loaded) @api_view(http_method_names=['GET', 'POST']) diff --git a/webapp/api/core/urls.py b/webapp/api/core/urls.py index 0b58b88e..4b12c3e6 100644 --- a/webapp/api/core/urls.py +++ b/webapp/api/core/urls.py @@ -27,6 +27,7 @@ path('api/anno-conf/', api.views.get_anno_tool_conf), path('api/search-concepts/', api.views.search_solr), path('api/prepare-documents/', api.views.prepare_documents), + path('api/prep-docs-bg-tasks/', api.views.prepare_docs_bg_tasks), path('api/api-token-auth/', auth_views.obtain_auth_token), path('admin/', admin.site.urls), path('api/api-auth/', include('rest_framework.urls', namespace='rest_framework')), diff --git a/webapp/frontend/src/components/common/DocumentSummary.vue b/webapp/frontend/src/components/common/DocumentSummary.vue index 6eea1418..7a6897b8 100644 --- a/webapp/frontend/src/components/common/DocumentSummary.vue +++ b/webapp/frontend/src/components/common/DocumentSummary.vue @@ -10,13 +10,31 @@
-