diff --git a/Model Training.ipynb b/Model Training.ipynb new file mode 100644 index 0000000..5992146 --- /dev/null +++ b/Model Training.ipynb @@ -0,0 +1,178 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "premier-closing", + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "import pandas as pd\n", + "import torch\n", + "\n", + "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "documented-viewer", + "metadata": {}, + "outputs": [], + "source": [ + "import drGAT" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "numerical-karaoke", + "metadata": {}, + "outputs": [], + "source": [ + "# ?drGAT" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "id": "b2258cab-db0d-4cb2-9e6a-4f347265bf39", + "metadata": {}, + "outputs": [], + "source": [ + "data = torch.load('train.pt')" + ] + }, + { + "cell_type": "markdown", + "id": "caroline-patrick", + "metadata": {}, + "source": [ + "# Model training\n", + "If is_sample=True, then the # of epochs is 5" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "id": "latter-kitchen", + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Using: cpu\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|████████████████████████████████████████████████████████████████████████████████████████| 5/5 [01:51<00:00, 22.34s/it]\n" + ] + } + ], + "source": [ + "model, attention = drGAT.train(data, is_sample=True, is_save=False)" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "id": "thrown-valve", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([[0.0041, 0.0000, 0.0000, ..., 0.0041, 0.0041, 0.0041],\n", + " [0.0000, 0.0041, 0.0000, ..., 0.0041, 0.0041, 0.0041],\n", + " [0.0000, 0.0000, 0.0041, ..., 0.0041, 0.0041, 0.0041],\n", + " ...,\n", + " [0.0041, 0.0041, 0.0041, ..., 0.0041, 0.0000, 0.0000],\n", + " [0.0041, 0.0041, 0.0041, ..., 0.0000, 0.0041, 0.0000],\n", + " [0.0041, 0.0041, 0.0041, ..., 0.0000, 0.0000, 0.0041]])" + ] + }, + "execution_count": 21, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "attention" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "utility-collaboration", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "markdown", + "id": "available-cameroon", + "metadata": {}, + "source": [ + "# Attention coefficient\n", + "This will be utilized [Fig3](https://github.com/inoue0426/new_drGAT/blob/main/Figs/Fig3.ipynb) and [Fig5](https://github.com/inoue0426/new_drGAT/blob/main/Figs/Fig5.ipynb)" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "id": "crude-intermediate", + "metadata": {}, + "outputs": [], + "source": [ + "# attention = pd.DataFrame(attention)\n", + "# attention" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "id": "executed-destination", + "metadata": {}, + "outputs": [], + "source": [ + "# attention.to_csv('attention.csv.gz', compression='gzip')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "korean-recipe", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "drGAT", + "language": "python", + "name": "drgat" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.14" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/train.pt b/train.pt new file mode 100644 index 0000000..b4bb26f Binary files /dev/null and b/train.pt differ