Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[onert/python] Introduce BaseSession #14527

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 18 additions & 7 deletions infra/nnfw/python/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,12 +52,23 @@

# copy *py files to package_directory
PY_DIR = os.path.join(THIS_FILE_DIR, '../../../runtime/onert/api/python/package')
for py_file in os.listdir(PY_DIR):
if py_file.endswith(".py"):
src_path = os.path.join(PY_DIR, py_file)
dest_path = os.path.join(THIS_FILE_DIR, package_directory)
shutil.copy(src_path, dest_path)
print(f"Copied '{src_path}' to '{dest_path}'")
for root, dirs, files in os.walk(PY_DIR):
# Calculate the relative path from the source directory
rel_path = os.path.relpath(root, PY_DIR)
dest_dir = os.path.join(THIS_FILE_DIR, package_directory)
dest_sub_dir = os.path.join(dest_dir, rel_path)
print(f"dest_sub_dir '{dest_sub_dir}'")

# Ensure the corresponding destination subdirectory exists
os.makedirs(dest_sub_dir, exist_ok=True)

# Copy only .py files
for py_file in files:
if py_file.endswith(".py"):
src_path = os.path.join(root, py_file)
# dest_path = os.path.join(THIS_FILE_DIR, package_directory)
shutil.copy(src_path, dest_sub_dir)
print(f"Copied '{src_path}' to '{dest_sub_dir}'")

# remove architecture directory
if os.path.exists(package_directory):
Expand Down Expand Up @@ -142,6 +153,6 @@ def get_directories():
url='https://github.com/Samsung/ONE',
license='Apache-2.0, MIT, BSD-2-Clause, BSD-3-Clause, Mozilla Public License 2.0',
has_ext_modules=lambda: True,
packages=[package_directory],
packages=find_packages(),
package_data={package_directory: so_list},
install_requires=['numpy >= 1.19'])
8 changes: 7 additions & 1 deletion runtime/onert/api/python/package/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,8 @@
__all__ = ['infer']
# Define the public API of the onert package
__all__ = ["infer", "tensorinfo"]

# Import and expose the infer module's functionalities
from . import infer

# Import and expose tensorinfo
from .common import tensorinfo as tensorinfo
3 changes: 3 additions & 0 deletions runtime/onert/api/python/package/common/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .basesession import BaseSession, tensorinfo

__all__ = ["BaseSession", "tensorinfo"]
75 changes: 75 additions & 0 deletions runtime/onert/api/python/package/common/basesession.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
import numpy as np

from ..native import libnnfw_api_pybind


def num_elems(tensor_info):
"""Get the total number of elements in nnfw_tensorinfo.dims."""
n = 1
for x in range(tensor_info.rank):
n *= tensor_info.dims[x]
return n


class BaseSession:
"""
Base class providing common functionality for inference and training sessions.
"""
def __init__(self, backend_session):
"""
Initialize the BaseSession with a backend session.
Args:
backend_session: A backend-specific session object (e.g., nnfw_session).
"""
self.session = backend_session
self.inputs = []
self.outputs = []

def __getattr__(self, name):
"""
Delegate attribute access to the bound NNFW_SESSION instance.
Args:
name (str): The name of the attribute or method to access.
Returns:
The attribute or method from the bound NNFW_SESSION instance.
"""
return getattr(self.session, name)

def set_inputs(self, size, inputs_array=[]):
"""
Set the input tensors for the session.
Args:
size (int): Number of input tensors.
inputs_array (list): List of numpy arrays for the input data.
"""
for i in range(size):
input_tensorinfo = self.session.input_tensorinfo(i)

if len(inputs_array) > i:
input_array = np.array(inputs_array[i], dtype=input_tensorinfo.dtype)
else:
print(
f"Model's input size is {size}, but given inputs_array size is {len(inputs_array)}.\n{i}-th index input is replaced by an array filled with 0."
)
input_array = np.zeros((num_elems(input_tensorinfo)),
dtype=input_tensorinfo.dtype)

self.session.set_input(i, input_array)
self.inputs.append(input_array)

def set_outputs(self, size):
"""
Set the output tensors for the session.
Args:
size (int): Number of output tensors.
"""
for i in range(size):
output_tensorinfo = self.session.output_tensorinfo(i)
output_array = np.zeros((num_elems(output_tensorinfo)),
dtype=output_tensorinfo.dtype)
self.session.set_output(i, output_array)
self.outputs.append(output_array)


def tensorinfo():
return libnnfw_api_pybind.infer.nnfw_tensorinfo()
58 changes: 0 additions & 58 deletions runtime/onert/api/python/package/infer.py

This file was deleted.

3 changes: 3 additions & 0 deletions runtime/onert/api/python/package/infer/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .session import session

__all__ = ["session"]
27 changes: 27 additions & 0 deletions runtime/onert/api/python/package/infer/session.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from ..native import libnnfw_api_pybind
from ..common.basesession import BaseSession


class session(BaseSession):
"""
Class for inference using nnfw_session.
"""
def __init__(self, nnpackage_path, backends="cpu"):
"""
Initialize the inference session.
Args:
nnpackage_path (str): Path to the nnpackage file or directory.
backends (str): Backends to use, default is "cpu".
"""
super().__init__(libnnfw_api_pybind.infer.nnfw_session(nnpackage_path, backends))
self.session.prepare()
self.set_outputs(self.session.output_size())

def inference(self):
ragmani marked this conversation as resolved.
Show resolved Hide resolved
"""
Perform model and get outputs
Returns:
list: Outputs from the model.
"""
self.session.run()
return self.outputs