mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Rewrite model configs, separate models
This commit is contained in:
parent
2c056ead42
commit
3ce3a7ee72
@ -30,7 +30,6 @@ from typing import Dict, Union, types, Optional, List, Type, Any
|
|||||||
import torch
|
import torch
|
||||||
import transformers
|
import transformers
|
||||||
|
|
||||||
from diffusers import DiffusionPipeline, SchedulerMixin, ConfigMixin
|
|
||||||
from diffusers import logging as diffusers_logging
|
from diffusers import logging as diffusers_logging
|
||||||
from huggingface_hub import HfApi, scan_cache_dir
|
from huggingface_hub import HfApi, scan_cache_dir
|
||||||
from transformers import logging as transformers_logging
|
from transformers import logging as transformers_logging
|
||||||
@ -40,7 +39,7 @@ from invokeai.app.services.config import get_invokeai_config
|
|||||||
|
|
||||||
from .lora import LoRAModel, TextualInversionModel
|
from .lora import LoRAModel, TextualInversionModel
|
||||||
|
|
||||||
from .models import MODEL_CLASSES
|
from .models import BaseModelType, ModelType, SubModelType
|
||||||
|
|
||||||
|
|
||||||
# Maximum size of the cache, in gigs
|
# Maximum size of the cache, in gigs
|
||||||
@ -129,11 +128,12 @@ class ModelCache(object):
|
|||||||
def get_key(
|
def get_key(
|
||||||
self,
|
self,
|
||||||
model_path: str,
|
model_path: str,
|
||||||
model_type: SDModelType,
|
base_model: BaseModelType,
|
||||||
submodel_type: Optional[SDModelType] = None,
|
model_type: ModelType,
|
||||||
|
submodel_type: Optional[SubModelType] = None,
|
||||||
):
|
):
|
||||||
|
|
||||||
key = f"{model_path}:{model_type}"
|
key = f"{model_path}:{base_model}:{model_type}"
|
||||||
if submodel_type:
|
if submodel_type:
|
||||||
key += f":{submodel_type}"
|
key += f":{submodel_type}"
|
||||||
return key
|
return key
|
||||||
@ -152,9 +152,12 @@ class ModelCache(object):
|
|||||||
self,
|
self,
|
||||||
model_path: str,
|
model_path: str,
|
||||||
model_class: Type[ModelBase],
|
model_class: Type[ModelBase],
|
||||||
|
base_model: BaseModelType,
|
||||||
|
model_type: ModelType,
|
||||||
):
|
):
|
||||||
model_info_key = self.get_key(
|
model_info_key = self.get_key(
|
||||||
model_path=model_path,
|
model_path=model_path,
|
||||||
|
base_model=base_model,
|
||||||
model_type=model_type,
|
model_type=model_type,
|
||||||
submodel_type=None,
|
submodel_type=None,
|
||||||
)
|
)
|
||||||
@ -172,6 +175,8 @@ class ModelCache(object):
|
|||||||
self,
|
self,
|
||||||
model_path: Union[str, Path],
|
model_path: Union[str, Path],
|
||||||
model_class: Type[ModelBase],
|
model_class: Type[ModelBase],
|
||||||
|
base_model: BaseModelType,
|
||||||
|
model_type: ModelType,
|
||||||
submodel: Optional[SubModelType] = None,
|
submodel: Optional[SubModelType] = None,
|
||||||
gpu_load: bool = True,
|
gpu_load: bool = True,
|
||||||
) -> Any:
|
) -> Any:
|
||||||
@ -185,17 +190,20 @@ class ModelCache(object):
|
|||||||
model_info = self._get_model_info(
|
model_info = self._get_model_info(
|
||||||
model_path=model_path,
|
model_path=model_path,
|
||||||
model_class=model_class,
|
model_class=model_class,
|
||||||
|
base_model=base_model,
|
||||||
|
model_type=model_type,
|
||||||
)
|
)
|
||||||
key = self.get_key(
|
key = self.get_key(
|
||||||
model_path=model_path,
|
model_path=model_path,
|
||||||
model_type=model_type, # TODO:
|
base_model=base_model,
|
||||||
|
model_type=model_type,
|
||||||
submodel_type=submodel,
|
submodel_type=submodel,
|
||||||
)
|
)
|
||||||
|
|
||||||
# TODO: lock for no copies on simultaneous calls?
|
# TODO: lock for no copies on simultaneous calls?
|
||||||
cache_entry = self._cached_models.get(key, None)
|
cache_entry = self._cached_models.get(key, None)
|
||||||
if cache_entry is None:
|
if cache_entry is None:
|
||||||
self.logger.info(f'Loading model {model_path}, type {model_type}:{submodel}')
|
self.logger.info(f'Loading model {model_path}, type {base_model}:{model_type}:{submodel}')
|
||||||
|
|
||||||
# this will remove older cached models until
|
# this will remove older cached models until
|
||||||
# there is sufficient room to load the requested model
|
# there is sufficient room to load the requested model
|
||||||
@ -203,7 +211,7 @@ class ModelCache(object):
|
|||||||
|
|
||||||
# clean memory to make MemoryUsage() more accurate
|
# clean memory to make MemoryUsage() more accurate
|
||||||
gc.collect()
|
gc.collect()
|
||||||
model = model_info.get_model(submodel, torch_dtype=self.precision, variant=)
|
model = model_info.get_model(submodel, torch_dtype=self.precision)
|
||||||
if mem_used := model_info.get_size(submodel):
|
if mem_used := model_info.get_size(submodel):
|
||||||
self.logger.debug(f'CPU RAM used for load: {(mem_used/GIG):.2f} GB')
|
self.logger.debug(f'CPU RAM used for load: {(mem_used/GIG):.2f} GB')
|
||||||
|
|
||||||
|
@ -221,6 +221,9 @@ MAX_CACHE_SIZE = 6.0 # GB
|
|||||||
# └── realesrgan
|
# └── realesrgan
|
||||||
|
|
||||||
|
|
||||||
|
class ConfigMeta(BaseModel):
|
||||||
|
version: str
|
||||||
|
|
||||||
class ModelManager(object):
|
class ModelManager(object):
|
||||||
"""
|
"""
|
||||||
High-level interface to model management.
|
High-level interface to model management.
|
||||||
@ -243,15 +246,24 @@ class ModelManager(object):
|
|||||||
and sequential_offload boolean. Note that the default device
|
and sequential_offload boolean. Note that the default device
|
||||||
type and precision are set up for a CUDA system running at half precision.
|
type and precision are set up for a CUDA system running at half precision.
|
||||||
"""
|
"""
|
||||||
if isinstance(config, DictConfig):
|
|
||||||
self.config_path = None
|
self.config_path = None
|
||||||
self.config = config
|
if isinstance(config, (str, Path)):
|
||||||
elif isinstance(config,(str,Path)):
|
self.config_path = Path(config)
|
||||||
self.config_path = config
|
config = OmegaConf.load(self.config_path)
|
||||||
self.config = OmegaConf.load(self.config_path)
|
|
||||||
else:
|
elif not isinstance(config, DictConfig):
|
||||||
raise ValueError('config argument must be an OmegaConf object, a Path or a string')
|
raise ValueError('config argument must be an OmegaConf object, a Path or a string')
|
||||||
|
|
||||||
|
config_meta = ConfigMeta(config.pop("__metadata__")) # TODO: naming
|
||||||
|
# TODO: metadata not found
|
||||||
|
|
||||||
|
self.models = dict()
|
||||||
|
for model_key, model_config in config.items():
|
||||||
|
model_name, base_model, model_type = self.parse_key(model_key)
|
||||||
|
model_class = MODEL_CLASSES[base_model][model_type]
|
||||||
|
self.models[model_key] = model_class.build_config(**model_config)
|
||||||
|
|
||||||
# check config version number and update on disk/RAM if necessary
|
# check config version number and update on disk/RAM if necessary
|
||||||
self.globals = InvokeAIAppConfig.get_config()
|
self.globals = InvokeAIAppConfig.get_config()
|
||||||
self._update_config_file_version()
|
self._update_config_file_version()
|
||||||
@ -279,7 +291,7 @@ class ModelManager(object):
|
|||||||
identifier.
|
identifier.
|
||||||
"""
|
"""
|
||||||
model_key = self.create_key(model_name, base_model, model_type)
|
model_key = self.create_key(model_name, base_model, model_type)
|
||||||
return model_key in self.config
|
return model_key in self.models
|
||||||
|
|
||||||
def create_key(
|
def create_key(
|
||||||
self,
|
self,
|
||||||
@ -351,52 +363,49 @@ class ModelManager(object):
|
|||||||
|
|
||||||
model_class = MODEL_CLASSES[base_model][model_type]
|
model_class = MODEL_CLASSES[base_model][model_type]
|
||||||
|
|
||||||
#if model_type in {
|
model_key = self.create_key(model_name, base_model, model_type)
|
||||||
# ModelType.Lora,
|
|
||||||
# ModelType.ControlNet,
|
|
||||||
# ModelType.TextualInversion,
|
|
||||||
# ModelType.Vae,
|
|
||||||
#}:
|
|
||||||
if not model_class.has_config:
|
|
||||||
#if model_class.Config is None:
|
|
||||||
# skip config
|
|
||||||
# load from
|
|
||||||
# /models/{base_model}/{model_type}/{model_name}
|
|
||||||
# /models/{base_model}/{model_type}/{model_name}.{ext}
|
|
||||||
|
|
||||||
model_config = None
|
# if model not found try to find it (maybe file just pasted)
|
||||||
|
if model_key not in self.models:
|
||||||
for ext in {"pt", "ckpt", "safetensors"}:
|
# TODO: find by mask or try rescan?
|
||||||
model_path = os.path.join(model_dir, base_model, model_type, f"{model_name}.{ext}")
|
path_mask = f"/models/{base_model}/{model_type}/{model_name}*"
|
||||||
if os.path.exists(model_path):
|
if False: # model_path = next(find_by_mask(path_mask)):
|
||||||
break
|
model_path = None # TODO:
|
||||||
else:
|
model_config = model_class.build_config(
|
||||||
model_path = os.path.join(model_dir, base_model, model_type, model_name)
|
path=model_path,
|
||||||
if not os.path.exists(model_path):
|
|
||||||
raise InvalidModelError(
|
|
||||||
f"Model not found - \"{base_model}/{model_type}/{model_name}\" "
|
|
||||||
)
|
|
||||||
|
|
||||||
else:
|
|
||||||
# find in config
|
|
||||||
model_key = self.create_key(model_name, base_model, model_type)
|
|
||||||
if model_key not in self.config:
|
|
||||||
raise InvalidModelError(
|
|
||||||
f'"{model_key}" is not a known model name. Please check your models.yaml file'
|
|
||||||
)
|
)
|
||||||
|
self.models[model_key] = model_config
|
||||||
|
else:
|
||||||
|
raise Exception(f"Model not found - {model_key}")
|
||||||
|
|
||||||
model_config = self.config[model_key]
|
# if it known model check that target path exists (if manualy deleted)
|
||||||
|
else:
|
||||||
|
# logic repeated twice(in rescan too) any way to optimize?
|
||||||
|
if not os.path.exists(self.models[model_key].path):
|
||||||
|
if model_class.save_to_config:
|
||||||
|
self.models[model_key].error = ModelError.NotFound
|
||||||
|
raise Exception(f"Files for model \"{model_key}\" not found")
|
||||||
|
|
||||||
|
else:
|
||||||
|
self.models.pop(model_key, None)
|
||||||
|
raise Exception(f"Model not found - {model_key}")
|
||||||
|
|
||||||
|
# reset model errors?
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
model_config = self.models[model_key]
|
||||||
|
|
||||||
# /models/{base_model}/{model_type}/{name}.ckpt or .safentesors
|
# /models/{base_model}/{model_type}/{name}.ckpt or .safentesors
|
||||||
# /models/{base_model}/{model_type}/{name}/
|
# /models/{base_model}/{model_type}/{name}/
|
||||||
model_path = model_config.path
|
model_path = model_config.path
|
||||||
|
|
||||||
# vae/movq override
|
# vae/movq override
|
||||||
# TODO:
|
# TODO:
|
||||||
if submodel is not None and submodel in model_config:
|
if submodel is not None and submodel in model_config:
|
||||||
model_path = model_config[submodel]["path"]
|
model_path = model_config[submodel]
|
||||||
model_type = submodel
|
model_type = submodel
|
||||||
submodel = None
|
submodel = None
|
||||||
|
|
||||||
dst_convert_path = None # TODO:
|
dst_convert_path = None # TODO:
|
||||||
model_path = model_class.convert_if_required(
|
model_path = model_class.convert_if_required(
|
||||||
@ -429,11 +438,11 @@ class ModelManager(object):
|
|||||||
Returns the name of the default model, or None
|
Returns the name of the default model, or None
|
||||||
if none is defined.
|
if none is defined.
|
||||||
"""
|
"""
|
||||||
for model_key, model_config in self.config.items():
|
for model_key, model_config in self.models.items():
|
||||||
if model_config.get("default", False):
|
if model_config.default:
|
||||||
return self.parse_key(model_key)
|
return self.parse_key(model_key)
|
||||||
|
|
||||||
for model_key, _ in self.config.items():
|
for model_key, _ in self.models.items():
|
||||||
return self.parse_key(model_key)
|
return self.parse_key(model_key)
|
||||||
else:
|
else:
|
||||||
return None # TODO: or redo as (None, None, None)
|
return None # TODO: or redo as (None, None, None)
|
||||||
@ -450,14 +459,11 @@ class ModelManager(object):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
model_key = self.model_key(model_name, base_model, model_type)
|
model_key = self.model_key(model_name, base_model, model_type)
|
||||||
if model_key not in self.config:
|
if model_key not in self.models:
|
||||||
raise Exception(f"Unknown model: {model_key}")
|
raise Exception(f"Unknown model: {model_key}")
|
||||||
|
|
||||||
for cur_model_key, config in self.config.items():
|
for cur_model_key, config in self.models.items():
|
||||||
if cur_model_key == model_key:
|
config.default = cur_model_key == model_key
|
||||||
config["default"] = True
|
|
||||||
else:
|
|
||||||
config.pop("default", None)
|
|
||||||
|
|
||||||
def model_info(
|
def model_info(
|
||||||
self,
|
self,
|
||||||
@ -469,14 +475,17 @@ class ModelManager(object):
|
|||||||
Given a model name returns the OmegaConf (dict-like) object describing it.
|
Given a model name returns the OmegaConf (dict-like) object describing it.
|
||||||
"""
|
"""
|
||||||
model_key = self.create_key(model_name, base_model, model_type)
|
model_key = self.create_key(model_name, base_model, model_type)
|
||||||
return self.config.get(model_key, None)
|
if model_key in self.models:
|
||||||
|
return self.models[model_key].dict(exclude_defaults=True)
|
||||||
|
else:
|
||||||
|
return None # TODO: None or empty dict on not found
|
||||||
|
|
||||||
def model_names(self) -> List[Tuple[str, BaseModelType, ModelType]]:
|
def model_names(self) -> List[Tuple[str, BaseModelType, ModelType]]:
|
||||||
"""
|
"""
|
||||||
Return a list of (str, BaseModelType, ModelType) corresponding to all models
|
Return a list of (str, BaseModelType, ModelType) corresponding to all models
|
||||||
known to the configuration.
|
known to the configuration.
|
||||||
"""
|
"""
|
||||||
return [(self.parse_key(x)) for x in self.config.keys() if isinstance(self.config[x], DictConfig)]
|
return [(self.parse_key(x)) for x in self.models.keys()]
|
||||||
|
|
||||||
def list_models(
|
def list_models(
|
||||||
self,
|
self,
|
||||||
@ -494,48 +503,37 @@ class ModelManager(object):
|
|||||||
assert not(model_type is not None and base_model is None), "model_type must be provided with base_model"
|
assert not(model_type is not None and base_model is None), "model_type must be provided with base_model"
|
||||||
|
|
||||||
models = dict()
|
models = dict()
|
||||||
for model_key in sorted(self.config, key=str.casefold):
|
for model_key in sorted(self.models, key=str.casefold):
|
||||||
stanza = self.config[model_key]
|
model_config = self.models[model_key]
|
||||||
|
|
||||||
if model_key.startswith('_'):
|
cur_model_name, cur_base_model, cur_model_type = self.parse_key(model_key)
|
||||||
|
if base_model is not None and cur_base_model != base_model:
|
||||||
|
continue
|
||||||
|
if model_type is not None and cur_model_type != model_type:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
model_name, m_base_model, stanza_type = self.parse_key(model_key)
|
if cur_base_model not in models:
|
||||||
if base_model is not None and m_base_model != base_model:
|
models[cur_base_model] = dict()
|
||||||
continue
|
if cur_model_type not in models[cur_base_model]:
|
||||||
if model_type is not None and model_type != stanza_type:
|
models[cur_base_model][cur_model_type] = dict()
|
||||||
continue
|
|
||||||
|
|
||||||
if m_base_model not in models:
|
models[m_base_model][stanza_type][model_name] = dict(
|
||||||
models[m_base_model] = dict()
|
**model_config.dict(exclude_defaults=True),
|
||||||
if stanza_type not in models:
|
|
||||||
models[m_base_model][stanza_type] = dict()
|
|
||||||
|
|
||||||
model_class = MODEL_CLASSES[m_base_model][stanza_type]
|
|
||||||
models[m_base_model][stanza_type][model_name] = model_class.build_config(
|
|
||||||
**stanza,
|
|
||||||
name=model_name,
|
name=model_name,
|
||||||
base_model=base_model,
|
base_model=cur_base_model,
|
||||||
type=stanza_type,
|
type=cur_model_type,
|
||||||
)
|
)
|
||||||
#models[m_base_model][stanza_type][model_name] = model_class.Config(
|
|
||||||
# **stanza,
|
|
||||||
# name=model_name,
|
|
||||||
# base_model=base_model,
|
|
||||||
# type=stanza_type,
|
|
||||||
#).dict()
|
|
||||||
|
|
||||||
return models
|
return models
|
||||||
|
|
||||||
def print_models(self) -> None:
|
def print_models(self) -> None:
|
||||||
"""
|
"""
|
||||||
Print a table of models, their descriptions, and load status
|
Print a table of models, their descriptions
|
||||||
"""
|
"""
|
||||||
|
# TODO: redo
|
||||||
for model_type, model_dict in self.list_models().items():
|
for model_type, model_dict in self.list_models().items():
|
||||||
for model_name, model_info in model_dict.items():
|
for model_name, model_info in model_dict.items():
|
||||||
line = f'{model_info["name"]:25s} {model_info["status"]:>15s} {model_info["type"]:10s} {model_info["description"]}'
|
line = f'{model_info["name"]:25s} {model_info["type"]:10s} {model_info["description"]}'
|
||||||
if model_info["status"] in ["in gpu","locked in gpu"]:
|
|
||||||
line = f"\033[1m{line}\033[0m"
|
|
||||||
print(line)
|
print(line)
|
||||||
|
|
||||||
def del_model(
|
def del_model(
|
||||||
@ -596,27 +594,14 @@ class ModelManager(object):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
model_class = MODEL_CLASSES[base_model][model_type]
|
model_class = MODEL_CLASSES[base_model][model_type]
|
||||||
|
model_config = model_class.build_config(**model_attributes)
|
||||||
model_class.build_config(
|
|
||||||
**model_attributes,
|
|
||||||
name=model_name,
|
|
||||||
base_model=base_model,
|
|
||||||
type=model_type,
|
|
||||||
)
|
|
||||||
#model_cfg = model_class.Config(
|
|
||||||
# **model_attributes,
|
|
||||||
# name=model_name,
|
|
||||||
# base_model=base_model,
|
|
||||||
# type=model_type,
|
|
||||||
#)
|
|
||||||
|
|
||||||
model_key = self.create_key(model_name, base_model, model_type)
|
model_key = self.create_key(model_name, base_model, model_type)
|
||||||
|
|
||||||
assert (
|
assert (
|
||||||
clobber or model_key not in self.config
|
clobber or model_key not in self.models
|
||||||
), f'attempt to overwrite existing model definition "{model_key}"'
|
), f'attempt to overwrite existing model definition "{model_key}"'
|
||||||
|
|
||||||
self.config[model_key] = model_attributes
|
self.models[model_key] = model_config
|
||||||
|
|
||||||
if clobber and model_key in self.cache_keys:
|
if clobber and model_key in self.cache_keys:
|
||||||
# TODO:
|
# TODO:
|
||||||
@ -822,7 +807,15 @@ class ModelManager(object):
|
|||||||
"""
|
"""
|
||||||
Write current configuration out to the indicated file.
|
Write current configuration out to the indicated file.
|
||||||
"""
|
"""
|
||||||
yaml_str = OmegaConf.to_yaml(self.config)
|
data_to_save = dict()
|
||||||
|
for model_key, model_config in self.models.items():
|
||||||
|
model_name, base_model, model_type = self.parse_key(model_key)
|
||||||
|
model_class = MODEL_CLASSES[base_model][model_type]
|
||||||
|
if model_class.save_to_config:
|
||||||
|
# TODO: or exclude_unset better fits here?
|
||||||
|
data_to_save[model_key] = model_config.dict(exclude_defaults=True)
|
||||||
|
|
||||||
|
yaml_str = OmegaConf.to_yaml(data_to_save)
|
||||||
config_file_path = conf_file or self.config_path
|
config_file_path = conf_file or self.config_path
|
||||||
assert config_file_path is not None,'no config file path to write to'
|
assert config_file_path is not None,'no config file path to write to'
|
||||||
config_file_path = self.globals.root_dir / config_file_path
|
config_file_path = self.globals.root_dir / config_file_path
|
||||||
@ -887,146 +880,41 @@ class ModelManager(object):
|
|||||||
return resolved_path
|
return resolved_path
|
||||||
|
|
||||||
def _update_config_file_version(self):
|
def _update_config_file_version(self):
|
||||||
"""
|
# TODO:
|
||||||
This gets called at object init time and will update
|
raise Exception("TODO: ")
|
||||||
from older versions of the config file to new ones
|
|
||||||
as necessary.
|
|
||||||
"""
|
|
||||||
current_version = self.config.get("_version","1.0.0")
|
|
||||||
if version.parse(current_version) < version.parse(CONFIG_FILE_VERSION):
|
|
||||||
self.logger.warning(f'models.yaml version {current_version} detected. Updating to {CONFIG_FILE_VERSION}')
|
|
||||||
self.logger.warning('The original file will be renamed models.yaml.orig')
|
|
||||||
if self.config_path:
|
|
||||||
old_file = Path(self.config_path)
|
|
||||||
new_name = old_file.parent / 'models.yaml.orig'
|
|
||||||
old_file.replace(new_name)
|
|
||||||
|
|
||||||
new_config = OmegaConf.create()
|
|
||||||
new_config["_version"] = CONFIG_FILE_VERSION
|
|
||||||
|
|
||||||
for model_key in self.config:
|
|
||||||
|
|
||||||
old_stanza = self.config[model_key]
|
def scan_models_directory(self):
|
||||||
if not isinstance(old_stanza,DictConfig):
|
|
||||||
continue
|
|
||||||
|
|
||||||
# ignore old and ugly way of associating a legacy
|
for model_key in list(self.models.keys()):
|
||||||
# vae with a legacy checkpont model
|
model_name, base_model, model_type = self.parse_key(model_key)
|
||||||
if old_stanza.get("config") and '/VAE/' in old_stanza.get("config"):
|
if not os.path.exists(model_config.path):
|
||||||
continue
|
if model_class.save_to_config:
|
||||||
|
self.models[model_key].error = ModelError.NotFound
|
||||||
# bare keys are updated to be prefixed with 'diffusers/'
|
|
||||||
if '/' not in model_key:
|
|
||||||
new_key = f'diffusers/{model_key}'
|
|
||||||
else:
|
else:
|
||||||
new_key = model_key
|
self.models.pop(model_key, None)
|
||||||
|
|
||||||
if old_stanza.get('format')=='diffusers':
|
|
||||||
model_format = 'folder'
|
|
||||||
elif old_stanza.get('weights') and Path(old_stanza.get('weights')).suffix == '.ckpt':
|
|
||||||
model_format = 'ckpt'
|
|
||||||
elif old_stanza.get('weights') and Path(old_stanza.get('weights')).suffix == '.safetensors':
|
|
||||||
model_format = 'safetensors'
|
|
||||||
else:
|
|
||||||
model_format = old_stanza.get('format')
|
|
||||||
|
|
||||||
# copy fields over manually rather than doing a copy() or deepcopy()
|
for base_model in BaseModelType:
|
||||||
# in order to avoid bringing in unwanted fields.
|
for model_type in ModelType:
|
||||||
new_config[new_key] = dict(
|
|
||||||
description = old_stanza.get('description'),
|
|
||||||
format = model_format,
|
|
||||||
)
|
|
||||||
for field in ["repo_id", "path", "weights", "config", "vae"]:
|
|
||||||
if field_value := old_stanza.get(field):
|
|
||||||
new_config[new_key].update({field: field_value})
|
|
||||||
|
|
||||||
self.config = new_config
|
|
||||||
if self.config_path:
|
|
||||||
self.commit()
|
|
||||||
|
|
||||||
def _delete_defunct_models(self):
|
model_class = MODEL_CLASSES[base_model][model_type]
|
||||||
'''
|
models_dir = os.path.join(self.globals.models_path, base_model, model_type)
|
||||||
Remove models no longer on disk.
|
|
||||||
'''
|
for entry_name in os.listdir(models_dir):
|
||||||
config = self.config
|
model_path = os.path.join(models_dir, entry_name)
|
||||||
|
model_name = Path(model_path).stem
|
||||||
|
model_config: ModelConfigBase = model_class.build_config(
|
||||||
|
path=model_path,
|
||||||
|
)
|
||||||
|
|
||||||
|
model_key = self.create_key(model_name, base_model, model_type)
|
||||||
|
if model_key not in self.models:
|
||||||
|
self.models[model_key] = model_config
|
||||||
|
|
||||||
to_delete = set()
|
|
||||||
for key in config:
|
|
||||||
if 'path' not in config[key]:
|
|
||||||
continue
|
|
||||||
path = self.globals.root_dir / config[key].path
|
|
||||||
if path.exists():
|
|
||||||
continue
|
|
||||||
to_delete.add(key)
|
|
||||||
|
|
||||||
for key in to_delete:
|
|
||||||
self.logger.warn(f'Removing model {key} from in-memory config because its path is no longer on disk')
|
|
||||||
config.pop(key)
|
|
||||||
|
|
||||||
def scan_models_directory(self, include_diffusers:bool=False):
|
|
||||||
'''
|
|
||||||
Scan the models directory for loras, textual_inversions and controlnets
|
|
||||||
and create appropriate entries in the in-memory omegaconf. Diffusers
|
|
||||||
will not be added unless include_diffusers is true.
|
|
||||||
'''
|
|
||||||
self._delete_defunct_models()
|
|
||||||
|
|
||||||
model_directory = self.globals.models_path
|
|
||||||
config = self.config
|
|
||||||
|
|
||||||
for root, dirs, files in os.walk(model_directory):
|
|
||||||
parents = root.split('/')
|
|
||||||
subpaths = parents[parents.index('models')+1:]
|
|
||||||
if len(subpaths) < 2:
|
|
||||||
continue
|
|
||||||
base, model_type, *_ = subpaths
|
|
||||||
|
|
||||||
if model_type == "diffusers" and not include_diffusers:
|
|
||||||
continue
|
|
||||||
|
|
||||||
for d in dirs:
|
|
||||||
config[f'{model_type}/{d}'] = dict(
|
|
||||||
path = os.path.join(root,d),
|
|
||||||
description = f'{model_type} model {d}',
|
|
||||||
format = 'folder',
|
|
||||||
base = base,
|
|
||||||
)
|
|
||||||
|
|
||||||
for f in files:
|
|
||||||
basename = Path(f).stem
|
|
||||||
format = Path(f).suffix[1:]
|
|
||||||
config[f'{model_type}/{basename}'] = dict(
|
|
||||||
path = os.path.join(root,f),
|
|
||||||
description = f'{model_type} model {basename}',
|
|
||||||
format = format,
|
|
||||||
base = base,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
##### NONE OF THE METHODS BELOW WORK NOW BECAUSE OF MODEL DIRECTORY REORGANIZATION
|
##### NONE OF THE METHODS BELOW WORK NOW BECAUSE OF MODEL DIRECTORY REORGANIZATION
|
||||||
##### AND NEED TO BE REWRITTEN
|
##### AND NEED TO BE REWRITTEN
|
||||||
def list_lora_models(self)->Dict[str,bool]:
|
|
||||||
'''Return a dict of installed lora models; key is either the shortname
|
|
||||||
defined in INITIAL_MODELS, or the basename of the file in the LoRA
|
|
||||||
directory. Value is True if installed'''
|
|
||||||
|
|
||||||
models = OmegaConf.load(Dataset_path).get('lora') or {}
|
|
||||||
installed_models = {x: False for x in models.keys()}
|
|
||||||
|
|
||||||
dir = self.globals.lora_path
|
|
||||||
installed_models = dict()
|
|
||||||
for root, dirs, files in os.walk(dir):
|
|
||||||
for name in files:
|
|
||||||
if Path(name).suffix not in ['.safetensors','.ckpt','.pt','.bin']:
|
|
||||||
continue
|
|
||||||
if name == 'pytorch_lora_weights.bin':
|
|
||||||
name = Path(root,name).parent.stem #Path(root,name).stem
|
|
||||||
else:
|
|
||||||
name = Path(name).stem
|
|
||||||
installed_models.update({name: True})
|
|
||||||
|
|
||||||
return installed_models
|
|
||||||
|
|
||||||
def install_lora_models(self, model_names: list[str], access_token:str=None):
|
def install_lora_models(self, model_names: list[str], access_token:str=None):
|
||||||
'''Download list of LoRA/LyCORIS models'''
|
'''Download list of LoRA/LyCORIS models'''
|
||||||
|
|
||||||
@ -1051,38 +939,6 @@ class ModelManager(object):
|
|||||||
|
|
||||||
else:
|
else:
|
||||||
self.logger.error(f"Unknown repo_id or URL: {name}")
|
self.logger.error(f"Unknown repo_id or URL: {name}")
|
||||||
|
|
||||||
def delete_lora_models(self, model_names: List[str]):
|
|
||||||
'''Remove the list of lora models'''
|
|
||||||
for name in model_names:
|
|
||||||
file_or_directory = self.globals.lora_path / name
|
|
||||||
if file_or_directory.is_dir():
|
|
||||||
self.logger.info(f'Purging LoRA/LyCORIS {name}')
|
|
||||||
shutil.rmtree(str(file_or_directory))
|
|
||||||
else:
|
|
||||||
for path in self.globals.lora_path.glob(f'{name}.*'):
|
|
||||||
self.logger.info(f'Purging LoRA/LyCORIS {name}')
|
|
||||||
path.unlink()
|
|
||||||
|
|
||||||
def list_ti_models(self)->Dict[str,bool]:
|
|
||||||
'''Return a dict of installed textual models; key is either the shortname
|
|
||||||
defined in INITIAL_MODELS, or the basename of the file in the LoRA
|
|
||||||
directory. Value is True if installed'''
|
|
||||||
|
|
||||||
models = OmegaConf.load(Dataset_path).get('textual_inversion') or {}
|
|
||||||
installed_models = {x: False for x in models.keys()}
|
|
||||||
|
|
||||||
dir = self.globals.embedding_path
|
|
||||||
for root, dirs, files in os.walk(dir):
|
|
||||||
for name in files:
|
|
||||||
if not Path(name).suffix in ['.bin','.pt','.ckpt','.safetensors']:
|
|
||||||
continue
|
|
||||||
if name == 'learned_embeds.bin':
|
|
||||||
name = Path(root,name).parent.stem #Path(root,name).stem
|
|
||||||
else:
|
|
||||||
name = Path(name).stem
|
|
||||||
installed_models.update({name: True})
|
|
||||||
return installed_models
|
|
||||||
|
|
||||||
def install_ti_models(self, model_names: list[str], access_token: str=None):
|
def install_ti_models(self, model_names: list[str], access_token: str=None):
|
||||||
'''Download list of textual inversion embeddings'''
|
'''Download list of textual inversion embeddings'''
|
||||||
@ -1104,32 +960,7 @@ class ModelManager(object):
|
|||||||
download_with_resume(name, self.globals.embedding_path)
|
download_with_resume(name, self.globals.embedding_path)
|
||||||
else:
|
else:
|
||||||
self.logger.error(f'{name} does not look like either a HuggingFace repo_id or a downloadable URL')
|
self.logger.error(f'{name} does not look like either a HuggingFace repo_id or a downloadable URL')
|
||||||
|
|
||||||
def delete_ti_models(self, model_names: list[str]):
|
|
||||||
'''Remove TI embeddings from disk'''
|
|
||||||
for name in model_names:
|
|
||||||
file_or_directory = self.globals.embedding_path / name
|
|
||||||
if file_or_directory.is_dir():
|
|
||||||
self.logger.info(f'Purging textual inversion embedding {name}')
|
|
||||||
shutil.rmtree(str(file_or_directory))
|
|
||||||
else:
|
|
||||||
for path in self.globals.embedding_path.glob(f'{name}.*'):
|
|
||||||
self.logger.info(f'Purging textual inversion embedding {name}')
|
|
||||||
path.unlink()
|
|
||||||
|
|
||||||
def list_controlnet_models(self)->Dict[str,bool]:
|
|
||||||
'''Return a dict of installed controlnet models; key is repo_id or short name
|
|
||||||
of model (defined in INITIAL_MODELS), and value is True if installed'''
|
|
||||||
|
|
||||||
cn_models = OmegaConf.load(Dataset_path).get('controlnet') or {}
|
|
||||||
installed_models = {x: False for x in cn_models.keys()}
|
|
||||||
|
|
||||||
cn_dir = self.globals.controlnet_path
|
|
||||||
for root, dirs, files in os.walk(cn_dir):
|
|
||||||
for name in dirs:
|
|
||||||
if Path(root, name, '.download_complete').exists():
|
|
||||||
installed_models.update({name.replace('--','/'): True})
|
|
||||||
return installed_models
|
|
||||||
|
|
||||||
def install_controlnet_models(self, model_names: list[str], access_token: str=None):
|
def install_controlnet_models(self, model_names: list[str], access_token: str=None):
|
||||||
'''Download list of controlnet models; provide either repo_id or short name listed in INITIAL_MODELS.yaml'''
|
'''Download list of controlnet models; provide either repo_id or short name listed in INITIAL_MODELS.yaml'''
|
||||||
@ -1175,12 +1006,4 @@ class ModelManager(object):
|
|||||||
(path.parent / '.download_complete').touch()
|
(path.parent / '.download_complete').touch()
|
||||||
break
|
break
|
||||||
|
|
||||||
def delete_controlnet_models(self, model_names: List[str]):
|
|
||||||
'''Remove the list of controlnet models'''
|
|
||||||
for name in model_names:
|
|
||||||
safe_name = name.replace('/','--')
|
|
||||||
directory = self.globals.controlnet_path / safe_name
|
|
||||||
if directory.exists():
|
|
||||||
self.logger.info(f'Purging controlnet model {name}')
|
|
||||||
shutil.rmtree(str(directory))
|
|
||||||
|
|
||||||
|
@ -1,726 +0,0 @@
|
|||||||
import sys
|
|
||||||
from enum import Enum
|
|
||||||
import torch
|
|
||||||
import safetensors.torch
|
|
||||||
from diffusers.utils import is_safetensors_available
|
|
||||||
|
|
||||||
class BaseModelType(str, Enum):
|
|
||||||
#StableDiffusion1_5 = "stable_diffusion_1_5"
|
|
||||||
#StableDiffusion2 = "stable_diffusion_2"
|
|
||||||
#StableDiffusion2Base = "stable_diffusion_2_base"
|
|
||||||
# TODO: maybe then add sample size(512/768)?
|
|
||||||
StableDiffusion1_5 = "SD-1"
|
|
||||||
StableDiffusion2Base = "SD-2-base" # 512 pixels; this will have epsilon parameterization
|
|
||||||
StableDiffusion2 = "SD-2" # 768 pixels; this will have v-prediction parameterization
|
|
||||||
#Kandinsky2_1 = "kandinsky_2_1"
|
|
||||||
|
|
||||||
class ModelType(str, Enum):
|
|
||||||
Pipeline = "pipeline"
|
|
||||||
Classifier = "classifier"
|
|
||||||
Vae = "vae"
|
|
||||||
|
|
||||||
Lora = "lora"
|
|
||||||
ControlNet = "controlnet"
|
|
||||||
TextualInversion = "embedding"
|
|
||||||
|
|
||||||
class SubModelType:
|
|
||||||
UNet = "unet"
|
|
||||||
TextEncoder = "text_encoder"
|
|
||||||
Tokenizer = "tokenizer"
|
|
||||||
Vae = "vae"
|
|
||||||
Scheduler = "scheduler"
|
|
||||||
SafetyChecker = "safety_checker"
|
|
||||||
#MoVQ = "movq"
|
|
||||||
|
|
||||||
MODEL_CLASSES = {
|
|
||||||
BaseModel.StableDiffusion1_5: {
|
|
||||||
ModelType.Pipeline: StableDiffusionModel,
|
|
||||||
ModelType.Classifier: ClassifierModel,
|
|
||||||
ModelType.Vae: VaeModel,
|
|
||||||
ModelType.Lora: LoraModel,
|
|
||||||
ModelType.ControlNet: ControlNetModel,
|
|
||||||
ModelType.TextualInversion: TextualInversionModel,
|
|
||||||
},
|
|
||||||
BaseModel.StableDiffusion2: {
|
|
||||||
ModelType.Pipeline: StableDiffusionModel,
|
|
||||||
ModelType.Classifier: ClassifierModel,
|
|
||||||
ModelType.Vae: VaeModel,
|
|
||||||
ModelType.Lora: LoraModel,
|
|
||||||
ModelType.ControlNet: ControlNetModel,
|
|
||||||
ModelType.TextualInversion: TextualInversionModel,
|
|
||||||
},
|
|
||||||
BaseModel.StableDiffusion2Base: {
|
|
||||||
ModelType.Pipeline: StableDiffusionModel,
|
|
||||||
ModelType.Classifier: ClassifierModel,
|
|
||||||
ModelType.Vae: VaeModel,
|
|
||||||
ModelType.Lora: LoraModel,
|
|
||||||
ModelType.ControlNet: ControlNetModel,
|
|
||||||
ModelType.TextualInversion: TextualInversionModel,
|
|
||||||
},
|
|
||||||
#BaseModel.Kandinsky2_1: {
|
|
||||||
# ModelType.Pipeline: Kandinsky2_1Model,
|
|
||||||
# ModelType.Classifier: ClassifierModel,
|
|
||||||
# ModelType.MoVQ: MoVQModel,
|
|
||||||
# ModelType.Lora: LoraModel,
|
|
||||||
# ModelType.ControlNet: ControlNetModel,
|
|
||||||
# ModelType.TextualInversion: TextualInversionModel,
|
|
||||||
#},
|
|
||||||
}
|
|
||||||
|
|
||||||
class EmptyConfigLoader(ConfigMixin):
|
|
||||||
@classmethod
|
|
||||||
def load_config(cls, *args, **kwargs):
|
|
||||||
cls.config_name = kwargs.pop("config_name")
|
|
||||||
return super().load_config(*args, **kwargs)
|
|
||||||
|
|
||||||
class ModelBase:
|
|
||||||
#model_path: str
|
|
||||||
#base_model: BaseModelType
|
|
||||||
#model_type: ModelType
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
model_path: str,
|
|
||||||
base_model: BaseModelType,
|
|
||||||
model_type: ModelType,
|
|
||||||
):
|
|
||||||
self.model_path = model_path
|
|
||||||
self.base_model = base_model
|
|
||||||
self.model_type = model_type
|
|
||||||
|
|
||||||
def _hf_definition_to_type(self, subtypes: List[str]) -> Type:
|
|
||||||
if len(subtypes) < 2:
|
|
||||||
raise Exception("Invalid subfolder definition!")
|
|
||||||
if subtypes[0] in ["diffusers", "transformers"]:
|
|
||||||
res_type = sys.modules[subtypes[0]]
|
|
||||||
subtypes = subtypes[1:]
|
|
||||||
|
|
||||||
else:
|
|
||||||
res_type = sys.modules["diffusers"]
|
|
||||||
res_type = getattr(res_type, "pipelines")
|
|
||||||
|
|
||||||
|
|
||||||
for subtype in subtypes:
|
|
||||||
res_type = getattr(res_type, subtype)
|
|
||||||
return res_type
|
|
||||||
|
|
||||||
|
|
||||||
class DiffusersModel(ModelBase):
|
|
||||||
#child_types: Dict[str, Type]
|
|
||||||
#child_sizes: Dict[str, int]
|
|
||||||
|
|
||||||
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
|
|
||||||
super().__init__(model_path, base_model, model_type)
|
|
||||||
|
|
||||||
self.child_types: Dict[str, Type] = dict()
|
|
||||||
self.child_sizes: Dict[str, int] = dict()
|
|
||||||
|
|
||||||
try:
|
|
||||||
config_data = DiffusionPipeline.load_config(self.model_path)
|
|
||||||
#config_data = json.loads(os.path.join(self.model_path, "model_index.json"))
|
|
||||||
except:
|
|
||||||
raise Exception("Invalid diffusers model! (model_index.json not found or invalid)")
|
|
||||||
|
|
||||||
config_data.pop("_ignore_files", None)
|
|
||||||
|
|
||||||
# retrieve all folder_names that contain relevant files
|
|
||||||
child_components = [k for k, v in config_data.items() if isinstance(v, list)]
|
|
||||||
|
|
||||||
for child_name in child_components:
|
|
||||||
child_type = self._hf_definition_to_type(config_data[child_name])
|
|
||||||
self.child_types[child_name] = child_type
|
|
||||||
self.child_sizes[child_name] = calc_model_size_by_fs(self.model_path, subfolder=child_name)
|
|
||||||
|
|
||||||
|
|
||||||
def get_size(self, child_type: Optional[SubModelType] = None):
|
|
||||||
if child_type is None:
|
|
||||||
return sum(self.child_sizes.values())
|
|
||||||
else:
|
|
||||||
return self.child_sizes[child_type]
|
|
||||||
|
|
||||||
|
|
||||||
def get_model(
|
|
||||||
self,
|
|
||||||
torch_dtype: Optional[torch.dtype],
|
|
||||||
child_type: Optional[SubModelType] = None,
|
|
||||||
):
|
|
||||||
# return pipeline in different function to pass more arguments
|
|
||||||
if child_type is None:
|
|
||||||
raise Exception("Child model type can't be null on diffusers model")
|
|
||||||
if child_type not in self.child_types:
|
|
||||||
return None # TODO: or raise
|
|
||||||
|
|
||||||
if torch_dtype == torch.float16:
|
|
||||||
variants = ["fp16", None]
|
|
||||||
else:
|
|
||||||
variants = [None, "fp16"]
|
|
||||||
|
|
||||||
# TODO: better error handling(differentiate not found from others)
|
|
||||||
for variant in variants:
|
|
||||||
try:
|
|
||||||
# TODO: set cache_dir to /dev/null to be sure that cache not used?
|
|
||||||
model = self.child_types[child_type].from_pretrained(
|
|
||||||
self.model_path,
|
|
||||||
subfolder=child_type.value,
|
|
||||||
torch_dtype=torch_dtype,
|
|
||||||
variant=variant,
|
|
||||||
local_files_only=True,
|
|
||||||
)
|
|
||||||
break
|
|
||||||
except Exception as e:
|
|
||||||
print("====ERR LOAD====")
|
|
||||||
print(f"{variant}: {e}")
|
|
||||||
|
|
||||||
# calc more accurate size
|
|
||||||
self.child_sizes[child_type] = calc_model_size_by_data(model)
|
|
||||||
return model
|
|
||||||
|
|
||||||
#def convert_if_required(model_path: Union[str, Path], cache_path: str, config: Optional[dict]) -> Path:
|
|
||||||
|
|
||||||
|
|
||||||
class StableDiffusionModel(DiffusersModel):
|
|
||||||
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
|
|
||||||
assert base_model in {
|
|
||||||
BaseModelType.StableDiffusion1_5,
|
|
||||||
BaseModelType.StableDiffusion2,
|
|
||||||
BaseModelType.StableDiffusion2Base,
|
|
||||||
}
|
|
||||||
assert model_type == ModelType.Pipeline
|
|
||||||
super().__init__(model_path, base_model, model_type)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def convert_if_required(model_path: Union[str, Path], dst_path: str, config: Optional[dict]) -> Path:
|
|
||||||
if not isinstance(model_path, Path):
|
|
||||||
model_path = Path(model_path)
|
|
||||||
|
|
||||||
# TODO: args
|
|
||||||
# TODO: set model_path, to config? pass dst_path as arg?
|
|
||||||
# TODO: check
|
|
||||||
return _convert_ckpt_and_cache(config)
|
|
||||||
|
|
||||||
class classproperty(object): # pylint: disable=invalid-name
|
|
||||||
"""Class property decorator.
|
|
||||||
|
|
||||||
Example usage:
|
|
||||||
|
|
||||||
class MyClass(object):
|
|
||||||
|
|
||||||
@classproperty
|
|
||||||
def value(cls):
|
|
||||||
return '123'
|
|
||||||
|
|
||||||
> print MyClass.value
|
|
||||||
123
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, func):
|
|
||||||
self._func = func
|
|
||||||
|
|
||||||
def __get__(self, owner_self, owner_cls):
|
|
||||||
return self._func(owner_cls)
|
|
||||||
|
|
||||||
class ModelConfigBase(BaseModel):
|
|
||||||
path: str # or Path
|
|
||||||
name: str
|
|
||||||
description: Optional[str]
|
|
||||||
|
|
||||||
|
|
||||||
class StableDiffusionDModel(DiffusersModel):
|
|
||||||
class Config(ModelConfigBase):
|
|
||||||
format: str
|
|
||||||
vae: Optional[str] = Field(None)
|
|
||||||
config: Optional[str] = Field(None)
|
|
||||||
|
|
||||||
@root_validator
|
|
||||||
def validator(cls, values):
|
|
||||||
if values["format"] not in {"checkpoint", "diffusers"}:
|
|
||||||
raise ValueError(f"Unkown stable diffusion model format: {values['format']}")
|
|
||||||
if values["config"] is not None and values["format"] != "checkpoint":
|
|
||||||
raise ValueError(f"Custom config field allowed only in checkpoint stable diffusion model")
|
|
||||||
return values
|
|
||||||
|
|
||||||
# return config only for checkpoint format
|
|
||||||
def dict(self, *args, **kwargs):
|
|
||||||
result = super().dict(*args, **kwargs)
|
|
||||||
if self.format != "checkpoint":
|
|
||||||
result.pop("config", None)
|
|
||||||
return result
|
|
||||||
|
|
||||||
@classproperty
|
|
||||||
def has_config(self):
|
|
||||||
return True
|
|
||||||
|
|
||||||
def build_config(self, **kwargs) -> dict:
|
|
||||||
try:
|
|
||||||
res = dict(
|
|
||||||
path=kwargs["path"],
|
|
||||||
name=kwargs["name"],
|
|
||||||
description=kwargs.get("description", None),
|
|
||||||
|
|
||||||
format=kwargs["format"],
|
|
||||||
vae=kwargs.get("vae", None),
|
|
||||||
)
|
|
||||||
if res["format"] not in {"checkpoint", "diffusers"}:
|
|
||||||
raise Exception(f"Unkonwn stable diffusion model format: {res['format']}")
|
|
||||||
if res["format"] == "checkpoint":
|
|
||||||
res["config"] = kwargs.get("config", None)
|
|
||||||
# TODO: raise if config specified for diffusers?
|
|
||||||
|
|
||||||
return res
|
|
||||||
|
|
||||||
except KeyError as e:
|
|
||||||
raise Exception(f"Field \"{e.args[0]}\" not found!")
|
|
||||||
|
|
||||||
|
|
||||||
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
|
|
||||||
assert base_model == BaseModelType.StableDiffusion1_5
|
|
||||||
assert model_type == ModelType.Pipeline
|
|
||||||
super().__init__(model_path, base_model, model_type)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def convert_if_required(cls, model_path: str, dst_path: str, config: Optional[dict]) -> str:
|
|
||||||
model_config = cls.Config(
|
|
||||||
**config,
|
|
||||||
path=model_path,
|
|
||||||
name="",
|
|
||||||
)
|
|
||||||
|
|
||||||
if hasattr(model_config, "config"):
|
|
||||||
convert_ckpt_and_cache(
|
|
||||||
model_path=model_path,
|
|
||||||
dst_path=dst_path,
|
|
||||||
config=config,
|
|
||||||
)
|
|
||||||
return dst_path
|
|
||||||
|
|
||||||
else:
|
|
||||||
return model_path
|
|
||||||
|
|
||||||
class StableDiffusion15CheckpointModel(DiffusersModel):
|
|
||||||
class Cnfig(ModelConfigBase):
|
|
||||||
vae: Optional[str] = Field(None)
|
|
||||||
config: Optional[str] = Field(None)
|
|
||||||
|
|
||||||
class StableDiffusion2BaseDiffusersModel(DiffusersModel):
|
|
||||||
class Config(ModelConfigBase):
|
|
||||||
vae: Optional[str] = Field(None)
|
|
||||||
|
|
||||||
class StableDiffusion2BaseCheckpointModel(DiffusersModel):
|
|
||||||
class Cnfig(ModelConfigBase):
|
|
||||||
vae: Optional[str] = Field(None)
|
|
||||||
config: Optional[str] = Field(None)
|
|
||||||
|
|
||||||
class StableDiffusion2DiffusersModel(DiffusersModel):
|
|
||||||
class Config(ModelConfigBase):
|
|
||||||
vae: Optional[str] = Field(None)
|
|
||||||
attention_upscale: bool = Field(True)
|
|
||||||
|
|
||||||
class StableDiffusion2CheckpointModel(DiffusersModel):
|
|
||||||
class Config(ModelConfigBase):
|
|
||||||
vae: Optional[str] = Field(None)
|
|
||||||
config: Optional[str] = Field(None)
|
|
||||||
attention_upscale: bool = Field(True)
|
|
||||||
|
|
||||||
|
|
||||||
class ClassifierModel(ModelBase):
|
|
||||||
#child_types: Dict[str, Type]
|
|
||||||
#child_sizes: Dict[str, int]
|
|
||||||
|
|
||||||
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
|
|
||||||
assert model_type == SDModelType.Classifier
|
|
||||||
super().__init__(model_path, base_model, model_type)
|
|
||||||
|
|
||||||
self.child_types: Dict[str, Type] = dict()
|
|
||||||
self.child_sizes: Dict[str, int] = dict()
|
|
||||||
|
|
||||||
try:
|
|
||||||
main_config = EmptyConfigLoader.load_config(self.model_path, config_name="config.json")
|
|
||||||
#main_config = json.loads(os.path.join(self.model_path, "config.json"))
|
|
||||||
except:
|
|
||||||
raise Exception("Invalid classifier model! (config.json not found or invalid)")
|
|
||||||
|
|
||||||
self._load_tokenizer(main_config)
|
|
||||||
self._load_text_encoder(main_config)
|
|
||||||
self._load_feature_extractor(main_config)
|
|
||||||
|
|
||||||
|
|
||||||
def _load_tokenizer(self, main_config: dict):
|
|
||||||
try:
|
|
||||||
tokenizer_config = EmptyConfigLoader.load_config(self.model_path, config_name="tokenizer_config.json")
|
|
||||||
#tokenizer_config = json.loads(os.path.join(self.model_path, "tokenizer_config.json"))
|
|
||||||
except:
|
|
||||||
raise Exception("Invalid classifier model! (Failed to load tokenizer_config.json)")
|
|
||||||
|
|
||||||
if "tokenizer_class" in tokenizer_config:
|
|
||||||
tokenizer_class_name = tokenizer_config["tokenizer_class"]
|
|
||||||
elif "model_type" in main_config:
|
|
||||||
tokenizer_class_name = transformers.models.auto.tokenization_auto.TOKENIZER_MAPPING_NAMES[main_config["model_type"]]
|
|
||||||
else:
|
|
||||||
raise Exception("Invalid classifier model! (Failed to detect tokenizer type)")
|
|
||||||
|
|
||||||
self.child_types[SDModelType.Tokenizer] = self._hf_definition_to_type(["transformers", tokenizer_class_name])
|
|
||||||
self.child_sizes[SDModelType.Tokenizer] = 0
|
|
||||||
|
|
||||||
|
|
||||||
def _load_text_encoder(self, main_config: dict):
|
|
||||||
if "architectures" in main_config and len(main_config["architectures"]) > 0:
|
|
||||||
text_encoder_class_name = main_config["architectures"][0]
|
|
||||||
elif "model_type" in main_config:
|
|
||||||
text_encoder_class_name = transformers.models.auto.modeling_auto.MODEL_FOR_PRETRAINING_MAPPING_NAMES[main_config["model_type"]]
|
|
||||||
else:
|
|
||||||
raise Exception("Invalid classifier model! (Failed to detect text_encoder type)")
|
|
||||||
|
|
||||||
self.child_types[SDModelType.TextEncoder] = self._hf_definition_to_type(["transformers", text_encoder_class_name])
|
|
||||||
self.child_sizes[SDModelType.TextEncoder] = calc_model_size_by_fs(self.model_path)
|
|
||||||
|
|
||||||
|
|
||||||
def _load_feature_extractor(self, main_config: dict):
|
|
||||||
self.child_sizes[SDModelType.FeatureExtractor] = 0
|
|
||||||
try:
|
|
||||||
feature_extractor_config = EmptyConfigLoader.load_config(self.model_path, config_name="preprocessor_config.json")
|
|
||||||
except:
|
|
||||||
return # feature extractor not passed with t5
|
|
||||||
|
|
||||||
try:
|
|
||||||
feature_extractor_class_name = feature_extractor_config["feature_extractor_type"]
|
|
||||||
self.child_types[SDModelType.FeatureExtractor] = self._hf_definition_to_type(["transformers", feature_extractor_class_name])
|
|
||||||
except:
|
|
||||||
raise Exception("Invalid classifier model! (Unknown feature_extrator type)")
|
|
||||||
|
|
||||||
|
|
||||||
def get_size(self, child_type: Optional[SDModelType] = None):
|
|
||||||
if child_type is None:
|
|
||||||
return sum(self.child_sizes.values())
|
|
||||||
else:
|
|
||||||
return self.child_sizes[child_type]
|
|
||||||
|
|
||||||
|
|
||||||
def get_model(
|
|
||||||
self,
|
|
||||||
torch_dtype: Optional[torch.dtype],
|
|
||||||
child_type: Optional[SDModelType] = None,
|
|
||||||
):
|
|
||||||
if child_type is None:
|
|
||||||
raise Exception("Child model type can't be null on classififer model")
|
|
||||||
if child_type not in self.child_types:
|
|
||||||
return None # TODO: or raise
|
|
||||||
|
|
||||||
model = self.child_types[child_type].from_pretrained(
|
|
||||||
self.model_path,
|
|
||||||
subfolder=child_type.value,
|
|
||||||
torch_dtype=torch_dtype,
|
|
||||||
)
|
|
||||||
# calc more accurate size
|
|
||||||
self.child_sizes[child_type] = calc_model_size_by_data(model)
|
|
||||||
return model
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def convert_if_required(model_path: Union[str, Path], cache_path: str, config: Optional[dict]) -> Path:
|
|
||||||
if not isinstance(model_path, Path):
|
|
||||||
model_path = Path(model_path)
|
|
||||||
return model_path
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class VaeModel(ModelBase):
|
|
||||||
#vae_class: Type
|
|
||||||
#model_size: int
|
|
||||||
|
|
||||||
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
|
|
||||||
assert model_type == ModelType.Vae
|
|
||||||
super().__init__(model_path, base_model, model_type)
|
|
||||||
|
|
||||||
try:
|
|
||||||
config = EmptyConfigLoader.load_config(self.model_path, config_name="config.json")
|
|
||||||
#config = json.loads(os.path.join(self.model_path, "config.json"))
|
|
||||||
except:
|
|
||||||
raise Exception("Invalid vae model! (config.json not found or invalid)")
|
|
||||||
|
|
||||||
try:
|
|
||||||
vae_class_name = config.get("_class_name", "AutoencoderKL")
|
|
||||||
self.vae_class = self._hf_definition_to_type(["diffusers", vae_class_name])
|
|
||||||
self.model_size = calc_model_size_by_fs(self.model_path)
|
|
||||||
except:
|
|
||||||
raise Exception("Invalid vae model! (Unkown vae type)")
|
|
||||||
|
|
||||||
def get_size(self, child_type: Optional[SDModelType] = None):
|
|
||||||
if child_type is not None:
|
|
||||||
raise Exception("There is no child models in vae model")
|
|
||||||
return self.model_size
|
|
||||||
|
|
||||||
def get_model(
|
|
||||||
self,
|
|
||||||
torch_dtype: Optional[torch.dtype],
|
|
||||||
child_type: Optional[SDModelType] = None,
|
|
||||||
):
|
|
||||||
if child_type is not None:
|
|
||||||
raise Exception("There is no child models in vae model")
|
|
||||||
|
|
||||||
model = self.vae_class.from_pretrained(
|
|
||||||
self.model_path,
|
|
||||||
torch_dtype=torch_dtype,
|
|
||||||
)
|
|
||||||
# calc more accurate size
|
|
||||||
self.model_size = calc_model_size_by_data(model)
|
|
||||||
return model
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def convert_if_required(model_path: Union[str, Path], cache_path: str, config: Optional[dict]) -> Path:
|
|
||||||
if not isinstance(model_path, Path):
|
|
||||||
model_path = Path(model_path)
|
|
||||||
# TODO:
|
|
||||||
#_convert_vae_ckpt_and_cache
|
|
||||||
raise Exception("TODO: ")
|
|
||||||
|
|
||||||
|
|
||||||
class LoRAModel(ModelBase):
|
|
||||||
#model_size: int
|
|
||||||
|
|
||||||
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
|
|
||||||
assert model_type == ModelType.Lora
|
|
||||||
super().__init__(model_path, base_model, model_type)
|
|
||||||
|
|
||||||
self.model_size = os.path.getsize(self.model_path)
|
|
||||||
|
|
||||||
def get_size(self, child_type: Optional[SDModelType] = None):
|
|
||||||
if child_type is not None:
|
|
||||||
raise Exception("There is no child models in lora")
|
|
||||||
return self.model_size
|
|
||||||
|
|
||||||
def get_model(
|
|
||||||
self,
|
|
||||||
torch_dtype: Optional[torch.dtype],
|
|
||||||
child_type: Optional[SDModelType] = None,
|
|
||||||
):
|
|
||||||
if child_type is not None:
|
|
||||||
raise Exception("There is no child models in lora")
|
|
||||||
|
|
||||||
model = LoRAModel.from_checkpoint(
|
|
||||||
file_path=self.model_path,
|
|
||||||
dtype=torch_dtype,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.model_size = model.calc_size()
|
|
||||||
return model
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def convert_if_required(model_path: Union[str, Path], cache_path: str, config: Optional[dict]) -> Path:
|
|
||||||
if not isinstance(model_path, Path):
|
|
||||||
model_path = Path(model_path)
|
|
||||||
|
|
||||||
# TODO: add diffusers lora when it stabilizes a bit
|
|
||||||
return model_path
|
|
||||||
|
|
||||||
|
|
||||||
class TextualInversionModel(ModelBase):
|
|
||||||
#model_size: int
|
|
||||||
|
|
||||||
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
|
|
||||||
assert model_type == ModelType.TextualInversion
|
|
||||||
super().__init__(model_path, base_model, model_type)
|
|
||||||
|
|
||||||
self.model_size = os.path.getsize(self.model_path)
|
|
||||||
|
|
||||||
def get_size(self, child_type: Optional[SDModelType] = None):
|
|
||||||
if child_type is not None:
|
|
||||||
raise Exception("There is no child models in textual inversion")
|
|
||||||
return self.model_size
|
|
||||||
|
|
||||||
def get_model(
|
|
||||||
self,
|
|
||||||
torch_dtype: Optional[torch.dtype],
|
|
||||||
child_type: Optional[SDModelType] = None,
|
|
||||||
):
|
|
||||||
if child_type is not None:
|
|
||||||
raise Exception("There is no child models in textual inversion")
|
|
||||||
|
|
||||||
model = TextualInversionModel.from_checkpoint(
|
|
||||||
file_path=self.model_path,
|
|
||||||
dtype=torch_dtype,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.model_size = model.embedding.nelement() * model.embedding.element_size()
|
|
||||||
return model
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def convert_if_required(model_path: Union[str, Path], cache_path: str, config: Optional[dict]) -> Path:
|
|
||||||
if not isinstance(model_path, Path):
|
|
||||||
model_path = Path(model_path)
|
|
||||||
return model_path
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def calc_model_size_by_fs(
|
|
||||||
model_path: str,
|
|
||||||
subfolder: Optional[str] = None,
|
|
||||||
variant: Optional[str] = None
|
|
||||||
):
|
|
||||||
if subfolder is not None:
|
|
||||||
model_path = os.path.join(model_path, subfolder)
|
|
||||||
|
|
||||||
# this can happen when, for example, the safety checker
|
|
||||||
# is not downloaded.
|
|
||||||
if not os.path.exists(model_path):
|
|
||||||
return 0
|
|
||||||
|
|
||||||
all_files = os.listdir(model_path)
|
|
||||||
all_files = [f for f in all_files if os.path.isfile(os.path.join(model_path, f))]
|
|
||||||
|
|
||||||
fp16_files = set([f for f in all_files if ".fp16." in f or ".fp16-" in f])
|
|
||||||
bit8_files = set([f for f in all_files if ".8bit." in f or ".8bit-" in f])
|
|
||||||
other_files = set(all_files) - fp16_files - bit8_files
|
|
||||||
|
|
||||||
if variant is None:
|
|
||||||
files = other_files
|
|
||||||
elif variant == "fp16":
|
|
||||||
files = fp16_files
|
|
||||||
elif variant == "8bit":
|
|
||||||
files = bit8_files
|
|
||||||
else:
|
|
||||||
raise NotImplementedError(f"Unknown variant: {variant}")
|
|
||||||
|
|
||||||
# try read from index if exists
|
|
||||||
index_postfix = ".index.json"
|
|
||||||
if variant is not None:
|
|
||||||
index_postfix = f".index.{variant}.json"
|
|
||||||
|
|
||||||
for file in files:
|
|
||||||
if not file.endswith(index_postfix):
|
|
||||||
continue
|
|
||||||
try:
|
|
||||||
with open(os.path.join(model_path, file), "r") as f:
|
|
||||||
index_data = json.loads(f.read())
|
|
||||||
return int(index_data["metadata"]["total_size"])
|
|
||||||
except:
|
|
||||||
pass
|
|
||||||
|
|
||||||
# calculate files size if there is no index file
|
|
||||||
formats = [
|
|
||||||
(".safetensors",), # safetensors
|
|
||||||
(".bin",), # torch
|
|
||||||
(".onnx", ".pb"), # onnx
|
|
||||||
(".msgpack",), # flax
|
|
||||||
(".ckpt",), # tf
|
|
||||||
(".h5",), # tf2
|
|
||||||
]
|
|
||||||
|
|
||||||
for file_format in formats:
|
|
||||||
model_files = [f for f in files if f.endswith(file_format)]
|
|
||||||
if len(model_files) == 0:
|
|
||||||
continue
|
|
||||||
|
|
||||||
model_size = 0
|
|
||||||
for model_file in model_files:
|
|
||||||
file_stats = os.stat(os.path.join(model_path, model_file))
|
|
||||||
model_size += file_stats.st_size
|
|
||||||
return model_size
|
|
||||||
|
|
||||||
#raise NotImplementedError(f"Unknown model structure! Files: {all_files}")
|
|
||||||
return 0 # scheduler/feature_extractor/tokenizer - models without loading to gpu
|
|
||||||
|
|
||||||
|
|
||||||
def calc_model_size_by_data(model) -> int:
|
|
||||||
if isinstance(model, DiffusionPipeline):
|
|
||||||
return _calc_pipeline_by_data(model)
|
|
||||||
elif isinstance(model, torch.nn.Module):
|
|
||||||
return _calc_model_by_data(model)
|
|
||||||
else:
|
|
||||||
return 0
|
|
||||||
|
|
||||||
|
|
||||||
def _calc_pipeline_by_data(pipeline) -> int:
|
|
||||||
res = 0
|
|
||||||
for submodel_key in pipeline.components.keys():
|
|
||||||
submodel = getattr(pipeline, submodel_key)
|
|
||||||
if submodel is not None and isinstance(submodel, torch.nn.Module):
|
|
||||||
res += _calc_model_by_data(submodel)
|
|
||||||
return res
|
|
||||||
|
|
||||||
|
|
||||||
def _calc_model_by_data(model) -> int:
|
|
||||||
mem_params = sum([param.nelement()*param.element_size() for param in model.parameters()])
|
|
||||||
mem_bufs = sum([buf.nelement()*buf.element_size() for buf in model.buffers()])
|
|
||||||
mem = mem_params + mem_bufs # in bytes
|
|
||||||
return mem
|
|
||||||
|
|
||||||
|
|
||||||
def _convert_ckpt_and_cache(self, mconfig: DictConfig) -> Path:
|
|
||||||
"""
|
|
||||||
Convert the checkpoint model indicated in mconfig into a
|
|
||||||
diffusers, cache it to disk, and return Path to converted
|
|
||||||
file. If already on disk then just returns Path.
|
|
||||||
"""
|
|
||||||
app_config = InvokeAIAppConfig.get_config()
|
|
||||||
weights = app_config.root_dir / mconfig.path
|
|
||||||
config_file = app_config.root_dir / mconfig.config
|
|
||||||
diffusers_path = app_config.converted_ckpts_dir / weights.stem
|
|
||||||
|
|
||||||
# return cached version if it exists
|
|
||||||
if diffusers_path.exists():
|
|
||||||
return diffusers_path
|
|
||||||
|
|
||||||
# TODO: I think that it more correctly to convert with embedded vae
|
|
||||||
# as if user will delete custom vae he will got not embedded but also custom vae
|
|
||||||
#vae_ckpt_path, vae_model = self._get_vae_for_conversion(weights, mconfig)
|
|
||||||
vae_ckpt_path, vae_model = None, None
|
|
||||||
|
|
||||||
# to avoid circular import errors
|
|
||||||
from .convert_ckpt_to_diffusers import convert_ckpt_to_diffusers
|
|
||||||
with SilenceWarnings():
|
|
||||||
convert_ckpt_to_diffusers(
|
|
||||||
weights,
|
|
||||||
diffusers_path,
|
|
||||||
extract_ema=True,
|
|
||||||
original_config_file=config_file,
|
|
||||||
vae=vae_model,
|
|
||||||
vae_path=str(app_config.root_dir / vae_ckpt_path) if vae_ckpt_path else None,
|
|
||||||
scan_needed=True,
|
|
||||||
)
|
|
||||||
return diffusers_path
|
|
||||||
|
|
||||||
def _convert_vae_ckpt_and_cache(self, mconfig: DictConfig) -> Path:
|
|
||||||
"""
|
|
||||||
Convert the VAE indicated in mconfig into a diffusers AutoencoderKL
|
|
||||||
object, cache it to disk, and return Path to converted
|
|
||||||
file. If already on disk then just returns Path.
|
|
||||||
"""
|
|
||||||
app_config = InvokeAIAppConfig.get_config()
|
|
||||||
root = app_config.root_dir
|
|
||||||
weights_file = root / mconfig.path
|
|
||||||
config_file = root / mconfig.config
|
|
||||||
diffusers_path = app_config.converted_ckpts_dir / weights_file.stem
|
|
||||||
image_size = mconfig.get('width') or mconfig.get('height') or 512
|
|
||||||
|
|
||||||
# return cached version if it exists
|
|
||||||
if diffusers_path.exists():
|
|
||||||
return diffusers_path
|
|
||||||
|
|
||||||
# this avoids circular import error
|
|
||||||
from .convert_ckpt_to_diffusers import convert_ldm_vae_to_diffusers
|
|
||||||
if weights_file.suffix == '.safetensors':
|
|
||||||
checkpoint = safetensors.torch.load_file(weights_file)
|
|
||||||
else:
|
|
||||||
checkpoint = torch.load(weights_file, map_location="cpu")
|
|
||||||
|
|
||||||
# sometimes weights are hidden under "state_dict", and sometimes not
|
|
||||||
if "state_dict" in checkpoint:
|
|
||||||
checkpoint = checkpoint["state_dict"]
|
|
||||||
|
|
||||||
config = OmegaConf.load(config_file)
|
|
||||||
|
|
||||||
vae_model = convert_ldm_vae_to_diffusers(
|
|
||||||
checkpoint = checkpoint,
|
|
||||||
vae_config = config,
|
|
||||||
image_size = image_size
|
|
||||||
)
|
|
||||||
vae_model.save_pretrained(
|
|
||||||
diffusers_path,
|
|
||||||
safe_serialization=is_safetensors_available()
|
|
||||||
)
|
|
||||||
return diffusers_path
|
|
37
invokeai/backend/model_management/models/__init__.py
Normal file
37
invokeai/backend/model_management/models/__init__.py
Normal file
@ -0,0 +1,37 @@
|
|||||||
|
from .base import BaseModelType, ModelType, SubModelType, ModelBase, ModelConfigBase
|
||||||
|
from .stable_diffusion import StableDiffusion15Model, StableDiffusion2Model, StableDiffusion2BaseModel
|
||||||
|
from .vae import VaeModel
|
||||||
|
from .lora import LoRAModel
|
||||||
|
#from .controlnet import ControlNetModel # TODO:
|
||||||
|
from .textual_inversion import TextualInversionModel
|
||||||
|
|
||||||
|
MODEL_CLASSES = {
|
||||||
|
BaseModelType.StableDiffusion1_5: {
|
||||||
|
ModelType.Pipeline: StableDiffusion15Model,
|
||||||
|
ModelType.Vae: VaeModel,
|
||||||
|
ModelType.Lora: LoRAModel,
|
||||||
|
#ModelType.ControlNet: ControlNetModel,
|
||||||
|
ModelType.TextualInversion: TextualInversionModel,
|
||||||
|
},
|
||||||
|
BaseModelType.StableDiffusion2: {
|
||||||
|
ModelType.Pipeline: StableDiffusion2Model,
|
||||||
|
ModelType.Vae: VaeModel,
|
||||||
|
ModelType.Lora: LoRAModel,
|
||||||
|
#ModelType.ControlNet: ControlNetModel,
|
||||||
|
ModelType.TextualInversion: TextualInversionModel,
|
||||||
|
},
|
||||||
|
BaseModelType.StableDiffusion2Base: {
|
||||||
|
ModelType.Pipeline: StableDiffusion2BaseModel,
|
||||||
|
ModelType.Vae: VaeModel,
|
||||||
|
ModelType.Lora: LoRAModel,
|
||||||
|
#ModelType.ControlNet: ControlNetModel,
|
||||||
|
ModelType.TextualInversion: TextualInversionModel,
|
||||||
|
},
|
||||||
|
#BaseModelType.Kandinsky2_1: {
|
||||||
|
# ModelType.Pipeline: Kandinsky2_1Model,
|
||||||
|
# ModelType.MoVQ: MoVQModel,
|
||||||
|
# ModelType.Lora: LoRAModel,
|
||||||
|
# ModelType.ControlNet: ControlNetModel,
|
||||||
|
# ModelType.TextualInversion: TextualInversionModel,
|
||||||
|
#},
|
||||||
|
}
|
295
invokeai/backend/model_management/models/base.py
Normal file
295
invokeai/backend/model_management/models/base.py
Normal file
@ -0,0 +1,295 @@
|
|||||||
|
import sys
|
||||||
|
from enum import Enum
|
||||||
|
import torch
|
||||||
|
from diffusers import DiffusionPipeline, ConfigMixin
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
from typing import List, Dict, Optional, Type
|
||||||
|
|
||||||
|
class BaseModelType(str, Enum):
|
||||||
|
#StableDiffusion1_5 = "stable_diffusion_1_5"
|
||||||
|
#StableDiffusion2 = "stable_diffusion_2"
|
||||||
|
#StableDiffusion2Base = "stable_diffusion_2_base"
|
||||||
|
# TODO: maybe then add sample size(512/768)?
|
||||||
|
StableDiffusion1_5 = "SD-1"
|
||||||
|
StableDiffusion2Base = "SD-2-base" # 512 pixels; this will have epsilon parameterization
|
||||||
|
StableDiffusion2 = "SD-2" # 768 pixels; this will have v-prediction parameterization
|
||||||
|
#Kandinsky2_1 = "kandinsky_2_1"
|
||||||
|
|
||||||
|
class ModelType(str, Enum):
|
||||||
|
Pipeline = "pipeline"
|
||||||
|
Vae = "vae"
|
||||||
|
|
||||||
|
Lora = "lora"
|
||||||
|
ControlNet = "controlnet"
|
||||||
|
TextualInversion = "embedding"
|
||||||
|
|
||||||
|
class SubModelType:
|
||||||
|
UNet = "unet"
|
||||||
|
TextEncoder = "text_encoder"
|
||||||
|
Tokenizer = "tokenizer"
|
||||||
|
Vae = "vae"
|
||||||
|
Scheduler = "scheduler"
|
||||||
|
SafetyChecker = "safety_checker"
|
||||||
|
#MoVQ = "movq"
|
||||||
|
|
||||||
|
class ModelError(str, Enum):
|
||||||
|
NotFound = "not_found"
|
||||||
|
|
||||||
|
class ModelConfigBase(BaseModel):
|
||||||
|
path: str # or Path
|
||||||
|
#name: str # not included as present in model key
|
||||||
|
description: Optional[str] = Field(None)
|
||||||
|
format: Optional[str] = Field(None)
|
||||||
|
default: Optional[bool] = Field(False)
|
||||||
|
# do not save to config
|
||||||
|
error: Optional[ModelError] = Field(None, exclude=True)
|
||||||
|
|
||||||
|
|
||||||
|
class EmptyConfigLoader(ConfigMixin):
|
||||||
|
@classmethod
|
||||||
|
def load_config(cls, *args, **kwargs):
|
||||||
|
cls.config_name = kwargs.pop("config_name")
|
||||||
|
return super().load_config(*args, **kwargs)
|
||||||
|
|
||||||
|
class ModelBase:
|
||||||
|
#model_path: str
|
||||||
|
#base_model: BaseModelType
|
||||||
|
#model_type: ModelType
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_path: str,
|
||||||
|
base_model: BaseModelType,
|
||||||
|
model_type: ModelType,
|
||||||
|
):
|
||||||
|
self.model_path = model_path
|
||||||
|
self.base_model = base_model
|
||||||
|
self.model_type = model_type
|
||||||
|
|
||||||
|
def _hf_definition_to_type(self, subtypes: List[str]) -> Type:
|
||||||
|
if len(subtypes) < 2:
|
||||||
|
raise Exception("Invalid subfolder definition!")
|
||||||
|
if subtypes[0] in ["diffusers", "transformers"]:
|
||||||
|
res_type = sys.modules[subtypes[0]]
|
||||||
|
subtypes = subtypes[1:]
|
||||||
|
|
||||||
|
else:
|
||||||
|
res_type = sys.modules["diffusers"]
|
||||||
|
res_type = getattr(res_type, "pipelines")
|
||||||
|
|
||||||
|
|
||||||
|
for subtype in subtypes:
|
||||||
|
res_type = getattr(res_type, subtype)
|
||||||
|
return res_type
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _get_configs(cls):
|
||||||
|
if not hasattr(cls, "__configs"):
|
||||||
|
configs = dict()
|
||||||
|
for name in dir(cls):
|
||||||
|
if name.startswith("__"):
|
||||||
|
continue
|
||||||
|
|
||||||
|
value = getattr(cls, name)
|
||||||
|
if not isinstance(value, type) or not issubclass(value, ModelConfigBase):
|
||||||
|
continue
|
||||||
|
|
||||||
|
fields = inspect.get_annotations(value)
|
||||||
|
if "format" not in fields or typing.get_origin(fields["format"]) != Literal:
|
||||||
|
raise Exception("Invalid config definition - format field not found")
|
||||||
|
|
||||||
|
format_type = typing.get_origin(fields["format"])
|
||||||
|
if format_type not in {None, Literal}:
|
||||||
|
raise Exception(f"Invalid config definition - unknown format type: {fields['format']}")
|
||||||
|
|
||||||
|
if format_type is Literal:
|
||||||
|
format = fields["format"].__args__[0]
|
||||||
|
else:
|
||||||
|
format = None
|
||||||
|
configs[format] = value # TODO: error when override(multiple)?
|
||||||
|
|
||||||
|
cls.__configs = configs
|
||||||
|
|
||||||
|
return cls.__configs
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def build_config(cls, **kwargs):
|
||||||
|
if "format" not in kwargs:
|
||||||
|
kwargs["format"] = cls.detect_format(kwargs["path"])
|
||||||
|
|
||||||
|
configs = cls._get_configs()
|
||||||
|
return configs[kwargs["format"]](**kwargs)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def detect_format(cls, path: str) -> str:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class DiffusersModel(ModelBase):
|
||||||
|
#child_types: Dict[str, Type]
|
||||||
|
#child_sizes: Dict[str, int]
|
||||||
|
|
||||||
|
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
|
||||||
|
super().__init__(model_path, base_model, model_type)
|
||||||
|
|
||||||
|
self.child_types: Dict[str, Type] = dict()
|
||||||
|
self.child_sizes: Dict[str, int] = dict()
|
||||||
|
|
||||||
|
try:
|
||||||
|
config_data = DiffusionPipeline.load_config(self.model_path)
|
||||||
|
#config_data = json.loads(os.path.join(self.model_path, "model_index.json"))
|
||||||
|
except:
|
||||||
|
raise Exception("Invalid diffusers model! (model_index.json not found or invalid)")
|
||||||
|
|
||||||
|
config_data.pop("_ignore_files", None)
|
||||||
|
|
||||||
|
# retrieve all folder_names that contain relevant files
|
||||||
|
child_components = [k for k, v in config_data.items() if isinstance(v, list)]
|
||||||
|
|
||||||
|
for child_name in child_components:
|
||||||
|
child_type = self._hf_definition_to_type(config_data[child_name])
|
||||||
|
self.child_types[child_name] = child_type
|
||||||
|
self.child_sizes[child_name] = calc_model_size_by_fs(self.model_path, subfolder=child_name)
|
||||||
|
|
||||||
|
|
||||||
|
def get_size(self, child_type: Optional[SubModelType] = None):
|
||||||
|
if child_type is None:
|
||||||
|
return sum(self.child_sizes.values())
|
||||||
|
else:
|
||||||
|
return self.child_sizes[child_type]
|
||||||
|
|
||||||
|
|
||||||
|
def get_model(
|
||||||
|
self,
|
||||||
|
torch_dtype: Optional[torch.dtype],
|
||||||
|
child_type: Optional[SubModelType] = None,
|
||||||
|
):
|
||||||
|
# return pipeline in different function to pass more arguments
|
||||||
|
if child_type is None:
|
||||||
|
raise Exception("Child model type can't be null on diffusers model")
|
||||||
|
if child_type not in self.child_types:
|
||||||
|
return None # TODO: or raise
|
||||||
|
|
||||||
|
if torch_dtype == torch.float16:
|
||||||
|
variants = ["fp16", None]
|
||||||
|
else:
|
||||||
|
variants = [None, "fp16"]
|
||||||
|
|
||||||
|
# TODO: better error handling(differentiate not found from others)
|
||||||
|
for variant in variants:
|
||||||
|
try:
|
||||||
|
# TODO: set cache_dir to /dev/null to be sure that cache not used?
|
||||||
|
model = self.child_types[child_type].from_pretrained(
|
||||||
|
self.model_path,
|
||||||
|
subfolder=child_type.value,
|
||||||
|
torch_dtype=torch_dtype,
|
||||||
|
variant=variant,
|
||||||
|
local_files_only=True,
|
||||||
|
)
|
||||||
|
break
|
||||||
|
except Exception as e:
|
||||||
|
print("====ERR LOAD====")
|
||||||
|
print(f"{variant}: {e}")
|
||||||
|
|
||||||
|
# calc more accurate size
|
||||||
|
self.child_sizes[child_type] = calc_model_size_by_data(model)
|
||||||
|
return model
|
||||||
|
|
||||||
|
#def convert_if_required(model_path: str, cache_path: str, config: Optional[dict]) -> str:
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def calc_model_size_by_fs(
|
||||||
|
model_path: str,
|
||||||
|
subfolder: Optional[str] = None,
|
||||||
|
variant: Optional[str] = None
|
||||||
|
):
|
||||||
|
if subfolder is not None:
|
||||||
|
model_path = os.path.join(model_path, subfolder)
|
||||||
|
|
||||||
|
# this can happen when, for example, the safety checker
|
||||||
|
# is not downloaded.
|
||||||
|
if not os.path.exists(model_path):
|
||||||
|
return 0
|
||||||
|
|
||||||
|
all_files = os.listdir(model_path)
|
||||||
|
all_files = [f for f in all_files if os.path.isfile(os.path.join(model_path, f))]
|
||||||
|
|
||||||
|
fp16_files = set([f for f in all_files if ".fp16." in f or ".fp16-" in f])
|
||||||
|
bit8_files = set([f for f in all_files if ".8bit." in f or ".8bit-" in f])
|
||||||
|
other_files = set(all_files) - fp16_files - bit8_files
|
||||||
|
|
||||||
|
if variant is None:
|
||||||
|
files = other_files
|
||||||
|
elif variant == "fp16":
|
||||||
|
files = fp16_files
|
||||||
|
elif variant == "8bit":
|
||||||
|
files = bit8_files
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(f"Unknown variant: {variant}")
|
||||||
|
|
||||||
|
# try read from index if exists
|
||||||
|
index_postfix = ".index.json"
|
||||||
|
if variant is not None:
|
||||||
|
index_postfix = f".index.{variant}.json"
|
||||||
|
|
||||||
|
for file in files:
|
||||||
|
if not file.endswith(index_postfix):
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
with open(os.path.join(model_path, file), "r") as f:
|
||||||
|
index_data = json.loads(f.read())
|
||||||
|
return int(index_data["metadata"]["total_size"])
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# calculate files size if there is no index file
|
||||||
|
formats = [
|
||||||
|
(".safetensors",), # safetensors
|
||||||
|
(".bin",), # torch
|
||||||
|
(".onnx", ".pb"), # onnx
|
||||||
|
(".msgpack",), # flax
|
||||||
|
(".ckpt",), # tf
|
||||||
|
(".h5",), # tf2
|
||||||
|
]
|
||||||
|
|
||||||
|
for file_format in formats:
|
||||||
|
model_files = [f for f in files if f.endswith(file_format)]
|
||||||
|
if len(model_files) == 0:
|
||||||
|
continue
|
||||||
|
|
||||||
|
model_size = 0
|
||||||
|
for model_file in model_files:
|
||||||
|
file_stats = os.stat(os.path.join(model_path, model_file))
|
||||||
|
model_size += file_stats.st_size
|
||||||
|
return model_size
|
||||||
|
|
||||||
|
#raise NotImplementedError(f"Unknown model structure! Files: {all_files}")
|
||||||
|
return 0 # scheduler/feature_extractor/tokenizer - models without loading to gpu
|
||||||
|
|
||||||
|
|
||||||
|
def calc_model_size_by_data(model) -> int:
|
||||||
|
if isinstance(model, DiffusionPipeline):
|
||||||
|
return _calc_pipeline_by_data(model)
|
||||||
|
elif isinstance(model, torch.nn.Module):
|
||||||
|
return _calc_model_by_data(model)
|
||||||
|
else:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
|
||||||
|
def _calc_pipeline_by_data(pipeline) -> int:
|
||||||
|
res = 0
|
||||||
|
for submodel_key in pipeline.components.keys():
|
||||||
|
submodel = getattr(pipeline, submodel_key)
|
||||||
|
if submodel is not None and isinstance(submodel, torch.nn.Module):
|
||||||
|
res += _calc_model_by_data(submodel)
|
||||||
|
return res
|
||||||
|
|
||||||
|
|
||||||
|
def _calc_model_by_data(model) -> int:
|
||||||
|
mem_params = sum([param.nelement()*param.element_size() for param in model.parameters()])
|
||||||
|
mem_bufs = sum([buf.nelement()*buf.element_size() for buf in model.buffers()])
|
||||||
|
mem = mem_params + mem_bufs # in bytes
|
||||||
|
return mem
|
63
invokeai/backend/model_management/models/lora.py
Normal file
63
invokeai/backend/model_management/models/lora.py
Normal file
@ -0,0 +1,63 @@
|
|||||||
|
import torch
|
||||||
|
from typing import Optional
|
||||||
|
from .base import (
|
||||||
|
ModelBase,
|
||||||
|
ModelConfigBase,
|
||||||
|
BaseModelType,
|
||||||
|
ModelType,
|
||||||
|
SubModelType,
|
||||||
|
)
|
||||||
|
# TODO: naming
|
||||||
|
from ..lora import LoRAModel as LoRAModelRaw
|
||||||
|
|
||||||
|
class LoRAModel(ModelBase):
|
||||||
|
#model_size: int
|
||||||
|
|
||||||
|
class Config(ModelConfigBase):
|
||||||
|
format: None
|
||||||
|
|
||||||
|
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
|
||||||
|
assert model_type == ModelType.Lora
|
||||||
|
super().__init__(model_path, base_model, model_type)
|
||||||
|
|
||||||
|
self.model_size = os.path.getsize(self.model_path)
|
||||||
|
|
||||||
|
def get_size(self, child_type: Optional[SubModelType] = None):
|
||||||
|
if child_type is not None:
|
||||||
|
raise Exception("There is no child models in lora")
|
||||||
|
return self.model_size
|
||||||
|
|
||||||
|
def get_model(
|
||||||
|
self,
|
||||||
|
torch_dtype: Optional[torch.dtype],
|
||||||
|
child_type: Optional[SubModelType] = None,
|
||||||
|
):
|
||||||
|
if child_type is not None:
|
||||||
|
raise Exception("There is no child models in lora")
|
||||||
|
|
||||||
|
model = LoRAModelRaw.from_checkpoint(
|
||||||
|
file_path=self.model_path,
|
||||||
|
dtype=torch_dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.model_size = model.calc_size()
|
||||||
|
return model
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def save_to_config(cls) -> bool:
|
||||||
|
return False
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def detect_format(cls, path: str):
|
||||||
|
if os.path.isdir(path):
|
||||||
|
return "diffusers"
|
||||||
|
else:
|
||||||
|
return "lycoris"
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def convert_if_required(cls, model_path: str, dst_cache_path: str, config: Optional[dict]) -> str:
|
||||||
|
if cls.detect_format(model_path) == "diffusers":
|
||||||
|
# TODO: add diffusers lora when it stabilizes a bit
|
||||||
|
raise NotImplementedError("Diffusers lora not supported")
|
||||||
|
else:
|
||||||
|
return model_path
|
131
invokeai/backend/model_management/models/stable_diffusion.py
Normal file
131
invokeai/backend/model_management/models/stable_diffusion.py
Normal file
@ -0,0 +1,131 @@
|
|||||||
|
import os
|
||||||
|
import torch
|
||||||
|
from pydantic import Field
|
||||||
|
from typing import Literal, Optional
|
||||||
|
from .base import (
|
||||||
|
ModelBase,
|
||||||
|
ModelConfigBase,
|
||||||
|
BaseModelType,
|
||||||
|
ModelType,
|
||||||
|
SubModelType,
|
||||||
|
DiffusersModel,
|
||||||
|
)
|
||||||
|
from invokeai.app.services.config import InvokeAIAppConfig
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: how to name properly
|
||||||
|
class StableDiffusion15Model(DiffusersModel):
|
||||||
|
|
||||||
|
# TODO: str -> Path?
|
||||||
|
class DiffusersConfig(ModelConfigBase):
|
||||||
|
format: Literal["diffusers"]
|
||||||
|
vae: Optional[str] = Field(None)
|
||||||
|
|
||||||
|
class CheckpointConfig(ModelConfigBase):
|
||||||
|
format: Literal["checkpoint"]
|
||||||
|
vae: Optional[str] = Field(None)
|
||||||
|
config: Optional[str] = Field(None)
|
||||||
|
|
||||||
|
|
||||||
|
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
|
||||||
|
assert base_model == BaseModelType.StableDiffusion1_5
|
||||||
|
assert model_type == ModelType.Pipeline
|
||||||
|
super().__init__(
|
||||||
|
model_path=model_path,
|
||||||
|
base_model=BaseModelType.StableDiffusion1_5,
|
||||||
|
model_type=ModelType.Pipeline,
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def save_to_config(cls) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def detect_format(cls, model_path: str):
|
||||||
|
if os.path.isdir(model_path):
|
||||||
|
return "diffusers"
|
||||||
|
else:
|
||||||
|
return "checkpoint"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def convert_if_required(cls, model_path: str, dst_cache_path: str, config: Optional[dict]) -> str:
|
||||||
|
cfg = cls.build_config(**config)
|
||||||
|
if isinstance(cfg, cls.CheckpointConfig):
|
||||||
|
return _convert_ckpt_and_cache(cfg) # TODO: args
|
||||||
|
else:
|
||||||
|
return model_path
|
||||||
|
|
||||||
|
# all same
|
||||||
|
class StableDiffusion2BaseModel(StableDiffusion15Model):
|
||||||
|
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
|
||||||
|
# skip StableDiffusion15Model __init__
|
||||||
|
assert base_model == BaseModelType.StableDiffusion2Base
|
||||||
|
assert model_type == ModelType.Pipeline
|
||||||
|
super(StableDiffusion15Model, self).__init__(
|
||||||
|
model_path=model_path,
|
||||||
|
base_model=BaseModelType.StableDiffusion2Base,
|
||||||
|
model_type=ModelType.Pipeline,
|
||||||
|
)
|
||||||
|
|
||||||
|
class StableDiffusion2Model(DiffusersModel):
|
||||||
|
|
||||||
|
# TODO: str -> Path?
|
||||||
|
# overwrite configs
|
||||||
|
class DiffusersConfig(ModelConfigBase):
|
||||||
|
format: Literal["diffusers"]
|
||||||
|
vae: Optional[str] = Field(None)
|
||||||
|
attention_upscale: bool = Field(True)
|
||||||
|
|
||||||
|
class CheckpointConfig(ModelConfigBase):
|
||||||
|
format: Literal["checkpoint"]
|
||||||
|
vae: Optional[str] = Field(None)
|
||||||
|
config: Optional[str] = Field(None)
|
||||||
|
attention_upscale: bool = Field(True)
|
||||||
|
|
||||||
|
|
||||||
|
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
|
||||||
|
# skip StableDiffusion15Model __init__
|
||||||
|
assert base_model == BaseModelType.StableDiffusion2
|
||||||
|
assert model_type == ModelType.Pipeline
|
||||||
|
super().__init__(
|
||||||
|
model_path=model_path,
|
||||||
|
base_model=BaseModelType.StableDiffusion2,
|
||||||
|
model_type=ModelType.Pipeline,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: rework
|
||||||
|
DictConfig = dict
|
||||||
|
def _convert_ckpt_and_cache(self, mconfig: DictConfig) -> str:
|
||||||
|
"""
|
||||||
|
Convert the checkpoint model indicated in mconfig into a
|
||||||
|
diffusers, cache it to disk, and return Path to converted
|
||||||
|
file. If already on disk then just returns Path.
|
||||||
|
"""
|
||||||
|
app_config = InvokeAIAppConfig.get_config()
|
||||||
|
weights = app_config.root_dir / mconfig.path
|
||||||
|
config_file = app_config.root_dir / mconfig.config
|
||||||
|
diffusers_path = app_config.converted_ckpts_dir / weights.stem
|
||||||
|
|
||||||
|
# return cached version if it exists
|
||||||
|
if diffusers_path.exists():
|
||||||
|
return diffusers_path
|
||||||
|
|
||||||
|
# TODO: I think that it more correctly to convert with embedded vae
|
||||||
|
# as if user will delete custom vae he will got not embedded but also custom vae
|
||||||
|
#vae_ckpt_path, vae_model = self._get_vae_for_conversion(weights, mconfig)
|
||||||
|
vae_ckpt_path, vae_model = None, None
|
||||||
|
|
||||||
|
# to avoid circular import errors
|
||||||
|
from ..convert_ckpt_to_diffusers import convert_ckpt_to_diffusers
|
||||||
|
with SilenceWarnings():
|
||||||
|
convert_ckpt_to_diffusers(
|
||||||
|
weights,
|
||||||
|
diffusers_path,
|
||||||
|
extract_ema=True,
|
||||||
|
original_config_file=config_file,
|
||||||
|
vae=vae_model,
|
||||||
|
vae_path=str(app_config.root_dir / vae_ckpt_path) if vae_ckpt_path else None,
|
||||||
|
scan_needed=True,
|
||||||
|
)
|
||||||
|
return diffusers_path
|
@ -0,0 +1,56 @@
|
|||||||
|
import torch
|
||||||
|
from typing import Optional
|
||||||
|
from .base import (
|
||||||
|
ModelBase,
|
||||||
|
ModelConfigBase,
|
||||||
|
BaseModelType,
|
||||||
|
ModelType,
|
||||||
|
SubModelType,
|
||||||
|
)
|
||||||
|
# TODO: naming
|
||||||
|
from ..lora import TextualInversionModel as TextualInversionModelRaw
|
||||||
|
|
||||||
|
class TextualInversionModel(ModelBase):
|
||||||
|
#model_size: int
|
||||||
|
|
||||||
|
class Config(ModelConfigBase):
|
||||||
|
format: None
|
||||||
|
|
||||||
|
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
|
||||||
|
assert model_type == ModelType.TextualInversion
|
||||||
|
super().__init__(model_path, base_model, model_type)
|
||||||
|
|
||||||
|
self.model_size = os.path.getsize(self.model_path)
|
||||||
|
|
||||||
|
def get_size(self, child_type: Optional[SubModelType] = None):
|
||||||
|
if child_type is not None:
|
||||||
|
raise Exception("There is no child models in textual inversion")
|
||||||
|
return self.model_size
|
||||||
|
|
||||||
|
def get_model(
|
||||||
|
self,
|
||||||
|
torch_dtype: Optional[torch.dtype],
|
||||||
|
child_type: Optional[SubModelType] = None,
|
||||||
|
):
|
||||||
|
if child_type is not None:
|
||||||
|
raise Exception("There is no child models in textual inversion")
|
||||||
|
|
||||||
|
model = TextualInversionModelRaw.from_checkpoint(
|
||||||
|
file_path=self.model_path,
|
||||||
|
dtype=torch_dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.model_size = model.embedding.nelement() * model.embedding.element_size()
|
||||||
|
return model
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def save_to_config(cls) -> bool:
|
||||||
|
return False
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def detect_format(cls, path: str):
|
||||||
|
return None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def convert_if_required(model_path: str, cache_path: str, config: Optional[dict]) -> str:
|
||||||
|
return model_path
|
122
invokeai/backend/model_management/models/vae.py
Normal file
122
invokeai/backend/model_management/models/vae.py
Normal file
@ -0,0 +1,122 @@
|
|||||||
|
import os
|
||||||
|
import torch
|
||||||
|
from typing import Optional
|
||||||
|
from .base import (
|
||||||
|
ModelBase,
|
||||||
|
ModelConfigBase,
|
||||||
|
BaseModelType,
|
||||||
|
ModelType,
|
||||||
|
SubModelType,
|
||||||
|
EmptyConfigLoader,
|
||||||
|
calc_model_size_by_fs,
|
||||||
|
calc_model_size_by_data,
|
||||||
|
)
|
||||||
|
from invokeai.app.services.config import InvokeAIAppConfig
|
||||||
|
|
||||||
|
class VaeModel(ModelBase):
|
||||||
|
#vae_class: Type
|
||||||
|
#model_size: int
|
||||||
|
|
||||||
|
class Config(ModelConfigBase):
|
||||||
|
format: None
|
||||||
|
|
||||||
|
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
|
||||||
|
assert model_type == ModelType.Vae
|
||||||
|
super().__init__(model_path, base_model, model_type)
|
||||||
|
|
||||||
|
try:
|
||||||
|
config = EmptyConfigLoader.load_config(self.model_path, config_name="config.json")
|
||||||
|
#config = json.loads(os.path.join(self.model_path, "config.json"))
|
||||||
|
except:
|
||||||
|
raise Exception("Invalid vae model! (config.json not found or invalid)")
|
||||||
|
|
||||||
|
try:
|
||||||
|
vae_class_name = config.get("_class_name", "AutoencoderKL")
|
||||||
|
self.vae_class = self._hf_definition_to_type(["diffusers", vae_class_name])
|
||||||
|
self.model_size = calc_model_size_by_fs(self.model_path)
|
||||||
|
except:
|
||||||
|
raise Exception("Invalid vae model! (Unkown vae type)")
|
||||||
|
|
||||||
|
def get_size(self, child_type: Optional[SubModelType] = None):
|
||||||
|
if child_type is not None:
|
||||||
|
raise Exception("There is no child models in vae model")
|
||||||
|
return self.model_size
|
||||||
|
|
||||||
|
def get_model(
|
||||||
|
self,
|
||||||
|
torch_dtype: Optional[torch.dtype],
|
||||||
|
child_type: Optional[SubModelType] = None,
|
||||||
|
):
|
||||||
|
if child_type is not None:
|
||||||
|
raise Exception("There is no child models in vae model")
|
||||||
|
|
||||||
|
model = self.vae_class.from_pretrained(
|
||||||
|
self.model_path,
|
||||||
|
torch_dtype=torch_dtype,
|
||||||
|
)
|
||||||
|
# calc more accurate size
|
||||||
|
self.model_size = calc_model_size_by_data(model)
|
||||||
|
return model
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def save_to_config(cls) -> bool:
|
||||||
|
return False
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def detect_format(cls, path: str):
|
||||||
|
if os.path.isdir(path):
|
||||||
|
return "diffusers"
|
||||||
|
else:
|
||||||
|
return "checkpoint"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def convert_if_required(cls, model_path: str, dst_cache_path: str, config: Optional[dict]) -> str:
|
||||||
|
if cls.detect_format(model_path) != "diffusers":
|
||||||
|
# TODO:
|
||||||
|
#_convert_vae_ckpt_and_cache
|
||||||
|
raise NotImplementedError("TODO: vae convert")
|
||||||
|
else:
|
||||||
|
return model_path
|
||||||
|
|
||||||
|
# TODO: rework
|
||||||
|
DictConfig = dict
|
||||||
|
def _convert_vae_ckpt_and_cache(self, mconfig: DictConfig) -> str:
|
||||||
|
"""
|
||||||
|
Convert the VAE indicated in mconfig into a diffusers AutoencoderKL
|
||||||
|
object, cache it to disk, and return Path to converted
|
||||||
|
file. If already on disk then just returns Path.
|
||||||
|
"""
|
||||||
|
app_config = InvokeAIAppConfig.get_config()
|
||||||
|
root = app_config.root_dir
|
||||||
|
weights_file = root / mconfig.path
|
||||||
|
config_file = root / mconfig.config
|
||||||
|
diffusers_path = app_config.converted_ckpts_dir / weights_file.stem
|
||||||
|
image_size = mconfig.get('width') or mconfig.get('height') or 512
|
||||||
|
|
||||||
|
# return cached version if it exists
|
||||||
|
if diffusers_path.exists():
|
||||||
|
return diffusers_path
|
||||||
|
|
||||||
|
# this avoids circular import error
|
||||||
|
from .convert_ckpt_to_diffusers import convert_ldm_vae_to_diffusers
|
||||||
|
if weights_file.suffix == '.safetensors':
|
||||||
|
checkpoint = safetensors.torch.load_file(weights_file)
|
||||||
|
else:
|
||||||
|
checkpoint = torch.load(weights_file, map_location="cpu")
|
||||||
|
|
||||||
|
# sometimes weights are hidden under "state_dict", and sometimes not
|
||||||
|
if "state_dict" in checkpoint:
|
||||||
|
checkpoint = checkpoint["state_dict"]
|
||||||
|
|
||||||
|
config = OmegaConf.load(config_file)
|
||||||
|
|
||||||
|
vae_model = convert_ldm_vae_to_diffusers(
|
||||||
|
checkpoint = checkpoint,
|
||||||
|
vae_config = config,
|
||||||
|
image_size = image_size
|
||||||
|
)
|
||||||
|
vae_model.save_pretrained(
|
||||||
|
diffusers_path,
|
||||||
|
safe_serialization=is_safetensors_available()
|
||||||
|
)
|
||||||
|
return diffusers_path
|
Loading…
Reference in New Issue
Block a user