Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Add Support for GPT-2 #163

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 8 additions & 11 deletions OnnxBridge/Secfloat/backendRep.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def check_variables_to_delete(delete_order_list, code_list, counter, var_dict, i
code_list.append(delete_variable(var_dict[variable], indent + 1))
for variable in delete_order_list[counter]
]
code_list.append("\n\n")
code_list.append("\n")


def prepare_input(code_list, node, var_dict, input_taken, indent):
Expand All @@ -81,7 +81,6 @@ def prepare_input(code_list, node, var_dict, input_taken, indent):
code_list.append(
take_input(var_dict[node.name], node.shape, node.party, indent + 1)
)
code_list.append("\n\n")
input_taken.append(node.name)


Expand Down Expand Up @@ -117,13 +116,13 @@ def prepare_func(code_list, node, var_dict, value_info, input_taken, indent):
input_taken += node.outputs

operator = getattr(Operator, node.op_type)
code_list.append(str(f'{" " * (indent+1)}cout<<"Inside {node.op_type}"<<endl;'))
code_list.append(str(f'{" " * (indent+1)}cout << "Inside {node.op_type}" << endl;'))
code_list.append(
operator(
node.attrs, node.inputs, node.outputs, value_info, var_dict, indent + 1
)
)
code_list.append("\n\n")
code_list.append("")


def prepare_output(code_list, node, var_dict, indent):
Expand All @@ -145,7 +144,6 @@ def prepare_output(code_list, node, var_dict, indent):
code_list.append(
give_output(var_dict[node.name], node.shape, node.party, indent + 1)
)
code_list.append("\n\n")


def prepare_export(program, var_dict, value_info, backend, file_path):
Expand All @@ -162,16 +160,16 @@ def prepare_export(program, var_dict, value_info, backend, file_path):
input_dict = dict()

if backend == "SECFLOAT":
code_list.append(f'#include "{file_path}/lib_secfloat/common.cpp" \n\n\n')
code_list.append(f'#include "{file_path}/lib_secfloat/common.cpp" \n')
code_list.append(
"int main(int __argc, char **__argv)\n{\n\n __init(__argc, __argv);\n"
"int main(int __argc, char **__argv)\n{\n __init(__argc, __argv);\n"
)
elif backend == "SECFLOAT_CLEARTEXT":
code_list.append(
f'#include "{file_path}/lib_cleartext/cleartext_common.cpp" \n\n\n'
f'#include "{file_path}/lib_cleartext/cleartext_common.cpp" \n'
)
code_list.append(
"int main(int __argc, char **__argv)\n{\n\n int __party=0;\n"
"int main(int __argc, char **__argv)\n{\n int __party=0;\n"
)

for node in program:
Expand All @@ -184,7 +182,6 @@ def prepare_export(program, var_dict, value_info, backend, file_path):

logger.info("Starting Export...")
for node in program:

if isinstance(node, Input):
input_dict[node.name] = node
elif isinstance(node, Node):
Expand All @@ -205,7 +202,7 @@ def prepare_export(program, var_dict, value_info, backend, file_path):

counter += 1

code_list.append(" return 0;\n")
code_list.append(" return 0;\n")
code_list.append("}")
logger.info("Completed Export.")

Expand Down
Loading