mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Change SDModelType enum to string, fixes(model unload negative locks count, scheduler load error, saftensors convert, wrong logic in del_model, wrong parse metadata in web)
This commit is contained in:
@ -59,19 +59,14 @@ class CompelInvocation(BaseInvocation):
|
||||
|
||||
def invoke(self, context: InvocationContext) -> CompelOutput:
|
||||
|
||||
# TODO: load without model
|
||||
text_encoder_info = context.services.model_manager.get_model(
|
||||
model_name=self.clip.text_encoder.model_name,
|
||||
model_type=SDModelType[self.clip.text_encoder.model_type],
|
||||
submodel=SDModelType[self.clip.text_encoder.submodel],
|
||||
**self.clip.text_encoder.dict(),
|
||||
)
|
||||
tokenizer_info = context.services.model_manager.get_model(
|
||||
model_name=self.clip.tokenizer.model_name,
|
||||
model_type=SDModelType[self.clip.tokenizer.model_type],
|
||||
submodel=SDModelType[self.clip.tokenizer.submodel],
|
||||
**self.clip.tokenizer.dict(),
|
||||
)
|
||||
with text_encoder_info.context as text_encoder,\
|
||||
tokenizer_info.context as tokenizer:
|
||||
with text_encoder_info as text_encoder,\
|
||||
tokenizer_info as tokenizer:
|
||||
|
||||
# TODO: global? input?
|
||||
#use_full_precision = precision == "float32" or precision == "autocast"
|
||||
|
@ -79,12 +79,8 @@ def get_scheduler(
|
||||
scheduler_info: ModelInfo,
|
||||
scheduler_name: str,
|
||||
) -> Scheduler:
|
||||
orig_scheduler_info = context.services.model_manager.get_model(
|
||||
model_name=scheduler_info.model_name,
|
||||
model_type=SDModelType[scheduler_info.model_type],
|
||||
submodel=SDModelType[scheduler_info.submodel],
|
||||
)
|
||||
with orig_scheduler_info.context as orig_scheduler:
|
||||
orig_scheduler_info = context.services.model_manager.get_model(**scheduler_info.dict())
|
||||
with orig_scheduler_info as orig_scheduler:
|
||||
scheduler_config = orig_scheduler.config
|
||||
|
||||
scheduler_class = scheduler_map.get(scheduler_name,'ddim')
|
||||
@ -243,14 +239,8 @@ class TextToLatentsInvocation(BaseInvocation):
|
||||
def step_callback(state: PipelineIntermediateState):
|
||||
self.dispatch_progress(context, source_node_id, state)
|
||||
|
||||
#unet_info = context.services.model_manager.get_model(**self.unet.unet.dict())
|
||||
unet_info = context.services.model_manager.get_model(
|
||||
model_name=self.unet.unet.model_name,
|
||||
model_type=SDModelType[self.unet.unet.model_type],
|
||||
submodel=SDModelType[self.unet.unet.submodel] if self.unet.unet.submodel else None,
|
||||
)
|
||||
|
||||
with unet_info.context as unet:
|
||||
unet_info = context.services.model_manager.get_model(**self.unet.unet.dict())
|
||||
with unet_info as unet:
|
||||
scheduler = get_scheduler(
|
||||
context=context,
|
||||
scheduler_info=self.unet.scheduler,
|
||||
@ -309,12 +299,10 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
|
||||
|
||||
#unet_info = context.services.model_manager.get_model(**self.unet.unet.dict())
|
||||
unet_info = context.services.model_manager.get_model(
|
||||
model_name=self.unet.unet.model_name,
|
||||
model_type=SDModelType[self.unet.unet.model_type],
|
||||
submodel=SDModelType[self.unet.unet.submodel] if self.unet.unet.submodel else None,
|
||||
**self.unet.unet.dict(),
|
||||
)
|
||||
|
||||
with unet_info.context as unet:
|
||||
with unet_info as unet:
|
||||
scheduler = get_scheduler(
|
||||
context=context,
|
||||
scheduler_info=self.unet.scheduler,
|
||||
@ -379,18 +367,18 @@ class LatentsToImageInvocation(BaseInvocation):
|
||||
|
||||
#vae_info = context.services.model_manager.get_model(**self.vae.vae.dict())
|
||||
vae_info = context.services.model_manager.get_model(
|
||||
model_name=self.vae.vae.model_name,
|
||||
model_type=SDModelType[self.vae.vae.model_type],
|
||||
submodel=SDModelType[self.vae.vae.submodel] if self.vae.vae.submodel else None,
|
||||
**self.vae.vae.dict(),
|
||||
)
|
||||
|
||||
with vae_info.context as vae:
|
||||
# TODO: check if it works
|
||||
with vae_info as vae:
|
||||
if self.tiled:
|
||||
vae.enable_tiling()
|
||||
else:
|
||||
vae.disable_tiling()
|
||||
|
||||
# clear memory as vae decode can request a lot
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
with torch.inference_mode():
|
||||
# copied from diffusers pipeline
|
||||
latents = latents / vae.config.scaling_factor
|
||||
@ -509,36 +497,29 @@ class ImageToLatentsInvocation(BaseInvocation):
|
||||
|
||||
#vae_info = context.services.model_manager.get_model(**self.vae.vae.dict())
|
||||
vae_info = context.services.model_manager.get_model(
|
||||
model_name=self.vae.vae.model_name,
|
||||
model_type=SDModelType[self.vae.vae.model_type],
|
||||
submodel=SDModelType[self.vae.vae.submodel] if self.vae.vae.submodel else None,
|
||||
**self.vae.vae.dict(),
|
||||
)
|
||||
|
||||
image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB"))
|
||||
if image_tensor.dim() == 3:
|
||||
image_tensor = einops.rearrange(image_tensor, "c h w -> 1 c h w")
|
||||
|
||||
with vae_info.context as vae:
|
||||
# TODO: check if it works
|
||||
with vae_info as vae:
|
||||
if self.tiled:
|
||||
vae.enable_tiling()
|
||||
else:
|
||||
vae.disable_tiling()
|
||||
|
||||
latents = self.non_noised_latents_from_image(vae, image_tensor)
|
||||
# non_noised_latents_from_image
|
||||
image_tensor = image_tensor.to(device=vae.device, dtype=vae.dtype)
|
||||
with torch.inference_mode():
|
||||
image_tensor_dist = vae.encode(image_tensor).latent_dist
|
||||
latents = image_tensor_dist.sample().to(
|
||||
dtype=vae.dtype
|
||||
) # FIXME: uses torch.randn. make reproducible!
|
||||
|
||||
latents = 0.18215 * latents
|
||||
|
||||
name = f"{context.graph_execution_state_id}__{self.id}"
|
||||
context.services.latents.set(name, latents)
|
||||
return LatentsOutput(latents=LatentsField(latents_name=name))
|
||||
|
||||
|
||||
def non_noised_latents_from_image(self, vae, init_image):
|
||||
init_image = init_image.to(device=vae.device, dtype=vae.dtype)
|
||||
with torch.inference_mode():
|
||||
init_latent_dist = vae.encode(init_image).latent_dist
|
||||
init_latents = init_latent_dist.sample().to(
|
||||
dtype=vae.dtype
|
||||
) # FIXME: uses torch.randn. make reproducible!
|
||||
|
||||
init_latents = 0.18215 * init_latents
|
||||
return init_latents
|
@ -8,8 +8,8 @@ from ...backend.model_management import SDModelType
|
||||
|
||||
class ModelInfo(BaseModel):
|
||||
model_name: str = Field(description="Info to load unet submodel")
|
||||
model_type: str = Field(description="Info to load unet submodel")
|
||||
submodel: Optional[str] = Field(description="Info to load unet submodel")
|
||||
model_type: SDModelType = Field(description="Info to load unet submodel")
|
||||
submodel: Optional[SDModelType] = Field(description="Info to load unet submodel")
|
||||
|
||||
class UNetField(BaseModel):
|
||||
unet: ModelInfo = Field(description="Info to load unet submodel")
|
||||
@ -62,15 +62,15 @@ class ModelLoaderInvocation(BaseInvocation):
|
||||
# TODO: not found exceptions
|
||||
if not context.services.model_manager.model_exists(
|
||||
model_name=self.model_name,
|
||||
model_type=SDModelType.diffusers,
|
||||
model_type=SDModelType.Diffusers,
|
||||
):
|
||||
raise Exception(f"Unkown model name: {self.model_name}!")
|
||||
|
||||
"""
|
||||
if not context.services.model_manager.model_exists(
|
||||
model_name=self.model_name,
|
||||
model_type=SDModelType.diffusers,
|
||||
submodel=SDModelType.tokenizer,
|
||||
model_type=SDModelType.Diffusers,
|
||||
submodel=SDModelType.Tokenizer,
|
||||
):
|
||||
raise Exception(
|
||||
f"Failed to find tokenizer submodel in {self.model_name}! Check if model corrupted"
|
||||
@ -78,8 +78,8 @@ class ModelLoaderInvocation(BaseInvocation):
|
||||
|
||||
if not context.services.model_manager.model_exists(
|
||||
model_name=self.model_name,
|
||||
model_type=SDModelType.diffusers,
|
||||
submodel=SDModelType.text_encoder,
|
||||
model_type=SDModelType.Diffusers,
|
||||
submodel=SDModelType.TextEncoder,
|
||||
):
|
||||
raise Exception(
|
||||
f"Failed to find text_encoder submodel in {self.model_name}! Check if model corrupted"
|
||||
@ -87,8 +87,8 @@ class ModelLoaderInvocation(BaseInvocation):
|
||||
|
||||
if not context.services.model_manager.model_exists(
|
||||
model_name=self.model_name,
|
||||
model_type=SDModelType.diffusers,
|
||||
submodel=SDModelType.unet,
|
||||
model_type=SDModelType.Diffusers,
|
||||
submodel=SDModelType.UNet,
|
||||
):
|
||||
raise Exception(
|
||||
f"Failed to find unet submodel from {self.model_name}! Check if model corrupted"
|
||||
@ -100,32 +100,32 @@ class ModelLoaderInvocation(BaseInvocation):
|
||||
unet=UNetField(
|
||||
unet=ModelInfo(
|
||||
model_name=self.model_name,
|
||||
model_type=SDModelType.diffusers.name,
|
||||
submodel=SDModelType.unet.name,
|
||||
model_type=SDModelType.Diffusers,
|
||||
submodel=SDModelType.UNet,
|
||||
),
|
||||
scheduler=ModelInfo(
|
||||
model_name=self.model_name,
|
||||
model_type=SDModelType.diffusers.name,
|
||||
submodel=SDModelType.scheduler.name,
|
||||
model_type=SDModelType.Diffusers,
|
||||
submodel=SDModelType.Scheduler,
|
||||
),
|
||||
),
|
||||
clip=ClipField(
|
||||
tokenizer=ModelInfo(
|
||||
model_name=self.model_name,
|
||||
model_type=SDModelType.diffusers.name,
|
||||
submodel=SDModelType.tokenizer.name,
|
||||
model_type=SDModelType.Diffusers,
|
||||
submodel=SDModelType.Tokenizer,
|
||||
),
|
||||
text_encoder=ModelInfo(
|
||||
model_name=self.model_name,
|
||||
model_type=SDModelType.diffusers.name,
|
||||
submodel=SDModelType.text_encoder.name,
|
||||
model_type=SDModelType.Diffusers,
|
||||
submodel=SDModelType.TextEncoder,
|
||||
),
|
||||
),
|
||||
vae=VaeField(
|
||||
vae=ModelInfo(
|
||||
model_name=self.model_name,
|
||||
model_type=SDModelType.diffusers.name,
|
||||
submodel=SDModelType.vae.name,
|
||||
model_type=SDModelType.Diffusers,
|
||||
submodel=SDModelType.Vae,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
@ -120,7 +120,7 @@ class EventServiceBase:
|
||||
node=node,
|
||||
source_node_id=source_node_id,
|
||||
model_name=model_name,
|
||||
model_type=model_type.name,
|
||||
model_type=model_type,
|
||||
submodel=submodel,
|
||||
),
|
||||
)
|
||||
@ -143,7 +143,7 @@ class EventServiceBase:
|
||||
node=node,
|
||||
source_node_id=source_node_id,
|
||||
model_name=model_name,
|
||||
model_type=model_type.name,
|
||||
model_type=model_type,
|
||||
submodel=submodel,
|
||||
model_info=model_info,
|
||||
),
|
||||
|
@ -1,21 +1,25 @@
|
||||
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Team
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import torch
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import Union, Callable, List, Tuple, types
|
||||
from typing import Union, Callable, List, Tuple, types, TYPE_CHECKING
|
||||
from dataclasses import dataclass
|
||||
|
||||
from invokeai.backend.model_management.model_manager import (
|
||||
ModelManager,
|
||||
SDModelType,
|
||||
SDModelInfo,
|
||||
torch,
|
||||
)
|
||||
from invokeai.app.models.exceptions import CanceledException
|
||||
from ...backend import Args,Globals # this must go when pr 3340 merged
|
||||
from ...backend import Args, Globals # this must go when pr 3340 merged
|
||||
from ...backend.util import choose_precision, choose_torch_device
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..invocations.baseinvocation import BaseInvocation, InvocationContext
|
||||
|
||||
@dataclass
|
||||
class LastUsedModel:
|
||||
model_name: str=None
|
||||
@ -28,9 +32,9 @@ class ModelManagerServiceBase(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def __init__(
|
||||
self,
|
||||
config: Args,
|
||||
logger: types.ModuleType
|
||||
self,
|
||||
config: Args,
|
||||
logger: types.ModuleType,
|
||||
):
|
||||
"""
|
||||
Initialize with the path to the models.yaml config file.
|
||||
@ -41,13 +45,14 @@ class ModelManagerServiceBase(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_model(self,
|
||||
model_name: str,
|
||||
model_type: SDModelType=SDModelType.diffusers,
|
||||
submodel: SDModelType=None,
|
||||
node=None, # circular dependency issues, so untyped at moment
|
||||
context=None,
|
||||
)->SDModelInfo:
|
||||
def get_model(
|
||||
self,
|
||||
model_name: str,
|
||||
model_type: SDModelType = SDModelType.Diffusers,
|
||||
submodel: Optional[SDModelType] = None,
|
||||
node: Optional[BaseInvocation] = None,
|
||||
context: Optional[InvocationContext] = None,
|
||||
) -> SDModelInfo:
|
||||
"""Retrieve the indicated model with name and type.
|
||||
submodel can be used to get a part (such as the vae)
|
||||
of a diffusers pipeline."""
|
||||
@ -60,14 +65,14 @@ class ModelManagerServiceBase(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def model_exists(
|
||||
self,
|
||||
model_name: str,
|
||||
model_type: SDModelType
|
||||
self,
|
||||
model_name: str,
|
||||
model_type: SDModelType,
|
||||
) -> bool:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def default_model(self) -> Union[Tuple[str, SDModelType],None]:
|
||||
def default_model(self) -> Optional[Tuple[str, SDModelType]]:
|
||||
"""
|
||||
Returns the name and typeof the default model, or None
|
||||
if none is defined.
|
||||
@ -80,21 +85,21 @@ class ModelManagerServiceBase(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def model_info(self, model_name: str, model_type: SDModelType)->dict:
|
||||
def model_info(self, model_name: str, model_type: SDModelType) -> dict:
|
||||
"""
|
||||
Given a model name returns a dict-like (OmegaConf) object describing it.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def model_names(self)->List[Tuple[str, SDModelType]]:
|
||||
def model_names(self) -> List[Tuple[str, SDModelType]]:
|
||||
"""
|
||||
Returns a list of all the model names known.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def list_models(self)->dict:
|
||||
def list_models(self) -> dict:
|
||||
"""
|
||||
Return a dict of models in the format:
|
||||
{ model_key1: {'status': 'active'|'cached'|'not loaded',
|
||||
@ -110,12 +115,12 @@ class ModelManagerServiceBase(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def add_model(
|
||||
self,
|
||||
model_name: str,
|
||||
model_type: SDModelType,
|
||||
model_attributes: dict,
|
||||
clobber: bool = False
|
||||
)->None:
|
||||
self,
|
||||
model_name: str,
|
||||
model_type: SDModelType,
|
||||
model_attributes: dict,
|
||||
clobber: bool = False
|
||||
) -> None:
|
||||
"""
|
||||
Update the named model with a dictionary of attributes. Will fail with an
|
||||
assertion error if the name already exists. Pass clobber=True to overwrite.
|
||||
@ -126,10 +131,12 @@ class ModelManagerServiceBase(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def del_model(self,
|
||||
model_name: str,
|
||||
model_type: SDModelType,
|
||||
delete_files: bool = False):
|
||||
def del_model(
|
||||
self,
|
||||
model_name: str,
|
||||
model_type: SDModelType,
|
||||
delete_files: bool = False,
|
||||
):
|
||||
"""
|
||||
Delete the named model from configuration. If delete_files is true,
|
||||
then the underlying weight file or diffusers directory will be deleted
|
||||
@ -140,9 +147,9 @@ class ModelManagerServiceBase(ABC):
|
||||
@abstractmethod
|
||||
def import_diffuser_model(
|
||||
repo_or_path: Union[str, Path],
|
||||
model_name: str = None,
|
||||
description: str = None,
|
||||
vae: dict = None,
|
||||
model_name: Optional[str] = None,
|
||||
description: Optional[str] = None,
|
||||
vae: Optional[dict] = None,
|
||||
) -> bool:
|
||||
"""
|
||||
Install the indicated diffuser model and returns True if successful.
|
||||
@ -157,10 +164,10 @@ class ModelManagerServiceBase(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def import_lora(
|
||||
self,
|
||||
path: Path,
|
||||
model_name: str=None,
|
||||
description: str=None,
|
||||
self,
|
||||
path: Path,
|
||||
model_name: Optional[str] = None,
|
||||
description: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Creates an entry for the indicated lora file. Call
|
||||
@ -170,10 +177,10 @@ class ModelManagerServiceBase(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def import_embedding(
|
||||
self,
|
||||
path: Path,
|
||||
model_name: str=None,
|
||||
description: str=None,
|
||||
self,
|
||||
path: Path,
|
||||
model_name: str=None,
|
||||
description: str=None,
|
||||
):
|
||||
"""
|
||||
Creates an entry for the indicated textual inversion embedding file.
|
||||
@ -223,7 +230,7 @@ class ModelManagerServiceBase(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def commit(self, conf_file: Path=None) -> None:
|
||||
def commit(self, conf_file: Path = None) -> None:
|
||||
"""
|
||||
Write current configuration out to the indicated file.
|
||||
If no conf_file is provided, then replaces the
|
||||
@ -235,10 +242,10 @@ class ModelManagerServiceBase(ABC):
|
||||
class ModelManagerService(ModelManagerServiceBase):
|
||||
"""Responsible for managing models on disk and in memory"""
|
||||
def __init__(
|
||||
self,
|
||||
config: Args,
|
||||
logger: types.ModuleType
|
||||
):
|
||||
self,
|
||||
config: Args,
|
||||
logger: types.ModuleType,
|
||||
):
|
||||
"""
|
||||
Initialize with the path to the models.yaml config file.
|
||||
Optional parameters are the torch device type, precision, max_models,
|
||||
@ -255,7 +262,7 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
logger.debug(f'config file={config_file}')
|
||||
|
||||
device = torch.device(choose_torch_device())
|
||||
if config.precision=="auto":
|
||||
if config.precision == "auto":
|
||||
precision = choose_precision(device)
|
||||
dtype = torch.float32 if precision=='float32' \
|
||||
else torch.float16
|
||||
@ -272,22 +279,24 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
|
||||
sequential_offload = config.sequential_guidance
|
||||
|
||||
self.mgr = ModelManager(config=config_file,
|
||||
device_type=device,
|
||||
precision=dtype,
|
||||
max_cache_size=max_cache_size,
|
||||
sequential_offload=sequential_offload,
|
||||
logger=logger
|
||||
)
|
||||
self.mgr = ModelManager(
|
||||
config=config_file,
|
||||
device_type=device,
|
||||
precision=dtype,
|
||||
max_cache_size=max_cache_size,
|
||||
sequential_offload=sequential_offload,
|
||||
logger=logger,
|
||||
)
|
||||
logger.info('Model manager service initialized')
|
||||
|
||||
def get_model(self,
|
||||
model_name: str,
|
||||
model_type: SDModelType=SDModelType.diffusers,
|
||||
submodel: SDModelType=None,
|
||||
node=None,
|
||||
context=None,
|
||||
)->SDModelInfo:
|
||||
def get_model(
|
||||
self,
|
||||
model_name: str,
|
||||
model_type: SDModelType = SDModelType.Diffusers,
|
||||
submodel: Optional[SDModelType] = None,
|
||||
node: Optional[BaseInvocation] = None,
|
||||
context: Optional[InvocationContext] = None,
|
||||
) -> SDModelInfo:
|
||||
"""
|
||||
Retrieve the indicated model. submodel can be used to get a
|
||||
part (such as the vae) of a diffusers mode.
|
||||
@ -340,9 +349,9 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
return model_info
|
||||
|
||||
def model_exists(
|
||||
self,
|
||||
model_name: str,
|
||||
model_type: SDModelType
|
||||
self,
|
||||
model_name: str,
|
||||
model_type: SDModelType,
|
||||
) -> bool:
|
||||
"""
|
||||
Given a model name, returns True if it is a valid
|
||||
@ -350,32 +359,33 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
"""
|
||||
return self.mgr.model_exists(
|
||||
model_name,
|
||||
model_type)
|
||||
model_type,
|
||||
)
|
||||
|
||||
def default_model(self) -> Union[Tuple[str, SDModelType],None]:
|
||||
def default_model(self) -> Optional[Tuple[str, SDModelType]]:
|
||||
"""
|
||||
Returns the name of the default model, or None
|
||||
if none is defined.
|
||||
"""
|
||||
return self.mgr.default_model()
|
||||
|
||||
def set_default_model(self, model_name:str, model_type: SDModelType):
|
||||
def set_default_model(self, model_name: str, model_type: SDModelType):
|
||||
"""Sets the default model to the indicated name."""
|
||||
self.mgr.set_default_model(model_name)
|
||||
|
||||
def model_info(self, model_name: str, model_type: SDModelType)->dict:
|
||||
def model_info(self, model_name: str, model_type: SDModelType) -> dict:
|
||||
"""
|
||||
Given a model name returns a dict-like (OmegaConf) object describing it.
|
||||
"""
|
||||
return self.mgr.model_info(model_name)
|
||||
|
||||
def model_names(self)->List[Tuple[str, SDModelType]]:
|
||||
def model_names(self) -> List[Tuple[str, SDModelType]]:
|
||||
"""
|
||||
Returns a list of all the model names known.
|
||||
"""
|
||||
return self.mgr.model_names()
|
||||
|
||||
def list_models(self)->dict:
|
||||
def list_models(self) -> dict:
|
||||
"""
|
||||
Return a dict of models in the format:
|
||||
{ model_key: {'status': 'active'|'cached'|'not loaded',
|
||||
@ -388,11 +398,12 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
return self.mgr.list_models()
|
||||
|
||||
def add_model(
|
||||
self,
|
||||
model_name: str,
|
||||
model_type: SDModelType,
|
||||
model_attributes: dict,
|
||||
clobber: bool = False)->None:
|
||||
self,
|
||||
model_name: str,
|
||||
model_type: SDModelType,
|
||||
model_attributes: dict,
|
||||
clobber: bool = False,
|
||||
)->None:
|
||||
"""
|
||||
Update the named model with a dictionary of attributes. Will fail with an
|
||||
assertion error if the name already exists. Pass clobber=True to overwrite.
|
||||
@ -400,14 +411,15 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
with an assertion error if provided attributes are incorrect or
|
||||
the model name is missing. Call commit() to write changes to disk.
|
||||
"""
|
||||
return self.mgr.add_model(model_name, model_type, model_attributes, dict, clobber)
|
||||
return self.mgr.add_model(model_name, model_type, model_attributes, clobber)
|
||||
|
||||
|
||||
def del_model(self,
|
||||
model_name: str,
|
||||
model_type: SDModelType=SDModelType.diffusers,
|
||||
delete_files: bool = False
|
||||
):
|
||||
def del_model(
|
||||
self,
|
||||
model_name: str,
|
||||
model_type: SDModelType = SDModelType.Diffusers,
|
||||
delete_files: bool = False,
|
||||
):
|
||||
"""
|
||||
Delete the named model from configuration. If delete_files is true,
|
||||
then the underlying weight file or diffusers directory will be deleted
|
||||
@ -416,11 +428,11 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
self.mgr.del_model(model_name, model_type, delete_files)
|
||||
|
||||
def import_diffuser_model(
|
||||
self,
|
||||
repo_or_path: Union[str, Path],
|
||||
model_name: str = None,
|
||||
description: str = None,
|
||||
vae: dict = None,
|
||||
self,
|
||||
repo_or_path: Union[str, Path],
|
||||
model_name: Optional[str] = None,
|
||||
description: Optional[str] = None,
|
||||
vae: Optional[dict] = None,
|
||||
) -> bool:
|
||||
"""
|
||||
Install the indicated diffuser model and returns True if successful.
|
||||
@ -431,13 +443,13 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
You can optionally provide a model name and/or description. If not provided,
|
||||
then these will be derived from the repo name. Call commit() to write to disk.
|
||||
"""
|
||||
return self.mgr.import_diffuser_model(repo_or_path, model_name, description, vae)
|
||||
return self.mgr.import_diffuser_model(repo_or_path, model_name, description, vae)
|
||||
|
||||
def import_lora(
|
||||
self,
|
||||
path: Path,
|
||||
model_name: str=None,
|
||||
description: str=None,
|
||||
self,
|
||||
path: Path,
|
||||
model_name: Optional[str] = None,
|
||||
description: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Creates an entry for the indicated lora file. Call
|
||||
@ -446,10 +458,10 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
self.mgr.import_lora(path, model_name, description)
|
||||
|
||||
def import_embedding(
|
||||
self,
|
||||
path: Path,
|
||||
model_name: str=None,
|
||||
description: str=None,
|
||||
self,
|
||||
path: Path,
|
||||
model_name: Optional[str] = None,
|
||||
description: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Creates an entry for the indicated textual inversion embedding file.
|
||||
@ -462,9 +474,9 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
path_url_or_repo: str,
|
||||
model_name: str = None,
|
||||
description: str = None,
|
||||
model_config_file: Path = None,
|
||||
commit_to_conf: Path = None,
|
||||
config_file_callback: Callable[[Path], Path] = None,
|
||||
model_config_file: Optional[Path] = None,
|
||||
commit_to_conf: Optional[Path] = None,
|
||||
config_file_callback: Optional[Callable[[Path], Path]] = None,
|
||||
) -> str:
|
||||
"""Accept a string which could be:
|
||||
- a HF diffusers repo_id
|
||||
@ -505,7 +517,7 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
)
|
||||
|
||||
|
||||
def commit(self, conf_file: Path=None):
|
||||
def commit(self, conf_file: Optional[Path]=None):
|
||||
"""
|
||||
Write current configuration out to the indicated file.
|
||||
If no conf_file is provided, then replaces the
|
||||
@ -514,16 +526,16 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
return self.mgr.commit(conf_file)
|
||||
|
||||
def _emit_load_event(
|
||||
self,
|
||||
node,
|
||||
context,
|
||||
model_name: str,
|
||||
model_type: SDModelType,
|
||||
submodel: SDModelType,
|
||||
model_info: SDModelInfo=None,
|
||||
self,
|
||||
node,
|
||||
context,
|
||||
model_name: str,
|
||||
model_type: SDModelType,
|
||||
submodel: SDModelType,
|
||||
model_info: Optional[SDModelInfo] = None,
|
||||
):
|
||||
if context.services.queue.is_canceled(context.graph_execution_state_id):
|
||||
raise CanceledException
|
||||
raise CanceledException()
|
||||
graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id)
|
||||
source_node_id = graph_execution_state.prepared_source_mapping[node.id]
|
||||
if context:
|
||||
@ -536,7 +548,7 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
submodel=submodel,
|
||||
)
|
||||
else:
|
||||
context.services.events.emit_model_load_completed (
|
||||
context.services.events.emit_model_load_completed(
|
||||
graph_execution_state_id=context.graph_execution_state_id,
|
||||
node=node.dict(),
|
||||
source_node_id=source_node_id,
|
||||
|
Reference in New Issue
Block a user