Skip to content

Commit

Permalink
onnx2leela options to fix tf exported onnx models (#1928)
Browse files Browse the repository at this point in the history
* onnx2leela options to fix tf exported onnx models
* initial support for onnx external data
* autodetect fp16 onnx model
* ensure outputs have the same data type as the input
  • Loading branch information
borg323 authored Nov 14, 2023
1 parent c285401 commit a2f98f7
Show file tree
Hide file tree
Showing 5 changed files with 242 additions and 23 deletions.
10 changes: 10 additions & 0 deletions scripts/compile_proto.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,10 +376,14 @@ def GenerateFunctionDeclarations(self, w):
w.Write("%s* add_%s();" % (cpp_type, name))
else:
w.Write("void add_%s(%s val);" % (name, cpp_type))
# Using a vector here breaks API compatibility with the standard
# protobuf library, but it is more convenient.
w.Write("const std::vector<%s>& %s() const;" %
(var_cpp_type, name))
w.Write("std::vector<%s>* mutable_%s();" % (var_cpp_type, name))
if self.type.IsMessage():
w.Write("const %s& %s(size_t idx) const;" % (cpp_type, name))
w.Write("%s* mutable_%s(size_t idx);" % (cpp_type, name))
else:
w.Write("%s %s(size_t idx) const;" % (cpp_type, name))
w.Write("size_t %s_size() const;" % (name))
Expand Down Expand Up @@ -408,10 +412,16 @@ def GenerateFunctionDefinitions(self, w, class_name):
w.Write(
"inline const std::vector<%s>& %s::%s() const { return %s_; }"
% (var_cpp_type, class_name, name, name))
w.Write(
"inline std::vector<%s>* %s::mutable_%s() { return &%s_; }"
% (var_cpp_type, class_name, name, name))
if self.type.IsMessage():
w.Write(
"inline const %s& %s::%s(size_t idx) const { return %s_[idx]; }"
% (cpp_type, class_name, name, name))
w.Write(
"inline %s* %s::mutable_%s(size_t idx) { return &%s_[idx]; }"
% (cpp_type, class_name, name, name))
else:
w.Write(
"inline %s %s::%s(size_t idx) const { return %s_[idx]; }" %
Expand Down
227 changes: 211 additions & 16 deletions src/lc0ctl/onnx2leela.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
#include "neural/onnx/onnx.pb.h"
#include "proto/net.pb.h"
#include "utils/files.h"
#include "utils/fp16_utils.h"
#include "utils/optionsparser.h"

namespace lczero {
Expand Down Expand Up @@ -81,8 +82,6 @@ const OptionId kMovesLeftFormatId("moves-left-format", "MovesLeftFormat",
"Format of the moves left head output.");

// ONNX options.
const OptionId kOnnxDataTypeId("onnx-data-type", "OnnxDataType",
"Data type to feed into the neural network.");
const OptionId kOnnxInputId{"onnx-input", "OnnxInput",
"The name of the input ONNX node."};
const OptionId kOnnxOutputValueId{
Expand All @@ -97,6 +96,12 @@ const OptionId kOnnxOutputMlhId{"onnx-output-mlh", "OnnxOutputMlh",

const OptionId kValidateModelId{"validate-weights", "ValidateWeights",
"Do a basic check of the provided ONNX file."};
const OptionId kFixRule50Id{
"fix-rule50", "",
"Fix tensorflow exported onnx that needs rule50 input scaling."};
const OptionId kFixWdlSoftmaxId{
"fix-wdl-softmax", "",
"Fix tensorflow exported onnx that is missing wdl output softmax."};

bool ProcessParameters(OptionsParser* options) {
using pblczero::NetworkFormat;
Expand All @@ -122,17 +127,15 @@ bool ProcessParameters(OptionsParser* options) {
NetworkFormat::MovesLeftFormat_Name)) =
NetworkFormat::MovesLeftFormat_Name(NetworkFormat::MOVES_LEFT_V1);
// Onnx options.
options->Add<ChoiceOption>(kOnnxDataTypeId,
GetAllEnumValues(OnnxModel::DataType_AllValues,
OnnxModel::DataType_Name)) =
OnnxModel::DataType_Name(OnnxModel::FLOAT);
options->Add<StringOption>(kOnnxInputId);
options->Add<StringOption>(kOnnxOutputPolicyId);
options->Add<StringOption>(kOnnxOutputValueId);
options->Add<StringOption>(kOnnxOutputWdlId);
options->Add<StringOption>(kOnnxOutputMlhId);

options->Add<BoolOption>(kValidateModelId) = true;
options->Add<BoolOption>(kFixRule50Id) = false;
options->Add<BoolOption>(kFixWdlSoftmaxId) = false;

if (!options->ProcessAllFlags()) return false;

Expand All @@ -142,10 +145,8 @@ bool ProcessParameters(OptionsParser* options) {
return true;
}

bool ValidateNetwork(const pblczero::Net& weights) {
bool ValidateNetwork(const pblczero::Net& weights, pblczero::ModelProto& onnx) {
const auto& onnx_model = weights.onnx_model();
pblczero::ModelProto onnx;
onnx.ParseFromString(onnx_model.model());

if (!onnx.has_ir_version()) {
CERR << "ONNX file doesn't appear to have version specified. Likely not an "
Expand Down Expand Up @@ -230,6 +231,190 @@ bool ValidateNetwork(const pblczero::Net& weights) {
return true;
}

void FixRule50(pblczero::ModelProto& model, const std::string& in, bool fp16) {
std::string name = "rule50fix";

for (size_t i = 0; i < model.graph().node_size(); i++) {
auto node = model.graph().node(i);
for (size_t j = 0; j < node.input_size(); j++) {
if (node.input(j) == in) {
CERR << "Inerting scaling between " << in << " and " << node.name();
model.mutable_graph()->mutable_node(i)->mutable_input()->at(j) =
std::string(name);
}
}
}

auto* init = model.mutable_graph()->add_initializer();
init->set_name(name + "_weights");
init->add_dims(112);
init->add_dims(1);
init->add_dims(1);
if (fp16) {
init->set_data_type(pblczero::TensorProto::FLOAT16);
std::vector<uint16_t> rule50weights(112, FP32toFP16(1.0f));
rule50weights[109] = FP32toFP16(1.0f / 99);
init->set_raw_data(
std::string(reinterpret_cast<const char*>(rule50weights.data()),
rule50weights.size() * sizeof(uint16_t)));
} else {
init->set_data_type(pblczero::TensorProto::FLOAT);
std::vector<float> rule50weights(112, 1.0f);
rule50weights[109] = 1.0f / 99;
init->set_raw_data(
std::string(reinterpret_cast<const char*>(rule50weights.data()),
rule50weights.size() * sizeof(float)));
}
auto* new_node = model.mutable_graph()->add_node();
new_node->set_name(name);
new_node->set_op_type("Mul");
new_node->add_input(in);
new_node->add_output(name);

new_node->add_input(name + "_weights");
}

void FixWdlSoftmax(pblczero::ModelProto& model, const std::string& out) {
std::string name = "softmax_fix";

for (size_t i = 0; i < model.graph().node_size(); i++) {
auto node = model.graph().node(i);
for (size_t j = 0; j < node.output_size(); j++) {
if (node.output(j) == out) {
CERR << "Inserting softmax between " << node.name() << " and " << out;
model.mutable_graph()->mutable_node(i)->mutable_output()->at(j) =
std::string(name);
break;
}
}
}

auto* new_node = model.mutable_graph()->add_node();
new_node->set_name(name);
new_node->set_op_type("Softmax");
new_node->add_input(name);
new_node->add_output(out);
}

pblczero::OnnxModel_DataType GetDataType(pblczero::ModelProto& model,
const std::string& name) {
using pblczero::TensorProto;
using pblczero::OnnxModel;
for (auto& in : model.graph().input()) {
if (in.name() == name && in.has_type() && in.type().has_tensor_type() &&
in.type().tensor_type().has_elem_type()) {
auto data_type = in.type().tensor_type().elem_type();
switch (data_type) {
case TensorProto::FLOAT:
return OnnxModel::FLOAT;
case TensorProto::FLOAT16:
return OnnxModel::FLOAT16;
default:
throw Exception("Unsupported data type: " +
TensorProto::DataType_Name(data_type));
}
}
}
return OnnxModel::FLOAT;
}

bool EnsureOutDataType(pblczero::ModelProto& model, const std::string& name,
pblczero::OnnxModel_DataType data_type) {
// Check if output has the correct data type and set it if not.
for (size_t i = 0; i < model.graph().output_size(); i++) {
auto out = model.graph().output(i);
if (out.name() == name) {
if (!out.has_type()) {
model.mutable_graph()->mutable_output(i)->mutable_type();
}
if (!out.type().has_tensor_type()) {
model.mutable_graph()
->mutable_output(i)
->mutable_type()
->mutable_tensor_type();
}
if (!out.type().tensor_type().has_elem_type() ||
out.type().tensor_type().elem_type() !=
static_cast<pblczero::TensorProto_DataType>(data_type)) {
model.mutable_graph()
->mutable_output(i)
->mutable_type()
->mutable_tensor_type()
->set_elem_type(
static_cast<pblczero::TensorProto_DataType>(data_type));
break;
}
return false;
}
}

// Insert a cast to the correct data type.
for (size_t i = 0; i < model.graph().node_size(); i++) {
auto node = model.graph().node(i);
for (size_t j = 0; j < node.output_size(); j++) {
if (node.output(j) == name) {
CERR << "Inserting cast between " << node.name() << " and " << name;
model.mutable_graph()->mutable_node(i)->mutable_output()->at(j) =
std::string(name + "/cast");
break;
}
}
}

auto* new_node = model.mutable_graph()->add_node();
new_node->set_name(name + "/cast");
new_node->set_op_type("Cast");
new_node->add_input(name + "/cast");
new_node->add_output(name);
auto* attr = new_node->add_attribute();
attr->set_name("to");
attr->set_type(pblczero::AttributeProto::INT);
attr->set_i(data_type);
return true;
}

bool MaybeFixOnnx(pblczero::ModelProto& model, const OptionsDict& dict,
pblczero::OnnxModel_DataType data_type) {
bool updated = false;

// Input.
if (dict.OwnExists<std::string>(kOnnxInputId)) {
if (dict.Get<bool>(kFixRule50Id)) {
FixRule50(model, dict.Get<std::string>(kOnnxInputId),
data_type == pblczero::OnnxModel::FLOAT16);
updated = true;
}
}

// Policy.
if (dict.OwnExists<std::string>(kOnnxOutputPolicyId)) {
updated |= EnsureOutDataType(
model, dict.Get<std::string>(kOnnxOutputPolicyId), data_type);
}

// Value.
if (dict.OwnExists<std::string>(kOnnxOutputValueId)) {
updated |= EnsureOutDataType(
model, dict.Get<std::string>(kOnnxOutputValueId), data_type);
}
if (dict.OwnExists<std::string>(kOnnxOutputWdlId)) {
auto out = dict.Get<std::string>(kOnnxOutputWdlId);
if (dict.Get<bool>(kFixWdlSoftmaxId)) {
FixWdlSoftmax(model, out);
updated = true;
}
updated |= EnsureOutDataType(model, out, data_type);
}

// Mlh.
if (dict.OwnExists<std::string>(kOnnxOutputMlhId)) {
updated |= EnsureOutDataType(model, dict.Get<std::string>(kOnnxOutputMlhId),
data_type);
}

return updated;
}

} // namespace

void ConvertOnnxToLeela() {
Expand All @@ -240,6 +425,10 @@ void ConvertOnnxToLeela() {

const OptionsDict& dict = options_parser.GetOptionsDict();

auto onnx_model = ReadFileToString(dict.Get<std::string>(kInputFilenameId));
pblczero::ModelProto model;
model.ParseFromString(onnx_model);

pblczero::Net out_weights;
out_weights.set_magic(0x1c0);
// ONNX networks appeared in v0.28.
Expand All @@ -249,17 +438,18 @@ void ConvertOnnxToLeela() {
auto format = out_weights.mutable_format()->mutable_network_format();
format->set_network(NetworkFormat::NETWORK_ONNX);
auto onnx = out_weights.mutable_onnx_model();
onnx->set_data_type(GetEnumValueFromString(
dict.Get<std::string>(kOnnxDataTypeId), OnnxModel::DataType_AllValues,
OnnxModel::DataType_Name));
auto data_type = OnnxModel::FLOAT;

// Input.
format->set_input(GetEnumValueFromString(
dict.Get<std::string>(kInputFormatId),
NetworkFormat::InputFormat_AllValues, NetworkFormat::InputFormat_Name));
if (dict.OwnExists<std::string>(kOnnxInputId)) {
onnx->set_input_planes(dict.Get<std::string>(kOnnxInputId));
auto in = dict.Get<std::string>(kOnnxInputId);
onnx->set_input_planes(in);
data_type = GetDataType(model, in);
}
onnx->set_data_type(data_type);

// Policy.
format->set_policy(GetEnumValueFromString(
Expand Down Expand Up @@ -289,8 +479,13 @@ void ConvertOnnxToLeela() {
onnx->set_output_mlh(dict.Get<std::string>(kOnnxOutputMlhId));
}

onnx->set_model(ReadFileToString(dict.Get<std::string>(kInputFilenameId)));
if (dict.Get<bool>(kValidateModelId) && !ValidateNetwork(out_weights)) {
if (MaybeFixOnnx(model, dict, data_type)) {
onnx->set_model(model.OutputAsString());
} else {
onnx->set_model(onnx_model);
}
if (dict.Get<bool>(kValidateModelId) &&
!ValidateNetwork(out_weights, model)) {
return;
}
WriteStringToGzFile(dict.Get<std::string>(kOutputFilenameId),
Expand All @@ -300,4 +495,4 @@ void ConvertOnnxToLeela() {
COUT << "Done.";
}

} // namespace lczero
} // namespace lczero
3 changes: 3 additions & 0 deletions src/neural/onnx/converter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -884,6 +884,9 @@ void Converter::GenerateOnnx(pblczero::OnnxModel* onnx) {
LegacyWeights weights(src_.weights());
OnnxBuilder builder(options_.opset);

onnx->set_data_type(GetDataType() == pblczero::TensorProto::FLOAT16
? pblczero::OnnxModel::FLOAT16
: pblczero::OnnxModel::FLOAT);
onnx->set_input_planes(options_.input_planes_name);
builder.AddInput(options_.input_planes_name, {options_.batch_size, 112, 8, 8},
GetDataType());
Expand Down
12 changes: 6 additions & 6 deletions src/neural/onnx/network_onnx.cc
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,8 @@ class OnnxComputation : public NetworkComputation {
class OnnxNetwork : public Network {
public:
OnnxNetwork(const WeightsFile& file, const OptionsDict& options,
OnnxProvider provider, int gpu, int threads, bool fp16,
int batch_size, int steps);
OnnxProvider provider, int gpu, int threads, int batch_size,
int steps);
std::unique_ptr<NetworkComputation> NewComputation() override {
if (fp16_) {
return std::make_unique<OnnxComputation<Ort::Float16_t>>(this);
Expand Down Expand Up @@ -321,13 +321,13 @@ Ort::SessionOptions GetOptions(OnnxProvider provider, int gpu, int threads,
}

OnnxNetwork::OnnxNetwork(const WeightsFile& file, const OptionsDict&,
OnnxProvider provider, int gpu, int threads, bool fp16,
OnnxProvider provider, int gpu, int threads,
int batch_size, int steps)
: onnx_env_(ORT_LOGGING_LEVEL_WARNING, "lc0"),
steps_(steps),
capabilities_{file.format().network_format().input(),
file.format().network_format().moves_left()},
fp16_(fp16),
fp16_(file.onnx_model().data_type() == pblczero::OnnxModel::FLOAT16),
batch_size_(batch_size),
provider_(provider) {
// Sanity checks.
Expand Down Expand Up @@ -396,7 +396,7 @@ std::unique_ptr<Network> MakeOnnxNetwork(const std::optional<WeightsFile>& w,

if (w->has_onnx_model()) {
return std::make_unique<OnnxNetwork>(*w, opts, kProvider, gpu, threads,
false, batch_size, steps);
batch_size, steps);
} else {
if (w->format().network_format().network() !=
pblczero::NetworkFormat::NETWORK_CLASSICAL_WITH_HEADFORMAT &&
Expand Down Expand Up @@ -448,7 +448,7 @@ std::unique_ptr<Network> MakeOnnxNetwork(const std::optional<WeightsFile>& w,

auto converted = ConvertWeightsToOnnx(*w, converter_options);
return std::make_unique<OnnxNetwork>(converted, opts, kProvider, gpu,
threads, fp16, batch_size, steps);
threads, batch_size, steps);
}
}

Expand Down
Loading

0 comments on commit a2f98f7

Please sign in to comment.