mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Use CLIPVisionModel under model management for IP-Adapter.
This commit is contained in:
@ -60,7 +60,7 @@ class CLIPVisionModel(ModelBase):
|
||||
if child_type is not None:
|
||||
raise ValueError("There are no child models in a CLIP Vision model.")
|
||||
|
||||
model = CLIPVisionModelWithProjection.from_pretrained(self._image_encoder_path, torch_dtype=torch_dtype)
|
||||
model = CLIPVisionModelWithProjection.from_pretrained(self.model_path, torch_dtype=torch_dtype)
|
||||
|
||||
# Calculate a more accurate model size.
|
||||
self.model_size = calc_model_size_by_data(model)
|
||||
|
@ -55,10 +55,6 @@ class IPAdapterModel(ModelBase):
|
||||
# TODO(ryand): Update self.model_size when the model is loaded from disk.
|
||||
return self.model_size
|
||||
|
||||
def _get_text_encoder_path(self) -> str:
|
||||
# TODO(ryand): Move the CLIP image encoder to its own model directory.
|
||||
return os.path.join(os.path.dirname(self.model_path), "image_encoder")
|
||||
|
||||
def get_model(
|
||||
self,
|
||||
torch_dtype: Optional[torch.dtype],
|
||||
@ -72,13 +68,9 @@ class IPAdapterModel(ModelBase):
|
||||
# TODO(ryand): Checking for "plus" in the file name is fragile. It should be possible to infer whether this is a
|
||||
# "plus" variant by loading the state_dict.
|
||||
if "plus" in str(self.model_path):
|
||||
return IPAdapterPlus(
|
||||
image_encoder_path=self._get_text_encoder_path(), ip_adapter_ckpt_path=self.model_path, device="cpu"
|
||||
)
|
||||
return IPAdapterPlus(ip_adapter_ckpt_path=self.model_path, device="cpu")
|
||||
else:
|
||||
return IPAdapter(
|
||||
image_encoder_path=self._get_text_encoder_path(), ip_adapter_ckpt_path=self.model_path, device="cpu"
|
||||
)
|
||||
return IPAdapter(ip_adapter_ckpt_path=self.model_path, device="cpu")
|
||||
|
||||
@classmethod
|
||||
def convert_if_required(
|
||||
|
Reference in New Issue
Block a user