mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Apply black
This commit is contained in:
@ -3,15 +3,24 @@ from enum import Enum
|
||||
from pydantic import BaseModel
|
||||
from typing import Literal, get_origin
|
||||
from .base import (
|
||||
BaseModelType, ModelType, SubModelType, ModelBase, ModelConfigBase,
|
||||
ModelVariantType, SchedulerPredictionType, ModelError, SilenceWarnings,
|
||||
ModelNotFoundException, InvalidModelException, DuplicateModelException
|
||||
)
|
||||
BaseModelType,
|
||||
ModelType,
|
||||
SubModelType,
|
||||
ModelBase,
|
||||
ModelConfigBase,
|
||||
ModelVariantType,
|
||||
SchedulerPredictionType,
|
||||
ModelError,
|
||||
SilenceWarnings,
|
||||
ModelNotFoundException,
|
||||
InvalidModelException,
|
||||
DuplicateModelException,
|
||||
)
|
||||
from .stable_diffusion import StableDiffusion1Model, StableDiffusion2Model
|
||||
from .sdxl import StableDiffusionXLModel
|
||||
from .vae import VaeModel
|
||||
from .lora import LoRAModel
|
||||
from .controlnet import ControlNetModel # TODO:
|
||||
from .controlnet import ControlNetModel # TODO:
|
||||
from .textual_inversion import TextualInversionModel
|
||||
|
||||
MODEL_CLASSES = {
|
||||
@ -45,18 +54,19 @@ MODEL_CLASSES = {
|
||||
ModelType.ControlNet: ControlNetModel,
|
||||
ModelType.TextualInversion: TextualInversionModel,
|
||||
},
|
||||
#BaseModelType.Kandinsky2_1: {
|
||||
# BaseModelType.Kandinsky2_1: {
|
||||
# ModelType.Main: Kandinsky2_1Model,
|
||||
# ModelType.MoVQ: MoVQModel,
|
||||
# ModelType.Lora: LoRAModel,
|
||||
# ModelType.ControlNet: ControlNetModel,
|
||||
# ModelType.TextualInversion: TextualInversionModel,
|
||||
#},
|
||||
# },
|
||||
}
|
||||
|
||||
MODEL_CONFIGS = list()
|
||||
OPENAPI_MODEL_CONFIGS = list()
|
||||
|
||||
|
||||
class OpenAPIModelInfoBase(BaseModel):
|
||||
model_name: str
|
||||
base_model: BaseModelType
|
||||
@ -72,27 +82,31 @@ for base_model, models in MODEL_CLASSES.items():
|
||||
# LS: sort to get the checkpoint configs first, which makes
|
||||
# for a better template in the Swagger docs
|
||||
for cfg in sorted(model_configs, key=lambda x: str(x)):
|
||||
model_name, cfg_name = cfg.__qualname__.split('.')[-2:]
|
||||
model_name, cfg_name = cfg.__qualname__.split(".")[-2:]
|
||||
openapi_cfg_name = model_name + cfg_name
|
||||
if openapi_cfg_name in vars():
|
||||
continue
|
||||
|
||||
api_wrapper = type(openapi_cfg_name, (cfg, OpenAPIModelInfoBase), dict(
|
||||
__annotations__ = dict(
|
||||
model_type=Literal[model_type.value],
|
||||
api_wrapper = type(
|
||||
openapi_cfg_name,
|
||||
(cfg, OpenAPIModelInfoBase),
|
||||
dict(
|
||||
__annotations__=dict(
|
||||
model_type=Literal[model_type.value],
|
||||
),
|
||||
),
|
||||
))
|
||||
)
|
||||
|
||||
#globals()[openapi_cfg_name] = api_wrapper
|
||||
# globals()[openapi_cfg_name] = api_wrapper
|
||||
vars()[openapi_cfg_name] = api_wrapper
|
||||
OPENAPI_MODEL_CONFIGS.append(api_wrapper)
|
||||
|
||||
|
||||
def get_model_config_enums():
|
||||
enums = list()
|
||||
|
||||
for model_config in MODEL_CONFIGS:
|
||||
|
||||
if hasattr(inspect,'get_annotations'):
|
||||
if hasattr(inspect, "get_annotations"):
|
||||
fields = inspect.get_annotations(model_config)
|
||||
else:
|
||||
fields = model_config.__annotations__
|
||||
@ -109,7 +123,9 @@ def get_model_config_enums():
|
||||
if isinstance(field, type) and issubclass(field, str) and issubclass(field, Enum):
|
||||
enums.append(field)
|
||||
|
||||
elif get_origin(field) is Literal and all(isinstance(arg, str) and isinstance(arg, Enum) for arg in field.__args__):
|
||||
elif get_origin(field) is Literal and all(
|
||||
isinstance(arg, str) and isinstance(arg, Enum) for arg in field.__args__
|
||||
):
|
||||
enums.append(type(field.__args__[0]))
|
||||
|
||||
elif field is None:
|
||||
@ -119,4 +135,3 @@ def get_model_config_enums():
|
||||
raise Exception(f"Unsupported format definition in {model_configs.__qualname__}")
|
||||
|
||||
return enums
|
||||
|
||||
|
@ -15,29 +15,35 @@ from contextlib import suppress
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import List, Dict, Optional, Type, Literal, TypeVar, Generic, Callable, Any, Union
|
||||
|
||||
|
||||
class DuplicateModelException(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class InvalidModelException(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class ModelNotFoundException(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class BaseModelType(str, Enum):
|
||||
StableDiffusion1 = "sd-1"
|
||||
StableDiffusion2 = "sd-2"
|
||||
StableDiffusionXL = "sdxl"
|
||||
StableDiffusionXLRefiner = "sdxl-refiner"
|
||||
#Kandinsky2_1 = "kandinsky-2.1"
|
||||
# Kandinsky2_1 = "kandinsky-2.1"
|
||||
|
||||
|
||||
class ModelType(str, Enum):
|
||||
Main = "main"
|
||||
Vae = "vae"
|
||||
Lora = "lora"
|
||||
ControlNet = "controlnet" # used by model_probe
|
||||
ControlNet = "controlnet" # used by model_probe
|
||||
TextualInversion = "embedding"
|
||||
|
||||
|
||||
class SubModelType(str, Enum):
|
||||
UNet = "unet"
|
||||
TextEncoder = "text_encoder"
|
||||
@ -47,23 +53,27 @@ class SubModelType(str, Enum):
|
||||
Vae = "vae"
|
||||
Scheduler = "scheduler"
|
||||
SafetyChecker = "safety_checker"
|
||||
#MoVQ = "movq"
|
||||
# MoVQ = "movq"
|
||||
|
||||
|
||||
class ModelVariantType(str, Enum):
|
||||
Normal = "normal"
|
||||
Inpaint = "inpaint"
|
||||
Depth = "depth"
|
||||
|
||||
|
||||
class SchedulerPredictionType(str, Enum):
|
||||
Epsilon = "epsilon"
|
||||
VPrediction = "v_prediction"
|
||||
Sample = "sample"
|
||||
|
||||
|
||||
|
||||
class ModelError(str, Enum):
|
||||
NotFound = "not_found"
|
||||
|
||||
|
||||
class ModelConfigBase(BaseModel):
|
||||
path: str # or Path
|
||||
path: str # or Path
|
||||
description: Optional[str] = Field(None)
|
||||
model_format: Optional[str] = Field(None)
|
||||
error: Optional[ModelError] = Field(None)
|
||||
@ -71,13 +81,17 @@ class ModelConfigBase(BaseModel):
|
||||
class Config:
|
||||
use_enum_values = True
|
||||
|
||||
|
||||
class EmptyConfigLoader(ConfigMixin):
|
||||
@classmethod
|
||||
def load_config(cls, *args, **kwargs):
|
||||
cls.config_name = kwargs.pop("config_name")
|
||||
return super().load_config(*args, **kwargs)
|
||||
|
||||
T_co = TypeVar('T_co', covariant=True)
|
||||
|
||||
T_co = TypeVar("T_co", covariant=True)
|
||||
|
||||
|
||||
class classproperty(Generic[T_co]):
|
||||
def __init__(self, fget: Callable[[Any], T_co]) -> None:
|
||||
self.fget = fget
|
||||
@ -86,12 +100,13 @@ class classproperty(Generic[T_co]):
|
||||
return self.fget(owner)
|
||||
|
||||
def __set__(self, instance: Optional[Any], value: Any) -> None:
|
||||
raise AttributeError('cannot set attribute')
|
||||
raise AttributeError("cannot set attribute")
|
||||
|
||||
|
||||
class ModelBase(metaclass=ABCMeta):
|
||||
#model_path: str
|
||||
#base_model: BaseModelType
|
||||
#model_type: ModelType
|
||||
# model_path: str
|
||||
# base_model: BaseModelType
|
||||
# model_type: ModelType
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -110,7 +125,7 @@ class ModelBase(metaclass=ABCMeta):
|
||||
return None
|
||||
elif any(t is None for t in subtypes):
|
||||
raise Exception(f"Unsupported definition: {subtypes}")
|
||||
|
||||
|
||||
if subtypes[0] in ["diffusers", "transformers"]:
|
||||
res_type = sys.modules[subtypes[0]]
|
||||
subtypes = subtypes[1:]
|
||||
@ -119,7 +134,6 @@ class ModelBase(metaclass=ABCMeta):
|
||||
res_type = sys.modules["diffusers"]
|
||||
res_type = getattr(res_type, "pipelines")
|
||||
|
||||
|
||||
for subtype in subtypes:
|
||||
res_type = getattr(res_type, subtype)
|
||||
return res_type
|
||||
@ -128,7 +142,7 @@ class ModelBase(metaclass=ABCMeta):
|
||||
def _get_configs(cls):
|
||||
with suppress(Exception):
|
||||
return cls.__configs
|
||||
|
||||
|
||||
configs = dict()
|
||||
for name in dir(cls):
|
||||
if name.startswith("__"):
|
||||
@ -138,7 +152,7 @@ class ModelBase(metaclass=ABCMeta):
|
||||
if not isinstance(value, type) or not issubclass(value, ModelConfigBase):
|
||||
continue
|
||||
|
||||
if hasattr(inspect,'get_annotations'):
|
||||
if hasattr(inspect, "get_annotations"):
|
||||
fields = inspect.get_annotations(value)
|
||||
else:
|
||||
fields = value.__annotations__
|
||||
@ -151,7 +165,9 @@ class ModelBase(metaclass=ABCMeta):
|
||||
for model_format in field:
|
||||
configs[model_format.value] = value
|
||||
|
||||
elif typing.get_origin(field) is Literal and all(isinstance(arg, str) and isinstance(arg, Enum) for arg in field.__args__):
|
||||
elif typing.get_origin(field) is Literal and all(
|
||||
isinstance(arg, str) and isinstance(arg, Enum) for arg in field.__args__
|
||||
):
|
||||
for model_format in field.__args__:
|
||||
configs[model_format.value] = value
|
||||
|
||||
@ -203,8 +219,8 @@ class ModelBase(metaclass=ABCMeta):
|
||||
|
||||
|
||||
class DiffusersModel(ModelBase):
|
||||
#child_types: Dict[str, Type]
|
||||
#child_sizes: Dict[str, int]
|
||||
# child_types: Dict[str, Type]
|
||||
# child_sizes: Dict[str, int]
|
||||
|
||||
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
|
||||
super().__init__(model_path, base_model, model_type)
|
||||
@ -214,7 +230,7 @@ class DiffusersModel(ModelBase):
|
||||
|
||||
try:
|
||||
config_data = DiffusionPipeline.load_config(self.model_path)
|
||||
#config_data = json.loads(os.path.join(self.model_path, "model_index.json"))
|
||||
# config_data = json.loads(os.path.join(self.model_path, "model_index.json"))
|
||||
except:
|
||||
raise Exception("Invalid diffusers model! (model_index.json not found or invalid)")
|
||||
|
||||
@ -228,14 +244,12 @@ class DiffusersModel(ModelBase):
|
||||
self.child_types[child_name] = child_type
|
||||
self.child_sizes[child_name] = calc_model_size_by_fs(self.model_path, subfolder=child_name)
|
||||
|
||||
|
||||
def get_size(self, child_type: Optional[SubModelType] = None):
|
||||
if child_type is None:
|
||||
return sum(self.child_sizes.values())
|
||||
else:
|
||||
return self.child_sizes[child_type]
|
||||
|
||||
|
||||
def get_model(
|
||||
self,
|
||||
torch_dtype: Optional[torch.dtype],
|
||||
@ -245,7 +259,7 @@ class DiffusersModel(ModelBase):
|
||||
if child_type is None:
|
||||
raise Exception("Child model type can't be null on diffusers model")
|
||||
if child_type not in self.child_types:
|
||||
return None # TODO: or raise
|
||||
return None # TODO: or raise
|
||||
|
||||
if torch_dtype == torch.float16:
|
||||
variants = ["fp16", None]
|
||||
@ -265,8 +279,8 @@ class DiffusersModel(ModelBase):
|
||||
)
|
||||
break
|
||||
except Exception as e:
|
||||
#print("====ERR LOAD====")
|
||||
#print(f"{variant}: {e}")
|
||||
# print("====ERR LOAD====")
|
||||
# print(f"{variant}: {e}")
|
||||
pass
|
||||
else:
|
||||
raise Exception(f"Failed to load {self.base_model}:{self.model_type}:{child_type} model")
|
||||
@ -275,15 +289,10 @@ class DiffusersModel(ModelBase):
|
||||
self.child_sizes[child_type] = calc_model_size_by_data(model)
|
||||
return model
|
||||
|
||||
#def convert_if_required(model_path: str, cache_path: str, config: Optional[dict]) -> str:
|
||||
# def convert_if_required(model_path: str, cache_path: str, config: Optional[dict]) -> str:
|
||||
|
||||
|
||||
|
||||
def calc_model_size_by_fs(
|
||||
model_path: str,
|
||||
subfolder: Optional[str] = None,
|
||||
variant: Optional[str] = None
|
||||
):
|
||||
def calc_model_size_by_fs(model_path: str, subfolder: Optional[str] = None, variant: Optional[str] = None):
|
||||
if subfolder is not None:
|
||||
model_path = os.path.join(model_path, subfolder)
|
||||
|
||||
@ -325,12 +334,12 @@ def calc_model_size_by_fs(
|
||||
|
||||
# calculate files size if there is no index file
|
||||
formats = [
|
||||
(".safetensors",), # safetensors
|
||||
(".bin",), # torch
|
||||
(".onnx", ".pb"), # onnx
|
||||
(".msgpack",), # flax
|
||||
(".ckpt",), # tf
|
||||
(".h5",), # tf2
|
||||
(".safetensors",), # safetensors
|
||||
(".bin",), # torch
|
||||
(".onnx", ".pb"), # onnx
|
||||
(".msgpack",), # flax
|
||||
(".ckpt",), # tf
|
||||
(".h5",), # tf2
|
||||
]
|
||||
|
||||
for file_format in formats:
|
||||
@ -343,9 +352,9 @@ def calc_model_size_by_fs(
|
||||
file_stats = os.stat(os.path.join(model_path, model_file))
|
||||
model_size += file_stats.st_size
|
||||
return model_size
|
||||
|
||||
#raise NotImplementedError(f"Unknown model structure! Files: {all_files}")
|
||||
return 0 # scheduler/feature_extractor/tokenizer - models without loading to gpu
|
||||
|
||||
# raise NotImplementedError(f"Unknown model structure! Files: {all_files}")
|
||||
return 0 # scheduler/feature_extractor/tokenizer - models without loading to gpu
|
||||
|
||||
|
||||
def calc_model_size_by_data(model) -> int:
|
||||
@ -364,12 +373,12 @@ def _calc_pipeline_by_data(pipeline) -> int:
|
||||
if submodel is not None and isinstance(submodel, torch.nn.Module):
|
||||
res += _calc_model_by_data(submodel)
|
||||
return res
|
||||
|
||||
|
||||
|
||||
def _calc_model_by_data(model) -> int:
|
||||
mem_params = sum([param.nelement()*param.element_size() for param in model.parameters()])
|
||||
mem_bufs = sum([buf.nelement()*buf.element_size() for buf in model.buffers()])
|
||||
mem = mem_params + mem_bufs # in bytes
|
||||
mem_params = sum([param.nelement() * param.element_size() for param in model.parameters()])
|
||||
mem_bufs = sum([buf.nelement() * buf.element_size() for buf in model.buffers()])
|
||||
mem = mem_params + mem_bufs # in bytes
|
||||
return mem
|
||||
|
||||
|
||||
@ -377,11 +386,15 @@ def _fast_safetensors_reader(path: str):
|
||||
checkpoint = dict()
|
||||
device = torch.device("meta")
|
||||
with open(path, "rb") as f:
|
||||
definition_len = int.from_bytes(f.read(8), 'little')
|
||||
definition_len = int.from_bytes(f.read(8), "little")
|
||||
definition_json = f.read(definition_len)
|
||||
definition = json.loads(definition_json)
|
||||
|
||||
if "__metadata__" in definition and definition["__metadata__"].get("format", "pt") not in {"pt", "torch", "pytorch"}:
|
||||
if "__metadata__" in definition and definition["__metadata__"].get("format", "pt") not in {
|
||||
"pt",
|
||||
"torch",
|
||||
"pytorch",
|
||||
}:
|
||||
raise Exception("Supported only pytorch safetensors files")
|
||||
definition.pop("__metadata__", None)
|
||||
|
||||
@ -400,6 +413,7 @@ def _fast_safetensors_reader(path: str):
|
||||
|
||||
return checkpoint
|
||||
|
||||
|
||||
def read_checkpoint_meta(path: Union[str, Path], scan: bool = False):
|
||||
if str(path).endswith(".safetensors"):
|
||||
try:
|
||||
@ -411,25 +425,27 @@ def read_checkpoint_meta(path: Union[str, Path], scan: bool = False):
|
||||
if scan:
|
||||
scan_result = scan_file_path(path)
|
||||
if scan_result.infected_files != 0:
|
||||
raise Exception(f"The model file \"{path}\" is potentially infected by malware. Aborting import.")
|
||||
raise Exception(f'The model file "{path}" is potentially infected by malware. Aborting import.')
|
||||
checkpoint = torch.load(path, map_location=torch.device("meta"))
|
||||
return checkpoint
|
||||
|
||||
|
||||
import warnings
|
||||
from diffusers import logging as diffusers_logging
|
||||
from transformers import logging as transformers_logging
|
||||
|
||||
|
||||
class SilenceWarnings(object):
|
||||
def __init__(self):
|
||||
self.transformers_verbosity = transformers_logging.get_verbosity()
|
||||
self.diffusers_verbosity = diffusers_logging.get_verbosity()
|
||||
|
||||
|
||||
def __enter__(self):
|
||||
transformers_logging.set_verbosity_error()
|
||||
diffusers_logging.set_verbosity_error()
|
||||
warnings.simplefilter('ignore')
|
||||
warnings.simplefilter("ignore")
|
||||
|
||||
def __exit__(self, type, value, traceback):
|
||||
transformers_logging.set_verbosity(self.transformers_verbosity)
|
||||
diffusers_logging.set_verbosity(self.diffusers_verbosity)
|
||||
warnings.simplefilter('default')
|
||||
warnings.simplefilter("default")
|
||||
|
@ -18,13 +18,15 @@ from .base import (
|
||||
)
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
|
||||
|
||||
class ControlNetModelFormat(str, Enum):
|
||||
Checkpoint = "checkpoint"
|
||||
Diffusers = "diffusers"
|
||||
|
||||
|
||||
class ControlNetModel(ModelBase):
|
||||
#model_class: Type
|
||||
#model_size: int
|
||||
# model_class: Type
|
||||
# model_size: int
|
||||
|
||||
class DiffusersConfig(ModelConfigBase):
|
||||
model_format: Literal[ControlNetModelFormat.Diffusers]
|
||||
@ -39,7 +41,7 @@ class ControlNetModel(ModelBase):
|
||||
|
||||
try:
|
||||
config = EmptyConfigLoader.load_config(self.model_path, config_name="config.json")
|
||||
#config = json.loads(os.path.join(self.model_path, "config.json"))
|
||||
# config = json.loads(os.path.join(self.model_path, "config.json"))
|
||||
except:
|
||||
raise Exception("Invalid controlnet model! (config.json not found or invalid)")
|
||||
|
||||
@ -67,7 +69,7 @@ class ControlNetModel(ModelBase):
|
||||
raise Exception("There is no child models in controlnet model")
|
||||
|
||||
model = None
|
||||
for variant in ['fp16',None]:
|
||||
for variant in ["fp16", None]:
|
||||
try:
|
||||
model = self.model_class.from_pretrained(
|
||||
self.model_path,
|
||||
@ -79,7 +81,7 @@ class ControlNetModel(ModelBase):
|
||||
pass
|
||||
if not model:
|
||||
raise ModelNotFoundException()
|
||||
|
||||
|
||||
# calc more accurate size
|
||||
self.model_size = calc_model_size_by_data(model)
|
||||
return model
|
||||
@ -105,29 +107,30 @@ class ControlNetModel(ModelBase):
|
||||
|
||||
@classmethod
|
||||
def convert_if_required(
|
||||
cls,
|
||||
model_path: str,
|
||||
output_path: str,
|
||||
config: ModelConfigBase,
|
||||
base_model: BaseModelType,
|
||||
) -> str:
|
||||
if cls.detect_format(model_path) == ControlNetModelFormat.Checkpoint:
|
||||
return _convert_controlnet_ckpt_and_cache(
|
||||
model_path = model_path,
|
||||
model_config = config.config,
|
||||
output_path = output_path,
|
||||
base_model = base_model,
|
||||
)
|
||||
else:
|
||||
return model_path
|
||||
|
||||
@classmethod
|
||||
def _convert_controlnet_ckpt_and_cache(
|
||||
cls,
|
||||
model_path: str,
|
||||
output_path: str,
|
||||
config: ModelConfigBase,
|
||||
base_model: BaseModelType,
|
||||
model_config: ControlNetModel.CheckpointConfig,
|
||||
) -> str:
|
||||
if cls.detect_format(model_path) == ControlNetModelFormat.Checkpoint:
|
||||
return _convert_controlnet_ckpt_and_cache(
|
||||
model_path=model_path,
|
||||
model_config=config.config,
|
||||
output_path=output_path,
|
||||
base_model=base_model,
|
||||
)
|
||||
else:
|
||||
return model_path
|
||||
|
||||
|
||||
@classmethod
|
||||
def _convert_controlnet_ckpt_and_cache(
|
||||
cls,
|
||||
model_path: str,
|
||||
output_path: str,
|
||||
base_model: BaseModelType,
|
||||
model_config: ControlNetModel.CheckpointConfig,
|
||||
) -> str:
|
||||
"""
|
||||
Convert the controlnet from checkpoint format to diffusers format,
|
||||
@ -144,12 +147,13 @@ def _convert_controlnet_ckpt_and_cache(
|
||||
|
||||
# to avoid circular import errors
|
||||
from ..convert_ckpt_to_diffusers import convert_controlnet_to_diffusers
|
||||
|
||||
convert_controlnet_to_diffusers(
|
||||
weights,
|
||||
output_path,
|
||||
original_config_file = app_config.root_path / model_config,
|
||||
image_size = 512,
|
||||
scan_needed = True,
|
||||
from_safetensors = weights.suffix == ".safetensors"
|
||||
original_config_file=app_config.root_path / model_config,
|
||||
image_size=512,
|
||||
scan_needed=True,
|
||||
from_safetensors=weights.suffix == ".safetensors",
|
||||
)
|
||||
return output_path
|
||||
|
@ -12,18 +12,21 @@ from .base import (
|
||||
InvalidModelException,
|
||||
ModelNotFoundException,
|
||||
)
|
||||
|
||||
# TODO: naming
|
||||
from ..lora import LoRAModel as LoRAModelRaw
|
||||
|
||||
|
||||
class LoRAModelFormat(str, Enum):
|
||||
LyCORIS = "lycoris"
|
||||
Diffusers = "diffusers"
|
||||
|
||||
|
||||
class LoRAModel(ModelBase):
|
||||
#model_size: int
|
||||
# model_size: int
|
||||
|
||||
class Config(ModelConfigBase):
|
||||
model_format: LoRAModelFormat # TODO:
|
||||
model_format: LoRAModelFormat # TODO:
|
||||
|
||||
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
|
||||
assert model_type == ModelType.Lora
|
||||
|
@ -15,12 +15,13 @@ from .base import (
|
||||
)
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
|
||||
class StableDiffusionXLModelFormat(str, Enum):
|
||||
Checkpoint = "checkpoint"
|
||||
Diffusers = "diffusers"
|
||||
|
||||
class StableDiffusionXLModel(DiffusersModel):
|
||||
|
||||
|
||||
class StableDiffusionXLModel(DiffusersModel):
|
||||
# TODO: check that configs overwriten properly
|
||||
class DiffusersConfig(ModelConfigBase):
|
||||
model_format: Literal[StableDiffusionXLModelFormat.Diffusers]
|
||||
@ -53,7 +54,7 @@ class StableDiffusionXLModel(DiffusersModel):
|
||||
|
||||
else:
|
||||
checkpoint = read_checkpoint_meta(path)
|
||||
checkpoint = checkpoint.get('state_dict', checkpoint)
|
||||
checkpoint = checkpoint.get("state_dict", checkpoint)
|
||||
in_channels = checkpoint["model.diffusion_model.input_blocks.0.0.weight"].shape[1]
|
||||
|
||||
elif model_format == StableDiffusionXLModelFormat.Diffusers:
|
||||
@ -61,7 +62,7 @@ class StableDiffusionXLModel(DiffusersModel):
|
||||
if os.path.exists(unet_config_path):
|
||||
with open(unet_config_path, "r") as f:
|
||||
unet_config = json.loads(f.read())
|
||||
in_channels = unet_config['in_channels']
|
||||
in_channels = unet_config["in_channels"]
|
||||
|
||||
else:
|
||||
raise Exception("Not supported stable diffusion diffusers format(possibly onnx?)")
|
||||
@ -81,11 +82,10 @@ class StableDiffusionXLModel(DiffusersModel):
|
||||
if ckpt_config_path is None:
|
||||
# TO DO: implement picking
|
||||
pass
|
||||
|
||||
|
||||
return cls.create_config(
|
||||
path=path,
|
||||
model_format=model_format,
|
||||
|
||||
config=ckpt_config_path,
|
||||
variant=variant,
|
||||
)
|
||||
@ -114,11 +114,12 @@ class StableDiffusionXLModel(DiffusersModel):
|
||||
# source code changes, we simply translate here
|
||||
if isinstance(config, cls.CheckpointConfig):
|
||||
from invokeai.backend.model_management.models.stable_diffusion import _convert_ckpt_and_cache
|
||||
|
||||
return _convert_ckpt_and_cache(
|
||||
version=base_model,
|
||||
model_config=config,
|
||||
output_path=output_path,
|
||||
use_safetensors=False, # corrupts sdxl models for some reason
|
||||
use_safetensors=False, # corrupts sdxl models for some reason
|
||||
)
|
||||
else:
|
||||
return model_path
|
||||
|
@ -26,8 +26,8 @@ class StableDiffusion1ModelFormat(str, Enum):
|
||||
Checkpoint = "checkpoint"
|
||||
Diffusers = "diffusers"
|
||||
|
||||
class StableDiffusion1Model(DiffusersModel):
|
||||
|
||||
class StableDiffusion1Model(DiffusersModel):
|
||||
class DiffusersConfig(ModelConfigBase):
|
||||
model_format: Literal[StableDiffusion1ModelFormat.Diffusers]
|
||||
vae: Optional[str] = Field(None)
|
||||
@ -38,7 +38,7 @@ class StableDiffusion1Model(DiffusersModel):
|
||||
vae: Optional[str] = Field(None)
|
||||
config: str
|
||||
variant: ModelVariantType
|
||||
|
||||
|
||||
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
|
||||
assert base_model == BaseModelType.StableDiffusion1
|
||||
assert model_type == ModelType.Main
|
||||
@ -59,7 +59,7 @@ class StableDiffusion1Model(DiffusersModel):
|
||||
|
||||
else:
|
||||
checkpoint = read_checkpoint_meta(path)
|
||||
checkpoint = checkpoint.get('state_dict', checkpoint)
|
||||
checkpoint = checkpoint.get("state_dict", checkpoint)
|
||||
in_channels = checkpoint["model.diffusion_model.input_blocks.0.0.weight"].shape[1]
|
||||
|
||||
elif model_format == StableDiffusion1ModelFormat.Diffusers:
|
||||
@ -67,7 +67,7 @@ class StableDiffusion1Model(DiffusersModel):
|
||||
if os.path.exists(unet_config_path):
|
||||
with open(unet_config_path, "r") as f:
|
||||
unet_config = json.loads(f.read())
|
||||
in_channels = unet_config['in_channels']
|
||||
in_channels = unet_config["in_channels"]
|
||||
|
||||
else:
|
||||
raise NotImplementedError(f"{path} is not a supported stable diffusion diffusers format")
|
||||
@ -88,7 +88,6 @@ class StableDiffusion1Model(DiffusersModel):
|
||||
return cls.create_config(
|
||||
path=path,
|
||||
model_format=model_format,
|
||||
|
||||
config=ckpt_config_path,
|
||||
variant=variant,
|
||||
)
|
||||
@ -125,16 +124,17 @@ class StableDiffusion1Model(DiffusersModel):
|
||||
version=BaseModelType.StableDiffusion1,
|
||||
model_config=config,
|
||||
output_path=output_path,
|
||||
)
|
||||
)
|
||||
else:
|
||||
return model_path
|
||||
|
||||
|
||||
class StableDiffusion2ModelFormat(str, Enum):
|
||||
Checkpoint = "checkpoint"
|
||||
Diffusers = "diffusers"
|
||||
|
||||
class StableDiffusion2Model(DiffusersModel):
|
||||
|
||||
class StableDiffusion2Model(DiffusersModel):
|
||||
# TODO: check that configs overwriten properly
|
||||
class DiffusersConfig(ModelConfigBase):
|
||||
model_format: Literal[StableDiffusion2ModelFormat.Diffusers]
|
||||
@ -167,7 +167,7 @@ class StableDiffusion2Model(DiffusersModel):
|
||||
|
||||
else:
|
||||
checkpoint = read_checkpoint_meta(path)
|
||||
checkpoint = checkpoint.get('state_dict', checkpoint)
|
||||
checkpoint = checkpoint.get("state_dict", checkpoint)
|
||||
in_channels = checkpoint["model.diffusion_model.input_blocks.0.0.weight"].shape[1]
|
||||
|
||||
elif model_format == StableDiffusion2ModelFormat.Diffusers:
|
||||
@ -175,7 +175,7 @@ class StableDiffusion2Model(DiffusersModel):
|
||||
if os.path.exists(unet_config_path):
|
||||
with open(unet_config_path, "r") as f:
|
||||
unet_config = json.loads(f.read())
|
||||
in_channels = unet_config['in_channels']
|
||||
in_channels = unet_config["in_channels"]
|
||||
|
||||
else:
|
||||
raise Exception("Not supported stable diffusion diffusers format(possibly onnx?)")
|
||||
@ -198,7 +198,6 @@ class StableDiffusion2Model(DiffusersModel):
|
||||
return cls.create_config(
|
||||
path=path,
|
||||
model_format=model_format,
|
||||
|
||||
config=ckpt_config_path,
|
||||
variant=variant,
|
||||
)
|
||||
@ -239,17 +238,19 @@ class StableDiffusion2Model(DiffusersModel):
|
||||
else:
|
||||
return model_path
|
||||
|
||||
|
||||
# TODO: rework
|
||||
# pass precision - currently defaulting to fp16
|
||||
def _convert_ckpt_and_cache(
|
||||
version: BaseModelType,
|
||||
model_config: Union[StableDiffusion1Model.CheckpointConfig,
|
||||
StableDiffusion2Model.CheckpointConfig,
|
||||
StableDiffusionXLModel.CheckpointConfig,
|
||||
],
|
||||
output_path: str,
|
||||
use_save_model: bool=False,
|
||||
**kwargs,
|
||||
version: BaseModelType,
|
||||
model_config: Union[
|
||||
StableDiffusion1Model.CheckpointConfig,
|
||||
StableDiffusion2Model.CheckpointConfig,
|
||||
StableDiffusionXLModel.CheckpointConfig,
|
||||
],
|
||||
output_path: str,
|
||||
use_save_model: bool = False,
|
||||
**kwargs,
|
||||
) -> str:
|
||||
"""
|
||||
Convert the checkpoint model indicated in mconfig into a
|
||||
@ -270,13 +271,14 @@ def _convert_ckpt_and_cache(
|
||||
from ..convert_ckpt_to_diffusers import convert_ckpt_to_diffusers
|
||||
from ...util.devices import choose_torch_device, torch_dtype
|
||||
|
||||
model_base_to_model_type = {BaseModelType.StableDiffusion1: 'FrozenCLIPEmbedder',
|
||||
BaseModelType.StableDiffusion2: 'FrozenOpenCLIPEmbedder',
|
||||
BaseModelType.StableDiffusionXL: 'SDXL',
|
||||
BaseModelType.StableDiffusionXLRefiner: 'SDXL-Refiner',
|
||||
}
|
||||
logger.info(f'Converting {weights} to diffusers format')
|
||||
with SilenceWarnings():
|
||||
model_base_to_model_type = {
|
||||
BaseModelType.StableDiffusion1: "FrozenCLIPEmbedder",
|
||||
BaseModelType.StableDiffusion2: "FrozenOpenCLIPEmbedder",
|
||||
BaseModelType.StableDiffusionXL: "SDXL",
|
||||
BaseModelType.StableDiffusionXLRefiner: "SDXL-Refiner",
|
||||
}
|
||||
logger.info(f"Converting {weights} to diffusers format")
|
||||
with SilenceWarnings():
|
||||
convert_ckpt_to_diffusers(
|
||||
weights,
|
||||
output_path,
|
||||
@ -286,12 +288,13 @@ def _convert_ckpt_and_cache(
|
||||
original_config_file=config_file,
|
||||
extract_ema=True,
|
||||
scan_needed=True,
|
||||
from_safetensors = weights.suffix == ".safetensors",
|
||||
precision = torch_dtype(choose_torch_device()),
|
||||
from_safetensors=weights.suffix == ".safetensors",
|
||||
precision=torch_dtype(choose_torch_device()),
|
||||
**kwargs,
|
||||
)
|
||||
return output_path
|
||||
|
||||
|
||||
def _select_ckpt_config(version: BaseModelType, variant: ModelVariantType):
|
||||
ckpt_configs = {
|
||||
BaseModelType.StableDiffusion1: {
|
||||
@ -299,7 +302,7 @@ def _select_ckpt_config(version: BaseModelType, variant: ModelVariantType):
|
||||
ModelVariantType.Inpaint: "v1-inpainting-inference.yaml",
|
||||
},
|
||||
BaseModelType.StableDiffusion2: {
|
||||
ModelVariantType.Normal: "v2-inference-v.yaml", # best guess, as we can't differentiate with base(512)
|
||||
ModelVariantType.Normal: "v2-inference-v.yaml", # best guess, as we can't differentiate with base(512)
|
||||
ModelVariantType.Inpaint: "v2-inpainting-inference.yaml",
|
||||
ModelVariantType.Depth: "v2-midas-inference.yaml",
|
||||
},
|
||||
@ -321,8 +324,6 @@ def _select_ckpt_config(version: BaseModelType, variant: ModelVariantType):
|
||||
if config_path.is_relative_to(app_config.root_path):
|
||||
config_path = config_path.relative_to(app_config.root_path)
|
||||
return str(config_path)
|
||||
|
||||
|
||||
except:
|
||||
return None
|
||||
|
||||
|
||||
|
@ -11,11 +11,13 @@ from .base import (
|
||||
ModelNotFoundException,
|
||||
InvalidModelException,
|
||||
)
|
||||
|
||||
# TODO: naming
|
||||
from ..lora import TextualInversionModel as TextualInversionModelRaw
|
||||
|
||||
|
||||
class TextualInversionModel(ModelBase):
|
||||
#model_size: int
|
||||
# model_size: int
|
||||
|
||||
class Config(ModelConfigBase):
|
||||
model_format: None
|
||||
@ -65,7 +67,7 @@ class TextualInversionModel(ModelBase):
|
||||
|
||||
if os.path.isdir(path):
|
||||
if os.path.exists(os.path.join(path, "learned_embeds.bin")):
|
||||
return None # diffusers-ti
|
||||
return None # diffusers-ti
|
||||
|
||||
if os.path.isfile(path):
|
||||
if any([path.endswith(f".{ext}") for ext in ["safetensors", "ckpt", "pt", "bin"]]):
|
||||
|
@ -22,13 +22,15 @@ from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from diffusers.utils import is_safetensors_available
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
|
||||
class VaeModelFormat(str, Enum):
|
||||
Checkpoint = "checkpoint"
|
||||
Diffusers = "diffusers"
|
||||
|
||||
|
||||
class VaeModel(ModelBase):
|
||||
#vae_class: Type
|
||||
#model_size: int
|
||||
# vae_class: Type
|
||||
# model_size: int
|
||||
|
||||
class Config(ModelConfigBase):
|
||||
model_format: VaeModelFormat
|
||||
@ -39,7 +41,7 @@ class VaeModel(ModelBase):
|
||||
|
||||
try:
|
||||
config = EmptyConfigLoader.load_config(self.model_path, config_name="config.json")
|
||||
#config = json.loads(os.path.join(self.model_path, "config.json"))
|
||||
# config = json.loads(os.path.join(self.model_path, "config.json"))
|
||||
except:
|
||||
raise Exception("Invalid vae model! (config.json not found or invalid)")
|
||||
|
||||
@ -95,7 +97,7 @@ class VaeModel(ModelBase):
|
||||
cls,
|
||||
model_path: str,
|
||||
output_path: str,
|
||||
config: ModelConfigBase, # empty config or config of parent model
|
||||
config: ModelConfigBase, # empty config or config of parent model
|
||||
base_model: BaseModelType,
|
||||
) -> str:
|
||||
if cls.detect_format(model_path) == VaeModelFormat.Checkpoint:
|
||||
@ -108,6 +110,7 @@ class VaeModel(ModelBase):
|
||||
else:
|
||||
return model_path
|
||||
|
||||
|
||||
# TODO: rework
|
||||
def _convert_vae_ckpt_and_cache(
|
||||
weights_path: str,
|
||||
@ -138,13 +141,14 @@ def _convert_vae_ckpt_and_cache(
|
||||
2.1 - 768
|
||||
"""
|
||||
image_size = 512
|
||||
|
||||
|
||||
# return cached version if it exists
|
||||
if output_path.exists():
|
||||
return output_path
|
||||
|
||||
if base_model in {BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2}:
|
||||
from .stable_diffusion import _select_ckpt_config
|
||||
|
||||
# all sd models use same vae settings
|
||||
config_file = _select_ckpt_config(base_model, ModelVariantType.Normal)
|
||||
else:
|
||||
@ -152,7 +156,8 @@ def _convert_vae_ckpt_and_cache(
|
||||
|
||||
# this avoids circular import error
|
||||
from ..convert_ckpt_to_diffusers import convert_ldm_vae_to_diffusers
|
||||
if weights_path.suffix == '.safetensors':
|
||||
|
||||
if weights_path.suffix == ".safetensors":
|
||||
checkpoint = safetensors.torch.load_file(weights_path, device="cpu")
|
||||
else:
|
||||
checkpoint = torch.load(weights_path, map_location="cpu")
|
||||
@ -161,15 +166,12 @@ def _convert_vae_ckpt_and_cache(
|
||||
if "state_dict" in checkpoint:
|
||||
checkpoint = checkpoint["state_dict"]
|
||||
|
||||
config = OmegaConf.load(app_config.root_path/config_file)
|
||||
config = OmegaConf.load(app_config.root_path / config_file)
|
||||
|
||||
vae_model = convert_ldm_vae_to_diffusers(
|
||||
checkpoint = checkpoint,
|
||||
vae_config = config,
|
||||
image_size = image_size,
|
||||
)
|
||||
vae_model.save_pretrained(
|
||||
output_path,
|
||||
safe_serialization=is_safetensors_available()
|
||||
checkpoint=checkpoint,
|
||||
vae_config=config,
|
||||
image_size=image_size,
|
||||
)
|
||||
vae_model.save_pretrained(output_path, safe_serialization=is_safetensors_available())
|
||||
return output_path
|
||||
|
Reference in New Issue
Block a user