Skip to content

Commit

Permalink
[luci] Introduce MinimumMSE quantization algorithm
Browse files Browse the repository at this point in the history
This commit introduces MinimumMSE quantization algorithm.

ONE-DCO-1.0-Signed-off-by: Vyacheslav Bazhenov <[email protected]>
  • Loading branch information
Vyacheslav Bazhenov committed Jul 25, 2024
1 parent cd2f583 commit 7684e7a
Show file tree
Hide file tree
Showing 3 changed files with 112 additions and 6 deletions.
6 changes: 6 additions & 0 deletions compiler/luci/pass/include/luci/Pass/QuantizationParameters.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,12 @@ struct LayerInfo
QuantizationGranularity granularity;
};

enum struct QuantizationAlgorithm
{
Common = 0,
MinimumMSE = 1,
};

} // namespace luci

#endif // __LUCI_QUANTIZATION_PARAMETERS_H__
106 changes: 102 additions & 4 deletions compiler/luci/pass/src/QuantizeWeightsOnly.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,8 @@ void iterate_per_channel(CircleConst *node, int32_t &channel_dim_index, IterFunc
template <loco::DataType out_type>
void sym_wquant_per_channel(CircleConst *node, std::vector<float> &min, std::vector<float> &max,
std::vector<float> &scaling_factor, std::vector<float> &nudged_min,
std::vector<float> &nudged_max, int32_t &channel_dim_index)
std::vector<float> &nudged_max, int32_t &channel_dim_index,
const luci::QuantizationAlgorithm &alg_type)
{
assert(node->dtype() == loco::DataType::FLOAT32);
assert(out_type == loco::DataType::S4 || out_type == loco::DataType::S8 ||
Expand All @@ -91,7 +92,104 @@ void sym_wquant_per_channel(CircleConst *node, std::vector<float> &min, std::vec
quantized_values[cal_offset(dimension, indices)] =
static_cast<int32_t>(std::round(data * scaling_factor_inv));
};
if (alg_type == QuantizationAlgorithm::MinimumMSE)
{
std::vector<float> max_scale(min.size());
for (int i = 0; i < min.size(); ++i)
{
max_scale[i] = std::max(std::fabs(min[i]), std::fabs(max[i]));
}
std::vector<double> channel_mse(min.size());
std::vector<double> channel_min_mse(min.size(), std::numeric_limits<double>::max());

auto calculate_mse = [&](uint32_t *indices, loco::TensorShape &dimension,
int channel_dim_index) {
int channel_idx = indices[channel_dim_index];
auto data = node->at<loco::DataType::FLOAT32>(cal_offset(dimension, indices));
data = data < nudged_min[channel_idx] ? nudged_min[channel_idx] : data;
data = data > nudged_max[channel_idx] ? nudged_max[channel_idx] : data;
double diff =
data - quantized_values[cal_offset(dimension, indices)] * scaling_factor[channel_idx];
channel_mse[channel_idx] += diff * diff;
};

std::vector<float> scaling_factor_base = scaling_factor;
std::vector<std::pair<float, float>> golden_start_end(min.size());
const auto kSearchIterations = 100;
const auto kPhi = 1.618033988749894848204586834365638118;
const auto kRangeCoefficient = 0.1;
for (int i = 0; i < max_scale.size(); ++i)
{
golden_start_end[i].first = scaling_factor_base[i] * (1.0 - kRangeCoefficient);
golden_start_end[i].second = scaling_factor_base[i] * (1.0 + kRangeCoefficient);
}


for (int i = 0; i < kSearchIterations; ++i)
{
for (int j = 0; j < scaling_factor.size(); ++j)
{
scaling_factor[j] = golden_start_end[j].second -
(golden_start_end[j].second - golden_start_end[j].first) / kPhi;
}
for (auto &val : channel_mse)
{
val = 0;
}
iterate_per_channel(node, channel_dim_index, quantize);
iterate_per_channel(node, channel_dim_index, calculate_mse);
auto channel_mse_x1 = channel_mse;

for (int j = 0; j < scaling_factor.size(); ++j)
{
scaling_factor[j] = golden_start_end[j].first +
(golden_start_end[j].second - golden_start_end[j].first) / kPhi;
}
for (auto &val : channel_mse)
{
val = 0;
}
iterate_per_channel(node, channel_dim_index, quantize);
iterate_per_channel(node, channel_dim_index, calculate_mse);
auto channel_mse_x2 = channel_mse;

for (int k = 0; k < channel_mse_x1.size(); ++k)
{
if (channel_mse_x1[k] > channel_mse_x2[k])
{
golden_start_end[k].first =
golden_start_end[k].second -
(golden_start_end[k].second - golden_start_end[k].first) / kPhi;
}
else
{
golden_start_end[k].second =
golden_start_end[k].first +
(golden_start_end[k].second - golden_start_end[k].first) / kPhi;
}
}
}
for (int i = 0; i < golden_start_end.size(); ++i)
{
scaling_factor[i] = (golden_start_end[i].first + golden_start_end[i].second) / 2;
}
iterate_per_channel(node, channel_dim_index, quantize);
iterate_per_channel(node, channel_dim_index, calculate_mse);
auto channel_mse_opt = channel_mse;
scaling_factor = scaling_factor_base;
iterate_per_channel(node, channel_dim_index, quantize);
iterate_per_channel(node, channel_dim_index, calculate_mse);
auto channel_mse_base = channel_mse;

// Checking if found scale is better than base
for (int i = 0; i < channel_mse_base.size(); ++i)
{
if (channel_mse_opt[i] < channel_mse_base[i])
scaling_factor[i] = (golden_start_end[i].first + golden_start_end[i].second) / 2;
else
channel_mse_opt[i] = channel_mse_base[i];
}
}
iterate_per_channel(node, channel_dim_index, quantize);

node->dtype(out_type); // change the type of tensor
Expand Down Expand Up @@ -167,17 +265,17 @@ void QuantizeWeightsOnly::quantize_weights(luci::CircleConst *weights)
if (output_type == loco::DataType::S4)
{
sym_wquant_per_channel<loco::DataType::S4>(weights, min, max, scaling_factor, nudged_min,
nudged_max, channel_dim_index);
nudged_max, channel_dim_index, algorithm);
}
else if (output_type == loco::DataType::S8)
{
sym_wquant_per_channel<loco::DataType::S8>(weights, min, max, scaling_factor, nudged_min,
nudged_max, channel_dim_index);
nudged_max, channel_dim_index, algorithm);
}
else if (output_type == loco::DataType::S16)
{
sym_wquant_per_channel<loco::DataType::S16>(weights, min, max, scaling_factor, nudged_min,
nudged_max, channel_dim_index);
nudged_max, channel_dim_index, algorithm);
}
else
{
Expand Down
6 changes: 4 additions & 2 deletions compiler/luci/pass/src/QuantizeWeightsOnly.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,16 @@ namespace luci
*/
struct QuantizeWeightsOnly final : public luci::CircleNodeMutableVisitor<void>
{
QuantizeWeightsOnly(loco::DataType input, loco::DataType output, QuantizationGranularity gr)
: input_type(input), output_type(output), granularity(gr)
QuantizeWeightsOnly(loco::DataType input, loco::DataType output, QuantizationGranularity gr,
QuantizationAlgorithm alg = QuantizationAlgorithm::Common)
: input_type(input), output_type(output), granularity(gr), algorithm(alg)
{
}

loco::DataType input_type;
loco::DataType output_type;
QuantizationGranularity granularity;
QuantizationAlgorithm algorithm;

private:
void quantize_weights(luci::CircleConst *weights);
Expand Down

0 comments on commit 7684e7a

Please sign in to comment.