Change SDModelType enum to string, fixes(model unload negative locks count, scheduler load error, saftensors convert, wrong logic in del_model, wrong parse metadata in web)

This commit is contained in:
Sergey Borisov
2023-05-14 03:06:26 +03:00
parent 2204e47596
commit 039fa73269
8 changed files with 388 additions and 363 deletions

View File

@ -23,12 +23,12 @@ import warnings
from collections import Counter
from enum import Enum
from pathlib import Path
from typing import Dict, Sequence, Union, Tuple, types
from typing import Dict, Sequence, Union, Tuple, types, Optional
import torch
import safetensors.torch
from diffusers import DiffusionPipeline, StableDiffusionPipeline, AutoencoderKL, SchedulerMixin, UNet2DConditionModel
from diffusers import DiffusionPipeline, StableDiffusionPipeline, AutoencoderKL, SchedulerMixin, UNet2DConditionModel, ConfigMixin
from diffusers import logging as diffusers_logging
from diffusers.pipelines.stable_diffusion.safety_checker import \
StableDiffusionSafetyChecker
@ -55,20 +55,38 @@ class LoraType(dict):
class TIType(dict):
pass
class SDModelType(Enum):
diffusers=StableDiffusionGeneratorPipeline # whole pipeline
vae=AutoencoderKL # diffusers parts
text_encoder=CLIPTextModel
tokenizer=CLIPTokenizer
unet=UNet2DConditionModel
scheduler=SchedulerMixin
safety_checker=StableDiffusionSafetyChecker
feature_extractor=CLIPFeatureExtractor
class SDModelType(str, Enum):
Diffusers="diffusers" # whole pipeline
Vae="vae" # diffusers parts
TextEncoder="text_encoder"
Tokenizer="tokenizer"
UNet="unet"
Scheduler="scheduler"
SafetyChecker="safety_checker"
FeatureExtractor="feature_extractor"
# These are all loaded as dicts of tensors, and we
# distinguish them by class
lora=LoraType
textual_inversion=TIType
Lora="lora"
TextualInversion="textual_inversion"
# TODO:
class EmptyScheduler(SchedulerMixin, ConfigMixin):
pass
MODEL_CLASSES = {
SDModelType.Diffusers: StableDiffusionGeneratorPipeline,
SDModelType.Vae: AutoencoderKL,
SDModelType.TextEncoder: CLIPTextModel, # TODO: t5
SDModelType.Tokenizer: CLIPTokenizer, # TODO: t5
SDModelType.UNet: UNet2DConditionModel,
SDModelType.Scheduler: EmptyScheduler,
SDModelType.SafetyChecker: StableDiffusionSafetyChecker,
SDModelType.FeatureExtractor: CLIPFeatureExtractor,
SDModelType.Lora: LoraType,
SDModelType.TextualInversion: TIType,
}
class ModelStatus(Enum):
unknown='unknown'
not_loaded='not loaded'
@ -80,21 +98,21 @@ class ModelStatus(Enum):
# After loading, we will know it exactly.
# Sizes are in Gigs, estimated for float16; double for float32
SIZE_GUESSTIMATE = {
SDModelType.diffusers: 2.2,
SDModelType.vae: 0.35,
SDModelType.text_encoder: 0.5,
SDModelType.tokenizer: 0.001,
SDModelType.unet: 3.4,
SDModelType.scheduler: 0.001,
SDModelType.safety_checker: 1.2,
SDModelType.feature_extractor: 0.001,
SDModelType.lora: 0.1,
SDModelType.textual_inversion: 0.001,
SDModelType.Diffusers: 2.2,
SDModelType.Vae: 0.35,
SDModelType.TextEncoder: 0.5,
SDModelType.Tokenizer: 0.001,
SDModelType.UNet: 3.4,
SDModelType.Scheduler: 0.001,
SDModelType.SafetyChecker: 1.2,
SDModelType.FeatureExtractor: 0.001,
SDModelType.Lora: 0.1,
SDModelType.TextualInversion: 0.001,
}
# The list of model classes we know how to fetch, for typechecking
ModelClass = Union[tuple([x.value for x in SDModelType])]
DiffusionClasses = (StableDiffusionGeneratorPipeline, AutoencoderKL, SchedulerMixin, UNet2DConditionModel)
ModelClass = Union[tuple([x for x in MODEL_CLASSES.values()])]
DiffusionClasses = (StableDiffusionGeneratorPipeline, AutoencoderKL, EmptyScheduler, UNet2DConditionModel)
class UnsafeModelException(Exception):
"Raised when a legacy model file fails the picklescan test"
@ -110,15 +128,15 @@ class ModelLocker(object):
class ModelCache(object):
def __init__(
self,
max_cache_size: float=DEFAULT_MAX_CACHE_SIZE,
execution_device: torch.device=torch.device('cuda'),
storage_device: torch.device=torch.device('cpu'),
precision: torch.dtype=torch.float16,
sequential_offload: bool=False,
lazy_offloading: bool=True,
sha_chunksize: int = 16777216,
logger: types.ModuleType = logger
self,
max_cache_size: float=DEFAULT_MAX_CACHE_SIZE,
execution_device: torch.device=torch.device('cuda'),
storage_device: torch.device=torch.device('cpu'),
precision: torch.dtype=torch.float16,
sequential_offload: bool=False,
lazy_offloading: bool=True,
sha_chunksize: int = 16777216,
logger: types.ModuleType = logger
):
'''
:param max_models: Maximum number of models to cache in CPU RAM [4]
@ -145,15 +163,15 @@ class ModelCache(object):
self.model_sizes: Dict[str,int] = dict()
def get_model(
self,
repo_id_or_path: Union[str,Path],
model_type: SDModelType=SDModelType.diffusers,
subfolder: Path=None,
submodel: SDModelType=None,
revision: str=None,
attach_model_part: Tuple[SDModelType, str] = (None,None),
gpu_load: bool=True,
)->ModelLocker: # ?? what does it return
self,
repo_id_or_path: Union[str, Path],
model_type: SDModelType = SDModelType.Diffusers,
subfolder: Path = None,
submodel: SDModelType = None,
revision: str = None,
attach_model_part: Tuple[SDModelType, str] = (None, None),
gpu_load: bool = True,
) -> ModelLocker: # ?? what does it return
'''
Load and return a HuggingFace model wrapped in a context manager generator, with RAM caching.
Use like this:
@ -178,14 +196,14 @@ class ModelCache(object):
vae_context = cache.get_model(
'stabilityai/sd-stable-diffusion-2',
submodel=SDModelType.vae
submodel=SDModelType.Vae
)
This is equivalent to:
vae_context = cache.get_model(
'stabilityai/sd-stable-diffusion-2',
model_type = SDModelType.vae,
model_type = SDModelType.Vae,
subfolder='vae'
)
@ -195,14 +213,14 @@ class ModelCache(object):
pipeline_context = cache.get_model(
'runwayml/stable-diffusion-v1-5',
attach_model_part=(SDModelType.vae,'stabilityai/sd-vae-ft-mse')
attach_model_part=(SDModelType.Vae,'stabilityai/sd-vae-ft-mse')
)
The model will be locked into GPU VRAM for the duration of the context.
:param repo_id_or_path: either the HuggingFace repo_id or a Path to a local model
:param model_type: An SDModelType enum indicating the type of the (parent) model
:param subfolder: name of a subfolder in which the model can be found, e.g. "vae"
:param submodel: an SDModelType enum indicating the model part to return, e.g. SDModelType.vae
:param submodel: an SDModelType enum indicating the model part to return, e.g. SDModelType.Vae
:param attach_model_part: load and attach a diffusers model component. Pass a tuple of format (SDModelType,repo_id)
:param revision: model revision
:param gpu_load: load the model into GPU [default True]
@ -211,7 +229,7 @@ class ModelCache(object):
repo_id_or_path,
revision,
subfolder,
model_type.value,
model_type,
)
# optimization: if caller is asking to load a submodel of a diffusers pipeline, then
@ -221,11 +239,11 @@ class ModelCache(object):
repo_id_or_path,
None,
revision,
SDModelType.diffusers.value
SDModelType.Diffusers
)
if possible_parent_key in self.models:
key = possible_parent_key
submodel=model_type
submodel = model_type
# Look for the model in the cache RAM
if key in self.models: # cached - move to bottom of stack (most recently used)
@ -256,24 +274,24 @@ class ModelCache(object):
self.current_cache_size += mem_used # increment size of the cache
# this is a bit of legacy work needed to support the old-style "load this diffuser with custom VAE"
if model_type==SDModelType.diffusers and attach_model_part[0]:
self.attach_part(model,*attach_model_part)
if model_type == SDModelType.Diffusers and attach_model_part[0]:
self.attach_part(model, *attach_model_part)
self.stack.append(key) # add to LRU cache
self.models[key]=model # keep copy of model in dict
self.models[key] = model # keep copy of model in dict
if submodel:
model = getattr(model, submodel.name)
model = getattr(model, submodel)
return self.ModelLocker(self, key, model, gpu_load)
def uncache_model(self, key: str):
'''Remove corresponding model from the cache'''
if key is not None and key in self.models:
with contextlib.suppress(ValueError), contextlib.suppress(KeyError):
del self.models[key]
del self.locked_models[key]
self.loaded_models.remove(key)
self.models.pop(key, None)
self.locked_models.pop(key, None)
self.loaded_models.discard(key)
with contextlib.suppress(ValueError):
self.stack.remove(key)
class ModelLocker(object):
@ -302,7 +320,7 @@ class ModelCache(object):
if model.device != cache.execution_device:
cache.logger.debug(f'Moving {key} into {cache.execution_device}')
with VRAMUsage() as mem:
model.to(cache.execution_device) # move into GPU
model.to(cache.execution_device, dtype=cache.precision) # move into GPU
cache.logger.debug(f'GPU VRAM used for load: {(mem.vram_used/GIG):.2f} GB')
cache.model_sizes[key] = mem.vram_used # more accurate size
@ -312,13 +330,16 @@ class ModelCache(object):
else:
# in the event that the caller wants the model in RAM, we
# move it into CPU if it is in GPU and not locked
if hasattr(model,'to') and (key in cache.loaded_models
if hasattr(model, 'to') and (key in cache.loaded_models
and cache.locked_models[key] == 0):
model.to(cache.storage_device)
cache.loaded_models.remove(key)
return model
def __exit__(self, type, value, traceback):
if not hasattr(self.model, 'to'):
return
key = self.key
cache = self.cache
cache.locked_models[key] -= 1
@ -326,11 +347,12 @@ class ModelCache(object):
cache._offload_unlocked_models()
cache._print_cuda_stats()
def attach_part(self,
diffusers_model: StableDiffusionPipeline,
part_type: SDModelType,
part_id: str
):
def attach_part(
self,
diffusers_model: StableDiffusionPipeline,
part_type: SDModelType,
part_id: str,
):
'''
Attach a diffusers model part to a diffusers model. This can be
used to replace the VAE, tokenizer, textencoder, unet, etc.
@ -338,27 +360,26 @@ class ModelCache(object):
:param part_type: An SD ModelType indicating the part
:param part_id: A HF repo_id for the part
'''
part_key = part_type.name
part_class = part_type.value
part = self._load_diffusers_from_storage(
part_id,
model_class=part_class,
model_class=MODEL_CLASSES[part_type],
)
part.to(diffusers_model.device)
setattr(diffusers_model,part_key,part)
self.logger.debug(f'Attached {part_key} {part_id}')
setattr(diffusers_model, part_type, part)
self.logger.debug(f'Attached {part_type} {part_id}')
def status(self,
repo_id_or_path: Union[str,Path],
model_type: SDModelType=SDModelType.diffusers,
revision: str=None,
subfolder: Path=None,
)->ModelStatus:
def status(
self,
repo_id_or_path: Union[str, Path],
model_type: SDModelType = SDModelType.Diffusers,
revision: str = None,
subfolder: Path = None,
) -> ModelStatus:
key = self._model_key(
repo_id_or_path,
revision,
subfolder,
model_type.value,
model_type,
)
if key not in self.models:
return ModelStatus.not_loaded
@ -370,9 +391,11 @@ class ModelCache(object):
else:
return ModelStatus.in_ram
def model_hash(self,
repo_id_or_path: Union[str,Path],
revision: str="main")->str:
def model_hash(
self,
repo_id_or_path: Union[str, Path],
revision: str = "main",
) -> str:
'''
Given the HF repo id or path to a model on disk, returns a unique
hash. Works for legacy checkpoint files, HF models on disk, and HF repo IDs
@ -385,7 +408,7 @@ class ModelCache(object):
else:
return self._hf_commit_hash(repo_id_or_path,revision)
def cache_size(self)->float:
def cache_size(self) -> float:
"Return the current size of the cache, in GB"
return self.current_cache_size / GIG
@ -407,10 +430,15 @@ class ModelCache(object):
logger.debug("Model scanned ok")
@staticmethod
def _model_key(path,revision,subfolder,model_class)->str:
return ':'.join([str(path),str(revision or ''),str(subfolder or ''),model_class.__name__])
def _model_key(path, revision, subfolder, model_class) -> str:
return ':'.join([
str(path),
str(revision or ''),
str(subfolder or ''),
model_class,
])
def _has_cuda(self)->bool:
def _has_cuda(self) -> bool:
return self.execution_device.type == 'cuda'
def _print_cuda_stats(self):
@ -450,43 +478,43 @@ class ModelCache(object):
self.loaded_models.remove(key)
def _load_model_from_storage(
self,
repo_id_or_path: Union[str,Path],
subfolder: Path=None,
revision: str=None,
model_type: SDModelType=SDModelType.diffusers,
)->ModelClass:
self,
repo_id_or_path: Union[str, Path],
subfolder: Optional[Path] = None,
revision: Optional[str] = None,
model_type: SDModelType = SDModelType.Diffusers,
) -> ModelClass:
'''
Load and return a HuggingFace model.
:param repo_id_or_path: either the HuggingFace repo_id or a Path to a local model
:param subfolder: name of a subfolder in which the model can be found, e.g. "vae"
:param revision: model revision
:param model_type: type of model to return, defaults to SDModelType.diffusers
:param model_type: type of model to return, defaults to SDModelType.Diffusers
'''
# silence transformer and diffuser warnings
with SilenceWarnings():
if model_type==SDModelType.lora:
if model_type==SDModelType.Lora:
model = self._load_lora_from_storage(repo_id_or_path)
elif model_type==SDModelType.textual_inversion:
elif model_type==SDModelType.TextualInversion:
model = self._load_ti_from_storage(repo_id_or_path)
else:
model = self._load_diffusers_from_storage(
repo_id_or_path,
subfolder,
revision,
model_type.value,
model_type,
)
if self.sequential_offload and isinstance(model,StableDiffusionGeneratorPipeline):
if self.sequential_offload and isinstance(model, StableDiffusionGeneratorPipeline):
model.enable_offload_submodels(self.execution_device)
return model
def _load_diffusers_from_storage(
self,
repo_id_or_path: Union[str,Path],
subfolder: Path=None,
revision: str=None,
model_class: ModelClass=StableDiffusionGeneratorPipeline,
)->ModelClass:
self,
repo_id_or_path: Union[str, Path],
subfolder: Optional[Path] = None,
revision: Optional[str] = None,
model_type: ModelClass = StableDiffusionGeneratorPipeline,
) -> ModelClass:
'''
Load and return a HuggingFace model using from_pretrained().
:param repo_id_or_path: either the HuggingFace repo_id or a Path to a local model
@ -494,17 +522,26 @@ class ModelCache(object):
:param revision: model revision
:param model_class: class of model to return, defaults to StableDiffusionGeneratorPIpeline
'''
revisions = [revision] if revision \
else ['fp16','main'] if self.precision==torch.float16 \
else ['main']
extra_args = {'torch_dtype': self.precision,
'safety_checker': None}\
if model_class in DiffusionClasses\
else {}
model_class = MODEL_CLASSES[model_type]
if revision is not None:
revisions = [revision]
elif self.precision == torch.float16:
revisions = ['fp16', 'main']
else:
revisions = ['main']
extra_args = dict()
if model_class in DiffusionClasses:
extra_args = dict(
torch_dtype=self.precision,
safety_checker=None,
)
for rev in revisions:
try:
model = model_class.from_pretrained(
model = model_class.from_pretrained(
repo_id_or_path,
revision=rev,
subfolder=subfolder or '.',
@ -517,13 +554,13 @@ class ModelCache(object):
pass
return model
def _load_lora_from_storage(self, lora_path: Path)->SDModelType.lora.value:
assert False,"_load_lora_from_storage() is not yet implemented"
def _load_lora_from_storage(self, lora_path: Path) -> LoraType:
assert False, "_load_lora_from_storage() is not yet implemented"
def _load_ti_from_storage(self, lora_path: Path)->SDModelType.textual_inversion.value:
assert False,"_load_ti_from_storage() is not yet implemented"
def _load_ti_from_storage(self, lora_path: Path) -> TIType:
assert False, "_load_ti_from_storage() is not yet implemented"
def _legacy_model_hash(self, checkpoint_path: Union[str,Path])->str:
def _legacy_model_hash(self, checkpoint_path: Union[str, Path]) -> str:
sha = hashlib.sha256()
path = Path(checkpoint_path)
assert path.is_file(),f"File {checkpoint_path} not found"
@ -544,7 +581,7 @@ class ModelCache(object):
f.write(hash)
return hash
def _local_model_hash(self, model_path: Union[str,Path])->str:
def _local_model_hash(self, model_path: Union[str, Path]) -> str:
sha = hashlib.sha256()
path = Path(model_path)
@ -566,7 +603,7 @@ class ModelCache(object):
f.write(hash)
return hash
def _hf_commit_hash(self, repo_id: str, revision: str='main')->str:
def _hf_commit_hash(self, repo_id: str, revision: str='main') -> str:
api = HfApi()
info = api.list_repo_refs(
repo_id=repo_id,
@ -578,7 +615,7 @@ class ModelCache(object):
return desired_revisions[0].target_commit
@staticmethod
def calc_model_size(model)->int:
def calc_model_size(model) -> int:
if isinstance(model,DiffusionPipeline):
return ModelCache._calc_pipeline(model)
elif isinstance(model,torch.nn.Module):
@ -587,7 +624,7 @@ class ModelCache(object):
return None
@staticmethod
def _calc_pipeline(pipeline)->int:
def _calc_pipeline(pipeline) -> int:
res = 0
for submodel_key in pipeline.components.keys():
submodel = getattr(pipeline, submodel_key)
@ -596,7 +633,7 @@ class ModelCache(object):
return res
@staticmethod
def _calc_model(model)->int:
def _calc_model(model) -> int:
mem_params = sum([param.nelement()*param.element_size() for param in model.parameters()])
mem_bufs = sum([buf.nelement()*buf.element_size() for buf in model.buffers()])
mem = mem_params + mem_bufs # in bytes

View File

@ -27,7 +27,7 @@ Typical usage:
max_cache_size=8
) # gigabytes
model_info = manager.get_model('stable-diffusion-1.5', SDModelType.diffusers)
model_info = manager.get_model('stable-diffusion-1.5', SDModelType.Diffusers)
with model_info.context as my_model:
my_model.latents_from_embeddings(...)
@ -45,7 +45,7 @@ parameter:
model_info = manager.get_model(
'clip-tokenizer',
model_type=SDModelType.tokenizer
model_type=SDModelType.Tokenizer
)
This will raise an InvalidModelError if the format defined in the
@ -96,7 +96,7 @@ SUBMODELS:
It is also possible to fetch an isolated submodel from a diffusers
model. Use the `submodel` parameter to select which part:
vae = manager.get_model('stable-diffusion-1.5',submodel=SDModelType.vae)
vae = manager.get_model('stable-diffusion-1.5',submodel=SDModelType.Vae)
with vae.context as my_vae:
print(type(my_vae))
# "AutoencoderKL"
@ -120,8 +120,8 @@ separated by "/". Example:
You can now use the `model_type` argument to indicate which model you
want:
tokenizer = mgr.get('clip-large',model_type=SDModelType.tokenizer)
encoder = mgr.get('clip-large',model_type=SDModelType.text_encoder)
tokenizer = mgr.get('clip-large',model_type=SDModelType.Tokenizer)
encoder = mgr.get('clip-large',model_type=SDModelType.TextEncoder)
OTHER FUNCTIONS:
@ -254,7 +254,7 @@ class ModelManager(object):
def model_exists(
self,
model_name: str,
model_type: SDModelType = SDModelType.diffusers,
model_type: SDModelType = SDModelType.Diffusers,
) -> bool:
"""
Given a model name, returns True if it is a valid
@ -264,28 +264,28 @@ class ModelManager(object):
return model_key in self.config
def create_key(self, model_name: str, model_type: SDModelType) -> str:
return f"{model_type.name}/{model_name}"
return f"{model_type}/{model_name}"
def parse_key(self, model_key: str) -> Tuple[str, SDModelType]:
model_type_str, model_name = model_key.split('/', 1)
if model_type_str not in SDModelType.__members__:
# TODO:
try:
model_type = SDModelType(model_type_str)
return (model_name, model_type)
except:
raise Exception(f"Unknown model type: {model_type_str}")
return (model_name, SDModelType[model_type_str])
def get_model(
self,
model_name: str,
model_type: SDModelType=SDModelType.diffusers,
submodel: SDModelType=None,
model_type: SDModelType = SDModelType.Diffusers,
submodel: Optional[SDModelType] = None,
) -> SDModelInfo:
"""Given a model named identified in models.yaml, return
an SDModelInfo object describing it.
:param model_name: symbolic name of the model in models.yaml
:param model_type: SDModelType enum indicating the type of model to return
:param submodel: an SDModelType enum indicating the portion of
the model to retrieve (e.g. SDModelType.vae)
the model to retrieve (e.g. SDModelType.Vae)
If not provided, the model_type will be read from the `format` field
of the corresponding stanza. If provided, the model_type will be used
@ -304,17 +304,17 @@ class ModelManager(object):
test1_pipeline = mgr.get_model('test1')
# returns a StableDiffusionGeneratorPipeline
test1_vae1 = mgr.get_model('test1', submodel=SDModelType.vae)
test1_vae1 = mgr.get_model('test1', submodel=SDModelType.Vae)
# returns the VAE part of a diffusers model as an AutoencoderKL
test1_vae2 = mgr.get_model('test1', model_type=SDModelType.diffusers, submodel=SDModelType.vae)
test1_vae2 = mgr.get_model('test1', model_type=SDModelType.Diffusers, submodel=SDModelType.Vae)
# does the same thing as the previous statement. Note that model_type
# is for the parent model, and submodel is for the part
test1_lora = mgr.get_model('test1', model_type=SDModelType.lora)
test1_lora = mgr.get_model('test1', model_type=SDModelType.Lora)
# returns a LoRA embed (as a 'dict' of tensors)
test1_encoder = mgr.get_modelI('test1', model_type=SDModelType.textencoder)
test1_encoder = mgr.get_modelI('test1', model_type=SDModelType.TextEncoder)
# raises an InvalidModelError
"""
@ -332,10 +332,10 @@ class ModelManager(object):
mconfig = self.config[model_key]
# type already checked as it's part of key
if model_type == SDModelType.diffusers:
if model_type == SDModelType.Diffusers:
# intercept stanzas that point to checkpoint weights and replace them
# with the equivalent diffusers model
if mconfig.format in ["ckpt", "diffusers"]:
if mconfig.format in ["ckpt", "safetensors"]:
location = self.convert_ckpt_and_cache(mconfig)
else:
location = global_resolve_path(mconfig.get('path')) or mconfig.get('repo_id')
@ -355,13 +355,13 @@ class ModelManager(object):
vae = (None, None)
with suppress(Exception):
vae_id = mconfig.vae.repo_id
vae = (SDModelType.vae, vae_id)
vae = (SDModelType.Vae, vae_id)
# optimization - don't load whole model if the user
# is asking for just a piece of it
if model_type == SDModelType.diffusers and submodel and not subfolder:
if model_type == SDModelType.Diffusers and submodel and not subfolder:
model_type = submodel
subfolder = submodel.name
subfolder = submodel.value
submodel = None
model_context = self.cache.get_model(
@ -390,7 +390,7 @@ class ModelManager(object):
_cache = self.cache
)
def default_model(self) -> Union[Tuple[str, SDModelType],None]:
def default_model(self) -> Optional[Tuple[str, SDModelType]]:
"""
Returns the name of the default model, or None
if none is defined.
@ -401,7 +401,7 @@ class ModelManager(object):
return (model_name, model_type)
return self.model_names()[0][0]
def set_default_model(self, model_name: str, model_type: SDModelType=SDModelType.diffusers) -> None:
def set_default_model(self, model_name: str, model_type: SDModelType=SDModelType.Diffusers) -> None:
"""
Set the default model. The change will not take
effect until you call model_manager.commit()
@ -415,25 +415,25 @@ class ModelManager(object):
config[self.create_key(model_name, model_type)]["default"] = True
def model_info(
self,
model_name: str,
model_type: SDModelType=SDModelType.diffusers
self,
model_name: str,
model_type: SDModelType=SDModelType.Diffusers,
) -> dict:
"""
Given a model name returns the OmegaConf (dict-like) object describing it.
"""
if not self.exists(model_name, model_type):
return None
return self.config[self.create_key(model_name,model_type)]
return self.config[self.create_key(model_name, model_type)]
def model_names(self) -> List[Tuple[str, SDModelType]]:
"""
Return a list of (str, SDModelType) corresponding to all models
known to the configuration.
"""
return [(self.parse_key(x)) for x in self.config.keys() if isinstance(self.config[x],DictConfig)]
return [(self.parse_key(x)) for x in self.config.keys() if isinstance(self.config[x], DictConfig)]
def is_legacy(self, model_name: str, model_type: SDModelType.diffusers) -> bool:
def is_legacy(self, model_name: str, model_type: SDModelType.Diffusers) -> bool:
"""
Return true if this is a legacy (.ckpt) model
"""
@ -461,14 +461,14 @@ class ModelManager(object):
# don't include VAEs in listing (legacy style)
if "config" in stanza and "/VAE/" in stanza["config"]:
continue
if model_key=='config_file_version':
if model_key == 'config_file_version':
continue
model_name, model_type = self.parse_key(model_key)
models[model_key] = dict()
# TODO: return all models in future
if model_type != SDModelType.diffusers:
if model_type != SDModelType.Diffusers:
continue
model_format = stanza.get('format')
@ -477,15 +477,15 @@ class ModelManager(object):
status = self.cache.status(
stanza.get('weights') or stanza.get('repo_id'),
revision=stanza.get('revision'),
subfolder=stanza.get('subfolder')
subfolder=stanza.get('subfolder'),
)
description = stanza.get("description", None)
models[model_key].update(
model_name=model_name,
model_type=model_type.name,
model_type=model_type,
format=model_format,
description=description,
status=status.value
status=status.value,
)
@ -528,8 +528,8 @@ class ModelManager(object):
def del_model(
self,
model_name: str,
model_type: SDModelType.diffusers,
delete_files: bool = False
model_type: SDModelType.Diffusers,
delete_files: bool = False,
):
"""
Delete the named model.
@ -539,9 +539,9 @@ class ModelManager(object):
if model_cfg is None:
self.logger.error(
f"Unknown model {model_key}"
)
return
f"Unknown model {model_key}"
)
return
# TODO: some legacy?
#if model_name in self.stack:
@ -571,7 +571,7 @@ class ModelManager(object):
model_name: str,
model_type: SDModelType,
model_attributes: dict,
clobber: bool = False
clobber: bool = False,
) -> None:
"""
Update the named model with a dictionary of attributes. Will fail with an
@ -581,7 +581,7 @@ class ModelManager(object):
attributes are incorrect or the model name is missing.
"""
if model_type == SDModelType.diffusers:
if model_type == SDModelType.Fiffusers:
# TODO: automaticaly or manualy?
#assert "format" in model_attributes, 'missing required field "format"'
model_format = "ckpt" if "weights" in model_attributes else "diffusers"
@ -647,16 +647,16 @@ class ModelManager(object):
else:
new_config.update(repo_id=repo_or_path)
self.add_model(model_name, SDModelType.diffusers, new_config, True)
self.add_model(model_name, SDModelType.Diffusers, new_config, True)
if commit_to_conf:
self.commit(commit_to_conf)
return self.create_key(model_name, SDModelType.diffusers)
return self.create_key(model_name, SDModelType.Diffusers)
def import_lora(
self,
path: Path,
model_name: str=None,
description: str=None,
model_name: Optional[str] = None,
description: Optional[str] = None,
):
"""
Creates an entry for the indicated lora file. Call
@ -667,7 +667,7 @@ class ModelManager(object):
model_description = description or f"LoRA model {model_name}"
self.add_model(
model_name,
SDModelType.lora,
SDModelType.Lora,
dict(
format="lora",
weights=str(path),
@ -679,8 +679,8 @@ class ModelManager(object):
def import_embedding(
self,
path: Path,
model_name: str=None,
description: str=None,
model_name: Optional[str] = None,
description: Optional[str] = None,
):
"""
Creates an entry for the indicated lora file. Call
@ -696,7 +696,7 @@ class ModelManager(object):
model_description = description or f"Textual embedding model {model_name}"
self.add_model(
model_name,
SDModelType.textual_inversion,
SDModelType.TextualInversion,
dict(
format="textual_inversion",
weights=str(weights),
@ -746,11 +746,11 @@ class ModelManager(object):
def heuristic_import(
self,
path_url_or_repo: str,
model_name: str = None,
description: str = None,
model_config_file: Path = None,
commit_to_conf: Path = None,
config_file_callback: Callable[[Path], Path] = None,
model_name: Optional[str] = None,
description: Optional[str] = None,
model_config_file: Optional[Path] = None,
commit_to_conf: Optional[Path] = None,
config_file_callback: Optional[Callable[[Path], Path]] = None,
) -> str:
"""Accept a string which could be:
- a HF diffusers repo_id
@ -927,7 +927,7 @@ class ModelManager(object):
)
return model_name
def convert_ckpt_and_cache(self, mconfig: DictConfig)->Path:
def convert_ckpt_and_cache(self, mconfig: DictConfig) -> Path:
"""
Convert the checkpoint model indicated in mconfig into a
diffusers, cache it to disk, and return Path to converted
@ -961,7 +961,7 @@ class ModelManager(object):
self,
weights: Path,
mconfig: DictConfig
) -> Tuple[Path, SDModelType.vae]:
) -> Tuple[Path, AutoencoderKL]:
# VAE handling is convoluted
# 1. If there is a .vae.ckpt file sharing same stem as weights, then use
# it as the vae_path passed to convert
@ -990,7 +990,7 @@ class ModelManager(object):
vae_diffusers_location = "stabilityai/sd-vae-ft-mse"
if vae_diffusers_location:
vae_model = self.cache.get_model(vae_diffusers_location, SDModelType.vae).model
vae_model = self.cache.get_model(vae_diffusers_location, SDModelType.Vae).model
return (None, vae_model)
return (None, None)
@ -1038,7 +1038,7 @@ class ModelManager(object):
vae_model = None
if vae:
vae_location = global_resolve_path(vae.get('path')) or vae.get('repo_id')
vae_model = self.cache.get_model(vae_location,SDModelType.vae).model
vae_model = self.cache.get_model(vae_location, SDModelType.Vae).model
vae_path = None
convert_ckpt_to_diffusers(
ckpt_path,
@ -1058,11 +1058,11 @@ class ModelManager(object):
description=model_description,
format="diffusers",
)
if self.model_exists(model_name, SDModelType.diffusers):
self.del_model(model_name, SDModelType.diffusers)
if self.model_exists(model_name, SDModelType.Diffusers):
self.del_model(model_name, SDModelType.Diffusers)
self.add_model(
model_name,
SDModelType.diffusers,
SDModelType.Diffusers,
new_config,
True
)