diff --git a/invokeai/backend/ip_adapter/ip_adapter.py b/invokeai/backend/ip_adapter/ip_adapter.py index e51966c779..4ab3a6f0c0 100644 --- a/invokeai/backend/ip_adapter/ip_adapter.py +++ b/invokeai/backend/ip_adapter/ip_adapter.py @@ -9,7 +9,6 @@ from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection from invokeai.backend.ip_adapter.ip_attention_weights import IPAttentionWeights -from ..raw_model import RawModel from .resampler import Resampler @@ -92,7 +91,7 @@ class MLPProjModel(torch.nn.Module): return clip_extra_context_tokens -class IPAdapter(RawModel): +class IPAdapter(torch.nn.Module): """IP-Adapter: https://arxiv.org/pdf/2308.06721.pdf""" def __init__( diff --git a/invokeai/backend/lora.py b/invokeai/backend/lora.py index 0b7128034a..15fa423978 100644 --- a/invokeai/backend/lora.py +++ b/invokeai/backend/lora.py @@ -11,8 +11,6 @@ from typing_extensions import Self from invokeai.backend.model_manager import BaseModelType -from .raw_model import RawModel - class LoRALayerBase: # rank: Optional[int] @@ -368,7 +366,7 @@ class IA3Layer(LoRALayerBase): AnyLoRALayer = Union[LoRALayer, LoHALayer, LoKRLayer, FullLayer, IA3Layer] -class LoRAModelRaw(RawModel): # (torch.nn.Module): +class LoRAModelRaw(torch.nn.Module): _name: str layers: Dict[str, AnyLoRALayer] diff --git a/invokeai/backend/model_manager/config.py b/invokeai/backend/model_manager/config.py index 524e39b2a1..9836ee3167 100644 --- a/invokeai/backend/model_manager/config.py +++ b/invokeai/backend/model_manager/config.py @@ -31,12 +31,13 @@ from typing_extensions import Annotated, Any, Dict from invokeai.app.invocations.constants import SCHEDULER_NAME_VALUES from invokeai.app.util.misc import uuid_string - -from ..raw_model import RawModel +from invokeai.backend.ip_adapter.ip_adapter import IPAdapter +from invokeai.backend.lora import LoRAModelRaw +from invokeai.backend.onnx.onnx_runtime import IAIOnnxRuntimeModel +from invokeai.backend.textual_inversion import TextualInversionModelRaw # ModelMixin is the base class for all diffusers and transformers models -# RawModel is the InvokeAI wrapper class for ip_adapters, loras, textual_inversion and onnx runtime -AnyModel = Union[ModelMixin, RawModel, torch.nn.Module] +AnyModel = Union[ModelMixin, torch.nn.Module, IPAdapter, LoRAModelRaw, TextualInversionModelRaw, IAIOnnxRuntimeModel] class InvalidModelConfigException(Exception): diff --git a/invokeai/backend/onnx/onnx_runtime.py b/invokeai/backend/onnx/onnx_runtime.py index 8916865dd5..3f41c92c6e 100644 --- a/invokeai/backend/onnx/onnx_runtime.py +++ b/invokeai/backend/onnx/onnx_runtime.py @@ -6,17 +6,16 @@ from typing import Any, List, Optional, Tuple, Union import numpy as np import onnx +import torch from onnx import numpy_helper from onnxruntime import InferenceSession, SessionOptions, get_available_providers -from ..raw_model import RawModel - ONNX_WEIGHTS_NAME = "model.onnx" # NOTE FROM LS: This was copied from Stalker's original implementation. # I have not yet gone through and fixed all the type hints -class IAIOnnxRuntimeModel(RawModel): +class IAIOnnxRuntimeModel(torch.nn.Module): class _tensor_access: def __init__(self, model): # type: ignore self.model = model diff --git a/invokeai/backend/raw_model.py b/invokeai/backend/raw_model.py deleted file mode 100644 index d0dc50c456..0000000000 --- a/invokeai/backend/raw_model.py +++ /dev/null @@ -1,15 +0,0 @@ -"""Base class for 'Raw' models. - -The RawModel class is the base class of LoRAModelRaw and TextualInversionModelRaw, -and is used for type checking of calls to the model patcher. Its main purpose -is to avoid a circular import issues when lora.py tries to import BaseModelType -from invokeai.backend.model_manager.config, and the latter tries to import LoRAModelRaw -from lora.py. - -The term 'raw' was introduced to describe a wrapper around a torch.nn.Module -that adds additional methods and attributes. -""" - - -class RawModel: - """Base class for 'Raw' model wrappers.""" diff --git a/invokeai/backend/textual_inversion.py b/invokeai/backend/textual_inversion.py index f7390979bb..479596f341 100644 --- a/invokeai/backend/textual_inversion.py +++ b/invokeai/backend/textual_inversion.py @@ -9,10 +9,8 @@ from safetensors.torch import load_file from transformers import CLIPTokenizer from typing_extensions import Self -from .raw_model import RawModel - -class TextualInversionModelRaw(RawModel): +class TextualInversionModelRaw(torch.nn.Module): embedding: torch.Tensor # [n, 768]|[n, 1280] embedding_2: Optional[torch.Tensor] = None # [n, 768]|[n, 1280] - for SDXL models