Add CLIPVisionModel to model management.

This commit is contained in:
Ryan Dick 2023-09-13 17:14:20 -04:00
parent a2777decd4
commit 3d52656176
4 changed files with 115 additions and 8 deletions

View File

@ -8,6 +8,8 @@ import torch
from diffusers import ConfigMixin, ModelMixin from diffusers import ConfigMixin, ModelMixin
from picklescan.scanner import scan_file_path from picklescan.scanner import scan_file_path
from invokeai.backend.model_management.models import BaseModelType
from .models import ( from .models import (
BaseModelType, BaseModelType,
InvalidModelException, InvalidModelException,
@ -53,6 +55,7 @@ class ModelProbe(object):
"AutoencoderKL": ModelType.Vae, "AutoencoderKL": ModelType.Vae,
"ControlNetModel": ModelType.ControlNet, "ControlNetModel": ModelType.ControlNet,
"IPAdapterModel": ModelType.IPAdapter, "IPAdapterModel": ModelType.IPAdapter,
"CLIPVision": ModelType.CLIPVision,
} }
@classmethod @classmethod
@ -119,14 +122,18 @@ class ModelProbe(object):
and prediction_type == SchedulerPredictionType.VPrediction and prediction_type == SchedulerPredictionType.VPrediction
), ),
format=format, format=format,
image_size=1024 image_size=(
1024
if (base_type in {BaseModelType.StableDiffusionXL, BaseModelType.StableDiffusionXLRefiner}) if (base_type in {BaseModelType.StableDiffusionXL, BaseModelType.StableDiffusionXLRefiner})
else 768 else (
768
if ( if (
base_type == BaseModelType.StableDiffusion2 base_type == BaseModelType.StableDiffusion2
and prediction_type == SchedulerPredictionType.VPrediction and prediction_type == SchedulerPredictionType.VPrediction
) )
else 512, else 512
)
),
) )
except Exception: except Exception:
raise raise
@ -372,6 +379,11 @@ class IPAdapterCheckpointProbe(CheckpointProbeBase):
raise NotImplementedError() raise NotImplementedError()
class CLIPVisionCheckpointProbe(CheckpointProbeBase):
def get_base_type(self) -> BaseModelType:
raise NotImplementedError()
######################################################## ########################################################
# classes for probing folders # classes for probing folders
####################################################### #######################################################
@ -520,6 +532,11 @@ class IPAdapterFolderProbe(FolderProbeBase):
raise NotImplementedError() raise NotImplementedError()
class CLIPVisionFolderProbe(FolderProbeBase):
def get_base_type(self) -> BaseModelType:
raise NotImplementedError()
############## register probe classes ###### ############## register probe classes ######
ModelProbe.register_probe("diffusers", ModelType.Main, PipelineFolderProbe) ModelProbe.register_probe("diffusers", ModelType.Main, PipelineFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.Vae, VaeFolderProbe) ModelProbe.register_probe("diffusers", ModelType.Vae, VaeFolderProbe)
@ -527,6 +544,7 @@ ModelProbe.register_probe("diffusers", ModelType.Lora, LoRAFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.TextualInversion, TextualInversionFolderProbe) ModelProbe.register_probe("diffusers", ModelType.TextualInversion, TextualInversionFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.ControlNet, ControlNetFolderProbe) ModelProbe.register_probe("diffusers", ModelType.ControlNet, ControlNetFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.IPAdapter, IPAdapterFolderProbe) ModelProbe.register_probe("diffusers", ModelType.IPAdapter, IPAdapterFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.CLIPVision, CLIPVisionFolderProbe)
ModelProbe.register_probe("checkpoint", ModelType.Main, PipelineCheckpointProbe) ModelProbe.register_probe("checkpoint", ModelType.Main, PipelineCheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.Vae, VaeCheckpointProbe) ModelProbe.register_probe("checkpoint", ModelType.Vae, VaeCheckpointProbe)
@ -534,5 +552,6 @@ ModelProbe.register_probe("checkpoint", ModelType.Lora, LoRACheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.TextualInversion, TextualInversionCheckpointProbe) ModelProbe.register_probe("checkpoint", ModelType.TextualInversion, TextualInversionCheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.ControlNet, ControlNetCheckpointProbe) ModelProbe.register_probe("checkpoint", ModelType.ControlNet, ControlNetCheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.IPAdapter, IPAdapterCheckpointProbe) ModelProbe.register_probe("checkpoint", ModelType.IPAdapter, IPAdapterCheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.CLIPVision, CLIPVisionCheckpointProbe)
ModelProbe.register_probe("onnx", ModelType.ONNX, ONNXFolderProbe) ModelProbe.register_probe("onnx", ModelType.ONNX, ONNXFolderProbe)

View File

@ -18,6 +18,7 @@ from .base import ( # noqa: F401
SilenceWarnings, SilenceWarnings,
SubModelType, SubModelType,
) )
from .clip_vision import CLIPVisionModel
from .controlnet import ControlNetModel # TODO: from .controlnet import ControlNetModel # TODO:
from .ip_adapter import IPAdapterModel from .ip_adapter import IPAdapterModel
from .lora import LoRAModel from .lora import LoRAModel
@ -36,6 +37,7 @@ MODEL_CLASSES = {
ModelType.ControlNet: ControlNetModel, ModelType.ControlNet: ControlNetModel,
ModelType.TextualInversion: TextualInversionModel, ModelType.TextualInversion: TextualInversionModel,
ModelType.IPAdapter: IPAdapterModel, ModelType.IPAdapter: IPAdapterModel,
ModelType.CLIPVision: CLIPVisionModel,
}, },
BaseModelType.StableDiffusion2: { BaseModelType.StableDiffusion2: {
ModelType.ONNX: ONNXStableDiffusion2Model, ModelType.ONNX: ONNXStableDiffusion2Model,
@ -45,6 +47,7 @@ MODEL_CLASSES = {
ModelType.ControlNet: ControlNetModel, ModelType.ControlNet: ControlNetModel,
ModelType.TextualInversion: TextualInversionModel, ModelType.TextualInversion: TextualInversionModel,
ModelType.IPAdapter: IPAdapterModel, ModelType.IPAdapter: IPAdapterModel,
ModelType.CLIPVision: CLIPVisionModel,
}, },
BaseModelType.StableDiffusionXL: { BaseModelType.StableDiffusionXL: {
ModelType.Main: StableDiffusionXLModel, ModelType.Main: StableDiffusionXLModel,
@ -55,6 +58,7 @@ MODEL_CLASSES = {
ModelType.TextualInversion: TextualInversionModel, ModelType.TextualInversion: TextualInversionModel,
ModelType.ONNX: ONNXStableDiffusion2Model, ModelType.ONNX: ONNXStableDiffusion2Model,
ModelType.IPAdapter: IPAdapterModel, ModelType.IPAdapter: IPAdapterModel,
ModelType.CLIPVision: CLIPVisionModel,
}, },
BaseModelType.StableDiffusionXLRefiner: { BaseModelType.StableDiffusionXLRefiner: {
ModelType.Main: StableDiffusionXLModel, ModelType.Main: StableDiffusionXLModel,
@ -65,6 +69,7 @@ MODEL_CLASSES = {
ModelType.TextualInversion: TextualInversionModel, ModelType.TextualInversion: TextualInversionModel,
ModelType.ONNX: ONNXStableDiffusion2Model, ModelType.ONNX: ONNXStableDiffusion2Model,
ModelType.IPAdapter: IPAdapterModel, ModelType.IPAdapter: IPAdapterModel,
ModelType.CLIPVision: CLIPVisionModel,
}, },
# BaseModelType.Kandinsky2_1: { # BaseModelType.Kandinsky2_1: {
# ModelType.Main: Kandinsky2_1Model, # ModelType.Main: Kandinsky2_1Model,

View File

@ -62,6 +62,7 @@ class ModelType(str, Enum):
ControlNet = "controlnet" # used by model_probe ControlNet = "controlnet" # used by model_probe
TextualInversion = "embedding" TextualInversion = "embedding"
IPAdapter = "ip_adapter" IPAdapter = "ip_adapter"
CLIPVision = "clip_vision"
class SubModelType(str, Enum): class SubModelType(str, Enum):

View File

@ -0,0 +1,82 @@
import os
from enum import Enum
from typing import Literal, Optional
import torch
from transformers import CLIPVisionModelWithProjection
from invokeai.backend.model_management.models.base import (
BaseModelType,
InvalidModelException,
ModelBase,
ModelConfigBase,
ModelType,
SubModelType,
calc_model_size_by_data,
calc_model_size_by_fs,
classproperty,
)
class CLIPVisionModelFormat(str, Enum):
Diffusers = "diffusers"
class CLIPVisionModel(ModelBase):
class DiffusersConfig(ModelConfigBase):
model_format: Literal[CLIPVisionModelFormat.Diffusers]
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
assert model_type == ModelType.CLIPVision
super().__init__(model_path, base_model, model_type)
self.model_size = calc_model_size_by_fs(self.model_path)
@classmethod
def detect_format(cls, path: str) -> str:
if not os.path.exists(path):
raise ModuleNotFoundError(f"No CLIP Vision model at path '{path}'.")
if os.path.isdir(path) and os.path.exists(os.path.join(path, "config.json")):
return CLIPVisionModelFormat.Diffusers
raise InvalidModelException(f"Unexpected CLIP Vision model format: {path}")
@classproperty
def save_to_config(cls) -> bool:
return True
def get_size(self, child_type: Optional[SubModelType] = None) -> int:
if child_type is not None:
raise ValueError("There are no child models in a CLIP Vision model.")
return self.model_size
def get_model(
self,
torch_dtype: Optional[torch.dtype],
child_type: Optional[SubModelType] = None,
) -> CLIPVisionModelWithProjection:
if child_type is not None:
raise ValueError("There are no child models in a CLIP Vision model.")
model = CLIPVisionModelWithProjection.from_pretrained(self._image_encoder_path, torch_dtype=torch_dtype)
# Calculate a more accurate model size.
self.model_size = calc_model_size_by_data(model)
return model
@classmethod
def convert_if_required(
cls,
model_path: str,
output_path: str,
config: ModelConfigBase,
base_model: BaseModelType,
) -> str:
format = cls.detect_format(model_path)
if format == CLIPVisionModelFormat.Diffusers:
return model_path
else:
raise ValueError(f"Unsupported format: '{format}'.")