Rewrite model configs, separate models

This commit is contained in:
Sergey Borisov 2023-06-11 04:49:09 +03:00
parent 2c056ead42
commit 3ce3a7ee72
9 changed files with 843 additions and 1034 deletions

View File

@ -30,7 +30,6 @@ from typing import Dict, Union, types, Optional, List, Type, Any
import torch import torch
import transformers import transformers
from diffusers import DiffusionPipeline, SchedulerMixin, ConfigMixin
from diffusers import logging as diffusers_logging from diffusers import logging as diffusers_logging
from huggingface_hub import HfApi, scan_cache_dir from huggingface_hub import HfApi, scan_cache_dir
from transformers import logging as transformers_logging from transformers import logging as transformers_logging
@ -40,7 +39,7 @@ from invokeai.app.services.config import get_invokeai_config
from .lora import LoRAModel, TextualInversionModel from .lora import LoRAModel, TextualInversionModel
from .models import MODEL_CLASSES from .models import BaseModelType, ModelType, SubModelType
# Maximum size of the cache, in gigs # Maximum size of the cache, in gigs
@ -129,11 +128,12 @@ class ModelCache(object):
def get_key( def get_key(
self, self,
model_path: str, model_path: str,
model_type: SDModelType, base_model: BaseModelType,
submodel_type: Optional[SDModelType] = None, model_type: ModelType,
submodel_type: Optional[SubModelType] = None,
): ):
key = f"{model_path}:{model_type}" key = f"{model_path}:{base_model}:{model_type}"
if submodel_type: if submodel_type:
key += f":{submodel_type}" key += f":{submodel_type}"
return key return key
@ -152,9 +152,12 @@ class ModelCache(object):
self, self,
model_path: str, model_path: str,
model_class: Type[ModelBase], model_class: Type[ModelBase],
base_model: BaseModelType,
model_type: ModelType,
): ):
model_info_key = self.get_key( model_info_key = self.get_key(
model_path=model_path, model_path=model_path,
base_model=base_model,
model_type=model_type, model_type=model_type,
submodel_type=None, submodel_type=None,
) )
@ -172,6 +175,8 @@ class ModelCache(object):
self, self,
model_path: Union[str, Path], model_path: Union[str, Path],
model_class: Type[ModelBase], model_class: Type[ModelBase],
base_model: BaseModelType,
model_type: ModelType,
submodel: Optional[SubModelType] = None, submodel: Optional[SubModelType] = None,
gpu_load: bool = True, gpu_load: bool = True,
) -> Any: ) -> Any:
@ -185,17 +190,20 @@ class ModelCache(object):
model_info = self._get_model_info( model_info = self._get_model_info(
model_path=model_path, model_path=model_path,
model_class=model_class, model_class=model_class,
base_model=base_model,
model_type=model_type,
) )
key = self.get_key( key = self.get_key(
model_path=model_path, model_path=model_path,
model_type=model_type, # TODO: base_model=base_model,
model_type=model_type,
submodel_type=submodel, submodel_type=submodel,
) )
# TODO: lock for no copies on simultaneous calls? # TODO: lock for no copies on simultaneous calls?
cache_entry = self._cached_models.get(key, None) cache_entry = self._cached_models.get(key, None)
if cache_entry is None: if cache_entry is None:
self.logger.info(f'Loading model {model_path}, type {model_type}:{submodel}') self.logger.info(f'Loading model {model_path}, type {base_model}:{model_type}:{submodel}')
# this will remove older cached models until # this will remove older cached models until
# there is sufficient room to load the requested model # there is sufficient room to load the requested model
@ -203,7 +211,7 @@ class ModelCache(object):
# clean memory to make MemoryUsage() more accurate # clean memory to make MemoryUsage() more accurate
gc.collect() gc.collect()
model = model_info.get_model(submodel, torch_dtype=self.precision, variant=) model = model_info.get_model(submodel, torch_dtype=self.precision)
if mem_used := model_info.get_size(submodel): if mem_used := model_info.get_size(submodel):
self.logger.debug(f'CPU RAM used for load: {(mem_used/GIG):.2f} GB') self.logger.debug(f'CPU RAM used for load: {(mem_used/GIG):.2f} GB')

View File

@ -221,6 +221,9 @@ MAX_CACHE_SIZE = 6.0 # GB
# └── realesrgan # └── realesrgan
class ConfigMeta(BaseModel):
version: str
class ModelManager(object): class ModelManager(object):
""" """
High-level interface to model management. High-level interface to model management.
@ -243,15 +246,24 @@ class ModelManager(object):
and sequential_offload boolean. Note that the default device and sequential_offload boolean. Note that the default device
type and precision are set up for a CUDA system running at half precision. type and precision are set up for a CUDA system running at half precision.
""" """
if isinstance(config, DictConfig):
self.config_path = None self.config_path = None
self.config = config if isinstance(config, (str, Path)):
elif isinstance(config,(str,Path)): self.config_path = Path(config)
self.config_path = config config = OmegaConf.load(self.config_path)
self.config = OmegaConf.load(self.config_path)
else: elif not isinstance(config, DictConfig):
raise ValueError('config argument must be an OmegaConf object, a Path or a string') raise ValueError('config argument must be an OmegaConf object, a Path or a string')
config_meta = ConfigMeta(config.pop("__metadata__")) # TODO: naming
# TODO: metadata not found
self.models = dict()
for model_key, model_config in config.items():
model_name, base_model, model_type = self.parse_key(model_key)
model_class = MODEL_CLASSES[base_model][model_type]
self.models[model_key] = model_class.build_config(**model_config)
# check config version number and update on disk/RAM if necessary # check config version number and update on disk/RAM if necessary
self.globals = InvokeAIAppConfig.get_config() self.globals = InvokeAIAppConfig.get_config()
self._update_config_file_version() self._update_config_file_version()
@ -279,7 +291,7 @@ class ModelManager(object):
identifier. identifier.
""" """
model_key = self.create_key(model_name, base_model, model_type) model_key = self.create_key(model_name, base_model, model_type)
return model_key in self.config return model_key in self.models
def create_key( def create_key(
self, self,
@ -351,41 +363,38 @@ class ModelManager(object):
model_class = MODEL_CLASSES[base_model][model_type] model_class = MODEL_CLASSES[base_model][model_type]
#if model_type in {
# ModelType.Lora,
# ModelType.ControlNet,
# ModelType.TextualInversion,
# ModelType.Vae,
#}:
if not model_class.has_config:
#if model_class.Config is None:
# skip config
# load from
# /models/{base_model}/{model_type}/{model_name}
# /models/{base_model}/{model_type}/{model_name}.{ext}
model_config = None
for ext in {"pt", "ckpt", "safetensors"}:
model_path = os.path.join(model_dir, base_model, model_type, f"{model_name}.{ext}")
if os.path.exists(model_path):
break
else:
model_path = os.path.join(model_dir, base_model, model_type, model_name)
if not os.path.exists(model_path):
raise InvalidModelError(
f"Model not found - \"{base_model}/{model_type}/{model_name}\" "
)
else:
# find in config
model_key = self.create_key(model_name, base_model, model_type) model_key = self.create_key(model_name, base_model, model_type)
if model_key not in self.config:
raise InvalidModelError(
f'"{model_key}" is not a known model name. Please check your models.yaml file'
)
model_config = self.config[model_key] # if model not found try to find it (maybe file just pasted)
if model_key not in self.models:
# TODO: find by mask or try rescan?
path_mask = f"/models/{base_model}/{model_type}/{model_name}*"
if False: # model_path = next(find_by_mask(path_mask)):
model_path = None # TODO:
model_config = model_class.build_config(
path=model_path,
)
self.models[model_key] = model_config
else:
raise Exception(f"Model not found - {model_key}")
# if it known model check that target path exists (if manualy deleted)
else:
# logic repeated twice(in rescan too) any way to optimize?
if not os.path.exists(self.models[model_key].path):
if model_class.save_to_config:
self.models[model_key].error = ModelError.NotFound
raise Exception(f"Files for model \"{model_key}\" not found")
else:
self.models.pop(model_key, None)
raise Exception(f"Model not found - {model_key}")
# reset model errors?
model_config = self.models[model_key]
# /models/{base_model}/{model_type}/{name}.ckpt or .safentesors # /models/{base_model}/{model_type}/{name}.ckpt or .safentesors
# /models/{base_model}/{model_type}/{name}/ # /models/{base_model}/{model_type}/{name}/
@ -394,7 +403,7 @@ class ModelManager(object):
# vae/movq override # vae/movq override
# TODO: # TODO:
if submodel is not None and submodel in model_config: if submodel is not None and submodel in model_config:
model_path = model_config[submodel]["path"] model_path = model_config[submodel]
model_type = submodel model_type = submodel
submodel = None submodel = None
@ -429,11 +438,11 @@ class ModelManager(object):
Returns the name of the default model, or None Returns the name of the default model, or None
if none is defined. if none is defined.
""" """
for model_key, model_config in self.config.items(): for model_key, model_config in self.models.items():
if model_config.get("default", False): if model_config.default:
return self.parse_key(model_key) return self.parse_key(model_key)
for model_key, _ in self.config.items(): for model_key, _ in self.models.items():
return self.parse_key(model_key) return self.parse_key(model_key)
else: else:
return None # TODO: or redo as (None, None, None) return None # TODO: or redo as (None, None, None)
@ -450,14 +459,11 @@ class ModelManager(object):
""" """
model_key = self.model_key(model_name, base_model, model_type) model_key = self.model_key(model_name, base_model, model_type)
if model_key not in self.config: if model_key not in self.models:
raise Exception(f"Unknown model: {model_key}") raise Exception(f"Unknown model: {model_key}")
for cur_model_key, config in self.config.items(): for cur_model_key, config in self.models.items():
if cur_model_key == model_key: config.default = cur_model_key == model_key
config["default"] = True
else:
config.pop("default", None)
def model_info( def model_info(
self, self,
@ -469,14 +475,17 @@ class ModelManager(object):
Given a model name returns the OmegaConf (dict-like) object describing it. Given a model name returns the OmegaConf (dict-like) object describing it.
""" """
model_key = self.create_key(model_name, base_model, model_type) model_key = self.create_key(model_name, base_model, model_type)
return self.config.get(model_key, None) if model_key in self.models:
return self.models[model_key].dict(exclude_defaults=True)
else:
return None # TODO: None or empty dict on not found
def model_names(self) -> List[Tuple[str, BaseModelType, ModelType]]: def model_names(self) -> List[Tuple[str, BaseModelType, ModelType]]:
""" """
Return a list of (str, BaseModelType, ModelType) corresponding to all models Return a list of (str, BaseModelType, ModelType) corresponding to all models
known to the configuration. known to the configuration.
""" """
return [(self.parse_key(x)) for x in self.config.keys() if isinstance(self.config[x], DictConfig)] return [(self.parse_key(x)) for x in self.models.keys()]
def list_models( def list_models(
self, self,
@ -494,48 +503,37 @@ class ModelManager(object):
assert not(model_type is not None and base_model is None), "model_type must be provided with base_model" assert not(model_type is not None and base_model is None), "model_type must be provided with base_model"
models = dict() models = dict()
for model_key in sorted(self.config, key=str.casefold): for model_key in sorted(self.models, key=str.casefold):
stanza = self.config[model_key] model_config = self.models[model_key]
if model_key.startswith('_'): cur_model_name, cur_base_model, cur_model_type = self.parse_key(model_key)
if base_model is not None and cur_base_model != base_model:
continue
if model_type is not None and cur_model_type != model_type:
continue continue
model_name, m_base_model, stanza_type = self.parse_key(model_key) if cur_base_model not in models:
if base_model is not None and m_base_model != base_model: models[cur_base_model] = dict()
continue if cur_model_type not in models[cur_base_model]:
if model_type is not None and model_type != stanza_type: models[cur_base_model][cur_model_type] = dict()
continue
if m_base_model not in models: models[m_base_model][stanza_type][model_name] = dict(
models[m_base_model] = dict() **model_config.dict(exclude_defaults=True),
if stanza_type not in models:
models[m_base_model][stanza_type] = dict()
model_class = MODEL_CLASSES[m_base_model][stanza_type]
models[m_base_model][stanza_type][model_name] = model_class.build_config(
**stanza,
name=model_name, name=model_name,
base_model=base_model, base_model=cur_base_model,
type=stanza_type, type=cur_model_type,
) )
#models[m_base_model][stanza_type][model_name] = model_class.Config(
# **stanza,
# name=model_name,
# base_model=base_model,
# type=stanza_type,
#).dict()
return models return models
def print_models(self) -> None: def print_models(self) -> None:
""" """
Print a table of models, their descriptions, and load status Print a table of models, their descriptions
""" """
# TODO: redo
for model_type, model_dict in self.list_models().items(): for model_type, model_dict in self.list_models().items():
for model_name, model_info in model_dict.items(): for model_name, model_info in model_dict.items():
line = f'{model_info["name"]:25s} {model_info["status"]:>15s} {model_info["type"]:10s} {model_info["description"]}' line = f'{model_info["name"]:25s} {model_info["type"]:10s} {model_info["description"]}'
if model_info["status"] in ["in gpu","locked in gpu"]:
line = f"\033[1m{line}\033[0m"
print(line) print(line)
def del_model( def del_model(
@ -596,27 +594,14 @@ class ModelManager(object):
""" """
model_class = MODEL_CLASSES[base_model][model_type] model_class = MODEL_CLASSES[base_model][model_type]
model_config = model_class.build_config(**model_attributes)
model_class.build_config(
**model_attributes,
name=model_name,
base_model=base_model,
type=model_type,
)
#model_cfg = model_class.Config(
# **model_attributes,
# name=model_name,
# base_model=base_model,
# type=model_type,
#)
model_key = self.create_key(model_name, base_model, model_type) model_key = self.create_key(model_name, base_model, model_type)
assert ( assert (
clobber or model_key not in self.config clobber or model_key not in self.models
), f'attempt to overwrite existing model definition "{model_key}"' ), f'attempt to overwrite existing model definition "{model_key}"'
self.config[model_key] = model_attributes self.models[model_key] = model_config
if clobber and model_key in self.cache_keys: if clobber and model_key in self.cache_keys:
# TODO: # TODO:
@ -822,7 +807,15 @@ class ModelManager(object):
""" """
Write current configuration out to the indicated file. Write current configuration out to the indicated file.
""" """
yaml_str = OmegaConf.to_yaml(self.config) data_to_save = dict()
for model_key, model_config in self.models.items():
model_name, base_model, model_type = self.parse_key(model_key)
model_class = MODEL_CLASSES[base_model][model_type]
if model_class.save_to_config:
# TODO: or exclude_unset better fits here?
data_to_save[model_key] = model_config.dict(exclude_defaults=True)
yaml_str = OmegaConf.to_yaml(data_to_save)
config_file_path = conf_file or self.config_path config_file_path = conf_file or self.config_path
assert config_file_path is not None,'no config file path to write to' assert config_file_path is not None,'no config file path to write to'
config_file_path = self.globals.root_dir / config_file_path config_file_path = self.globals.root_dir / config_file_path
@ -887,146 +880,41 @@ class ModelManager(object):
return resolved_path return resolved_path
def _update_config_file_version(self): def _update_config_file_version(self):
""" # TODO:
This gets called at object init time and will update raise Exception("TODO: ")
from older versions of the config file to new ones
as necessary.
"""
current_version = self.config.get("_version","1.0.0")
if version.parse(current_version) < version.parse(CONFIG_FILE_VERSION):
self.logger.warning(f'models.yaml version {current_version} detected. Updating to {CONFIG_FILE_VERSION}')
self.logger.warning('The original file will be renamed models.yaml.orig')
if self.config_path:
old_file = Path(self.config_path)
new_name = old_file.parent / 'models.yaml.orig'
old_file.replace(new_name)
new_config = OmegaConf.create() def scan_models_directory(self):
new_config["_version"] = CONFIG_FILE_VERSION
for model_key in self.config: for model_key in list(self.models.keys()):
model_name, base_model, model_type = self.parse_key(model_key)
old_stanza = self.config[model_key] if not os.path.exists(model_config.path):
if not isinstance(old_stanza,DictConfig): if model_class.save_to_config:
continue self.models[model_key].error = ModelError.NotFound
# ignore old and ugly way of associating a legacy
# vae with a legacy checkpont model
if old_stanza.get("config") and '/VAE/' in old_stanza.get("config"):
continue
# bare keys are updated to be prefixed with 'diffusers/'
if '/' not in model_key:
new_key = f'diffusers/{model_key}'
else: else:
new_key = model_key self.models.pop(model_key, None)
if old_stanza.get('format')=='diffusers':
model_format = 'folder'
elif old_stanza.get('weights') and Path(old_stanza.get('weights')).suffix == '.ckpt':
model_format = 'ckpt'
elif old_stanza.get('weights') and Path(old_stanza.get('weights')).suffix == '.safetensors':
model_format = 'safetensors'
else:
model_format = old_stanza.get('format')
# copy fields over manually rather than doing a copy() or deepcopy() for base_model in BaseModelType:
# in order to avoid bringing in unwanted fields. for model_type in ModelType:
new_config[new_key] = dict(
description = old_stanza.get('description'),
format = model_format,
)
for field in ["repo_id", "path", "weights", "config", "vae"]:
if field_value := old_stanza.get(field):
new_config[new_key].update({field: field_value})
self.config = new_config model_class = MODEL_CLASSES[base_model][model_type]
if self.config_path: models_dir = os.path.join(self.globals.models_path, base_model, model_type)
self.commit()
def _delete_defunct_models(self): for entry_name in os.listdir(models_dir):
''' model_path = os.path.join(models_dir, entry_name)
Remove models no longer on disk. model_name = Path(model_path).stem
''' model_config: ModelConfigBase = model_class.build_config(
config = self.config path=model_path,
to_delete = set()
for key in config:
if 'path' not in config[key]:
continue
path = self.globals.root_dir / config[key].path
if path.exists():
continue
to_delete.add(key)
for key in to_delete:
self.logger.warn(f'Removing model {key} from in-memory config because its path is no longer on disk')
config.pop(key)
def scan_models_directory(self, include_diffusers:bool=False):
'''
Scan the models directory for loras, textual_inversions and controlnets
and create appropriate entries in the in-memory omegaconf. Diffusers
will not be added unless include_diffusers is true.
'''
self._delete_defunct_models()
model_directory = self.globals.models_path
config = self.config
for root, dirs, files in os.walk(model_directory):
parents = root.split('/')
subpaths = parents[parents.index('models')+1:]
if len(subpaths) < 2:
continue
base, model_type, *_ = subpaths
if model_type == "diffusers" and not include_diffusers:
continue
for d in dirs:
config[f'{model_type}/{d}'] = dict(
path = os.path.join(root,d),
description = f'{model_type} model {d}',
format = 'folder',
base = base,
) )
for f in files: model_key = self.create_key(model_name, base_model, model_type)
basename = Path(f).stem if model_key not in self.models:
format = Path(f).suffix[1:] self.models[model_key] = model_config
config[f'{model_type}/{basename}'] = dict(
path = os.path.join(root,f),
description = f'{model_type} model {basename}',
format = format,
base = base,
)
##### NONE OF THE METHODS BELOW WORK NOW BECAUSE OF MODEL DIRECTORY REORGANIZATION ##### NONE OF THE METHODS BELOW WORK NOW BECAUSE OF MODEL DIRECTORY REORGANIZATION
##### AND NEED TO BE REWRITTEN ##### AND NEED TO BE REWRITTEN
def list_lora_models(self)->Dict[str,bool]:
'''Return a dict of installed lora models; key is either the shortname
defined in INITIAL_MODELS, or the basename of the file in the LoRA
directory. Value is True if installed'''
models = OmegaConf.load(Dataset_path).get('lora') or {}
installed_models = {x: False for x in models.keys()}
dir = self.globals.lora_path
installed_models = dict()
for root, dirs, files in os.walk(dir):
for name in files:
if Path(name).suffix not in ['.safetensors','.ckpt','.pt','.bin']:
continue
if name == 'pytorch_lora_weights.bin':
name = Path(root,name).parent.stem #Path(root,name).stem
else:
name = Path(name).stem
installed_models.update({name: True})
return installed_models
def install_lora_models(self, model_names: list[str], access_token:str=None): def install_lora_models(self, model_names: list[str], access_token:str=None):
'''Download list of LoRA/LyCORIS models''' '''Download list of LoRA/LyCORIS models'''
@ -1052,38 +940,6 @@ class ModelManager(object):
else: else:
self.logger.error(f"Unknown repo_id or URL: {name}") self.logger.error(f"Unknown repo_id or URL: {name}")
def delete_lora_models(self, model_names: List[str]):
'''Remove the list of lora models'''
for name in model_names:
file_or_directory = self.globals.lora_path / name
if file_or_directory.is_dir():
self.logger.info(f'Purging LoRA/LyCORIS {name}')
shutil.rmtree(str(file_or_directory))
else:
for path in self.globals.lora_path.glob(f'{name}.*'):
self.logger.info(f'Purging LoRA/LyCORIS {name}')
path.unlink()
def list_ti_models(self)->Dict[str,bool]:
'''Return a dict of installed textual models; key is either the shortname
defined in INITIAL_MODELS, or the basename of the file in the LoRA
directory. Value is True if installed'''
models = OmegaConf.load(Dataset_path).get('textual_inversion') or {}
installed_models = {x: False for x in models.keys()}
dir = self.globals.embedding_path
for root, dirs, files in os.walk(dir):
for name in files:
if not Path(name).suffix in ['.bin','.pt','.ckpt','.safetensors']:
continue
if name == 'learned_embeds.bin':
name = Path(root,name).parent.stem #Path(root,name).stem
else:
name = Path(name).stem
installed_models.update({name: True})
return installed_models
def install_ti_models(self, model_names: list[str], access_token: str=None): def install_ti_models(self, model_names: list[str], access_token: str=None):
'''Download list of textual inversion embeddings''' '''Download list of textual inversion embeddings'''
@ -1105,31 +961,6 @@ class ModelManager(object):
else: else:
self.logger.error(f'{name} does not look like either a HuggingFace repo_id or a downloadable URL') self.logger.error(f'{name} does not look like either a HuggingFace repo_id or a downloadable URL')
def delete_ti_models(self, model_names: list[str]):
'''Remove TI embeddings from disk'''
for name in model_names:
file_or_directory = self.globals.embedding_path / name
if file_or_directory.is_dir():
self.logger.info(f'Purging textual inversion embedding {name}')
shutil.rmtree(str(file_or_directory))
else:
for path in self.globals.embedding_path.glob(f'{name}.*'):
self.logger.info(f'Purging textual inversion embedding {name}')
path.unlink()
def list_controlnet_models(self)->Dict[str,bool]:
'''Return a dict of installed controlnet models; key is repo_id or short name
of model (defined in INITIAL_MODELS), and value is True if installed'''
cn_models = OmegaConf.load(Dataset_path).get('controlnet') or {}
installed_models = {x: False for x in cn_models.keys()}
cn_dir = self.globals.controlnet_path
for root, dirs, files in os.walk(cn_dir):
for name in dirs:
if Path(root, name, '.download_complete').exists():
installed_models.update({name.replace('--','/'): True})
return installed_models
def install_controlnet_models(self, model_names: list[str], access_token: str=None): def install_controlnet_models(self, model_names: list[str], access_token: str=None):
'''Download list of controlnet models; provide either repo_id or short name listed in INITIAL_MODELS.yaml''' '''Download list of controlnet models; provide either repo_id or short name listed in INITIAL_MODELS.yaml'''
@ -1175,12 +1006,4 @@ class ModelManager(object):
(path.parent / '.download_complete').touch() (path.parent / '.download_complete').touch()
break break
def delete_controlnet_models(self, model_names: List[str]):
'''Remove the list of controlnet models'''
for name in model_names:
safe_name = name.replace('/','--')
directory = self.globals.controlnet_path / safe_name
if directory.exists():
self.logger.info(f'Purging controlnet model {name}')
shutil.rmtree(str(directory))

View File

@ -1,726 +0,0 @@
import sys
from enum import Enum
import torch
import safetensors.torch
from diffusers.utils import is_safetensors_available
class BaseModelType(str, Enum):
#StableDiffusion1_5 = "stable_diffusion_1_5"
#StableDiffusion2 = "stable_diffusion_2"
#StableDiffusion2Base = "stable_diffusion_2_base"
# TODO: maybe then add sample size(512/768)?
StableDiffusion1_5 = "SD-1"
StableDiffusion2Base = "SD-2-base" # 512 pixels; this will have epsilon parameterization
StableDiffusion2 = "SD-2" # 768 pixels; this will have v-prediction parameterization
#Kandinsky2_1 = "kandinsky_2_1"
class ModelType(str, Enum):
Pipeline = "pipeline"
Classifier = "classifier"
Vae = "vae"
Lora = "lora"
ControlNet = "controlnet"
TextualInversion = "embedding"
class SubModelType:
UNet = "unet"
TextEncoder = "text_encoder"
Tokenizer = "tokenizer"
Vae = "vae"
Scheduler = "scheduler"
SafetyChecker = "safety_checker"
#MoVQ = "movq"
MODEL_CLASSES = {
BaseModel.StableDiffusion1_5: {
ModelType.Pipeline: StableDiffusionModel,
ModelType.Classifier: ClassifierModel,
ModelType.Vae: VaeModel,
ModelType.Lora: LoraModel,
ModelType.ControlNet: ControlNetModel,
ModelType.TextualInversion: TextualInversionModel,
},
BaseModel.StableDiffusion2: {
ModelType.Pipeline: StableDiffusionModel,
ModelType.Classifier: ClassifierModel,
ModelType.Vae: VaeModel,
ModelType.Lora: LoraModel,
ModelType.ControlNet: ControlNetModel,
ModelType.TextualInversion: TextualInversionModel,
},
BaseModel.StableDiffusion2Base: {
ModelType.Pipeline: StableDiffusionModel,
ModelType.Classifier: ClassifierModel,
ModelType.Vae: VaeModel,
ModelType.Lora: LoraModel,
ModelType.ControlNet: ControlNetModel,
ModelType.TextualInversion: TextualInversionModel,
},
#BaseModel.Kandinsky2_1: {
# ModelType.Pipeline: Kandinsky2_1Model,
# ModelType.Classifier: ClassifierModel,
# ModelType.MoVQ: MoVQModel,
# ModelType.Lora: LoraModel,
# ModelType.ControlNet: ControlNetModel,
# ModelType.TextualInversion: TextualInversionModel,
#},
}
class EmptyConfigLoader(ConfigMixin):
@classmethod
def load_config(cls, *args, **kwargs):
cls.config_name = kwargs.pop("config_name")
return super().load_config(*args, **kwargs)
class ModelBase:
#model_path: str
#base_model: BaseModelType
#model_type: ModelType
def __init__(
self,
model_path: str,
base_model: BaseModelType,
model_type: ModelType,
):
self.model_path = model_path
self.base_model = base_model
self.model_type = model_type
def _hf_definition_to_type(self, subtypes: List[str]) -> Type:
if len(subtypes) < 2:
raise Exception("Invalid subfolder definition!")
if subtypes[0] in ["diffusers", "transformers"]:
res_type = sys.modules[subtypes[0]]
subtypes = subtypes[1:]
else:
res_type = sys.modules["diffusers"]
res_type = getattr(res_type, "pipelines")
for subtype in subtypes:
res_type = getattr(res_type, subtype)
return res_type
class DiffusersModel(ModelBase):
#child_types: Dict[str, Type]
#child_sizes: Dict[str, int]
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
super().__init__(model_path, base_model, model_type)
self.child_types: Dict[str, Type] = dict()
self.child_sizes: Dict[str, int] = dict()
try:
config_data = DiffusionPipeline.load_config(self.model_path)
#config_data = json.loads(os.path.join(self.model_path, "model_index.json"))
except:
raise Exception("Invalid diffusers model! (model_index.json not found or invalid)")
config_data.pop("_ignore_files", None)
# retrieve all folder_names that contain relevant files
child_components = [k for k, v in config_data.items() if isinstance(v, list)]
for child_name in child_components:
child_type = self._hf_definition_to_type(config_data[child_name])
self.child_types[child_name] = child_type
self.child_sizes[child_name] = calc_model_size_by_fs(self.model_path, subfolder=child_name)
def get_size(self, child_type: Optional[SubModelType] = None):
if child_type is None:
return sum(self.child_sizes.values())
else:
return self.child_sizes[child_type]
def get_model(
self,
torch_dtype: Optional[torch.dtype],
child_type: Optional[SubModelType] = None,
):
# return pipeline in different function to pass more arguments
if child_type is None:
raise Exception("Child model type can't be null on diffusers model")
if child_type not in self.child_types:
return None # TODO: or raise
if torch_dtype == torch.float16:
variants = ["fp16", None]
else:
variants = [None, "fp16"]
# TODO: better error handling(differentiate not found from others)
for variant in variants:
try:
# TODO: set cache_dir to /dev/null to be sure that cache not used?
model = self.child_types[child_type].from_pretrained(
self.model_path,
subfolder=child_type.value,
torch_dtype=torch_dtype,
variant=variant,
local_files_only=True,
)
break
except Exception as e:
print("====ERR LOAD====")
print(f"{variant}: {e}")
# calc more accurate size
self.child_sizes[child_type] = calc_model_size_by_data(model)
return model
#def convert_if_required(model_path: Union[str, Path], cache_path: str, config: Optional[dict]) -> Path:
class StableDiffusionModel(DiffusersModel):
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
assert base_model in {
BaseModelType.StableDiffusion1_5,
BaseModelType.StableDiffusion2,
BaseModelType.StableDiffusion2Base,
}
assert model_type == ModelType.Pipeline
super().__init__(model_path, base_model, model_type)
@staticmethod
def convert_if_required(model_path: Union[str, Path], dst_path: str, config: Optional[dict]) -> Path:
if not isinstance(model_path, Path):
model_path = Path(model_path)
# TODO: args
# TODO: set model_path, to config? pass dst_path as arg?
# TODO: check
return _convert_ckpt_and_cache(config)
class classproperty(object): # pylint: disable=invalid-name
"""Class property decorator.
Example usage:
class MyClass(object):
@classproperty
def value(cls):
return '123'
> print MyClass.value
123
"""
def __init__(self, func):
self._func = func
def __get__(self, owner_self, owner_cls):
return self._func(owner_cls)
class ModelConfigBase(BaseModel):
path: str # or Path
name: str
description: Optional[str]
class StableDiffusionDModel(DiffusersModel):
class Config(ModelConfigBase):
format: str
vae: Optional[str] = Field(None)
config: Optional[str] = Field(None)
@root_validator
def validator(cls, values):
if values["format"] not in {"checkpoint", "diffusers"}:
raise ValueError(f"Unkown stable diffusion model format: {values['format']}")
if values["config"] is not None and values["format"] != "checkpoint":
raise ValueError(f"Custom config field allowed only in checkpoint stable diffusion model")
return values
# return config only for checkpoint format
def dict(self, *args, **kwargs):
result = super().dict(*args, **kwargs)
if self.format != "checkpoint":
result.pop("config", None)
return result
@classproperty
def has_config(self):
return True
def build_config(self, **kwargs) -> dict:
try:
res = dict(
path=kwargs["path"],
name=kwargs["name"],
description=kwargs.get("description", None),
format=kwargs["format"],
vae=kwargs.get("vae", None),
)
if res["format"] not in {"checkpoint", "diffusers"}:
raise Exception(f"Unkonwn stable diffusion model format: {res['format']}")
if res["format"] == "checkpoint":
res["config"] = kwargs.get("config", None)
# TODO: raise if config specified for diffusers?
return res
except KeyError as e:
raise Exception(f"Field \"{e.args[0]}\" not found!")
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
assert base_model == BaseModelType.StableDiffusion1_5
assert model_type == ModelType.Pipeline
super().__init__(model_path, base_model, model_type)
@classmethod
def convert_if_required(cls, model_path: str, dst_path: str, config: Optional[dict]) -> str:
model_config = cls.Config(
**config,
path=model_path,
name="",
)
if hasattr(model_config, "config"):
convert_ckpt_and_cache(
model_path=model_path,
dst_path=dst_path,
config=config,
)
return dst_path
else:
return model_path
class StableDiffusion15CheckpointModel(DiffusersModel):
class Cnfig(ModelConfigBase):
vae: Optional[str] = Field(None)
config: Optional[str] = Field(None)
class StableDiffusion2BaseDiffusersModel(DiffusersModel):
class Config(ModelConfigBase):
vae: Optional[str] = Field(None)
class StableDiffusion2BaseCheckpointModel(DiffusersModel):
class Cnfig(ModelConfigBase):
vae: Optional[str] = Field(None)
config: Optional[str] = Field(None)
class StableDiffusion2DiffusersModel(DiffusersModel):
class Config(ModelConfigBase):
vae: Optional[str] = Field(None)
attention_upscale: bool = Field(True)
class StableDiffusion2CheckpointModel(DiffusersModel):
class Config(ModelConfigBase):
vae: Optional[str] = Field(None)
config: Optional[str] = Field(None)
attention_upscale: bool = Field(True)
class ClassifierModel(ModelBase):
#child_types: Dict[str, Type]
#child_sizes: Dict[str, int]
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
assert model_type == SDModelType.Classifier
super().__init__(model_path, base_model, model_type)
self.child_types: Dict[str, Type] = dict()
self.child_sizes: Dict[str, int] = dict()
try:
main_config = EmptyConfigLoader.load_config(self.model_path, config_name="config.json")
#main_config = json.loads(os.path.join(self.model_path, "config.json"))
except:
raise Exception("Invalid classifier model! (config.json not found or invalid)")
self._load_tokenizer(main_config)
self._load_text_encoder(main_config)
self._load_feature_extractor(main_config)
def _load_tokenizer(self, main_config: dict):
try:
tokenizer_config = EmptyConfigLoader.load_config(self.model_path, config_name="tokenizer_config.json")
#tokenizer_config = json.loads(os.path.join(self.model_path, "tokenizer_config.json"))
except:
raise Exception("Invalid classifier model! (Failed to load tokenizer_config.json)")
if "tokenizer_class" in tokenizer_config:
tokenizer_class_name = tokenizer_config["tokenizer_class"]
elif "model_type" in main_config:
tokenizer_class_name = transformers.models.auto.tokenization_auto.TOKENIZER_MAPPING_NAMES[main_config["model_type"]]
else:
raise Exception("Invalid classifier model! (Failed to detect tokenizer type)")
self.child_types[SDModelType.Tokenizer] = self._hf_definition_to_type(["transformers", tokenizer_class_name])
self.child_sizes[SDModelType.Tokenizer] = 0
def _load_text_encoder(self, main_config: dict):
if "architectures" in main_config and len(main_config["architectures"]) > 0:
text_encoder_class_name = main_config["architectures"][0]
elif "model_type" in main_config:
text_encoder_class_name = transformers.models.auto.modeling_auto.MODEL_FOR_PRETRAINING_MAPPING_NAMES[main_config["model_type"]]
else:
raise Exception("Invalid classifier model! (Failed to detect text_encoder type)")
self.child_types[SDModelType.TextEncoder] = self._hf_definition_to_type(["transformers", text_encoder_class_name])
self.child_sizes[SDModelType.TextEncoder] = calc_model_size_by_fs(self.model_path)
def _load_feature_extractor(self, main_config: dict):
self.child_sizes[SDModelType.FeatureExtractor] = 0
try:
feature_extractor_config = EmptyConfigLoader.load_config(self.model_path, config_name="preprocessor_config.json")
except:
return # feature extractor not passed with t5
try:
feature_extractor_class_name = feature_extractor_config["feature_extractor_type"]
self.child_types[SDModelType.FeatureExtractor] = self._hf_definition_to_type(["transformers", feature_extractor_class_name])
except:
raise Exception("Invalid classifier model! (Unknown feature_extrator type)")
def get_size(self, child_type: Optional[SDModelType] = None):
if child_type is None:
return sum(self.child_sizes.values())
else:
return self.child_sizes[child_type]
def get_model(
self,
torch_dtype: Optional[torch.dtype],
child_type: Optional[SDModelType] = None,
):
if child_type is None:
raise Exception("Child model type can't be null on classififer model")
if child_type not in self.child_types:
return None # TODO: or raise
model = self.child_types[child_type].from_pretrained(
self.model_path,
subfolder=child_type.value,
torch_dtype=torch_dtype,
)
# calc more accurate size
self.child_sizes[child_type] = calc_model_size_by_data(model)
return model
@staticmethod
def convert_if_required(model_path: Union[str, Path], cache_path: str, config: Optional[dict]) -> Path:
if not isinstance(model_path, Path):
model_path = Path(model_path)
return model_path
class VaeModel(ModelBase):
#vae_class: Type
#model_size: int
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
assert model_type == ModelType.Vae
super().__init__(model_path, base_model, model_type)
try:
config = EmptyConfigLoader.load_config(self.model_path, config_name="config.json")
#config = json.loads(os.path.join(self.model_path, "config.json"))
except:
raise Exception("Invalid vae model! (config.json not found or invalid)")
try:
vae_class_name = config.get("_class_name", "AutoencoderKL")
self.vae_class = self._hf_definition_to_type(["diffusers", vae_class_name])
self.model_size = calc_model_size_by_fs(self.model_path)
except:
raise Exception("Invalid vae model! (Unkown vae type)")
def get_size(self, child_type: Optional[SDModelType] = None):
if child_type is not None:
raise Exception("There is no child models in vae model")
return self.model_size
def get_model(
self,
torch_dtype: Optional[torch.dtype],
child_type: Optional[SDModelType] = None,
):
if child_type is not None:
raise Exception("There is no child models in vae model")
model = self.vae_class.from_pretrained(
self.model_path,
torch_dtype=torch_dtype,
)
# calc more accurate size
self.model_size = calc_model_size_by_data(model)
return model
@staticmethod
def convert_if_required(model_path: Union[str, Path], cache_path: str, config: Optional[dict]) -> Path:
if not isinstance(model_path, Path):
model_path = Path(model_path)
# TODO:
#_convert_vae_ckpt_and_cache
raise Exception("TODO: ")
class LoRAModel(ModelBase):
#model_size: int
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
assert model_type == ModelType.Lora
super().__init__(model_path, base_model, model_type)
self.model_size = os.path.getsize(self.model_path)
def get_size(self, child_type: Optional[SDModelType] = None):
if child_type is not None:
raise Exception("There is no child models in lora")
return self.model_size
def get_model(
self,
torch_dtype: Optional[torch.dtype],
child_type: Optional[SDModelType] = None,
):
if child_type is not None:
raise Exception("There is no child models in lora")
model = LoRAModel.from_checkpoint(
file_path=self.model_path,
dtype=torch_dtype,
)
self.model_size = model.calc_size()
return model
@staticmethod
def convert_if_required(model_path: Union[str, Path], cache_path: str, config: Optional[dict]) -> Path:
if not isinstance(model_path, Path):
model_path = Path(model_path)
# TODO: add diffusers lora when it stabilizes a bit
return model_path
class TextualInversionModel(ModelBase):
#model_size: int
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
assert model_type == ModelType.TextualInversion
super().__init__(model_path, base_model, model_type)
self.model_size = os.path.getsize(self.model_path)
def get_size(self, child_type: Optional[SDModelType] = None):
if child_type is not None:
raise Exception("There is no child models in textual inversion")
return self.model_size
def get_model(
self,
torch_dtype: Optional[torch.dtype],
child_type: Optional[SDModelType] = None,
):
if child_type is not None:
raise Exception("There is no child models in textual inversion")
model = TextualInversionModel.from_checkpoint(
file_path=self.model_path,
dtype=torch_dtype,
)
self.model_size = model.embedding.nelement() * model.embedding.element_size()
return model
@staticmethod
def convert_if_required(model_path: Union[str, Path], cache_path: str, config: Optional[dict]) -> Path:
if not isinstance(model_path, Path):
model_path = Path(model_path)
return model_path
def calc_model_size_by_fs(
model_path: str,
subfolder: Optional[str] = None,
variant: Optional[str] = None
):
if subfolder is not None:
model_path = os.path.join(model_path, subfolder)
# this can happen when, for example, the safety checker
# is not downloaded.
if not os.path.exists(model_path):
return 0
all_files = os.listdir(model_path)
all_files = [f for f in all_files if os.path.isfile(os.path.join(model_path, f))]
fp16_files = set([f for f in all_files if ".fp16." in f or ".fp16-" in f])
bit8_files = set([f for f in all_files if ".8bit." in f or ".8bit-" in f])
other_files = set(all_files) - fp16_files - bit8_files
if variant is None:
files = other_files
elif variant == "fp16":
files = fp16_files
elif variant == "8bit":
files = bit8_files
else:
raise NotImplementedError(f"Unknown variant: {variant}")
# try read from index if exists
index_postfix = ".index.json"
if variant is not None:
index_postfix = f".index.{variant}.json"
for file in files:
if not file.endswith(index_postfix):
continue
try:
with open(os.path.join(model_path, file), "r") as f:
index_data = json.loads(f.read())
return int(index_data["metadata"]["total_size"])
except:
pass
# calculate files size if there is no index file
formats = [
(".safetensors",), # safetensors
(".bin",), # torch
(".onnx", ".pb"), # onnx
(".msgpack",), # flax
(".ckpt",), # tf
(".h5",), # tf2
]
for file_format in formats:
model_files = [f for f in files if f.endswith(file_format)]
if len(model_files) == 0:
continue
model_size = 0
for model_file in model_files:
file_stats = os.stat(os.path.join(model_path, model_file))
model_size += file_stats.st_size
return model_size
#raise NotImplementedError(f"Unknown model structure! Files: {all_files}")
return 0 # scheduler/feature_extractor/tokenizer - models without loading to gpu
def calc_model_size_by_data(model) -> int:
if isinstance(model, DiffusionPipeline):
return _calc_pipeline_by_data(model)
elif isinstance(model, torch.nn.Module):
return _calc_model_by_data(model)
else:
return 0
def _calc_pipeline_by_data(pipeline) -> int:
res = 0
for submodel_key in pipeline.components.keys():
submodel = getattr(pipeline, submodel_key)
if submodel is not None and isinstance(submodel, torch.nn.Module):
res += _calc_model_by_data(submodel)
return res
def _calc_model_by_data(model) -> int:
mem_params = sum([param.nelement()*param.element_size() for param in model.parameters()])
mem_bufs = sum([buf.nelement()*buf.element_size() for buf in model.buffers()])
mem = mem_params + mem_bufs # in bytes
return mem
def _convert_ckpt_and_cache(self, mconfig: DictConfig) -> Path:
"""
Convert the checkpoint model indicated in mconfig into a
diffusers, cache it to disk, and return Path to converted
file. If already on disk then just returns Path.
"""
app_config = InvokeAIAppConfig.get_config()
weights = app_config.root_dir / mconfig.path
config_file = app_config.root_dir / mconfig.config
diffusers_path = app_config.converted_ckpts_dir / weights.stem
# return cached version if it exists
if diffusers_path.exists():
return diffusers_path
# TODO: I think that it more correctly to convert with embedded vae
# as if user will delete custom vae he will got not embedded but also custom vae
#vae_ckpt_path, vae_model = self._get_vae_for_conversion(weights, mconfig)
vae_ckpt_path, vae_model = None, None
# to avoid circular import errors
from .convert_ckpt_to_diffusers import convert_ckpt_to_diffusers
with SilenceWarnings():
convert_ckpt_to_diffusers(
weights,
diffusers_path,
extract_ema=True,
original_config_file=config_file,
vae=vae_model,
vae_path=str(app_config.root_dir / vae_ckpt_path) if vae_ckpt_path else None,
scan_needed=True,
)
return diffusers_path
def _convert_vae_ckpt_and_cache(self, mconfig: DictConfig) -> Path:
"""
Convert the VAE indicated in mconfig into a diffusers AutoencoderKL
object, cache it to disk, and return Path to converted
file. If already on disk then just returns Path.
"""
app_config = InvokeAIAppConfig.get_config()
root = app_config.root_dir
weights_file = root / mconfig.path
config_file = root / mconfig.config
diffusers_path = app_config.converted_ckpts_dir / weights_file.stem
image_size = mconfig.get('width') or mconfig.get('height') or 512
# return cached version if it exists
if diffusers_path.exists():
return diffusers_path
# this avoids circular import error
from .convert_ckpt_to_diffusers import convert_ldm_vae_to_diffusers
if weights_file.suffix == '.safetensors':
checkpoint = safetensors.torch.load_file(weights_file)
else:
checkpoint = torch.load(weights_file, map_location="cpu")
# sometimes weights are hidden under "state_dict", and sometimes not
if "state_dict" in checkpoint:
checkpoint = checkpoint["state_dict"]
config = OmegaConf.load(config_file)
vae_model = convert_ldm_vae_to_diffusers(
checkpoint = checkpoint,
vae_config = config,
image_size = image_size
)
vae_model.save_pretrained(
diffusers_path,
safe_serialization=is_safetensors_available()
)
return diffusers_path

View File

@ -0,0 +1,37 @@
from .base import BaseModelType, ModelType, SubModelType, ModelBase, ModelConfigBase
from .stable_diffusion import StableDiffusion15Model, StableDiffusion2Model, StableDiffusion2BaseModel
from .vae import VaeModel
from .lora import LoRAModel
#from .controlnet import ControlNetModel # TODO:
from .textual_inversion import TextualInversionModel
MODEL_CLASSES = {
BaseModelType.StableDiffusion1_5: {
ModelType.Pipeline: StableDiffusion15Model,
ModelType.Vae: VaeModel,
ModelType.Lora: LoRAModel,
#ModelType.ControlNet: ControlNetModel,
ModelType.TextualInversion: TextualInversionModel,
},
BaseModelType.StableDiffusion2: {
ModelType.Pipeline: StableDiffusion2Model,
ModelType.Vae: VaeModel,
ModelType.Lora: LoRAModel,
#ModelType.ControlNet: ControlNetModel,
ModelType.TextualInversion: TextualInversionModel,
},
BaseModelType.StableDiffusion2Base: {
ModelType.Pipeline: StableDiffusion2BaseModel,
ModelType.Vae: VaeModel,
ModelType.Lora: LoRAModel,
#ModelType.ControlNet: ControlNetModel,
ModelType.TextualInversion: TextualInversionModel,
},
#BaseModelType.Kandinsky2_1: {
# ModelType.Pipeline: Kandinsky2_1Model,
# ModelType.MoVQ: MoVQModel,
# ModelType.Lora: LoRAModel,
# ModelType.ControlNet: ControlNetModel,
# ModelType.TextualInversion: TextualInversionModel,
#},
}

View File

@ -0,0 +1,295 @@
import sys
from enum import Enum
import torch
from diffusers import DiffusionPipeline, ConfigMixin
from pydantic import BaseModel, Field
from typing import List, Dict, Optional, Type
class BaseModelType(str, Enum):
#StableDiffusion1_5 = "stable_diffusion_1_5"
#StableDiffusion2 = "stable_diffusion_2"
#StableDiffusion2Base = "stable_diffusion_2_base"
# TODO: maybe then add sample size(512/768)?
StableDiffusion1_5 = "SD-1"
StableDiffusion2Base = "SD-2-base" # 512 pixels; this will have epsilon parameterization
StableDiffusion2 = "SD-2" # 768 pixels; this will have v-prediction parameterization
#Kandinsky2_1 = "kandinsky_2_1"
class ModelType(str, Enum):
Pipeline = "pipeline"
Vae = "vae"
Lora = "lora"
ControlNet = "controlnet"
TextualInversion = "embedding"
class SubModelType:
UNet = "unet"
TextEncoder = "text_encoder"
Tokenizer = "tokenizer"
Vae = "vae"
Scheduler = "scheduler"
SafetyChecker = "safety_checker"
#MoVQ = "movq"
class ModelError(str, Enum):
NotFound = "not_found"
class ModelConfigBase(BaseModel):
path: str # or Path
#name: str # not included as present in model key
description: Optional[str] = Field(None)
format: Optional[str] = Field(None)
default: Optional[bool] = Field(False)
# do not save to config
error: Optional[ModelError] = Field(None, exclude=True)
class EmptyConfigLoader(ConfigMixin):
@classmethod
def load_config(cls, *args, **kwargs):
cls.config_name = kwargs.pop("config_name")
return super().load_config(*args, **kwargs)
class ModelBase:
#model_path: str
#base_model: BaseModelType
#model_type: ModelType
def __init__(
self,
model_path: str,
base_model: BaseModelType,
model_type: ModelType,
):
self.model_path = model_path
self.base_model = base_model
self.model_type = model_type
def _hf_definition_to_type(self, subtypes: List[str]) -> Type:
if len(subtypes) < 2:
raise Exception("Invalid subfolder definition!")
if subtypes[0] in ["diffusers", "transformers"]:
res_type = sys.modules[subtypes[0]]
subtypes = subtypes[1:]
else:
res_type = sys.modules["diffusers"]
res_type = getattr(res_type, "pipelines")
for subtype in subtypes:
res_type = getattr(res_type, subtype)
return res_type
@classmethod
def _get_configs(cls):
if not hasattr(cls, "__configs"):
configs = dict()
for name in dir(cls):
if name.startswith("__"):
continue
value = getattr(cls, name)
if not isinstance(value, type) or not issubclass(value, ModelConfigBase):
continue
fields = inspect.get_annotations(value)
if "format" not in fields or typing.get_origin(fields["format"]) != Literal:
raise Exception("Invalid config definition - format field not found")
format_type = typing.get_origin(fields["format"])
if format_type not in {None, Literal}:
raise Exception(f"Invalid config definition - unknown format type: {fields['format']}")
if format_type is Literal:
format = fields["format"].__args__[0]
else:
format = None
configs[format] = value # TODO: error when override(multiple)?
cls.__configs = configs
return cls.__configs
@classmethod
def build_config(cls, **kwargs):
if "format" not in kwargs:
kwargs["format"] = cls.detect_format(kwargs["path"])
configs = cls._get_configs()
return configs[kwargs["format"]](**kwargs)
@classmethod
def detect_format(cls, path: str) -> str:
raise NotImplementedError()
class DiffusersModel(ModelBase):
#child_types: Dict[str, Type]
#child_sizes: Dict[str, int]
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
super().__init__(model_path, base_model, model_type)
self.child_types: Dict[str, Type] = dict()
self.child_sizes: Dict[str, int] = dict()
try:
config_data = DiffusionPipeline.load_config(self.model_path)
#config_data = json.loads(os.path.join(self.model_path, "model_index.json"))
except:
raise Exception("Invalid diffusers model! (model_index.json not found or invalid)")
config_data.pop("_ignore_files", None)
# retrieve all folder_names that contain relevant files
child_components = [k for k, v in config_data.items() if isinstance(v, list)]
for child_name in child_components:
child_type = self._hf_definition_to_type(config_data[child_name])
self.child_types[child_name] = child_type
self.child_sizes[child_name] = calc_model_size_by_fs(self.model_path, subfolder=child_name)
def get_size(self, child_type: Optional[SubModelType] = None):
if child_type is None:
return sum(self.child_sizes.values())
else:
return self.child_sizes[child_type]
def get_model(
self,
torch_dtype: Optional[torch.dtype],
child_type: Optional[SubModelType] = None,
):
# return pipeline in different function to pass more arguments
if child_type is None:
raise Exception("Child model type can't be null on diffusers model")
if child_type not in self.child_types:
return None # TODO: or raise
if torch_dtype == torch.float16:
variants = ["fp16", None]
else:
variants = [None, "fp16"]
# TODO: better error handling(differentiate not found from others)
for variant in variants:
try:
# TODO: set cache_dir to /dev/null to be sure that cache not used?
model = self.child_types[child_type].from_pretrained(
self.model_path,
subfolder=child_type.value,
torch_dtype=torch_dtype,
variant=variant,
local_files_only=True,
)
break
except Exception as e:
print("====ERR LOAD====")
print(f"{variant}: {e}")
# calc more accurate size
self.child_sizes[child_type] = calc_model_size_by_data(model)
return model
#def convert_if_required(model_path: str, cache_path: str, config: Optional[dict]) -> str:
def calc_model_size_by_fs(
model_path: str,
subfolder: Optional[str] = None,
variant: Optional[str] = None
):
if subfolder is not None:
model_path = os.path.join(model_path, subfolder)
# this can happen when, for example, the safety checker
# is not downloaded.
if not os.path.exists(model_path):
return 0
all_files = os.listdir(model_path)
all_files = [f for f in all_files if os.path.isfile(os.path.join(model_path, f))]
fp16_files = set([f for f in all_files if ".fp16." in f or ".fp16-" in f])
bit8_files = set([f for f in all_files if ".8bit." in f or ".8bit-" in f])
other_files = set(all_files) - fp16_files - bit8_files
if variant is None:
files = other_files
elif variant == "fp16":
files = fp16_files
elif variant == "8bit":
files = bit8_files
else:
raise NotImplementedError(f"Unknown variant: {variant}")
# try read from index if exists
index_postfix = ".index.json"
if variant is not None:
index_postfix = f".index.{variant}.json"
for file in files:
if not file.endswith(index_postfix):
continue
try:
with open(os.path.join(model_path, file), "r") as f:
index_data = json.loads(f.read())
return int(index_data["metadata"]["total_size"])
except:
pass
# calculate files size if there is no index file
formats = [
(".safetensors",), # safetensors
(".bin",), # torch
(".onnx", ".pb"), # onnx
(".msgpack",), # flax
(".ckpt",), # tf
(".h5",), # tf2
]
for file_format in formats:
model_files = [f for f in files if f.endswith(file_format)]
if len(model_files) == 0:
continue
model_size = 0
for model_file in model_files:
file_stats = os.stat(os.path.join(model_path, model_file))
model_size += file_stats.st_size
return model_size
#raise NotImplementedError(f"Unknown model structure! Files: {all_files}")
return 0 # scheduler/feature_extractor/tokenizer - models without loading to gpu
def calc_model_size_by_data(model) -> int:
if isinstance(model, DiffusionPipeline):
return _calc_pipeline_by_data(model)
elif isinstance(model, torch.nn.Module):
return _calc_model_by_data(model)
else:
return 0
def _calc_pipeline_by_data(pipeline) -> int:
res = 0
for submodel_key in pipeline.components.keys():
submodel = getattr(pipeline, submodel_key)
if submodel is not None and isinstance(submodel, torch.nn.Module):
res += _calc_model_by_data(submodel)
return res
def _calc_model_by_data(model) -> int:
mem_params = sum([param.nelement()*param.element_size() for param in model.parameters()])
mem_bufs = sum([buf.nelement()*buf.element_size() for buf in model.buffers()])
mem = mem_params + mem_bufs # in bytes
return mem

View File

@ -0,0 +1,63 @@
import torch
from typing import Optional
from .base import (
ModelBase,
ModelConfigBase,
BaseModelType,
ModelType,
SubModelType,
)
# TODO: naming
from ..lora import LoRAModel as LoRAModelRaw
class LoRAModel(ModelBase):
#model_size: int
class Config(ModelConfigBase):
format: None
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
assert model_type == ModelType.Lora
super().__init__(model_path, base_model, model_type)
self.model_size = os.path.getsize(self.model_path)
def get_size(self, child_type: Optional[SubModelType] = None):
if child_type is not None:
raise Exception("There is no child models in lora")
return self.model_size
def get_model(
self,
torch_dtype: Optional[torch.dtype],
child_type: Optional[SubModelType] = None,
):
if child_type is not None:
raise Exception("There is no child models in lora")
model = LoRAModelRaw.from_checkpoint(
file_path=self.model_path,
dtype=torch_dtype,
)
self.model_size = model.calc_size()
return model
@classmethod
def save_to_config(cls) -> bool:
return False
@classmethod
def detect_format(cls, path: str):
if os.path.isdir(path):
return "diffusers"
else:
return "lycoris"
@staticmethod
def convert_if_required(cls, model_path: str, dst_cache_path: str, config: Optional[dict]) -> str:
if cls.detect_format(model_path) == "diffusers":
# TODO: add diffusers lora when it stabilizes a bit
raise NotImplementedError("Diffusers lora not supported")
else:
return model_path

View File

@ -0,0 +1,131 @@
import os
import torch
from pydantic import Field
from typing import Literal, Optional
from .base import (
ModelBase,
ModelConfigBase,
BaseModelType,
ModelType,
SubModelType,
DiffusersModel,
)
from invokeai.app.services.config import InvokeAIAppConfig
# TODO: how to name properly
class StableDiffusion15Model(DiffusersModel):
# TODO: str -> Path?
class DiffusersConfig(ModelConfigBase):
format: Literal["diffusers"]
vae: Optional[str] = Field(None)
class CheckpointConfig(ModelConfigBase):
format: Literal["checkpoint"]
vae: Optional[str] = Field(None)
config: Optional[str] = Field(None)
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
assert base_model == BaseModelType.StableDiffusion1_5
assert model_type == ModelType.Pipeline
super().__init__(
model_path=model_path,
base_model=BaseModelType.StableDiffusion1_5,
model_type=ModelType.Pipeline,
)
@classmethod
def save_to_config(cls) -> bool:
return True
@classmethod
def detect_format(cls, model_path: str):
if os.path.isdir(model_path):
return "diffusers"
else:
return "checkpoint"
@classmethod
def convert_if_required(cls, model_path: str, dst_cache_path: str, config: Optional[dict]) -> str:
cfg = cls.build_config(**config)
if isinstance(cfg, cls.CheckpointConfig):
return _convert_ckpt_and_cache(cfg) # TODO: args
else:
return model_path
# all same
class StableDiffusion2BaseModel(StableDiffusion15Model):
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
# skip StableDiffusion15Model __init__
assert base_model == BaseModelType.StableDiffusion2Base
assert model_type == ModelType.Pipeline
super(StableDiffusion15Model, self).__init__(
model_path=model_path,
base_model=BaseModelType.StableDiffusion2Base,
model_type=ModelType.Pipeline,
)
class StableDiffusion2Model(DiffusersModel):
# TODO: str -> Path?
# overwrite configs
class DiffusersConfig(ModelConfigBase):
format: Literal["diffusers"]
vae: Optional[str] = Field(None)
attention_upscale: bool = Field(True)
class CheckpointConfig(ModelConfigBase):
format: Literal["checkpoint"]
vae: Optional[str] = Field(None)
config: Optional[str] = Field(None)
attention_upscale: bool = Field(True)
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
# skip StableDiffusion15Model __init__
assert base_model == BaseModelType.StableDiffusion2
assert model_type == ModelType.Pipeline
super().__init__(
model_path=model_path,
base_model=BaseModelType.StableDiffusion2,
model_type=ModelType.Pipeline,
)
# TODO: rework
DictConfig = dict
def _convert_ckpt_and_cache(self, mconfig: DictConfig) -> str:
"""
Convert the checkpoint model indicated in mconfig into a
diffusers, cache it to disk, and return Path to converted
file. If already on disk then just returns Path.
"""
app_config = InvokeAIAppConfig.get_config()
weights = app_config.root_dir / mconfig.path
config_file = app_config.root_dir / mconfig.config
diffusers_path = app_config.converted_ckpts_dir / weights.stem
# return cached version if it exists
if diffusers_path.exists():
return diffusers_path
# TODO: I think that it more correctly to convert with embedded vae
# as if user will delete custom vae he will got not embedded but also custom vae
#vae_ckpt_path, vae_model = self._get_vae_for_conversion(weights, mconfig)
vae_ckpt_path, vae_model = None, None
# to avoid circular import errors
from ..convert_ckpt_to_diffusers import convert_ckpt_to_diffusers
with SilenceWarnings():
convert_ckpt_to_diffusers(
weights,
diffusers_path,
extract_ema=True,
original_config_file=config_file,
vae=vae_model,
vae_path=str(app_config.root_dir / vae_ckpt_path) if vae_ckpt_path else None,
scan_needed=True,
)
return diffusers_path

View File

@ -0,0 +1,56 @@
import torch
from typing import Optional
from .base import (
ModelBase,
ModelConfigBase,
BaseModelType,
ModelType,
SubModelType,
)
# TODO: naming
from ..lora import TextualInversionModel as TextualInversionModelRaw
class TextualInversionModel(ModelBase):
#model_size: int
class Config(ModelConfigBase):
format: None
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
assert model_type == ModelType.TextualInversion
super().__init__(model_path, base_model, model_type)
self.model_size = os.path.getsize(self.model_path)
def get_size(self, child_type: Optional[SubModelType] = None):
if child_type is not None:
raise Exception("There is no child models in textual inversion")
return self.model_size
def get_model(
self,
torch_dtype: Optional[torch.dtype],
child_type: Optional[SubModelType] = None,
):
if child_type is not None:
raise Exception("There is no child models in textual inversion")
model = TextualInversionModelRaw.from_checkpoint(
file_path=self.model_path,
dtype=torch_dtype,
)
self.model_size = model.embedding.nelement() * model.embedding.element_size()
return model
@classmethod
def save_to_config(cls) -> bool:
return False
@classmethod
def detect_format(cls, path: str):
return None
@staticmethod
def convert_if_required(model_path: str, cache_path: str, config: Optional[dict]) -> str:
return model_path

View File

@ -0,0 +1,122 @@
import os
import torch
from typing import Optional
from .base import (
ModelBase,
ModelConfigBase,
BaseModelType,
ModelType,
SubModelType,
EmptyConfigLoader,
calc_model_size_by_fs,
calc_model_size_by_data,
)
from invokeai.app.services.config import InvokeAIAppConfig
class VaeModel(ModelBase):
#vae_class: Type
#model_size: int
class Config(ModelConfigBase):
format: None
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
assert model_type == ModelType.Vae
super().__init__(model_path, base_model, model_type)
try:
config = EmptyConfigLoader.load_config(self.model_path, config_name="config.json")
#config = json.loads(os.path.join(self.model_path, "config.json"))
except:
raise Exception("Invalid vae model! (config.json not found or invalid)")
try:
vae_class_name = config.get("_class_name", "AutoencoderKL")
self.vae_class = self._hf_definition_to_type(["diffusers", vae_class_name])
self.model_size = calc_model_size_by_fs(self.model_path)
except:
raise Exception("Invalid vae model! (Unkown vae type)")
def get_size(self, child_type: Optional[SubModelType] = None):
if child_type is not None:
raise Exception("There is no child models in vae model")
return self.model_size
def get_model(
self,
torch_dtype: Optional[torch.dtype],
child_type: Optional[SubModelType] = None,
):
if child_type is not None:
raise Exception("There is no child models in vae model")
model = self.vae_class.from_pretrained(
self.model_path,
torch_dtype=torch_dtype,
)
# calc more accurate size
self.model_size = calc_model_size_by_data(model)
return model
@classmethod
def save_to_config(cls) -> bool:
return False
@classmethod
def detect_format(cls, path: str):
if os.path.isdir(path):
return "diffusers"
else:
return "checkpoint"
@classmethod
def convert_if_required(cls, model_path: str, dst_cache_path: str, config: Optional[dict]) -> str:
if cls.detect_format(model_path) != "diffusers":
# TODO:
#_convert_vae_ckpt_and_cache
raise NotImplementedError("TODO: vae convert")
else:
return model_path
# TODO: rework
DictConfig = dict
def _convert_vae_ckpt_and_cache(self, mconfig: DictConfig) -> str:
"""
Convert the VAE indicated in mconfig into a diffusers AutoencoderKL
object, cache it to disk, and return Path to converted
file. If already on disk then just returns Path.
"""
app_config = InvokeAIAppConfig.get_config()
root = app_config.root_dir
weights_file = root / mconfig.path
config_file = root / mconfig.config
diffusers_path = app_config.converted_ckpts_dir / weights_file.stem
image_size = mconfig.get('width') or mconfig.get('height') or 512
# return cached version if it exists
if diffusers_path.exists():
return diffusers_path
# this avoids circular import error
from .convert_ckpt_to_diffusers import convert_ldm_vae_to_diffusers
if weights_file.suffix == '.safetensors':
checkpoint = safetensors.torch.load_file(weights_file)
else:
checkpoint = torch.load(weights_file, map_location="cpu")
# sometimes weights are hidden under "state_dict", and sometimes not
if "state_dict" in checkpoint:
checkpoint = checkpoint["state_dict"]
config = OmegaConf.load(config_file)
vae_model = convert_ldm_vae_to_diffusers(
checkpoint = checkpoint,
vae_config = config,
image_size = image_size
)
vae_model.save_pretrained(
diffusers_path,
safe_serialization=is_safetensors_available()
)
return diffusers_path