autoimport from embedding/controlnet/lora folders designated in startup file

This commit is contained in:
Lincoln Stein
2023-06-27 12:30:53 -04:00
parent f15d28d141
commit e8ed0fad6c
7 changed files with 172 additions and 123 deletions

View File

@ -168,11 +168,27 @@ structure at initialization time by scanning the models directory. The
in-memory data structure can be resynchronized by calling
`manager.scan_models_directory()`.
Files and folders placed inside the `autoimport_dir` (path defined in
`invokeai.yaml`, defaulting to `ROOTDIR/autoimport` will also be
scanned for new models at initialization time and added to
`models.yaml`. Files will not be moved from this location but
preserved in-place.
Files and folders placed inside the `autoimport` paths (paths
defined in `invokeai.yaml`) will also be scanned for new models at
initialization time and added to `models.yaml`. Files will not be
moved from this location but preserved in-place. These directories
are:
configuration default description
------------- ------- -----------
autoimport_dir autoimport/main main models
lora_dir autoimport/lora LoRA/LyCORIS models
embedding_dir autoimport/embedding TI embeddings
controlnet_dir autoimport/controlnet ControlNet models
In actuality, models located in any of these directories are scanned
to determine their type, so it isn't strictly necessary to organize
the different types in this way. This entry in `invokeai.yaml` will
recursively scan all subdirectories within `autoimport`, scan models
files it finds, and import them if recognized.
Paths:
autoimport_dir: autoimport
A model can be manually added using `add_model()` using the model's
name, base model, type and a dict of model attributes. See
@ -208,6 +224,7 @@ checkpoint or safetensors file.
The path points to a file or directory on disk. If a relative path,
the root is the InvokeAI ROOTDIR.
"""
from __future__ import annotations
@ -660,7 +677,7 @@ class ModelManager(object):
):
loaded_files = set()
new_models_found = False
with Chdir(self.app_config.root_path):
for model_key, model_config in list(self.models.items()):
model_name, cur_base_model, cur_model_type = self.parse_key(model_key)
@ -720,30 +737,38 @@ class ModelManager(object):
)
installed = set()
if not self.app_config.autoimport_dir:
return installed
autodir = self.app_config.root_path / self.app_config.autoimport_dir
if not (autodir and autodir.exists()):
return installed
known_paths = {(self.app_config.root_path / x['path']).resolve() for x in self.list_models()}
config = self.app_config
known_paths = {(self.app_config.root_path / x['path']) for x in self.list_models()}
scanned_dirs = set()
for root, dirs, files in os.walk(autodir):
for d in dirs:
path = Path(root) / d
if path in known_paths:
continue
if any([(path/x).exists() for x in {'config.json','model_index.json','learned_embeds.bin'}]):
installed.update(installer.heuristic_install(path))
scanned_dirs.add(path)
for f in files:
path = Path(root) / f
if path in known_paths or path.parent in scanned_dirs:
continue
if path.suffix in {'.ckpt','.bin','.pth','.safetensors'}:
installed.update(installer.heuristic_install(path))
for autodir in [config.autoimport_dir,
config.lora_dir,
config.embedding_dir,
config.controlnet_dir]:
if autodir is None:
continue
autodir = self.app_config.root_path / autodir
if not autodir.exists():
continue
for root, dirs, files in os.walk(autodir):
for d in dirs:
path = Path(root) / d
if path in known_paths or path.parent in scanned_dirs:
scanned_dirs.add(path)
continue
if any([(path/x).exists() for x in {'config.json','model_index.json','learned_embeds.bin'}]):
installed.update(installer.heuristic_install(path))
scanned_dirs.add(path)
for f in files:
path = Path(root) / f
if path in known_paths or path.parent in scanned_dirs:
continue
if path.suffix in {'.ckpt','.bin','.pth','.safetensors','.pt'}:
installed.update(installer.heuristic_install(path))
return installed
def heuristic_import(self,

View File

@ -22,7 +22,7 @@ class ModelProbeInfo(object):
variant_type: ModelVariantType
prediction_type: SchedulerPredictionType
upcast_attention: bool
format: Literal['diffusers','checkpoint']
format: Literal['diffusers','checkpoint', 'lycoris']
image_size: int
class ProbeBase(object):
@ -75,22 +75,23 @@ class ModelProbe(object):
between V2-Base and V2-768 SD models.
'''
if model_path:
format = 'diffusers' if model_path.is_dir() else 'checkpoint'
format_type = 'diffusers' if model_path.is_dir() else 'checkpoint'
else:
format = 'diffusers' if isinstance(model,(ConfigMixin,ModelMixin)) else 'checkpoint'
format_type = 'diffusers' if isinstance(model,(ConfigMixin,ModelMixin)) else 'checkpoint'
model_info = None
try:
model_type = cls.get_model_type_from_folder(model_path, model) \
if format == 'diffusers' \
if format_type == 'diffusers' \
else cls.get_model_type_from_checkpoint(model_path, model)
probe_class = cls.PROBES[format].get(model_type)
probe_class = cls.PROBES[format_type].get(model_type)
if not probe_class:
return None
probe = probe_class(model_path, model, prediction_type_helper)
base_type = probe.get_base_type()
variant_type = probe.get_variant_type()
prediction_type = probe.get_scheduler_prediction_type()
format = probe.get_format()
model_info = ModelProbeInfo(
model_type = model_type,
base_type = base_type,
@ -116,10 +117,10 @@ class ModelProbe(object):
if model_path.name == "learned_embeds.bin":
return ModelType.TextualInversion
checkpoint = checkpoint or read_checkpoint_meta(model_path, scan=True)
checkpoint = checkpoint.get("state_dict", checkpoint)
ckpt = checkpoint if checkpoint else read_checkpoint_meta(model_path, scan=True)
ckpt = ckpt.get("state_dict", ckpt)
for key in checkpoint.keys():
for key in ckpt.keys():
if any(key.startswith(v) for v in {"cond_stage_model.", "first_stage_model.", "model.diffusion_model."}):
return ModelType.Main
elif any(key.startswith(v) for v in {"encoder.conv_in", "decoder.conv_in"}):
@ -133,7 +134,7 @@ class ModelProbe(object):
else:
# diffusers-ti
if len(checkpoint) < 10 and all(isinstance(v, torch.Tensor) for v in checkpoint.values()):
if len(ckpt) < 10 and all(isinstance(v, torch.Tensor) for v in ckpt.values()):
return ModelType.TextualInversion
raise ValueError("Unable to determine model type")
@ -201,6 +202,9 @@ class ProbeBase(object):
def get_scheduler_prediction_type(self)->SchedulerPredictionType:
pass
def get_format(self)->str:
pass
class CheckpointProbeBase(ProbeBase):
def __init__(self,
checkpoint_path: Path,
@ -214,6 +218,9 @@ class CheckpointProbeBase(ProbeBase):
def get_base_type(self)->BaseModelType:
pass
def get_format(self)->str:
return 'checkpoint'
def get_variant_type(self)-> ModelVariantType:
model_type = ModelProbe.get_model_type_from_checkpoint(self.checkpoint_path,self.checkpoint)
if model_type != ModelType.Main:
@ -267,6 +274,9 @@ class VaeCheckpointProbe(CheckpointProbeBase):
return BaseModelType.StableDiffusion1
class LoRACheckpointProbe(CheckpointProbeBase):
def get_format(self)->str:
return 'lycoris'
def get_base_type(self)->BaseModelType:
checkpoint = self.checkpoint
key1 = "lora_te_text_model_encoder_layers_0_mlp_fc1.lora_down.weight"
@ -286,6 +296,9 @@ class LoRACheckpointProbe(CheckpointProbeBase):
return None
class TextualInversionCheckpointProbe(CheckpointProbeBase):
def get_format(self)->str:
return None
def get_base_type(self)->BaseModelType:
checkpoint = self.checkpoint
if 'string_to_token' in checkpoint:
@ -332,6 +345,9 @@ class FolderProbeBase(ProbeBase):
def get_variant_type(self)->ModelVariantType:
return ModelVariantType.Normal
def get_format(self)->str:
return 'diffusers'
class PipelineFolderProbe(FolderProbeBase):
def get_base_type(self)->BaseModelType:
if self.model:
@ -387,6 +403,9 @@ class VaeFolderProbe(FolderProbeBase):
return BaseModelType.StableDiffusion1
class TextualInversionFolderProbe(FolderProbeBase):
def get_format(self)->str:
return None
def get_base_type(self)->BaseModelType:
path = self.folder_path / 'learned_embeds.bin'
if not path.exists():

View File

@ -397,7 +397,7 @@ def read_checkpoint_meta(path: Union[str, Path], scan: bool = False):
checkpoint = safetensors.torch.load_file(path, device="cpu")
else:
if scan:
scan_result = scan_file_path(checkpoint)
scan_result = scan_file_path(path)
if scan_result.infected_files != 0:
raise Exception(f"The model file \"{path}\" is potentially infected by malware. Aborting import.")
checkpoint = torch.load(path, map_location=torch.device("meta"))