From a2f98f79b93dc4177f0f8ee66e4f8d00c482df52 Mon Sep 17 00:00:00 2001 From: borg323 <39573933+borg323@users.noreply.github.com> Date: Tue, 14 Nov 2023 23:05:53 +0200 Subject: [PATCH] onnx2leela options to fix tf exported onnx models (#1928) * onnx2leela options to fix tf exported onnx models * initial support for onnx external data * autodetect fp16 onnx model * ensure outputs have the same data type as the input --- scripts/compile_proto.py | 10 ++ src/lc0ctl/onnx2leela.cc | 227 +++++++++++++++++++++++++++++--- src/neural/onnx/converter.cc | 3 + src/neural/onnx/network_onnx.cc | 12 +- src/neural/onnx/onnx.proto | 13 +- 5 files changed, 242 insertions(+), 23 deletions(-) diff --git a/scripts/compile_proto.py b/scripts/compile_proto.py index 1a21c2da4e..9163683a4e 100755 --- a/scripts/compile_proto.py +++ b/scripts/compile_proto.py @@ -376,10 +376,14 @@ def GenerateFunctionDeclarations(self, w): w.Write("%s* add_%s();" % (cpp_type, name)) else: w.Write("void add_%s(%s val);" % (name, cpp_type)) + # Using a vector here breaks API compatibility with the standard + # protobuf library, but it is more convenient. w.Write("const std::vector<%s>& %s() const;" % (var_cpp_type, name)) + w.Write("std::vector<%s>* mutable_%s();" % (var_cpp_type, name)) if self.type.IsMessage(): w.Write("const %s& %s(size_t idx) const;" % (cpp_type, name)) + w.Write("%s* mutable_%s(size_t idx);" % (cpp_type, name)) else: w.Write("%s %s(size_t idx) const;" % (cpp_type, name)) w.Write("size_t %s_size() const;" % (name)) @@ -408,10 +412,16 @@ def GenerateFunctionDefinitions(self, w, class_name): w.Write( "inline const std::vector<%s>& %s::%s() const { return %s_; }" % (var_cpp_type, class_name, name, name)) + w.Write( + "inline std::vector<%s>* %s::mutable_%s() { return &%s_; }" + % (var_cpp_type, class_name, name, name)) if self.type.IsMessage(): w.Write( "inline const %s& %s::%s(size_t idx) const { return %s_[idx]; }" % (cpp_type, class_name, name, name)) + w.Write( + "inline %s* %s::mutable_%s(size_t idx) { return &%s_[idx]; }" + % (cpp_type, class_name, name, name)) else: w.Write( "inline %s %s::%s(size_t idx) const { return %s_[idx]; }" % diff --git a/src/lc0ctl/onnx2leela.cc b/src/lc0ctl/onnx2leela.cc index 08b505c696..d036138251 100644 --- a/src/lc0ctl/onnx2leela.cc +++ b/src/lc0ctl/onnx2leela.cc @@ -35,6 +35,7 @@ #include "neural/onnx/onnx.pb.h" #include "proto/net.pb.h" #include "utils/files.h" +#include "utils/fp16_utils.h" #include "utils/optionsparser.h" namespace lczero { @@ -81,8 +82,6 @@ const OptionId kMovesLeftFormatId("moves-left-format", "MovesLeftFormat", "Format of the moves left head output."); // ONNX options. -const OptionId kOnnxDataTypeId("onnx-data-type", "OnnxDataType", - "Data type to feed into the neural network."); const OptionId kOnnxInputId{"onnx-input", "OnnxInput", "The name of the input ONNX node."}; const OptionId kOnnxOutputValueId{ @@ -97,6 +96,12 @@ const OptionId kOnnxOutputMlhId{"onnx-output-mlh", "OnnxOutputMlh", const OptionId kValidateModelId{"validate-weights", "ValidateWeights", "Do a basic check of the provided ONNX file."}; +const OptionId kFixRule50Id{ + "fix-rule50", "", + "Fix tensorflow exported onnx that needs rule50 input scaling."}; +const OptionId kFixWdlSoftmaxId{ + "fix-wdl-softmax", "", + "Fix tensorflow exported onnx that is missing wdl output softmax."}; bool ProcessParameters(OptionsParser* options) { using pblczero::NetworkFormat; @@ -122,10 +127,6 @@ bool ProcessParameters(OptionsParser* options) { NetworkFormat::MovesLeftFormat_Name)) = NetworkFormat::MovesLeftFormat_Name(NetworkFormat::MOVES_LEFT_V1); // Onnx options. - options->Add(kOnnxDataTypeId, - GetAllEnumValues(OnnxModel::DataType_AllValues, - OnnxModel::DataType_Name)) = - OnnxModel::DataType_Name(OnnxModel::FLOAT); options->Add(kOnnxInputId); options->Add(kOnnxOutputPolicyId); options->Add(kOnnxOutputValueId); @@ -133,6 +134,8 @@ bool ProcessParameters(OptionsParser* options) { options->Add(kOnnxOutputMlhId); options->Add(kValidateModelId) = true; + options->Add(kFixRule50Id) = false; + options->Add(kFixWdlSoftmaxId) = false; if (!options->ProcessAllFlags()) return false; @@ -142,10 +145,8 @@ bool ProcessParameters(OptionsParser* options) { return true; } -bool ValidateNetwork(const pblczero::Net& weights) { +bool ValidateNetwork(const pblczero::Net& weights, pblczero::ModelProto& onnx) { const auto& onnx_model = weights.onnx_model(); - pblczero::ModelProto onnx; - onnx.ParseFromString(onnx_model.model()); if (!onnx.has_ir_version()) { CERR << "ONNX file doesn't appear to have version specified. Likely not an " @@ -230,6 +231,190 @@ bool ValidateNetwork(const pblczero::Net& weights) { return true; } +void FixRule50(pblczero::ModelProto& model, const std::string& in, bool fp16) { + std::string name = "rule50fix"; + + for (size_t i = 0; i < model.graph().node_size(); i++) { + auto node = model.graph().node(i); + for (size_t j = 0; j < node.input_size(); j++) { + if (node.input(j) == in) { + CERR << "Inerting scaling between " << in << " and " << node.name(); + model.mutable_graph()->mutable_node(i)->mutable_input()->at(j) = + std::string(name); + } + } + } + + auto* init = model.mutable_graph()->add_initializer(); + init->set_name(name + "_weights"); + init->add_dims(112); + init->add_dims(1); + init->add_dims(1); + if (fp16) { + init->set_data_type(pblczero::TensorProto::FLOAT16); + std::vector rule50weights(112, FP32toFP16(1.0f)); + rule50weights[109] = FP32toFP16(1.0f / 99); + init->set_raw_data( + std::string(reinterpret_cast(rule50weights.data()), + rule50weights.size() * sizeof(uint16_t))); + } else { + init->set_data_type(pblczero::TensorProto::FLOAT); + std::vector rule50weights(112, 1.0f); + rule50weights[109] = 1.0f / 99; + init->set_raw_data( + std::string(reinterpret_cast(rule50weights.data()), + rule50weights.size() * sizeof(float))); + } + auto* new_node = model.mutable_graph()->add_node(); + new_node->set_name(name); + new_node->set_op_type("Mul"); + new_node->add_input(in); + new_node->add_output(name); + + new_node->add_input(name + "_weights"); +} + +void FixWdlSoftmax(pblczero::ModelProto& model, const std::string& out) { + std::string name = "softmax_fix"; + + for (size_t i = 0; i < model.graph().node_size(); i++) { + auto node = model.graph().node(i); + for (size_t j = 0; j < node.output_size(); j++) { + if (node.output(j) == out) { + CERR << "Inserting softmax between " << node.name() << " and " << out; + model.mutable_graph()->mutable_node(i)->mutable_output()->at(j) = + std::string(name); + break; + } + } + } + + auto* new_node = model.mutable_graph()->add_node(); + new_node->set_name(name); + new_node->set_op_type("Softmax"); + new_node->add_input(name); + new_node->add_output(out); +} + +pblczero::OnnxModel_DataType GetDataType(pblczero::ModelProto& model, + const std::string& name) { + using pblczero::TensorProto; + using pblczero::OnnxModel; + for (auto& in : model.graph().input()) { + if (in.name() == name && in.has_type() && in.type().has_tensor_type() && + in.type().tensor_type().has_elem_type()) { + auto data_type = in.type().tensor_type().elem_type(); + switch (data_type) { + case TensorProto::FLOAT: + return OnnxModel::FLOAT; + case TensorProto::FLOAT16: + return OnnxModel::FLOAT16; + default: + throw Exception("Unsupported data type: " + + TensorProto::DataType_Name(data_type)); + } + } + } + return OnnxModel::FLOAT; +} + +bool EnsureOutDataType(pblczero::ModelProto& model, const std::string& name, + pblczero::OnnxModel_DataType data_type) { + // Check if output has the correct data type and set it if not. + for (size_t i = 0; i < model.graph().output_size(); i++) { + auto out = model.graph().output(i); + if (out.name() == name) { + if (!out.has_type()) { + model.mutable_graph()->mutable_output(i)->mutable_type(); + } + if (!out.type().has_tensor_type()) { + model.mutable_graph() + ->mutable_output(i) + ->mutable_type() + ->mutable_tensor_type(); + } + if (!out.type().tensor_type().has_elem_type() || + out.type().tensor_type().elem_type() != + static_cast(data_type)) { + model.mutable_graph() + ->mutable_output(i) + ->mutable_type() + ->mutable_tensor_type() + ->set_elem_type( + static_cast(data_type)); + break; + } + return false; + } + } + + // Insert a cast to the correct data type. + for (size_t i = 0; i < model.graph().node_size(); i++) { + auto node = model.graph().node(i); + for (size_t j = 0; j < node.output_size(); j++) { + if (node.output(j) == name) { + CERR << "Inserting cast between " << node.name() << " and " << name; + model.mutable_graph()->mutable_node(i)->mutable_output()->at(j) = + std::string(name + "/cast"); + break; + } + } + } + + auto* new_node = model.mutable_graph()->add_node(); + new_node->set_name(name + "/cast"); + new_node->set_op_type("Cast"); + new_node->add_input(name + "/cast"); + new_node->add_output(name); + auto* attr = new_node->add_attribute(); + attr->set_name("to"); + attr->set_type(pblczero::AttributeProto::INT); + attr->set_i(data_type); + return true; +} + +bool MaybeFixOnnx(pblczero::ModelProto& model, const OptionsDict& dict, + pblczero::OnnxModel_DataType data_type) { + bool updated = false; + + // Input. + if (dict.OwnExists(kOnnxInputId)) { + if (dict.Get(kFixRule50Id)) { + FixRule50(model, dict.Get(kOnnxInputId), + data_type == pblczero::OnnxModel::FLOAT16); + updated = true; + } + } + + // Policy. + if (dict.OwnExists(kOnnxOutputPolicyId)) { + updated |= EnsureOutDataType( + model, dict.Get(kOnnxOutputPolicyId), data_type); + } + + // Value. + if (dict.OwnExists(kOnnxOutputValueId)) { + updated |= EnsureOutDataType( + model, dict.Get(kOnnxOutputValueId), data_type); + } + if (dict.OwnExists(kOnnxOutputWdlId)) { + auto out = dict.Get(kOnnxOutputWdlId); + if (dict.Get(kFixWdlSoftmaxId)) { + FixWdlSoftmax(model, out); + updated = true; + } + updated |= EnsureOutDataType(model, out, data_type); + } + + // Mlh. + if (dict.OwnExists(kOnnxOutputMlhId)) { + updated |= EnsureOutDataType(model, dict.Get(kOnnxOutputMlhId), + data_type); + } + + return updated; +} + } // namespace void ConvertOnnxToLeela() { @@ -240,6 +425,10 @@ void ConvertOnnxToLeela() { const OptionsDict& dict = options_parser.GetOptionsDict(); + auto onnx_model = ReadFileToString(dict.Get(kInputFilenameId)); + pblczero::ModelProto model; + model.ParseFromString(onnx_model); + pblczero::Net out_weights; out_weights.set_magic(0x1c0); // ONNX networks appeared in v0.28. @@ -249,17 +438,18 @@ void ConvertOnnxToLeela() { auto format = out_weights.mutable_format()->mutable_network_format(); format->set_network(NetworkFormat::NETWORK_ONNX); auto onnx = out_weights.mutable_onnx_model(); - onnx->set_data_type(GetEnumValueFromString( - dict.Get(kOnnxDataTypeId), OnnxModel::DataType_AllValues, - OnnxModel::DataType_Name)); + auto data_type = OnnxModel::FLOAT; // Input. format->set_input(GetEnumValueFromString( dict.Get(kInputFormatId), NetworkFormat::InputFormat_AllValues, NetworkFormat::InputFormat_Name)); if (dict.OwnExists(kOnnxInputId)) { - onnx->set_input_planes(dict.Get(kOnnxInputId)); + auto in = dict.Get(kOnnxInputId); + onnx->set_input_planes(in); + data_type = GetDataType(model, in); } + onnx->set_data_type(data_type); // Policy. format->set_policy(GetEnumValueFromString( @@ -289,8 +479,13 @@ void ConvertOnnxToLeela() { onnx->set_output_mlh(dict.Get(kOnnxOutputMlhId)); } - onnx->set_model(ReadFileToString(dict.Get(kInputFilenameId))); - if (dict.Get(kValidateModelId) && !ValidateNetwork(out_weights)) { + if (MaybeFixOnnx(model, dict, data_type)) { + onnx->set_model(model.OutputAsString()); + } else { + onnx->set_model(onnx_model); + } + if (dict.Get(kValidateModelId) && + !ValidateNetwork(out_weights, model)) { return; } WriteStringToGzFile(dict.Get(kOutputFilenameId), @@ -300,4 +495,4 @@ void ConvertOnnxToLeela() { COUT << "Done."; } -} // namespace lczero \ No newline at end of file +} // namespace lczero diff --git a/src/neural/onnx/converter.cc b/src/neural/onnx/converter.cc index a69e2b335a..59297b3ead 100644 --- a/src/neural/onnx/converter.cc +++ b/src/neural/onnx/converter.cc @@ -884,6 +884,9 @@ void Converter::GenerateOnnx(pblczero::OnnxModel* onnx) { LegacyWeights weights(src_.weights()); OnnxBuilder builder(options_.opset); + onnx->set_data_type(GetDataType() == pblczero::TensorProto::FLOAT16 + ? pblczero::OnnxModel::FLOAT16 + : pblczero::OnnxModel::FLOAT); onnx->set_input_planes(options_.input_planes_name); builder.AddInput(options_.input_planes_name, {options_.batch_size, 112, 8, 8}, GetDataType()); diff --git a/src/neural/onnx/network_onnx.cc b/src/neural/onnx/network_onnx.cc index 82768acdd7..d83ac37cf0 100644 --- a/src/neural/onnx/network_onnx.cc +++ b/src/neural/onnx/network_onnx.cc @@ -82,8 +82,8 @@ class OnnxComputation : public NetworkComputation { class OnnxNetwork : public Network { public: OnnxNetwork(const WeightsFile& file, const OptionsDict& options, - OnnxProvider provider, int gpu, int threads, bool fp16, - int batch_size, int steps); + OnnxProvider provider, int gpu, int threads, int batch_size, + int steps); std::unique_ptr NewComputation() override { if (fp16_) { return std::make_unique>(this); @@ -321,13 +321,13 @@ Ort::SessionOptions GetOptions(OnnxProvider provider, int gpu, int threads, } OnnxNetwork::OnnxNetwork(const WeightsFile& file, const OptionsDict&, - OnnxProvider provider, int gpu, int threads, bool fp16, + OnnxProvider provider, int gpu, int threads, int batch_size, int steps) : onnx_env_(ORT_LOGGING_LEVEL_WARNING, "lc0"), steps_(steps), capabilities_{file.format().network_format().input(), file.format().network_format().moves_left()}, - fp16_(fp16), + fp16_(file.onnx_model().data_type() == pblczero::OnnxModel::FLOAT16), batch_size_(batch_size), provider_(provider) { // Sanity checks. @@ -396,7 +396,7 @@ std::unique_ptr MakeOnnxNetwork(const std::optional& w, if (w->has_onnx_model()) { return std::make_unique(*w, opts, kProvider, gpu, threads, - false, batch_size, steps); + batch_size, steps); } else { if (w->format().network_format().network() != pblczero::NetworkFormat::NETWORK_CLASSICAL_WITH_HEADFORMAT && @@ -448,7 +448,7 @@ std::unique_ptr MakeOnnxNetwork(const std::optional& w, auto converted = ConvertWeightsToOnnx(*w, converter_options); return std::make_unique(converted, opts, kProvider, gpu, - threads, fp16, batch_size, steps); + threads, batch_size, steps); } } diff --git a/src/neural/onnx/onnx.proto b/src/neural/onnx/onnx.proto index 01a5ebc69a..5c667192e9 100644 --- a/src/neural/onnx/onnx.proto +++ b/src/neural/onnx/onnx.proto @@ -55,6 +55,17 @@ message TensorProto { optional string name = 8; optional bytes raw_data = 9; optional string doc_string = 12; + repeated StringStringEntryProto external_data = 13; + enum DataLocation { + DEFAULT = 0; + EXTERNAL = 1; + } + optional DataLocation data_location = 14; +} + +message StringStringEntryProto { + optional string key = 1; + optional string value = 2; } message AttributeProto { @@ -147,4 +158,4 @@ message ModelProto { optional string doc_string = 6; optional GraphProto graph = 7; repeated OperatorSetIdProto opset_import = 8; -} \ No newline at end of file +}