Skip to content

Commit

Permalink
Add ccv_cnnp_rmsnorm.
Browse files Browse the repository at this point in the history
  • Loading branch information
liuliu committed Dec 20, 2023
1 parent eb059a5 commit 2fa1c4d
Show file tree
Hide file tree
Showing 2 changed files with 93 additions and 0 deletions.
83 changes: 83 additions & 0 deletions lib/nnc/ccv_cnnp_model_addons.c
Original file line number Diff line number Diff line change
Expand Up @@ -2037,6 +2037,89 @@ static ccv_cnnp_model_t* _ccv_cnnp_group_norm_copy(const ccv_cnnp_model_t* const
return ccv_cnnp_group_norm(self->params.gnorm.group_axis, self->params.gnorm.groups, self->params.gnorm.epsilon, self->params.gnorm.reduce_axis, self->params.gnorm.reduce_count, self->super.is_trainable, self->super.name);
}

// MARK - RMSNorm Layer

typedef struct {
ccv_cnnp_model_t super;
ccv_nnc_tensor_symbol_t output;
ccv_nnc_tensor_symbol_t scale;
ccv_nnc_cmd_param_t params;
} ccv_cnnp_model_rmsnorm_t;

static void _ccv_cnnp_rmsnorm_build(ccv_cnnp_model_t* const super, ccv_nnc_symbolic_graph_t* const graph, const ccv_nnc_tensor_symbol_t* const inputs, const int input_size, ccv_nnc_tensor_symbol_t* const outputs, const int output_size)
{
assert(input_size == 1);
assert(output_size == 1);
ccv_cnnp_model_rmsnorm_t* const self = (ccv_cnnp_model_rmsnorm_t*)super;
const ccv_nnc_tensor_param_t params = ccv_nnc_tensor_symbol_params(graph, inputs[0]);
ccv_nnc_tensor_param_t scale_params = params;
const int nd = ccv_nnc_tensor_nd(params.dim);
int i;
for (i = 0; i < nd; i++)
scale_params.dim[i] = 1;
for (i = 0; i < self->params.rmsnorm.count; i++)
scale_params.dim[self->params.rmsnorm.axis[i]] = params.dim[self->params.rmsnorm.axis[i]];
// Both scale and bias are shared between if this model is reused.
if (!self->scale.graph)
self->scale = ccv_nnc_tensor_symbol_new(graph, scale_params, "scale");
const ccv_nnc_cmd_t rmsnorm = ccv_nnc_cmd(CCV_NNC_LAYER_NORM_FORWARD, 0, self->params, 0);
ccv_nnc_tensor_param_t output_params[2];
ccv_nnc_hint_tensor_auto(rmsnorm, (ccv_nnc_tensor_param_t []){
params,
scale_params,
}, 2, ccv_nnc_no_hint, output_params, 2);
const ccv_nnc_tensor_symbol_t output = ccv_nnc_tensor_symbol_new(graph, output_params[0], 0);
const ccv_nnc_tensor_symbol_t saved_inv_std = ccv_nnc_tensor_symbol_new(graph, output_params[1], "saved_inv_std");
ccv_nnc_graph_exec_symbol_new(graph, rmsnorm, TENSOR_SYMBOL_LIST(inputs[0], self->scale), TENSOR_SYMBOL_LIST(output, saved_inv_std), "rmsnorm");
outputs[0] = output;
}

static void _ccv_cnnp_rmsnorm_init_states(ccv_cnnp_model_t* const super, ccv_nnc_symbolic_graph_t* const graph, const ccv_cnnp_state_initializer_f initializer, void* const context)
{
ccv_cnnp_model_rmsnorm_t* const self = (ccv_cnnp_model_rmsnorm_t*)super;
if (self->scale.graph)
initializer(context, CMD_RANDOM_UNIFORM_FORWARD(0, 1), ccv_nnc_no_hint, 0, 0, self->scale);
}

static void _ccv_cnnp_rmsnorm_add_to_parameter(ccv_cnnp_model_t* const super, const ccv_cnnp_add_to_array_f add_to_array, void* const parameters, const int is_trainable)
{
ccv_cnnp_model_rmsnorm_t* const self = (ccv_cnnp_model_rmsnorm_t*)super;
if (self->scale.graph)
add_to_array(parameters, self->scale, is_trainable);
}

static ccv_cnnp_model_t* _ccv_cnnp_rmsnorm_copy(const ccv_cnnp_model_t* const super, void* const context);

static const ccv_cnnp_model_vtab_t ccv_cnnp_rmsnorm_isa = {
.build = _ccv_cnnp_rmsnorm_build,
.init_states = _ccv_cnnp_rmsnorm_init_states,
.add_to_parameter = _ccv_cnnp_rmsnorm_add_to_parameter,
.copy = _ccv_cnnp_rmsnorm_copy,
};

ccv_cnnp_model_t* ccv_cnnp_rmsnorm(const float epsilon, const int axis[CCV_NNC_MAX_DIM_ALLOC], const int axis_count, const int is_trainable, const char* const name)
{
ccv_cnnp_model_rmsnorm_t* const model_rmsnorm = (ccv_cnnp_model_rmsnorm_t*)cccalloc(1, sizeof(ccv_cnnp_model_rmsnorm_t));
model_rmsnorm->super.isa = &ccv_cnnp_rmsnorm_isa;
model_rmsnorm->super.input_size = 1;
model_rmsnorm->super.outputs = &model_rmsnorm->output;
model_rmsnorm->super.output_size = 1;
model_rmsnorm->super.is_trainable = is_trainable;
ccv_cnnp_model_copy_name(&model_rmsnorm->super, name);
model_rmsnorm->scale.d = CCV_NNC_NO_TENSOR_SYMBOL;
model_rmsnorm->scale.graph = 0;
model_rmsnorm->params.rmsnorm.epsilon = epsilon;
model_rmsnorm->params.rmsnorm.count = axis_count;
memcpy(model_rmsnorm->params.lnorm.axis, axis, sizeof(int) * axis_count);
return (ccv_cnnp_model_t*)model_rmsnorm;
}

static ccv_cnnp_model_t* _ccv_cnnp_rmsnorm_copy(const ccv_cnnp_model_t* const super, void* const context)
{
const ccv_cnnp_model_rmsnorm_t* const self = (const ccv_cnnp_model_rmsnorm_t*)super;
return ccv_cnnp_rmsnorm(self->params.rmsnorm.epsilon, self->params.rmsnorm.axis, self->params.rmsnorm.count, self->super.is_trainable, self->super.name);
}

// MARK - Batched Matrix Mul Layer

typedef struct {
Expand Down
10 changes: 10 additions & 0 deletions lib/nnc/ccv_nnc.h
Original file line number Diff line number Diff line change
Expand Up @@ -4257,6 +4257,16 @@ CCV_WARN_UNUSED(ccv_cnnp_model_t*) ccv_cnnp_layer_norm(const float epsilon, cons
* @return A group norm model.
*/
CCV_WARN_UNUSED(ccv_cnnp_model_t*) ccv_cnnp_group_norm(const int group_axis, const int groups, const float epsilon, const int reduce_axis[CCV_NNC_MAX_DIM_ALLOC], const int axis_count, const int is_trainable, const char* const name);
/**
* A rmsnorm model.
* @param epsilon The epsilon in layer norm parameter.
* @param axis The axis are the feature axis to compute norm.
* @param axis_count How many axis we count as feature.
* @param is_trainable Whether the parameters of this model can be trained.
* @param name The unique name of the model.
* @return A rmsnorm model.
*/
CCV_WARN_UNUSED(ccv_cnnp_model_t*) ccv_cnnp_rmsnorm(const float epsilon, const int axis[CCV_NNC_MAX_DIM_ALLOC], const int axis_count, const int is_trainable, const char* const name);
/**
* Add two input tensors together. Different from sum because this support broadcasting.
* @param p The weight for the first input.
Expand Down

0 comments on commit 2fa1c4d

Please sign in to comment.