start removing repo_id support

This commit is contained in:
Lincoln Stein 2023-06-10 09:57:23 -04:00
parent 887576d217
commit 959e64c9b3

View File

@ -354,23 +354,15 @@ class ModelManager(object):
mconfig = self.config[model_key] mconfig = self.config[model_key]
# type already checked as it's part of key # type already checked as it's part of key
location = None
if model_type == SDModelType.Diffusers: if model_type == SDModelType.Diffusers:
# intercept stanzas that point to checkpoint weights and replace them # intercept stanzas that point to checkpoint weights and replace them with the equivalent diffusers model
# with the equivalent diffusers model
if mconfig.format in ["ckpt", "safetensors"]: if mconfig.format in ["ckpt", "safetensors"]:
location = self.convert_ckpt_and_cache(mconfig) location = self.convert_ckpt_and_cache(mconfig) # TODO: Maybe don't do this any longer?
elif mconfig.get('path'): elif mconfig.get('path'):
location = self.globals.root_dir / mconfig.get('path') location = self.globals.root_dir / mconfig.get('path')
else:
location = mconfig.get('repo_id')
elif p := mconfig.get('path'): elif p := mconfig.get('path'):
location = self.globals.root_dir / p location = self.globals.root_dir / p
elif r := mconfig.get('repo_id'):
location = r
elif w := mconfig.get('weights'):
location = self.globals.root_dir / w
else:
location = None
revision = mconfig.get('revision') revision = mconfig.get('revision')
if model_type in [SDModelType.Lora, SDModelType.TextualInversion]: if model_type in [SDModelType.Lora, SDModelType.TextualInversion]:
@ -378,6 +370,9 @@ class ModelManager(object):
else: else:
hash = self.cache.model_hash(location, revision) hash = self.cache.model_hash(location, revision)
if not location:
return None
# If the caller is asking for part of the model and the config indicates # If the caller is asking for part of the model and the config indicates
# an external replacement for that field, then we fetch the replacement # an external replacement for that field, then we fetch the replacement
if submodel and mconfig.get(submodel): if submodel and mconfig.get(submodel):
@ -653,6 +648,7 @@ class ModelManager(object):
self.cache.uncache_model(self.cache_keys[model_key]) self.cache.uncache_model(self.cache_keys[model_key])
del self.cache_keys[model_key] del self.cache_keys[model_key]
# TODO: DELETE OR UPDATE - handled by scan_models_directory()
def import_diffuser_model( def import_diffuser_model(
self, self,
repo_or_path: Union[str, Path], repo_or_path: Union[str, Path],
@ -689,6 +685,7 @@ class ModelManager(object):
self.commit(commit_to_conf) self.commit(commit_to_conf)
return self.create_key(model_name, SDModelType.Diffusers) return self.create_key(model_name, SDModelType.Diffusers)
# TODO: DELETE OR UPDATE - handled by scan_models_directory()
def import_lora( def import_lora(
self, self,
path: Path, path: Path,
@ -713,6 +710,7 @@ class ModelManager(object):
True True
) )
# TODO: DELETE OR UPDATE - handled by scan_models_directory()
def import_embedding( def import_embedding(
self, self,
path: Path, path: Path,