Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LearningRateFinder not working with CLI optimizers #16787

Open
rusmux opened this issue Feb 16, 2023 · 2 comments
Open

LearningRateFinder not working with CLI optimizers #16787

rusmux opened this issue Feb 16, 2023 · 2 comments
Labels
bug Something isn't working tuner

Comments

@rusmux
Copy link
Contributor

rusmux commented Feb 16, 2023

Bug description

LearningRateFinder does not update the optimizer if it is defined from the CLI or yaml config file.

For example, I define in train.yaml:

...
optimizer:
  class_path: torch.optim.AdamW
  init_args:
    lr: 1.5e-3
...

And I set the callback:

LearningRateFinder(update_attr=True)

At the start, It finds the best learning rate:

Screenshot 82

But after that, it still uses the learning rate I provided:

Screenshot 83

I also tried to do it manually like that:

Screenshot 84

But I had the same result.

How to reproduce the bug

Define an optimizer in a yaml config file. Add the `LearningRateFinder` callback.

Error messages and logs

# Error messages and logs here please

Environment

Current environment
* CUDA:
	- GPU:
		- NVIDIA RTX A4000
	- available:         True
	- version:           11.7
* Lightning:
	- lightning-utilities: 0.6.0.post0
	- pytorch-lightning: 1.9.1
	- torch:             1.13.1
	- torchmetrics:      0.11.1
	- torchvision:       0.14.1
* Packages:
	- aiobotocore:       2.4.2
	- aiofiles:          22.1.0
	- aiohttp:           3.8.4
	- aiohttp-retry:     2.8.3
	- aioitertools:      0.11.0
	- aiosignal:         1.3.1
	- aiosqlite:         0.18.0
	- albumentations:    1.3.0
	- amqp:              5.1.1
	- antlr4-python3-runtime: 4.9.3
	- anyio:             3.6.2
	- appdirs:           1.4.4
	- argcomplete:       2.0.0
	- argon2-cffi:       21.3.0
	- argon2-cffi-bindings: 21.2.0
	- arrow:             1.2.3
	- astor:             0.8.1
	- asttokens:         2.2.1
	- async-timeout:     4.0.2
	- asyncssh:          2.13.0
	- atpublic:          3.1.1
	- attrs:             22.2.0
	- babel:             2.11.0
	- backcall:          0.2.0
	- bandit:            1.7.4
	- beautifulsoup4:    4.11.2
	- billiard:          3.6.4.0
	- bleach:            6.0.0
	- boto3:             1.24.59
	- botocore:          1.27.59
	- celery:            5.2.7
	- certifi:           2022.12.7
	- cffi:              1.15.1
	- cfgv:              3.3.1
	- charset-normalizer: 3.0.1
	- clearml:           1.9.1
	- click:             8.1.3
	- click-didyoumean:  0.3.0
	- click-plugins:     1.1.1
	- click-repl:        0.2.0
	- colorama:          0.4.6
	- comm:              0.1.2
	- configobj:         5.0.8
	- contourpy:         1.0.7
	- cryptography:      39.0.1
	- cycler:            0.11.0
	- dacite:            1.8.0
	- darglint:          1.8.1
	- debugpy:           1.6.6
	- decorator:         5.1.1
	- defusedxml:        0.7.1
	- deprecated:        1.2.13
	- dictdiffer:        0.9.0
	- dill:              0.3.6
	- diskcache:         5.4.0
	- distlib:           0.3.6
	- distro:            1.8.0
	- dnspython:         2.3.0
	- docstring-parser:  0.15
	- docutils:          0.19
	- dpath:             2.1.4
	- dulwich:           0.21.2
	- dvc:               2.45.0
	- dvc-data:          0.40.1
	- dvc-http:          2.30.2
	- dvc-objects:       0.19.3
	- dvc-render:        0.1.2
	- dvc-s3:            2.21.0
	- dvc-studio-client: 0.4.0
	- dvc-task:          0.1.11
	- dvclive:           2.0.2
	- eradicate:         2.1.0
	- eventlet:          0.33.3
	- exceptiongroup:    1.1.0
	- executing:         1.2.0
	- fastjsonschema:    2.16.2
	- fiftyone:          0.18.0
	- fiftyone-brain:    0.9.2
	- fiftyone-db:       0.4.0
	- filelock:          3.9.0
	- flake8:            4.0.1
	- flake8-bandit:     3.0.0
	- flake8-broken-line: 0.5.0
	- flake8-bugbear:    22.12.6
	- flake8-commas:     2.1.0
	- flake8-comprehensions: 3.10.1
	- flake8-debugger:   4.1.2
	- flake8-docstrings: 1.7.0
	- flake8-eradicate:  1.4.0
	- flake8-isort:      4.2.0
	- flake8-polyfill:   1.0.2
	- flake8-quotes:     3.3.2
	- flake8-rst-docstrings: 0.2.7
	- flake8-string-format: 0.3.0
	- flatten-dict:      0.4.2
	- flufl.lock:        7.1.1
	- fonttools:         4.38.0
	- fqdn:              1.5.1
	- frozenlist:        1.3.3
	- fsspec:            2023.1.0
	- funcy:             1.18
	- furl:              2.1.3
	- future:            0.18.3
	- gitdb:             4.0.10
	- gitpython:         3.1.30
	- glob2:             0.7
	- grandalf:          0.8
	- graphql-core:      3.2.3
	- greenlet:          2.0.2
	- h11:               0.14.0
	- h2:                4.1.0
	- hpack:             4.0.0
	- httpcore:          0.16.3
	- httpx:             0.23.3
	- huggingface-hub:   0.12.0
	- hydra-core:        1.3.1
	- hypercorn:         0.14.3
	- hyperframe:        6.0.1
	- identify:          2.5.18
	- idna:              3.4
	- imageio:           2.25.1
	- importlib-resources: 5.10.2
	- iniconfig:         2.0.0
	- ipykernel:         6.21.2
	- ipython:           8.10.0
	- ipython-genutils:  0.2.0
	- ipywidgets:        8.0.4
	- isoduration:       20.11.0
	- isort:             5.12.0
	- iterative-telemetry: 0.0.7
	- jedi:              0.18.2
	- jinja2:            3.1.2
	- jmespath:          1.0.1
	- joblib:            1.2.0
	- json5:             0.9.11
	- jsonargparse:      4.19.0
	- jsonpointer:       2.3
	- jsonschema:        4.17.3
	- jupyter-client:    8.0.2
	- jupyter-contrib-core: 0.4.2
	- jupyter-contrib-nbextensions: 0.7.0
	- jupyter-core:      5.2.0
	- jupyter-events:    0.5.0
	- jupyter-highlight-selected-word: 0.2.0
	- jupyter-nbextensions-configurator: 0.6.1
	- jupyter-server:    2.2.1
	- jupyter-server-fileid: 0.6.0
	- jupyter-server-terminals: 0.4.4
	- jupyter-server-ydoc: 0.6.1
	- jupyter-ydoc:      0.2.2
	- jupyterlab:        3.6.1
	- jupyterlab-pygments: 0.2.2
	- jupyterlab-server: 2.19.0
	- jupyterlab-widgets: 3.0.5
	- kaleido:           0.2.1
	- kiwisolver:        1.4.4
	- kombu:             5.2.4
	- lightning-utilities: 0.6.0.post0
	- lxml:              4.9.2
	- markdown-it-py:    2.1.0
	- markupsafe:        2.1.2
	- matplotlib:        3.7.0
	- matplotlib-inline: 0.1.6
	- mccabe:            0.6.1
	- mdurl:             0.1.2
	- mistune:           2.0.5
	- mongoengine:       0.24.2
	- motor:             3.1.1
	- multidict:         6.0.4
	- nanotime:          0.5.2
	- nbclassic:         0.5.1
	- nbclient:          0.7.2
	- nbconvert:         7.2.9
	- nbformat:          5.7.3
	- ndjson:            0.3.1
	- nest-asyncio:      1.5.6
	- networkx:          3.0
	- nodeenv:           1.7.0
	- notebook:          6.5.2
	- notebook-shim:     0.2.2
	- numpy:             1.24.2
	- nvidia-cublas-cu11: 11.10.3.66
	- nvidia-cuda-nvrtc-cu11: 11.7.99
	- nvidia-cuda-runtime-cu11: 11.7.99
	- nvidia-cudnn-cu11: 8.5.0.96
	- omegaconf:         2.3.0
	- onnx:              1.13.0
	- opencv-python-headless: 4.7.0.68
	- orderedmultidict:  1.0.1
	- orjson:            3.8.6
	- packaging:         23.0
	- pandas:            1.5.3
	- pandocfilters:     1.5.0
	- parso:             0.8.3
	- pathlib2:          2.3.7.post1
	- pathspec:          0.11.0
	- patool:            1.12
	- pbr:               5.11.1
	- pep8-naming:       0.13.2
	- pexpect:           4.8.0
	- pickleshare:       0.7.5
	- pillow:            9.4.0
	- pip:               23.0
	- platformdirs:      3.0.0
	- plotly:            5.13.0
	- pluggy:            1.0.0
	- pprintpp:          0.4.0
	- pre-commit:        2.21.0
	- priority:          2.0.0
	- prometheus-client: 0.16.0
	- prompt-toolkit:    3.0.36
	- protobuf:          3.20.3
	- psutil:            5.9.4
	- ptyprocess:        0.7.0
	- pure-eval:         0.2.2
	- pycodestyle:       2.8.0
	- pycparser:         2.21
	- pydocstyle:        6.3.0
	- pydot:             1.4.2
	- pyflakes:          2.4.0
	- pygit2:            1.11.1
	- pygments:          2.14.0
	- pygtrie:           2.5.0
	- pyjwt:             2.4.0
	- pymongo:           4.3.3
	- pyparsing:         3.0.9
	- pyrsistent:        0.19.3
	- pytest:            7.2.1
	- python-dateutil:   2.8.2
	- python-json-logger: 2.0.6
	- pytorch-lightning: 1.9.1
	- pytz:              2022.7.1
	- pytz-deprecation-shim: 0.1.0.post0
	- pywavelets:        1.4.1
	- pyyaml:            6.0
	- pyzmq:             25.0.0
	- qudida:            0.0.4
	- requests:          2.28.2
	- restructuredtext-lint: 1.4.0
	- retrying:          1.3.4
	- rfc3339-validator: 0.1.4
	- rfc3986:           1.5.0
	- rfc3986-validator: 0.1.1
	- rich:              13.3.1
	- ruamel.yaml:       0.17.21
	- ruamel.yaml.clib:  0.2.7
	- s3fs:              2023.1.0
	- s3transfer:        0.6.0
	- scikit-image:      0.19.3
	- scikit-learn:      1.2.1
	- scipy:             1.10.0
	- scmrepo:           0.1.9
	- send2trash:        1.8.0
	- setuptools:        67.3.1
	- shortuuid:         1.0.11
	- shtab:             1.5.8
	- six:               1.16.0
	- smmap:             5.0.0
	- sniffio:           1.3.0
	- snowballstemmer:   2.2.0
	- sortedcontainers:  2.4.0
	- soupsieve:         2.4
	- sqltrie:           0.0.28
	- sse-starlette:     0.10.3
	- sseclient-py:      1.7.2
	- stack-data:        0.6.2
	- starlette:         0.20.4
	- stevedore:         5.0.0
	- strawberry-graphql: 0.138.1
	- tabulate:          0.9.0
	- tenacity:          8.2.1
	- tensorboardx:      2.6
	- terminado:         0.17.1
	- threadpoolctl:     3.1.0
	- tifffile:          2023.2.3
	- timm:              0.6.12
	- tinycss2:          1.2.1
	- toml:              0.10.2
	- tomli:             2.0.1
	- tomlkit:           0.11.6
	- torch:             1.13.1
	- torchmetrics:      0.11.1
	- torchvision:       0.14.1
	- tornado:           6.2
	- tqdm:              4.64.1
	- traitlets:         5.9.0
	- typeshed-client:   2.2.0
	- typing-extensions: 4.5.0
	- tzdata:            2022.7
	- tzlocal:           4.2
	- universal-analytics-python3: 1.1.1
	- uri-template:      1.2.0
	- urllib3:           1.26.14
	- vine:              5.0.0
	- virtualenv:        20.19.0
	- voluptuous:        0.13.1
	- voxel51-eta:       0.8.3
	- wcwidth:           0.2.6
	- webcolors:         1.12
	- webencodings:      0.5.1
	- websocket-client:  1.5.1
	- wemake-python-styleguide: 0.17.0
	- wheel:             0.38.4
	- widgetsnbextension: 4.0.5
	- wrapt:             1.14.1
	- wsproto:           1.2.0
	- xmltodict:         0.13.0
	- y-py:              0.5.5
	- yarl:              1.8.2
	- ypy-websocket:     0.8.2
	- zc.lockfile:       2.0
* System:
	- OS:                Linux
	- architecture:
		- 64bit
		-
	- processor:
	- python:            3.10.10
	- version:           #152-Ubuntu SMP Wed Nov 23 20:19:22 UTC 2022

More info

I think, the problem is specific in how and when optimizers and schedulers are instantiated. Because I run the above code, but only for batch size, and it worked as expected:

Screenshot 85

It used the found batch size in training.

For now, as I understand, the way to use LearningRateFinder is to manually define configure_optimizers() in LightningModule. But this way I can't change the optimizer from the yaml config file.

@rusmux rusmux added bug Something isn't working needs triage Waiting to be triaged by maintainers labels Feb 16, 2023
@weicao1990
Copy link

weicao1990 commented Mar 8, 2023

hi, I also faced such issue. My solution is to add before_fit function to your customized CLI class.

def before_fit(self):
    tuner = Tuner(self.trainer)
    tuner.lr_find(self.model, datamodule=self.datamodule)

In this way, pl will execute configure_optimizers after obtaining the optimal LR. Otherwise if we use LRFinder callback, configure_optimizers will not be executed after finding the optimal LR.

@stale
Copy link

stale bot commented Apr 13, 2023

This issue has been automatically marked as stale because it hasn't had any recent activity. This issue will be closed in 7 days if no further activity occurs. Thank you for your contributions - the Lightning Team!

@stale stale bot added the won't fix This will not be worked on label Apr 13, 2023
@awaelchli awaelchli added tuner and removed won't fix This will not be worked on needs triage Waiting to be triaged by maintainers labels Nov 25, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working tuner
Projects
None yet
Development

No branches or pull requests

3 participants