Skip to content

Commit

Permalink
Deterministic order of testscases
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexanderDokuchaev committed Apr 18, 2024
1 parent 573b0c3 commit ae1bd35
Show file tree
Hide file tree
Showing 4 changed files with 480 additions and 35 deletions.
36 changes: 8 additions & 28 deletions tests/torch/models_hub_test/test_timm.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from pathlib import Path
from typing import Tuple

import pytest
Expand All @@ -18,32 +19,11 @@

from tests.torch.models_hub_test.common import BaseTestModel
from tests.torch.models_hub_test.common import ExampleType
from tests.torch.models_hub_test.common import ModelInfo
from tests.torch.models_hub_test.common import get_model_params
from tests.torch.models_hub_test.common import idfn


def filter_timm(timm_list: list) -> list:
unique_models = set()
filtered_list = []
ignore_set = {
"base", "mini", "small", "xxtiny", "xtiny", "tiny", "lite", "nano", "pico", "medium", "big",
"large", "xlarge", "xxlarge", "huge", "gigantic", "giant", "enormous", "xs", "xxs", "s", "m", "l", "xl"
} # fmt: skip
for name in timm_list:
# first: remove datasets
name_parts = name.split(".")
_name = "_".join(name.split(".")[:-1]) if len(name_parts) > 1 else name
# second: remove sizes
name_set = set([n for n in _name.split("_") if not n.isnumeric()])
name_set = name_set.difference(ignore_set)
name_join = "_".join(name_set)
if name_join not in unique_models:
unique_models.add(name_join)
filtered_list.append(name)
return filtered_list


def get_all_models() -> list:
m_list = timm.list_pretrained()
return filter_timm(m_list)
MODEL_LIST_FILE = Path(__file__).parent / "timm_models.txt"


class TestTimmModel(BaseTestModel):
Expand All @@ -54,6 +34,6 @@ def load_model(self, model_name: str) -> Tuple[nn.Module, ExampleType]:
example = (torch.randn(shape),)
return m, example

@pytest.mark.parametrize("name", get_all_models())
def test_nncf_wrap(self, name):
self.nncf_wrap(name)
@pytest.mark.parametrize("name", get_model_params(MODEL_LIST_FILE), ids=idfn)
def test_nncf_wrap(self, model_info: ModelInfo):
self.nncf_wrap(model_info.model_name)
Loading

0 comments on commit ae1bd35

Please sign in to comment.