mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
resolve conflicts
This commit is contained in:
@ -593,9 +593,12 @@ script, which will perform a full upgrade in place."""
|
||||
config = InvokeAIAppConfig.get_config()
|
||||
config.parse_args(['--root',str(dest_root)])
|
||||
|
||||
# TODO: revisit
|
||||
# assert (dest_root / 'models').is_dir(), f"{dest_root} does not contain a 'models' subdirectory"
|
||||
# assert (dest_root / 'invokeai.yaml').exists(), f"{dest_root} does not contain an InvokeAI init file."
|
||||
# TODO: revisit - don't rely on invokeai.yaml to exist yet!
|
||||
dest_is_setup = (dest_root / 'models/core').exists() and (dest_root / 'databases').exists()
|
||||
if not dest_is_setup:
|
||||
import invokeai.frontend.install.invokeai_configure
|
||||
from invokeai.backend.install.invokeai_configure import initialize_rootdir
|
||||
initialize_rootdir(dest_root, True)
|
||||
|
||||
do_migrate(src_root,dest_root)
|
||||
|
||||
|
@ -71,8 +71,6 @@ class ModelInstallList:
|
||||
class InstallSelections():
|
||||
install_models: List[str]= field(default_factory=list)
|
||||
remove_models: List[str]=field(default_factory=list)
|
||||
# scan_directory: Path = None
|
||||
# autoscan_on_startup: bool=False
|
||||
|
||||
@dataclass
|
||||
class ModelLoadInfo():
|
||||
@ -121,8 +119,8 @@ class ModelInstall(object):
|
||||
installed_models = self.mgr.list_models()
|
||||
for md in installed_models:
|
||||
base = md['base_model']
|
||||
model_type = md['type']
|
||||
name = md['name']
|
||||
model_type = md['model_type']
|
||||
name = md['model_name']
|
||||
key = ModelManager.create_key(name, base, model_type)
|
||||
if key in model_dict:
|
||||
model_dict[key].installed = True
|
||||
|
@ -36,6 +36,9 @@ from .models import BaseModelType, ModelType, SubModelType, ModelBase
|
||||
# Default is roughly enough to hold three fp16 diffusers models in RAM simultaneously
|
||||
DEFAULT_MAX_CACHE_SIZE = 6.0
|
||||
|
||||
# amount of GPU memory to hold in reserve for use by generations (GB)
|
||||
DEFAULT_MAX_VRAM_CACHE_SIZE= 2.75
|
||||
|
||||
# actual size of a gig
|
||||
GIG = 1073741824
|
||||
|
||||
@ -82,6 +85,7 @@ class ModelCache(object):
|
||||
def __init__(
|
||||
self,
|
||||
max_cache_size: float=DEFAULT_MAX_CACHE_SIZE,
|
||||
max_vram_cache_size: float=DEFAULT_MAX_VRAM_CACHE_SIZE,
|
||||
execution_device: torch.device=torch.device('cuda'),
|
||||
storage_device: torch.device=torch.device('cpu'),
|
||||
precision: torch.dtype=torch.float16,
|
||||
@ -99,12 +103,11 @@ class ModelCache(object):
|
||||
:param sequential_offload: Conserve VRAM by loading and unloading each stage of the pipeline sequentially
|
||||
:param sha_chunksize: Chunksize to use when calculating sha256 model hash
|
||||
'''
|
||||
#max_cache_size = 9999
|
||||
self.model_infos: Dict[str, ModelBase] = dict()
|
||||
self.lazy_offloading = lazy_offloading
|
||||
#self.sequential_offload: bool=sequential_offload
|
||||
self.precision: torch.dtype=precision
|
||||
self.max_cache_size: int=max_cache_size
|
||||
self.max_cache_size: float=max_cache_size
|
||||
self.max_vram_cache_size: float=max_vram_cache_size
|
||||
self.execution_device: torch.device=execution_device
|
||||
self.storage_device: torch.device=storage_device
|
||||
self.sha_chunksize=sha_chunksize
|
||||
@ -201,14 +204,22 @@ class ModelCache(object):
|
||||
self._cache_stack.remove(key)
|
||||
self._cache_stack.append(key)
|
||||
|
||||
return self.ModelLocker(self, key, cache_entry.model, gpu_load)
|
||||
return self.ModelLocker(self, key, cache_entry.model, gpu_load, cache_entry.size)
|
||||
|
||||
class ModelLocker(object):
|
||||
def __init__(self, cache, key, model, gpu_load):
|
||||
def __init__(self, cache, key, model, gpu_load, size_needed):
|
||||
'''
|
||||
:param cache: The model_cache object
|
||||
:param key: The key of the model to lock in GPU
|
||||
:param model: The model to lock
|
||||
:param gpu_load: True if load into gpu
|
||||
:param size_needed: Size of the model to load
|
||||
'''
|
||||
self.gpu_load = gpu_load
|
||||
self.cache = cache
|
||||
self.key = key
|
||||
self.model = model
|
||||
self.size_needed = size_needed
|
||||
self.cache_entry = self.cache._cached_models[self.key]
|
||||
|
||||
def __enter__(self) -> Any:
|
||||
@ -222,7 +233,7 @@ class ModelCache(object):
|
||||
|
||||
try:
|
||||
if self.cache.lazy_offloading:
|
||||
self.cache._offload_unlocked_models()
|
||||
self.cache._offload_unlocked_models(self.size_needed)
|
||||
|
||||
if self.model.device != self.cache.execution_device:
|
||||
self.cache.logger.debug(f'Moving {self.key} into {self.cache.execution_device}')
|
||||
@ -337,14 +348,20 @@ class ModelCache(object):
|
||||
|
||||
self.logger.debug(f"After unloading: cached_models={len(self._cached_models)}")
|
||||
|
||||
|
||||
def _offload_unlocked_models(self):
|
||||
for model_key, cache_entry in self._cached_models.items():
|
||||
def _offload_unlocked_models(self, size_needed: int=0):
|
||||
reserved = self.max_vram_cache_size * GIG
|
||||
vram_in_use = torch.cuda.memory_allocated()
|
||||
self.logger.debug(f'{(vram_in_use/GIG):.2f}GB VRAM used for models; max allowed={(reserved/GIG):.2f}GB')
|
||||
for model_key, cache_entry in sorted(self._cached_models.items(), key=lambda x:x[1].size):
|
||||
if vram_in_use <= reserved:
|
||||
break
|
||||
if not cache_entry.locked and cache_entry.loaded:
|
||||
self.logger.debug(f'Offloading {model_key} from {self.execution_device} into {self.storage_device}')
|
||||
with VRAMUsage() as mem:
|
||||
cache_entry.model.to(self.storage_device)
|
||||
self.logger.debug(f'GPU VRAM freed: {(mem.vram_used/GIG):.2f} GB')
|
||||
vram_in_use += mem.vram_used # note vram_used is negative
|
||||
self.logger.debug(f'{(vram_in_use/GIG):.2f}GB VRAM used for models; max allowed={(reserved/GIG):.2f}GB')
|
||||
|
||||
def _local_model_hash(self, model_path: Union[str, Path]) -> str:
|
||||
sha = hashlib.sha256()
|
||||
|
@ -231,6 +231,7 @@ from __future__ import annotations
|
||||
import os
|
||||
import hashlib
|
||||
import textwrap
|
||||
import yaml
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Optional, List, Tuple, Union, Dict, Set, Callable, types
|
||||
@ -246,11 +247,12 @@ import invokeai.backend.util.logging as logger
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.backend.util import CUDA_DEVICE, Chdir
|
||||
from .model_cache import ModelCache, ModelLocker
|
||||
from .model_search import ModelSearch
|
||||
from .models import (
|
||||
BaseModelType, ModelType, SubModelType,
|
||||
ModelError, SchedulerPredictionType, MODEL_CLASSES,
|
||||
ModelConfigBase, ModelNotFoundException,
|
||||
)
|
||||
ModelConfigBase, ModelNotFoundException, InvalidModelException,
|
||||
)
|
||||
|
||||
# We are only starting to number the config file with release 3.
|
||||
# The config file version doesn't have to start at release version, but it will help
|
||||
@ -274,10 +276,6 @@ class ModelInfo():
|
||||
def __exit__(self,*args, **kwargs):
|
||||
self.context.__exit__(*args, **kwargs)
|
||||
|
||||
class InvalidModelError(Exception):
|
||||
"Raised when an invalid model is requested"
|
||||
pass
|
||||
|
||||
class AddModelResult(BaseModel):
|
||||
name: str = Field(description="The name of the model after installation")
|
||||
model_type: ModelType = Field(description="The type of model")
|
||||
@ -314,6 +312,9 @@ class ModelManager(object):
|
||||
self.config_path = None
|
||||
if isinstance(config, (str, Path)):
|
||||
self.config_path = Path(config)
|
||||
if not self.config_path.exists():
|
||||
logger.warning(f'The file {self.config_path} was not found. Initializing a new file')
|
||||
self.initialize_model_config(self.config_path)
|
||||
config = OmegaConf.load(self.config_path)
|
||||
|
||||
elif not isinstance(config, DictConfig):
|
||||
@ -322,9 +323,31 @@ class ModelManager(object):
|
||||
self.config_meta = ConfigMeta(**config.pop("__metadata__"))
|
||||
# TODO: metadata not found
|
||||
# TODO: version check
|
||||
|
||||
self.app_config = InvokeAIAppConfig.get_config()
|
||||
self.logger = logger
|
||||
self.cache = ModelCache(
|
||||
max_cache_size=max_cache_size,
|
||||
max_vram_cache_size = self.app_config.max_vram_cache_size,
|
||||
execution_device = device_type,
|
||||
precision = precision,
|
||||
sequential_offload = sequential_offload,
|
||||
logger = logger,
|
||||
)
|
||||
|
||||
self._read_models(config)
|
||||
|
||||
def _read_models(self, config: Optional[DictConfig] = None):
|
||||
if not config:
|
||||
if self.config_path:
|
||||
config = OmegaConf.load(self.config_path)
|
||||
else:
|
||||
return
|
||||
|
||||
self.models = dict()
|
||||
for model_key, model_config in config.items():
|
||||
if model_key.startswith('_'):
|
||||
continue
|
||||
model_name, base_model, model_type = self.parse_key(model_key)
|
||||
model_class = MODEL_CLASSES[base_model][model_type]
|
||||
# alias for config file
|
||||
@ -332,20 +355,20 @@ class ModelManager(object):
|
||||
self.models[model_key] = model_class.create_config(**model_config)
|
||||
|
||||
# check config version number and update on disk/RAM if necessary
|
||||
self.app_config = InvokeAIAppConfig.get_config()
|
||||
self.logger = logger
|
||||
self.cache = ModelCache(
|
||||
max_cache_size=max_cache_size,
|
||||
execution_device = device_type,
|
||||
precision = precision,
|
||||
sequential_offload = sequential_offload,
|
||||
logger = logger,
|
||||
)
|
||||
self.cache_keys = dict()
|
||||
|
||||
# add controlnet, lora and textual_inversion models from disk
|
||||
self.scan_models_directory()
|
||||
|
||||
def sync_to_config(self):
|
||||
"""
|
||||
Call this when `models.yaml` has been changed externally.
|
||||
This will reinitialize internal data structures
|
||||
"""
|
||||
# Reread models directory; note that this will reinitialize the cache,
|
||||
# causing otherwise unreferenced models to be removed from memory
|
||||
self._read_models()
|
||||
|
||||
def model_exists(
|
||||
self,
|
||||
model_name: str,
|
||||
@ -386,6 +409,16 @@ class ModelManager(object):
|
||||
def _get_model_cache_path(self, model_path):
|
||||
return self.app_config.models_path / ".cache" / hashlib.md5(str(model_path).encode()).hexdigest()
|
||||
|
||||
@classmethod
|
||||
def initialize_model_config(cls, config_path: Path):
|
||||
"""Create empty config file"""
|
||||
with open(config_path,'w') as yaml_file:
|
||||
yaml_file.write(yaml.dump({'__metadata__':
|
||||
{'version':'3.0.0'}
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
def get_model(
|
||||
self,
|
||||
model_name: str,
|
||||
@ -516,7 +549,10 @@ class ModelManager(object):
|
||||
model_keys = [self.create_key(model_name, base_model, model_type)] if model_name else sorted(self.models, key=str.casefold)
|
||||
models = []
|
||||
for model_key in model_keys:
|
||||
model_config = self.models[model_key]
|
||||
model_config = self.models.get(model_key)
|
||||
if not model_config:
|
||||
self.logger.error(f'Unknown model {model_name}')
|
||||
raise KeyError(f'Unknown model {model_name}')
|
||||
|
||||
cur_model_name, cur_base_model, cur_model_type = self.parse_key(model_key)
|
||||
if base_model is not None and cur_base_model != base_model:
|
||||
@ -527,9 +563,9 @@ class ModelManager(object):
|
||||
model_dict = dict(
|
||||
**model_config.dict(exclude_defaults=True),
|
||||
# OpenAPIModelInfoBase
|
||||
name=cur_model_name,
|
||||
model_name=cur_model_name,
|
||||
base_model=cur_base_model,
|
||||
type=cur_model_type,
|
||||
model_type=cur_model_type,
|
||||
)
|
||||
|
||||
models.append(model_dict)
|
||||
@ -578,6 +614,7 @@ class ModelManager(object):
|
||||
rmtree(str(model_path))
|
||||
else:
|
||||
model_path.unlink()
|
||||
self.commit()
|
||||
|
||||
# LS: tested
|
||||
def add_model(
|
||||
@ -634,11 +671,61 @@ class ModelManager(object):
|
||||
config = model_config,
|
||||
)
|
||||
|
||||
def rename_model(
|
||||
self,
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
new_name: str = None,
|
||||
new_base: BaseModelType = None,
|
||||
):
|
||||
'''
|
||||
Rename or rebase a model.
|
||||
'''
|
||||
if new_name is None and new_base is None:
|
||||
self.logger.error("rename_model() called with neither a new_name nor a new_base. {model_name} unchanged.")
|
||||
return
|
||||
|
||||
model_key = self.create_key(model_name, base_model, model_type)
|
||||
model_cfg = self.models.get(model_key, None)
|
||||
if not model_cfg:
|
||||
raise KeyError(f"Unknown model: {model_key}")
|
||||
|
||||
old_path = self.app_config.root_path / model_cfg.path
|
||||
new_name = new_name or model_name
|
||||
new_base = new_base or base_model
|
||||
new_key = self.create_key(new_name, new_base, model_type)
|
||||
if new_key in self.models:
|
||||
raise ValueError(f'Attempt to overwrite existing model definition "{new_key}"')
|
||||
|
||||
# if this is a model file/directory that we manage ourselves, we need to move it
|
||||
if old_path.is_relative_to(self.app_config.models_path):
|
||||
new_path = self.app_config.root_path / 'models' / new_base.value / model_type.value / new_name
|
||||
move(old_path, new_path)
|
||||
model_cfg.path = str(new_path.relative_to(self.app_config.root_path))
|
||||
|
||||
# clean up caches
|
||||
old_model_cache = self._get_model_cache_path(old_path)
|
||||
if old_model_cache.exists():
|
||||
if old_model_cache.is_dir():
|
||||
rmtree(str(old_model_cache))
|
||||
else:
|
||||
old_model_cache.unlink()
|
||||
|
||||
cache_ids = self.cache_keys.pop(model_key, [])
|
||||
for cache_id in cache_ids:
|
||||
self.cache.uncache_model(cache_id)
|
||||
|
||||
self.models.pop(model_key, None) # delete
|
||||
self.models[new_key] = model_cfg
|
||||
self.commit()
|
||||
|
||||
def convert_model (
|
||||
self,
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: Union[ModelType.Main,ModelType.Vae],
|
||||
dest_directory: Optional[Path]=None,
|
||||
) -> AddModelResult:
|
||||
'''
|
||||
Convert a checkpoint file into a diffusers folder, deleting the cached
|
||||
@ -665,14 +752,14 @@ class ModelManager(object):
|
||||
)
|
||||
checkpoint_path = self.app_config.root_path / info["path"]
|
||||
old_diffusers_path = self.app_config.models_path / model.location
|
||||
new_diffusers_path = self.app_config.models_path / base_model.value / model_type.value / model_name
|
||||
new_diffusers_path = (dest_directory or self.app_config.models_path / base_model.value / model_type.value) / model_name
|
||||
if new_diffusers_path.exists():
|
||||
raise ValueError(f"A diffusers model already exists at {new_diffusers_path}")
|
||||
|
||||
try:
|
||||
move(old_diffusers_path,new_diffusers_path)
|
||||
info["model_format"] = "diffusers"
|
||||
info["path"] = str(new_diffusers_path.relative_to(self.app_config.root_path))
|
||||
info["path"] = str(new_diffusers_path) if dest_directory else str(new_diffusers_path.relative_to(self.app_config.root_path))
|
||||
info.pop('config')
|
||||
|
||||
result = self.add_model(model_name, base_model, model_type,
|
||||
@ -802,6 +889,8 @@ class ModelManager(object):
|
||||
model_config: ModelConfigBase = model_class.probe_config(str(model_path))
|
||||
self.models[model_key] = model_config
|
||||
new_models_found = True
|
||||
except InvalidModelException:
|
||||
self.logger.warning(f"Not a valid model: {model_path}")
|
||||
except NotImplementedError as e:
|
||||
self.logger.warning(e)
|
||||
|
||||
@ -810,6 +899,7 @@ class ModelManager(object):
|
||||
if (new_models_found or imported_models) and self.config_path:
|
||||
self.commit()
|
||||
|
||||
|
||||
def autoimport(self)->Dict[str, AddModelResult]:
|
||||
'''
|
||||
Scan the autoimport directory (if defined) and import new models, delete defunct models.
|
||||
@ -817,57 +907,42 @@ class ModelManager(object):
|
||||
# avoid circular import
|
||||
from invokeai.backend.install.model_install_backend import ModelInstall
|
||||
from invokeai.frontend.install.model_install import ask_user_for_prediction_type
|
||||
|
||||
|
||||
|
||||
class ScanAndImport(ModelSearch):
|
||||
def __init__(self, directories, logger, ignore: Set[Path], installer: ModelInstall):
|
||||
super().__init__(directories, logger)
|
||||
self.installer = installer
|
||||
self.ignore = ignore
|
||||
|
||||
def on_search_started(self):
|
||||
self.new_models_found = dict()
|
||||
|
||||
def on_model_found(self, model: Path):
|
||||
if model not in self.ignore:
|
||||
self.new_models_found.update(self.installer.heuristic_import(model))
|
||||
|
||||
def on_search_completed(self):
|
||||
self.logger.info(f'Scanned {self._items_scanned} files and directories, imported {len(self.new_models_found)} models')
|
||||
|
||||
def models_found(self):
|
||||
return self.new_models_found
|
||||
|
||||
|
||||
installer = ModelInstall(config = self.app_config,
|
||||
model_manager = self,
|
||||
prediction_type_helper = ask_user_for_prediction_type,
|
||||
)
|
||||
|
||||
scanned_dirs = set()
|
||||
|
||||
config = self.app_config
|
||||
known_paths = {(self.app_config.root_path / x['path']) for x in self.list_models()}
|
||||
|
||||
for autodir in [config.autoimport_dir,
|
||||
config.lora_dir,
|
||||
config.embedding_dir,
|
||||
config.controlnet_dir]:
|
||||
if autodir is None:
|
||||
continue
|
||||
|
||||
self.logger.info(f'Scanning {autodir} for models to import')
|
||||
installed = dict()
|
||||
|
||||
autodir = self.app_config.root_path / autodir
|
||||
if not autodir.exists():
|
||||
continue
|
||||
|
||||
items_scanned = 0
|
||||
new_models_found = dict()
|
||||
|
||||
for root, dirs, files in os.walk(autodir):
|
||||
items_scanned += len(dirs) + len(files)
|
||||
for d in dirs:
|
||||
path = Path(root) / d
|
||||
if path in known_paths or path.parent in scanned_dirs:
|
||||
scanned_dirs.add(path)
|
||||
continue
|
||||
if any([(path/x).exists() for x in {'config.json','model_index.json','learned_embeds.bin','pytorch_lora_weights.bin'}]):
|
||||
new_models_found.update(installer.heuristic_import(path))
|
||||
scanned_dirs.add(path)
|
||||
|
||||
for f in files:
|
||||
path = Path(root) / f
|
||||
if path in known_paths or path.parent in scanned_dirs:
|
||||
continue
|
||||
if path.suffix in {'.ckpt','.bin','.pth','.safetensors','.pt'}:
|
||||
import_result = installer.heuristic_import(path)
|
||||
new_models_found.update(import_result)
|
||||
|
||||
self.logger.info(f'Scanned {items_scanned} files and directories, imported {len(new_models_found)} models')
|
||||
installed.update(new_models_found)
|
||||
|
||||
return installed
|
||||
known_paths = {config.root_path / x['path'] for x in self.list_models()}
|
||||
directories = {config.root_path / x for x in [config.autoimport_dir,
|
||||
config.lora_dir,
|
||||
config.embedding_dir,
|
||||
config.controlnet_dir]
|
||||
}
|
||||
scanner = ScanAndImport(directories, self.logger, ignore=known_paths, installer=installer)
|
||||
scanner.search()
|
||||
return scanner.models_found()
|
||||
|
||||
def heuristic_import(self,
|
||||
items_to_import: Set[str],
|
||||
@ -905,3 +980,4 @@ class ModelManager(object):
|
||||
successfully_installed.update(installed)
|
||||
self.commit()
|
||||
return successfully_installed
|
||||
|
||||
|
@ -11,7 +11,7 @@ from enum import Enum
|
||||
from pathlib import Path
|
||||
from diffusers import DiffusionPipeline
|
||||
from diffusers import logging as dlogging
|
||||
from typing import List, Union
|
||||
from typing import List, Union, Optional
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
|
||||
@ -74,6 +74,7 @@ class ModelMerger(object):
|
||||
alpha: float = 0.5,
|
||||
interp: MergeInterpolationMethod = None,
|
||||
force: bool = False,
|
||||
merge_dest_directory: Optional[Path] = None,
|
||||
**kwargs,
|
||||
) -> AddModelResult:
|
||||
"""
|
||||
@ -85,7 +86,7 @@ class ModelMerger(object):
|
||||
:param interp: The interpolation method to use for the merging. Supports "weighted_average", "sigmoid", "inv_sigmoid", "add_difference" and None.
|
||||
Passing None uses the default interpolation which is weighted sum interpolation. For merging three checkpoints, only "add_difference" is supported. Add_difference is A+(B-C).
|
||||
:param force: Whether to ignore mismatch in model_config.json for the current models. Defaults to False.
|
||||
|
||||
:param merge_dest_directory: Save the merged model to the designated directory (with 'merged_model_name' appended)
|
||||
**kwargs - the default DiffusionPipeline.get_config_dict kwargs:
|
||||
cache_dir, resume_download, force_download, proxies, local_files_only, use_auth_token, revision, torch_dtype, device_map
|
||||
"""
|
||||
@ -111,7 +112,7 @@ class ModelMerger(object):
|
||||
merged_pipe = self.merge_diffusion_models(
|
||||
model_paths, alpha, merge_method, force, **kwargs
|
||||
)
|
||||
dump_path = config.models_path / base_model.value / ModelType.Main.value
|
||||
dump_path = Path(merge_dest_directory) if merge_dest_directory else config.models_path / base_model.value / ModelType.Main.value
|
||||
dump_path.mkdir(parents=True, exist_ok=True)
|
||||
dump_path = dump_path / merged_model_name
|
||||
|
||||
|
@ -61,7 +61,7 @@ class ModelProbe(object):
|
||||
elif isinstance(model,(dict,ModelMixin,ConfigMixin)):
|
||||
return cls.probe(model_path=None, model=model, prediction_type_helper=prediction_type_helper)
|
||||
else:
|
||||
raise Exception("model parameter {model} is neither a Path, nor a model")
|
||||
raise ValueError("model parameter {model} is neither a Path, nor a model")
|
||||
|
||||
@classmethod
|
||||
def probe(cls,
|
||||
@ -240,7 +240,7 @@ class CheckpointProbeBase(ProbeBase):
|
||||
elif in_channels == 4:
|
||||
return ModelVariantType.Normal
|
||||
else:
|
||||
raise Exception("Cannot determine variant type")
|
||||
raise ValueError(f"Cannot determine variant type (in_channels={in_channels}) at {self.checkpoint_path}")
|
||||
|
||||
class PipelineCheckpointProbe(CheckpointProbeBase):
|
||||
def get_base_type(self)->BaseModelType:
|
||||
@ -254,7 +254,7 @@ class PipelineCheckpointProbe(CheckpointProbeBase):
|
||||
# TODO: Verify that this is correct! Need an XL checkpoint file for this.
|
||||
if key_name in state_dict and state_dict[key_name].shape[-1] == 2048:
|
||||
return BaseModelType.StableDiffusionXL
|
||||
raise Exception("Cannot determine base type")
|
||||
raise ValueError("Cannot determine base type")
|
||||
|
||||
def get_scheduler_prediction_type(self)->SchedulerPredictionType:
|
||||
type = self.get_base_type()
|
||||
@ -335,7 +335,7 @@ class ControlNetCheckpointProbe(CheckpointProbeBase):
|
||||
return BaseModelType.StableDiffusion2
|
||||
elif self.checkpoint_path and self.helper:
|
||||
return self.helper(self.checkpoint_path)
|
||||
raise Exception("Unable to determine base type for {self.checkpoint_path}")
|
||||
raise ValueError("Unable to determine base type for {self.checkpoint_path}")
|
||||
|
||||
########################################################
|
||||
# classes for probing folders
|
||||
@ -428,7 +428,7 @@ class ControlNetFolderProbe(FolderProbeBase):
|
||||
def get_base_type(self)->BaseModelType:
|
||||
config_file = self.folder_path / 'config.json'
|
||||
if not config_file.exists():
|
||||
raise Exception(f"Cannot determine base type for {self.folder_path}")
|
||||
raise ValueError(f"Cannot determine base type for {self.folder_path}")
|
||||
with open(config_file,'r') as file:
|
||||
config = json.load(file)
|
||||
# no obvious way to distinguish between sd2-base and sd2-768
|
||||
@ -445,7 +445,7 @@ class LoRAFolderProbe(FolderProbeBase):
|
||||
model_file = base_file
|
||||
break
|
||||
if not model_file:
|
||||
raise Exception('Unknown LoRA format encountered')
|
||||
raise ValueError('Unknown LoRA format encountered')
|
||||
return LoRACheckpointProbe(model_file,None).get_base_type()
|
||||
|
||||
############## register probe classes ######
|
||||
|
103
invokeai/backend/model_management/model_search.py
Normal file
103
invokeai/backend/model_management/model_search.py
Normal file
@ -0,0 +1,103 @@
|
||||
# Copyright 2023, Lincoln D. Stein and the InvokeAI Team
|
||||
"""
|
||||
Abstract base class for recursive directory search for models.
|
||||
"""
|
||||
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Set, types
|
||||
from pathlib import Path
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
|
||||
class ModelSearch(ABC):
|
||||
def __init__(self, directories: List[Path], logger: types.ModuleType=logger):
|
||||
"""
|
||||
Initialize a recursive model directory search.
|
||||
:param directories: List of directory Paths to recurse through
|
||||
:param logger: Logger to use
|
||||
"""
|
||||
self.directories = directories
|
||||
self.logger = logger
|
||||
self._items_scanned = 0
|
||||
self._models_found = 0
|
||||
self._scanned_dirs = set()
|
||||
self._scanned_paths = set()
|
||||
self._pruned_paths = set()
|
||||
|
||||
@abstractmethod
|
||||
def on_search_started(self):
|
||||
"""
|
||||
Called before the scan starts.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def on_model_found(self, model: Path):
|
||||
"""
|
||||
Process a found model. Raise an exception if something goes wrong.
|
||||
:param model: Model to process - could be a directory or checkpoint.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def on_search_completed(self):
|
||||
"""
|
||||
Perform some activity when the scan is completed. May use instance
|
||||
variables, items_scanned and models_found
|
||||
"""
|
||||
pass
|
||||
|
||||
def search(self):
|
||||
self.on_search_started()
|
||||
for dir in self.directories:
|
||||
self.walk_directory(dir)
|
||||
self.on_search_completed()
|
||||
|
||||
def walk_directory(self, path: Path):
|
||||
for root, dirs, files in os.walk(path):
|
||||
if str(Path(root).name).startswith('.'):
|
||||
self._pruned_paths.add(root)
|
||||
if any([Path(root).is_relative_to(x) for x in self._pruned_paths]):
|
||||
continue
|
||||
|
||||
self._items_scanned += len(dirs) + len(files)
|
||||
for d in dirs:
|
||||
path = Path(root) / d
|
||||
if path in self._scanned_paths or path.parent in self._scanned_dirs:
|
||||
self._scanned_dirs.add(path)
|
||||
continue
|
||||
if any([(path/x).exists() for x in {'config.json','model_index.json','learned_embeds.bin','pytorch_lora_weights.bin'}]):
|
||||
try:
|
||||
self.on_model_found(path)
|
||||
self._models_found += 1
|
||||
self._scanned_dirs.add(path)
|
||||
except Exception as e:
|
||||
self.logger.warning(str(e))
|
||||
|
||||
for f in files:
|
||||
path = Path(root) / f
|
||||
if path.parent in self._scanned_dirs:
|
||||
continue
|
||||
if path.suffix in {'.ckpt','.bin','.pth','.safetensors','.pt'}:
|
||||
try:
|
||||
self.on_model_found(path)
|
||||
self._models_found += 1
|
||||
except Exception as e:
|
||||
self.logger.warning(str(e))
|
||||
|
||||
class FindModels(ModelSearch):
|
||||
def on_search_started(self):
|
||||
self.models_found: Set[Path] = set()
|
||||
|
||||
def on_model_found(self,model: Path):
|
||||
self.models_found.add(model)
|
||||
|
||||
def on_search_completed(self):
|
||||
pass
|
||||
|
||||
def list_models(self) -> List[Path]:
|
||||
self.search()
|
||||
return self.models_found
|
||||
|
||||
|
@ -2,7 +2,7 @@ import inspect
|
||||
from enum import Enum
|
||||
from pydantic import BaseModel
|
||||
from typing import Literal, get_origin
|
||||
from .base import BaseModelType, ModelType, SubModelType, ModelBase, ModelConfigBase, ModelVariantType, SchedulerPredictionType, ModelError, SilenceWarnings, ModelNotFoundException
|
||||
from .base import BaseModelType, ModelType, SubModelType, ModelBase, ModelConfigBase, ModelVariantType, SchedulerPredictionType, ModelError, SilenceWarnings, ModelNotFoundException, InvalidModelException
|
||||
from .stable_diffusion import StableDiffusion1Model, StableDiffusion2Model
|
||||
from .sdxl import StableDiffusionXLModel
|
||||
from .vae import VaeModel
|
||||
@ -54,9 +54,9 @@ MODEL_CONFIGS = list()
|
||||
OPENAPI_MODEL_CONFIGS = list()
|
||||
|
||||
class OpenAPIModelInfoBase(BaseModel):
|
||||
name: str
|
||||
model_name: str
|
||||
base_model: BaseModelType
|
||||
type: ModelType
|
||||
model_type: ModelType
|
||||
|
||||
|
||||
for base_model, models in MODEL_CLASSES.items():
|
||||
@ -65,7 +65,9 @@ for base_model, models in MODEL_CLASSES.items():
|
||||
model_configs.discard(None)
|
||||
MODEL_CONFIGS.extend(model_configs)
|
||||
|
||||
for cfg in model_configs:
|
||||
# LS: sort to get the checkpoint configs first, which makes
|
||||
# for a better template in the Swagger docs
|
||||
for cfg in sorted(model_configs, key=lambda x: str(x)):
|
||||
model_name, cfg_name = cfg.__qualname__.split('.')[-2:]
|
||||
openapi_cfg_name = model_name + cfg_name
|
||||
if openapi_cfg_name in vars():
|
||||
@ -73,7 +75,7 @@ for base_model, models in MODEL_CLASSES.items():
|
||||
|
||||
api_wrapper = type(openapi_cfg_name, (cfg, OpenAPIModelInfoBase), dict(
|
||||
__annotations__ = dict(
|
||||
type=Literal[model_type.value],
|
||||
model_type=Literal[model_type.value],
|
||||
),
|
||||
))
|
||||
|
||||
|
@ -15,6 +15,9 @@ from contextlib import suppress
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import List, Dict, Optional, Type, Literal, TypeVar, Generic, Callable, Any, Union
|
||||
|
||||
class InvalidModelException(Exception):
|
||||
pass
|
||||
|
||||
class ModelNotFoundException(Exception):
|
||||
pass
|
||||
|
||||
@ -60,7 +63,6 @@ class ModelConfigBase(BaseModel):
|
||||
path: str # or Path
|
||||
description: Optional[str] = Field(None)
|
||||
model_format: Optional[str] = Field(None)
|
||||
# do not save to config
|
||||
error: Optional[ModelError] = Field(None)
|
||||
|
||||
class Config:
|
||||
|
@ -1,8 +1,7 @@
|
||||
import os
|
||||
import torch
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Optional, Union, Literal
|
||||
from typing import Optional
|
||||
from .base import (
|
||||
ModelBase,
|
||||
ModelConfigBase,
|
||||
@ -13,6 +12,8 @@ from .base import (
|
||||
calc_model_size_by_fs,
|
||||
calc_model_size_by_data,
|
||||
classproperty,
|
||||
InvalidModelException,
|
||||
ModelNotFoundException,
|
||||
)
|
||||
|
||||
class ControlNetModelFormat(str, Enum):
|
||||
@ -59,10 +60,20 @@ class ControlNetModel(ModelBase):
|
||||
if child_type is not None:
|
||||
raise Exception("There is no child models in controlnet model")
|
||||
|
||||
model = self.model_class.from_pretrained(
|
||||
self.model_path,
|
||||
torch_dtype=torch_dtype,
|
||||
)
|
||||
model = None
|
||||
for variant in ['fp16',None]:
|
||||
try:
|
||||
model = self.model_class.from_pretrained(
|
||||
self.model_path,
|
||||
torch_dtype=torch_dtype,
|
||||
variant=variant,
|
||||
)
|
||||
break
|
||||
except:
|
||||
pass
|
||||
if not model:
|
||||
raise ModelNotFoundException()
|
||||
|
||||
# calc more accurate size
|
||||
self.model_size = calc_model_size_by_data(model)
|
||||
return model
|
||||
@ -73,10 +84,18 @@ class ControlNetModel(ModelBase):
|
||||
|
||||
@classmethod
|
||||
def detect_format(cls, path: str):
|
||||
if not os.path.exists(path):
|
||||
raise ModelNotFoundException()
|
||||
|
||||
if os.path.isdir(path):
|
||||
return ControlNetModelFormat.Diffusers
|
||||
else:
|
||||
return ControlNetModelFormat.Checkpoint
|
||||
if os.path.exists(os.path.join(path, "config.json")):
|
||||
return ControlNetModelFormat.Diffusers
|
||||
|
||||
if os.path.isfile(path):
|
||||
if any([path.endswith(f".{ext}") for ext in ["safetensors", "ckpt", "pt", "pth"]]):
|
||||
return ControlNetModelFormat.Checkpoint
|
||||
|
||||
raise InvalidModelException(f"Not a valid model: {path}")
|
||||
|
||||
@classmethod
|
||||
def convert_if_required(
|
||||
|
@ -9,6 +9,7 @@ from .base import (
|
||||
ModelType,
|
||||
SubModelType,
|
||||
classproperty,
|
||||
InvalidModelException,
|
||||
)
|
||||
# TODO: naming
|
||||
from ..lora import LoRAModel as LoRAModelRaw
|
||||
@ -56,10 +57,18 @@ class LoRAModel(ModelBase):
|
||||
|
||||
@classmethod
|
||||
def detect_format(cls, path: str):
|
||||
if not os.path.exists(path):
|
||||
raise ModelNotFoundException()
|
||||
|
||||
if os.path.isdir(path):
|
||||
return LoRAModelFormat.Diffusers
|
||||
else:
|
||||
return LoRAModelFormat.LyCORIS
|
||||
if os.path.exists(os.path.join(path, "pytorch_lora_weights.bin")):
|
||||
return LoRAModelFormat.Diffusers
|
||||
|
||||
if os.path.isfile(path):
|
||||
if any([path.endswith(f".{ext}") for ext in ["safetensors", "ckpt", "pt"]]):
|
||||
return LoRAModelFormat.LyCORIS
|
||||
|
||||
raise InvalidModelException(f"Not a valid model: {path}")
|
||||
|
||||
@classmethod
|
||||
def convert_if_required(
|
||||
|
@ -13,6 +13,7 @@ from .base import (
|
||||
SilenceWarnings,
|
||||
read_checkpoint_meta,
|
||||
classproperty,
|
||||
InvalidModelException,
|
||||
)
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from omegaconf import OmegaConf
|
||||
@ -33,8 +34,7 @@ class StableDiffusion1Model(DiffusersModel):
|
||||
vae: Optional[str] = Field(None)
|
||||
config: str
|
||||
variant: ModelVariantType
|
||||
|
||||
|
||||
|
||||
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
|
||||
assert base_model == BaseModelType.StableDiffusion1
|
||||
assert model_type == ModelType.Main
|
||||
@ -95,10 +95,18 @@ class StableDiffusion1Model(DiffusersModel):
|
||||
|
||||
@classmethod
|
||||
def detect_format(cls, model_path: str):
|
||||
if not os.path.exists(model_path):
|
||||
raise ModelNotFoundException()
|
||||
|
||||
if os.path.isdir(model_path):
|
||||
return StableDiffusion1ModelFormat.Diffusers
|
||||
else:
|
||||
return StableDiffusion1ModelFormat.Checkpoint
|
||||
if os.path.exists(os.path.join(model_path, "model_index.json")):
|
||||
return StableDiffusion1ModelFormat.Diffusers
|
||||
|
||||
if os.path.isfile(model_path):
|
||||
if any([model_path.endswith(f".{ext}") for ext in ["safetensors", "ckpt", "pt"]]):
|
||||
return StableDiffusion1ModelFormat.Checkpoint
|
||||
|
||||
raise InvalidModelException(f"Not a valid model: {model_path}")
|
||||
|
||||
@classmethod
|
||||
def convert_if_required(
|
||||
@ -197,10 +205,18 @@ class StableDiffusion2Model(DiffusersModel):
|
||||
|
||||
@classmethod
|
||||
def detect_format(cls, model_path: str):
|
||||
if not os.path.exists(model_path):
|
||||
raise ModelNotFoundException()
|
||||
|
||||
if os.path.isdir(model_path):
|
||||
return StableDiffusion2ModelFormat.Diffusers
|
||||
else:
|
||||
return StableDiffusion2ModelFormat.Checkpoint
|
||||
if os.path.exists(os.path.join(model_path, "model_index.json")):
|
||||
return StableDiffusion2ModelFormat.Diffusers
|
||||
|
||||
if os.path.isfile(model_path):
|
||||
if any([model_path.endswith(f".{ext}") for ext in ["safetensors", "ckpt", "pt"]]):
|
||||
return StableDiffusion2ModelFormat.Checkpoint
|
||||
|
||||
raise InvalidModelException(f"Not a valid model: {model_path}")
|
||||
|
||||
@classmethod
|
||||
def convert_if_required(
|
||||
|
@ -9,6 +9,7 @@ from .base import (
|
||||
SubModelType,
|
||||
classproperty,
|
||||
ModelNotFoundException,
|
||||
InvalidModelException,
|
||||
)
|
||||
# TODO: naming
|
||||
from ..lora import TextualInversionModel as TextualInversionModelRaw
|
||||
@ -59,7 +60,18 @@ class TextualInversionModel(ModelBase):
|
||||
|
||||
@classmethod
|
||||
def detect_format(cls, path: str):
|
||||
return None
|
||||
if not os.path.exists(path):
|
||||
raise ModelNotFoundException()
|
||||
|
||||
if os.path.isdir(path):
|
||||
if os.path.exists(os.path.join(path, "learned_embeds.bin")):
|
||||
return None # diffusers-ti
|
||||
|
||||
if os.path.isfile(path):
|
||||
if any([path.endswith(f".{ext}") for ext in ["safetensors", "ckpt", "pt"]]):
|
||||
return None
|
||||
|
||||
raise InvalidModelException(f"Not a valid model: {path}")
|
||||
|
||||
@classmethod
|
||||
def convert_if_required(
|
||||
|
@ -15,6 +15,7 @@ from .base import (
|
||||
calc_model_size_by_fs,
|
||||
calc_model_size_by_data,
|
||||
classproperty,
|
||||
InvalidModelException,
|
||||
)
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from diffusers.utils import is_safetensors_available
|
||||
@ -75,10 +76,18 @@ class VaeModel(ModelBase):
|
||||
|
||||
@classmethod
|
||||
def detect_format(cls, path: str):
|
||||
if not os.path.exists(path):
|
||||
raise ModelNotFoundException()
|
||||
|
||||
if os.path.isdir(path):
|
||||
return VaeModelFormat.Diffusers
|
||||
else:
|
||||
return VaeModelFormat.Checkpoint
|
||||
if os.path.exists(os.path.join(path, "config.json")):
|
||||
return VaeModelFormat.Diffusers
|
||||
|
||||
if os.path.isfile(path):
|
||||
if any([path.endswith(f".{ext}") for ext in ["safetensors", "ckpt", "pt"]]):
|
||||
return VaeModelFormat.Checkpoint
|
||||
|
||||
raise InvalidModelException(f"Not a valid model: {path}")
|
||||
|
||||
@classmethod
|
||||
def convert_if_required(
|
||||
|
@ -127,7 +127,7 @@ class AddsMaskGuidance:
|
||||
|
||||
def _t_for_field(self, field_name: str, t):
|
||||
if field_name == "pred_original_sample":
|
||||
return torch.zeros_like(t, dtype=t.dtype) # it represents t=0
|
||||
return self.scheduler.timesteps[-1]
|
||||
return t
|
||||
|
||||
def apply_mask(self, latents: torch.Tensor, t) -> torch.Tensor:
|
||||
@ -631,7 +631,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
control_latent_input = torch.cat([unet_latent_input] * 2)
|
||||
|
||||
if cfg_injection: # only applying ControlNet to conditional instead of in unconditioned
|
||||
encoder_hidden_states = torch.cat([conditioning_data.unconditioned_embeddings])
|
||||
encoder_hidden_states = conditioning_data.text_embeddings
|
||||
else:
|
||||
encoder_hidden_states = torch.cat([conditioning_data.unconditioned_embeddings,
|
||||
conditioning_data.text_embeddings])
|
||||
|
@ -241,11 +241,45 @@ class InvokeAIDiffuserComponent:
|
||||
|
||||
def _apply_standard_conditioning(self, x, sigma, unconditioning, conditioning, **kwargs):
|
||||
# fast batched path
|
||||
|
||||
def _pad_conditioning(cond, target_len, encoder_attention_mask):
|
||||
conditioning_attention_mask = torch.ones((cond.shape[0], cond.shape[1]), device=cond.device, dtype=cond.dtype)
|
||||
|
||||
if cond.shape[1] < max_len:
|
||||
conditioning_attention_mask = torch.cat([
|
||||
conditioning_attention_mask,
|
||||
torch.zeros((cond.shape[0], max_len - cond.shape[1]), device=cond.device, dtype=cond.dtype),
|
||||
], dim=1)
|
||||
|
||||
cond = torch.cat([
|
||||
cond,
|
||||
torch.zeros((cond.shape[0], max_len - cond.shape[1], cond.shape[2]), device=cond.device, dtype=cond.dtype),
|
||||
], dim=1)
|
||||
|
||||
if encoder_attention_mask is None:
|
||||
encoder_attention_mask = conditioning_attention_mask
|
||||
else:
|
||||
encoder_attention_mask = torch.cat([
|
||||
encoder_attention_mask,
|
||||
conditioning_attention_mask,
|
||||
])
|
||||
|
||||
return cond, encoder_attention_mask
|
||||
|
||||
x_twice = torch.cat([x] * 2)
|
||||
sigma_twice = torch.cat([sigma] * 2)
|
||||
|
||||
encoder_attention_mask = None
|
||||
if unconditioning.shape[1] != conditioning.shape[1]:
|
||||
max_len = max(unconditioning.shape[1], conditioning.shape[1])
|
||||
unconditioning, encoder_attention_mask = _pad_conditioning(unconditioning, max_len, encoder_attention_mask)
|
||||
conditioning, encoder_attention_mask = _pad_conditioning(conditioning, max_len, encoder_attention_mask)
|
||||
|
||||
both_conditionings = torch.cat([unconditioning, conditioning])
|
||||
both_results = self.model_forward_callback(
|
||||
x_twice, sigma_twice, both_conditionings, **kwargs,
|
||||
x_twice, sigma_twice, both_conditionings,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
**kwargs,
|
||||
)
|
||||
unconditioned_next_x, conditioned_next_x = both_results.chunk(2)
|
||||
return unconditioned_next_x, conditioned_next_x
|
||||
|
Reference in New Issue
Block a user