mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Change SDModelType enum to string, fixes(model unload negative locks count, scheduler load error, saftensors convert, wrong logic in del_model, wrong parse metadata in web)
This commit is contained in:
parent
2204e47596
commit
039fa73269
@ -59,19 +59,14 @@ class CompelInvocation(BaseInvocation):
|
||||
|
||||
def invoke(self, context: InvocationContext) -> CompelOutput:
|
||||
|
||||
# TODO: load without model
|
||||
text_encoder_info = context.services.model_manager.get_model(
|
||||
model_name=self.clip.text_encoder.model_name,
|
||||
model_type=SDModelType[self.clip.text_encoder.model_type],
|
||||
submodel=SDModelType[self.clip.text_encoder.submodel],
|
||||
**self.clip.text_encoder.dict(),
|
||||
)
|
||||
tokenizer_info = context.services.model_manager.get_model(
|
||||
model_name=self.clip.tokenizer.model_name,
|
||||
model_type=SDModelType[self.clip.tokenizer.model_type],
|
||||
submodel=SDModelType[self.clip.tokenizer.submodel],
|
||||
**self.clip.tokenizer.dict(),
|
||||
)
|
||||
with text_encoder_info.context as text_encoder,\
|
||||
tokenizer_info.context as tokenizer:
|
||||
with text_encoder_info as text_encoder,\
|
||||
tokenizer_info as tokenizer:
|
||||
|
||||
# TODO: global? input?
|
||||
#use_full_precision = precision == "float32" or precision == "autocast"
|
||||
|
@ -79,12 +79,8 @@ def get_scheduler(
|
||||
scheduler_info: ModelInfo,
|
||||
scheduler_name: str,
|
||||
) -> Scheduler:
|
||||
orig_scheduler_info = context.services.model_manager.get_model(
|
||||
model_name=scheduler_info.model_name,
|
||||
model_type=SDModelType[scheduler_info.model_type],
|
||||
submodel=SDModelType[scheduler_info.submodel],
|
||||
)
|
||||
with orig_scheduler_info.context as orig_scheduler:
|
||||
orig_scheduler_info = context.services.model_manager.get_model(**scheduler_info.dict())
|
||||
with orig_scheduler_info as orig_scheduler:
|
||||
scheduler_config = orig_scheduler.config
|
||||
|
||||
scheduler_class = scheduler_map.get(scheduler_name,'ddim')
|
||||
@ -243,14 +239,8 @@ class TextToLatentsInvocation(BaseInvocation):
|
||||
def step_callback(state: PipelineIntermediateState):
|
||||
self.dispatch_progress(context, source_node_id, state)
|
||||
|
||||
#unet_info = context.services.model_manager.get_model(**self.unet.unet.dict())
|
||||
unet_info = context.services.model_manager.get_model(
|
||||
model_name=self.unet.unet.model_name,
|
||||
model_type=SDModelType[self.unet.unet.model_type],
|
||||
submodel=SDModelType[self.unet.unet.submodel] if self.unet.unet.submodel else None,
|
||||
)
|
||||
|
||||
with unet_info.context as unet:
|
||||
unet_info = context.services.model_manager.get_model(**self.unet.unet.dict())
|
||||
with unet_info as unet:
|
||||
scheduler = get_scheduler(
|
||||
context=context,
|
||||
scheduler_info=self.unet.scheduler,
|
||||
@ -309,12 +299,10 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
|
||||
|
||||
#unet_info = context.services.model_manager.get_model(**self.unet.unet.dict())
|
||||
unet_info = context.services.model_manager.get_model(
|
||||
model_name=self.unet.unet.model_name,
|
||||
model_type=SDModelType[self.unet.unet.model_type],
|
||||
submodel=SDModelType[self.unet.unet.submodel] if self.unet.unet.submodel else None,
|
||||
**self.unet.unet.dict(),
|
||||
)
|
||||
|
||||
with unet_info.context as unet:
|
||||
with unet_info as unet:
|
||||
scheduler = get_scheduler(
|
||||
context=context,
|
||||
scheduler_info=self.unet.scheduler,
|
||||
@ -379,18 +367,18 @@ class LatentsToImageInvocation(BaseInvocation):
|
||||
|
||||
#vae_info = context.services.model_manager.get_model(**self.vae.vae.dict())
|
||||
vae_info = context.services.model_manager.get_model(
|
||||
model_name=self.vae.vae.model_name,
|
||||
model_type=SDModelType[self.vae.vae.model_type],
|
||||
submodel=SDModelType[self.vae.vae.submodel] if self.vae.vae.submodel else None,
|
||||
**self.vae.vae.dict(),
|
||||
)
|
||||
|
||||
with vae_info.context as vae:
|
||||
# TODO: check if it works
|
||||
with vae_info as vae:
|
||||
if self.tiled:
|
||||
vae.enable_tiling()
|
||||
else:
|
||||
vae.disable_tiling()
|
||||
|
||||
# clear memory as vae decode can request a lot
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
with torch.inference_mode():
|
||||
# copied from diffusers pipeline
|
||||
latents = latents / vae.config.scaling_factor
|
||||
@ -509,36 +497,29 @@ class ImageToLatentsInvocation(BaseInvocation):
|
||||
|
||||
#vae_info = context.services.model_manager.get_model(**self.vae.vae.dict())
|
||||
vae_info = context.services.model_manager.get_model(
|
||||
model_name=self.vae.vae.model_name,
|
||||
model_type=SDModelType[self.vae.vae.model_type],
|
||||
submodel=SDModelType[self.vae.vae.submodel] if self.vae.vae.submodel else None,
|
||||
**self.vae.vae.dict(),
|
||||
)
|
||||
|
||||
image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB"))
|
||||
if image_tensor.dim() == 3:
|
||||
image_tensor = einops.rearrange(image_tensor, "c h w -> 1 c h w")
|
||||
|
||||
with vae_info.context as vae:
|
||||
# TODO: check if it works
|
||||
with vae_info as vae:
|
||||
if self.tiled:
|
||||
vae.enable_tiling()
|
||||
else:
|
||||
vae.disable_tiling()
|
||||
|
||||
latents = self.non_noised_latents_from_image(vae, image_tensor)
|
||||
# non_noised_latents_from_image
|
||||
image_tensor = image_tensor.to(device=vae.device, dtype=vae.dtype)
|
||||
with torch.inference_mode():
|
||||
image_tensor_dist = vae.encode(image_tensor).latent_dist
|
||||
latents = image_tensor_dist.sample().to(
|
||||
dtype=vae.dtype
|
||||
) # FIXME: uses torch.randn. make reproducible!
|
||||
|
||||
latents = 0.18215 * latents
|
||||
|
||||
name = f"{context.graph_execution_state_id}__{self.id}"
|
||||
context.services.latents.set(name, latents)
|
||||
return LatentsOutput(latents=LatentsField(latents_name=name))
|
||||
|
||||
|
||||
def non_noised_latents_from_image(self, vae, init_image):
|
||||
init_image = init_image.to(device=vae.device, dtype=vae.dtype)
|
||||
with torch.inference_mode():
|
||||
init_latent_dist = vae.encode(init_image).latent_dist
|
||||
init_latents = init_latent_dist.sample().to(
|
||||
dtype=vae.dtype
|
||||
) # FIXME: uses torch.randn. make reproducible!
|
||||
|
||||
init_latents = 0.18215 * init_latents
|
||||
return init_latents
|
@ -8,8 +8,8 @@ from ...backend.model_management import SDModelType
|
||||
|
||||
class ModelInfo(BaseModel):
|
||||
model_name: str = Field(description="Info to load unet submodel")
|
||||
model_type: str = Field(description="Info to load unet submodel")
|
||||
submodel: Optional[str] = Field(description="Info to load unet submodel")
|
||||
model_type: SDModelType = Field(description="Info to load unet submodel")
|
||||
submodel: Optional[SDModelType] = Field(description="Info to load unet submodel")
|
||||
|
||||
class UNetField(BaseModel):
|
||||
unet: ModelInfo = Field(description="Info to load unet submodel")
|
||||
@ -62,15 +62,15 @@ class ModelLoaderInvocation(BaseInvocation):
|
||||
# TODO: not found exceptions
|
||||
if not context.services.model_manager.model_exists(
|
||||
model_name=self.model_name,
|
||||
model_type=SDModelType.diffusers,
|
||||
model_type=SDModelType.Diffusers,
|
||||
):
|
||||
raise Exception(f"Unkown model name: {self.model_name}!")
|
||||
|
||||
"""
|
||||
if not context.services.model_manager.model_exists(
|
||||
model_name=self.model_name,
|
||||
model_type=SDModelType.diffusers,
|
||||
submodel=SDModelType.tokenizer,
|
||||
model_type=SDModelType.Diffusers,
|
||||
submodel=SDModelType.Tokenizer,
|
||||
):
|
||||
raise Exception(
|
||||
f"Failed to find tokenizer submodel in {self.model_name}! Check if model corrupted"
|
||||
@ -78,8 +78,8 @@ class ModelLoaderInvocation(BaseInvocation):
|
||||
|
||||
if not context.services.model_manager.model_exists(
|
||||
model_name=self.model_name,
|
||||
model_type=SDModelType.diffusers,
|
||||
submodel=SDModelType.text_encoder,
|
||||
model_type=SDModelType.Diffusers,
|
||||
submodel=SDModelType.TextEncoder,
|
||||
):
|
||||
raise Exception(
|
||||
f"Failed to find text_encoder submodel in {self.model_name}! Check if model corrupted"
|
||||
@ -87,8 +87,8 @@ class ModelLoaderInvocation(BaseInvocation):
|
||||
|
||||
if not context.services.model_manager.model_exists(
|
||||
model_name=self.model_name,
|
||||
model_type=SDModelType.diffusers,
|
||||
submodel=SDModelType.unet,
|
||||
model_type=SDModelType.Diffusers,
|
||||
submodel=SDModelType.UNet,
|
||||
):
|
||||
raise Exception(
|
||||
f"Failed to find unet submodel from {self.model_name}! Check if model corrupted"
|
||||
@ -100,32 +100,32 @@ class ModelLoaderInvocation(BaseInvocation):
|
||||
unet=UNetField(
|
||||
unet=ModelInfo(
|
||||
model_name=self.model_name,
|
||||
model_type=SDModelType.diffusers.name,
|
||||
submodel=SDModelType.unet.name,
|
||||
model_type=SDModelType.Diffusers,
|
||||
submodel=SDModelType.UNet,
|
||||
),
|
||||
scheduler=ModelInfo(
|
||||
model_name=self.model_name,
|
||||
model_type=SDModelType.diffusers.name,
|
||||
submodel=SDModelType.scheduler.name,
|
||||
model_type=SDModelType.Diffusers,
|
||||
submodel=SDModelType.Scheduler,
|
||||
),
|
||||
),
|
||||
clip=ClipField(
|
||||
tokenizer=ModelInfo(
|
||||
model_name=self.model_name,
|
||||
model_type=SDModelType.diffusers.name,
|
||||
submodel=SDModelType.tokenizer.name,
|
||||
model_type=SDModelType.Diffusers,
|
||||
submodel=SDModelType.Tokenizer,
|
||||
),
|
||||
text_encoder=ModelInfo(
|
||||
model_name=self.model_name,
|
||||
model_type=SDModelType.diffusers.name,
|
||||
submodel=SDModelType.text_encoder.name,
|
||||
model_type=SDModelType.Diffusers,
|
||||
submodel=SDModelType.TextEncoder,
|
||||
),
|
||||
),
|
||||
vae=VaeField(
|
||||
vae=ModelInfo(
|
||||
model_name=self.model_name,
|
||||
model_type=SDModelType.diffusers.name,
|
||||
submodel=SDModelType.vae.name,
|
||||
model_type=SDModelType.Diffusers,
|
||||
submodel=SDModelType.Vae,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
@ -120,7 +120,7 @@ class EventServiceBase:
|
||||
node=node,
|
||||
source_node_id=source_node_id,
|
||||
model_name=model_name,
|
||||
model_type=model_type.name,
|
||||
model_type=model_type,
|
||||
submodel=submodel,
|
||||
),
|
||||
)
|
||||
@ -143,7 +143,7 @@ class EventServiceBase:
|
||||
node=node,
|
||||
source_node_id=source_node_id,
|
||||
model_name=model_name,
|
||||
model_type=model_type.name,
|
||||
model_type=model_type,
|
||||
submodel=submodel,
|
||||
model_info=model_info,
|
||||
),
|
||||
|
@ -1,21 +1,25 @@
|
||||
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Team
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import torch
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import Union, Callable, List, Tuple, types
|
||||
from typing import Union, Callable, List, Tuple, types, TYPE_CHECKING
|
||||
from dataclasses import dataclass
|
||||
|
||||
from invokeai.backend.model_management.model_manager import (
|
||||
ModelManager,
|
||||
SDModelType,
|
||||
SDModelInfo,
|
||||
torch,
|
||||
)
|
||||
from invokeai.app.models.exceptions import CanceledException
|
||||
from ...backend import Args, Globals # this must go when pr 3340 merged
|
||||
from ...backend.util import choose_precision, choose_torch_device
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..invocations.baseinvocation import BaseInvocation, InvocationContext
|
||||
|
||||
@dataclass
|
||||
class LastUsedModel:
|
||||
model_name: str=None
|
||||
@ -30,7 +34,7 @@ class ModelManagerServiceBase(ABC):
|
||||
def __init__(
|
||||
self,
|
||||
config: Args,
|
||||
logger: types.ModuleType
|
||||
logger: types.ModuleType,
|
||||
):
|
||||
"""
|
||||
Initialize with the path to the models.yaml config file.
|
||||
@ -41,12 +45,13 @@ class ModelManagerServiceBase(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_model(self,
|
||||
def get_model(
|
||||
self,
|
||||
model_name: str,
|
||||
model_type: SDModelType=SDModelType.diffusers,
|
||||
submodel: SDModelType=None,
|
||||
node=None, # circular dependency issues, so untyped at moment
|
||||
context=None,
|
||||
model_type: SDModelType = SDModelType.Diffusers,
|
||||
submodel: Optional[SDModelType] = None,
|
||||
node: Optional[BaseInvocation] = None,
|
||||
context: Optional[InvocationContext] = None,
|
||||
) -> SDModelInfo:
|
||||
"""Retrieve the indicated model with name and type.
|
||||
submodel can be used to get a part (such as the vae)
|
||||
@ -62,12 +67,12 @@ class ModelManagerServiceBase(ABC):
|
||||
def model_exists(
|
||||
self,
|
||||
model_name: str,
|
||||
model_type: SDModelType
|
||||
model_type: SDModelType,
|
||||
) -> bool:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def default_model(self) -> Union[Tuple[str, SDModelType],None]:
|
||||
def default_model(self) -> Optional[Tuple[str, SDModelType]]:
|
||||
"""
|
||||
Returns the name and typeof the default model, or None
|
||||
if none is defined.
|
||||
@ -126,10 +131,12 @@ class ModelManagerServiceBase(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def del_model(self,
|
||||
def del_model(
|
||||
self,
|
||||
model_name: str,
|
||||
model_type: SDModelType,
|
||||
delete_files: bool = False):
|
||||
delete_files: bool = False,
|
||||
):
|
||||
"""
|
||||
Delete the named model from configuration. If delete_files is true,
|
||||
then the underlying weight file or diffusers directory will be deleted
|
||||
@ -140,9 +147,9 @@ class ModelManagerServiceBase(ABC):
|
||||
@abstractmethod
|
||||
def import_diffuser_model(
|
||||
repo_or_path: Union[str, Path],
|
||||
model_name: str = None,
|
||||
description: str = None,
|
||||
vae: dict = None,
|
||||
model_name: Optional[str] = None,
|
||||
description: Optional[str] = None,
|
||||
vae: Optional[dict] = None,
|
||||
) -> bool:
|
||||
"""
|
||||
Install the indicated diffuser model and returns True if successful.
|
||||
@ -159,8 +166,8 @@ class ModelManagerServiceBase(ABC):
|
||||
def import_lora(
|
||||
self,
|
||||
path: Path,
|
||||
model_name: str=None,
|
||||
description: str=None,
|
||||
model_name: Optional[str] = None,
|
||||
description: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Creates an entry for the indicated lora file. Call
|
||||
@ -237,7 +244,7 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
def __init__(
|
||||
self,
|
||||
config: Args,
|
||||
logger: types.ModuleType
|
||||
logger: types.ModuleType,
|
||||
):
|
||||
"""
|
||||
Initialize with the path to the models.yaml config file.
|
||||
@ -272,21 +279,23 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
|
||||
sequential_offload = config.sequential_guidance
|
||||
|
||||
self.mgr = ModelManager(config=config_file,
|
||||
self.mgr = ModelManager(
|
||||
config=config_file,
|
||||
device_type=device,
|
||||
precision=dtype,
|
||||
max_cache_size=max_cache_size,
|
||||
sequential_offload=sequential_offload,
|
||||
logger=logger
|
||||
logger=logger,
|
||||
)
|
||||
logger.info('Model manager service initialized')
|
||||
|
||||
def get_model(self,
|
||||
def get_model(
|
||||
self,
|
||||
model_name: str,
|
||||
model_type: SDModelType=SDModelType.diffusers,
|
||||
submodel: SDModelType=None,
|
||||
node=None,
|
||||
context=None,
|
||||
model_type: SDModelType = SDModelType.Diffusers,
|
||||
submodel: Optional[SDModelType] = None,
|
||||
node: Optional[BaseInvocation] = None,
|
||||
context: Optional[InvocationContext] = None,
|
||||
) -> SDModelInfo:
|
||||
"""
|
||||
Retrieve the indicated model. submodel can be used to get a
|
||||
@ -342,7 +351,7 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
def model_exists(
|
||||
self,
|
||||
model_name: str,
|
||||
model_type: SDModelType
|
||||
model_type: SDModelType,
|
||||
) -> bool:
|
||||
"""
|
||||
Given a model name, returns True if it is a valid
|
||||
@ -350,9 +359,10 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
"""
|
||||
return self.mgr.model_exists(
|
||||
model_name,
|
||||
model_type)
|
||||
model_type,
|
||||
)
|
||||
|
||||
def default_model(self) -> Union[Tuple[str, SDModelType],None]:
|
||||
def default_model(self) -> Optional[Tuple[str, SDModelType]]:
|
||||
"""
|
||||
Returns the name of the default model, or None
|
||||
if none is defined.
|
||||
@ -392,7 +402,8 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
model_name: str,
|
||||
model_type: SDModelType,
|
||||
model_attributes: dict,
|
||||
clobber: bool = False)->None:
|
||||
clobber: bool = False,
|
||||
)->None:
|
||||
"""
|
||||
Update the named model with a dictionary of attributes. Will fail with an
|
||||
assertion error if the name already exists. Pass clobber=True to overwrite.
|
||||
@ -400,13 +411,14 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
with an assertion error if provided attributes are incorrect or
|
||||
the model name is missing. Call commit() to write changes to disk.
|
||||
"""
|
||||
return self.mgr.add_model(model_name, model_type, model_attributes, dict, clobber)
|
||||
return self.mgr.add_model(model_name, model_type, model_attributes, clobber)
|
||||
|
||||
|
||||
def del_model(self,
|
||||
def del_model(
|
||||
self,
|
||||
model_name: str,
|
||||
model_type: SDModelType=SDModelType.diffusers,
|
||||
delete_files: bool = False
|
||||
model_type: SDModelType = SDModelType.Diffusers,
|
||||
delete_files: bool = False,
|
||||
):
|
||||
"""
|
||||
Delete the named model from configuration. If delete_files is true,
|
||||
@ -418,9 +430,9 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
def import_diffuser_model(
|
||||
self,
|
||||
repo_or_path: Union[str, Path],
|
||||
model_name: str = None,
|
||||
description: str = None,
|
||||
vae: dict = None,
|
||||
model_name: Optional[str] = None,
|
||||
description: Optional[str] = None,
|
||||
vae: Optional[dict] = None,
|
||||
) -> bool:
|
||||
"""
|
||||
Install the indicated diffuser model and returns True if successful.
|
||||
@ -436,8 +448,8 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
def import_lora(
|
||||
self,
|
||||
path: Path,
|
||||
model_name: str=None,
|
||||
description: str=None,
|
||||
model_name: Optional[str] = None,
|
||||
description: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Creates an entry for the indicated lora file. Call
|
||||
@ -448,8 +460,8 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
def import_embedding(
|
||||
self,
|
||||
path: Path,
|
||||
model_name: str=None,
|
||||
description: str=None,
|
||||
model_name: Optional[str] = None,
|
||||
description: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Creates an entry for the indicated textual inversion embedding file.
|
||||
@ -462,9 +474,9 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
path_url_or_repo: str,
|
||||
model_name: str = None,
|
||||
description: str = None,
|
||||
model_config_file: Path = None,
|
||||
commit_to_conf: Path = None,
|
||||
config_file_callback: Callable[[Path], Path] = None,
|
||||
model_config_file: Optional[Path] = None,
|
||||
commit_to_conf: Optional[Path] = None,
|
||||
config_file_callback: Optional[Callable[[Path], Path]] = None,
|
||||
) -> str:
|
||||
"""Accept a string which could be:
|
||||
- a HF diffusers repo_id
|
||||
@ -505,7 +517,7 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
)
|
||||
|
||||
|
||||
def commit(self, conf_file: Path=None):
|
||||
def commit(self, conf_file: Optional[Path]=None):
|
||||
"""
|
||||
Write current configuration out to the indicated file.
|
||||
If no conf_file is provided, then replaces the
|
||||
@ -520,10 +532,10 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
model_name: str,
|
||||
model_type: SDModelType,
|
||||
submodel: SDModelType,
|
||||
model_info: SDModelInfo=None,
|
||||
model_info: Optional[SDModelInfo] = None,
|
||||
):
|
||||
if context.services.queue.is_canceled(context.graph_execution_state_id):
|
||||
raise CanceledException
|
||||
raise CanceledException()
|
||||
graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id)
|
||||
source_node_id = graph_execution_state.prepared_source_mapping[node.id]
|
||||
if context:
|
||||
|
@ -23,12 +23,12 @@ import warnings
|
||||
from collections import Counter
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Dict, Sequence, Union, Tuple, types
|
||||
from typing import Dict, Sequence, Union, Tuple, types, Optional
|
||||
|
||||
import torch
|
||||
import safetensors.torch
|
||||
|
||||
from diffusers import DiffusionPipeline, StableDiffusionPipeline, AutoencoderKL, SchedulerMixin, UNet2DConditionModel
|
||||
from diffusers import DiffusionPipeline, StableDiffusionPipeline, AutoencoderKL, SchedulerMixin, UNet2DConditionModel, ConfigMixin
|
||||
from diffusers import logging as diffusers_logging
|
||||
from diffusers.pipelines.stable_diffusion.safety_checker import \
|
||||
StableDiffusionSafetyChecker
|
||||
@ -55,19 +55,37 @@ class LoraType(dict):
|
||||
class TIType(dict):
|
||||
pass
|
||||
|
||||
class SDModelType(Enum):
|
||||
diffusers=StableDiffusionGeneratorPipeline # whole pipeline
|
||||
vae=AutoencoderKL # diffusers parts
|
||||
text_encoder=CLIPTextModel
|
||||
tokenizer=CLIPTokenizer
|
||||
unet=UNet2DConditionModel
|
||||
scheduler=SchedulerMixin
|
||||
safety_checker=StableDiffusionSafetyChecker
|
||||
feature_extractor=CLIPFeatureExtractor
|
||||
class SDModelType(str, Enum):
|
||||
Diffusers="diffusers" # whole pipeline
|
||||
Vae="vae" # diffusers parts
|
||||
TextEncoder="text_encoder"
|
||||
Tokenizer="tokenizer"
|
||||
UNet="unet"
|
||||
Scheduler="scheduler"
|
||||
SafetyChecker="safety_checker"
|
||||
FeatureExtractor="feature_extractor"
|
||||
# These are all loaded as dicts of tensors, and we
|
||||
# distinguish them by class
|
||||
lora=LoraType
|
||||
textual_inversion=TIType
|
||||
Lora="lora"
|
||||
TextualInversion="textual_inversion"
|
||||
|
||||
# TODO:
|
||||
class EmptyScheduler(SchedulerMixin, ConfigMixin):
|
||||
pass
|
||||
|
||||
MODEL_CLASSES = {
|
||||
SDModelType.Diffusers: StableDiffusionGeneratorPipeline,
|
||||
SDModelType.Vae: AutoencoderKL,
|
||||
SDModelType.TextEncoder: CLIPTextModel, # TODO: t5
|
||||
SDModelType.Tokenizer: CLIPTokenizer, # TODO: t5
|
||||
SDModelType.UNet: UNet2DConditionModel,
|
||||
SDModelType.Scheduler: EmptyScheduler,
|
||||
SDModelType.SafetyChecker: StableDiffusionSafetyChecker,
|
||||
SDModelType.FeatureExtractor: CLIPFeatureExtractor,
|
||||
|
||||
SDModelType.Lora: LoraType,
|
||||
SDModelType.TextualInversion: TIType,
|
||||
}
|
||||
|
||||
class ModelStatus(Enum):
|
||||
unknown='unknown'
|
||||
@ -80,21 +98,21 @@ class ModelStatus(Enum):
|
||||
# After loading, we will know it exactly.
|
||||
# Sizes are in Gigs, estimated for float16; double for float32
|
||||
SIZE_GUESSTIMATE = {
|
||||
SDModelType.diffusers: 2.2,
|
||||
SDModelType.vae: 0.35,
|
||||
SDModelType.text_encoder: 0.5,
|
||||
SDModelType.tokenizer: 0.001,
|
||||
SDModelType.unet: 3.4,
|
||||
SDModelType.scheduler: 0.001,
|
||||
SDModelType.safety_checker: 1.2,
|
||||
SDModelType.feature_extractor: 0.001,
|
||||
SDModelType.lora: 0.1,
|
||||
SDModelType.textual_inversion: 0.001,
|
||||
SDModelType.Diffusers: 2.2,
|
||||
SDModelType.Vae: 0.35,
|
||||
SDModelType.TextEncoder: 0.5,
|
||||
SDModelType.Tokenizer: 0.001,
|
||||
SDModelType.UNet: 3.4,
|
||||
SDModelType.Scheduler: 0.001,
|
||||
SDModelType.SafetyChecker: 1.2,
|
||||
SDModelType.FeatureExtractor: 0.001,
|
||||
SDModelType.Lora: 0.1,
|
||||
SDModelType.TextualInversion: 0.001,
|
||||
}
|
||||
|
||||
# The list of model classes we know how to fetch, for typechecking
|
||||
ModelClass = Union[tuple([x.value for x in SDModelType])]
|
||||
DiffusionClasses = (StableDiffusionGeneratorPipeline, AutoencoderKL, SchedulerMixin, UNet2DConditionModel)
|
||||
ModelClass = Union[tuple([x for x in MODEL_CLASSES.values()])]
|
||||
DiffusionClasses = (StableDiffusionGeneratorPipeline, AutoencoderKL, EmptyScheduler, UNet2DConditionModel)
|
||||
|
||||
class UnsafeModelException(Exception):
|
||||
"Raised when a legacy model file fails the picklescan test"
|
||||
@ -147,7 +165,7 @@ class ModelCache(object):
|
||||
def get_model(
|
||||
self,
|
||||
repo_id_or_path: Union[str, Path],
|
||||
model_type: SDModelType=SDModelType.diffusers,
|
||||
model_type: SDModelType = SDModelType.Diffusers,
|
||||
subfolder: Path = None,
|
||||
submodel: SDModelType = None,
|
||||
revision: str = None,
|
||||
@ -178,14 +196,14 @@ class ModelCache(object):
|
||||
|
||||
vae_context = cache.get_model(
|
||||
'stabilityai/sd-stable-diffusion-2',
|
||||
submodel=SDModelType.vae
|
||||
submodel=SDModelType.Vae
|
||||
)
|
||||
|
||||
This is equivalent to:
|
||||
|
||||
vae_context = cache.get_model(
|
||||
'stabilityai/sd-stable-diffusion-2',
|
||||
model_type = SDModelType.vae,
|
||||
model_type = SDModelType.Vae,
|
||||
subfolder='vae'
|
||||
)
|
||||
|
||||
@ -195,14 +213,14 @@ class ModelCache(object):
|
||||
|
||||
pipeline_context = cache.get_model(
|
||||
'runwayml/stable-diffusion-v1-5',
|
||||
attach_model_part=(SDModelType.vae,'stabilityai/sd-vae-ft-mse')
|
||||
attach_model_part=(SDModelType.Vae,'stabilityai/sd-vae-ft-mse')
|
||||
)
|
||||
|
||||
The model will be locked into GPU VRAM for the duration of the context.
|
||||
:param repo_id_or_path: either the HuggingFace repo_id or a Path to a local model
|
||||
:param model_type: An SDModelType enum indicating the type of the (parent) model
|
||||
:param subfolder: name of a subfolder in which the model can be found, e.g. "vae"
|
||||
:param submodel: an SDModelType enum indicating the model part to return, e.g. SDModelType.vae
|
||||
:param submodel: an SDModelType enum indicating the model part to return, e.g. SDModelType.Vae
|
||||
:param attach_model_part: load and attach a diffusers model component. Pass a tuple of format (SDModelType,repo_id)
|
||||
:param revision: model revision
|
||||
:param gpu_load: load the model into GPU [default True]
|
||||
@ -211,7 +229,7 @@ class ModelCache(object):
|
||||
repo_id_or_path,
|
||||
revision,
|
||||
subfolder,
|
||||
model_type.value,
|
||||
model_type,
|
||||
)
|
||||
|
||||
# optimization: if caller is asking to load a submodel of a diffusers pipeline, then
|
||||
@ -221,7 +239,7 @@ class ModelCache(object):
|
||||
repo_id_or_path,
|
||||
None,
|
||||
revision,
|
||||
SDModelType.diffusers.value
|
||||
SDModelType.Diffusers
|
||||
)
|
||||
if possible_parent_key in self.models:
|
||||
key = possible_parent_key
|
||||
@ -256,24 +274,24 @@ class ModelCache(object):
|
||||
self.current_cache_size += mem_used # increment size of the cache
|
||||
|
||||
# this is a bit of legacy work needed to support the old-style "load this diffuser with custom VAE"
|
||||
if model_type==SDModelType.diffusers and attach_model_part[0]:
|
||||
if model_type == SDModelType.Diffusers and attach_model_part[0]:
|
||||
self.attach_part(model, *attach_model_part)
|
||||
|
||||
self.stack.append(key) # add to LRU cache
|
||||
self.models[key] = model # keep copy of model in dict
|
||||
|
||||
if submodel:
|
||||
model = getattr(model, submodel.name)
|
||||
model = getattr(model, submodel)
|
||||
|
||||
return self.ModelLocker(self, key, model, gpu_load)
|
||||
|
||||
def uncache_model(self, key: str):
|
||||
'''Remove corresponding model from the cache'''
|
||||
if key is not None and key in self.models:
|
||||
with contextlib.suppress(ValueError), contextlib.suppress(KeyError):
|
||||
del self.models[key]
|
||||
del self.locked_models[key]
|
||||
self.loaded_models.remove(key)
|
||||
self.models.pop(key, None)
|
||||
self.locked_models.pop(key, None)
|
||||
self.loaded_models.discard(key)
|
||||
with contextlib.suppress(ValueError):
|
||||
self.stack.remove(key)
|
||||
|
||||
class ModelLocker(object):
|
||||
@ -302,7 +320,7 @@ class ModelCache(object):
|
||||
if model.device != cache.execution_device:
|
||||
cache.logger.debug(f'Moving {key} into {cache.execution_device}')
|
||||
with VRAMUsage() as mem:
|
||||
model.to(cache.execution_device) # move into GPU
|
||||
model.to(cache.execution_device, dtype=cache.precision) # move into GPU
|
||||
cache.logger.debug(f'GPU VRAM used for load: {(mem.vram_used/GIG):.2f} GB')
|
||||
cache.model_sizes[key] = mem.vram_used # more accurate size
|
||||
|
||||
@ -319,6 +337,9 @@ class ModelCache(object):
|
||||
return model
|
||||
|
||||
def __exit__(self, type, value, traceback):
|
||||
if not hasattr(self.model, 'to'):
|
||||
return
|
||||
|
||||
key = self.key
|
||||
cache = self.cache
|
||||
cache.locked_models[key] -= 1
|
||||
@ -326,10 +347,11 @@ class ModelCache(object):
|
||||
cache._offload_unlocked_models()
|
||||
cache._print_cuda_stats()
|
||||
|
||||
def attach_part(self,
|
||||
def attach_part(
|
||||
self,
|
||||
diffusers_model: StableDiffusionPipeline,
|
||||
part_type: SDModelType,
|
||||
part_id: str
|
||||
part_id: str,
|
||||
):
|
||||
'''
|
||||
Attach a diffusers model part to a diffusers model. This can be
|
||||
@ -338,19 +360,18 @@ class ModelCache(object):
|
||||
:param part_type: An SD ModelType indicating the part
|
||||
:param part_id: A HF repo_id for the part
|
||||
'''
|
||||
part_key = part_type.name
|
||||
part_class = part_type.value
|
||||
part = self._load_diffusers_from_storage(
|
||||
part_id,
|
||||
model_class=part_class,
|
||||
model_class=MODEL_CLASSES[part_type],
|
||||
)
|
||||
part.to(diffusers_model.device)
|
||||
setattr(diffusers_model,part_key,part)
|
||||
self.logger.debug(f'Attached {part_key} {part_id}')
|
||||
setattr(diffusers_model, part_type, part)
|
||||
self.logger.debug(f'Attached {part_type} {part_id}')
|
||||
|
||||
def status(self,
|
||||
def status(
|
||||
self,
|
||||
repo_id_or_path: Union[str, Path],
|
||||
model_type: SDModelType=SDModelType.diffusers,
|
||||
model_type: SDModelType = SDModelType.Diffusers,
|
||||
revision: str = None,
|
||||
subfolder: Path = None,
|
||||
) -> ModelStatus:
|
||||
@ -358,7 +379,7 @@ class ModelCache(object):
|
||||
repo_id_or_path,
|
||||
revision,
|
||||
subfolder,
|
||||
model_type.value,
|
||||
model_type,
|
||||
)
|
||||
if key not in self.models:
|
||||
return ModelStatus.not_loaded
|
||||
@ -370,9 +391,11 @@ class ModelCache(object):
|
||||
else:
|
||||
return ModelStatus.in_ram
|
||||
|
||||
def model_hash(self,
|
||||
def model_hash(
|
||||
self,
|
||||
repo_id_or_path: Union[str, Path],
|
||||
revision: str="main")->str:
|
||||
revision: str = "main",
|
||||
) -> str:
|
||||
'''
|
||||
Given the HF repo id or path to a model on disk, returns a unique
|
||||
hash. Works for legacy checkpoint files, HF models on disk, and HF repo IDs
|
||||
@ -408,7 +431,12 @@ class ModelCache(object):
|
||||
|
||||
@staticmethod
|
||||
def _model_key(path, revision, subfolder, model_class) -> str:
|
||||
return ':'.join([str(path),str(revision or ''),str(subfolder or ''),model_class.__name__])
|
||||
return ':'.join([
|
||||
str(path),
|
||||
str(revision or ''),
|
||||
str(subfolder or ''),
|
||||
model_class,
|
||||
])
|
||||
|
||||
def _has_cuda(self) -> bool:
|
||||
return self.execution_device.type == 'cuda'
|
||||
@ -452,29 +480,29 @@ class ModelCache(object):
|
||||
def _load_model_from_storage(
|
||||
self,
|
||||
repo_id_or_path: Union[str, Path],
|
||||
subfolder: Path=None,
|
||||
revision: str=None,
|
||||
model_type: SDModelType=SDModelType.diffusers,
|
||||
subfolder: Optional[Path] = None,
|
||||
revision: Optional[str] = None,
|
||||
model_type: SDModelType = SDModelType.Diffusers,
|
||||
) -> ModelClass:
|
||||
'''
|
||||
Load and return a HuggingFace model.
|
||||
:param repo_id_or_path: either the HuggingFace repo_id or a Path to a local model
|
||||
:param subfolder: name of a subfolder in which the model can be found, e.g. "vae"
|
||||
:param revision: model revision
|
||||
:param model_type: type of model to return, defaults to SDModelType.diffusers
|
||||
:param model_type: type of model to return, defaults to SDModelType.Diffusers
|
||||
'''
|
||||
# silence transformer and diffuser warnings
|
||||
with SilenceWarnings():
|
||||
if model_type==SDModelType.lora:
|
||||
if model_type==SDModelType.Lora:
|
||||
model = self._load_lora_from_storage(repo_id_or_path)
|
||||
elif model_type==SDModelType.textual_inversion:
|
||||
elif model_type==SDModelType.TextualInversion:
|
||||
model = self._load_ti_from_storage(repo_id_or_path)
|
||||
else:
|
||||
model = self._load_diffusers_from_storage(
|
||||
repo_id_or_path,
|
||||
subfolder,
|
||||
revision,
|
||||
model_type.value,
|
||||
model_type,
|
||||
)
|
||||
if self.sequential_offload and isinstance(model, StableDiffusionGeneratorPipeline):
|
||||
model.enable_offload_submodels(self.execution_device)
|
||||
@ -483,9 +511,9 @@ class ModelCache(object):
|
||||
def _load_diffusers_from_storage(
|
||||
self,
|
||||
repo_id_or_path: Union[str, Path],
|
||||
subfolder: Path=None,
|
||||
revision: str=None,
|
||||
model_class: ModelClass=StableDiffusionGeneratorPipeline,
|
||||
subfolder: Optional[Path] = None,
|
||||
revision: Optional[str] = None,
|
||||
model_type: ModelClass = StableDiffusionGeneratorPipeline,
|
||||
) -> ModelClass:
|
||||
'''
|
||||
Load and return a HuggingFace model using from_pretrained().
|
||||
@ -494,13 +522,22 @@ class ModelCache(object):
|
||||
:param revision: model revision
|
||||
:param model_class: class of model to return, defaults to StableDiffusionGeneratorPIpeline
|
||||
'''
|
||||
revisions = [revision] if revision \
|
||||
else ['fp16','main'] if self.precision==torch.float16 \
|
||||
else ['main']
|
||||
extra_args = {'torch_dtype': self.precision,
|
||||
'safety_checker': None}\
|
||||
if model_class in DiffusionClasses\
|
||||
else {}
|
||||
|
||||
model_class = MODEL_CLASSES[model_type]
|
||||
|
||||
if revision is not None:
|
||||
revisions = [revision]
|
||||
elif self.precision == torch.float16:
|
||||
revisions = ['fp16', 'main']
|
||||
else:
|
||||
revisions = ['main']
|
||||
|
||||
extra_args = dict()
|
||||
if model_class in DiffusionClasses:
|
||||
extra_args = dict(
|
||||
torch_dtype=self.precision,
|
||||
safety_checker=None,
|
||||
)
|
||||
|
||||
for rev in revisions:
|
||||
try:
|
||||
@ -517,10 +554,10 @@ class ModelCache(object):
|
||||
pass
|
||||
return model
|
||||
|
||||
def _load_lora_from_storage(self, lora_path: Path)->SDModelType.lora.value:
|
||||
def _load_lora_from_storage(self, lora_path: Path) -> LoraType:
|
||||
assert False, "_load_lora_from_storage() is not yet implemented"
|
||||
|
||||
def _load_ti_from_storage(self, lora_path: Path)->SDModelType.textual_inversion.value:
|
||||
def _load_ti_from_storage(self, lora_path: Path) -> TIType:
|
||||
assert False, "_load_ti_from_storage() is not yet implemented"
|
||||
|
||||
def _legacy_model_hash(self, checkpoint_path: Union[str, Path]) -> str:
|
||||
|
@ -27,7 +27,7 @@ Typical usage:
|
||||
max_cache_size=8
|
||||
) # gigabytes
|
||||
|
||||
model_info = manager.get_model('stable-diffusion-1.5', SDModelType.diffusers)
|
||||
model_info = manager.get_model('stable-diffusion-1.5', SDModelType.Diffusers)
|
||||
with model_info.context as my_model:
|
||||
my_model.latents_from_embeddings(...)
|
||||
|
||||
@ -45,7 +45,7 @@ parameter:
|
||||
|
||||
model_info = manager.get_model(
|
||||
'clip-tokenizer',
|
||||
model_type=SDModelType.tokenizer
|
||||
model_type=SDModelType.Tokenizer
|
||||
)
|
||||
|
||||
This will raise an InvalidModelError if the format defined in the
|
||||
@ -96,7 +96,7 @@ SUBMODELS:
|
||||
It is also possible to fetch an isolated submodel from a diffusers
|
||||
model. Use the `submodel` parameter to select which part:
|
||||
|
||||
vae = manager.get_model('stable-diffusion-1.5',submodel=SDModelType.vae)
|
||||
vae = manager.get_model('stable-diffusion-1.5',submodel=SDModelType.Vae)
|
||||
with vae.context as my_vae:
|
||||
print(type(my_vae))
|
||||
# "AutoencoderKL"
|
||||
@ -120,8 +120,8 @@ separated by "/". Example:
|
||||
You can now use the `model_type` argument to indicate which model you
|
||||
want:
|
||||
|
||||
tokenizer = mgr.get('clip-large',model_type=SDModelType.tokenizer)
|
||||
encoder = mgr.get('clip-large',model_type=SDModelType.text_encoder)
|
||||
tokenizer = mgr.get('clip-large',model_type=SDModelType.Tokenizer)
|
||||
encoder = mgr.get('clip-large',model_type=SDModelType.TextEncoder)
|
||||
|
||||
OTHER FUNCTIONS:
|
||||
|
||||
@ -254,7 +254,7 @@ class ModelManager(object):
|
||||
def model_exists(
|
||||
self,
|
||||
model_name: str,
|
||||
model_type: SDModelType = SDModelType.diffusers,
|
||||
model_type: SDModelType = SDModelType.Diffusers,
|
||||
) -> bool:
|
||||
"""
|
||||
Given a model name, returns True if it is a valid
|
||||
@ -264,28 +264,28 @@ class ModelManager(object):
|
||||
return model_key in self.config
|
||||
|
||||
def create_key(self, model_name: str, model_type: SDModelType) -> str:
|
||||
return f"{model_type.name}/{model_name}"
|
||||
return f"{model_type}/{model_name}"
|
||||
|
||||
def parse_key(self, model_key: str) -> Tuple[str, SDModelType]:
|
||||
model_type_str, model_name = model_key.split('/', 1)
|
||||
if model_type_str not in SDModelType.__members__:
|
||||
# TODO:
|
||||
try:
|
||||
model_type = SDModelType(model_type_str)
|
||||
return (model_name, model_type)
|
||||
except:
|
||||
raise Exception(f"Unknown model type: {model_type_str}")
|
||||
|
||||
return (model_name, SDModelType[model_type_str])
|
||||
|
||||
def get_model(
|
||||
self,
|
||||
model_name: str,
|
||||
model_type: SDModelType=SDModelType.diffusers,
|
||||
submodel: SDModelType=None,
|
||||
model_type: SDModelType = SDModelType.Diffusers,
|
||||
submodel: Optional[SDModelType] = None,
|
||||
) -> SDModelInfo:
|
||||
"""Given a model named identified in models.yaml, return
|
||||
an SDModelInfo object describing it.
|
||||
:param model_name: symbolic name of the model in models.yaml
|
||||
:param model_type: SDModelType enum indicating the type of model to return
|
||||
:param submodel: an SDModelType enum indicating the portion of
|
||||
the model to retrieve (e.g. SDModelType.vae)
|
||||
the model to retrieve (e.g. SDModelType.Vae)
|
||||
|
||||
If not provided, the model_type will be read from the `format` field
|
||||
of the corresponding stanza. If provided, the model_type will be used
|
||||
@ -304,17 +304,17 @@ class ModelManager(object):
|
||||
test1_pipeline = mgr.get_model('test1')
|
||||
# returns a StableDiffusionGeneratorPipeline
|
||||
|
||||
test1_vae1 = mgr.get_model('test1', submodel=SDModelType.vae)
|
||||
test1_vae1 = mgr.get_model('test1', submodel=SDModelType.Vae)
|
||||
# returns the VAE part of a diffusers model as an AutoencoderKL
|
||||
|
||||
test1_vae2 = mgr.get_model('test1', model_type=SDModelType.diffusers, submodel=SDModelType.vae)
|
||||
test1_vae2 = mgr.get_model('test1', model_type=SDModelType.Diffusers, submodel=SDModelType.Vae)
|
||||
# does the same thing as the previous statement. Note that model_type
|
||||
# is for the parent model, and submodel is for the part
|
||||
|
||||
test1_lora = mgr.get_model('test1', model_type=SDModelType.lora)
|
||||
test1_lora = mgr.get_model('test1', model_type=SDModelType.Lora)
|
||||
# returns a LoRA embed (as a 'dict' of tensors)
|
||||
|
||||
test1_encoder = mgr.get_modelI('test1', model_type=SDModelType.textencoder)
|
||||
test1_encoder = mgr.get_modelI('test1', model_type=SDModelType.TextEncoder)
|
||||
# raises an InvalidModelError
|
||||
|
||||
"""
|
||||
@ -332,10 +332,10 @@ class ModelManager(object):
|
||||
mconfig = self.config[model_key]
|
||||
|
||||
# type already checked as it's part of key
|
||||
if model_type == SDModelType.diffusers:
|
||||
if model_type == SDModelType.Diffusers:
|
||||
# intercept stanzas that point to checkpoint weights and replace them
|
||||
# with the equivalent diffusers model
|
||||
if mconfig.format in ["ckpt", "diffusers"]:
|
||||
if mconfig.format in ["ckpt", "safetensors"]:
|
||||
location = self.convert_ckpt_and_cache(mconfig)
|
||||
else:
|
||||
location = global_resolve_path(mconfig.get('path')) or mconfig.get('repo_id')
|
||||
@ -355,13 +355,13 @@ class ModelManager(object):
|
||||
vae = (None, None)
|
||||
with suppress(Exception):
|
||||
vae_id = mconfig.vae.repo_id
|
||||
vae = (SDModelType.vae, vae_id)
|
||||
vae = (SDModelType.Vae, vae_id)
|
||||
|
||||
# optimization - don't load whole model if the user
|
||||
# is asking for just a piece of it
|
||||
if model_type == SDModelType.diffusers and submodel and not subfolder:
|
||||
if model_type == SDModelType.Diffusers and submodel and not subfolder:
|
||||
model_type = submodel
|
||||
subfolder = submodel.name
|
||||
subfolder = submodel.value
|
||||
submodel = None
|
||||
|
||||
model_context = self.cache.get_model(
|
||||
@ -390,7 +390,7 @@ class ModelManager(object):
|
||||
_cache = self.cache
|
||||
)
|
||||
|
||||
def default_model(self) -> Union[Tuple[str, SDModelType],None]:
|
||||
def default_model(self) -> Optional[Tuple[str, SDModelType]]:
|
||||
"""
|
||||
Returns the name of the default model, or None
|
||||
if none is defined.
|
||||
@ -401,7 +401,7 @@ class ModelManager(object):
|
||||
return (model_name, model_type)
|
||||
return self.model_names()[0][0]
|
||||
|
||||
def set_default_model(self, model_name: str, model_type: SDModelType=SDModelType.diffusers) -> None:
|
||||
def set_default_model(self, model_name: str, model_type: SDModelType=SDModelType.Diffusers) -> None:
|
||||
"""
|
||||
Set the default model. The change will not take
|
||||
effect until you call model_manager.commit()
|
||||
@ -417,7 +417,7 @@ class ModelManager(object):
|
||||
def model_info(
|
||||
self,
|
||||
model_name: str,
|
||||
model_type: SDModelType=SDModelType.diffusers
|
||||
model_type: SDModelType=SDModelType.Diffusers,
|
||||
) -> dict:
|
||||
"""
|
||||
Given a model name returns the OmegaConf (dict-like) object describing it.
|
||||
@ -433,7 +433,7 @@ class ModelManager(object):
|
||||
"""
|
||||
return [(self.parse_key(x)) for x in self.config.keys() if isinstance(self.config[x], DictConfig)]
|
||||
|
||||
def is_legacy(self, model_name: str, model_type: SDModelType.diffusers) -> bool:
|
||||
def is_legacy(self, model_name: str, model_type: SDModelType.Diffusers) -> bool:
|
||||
"""
|
||||
Return true if this is a legacy (.ckpt) model
|
||||
"""
|
||||
@ -468,7 +468,7 @@ class ModelManager(object):
|
||||
models[model_key] = dict()
|
||||
|
||||
# TODO: return all models in future
|
||||
if model_type != SDModelType.diffusers:
|
||||
if model_type != SDModelType.Diffusers:
|
||||
continue
|
||||
|
||||
model_format = stanza.get('format')
|
||||
@ -477,15 +477,15 @@ class ModelManager(object):
|
||||
status = self.cache.status(
|
||||
stanza.get('weights') or stanza.get('repo_id'),
|
||||
revision=stanza.get('revision'),
|
||||
subfolder=stanza.get('subfolder')
|
||||
subfolder=stanza.get('subfolder'),
|
||||
)
|
||||
description = stanza.get("description", None)
|
||||
models[model_key].update(
|
||||
model_name=model_name,
|
||||
model_type=model_type.name,
|
||||
model_type=model_type,
|
||||
format=model_format,
|
||||
description=description,
|
||||
status=status.value
|
||||
status=status.value,
|
||||
)
|
||||
|
||||
|
||||
@ -528,8 +528,8 @@ class ModelManager(object):
|
||||
def del_model(
|
||||
self,
|
||||
model_name: str,
|
||||
model_type: SDModelType.diffusers,
|
||||
delete_files: bool = False
|
||||
model_type: SDModelType.Diffusers,
|
||||
delete_files: bool = False,
|
||||
):
|
||||
"""
|
||||
Delete the named model.
|
||||
@ -571,7 +571,7 @@ class ModelManager(object):
|
||||
model_name: str,
|
||||
model_type: SDModelType,
|
||||
model_attributes: dict,
|
||||
clobber: bool = False
|
||||
clobber: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
Update the named model with a dictionary of attributes. Will fail with an
|
||||
@ -581,7 +581,7 @@ class ModelManager(object):
|
||||
attributes are incorrect or the model name is missing.
|
||||
"""
|
||||
|
||||
if model_type == SDModelType.diffusers:
|
||||
if model_type == SDModelType.Fiffusers:
|
||||
# TODO: automaticaly or manualy?
|
||||
#assert "format" in model_attributes, 'missing required field "format"'
|
||||
model_format = "ckpt" if "weights" in model_attributes else "diffusers"
|
||||
@ -647,16 +647,16 @@ class ModelManager(object):
|
||||
else:
|
||||
new_config.update(repo_id=repo_or_path)
|
||||
|
||||
self.add_model(model_name, SDModelType.diffusers, new_config, True)
|
||||
self.add_model(model_name, SDModelType.Diffusers, new_config, True)
|
||||
if commit_to_conf:
|
||||
self.commit(commit_to_conf)
|
||||
return self.create_key(model_name, SDModelType.diffusers)
|
||||
return self.create_key(model_name, SDModelType.Diffusers)
|
||||
|
||||
def import_lora(
|
||||
self,
|
||||
path: Path,
|
||||
model_name: str=None,
|
||||
description: str=None,
|
||||
model_name: Optional[str] = None,
|
||||
description: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Creates an entry for the indicated lora file. Call
|
||||
@ -667,7 +667,7 @@ class ModelManager(object):
|
||||
model_description = description or f"LoRA model {model_name}"
|
||||
self.add_model(
|
||||
model_name,
|
||||
SDModelType.lora,
|
||||
SDModelType.Lora,
|
||||
dict(
|
||||
format="lora",
|
||||
weights=str(path),
|
||||
@ -679,8 +679,8 @@ class ModelManager(object):
|
||||
def import_embedding(
|
||||
self,
|
||||
path: Path,
|
||||
model_name: str=None,
|
||||
description: str=None,
|
||||
model_name: Optional[str] = None,
|
||||
description: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Creates an entry for the indicated lora file. Call
|
||||
@ -696,7 +696,7 @@ class ModelManager(object):
|
||||
model_description = description or f"Textual embedding model {model_name}"
|
||||
self.add_model(
|
||||
model_name,
|
||||
SDModelType.textual_inversion,
|
||||
SDModelType.TextualInversion,
|
||||
dict(
|
||||
format="textual_inversion",
|
||||
weights=str(weights),
|
||||
@ -746,11 +746,11 @@ class ModelManager(object):
|
||||
def heuristic_import(
|
||||
self,
|
||||
path_url_or_repo: str,
|
||||
model_name: str = None,
|
||||
description: str = None,
|
||||
model_config_file: Path = None,
|
||||
commit_to_conf: Path = None,
|
||||
config_file_callback: Callable[[Path], Path] = None,
|
||||
model_name: Optional[str] = None,
|
||||
description: Optional[str] = None,
|
||||
model_config_file: Optional[Path] = None,
|
||||
commit_to_conf: Optional[Path] = None,
|
||||
config_file_callback: Optional[Callable[[Path], Path]] = None,
|
||||
) -> str:
|
||||
"""Accept a string which could be:
|
||||
- a HF diffusers repo_id
|
||||
@ -961,7 +961,7 @@ class ModelManager(object):
|
||||
self,
|
||||
weights: Path,
|
||||
mconfig: DictConfig
|
||||
) -> Tuple[Path, SDModelType.vae]:
|
||||
) -> Tuple[Path, AutoencoderKL]:
|
||||
# VAE handling is convoluted
|
||||
# 1. If there is a .vae.ckpt file sharing same stem as weights, then use
|
||||
# it as the vae_path passed to convert
|
||||
@ -990,7 +990,7 @@ class ModelManager(object):
|
||||
vae_diffusers_location = "stabilityai/sd-vae-ft-mse"
|
||||
|
||||
if vae_diffusers_location:
|
||||
vae_model = self.cache.get_model(vae_diffusers_location, SDModelType.vae).model
|
||||
vae_model = self.cache.get_model(vae_diffusers_location, SDModelType.Vae).model
|
||||
return (None, vae_model)
|
||||
|
||||
return (None, None)
|
||||
@ -1038,7 +1038,7 @@ class ModelManager(object):
|
||||
vae_model = None
|
||||
if vae:
|
||||
vae_location = global_resolve_path(vae.get('path')) or vae.get('repo_id')
|
||||
vae_model = self.cache.get_model(vae_location,SDModelType.vae).model
|
||||
vae_model = self.cache.get_model(vae_location, SDModelType.Vae).model
|
||||
vae_path = None
|
||||
convert_ckpt_to_diffusers(
|
||||
ckpt_path,
|
||||
@ -1058,11 +1058,11 @@ class ModelManager(object):
|
||||
description=model_description,
|
||||
format="diffusers",
|
||||
)
|
||||
if self.model_exists(model_name, SDModelType.diffusers):
|
||||
self.del_model(model_name, SDModelType.diffusers)
|
||||
if self.model_exists(model_name, SDModelType.Diffusers):
|
||||
self.del_model(model_name, SDModelType.Diffusers)
|
||||
self.add_model(
|
||||
model_name,
|
||||
SDModelType.diffusers,
|
||||
SDModelType.Diffusers,
|
||||
new_config,
|
||||
True
|
||||
)
|
||||
|
@ -263,7 +263,7 @@ export const parseNodeMetadata = (
|
||||
return;
|
||||
}
|
||||
|
||||
if ('unet' in nodeItem && 'tokenizer' in nodeItem) {
|
||||
if ('unet' in nodeItem && 'scheduler' in nodeItem) {
|
||||
const unetField = parseUNetField(nodeItem);
|
||||
if (unetField) {
|
||||
parsed[nodeKey] = unetField;
|
||||
|
Loading…
Reference in New Issue
Block a user