start removing repo_id support

This commit is contained in:
Lincoln Stein 2023-06-10 09:57:23 -04:00
parent 887576d217
commit 959e64c9b3

View File

@ -354,23 +354,15 @@ class ModelManager(object):
mconfig = self.config[model_key]
# type already checked as it's part of key
location = None
if model_type == SDModelType.Diffusers:
# intercept stanzas that point to checkpoint weights and replace them
# with the equivalent diffusers model
# intercept stanzas that point to checkpoint weights and replace them with the equivalent diffusers model
if mconfig.format in ["ckpt", "safetensors"]:
location = self.convert_ckpt_and_cache(mconfig)
location = self.convert_ckpt_and_cache(mconfig) # TODO: Maybe don't do this any longer?
elif mconfig.get('path'):
location = self.globals.root_dir / mconfig.get('path')
else:
location = mconfig.get('repo_id')
elif p := mconfig.get('path'):
location = self.globals.root_dir / p
elif r := mconfig.get('repo_id'):
location = r
elif w := mconfig.get('weights'):
location = self.globals.root_dir / w
else:
location = None
revision = mconfig.get('revision')
if model_type in [SDModelType.Lora, SDModelType.TextualInversion]:
@ -378,6 +370,9 @@ class ModelManager(object):
else:
hash = self.cache.model_hash(location, revision)
if not location:
return None
# If the caller is asking for part of the model and the config indicates
# an external replacement for that field, then we fetch the replacement
if submodel and mconfig.get(submodel):
@ -653,6 +648,7 @@ class ModelManager(object):
self.cache.uncache_model(self.cache_keys[model_key])
del self.cache_keys[model_key]
# TODO: DELETE OR UPDATE - handled by scan_models_directory()
def import_diffuser_model(
self,
repo_or_path: Union[str, Path],
@ -689,6 +685,7 @@ class ModelManager(object):
self.commit(commit_to_conf)
return self.create_key(model_name, SDModelType.Diffusers)
# TODO: DELETE OR UPDATE - handled by scan_models_directory()
def import_lora(
self,
path: Path,
@ -713,6 +710,7 @@ class ModelManager(object):
True
)
# TODO: DELETE OR UPDATE - handled by scan_models_directory()
def import_embedding(
self,
path: Path,