Skip to content

Commit

Permalink
Fix issue in per-tensor quantization and missing input_rank (#65)
Browse files Browse the repository at this point in the history
  • Loading branch information
Chizkiyahu authored Oct 5, 2023
1 parent d708fcf commit e1d7f11
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -155,8 +155,11 @@ def symbolic(g,
# When None is passed as channel_axis, the op has no attribute of channel_axis,
# which creates conflict with the onnxruntime function. For this reason, if we quantize
# per-tensor and channel_axis is None, we set it to 0.
# per-tensor and input_rank is None, we set it to 4.
if not per_channel and channel_axis is None:
channel_axis = 0
if not per_channel and input_rank is None:
input_rank = 4

return g.op(f"{ONNX_CUSTOM_OP_DOMAIN}::WeightsLUTPOTQuantizer", input_tensor,
g.op('Constant', value_t=torch.tensor(lut_values, dtype=torch.float32)),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -177,8 +177,11 @@ def symbolic(g,
# When None is passed as channel_axis, the op has no attribute of channel_axis,
# which creates conflict with the onnxruntime function. For this reason, if we quantize
# per-tensor and channel_axis is None, we set it to 0.
# per-tensor and input_rank is None, we set it to 4.
if not per_channel and channel_axis is None:
channel_axis = 0
if not per_channel and input_rank is None:
input_rank = 4

return g.op(f"{ONNX_CUSTOM_OP_DOMAIN}::WeightsLUTSymmetricQuantizer", input_tensor,
g.op('Constant', value_t=torch.tensor(lut_values, dtype=torch.float32)),
Expand Down

0 comments on commit e1d7f11

Please sign in to comment.