Update IPAdapterModel to respect requested torch_dtype.

This commit is contained in:
Ryan Dick 2023-09-13 21:06:42 -04:00
parent ebf26687cb
commit 77d135967f

View File

@ -63,14 +63,12 @@ class IPAdapterModel(ModelBase):
if child_type is not None: if child_type is not None:
raise ValueError("There are no child models in an IP-Adapter model.") raise ValueError("There are no child models in an IP-Adapter model.")
# TODO(ryand): Update IPAdapter to accept a torch_dtype param.
# TODO(ryand): Checking for "plus" in the file name is fragile. It should be possible to infer whether this is a # TODO(ryand): Checking for "plus" in the file name is fragile. It should be possible to infer whether this is a
# "plus" variant by loading the state_dict. # "plus" variant by loading the state_dict.
if "plus" in str(self.model_path): if "plus" in str(self.model_path):
return IPAdapterPlus(ip_adapter_ckpt_path=self.model_path, device="cpu") return IPAdapterPlus(ip_adapter_ckpt_path=self.model_path, device="cpu", dtype=torch_dtype)
else: else:
return IPAdapter(ip_adapter_ckpt_path=self.model_path, device="cpu") return IPAdapter(ip_adapter_ckpt_path=self.model_path, device="cpu", dtype=torch_dtype)
@classmethod @classmethod
def convert_if_required( def convert_if_required(