diff --git a/models.py b/models.py index e5c4080..c91b643 100644 --- a/models.py +++ b/models.py @@ -275,7 +275,8 @@ def __init__(self): super(MultitaskClassifierBase, self).__init__() def forward(self, input_ids, entity_labels, - attention_mask=None, token_type_ids=None, kg_datas=None, position_ids=None, head_mask=None, rel_idxs=[], lidx=[], ridx=[], task='relation', args=None): + attention_mask=None, token_type_ids=None, kg_datas=None, position_ids=None, head_mask=None, + rel_idxs=None, lidx=[], ridx=[], task='relation', args=None):[], task='relation', args=None): ''' entity_labels are just for extracting proteins ''' @@ -716,4 +717,4 @@ def __init__(self, args, bert_weights_path='biobert_weights/scibert_scivocab_unc self.softmax_ent = nn.Softmax(dim=2) if args.use_knowledge: self.gnn = KnowledgeGNN(kg_embedding_dim=kg_embedding_dim, num_edge_embeddings=args.num_edge_embeddings, token_embedding_size=config.hidden_size, args=args, kg_pretrained_weights=kg_pretrained_weights) - \ No newline at end of file +