Note
#Frame = #input_frame x #clip x #crop.
#input_frame
means how many frames are input for model during the test phase.#crop
means spatial crops (e.g., 3 for left/right/center crop).#clip
means temporal clips (e.g., 5 means repeted temporal sampling five clips with different start indices).
For Kinetrics-400, VideoMAE is trained around 1600 epoch without any extra data. The following checkpoints are available in both tensorflow SavedModel
and h5
format.
Backbone | #Frame | Pre-train | Fine-tune | Top-1 | Top-5 |
---|---|---|---|---|---|
ViT-S | 16x5x3 | SavedModel/h5 | SavedModel/h5 | 79.0 | 93.8 |
ViT-B | 16x5x3 | SavedModel/h5 | SavedModel/h5 | 81.5 | 95.1 |
ViT-L | 16x5x3 | SavedModel/h5 | SavedModel/h5 | 85.2 | 96.8 |
ViT-H | 16x5x3 | ? | SavedModel/h5 | 86.6 | 97.1 |
?* Official ViT-H
backbone of VideoMAE has weight issue in pretrained model, details MCG-NJU/VideoMAE#89
For SSv2, VideoMAE is trained around 2400 epoch without any extra data.
Backbone | #Frame | Pre-train | Fine-tune | Top-1 | Top-5 |
---|---|---|---|---|---|
ViT-S | 16x2x3 | SavedModel/h5 | SavedModel/h5 | 66.8 | 90.3 |
ViT-B | 16x2x3 | SavedModel/h5 | SavedModel/h5 | 70.8 | 92.4 |
For UCF101, VideoMAE is trained around 3200 epoch without any extra data.
Backbone | #Frame | Pre-train | Fine-tune | Top-1 | Top-5 |
---|---|---|---|---|---|
ViT-B | 16x5x3 | SavedModel/h5 | SavedModel/h5 | 91.3 | 98.5 |
The torch
video-mae model can be loaded from the official repo. Following are some quick test of both implementation, showing logit matching. Please note, here only fine-tune models (UCF-101) are used to demonstrate.
inputs_pt = torch.tensor(np.random.rand(4, 3, 16, 224, 224).astype('float32'))
inputs_tf = inputs_pt.detach().numpy().transpose(0,2,3,4,1)
model_pt.eval()
y_pred_pt = model_pt(inputs_pt.float()) # UCF-101 model
y_pred_pt = y_pred_pt.detach().numpy()
y_pred_pt.shape
(4, 101)
y_pred_tf = model_tf(inputs_tf, training=False)
y_pred_tf = y_pred_tf.numpy()
y_pred__tf.shape
(4, 101)
np.testing.assert_allclose(
y_pred_tf,
y_pred_pt,
1e-5, 1e-5
) # OK
Saving and Reloading Weight - check if saving and reloading is safe.
model_tf.save_weights(checkpoint_name + '.h5')
new_model_tf = build_video_mae(...)
new_model_tf.load_weights(checkpoint_name + '.h5')
# Let's check: weight matching
assert len(model_tf.weights) == len(new_model_tf.weights)
for a, b in zip(model_tf.weights, new_model_tf.weights):
np.testing.assert_allclose(a.numpy(), b.numpy()) # OK
# Let's check: inference matching
test_input = tf.random.normal(
[1, 16, 224, 224, 3], 0, 1, tf.float32
)
tf.nest.map_structure(
np.testing.assert_allclose,
model_tf.predict(test_input),
new_model_tf.predict(test_input),
) # OK
Saving and Reloading TF SavedModel - check if saving and reloading is safe.
model_tf.save(checkpoint_name)
loaded_model = keras.models.load_model(
checkpoint_name
)
assert len(model_tf.weights) == len(loaded_model.weights)
for a, b in zip(model_tf.weights, loaded_model.weights):
np.testing.assert_allclose(a.numpy(), b.numpy()) # OK
# Let's check: inference matching
test_input = tf.random.normal(
[1, 16, 224, 224, 3], 0, 1, tf.float32
)
tf.nest.map_structure(
np.testing.assert_allclose,
model_tf.predict(test_input),
loaded_model.predict(test_input),
) # OK
Weight matching between TF SavedModel vs torch
model.
y_pred_pt = model_pt(inputs_pt.float())
y_pred_tf = loaded_model(inputs_tf, training=False)
print(y_pred_pt.shape, y_pred_tf.shape)
np.testing.assert_allclose(
y_pred_tf.numpy(),
y_pred_pt.detach().numpy(),
1e-5, 1e-5
) # OK
XLA compatible - TF SavedModel
call_fn = tf.function(loaded_model, jit_compile=True)
%timeit _ = call_fn(inputs_tf, training=False)