Skip to content

Commit

Permalink
update long-vid configs
Browse files Browse the repository at this point in the history
  • Loading branch information
yingqinghe committed Aug 21, 2023
1 parent 2303e6a commit 51bdbf6
Show file tree
Hide file tree
Showing 3 changed files with 294 additions and 4 deletions.
144 changes: 144 additions & 0 deletions configs/lvdm_long/sky_interp.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
model:
base_learning_rate: 8.0e-5 #1.5e-04
scale_lr: False
target: lvdm.models.ddpm3d.FrameInterpPredLatentDiffusion
params:
linear_start: 0.0015
linear_end: 0.0155
log_every_t: 200
timesteps: 1000
loss_type: l1
first_stage_key: image
cond_stage_key: "image"
image_size: 32
channels: 4
monitor: val/loss_simple_ema
conditioning_key: concat-adm-mask
cond_stage_config: null
noisy_cond: True
max_noise_level: 250
cond_stage_trainable: False
concat_mode: False
scale_by_std: False
scale_factor: 0.33422927
shift_factor: 1.4606637
encoder_type: 3d
rand_temporal_mask: true
p_interp: 0.9
p_pred: 0.0
n_prevs: null
split_clips: False
downfactor_t: null # used for split video frames to clips before encoding
clip_length: null

unet_config:
target: lvdm.models.modules.openaimodel3d.FrameInterpPredUNet
params:
num_classes: 251 # timesteps for noise conditoining
image_size: 32
in_channels: 5
out_channels: 4
model_channels: 256
attention_resolutions:
- 8
- 4
- 2
num_res_blocks: 3
channel_mult:
- 1
- 2
- 3
- 4
num_heads: 4
use_temporal_transformer: False
use_checkpoint: true
legacy: False
# temporal
kernel_size_t: 1
padding_t: 0
temporal_length: 5
use_relative_position: True
use_scale_shift_norm: True
first_stage_config:
target: lvdm.models.autoencoder3d.AutoencoderKL
params:
monitor: "val/rec_loss"
embed_dim: 4
lossconfig: __is_first_stage__
ddconfig:
double_z: True
z_channels: 4
encoder:
target: lvdm.models.modules.aemodules3d.Encoder
params:
n_hiddens: 32
downsample: [4, 8, 8]
image_channel: 3
norm_type: group
padding_type: replicate
double_z: True
z_channels: 4
decoder:
target: lvdm.models.modules.aemodules3d.Decoder
params:
n_hiddens: 32
upsample: [4, 8, 8]
z_channels: 4
image_channel: 3
norm_type: group

data:
target: main.DataModuleFromConfig
params:
batch_size: 2
num_workers: 0
wrap: false
train:
target: lvdm.data.frame_dataset.VideoFrameDataset
params:
data_root: /dockerdata/sky_timelapse
resolution: 256
video_length: 20
dataset_name: sky
subset_split: train
spatial_transform: center_crop_resize
clip_step: 1
temporal_transform: rand_clips
validation:
target: lvdm.data.frame_dataset.VideoFrameDataset
params:
data_root: /dockerdata/sky_timelapse
resolution: 256
video_length: 20
dataset_name: sky
subset_split: test
spatial_transform: center_crop_resize
clip_step: 1
temporal_transform: rand_clips

lightning:
callbacks:
image_logger:
target: lvdm.utils.callbacks.ImageLogger
params:
batch_frequency: 2000
max_images: 8
increase_log_steps: False
metrics_over_trainsteps_checkpoint:
target: pytorch_lightning.callbacks.ModelCheckpoint
params:
filename: "{epoch:06}-{step:09}"
save_weights_only: False
every_n_epochs: 200
every_n_train_steps: null
trainer:
benchmark: True
batch_size: 2
num_workers: 0
num_nodes: 1
max_epochs: 2000
modelcheckpoint:
target: pytorch_lightning.callbacks.ModelCheckpoint
params:
every_n_epochs: 1
filename: "{epoch:04}-{step:06}"
145 changes: 145 additions & 0 deletions configs/lvdm_long/sky_pred.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
model:
base_learning_rate: 8.0e-5 # 1.5e-04
scale_lr: False
target: lvdm.models.ddpm3d.FrameInterpPredLatentDiffusion
params:
linear_start: 0.0015
linear_end: 0.0155
log_every_t: 200
timesteps: 1000
loss_type: l1
first_stage_key: image
cond_stage_key: "image"
image_size: 32
channels: 4
monitor: val/loss_simple_ema
conditioning_key: concat-adm-mask
cond_stage_config: null
noisy_cond: True
max_noise_level: 250
cond_stage_trainable: False
concat_mode: False
scale_by_std: False
scale_factor: 0.33422927
shift_factor: 1.4606637
encoder_type: 3d
rand_temporal_mask: true
p_interp: 0.0
p_pred: 0.5
n_prevs: [1,]
split_clips: False
downfactor_t: null # used for split video frames to clips before encoding
clip_length: null
latent_frame_strde: 4

unet_config:
target: lvdm.models.modules.openaimodel3d.FrameInterpPredUNet
params:
num_classes: 251 # timesteps for noise conditoining
image_size: 32
in_channels: 5
out_channels: 4
model_channels: 256
attention_resolutions:
- 8
- 4
- 2
num_res_blocks: 3
channel_mult:
- 1
- 2
- 3
- 4
num_heads: 4
use_temporal_transformer: False
use_checkpoint: true
legacy: False
# temporal
kernel_size_t: 1
padding_t: 0
temporal_length: 4
use_relative_position: True
use_scale_shift_norm: True
first_stage_config:
target: lvdm.models.autoencoder3d.AutoencoderKL
params:
monitor: "val/rec_loss"
embed_dim: 4
lossconfig: __is_first_stage__
ddconfig:
double_z: True
z_channels: 4
encoder:
target: lvdm.models.modules.aemodules3d.Encoder
params:
n_hiddens: 32
downsample: [4, 8, 8]
image_channel: 3
norm_type: group
padding_type: replicate
double_z: True
z_channels: 4
decoder:
target: lvdm.models.modules.aemodules3d.Decoder
params:
n_hiddens: 32
upsample: [4, 8, 8]
z_channels: 4
image_channel: 3
norm_type: group

data:
target: main.DataModuleFromConfig
params:
batch_size: 2
num_workers: 0
wrap: false
train:
target: lvdm.data.frame_dataset.VideoFrameDataset
params:
data_root: /dockerdata/sky_timelapse
resolution: 256
video_length: 64
dataset_name: sky
subset_split: train
spatial_transform: center_crop_resize
clip_step: 1
temporal_transform: rand_clips
validation:
target: lvdm.data.frame_dataset.VideoFrameDataset
params:
data_root: /dockerdata/sky_timelapse
resolution: 256
video_length: 64
dataset_name: sky
subset_split: test
spatial_transform: center_crop_resize
clip_step: 1
temporal_transform: rand_clips

lightning:
callbacks:
image_logger:
target: lvdm.utils.callbacks.ImageLogger
params:
batch_frequency: 2000
max_images: 8
increase_log_steps: False
metrics_over_trainsteps_checkpoint:
target: pytorch_lightning.callbacks.ModelCheckpoint
params:
filename: "{epoch:06}-{step:09}"
save_weights_only: False
every_n_epochs: 100
every_n_train_steps: null
trainer:
benchmark: True
batch_size: 2
num_workers: 0
num_nodes: 1
max_epochs: 2000
modelcheckpoint:
target: pytorch_lightning.callbacks.ModelCheckpoint
params:
every_n_epochs: 1
filename: "{epoch:04}-{step:06}"
9 changes: 5 additions & 4 deletions shellscripts/sample_lvdm_long.sh
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@

CKPT_PRED="$TBD"
CKPT_INTERP="$TBD"
CKPT_PRED="models/lvdm_long/sky_pred.ckpt"
CKPT_INTERP="models/lvdm_long/sky_interp.ckpt"
AEPATH="models/ae/ae_sky.ckpt"
CONFIG_PRED="configs/lvdm_long/sky_pred.yaml"
CONFIG_INTERP="configs/lvdm_long/sky_interp.yaml"
Expand All @@ -19,6 +19,7 @@ python scripts/sample_uncond_long_videos.py \
model.params.first_stage_config.params.ckpt_path=$AEPATH \
--sample_cond_noise_level 100 \
--uncond_scale 0.1 \
--n_pred_steps 2
--n_pred_steps 2 \
--sample_type ddim --ddim_steps 50

# if use DDIM: add: `--sample_type ddim --ddim_steps 50`
# if use DDPM: remove: `--sample_type ddim --ddim_steps 50`

0 comments on commit 51bdbf6

Please sign in to comment.