mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
WIP - A bunch of boilerplate to support Spandrel Image-to-Image models throughout the model manager and the frontend.
This commit is contained in:
@ -373,6 +373,17 @@ class T2IAdapterConfig(DiffusersConfigBase, ControlAdapterConfigBase):
|
||||
return Tag(f"{ModelType.T2IAdapter.value}.{ModelFormat.Diffusers.value}")
|
||||
|
||||
|
||||
class SpandrelImageToImageConfig(ModelConfigBase):
|
||||
"""Model config for Spandrel Image to Image models."""
|
||||
|
||||
type: Literal[ModelType.SpandrelImageToImage] = ModelType.SpandrelImageToImage
|
||||
format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint
|
||||
|
||||
@staticmethod
|
||||
def get_tag() -> Tag:
|
||||
return Tag(f"{ModelType.SpandrelImageToImage.value}.{ModelFormat.Checkpoint.value}")
|
||||
|
||||
|
||||
def get_model_discriminator_value(v: Any) -> str:
|
||||
"""
|
||||
Computes the discriminator value for a model config.
|
||||
@ -409,6 +420,7 @@ AnyModelConfig = Annotated[
|
||||
Annotated[IPAdapterInvokeAIConfig, IPAdapterInvokeAIConfig.get_tag()],
|
||||
Annotated[IPAdapterCheckpointConfig, IPAdapterCheckpointConfig.get_tag()],
|
||||
Annotated[T2IAdapterConfig, T2IAdapterConfig.get_tag()],
|
||||
Annotated[SpandrelImageToImageConfig, SpandrelImageToImageConfig.get_tag()],
|
||||
Annotated[CLIPVisionDiffusersConfig, CLIPVisionDiffusersConfig.get_tag()],
|
||||
],
|
||||
Discriminator(get_model_discriminator_value),
|
||||
|
@ -243,10 +243,14 @@ class ModelProbe(object):
|
||||
|
||||
# Check if the model can be loaded as a SpandrelImageToImageModel.
|
||||
try:
|
||||
_ = SpandrelImageToImageModel.load_from_state_dict(ckpt)
|
||||
# TODO(ryand): Figure out why load_from_state_dict() doesn't work as expected.
|
||||
# _ = SpandrelImageToImageModel.load_from_state_dict(ckpt)
|
||||
_ = SpandrelImageToImageModel.load_from_file(model_path)
|
||||
return ModelType.SpandrelImageToImage
|
||||
except Exception:
|
||||
except Exception as e:
|
||||
# TODO(ryand): Catch a more specific exception type here if we can.
|
||||
# TODO(ryand): Delete this print statement.
|
||||
print(e)
|
||||
pass
|
||||
|
||||
raise InvalidModelConfigException(f"Unable to determine model type for {model_path}")
|
||||
@ -579,9 +583,9 @@ class T2IAdapterCheckpointProbe(CheckpointProbeBase):
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class SpandrelImageToImageModelProbe(CheckpointProbeBase):
|
||||
class SpandrelImageToImageCheckpointProbe(CheckpointProbeBase):
|
||||
def get_base_type(self) -> BaseModelType:
|
||||
raise NotImplementedError()
|
||||
return BaseModelType.Any
|
||||
|
||||
|
||||
########################################################
|
||||
@ -791,6 +795,11 @@ class CLIPVisionFolderProbe(FolderProbeBase):
|
||||
return BaseModelType.Any
|
||||
|
||||
|
||||
class SpandrelImageToImageFolderProbe(FolderProbeBase):
|
||||
def get_base_type(self) -> BaseModelType:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class T2IAdapterFolderProbe(FolderProbeBase):
|
||||
def get_base_type(self) -> BaseModelType:
|
||||
config_file = self.model_path / "config.json"
|
||||
@ -820,6 +829,7 @@ ModelProbe.register_probe("diffusers", ModelType.ControlNet, ControlNetFolderPro
|
||||
ModelProbe.register_probe("diffusers", ModelType.IPAdapter, IPAdapterFolderProbe)
|
||||
ModelProbe.register_probe("diffusers", ModelType.CLIPVision, CLIPVisionFolderProbe)
|
||||
ModelProbe.register_probe("diffusers", ModelType.T2IAdapter, T2IAdapterFolderProbe)
|
||||
ModelProbe.register_probe("diffusers", ModelType.SpandrelImageToImage, SpandrelImageToImageFolderProbe)
|
||||
|
||||
ModelProbe.register_probe("checkpoint", ModelType.Main, PipelineCheckpointProbe)
|
||||
ModelProbe.register_probe("checkpoint", ModelType.VAE, VaeCheckpointProbe)
|
||||
@ -829,5 +839,6 @@ ModelProbe.register_probe("checkpoint", ModelType.ControlNet, ControlNetCheckpoi
|
||||
ModelProbe.register_probe("checkpoint", ModelType.IPAdapter, IPAdapterCheckpointProbe)
|
||||
ModelProbe.register_probe("checkpoint", ModelType.CLIPVision, CLIPVisionCheckpointProbe)
|
||||
ModelProbe.register_probe("checkpoint", ModelType.T2IAdapter, T2IAdapterCheckpointProbe)
|
||||
ModelProbe.register_probe("checkpoint", ModelType.SpandrelImageToImage, SpandrelImageToImageCheckpointProbe)
|
||||
|
||||
ModelProbe.register_probe("onnx", ModelType.ONNX, ONNXFolderProbe)
|
||||
|
Reference in New Issue
Block a user