fixup ip adapter handling

This commit is contained in:
Lincoln Stein 2024-06-24 14:57:54 -04:00
parent 9b7b182cf7
commit 5d6a77d336
4 changed files with 21 additions and 5 deletions

View File

@ -226,6 +226,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
# Add a batch dimension to the mask, because torchvision expects shape (batch, channels, h, w).
mask = mask.unsqueeze(0) # Shape: (1, h, w) -> (1, 1, h, w)
resized_mask = tf(mask)
assert isinstance(resized_mask, torch.Tensor)
return resized_mask
def _concat_regional_text_embeddings(

View File

@ -25,6 +25,7 @@ from enum import Enum
from typing import Literal, Optional, Type, TypeAlias, Union
import torch
from diffusers.configuration_utils import ConfigMixin
from diffusers.models.modeling_utils import ModelMixin
from pydantic import BaseModel, ConfigDict, Discriminator, Field, Tag, TypeAdapter
from typing_extensions import Annotated, Any, Dict
@ -37,7 +38,7 @@ from ..raw_model import RawModel
# ModelMixin is the base class for all diffusers and transformers models
# RawModel is the InvokeAI wrapper class for ip_adapters, loras, textual_inversion and onnx runtime
AnyModel = Union[ModelMixin, RawModel, torch.nn.Module, Dict[str, torch.Tensor]]
AnyModel = Union[ConfigMixin, ModelMixin, RawModel, torch.nn.Module, Dict[str, torch.Tensor]]
class InvalidModelConfigException(Exception):
@ -177,6 +178,7 @@ class ModelConfigBase(BaseModel):
@staticmethod
def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseModel]) -> None:
"""Extend the pydantic schema from a json."""
schema["required"].extend(["key", "type", "format"])
model_config = ConfigDict(validate_assignment=True, json_schema_extra=json_schema_extra)
@ -443,7 +445,7 @@ class ModelConfigFactory(object):
model = dest_class.model_validate(model_data)
else:
# mypy doesn't typecheck TypeAdapters well?
model = AnyModelConfigValidator.validate_python(model_data) # type: ignore
model = AnyModelConfigValidator.validate_python(model_data)
assert model is not None
if key:
model.key = key

View File

@ -188,6 +188,11 @@ class ModelCacheBase(ABC, Generic[T]):
"""Return true if the model identified by key and submodel_type is in the cache."""
pass
@abstractmethod
def model_to_device(self, cache_entry: CacheRecord[AnyModel], target_device: torch.device) -> AnyModel:
"""Move a copy of the model into the indicated device and return it."""
pass
@abstractmethod
def cache_size(self) -> int:
"""Get the total size of the models currently cached."""

View File

@ -18,6 +18,7 @@ context. Use like this:
"""
import copy
import gc
import math
import sys
@ -29,6 +30,7 @@ from threading import BoundedSemaphore
from typing import Dict, Generator, List, Optional, Set
import torch
from diffusers.configuration_utils import ConfigMixin
from invokeai.backend.model_manager import AnyModel, SubModelType
from invokeai.backend.model_manager.load.memory_snapshot import MemorySnapshot, get_pretty_snapshot_diff
@ -294,12 +296,18 @@ class ModelCache(ModelCacheBase[AnyModel]):
May raise a torch.cuda.OutOfMemoryError
"""
self.logger.info(f"Called to move {cache_entry.key} to {target_device}")
self.logger.info(f"Called to move {cache_entry.key} ({type(cache_entry.model)=}) to {target_device}")
# Some models don't have a state dictionary, in which case the
# stored model will still reside in CPU
if cache_entry.state_dict is None:
return cache_entry.model
if hasattr(cache_entry.model, "to"):
model_in_gpu = copy.deepcopy(cache_entry.model)
assert hasattr(model_in_gpu, "to")
model_in_gpu.to(target_device)
return model_in_gpu
else:
return cache_entry.model # what happens in CPU stays in CPU
# This roundabout method for moving the model around is done to avoid
# the cost of moving the model from RAM to VRAM and then back from VRAM to RAM.
@ -317,7 +325,7 @@ class ModelCache(ModelCacheBase[AnyModel]):
template = cache_entry.model
cls = template.__class__
with skip_torch_weight_init():
if hasattr(cls, "from_config"):
if isinstance(cls, ConfigMixin) or hasattr(cls, "from_config"):
working_model = template.__class__.from_config(template.config) # diffusers style
else:
working_model = template.__class__(config=template.config) # transformers style (sigh)