From ba3e7fa738baee678cd492465fcf3fd3c0b8bf03 Mon Sep 17 00:00:00 2001 From: Hemanth Kumar J <45484205+hemanthkumar17@users.noreply.github.com> Date: Wed, 22 Mar 2023 13:27:16 -0500 Subject: [PATCH] Update main_task_retrieval.py FIX Bug https://github.com/microsoft/UniVL/issues/27 --- main_task_retrieval.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/main_task_retrieval.py b/main_task_retrieval.py index c335a6b..d8ed4eb 100644 --- a/main_task_retrieval.py +++ b/main_task_retrieval.py @@ -440,6 +440,7 @@ def eval_epoch(args, model, test_dataloader, device, n_gpu): sim_matrix = np.concatenate(tuple(sim_matrix), axis=0) else: sim_matrix = _run_on_single_gpu(model, batch_list, batch_list, batch_sequence_output_list, batch_visual_output_list) + sim_matrix = np.concatenate(tuple(sim_matrix), axis=0) metrics = compute_metrics(sim_matrix) logger.info('\t Length-T: {}, Length-V:{}'.format(len(sim_matrix), len(sim_matrix[0]))) @@ -512,4 +513,4 @@ def main(): eval_epoch(args, model, test_dataloader, device, n_gpu) if __name__ == "__main__": - main() \ No newline at end of file + main()