mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
fixup ip adapter handling
This commit is contained in:
parent
9b7b182cf7
commit
5d6a77d336
@ -226,6 +226,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
# Add a batch dimension to the mask, because torchvision expects shape (batch, channels, h, w).
|
# Add a batch dimension to the mask, because torchvision expects shape (batch, channels, h, w).
|
||||||
mask = mask.unsqueeze(0) # Shape: (1, h, w) -> (1, 1, h, w)
|
mask = mask.unsqueeze(0) # Shape: (1, h, w) -> (1, 1, h, w)
|
||||||
resized_mask = tf(mask)
|
resized_mask = tf(mask)
|
||||||
|
assert isinstance(resized_mask, torch.Tensor)
|
||||||
return resized_mask
|
return resized_mask
|
||||||
|
|
||||||
def _concat_regional_text_embeddings(
|
def _concat_regional_text_embeddings(
|
||||||
|
@ -25,6 +25,7 @@ from enum import Enum
|
|||||||
from typing import Literal, Optional, Type, TypeAlias, Union
|
from typing import Literal, Optional, Type, TypeAlias, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from diffusers.configuration_utils import ConfigMixin
|
||||||
from diffusers.models.modeling_utils import ModelMixin
|
from diffusers.models.modeling_utils import ModelMixin
|
||||||
from pydantic import BaseModel, ConfigDict, Discriminator, Field, Tag, TypeAdapter
|
from pydantic import BaseModel, ConfigDict, Discriminator, Field, Tag, TypeAdapter
|
||||||
from typing_extensions import Annotated, Any, Dict
|
from typing_extensions import Annotated, Any, Dict
|
||||||
@ -37,7 +38,7 @@ from ..raw_model import RawModel
|
|||||||
|
|
||||||
# ModelMixin is the base class for all diffusers and transformers models
|
# ModelMixin is the base class for all diffusers and transformers models
|
||||||
# RawModel is the InvokeAI wrapper class for ip_adapters, loras, textual_inversion and onnx runtime
|
# RawModel is the InvokeAI wrapper class for ip_adapters, loras, textual_inversion and onnx runtime
|
||||||
AnyModel = Union[ModelMixin, RawModel, torch.nn.Module, Dict[str, torch.Tensor]]
|
AnyModel = Union[ConfigMixin, ModelMixin, RawModel, torch.nn.Module, Dict[str, torch.Tensor]]
|
||||||
|
|
||||||
|
|
||||||
class InvalidModelConfigException(Exception):
|
class InvalidModelConfigException(Exception):
|
||||||
@ -177,6 +178,7 @@ class ModelConfigBase(BaseModel):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseModel]) -> None:
|
def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseModel]) -> None:
|
||||||
|
"""Extend the pydantic schema from a json."""
|
||||||
schema["required"].extend(["key", "type", "format"])
|
schema["required"].extend(["key", "type", "format"])
|
||||||
|
|
||||||
model_config = ConfigDict(validate_assignment=True, json_schema_extra=json_schema_extra)
|
model_config = ConfigDict(validate_assignment=True, json_schema_extra=json_schema_extra)
|
||||||
@ -443,7 +445,7 @@ class ModelConfigFactory(object):
|
|||||||
model = dest_class.model_validate(model_data)
|
model = dest_class.model_validate(model_data)
|
||||||
else:
|
else:
|
||||||
# mypy doesn't typecheck TypeAdapters well?
|
# mypy doesn't typecheck TypeAdapters well?
|
||||||
model = AnyModelConfigValidator.validate_python(model_data) # type: ignore
|
model = AnyModelConfigValidator.validate_python(model_data)
|
||||||
assert model is not None
|
assert model is not None
|
||||||
if key:
|
if key:
|
||||||
model.key = key
|
model.key = key
|
||||||
|
@ -188,6 +188,11 @@ class ModelCacheBase(ABC, Generic[T]):
|
|||||||
"""Return true if the model identified by key and submodel_type is in the cache."""
|
"""Return true if the model identified by key and submodel_type is in the cache."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def model_to_device(self, cache_entry: CacheRecord[AnyModel], target_device: torch.device) -> AnyModel:
|
||||||
|
"""Move a copy of the model into the indicated device and return it."""
|
||||||
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def cache_size(self) -> int:
|
def cache_size(self) -> int:
|
||||||
"""Get the total size of the models currently cached."""
|
"""Get the total size of the models currently cached."""
|
||||||
|
@ -18,6 +18,7 @@ context. Use like this:
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import copy
|
||||||
import gc
|
import gc
|
||||||
import math
|
import math
|
||||||
import sys
|
import sys
|
||||||
@ -29,6 +30,7 @@ from threading import BoundedSemaphore
|
|||||||
from typing import Dict, Generator, List, Optional, Set
|
from typing import Dict, Generator, List, Optional, Set
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from diffusers.configuration_utils import ConfigMixin
|
||||||
|
|
||||||
from invokeai.backend.model_manager import AnyModel, SubModelType
|
from invokeai.backend.model_manager import AnyModel, SubModelType
|
||||||
from invokeai.backend.model_manager.load.memory_snapshot import MemorySnapshot, get_pretty_snapshot_diff
|
from invokeai.backend.model_manager.load.memory_snapshot import MemorySnapshot, get_pretty_snapshot_diff
|
||||||
@ -294,12 +296,18 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
|||||||
|
|
||||||
May raise a torch.cuda.OutOfMemoryError
|
May raise a torch.cuda.OutOfMemoryError
|
||||||
"""
|
"""
|
||||||
self.logger.info(f"Called to move {cache_entry.key} to {target_device}")
|
self.logger.info(f"Called to move {cache_entry.key} ({type(cache_entry.model)=}) to {target_device}")
|
||||||
|
|
||||||
# Some models don't have a state dictionary, in which case the
|
# Some models don't have a state dictionary, in which case the
|
||||||
# stored model will still reside in CPU
|
# stored model will still reside in CPU
|
||||||
if cache_entry.state_dict is None:
|
if cache_entry.state_dict is None:
|
||||||
return cache_entry.model
|
if hasattr(cache_entry.model, "to"):
|
||||||
|
model_in_gpu = copy.deepcopy(cache_entry.model)
|
||||||
|
assert hasattr(model_in_gpu, "to")
|
||||||
|
model_in_gpu.to(target_device)
|
||||||
|
return model_in_gpu
|
||||||
|
else:
|
||||||
|
return cache_entry.model # what happens in CPU stays in CPU
|
||||||
|
|
||||||
# This roundabout method for moving the model around is done to avoid
|
# This roundabout method for moving the model around is done to avoid
|
||||||
# the cost of moving the model from RAM to VRAM and then back from VRAM to RAM.
|
# the cost of moving the model from RAM to VRAM and then back from VRAM to RAM.
|
||||||
@ -317,7 +325,7 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
|||||||
template = cache_entry.model
|
template = cache_entry.model
|
||||||
cls = template.__class__
|
cls = template.__class__
|
||||||
with skip_torch_weight_init():
|
with skip_torch_weight_init():
|
||||||
if hasattr(cls, "from_config"):
|
if isinstance(cls, ConfigMixin) or hasattr(cls, "from_config"):
|
||||||
working_model = template.__class__.from_config(template.config) # diffusers style
|
working_model = template.__class__.from_config(template.config) # diffusers style
|
||||||
else:
|
else:
|
||||||
working_model = template.__class__(config=template.config) # transformers style (sigh)
|
working_model = template.__class__(config=template.config) # transformers style (sigh)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user