Change SDModelType enum to string, fixes(model unload negative locks count, scheduler load error, saftensors convert, wrong logic in del_model, wrong parse metadata in web)

This commit is contained in:
Sergey Borisov
2023-05-14 03:06:26 +03:00
parent 2204e47596
commit 039fa73269
8 changed files with 388 additions and 363 deletions

View File

@ -59,19 +59,14 @@ class CompelInvocation(BaseInvocation):
def invoke(self, context: InvocationContext) -> CompelOutput:
# TODO: load without model
text_encoder_info = context.services.model_manager.get_model(
model_name=self.clip.text_encoder.model_name,
model_type=SDModelType[self.clip.text_encoder.model_type],
submodel=SDModelType[self.clip.text_encoder.submodel],
**self.clip.text_encoder.dict(),
)
tokenizer_info = context.services.model_manager.get_model(
model_name=self.clip.tokenizer.model_name,
model_type=SDModelType[self.clip.tokenizer.model_type],
submodel=SDModelType[self.clip.tokenizer.submodel],
**self.clip.tokenizer.dict(),
)
with text_encoder_info.context as text_encoder,\
tokenizer_info.context as tokenizer:
with text_encoder_info as text_encoder,\
tokenizer_info as tokenizer:
# TODO: global? input?
#use_full_precision = precision == "float32" or precision == "autocast"

View File

@ -79,12 +79,8 @@ def get_scheduler(
scheduler_info: ModelInfo,
scheduler_name: str,
) -> Scheduler:
orig_scheduler_info = context.services.model_manager.get_model(
model_name=scheduler_info.model_name,
model_type=SDModelType[scheduler_info.model_type],
submodel=SDModelType[scheduler_info.submodel],
)
with orig_scheduler_info.context as orig_scheduler:
orig_scheduler_info = context.services.model_manager.get_model(**scheduler_info.dict())
with orig_scheduler_info as orig_scheduler:
scheduler_config = orig_scheduler.config
scheduler_class = scheduler_map.get(scheduler_name,'ddim')
@ -243,14 +239,8 @@ class TextToLatentsInvocation(BaseInvocation):
def step_callback(state: PipelineIntermediateState):
self.dispatch_progress(context, source_node_id, state)
#unet_info = context.services.model_manager.get_model(**self.unet.unet.dict())
unet_info = context.services.model_manager.get_model(
model_name=self.unet.unet.model_name,
model_type=SDModelType[self.unet.unet.model_type],
submodel=SDModelType[self.unet.unet.submodel] if self.unet.unet.submodel else None,
)
with unet_info.context as unet:
unet_info = context.services.model_manager.get_model(**self.unet.unet.dict())
with unet_info as unet:
scheduler = get_scheduler(
context=context,
scheduler_info=self.unet.scheduler,
@ -309,12 +299,10 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
#unet_info = context.services.model_manager.get_model(**self.unet.unet.dict())
unet_info = context.services.model_manager.get_model(
model_name=self.unet.unet.model_name,
model_type=SDModelType[self.unet.unet.model_type],
submodel=SDModelType[self.unet.unet.submodel] if self.unet.unet.submodel else None,
**self.unet.unet.dict(),
)
with unet_info.context as unet:
with unet_info as unet:
scheduler = get_scheduler(
context=context,
scheduler_info=self.unet.scheduler,
@ -379,18 +367,18 @@ class LatentsToImageInvocation(BaseInvocation):
#vae_info = context.services.model_manager.get_model(**self.vae.vae.dict())
vae_info = context.services.model_manager.get_model(
model_name=self.vae.vae.model_name,
model_type=SDModelType[self.vae.vae.model_type],
submodel=SDModelType[self.vae.vae.submodel] if self.vae.vae.submodel else None,
**self.vae.vae.dict(),
)
with vae_info.context as vae:
# TODO: check if it works
with vae_info as vae:
if self.tiled:
vae.enable_tiling()
else:
vae.disable_tiling()
# clear memory as vae decode can request a lot
torch.cuda.empty_cache()
with torch.inference_mode():
# copied from diffusers pipeline
latents = latents / vae.config.scaling_factor
@ -509,36 +497,29 @@ class ImageToLatentsInvocation(BaseInvocation):
#vae_info = context.services.model_manager.get_model(**self.vae.vae.dict())
vae_info = context.services.model_manager.get_model(
model_name=self.vae.vae.model_name,
model_type=SDModelType[self.vae.vae.model_type],
submodel=SDModelType[self.vae.vae.submodel] if self.vae.vae.submodel else None,
**self.vae.vae.dict(),
)
image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB"))
if image_tensor.dim() == 3:
image_tensor = einops.rearrange(image_tensor, "c h w -> 1 c h w")
with vae_info.context as vae:
# TODO: check if it works
with vae_info as vae:
if self.tiled:
vae.enable_tiling()
else:
vae.disable_tiling()
latents = self.non_noised_latents_from_image(vae, image_tensor)
# non_noised_latents_from_image
image_tensor = image_tensor.to(device=vae.device, dtype=vae.dtype)
with torch.inference_mode():
image_tensor_dist = vae.encode(image_tensor).latent_dist
latents = image_tensor_dist.sample().to(
dtype=vae.dtype
) # FIXME: uses torch.randn. make reproducible!
latents = 0.18215 * latents
name = f"{context.graph_execution_state_id}__{self.id}"
context.services.latents.set(name, latents)
return LatentsOutput(latents=LatentsField(latents_name=name))
def non_noised_latents_from_image(self, vae, init_image):
init_image = init_image.to(device=vae.device, dtype=vae.dtype)
with torch.inference_mode():
init_latent_dist = vae.encode(init_image).latent_dist
init_latents = init_latent_dist.sample().to(
dtype=vae.dtype
) # FIXME: uses torch.randn. make reproducible!
init_latents = 0.18215 * init_latents
return init_latents

View File

@ -8,8 +8,8 @@ from ...backend.model_management import SDModelType
class ModelInfo(BaseModel):
model_name: str = Field(description="Info to load unet submodel")
model_type: str = Field(description="Info to load unet submodel")
submodel: Optional[str] = Field(description="Info to load unet submodel")
model_type: SDModelType = Field(description="Info to load unet submodel")
submodel: Optional[SDModelType] = Field(description="Info to load unet submodel")
class UNetField(BaseModel):
unet: ModelInfo = Field(description="Info to load unet submodel")
@ -62,15 +62,15 @@ class ModelLoaderInvocation(BaseInvocation):
# TODO: not found exceptions
if not context.services.model_manager.model_exists(
model_name=self.model_name,
model_type=SDModelType.diffusers,
model_type=SDModelType.Diffusers,
):
raise Exception(f"Unkown model name: {self.model_name}!")
"""
if not context.services.model_manager.model_exists(
model_name=self.model_name,
model_type=SDModelType.diffusers,
submodel=SDModelType.tokenizer,
model_type=SDModelType.Diffusers,
submodel=SDModelType.Tokenizer,
):
raise Exception(
f"Failed to find tokenizer submodel in {self.model_name}! Check if model corrupted"
@ -78,8 +78,8 @@ class ModelLoaderInvocation(BaseInvocation):
if not context.services.model_manager.model_exists(
model_name=self.model_name,
model_type=SDModelType.diffusers,
submodel=SDModelType.text_encoder,
model_type=SDModelType.Diffusers,
submodel=SDModelType.TextEncoder,
):
raise Exception(
f"Failed to find text_encoder submodel in {self.model_name}! Check if model corrupted"
@ -87,8 +87,8 @@ class ModelLoaderInvocation(BaseInvocation):
if not context.services.model_manager.model_exists(
model_name=self.model_name,
model_type=SDModelType.diffusers,
submodel=SDModelType.unet,
model_type=SDModelType.Diffusers,
submodel=SDModelType.UNet,
):
raise Exception(
f"Failed to find unet submodel from {self.model_name}! Check if model corrupted"
@ -100,32 +100,32 @@ class ModelLoaderInvocation(BaseInvocation):
unet=UNetField(
unet=ModelInfo(
model_name=self.model_name,
model_type=SDModelType.diffusers.name,
submodel=SDModelType.unet.name,
model_type=SDModelType.Diffusers,
submodel=SDModelType.UNet,
),
scheduler=ModelInfo(
model_name=self.model_name,
model_type=SDModelType.diffusers.name,
submodel=SDModelType.scheduler.name,
model_type=SDModelType.Diffusers,
submodel=SDModelType.Scheduler,
),
),
clip=ClipField(
tokenizer=ModelInfo(
model_name=self.model_name,
model_type=SDModelType.diffusers.name,
submodel=SDModelType.tokenizer.name,
model_type=SDModelType.Diffusers,
submodel=SDModelType.Tokenizer,
),
text_encoder=ModelInfo(
model_name=self.model_name,
model_type=SDModelType.diffusers.name,
submodel=SDModelType.text_encoder.name,
model_type=SDModelType.Diffusers,
submodel=SDModelType.TextEncoder,
),
),
vae=VaeField(
vae=ModelInfo(
model_name=self.model_name,
model_type=SDModelType.diffusers.name,
submodel=SDModelType.vae.name,
model_type=SDModelType.Diffusers,
submodel=SDModelType.Vae,
),
)
)

View File

@ -120,7 +120,7 @@ class EventServiceBase:
node=node,
source_node_id=source_node_id,
model_name=model_name,
model_type=model_type.name,
model_type=model_type,
submodel=submodel,
),
)
@ -143,7 +143,7 @@ class EventServiceBase:
node=node,
source_node_id=source_node_id,
model_name=model_name,
model_type=model_type.name,
model_type=model_type,
submodel=submodel,
model_info=model_info,
),

View File

@ -1,21 +1,25 @@
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Team
from __future__ import annotations
import torch
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Union, Callable, List, Tuple, types
from typing import Union, Callable, List, Tuple, types, TYPE_CHECKING
from dataclasses import dataclass
from invokeai.backend.model_management.model_manager import (
ModelManager,
SDModelType,
SDModelInfo,
torch,
)
from invokeai.app.models.exceptions import CanceledException
from ...backend import Args,Globals # this must go when pr 3340 merged
from ...backend import Args, Globals # this must go when pr 3340 merged
from ...backend.util import choose_precision, choose_torch_device
if TYPE_CHECKING:
from ..invocations.baseinvocation import BaseInvocation, InvocationContext
@dataclass
class LastUsedModel:
model_name: str=None
@ -28,9 +32,9 @@ class ModelManagerServiceBase(ABC):
@abstractmethod
def __init__(
self,
config: Args,
logger: types.ModuleType
self,
config: Args,
logger: types.ModuleType,
):
"""
Initialize with the path to the models.yaml config file.
@ -41,13 +45,14 @@ class ModelManagerServiceBase(ABC):
pass
@abstractmethod
def get_model(self,
model_name: str,
model_type: SDModelType=SDModelType.diffusers,
submodel: SDModelType=None,
node=None, # circular dependency issues, so untyped at moment
context=None,
)->SDModelInfo:
def get_model(
self,
model_name: str,
model_type: SDModelType = SDModelType.Diffusers,
submodel: Optional[SDModelType] = None,
node: Optional[BaseInvocation] = None,
context: Optional[InvocationContext] = None,
) -> SDModelInfo:
"""Retrieve the indicated model with name and type.
submodel can be used to get a part (such as the vae)
of a diffusers pipeline."""
@ -60,14 +65,14 @@ class ModelManagerServiceBase(ABC):
@abstractmethod
def model_exists(
self,
model_name: str,
model_type: SDModelType
self,
model_name: str,
model_type: SDModelType,
) -> bool:
pass
@abstractmethod
def default_model(self) -> Union[Tuple[str, SDModelType],None]:
def default_model(self) -> Optional[Tuple[str, SDModelType]]:
"""
Returns the name and typeof the default model, or None
if none is defined.
@ -80,21 +85,21 @@ class ModelManagerServiceBase(ABC):
pass
@abstractmethod
def model_info(self, model_name: str, model_type: SDModelType)->dict:
def model_info(self, model_name: str, model_type: SDModelType) -> dict:
"""
Given a model name returns a dict-like (OmegaConf) object describing it.
"""
pass
@abstractmethod
def model_names(self)->List[Tuple[str, SDModelType]]:
def model_names(self) -> List[Tuple[str, SDModelType]]:
"""
Returns a list of all the model names known.
"""
pass
@abstractmethod
def list_models(self)->dict:
def list_models(self) -> dict:
"""
Return a dict of models in the format:
{ model_key1: {'status': 'active'|'cached'|'not loaded',
@ -110,12 +115,12 @@ class ModelManagerServiceBase(ABC):
@abstractmethod
def add_model(
self,
model_name: str,
model_type: SDModelType,
model_attributes: dict,
clobber: bool = False
)->None:
self,
model_name: str,
model_type: SDModelType,
model_attributes: dict,
clobber: bool = False
) -> None:
"""
Update the named model with a dictionary of attributes. Will fail with an
assertion error if the name already exists. Pass clobber=True to overwrite.
@ -126,10 +131,12 @@ class ModelManagerServiceBase(ABC):
pass
@abstractmethod
def del_model(self,
model_name: str,
model_type: SDModelType,
delete_files: bool = False):
def del_model(
self,
model_name: str,
model_type: SDModelType,
delete_files: bool = False,
):
"""
Delete the named model from configuration. If delete_files is true,
then the underlying weight file or diffusers directory will be deleted
@ -140,9 +147,9 @@ class ModelManagerServiceBase(ABC):
@abstractmethod
def import_diffuser_model(
repo_or_path: Union[str, Path],
model_name: str = None,
description: str = None,
vae: dict = None,
model_name: Optional[str] = None,
description: Optional[str] = None,
vae: Optional[dict] = None,
) -> bool:
"""
Install the indicated diffuser model and returns True if successful.
@ -157,10 +164,10 @@ class ModelManagerServiceBase(ABC):
@abstractmethod
def import_lora(
self,
path: Path,
model_name: str=None,
description: str=None,
self,
path: Path,
model_name: Optional[str] = None,
description: Optional[str] = None,
):
"""
Creates an entry for the indicated lora file. Call
@ -170,10 +177,10 @@ class ModelManagerServiceBase(ABC):
@abstractmethod
def import_embedding(
self,
path: Path,
model_name: str=None,
description: str=None,
self,
path: Path,
model_name: str=None,
description: str=None,
):
"""
Creates an entry for the indicated textual inversion embedding file.
@ -223,7 +230,7 @@ class ModelManagerServiceBase(ABC):
pass
@abstractmethod
def commit(self, conf_file: Path=None) -> None:
def commit(self, conf_file: Path = None) -> None:
"""
Write current configuration out to the indicated file.
If no conf_file is provided, then replaces the
@ -235,10 +242,10 @@ class ModelManagerServiceBase(ABC):
class ModelManagerService(ModelManagerServiceBase):
"""Responsible for managing models on disk and in memory"""
def __init__(
self,
config: Args,
logger: types.ModuleType
):
self,
config: Args,
logger: types.ModuleType,
):
"""
Initialize with the path to the models.yaml config file.
Optional parameters are the torch device type, precision, max_models,
@ -255,7 +262,7 @@ class ModelManagerService(ModelManagerServiceBase):
logger.debug(f'config file={config_file}')
device = torch.device(choose_torch_device())
if config.precision=="auto":
if config.precision == "auto":
precision = choose_precision(device)
dtype = torch.float32 if precision=='float32' \
else torch.float16
@ -272,22 +279,24 @@ class ModelManagerService(ModelManagerServiceBase):
sequential_offload = config.sequential_guidance
self.mgr = ModelManager(config=config_file,
device_type=device,
precision=dtype,
max_cache_size=max_cache_size,
sequential_offload=sequential_offload,
logger=logger
)
self.mgr = ModelManager(
config=config_file,
device_type=device,
precision=dtype,
max_cache_size=max_cache_size,
sequential_offload=sequential_offload,
logger=logger,
)
logger.info('Model manager service initialized')
def get_model(self,
model_name: str,
model_type: SDModelType=SDModelType.diffusers,
submodel: SDModelType=None,
node=None,
context=None,
)->SDModelInfo:
def get_model(
self,
model_name: str,
model_type: SDModelType = SDModelType.Diffusers,
submodel: Optional[SDModelType] = None,
node: Optional[BaseInvocation] = None,
context: Optional[InvocationContext] = None,
) -> SDModelInfo:
"""
Retrieve the indicated model. submodel can be used to get a
part (such as the vae) of a diffusers mode.
@ -340,9 +349,9 @@ class ModelManagerService(ModelManagerServiceBase):
return model_info
def model_exists(
self,
model_name: str,
model_type: SDModelType
self,
model_name: str,
model_type: SDModelType,
) -> bool:
"""
Given a model name, returns True if it is a valid
@ -350,32 +359,33 @@ class ModelManagerService(ModelManagerServiceBase):
"""
return self.mgr.model_exists(
model_name,
model_type)
model_type,
)
def default_model(self) -> Union[Tuple[str, SDModelType],None]:
def default_model(self) -> Optional[Tuple[str, SDModelType]]:
"""
Returns the name of the default model, or None
if none is defined.
"""
return self.mgr.default_model()
def set_default_model(self, model_name:str, model_type: SDModelType):
def set_default_model(self, model_name: str, model_type: SDModelType):
"""Sets the default model to the indicated name."""
self.mgr.set_default_model(model_name)
def model_info(self, model_name: str, model_type: SDModelType)->dict:
def model_info(self, model_name: str, model_type: SDModelType) -> dict:
"""
Given a model name returns a dict-like (OmegaConf) object describing it.
"""
return self.mgr.model_info(model_name)
def model_names(self)->List[Tuple[str, SDModelType]]:
def model_names(self) -> List[Tuple[str, SDModelType]]:
"""
Returns a list of all the model names known.
"""
return self.mgr.model_names()
def list_models(self)->dict:
def list_models(self) -> dict:
"""
Return a dict of models in the format:
{ model_key: {'status': 'active'|'cached'|'not loaded',
@ -388,11 +398,12 @@ class ModelManagerService(ModelManagerServiceBase):
return self.mgr.list_models()
def add_model(
self,
model_name: str,
model_type: SDModelType,
model_attributes: dict,
clobber: bool = False)->None:
self,
model_name: str,
model_type: SDModelType,
model_attributes: dict,
clobber: bool = False,
)->None:
"""
Update the named model with a dictionary of attributes. Will fail with an
assertion error if the name already exists. Pass clobber=True to overwrite.
@ -400,14 +411,15 @@ class ModelManagerService(ModelManagerServiceBase):
with an assertion error if provided attributes are incorrect or
the model name is missing. Call commit() to write changes to disk.
"""
return self.mgr.add_model(model_name, model_type, model_attributes, dict, clobber)
return self.mgr.add_model(model_name, model_type, model_attributes, clobber)
def del_model(self,
model_name: str,
model_type: SDModelType=SDModelType.diffusers,
delete_files: bool = False
):
def del_model(
self,
model_name: str,
model_type: SDModelType = SDModelType.Diffusers,
delete_files: bool = False,
):
"""
Delete the named model from configuration. If delete_files is true,
then the underlying weight file or diffusers directory will be deleted
@ -416,11 +428,11 @@ class ModelManagerService(ModelManagerServiceBase):
self.mgr.del_model(model_name, model_type, delete_files)
def import_diffuser_model(
self,
repo_or_path: Union[str, Path],
model_name: str = None,
description: str = None,
vae: dict = None,
self,
repo_or_path: Union[str, Path],
model_name: Optional[str] = None,
description: Optional[str] = None,
vae: Optional[dict] = None,
) -> bool:
"""
Install the indicated diffuser model and returns True if successful.
@ -431,13 +443,13 @@ class ModelManagerService(ModelManagerServiceBase):
You can optionally provide a model name and/or description. If not provided,
then these will be derived from the repo name. Call commit() to write to disk.
"""
return self.mgr.import_diffuser_model(repo_or_path, model_name, description, vae)
return self.mgr.import_diffuser_model(repo_or_path, model_name, description, vae)
def import_lora(
self,
path: Path,
model_name: str=None,
description: str=None,
self,
path: Path,
model_name: Optional[str] = None,
description: Optional[str] = None,
):
"""
Creates an entry for the indicated lora file. Call
@ -446,10 +458,10 @@ class ModelManagerService(ModelManagerServiceBase):
self.mgr.import_lora(path, model_name, description)
def import_embedding(
self,
path: Path,
model_name: str=None,
description: str=None,
self,
path: Path,
model_name: Optional[str] = None,
description: Optional[str] = None,
):
"""
Creates an entry for the indicated textual inversion embedding file.
@ -462,9 +474,9 @@ class ModelManagerService(ModelManagerServiceBase):
path_url_or_repo: str,
model_name: str = None,
description: str = None,
model_config_file: Path = None,
commit_to_conf: Path = None,
config_file_callback: Callable[[Path], Path] = None,
model_config_file: Optional[Path] = None,
commit_to_conf: Optional[Path] = None,
config_file_callback: Optional[Callable[[Path], Path]] = None,
) -> str:
"""Accept a string which could be:
- a HF diffusers repo_id
@ -505,7 +517,7 @@ class ModelManagerService(ModelManagerServiceBase):
)
def commit(self, conf_file: Path=None):
def commit(self, conf_file: Optional[Path]=None):
"""
Write current configuration out to the indicated file.
If no conf_file is provided, then replaces the
@ -514,16 +526,16 @@ class ModelManagerService(ModelManagerServiceBase):
return self.mgr.commit(conf_file)
def _emit_load_event(
self,
node,
context,
model_name: str,
model_type: SDModelType,
submodel: SDModelType,
model_info: SDModelInfo=None,
self,
node,
context,
model_name: str,
model_type: SDModelType,
submodel: SDModelType,
model_info: Optional[SDModelInfo] = None,
):
if context.services.queue.is_canceled(context.graph_execution_state_id):
raise CanceledException
raise CanceledException()
graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id)
source_node_id = graph_execution_state.prepared_source_mapping[node.id]
if context:
@ -536,7 +548,7 @@ class ModelManagerService(ModelManagerServiceBase):
submodel=submodel,
)
else:
context.services.events.emit_model_load_completed (
context.services.events.emit_model_load_completed(
graph_execution_state_id=context.graph_execution_state_id,
node=node.dict(),
source_node_id=source_node_id,