mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Update IPAdapterModel to respect requested torch_dtype.
This commit is contained in:
parent
ebf26687cb
commit
77d135967f
@ -63,14 +63,12 @@ class IPAdapterModel(ModelBase):
|
|||||||
if child_type is not None:
|
if child_type is not None:
|
||||||
raise ValueError("There are no child models in an IP-Adapter model.")
|
raise ValueError("There are no child models in an IP-Adapter model.")
|
||||||
|
|
||||||
# TODO(ryand): Update IPAdapter to accept a torch_dtype param.
|
|
||||||
|
|
||||||
# TODO(ryand): Checking for "plus" in the file name is fragile. It should be possible to infer whether this is a
|
# TODO(ryand): Checking for "plus" in the file name is fragile. It should be possible to infer whether this is a
|
||||||
# "plus" variant by loading the state_dict.
|
# "plus" variant by loading the state_dict.
|
||||||
if "plus" in str(self.model_path):
|
if "plus" in str(self.model_path):
|
||||||
return IPAdapterPlus(ip_adapter_ckpt_path=self.model_path, device="cpu")
|
return IPAdapterPlus(ip_adapter_ckpt_path=self.model_path, device="cpu", dtype=torch_dtype)
|
||||||
else:
|
else:
|
||||||
return IPAdapter(ip_adapter_ckpt_path=self.model_path, device="cpu")
|
return IPAdapter(ip_adapter_ckpt_path=self.model_path, device="cpu", dtype=torch_dtype)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def convert_if_required(
|
def convert_if_required(
|
||||||
|
Loading…
Reference in New Issue
Block a user