From 3d52656176e66635f6593a2d98eb93bc91640a0d Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Wed, 13 Sep 2023 17:14:20 -0400 Subject: [PATCH] Add CLIPVisionModel to model management. --- .../backend/model_management/model_probe.py | 35 ++++++-- .../model_management/models/__init__.py | 5 ++ .../backend/model_management/models/base.py | 1 + .../model_management/models/clip_vision.py | 82 +++++++++++++++++++ 4 files changed, 115 insertions(+), 8 deletions(-) create mode 100644 invokeai/backend/model_management/models/clip_vision.py diff --git a/invokeai/backend/model_management/model_probe.py b/invokeai/backend/model_management/model_probe.py index 6750e0fe6a..71e5b89e0f 100644 --- a/invokeai/backend/model_management/model_probe.py +++ b/invokeai/backend/model_management/model_probe.py @@ -8,6 +8,8 @@ import torch from diffusers import ConfigMixin, ModelMixin from picklescan.scanner import scan_file_path +from invokeai.backend.model_management.models import BaseModelType + from .models import ( BaseModelType, InvalidModelException, @@ -53,6 +55,7 @@ class ModelProbe(object): "AutoencoderKL": ModelType.Vae, "ControlNetModel": ModelType.ControlNet, "IPAdapterModel": ModelType.IPAdapter, + "CLIPVision": ModelType.CLIPVision, } @classmethod @@ -119,14 +122,18 @@ class ModelProbe(object): and prediction_type == SchedulerPredictionType.VPrediction ), format=format, - image_size=1024 - if (base_type in {BaseModelType.StableDiffusionXL, BaseModelType.StableDiffusionXLRefiner}) - else 768 - if ( - base_type == BaseModelType.StableDiffusion2 - and prediction_type == SchedulerPredictionType.VPrediction - ) - else 512, + image_size=( + 1024 + if (base_type in {BaseModelType.StableDiffusionXL, BaseModelType.StableDiffusionXLRefiner}) + else ( + 768 + if ( + base_type == BaseModelType.StableDiffusion2 + and prediction_type == SchedulerPredictionType.VPrediction + ) + else 512 + ) + ), ) except Exception: raise @@ -372,6 +379,11 @@ class IPAdapterCheckpointProbe(CheckpointProbeBase): raise NotImplementedError() +class CLIPVisionCheckpointProbe(CheckpointProbeBase): + def get_base_type(self) -> BaseModelType: + raise NotImplementedError() + + ######################################################## # classes for probing folders ####################################################### @@ -520,6 +532,11 @@ class IPAdapterFolderProbe(FolderProbeBase): raise NotImplementedError() +class CLIPVisionFolderProbe(FolderProbeBase): + def get_base_type(self) -> BaseModelType: + raise NotImplementedError() + + ############## register probe classes ###### ModelProbe.register_probe("diffusers", ModelType.Main, PipelineFolderProbe) ModelProbe.register_probe("diffusers", ModelType.Vae, VaeFolderProbe) @@ -527,6 +544,7 @@ ModelProbe.register_probe("diffusers", ModelType.Lora, LoRAFolderProbe) ModelProbe.register_probe("diffusers", ModelType.TextualInversion, TextualInversionFolderProbe) ModelProbe.register_probe("diffusers", ModelType.ControlNet, ControlNetFolderProbe) ModelProbe.register_probe("diffusers", ModelType.IPAdapter, IPAdapterFolderProbe) +ModelProbe.register_probe("diffusers", ModelType.CLIPVision, CLIPVisionFolderProbe) ModelProbe.register_probe("checkpoint", ModelType.Main, PipelineCheckpointProbe) ModelProbe.register_probe("checkpoint", ModelType.Vae, VaeCheckpointProbe) @@ -534,5 +552,6 @@ ModelProbe.register_probe("checkpoint", ModelType.Lora, LoRACheckpointProbe) ModelProbe.register_probe("checkpoint", ModelType.TextualInversion, TextualInversionCheckpointProbe) ModelProbe.register_probe("checkpoint", ModelType.ControlNet, ControlNetCheckpointProbe) ModelProbe.register_probe("checkpoint", ModelType.IPAdapter, IPAdapterCheckpointProbe) +ModelProbe.register_probe("checkpoint", ModelType.CLIPVision, CLIPVisionCheckpointProbe) ModelProbe.register_probe("onnx", ModelType.ONNX, ONNXFolderProbe) diff --git a/invokeai/backend/model_management/models/__init__.py b/invokeai/backend/model_management/models/__init__.py index cc9d1f4055..deeaf33bca 100644 --- a/invokeai/backend/model_management/models/__init__.py +++ b/invokeai/backend/model_management/models/__init__.py @@ -18,6 +18,7 @@ from .base import ( # noqa: F401 SilenceWarnings, SubModelType, ) +from .clip_vision import CLIPVisionModel from .controlnet import ControlNetModel # TODO: from .ip_adapter import IPAdapterModel from .lora import LoRAModel @@ -36,6 +37,7 @@ MODEL_CLASSES = { ModelType.ControlNet: ControlNetModel, ModelType.TextualInversion: TextualInversionModel, ModelType.IPAdapter: IPAdapterModel, + ModelType.CLIPVision: CLIPVisionModel, }, BaseModelType.StableDiffusion2: { ModelType.ONNX: ONNXStableDiffusion2Model, @@ -45,6 +47,7 @@ MODEL_CLASSES = { ModelType.ControlNet: ControlNetModel, ModelType.TextualInversion: TextualInversionModel, ModelType.IPAdapter: IPAdapterModel, + ModelType.CLIPVision: CLIPVisionModel, }, BaseModelType.StableDiffusionXL: { ModelType.Main: StableDiffusionXLModel, @@ -55,6 +58,7 @@ MODEL_CLASSES = { ModelType.TextualInversion: TextualInversionModel, ModelType.ONNX: ONNXStableDiffusion2Model, ModelType.IPAdapter: IPAdapterModel, + ModelType.CLIPVision: CLIPVisionModel, }, BaseModelType.StableDiffusionXLRefiner: { ModelType.Main: StableDiffusionXLModel, @@ -65,6 +69,7 @@ MODEL_CLASSES = { ModelType.TextualInversion: TextualInversionModel, ModelType.ONNX: ONNXStableDiffusion2Model, ModelType.IPAdapter: IPAdapterModel, + ModelType.CLIPVision: CLIPVisionModel, }, # BaseModelType.Kandinsky2_1: { # ModelType.Main: Kandinsky2_1Model, diff --git a/invokeai/backend/model_management/models/base.py b/invokeai/backend/model_management/models/base.py index 0bff479412..f69baf50fd 100644 --- a/invokeai/backend/model_management/models/base.py +++ b/invokeai/backend/model_management/models/base.py @@ -62,6 +62,7 @@ class ModelType(str, Enum): ControlNet = "controlnet" # used by model_probe TextualInversion = "embedding" IPAdapter = "ip_adapter" + CLIPVision = "clip_vision" class SubModelType(str, Enum): diff --git a/invokeai/backend/model_management/models/clip_vision.py b/invokeai/backend/model_management/models/clip_vision.py new file mode 100644 index 0000000000..7df3119f9c --- /dev/null +++ b/invokeai/backend/model_management/models/clip_vision.py @@ -0,0 +1,82 @@ +import os +from enum import Enum +from typing import Literal, Optional + +import torch +from transformers import CLIPVisionModelWithProjection + +from invokeai.backend.model_management.models.base import ( + BaseModelType, + InvalidModelException, + ModelBase, + ModelConfigBase, + ModelType, + SubModelType, + calc_model_size_by_data, + calc_model_size_by_fs, + classproperty, +) + + +class CLIPVisionModelFormat(str, Enum): + Diffusers = "diffusers" + + +class CLIPVisionModel(ModelBase): + class DiffusersConfig(ModelConfigBase): + model_format: Literal[CLIPVisionModelFormat.Diffusers] + + def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType): + assert model_type == ModelType.CLIPVision + super().__init__(model_path, base_model, model_type) + + self.model_size = calc_model_size_by_fs(self.model_path) + + @classmethod + def detect_format(cls, path: str) -> str: + if not os.path.exists(path): + raise ModuleNotFoundError(f"No CLIP Vision model at path '{path}'.") + + if os.path.isdir(path) and os.path.exists(os.path.join(path, "config.json")): + return CLIPVisionModelFormat.Diffusers + + raise InvalidModelException(f"Unexpected CLIP Vision model format: {path}") + + @classproperty + def save_to_config(cls) -> bool: + return True + + def get_size(self, child_type: Optional[SubModelType] = None) -> int: + if child_type is not None: + raise ValueError("There are no child models in a CLIP Vision model.") + + return self.model_size + + def get_model( + self, + torch_dtype: Optional[torch.dtype], + child_type: Optional[SubModelType] = None, + ) -> CLIPVisionModelWithProjection: + if child_type is not None: + raise ValueError("There are no child models in a CLIP Vision model.") + + model = CLIPVisionModelWithProjection.from_pretrained(self._image_encoder_path, torch_dtype=torch_dtype) + + # Calculate a more accurate model size. + self.model_size = calc_model_size_by_data(model) + + return model + + @classmethod + def convert_if_required( + cls, + model_path: str, + output_path: str, + config: ModelConfigBase, + base_model: BaseModelType, + ) -> str: + format = cls.detect_format(model_path) + if format == CLIPVisionModelFormat.Diffusers: + return model_path + else: + raise ValueError(f"Unsupported format: '{format}'.")