mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Fix bug in IPAdapter.to(...).
This commit is contained in:
parent
a22c8cb3a1
commit
94c186bb4c
@ -18,6 +18,8 @@ from diffusers.models import UNet2DConditionModel
|
|||||||
from PIL import Image
|
from PIL import Image
|
||||||
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
|
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
|
||||||
|
|
||||||
|
from invokeai.backend.model_management.models.base import calc_model_size_by_data
|
||||||
|
|
||||||
from .attention_processor import AttnProcessor, IPAttnProcessor
|
from .attention_processor import AttnProcessor, IPAttnProcessor
|
||||||
from .resampler import Resampler
|
from .resampler import Resampler
|
||||||
|
|
||||||
@ -94,7 +96,7 @@ class IPAdapter:
|
|||||||
|
|
||||||
self._image_proj_model.to(device=self.device, dtype=self.dtype)
|
self._image_proj_model.to(device=self.device, dtype=self.dtype)
|
||||||
if self._attn_processors is not None:
|
if self._attn_processors is not None:
|
||||||
torch.nn.ModuleList(self._attn_processors).to(device=self.device, dtype=self.dtype)
|
torch.nn.ModuleList(self._attn_processors.values()).to(device=self.device, dtype=self.dtype)
|
||||||
|
|
||||||
def _init_image_proj_model(self, state_dict):
|
def _init_image_proj_model(self, state_dict):
|
||||||
return ImageProjModel.from_state_dict(state_dict, self._num_tokens).to(self.device, dtype=self.dtype)
|
return ImageProjModel.from_state_dict(state_dict, self._num_tokens).to(self.device, dtype=self.dtype)
|
||||||
|
@ -72,7 +72,6 @@ class IPAdapterModel(ModelBase):
|
|||||||
if child_type is not None:
|
if child_type is not None:
|
||||||
raise ValueError("There are no child models in an IP-Adapter model.")
|
raise ValueError("There are no child models in an IP-Adapter model.")
|
||||||
|
|
||||||
# TODO(ryand): Update self.model_size when the model is loaded from disk.
|
|
||||||
return self.model_size
|
return self.model_size
|
||||||
|
|
||||||
def get_model(
|
def get_model(
|
||||||
|
Loading…
Reference in New Issue
Block a user