-
Notifications
You must be signed in to change notification settings - Fork 2.2k
/
Copy pathdisentangledAttentionPlugin.h
155 lines (114 loc) · 5.54 KB
/
disentangledAttentionPlugin.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
/*
* SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef TRT_DISENTANGLED_ATTENTION_PLUGIN_H
#define TRT_DISENTANGLED_ATTENTION_PLUGIN_H
#include "NvInferPlugin.h"
#include "common/plugin.h"
#include "common/serialize.hpp"
#include <cstdint>
#include <iostream>
#include <string>
#include <vector>
// One of the preferred ways of making TensorRT to be able to see
// our custom layer requires extending IPluginV2 and IPluginCreator classes.
// For requirements for overriden functions, check TensorRT API docs.
namespace nvinfer1
{
namespace plugin
{
// using namespace nvinfer1;
// Version 1: regular relative position index
// Version 2: log bucket relative position index
#define kDISENTANGLED_VERSION 2
#if kDISENTANGLED_VERSION == 1
constexpr int32_t kDISENTANGLED_TILESIZE = 32;
constexpr int32_t kDISENTANGLED_BLOCKDIMY = 8;
#elif kDISENTANGLED_VERSION == 2
constexpr int32_t kDISENTANGLED_TILESIZE = 64;
constexpr int32_t kDISENTANGLED_BLOCKDIMY = 4;
#endif
template <typename TDataType, int32_t tTileSize, int32_t tBlockDimY>
void disentangled_kernel_wrapper(TDataType const* data0, TDataType const* data1, TDataType const* data2,
TDataType* result, dim3 dimData0, dim3 dimData1, dim3 dimData2, dim3 dimResult, TDataType factor, int32_t span,
dim3 block, dim3 grid, cudaStream_t stream);
class DisentangledAttentionPlugin final : public nvinfer1::IPluginV2DynamicExt
{
public:
DisentangledAttentionPlugin();
DisentangledAttentionPlugin(int32_t span, float factor);
DisentangledAttentionPlugin(void const* serialData, size_t serialLength);
int32_t getNbOutputs() const noexcept override;
// DynamicExt plugins returns DimsExprs class instead of Dims
nvinfer1::DimsExprs getOutputDimensions(int32_t index, nvinfer1::DimsExprs const* inputs, int32_t nbInputDims,
nvinfer1::IExprBuilder& exprBuilder) noexcept override; // determine output dims based on input info
int32_t initialize() noexcept override;
void terminate() noexcept override;
size_t getWorkspaceSize(nvinfer1::PluginTensorDesc const* inputs, int32_t nbInputs,
nvinfer1::PluginTensorDesc const* outputs, int32_t nbOutputs) const noexcept override;
// This is where the plugin work is done.
int32_t enqueue(nvinfer1::PluginTensorDesc const* inputDesc, nvinfer1::PluginTensorDesc const* outputDesc,
void const* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream) noexcept override;
size_t getSerializationSize() const noexcept override;
void serialize(void* buffer) const noexcept override;
bool supportsFormatCombination(
int32_t pos, nvinfer1::PluginTensorDesc const* inOut, int32_t nbInputs, int32_t nbOutputs) noexcept override;
char const* getPluginType() const noexcept override;
char const* getPluginVersion() const noexcept override;
nvinfer1::IPluginV2DynamicExt* clone() const noexcept override;
void destroy() noexcept override;
nvinfer1::DataType getOutputDataType(
int32_t index, nvinfer1::DataType const* inputTypes, int32_t nbInputs) const noexcept override;
void setPluginNamespace(char const* pluginNamespace) noexcept override;
char const* getPluginNamespace() const noexcept override;
void configurePlugin(nvinfer1::DynamicPluginTensorDesc const* in, int32_t nbInputs,
nvinfer1::DynamicPluginTensorDesc const* out, int32_t nbOutputs) noexcept override;
private:
// Helper method for enqueue()
template <typename TDataType>
void enqueueType(nvinfer1::PluginTensorDesc const* inputDesc, nvinfer1::PluginTensorDesc const* outputDesc,
void const* const* inputs, void* const* outputs, cudaStream_t stream, TDataType factor);
std::string mNamespace;
// attributes
int32_t mSpan;
float mFactor;
using IPluginV2::getOutputDimensions;
using IPluginV2::getWorkspaceSize;
using IPluginV2::enqueue;
using IPluginV2Ext::configurePlugin;
};
class DisentangledAttentionPluginCreator : public nvinfer1::IPluginCreator
{
public:
DisentangledAttentionPluginCreator();
~DisentangledAttentionPluginCreator() override = default;
char const* getPluginName() const noexcept override;
char const* getPluginVersion() const noexcept override;
nvinfer1::PluginFieldCollection const* getFieldNames() noexcept override;
nvinfer1::IPluginV2DynamicExt* createPlugin(
char const* name, nvinfer1::PluginFieldCollection const* fc) noexcept override;
nvinfer1::IPluginV2DynamicExt* deserializePlugin(
char const* name, void const* serialData, size_t serialLength) noexcept override;
void setPluginNamespace(char const* pluginNamespace) noexcept override;
char const* getPluginNamespace() const noexcept override;
private:
static nvinfer1::PluginFieldCollection mFC;
static std::vector<nvinfer1::PluginField> mPluginAttributes;
std::string mNamespace;
};
} // namespace plugin
} // namespace nvinfer1
#endif // TRT_DISENTANGLED_ATTENTION_PLUGIN_H