InvokeAI/invokeai/backend/model_management/models/base.py

Ignoring revisions in .git-blame-ignore-revs. Click here to bypass and see the normal blame view.

307 lines
9.7 KiB
Python
Raw Normal View History

2023-06-12 13:14:09 +00:00
import os
2023-06-11 01:49:09 +00:00
import sys
2023-06-11 13:42:40 +00:00
import typing
import inspect
2023-06-11 01:49:09 +00:00
from enum import Enum
import torch
from diffusers import DiffusionPipeline, ConfigMixin
from pydantic import BaseModel, Field
2023-06-11 13:42:40 +00:00
from typing import List, Dict, Optional, Type, Literal
2023-06-11 01:49:09 +00:00
class BaseModelType(str, Enum):
#StableDiffusion1_5 = "stable_diffusion_1_5"
#StableDiffusion2 = "stable_diffusion_2"
#StableDiffusion2Base = "stable_diffusion_2_base"
# TODO: maybe then add sample size(512/768)?
StableDiffusion1 = "sd-1"
2023-06-11 13:42:40 +00:00
StableDiffusion2 = "sd-2" # 768 pixels; this will have v-prediction parameterization
2023-06-11 01:49:09 +00:00
#Kandinsky2_1 = "kandinsky_2_1"
class ModelType(str, Enum):
Pipeline = "pipeline"
Vae = "vae"
Lora = "lora"
ControlNet = "controlnet" # used by model_probe
2023-06-11 01:49:09 +00:00
TextualInversion = "embedding"
2023-06-11 03:12:21 +00:00
class SubModelType(str, Enum):
2023-06-11 01:49:09 +00:00
UNet = "unet"
TextEncoder = "text_encoder"
Tokenizer = "tokenizer"
Vae = "vae"
Scheduler = "scheduler"
SafetyChecker = "safety_checker"
#MoVQ = "movq"
class ModelVariantType(str, Enum):
Normal = "normal"
Inpaint = "inpaint"
Depth = "depth"
class SchedulerPredictionType(str, Enum):
Epsilon = "epsilon"
VPrediction = "v_prediction"
Sample = "sample"
2023-06-11 01:49:09 +00:00
class ModelError(str, Enum):
NotFound = "not_found"
class ModelConfigBase(BaseModel):
path: str # or Path
#name: str # not included as present in model key
description: Optional[str] = Field(None)
format: Optional[str] = Field(None)
default: Optional[bool] = Field(False)
# do not save to config
error: Optional[ModelError] = Field(None, exclude=True)
class EmptyConfigLoader(ConfigMixin):
@classmethod
def load_config(cls, *args, **kwargs):
cls.config_name = kwargs.pop("config_name")
return super().load_config(*args, **kwargs)
class ModelBase:
#model_path: str
#base_model: BaseModelType
#model_type: ModelType
def __init__(
self,
model_path: str,
base_model: BaseModelType,
model_type: ModelType,
):
self.model_path = model_path
self.base_model = base_model
self.model_type = model_type
def _hf_definition_to_type(self, subtypes: List[str]) -> Type:
if len(subtypes) < 2:
raise Exception("Invalid subfolder definition!")
if subtypes[0] in ["diffusers", "transformers"]:
res_type = sys.modules[subtypes[0]]
subtypes = subtypes[1:]
else:
res_type = sys.modules["diffusers"]
res_type = getattr(res_type, "pipelines")
for subtype in subtypes:
res_type = getattr(res_type, subtype)
return res_type
@classmethod
def _get_configs(cls):
if not hasattr(cls, "__configs"):
configs = dict()
for name in dir(cls):
if name.startswith("__"):
continue
value = getattr(cls, name)
if not isinstance(value, type) or not issubclass(value, ModelConfigBase):
continue
fields = inspect.get_annotations(value)
if "format" not in fields or typing.get_origin(fields["format"]) != Literal:
raise Exception("Invalid config definition - format field not found")
format_type = typing.get_origin(fields["format"])
if format_type not in {None, Literal}:
raise Exception(f"Invalid config definition - unknown format type: {fields['format']}")
if format_type is Literal:
format = fields["format"].__args__[0]
else:
format = None
configs[format] = value # TODO: error when override(multiple)?
cls.__configs = configs
return cls.__configs
@classmethod
def build_config(cls, **kwargs):
if "format" not in kwargs:
kwargs["format"] = cls.detect_format(kwargs["path"])
configs = cls._get_configs()
return configs[kwargs["format"]](**kwargs)
@classmethod
def detect_format(cls, path: str) -> str:
raise NotImplementedError()
class DiffusersModel(ModelBase):
#child_types: Dict[str, Type]
#child_sizes: Dict[str, int]
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
super().__init__(model_path, base_model, model_type)
self.child_types: Dict[str, Type] = dict()
self.child_sizes: Dict[str, int] = dict()
try:
config_data = DiffusionPipeline.load_config(self.model_path)
#config_data = json.loads(os.path.join(self.model_path, "model_index.json"))
except:
raise Exception("Invalid diffusers model! (model_index.json not found or invalid)")
config_data.pop("_ignore_files", None)
# retrieve all folder_names that contain relevant files
child_components = [k for k, v in config_data.items() if isinstance(v, list)]
for child_name in child_components:
child_type = self._hf_definition_to_type(config_data[child_name])
self.child_types[child_name] = child_type
self.child_sizes[child_name] = calc_model_size_by_fs(self.model_path, subfolder=child_name)
def get_size(self, child_type: Optional[SubModelType] = None):
if child_type is None:
return sum(self.child_sizes.values())
else:
return self.child_sizes[child_type]
def get_model(
self,
torch_dtype: Optional[torch.dtype],
child_type: Optional[SubModelType] = None,
):
# return pipeline in different function to pass more arguments
if child_type is None:
raise Exception("Child model type can't be null on diffusers model")
if child_type not in self.child_types:
return None # TODO: or raise
if torch_dtype == torch.float16:
variants = ["fp16", None]
else:
variants = [None, "fp16"]
# TODO: better error handling(differentiate not found from others)
for variant in variants:
try:
# TODO: set cache_dir to /dev/null to be sure that cache not used?
model = self.child_types[child_type].from_pretrained(
self.model_path,
subfolder=child_type.value,
torch_dtype=torch_dtype,
variant=variant,
local_files_only=True,
)
break
except Exception as e:
print("====ERR LOAD====")
print(f"{variant}: {e}")
# calc more accurate size
self.child_sizes[child_type] = calc_model_size_by_data(model)
return model
#def convert_if_required(model_path: str, cache_path: str, config: Optional[dict]) -> str:
def calc_model_size_by_fs(
model_path: str,
subfolder: Optional[str] = None,
variant: Optional[str] = None
):
if subfolder is not None:
model_path = os.path.join(model_path, subfolder)
# this can happen when, for example, the safety checker
# is not downloaded.
if not os.path.exists(model_path):
return 0
all_files = os.listdir(model_path)
all_files = [f for f in all_files if os.path.isfile(os.path.join(model_path, f))]
fp16_files = set([f for f in all_files if ".fp16." in f or ".fp16-" in f])
bit8_files = set([f for f in all_files if ".8bit." in f or ".8bit-" in f])
other_files = set(all_files) - fp16_files - bit8_files
if variant is None:
files = other_files
elif variant == "fp16":
files = fp16_files
elif variant == "8bit":
files = bit8_files
else:
raise NotImplementedError(f"Unknown variant: {variant}")
# try read from index if exists
index_postfix = ".index.json"
if variant is not None:
index_postfix = f".index.{variant}.json"
for file in files:
if not file.endswith(index_postfix):
continue
try:
with open(os.path.join(model_path, file), "r") as f:
index_data = json.loads(f.read())
return int(index_data["metadata"]["total_size"])
except:
pass
# calculate files size if there is no index file
formats = [
(".safetensors",), # safetensors
(".bin",), # torch
(".onnx", ".pb"), # onnx
(".msgpack",), # flax
(".ckpt",), # tf
(".h5",), # tf2
]
for file_format in formats:
model_files = [f for f in files if f.endswith(file_format)]
if len(model_files) == 0:
continue
model_size = 0
for model_file in model_files:
file_stats = os.stat(os.path.join(model_path, model_file))
model_size += file_stats.st_size
return model_size
#raise NotImplementedError(f"Unknown model structure! Files: {all_files}")
return 0 # scheduler/feature_extractor/tokenizer - models without loading to gpu
def calc_model_size_by_data(model) -> int:
if isinstance(model, DiffusionPipeline):
return _calc_pipeline_by_data(model)
elif isinstance(model, torch.nn.Module):
return _calc_model_by_data(model)
else:
return 0
def _calc_pipeline_by_data(pipeline) -> int:
res = 0
for submodel_key in pipeline.components.keys():
submodel = getattr(pipeline, submodel_key)
if submodel is not None and isinstance(submodel, torch.nn.Module):
res += _calc_model_by_data(submodel)
return res
def _calc_model_by_data(model) -> int:
mem_params = sum([param.nelement()*param.element_size() for param in model.parameters()])
mem_bufs = sum([buf.nelement()*buf.element_size() for buf in model.buffers()])
mem = mem_params + mem_bufs # in bytes
return mem