Apply black

This commit is contained in:
Martin Kristiansen
2023-07-27 10:54:01 -04:00
parent 2183dba5c5
commit 218b6d0546
148 changed files with 5486 additions and 6296 deletions

View File

@ -3,15 +3,24 @@ from enum import Enum
from pydantic import BaseModel
from typing import Literal, get_origin
from .base import (
BaseModelType, ModelType, SubModelType, ModelBase, ModelConfigBase,
ModelVariantType, SchedulerPredictionType, ModelError, SilenceWarnings,
ModelNotFoundException, InvalidModelException, DuplicateModelException
)
BaseModelType,
ModelType,
SubModelType,
ModelBase,
ModelConfigBase,
ModelVariantType,
SchedulerPredictionType,
ModelError,
SilenceWarnings,
ModelNotFoundException,
InvalidModelException,
DuplicateModelException,
)
from .stable_diffusion import StableDiffusion1Model, StableDiffusion2Model
from .sdxl import StableDiffusionXLModel
from .vae import VaeModel
from .lora import LoRAModel
from .controlnet import ControlNetModel # TODO:
from .controlnet import ControlNetModel # TODO:
from .textual_inversion import TextualInversionModel
MODEL_CLASSES = {
@ -45,18 +54,19 @@ MODEL_CLASSES = {
ModelType.ControlNet: ControlNetModel,
ModelType.TextualInversion: TextualInversionModel,
},
#BaseModelType.Kandinsky2_1: {
# BaseModelType.Kandinsky2_1: {
# ModelType.Main: Kandinsky2_1Model,
# ModelType.MoVQ: MoVQModel,
# ModelType.Lora: LoRAModel,
# ModelType.ControlNet: ControlNetModel,
# ModelType.TextualInversion: TextualInversionModel,
#},
# },
}
MODEL_CONFIGS = list()
OPENAPI_MODEL_CONFIGS = list()
class OpenAPIModelInfoBase(BaseModel):
model_name: str
base_model: BaseModelType
@ -72,27 +82,31 @@ for base_model, models in MODEL_CLASSES.items():
# LS: sort to get the checkpoint configs first, which makes
# for a better template in the Swagger docs
for cfg in sorted(model_configs, key=lambda x: str(x)):
model_name, cfg_name = cfg.__qualname__.split('.')[-2:]
model_name, cfg_name = cfg.__qualname__.split(".")[-2:]
openapi_cfg_name = model_name + cfg_name
if openapi_cfg_name in vars():
continue
api_wrapper = type(openapi_cfg_name, (cfg, OpenAPIModelInfoBase), dict(
__annotations__ = dict(
model_type=Literal[model_type.value],
api_wrapper = type(
openapi_cfg_name,
(cfg, OpenAPIModelInfoBase),
dict(
__annotations__=dict(
model_type=Literal[model_type.value],
),
),
))
)
#globals()[openapi_cfg_name] = api_wrapper
# globals()[openapi_cfg_name] = api_wrapper
vars()[openapi_cfg_name] = api_wrapper
OPENAPI_MODEL_CONFIGS.append(api_wrapper)
def get_model_config_enums():
enums = list()
for model_config in MODEL_CONFIGS:
if hasattr(inspect,'get_annotations'):
if hasattr(inspect, "get_annotations"):
fields = inspect.get_annotations(model_config)
else:
fields = model_config.__annotations__
@ -109,7 +123,9 @@ def get_model_config_enums():
if isinstance(field, type) and issubclass(field, str) and issubclass(field, Enum):
enums.append(field)
elif get_origin(field) is Literal and all(isinstance(arg, str) and isinstance(arg, Enum) for arg in field.__args__):
elif get_origin(field) is Literal and all(
isinstance(arg, str) and isinstance(arg, Enum) for arg in field.__args__
):
enums.append(type(field.__args__[0]))
elif field is None:
@ -119,4 +135,3 @@ def get_model_config_enums():
raise Exception(f"Unsupported format definition in {model_configs.__qualname__}")
return enums

View File

@ -15,29 +15,35 @@ from contextlib import suppress
from pydantic import BaseModel, Field
from typing import List, Dict, Optional, Type, Literal, TypeVar, Generic, Callable, Any, Union
class DuplicateModelException(Exception):
pass
class InvalidModelException(Exception):
pass
class ModelNotFoundException(Exception):
pass
class BaseModelType(str, Enum):
StableDiffusion1 = "sd-1"
StableDiffusion2 = "sd-2"
StableDiffusionXL = "sdxl"
StableDiffusionXLRefiner = "sdxl-refiner"
#Kandinsky2_1 = "kandinsky-2.1"
# Kandinsky2_1 = "kandinsky-2.1"
class ModelType(str, Enum):
Main = "main"
Vae = "vae"
Lora = "lora"
ControlNet = "controlnet" # used by model_probe
ControlNet = "controlnet" # used by model_probe
TextualInversion = "embedding"
class SubModelType(str, Enum):
UNet = "unet"
TextEncoder = "text_encoder"
@ -47,23 +53,27 @@ class SubModelType(str, Enum):
Vae = "vae"
Scheduler = "scheduler"
SafetyChecker = "safety_checker"
#MoVQ = "movq"
# MoVQ = "movq"
class ModelVariantType(str, Enum):
Normal = "normal"
Inpaint = "inpaint"
Depth = "depth"
class SchedulerPredictionType(str, Enum):
Epsilon = "epsilon"
VPrediction = "v_prediction"
Sample = "sample"
class ModelError(str, Enum):
NotFound = "not_found"
class ModelConfigBase(BaseModel):
path: str # or Path
path: str # or Path
description: Optional[str] = Field(None)
model_format: Optional[str] = Field(None)
error: Optional[ModelError] = Field(None)
@ -71,13 +81,17 @@ class ModelConfigBase(BaseModel):
class Config:
use_enum_values = True
class EmptyConfigLoader(ConfigMixin):
@classmethod
def load_config(cls, *args, **kwargs):
cls.config_name = kwargs.pop("config_name")
return super().load_config(*args, **kwargs)
T_co = TypeVar('T_co', covariant=True)
T_co = TypeVar("T_co", covariant=True)
class classproperty(Generic[T_co]):
def __init__(self, fget: Callable[[Any], T_co]) -> None:
self.fget = fget
@ -86,12 +100,13 @@ class classproperty(Generic[T_co]):
return self.fget(owner)
def __set__(self, instance: Optional[Any], value: Any) -> None:
raise AttributeError('cannot set attribute')
raise AttributeError("cannot set attribute")
class ModelBase(metaclass=ABCMeta):
#model_path: str
#base_model: BaseModelType
#model_type: ModelType
# model_path: str
# base_model: BaseModelType
# model_type: ModelType
def __init__(
self,
@ -110,7 +125,7 @@ class ModelBase(metaclass=ABCMeta):
return None
elif any(t is None for t in subtypes):
raise Exception(f"Unsupported definition: {subtypes}")
if subtypes[0] in ["diffusers", "transformers"]:
res_type = sys.modules[subtypes[0]]
subtypes = subtypes[1:]
@ -119,7 +134,6 @@ class ModelBase(metaclass=ABCMeta):
res_type = sys.modules["diffusers"]
res_type = getattr(res_type, "pipelines")
for subtype in subtypes:
res_type = getattr(res_type, subtype)
return res_type
@ -128,7 +142,7 @@ class ModelBase(metaclass=ABCMeta):
def _get_configs(cls):
with suppress(Exception):
return cls.__configs
configs = dict()
for name in dir(cls):
if name.startswith("__"):
@ -138,7 +152,7 @@ class ModelBase(metaclass=ABCMeta):
if not isinstance(value, type) or not issubclass(value, ModelConfigBase):
continue
if hasattr(inspect,'get_annotations'):
if hasattr(inspect, "get_annotations"):
fields = inspect.get_annotations(value)
else:
fields = value.__annotations__
@ -151,7 +165,9 @@ class ModelBase(metaclass=ABCMeta):
for model_format in field:
configs[model_format.value] = value
elif typing.get_origin(field) is Literal and all(isinstance(arg, str) and isinstance(arg, Enum) for arg in field.__args__):
elif typing.get_origin(field) is Literal and all(
isinstance(arg, str) and isinstance(arg, Enum) for arg in field.__args__
):
for model_format in field.__args__:
configs[model_format.value] = value
@ -203,8 +219,8 @@ class ModelBase(metaclass=ABCMeta):
class DiffusersModel(ModelBase):
#child_types: Dict[str, Type]
#child_sizes: Dict[str, int]
# child_types: Dict[str, Type]
# child_sizes: Dict[str, int]
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
super().__init__(model_path, base_model, model_type)
@ -214,7 +230,7 @@ class DiffusersModel(ModelBase):
try:
config_data = DiffusionPipeline.load_config(self.model_path)
#config_data = json.loads(os.path.join(self.model_path, "model_index.json"))
# config_data = json.loads(os.path.join(self.model_path, "model_index.json"))
except:
raise Exception("Invalid diffusers model! (model_index.json not found or invalid)")
@ -228,14 +244,12 @@ class DiffusersModel(ModelBase):
self.child_types[child_name] = child_type
self.child_sizes[child_name] = calc_model_size_by_fs(self.model_path, subfolder=child_name)
def get_size(self, child_type: Optional[SubModelType] = None):
if child_type is None:
return sum(self.child_sizes.values())
else:
return self.child_sizes[child_type]
def get_model(
self,
torch_dtype: Optional[torch.dtype],
@ -245,7 +259,7 @@ class DiffusersModel(ModelBase):
if child_type is None:
raise Exception("Child model type can't be null on diffusers model")
if child_type not in self.child_types:
return None # TODO: or raise
return None # TODO: or raise
if torch_dtype == torch.float16:
variants = ["fp16", None]
@ -265,8 +279,8 @@ class DiffusersModel(ModelBase):
)
break
except Exception as e:
#print("====ERR LOAD====")
#print(f"{variant}: {e}")
# print("====ERR LOAD====")
# print(f"{variant}: {e}")
pass
else:
raise Exception(f"Failed to load {self.base_model}:{self.model_type}:{child_type} model")
@ -275,15 +289,10 @@ class DiffusersModel(ModelBase):
self.child_sizes[child_type] = calc_model_size_by_data(model)
return model
#def convert_if_required(model_path: str, cache_path: str, config: Optional[dict]) -> str:
# def convert_if_required(model_path: str, cache_path: str, config: Optional[dict]) -> str:
def calc_model_size_by_fs(
model_path: str,
subfolder: Optional[str] = None,
variant: Optional[str] = None
):
def calc_model_size_by_fs(model_path: str, subfolder: Optional[str] = None, variant: Optional[str] = None):
if subfolder is not None:
model_path = os.path.join(model_path, subfolder)
@ -325,12 +334,12 @@ def calc_model_size_by_fs(
# calculate files size if there is no index file
formats = [
(".safetensors",), # safetensors
(".bin",), # torch
(".onnx", ".pb"), # onnx
(".msgpack",), # flax
(".ckpt",), # tf
(".h5",), # tf2
(".safetensors",), # safetensors
(".bin",), # torch
(".onnx", ".pb"), # onnx
(".msgpack",), # flax
(".ckpt",), # tf
(".h5",), # tf2
]
for file_format in formats:
@ -343,9 +352,9 @@ def calc_model_size_by_fs(
file_stats = os.stat(os.path.join(model_path, model_file))
model_size += file_stats.st_size
return model_size
#raise NotImplementedError(f"Unknown model structure! Files: {all_files}")
return 0 # scheduler/feature_extractor/tokenizer - models without loading to gpu
# raise NotImplementedError(f"Unknown model structure! Files: {all_files}")
return 0 # scheduler/feature_extractor/tokenizer - models without loading to gpu
def calc_model_size_by_data(model) -> int:
@ -364,12 +373,12 @@ def _calc_pipeline_by_data(pipeline) -> int:
if submodel is not None and isinstance(submodel, torch.nn.Module):
res += _calc_model_by_data(submodel)
return res
def _calc_model_by_data(model) -> int:
mem_params = sum([param.nelement()*param.element_size() for param in model.parameters()])
mem_bufs = sum([buf.nelement()*buf.element_size() for buf in model.buffers()])
mem = mem_params + mem_bufs # in bytes
mem_params = sum([param.nelement() * param.element_size() for param in model.parameters()])
mem_bufs = sum([buf.nelement() * buf.element_size() for buf in model.buffers()])
mem = mem_params + mem_bufs # in bytes
return mem
@ -377,11 +386,15 @@ def _fast_safetensors_reader(path: str):
checkpoint = dict()
device = torch.device("meta")
with open(path, "rb") as f:
definition_len = int.from_bytes(f.read(8), 'little')
definition_len = int.from_bytes(f.read(8), "little")
definition_json = f.read(definition_len)
definition = json.loads(definition_json)
if "__metadata__" in definition and definition["__metadata__"].get("format", "pt") not in {"pt", "torch", "pytorch"}:
if "__metadata__" in definition and definition["__metadata__"].get("format", "pt") not in {
"pt",
"torch",
"pytorch",
}:
raise Exception("Supported only pytorch safetensors files")
definition.pop("__metadata__", None)
@ -400,6 +413,7 @@ def _fast_safetensors_reader(path: str):
return checkpoint
def read_checkpoint_meta(path: Union[str, Path], scan: bool = False):
if str(path).endswith(".safetensors"):
try:
@ -411,25 +425,27 @@ def read_checkpoint_meta(path: Union[str, Path], scan: bool = False):
if scan:
scan_result = scan_file_path(path)
if scan_result.infected_files != 0:
raise Exception(f"The model file \"{path}\" is potentially infected by malware. Aborting import.")
raise Exception(f'The model file "{path}" is potentially infected by malware. Aborting import.')
checkpoint = torch.load(path, map_location=torch.device("meta"))
return checkpoint
import warnings
from diffusers import logging as diffusers_logging
from transformers import logging as transformers_logging
class SilenceWarnings(object):
def __init__(self):
self.transformers_verbosity = transformers_logging.get_verbosity()
self.diffusers_verbosity = diffusers_logging.get_verbosity()
def __enter__(self):
transformers_logging.set_verbosity_error()
diffusers_logging.set_verbosity_error()
warnings.simplefilter('ignore')
warnings.simplefilter("ignore")
def __exit__(self, type, value, traceback):
transformers_logging.set_verbosity(self.transformers_verbosity)
diffusers_logging.set_verbosity(self.diffusers_verbosity)
warnings.simplefilter('default')
warnings.simplefilter("default")

View File

@ -18,13 +18,15 @@ from .base import (
)
from invokeai.app.services.config import InvokeAIAppConfig
class ControlNetModelFormat(str, Enum):
Checkpoint = "checkpoint"
Diffusers = "diffusers"
class ControlNetModel(ModelBase):
#model_class: Type
#model_size: int
# model_class: Type
# model_size: int
class DiffusersConfig(ModelConfigBase):
model_format: Literal[ControlNetModelFormat.Diffusers]
@ -39,7 +41,7 @@ class ControlNetModel(ModelBase):
try:
config = EmptyConfigLoader.load_config(self.model_path, config_name="config.json")
#config = json.loads(os.path.join(self.model_path, "config.json"))
# config = json.loads(os.path.join(self.model_path, "config.json"))
except:
raise Exception("Invalid controlnet model! (config.json not found or invalid)")
@ -67,7 +69,7 @@ class ControlNetModel(ModelBase):
raise Exception("There is no child models in controlnet model")
model = None
for variant in ['fp16',None]:
for variant in ["fp16", None]:
try:
model = self.model_class.from_pretrained(
self.model_path,
@ -79,7 +81,7 @@ class ControlNetModel(ModelBase):
pass
if not model:
raise ModelNotFoundException()
# calc more accurate size
self.model_size = calc_model_size_by_data(model)
return model
@ -105,29 +107,30 @@ class ControlNetModel(ModelBase):
@classmethod
def convert_if_required(
cls,
model_path: str,
output_path: str,
config: ModelConfigBase,
base_model: BaseModelType,
) -> str:
if cls.detect_format(model_path) == ControlNetModelFormat.Checkpoint:
return _convert_controlnet_ckpt_and_cache(
model_path = model_path,
model_config = config.config,
output_path = output_path,
base_model = base_model,
)
else:
return model_path
@classmethod
def _convert_controlnet_ckpt_and_cache(
cls,
model_path: str,
output_path: str,
config: ModelConfigBase,
base_model: BaseModelType,
model_config: ControlNetModel.CheckpointConfig,
) -> str:
if cls.detect_format(model_path) == ControlNetModelFormat.Checkpoint:
return _convert_controlnet_ckpt_and_cache(
model_path=model_path,
model_config=config.config,
output_path=output_path,
base_model=base_model,
)
else:
return model_path
@classmethod
def _convert_controlnet_ckpt_and_cache(
cls,
model_path: str,
output_path: str,
base_model: BaseModelType,
model_config: ControlNetModel.CheckpointConfig,
) -> str:
"""
Convert the controlnet from checkpoint format to diffusers format,
@ -144,12 +147,13 @@ def _convert_controlnet_ckpt_and_cache(
# to avoid circular import errors
from ..convert_ckpt_to_diffusers import convert_controlnet_to_diffusers
convert_controlnet_to_diffusers(
weights,
output_path,
original_config_file = app_config.root_path / model_config,
image_size = 512,
scan_needed = True,
from_safetensors = weights.suffix == ".safetensors"
original_config_file=app_config.root_path / model_config,
image_size=512,
scan_needed=True,
from_safetensors=weights.suffix == ".safetensors",
)
return output_path

View File

@ -12,18 +12,21 @@ from .base import (
InvalidModelException,
ModelNotFoundException,
)
# TODO: naming
from ..lora import LoRAModel as LoRAModelRaw
class LoRAModelFormat(str, Enum):
LyCORIS = "lycoris"
Diffusers = "diffusers"
class LoRAModel(ModelBase):
#model_size: int
# model_size: int
class Config(ModelConfigBase):
model_format: LoRAModelFormat # TODO:
model_format: LoRAModelFormat # TODO:
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
assert model_type == ModelType.Lora

View File

@ -15,12 +15,13 @@ from .base import (
)
from omegaconf import OmegaConf
class StableDiffusionXLModelFormat(str, Enum):
Checkpoint = "checkpoint"
Diffusers = "diffusers"
class StableDiffusionXLModel(DiffusersModel):
class StableDiffusionXLModel(DiffusersModel):
# TODO: check that configs overwriten properly
class DiffusersConfig(ModelConfigBase):
model_format: Literal[StableDiffusionXLModelFormat.Diffusers]
@ -53,7 +54,7 @@ class StableDiffusionXLModel(DiffusersModel):
else:
checkpoint = read_checkpoint_meta(path)
checkpoint = checkpoint.get('state_dict', checkpoint)
checkpoint = checkpoint.get("state_dict", checkpoint)
in_channels = checkpoint["model.diffusion_model.input_blocks.0.0.weight"].shape[1]
elif model_format == StableDiffusionXLModelFormat.Diffusers:
@ -61,7 +62,7 @@ class StableDiffusionXLModel(DiffusersModel):
if os.path.exists(unet_config_path):
with open(unet_config_path, "r") as f:
unet_config = json.loads(f.read())
in_channels = unet_config['in_channels']
in_channels = unet_config["in_channels"]
else:
raise Exception("Not supported stable diffusion diffusers format(possibly onnx?)")
@ -81,11 +82,10 @@ class StableDiffusionXLModel(DiffusersModel):
if ckpt_config_path is None:
# TO DO: implement picking
pass
return cls.create_config(
path=path,
model_format=model_format,
config=ckpt_config_path,
variant=variant,
)
@ -114,11 +114,12 @@ class StableDiffusionXLModel(DiffusersModel):
# source code changes, we simply translate here
if isinstance(config, cls.CheckpointConfig):
from invokeai.backend.model_management.models.stable_diffusion import _convert_ckpt_and_cache
return _convert_ckpt_and_cache(
version=base_model,
model_config=config,
output_path=output_path,
use_safetensors=False, # corrupts sdxl models for some reason
use_safetensors=False, # corrupts sdxl models for some reason
)
else:
return model_path

View File

@ -26,8 +26,8 @@ class StableDiffusion1ModelFormat(str, Enum):
Checkpoint = "checkpoint"
Diffusers = "diffusers"
class StableDiffusion1Model(DiffusersModel):
class StableDiffusion1Model(DiffusersModel):
class DiffusersConfig(ModelConfigBase):
model_format: Literal[StableDiffusion1ModelFormat.Diffusers]
vae: Optional[str] = Field(None)
@ -38,7 +38,7 @@ class StableDiffusion1Model(DiffusersModel):
vae: Optional[str] = Field(None)
config: str
variant: ModelVariantType
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
assert base_model == BaseModelType.StableDiffusion1
assert model_type == ModelType.Main
@ -59,7 +59,7 @@ class StableDiffusion1Model(DiffusersModel):
else:
checkpoint = read_checkpoint_meta(path)
checkpoint = checkpoint.get('state_dict', checkpoint)
checkpoint = checkpoint.get("state_dict", checkpoint)
in_channels = checkpoint["model.diffusion_model.input_blocks.0.0.weight"].shape[1]
elif model_format == StableDiffusion1ModelFormat.Diffusers:
@ -67,7 +67,7 @@ class StableDiffusion1Model(DiffusersModel):
if os.path.exists(unet_config_path):
with open(unet_config_path, "r") as f:
unet_config = json.loads(f.read())
in_channels = unet_config['in_channels']
in_channels = unet_config["in_channels"]
else:
raise NotImplementedError(f"{path} is not a supported stable diffusion diffusers format")
@ -88,7 +88,6 @@ class StableDiffusion1Model(DiffusersModel):
return cls.create_config(
path=path,
model_format=model_format,
config=ckpt_config_path,
variant=variant,
)
@ -125,16 +124,17 @@ class StableDiffusion1Model(DiffusersModel):
version=BaseModelType.StableDiffusion1,
model_config=config,
output_path=output_path,
)
)
else:
return model_path
class StableDiffusion2ModelFormat(str, Enum):
Checkpoint = "checkpoint"
Diffusers = "diffusers"
class StableDiffusion2Model(DiffusersModel):
class StableDiffusion2Model(DiffusersModel):
# TODO: check that configs overwriten properly
class DiffusersConfig(ModelConfigBase):
model_format: Literal[StableDiffusion2ModelFormat.Diffusers]
@ -167,7 +167,7 @@ class StableDiffusion2Model(DiffusersModel):
else:
checkpoint = read_checkpoint_meta(path)
checkpoint = checkpoint.get('state_dict', checkpoint)
checkpoint = checkpoint.get("state_dict", checkpoint)
in_channels = checkpoint["model.diffusion_model.input_blocks.0.0.weight"].shape[1]
elif model_format == StableDiffusion2ModelFormat.Diffusers:
@ -175,7 +175,7 @@ class StableDiffusion2Model(DiffusersModel):
if os.path.exists(unet_config_path):
with open(unet_config_path, "r") as f:
unet_config = json.loads(f.read())
in_channels = unet_config['in_channels']
in_channels = unet_config["in_channels"]
else:
raise Exception("Not supported stable diffusion diffusers format(possibly onnx?)")
@ -198,7 +198,6 @@ class StableDiffusion2Model(DiffusersModel):
return cls.create_config(
path=path,
model_format=model_format,
config=ckpt_config_path,
variant=variant,
)
@ -239,17 +238,19 @@ class StableDiffusion2Model(DiffusersModel):
else:
return model_path
# TODO: rework
# pass precision - currently defaulting to fp16
def _convert_ckpt_and_cache(
version: BaseModelType,
model_config: Union[StableDiffusion1Model.CheckpointConfig,
StableDiffusion2Model.CheckpointConfig,
StableDiffusionXLModel.CheckpointConfig,
],
output_path: str,
use_save_model: bool=False,
**kwargs,
version: BaseModelType,
model_config: Union[
StableDiffusion1Model.CheckpointConfig,
StableDiffusion2Model.CheckpointConfig,
StableDiffusionXLModel.CheckpointConfig,
],
output_path: str,
use_save_model: bool = False,
**kwargs,
) -> str:
"""
Convert the checkpoint model indicated in mconfig into a
@ -270,13 +271,14 @@ def _convert_ckpt_and_cache(
from ..convert_ckpt_to_diffusers import convert_ckpt_to_diffusers
from ...util.devices import choose_torch_device, torch_dtype
model_base_to_model_type = {BaseModelType.StableDiffusion1: 'FrozenCLIPEmbedder',
BaseModelType.StableDiffusion2: 'FrozenOpenCLIPEmbedder',
BaseModelType.StableDiffusionXL: 'SDXL',
BaseModelType.StableDiffusionXLRefiner: 'SDXL-Refiner',
}
logger.info(f'Converting {weights} to diffusers format')
with SilenceWarnings():
model_base_to_model_type = {
BaseModelType.StableDiffusion1: "FrozenCLIPEmbedder",
BaseModelType.StableDiffusion2: "FrozenOpenCLIPEmbedder",
BaseModelType.StableDiffusionXL: "SDXL",
BaseModelType.StableDiffusionXLRefiner: "SDXL-Refiner",
}
logger.info(f"Converting {weights} to diffusers format")
with SilenceWarnings():
convert_ckpt_to_diffusers(
weights,
output_path,
@ -286,12 +288,13 @@ def _convert_ckpt_and_cache(
original_config_file=config_file,
extract_ema=True,
scan_needed=True,
from_safetensors = weights.suffix == ".safetensors",
precision = torch_dtype(choose_torch_device()),
from_safetensors=weights.suffix == ".safetensors",
precision=torch_dtype(choose_torch_device()),
**kwargs,
)
return output_path
def _select_ckpt_config(version: BaseModelType, variant: ModelVariantType):
ckpt_configs = {
BaseModelType.StableDiffusion1: {
@ -299,7 +302,7 @@ def _select_ckpt_config(version: BaseModelType, variant: ModelVariantType):
ModelVariantType.Inpaint: "v1-inpainting-inference.yaml",
},
BaseModelType.StableDiffusion2: {
ModelVariantType.Normal: "v2-inference-v.yaml", # best guess, as we can't differentiate with base(512)
ModelVariantType.Normal: "v2-inference-v.yaml", # best guess, as we can't differentiate with base(512)
ModelVariantType.Inpaint: "v2-inpainting-inference.yaml",
ModelVariantType.Depth: "v2-midas-inference.yaml",
},
@ -321,8 +324,6 @@ def _select_ckpt_config(version: BaseModelType, variant: ModelVariantType):
if config_path.is_relative_to(app_config.root_path):
config_path = config_path.relative_to(app_config.root_path)
return str(config_path)
except:
return None

View File

@ -11,11 +11,13 @@ from .base import (
ModelNotFoundException,
InvalidModelException,
)
# TODO: naming
from ..lora import TextualInversionModel as TextualInversionModelRaw
class TextualInversionModel(ModelBase):
#model_size: int
# model_size: int
class Config(ModelConfigBase):
model_format: None
@ -65,7 +67,7 @@ class TextualInversionModel(ModelBase):
if os.path.isdir(path):
if os.path.exists(os.path.join(path, "learned_embeds.bin")):
return None # diffusers-ti
return None # diffusers-ti
if os.path.isfile(path):
if any([path.endswith(f".{ext}") for ext in ["safetensors", "ckpt", "pt", "bin"]]):

View File

@ -22,13 +22,15 @@ from invokeai.app.services.config import InvokeAIAppConfig
from diffusers.utils import is_safetensors_available
from omegaconf import OmegaConf
class VaeModelFormat(str, Enum):
Checkpoint = "checkpoint"
Diffusers = "diffusers"
class VaeModel(ModelBase):
#vae_class: Type
#model_size: int
# vae_class: Type
# model_size: int
class Config(ModelConfigBase):
model_format: VaeModelFormat
@ -39,7 +41,7 @@ class VaeModel(ModelBase):
try:
config = EmptyConfigLoader.load_config(self.model_path, config_name="config.json")
#config = json.loads(os.path.join(self.model_path, "config.json"))
# config = json.loads(os.path.join(self.model_path, "config.json"))
except:
raise Exception("Invalid vae model! (config.json not found or invalid)")
@ -95,7 +97,7 @@ class VaeModel(ModelBase):
cls,
model_path: str,
output_path: str,
config: ModelConfigBase, # empty config or config of parent model
config: ModelConfigBase, # empty config or config of parent model
base_model: BaseModelType,
) -> str:
if cls.detect_format(model_path) == VaeModelFormat.Checkpoint:
@ -108,6 +110,7 @@ class VaeModel(ModelBase):
else:
return model_path
# TODO: rework
def _convert_vae_ckpt_and_cache(
weights_path: str,
@ -138,13 +141,14 @@ def _convert_vae_ckpt_and_cache(
2.1 - 768
"""
image_size = 512
# return cached version if it exists
if output_path.exists():
return output_path
if base_model in {BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2}:
from .stable_diffusion import _select_ckpt_config
# all sd models use same vae settings
config_file = _select_ckpt_config(base_model, ModelVariantType.Normal)
else:
@ -152,7 +156,8 @@ def _convert_vae_ckpt_and_cache(
# this avoids circular import error
from ..convert_ckpt_to_diffusers import convert_ldm_vae_to_diffusers
if weights_path.suffix == '.safetensors':
if weights_path.suffix == ".safetensors":
checkpoint = safetensors.torch.load_file(weights_path, device="cpu")
else:
checkpoint = torch.load(weights_path, map_location="cpu")
@ -161,15 +166,12 @@ def _convert_vae_ckpt_and_cache(
if "state_dict" in checkpoint:
checkpoint = checkpoint["state_dict"]
config = OmegaConf.load(app_config.root_path/config_file)
config = OmegaConf.load(app_config.root_path / config_file)
vae_model = convert_ldm_vae_to_diffusers(
checkpoint = checkpoint,
vae_config = config,
image_size = image_size,
)
vae_model.save_pretrained(
output_path,
safe_serialization=is_safetensors_available()
checkpoint=checkpoint,
vae_config=config,
image_size=image_size,
)
vae_model.save_pretrained(output_path, safe_serialization=is_safetensors_available())
return output_path