Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Deterministic order of testcases #2642

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 8 additions & 28 deletions tests/torch/models_hub_test/test_timm.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from pathlib import Path
from typing import Tuple

import pytest
Expand All @@ -18,32 +19,11 @@

from tests.torch.models_hub_test.common import BaseTestModel
from tests.torch.models_hub_test.common import ExampleType
from tests.torch.models_hub_test.common import ModelInfo
from tests.torch.models_hub_test.common import get_model_params
from tests.torch.models_hub_test.common import idfn


def filter_timm(timm_list: list) -> list:
unique_models = set()
filtered_list = []
ignore_set = {
"base", "mini", "small", "xxtiny", "xtiny", "tiny", "lite", "nano", "pico", "medium", "big",
"large", "xlarge", "xxlarge", "huge", "gigantic", "giant", "enormous", "xs", "xxs", "s", "m", "l", "xl"
} # fmt: skip
for name in timm_list:
# first: remove datasets
name_parts = name.split(".")
_name = "_".join(name.split(".")[:-1]) if len(name_parts) > 1 else name
# second: remove sizes
name_set = set([n for n in _name.split("_") if not n.isnumeric()])
name_set = name_set.difference(ignore_set)
name_join = "_".join(name_set)
if name_join not in unique_models:
unique_models.add(name_join)
filtered_list.append(name)
return filtered_list


def get_all_models() -> list:
m_list = timm.list_pretrained()
return filter_timm(m_list)
MODEL_LIST_FILE = Path(__file__).parent / "timm_models.txt"


class TestTimmModel(BaseTestModel):
Expand All @@ -54,6 +34,6 @@ def load_model(self, model_name: str) -> Tuple[nn.Module, ExampleType]:
example = (torch.randn(shape),)
return m, example

@pytest.mark.parametrize("name", get_all_models())
def test_nncf_wrap(self, name):
self.nncf_wrap(name)
@pytest.mark.parametrize("model_info", get_model_params(MODEL_LIST_FILE), ids=idfn)
def test_nncf_wrap(self, model_info: ModelInfo):
self.nncf_wrap(model_info.model_name)
Loading
Loading