Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

modified: #16

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,14 @@ The data set covers four measurement series. Labels are provided as destructive
### Requirements
- Python 3.10
- PyTorch 1.11.0
and the packages defined in the requirements file (```pip3 install -r requirements.txt```)
and the packages defined in the requirements file (```pip install -r requirements.txt```)
- Download the data set to a local folder

## How to train
If all packages are installed and the data set was downloaded, the training can start.
This will train the HS-CNN model on the ripeness classification of avocados:

PYTHONPATH=$PYTHONPATH:. python3 classification/train.py --data_path /folder/of/downloaded/dataset/ --model deephs_net --fruit avocado --classification_type ripeness --seed 23312323
PYTHONPATH=$PYTHONPATH:. python classification/train.py --data_path "D:\College\Project" --model deephs_net --fruit mango --classification_type ripeness --seed 23312323 --log_path "D:\College\Project\test_log_path" --num_epochs 2

<img src="images/deephs_net_loss.png" alt="Loss" style="width: 300px;"/><br>
<img src="images/deephs_net_accuracy.png" alt="Accuracy" style="width: 300px;"/><br>
Expand All @@ -51,7 +51,7 @@ And this will train HS-CNN + HyveConv++ on the same classification task:
<img src="images/hyve_confusion.png" alt="Confusion" style="width: 300px;"/><br>
**Figure 2** - Training of HS-CNN + HyveConv++:

```PYTHONPATH=$PYTHONPATH:. python3 classification/train.py --help``` provides helpful information regarding the parameters.
```PYTHONPATH=$PYTHONPATH:. python classification/train.py --help``` provides helpful information regarding the parameters.
For more information about the training framework PyTorch-Lightning, we refer to the official documentation (https://pytorch-lightning.readthedocs.io/en/latest/).


Expand Down
41 changes: 32 additions & 9 deletions classification/train.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@

#get cpu num core py os
import os


import argparse

import pytorch_lightning as lightning
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.loggers.wandb import WandbLogger
from pytorch_lightning.loggers import WandbLogger
import torch
from torch.utils.data import DataLoader
from torchvision import transforms
Expand All @@ -26,6 +31,14 @@
)
import core.util as util

import wandb
import torch






AUGMENTATION_CONFIG_TRAIN = {
'random_flip': True,
'random_rotate': True,
Expand Down Expand Up @@ -321,7 +334,7 @@ def str2bool(v):
help="the root folder of dataset")
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--log_path", type=str, default=None)
parser.add_argument("--online_logging", default=False, action='store_true')
parser.add_argument("--online_logging", default=True, action='store_true') #online logging true
parser.add_argument("--debug", default=False, action='store_true')
parser.add_argument("--model_checkpoint", default=None, type=str)

Expand Down Expand Up @@ -363,8 +376,12 @@ def main(hparams):
print("Hparams: %s" % hparams)

model = DeepHsModule(hparams)
logger = WandbLogger(offline=not hparams['online_logging'], save_dir=hparams['log_path'],
project='deephs') if 'logger' not in hparams.keys() else hparams['logger']
logger = WandbLogger(
offline=not hparams.get('online_logging', True), # Use parentheses here
save_dir=hparams.get('log_path', './logs'),
project='deephs') if 'logger' not in hparams else hparams['logger']



early_stop_callback = EarlyStopping(
monitor='val/loss',
Expand All @@ -382,12 +399,17 @@ def main(hparams):
mode='min'
)





num_cpu_cores = os.cpu_count() #ใช้ cpu
trainer = lightning.Trainer(max_epochs=opt.num_epochs,
accelerator='gpu',
devices=-1,
accelerator='cpu',
devices=1,
logger=logger,
strategy='ddp',
min_epochs=50,
strategy=None,
min_epochs=2,
callbacks=[LRLoggingCallback(),
early_stop_callback,
checkpoint_callback
Expand All @@ -408,7 +430,8 @@ def main(hparams):

if __name__ == "__main__":
opt = get_args()
num_gpus = torch.cuda.device_count()
#ปิด gpu
#num_gpus = torch.cuda.device_count()

# fix the seed for reproducibility
seed = opt.seed
Expand Down
2 changes: 1 addition & 1 deletion classification/train_multi_camera.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ def main(hparams):
mode='min'
)

trainer = lightning.Trainer(max_epochs=opt.num_epochs, gpus=-1, logger=logger,
trainer = lightning.Trainer(max_epochs=opt.num_epochs, gpus=0, logger=logger,
strategy='ddp',
min_epochs=50,
callbacks=[LRLoggingCallback(),
Expand Down
11 changes: 8 additions & 3 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,10 @@
h5py==2.10.0
#upgrade h5py ,wandb,pytorch-lightning
#hey
h5py==3.11.0
wandb==0.18.0
pytorch-lightning==1.5.10

#classic
image==1.5.32
imageio==2.8.0
matplotlib==3.5.3
Expand All @@ -8,7 +14,6 @@ opencv-python==4.6.0.66
PyExifTool==0.1.1
pyparsing==3.0.9
python-dateutil==2.8.2
pytorch-lightning==1.7.6
scikit-image==0.19.2
scikit-learn==1.0.2
scipy==1.8.0
Expand All @@ -26,6 +31,6 @@ torchvision==0.12.0
tornado==6.2
tqdm==4.64.1
urllib3==1.26.12
wandb==0.13.3



1 change: 1 addition & 0 deletions run_tracking.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@

17 changes: 17 additions & 0 deletions test_gpu_.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
'''
import torch
import pytorch_lightning as pl
print(torch.__version__)
print(pl.__version__)
'''
import torch
import os
#print(f"MASTER_ADDR: {os.getenv('MASTER_ADDR')}")
#print(f"MASTER_PORT: {os.getenv('MASTER_PORT')}")

#print(os.cpu_count())



device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)
25 changes: 25 additions & 0 deletions wandb/offline-run-20240916_141122-l9un9d3d/files/config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
_wandb:
value:
cli_version: 0.18.0
m: []
python_version: 3.10.0
t:
"1":
- 1
- 5
- 9
- 11
- 41
- 53
- 55
"3":
- 4
- 23
- 55
"4": 3.10.0
"5": 0.18.0
"8":
- 3
- 5
"12": 0.18.0
"13": windows-amd64
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
97 changes: 97 additions & 0 deletions wandb/offline-run-20240916_141122-l9un9d3d/files/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
absl-py==2.1.0
aiohappyeyeballs==2.4.0
aiohttp==3.10.5
aiosignal==1.3.1
asgiref==3.8.1
async-timeout==4.0.3
attrs==24.2.0
cachetools==5.5.0
certifi==2024.8.30
charset-normalizer==3.3.2
click==8.1.7
colorama==0.4.6
cycler==0.12.1
Django==5.1.1
docker-pycreds==0.4.0
filelock==3.16.0
fonttools==4.53.1
frozenlist==1.4.1
fsspec==2024.9.0
future==1.0.0
gitdb==4.0.11
GitPython==3.1.43
google-auth==2.34.0
google-auth-oauthlib==0.4.6
grpcio==1.66.1
h5py==3.11.0
idna==3.9
image==1.5.32
imageio==2.8.0
Jinja2==3.1.4
joblib==1.4.2
kiwisolver==1.4.7
lightning==2.4.0
lightning-utilities==0.11.7
Markdown==3.7
MarkupSafe==2.1.5
matplotlib==3.5.3
mpmath==1.3.0
multidict==6.1.0
networkx==3.3
numpy==1.23.3
oauthlib==3.2.2
opencv-contrib-python==4.6.0.66
opencv-python==4.6.0.66
packaging==24.1
pandas==2.2.2
pathtools==0.1.2
pillow==10.4.0
pip==21.2.3
platformdirs==4.3.3
promise==2.3
protobuf==3.19.6
psutil==6.0.0
pyasn1==0.6.1
pyasn1_modules==0.4.1
pyDeprecate==0.3.1
PyExifTool==0.1.1
pyparsing==3.0.9
python-dateutil==2.8.2
pytorch-lightning==1.5.10
pytz==2024.2
PyWavelets==1.7.0
PyYAML==6.0.2
requests==2.32.3
requests-oauthlib==2.0.0
rsa==4.9
scikit-image==0.19.2
scikit-learn==1.0.2
scipy==1.8.0
seaborn==0.11.0
sentry-sdk==2.14.0
setproctitle==1.3.2
setuptools==59.5.0
shortuuid==1.0.13
six==1.16.0
smmap==5.0.0
spectral==0.23
sqlparse==0.5.1
sympy==1.13.2
tensorboard==2.10.0
tensorboard-data-server==0.6.1
tensorboard-plugin-wit==1.8.1
threadpoolctl==3.5.0
tifffile==2024.8.30
torch==1.11.0+cpu
torchaudio==0.11.0
torchmetrics==0.9.3
torchvision==0.12.0+cpu
tornado==6.2
tqdm==4.64.1
typing_extensions==4.12.2
tzdata==2024.1
urllib3==1.26.12
wandb==0.18.0
Werkzeug==3.0.4
wheel==0.44.0
yarl==1.11.1
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
{
"os": "Windows-10-10.0.22631-SP0",
"python": "3.10.0",
"startedAt": "2024-09-16T07:11:22.358661Z",
"args": [
"--data_path",
"D:\\College\\Project",
"--model",
"deephs_net",
"--fruit",
"mango",
"--classification_type",
"ripeness",
"--seed",
"23312323",
"--num_epochs",
"2"
],
"program": "C:\\Users\\intha\\deephs_fruit\\classification\\train.py",
"codePath": "classification\\train.py",
"git": {
"remote": "https://github.com/hut22929/deephs_fruit.git",
"commit": "e74580eea097ac4af5f40b10344603c5fb7e93a3"
},
"root": "C:\\Users\\intha\\deephs_fruit",
"host": "Hut",
"username": "intha",
"executable": "C:\\Users\\intha\\deephs_fruit\\venv\\Scripts\\python.exe",
"codePathLocal": "classification\\train.py",
"cpu_count": 4,
"cpu_count_logical": 8,
"disk": {
"/": {
"total": "148813901824",
"used": "146011209728"
}
},
"memory": {
"total": "12674457600"
},
"cpu": {
"count": 4,
"countLogical": 8
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{}
Binary file not shown.
25 changes: 25 additions & 0 deletions wandb/offline-run-20240916_142954-6krtpvlu/files/config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
_wandb:
value:
cli_version: 0.18.0
m: []
python_version: 3.10.0
t:
"1":
- 1
- 5
- 9
- 11
- 41
- 53
- 55
"3":
- 4
- 23
- 55
"4": 3.10.0
"5": 0.18.0
"8":
- 3
- 5
"12": 0.18.0
"13": windows-amd64
Loading