Skip to content

Commit

Permalink
Drop torch.compile fullgraph test (#19166)
Browse files Browse the repository at this point in the history
  • Loading branch information
carmocca authored Dec 16, 2023
1 parent 3b1643c commit 5c36e99
Showing 1 changed file with 0 additions and 39 deletions.
39 changes: 0 additions & 39 deletions tests/tests_fabric/test_fabric.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
from re import escape
from unittest import mock
from unittest.mock import ANY, MagicMock, Mock, PropertyMock, call
Expand All @@ -35,7 +34,6 @@
)
from lightning.fabric.strategies.strategy import _Sharded
from lightning.fabric.utilities.exceptions import MisconfigurationException
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_1
from lightning.fabric.utilities.seed import pl_worker_init_function, seed_everything
from lightning.fabric.utilities.warnings import PossibleUserWarning
from lightning.fabric.wrappers import _FabricDataLoader, _FabricModule, _FabricOptimizer
Expand Down Expand Up @@ -1204,40 +1202,3 @@ def test_verify_launch_called():
fabric.launch()
assert fabric._launched
fabric._validate_launched()


@pytest.mark.skipif(sys.platform == "darwin" and not _TORCH_GREATER_EQUAL_2_1, reason="Fix for MacOS in PyTorch 2.1")
@RunIf(dynamo=True)
@pytest.mark.parametrize(
"kwargs",
[
{},
pytest.param({"precision": "16-true"}, marks=pytest.mark.xfail(raises=RuntimeError, match="Unsupported")),
pytest.param({"precision": "64-true"}, marks=pytest.mark.xfail(raises=RuntimeError, match="Unsupported")),
],
)
def test_fabric_with_torchdynamo_fullgraph(kwargs):
class MyModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.l = torch.nn.Linear(10, 10)

def forward(self, x):
# forward gets compiled
assert torch._dynamo.is_compiling()
return self.l(x)

def fn(model, x):
assert torch._dynamo.is_compiling()
a = x * 10
return model(a)

fabric = Fabric(devices=1, accelerator="cpu", **kwargs)
model = MyModel()
fmodel = fabric.setup(model)
# we are compiling a function that calls model.forward() inside
cfn = torch.compile(fn, fullgraph=True)
x = torch.randn(10, 10, device=fabric.device)
# pass the fabric wrapped model to the compiled function, so that it gets compiled too
out = cfn(fmodel, x)
assert isinstance(out, torch.Tensor)

0 comments on commit 5c36e99

Please sign in to comment.