Skip to content

Commit

Permalink
add(torch-frontend): Adds torchvision frontend and moves torch.ops.to…
Browse files Browse the repository at this point in the history
…rchvision there (#26491)
  • Loading branch information
AnnaTz committed Oct 9, 2023
1 parent 3eadffb commit 0920967
Show file tree
Hide file tree
Showing 12 changed files with 183 additions and 8 deletions.
1 change: 1 addition & 0 deletions ivy/functional/frontends/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
"paddle": "2.5.1",
"sklearn": "1.3.0",
"xgboost": "1.7.6",
"torchvision": "0.15.2.",
}


Expand Down
1 change: 0 additions & 1 deletion ivy/functional/frontends/torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,6 @@ def promote_types_of_torch_inputs(

from . import nn
from .nn.functional import softmax, relu
from . import ops
from . import tensor
from .tensor import *
from . import blas_and_lapack_ops
Expand Down
1 change: 0 additions & 1 deletion ivy/functional/frontends/torch/ops/__init__.py

This file was deleted.

2 changes: 0 additions & 2 deletions ivy/functional/frontends/torch/ops/torchvision/__init__.py

This file was deleted.

23 changes: 23 additions & 0 deletions ivy/functional/frontends/torchvision/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import sys


import ivy.functional.frontends.torch as torch
import ivy
from ivy.functional.frontends import set_frontend_to_specific_version


from . import ops


tensor = _frontend_array = torch.tensor


# setting to specific version #
# --------------------------- #

if ivy.is_local():
module = ivy.utils._importlib.import_cache[__name__]
else:
module = sys.modules[__name__]

set_frontend_to_specific_version(module)
File renamed without changes.
145 changes: 145 additions & 0 deletions ivy_tests/test_ivy/test_frontends/config/torchvision.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
from .base import FrontendConfig, SupportedDtypes, SupportedDeviecs
import ivy


def get_config():
return TorchVisionFrontendConfig()


class TorchVisionFrontendConfig(FrontendConfig):
backend = ivy.with_backend("torch")

valid_devices = ["cpu", "gpu"]
invalid_devices = ["tpu"]

valid_dtypes = [
"int16",
"int32",
"int64",
"uint8",
"float16",
"float32",
"float64",
]

invalid_dtypes = [
"int8",
"uint16",
"uint32",
"uint64",
"bfloat16",
"complex64",
"complex128",
"bool",
]

valid_numeric_dtypes = [
"int16",
"int32",
"int64",
"uint8",
"float16",
"float32",
"float64",
]

invalid_numeric_dtypes = [
"int8",
"uint16",
"uint32",
"uint64",
"bfloat16",
"complex64",
"complex128",
"bool",
]

valid_int_dtypes = [
"int16",
"int32",
"int64",
"uint8",
]

invalid_int_dtypes = [
"int8",
"uint16",
"uint32",
"uint64",
]

valid_uint_dtypes = [
"uint8",
]

invalid_uint_dtypes = [
"uint16",
"uint32",
"uint64",
]

valid_float_dtypes = [
"float16",
"float32",
"float64",
]

invalid_float_dtypes = [
"bfloat16",
]

valid_complex_dtypes = []

invalid_complex_dtypes = [
"complex64",
"complex128",
]

@property
def supported_devices(self):
return SupportedDeviecs(
valid_devices=self.valid_devices, invalid_devices=self.invalid_devices
)

@property
def supported_dtypes(self):
return SupportedDtypes(
valid_dtypes=self.valid_dtypes,
invalid_dtypes=self.invalid_dtypes,
valid_numeric_dtypes=self.valid_numeric_dtypes,
invalid_numeric_dtypes=self.invalid_numeric_dtypes,
valid_int_dtypes=self.valid_int_dtypes,
invalid_int_dtypes=self.invalid_int_dtypes,
valid_uint_dtypes=self.valid_uint_dtypes,
invalid_uint_dtypes=self.invalid_uint_dtypes,
valid_float_dtypes=self.valid_float_dtypes,
invalid_float_dtypes=self.invalid_float_dtypes,
valid_complex_dtypes=self.valid_complex_dtypes,
invalid_complex_dtypes=self.invalid_complex_dtypes,
)

@property
def Dtype(self):
return self.backend.Dtype

@property
def Device(self):
return self.backend.Device

def native_array(self, x):
return self.backend.native_array(x)

def is_native_array(self, x):
return self.backend.is_native_array(x)

def to_numpy(self, x):
return self.backend.to_numpy(x)

def as_native_dtype(self, dtype: str):
return self.backend.as_native_dtype(dtype)

def as_native_device(self, device: str):
return self.backend.as_native_dev(device)

def isscalar(self, x):
return self.backend.isscalar(x)
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
import pytest


@pytest.fixture(scope="session")
def frontend():
return "torchvision"
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import numpy as np
from hypothesis import strategies as st


# local
import ivy_tests.test_ivy.helpers as helpers
from ivy_tests.test_ivy.helpers import handle_frontend_test
Expand Down Expand Up @@ -83,11 +84,11 @@ def _roi_align_helper(draw):

# nms
@handle_frontend_test(
fn_tree="torch.ops.torchvision.nms",
fn_tree="torchvision.ops.nms",
dts_boxes_scores_iou=_nms_helper(),
test_with_out=st.just(False),
)
def test_torch_nms(
def test_torchvision_nms(
*,
dts_boxes_scores_iou,
on_device,
Expand All @@ -112,11 +113,11 @@ def test_torch_nms(

# roi_align
@handle_frontend_test(
fn_tree="torch.ops.torchvision.roi_align",
fn_tree="torchvision.ops.roi_align",
inputs=_roi_align_helper(),
test_with_out=st.just(False),
)
def test_torch_roi_align(
def test_torchvision_roi_align(
*,
inputs,
on_device,
Expand Down
3 changes: 3 additions & 0 deletions run_tests_CLI/synchronize_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
"misc": "test_misc",
"paddle": "test_frontends/test_paddle",
"scipy": "test_frontends/test_scipy",
"torchvision": "test_frontends/test_torchvision",
}


Expand Down Expand Up @@ -58,6 +59,7 @@ def keys_to_delete_from_db(all_tests, module, data, current_key=""):
"test_onnx",
"test_sklearn",
"test_xgboost",
"test_torchvision",
)
db_dict = {
"test_functional/test_core": ["core", 10],
Expand All @@ -77,6 +79,7 @@ def keys_to_delete_from_db(all_tests, module, data, current_key=""):
"test_onnx": ["onnx", 24],
"test_sklearn": ["sklearn", 25],
"test_xgboost": ["xgboost", 26],
"test_torchvision": ["torchvision", 27],
}


Expand Down

0 comments on commit 0920967

Please sign in to comment.