Skip to content

Commit

Permalink
updates
Browse files Browse the repository at this point in the history
  • Loading branch information
sayakpaul committed Nov 28, 2024
1 parent 58a0632 commit 2fde026
Show file tree
Hide file tree
Showing 9 changed files with 466 additions and 1,375 deletions.
150 changes: 31 additions & 119 deletions training/mochi-1/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,11 @@ def _get_model_args(parser: argparse.ArgumentParser) -> None:
default=None,
help="The directory where the downloaded models and datasets will be stored.",
)
parser.add_argument(
"--cast_dit",
action="store_true",
help="If we should cast DiT params to a lower precision.",
)


def _get_dataset_args(parser: argparse.ArgumentParser) -> None:
Expand All @@ -38,58 +43,12 @@ def _get_dataset_args(parser: argparse.ArgumentParser) -> None:
help=("A folder containing the training data."),
)
parser.add_argument(
"--dataset_file",
type=str,
default=None,
help=("Path to a CSV file if loading prompts/video paths using this format."),
)
parser.add_argument(
"--video_column",
type=str,
default="video",
help="The column of the dataset containing videos. Or, the name of the file in `--data_root` folder containing the line-separated path to video data.",
)
parser.add_argument(
"--caption_column",
type=str,
default="text",
help="The column of the dataset containing the instance prompt for each video. Or, the name of the file in `--data_root` folder containing the line-separated instance prompts.",
)
parser.add_argument(
"--id_token",
type=str,
default=None,
help="Identifier token appended to the start of each prompt if provided.",
)
parser.add_argument(
"--height_buckets",
nargs="+",
type=int,
default=[256, 320, 384, 480, 512, 576, 720, 768, 960, 1024, 1280, 1536],
)
parser.add_argument(
"--width_buckets",
nargs="+",
type=int,
default=[256, 320, 384, 480, 512, 576, 720, 768, 848, 960, 1024, 1280, 1536],
)
parser.add_argument(
"--frame_buckets",
nargs="+",
type=int,
default=[84],
)
parser.add_argument(
"--load_tensors",
action="store_true",
help="Whether to use a pre-encoded tensor dataset of latents and prompt embeddings instead of videos and text prompts. The expected format is that saved by running the `prepare_dataset.py` script.",
)
parser.add_argument(
"--random_flip",
"--caption_dropout",
type=float,
default=None,
help="If random horizontal flip augmentation is to be used, this should be the flip probability.",
help=("Probability to drop out captions randomly."),
)

parser.add_argument(
"--dataloader_num_workers",
type=int,
Expand Down Expand Up @@ -140,23 +99,39 @@ def _get_validation_args(parser: argparse.ArgumentParser) -> None:
default=False,
help="Whether or not to enable model-wise CPU offloading when performing validation/testing to save memory.",
)
parser.add_argument(
"--fps",
type=int,
default=30,
help="FPS to use when serializing the output videos.",
)
parser.add_argument(
"--height",
type=int,
default=480,
)
parser.add_argument(
"--width",
type=int,
default=848,
)


def _get_training_args(parser: argparse.ArgumentParser) -> None:
parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
parser.add_argument("--rank", type=int, default=64, help="The rank for LoRA matrices.")
parser.add_argument("--rank", type=int, default=16, help="The rank for LoRA matrices.")
parser.add_argument(
"--lora_alpha",
type=int,
default=64,
default=16,
help="The lora_alpha to compute scaling factor (lora_alpha / rank) for LoRA matrices.",
)
parser.add_argument(
"--target_modules",
nargs="+",
type=str,
default=["to_k", "to_q", "to_v", "to_out.0"],
help="Target modules to train LoRA for."
help="Target modules to train LoRA for.",
)
parser.add_argument(
"--mixed_precision",
Expand All @@ -175,43 +150,6 @@ def _get_training_args(parser: argparse.ArgumentParser) -> None:
default="mochi-lora",
help="The output directory where the model predictions and checkpoints will be written.",
)
parser.add_argument(
"--height",
type=int,
default=480,
help="All input videos are resized to this height.",
)
parser.add_argument(
"--width",
type=int,
default=848,
help="All input videos are resized to this width.",
)
parser.add_argument(
"--video_reshape_mode",
type=str,
default=None,
help="All input videos are reshaped to this mode. Choose between ['center', 'random', 'none']",
)
parser.add_argument("--fps", type=int, default=30, help="All input videos will be used at this FPS.")
parser.add_argument(
"--max_num_frames",
type=int,
default=84,
help="All input videos will be truncated to these many frames.",
)
parser.add_argument(
"--skip_frames_start",
type=int,
default=0,
help="Number of frames to skip from the beginning of each input video. Useful if training data contains intro sequences.",
)
parser.add_argument(
"--skip_frames_end",
type=int,
default=0,
help="Number of frames to skip from the end of each input video. Useful if training data contains outro sequences.",
)
parser.add_argument(
"--train_batch_size",
type=int,
Expand Down Expand Up @@ -256,25 +194,6 @@ def _get_training_args(parser: argparse.ArgumentParser) -> None:
default=1,
help="Number of updates steps to accumulate before performing a backward/update pass.",
)
parser.add_argument(
"--weighting_scheme",
type=str,
default="none",
choices=["sigma_sqrt", "logit_normal", "mode", "cosmap", "none"],
help=('We default to the "none" weighting scheme for uniform sampling and uniform loss'),
)
parser.add_argument(
"--logit_mean", type=float, default=0.0, help="mean to use when using the `'logit_normal'` weighting scheme."
)
parser.add_argument(
"--logit_std", type=float, default=1.0, help="std to use when using the `'logit_normal'` weighting scheme."
)
parser.add_argument(
"--mode_scale",
type=float,
default=1.29,
help="Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`.",
)
parser.add_argument(
"--gradient_checkpointing",
action="store_true",
Expand All @@ -283,19 +202,18 @@ def _get_training_args(parser: argparse.ArgumentParser) -> None:
parser.add_argument(
"--learning_rate",
type=float,
default=1e-4,
default=2e-4,
help="Initial learning rate (after the potential warmup period) to use.",
)
parser.add_argument(
"--scale_lr",
action="store_true",
default=False,
help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
)
parser.add_argument(
"--lr_scheduler",
type=str,
default="constant",
default="cosine",
help=(
'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
' "constant", "constant_with_warmup"]'
Expand All @@ -304,7 +222,7 @@ def _get_training_args(parser: argparse.ArgumentParser) -> None:
parser.add_argument(
"--lr_warmup_steps",
type=int,
default=500,
default=200,
help="Number of steps for the warmup in the lr scheduler.",
)
parser.add_argument(
Expand All @@ -331,12 +249,6 @@ def _get_training_args(parser: argparse.ArgumentParser) -> None:
default=False,
help="Whether or not to use VAE tiling for saving memory.",
)
parser.add_argument(
"--noised_image_dropout",
type=float,
default=0.05,
help="Image condition dropout probability when finetuning image-to-video.",
)


def _get_optimizer_args(parser: argparse.ArgumentParser) -> None:
Expand Down Expand Up @@ -386,7 +298,7 @@ def _get_optimizer_args(parser: argparse.ArgumentParser) -> None:
parser.add_argument(
"--weight_decay",
type=float,
default=1e-04,
default=0.01,
help="Weight decay to use for optimizer.",
)
parser.add_argument(
Expand Down
Loading

0 comments on commit 2fde026

Please sign in to comment.