Remove RawModel - it was just creating a weird layer of indirection in the AnyModel type without adding any value.

This commit is contained in:
Ryan Dick 2024-03-15 14:43:17 -04:00
parent 73e5b08c1f
commit 827ac4b841
6 changed files with 10 additions and 30 deletions

View File

@ -9,7 +9,6 @@ from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
from invokeai.backend.ip_adapter.ip_attention_weights import IPAttentionWeights from invokeai.backend.ip_adapter.ip_attention_weights import IPAttentionWeights
from ..raw_model import RawModel
from .resampler import Resampler from .resampler import Resampler
@ -92,7 +91,7 @@ class MLPProjModel(torch.nn.Module):
return clip_extra_context_tokens return clip_extra_context_tokens
class IPAdapter(RawModel): class IPAdapter(torch.nn.Module):
"""IP-Adapter: https://arxiv.org/pdf/2308.06721.pdf""" """IP-Adapter: https://arxiv.org/pdf/2308.06721.pdf"""
def __init__( def __init__(

View File

@ -11,8 +11,6 @@ from typing_extensions import Self
from invokeai.backend.model_manager import BaseModelType from invokeai.backend.model_manager import BaseModelType
from .raw_model import RawModel
class LoRALayerBase: class LoRALayerBase:
# rank: Optional[int] # rank: Optional[int]
@ -368,7 +366,7 @@ class IA3Layer(LoRALayerBase):
AnyLoRALayer = Union[LoRALayer, LoHALayer, LoKRLayer, FullLayer, IA3Layer] AnyLoRALayer = Union[LoRALayer, LoHALayer, LoKRLayer, FullLayer, IA3Layer]
class LoRAModelRaw(RawModel): # (torch.nn.Module): class LoRAModelRaw(torch.nn.Module):
_name: str _name: str
layers: Dict[str, AnyLoRALayer] layers: Dict[str, AnyLoRALayer]

View File

@ -31,12 +31,13 @@ from typing_extensions import Annotated, Any, Dict
from invokeai.app.invocations.constants import SCHEDULER_NAME_VALUES from invokeai.app.invocations.constants import SCHEDULER_NAME_VALUES
from invokeai.app.util.misc import uuid_string from invokeai.app.util.misc import uuid_string
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
from ..raw_model import RawModel from invokeai.backend.lora import LoRAModelRaw
from invokeai.backend.onnx.onnx_runtime import IAIOnnxRuntimeModel
from invokeai.backend.textual_inversion import TextualInversionModelRaw
# ModelMixin is the base class for all diffusers and transformers models # ModelMixin is the base class for all diffusers and transformers models
# RawModel is the InvokeAI wrapper class for ip_adapters, loras, textual_inversion and onnx runtime AnyModel = Union[ModelMixin, torch.nn.Module, IPAdapter, LoRAModelRaw, TextualInversionModelRaw, IAIOnnxRuntimeModel]
AnyModel = Union[ModelMixin, RawModel, torch.nn.Module]
class InvalidModelConfigException(Exception): class InvalidModelConfigException(Exception):

View File

@ -6,17 +6,16 @@ from typing import Any, List, Optional, Tuple, Union
import numpy as np import numpy as np
import onnx import onnx
import torch
from onnx import numpy_helper from onnx import numpy_helper
from onnxruntime import InferenceSession, SessionOptions, get_available_providers from onnxruntime import InferenceSession, SessionOptions, get_available_providers
from ..raw_model import RawModel
ONNX_WEIGHTS_NAME = "model.onnx" ONNX_WEIGHTS_NAME = "model.onnx"
# NOTE FROM LS: This was copied from Stalker's original implementation. # NOTE FROM LS: This was copied from Stalker's original implementation.
# I have not yet gone through and fixed all the type hints # I have not yet gone through and fixed all the type hints
class IAIOnnxRuntimeModel(RawModel): class IAIOnnxRuntimeModel(torch.nn.Module):
class _tensor_access: class _tensor_access:
def __init__(self, model): # type: ignore def __init__(self, model): # type: ignore
self.model = model self.model = model

View File

@ -1,15 +0,0 @@
"""Base class for 'Raw' models.
The RawModel class is the base class of LoRAModelRaw and TextualInversionModelRaw,
and is used for type checking of calls to the model patcher. Its main purpose
is to avoid a circular import issues when lora.py tries to import BaseModelType
from invokeai.backend.model_manager.config, and the latter tries to import LoRAModelRaw
from lora.py.
The term 'raw' was introduced to describe a wrapper around a torch.nn.Module
that adds additional methods and attributes.
"""
class RawModel:
"""Base class for 'Raw' model wrappers."""

View File

@ -9,10 +9,8 @@ from safetensors.torch import load_file
from transformers import CLIPTokenizer from transformers import CLIPTokenizer
from typing_extensions import Self from typing_extensions import Self
from .raw_model import RawModel
class TextualInversionModelRaw(torch.nn.Module):
class TextualInversionModelRaw(RawModel):
embedding: torch.Tensor # [n, 768]|[n, 1280] embedding: torch.Tensor # [n, 768]|[n, 1280]
embedding_2: Optional[torch.Tensor] = None # [n, 768]|[n, 1280] - for SDXL models embedding_2: Optional[torch.Tensor] = None # [n, 768]|[n, 1280] - for SDXL models