-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
0 parents
commit 91da562
Showing
14 changed files
with
1,255 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
.idea | ||
.ipynb_checkpoints | ||
CLIP | ||
weights | ||
videos | ||
ChangeIt |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
FROM pytorch/pytorch:1.10.0-cuda11.3-cudnn8-devel | ||
# using pytorch:1.12.0-cuda11.3-cudnn8-devel results in training being 2x slower for some weird reason | ||
|
||
RUN apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64/3bf863cc.pub | ||
|
||
RUN apt-get update \ | ||
&& apt-get install ffmpeg wget git -y | ||
|
||
RUN pip install \ | ||
opencv-python \ | ||
pillow \ | ||
matplotlib \ | ||
scikit-learn \ | ||
scipy \ | ||
tqdm \ | ||
pandas \ | ||
ffmpeg-python \ | ||
ftfy \ | ||
regex \ | ||
imgaug | ||
|
||
ENV LD_LIBRARY_PATH=/usr/local/cuda/lib64:$LD_LIBRARY_PATH | ||
|
||
COPY cuda_ops /tmp | ||
|
||
RUN cd /tmp \ | ||
&& TORCH_CUDA_ARCH_LIST="6.1;7.0;7.5;8.0;8.6" python setup.py install \ | ||
&& rm -rf * |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
# Multi-Task Learning of Object State Changes from Uncurated Videos | ||
|
||
|
||
## Train the model on ChangeIt dataset | ||
1. **Setup the environment** | ||
- Our code can be run in a docker container. Build it by running the following command. | ||
Note that by default, we compile custom CUDA code for architectures 6.1, 7.0, 7.5, 8.0, and 8.6. | ||
You may need to update the Dockerfile with your GPU architecture. | ||
``` | ||
docker build -t multi-task-object-states . | ||
``` | ||
- Go into the docker image. | ||
``` | ||
docker run -it --rm --gpus all -v $(pwd):$(pwd) -w $(pwd) --user=$(id -u $USER):$(id -g $USER) multi-task-object-states bash | ||
``` | ||
2. **Download requirements** | ||
- Our code requires CLIP repository, CLIP model weights, and the ChangeIt dataset annotations. | ||
Run `./download_requirements.sh` to obtain those dependencies or download them yourselves. | ||
3. **Download dataset** | ||
- To replicate our experiments on the ChangeIt dataset, the dataset videos are required. | ||
Please download them and put them inside `videos/*category*` folder. | ||
See [ChangeIt GitHub page](https://github.com/soCzech/ChangeIt) on how to download them. | ||
4. **Train a model** | ||
- Run the training. | ||
``` | ||
python train.py --video_roots ./videos | ||
--dataset_root ./ChangeIt | ||
--train_backbone | ||
--augment | ||
--local_batch_size 2 | ||
``` | ||
- We trained the model on 32 GPUs, i.e. batch size 64. | ||
- To run the code on multiple GPUs, simply run the code on a machine with multiple GPUs. | ||
- To run the code on multiple nodes, run the code once on each node. | ||
If you are not running on slurm, you also need to set environment variable `SLURM_NPROCS` | ||
to the total number of nodes and the variable `SLURM_PROCID` to the node id starting from zero. | ||
Make sure you also set `SLURM_JOBID` to some unique value. | ||
## Train the model on your dataset | ||
- To train the model on your dataset, complete steps **1.** and **2.** from above. | ||
- Put your videos into `*dir*/*category*` for every video category `*category*`. | ||
- Put your annotations for selected videos into `*dataset*/annotations/*category*`. | ||
Use the same [format](https://github.com/soCzech/ChangeIt/tree/main/annotations) as in the case of ChangeIt dataset. | ||
- Run the training. | ||
``` | ||
python train.py --video_roots *dir* | ||
--dataset_root *dataset* | ||
--train_backbone | ||
--augment | ||
--local_batch_size 2 | ||
--ignore_video_weight | ||
``` | ||
- `--ignore_video_weight` option ignores noise adaptive weighting done for noisy ChangeIt dataset. | ||
To use the noise adaptive weighting, you need to provide `*dataset*/categories.csv` and `*dataset*/videos/*category*.csv` files as well. | ||
## Use a trained model | ||
Here is an example code for the inference of a trained model. | ||
```python | ||
checkpoint = torch.load("path/to/saved/model.pth", map_location="cpu") | ||
model = ClipClassifier(params=checkpoint["args"], | ||
n_classes=checkpoint["n_classes"], | ||
hidden_mlp_layers=checkpoint["hidden_mlp_layers"]).cuda() | ||
model.load_state_dict({k.replace("module.", ""): v for k, v in checkpoint["state_dict"].items()}) | ||
video_frames = torch.from_numpy( | ||
extract_frames(video_fn, fps=1, size=(398, 224), crop=(398 - 224, 0))) | ||
with torch.no_grad(): | ||
predictions = model(video_frames.cuda()) | ||
state_pred, action_pred = torch.softmax(predictions["state"], -1), torch.softmax(predictions["action"], -1) | ||
``` | ||
|
||
|
||
## Acknowledgements | ||
The ordering constraint code has been adapted from the CVPR 2022 paper | ||
[Look for the Change: Learning Object States and State-Modifying Actions from Untrimmed Web Videos](https://arxiv.org/abs/2203.11637) | ||
available on [github.com](https://github.com/soCzech/LookForTheChange). |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
import torch | ||
import _lookforthechange_ops | ||
|
||
|
||
def optimal_state_change(state_tensor, action_tensor, lens, delta, kappa, max_action_state_distance=500): | ||
return _lookforthechange_ops.optimal_state_change( | ||
state_tensor.contiguous(), action_tensor.contiguous(), lens, delta, kappa, max_action_state_distance) | ||
|
||
|
||
def optimal_state_change_indices(state_tensor, action_tensor, lens, max_action_state_distance=500): | ||
return _lookforthechange_ops.optimal_state_change_indices( | ||
state_tensor.contiguous(), action_tensor.contiguous(), lens, max_action_state_distance) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
from setuptools import setup | ||
from torch.utils import cpp_extension | ||
|
||
library_dirs = cpp_extension.library_paths(cuda=True) | ||
include_dirs = cpp_extension.include_paths(cuda=True) | ||
|
||
print("library_dirs:", library_dirs) | ||
print("include_dirs:", include_dirs) | ||
|
||
setup( | ||
name="lookforthechange", | ||
version="2.0", | ||
install_requires=[ | ||
"numpy", | ||
"torch" | ||
], | ||
ext_modules=[ | ||
cpp_extension.CUDAExtension( | ||
name='_lookforthechange_ops', | ||
sources=[ | ||
'src/common.cpp', | ||
'src/optimal_state_change.cu', | ||
'src/optimal_state_change_indices.cu', | ||
], | ||
library_dirs=library_dirs, | ||
include_dirs=include_dirs | ||
) | ||
], | ||
packages=['lookforthechange'], | ||
package_dir={'lookforthechange': './python'}, | ||
cmdclass={'build_ext': cpp_extension.BuildExtension} | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
#include "common.hpp" | ||
|
||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { | ||
m.def("optimal_state_change", &optimal_state_change, "Optimal State1-Action-State2 Sequence (GPU)"); | ||
m.def("optimal_state_change_indices", &optimal_state_change_indices, "Optimal State1-Action-State2 Sequence (GPU)"); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
#include <torch/extension.h> | ||
|
||
// C++ interface | ||
#define CHECK_CPU(x) TORCH_CHECK(!x.type().is_cuda(), #x " must be a CPU tensor") | ||
#define CHECK_GPU(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a GPU tensor") | ||
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") | ||
#define CHECK_INPUT(x) CHECK_CPU(x); CHECK_CONTIGUOUS(x) | ||
#define CHECK_CUDA_INPUT(x) CHECK_GPU(x); CHECK_CONTIGUOUS(x) | ||
|
||
std::vector<torch::Tensor> optimal_state_change( | ||
torch::Tensor state_tensor, torch::Tensor action_tensor, torch::Tensor lens, int delta, int kappa, int max_action_state_distance); | ||
|
||
torch::Tensor optimal_state_change_indices( | ||
torch::Tensor state_tensor, torch::Tensor action_tensor, torch::Tensor lens, int max_action_state_distance); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,145 @@ | ||
#include "common.hpp" | ||
|
||
__global__ void SingleStateChangeKernel( | ||
const float* state_tensor, | ||
const float* action_tensor, | ||
const int* lens, | ||
int* state_targets, | ||
int* action_targets, | ||
const int delta, | ||
const int kappa, | ||
const int max_action_state_distance | ||
) { | ||
const int batch_idx = blockIdx.x; | ||
const int video_len = blockDim.x; | ||
const int state1_pos = threadIdx.x; | ||
const int actual_len = lens[batch_idx]; | ||
|
||
// get pointer to shared memory | ||
extern __shared__ char shared_mem[]; | ||
int* state1_to_action_pos = reinterpret_cast<int*>(shared_mem); | ||
int* state1_to_state2_pos = state1_to_action_pos + video_len; | ||
float* state1_to_score = reinterpret_cast<float*>(state1_to_state2_pos + video_len); | ||
float* action_tensor_shared = state1_to_score + video_len; | ||
float* state_tensor_shared = action_tensor_shared + video_len; | ||
|
||
// load action and state tensors into shared memory | ||
action_tensor_shared[state1_pos] = action_tensor[batch_idx * video_len + state1_pos]; | ||
state_tensor_shared[2 * state1_pos + 0] = state_tensor[batch_idx * video_len * 2 + state1_pos * 2 + 0]; | ||
state_tensor_shared[2 * state1_pos + 1] = state_tensor[batch_idx * video_len * 2 + state1_pos * 2 + 1]; | ||
|
||
__syncthreads(); | ||
|
||
float best_score = -std::numeric_limits<float>::infinity(); | ||
int best_action_pos = 0, best_state2_pos = 0; // position of states/action for videos shorter than 3 | ||
|
||
for (int action_pos = state1_pos + 1; action_pos <= state1_pos + max_action_state_distance && action_pos < actual_len - 1; ++action_pos) { // -1: need at least one position for state2 | ||
float action_score = action_tensor_shared[action_pos]; | ||
|
||
for (int state2_pos = action_pos + 1; state2_pos <= action_pos + max_action_state_distance && state2_pos < actual_len; ++state2_pos) { | ||
float state2_score = state_tensor_shared[2 * state2_pos + 1]; // 2 states, +1 for second state | ||
|
||
float score = action_score * state2_score; | ||
if (score > best_score) { | ||
best_score = score; | ||
best_action_pos = action_pos; | ||
best_state2_pos = state2_pos; | ||
} | ||
} | ||
} | ||
|
||
state1_to_action_pos[state1_pos] = best_action_pos; | ||
state1_to_state2_pos[state1_pos] = best_state2_pos; | ||
state1_to_score[state1_pos] = best_score * state_tensor_shared[2 * state1_pos + 0]; | ||
|
||
__syncthreads(); | ||
|
||
if (state1_pos == 0) { // compute reduction only on the first thread | ||
best_score = state1_to_score[0]; | ||
int best_state1_pos = 0; | ||
for (int i = 1; i < actual_len - 2; ++i) { // -2: need at least one position for action and one for state2 | ||
if (best_score < state1_to_score[i]) { | ||
best_state1_pos = i; | ||
best_score = state1_to_score[i]; | ||
} | ||
} | ||
best_action_pos = state1_to_action_pos[best_state1_pos]; | ||
best_state2_pos = state1_to_state2_pos[best_state1_pos]; | ||
|
||
// FILL state_targets TENSOR | ||
// 0 .. default - no label | ||
// 1 .. initial state label | ||
// 2 .. end state label | ||
for (int i = best_state1_pos - delta; i <= best_state1_pos + delta; ++i) { | ||
if (i < 0 || i >= actual_len) continue; | ||
state_targets[batch_idx * video_len + i] = 1; | ||
} | ||
for (int i = best_state2_pos - delta; i <= best_state2_pos + delta; ++i) { | ||
if (i < 0 || i >= actual_len) continue; | ||
state_targets[batch_idx * video_len + i] = 2; | ||
} | ||
|
||
// FILL action_targets TENSOR | ||
// 0 .. default - no label | ||
// 1 .. no-action label | ||
// 2 .. action label | ||
for (int i = 0; i <= delta; ++i) { | ||
int j = best_action_pos - i - kappa; | ||
if (j < 0) { | ||
action_targets[batch_idx * video_len + 0] = 1; | ||
} else { | ||
action_targets[batch_idx * video_len + j] = 1; | ||
} | ||
|
||
int k = best_action_pos + i + kappa; | ||
if (k >= actual_len) { | ||
action_targets[batch_idx * video_len + actual_len - 1] = 1; | ||
} else { | ||
action_targets[batch_idx * video_len + k] = 1; | ||
} | ||
} | ||
for (int i = best_action_pos - delta; i <= best_action_pos + delta; ++i) { | ||
if (i < 0 || i >= actual_len) continue; | ||
action_targets[batch_idx * video_len + i] = 2; | ||
} | ||
} | ||
} | ||
|
||
std::vector<torch::Tensor> optimal_state_change( | ||
torch::Tensor state_tensor, torch::Tensor action_tensor, torch::Tensor lens, int delta, int kappa, int max_action_state_distance) { | ||
|
||
CHECK_CUDA_INPUT(state_tensor); | ||
CHECK_CUDA_INPUT(action_tensor); | ||
CHECK_CUDA_INPUT(lens); | ||
|
||
int batch_size = state_tensor.size(0); | ||
int video_len = state_tensor.size(1); | ||
|
||
TORCH_CHECK(state_tensor.size(2) == 2, "state_tensor must be of shape [batch, video_len, 2]") | ||
TORCH_CHECK(action_tensor.size(2) == 1, "action_tensor must be of shape [batch, video_len, 1]") | ||
|
||
auto options = torch::TensorOptions().dtype(torch::kInt).device(torch::kCUDA); | ||
auto state_targets = torch::zeros({batch_size, video_len}, options); | ||
auto action_targets = torch::zeros({batch_size, video_len}, options); | ||
|
||
const int threads = video_len; | ||
const int blocks = batch_size; | ||
// store in shared memory: | ||
// best action position for each state1 position (1x int) | ||
// best state2 position for each state1 position (1x int) | ||
// best score for each state1 position (1x float) | ||
// action tensor (1x float) | ||
// state tensor (2x float) | ||
const int shared_mem = video_len * (2 * sizeof(int) + 4 * sizeof(float)); | ||
SingleStateChangeKernel<<<blocks, threads, shared_mem>>>( | ||
state_tensor.data_ptr<float>(), | ||
action_tensor.data_ptr<float>(), | ||
lens.data_ptr<int>(), | ||
state_targets.data_ptr<int>(), | ||
action_targets.data_ptr<int>(), | ||
delta, | ||
kappa, | ||
max_action_state_distance); | ||
|
||
return std::vector<torch::Tensor>{state_targets, action_targets}; | ||
} |
Oops, something went wrong.