-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
2 changed files
with
553 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,394 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 1, | ||
"id": "cddfec5d", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import matplotlib.pyplot as plt\n", | ||
"import numpy as np\n", | ||
"import os\n", | ||
"import tensorflow as tf\n", | ||
"from tensorflow.keras.utils import image_dataset_from_directory" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 27, | ||
"id": "2bb91ad6", | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"Found 694 files belonging to 2 classes.\n", | ||
"Found 200 files belonging to 2 classes.\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"BATCH_SIZE = 32\n", | ||
"NUM_CLASSES = 2\n", | ||
"IMG_SIZE = 224\n", | ||
"NUM_EPOCHS = 15\n", | ||
"\n", | ||
"TRAIN_DIR = \"C://Users//alvin//DATA//dsc//cxc-2022//data//alien_predator//train\"\n", | ||
"TEST_DIR = \"C://Users//alvin//DATA//dsc//cxc-2022//data//alien_predator//validation\"\n", | ||
"\n", | ||
"train_ds = image_dataset_from_directory(TRAIN_DIR,\n", | ||
" shuffle=True,\n", | ||
" batch_size=BATCH_SIZE,\n", | ||
" image_size=(IMG_SIZE, IMG_SIZE))\n", | ||
"\n", | ||
"test_ds = image_dataset_from_directory(TEST_DIR,\n", | ||
" shuffle=True,\n", | ||
" batch_size=BATCH_SIZE,\n", | ||
" image_size=(IMG_SIZE, IMG_SIZE))\n", | ||
" " | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 3, | ||
"id": "a6ac7ebc", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"model = tf.keras.Sequential([\n", | ||
" tf.keras.layers.Resizing(IMG_SIZE, IMG_SIZE),\n", | ||
" tf.keras.layers.Rescaling(1./255),\n", | ||
" tf.keras.layers.Conv2D(16, 3, padding='same', activation='relu'),\n", | ||
" tf.keras.layers.MaxPooling2D(),\n", | ||
" tf.keras.layers.Conv2D(32, 3, padding='same', activation='relu'),\n", | ||
" tf.keras.layers.MaxPooling2D(),\n", | ||
" tf.keras.layers.Conv2D(64, 3, padding='same', activation='relu'),\n", | ||
" tf.keras.layers.MaxPooling2D(),\n", | ||
" tf.keras.layers.Flatten(),\n", | ||
" tf.keras.layers.Dense(128, activation='relu'),\n", | ||
" tf.keras.layers.Dense(NUM_CLASSES)\n", | ||
"])" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 4, | ||
"id": "8d01f717", | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"Epoch 1/15\n", | ||
"22/22 [==============================] - 9s 405ms/step - loss: 1.2455 - acc: 0.5274 - val_loss: 0.6716 - val_acc: 0.6100\n", | ||
"Epoch 2/15\n", | ||
"22/22 [==============================] - 10s 435ms/step - loss: 0.6139 - acc: 0.6470 - val_loss: 0.6061 - val_acc: 0.6450\n", | ||
"Epoch 3/15\n", | ||
"22/22 [==============================] - 10s 450ms/step - loss: 0.5307 - acc: 0.7522 - val_loss: 0.6285 - val_acc: 0.6500\n", | ||
"Epoch 4/15\n", | ||
"22/22 [==============================] - 10s 444ms/step - loss: 0.4701 - acc: 0.7997 - val_loss: 0.5551 - val_acc: 0.7200\n", | ||
"Epoch 5/15\n", | ||
"22/22 [==============================] - 10s 438ms/step - loss: 0.3152 - acc: 0.8833 - val_loss: 0.5526 - val_acc: 0.7300\n", | ||
"Epoch 6/15\n", | ||
"22/22 [==============================] - 12s 536ms/step - loss: 0.2028 - acc: 0.9265 - val_loss: 0.7496 - val_acc: 0.7050\n", | ||
"Epoch 7/15\n", | ||
"22/22 [==============================] - 17s 753ms/step - loss: 0.1850 - acc: 0.9251 - val_loss: 0.7892 - val_acc: 0.6500\n", | ||
"Epoch 8/15\n", | ||
"22/22 [==============================] - 15s 677ms/step - loss: 0.1292 - acc: 0.9640 - val_loss: 1.0998 - val_acc: 0.6900\n", | ||
"Epoch 9/15\n", | ||
"22/22 [==============================] - 16s 729ms/step - loss: 0.0668 - acc: 0.9798 - val_loss: 0.9719 - val_acc: 0.7000\n", | ||
"Epoch 10/15\n", | ||
"22/22 [==============================] - 16s 734ms/step - loss: 0.0348 - acc: 0.9942 - val_loss: 1.0952 - val_acc: 0.6950\n", | ||
"Epoch 11/15\n", | ||
"22/22 [==============================] - 16s 718ms/step - loss: 0.0164 - acc: 0.9986 - val_loss: 1.3058 - val_acc: 0.7200\n", | ||
"Epoch 12/15\n", | ||
"22/22 [==============================] - 16s 718ms/step - loss: 0.0145 - acc: 0.9971 - val_loss: 1.3329 - val_acc: 0.7200\n", | ||
"Epoch 13/15\n", | ||
"22/22 [==============================] - 16s 725ms/step - loss: 0.0111 - acc: 0.9986 - val_loss: 1.4546 - val_acc: 0.7000\n", | ||
"Epoch 14/15\n", | ||
"22/22 [==============================] - 16s 732ms/step - loss: 0.0036 - acc: 1.0000 - val_loss: 1.6123 - val_acc: 0.7050\n", | ||
"Epoch 15/15\n", | ||
"22/22 [==============================] - 16s 735ms/step - loss: 0.0026 - acc: 1.0000 - val_loss: 1.6705 - val_acc: 0.7000\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"model.compile(\n", | ||
" optimizer=tf.keras.optimizers.Adam(),\n", | ||
" loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n", | ||
" metrics=['acc']\n", | ||
")\n", | ||
"\n", | ||
"\n", | ||
"history = model.fit(\n", | ||
" train_ds,\n", | ||
" validation_data=test_ds,\n", | ||
" epochs=NUM_EPOCHS\n", | ||
")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 5, | ||
"id": "7208139f", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# talk about overfitting, now adding dropout" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 6, | ||
"id": "041bd4a4", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"dropout_model = tf.keras.Sequential([\n", | ||
" tf.keras.layers.Resizing(IMG_SIZE, IMG_SIZE),\n", | ||
" tf.keras.layers.Rescaling(1./255),\n", | ||
" tf.keras.layers.Conv2D(16, 3, padding='same', activation='relu'),\n", | ||
" tf.keras.layers.MaxPooling2D(),\n", | ||
" tf.keras.layers.Dropout(0.3),\n", | ||
" tf.keras.layers.Conv2D(32, 3, padding='same', activation='relu'),\n", | ||
" tf.keras.layers.MaxPooling2D(),\n", | ||
" tf.keras.layers.Conv2D(64, 3, padding='same', activation='relu'),\n", | ||
" tf.keras.layers.MaxPooling2D(),\n", | ||
" tf.keras.layers.Dropout(0.3),\n", | ||
" tf.keras.layers.Flatten(),\n", | ||
" tf.keras.layers.Dense(128, activation='relu'),\n", | ||
" tf.keras.layers.Dense(NUM_CLASSES)\n", | ||
"])" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 7, | ||
"id": "dff898ba", | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"Epoch 1/15\n", | ||
"22/22 [==============================] - 19s 841ms/step - loss: 1.0004 - acc: 0.4755 - val_loss: 0.6842 - val_acc: 0.5400\n", | ||
"Epoch 2/15\n", | ||
"22/22 [==============================] - 19s 861ms/step - loss: 0.6597 - acc: 0.5461 - val_loss: 0.6651 - val_acc: 0.6650\n", | ||
"Epoch 3/15\n", | ||
"22/22 [==============================] - 18s 822ms/step - loss: 0.5992 - acc: 0.6830 - val_loss: 0.6338 - val_acc: 0.6550\n", | ||
"Epoch 4/15\n", | ||
"22/22 [==============================] - 19s 866ms/step - loss: 0.5560 - acc: 0.7478 - val_loss: 0.5589 - val_acc: 0.6850\n", | ||
"Epoch 5/15\n", | ||
"22/22 [==============================] - 19s 846ms/step - loss: 0.4456 - acc: 0.7968 - val_loss: 0.5684 - val_acc: 0.6950\n", | ||
"Epoch 6/15\n", | ||
"22/22 [==============================] - 19s 869ms/step - loss: 0.4187 - acc: 0.8516 - val_loss: 0.5490 - val_acc: 0.7200\n", | ||
"Epoch 7/15\n", | ||
"22/22 [==============================] - 19s 860ms/step - loss: 0.2824 - acc: 0.8818 - val_loss: 0.5506 - val_acc: 0.7100\n", | ||
"Epoch 8/15\n", | ||
"22/22 [==============================] - 20s 894ms/step - loss: 0.2260 - acc: 0.9222 - val_loss: 0.5818 - val_acc: 0.7000\n", | ||
"Epoch 9/15\n", | ||
"22/22 [==============================] - 20s 895ms/step - loss: 0.1209 - acc: 0.9640 - val_loss: 0.6214 - val_acc: 0.7200\n", | ||
"Epoch 10/15\n", | ||
"22/22 [==============================] - 19s 870ms/step - loss: 0.0976 - acc: 0.9741 - val_loss: 0.7811 - val_acc: 0.6700\n", | ||
"Epoch 11/15\n", | ||
"22/22 [==============================] - 20s 907ms/step - loss: 0.0605 - acc: 0.9813 - val_loss: 0.8814 - val_acc: 0.7250\n", | ||
"Epoch 12/15\n", | ||
"22/22 [==============================] - 19s 862ms/step - loss: 0.0546 - acc: 0.9798 - val_loss: 0.9069 - val_acc: 0.7300\n", | ||
"Epoch 13/15\n", | ||
"22/22 [==============================] - 20s 890ms/step - loss: 0.0386 - acc: 0.9914 - val_loss: 0.9644 - val_acc: 0.6900\n", | ||
"Epoch 14/15\n", | ||
"22/22 [==============================] - 20s 884ms/step - loss: 0.0475 - acc: 0.9827 - val_loss: 0.9680 - val_acc: 0.6850\n", | ||
"Epoch 15/15\n", | ||
"22/22 [==============================] - 20s 905ms/step - loss: 0.0415 - acc: 0.9870 - val_loss: 1.0254 - val_acc: 0.7500\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"dropout_model.compile(\n", | ||
" optimizer=tf.keras.optimizers.Adam(),\n", | ||
" loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n", | ||
" metrics=['acc']\n", | ||
")\n", | ||
"\n", | ||
"\n", | ||
"dropout_history = dropout_model.fit(\n", | ||
" train_ds,\n", | ||
" validation_data=test_ds,\n", | ||
" epochs=NUM_EPOCHS\n", | ||
")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "01064e0f", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# still overfitting haha" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "11962da7", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# Might be good to introduce ReduceLRonPlateau and EarlyStopping here too" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 29, | ||
"id": "0c98b8b5", | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"Found 694 images belonging to 2 classes.\n", | ||
"Found 200 images belonging to 2 classes.\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"train_aug_ds = tf.keras.preprocessing.image.ImageDataGenerator(\n", | ||
" brightness_range=[0.9, 1.1],\n", | ||
" rotation_range=30,\n", | ||
" width_shift_range=0.1,\n", | ||
" height_shift_range=0.1,\n", | ||
" shear_range=0.1,\n", | ||
" zoom_range=0.1,\n", | ||
" fill_mode=\"nearest\",\n", | ||
" horizontal_flip=True,\n", | ||
" vertical_flip=True,\n", | ||
" rescale=1.0/255.0\n", | ||
").flow_from_directory(TRAIN_DIR, batch_size=BATCH_SIZE, target_size=(IMG_SIZE, IMG_SIZE), shuffle=True, class_mode='sparse')\n", | ||
"\n", | ||
"\n", | ||
"test_aug_ds = tf.keras.preprocessing.image.ImageDataGenerator(rescale=1.0/255.0).flow_from_directory(\n", | ||
" TEST_DIR, batch_size=BATCH_SIZE, target_size=(IMG_SIZE, IMG_SIZE), shuffle=True, class_mode='sparse'\n", | ||
")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 32, | ||
"id": "5c30af59", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"del dropout_model\n", | ||
"\n", | ||
"dropout_model = tf.keras.Sequential([\n", | ||
" tf.keras.layers.Conv2D(16, 3, padding='same', activation='relu'),\n", | ||
" tf.keras.layers.MaxPooling2D(),\n", | ||
" tf.keras.layers.Dropout(0.3),\n", | ||
" tf.keras.layers.Conv2D(32, 3, padding='same', activation='relu'),\n", | ||
" tf.keras.layers.MaxPooling2D(),\n", | ||
" tf.keras.layers.Conv2D(64, 3, padding='same', activation='relu'),\n", | ||
" tf.keras.layers.MaxPooling2D(),\n", | ||
" tf.keras.layers.Dropout(0.3),\n", | ||
" tf.keras.layers.Flatten(),\n", | ||
" tf.keras.layers.Dense(128, activation='relu'),\n", | ||
" tf.keras.layers.Dense(NUM_CLASSES)\n", | ||
"])" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 33, | ||
"id": "c892e485", | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"Epoch 1/20\n", | ||
"22/22 [==============================] - 22s 978ms/step - loss: 1.4652 - acc: 0.4597 - val_loss: 0.6885 - val_acc: 0.5050\n", | ||
"Epoch 2/20\n", | ||
"22/22 [==============================] - 22s 988ms/step - loss: 0.6711 - acc: 0.5778 - val_loss: 0.6485 - val_acc: 0.6100\n", | ||
"Epoch 3/20\n", | ||
"22/22 [==============================] - 22s 988ms/step - loss: 0.6291 - acc: 0.6455 - val_loss: 0.6165 - val_acc: 0.6450\n", | ||
"Epoch 4/20\n", | ||
"22/22 [==============================] - 22s 990ms/step - loss: 0.5852 - acc: 0.6974 - val_loss: 0.7310 - val_acc: 0.5700\n", | ||
"Epoch 5/20\n", | ||
"22/22 [==============================] - 22s 1s/step - loss: 0.6158 - acc: 0.6657 - val_loss: 0.6325 - val_acc: 0.6400\n", | ||
"Epoch 6/20\n", | ||
"22/22 [==============================] - 22s 977ms/step - loss: 0.6196 - acc: 0.6671 - val_loss: 0.6182 - val_acc: 0.6500\n", | ||
"Epoch 7/20\n", | ||
"22/22 [==============================] - 22s 975ms/step - loss: 0.5712 - acc: 0.7075 - val_loss: 0.6478 - val_acc: 0.5900\n", | ||
"Epoch 8/20\n", | ||
"22/22 [==============================] - 22s 971ms/step - loss: 0.5752 - acc: 0.7032 - val_loss: 0.7221 - val_acc: 0.5600\n", | ||
"Epoch 9/20\n", | ||
"22/22 [==============================] - 21s 958ms/step - loss: 0.5577 - acc: 0.7248 - val_loss: 0.5665 - val_acc: 0.6950\n", | ||
"Epoch 10/20\n", | ||
"22/22 [==============================] - 23s 1s/step - loss: 0.5383 - acc: 0.7248 - val_loss: 0.5474 - val_acc: 0.7050\n", | ||
"Epoch 11/20\n", | ||
"22/22 [==============================] - 24s 1s/step - loss: 0.5226 - acc: 0.7450 - val_loss: 0.6188 - val_acc: 0.6400\n", | ||
"Epoch 12/20\n", | ||
"22/22 [==============================] - 22s 988ms/step - loss: 0.5254 - acc: 0.7450 - val_loss: 0.5570 - val_acc: 0.7300\n", | ||
"Epoch 13/20\n", | ||
"22/22 [==============================] - 22s 989ms/step - loss: 0.5132 - acc: 0.7522 - val_loss: 0.5408 - val_acc: 0.7200\n", | ||
"Epoch 14/20\n", | ||
"22/22 [==============================] - 22s 969ms/step - loss: 0.5049 - acc: 0.7536 - val_loss: 0.5888 - val_acc: 0.6800\n", | ||
"Epoch 15/20\n", | ||
"22/22 [==============================] - 21s 966ms/step - loss: 0.5174 - acc: 0.7392 - val_loss: 0.5437 - val_acc: 0.7300\n", | ||
"Epoch 16/20\n", | ||
"22/22 [==============================] - 22s 960ms/step - loss: 0.4946 - acc: 0.7695 - val_loss: 0.5308 - val_acc: 0.7150\n", | ||
"Epoch 17/20\n", | ||
"22/22 [==============================] - 22s 990ms/step - loss: 0.4957 - acc: 0.7622 - val_loss: 0.5240 - val_acc: 0.7450\n", | ||
"Epoch 18/20\n", | ||
"22/22 [==============================] - 14s 598ms/step - loss: 0.5104 - acc: 0.7565 - val_loss: 0.6348 - val_acc: 0.6900\n", | ||
"Epoch 19/20\n", | ||
"22/22 [==============================] - 12s 531ms/step - loss: 0.4926 - acc: 0.7695 - val_loss: 0.6126 - val_acc: 0.6900\n", | ||
"Epoch 20/20\n", | ||
"22/22 [==============================] - 12s 527ms/step - loss: 0.4826 - acc: 0.7911 - val_loss: 0.5567 - val_acc: 0.7300\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"dropout_model.compile(\n", | ||
" optimizer=tf.keras.optimizers.Adam(),\n", | ||
" loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n", | ||
" metrics=['acc']\n", | ||
")\n", | ||
"\n", | ||
"\n", | ||
"dropout_history = dropout_model.fit(\n", | ||
" train_aug_ds,\n", | ||
" validation_data=test_aug_ds,\n", | ||
" epochs=NUM_EPOCHS + 5\n", | ||
")" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "Python 3 (ipykernel)", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.8.8" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 5 | ||
} |
Oops, something went wrong.