WIP - A bunch of boilerplate to support Spandrel Image-to-Image models throughout the model manager and the frontend.

This commit is contained in:
Ryan Dick
2024-06-28 18:03:09 -04:00
parent 95079dc7d4
commit 29c8ddfb88
15 changed files with 287 additions and 19 deletions

View File

@ -373,6 +373,17 @@ class T2IAdapterConfig(DiffusersConfigBase, ControlAdapterConfigBase):
return Tag(f"{ModelType.T2IAdapter.value}.{ModelFormat.Diffusers.value}")
class SpandrelImageToImageConfig(ModelConfigBase):
"""Model config for Spandrel Image to Image models."""
type: Literal[ModelType.SpandrelImageToImage] = ModelType.SpandrelImageToImage
format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint
@staticmethod
def get_tag() -> Tag:
return Tag(f"{ModelType.SpandrelImageToImage.value}.{ModelFormat.Checkpoint.value}")
def get_model_discriminator_value(v: Any) -> str:
"""
Computes the discriminator value for a model config.
@ -409,6 +420,7 @@ AnyModelConfig = Annotated[
Annotated[IPAdapterInvokeAIConfig, IPAdapterInvokeAIConfig.get_tag()],
Annotated[IPAdapterCheckpointConfig, IPAdapterCheckpointConfig.get_tag()],
Annotated[T2IAdapterConfig, T2IAdapterConfig.get_tag()],
Annotated[SpandrelImageToImageConfig, SpandrelImageToImageConfig.get_tag()],
Annotated[CLIPVisionDiffusersConfig, CLIPVisionDiffusersConfig.get_tag()],
],
Discriminator(get_model_discriminator_value),

View File

@ -243,10 +243,14 @@ class ModelProbe(object):
# Check if the model can be loaded as a SpandrelImageToImageModel.
try:
_ = SpandrelImageToImageModel.load_from_state_dict(ckpt)
# TODO(ryand): Figure out why load_from_state_dict() doesn't work as expected.
# _ = SpandrelImageToImageModel.load_from_state_dict(ckpt)
_ = SpandrelImageToImageModel.load_from_file(model_path)
return ModelType.SpandrelImageToImage
except Exception:
except Exception as e:
# TODO(ryand): Catch a more specific exception type here if we can.
# TODO(ryand): Delete this print statement.
print(e)
pass
raise InvalidModelConfigException(f"Unable to determine model type for {model_path}")
@ -579,9 +583,9 @@ class T2IAdapterCheckpointProbe(CheckpointProbeBase):
raise NotImplementedError()
class SpandrelImageToImageModelProbe(CheckpointProbeBase):
class SpandrelImageToImageCheckpointProbe(CheckpointProbeBase):
def get_base_type(self) -> BaseModelType:
raise NotImplementedError()
return BaseModelType.Any
########################################################
@ -791,6 +795,11 @@ class CLIPVisionFolderProbe(FolderProbeBase):
return BaseModelType.Any
class SpandrelImageToImageFolderProbe(FolderProbeBase):
def get_base_type(self) -> BaseModelType:
raise NotImplementedError()
class T2IAdapterFolderProbe(FolderProbeBase):
def get_base_type(self) -> BaseModelType:
config_file = self.model_path / "config.json"
@ -820,6 +829,7 @@ ModelProbe.register_probe("diffusers", ModelType.ControlNet, ControlNetFolderPro
ModelProbe.register_probe("diffusers", ModelType.IPAdapter, IPAdapterFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.CLIPVision, CLIPVisionFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.T2IAdapter, T2IAdapterFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.SpandrelImageToImage, SpandrelImageToImageFolderProbe)
ModelProbe.register_probe("checkpoint", ModelType.Main, PipelineCheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.VAE, VaeCheckpointProbe)
@ -829,5 +839,6 @@ ModelProbe.register_probe("checkpoint", ModelType.ControlNet, ControlNetCheckpoi
ModelProbe.register_probe("checkpoint", ModelType.IPAdapter, IPAdapterCheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.CLIPVision, CLIPVisionCheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.T2IAdapter, T2IAdapterCheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.SpandrelImageToImage, SpandrelImageToImageCheckpointProbe)
ModelProbe.register_probe("onnx", ModelType.ONNX, ONNXFolderProbe)