mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Fix ckpt and vae conversion, migrate script, remove sd2-base
This commit is contained in:
parent
a6af7e8824
commit
e7db6d8120
@ -43,10 +43,10 @@ class ModelLoaderOutput(BaseInvocationOutput):
|
|||||||
#fmt: on
|
#fmt: on
|
||||||
|
|
||||||
|
|
||||||
class ModelLoaderInvocation(BaseInvocation):
|
class SD1ModelLoaderInvocation(BaseInvocation):
|
||||||
"""Loading submodels of selected model."""
|
"""Loading submodels of selected model."""
|
||||||
|
|
||||||
type: Literal["model_loader"] = "model_loader"
|
type: Literal["sd1_model_loader"] = "sd1_model_loader"
|
||||||
|
|
||||||
model_name: str = Field(default="", description="Model to load")
|
model_name: str = Field(default="", description="Model to load")
|
||||||
# TODO: precision?
|
# TODO: precision?
|
||||||
@ -64,7 +64,110 @@ class ModelLoaderInvocation(BaseInvocation):
|
|||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ModelLoaderOutput:
|
def invoke(self, context: InvocationContext) -> ModelLoaderOutput:
|
||||||
|
|
||||||
base_model = BaseModelType.StableDiffusion1_5 # TODO:
|
base_model = BaseModelType.StableDiffusion1 # TODO:
|
||||||
|
|
||||||
|
# TODO: not found exceptions
|
||||||
|
if not context.services.model_manager.model_exists(
|
||||||
|
model_name=self.model_name,
|
||||||
|
base_model=base_model,
|
||||||
|
model_type=ModelType.Pipeline,
|
||||||
|
):
|
||||||
|
raise Exception(f"Unkown model name: {self.model_name}!")
|
||||||
|
|
||||||
|
"""
|
||||||
|
if not context.services.model_manager.model_exists(
|
||||||
|
model_name=self.model_name,
|
||||||
|
model_type=SDModelType.Diffusers,
|
||||||
|
submodel=SDModelType.Tokenizer,
|
||||||
|
):
|
||||||
|
raise Exception(
|
||||||
|
f"Failed to find tokenizer submodel in {self.model_name}! Check if model corrupted"
|
||||||
|
)
|
||||||
|
|
||||||
|
if not context.services.model_manager.model_exists(
|
||||||
|
model_name=self.model_name,
|
||||||
|
model_type=SDModelType.Diffusers,
|
||||||
|
submodel=SDModelType.TextEncoder,
|
||||||
|
):
|
||||||
|
raise Exception(
|
||||||
|
f"Failed to find text_encoder submodel in {self.model_name}! Check if model corrupted"
|
||||||
|
)
|
||||||
|
|
||||||
|
if not context.services.model_manager.model_exists(
|
||||||
|
model_name=self.model_name,
|
||||||
|
model_type=SDModelType.Diffusers,
|
||||||
|
submodel=SDModelType.UNet,
|
||||||
|
):
|
||||||
|
raise Exception(
|
||||||
|
f"Failed to find unet submodel from {self.model_name}! Check if model corrupted"
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
return ModelLoaderOutput(
|
||||||
|
unet=UNetField(
|
||||||
|
unet=ModelInfo(
|
||||||
|
model_name=self.model_name,
|
||||||
|
base_model=base_model,
|
||||||
|
model_type=ModelType.Pipeline,
|
||||||
|
submodel=SubModelType.UNet,
|
||||||
|
),
|
||||||
|
scheduler=ModelInfo(
|
||||||
|
model_name=self.model_name,
|
||||||
|
base_model=base_model,
|
||||||
|
model_type=ModelType.Pipeline,
|
||||||
|
submodel=SubModelType.Scheduler,
|
||||||
|
),
|
||||||
|
loras=[],
|
||||||
|
),
|
||||||
|
clip=ClipField(
|
||||||
|
tokenizer=ModelInfo(
|
||||||
|
model_name=self.model_name,
|
||||||
|
base_model=base_model,
|
||||||
|
model_type=ModelType.Pipeline,
|
||||||
|
submodel=SubModelType.Tokenizer,
|
||||||
|
),
|
||||||
|
text_encoder=ModelInfo(
|
||||||
|
model_name=self.model_name,
|
||||||
|
base_model=base_model,
|
||||||
|
model_type=ModelType.Pipeline,
|
||||||
|
submodel=SubModelType.TextEncoder,
|
||||||
|
),
|
||||||
|
loras=[],
|
||||||
|
),
|
||||||
|
vae=VaeField(
|
||||||
|
vae=ModelInfo(
|
||||||
|
model_name=self.model_name,
|
||||||
|
base_model=base_model,
|
||||||
|
model_type=ModelType.Pipeline,
|
||||||
|
submodel=SubModelType.Vae,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# TODO: optimize(less code copy)
|
||||||
|
class SD2ModelLoaderInvocation(BaseInvocation):
|
||||||
|
"""Loading submodels of selected model."""
|
||||||
|
|
||||||
|
type: Literal["sd2_model_loader"] = "sd2_model_loader"
|
||||||
|
|
||||||
|
model_name: str = Field(default="", description="Model to load")
|
||||||
|
# TODO: precision?
|
||||||
|
|
||||||
|
# Schema customisation
|
||||||
|
class Config(InvocationConfig):
|
||||||
|
schema_extra = {
|
||||||
|
"ui": {
|
||||||
|
"tags": ["model", "loader"],
|
||||||
|
"type_hints": {
|
||||||
|
"model_name": "model" # TODO: rename to model_name?
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
def invoke(self, context: InvocationContext) -> ModelLoaderOutput:
|
||||||
|
|
||||||
|
base_model = BaseModelType.StableDiffusion2 # TODO:
|
||||||
|
|
||||||
# TODO: not found exceptions
|
# TODO: not found exceptions
|
||||||
if not context.services.model_manager.model_exists(
|
if not context.services.model_manager.model_exists(
|
||||||
|
@ -28,8 +28,9 @@ from safetensors.torch import load_file
|
|||||||
import invokeai.backend.util.logging as logger
|
import invokeai.backend.util.logging as logger
|
||||||
from invokeai.app.services.config import InvokeAIAppConfig
|
from invokeai.app.services.config import InvokeAIAppConfig
|
||||||
|
|
||||||
from .model_manager import ModelManager, SDLegacyType
|
from .model_manager import ModelManager
|
||||||
from .model_cache import ModelCache
|
from .model_cache import ModelCache
|
||||||
|
from .models import SchedulerPredictionType, BaseModelType, ModelVariantType
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf
|
||||||
@ -58,10 +59,6 @@ from diffusers.pipelines.latent_diffusion.pipeline_latent_diffusion import (
|
|||||||
LDMBertConfig,
|
LDMBertConfig,
|
||||||
LDMBertModel,
|
LDMBertModel,
|
||||||
)
|
)
|
||||||
from diffusers.pipelines.paint_by_example import (
|
|
||||||
PaintByExampleImageEncoder,
|
|
||||||
PaintByExamplePipeline,
|
|
||||||
)
|
|
||||||
from diffusers.pipelines.stable_diffusion.safety_checker import (
|
from diffusers.pipelines.stable_diffusion.safety_checker import (
|
||||||
StableDiffusionSafetyChecker,
|
StableDiffusionSafetyChecker,
|
||||||
)
|
)
|
||||||
@ -911,8 +908,10 @@ textenc_pattern = re.compile("|".join(protected.keys()))
|
|||||||
|
|
||||||
|
|
||||||
def convert_open_clip_checkpoint(checkpoint):
|
def convert_open_clip_checkpoint(checkpoint):
|
||||||
cache_dir = InvokeAIAppConfig.get_config().cache_dir
|
text_model = CLIPTextModel.from_pretrained(
|
||||||
text_model = CLIPTextModel.from_pretrained(MODEL_ROOT / 'stable-diffusion-2-text_encoder')
|
MODEL_ROOT / 'stable-diffusion-2-clip',
|
||||||
|
subfolder='text_encoder',
|
||||||
|
)
|
||||||
|
|
||||||
keys = list(checkpoint.keys())
|
keys = list(checkpoint.keys())
|
||||||
|
|
||||||
@ -1002,20 +1001,15 @@ def convert_ldm_vae_to_diffusers(checkpoint, vae_config: DictConfig, image_size:
|
|||||||
|
|
||||||
def load_pipeline_from_original_stable_diffusion_ckpt(
|
def load_pipeline_from_original_stable_diffusion_ckpt(
|
||||||
checkpoint_path: str,
|
checkpoint_path: str,
|
||||||
original_config_file: str = None,
|
model_version: BaseModelType,
|
||||||
num_in_channels: int = None,
|
model_variant: ModelVariantType,
|
||||||
scheduler_type: str = "pndm",
|
original_config_file: str,
|
||||||
pipeline_type: str = None,
|
|
||||||
image_size: int = None,
|
|
||||||
prediction_type: str = None,
|
|
||||||
extract_ema: bool = True,
|
extract_ema: bool = True,
|
||||||
upcast_attn: bool = False,
|
|
||||||
vae: AutoencoderKL = None,
|
|
||||||
vae_path: str = None,
|
|
||||||
precision: torch.dtype = torch.float32,
|
precision: torch.dtype = torch.float32,
|
||||||
return_generator_pipeline: bool = False,
|
upcast_attention: bool = False,
|
||||||
scan_needed:bool=True,
|
prediction_type: SchedulerPredictionType = SchedulerPredictionType.Epsilon,
|
||||||
) -> Union[StableDiffusionPipeline, StableDiffusionGeneratorPipeline]:
|
scan_needed: bool = True,
|
||||||
|
) -> StableDiffusionPipeline:
|
||||||
"""
|
"""
|
||||||
Load a Stable Diffusion pipeline object from a CompVis-style `.ckpt`/`.safetensors` file and (ideally) a `.yaml`
|
Load a Stable Diffusion pipeline object from a CompVis-style `.ckpt`/`.safetensors` file and (ideally) a `.yaml`
|
||||||
config file.
|
config file.
|
||||||
@ -1027,148 +1021,68 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
|
|||||||
:param checkpoint_path: Path to `.ckpt` file.
|
:param checkpoint_path: Path to `.ckpt` file.
|
||||||
:param original_config_file: Path to `.yaml` config file corresponding to the original architecture.
|
:param original_config_file: Path to `.yaml` config file corresponding to the original architecture.
|
||||||
If `None`, will be automatically inferred by looking for a key that only exists in SD2.0 models.
|
If `None`, will be automatically inferred by looking for a key that only exists in SD2.0 models.
|
||||||
:param image_size: The image size that the model was trained on. Use 512 for Stable Diffusion v1.X and Stable Diffusion v2
|
|
||||||
Base. Use 768 for Stable Diffusion v2.
|
|
||||||
:param prediction_type: The prediction type that the model was trained on. Use `'epsilon'` for Stable Diffusion
|
:param prediction_type: The prediction type that the model was trained on. Use `'epsilon'` for Stable Diffusion
|
||||||
v1.X and Stable Diffusion v2 Base. Use `'v-prediction'` for Stable Diffusion v2.
|
v1.X and Stable Diffusion v2 Base. Use `'v-prediction'` for Stable Diffusion v2.
|
||||||
:param num_in_channels: The number of input channels. If `None` number of input channels will be automatically
|
|
||||||
inferred.
|
|
||||||
:param scheduler_type: Type of scheduler to use. Should be one of `["pndm", "lms", "heun", "euler",
|
:param scheduler_type: Type of scheduler to use. Should be one of `["pndm", "lms", "heun", "euler",
|
||||||
"euler-ancestral", "dpm", "ddim"]`. :param model_type: The pipeline type. `None` to automatically infer, or one of
|
"euler-ancestral", "dpm", "ddim"]`. :param model_type: The pipeline type. `None` to automatically infer, or one of
|
||||||
`["FrozenOpenCLIPEmbedder", "FrozenCLIPEmbedder", "PaintByExample"]`. :param extract_ema: Only relevant for
|
`["FrozenOpenCLIPEmbedder", "FrozenCLIPEmbedder"]`. :param extract_ema: Only relevant for
|
||||||
checkpoints that have both EMA and non-EMA weights. Whether to extract the EMA weights
|
checkpoints that have both EMA and non-EMA weights. Whether to extract the EMA weights
|
||||||
or not. Defaults to `False`. Pass `True` to extract the EMA weights. EMA weights usually yield higher
|
or not. Defaults to `False`. Pass `True` to extract the EMA weights. EMA weights usually yield higher
|
||||||
quality images for inference. Non-EMA weights are usually better to continue fine-tuning.
|
quality images for inference. Non-EMA weights are usually better to continue fine-tuning.
|
||||||
:param precision: precision to use - torch.float16, torch.float32 or torch.autocast
|
:param precision: precision to use - torch.float16, torch.float32 or torch.autocast
|
||||||
:param upcast_attention: Whether the attention computation should always be upcasted. This is necessary when
|
:param upcast_attention: Whether the attention computation should always be upcasted. This is necessary when
|
||||||
running stable diffusion 2.1.
|
running stable diffusion 2.1.
|
||||||
:param vae: A diffusers VAE to load into the pipeline.
|
|
||||||
:param vae_path: Path to a checkpoint VAE that will be converted into diffusers and loaded into the pipeline.
|
|
||||||
"""
|
"""
|
||||||
config = InvokeAIAppConfig.get_config()
|
config = InvokeAIAppConfig.get_config()
|
||||||
cache_dir = config.cache_dir
|
|
||||||
|
|
||||||
with warnings.catch_warnings():
|
with warnings.catch_warnings():
|
||||||
warnings.simplefilter("ignore")
|
warnings.simplefilter("ignore")
|
||||||
verbosity = dlogging.get_verbosity()
|
verbosity = dlogging.get_verbosity()
|
||||||
dlogging.set_verbosity_error()
|
dlogging.set_verbosity_error()
|
||||||
|
|
||||||
if Path(checkpoint_path).suffix == '.ckpt':
|
if str(checkpoint_path).endswith(".safetensors"):
|
||||||
if scan_needed:
|
|
||||||
ModelCache.scan_model(checkpoint_path,checkpoint_path)
|
|
||||||
checkpoint = torch.load(checkpoint_path)
|
|
||||||
else:
|
|
||||||
checkpoint = load_file(checkpoint_path)
|
checkpoint = load_file(checkpoint_path)
|
||||||
|
|
||||||
pipeline_class = (
|
|
||||||
StableDiffusionGeneratorPipeline
|
|
||||||
if return_generator_pipeline
|
|
||||||
else StableDiffusionPipeline
|
|
||||||
)
|
|
||||||
|
|
||||||
# Sometimes models don't have the global_step item
|
|
||||||
if "global_step" in checkpoint:
|
|
||||||
global_step = checkpoint["global_step"]
|
|
||||||
else:
|
else:
|
||||||
logger.debug("global_step key not found in model")
|
if scan_needed:
|
||||||
global_step = None
|
ModelCache.scan_model(checkpoint_path, checkpoint_path)
|
||||||
|
checkpoint = torch.load(checkpoint_path)
|
||||||
|
|
||||||
# sometimes there is a state_dict key and sometimes not
|
# sometimes there is a state_dict key and sometimes not
|
||||||
if "state_dict" in checkpoint:
|
if "state_dict" in checkpoint:
|
||||||
checkpoint = checkpoint["state_dict"]
|
checkpoint = checkpoint["state_dict"]
|
||||||
|
|
||||||
upcast_attention = False
|
|
||||||
if original_config_file is None:
|
|
||||||
model_type = ModelManager.probe_model_type(checkpoint)
|
|
||||||
|
|
||||||
if model_type == SDLegacyType.V2_v:
|
|
||||||
original_config_file = (
|
|
||||||
config.legacy_conf_path / "v2-inference-v.yaml"
|
|
||||||
)
|
|
||||||
if global_step == 110000:
|
|
||||||
# v2.1 needs to upcast attention
|
|
||||||
upcast_attention = True
|
|
||||||
elif model_type == SDLegacyType.V2_e:
|
|
||||||
original_config_file = (
|
|
||||||
config.legacy_conf_path / "v2-inference.yaml"
|
|
||||||
)
|
|
||||||
elif model_type == SDLegacyType.V1_INPAINT:
|
|
||||||
original_config_file = (
|
|
||||||
config.legacy_conf_path / "v1-inpainting-inference.yaml"
|
|
||||||
)
|
|
||||||
|
|
||||||
elif model_type == SDLegacyType.V1:
|
|
||||||
original_config_file = (
|
|
||||||
config.legacy_conf_path / "v1-inference.yaml"
|
|
||||||
)
|
|
||||||
|
|
||||||
else:
|
|
||||||
raise Exception("Unknown checkpoint type")
|
|
||||||
|
|
||||||
original_config = OmegaConf.load(original_config_file)
|
original_config = OmegaConf.load(original_config_file)
|
||||||
|
|
||||||
if num_in_channels is not None:
|
if model_version == BaseModelType.StableDiffusion2 and prediction_type == SchedulerPredictionType.VPrediction:
|
||||||
original_config["model"]["params"]["unet_config"]["params"][
|
image_size = 768
|
||||||
"in_channels"
|
|
||||||
] = num_in_channels
|
|
||||||
|
|
||||||
if (
|
|
||||||
"parameterization" in original_config["model"]["params"]
|
|
||||||
and original_config["model"]["params"]["parameterization"] == "v"
|
|
||||||
):
|
|
||||||
if prediction_type is None:
|
|
||||||
# NOTE: For stable diffusion 2 base it is recommended to pass `prediction_type=="epsilon"`
|
|
||||||
# as it relies on a brittle global step parameter here
|
|
||||||
prediction_type = "epsilon" if global_step == 875000 else "v_prediction"
|
|
||||||
if image_size is None:
|
|
||||||
# NOTE: For stable diffusion 2 base one has to pass `image_size==512`
|
|
||||||
# as it relies on a brittle global step parameter here
|
|
||||||
image_size = 512 if global_step == 875000 else 768
|
|
||||||
else:
|
else:
|
||||||
if prediction_type is None:
|
image_size = 512
|
||||||
prediction_type = "epsilon"
|
|
||||||
if image_size is None:
|
#
|
||||||
image_size = 512
|
# convert scheduler
|
||||||
|
#
|
||||||
|
|
||||||
num_train_timesteps = original_config.model.params.timesteps
|
num_train_timesteps = original_config.model.params.timesteps
|
||||||
beta_start = original_config.model.params.linear_start
|
beta_start = original_config.model.params.linear_start
|
||||||
beta_end = original_config.model.params.linear_end
|
beta_end = original_config.model.params.linear_end
|
||||||
|
|
||||||
scheduler = DDIMScheduler(
|
scheduler = PNDMScheduler(
|
||||||
beta_end=beta_end,
|
beta_end=beta_end,
|
||||||
beta_schedule="scaled_linear",
|
beta_schedule="scaled_linear",
|
||||||
beta_start=beta_start,
|
beta_start=beta_start,
|
||||||
num_train_timesteps=num_train_timesteps,
|
num_train_timesteps=num_train_timesteps,
|
||||||
steps_offset=1,
|
steps_offset=1,
|
||||||
clip_sample=False,
|
|
||||||
set_alpha_to_one=False,
|
set_alpha_to_one=False,
|
||||||
prediction_type=prediction_type,
|
prediction_type=prediction_type,
|
||||||
|
skip_prk_steps=True
|
||||||
)
|
)
|
||||||
# make sure scheduler works correctly with DDIM
|
# make sure scheduler works correctly with DDIM
|
||||||
scheduler.register_to_config(clip_sample=False)
|
scheduler.register_to_config(clip_sample=False)
|
||||||
|
|
||||||
if scheduler_type == "pndm":
|
#
|
||||||
config = dict(scheduler.config)
|
# convert unet
|
||||||
config["skip_prk_steps"] = True
|
#
|
||||||
scheduler = PNDMScheduler.from_config(config)
|
|
||||||
elif scheduler_type == "lms":
|
|
||||||
scheduler = LMSDiscreteScheduler.from_config(scheduler.config)
|
|
||||||
elif scheduler_type == "heun":
|
|
||||||
scheduler = HeunDiscreteScheduler.from_config(scheduler.config)
|
|
||||||
elif scheduler_type == "euler":
|
|
||||||
scheduler = EulerDiscreteScheduler.from_config(scheduler.config)
|
|
||||||
elif scheduler_type == "euler-ancestral":
|
|
||||||
scheduler = EulerAncestralDiscreteScheduler.from_config(scheduler.config)
|
|
||||||
elif scheduler_type == "dpm":
|
|
||||||
scheduler = DPMSolverMultistepScheduler.from_config(scheduler.config)
|
|
||||||
elif scheduler_type == 'unipc':
|
|
||||||
scheduler = UniPCMultistepScheduler.from_config(scheduler.config)
|
|
||||||
elif scheduler_type == "ddim":
|
|
||||||
scheduler = scheduler
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Scheduler of type {scheduler_type} doesn't exist!")
|
|
||||||
|
|
||||||
# Convert the UNet2DConditionModel model.
|
|
||||||
unet_config = create_unet_diffusers_config(
|
unet_config = create_unet_diffusers_config(
|
||||||
original_config, image_size=image_size
|
original_config, image_size=image_size
|
||||||
)
|
)
|
||||||
@ -1181,35 +1095,25 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
|
|||||||
|
|
||||||
unet.load_state_dict(converted_unet_checkpoint)
|
unet.load_state_dict(converted_unet_checkpoint)
|
||||||
|
|
||||||
# If a replacement VAE path was specified, we'll incorporate that into
|
#
|
||||||
# the checkpoint model and then convert it
|
# convert vae
|
||||||
if vae_path:
|
#
|
||||||
logger.debug(f"Converting VAE {vae_path}")
|
|
||||||
replace_checkpoint_vae(checkpoint,vae_path)
|
|
||||||
# otherwise we use the original VAE, provided that
|
|
||||||
# an externally loaded diffusers VAE was not passed
|
|
||||||
elif not vae:
|
|
||||||
logger.debug("Using checkpoint model's original VAE")
|
|
||||||
|
|
||||||
if vae:
|
vae = convert_ldm_vae_to_diffusers(
|
||||||
logger.debug("Using replacement diffusers VAE")
|
checkpoint,
|
||||||
else: # convert the original or replacement VAE
|
original_config,
|
||||||
vae = convert_ldm_vae_to_diffusers(
|
image_size,
|
||||||
checkpoint,
|
)
|
||||||
original_config,
|
|
||||||
image_size)
|
|
||||||
|
|
||||||
# Convert the text model.
|
# Convert the text model.
|
||||||
model_type = pipeline_type
|
model_type = original_config.model.params.cond_stage_config.target.split(".")[-1]
|
||||||
if model_type is None:
|
|
||||||
model_type = original_config.model.params.cond_stage_config.target.split(
|
|
||||||
"."
|
|
||||||
)[-1]
|
|
||||||
|
|
||||||
if model_type == "FrozenOpenCLIPEmbedder":
|
if model_type == "FrozenOpenCLIPEmbedder":
|
||||||
text_model = convert_open_clip_checkpoint(checkpoint)
|
text_model = convert_open_clip_checkpoint(checkpoint)
|
||||||
tokenizer = CLIPTokenizer.from_pretrained(MODEL_ROOT / 'stable-diffusion-2-tokenizer')
|
tokenizer = CLIPTokenizer.from_pretrained(
|
||||||
pipe = pipeline_class(
|
MODEL_ROOT / 'stable-diffusion-2-clip',
|
||||||
|
subfolder='tokenizer',
|
||||||
|
)
|
||||||
|
pipe = StableDiffusionPipeline(
|
||||||
vae=vae.to(precision),
|
vae=vae.to(precision),
|
||||||
text_encoder=text_model.to(precision),
|
text_encoder=text_model.to(precision),
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
@ -1219,20 +1123,22 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
|
|||||||
feature_extractor=None,
|
feature_extractor=None,
|
||||||
requires_safety_checker=False,
|
requires_safety_checker=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
elif model_type in ["FrozenCLIPEmbedder", "WeightedFrozenCLIPEmbedder"]:
|
elif model_type in ["FrozenCLIPEmbedder", "WeightedFrozenCLIPEmbedder"]:
|
||||||
text_model = convert_ldm_clip_checkpoint(checkpoint)
|
text_model = convert_ldm_clip_checkpoint(checkpoint)
|
||||||
tokenizer = CLIPTokenizer.from_pretrained(MODEL_ROOT / 'clip-vit-large-patch14')
|
tokenizer = CLIPTokenizer.from_pretrained(MODEL_ROOT / 'clip-vit-large-patch14')
|
||||||
safety_checker = StableDiffusionSafetyChecker.from_pretrained(MODEL_ROOT / 'stable-diffusion-safety-checker')
|
safety_checker = StableDiffusionSafetyChecker.from_pretrained(MODEL_ROOT / 'stable-diffusion-safety-checker')
|
||||||
feature_extractor = AutoFeatureExtractor.from_pretrained(MODEL_ROOT / 'stable-diffusion-safety-checker-extractor')
|
feature_extractor = AutoFeatureExtractor.from_pretrained(MODEL_ROOT / 'stable-diffusion-safety-checker')
|
||||||
pipe = pipeline_class(
|
pipe = StableDiffusionPipeline(
|
||||||
vae=vae.to(precision),
|
vae=vae.to(precision),
|
||||||
text_encoder=text_model.to(precision),
|
text_encoder=text_model.to(precision),
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
unet=unet.to(precision),
|
unet=unet.to(precision),
|
||||||
scheduler=scheduler,
|
scheduler=scheduler,
|
||||||
safety_checker=None if return_generator_pipeline else safety_checker.to(precision),
|
safety_checker=safety_checker.to(precision),
|
||||||
feature_extractor=feature_extractor,
|
feature_extractor=feature_extractor,
|
||||||
)
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
text_config = create_ldm_bert_config(original_config)
|
text_config = create_ldm_bert_config(original_config)
|
||||||
text_model = convert_ldm_bert_checkpoint(checkpoint, text_config)
|
text_model = convert_ldm_bert_checkpoint(checkpoint, text_config)
|
||||||
|
@ -20,15 +20,12 @@ import gc
|
|||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import hashlib
|
import hashlib
|
||||||
import warnings
|
|
||||||
from contextlib import suppress
|
from contextlib import suppress
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, Union, types, Optional, Type, Any
|
from typing import Dict, Union, types, Optional, Type, Any
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from diffusers import logging as diffusers_logging
|
|
||||||
from transformers import logging as transformers_logging
|
|
||||||
import logging
|
import logging
|
||||||
import invokeai.backend.util.logging as logger
|
import invokeai.backend.util.logging as logger
|
||||||
from invokeai.app.services.config import get_invokeai_config
|
from invokeai.app.services.config import get_invokeai_config
|
||||||
@ -382,21 +379,6 @@ class ModelCache(object):
|
|||||||
f.write(hash)
|
f.write(hash)
|
||||||
return hash
|
return hash
|
||||||
|
|
||||||
class SilenceWarnings(object):
|
|
||||||
def __init__(self):
|
|
||||||
self.transformers_verbosity = transformers_logging.get_verbosity()
|
|
||||||
self.diffusers_verbosity = diffusers_logging.get_verbosity()
|
|
||||||
|
|
||||||
def __enter__(self):
|
|
||||||
transformers_logging.set_verbosity_error()
|
|
||||||
diffusers_logging.set_verbosity_error()
|
|
||||||
warnings.simplefilter('ignore')
|
|
||||||
|
|
||||||
def __exit__(self,type,value,traceback):
|
|
||||||
transformers_logging.set_verbosity(self.transformers_verbosity)
|
|
||||||
diffusers_logging.set_verbosity(self.diffusers_verbosity)
|
|
||||||
warnings.simplefilter('default')
|
|
||||||
|
|
||||||
class VRAMUsage(object):
|
class VRAMUsage(object):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.vram = None
|
self.vram = None
|
||||||
|
@ -148,6 +148,7 @@ into the model when downloaded or converted.
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
import hashlib
|
||||||
import textwrap
|
import textwrap
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from packaging import version
|
from packaging import version
|
||||||
@ -166,7 +167,7 @@ import invokeai.backend.util.logging as logger
|
|||||||
from invokeai.app.services.config import InvokeAIAppConfig
|
from invokeai.app.services.config import InvokeAIAppConfig
|
||||||
from invokeai.backend.util import CUDA_DEVICE, download_with_resume
|
from invokeai.backend.util import CUDA_DEVICE, download_with_resume
|
||||||
from .model_cache import ModelCache, ModelLocker
|
from .model_cache import ModelCache, ModelLocker
|
||||||
from .models import BaseModelType, ModelType, SubModelType, MODEL_CLASSES
|
from .models import BaseModelType, ModelType, SubModelType, ModelError, MODEL_CLASSES
|
||||||
|
|
||||||
# We are only starting to number the config file with release 3.
|
# We are only starting to number the config file with release 3.
|
||||||
# The config file version doesn't have to start at release version, but it will help
|
# The config file version doesn't have to start at release version, but it will help
|
||||||
@ -299,7 +300,7 @@ class ModelManager(object):
|
|||||||
for model_key, model_config in config.items():
|
for model_key, model_config in config.items():
|
||||||
model_name, base_model, model_type = self.parse_key(model_key)
|
model_name, base_model, model_type = self.parse_key(model_key)
|
||||||
model_class = MODEL_CLASSES[base_model][model_type]
|
model_class = MODEL_CLASSES[base_model][model_type]
|
||||||
self.models[model_key] = model_class.build_config(**model_config)
|
self.models[model_key] = model_class.create_config(**model_config)
|
||||||
|
|
||||||
# check config version number and update on disk/RAM if necessary
|
# check config version number and update on disk/RAM if necessary
|
||||||
self.globals = InvokeAIAppConfig.get_config()
|
self.globals = InvokeAIAppConfig.get_config()
|
||||||
@ -406,9 +407,7 @@ class ModelManager(object):
|
|||||||
path_mask = f"/models/{base_model}/{model_type}/{model_name}*"
|
path_mask = f"/models/{base_model}/{model_type}/{model_name}*"
|
||||||
if False: # model_path = next(find_by_mask(path_mask)):
|
if False: # model_path = next(find_by_mask(path_mask)):
|
||||||
model_path = None # TODO:
|
model_path = None # TODO:
|
||||||
model_config = model_class.build_config(
|
model_config = model_class.probe_config(model_path)
|
||||||
path=model_path,
|
|
||||||
)
|
|
||||||
self.models[model_key] = model_config
|
self.models[model_key] = model_config
|
||||||
else:
|
else:
|
||||||
raise Exception(f"Model not found - {model_key}")
|
raise Exception(f"Model not found - {model_key}")
|
||||||
@ -437,16 +436,22 @@ class ModelManager(object):
|
|||||||
|
|
||||||
# vae/movq override
|
# vae/movq override
|
||||||
# TODO:
|
# TODO:
|
||||||
if submodel_type is not None and submodel_type in model_config:
|
if submodel_type is not None and hasattr(model_config, submodel_type):
|
||||||
model_path = model_config[submodel_type]
|
override_path = getattr(model_config, submodel_type)
|
||||||
model_type = submodel_type
|
if override_path:
|
||||||
submodel_type = None
|
model_path = override_path
|
||||||
|
model_type = submodel_type
|
||||||
|
submodel_type = None
|
||||||
|
model_class = MODEL_CLASSES[base_model][model_type]
|
||||||
|
|
||||||
dst_convert_path = None # TODO:
|
# TODO: path
|
||||||
|
# TODO: is it accurate to use path as id
|
||||||
|
dst_convert_path = self.globals.models_dir / ".cache" / hashlib.md5(model_path.encode()).hexdigest()
|
||||||
model_path = model_class.convert_if_required(
|
model_path = model_class.convert_if_required(
|
||||||
model_path,
|
base_model=base_model,
|
||||||
dst_convert_path,
|
model_path=model_path,
|
||||||
model_config,
|
output_path=dst_convert_path,
|
||||||
|
config=model_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
model_context = self.cache.get_model(
|
model_context = self.cache.get_model(
|
||||||
@ -457,14 +462,14 @@ class ModelManager(object):
|
|||||||
submodel=submodel_type,
|
submodel=submodel_type,
|
||||||
)
|
)
|
||||||
|
|
||||||
hash = "<NO_HASH>" # TODO:
|
model_hash = "<NO_HASH>" # TODO:
|
||||||
|
|
||||||
return ModelInfo(
|
return ModelInfo(
|
||||||
context = model_context,
|
context = model_context,
|
||||||
name = model_name,
|
name = model_name,
|
||||||
base_model = base_model,
|
base_model = base_model,
|
||||||
type = submodel_type or model_type,
|
type = submodel_type or model_type,
|
||||||
hash = hash,
|
hash = model_hash,
|
||||||
location = model_path, # TODO:
|
location = model_path, # TODO:
|
||||||
precision = self.cache.precision,
|
precision = self.cache.precision,
|
||||||
_cache = self.cache,
|
_cache = self.cache,
|
||||||
@ -633,7 +638,7 @@ class ModelManager(object):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
model_class = MODEL_CLASSES[base_model][model_type]
|
model_class = MODEL_CLASSES[base_model][model_type]
|
||||||
model_config = model_class.build_config(**model_attributes)
|
model_config = model_class.create_config(**model_attributes)
|
||||||
model_key = self.create_key(model_name, base_model, model_type)
|
model_key = self.create_key(model_name, base_model, model_type)
|
||||||
|
|
||||||
assert (
|
assert (
|
||||||
@ -749,13 +754,15 @@ class ModelManager(object):
|
|||||||
|
|
||||||
for model_key, model_config in list(self.models.items()):
|
for model_key, model_config in list(self.models.items()):
|
||||||
model_name, base_model, model_type = self.parse_key(model_key)
|
model_name, base_model, model_type = self.parse_key(model_key)
|
||||||
if not os.path.exists(model_config.path):
|
model_path = str(self.globals.root / model_config.path)
|
||||||
|
if not os.path.exists(model_path):
|
||||||
|
model_class = MODEL_CLASSES[base_model][model_type]
|
||||||
if model_class.save_to_config:
|
if model_class.save_to_config:
|
||||||
model_config.error = ModelError.NotFound
|
model_config.error = ModelError.NotFound
|
||||||
else:
|
else:
|
||||||
self.models.pop(model_key, None)
|
self.models.pop(model_key, None)
|
||||||
else:
|
else:
|
||||||
loaded_files.add(model_config.path)
|
loaded_files.add(model_path)
|
||||||
|
|
||||||
for base_model in BaseModelType:
|
for base_model in BaseModelType:
|
||||||
for model_type in ModelType:
|
for model_type in ModelType:
|
||||||
@ -774,7 +781,5 @@ class ModelManager(object):
|
|||||||
if model_key in self.models:
|
if model_key in self.models:
|
||||||
raise Exception(f"Model with key {model_key} added twice")
|
raise Exception(f"Model with key {model_key} added twice")
|
||||||
|
|
||||||
model_config: ModelConfigBase = model_class.build_config(
|
model_config: ModelConfigBase = model_class.probe_config(model_path)
|
||||||
path=model_path,
|
|
||||||
)
|
|
||||||
self.models[model_key] = model_config
|
self.models[model_key] = model_config
|
||||||
|
@ -12,8 +12,7 @@ from typing import Callable, Literal, Union, Dict
|
|||||||
from picklescan.scanner import scan_file_path
|
from picklescan.scanner import scan_file_path
|
||||||
|
|
||||||
import invokeai.backend.util.logging as logger
|
import invokeai.backend.util.logging as logger
|
||||||
from .models import BaseModelType, ModelType, ModelVariantType, SchedulerPredictionType
|
from .models import BaseModelType, ModelType, ModelVariantType, SchedulerPredictionType, SilenceWarnings
|
||||||
from .model_cache import SilenceWarnings
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ModelVariantInfo(object):
|
class ModelVariantInfo(object):
|
||||||
@ -21,6 +20,7 @@ class ModelVariantInfo(object):
|
|||||||
base_type: BaseModelType
|
base_type: BaseModelType
|
||||||
variant_type: ModelVariantType
|
variant_type: ModelVariantType
|
||||||
prediction_type: SchedulerPredictionType
|
prediction_type: SchedulerPredictionType
|
||||||
|
upcast_attention: bool
|
||||||
format: Literal['folder','checkpoint']
|
format: Literal['folder','checkpoint']
|
||||||
image_size: int
|
image_size: int
|
||||||
|
|
||||||
@ -95,10 +95,12 @@ class ModelProbe(object):
|
|||||||
base_type = base_type,
|
base_type = base_type,
|
||||||
variant_type = variant_type,
|
variant_type = variant_type,
|
||||||
prediction_type = prediction_type,
|
prediction_type = prediction_type,
|
||||||
|
upcast_attention = (base_type==BaseModelType.StableDiffusion2 \
|
||||||
|
and prediction_type==SchedulerPredictionType.VPrediction),
|
||||||
format = format,
|
format = format,
|
||||||
image_size = 768 if (base_type==BaseModelType.StableDiffusion2 \
|
image_size = 768 if (base_type==BaseModelType.StableDiffusion2 \
|
||||||
and prediction_type==SchedulerPredictionType.VPrediction \
|
and prediction_type==SchedulerPredictionType.VPrediction \
|
||||||
) else 512
|
) else 512,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return None
|
return None
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
from .base import BaseModelType, ModelType, SubModelType, ModelBase, ModelConfigBase, ModelVariantType, SchedulerPredictionType
|
from .base import BaseModelType, ModelType, SubModelType, ModelBase, ModelConfigBase, ModelVariantType, SchedulerPredictionType, ModelError, SilenceWarnings
|
||||||
from .stable_diffusion import StableDiffusion15Model, StableDiffusion2Model, StableDiffusion2BaseModel
|
from .stable_diffusion import StableDiffusion1Model, StableDiffusion2Model
|
||||||
from .vae import VaeModel
|
from .vae import VaeModel
|
||||||
from .lora import LoRAModel
|
from .lora import LoRAModel
|
||||||
#from .controlnet import ControlNetModel # TODO:
|
#from .controlnet import ControlNetModel # TODO:
|
||||||
@ -11,7 +11,7 @@ class ControlNetModel:
|
|||||||
|
|
||||||
MODEL_CLASSES = {
|
MODEL_CLASSES = {
|
||||||
BaseModelType.StableDiffusion1: {
|
BaseModelType.StableDiffusion1: {
|
||||||
ModelType.Pipeline: StableDiffusion15Model,
|
ModelType.Pipeline: StableDiffusion1Model,
|
||||||
ModelType.Vae: VaeModel,
|
ModelType.Vae: VaeModel,
|
||||||
ModelType.Lora: LoRAModel,
|
ModelType.Lora: LoRAModel,
|
||||||
ModelType.ControlNet: ControlNetModel,
|
ModelType.ControlNet: ControlNetModel,
|
||||||
|
@ -10,13 +10,9 @@ from pydantic import BaseModel, Field
|
|||||||
from typing import List, Dict, Optional, Type, Literal
|
from typing import List, Dict, Optional, Type, Literal
|
||||||
|
|
||||||
class BaseModelType(str, Enum):
|
class BaseModelType(str, Enum):
|
||||||
#StableDiffusion1_5 = "stable_diffusion_1_5"
|
|
||||||
#StableDiffusion2 = "stable_diffusion_2"
|
|
||||||
#StableDiffusion2Base = "stable_diffusion_2_base"
|
|
||||||
# TODO: maybe then add sample size(512/768)?
|
|
||||||
StableDiffusion1 = "sd-1"
|
StableDiffusion1 = "sd-1"
|
||||||
StableDiffusion2 = "sd-2" # 768 pixels; this will have v-prediction parameterization
|
StableDiffusion2 = "sd-2"
|
||||||
#Kandinsky2_1 = "kandinsky_2_1"
|
#Kandinsky2_1 = "kandinsky-2.1"
|
||||||
|
|
||||||
class ModelType(str, Enum):
|
class ModelType(str, Enum):
|
||||||
Pipeline = "pipeline"
|
Pipeline = "pipeline"
|
||||||
@ -107,7 +103,7 @@ class ModelBase:
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
fields = inspect.get_annotations(value)
|
fields = inspect.get_annotations(value)
|
||||||
if "format" not in fields or typing.get_origin(fields["format"]) != Literal:
|
if "format" not in fields:
|
||||||
raise Exception("Invalid config definition - format field not found")
|
raise Exception("Invalid config definition - format field not found")
|
||||||
|
|
||||||
format_type = typing.get_origin(fields["format"])
|
format_type = typing.get_origin(fields["format"])
|
||||||
@ -125,13 +121,20 @@ class ModelBase:
|
|||||||
return cls.__configs
|
return cls.__configs
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def build_config(cls, **kwargs):
|
def create_config(cls, **kwargs):
|
||||||
if "format" not in kwargs:
|
if "format" not in kwargs:
|
||||||
kwargs["format"] = cls.detect_format(kwargs["path"])
|
raise Exception("Field 'format' not found in model config")
|
||||||
|
|
||||||
configs = cls._get_configs()
|
configs = cls._get_configs()
|
||||||
return configs[kwargs["format"]](**kwargs)
|
return configs[kwargs["format"]](**kwargs)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def probe_config(cls, path: str, **kwargs):
|
||||||
|
return cls.create_config(
|
||||||
|
path=path,
|
||||||
|
format=cls.detect_format(path),
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def detect_format(cls, path: str) -> str:
|
def detect_format(cls, path: str) -> str:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
@ -304,3 +307,62 @@ def _calc_model_by_data(model) -> int:
|
|||||||
mem_bufs = sum([buf.nelement()*buf.element_size() for buf in model.buffers()])
|
mem_bufs = sum([buf.nelement()*buf.element_size() for buf in model.buffers()])
|
||||||
mem = mem_params + mem_bufs # in bytes
|
mem = mem_params + mem_bufs # in bytes
|
||||||
return mem
|
return mem
|
||||||
|
|
||||||
|
|
||||||
|
def _fast_safetensors_reader(path: str):
|
||||||
|
checkpoint = dict()
|
||||||
|
device = torch.device("meta")
|
||||||
|
with open(path, "rb") as f:
|
||||||
|
definition_len = int.from_bytes(f.read(8), 'little')
|
||||||
|
definition_json = f.read(definition_len)
|
||||||
|
definition = json.loads(definition_json)
|
||||||
|
|
||||||
|
if "__metadata__" in definition and definition["__metadata__"].get("format", "pt") not in {"pt", "torch", "pytorch"}:
|
||||||
|
raise Exception("Supported only pytorch safetensors files")
|
||||||
|
definition.pop("__metadata__", None)
|
||||||
|
|
||||||
|
for key, info in definition.items():
|
||||||
|
dtype = {
|
||||||
|
"I8": torch.int8,
|
||||||
|
"I16": torch.int16,
|
||||||
|
"I32": torch.int32,
|
||||||
|
"I64": torch.int64,
|
||||||
|
"F16": torch.float16,
|
||||||
|
"F32": torch.float32,
|
||||||
|
"F64": torch.float64,
|
||||||
|
}[info["dtype"]]
|
||||||
|
|
||||||
|
checkpoint[key] = torch.empty(info["shape"], dtype=dtype, device=device)
|
||||||
|
|
||||||
|
return checkpoint
|
||||||
|
|
||||||
|
|
||||||
|
def read_checkpoint_meta(path: str):
|
||||||
|
if path.endswith(".safetensors"):
|
||||||
|
try:
|
||||||
|
checkpoint = _fast_safetensors_reader(path)
|
||||||
|
except:
|
||||||
|
# TODO: create issue for support "meta"?
|
||||||
|
checkpoint = safetensors.torch.load_file(path, device="cpu")
|
||||||
|
else:
|
||||||
|
checkpoint = torch.load(path, map_location=torch.device("meta"))
|
||||||
|
return checkpoint
|
||||||
|
|
||||||
|
import warnings
|
||||||
|
from diffusers import logging as diffusers_logging
|
||||||
|
from transformers import logging as transformers_logging
|
||||||
|
|
||||||
|
class SilenceWarnings(object):
|
||||||
|
def __init__(self):
|
||||||
|
self.transformers_verbosity = transformers_logging.get_verbosity()
|
||||||
|
self.diffusers_verbosity = diffusers_logging.get_verbosity()
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
transformers_logging.set_verbosity_error()
|
||||||
|
diffusers_logging.set_verbosity_error()
|
||||||
|
warnings.simplefilter('ignore')
|
||||||
|
|
||||||
|
def __exit__(self, type, value, traceback):
|
||||||
|
transformers_logging.set_verbosity(self.transformers_verbosity)
|
||||||
|
diffusers_logging.set_verbosity(self.diffusers_verbosity)
|
||||||
|
warnings.simplefilter('default')
|
||||||
|
@ -54,8 +54,14 @@ class LoRAModel(ModelBase):
|
|||||||
else:
|
else:
|
||||||
return "lycoris"
|
return "lycoris"
|
||||||
|
|
||||||
@staticmethod
|
@classmethod
|
||||||
def convert_if_required(cls, model_path: str, dst_cache_path: str, config: Optional[dict]) -> str:
|
def convert_if_required(
|
||||||
|
cls,
|
||||||
|
model_path: str,
|
||||||
|
output_path: str,
|
||||||
|
config: ModelConfigBase,
|
||||||
|
base_model: BaseModelType,
|
||||||
|
) -> str:
|
||||||
if cls.detect_format(model_path) == "diffusers":
|
if cls.detect_format(model_path) == "diffusers":
|
||||||
# TODO: add diffusers lora when it stabilizes a bit
|
# TODO: add diffusers lora when it stabilizes a bit
|
||||||
raise NotImplementedError("Diffusers lora not supported")
|
raise NotImplementedError("Diffusers lora not supported")
|
||||||
|
@ -3,7 +3,8 @@ import json
|
|||||||
import torch
|
import torch
|
||||||
import safetensors.torch
|
import safetensors.torch
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
from typing import Literal, Optional
|
from pathlib import Path
|
||||||
|
from typing import Literal, Optional, Union
|
||||||
from .base import (
|
from .base import (
|
||||||
ModelBase,
|
ModelBase,
|
||||||
ModelConfigBase,
|
ModelConfigBase,
|
||||||
@ -12,13 +13,16 @@ from .base import (
|
|||||||
SubModelType,
|
SubModelType,
|
||||||
ModelVariantType,
|
ModelVariantType,
|
||||||
DiffusersModel,
|
DiffusersModel,
|
||||||
|
SchedulerPredictionType,
|
||||||
|
SilenceWarnings,
|
||||||
|
read_checkpoint_meta,
|
||||||
)
|
)
|
||||||
from invokeai.app.services.config import InvokeAIAppConfig
|
from invokeai.app.services.config import InvokeAIAppConfig
|
||||||
|
from omegaconf import OmegaConf
|
||||||
|
|
||||||
# TODO: how to name properly
|
|
||||||
class StableDiffusion15Model(DiffusersModel):
|
|
||||||
|
|
||||||
# TODO: str -> Path?
|
class StableDiffusion1Model(DiffusersModel):
|
||||||
|
|
||||||
class DiffusersConfig(ModelConfigBase):
|
class DiffusersConfig(ModelConfigBase):
|
||||||
format: Literal["diffusers"]
|
format: Literal["diffusers"]
|
||||||
vae: Optional[str] = Field(None)
|
vae: Optional[str] = Field(None)
|
||||||
@ -32,94 +36,56 @@ class StableDiffusion15Model(DiffusersModel):
|
|||||||
|
|
||||||
|
|
||||||
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
|
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
|
||||||
assert base_model == BaseModelType.StableDiffusion1_5
|
assert base_model == BaseModelType.StableDiffusion1
|
||||||
assert model_type == ModelType.Pipeline
|
assert model_type == ModelType.Pipeline
|
||||||
super().__init__(
|
super().__init__(
|
||||||
model_path=model_path,
|
model_path=model_path,
|
||||||
base_model=BaseModelType.StableDiffusion1_5,
|
base_model=BaseModelType.StableDiffusion1,
|
||||||
model_type=ModelType.Pipeline,
|
model_type=ModelType.Pipeline,
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _fast_safetensors_reader(path: str):
|
|
||||||
checkpoint = dict()
|
|
||||||
device = torch.device("meta")
|
|
||||||
with open(path, "rb") as f:
|
|
||||||
definition_len = int.from_bytes(f.read(8), 'little')
|
|
||||||
definition_json = f.read(definition_len)
|
|
||||||
definition = json.loads(definition_json)
|
|
||||||
|
|
||||||
if "__metadata__" in definition and definition["__metadata__"].get("format", "pt") not in {"pt", "torch", "pytorch"}:
|
|
||||||
raise Exception("Supported only pytorch safetensors files")
|
|
||||||
definition.pop("__metadata__", None)
|
|
||||||
|
|
||||||
for key, info in definition.items():
|
|
||||||
dtype = {
|
|
||||||
"I8": torch.int8,
|
|
||||||
"I16": torch.int16,
|
|
||||||
"I32": torch.int32,
|
|
||||||
"I64": torch.int64,
|
|
||||||
"F16": torch.float16,
|
|
||||||
"F32": torch.float32,
|
|
||||||
"F64": torch.float64,
|
|
||||||
}[info["dtype"]]
|
|
||||||
|
|
||||||
checkpoint[key] = torch.empty(info["shape"], dtype=dtype, device=device)
|
|
||||||
|
|
||||||
return checkpoint
|
|
||||||
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def read_checkpoint_meta(cls, path: str):
|
def probe_config(cls, path: str, **kwargs):
|
||||||
if path.endswith(".safetensors"):
|
model_format = cls.detect_format(path)
|
||||||
try:
|
ckpt_config_path = kwargs.get("config", None)
|
||||||
checkpoint = cls._fast_safetensors_reader(path)
|
if model_format == "checkpoint":
|
||||||
except:
|
if ckpt_config_path:
|
||||||
checkpoint = safetensors.torch.load_file(path, device="cpu") # TODO: create issue for support "meta"?
|
ckpt_config = OmegaConf.load(ckpt_config_path)
|
||||||
|
ckpt_config["model"]["params"]["unet_config"]["params"]["in_channels"]
|
||||||
|
|
||||||
|
else:
|
||||||
|
checkpoint = read_checkpoint_meta(path)
|
||||||
|
checkpoint = checkpoint.get('state_dict', checkpoint)
|
||||||
|
in_channels = checkpoint["model.diffusion_model.input_blocks.0.0.weight"].shape[1]
|
||||||
|
|
||||||
|
elif model_format == "diffusers":
|
||||||
|
unet_config_path = os.path.join(path, "unet", "config.json")
|
||||||
|
if os.path.exists(unet_config_path):
|
||||||
|
with open(unet_config_path, "r") as f:
|
||||||
|
unet_config = json.loads(f.read())
|
||||||
|
in_channels = unet_config['in_channels']
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise Exception("Not supported stable diffusion diffusers format(possibly onnx?)")
|
||||||
|
|
||||||
else:
|
else:
|
||||||
checkpoint = torch.load(path, map_location=torch.device("meta"))
|
raise NotImplementedError(f"Unknown stable diffusion 1.* format: {model_format}")
|
||||||
return checkpoint
|
|
||||||
|
|
||||||
@classmethod
|
if in_channels == 9:
|
||||||
def build_config(cls, **kwargs):
|
variant = ModelVariantType.Inpaint
|
||||||
if "format" not in kwargs:
|
elif in_channels == 4:
|
||||||
kwargs["format"] = cls.detect_format(kwargs["path"])
|
variant = ModelVariantType.Normal
|
||||||
|
else:
|
||||||
if "variant" not in kwargs:
|
raise Exception("Unkown stable diffusion 1.* model format")
|
||||||
if kwargs["format"] == "checkpoint":
|
|
||||||
if "config" in kwargs:
|
|
||||||
ckpt_config = OmegaConf.load(kwargs["config"])
|
|
||||||
in_channels = ckpt_config["model"]["params"]["unet_config"]["params"]["in_channels"]
|
|
||||||
|
|
||||||
else:
|
|
||||||
checkpoint = cls.read_checkpoint_meta(kwargs["path"])
|
|
||||||
checkpoint = checkpoint.get('state_dict', checkpoint)
|
|
||||||
in_channels = checkpoint["model.diffusion_model.input_blocks.0.0.weight"].shape[1]
|
|
||||||
|
|
||||||
elif kwargs["format"] == "diffusers":
|
|
||||||
unet_config_path = os.path.join(kwargs["path"], "unet", "config.json")
|
|
||||||
if os.path.exists(unet_config_path):
|
|
||||||
with open(unet_config_path, "r") as f:
|
|
||||||
unet_config = json.loads(f.read())
|
|
||||||
in_channels = unet_config['in_channels']
|
|
||||||
|
|
||||||
else:
|
|
||||||
raise Exception("Not supported stable diffusion diffusers format(possibly onnx?)")
|
|
||||||
|
|
||||||
else:
|
|
||||||
raise NotImplementedError(f"Unknown stable diffusion format: {kwargs['format']}")
|
|
||||||
|
|
||||||
if in_channels == 9:
|
|
||||||
kwargs["variant"] = ModelVariantType.Inpaint
|
|
||||||
elif in_channels == 5:
|
|
||||||
kwargs["variant"] = ModelVariantType.Depth
|
|
||||||
elif in_channels == 4:
|
|
||||||
kwargs["variant"] = ModelVariantType.Normal
|
|
||||||
else:
|
|
||||||
raise Exception("Unkown stable diffusion model format")
|
|
||||||
|
|
||||||
|
|
||||||
return super().build_config(**kwargs)
|
return cls.create_config(
|
||||||
|
path=path,
|
||||||
|
format=model_format,
|
||||||
|
|
||||||
|
config=ckpt_config_path,
|
||||||
|
variant=variant,
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def save_to_config(cls) -> bool:
|
def save_to_config(cls) -> bool:
|
||||||
@ -133,128 +99,215 @@ class StableDiffusion15Model(DiffusersModel):
|
|||||||
return "checkpoint"
|
return "checkpoint"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def convert_if_required(cls, model_path: str, dst_cache_path: str, config: ModelConfigBase) -> str:
|
def convert_if_required(
|
||||||
|
cls,
|
||||||
|
model_path: str,
|
||||||
|
output_path: str,
|
||||||
|
config: ModelConfigBase,
|
||||||
|
base_model: BaseModelType,
|
||||||
|
) -> str:
|
||||||
|
assert model_path == config.path
|
||||||
|
|
||||||
if isinstance(config, cls.CheckpointConfig):
|
if isinstance(config, cls.CheckpointConfig):
|
||||||
return _convert_ckpt_and_cache(
|
return _convert_ckpt_and_cache(
|
||||||
version=BaseModelType.StableDiffusion1_5,
|
version=BaseModelType.StableDiffusion1,
|
||||||
config=config.dict(),
|
model_config=config,
|
||||||
in_path=model_path,
|
output_path=output_path,
|
||||||
out_path=dst_cache_path,
|
|
||||||
) # TODO: args
|
) # TODO: args
|
||||||
else:
|
else:
|
||||||
return model_path
|
return model_path
|
||||||
|
|
||||||
# all same
|
|
||||||
class StableDiffusion2BaseModel(StableDiffusion15Model):
|
|
||||||
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
|
|
||||||
# skip StableDiffusion15Model __init__
|
|
||||||
assert base_model == BaseModelType.StableDiffusion2Base
|
|
||||||
assert model_type == ModelType.Pipeline
|
|
||||||
super(StableDiffusion15Model, self).__init__(
|
|
||||||
model_path=model_path,
|
|
||||||
base_model=BaseModelType.StableDiffusion2Base,
|
|
||||||
model_type=ModelType.Pipeline,
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
class StableDiffusion2Model(DiffusersModel):
|
||||||
def convert_if_required(cls, model_path: str, dst_cache_path: str, config: ModelConfigBase) -> str:
|
|
||||||
if isinstance(config, cls.CheckpointConfig):
|
|
||||||
return _convert_ckpt_and_cache(
|
|
||||||
version=BaseModelType.StableDiffusion2Base,
|
|
||||||
config=config.dict(),
|
|
||||||
in_path=model_path,
|
|
||||||
out_path=dst_cache_path,
|
|
||||||
) # TODO: args
|
|
||||||
else:
|
|
||||||
return model_path
|
|
||||||
|
|
||||||
class StableDiffusion2Model(StableDiffusion15Model):
|
# TODO: check that configs overwriten properly
|
||||||
|
|
||||||
# TODO: str -> Path?
|
|
||||||
# TODO: check that configs overwriten
|
|
||||||
class DiffusersConfig(ModelConfigBase):
|
class DiffusersConfig(ModelConfigBase):
|
||||||
format: Literal["diffusers"]
|
format: Literal["diffusers"]
|
||||||
vae: Optional[str] = Field(None)
|
vae: Optional[str] = Field(None)
|
||||||
attention_upscale: bool = Field(True)
|
prediction_type: SchedulerPredictionType
|
||||||
|
upcast_attention: bool
|
||||||
|
|
||||||
class CheckpointConfig(ModelConfigBase):
|
class CheckpointConfig(ModelConfigBase):
|
||||||
format: Literal["checkpoint"]
|
format: Literal["checkpoint"]
|
||||||
vae: Optional[str] = Field(None)
|
vae: Optional[str] = Field(None)
|
||||||
config: Optional[str] = Field(None)
|
config: Optional[str] = Field(None)
|
||||||
attention_upscale: bool = Field(True)
|
prediction_type: SchedulerPredictionType
|
||||||
|
upcast_attention: bool
|
||||||
|
|
||||||
|
|
||||||
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
|
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
|
||||||
# skip StableDiffusion15Model __init__
|
|
||||||
assert base_model == BaseModelType.StableDiffusion2
|
assert base_model == BaseModelType.StableDiffusion2
|
||||||
assert model_type == ModelType.Pipeline
|
assert model_type == ModelType.Pipeline
|
||||||
# skip StableDiffusion15Model __init__
|
super().__init__(
|
||||||
super(StableDiffusion15Model, self).__init__(
|
|
||||||
model_path=model_path,
|
model_path=model_path,
|
||||||
base_model=BaseModelType.StableDiffusion2,
|
base_model=BaseModelType.StableDiffusion2,
|
||||||
model_type=ModelType.Pipeline,
|
model_type=ModelType.Pipeline,
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def convert_if_required(cls, model_path: str, dst_cache_path: str, config: ModelConfigBase) -> str:
|
def probe_config(cls, path: str, **kwargs):
|
||||||
|
model_format = cls.detect_format(path)
|
||||||
|
ckpt_config_path = kwargs.get("config", None)
|
||||||
|
if model_format == "checkpoint":
|
||||||
|
if ckpt_config_path:
|
||||||
|
ckpt_config = OmegaConf.load(ckpt_config_path)
|
||||||
|
ckpt_config["model"]["params"]["unet_config"]["params"]["in_channels"]
|
||||||
|
|
||||||
|
else:
|
||||||
|
checkpoint = read_checkpoint_meta(path)
|
||||||
|
checkpoint = checkpoint.get('state_dict', checkpoint)
|
||||||
|
in_channels = checkpoint["model.diffusion_model.input_blocks.0.0.weight"].shape[1]
|
||||||
|
|
||||||
|
elif model_format == "diffusers":
|
||||||
|
unet_config_path = os.path.join(path, "unet", "config.json")
|
||||||
|
if os.path.exists(unet_config_path):
|
||||||
|
with open(unet_config_path, "r") as f:
|
||||||
|
unet_config = json.loads(f.read())
|
||||||
|
in_channels = unet_config['in_channels']
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise Exception("Not supported stable diffusion diffusers format(possibly onnx?)")
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(f"Unknown stable diffusion 2.* format: {model_format}")
|
||||||
|
|
||||||
|
if in_channels == 9:
|
||||||
|
variant = ModelVariantType.Inpaint
|
||||||
|
elif in_channels == 5:
|
||||||
|
variant = ModelVariantType.Depth
|
||||||
|
elif in_channels == 4:
|
||||||
|
variant = ModelVariantType.Normal
|
||||||
|
else:
|
||||||
|
raise Exception("Unkown stable diffusion 2.* model format")
|
||||||
|
|
||||||
|
if variant == ModelVariantType.Normal:
|
||||||
|
prediction_type = SchedulerPredictionType.VPrediction
|
||||||
|
upcast_attention = True
|
||||||
|
|
||||||
|
else:
|
||||||
|
prediction_type = SchedulerPredictionType.Epsilon
|
||||||
|
upcast_attention = False
|
||||||
|
|
||||||
|
return cls.create_config(
|
||||||
|
path=path,
|
||||||
|
format=model_format,
|
||||||
|
|
||||||
|
config=ckpt_config_path,
|
||||||
|
variant=variant,
|
||||||
|
prediction_type=prediction_type,
|
||||||
|
upcast_attention=upcast_attention,
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def save_to_config(cls) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def detect_format(cls, model_path: str):
|
||||||
|
if os.path.isdir(model_path):
|
||||||
|
return "diffusers"
|
||||||
|
else:
|
||||||
|
return "checkpoint"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def convert_if_required(
|
||||||
|
cls,
|
||||||
|
model_path: str,
|
||||||
|
output_path: str,
|
||||||
|
config: ModelConfigBase,
|
||||||
|
base_model: BaseModelType,
|
||||||
|
) -> str:
|
||||||
|
assert model_path == config.path
|
||||||
|
|
||||||
if isinstance(config, cls.CheckpointConfig):
|
if isinstance(config, cls.CheckpointConfig):
|
||||||
return _convert_ckpt_and_cache(
|
return _convert_ckpt_and_cache(
|
||||||
version=BaseModelType.StableDiffusion2,
|
version=BaseModelType.StableDiffusion2,
|
||||||
config=config.dict(),
|
model_config=config,
|
||||||
in_path=model_path,
|
output_path=output_path,
|
||||||
out_path=dst_cache_path,
|
|
||||||
) # TODO: args
|
) # TODO: args
|
||||||
else:
|
else:
|
||||||
return model_path
|
return model_path
|
||||||
|
|
||||||
|
def _select_ckpt_config(version: BaseModelType, variant: ModelVariantType):
|
||||||
|
ckpt_configs = {
|
||||||
|
BaseModelType.StableDiffusion1: {
|
||||||
|
ModelVariantType.Normal: "v1-inference.yaml",
|
||||||
|
ModelVariantType.Inpaint: "v1-inpainting-inference.yaml",
|
||||||
|
},
|
||||||
|
BaseModelType.StableDiffusion2: {
|
||||||
|
# code further will manually set upcast_attention and v_prediction
|
||||||
|
ModelVariantType.Normal: "v2-inference.yaml",
|
||||||
|
ModelVariantType.Inpaint: "v2-inpainting-inference.yaml",
|
||||||
|
ModelVariantType.Depth: "v2-midas-inference.yaml",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
# TODO: path
|
||||||
|
#model_config.config = app_config.config_dir / "stable-diffusion" / ckpt_configs[version][model_config.variant]
|
||||||
|
#return InvokeAIAppConfig.get_config().legacy_conf_dir / ckpt_configs[version][variant]
|
||||||
|
return InvokeAIAppConfig.get_config().root_dir / "configs" / "stable-diffusion" / ckpt_configs[version][variant]
|
||||||
|
|
||||||
|
except:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
# TODO: rework
|
# TODO: rework
|
||||||
DictConfig = dict
|
|
||||||
def _convert_ckpt_and_cache(
|
def _convert_ckpt_and_cache(
|
||||||
self,
|
|
||||||
version: BaseModelType,
|
version: BaseModelType,
|
||||||
mconfig: dict, # TODO:
|
model_config: Union[StableDiffusion1Model.CheckpointConfig, StableDiffusion2Model.CheckpointConfig],
|
||||||
in_path: str,
|
output_path: str,
|
||||||
out_path: str,
|
|
||||||
) -> str:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Convert the checkpoint model indicated in mconfig into a
|
Convert the checkpoint model indicated in mconfig into a
|
||||||
diffusers, cache it to disk, and return Path to converted
|
diffusers, cache it to disk, and return Path to converted
|
||||||
file. If already on disk then just returns Path.
|
file. If already on disk then just returns Path.
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError()
|
|
||||||
app_config = InvokeAIAppConfig.get_config()
|
app_config = InvokeAIAppConfig.get_config()
|
||||||
|
|
||||||
#if "config" not in mconfig:
|
if model_config.config is None:
|
||||||
# if version == BaseModelType.StableDiffusion1_5:
|
model_config.config = _select_ckpt_config(version, model_config.variant)
|
||||||
#if
|
if model_config.config is None:
|
||||||
#mconfig["config"] = app_config.config_dir / "stable-diffusion" / "v1-inference.yaml"
|
raise Exception(f"Model variant {model_config.variant} not supported for {version}")
|
||||||
|
|
||||||
|
|
||||||
weights = app_config.root_dir / mconfig.path
|
weights = app_config.root_dir / model_config.path
|
||||||
config_file = app_config.root_dir / mconfig.config
|
config_file = app_config.root_dir / model_config.config
|
||||||
diffusers_path = app_config.converted_ckpts_dir / weights.stem
|
output_path = Path(output_path)
|
||||||
|
|
||||||
|
if version == BaseModelType.StableDiffusion1:
|
||||||
|
upcast_attention = False
|
||||||
|
prediction_type = SchedulerPredictionType.Epsilon
|
||||||
|
|
||||||
|
elif version == BaseModelType.StableDiffusion2:
|
||||||
|
upcast_attention = config.upcast_attention
|
||||||
|
prediction_type = config.prediction_type
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise Exception(f"Unknown model provided: {version}")
|
||||||
|
|
||||||
|
|
||||||
# return cached version if it exists
|
# return cached version if it exists
|
||||||
if diffusers_path.exists():
|
if output_path.exists():
|
||||||
return diffusers_path
|
return output_path
|
||||||
|
|
||||||
# TODO: I think that it more correctly to convert with embedded vae
|
# TODO: I think that it more correctly to convert with embedded vae
|
||||||
# as if user will delete custom vae he will got not embedded but also custom vae
|
# as if user will delete custom vae he will got not embedded but also custom vae
|
||||||
#vae_ckpt_path, vae_model = self._get_vae_for_conversion(weights, mconfig)
|
#vae_ckpt_path, vae_model = self._get_vae_for_conversion(weights, mconfig)
|
||||||
vae_ckpt_path, vae_model = None, None
|
|
||||||
|
|
||||||
# to avoid circular import errors
|
# to avoid circular import errors
|
||||||
from ..convert_ckpt_to_diffusers import convert_ckpt_to_diffusers
|
from ..convert_ckpt_to_diffusers import convert_ckpt_to_diffusers
|
||||||
with SilenceWarnings():
|
with SilenceWarnings():
|
||||||
convert_ckpt_to_diffusers(
|
convert_ckpt_to_diffusers(
|
||||||
weights,
|
weights,
|
||||||
diffusers_path,
|
output_path,
|
||||||
extract_ema=True,
|
model_version=version,
|
||||||
|
model_variant=model_config.variant,
|
||||||
original_config_file=config_file,
|
original_config_file=config_file,
|
||||||
vae=vae_model,
|
extract_ema=True,
|
||||||
vae_path=str(app_config.root_dir / vae_ckpt_path) if vae_ckpt_path else None,
|
upcast_attention=upcast_attention,
|
||||||
|
prediction_type=prediction_type,
|
||||||
scan_needed=True,
|
scan_needed=True,
|
||||||
model_root=app_config.models_path,
|
model_root=app_config.models_path,
|
||||||
)
|
)
|
||||||
return diffusers_path
|
return output_path
|
||||||
|
@ -51,6 +51,12 @@ class TextualInversionModel(ModelBase):
|
|||||||
def detect_format(cls, path: str):
|
def detect_format(cls, path: str):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@staticmethod
|
@classmethod
|
||||||
def convert_if_required(model_path: str, cache_path: str, config: Optional[dict]) -> str:
|
def convert_if_required(
|
||||||
|
cls,
|
||||||
|
model_path: str,
|
||||||
|
output_path: str,
|
||||||
|
config: ModelConfigBase,
|
||||||
|
base_model: BaseModelType,
|
||||||
|
) -> str:
|
||||||
return model_path
|
return model_path
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
import os
|
import os
|
||||||
import torch
|
import torch
|
||||||
|
from pathlib import Path
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
from .base import (
|
from .base import (
|
||||||
ModelBase,
|
ModelBase,
|
||||||
@ -7,11 +8,14 @@ from .base import (
|
|||||||
BaseModelType,
|
BaseModelType,
|
||||||
ModelType,
|
ModelType,
|
||||||
SubModelType,
|
SubModelType,
|
||||||
|
ModelVariantType,
|
||||||
EmptyConfigLoader,
|
EmptyConfigLoader,
|
||||||
calc_model_size_by_fs,
|
calc_model_size_by_fs,
|
||||||
calc_model_size_by_data,
|
calc_model_size_by_data,
|
||||||
)
|
)
|
||||||
from invokeai.app.services.config import InvokeAIAppConfig
|
from invokeai.app.services.config import InvokeAIAppConfig
|
||||||
|
from diffusers.utils import is_safetensors_available
|
||||||
|
from omegaconf import OmegaConf
|
||||||
|
|
||||||
class VaeModel(ModelBase):
|
class VaeModel(ModelBase):
|
||||||
#vae_class: Type
|
#vae_class: Type
|
||||||
@ -70,39 +74,72 @@ class VaeModel(ModelBase):
|
|||||||
return "checkpoint"
|
return "checkpoint"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def convert_if_required(cls, model_path: str, dst_cache_path: str, config: Optional[dict]) -> str:
|
def convert_if_required(
|
||||||
|
cls,
|
||||||
|
model_path: str,
|
||||||
|
output_path: str,
|
||||||
|
config: ModelConfigBase, # empty config or config of parent model
|
||||||
|
base_model: BaseModelType,
|
||||||
|
) -> str:
|
||||||
if cls.detect_format(model_path) != "diffusers":
|
if cls.detect_format(model_path) != "diffusers":
|
||||||
# TODO:
|
return _convert_vae_ckpt_and_cache(
|
||||||
#_convert_vae_ckpt_and_cache
|
weights_path=model_path,
|
||||||
raise NotImplementedError("TODO: vae convert")
|
output_path=output_path,
|
||||||
|
base_model=base_model,
|
||||||
|
model_config=config,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
return model_path
|
return model_path
|
||||||
|
|
||||||
# TODO: rework
|
# TODO: rework
|
||||||
DictConfig = dict
|
def _convert_vae_ckpt_and_cache(
|
||||||
def _convert_vae_ckpt_and_cache(self, mconfig: DictConfig) -> str:
|
weights_path: str,
|
||||||
|
output_path: str,
|
||||||
|
base_model: BaseModelType,
|
||||||
|
model_config: ModelConfigBase,
|
||||||
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Convert the VAE indicated in mconfig into a diffusers AutoencoderKL
|
Convert the VAE indicated in mconfig into a diffusers AutoencoderKL
|
||||||
object, cache it to disk, and return Path to converted
|
object, cache it to disk, and return Path to converted
|
||||||
file. If already on disk then just returns Path.
|
file. If already on disk then just returns Path.
|
||||||
"""
|
"""
|
||||||
app_config = InvokeAIAppConfig.get_config()
|
app_config = InvokeAIAppConfig.get_config()
|
||||||
root = app_config.root_dir
|
weights_path = app_config.root_dir / weights_path
|
||||||
weights_file = root / mconfig.path
|
output_path = Path(output_path)
|
||||||
config_file = root / mconfig.config
|
|
||||||
diffusers_path = app_config.converted_ckpts_dir / weights_file.stem
|
"""
|
||||||
image_size = mconfig.get('width') or mconfig.get('height') or 512
|
this size used only in when tiling enabled to separate input in tiles
|
||||||
|
sizes in configs from stable diffusion githubs(1 and 2) set to 256
|
||||||
|
on huggingface it:
|
||||||
|
1.5 - 512
|
||||||
|
1.5-inpainting - 256
|
||||||
|
2-inpainting - 512
|
||||||
|
2-depth - 256
|
||||||
|
2-base - 512
|
||||||
|
2 - 768
|
||||||
|
2.1-base - 768
|
||||||
|
2.1 - 768
|
||||||
|
"""
|
||||||
|
image_size = 512
|
||||||
|
|
||||||
# return cached version if it exists
|
# return cached version if it exists
|
||||||
if diffusers_path.exists():
|
if output_path.exists():
|
||||||
return diffusers_path
|
return output_path
|
||||||
|
|
||||||
|
if base_model in {BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2}:
|
||||||
|
from .stable_diffusion import _select_ckpt_config
|
||||||
|
# all sd models use same vae settings
|
||||||
|
config_file = _select_ckpt_config(base_model, ModelVariantType.Normal)
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise Exception(f"Vae conversion not supported for model type: {base_model}")
|
||||||
|
|
||||||
# this avoids circular import error
|
# this avoids circular import error
|
||||||
from .convert_ckpt_to_diffusers import convert_ldm_vae_to_diffusers
|
from ..convert_ckpt_to_diffusers import convert_ldm_vae_to_diffusers
|
||||||
if weights_file.suffix == '.safetensors':
|
if weights_path.suffix == '.safetensors':
|
||||||
checkpoint = safetensors.torch.load_file(weights_file)
|
checkpoint = safetensors.torch.load_file(weights_path, device="cpu")
|
||||||
else:
|
else:
|
||||||
checkpoint = torch.load(weights_file, map_location="cpu")
|
checkpoint = torch.load(weights_path, map_location="cpu")
|
||||||
|
|
||||||
# sometimes weights are hidden under "state_dict", and sometimes not
|
# sometimes weights are hidden under "state_dict", and sometimes not
|
||||||
if "state_dict" in checkpoint:
|
if "state_dict" in checkpoint:
|
||||||
@ -117,7 +154,7 @@ def _convert_vae_ckpt_and_cache(self, mconfig: DictConfig) -> str:
|
|||||||
model_root = app_config.models_path,
|
model_root = app_config.models_path,
|
||||||
)
|
)
|
||||||
vae_model.save_pretrained(
|
vae_model.save_pretrained(
|
||||||
diffusers_path,
|
output_path,
|
||||||
safe_serialization=is_safetensors_available()
|
safe_serialization=is_safetensors_available()
|
||||||
)
|
)
|
||||||
return diffusers_path
|
return output_path
|
||||||
|
@ -14,7 +14,7 @@ export const receivedModels = createAppAsyncThunk(
|
|||||||
const response = await ModelsService.listModels();
|
const response = await ModelsService.listModels();
|
||||||
|
|
||||||
const deserializedModels = reduce(
|
const deserializedModels = reduce(
|
||||||
response.models['sd-1.5']['pipeline'],
|
response.models['sd-1']['pipeline'],
|
||||||
(modelsAccumulator, model, modelName) => {
|
(modelsAccumulator, model, modelName) => {
|
||||||
modelsAccumulator[modelName] = { ...model, name: modelName };
|
modelsAccumulator[modelName] = { ...model, name: modelName };
|
||||||
|
|
||||||
@ -25,7 +25,7 @@ export const receivedModels = createAppAsyncThunk(
|
|||||||
|
|
||||||
models.info(
|
models.info(
|
||||||
{ response },
|
{ response },
|
||||||
`Received ${size(response.models['sd-1.5']['pipeline'])} models`
|
`Received ${size(response.models['sd-1']['pipeline'])} models`
|
||||||
);
|
);
|
||||||
|
|
||||||
return deserializedModels;
|
return deserializedModels;
|
||||||
|
@ -100,45 +100,45 @@ def migrate_conversion_models(dest_directory: Path):
|
|||||||
# These are needed for the conversion script
|
# These are needed for the conversion script
|
||||||
kwargs = dict(
|
kwargs = dict(
|
||||||
cache_dir = Path('./models/hub'),
|
cache_dir = Path('./models/hub'),
|
||||||
local_files_only = True
|
#local_files_only = True
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
logger.info('Migrating core tokenizers and text encoders')
|
logger.info('Migrating core tokenizers and text encoders')
|
||||||
target_dir = dest_directory/'core/convert'
|
target_dir = dest_directory / 'core' / 'convert'
|
||||||
|
|
||||||
# bert
|
# bert
|
||||||
bert = BertTokenizerFast.from_pretrained("bert-base-uncased", **kwargs)
|
bert = BertTokenizerFast.from_pretrained("bert-base-uncased", **kwargs)
|
||||||
bert.save_pretrained(target_dir/'bert-base-uncased', safe_serialization=True)
|
bert.save_pretrained(target_dir / 'bert-base-uncased', safe_serialization=True)
|
||||||
|
|
||||||
# sd-1
|
# sd-1
|
||||||
repo_id = 'openai/clip-vit-large-patch14'
|
repo_id = 'openai/clip-vit-large-patch14'
|
||||||
pipeline = CLIPTokenizer.from_pretrained(repo_id, **kwargs)
|
pipeline = CLIPTokenizer.from_pretrained(repo_id, **kwargs)
|
||||||
pipeline.save_pretrained(target_dir/'clip-vit-large-patch14', safe_serialization=True)
|
pipeline.save_pretrained(target_dir / 'clip-vit-large-patch14', safe_serialization=True)
|
||||||
|
|
||||||
pipeline = CLIPTextModel.from_pretrained(repo_id, **kwargs)
|
pipeline = CLIPTextModel.from_pretrained(repo_id, **kwargs)
|
||||||
pipeline.save_pretrained(target_dir/'clip-vit-large-patch14', safe_serialization=True)
|
pipeline.save_pretrained(target_dir / 'clip-vit-large-patch14', safe_serialization=True)
|
||||||
|
|
||||||
# sd-2
|
# sd-2
|
||||||
repo_id = "stabilityai/stable-diffusion-2"
|
repo_id = "stabilityai/stable-diffusion-2"
|
||||||
pipeline = CLIPTokenizer.from_pretrained(repo_id, subfolder = "tokenizer", **kwargs)
|
pipeline = CLIPTokenizer.from_pretrained(repo_id, subfolder="tokenizer", **kwargs)
|
||||||
pipeline.save_pretrained(target_dir/'stable-diffusion-2-tokenizer', safe_serialization=True)
|
pipeline.save_pretrained(target_dir / 'stable-diffusion-2-clip' / 'tokenizer', safe_serialization=True)
|
||||||
|
|
||||||
pipeline = CLIPTextModel.from_pretrained(repo_id,subfolder = "text_encoder", **kwargs)
|
pipeline = CLIPTextModel.from_pretrained(repo_id, subfolder="text_encoder", **kwargs)
|
||||||
pipeline.save_pretrained(target_dir/'stable-diffusion-2-text_encoder', safe_serialization=True)
|
pipeline.save_pretrained(target_dir / 'stable-diffusion-2-clip' / 'text_encoder', safe_serialization=True)
|
||||||
|
|
||||||
# VAE
|
# VAE
|
||||||
logger.info('Migrating stable diffusion VAE')
|
logger.info('Migrating stable diffusion VAE')
|
||||||
vae = AutoencoderKL.from_pretrained('stabilityai/sd-vae-ft-mse', **kwargs)
|
vae = AutoencoderKL.from_pretrained('stabilityai/sd-vae-ft-mse', **kwargs)
|
||||||
vae.save_pretrained(target_dir/'sd-vae-ft-mse', safe_serialization=True)
|
vae.save_pretrained(target_dir / 'sd-vae-ft-mse', safe_serialization=True)
|
||||||
|
|
||||||
# safety checking
|
# safety checking
|
||||||
logger.info('Migrating safety checker')
|
logger.info('Migrating safety checker')
|
||||||
repo_id = "CompVis/stable-diffusion-safety-checker"
|
repo_id = "CompVis/stable-diffusion-safety-checker"
|
||||||
pipeline = AutoFeatureExtractor.from_pretrained(repo_id,**kwargs)
|
pipeline = AutoFeatureExtractor.from_pretrained(repo_id,**kwargs)
|
||||||
pipeline.save_pretrained(target_dir/'stable-diffusion-safety-checker-extractor', safe_serialization=True)
|
pipeline.save_pretrained(target_dir / 'stable-diffusion-safety-checker', safe_serialization=True)
|
||||||
|
|
||||||
pipeline = StableDiffusionSafetyChecker.from_pretrained(repo_id,**kwargs)
|
pipeline = StableDiffusionSafetyChecker.from_pretrained(repo_id,**kwargs)
|
||||||
pipeline.save_pretrained(target_dir/'stable-diffusion-safety-checker', safe_serialization=True)
|
pipeline.save_pretrained(target_dir / 'stable-diffusion-safety-checker', safe_serialization=True)
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
raise
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
Loading…
Reference in New Issue
Block a user