make model manager v2 ready for PR review

- Replace legacy model manager service with the v2 manager.

- Update invocations to use new load interface.

- Fixed many but not all type checking errors in the invocations. Most
  were unrelated to model manager

- Updated routes. All the new routes live under the route tag
  `model_manager_v2`. To avoid confusion with the old routes,
  they have the URL prefix `/api/v2/models`. The old routes
  have been de-registered.

- Added a pytest for the loader.

- Updated documentation in contributing/MODEL_MANAGER.md
This commit is contained in:
Lincoln Stein 2024-02-10 18:09:45 -05:00
parent 1d724bca4a
commit 40a81c358d
41 changed files with 691 additions and 448 deletions

View File

@ -28,7 +28,7 @@ model. These are the:
Hugging Face, as well as discriminating among model versions in Hugging Face, as well as discriminating among model versions in
Civitai, but can be used for arbitrary content. Civitai, but can be used for arbitrary content.
* _ModelLoadServiceBase_ (**CURRENTLY UNDER DEVELOPMENT - NOT IMPLEMENTED**) * _ModelLoadServiceBase_
Responsible for loading a model from disk Responsible for loading a model from disk
into RAM and VRAM and getting it ready for inference. into RAM and VRAM and getting it ready for inference.
@ -41,10 +41,10 @@ The four main services can be found in
* `invokeai/app/services/model_records/` * `invokeai/app/services/model_records/`
* `invokeai/app/services/model_install/` * `invokeai/app/services/model_install/`
* `invokeai/app/services/downloads/` * `invokeai/app/services/downloads/`
* `invokeai/app/services/model_loader/` (**under development**) * `invokeai/app/services/model_load/`
Code related to the FastAPI web API can be found in Code related to the FastAPI web API can be found in
`invokeai/app/api/routers/model_records.py`. `invokeai/app/api/routers/model_manager_v2.py`.
*** ***
@ -84,10 +84,10 @@ diffusers model. When this happens, `original_hash` is unchanged, but
`ModelType`, `ModelFormat` and `BaseModelType` are string enums that `ModelType`, `ModelFormat` and `BaseModelType` are string enums that
are defined in `invokeai.backend.model_manager.config`. They are also are defined in `invokeai.backend.model_manager.config`. They are also
imported by, and can be reexported from, imported by, and can be reexported from,
`invokeai.app.services.model_record_service`: `invokeai.app.services.model_manager.model_records`:
``` ```
from invokeai.app.services.model_record_service import ModelType, ModelFormat, BaseModelType from invokeai.app.services.model_records import ModelType, ModelFormat, BaseModelType
``` ```
The `path` field can be absolute or relative. If relative, it is taken The `path` field can be absolute or relative. If relative, it is taken
@ -123,7 +123,7 @@ taken to be the `models_dir` directory.
`variant` is an enumerated string class with values `normal`, `variant` is an enumerated string class with values `normal`,
`inpaint` and `depth`. If needed, it can be imported if needed from `inpaint` and `depth`. If needed, it can be imported if needed from
either `invokeai.app.services.model_record_service` or either `invokeai.app.services.model_records` or
`invokeai.backend.model_manager.config`. `invokeai.backend.model_manager.config`.
### ONNXSD2Config ### ONNXSD2Config
@ -134,7 +134,7 @@ either `invokeai.app.services.model_record_service` or
| `upcast_attention` | bool | Model requires its attention module to be upcast | | `upcast_attention` | bool | Model requires its attention module to be upcast |
The `SchedulerPredictionType` enum can be imported from either The `SchedulerPredictionType` enum can be imported from either
`invokeai.app.services.model_record_service` or `invokeai.app.services.model_records` or
`invokeai.backend.model_manager.config`. `invokeai.backend.model_manager.config`.
### Other config classes ### Other config classes
@ -157,15 +157,6 @@ indicates that the model is compatible with any of the base
models. This works OK for some models, such as the IP Adapter image models. This works OK for some models, such as the IP Adapter image
encoders, but is an all-or-nothing proposition. encoders, but is an all-or-nothing proposition.
Another issue is that the config class hierarchy is paralleled to some
extent by a `ModelBase` class hierarchy defined in
`invokeai.backend.model_manager.models.base` and its subclasses. These
are classes representing the models after they are loaded into RAM and
include runtime information such as load status and bytes used. Some
of the fields, including `name`, `model_type` and `base_model`, are
shared between `ModelConfigBase` and `ModelBase`, and this is a
potential source of confusion.
## Reading and Writing Model Configuration Records ## Reading and Writing Model Configuration Records
The `ModelRecordService` provides the ability to retrieve model The `ModelRecordService` provides the ability to retrieve model
@ -177,11 +168,11 @@ initialization and can be retrieved within an invocation from the
`InvocationContext` object: `InvocationContext` object:
``` ```
store = context.services.model_record_store store = context.services.model_manager.store
``` ```
or from elsewhere in the code by accessing or from elsewhere in the code by accessing
`ApiDependencies.invoker.services.model_record_store`. `ApiDependencies.invoker.services.model_manager.store`.
### Creating a `ModelRecordService` ### Creating a `ModelRecordService`
@ -190,7 +181,7 @@ you can directly create either a `ModelRecordServiceSQL` or a
`ModelRecordServiceFile` object: `ModelRecordServiceFile` object:
``` ```
from invokeai.app.services.model_record_service import ModelRecordServiceSQL, ModelRecordServiceFile from invokeai.app.services.model_records import ModelRecordServiceSQL, ModelRecordServiceFile
store = ModelRecordServiceSQL.from_connection(connection, lock) store = ModelRecordServiceSQL.from_connection(connection, lock)
store = ModelRecordServiceSQL.from_db_file('/path/to/sqlite_database.db') store = ModelRecordServiceSQL.from_db_file('/path/to/sqlite_database.db')
@ -252,7 +243,7 @@ So a typical startup pattern would be:
``` ```
import sqlite3 import sqlite3
from invokeai.app.services.thread import lock from invokeai.app.services.thread import lock
from invokeai.app.services.model_record_service import ModelRecordServiceBase from invokeai.app.services.model_records import ModelRecordServiceBase
from invokeai.app.services.config import InvokeAIAppConfig from invokeai.app.services.config import InvokeAIAppConfig
config = InvokeAIAppConfig.get_config() config = InvokeAIAppConfig.get_config()
@ -260,19 +251,6 @@ db_conn = sqlite3.connect(config.db_path.as_posix(), check_same_thread=False)
store = ModelRecordServiceBase.open(config, db_conn, lock) store = ModelRecordServiceBase.open(config, db_conn, lock)
``` ```
_A note on simultaneous access to `invokeai.db`_: The current InvokeAI
service architecture for the image and graph databases is careful to
use a shared sqlite3 connection and a thread lock to ensure that two
threads don't attempt to access the database simultaneously. However,
the default `sqlite3` library used by Python reports using
**Serialized** mode, which allows multiple threads to access the
database simultaneously using multiple database connections (see
https://www.sqlite.org/threadsafe.html and
https://ricardoanderegg.com/posts/python-sqlite-thread-safety/). Therefore
it should be safe to allow the record service to open its own SQLite
database connection. Opening a model record service should then be as
simple as `ModelRecordServiceBase.open(config)`.
### Fetching a Model's Configuration from `ModelRecordServiceBase` ### Fetching a Model's Configuration from `ModelRecordServiceBase`
Configurations can be retrieved in several ways. Configurations can be retrieved in several ways.
@ -1465,7 +1443,7 @@ create alternative instances if you wish.
### Creating a ModelLoadService object ### Creating a ModelLoadService object
The class is defined in The class is defined in
`invokeai.app.services.model_loader_service`. It is initialized with `invokeai.app.services.model_load`. It is initialized with
an InvokeAIAppConfig object, from which it gets configuration an InvokeAIAppConfig object, from which it gets configuration
information such as the user's desired GPU and precision, and with a information such as the user's desired GPU and precision, and with a
previously-created `ModelRecordServiceBase` object, from which it previously-created `ModelRecordServiceBase` object, from which it
@ -1475,8 +1453,8 @@ Here is a typical initialization pattern:
``` ```
from invokeai.app.services.config import InvokeAIAppConfig from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.model_record_service import ModelRecordServiceBase from invokeai.app.services.model_records import ModelRecordServiceBase
from invokeai.app.services.model_loader_service import ModelLoadService from invokeai.app.services.model_load import ModelLoadService
config = InvokeAIAppConfig.get_config() config = InvokeAIAppConfig.get_config()
store = ModelRecordServiceBase.open(config) store = ModelRecordServiceBase.open(config)
@ -1487,14 +1465,11 @@ Note that we are relying on the contents of the application
configuration to choose the implementation of configuration to choose the implementation of
`ModelRecordServiceBase`. `ModelRecordServiceBase`.
### get_model(key, [submodel_type], [context]) -> ModelInfo: ### load_model_by_key(key, [submodel_type], [context]) -> LoadedModel
*** TO DO: change to get_model(key, context=None, **kwargs) The `load_model_by_key()` method receives the unique key that
identifies the model. It loads the model into memory, gets the model
The `get_model()` method, like its similarly-named cousin in ready for use, and returns a `LoadedModel` object.
`ModelRecordService`, receives the unique key that identifies the
model. It loads the model into memory, gets the model ready for use,
and returns a `ModelInfo` object.
The optional second argument, `subtype` is a `SubModelType` string The optional second argument, `subtype` is a `SubModelType` string
enum, such as "vae". It is mandatory when used with a main model, and enum, such as "vae". It is mandatory when used with a main model, and
@ -1504,46 +1479,64 @@ The optional third argument, `context` can be provided by
an invocation to trigger model load event reporting. See below for an invocation to trigger model load event reporting. See below for
details. details.
The returned `ModelInfo` object shares some fields in common with The returned `LoadedModel` object contains a copy of the configuration
`ModelConfigBase`, but is otherwise a completely different beast: record returned by the model record `get_model()` method, as well as
the in-memory loaded model:
| **Field Name** | **Type** | **Description** |
| **Attribute Name** | **Type** | **Description** |
|----------------|-----------------|------------------| |----------------|-----------------|------------------|
| `key` | str | The model key derived from the ModelRecordService database | | `config` | AnyModelConfig | A copy of the model's configuration record for retrieving base type, etc. |
| `name` | str | Name of this model | | `model` | AnyModel | The instantiated model (details below) |
| `base_model` | BaseModelType | Base model for this model | | `locker` | ModelLockerBase | A context manager that mediates the movement of the model into VRAM |
| `type` | ModelType or SubModelType | Either the model type (non-main) or the submodel type (main models)|
| `location` | Path or str | Location of the model on the filesystem |
| `precision` | torch.dtype | The torch.precision to use for inference |
| `context` | ModelCache.ModelLocker | A context class used to lock the model in VRAM while in use |
The types for `ModelInfo` and `SubModelType` can be imported from Because the loader can return multiple model types, it is typed to
`invokeai.app.services.model_loader_service`. return `AnyModel`, a Union `ModelMixin`, `torch.nn.Module`,
`IAIOnnxRuntimeModel`, `IPAdapter`, `IPAdapterPlus`, and
`EmbeddingModelRaw`. `ModelMixin` is the base class of all diffusers
models, `EmbeddingModelRaw` is used for LoRA and TextualInversion
models. The others are obvious.
To use the model, you use the `ModelInfo` as a context manager using
the following pattern: `LoadedModel` acts as a context manager. The context loads the model
into the execution device (e.g. VRAM on CUDA systems), locks the model
in the execution device for the duration of the context, and returns
the model. Use it like this:
``` ```
model_info = loader.get_model('f13dd932c0c35c22dcb8d6cda4203764', SubModelType('vae')) model_info = loader.get_model_by_key('f13dd932c0c35c22dcb8d6cda4203764', SubModelType('vae'))
with model_info as vae: with model_info as vae:
image = vae.decode(latents)[0] image = vae.decode(latents)[0]
``` ```
The `vae` model will stay locked in the GPU during the period of time `get_model_by_key()` may raise any of the following exceptions:
it is in the context manager's scope.
`get_model()` may raise any of the following exceptions: - `UnknownModelException` -- key not in database
- `ModelNotFoundException` -- key in database but model not found at path
- `UnknownModelException` -- key not in database - `NotImplementedException` -- the loader doesn't know how to load this type of model
- `ModelNotFoundException` -- key in database but model not found at path
- `InvalidModelException` -- the model is guilty of a variety of sins
** TO DO: ** Resolve discrepancy between ModelInfo.location and ### load_model_by_attr(model_name, base_model, model_type, [submodel], [context]) -> LoadedModel
ModelConfig.path.
This is similar to `load_model_by_key`, but instead it accepts the
combination of the model's name, type and base, which it passes to the
model record config store for retrieval. If successful, this method
returns a `LoadedModel`. It can raise the following exceptions:
```
UnknownModelException -- model with these attributes not known
NotImplementedException -- the loader doesn't know how to load this type of model
ValueError -- more than one model matches this combination of base/type/name
```
### load_model_by_config(config, [submodel], [context]) -> LoadedModel
This method takes an `AnyModelConfig` returned by
ModelRecordService.get_model() and returns the corresponding loaded
model. It may raise a `NotImplementedException`.
### Emitting model loading events ### Emitting model loading events
When the `context` argument is passed to `get_model()`, it will When the `context` argument is passed to `load_model_*()`, it will
retrieve the invocation event bus from the passed `InvocationContext` retrieve the invocation event bus from the passed `InvocationContext`
object to emit events on the invocation bus. The two events are object to emit events on the invocation bus. The two events are
"model_load_started" and "model_load_completed". Both carry the "model_load_started" and "model_load_completed". Both carry the
@ -1563,3 +1556,97 @@ payload=dict(
) )
``` ```
### Adding Model Loaders
Model loaders are small classes that inherit from the `ModelLoader`
base class. They typically implement one method `_load_model()` whose
signature is:
```
def _load_model(
self,
model_path: Path,
model_variant: Optional[ModelRepoVariant] = None,
submodel_type: Optional[SubModelType] = None,
) -> AnyModel:
```
`_load_model()` will be passed the path to the model on disk, an
optional repository variant (used by the diffusers loaders to select,
e.g. the `fp16` variant, and an optional submodel_type for main and
onnx models.
To install a new loader, place it in
`invokeai/backend/model_manager/load/model_loaders`. Inherit from
`ModelLoader` and use the `@AnyModelLoader.register()` decorator to
indicate what type of models the loader can handle.
Here is a complete example from `generic_diffusers.py`, which is able
to load several different diffusers types:
```
from pathlib import Path
from typing import Optional
from invokeai.backend.model_manager import (
AnyModel,
BaseModelType,
ModelFormat,
ModelRepoVariant,
ModelType,
SubModelType,
)
from ..load_base import AnyModelLoader
from ..load_default import ModelLoader
@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.CLIPVision, format=ModelFormat.Diffusers)
@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.T2IAdapter, format=ModelFormat.Diffusers)
class GenericDiffusersLoader(ModelLoader):
"""Class to load simple diffusers models."""
def _load_model(
self,
model_path: Path,
model_variant: Optional[ModelRepoVariant] = None,
submodel_type: Optional[SubModelType] = None,
) -> AnyModel:
model_class = self._get_hf_load_class(model_path)
if submodel_type is not None:
raise Exception(f"There are no submodels in models of type {model_class}")
variant = model_variant.value if model_variant else None
result: AnyModel = model_class.from_pretrained(model_path, torch_dtype=self._torch_dtype, variant=variant) # type: ignore
return result
```
Note that a loader can register itself to handle several different
model types. An exception will be raised if more than one loader tries
to register the same model type.
#### Conversion
Some models require conversion to diffusers format before they can be
loaded. These loaders should override two additional methods:
```
_needs_conversion(self, config: AnyModelConfig, model_path: Path, dest_path: Path) -> bool
_convert_model(self, config: AnyModelConfig, model_path: Path, output_path: Path) -> Path:
```
The first method accepts the model configuration, the path to where
the unmodified model is currently installed, and a proposed
destination for the converted model. This method returns True if the
model needs to be converted. It typically does this by comparing the
last modification time of the original model file to the modification
time of the converted model. In some cases you will also want to check
the modification date of the configuration record, in the event that
the user has changed something like the scheduler prediction type that
will require the model to be re-converted. See `controlnet.py` for an
example of this logic.
The second method accepts the model configuration, the path to the
original model on disk, and the desired output path for the converted
model. It does whatever it needs to do to get the model into diffusers
format, and returns the Path of the resulting model. (The path should
ordinarily be the same as `output_path`.)

View File

@ -4,9 +4,6 @@ from logging import Logger
from invokeai.app.services.item_storage.item_storage_memory import ItemStorageMemory from invokeai.app.services.item_storage.item_storage_memory import ItemStorageMemory
from invokeai.app.services.shared.sqlite.sqlite_util import init_db from invokeai.app.services.shared.sqlite.sqlite_util import init_db
from invokeai.backend.model_manager.load import AnyModelLoader, ModelConvertCache
from invokeai.backend.model_manager.load.model_cache import ModelCache
from invokeai.backend.model_manager.metadata import ModelMetadataStore
from invokeai.backend.util.logging import InvokeAILogger from invokeai.backend.util.logging import InvokeAILogger
from invokeai.version.invokeai_version import __version__ from invokeai.version.invokeai_version import __version__
@ -27,9 +24,7 @@ from ..services.invocation_stats.invocation_stats_default import InvocationStats
from ..services.invoker import Invoker from ..services.invoker import Invoker
from ..services.latents_storage.latents_storage_disk import DiskLatentsStorage from ..services.latents_storage.latents_storage_disk import DiskLatentsStorage
from ..services.latents_storage.latents_storage_forward_cache import ForwardCacheLatentsStorage from ..services.latents_storage.latents_storage_forward_cache import ForwardCacheLatentsStorage
from ..services.model_install import ModelInstallService
from ..services.model_manager.model_manager_default import ModelManagerService from ..services.model_manager.model_manager_default import ModelManagerService
from ..services.model_records import ModelRecordServiceSQL
from ..services.names.names_default import SimpleNameService from ..services.names.names_default import SimpleNameService
from ..services.session_processor.session_processor_default import DefaultSessionProcessor from ..services.session_processor.session_processor_default import DefaultSessionProcessor
from ..services.session_queue.session_queue_sqlite import SqliteSessionQueue from ..services.session_queue.session_queue_sqlite import SqliteSessionQueue
@ -87,26 +82,10 @@ class ApiDependencies:
images = ImageService() images = ImageService()
invocation_cache = MemoryInvocationCache(max_cache_size=config.node_cache_size) invocation_cache = MemoryInvocationCache(max_cache_size=config.node_cache_size)
latents = ForwardCacheLatentsStorage(DiskLatentsStorage(f"{output_folder}/latents")) latents = ForwardCacheLatentsStorage(DiskLatentsStorage(f"{output_folder}/latents"))
model_loader = AnyModelLoader(
app_config=config,
logger=logger,
ram_cache=ModelCache(
max_cache_size=config.ram_cache_size, max_vram_cache_size=config.vram_cache_size, logger=logger
),
convert_cache=ModelConvertCache(
cache_path=config.models_convert_cache_path, max_size=config.convert_cache_size
),
)
model_record_service = ModelRecordServiceSQL(db=db, loader=model_loader)
download_queue_service = DownloadQueueService(event_bus=events) download_queue_service = DownloadQueueService(event_bus=events)
model_install_service = ModelInstallService( model_manager = ModelManagerService.build_model_manager(
app_config=config, app_config=configuration, db=db, download_queue=download_queue_service, events=events
record_store=model_record_service,
download_queue=download_queue_service,
metadata_store=ModelMetadataStore(db=db),
event_bus=events,
) )
model_manager = ModelManagerService(config, logger) # TO DO: legacy model manager v1. Remove
names = SimpleNameService() names = SimpleNameService()
performance_statistics = InvocationStatsService() performance_statistics = InvocationStatsService()
processor = DefaultInvocationProcessor() processor = DefaultInvocationProcessor()
@ -131,9 +110,7 @@ class ApiDependencies:
latents=latents, latents=latents,
logger=logger, logger=logger,
model_manager=model_manager, model_manager=model_manager,
model_records=model_record_service,
download_queue=download_queue_service, download_queue=download_queue_service,
model_install=model_install_service,
names=names, names=names,
performance_statistics=performance_statistics, performance_statistics=performance_statistics,
processor=processor, processor=processor,

View File

@ -32,7 +32,7 @@ from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata
from ..dependencies import ApiDependencies from ..dependencies import ApiDependencies
model_records_router = APIRouter(prefix="/v1/model/record", tags=["model_manager_v2_unstable"]) model_manager_v2_router = APIRouter(prefix="/v2/models", tags=["model_manager_v2"])
class ModelsList(BaseModel): class ModelsList(BaseModel):
@ -52,7 +52,7 @@ class ModelTagSet(BaseModel):
tags: Set[str] tags: Set[str]
@model_records_router.get( @model_manager_v2_router.get(
"/", "/",
operation_id="list_model_records", operation_id="list_model_records",
) )
@ -65,7 +65,7 @@ async def list_model_records(
), ),
) -> ModelsList: ) -> ModelsList:
"""Get a list of models.""" """Get a list of models."""
record_store = ApiDependencies.invoker.services.model_records record_store = ApiDependencies.invoker.services.model_manager.store
found_models: list[AnyModelConfig] = [] found_models: list[AnyModelConfig] = []
if base_models: if base_models:
for base_model in base_models: for base_model in base_models:
@ -81,7 +81,7 @@ async def list_model_records(
return ModelsList(models=found_models) return ModelsList(models=found_models)
@model_records_router.get( @model_manager_v2_router.get(
"/i/{key}", "/i/{key}",
operation_id="get_model_record", operation_id="get_model_record",
responses={ responses={
@ -94,24 +94,27 @@ async def get_model_record(
key: str = Path(description="Key of the model record to fetch."), key: str = Path(description="Key of the model record to fetch."),
) -> AnyModelConfig: ) -> AnyModelConfig:
"""Get a model record""" """Get a model record"""
record_store = ApiDependencies.invoker.services.model_records record_store = ApiDependencies.invoker.services.model_manager.store
try: try:
return record_store.get_model(key) config: AnyModelConfig = record_store.get_model(key)
return config
except UnknownModelException as e: except UnknownModelException as e:
raise HTTPException(status_code=404, detail=str(e)) raise HTTPException(status_code=404, detail=str(e))
@model_records_router.get("/meta", operation_id="list_model_summary") @model_manager_v2_router.get("/meta", operation_id="list_model_summary")
async def list_model_summary( async def list_model_summary(
page: int = Query(default=0, description="The page to get"), page: int = Query(default=0, description="The page to get"),
per_page: int = Query(default=10, description="The number of models per page"), per_page: int = Query(default=10, description="The number of models per page"),
order_by: ModelRecordOrderBy = Query(default=ModelRecordOrderBy.Default, description="The attribute to order by"), order_by: ModelRecordOrderBy = Query(default=ModelRecordOrderBy.Default, description="The attribute to order by"),
) -> PaginatedResults[ModelSummary]: ) -> PaginatedResults[ModelSummary]:
"""Gets a page of model summary data.""" """Gets a page of model summary data."""
return ApiDependencies.invoker.services.model_records.list_models(page=page, per_page=per_page, order_by=order_by) record_store = ApiDependencies.invoker.services.model_manager.store
results: PaginatedResults[ModelSummary] = record_store.list_models(page=page, per_page=per_page, order_by=order_by)
return results
@model_records_router.get( @model_manager_v2_router.get(
"/meta/i/{key}", "/meta/i/{key}",
operation_id="get_model_metadata", operation_id="get_model_metadata",
responses={ responses={
@ -124,24 +127,25 @@ async def get_model_metadata(
key: str = Path(description="Key of the model repo metadata to fetch."), key: str = Path(description="Key of the model repo metadata to fetch."),
) -> Optional[AnyModelRepoMetadata]: ) -> Optional[AnyModelRepoMetadata]:
"""Get a model metadata object.""" """Get a model metadata object."""
record_store = ApiDependencies.invoker.services.model_records record_store = ApiDependencies.invoker.services.model_manager.store
result = record_store.get_metadata(key) result: Optional[AnyModelRepoMetadata] = record_store.get_metadata(key)
if not result: if not result:
raise HTTPException(status_code=404, detail="No metadata for a model with this key") raise HTTPException(status_code=404, detail="No metadata for a model with this key")
return result return result
@model_records_router.get( @model_manager_v2_router.get(
"/tags", "/tags",
operation_id="list_tags", operation_id="list_tags",
) )
async def list_tags() -> Set[str]: async def list_tags() -> Set[str]:
"""Get a unique set of all the model tags.""" """Get a unique set of all the model tags."""
record_store = ApiDependencies.invoker.services.model_records record_store = ApiDependencies.invoker.services.model_manager.store
return record_store.list_tags() result: Set[str] = record_store.list_tags()
return result
@model_records_router.get( @model_manager_v2_router.get(
"/tags/search", "/tags/search",
operation_id="search_by_metadata_tags", operation_id="search_by_metadata_tags",
) )
@ -149,12 +153,12 @@ async def search_by_metadata_tags(
tags: Set[str] = Query(default=None, description="Tags to search for"), tags: Set[str] = Query(default=None, description="Tags to search for"),
) -> ModelsList: ) -> ModelsList:
"""Get a list of models.""" """Get a list of models."""
record_store = ApiDependencies.invoker.services.model_records record_store = ApiDependencies.invoker.services.model_manager.store
results = record_store.search_by_metadata_tag(tags) results = record_store.search_by_metadata_tag(tags)
return ModelsList(models=results) return ModelsList(models=results)
@model_records_router.patch( @model_manager_v2_router.patch(
"/i/{key}", "/i/{key}",
operation_id="update_model_record", operation_id="update_model_record",
responses={ responses={
@ -172,9 +176,9 @@ async def update_model_record(
) -> AnyModelConfig: ) -> AnyModelConfig:
"""Update model contents with a new config. If the model name or base fields are changed, then the model is renamed.""" """Update model contents with a new config. If the model name or base fields are changed, then the model is renamed."""
logger = ApiDependencies.invoker.services.logger logger = ApiDependencies.invoker.services.logger
record_store = ApiDependencies.invoker.services.model_records record_store = ApiDependencies.invoker.services.model_manager.store
try: try:
model_response = record_store.update_model(key, config=info) model_response: AnyModelConfig = record_store.update_model(key, config=info)
logger.info(f"Updated model: {key}") logger.info(f"Updated model: {key}")
except UnknownModelException as e: except UnknownModelException as e:
raise HTTPException(status_code=404, detail=str(e)) raise HTTPException(status_code=404, detail=str(e))
@ -184,7 +188,7 @@ async def update_model_record(
return model_response return model_response
@model_records_router.delete( @model_manager_v2_router.delete(
"/i/{key}", "/i/{key}",
operation_id="del_model_record", operation_id="del_model_record",
responses={ responses={
@ -205,7 +209,7 @@ async def del_model_record(
logger = ApiDependencies.invoker.services.logger logger = ApiDependencies.invoker.services.logger
try: try:
installer = ApiDependencies.invoker.services.model_install installer = ApiDependencies.invoker.services.model_manager.install
installer.delete(key) installer.delete(key)
logger.info(f"Deleted model: {key}") logger.info(f"Deleted model: {key}")
return Response(status_code=204) return Response(status_code=204)
@ -214,7 +218,7 @@ async def del_model_record(
raise HTTPException(status_code=404, detail=str(e)) raise HTTPException(status_code=404, detail=str(e))
@model_records_router.post( @model_manager_v2_router.post(
"/i/", "/i/",
operation_id="add_model_record", operation_id="add_model_record",
responses={ responses={
@ -229,7 +233,7 @@ async def add_model_record(
) -> AnyModelConfig: ) -> AnyModelConfig:
"""Add a model using the configuration information appropriate for its type.""" """Add a model using the configuration information appropriate for its type."""
logger = ApiDependencies.invoker.services.logger logger = ApiDependencies.invoker.services.logger
record_store = ApiDependencies.invoker.services.model_records record_store = ApiDependencies.invoker.services.model_manager.store
if config.key == "<NOKEY>": if config.key == "<NOKEY>":
config.key = sha1(randbytes(100)).hexdigest() config.key = sha1(randbytes(100)).hexdigest()
logger.info(f"Created model {config.key} for {config.name}") logger.info(f"Created model {config.key} for {config.name}")
@ -243,10 +247,11 @@ async def add_model_record(
raise HTTPException(status_code=415) raise HTTPException(status_code=415)
# now fetch it out # now fetch it out
return record_store.get_model(config.key) result: AnyModelConfig = record_store.get_model(config.key)
return result
@model_records_router.post( @model_manager_v2_router.post(
"/import", "/import",
operation_id="import_model_record", operation_id="import_model_record",
responses={ responses={
@ -322,7 +327,7 @@ async def import_model(
logger = ApiDependencies.invoker.services.logger logger = ApiDependencies.invoker.services.logger
try: try:
installer = ApiDependencies.invoker.services.model_install installer = ApiDependencies.invoker.services.model_manager.install
result: ModelInstallJob = installer.import_model( result: ModelInstallJob = installer.import_model(
source=source, source=source,
config=config, config=config,
@ -340,17 +345,17 @@ async def import_model(
return result return result
@model_records_router.get( @model_manager_v2_router.get(
"/import", "/import",
operation_id="list_model_install_jobs", operation_id="list_model_install_jobs",
) )
async def list_model_install_jobs() -> List[ModelInstallJob]: async def list_model_install_jobs() -> List[ModelInstallJob]:
"""Return list of model install jobs.""" """Return list of model install jobs."""
jobs: List[ModelInstallJob] = ApiDependencies.invoker.services.model_install.list_jobs() jobs: List[ModelInstallJob] = ApiDependencies.invoker.services.model_manager.install.list_jobs()
return jobs return jobs
@model_records_router.get( @model_manager_v2_router.get(
"/import/{id}", "/import/{id}",
operation_id="get_model_install_job", operation_id="get_model_install_job",
responses={ responses={
@ -361,12 +366,13 @@ async def list_model_install_jobs() -> List[ModelInstallJob]:
async def get_model_install_job(id: int = Path(description="Model install id")) -> ModelInstallJob: async def get_model_install_job(id: int = Path(description="Model install id")) -> ModelInstallJob:
"""Return model install job corresponding to the given source.""" """Return model install job corresponding to the given source."""
try: try:
return ApiDependencies.invoker.services.model_install.get_job_by_id(id) result: ModelInstallJob = ApiDependencies.invoker.services.model_manager.install.get_job_by_id(id)
return result
except ValueError as e: except ValueError as e:
raise HTTPException(status_code=404, detail=str(e)) raise HTTPException(status_code=404, detail=str(e))
@model_records_router.delete( @model_manager_v2_router.delete(
"/import/{id}", "/import/{id}",
operation_id="cancel_model_install_job", operation_id="cancel_model_install_job",
responses={ responses={
@ -377,7 +383,7 @@ async def get_model_install_job(id: int = Path(description="Model install id"))
) )
async def cancel_model_install_job(id: int = Path(description="Model install job ID")) -> None: async def cancel_model_install_job(id: int = Path(description="Model install job ID")) -> None:
"""Cancel the model install job(s) corresponding to the given job ID.""" """Cancel the model install job(s) corresponding to the given job ID."""
installer = ApiDependencies.invoker.services.model_install installer = ApiDependencies.invoker.services.model_manager.install
try: try:
job = installer.get_job_by_id(id) job = installer.get_job_by_id(id)
except ValueError as e: except ValueError as e:
@ -385,7 +391,7 @@ async def cancel_model_install_job(id: int = Path(description="Model install job
installer.cancel_job(job) installer.cancel_job(job)
@model_records_router.patch( @model_manager_v2_router.patch(
"/import", "/import",
operation_id="prune_model_install_jobs", operation_id="prune_model_install_jobs",
responses={ responses={
@ -395,11 +401,11 @@ async def cancel_model_install_job(id: int = Path(description="Model install job
) )
async def prune_model_install_jobs() -> Response: async def prune_model_install_jobs() -> Response:
"""Prune all completed and errored jobs from the install job list.""" """Prune all completed and errored jobs from the install job list."""
ApiDependencies.invoker.services.model_install.prune_jobs() ApiDependencies.invoker.services.model_manager.install.prune_jobs()
return Response(status_code=204) return Response(status_code=204)
@model_records_router.patch( @model_manager_v2_router.patch(
"/sync", "/sync",
operation_id="sync_models_to_config", operation_id="sync_models_to_config",
responses={ responses={
@ -414,11 +420,11 @@ async def sync_models_to_config() -> Response:
Model files without a corresponding Model files without a corresponding
record in the database are added. Orphan records without a models file are deleted. record in the database are added. Orphan records without a models file are deleted.
""" """
ApiDependencies.invoker.services.model_install.sync_to_config() ApiDependencies.invoker.services.model_manager.install.sync_to_config()
return Response(status_code=204) return Response(status_code=204)
@model_records_router.put( @model_manager_v2_router.put(
"/merge", "/merge",
operation_id="merge", operation_id="merge",
) )
@ -451,7 +457,7 @@ async def merge(
try: try:
logger.info(f"Merging models: {keys} into {merge_dest_directory or '<MODELS>'}/{merged_model_name}") logger.info(f"Merging models: {keys} into {merge_dest_directory or '<MODELS>'}/{merged_model_name}")
dest = pathlib.Path(merge_dest_directory) if merge_dest_directory else None dest = pathlib.Path(merge_dest_directory) if merge_dest_directory else None
installer = ApiDependencies.invoker.services.model_install installer = ApiDependencies.invoker.services.model_manager.install
merger = ModelMerger(installer) merger = ModelMerger(installer)
model_names = [installer.record_store.get_model(x).name for x in keys] model_names = [installer.record_store.get_model(x).name for x in keys]
response = merger.merge_diffusion_models_and_save( response = merger.merge_diffusion_models_and_save(

View File

@ -8,8 +8,7 @@ from fastapi.routing import APIRouter
from pydantic import BaseModel, ConfigDict, Field, TypeAdapter from pydantic import BaseModel, ConfigDict, Field, TypeAdapter
from starlette.exceptions import HTTPException from starlette.exceptions import HTTPException
from invokeai.backend import BaseModelType, ModelType from invokeai.backend.model_management import BaseModelType, MergeInterpolationMethod, ModelType
from invokeai.backend.model_management import MergeInterpolationMethod
from invokeai.backend.model_management.models import ( from invokeai.backend.model_management.models import (
OPENAPI_MODEL_CONFIGS, OPENAPI_MODEL_CONFIGS,
InvalidModelException, InvalidModelException,

View File

@ -47,7 +47,7 @@ if True: # hack to make flake8 happy with imports coming after setting up the c
boards, boards,
download_queue, download_queue,
images, images,
model_records, model_manager_v2,
models, models,
session_queue, session_queue,
sessions, sessions,
@ -115,8 +115,7 @@ async def shutdown_event() -> None:
app.include_router(sessions.session_router, prefix="/api") app.include_router(sessions.session_router, prefix="/api")
app.include_router(utilities.utilities_router, prefix="/api") app.include_router(utilities.utilities_router, prefix="/api")
app.include_router(models.models_router, prefix="/api") app.include_router(model_manager_v2.model_manager_v2_router, prefix="/api")
app.include_router(model_records.model_records_router, prefix="/api")
app.include_router(download_queue.download_queue_router, prefix="/api") app.include_router(download_queue.download_queue_router, prefix="/api")
app.include_router(images.images_router, prefix="/api") app.include_router(images.images_router, prefix="/api")
app.include_router(boards.boards_router, prefix="/api") app.include_router(boards.boards_router, prefix="/api")

View File

@ -4,6 +4,7 @@ from typing import Iterator, List, Optional, Tuple, Union
import torch import torch
from compel import Compel, ReturnedEmbeddingsType from compel import Compel, ReturnedEmbeddingsType
from compel.prompt_parser import Blend, Conjunction, CrossAttentionControlSubstitute, FlattenedPrompt, Fragment from compel.prompt_parser import Blend, Conjunction, CrossAttentionControlSubstitute, FlattenedPrompt, Fragment
from transformers import CLIPTokenizer
import invokeai.backend.util.logging as logger import invokeai.backend.util.logging as logger
from invokeai.app.invocations.primitives import ConditioningField, ConditioningOutput from invokeai.app.invocations.primitives import ConditioningField, ConditioningOutput
@ -70,18 +71,18 @@ class CompelInvocation(BaseInvocation):
@torch.no_grad() @torch.no_grad()
def invoke(self, context: InvocationContext) -> ConditioningOutput: def invoke(self, context: InvocationContext) -> ConditioningOutput:
tokenizer_info = context.services.model_records.load_model( tokenizer_info = context.services.model_manager.load.load_model_by_key(
**self.clip.tokenizer.model_dump(), **self.clip.tokenizer.model_dump(),
context=context, context=context,
) )
text_encoder_info = context.services.model_records.load_model( text_encoder_info = context.services.model_manager.load.load_model_by_key(
**self.clip.text_encoder.model_dump(), **self.clip.text_encoder.model_dump(),
context=context, context=context,
) )
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]: def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
for lora in self.clip.loras: for lora in self.clip.loras:
lora_info = context.services.model_records.load_model( lora_info = context.services.model_manager.load.load_model_by_key(
**lora.model_dump(exclude={"weight"}), context=context **lora.model_dump(exclude={"weight"}), context=context
) )
assert isinstance(lora_info.model, LoRAModelRaw) assert isinstance(lora_info.model, LoRAModelRaw)
@ -95,7 +96,7 @@ class CompelInvocation(BaseInvocation):
for trigger in extract_ti_triggers_from_prompt(self.prompt): for trigger in extract_ti_triggers_from_prompt(self.prompt):
name = trigger[1:-1] name = trigger[1:-1]
try: try:
loaded_model = context.services.model_records.load_model( loaded_model = context.services.model_manager.load.load_model_by_key(
**self.clip.text_encoder.model_dump(), **self.clip.text_encoder.model_dump(),
context=context, context=context,
).model ).model
@ -171,11 +172,11 @@ class SDXLPromptInvocationBase:
lora_prefix: str, lora_prefix: str,
zero_on_empty: bool, zero_on_empty: bool,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[ExtraConditioningInfo]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[ExtraConditioningInfo]]:
tokenizer_info = context.services.model_records.load_model( tokenizer_info = context.services.model_manager.load.load_model_by_key(
**clip_field.tokenizer.model_dump(), **clip_field.tokenizer.model_dump(),
context=context, context=context,
) )
text_encoder_info = context.services.model_records.load_model( text_encoder_info = context.services.model_manager.load.load_model_by_key(
**clip_field.text_encoder.model_dump(), **clip_field.text_encoder.model_dump(),
context=context, context=context,
) )
@ -203,7 +204,7 @@ class SDXLPromptInvocationBase:
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]: def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
for lora in clip_field.loras: for lora in clip_field.loras:
lora_info = context.services.model_records.load_model( lora_info = context.services.model_manager.load.load_model_by_key(
**lora.model_dump(exclude={"weight"}), context=context **lora.model_dump(exclude={"weight"}), context=context
) )
lora_model = lora_info.model lora_model = lora_info.model
@ -218,7 +219,7 @@ class SDXLPromptInvocationBase:
for trigger in extract_ti_triggers_from_prompt(prompt): for trigger in extract_ti_triggers_from_prompt(prompt):
name = trigger[1:-1] name = trigger[1:-1]
try: try:
ti_model = context.services.model_records.load_model_by_attr( ti_model = context.services.model_manager.load.load_model_by_attr(
model_name=name, model_name=name,
base_model=text_encoder_info.config.base, base_model=text_encoder_info.config.base,
model_type=ModelType.TextualInversion, model_type=ModelType.TextualInversion,
@ -465,9 +466,9 @@ class ClipSkipInvocation(BaseInvocation):
def get_max_token_count( def get_max_token_count(
tokenizer, tokenizer: CLIPTokenizer,
prompt: Union[FlattenedPrompt, Blend, Conjunction], prompt: Union[FlattenedPrompt, Blend, Conjunction],
truncate_if_too_long=False, truncate_if_too_long: bool = False,
) -> int: ) -> int:
if type(prompt) is Blend: if type(prompt) is Blend:
blend: Blend = prompt blend: Blend = prompt
@ -479,7 +480,9 @@ def get_max_token_count(
return len(get_tokens_for_prompt_object(tokenizer, prompt, truncate_if_too_long)) return len(get_tokens_for_prompt_object(tokenizer, prompt, truncate_if_too_long))
def get_tokens_for_prompt_object(tokenizer, parsed_prompt: FlattenedPrompt, truncate_if_too_long=True) -> List[str]: def get_tokens_for_prompt_object(
tokenizer: CLIPTokenizer, parsed_prompt: FlattenedPrompt, truncate_if_too_long: bool = True
) -> List[str]:
if type(parsed_prompt) is Blend: if type(parsed_prompt) is Blend:
raise ValueError("Blend is not supported here - you need to get tokens for each of its .children") raise ValueError("Blend is not supported here - you need to get tokens for each of its .children")
@ -492,24 +495,29 @@ def get_tokens_for_prompt_object(tokenizer, parsed_prompt: FlattenedPrompt, trun
for x in parsed_prompt.children for x in parsed_prompt.children
] ]
text = " ".join(text_fragments) text = " ".join(text_fragments)
tokens = tokenizer.tokenize(text) tokens: List[str] = tokenizer.tokenize(text)
if truncate_if_too_long: if truncate_if_too_long:
max_tokens_length = tokenizer.model_max_length - 2 # typically 75 max_tokens_length = tokenizer.model_max_length - 2 # typically 75
tokens = tokens[0:max_tokens_length] tokens = tokens[0:max_tokens_length]
return tokens return tokens
def log_tokenization_for_conjunction(c: Conjunction, tokenizer, display_label_prefix=None): def log_tokenization_for_conjunction(
c: Conjunction, tokenizer: CLIPTokenizer, display_label_prefix: Optional[str] = None
) -> None:
display_label_prefix = display_label_prefix or "" display_label_prefix = display_label_prefix or ""
for i, p in enumerate(c.prompts): for i, p in enumerate(c.prompts):
if len(c.prompts) > 1: if len(c.prompts) > 1:
this_display_label_prefix = f"{display_label_prefix}(conjunction part {i + 1}, weight={c.weights[i]})" this_display_label_prefix = f"{display_label_prefix}(conjunction part {i + 1}, weight={c.weights[i]})"
else: else:
assert display_label_prefix is not None
this_display_label_prefix = display_label_prefix this_display_label_prefix = display_label_prefix
log_tokenization_for_prompt_object(p, tokenizer, display_label_prefix=this_display_label_prefix) log_tokenization_for_prompt_object(p, tokenizer, display_label_prefix=this_display_label_prefix)
def log_tokenization_for_prompt_object(p: Union[Blend, FlattenedPrompt], tokenizer, display_label_prefix=None): def log_tokenization_for_prompt_object(
p: Union[Blend, FlattenedPrompt], tokenizer: CLIPTokenizer, display_label_prefix: Optional[str] = None
) -> None:
display_label_prefix = display_label_prefix or "" display_label_prefix = display_label_prefix or ""
if type(p) is Blend: if type(p) is Blend:
blend: Blend = p blend: Blend = p
@ -549,7 +557,12 @@ def log_tokenization_for_prompt_object(p: Union[Blend, FlattenedPrompt], tokeniz
log_tokenization_for_text(text, tokenizer, display_label=display_label_prefix) log_tokenization_for_text(text, tokenizer, display_label=display_label_prefix)
def log_tokenization_for_text(text, tokenizer, display_label=None, truncate_if_too_long=False): def log_tokenization_for_text(
text: str,
tokenizer: CLIPTokenizer,
display_label: Optional[str] = None,
truncate_if_too_long: Optional[bool] = False,
) -> None:
"""shows how the prompt is tokenized """shows how the prompt is tokenized
# usually tokens have '</w>' to indicate end-of-word, # usually tokens have '</w>' to indicate end-of-word,
# but for readability it has been replaced with ' ' # but for readability it has been replaced with ' '

View File

@ -3,13 +3,15 @@
import math import math
from contextlib import ExitStack from contextlib import ExitStack
from functools import singledispatchmethod from functools import singledispatchmethod
from typing import Iterator, List, Literal, Optional, Tuple, Union from typing import Any, Iterator, List, Literal, Optional, Tuple, Union
import einops import einops
import numpy as np import numpy as np
import numpy.typing as npt
import torch import torch
import torchvision.transforms as T import torchvision.transforms as T
from diffusers import AutoencoderKL, AutoencoderTiny, UNet2DConditionModel from diffusers import AutoencoderKL, AutoencoderTiny
from diffusers.configuration_utils import ConfigMixin
from diffusers.image_processor import VaeImageProcessor from diffusers.image_processor import VaeImageProcessor
from diffusers.models.adapter import T2IAdapter from diffusers.models.adapter import T2IAdapter
from diffusers.models.attention_processor import ( from diffusers.models.attention_processor import (
@ -18,8 +20,10 @@ from diffusers.models.attention_processor import (
LoRAXFormersAttnProcessor, LoRAXFormersAttnProcessor,
XFormersAttnProcessor, XFormersAttnProcessor,
) )
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
from diffusers.schedulers import DPMSolverSDEScheduler from diffusers.schedulers import DPMSolverSDEScheduler
from diffusers.schedulers import SchedulerMixin as Scheduler from diffusers.schedulers import SchedulerMixin as Scheduler
from PIL import Image
from pydantic import field_validator from pydantic import field_validator
from torchvision.transforms.functional import resize as tv_resize from torchvision.transforms.functional import resize as tv_resize
@ -38,9 +42,10 @@ from invokeai.app.services.image_records.image_records_common import ImageCatego
from invokeai.app.shared.fields import FieldDescriptions from invokeai.app.shared.fields import FieldDescriptions
from invokeai.app.util.controlnet_utils import prepare_control_image from invokeai.app.util.controlnet_utils import prepare_control_image
from invokeai.app.util.step_callback import stable_diffusion_step_callback from invokeai.app.util.step_callback import stable_diffusion_step_callback
from invokeai.backend.embeddings.lora import LoRAModelRaw
from invokeai.backend.embeddings.model_patcher import ModelPatcher from invokeai.backend.embeddings.model_patcher import ModelPatcher
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter, IPAdapterPlus from invokeai.backend.ip_adapter.ip_adapter import IPAdapter, IPAdapterPlus
from invokeai.backend.model_manager import AnyModel, BaseModelType from invokeai.backend.model_manager import BaseModelType, LoadedModel
from invokeai.backend.stable_diffusion import PipelineIntermediateState, set_seamless from invokeai.backend.stable_diffusion import PipelineIntermediateState, set_seamless
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningData, IPAdapterConditioningInfo from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningData, IPAdapterConditioningInfo
from invokeai.backend.util.silence_warnings import SilenceWarnings from invokeai.backend.util.silence_warnings import SilenceWarnings
@ -76,7 +81,9 @@ if choose_torch_device() == torch.device("mps"):
DEFAULT_PRECISION = choose_precision(choose_torch_device()) DEFAULT_PRECISION = choose_precision(choose_torch_device())
SAMPLER_NAME_VALUES = Literal[tuple(SCHEDULER_MAP.keys())] # FIXME: "Invalid type alias" SAMPLER_NAME_VALUES = Literal[
tuple(SCHEDULER_MAP.keys())
] # FIXME: "Invalid type alias". This defeats static type checking.
# HACK: Many nodes are currently hard-coded to use a fixed latent scale factor of 8. This is fragile, and will need to # HACK: Many nodes are currently hard-coded to use a fixed latent scale factor of 8. This is fragile, and will need to
# be addressed if future models use a different latent scale factor. Also, note that there may be places where the scale # be addressed if future models use a different latent scale factor. Also, note that there may be places where the scale
@ -130,10 +137,10 @@ class CreateDenoiseMaskInvocation(BaseInvocation):
ui_order=4, ui_order=4,
) )
def prep_mask_tensor(self, mask_image): def prep_mask_tensor(self, mask_image: Image) -> torch.Tensor:
if mask_image.mode != "L": if mask_image.mode != "L":
mask_image = mask_image.convert("L") mask_image = mask_image.convert("L")
mask_tensor = image_resized_to_grid_as_tensor(mask_image, normalize=False) mask_tensor: torch.Tensor = image_resized_to_grid_as_tensor(mask_image, normalize=False)
if mask_tensor.dim() == 3: if mask_tensor.dim() == 3:
mask_tensor = mask_tensor.unsqueeze(0) mask_tensor = mask_tensor.unsqueeze(0)
# if shape is not None: # if shape is not None:
@ -144,24 +151,24 @@ class CreateDenoiseMaskInvocation(BaseInvocation):
def invoke(self, context: InvocationContext) -> DenoiseMaskOutput: def invoke(self, context: InvocationContext) -> DenoiseMaskOutput:
if self.image is not None: if self.image is not None:
image = context.services.images.get_pil_image(self.image.image_name) image = context.services.images.get_pil_image(self.image.image_name)
image = image_resized_to_grid_as_tensor(image.convert("RGB")) image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB"))
if image.dim() == 3: if image_tensor.dim() == 3:
image = image.unsqueeze(0) image_tensor = image_tensor.unsqueeze(0)
else: else:
image = None image_tensor = None
mask = self.prep_mask_tensor( mask = self.prep_mask_tensor(
context.services.images.get_pil_image(self.mask.image_name), context.services.images.get_pil_image(self.mask.image_name),
) )
if image is not None: if image_tensor is not None:
vae_info = context.services.model_records.load_model( vae_info = context.services.model_manager.load.load_model_by_key(
**self.vae.vae.model_dump(), **self.vae.vae.model_dump(),
context=context, context=context,
) )
img_mask = tv_resize(mask, image.shape[-2:], T.InterpolationMode.BILINEAR, antialias=False) img_mask = tv_resize(mask, image_tensor.shape[-2:], T.InterpolationMode.BILINEAR, antialias=False)
masked_image = image * torch.where(img_mask < 0.5, 0.0, 1.0) masked_image = image_tensor * torch.where(img_mask < 0.5, 0.0, 1.0)
# TODO: # TODO:
masked_latents = ImageToLatentsInvocation.vae_encode(vae_info, self.fp32, self.tiled, masked_image.clone()) masked_latents = ImageToLatentsInvocation.vae_encode(vae_info, self.fp32, self.tiled, masked_image.clone())
@ -188,7 +195,7 @@ def get_scheduler(
seed: int, seed: int,
) -> Scheduler: ) -> Scheduler:
scheduler_class, scheduler_extra_config = SCHEDULER_MAP.get(scheduler_name, SCHEDULER_MAP["ddim"]) scheduler_class, scheduler_extra_config = SCHEDULER_MAP.get(scheduler_name, SCHEDULER_MAP["ddim"])
orig_scheduler_info = context.services.model_records.load_model( orig_scheduler_info = context.services.model_manager.load.load_model_by_key(
**scheduler_info.model_dump(), **scheduler_info.model_dump(),
context=context, context=context,
) )
@ -199,7 +206,7 @@ def get_scheduler(
scheduler_config = scheduler_config["_backup"] scheduler_config = scheduler_config["_backup"]
scheduler_config = { scheduler_config = {
**scheduler_config, **scheduler_config,
**scheduler_extra_config, **scheduler_extra_config, # FIXME
"_backup": scheduler_config, "_backup": scheduler_config,
} }
@ -212,6 +219,7 @@ def get_scheduler(
# hack copied over from generate.py # hack copied over from generate.py
if not hasattr(scheduler, "uses_inpainting_model"): if not hasattr(scheduler, "uses_inpainting_model"):
scheduler.uses_inpainting_model = lambda: False scheduler.uses_inpainting_model = lambda: False
assert isinstance(scheduler, Scheduler)
return scheduler return scheduler
@ -295,7 +303,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
) )
@field_validator("cfg_scale") @field_validator("cfg_scale")
def ge_one(cls, v): def ge_one(cls, v: Union[List[float], float]) -> Union[List[float], float]:
"""validate that all cfg_scale values are >= 1""" """validate that all cfg_scale values are >= 1"""
if isinstance(v, list): if isinstance(v, list):
for i in v: for i in v:
@ -325,9 +333,9 @@ class DenoiseLatentsInvocation(BaseInvocation):
def get_conditioning_data( def get_conditioning_data(
self, self,
context: InvocationContext, context: InvocationContext,
scheduler, scheduler: Scheduler,
unet, unet: UNet2DConditionModel,
seed, seed: int,
) -> ConditioningData: ) -> ConditioningData:
positive_cond_data = context.services.latents.get(self.positive_conditioning.conditioning_name) positive_cond_data = context.services.latents.get(self.positive_conditioning.conditioning_name)
c = positive_cond_data.conditionings[0].to(device=unet.device, dtype=unet.dtype) c = positive_cond_data.conditionings[0].to(device=unet.device, dtype=unet.dtype)
@ -350,7 +358,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
), ),
) )
conditioning_data = conditioning_data.add_scheduler_args_if_applicable( conditioning_data = conditioning_data.add_scheduler_args_if_applicable( # FIXME
scheduler, scheduler,
# for ddim scheduler # for ddim scheduler
eta=0.0, # ddim_eta eta=0.0, # ddim_eta
@ -362,8 +370,8 @@ class DenoiseLatentsInvocation(BaseInvocation):
def create_pipeline( def create_pipeline(
self, self,
unet, unet: UNet2DConditionModel,
scheduler, scheduler: Scheduler,
) -> StableDiffusionGeneratorPipeline: ) -> StableDiffusionGeneratorPipeline:
# TODO: # TODO:
# configure_model_padding( # configure_model_padding(
@ -374,10 +382,10 @@ class DenoiseLatentsInvocation(BaseInvocation):
class FakeVae: class FakeVae:
class FakeVaeConfig: class FakeVaeConfig:
def __init__(self): def __init__(self) -> None:
self.block_out_channels = [0] self.block_out_channels = [0]
def __init__(self): def __init__(self) -> None:
self.config = FakeVae.FakeVaeConfig() self.config = FakeVae.FakeVaeConfig()
return StableDiffusionGeneratorPipeline( return StableDiffusionGeneratorPipeline(
@ -394,11 +402,11 @@ class DenoiseLatentsInvocation(BaseInvocation):
def prep_control_data( def prep_control_data(
self, self,
context: InvocationContext, context: InvocationContext,
control_input: Union[ControlField, List[ControlField]], control_input: Optional[Union[ControlField, List[ControlField]]],
latents_shape: List[int], latents_shape: List[int],
exit_stack: ExitStack, exit_stack: ExitStack,
do_classifier_free_guidance: bool = True, do_classifier_free_guidance: bool = True,
) -> List[ControlNetData]: ) -> Optional[List[ControlNetData]]:
# Assuming fixed dimensional scaling of LATENT_SCALE_FACTOR. # Assuming fixed dimensional scaling of LATENT_SCALE_FACTOR.
control_height_resize = latents_shape[2] * LATENT_SCALE_FACTOR control_height_resize = latents_shape[2] * LATENT_SCALE_FACTOR
control_width_resize = latents_shape[3] * LATENT_SCALE_FACTOR control_width_resize = latents_shape[3] * LATENT_SCALE_FACTOR
@ -421,7 +429,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
controlnet_data = [] controlnet_data = []
for control_info in control_list: for control_info in control_list:
control_model = exit_stack.enter_context( control_model = exit_stack.enter_context(
context.services.model_records.load_model( context.services.model_manager.load.load_model_by_key(
key=control_info.control_model.key, key=control_info.control_model.key,
context=context, context=context,
) )
@ -487,23 +495,25 @@ class DenoiseLatentsInvocation(BaseInvocation):
conditioning_data.ip_adapter_conditioning = [] conditioning_data.ip_adapter_conditioning = []
for single_ip_adapter in ip_adapter: for single_ip_adapter in ip_adapter:
ip_adapter_model: Union[IPAdapter, IPAdapterPlus] = exit_stack.enter_context( ip_adapter_model: Union[IPAdapter, IPAdapterPlus] = exit_stack.enter_context(
context.services.model_records.load_model( context.services.model_manager.load.load_model_by_key(
key=single_ip_adapter.ip_adapter_model.key, key=single_ip_adapter.ip_adapter_model.key,
context=context, context=context,
) )
) )
image_encoder_model_info = context.services.model_records.load_model( image_encoder_model_info = context.services.model_manager.load.load_model_by_key(
key=single_ip_adapter.image_encoder_model.key, key=single_ip_adapter.image_encoder_model.key,
context=context, context=context,
) )
# `single_ip_adapter.image` could be a list or a single ImageField. Normalize to a list here. # `single_ip_adapter.image` could be a list or a single ImageField. Normalize to a list here.
single_ipa_images = single_ip_adapter.image single_ipa_image_fields = single_ip_adapter.image
if not isinstance(single_ipa_images, list): if not isinstance(single_ipa_image_fields, list):
single_ipa_images = [single_ipa_images] single_ipa_image_fields = [single_ipa_image_fields]
single_ipa_images = [context.services.images.get_pil_image(image.image_name) for image in single_ipa_images] single_ipa_images = [
context.services.images.get_pil_image(image.image_name) for image in single_ipa_image_fields
]
# TODO(ryand): With some effort, the step of running the CLIP Vision encoder could be done before any other # TODO(ryand): With some effort, the step of running the CLIP Vision encoder could be done before any other
# models are needed in memory. This would help to reduce peak memory utilization in low-memory environments. # models are needed in memory. This would help to reduce peak memory utilization in low-memory environments.
@ -547,21 +557,19 @@ class DenoiseLatentsInvocation(BaseInvocation):
t2i_adapter_data = [] t2i_adapter_data = []
for t2i_adapter_field in t2i_adapter: for t2i_adapter_field in t2i_adapter:
t2i_adapter_model_info = context.services.model_records.load_model( t2i_adapter_model_info = context.services.model_manager.load.load_model_by_key(
key=t2i_adapter_field.t2i_adapter_model.key, key=t2i_adapter_field.t2i_adapter_model.key,
context=context, context=context,
) )
image = context.services.images.get_pil_image(t2i_adapter_field.image.image_name) image = context.services.images.get_pil_image(t2i_adapter_field.image.image_name)
# The max_unet_downscale is the maximum amount that the UNet model downscales the latent image internally. # The max_unet_downscale is the maximum amount that the UNet model downscales the latent image internally.
if t2i_adapter_field.t2i_adapter_model.base_model == BaseModelType.StableDiffusion1: if t2i_adapter_model_info.base == BaseModelType.StableDiffusion1:
max_unet_downscale = 8 max_unet_downscale = 8
elif t2i_adapter_field.t2i_adapter_model.base_model == BaseModelType.StableDiffusionXL: elif t2i_adapter_model_info.base == BaseModelType.StableDiffusionXL:
max_unet_downscale = 4 max_unet_downscale = 4
else: else:
raise ValueError( raise ValueError(f"Unexpected T2I-Adapter base model type: '{t2i_adapter_model_info.base}'.")
f"Unexpected T2I-Adapter base model type: '{t2i_adapter_field.t2i_adapter_model.base_model}'."
)
t2i_adapter_model: T2IAdapter t2i_adapter_model: T2IAdapter
with t2i_adapter_model_info as t2i_adapter_model: with t2i_adapter_model_info as t2i_adapter_model:
@ -609,7 +617,15 @@ class DenoiseLatentsInvocation(BaseInvocation):
# original idea by https://github.com/AmericanPresidentJimmyCarter # original idea by https://github.com/AmericanPresidentJimmyCarter
# TODO: research more for second order schedulers timesteps # TODO: research more for second order schedulers timesteps
def init_scheduler(self, scheduler, device, steps, denoising_start, denoising_end): def init_scheduler(
self,
scheduler: Union[Scheduler, ConfigMixin],
device: torch.device,
steps: int,
denoising_start: float,
denoising_end: float,
) -> Tuple[int, List[int], int]:
assert isinstance(scheduler, ConfigMixin)
if scheduler.config.get("cpu_only", False): if scheduler.config.get("cpu_only", False):
scheduler.set_timesteps(steps, device="cpu") scheduler.set_timesteps(steps, device="cpu")
timesteps = scheduler.timesteps.to(device=device) timesteps = scheduler.timesteps.to(device=device)
@ -621,11 +637,11 @@ class DenoiseLatentsInvocation(BaseInvocation):
_timesteps = timesteps[:: scheduler.order] _timesteps = timesteps[:: scheduler.order]
# get start timestep index # get start timestep index
t_start_val = int(round(scheduler.config.num_train_timesteps * (1 - denoising_start))) t_start_val = int(round(scheduler.config["num_train_timesteps"] * (1 - denoising_start)))
t_start_idx = len(list(filter(lambda ts: ts >= t_start_val, _timesteps))) t_start_idx = len(list(filter(lambda ts: ts >= t_start_val, _timesteps)))
# get end timestep index # get end timestep index
t_end_val = int(round(scheduler.config.num_train_timesteps * (1 - denoising_end))) t_end_val = int(round(scheduler.config["num_train_timesteps"] * (1 - denoising_end)))
t_end_idx = len(list(filter(lambda ts: ts >= t_end_val, _timesteps[t_start_idx:]))) t_end_idx = len(list(filter(lambda ts: ts >= t_end_val, _timesteps[t_start_idx:])))
# apply order to indexes # apply order to indexes
@ -638,7 +654,9 @@ class DenoiseLatentsInvocation(BaseInvocation):
return num_inference_steps, timesteps, init_timestep return num_inference_steps, timesteps, init_timestep
def prep_inpaint_mask(self, context, latents): def prep_inpaint_mask(
self, context: InvocationContext, latents: torch.Tensor
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
if self.denoise_mask is None: if self.denoise_mask is None:
return None, None return None, None
@ -691,12 +709,15 @@ class DenoiseLatentsInvocation(BaseInvocation):
graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id) graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id)
source_node_id = graph_execution_state.prepared_source_mapping[self.id] source_node_id = graph_execution_state.prepared_source_mapping[self.id]
def step_callback(state: PipelineIntermediateState): # get the unet's config so that we can pass the base to dispatch_progress()
self.dispatch_progress(context, source_node_id, state, self.unet.unet.base_model) unet_config = context.services.model_manager.store.get_model(**self.unet.unet.model_dump())
def _lora_loader() -> Iterator[Tuple[AnyModel, float]]: def step_callback(state: PipelineIntermediateState) -> None:
self.dispatch_progress(context, source_node_id, state, unet_config.base)
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
for lora in self.unet.loras: for lora in self.unet.loras:
lora_info = context.services.model_records.load_model( lora_info = context.services.model_manager.load.load_model_by_key(
**lora.model_dump(exclude={"weight"}), **lora.model_dump(exclude={"weight"}),
context=context, context=context,
) )
@ -704,7 +725,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
del lora_info del lora_info
return return
unet_info = context.services.model_records.load_model( unet_info = context.services.model_manager.load.load_model_by_key(
**self.unet.unet.model_dump(), **self.unet.unet.model_dump(),
context=context, context=context,
) )
@ -815,7 +836,7 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata):
def invoke(self, context: InvocationContext) -> ImageOutput: def invoke(self, context: InvocationContext) -> ImageOutput:
latents = context.services.latents.get(self.latents.latents_name) latents = context.services.latents.get(self.latents.latents_name)
vae_info = context.services.model_records.load_model( vae_info = context.services.model_manager.load.load_model_by_key(
**self.vae.vae.model_dump(), **self.vae.vae.model_dump(),
context=context, context=context,
) )
@ -1010,8 +1031,9 @@ class ImageToLatentsInvocation(BaseInvocation):
fp32: bool = InputField(default=DEFAULT_PRECISION == "float32", description=FieldDescriptions.fp32) fp32: bool = InputField(default=DEFAULT_PRECISION == "float32", description=FieldDescriptions.fp32)
@staticmethod @staticmethod
def vae_encode(vae_info, upcast, tiled, image_tensor): def vae_encode(vae_info: LoadedModel, upcast: bool, tiled: bool, image_tensor: torch.Tensor) -> torch.Tensor:
with vae_info as vae: with vae_info as vae:
assert isinstance(vae, torch.nn.Module)
orig_dtype = vae.dtype orig_dtype = vae.dtype
if upcast: if upcast:
vae.to(dtype=torch.float32) vae.to(dtype=torch.float32)
@ -1057,7 +1079,7 @@ class ImageToLatentsInvocation(BaseInvocation):
def invoke(self, context: InvocationContext) -> LatentsOutput: def invoke(self, context: InvocationContext) -> LatentsOutput:
image = context.services.images.get_pil_image(self.image.image_name) image = context.services.images.get_pil_image(self.image.image_name)
vae_info = context.services.model_records.load_model( vae_info = context.services.model_manager.load.load_model_by_key(
**self.vae.vae.model_dump(), **self.vae.vae.model_dump(),
context=context, context=context,
) )
@ -1076,14 +1098,19 @@ class ImageToLatentsInvocation(BaseInvocation):
@singledispatchmethod @singledispatchmethod
@staticmethod @staticmethod
def _encode_to_tensor(vae: AutoencoderKL, image_tensor: torch.FloatTensor) -> torch.FloatTensor: def _encode_to_tensor(vae: AutoencoderKL, image_tensor: torch.FloatTensor) -> torch.FloatTensor:
assert isinstance(vae, torch.nn.Module)
image_tensor_dist = vae.encode(image_tensor).latent_dist image_tensor_dist = vae.encode(image_tensor).latent_dist
latents = image_tensor_dist.sample().to(dtype=vae.dtype) # FIXME: uses torch.randn. make reproducible! latents: torch.Tensor = image_tensor_dist.sample().to(
dtype=vae.dtype
) # FIXME: uses torch.randn. make reproducible!
return latents return latents
@_encode_to_tensor.register @_encode_to_tensor.register
@staticmethod @staticmethod
def _(vae: AutoencoderTiny, image_tensor: torch.FloatTensor) -> torch.FloatTensor: def _(vae: AutoencoderTiny, image_tensor: torch.FloatTensor) -> torch.FloatTensor:
return vae.encode(image_tensor).latents assert isinstance(vae, torch.nn.Module)
latents: torch.FloatTensor = vae.encode(image_tensor).latents
return latents
@invocation( @invocation(
@ -1116,7 +1143,12 @@ class BlendLatentsInvocation(BaseInvocation):
# TODO: # TODO:
device = choose_torch_device() device = choose_torch_device()
def slerp(t, v0, v1, DOT_THRESHOLD=0.9995): def slerp(
t: Union[float, npt.NDArray[Any]], # FIXME: maybe use np.float32 here?
v0: Union[torch.Tensor, npt.NDArray[Any]],
v1: Union[torch.Tensor, npt.NDArray[Any]],
DOT_THRESHOLD: float = 0.9995,
) -> Union[torch.Tensor, npt.NDArray[Any]]:
""" """
Spherical linear interpolation Spherical linear interpolation
Args: Args:
@ -1149,12 +1181,16 @@ class BlendLatentsInvocation(BaseInvocation):
v2 = s0 * v0 + s1 * v1 v2 = s0 * v0 + s1 * v1
if inputs_are_torch: if inputs_are_torch:
v2 = torch.from_numpy(v2).to(device) v2_torch: torch.Tensor = torch.from_numpy(v2).to(device)
return v2_torch
return v2 else:
assert isinstance(v2, np.ndarray)
return v2
# blend # blend
blended_latents = slerp(self.alpha, latents_a, latents_b) bl = slerp(self.alpha, latents_a, latents_b)
assert isinstance(bl, torch.Tensor)
blended_latents: torch.Tensor = bl # for type checking convenience
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699 # https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
blended_latents = blended_latents.to("cpu") blended_latents = blended_latents.to("cpu")
@ -1250,15 +1286,19 @@ class IdealSizeInvocation(BaseInvocation):
description="Amount to multiply the model's dimensions by when calculating the ideal size (may result in initial generation artifacts if too large)", description="Amount to multiply the model's dimensions by when calculating the ideal size (may result in initial generation artifacts if too large)",
) )
def trim_to_multiple_of(self, *args, multiple_of=LATENT_SCALE_FACTOR): def trim_to_multiple_of(self, *args: int, multiple_of: int = LATENT_SCALE_FACTOR) -> Tuple[int, ...]:
return tuple((x - x % multiple_of) for x in args) return tuple((x - x % multiple_of) for x in args)
def invoke(self, context: InvocationContext) -> IdealSizeOutput: def invoke(self, context: InvocationContext) -> IdealSizeOutput:
unet_config = context.services.model_manager.load.load_model_by_key(
**self.unet.unet.model_dump(),
context=context,
)
aspect = self.width / self.height aspect = self.width / self.height
dimension = 512 dimension: float = 512
if self.unet.unet.base_model == BaseModelType.StableDiffusion2: if unet_config.base == BaseModelType.StableDiffusion2:
dimension = 768 dimension = 768
elif self.unet.unet.base_model == BaseModelType.StableDiffusionXL: elif unet_config.base == BaseModelType.StableDiffusionXL:
dimension = 1024 dimension = 1024
dimension = dimension * self.multiplier dimension = dimension * self.multiplier
min_dimension = math.floor(dimension * 0.5) min_dimension = math.floor(dimension * 0.5)

View File

@ -20,7 +20,7 @@ from .baseinvocation import (
class ModelInfo(BaseModel): class ModelInfo(BaseModel):
key: str = Field(description="Info to load submodel") key: str = Field(description="Key of model as returned by ModelRecordServiceBase.get_model()")
submodel: Optional[SubModelType] = Field(default=None, description="Info to load submodel") submodel: Optional[SubModelType] = Field(default=None, description="Info to load submodel")

View File

@ -15,8 +15,8 @@ from invokeai.app.invocations.primitives import ConditioningField, ConditioningO
from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin
from invokeai.app.shared.fields import FieldDescriptions from invokeai.app.shared.fields import FieldDescriptions
from invokeai.app.util.step_callback import stable_diffusion_step_callback from invokeai.app.util.step_callback import stable_diffusion_step_callback
from invokeai.backend import ModelType, SubModelType
from invokeai.backend.embeddings.model_patcher import ONNXModelPatcher from invokeai.backend.embeddings.model_patcher import ONNXModelPatcher
from invokeai.backend.model_manager import ModelType, SubModelType
from ...backend.stable_diffusion import PipelineIntermediateState from ...backend.stable_diffusion import PipelineIntermediateState
from ...backend.util import choose_torch_device from ...backend.util import choose_torch_device
@ -62,16 +62,16 @@ class ONNXPromptInvocation(BaseInvocation):
clip: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection) clip: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection)
def invoke(self, context: InvocationContext) -> ConditioningOutput: def invoke(self, context: InvocationContext) -> ConditioningOutput:
tokenizer_info = context.services.model_records.load_model( tokenizer_info = context.services.model_manager.load.load_model_by_key(
**self.clip.tokenizer.model_dump(), **self.clip.tokenizer.model_dump(),
) )
text_encoder_info = context.services.model_records.load_model( text_encoder_info = context.services.model_manager.load.load_model_by_key(
**self.clip.text_encoder.model_dump(), **self.clip.text_encoder.model_dump(),
) )
with tokenizer_info as orig_tokenizer, text_encoder_info as text_encoder: # , ExitStack() as stack: with tokenizer_info as orig_tokenizer, text_encoder_info as text_encoder: # , ExitStack() as stack:
loras = [ loras = [
( (
context.services.model_records.load_model(**lora.model_dump(exclude={"weight"})).model, context.services.model_manager.load.load_model_by_key(**lora.model_dump(exclude={"weight"})).model,
lora.weight, lora.weight,
) )
for lora in self.clip.loras for lora in self.clip.loras
@ -84,7 +84,7 @@ class ONNXPromptInvocation(BaseInvocation):
ti_list.append( ti_list.append(
( (
name, name,
context.services.model_records.load_model_by_attr( context.services.model_manager.load.load_model_by_attr(
model_name=name, model_name=name,
base_model=text_encoder_info.config.base, base_model=text_encoder_info.config.base,
model_type=ModelType.TextualInversion, model_type=ModelType.TextualInversion,
@ -257,13 +257,13 @@ class ONNXTextToLatentsInvocation(BaseInvocation):
eta=0.0, eta=0.0,
) )
unet_info = context.services.model_records.load_model(**self.unet.unet.model_dump()) unet_info = context.services.model_manager.load.load_model_by_key(**self.unet.unet.model_dump())
with unet_info as unet: # , ExitStack() as stack: with unet_info as unet: # , ExitStack() as stack:
# loras = [(stack.enter_context(context.services.model_manager.get_model(**lora.dict(exclude={"weight"}))), lora.weight) for lora in self.unet.loras] # loras = [(stack.enter_context(context.services.model_manager.get_model(**lora.dict(exclude={"weight"}))), lora.weight) for lora in self.unet.loras]
loras = [ loras = [
( (
context.services.model_records.load_model(**lora.model_dump(exclude={"weight"})).model, context.services.model_manager.load.load_model_by_key(**lora.model_dump(exclude={"weight"})).model,
lora.weight, lora.weight,
) )
for lora in self.unet.loras for lora in self.unet.loras
@ -346,7 +346,7 @@ class ONNXLatentsToImageInvocation(BaseInvocation, WithMetadata):
if self.vae.vae.submodel != SubModelType.VaeDecoder: if self.vae.vae.submodel != SubModelType.VaeDecoder:
raise Exception(f"Expected vae_decoder, found: {self.vae.vae.submodel}") raise Exception(f"Expected vae_decoder, found: {self.vae.vae.submodel}")
vae_info = context.services.model_records.load_model( vae_info = context.services.model_manager.load.load_model_by_key(
**self.vae.vae.model_dump(), **self.vae.vae.model_dump(),
) )

View File

@ -368,7 +368,7 @@ class LatentsCollectionInvocation(BaseInvocation):
return LatentsCollectionOutput(collection=self.collection) return LatentsCollectionOutput(collection=self.collection)
def build_latents_output(latents_name: str, latents: torch.Tensor, seed: Optional[int] = None): def build_latents_output(latents_name: str, latents: torch.Tensor, seed: Optional[int] = None) -> LatentsOutput:
return LatentsOutput( return LatentsOutput(
latents=LatentsField(latents_name=latents_name, seed=seed), latents=LatentsField(latents_name=latents_name, seed=seed),
width=latents.size()[3] * 8, width=latents.size()[3] * 8,

View File

@ -22,9 +22,7 @@ if TYPE_CHECKING:
from .invocation_stats.invocation_stats_base import InvocationStatsServiceBase from .invocation_stats.invocation_stats_base import InvocationStatsServiceBase
from .item_storage.item_storage_base import ItemStorageABC from .item_storage.item_storage_base import ItemStorageABC
from .latents_storage.latents_storage_base import LatentsStorageBase from .latents_storage.latents_storage_base import LatentsStorageBase
from .model_install import ModelInstallServiceBase
from .model_manager.model_manager_base import ModelManagerServiceBase from .model_manager.model_manager_base import ModelManagerServiceBase
from .model_records import ModelRecordServiceBase
from .names.names_base import NameServiceBase from .names.names_base import NameServiceBase
from .session_processor.session_processor_base import SessionProcessorBase from .session_processor.session_processor_base import SessionProcessorBase
from .session_queue.session_queue_base import SessionQueueBase from .session_queue.session_queue_base import SessionQueueBase
@ -50,9 +48,7 @@ class InvocationServices:
latents: "LatentsStorageBase" latents: "LatentsStorageBase"
logger: "Logger" logger: "Logger"
model_manager: "ModelManagerServiceBase" model_manager: "ModelManagerServiceBase"
model_records: "ModelRecordServiceBase"
download_queue: "DownloadQueueServiceBase" download_queue: "DownloadQueueServiceBase"
model_install: "ModelInstallServiceBase"
processor: "InvocationProcessorABC" processor: "InvocationProcessorABC"
performance_statistics: "InvocationStatsServiceBase" performance_statistics: "InvocationStatsServiceBase"
queue: "InvocationQueueABC" queue: "InvocationQueueABC"
@ -78,9 +74,7 @@ class InvocationServices:
latents: "LatentsStorageBase", latents: "LatentsStorageBase",
logger: "Logger", logger: "Logger",
model_manager: "ModelManagerServiceBase", model_manager: "ModelManagerServiceBase",
model_records: "ModelRecordServiceBase",
download_queue: "DownloadQueueServiceBase", download_queue: "DownloadQueueServiceBase",
model_install: "ModelInstallServiceBase",
processor: "InvocationProcessorABC", processor: "InvocationProcessorABC",
performance_statistics: "InvocationStatsServiceBase", performance_statistics: "InvocationStatsServiceBase",
queue: "InvocationQueueABC", queue: "InvocationQueueABC",
@ -104,9 +98,7 @@ class InvocationServices:
self.latents = latents self.latents = latents
self.logger = logger self.logger = logger
self.model_manager = model_manager self.model_manager = model_manager
self.model_records = model_records
self.download_queue = download_queue self.download_queue = download_queue
self.model_install = model_install
self.processor = processor self.processor = processor
self.performance_statistics = performance_statistics self.performance_statistics = performance_statistics
self.queue = queue self.queue = queue

View File

@ -43,8 +43,10 @@ class InvocationStatsService(InvocationStatsServiceBase):
@contextmanager @contextmanager
def collect_stats(self, invocation: BaseInvocation, graph_execution_state_id: str) -> Iterator[None]: def collect_stats(self, invocation: BaseInvocation, graph_execution_state_id: str) -> Iterator[None]:
# This is to handle case of the model manager not being initialized, which happens
# during some tests.
services = self._invoker.services services = self._invoker.services
if services.model_records is None or services.model_records.loader is None: if services.model_manager is None or services.model_manager.load is None:
yield None yield None
if not self._stats.get(graph_execution_state_id): if not self._stats.get(graph_execution_state_id):
# First time we're seeing this graph_execution_state_id. # First time we're seeing this graph_execution_state_id.
@ -60,9 +62,8 @@ class InvocationStatsService(InvocationStatsServiceBase):
if torch.cuda.is_available(): if torch.cuda.is_available():
torch.cuda.reset_peak_memory_stats() torch.cuda.reset_peak_memory_stats()
# TO DO [LS]: clean up loader service - shouldn't be an attribute of model records assert services.model_manager.load is not None
assert services.model_records.loader is not None services.model_manager.load.ram_cache.stats = self._cache_stats[graph_execution_state_id]
services.model_records.loader.ram_cache.stats = self._cache_stats[graph_execution_state_id]
try: try:
# Let the invocation run. # Let the invocation run.

View File

@ -5,7 +5,7 @@ from typing import Callable, Union
import torch import torch
from ..compel import ConditioningFieldData from invokeai.app.invocations.compel import ConditioningFieldData
class LatentsStorageBase(ABC): class LatentsStorageBase(ABC):

View File

@ -5,9 +5,9 @@ from typing import Union
import torch import torch
from invokeai.app.invocations.compel import ConditioningFieldData
from invokeai.app.services.invoker import Invoker from invokeai.app.services.invoker import Invoker
from ..compel import ConditioningFieldData
from .latents_storage_base import LatentsStorageBase from .latents_storage_base import LatentsStorageBase

View File

@ -5,9 +5,9 @@ from typing import Dict, Optional, Union
import torch import torch
from invokeai.app.invocations.compel import ConditioningFieldData
from invokeai.app.services.invoker import Invoker from invokeai.app.services.invoker import Invoker
from ..compel import ConditioningFieldData
from .latents_storage_base import LatentsStorageBase from .latents_storage_base import LatentsStorageBase

View File

@ -4,7 +4,8 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Optional from typing import Optional
from invokeai.backend.model_manager import AnyModelConfig, SubModelType from invokeai.app.invocations.baseinvocation import InvocationContext
from invokeai.backend.model_manager import AnyModelConfig, BaseModelType, ModelType, SubModelType
from invokeai.backend.model_manager.load import LoadedModel from invokeai.backend.model_manager.load import LoadedModel
@ -12,11 +13,60 @@ class ModelLoadServiceBase(ABC):
"""Wrapper around AnyModelLoader.""" """Wrapper around AnyModelLoader."""
@abstractmethod @abstractmethod
def load_model_by_key(self, key: str, submodel_type: Optional[SubModelType] = None) -> LoadedModel: def load_model_by_key(
"""Given a model's key, load it and return the LoadedModel object.""" self,
key: str,
submodel_type: Optional[SubModelType] = None,
context: Optional[InvocationContext] = None,
) -> LoadedModel:
"""
Given a model's key, load it and return the LoadedModel object.
:param key: Key of model config to be fetched.
:param submodel: For main (pipeline models), the submodel to fetch.
:param context: Invocation context used for event reporting
"""
pass pass
@abstractmethod @abstractmethod
def load_model_by_config(self, config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> LoadedModel: def load_model_by_config(
"""Given a model's configuration, load it and return the LoadedModel object.""" self,
model_config: AnyModelConfig,
submodel_type: Optional[SubModelType] = None,
context: Optional[InvocationContext] = None,
) -> LoadedModel:
"""
Given a model's configuration, load it and return the LoadedModel object.
:param model_config: Model configuration record (as returned by ModelRecordBase.get_model())
:param submodel: For main (pipeline models), the submodel to fetch.
:param context: Invocation context used for event reporting
"""
pass pass
@abstractmethod
def load_model_by_attr(
self,
model_name: str,
base_model: BaseModelType,
model_type: ModelType,
submodel: Optional[SubModelType] = None,
context: Optional[InvocationContext] = None,
) -> LoadedModel:
"""
Given a model's attributes, search the database for it, and if found, load and return the LoadedModel object.
This is provided for API compatability with the get_model() method
in the original model manager. However, note that LoadedModel is
not the same as the original ModelInfo that ws returned.
:param model_name: Name of to be fetched.
:param base_model: Base model
:param model_type: Type of the model
:param submodel: For main (pipeline models), the submodel to fetch
:param context: The invocation context.
Exceptions: UnknownModelException -- model with these attributes not known
NotImplementedException -- a model loader was not provided at initialization time
ValueError -- more than one model matches this combination
"""

View File

@ -3,12 +3,14 @@
from typing import Optional from typing import Optional
from invokeai.app.invocations.baseinvocation import InvocationContext
from invokeai.app.services.config import InvokeAIAppConfig from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.model_records import ModelRecordServiceBase from invokeai.app.services.invocation_processor.invocation_processor_common import CanceledException
from invokeai.backend.model_manager import AnyModelConfig, SubModelType from invokeai.app.services.model_records import ModelRecordServiceBase, UnknownModelException
from invokeai.backend.model_manager import AnyModel, AnyModelConfig, BaseModelType, ModelType, SubModelType
from invokeai.backend.model_manager.load import AnyModelLoader, LoadedModel, ModelCache, ModelConvertCache from invokeai.backend.model_manager.load import AnyModelLoader, LoadedModel, ModelCache, ModelConvertCache
from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase
from invokeai.backend.model_manager.load.ram_cache import ModelCacheBase from invokeai.backend.model_manager.load.model_cache import ModelCacheBase
from invokeai.backend.util.logging import InvokeAILogger from invokeai.backend.util.logging import InvokeAILogger
from .model_load_base import ModelLoadServiceBase from .model_load_base import ModelLoadServiceBase
@ -21,7 +23,7 @@ class ModelLoadService(ModelLoadServiceBase):
self, self,
app_config: InvokeAIAppConfig, app_config: InvokeAIAppConfig,
record_store: ModelRecordServiceBase, record_store: ModelRecordServiceBase,
ram_cache: Optional[ModelCacheBase] = None, ram_cache: Optional[ModelCacheBase[AnyModel]] = None,
convert_cache: Optional[ModelConvertCacheBase] = None, convert_cache: Optional[ModelConvertCacheBase] = None,
): ):
"""Initialize the model load service.""" """Initialize the model load service."""
@ -44,11 +46,104 @@ class ModelLoadService(ModelLoadServiceBase):
), ),
) )
def load_model_by_key(self, key: str, submodel_type: Optional[SubModelType] = None) -> LoadedModel: def load_model_by_key(
"""Given a model's key, load it and return the LoadedModel object.""" self,
config = self._store.get_model(key) key: str,
return self.load_model_by_config(config, submodel_type) submodel_type: Optional[SubModelType] = None,
context: Optional[InvocationContext] = None,
) -> LoadedModel:
"""
Given a model's key, load it and return the LoadedModel object.
def load_model_by_config(self, config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> LoadedModel: :param key: Key of model config to be fetched.
"""Given a model's configuration, load it and return the LoadedModel object.""" :param submodel: For main (pipeline models), the submodel to fetch.
return self._any_loader.load_model(config, submodel_type) :param context: Invocation context used for event reporting
"""
config = self._store.get_model(key)
return self.load_model_by_config(config, submodel_type, context)
def load_model_by_attr(
self,
model_name: str,
base_model: BaseModelType,
model_type: ModelType,
submodel: Optional[SubModelType] = None,
context: Optional[InvocationContext] = None,
) -> LoadedModel:
"""
Given a model's attributes, search the database for it, and if found, load and return the LoadedModel object.
This is provided for API compatability with the get_model() method
in the original model manager. However, note that LoadedModel is
not the same as the original ModelInfo that ws returned.
:param model_name: Name of to be fetched.
:param base_model: Base model
:param model_type: Type of the model
:param submodel: For main (pipeline models), the submodel to fetch
:param context: The invocation context.
Exceptions: UnknownModelException -- model with this key not known
NotImplementedException -- a model loader was not provided at initialization time
ValueError -- more than one model matches this combination
"""
configs = self._store.search_by_attr(model_name, base_model, model_type)
if len(configs) == 0:
raise UnknownModelException(f"{base_model}/{model_type}/{model_name}: Unknown model")
elif len(configs) > 1:
raise ValueError(f"{base_model}/{model_type}/{model_name}: More than one model matches.")
else:
return self.load_model_by_key(configs[0].key, submodel)
def load_model_by_config(
self,
model_config: AnyModelConfig,
submodel_type: Optional[SubModelType] = None,
context: Optional[InvocationContext] = None,
) -> LoadedModel:
"""
Given a model's configuration, load it and return the LoadedModel object.
:param model_config: Model configuration record (as returned by ModelRecordBase.get_model())
:param submodel: For main (pipeline models), the submodel to fetch.
:param context: Invocation context used for event reporting
"""
if context:
self._emit_load_event(
context=context,
model_config=model_config,
)
loaded_model = self._any_loader.load_model(model_config, submodel_type)
if context:
self._emit_load_event(
context=context,
model_config=model_config,
loaded=True,
)
return loaded_model
def _emit_load_event(
self,
context: InvocationContext,
model_config: AnyModelConfig,
loaded: Optional[bool] = False,
) -> None:
if context.services.queue.is_canceled(context.graph_execution_state_id):
raise CanceledException()
if not loaded:
context.services.events.emit_model_load_started(
queue_id=context.queue_id,
queue_item_id=context.queue_item_id,
queue_batch_id=context.queue_batch_id,
graph_execution_state_id=context.graph_execution_state_id,
model_config=model_config,
)
else:
context.services.events.emit_model_load_completed(
queue_id=context.queue_id,
queue_item_id=context.queue_item_id,
queue_batch_id=context.queue_batch_id,
graph_execution_state_id=context.graph_execution_state_id,
model_config=model_config,
)

View File

@ -2,24 +2,26 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from pydantic import BaseModel, Field
from typing_extensions import Self from typing_extensions import Self
from invokeai.app.services.invoker import Invoker
from ..config import InvokeAIAppConfig from ..config import InvokeAIAppConfig
from ..events.events_base import EventServiceBase
from ..download import DownloadQueueServiceBase from ..download import DownloadQueueServiceBase
from ..events.events_base import EventServiceBase
from ..model_install import ModelInstallServiceBase from ..model_install import ModelInstallServiceBase
from ..model_load import ModelLoadServiceBase from ..model_load import ModelLoadServiceBase
from ..model_records import ModelRecordServiceBase from ..model_records import ModelRecordServiceBase
from ..shared.sqlite.sqlite_database import SqliteDatabase from ..shared.sqlite.sqlite_database import SqliteDatabase
class ModelManagerServiceBase(BaseModel, ABC): class ModelManagerServiceBase(ABC):
"""Abstract base class for the model manager service.""" """Abstract base class for the model manager service."""
store: ModelRecordServiceBase = Field(description="An instance of the model record configuration service.") # attributes:
install: ModelInstallServiceBase = Field(description="An instance of the model install service.") # store: ModelRecordServiceBase = Field(description="An instance of the model record configuration service.")
load: ModelLoadServiceBase = Field(description="An instance of the model load service.") # install: ModelInstallServiceBase = Field(description="An instance of the model install service.")
# load: ModelLoadServiceBase = Field(description="An instance of the model load service.")
@classmethod @classmethod
@abstractmethod @abstractmethod
@ -37,3 +39,29 @@ class ModelManagerServiceBase(BaseModel, ABC):
method simplifies the construction considerably. method simplifies the construction considerably.
""" """
pass pass
@property
@abstractmethod
def store(self) -> ModelRecordServiceBase:
"""Return the ModelRecordServiceBase used to store and retrieve configuration records."""
pass
@property
@abstractmethod
def load(self) -> ModelLoadServiceBase:
"""Return the ModelLoadServiceBase used to load models from their configuration records."""
pass
@property
@abstractmethod
def install(self) -> ModelInstallServiceBase:
"""Return the ModelInstallServiceBase used to download and manipulate model files."""
pass
@abstractmethod
def start(self, invoker: Invoker) -> None:
pass
@abstractmethod
def stop(self, invoker: Invoker) -> None:
pass

View File

@ -3,6 +3,7 @@
from typing_extensions import Self from typing_extensions import Self
from invokeai.app.services.invoker import Invoker
from invokeai.backend.model_manager.load import ModelCache, ModelConvertCache from invokeai.backend.model_manager.load import ModelCache, ModelConvertCache
from invokeai.backend.model_manager.metadata import ModelMetadataStore from invokeai.backend.model_manager.metadata import ModelMetadataStore
from invokeai.backend.util.logging import InvokeAILogger from invokeai.backend.util.logging import InvokeAILogger
@ -10,9 +11,9 @@ from invokeai.backend.util.logging import InvokeAILogger
from ..config import InvokeAIAppConfig from ..config import InvokeAIAppConfig
from ..download import DownloadQueueServiceBase from ..download import DownloadQueueServiceBase
from ..events.events_base import EventServiceBase from ..events.events_base import EventServiceBase
from ..model_install import ModelInstallService from ..model_install import ModelInstallService, ModelInstallServiceBase
from ..model_load import ModelLoadService from ..model_load import ModelLoadService, ModelLoadServiceBase
from ..model_records import ModelRecordServiceSQL from ..model_records import ModelRecordServiceBase, ModelRecordServiceSQL
from ..shared.sqlite.sqlite_database import SqliteDatabase from ..shared.sqlite.sqlite_database import SqliteDatabase
from .model_manager_base import ModelManagerServiceBase from .model_manager_base import ModelManagerServiceBase
@ -27,6 +28,38 @@ class ModelManagerService(ModelManagerServiceBase):
model_manager.load -- Routines to load models into memory. model_manager.load -- Routines to load models into memory.
""" """
def __init__(
self,
store: ModelRecordServiceBase,
install: ModelInstallServiceBase,
load: ModelLoadServiceBase,
):
self._store = store
self._install = install
self._load = load
@property
def store(self) -> ModelRecordServiceBase:
return self._store
@property
def install(self) -> ModelInstallServiceBase:
return self._install
@property
def load(self) -> ModelLoadServiceBase:
return self._load
def start(self, invoker: Invoker) -> None:
for service in [self._store, self._install, self._load]:
if hasattr(service, "start"):
service.start(invoker)
def stop(self, invoker: Invoker) -> None:
for service in [self._store, self._install, self._load]:
if hasattr(service, "stop"):
service.stop(invoker)
@classmethod @classmethod
def build_model_manager( def build_model_manager(
cls, cls,

View File

@ -10,15 +10,12 @@ from typing import Any, Dict, List, Optional, Set, Tuple, Union
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from invokeai.app.invocations.baseinvocation import InvocationContext
from invokeai.app.services.shared.pagination import PaginatedResults from invokeai.app.services.shared.pagination import PaginatedResults
from invokeai.backend.model_manager import ( from invokeai.backend.model_manager import (
AnyModelConfig, AnyModelConfig,
BaseModelType, BaseModelType,
LoadedModel,
ModelFormat, ModelFormat,
ModelType, ModelType,
SubModelType,
) )
from invokeai.backend.model_manager.load import AnyModelLoader from invokeai.backend.model_manager.load import AnyModelLoader
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata, ModelMetadataStore from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata, ModelMetadataStore
@ -111,52 +108,6 @@ class ModelRecordServiceBase(ABC):
""" """
pass pass
@abstractmethod
def load_model(
self,
key: str,
submodel: Optional[SubModelType] = None,
context: Optional[InvocationContext] = None,
) -> LoadedModel:
"""
Load the indicated model into memory and return a LoadedModel object.
:param key: Key of model config to be fetched.
:param submodel: For main (pipeline models), the submodel to fetch
:param context: Invocation context, used for event issuing.
Exceptions: UnknownModelException -- model with this key not known
NotImplementedException -- a model loader was not provided at initialization time
"""
pass
@abstractmethod
def load_model_by_attr(
self,
model_name: str,
base_model: BaseModelType,
model_type: ModelType,
submodel: Optional[SubModelType] = None,
context: Optional[InvocationContext] = None,
) -> LoadedModel:
"""
Load the indicated model into memory and return a LoadedModel object.
This is provided for API compatability with the get_model() method
in the original model manager. However, note that LoadedModel is
not the same as the original ModelInfo that ws returned.
:param model_name: Key of model config to be fetched.
:param base_model: Base model
:param model_type: Type of the model
:param submodel: For main (pipeline models), the submodel to fetch
:param context: The invocation context.
Exceptions: UnknownModelException -- model with this key not known
NotImplementedException -- a model loader was not provided at initialization time
"""
pass
@property @property
@abstractmethod @abstractmethod
def metadata_store(self) -> ModelMetadataStore: def metadata_store(self) -> ModelMetadataStore:

View File

@ -46,8 +46,6 @@ from math import ceil
from pathlib import Path from pathlib import Path
from typing import Any, Dict, List, Optional, Set, Tuple, Union from typing import Any, Dict, List, Optional, Set, Tuple, Union
from invokeai.app.invocations.baseinvocation import InvocationContext
from invokeai.app.services.invocation_processor.invocation_processor_common import CanceledException
from invokeai.app.services.shared.pagination import PaginatedResults from invokeai.app.services.shared.pagination import PaginatedResults
from invokeai.backend.model_manager.config import ( from invokeai.backend.model_manager.config import (
AnyModelConfig, AnyModelConfig,
@ -55,9 +53,8 @@ from invokeai.backend.model_manager.config import (
ModelConfigFactory, ModelConfigFactory,
ModelFormat, ModelFormat,
ModelType, ModelType,
SubModelType,
) )
from invokeai.backend.model_manager.load import AnyModelLoader, LoadedModel from invokeai.backend.model_manager.load import AnyModelLoader
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata, ModelMetadataStore, UnknownMetadataException from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata, ModelMetadataStore, UnknownMetadataException
from ..shared.sqlite.sqlite_database import SqliteDatabase from ..shared.sqlite.sqlite_database import SqliteDatabase
@ -220,74 +217,6 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
model = ModelConfigFactory.make_config(json.loads(rows[0]), timestamp=rows[1]) model = ModelConfigFactory.make_config(json.loads(rows[0]), timestamp=rows[1])
return model return model
def load_model(
self,
key: str,
submodel: Optional[SubModelType],
context: Optional[InvocationContext] = None,
) -> LoadedModel:
"""
Load the indicated model into memory and return a LoadedModel object.
:param key: Key of model config to be fetched.
:param submodel: For main (pipeline models), the submodel to fetch.
:param context: Invocation context used for event reporting
Exceptions: UnknownModelException -- model with this key not known
NotImplementedException -- a model loader was not provided at initialization time
"""
if not self._loader:
raise NotImplementedError(f"Class {self.__class__} was not initialized with a model loader")
# we can emit model loading events if we are executing with access to the invocation context
model_config = self.get_model(key)
if context:
self._emit_load_event(
context=context,
model_config=model_config,
)
loaded_model = self._loader.load_model(model_config, submodel)
if context:
self._emit_load_event(
context=context,
model_config=model_config,
loaded=True,
)
return loaded_model
def load_model_by_attr(
self,
model_name: str,
base_model: BaseModelType,
model_type: ModelType,
submodel: Optional[SubModelType] = None,
context: Optional[InvocationContext] = None,
) -> LoadedModel:
"""
Load the indicated model into memory and return a LoadedModel object.
This is provided for API compatability with the get_model() method
in the original model manager. However, note that LoadedModel is
not the same as the original ModelInfo that ws returned.
:param model_name: Key of model config to be fetched.
:param base_model: Base model
:param model_type: Type of the model
:param submodel: For main (pipeline models), the submodel to fetch
:param context: The invocation context.
Exceptions: UnknownModelException -- model with this key not known
NotImplementedException -- a model loader was not provided at initialization time
ValueError -- more than one model matches this combination
"""
configs = self.search_by_attr(model_name, base_model, model_type)
if len(configs) == 0:
raise UnknownModelException(f"{base_model}/{model_type}/{model_name}: Unknown model")
elif len(configs) > 1:
raise ValueError(f"{base_model}/{model_type}/{model_name}: More than one model matches.")
else:
return self.load_model(configs[0].key, submodel)
def exists(self, key: str) -> bool: def exists(self, key: str) -> bool:
""" """
Return True if a model with the indicated key exists in the databse. Return True if a model with the indicated key exists in the databse.
@ -476,29 +405,3 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
return PaginatedResults( return PaginatedResults(
page=page, pages=ceil(total / per_page), per_page=per_page, total=total, items=items page=page, pages=ceil(total / per_page), per_page=per_page, total=total, items=items
) )
def _emit_load_event(
self,
context: InvocationContext,
model_config: AnyModelConfig,
loaded: Optional[bool] = False,
) -> None:
if context.services.queue.is_canceled(context.graph_execution_state_id):
raise CanceledException()
if not loaded:
context.services.events.emit_model_load_started(
queue_id=context.queue_id,
queue_item_id=context.queue_item_id,
queue_batch_id=context.queue_batch_id,
graph_execution_state_id=context.graph_execution_state_id,
model_config=model_config,
)
else:
context.services.events.emit_model_load_completed(
queue_id=context.queue_id,
queue_item_id=context.queue_item_id,
queue_batch_id=context.queue_batch_id,
graph_execution_state_id=context.graph_execution_state_id,
model_config=model_config,
)

View File

@ -6,6 +6,7 @@ from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import
class Migration6Callback: class Migration6Callback:
def __call__(self, cursor: sqlite3.Cursor) -> None: def __call__(self, cursor: sqlite3.Cursor) -> None:
self._recreate_model_triggers(cursor) self._recreate_model_triggers(cursor)
self._delete_ip_adapters(cursor)
def _recreate_model_triggers(self, cursor: sqlite3.Cursor) -> None: def _recreate_model_triggers(self, cursor: sqlite3.Cursor) -> None:
""" """
@ -26,6 +27,22 @@ class Migration6Callback:
""" """
) )
def _delete_ip_adapters(self, cursor: sqlite3.Cursor) -> None:
"""
Delete all the IP adapters.
The model manager will automatically find and re-add them after the migration
is done. This allows the manager to add the correct image encoder to their
configuration records.
"""
cursor.execute(
"""--sql
DELETE FROM model_config
WHERE type='ip_adapter';
"""
)
def build_migration_6() -> Migration: def build_migration_6() -> Migration:
""" """
@ -33,6 +50,8 @@ def build_migration_6() -> Migration:
This migration does the following: This migration does the following:
- Adds the model_config_updated_at trigger if it does not exist - Adds the model_config_updated_at trigger if it does not exist
- Delete all ip_adapter models so that the model prober can find and
update with the correct image processor model.
""" """
migration_6 = Migration( migration_6 = Migration(
from_version=5, from_version=5,

View File

@ -64,7 +64,7 @@ class ModelPatcher:
def apply_lora_unet( def apply_lora_unet(
cls, cls,
unet: UNet2DConditionModel, unet: UNet2DConditionModel,
loras: List[Tuple[LoRAModelRaw, float]], loras: Iterator[Tuple[LoRAModelRaw, float]],
) -> None: ) -> None:
with cls.apply_lora(unet, loras, "lora_unet_"): with cls.apply_lora(unet, loras, "lora_unet_"):
yield yield
@ -307,7 +307,7 @@ class ONNXModelPatcher:
def apply_lora_unet( def apply_lora_unet(
cls, cls,
unet: OnnxRuntimeModel, unet: OnnxRuntimeModel,
loras: List[Tuple[LoRAModelRaw, float]], loras: Iterator[Tuple[LoRAModelRaw, float]],
) -> None: ) -> None:
with cls.apply_lora(unet, loras, "lora_unet_"): with cls.apply_lora(unet, loras, "lora_unet_"):
yield yield

View File

@ -8,8 +8,8 @@ from PIL import Image
import invokeai.backend.util.logging as logger import invokeai.backend.util.logging as logger
from invokeai.app.services.config import InvokeAIAppConfig from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.backend import SilenceWarnings
from invokeai.backend.util.devices import choose_torch_device from invokeai.backend.util.devices import choose_torch_device
from invokeai.backend.util.silence_warnings import SilenceWarnings
config = InvokeAIAppConfig.get_config() config = InvokeAIAppConfig.get_config()

View File

@ -8,7 +8,6 @@ from PIL import Image
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
from invokeai.backend.ip_adapter.ip_attention_weights import IPAttentionWeights from invokeai.backend.ip_adapter.ip_attention_weights import IPAttentionWeights
from invokeai.backend.model_management.models.base import calc_model_size_by_data
from .resampler import Resampler from .resampler import Resampler
@ -124,6 +123,9 @@ class IPAdapter:
self.attn_weights.to(device=self.device, dtype=self.dtype) self.attn_weights.to(device=self.device, dtype=self.dtype)
def calc_size(self): def calc_size(self):
# workaround for circular import
from invokeai.backend.model_manager.load.model_util import calc_model_size_by_data
return calc_model_size_by_data(self._image_proj_model) + calc_model_size_by_data(self.attn_weights) return calc_model_size_by_data(self._image_proj_model) + calc_model_size_by_data(self.attn_weights)
def _init_image_proj_model(self, state_dict): def _init_image_proj_model(self, state_dict):

View File

@ -21,7 +21,7 @@ Validation errors will raise an InvalidModelConfigException error.
""" """
import time import time
from enum import Enum from enum import Enum
from typing import Literal, Optional, Type, Union, Class from typing import Literal, Optional, Type, Union
import torch import torch
from diffusers import ModelMixin from diffusers import ModelMixin
@ -335,7 +335,7 @@ class ModelConfigFactory(object):
cls, cls,
model_data: Union[Dict[str, Any], AnyModelConfig], model_data: Union[Dict[str, Any], AnyModelConfig],
key: Optional[str] = None, key: Optional[str] = None,
dest_class: Optional[Type[Class]] = None, dest_class: Optional[Type[ModelConfigBase]] = None,
timestamp: Optional[float] = None, timestamp: Optional[float] = None,
) -> AnyModelConfig: ) -> AnyModelConfig:
""" """
@ -347,14 +347,17 @@ class ModelConfigFactory(object):
:param dest_class: The config class to be returned. If not provided, will :param dest_class: The config class to be returned. If not provided, will
be selected automatically. be selected automatically.
""" """
model: Optional[ModelConfigBase] = None
if isinstance(model_data, ModelConfigBase): if isinstance(model_data, ModelConfigBase):
model = model_data model = model_data
elif dest_class: elif dest_class:
model = dest_class.validate_python(model_data) model = dest_class.model_validate(model_data)
else: else:
model = AnyModelConfigValidator.validate_python(model_data) # mypy doesn't typecheck TypeAdapters well?
model = AnyModelConfigValidator.validate_python(model_data) # type: ignore
assert model is not None
if key: if key:
model.key = key model.key = key
if timestamp: if timestamp:
model.last_modified = timestamp model.last_modified = timestamp
return model return model # type: ignore

View File

@ -18,8 +18,16 @@ from pathlib import Path
from typing import Any, Callable, Dict, Optional, Tuple, Type from typing import Any, Callable, Dict, Optional, Tuple, Type
from invokeai.app.services.config import InvokeAIAppConfig from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.backend.model_manager import AnyModelConfig, BaseModelType, ModelFormat, ModelType, SubModelType from invokeai.backend.model_manager.config import (
from invokeai.backend.model_manager.config import AnyModel, VaeCheckpointConfig, VaeDiffusersConfig AnyModel,
AnyModelConfig,
BaseModelType,
ModelFormat,
ModelType,
SubModelType,
VaeCheckpointConfig,
VaeDiffusersConfig,
)
from invokeai.backend.model_manager.load.convert_cache.convert_cache_base import ModelConvertCacheBase from invokeai.backend.model_manager.load.convert_cache.convert_cache_base import ModelConvertCacheBase
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase, ModelLockerBase from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase, ModelLockerBase
from invokeai.backend.util.logging import InvokeAILogger from invokeai.backend.util.logging import InvokeAILogger
@ -32,7 +40,7 @@ class LoadedModel:
config: AnyModelConfig config: AnyModelConfig
locker: ModelLockerBase locker: ModelLockerBase
def __enter__(self) -> AnyModel: # I think load_file() always returns a dict def __enter__(self) -> AnyModel:
"""Context entry.""" """Context entry."""
self.locker.lock() self.locker.lock()
return self.model return self.model
@ -171,6 +179,10 @@ class AnyModelLoader:
def decorator(subclass: Type[ModelLoaderBase]) -> Type[ModelLoaderBase]: def decorator(subclass: Type[ModelLoaderBase]) -> Type[ModelLoaderBase]:
cls._logger.debug(f"Registering class {subclass.__name__} to load models of type {base}/{type}/{format}") cls._logger.debug(f"Registering class {subclass.__name__} to load models of type {base}/{type}/{format}")
key = cls._to_registry_key(base, type, format) key = cls._to_registry_key(base, type, format)
if key in cls._registry:
raise Exception(
f"{subclass.__name__} is trying to register as a loader for {base}/{type}/{format}, but this type of model has already been registered by {cls._registry[key].__name__}"
)
cls._registry[key] = subclass cls._registry[key] = subclass
return subclass return subclass

View File

@ -169,7 +169,7 @@ class ModelLoader(ModelLoaderBase):
raise InvalidModelConfigException("An expected config.json file is missing from this model.") from e raise InvalidModelConfigException("An expected config.json file is missing from this model.") from e
# This needs to be implemented in subclasses that handle checkpoints # This needs to be implemented in subclasses that handle checkpoints
def _convert_model(self, config: AnyModelConfig, weights_path: Path, output_path: Path) -> Path: def _convert_model(self, config: AnyModelConfig, model_path: Path, output_path: Path) -> Path:
raise NotImplementedError raise NotImplementedError
# This needs to be implemented in the subclass # This needs to be implemented in the subclass

View File

@ -246,7 +246,7 @@ class ModelCache(ModelCacheBase[AnyModel]):
def move_model_to_device(self, cache_entry: CacheRecord[AnyModel], target_device: torch.device) -> None: def move_model_to_device(self, cache_entry: CacheRecord[AnyModel], target_device: torch.device) -> None:
"""Move model into the indicated device.""" """Move model into the indicated device."""
# These attributes are not in the base ModelMixin class but in derived classes. # These attributes are not in the base ModelMixin class but in various derived classes.
# Some models don't have these attributes, in which case they run in RAM/CPU. # Some models don't have these attributes, in which case they run in RAM/CPU.
self.logger.debug(f"Called to move {cache_entry.key} to {target_device}") self.logger.debug(f"Called to move {cache_entry.key} to {target_device}")
if not (hasattr(cache_entry.model, "device") and hasattr(cache_entry.model, "to")): if not (hasattr(cache_entry.model, "device") and hasattr(cache_entry.model, "to")):

View File

@ -35,28 +35,28 @@ class ControlnetLoader(GenericDiffusersLoader):
else: else:
return True return True
def _convert_model(self, config: AnyModelConfig, weights_path: Path, output_path: Path) -> Path: def _convert_model(self, config: AnyModelConfig, model_path: Path, output_path: Path) -> Path:
if config.base not in {BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2}: if config.base not in {BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2}:
raise Exception(f"Vae conversion not supported for model type: {config.base}") raise Exception(f"Vae conversion not supported for model type: {config.base}")
else: else:
assert hasattr(config, "config") assert hasattr(config, "config")
config_file = config.config config_file = config.config
if weights_path.suffix == ".safetensors": if model_path.suffix == ".safetensors":
checkpoint = safetensors.torch.load_file(weights_path, device="cpu") checkpoint = safetensors.torch.load_file(model_path, device="cpu")
else: else:
checkpoint = torch.load(weights_path, map_location="cpu") checkpoint = torch.load(model_path, map_location="cpu")
# sometimes weights are hidden under "state_dict", and sometimes not # sometimes weights are hidden under "state_dict", and sometimes not
if "state_dict" in checkpoint: if "state_dict" in checkpoint:
checkpoint = checkpoint["state_dict"] checkpoint = checkpoint["state_dict"]
convert_controlnet_to_diffusers( convert_controlnet_to_diffusers(
weights_path, model_path,
output_path, output_path,
original_config_file=self._app_config.root_path / config_file, original_config_file=self._app_config.root_path / config_file,
image_size=512, image_size=512,
scan_needed=True, scan_needed=True,
from_safetensors=weights_path.suffix == ".safetensors", from_safetensors=model_path.suffix == ".safetensors",
) )
return output_path return output_path

View File

@ -12,8 +12,9 @@ from invokeai.backend.model_manager import (
ModelType, ModelType,
SubModelType, SubModelType,
) )
from invokeai.backend.model_manager.load.load_base import AnyModelLoader
from invokeai.backend.model_manager.load.load_default import ModelLoader from ..load_base import AnyModelLoader
from ..load_default import ModelLoader
@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.CLIPVision, format=ModelFormat.Diffusers) @AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.CLIPVision, format=ModelFormat.Diffusers)

View File

@ -65,7 +65,7 @@ class StableDiffusionDiffusersModel(ModelLoader):
else: else:
return True return True
def _convert_model(self, config: AnyModelConfig, weights_path: Path, output_path: Path) -> Path: def _convert_model(self, config: AnyModelConfig, model_path: Path, output_path: Path) -> Path:
assert isinstance(config, MainCheckpointConfig) assert isinstance(config, MainCheckpointConfig)
variant = config.variant variant = config.variant
base = config.base base = config.base
@ -75,9 +75,9 @@ class StableDiffusionDiffusersModel(ModelLoader):
config_file = config.config config_file = config.config
self._logger.info(f"Converting {weights_path} to diffusers format") self._logger.info(f"Converting {model_path} to diffusers format")
convert_ckpt_to_diffusers( convert_ckpt_to_diffusers(
weights_path, model_path,
output_path, output_path,
model_type=self.model_base_to_model_type[base], model_type=self.model_base_to_model_type[base],
model_version=base, model_version=base,
@ -86,7 +86,7 @@ class StableDiffusionDiffusersModel(ModelLoader):
extract_ema=True, extract_ema=True,
scan_needed=True, scan_needed=True,
pipeline_class=pipeline_class, pipeline_class=pipeline_class,
from_safetensors=weights_path.suffix == ".safetensors", from_safetensors=model_path.suffix == ".safetensors",
precision=self._torch_dtype, precision=self._torch_dtype,
load_safety_checker=False, load_safety_checker=False,
) )

View File

@ -37,7 +37,7 @@ class VaeLoader(GenericDiffusersLoader):
else: else:
return True return True
def _convert_model(self, config: AnyModelConfig, weights_path: Path, output_path: Path) -> Path: def _convert_model(self, config: AnyModelConfig, model_path: Path, output_path: Path) -> Path:
# TO DO: check whether sdxl VAE models convert. # TO DO: check whether sdxl VAE models convert.
if config.base not in {BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2}: if config.base not in {BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2}:
raise Exception(f"Vae conversion not supported for model type: {config.base}") raise Exception(f"Vae conversion not supported for model type: {config.base}")
@ -46,10 +46,10 @@ class VaeLoader(GenericDiffusersLoader):
"v1-inference.yaml" if config.base == BaseModelType.StableDiffusion1 else "v2-inference-v.yaml" "v1-inference.yaml" if config.base == BaseModelType.StableDiffusion1 else "v2-inference-v.yaml"
) )
if weights_path.suffix == ".safetensors": if model_path.suffix == ".safetensors":
checkpoint = safetensors.torch.load_file(weights_path, device="cpu") checkpoint = safetensors.torch.load_file(model_path, device="cpu")
else: else:
checkpoint = torch.load(weights_path, map_location="cpu") checkpoint = torch.load(model_path, map_location="cpu")
# sometimes weights are hidden under "state_dict", and sometimes not # sometimes weights are hidden under "state_dict", and sometimes not
if "state_dict" in checkpoint: if "state_dict" in checkpoint:

View File

@ -65,7 +65,7 @@ def calc_model_size_by_fs(model_path: Path, subfolder: Optional[str] = None, var
bit8_files = {f for f in all_files if ".8bit." in f.name or ".8bit-" in f.name} bit8_files = {f for f in all_files if ".8bit." in f.name or ".8bit-" in f.name}
other_files = set(all_files) - fp16_files - bit8_files other_files = set(all_files) - fp16_files - bit8_files
if variant is None: if not variant: # ModelRepoVariant.DEFAULT evaluates to empty string for compatability with HF
files = other_files files = other_files
elif variant == "fp16": elif variant == "fp16":
files = fp16_files files = fp16_files

View File

@ -22,11 +22,12 @@ Example usage:
import os import os
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from logging import Logger
from pathlib import Path from pathlib import Path
from typing import Callable, Optional, Set, Union from typing import Callable, Optional, Set, Union
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from logging import Logger
from invokeai.backend.util.logging import InvokeAILogger from invokeai.backend.util.logging import InvokeAILogger
default_logger: Logger = InvokeAILogger.get_logger() default_logger: Logger = InvokeAILogger.get_logger()

View File

@ -1 +1,3 @@
from .schedulers import SCHEDULER_MAP # noqa: F401 from .schedulers import SCHEDULER_MAP # noqa: F401
__all__ = ["SCHEDULER_MAP"]

View File

@ -513,7 +513,7 @@ def select_and_download_models(opt: Namespace) -> None:
"""Prompt user for install/delete selections and execute.""" """Prompt user for install/delete selections and execute."""
precision = "float32" if opt.full_precision else choose_precision(torch.device(choose_torch_device())) precision = "float32" if opt.full_precision else choose_precision(torch.device(choose_torch_device()))
# unsure how to avoid a typing complaint in the next line: config.precision is an enumerated Literal # unsure how to avoid a typing complaint in the next line: config.precision is an enumerated Literal
config.precision = precision # type: ignore config.precision = precision
install_helper = InstallHelper(config, logger) install_helper = InstallHelper(config, logger)
installer = install_helper.installer installer = install_helper.installer

View File

@ -64,9 +64,7 @@ def mock_services() -> InvocationServices:
latents=None, # type: ignore latents=None, # type: ignore
logger=logging, # type: ignore logger=logging, # type: ignore
model_manager=None, # type: ignore model_manager=None, # type: ignore
model_records=None, # type: ignore
download_queue=None, # type: ignore download_queue=None, # type: ignore
model_install=None, # type: ignore
names=None, # type: ignore names=None, # type: ignore
performance_statistics=InvocationStatsService(), performance_statistics=InvocationStatsService(),
processor=DefaultInvocationProcessor(), processor=DefaultInvocationProcessor(),

View File

@ -66,9 +66,7 @@ def mock_services() -> InvocationServices:
latents=None, # type: ignore latents=None, # type: ignore
logger=logging, # type: ignore logger=logging, # type: ignore
model_manager=None, # type: ignore model_manager=None, # type: ignore
model_records=None, # type: ignore
download_queue=None, # type: ignore download_queue=None, # type: ignore
model_install=None, # type: ignore
names=None, # type: ignore names=None, # type: ignore
performance_statistics=InvocationStatsService(), performance_statistics=InvocationStatsService(),
processor=DefaultInvocationProcessor(), processor=DefaultInvocationProcessor(),

View File

@ -0,0 +1,22 @@
"""
Test model loading
"""
from pathlib import Path
from invokeai.app.services.model_install import ModelInstallServiceBase
from invokeai.backend.embeddings.textual_inversion import TextualInversionModelRaw
from invokeai.backend.model_manager.load import AnyModelLoader
from tests.backend.model_manager_2.model_manager_2_fixtures import * # noqa F403
def test_loading(mm2_installer: ModelInstallServiceBase, mm2_loader: AnyModelLoader, embedding_file: Path):
store = mm2_installer.record_store
matches = store.search_by_attr(model_name="test_embedding")
assert len(matches) == 0
key = mm2_installer.register_path(embedding_file)
loaded_model = mm2_loader.load_model(store.get_model(key))
assert loaded_model is not None
assert loaded_model.config.key == key
with loaded_model as model:
assert isinstance(model, TextualInversionModelRaw)

View File

@ -20,6 +20,7 @@ from invokeai.backend.model_manager.config import (
ModelFormat, ModelFormat,
ModelType, ModelType,
) )
from invokeai.backend.model_manager.load import AnyModelLoader, ModelCache, ModelConvertCache
from invokeai.backend.model_manager.metadata import ModelMetadataStore from invokeai.backend.model_manager.metadata import ModelMetadataStore
from invokeai.backend.util.logging import InvokeAILogger from invokeai.backend.util.logging import InvokeAILogger
from tests.backend.model_manager_2.model_metadata.metadata_examples import ( from tests.backend.model_manager_2.model_metadata.metadata_examples import (
@ -89,6 +90,16 @@ def mm2_app_config(mm2_root_dir: Path) -> InvokeAIAppConfig:
return app_config return app_config
@pytest.fixture
def mm2_loader(mm2_app_config: InvokeAIAppConfig, mm2_record_store: ModelRecordServiceSQL) -> AnyModelLoader:
logger = InvokeAILogger.get_logger(config=mm2_app_config)
ram_cache = ModelCache(
logger=logger, max_cache_size=mm2_app_config.ram_cache_size, max_vram_cache_size=mm2_app_config.vram_cache_size
)
convert_cache = ModelConvertCache(mm2_app_config.models_convert_cache_path)
return AnyModelLoader(app_config=mm2_app_config, logger=logger, ram_cache=ram_cache, convert_cache=convert_cache)
@pytest.fixture @pytest.fixture
def mm2_record_store(mm2_app_config: InvokeAIAppConfig) -> ModelRecordServiceSQL: def mm2_record_store(mm2_app_config: InvokeAIAppConfig) -> ModelRecordServiceSQL:
logger = InvokeAILogger.get_logger(config=mm2_app_config) logger = InvokeAILogger.get_logger(config=mm2_app_config)