mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
implement StALKeR7779 requested API for fetching submodels
This commit is contained in:
parent
fd63e36822
commit
c15b49c805
@ -19,10 +19,9 @@ context. Use like this:
|
|||||||
import contextlib
|
import contextlib
|
||||||
import gc
|
import gc
|
||||||
import hashlib
|
import hashlib
|
||||||
import logging
|
|
||||||
import warnings
|
import warnings
|
||||||
from collections import Counter
|
from collections import Counter
|
||||||
from enum import Enum
|
from enum import Enum,auto
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from psutil import Process
|
from psutil import Process
|
||||||
from typing import Dict, Sequence, Union, Tuple, types
|
from typing import Dict, Sequence, Union, Tuple, types
|
||||||
@ -52,9 +51,15 @@ DEFAULT_MAX_CACHE_SIZE = 6.0
|
|||||||
GIG = 1073741824
|
GIG = 1073741824
|
||||||
|
|
||||||
# This is the mapping from the stable diffusion submodel dict key to the class
|
# This is the mapping from the stable diffusion submodel dict key to the class
|
||||||
|
class LoraType(dict):
|
||||||
|
pass
|
||||||
|
class TIType(dict):
|
||||||
|
pass
|
||||||
|
class CkptType(dict):
|
||||||
|
pass
|
||||||
|
|
||||||
class SDModelType(Enum):
|
class SDModelType(Enum):
|
||||||
diffusion_pipeline=StableDiffusionGeneratorPipeline # whole thing
|
diffusers=StableDiffusionGeneratorPipeline # whole pipeline
|
||||||
diffusers=StableDiffusionGeneratorPipeline # same thing, different name
|
|
||||||
vae=AutoencoderKL # diffusers parts
|
vae=AutoencoderKL # diffusers parts
|
||||||
text_encoder=CLIPTextModel
|
text_encoder=CLIPTextModel
|
||||||
tokenizer=CLIPTokenizer
|
tokenizer=CLIPTokenizer
|
||||||
@ -62,10 +67,11 @@ class SDModelType(Enum):
|
|||||||
scheduler=SchedulerMixin
|
scheduler=SchedulerMixin
|
||||||
safety_checker=StableDiffusionSafetyChecker
|
safety_checker=StableDiffusionSafetyChecker
|
||||||
feature_extractor=CLIPFeatureExtractor
|
feature_extractor=CLIPFeatureExtractor
|
||||||
# These are all loaded as dicts of tensors
|
# These are all loaded as dicts of tensors, and we
|
||||||
lora=dict
|
# distinguish them by class
|
||||||
textual_inversion=dict
|
lora=LoraType
|
||||||
ckpt=dict
|
textual_inversion=TIType
|
||||||
|
ckpt=CkptType
|
||||||
|
|
||||||
class ModelStatus(Enum):
|
class ModelStatus(Enum):
|
||||||
unknown='unknown'
|
unknown='unknown'
|
||||||
@ -78,17 +84,16 @@ class ModelStatus(Enum):
|
|||||||
# After loading, we will know it exactly.
|
# After loading, we will know it exactly.
|
||||||
# Sizes are in Gigs, estimated for float16; double for float32
|
# Sizes are in Gigs, estimated for float16; double for float32
|
||||||
SIZE_GUESSTIMATE = {
|
SIZE_GUESSTIMATE = {
|
||||||
SDModelType.diffusion_pipeline: 2.5,
|
|
||||||
SDModelType.diffusers: 2.5,
|
SDModelType.diffusers: 2.5,
|
||||||
SDModelType.vae: 0.35,
|
SDModelType.vae: 0.35,
|
||||||
SDModelType.text_encoder: 0.5,
|
SDModelType.text_encoder: 0.5,
|
||||||
SDModelType.tokenizer: 0.0001,
|
SDModelType.tokenizer: 0.001,
|
||||||
SDModelType.unet: 3.4,
|
SDModelType.unet: 3.4,
|
||||||
SDModelType.scheduler: 0.0001,
|
SDModelType.scheduler: 0.001,
|
||||||
SDModelType.safety_checker: 1.2,
|
SDModelType.safety_checker: 1.2,
|
||||||
SDModelType.feature_extractor: 0.0001,
|
SDModelType.feature_extractor: 0.001,
|
||||||
SDModelType.lora: 0.1,
|
SDModelType.lora: 0.1,
|
||||||
SDModelType.textual_inversion: 0.0001,
|
SDModelType.textual_inversion: 0.001,
|
||||||
SDModelType.ckpt: 4.2,
|
SDModelType.ckpt: 4.2,
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -152,7 +157,7 @@ class ModelCache(object):
|
|||||||
def get_model(
|
def get_model(
|
||||||
self,
|
self,
|
||||||
repo_id_or_path: Union[str,Path],
|
repo_id_or_path: Union[str,Path],
|
||||||
model_type: SDModelType=SDModelType.diffusion_pipeline,
|
model_type: SDModelType=SDModelType.diffusers,
|
||||||
subfolder: Path=None,
|
subfolder: Path=None,
|
||||||
submodel: SDModelType=None,
|
submodel: SDModelType=None,
|
||||||
revision: str=None,
|
revision: str=None,
|
||||||
@ -263,7 +268,7 @@ class ModelCache(object):
|
|||||||
self.current_cache_size += usage.mem_used # increment size of the cache
|
self.current_cache_size += usage.mem_used # increment size of the cache
|
||||||
|
|
||||||
# this is a bit of legacy work needed to support the old-style "load this diffuser with custom VAE"
|
# this is a bit of legacy work needed to support the old-style "load this diffuser with custom VAE"
|
||||||
if model_type==SDModelType.diffusion_pipeline and attach_model_part[0]:
|
if model_type==SDModelType.diffusers and attach_model_part[0]:
|
||||||
self.attach_part(model,*attach_model_part)
|
self.attach_part(model,*attach_model_part)
|
||||||
|
|
||||||
self.stack.append(key) # add to LRU cache
|
self.stack.append(key) # add to LRU cache
|
||||||
@ -301,8 +306,10 @@ class ModelCache(object):
|
|||||||
cache.locked_models[key] += 1
|
cache.locked_models[key] += 1
|
||||||
if cache.lazy_offloading:
|
if cache.lazy_offloading:
|
||||||
cache._offload_unlocked_models()
|
cache._offload_unlocked_models()
|
||||||
cache.logger.debug(f'Loading {key} into {cache.execution_device}')
|
if model.device != cache.execution_device:
|
||||||
|
cache.logger.debug(f'Moving {key} into {cache.execution_device}')
|
||||||
model.to(cache.execution_device) # move into GPU
|
model.to(cache.execution_device) # move into GPU
|
||||||
|
cache.logger.debug(f'Locking {key} in {cache.execution_device}')
|
||||||
cache._print_cuda_stats()
|
cache._print_cuda_stats()
|
||||||
else:
|
else:
|
||||||
# in the event that the caller wants the model in RAM, we
|
# in the event that the caller wants the model in RAM, we
|
||||||
@ -345,7 +352,7 @@ class ModelCache(object):
|
|||||||
|
|
||||||
def status(self,
|
def status(self,
|
||||||
repo_id_or_path: Union[str,Path],
|
repo_id_or_path: Union[str,Path],
|
||||||
model_type: SDModelType=SDModelType.diffusion_pipeline,
|
model_type: SDModelType=SDModelType.diffusers,
|
||||||
revision: str=None,
|
revision: str=None,
|
||||||
subfolder: Path=None,
|
subfolder: Path=None,
|
||||||
)->ModelStatus:
|
)->ModelStatus:
|
||||||
@ -428,7 +435,7 @@ class ModelCache(object):
|
|||||||
def _make_cache_room(self, key, model_type):
|
def _make_cache_room(self, key, model_type):
|
||||||
# calculate how much memory this model will require
|
# calculate how much memory this model will require
|
||||||
multiplier = 2 if self.precision==torch.float32 else 1
|
multiplier = 2 if self.precision==torch.float32 else 1
|
||||||
bytes_needed = int(self.model_sizes.get(key,0) or SIZE_GUESSTIMATE[model_type]*GIG*multiplier)
|
bytes_needed = int(self.model_sizes.get(key,0) or SIZE_GUESSTIMATE.get(model_type,0.5)*GIG*multiplier)
|
||||||
maximum_size = self.max_cache_size * GIG # stored in GB, convert to bytes
|
maximum_size = self.max_cache_size * GIG # stored in GB, convert to bytes
|
||||||
current_size = self.current_cache_size
|
current_size = self.current_cache_size
|
||||||
|
|
||||||
@ -473,7 +480,7 @@ class ModelCache(object):
|
|||||||
# silence transformer and diffuser warnings
|
# silence transformer and diffuser warnings
|
||||||
with SilenceWarnings():
|
with SilenceWarnings():
|
||||||
if self.is_legacy_ckpt(repo_id_or_path):
|
if self.is_legacy_ckpt(repo_id_or_path):
|
||||||
model = self._load_ckpt_from_storage(repo_id_or_path, legacy_info)
|
model = model_class(self._load_ckpt_from_storage(repo_id_or_path, legacy_info))
|
||||||
else:
|
else:
|
||||||
model = self._load_diffusers_from_storage(
|
model = self._load_diffusers_from_storage(
|
||||||
repo_id_or_path,
|
repo_id_or_path,
|
||||||
|
@ -20,18 +20,37 @@ return a SDModelInfo object that contains the following attributes:
|
|||||||
Typical usage:
|
Typical usage:
|
||||||
|
|
||||||
from invokeai.backend import ModelManager
|
from invokeai.backend import ModelManager
|
||||||
manager = ModelManager(config_path='./configs/models.yaml',max_models=4)
|
|
||||||
|
manager = ModelManager(
|
||||||
|
config='./configs/models.yaml',
|
||||||
|
max_cache_size=8
|
||||||
|
) # gigabytes
|
||||||
|
|
||||||
model_info = manager.get_model('stable-diffusion-1.5')
|
model_info = manager.get_model('stable-diffusion-1.5')
|
||||||
with model_info.context as my_model:
|
with model_info.context as my_model:
|
||||||
my_model.latents_from_embeddings(...)
|
my_model.latents_from_embeddings(...)
|
||||||
|
|
||||||
The manager uses the underlying ModelCache class to keep
|
The manager uses the underlying ModelCache class to keep
|
||||||
frequently-used models in RAM and move them into GPU as needed for
|
frequently-used models in RAM and move them into GPU as needed for
|
||||||
generation operations. The ModelCache object can be accessed using
|
generation operations. The optional `max_cache_size` argument
|
||||||
the manager's "cache" attribute.
|
indicates the maximum size the cache can grow to, in gigabytes. The
|
||||||
|
underlying ModelCache object can be accessed using the manager's "cache"
|
||||||
|
attribute.
|
||||||
|
|
||||||
Other methods provided by ModelManager support importing, editing,
|
Because the model manager can return multiple different types of
|
||||||
converting and deleting models.
|
models, you may wish to add additional type checking on the class
|
||||||
|
of model returned. To do this, provide the option `model_type`
|
||||||
|
parameter:
|
||||||
|
|
||||||
|
model_info = manager.get_model(
|
||||||
|
'clip-tokenizer',
|
||||||
|
model_type=SDModelType.tokenizer
|
||||||
|
)
|
||||||
|
|
||||||
|
This will raise an InvalidModelError if the format defined in the
|
||||||
|
config file doesn't match the requested model type.
|
||||||
|
|
||||||
|
MODELS.YAML
|
||||||
|
|
||||||
The general format of a models.yaml section is:
|
The general format of a models.yaml section is:
|
||||||
|
|
||||||
@ -40,7 +59,6 @@ The general format of a models.yaml section is:
|
|||||||
repo_id: owner/repo
|
repo_id: owner/repo
|
||||||
path: /path/to/local/file/or/directory
|
path: /path/to/local/file/or/directory
|
||||||
subfolder: subfolder-name
|
subfolder: subfolder-name
|
||||||
submodel: vae|text_encoder|tokenizer...
|
|
||||||
|
|
||||||
The format is one of {diffusers, ckpt, vae, text_encoder, tokenizer,
|
The format is one of {diffusers, ckpt, vae, text_encoder, tokenizer,
|
||||||
unet, scheduler, safety_checker, feature_extractor}, and correspond to
|
unet, scheduler, safety_checker, feature_extractor}, and correspond to
|
||||||
@ -54,11 +72,7 @@ If subfolder is provided, then the model exists in a subdirectory of
|
|||||||
the main model. These are usually named after the model type, such as
|
the main model. These are usually named after the model type, such as
|
||||||
"unet".
|
"unet".
|
||||||
|
|
||||||
Finally, if submodel is provided, then the path/repo_id is treated as
|
This example summarizes the two ways of getting a non-diffuser model:
|
||||||
a diffusers model, the whole thing is ready into memory, and then the
|
|
||||||
requested part (e.g. "unet") is retrieved.
|
|
||||||
|
|
||||||
This summarizes the three ways of getting a non-diffuser model:
|
|
||||||
|
|
||||||
clip-test-1:
|
clip-test-1:
|
||||||
format: text_encoder
|
format: text_encoder
|
||||||
@ -66,21 +80,48 @@ This summarizes the three ways of getting a non-diffuser model:
|
|||||||
description: Returns standalone CLIPTextModel
|
description: Returns standalone CLIPTextModel
|
||||||
|
|
||||||
clip-test-2:
|
clip-test-2:
|
||||||
format: diffusers
|
|
||||||
repo_id: stabilityai/stable-diffusion-2
|
|
||||||
submodel: text_encoder
|
|
||||||
description: Returns the text_encoder part of whole diffusers model (whole thing in RAM)
|
|
||||||
|
|
||||||
clip-test-3:
|
|
||||||
format: text_encoder
|
format: text_encoder
|
||||||
repo_id: stabilityai/stable-diffusion-2
|
repo_id: stabilityai/stable-diffusion-2
|
||||||
subfolder: text_encoder
|
subfolder: text_encoder
|
||||||
description: Returns the text_encoder in the subfolder of the diffusers model (just the encoder in RAM)
|
description: Returns the text_encoder in the subfolder of the diffusers model (just the encoder in RAM)
|
||||||
|
|
||||||
clip-token:
|
SUBMODELS:
|
||||||
|
|
||||||
|
It is also possible to fetch an isolated submodel from a diffusers
|
||||||
|
model. Use the `submodel` parameter to select which part:
|
||||||
|
|
||||||
|
vae = manager.get_model('stable-diffusion-1.5',submodel=SDModelType.vae)
|
||||||
|
with vae.context as my_vae:
|
||||||
|
print(type(my_vae))
|
||||||
|
# "AutoencoderKL"
|
||||||
|
|
||||||
|
DISAMBIGUATION:
|
||||||
|
|
||||||
|
You may wish to use the same name for a related family of models. To
|
||||||
|
do this, disambiguate the stanza key with the model and and format
|
||||||
|
separated by "/". Example:
|
||||||
|
|
||||||
|
clip-large/tokenizer:
|
||||||
format: tokenizer
|
format: tokenizer
|
||||||
repo_id: openai/clip-vit-large-patch14
|
repo_id: openai/clip-vit-large-patch14
|
||||||
description: Returns standalone tokenizer
|
description: Returns standalone tokenizer
|
||||||
|
|
||||||
|
clip-large/text_encoder:
|
||||||
|
format: text_encoder
|
||||||
|
repo_id: openai/clip-vit-large-patch14
|
||||||
|
description: Returns standalone text encoder
|
||||||
|
|
||||||
|
You can now use the `model_type` argument to indicate which model you
|
||||||
|
want:
|
||||||
|
|
||||||
|
tokenizer = mgr.get('clip-large',model_type=SDModelType.tokenizer)
|
||||||
|
encoder = mgr.get('clip-large',model_type=SDModelType.text_encoder)
|
||||||
|
|
||||||
|
OTHER FUNCTIONS:
|
||||||
|
|
||||||
|
Other methods provided by ModelManager support importing, editing,
|
||||||
|
converting and deleting models.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
@ -152,7 +193,7 @@ class ModelManager(object):
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config_path: Path,
|
config: Union[Path, DictConfig, str],
|
||||||
device_type: torch.device = CUDA_DEVICE,
|
device_type: torch.device = CUDA_DEVICE,
|
||||||
precision: torch.dtype = torch.float16,
|
precision: torch.dtype = torch.float16,
|
||||||
max_cache_size=MAX_CACHE_SIZE,
|
max_cache_size=MAX_CACHE_SIZE,
|
||||||
@ -165,8 +206,15 @@ class ModelManager(object):
|
|||||||
and sequential_offload boolean. Note that the default device
|
and sequential_offload boolean. Note that the default device
|
||||||
type and precision are set up for a CUDA system running at half precision.
|
type and precision are set up for a CUDA system running at half precision.
|
||||||
"""
|
"""
|
||||||
self.config_path = config_path
|
if isinstance(config, DictConfig):
|
||||||
|
self.config = config
|
||||||
|
self.config_path = None
|
||||||
|
elif type(config) in [str,DictConfig]:
|
||||||
|
self.config_path = config
|
||||||
self.config = OmegaConf.load(self.config_path)
|
self.config = OmegaConf.load(self.config_path)
|
||||||
|
else:
|
||||||
|
raise ValueError('config argument must be an OmegaConf object, a Path or a string')
|
||||||
|
|
||||||
self.cache = ModelCache(
|
self.cache = ModelCache(
|
||||||
max_cache_size=max_cache_size,
|
max_cache_size=max_cache_size,
|
||||||
execution_device = device_type,
|
execution_device = device_type,
|
||||||
@ -185,28 +233,64 @@ class ModelManager(object):
|
|||||||
return model_name in self.config
|
return model_name in self.config
|
||||||
|
|
||||||
def get_model(self,
|
def get_model(self,
|
||||||
model_name: str = None,
|
model_name: str,
|
||||||
|
model_type: SDModelType=None,
|
||||||
submodel: SDModelType=None,
|
submodel: SDModelType=None,
|
||||||
) -> SDModelInfo:
|
) -> SDModelInfo:
|
||||||
"""Given a model named identified in models.yaml, return
|
"""Given a model named identified in models.yaml, return
|
||||||
an SDModelInfo object describing it.
|
an SDModelInfo object describing it.
|
||||||
:param model_name: symbolic name of the model in models.yaml
|
:param model_name: symbolic name of the model in models.yaml
|
||||||
|
:param model_type: SDModelType enum indicating the type of model to return
|
||||||
:param submodel: an SDModelType enum indicating the portion of
|
:param submodel: an SDModelType enum indicating the portion of
|
||||||
the model to retrieve (e.g. SDModelType.vae)
|
the model to retrieve (e.g. SDModelType.vae)
|
||||||
|
|
||||||
|
If not provided, the model_type will be read from the `format` field
|
||||||
|
of the corresponding stanza. If provided, the model_type will be used
|
||||||
|
to disambiguate stanzas in the configuration file. The default is to
|
||||||
|
assume a diffusers pipeline. The behavior is illustrated here:
|
||||||
|
|
||||||
|
[models.yaml]
|
||||||
|
test1/diffusers:
|
||||||
|
repo_id: foo/bar
|
||||||
|
format: diffusers
|
||||||
|
description: Typical diffusers pipeline
|
||||||
|
|
||||||
|
test1/lora:
|
||||||
|
repo_id: /tmp/loras/test1.safetensors
|
||||||
|
format: lora
|
||||||
|
description: Typical lora file
|
||||||
|
|
||||||
|
test1_pipeline = mgr.get_model('test1')
|
||||||
|
# returns a StableDiffusionGeneratorPipeline
|
||||||
|
|
||||||
|
test1_vae1 = mgr.get_model('test1',submodel=SDModelType.vae)
|
||||||
|
# returns the VAE part of a diffusers model as an AutoencoderKL
|
||||||
|
|
||||||
|
test1_vae2 = mgr.get_model('test1',model_type=SDModelType.diffusers,submodel=SDModelType.vae)
|
||||||
|
# does the same thing as the previous statement. Note that model_type
|
||||||
|
# is for the parent model, and submodel is for the part
|
||||||
|
|
||||||
|
test1_lora = mgr.get_model('test1',model_type=SDModelType.lora)
|
||||||
|
# returns a LoRA embed (as a 'dict' of tensors)
|
||||||
|
|
||||||
|
test1_encoder = mgr.get_modelI('test1',model_type=SDModelType.textencoder)
|
||||||
|
# raises an InvalidModelError
|
||||||
|
|
||||||
"""
|
"""
|
||||||
if not model_name:
|
if not model_name:
|
||||||
model_name = self.default_model()
|
model_name = self.default_model()
|
||||||
|
|
||||||
if not self.valid_model(model_name):
|
model_key = self._disambiguate_name(model_name, model_type)
|
||||||
raise InvalidModelError(
|
|
||||||
f'"{model_name}" is not a known model name. Please check your models.yaml file'
|
|
||||||
)
|
|
||||||
|
|
||||||
# get the required loading info out of the config file
|
# get the required loading info out of the config file
|
||||||
mconfig = self.config[model_name]
|
mconfig = self.config[model_key]
|
||||||
|
|
||||||
format = mconfig.get('format','diffusers')
|
format = mconfig.get('format','diffusers')
|
||||||
model_type = SDModelType.diffusion_pipeline
|
if model_type and model_type.name != format:
|
||||||
|
raise InvalidModelError(
|
||||||
|
f'Inconsistent model definition; {model_key} has format {format}, but type {model_type.name} was requested'
|
||||||
|
)
|
||||||
|
|
||||||
model_parts = dict([(x.name,x) for x in SDModelType])
|
model_parts = dict([(x.name,x) for x in SDModelType])
|
||||||
legacy = None
|
legacy = None
|
||||||
|
|
||||||
@ -219,16 +303,14 @@ class ModelManager(object):
|
|||||||
legacy.vae_file = global_resolve_path(mconfig.vae)
|
legacy.vae_file = global_resolve_path(mconfig.vae)
|
||||||
elif format=='diffusers':
|
elif format=='diffusers':
|
||||||
location = mconfig.get('repo_id') or mconfig.get('path')
|
location = mconfig.get('repo_id') or mconfig.get('path')
|
||||||
if sm := mconfig.get('submodel'):
|
|
||||||
submodel = model_parts[sm]
|
|
||||||
elif format in model_parts:
|
elif format in model_parts:
|
||||||
location = mconfig.get('repo_id') or mconfig.get('path') or mconfig.get('weights')
|
location = mconfig.get('repo_id') or mconfig.get('path') or mconfig.get('weights')
|
||||||
model_type = model_parts[format]
|
|
||||||
else:
|
else:
|
||||||
raise InvalidModelError(
|
raise InvalidModelError(
|
||||||
f'"{model_name}" has an unknown format {format}'
|
f'"{model_key}" has an unknown format {format}'
|
||||||
)
|
)
|
||||||
|
|
||||||
|
model_type = model_parts[format]
|
||||||
subfolder = mconfig.get('subfolder')
|
subfolder = mconfig.get('subfolder')
|
||||||
revision = mconfig.get('revision')
|
revision = mconfig.get('revision')
|
||||||
hash = self.cache.model_hash(location,revision)
|
hash = self.cache.model_hash(location,revision)
|
||||||
@ -254,7 +336,7 @@ class ModelManager(object):
|
|||||||
# in case we need to communicate information about this
|
# in case we need to communicate information about this
|
||||||
# model to the cache manager, then we need to remember
|
# model to the cache manager, then we need to remember
|
||||||
# the cache key
|
# the cache key
|
||||||
self.cache_keys[model_name] = model_context.key
|
self.cache_keys[model_key] = model_context.key
|
||||||
|
|
||||||
return SDModelInfo(
|
return SDModelInfo(
|
||||||
context = model_context,
|
context = model_context,
|
||||||
@ -449,18 +531,20 @@ class ModelManager(object):
|
|||||||
else:
|
else:
|
||||||
assert "weights" in model_attributes and "description" in model_attributes
|
assert "weights" in model_attributes and "description" in model_attributes
|
||||||
|
|
||||||
|
model_key = f'{model_name}/{format}'
|
||||||
|
|
||||||
assert (
|
assert (
|
||||||
clobber or model_name not in omega
|
clobber or model_key not in omega
|
||||||
), f'attempt to overwrite existing model definition "{model_name}"'
|
), f'attempt to overwrite existing model definition "{model_key}"'
|
||||||
|
|
||||||
omega[model_name] = model_attributes
|
omega[model_key] = model_attributes
|
||||||
|
|
||||||
if "weights" in omega[model_name]:
|
if "weights" in omega[model_key]:
|
||||||
omega[model_name]["weights"].replace("\\", "/")
|
omega[model_key]["weights"].replace("\\", "/")
|
||||||
|
|
||||||
if clobber and model_name in self.cache_keys:
|
if clobber and model_key in self.cache_keys:
|
||||||
self.cache.uncache_model(self.cache_keys[model_name])
|
self.cache.uncache_model(self.cache_keys[model_key])
|
||||||
del self.cache_keys[model_name]
|
del self.cache_keys[model_key]
|
||||||
|
|
||||||
def import_diffuser_model(
|
def import_diffuser_model(
|
||||||
self,
|
self,
|
||||||
@ -482,6 +566,7 @@ class ModelManager(object):
|
|||||||
models.yaml file.
|
models.yaml file.
|
||||||
"""
|
"""
|
||||||
model_name = model_name or Path(repo_or_path).stem
|
model_name = model_name or Path(repo_or_path).stem
|
||||||
|
model_key = f'{model_name}/diffusers'
|
||||||
model_description = description or f"Imported diffusers model {model_name}"
|
model_description = description or f"Imported diffusers model {model_name}"
|
||||||
new_config = dict(
|
new_config = dict(
|
||||||
description=model_description,
|
description=model_description,
|
||||||
@ -493,10 +578,10 @@ class ModelManager(object):
|
|||||||
else:
|
else:
|
||||||
new_config.update(repo_id=repo_or_path)
|
new_config.update(repo_id=repo_or_path)
|
||||||
|
|
||||||
self.add_model(model_name, new_config, True)
|
self.add_model(model_key, new_config, True)
|
||||||
if commit_to_conf:
|
if commit_to_conf:
|
||||||
self.commit(commit_to_conf)
|
self.commit(commit_to_conf)
|
||||||
return model_name
|
return model_key
|
||||||
|
|
||||||
def import_lora(
|
def import_lora(
|
||||||
self,
|
self,
|
||||||
@ -511,7 +596,7 @@ class ModelManager(object):
|
|||||||
path = Path(path)
|
path = Path(path)
|
||||||
model_name = model_name or path.stem
|
model_name = model_name or path.stem
|
||||||
model_description = description or f"LoRA model {model_name}"
|
model_description = description or f"LoRA model {model_name}"
|
||||||
self.add_model(model_name,
|
self.add_model(f'{model_name}/{SDModelType.lora.name}',
|
||||||
dict(
|
dict(
|
||||||
format="lora",
|
format="lora",
|
||||||
weights=str(path),
|
weights=str(path),
|
||||||
@ -538,7 +623,7 @@ class ModelManager(object):
|
|||||||
|
|
||||||
model_name = model_name or path.stem
|
model_name = model_name or path.stem
|
||||||
model_description = description or f"Textual embedding model {model_name}"
|
model_description = description or f"Textual embedding model {model_name}"
|
||||||
self.add_model(model_name,
|
self.add_model(f'{model_name}/{SDModelType.textual_inversion.name}',
|
||||||
dict(
|
dict(
|
||||||
format="textual_inversion",
|
format="textual_inversion",
|
||||||
weights=str(weights),
|
weights=str(weights),
|
||||||
@ -871,6 +956,7 @@ class ModelManager(object):
|
|||||||
"""
|
"""
|
||||||
yaml_str = OmegaConf.to_yaml(self.config)
|
yaml_str = OmegaConf.to_yaml(self.config)
|
||||||
config_file_path = conf_file or self.config_path
|
config_file_path = conf_file or self.config_path
|
||||||
|
assert config_file_path is not None,'no config file path to write to'
|
||||||
tmpfile = os.path.join(os.path.dirname(config_file_path), "new_config.tmp")
|
tmpfile = os.path.join(os.path.dirname(config_file_path), "new_config.tmp")
|
||||||
with open(tmpfile, "w", encoding="utf-8") as outfile:
|
with open(tmpfile, "w", encoding="utf-8") as outfile:
|
||||||
outfile.write(self.preamble())
|
outfile.write(self.preamble())
|
||||||
@ -893,6 +979,18 @@ class ModelManager(object):
|
|||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _disambiguate_name(self, model_name:str, model_type:SDModelType)->str:
|
||||||
|
model_type = model_type or SDModelType.diffusers
|
||||||
|
full_name = f"{model_name}/{model_type.name}"
|
||||||
|
if self.valid_model(full_name):
|
||||||
|
return full_name
|
||||||
|
if self.valid_model(model_name):
|
||||||
|
return model_name
|
||||||
|
raise InvalidModelError(
|
||||||
|
f'Neither "{model_name}" nor "{full_name}" are known model names. Please check your models.yaml file'
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _delete_model_from_cache(cls,repo_id):
|
def _delete_model_from_cache(cls,repo_id):
|
||||||
cache_info = scan_cache_dir(global_cache_dir("hub"))
|
cache_info = scan_cache_dir(global_cache_dir("hub"))
|
||||||
|
@ -40,7 +40,7 @@ class DMType(Enum):
|
|||||||
type1 = DummyModelType1
|
type1 = DummyModelType1
|
||||||
type2 = DummyModelType2
|
type2 = DummyModelType2
|
||||||
|
|
||||||
cache = ModelCache(max_models=4)
|
cache = ModelCache(max_cache_size=4)
|
||||||
|
|
||||||
def test_pipeline_fetch():
|
def test_pipeline_fetch():
|
||||||
assert cache.cache_size()==0
|
assert cache.cache_size()==0
|
||||||
@ -53,12 +53,10 @@ def test_pipeline_fetch():
|
|||||||
assert type(pipeline1)==DMType.dummy_pipeline.value,'get_model() did not return model of expected type'
|
assert type(pipeline1)==DMType.dummy_pipeline.value,'get_model() did not return model of expected type'
|
||||||
assert pipeline1==pipeline1a,'pipelines with the same repo_id should be the same'
|
assert pipeline1==pipeline1a,'pipelines with the same repo_id should be the same'
|
||||||
assert pipeline1!=pipeline2,'pipelines with different repo_ids should not be the same'
|
assert pipeline1!=pipeline2,'pipelines with different repo_ids should not be the same'
|
||||||
assert cache.cache_size()==2,'cache should uniquely cache models with same identity'
|
assert len(cache.models)==2,'cache should uniquely cache models with same identity'
|
||||||
with cache.get_model('dummy/pipeline3',DMType.dummy_pipeline) as pipeline3,\
|
with cache.get_model('dummy/pipeline3',DMType.dummy_pipeline) as pipeline3,\
|
||||||
cache.get_model('dummy/pipeline4',DMType.dummy_pipeline) as pipeline4:
|
cache.get_model('dummy/pipeline4',DMType.dummy_pipeline) as pipeline4:
|
||||||
assert cache.cache_size()==4,'cache did not grow as expected'
|
assert len(cache.models)==4,'cache did not grow as expected'
|
||||||
with cache.get_model('dummy/pipeline5',DMType.dummy_pipeline) as pipeline5:
|
|
||||||
assert cache.cache_size()==4,'cache did not free space as expected'
|
|
||||||
|
|
||||||
def test_signatures():
|
def test_signatures():
|
||||||
with cache.get_model('dummy/pipeline',DMType.dummy_pipeline,revision='main') as pipeline1,\
|
with cache.get_model('dummy/pipeline',DMType.dummy_pipeline,revision='main') as pipeline1,\
|
||||||
|
Loading…
Reference in New Issue
Block a user