-
Notifications
You must be signed in to change notification settings - Fork 18
/
Copy pathvit.h
124 lines (111 loc) · 3.3 KB
/
vit.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
#pragma once
#include "ggml/ggml.h"
#include "ggml/ggml-alloc.h"
#include "ggml/examples/stb_image.h"
#include <cassert>
#include <cmath>
#include <cstddef>
#include <cstdio>
#include <cstring>
#include <fstream>
#include <map>
#include <string>
#include <vector>
#include <thread>
#include <cinttypes>
#include <algorithm>
struct vit_hparams
{
int32_t hidden_size = 768;
int32_t num_hidden_layers = 12;
int32_t num_attention_heads = 12;
int32_t num_classes = 1000;
int32_t patch_size = 8;
int32_t img_size = 224;
int32_t ftype = 1;
float eps = 1e-6f;
std::string interpolation = "bicubic";
std::map<int, std::string> id2label;
int32_t n_enc_head_dim() const;
int32_t n_img_size() const;
int32_t n_patch_size() const;
int32_t n_img_embd() const;
};
struct vit_block
{
struct ggml_tensor *norm1_w;
struct ggml_tensor *norm1_b;
struct ggml_tensor *qkv_w;
struct ggml_tensor *qkv_b;
struct ggml_tensor *proj_w;
struct ggml_tensor *proj_b;
struct ggml_tensor *norm2_w;
struct ggml_tensor *norm2_b;
struct ggml_tensor *mlp_lin1_w;
struct ggml_tensor *mlp_lin1_b;
struct ggml_tensor *mlp_lin2_w;
struct ggml_tensor *mlp_lin2_b;
};
struct classifier_head
{
struct ggml_tensor *norm_w;
struct ggml_tensor *norm_b;
struct ggml_tensor *head_w;
struct ggml_tensor *head_b;
};
struct vit_image_encoder
{
struct ggml_tensor *pe;
struct ggml_tensor *cls_token;
struct ggml_tensor *proj_w;
struct ggml_tensor *proj_b;
std::vector<vit_block> layers;
};
struct vit_state
{
struct ggml_tensor *prediction;
struct ggml_context *ctx;
std::vector<uint8_t> work_buffer;
std::vector<uint8_t> buf_alloc_img_enc;
std::vector<uint8_t> buf_compute_img_enc;
struct ggml_allocr *allocr;
};
struct vit_model
{
vit_hparams hparams;
vit_image_encoder enc_img;
classifier_head classifier;
struct ggml_context *ctx;
std::map<std::string, struct ggml_tensor *> tensors;
};
struct image_u8
{
int nx;
int ny;
std::vector<uint8_t> data;
};
struct image_f32
{
int nx;
int ny;
std::vector<float> data;
};
struct vit_params
{
int32_t seed = -1;
int32_t n_threads = std::min(4, (int32_t)std::thread::hardware_concurrency());
int32_t topk = 5;
std::string model = "../ggml-model-f16.gguf"; // model path
std::string fname_inp = "../assets/tench.jpg"; // image path
float eps = 1e-6f; // epsilon used in LN
};
void print_t_f32(const char *title, struct ggml_tensor *t, int n);
static void ggml_disconnect_node_from_graph(ggml_tensor *t);
void ggml_graph_compute_helper(std::vector<uint8_t> &buf, ggml_cgraph *graph, int n_threads);
bool load_image_from_file(const std::string &fname, image_u8 &img);
bool vit_image_preprocess(const image_u8 &img, image_f32 &res, const vit_hparams ¶ms);
bool vit_model_load(const std::string &fname, vit_model &model);
struct ggml_cgraph *vit_encode_image(const vit_model &model, vit_state &state, const image_f32 &img);
int vit_predict(const vit_model &model, vit_state &state, const image_f32 img1, const vit_params ¶ms, std::vector<std::pair<float, int>> &predictions);
void print_usage(int argc, char **argv, const vit_params ¶ms);
bool vit_params_parse(int argc, char **argv, vit_params ¶ms);