diff --git a/dall_e/__init__.py b/dall_e/__init__.py index cd982fa..d76cdc5 100644 --- a/dall_e/__init__.py +++ b/dall_e/__init__.py @@ -1,9 +1,10 @@ -import io, requests +import io +import requests + import torch import torch.nn as nn -from dall_e.encoder import Encoder -from dall_e.decoder import Decoder +from dall_e.encoder import Encoder, Decoder from dall_e.utils import map_pixels, unmap_pixels def load_model(path: str, device: torch.device = None) -> nn.Module: