mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
New models structure draft
This commit is contained in:
parent
887576d217
commit
2c056ead42
@ -248,7 +248,6 @@ class TextToLatentsInvocation(BaseInvocation):
|
|||||||
feature_extractor=None,
|
feature_extractor=None,
|
||||||
requires_safety_checker=False,
|
requires_safety_checker=False,
|
||||||
precision="float16" if unet.dtype == torch.float16 else "float32",
|
precision="float16" if unet.dtype == torch.float16 else "float32",
|
||||||
#precision="float16", # TODO:
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def prep_control_data(self,
|
def prep_control_data(self,
|
||||||
|
@ -40,456 +40,7 @@ from invokeai.app.services.config import get_invokeai_config
|
|||||||
|
|
||||||
from .lora import LoRAModel, TextualInversionModel
|
from .lora import LoRAModel, TextualInversionModel
|
||||||
|
|
||||||
def get_model_path(repo_id_or_path: str):
|
from .models import MODEL_CLASSES
|
||||||
globals = get_invokeai_config()
|
|
||||||
|
|
||||||
if os.path.exists(repo_id_or_path):
|
|
||||||
return repo_id_or_path
|
|
||||||
|
|
||||||
cache = scan_cache_dir(globals.cache_dir)
|
|
||||||
for repo in cache.repos:
|
|
||||||
if repo.repo_id != repo_id_or_path:
|
|
||||||
continue
|
|
||||||
for rev in repo.revisions:
|
|
||||||
if "main" in rev.refs:
|
|
||||||
return rev.snapshot_path
|
|
||||||
raise Exception(f"{repo_id_or_path} - not found")
|
|
||||||
|
|
||||||
def calc_model_size_by_fs(
|
|
||||||
repo_id_or_path: str,
|
|
||||||
subfolder: Optional[str] = None,
|
|
||||||
variant: Optional[str] = None
|
|
||||||
):
|
|
||||||
model_path = get_model_path(repo_id_or_path)
|
|
||||||
if subfolder is not None:
|
|
||||||
model_path = os.path.join(model_path, subfolder)
|
|
||||||
|
|
||||||
# this can happen when, for example, the safety checker
|
|
||||||
# is not downloaded.
|
|
||||||
if not os.path.exists(model_path):
|
|
||||||
return 0
|
|
||||||
|
|
||||||
all_files = os.listdir(model_path)
|
|
||||||
all_files = [f for f in all_files if os.path.isfile(os.path.join(model_path, f))]
|
|
||||||
|
|
||||||
fp16_files = set([f for f in all_files if ".fp16." in f or ".fp16-" in f])
|
|
||||||
bit8_files = set([f for f in all_files if ".8bit." in f or ".8bit-" in f])
|
|
||||||
other_files = set(all_files) - fp16_files - bit8_files
|
|
||||||
|
|
||||||
if variant is None:
|
|
||||||
files = other_files
|
|
||||||
elif variant == "fp16":
|
|
||||||
files = fp16_files
|
|
||||||
elif variant == "8bit":
|
|
||||||
files = bit8_files
|
|
||||||
else:
|
|
||||||
raise NotImplementedError(f"Unknown variant: {variant}")
|
|
||||||
|
|
||||||
# try read from index if exists
|
|
||||||
index_postfix = ".index.json"
|
|
||||||
if variant is not None:
|
|
||||||
index_postfix = f".index.{variant}.json"
|
|
||||||
|
|
||||||
for file in files:
|
|
||||||
if not file.endswith(index_postfix):
|
|
||||||
continue
|
|
||||||
try:
|
|
||||||
with open(os.path.join(model_path, file), "r") as f:
|
|
||||||
index_data = json.loads(f.read())
|
|
||||||
return int(index_data["metadata"]["total_size"])
|
|
||||||
except:
|
|
||||||
pass
|
|
||||||
|
|
||||||
# calculate files size if there is no index file
|
|
||||||
formats = [
|
|
||||||
(".safetensors",), # safetensors
|
|
||||||
(".bin",), # torch
|
|
||||||
(".onnx", ".pb"), # onnx
|
|
||||||
(".msgpack",), # flax
|
|
||||||
(".ckpt",), # tf
|
|
||||||
(".h5",), # tf2
|
|
||||||
]
|
|
||||||
|
|
||||||
for file_format in formats:
|
|
||||||
model_files = [f for f in files if f.endswith(file_format)]
|
|
||||||
if len(model_files) == 0:
|
|
||||||
continue
|
|
||||||
|
|
||||||
model_size = 0
|
|
||||||
for model_file in model_files:
|
|
||||||
file_stats = os.stat(os.path.join(model_path, model_file))
|
|
||||||
model_size += file_stats.st_size
|
|
||||||
return model_size
|
|
||||||
|
|
||||||
#raise NotImplementedError(f"Unknown model structure! Files: {all_files}")
|
|
||||||
return 0 # scheduler/feature_extractor/tokenizer - models without loading to gpu
|
|
||||||
|
|
||||||
|
|
||||||
def calc_model_size_by_data(model) -> int:
|
|
||||||
if isinstance(model, DiffusionPipeline):
|
|
||||||
return _calc_pipeline_by_data(model)
|
|
||||||
elif isinstance(model, torch.nn.Module):
|
|
||||||
return _calc_model_by_data(model)
|
|
||||||
else:
|
|
||||||
return 0
|
|
||||||
|
|
||||||
|
|
||||||
def _calc_pipeline_by_data(pipeline) -> int:
|
|
||||||
res = 0
|
|
||||||
for submodel_key in pipeline.components.keys():
|
|
||||||
submodel = getattr(pipeline, submodel_key)
|
|
||||||
if submodel is not None and isinstance(submodel, torch.nn.Module):
|
|
||||||
res += _calc_model_by_data(submodel)
|
|
||||||
return res
|
|
||||||
|
|
||||||
|
|
||||||
def _calc_model_by_data(model) -> int:
|
|
||||||
mem_params = sum([param.nelement()*param.element_size() for param in model.parameters()])
|
|
||||||
mem_bufs = sum([buf.nelement()*buf.element_size() for buf in model.buffers()])
|
|
||||||
mem = mem_params + mem_bufs # in bytes
|
|
||||||
return mem
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class SDModelType(str, Enum):
|
|
||||||
Diffusers = "diffusers"
|
|
||||||
Classifier = "classifier"
|
|
||||||
UNet = "unet"
|
|
||||||
TextEncoder = "text_encoder"
|
|
||||||
Tokenizer = "tokenizer"
|
|
||||||
Vae = "vae"
|
|
||||||
Scheduler = "scheduler"
|
|
||||||
Lora = "lora"
|
|
||||||
TextualInversion = "textual_inversion"
|
|
||||||
ControlNet = "control_net"
|
|
||||||
|
|
||||||
class BaseModel(str, Enum):
|
|
||||||
StableDiffusion1_5 = "SD-1"
|
|
||||||
StableDiffusion2Base = "SD-2-base" # 512 pixels; this will have epsilon parameterization
|
|
||||||
StableDiffusion2 = "SD-2" # 768 pixels; this will have v-prediction parameterization
|
|
||||||
|
|
||||||
class ModelInfoBase:
|
|
||||||
#model_path: str
|
|
||||||
#model_type: SDModelType
|
|
||||||
|
|
||||||
def __init__(self, repo_id_or_path: str, model_type: SDModelType):
|
|
||||||
self.repo_id_or_path = repo_id_or_path # TODO: or use allways path?
|
|
||||||
self.model_path = get_model_path(repo_id_or_path)
|
|
||||||
self.model_type = model_type
|
|
||||||
|
|
||||||
def _definition_to_type(self, subtypes: List[str]) -> Type:
|
|
||||||
if len(subtypes) < 2:
|
|
||||||
raise Exception("Invalid subfolder definition!")
|
|
||||||
if subtypes[0] in ["diffusers", "transformers"]:
|
|
||||||
res_type = sys.modules[subtypes[0]]
|
|
||||||
subtypes = subtypes[1:]
|
|
||||||
|
|
||||||
else:
|
|
||||||
res_type = sys.modules["diffusers"]
|
|
||||||
res_type = getattr(res_type, "pipelines")
|
|
||||||
|
|
||||||
|
|
||||||
for subtype in subtypes:
|
|
||||||
res_type = getattr(res_type, subtype)
|
|
||||||
return res_type
|
|
||||||
|
|
||||||
|
|
||||||
class DiffusersModelInfo(ModelInfoBase):
|
|
||||||
#child_types: Dict[str, Type]
|
|
||||||
#child_sizes: Dict[str, int]
|
|
||||||
|
|
||||||
def __init__(self, repo_id_or_path: str, model_type: SDModelType):
|
|
||||||
assert model_type == SDModelType.Diffusers
|
|
||||||
super().__init__(repo_id_or_path, model_type)
|
|
||||||
|
|
||||||
self.child_types: Dict[str, Type] = dict()
|
|
||||||
self.child_sizes: Dict[str, int] = dict()
|
|
||||||
|
|
||||||
try:
|
|
||||||
config_data = DiffusionPipeline.load_config(repo_id_or_path)
|
|
||||||
#config_data = json.loads(os.path.join(self.model_path, "model_index.json"))
|
|
||||||
except:
|
|
||||||
raise Exception("Invalid diffusers model! (model_index.json not found or invalid)")
|
|
||||||
|
|
||||||
config_data.pop("_ignore_files", None)
|
|
||||||
|
|
||||||
# retrieve all folder_names that contain relevant files
|
|
||||||
child_components = [k for k, v in config_data.items() if isinstance(v, list)]
|
|
||||||
|
|
||||||
for child_name in child_components:
|
|
||||||
child_type = self._definition_to_type(config_data[child_name])
|
|
||||||
self.child_types[child_name] = child_type
|
|
||||||
self.child_sizes[child_name] = calc_model_size_by_fs(repo_id_or_path, subfolder=child_name)
|
|
||||||
|
|
||||||
|
|
||||||
def get_size(self, child_type: Optional[SDModelType] = None):
|
|
||||||
if child_type is None:
|
|
||||||
return sum(self.child_sizes.values())
|
|
||||||
else:
|
|
||||||
return self.child_sizes[child_type]
|
|
||||||
|
|
||||||
|
|
||||||
def get_model(
|
|
||||||
self,
|
|
||||||
child_type: Optional[SDModelType] = None,
|
|
||||||
torch_dtype: Optional[torch.dtype] = None,
|
|
||||||
):
|
|
||||||
# return pipeline in different function to pass more arguments
|
|
||||||
if child_type is None:
|
|
||||||
raise Exception("Child model type can't be null on diffusers model")
|
|
||||||
if child_type not in self.child_types:
|
|
||||||
return None # TODO: or raise
|
|
||||||
|
|
||||||
# TODO:
|
|
||||||
for variant in ["fp16", "main", None]:
|
|
||||||
try:
|
|
||||||
model = self.child_types[child_type].from_pretrained(
|
|
||||||
self.repo_id_or_path,
|
|
||||||
subfolder=child_type.value,
|
|
||||||
cache_dir=get_invokeai_config().cache_dir,
|
|
||||||
torch_dtype=torch_dtype,
|
|
||||||
variant=variant,
|
|
||||||
)
|
|
||||||
break
|
|
||||||
except Exception as e:
|
|
||||||
print("====ERR LOAD====")
|
|
||||||
print(f"{variant}: {e}")
|
|
||||||
|
|
||||||
# calc more accurate size
|
|
||||||
self.child_sizes[child_type] = calc_model_size_by_data(model)
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
def get_pipeline(self, **kwargs):
|
|
||||||
return DiffusionPipeline.from_pretrained(
|
|
||||||
self.repo_id_or_path,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class EmptyConfigLoader(ConfigMixin):
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def load_config(cls, *args, **kwargs):
|
|
||||||
cls.config_name = kwargs.pop("config_name")
|
|
||||||
return super().load_config(*args, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
class ClassifierModelInfo(ModelInfoBase):
|
|
||||||
#child_types: Dict[str, Type]
|
|
||||||
#child_sizes: Dict[str, int]
|
|
||||||
|
|
||||||
def __init__(self, repo_id_or_path: str, model_type: SDModelType):
|
|
||||||
assert model_type == SDModelType.Classifier
|
|
||||||
super().__init__(repo_id_or_path, model_type)
|
|
||||||
|
|
||||||
self.child_types: Dict[str, Type] = dict()
|
|
||||||
self.child_sizes: Dict[str, int] = dict()
|
|
||||||
|
|
||||||
try:
|
|
||||||
main_config = EmptyConfigLoader.load_config(self.repo_id_or_path, config_name="config.json")
|
|
||||||
#main_config = json.loads(os.path.join(self.model_path, "config.json"))
|
|
||||||
except:
|
|
||||||
raise Exception("Invalid classifier model! (config.json not found or invalid)")
|
|
||||||
|
|
||||||
self._load_tokenizer(main_config)
|
|
||||||
self._load_text_encoder(main_config)
|
|
||||||
self._load_feature_extractor(main_config)
|
|
||||||
|
|
||||||
|
|
||||||
def _load_tokenizer(self, main_config: dict):
|
|
||||||
try:
|
|
||||||
tokenizer_config = EmptyConfigLoader.load_config(self.repo_id_or_path, config_name="tokenizer_config.json")
|
|
||||||
#tokenizer_config = json.loads(os.path.join(self.model_path, "tokenizer_config.json"))
|
|
||||||
except:
|
|
||||||
raise Exception("Invalid classifier model! (Failed to load tokenizer_config.json)")
|
|
||||||
|
|
||||||
if "tokenizer_class" in tokenizer_config:
|
|
||||||
tokenizer_class_name = tokenizer_config["tokenizer_class"]
|
|
||||||
elif "model_type" in main_config:
|
|
||||||
tokenizer_class_name = transformers.models.auto.tokenization_auto.TOKENIZER_MAPPING_NAMES[main_config["model_type"]]
|
|
||||||
else:
|
|
||||||
raise Exception("Invalid classifier model! (Failed to detect tokenizer type)")
|
|
||||||
|
|
||||||
self.child_types[SDModelType.Tokenizer] = self._definition_to_type(["transformers", tokenizer_class_name])
|
|
||||||
self.child_sizes[SDModelType.Tokenizer] = 0
|
|
||||||
|
|
||||||
|
|
||||||
def _load_text_encoder(self, main_config: dict):
|
|
||||||
if "architectures" in main_config and len(main_config["architectures"]) > 0:
|
|
||||||
text_encoder_class_name = main_config["architectures"][0]
|
|
||||||
elif "model_type" in main_config:
|
|
||||||
text_encoder_class_name = transformers.models.auto.modeling_auto.MODEL_FOR_PRETRAINING_MAPPING_NAMES[main_config["model_type"]]
|
|
||||||
else:
|
|
||||||
raise Exception("Invalid classifier model! (Failed to detect text_encoder type)")
|
|
||||||
|
|
||||||
self.child_types[SDModelType.TextEncoder] = self._definition_to_type(["transformers", text_encoder_class_name])
|
|
||||||
self.child_sizes[SDModelType.TextEncoder] = calc_model_size_by_fs(self.repo_id_or_path)
|
|
||||||
|
|
||||||
|
|
||||||
def _load_feature_extractor(self, main_config: dict):
|
|
||||||
self.child_sizes[SDModelType.FeatureExtractor] = 0
|
|
||||||
try:
|
|
||||||
feature_extractor_config = EmptyConfigLoader.load_config(self.repo_id_or_path, config_name="preprocessor_config.json")
|
|
||||||
except:
|
|
||||||
return # feature extractor not passed with t5
|
|
||||||
|
|
||||||
try:
|
|
||||||
feature_extractor_class_name = feature_extractor_config["feature_extractor_type"]
|
|
||||||
self.child_types[SDModelType.FeatureExtractor] = self._definition_to_type(["transformers", feature_extractor_class_name])
|
|
||||||
except:
|
|
||||||
raise Exception("Invalid classifier model! (Unknown feature_extrator type)")
|
|
||||||
|
|
||||||
|
|
||||||
def get_size(self, child_type: Optional[SDModelType] = None):
|
|
||||||
if child_type is None:
|
|
||||||
return sum(self.child_sizes.values())
|
|
||||||
else:
|
|
||||||
return self.child_sizes[child_type]
|
|
||||||
|
|
||||||
|
|
||||||
def get_model(
|
|
||||||
self,
|
|
||||||
child_type: Optional[SDModelType] = None,
|
|
||||||
torch_dtype: Optional[torch.dtype] = None,
|
|
||||||
):
|
|
||||||
if child_type is None:
|
|
||||||
raise Exception("Child model type can't be null on classififer model")
|
|
||||||
if child_type not in self.child_types:
|
|
||||||
return None # TODO: or raise
|
|
||||||
|
|
||||||
model = self.child_types[child_type].from_pretrained(
|
|
||||||
self.repo_id_or_path,
|
|
||||||
subfolder=child_type.value,
|
|
||||||
cache_dir=get_invokeai_config().cache_dir,
|
|
||||||
torch_dtype=torch_dtype,
|
|
||||||
)
|
|
||||||
# calc more accurate size
|
|
||||||
self.child_sizes[child_type] = calc_model_size_by_data(model)
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class VaeModelInfo(ModelInfoBase):
|
|
||||||
#vae_class: Type
|
|
||||||
#model_size: int
|
|
||||||
|
|
||||||
def __init__(self, repo_id_or_path: str, model_type: SDModelType):
|
|
||||||
assert model_type == SDModelType.Vae
|
|
||||||
super().__init__(repo_id_or_path, model_type)
|
|
||||||
|
|
||||||
try:
|
|
||||||
config = EmptyConfigLoader.load_config(repo_id_or_path, config_name="config.json")
|
|
||||||
#config = json.loads(os.path.join(self.model_path, "config.json"))
|
|
||||||
except:
|
|
||||||
raise Exception("Invalid vae model! (config.json not found or invalid)")
|
|
||||||
|
|
||||||
try:
|
|
||||||
vae_class_name = config.get("_class_name", "AutoencoderKL")
|
|
||||||
self.vae_class = self._definition_to_type(["diffusers", vae_class_name])
|
|
||||||
self.model_size = calc_model_size_by_fs(repo_id_or_path)
|
|
||||||
except:
|
|
||||||
raise Exception("Invalid vae model! (Unkown vae type)")
|
|
||||||
|
|
||||||
def get_size(self, child_type: Optional[SDModelType] = None):
|
|
||||||
if child_type is not None:
|
|
||||||
raise Exception("There is no child models in vae model")
|
|
||||||
return self.model_size
|
|
||||||
|
|
||||||
def get_model(
|
|
||||||
self,
|
|
||||||
child_type: Optional[SDModelType] = None,
|
|
||||||
torch_dtype: Optional[torch.dtype] = None,
|
|
||||||
):
|
|
||||||
if child_type is not None:
|
|
||||||
raise Exception("There is no child models in vae model")
|
|
||||||
|
|
||||||
model = self.vae_class.from_pretrained(
|
|
||||||
self.repo_id_or_path,
|
|
||||||
cache_dir=get_invokeai_config().cache_dir,
|
|
||||||
torch_dtype=torch_dtype,
|
|
||||||
)
|
|
||||||
# calc more accurate size
|
|
||||||
self.model_size = calc_model_size_by_data(model)
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
class LoRAModelInfo(ModelInfoBase):
|
|
||||||
#model_size: int
|
|
||||||
|
|
||||||
def __init__(self, file_path: str, model_type: SDModelType):
|
|
||||||
assert model_type == SDModelType.Lora
|
|
||||||
# check manualy as super().__init__ will try to resolve repo_id too
|
|
||||||
if not os.path.exists(file_path):
|
|
||||||
raise Exception("Model not found")
|
|
||||||
super().__init__(file_path, model_type)
|
|
||||||
|
|
||||||
self.model_size = os.path.getsize(file_path)
|
|
||||||
|
|
||||||
def get_size(self, child_type: Optional[SDModelType] = None):
|
|
||||||
if child_type is not None:
|
|
||||||
raise Exception("There is no child models in lora")
|
|
||||||
return self.model_size
|
|
||||||
|
|
||||||
def get_model(
|
|
||||||
self,
|
|
||||||
child_type: Optional[SDModelType] = None,
|
|
||||||
torch_dtype: Optional[torch.dtype] = None,
|
|
||||||
):
|
|
||||||
if child_type is not None:
|
|
||||||
raise Exception("There is no child models in lora")
|
|
||||||
|
|
||||||
model = LoRAModel.from_checkpoint(
|
|
||||||
file_path=self.model_path,
|
|
||||||
dtype=torch_dtype,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.model_size = model.calc_size()
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
class TextualInversionModelInfo(ModelInfoBase):
|
|
||||||
#model_size: int
|
|
||||||
|
|
||||||
def __init__(self, file_path: str, model_type: SDModelType):
|
|
||||||
assert model_type == SDModelType.TextualInversion
|
|
||||||
# check manualy as super().__init__ will try to resolve repo_id too
|
|
||||||
if not os.path.exists(file_path):
|
|
||||||
raise Exception("Model not found")
|
|
||||||
super().__init__(file_path, model_type)
|
|
||||||
|
|
||||||
self.model_size = os.path.getsize(file_path)
|
|
||||||
|
|
||||||
def get_size(self, child_type: Optional[SDModelType] = None):
|
|
||||||
if child_type is not None:
|
|
||||||
raise Exception("There is no child models in textual inversion")
|
|
||||||
return self.model_size
|
|
||||||
|
|
||||||
def get_model(
|
|
||||||
self,
|
|
||||||
child_type: Optional[SDModelType] = None,
|
|
||||||
torch_dtype: Optional[torch.dtype] = None,
|
|
||||||
):
|
|
||||||
if child_type is not None:
|
|
||||||
raise Exception("There is no child models in textual inversion")
|
|
||||||
|
|
||||||
model = TextualInversionModel.from_checkpoint(
|
|
||||||
file_path=self.model_path,
|
|
||||||
dtype=torch_dtype,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.model_size = model.embedding.nelement() * model.embedding.element_size()
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
MODEL_TYPES = {
|
|
||||||
SDModelType.Diffusers: DiffusersModelInfo,
|
|
||||||
SDModelType.Classifier: ClassifierModelInfo,
|
|
||||||
SDModelType.Vae: VaeModelInfo,
|
|
||||||
SDModelType.Lora: LoRAModelInfo,
|
|
||||||
SDModelType.TextualInversion: TextualInversionModelInfo,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
# Maximum size of the cache, in gigs
|
# Maximum size of the cache, in gigs
|
||||||
@ -499,10 +50,6 @@ DEFAULT_MAX_CACHE_SIZE = 6.0
|
|||||||
# actual size of a gig
|
# actual size of a gig
|
||||||
GIG = 1073741824
|
GIG = 1073741824
|
||||||
|
|
||||||
# TODO:
|
|
||||||
class EmptyScheduler(SchedulerMixin, ConfigMixin):
|
|
||||||
pass
|
|
||||||
|
|
||||||
class ModelLocker(object):
|
class ModelLocker(object):
|
||||||
"Forward declaration"
|
"Forward declaration"
|
||||||
pass
|
pass
|
||||||
@ -583,12 +130,10 @@ class ModelCache(object):
|
|||||||
self,
|
self,
|
||||||
model_path: str,
|
model_path: str,
|
||||||
model_type: SDModelType,
|
model_type: SDModelType,
|
||||||
revision: Optional[str] = None,
|
|
||||||
submodel_type: Optional[SDModelType] = None,
|
submodel_type: Optional[SDModelType] = None,
|
||||||
):
|
):
|
||||||
revision = revision or "main"
|
|
||||||
|
|
||||||
key = f"{model_path}:{model_type}:{revision}"
|
key = f"{model_path}:{model_type}"
|
||||||
if submodel_type:
|
if submodel_type:
|
||||||
key += f":{submodel_type}"
|
key += f":{submodel_type}"
|
||||||
return key
|
return key
|
||||||
@ -606,55 +151,51 @@ class ModelCache(object):
|
|||||||
def _get_model_info(
|
def _get_model_info(
|
||||||
self,
|
self,
|
||||||
model_path: str,
|
model_path: str,
|
||||||
model_type: SDModelType,
|
model_class: Type[ModelBase],
|
||||||
revision: str,
|
|
||||||
):
|
):
|
||||||
model_info_key = self.get_key(
|
model_info_key = self.get_key(
|
||||||
model_path=model_path,
|
model_path=model_path,
|
||||||
model_type=model_type,
|
model_type=model_type,
|
||||||
revision=revision,
|
|
||||||
submodel_type=None,
|
submodel_type=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
if model_info_key not in self.model_infos:
|
if model_info_key not in self.model_infos:
|
||||||
if model_type not in MODEL_TYPES:
|
self.model_infos[model_info_key] = model_class(
|
||||||
raise Exception(f"Unknown/unsupported model type: {model_type}")
|
|
||||||
|
|
||||||
self.model_infos[model_info_key] = MODEL_TYPES[model_type](
|
|
||||||
model_path,
|
model_path,
|
||||||
model_type,
|
model_type,
|
||||||
)
|
)
|
||||||
|
|
||||||
return self.model_infos[model_info_key]
|
return self.model_infos[model_info_key]
|
||||||
|
|
||||||
|
# TODO: args
|
||||||
def get_model(
|
def get_model(
|
||||||
self,
|
self,
|
||||||
repo_id_or_path: Union[str, Path],
|
model_path: Union[str, Path],
|
||||||
model_type: SDModelType = SDModelType.Diffusers,
|
model_class: Type[ModelBase],
|
||||||
submodel: Optional[SDModelType] = None,
|
submodel: Optional[SubModelType] = None,
|
||||||
revision: Optional[str] = None,
|
|
||||||
variant: Optional[str] = None,
|
|
||||||
gpu_load: bool = True,
|
gpu_load: bool = True,
|
||||||
) -> Any:
|
) -> Any:
|
||||||
|
|
||||||
model_path = get_model_path(repo_id_or_path)
|
if not isinstance(model_path, Path):
|
||||||
|
model_path = Path(model_path)
|
||||||
|
|
||||||
|
if not os.path.exists(model_path):
|
||||||
|
raise Exception(f"Model not found: {model_path}")
|
||||||
|
|
||||||
model_info = self._get_model_info(
|
model_info = self._get_model_info(
|
||||||
model_path=model_path,
|
model_path=model_path,
|
||||||
model_type=model_type,
|
model_class=model_class,
|
||||||
revision=revision,
|
|
||||||
)
|
)
|
||||||
# TODO: variant
|
|
||||||
key = self.get_key(
|
key = self.get_key(
|
||||||
model_path=model_path,
|
model_path=model_path,
|
||||||
model_type=model_type,
|
model_type=model_type, # TODO:
|
||||||
revision=revision,
|
|
||||||
submodel_type=submodel,
|
submodel_type=submodel,
|
||||||
)
|
)
|
||||||
|
|
||||||
# TODO: lock for no copies on simultaneous calls?
|
# TODO: lock for no copies on simultaneous calls?
|
||||||
cache_entry = self._cached_models.get(key, None)
|
cache_entry = self._cached_models.get(key, None)
|
||||||
if cache_entry is None:
|
if cache_entry is None:
|
||||||
self.logger.info(f'Loading model {repo_id_or_path}, type {model_type}:{submodel}')
|
self.logger.info(f'Loading model {model_path}, type {model_type}:{submodel}')
|
||||||
|
|
||||||
# this will remove older cached models until
|
# this will remove older cached models until
|
||||||
# there is sufficient room to load the requested model
|
# there is sufficient room to load the requested model
|
||||||
@ -662,7 +203,7 @@ class ModelCache(object):
|
|||||||
|
|
||||||
# clean memory to make MemoryUsage() more accurate
|
# clean memory to make MemoryUsage() more accurate
|
||||||
gc.collect()
|
gc.collect()
|
||||||
model = model_info.get_model(submodel, torch_dtype=self.precision)
|
model = model_info.get_model(submodel, torch_dtype=self.precision, variant=)
|
||||||
if mem_used := model_info.get_size(submodel):
|
if mem_used := model_info.get_size(submodel):
|
||||||
self.logger.debug(f'CPU RAM used for load: {(mem_used/GIG):.2f} GB')
|
self.logger.debug(f'CPU RAM used for load: {(mem_used/GIG):.2f} GB')
|
||||||
|
|
||||||
@ -732,20 +273,14 @@ class ModelCache(object):
|
|||||||
|
|
||||||
def model_hash(
|
def model_hash(
|
||||||
self,
|
self,
|
||||||
repo_id_or_path: Union[str, Path],
|
model_path: Union[str, Path],
|
||||||
revision: str = "main",
|
|
||||||
) -> str:
|
) -> str:
|
||||||
'''
|
'''
|
||||||
Given the HF repo id or path to a model on disk, returns a unique
|
Given the HF repo id or path to a model on disk, returns a unique
|
||||||
hash. Works for legacy checkpoint files, HF models on disk, and HF repo IDs
|
hash. Works for legacy checkpoint files, HF models on disk, and HF repo IDs
|
||||||
:param repo_id_or_path: repo_id string or Path to model file/directory on disk.
|
:param model_path: Path to model file/directory on disk.
|
||||||
:param revision: optional revision string (if fetching a HF repo_id)
|
|
||||||
'''
|
'''
|
||||||
revision = revision or "main"
|
return self._local_model_hash(model_path)
|
||||||
if Path(repo_id_or_path).is_dir():
|
|
||||||
return self._local_model_hash(repo_id_or_path)
|
|
||||||
else:
|
|
||||||
return self._hf_commit_hash(repo_id_or_path,revision)
|
|
||||||
|
|
||||||
def cache_size(self) -> float:
|
def cache_size(self) -> float:
|
||||||
"Return the current size of the cache, in GB"
|
"Return the current size of the cache, in GB"
|
||||||
@ -840,17 +375,6 @@ class ModelCache(object):
|
|||||||
with open(hashpath, "w") as f:
|
with open(hashpath, "w") as f:
|
||||||
f.write(hash)
|
f.write(hash)
|
||||||
return hash
|
return hash
|
||||||
|
|
||||||
def _hf_commit_hash(self, repo_id: str, revision: str='main') -> str:
|
|
||||||
api = HfApi()
|
|
||||||
info = api.list_repo_refs(
|
|
||||||
repo_id=repo_id,
|
|
||||||
repo_type='model',
|
|
||||||
)
|
|
||||||
desired_revisions = [branch for branch in info.branches if branch.name==revision]
|
|
||||||
if not desired_revisions:
|
|
||||||
raise KeyError(f"Revision '{revision}' not found in {repo_id}")
|
|
||||||
return desired_revisions[0].target_commit
|
|
||||||
|
|
||||||
class SilenceWarnings(object):
|
class SilenceWarnings(object):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
@ -163,7 +163,6 @@ import safetensors
|
|||||||
import safetensors.torch
|
import safetensors.torch
|
||||||
import torch
|
import torch
|
||||||
from diffusers import AutoencoderKL
|
from diffusers import AutoencoderKL
|
||||||
from diffusers.utils import is_safetensors_available
|
|
||||||
from huggingface_hub import scan_cache_dir
|
from huggingface_hub import scan_cache_dir
|
||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf
|
||||||
from omegaconf.dictconfig import DictConfig
|
from omegaconf.dictconfig import DictConfig
|
||||||
@ -172,8 +171,8 @@ import invokeai.backend.util.logging as logger
|
|||||||
from invokeai.app.services.config import InvokeAIAppConfig
|
from invokeai.app.services.config import InvokeAIAppConfig
|
||||||
from invokeai.backend.util import CUDA_DEVICE, download_with_resume
|
from invokeai.backend.util import CUDA_DEVICE, download_with_resume
|
||||||
from ..install.model_install_backend import Dataset_path, hf_download_with_resume
|
from ..install.model_install_backend import Dataset_path, hf_download_with_resume
|
||||||
from .model_cache import (ModelCache, ModelLocker, SDModelType,
|
from .model_cache import ModelCache, ModelLocker, SilenceWarnings
|
||||||
SilenceWarnings)
|
from .models import BaseModelType, ModelType, SubModelType, MODEL_CLASSES
|
||||||
# We are only starting to number the config file with release 3.
|
# We are only starting to number the config file with release 3.
|
||||||
# The config file version doesn't have to start at release version, but it will help
|
# The config file version doesn't have to start at release version, but it will help
|
||||||
# reduce confusion.
|
# reduce confusion.
|
||||||
@ -201,14 +200,6 @@ class InvalidModelError(Exception):
|
|||||||
"Raised when an invalid model is requested"
|
"Raised when an invalid model is requested"
|
||||||
pass
|
pass
|
||||||
|
|
||||||
class SDLegacyType(Enum):
|
|
||||||
V1 = auto()
|
|
||||||
V1_INPAINT = auto()
|
|
||||||
V2 = auto()
|
|
||||||
V2_e = auto()
|
|
||||||
V2_v = auto()
|
|
||||||
UNKNOWN = auto()
|
|
||||||
|
|
||||||
MAX_CACHE_SIZE = 6.0 # GB
|
MAX_CACHE_SIZE = 6.0 # GB
|
||||||
|
|
||||||
|
|
||||||
@ -280,32 +271,45 @@ class ModelManager(object):
|
|||||||
def model_exists(
|
def model_exists(
|
||||||
self,
|
self,
|
||||||
model_name: str,
|
model_name: str,
|
||||||
model_type: SDModelType = SDModelType.Diffusers,
|
base_model: BaseModelType,
|
||||||
|
model_type: ModelType,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""
|
"""
|
||||||
Given a model name, returns True if it is a valid
|
Given a model name, returns True if it is a valid
|
||||||
identifier.
|
identifier.
|
||||||
"""
|
"""
|
||||||
model_key = self.create_key(model_name, model_type)
|
model_key = self.create_key(model_name, base_model, model_type)
|
||||||
return model_key in self.config
|
return model_key in self.config
|
||||||
|
|
||||||
def create_key(self, model_name: str, model_type: SDModelType) -> str:
|
def create_key(
|
||||||
return f"{model_type}/{model_name}"
|
self,
|
||||||
|
model_name: str,
|
||||||
|
base_model: BaseModelType,
|
||||||
|
model_type: ModelType,
|
||||||
|
) -> str:
|
||||||
|
return f"{base_model}/{model_type}/{model_name}"
|
||||||
|
|
||||||
def parse_key(self, model_key: str) -> Tuple[str, SDModelType]:
|
def parse_key(self, model_key: str) -> Tuple[str, BaseModelType, ModelType]:
|
||||||
model_type_str, model_name = model_key.split('/', 1)
|
base_model_str, model_type_str, model_name = model_key.split('/', 2)
|
||||||
try:
|
try:
|
||||||
model_type = SDModelType(model_type_str)
|
model_type = SDModelType(model_type_str)
|
||||||
return (model_name, model_type)
|
|
||||||
except:
|
except:
|
||||||
raise Exception(f"Unknown model type: {model_type_str}")
|
raise Exception(f"Unknown model type: {model_type_str}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
base_model = BaseModelType(base_model_str)
|
||||||
|
except:
|
||||||
|
raise Exception(f"Unknown base model: {base_model_str}")
|
||||||
|
|
||||||
|
return (model_name, base_model, model_type)
|
||||||
|
|
||||||
def get_model(
|
def get_model(
|
||||||
self,
|
self,
|
||||||
model_name: str,
|
model_name: str,
|
||||||
model_type: SDModelType = SDModelType.Diffusers,
|
base_model: BaseModelType,
|
||||||
submodel: Optional[SDModelType] = None,
|
model_type: ModelType,
|
||||||
) -> SDModelInfo:
|
submodel_type: Optional[SubModelType] = None
|
||||||
|
):
|
||||||
"""Given a model named identified in models.yaml, return
|
"""Given a model named identified in models.yaml, return
|
||||||
an SDModelInfo object describing it.
|
an SDModelInfo object describing it.
|
||||||
:param model_name: symbolic name of the model in models.yaml
|
:param model_name: symbolic name of the model in models.yaml
|
||||||
@ -344,210 +348,182 @@ class ModelManager(object):
|
|||||||
# raises an InvalidModelError
|
# raises an InvalidModelError
|
||||||
|
|
||||||
"""
|
"""
|
||||||
model_key = self.create_key(model_name, model_type)
|
|
||||||
if model_key not in self.config:
|
model_class = MODEL_CLASSES[base_model][model_type]
|
||||||
raise InvalidModelError(
|
|
||||||
f'"{model_key}" is not a known model name. Please check your models.yaml file'
|
#if model_type in {
|
||||||
)
|
# ModelType.Lora,
|
||||||
|
# ModelType.ControlNet,
|
||||||
# get the required loading info out of the config file
|
# ModelType.TextualInversion,
|
||||||
mconfig = self.config[model_key]
|
# ModelType.Vae,
|
||||||
|
#}:
|
||||||
# type already checked as it's part of key
|
if not model_class.has_config:
|
||||||
if model_type == SDModelType.Diffusers:
|
#if model_class.Config is None:
|
||||||
# intercept stanzas that point to checkpoint weights and replace them
|
# skip config
|
||||||
# with the equivalent diffusers model
|
# load from
|
||||||
if mconfig.format in ["ckpt", "safetensors"]:
|
# /models/{base_model}/{model_type}/{model_name}
|
||||||
location = self.convert_ckpt_and_cache(mconfig)
|
# /models/{base_model}/{model_type}/{model_name}.{ext}
|
||||||
elif mconfig.get('path'):
|
|
||||||
location = self.globals.root_dir / mconfig.get('path')
|
model_config = None
|
||||||
|
|
||||||
|
for ext in {"pt", "ckpt", "safetensors"}:
|
||||||
|
model_path = os.path.join(model_dir, base_model, model_type, f"{model_name}.{ext}")
|
||||||
|
if os.path.exists(model_path):
|
||||||
|
break
|
||||||
else:
|
else:
|
||||||
location = mconfig.get('repo_id')
|
model_path = os.path.join(model_dir, base_model, model_type, model_name)
|
||||||
elif p := mconfig.get('path'):
|
if not os.path.exists(model_path):
|
||||||
location = self.globals.root_dir / p
|
raise InvalidModelError(
|
||||||
elif r := mconfig.get('repo_id'):
|
f"Model not found - \"{base_model}/{model_type}/{model_name}\" "
|
||||||
location = r
|
)
|
||||||
elif w := mconfig.get('weights'):
|
|
||||||
location = self.globals.root_dir / w
|
|
||||||
else:
|
else:
|
||||||
location = None
|
# find in config
|
||||||
|
model_key = self.create_key(model_name, base_model, model_type)
|
||||||
revision = mconfig.get('revision')
|
if model_key not in self.config:
|
||||||
if model_type in [SDModelType.Lora, SDModelType.TextualInversion]:
|
raise InvalidModelError(
|
||||||
hash = "<NO_HASH>" # TODO:
|
f'"{model_key}" is not a known model name. Please check your models.yaml file'
|
||||||
else:
|
)
|
||||||
hash = self.cache.model_hash(location, revision)
|
|
||||||
|
|
||||||
# If the caller is asking for part of the model and the config indicates
|
model_config = self.config[model_key]
|
||||||
# an external replacement for that field, then we fetch the replacement
|
|
||||||
if submodel and mconfig.get(submodel):
|
# /models/{base_model}/{model_type}/{name}.ckpt or .safentesors
|
||||||
location = mconfig.get(submodel).get('path') \
|
# /models/{base_model}/{model_type}/{name}/
|
||||||
or mconfig.get(submodel).get('repo_id')
|
model_path = model_config.path
|
||||||
model_type = submodel
|
|
||||||
submodel = None
|
|
||||||
|
|
||||||
# to support the traditional way of attaching a VAE
|
# vae/movq override
|
||||||
# to a model, we hacked in `attach_model_part`
|
# TODO:
|
||||||
# TODO:
|
if submodel is not None and submodel in model_config:
|
||||||
if model_type == SDModelType.Vae and "vae" in mconfig:
|
model_path = model_config[submodel]["path"]
|
||||||
print("NOT_IMPLEMENTED - RETURN CUSTOM VAE")
|
model_type = submodel
|
||||||
|
submodel = None
|
||||||
|
|
||||||
model_context = self.cache.get_model(
|
dst_convert_path = None # TODO:
|
||||||
location,
|
model_path = model_class.convert_if_required(
|
||||||
model_type = model_type,
|
model_path,
|
||||||
revision = revision,
|
dst_convert_path,
|
||||||
submodel = submodel,
|
model_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
# in case we need to communicate information about this
|
model_context = self.cache.get_model(
|
||||||
# model to the cache manager, then we need to remember
|
model_path,
|
||||||
# the cache key
|
model_class,
|
||||||
self.cache_keys[model_key] = model_context.key
|
submodel,
|
||||||
|
)
|
||||||
|
|
||||||
|
hash = "<NO_HASH>" # TODO:
|
||||||
|
|
||||||
return SDModelInfo(
|
return SDModelInfo(
|
||||||
context = model_context,
|
context = model_context,
|
||||||
name = model_name,
|
name = model_name,
|
||||||
|
base_model = base_model,
|
||||||
type = submodel or model_type,
|
type = submodel or model_type,
|
||||||
hash = hash,
|
hash = hash,
|
||||||
location = location,
|
location = model_path, # TODO:
|
||||||
revision = revision,
|
|
||||||
precision = self.cache.precision,
|
precision = self.cache.precision,
|
||||||
_cache = self.cache
|
_cache = self.cache,
|
||||||
)
|
)
|
||||||
|
|
||||||
def default_model(self) -> Optional[Tuple[str, SDModelType]]:
|
def default_model(self) -> Optional[Tuple[str, BaseModelType, ModelType]]:
|
||||||
"""
|
"""
|
||||||
Returns the name of the default model, or None
|
Returns the name of the default model, or None
|
||||||
if none is defined.
|
if none is defined.
|
||||||
"""
|
"""
|
||||||
for model_name, model_type in self.model_names():
|
for model_key, model_config in self.config.items():
|
||||||
model_key = self.create_key(model_name, model_type)
|
if model_config.get("default", False):
|
||||||
if self.config[model_key].get("default"):
|
return self.parse_key(model_key)
|
||||||
return (model_name, model_type)
|
|
||||||
return self.model_names()[0][0]
|
|
||||||
|
|
||||||
def set_default_model(self, model_name: str, model_type: SDModelType=SDModelType.Diffusers) -> None:
|
for model_key, _ in self.config.items():
|
||||||
|
return self.parse_key(model_key)
|
||||||
|
else:
|
||||||
|
return None # TODO: or redo as (None, None, None)
|
||||||
|
|
||||||
|
def set_default_model(
|
||||||
|
self,
|
||||||
|
model_name: str,
|
||||||
|
base_model: BaseModelType,
|
||||||
|
model_type: ModelType,
|
||||||
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Set the default model. The change will not take
|
Set the default model. The change will not take
|
||||||
effect until you call model_manager.commit()
|
effect until you call model_manager.commit()
|
||||||
"""
|
"""
|
||||||
assert self.model_exists(model_name, model_type), f"unknown model '{model_name}'"
|
|
||||||
|
|
||||||
config = self.config
|
model_key = self.model_key(model_name, base_model, model_type)
|
||||||
for model_name, model_type in self.model_names():
|
if model_key not in self.config:
|
||||||
key = self.create_key(model_name, model_type)
|
raise Exception(f"Unknown model: {model_key}")
|
||||||
config[key].pop("default", None)
|
|
||||||
config[self.create_key(model_name, model_type)]["default"] = True
|
for cur_model_key, config in self.config.items():
|
||||||
|
if cur_model_key == model_key:
|
||||||
|
config["default"] = True
|
||||||
|
else:
|
||||||
|
config.pop("default", None)
|
||||||
|
|
||||||
def model_info(
|
def model_info(
|
||||||
self,
|
self,
|
||||||
model_name: str,
|
model_name: str,
|
||||||
model_type: SDModelType=SDModelType.Diffusers,
|
base_model: BaseModelType,
|
||||||
|
model_type: ModelType,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
"""
|
"""
|
||||||
Given a model name returns the OmegaConf (dict-like) object describing it.
|
Given a model name returns the OmegaConf (dict-like) object describing it.
|
||||||
"""
|
"""
|
||||||
if not self.model_exists(model_name, model_type):
|
model_key = self.create_key(model_name, base_model, model_type)
|
||||||
return None
|
return self.config.get(model_key, None)
|
||||||
return self.config[self.create_key(model_name, model_type)]
|
|
||||||
|
|
||||||
def model_names(self) -> List[Tuple[str, SDModelType]]:
|
def model_names(self) -> List[Tuple[str, BaseModelType, ModelType]]:
|
||||||
"""
|
"""
|
||||||
Return a list of (str, SDModelType) corresponding to all models
|
Return a list of (str, BaseModelType, ModelType) corresponding to all models
|
||||||
known to the configuration.
|
known to the configuration.
|
||||||
"""
|
"""
|
||||||
return [(self.parse_key(x)) for x in self.config.keys() if isinstance(self.config[x], DictConfig)]
|
return [(self.parse_key(x)) for x in self.config.keys() if isinstance(self.config[x], DictConfig)]
|
||||||
|
|
||||||
def is_legacy(self, model_name: str, model_type: SDModelType.Diffusers) -> bool:
|
def list_models(
|
||||||
|
self,
|
||||||
|
base_model: Optional[BaseModelType] = None,
|
||||||
|
model_type: Optional[SDModelType] = None,
|
||||||
|
) -> Dict[str, Dict[str, str]]:
|
||||||
"""
|
"""
|
||||||
Return true if this is a legacy (.ckpt) model
|
Return a dict of models, in format [base_model][model_type][model_name]
|
||||||
"""
|
|
||||||
# if we are converting legacy files automatically, then
|
|
||||||
# there are no legacy ckpts!
|
|
||||||
if self.globals.ckpt_convert:
|
|
||||||
return False
|
|
||||||
info = self.model_info(model_name, model_type)
|
|
||||||
if "weights" in info and info["weights"].endswith((".ckpt", ".safetensors")):
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
def list_models(self, model_type: SDModelType=None) -> dict[str,dict[str,str]]:
|
|
||||||
"""
|
|
||||||
Return a dict of models, in format [model_type][model_name], with
|
|
||||||
following fields:
|
|
||||||
model_name
|
|
||||||
model_type
|
|
||||||
format
|
|
||||||
description
|
|
||||||
status
|
|
||||||
# for folders only
|
|
||||||
repo_id
|
|
||||||
path
|
|
||||||
subfolder
|
|
||||||
vae
|
|
||||||
# for ckpts only
|
|
||||||
config
|
|
||||||
weights
|
|
||||||
vae
|
|
||||||
|
|
||||||
Please use model_manager.models() to get all the model names,
|
Please use model_manager.models() to get all the model names,
|
||||||
model_manager.model_info('model-name') to get the stanza for the model
|
model_manager.model_info('model-name') to get the stanza for the model
|
||||||
named 'model-name', and model_manager.config to get the full OmegaConf
|
named 'model-name', and model_manager.config to get the full OmegaConf
|
||||||
object derived from models.yaml
|
object derived from models.yaml
|
||||||
"""
|
"""
|
||||||
models = {}
|
assert not(model_type is not None and base_model is None), "model_type must be provided with base_model"
|
||||||
|
|
||||||
|
models = dict()
|
||||||
for model_key in sorted(self.config, key=str.casefold):
|
for model_key in sorted(self.config, key=str.casefold):
|
||||||
stanza = self.config[model_key]
|
stanza = self.config[model_key]
|
||||||
# don't include VAEs in listing (legacy style)
|
|
||||||
if "config" in stanza and "/VAE/" in stanza["config"]:
|
|
||||||
continue
|
|
||||||
if model_key.startswith('_'):
|
if model_key.startswith('_'):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
model_name, stanza_type = self.parse_key(model_key)
|
model_name, m_base_model, stanza_type = self.parse_key(model_key)
|
||||||
|
if base_model is not None and m_base_model != base_model:
|
||||||
|
continue
|
||||||
if model_type is not None and model_type != stanza_type:
|
if model_type is not None and model_type != stanza_type:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
if m_base_model not in models:
|
||||||
|
models[m_base_model] = dict()
|
||||||
if stanza_type not in models:
|
if stanza_type not in models:
|
||||||
models[stanza_type] = dict()
|
models[m_base_model][stanza_type] = dict()
|
||||||
|
|
||||||
models[stanza_type][model_name] = dict()
|
model_class = MODEL_CLASSES[m_base_model][stanza_type]
|
||||||
|
models[m_base_model][stanza_type][model_name] = model_class.build_config(
|
||||||
model_format = stanza.get('format')
|
**stanza,
|
||||||
|
name=model_name,
|
||||||
# Common Attribs
|
base_model=base_model,
|
||||||
description = stanza.get("description", None)
|
type=stanza_type,
|
||||||
models[stanza_type][model_name].update(
|
|
||||||
model_name=model_name,
|
|
||||||
model_type=stanza_type,
|
|
||||||
format=model_format,
|
|
||||||
description=description,
|
|
||||||
status="unknown", # TODO: no more status as model loaded separately
|
|
||||||
)
|
)
|
||||||
|
#models[m_base_model][stanza_type][model_name] = model_class.Config(
|
||||||
# Checkpoint Config Parse
|
# **stanza,
|
||||||
if model_format in ["ckpt","safetensors"]:
|
# name=model_name,
|
||||||
models[stanza_type][model_name].update(
|
# base_model=base_model,
|
||||||
config = str(stanza.get("config", None)),
|
# type=stanza_type,
|
||||||
weights = str(stanza.get("weights", None)),
|
#).dict()
|
||||||
vae = str(stanza.get("vae", None)),
|
|
||||||
)
|
|
||||||
|
|
||||||
# Diffusers Config Parse
|
|
||||||
elif model_format == "folder":
|
|
||||||
if vae := stanza.get("vae", None):
|
|
||||||
if isinstance(vae, DictConfig):
|
|
||||||
vae = dict(
|
|
||||||
repo_id = str(vae.get("repo_id", None)),
|
|
||||||
path = str(vae.get("path", None)),
|
|
||||||
subfolder = str(vae.get("subfolder", None)),
|
|
||||||
)
|
|
||||||
|
|
||||||
models[stanza_type][model_name].update(
|
|
||||||
vae = vae,
|
|
||||||
repo_id = str(stanza.get("repo_id", None)),
|
|
||||||
path = str(stanza.get("path", None)),
|
|
||||||
)
|
|
||||||
|
|
||||||
return models
|
return models
|
||||||
|
|
||||||
@ -557,7 +533,7 @@ class ModelManager(object):
|
|||||||
"""
|
"""
|
||||||
for model_type, model_dict in self.list_models().items():
|
for model_type, model_dict in self.list_models().items():
|
||||||
for model_name, model_info in model_dict.items():
|
for model_name, model_info in model_dict.items():
|
||||||
line = f'{model_info["model_name"]:25s} {model_info["status"]:>15s} {model_info["model_type"]:10s} {model_info["description"]}'
|
line = f'{model_info["name"]:25s} {model_info["status"]:>15s} {model_info["type"]:10s} {model_info["description"]}'
|
||||||
if model_info["status"] in ["in gpu","locked in gpu"]:
|
if model_info["status"] in ["in gpu","locked in gpu"]:
|
||||||
line = f"\033[1m{line}\033[0m"
|
line = f"\033[1m{line}\033[0m"
|
||||||
print(line)
|
print(line)
|
||||||
@ -606,7 +582,8 @@ class ModelManager(object):
|
|||||||
def add_model(
|
def add_model(
|
||||||
self,
|
self,
|
||||||
model_name: str,
|
model_name: str,
|
||||||
model_type: SDModelType,
|
base_model: BaseModelType,
|
||||||
|
model_type: ModelType,
|
||||||
model_attributes: dict,
|
model_attributes: dict,
|
||||||
clobber: bool = False,
|
clobber: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
@ -618,38 +595,31 @@ class ModelManager(object):
|
|||||||
attributes are incorrect or the model name is missing.
|
attributes are incorrect or the model name is missing.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if model_type == SDModelType.Fiffusers:
|
model_class = MODEL_CLASSES[base_model][model_type]
|
||||||
# TODO: automaticaly or manualy?
|
|
||||||
#assert "format" in model_attributes, 'missing required field "format"'
|
|
||||||
model_format = "ckpt" if "weights" in model_attributes else "diffusers"
|
|
||||||
|
|
||||||
if model_format == "diffusers":
|
model_class.build_config(
|
||||||
assert (
|
**model_attributes,
|
||||||
"description" in model_attributes
|
name=model_name,
|
||||||
), 'required field "description" is missing'
|
base_model=base_model,
|
||||||
assert (
|
type=model_type,
|
||||||
"path" in model_attributes or "repo_id" in model_attributes
|
)
|
||||||
), 'model must have either the "path" or "repo_id" fields defined'
|
#model_cfg = model_class.Config(
|
||||||
|
# **model_attributes,
|
||||||
|
# name=model_name,
|
||||||
|
# base_model=base_model,
|
||||||
|
# type=model_type,
|
||||||
|
#)
|
||||||
|
|
||||||
elif model_format == "ckpt":
|
model_key = self.create_key(model_name, base_model, model_type)
|
||||||
for field in ("description", "weights", "config"):
|
|
||||||
assert field in model_attributes, f"required field {field} is missing"
|
|
||||||
|
|
||||||
else:
|
|
||||||
assert "weights" in model_attributes and "description" in model_attributes
|
|
||||||
|
|
||||||
model_key = self.create_key(model_name, model_type)
|
|
||||||
|
|
||||||
assert (
|
assert (
|
||||||
clobber or model_key not in self.config
|
clobber or model_key not in self.config
|
||||||
), f'attempt to overwrite existing model definition "{model_key}"'
|
), f'attempt to overwrite existing model definition "{model_key}"'
|
||||||
|
|
||||||
self.config[model_key] = model_attributes
|
self.config[model_key] = model_attributes
|
||||||
|
|
||||||
if "weights" in self.config[model_key]:
|
|
||||||
self.config[model_key]["weights"].replace("\\", "/")
|
|
||||||
|
|
||||||
if clobber and model_key in self.cache_keys:
|
if clobber and model_key in self.cache_keys:
|
||||||
|
# TODO:
|
||||||
self.cache.uncache_model(self.cache_keys[model_key])
|
self.cache.uncache_model(self.cache_keys[model_key])
|
||||||
del self.cache_keys[model_key]
|
del self.cache_keys[model_key]
|
||||||
|
|
||||||
@ -741,326 +711,6 @@ class ModelManager(object):
|
|||||||
),
|
),
|
||||||
True
|
True
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def probe_model_type(self, checkpoint: dict) -> SDLegacyType:
|
|
||||||
"""
|
|
||||||
Given a pickle or safetensors model object, probes contents
|
|
||||||
of the object and returns an SDLegacyType indicating its
|
|
||||||
format. Valid return values include:
|
|
||||||
SDLegacyType.V1
|
|
||||||
SDLegacyType.V1_INPAINT
|
|
||||||
SDLegacyType.V2 (V2 prediction type unknown)
|
|
||||||
SDLegacyType.V2_e (V2 using 'epsilon' prediction type)
|
|
||||||
SDLegacyType.V2_v (V2 using 'v_prediction' prediction type)
|
|
||||||
SDLegacyType.UNKNOWN
|
|
||||||
"""
|
|
||||||
global_step = checkpoint.get("global_step")
|
|
||||||
state_dict = checkpoint.get("state_dict") or checkpoint
|
|
||||||
|
|
||||||
try:
|
|
||||||
key_name = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight"
|
|
||||||
if key_name in state_dict and state_dict[key_name].shape[-1] == 1024:
|
|
||||||
if global_step == 220000:
|
|
||||||
return SDLegacyType.V2_e
|
|
||||||
elif global_step == 110000:
|
|
||||||
return SDLegacyType.V2_v
|
|
||||||
else:
|
|
||||||
return SDLegacyType.V2
|
|
||||||
# otherwise we assume a V1 file
|
|
||||||
in_channels = state_dict[
|
|
||||||
"model.diffusion_model.input_blocks.0.0.weight"
|
|
||||||
].shape[1]
|
|
||||||
if in_channels == 9:
|
|
||||||
return SDLegacyType.V1_INPAINT
|
|
||||||
elif in_channels == 4:
|
|
||||||
return SDLegacyType.V1
|
|
||||||
else:
|
|
||||||
return SDLegacyType.UNKNOWN
|
|
||||||
except KeyError:
|
|
||||||
return SDLegacyType.UNKNOWN
|
|
||||||
|
|
||||||
def heuristic_import(
|
|
||||||
self,
|
|
||||||
path_url_or_repo: str,
|
|
||||||
model_name: Optional[str] = None,
|
|
||||||
description: Optional[str] = None,
|
|
||||||
model_config_file: Optional[Path] = None,
|
|
||||||
commit_to_conf: Optional[Path] = None,
|
|
||||||
config_file_callback: Optional[Callable[[Path], Path]] = None,
|
|
||||||
) -> str:
|
|
||||||
"""Accept a string which could be:
|
|
||||||
- a HF diffusers repo_id
|
|
||||||
- a URL pointing to a legacy .ckpt or .safetensors file
|
|
||||||
- a local path pointing to a legacy .ckpt or .safetensors file
|
|
||||||
- a local directory containing .ckpt and .safetensors files
|
|
||||||
- a local directory containing a diffusers model
|
|
||||||
|
|
||||||
After determining the nature of the model and downloading it
|
|
||||||
(if necessary), the file is probed to determine the correct
|
|
||||||
configuration file (if needed) and it is imported.
|
|
||||||
|
|
||||||
The model_name and/or description can be provided. If not, they will
|
|
||||||
be generated automatically.
|
|
||||||
|
|
||||||
If commit_to_conf is provided, the newly loaded model will be written
|
|
||||||
to the `models.yaml` file at the indicated path. Otherwise, the changes
|
|
||||||
will only remain in memory.
|
|
||||||
|
|
||||||
The routine will do its best to figure out the config file
|
|
||||||
needed to convert legacy checkpoint file, but if it can't it
|
|
||||||
will call the config_file_callback routine, if provided. The
|
|
||||||
callback accepts a single argument, the Path to the checkpoint
|
|
||||||
file, and returns a Path to the config file to use.
|
|
||||||
|
|
||||||
The (potentially derived) name of the model is returned on
|
|
||||||
success, or None on failure. When multiple models are added
|
|
||||||
from a directory, only the last imported one is returned.
|
|
||||||
|
|
||||||
"""
|
|
||||||
model_path: Path = None
|
|
||||||
thing = str(path_url_or_repo) # to save typing
|
|
||||||
|
|
||||||
self.logger.info(f"Probing {thing} for import")
|
|
||||||
|
|
||||||
if thing.startswith(("http:", "https:", "ftp:")):
|
|
||||||
self.logger.info(f"{thing} appears to be a URL")
|
|
||||||
model_path = self._resolve_path(
|
|
||||||
thing, "models/ldm/stable-diffusion-v1"
|
|
||||||
) # _resolve_path does a download if needed
|
|
||||||
|
|
||||||
elif Path(thing).is_file() and thing.endswith((".ckpt", ".safetensors")):
|
|
||||||
if Path(thing).stem in ["model", "diffusion_pytorch_model"]:
|
|
||||||
self.logger.debug(f"{Path(thing).name} appears to be part of a diffusers model. Skipping import")
|
|
||||||
return
|
|
||||||
else:
|
|
||||||
self.logger.debug(f"{thing} appears to be a checkpoint file on disk")
|
|
||||||
model_path = self._resolve_path(thing, "models/ldm/stable-diffusion-v1")
|
|
||||||
|
|
||||||
elif Path(thing).is_dir() and Path(thing, "model_index.json").exists():
|
|
||||||
self.logger.debug(f"{thing} appears to be a diffusers file on disk")
|
|
||||||
model_name = self.import_diffuser_model(
|
|
||||||
thing,
|
|
||||||
vae=dict(repo_id="stabilityai/sd-vae-ft-mse"),
|
|
||||||
model_name=model_name,
|
|
||||||
description=description,
|
|
||||||
commit_to_conf=commit_to_conf,
|
|
||||||
)
|
|
||||||
|
|
||||||
elif Path(thing).is_dir():
|
|
||||||
if (Path(thing) / "model_index.json").exists():
|
|
||||||
self.logger.debug(f"{thing} appears to be a diffusers model.")
|
|
||||||
model_name = self.import_diffuser_model(
|
|
||||||
thing, commit_to_conf=commit_to_conf
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self.logger.debug(f"{thing} appears to be a directory. Will scan for models to import")
|
|
||||||
for m in list(Path(thing).rglob("*.ckpt")) + list(
|
|
||||||
Path(thing).rglob("*.safetensors")
|
|
||||||
):
|
|
||||||
if model_name := self.heuristic_import(
|
|
||||||
str(m),
|
|
||||||
commit_to_conf=commit_to_conf,
|
|
||||||
config_file_callback=config_file_callback,
|
|
||||||
):
|
|
||||||
self.logger.info(f"{model_name} successfully imported")
|
|
||||||
return model_name
|
|
||||||
|
|
||||||
elif re.match(r"^[\w.+-]+/[\w.+-]+$", thing):
|
|
||||||
self.logger.debug(f"{thing} appears to be a HuggingFace diffusers repo_id")
|
|
||||||
model_name = self.import_diffuser_model(
|
|
||||||
thing, commit_to_conf=commit_to_conf
|
|
||||||
)
|
|
||||||
pipeline, _, _, _ = self._load_diffusers_model(self.config[model_name])
|
|
||||||
return model_name
|
|
||||||
else:
|
|
||||||
self.logger.warning(f"{thing}: Unknown thing. Please provide a URL, file path, directory or HuggingFace repo_id")
|
|
||||||
|
|
||||||
# Model_path is set in the event of a legacy checkpoint file.
|
|
||||||
# If not set, we're all done
|
|
||||||
if not model_path:
|
|
||||||
return
|
|
||||||
|
|
||||||
if model_path.stem in self.config: # already imported
|
|
||||||
self.logger.debug("Already imported. Skipping")
|
|
||||||
return model_path.stem
|
|
||||||
|
|
||||||
# another round of heuristics to guess the correct config file.
|
|
||||||
checkpoint = None
|
|
||||||
if model_path.suffix in [".ckpt", ".pt"]:
|
|
||||||
self.cache.scan_model(model_path, model_path)
|
|
||||||
checkpoint = torch.load(model_path)
|
|
||||||
else:
|
|
||||||
checkpoint = safetensors.torch.load_file(model_path)
|
|
||||||
|
|
||||||
# additional probing needed if no config file provided
|
|
||||||
if model_config_file is None:
|
|
||||||
# look for a like-named .yaml file in same directory
|
|
||||||
if model_path.with_suffix(".yaml").exists():
|
|
||||||
model_config_file = model_path.with_suffix(".yaml")
|
|
||||||
self.logger.debug(f"Using config file {model_config_file.name}")
|
|
||||||
|
|
||||||
else:
|
|
||||||
model_type = self.probe_model_type(checkpoint)
|
|
||||||
if model_type == SDLegacyType.V1:
|
|
||||||
self.logger.debug("SD-v1 model detected")
|
|
||||||
model_config_file = self.globals.legacy_conf_path / "v1-inference.yaml"
|
|
||||||
elif model_type == SDLegacyType.V1_INPAINT:
|
|
||||||
self.logger.debug("SD-v1 inpainting model detected")
|
|
||||||
model_config_file = self.globals.legacy_conf_path / "v1-inpainting-inference.yaml"
|
|
||||||
elif model_type == SDLegacyType.V2_v:
|
|
||||||
self.logger.debug("SD-v2-v model detected")
|
|
||||||
model_config_file = self.globals.legacy_conf_path / "v2-inference-v.yaml"
|
|
||||||
elif model_type == SDLegacyType.V2_e:
|
|
||||||
self.logger.debug("SD-v2-e model detected")
|
|
||||||
model_config_file = self.globals.legacy_conf_path / "v2-inference.yaml"
|
|
||||||
elif model_type == SDLegacyType.V2:
|
|
||||||
self.logger.warning(
|
|
||||||
f"{thing} is a V2 checkpoint file, but its parameterization cannot be determined."
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self.logger.warning(
|
|
||||||
f"{thing} is a legacy checkpoint file but not a known Stable Diffusion model."
|
|
||||||
)
|
|
||||||
|
|
||||||
if not model_config_file and config_file_callback:
|
|
||||||
model_config_file = config_file_callback(model_path)
|
|
||||||
|
|
||||||
# despite our best efforts, we could not find a model config file, so give up
|
|
||||||
if not model_config_file:
|
|
||||||
return
|
|
||||||
|
|
||||||
# look for a custom vae, a like-named file ending with .vae in the same directory
|
|
||||||
vae_path = None
|
|
||||||
for suffix in ["pt", "ckpt", "safetensors"]:
|
|
||||||
if (model_path.with_suffix(f".vae.{suffix}")).exists():
|
|
||||||
vae_path = model_path.with_suffix(f".vae.{suffix}")
|
|
||||||
self.logger.debug(f"Using VAE file {vae_path.name}")
|
|
||||||
vae = None if vae_path else dict(repo_id="stabilityai/sd-vae-ft-mse")
|
|
||||||
|
|
||||||
diffuser_path = self.globals.converted_ckpts_dir / model_path.stem
|
|
||||||
with SilenceWarnings():
|
|
||||||
model_name = self.convert_and_import(
|
|
||||||
model_path,
|
|
||||||
diffusers_path=diffuser_path,
|
|
||||||
vae=vae,
|
|
||||||
vae_path=str(vae_path),
|
|
||||||
model_name=model_name,
|
|
||||||
model_description=description,
|
|
||||||
original_config_file=model_config_file,
|
|
||||||
commit_to_conf=commit_to_conf,
|
|
||||||
scan_needed=False,
|
|
||||||
)
|
|
||||||
return model_name
|
|
||||||
|
|
||||||
def convert_ckpt_and_cache(self, mconfig: DictConfig) -> Path:
|
|
||||||
"""
|
|
||||||
Convert the checkpoint model indicated in mconfig into a
|
|
||||||
diffusers, cache it to disk, and return Path to converted
|
|
||||||
file. If already on disk then just returns Path.
|
|
||||||
"""
|
|
||||||
weights = self.globals.root_dir / mconfig.weights
|
|
||||||
config_file = self.globals.root_dir / mconfig.config
|
|
||||||
diffusers_path = self.globals.converted_ckpts_dir / weights.stem
|
|
||||||
|
|
||||||
# return cached version if it exists
|
|
||||||
if diffusers_path.exists():
|
|
||||||
return diffusers_path
|
|
||||||
|
|
||||||
vae_ckpt_path, vae_model = self._get_vae_for_conversion(weights, mconfig)
|
|
||||||
|
|
||||||
# to avoid circular import errors
|
|
||||||
from .convert_ckpt_to_diffusers import convert_ckpt_to_diffusers
|
|
||||||
with SilenceWarnings():
|
|
||||||
convert_ckpt_to_diffusers(
|
|
||||||
weights,
|
|
||||||
diffusers_path,
|
|
||||||
extract_ema=True,
|
|
||||||
original_config_file=config_file,
|
|
||||||
vae=vae_model,
|
|
||||||
vae_path=str(self.globals.root_dir / vae_ckpt_path) if vae_ckpt_path else None,
|
|
||||||
scan_needed=True,
|
|
||||||
)
|
|
||||||
return diffusers_path
|
|
||||||
|
|
||||||
def convert_vae_ckpt_and_cache(self, mconfig: DictConfig) -> Path:
|
|
||||||
"""
|
|
||||||
Convert the VAE indicated in mconfig into a diffusers AutoencoderKL
|
|
||||||
object, cache it to disk, and return Path to converted
|
|
||||||
file. If already on disk then just returns Path.
|
|
||||||
"""
|
|
||||||
root = self.globals.root_dir
|
|
||||||
weights_file = root / mconfig.weights
|
|
||||||
config_file = root / mconfig.config
|
|
||||||
diffusers_path = self.globals.converted_ckpts_dir / weights_file.stem
|
|
||||||
image_size = mconfig.get('width') or mconfig.get('height') or 512
|
|
||||||
|
|
||||||
# return cached version if it exists
|
|
||||||
if diffusers_path.exists():
|
|
||||||
return diffusers_path
|
|
||||||
|
|
||||||
# this avoids circular import error
|
|
||||||
from .convert_ckpt_to_diffusers import convert_ldm_vae_to_diffusers
|
|
||||||
checkpoint = torch.load(weights_file, map_location="cpu")\
|
|
||||||
if weights_file.suffix in ['.ckpt','.pt'] \
|
|
||||||
else safetensors.torch.load_file(weights_file)
|
|
||||||
|
|
||||||
# sometimes weights are hidden under "state_dict", and sometimes not
|
|
||||||
if "state_dict" in checkpoint:
|
|
||||||
checkpoint = checkpoint["state_dict"]
|
|
||||||
|
|
||||||
config = OmegaConf.load(config_file)
|
|
||||||
|
|
||||||
vae_model = convert_ldm_vae_to_diffusers(
|
|
||||||
checkpoint = checkpoint,
|
|
||||||
vae_config = config,
|
|
||||||
image_size = image_size
|
|
||||||
)
|
|
||||||
vae_model.save_pretrained(
|
|
||||||
diffusers_path,
|
|
||||||
safe_serialization=is_safetensors_available()
|
|
||||||
)
|
|
||||||
return diffusers_path
|
|
||||||
|
|
||||||
def _get_vae_for_conversion(
|
|
||||||
self,
|
|
||||||
weights: Path,
|
|
||||||
mconfig: DictConfig
|
|
||||||
) -> Tuple[Path, AutoencoderKL]:
|
|
||||||
# VAE handling is convoluted
|
|
||||||
# 1. If there is a .vae.ckpt file sharing same stem as weights, then use
|
|
||||||
# it as the vae_path passed to convert
|
|
||||||
vae_ckpt_path = None
|
|
||||||
vae_diffusers_location = None
|
|
||||||
vae_model = None
|
|
||||||
for suffix in ["pt", "ckpt", "safetensors"]:
|
|
||||||
if (weights.with_suffix(f".vae.{suffix}")).exists():
|
|
||||||
vae_ckpt_path = weights.with_suffix(f".vae.{suffix}")
|
|
||||||
self.logger.debug(f"Using VAE file {vae_ckpt_path.name}")
|
|
||||||
if vae_ckpt_path:
|
|
||||||
return (vae_ckpt_path, None)
|
|
||||||
|
|
||||||
# 2. If mconfig has a vae weights path, then we use that as vae_path
|
|
||||||
vae_config = mconfig.get('vae')
|
|
||||||
if vae_config and isinstance(vae_config,str):
|
|
||||||
vae_ckpt_path = vae_config
|
|
||||||
return (vae_ckpt_path, None)
|
|
||||||
|
|
||||||
# 3. If mconfig has a vae dict, then we use it as the diffusers-style vae
|
|
||||||
if vae_config and isinstance(vae_config,DictConfig):
|
|
||||||
vae_diffusers_location = self.globals.root_dir / vae_config.get('path') \
|
|
||||||
if vae_config.get('path') \
|
|
||||||
else vae_config.get('repo_id')
|
|
||||||
|
|
||||||
# 4. Otherwise, we use stabilityai/sd-vae-ft-mse "because it works"
|
|
||||||
else:
|
|
||||||
vae_diffusers_location = "stabilityai/sd-vae-ft-mse"
|
|
||||||
|
|
||||||
if vae_diffusers_location:
|
|
||||||
vae_model = self.cache.get_model(vae_diffusers_location, SDModelType.Vae).model
|
|
||||||
return (None, vae_model)
|
|
||||||
|
|
||||||
return (None, None)
|
|
||||||
|
|
||||||
def convert_and_import(
|
def convert_and_import(
|
||||||
self,
|
self,
|
||||||
|
726
invokeai/backend/model_management/models.py
Normal file
726
invokeai/backend/model_management/models.py
Normal file
@ -0,0 +1,726 @@
|
|||||||
|
import sys
|
||||||
|
from enum import Enum
|
||||||
|
import torch
|
||||||
|
import safetensors.torch
|
||||||
|
from diffusers.utils import is_safetensors_available
|
||||||
|
|
||||||
|
class BaseModelType(str, Enum):
|
||||||
|
#StableDiffusion1_5 = "stable_diffusion_1_5"
|
||||||
|
#StableDiffusion2 = "stable_diffusion_2"
|
||||||
|
#StableDiffusion2Base = "stable_diffusion_2_base"
|
||||||
|
# TODO: maybe then add sample size(512/768)?
|
||||||
|
StableDiffusion1_5 = "SD-1"
|
||||||
|
StableDiffusion2Base = "SD-2-base" # 512 pixels; this will have epsilon parameterization
|
||||||
|
StableDiffusion2 = "SD-2" # 768 pixels; this will have v-prediction parameterization
|
||||||
|
#Kandinsky2_1 = "kandinsky_2_1"
|
||||||
|
|
||||||
|
class ModelType(str, Enum):
|
||||||
|
Pipeline = "pipeline"
|
||||||
|
Classifier = "classifier"
|
||||||
|
Vae = "vae"
|
||||||
|
|
||||||
|
Lora = "lora"
|
||||||
|
ControlNet = "controlnet"
|
||||||
|
TextualInversion = "embedding"
|
||||||
|
|
||||||
|
class SubModelType:
|
||||||
|
UNet = "unet"
|
||||||
|
TextEncoder = "text_encoder"
|
||||||
|
Tokenizer = "tokenizer"
|
||||||
|
Vae = "vae"
|
||||||
|
Scheduler = "scheduler"
|
||||||
|
SafetyChecker = "safety_checker"
|
||||||
|
#MoVQ = "movq"
|
||||||
|
|
||||||
|
MODEL_CLASSES = {
|
||||||
|
BaseModel.StableDiffusion1_5: {
|
||||||
|
ModelType.Pipeline: StableDiffusionModel,
|
||||||
|
ModelType.Classifier: ClassifierModel,
|
||||||
|
ModelType.Vae: VaeModel,
|
||||||
|
ModelType.Lora: LoraModel,
|
||||||
|
ModelType.ControlNet: ControlNetModel,
|
||||||
|
ModelType.TextualInversion: TextualInversionModel,
|
||||||
|
},
|
||||||
|
BaseModel.StableDiffusion2: {
|
||||||
|
ModelType.Pipeline: StableDiffusionModel,
|
||||||
|
ModelType.Classifier: ClassifierModel,
|
||||||
|
ModelType.Vae: VaeModel,
|
||||||
|
ModelType.Lora: LoraModel,
|
||||||
|
ModelType.ControlNet: ControlNetModel,
|
||||||
|
ModelType.TextualInversion: TextualInversionModel,
|
||||||
|
},
|
||||||
|
BaseModel.StableDiffusion2Base: {
|
||||||
|
ModelType.Pipeline: StableDiffusionModel,
|
||||||
|
ModelType.Classifier: ClassifierModel,
|
||||||
|
ModelType.Vae: VaeModel,
|
||||||
|
ModelType.Lora: LoraModel,
|
||||||
|
ModelType.ControlNet: ControlNetModel,
|
||||||
|
ModelType.TextualInversion: TextualInversionModel,
|
||||||
|
},
|
||||||
|
#BaseModel.Kandinsky2_1: {
|
||||||
|
# ModelType.Pipeline: Kandinsky2_1Model,
|
||||||
|
# ModelType.Classifier: ClassifierModel,
|
||||||
|
# ModelType.MoVQ: MoVQModel,
|
||||||
|
# ModelType.Lora: LoraModel,
|
||||||
|
# ModelType.ControlNet: ControlNetModel,
|
||||||
|
# ModelType.TextualInversion: TextualInversionModel,
|
||||||
|
#},
|
||||||
|
}
|
||||||
|
|
||||||
|
class EmptyConfigLoader(ConfigMixin):
|
||||||
|
@classmethod
|
||||||
|
def load_config(cls, *args, **kwargs):
|
||||||
|
cls.config_name = kwargs.pop("config_name")
|
||||||
|
return super().load_config(*args, **kwargs)
|
||||||
|
|
||||||
|
class ModelBase:
|
||||||
|
#model_path: str
|
||||||
|
#base_model: BaseModelType
|
||||||
|
#model_type: ModelType
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_path: str,
|
||||||
|
base_model: BaseModelType,
|
||||||
|
model_type: ModelType,
|
||||||
|
):
|
||||||
|
self.model_path = model_path
|
||||||
|
self.base_model = base_model
|
||||||
|
self.model_type = model_type
|
||||||
|
|
||||||
|
def _hf_definition_to_type(self, subtypes: List[str]) -> Type:
|
||||||
|
if len(subtypes) < 2:
|
||||||
|
raise Exception("Invalid subfolder definition!")
|
||||||
|
if subtypes[0] in ["diffusers", "transformers"]:
|
||||||
|
res_type = sys.modules[subtypes[0]]
|
||||||
|
subtypes = subtypes[1:]
|
||||||
|
|
||||||
|
else:
|
||||||
|
res_type = sys.modules["diffusers"]
|
||||||
|
res_type = getattr(res_type, "pipelines")
|
||||||
|
|
||||||
|
|
||||||
|
for subtype in subtypes:
|
||||||
|
res_type = getattr(res_type, subtype)
|
||||||
|
return res_type
|
||||||
|
|
||||||
|
|
||||||
|
class DiffusersModel(ModelBase):
|
||||||
|
#child_types: Dict[str, Type]
|
||||||
|
#child_sizes: Dict[str, int]
|
||||||
|
|
||||||
|
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
|
||||||
|
super().__init__(model_path, base_model, model_type)
|
||||||
|
|
||||||
|
self.child_types: Dict[str, Type] = dict()
|
||||||
|
self.child_sizes: Dict[str, int] = dict()
|
||||||
|
|
||||||
|
try:
|
||||||
|
config_data = DiffusionPipeline.load_config(self.model_path)
|
||||||
|
#config_data = json.loads(os.path.join(self.model_path, "model_index.json"))
|
||||||
|
except:
|
||||||
|
raise Exception("Invalid diffusers model! (model_index.json not found or invalid)")
|
||||||
|
|
||||||
|
config_data.pop("_ignore_files", None)
|
||||||
|
|
||||||
|
# retrieve all folder_names that contain relevant files
|
||||||
|
child_components = [k for k, v in config_data.items() if isinstance(v, list)]
|
||||||
|
|
||||||
|
for child_name in child_components:
|
||||||
|
child_type = self._hf_definition_to_type(config_data[child_name])
|
||||||
|
self.child_types[child_name] = child_type
|
||||||
|
self.child_sizes[child_name] = calc_model_size_by_fs(self.model_path, subfolder=child_name)
|
||||||
|
|
||||||
|
|
||||||
|
def get_size(self, child_type: Optional[SubModelType] = None):
|
||||||
|
if child_type is None:
|
||||||
|
return sum(self.child_sizes.values())
|
||||||
|
else:
|
||||||
|
return self.child_sizes[child_type]
|
||||||
|
|
||||||
|
|
||||||
|
def get_model(
|
||||||
|
self,
|
||||||
|
torch_dtype: Optional[torch.dtype],
|
||||||
|
child_type: Optional[SubModelType] = None,
|
||||||
|
):
|
||||||
|
# return pipeline in different function to pass more arguments
|
||||||
|
if child_type is None:
|
||||||
|
raise Exception("Child model type can't be null on diffusers model")
|
||||||
|
if child_type not in self.child_types:
|
||||||
|
return None # TODO: or raise
|
||||||
|
|
||||||
|
if torch_dtype == torch.float16:
|
||||||
|
variants = ["fp16", None]
|
||||||
|
else:
|
||||||
|
variants = [None, "fp16"]
|
||||||
|
|
||||||
|
# TODO: better error handling(differentiate not found from others)
|
||||||
|
for variant in variants:
|
||||||
|
try:
|
||||||
|
# TODO: set cache_dir to /dev/null to be sure that cache not used?
|
||||||
|
model = self.child_types[child_type].from_pretrained(
|
||||||
|
self.model_path,
|
||||||
|
subfolder=child_type.value,
|
||||||
|
torch_dtype=torch_dtype,
|
||||||
|
variant=variant,
|
||||||
|
local_files_only=True,
|
||||||
|
)
|
||||||
|
break
|
||||||
|
except Exception as e:
|
||||||
|
print("====ERR LOAD====")
|
||||||
|
print(f"{variant}: {e}")
|
||||||
|
|
||||||
|
# calc more accurate size
|
||||||
|
self.child_sizes[child_type] = calc_model_size_by_data(model)
|
||||||
|
return model
|
||||||
|
|
||||||
|
#def convert_if_required(model_path: Union[str, Path], cache_path: str, config: Optional[dict]) -> Path:
|
||||||
|
|
||||||
|
|
||||||
|
class StableDiffusionModel(DiffusersModel):
|
||||||
|
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
|
||||||
|
assert base_model in {
|
||||||
|
BaseModelType.StableDiffusion1_5,
|
||||||
|
BaseModelType.StableDiffusion2,
|
||||||
|
BaseModelType.StableDiffusion2Base,
|
||||||
|
}
|
||||||
|
assert model_type == ModelType.Pipeline
|
||||||
|
super().__init__(model_path, base_model, model_type)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def convert_if_required(model_path: Union[str, Path], dst_path: str, config: Optional[dict]) -> Path:
|
||||||
|
if not isinstance(model_path, Path):
|
||||||
|
model_path = Path(model_path)
|
||||||
|
|
||||||
|
# TODO: args
|
||||||
|
# TODO: set model_path, to config? pass dst_path as arg?
|
||||||
|
# TODO: check
|
||||||
|
return _convert_ckpt_and_cache(config)
|
||||||
|
|
||||||
|
class classproperty(object): # pylint: disable=invalid-name
|
||||||
|
"""Class property decorator.
|
||||||
|
|
||||||
|
Example usage:
|
||||||
|
|
||||||
|
class MyClass(object):
|
||||||
|
|
||||||
|
@classproperty
|
||||||
|
def value(cls):
|
||||||
|
return '123'
|
||||||
|
|
||||||
|
> print MyClass.value
|
||||||
|
123
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, func):
|
||||||
|
self._func = func
|
||||||
|
|
||||||
|
def __get__(self, owner_self, owner_cls):
|
||||||
|
return self._func(owner_cls)
|
||||||
|
|
||||||
|
class ModelConfigBase(BaseModel):
|
||||||
|
path: str # or Path
|
||||||
|
name: str
|
||||||
|
description: Optional[str]
|
||||||
|
|
||||||
|
|
||||||
|
class StableDiffusionDModel(DiffusersModel):
|
||||||
|
class Config(ModelConfigBase):
|
||||||
|
format: str
|
||||||
|
vae: Optional[str] = Field(None)
|
||||||
|
config: Optional[str] = Field(None)
|
||||||
|
|
||||||
|
@root_validator
|
||||||
|
def validator(cls, values):
|
||||||
|
if values["format"] not in {"checkpoint", "diffusers"}:
|
||||||
|
raise ValueError(f"Unkown stable diffusion model format: {values['format']}")
|
||||||
|
if values["config"] is not None and values["format"] != "checkpoint":
|
||||||
|
raise ValueError(f"Custom config field allowed only in checkpoint stable diffusion model")
|
||||||
|
return values
|
||||||
|
|
||||||
|
# return config only for checkpoint format
|
||||||
|
def dict(self, *args, **kwargs):
|
||||||
|
result = super().dict(*args, **kwargs)
|
||||||
|
if self.format != "checkpoint":
|
||||||
|
result.pop("config", None)
|
||||||
|
return result
|
||||||
|
|
||||||
|
@classproperty
|
||||||
|
def has_config(self):
|
||||||
|
return True
|
||||||
|
|
||||||
|
def build_config(self, **kwargs) -> dict:
|
||||||
|
try:
|
||||||
|
res = dict(
|
||||||
|
path=kwargs["path"],
|
||||||
|
name=kwargs["name"],
|
||||||
|
description=kwargs.get("description", None),
|
||||||
|
|
||||||
|
format=kwargs["format"],
|
||||||
|
vae=kwargs.get("vae", None),
|
||||||
|
)
|
||||||
|
if res["format"] not in {"checkpoint", "diffusers"}:
|
||||||
|
raise Exception(f"Unkonwn stable diffusion model format: {res['format']}")
|
||||||
|
if res["format"] == "checkpoint":
|
||||||
|
res["config"] = kwargs.get("config", None)
|
||||||
|
# TODO: raise if config specified for diffusers?
|
||||||
|
|
||||||
|
return res
|
||||||
|
|
||||||
|
except KeyError as e:
|
||||||
|
raise Exception(f"Field \"{e.args[0]}\" not found!")
|
||||||
|
|
||||||
|
|
||||||
|
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
|
||||||
|
assert base_model == BaseModelType.StableDiffusion1_5
|
||||||
|
assert model_type == ModelType.Pipeline
|
||||||
|
super().__init__(model_path, base_model, model_type)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def convert_if_required(cls, model_path: str, dst_path: str, config: Optional[dict]) -> str:
|
||||||
|
model_config = cls.Config(
|
||||||
|
**config,
|
||||||
|
path=model_path,
|
||||||
|
name="",
|
||||||
|
)
|
||||||
|
|
||||||
|
if hasattr(model_config, "config"):
|
||||||
|
convert_ckpt_and_cache(
|
||||||
|
model_path=model_path,
|
||||||
|
dst_path=dst_path,
|
||||||
|
config=config,
|
||||||
|
)
|
||||||
|
return dst_path
|
||||||
|
|
||||||
|
else:
|
||||||
|
return model_path
|
||||||
|
|
||||||
|
class StableDiffusion15CheckpointModel(DiffusersModel):
|
||||||
|
class Cnfig(ModelConfigBase):
|
||||||
|
vae: Optional[str] = Field(None)
|
||||||
|
config: Optional[str] = Field(None)
|
||||||
|
|
||||||
|
class StableDiffusion2BaseDiffusersModel(DiffusersModel):
|
||||||
|
class Config(ModelConfigBase):
|
||||||
|
vae: Optional[str] = Field(None)
|
||||||
|
|
||||||
|
class StableDiffusion2BaseCheckpointModel(DiffusersModel):
|
||||||
|
class Cnfig(ModelConfigBase):
|
||||||
|
vae: Optional[str] = Field(None)
|
||||||
|
config: Optional[str] = Field(None)
|
||||||
|
|
||||||
|
class StableDiffusion2DiffusersModel(DiffusersModel):
|
||||||
|
class Config(ModelConfigBase):
|
||||||
|
vae: Optional[str] = Field(None)
|
||||||
|
attention_upscale: bool = Field(True)
|
||||||
|
|
||||||
|
class StableDiffusion2CheckpointModel(DiffusersModel):
|
||||||
|
class Config(ModelConfigBase):
|
||||||
|
vae: Optional[str] = Field(None)
|
||||||
|
config: Optional[str] = Field(None)
|
||||||
|
attention_upscale: bool = Field(True)
|
||||||
|
|
||||||
|
|
||||||
|
class ClassifierModel(ModelBase):
|
||||||
|
#child_types: Dict[str, Type]
|
||||||
|
#child_sizes: Dict[str, int]
|
||||||
|
|
||||||
|
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
|
||||||
|
assert model_type == SDModelType.Classifier
|
||||||
|
super().__init__(model_path, base_model, model_type)
|
||||||
|
|
||||||
|
self.child_types: Dict[str, Type] = dict()
|
||||||
|
self.child_sizes: Dict[str, int] = dict()
|
||||||
|
|
||||||
|
try:
|
||||||
|
main_config = EmptyConfigLoader.load_config(self.model_path, config_name="config.json")
|
||||||
|
#main_config = json.loads(os.path.join(self.model_path, "config.json"))
|
||||||
|
except:
|
||||||
|
raise Exception("Invalid classifier model! (config.json not found or invalid)")
|
||||||
|
|
||||||
|
self._load_tokenizer(main_config)
|
||||||
|
self._load_text_encoder(main_config)
|
||||||
|
self._load_feature_extractor(main_config)
|
||||||
|
|
||||||
|
|
||||||
|
def _load_tokenizer(self, main_config: dict):
|
||||||
|
try:
|
||||||
|
tokenizer_config = EmptyConfigLoader.load_config(self.model_path, config_name="tokenizer_config.json")
|
||||||
|
#tokenizer_config = json.loads(os.path.join(self.model_path, "tokenizer_config.json"))
|
||||||
|
except:
|
||||||
|
raise Exception("Invalid classifier model! (Failed to load tokenizer_config.json)")
|
||||||
|
|
||||||
|
if "tokenizer_class" in tokenizer_config:
|
||||||
|
tokenizer_class_name = tokenizer_config["tokenizer_class"]
|
||||||
|
elif "model_type" in main_config:
|
||||||
|
tokenizer_class_name = transformers.models.auto.tokenization_auto.TOKENIZER_MAPPING_NAMES[main_config["model_type"]]
|
||||||
|
else:
|
||||||
|
raise Exception("Invalid classifier model! (Failed to detect tokenizer type)")
|
||||||
|
|
||||||
|
self.child_types[SDModelType.Tokenizer] = self._hf_definition_to_type(["transformers", tokenizer_class_name])
|
||||||
|
self.child_sizes[SDModelType.Tokenizer] = 0
|
||||||
|
|
||||||
|
|
||||||
|
def _load_text_encoder(self, main_config: dict):
|
||||||
|
if "architectures" in main_config and len(main_config["architectures"]) > 0:
|
||||||
|
text_encoder_class_name = main_config["architectures"][0]
|
||||||
|
elif "model_type" in main_config:
|
||||||
|
text_encoder_class_name = transformers.models.auto.modeling_auto.MODEL_FOR_PRETRAINING_MAPPING_NAMES[main_config["model_type"]]
|
||||||
|
else:
|
||||||
|
raise Exception("Invalid classifier model! (Failed to detect text_encoder type)")
|
||||||
|
|
||||||
|
self.child_types[SDModelType.TextEncoder] = self._hf_definition_to_type(["transformers", text_encoder_class_name])
|
||||||
|
self.child_sizes[SDModelType.TextEncoder] = calc_model_size_by_fs(self.model_path)
|
||||||
|
|
||||||
|
|
||||||
|
def _load_feature_extractor(self, main_config: dict):
|
||||||
|
self.child_sizes[SDModelType.FeatureExtractor] = 0
|
||||||
|
try:
|
||||||
|
feature_extractor_config = EmptyConfigLoader.load_config(self.model_path, config_name="preprocessor_config.json")
|
||||||
|
except:
|
||||||
|
return # feature extractor not passed with t5
|
||||||
|
|
||||||
|
try:
|
||||||
|
feature_extractor_class_name = feature_extractor_config["feature_extractor_type"]
|
||||||
|
self.child_types[SDModelType.FeatureExtractor] = self._hf_definition_to_type(["transformers", feature_extractor_class_name])
|
||||||
|
except:
|
||||||
|
raise Exception("Invalid classifier model! (Unknown feature_extrator type)")
|
||||||
|
|
||||||
|
|
||||||
|
def get_size(self, child_type: Optional[SDModelType] = None):
|
||||||
|
if child_type is None:
|
||||||
|
return sum(self.child_sizes.values())
|
||||||
|
else:
|
||||||
|
return self.child_sizes[child_type]
|
||||||
|
|
||||||
|
|
||||||
|
def get_model(
|
||||||
|
self,
|
||||||
|
torch_dtype: Optional[torch.dtype],
|
||||||
|
child_type: Optional[SDModelType] = None,
|
||||||
|
):
|
||||||
|
if child_type is None:
|
||||||
|
raise Exception("Child model type can't be null on classififer model")
|
||||||
|
if child_type not in self.child_types:
|
||||||
|
return None # TODO: or raise
|
||||||
|
|
||||||
|
model = self.child_types[child_type].from_pretrained(
|
||||||
|
self.model_path,
|
||||||
|
subfolder=child_type.value,
|
||||||
|
torch_dtype=torch_dtype,
|
||||||
|
)
|
||||||
|
# calc more accurate size
|
||||||
|
self.child_sizes[child_type] = calc_model_size_by_data(model)
|
||||||
|
return model
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def convert_if_required(model_path: Union[str, Path], cache_path: str, config: Optional[dict]) -> Path:
|
||||||
|
if not isinstance(model_path, Path):
|
||||||
|
model_path = Path(model_path)
|
||||||
|
return model_path
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class VaeModel(ModelBase):
|
||||||
|
#vae_class: Type
|
||||||
|
#model_size: int
|
||||||
|
|
||||||
|
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
|
||||||
|
assert model_type == ModelType.Vae
|
||||||
|
super().__init__(model_path, base_model, model_type)
|
||||||
|
|
||||||
|
try:
|
||||||
|
config = EmptyConfigLoader.load_config(self.model_path, config_name="config.json")
|
||||||
|
#config = json.loads(os.path.join(self.model_path, "config.json"))
|
||||||
|
except:
|
||||||
|
raise Exception("Invalid vae model! (config.json not found or invalid)")
|
||||||
|
|
||||||
|
try:
|
||||||
|
vae_class_name = config.get("_class_name", "AutoencoderKL")
|
||||||
|
self.vae_class = self._hf_definition_to_type(["diffusers", vae_class_name])
|
||||||
|
self.model_size = calc_model_size_by_fs(self.model_path)
|
||||||
|
except:
|
||||||
|
raise Exception("Invalid vae model! (Unkown vae type)")
|
||||||
|
|
||||||
|
def get_size(self, child_type: Optional[SDModelType] = None):
|
||||||
|
if child_type is not None:
|
||||||
|
raise Exception("There is no child models in vae model")
|
||||||
|
return self.model_size
|
||||||
|
|
||||||
|
def get_model(
|
||||||
|
self,
|
||||||
|
torch_dtype: Optional[torch.dtype],
|
||||||
|
child_type: Optional[SDModelType] = None,
|
||||||
|
):
|
||||||
|
if child_type is not None:
|
||||||
|
raise Exception("There is no child models in vae model")
|
||||||
|
|
||||||
|
model = self.vae_class.from_pretrained(
|
||||||
|
self.model_path,
|
||||||
|
torch_dtype=torch_dtype,
|
||||||
|
)
|
||||||
|
# calc more accurate size
|
||||||
|
self.model_size = calc_model_size_by_data(model)
|
||||||
|
return model
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def convert_if_required(model_path: Union[str, Path], cache_path: str, config: Optional[dict]) -> Path:
|
||||||
|
if not isinstance(model_path, Path):
|
||||||
|
model_path = Path(model_path)
|
||||||
|
# TODO:
|
||||||
|
#_convert_vae_ckpt_and_cache
|
||||||
|
raise Exception("TODO: ")
|
||||||
|
|
||||||
|
|
||||||
|
class LoRAModel(ModelBase):
|
||||||
|
#model_size: int
|
||||||
|
|
||||||
|
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
|
||||||
|
assert model_type == ModelType.Lora
|
||||||
|
super().__init__(model_path, base_model, model_type)
|
||||||
|
|
||||||
|
self.model_size = os.path.getsize(self.model_path)
|
||||||
|
|
||||||
|
def get_size(self, child_type: Optional[SDModelType] = None):
|
||||||
|
if child_type is not None:
|
||||||
|
raise Exception("There is no child models in lora")
|
||||||
|
return self.model_size
|
||||||
|
|
||||||
|
def get_model(
|
||||||
|
self,
|
||||||
|
torch_dtype: Optional[torch.dtype],
|
||||||
|
child_type: Optional[SDModelType] = None,
|
||||||
|
):
|
||||||
|
if child_type is not None:
|
||||||
|
raise Exception("There is no child models in lora")
|
||||||
|
|
||||||
|
model = LoRAModel.from_checkpoint(
|
||||||
|
file_path=self.model_path,
|
||||||
|
dtype=torch_dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.model_size = model.calc_size()
|
||||||
|
return model
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def convert_if_required(model_path: Union[str, Path], cache_path: str, config: Optional[dict]) -> Path:
|
||||||
|
if not isinstance(model_path, Path):
|
||||||
|
model_path = Path(model_path)
|
||||||
|
|
||||||
|
# TODO: add diffusers lora when it stabilizes a bit
|
||||||
|
return model_path
|
||||||
|
|
||||||
|
|
||||||
|
class TextualInversionModel(ModelBase):
|
||||||
|
#model_size: int
|
||||||
|
|
||||||
|
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
|
||||||
|
assert model_type == ModelType.TextualInversion
|
||||||
|
super().__init__(model_path, base_model, model_type)
|
||||||
|
|
||||||
|
self.model_size = os.path.getsize(self.model_path)
|
||||||
|
|
||||||
|
def get_size(self, child_type: Optional[SDModelType] = None):
|
||||||
|
if child_type is not None:
|
||||||
|
raise Exception("There is no child models in textual inversion")
|
||||||
|
return self.model_size
|
||||||
|
|
||||||
|
def get_model(
|
||||||
|
self,
|
||||||
|
torch_dtype: Optional[torch.dtype],
|
||||||
|
child_type: Optional[SDModelType] = None,
|
||||||
|
):
|
||||||
|
if child_type is not None:
|
||||||
|
raise Exception("There is no child models in textual inversion")
|
||||||
|
|
||||||
|
model = TextualInversionModel.from_checkpoint(
|
||||||
|
file_path=self.model_path,
|
||||||
|
dtype=torch_dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.model_size = model.embedding.nelement() * model.embedding.element_size()
|
||||||
|
return model
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def convert_if_required(model_path: Union[str, Path], cache_path: str, config: Optional[dict]) -> Path:
|
||||||
|
if not isinstance(model_path, Path):
|
||||||
|
model_path = Path(model_path)
|
||||||
|
return model_path
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def calc_model_size_by_fs(
|
||||||
|
model_path: str,
|
||||||
|
subfolder: Optional[str] = None,
|
||||||
|
variant: Optional[str] = None
|
||||||
|
):
|
||||||
|
if subfolder is not None:
|
||||||
|
model_path = os.path.join(model_path, subfolder)
|
||||||
|
|
||||||
|
# this can happen when, for example, the safety checker
|
||||||
|
# is not downloaded.
|
||||||
|
if not os.path.exists(model_path):
|
||||||
|
return 0
|
||||||
|
|
||||||
|
all_files = os.listdir(model_path)
|
||||||
|
all_files = [f for f in all_files if os.path.isfile(os.path.join(model_path, f))]
|
||||||
|
|
||||||
|
fp16_files = set([f for f in all_files if ".fp16." in f or ".fp16-" in f])
|
||||||
|
bit8_files = set([f for f in all_files if ".8bit." in f or ".8bit-" in f])
|
||||||
|
other_files = set(all_files) - fp16_files - bit8_files
|
||||||
|
|
||||||
|
if variant is None:
|
||||||
|
files = other_files
|
||||||
|
elif variant == "fp16":
|
||||||
|
files = fp16_files
|
||||||
|
elif variant == "8bit":
|
||||||
|
files = bit8_files
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(f"Unknown variant: {variant}")
|
||||||
|
|
||||||
|
# try read from index if exists
|
||||||
|
index_postfix = ".index.json"
|
||||||
|
if variant is not None:
|
||||||
|
index_postfix = f".index.{variant}.json"
|
||||||
|
|
||||||
|
for file in files:
|
||||||
|
if not file.endswith(index_postfix):
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
with open(os.path.join(model_path, file), "r") as f:
|
||||||
|
index_data = json.loads(f.read())
|
||||||
|
return int(index_data["metadata"]["total_size"])
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# calculate files size if there is no index file
|
||||||
|
formats = [
|
||||||
|
(".safetensors",), # safetensors
|
||||||
|
(".bin",), # torch
|
||||||
|
(".onnx", ".pb"), # onnx
|
||||||
|
(".msgpack",), # flax
|
||||||
|
(".ckpt",), # tf
|
||||||
|
(".h5",), # tf2
|
||||||
|
]
|
||||||
|
|
||||||
|
for file_format in formats:
|
||||||
|
model_files = [f for f in files if f.endswith(file_format)]
|
||||||
|
if len(model_files) == 0:
|
||||||
|
continue
|
||||||
|
|
||||||
|
model_size = 0
|
||||||
|
for model_file in model_files:
|
||||||
|
file_stats = os.stat(os.path.join(model_path, model_file))
|
||||||
|
model_size += file_stats.st_size
|
||||||
|
return model_size
|
||||||
|
|
||||||
|
#raise NotImplementedError(f"Unknown model structure! Files: {all_files}")
|
||||||
|
return 0 # scheduler/feature_extractor/tokenizer - models without loading to gpu
|
||||||
|
|
||||||
|
|
||||||
|
def calc_model_size_by_data(model) -> int:
|
||||||
|
if isinstance(model, DiffusionPipeline):
|
||||||
|
return _calc_pipeline_by_data(model)
|
||||||
|
elif isinstance(model, torch.nn.Module):
|
||||||
|
return _calc_model_by_data(model)
|
||||||
|
else:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
|
||||||
|
def _calc_pipeline_by_data(pipeline) -> int:
|
||||||
|
res = 0
|
||||||
|
for submodel_key in pipeline.components.keys():
|
||||||
|
submodel = getattr(pipeline, submodel_key)
|
||||||
|
if submodel is not None and isinstance(submodel, torch.nn.Module):
|
||||||
|
res += _calc_model_by_data(submodel)
|
||||||
|
return res
|
||||||
|
|
||||||
|
|
||||||
|
def _calc_model_by_data(model) -> int:
|
||||||
|
mem_params = sum([param.nelement()*param.element_size() for param in model.parameters()])
|
||||||
|
mem_bufs = sum([buf.nelement()*buf.element_size() for buf in model.buffers()])
|
||||||
|
mem = mem_params + mem_bufs # in bytes
|
||||||
|
return mem
|
||||||
|
|
||||||
|
|
||||||
|
def _convert_ckpt_and_cache(self, mconfig: DictConfig) -> Path:
|
||||||
|
"""
|
||||||
|
Convert the checkpoint model indicated in mconfig into a
|
||||||
|
diffusers, cache it to disk, and return Path to converted
|
||||||
|
file. If already on disk then just returns Path.
|
||||||
|
"""
|
||||||
|
app_config = InvokeAIAppConfig.get_config()
|
||||||
|
weights = app_config.root_dir / mconfig.path
|
||||||
|
config_file = app_config.root_dir / mconfig.config
|
||||||
|
diffusers_path = app_config.converted_ckpts_dir / weights.stem
|
||||||
|
|
||||||
|
# return cached version if it exists
|
||||||
|
if diffusers_path.exists():
|
||||||
|
return diffusers_path
|
||||||
|
|
||||||
|
# TODO: I think that it more correctly to convert with embedded vae
|
||||||
|
# as if user will delete custom vae he will got not embedded but also custom vae
|
||||||
|
#vae_ckpt_path, vae_model = self._get_vae_for_conversion(weights, mconfig)
|
||||||
|
vae_ckpt_path, vae_model = None, None
|
||||||
|
|
||||||
|
# to avoid circular import errors
|
||||||
|
from .convert_ckpt_to_diffusers import convert_ckpt_to_diffusers
|
||||||
|
with SilenceWarnings():
|
||||||
|
convert_ckpt_to_diffusers(
|
||||||
|
weights,
|
||||||
|
diffusers_path,
|
||||||
|
extract_ema=True,
|
||||||
|
original_config_file=config_file,
|
||||||
|
vae=vae_model,
|
||||||
|
vae_path=str(app_config.root_dir / vae_ckpt_path) if vae_ckpt_path else None,
|
||||||
|
scan_needed=True,
|
||||||
|
)
|
||||||
|
return diffusers_path
|
||||||
|
|
||||||
|
def _convert_vae_ckpt_and_cache(self, mconfig: DictConfig) -> Path:
|
||||||
|
"""
|
||||||
|
Convert the VAE indicated in mconfig into a diffusers AutoencoderKL
|
||||||
|
object, cache it to disk, and return Path to converted
|
||||||
|
file. If already on disk then just returns Path.
|
||||||
|
"""
|
||||||
|
app_config = InvokeAIAppConfig.get_config()
|
||||||
|
root = app_config.root_dir
|
||||||
|
weights_file = root / mconfig.path
|
||||||
|
config_file = root / mconfig.config
|
||||||
|
diffusers_path = app_config.converted_ckpts_dir / weights_file.stem
|
||||||
|
image_size = mconfig.get('width') or mconfig.get('height') or 512
|
||||||
|
|
||||||
|
# return cached version if it exists
|
||||||
|
if diffusers_path.exists():
|
||||||
|
return diffusers_path
|
||||||
|
|
||||||
|
# this avoids circular import error
|
||||||
|
from .convert_ckpt_to_diffusers import convert_ldm_vae_to_diffusers
|
||||||
|
if weights_file.suffix == '.safetensors':
|
||||||
|
checkpoint = safetensors.torch.load_file(weights_file)
|
||||||
|
else:
|
||||||
|
checkpoint = torch.load(weights_file, map_location="cpu")
|
||||||
|
|
||||||
|
# sometimes weights are hidden under "state_dict", and sometimes not
|
||||||
|
if "state_dict" in checkpoint:
|
||||||
|
checkpoint = checkpoint["state_dict"]
|
||||||
|
|
||||||
|
config = OmegaConf.load(config_file)
|
||||||
|
|
||||||
|
vae_model = convert_ldm_vae_to_diffusers(
|
||||||
|
checkpoint = checkpoint,
|
||||||
|
vae_config = config,
|
||||||
|
image_size = image_size
|
||||||
|
)
|
||||||
|
vae_model.save_pretrained(
|
||||||
|
diffusers_path,
|
||||||
|
safe_serialization=is_safetensors_available()
|
||||||
|
)
|
||||||
|
return diffusers_path
|
Loading…
Reference in New Issue
Block a user