mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
LoRA model loading fixes (#3663)
This PR enables model manager importation of diffusers-style .bin LoRAs. However, since there is no backend support for this type of LoRA yet, attempts to use them will result in an unimplemented error. It closes #3636 and #3637
This commit is contained in:
commit
2595c1d86f
@ -193,7 +193,10 @@ class ModelInstall(object):
|
|||||||
models_installed.update(self._install_path(path))
|
models_installed.update(self._install_path(path))
|
||||||
|
|
||||||
# folders style or similar
|
# folders style or similar
|
||||||
elif path.is_dir() and any([(path/x).exists() for x in {'config.json','model_index.json','learned_embeds.bin'}]):
|
elif path.is_dir() and any([(path/x).exists() for x in \
|
||||||
|
{'config.json','model_index.json','learned_embeds.bin','pytorch_lora_weights.bin'}
|
||||||
|
]
|
||||||
|
):
|
||||||
models_installed.update(self._install_path(path))
|
models_installed.update(self._install_path(path))
|
||||||
|
|
||||||
# recursive scan
|
# recursive scan
|
||||||
|
@ -785,7 +785,7 @@ class ModelManager(object):
|
|||||||
if path in known_paths or path.parent in scanned_dirs:
|
if path in known_paths or path.parent in scanned_dirs:
|
||||||
scanned_dirs.add(path)
|
scanned_dirs.add(path)
|
||||||
continue
|
continue
|
||||||
if any([(path/x).exists() for x in {'config.json','model_index.json','learned_embeds.bin'}]):
|
if any([(path/x).exists() for x in {'config.json','model_index.json','learned_embeds.bin','pytorch_lora_weights.bin'}]):
|
||||||
new_models_found.update(installer.heuristic_import(path))
|
new_models_found.update(installer.heuristic_import(path))
|
||||||
scanned_dirs.add(path)
|
scanned_dirs.add(path)
|
||||||
|
|
||||||
@ -794,7 +794,8 @@ class ModelManager(object):
|
|||||||
if path in known_paths or path.parent in scanned_dirs:
|
if path in known_paths or path.parent in scanned_dirs:
|
||||||
continue
|
continue
|
||||||
if path.suffix in {'.ckpt','.bin','.pth','.safetensors','.pt'}:
|
if path.suffix in {'.ckpt','.bin','.pth','.safetensors','.pt'}:
|
||||||
new_models_found.update(installer.heuristic_import(path))
|
import_result = installer.heuristic_import(path)
|
||||||
|
new_models_found.update(import_result)
|
||||||
|
|
||||||
self.logger.info(f'Scanned {items_scanned} files and directories, imported {len(new_models_found)} models')
|
self.logger.info(f'Scanned {items_scanned} files and directories, imported {len(new_models_found)} models')
|
||||||
installed.update(new_models_found)
|
installed.update(new_models_found)
|
||||||
|
@ -78,7 +78,6 @@ class ModelProbe(object):
|
|||||||
format_type = 'diffusers' if model_path.is_dir() else 'checkpoint'
|
format_type = 'diffusers' if model_path.is_dir() else 'checkpoint'
|
||||||
else:
|
else:
|
||||||
format_type = 'diffusers' if isinstance(model,(ConfigMixin,ModelMixin)) else 'checkpoint'
|
format_type = 'diffusers' if isinstance(model,(ConfigMixin,ModelMixin)) else 'checkpoint'
|
||||||
|
|
||||||
model_info = None
|
model_info = None
|
||||||
try:
|
try:
|
||||||
model_type = cls.get_model_type_from_folder(model_path, model) \
|
model_type = cls.get_model_type_from_folder(model_path, model) \
|
||||||
@ -105,7 +104,7 @@ class ModelProbe(object):
|
|||||||
) else 512,
|
) else 512,
|
||||||
)
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
return None
|
raise
|
||||||
|
|
||||||
return model_info
|
return model_info
|
||||||
|
|
||||||
@ -127,6 +126,8 @@ class ModelProbe(object):
|
|||||||
return ModelType.Vae
|
return ModelType.Vae
|
||||||
elif any(key.startswith(v) for v in {"lora_te_", "lora_unet_"}):
|
elif any(key.startswith(v) for v in {"lora_te_", "lora_unet_"}):
|
||||||
return ModelType.Lora
|
return ModelType.Lora
|
||||||
|
elif any(key.endswith(v) for v in {"to_k_lora.up.weight", "to_q_lora.down.weight"}):
|
||||||
|
return ModelType.Lora
|
||||||
elif any(key.startswith(v) for v in {"control_model", "input_blocks"}):
|
elif any(key.startswith(v) for v in {"control_model", "input_blocks"}):
|
||||||
return ModelType.ControlNet
|
return ModelType.ControlNet
|
||||||
elif key in {"emb_params", "string_to_param"}:
|
elif key in {"emb_params", "string_to_param"}:
|
||||||
@ -137,7 +138,7 @@ class ModelProbe(object):
|
|||||||
if len(ckpt) < 10 and all(isinstance(v, torch.Tensor) for v in ckpt.values()):
|
if len(ckpt) < 10 and all(isinstance(v, torch.Tensor) for v in ckpt.values()):
|
||||||
return ModelType.TextualInversion
|
return ModelType.TextualInversion
|
||||||
|
|
||||||
raise ValueError("Unable to determine model type")
|
raise ValueError(f"Unable to determine model type for {model_path}")
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_model_type_from_folder(cls, folder_path: Path, model: ModelMixin)->ModelType:
|
def get_model_type_from_folder(cls, folder_path: Path, model: ModelMixin)->ModelType:
|
||||||
@ -167,7 +168,7 @@ class ModelProbe(object):
|
|||||||
return type
|
return type
|
||||||
|
|
||||||
# give up
|
# give up
|
||||||
raise ValueError("Unable to determine model type")
|
raise ValueError("Unable to determine model type for {folder_path}")
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _scan_and_load_checkpoint(cls,model_path: Path)->dict:
|
def _scan_and_load_checkpoint(cls,model_path: Path)->dict:
|
||||||
|
@ -678,9 +678,8 @@ def select_and_download_models(opt: Namespace):
|
|||||||
|
|
||||||
# this is where the TUI is called
|
# this is where the TUI is called
|
||||||
else:
|
else:
|
||||||
# needed because the torch library is loaded, even though we don't use it
|
# needed to support the probe() method running under a subprocess
|
||||||
# currently commented out because it has started generating errors (?)
|
torch.multiprocessing.set_start_method("spawn")
|
||||||
# torch.multiprocessing.set_start_method("spawn")
|
|
||||||
|
|
||||||
# the third argument is needed in the Windows 11 environment in
|
# the third argument is needed in the Windows 11 environment in
|
||||||
# order to launch and resize a console window running this program
|
# order to launch and resize a console window running this program
|
||||||
|
Loading…
x
Reference in New Issue
Block a user