Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add cosine restart learning rate #2953

Open
wants to merge 3 commits into
base: devel
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions deepmd/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ def gelu_wrapper(x):
"softplus": tf.nn.softplus,
"sigmoid": tf.sigmoid,
"tanh": tf.nn.tanh,
"swish": tf.nn.swish,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems that it has been renamed to silu: tensorflow/tensorflow#41066

"gelu": gelu,
"gelu_tf": gelu_tf,
"None": None,
Expand Down
56 changes: 34 additions & 22 deletions deepmd/train/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@
)
from deepmd.utils.learning_rate import (
LearningRateExp,
LearningRateCos,
LearningRateCosRestarts,
)
from deepmd.utils.sess import (
run_sess,
Expand Down Expand Up @@ -113,13 +115,21 @@
scale_lr_coef = np.sqrt(self.run_opt.world_size).real
else:
scale_lr_coef = 1.0
lr_type = lr_param.get("type", "exp")
if lr_type == "exp":
self.lr_type = lr_param.get("type", "exp")
if self.lr_type == "exp":
lr = LearningRateExp(
lr_param["start_lr"], lr_param["stop_lr"], lr_param["decay_steps"]
)
elif self.lr_type == "cos":
lr = LearningRateCos(

Check warning on line 124 in deepmd/train/trainer.py

View check run for this annotation

Codecov / codecov/patch

deepmd/train/trainer.py#L123-L124

Added lines #L123 - L124 were not covered by tests
lr_param["start_lr"], lr_param["stop_lr"], lr_param["decay_steps"]
)
elif self.lr_type == "cosrestart":
lr = LearningRateCosRestarts(

Check warning on line 128 in deepmd/train/trainer.py

View check run for this annotation

Codecov / codecov/patch

deepmd/train/trainer.py#L127-L128

Added lines #L127 - L128 were not covered by tests
lr_param["start_lr"], lr_param["stop_lr"], lr_param["decay_steps"]
)
else:
raise RuntimeError("unknown learning_rate type " + lr_type)
raise RuntimeError("unknown learning_rate type " + self.lr_type)

Check warning on line 132 in deepmd/train/trainer.py

View check run for this annotation

Codecov / codecov/patch

deepmd/train/trainer.py#L132

Added line #L132 was not covered by tests
return lr, scale_lr_coef

# learning rate
Expand Down Expand Up @@ -553,29 +563,31 @@
is_first_step = True
self.cur_batch = cur_batch
if not self.multi_task_mode:
log.info(
"start training at lr %.2e (== %.2e), decay_step %d, decay_rate %f, final lr will be %.2e"
% (
run_sess(self.sess, self.learning_rate),
self.lr.value(cur_batch),
self.lr.decay_steps_,
self.lr.decay_rate_,
self.lr.value(stop_batch),
)
)
else:
for fitting_key in self.fitting:
if self.lr_type == "exp":
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not a good behavior to switch the learning rate in the Trainer. Instead, implement the method LearningRate.log_start (LearningRate should be an abstract base class and inherited by all learning rate classes) and call self.lr.log_start(self.sess) here.

log.info(
"%s: start training at lr %.2e (== %.2e), decay_step %d, decay_rate %f, final lr will be %.2e"
"start training at lr %.2e (== %.2e), decay_step %d, decay_rate %f, final lr will be %.2e"
% (
fitting_key,
run_sess(self.sess, self.learning_rate_dict[fitting_key]),
self.lr_dict[fitting_key].value(cur_batch),
self.lr_dict[fitting_key].decay_steps_,
self.lr_dict[fitting_key].decay_rate_,
self.lr_dict[fitting_key].value(stop_batch),
run_sess(self.sess, self.learning_rate),
self.lr.value(cur_batch),
self.lr.decay_steps_,
self.lr.decay_rate_,
self.lr.value(stop_batch),
)
)
else:
for fitting_key in self.fitting:
if self.lr_type == "exp":
log.info(
"%s: start training at lr %.2e (== %.2e), decay_step %d, decay_rate %f, final lr will be %.2e"
% (
fitting_key,
run_sess(self.sess, self.learning_rate_dict[fitting_key]),
self.lr_dict[fitting_key].value(cur_batch),
self.lr_dict[fitting_key].decay_steps_,
self.lr_dict[fitting_key].decay_rate_,
self.lr_dict[fitting_key].value(stop_batch),
)
)

prf_options = None
prf_run_metadata = None
Expand Down
31 changes: 30 additions & 1 deletion deepmd/utils/argcheck.py
Original file line number Diff line number Diff line change
Expand Up @@ -1010,13 +1010,42 @@ def learning_rate_exp():
]
return args

def learning_rate_cos():
doc_start_lr = "The learning rate the start of the training."
doc_stop_lr = "The desired learning rate at the end of the training."
doc_decay_steps = (
"Number of steps to decay over."
)

args = [
Argument("start_lr", float, optional=True, default=1e-3, doc=doc_start_lr),
Argument("stop_lr", float, optional=True, default=1e-8, doc=doc_stop_lr),
Argument("decay_steps", int, optional=True, default=100000, doc=doc_decay_steps),
]
return args

def learning_rate_cosrestarts():
doc_start_lr = "The learning rate the start of the training."
doc_stop_lr = "The desired learning rate at the end of the training."
doc_decay_steps = (
"Number of steps to decay over of the first decay."
)

args = [
Argument("start_lr", float, optional=True, default=1e-3, doc=doc_start_lr),
Argument("stop_lr", float, optional=True, default=1e-8, doc=doc_stop_lr),
Argument("decay_steps", int, optional=True, default=10000, doc=doc_decay_steps),
]
return args

def learning_rate_variant_type_args():
doc_lr = "The type of the learning rate."

return Variant(
"type",
[Argument("exp", dict, learning_rate_exp())],
[Argument("exp", dict, learning_rate_exp()),
Argument("cos", dict, learning_rate_cos()),
Argument("cosrestart", dict, learning_rate_cosrestarts())],
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You may need to add some documentation to variants (doc="xxx"). Otherwise, no one knows what they are.

optional=True,
default_tag="exp",
doc=doc_lr,
Expand Down
171 changes: 171 additions & 0 deletions deepmd/utils/learning_rate.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,3 +105,174 @@
def value(self, step: int) -> float:
"""Get the lr at a certain step."""
return self.start_lr_ * np.power(self.decay_rate_, (step // self.decay_steps_))

class LearningRateCos:
r"""The cosine decaying learning rate.

The function returns the decayed learning rate. It is computed as:
```python
global_step = min(global_step, decay_steps)
cosine_decay = 0.5 * (1 + cos(pi * global_step / decay_steps))
decayed = (1 - alpha) * cosine_decay + alpha
decayed_learning_rate = learning_rate * decayed
```
Comment on lines +113 to +118
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


Parameters
----------
start_lr
Starting learning rate
stop_lr
Minimum learning rate value as a fraction of learning_rate.
decay_steps
Number of steps to decay over.
"""

def __init__(
self,
start_lr: float,
stop_lr: float = 5e-8,
decay_steps: int = 100000,
) -> None:
"""Constructor."""
self.cd = {}
self.cd["start_lr"] = start_lr
self.cd["stop_lr"] = stop_lr
self.cd["decay_steps"] = decay_steps
self.start_lr_ = self.cd["start_lr"]
self.alpha_ = self.cd["stop_lr"]/self.cd["start_lr"]

Check warning on line 142 in deepmd/utils/learning_rate.py

View check run for this annotation

Codecov / codecov/patch

deepmd/utils/learning_rate.py#L137-L142

Added lines #L137 - L142 were not covered by tests

def build(
self, global_step: tf.Tensor, stop_step: Optional[int] = None
) -> tf.Tensor:
"""Build the learning rate.

Parameters
----------
global_step
The tf Tensor prividing the global training step
stop_step
The stop step.

Returns
-------
learning_rate
The learning rate
"""
if stop_step is None:
self.decay_steps_ = (

Check warning on line 162 in deepmd/utils/learning_rate.py

View check run for this annotation

Codecov / codecov/patch

deepmd/utils/learning_rate.py#L161-L162

Added lines #L161 - L162 were not covered by tests
self.cd["decay_steps"] if self.cd["decay_steps"] is not None else 100000
)
else:
self.stop_lr_ = (

Check warning on line 166 in deepmd/utils/learning_rate.py

View check run for this annotation

Codecov / codecov/patch

deepmd/utils/learning_rate.py#L166

Added line #L166 was not covered by tests
self.cd["stop_lr"] if self.cd["stop_lr"] is not None else 5e-8
)
self.decay_steps_ = (

Check warning on line 169 in deepmd/utils/learning_rate.py

View check run for this annotation

Codecov / codecov/patch

deepmd/utils/learning_rate.py#L169

Added line #L169 was not covered by tests
self.cd["decay_steps"]
if self.cd["decay_steps"] is not None
else stop_step
)

return tf.train.cosine_decay(

Check warning on line 175 in deepmd/utils/learning_rate.py

View check run for this annotation

Codecov / codecov/patch

deepmd/utils/learning_rate.py#L175

Added line #L175 was not covered by tests
self.start_lr_,
global_step,
self.decay_steps_,
self.alpha_,
name="cosine",
)

def start_lr(self) -> float:
"""Get the start lr."""
return self.start_lr_

Check warning on line 185 in deepmd/utils/learning_rate.py

View check run for this annotation

Codecov / codecov/patch

deepmd/utils/learning_rate.py#L185

Added line #L185 was not covered by tests

def value(self, step: int) -> float:
"""Get the lr at a certain step."""
step = min(step, self.decay_steps_)
cosine_decay = 0.5 * (1 + np.cos(np.pi * step / self.decay_steps_))
decayed = (1 - self.alpha_) * cosine_decay + self.alpha_
decayed_learning_rate = self.start_lr_ * decayed
return decayed_learning_rate

Check warning on line 193 in deepmd/utils/learning_rate.py

View check run for this annotation

Codecov / codecov/patch

deepmd/utils/learning_rate.py#L189-L193

Added lines #L189 - L193 were not covered by tests


class LearningRateCosRestarts:
r"""The cosine decaying restart learning rate.

The function returns the cosine decayed learning rate while taking into account
possible warm restarts.
```
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This line should be removed.


Parameters
----------
start_lr
Starting learning rate
stop_lr
Minimum learning rate value as a fraction of learning_rate.
decay_steps
Number of steps to decay over.
"""

def __init__(
self,
start_lr: float,
stop_lr: float = 5e-8,
decay_steps: int = 10000,
) -> None:
"""Constructor."""
self.cd = {}
self.cd["start_lr"] = start_lr
self.cd["stop_lr"] = stop_lr
self.cd["decay_steps"] = decay_steps
self.start_lr_ = self.cd["start_lr"]
self.alpha_ = self.cd["stop_lr"]/self.cd["start_lr"]

Check warning on line 225 in deepmd/utils/learning_rate.py

View check run for this annotation

Codecov / codecov/patch

deepmd/utils/learning_rate.py#L220-L225

Added lines #L220 - L225 were not covered by tests

def build(
self, global_step: tf.Tensor, stop_step: Optional[int] = None
) -> tf.Tensor:
"""Build the learning rate.

Parameters
----------
global_step
The tf Tensor prividing the global training step
stop_step
The stop step.

Returns
-------
learning_rate
The learning rate
"""
if stop_step is None:
self.decay_steps_ = (

Check warning on line 245 in deepmd/utils/learning_rate.py

View check run for this annotation

Codecov / codecov/patch

deepmd/utils/learning_rate.py#L244-L245

Added lines #L244 - L245 were not covered by tests
self.cd["decay_steps"] if self.cd["decay_steps"] is not None else 10000
)
else:
self.stop_lr_ = (

Check warning on line 249 in deepmd/utils/learning_rate.py

View check run for this annotation

Codecov / codecov/patch

deepmd/utils/learning_rate.py#L249

Added line #L249 was not covered by tests
self.cd["stop_lr"] if self.cd["stop_lr"] is not None else 5e-8
)
self.decay_steps_ = (

Check warning on line 252 in deepmd/utils/learning_rate.py

View check run for this annotation

Codecov / codecov/patch

deepmd/utils/learning_rate.py#L252

Added line #L252 was not covered by tests
self.cd["decay_steps"]
if self.cd["decay_steps"] is not None
else stop_step
)



return tf.train.cosine_decay_restarts(

Check warning on line 260 in deepmd/utils/learning_rate.py

View check run for this annotation

Codecov / codecov/patch

deepmd/utils/learning_rate.py#L260

Added line #L260 was not covered by tests
learning_rate=self.start_lr_,
global_step=global_step,
first_decay_steps=self.decay_steps_,
alpha=self.alpha_,
name="cosinerestart",
)

def start_lr(self) -> float:
"""Get the start lr."""
return self.start_lr_

Check warning on line 270 in deepmd/utils/learning_rate.py

View check run for this annotation

Codecov / codecov/patch

deepmd/utils/learning_rate.py#L270

Added line #L270 was not covered by tests

def value(self, step: int) -> float:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you may not need to implement the value method if you do not print the information regarding the learning rate at the beginning of the training:
https://github.com/hellozhaoming/deepmd-kit/blob/05052c195308f61b63ce2bab130ce0e8cba60604/deepmd/train/trainer.py#L566

"""Get the lr at a certain step. Need to revise later"""
step = min(step, self.decay_steps_)
cosine_decay = 0.5 * (1 + np.cos(np.pi * step / self.decay_steps_))
decayed = (1 - self.alpha_) * cosine_decay + self.alpha_
decayed_learning_rate = self.start_lr_ * decayed
return decayed_learning_rate

Check warning on line 278 in deepmd/utils/learning_rate.py

View check run for this annotation

Codecov / codecov/patch

deepmd/utils/learning_rate.py#L274-L278

Added lines #L274 - L278 were not covered by tests
Loading