mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Refactor logic/small fixes
This commit is contained in:
parent
160b5d7992
commit
23c22ac933
@ -311,7 +311,7 @@ class ModelManager(object):
|
|||||||
self.models[model_key] = model_class.create_config(**model_config)
|
self.models[model_key] = model_class.create_config(**model_config)
|
||||||
|
|
||||||
# check config version number and update on disk/RAM if necessary
|
# check config version number and update on disk/RAM if necessary
|
||||||
self.globals = InvokeAIAppConfig.get_config()
|
self.app_config = InvokeAIAppConfig.get_config()
|
||||||
self.logger = logger
|
self.logger = logger
|
||||||
self.cache = ModelCache(
|
self.cache = ModelCache(
|
||||||
max_cache_size=max_cache_size,
|
max_cache_size=max_cache_size,
|
||||||
@ -362,6 +362,9 @@ class ModelManager(object):
|
|||||||
|
|
||||||
return (model_name, base_model, model_type)
|
return (model_name, base_model, model_type)
|
||||||
|
|
||||||
|
def _get_model_cache_path(self, model_path):
|
||||||
|
return self.app_config.models_path / ".cache" / hashlib.md5(str(model_path).encode()).hexdigest()
|
||||||
|
|
||||||
def get_model(
|
def get_model(
|
||||||
self,
|
self,
|
||||||
model_name: str,
|
model_name: str,
|
||||||
@ -382,19 +385,13 @@ class ModelManager(object):
|
|||||||
|
|
||||||
# if model not found try to find it (maybe file just pasted)
|
# if model not found try to find it (maybe file just pasted)
|
||||||
if model_key not in self.models:
|
if model_key not in self.models:
|
||||||
# TODO: find by mask or try rescan?
|
self.scan_models_directory(base_model=base_model, model_type=model_type)
|
||||||
path_mask = f"/models/{base_model}/{model_type}/{model_name}*"
|
if model_key not in self.models:
|
||||||
if False: # model_path = next(find_by_mask(path_mask)):
|
|
||||||
model_path = None # TODO:
|
|
||||||
model_config = model_class.probe_config(model_path)
|
|
||||||
self.models[model_key] = model_config
|
|
||||||
else:
|
|
||||||
raise Exception(f"Model not found - {model_key}")
|
raise Exception(f"Model not found - {model_key}")
|
||||||
|
|
||||||
# if it known model check that target path exists (if manualy deleted)
|
model_config = self.models[model_key]
|
||||||
else:
|
model_path = self.app_config.root_path / model_config.path
|
||||||
# logic repeated twice(in rescan too) any way to optimize?
|
|
||||||
model_path = self.globals.root_path / self.models[model_key].path
|
|
||||||
if not model_path.exists():
|
if not model_path.exists():
|
||||||
if model_class.save_to_config:
|
if model_class.save_to_config:
|
||||||
self.models[model_key].error = ModelError.NotFound
|
self.models[model_key].error = ModelError.NotFound
|
||||||
@ -404,16 +401,6 @@ class ModelManager(object):
|
|||||||
self.models.pop(model_key, None)
|
self.models.pop(model_key, None)
|
||||||
raise Exception(f"Model not found - {model_key}")
|
raise Exception(f"Model not found - {model_key}")
|
||||||
|
|
||||||
# reset model errors?
|
|
||||||
|
|
||||||
model_config = self.models[model_key]
|
|
||||||
|
|
||||||
# /models/{base_model}/{model_type}/{name}.ckpt or .safentesors
|
|
||||||
# /models/{base_model}/{model_type}/{name}/
|
|
||||||
# massage relative paths into absolute ones
|
|
||||||
model_path = model_path or self.globals.root_path / model_config.path
|
|
||||||
model_config.path = model_path
|
|
||||||
|
|
||||||
# vae/movq override
|
# vae/movq override
|
||||||
# TODO:
|
# TODO:
|
||||||
if submodel_type is not None and hasattr(model_config, submodel_type):
|
if submodel_type is not None and hasattr(model_config, submodel_type):
|
||||||
@ -426,7 +413,7 @@ class ModelManager(object):
|
|||||||
|
|
||||||
# TODO: path
|
# TODO: path
|
||||||
# TODO: is it accurate to use path as id
|
# TODO: is it accurate to use path as id
|
||||||
dst_convert_path = self.globals.models_dir / ".cache" / hashlib.md5(str(model_path).encode()).hexdigest()
|
dst_convert_path = self._get_model_cache_path(model_path)
|
||||||
model_path = model_class.convert_if_required(
|
model_path = model_class.convert_if_required(
|
||||||
base_model=base_model,
|
base_model=base_model,
|
||||||
model_path=model_path,
|
model_path=model_path,
|
||||||
@ -547,9 +534,12 @@ class ModelManager(object):
|
|||||||
self.cache.uncache_model(cache_id)
|
self.cache.uncache_model(cache_id)
|
||||||
|
|
||||||
# if model inside invoke models folder - delete files
|
# if model inside invoke models folder - delete files
|
||||||
model_path = self.globals.root_path / model_cfg.path
|
model_path = self.app_config.root_path / model_cfg.path
|
||||||
|
cache_path = self._get_model_cache_path(model_path)
|
||||||
|
if cache_path.exists():
|
||||||
|
rmtree(str(cache_path))
|
||||||
|
|
||||||
if model_path.is_relative_to(self.globals.models_path):
|
if model_path.is_relative_to(self.app_config.models_path):
|
||||||
if model_path.is_dir():
|
if model_path.is_dir():
|
||||||
rmtree(str(model_path))
|
rmtree(str(model_path))
|
||||||
else:
|
else:
|
||||||
@ -576,18 +566,30 @@ class ModelManager(object):
|
|||||||
model_config = model_class.create_config(**model_attributes)
|
model_config = model_class.create_config(**model_attributes)
|
||||||
model_key = self.create_key(model_name, base_model, model_type)
|
model_key = self.create_key(model_name, base_model, model_type)
|
||||||
|
|
||||||
assert (
|
if clobber or model_key not in self.models:
|
||||||
clobber or model_key not in self.models
|
raise Exception(f'Attempt to overwrite existing model definition "{model_key}"')
|
||||||
), f'attempt to overwrite existing model definition "{model_key}"'
|
|
||||||
|
|
||||||
self.models[model_key] = model_config
|
old_model = self.models.pop(model_key, False)
|
||||||
|
if old_model is not None:
|
||||||
|
# TODO: if path changed and old_model.path inside models folder should we delete this too?
|
||||||
|
|
||||||
if clobber and model_key in self.cache_keys:
|
# remove conversion cache as config changed
|
||||||
|
old_model_path = self.app_config.root_path / old_model.path
|
||||||
|
old_model_cache = self._get_model_cache_path(old_model_path)
|
||||||
|
if old_model_cache.exists():
|
||||||
|
if old_model_cache.is_dir():
|
||||||
|
rmtree(str(old_model_cache))
|
||||||
|
else:
|
||||||
|
old_model_cache.unlink()
|
||||||
|
|
||||||
|
# remove in-memory cache
|
||||||
# note: it not garantie to release memory(model can has other references)
|
# note: it not garantie to release memory(model can has other references)
|
||||||
cache_ids = self.cache_keys.pop(model_key, [])
|
cache_ids = self.cache_keys.pop(model_key, [])
|
||||||
for cache_id in cache_ids:
|
for cache_id in cache_ids:
|
||||||
self.cache.uncache_model(cache_id)
|
self.cache.uncache_model(cache_id)
|
||||||
|
|
||||||
|
self.models[model_key] = model_config
|
||||||
|
|
||||||
def search_models(self, search_folder):
|
def search_models(self, search_folder):
|
||||||
self.logger.info(f"Finding Models In: {search_folder}")
|
self.logger.info(f"Finding Models In: {search_folder}")
|
||||||
models_folder_ckpt = Path(search_folder).glob("**/*.ckpt")
|
models_folder_ckpt = Path(search_folder).glob("**/*.ckpt")
|
||||||
@ -628,7 +630,7 @@ class ModelManager(object):
|
|||||||
yaml_str = OmegaConf.to_yaml(data_to_save)
|
yaml_str = OmegaConf.to_yaml(data_to_save)
|
||||||
config_file_path = conf_file or self.config_path
|
config_file_path = conf_file or self.config_path
|
||||||
assert config_file_path is not None,'no config file path to write to'
|
assert config_file_path is not None,'no config file path to write to'
|
||||||
config_file_path = self.globals.root_dir / config_file_path
|
config_file_path = self.app_config.root_path / config_file_path
|
||||||
tmpfile = os.path.join(os.path.dirname(config_file_path), "new_config.tmp")
|
tmpfile = os.path.join(os.path.dirname(config_file_path), "new_config.tmp")
|
||||||
with open(tmpfile, "w", encoding="utf-8") as outfile:
|
with open(tmpfile, "w", encoding="utf-8") as outfile:
|
||||||
outfile.write(self.preamble())
|
outfile.write(self.preamble())
|
||||||
@ -651,16 +653,20 @@ class ModelManager(object):
|
|||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
|
|
||||||
def scan_models_directory(self):
|
def scan_models_directory(
|
||||||
|
self,
|
||||||
|
base_model: Optional[BaseModelType] = None,
|
||||||
|
model_type: Optional[ModelType] = None,
|
||||||
|
):
|
||||||
loaded_files = set()
|
loaded_files = set()
|
||||||
new_models_found = False
|
new_models_found = False
|
||||||
|
|
||||||
with Chdir(self.globals.root_path):
|
with Chdir(self.app_config.root_path):
|
||||||
for model_key, model_config in list(self.models.items()):
|
for model_key, model_config in list(self.models.items()):
|
||||||
model_name, base_model, model_type = self.parse_key(model_key)
|
model_name, cur_base_model, cur_model_type = self.parse_key(model_key)
|
||||||
model_path = str(model_config.path)
|
model_path = self.app_config.root_path / model_config.path
|
||||||
if not os.path.exists(model_path):
|
if not model_path.exists():
|
||||||
model_class = MODEL_CLASSES[base_model][model_type]
|
model_class = MODEL_CLASSES[cur_base_model][cur_model_type]
|
||||||
if model_class.save_to_config:
|
if model_class.save_to_config:
|
||||||
model_config.error = ModelError.NotFound
|
model_config.error = ModelError.NotFound
|
||||||
else:
|
else:
|
||||||
@ -668,24 +674,29 @@ class ModelManager(object):
|
|||||||
else:
|
else:
|
||||||
loaded_files.add(model_path)
|
loaded_files.add(model_path)
|
||||||
|
|
||||||
for base_model in BaseModelType:
|
for cur_base_model in BaseModelType:
|
||||||
for model_type in ModelType:
|
if base_model is not None and cur_base_model != base_model:
|
||||||
model_class = MODEL_CLASSES[base_model][model_type]
|
continue
|
||||||
models_dir = os.path.join(self.globals.models_dir, base_model, model_type)
|
|
||||||
|
|
||||||
if not os.path.exists(models_dir):
|
for cur_model_type in ModelType:
|
||||||
|
if model_type is not None and cur_model_type != model_type:
|
||||||
|
continue
|
||||||
|
model_class = MODEL_CLASSES[cur_base_model][cur_model_type]
|
||||||
|
models_dir = self.app_config.models_path / cur_base_model.value / cur_model_type.value
|
||||||
|
|
||||||
|
if not models_dir.exists():
|
||||||
continue # TODO: or create all folders?
|
continue # TODO: or create all folders?
|
||||||
|
|
||||||
for entry_name in os.listdir(models_dir):
|
for model_path in models_dir.iterdir():
|
||||||
model_path = os.path.join(models_dir, entry_name)
|
|
||||||
if model_path not in loaded_files: # TODO: check
|
if model_path not in loaded_files: # TODO: check
|
||||||
model_path = Path(model_path)
|
|
||||||
model_name = model_path.name if model_path.is_dir() else model_path.stem
|
model_name = model_path.name if model_path.is_dir() else model_path.stem
|
||||||
model_key = self.create_key(model_name, base_model, model_type)
|
model_key = self.create_key(model_name, cur_base_model, cur_model_type)
|
||||||
|
|
||||||
if model_key in self.models:
|
if model_key in self.models:
|
||||||
raise Exception(f"Model with key {model_key} added twice")
|
raise Exception(f"Model with key {model_key} added twice")
|
||||||
|
|
||||||
|
if model_path.is_relative_to(self.app_config.root_path):
|
||||||
|
model_path = model_path.relative_to(self.app_config.root_path)
|
||||||
model_config: ModelConfigBase = model_class.probe_config(str(model_path))
|
model_config: ModelConfigBase = model_class.probe_config(str(model_path))
|
||||||
self.models[model_key] = model_config
|
self.models[model_key] = model_config
|
||||||
new_models_found = True
|
new_models_found = True
|
||||||
@ -701,18 +712,18 @@ class ModelManager(object):
|
|||||||
'''
|
'''
|
||||||
# avoid circular import
|
# avoid circular import
|
||||||
from invokeai.backend.install.model_install_backend import ModelInstall
|
from invokeai.backend.install.model_install_backend import ModelInstall
|
||||||
installer = ModelInstall(config = self.globals,
|
installer = ModelInstall(config = self.app_config,
|
||||||
model_manager = self)
|
model_manager = self)
|
||||||
|
|
||||||
installed = set()
|
installed = set()
|
||||||
if not self.globals.autoimport_dir:
|
if not self.app_config.autoimport_dir:
|
||||||
return installed
|
return installed
|
||||||
|
|
||||||
autodir = self.globals.root_path / self.globals.autoimport_dir
|
autodir = self.app_config.root_path / self.app_config.autoimport_dir
|
||||||
if not (autodir and autodir.exists()):
|
if not (autodir and autodir.exists()):
|
||||||
return installed
|
return installed
|
||||||
|
|
||||||
known_paths = {(self.globals.root_path / x['path']).resolve() for x in self.list_models()}
|
known_paths = {(self.app_config.root_path / x['path']).resolve() for x in self.list_models()}
|
||||||
scanned_dirs = set()
|
scanned_dirs = set()
|
||||||
for root, dirs, files in os.walk(autodir):
|
for root, dirs, files in os.walk(autodir):
|
||||||
for d in dirs:
|
for d in dirs:
|
||||||
@ -752,7 +763,7 @@ class ModelManager(object):
|
|||||||
from invokeai.backend.install.model_install_backend import ModelInstall
|
from invokeai.backend.install.model_install_backend import ModelInstall
|
||||||
successfully_installed = set()
|
successfully_installed = set()
|
||||||
|
|
||||||
installer = ModelInstall(config = self.globals,
|
installer = ModelInstall(config = self.app_config,
|
||||||
prediction_type_helper = prediction_type_helper,
|
prediction_type_helper = prediction_type_helper,
|
||||||
model_manager = self)
|
model_manager = self)
|
||||||
for thing in items_to_import:
|
for thing in items_to_import:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user