fix(mm): handle depth and inpainting models when converting to diffusers

"Normal" models have 4 in-channels, while "Depth" models have 5 and "Inpaint" models have 9.

We need to explicitly tell diffusers the channel count when converting models.

Closes  #6058
This commit is contained in:
psychedelicious 2024-03-27 19:01:04 +11:00 committed by Kent Keirsey
parent 536bb4f053
commit eb33303e79

View File

@ -14,12 +14,18 @@ from invokeai.backend.model_manager import (
SchedulerPredictionType, SchedulerPredictionType,
SubModelType, SubModelType,
) )
from invokeai.backend.model_manager.config import CheckpointConfigBase, MainCheckpointConfig from invokeai.backend.model_manager.config import CheckpointConfigBase, MainCheckpointConfig, ModelVariantType
from invokeai.backend.model_manager.convert_ckpt_to_diffusers import convert_ckpt_to_diffusers from invokeai.backend.model_manager.convert_ckpt_to_diffusers import convert_ckpt_to_diffusers
from .. import ModelLoaderRegistry from .. import ModelLoaderRegistry
from .generic_diffusers import GenericDiffusersLoader from .generic_diffusers import GenericDiffusersLoader
VARIANT_TO_IN_CHANNEL_MAP = {
ModelVariantType.Normal: 4,
ModelVariantType.Depth: 5,
ModelVariantType.Inpaint: 9,
}
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.Main, format=ModelFormat.Diffusers) @ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.Main, format=ModelFormat.Diffusers)
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.Main, format=ModelFormat.Checkpoint) @ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.Main, format=ModelFormat.Checkpoint)
@ -87,6 +93,7 @@ class StableDiffusionDiffusersModel(GenericDiffusersLoader):
) )
self._logger.info(f"Converting {model_path} to diffusers format") self._logger.info(f"Converting {model_path} to diffusers format")
convert_ckpt_to_diffusers( convert_ckpt_to_diffusers(
model_path, model_path,
output_path, output_path,
@ -99,5 +106,6 @@ class StableDiffusionDiffusersModel(GenericDiffusersLoader):
image_size=image_size, image_size=image_size,
upcast_attention=upcast_attention, upcast_attention=upcast_attention,
load_safety_checker=False, load_safety_checker=False,
num_in_channels=VARIANT_TO_IN_CHANNEL_MAP[config.variant],
) )
return output_path return output_path