wip: Initial implementation of safetensor support for IP Adapter

This commit is contained in:
blessedcoolant
2024-03-24 01:40:28 +05:30
parent 8584171a49
commit 67afb1763e
6 changed files with 104 additions and 61 deletions

View File

@ -230,9 +230,10 @@ class ModelProbe(object):
return ModelType.LoRA
elif any(key.startswith(v) for v in {"controlnet", "control_model", "input_blocks"}):
return ModelType.ControlNet
elif any(key.startswith(v) for v in {"image_proj.", "ip_adapter."}):
return ModelType.IPAdapter
elif key in {"emb_params", "string_to_param"}:
return ModelType.TextualInversion
else:
# diffusers-ti
if len(ckpt) < 10 and all(isinstance(v, torch.Tensor) for v in ckpt.values()):
@ -527,8 +528,15 @@ class ControlNetCheckpointProbe(CheckpointProbeBase):
class IPAdapterCheckpointProbe(CheckpointProbeBase):
"""Class for probing IP Adapters"""
def get_base_type(self) -> BaseModelType:
raise NotImplementedError()
checkpoint = self.checkpoint
for key in checkpoint.keys():
if not key.startswith(("image_proj.", "ip_adapter.")):
continue
return BaseModelType.StableDiffusionXL
raise InvalidModelConfigException(f"{self.model_path}: Unable to determine base type")
class CLIPVisionCheckpointProbe(CheckpointProbeBase):
@ -689,9 +697,7 @@ class ControlNetFolderProbe(FolderProbeBase):
else (
BaseModelType.StableDiffusion2
if dimension == 1024
else BaseModelType.StableDiffusionXL
if dimension == 2048
else None
else BaseModelType.StableDiffusionXL if dimension == 2048 else None
)
)
if not base_model:
@ -768,7 +774,7 @@ class T2IAdapterFolderProbe(FolderProbeBase):
)
############## register probe classes ######
# Register probe classes
ModelProbe.register_probe("diffusers", ModelType.Main, PipelineFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.VAE, VaeFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.LoRA, LoRAFolderProbe)