From 039fa73269f2decb55e75347e3b19dad534de9ac Mon Sep 17 00:00:00 2001 From: Sergey Borisov Date: Sun, 14 May 2023 03:06:26 +0300 Subject: [PATCH] Change SDModelType enum to string, fixes(model unload negative locks count, scheduler load error, saftensors convert, wrong logic in del_model, wrong parse metadata in web) --- invokeai/app/invocations/compel.py | 13 +- invokeai/app/invocations/latent.py | 63 ++-- invokeai/app/invocations/model.py | 38 +-- invokeai/app/services/events.py | 4 +- .../app/services/model_manager_service.py | 228 +++++++------- .../backend/model_management/model_cache.py | 277 ++++++++++-------- .../backend/model_management/model_manager.py | 126 ++++---- .../web/src/common/util/parseMetadata.ts | 2 +- 8 files changed, 388 insertions(+), 363 deletions(-) diff --git a/invokeai/app/invocations/compel.py b/invokeai/app/invocations/compel.py index 584e14ea0f..633b53accd 100644 --- a/invokeai/app/invocations/compel.py +++ b/invokeai/app/invocations/compel.py @@ -59,19 +59,14 @@ class CompelInvocation(BaseInvocation): def invoke(self, context: InvocationContext) -> CompelOutput: - # TODO: load without model text_encoder_info = context.services.model_manager.get_model( - model_name=self.clip.text_encoder.model_name, - model_type=SDModelType[self.clip.text_encoder.model_type], - submodel=SDModelType[self.clip.text_encoder.submodel], + **self.clip.text_encoder.dict(), ) tokenizer_info = context.services.model_manager.get_model( - model_name=self.clip.tokenizer.model_name, - model_type=SDModelType[self.clip.tokenizer.model_type], - submodel=SDModelType[self.clip.tokenizer.submodel], + **self.clip.tokenizer.dict(), ) - with text_encoder_info.context as text_encoder,\ - tokenizer_info.context as tokenizer: + with text_encoder_info as text_encoder,\ + tokenizer_info as tokenizer: # TODO: global? input? #use_full_precision = precision == "float32" or precision == "autocast" diff --git a/invokeai/app/invocations/latent.py b/invokeai/app/invocations/latent.py index 9ca1c81a3d..cedadde4d0 100644 --- a/invokeai/app/invocations/latent.py +++ b/invokeai/app/invocations/latent.py @@ -79,12 +79,8 @@ def get_scheduler( scheduler_info: ModelInfo, scheduler_name: str, ) -> Scheduler: - orig_scheduler_info = context.services.model_manager.get_model( - model_name=scheduler_info.model_name, - model_type=SDModelType[scheduler_info.model_type], - submodel=SDModelType[scheduler_info.submodel], - ) - with orig_scheduler_info.context as orig_scheduler: + orig_scheduler_info = context.services.model_manager.get_model(**scheduler_info.dict()) + with orig_scheduler_info as orig_scheduler: scheduler_config = orig_scheduler.config scheduler_class = scheduler_map.get(scheduler_name,'ddim') @@ -243,14 +239,8 @@ class TextToLatentsInvocation(BaseInvocation): def step_callback(state: PipelineIntermediateState): self.dispatch_progress(context, source_node_id, state) - #unet_info = context.services.model_manager.get_model(**self.unet.unet.dict()) - unet_info = context.services.model_manager.get_model( - model_name=self.unet.unet.model_name, - model_type=SDModelType[self.unet.unet.model_type], - submodel=SDModelType[self.unet.unet.submodel] if self.unet.unet.submodel else None, - ) - - with unet_info.context as unet: + unet_info = context.services.model_manager.get_model(**self.unet.unet.dict()) + with unet_info as unet: scheduler = get_scheduler( context=context, scheduler_info=self.unet.scheduler, @@ -309,12 +299,10 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation): #unet_info = context.services.model_manager.get_model(**self.unet.unet.dict()) unet_info = context.services.model_manager.get_model( - model_name=self.unet.unet.model_name, - model_type=SDModelType[self.unet.unet.model_type], - submodel=SDModelType[self.unet.unet.submodel] if self.unet.unet.submodel else None, + **self.unet.unet.dict(), ) - with unet_info.context as unet: + with unet_info as unet: scheduler = get_scheduler( context=context, scheduler_info=self.unet.scheduler, @@ -379,18 +367,18 @@ class LatentsToImageInvocation(BaseInvocation): #vae_info = context.services.model_manager.get_model(**self.vae.vae.dict()) vae_info = context.services.model_manager.get_model( - model_name=self.vae.vae.model_name, - model_type=SDModelType[self.vae.vae.model_type], - submodel=SDModelType[self.vae.vae.submodel] if self.vae.vae.submodel else None, + **self.vae.vae.dict(), ) - with vae_info.context as vae: - # TODO: check if it works + with vae_info as vae: if self.tiled: vae.enable_tiling() else: vae.disable_tiling() + # clear memory as vae decode can request a lot + torch.cuda.empty_cache() + with torch.inference_mode(): # copied from diffusers pipeline latents = latents / vae.config.scaling_factor @@ -509,36 +497,29 @@ class ImageToLatentsInvocation(BaseInvocation): #vae_info = context.services.model_manager.get_model(**self.vae.vae.dict()) vae_info = context.services.model_manager.get_model( - model_name=self.vae.vae.model_name, - model_type=SDModelType[self.vae.vae.model_type], - submodel=SDModelType[self.vae.vae.submodel] if self.vae.vae.submodel else None, + **self.vae.vae.dict(), ) image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB")) if image_tensor.dim() == 3: image_tensor = einops.rearrange(image_tensor, "c h w -> 1 c h w") - with vae_info.context as vae: - # TODO: check if it works + with vae_info as vae: if self.tiled: vae.enable_tiling() else: vae.disable_tiling() - latents = self.non_noised_latents_from_image(vae, image_tensor) + # non_noised_latents_from_image + image_tensor = image_tensor.to(device=vae.device, dtype=vae.dtype) + with torch.inference_mode(): + image_tensor_dist = vae.encode(image_tensor).latent_dist + latents = image_tensor_dist.sample().to( + dtype=vae.dtype + ) # FIXME: uses torch.randn. make reproducible! + + latents = 0.18215 * latents name = f"{context.graph_execution_state_id}__{self.id}" context.services.latents.set(name, latents) return LatentsOutput(latents=LatentsField(latents_name=name)) - - - def non_noised_latents_from_image(self, vae, init_image): - init_image = init_image.to(device=vae.device, dtype=vae.dtype) - with torch.inference_mode(): - init_latent_dist = vae.encode(init_image).latent_dist - init_latents = init_latent_dist.sample().to( - dtype=vae.dtype - ) # FIXME: uses torch.randn. make reproducible! - - init_latents = 0.18215 * init_latents - return init_latents \ No newline at end of file diff --git a/invokeai/app/invocations/model.py b/invokeai/app/invocations/model.py index 74467d7ab3..c473451c6c 100644 --- a/invokeai/app/invocations/model.py +++ b/invokeai/app/invocations/model.py @@ -8,8 +8,8 @@ from ...backend.model_management import SDModelType class ModelInfo(BaseModel): model_name: str = Field(description="Info to load unet submodel") - model_type: str = Field(description="Info to load unet submodel") - submodel: Optional[str] = Field(description="Info to load unet submodel") + model_type: SDModelType = Field(description="Info to load unet submodel") + submodel: Optional[SDModelType] = Field(description="Info to load unet submodel") class UNetField(BaseModel): unet: ModelInfo = Field(description="Info to load unet submodel") @@ -62,15 +62,15 @@ class ModelLoaderInvocation(BaseInvocation): # TODO: not found exceptions if not context.services.model_manager.model_exists( model_name=self.model_name, - model_type=SDModelType.diffusers, + model_type=SDModelType.Diffusers, ): raise Exception(f"Unkown model name: {self.model_name}!") """ if not context.services.model_manager.model_exists( model_name=self.model_name, - model_type=SDModelType.diffusers, - submodel=SDModelType.tokenizer, + model_type=SDModelType.Diffusers, + submodel=SDModelType.Tokenizer, ): raise Exception( f"Failed to find tokenizer submodel in {self.model_name}! Check if model corrupted" @@ -78,8 +78,8 @@ class ModelLoaderInvocation(BaseInvocation): if not context.services.model_manager.model_exists( model_name=self.model_name, - model_type=SDModelType.diffusers, - submodel=SDModelType.text_encoder, + model_type=SDModelType.Diffusers, + submodel=SDModelType.TextEncoder, ): raise Exception( f"Failed to find text_encoder submodel in {self.model_name}! Check if model corrupted" @@ -87,8 +87,8 @@ class ModelLoaderInvocation(BaseInvocation): if not context.services.model_manager.model_exists( model_name=self.model_name, - model_type=SDModelType.diffusers, - submodel=SDModelType.unet, + model_type=SDModelType.Diffusers, + submodel=SDModelType.UNet, ): raise Exception( f"Failed to find unet submodel from {self.model_name}! Check if model corrupted" @@ -100,32 +100,32 @@ class ModelLoaderInvocation(BaseInvocation): unet=UNetField( unet=ModelInfo( model_name=self.model_name, - model_type=SDModelType.diffusers.name, - submodel=SDModelType.unet.name, + model_type=SDModelType.Diffusers, + submodel=SDModelType.UNet, ), scheduler=ModelInfo( model_name=self.model_name, - model_type=SDModelType.diffusers.name, - submodel=SDModelType.scheduler.name, + model_type=SDModelType.Diffusers, + submodel=SDModelType.Scheduler, ), ), clip=ClipField( tokenizer=ModelInfo( model_name=self.model_name, - model_type=SDModelType.diffusers.name, - submodel=SDModelType.tokenizer.name, + model_type=SDModelType.Diffusers, + submodel=SDModelType.Tokenizer, ), text_encoder=ModelInfo( model_name=self.model_name, - model_type=SDModelType.diffusers.name, - submodel=SDModelType.text_encoder.name, + model_type=SDModelType.Diffusers, + submodel=SDModelType.TextEncoder, ), ), vae=VaeField( vae=ModelInfo( model_name=self.model_name, - model_type=SDModelType.diffusers.name, - submodel=SDModelType.vae.name, + model_type=SDModelType.Diffusers, + submodel=SDModelType.Vae, ), ) ) diff --git a/invokeai/app/services/events.py b/invokeai/app/services/events.py index d3f61a70b8..9e2db4bfe8 100644 --- a/invokeai/app/services/events.py +++ b/invokeai/app/services/events.py @@ -120,7 +120,7 @@ class EventServiceBase: node=node, source_node_id=source_node_id, model_name=model_name, - model_type=model_type.name, + model_type=model_type, submodel=submodel, ), ) @@ -143,7 +143,7 @@ class EventServiceBase: node=node, source_node_id=source_node_id, model_name=model_name, - model_type=model_type.name, + model_type=model_type, submodel=submodel, model_info=model_info, ), diff --git a/invokeai/app/services/model_manager_service.py b/invokeai/app/services/model_manager_service.py index 5c57c5462c..0d140511e0 100644 --- a/invokeai/app/services/model_manager_service.py +++ b/invokeai/app/services/model_manager_service.py @@ -1,21 +1,25 @@ # Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Team +from __future__ import annotations +import torch from abc import ABC, abstractmethod from pathlib import Path -from typing import Union, Callable, List, Tuple, types +from typing import Union, Callable, List, Tuple, types, TYPE_CHECKING from dataclasses import dataclass from invokeai.backend.model_management.model_manager import ( ModelManager, SDModelType, SDModelInfo, - torch, ) from invokeai.app.models.exceptions import CanceledException -from ...backend import Args,Globals # this must go when pr 3340 merged +from ...backend import Args, Globals # this must go when pr 3340 merged from ...backend.util import choose_precision, choose_torch_device +if TYPE_CHECKING: + from ..invocations.baseinvocation import BaseInvocation, InvocationContext + @dataclass class LastUsedModel: model_name: str=None @@ -28,9 +32,9 @@ class ModelManagerServiceBase(ABC): @abstractmethod def __init__( - self, - config: Args, - logger: types.ModuleType + self, + config: Args, + logger: types.ModuleType, ): """ Initialize with the path to the models.yaml config file. @@ -41,13 +45,14 @@ class ModelManagerServiceBase(ABC): pass @abstractmethod - def get_model(self, - model_name: str, - model_type: SDModelType=SDModelType.diffusers, - submodel: SDModelType=None, - node=None, # circular dependency issues, so untyped at moment - context=None, - )->SDModelInfo: + def get_model( + self, + model_name: str, + model_type: SDModelType = SDModelType.Diffusers, + submodel: Optional[SDModelType] = None, + node: Optional[BaseInvocation] = None, + context: Optional[InvocationContext] = None, + ) -> SDModelInfo: """Retrieve the indicated model with name and type. submodel can be used to get a part (such as the vae) of a diffusers pipeline.""" @@ -60,14 +65,14 @@ class ModelManagerServiceBase(ABC): @abstractmethod def model_exists( - self, - model_name: str, - model_type: SDModelType + self, + model_name: str, + model_type: SDModelType, ) -> bool: pass @abstractmethod - def default_model(self) -> Union[Tuple[str, SDModelType],None]: + def default_model(self) -> Optional[Tuple[str, SDModelType]]: """ Returns the name and typeof the default model, or None if none is defined. @@ -80,21 +85,21 @@ class ModelManagerServiceBase(ABC): pass @abstractmethod - def model_info(self, model_name: str, model_type: SDModelType)->dict: + def model_info(self, model_name: str, model_type: SDModelType) -> dict: """ Given a model name returns a dict-like (OmegaConf) object describing it. """ pass @abstractmethod - def model_names(self)->List[Tuple[str, SDModelType]]: + def model_names(self) -> List[Tuple[str, SDModelType]]: """ Returns a list of all the model names known. """ pass @abstractmethod - def list_models(self)->dict: + def list_models(self) -> dict: """ Return a dict of models in the format: { model_key1: {'status': 'active'|'cached'|'not loaded', @@ -110,12 +115,12 @@ class ModelManagerServiceBase(ABC): @abstractmethod def add_model( - self, - model_name: str, - model_type: SDModelType, - model_attributes: dict, - clobber: bool = False - )->None: + self, + model_name: str, + model_type: SDModelType, + model_attributes: dict, + clobber: bool = False + ) -> None: """ Update the named model with a dictionary of attributes. Will fail with an assertion error if the name already exists. Pass clobber=True to overwrite. @@ -126,10 +131,12 @@ class ModelManagerServiceBase(ABC): pass @abstractmethod - def del_model(self, - model_name: str, - model_type: SDModelType, - delete_files: bool = False): + def del_model( + self, + model_name: str, + model_type: SDModelType, + delete_files: bool = False, + ): """ Delete the named model from configuration. If delete_files is true, then the underlying weight file or diffusers directory will be deleted @@ -140,9 +147,9 @@ class ModelManagerServiceBase(ABC): @abstractmethod def import_diffuser_model( repo_or_path: Union[str, Path], - model_name: str = None, - description: str = None, - vae: dict = None, + model_name: Optional[str] = None, + description: Optional[str] = None, + vae: Optional[dict] = None, ) -> bool: """ Install the indicated diffuser model and returns True if successful. @@ -157,10 +164,10 @@ class ModelManagerServiceBase(ABC): @abstractmethod def import_lora( - self, - path: Path, - model_name: str=None, - description: str=None, + self, + path: Path, + model_name: Optional[str] = None, + description: Optional[str] = None, ): """ Creates an entry for the indicated lora file. Call @@ -170,10 +177,10 @@ class ModelManagerServiceBase(ABC): @abstractmethod def import_embedding( - self, - path: Path, - model_name: str=None, - description: str=None, + self, + path: Path, + model_name: str=None, + description: str=None, ): """ Creates an entry for the indicated textual inversion embedding file. @@ -223,7 +230,7 @@ class ModelManagerServiceBase(ABC): pass @abstractmethod - def commit(self, conf_file: Path=None) -> None: + def commit(self, conf_file: Path = None) -> None: """ Write current configuration out to the indicated file. If no conf_file is provided, then replaces the @@ -235,10 +242,10 @@ class ModelManagerServiceBase(ABC): class ModelManagerService(ModelManagerServiceBase): """Responsible for managing models on disk and in memory""" def __init__( - self, - config: Args, - logger: types.ModuleType - ): + self, + config: Args, + logger: types.ModuleType, + ): """ Initialize with the path to the models.yaml config file. Optional parameters are the torch device type, precision, max_models, @@ -255,7 +262,7 @@ class ModelManagerService(ModelManagerServiceBase): logger.debug(f'config file={config_file}') device = torch.device(choose_torch_device()) - if config.precision=="auto": + if config.precision == "auto": precision = choose_precision(device) dtype = torch.float32 if precision=='float32' \ else torch.float16 @@ -272,22 +279,24 @@ class ModelManagerService(ModelManagerServiceBase): sequential_offload = config.sequential_guidance - self.mgr = ModelManager(config=config_file, - device_type=device, - precision=dtype, - max_cache_size=max_cache_size, - sequential_offload=sequential_offload, - logger=logger - ) + self.mgr = ModelManager( + config=config_file, + device_type=device, + precision=dtype, + max_cache_size=max_cache_size, + sequential_offload=sequential_offload, + logger=logger, + ) logger.info('Model manager service initialized') - def get_model(self, - model_name: str, - model_type: SDModelType=SDModelType.diffusers, - submodel: SDModelType=None, - node=None, - context=None, - )->SDModelInfo: + def get_model( + self, + model_name: str, + model_type: SDModelType = SDModelType.Diffusers, + submodel: Optional[SDModelType] = None, + node: Optional[BaseInvocation] = None, + context: Optional[InvocationContext] = None, + ) -> SDModelInfo: """ Retrieve the indicated model. submodel can be used to get a part (such as the vae) of a diffusers mode. @@ -340,9 +349,9 @@ class ModelManagerService(ModelManagerServiceBase): return model_info def model_exists( - self, - model_name: str, - model_type: SDModelType + self, + model_name: str, + model_type: SDModelType, ) -> bool: """ Given a model name, returns True if it is a valid @@ -350,32 +359,33 @@ class ModelManagerService(ModelManagerServiceBase): """ return self.mgr.model_exists( model_name, - model_type) + model_type, + ) - def default_model(self) -> Union[Tuple[str, SDModelType],None]: + def default_model(self) -> Optional[Tuple[str, SDModelType]]: """ Returns the name of the default model, or None if none is defined. """ return self.mgr.default_model() - def set_default_model(self, model_name:str, model_type: SDModelType): + def set_default_model(self, model_name: str, model_type: SDModelType): """Sets the default model to the indicated name.""" self.mgr.set_default_model(model_name) - def model_info(self, model_name: str, model_type: SDModelType)->dict: + def model_info(self, model_name: str, model_type: SDModelType) -> dict: """ Given a model name returns a dict-like (OmegaConf) object describing it. """ return self.mgr.model_info(model_name) - def model_names(self)->List[Tuple[str, SDModelType]]: + def model_names(self) -> List[Tuple[str, SDModelType]]: """ Returns a list of all the model names known. """ return self.mgr.model_names() - def list_models(self)->dict: + def list_models(self) -> dict: """ Return a dict of models in the format: { model_key: {'status': 'active'|'cached'|'not loaded', @@ -388,11 +398,12 @@ class ModelManagerService(ModelManagerServiceBase): return self.mgr.list_models() def add_model( - self, - model_name: str, - model_type: SDModelType, - model_attributes: dict, - clobber: bool = False)->None: + self, + model_name: str, + model_type: SDModelType, + model_attributes: dict, + clobber: bool = False, + )->None: """ Update the named model with a dictionary of attributes. Will fail with an assertion error if the name already exists. Pass clobber=True to overwrite. @@ -400,14 +411,15 @@ class ModelManagerService(ModelManagerServiceBase): with an assertion error if provided attributes are incorrect or the model name is missing. Call commit() to write changes to disk. """ - return self.mgr.add_model(model_name, model_type, model_attributes, dict, clobber) + return self.mgr.add_model(model_name, model_type, model_attributes, clobber) - def del_model(self, - model_name: str, - model_type: SDModelType=SDModelType.diffusers, - delete_files: bool = False - ): + def del_model( + self, + model_name: str, + model_type: SDModelType = SDModelType.Diffusers, + delete_files: bool = False, + ): """ Delete the named model from configuration. If delete_files is true, then the underlying weight file or diffusers directory will be deleted @@ -416,11 +428,11 @@ class ModelManagerService(ModelManagerServiceBase): self.mgr.del_model(model_name, model_type, delete_files) def import_diffuser_model( - self, - repo_or_path: Union[str, Path], - model_name: str = None, - description: str = None, - vae: dict = None, + self, + repo_or_path: Union[str, Path], + model_name: Optional[str] = None, + description: Optional[str] = None, + vae: Optional[dict] = None, ) -> bool: """ Install the indicated diffuser model and returns True if successful. @@ -431,13 +443,13 @@ class ModelManagerService(ModelManagerServiceBase): You can optionally provide a model name and/or description. If not provided, then these will be derived from the repo name. Call commit() to write to disk. """ - return self.mgr.import_diffuser_model(repo_or_path, model_name, description, vae) + return self.mgr.import_diffuser_model(repo_or_path, model_name, description, vae) def import_lora( - self, - path: Path, - model_name: str=None, - description: str=None, + self, + path: Path, + model_name: Optional[str] = None, + description: Optional[str] = None, ): """ Creates an entry for the indicated lora file. Call @@ -446,10 +458,10 @@ class ModelManagerService(ModelManagerServiceBase): self.mgr.import_lora(path, model_name, description) def import_embedding( - self, - path: Path, - model_name: str=None, - description: str=None, + self, + path: Path, + model_name: Optional[str] = None, + description: Optional[str] = None, ): """ Creates an entry for the indicated textual inversion embedding file. @@ -462,9 +474,9 @@ class ModelManagerService(ModelManagerServiceBase): path_url_or_repo: str, model_name: str = None, description: str = None, - model_config_file: Path = None, - commit_to_conf: Path = None, - config_file_callback: Callable[[Path], Path] = None, + model_config_file: Optional[Path] = None, + commit_to_conf: Optional[Path] = None, + config_file_callback: Optional[Callable[[Path], Path]] = None, ) -> str: """Accept a string which could be: - a HF diffusers repo_id @@ -505,7 +517,7 @@ class ModelManagerService(ModelManagerServiceBase): ) - def commit(self, conf_file: Path=None): + def commit(self, conf_file: Optional[Path]=None): """ Write current configuration out to the indicated file. If no conf_file is provided, then replaces the @@ -514,16 +526,16 @@ class ModelManagerService(ModelManagerServiceBase): return self.mgr.commit(conf_file) def _emit_load_event( - self, - node, - context, - model_name: str, - model_type: SDModelType, - submodel: SDModelType, - model_info: SDModelInfo=None, + self, + node, + context, + model_name: str, + model_type: SDModelType, + submodel: SDModelType, + model_info: Optional[SDModelInfo] = None, ): if context.services.queue.is_canceled(context.graph_execution_state_id): - raise CanceledException + raise CanceledException() graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id) source_node_id = graph_execution_state.prepared_source_mapping[node.id] if context: @@ -536,7 +548,7 @@ class ModelManagerService(ModelManagerServiceBase): submodel=submodel, ) else: - context.services.events.emit_model_load_completed ( + context.services.events.emit_model_load_completed( graph_execution_state_id=context.graph_execution_state_id, node=node.dict(), source_node_id=source_node_id, diff --git a/invokeai/backend/model_management/model_cache.py b/invokeai/backend/model_management/model_cache.py index 558de0f28c..5c062483db 100644 --- a/invokeai/backend/model_management/model_cache.py +++ b/invokeai/backend/model_management/model_cache.py @@ -23,12 +23,12 @@ import warnings from collections import Counter from enum import Enum from pathlib import Path -from typing import Dict, Sequence, Union, Tuple, types +from typing import Dict, Sequence, Union, Tuple, types, Optional import torch import safetensors.torch -from diffusers import DiffusionPipeline, StableDiffusionPipeline, AutoencoderKL, SchedulerMixin, UNet2DConditionModel +from diffusers import DiffusionPipeline, StableDiffusionPipeline, AutoencoderKL, SchedulerMixin, UNet2DConditionModel, ConfigMixin from diffusers import logging as diffusers_logging from diffusers.pipelines.stable_diffusion.safety_checker import \ StableDiffusionSafetyChecker @@ -55,20 +55,38 @@ class LoraType(dict): class TIType(dict): pass -class SDModelType(Enum): - diffusers=StableDiffusionGeneratorPipeline # whole pipeline - vae=AutoencoderKL # diffusers parts - text_encoder=CLIPTextModel - tokenizer=CLIPTokenizer - unet=UNet2DConditionModel - scheduler=SchedulerMixin - safety_checker=StableDiffusionSafetyChecker - feature_extractor=CLIPFeatureExtractor +class SDModelType(str, Enum): + Diffusers="diffusers" # whole pipeline + Vae="vae" # diffusers parts + TextEncoder="text_encoder" + Tokenizer="tokenizer" + UNet="unet" + Scheduler="scheduler" + SafetyChecker="safety_checker" + FeatureExtractor="feature_extractor" # These are all loaded as dicts of tensors, and we # distinguish them by class - lora=LoraType - textual_inversion=TIType - + Lora="lora" + TextualInversion="textual_inversion" + +# TODO: +class EmptyScheduler(SchedulerMixin, ConfigMixin): + pass + +MODEL_CLASSES = { + SDModelType.Diffusers: StableDiffusionGeneratorPipeline, + SDModelType.Vae: AutoencoderKL, + SDModelType.TextEncoder: CLIPTextModel, # TODO: t5 + SDModelType.Tokenizer: CLIPTokenizer, # TODO: t5 + SDModelType.UNet: UNet2DConditionModel, + SDModelType.Scheduler: EmptyScheduler, + SDModelType.SafetyChecker: StableDiffusionSafetyChecker, + SDModelType.FeatureExtractor: CLIPFeatureExtractor, + + SDModelType.Lora: LoraType, + SDModelType.TextualInversion: TIType, +} + class ModelStatus(Enum): unknown='unknown' not_loaded='not loaded' @@ -80,21 +98,21 @@ class ModelStatus(Enum): # After loading, we will know it exactly. # Sizes are in Gigs, estimated for float16; double for float32 SIZE_GUESSTIMATE = { - SDModelType.diffusers: 2.2, - SDModelType.vae: 0.35, - SDModelType.text_encoder: 0.5, - SDModelType.tokenizer: 0.001, - SDModelType.unet: 3.4, - SDModelType.scheduler: 0.001, - SDModelType.safety_checker: 1.2, - SDModelType.feature_extractor: 0.001, - SDModelType.lora: 0.1, - SDModelType.textual_inversion: 0.001, + SDModelType.Diffusers: 2.2, + SDModelType.Vae: 0.35, + SDModelType.TextEncoder: 0.5, + SDModelType.Tokenizer: 0.001, + SDModelType.UNet: 3.4, + SDModelType.Scheduler: 0.001, + SDModelType.SafetyChecker: 1.2, + SDModelType.FeatureExtractor: 0.001, + SDModelType.Lora: 0.1, + SDModelType.TextualInversion: 0.001, } # The list of model classes we know how to fetch, for typechecking -ModelClass = Union[tuple([x.value for x in SDModelType])] -DiffusionClasses = (StableDiffusionGeneratorPipeline, AutoencoderKL, SchedulerMixin, UNet2DConditionModel) +ModelClass = Union[tuple([x for x in MODEL_CLASSES.values()])] +DiffusionClasses = (StableDiffusionGeneratorPipeline, AutoencoderKL, EmptyScheduler, UNet2DConditionModel) class UnsafeModelException(Exception): "Raised when a legacy model file fails the picklescan test" @@ -110,15 +128,15 @@ class ModelLocker(object): class ModelCache(object): def __init__( - self, - max_cache_size: float=DEFAULT_MAX_CACHE_SIZE, - execution_device: torch.device=torch.device('cuda'), - storage_device: torch.device=torch.device('cpu'), - precision: torch.dtype=torch.float16, - sequential_offload: bool=False, - lazy_offloading: bool=True, - sha_chunksize: int = 16777216, - logger: types.ModuleType = logger + self, + max_cache_size: float=DEFAULT_MAX_CACHE_SIZE, + execution_device: torch.device=torch.device('cuda'), + storage_device: torch.device=torch.device('cpu'), + precision: torch.dtype=torch.float16, + sequential_offload: bool=False, + lazy_offloading: bool=True, + sha_chunksize: int = 16777216, + logger: types.ModuleType = logger ): ''' :param max_models: Maximum number of models to cache in CPU RAM [4] @@ -145,15 +163,15 @@ class ModelCache(object): self.model_sizes: Dict[str,int] = dict() def get_model( - self, - repo_id_or_path: Union[str,Path], - model_type: SDModelType=SDModelType.diffusers, - subfolder: Path=None, - submodel: SDModelType=None, - revision: str=None, - attach_model_part: Tuple[SDModelType, str] = (None,None), - gpu_load: bool=True, - )->ModelLocker: # ?? what does it return + self, + repo_id_or_path: Union[str, Path], + model_type: SDModelType = SDModelType.Diffusers, + subfolder: Path = None, + submodel: SDModelType = None, + revision: str = None, + attach_model_part: Tuple[SDModelType, str] = (None, None), + gpu_load: bool = True, + ) -> ModelLocker: # ?? what does it return ''' Load and return a HuggingFace model wrapped in a context manager generator, with RAM caching. Use like this: @@ -178,14 +196,14 @@ class ModelCache(object): vae_context = cache.get_model( 'stabilityai/sd-stable-diffusion-2', - submodel=SDModelType.vae + submodel=SDModelType.Vae ) This is equivalent to: vae_context = cache.get_model( 'stabilityai/sd-stable-diffusion-2', - model_type = SDModelType.vae, + model_type = SDModelType.Vae, subfolder='vae' ) @@ -195,14 +213,14 @@ class ModelCache(object): pipeline_context = cache.get_model( 'runwayml/stable-diffusion-v1-5', - attach_model_part=(SDModelType.vae,'stabilityai/sd-vae-ft-mse') + attach_model_part=(SDModelType.Vae,'stabilityai/sd-vae-ft-mse') ) The model will be locked into GPU VRAM for the duration of the context. :param repo_id_or_path: either the HuggingFace repo_id or a Path to a local model :param model_type: An SDModelType enum indicating the type of the (parent) model :param subfolder: name of a subfolder in which the model can be found, e.g. "vae" - :param submodel: an SDModelType enum indicating the model part to return, e.g. SDModelType.vae + :param submodel: an SDModelType enum indicating the model part to return, e.g. SDModelType.Vae :param attach_model_part: load and attach a diffusers model component. Pass a tuple of format (SDModelType,repo_id) :param revision: model revision :param gpu_load: load the model into GPU [default True] @@ -211,7 +229,7 @@ class ModelCache(object): repo_id_or_path, revision, subfolder, - model_type.value, + model_type, ) # optimization: if caller is asking to load a submodel of a diffusers pipeline, then @@ -221,11 +239,11 @@ class ModelCache(object): repo_id_or_path, None, revision, - SDModelType.diffusers.value + SDModelType.Diffusers ) if possible_parent_key in self.models: key = possible_parent_key - submodel=model_type + submodel = model_type # Look for the model in the cache RAM if key in self.models: # cached - move to bottom of stack (most recently used) @@ -256,24 +274,24 @@ class ModelCache(object): self.current_cache_size += mem_used # increment size of the cache # this is a bit of legacy work needed to support the old-style "load this diffuser with custom VAE" - if model_type==SDModelType.diffusers and attach_model_part[0]: - self.attach_part(model,*attach_model_part) + if model_type == SDModelType.Diffusers and attach_model_part[0]: + self.attach_part(model, *attach_model_part) self.stack.append(key) # add to LRU cache - self.models[key]=model # keep copy of model in dict + self.models[key] = model # keep copy of model in dict if submodel: - model = getattr(model, submodel.name) + model = getattr(model, submodel) return self.ModelLocker(self, key, model, gpu_load) def uncache_model(self, key: str): '''Remove corresponding model from the cache''' if key is not None and key in self.models: - with contextlib.suppress(ValueError), contextlib.suppress(KeyError): - del self.models[key] - del self.locked_models[key] - self.loaded_models.remove(key) + self.models.pop(key, None) + self.locked_models.pop(key, None) + self.loaded_models.discard(key) + with contextlib.suppress(ValueError): self.stack.remove(key) class ModelLocker(object): @@ -302,7 +320,7 @@ class ModelCache(object): if model.device != cache.execution_device: cache.logger.debug(f'Moving {key} into {cache.execution_device}') with VRAMUsage() as mem: - model.to(cache.execution_device) # move into GPU + model.to(cache.execution_device, dtype=cache.precision) # move into GPU cache.logger.debug(f'GPU VRAM used for load: {(mem.vram_used/GIG):.2f} GB') cache.model_sizes[key] = mem.vram_used # more accurate size @@ -312,13 +330,16 @@ class ModelCache(object): else: # in the event that the caller wants the model in RAM, we # move it into CPU if it is in GPU and not locked - if hasattr(model,'to') and (key in cache.loaded_models + if hasattr(model, 'to') and (key in cache.loaded_models and cache.locked_models[key] == 0): model.to(cache.storage_device) cache.loaded_models.remove(key) return model def __exit__(self, type, value, traceback): + if not hasattr(self.model, 'to'): + return + key = self.key cache = self.cache cache.locked_models[key] -= 1 @@ -326,11 +347,12 @@ class ModelCache(object): cache._offload_unlocked_models() cache._print_cuda_stats() - def attach_part(self, - diffusers_model: StableDiffusionPipeline, - part_type: SDModelType, - part_id: str - ): + def attach_part( + self, + diffusers_model: StableDiffusionPipeline, + part_type: SDModelType, + part_id: str, + ): ''' Attach a diffusers model part to a diffusers model. This can be used to replace the VAE, tokenizer, textencoder, unet, etc. @@ -338,27 +360,26 @@ class ModelCache(object): :param part_type: An SD ModelType indicating the part :param part_id: A HF repo_id for the part ''' - part_key = part_type.name - part_class = part_type.value part = self._load_diffusers_from_storage( part_id, - model_class=part_class, + model_class=MODEL_CLASSES[part_type], ) part.to(diffusers_model.device) - setattr(diffusers_model,part_key,part) - self.logger.debug(f'Attached {part_key} {part_id}') + setattr(diffusers_model, part_type, part) + self.logger.debug(f'Attached {part_type} {part_id}') - def status(self, - repo_id_or_path: Union[str,Path], - model_type: SDModelType=SDModelType.diffusers, - revision: str=None, - subfolder: Path=None, - )->ModelStatus: + def status( + self, + repo_id_or_path: Union[str, Path], + model_type: SDModelType = SDModelType.Diffusers, + revision: str = None, + subfolder: Path = None, + ) -> ModelStatus: key = self._model_key( repo_id_or_path, revision, subfolder, - model_type.value, + model_type, ) if key not in self.models: return ModelStatus.not_loaded @@ -370,9 +391,11 @@ class ModelCache(object): else: return ModelStatus.in_ram - def model_hash(self, - repo_id_or_path: Union[str,Path], - revision: str="main")->str: + def model_hash( + self, + repo_id_or_path: Union[str, Path], + revision: str = "main", + ) -> str: ''' Given the HF repo id or path to a model on disk, returns a unique hash. Works for legacy checkpoint files, HF models on disk, and HF repo IDs @@ -385,7 +408,7 @@ class ModelCache(object): else: return self._hf_commit_hash(repo_id_or_path,revision) - def cache_size(self)->float: + def cache_size(self) -> float: "Return the current size of the cache, in GB" return self.current_cache_size / GIG @@ -407,10 +430,15 @@ class ModelCache(object): logger.debug("Model scanned ok") @staticmethod - def _model_key(path,revision,subfolder,model_class)->str: - return ':'.join([str(path),str(revision or ''),str(subfolder or ''),model_class.__name__]) + def _model_key(path, revision, subfolder, model_class) -> str: + return ':'.join([ + str(path), + str(revision or ''), + str(subfolder or ''), + model_class, + ]) - def _has_cuda(self)->bool: + def _has_cuda(self) -> bool: return self.execution_device.type == 'cuda' def _print_cuda_stats(self): @@ -450,43 +478,43 @@ class ModelCache(object): self.loaded_models.remove(key) def _load_model_from_storage( - self, - repo_id_or_path: Union[str,Path], - subfolder: Path=None, - revision: str=None, - model_type: SDModelType=SDModelType.diffusers, - )->ModelClass: + self, + repo_id_or_path: Union[str, Path], + subfolder: Optional[Path] = None, + revision: Optional[str] = None, + model_type: SDModelType = SDModelType.Diffusers, + ) -> ModelClass: ''' Load and return a HuggingFace model. :param repo_id_or_path: either the HuggingFace repo_id or a Path to a local model :param subfolder: name of a subfolder in which the model can be found, e.g. "vae" :param revision: model revision - :param model_type: type of model to return, defaults to SDModelType.diffusers + :param model_type: type of model to return, defaults to SDModelType.Diffusers ''' # silence transformer and diffuser warnings with SilenceWarnings(): - if model_type==SDModelType.lora: + if model_type==SDModelType.Lora: model = self._load_lora_from_storage(repo_id_or_path) - elif model_type==SDModelType.textual_inversion: + elif model_type==SDModelType.TextualInversion: model = self._load_ti_from_storage(repo_id_or_path) else: model = self._load_diffusers_from_storage( repo_id_or_path, subfolder, revision, - model_type.value, + model_type, ) - if self.sequential_offload and isinstance(model,StableDiffusionGeneratorPipeline): + if self.sequential_offload and isinstance(model, StableDiffusionGeneratorPipeline): model.enable_offload_submodels(self.execution_device) return model def _load_diffusers_from_storage( - self, - repo_id_or_path: Union[str,Path], - subfolder: Path=None, - revision: str=None, - model_class: ModelClass=StableDiffusionGeneratorPipeline, - )->ModelClass: + self, + repo_id_or_path: Union[str, Path], + subfolder: Optional[Path] = None, + revision: Optional[str] = None, + model_type: ModelClass = StableDiffusionGeneratorPipeline, + ) -> ModelClass: ''' Load and return a HuggingFace model using from_pretrained(). :param repo_id_or_path: either the HuggingFace repo_id or a Path to a local model @@ -494,17 +522,26 @@ class ModelCache(object): :param revision: model revision :param model_class: class of model to return, defaults to StableDiffusionGeneratorPIpeline ''' - revisions = [revision] if revision \ - else ['fp16','main'] if self.precision==torch.float16 \ - else ['main'] - extra_args = {'torch_dtype': self.precision, - 'safety_checker': None}\ - if model_class in DiffusionClasses\ - else {} + + model_class = MODEL_CLASSES[model_type] + + if revision is not None: + revisions = [revision] + elif self.precision == torch.float16: + revisions = ['fp16', 'main'] + else: + revisions = ['main'] + + extra_args = dict() + if model_class in DiffusionClasses: + extra_args = dict( + torch_dtype=self.precision, + safety_checker=None, + ) for rev in revisions: try: - model = model_class.from_pretrained( + model = model_class.from_pretrained( repo_id_or_path, revision=rev, subfolder=subfolder or '.', @@ -517,13 +554,13 @@ class ModelCache(object): pass return model - def _load_lora_from_storage(self, lora_path: Path)->SDModelType.lora.value: - assert False,"_load_lora_from_storage() is not yet implemented" + def _load_lora_from_storage(self, lora_path: Path) -> LoraType: + assert False, "_load_lora_from_storage() is not yet implemented" - def _load_ti_from_storage(self, lora_path: Path)->SDModelType.textual_inversion.value: - assert False,"_load_ti_from_storage() is not yet implemented" + def _load_ti_from_storage(self, lora_path: Path) -> TIType: + assert False, "_load_ti_from_storage() is not yet implemented" - def _legacy_model_hash(self, checkpoint_path: Union[str,Path])->str: + def _legacy_model_hash(self, checkpoint_path: Union[str, Path]) -> str: sha = hashlib.sha256() path = Path(checkpoint_path) assert path.is_file(),f"File {checkpoint_path} not found" @@ -544,7 +581,7 @@ class ModelCache(object): f.write(hash) return hash - def _local_model_hash(self, model_path: Union[str,Path])->str: + def _local_model_hash(self, model_path: Union[str, Path]) -> str: sha = hashlib.sha256() path = Path(model_path) @@ -566,7 +603,7 @@ class ModelCache(object): f.write(hash) return hash - def _hf_commit_hash(self, repo_id: str, revision: str='main')->str: + def _hf_commit_hash(self, repo_id: str, revision: str='main') -> str: api = HfApi() info = api.list_repo_refs( repo_id=repo_id, @@ -578,7 +615,7 @@ class ModelCache(object): return desired_revisions[0].target_commit @staticmethod - def calc_model_size(model)->int: + def calc_model_size(model) -> int: if isinstance(model,DiffusionPipeline): return ModelCache._calc_pipeline(model) elif isinstance(model,torch.nn.Module): @@ -587,7 +624,7 @@ class ModelCache(object): return None @staticmethod - def _calc_pipeline(pipeline)->int: + def _calc_pipeline(pipeline) -> int: res = 0 for submodel_key in pipeline.components.keys(): submodel = getattr(pipeline, submodel_key) @@ -596,7 +633,7 @@ class ModelCache(object): return res @staticmethod - def _calc_model(model)->int: + def _calc_model(model) -> int: mem_params = sum([param.nelement()*param.element_size() for param in model.parameters()]) mem_bufs = sum([buf.nelement()*buf.element_size() for buf in model.buffers()]) mem = mem_params + mem_bufs # in bytes diff --git a/invokeai/backend/model_management/model_manager.py b/invokeai/backend/model_management/model_manager.py index 87fb938338..05e0afff98 100644 --- a/invokeai/backend/model_management/model_manager.py +++ b/invokeai/backend/model_management/model_manager.py @@ -27,7 +27,7 @@ Typical usage: max_cache_size=8 ) # gigabytes - model_info = manager.get_model('stable-diffusion-1.5', SDModelType.diffusers) + model_info = manager.get_model('stable-diffusion-1.5', SDModelType.Diffusers) with model_info.context as my_model: my_model.latents_from_embeddings(...) @@ -45,7 +45,7 @@ parameter: model_info = manager.get_model( 'clip-tokenizer', - model_type=SDModelType.tokenizer + model_type=SDModelType.Tokenizer ) This will raise an InvalidModelError if the format defined in the @@ -96,7 +96,7 @@ SUBMODELS: It is also possible to fetch an isolated submodel from a diffusers model. Use the `submodel` parameter to select which part: - vae = manager.get_model('stable-diffusion-1.5',submodel=SDModelType.vae) + vae = manager.get_model('stable-diffusion-1.5',submodel=SDModelType.Vae) with vae.context as my_vae: print(type(my_vae)) # "AutoencoderKL" @@ -120,8 +120,8 @@ separated by "/". Example: You can now use the `model_type` argument to indicate which model you want: - tokenizer = mgr.get('clip-large',model_type=SDModelType.tokenizer) - encoder = mgr.get('clip-large',model_type=SDModelType.text_encoder) + tokenizer = mgr.get('clip-large',model_type=SDModelType.Tokenizer) + encoder = mgr.get('clip-large',model_type=SDModelType.TextEncoder) OTHER FUNCTIONS: @@ -254,7 +254,7 @@ class ModelManager(object): def model_exists( self, model_name: str, - model_type: SDModelType = SDModelType.diffusers, + model_type: SDModelType = SDModelType.Diffusers, ) -> bool: """ Given a model name, returns True if it is a valid @@ -264,28 +264,28 @@ class ModelManager(object): return model_key in self.config def create_key(self, model_name: str, model_type: SDModelType) -> str: - return f"{model_type.name}/{model_name}" + return f"{model_type}/{model_name}" def parse_key(self, model_key: str) -> Tuple[str, SDModelType]: model_type_str, model_name = model_key.split('/', 1) - if model_type_str not in SDModelType.__members__: - # TODO: + try: + model_type = SDModelType(model_type_str) + return (model_name, model_type) + except: raise Exception(f"Unknown model type: {model_type_str}") - return (model_name, SDModelType[model_type_str]) - def get_model( self, model_name: str, - model_type: SDModelType=SDModelType.diffusers, - submodel: SDModelType=None, + model_type: SDModelType = SDModelType.Diffusers, + submodel: Optional[SDModelType] = None, ) -> SDModelInfo: """Given a model named identified in models.yaml, return an SDModelInfo object describing it. :param model_name: symbolic name of the model in models.yaml :param model_type: SDModelType enum indicating the type of model to return :param submodel: an SDModelType enum indicating the portion of - the model to retrieve (e.g. SDModelType.vae) + the model to retrieve (e.g. SDModelType.Vae) If not provided, the model_type will be read from the `format` field of the corresponding stanza. If provided, the model_type will be used @@ -304,17 +304,17 @@ class ModelManager(object): test1_pipeline = mgr.get_model('test1') # returns a StableDiffusionGeneratorPipeline - test1_vae1 = mgr.get_model('test1', submodel=SDModelType.vae) + test1_vae1 = mgr.get_model('test1', submodel=SDModelType.Vae) # returns the VAE part of a diffusers model as an AutoencoderKL - test1_vae2 = mgr.get_model('test1', model_type=SDModelType.diffusers, submodel=SDModelType.vae) + test1_vae2 = mgr.get_model('test1', model_type=SDModelType.Diffusers, submodel=SDModelType.Vae) # does the same thing as the previous statement. Note that model_type # is for the parent model, and submodel is for the part - test1_lora = mgr.get_model('test1', model_type=SDModelType.lora) + test1_lora = mgr.get_model('test1', model_type=SDModelType.Lora) # returns a LoRA embed (as a 'dict' of tensors) - test1_encoder = mgr.get_modelI('test1', model_type=SDModelType.textencoder) + test1_encoder = mgr.get_modelI('test1', model_type=SDModelType.TextEncoder) # raises an InvalidModelError """ @@ -332,10 +332,10 @@ class ModelManager(object): mconfig = self.config[model_key] # type already checked as it's part of key - if model_type == SDModelType.diffusers: + if model_type == SDModelType.Diffusers: # intercept stanzas that point to checkpoint weights and replace them # with the equivalent diffusers model - if mconfig.format in ["ckpt", "diffusers"]: + if mconfig.format in ["ckpt", "safetensors"]: location = self.convert_ckpt_and_cache(mconfig) else: location = global_resolve_path(mconfig.get('path')) or mconfig.get('repo_id') @@ -355,13 +355,13 @@ class ModelManager(object): vae = (None, None) with suppress(Exception): vae_id = mconfig.vae.repo_id - vae = (SDModelType.vae, vae_id) + vae = (SDModelType.Vae, vae_id) # optimization - don't load whole model if the user # is asking for just a piece of it - if model_type == SDModelType.diffusers and submodel and not subfolder: + if model_type == SDModelType.Diffusers and submodel and not subfolder: model_type = submodel - subfolder = submodel.name + subfolder = submodel.value submodel = None model_context = self.cache.get_model( @@ -390,7 +390,7 @@ class ModelManager(object): _cache = self.cache ) - def default_model(self) -> Union[Tuple[str, SDModelType],None]: + def default_model(self) -> Optional[Tuple[str, SDModelType]]: """ Returns the name of the default model, or None if none is defined. @@ -401,7 +401,7 @@ class ModelManager(object): return (model_name, model_type) return self.model_names()[0][0] - def set_default_model(self, model_name: str, model_type: SDModelType=SDModelType.diffusers) -> None: + def set_default_model(self, model_name: str, model_type: SDModelType=SDModelType.Diffusers) -> None: """ Set the default model. The change will not take effect until you call model_manager.commit() @@ -415,25 +415,25 @@ class ModelManager(object): config[self.create_key(model_name, model_type)]["default"] = True def model_info( - self, - model_name: str, - model_type: SDModelType=SDModelType.diffusers + self, + model_name: str, + model_type: SDModelType=SDModelType.Diffusers, ) -> dict: """ Given a model name returns the OmegaConf (dict-like) object describing it. """ if not self.exists(model_name, model_type): return None - return self.config[self.create_key(model_name,model_type)] + return self.config[self.create_key(model_name, model_type)] def model_names(self) -> List[Tuple[str, SDModelType]]: """ Return a list of (str, SDModelType) corresponding to all models known to the configuration. """ - return [(self.parse_key(x)) for x in self.config.keys() if isinstance(self.config[x],DictConfig)] + return [(self.parse_key(x)) for x in self.config.keys() if isinstance(self.config[x], DictConfig)] - def is_legacy(self, model_name: str, model_type: SDModelType.diffusers) -> bool: + def is_legacy(self, model_name: str, model_type: SDModelType.Diffusers) -> bool: """ Return true if this is a legacy (.ckpt) model """ @@ -461,14 +461,14 @@ class ModelManager(object): # don't include VAEs in listing (legacy style) if "config" in stanza and "/VAE/" in stanza["config"]: continue - if model_key=='config_file_version': + if model_key == 'config_file_version': continue model_name, model_type = self.parse_key(model_key) models[model_key] = dict() # TODO: return all models in future - if model_type != SDModelType.diffusers: + if model_type != SDModelType.Diffusers: continue model_format = stanza.get('format') @@ -477,15 +477,15 @@ class ModelManager(object): status = self.cache.status( stanza.get('weights') or stanza.get('repo_id'), revision=stanza.get('revision'), - subfolder=stanza.get('subfolder') + subfolder=stanza.get('subfolder'), ) description = stanza.get("description", None) models[model_key].update( model_name=model_name, - model_type=model_type.name, + model_type=model_type, format=model_format, description=description, - status=status.value + status=status.value, ) @@ -528,8 +528,8 @@ class ModelManager(object): def del_model( self, model_name: str, - model_type: SDModelType.diffusers, - delete_files: bool = False + model_type: SDModelType.Diffusers, + delete_files: bool = False, ): """ Delete the named model. @@ -539,9 +539,9 @@ class ModelManager(object): if model_cfg is None: self.logger.error( - f"Unknown model {model_key}" - ) - return + f"Unknown model {model_key}" + ) + return # TODO: some legacy? #if model_name in self.stack: @@ -571,7 +571,7 @@ class ModelManager(object): model_name: str, model_type: SDModelType, model_attributes: dict, - clobber: bool = False + clobber: bool = False, ) -> None: """ Update the named model with a dictionary of attributes. Will fail with an @@ -581,7 +581,7 @@ class ModelManager(object): attributes are incorrect or the model name is missing. """ - if model_type == SDModelType.diffusers: + if model_type == SDModelType.Fiffusers: # TODO: automaticaly or manualy? #assert "format" in model_attributes, 'missing required field "format"' model_format = "ckpt" if "weights" in model_attributes else "diffusers" @@ -647,16 +647,16 @@ class ModelManager(object): else: new_config.update(repo_id=repo_or_path) - self.add_model(model_name, SDModelType.diffusers, new_config, True) + self.add_model(model_name, SDModelType.Diffusers, new_config, True) if commit_to_conf: self.commit(commit_to_conf) - return self.create_key(model_name, SDModelType.diffusers) + return self.create_key(model_name, SDModelType.Diffusers) def import_lora( self, path: Path, - model_name: str=None, - description: str=None, + model_name: Optional[str] = None, + description: Optional[str] = None, ): """ Creates an entry for the indicated lora file. Call @@ -667,7 +667,7 @@ class ModelManager(object): model_description = description or f"LoRA model {model_name}" self.add_model( model_name, - SDModelType.lora, + SDModelType.Lora, dict( format="lora", weights=str(path), @@ -679,8 +679,8 @@ class ModelManager(object): def import_embedding( self, path: Path, - model_name: str=None, - description: str=None, + model_name: Optional[str] = None, + description: Optional[str] = None, ): """ Creates an entry for the indicated lora file. Call @@ -696,7 +696,7 @@ class ModelManager(object): model_description = description or f"Textual embedding model {model_name}" self.add_model( model_name, - SDModelType.textual_inversion, + SDModelType.TextualInversion, dict( format="textual_inversion", weights=str(weights), @@ -746,11 +746,11 @@ class ModelManager(object): def heuristic_import( self, path_url_or_repo: str, - model_name: str = None, - description: str = None, - model_config_file: Path = None, - commit_to_conf: Path = None, - config_file_callback: Callable[[Path], Path] = None, + model_name: Optional[str] = None, + description: Optional[str] = None, + model_config_file: Optional[Path] = None, + commit_to_conf: Optional[Path] = None, + config_file_callback: Optional[Callable[[Path], Path]] = None, ) -> str: """Accept a string which could be: - a HF diffusers repo_id @@ -927,7 +927,7 @@ class ModelManager(object): ) return model_name - def convert_ckpt_and_cache(self, mconfig: DictConfig)->Path: + def convert_ckpt_and_cache(self, mconfig: DictConfig) -> Path: """ Convert the checkpoint model indicated in mconfig into a diffusers, cache it to disk, and return Path to converted @@ -961,7 +961,7 @@ class ModelManager(object): self, weights: Path, mconfig: DictConfig - ) -> Tuple[Path, SDModelType.vae]: + ) -> Tuple[Path, AutoencoderKL]: # VAE handling is convoluted # 1. If there is a .vae.ckpt file sharing same stem as weights, then use # it as the vae_path passed to convert @@ -990,7 +990,7 @@ class ModelManager(object): vae_diffusers_location = "stabilityai/sd-vae-ft-mse" if vae_diffusers_location: - vae_model = self.cache.get_model(vae_diffusers_location, SDModelType.vae).model + vae_model = self.cache.get_model(vae_diffusers_location, SDModelType.Vae).model return (None, vae_model) return (None, None) @@ -1038,7 +1038,7 @@ class ModelManager(object): vae_model = None if vae: vae_location = global_resolve_path(vae.get('path')) or vae.get('repo_id') - vae_model = self.cache.get_model(vae_location,SDModelType.vae).model + vae_model = self.cache.get_model(vae_location, SDModelType.Vae).model vae_path = None convert_ckpt_to_diffusers( ckpt_path, @@ -1058,11 +1058,11 @@ class ModelManager(object): description=model_description, format="diffusers", ) - if self.model_exists(model_name, SDModelType.diffusers): - self.del_model(model_name, SDModelType.diffusers) + if self.model_exists(model_name, SDModelType.Diffusers): + self.del_model(model_name, SDModelType.Diffusers) self.add_model( model_name, - SDModelType.diffusers, + SDModelType.Diffusers, new_config, True ) diff --git a/invokeai/frontend/web/src/common/util/parseMetadata.ts b/invokeai/frontend/web/src/common/util/parseMetadata.ts index 95d08db3e0..f1828d95e7 100644 --- a/invokeai/frontend/web/src/common/util/parseMetadata.ts +++ b/invokeai/frontend/web/src/common/util/parseMetadata.ts @@ -263,7 +263,7 @@ export const parseNodeMetadata = ( return; } - if ('unet' in nodeItem && 'tokenizer' in nodeItem) { + if ('unet' in nodeItem && 'scheduler' in nodeItem) { const unetField = parseUNetField(nodeItem); if (unetField) { parsed[nodeKey] = unetField;