New models structure draft

This commit is contained in:
Sergey Borisov 2023-06-10 03:14:10 +03:00
parent 887576d217
commit 2c056ead42
4 changed files with 909 additions and 1010 deletions
invokeai
app/invocations
backend/model_management

View File

@ -248,7 +248,6 @@ class TextToLatentsInvocation(BaseInvocation):
feature_extractor=None,
requires_safety_checker=False,
precision="float16" if unet.dtype == torch.float16 else "float32",
#precision="float16", # TODO:
)
def prep_control_data(self,

View File

@ -40,456 +40,7 @@ from invokeai.app.services.config import get_invokeai_config
from .lora import LoRAModel, TextualInversionModel
def get_model_path(repo_id_or_path: str):
globals = get_invokeai_config()
if os.path.exists(repo_id_or_path):
return repo_id_or_path
cache = scan_cache_dir(globals.cache_dir)
for repo in cache.repos:
if repo.repo_id != repo_id_or_path:
continue
for rev in repo.revisions:
if "main" in rev.refs:
return rev.snapshot_path
raise Exception(f"{repo_id_or_path} - not found")
def calc_model_size_by_fs(
repo_id_or_path: str,
subfolder: Optional[str] = None,
variant: Optional[str] = None
):
model_path = get_model_path(repo_id_or_path)
if subfolder is not None:
model_path = os.path.join(model_path, subfolder)
# this can happen when, for example, the safety checker
# is not downloaded.
if not os.path.exists(model_path):
return 0
all_files = os.listdir(model_path)
all_files = [f for f in all_files if os.path.isfile(os.path.join(model_path, f))]
fp16_files = set([f for f in all_files if ".fp16." in f or ".fp16-" in f])
bit8_files = set([f for f in all_files if ".8bit." in f or ".8bit-" in f])
other_files = set(all_files) - fp16_files - bit8_files
if variant is None:
files = other_files
elif variant == "fp16":
files = fp16_files
elif variant == "8bit":
files = bit8_files
else:
raise NotImplementedError(f"Unknown variant: {variant}")
# try read from index if exists
index_postfix = ".index.json"
if variant is not None:
index_postfix = f".index.{variant}.json"
for file in files:
if not file.endswith(index_postfix):
continue
try:
with open(os.path.join(model_path, file), "r") as f:
index_data = json.loads(f.read())
return int(index_data["metadata"]["total_size"])
except:
pass
# calculate files size if there is no index file
formats = [
(".safetensors",), # safetensors
(".bin",), # torch
(".onnx", ".pb"), # onnx
(".msgpack",), # flax
(".ckpt",), # tf
(".h5",), # tf2
]
for file_format in formats:
model_files = [f for f in files if f.endswith(file_format)]
if len(model_files) == 0:
continue
model_size = 0
for model_file in model_files:
file_stats = os.stat(os.path.join(model_path, model_file))
model_size += file_stats.st_size
return model_size
#raise NotImplementedError(f"Unknown model structure! Files: {all_files}")
return 0 # scheduler/feature_extractor/tokenizer - models without loading to gpu
def calc_model_size_by_data(model) -> int:
if isinstance(model, DiffusionPipeline):
return _calc_pipeline_by_data(model)
elif isinstance(model, torch.nn.Module):
return _calc_model_by_data(model)
else:
return 0
def _calc_pipeline_by_data(pipeline) -> int:
res = 0
for submodel_key in pipeline.components.keys():
submodel = getattr(pipeline, submodel_key)
if submodel is not None and isinstance(submodel, torch.nn.Module):
res += _calc_model_by_data(submodel)
return res
def _calc_model_by_data(model) -> int:
mem_params = sum([param.nelement()*param.element_size() for param in model.parameters()])
mem_bufs = sum([buf.nelement()*buf.element_size() for buf in model.buffers()])
mem = mem_params + mem_bufs # in bytes
return mem
class SDModelType(str, Enum):
Diffusers = "diffusers"
Classifier = "classifier"
UNet = "unet"
TextEncoder = "text_encoder"
Tokenizer = "tokenizer"
Vae = "vae"
Scheduler = "scheduler"
Lora = "lora"
TextualInversion = "textual_inversion"
ControlNet = "control_net"
class BaseModel(str, Enum):
StableDiffusion1_5 = "SD-1"
StableDiffusion2Base = "SD-2-base" # 512 pixels; this will have epsilon parameterization
StableDiffusion2 = "SD-2" # 768 pixels; this will have v-prediction parameterization
class ModelInfoBase:
#model_path: str
#model_type: SDModelType
def __init__(self, repo_id_or_path: str, model_type: SDModelType):
self.repo_id_or_path = repo_id_or_path # TODO: or use allways path?
self.model_path = get_model_path(repo_id_or_path)
self.model_type = model_type
def _definition_to_type(self, subtypes: List[str]) -> Type:
if len(subtypes) < 2:
raise Exception("Invalid subfolder definition!")
if subtypes[0] in ["diffusers", "transformers"]:
res_type = sys.modules[subtypes[0]]
subtypes = subtypes[1:]
else:
res_type = sys.modules["diffusers"]
res_type = getattr(res_type, "pipelines")
for subtype in subtypes:
res_type = getattr(res_type, subtype)
return res_type
class DiffusersModelInfo(ModelInfoBase):
#child_types: Dict[str, Type]
#child_sizes: Dict[str, int]
def __init__(self, repo_id_or_path: str, model_type: SDModelType):
assert model_type == SDModelType.Diffusers
super().__init__(repo_id_or_path, model_type)
self.child_types: Dict[str, Type] = dict()
self.child_sizes: Dict[str, int] = dict()
try:
config_data = DiffusionPipeline.load_config(repo_id_or_path)
#config_data = json.loads(os.path.join(self.model_path, "model_index.json"))
except:
raise Exception("Invalid diffusers model! (model_index.json not found or invalid)")
config_data.pop("_ignore_files", None)
# retrieve all folder_names that contain relevant files
child_components = [k for k, v in config_data.items() if isinstance(v, list)]
for child_name in child_components:
child_type = self._definition_to_type(config_data[child_name])
self.child_types[child_name] = child_type
self.child_sizes[child_name] = calc_model_size_by_fs(repo_id_or_path, subfolder=child_name)
def get_size(self, child_type: Optional[SDModelType] = None):
if child_type is None:
return sum(self.child_sizes.values())
else:
return self.child_sizes[child_type]
def get_model(
self,
child_type: Optional[SDModelType] = None,
torch_dtype: Optional[torch.dtype] = None,
):
# return pipeline in different function to pass more arguments
if child_type is None:
raise Exception("Child model type can't be null on diffusers model")
if child_type not in self.child_types:
return None # TODO: or raise
# TODO:
for variant in ["fp16", "main", None]:
try:
model = self.child_types[child_type].from_pretrained(
self.repo_id_or_path,
subfolder=child_type.value,
cache_dir=get_invokeai_config().cache_dir,
torch_dtype=torch_dtype,
variant=variant,
)
break
except Exception as e:
print("====ERR LOAD====")
print(f"{variant}: {e}")
# calc more accurate size
self.child_sizes[child_type] = calc_model_size_by_data(model)
return model
def get_pipeline(self, **kwargs):
return DiffusionPipeline.from_pretrained(
self.repo_id_or_path,
**kwargs,
)
class EmptyConfigLoader(ConfigMixin):
@classmethod
def load_config(cls, *args, **kwargs):
cls.config_name = kwargs.pop("config_name")
return super().load_config(*args, **kwargs)
class ClassifierModelInfo(ModelInfoBase):
#child_types: Dict[str, Type]
#child_sizes: Dict[str, int]
def __init__(self, repo_id_or_path: str, model_type: SDModelType):
assert model_type == SDModelType.Classifier
super().__init__(repo_id_or_path, model_type)
self.child_types: Dict[str, Type] = dict()
self.child_sizes: Dict[str, int] = dict()
try:
main_config = EmptyConfigLoader.load_config(self.repo_id_or_path, config_name="config.json")
#main_config = json.loads(os.path.join(self.model_path, "config.json"))
except:
raise Exception("Invalid classifier model! (config.json not found or invalid)")
self._load_tokenizer(main_config)
self._load_text_encoder(main_config)
self._load_feature_extractor(main_config)
def _load_tokenizer(self, main_config: dict):
try:
tokenizer_config = EmptyConfigLoader.load_config(self.repo_id_or_path, config_name="tokenizer_config.json")
#tokenizer_config = json.loads(os.path.join(self.model_path, "tokenizer_config.json"))
except:
raise Exception("Invalid classifier model! (Failed to load tokenizer_config.json)")
if "tokenizer_class" in tokenizer_config:
tokenizer_class_name = tokenizer_config["tokenizer_class"]
elif "model_type" in main_config:
tokenizer_class_name = transformers.models.auto.tokenization_auto.TOKENIZER_MAPPING_NAMES[main_config["model_type"]]
else:
raise Exception("Invalid classifier model! (Failed to detect tokenizer type)")
self.child_types[SDModelType.Tokenizer] = self._definition_to_type(["transformers", tokenizer_class_name])
self.child_sizes[SDModelType.Tokenizer] = 0
def _load_text_encoder(self, main_config: dict):
if "architectures" in main_config and len(main_config["architectures"]) > 0:
text_encoder_class_name = main_config["architectures"][0]
elif "model_type" in main_config:
text_encoder_class_name = transformers.models.auto.modeling_auto.MODEL_FOR_PRETRAINING_MAPPING_NAMES[main_config["model_type"]]
else:
raise Exception("Invalid classifier model! (Failed to detect text_encoder type)")
self.child_types[SDModelType.TextEncoder] = self._definition_to_type(["transformers", text_encoder_class_name])
self.child_sizes[SDModelType.TextEncoder] = calc_model_size_by_fs(self.repo_id_or_path)
def _load_feature_extractor(self, main_config: dict):
self.child_sizes[SDModelType.FeatureExtractor] = 0
try:
feature_extractor_config = EmptyConfigLoader.load_config(self.repo_id_or_path, config_name="preprocessor_config.json")
except:
return # feature extractor not passed with t5
try:
feature_extractor_class_name = feature_extractor_config["feature_extractor_type"]
self.child_types[SDModelType.FeatureExtractor] = self._definition_to_type(["transformers", feature_extractor_class_name])
except:
raise Exception("Invalid classifier model! (Unknown feature_extrator type)")
def get_size(self, child_type: Optional[SDModelType] = None):
if child_type is None:
return sum(self.child_sizes.values())
else:
return self.child_sizes[child_type]
def get_model(
self,
child_type: Optional[SDModelType] = None,
torch_dtype: Optional[torch.dtype] = None,
):
if child_type is None:
raise Exception("Child model type can't be null on classififer model")
if child_type not in self.child_types:
return None # TODO: or raise
model = self.child_types[child_type].from_pretrained(
self.repo_id_or_path,
subfolder=child_type.value,
cache_dir=get_invokeai_config().cache_dir,
torch_dtype=torch_dtype,
)
# calc more accurate size
self.child_sizes[child_type] = calc_model_size_by_data(model)
return model
class VaeModelInfo(ModelInfoBase):
#vae_class: Type
#model_size: int
def __init__(self, repo_id_or_path: str, model_type: SDModelType):
assert model_type == SDModelType.Vae
super().__init__(repo_id_or_path, model_type)
try:
config = EmptyConfigLoader.load_config(repo_id_or_path, config_name="config.json")
#config = json.loads(os.path.join(self.model_path, "config.json"))
except:
raise Exception("Invalid vae model! (config.json not found or invalid)")
try:
vae_class_name = config.get("_class_name", "AutoencoderKL")
self.vae_class = self._definition_to_type(["diffusers", vae_class_name])
self.model_size = calc_model_size_by_fs(repo_id_or_path)
except:
raise Exception("Invalid vae model! (Unkown vae type)")
def get_size(self, child_type: Optional[SDModelType] = None):
if child_type is not None:
raise Exception("There is no child models in vae model")
return self.model_size
def get_model(
self,
child_type: Optional[SDModelType] = None,
torch_dtype: Optional[torch.dtype] = None,
):
if child_type is not None:
raise Exception("There is no child models in vae model")
model = self.vae_class.from_pretrained(
self.repo_id_or_path,
cache_dir=get_invokeai_config().cache_dir,
torch_dtype=torch_dtype,
)
# calc more accurate size
self.model_size = calc_model_size_by_data(model)
return model
class LoRAModelInfo(ModelInfoBase):
#model_size: int
def __init__(self, file_path: str, model_type: SDModelType):
assert model_type == SDModelType.Lora
# check manualy as super().__init__ will try to resolve repo_id too
if not os.path.exists(file_path):
raise Exception("Model not found")
super().__init__(file_path, model_type)
self.model_size = os.path.getsize(file_path)
def get_size(self, child_type: Optional[SDModelType] = None):
if child_type is not None:
raise Exception("There is no child models in lora")
return self.model_size
def get_model(
self,
child_type: Optional[SDModelType] = None,
torch_dtype: Optional[torch.dtype] = None,
):
if child_type is not None:
raise Exception("There is no child models in lora")
model = LoRAModel.from_checkpoint(
file_path=self.model_path,
dtype=torch_dtype,
)
self.model_size = model.calc_size()
return model
class TextualInversionModelInfo(ModelInfoBase):
#model_size: int
def __init__(self, file_path: str, model_type: SDModelType):
assert model_type == SDModelType.TextualInversion
# check manualy as super().__init__ will try to resolve repo_id too
if not os.path.exists(file_path):
raise Exception("Model not found")
super().__init__(file_path, model_type)
self.model_size = os.path.getsize(file_path)
def get_size(self, child_type: Optional[SDModelType] = None):
if child_type is not None:
raise Exception("There is no child models in textual inversion")
return self.model_size
def get_model(
self,
child_type: Optional[SDModelType] = None,
torch_dtype: Optional[torch.dtype] = None,
):
if child_type is not None:
raise Exception("There is no child models in textual inversion")
model = TextualInversionModel.from_checkpoint(
file_path=self.model_path,
dtype=torch_dtype,
)
self.model_size = model.embedding.nelement() * model.embedding.element_size()
return model
MODEL_TYPES = {
SDModelType.Diffusers: DiffusersModelInfo,
SDModelType.Classifier: ClassifierModelInfo,
SDModelType.Vae: VaeModelInfo,
SDModelType.Lora: LoRAModelInfo,
SDModelType.TextualInversion: TextualInversionModelInfo,
}
from .models import MODEL_CLASSES
# Maximum size of the cache, in gigs
@ -499,10 +50,6 @@ DEFAULT_MAX_CACHE_SIZE = 6.0
# actual size of a gig
GIG = 1073741824
# TODO:
class EmptyScheduler(SchedulerMixin, ConfigMixin):
pass
class ModelLocker(object):
"Forward declaration"
pass
@ -583,12 +130,10 @@ class ModelCache(object):
self,
model_path: str,
model_type: SDModelType,
revision: Optional[str] = None,
submodel_type: Optional[SDModelType] = None,
):
revision = revision or "main"
key = f"{model_path}:{model_type}:{revision}"
key = f"{model_path}:{model_type}"
if submodel_type:
key += f":{submodel_type}"
return key
@ -606,55 +151,51 @@ class ModelCache(object):
def _get_model_info(
self,
model_path: str,
model_type: SDModelType,
revision: str,
model_class: Type[ModelBase],
):
model_info_key = self.get_key(
model_path=model_path,
model_type=model_type,
revision=revision,
submodel_type=None,
)
if model_info_key not in self.model_infos:
if model_type not in MODEL_TYPES:
raise Exception(f"Unknown/unsupported model type: {model_type}")
self.model_infos[model_info_key] = MODEL_TYPES[model_type](
self.model_infos[model_info_key] = model_class(
model_path,
model_type,
)
return self.model_infos[model_info_key]
# TODO: args
def get_model(
self,
repo_id_or_path: Union[str, Path],
model_type: SDModelType = SDModelType.Diffusers,
submodel: Optional[SDModelType] = None,
revision: Optional[str] = None,
variant: Optional[str] = None,
model_path: Union[str, Path],
model_class: Type[ModelBase],
submodel: Optional[SubModelType] = None,
gpu_load: bool = True,
) -> Any:
model_path = get_model_path(repo_id_or_path)
if not isinstance(model_path, Path):
model_path = Path(model_path)
if not os.path.exists(model_path):
raise Exception(f"Model not found: {model_path}")
model_info = self._get_model_info(
model_path=model_path,
model_type=model_type,
revision=revision,
model_class=model_class,
)
# TODO: variant
key = self.get_key(
model_path=model_path,
model_type=model_type,
revision=revision,
model_type=model_type, # TODO:
submodel_type=submodel,
)
# TODO: lock for no copies on simultaneous calls?
cache_entry = self._cached_models.get(key, None)
if cache_entry is None:
self.logger.info(f'Loading model {repo_id_or_path}, type {model_type}:{submodel}')
self.logger.info(f'Loading model {model_path}, type {model_type}:{submodel}')
# this will remove older cached models until
# there is sufficient room to load the requested model
@ -662,7 +203,7 @@ class ModelCache(object):
# clean memory to make MemoryUsage() more accurate
gc.collect()
model = model_info.get_model(submodel, torch_dtype=self.precision)
model = model_info.get_model(submodel, torch_dtype=self.precision, variant=)
if mem_used := model_info.get_size(submodel):
self.logger.debug(f'CPU RAM used for load: {(mem_used/GIG):.2f} GB')
@ -732,20 +273,14 @@ class ModelCache(object):
def model_hash(
self,
repo_id_or_path: Union[str, Path],
revision: str = "main",
model_path: Union[str, Path],
) -> str:
'''
Given the HF repo id or path to a model on disk, returns a unique
hash. Works for legacy checkpoint files, HF models on disk, and HF repo IDs
:param repo_id_or_path: repo_id string or Path to model file/directory on disk.
:param revision: optional revision string (if fetching a HF repo_id)
:param model_path: Path to model file/directory on disk.
'''
revision = revision or "main"
if Path(repo_id_or_path).is_dir():
return self._local_model_hash(repo_id_or_path)
else:
return self._hf_commit_hash(repo_id_or_path,revision)
return self._local_model_hash(model_path)
def cache_size(self) -> float:
"Return the current size of the cache, in GB"
@ -840,17 +375,6 @@ class ModelCache(object):
with open(hashpath, "w") as f:
f.write(hash)
return hash
def _hf_commit_hash(self, repo_id: str, revision: str='main') -> str:
api = HfApi()
info = api.list_repo_refs(
repo_id=repo_id,
repo_type='model',
)
desired_revisions = [branch for branch in info.branches if branch.name==revision]
if not desired_revisions:
raise KeyError(f"Revision '{revision}' not found in {repo_id}")
return desired_revisions[0].target_commit
class SilenceWarnings(object):
def __init__(self):

View File

@ -163,7 +163,6 @@ import safetensors
import safetensors.torch
import torch
from diffusers import AutoencoderKL
from diffusers.utils import is_safetensors_available
from huggingface_hub import scan_cache_dir
from omegaconf import OmegaConf
from omegaconf.dictconfig import DictConfig
@ -172,8 +171,8 @@ import invokeai.backend.util.logging as logger
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.backend.util import CUDA_DEVICE, download_with_resume
from ..install.model_install_backend import Dataset_path, hf_download_with_resume
from .model_cache import (ModelCache, ModelLocker, SDModelType,
SilenceWarnings)
from .model_cache import ModelCache, ModelLocker, SilenceWarnings
from .models import BaseModelType, ModelType, SubModelType, MODEL_CLASSES
# We are only starting to number the config file with release 3.
# The config file version doesn't have to start at release version, but it will help
# reduce confusion.
@ -201,14 +200,6 @@ class InvalidModelError(Exception):
"Raised when an invalid model is requested"
pass
class SDLegacyType(Enum):
V1 = auto()
V1_INPAINT = auto()
V2 = auto()
V2_e = auto()
V2_v = auto()
UNKNOWN = auto()
MAX_CACHE_SIZE = 6.0 # GB
@ -280,32 +271,45 @@ class ModelManager(object):
def model_exists(
self,
model_name: str,
model_type: SDModelType = SDModelType.Diffusers,
base_model: BaseModelType,
model_type: ModelType,
) -> bool:
"""
Given a model name, returns True if it is a valid
identifier.
"""
model_key = self.create_key(model_name, model_type)
model_key = self.create_key(model_name, base_model, model_type)
return model_key in self.config
def create_key(self, model_name: str, model_type: SDModelType) -> str:
return f"{model_type}/{model_name}"
def create_key(
self,
model_name: str,
base_model: BaseModelType,
model_type: ModelType,
) -> str:
return f"{base_model}/{model_type}/{model_name}"
def parse_key(self, model_key: str) -> Tuple[str, SDModelType]:
model_type_str, model_name = model_key.split('/', 1)
def parse_key(self, model_key: str) -> Tuple[str, BaseModelType, ModelType]:
base_model_str, model_type_str, model_name = model_key.split('/', 2)
try:
model_type = SDModelType(model_type_str)
return (model_name, model_type)
except:
raise Exception(f"Unknown model type: {model_type_str}")
try:
base_model = BaseModelType(base_model_str)
except:
raise Exception(f"Unknown base model: {base_model_str}")
return (model_name, base_model, model_type)
def get_model(
self,
model_name: str,
model_type: SDModelType = SDModelType.Diffusers,
submodel: Optional[SDModelType] = None,
) -> SDModelInfo:
base_model: BaseModelType,
model_type: ModelType,
submodel_type: Optional[SubModelType] = None
):
"""Given a model named identified in models.yaml, return
an SDModelInfo object describing it.
:param model_name: symbolic name of the model in models.yaml
@ -344,210 +348,182 @@ class ModelManager(object):
# raises an InvalidModelError
"""
model_key = self.create_key(model_name, model_type)
if model_key not in self.config:
raise InvalidModelError(
f'"{model_key}" is not a known model name. Please check your models.yaml file'
)
# get the required loading info out of the config file
mconfig = self.config[model_key]
# type already checked as it's part of key
if model_type == SDModelType.Diffusers:
# intercept stanzas that point to checkpoint weights and replace them
# with the equivalent diffusers model
if mconfig.format in ["ckpt", "safetensors"]:
location = self.convert_ckpt_and_cache(mconfig)
elif mconfig.get('path'):
location = self.globals.root_dir / mconfig.get('path')
model_class = MODEL_CLASSES[base_model][model_type]
#if model_type in {
# ModelType.Lora,
# ModelType.ControlNet,
# ModelType.TextualInversion,
# ModelType.Vae,
#}:
if not model_class.has_config:
#if model_class.Config is None:
# skip config
# load from
# /models/{base_model}/{model_type}/{model_name}
# /models/{base_model}/{model_type}/{model_name}.{ext}
model_config = None
for ext in {"pt", "ckpt", "safetensors"}:
model_path = os.path.join(model_dir, base_model, model_type, f"{model_name}.{ext}")
if os.path.exists(model_path):
break
else:
location = mconfig.get('repo_id')
elif p := mconfig.get('path'):
location = self.globals.root_dir / p
elif r := mconfig.get('repo_id'):
location = r
elif w := mconfig.get('weights'):
location = self.globals.root_dir / w
model_path = os.path.join(model_dir, base_model, model_type, model_name)
if not os.path.exists(model_path):
raise InvalidModelError(
f"Model not found - \"{base_model}/{model_type}/{model_name}\" "
)
else:
location = None
revision = mconfig.get('revision')
if model_type in [SDModelType.Lora, SDModelType.TextualInversion]:
hash = "<NO_HASH>" # TODO:
else:
hash = self.cache.model_hash(location, revision)
# find in config
model_key = self.create_key(model_name, base_model, model_type)
if model_key not in self.config:
raise InvalidModelError(
f'"{model_key}" is not a known model name. Please check your models.yaml file'
)
# If the caller is asking for part of the model and the config indicates
# an external replacement for that field, then we fetch the replacement
if submodel and mconfig.get(submodel):
location = mconfig.get(submodel).get('path') \
or mconfig.get(submodel).get('repo_id')
model_type = submodel
submodel = None
model_config = self.config[model_key]
# /models/{base_model}/{model_type}/{name}.ckpt or .safentesors
# /models/{base_model}/{model_type}/{name}/
model_path = model_config.path
# to support the traditional way of attaching a VAE
# to a model, we hacked in `attach_model_part`
# TODO:
if model_type == SDModelType.Vae and "vae" in mconfig:
print("NOT_IMPLEMENTED - RETURN CUSTOM VAE")
# vae/movq override
# TODO:
if submodel is not None and submodel in model_config:
model_path = model_config[submodel]["path"]
model_type = submodel
submodel = None
model_context = self.cache.get_model(
location,
model_type = model_type,
revision = revision,
submodel = submodel,
dst_convert_path = None # TODO:
model_path = model_class.convert_if_required(
model_path,
dst_convert_path,
model_config,
)
# in case we need to communicate information about this
# model to the cache manager, then we need to remember
# the cache key
self.cache_keys[model_key] = model_context.key
model_context = self.cache.get_model(
model_path,
model_class,
submodel,
)
hash = "<NO_HASH>" # TODO:
return SDModelInfo(
context = model_context,
name = model_name,
base_model = base_model,
type = submodel or model_type,
hash = hash,
location = location,
revision = revision,
location = model_path, # TODO:
precision = self.cache.precision,
_cache = self.cache
_cache = self.cache,
)
def default_model(self) -> Optional[Tuple[str, SDModelType]]:
def default_model(self) -> Optional[Tuple[str, BaseModelType, ModelType]]:
"""
Returns the name of the default model, or None
if none is defined.
"""
for model_name, model_type in self.model_names():
model_key = self.create_key(model_name, model_type)
if self.config[model_key].get("default"):
return (model_name, model_type)
return self.model_names()[0][0]
for model_key, model_config in self.config.items():
if model_config.get("default", False):
return self.parse_key(model_key)
def set_default_model(self, model_name: str, model_type: SDModelType=SDModelType.Diffusers) -> None:
for model_key, _ in self.config.items():
return self.parse_key(model_key)
else:
return None # TODO: or redo as (None, None, None)
def set_default_model(
self,
model_name: str,
base_model: BaseModelType,
model_type: ModelType,
) -> None:
"""
Set the default model. The change will not take
effect until you call model_manager.commit()
"""
assert self.model_exists(model_name, model_type), f"unknown model '{model_name}'"
config = self.config
for model_name, model_type in self.model_names():
key = self.create_key(model_name, model_type)
config[key].pop("default", None)
config[self.create_key(model_name, model_type)]["default"] = True
model_key = self.model_key(model_name, base_model, model_type)
if model_key not in self.config:
raise Exception(f"Unknown model: {model_key}")
for cur_model_key, config in self.config.items():
if cur_model_key == model_key:
config["default"] = True
else:
config.pop("default", None)
def model_info(
self,
model_name: str,
model_type: SDModelType=SDModelType.Diffusers,
base_model: BaseModelType,
model_type: ModelType,
) -> dict:
"""
Given a model name returns the OmegaConf (dict-like) object describing it.
"""
if not self.model_exists(model_name, model_type):
return None
return self.config[self.create_key(model_name, model_type)]
model_key = self.create_key(model_name, base_model, model_type)
return self.config.get(model_key, None)
def model_names(self) -> List[Tuple[str, SDModelType]]:
def model_names(self) -> List[Tuple[str, BaseModelType, ModelType]]:
"""
Return a list of (str, SDModelType) corresponding to all models
Return a list of (str, BaseModelType, ModelType) corresponding to all models
known to the configuration.
"""
return [(self.parse_key(x)) for x in self.config.keys() if isinstance(self.config[x], DictConfig)]
def is_legacy(self, model_name: str, model_type: SDModelType.Diffusers) -> bool:
def list_models(
self,
base_model: Optional[BaseModelType] = None,
model_type: Optional[SDModelType] = None,
) -> Dict[str, Dict[str, str]]:
"""
Return true if this is a legacy (.ckpt) model
"""
# if we are converting legacy files automatically, then
# there are no legacy ckpts!
if self.globals.ckpt_convert:
return False
info = self.model_info(model_name, model_type)
if "weights" in info and info["weights"].endswith((".ckpt", ".safetensors")):
return True
return False
def list_models(self, model_type: SDModelType=None) -> dict[str,dict[str,str]]:
"""
Return a dict of models, in format [model_type][model_name], with
following fields:
model_name
model_type
format
description
status
# for folders only
repo_id
path
subfolder
vae
# for ckpts only
config
weights
vae
Return a dict of models, in format [base_model][model_type][model_name]
Please use model_manager.models() to get all the model names,
model_manager.model_info('model-name') to get the stanza for the model
named 'model-name', and model_manager.config to get the full OmegaConf
object derived from models.yaml
"""
models = {}
assert not(model_type is not None and base_model is None), "model_type must be provided with base_model"
models = dict()
for model_key in sorted(self.config, key=str.casefold):
stanza = self.config[model_key]
# don't include VAEs in listing (legacy style)
if "config" in stanza and "/VAE/" in stanza["config"]:
continue
if model_key.startswith('_'):
continue
model_name, stanza_type = self.parse_key(model_key)
model_name, m_base_model, stanza_type = self.parse_key(model_key)
if base_model is not None and m_base_model != base_model:
continue
if model_type is not None and model_type != stanza_type:
continue
if m_base_model not in models:
models[m_base_model] = dict()
if stanza_type not in models:
models[stanza_type] = dict()
models[m_base_model][stanza_type] = dict()
models[stanza_type][model_name] = dict()
model_format = stanza.get('format')
# Common Attribs
description = stanza.get("description", None)
models[stanza_type][model_name].update(
model_name=model_name,
model_type=stanza_type,
format=model_format,
description=description,
status="unknown", # TODO: no more status as model loaded separately
model_class = MODEL_CLASSES[m_base_model][stanza_type]
models[m_base_model][stanza_type][model_name] = model_class.build_config(
**stanza,
name=model_name,
base_model=base_model,
type=stanza_type,
)
# Checkpoint Config Parse
if model_format in ["ckpt","safetensors"]:
models[stanza_type][model_name].update(
config = str(stanza.get("config", None)),
weights = str(stanza.get("weights", None)),
vae = str(stanza.get("vae", None)),
)
# Diffusers Config Parse
elif model_format == "folder":
if vae := stanza.get("vae", None):
if isinstance(vae, DictConfig):
vae = dict(
repo_id = str(vae.get("repo_id", None)),
path = str(vae.get("path", None)),
subfolder = str(vae.get("subfolder", None)),
)
models[stanza_type][model_name].update(
vae = vae,
repo_id = str(stanza.get("repo_id", None)),
path = str(stanza.get("path", None)),
)
#models[m_base_model][stanza_type][model_name] = model_class.Config(
# **stanza,
# name=model_name,
# base_model=base_model,
# type=stanza_type,
#).dict()
return models
@ -557,7 +533,7 @@ class ModelManager(object):
"""
for model_type, model_dict in self.list_models().items():
for model_name, model_info in model_dict.items():
line = f'{model_info["model_name"]:25s} {model_info["status"]:>15s} {model_info["model_type"]:10s} {model_info["description"]}'
line = f'{model_info["name"]:25s} {model_info["status"]:>15s} {model_info["type"]:10s} {model_info["description"]}'
if model_info["status"] in ["in gpu","locked in gpu"]:
line = f"\033[1m{line}\033[0m"
print(line)
@ -606,7 +582,8 @@ class ModelManager(object):
def add_model(
self,
model_name: str,
model_type: SDModelType,
base_model: BaseModelType,
model_type: ModelType,
model_attributes: dict,
clobber: bool = False,
) -> None:
@ -618,38 +595,31 @@ class ModelManager(object):
attributes are incorrect or the model name is missing.
"""
if model_type == SDModelType.Fiffusers:
# TODO: automaticaly or manualy?
#assert "format" in model_attributes, 'missing required field "format"'
model_format = "ckpt" if "weights" in model_attributes else "diffusers"
model_class = MODEL_CLASSES[base_model][model_type]
if model_format == "diffusers":
assert (
"description" in model_attributes
), 'required field "description" is missing'
assert (
"path" in model_attributes or "repo_id" in model_attributes
), 'model must have either the "path" or "repo_id" fields defined'
model_class.build_config(
**model_attributes,
name=model_name,
base_model=base_model,
type=model_type,
)
#model_cfg = model_class.Config(
# **model_attributes,
# name=model_name,
# base_model=base_model,
# type=model_type,
#)
elif model_format == "ckpt":
for field in ("description", "weights", "config"):
assert field in model_attributes, f"required field {field} is missing"
else:
assert "weights" in model_attributes and "description" in model_attributes
model_key = self.create_key(model_name, model_type)
model_key = self.create_key(model_name, base_model, model_type)
assert (
clobber or model_key not in self.config
), f'attempt to overwrite existing model definition "{model_key}"'
self.config[model_key] = model_attributes
if "weights" in self.config[model_key]:
self.config[model_key]["weights"].replace("\\", "/")
if clobber and model_key in self.cache_keys:
# TODO:
self.cache.uncache_model(self.cache_keys[model_key])
del self.cache_keys[model_key]
@ -741,326 +711,6 @@ class ModelManager(object):
),
True
)
@classmethod
def probe_model_type(self, checkpoint: dict) -> SDLegacyType:
"""
Given a pickle or safetensors model object, probes contents
of the object and returns an SDLegacyType indicating its
format. Valid return values include:
SDLegacyType.V1
SDLegacyType.V1_INPAINT
SDLegacyType.V2 (V2 prediction type unknown)
SDLegacyType.V2_e (V2 using 'epsilon' prediction type)
SDLegacyType.V2_v (V2 using 'v_prediction' prediction type)
SDLegacyType.UNKNOWN
"""
global_step = checkpoint.get("global_step")
state_dict = checkpoint.get("state_dict") or checkpoint
try:
key_name = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight"
if key_name in state_dict and state_dict[key_name].shape[-1] == 1024:
if global_step == 220000:
return SDLegacyType.V2_e
elif global_step == 110000:
return SDLegacyType.V2_v
else:
return SDLegacyType.V2
# otherwise we assume a V1 file
in_channels = state_dict[
"model.diffusion_model.input_blocks.0.0.weight"
].shape[1]
if in_channels == 9:
return SDLegacyType.V1_INPAINT
elif in_channels == 4:
return SDLegacyType.V1
else:
return SDLegacyType.UNKNOWN
except KeyError:
return SDLegacyType.UNKNOWN
def heuristic_import(
self,
path_url_or_repo: str,
model_name: Optional[str] = None,
description: Optional[str] = None,
model_config_file: Optional[Path] = None,
commit_to_conf: Optional[Path] = None,
config_file_callback: Optional[Callable[[Path], Path]] = None,
) -> str:
"""Accept a string which could be:
- a HF diffusers repo_id
- a URL pointing to a legacy .ckpt or .safetensors file
- a local path pointing to a legacy .ckpt or .safetensors file
- a local directory containing .ckpt and .safetensors files
- a local directory containing a diffusers model
After determining the nature of the model and downloading it
(if necessary), the file is probed to determine the correct
configuration file (if needed) and it is imported.
The model_name and/or description can be provided. If not, they will
be generated automatically.
If commit_to_conf is provided, the newly loaded model will be written
to the `models.yaml` file at the indicated path. Otherwise, the changes
will only remain in memory.
The routine will do its best to figure out the config file
needed to convert legacy checkpoint file, but if it can't it
will call the config_file_callback routine, if provided. The
callback accepts a single argument, the Path to the checkpoint
file, and returns a Path to the config file to use.
The (potentially derived) name of the model is returned on
success, or None on failure. When multiple models are added
from a directory, only the last imported one is returned.
"""
model_path: Path = None
thing = str(path_url_or_repo) # to save typing
self.logger.info(f"Probing {thing} for import")
if thing.startswith(("http:", "https:", "ftp:")):
self.logger.info(f"{thing} appears to be a URL")
model_path = self._resolve_path(
thing, "models/ldm/stable-diffusion-v1"
) # _resolve_path does a download if needed
elif Path(thing).is_file() and thing.endswith((".ckpt", ".safetensors")):
if Path(thing).stem in ["model", "diffusion_pytorch_model"]:
self.logger.debug(f"{Path(thing).name} appears to be part of a diffusers model. Skipping import")
return
else:
self.logger.debug(f"{thing} appears to be a checkpoint file on disk")
model_path = self._resolve_path(thing, "models/ldm/stable-diffusion-v1")
elif Path(thing).is_dir() and Path(thing, "model_index.json").exists():
self.logger.debug(f"{thing} appears to be a diffusers file on disk")
model_name = self.import_diffuser_model(
thing,
vae=dict(repo_id="stabilityai/sd-vae-ft-mse"),
model_name=model_name,
description=description,
commit_to_conf=commit_to_conf,
)
elif Path(thing).is_dir():
if (Path(thing) / "model_index.json").exists():
self.logger.debug(f"{thing} appears to be a diffusers model.")
model_name = self.import_diffuser_model(
thing, commit_to_conf=commit_to_conf
)
else:
self.logger.debug(f"{thing} appears to be a directory. Will scan for models to import")
for m in list(Path(thing).rglob("*.ckpt")) + list(
Path(thing).rglob("*.safetensors")
):
if model_name := self.heuristic_import(
str(m),
commit_to_conf=commit_to_conf,
config_file_callback=config_file_callback,
):
self.logger.info(f"{model_name} successfully imported")
return model_name
elif re.match(r"^[\w.+-]+/[\w.+-]+$", thing):
self.logger.debug(f"{thing} appears to be a HuggingFace diffusers repo_id")
model_name = self.import_diffuser_model(
thing, commit_to_conf=commit_to_conf
)
pipeline, _, _, _ = self._load_diffusers_model(self.config[model_name])
return model_name
else:
self.logger.warning(f"{thing}: Unknown thing. Please provide a URL, file path, directory or HuggingFace repo_id")
# Model_path is set in the event of a legacy checkpoint file.
# If not set, we're all done
if not model_path:
return
if model_path.stem in self.config: # already imported
self.logger.debug("Already imported. Skipping")
return model_path.stem
# another round of heuristics to guess the correct config file.
checkpoint = None
if model_path.suffix in [".ckpt", ".pt"]:
self.cache.scan_model(model_path, model_path)
checkpoint = torch.load(model_path)
else:
checkpoint = safetensors.torch.load_file(model_path)
# additional probing needed if no config file provided
if model_config_file is None:
# look for a like-named .yaml file in same directory
if model_path.with_suffix(".yaml").exists():
model_config_file = model_path.with_suffix(".yaml")
self.logger.debug(f"Using config file {model_config_file.name}")
else:
model_type = self.probe_model_type(checkpoint)
if model_type == SDLegacyType.V1:
self.logger.debug("SD-v1 model detected")
model_config_file = self.globals.legacy_conf_path / "v1-inference.yaml"
elif model_type == SDLegacyType.V1_INPAINT:
self.logger.debug("SD-v1 inpainting model detected")
model_config_file = self.globals.legacy_conf_path / "v1-inpainting-inference.yaml"
elif model_type == SDLegacyType.V2_v:
self.logger.debug("SD-v2-v model detected")
model_config_file = self.globals.legacy_conf_path / "v2-inference-v.yaml"
elif model_type == SDLegacyType.V2_e:
self.logger.debug("SD-v2-e model detected")
model_config_file = self.globals.legacy_conf_path / "v2-inference.yaml"
elif model_type == SDLegacyType.V2:
self.logger.warning(
f"{thing} is a V2 checkpoint file, but its parameterization cannot be determined."
)
else:
self.logger.warning(
f"{thing} is a legacy checkpoint file but not a known Stable Diffusion model."
)
if not model_config_file and config_file_callback:
model_config_file = config_file_callback(model_path)
# despite our best efforts, we could not find a model config file, so give up
if not model_config_file:
return
# look for a custom vae, a like-named file ending with .vae in the same directory
vae_path = None
for suffix in ["pt", "ckpt", "safetensors"]:
if (model_path.with_suffix(f".vae.{suffix}")).exists():
vae_path = model_path.with_suffix(f".vae.{suffix}")
self.logger.debug(f"Using VAE file {vae_path.name}")
vae = None if vae_path else dict(repo_id="stabilityai/sd-vae-ft-mse")
diffuser_path = self.globals.converted_ckpts_dir / model_path.stem
with SilenceWarnings():
model_name = self.convert_and_import(
model_path,
diffusers_path=diffuser_path,
vae=vae,
vae_path=str(vae_path),
model_name=model_name,
model_description=description,
original_config_file=model_config_file,
commit_to_conf=commit_to_conf,
scan_needed=False,
)
return model_name
def convert_ckpt_and_cache(self, mconfig: DictConfig) -> Path:
"""
Convert the checkpoint model indicated in mconfig into a
diffusers, cache it to disk, and return Path to converted
file. If already on disk then just returns Path.
"""
weights = self.globals.root_dir / mconfig.weights
config_file = self.globals.root_dir / mconfig.config
diffusers_path = self.globals.converted_ckpts_dir / weights.stem
# return cached version if it exists
if diffusers_path.exists():
return diffusers_path
vae_ckpt_path, vae_model = self._get_vae_for_conversion(weights, mconfig)
# to avoid circular import errors
from .convert_ckpt_to_diffusers import convert_ckpt_to_diffusers
with SilenceWarnings():
convert_ckpt_to_diffusers(
weights,
diffusers_path,
extract_ema=True,
original_config_file=config_file,
vae=vae_model,
vae_path=str(self.globals.root_dir / vae_ckpt_path) if vae_ckpt_path else None,
scan_needed=True,
)
return diffusers_path
def convert_vae_ckpt_and_cache(self, mconfig: DictConfig) -> Path:
"""
Convert the VAE indicated in mconfig into a diffusers AutoencoderKL
object, cache it to disk, and return Path to converted
file. If already on disk then just returns Path.
"""
root = self.globals.root_dir
weights_file = root / mconfig.weights
config_file = root / mconfig.config
diffusers_path = self.globals.converted_ckpts_dir / weights_file.stem
image_size = mconfig.get('width') or mconfig.get('height') or 512
# return cached version if it exists
if diffusers_path.exists():
return diffusers_path
# this avoids circular import error
from .convert_ckpt_to_diffusers import convert_ldm_vae_to_diffusers
checkpoint = torch.load(weights_file, map_location="cpu")\
if weights_file.suffix in ['.ckpt','.pt'] \
else safetensors.torch.load_file(weights_file)
# sometimes weights are hidden under "state_dict", and sometimes not
if "state_dict" in checkpoint:
checkpoint = checkpoint["state_dict"]
config = OmegaConf.load(config_file)
vae_model = convert_ldm_vae_to_diffusers(
checkpoint = checkpoint,
vae_config = config,
image_size = image_size
)
vae_model.save_pretrained(
diffusers_path,
safe_serialization=is_safetensors_available()
)
return diffusers_path
def _get_vae_for_conversion(
self,
weights: Path,
mconfig: DictConfig
) -> Tuple[Path, AutoencoderKL]:
# VAE handling is convoluted
# 1. If there is a .vae.ckpt file sharing same stem as weights, then use
# it as the vae_path passed to convert
vae_ckpt_path = None
vae_diffusers_location = None
vae_model = None
for suffix in ["pt", "ckpt", "safetensors"]:
if (weights.with_suffix(f".vae.{suffix}")).exists():
vae_ckpt_path = weights.with_suffix(f".vae.{suffix}")
self.logger.debug(f"Using VAE file {vae_ckpt_path.name}")
if vae_ckpt_path:
return (vae_ckpt_path, None)
# 2. If mconfig has a vae weights path, then we use that as vae_path
vae_config = mconfig.get('vae')
if vae_config and isinstance(vae_config,str):
vae_ckpt_path = vae_config
return (vae_ckpt_path, None)
# 3. If mconfig has a vae dict, then we use it as the diffusers-style vae
if vae_config and isinstance(vae_config,DictConfig):
vae_diffusers_location = self.globals.root_dir / vae_config.get('path') \
if vae_config.get('path') \
else vae_config.get('repo_id')
# 4. Otherwise, we use stabilityai/sd-vae-ft-mse "because it works"
else:
vae_diffusers_location = "stabilityai/sd-vae-ft-mse"
if vae_diffusers_location:
vae_model = self.cache.get_model(vae_diffusers_location, SDModelType.Vae).model
return (None, vae_model)
return (None, None)
def convert_and_import(
self,

View File

@ -0,0 +1,726 @@
import sys
from enum import Enum
import torch
import safetensors.torch
from diffusers.utils import is_safetensors_available
class BaseModelType(str, Enum):
#StableDiffusion1_5 = "stable_diffusion_1_5"
#StableDiffusion2 = "stable_diffusion_2"
#StableDiffusion2Base = "stable_diffusion_2_base"
# TODO: maybe then add sample size(512/768)?
StableDiffusion1_5 = "SD-1"
StableDiffusion2Base = "SD-2-base" # 512 pixels; this will have epsilon parameterization
StableDiffusion2 = "SD-2" # 768 pixels; this will have v-prediction parameterization
#Kandinsky2_1 = "kandinsky_2_1"
class ModelType(str, Enum):
Pipeline = "pipeline"
Classifier = "classifier"
Vae = "vae"
Lora = "lora"
ControlNet = "controlnet"
TextualInversion = "embedding"
class SubModelType:
UNet = "unet"
TextEncoder = "text_encoder"
Tokenizer = "tokenizer"
Vae = "vae"
Scheduler = "scheduler"
SafetyChecker = "safety_checker"
#MoVQ = "movq"
MODEL_CLASSES = {
BaseModel.StableDiffusion1_5: {
ModelType.Pipeline: StableDiffusionModel,
ModelType.Classifier: ClassifierModel,
ModelType.Vae: VaeModel,
ModelType.Lora: LoraModel,
ModelType.ControlNet: ControlNetModel,
ModelType.TextualInversion: TextualInversionModel,
},
BaseModel.StableDiffusion2: {
ModelType.Pipeline: StableDiffusionModel,
ModelType.Classifier: ClassifierModel,
ModelType.Vae: VaeModel,
ModelType.Lora: LoraModel,
ModelType.ControlNet: ControlNetModel,
ModelType.TextualInversion: TextualInversionModel,
},
BaseModel.StableDiffusion2Base: {
ModelType.Pipeline: StableDiffusionModel,
ModelType.Classifier: ClassifierModel,
ModelType.Vae: VaeModel,
ModelType.Lora: LoraModel,
ModelType.ControlNet: ControlNetModel,
ModelType.TextualInversion: TextualInversionModel,
},
#BaseModel.Kandinsky2_1: {
# ModelType.Pipeline: Kandinsky2_1Model,
# ModelType.Classifier: ClassifierModel,
# ModelType.MoVQ: MoVQModel,
# ModelType.Lora: LoraModel,
# ModelType.ControlNet: ControlNetModel,
# ModelType.TextualInversion: TextualInversionModel,
#},
}
class EmptyConfigLoader(ConfigMixin):
@classmethod
def load_config(cls, *args, **kwargs):
cls.config_name = kwargs.pop("config_name")
return super().load_config(*args, **kwargs)
class ModelBase:
#model_path: str
#base_model: BaseModelType
#model_type: ModelType
def __init__(
self,
model_path: str,
base_model: BaseModelType,
model_type: ModelType,
):
self.model_path = model_path
self.base_model = base_model
self.model_type = model_type
def _hf_definition_to_type(self, subtypes: List[str]) -> Type:
if len(subtypes) < 2:
raise Exception("Invalid subfolder definition!")
if subtypes[0] in ["diffusers", "transformers"]:
res_type = sys.modules[subtypes[0]]
subtypes = subtypes[1:]
else:
res_type = sys.modules["diffusers"]
res_type = getattr(res_type, "pipelines")
for subtype in subtypes:
res_type = getattr(res_type, subtype)
return res_type
class DiffusersModel(ModelBase):
#child_types: Dict[str, Type]
#child_sizes: Dict[str, int]
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
super().__init__(model_path, base_model, model_type)
self.child_types: Dict[str, Type] = dict()
self.child_sizes: Dict[str, int] = dict()
try:
config_data = DiffusionPipeline.load_config(self.model_path)
#config_data = json.loads(os.path.join(self.model_path, "model_index.json"))
except:
raise Exception("Invalid diffusers model! (model_index.json not found or invalid)")
config_data.pop("_ignore_files", None)
# retrieve all folder_names that contain relevant files
child_components = [k for k, v in config_data.items() if isinstance(v, list)]
for child_name in child_components:
child_type = self._hf_definition_to_type(config_data[child_name])
self.child_types[child_name] = child_type
self.child_sizes[child_name] = calc_model_size_by_fs(self.model_path, subfolder=child_name)
def get_size(self, child_type: Optional[SubModelType] = None):
if child_type is None:
return sum(self.child_sizes.values())
else:
return self.child_sizes[child_type]
def get_model(
self,
torch_dtype: Optional[torch.dtype],
child_type: Optional[SubModelType] = None,
):
# return pipeline in different function to pass more arguments
if child_type is None:
raise Exception("Child model type can't be null on diffusers model")
if child_type not in self.child_types:
return None # TODO: or raise
if torch_dtype == torch.float16:
variants = ["fp16", None]
else:
variants = [None, "fp16"]
# TODO: better error handling(differentiate not found from others)
for variant in variants:
try:
# TODO: set cache_dir to /dev/null to be sure that cache not used?
model = self.child_types[child_type].from_pretrained(
self.model_path,
subfolder=child_type.value,
torch_dtype=torch_dtype,
variant=variant,
local_files_only=True,
)
break
except Exception as e:
print("====ERR LOAD====")
print(f"{variant}: {e}")
# calc more accurate size
self.child_sizes[child_type] = calc_model_size_by_data(model)
return model
#def convert_if_required(model_path: Union[str, Path], cache_path: str, config: Optional[dict]) -> Path:
class StableDiffusionModel(DiffusersModel):
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
assert base_model in {
BaseModelType.StableDiffusion1_5,
BaseModelType.StableDiffusion2,
BaseModelType.StableDiffusion2Base,
}
assert model_type == ModelType.Pipeline
super().__init__(model_path, base_model, model_type)
@staticmethod
def convert_if_required(model_path: Union[str, Path], dst_path: str, config: Optional[dict]) -> Path:
if not isinstance(model_path, Path):
model_path = Path(model_path)
# TODO: args
# TODO: set model_path, to config? pass dst_path as arg?
# TODO: check
return _convert_ckpt_and_cache(config)
class classproperty(object): # pylint: disable=invalid-name
"""Class property decorator.
Example usage:
class MyClass(object):
@classproperty
def value(cls):
return '123'
> print MyClass.value
123
"""
def __init__(self, func):
self._func = func
def __get__(self, owner_self, owner_cls):
return self._func(owner_cls)
class ModelConfigBase(BaseModel):
path: str # or Path
name: str
description: Optional[str]
class StableDiffusionDModel(DiffusersModel):
class Config(ModelConfigBase):
format: str
vae: Optional[str] = Field(None)
config: Optional[str] = Field(None)
@root_validator
def validator(cls, values):
if values["format"] not in {"checkpoint", "diffusers"}:
raise ValueError(f"Unkown stable diffusion model format: {values['format']}")
if values["config"] is not None and values["format"] != "checkpoint":
raise ValueError(f"Custom config field allowed only in checkpoint stable diffusion model")
return values
# return config only for checkpoint format
def dict(self, *args, **kwargs):
result = super().dict(*args, **kwargs)
if self.format != "checkpoint":
result.pop("config", None)
return result
@classproperty
def has_config(self):
return True
def build_config(self, **kwargs) -> dict:
try:
res = dict(
path=kwargs["path"],
name=kwargs["name"],
description=kwargs.get("description", None),
format=kwargs["format"],
vae=kwargs.get("vae", None),
)
if res["format"] not in {"checkpoint", "diffusers"}:
raise Exception(f"Unkonwn stable diffusion model format: {res['format']}")
if res["format"] == "checkpoint":
res["config"] = kwargs.get("config", None)
# TODO: raise if config specified for diffusers?
return res
except KeyError as e:
raise Exception(f"Field \"{e.args[0]}\" not found!")
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
assert base_model == BaseModelType.StableDiffusion1_5
assert model_type == ModelType.Pipeline
super().__init__(model_path, base_model, model_type)
@classmethod
def convert_if_required(cls, model_path: str, dst_path: str, config: Optional[dict]) -> str:
model_config = cls.Config(
**config,
path=model_path,
name="",
)
if hasattr(model_config, "config"):
convert_ckpt_and_cache(
model_path=model_path,
dst_path=dst_path,
config=config,
)
return dst_path
else:
return model_path
class StableDiffusion15CheckpointModel(DiffusersModel):
class Cnfig(ModelConfigBase):
vae: Optional[str] = Field(None)
config: Optional[str] = Field(None)
class StableDiffusion2BaseDiffusersModel(DiffusersModel):
class Config(ModelConfigBase):
vae: Optional[str] = Field(None)
class StableDiffusion2BaseCheckpointModel(DiffusersModel):
class Cnfig(ModelConfigBase):
vae: Optional[str] = Field(None)
config: Optional[str] = Field(None)
class StableDiffusion2DiffusersModel(DiffusersModel):
class Config(ModelConfigBase):
vae: Optional[str] = Field(None)
attention_upscale: bool = Field(True)
class StableDiffusion2CheckpointModel(DiffusersModel):
class Config(ModelConfigBase):
vae: Optional[str] = Field(None)
config: Optional[str] = Field(None)
attention_upscale: bool = Field(True)
class ClassifierModel(ModelBase):
#child_types: Dict[str, Type]
#child_sizes: Dict[str, int]
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
assert model_type == SDModelType.Classifier
super().__init__(model_path, base_model, model_type)
self.child_types: Dict[str, Type] = dict()
self.child_sizes: Dict[str, int] = dict()
try:
main_config = EmptyConfigLoader.load_config(self.model_path, config_name="config.json")
#main_config = json.loads(os.path.join(self.model_path, "config.json"))
except:
raise Exception("Invalid classifier model! (config.json not found or invalid)")
self._load_tokenizer(main_config)
self._load_text_encoder(main_config)
self._load_feature_extractor(main_config)
def _load_tokenizer(self, main_config: dict):
try:
tokenizer_config = EmptyConfigLoader.load_config(self.model_path, config_name="tokenizer_config.json")
#tokenizer_config = json.loads(os.path.join(self.model_path, "tokenizer_config.json"))
except:
raise Exception("Invalid classifier model! (Failed to load tokenizer_config.json)")
if "tokenizer_class" in tokenizer_config:
tokenizer_class_name = tokenizer_config["tokenizer_class"]
elif "model_type" in main_config:
tokenizer_class_name = transformers.models.auto.tokenization_auto.TOKENIZER_MAPPING_NAMES[main_config["model_type"]]
else:
raise Exception("Invalid classifier model! (Failed to detect tokenizer type)")
self.child_types[SDModelType.Tokenizer] = self._hf_definition_to_type(["transformers", tokenizer_class_name])
self.child_sizes[SDModelType.Tokenizer] = 0
def _load_text_encoder(self, main_config: dict):
if "architectures" in main_config and len(main_config["architectures"]) > 0:
text_encoder_class_name = main_config["architectures"][0]
elif "model_type" in main_config:
text_encoder_class_name = transformers.models.auto.modeling_auto.MODEL_FOR_PRETRAINING_MAPPING_NAMES[main_config["model_type"]]
else:
raise Exception("Invalid classifier model! (Failed to detect text_encoder type)")
self.child_types[SDModelType.TextEncoder] = self._hf_definition_to_type(["transformers", text_encoder_class_name])
self.child_sizes[SDModelType.TextEncoder] = calc_model_size_by_fs(self.model_path)
def _load_feature_extractor(self, main_config: dict):
self.child_sizes[SDModelType.FeatureExtractor] = 0
try:
feature_extractor_config = EmptyConfigLoader.load_config(self.model_path, config_name="preprocessor_config.json")
except:
return # feature extractor not passed with t5
try:
feature_extractor_class_name = feature_extractor_config["feature_extractor_type"]
self.child_types[SDModelType.FeatureExtractor] = self._hf_definition_to_type(["transformers", feature_extractor_class_name])
except:
raise Exception("Invalid classifier model! (Unknown feature_extrator type)")
def get_size(self, child_type: Optional[SDModelType] = None):
if child_type is None:
return sum(self.child_sizes.values())
else:
return self.child_sizes[child_type]
def get_model(
self,
torch_dtype: Optional[torch.dtype],
child_type: Optional[SDModelType] = None,
):
if child_type is None:
raise Exception("Child model type can't be null on classififer model")
if child_type not in self.child_types:
return None # TODO: or raise
model = self.child_types[child_type].from_pretrained(
self.model_path,
subfolder=child_type.value,
torch_dtype=torch_dtype,
)
# calc more accurate size
self.child_sizes[child_type] = calc_model_size_by_data(model)
return model
@staticmethod
def convert_if_required(model_path: Union[str, Path], cache_path: str, config: Optional[dict]) -> Path:
if not isinstance(model_path, Path):
model_path = Path(model_path)
return model_path
class VaeModel(ModelBase):
#vae_class: Type
#model_size: int
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
assert model_type == ModelType.Vae
super().__init__(model_path, base_model, model_type)
try:
config = EmptyConfigLoader.load_config(self.model_path, config_name="config.json")
#config = json.loads(os.path.join(self.model_path, "config.json"))
except:
raise Exception("Invalid vae model! (config.json not found or invalid)")
try:
vae_class_name = config.get("_class_name", "AutoencoderKL")
self.vae_class = self._hf_definition_to_type(["diffusers", vae_class_name])
self.model_size = calc_model_size_by_fs(self.model_path)
except:
raise Exception("Invalid vae model! (Unkown vae type)")
def get_size(self, child_type: Optional[SDModelType] = None):
if child_type is not None:
raise Exception("There is no child models in vae model")
return self.model_size
def get_model(
self,
torch_dtype: Optional[torch.dtype],
child_type: Optional[SDModelType] = None,
):
if child_type is not None:
raise Exception("There is no child models in vae model")
model = self.vae_class.from_pretrained(
self.model_path,
torch_dtype=torch_dtype,
)
# calc more accurate size
self.model_size = calc_model_size_by_data(model)
return model
@staticmethod
def convert_if_required(model_path: Union[str, Path], cache_path: str, config: Optional[dict]) -> Path:
if not isinstance(model_path, Path):
model_path = Path(model_path)
# TODO:
#_convert_vae_ckpt_and_cache
raise Exception("TODO: ")
class LoRAModel(ModelBase):
#model_size: int
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
assert model_type == ModelType.Lora
super().__init__(model_path, base_model, model_type)
self.model_size = os.path.getsize(self.model_path)
def get_size(self, child_type: Optional[SDModelType] = None):
if child_type is not None:
raise Exception("There is no child models in lora")
return self.model_size
def get_model(
self,
torch_dtype: Optional[torch.dtype],
child_type: Optional[SDModelType] = None,
):
if child_type is not None:
raise Exception("There is no child models in lora")
model = LoRAModel.from_checkpoint(
file_path=self.model_path,
dtype=torch_dtype,
)
self.model_size = model.calc_size()
return model
@staticmethod
def convert_if_required(model_path: Union[str, Path], cache_path: str, config: Optional[dict]) -> Path:
if not isinstance(model_path, Path):
model_path = Path(model_path)
# TODO: add diffusers lora when it stabilizes a bit
return model_path
class TextualInversionModel(ModelBase):
#model_size: int
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
assert model_type == ModelType.TextualInversion
super().__init__(model_path, base_model, model_type)
self.model_size = os.path.getsize(self.model_path)
def get_size(self, child_type: Optional[SDModelType] = None):
if child_type is not None:
raise Exception("There is no child models in textual inversion")
return self.model_size
def get_model(
self,
torch_dtype: Optional[torch.dtype],
child_type: Optional[SDModelType] = None,
):
if child_type is not None:
raise Exception("There is no child models in textual inversion")
model = TextualInversionModel.from_checkpoint(
file_path=self.model_path,
dtype=torch_dtype,
)
self.model_size = model.embedding.nelement() * model.embedding.element_size()
return model
@staticmethod
def convert_if_required(model_path: Union[str, Path], cache_path: str, config: Optional[dict]) -> Path:
if not isinstance(model_path, Path):
model_path = Path(model_path)
return model_path
def calc_model_size_by_fs(
model_path: str,
subfolder: Optional[str] = None,
variant: Optional[str] = None
):
if subfolder is not None:
model_path = os.path.join(model_path, subfolder)
# this can happen when, for example, the safety checker
# is not downloaded.
if not os.path.exists(model_path):
return 0
all_files = os.listdir(model_path)
all_files = [f for f in all_files if os.path.isfile(os.path.join(model_path, f))]
fp16_files = set([f for f in all_files if ".fp16." in f or ".fp16-" in f])
bit8_files = set([f for f in all_files if ".8bit." in f or ".8bit-" in f])
other_files = set(all_files) - fp16_files - bit8_files
if variant is None:
files = other_files
elif variant == "fp16":
files = fp16_files
elif variant == "8bit":
files = bit8_files
else:
raise NotImplementedError(f"Unknown variant: {variant}")
# try read from index if exists
index_postfix = ".index.json"
if variant is not None:
index_postfix = f".index.{variant}.json"
for file in files:
if not file.endswith(index_postfix):
continue
try:
with open(os.path.join(model_path, file), "r") as f:
index_data = json.loads(f.read())
return int(index_data["metadata"]["total_size"])
except:
pass
# calculate files size if there is no index file
formats = [
(".safetensors",), # safetensors
(".bin",), # torch
(".onnx", ".pb"), # onnx
(".msgpack",), # flax
(".ckpt",), # tf
(".h5",), # tf2
]
for file_format in formats:
model_files = [f for f in files if f.endswith(file_format)]
if len(model_files) == 0:
continue
model_size = 0
for model_file in model_files:
file_stats = os.stat(os.path.join(model_path, model_file))
model_size += file_stats.st_size
return model_size
#raise NotImplementedError(f"Unknown model structure! Files: {all_files}")
return 0 # scheduler/feature_extractor/tokenizer - models without loading to gpu
def calc_model_size_by_data(model) -> int:
if isinstance(model, DiffusionPipeline):
return _calc_pipeline_by_data(model)
elif isinstance(model, torch.nn.Module):
return _calc_model_by_data(model)
else:
return 0
def _calc_pipeline_by_data(pipeline) -> int:
res = 0
for submodel_key in pipeline.components.keys():
submodel = getattr(pipeline, submodel_key)
if submodel is not None and isinstance(submodel, torch.nn.Module):
res += _calc_model_by_data(submodel)
return res
def _calc_model_by_data(model) -> int:
mem_params = sum([param.nelement()*param.element_size() for param in model.parameters()])
mem_bufs = sum([buf.nelement()*buf.element_size() for buf in model.buffers()])
mem = mem_params + mem_bufs # in bytes
return mem
def _convert_ckpt_and_cache(self, mconfig: DictConfig) -> Path:
"""
Convert the checkpoint model indicated in mconfig into a
diffusers, cache it to disk, and return Path to converted
file. If already on disk then just returns Path.
"""
app_config = InvokeAIAppConfig.get_config()
weights = app_config.root_dir / mconfig.path
config_file = app_config.root_dir / mconfig.config
diffusers_path = app_config.converted_ckpts_dir / weights.stem
# return cached version if it exists
if diffusers_path.exists():
return diffusers_path
# TODO: I think that it more correctly to convert with embedded vae
# as if user will delete custom vae he will got not embedded but also custom vae
#vae_ckpt_path, vae_model = self._get_vae_for_conversion(weights, mconfig)
vae_ckpt_path, vae_model = None, None
# to avoid circular import errors
from .convert_ckpt_to_diffusers import convert_ckpt_to_diffusers
with SilenceWarnings():
convert_ckpt_to_diffusers(
weights,
diffusers_path,
extract_ema=True,
original_config_file=config_file,
vae=vae_model,
vae_path=str(app_config.root_dir / vae_ckpt_path) if vae_ckpt_path else None,
scan_needed=True,
)
return diffusers_path
def _convert_vae_ckpt_and_cache(self, mconfig: DictConfig) -> Path:
"""
Convert the VAE indicated in mconfig into a diffusers AutoencoderKL
object, cache it to disk, and return Path to converted
file. If already on disk then just returns Path.
"""
app_config = InvokeAIAppConfig.get_config()
root = app_config.root_dir
weights_file = root / mconfig.path
config_file = root / mconfig.config
diffusers_path = app_config.converted_ckpts_dir / weights_file.stem
image_size = mconfig.get('width') or mconfig.get('height') or 512
# return cached version if it exists
if diffusers_path.exists():
return diffusers_path
# this avoids circular import error
from .convert_ckpt_to_diffusers import convert_ldm_vae_to_diffusers
if weights_file.suffix == '.safetensors':
checkpoint = safetensors.torch.load_file(weights_file)
else:
checkpoint = torch.load(weights_file, map_location="cpu")
# sometimes weights are hidden under "state_dict", and sometimes not
if "state_dict" in checkpoint:
checkpoint = checkpoint["state_dict"]
config = OmegaConf.load(config_file)
vae_model = convert_ldm_vae_to_diffusers(
checkpoint = checkpoint,
vae_config = config,
image_size = image_size
)
vae_model.save_pretrained(
diffusers_path,
safe_serialization=is_safetensors_available()
)
return diffusers_path