Skip to content

Commit

Permalink
Allow changing folder_paths.base_path via command line argument. (#6600)
Browse files Browse the repository at this point in the history
* Reimpl. CLI arg directly inside folder_paths.

* Update tests to use CLI arg mocking.

* Revert last-minute refactor.

* Fix test state polution.
  • Loading branch information
webfiltered authored Jan 29, 2025
1 parent 13fd4d6 commit 222f48c
Show file tree
Hide file tree
Showing 3 changed files with 83 additions and 10 deletions.
9 changes: 5 additions & 4 deletions comfy/cli_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,11 @@ def __call__(self, parser, namespace, values, option_string=None):
parser.add_argument("--enable-cors-header", type=str, default=None, metavar="ORIGIN", nargs="?", const="*", help="Enable CORS (Cross-Origin Resource Sharing) with optional origin or allow all with default '*'.")
parser.add_argument("--max-upload-size", type=float, default=100, help="Set the maximum upload size in MB.")

parser.add_argument("--base-directory", type=str, default=None, help="Set the ComfyUI base directory for models, custom_nodes, input, output, temp, and user directories.")
parser.add_argument("--extra-model-paths-config", type=str, default=None, metavar="PATH", nargs='+', action='append', help="Load one or more extra_model_paths.yaml files.")
parser.add_argument("--output-directory", type=str, default=None, help="Set the ComfyUI output directory.")
parser.add_argument("--temp-directory", type=str, default=None, help="Set the ComfyUI temp directory (default is in the ComfyUI directory).")
parser.add_argument("--input-directory", type=str, default=None, help="Set the ComfyUI input directory.")
parser.add_argument("--output-directory", type=str, default=None, help="Set the ComfyUI output directory. Overrides --base-directory.")
parser.add_argument("--temp-directory", type=str, default=None, help="Set the ComfyUI temp directory (default is in the ComfyUI directory). Overrides --base-directory.")
parser.add_argument("--input-directory", type=str, default=None, help="Set the ComfyUI input directory. Overrides --base-directory.")
parser.add_argument("--auto-launch", action="store_true", help="Automatically launch ComfyUI in the default browser.")
parser.add_argument("--disable-auto-launch", action="store_true", help="Disable auto launching the browser.")
parser.add_argument("--cuda-device", type=int, default=None, metavar="DEVICE_ID", help="Set the id of the cuda device this instance will use.")
Expand Down Expand Up @@ -176,7 +177,7 @@ def is_valid_directory(path: Optional[str]) -> Optional[str]:
help="The local filesystem path to the directory where the frontend is located. Overrides --front-end-version.",
)

parser.add_argument("--user-directory", type=is_valid_directory, default=None, help="Set the ComfyUI user directory with an absolute path.")
parser.add_argument("--user-directory", type=is_valid_directory, default=None, help="Set the ComfyUI user directory with an absolute path. Overrides --base-directory.")

if comfy.options.args_parsing:
args = parser.parse_args()
Expand Down
10 changes: 9 additions & 1 deletion folder_paths.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,19 @@
from typing import Literal
from collections.abc import Collection

from comfy.cli_args import args

supported_pt_extensions: set[str] = {'.ckpt', '.pt', '.bin', '.pth', '.safetensors', '.pkl', '.sft'}

folder_names_and_paths: dict[str, tuple[list[str], set[str]]] = {}

base_path = os.path.dirname(os.path.realpath(__file__))
# --base-directory - Resets all default paths configured in folder_paths with a new base path
if args.base_directory:
base_path = os.path.abspath(args.base_directory)
logging.info(f"Setting base directory to: {base_path}")
else:
base_path = os.path.dirname(os.path.realpath(__file__))

models_dir = os.path.join(base_path, "models")
folder_names_and_paths["checkpoints"] = ([os.path.join(models_dir, "checkpoints")], supported_pt_extensions)
folder_names_and_paths["configs"] = ([os.path.join(models_dir, "configs")], [".yaml"])
Expand Down
74 changes: 69 additions & 5 deletions tests-unit/comfy_test/folder_path_test.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,45 @@
### 🗻 This file is created through the spirit of Mount Fuji at its peak
# TODO(yoland): clean up this after I get back down
import sys
import pytest
import os
import tempfile
from unittest.mock import patch
from importlib import reload

import folder_paths
import comfy.cli_args
from comfy.options import enable_args_parsing
enable_args_parsing()


@pytest.fixture()
def clear_folder_paths():
# Clear the global dictionary before each test to ensure isolation
original = folder_paths.folder_names_and_paths.copy()
folder_paths.folder_names_and_paths.clear()
# Reload the module after each test to ensure isolation
yield
folder_paths.folder_names_and_paths = original
reload(folder_paths)

@pytest.fixture
def temp_dir():
with tempfile.TemporaryDirectory() as tmpdirname:
yield tmpdirname


def test_get_directory_by_type():
@pytest.fixture
def set_base_dir():
def _set_base_dir(base_dir):
# Mock CLI args
with patch.object(sys, 'argv', ["main.py", "--base-directory", base_dir]):
reload(comfy.cli_args)
reload(folder_paths)
yield _set_base_dir
# Reload the modules after each test to ensure isolation
with patch.object(sys, 'argv', ["main.py"]):
reload(comfy.cli_args)
reload(folder_paths)


def test_get_directory_by_type(clear_folder_paths):
test_dir = "/test/dir"
folder_paths.set_output_directory(test_dir)
assert folder_paths.get_directory_by_type("output") == test_dir
Expand Down Expand Up @@ -96,3 +114,49 @@ def test_get_save_image_path(temp_dir):
assert counter == 1
assert subfolder == ""
assert filename_prefix == "test"


def test_base_path_changes(set_base_dir):
test_dir = os.path.abspath("/test/dir")
set_base_dir(test_dir)

assert folder_paths.base_path == test_dir
assert folder_paths.models_dir == os.path.join(test_dir, "models")
assert folder_paths.input_directory == os.path.join(test_dir, "input")
assert folder_paths.output_directory == os.path.join(test_dir, "output")
assert folder_paths.temp_directory == os.path.join(test_dir, "temp")
assert folder_paths.user_directory == os.path.join(test_dir, "user")

assert os.path.join(test_dir, "custom_nodes") in folder_paths.get_folder_paths("custom_nodes")

for name in ["checkpoints", "loras", "vae", "configs", "embeddings", "controlnet", "classifiers"]:
assert folder_paths.get_folder_paths(name)[0] == os.path.join(test_dir, "models", name)


def test_base_path_change_clears_old(set_base_dir):
test_dir = os.path.abspath("/test/dir")
set_base_dir(test_dir)

assert len(folder_paths.get_folder_paths("custom_nodes")) == 1

single_model_paths = [
"checkpoints",
"loras",
"vae",
"configs",
"clip_vision",
"style_models",
"diffusers",
"vae_approx",
"gligen",
"upscale_models",
"embeddings",
"hypernetworks",
"photomaker",
"classifiers",
]
for name in single_model_paths:
assert len(folder_paths.get_folder_paths(name)) == 1

for name in ["controlnet", "diffusion_models", "text_encoders"]:
assert len(folder_paths.get_folder_paths(name)) == 2

0 comments on commit 222f48c

Please sign in to comment.