Change SDModelType enum to string, fixes(model unload negative locks count, scheduler load error, saftensors convert, wrong logic in del_model, wrong parse metadata in web)

This commit is contained in:
Sergey Borisov 2023-05-14 03:06:26 +03:00
parent 2204e47596
commit 039fa73269
8 changed files with 388 additions and 363 deletions

View File

@ -59,19 +59,14 @@ class CompelInvocation(BaseInvocation):
def invoke(self, context: InvocationContext) -> CompelOutput: def invoke(self, context: InvocationContext) -> CompelOutput:
# TODO: load without model
text_encoder_info = context.services.model_manager.get_model( text_encoder_info = context.services.model_manager.get_model(
model_name=self.clip.text_encoder.model_name, **self.clip.text_encoder.dict(),
model_type=SDModelType[self.clip.text_encoder.model_type],
submodel=SDModelType[self.clip.text_encoder.submodel],
) )
tokenizer_info = context.services.model_manager.get_model( tokenizer_info = context.services.model_manager.get_model(
model_name=self.clip.tokenizer.model_name, **self.clip.tokenizer.dict(),
model_type=SDModelType[self.clip.tokenizer.model_type],
submodel=SDModelType[self.clip.tokenizer.submodel],
) )
with text_encoder_info.context as text_encoder,\ with text_encoder_info as text_encoder,\
tokenizer_info.context as tokenizer: tokenizer_info as tokenizer:
# TODO: global? input? # TODO: global? input?
#use_full_precision = precision == "float32" or precision == "autocast" #use_full_precision = precision == "float32" or precision == "autocast"

View File

@ -79,12 +79,8 @@ def get_scheduler(
scheduler_info: ModelInfo, scheduler_info: ModelInfo,
scheduler_name: str, scheduler_name: str,
) -> Scheduler: ) -> Scheduler:
orig_scheduler_info = context.services.model_manager.get_model( orig_scheduler_info = context.services.model_manager.get_model(**scheduler_info.dict())
model_name=scheduler_info.model_name, with orig_scheduler_info as orig_scheduler:
model_type=SDModelType[scheduler_info.model_type],
submodel=SDModelType[scheduler_info.submodel],
)
with orig_scheduler_info.context as orig_scheduler:
scheduler_config = orig_scheduler.config scheduler_config = orig_scheduler.config
scheduler_class = scheduler_map.get(scheduler_name,'ddim') scheduler_class = scheduler_map.get(scheduler_name,'ddim')
@ -243,14 +239,8 @@ class TextToLatentsInvocation(BaseInvocation):
def step_callback(state: PipelineIntermediateState): def step_callback(state: PipelineIntermediateState):
self.dispatch_progress(context, source_node_id, state) self.dispatch_progress(context, source_node_id, state)
#unet_info = context.services.model_manager.get_model(**self.unet.unet.dict()) unet_info = context.services.model_manager.get_model(**self.unet.unet.dict())
unet_info = context.services.model_manager.get_model( with unet_info as unet:
model_name=self.unet.unet.model_name,
model_type=SDModelType[self.unet.unet.model_type],
submodel=SDModelType[self.unet.unet.submodel] if self.unet.unet.submodel else None,
)
with unet_info.context as unet:
scheduler = get_scheduler( scheduler = get_scheduler(
context=context, context=context,
scheduler_info=self.unet.scheduler, scheduler_info=self.unet.scheduler,
@ -309,12 +299,10 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
#unet_info = context.services.model_manager.get_model(**self.unet.unet.dict()) #unet_info = context.services.model_manager.get_model(**self.unet.unet.dict())
unet_info = context.services.model_manager.get_model( unet_info = context.services.model_manager.get_model(
model_name=self.unet.unet.model_name, **self.unet.unet.dict(),
model_type=SDModelType[self.unet.unet.model_type],
submodel=SDModelType[self.unet.unet.submodel] if self.unet.unet.submodel else None,
) )
with unet_info.context as unet: with unet_info as unet:
scheduler = get_scheduler( scheduler = get_scheduler(
context=context, context=context,
scheduler_info=self.unet.scheduler, scheduler_info=self.unet.scheduler,
@ -379,18 +367,18 @@ class LatentsToImageInvocation(BaseInvocation):
#vae_info = context.services.model_manager.get_model(**self.vae.vae.dict()) #vae_info = context.services.model_manager.get_model(**self.vae.vae.dict())
vae_info = context.services.model_manager.get_model( vae_info = context.services.model_manager.get_model(
model_name=self.vae.vae.model_name, **self.vae.vae.dict(),
model_type=SDModelType[self.vae.vae.model_type],
submodel=SDModelType[self.vae.vae.submodel] if self.vae.vae.submodel else None,
) )
with vae_info.context as vae: with vae_info as vae:
# TODO: check if it works
if self.tiled: if self.tiled:
vae.enable_tiling() vae.enable_tiling()
else: else:
vae.disable_tiling() vae.disable_tiling()
# clear memory as vae decode can request a lot
torch.cuda.empty_cache()
with torch.inference_mode(): with torch.inference_mode():
# copied from diffusers pipeline # copied from diffusers pipeline
latents = latents / vae.config.scaling_factor latents = latents / vae.config.scaling_factor
@ -509,36 +497,29 @@ class ImageToLatentsInvocation(BaseInvocation):
#vae_info = context.services.model_manager.get_model(**self.vae.vae.dict()) #vae_info = context.services.model_manager.get_model(**self.vae.vae.dict())
vae_info = context.services.model_manager.get_model( vae_info = context.services.model_manager.get_model(
model_name=self.vae.vae.model_name, **self.vae.vae.dict(),
model_type=SDModelType[self.vae.vae.model_type],
submodel=SDModelType[self.vae.vae.submodel] if self.vae.vae.submodel else None,
) )
image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB")) image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB"))
if image_tensor.dim() == 3: if image_tensor.dim() == 3:
image_tensor = einops.rearrange(image_tensor, "c h w -> 1 c h w") image_tensor = einops.rearrange(image_tensor, "c h w -> 1 c h w")
with vae_info.context as vae: with vae_info as vae:
# TODO: check if it works
if self.tiled: if self.tiled:
vae.enable_tiling() vae.enable_tiling()
else: else:
vae.disable_tiling() vae.disable_tiling()
latents = self.non_noised_latents_from_image(vae, image_tensor) # non_noised_latents_from_image
image_tensor = image_tensor.to(device=vae.device, dtype=vae.dtype)
with torch.inference_mode():
image_tensor_dist = vae.encode(image_tensor).latent_dist
latents = image_tensor_dist.sample().to(
dtype=vae.dtype
) # FIXME: uses torch.randn. make reproducible!
latents = 0.18215 * latents
name = f"{context.graph_execution_state_id}__{self.id}" name = f"{context.graph_execution_state_id}__{self.id}"
context.services.latents.set(name, latents) context.services.latents.set(name, latents)
return LatentsOutput(latents=LatentsField(latents_name=name)) return LatentsOutput(latents=LatentsField(latents_name=name))
def non_noised_latents_from_image(self, vae, init_image):
init_image = init_image.to(device=vae.device, dtype=vae.dtype)
with torch.inference_mode():
init_latent_dist = vae.encode(init_image).latent_dist
init_latents = init_latent_dist.sample().to(
dtype=vae.dtype
) # FIXME: uses torch.randn. make reproducible!
init_latents = 0.18215 * init_latents
return init_latents

View File

@ -8,8 +8,8 @@ from ...backend.model_management import SDModelType
class ModelInfo(BaseModel): class ModelInfo(BaseModel):
model_name: str = Field(description="Info to load unet submodel") model_name: str = Field(description="Info to load unet submodel")
model_type: str = Field(description="Info to load unet submodel") model_type: SDModelType = Field(description="Info to load unet submodel")
submodel: Optional[str] = Field(description="Info to load unet submodel") submodel: Optional[SDModelType] = Field(description="Info to load unet submodel")
class UNetField(BaseModel): class UNetField(BaseModel):
unet: ModelInfo = Field(description="Info to load unet submodel") unet: ModelInfo = Field(description="Info to load unet submodel")
@ -62,15 +62,15 @@ class ModelLoaderInvocation(BaseInvocation):
# TODO: not found exceptions # TODO: not found exceptions
if not context.services.model_manager.model_exists( if not context.services.model_manager.model_exists(
model_name=self.model_name, model_name=self.model_name,
model_type=SDModelType.diffusers, model_type=SDModelType.Diffusers,
): ):
raise Exception(f"Unkown model name: {self.model_name}!") raise Exception(f"Unkown model name: {self.model_name}!")
""" """
if not context.services.model_manager.model_exists( if not context.services.model_manager.model_exists(
model_name=self.model_name, model_name=self.model_name,
model_type=SDModelType.diffusers, model_type=SDModelType.Diffusers,
submodel=SDModelType.tokenizer, submodel=SDModelType.Tokenizer,
): ):
raise Exception( raise Exception(
f"Failed to find tokenizer submodel in {self.model_name}! Check if model corrupted" f"Failed to find tokenizer submodel in {self.model_name}! Check if model corrupted"
@ -78,8 +78,8 @@ class ModelLoaderInvocation(BaseInvocation):
if not context.services.model_manager.model_exists( if not context.services.model_manager.model_exists(
model_name=self.model_name, model_name=self.model_name,
model_type=SDModelType.diffusers, model_type=SDModelType.Diffusers,
submodel=SDModelType.text_encoder, submodel=SDModelType.TextEncoder,
): ):
raise Exception( raise Exception(
f"Failed to find text_encoder submodel in {self.model_name}! Check if model corrupted" f"Failed to find text_encoder submodel in {self.model_name}! Check if model corrupted"
@ -87,8 +87,8 @@ class ModelLoaderInvocation(BaseInvocation):
if not context.services.model_manager.model_exists( if not context.services.model_manager.model_exists(
model_name=self.model_name, model_name=self.model_name,
model_type=SDModelType.diffusers, model_type=SDModelType.Diffusers,
submodel=SDModelType.unet, submodel=SDModelType.UNet,
): ):
raise Exception( raise Exception(
f"Failed to find unet submodel from {self.model_name}! Check if model corrupted" f"Failed to find unet submodel from {self.model_name}! Check if model corrupted"
@ -100,32 +100,32 @@ class ModelLoaderInvocation(BaseInvocation):
unet=UNetField( unet=UNetField(
unet=ModelInfo( unet=ModelInfo(
model_name=self.model_name, model_name=self.model_name,
model_type=SDModelType.diffusers.name, model_type=SDModelType.Diffusers,
submodel=SDModelType.unet.name, submodel=SDModelType.UNet,
), ),
scheduler=ModelInfo( scheduler=ModelInfo(
model_name=self.model_name, model_name=self.model_name,
model_type=SDModelType.diffusers.name, model_type=SDModelType.Diffusers,
submodel=SDModelType.scheduler.name, submodel=SDModelType.Scheduler,
), ),
), ),
clip=ClipField( clip=ClipField(
tokenizer=ModelInfo( tokenizer=ModelInfo(
model_name=self.model_name, model_name=self.model_name,
model_type=SDModelType.diffusers.name, model_type=SDModelType.Diffusers,
submodel=SDModelType.tokenizer.name, submodel=SDModelType.Tokenizer,
), ),
text_encoder=ModelInfo( text_encoder=ModelInfo(
model_name=self.model_name, model_name=self.model_name,
model_type=SDModelType.diffusers.name, model_type=SDModelType.Diffusers,
submodel=SDModelType.text_encoder.name, submodel=SDModelType.TextEncoder,
), ),
), ),
vae=VaeField( vae=VaeField(
vae=ModelInfo( vae=ModelInfo(
model_name=self.model_name, model_name=self.model_name,
model_type=SDModelType.diffusers.name, model_type=SDModelType.Diffusers,
submodel=SDModelType.vae.name, submodel=SDModelType.Vae,
), ),
) )
) )

View File

@ -120,7 +120,7 @@ class EventServiceBase:
node=node, node=node,
source_node_id=source_node_id, source_node_id=source_node_id,
model_name=model_name, model_name=model_name,
model_type=model_type.name, model_type=model_type,
submodel=submodel, submodel=submodel,
), ),
) )
@ -143,7 +143,7 @@ class EventServiceBase:
node=node, node=node,
source_node_id=source_node_id, source_node_id=source_node_id,
model_name=model_name, model_name=model_name,
model_type=model_type.name, model_type=model_type,
submodel=submodel, submodel=submodel,
model_info=model_info, model_info=model_info,
), ),

View File

@ -1,21 +1,25 @@
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Team # Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Team
from __future__ import annotations
import torch
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from pathlib import Path from pathlib import Path
from typing import Union, Callable, List, Tuple, types from typing import Union, Callable, List, Tuple, types, TYPE_CHECKING
from dataclasses import dataclass from dataclasses import dataclass
from invokeai.backend.model_management.model_manager import ( from invokeai.backend.model_management.model_manager import (
ModelManager, ModelManager,
SDModelType, SDModelType,
SDModelInfo, SDModelInfo,
torch,
) )
from invokeai.app.models.exceptions import CanceledException from invokeai.app.models.exceptions import CanceledException
from ...backend import Args,Globals # this must go when pr 3340 merged from ...backend import Args, Globals # this must go when pr 3340 merged
from ...backend.util import choose_precision, choose_torch_device from ...backend.util import choose_precision, choose_torch_device
if TYPE_CHECKING:
from ..invocations.baseinvocation import BaseInvocation, InvocationContext
@dataclass @dataclass
class LastUsedModel: class LastUsedModel:
model_name: str=None model_name: str=None
@ -28,9 +32,9 @@ class ModelManagerServiceBase(ABC):
@abstractmethod @abstractmethod
def __init__( def __init__(
self, self,
config: Args, config: Args,
logger: types.ModuleType logger: types.ModuleType,
): ):
""" """
Initialize with the path to the models.yaml config file. Initialize with the path to the models.yaml config file.
@ -41,13 +45,14 @@ class ModelManagerServiceBase(ABC):
pass pass
@abstractmethod @abstractmethod
def get_model(self, def get_model(
model_name: str, self,
model_type: SDModelType=SDModelType.diffusers, model_name: str,
submodel: SDModelType=None, model_type: SDModelType = SDModelType.Diffusers,
node=None, # circular dependency issues, so untyped at moment submodel: Optional[SDModelType] = None,
context=None, node: Optional[BaseInvocation] = None,
)->SDModelInfo: context: Optional[InvocationContext] = None,
) -> SDModelInfo:
"""Retrieve the indicated model with name and type. """Retrieve the indicated model with name and type.
submodel can be used to get a part (such as the vae) submodel can be used to get a part (such as the vae)
of a diffusers pipeline.""" of a diffusers pipeline."""
@ -60,14 +65,14 @@ class ModelManagerServiceBase(ABC):
@abstractmethod @abstractmethod
def model_exists( def model_exists(
self, self,
model_name: str, model_name: str,
model_type: SDModelType model_type: SDModelType,
) -> bool: ) -> bool:
pass pass
@abstractmethod @abstractmethod
def default_model(self) -> Union[Tuple[str, SDModelType],None]: def default_model(self) -> Optional[Tuple[str, SDModelType]]:
""" """
Returns the name and typeof the default model, or None Returns the name and typeof the default model, or None
if none is defined. if none is defined.
@ -80,21 +85,21 @@ class ModelManagerServiceBase(ABC):
pass pass
@abstractmethod @abstractmethod
def model_info(self, model_name: str, model_type: SDModelType)->dict: def model_info(self, model_name: str, model_type: SDModelType) -> dict:
""" """
Given a model name returns a dict-like (OmegaConf) object describing it. Given a model name returns a dict-like (OmegaConf) object describing it.
""" """
pass pass
@abstractmethod @abstractmethod
def model_names(self)->List[Tuple[str, SDModelType]]: def model_names(self) -> List[Tuple[str, SDModelType]]:
""" """
Returns a list of all the model names known. Returns a list of all the model names known.
""" """
pass pass
@abstractmethod @abstractmethod
def list_models(self)->dict: def list_models(self) -> dict:
""" """
Return a dict of models in the format: Return a dict of models in the format:
{ model_key1: {'status': 'active'|'cached'|'not loaded', { model_key1: {'status': 'active'|'cached'|'not loaded',
@ -110,12 +115,12 @@ class ModelManagerServiceBase(ABC):
@abstractmethod @abstractmethod
def add_model( def add_model(
self, self,
model_name: str, model_name: str,
model_type: SDModelType, model_type: SDModelType,
model_attributes: dict, model_attributes: dict,
clobber: bool = False clobber: bool = False
)->None: ) -> None:
""" """
Update the named model with a dictionary of attributes. Will fail with an Update the named model with a dictionary of attributes. Will fail with an
assertion error if the name already exists. Pass clobber=True to overwrite. assertion error if the name already exists. Pass clobber=True to overwrite.
@ -126,10 +131,12 @@ class ModelManagerServiceBase(ABC):
pass pass
@abstractmethod @abstractmethod
def del_model(self, def del_model(
model_name: str, self,
model_type: SDModelType, model_name: str,
delete_files: bool = False): model_type: SDModelType,
delete_files: bool = False,
):
""" """
Delete the named model from configuration. If delete_files is true, Delete the named model from configuration. If delete_files is true,
then the underlying weight file or diffusers directory will be deleted then the underlying weight file or diffusers directory will be deleted
@ -140,9 +147,9 @@ class ModelManagerServiceBase(ABC):
@abstractmethod @abstractmethod
def import_diffuser_model( def import_diffuser_model(
repo_or_path: Union[str, Path], repo_or_path: Union[str, Path],
model_name: str = None, model_name: Optional[str] = None,
description: str = None, description: Optional[str] = None,
vae: dict = None, vae: Optional[dict] = None,
) -> bool: ) -> bool:
""" """
Install the indicated diffuser model and returns True if successful. Install the indicated diffuser model and returns True if successful.
@ -157,10 +164,10 @@ class ModelManagerServiceBase(ABC):
@abstractmethod @abstractmethod
def import_lora( def import_lora(
self, self,
path: Path, path: Path,
model_name: str=None, model_name: Optional[str] = None,
description: str=None, description: Optional[str] = None,
): ):
""" """
Creates an entry for the indicated lora file. Call Creates an entry for the indicated lora file. Call
@ -170,10 +177,10 @@ class ModelManagerServiceBase(ABC):
@abstractmethod @abstractmethod
def import_embedding( def import_embedding(
self, self,
path: Path, path: Path,
model_name: str=None, model_name: str=None,
description: str=None, description: str=None,
): ):
""" """
Creates an entry for the indicated textual inversion embedding file. Creates an entry for the indicated textual inversion embedding file.
@ -223,7 +230,7 @@ class ModelManagerServiceBase(ABC):
pass pass
@abstractmethod @abstractmethod
def commit(self, conf_file: Path=None) -> None: def commit(self, conf_file: Path = None) -> None:
""" """
Write current configuration out to the indicated file. Write current configuration out to the indicated file.
If no conf_file is provided, then replaces the If no conf_file is provided, then replaces the
@ -235,10 +242,10 @@ class ModelManagerServiceBase(ABC):
class ModelManagerService(ModelManagerServiceBase): class ModelManagerService(ModelManagerServiceBase):
"""Responsible for managing models on disk and in memory""" """Responsible for managing models on disk and in memory"""
def __init__( def __init__(
self, self,
config: Args, config: Args,
logger: types.ModuleType logger: types.ModuleType,
): ):
""" """
Initialize with the path to the models.yaml config file. Initialize with the path to the models.yaml config file.
Optional parameters are the torch device type, precision, max_models, Optional parameters are the torch device type, precision, max_models,
@ -255,7 +262,7 @@ class ModelManagerService(ModelManagerServiceBase):
logger.debug(f'config file={config_file}') logger.debug(f'config file={config_file}')
device = torch.device(choose_torch_device()) device = torch.device(choose_torch_device())
if config.precision=="auto": if config.precision == "auto":
precision = choose_precision(device) precision = choose_precision(device)
dtype = torch.float32 if precision=='float32' \ dtype = torch.float32 if precision=='float32' \
else torch.float16 else torch.float16
@ -272,22 +279,24 @@ class ModelManagerService(ModelManagerServiceBase):
sequential_offload = config.sequential_guidance sequential_offload = config.sequential_guidance
self.mgr = ModelManager(config=config_file, self.mgr = ModelManager(
device_type=device, config=config_file,
precision=dtype, device_type=device,
max_cache_size=max_cache_size, precision=dtype,
sequential_offload=sequential_offload, max_cache_size=max_cache_size,
logger=logger sequential_offload=sequential_offload,
) logger=logger,
)
logger.info('Model manager service initialized') logger.info('Model manager service initialized')
def get_model(self, def get_model(
model_name: str, self,
model_type: SDModelType=SDModelType.diffusers, model_name: str,
submodel: SDModelType=None, model_type: SDModelType = SDModelType.Diffusers,
node=None, submodel: Optional[SDModelType] = None,
context=None, node: Optional[BaseInvocation] = None,
)->SDModelInfo: context: Optional[InvocationContext] = None,
) -> SDModelInfo:
""" """
Retrieve the indicated model. submodel can be used to get a Retrieve the indicated model. submodel can be used to get a
part (such as the vae) of a diffusers mode. part (such as the vae) of a diffusers mode.
@ -340,9 +349,9 @@ class ModelManagerService(ModelManagerServiceBase):
return model_info return model_info
def model_exists( def model_exists(
self, self,
model_name: str, model_name: str,
model_type: SDModelType model_type: SDModelType,
) -> bool: ) -> bool:
""" """
Given a model name, returns True if it is a valid Given a model name, returns True if it is a valid
@ -350,32 +359,33 @@ class ModelManagerService(ModelManagerServiceBase):
""" """
return self.mgr.model_exists( return self.mgr.model_exists(
model_name, model_name,
model_type) model_type,
)
def default_model(self) -> Union[Tuple[str, SDModelType],None]: def default_model(self) -> Optional[Tuple[str, SDModelType]]:
""" """
Returns the name of the default model, or None Returns the name of the default model, or None
if none is defined. if none is defined.
""" """
return self.mgr.default_model() return self.mgr.default_model()
def set_default_model(self, model_name:str, model_type: SDModelType): def set_default_model(self, model_name: str, model_type: SDModelType):
"""Sets the default model to the indicated name.""" """Sets the default model to the indicated name."""
self.mgr.set_default_model(model_name) self.mgr.set_default_model(model_name)
def model_info(self, model_name: str, model_type: SDModelType)->dict: def model_info(self, model_name: str, model_type: SDModelType) -> dict:
""" """
Given a model name returns a dict-like (OmegaConf) object describing it. Given a model name returns a dict-like (OmegaConf) object describing it.
""" """
return self.mgr.model_info(model_name) return self.mgr.model_info(model_name)
def model_names(self)->List[Tuple[str, SDModelType]]: def model_names(self) -> List[Tuple[str, SDModelType]]:
""" """
Returns a list of all the model names known. Returns a list of all the model names known.
""" """
return self.mgr.model_names() return self.mgr.model_names()
def list_models(self)->dict: def list_models(self) -> dict:
""" """
Return a dict of models in the format: Return a dict of models in the format:
{ model_key: {'status': 'active'|'cached'|'not loaded', { model_key: {'status': 'active'|'cached'|'not loaded',
@ -388,11 +398,12 @@ class ModelManagerService(ModelManagerServiceBase):
return self.mgr.list_models() return self.mgr.list_models()
def add_model( def add_model(
self, self,
model_name: str, model_name: str,
model_type: SDModelType, model_type: SDModelType,
model_attributes: dict, model_attributes: dict,
clobber: bool = False)->None: clobber: bool = False,
)->None:
""" """
Update the named model with a dictionary of attributes. Will fail with an Update the named model with a dictionary of attributes. Will fail with an
assertion error if the name already exists. Pass clobber=True to overwrite. assertion error if the name already exists. Pass clobber=True to overwrite.
@ -400,14 +411,15 @@ class ModelManagerService(ModelManagerServiceBase):
with an assertion error if provided attributes are incorrect or with an assertion error if provided attributes are incorrect or
the model name is missing. Call commit() to write changes to disk. the model name is missing. Call commit() to write changes to disk.
""" """
return self.mgr.add_model(model_name, model_type, model_attributes, dict, clobber) return self.mgr.add_model(model_name, model_type, model_attributes, clobber)
def del_model(self, def del_model(
model_name: str, self,
model_type: SDModelType=SDModelType.diffusers, model_name: str,
delete_files: bool = False model_type: SDModelType = SDModelType.Diffusers,
): delete_files: bool = False,
):
""" """
Delete the named model from configuration. If delete_files is true, Delete the named model from configuration. If delete_files is true,
then the underlying weight file or diffusers directory will be deleted then the underlying weight file or diffusers directory will be deleted
@ -416,11 +428,11 @@ class ModelManagerService(ModelManagerServiceBase):
self.mgr.del_model(model_name, model_type, delete_files) self.mgr.del_model(model_name, model_type, delete_files)
def import_diffuser_model( def import_diffuser_model(
self, self,
repo_or_path: Union[str, Path], repo_or_path: Union[str, Path],
model_name: str = None, model_name: Optional[str] = None,
description: str = None, description: Optional[str] = None,
vae: dict = None, vae: Optional[dict] = None,
) -> bool: ) -> bool:
""" """
Install the indicated diffuser model and returns True if successful. Install the indicated diffuser model and returns True if successful.
@ -431,13 +443,13 @@ class ModelManagerService(ModelManagerServiceBase):
You can optionally provide a model name and/or description. If not provided, You can optionally provide a model name and/or description. If not provided,
then these will be derived from the repo name. Call commit() to write to disk. then these will be derived from the repo name. Call commit() to write to disk.
""" """
return self.mgr.import_diffuser_model(repo_or_path, model_name, description, vae) return self.mgr.import_diffuser_model(repo_or_path, model_name, description, vae)
def import_lora( def import_lora(
self, self,
path: Path, path: Path,
model_name: str=None, model_name: Optional[str] = None,
description: str=None, description: Optional[str] = None,
): ):
""" """
Creates an entry for the indicated lora file. Call Creates an entry for the indicated lora file. Call
@ -446,10 +458,10 @@ class ModelManagerService(ModelManagerServiceBase):
self.mgr.import_lora(path, model_name, description) self.mgr.import_lora(path, model_name, description)
def import_embedding( def import_embedding(
self, self,
path: Path, path: Path,
model_name: str=None, model_name: Optional[str] = None,
description: str=None, description: Optional[str] = None,
): ):
""" """
Creates an entry for the indicated textual inversion embedding file. Creates an entry for the indicated textual inversion embedding file.
@ -462,9 +474,9 @@ class ModelManagerService(ModelManagerServiceBase):
path_url_or_repo: str, path_url_or_repo: str,
model_name: str = None, model_name: str = None,
description: str = None, description: str = None,
model_config_file: Path = None, model_config_file: Optional[Path] = None,
commit_to_conf: Path = None, commit_to_conf: Optional[Path] = None,
config_file_callback: Callable[[Path], Path] = None, config_file_callback: Optional[Callable[[Path], Path]] = None,
) -> str: ) -> str:
"""Accept a string which could be: """Accept a string which could be:
- a HF diffusers repo_id - a HF diffusers repo_id
@ -505,7 +517,7 @@ class ModelManagerService(ModelManagerServiceBase):
) )
def commit(self, conf_file: Path=None): def commit(self, conf_file: Optional[Path]=None):
""" """
Write current configuration out to the indicated file. Write current configuration out to the indicated file.
If no conf_file is provided, then replaces the If no conf_file is provided, then replaces the
@ -514,16 +526,16 @@ class ModelManagerService(ModelManagerServiceBase):
return self.mgr.commit(conf_file) return self.mgr.commit(conf_file)
def _emit_load_event( def _emit_load_event(
self, self,
node, node,
context, context,
model_name: str, model_name: str,
model_type: SDModelType, model_type: SDModelType,
submodel: SDModelType, submodel: SDModelType,
model_info: SDModelInfo=None, model_info: Optional[SDModelInfo] = None,
): ):
if context.services.queue.is_canceled(context.graph_execution_state_id): if context.services.queue.is_canceled(context.graph_execution_state_id):
raise CanceledException raise CanceledException()
graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id) graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id)
source_node_id = graph_execution_state.prepared_source_mapping[node.id] source_node_id = graph_execution_state.prepared_source_mapping[node.id]
if context: if context:
@ -536,7 +548,7 @@ class ModelManagerService(ModelManagerServiceBase):
submodel=submodel, submodel=submodel,
) )
else: else:
context.services.events.emit_model_load_completed ( context.services.events.emit_model_load_completed(
graph_execution_state_id=context.graph_execution_state_id, graph_execution_state_id=context.graph_execution_state_id,
node=node.dict(), node=node.dict(),
source_node_id=source_node_id, source_node_id=source_node_id,

View File

@ -23,12 +23,12 @@ import warnings
from collections import Counter from collections import Counter
from enum import Enum from enum import Enum
from pathlib import Path from pathlib import Path
from typing import Dict, Sequence, Union, Tuple, types from typing import Dict, Sequence, Union, Tuple, types, Optional
import torch import torch
import safetensors.torch import safetensors.torch
from diffusers import DiffusionPipeline, StableDiffusionPipeline, AutoencoderKL, SchedulerMixin, UNet2DConditionModel from diffusers import DiffusionPipeline, StableDiffusionPipeline, AutoencoderKL, SchedulerMixin, UNet2DConditionModel, ConfigMixin
from diffusers import logging as diffusers_logging from diffusers import logging as diffusers_logging
from diffusers.pipelines.stable_diffusion.safety_checker import \ from diffusers.pipelines.stable_diffusion.safety_checker import \
StableDiffusionSafetyChecker StableDiffusionSafetyChecker
@ -55,20 +55,38 @@ class LoraType(dict):
class TIType(dict): class TIType(dict):
pass pass
class SDModelType(Enum): class SDModelType(str, Enum):
diffusers=StableDiffusionGeneratorPipeline # whole pipeline Diffusers="diffusers" # whole pipeline
vae=AutoencoderKL # diffusers parts Vae="vae" # diffusers parts
text_encoder=CLIPTextModel TextEncoder="text_encoder"
tokenizer=CLIPTokenizer Tokenizer="tokenizer"
unet=UNet2DConditionModel UNet="unet"
scheduler=SchedulerMixin Scheduler="scheduler"
safety_checker=StableDiffusionSafetyChecker SafetyChecker="safety_checker"
feature_extractor=CLIPFeatureExtractor FeatureExtractor="feature_extractor"
# These are all loaded as dicts of tensors, and we # These are all loaded as dicts of tensors, and we
# distinguish them by class # distinguish them by class
lora=LoraType Lora="lora"
textual_inversion=TIType TextualInversion="textual_inversion"
# TODO:
class EmptyScheduler(SchedulerMixin, ConfigMixin):
pass
MODEL_CLASSES = {
SDModelType.Diffusers: StableDiffusionGeneratorPipeline,
SDModelType.Vae: AutoencoderKL,
SDModelType.TextEncoder: CLIPTextModel, # TODO: t5
SDModelType.Tokenizer: CLIPTokenizer, # TODO: t5
SDModelType.UNet: UNet2DConditionModel,
SDModelType.Scheduler: EmptyScheduler,
SDModelType.SafetyChecker: StableDiffusionSafetyChecker,
SDModelType.FeatureExtractor: CLIPFeatureExtractor,
SDModelType.Lora: LoraType,
SDModelType.TextualInversion: TIType,
}
class ModelStatus(Enum): class ModelStatus(Enum):
unknown='unknown' unknown='unknown'
not_loaded='not loaded' not_loaded='not loaded'
@ -80,21 +98,21 @@ class ModelStatus(Enum):
# After loading, we will know it exactly. # After loading, we will know it exactly.
# Sizes are in Gigs, estimated for float16; double for float32 # Sizes are in Gigs, estimated for float16; double for float32
SIZE_GUESSTIMATE = { SIZE_GUESSTIMATE = {
SDModelType.diffusers: 2.2, SDModelType.Diffusers: 2.2,
SDModelType.vae: 0.35, SDModelType.Vae: 0.35,
SDModelType.text_encoder: 0.5, SDModelType.TextEncoder: 0.5,
SDModelType.tokenizer: 0.001, SDModelType.Tokenizer: 0.001,
SDModelType.unet: 3.4, SDModelType.UNet: 3.4,
SDModelType.scheduler: 0.001, SDModelType.Scheduler: 0.001,
SDModelType.safety_checker: 1.2, SDModelType.SafetyChecker: 1.2,
SDModelType.feature_extractor: 0.001, SDModelType.FeatureExtractor: 0.001,
SDModelType.lora: 0.1, SDModelType.Lora: 0.1,
SDModelType.textual_inversion: 0.001, SDModelType.TextualInversion: 0.001,
} }
# The list of model classes we know how to fetch, for typechecking # The list of model classes we know how to fetch, for typechecking
ModelClass = Union[tuple([x.value for x in SDModelType])] ModelClass = Union[tuple([x for x in MODEL_CLASSES.values()])]
DiffusionClasses = (StableDiffusionGeneratorPipeline, AutoencoderKL, SchedulerMixin, UNet2DConditionModel) DiffusionClasses = (StableDiffusionGeneratorPipeline, AutoencoderKL, EmptyScheduler, UNet2DConditionModel)
class UnsafeModelException(Exception): class UnsafeModelException(Exception):
"Raised when a legacy model file fails the picklescan test" "Raised when a legacy model file fails the picklescan test"
@ -110,15 +128,15 @@ class ModelLocker(object):
class ModelCache(object): class ModelCache(object):
def __init__( def __init__(
self, self,
max_cache_size: float=DEFAULT_MAX_CACHE_SIZE, max_cache_size: float=DEFAULT_MAX_CACHE_SIZE,
execution_device: torch.device=torch.device('cuda'), execution_device: torch.device=torch.device('cuda'),
storage_device: torch.device=torch.device('cpu'), storage_device: torch.device=torch.device('cpu'),
precision: torch.dtype=torch.float16, precision: torch.dtype=torch.float16,
sequential_offload: bool=False, sequential_offload: bool=False,
lazy_offloading: bool=True, lazy_offloading: bool=True,
sha_chunksize: int = 16777216, sha_chunksize: int = 16777216,
logger: types.ModuleType = logger logger: types.ModuleType = logger
): ):
''' '''
:param max_models: Maximum number of models to cache in CPU RAM [4] :param max_models: Maximum number of models to cache in CPU RAM [4]
@ -145,15 +163,15 @@ class ModelCache(object):
self.model_sizes: Dict[str,int] = dict() self.model_sizes: Dict[str,int] = dict()
def get_model( def get_model(
self, self,
repo_id_or_path: Union[str,Path], repo_id_or_path: Union[str, Path],
model_type: SDModelType=SDModelType.diffusers, model_type: SDModelType = SDModelType.Diffusers,
subfolder: Path=None, subfolder: Path = None,
submodel: SDModelType=None, submodel: SDModelType = None,
revision: str=None, revision: str = None,
attach_model_part: Tuple[SDModelType, str] = (None,None), attach_model_part: Tuple[SDModelType, str] = (None, None),
gpu_load: bool=True, gpu_load: bool = True,
)->ModelLocker: # ?? what does it return ) -> ModelLocker: # ?? what does it return
''' '''
Load and return a HuggingFace model wrapped in a context manager generator, with RAM caching. Load and return a HuggingFace model wrapped in a context manager generator, with RAM caching.
Use like this: Use like this:
@ -178,14 +196,14 @@ class ModelCache(object):
vae_context = cache.get_model( vae_context = cache.get_model(
'stabilityai/sd-stable-diffusion-2', 'stabilityai/sd-stable-diffusion-2',
submodel=SDModelType.vae submodel=SDModelType.Vae
) )
This is equivalent to: This is equivalent to:
vae_context = cache.get_model( vae_context = cache.get_model(
'stabilityai/sd-stable-diffusion-2', 'stabilityai/sd-stable-diffusion-2',
model_type = SDModelType.vae, model_type = SDModelType.Vae,
subfolder='vae' subfolder='vae'
) )
@ -195,14 +213,14 @@ class ModelCache(object):
pipeline_context = cache.get_model( pipeline_context = cache.get_model(
'runwayml/stable-diffusion-v1-5', 'runwayml/stable-diffusion-v1-5',
attach_model_part=(SDModelType.vae,'stabilityai/sd-vae-ft-mse') attach_model_part=(SDModelType.Vae,'stabilityai/sd-vae-ft-mse')
) )
The model will be locked into GPU VRAM for the duration of the context. The model will be locked into GPU VRAM for the duration of the context.
:param repo_id_or_path: either the HuggingFace repo_id or a Path to a local model :param repo_id_or_path: either the HuggingFace repo_id or a Path to a local model
:param model_type: An SDModelType enum indicating the type of the (parent) model :param model_type: An SDModelType enum indicating the type of the (parent) model
:param subfolder: name of a subfolder in which the model can be found, e.g. "vae" :param subfolder: name of a subfolder in which the model can be found, e.g. "vae"
:param submodel: an SDModelType enum indicating the model part to return, e.g. SDModelType.vae :param submodel: an SDModelType enum indicating the model part to return, e.g. SDModelType.Vae
:param attach_model_part: load and attach a diffusers model component. Pass a tuple of format (SDModelType,repo_id) :param attach_model_part: load and attach a diffusers model component. Pass a tuple of format (SDModelType,repo_id)
:param revision: model revision :param revision: model revision
:param gpu_load: load the model into GPU [default True] :param gpu_load: load the model into GPU [default True]
@ -211,7 +229,7 @@ class ModelCache(object):
repo_id_or_path, repo_id_or_path,
revision, revision,
subfolder, subfolder,
model_type.value, model_type,
) )
# optimization: if caller is asking to load a submodel of a diffusers pipeline, then # optimization: if caller is asking to load a submodel of a diffusers pipeline, then
@ -221,11 +239,11 @@ class ModelCache(object):
repo_id_or_path, repo_id_or_path,
None, None,
revision, revision,
SDModelType.diffusers.value SDModelType.Diffusers
) )
if possible_parent_key in self.models: if possible_parent_key in self.models:
key = possible_parent_key key = possible_parent_key
submodel=model_type submodel = model_type
# Look for the model in the cache RAM # Look for the model in the cache RAM
if key in self.models: # cached - move to bottom of stack (most recently used) if key in self.models: # cached - move to bottom of stack (most recently used)
@ -256,24 +274,24 @@ class ModelCache(object):
self.current_cache_size += mem_used # increment size of the cache self.current_cache_size += mem_used # increment size of the cache
# this is a bit of legacy work needed to support the old-style "load this diffuser with custom VAE" # this is a bit of legacy work needed to support the old-style "load this diffuser with custom VAE"
if model_type==SDModelType.diffusers and attach_model_part[0]: if model_type == SDModelType.Diffusers and attach_model_part[0]:
self.attach_part(model,*attach_model_part) self.attach_part(model, *attach_model_part)
self.stack.append(key) # add to LRU cache self.stack.append(key) # add to LRU cache
self.models[key]=model # keep copy of model in dict self.models[key] = model # keep copy of model in dict
if submodel: if submodel:
model = getattr(model, submodel.name) model = getattr(model, submodel)
return self.ModelLocker(self, key, model, gpu_load) return self.ModelLocker(self, key, model, gpu_load)
def uncache_model(self, key: str): def uncache_model(self, key: str):
'''Remove corresponding model from the cache''' '''Remove corresponding model from the cache'''
if key is not None and key in self.models: if key is not None and key in self.models:
with contextlib.suppress(ValueError), contextlib.suppress(KeyError): self.models.pop(key, None)
del self.models[key] self.locked_models.pop(key, None)
del self.locked_models[key] self.loaded_models.discard(key)
self.loaded_models.remove(key) with contextlib.suppress(ValueError):
self.stack.remove(key) self.stack.remove(key)
class ModelLocker(object): class ModelLocker(object):
@ -302,7 +320,7 @@ class ModelCache(object):
if model.device != cache.execution_device: if model.device != cache.execution_device:
cache.logger.debug(f'Moving {key} into {cache.execution_device}') cache.logger.debug(f'Moving {key} into {cache.execution_device}')
with VRAMUsage() as mem: with VRAMUsage() as mem:
model.to(cache.execution_device) # move into GPU model.to(cache.execution_device, dtype=cache.precision) # move into GPU
cache.logger.debug(f'GPU VRAM used for load: {(mem.vram_used/GIG):.2f} GB') cache.logger.debug(f'GPU VRAM used for load: {(mem.vram_used/GIG):.2f} GB')
cache.model_sizes[key] = mem.vram_used # more accurate size cache.model_sizes[key] = mem.vram_used # more accurate size
@ -312,13 +330,16 @@ class ModelCache(object):
else: else:
# in the event that the caller wants the model in RAM, we # in the event that the caller wants the model in RAM, we
# move it into CPU if it is in GPU and not locked # move it into CPU if it is in GPU and not locked
if hasattr(model,'to') and (key in cache.loaded_models if hasattr(model, 'to') and (key in cache.loaded_models
and cache.locked_models[key] == 0): and cache.locked_models[key] == 0):
model.to(cache.storage_device) model.to(cache.storage_device)
cache.loaded_models.remove(key) cache.loaded_models.remove(key)
return model return model
def __exit__(self, type, value, traceback): def __exit__(self, type, value, traceback):
if not hasattr(self.model, 'to'):
return
key = self.key key = self.key
cache = self.cache cache = self.cache
cache.locked_models[key] -= 1 cache.locked_models[key] -= 1
@ -326,11 +347,12 @@ class ModelCache(object):
cache._offload_unlocked_models() cache._offload_unlocked_models()
cache._print_cuda_stats() cache._print_cuda_stats()
def attach_part(self, def attach_part(
diffusers_model: StableDiffusionPipeline, self,
part_type: SDModelType, diffusers_model: StableDiffusionPipeline,
part_id: str part_type: SDModelType,
): part_id: str,
):
''' '''
Attach a diffusers model part to a diffusers model. This can be Attach a diffusers model part to a diffusers model. This can be
used to replace the VAE, tokenizer, textencoder, unet, etc. used to replace the VAE, tokenizer, textencoder, unet, etc.
@ -338,27 +360,26 @@ class ModelCache(object):
:param part_type: An SD ModelType indicating the part :param part_type: An SD ModelType indicating the part
:param part_id: A HF repo_id for the part :param part_id: A HF repo_id for the part
''' '''
part_key = part_type.name
part_class = part_type.value
part = self._load_diffusers_from_storage( part = self._load_diffusers_from_storage(
part_id, part_id,
model_class=part_class, model_class=MODEL_CLASSES[part_type],
) )
part.to(diffusers_model.device) part.to(diffusers_model.device)
setattr(diffusers_model,part_key,part) setattr(diffusers_model, part_type, part)
self.logger.debug(f'Attached {part_key} {part_id}') self.logger.debug(f'Attached {part_type} {part_id}')
def status(self, def status(
repo_id_or_path: Union[str,Path], self,
model_type: SDModelType=SDModelType.diffusers, repo_id_or_path: Union[str, Path],
revision: str=None, model_type: SDModelType = SDModelType.Diffusers,
subfolder: Path=None, revision: str = None,
)->ModelStatus: subfolder: Path = None,
) -> ModelStatus:
key = self._model_key( key = self._model_key(
repo_id_or_path, repo_id_or_path,
revision, revision,
subfolder, subfolder,
model_type.value, model_type,
) )
if key not in self.models: if key not in self.models:
return ModelStatus.not_loaded return ModelStatus.not_loaded
@ -370,9 +391,11 @@ class ModelCache(object):
else: else:
return ModelStatus.in_ram return ModelStatus.in_ram
def model_hash(self, def model_hash(
repo_id_or_path: Union[str,Path], self,
revision: str="main")->str: repo_id_or_path: Union[str, Path],
revision: str = "main",
) -> str:
''' '''
Given the HF repo id or path to a model on disk, returns a unique Given the HF repo id or path to a model on disk, returns a unique
hash. Works for legacy checkpoint files, HF models on disk, and HF repo IDs hash. Works for legacy checkpoint files, HF models on disk, and HF repo IDs
@ -385,7 +408,7 @@ class ModelCache(object):
else: else:
return self._hf_commit_hash(repo_id_or_path,revision) return self._hf_commit_hash(repo_id_or_path,revision)
def cache_size(self)->float: def cache_size(self) -> float:
"Return the current size of the cache, in GB" "Return the current size of the cache, in GB"
return self.current_cache_size / GIG return self.current_cache_size / GIG
@ -407,10 +430,15 @@ class ModelCache(object):
logger.debug("Model scanned ok") logger.debug("Model scanned ok")
@staticmethod @staticmethod
def _model_key(path,revision,subfolder,model_class)->str: def _model_key(path, revision, subfolder, model_class) -> str:
return ':'.join([str(path),str(revision or ''),str(subfolder or ''),model_class.__name__]) return ':'.join([
str(path),
str(revision or ''),
str(subfolder or ''),
model_class,
])
def _has_cuda(self)->bool: def _has_cuda(self) -> bool:
return self.execution_device.type == 'cuda' return self.execution_device.type == 'cuda'
def _print_cuda_stats(self): def _print_cuda_stats(self):
@ -450,43 +478,43 @@ class ModelCache(object):
self.loaded_models.remove(key) self.loaded_models.remove(key)
def _load_model_from_storage( def _load_model_from_storage(
self, self,
repo_id_or_path: Union[str,Path], repo_id_or_path: Union[str, Path],
subfolder: Path=None, subfolder: Optional[Path] = None,
revision: str=None, revision: Optional[str] = None,
model_type: SDModelType=SDModelType.diffusers, model_type: SDModelType = SDModelType.Diffusers,
)->ModelClass: ) -> ModelClass:
''' '''
Load and return a HuggingFace model. Load and return a HuggingFace model.
:param repo_id_or_path: either the HuggingFace repo_id or a Path to a local model :param repo_id_or_path: either the HuggingFace repo_id or a Path to a local model
:param subfolder: name of a subfolder in which the model can be found, e.g. "vae" :param subfolder: name of a subfolder in which the model can be found, e.g. "vae"
:param revision: model revision :param revision: model revision
:param model_type: type of model to return, defaults to SDModelType.diffusers :param model_type: type of model to return, defaults to SDModelType.Diffusers
''' '''
# silence transformer and diffuser warnings # silence transformer and diffuser warnings
with SilenceWarnings(): with SilenceWarnings():
if model_type==SDModelType.lora: if model_type==SDModelType.Lora:
model = self._load_lora_from_storage(repo_id_or_path) model = self._load_lora_from_storage(repo_id_or_path)
elif model_type==SDModelType.textual_inversion: elif model_type==SDModelType.TextualInversion:
model = self._load_ti_from_storage(repo_id_or_path) model = self._load_ti_from_storage(repo_id_or_path)
else: else:
model = self._load_diffusers_from_storage( model = self._load_diffusers_from_storage(
repo_id_or_path, repo_id_or_path,
subfolder, subfolder,
revision, revision,
model_type.value, model_type,
) )
if self.sequential_offload and isinstance(model,StableDiffusionGeneratorPipeline): if self.sequential_offload and isinstance(model, StableDiffusionGeneratorPipeline):
model.enable_offload_submodels(self.execution_device) model.enable_offload_submodels(self.execution_device)
return model return model
def _load_diffusers_from_storage( def _load_diffusers_from_storage(
self, self,
repo_id_or_path: Union[str,Path], repo_id_or_path: Union[str, Path],
subfolder: Path=None, subfolder: Optional[Path] = None,
revision: str=None, revision: Optional[str] = None,
model_class: ModelClass=StableDiffusionGeneratorPipeline, model_type: ModelClass = StableDiffusionGeneratorPipeline,
)->ModelClass: ) -> ModelClass:
''' '''
Load and return a HuggingFace model using from_pretrained(). Load and return a HuggingFace model using from_pretrained().
:param repo_id_or_path: either the HuggingFace repo_id or a Path to a local model :param repo_id_or_path: either the HuggingFace repo_id or a Path to a local model
@ -494,17 +522,26 @@ class ModelCache(object):
:param revision: model revision :param revision: model revision
:param model_class: class of model to return, defaults to StableDiffusionGeneratorPIpeline :param model_class: class of model to return, defaults to StableDiffusionGeneratorPIpeline
''' '''
revisions = [revision] if revision \
else ['fp16','main'] if self.precision==torch.float16 \ model_class = MODEL_CLASSES[model_type]
else ['main']
extra_args = {'torch_dtype': self.precision, if revision is not None:
'safety_checker': None}\ revisions = [revision]
if model_class in DiffusionClasses\ elif self.precision == torch.float16:
else {} revisions = ['fp16', 'main']
else:
revisions = ['main']
extra_args = dict()
if model_class in DiffusionClasses:
extra_args = dict(
torch_dtype=self.precision,
safety_checker=None,
)
for rev in revisions: for rev in revisions:
try: try:
model = model_class.from_pretrained( model = model_class.from_pretrained(
repo_id_or_path, repo_id_or_path,
revision=rev, revision=rev,
subfolder=subfolder or '.', subfolder=subfolder or '.',
@ -517,13 +554,13 @@ class ModelCache(object):
pass pass
return model return model
def _load_lora_from_storage(self, lora_path: Path)->SDModelType.lora.value: def _load_lora_from_storage(self, lora_path: Path) -> LoraType:
assert False,"_load_lora_from_storage() is not yet implemented" assert False, "_load_lora_from_storage() is not yet implemented"
def _load_ti_from_storage(self, lora_path: Path)->SDModelType.textual_inversion.value: def _load_ti_from_storage(self, lora_path: Path) -> TIType:
assert False,"_load_ti_from_storage() is not yet implemented" assert False, "_load_ti_from_storage() is not yet implemented"
def _legacy_model_hash(self, checkpoint_path: Union[str,Path])->str: def _legacy_model_hash(self, checkpoint_path: Union[str, Path]) -> str:
sha = hashlib.sha256() sha = hashlib.sha256()
path = Path(checkpoint_path) path = Path(checkpoint_path)
assert path.is_file(),f"File {checkpoint_path} not found" assert path.is_file(),f"File {checkpoint_path} not found"
@ -544,7 +581,7 @@ class ModelCache(object):
f.write(hash) f.write(hash)
return hash return hash
def _local_model_hash(self, model_path: Union[str,Path])->str: def _local_model_hash(self, model_path: Union[str, Path]) -> str:
sha = hashlib.sha256() sha = hashlib.sha256()
path = Path(model_path) path = Path(model_path)
@ -566,7 +603,7 @@ class ModelCache(object):
f.write(hash) f.write(hash)
return hash return hash
def _hf_commit_hash(self, repo_id: str, revision: str='main')->str: def _hf_commit_hash(self, repo_id: str, revision: str='main') -> str:
api = HfApi() api = HfApi()
info = api.list_repo_refs( info = api.list_repo_refs(
repo_id=repo_id, repo_id=repo_id,
@ -578,7 +615,7 @@ class ModelCache(object):
return desired_revisions[0].target_commit return desired_revisions[0].target_commit
@staticmethod @staticmethod
def calc_model_size(model)->int: def calc_model_size(model) -> int:
if isinstance(model,DiffusionPipeline): if isinstance(model,DiffusionPipeline):
return ModelCache._calc_pipeline(model) return ModelCache._calc_pipeline(model)
elif isinstance(model,torch.nn.Module): elif isinstance(model,torch.nn.Module):
@ -587,7 +624,7 @@ class ModelCache(object):
return None return None
@staticmethod @staticmethod
def _calc_pipeline(pipeline)->int: def _calc_pipeline(pipeline) -> int:
res = 0 res = 0
for submodel_key in pipeline.components.keys(): for submodel_key in pipeline.components.keys():
submodel = getattr(pipeline, submodel_key) submodel = getattr(pipeline, submodel_key)
@ -596,7 +633,7 @@ class ModelCache(object):
return res return res
@staticmethod @staticmethod
def _calc_model(model)->int: def _calc_model(model) -> int:
mem_params = sum([param.nelement()*param.element_size() for param in model.parameters()]) mem_params = sum([param.nelement()*param.element_size() for param in model.parameters()])
mem_bufs = sum([buf.nelement()*buf.element_size() for buf in model.buffers()]) mem_bufs = sum([buf.nelement()*buf.element_size() for buf in model.buffers()])
mem = mem_params + mem_bufs # in bytes mem = mem_params + mem_bufs # in bytes

View File

@ -27,7 +27,7 @@ Typical usage:
max_cache_size=8 max_cache_size=8
) # gigabytes ) # gigabytes
model_info = manager.get_model('stable-diffusion-1.5', SDModelType.diffusers) model_info = manager.get_model('stable-diffusion-1.5', SDModelType.Diffusers)
with model_info.context as my_model: with model_info.context as my_model:
my_model.latents_from_embeddings(...) my_model.latents_from_embeddings(...)
@ -45,7 +45,7 @@ parameter:
model_info = manager.get_model( model_info = manager.get_model(
'clip-tokenizer', 'clip-tokenizer',
model_type=SDModelType.tokenizer model_type=SDModelType.Tokenizer
) )
This will raise an InvalidModelError if the format defined in the This will raise an InvalidModelError if the format defined in the
@ -96,7 +96,7 @@ SUBMODELS:
It is also possible to fetch an isolated submodel from a diffusers It is also possible to fetch an isolated submodel from a diffusers
model. Use the `submodel` parameter to select which part: model. Use the `submodel` parameter to select which part:
vae = manager.get_model('stable-diffusion-1.5',submodel=SDModelType.vae) vae = manager.get_model('stable-diffusion-1.5',submodel=SDModelType.Vae)
with vae.context as my_vae: with vae.context as my_vae:
print(type(my_vae)) print(type(my_vae))
# "AutoencoderKL" # "AutoencoderKL"
@ -120,8 +120,8 @@ separated by "/". Example:
You can now use the `model_type` argument to indicate which model you You can now use the `model_type` argument to indicate which model you
want: want:
tokenizer = mgr.get('clip-large',model_type=SDModelType.tokenizer) tokenizer = mgr.get('clip-large',model_type=SDModelType.Tokenizer)
encoder = mgr.get('clip-large',model_type=SDModelType.text_encoder) encoder = mgr.get('clip-large',model_type=SDModelType.TextEncoder)
OTHER FUNCTIONS: OTHER FUNCTIONS:
@ -254,7 +254,7 @@ class ModelManager(object):
def model_exists( def model_exists(
self, self,
model_name: str, model_name: str,
model_type: SDModelType = SDModelType.diffusers, model_type: SDModelType = SDModelType.Diffusers,
) -> bool: ) -> bool:
""" """
Given a model name, returns True if it is a valid Given a model name, returns True if it is a valid
@ -264,28 +264,28 @@ class ModelManager(object):
return model_key in self.config return model_key in self.config
def create_key(self, model_name: str, model_type: SDModelType) -> str: def create_key(self, model_name: str, model_type: SDModelType) -> str:
return f"{model_type.name}/{model_name}" return f"{model_type}/{model_name}"
def parse_key(self, model_key: str) -> Tuple[str, SDModelType]: def parse_key(self, model_key: str) -> Tuple[str, SDModelType]:
model_type_str, model_name = model_key.split('/', 1) model_type_str, model_name = model_key.split('/', 1)
if model_type_str not in SDModelType.__members__: try:
# TODO: model_type = SDModelType(model_type_str)
return (model_name, model_type)
except:
raise Exception(f"Unknown model type: {model_type_str}") raise Exception(f"Unknown model type: {model_type_str}")
return (model_name, SDModelType[model_type_str])
def get_model( def get_model(
self, self,
model_name: str, model_name: str,
model_type: SDModelType=SDModelType.diffusers, model_type: SDModelType = SDModelType.Diffusers,
submodel: SDModelType=None, submodel: Optional[SDModelType] = None,
) -> SDModelInfo: ) -> SDModelInfo:
"""Given a model named identified in models.yaml, return """Given a model named identified in models.yaml, return
an SDModelInfo object describing it. an SDModelInfo object describing it.
:param model_name: symbolic name of the model in models.yaml :param model_name: symbolic name of the model in models.yaml
:param model_type: SDModelType enum indicating the type of model to return :param model_type: SDModelType enum indicating the type of model to return
:param submodel: an SDModelType enum indicating the portion of :param submodel: an SDModelType enum indicating the portion of
the model to retrieve (e.g. SDModelType.vae) the model to retrieve (e.g. SDModelType.Vae)
If not provided, the model_type will be read from the `format` field If not provided, the model_type will be read from the `format` field
of the corresponding stanza. If provided, the model_type will be used of the corresponding stanza. If provided, the model_type will be used
@ -304,17 +304,17 @@ class ModelManager(object):
test1_pipeline = mgr.get_model('test1') test1_pipeline = mgr.get_model('test1')
# returns a StableDiffusionGeneratorPipeline # returns a StableDiffusionGeneratorPipeline
test1_vae1 = mgr.get_model('test1', submodel=SDModelType.vae) test1_vae1 = mgr.get_model('test1', submodel=SDModelType.Vae)
# returns the VAE part of a diffusers model as an AutoencoderKL # returns the VAE part of a diffusers model as an AutoencoderKL
test1_vae2 = mgr.get_model('test1', model_type=SDModelType.diffusers, submodel=SDModelType.vae) test1_vae2 = mgr.get_model('test1', model_type=SDModelType.Diffusers, submodel=SDModelType.Vae)
# does the same thing as the previous statement. Note that model_type # does the same thing as the previous statement. Note that model_type
# is for the parent model, and submodel is for the part # is for the parent model, and submodel is for the part
test1_lora = mgr.get_model('test1', model_type=SDModelType.lora) test1_lora = mgr.get_model('test1', model_type=SDModelType.Lora)
# returns a LoRA embed (as a 'dict' of tensors) # returns a LoRA embed (as a 'dict' of tensors)
test1_encoder = mgr.get_modelI('test1', model_type=SDModelType.textencoder) test1_encoder = mgr.get_modelI('test1', model_type=SDModelType.TextEncoder)
# raises an InvalidModelError # raises an InvalidModelError
""" """
@ -332,10 +332,10 @@ class ModelManager(object):
mconfig = self.config[model_key] mconfig = self.config[model_key]
# type already checked as it's part of key # type already checked as it's part of key
if model_type == SDModelType.diffusers: if model_type == SDModelType.Diffusers:
# intercept stanzas that point to checkpoint weights and replace them # intercept stanzas that point to checkpoint weights and replace them
# with the equivalent diffusers model # with the equivalent diffusers model
if mconfig.format in ["ckpt", "diffusers"]: if mconfig.format in ["ckpt", "safetensors"]:
location = self.convert_ckpt_and_cache(mconfig) location = self.convert_ckpt_and_cache(mconfig)
else: else:
location = global_resolve_path(mconfig.get('path')) or mconfig.get('repo_id') location = global_resolve_path(mconfig.get('path')) or mconfig.get('repo_id')
@ -355,13 +355,13 @@ class ModelManager(object):
vae = (None, None) vae = (None, None)
with suppress(Exception): with suppress(Exception):
vae_id = mconfig.vae.repo_id vae_id = mconfig.vae.repo_id
vae = (SDModelType.vae, vae_id) vae = (SDModelType.Vae, vae_id)
# optimization - don't load whole model if the user # optimization - don't load whole model if the user
# is asking for just a piece of it # is asking for just a piece of it
if model_type == SDModelType.diffusers and submodel and not subfolder: if model_type == SDModelType.Diffusers and submodel and not subfolder:
model_type = submodel model_type = submodel
subfolder = submodel.name subfolder = submodel.value
submodel = None submodel = None
model_context = self.cache.get_model( model_context = self.cache.get_model(
@ -390,7 +390,7 @@ class ModelManager(object):
_cache = self.cache _cache = self.cache
) )
def default_model(self) -> Union[Tuple[str, SDModelType],None]: def default_model(self) -> Optional[Tuple[str, SDModelType]]:
""" """
Returns the name of the default model, or None Returns the name of the default model, or None
if none is defined. if none is defined.
@ -401,7 +401,7 @@ class ModelManager(object):
return (model_name, model_type) return (model_name, model_type)
return self.model_names()[0][0] return self.model_names()[0][0]
def set_default_model(self, model_name: str, model_type: SDModelType=SDModelType.diffusers) -> None: def set_default_model(self, model_name: str, model_type: SDModelType=SDModelType.Diffusers) -> None:
""" """
Set the default model. The change will not take Set the default model. The change will not take
effect until you call model_manager.commit() effect until you call model_manager.commit()
@ -415,25 +415,25 @@ class ModelManager(object):
config[self.create_key(model_name, model_type)]["default"] = True config[self.create_key(model_name, model_type)]["default"] = True
def model_info( def model_info(
self, self,
model_name: str, model_name: str,
model_type: SDModelType=SDModelType.diffusers model_type: SDModelType=SDModelType.Diffusers,
) -> dict: ) -> dict:
""" """
Given a model name returns the OmegaConf (dict-like) object describing it. Given a model name returns the OmegaConf (dict-like) object describing it.
""" """
if not self.exists(model_name, model_type): if not self.exists(model_name, model_type):
return None return None
return self.config[self.create_key(model_name,model_type)] return self.config[self.create_key(model_name, model_type)]
def model_names(self) -> List[Tuple[str, SDModelType]]: def model_names(self) -> List[Tuple[str, SDModelType]]:
""" """
Return a list of (str, SDModelType) corresponding to all models Return a list of (str, SDModelType) corresponding to all models
known to the configuration. known to the configuration.
""" """
return [(self.parse_key(x)) for x in self.config.keys() if isinstance(self.config[x],DictConfig)] return [(self.parse_key(x)) for x in self.config.keys() if isinstance(self.config[x], DictConfig)]
def is_legacy(self, model_name: str, model_type: SDModelType.diffusers) -> bool: def is_legacy(self, model_name: str, model_type: SDModelType.Diffusers) -> bool:
""" """
Return true if this is a legacy (.ckpt) model Return true if this is a legacy (.ckpt) model
""" """
@ -461,14 +461,14 @@ class ModelManager(object):
# don't include VAEs in listing (legacy style) # don't include VAEs in listing (legacy style)
if "config" in stanza and "/VAE/" in stanza["config"]: if "config" in stanza and "/VAE/" in stanza["config"]:
continue continue
if model_key=='config_file_version': if model_key == 'config_file_version':
continue continue
model_name, model_type = self.parse_key(model_key) model_name, model_type = self.parse_key(model_key)
models[model_key] = dict() models[model_key] = dict()
# TODO: return all models in future # TODO: return all models in future
if model_type != SDModelType.diffusers: if model_type != SDModelType.Diffusers:
continue continue
model_format = stanza.get('format') model_format = stanza.get('format')
@ -477,15 +477,15 @@ class ModelManager(object):
status = self.cache.status( status = self.cache.status(
stanza.get('weights') or stanza.get('repo_id'), stanza.get('weights') or stanza.get('repo_id'),
revision=stanza.get('revision'), revision=stanza.get('revision'),
subfolder=stanza.get('subfolder') subfolder=stanza.get('subfolder'),
) )
description = stanza.get("description", None) description = stanza.get("description", None)
models[model_key].update( models[model_key].update(
model_name=model_name, model_name=model_name,
model_type=model_type.name, model_type=model_type,
format=model_format, format=model_format,
description=description, description=description,
status=status.value status=status.value,
) )
@ -528,8 +528,8 @@ class ModelManager(object):
def del_model( def del_model(
self, self,
model_name: str, model_name: str,
model_type: SDModelType.diffusers, model_type: SDModelType.Diffusers,
delete_files: bool = False delete_files: bool = False,
): ):
""" """
Delete the named model. Delete the named model.
@ -539,9 +539,9 @@ class ModelManager(object):
if model_cfg is None: if model_cfg is None:
self.logger.error( self.logger.error(
f"Unknown model {model_key}" f"Unknown model {model_key}"
) )
return return
# TODO: some legacy? # TODO: some legacy?
#if model_name in self.stack: #if model_name in self.stack:
@ -571,7 +571,7 @@ class ModelManager(object):
model_name: str, model_name: str,
model_type: SDModelType, model_type: SDModelType,
model_attributes: dict, model_attributes: dict,
clobber: bool = False clobber: bool = False,
) -> None: ) -> None:
""" """
Update the named model with a dictionary of attributes. Will fail with an Update the named model with a dictionary of attributes. Will fail with an
@ -581,7 +581,7 @@ class ModelManager(object):
attributes are incorrect or the model name is missing. attributes are incorrect or the model name is missing.
""" """
if model_type == SDModelType.diffusers: if model_type == SDModelType.Fiffusers:
# TODO: automaticaly or manualy? # TODO: automaticaly or manualy?
#assert "format" in model_attributes, 'missing required field "format"' #assert "format" in model_attributes, 'missing required field "format"'
model_format = "ckpt" if "weights" in model_attributes else "diffusers" model_format = "ckpt" if "weights" in model_attributes else "diffusers"
@ -647,16 +647,16 @@ class ModelManager(object):
else: else:
new_config.update(repo_id=repo_or_path) new_config.update(repo_id=repo_or_path)
self.add_model(model_name, SDModelType.diffusers, new_config, True) self.add_model(model_name, SDModelType.Diffusers, new_config, True)
if commit_to_conf: if commit_to_conf:
self.commit(commit_to_conf) self.commit(commit_to_conf)
return self.create_key(model_name, SDModelType.diffusers) return self.create_key(model_name, SDModelType.Diffusers)
def import_lora( def import_lora(
self, self,
path: Path, path: Path,
model_name: str=None, model_name: Optional[str] = None,
description: str=None, description: Optional[str] = None,
): ):
""" """
Creates an entry for the indicated lora file. Call Creates an entry for the indicated lora file. Call
@ -667,7 +667,7 @@ class ModelManager(object):
model_description = description or f"LoRA model {model_name}" model_description = description or f"LoRA model {model_name}"
self.add_model( self.add_model(
model_name, model_name,
SDModelType.lora, SDModelType.Lora,
dict( dict(
format="lora", format="lora",
weights=str(path), weights=str(path),
@ -679,8 +679,8 @@ class ModelManager(object):
def import_embedding( def import_embedding(
self, self,
path: Path, path: Path,
model_name: str=None, model_name: Optional[str] = None,
description: str=None, description: Optional[str] = None,
): ):
""" """
Creates an entry for the indicated lora file. Call Creates an entry for the indicated lora file. Call
@ -696,7 +696,7 @@ class ModelManager(object):
model_description = description or f"Textual embedding model {model_name}" model_description = description or f"Textual embedding model {model_name}"
self.add_model( self.add_model(
model_name, model_name,
SDModelType.textual_inversion, SDModelType.TextualInversion,
dict( dict(
format="textual_inversion", format="textual_inversion",
weights=str(weights), weights=str(weights),
@ -746,11 +746,11 @@ class ModelManager(object):
def heuristic_import( def heuristic_import(
self, self,
path_url_or_repo: str, path_url_or_repo: str,
model_name: str = None, model_name: Optional[str] = None,
description: str = None, description: Optional[str] = None,
model_config_file: Path = None, model_config_file: Optional[Path] = None,
commit_to_conf: Path = None, commit_to_conf: Optional[Path] = None,
config_file_callback: Callable[[Path], Path] = None, config_file_callback: Optional[Callable[[Path], Path]] = None,
) -> str: ) -> str:
"""Accept a string which could be: """Accept a string which could be:
- a HF diffusers repo_id - a HF diffusers repo_id
@ -927,7 +927,7 @@ class ModelManager(object):
) )
return model_name return model_name
def convert_ckpt_and_cache(self, mconfig: DictConfig)->Path: def convert_ckpt_and_cache(self, mconfig: DictConfig) -> Path:
""" """
Convert the checkpoint model indicated in mconfig into a Convert the checkpoint model indicated in mconfig into a
diffusers, cache it to disk, and return Path to converted diffusers, cache it to disk, and return Path to converted
@ -961,7 +961,7 @@ class ModelManager(object):
self, self,
weights: Path, weights: Path,
mconfig: DictConfig mconfig: DictConfig
) -> Tuple[Path, SDModelType.vae]: ) -> Tuple[Path, AutoencoderKL]:
# VAE handling is convoluted # VAE handling is convoluted
# 1. If there is a .vae.ckpt file sharing same stem as weights, then use # 1. If there is a .vae.ckpt file sharing same stem as weights, then use
# it as the vae_path passed to convert # it as the vae_path passed to convert
@ -990,7 +990,7 @@ class ModelManager(object):
vae_diffusers_location = "stabilityai/sd-vae-ft-mse" vae_diffusers_location = "stabilityai/sd-vae-ft-mse"
if vae_diffusers_location: if vae_diffusers_location:
vae_model = self.cache.get_model(vae_diffusers_location, SDModelType.vae).model vae_model = self.cache.get_model(vae_diffusers_location, SDModelType.Vae).model
return (None, vae_model) return (None, vae_model)
return (None, None) return (None, None)
@ -1038,7 +1038,7 @@ class ModelManager(object):
vae_model = None vae_model = None
if vae: if vae:
vae_location = global_resolve_path(vae.get('path')) or vae.get('repo_id') vae_location = global_resolve_path(vae.get('path')) or vae.get('repo_id')
vae_model = self.cache.get_model(vae_location,SDModelType.vae).model vae_model = self.cache.get_model(vae_location, SDModelType.Vae).model
vae_path = None vae_path = None
convert_ckpt_to_diffusers( convert_ckpt_to_diffusers(
ckpt_path, ckpt_path,
@ -1058,11 +1058,11 @@ class ModelManager(object):
description=model_description, description=model_description,
format="diffusers", format="diffusers",
) )
if self.model_exists(model_name, SDModelType.diffusers): if self.model_exists(model_name, SDModelType.Diffusers):
self.del_model(model_name, SDModelType.diffusers) self.del_model(model_name, SDModelType.Diffusers)
self.add_model( self.add_model(
model_name, model_name,
SDModelType.diffusers, SDModelType.Diffusers,
new_config, new_config,
True True
) )

View File

@ -263,7 +263,7 @@ export const parseNodeMetadata = (
return; return;
} }
if ('unet' in nodeItem && 'tokenizer' in nodeItem) { if ('unet' in nodeItem && 'scheduler' in nodeItem) {
const unetField = parseUNetField(nodeItem); const unetField = parseUNetField(nodeItem);
if (unetField) { if (unetField) {
parsed[nodeKey] = unetField; parsed[nodeKey] = unetField;