Skip to content

Commit

Permalink
feat: Updated src/test_main.py
Browse files Browse the repository at this point in the history
  • Loading branch information
sweep-nightly[bot] authored Oct 23, 2023
1 parent cc9abe2 commit 7fe9106
Showing 1 changed file with 10 additions and 14 deletions.
24 changes: 10 additions & 14 deletions src/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,22 +20,18 @@ def test_data_loading_and_preprocessing(mocker: MockerFixture):
mock_mnist.assert_called_once_with('.', download=True, train=True, transform=transform)
mock_dataloader.assert_called_once_with(trainset, batch_size=64, shuffle=True)

assert isinstance(trainset, datasets.MNIST)
assert isinstance(trainloader, DataLoader)
pytest.assume(isinstance(trainset, datasets.MNIST))
pytest.assume(isinstance(trainloader, DataLoader))

def test_model_definition():
model = Net()

assert isinstance(model, Net)
assert isinstance(model.fc1, torch.nn.Linear)
assert isinstance(model.fc2, torch.nn.Linear)
assert isinstance(model.fc3, torch.nn.Linear)

input_data = torch.randn(64, 1, 28, 28)
output = model(input_data)

assert output.size() == (64, 10)
assert output.dtype == torch.float32
pytest.assume(isinstance(model, Net))
pytest.assume(isinstance(model.fc1, torch.nn.Linear))
pytest.assume(isinstance(model.fc2, torch.nn.Linear))
pytest.assume(isinstance(model.fc3, torch.nn.Linear))
pytest.assume(output.size() == (64, 10))
pytest.assume(output.dtype == torch.float32)

def test_forward_method(mocker: MockerFixture):
mock_relu = mocker.patch('torch.nn.functional.relu')
Expand All @@ -49,5 +45,5 @@ def test_forward_method(mocker: MockerFixture):
mock_relu.assert_any_call(model.fc2(mock_relu.return_value))
mock_log_softmax.assert_called_once_with(model.fc3(mock_relu.return_value), dim=1)

assert output.size() == (64, 10)
assert output.dtype == torch.float32
pytest.assume(output.size() == (64, 10))
pytest.assume(output.dtype == torch.float32)

0 comments on commit 7fe9106

Please sign in to comment.