diff --git a/invokeai/app/invocations/latent.py b/invokeai/app/invocations/latent.py index 39fb301aad..2aa93ac720 100644 --- a/invokeai/app/invocations/latent.py +++ b/invokeai/app/invocations/latent.py @@ -248,7 +248,6 @@ class TextToLatentsInvocation(BaseInvocation): feature_extractor=None, requires_safety_checker=False, precision="float16" if unet.dtype == torch.float16 else "float32", - #precision="float16", # TODO: ) def prep_control_data(self, diff --git a/invokeai/backend/model_management/model_cache.py b/invokeai/backend/model_management/model_cache.py index 0b9c267c7f..c92f875a4f 100644 --- a/invokeai/backend/model_management/model_cache.py +++ b/invokeai/backend/model_management/model_cache.py @@ -40,456 +40,7 @@ from invokeai.app.services.config import get_invokeai_config from .lora import LoRAModel, TextualInversionModel -def get_model_path(repo_id_or_path: str): - globals = get_invokeai_config() - - if os.path.exists(repo_id_or_path): - return repo_id_or_path - - cache = scan_cache_dir(globals.cache_dir) - for repo in cache.repos: - if repo.repo_id != repo_id_or_path: - continue - for rev in repo.revisions: - if "main" in rev.refs: - return rev.snapshot_path - raise Exception(f"{repo_id_or_path} - not found") - -def calc_model_size_by_fs( - repo_id_or_path: str, - subfolder: Optional[str] = None, - variant: Optional[str] = None -): - model_path = get_model_path(repo_id_or_path) - if subfolder is not None: - model_path = os.path.join(model_path, subfolder) - - # this can happen when, for example, the safety checker - # is not downloaded. - if not os.path.exists(model_path): - return 0 - - all_files = os.listdir(model_path) - all_files = [f for f in all_files if os.path.isfile(os.path.join(model_path, f))] - - fp16_files = set([f for f in all_files if ".fp16." in f or ".fp16-" in f]) - bit8_files = set([f for f in all_files if ".8bit." in f or ".8bit-" in f]) - other_files = set(all_files) - fp16_files - bit8_files - - if variant is None: - files = other_files - elif variant == "fp16": - files = fp16_files - elif variant == "8bit": - files = bit8_files - else: - raise NotImplementedError(f"Unknown variant: {variant}") - - # try read from index if exists - index_postfix = ".index.json" - if variant is not None: - index_postfix = f".index.{variant}.json" - - for file in files: - if not file.endswith(index_postfix): - continue - try: - with open(os.path.join(model_path, file), "r") as f: - index_data = json.loads(f.read()) - return int(index_data["metadata"]["total_size"]) - except: - pass - - # calculate files size if there is no index file - formats = [ - (".safetensors",), # safetensors - (".bin",), # torch - (".onnx", ".pb"), # onnx - (".msgpack",), # flax - (".ckpt",), # tf - (".h5",), # tf2 - ] - - for file_format in formats: - model_files = [f for f in files if f.endswith(file_format)] - if len(model_files) == 0: - continue - - model_size = 0 - for model_file in model_files: - file_stats = os.stat(os.path.join(model_path, model_file)) - model_size += file_stats.st_size - return model_size - - #raise NotImplementedError(f"Unknown model structure! Files: {all_files}") - return 0 # scheduler/feature_extractor/tokenizer - models without loading to gpu - - -def calc_model_size_by_data(model) -> int: - if isinstance(model, DiffusionPipeline): - return _calc_pipeline_by_data(model) - elif isinstance(model, torch.nn.Module): - return _calc_model_by_data(model) - else: - return 0 - - -def _calc_pipeline_by_data(pipeline) -> int: - res = 0 - for submodel_key in pipeline.components.keys(): - submodel = getattr(pipeline, submodel_key) - if submodel is not None and isinstance(submodel, torch.nn.Module): - res += _calc_model_by_data(submodel) - return res - - -def _calc_model_by_data(model) -> int: - mem_params = sum([param.nelement()*param.element_size() for param in model.parameters()]) - mem_bufs = sum([buf.nelement()*buf.element_size() for buf in model.buffers()]) - mem = mem_params + mem_bufs # in bytes - return mem - - - - -class SDModelType(str, Enum): - Diffusers = "diffusers" - Classifier = "classifier" - UNet = "unet" - TextEncoder = "text_encoder" - Tokenizer = "tokenizer" - Vae = "vae" - Scheduler = "scheduler" - Lora = "lora" - TextualInversion = "textual_inversion" - ControlNet = "control_net" - -class BaseModel(str, Enum): - StableDiffusion1_5 = "SD-1" - StableDiffusion2Base = "SD-2-base" # 512 pixels; this will have epsilon parameterization - StableDiffusion2 = "SD-2" # 768 pixels; this will have v-prediction parameterization - -class ModelInfoBase: - #model_path: str - #model_type: SDModelType - - def __init__(self, repo_id_or_path: str, model_type: SDModelType): - self.repo_id_or_path = repo_id_or_path # TODO: or use allways path? - self.model_path = get_model_path(repo_id_or_path) - self.model_type = model_type - - def _definition_to_type(self, subtypes: List[str]) -> Type: - if len(subtypes) < 2: - raise Exception("Invalid subfolder definition!") - if subtypes[0] in ["diffusers", "transformers"]: - res_type = sys.modules[subtypes[0]] - subtypes = subtypes[1:] - - else: - res_type = sys.modules["diffusers"] - res_type = getattr(res_type, "pipelines") - - - for subtype in subtypes: - res_type = getattr(res_type, subtype) - return res_type - - -class DiffusersModelInfo(ModelInfoBase): - #child_types: Dict[str, Type] - #child_sizes: Dict[str, int] - - def __init__(self, repo_id_or_path: str, model_type: SDModelType): - assert model_type == SDModelType.Diffusers - super().__init__(repo_id_or_path, model_type) - - self.child_types: Dict[str, Type] = dict() - self.child_sizes: Dict[str, int] = dict() - - try: - config_data = DiffusionPipeline.load_config(repo_id_or_path) - #config_data = json.loads(os.path.join(self.model_path, "model_index.json")) - except: - raise Exception("Invalid diffusers model! (model_index.json not found or invalid)") - - config_data.pop("_ignore_files", None) - - # retrieve all folder_names that contain relevant files - child_components = [k for k, v in config_data.items() if isinstance(v, list)] - - for child_name in child_components: - child_type = self._definition_to_type(config_data[child_name]) - self.child_types[child_name] = child_type - self.child_sizes[child_name] = calc_model_size_by_fs(repo_id_or_path, subfolder=child_name) - - - def get_size(self, child_type: Optional[SDModelType] = None): - if child_type is None: - return sum(self.child_sizes.values()) - else: - return self.child_sizes[child_type] - - - def get_model( - self, - child_type: Optional[SDModelType] = None, - torch_dtype: Optional[torch.dtype] = None, - ): - # return pipeline in different function to pass more arguments - if child_type is None: - raise Exception("Child model type can't be null on diffusers model") - if child_type not in self.child_types: - return None # TODO: or raise - - # TODO: - for variant in ["fp16", "main", None]: - try: - model = self.child_types[child_type].from_pretrained( - self.repo_id_or_path, - subfolder=child_type.value, - cache_dir=get_invokeai_config().cache_dir, - torch_dtype=torch_dtype, - variant=variant, - ) - break - except Exception as e: - print("====ERR LOAD====") - print(f"{variant}: {e}") - - # calc more accurate size - self.child_sizes[child_type] = calc_model_size_by_data(model) - return model - - - def get_pipeline(self, **kwargs): - return DiffusionPipeline.from_pretrained( - self.repo_id_or_path, - **kwargs, - ) - - -class EmptyConfigLoader(ConfigMixin): - - @classmethod - def load_config(cls, *args, **kwargs): - cls.config_name = kwargs.pop("config_name") - return super().load_config(*args, **kwargs) - - -class ClassifierModelInfo(ModelInfoBase): - #child_types: Dict[str, Type] - #child_sizes: Dict[str, int] - - def __init__(self, repo_id_or_path: str, model_type: SDModelType): - assert model_type == SDModelType.Classifier - super().__init__(repo_id_or_path, model_type) - - self.child_types: Dict[str, Type] = dict() - self.child_sizes: Dict[str, int] = dict() - - try: - main_config = EmptyConfigLoader.load_config(self.repo_id_or_path, config_name="config.json") - #main_config = json.loads(os.path.join(self.model_path, "config.json")) - except: - raise Exception("Invalid classifier model! (config.json not found or invalid)") - - self._load_tokenizer(main_config) - self._load_text_encoder(main_config) - self._load_feature_extractor(main_config) - - - def _load_tokenizer(self, main_config: dict): - try: - tokenizer_config = EmptyConfigLoader.load_config(self.repo_id_or_path, config_name="tokenizer_config.json") - #tokenizer_config = json.loads(os.path.join(self.model_path, "tokenizer_config.json")) - except: - raise Exception("Invalid classifier model! (Failed to load tokenizer_config.json)") - - if "tokenizer_class" in tokenizer_config: - tokenizer_class_name = tokenizer_config["tokenizer_class"] - elif "model_type" in main_config: - tokenizer_class_name = transformers.models.auto.tokenization_auto.TOKENIZER_MAPPING_NAMES[main_config["model_type"]] - else: - raise Exception("Invalid classifier model! (Failed to detect tokenizer type)") - - self.child_types[SDModelType.Tokenizer] = self._definition_to_type(["transformers", tokenizer_class_name]) - self.child_sizes[SDModelType.Tokenizer] = 0 - - - def _load_text_encoder(self, main_config: dict): - if "architectures" in main_config and len(main_config["architectures"]) > 0: - text_encoder_class_name = main_config["architectures"][0] - elif "model_type" in main_config: - text_encoder_class_name = transformers.models.auto.modeling_auto.MODEL_FOR_PRETRAINING_MAPPING_NAMES[main_config["model_type"]] - else: - raise Exception("Invalid classifier model! (Failed to detect text_encoder type)") - - self.child_types[SDModelType.TextEncoder] = self._definition_to_type(["transformers", text_encoder_class_name]) - self.child_sizes[SDModelType.TextEncoder] = calc_model_size_by_fs(self.repo_id_or_path) - - - def _load_feature_extractor(self, main_config: dict): - self.child_sizes[SDModelType.FeatureExtractor] = 0 - try: - feature_extractor_config = EmptyConfigLoader.load_config(self.repo_id_or_path, config_name="preprocessor_config.json") - except: - return # feature extractor not passed with t5 - - try: - feature_extractor_class_name = feature_extractor_config["feature_extractor_type"] - self.child_types[SDModelType.FeatureExtractor] = self._definition_to_type(["transformers", feature_extractor_class_name]) - except: - raise Exception("Invalid classifier model! (Unknown feature_extrator type)") - - - def get_size(self, child_type: Optional[SDModelType] = None): - if child_type is None: - return sum(self.child_sizes.values()) - else: - return self.child_sizes[child_type] - - - def get_model( - self, - child_type: Optional[SDModelType] = None, - torch_dtype: Optional[torch.dtype] = None, - ): - if child_type is None: - raise Exception("Child model type can't be null on classififer model") - if child_type not in self.child_types: - return None # TODO: or raise - - model = self.child_types[child_type].from_pretrained( - self.repo_id_or_path, - subfolder=child_type.value, - cache_dir=get_invokeai_config().cache_dir, - torch_dtype=torch_dtype, - ) - # calc more accurate size - self.child_sizes[child_type] = calc_model_size_by_data(model) - return model - - - -class VaeModelInfo(ModelInfoBase): - #vae_class: Type - #model_size: int - - def __init__(self, repo_id_or_path: str, model_type: SDModelType): - assert model_type == SDModelType.Vae - super().__init__(repo_id_or_path, model_type) - - try: - config = EmptyConfigLoader.load_config(repo_id_or_path, config_name="config.json") - #config = json.loads(os.path.join(self.model_path, "config.json")) - except: - raise Exception("Invalid vae model! (config.json not found or invalid)") - - try: - vae_class_name = config.get("_class_name", "AutoencoderKL") - self.vae_class = self._definition_to_type(["diffusers", vae_class_name]) - self.model_size = calc_model_size_by_fs(repo_id_or_path) - except: - raise Exception("Invalid vae model! (Unkown vae type)") - - def get_size(self, child_type: Optional[SDModelType] = None): - if child_type is not None: - raise Exception("There is no child models in vae model") - return self.model_size - - def get_model( - self, - child_type: Optional[SDModelType] = None, - torch_dtype: Optional[torch.dtype] = None, - ): - if child_type is not None: - raise Exception("There is no child models in vae model") - - model = self.vae_class.from_pretrained( - self.repo_id_or_path, - cache_dir=get_invokeai_config().cache_dir, - torch_dtype=torch_dtype, - ) - # calc more accurate size - self.model_size = calc_model_size_by_data(model) - return model - - -class LoRAModelInfo(ModelInfoBase): - #model_size: int - - def __init__(self, file_path: str, model_type: SDModelType): - assert model_type == SDModelType.Lora - # check manualy as super().__init__ will try to resolve repo_id too - if not os.path.exists(file_path): - raise Exception("Model not found") - super().__init__(file_path, model_type) - - self.model_size = os.path.getsize(file_path) - - def get_size(self, child_type: Optional[SDModelType] = None): - if child_type is not None: - raise Exception("There is no child models in lora") - return self.model_size - - def get_model( - self, - child_type: Optional[SDModelType] = None, - torch_dtype: Optional[torch.dtype] = None, - ): - if child_type is not None: - raise Exception("There is no child models in lora") - - model = LoRAModel.from_checkpoint( - file_path=self.model_path, - dtype=torch_dtype, - ) - - self.model_size = model.calc_size() - return model - - -class TextualInversionModelInfo(ModelInfoBase): - #model_size: int - - def __init__(self, file_path: str, model_type: SDModelType): - assert model_type == SDModelType.TextualInversion - # check manualy as super().__init__ will try to resolve repo_id too - if not os.path.exists(file_path): - raise Exception("Model not found") - super().__init__(file_path, model_type) - - self.model_size = os.path.getsize(file_path) - - def get_size(self, child_type: Optional[SDModelType] = None): - if child_type is not None: - raise Exception("There is no child models in textual inversion") - return self.model_size - - def get_model( - self, - child_type: Optional[SDModelType] = None, - torch_dtype: Optional[torch.dtype] = None, - ): - if child_type is not None: - raise Exception("There is no child models in textual inversion") - - model = TextualInversionModel.from_checkpoint( - file_path=self.model_path, - dtype=torch_dtype, - ) - - self.model_size = model.embedding.nelement() * model.embedding.element_size() - return model - - -MODEL_TYPES = { - SDModelType.Diffusers: DiffusersModelInfo, - SDModelType.Classifier: ClassifierModelInfo, - SDModelType.Vae: VaeModelInfo, - SDModelType.Lora: LoRAModelInfo, - SDModelType.TextualInversion: TextualInversionModelInfo, -} +from .models import MODEL_CLASSES # Maximum size of the cache, in gigs @@ -499,10 +50,6 @@ DEFAULT_MAX_CACHE_SIZE = 6.0 # actual size of a gig GIG = 1073741824 -# TODO: -class EmptyScheduler(SchedulerMixin, ConfigMixin): - pass - class ModelLocker(object): "Forward declaration" pass @@ -583,12 +130,10 @@ class ModelCache(object): self, model_path: str, model_type: SDModelType, - revision: Optional[str] = None, submodel_type: Optional[SDModelType] = None, ): - revision = revision or "main" - key = f"{model_path}:{model_type}:{revision}" + key = f"{model_path}:{model_type}" if submodel_type: key += f":{submodel_type}" return key @@ -606,55 +151,51 @@ class ModelCache(object): def _get_model_info( self, model_path: str, - model_type: SDModelType, - revision: str, + model_class: Type[ModelBase], ): model_info_key = self.get_key( model_path=model_path, model_type=model_type, - revision=revision, submodel_type=None, ) if model_info_key not in self.model_infos: - if model_type not in MODEL_TYPES: - raise Exception(f"Unknown/unsupported model type: {model_type}") - - self.model_infos[model_info_key] = MODEL_TYPES[model_type]( + self.model_infos[model_info_key] = model_class( model_path, model_type, ) return self.model_infos[model_info_key] + # TODO: args def get_model( self, - repo_id_or_path: Union[str, Path], - model_type: SDModelType = SDModelType.Diffusers, - submodel: Optional[SDModelType] = None, - revision: Optional[str] = None, - variant: Optional[str] = None, + model_path: Union[str, Path], + model_class: Type[ModelBase], + submodel: Optional[SubModelType] = None, gpu_load: bool = True, ) -> Any: - model_path = get_model_path(repo_id_or_path) + if not isinstance(model_path, Path): + model_path = Path(model_path) + + if not os.path.exists(model_path): + raise Exception(f"Model not found: {model_path}") + model_info = self._get_model_info( model_path=model_path, - model_type=model_type, - revision=revision, + model_class=model_class, ) - # TODO: variant key = self.get_key( model_path=model_path, - model_type=model_type, - revision=revision, + model_type=model_type, # TODO: submodel_type=submodel, ) # TODO: lock for no copies on simultaneous calls? cache_entry = self._cached_models.get(key, None) if cache_entry is None: - self.logger.info(f'Loading model {repo_id_or_path}, type {model_type}:{submodel}') + self.logger.info(f'Loading model {model_path}, type {model_type}:{submodel}') # this will remove older cached models until # there is sufficient room to load the requested model @@ -662,7 +203,7 @@ class ModelCache(object): # clean memory to make MemoryUsage() more accurate gc.collect() - model = model_info.get_model(submodel, torch_dtype=self.precision) + model = model_info.get_model(submodel, torch_dtype=self.precision, variant=) if mem_used := model_info.get_size(submodel): self.logger.debug(f'CPU RAM used for load: {(mem_used/GIG):.2f} GB') @@ -732,20 +273,14 @@ class ModelCache(object): def model_hash( self, - repo_id_or_path: Union[str, Path], - revision: str = "main", + model_path: Union[str, Path], ) -> str: ''' Given the HF repo id or path to a model on disk, returns a unique hash. Works for legacy checkpoint files, HF models on disk, and HF repo IDs - :param repo_id_or_path: repo_id string or Path to model file/directory on disk. - :param revision: optional revision string (if fetching a HF repo_id) + :param model_path: Path to model file/directory on disk. ''' - revision = revision or "main" - if Path(repo_id_or_path).is_dir(): - return self._local_model_hash(repo_id_or_path) - else: - return self._hf_commit_hash(repo_id_or_path,revision) + return self._local_model_hash(model_path) def cache_size(self) -> float: "Return the current size of the cache, in GB" @@ -840,17 +375,6 @@ class ModelCache(object): with open(hashpath, "w") as f: f.write(hash) return hash - - def _hf_commit_hash(self, repo_id: str, revision: str='main') -> str: - api = HfApi() - info = api.list_repo_refs( - repo_id=repo_id, - repo_type='model', - ) - desired_revisions = [branch for branch in info.branches if branch.name==revision] - if not desired_revisions: - raise KeyError(f"Revision '{revision}' not found in {repo_id}") - return desired_revisions[0].target_commit class SilenceWarnings(object): def __init__(self): diff --git a/invokeai/backend/model_management/model_manager.py b/invokeai/backend/model_management/model_manager.py index 5df8f6fd6e..e6fc990a56 100644 --- a/invokeai/backend/model_management/model_manager.py +++ b/invokeai/backend/model_management/model_manager.py @@ -163,7 +163,6 @@ import safetensors import safetensors.torch import torch from diffusers import AutoencoderKL -from diffusers.utils import is_safetensors_available from huggingface_hub import scan_cache_dir from omegaconf import OmegaConf from omegaconf.dictconfig import DictConfig @@ -172,8 +171,8 @@ import invokeai.backend.util.logging as logger from invokeai.app.services.config import InvokeAIAppConfig from invokeai.backend.util import CUDA_DEVICE, download_with_resume from ..install.model_install_backend import Dataset_path, hf_download_with_resume -from .model_cache import (ModelCache, ModelLocker, SDModelType, - SilenceWarnings) +from .model_cache import ModelCache, ModelLocker, SilenceWarnings +from .models import BaseModelType, ModelType, SubModelType, MODEL_CLASSES # We are only starting to number the config file with release 3. # The config file version doesn't have to start at release version, but it will help # reduce confusion. @@ -201,14 +200,6 @@ class InvalidModelError(Exception): "Raised when an invalid model is requested" pass -class SDLegacyType(Enum): - V1 = auto() - V1_INPAINT = auto() - V2 = auto() - V2_e = auto() - V2_v = auto() - UNKNOWN = auto() - MAX_CACHE_SIZE = 6.0 # GB @@ -280,32 +271,45 @@ class ModelManager(object): def model_exists( self, model_name: str, - model_type: SDModelType = SDModelType.Diffusers, + base_model: BaseModelType, + model_type: ModelType, ) -> bool: """ Given a model name, returns True if it is a valid identifier. """ - model_key = self.create_key(model_name, model_type) + model_key = self.create_key(model_name, base_model, model_type) return model_key in self.config - def create_key(self, model_name: str, model_type: SDModelType) -> str: - return f"{model_type}/{model_name}" + def create_key( + self, + model_name: str, + base_model: BaseModelType, + model_type: ModelType, + ) -> str: + return f"{base_model}/{model_type}/{model_name}" - def parse_key(self, model_key: str) -> Tuple[str, SDModelType]: - model_type_str, model_name = model_key.split('/', 1) + def parse_key(self, model_key: str) -> Tuple[str, BaseModelType, ModelType]: + base_model_str, model_type_str, model_name = model_key.split('/', 2) try: model_type = SDModelType(model_type_str) - return (model_name, model_type) except: raise Exception(f"Unknown model type: {model_type_str}") + try: + base_model = BaseModelType(base_model_str) + except: + raise Exception(f"Unknown base model: {base_model_str}") + + return (model_name, base_model, model_type) + def get_model( self, model_name: str, - model_type: SDModelType = SDModelType.Diffusers, - submodel: Optional[SDModelType] = None, - ) -> SDModelInfo: + base_model: BaseModelType, + model_type: ModelType, + submodel_type: Optional[SubModelType] = None + ): """Given a model named identified in models.yaml, return an SDModelInfo object describing it. :param model_name: symbolic name of the model in models.yaml @@ -344,205 +348,181 @@ class ModelManager(object): # raises an InvalidModelError """ - model_key = self.create_key(model_name, model_type) - if model_key not in self.config: - raise InvalidModelError( - f'"{model_key}" is not a known model name. Please check your models.yaml file' - ) - - # get the required loading info out of the config file - mconfig = self.config[model_key] - - # type already checked as it's part of key - location = None - if model_type == SDModelType.Diffusers: - # intercept stanzas that point to checkpoint weights and replace them with the equivalent diffusers model - if mconfig.format in ["ckpt", "safetensors"]: - location = self.convert_ckpt_and_cache(mconfig) # TODO: Maybe don't do this any longer? - elif mconfig.get('path'): - location = self.globals.root_dir / mconfig.get('path') - elif p := mconfig.get('path'): - location = self.globals.root_dir / p - - revision = mconfig.get('revision') - if model_type in [SDModelType.Lora, SDModelType.TextualInversion]: - hash = "" # TODO: + model_class = MODEL_CLASSES[base_model][model_type] + + #if model_type in { + # ModelType.Lora, + # ModelType.ControlNet, + # ModelType.TextualInversion, + # ModelType.Vae, + #}: + if not model_class.has_config: + #if model_class.Config is None: + # skip config + # load from + # /models/{base_model}/{model_type}/{model_name} + # /models/{base_model}/{model_type}/{model_name}.{ext} + + model_config = None + + for ext in {"pt", "ckpt", "safetensors"}: + model_path = os.path.join(model_dir, base_model, model_type, f"{model_name}.{ext}") + if os.path.exists(model_path): + break + else: + model_path = os.path.join(model_dir, base_model, model_type, model_name) + if not os.path.exists(model_path): + raise InvalidModelError( + f"Model not found - \"{base_model}/{model_type}/{model_name}\" " + ) + else: - hash = self.cache.model_hash(location, revision) + # find in config + model_key = self.create_key(model_name, base_model, model_type) + if model_key not in self.config: + raise InvalidModelError( + f'"{model_key}" is not a known model name. Please check your models.yaml file' + ) - if not location: - return None + model_config = self.config[model_key] + + # /models/{base_model}/{model_type}/{name}.ckpt or .safentesors + # /models/{base_model}/{model_type}/{name}/ + model_path = model_config.path - # If the caller is asking for part of the model and the config indicates - # an external replacement for that field, then we fetch the replacement - if submodel and mconfig.get(submodel): - location = mconfig.get(submodel).get('path') \ - or mconfig.get(submodel).get('repo_id') - model_type = submodel - submodel = None + # vae/movq override + # TODO: + if submodel is not None and submodel in model_config: + model_path = model_config[submodel]["path"] + model_type = submodel + submodel = None - # to support the traditional way of attaching a VAE - # to a model, we hacked in `attach_model_part` - # TODO: - if model_type == SDModelType.Vae and "vae" in mconfig: - print("NOT_IMPLEMENTED - RETURN CUSTOM VAE") - - model_context = self.cache.get_model( - location, - model_type = model_type, - revision = revision, - submodel = submodel, + dst_convert_path = None # TODO: + model_path = model_class.convert_if_required( + model_path, + dst_convert_path, + model_config, ) - # in case we need to communicate information about this - # model to the cache manager, then we need to remember - # the cache key - self.cache_keys[model_key] = model_context.key - + model_context = self.cache.get_model( + model_path, + model_class, + submodel, + ) + + hash = "" # TODO: + return SDModelInfo( context = model_context, name = model_name, + base_model = base_model, type = submodel or model_type, hash = hash, - location = location, - revision = revision, + location = model_path, # TODO: precision = self.cache.precision, - _cache = self.cache + _cache = self.cache, ) - def default_model(self) -> Optional[Tuple[str, SDModelType]]: + def default_model(self) -> Optional[Tuple[str, BaseModelType, ModelType]]: """ Returns the name of the default model, or None if none is defined. """ - for model_name, model_type in self.model_names(): - model_key = self.create_key(model_name, model_type) - if self.config[model_key].get("default"): - return (model_name, model_type) - return self.model_names()[0][0] + for model_key, model_config in self.config.items(): + if model_config.get("default", False): + return self.parse_key(model_key) - def set_default_model(self, model_name: str, model_type: SDModelType=SDModelType.Diffusers) -> None: + for model_key, _ in self.config.items(): + return self.parse_key(model_key) + else: + return None # TODO: or redo as (None, None, None) + + def set_default_model( + self, + model_name: str, + base_model: BaseModelType, + model_type: ModelType, + ) -> None: """ Set the default model. The change will not take effect until you call model_manager.commit() """ - assert self.model_exists(model_name, model_type), f"unknown model '{model_name}'" - config = self.config - for model_name, model_type in self.model_names(): - key = self.create_key(model_name, model_type) - config[key].pop("default", None) - config[self.create_key(model_name, model_type)]["default"] = True + model_key = self.model_key(model_name, base_model, model_type) + if model_key not in self.config: + raise Exception(f"Unknown model: {model_key}") + + for cur_model_key, config in self.config.items(): + if cur_model_key == model_key: + config["default"] = True + else: + config.pop("default", None) def model_info( self, model_name: str, - model_type: SDModelType=SDModelType.Diffusers, + base_model: BaseModelType, + model_type: ModelType, ) -> dict: """ Given a model name returns the OmegaConf (dict-like) object describing it. """ - if not self.model_exists(model_name, model_type): - return None - return self.config[self.create_key(model_name, model_type)] + model_key = self.create_key(model_name, base_model, model_type) + return self.config.get(model_key, None) - def model_names(self) -> List[Tuple[str, SDModelType]]: + def model_names(self) -> List[Tuple[str, BaseModelType, ModelType]]: """ - Return a list of (str, SDModelType) corresponding to all models + Return a list of (str, BaseModelType, ModelType) corresponding to all models known to the configuration. """ return [(self.parse_key(x)) for x in self.config.keys() if isinstance(self.config[x], DictConfig)] - def is_legacy(self, model_name: str, model_type: SDModelType.Diffusers) -> bool: + def list_models( + self, + base_model: Optional[BaseModelType] = None, + model_type: Optional[SDModelType] = None, + ) -> Dict[str, Dict[str, str]]: """ - Return true if this is a legacy (.ckpt) model - """ - # if we are converting legacy files automatically, then - # there are no legacy ckpts! - if self.globals.ckpt_convert: - return False - info = self.model_info(model_name, model_type) - if "weights" in info and info["weights"].endswith((".ckpt", ".safetensors")): - return True - return False - - def list_models(self, model_type: SDModelType=None) -> dict[str,dict[str,str]]: - """ - Return a dict of models, in format [model_type][model_name], with - following fields: - model_name - model_type - format - description - status - # for folders only - repo_id - path - subfolder - vae - # for ckpts only - config - weights - vae + Return a dict of models, in format [base_model][model_type][model_name] Please use model_manager.models() to get all the model names, model_manager.model_info('model-name') to get the stanza for the model named 'model-name', and model_manager.config to get the full OmegaConf object derived from models.yaml """ - models = {} + assert not(model_type is not None and base_model is None), "model_type must be provided with base_model" + + models = dict() for model_key in sorted(self.config, key=str.casefold): stanza = self.config[model_key] - # don't include VAEs in listing (legacy style) - if "config" in stanza and "/VAE/" in stanza["config"]: - continue + if model_key.startswith('_'): continue - model_name, stanza_type = self.parse_key(model_key) + model_name, m_base_model, stanza_type = self.parse_key(model_key) + if base_model is not None and m_base_model != base_model: + continue if model_type is not None and model_type != stanza_type: continue + if m_base_model not in models: + models[m_base_model] = dict() if stanza_type not in models: - models[stanza_type] = dict() + models[m_base_model][stanza_type] = dict() - models[stanza_type][model_name] = dict() - - model_format = stanza.get('format') - - # Common Attribs - description = stanza.get("description", None) - models[stanza_type][model_name].update( - model_name=model_name, - model_type=stanza_type, - format=model_format, - description=description, - status="unknown", # TODO: no more status as model loaded separately + model_class = MODEL_CLASSES[m_base_model][stanza_type] + models[m_base_model][stanza_type][model_name] = model_class.build_config( + **stanza, + name=model_name, + base_model=base_model, + type=stanza_type, ) - - # Checkpoint Config Parse - if model_format in ["ckpt","safetensors"]: - models[stanza_type][model_name].update( - config = str(stanza.get("config", None)), - weights = str(stanza.get("weights", None)), - vae = str(stanza.get("vae", None)), - ) - - # Diffusers Config Parse - elif model_format == "folder": - if vae := stanza.get("vae", None): - if isinstance(vae, DictConfig): - vae = dict( - repo_id = str(vae.get("repo_id", None)), - path = str(vae.get("path", None)), - subfolder = str(vae.get("subfolder", None)), - ) - - models[stanza_type][model_name].update( - vae = vae, - repo_id = str(stanza.get("repo_id", None)), - path = str(stanza.get("path", None)), - ) + #models[m_base_model][stanza_type][model_name] = model_class.Config( + # **stanza, + # name=model_name, + # base_model=base_model, + # type=stanza_type, + #).dict() return models @@ -552,7 +532,7 @@ class ModelManager(object): """ for model_type, model_dict in self.list_models().items(): for model_name, model_info in model_dict.items(): - line = f'{model_info["model_name"]:25s} {model_info["status"]:>15s} {model_info["model_type"]:10s} {model_info["description"]}' + line = f'{model_info["name"]:25s} {model_info["status"]:>15s} {model_info["type"]:10s} {model_info["description"]}' if model_info["status"] in ["in gpu","locked in gpu"]: line = f"\033[1m{line}\033[0m" print(line) @@ -601,7 +581,8 @@ class ModelManager(object): def add_model( self, model_name: str, - model_type: SDModelType, + base_model: BaseModelType, + model_type: ModelType, model_attributes: dict, clobber: bool = False, ) -> None: @@ -613,38 +594,31 @@ class ModelManager(object): attributes are incorrect or the model name is missing. """ - if model_type == SDModelType.Fiffusers: - # TODO: automaticaly or manualy? - #assert "format" in model_attributes, 'missing required field "format"' - model_format = "ckpt" if "weights" in model_attributes else "diffusers" + model_class = MODEL_CLASSES[base_model][model_type] - if model_format == "diffusers": - assert ( - "description" in model_attributes - ), 'required field "description" is missing' - assert ( - "path" in model_attributes or "repo_id" in model_attributes - ), 'model must have either the "path" or "repo_id" fields defined' + model_class.build_config( + **model_attributes, + name=model_name, + base_model=base_model, + type=model_type, + ) + #model_cfg = model_class.Config( + # **model_attributes, + # name=model_name, + # base_model=base_model, + # type=model_type, + #) - elif model_format == "ckpt": - for field in ("description", "weights", "config"): - assert field in model_attributes, f"required field {field} is missing" - - else: - assert "weights" in model_attributes and "description" in model_attributes - - model_key = self.create_key(model_name, model_type) + model_key = self.create_key(model_name, base_model, model_type) assert ( clobber or model_key not in self.config ), f'attempt to overwrite existing model definition "{model_key}"' self.config[model_key] = model_attributes - - if "weights" in self.config[model_key]: - self.config[model_key]["weights"].replace("\\", "/") if clobber and model_key in self.cache_keys: + # TODO: self.cache.uncache_model(self.cache_keys[model_key]) del self.cache_keys[model_key] @@ -739,326 +713,6 @@ class ModelManager(object): ), True ) - - @classmethod - def probe_model_type(self, checkpoint: dict) -> SDLegacyType: - """ - Given a pickle or safetensors model object, probes contents - of the object and returns an SDLegacyType indicating its - format. Valid return values include: - SDLegacyType.V1 - SDLegacyType.V1_INPAINT - SDLegacyType.V2 (V2 prediction type unknown) - SDLegacyType.V2_e (V2 using 'epsilon' prediction type) - SDLegacyType.V2_v (V2 using 'v_prediction' prediction type) - SDLegacyType.UNKNOWN - """ - global_step = checkpoint.get("global_step") - state_dict = checkpoint.get("state_dict") or checkpoint - - try: - key_name = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight" - if key_name in state_dict and state_dict[key_name].shape[-1] == 1024: - if global_step == 220000: - return SDLegacyType.V2_e - elif global_step == 110000: - return SDLegacyType.V2_v - else: - return SDLegacyType.V2 - # otherwise we assume a V1 file - in_channels = state_dict[ - "model.diffusion_model.input_blocks.0.0.weight" - ].shape[1] - if in_channels == 9: - return SDLegacyType.V1_INPAINT - elif in_channels == 4: - return SDLegacyType.V1 - else: - return SDLegacyType.UNKNOWN - except KeyError: - return SDLegacyType.UNKNOWN - - def heuristic_import( - self, - path_url_or_repo: str, - model_name: Optional[str] = None, - description: Optional[str] = None, - model_config_file: Optional[Path] = None, - commit_to_conf: Optional[Path] = None, - config_file_callback: Optional[Callable[[Path], Path]] = None, - ) -> str: - """Accept a string which could be: - - a HF diffusers repo_id - - a URL pointing to a legacy .ckpt or .safetensors file - - a local path pointing to a legacy .ckpt or .safetensors file - - a local directory containing .ckpt and .safetensors files - - a local directory containing a diffusers model - - After determining the nature of the model and downloading it - (if necessary), the file is probed to determine the correct - configuration file (if needed) and it is imported. - - The model_name and/or description can be provided. If not, they will - be generated automatically. - - If commit_to_conf is provided, the newly loaded model will be written - to the `models.yaml` file at the indicated path. Otherwise, the changes - will only remain in memory. - - The routine will do its best to figure out the config file - needed to convert legacy checkpoint file, but if it can't it - will call the config_file_callback routine, if provided. The - callback accepts a single argument, the Path to the checkpoint - file, and returns a Path to the config file to use. - - The (potentially derived) name of the model is returned on - success, or None on failure. When multiple models are added - from a directory, only the last imported one is returned. - - """ - model_path: Path = None - thing = str(path_url_or_repo) # to save typing - - self.logger.info(f"Probing {thing} for import") - - if thing.startswith(("http:", "https:", "ftp:")): - self.logger.info(f"{thing} appears to be a URL") - model_path = self._resolve_path( - thing, "models/ldm/stable-diffusion-v1" - ) # _resolve_path does a download if needed - - elif Path(thing).is_file() and thing.endswith((".ckpt", ".safetensors")): - if Path(thing).stem in ["model", "diffusion_pytorch_model"]: - self.logger.debug(f"{Path(thing).name} appears to be part of a diffusers model. Skipping import") - return - else: - self.logger.debug(f"{thing} appears to be a checkpoint file on disk") - model_path = self._resolve_path(thing, "models/ldm/stable-diffusion-v1") - - elif Path(thing).is_dir() and Path(thing, "model_index.json").exists(): - self.logger.debug(f"{thing} appears to be a diffusers file on disk") - model_name = self.import_diffuser_model( - thing, - vae=dict(repo_id="stabilityai/sd-vae-ft-mse"), - model_name=model_name, - description=description, - commit_to_conf=commit_to_conf, - ) - - elif Path(thing).is_dir(): - if (Path(thing) / "model_index.json").exists(): - self.logger.debug(f"{thing} appears to be a diffusers model.") - model_name = self.import_diffuser_model( - thing, commit_to_conf=commit_to_conf - ) - else: - self.logger.debug(f"{thing} appears to be a directory. Will scan for models to import") - for m in list(Path(thing).rglob("*.ckpt")) + list( - Path(thing).rglob("*.safetensors") - ): - if model_name := self.heuristic_import( - str(m), - commit_to_conf=commit_to_conf, - config_file_callback=config_file_callback, - ): - self.logger.info(f"{model_name} successfully imported") - return model_name - - elif re.match(r"^[\w.+-]+/[\w.+-]+$", thing): - self.logger.debug(f"{thing} appears to be a HuggingFace diffusers repo_id") - model_name = self.import_diffuser_model( - thing, commit_to_conf=commit_to_conf - ) - pipeline, _, _, _ = self._load_diffusers_model(self.config[model_name]) - return model_name - else: - self.logger.warning(f"{thing}: Unknown thing. Please provide a URL, file path, directory or HuggingFace repo_id") - - # Model_path is set in the event of a legacy checkpoint file. - # If not set, we're all done - if not model_path: - return - - if model_path.stem in self.config: # already imported - self.logger.debug("Already imported. Skipping") - return model_path.stem - - # another round of heuristics to guess the correct config file. - checkpoint = None - if model_path.suffix in [".ckpt", ".pt"]: - self.cache.scan_model(model_path, model_path) - checkpoint = torch.load(model_path) - else: - checkpoint = safetensors.torch.load_file(model_path) - - # additional probing needed if no config file provided - if model_config_file is None: - # look for a like-named .yaml file in same directory - if model_path.with_suffix(".yaml").exists(): - model_config_file = model_path.with_suffix(".yaml") - self.logger.debug(f"Using config file {model_config_file.name}") - - else: - model_type = self.probe_model_type(checkpoint) - if model_type == SDLegacyType.V1: - self.logger.debug("SD-v1 model detected") - model_config_file = self.globals.legacy_conf_path / "v1-inference.yaml" - elif model_type == SDLegacyType.V1_INPAINT: - self.logger.debug("SD-v1 inpainting model detected") - model_config_file = self.globals.legacy_conf_path / "v1-inpainting-inference.yaml" - elif model_type == SDLegacyType.V2_v: - self.logger.debug("SD-v2-v model detected") - model_config_file = self.globals.legacy_conf_path / "v2-inference-v.yaml" - elif model_type == SDLegacyType.V2_e: - self.logger.debug("SD-v2-e model detected") - model_config_file = self.globals.legacy_conf_path / "v2-inference.yaml" - elif model_type == SDLegacyType.V2: - self.logger.warning( - f"{thing} is a V2 checkpoint file, but its parameterization cannot be determined." - ) - else: - self.logger.warning( - f"{thing} is a legacy checkpoint file but not a known Stable Diffusion model." - ) - - if not model_config_file and config_file_callback: - model_config_file = config_file_callback(model_path) - - # despite our best efforts, we could not find a model config file, so give up - if not model_config_file: - return - - # look for a custom vae, a like-named file ending with .vae in the same directory - vae_path = None - for suffix in ["pt", "ckpt", "safetensors"]: - if (model_path.with_suffix(f".vae.{suffix}")).exists(): - vae_path = model_path.with_suffix(f".vae.{suffix}") - self.logger.debug(f"Using VAE file {vae_path.name}") - vae = None if vae_path else dict(repo_id="stabilityai/sd-vae-ft-mse") - - diffuser_path = self.globals.converted_ckpts_dir / model_path.stem - with SilenceWarnings(): - model_name = self.convert_and_import( - model_path, - diffusers_path=diffuser_path, - vae=vae, - vae_path=str(vae_path), - model_name=model_name, - model_description=description, - original_config_file=model_config_file, - commit_to_conf=commit_to_conf, - scan_needed=False, - ) - return model_name - - def convert_ckpt_and_cache(self, mconfig: DictConfig) -> Path: - """ - Convert the checkpoint model indicated in mconfig into a - diffusers, cache it to disk, and return Path to converted - file. If already on disk then just returns Path. - """ - weights = self.globals.root_dir / mconfig.weights - config_file = self.globals.root_dir / mconfig.config - diffusers_path = self.globals.converted_ckpts_dir / weights.stem - - # return cached version if it exists - if diffusers_path.exists(): - return diffusers_path - - vae_ckpt_path, vae_model = self._get_vae_for_conversion(weights, mconfig) - - # to avoid circular import errors - from .convert_ckpt_to_diffusers import convert_ckpt_to_diffusers - with SilenceWarnings(): - convert_ckpt_to_diffusers( - weights, - diffusers_path, - extract_ema=True, - original_config_file=config_file, - vae=vae_model, - vae_path=str(self.globals.root_dir / vae_ckpt_path) if vae_ckpt_path else None, - scan_needed=True, - ) - return diffusers_path - - def convert_vae_ckpt_and_cache(self, mconfig: DictConfig) -> Path: - """ - Convert the VAE indicated in mconfig into a diffusers AutoencoderKL - object, cache it to disk, and return Path to converted - file. If already on disk then just returns Path. - """ - root = self.globals.root_dir - weights_file = root / mconfig.weights - config_file = root / mconfig.config - diffusers_path = self.globals.converted_ckpts_dir / weights_file.stem - image_size = mconfig.get('width') or mconfig.get('height') or 512 - - # return cached version if it exists - if diffusers_path.exists(): - return diffusers_path - - # this avoids circular import error - from .convert_ckpt_to_diffusers import convert_ldm_vae_to_diffusers - checkpoint = torch.load(weights_file, map_location="cpu")\ - if weights_file.suffix in ['.ckpt','.pt'] \ - else safetensors.torch.load_file(weights_file) - - # sometimes weights are hidden under "state_dict", and sometimes not - if "state_dict" in checkpoint: - checkpoint = checkpoint["state_dict"] - - config = OmegaConf.load(config_file) - - vae_model = convert_ldm_vae_to_diffusers( - checkpoint = checkpoint, - vae_config = config, - image_size = image_size - ) - vae_model.save_pretrained( - diffusers_path, - safe_serialization=is_safetensors_available() - ) - return diffusers_path - - def _get_vae_for_conversion( - self, - weights: Path, - mconfig: DictConfig - ) -> Tuple[Path, AutoencoderKL]: - # VAE handling is convoluted - # 1. If there is a .vae.ckpt file sharing same stem as weights, then use - # it as the vae_path passed to convert - vae_ckpt_path = None - vae_diffusers_location = None - vae_model = None - for suffix in ["pt", "ckpt", "safetensors"]: - if (weights.with_suffix(f".vae.{suffix}")).exists(): - vae_ckpt_path = weights.with_suffix(f".vae.{suffix}") - self.logger.debug(f"Using VAE file {vae_ckpt_path.name}") - if vae_ckpt_path: - return (vae_ckpt_path, None) - - # 2. If mconfig has a vae weights path, then we use that as vae_path - vae_config = mconfig.get('vae') - if vae_config and isinstance(vae_config,str): - vae_ckpt_path = vae_config - return (vae_ckpt_path, None) - - # 3. If mconfig has a vae dict, then we use it as the diffusers-style vae - if vae_config and isinstance(vae_config,DictConfig): - vae_diffusers_location = self.globals.root_dir / vae_config.get('path') \ - if vae_config.get('path') \ - else vae_config.get('repo_id') - - # 4. Otherwise, we use stabilityai/sd-vae-ft-mse "because it works" - else: - vae_diffusers_location = "stabilityai/sd-vae-ft-mse" - - if vae_diffusers_location: - vae_model = self.cache.get_model(vae_diffusers_location, SDModelType.Vae).model - return (None, vae_model) - - return (None, None) def convert_and_import( self, diff --git a/invokeai/backend/model_management/models.py b/invokeai/backend/model_management/models.py new file mode 100644 index 0000000000..8fff2a61d2 --- /dev/null +++ b/invokeai/backend/model_management/models.py @@ -0,0 +1,726 @@ +import sys +from enum import Enum +import torch +import safetensors.torch +from diffusers.utils import is_safetensors_available + +class BaseModelType(str, Enum): + #StableDiffusion1_5 = "stable_diffusion_1_5" + #StableDiffusion2 = "stable_diffusion_2" + #StableDiffusion2Base = "stable_diffusion_2_base" + # TODO: maybe then add sample size(512/768)? + StableDiffusion1_5 = "SD-1" + StableDiffusion2Base = "SD-2-base" # 512 pixels; this will have epsilon parameterization + StableDiffusion2 = "SD-2" # 768 pixels; this will have v-prediction parameterization + #Kandinsky2_1 = "kandinsky_2_1" + +class ModelType(str, Enum): + Pipeline = "pipeline" + Classifier = "classifier" + Vae = "vae" + + Lora = "lora" + ControlNet = "controlnet" + TextualInversion = "embedding" + +class SubModelType: + UNet = "unet" + TextEncoder = "text_encoder" + Tokenizer = "tokenizer" + Vae = "vae" + Scheduler = "scheduler" + SafetyChecker = "safety_checker" + #MoVQ = "movq" + +MODEL_CLASSES = { + BaseModel.StableDiffusion1_5: { + ModelType.Pipeline: StableDiffusionModel, + ModelType.Classifier: ClassifierModel, + ModelType.Vae: VaeModel, + ModelType.Lora: LoraModel, + ModelType.ControlNet: ControlNetModel, + ModelType.TextualInversion: TextualInversionModel, + }, + BaseModel.StableDiffusion2: { + ModelType.Pipeline: StableDiffusionModel, + ModelType.Classifier: ClassifierModel, + ModelType.Vae: VaeModel, + ModelType.Lora: LoraModel, + ModelType.ControlNet: ControlNetModel, + ModelType.TextualInversion: TextualInversionModel, + }, + BaseModel.StableDiffusion2Base: { + ModelType.Pipeline: StableDiffusionModel, + ModelType.Classifier: ClassifierModel, + ModelType.Vae: VaeModel, + ModelType.Lora: LoraModel, + ModelType.ControlNet: ControlNetModel, + ModelType.TextualInversion: TextualInversionModel, + }, + #BaseModel.Kandinsky2_1: { + # ModelType.Pipeline: Kandinsky2_1Model, + # ModelType.Classifier: ClassifierModel, + # ModelType.MoVQ: MoVQModel, + # ModelType.Lora: LoraModel, + # ModelType.ControlNet: ControlNetModel, + # ModelType.TextualInversion: TextualInversionModel, + #}, +} + +class EmptyConfigLoader(ConfigMixin): + @classmethod + def load_config(cls, *args, **kwargs): + cls.config_name = kwargs.pop("config_name") + return super().load_config(*args, **kwargs) + +class ModelBase: + #model_path: str + #base_model: BaseModelType + #model_type: ModelType + + def __init__( + self, + model_path: str, + base_model: BaseModelType, + model_type: ModelType, + ): + self.model_path = model_path + self.base_model = base_model + self.model_type = model_type + + def _hf_definition_to_type(self, subtypes: List[str]) -> Type: + if len(subtypes) < 2: + raise Exception("Invalid subfolder definition!") + if subtypes[0] in ["diffusers", "transformers"]: + res_type = sys.modules[subtypes[0]] + subtypes = subtypes[1:] + + else: + res_type = sys.modules["diffusers"] + res_type = getattr(res_type, "pipelines") + + + for subtype in subtypes: + res_type = getattr(res_type, subtype) + return res_type + + +class DiffusersModel(ModelBase): + #child_types: Dict[str, Type] + #child_sizes: Dict[str, int] + + def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType): + super().__init__(model_path, base_model, model_type) + + self.child_types: Dict[str, Type] = dict() + self.child_sizes: Dict[str, int] = dict() + + try: + config_data = DiffusionPipeline.load_config(self.model_path) + #config_data = json.loads(os.path.join(self.model_path, "model_index.json")) + except: + raise Exception("Invalid diffusers model! (model_index.json not found or invalid)") + + config_data.pop("_ignore_files", None) + + # retrieve all folder_names that contain relevant files + child_components = [k for k, v in config_data.items() if isinstance(v, list)] + + for child_name in child_components: + child_type = self._hf_definition_to_type(config_data[child_name]) + self.child_types[child_name] = child_type + self.child_sizes[child_name] = calc_model_size_by_fs(self.model_path, subfolder=child_name) + + + def get_size(self, child_type: Optional[SubModelType] = None): + if child_type is None: + return sum(self.child_sizes.values()) + else: + return self.child_sizes[child_type] + + + def get_model( + self, + torch_dtype: Optional[torch.dtype], + child_type: Optional[SubModelType] = None, + ): + # return pipeline in different function to pass more arguments + if child_type is None: + raise Exception("Child model type can't be null on diffusers model") + if child_type not in self.child_types: + return None # TODO: or raise + + if torch_dtype == torch.float16: + variants = ["fp16", None] + else: + variants = [None, "fp16"] + + # TODO: better error handling(differentiate not found from others) + for variant in variants: + try: + # TODO: set cache_dir to /dev/null to be sure that cache not used? + model = self.child_types[child_type].from_pretrained( + self.model_path, + subfolder=child_type.value, + torch_dtype=torch_dtype, + variant=variant, + local_files_only=True, + ) + break + except Exception as e: + print("====ERR LOAD====") + print(f"{variant}: {e}") + + # calc more accurate size + self.child_sizes[child_type] = calc_model_size_by_data(model) + return model + + #def convert_if_required(model_path: Union[str, Path], cache_path: str, config: Optional[dict]) -> Path: + + +class StableDiffusionModel(DiffusersModel): + def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType): + assert base_model in { + BaseModelType.StableDiffusion1_5, + BaseModelType.StableDiffusion2, + BaseModelType.StableDiffusion2Base, + } + assert model_type == ModelType.Pipeline + super().__init__(model_path, base_model, model_type) + + @staticmethod + def convert_if_required(model_path: Union[str, Path], dst_path: str, config: Optional[dict]) -> Path: + if not isinstance(model_path, Path): + model_path = Path(model_path) + + # TODO: args + # TODO: set model_path, to config? pass dst_path as arg? + # TODO: check + return _convert_ckpt_and_cache(config) + +class classproperty(object): # pylint: disable=invalid-name + """Class property decorator. + + Example usage: + + class MyClass(object): + + @classproperty + def value(cls): + return '123' + + > print MyClass.value + 123 + """ + + def __init__(self, func): + self._func = func + + def __get__(self, owner_self, owner_cls): + return self._func(owner_cls) + +class ModelConfigBase(BaseModel): + path: str # or Path + name: str + description: Optional[str] + + +class StableDiffusionDModel(DiffusersModel): + class Config(ModelConfigBase): + format: str + vae: Optional[str] = Field(None) + config: Optional[str] = Field(None) + + @root_validator + def validator(cls, values): + if values["format"] not in {"checkpoint", "diffusers"}: + raise ValueError(f"Unkown stable diffusion model format: {values['format']}") + if values["config"] is not None and values["format"] != "checkpoint": + raise ValueError(f"Custom config field allowed only in checkpoint stable diffusion model") + return values + + # return config only for checkpoint format + def dict(self, *args, **kwargs): + result = super().dict(*args, **kwargs) + if self.format != "checkpoint": + result.pop("config", None) + return result + + @classproperty + def has_config(self): + return True + + def build_config(self, **kwargs) -> dict: + try: + res = dict( + path=kwargs["path"], + name=kwargs["name"], + description=kwargs.get("description", None), + + format=kwargs["format"], + vae=kwargs.get("vae", None), + ) + if res["format"] not in {"checkpoint", "diffusers"}: + raise Exception(f"Unkonwn stable diffusion model format: {res['format']}") + if res["format"] == "checkpoint": + res["config"] = kwargs.get("config", None) + # TODO: raise if config specified for diffusers? + + return res + + except KeyError as e: + raise Exception(f"Field \"{e.args[0]}\" not found!") + + + def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType): + assert base_model == BaseModelType.StableDiffusion1_5 + assert model_type == ModelType.Pipeline + super().__init__(model_path, base_model, model_type) + + @classmethod + def convert_if_required(cls, model_path: str, dst_path: str, config: Optional[dict]) -> str: + model_config = cls.Config( + **config, + path=model_path, + name="", + ) + + if hasattr(model_config, "config"): + convert_ckpt_and_cache( + model_path=model_path, + dst_path=dst_path, + config=config, + ) + return dst_path + + else: + return model_path + +class StableDiffusion15CheckpointModel(DiffusersModel): + class Cnfig(ModelConfigBase): + vae: Optional[str] = Field(None) + config: Optional[str] = Field(None) + +class StableDiffusion2BaseDiffusersModel(DiffusersModel): + class Config(ModelConfigBase): + vae: Optional[str] = Field(None) + +class StableDiffusion2BaseCheckpointModel(DiffusersModel): + class Cnfig(ModelConfigBase): + vae: Optional[str] = Field(None) + config: Optional[str] = Field(None) + +class StableDiffusion2DiffusersModel(DiffusersModel): + class Config(ModelConfigBase): + vae: Optional[str] = Field(None) + attention_upscale: bool = Field(True) + +class StableDiffusion2CheckpointModel(DiffusersModel): + class Config(ModelConfigBase): + vae: Optional[str] = Field(None) + config: Optional[str] = Field(None) + attention_upscale: bool = Field(True) + + +class ClassifierModel(ModelBase): + #child_types: Dict[str, Type] + #child_sizes: Dict[str, int] + + def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType): + assert model_type == SDModelType.Classifier + super().__init__(model_path, base_model, model_type) + + self.child_types: Dict[str, Type] = dict() + self.child_sizes: Dict[str, int] = dict() + + try: + main_config = EmptyConfigLoader.load_config(self.model_path, config_name="config.json") + #main_config = json.loads(os.path.join(self.model_path, "config.json")) + except: + raise Exception("Invalid classifier model! (config.json not found or invalid)") + + self._load_tokenizer(main_config) + self._load_text_encoder(main_config) + self._load_feature_extractor(main_config) + + + def _load_tokenizer(self, main_config: dict): + try: + tokenizer_config = EmptyConfigLoader.load_config(self.model_path, config_name="tokenizer_config.json") + #tokenizer_config = json.loads(os.path.join(self.model_path, "tokenizer_config.json")) + except: + raise Exception("Invalid classifier model! (Failed to load tokenizer_config.json)") + + if "tokenizer_class" in tokenizer_config: + tokenizer_class_name = tokenizer_config["tokenizer_class"] + elif "model_type" in main_config: + tokenizer_class_name = transformers.models.auto.tokenization_auto.TOKENIZER_MAPPING_NAMES[main_config["model_type"]] + else: + raise Exception("Invalid classifier model! (Failed to detect tokenizer type)") + + self.child_types[SDModelType.Tokenizer] = self._hf_definition_to_type(["transformers", tokenizer_class_name]) + self.child_sizes[SDModelType.Tokenizer] = 0 + + + def _load_text_encoder(self, main_config: dict): + if "architectures" in main_config and len(main_config["architectures"]) > 0: + text_encoder_class_name = main_config["architectures"][0] + elif "model_type" in main_config: + text_encoder_class_name = transformers.models.auto.modeling_auto.MODEL_FOR_PRETRAINING_MAPPING_NAMES[main_config["model_type"]] + else: + raise Exception("Invalid classifier model! (Failed to detect text_encoder type)") + + self.child_types[SDModelType.TextEncoder] = self._hf_definition_to_type(["transformers", text_encoder_class_name]) + self.child_sizes[SDModelType.TextEncoder] = calc_model_size_by_fs(self.model_path) + + + def _load_feature_extractor(self, main_config: dict): + self.child_sizes[SDModelType.FeatureExtractor] = 0 + try: + feature_extractor_config = EmptyConfigLoader.load_config(self.model_path, config_name="preprocessor_config.json") + except: + return # feature extractor not passed with t5 + + try: + feature_extractor_class_name = feature_extractor_config["feature_extractor_type"] + self.child_types[SDModelType.FeatureExtractor] = self._hf_definition_to_type(["transformers", feature_extractor_class_name]) + except: + raise Exception("Invalid classifier model! (Unknown feature_extrator type)") + + + def get_size(self, child_type: Optional[SDModelType] = None): + if child_type is None: + return sum(self.child_sizes.values()) + else: + return self.child_sizes[child_type] + + + def get_model( + self, + torch_dtype: Optional[torch.dtype], + child_type: Optional[SDModelType] = None, + ): + if child_type is None: + raise Exception("Child model type can't be null on classififer model") + if child_type not in self.child_types: + return None # TODO: or raise + + model = self.child_types[child_type].from_pretrained( + self.model_path, + subfolder=child_type.value, + torch_dtype=torch_dtype, + ) + # calc more accurate size + self.child_sizes[child_type] = calc_model_size_by_data(model) + return model + + @staticmethod + def convert_if_required(model_path: Union[str, Path], cache_path: str, config: Optional[dict]) -> Path: + if not isinstance(model_path, Path): + model_path = Path(model_path) + return model_path + + + +class VaeModel(ModelBase): + #vae_class: Type + #model_size: int + + def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType): + assert model_type == ModelType.Vae + super().__init__(model_path, base_model, model_type) + + try: + config = EmptyConfigLoader.load_config(self.model_path, config_name="config.json") + #config = json.loads(os.path.join(self.model_path, "config.json")) + except: + raise Exception("Invalid vae model! (config.json not found or invalid)") + + try: + vae_class_name = config.get("_class_name", "AutoencoderKL") + self.vae_class = self._hf_definition_to_type(["diffusers", vae_class_name]) + self.model_size = calc_model_size_by_fs(self.model_path) + except: + raise Exception("Invalid vae model! (Unkown vae type)") + + def get_size(self, child_type: Optional[SDModelType] = None): + if child_type is not None: + raise Exception("There is no child models in vae model") + return self.model_size + + def get_model( + self, + torch_dtype: Optional[torch.dtype], + child_type: Optional[SDModelType] = None, + ): + if child_type is not None: + raise Exception("There is no child models in vae model") + + model = self.vae_class.from_pretrained( + self.model_path, + torch_dtype=torch_dtype, + ) + # calc more accurate size + self.model_size = calc_model_size_by_data(model) + return model + + @staticmethod + def convert_if_required(model_path: Union[str, Path], cache_path: str, config: Optional[dict]) -> Path: + if not isinstance(model_path, Path): + model_path = Path(model_path) + # TODO: + #_convert_vae_ckpt_and_cache + raise Exception("TODO: ") + + +class LoRAModel(ModelBase): + #model_size: int + + def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType): + assert model_type == ModelType.Lora + super().__init__(model_path, base_model, model_type) + + self.model_size = os.path.getsize(self.model_path) + + def get_size(self, child_type: Optional[SDModelType] = None): + if child_type is not None: + raise Exception("There is no child models in lora") + return self.model_size + + def get_model( + self, + torch_dtype: Optional[torch.dtype], + child_type: Optional[SDModelType] = None, + ): + if child_type is not None: + raise Exception("There is no child models in lora") + + model = LoRAModel.from_checkpoint( + file_path=self.model_path, + dtype=torch_dtype, + ) + + self.model_size = model.calc_size() + return model + + @staticmethod + def convert_if_required(model_path: Union[str, Path], cache_path: str, config: Optional[dict]) -> Path: + if not isinstance(model_path, Path): + model_path = Path(model_path) + + # TODO: add diffusers lora when it stabilizes a bit + return model_path + + +class TextualInversionModel(ModelBase): + #model_size: int + + def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType): + assert model_type == ModelType.TextualInversion + super().__init__(model_path, base_model, model_type) + + self.model_size = os.path.getsize(self.model_path) + + def get_size(self, child_type: Optional[SDModelType] = None): + if child_type is not None: + raise Exception("There is no child models in textual inversion") + return self.model_size + + def get_model( + self, + torch_dtype: Optional[torch.dtype], + child_type: Optional[SDModelType] = None, + ): + if child_type is not None: + raise Exception("There is no child models in textual inversion") + + model = TextualInversionModel.from_checkpoint( + file_path=self.model_path, + dtype=torch_dtype, + ) + + self.model_size = model.embedding.nelement() * model.embedding.element_size() + return model + + @staticmethod + def convert_if_required(model_path: Union[str, Path], cache_path: str, config: Optional[dict]) -> Path: + if not isinstance(model_path, Path): + model_path = Path(model_path) + return model_path + + + + + + + + + +def calc_model_size_by_fs( + model_path: str, + subfolder: Optional[str] = None, + variant: Optional[str] = None +): + if subfolder is not None: + model_path = os.path.join(model_path, subfolder) + + # this can happen when, for example, the safety checker + # is not downloaded. + if not os.path.exists(model_path): + return 0 + + all_files = os.listdir(model_path) + all_files = [f for f in all_files if os.path.isfile(os.path.join(model_path, f))] + + fp16_files = set([f for f in all_files if ".fp16." in f or ".fp16-" in f]) + bit8_files = set([f for f in all_files if ".8bit." in f or ".8bit-" in f]) + other_files = set(all_files) - fp16_files - bit8_files + + if variant is None: + files = other_files + elif variant == "fp16": + files = fp16_files + elif variant == "8bit": + files = bit8_files + else: + raise NotImplementedError(f"Unknown variant: {variant}") + + # try read from index if exists + index_postfix = ".index.json" + if variant is not None: + index_postfix = f".index.{variant}.json" + + for file in files: + if not file.endswith(index_postfix): + continue + try: + with open(os.path.join(model_path, file), "r") as f: + index_data = json.loads(f.read()) + return int(index_data["metadata"]["total_size"]) + except: + pass + + # calculate files size if there is no index file + formats = [ + (".safetensors",), # safetensors + (".bin",), # torch + (".onnx", ".pb"), # onnx + (".msgpack",), # flax + (".ckpt",), # tf + (".h5",), # tf2 + ] + + for file_format in formats: + model_files = [f for f in files if f.endswith(file_format)] + if len(model_files) == 0: + continue + + model_size = 0 + for model_file in model_files: + file_stats = os.stat(os.path.join(model_path, model_file)) + model_size += file_stats.st_size + return model_size + + #raise NotImplementedError(f"Unknown model structure! Files: {all_files}") + return 0 # scheduler/feature_extractor/tokenizer - models without loading to gpu + + +def calc_model_size_by_data(model) -> int: + if isinstance(model, DiffusionPipeline): + return _calc_pipeline_by_data(model) + elif isinstance(model, torch.nn.Module): + return _calc_model_by_data(model) + else: + return 0 + + +def _calc_pipeline_by_data(pipeline) -> int: + res = 0 + for submodel_key in pipeline.components.keys(): + submodel = getattr(pipeline, submodel_key) + if submodel is not None and isinstance(submodel, torch.nn.Module): + res += _calc_model_by_data(submodel) + return res + + +def _calc_model_by_data(model) -> int: + mem_params = sum([param.nelement()*param.element_size() for param in model.parameters()]) + mem_bufs = sum([buf.nelement()*buf.element_size() for buf in model.buffers()]) + mem = mem_params + mem_bufs # in bytes + return mem + + +def _convert_ckpt_and_cache(self, mconfig: DictConfig) -> Path: + """ + Convert the checkpoint model indicated in mconfig into a + diffusers, cache it to disk, and return Path to converted + file. If already on disk then just returns Path. + """ + app_config = InvokeAIAppConfig.get_config() + weights = app_config.root_dir / mconfig.path + config_file = app_config.root_dir / mconfig.config + diffusers_path = app_config.converted_ckpts_dir / weights.stem + + # return cached version if it exists + if diffusers_path.exists(): + return diffusers_path + + # TODO: I think that it more correctly to convert with embedded vae + # as if user will delete custom vae he will got not embedded but also custom vae + #vae_ckpt_path, vae_model = self._get_vae_for_conversion(weights, mconfig) + vae_ckpt_path, vae_model = None, None + + # to avoid circular import errors + from .convert_ckpt_to_diffusers import convert_ckpt_to_diffusers + with SilenceWarnings(): + convert_ckpt_to_diffusers( + weights, + diffusers_path, + extract_ema=True, + original_config_file=config_file, + vae=vae_model, + vae_path=str(app_config.root_dir / vae_ckpt_path) if vae_ckpt_path else None, + scan_needed=True, + ) + return diffusers_path + +def _convert_vae_ckpt_and_cache(self, mconfig: DictConfig) -> Path: + """ + Convert the VAE indicated in mconfig into a diffusers AutoencoderKL + object, cache it to disk, and return Path to converted + file. If already on disk then just returns Path. + """ + app_config = InvokeAIAppConfig.get_config() + root = app_config.root_dir + weights_file = root / mconfig.path + config_file = root / mconfig.config + diffusers_path = app_config.converted_ckpts_dir / weights_file.stem + image_size = mconfig.get('width') or mconfig.get('height') or 512 + + # return cached version if it exists + if diffusers_path.exists(): + return diffusers_path + + # this avoids circular import error + from .convert_ckpt_to_diffusers import convert_ldm_vae_to_diffusers + if weights_file.suffix == '.safetensors': + checkpoint = safetensors.torch.load_file(weights_file) + else: + checkpoint = torch.load(weights_file, map_location="cpu") + + # sometimes weights are hidden under "state_dict", and sometimes not + if "state_dict" in checkpoint: + checkpoint = checkpoint["state_dict"] + + config = OmegaConf.load(config_file) + + vae_model = convert_ldm_vae_to_diffusers( + checkpoint = checkpoint, + vae_config = config, + image_size = image_size + ) + vae_model.save_pretrained( + diffusers_path, + safe_serialization=is_safetensors_available() + ) + return diffusers_path