Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change default batch_size for finetuning to max_batch_size for a model #189

Merged
merged 9 commits into from
Sep 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ client.fine_tuning.create(
model = 'mistralai/Mixtral-8x7B-Instruct-v0.1',
n_epochs = 3,
n_checkpoints = 1,
batch_size = 4,
batch_size = "max",
learning_rate = 1e-5,
suffix = 'my-demo-finetune',
wandb_api_key = '1a2b3c4d5e.......',
Expand Down
64 changes: 55 additions & 9 deletions src/together/cli/api/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,17 @@
import json
from datetime import datetime
from textwrap import wrap
from typing import Any, Literal

import click
from click.core import ParameterSource # type: ignore[attr-defined]
from rich import print as rprint
from tabulate import tabulate

from together import Together
from together.types.finetune import DownloadCheckpointType
from together.cli.api.utils import INT_WITH_MAX
from together.utils import finetune_price_to_dollars, log_warn, parse_timestamp
from together.types.finetune import DownloadCheckpointType, FinetuneTrainingLimits


_CONFIRMATION_MESSAGE = (
Expand Down Expand Up @@ -56,7 +58,7 @@ def fine_tuning(ctx: click.Context) -> None:
@click.option(
"--n-checkpoints", type=int, default=1, help="Number of checkpoints to save"
)
@click.option("--batch-size", type=int, default=16, help="Train batch size")
@click.option("--batch-size", type=INT_WITH_MAX, default="max", help="Train batch size")
@click.option("--learning-rate", type=float, default=1e-5, help="Learning rate")
@click.option(
"--lora/--no-lora",
Expand Down Expand Up @@ -93,7 +95,7 @@ def create(
n_epochs: int,
n_evals: int,
n_checkpoints: int,
batch_size: int,
batch_size: int | Literal["max"],
learning_rate: float,
lora: bool,
lora_r: int,
Expand All @@ -107,20 +109,64 @@ def create(
"""Start fine-tuning"""
client: Together = ctx.obj

training_args: dict[str, Any] = dict(
training_file=training_file,
model=model,
n_epochs=n_epochs,
validation_file=validation_file,
n_evals=n_evals,
n_checkpoints=n_checkpoints,
batch_size=batch_size,
learning_rate=learning_rate,
lora=lora,
lora_r=lora_r,
lora_dropout=lora_dropout,
lora_alpha=lora_alpha,
lora_trainable_modules=lora_trainable_modules,
suffix=suffix,
wandb_api_key=wandb_api_key,
)

model_limits: FinetuneTrainingLimits = client.fine_tuning.get_model_limits(
model=model
)

if lora:
learning_rate_source = click.get_current_context().get_parameter_source( # type: ignore[attr-defined]
"learning_rate"
)
if learning_rate_source == ParameterSource.DEFAULT:
learning_rate = 1e-3
if model_limits.lora_training is None:
raise click.BadParameter(
f"LoRA fine-tuning is not supported for the model `{model}`"
)

default_values = {
"lora_r": model_limits.lora_training.max_rank,
"batch_size": model_limits.lora_training.max_batch_size,
"learning_rate": 1e-3,
}
for arg in default_values:
arg_source = ctx.get_parameter_source("arg") # type: ignore[attr-defined]
if arg_source == ParameterSource.DEFAULT:
training_args[arg] = default_values[arg_source]

if ctx.get_parameter_source("lora_alpha") == ParameterSource.DEFAULT: # type: ignore[attr-defined]
training_args["lora_alpha"] = training_args["lora_r"] * 2
else:
if model_limits.full_training is None:
raise click.BadParameter(
f"Full fine-tuning is not supported for the model `{model}`"
)

for param in ["lora_r", "lora_dropout", "lora_alpha", "lora_trainable_modules"]:
param_source = click.get_current_context().get_parameter_source(param) # type: ignore[attr-defined]
param_source = ctx.get_parameter_source(param) # type: ignore[attr-defined]
if param_source != ParameterSource.DEFAULT:
raise click.BadParameter(
f"You set LoRA parameter `{param}` for a full fine-tuning job. "
f"Please change the job type with --lora or remove `{param}` from the arguments"
)

batch_size_source = ctx.get_parameter_source("batch_size") # type: ignore[attr-defined]
if batch_size_source == ParameterSource.DEFAULT:
training_args["batch_size"] = model_limits.full_training.max_batch_size

if n_evals <= 0 and validation_file:
log_warn(
"Warning: You have specified a validation file but the number of evaluation loops is set to 0. No evaluations will be performed."
Expand Down
21 changes: 21 additions & 0 deletions src/together/cli/api/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import click

from typing import Literal


class AutoIntParamType(click.ParamType):
name = "integer"

def convert(
self, value: str, param: click.Parameter | None, ctx: click.Context | None
) -> int | Literal["max"] | None:
if isinstance(value, int):
return value

if value == "max":
return "max"

self.fail("Invalid integer value: {value}")


INT_WITH_MAX = AutoIntParamType()
4 changes: 2 additions & 2 deletions src/together/legacy/finetune.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

import warnings
from typing import Any, Dict, List
from typing import Any, Dict, List, Literal

import together
from together.legacy.base import API_KEY_WARNING, deprecated
Expand Down Expand Up @@ -43,7 +43,7 @@ def create(
model=model,
n_epochs=n_epochs,
n_checkpoints=n_checkpoints,
batch_size=batch_size,
batch_size=batch_size if isinstance(batch_size, int) else "max",
learning_rate=learning_rate,
suffix=suffix,
wandb_api_key=wandb_api_key,
Expand Down
110 changes: 99 additions & 11 deletions src/together/resources/finetune.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

from pathlib import Path
from typing import Literal

from rich import print as rprint

Expand All @@ -13,14 +14,15 @@
FinetuneListEvents,
FinetuneRequest,
FinetuneResponse,
FinetuneTrainingLimits,
FullTrainingType,
LoRATrainingType,
TogetherClient,
TogetherRequest,
TrainingType,
)
from together.types.finetune import DownloadCheckpointType
from together.utils import log_warn, normalize_key
from together.utils import log_warn_once, normalize_key


class FineTuning:
Expand All @@ -36,16 +38,17 @@ def create(
validation_file: str | None = "",
n_evals: int | None = 0,
n_checkpoints: int | None = 1,
batch_size: int | None = 16,
batch_size: int | Literal["max"] = "max",
learning_rate: float | None = 0.00001,
lora: bool = False,
lora_r: int | None = 8,
lora_r: int | None = None,
lora_dropout: float | None = 0,
lora_alpha: float | None = 8,
lora_alpha: float | None = None,
lora_trainable_modules: str | None = "all-linear",
suffix: str | None = None,
wandb_api_key: str | None = None,
verbose: bool = False,
model_limits: FinetuneTrainingLimits | None = None,
) -> FinetuneResponse:
"""
Method to initiate a fine-tuning job
Expand All @@ -58,7 +61,7 @@ def create(
n_evals (int, optional): Number of evaluation loops to run. Defaults to 0.
n_checkpoints (int, optional): Number of checkpoints to save during fine-tuning.
Defaults to 1.
batch_size (int, optional): Batch size for fine-tuning. Defaults to 32.
batch_size (int, optional): Batch size for fine-tuning. Defaults to max.
learning_rate (float, optional): Learning rate multiplier to use for training
Defaults to 0.00001.
lora (bool, optional): Whether to use LoRA adapters. Defaults to True.
Expand All @@ -72,24 +75,59 @@ def create(
Defaults to None.
verbose (bool, optional): whether to print the job parameters before submitting a request.
Defaults to False.
model_limits (FinetuneTrainingLimits, optional): Limits for the hyperparameters the model in Fine-tuning.
Defaults to None.

Returns:
FinetuneResponse: Object containing information about fine-tuning job.
"""

if batch_size == "max":
log_warn_once(
"Starting from together>=1.3.0, "
"the default batch size is set to the maximum allowed value for each model."
)

requestor = api_requestor.APIRequestor(
client=self._client,
)

if model_limits is None:
model_limits = self.get_model_limits(model=model)

training_type: TrainingType = FullTrainingType()
if lora:
if model_limits.lora_training is None:
raise ValueError(
"LoRA adapters are not supported for the selected model."
)
lora_r = (
lora_r if lora_r is not None else model_limits.lora_training.max_rank
)
lora_alpha = lora_alpha if lora_alpha is not None else lora_r * 2
training_type = LoRATrainingType(
lora_r=lora_r,
lora_alpha=lora_alpha,
lora_dropout=lora_dropout,
lora_trainable_modules=lora_trainable_modules,
)

batch_size = (
batch_size
if batch_size != "max"
else model_limits.lora_training.max_batch_size
)
else:
if model_limits.full_training is None:
raise ValueError(
"Full training is not supported for the selected model."
)
Comment on lines +120 to +124
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like you have duplicated validation logic here and in cli/finetune.py. Maybe it's best to extract it to a function, call it in cli/finetune.py and reraise the exception as click.BadParameter if necessary?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This logic is mostly for mypy -- it will error out that full_training is None in the following lines without this check

batch_size = (
batch_size
if batch_size != "max"
else model_limits.full_training.max_batch_size
)

finetune_request = FinetuneRequest(
model=model,
training_file=training_file,
Expand Down Expand Up @@ -121,12 +159,6 @@ def create(

assert isinstance(response, TogetherResponse)

# TODO: Remove after next LoRA default change
log_warn(
"Some of the jobs run _directly_ from the together-python library might be trained using LoRA adapters. "
"The version range when this change occurred is from 1.2.3 to 1.2.6."
)

return FinetuneResponse(**response.data)

def list(self) -> FinetuneList:
Expand Down Expand Up @@ -305,6 +337,34 @@ def download(
size=file_size,
)

def get_model_limits(self, *, model: str) -> FinetuneTrainingLimits:
"""
Requests training limits for a specific model

Args:
model_name (str): Name of the model to get limits for

Returns:
FinetuneTrainingLimits: Object containing training limits for the model
"""

requestor = api_requestor.APIRequestor(
client=self._client,
)

model_limits_response, _, _ = requestor.request(
options=TogetherRequest(
method="GET",
url="fine-tunes/models/limits",
params={"model_name": model},
),
stream=False,
)

model_limits = FinetuneTrainingLimits(**model_limits_response.data)

return model_limits


class AsyncFineTuning:
def __init__(self, client: TogetherClient) -> None:
Expand Down Expand Up @@ -493,3 +553,31 @@ async def download(
"AsyncFineTuning.download not implemented. "
"Please use FineTuning.download function instead."
)

async def get_model_limits(self, *, model: str) -> FinetuneTrainingLimits:
"""
Requests training limits for a specific model

Args:
model_name (str): Name of the model to get limits for

Returns:
FinetuneTrainingLimits: Object containing training limits for the model
"""

requestor = api_requestor.APIRequestor(
client=self._client,
)

model_limits_response, _, _ = await requestor.arequest(
options=TogetherRequest(
method="GET",
url="fine-tunes/models/limits",
params={"model": model},
),
stream=False,
)

model_limits = FinetuneTrainingLimits(**model_limits_response.data)

return model_limits
2 changes: 2 additions & 0 deletions src/together/types/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
FullTrainingType,
LoRATrainingType,
TrainingType,
FinetuneTrainingLimits,
)
from together.types.images import (
ImageRequest,
Expand Down Expand Up @@ -71,4 +72,5 @@
"LoRATrainingType",
"RerankRequest",
"RerankResponse",
"FinetuneTrainingLimits",
]
18 changes: 18 additions & 0 deletions src/together/types/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,3 +263,21 @@ class FinetuneDownloadResult(BaseModel):
filename: str | None = None
# size in bytes
size: int | None = None


class FinetuneFullTrainingLimits(BaseModel):
max_batch_size: int
min_batch_size: int


class FinetuneLoraTrainingLimits(FinetuneFullTrainingLimits):
max_rank: int
target_modules: List[str]


class FinetuneTrainingLimits(BaseModel):
max_num_epochs: int
max_learning_rate: float
min_learning_rate: float
full_training: FinetuneFullTrainingLimits | None = None
lora_training: FinetuneLoraTrainingLimits | None = None
Loading
Loading