mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Change SDModelType enum to string, fixes(model unload negative locks count, scheduler load error, saftensors convert, wrong logic in del_model, wrong parse metadata in web)
This commit is contained in:
parent
2204e47596
commit
039fa73269
@ -59,19 +59,14 @@ class CompelInvocation(BaseInvocation):
|
|||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> CompelOutput:
|
def invoke(self, context: InvocationContext) -> CompelOutput:
|
||||||
|
|
||||||
# TODO: load without model
|
|
||||||
text_encoder_info = context.services.model_manager.get_model(
|
text_encoder_info = context.services.model_manager.get_model(
|
||||||
model_name=self.clip.text_encoder.model_name,
|
**self.clip.text_encoder.dict(),
|
||||||
model_type=SDModelType[self.clip.text_encoder.model_type],
|
|
||||||
submodel=SDModelType[self.clip.text_encoder.submodel],
|
|
||||||
)
|
)
|
||||||
tokenizer_info = context.services.model_manager.get_model(
|
tokenizer_info = context.services.model_manager.get_model(
|
||||||
model_name=self.clip.tokenizer.model_name,
|
**self.clip.tokenizer.dict(),
|
||||||
model_type=SDModelType[self.clip.tokenizer.model_type],
|
|
||||||
submodel=SDModelType[self.clip.tokenizer.submodel],
|
|
||||||
)
|
)
|
||||||
with text_encoder_info.context as text_encoder,\
|
with text_encoder_info as text_encoder,\
|
||||||
tokenizer_info.context as tokenizer:
|
tokenizer_info as tokenizer:
|
||||||
|
|
||||||
# TODO: global? input?
|
# TODO: global? input?
|
||||||
#use_full_precision = precision == "float32" or precision == "autocast"
|
#use_full_precision = precision == "float32" or precision == "autocast"
|
||||||
|
@ -79,12 +79,8 @@ def get_scheduler(
|
|||||||
scheduler_info: ModelInfo,
|
scheduler_info: ModelInfo,
|
||||||
scheduler_name: str,
|
scheduler_name: str,
|
||||||
) -> Scheduler:
|
) -> Scheduler:
|
||||||
orig_scheduler_info = context.services.model_manager.get_model(
|
orig_scheduler_info = context.services.model_manager.get_model(**scheduler_info.dict())
|
||||||
model_name=scheduler_info.model_name,
|
with orig_scheduler_info as orig_scheduler:
|
||||||
model_type=SDModelType[scheduler_info.model_type],
|
|
||||||
submodel=SDModelType[scheduler_info.submodel],
|
|
||||||
)
|
|
||||||
with orig_scheduler_info.context as orig_scheduler:
|
|
||||||
scheduler_config = orig_scheduler.config
|
scheduler_config = orig_scheduler.config
|
||||||
|
|
||||||
scheduler_class = scheduler_map.get(scheduler_name,'ddim')
|
scheduler_class = scheduler_map.get(scheduler_name,'ddim')
|
||||||
@ -243,14 +239,8 @@ class TextToLatentsInvocation(BaseInvocation):
|
|||||||
def step_callback(state: PipelineIntermediateState):
|
def step_callback(state: PipelineIntermediateState):
|
||||||
self.dispatch_progress(context, source_node_id, state)
|
self.dispatch_progress(context, source_node_id, state)
|
||||||
|
|
||||||
#unet_info = context.services.model_manager.get_model(**self.unet.unet.dict())
|
unet_info = context.services.model_manager.get_model(**self.unet.unet.dict())
|
||||||
unet_info = context.services.model_manager.get_model(
|
with unet_info as unet:
|
||||||
model_name=self.unet.unet.model_name,
|
|
||||||
model_type=SDModelType[self.unet.unet.model_type],
|
|
||||||
submodel=SDModelType[self.unet.unet.submodel] if self.unet.unet.submodel else None,
|
|
||||||
)
|
|
||||||
|
|
||||||
with unet_info.context as unet:
|
|
||||||
scheduler = get_scheduler(
|
scheduler = get_scheduler(
|
||||||
context=context,
|
context=context,
|
||||||
scheduler_info=self.unet.scheduler,
|
scheduler_info=self.unet.scheduler,
|
||||||
@ -309,12 +299,10 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
|
|||||||
|
|
||||||
#unet_info = context.services.model_manager.get_model(**self.unet.unet.dict())
|
#unet_info = context.services.model_manager.get_model(**self.unet.unet.dict())
|
||||||
unet_info = context.services.model_manager.get_model(
|
unet_info = context.services.model_manager.get_model(
|
||||||
model_name=self.unet.unet.model_name,
|
**self.unet.unet.dict(),
|
||||||
model_type=SDModelType[self.unet.unet.model_type],
|
|
||||||
submodel=SDModelType[self.unet.unet.submodel] if self.unet.unet.submodel else None,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
with unet_info.context as unet:
|
with unet_info as unet:
|
||||||
scheduler = get_scheduler(
|
scheduler = get_scheduler(
|
||||||
context=context,
|
context=context,
|
||||||
scheduler_info=self.unet.scheduler,
|
scheduler_info=self.unet.scheduler,
|
||||||
@ -379,18 +367,18 @@ class LatentsToImageInvocation(BaseInvocation):
|
|||||||
|
|
||||||
#vae_info = context.services.model_manager.get_model(**self.vae.vae.dict())
|
#vae_info = context.services.model_manager.get_model(**self.vae.vae.dict())
|
||||||
vae_info = context.services.model_manager.get_model(
|
vae_info = context.services.model_manager.get_model(
|
||||||
model_name=self.vae.vae.model_name,
|
**self.vae.vae.dict(),
|
||||||
model_type=SDModelType[self.vae.vae.model_type],
|
|
||||||
submodel=SDModelType[self.vae.vae.submodel] if self.vae.vae.submodel else None,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
with vae_info.context as vae:
|
with vae_info as vae:
|
||||||
# TODO: check if it works
|
|
||||||
if self.tiled:
|
if self.tiled:
|
||||||
vae.enable_tiling()
|
vae.enable_tiling()
|
||||||
else:
|
else:
|
||||||
vae.disable_tiling()
|
vae.disable_tiling()
|
||||||
|
|
||||||
|
# clear memory as vae decode can request a lot
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
with torch.inference_mode():
|
with torch.inference_mode():
|
||||||
# copied from diffusers pipeline
|
# copied from diffusers pipeline
|
||||||
latents = latents / vae.config.scaling_factor
|
latents = latents / vae.config.scaling_factor
|
||||||
@ -509,36 +497,29 @@ class ImageToLatentsInvocation(BaseInvocation):
|
|||||||
|
|
||||||
#vae_info = context.services.model_manager.get_model(**self.vae.vae.dict())
|
#vae_info = context.services.model_manager.get_model(**self.vae.vae.dict())
|
||||||
vae_info = context.services.model_manager.get_model(
|
vae_info = context.services.model_manager.get_model(
|
||||||
model_name=self.vae.vae.model_name,
|
**self.vae.vae.dict(),
|
||||||
model_type=SDModelType[self.vae.vae.model_type],
|
|
||||||
submodel=SDModelType[self.vae.vae.submodel] if self.vae.vae.submodel else None,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB"))
|
image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB"))
|
||||||
if image_tensor.dim() == 3:
|
if image_tensor.dim() == 3:
|
||||||
image_tensor = einops.rearrange(image_tensor, "c h w -> 1 c h w")
|
image_tensor = einops.rearrange(image_tensor, "c h w -> 1 c h w")
|
||||||
|
|
||||||
with vae_info.context as vae:
|
with vae_info as vae:
|
||||||
# TODO: check if it works
|
|
||||||
if self.tiled:
|
if self.tiled:
|
||||||
vae.enable_tiling()
|
vae.enable_tiling()
|
||||||
else:
|
else:
|
||||||
vae.disable_tiling()
|
vae.disable_tiling()
|
||||||
|
|
||||||
latents = self.non_noised_latents_from_image(vae, image_tensor)
|
# non_noised_latents_from_image
|
||||||
|
image_tensor = image_tensor.to(device=vae.device, dtype=vae.dtype)
|
||||||
|
with torch.inference_mode():
|
||||||
|
image_tensor_dist = vae.encode(image_tensor).latent_dist
|
||||||
|
latents = image_tensor_dist.sample().to(
|
||||||
|
dtype=vae.dtype
|
||||||
|
) # FIXME: uses torch.randn. make reproducible!
|
||||||
|
|
||||||
|
latents = 0.18215 * latents
|
||||||
|
|
||||||
name = f"{context.graph_execution_state_id}__{self.id}"
|
name = f"{context.graph_execution_state_id}__{self.id}"
|
||||||
context.services.latents.set(name, latents)
|
context.services.latents.set(name, latents)
|
||||||
return LatentsOutput(latents=LatentsField(latents_name=name))
|
return LatentsOutput(latents=LatentsField(latents_name=name))
|
||||||
|
|
||||||
|
|
||||||
def non_noised_latents_from_image(self, vae, init_image):
|
|
||||||
init_image = init_image.to(device=vae.device, dtype=vae.dtype)
|
|
||||||
with torch.inference_mode():
|
|
||||||
init_latent_dist = vae.encode(init_image).latent_dist
|
|
||||||
init_latents = init_latent_dist.sample().to(
|
|
||||||
dtype=vae.dtype
|
|
||||||
) # FIXME: uses torch.randn. make reproducible!
|
|
||||||
|
|
||||||
init_latents = 0.18215 * init_latents
|
|
||||||
return init_latents
|
|
@ -8,8 +8,8 @@ from ...backend.model_management import SDModelType
|
|||||||
|
|
||||||
class ModelInfo(BaseModel):
|
class ModelInfo(BaseModel):
|
||||||
model_name: str = Field(description="Info to load unet submodel")
|
model_name: str = Field(description="Info to load unet submodel")
|
||||||
model_type: str = Field(description="Info to load unet submodel")
|
model_type: SDModelType = Field(description="Info to load unet submodel")
|
||||||
submodel: Optional[str] = Field(description="Info to load unet submodel")
|
submodel: Optional[SDModelType] = Field(description="Info to load unet submodel")
|
||||||
|
|
||||||
class UNetField(BaseModel):
|
class UNetField(BaseModel):
|
||||||
unet: ModelInfo = Field(description="Info to load unet submodel")
|
unet: ModelInfo = Field(description="Info to load unet submodel")
|
||||||
@ -62,15 +62,15 @@ class ModelLoaderInvocation(BaseInvocation):
|
|||||||
# TODO: not found exceptions
|
# TODO: not found exceptions
|
||||||
if not context.services.model_manager.model_exists(
|
if not context.services.model_manager.model_exists(
|
||||||
model_name=self.model_name,
|
model_name=self.model_name,
|
||||||
model_type=SDModelType.diffusers,
|
model_type=SDModelType.Diffusers,
|
||||||
):
|
):
|
||||||
raise Exception(f"Unkown model name: {self.model_name}!")
|
raise Exception(f"Unkown model name: {self.model_name}!")
|
||||||
|
|
||||||
"""
|
"""
|
||||||
if not context.services.model_manager.model_exists(
|
if not context.services.model_manager.model_exists(
|
||||||
model_name=self.model_name,
|
model_name=self.model_name,
|
||||||
model_type=SDModelType.diffusers,
|
model_type=SDModelType.Diffusers,
|
||||||
submodel=SDModelType.tokenizer,
|
submodel=SDModelType.Tokenizer,
|
||||||
):
|
):
|
||||||
raise Exception(
|
raise Exception(
|
||||||
f"Failed to find tokenizer submodel in {self.model_name}! Check if model corrupted"
|
f"Failed to find tokenizer submodel in {self.model_name}! Check if model corrupted"
|
||||||
@ -78,8 +78,8 @@ class ModelLoaderInvocation(BaseInvocation):
|
|||||||
|
|
||||||
if not context.services.model_manager.model_exists(
|
if not context.services.model_manager.model_exists(
|
||||||
model_name=self.model_name,
|
model_name=self.model_name,
|
||||||
model_type=SDModelType.diffusers,
|
model_type=SDModelType.Diffusers,
|
||||||
submodel=SDModelType.text_encoder,
|
submodel=SDModelType.TextEncoder,
|
||||||
):
|
):
|
||||||
raise Exception(
|
raise Exception(
|
||||||
f"Failed to find text_encoder submodel in {self.model_name}! Check if model corrupted"
|
f"Failed to find text_encoder submodel in {self.model_name}! Check if model corrupted"
|
||||||
@ -87,8 +87,8 @@ class ModelLoaderInvocation(BaseInvocation):
|
|||||||
|
|
||||||
if not context.services.model_manager.model_exists(
|
if not context.services.model_manager.model_exists(
|
||||||
model_name=self.model_name,
|
model_name=self.model_name,
|
||||||
model_type=SDModelType.diffusers,
|
model_type=SDModelType.Diffusers,
|
||||||
submodel=SDModelType.unet,
|
submodel=SDModelType.UNet,
|
||||||
):
|
):
|
||||||
raise Exception(
|
raise Exception(
|
||||||
f"Failed to find unet submodel from {self.model_name}! Check if model corrupted"
|
f"Failed to find unet submodel from {self.model_name}! Check if model corrupted"
|
||||||
@ -100,32 +100,32 @@ class ModelLoaderInvocation(BaseInvocation):
|
|||||||
unet=UNetField(
|
unet=UNetField(
|
||||||
unet=ModelInfo(
|
unet=ModelInfo(
|
||||||
model_name=self.model_name,
|
model_name=self.model_name,
|
||||||
model_type=SDModelType.diffusers.name,
|
model_type=SDModelType.Diffusers,
|
||||||
submodel=SDModelType.unet.name,
|
submodel=SDModelType.UNet,
|
||||||
),
|
),
|
||||||
scheduler=ModelInfo(
|
scheduler=ModelInfo(
|
||||||
model_name=self.model_name,
|
model_name=self.model_name,
|
||||||
model_type=SDModelType.diffusers.name,
|
model_type=SDModelType.Diffusers,
|
||||||
submodel=SDModelType.scheduler.name,
|
submodel=SDModelType.Scheduler,
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
clip=ClipField(
|
clip=ClipField(
|
||||||
tokenizer=ModelInfo(
|
tokenizer=ModelInfo(
|
||||||
model_name=self.model_name,
|
model_name=self.model_name,
|
||||||
model_type=SDModelType.diffusers.name,
|
model_type=SDModelType.Diffusers,
|
||||||
submodel=SDModelType.tokenizer.name,
|
submodel=SDModelType.Tokenizer,
|
||||||
),
|
),
|
||||||
text_encoder=ModelInfo(
|
text_encoder=ModelInfo(
|
||||||
model_name=self.model_name,
|
model_name=self.model_name,
|
||||||
model_type=SDModelType.diffusers.name,
|
model_type=SDModelType.Diffusers,
|
||||||
submodel=SDModelType.text_encoder.name,
|
submodel=SDModelType.TextEncoder,
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
vae=VaeField(
|
vae=VaeField(
|
||||||
vae=ModelInfo(
|
vae=ModelInfo(
|
||||||
model_name=self.model_name,
|
model_name=self.model_name,
|
||||||
model_type=SDModelType.diffusers.name,
|
model_type=SDModelType.Diffusers,
|
||||||
submodel=SDModelType.vae.name,
|
submodel=SDModelType.Vae,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
@ -120,7 +120,7 @@ class EventServiceBase:
|
|||||||
node=node,
|
node=node,
|
||||||
source_node_id=source_node_id,
|
source_node_id=source_node_id,
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
model_type=model_type.name,
|
model_type=model_type,
|
||||||
submodel=submodel,
|
submodel=submodel,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
@ -143,7 +143,7 @@ class EventServiceBase:
|
|||||||
node=node,
|
node=node,
|
||||||
source_node_id=source_node_id,
|
source_node_id=source_node_id,
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
model_type=model_type.name,
|
model_type=model_type,
|
||||||
submodel=submodel,
|
submodel=submodel,
|
||||||
model_info=model_info,
|
model_info=model_info,
|
||||||
),
|
),
|
||||||
|
@ -1,21 +1,25 @@
|
|||||||
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Team
|
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Team
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import torch
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Union, Callable, List, Tuple, types
|
from typing import Union, Callable, List, Tuple, types, TYPE_CHECKING
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
|
||||||
from invokeai.backend.model_management.model_manager import (
|
from invokeai.backend.model_management.model_manager import (
|
||||||
ModelManager,
|
ModelManager,
|
||||||
SDModelType,
|
SDModelType,
|
||||||
SDModelInfo,
|
SDModelInfo,
|
||||||
torch,
|
|
||||||
)
|
)
|
||||||
from invokeai.app.models.exceptions import CanceledException
|
from invokeai.app.models.exceptions import CanceledException
|
||||||
from ...backend import Args,Globals # this must go when pr 3340 merged
|
from ...backend import Args, Globals # this must go when pr 3340 merged
|
||||||
from ...backend.util import choose_precision, choose_torch_device
|
from ...backend.util import choose_precision, choose_torch_device
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from ..invocations.baseinvocation import BaseInvocation, InvocationContext
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class LastUsedModel:
|
class LastUsedModel:
|
||||||
model_name: str=None
|
model_name: str=None
|
||||||
@ -28,9 +32,9 @@ class ModelManagerServiceBase(ABC):
|
|||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: Args,
|
config: Args,
|
||||||
logger: types.ModuleType
|
logger: types.ModuleType,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Initialize with the path to the models.yaml config file.
|
Initialize with the path to the models.yaml config file.
|
||||||
@ -41,13 +45,14 @@ class ModelManagerServiceBase(ABC):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_model(self,
|
def get_model(
|
||||||
model_name: str,
|
self,
|
||||||
model_type: SDModelType=SDModelType.diffusers,
|
model_name: str,
|
||||||
submodel: SDModelType=None,
|
model_type: SDModelType = SDModelType.Diffusers,
|
||||||
node=None, # circular dependency issues, so untyped at moment
|
submodel: Optional[SDModelType] = None,
|
||||||
context=None,
|
node: Optional[BaseInvocation] = None,
|
||||||
)->SDModelInfo:
|
context: Optional[InvocationContext] = None,
|
||||||
|
) -> SDModelInfo:
|
||||||
"""Retrieve the indicated model with name and type.
|
"""Retrieve the indicated model with name and type.
|
||||||
submodel can be used to get a part (such as the vae)
|
submodel can be used to get a part (such as the vae)
|
||||||
of a diffusers pipeline."""
|
of a diffusers pipeline."""
|
||||||
@ -60,14 +65,14 @@ class ModelManagerServiceBase(ABC):
|
|||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def model_exists(
|
def model_exists(
|
||||||
self,
|
self,
|
||||||
model_name: str,
|
model_name: str,
|
||||||
model_type: SDModelType
|
model_type: SDModelType,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def default_model(self) -> Union[Tuple[str, SDModelType],None]:
|
def default_model(self) -> Optional[Tuple[str, SDModelType]]:
|
||||||
"""
|
"""
|
||||||
Returns the name and typeof the default model, or None
|
Returns the name and typeof the default model, or None
|
||||||
if none is defined.
|
if none is defined.
|
||||||
@ -80,21 +85,21 @@ class ModelManagerServiceBase(ABC):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def model_info(self, model_name: str, model_type: SDModelType)->dict:
|
def model_info(self, model_name: str, model_type: SDModelType) -> dict:
|
||||||
"""
|
"""
|
||||||
Given a model name returns a dict-like (OmegaConf) object describing it.
|
Given a model name returns a dict-like (OmegaConf) object describing it.
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def model_names(self)->List[Tuple[str, SDModelType]]:
|
def model_names(self) -> List[Tuple[str, SDModelType]]:
|
||||||
"""
|
"""
|
||||||
Returns a list of all the model names known.
|
Returns a list of all the model names known.
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def list_models(self)->dict:
|
def list_models(self) -> dict:
|
||||||
"""
|
"""
|
||||||
Return a dict of models in the format:
|
Return a dict of models in the format:
|
||||||
{ model_key1: {'status': 'active'|'cached'|'not loaded',
|
{ model_key1: {'status': 'active'|'cached'|'not loaded',
|
||||||
@ -110,12 +115,12 @@ class ModelManagerServiceBase(ABC):
|
|||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def add_model(
|
def add_model(
|
||||||
self,
|
self,
|
||||||
model_name: str,
|
model_name: str,
|
||||||
model_type: SDModelType,
|
model_type: SDModelType,
|
||||||
model_attributes: dict,
|
model_attributes: dict,
|
||||||
clobber: bool = False
|
clobber: bool = False
|
||||||
)->None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Update the named model with a dictionary of attributes. Will fail with an
|
Update the named model with a dictionary of attributes. Will fail with an
|
||||||
assertion error if the name already exists. Pass clobber=True to overwrite.
|
assertion error if the name already exists. Pass clobber=True to overwrite.
|
||||||
@ -126,10 +131,12 @@ class ModelManagerServiceBase(ABC):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def del_model(self,
|
def del_model(
|
||||||
model_name: str,
|
self,
|
||||||
model_type: SDModelType,
|
model_name: str,
|
||||||
delete_files: bool = False):
|
model_type: SDModelType,
|
||||||
|
delete_files: bool = False,
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Delete the named model from configuration. If delete_files is true,
|
Delete the named model from configuration. If delete_files is true,
|
||||||
then the underlying weight file or diffusers directory will be deleted
|
then the underlying weight file or diffusers directory will be deleted
|
||||||
@ -140,9 +147,9 @@ class ModelManagerServiceBase(ABC):
|
|||||||
@abstractmethod
|
@abstractmethod
|
||||||
def import_diffuser_model(
|
def import_diffuser_model(
|
||||||
repo_or_path: Union[str, Path],
|
repo_or_path: Union[str, Path],
|
||||||
model_name: str = None,
|
model_name: Optional[str] = None,
|
||||||
description: str = None,
|
description: Optional[str] = None,
|
||||||
vae: dict = None,
|
vae: Optional[dict] = None,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""
|
"""
|
||||||
Install the indicated diffuser model and returns True if successful.
|
Install the indicated diffuser model and returns True if successful.
|
||||||
@ -157,10 +164,10 @@ class ModelManagerServiceBase(ABC):
|
|||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def import_lora(
|
def import_lora(
|
||||||
self,
|
self,
|
||||||
path: Path,
|
path: Path,
|
||||||
model_name: str=None,
|
model_name: Optional[str] = None,
|
||||||
description: str=None,
|
description: Optional[str] = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Creates an entry for the indicated lora file. Call
|
Creates an entry for the indicated lora file. Call
|
||||||
@ -170,10 +177,10 @@ class ModelManagerServiceBase(ABC):
|
|||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def import_embedding(
|
def import_embedding(
|
||||||
self,
|
self,
|
||||||
path: Path,
|
path: Path,
|
||||||
model_name: str=None,
|
model_name: str=None,
|
||||||
description: str=None,
|
description: str=None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Creates an entry for the indicated textual inversion embedding file.
|
Creates an entry for the indicated textual inversion embedding file.
|
||||||
@ -223,7 +230,7 @@ class ModelManagerServiceBase(ABC):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def commit(self, conf_file: Path=None) -> None:
|
def commit(self, conf_file: Path = None) -> None:
|
||||||
"""
|
"""
|
||||||
Write current configuration out to the indicated file.
|
Write current configuration out to the indicated file.
|
||||||
If no conf_file is provided, then replaces the
|
If no conf_file is provided, then replaces the
|
||||||
@ -235,10 +242,10 @@ class ModelManagerServiceBase(ABC):
|
|||||||
class ModelManagerService(ModelManagerServiceBase):
|
class ModelManagerService(ModelManagerServiceBase):
|
||||||
"""Responsible for managing models on disk and in memory"""
|
"""Responsible for managing models on disk and in memory"""
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: Args,
|
config: Args,
|
||||||
logger: types.ModuleType
|
logger: types.ModuleType,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Initialize with the path to the models.yaml config file.
|
Initialize with the path to the models.yaml config file.
|
||||||
Optional parameters are the torch device type, precision, max_models,
|
Optional parameters are the torch device type, precision, max_models,
|
||||||
@ -255,7 +262,7 @@ class ModelManagerService(ModelManagerServiceBase):
|
|||||||
logger.debug(f'config file={config_file}')
|
logger.debug(f'config file={config_file}')
|
||||||
|
|
||||||
device = torch.device(choose_torch_device())
|
device = torch.device(choose_torch_device())
|
||||||
if config.precision=="auto":
|
if config.precision == "auto":
|
||||||
precision = choose_precision(device)
|
precision = choose_precision(device)
|
||||||
dtype = torch.float32 if precision=='float32' \
|
dtype = torch.float32 if precision=='float32' \
|
||||||
else torch.float16
|
else torch.float16
|
||||||
@ -272,22 +279,24 @@ class ModelManagerService(ModelManagerServiceBase):
|
|||||||
|
|
||||||
sequential_offload = config.sequential_guidance
|
sequential_offload = config.sequential_guidance
|
||||||
|
|
||||||
self.mgr = ModelManager(config=config_file,
|
self.mgr = ModelManager(
|
||||||
device_type=device,
|
config=config_file,
|
||||||
precision=dtype,
|
device_type=device,
|
||||||
max_cache_size=max_cache_size,
|
precision=dtype,
|
||||||
sequential_offload=sequential_offload,
|
max_cache_size=max_cache_size,
|
||||||
logger=logger
|
sequential_offload=sequential_offload,
|
||||||
)
|
logger=logger,
|
||||||
|
)
|
||||||
logger.info('Model manager service initialized')
|
logger.info('Model manager service initialized')
|
||||||
|
|
||||||
def get_model(self,
|
def get_model(
|
||||||
model_name: str,
|
self,
|
||||||
model_type: SDModelType=SDModelType.diffusers,
|
model_name: str,
|
||||||
submodel: SDModelType=None,
|
model_type: SDModelType = SDModelType.Diffusers,
|
||||||
node=None,
|
submodel: Optional[SDModelType] = None,
|
||||||
context=None,
|
node: Optional[BaseInvocation] = None,
|
||||||
)->SDModelInfo:
|
context: Optional[InvocationContext] = None,
|
||||||
|
) -> SDModelInfo:
|
||||||
"""
|
"""
|
||||||
Retrieve the indicated model. submodel can be used to get a
|
Retrieve the indicated model. submodel can be used to get a
|
||||||
part (such as the vae) of a diffusers mode.
|
part (such as the vae) of a diffusers mode.
|
||||||
@ -340,9 +349,9 @@ class ModelManagerService(ModelManagerServiceBase):
|
|||||||
return model_info
|
return model_info
|
||||||
|
|
||||||
def model_exists(
|
def model_exists(
|
||||||
self,
|
self,
|
||||||
model_name: str,
|
model_name: str,
|
||||||
model_type: SDModelType
|
model_type: SDModelType,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""
|
"""
|
||||||
Given a model name, returns True if it is a valid
|
Given a model name, returns True if it is a valid
|
||||||
@ -350,32 +359,33 @@ class ModelManagerService(ModelManagerServiceBase):
|
|||||||
"""
|
"""
|
||||||
return self.mgr.model_exists(
|
return self.mgr.model_exists(
|
||||||
model_name,
|
model_name,
|
||||||
model_type)
|
model_type,
|
||||||
|
)
|
||||||
|
|
||||||
def default_model(self) -> Union[Tuple[str, SDModelType],None]:
|
def default_model(self) -> Optional[Tuple[str, SDModelType]]:
|
||||||
"""
|
"""
|
||||||
Returns the name of the default model, or None
|
Returns the name of the default model, or None
|
||||||
if none is defined.
|
if none is defined.
|
||||||
"""
|
"""
|
||||||
return self.mgr.default_model()
|
return self.mgr.default_model()
|
||||||
|
|
||||||
def set_default_model(self, model_name:str, model_type: SDModelType):
|
def set_default_model(self, model_name: str, model_type: SDModelType):
|
||||||
"""Sets the default model to the indicated name."""
|
"""Sets the default model to the indicated name."""
|
||||||
self.mgr.set_default_model(model_name)
|
self.mgr.set_default_model(model_name)
|
||||||
|
|
||||||
def model_info(self, model_name: str, model_type: SDModelType)->dict:
|
def model_info(self, model_name: str, model_type: SDModelType) -> dict:
|
||||||
"""
|
"""
|
||||||
Given a model name returns a dict-like (OmegaConf) object describing it.
|
Given a model name returns a dict-like (OmegaConf) object describing it.
|
||||||
"""
|
"""
|
||||||
return self.mgr.model_info(model_name)
|
return self.mgr.model_info(model_name)
|
||||||
|
|
||||||
def model_names(self)->List[Tuple[str, SDModelType]]:
|
def model_names(self) -> List[Tuple[str, SDModelType]]:
|
||||||
"""
|
"""
|
||||||
Returns a list of all the model names known.
|
Returns a list of all the model names known.
|
||||||
"""
|
"""
|
||||||
return self.mgr.model_names()
|
return self.mgr.model_names()
|
||||||
|
|
||||||
def list_models(self)->dict:
|
def list_models(self) -> dict:
|
||||||
"""
|
"""
|
||||||
Return a dict of models in the format:
|
Return a dict of models in the format:
|
||||||
{ model_key: {'status': 'active'|'cached'|'not loaded',
|
{ model_key: {'status': 'active'|'cached'|'not loaded',
|
||||||
@ -388,11 +398,12 @@ class ModelManagerService(ModelManagerServiceBase):
|
|||||||
return self.mgr.list_models()
|
return self.mgr.list_models()
|
||||||
|
|
||||||
def add_model(
|
def add_model(
|
||||||
self,
|
self,
|
||||||
model_name: str,
|
model_name: str,
|
||||||
model_type: SDModelType,
|
model_type: SDModelType,
|
||||||
model_attributes: dict,
|
model_attributes: dict,
|
||||||
clobber: bool = False)->None:
|
clobber: bool = False,
|
||||||
|
)->None:
|
||||||
"""
|
"""
|
||||||
Update the named model with a dictionary of attributes. Will fail with an
|
Update the named model with a dictionary of attributes. Will fail with an
|
||||||
assertion error if the name already exists. Pass clobber=True to overwrite.
|
assertion error if the name already exists. Pass clobber=True to overwrite.
|
||||||
@ -400,14 +411,15 @@ class ModelManagerService(ModelManagerServiceBase):
|
|||||||
with an assertion error if provided attributes are incorrect or
|
with an assertion error if provided attributes are incorrect or
|
||||||
the model name is missing. Call commit() to write changes to disk.
|
the model name is missing. Call commit() to write changes to disk.
|
||||||
"""
|
"""
|
||||||
return self.mgr.add_model(model_name, model_type, model_attributes, dict, clobber)
|
return self.mgr.add_model(model_name, model_type, model_attributes, clobber)
|
||||||
|
|
||||||
|
|
||||||
def del_model(self,
|
def del_model(
|
||||||
model_name: str,
|
self,
|
||||||
model_type: SDModelType=SDModelType.diffusers,
|
model_name: str,
|
||||||
delete_files: bool = False
|
model_type: SDModelType = SDModelType.Diffusers,
|
||||||
):
|
delete_files: bool = False,
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Delete the named model from configuration. If delete_files is true,
|
Delete the named model from configuration. If delete_files is true,
|
||||||
then the underlying weight file or diffusers directory will be deleted
|
then the underlying weight file or diffusers directory will be deleted
|
||||||
@ -416,11 +428,11 @@ class ModelManagerService(ModelManagerServiceBase):
|
|||||||
self.mgr.del_model(model_name, model_type, delete_files)
|
self.mgr.del_model(model_name, model_type, delete_files)
|
||||||
|
|
||||||
def import_diffuser_model(
|
def import_diffuser_model(
|
||||||
self,
|
self,
|
||||||
repo_or_path: Union[str, Path],
|
repo_or_path: Union[str, Path],
|
||||||
model_name: str = None,
|
model_name: Optional[str] = None,
|
||||||
description: str = None,
|
description: Optional[str] = None,
|
||||||
vae: dict = None,
|
vae: Optional[dict] = None,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""
|
"""
|
||||||
Install the indicated diffuser model and returns True if successful.
|
Install the indicated diffuser model and returns True if successful.
|
||||||
@ -431,13 +443,13 @@ class ModelManagerService(ModelManagerServiceBase):
|
|||||||
You can optionally provide a model name and/or description. If not provided,
|
You can optionally provide a model name and/or description. If not provided,
|
||||||
then these will be derived from the repo name. Call commit() to write to disk.
|
then these will be derived from the repo name. Call commit() to write to disk.
|
||||||
"""
|
"""
|
||||||
return self.mgr.import_diffuser_model(repo_or_path, model_name, description, vae)
|
return self.mgr.import_diffuser_model(repo_or_path, model_name, description, vae)
|
||||||
|
|
||||||
def import_lora(
|
def import_lora(
|
||||||
self,
|
self,
|
||||||
path: Path,
|
path: Path,
|
||||||
model_name: str=None,
|
model_name: Optional[str] = None,
|
||||||
description: str=None,
|
description: Optional[str] = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Creates an entry for the indicated lora file. Call
|
Creates an entry for the indicated lora file. Call
|
||||||
@ -446,10 +458,10 @@ class ModelManagerService(ModelManagerServiceBase):
|
|||||||
self.mgr.import_lora(path, model_name, description)
|
self.mgr.import_lora(path, model_name, description)
|
||||||
|
|
||||||
def import_embedding(
|
def import_embedding(
|
||||||
self,
|
self,
|
||||||
path: Path,
|
path: Path,
|
||||||
model_name: str=None,
|
model_name: Optional[str] = None,
|
||||||
description: str=None,
|
description: Optional[str] = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Creates an entry for the indicated textual inversion embedding file.
|
Creates an entry for the indicated textual inversion embedding file.
|
||||||
@ -462,9 +474,9 @@ class ModelManagerService(ModelManagerServiceBase):
|
|||||||
path_url_or_repo: str,
|
path_url_or_repo: str,
|
||||||
model_name: str = None,
|
model_name: str = None,
|
||||||
description: str = None,
|
description: str = None,
|
||||||
model_config_file: Path = None,
|
model_config_file: Optional[Path] = None,
|
||||||
commit_to_conf: Path = None,
|
commit_to_conf: Optional[Path] = None,
|
||||||
config_file_callback: Callable[[Path], Path] = None,
|
config_file_callback: Optional[Callable[[Path], Path]] = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Accept a string which could be:
|
"""Accept a string which could be:
|
||||||
- a HF diffusers repo_id
|
- a HF diffusers repo_id
|
||||||
@ -505,7 +517,7 @@ class ModelManagerService(ModelManagerServiceBase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def commit(self, conf_file: Path=None):
|
def commit(self, conf_file: Optional[Path]=None):
|
||||||
"""
|
"""
|
||||||
Write current configuration out to the indicated file.
|
Write current configuration out to the indicated file.
|
||||||
If no conf_file is provided, then replaces the
|
If no conf_file is provided, then replaces the
|
||||||
@ -514,16 +526,16 @@ class ModelManagerService(ModelManagerServiceBase):
|
|||||||
return self.mgr.commit(conf_file)
|
return self.mgr.commit(conf_file)
|
||||||
|
|
||||||
def _emit_load_event(
|
def _emit_load_event(
|
||||||
self,
|
self,
|
||||||
node,
|
node,
|
||||||
context,
|
context,
|
||||||
model_name: str,
|
model_name: str,
|
||||||
model_type: SDModelType,
|
model_type: SDModelType,
|
||||||
submodel: SDModelType,
|
submodel: SDModelType,
|
||||||
model_info: SDModelInfo=None,
|
model_info: Optional[SDModelInfo] = None,
|
||||||
):
|
):
|
||||||
if context.services.queue.is_canceled(context.graph_execution_state_id):
|
if context.services.queue.is_canceled(context.graph_execution_state_id):
|
||||||
raise CanceledException
|
raise CanceledException()
|
||||||
graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id)
|
graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id)
|
||||||
source_node_id = graph_execution_state.prepared_source_mapping[node.id]
|
source_node_id = graph_execution_state.prepared_source_mapping[node.id]
|
||||||
if context:
|
if context:
|
||||||
@ -536,7 +548,7 @@ class ModelManagerService(ModelManagerServiceBase):
|
|||||||
submodel=submodel,
|
submodel=submodel,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
context.services.events.emit_model_load_completed (
|
context.services.events.emit_model_load_completed(
|
||||||
graph_execution_state_id=context.graph_execution_state_id,
|
graph_execution_state_id=context.graph_execution_state_id,
|
||||||
node=node.dict(),
|
node=node.dict(),
|
||||||
source_node_id=source_node_id,
|
source_node_id=source_node_id,
|
||||||
|
@ -23,12 +23,12 @@ import warnings
|
|||||||
from collections import Counter
|
from collections import Counter
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, Sequence, Union, Tuple, types
|
from typing import Dict, Sequence, Union, Tuple, types, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import safetensors.torch
|
import safetensors.torch
|
||||||
|
|
||||||
from diffusers import DiffusionPipeline, StableDiffusionPipeline, AutoencoderKL, SchedulerMixin, UNet2DConditionModel
|
from diffusers import DiffusionPipeline, StableDiffusionPipeline, AutoencoderKL, SchedulerMixin, UNet2DConditionModel, ConfigMixin
|
||||||
from diffusers import logging as diffusers_logging
|
from diffusers import logging as diffusers_logging
|
||||||
from diffusers.pipelines.stable_diffusion.safety_checker import \
|
from diffusers.pipelines.stable_diffusion.safety_checker import \
|
||||||
StableDiffusionSafetyChecker
|
StableDiffusionSafetyChecker
|
||||||
@ -55,20 +55,38 @@ class LoraType(dict):
|
|||||||
class TIType(dict):
|
class TIType(dict):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
class SDModelType(Enum):
|
class SDModelType(str, Enum):
|
||||||
diffusers=StableDiffusionGeneratorPipeline # whole pipeline
|
Diffusers="diffusers" # whole pipeline
|
||||||
vae=AutoencoderKL # diffusers parts
|
Vae="vae" # diffusers parts
|
||||||
text_encoder=CLIPTextModel
|
TextEncoder="text_encoder"
|
||||||
tokenizer=CLIPTokenizer
|
Tokenizer="tokenizer"
|
||||||
unet=UNet2DConditionModel
|
UNet="unet"
|
||||||
scheduler=SchedulerMixin
|
Scheduler="scheduler"
|
||||||
safety_checker=StableDiffusionSafetyChecker
|
SafetyChecker="safety_checker"
|
||||||
feature_extractor=CLIPFeatureExtractor
|
FeatureExtractor="feature_extractor"
|
||||||
# These are all loaded as dicts of tensors, and we
|
# These are all loaded as dicts of tensors, and we
|
||||||
# distinguish them by class
|
# distinguish them by class
|
||||||
lora=LoraType
|
Lora="lora"
|
||||||
textual_inversion=TIType
|
TextualInversion="textual_inversion"
|
||||||
|
|
||||||
|
# TODO:
|
||||||
|
class EmptyScheduler(SchedulerMixin, ConfigMixin):
|
||||||
|
pass
|
||||||
|
|
||||||
|
MODEL_CLASSES = {
|
||||||
|
SDModelType.Diffusers: StableDiffusionGeneratorPipeline,
|
||||||
|
SDModelType.Vae: AutoencoderKL,
|
||||||
|
SDModelType.TextEncoder: CLIPTextModel, # TODO: t5
|
||||||
|
SDModelType.Tokenizer: CLIPTokenizer, # TODO: t5
|
||||||
|
SDModelType.UNet: UNet2DConditionModel,
|
||||||
|
SDModelType.Scheduler: EmptyScheduler,
|
||||||
|
SDModelType.SafetyChecker: StableDiffusionSafetyChecker,
|
||||||
|
SDModelType.FeatureExtractor: CLIPFeatureExtractor,
|
||||||
|
|
||||||
|
SDModelType.Lora: LoraType,
|
||||||
|
SDModelType.TextualInversion: TIType,
|
||||||
|
}
|
||||||
|
|
||||||
class ModelStatus(Enum):
|
class ModelStatus(Enum):
|
||||||
unknown='unknown'
|
unknown='unknown'
|
||||||
not_loaded='not loaded'
|
not_loaded='not loaded'
|
||||||
@ -80,21 +98,21 @@ class ModelStatus(Enum):
|
|||||||
# After loading, we will know it exactly.
|
# After loading, we will know it exactly.
|
||||||
# Sizes are in Gigs, estimated for float16; double for float32
|
# Sizes are in Gigs, estimated for float16; double for float32
|
||||||
SIZE_GUESSTIMATE = {
|
SIZE_GUESSTIMATE = {
|
||||||
SDModelType.diffusers: 2.2,
|
SDModelType.Diffusers: 2.2,
|
||||||
SDModelType.vae: 0.35,
|
SDModelType.Vae: 0.35,
|
||||||
SDModelType.text_encoder: 0.5,
|
SDModelType.TextEncoder: 0.5,
|
||||||
SDModelType.tokenizer: 0.001,
|
SDModelType.Tokenizer: 0.001,
|
||||||
SDModelType.unet: 3.4,
|
SDModelType.UNet: 3.4,
|
||||||
SDModelType.scheduler: 0.001,
|
SDModelType.Scheduler: 0.001,
|
||||||
SDModelType.safety_checker: 1.2,
|
SDModelType.SafetyChecker: 1.2,
|
||||||
SDModelType.feature_extractor: 0.001,
|
SDModelType.FeatureExtractor: 0.001,
|
||||||
SDModelType.lora: 0.1,
|
SDModelType.Lora: 0.1,
|
||||||
SDModelType.textual_inversion: 0.001,
|
SDModelType.TextualInversion: 0.001,
|
||||||
}
|
}
|
||||||
|
|
||||||
# The list of model classes we know how to fetch, for typechecking
|
# The list of model classes we know how to fetch, for typechecking
|
||||||
ModelClass = Union[tuple([x.value for x in SDModelType])]
|
ModelClass = Union[tuple([x for x in MODEL_CLASSES.values()])]
|
||||||
DiffusionClasses = (StableDiffusionGeneratorPipeline, AutoencoderKL, SchedulerMixin, UNet2DConditionModel)
|
DiffusionClasses = (StableDiffusionGeneratorPipeline, AutoencoderKL, EmptyScheduler, UNet2DConditionModel)
|
||||||
|
|
||||||
class UnsafeModelException(Exception):
|
class UnsafeModelException(Exception):
|
||||||
"Raised when a legacy model file fails the picklescan test"
|
"Raised when a legacy model file fails the picklescan test"
|
||||||
@ -110,15 +128,15 @@ class ModelLocker(object):
|
|||||||
|
|
||||||
class ModelCache(object):
|
class ModelCache(object):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
max_cache_size: float=DEFAULT_MAX_CACHE_SIZE,
|
max_cache_size: float=DEFAULT_MAX_CACHE_SIZE,
|
||||||
execution_device: torch.device=torch.device('cuda'),
|
execution_device: torch.device=torch.device('cuda'),
|
||||||
storage_device: torch.device=torch.device('cpu'),
|
storage_device: torch.device=torch.device('cpu'),
|
||||||
precision: torch.dtype=torch.float16,
|
precision: torch.dtype=torch.float16,
|
||||||
sequential_offload: bool=False,
|
sequential_offload: bool=False,
|
||||||
lazy_offloading: bool=True,
|
lazy_offloading: bool=True,
|
||||||
sha_chunksize: int = 16777216,
|
sha_chunksize: int = 16777216,
|
||||||
logger: types.ModuleType = logger
|
logger: types.ModuleType = logger
|
||||||
):
|
):
|
||||||
'''
|
'''
|
||||||
:param max_models: Maximum number of models to cache in CPU RAM [4]
|
:param max_models: Maximum number of models to cache in CPU RAM [4]
|
||||||
@ -145,15 +163,15 @@ class ModelCache(object):
|
|||||||
self.model_sizes: Dict[str,int] = dict()
|
self.model_sizes: Dict[str,int] = dict()
|
||||||
|
|
||||||
def get_model(
|
def get_model(
|
||||||
self,
|
self,
|
||||||
repo_id_or_path: Union[str,Path],
|
repo_id_or_path: Union[str, Path],
|
||||||
model_type: SDModelType=SDModelType.diffusers,
|
model_type: SDModelType = SDModelType.Diffusers,
|
||||||
subfolder: Path=None,
|
subfolder: Path = None,
|
||||||
submodel: SDModelType=None,
|
submodel: SDModelType = None,
|
||||||
revision: str=None,
|
revision: str = None,
|
||||||
attach_model_part: Tuple[SDModelType, str] = (None,None),
|
attach_model_part: Tuple[SDModelType, str] = (None, None),
|
||||||
gpu_load: bool=True,
|
gpu_load: bool = True,
|
||||||
)->ModelLocker: # ?? what does it return
|
) -> ModelLocker: # ?? what does it return
|
||||||
'''
|
'''
|
||||||
Load and return a HuggingFace model wrapped in a context manager generator, with RAM caching.
|
Load and return a HuggingFace model wrapped in a context manager generator, with RAM caching.
|
||||||
Use like this:
|
Use like this:
|
||||||
@ -178,14 +196,14 @@ class ModelCache(object):
|
|||||||
|
|
||||||
vae_context = cache.get_model(
|
vae_context = cache.get_model(
|
||||||
'stabilityai/sd-stable-diffusion-2',
|
'stabilityai/sd-stable-diffusion-2',
|
||||||
submodel=SDModelType.vae
|
submodel=SDModelType.Vae
|
||||||
)
|
)
|
||||||
|
|
||||||
This is equivalent to:
|
This is equivalent to:
|
||||||
|
|
||||||
vae_context = cache.get_model(
|
vae_context = cache.get_model(
|
||||||
'stabilityai/sd-stable-diffusion-2',
|
'stabilityai/sd-stable-diffusion-2',
|
||||||
model_type = SDModelType.vae,
|
model_type = SDModelType.Vae,
|
||||||
subfolder='vae'
|
subfolder='vae'
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -195,14 +213,14 @@ class ModelCache(object):
|
|||||||
|
|
||||||
pipeline_context = cache.get_model(
|
pipeline_context = cache.get_model(
|
||||||
'runwayml/stable-diffusion-v1-5',
|
'runwayml/stable-diffusion-v1-5',
|
||||||
attach_model_part=(SDModelType.vae,'stabilityai/sd-vae-ft-mse')
|
attach_model_part=(SDModelType.Vae,'stabilityai/sd-vae-ft-mse')
|
||||||
)
|
)
|
||||||
|
|
||||||
The model will be locked into GPU VRAM for the duration of the context.
|
The model will be locked into GPU VRAM for the duration of the context.
|
||||||
:param repo_id_or_path: either the HuggingFace repo_id or a Path to a local model
|
:param repo_id_or_path: either the HuggingFace repo_id or a Path to a local model
|
||||||
:param model_type: An SDModelType enum indicating the type of the (parent) model
|
:param model_type: An SDModelType enum indicating the type of the (parent) model
|
||||||
:param subfolder: name of a subfolder in which the model can be found, e.g. "vae"
|
:param subfolder: name of a subfolder in which the model can be found, e.g. "vae"
|
||||||
:param submodel: an SDModelType enum indicating the model part to return, e.g. SDModelType.vae
|
:param submodel: an SDModelType enum indicating the model part to return, e.g. SDModelType.Vae
|
||||||
:param attach_model_part: load and attach a diffusers model component. Pass a tuple of format (SDModelType,repo_id)
|
:param attach_model_part: load and attach a diffusers model component. Pass a tuple of format (SDModelType,repo_id)
|
||||||
:param revision: model revision
|
:param revision: model revision
|
||||||
:param gpu_load: load the model into GPU [default True]
|
:param gpu_load: load the model into GPU [default True]
|
||||||
@ -211,7 +229,7 @@ class ModelCache(object):
|
|||||||
repo_id_or_path,
|
repo_id_or_path,
|
||||||
revision,
|
revision,
|
||||||
subfolder,
|
subfolder,
|
||||||
model_type.value,
|
model_type,
|
||||||
)
|
)
|
||||||
|
|
||||||
# optimization: if caller is asking to load a submodel of a diffusers pipeline, then
|
# optimization: if caller is asking to load a submodel of a diffusers pipeline, then
|
||||||
@ -221,11 +239,11 @@ class ModelCache(object):
|
|||||||
repo_id_or_path,
|
repo_id_or_path,
|
||||||
None,
|
None,
|
||||||
revision,
|
revision,
|
||||||
SDModelType.diffusers.value
|
SDModelType.Diffusers
|
||||||
)
|
)
|
||||||
if possible_parent_key in self.models:
|
if possible_parent_key in self.models:
|
||||||
key = possible_parent_key
|
key = possible_parent_key
|
||||||
submodel=model_type
|
submodel = model_type
|
||||||
|
|
||||||
# Look for the model in the cache RAM
|
# Look for the model in the cache RAM
|
||||||
if key in self.models: # cached - move to bottom of stack (most recently used)
|
if key in self.models: # cached - move to bottom of stack (most recently used)
|
||||||
@ -256,24 +274,24 @@ class ModelCache(object):
|
|||||||
self.current_cache_size += mem_used # increment size of the cache
|
self.current_cache_size += mem_used # increment size of the cache
|
||||||
|
|
||||||
# this is a bit of legacy work needed to support the old-style "load this diffuser with custom VAE"
|
# this is a bit of legacy work needed to support the old-style "load this diffuser with custom VAE"
|
||||||
if model_type==SDModelType.diffusers and attach_model_part[0]:
|
if model_type == SDModelType.Diffusers and attach_model_part[0]:
|
||||||
self.attach_part(model,*attach_model_part)
|
self.attach_part(model, *attach_model_part)
|
||||||
|
|
||||||
self.stack.append(key) # add to LRU cache
|
self.stack.append(key) # add to LRU cache
|
||||||
self.models[key]=model # keep copy of model in dict
|
self.models[key] = model # keep copy of model in dict
|
||||||
|
|
||||||
if submodel:
|
if submodel:
|
||||||
model = getattr(model, submodel.name)
|
model = getattr(model, submodel)
|
||||||
|
|
||||||
return self.ModelLocker(self, key, model, gpu_load)
|
return self.ModelLocker(self, key, model, gpu_load)
|
||||||
|
|
||||||
def uncache_model(self, key: str):
|
def uncache_model(self, key: str):
|
||||||
'''Remove corresponding model from the cache'''
|
'''Remove corresponding model from the cache'''
|
||||||
if key is not None and key in self.models:
|
if key is not None and key in self.models:
|
||||||
with contextlib.suppress(ValueError), contextlib.suppress(KeyError):
|
self.models.pop(key, None)
|
||||||
del self.models[key]
|
self.locked_models.pop(key, None)
|
||||||
del self.locked_models[key]
|
self.loaded_models.discard(key)
|
||||||
self.loaded_models.remove(key)
|
with contextlib.suppress(ValueError):
|
||||||
self.stack.remove(key)
|
self.stack.remove(key)
|
||||||
|
|
||||||
class ModelLocker(object):
|
class ModelLocker(object):
|
||||||
@ -302,7 +320,7 @@ class ModelCache(object):
|
|||||||
if model.device != cache.execution_device:
|
if model.device != cache.execution_device:
|
||||||
cache.logger.debug(f'Moving {key} into {cache.execution_device}')
|
cache.logger.debug(f'Moving {key} into {cache.execution_device}')
|
||||||
with VRAMUsage() as mem:
|
with VRAMUsage() as mem:
|
||||||
model.to(cache.execution_device) # move into GPU
|
model.to(cache.execution_device, dtype=cache.precision) # move into GPU
|
||||||
cache.logger.debug(f'GPU VRAM used for load: {(mem.vram_used/GIG):.2f} GB')
|
cache.logger.debug(f'GPU VRAM used for load: {(mem.vram_used/GIG):.2f} GB')
|
||||||
cache.model_sizes[key] = mem.vram_used # more accurate size
|
cache.model_sizes[key] = mem.vram_used # more accurate size
|
||||||
|
|
||||||
@ -312,13 +330,16 @@ class ModelCache(object):
|
|||||||
else:
|
else:
|
||||||
# in the event that the caller wants the model in RAM, we
|
# in the event that the caller wants the model in RAM, we
|
||||||
# move it into CPU if it is in GPU and not locked
|
# move it into CPU if it is in GPU and not locked
|
||||||
if hasattr(model,'to') and (key in cache.loaded_models
|
if hasattr(model, 'to') and (key in cache.loaded_models
|
||||||
and cache.locked_models[key] == 0):
|
and cache.locked_models[key] == 0):
|
||||||
model.to(cache.storage_device)
|
model.to(cache.storage_device)
|
||||||
cache.loaded_models.remove(key)
|
cache.loaded_models.remove(key)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
def __exit__(self, type, value, traceback):
|
def __exit__(self, type, value, traceback):
|
||||||
|
if not hasattr(self.model, 'to'):
|
||||||
|
return
|
||||||
|
|
||||||
key = self.key
|
key = self.key
|
||||||
cache = self.cache
|
cache = self.cache
|
||||||
cache.locked_models[key] -= 1
|
cache.locked_models[key] -= 1
|
||||||
@ -326,11 +347,12 @@ class ModelCache(object):
|
|||||||
cache._offload_unlocked_models()
|
cache._offload_unlocked_models()
|
||||||
cache._print_cuda_stats()
|
cache._print_cuda_stats()
|
||||||
|
|
||||||
def attach_part(self,
|
def attach_part(
|
||||||
diffusers_model: StableDiffusionPipeline,
|
self,
|
||||||
part_type: SDModelType,
|
diffusers_model: StableDiffusionPipeline,
|
||||||
part_id: str
|
part_type: SDModelType,
|
||||||
):
|
part_id: str,
|
||||||
|
):
|
||||||
'''
|
'''
|
||||||
Attach a diffusers model part to a diffusers model. This can be
|
Attach a diffusers model part to a diffusers model. This can be
|
||||||
used to replace the VAE, tokenizer, textencoder, unet, etc.
|
used to replace the VAE, tokenizer, textencoder, unet, etc.
|
||||||
@ -338,27 +360,26 @@ class ModelCache(object):
|
|||||||
:param part_type: An SD ModelType indicating the part
|
:param part_type: An SD ModelType indicating the part
|
||||||
:param part_id: A HF repo_id for the part
|
:param part_id: A HF repo_id for the part
|
||||||
'''
|
'''
|
||||||
part_key = part_type.name
|
|
||||||
part_class = part_type.value
|
|
||||||
part = self._load_diffusers_from_storage(
|
part = self._load_diffusers_from_storage(
|
||||||
part_id,
|
part_id,
|
||||||
model_class=part_class,
|
model_class=MODEL_CLASSES[part_type],
|
||||||
)
|
)
|
||||||
part.to(diffusers_model.device)
|
part.to(diffusers_model.device)
|
||||||
setattr(diffusers_model,part_key,part)
|
setattr(diffusers_model, part_type, part)
|
||||||
self.logger.debug(f'Attached {part_key} {part_id}')
|
self.logger.debug(f'Attached {part_type} {part_id}')
|
||||||
|
|
||||||
def status(self,
|
def status(
|
||||||
repo_id_or_path: Union[str,Path],
|
self,
|
||||||
model_type: SDModelType=SDModelType.diffusers,
|
repo_id_or_path: Union[str, Path],
|
||||||
revision: str=None,
|
model_type: SDModelType = SDModelType.Diffusers,
|
||||||
subfolder: Path=None,
|
revision: str = None,
|
||||||
)->ModelStatus:
|
subfolder: Path = None,
|
||||||
|
) -> ModelStatus:
|
||||||
key = self._model_key(
|
key = self._model_key(
|
||||||
repo_id_or_path,
|
repo_id_or_path,
|
||||||
revision,
|
revision,
|
||||||
subfolder,
|
subfolder,
|
||||||
model_type.value,
|
model_type,
|
||||||
)
|
)
|
||||||
if key not in self.models:
|
if key not in self.models:
|
||||||
return ModelStatus.not_loaded
|
return ModelStatus.not_loaded
|
||||||
@ -370,9 +391,11 @@ class ModelCache(object):
|
|||||||
else:
|
else:
|
||||||
return ModelStatus.in_ram
|
return ModelStatus.in_ram
|
||||||
|
|
||||||
def model_hash(self,
|
def model_hash(
|
||||||
repo_id_or_path: Union[str,Path],
|
self,
|
||||||
revision: str="main")->str:
|
repo_id_or_path: Union[str, Path],
|
||||||
|
revision: str = "main",
|
||||||
|
) -> str:
|
||||||
'''
|
'''
|
||||||
Given the HF repo id or path to a model on disk, returns a unique
|
Given the HF repo id or path to a model on disk, returns a unique
|
||||||
hash. Works for legacy checkpoint files, HF models on disk, and HF repo IDs
|
hash. Works for legacy checkpoint files, HF models on disk, and HF repo IDs
|
||||||
@ -385,7 +408,7 @@ class ModelCache(object):
|
|||||||
else:
|
else:
|
||||||
return self._hf_commit_hash(repo_id_or_path,revision)
|
return self._hf_commit_hash(repo_id_or_path,revision)
|
||||||
|
|
||||||
def cache_size(self)->float:
|
def cache_size(self) -> float:
|
||||||
"Return the current size of the cache, in GB"
|
"Return the current size of the cache, in GB"
|
||||||
return self.current_cache_size / GIG
|
return self.current_cache_size / GIG
|
||||||
|
|
||||||
@ -407,10 +430,15 @@ class ModelCache(object):
|
|||||||
logger.debug("Model scanned ok")
|
logger.debug("Model scanned ok")
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _model_key(path,revision,subfolder,model_class)->str:
|
def _model_key(path, revision, subfolder, model_class) -> str:
|
||||||
return ':'.join([str(path),str(revision or ''),str(subfolder or ''),model_class.__name__])
|
return ':'.join([
|
||||||
|
str(path),
|
||||||
|
str(revision or ''),
|
||||||
|
str(subfolder or ''),
|
||||||
|
model_class,
|
||||||
|
])
|
||||||
|
|
||||||
def _has_cuda(self)->bool:
|
def _has_cuda(self) -> bool:
|
||||||
return self.execution_device.type == 'cuda'
|
return self.execution_device.type == 'cuda'
|
||||||
|
|
||||||
def _print_cuda_stats(self):
|
def _print_cuda_stats(self):
|
||||||
@ -450,43 +478,43 @@ class ModelCache(object):
|
|||||||
self.loaded_models.remove(key)
|
self.loaded_models.remove(key)
|
||||||
|
|
||||||
def _load_model_from_storage(
|
def _load_model_from_storage(
|
||||||
self,
|
self,
|
||||||
repo_id_or_path: Union[str,Path],
|
repo_id_or_path: Union[str, Path],
|
||||||
subfolder: Path=None,
|
subfolder: Optional[Path] = None,
|
||||||
revision: str=None,
|
revision: Optional[str] = None,
|
||||||
model_type: SDModelType=SDModelType.diffusers,
|
model_type: SDModelType = SDModelType.Diffusers,
|
||||||
)->ModelClass:
|
) -> ModelClass:
|
||||||
'''
|
'''
|
||||||
Load and return a HuggingFace model.
|
Load and return a HuggingFace model.
|
||||||
:param repo_id_or_path: either the HuggingFace repo_id or a Path to a local model
|
:param repo_id_or_path: either the HuggingFace repo_id or a Path to a local model
|
||||||
:param subfolder: name of a subfolder in which the model can be found, e.g. "vae"
|
:param subfolder: name of a subfolder in which the model can be found, e.g. "vae"
|
||||||
:param revision: model revision
|
:param revision: model revision
|
||||||
:param model_type: type of model to return, defaults to SDModelType.diffusers
|
:param model_type: type of model to return, defaults to SDModelType.Diffusers
|
||||||
'''
|
'''
|
||||||
# silence transformer and diffuser warnings
|
# silence transformer and diffuser warnings
|
||||||
with SilenceWarnings():
|
with SilenceWarnings():
|
||||||
if model_type==SDModelType.lora:
|
if model_type==SDModelType.Lora:
|
||||||
model = self._load_lora_from_storage(repo_id_or_path)
|
model = self._load_lora_from_storage(repo_id_or_path)
|
||||||
elif model_type==SDModelType.textual_inversion:
|
elif model_type==SDModelType.TextualInversion:
|
||||||
model = self._load_ti_from_storage(repo_id_or_path)
|
model = self._load_ti_from_storage(repo_id_or_path)
|
||||||
else:
|
else:
|
||||||
model = self._load_diffusers_from_storage(
|
model = self._load_diffusers_from_storage(
|
||||||
repo_id_or_path,
|
repo_id_or_path,
|
||||||
subfolder,
|
subfolder,
|
||||||
revision,
|
revision,
|
||||||
model_type.value,
|
model_type,
|
||||||
)
|
)
|
||||||
if self.sequential_offload and isinstance(model,StableDiffusionGeneratorPipeline):
|
if self.sequential_offload and isinstance(model, StableDiffusionGeneratorPipeline):
|
||||||
model.enable_offload_submodels(self.execution_device)
|
model.enable_offload_submodels(self.execution_device)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
def _load_diffusers_from_storage(
|
def _load_diffusers_from_storage(
|
||||||
self,
|
self,
|
||||||
repo_id_or_path: Union[str,Path],
|
repo_id_or_path: Union[str, Path],
|
||||||
subfolder: Path=None,
|
subfolder: Optional[Path] = None,
|
||||||
revision: str=None,
|
revision: Optional[str] = None,
|
||||||
model_class: ModelClass=StableDiffusionGeneratorPipeline,
|
model_type: ModelClass = StableDiffusionGeneratorPipeline,
|
||||||
)->ModelClass:
|
) -> ModelClass:
|
||||||
'''
|
'''
|
||||||
Load and return a HuggingFace model using from_pretrained().
|
Load and return a HuggingFace model using from_pretrained().
|
||||||
:param repo_id_or_path: either the HuggingFace repo_id or a Path to a local model
|
:param repo_id_or_path: either the HuggingFace repo_id or a Path to a local model
|
||||||
@ -494,17 +522,26 @@ class ModelCache(object):
|
|||||||
:param revision: model revision
|
:param revision: model revision
|
||||||
:param model_class: class of model to return, defaults to StableDiffusionGeneratorPIpeline
|
:param model_class: class of model to return, defaults to StableDiffusionGeneratorPIpeline
|
||||||
'''
|
'''
|
||||||
revisions = [revision] if revision \
|
|
||||||
else ['fp16','main'] if self.precision==torch.float16 \
|
model_class = MODEL_CLASSES[model_type]
|
||||||
else ['main']
|
|
||||||
extra_args = {'torch_dtype': self.precision,
|
if revision is not None:
|
||||||
'safety_checker': None}\
|
revisions = [revision]
|
||||||
if model_class in DiffusionClasses\
|
elif self.precision == torch.float16:
|
||||||
else {}
|
revisions = ['fp16', 'main']
|
||||||
|
else:
|
||||||
|
revisions = ['main']
|
||||||
|
|
||||||
|
extra_args = dict()
|
||||||
|
if model_class in DiffusionClasses:
|
||||||
|
extra_args = dict(
|
||||||
|
torch_dtype=self.precision,
|
||||||
|
safety_checker=None,
|
||||||
|
)
|
||||||
|
|
||||||
for rev in revisions:
|
for rev in revisions:
|
||||||
try:
|
try:
|
||||||
model = model_class.from_pretrained(
|
model = model_class.from_pretrained(
|
||||||
repo_id_or_path,
|
repo_id_or_path,
|
||||||
revision=rev,
|
revision=rev,
|
||||||
subfolder=subfolder or '.',
|
subfolder=subfolder or '.',
|
||||||
@ -517,13 +554,13 @@ class ModelCache(object):
|
|||||||
pass
|
pass
|
||||||
return model
|
return model
|
||||||
|
|
||||||
def _load_lora_from_storage(self, lora_path: Path)->SDModelType.lora.value:
|
def _load_lora_from_storage(self, lora_path: Path) -> LoraType:
|
||||||
assert False,"_load_lora_from_storage() is not yet implemented"
|
assert False, "_load_lora_from_storage() is not yet implemented"
|
||||||
|
|
||||||
def _load_ti_from_storage(self, lora_path: Path)->SDModelType.textual_inversion.value:
|
def _load_ti_from_storage(self, lora_path: Path) -> TIType:
|
||||||
assert False,"_load_ti_from_storage() is not yet implemented"
|
assert False, "_load_ti_from_storage() is not yet implemented"
|
||||||
|
|
||||||
def _legacy_model_hash(self, checkpoint_path: Union[str,Path])->str:
|
def _legacy_model_hash(self, checkpoint_path: Union[str, Path]) -> str:
|
||||||
sha = hashlib.sha256()
|
sha = hashlib.sha256()
|
||||||
path = Path(checkpoint_path)
|
path = Path(checkpoint_path)
|
||||||
assert path.is_file(),f"File {checkpoint_path} not found"
|
assert path.is_file(),f"File {checkpoint_path} not found"
|
||||||
@ -544,7 +581,7 @@ class ModelCache(object):
|
|||||||
f.write(hash)
|
f.write(hash)
|
||||||
return hash
|
return hash
|
||||||
|
|
||||||
def _local_model_hash(self, model_path: Union[str,Path])->str:
|
def _local_model_hash(self, model_path: Union[str, Path]) -> str:
|
||||||
sha = hashlib.sha256()
|
sha = hashlib.sha256()
|
||||||
path = Path(model_path)
|
path = Path(model_path)
|
||||||
|
|
||||||
@ -566,7 +603,7 @@ class ModelCache(object):
|
|||||||
f.write(hash)
|
f.write(hash)
|
||||||
return hash
|
return hash
|
||||||
|
|
||||||
def _hf_commit_hash(self, repo_id: str, revision: str='main')->str:
|
def _hf_commit_hash(self, repo_id: str, revision: str='main') -> str:
|
||||||
api = HfApi()
|
api = HfApi()
|
||||||
info = api.list_repo_refs(
|
info = api.list_repo_refs(
|
||||||
repo_id=repo_id,
|
repo_id=repo_id,
|
||||||
@ -578,7 +615,7 @@ class ModelCache(object):
|
|||||||
return desired_revisions[0].target_commit
|
return desired_revisions[0].target_commit
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def calc_model_size(model)->int:
|
def calc_model_size(model) -> int:
|
||||||
if isinstance(model,DiffusionPipeline):
|
if isinstance(model,DiffusionPipeline):
|
||||||
return ModelCache._calc_pipeline(model)
|
return ModelCache._calc_pipeline(model)
|
||||||
elif isinstance(model,torch.nn.Module):
|
elif isinstance(model,torch.nn.Module):
|
||||||
@ -587,7 +624,7 @@ class ModelCache(object):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _calc_pipeline(pipeline)->int:
|
def _calc_pipeline(pipeline) -> int:
|
||||||
res = 0
|
res = 0
|
||||||
for submodel_key in pipeline.components.keys():
|
for submodel_key in pipeline.components.keys():
|
||||||
submodel = getattr(pipeline, submodel_key)
|
submodel = getattr(pipeline, submodel_key)
|
||||||
@ -596,7 +633,7 @@ class ModelCache(object):
|
|||||||
return res
|
return res
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _calc_model(model)->int:
|
def _calc_model(model) -> int:
|
||||||
mem_params = sum([param.nelement()*param.element_size() for param in model.parameters()])
|
mem_params = sum([param.nelement()*param.element_size() for param in model.parameters()])
|
||||||
mem_bufs = sum([buf.nelement()*buf.element_size() for buf in model.buffers()])
|
mem_bufs = sum([buf.nelement()*buf.element_size() for buf in model.buffers()])
|
||||||
mem = mem_params + mem_bufs # in bytes
|
mem = mem_params + mem_bufs # in bytes
|
||||||
|
@ -27,7 +27,7 @@ Typical usage:
|
|||||||
max_cache_size=8
|
max_cache_size=8
|
||||||
) # gigabytes
|
) # gigabytes
|
||||||
|
|
||||||
model_info = manager.get_model('stable-diffusion-1.5', SDModelType.diffusers)
|
model_info = manager.get_model('stable-diffusion-1.5', SDModelType.Diffusers)
|
||||||
with model_info.context as my_model:
|
with model_info.context as my_model:
|
||||||
my_model.latents_from_embeddings(...)
|
my_model.latents_from_embeddings(...)
|
||||||
|
|
||||||
@ -45,7 +45,7 @@ parameter:
|
|||||||
|
|
||||||
model_info = manager.get_model(
|
model_info = manager.get_model(
|
||||||
'clip-tokenizer',
|
'clip-tokenizer',
|
||||||
model_type=SDModelType.tokenizer
|
model_type=SDModelType.Tokenizer
|
||||||
)
|
)
|
||||||
|
|
||||||
This will raise an InvalidModelError if the format defined in the
|
This will raise an InvalidModelError if the format defined in the
|
||||||
@ -96,7 +96,7 @@ SUBMODELS:
|
|||||||
It is also possible to fetch an isolated submodel from a diffusers
|
It is also possible to fetch an isolated submodel from a diffusers
|
||||||
model. Use the `submodel` parameter to select which part:
|
model. Use the `submodel` parameter to select which part:
|
||||||
|
|
||||||
vae = manager.get_model('stable-diffusion-1.5',submodel=SDModelType.vae)
|
vae = manager.get_model('stable-diffusion-1.5',submodel=SDModelType.Vae)
|
||||||
with vae.context as my_vae:
|
with vae.context as my_vae:
|
||||||
print(type(my_vae))
|
print(type(my_vae))
|
||||||
# "AutoencoderKL"
|
# "AutoencoderKL"
|
||||||
@ -120,8 +120,8 @@ separated by "/". Example:
|
|||||||
You can now use the `model_type` argument to indicate which model you
|
You can now use the `model_type` argument to indicate which model you
|
||||||
want:
|
want:
|
||||||
|
|
||||||
tokenizer = mgr.get('clip-large',model_type=SDModelType.tokenizer)
|
tokenizer = mgr.get('clip-large',model_type=SDModelType.Tokenizer)
|
||||||
encoder = mgr.get('clip-large',model_type=SDModelType.text_encoder)
|
encoder = mgr.get('clip-large',model_type=SDModelType.TextEncoder)
|
||||||
|
|
||||||
OTHER FUNCTIONS:
|
OTHER FUNCTIONS:
|
||||||
|
|
||||||
@ -254,7 +254,7 @@ class ModelManager(object):
|
|||||||
def model_exists(
|
def model_exists(
|
||||||
self,
|
self,
|
||||||
model_name: str,
|
model_name: str,
|
||||||
model_type: SDModelType = SDModelType.diffusers,
|
model_type: SDModelType = SDModelType.Diffusers,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""
|
"""
|
||||||
Given a model name, returns True if it is a valid
|
Given a model name, returns True if it is a valid
|
||||||
@ -264,28 +264,28 @@ class ModelManager(object):
|
|||||||
return model_key in self.config
|
return model_key in self.config
|
||||||
|
|
||||||
def create_key(self, model_name: str, model_type: SDModelType) -> str:
|
def create_key(self, model_name: str, model_type: SDModelType) -> str:
|
||||||
return f"{model_type.name}/{model_name}"
|
return f"{model_type}/{model_name}"
|
||||||
|
|
||||||
def parse_key(self, model_key: str) -> Tuple[str, SDModelType]:
|
def parse_key(self, model_key: str) -> Tuple[str, SDModelType]:
|
||||||
model_type_str, model_name = model_key.split('/', 1)
|
model_type_str, model_name = model_key.split('/', 1)
|
||||||
if model_type_str not in SDModelType.__members__:
|
try:
|
||||||
# TODO:
|
model_type = SDModelType(model_type_str)
|
||||||
|
return (model_name, model_type)
|
||||||
|
except:
|
||||||
raise Exception(f"Unknown model type: {model_type_str}")
|
raise Exception(f"Unknown model type: {model_type_str}")
|
||||||
|
|
||||||
return (model_name, SDModelType[model_type_str])
|
|
||||||
|
|
||||||
def get_model(
|
def get_model(
|
||||||
self,
|
self,
|
||||||
model_name: str,
|
model_name: str,
|
||||||
model_type: SDModelType=SDModelType.diffusers,
|
model_type: SDModelType = SDModelType.Diffusers,
|
||||||
submodel: SDModelType=None,
|
submodel: Optional[SDModelType] = None,
|
||||||
) -> SDModelInfo:
|
) -> SDModelInfo:
|
||||||
"""Given a model named identified in models.yaml, return
|
"""Given a model named identified in models.yaml, return
|
||||||
an SDModelInfo object describing it.
|
an SDModelInfo object describing it.
|
||||||
:param model_name: symbolic name of the model in models.yaml
|
:param model_name: symbolic name of the model in models.yaml
|
||||||
:param model_type: SDModelType enum indicating the type of model to return
|
:param model_type: SDModelType enum indicating the type of model to return
|
||||||
:param submodel: an SDModelType enum indicating the portion of
|
:param submodel: an SDModelType enum indicating the portion of
|
||||||
the model to retrieve (e.g. SDModelType.vae)
|
the model to retrieve (e.g. SDModelType.Vae)
|
||||||
|
|
||||||
If not provided, the model_type will be read from the `format` field
|
If not provided, the model_type will be read from the `format` field
|
||||||
of the corresponding stanza. If provided, the model_type will be used
|
of the corresponding stanza. If provided, the model_type will be used
|
||||||
@ -304,17 +304,17 @@ class ModelManager(object):
|
|||||||
test1_pipeline = mgr.get_model('test1')
|
test1_pipeline = mgr.get_model('test1')
|
||||||
# returns a StableDiffusionGeneratorPipeline
|
# returns a StableDiffusionGeneratorPipeline
|
||||||
|
|
||||||
test1_vae1 = mgr.get_model('test1', submodel=SDModelType.vae)
|
test1_vae1 = mgr.get_model('test1', submodel=SDModelType.Vae)
|
||||||
# returns the VAE part of a diffusers model as an AutoencoderKL
|
# returns the VAE part of a diffusers model as an AutoencoderKL
|
||||||
|
|
||||||
test1_vae2 = mgr.get_model('test1', model_type=SDModelType.diffusers, submodel=SDModelType.vae)
|
test1_vae2 = mgr.get_model('test1', model_type=SDModelType.Diffusers, submodel=SDModelType.Vae)
|
||||||
# does the same thing as the previous statement. Note that model_type
|
# does the same thing as the previous statement. Note that model_type
|
||||||
# is for the parent model, and submodel is for the part
|
# is for the parent model, and submodel is for the part
|
||||||
|
|
||||||
test1_lora = mgr.get_model('test1', model_type=SDModelType.lora)
|
test1_lora = mgr.get_model('test1', model_type=SDModelType.Lora)
|
||||||
# returns a LoRA embed (as a 'dict' of tensors)
|
# returns a LoRA embed (as a 'dict' of tensors)
|
||||||
|
|
||||||
test1_encoder = mgr.get_modelI('test1', model_type=SDModelType.textencoder)
|
test1_encoder = mgr.get_modelI('test1', model_type=SDModelType.TextEncoder)
|
||||||
# raises an InvalidModelError
|
# raises an InvalidModelError
|
||||||
|
|
||||||
"""
|
"""
|
||||||
@ -332,10 +332,10 @@ class ModelManager(object):
|
|||||||
mconfig = self.config[model_key]
|
mconfig = self.config[model_key]
|
||||||
|
|
||||||
# type already checked as it's part of key
|
# type already checked as it's part of key
|
||||||
if model_type == SDModelType.diffusers:
|
if model_type == SDModelType.Diffusers:
|
||||||
# intercept stanzas that point to checkpoint weights and replace them
|
# intercept stanzas that point to checkpoint weights and replace them
|
||||||
# with the equivalent diffusers model
|
# with the equivalent diffusers model
|
||||||
if mconfig.format in ["ckpt", "diffusers"]:
|
if mconfig.format in ["ckpt", "safetensors"]:
|
||||||
location = self.convert_ckpt_and_cache(mconfig)
|
location = self.convert_ckpt_and_cache(mconfig)
|
||||||
else:
|
else:
|
||||||
location = global_resolve_path(mconfig.get('path')) or mconfig.get('repo_id')
|
location = global_resolve_path(mconfig.get('path')) or mconfig.get('repo_id')
|
||||||
@ -355,13 +355,13 @@ class ModelManager(object):
|
|||||||
vae = (None, None)
|
vae = (None, None)
|
||||||
with suppress(Exception):
|
with suppress(Exception):
|
||||||
vae_id = mconfig.vae.repo_id
|
vae_id = mconfig.vae.repo_id
|
||||||
vae = (SDModelType.vae, vae_id)
|
vae = (SDModelType.Vae, vae_id)
|
||||||
|
|
||||||
# optimization - don't load whole model if the user
|
# optimization - don't load whole model if the user
|
||||||
# is asking for just a piece of it
|
# is asking for just a piece of it
|
||||||
if model_type == SDModelType.diffusers and submodel and not subfolder:
|
if model_type == SDModelType.Diffusers and submodel and not subfolder:
|
||||||
model_type = submodel
|
model_type = submodel
|
||||||
subfolder = submodel.name
|
subfolder = submodel.value
|
||||||
submodel = None
|
submodel = None
|
||||||
|
|
||||||
model_context = self.cache.get_model(
|
model_context = self.cache.get_model(
|
||||||
@ -390,7 +390,7 @@ class ModelManager(object):
|
|||||||
_cache = self.cache
|
_cache = self.cache
|
||||||
)
|
)
|
||||||
|
|
||||||
def default_model(self) -> Union[Tuple[str, SDModelType],None]:
|
def default_model(self) -> Optional[Tuple[str, SDModelType]]:
|
||||||
"""
|
"""
|
||||||
Returns the name of the default model, or None
|
Returns the name of the default model, or None
|
||||||
if none is defined.
|
if none is defined.
|
||||||
@ -401,7 +401,7 @@ class ModelManager(object):
|
|||||||
return (model_name, model_type)
|
return (model_name, model_type)
|
||||||
return self.model_names()[0][0]
|
return self.model_names()[0][0]
|
||||||
|
|
||||||
def set_default_model(self, model_name: str, model_type: SDModelType=SDModelType.diffusers) -> None:
|
def set_default_model(self, model_name: str, model_type: SDModelType=SDModelType.Diffusers) -> None:
|
||||||
"""
|
"""
|
||||||
Set the default model. The change will not take
|
Set the default model. The change will not take
|
||||||
effect until you call model_manager.commit()
|
effect until you call model_manager.commit()
|
||||||
@ -415,25 +415,25 @@ class ModelManager(object):
|
|||||||
config[self.create_key(model_name, model_type)]["default"] = True
|
config[self.create_key(model_name, model_type)]["default"] = True
|
||||||
|
|
||||||
def model_info(
|
def model_info(
|
||||||
self,
|
self,
|
||||||
model_name: str,
|
model_name: str,
|
||||||
model_type: SDModelType=SDModelType.diffusers
|
model_type: SDModelType=SDModelType.Diffusers,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
"""
|
"""
|
||||||
Given a model name returns the OmegaConf (dict-like) object describing it.
|
Given a model name returns the OmegaConf (dict-like) object describing it.
|
||||||
"""
|
"""
|
||||||
if not self.exists(model_name, model_type):
|
if not self.exists(model_name, model_type):
|
||||||
return None
|
return None
|
||||||
return self.config[self.create_key(model_name,model_type)]
|
return self.config[self.create_key(model_name, model_type)]
|
||||||
|
|
||||||
def model_names(self) -> List[Tuple[str, SDModelType]]:
|
def model_names(self) -> List[Tuple[str, SDModelType]]:
|
||||||
"""
|
"""
|
||||||
Return a list of (str, SDModelType) corresponding to all models
|
Return a list of (str, SDModelType) corresponding to all models
|
||||||
known to the configuration.
|
known to the configuration.
|
||||||
"""
|
"""
|
||||||
return [(self.parse_key(x)) for x in self.config.keys() if isinstance(self.config[x],DictConfig)]
|
return [(self.parse_key(x)) for x in self.config.keys() if isinstance(self.config[x], DictConfig)]
|
||||||
|
|
||||||
def is_legacy(self, model_name: str, model_type: SDModelType.diffusers) -> bool:
|
def is_legacy(self, model_name: str, model_type: SDModelType.Diffusers) -> bool:
|
||||||
"""
|
"""
|
||||||
Return true if this is a legacy (.ckpt) model
|
Return true if this is a legacy (.ckpt) model
|
||||||
"""
|
"""
|
||||||
@ -461,14 +461,14 @@ class ModelManager(object):
|
|||||||
# don't include VAEs in listing (legacy style)
|
# don't include VAEs in listing (legacy style)
|
||||||
if "config" in stanza and "/VAE/" in stanza["config"]:
|
if "config" in stanza and "/VAE/" in stanza["config"]:
|
||||||
continue
|
continue
|
||||||
if model_key=='config_file_version':
|
if model_key == 'config_file_version':
|
||||||
continue
|
continue
|
||||||
|
|
||||||
model_name, model_type = self.parse_key(model_key)
|
model_name, model_type = self.parse_key(model_key)
|
||||||
models[model_key] = dict()
|
models[model_key] = dict()
|
||||||
|
|
||||||
# TODO: return all models in future
|
# TODO: return all models in future
|
||||||
if model_type != SDModelType.diffusers:
|
if model_type != SDModelType.Diffusers:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
model_format = stanza.get('format')
|
model_format = stanza.get('format')
|
||||||
@ -477,15 +477,15 @@ class ModelManager(object):
|
|||||||
status = self.cache.status(
|
status = self.cache.status(
|
||||||
stanza.get('weights') or stanza.get('repo_id'),
|
stanza.get('weights') or stanza.get('repo_id'),
|
||||||
revision=stanza.get('revision'),
|
revision=stanza.get('revision'),
|
||||||
subfolder=stanza.get('subfolder')
|
subfolder=stanza.get('subfolder'),
|
||||||
)
|
)
|
||||||
description = stanza.get("description", None)
|
description = stanza.get("description", None)
|
||||||
models[model_key].update(
|
models[model_key].update(
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
model_type=model_type.name,
|
model_type=model_type,
|
||||||
format=model_format,
|
format=model_format,
|
||||||
description=description,
|
description=description,
|
||||||
status=status.value
|
status=status.value,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -528,8 +528,8 @@ class ModelManager(object):
|
|||||||
def del_model(
|
def del_model(
|
||||||
self,
|
self,
|
||||||
model_name: str,
|
model_name: str,
|
||||||
model_type: SDModelType.diffusers,
|
model_type: SDModelType.Diffusers,
|
||||||
delete_files: bool = False
|
delete_files: bool = False,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Delete the named model.
|
Delete the named model.
|
||||||
@ -539,9 +539,9 @@ class ModelManager(object):
|
|||||||
|
|
||||||
if model_cfg is None:
|
if model_cfg is None:
|
||||||
self.logger.error(
|
self.logger.error(
|
||||||
f"Unknown model {model_key}"
|
f"Unknown model {model_key}"
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
# TODO: some legacy?
|
# TODO: some legacy?
|
||||||
#if model_name in self.stack:
|
#if model_name in self.stack:
|
||||||
@ -571,7 +571,7 @@ class ModelManager(object):
|
|||||||
model_name: str,
|
model_name: str,
|
||||||
model_type: SDModelType,
|
model_type: SDModelType,
|
||||||
model_attributes: dict,
|
model_attributes: dict,
|
||||||
clobber: bool = False
|
clobber: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Update the named model with a dictionary of attributes. Will fail with an
|
Update the named model with a dictionary of attributes. Will fail with an
|
||||||
@ -581,7 +581,7 @@ class ModelManager(object):
|
|||||||
attributes are incorrect or the model name is missing.
|
attributes are incorrect or the model name is missing.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if model_type == SDModelType.diffusers:
|
if model_type == SDModelType.Fiffusers:
|
||||||
# TODO: automaticaly or manualy?
|
# TODO: automaticaly or manualy?
|
||||||
#assert "format" in model_attributes, 'missing required field "format"'
|
#assert "format" in model_attributes, 'missing required field "format"'
|
||||||
model_format = "ckpt" if "weights" in model_attributes else "diffusers"
|
model_format = "ckpt" if "weights" in model_attributes else "diffusers"
|
||||||
@ -647,16 +647,16 @@ class ModelManager(object):
|
|||||||
else:
|
else:
|
||||||
new_config.update(repo_id=repo_or_path)
|
new_config.update(repo_id=repo_or_path)
|
||||||
|
|
||||||
self.add_model(model_name, SDModelType.diffusers, new_config, True)
|
self.add_model(model_name, SDModelType.Diffusers, new_config, True)
|
||||||
if commit_to_conf:
|
if commit_to_conf:
|
||||||
self.commit(commit_to_conf)
|
self.commit(commit_to_conf)
|
||||||
return self.create_key(model_name, SDModelType.diffusers)
|
return self.create_key(model_name, SDModelType.Diffusers)
|
||||||
|
|
||||||
def import_lora(
|
def import_lora(
|
||||||
self,
|
self,
|
||||||
path: Path,
|
path: Path,
|
||||||
model_name: str=None,
|
model_name: Optional[str] = None,
|
||||||
description: str=None,
|
description: Optional[str] = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Creates an entry for the indicated lora file. Call
|
Creates an entry for the indicated lora file. Call
|
||||||
@ -667,7 +667,7 @@ class ModelManager(object):
|
|||||||
model_description = description or f"LoRA model {model_name}"
|
model_description = description or f"LoRA model {model_name}"
|
||||||
self.add_model(
|
self.add_model(
|
||||||
model_name,
|
model_name,
|
||||||
SDModelType.lora,
|
SDModelType.Lora,
|
||||||
dict(
|
dict(
|
||||||
format="lora",
|
format="lora",
|
||||||
weights=str(path),
|
weights=str(path),
|
||||||
@ -679,8 +679,8 @@ class ModelManager(object):
|
|||||||
def import_embedding(
|
def import_embedding(
|
||||||
self,
|
self,
|
||||||
path: Path,
|
path: Path,
|
||||||
model_name: str=None,
|
model_name: Optional[str] = None,
|
||||||
description: str=None,
|
description: Optional[str] = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Creates an entry for the indicated lora file. Call
|
Creates an entry for the indicated lora file. Call
|
||||||
@ -696,7 +696,7 @@ class ModelManager(object):
|
|||||||
model_description = description or f"Textual embedding model {model_name}"
|
model_description = description or f"Textual embedding model {model_name}"
|
||||||
self.add_model(
|
self.add_model(
|
||||||
model_name,
|
model_name,
|
||||||
SDModelType.textual_inversion,
|
SDModelType.TextualInversion,
|
||||||
dict(
|
dict(
|
||||||
format="textual_inversion",
|
format="textual_inversion",
|
||||||
weights=str(weights),
|
weights=str(weights),
|
||||||
@ -746,11 +746,11 @@ class ModelManager(object):
|
|||||||
def heuristic_import(
|
def heuristic_import(
|
||||||
self,
|
self,
|
||||||
path_url_or_repo: str,
|
path_url_or_repo: str,
|
||||||
model_name: str = None,
|
model_name: Optional[str] = None,
|
||||||
description: str = None,
|
description: Optional[str] = None,
|
||||||
model_config_file: Path = None,
|
model_config_file: Optional[Path] = None,
|
||||||
commit_to_conf: Path = None,
|
commit_to_conf: Optional[Path] = None,
|
||||||
config_file_callback: Callable[[Path], Path] = None,
|
config_file_callback: Optional[Callable[[Path], Path]] = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Accept a string which could be:
|
"""Accept a string which could be:
|
||||||
- a HF diffusers repo_id
|
- a HF diffusers repo_id
|
||||||
@ -927,7 +927,7 @@ class ModelManager(object):
|
|||||||
)
|
)
|
||||||
return model_name
|
return model_name
|
||||||
|
|
||||||
def convert_ckpt_and_cache(self, mconfig: DictConfig)->Path:
|
def convert_ckpt_and_cache(self, mconfig: DictConfig) -> Path:
|
||||||
"""
|
"""
|
||||||
Convert the checkpoint model indicated in mconfig into a
|
Convert the checkpoint model indicated in mconfig into a
|
||||||
diffusers, cache it to disk, and return Path to converted
|
diffusers, cache it to disk, and return Path to converted
|
||||||
@ -961,7 +961,7 @@ class ModelManager(object):
|
|||||||
self,
|
self,
|
||||||
weights: Path,
|
weights: Path,
|
||||||
mconfig: DictConfig
|
mconfig: DictConfig
|
||||||
) -> Tuple[Path, SDModelType.vae]:
|
) -> Tuple[Path, AutoencoderKL]:
|
||||||
# VAE handling is convoluted
|
# VAE handling is convoluted
|
||||||
# 1. If there is a .vae.ckpt file sharing same stem as weights, then use
|
# 1. If there is a .vae.ckpt file sharing same stem as weights, then use
|
||||||
# it as the vae_path passed to convert
|
# it as the vae_path passed to convert
|
||||||
@ -990,7 +990,7 @@ class ModelManager(object):
|
|||||||
vae_diffusers_location = "stabilityai/sd-vae-ft-mse"
|
vae_diffusers_location = "stabilityai/sd-vae-ft-mse"
|
||||||
|
|
||||||
if vae_diffusers_location:
|
if vae_diffusers_location:
|
||||||
vae_model = self.cache.get_model(vae_diffusers_location, SDModelType.vae).model
|
vae_model = self.cache.get_model(vae_diffusers_location, SDModelType.Vae).model
|
||||||
return (None, vae_model)
|
return (None, vae_model)
|
||||||
|
|
||||||
return (None, None)
|
return (None, None)
|
||||||
@ -1038,7 +1038,7 @@ class ModelManager(object):
|
|||||||
vae_model = None
|
vae_model = None
|
||||||
if vae:
|
if vae:
|
||||||
vae_location = global_resolve_path(vae.get('path')) or vae.get('repo_id')
|
vae_location = global_resolve_path(vae.get('path')) or vae.get('repo_id')
|
||||||
vae_model = self.cache.get_model(vae_location,SDModelType.vae).model
|
vae_model = self.cache.get_model(vae_location, SDModelType.Vae).model
|
||||||
vae_path = None
|
vae_path = None
|
||||||
convert_ckpt_to_diffusers(
|
convert_ckpt_to_diffusers(
|
||||||
ckpt_path,
|
ckpt_path,
|
||||||
@ -1058,11 +1058,11 @@ class ModelManager(object):
|
|||||||
description=model_description,
|
description=model_description,
|
||||||
format="diffusers",
|
format="diffusers",
|
||||||
)
|
)
|
||||||
if self.model_exists(model_name, SDModelType.diffusers):
|
if self.model_exists(model_name, SDModelType.Diffusers):
|
||||||
self.del_model(model_name, SDModelType.diffusers)
|
self.del_model(model_name, SDModelType.Diffusers)
|
||||||
self.add_model(
|
self.add_model(
|
||||||
model_name,
|
model_name,
|
||||||
SDModelType.diffusers,
|
SDModelType.Diffusers,
|
||||||
new_config,
|
new_config,
|
||||||
True
|
True
|
||||||
)
|
)
|
||||||
|
@ -263,7 +263,7 @@ export const parseNodeMetadata = (
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
if ('unet' in nodeItem && 'tokenizer' in nodeItem) {
|
if ('unet' in nodeItem && 'scheduler' in nodeItem) {
|
||||||
const unetField = parseUNetField(nodeItem);
|
const unetField = parseUNetField(nodeItem);
|
||||||
if (unetField) {
|
if (unetField) {
|
||||||
parsed[nodeKey] = unetField;
|
parsed[nodeKey] = unetField;
|
||||||
|
Loading…
Reference in New Issue
Block a user