mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
fix relative model paths to be against config.models_path, not root
This commit is contained in:
parent
974175be45
commit
9968ff2893
@ -171,7 +171,6 @@ from pydantic import BaseSettings, Field, parse_obj_as
|
|||||||
from typing import ClassVar, Dict, List, Set, Literal, Union, get_origin, get_type_hints, get_args
|
from typing import ClassVar, Dict, List, Set, Literal, Union, get_origin, get_type_hints, get_args
|
||||||
|
|
||||||
INIT_FILE = Path("invokeai.yaml")
|
INIT_FILE = Path("invokeai.yaml")
|
||||||
MODEL_CORE = Path("models/core")
|
|
||||||
DB_FILE = Path("invokeai.db")
|
DB_FILE = Path("invokeai.db")
|
||||||
LEGACY_INIT_FILE = Path("invokeai.init")
|
LEGACY_INIT_FILE = Path("invokeai.init")
|
||||||
|
|
||||||
@ -357,7 +356,7 @@ def _find_root() -> Path:
|
|||||||
venv = Path(os.environ.get("VIRTUAL_ENV") or ".")
|
venv = Path(os.environ.get("VIRTUAL_ENV") or ".")
|
||||||
if os.environ.get("INVOKEAI_ROOT"):
|
if os.environ.get("INVOKEAI_ROOT"):
|
||||||
root = Path(os.environ.get("INVOKEAI_ROOT")).resolve()
|
root = Path(os.environ.get("INVOKEAI_ROOT")).resolve()
|
||||||
elif any([(venv.parent / x).exists() for x in [INIT_FILE, LEGACY_INIT_FILE, MODEL_CORE]]):
|
elif any([(venv.parent / x).exists() for x in [INIT_FILE, LEGACY_INIT_FILE]]):
|
||||||
root = (venv.parent).resolve()
|
root = (venv.parent).resolve()
|
||||||
else:
|
else:
|
||||||
root = Path("~/invokeai").expanduser().resolve()
|
root = Path("~/invokeai").expanduser().resolve()
|
||||||
|
@ -181,7 +181,7 @@ def download_with_progress_bar(model_url: str, model_dest: str, label: str = "th
|
|||||||
|
|
||||||
|
|
||||||
def download_conversion_models():
|
def download_conversion_models():
|
||||||
target_dir = config.root_path / "models/core/convert"
|
target_dir = config.models_path / "core/convert"
|
||||||
kwargs = dict() # for future use
|
kwargs = dict() # for future use
|
||||||
try:
|
try:
|
||||||
logger.info("Downloading core tokenizers and text encoders")
|
logger.info("Downloading core tokenizers and text encoders")
|
||||||
|
@ -103,6 +103,7 @@ class ModelInstall(object):
|
|||||||
access_token: str = None,
|
access_token: str = None,
|
||||||
):
|
):
|
||||||
self.config = config
|
self.config = config
|
||||||
|
# force model manager to be a singleton
|
||||||
self.mgr = model_manager or ModelManager(config.model_conf_path)
|
self.mgr = model_manager or ModelManager(config.model_conf_path)
|
||||||
self.datasets = OmegaConf.load(Dataset_path)
|
self.datasets = OmegaConf.load(Dataset_path)
|
||||||
self.prediction_helper = prediction_type_helper
|
self.prediction_helper = prediction_type_helper
|
||||||
@ -273,6 +274,7 @@ class ModelInstall(object):
|
|||||||
logger.error(f"Unable to download {url}. Skipping.")
|
logger.error(f"Unable to download {url}. Skipping.")
|
||||||
info = ModelProbe().heuristic_probe(location)
|
info = ModelProbe().heuristic_probe(location)
|
||||||
dest = self.config.models_path / info.base_type.value / info.model_type.value / location.name
|
dest = self.config.models_path / info.base_type.value / info.model_type.value / location.name
|
||||||
|
dest.parent.mkdir(parents=True, exist_ok=True)
|
||||||
models_path = shutil.move(location, dest)
|
models_path = shutil.move(location, dest)
|
||||||
|
|
||||||
# staged version will be garbage-collected at this time
|
# staged version will be garbage-collected at this time
|
||||||
@ -346,7 +348,7 @@ class ModelInstall(object):
|
|||||||
if key in self.datasets:
|
if key in self.datasets:
|
||||||
description = self.datasets[key].get("description") or description
|
description = self.datasets[key].get("description") or description
|
||||||
|
|
||||||
rel_path = self.relative_to_root(path)
|
rel_path = self.relative_to_root(path,self.config.models_path)
|
||||||
|
|
||||||
attributes = dict(
|
attributes = dict(
|
||||||
path=str(rel_path),
|
path=str(rel_path),
|
||||||
@ -386,8 +388,8 @@ class ModelInstall(object):
|
|||||||
attributes.update(dict(config=str(legacy_conf)))
|
attributes.update(dict(config=str(legacy_conf)))
|
||||||
return attributes
|
return attributes
|
||||||
|
|
||||||
def relative_to_root(self, path: Path) -> Path:
|
def relative_to_root(self, path: Path, root: None) -> Path:
|
||||||
root = self.config.root_path
|
root = root or self.config.root_path
|
||||||
if path.is_relative_to(root):
|
if path.is_relative_to(root):
|
||||||
return path.relative_to(root)
|
return path.relative_to(root)
|
||||||
else:
|
else:
|
||||||
|
@ -63,7 +63,7 @@ from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionS
|
|||||||
from diffusers.pipelines.stable_diffusion.stable_unclip_image_normalizer import StableUnCLIPImageNormalizer
|
from diffusers.pipelines.stable_diffusion.stable_unclip_image_normalizer import StableUnCLIPImageNormalizer
|
||||||
|
|
||||||
from invokeai.backend.util.logging import InvokeAILogger
|
from invokeai.backend.util.logging import InvokeAILogger
|
||||||
from invokeai.app.services.config import InvokeAIAppConfig, MODEL_CORE
|
from invokeai.app.services.config import InvokeAIAppConfig
|
||||||
|
|
||||||
from picklescan.scanner import scan_file_path
|
from picklescan.scanner import scan_file_path
|
||||||
from .models import BaseModelType, ModelVariantType
|
from .models import BaseModelType, ModelVariantType
|
||||||
@ -81,7 +81,7 @@ if is_accelerate_available():
|
|||||||
from accelerate.utils import set_module_tensor_to_device
|
from accelerate.utils import set_module_tensor_to_device
|
||||||
|
|
||||||
logger = InvokeAILogger.getLogger(__name__)
|
logger = InvokeAILogger.getLogger(__name__)
|
||||||
CONVERT_MODEL_ROOT = InvokeAIAppConfig.get_config().root_path / MODEL_CORE / "convert"
|
CONVERT_MODEL_ROOT = InvokeAIAppConfig.get_config().models_path / "core/convert"
|
||||||
|
|
||||||
|
|
||||||
def shave_segments(path, n_shave_prefix_segments=1):
|
def shave_segments(path, n_shave_prefix_segments=1):
|
||||||
@ -1281,7 +1281,7 @@ def download_from_original_stable_diffusion_ckpt(
|
|||||||
original_config = OmegaConf.load(original_config_file)
|
original_config = OmegaConf.load(original_config_file)
|
||||||
if (
|
if (
|
||||||
model_version == BaseModelType.StableDiffusion2
|
model_version == BaseModelType.StableDiffusion2
|
||||||
and original_config["model"]["params"]["parameterization"] == "v"
|
and original_config["model"]["params"].get("parameterization") == "v"
|
||||||
):
|
):
|
||||||
prediction_type = "v_prediction"
|
prediction_type = "v_prediction"
|
||||||
upcast_attention = True
|
upcast_attention = True
|
||||||
|
@ -456,7 +456,7 @@ class ModelManager(object):
|
|||||||
raise ModelNotFoundException(f"Model not found - {model_key}")
|
raise ModelNotFoundException(f"Model not found - {model_key}")
|
||||||
|
|
||||||
model_config = self.models[model_key]
|
model_config = self.models[model_key]
|
||||||
model_path = self.app_config.root_path / model_config.path
|
model_path = self.app_config.models_path / model_config.path
|
||||||
|
|
||||||
if not model_path.exists():
|
if not model_path.exists():
|
||||||
if model_class.save_to_config:
|
if model_class.save_to_config:
|
||||||
@ -623,7 +623,7 @@ class ModelManager(object):
|
|||||||
self.cache.uncache_model(cache_id)
|
self.cache.uncache_model(cache_id)
|
||||||
|
|
||||||
# if model inside invoke models folder - delete files
|
# if model inside invoke models folder - delete files
|
||||||
model_path = self.app_config.root_path / model_cfg.path
|
model_path = self.app_config.models_path / model_cfg.path
|
||||||
cache_path = self._get_model_cache_path(model_path)
|
cache_path = self._get_model_cache_path(model_path)
|
||||||
if cache_path.exists():
|
if cache_path.exists():
|
||||||
rmtree(str(cache_path))
|
rmtree(str(cache_path))
|
||||||
@ -656,8 +656,8 @@ class ModelManager(object):
|
|||||||
"""
|
"""
|
||||||
# relativize paths as they go in - this makes it easier to move the root directory around
|
# relativize paths as they go in - this makes it easier to move the root directory around
|
||||||
if path := model_attributes.get("path"):
|
if path := model_attributes.get("path"):
|
||||||
if Path(path).is_relative_to(self.app_config.root_path):
|
if Path(path).is_relative_to(self.app_config.models_path):
|
||||||
model_attributes["path"] = str(Path(path).relative_to(self.app_config.root_path))
|
model_attributes["path"] = str(Path(path).relative_to(self.app_config.models_path))
|
||||||
|
|
||||||
model_class = MODEL_CLASSES[base_model][model_type]
|
model_class = MODEL_CLASSES[base_model][model_type]
|
||||||
model_config = model_class.create_config(**model_attributes)
|
model_config = model_class.create_config(**model_attributes)
|
||||||
@ -732,7 +732,7 @@ class ModelManager(object):
|
|||||||
/ new_name
|
/ new_name
|
||||||
)
|
)
|
||||||
move(old_path, new_path)
|
move(old_path, new_path)
|
||||||
model_cfg.path = str(new_path.relative_to(self.app_config.root_path))
|
model_cfg.path = str(new_path.relative_to(self.app_config.models_path))
|
||||||
|
|
||||||
# clean up caches
|
# clean up caches
|
||||||
old_model_cache = self._get_model_cache_path(old_path)
|
old_model_cache = self._get_model_cache_path(old_path)
|
||||||
@ -795,7 +795,7 @@ class ModelManager(object):
|
|||||||
info["path"] = (
|
info["path"] = (
|
||||||
str(new_diffusers_path)
|
str(new_diffusers_path)
|
||||||
if dest_directory
|
if dest_directory
|
||||||
else str(new_diffusers_path.relative_to(self.app_config.root_path))
|
else str(new_diffusers_path.relative_to(self.app_config.models_path))
|
||||||
)
|
)
|
||||||
info.pop("config")
|
info.pop("config")
|
||||||
|
|
||||||
@ -883,10 +883,17 @@ class ModelManager(object):
|
|||||||
new_models_found = False
|
new_models_found = False
|
||||||
|
|
||||||
self.logger.info(f"Scanning {self.app_config.models_path} for new models")
|
self.logger.info(f"Scanning {self.app_config.models_path} for new models")
|
||||||
with Chdir(self.app_config.root_path):
|
with Chdir(self.app_config.models_path):
|
||||||
for model_key, model_config in list(self.models.items()):
|
for model_key, model_config in list(self.models.items()):
|
||||||
model_name, cur_base_model, cur_model_type = self.parse_key(model_key)
|
model_name, cur_base_model, cur_model_type = self.parse_key(model_key)
|
||||||
model_path = self.app_config.root_path.absolute() / model_config.path
|
|
||||||
|
# Patch for relative path bug in older models.yaml - paths should not
|
||||||
|
# be starting with a hard-coded 'models'. This will also fix up
|
||||||
|
# models.yaml when committed.
|
||||||
|
if model_config.path.startswith('models'):
|
||||||
|
model_config.path = str(Path(*Path(model_config.path).parts[1:]))
|
||||||
|
|
||||||
|
model_path = self.app_config.models_path.absolute() / model_config.path
|
||||||
if not model_path.exists():
|
if not model_path.exists():
|
||||||
model_class = MODEL_CLASSES[cur_base_model][cur_model_type]
|
model_class = MODEL_CLASSES[cur_base_model][cur_model_type]
|
||||||
if model_class.save_to_config:
|
if model_class.save_to_config:
|
||||||
@ -919,8 +926,8 @@ class ModelManager(object):
|
|||||||
if model_key in self.models:
|
if model_key in self.models:
|
||||||
raise DuplicateModelException(f"Model with key {model_key} added twice")
|
raise DuplicateModelException(f"Model with key {model_key} added twice")
|
||||||
|
|
||||||
if model_path.is_relative_to(self.app_config.root_path):
|
if model_path.is_relative_to(self.app_config.models_path):
|
||||||
model_path = model_path.relative_to(self.app_config.root_path)
|
model_path = model_path.relative_to(self.app_config.models_path)
|
||||||
|
|
||||||
model_config: ModelConfigBase = model_class.probe_config(str(model_path))
|
model_config: ModelConfigBase = model_class.probe_config(str(model_path))
|
||||||
self.models[model_key] = model_config
|
self.models[model_key] = model_config
|
||||||
@ -971,7 +978,7 @@ class ModelManager(object):
|
|||||||
# LS: hacky
|
# LS: hacky
|
||||||
# Patch in the SD VAE from core so that it is available for use by the UI
|
# Patch in the SD VAE from core so that it is available for use by the UI
|
||||||
try:
|
try:
|
||||||
self.heuristic_import({config.root_path / "models/core/convert/sd-vae-ft-mse"})
|
self.heuristic_import({config.models_path / "core/convert/sd-vae-ft-mse"})
|
||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@ -259,7 +259,7 @@ def _convert_ckpt_and_cache(
|
|||||||
"""
|
"""
|
||||||
app_config = InvokeAIAppConfig.get_config()
|
app_config = InvokeAIAppConfig.get_config()
|
||||||
|
|
||||||
weights = app_config.root_path / model_config.path
|
weights = app_config.models_path / model_config.path
|
||||||
config_file = app_config.root_path / model_config.config
|
config_file = app_config.root_path / model_config.config
|
||||||
output_path = Path(output_path)
|
output_path = Path(output_path)
|
||||||
|
|
||||||
|
@ -153,7 +153,7 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
|
|||||||
BufferBox,
|
BufferBox,
|
||||||
name="Log Messages",
|
name="Log Messages",
|
||||||
editable=False,
|
editable=False,
|
||||||
max_height=8,
|
max_height=15,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.nextrely += 1
|
self.nextrely += 1
|
||||||
@ -399,7 +399,7 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
|
|||||||
self.ok_button.hidden = True
|
self.ok_button.hidden = True
|
||||||
self.display()
|
self.display()
|
||||||
|
|
||||||
# for communication with the subprocess
|
# TO DO: Spawn a worker thread, not a subprocess
|
||||||
parent_conn, child_conn = Pipe()
|
parent_conn, child_conn = Pipe()
|
||||||
p = Process(
|
p = Process(
|
||||||
target=process_and_execute,
|
target=process_and_execute,
|
||||||
@ -414,7 +414,6 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
|
|||||||
self.subprocess_connection = parent_conn
|
self.subprocess_connection = parent_conn
|
||||||
self.subprocess = p
|
self.subprocess = p
|
||||||
app.install_selections = InstallSelections()
|
app.install_selections = InstallSelections()
|
||||||
# process_and_execute(app.opt, app.install_selections)
|
|
||||||
|
|
||||||
def on_back(self):
|
def on_back(self):
|
||||||
self.parentApp.switchFormPrevious()
|
self.parentApp.switchFormPrevious()
|
||||||
@ -489,8 +488,6 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
|
|||||||
|
|
||||||
# rebuild the form, saving and restoring some of the fields that need to be preserved.
|
# rebuild the form, saving and restoring some of the fields that need to be preserved.
|
||||||
saved_messages = self.monitor.entry_widget.values
|
saved_messages = self.monitor.entry_widget.values
|
||||||
# autoload_dir = str(config.root_path / self.pipeline_models['autoload_directory'].value)
|
|
||||||
# autoscan = self.pipeline_models['autoscan_on_startup'].value
|
|
||||||
|
|
||||||
app.main_form = app.addForm(
|
app.main_form = app.addForm(
|
||||||
"MAIN",
|
"MAIN",
|
||||||
@ -544,13 +541,6 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
|
|||||||
if downloads := section.get("download_ids"):
|
if downloads := section.get("download_ids"):
|
||||||
selections.install_models.extend(downloads.value.split())
|
selections.install_models.extend(downloads.value.split())
|
||||||
|
|
||||||
# load directory and whether to scan on startup
|
|
||||||
# if self.parentApp.autoload_pending:
|
|
||||||
# selections.scan_directory = str(config.root_path / self.pipeline_models['autoload_directory'].value)
|
|
||||||
# self.parentApp.autoload_pending = False
|
|
||||||
# selections.autoscan_on_startup = self.pipeline_models['autoscan_on_startup'].value
|
|
||||||
|
|
||||||
|
|
||||||
class AddModelApplication(npyscreen.NPSAppManaged):
|
class AddModelApplication(npyscreen.NPSAppManaged):
|
||||||
def __init__(self, opt):
|
def __init__(self, opt):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -639,6 +629,10 @@ def process_and_execute(
|
|||||||
selections: InstallSelections,
|
selections: InstallSelections,
|
||||||
conn_out: Connection = None,
|
conn_out: Connection = None,
|
||||||
):
|
):
|
||||||
|
# need to reinitialize config in subprocess
|
||||||
|
config = InvokeAIAppConfig.get_config()
|
||||||
|
config.parse_args()
|
||||||
|
|
||||||
# set up so that stderr is sent to conn_out
|
# set up so that stderr is sent to conn_out
|
||||||
if conn_out:
|
if conn_out:
|
||||||
translator = StderrToMessage(conn_out)
|
translator = StderrToMessage(conn_out)
|
||||||
@ -685,9 +679,6 @@ def select_and_download_models(opt: Namespace):
|
|||||||
precision = "float32" if opt.full_precision else choose_precision(torch.device(choose_torch_device()))
|
precision = "float32" if opt.full_precision else choose_precision(torch.device(choose_torch_device()))
|
||||||
config.precision = precision
|
config.precision = precision
|
||||||
helper = lambda x: ask_user_for_prediction_type(x)
|
helper = lambda x: ask_user_for_prediction_type(x)
|
||||||
# if do_listings(opt):
|
|
||||||
# pass
|
|
||||||
|
|
||||||
installer = ModelInstall(config, prediction_type_helper=helper)
|
installer = ModelInstall(config, prediction_type_helper=helper)
|
||||||
if opt.list_models:
|
if opt.list_models:
|
||||||
installer.list_models(opt.list_models)
|
installer.list_models(opt.list_models)
|
||||||
@ -706,8 +697,6 @@ def select_and_download_models(opt: Namespace):
|
|||||||
# needed to support the probe() method running under a subprocess
|
# needed to support the probe() method running under a subprocess
|
||||||
torch.multiprocessing.set_start_method("spawn")
|
torch.multiprocessing.set_start_method("spawn")
|
||||||
|
|
||||||
# the third argument is needed in the Windows 11 environment in
|
|
||||||
# order to launch and resize a console window running this program
|
|
||||||
set_min_terminal_size(MIN_COLS, MIN_LINES)
|
set_min_terminal_size(MIN_COLS, MIN_LINES)
|
||||||
installApp = AddModelApplication(opt)
|
installApp = AddModelApplication(opt)
|
||||||
try:
|
try:
|
||||||
|
Loading…
Reference in New Issue
Block a user