more refactoring; fixed place where rel conversion missed

This commit is contained in:
Lincoln Stein
2023-07-29 13:00:07 -04:00
parent 982a568349
commit 99daa97978
4 changed files with 39 additions and 48 deletions

View File

@ -586,7 +586,7 @@ class ModelManager(object):
# expose paths as absolute to help web UI
if path := model_dict.get("path"):
model_dict["path"] = str(self.app_config.root_path / path)
model_dict["path"] = str(self.resolve_model_path(path))
models.append(model_dict)
return models
@ -654,10 +654,9 @@ class ModelManager(object):
The returned dict has the same format as the dict returned by
model_info().
"""
# relativize paths as they go in - this makes it easier to move the root directory around
# relativize paths as they go in - this makes it easier to move the models directory around
if path := model_attributes.get("path"):
if Path(path).is_relative_to(self.app_config.models_path):
model_attributes["path"] = str(Path(path).relative_to(self.app_config.models_path))
model_attributes["path"] = str(self.relative_model_path(Path(path)))
model_class = MODEL_CLASSES[base_model][model_type]
model_config = model_class.create_config(**model_attributes)
@ -715,7 +714,7 @@ class ModelManager(object):
if not model_cfg:
raise ModelNotFoundException(f"Unknown model: {model_key}")
old_path = self.app_config.root_path / model_cfg.path
old_path = self.resolve_model_path(model_cfg.path)
new_name = new_name or model_name
new_base = new_base or base_model
new_key = self.create_key(new_name, new_base, model_type)
@ -725,11 +724,13 @@ class ModelManager(object):
# if this is a model file/directory that we manage ourselves, we need to move it
if old_path.is_relative_to(self.app_config.models_path):
new_path = (
self.app_config.root_path
/ "models"
/ BaseModelType(new_base).value
/ ModelType(model_type).value
/ new_name
self.resolve_model_path(
Path(
BaseModelType(new_base).value,
ModelType(model_type).value,
new_name,
)
)
)
move(old_path, new_path)
model_cfg.path = str(new_path.relative_to(self.app_config.models_path))
@ -810,9 +811,15 @@ class ModelManager(object):
return result
def resolve_model_path(self, path: str) -> Path:
def resolve_model_path(self, path: Union[Path,str]) -> Path:
"""return relative paths based on configured models_path"""
return self.app_config.models_path / path
def relative_model_path(self, model_path: Path) -> Path:
if model_path.is_relative_to(self.app_config.models_path):
model_path = model_path.relative_to(self.app_config.models_path)
return model_path
def search_models(self, search_folder):
self.logger.info(f"Finding Models In: {search_folder}")
models_folder_ckpt = Path(search_folder).glob("**/*.ckpt")
@ -896,7 +903,7 @@ class ModelManager(object):
if model_config.path.startswith("models"):
model_config.path = str(Path(*Path(model_config.path).parts[1:]))
model_path = self.app_config.models_path.absolute() / model_config.path
model_path = self.resolve_model_path(model_config.path).absolute()
if not model_path.exists():
model_class = MODEL_CLASSES[cur_base_model][cur_model_type]
if model_class.save_to_config:
@ -915,7 +922,7 @@ class ModelManager(object):
if model_type is not None and cur_model_type != model_type:
continue
model_class = MODEL_CLASSES[cur_base_model][cur_model_type]
models_dir = self.app_config.models_path / cur_base_model.value / cur_model_type.value
models_dir = self.resolve_model_path(Path(cur_base_model.value, cur_model_type.value))
if not models_dir.exists():
continue # TODO: or create all folders?
@ -928,10 +935,8 @@ class ModelManager(object):
try:
if model_key in self.models:
raise DuplicateModelException(f"Model with key {model_key} added twice")
if model_path.is_relative_to(self.app_config.models_path):
model_path = model_path.relative_to(self.app_config.models_path)
model_path = self.relative_model_path(model_path)
model_config: ModelConfigBase = model_class.probe_config(str(model_path))
self.models[model_key] = model_config
new_models_found = True
@ -942,12 +947,11 @@ class ModelManager(object):
except NotImplementedError as e:
self.logger.warning(e)
imported_models = self.autoimport()
imported_models = self.scan_autoimport_directory()
if (new_models_found or imported_models) and self.config_path:
self.commit()
def autoimport(self) -> Dict[str, AddModelResult]:
def scan_autoimport_directory(self) -> Dict[str, AddModelResult]:
"""
Scan the autoimport directory (if defined) and import new models, delete defunct models.
"""
@ -981,7 +985,7 @@ class ModelManager(object):
# LS: hacky
# Patch in the SD VAE from core so that it is available for use by the UI
try:
self.heuristic_import({config.models_path / "core/convert/sd-vae-ft-mse"})
self.heuristic_import({self.resolve_model_path("core/convert/sd-vae-ft-mse")})
except:
pass