Skip to content

Commit

Permalink
Explore & Label Space (#10)
Browse files Browse the repository at this point in the history
* probe class

* change attack into adversarial

* lib changes

* added space submodule

* new demo

* explore interface

* label interface

* new interfaces

* clf training files

* new demo and download option

* bug comment

* train clf

* make act ds

* lower disk usage

* pre-commit update

* tensor dataset

* typo

* non tensor output handle

* config names

* infos

* all layers shortcut

* typo

* isort bug

* change isort version

* dropping test coverage constraint
  • Loading branch information
Xmaster6y authored Apr 12, 2024
1 parent 7e68e03 commit 59a78f8
Show file tree
Hide file tree
Showing 36 changed files with 2,705 additions and 1,121 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -130,4 +130,5 @@ dmypy.json

# Various files
ignored
wandb
*secret*
3 changes: 3 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
[submodule "explore-label-concepts"]
path = explore-label-concepts
url = [email protected]:spaces/Xmaster6y/explore-label-concepts
9 changes: 5 additions & 4 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
repos:
- repo: https://github.com/psf/black
rev: 23.11.0
rev: 24.3.0
hooks:
- id: black
args: ["--config", "pyproject.toml"]
Expand All @@ -16,16 +16,17 @@ repos:
- id: trailing-whitespace
- id: check-docstring-first
- repo: https://github.com/python-poetry/poetry
rev: 1.7.0
rev: 1.8.0
hooks:
- id: poetry-check
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.7.1
rev: v1.9.0
hooks:
- id: mypy
additional_dependencies: ['types-requests', 'types-toml']
exclude: ^scripts/
- repo: https://github.com/pycqa/flake8
rev: 6.1.0
rev: 7.0.0
hooks:
- id: flake8
args: ['--ignore=E203,W503', '--per-file-ignores=__init__.py:F401']
Expand Down
16 changes: 16 additions & 0 deletions .vscode/launch.json
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,22 @@
"module": "scripts.contrast_reading",
"console": "integratedTerminal",
"justMyCode": false
},
{
"name": "Script Train CLF",
"type": "debugpy",
"request": "launch",
"module": "scripts.train_clf",
"console": "integratedTerminal",
"justMyCode": false
},
{
"name": "Script Make Act Dataset",
"type": "debugpy",
"request": "launch",
"module": "scripts.make_activation_dataset",
"console": "integratedTerminal",
"justMyCode": false
}
]
}
8 changes: 4 additions & 4 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@ test-assets:

.PHONY: tests
tests:
poetry run pytest tests --cov=src --cov-report=term-missing --cov-fail-under=50 -s -v
poetry run pytest tests --cov=src --cov-report=term-missing --cov-fail-under=1 -s -v

# API
.PHONY: app-start
app-start:
poetry run python -m demo.main
.PHONY: demo-explore-label-concepts
demo-explore-label-concepts:
poetry run python explore-label-concepts/app.py
18 changes: 0 additions & 18 deletions demo/main.py

This file was deleted.

1 change: 1 addition & 0 deletions explore-label-concepts
Submodule explore-label-concepts added at 9b4416
2,908 changes: 2,042 additions & 866 deletions poetry.lock

Large diffs are not rendered by default.

13 changes: 12 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ line-length = 79
[tool.isort]
profile = "black"
line_length = 79
src_paths = ["src", "tests", "scripts", "docs", "demo"]

[tool.poetry]
name = "mulsi"
Expand Down Expand Up @@ -32,6 +33,8 @@ tensordict = "^0.3.0"
typeguard = "^4.1.5"
einops = "^0.7.0"
torchvision = "^0.17.0"
datasets = {version = "^2.18.0", extras = ["scripts"]}
wandb = {version = "^0.16.5", extras = ["scripts"]}

[tool.poetry.group.dev]
optional = true
Expand All @@ -46,4 +49,12 @@ gdown = "^5.1.0"
optional = true

[tool.poetry.group.demo.dependencies]
gradio = "^4.14.0"
gradio = {extras = ["oauth"], version = "^4.24.0"}
jsonlines = "^4.0.0"

[tool.poetry.group.scripts]
optional = true

[tool.poetry.group.scripts.dependencies]
datasets = "^2.18.0"
wandb = "^0.16.5"
2 changes: 2 additions & 0 deletions scripts/assets/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
*
!.gitignore
27 changes: 27 additions & 0 deletions scripts/clip_reading.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
"""Simple FGSM attack.
"""

from PIL import Image
from torchvision.transforms.functional import pil_to_tensor
from transformers import CLIPModel, CLIPProcessor

from mulsi import AdversarialImage, DiffClipProcessor

####################
# HYPERPARAMETERS
####################
image_path = "assets/orange.jpg"
model_name = "openai/clip-vit-base-patch32"
epsilon = 0.1
####################

image = Image.open(image_path)
model = CLIPModel.from_pretrained(model_name)
model.eval()
for param in model.parameters():
param.requires_grad = False
processor = CLIPProcessor.from_pretrained(model_name)
diff_processor = DiffClipProcessor(processor.image_processor)

image_tensor = pil_to_tensor(image).float().unsqueeze(0)
adv_image = AdversarialImage(image_tensor, model, None)
10 changes: 10 additions & 0 deletions scripts/constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
"""Constans for running the scripts.
"""

import os
import pathlib

HF_TOKEN = os.getenv("HF_TOKEN")
WANDB_API_KEY = os.getenv("WANDB_API_KEY")

ASSETS_FOLDER = pathlib.Path(__file__).parent / "assets"
33 changes: 20 additions & 13 deletions scripts/contrast_reading.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
"""Simple showcase of the contrast reading vectors.
Run with:
```
poetry run python -m scripts.contrast_reading
```
"""

from transformers import AutoModelForCausalLM, AutoTokenizer

from mulsi import RepresentationReader, TdTokenizer
from mulsi.wrapper import LlmWrapper
from mulsi import ContrastReader, LlmWrapper, TdTokenizer
from scripts.utils import viz

####################
# HYPERPARAMETERS
Expand All @@ -14,9 +19,7 @@
cons_inputs = "I hate this codebase"
####################

love_reader = RepresentationReader.from_name(
"contrast", pros_inputs=pros_inputs, cons_inputs=cons_inputs
)
love_reader = ContrastReader(pros_inputs=pros_inputs, cons_inputs=cons_inputs)
model = AutoModelForCausalLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
td_tokenizer = TdTokenizer(tokenizer)
Expand All @@ -27,16 +30,20 @@
sentences = [
"I love this codebase",
"I hate this codebase",
"I like this codebase",
"Doing XAI research is what I love",
"My girlfriend loves me",
"I am neutral about this codebase",
"I like eating ice cream",
"When is the next train?",
]
head_line = ("Sentence", "Cosine Similarity")
table = []
for sentence in sentences:
print(
sentence,
love_reader.read(
wrapper=wrapper,
inputs=sentence,
reading_vector=love_reading_vector,
),
)
cosim = love_reader.read(
wrapper=wrapper,
inputs=sentence,
reading_vector=love_reading_vector,
).item()
table.append((sentence, f"{cosim:.2f}"))
viz.table_print(headings=head_line, table=table)
148 changes: 148 additions & 0 deletions scripts/make_activation_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
"""Script to make a dataset of activations from a CLIP model.
Run with:
```
poetry run python -m scripts.make_activation_dataset
```
"""

import argparse
import re

import torch
import wandb
from datasets import Dataset, DatasetDict, load_dataset
from huggingface_hub import HfApi
from torch.utils.data import DataLoader
from transformers import CLIPModel, CLIPProcessor

from mulsi.hook import CacheHook, HookConfig
from scripts.constants import ASSETS_FOLDER, HF_TOKEN, WANDB_API_KEY
from scripts.utils.dataset import make_generators

####################
# HYPERPARAMETERS
####################
parser = argparse.ArgumentParser("train-clf")
parser.add_argument(
"--model_name", type=str, default="openai/clip-vit-base-patch32"
)
parser.add_argument(
"--dataset_name", type=str, default="Xmaster6y/fruit-vegetable-concepts"
)
parser.add_argument("--download_dataset", action="store_true", default=False)
parser.add_argument("--batch_size", type=int, default=64)
parser.add_argument("--layers", type=str, default="0,6,11")
####################

ARGS = parser.parse_args()
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"[INFO] Running on {DEVICE}")
hf_api = HfApi(token=HF_TOKEN)
wandb.login(key=WANDB_API_KEY) # type: ignore

processor = CLIPProcessor.from_pretrained(ARGS.model_name)
model = CLIPModel.from_pretrained(ARGS.model_name)
model.eval()
model.to(DEVICE)

if ARGS.layers == "*":
layers = [
str(i) for i in range(model.vision_model.config.num_hidden_layers)
]
else:
layers = ARGS.layers.split(",")

if ARGS.download_dataset:
hf_api.snapshot_download(
repo_id=ARGS.dataset_name,
repo_type="dataset",
local_dir=f"{ASSETS_FOLDER}/{ARGS.dataset_name}",
)

dataset = load_dataset(f"{ASSETS_FOLDER}/{ARGS.dataset_name}")
print(f"[INFO] Loaded dataset: {dataset}")


def collate_fn(batch):
images = []
infos = []
for x in batch:
images.append(x.pop("image"))
x.pop("original_name")
infos.append(x)
return images, infos


splits = ["train", "validation", "test"]
dataloaders = {
split: DataLoader(
dataset[split],
batch_size=ARGS.batch_size,
collate_fn=collate_fn,
)
for split in splits
}

cache_hook = CacheHook(
HookConfig(module_exp=rf".*\.layers\.({'|'.join(layers)})$")
)
handles = cache_hook.register(model.vision_model)
print(f"[INFO] Registered {len(handles)} hooks")


def make_batch_gen(
batched_activations,
infos,
):
def gen():
for activation, info in zip(batched_activations, infos):
yield {"activation": activation.cpu().float().numpy(), **info}

return gen


@torch.no_grad
def make_gen_list(
gen_dict,
dataloaders,
):
module_exp = re.compile(r".*\.layers\.(?P<layer>\d+)$")
for split, dataloader in dataloaders.items():
for batch in dataloader:
images, infos = batch
image_inputs = processor(
images=images,
return_tensors="pt",
)
image_inputs = {k: v.to(DEVICE) for k, v in image_inputs.items()}
model.vision_model(**image_inputs)
for module, batched_activations in cache_hook.storage.items():
m = module_exp.match(module)
layer = m.group("layer")
gen_dict[layer][split].append(
make_batch_gen(batched_activations[0].detach(), infos)
)


gen_dict = make_generators(
layers=layers,
splits=splits,
make_gen_list=make_gen_list,
dataloaders=dataloaders,
)
dataset_dict = {
f"layers.{layer}": DatasetDict(
{
split: Dataset.from_generator(gen_dict[layer][split])
for split in splits
}
)
for layer in layers
}

for layer_name, dataset in dataset_dict.items():
dataset.push_to_hub(
repo_id=ARGS.dataset_name.replace("concepts", "activations"),
config_name=layer_name,
)
Loading

0 comments on commit 59a78f8

Please sign in to comment.