-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: Adding baseline in Elixir to run script to benchmark models.
- Loading branch information
1 parent
76a458b
commit 7dc22fd
Showing
4 changed files
with
177 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
/models/*/* |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,114 @@ | ||
defmodule Comparison.Models do | ||
@moduledoc """ | ||
Manages loading the modules when benchmarking models. | ||
It is inspired by the `App.Models` module in the Phoenix app. | ||
""" | ||
require Logger | ||
|
||
@doc """ | ||
Verifies and downloads the model according. | ||
You can optionally force it to re-download the model by passing `force_download?` | ||
""" | ||
def verify_and_download_model(model, force_download? \\ false) do | ||
case force_download? do | ||
true -> | ||
File.rm_rf!(model.cache_path) # Delete any cached pre-existing model | ||
download_model(model) # Download model | ||
|
||
false -> | ||
# Check if the model cache directory exists or if it's not empty. | ||
# If so, we download the model. | ||
model_location = Path.join(model.cache_path, "huggingface") | ||
if not File.exists?(model_location) or File.ls!(model_location) == [] do | ||
download_model(model) | ||
end | ||
end | ||
end | ||
|
||
@doc """ | ||
Serving function that serves the `Bumblebee` models used throughout the app. | ||
This function is meant to be called and served by `Nx`, | ||
like `Nx.Serving.run(serving, "The capital of [MASK] is Paris.")` | ||
This assumes the models that are being used exist locally. | ||
""" | ||
def serving(model) do | ||
model = load_offline_model_params(model) | ||
|
||
Bumblebee.Vision.image_to_text( | ||
model.model_info, | ||
model.featurizer, | ||
model.tokenizer, | ||
model.generation_config, | ||
compile: [batch_size: 10], | ||
defn_options: [compiler: EXLA], | ||
preallocate_params: true | ||
) | ||
end | ||
|
||
# Loads the model from the cache folder. | ||
# It will load the model and the respective the featurizer, tokenizer and generation config if needed, | ||
# and return a map with all of these at the end. | ||
defp load_offline_model_params(model) do | ||
Logger.info("Loading #{model.name}...") | ||
|
||
# Loading model | ||
loading_settings = {:hf, model.name, cache_dir: model.cache_path, offline: true} | ||
{:ok, model_info} = Bumblebee.load_model(loading_settings) | ||
|
||
info = %{model_info: model_info} | ||
|
||
# Load featurizer, tokenizer and generation config if needed | ||
info = | ||
if(model.load_featurizer) do | ||
{:ok, featurizer} = Bumblebee.load_featurizer(loading_settings) | ||
Map.put(info, :featurizer, featurizer) | ||
else | ||
info | ||
end | ||
|
||
info = | ||
if(model.load_tokenizer) do | ||
{:ok, tokenizer} = Bumblebee.load_tokenizer(loading_settings) | ||
Map.put(info, :tokenizer, tokenizer) | ||
else | ||
info | ||
end | ||
|
||
info = | ||
if(model.load_generation_config) do | ||
{:ok, generation_config} = | ||
Bumblebee.load_generation_config(loading_settings) | ||
|
||
Map.put(info, :generation_config, generation_config) | ||
else | ||
info | ||
end | ||
|
||
# Return a map with the model and respective parameters. | ||
info | ||
end | ||
|
||
# Downloads the models according to a given %ModelInfo struct. | ||
# It will load the model and the respective the featurizer, tokenizer and generation config if needed. | ||
defp download_model(model) do | ||
Logger.info("Downloading #{model.name}...") | ||
|
||
# Download model | ||
downloading_settings = {:hf, model.name, cache_dir: model.cache_path} | ||
Bumblebee.load_model(downloading_settings) | ||
|
||
# Download featurizer, tokenizer and generation config if needed | ||
if(model.load_featurizer) do | ||
Bumblebee.load_featurizer(downloading_settings) | ||
end | ||
|
||
if(model.load_tokenizer) do | ||
Bumblebee.load_tokenizer(downloading_settings) | ||
end | ||
|
||
if(model.load_generation_config) do | ||
Bumblebee.load_generation_config(downloading_settings) | ||
end | ||
end | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
This folder will hold the models that are being benchmarked. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
# Install the needed dependencies | ||
Mix.install( | ||
[ | ||
{:bumblebee, "~> 0.4.2"}, | ||
{:exla, ">= 0.0.0"} | ||
], | ||
config: [nx: [default_backend: EXLA.Backend]] | ||
) | ||
|
||
# Define the model information struct used for each model being benchmarked. | ||
defmodule ModelInfo do | ||
@doc """ | ||
Information regarding the model being loaded. | ||
It holds the name of the model repository and the directory it will be saved into. | ||
It also has booleans to load each model parameter at will - this is because some models (like BLIP) require featurizer, tokenizations and generation configuration. | ||
""" | ||
defstruct [:name, :cache_path, :load_featurizer, :load_tokenizer, :load_generation_config] | ||
end | ||
|
||
|
||
# Benchmark module that when executed, will create a file with the results of the benchmark. | ||
defmodule Benchmark do | ||
|
||
# Import model manager module | ||
Code.require_file("manage_models.exs") | ||
|
||
# Models to be benchmarked ------- | ||
@models_folder_path Path.join(File.cwd!, "models") | ||
|
||
@model %ModelInfo{ | ||
name: "Salesforce/blip-image-captioning-base", | ||
cache_path: Path.join(@models_folder_path, "blip-image-captioning-base"), | ||
load_featurizer: true, | ||
load_tokenizer: true, | ||
load_generation_config: true | ||
} | ||
def extract_label(result) do %{results: [%{text: label}]} = result; label end | ||
|
||
|
||
# Run this to create a file to benchmark the models | ||
def main() do | ||
|
||
# We first verify if the model exists and we download accordingly | ||
Comparison.Models.verify_and_download_model(@model) | ||
|
||
serving = Comparison.Models.serving(@model) | ||
|
||
# Retrieve 50 images from COCO dataset | ||
#images = get_coco_images() | ||
|
||
# Run the images through the model and get the prediction for each one. | ||
# We measure the time to predict the image, get the prediction and save the prediction and execution time to file. | ||
#Enum.each(images, fn image -> | ||
# prediction = predict_example_image(image, model) | ||
#end) | ||
|
||
end | ||
end | ||
|
||
# Run Benchmark module | ||
Benchmark.main() |