make model manager v2 ready for PR review

- Replace legacy model manager service with the v2 manager.

- Update invocations to use new load interface.

- Fixed many but not all type checking errors in the invocations. Most
  were unrelated to model manager

- Updated routes. All the new routes live under the route tag
  `model_manager_v2`. To avoid confusion with the old routes,
  they have the URL prefix `/api/v2/models`. The old routes
  have been de-registered.

- Added a pytest for the loader.

- Updated documentation in contributing/MODEL_MANAGER.md
This commit is contained in:
Lincoln Stein 2024-02-10 18:09:45 -05:00 committed by psychedelicious
parent 7956602b19
commit a23dedd2ee
36 changed files with 680 additions and 435 deletions

View File

@ -28,7 +28,7 @@ model. These are the:
Hugging Face, as well as discriminating among model versions in
Civitai, but can be used for arbitrary content.
* _ModelLoadServiceBase_ (**CURRENTLY UNDER DEVELOPMENT - NOT IMPLEMENTED**)
* _ModelLoadServiceBase_
Responsible for loading a model from disk
into RAM and VRAM and getting it ready for inference.
@ -41,10 +41,10 @@ The four main services can be found in
* `invokeai/app/services/model_records/`
* `invokeai/app/services/model_install/`
* `invokeai/app/services/downloads/`
* `invokeai/app/services/model_loader/` (**under development**)
* `invokeai/app/services/model_load/`
Code related to the FastAPI web API can be found in
`invokeai/app/api/routers/model_records.py`.
`invokeai/app/api/routers/model_manager_v2.py`.
***
@ -84,10 +84,10 @@ diffusers model. When this happens, `original_hash` is unchanged, but
`ModelType`, `ModelFormat` and `BaseModelType` are string enums that
are defined in `invokeai.backend.model_manager.config`. They are also
imported by, and can be reexported from,
`invokeai.app.services.model_record_service`:
`invokeai.app.services.model_manager.model_records`:
```
from invokeai.app.services.model_record_service import ModelType, ModelFormat, BaseModelType
from invokeai.app.services.model_records import ModelType, ModelFormat, BaseModelType
```
The `path` field can be absolute or relative. If relative, it is taken
@ -123,7 +123,7 @@ taken to be the `models_dir` directory.
`variant` is an enumerated string class with values `normal`,
`inpaint` and `depth`. If needed, it can be imported if needed from
either `invokeai.app.services.model_record_service` or
either `invokeai.app.services.model_records` or
`invokeai.backend.model_manager.config`.
### ONNXSD2Config
@ -134,7 +134,7 @@ either `invokeai.app.services.model_record_service` or
| `upcast_attention` | bool | Model requires its attention module to be upcast |
The `SchedulerPredictionType` enum can be imported from either
`invokeai.app.services.model_record_service` or
`invokeai.app.services.model_records` or
`invokeai.backend.model_manager.config`.
### Other config classes
@ -157,15 +157,6 @@ indicates that the model is compatible with any of the base
models. This works OK for some models, such as the IP Adapter image
encoders, but is an all-or-nothing proposition.
Another issue is that the config class hierarchy is paralleled to some
extent by a `ModelBase` class hierarchy defined in
`invokeai.backend.model_manager.models.base` and its subclasses. These
are classes representing the models after they are loaded into RAM and
include runtime information such as load status and bytes used. Some
of the fields, including `name`, `model_type` and `base_model`, are
shared between `ModelConfigBase` and `ModelBase`, and this is a
potential source of confusion.
## Reading and Writing Model Configuration Records
The `ModelRecordService` provides the ability to retrieve model
@ -177,11 +168,11 @@ initialization and can be retrieved within an invocation from the
`InvocationContext` object:
```
store = context.services.model_record_store
store = context.services.model_manager.store
```
or from elsewhere in the code by accessing
`ApiDependencies.invoker.services.model_record_store`.
`ApiDependencies.invoker.services.model_manager.store`.
### Creating a `ModelRecordService`
@ -190,7 +181,7 @@ you can directly create either a `ModelRecordServiceSQL` or a
`ModelRecordServiceFile` object:
```
from invokeai.app.services.model_record_service import ModelRecordServiceSQL, ModelRecordServiceFile
from invokeai.app.services.model_records import ModelRecordServiceSQL, ModelRecordServiceFile
store = ModelRecordServiceSQL.from_connection(connection, lock)
store = ModelRecordServiceSQL.from_db_file('/path/to/sqlite_database.db')
@ -252,7 +243,7 @@ So a typical startup pattern would be:
```
import sqlite3
from invokeai.app.services.thread import lock
from invokeai.app.services.model_record_service import ModelRecordServiceBase
from invokeai.app.services.model_records import ModelRecordServiceBase
from invokeai.app.services.config import InvokeAIAppConfig
config = InvokeAIAppConfig.get_config()
@ -260,19 +251,6 @@ db_conn = sqlite3.connect(config.db_path.as_posix(), check_same_thread=False)
store = ModelRecordServiceBase.open(config, db_conn, lock)
```
_A note on simultaneous access to `invokeai.db`_: The current InvokeAI
service architecture for the image and graph databases is careful to
use a shared sqlite3 connection and a thread lock to ensure that two
threads don't attempt to access the database simultaneously. However,
the default `sqlite3` library used by Python reports using
**Serialized** mode, which allows multiple threads to access the
database simultaneously using multiple database connections (see
https://www.sqlite.org/threadsafe.html and
https://ricardoanderegg.com/posts/python-sqlite-thread-safety/). Therefore
it should be safe to allow the record service to open its own SQLite
database connection. Opening a model record service should then be as
simple as `ModelRecordServiceBase.open(config)`.
### Fetching a Model's Configuration from `ModelRecordServiceBase`
Configurations can be retrieved in several ways.
@ -1465,7 +1443,7 @@ create alternative instances if you wish.
### Creating a ModelLoadService object
The class is defined in
`invokeai.app.services.model_loader_service`. It is initialized with
`invokeai.app.services.model_load`. It is initialized with
an InvokeAIAppConfig object, from which it gets configuration
information such as the user's desired GPU and precision, and with a
previously-created `ModelRecordServiceBase` object, from which it
@ -1475,8 +1453,8 @@ Here is a typical initialization pattern:
```
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.model_record_service import ModelRecordServiceBase
from invokeai.app.services.model_loader_service import ModelLoadService
from invokeai.app.services.model_records import ModelRecordServiceBase
from invokeai.app.services.model_load import ModelLoadService
config = InvokeAIAppConfig.get_config()
store = ModelRecordServiceBase.open(config)
@ -1487,14 +1465,11 @@ Note that we are relying on the contents of the application
configuration to choose the implementation of
`ModelRecordServiceBase`.
### get_model(key, [submodel_type], [context]) -> ModelInfo:
### load_model_by_key(key, [submodel_type], [context]) -> LoadedModel
*** TO DO: change to get_model(key, context=None, **kwargs)
The `get_model()` method, like its similarly-named cousin in
`ModelRecordService`, receives the unique key that identifies the
model. It loads the model into memory, gets the model ready for use,
and returns a `ModelInfo` object.
The `load_model_by_key()` method receives the unique key that
identifies the model. It loads the model into memory, gets the model
ready for use, and returns a `LoadedModel` object.
The optional second argument, `subtype` is a `SubModelType` string
enum, such as "vae". It is mandatory when used with a main model, and
@ -1504,46 +1479,64 @@ The optional third argument, `context` can be provided by
an invocation to trigger model load event reporting. See below for
details.
The returned `ModelInfo` object shares some fields in common with
`ModelConfigBase`, but is otherwise a completely different beast:
The returned `LoadedModel` object contains a copy of the configuration
record returned by the model record `get_model()` method, as well as
the in-memory loaded model:
| **Field Name** | **Type** | **Description** |
| **Attribute Name** | **Type** | **Description** |
|----------------|-----------------|------------------|
| `key` | str | The model key derived from the ModelRecordService database |
| `name` | str | Name of this model |
| `base_model` | BaseModelType | Base model for this model |
| `type` | ModelType or SubModelType | Either the model type (non-main) or the submodel type (main models)|
| `location` | Path or str | Location of the model on the filesystem |
| `precision` | torch.dtype | The torch.precision to use for inference |
| `context` | ModelCache.ModelLocker | A context class used to lock the model in VRAM while in use |
| `config` | AnyModelConfig | A copy of the model's configuration record for retrieving base type, etc. |
| `model` | AnyModel | The instantiated model (details below) |
| `locker` | ModelLockerBase | A context manager that mediates the movement of the model into VRAM |
The types for `ModelInfo` and `SubModelType` can be imported from
`invokeai.app.services.model_loader_service`.
Because the loader can return multiple model types, it is typed to
return `AnyModel`, a Union `ModelMixin`, `torch.nn.Module`,
`IAIOnnxRuntimeModel`, `IPAdapter`, `IPAdapterPlus`, and
`EmbeddingModelRaw`. `ModelMixin` is the base class of all diffusers
models, `EmbeddingModelRaw` is used for LoRA and TextualInversion
models. The others are obvious.
To use the model, you use the `ModelInfo` as a context manager using
the following pattern:
`LoadedModel` acts as a context manager. The context loads the model
into the execution device (e.g. VRAM on CUDA systems), locks the model
in the execution device for the duration of the context, and returns
the model. Use it like this:
```
model_info = loader.get_model('f13dd932c0c35c22dcb8d6cda4203764', SubModelType('vae'))
model_info = loader.get_model_by_key('f13dd932c0c35c22dcb8d6cda4203764', SubModelType('vae'))
with model_info as vae:
image = vae.decode(latents)[0]
```
The `vae` model will stay locked in the GPU during the period of time
it is in the context manager's scope.
`get_model_by_key()` may raise any of the following exceptions:
`get_model()` may raise any of the following exceptions:
- `UnknownModelException` -- key not in database
- `ModelNotFoundException` -- key in database but model not found at path
- `InvalidModelException` -- the model is guilty of a variety of sins
- `UnknownModelException` -- key not in database
- `ModelNotFoundException` -- key in database but model not found at path
- `NotImplementedException` -- the loader doesn't know how to load this type of model
** TO DO: ** Resolve discrepancy between ModelInfo.location and
ModelConfig.path.
### load_model_by_attr(model_name, base_model, model_type, [submodel], [context]) -> LoadedModel
This is similar to `load_model_by_key`, but instead it accepts the
combination of the model's name, type and base, which it passes to the
model record config store for retrieval. If successful, this method
returns a `LoadedModel`. It can raise the following exceptions:
```
UnknownModelException -- model with these attributes not known
NotImplementedException -- the loader doesn't know how to load this type of model
ValueError -- more than one model matches this combination of base/type/name
```
### load_model_by_config(config, [submodel], [context]) -> LoadedModel
This method takes an `AnyModelConfig` returned by
ModelRecordService.get_model() and returns the corresponding loaded
model. It may raise a `NotImplementedException`.
### Emitting model loading events
When the `context` argument is passed to `get_model()`, it will
When the `context` argument is passed to `load_model_*()`, it will
retrieve the invocation event bus from the passed `InvocationContext`
object to emit events on the invocation bus. The two events are
"model_load_started" and "model_load_completed". Both carry the
@ -1563,3 +1556,97 @@ payload=dict(
)
```
### Adding Model Loaders
Model loaders are small classes that inherit from the `ModelLoader`
base class. They typically implement one method `_load_model()` whose
signature is:
```
def _load_model(
self,
model_path: Path,
model_variant: Optional[ModelRepoVariant] = None,
submodel_type: Optional[SubModelType] = None,
) -> AnyModel:
```
`_load_model()` will be passed the path to the model on disk, an
optional repository variant (used by the diffusers loaders to select,
e.g. the `fp16` variant, and an optional submodel_type for main and
onnx models.
To install a new loader, place it in
`invokeai/backend/model_manager/load/model_loaders`. Inherit from
`ModelLoader` and use the `@AnyModelLoader.register()` decorator to
indicate what type of models the loader can handle.
Here is a complete example from `generic_diffusers.py`, which is able
to load several different diffusers types:
```
from pathlib import Path
from typing import Optional
from invokeai.backend.model_manager import (
AnyModel,
BaseModelType,
ModelFormat,
ModelRepoVariant,
ModelType,
SubModelType,
)
from ..load_base import AnyModelLoader
from ..load_default import ModelLoader
@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.CLIPVision, format=ModelFormat.Diffusers)
@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.T2IAdapter, format=ModelFormat.Diffusers)
class GenericDiffusersLoader(ModelLoader):
"""Class to load simple diffusers models."""
def _load_model(
self,
model_path: Path,
model_variant: Optional[ModelRepoVariant] = None,
submodel_type: Optional[SubModelType] = None,
) -> AnyModel:
model_class = self._get_hf_load_class(model_path)
if submodel_type is not None:
raise Exception(f"There are no submodels in models of type {model_class}")
variant = model_variant.value if model_variant else None
result: AnyModel = model_class.from_pretrained(model_path, torch_dtype=self._torch_dtype, variant=variant) # type: ignore
return result
```
Note that a loader can register itself to handle several different
model types. An exception will be raised if more than one loader tries
to register the same model type.
#### Conversion
Some models require conversion to diffusers format before they can be
loaded. These loaders should override two additional methods:
```
_needs_conversion(self, config: AnyModelConfig, model_path: Path, dest_path: Path) -> bool
_convert_model(self, config: AnyModelConfig, model_path: Path, output_path: Path) -> Path:
```
The first method accepts the model configuration, the path to where
the unmodified model is currently installed, and a proposed
destination for the converted model. This method returns True if the
model needs to be converted. It typically does this by comparing the
last modification time of the original model file to the modification
time of the converted model. In some cases you will also want to check
the modification date of the configuration record, in the event that
the user has changed something like the scheduler prediction type that
will require the model to be re-converted. See `controlnet.py` for an
example of this logic.
The second method accepts the model configuration, the path to the
original model on disk, and the desired output path for the converted
model. It does whatever it needs to do to get the model into diffusers
format, and returns the Path of the resulting model. (The path should
ordinarily be the same as `output_path`.)

View File

@ -8,9 +8,6 @@ from invokeai.app.services.item_storage.item_storage_memory import ItemStorageMe
from invokeai.app.services.object_serializer.object_serializer_disk import ObjectSerializerDisk
from invokeai.app.services.object_serializer.object_serializer_forward_cache import ObjectSerializerForwardCache
from invokeai.app.services.shared.sqlite.sqlite_util import init_db
from invokeai.backend.model_manager.load import AnyModelLoader, ModelConvertCache
from invokeai.backend.model_manager.load.model_cache import ModelCache
from invokeai.backend.model_manager.metadata import ModelMetadataStore
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData
from invokeai.backend.util.logging import InvokeAILogger
from invokeai.version.invokeai_version import __version__
@ -30,9 +27,7 @@ from ..services.invocation_queue.invocation_queue_memory import MemoryInvocation
from ..services.invocation_services import InvocationServices
from ..services.invocation_stats.invocation_stats_default import InvocationStatsService
from ..services.invoker import Invoker
from ..services.model_install import ModelInstallService
from ..services.model_manager.model_manager_default import ModelManagerService
from ..services.model_records import ModelRecordServiceSQL
from ..services.names.names_default import SimpleNameService
from ..services.session_processor.session_processor_default import DefaultSessionProcessor
from ..services.session_queue.session_queue_sqlite import SqliteSessionQueue
@ -98,28 +93,10 @@ class ApiDependencies:
conditioning = ObjectSerializerForwardCache(
ObjectSerializerDisk[ConditioningFieldData](output_folder / "conditioning", ephemeral=True)
)
model_manager = ModelManagerService(config, logger)
model_record_service = ModelRecordServiceSQL(db=db)
model_loader = AnyModelLoader(
app_config=config,
logger=logger,
ram_cache=ModelCache(
max_cache_size=config.ram_cache_size, max_vram_cache_size=config.vram_cache_size, logger=logger
),
convert_cache=ModelConvertCache(
cache_path=config.models_convert_cache_path, max_size=config.convert_cache_size
),
)
model_record_service = ModelRecordServiceSQL(db=db, loader=model_loader)
download_queue_service = DownloadQueueService(event_bus=events)
model_install_service = ModelInstallService(
app_config=config,
record_store=model_record_service,
download_queue=download_queue_service,
metadata_store=ModelMetadataStore(db=db),
event_bus=events,
model_manager = ModelManagerService.build_model_manager(
app_config=configuration, db=db, download_queue=download_queue_service, events=events
)
model_manager = ModelManagerService(config, logger) # TO DO: legacy model manager v1. Remove
names = SimpleNameService()
performance_statistics = InvocationStatsService()
processor = DefaultInvocationProcessor()
@ -143,9 +120,7 @@ class ApiDependencies:
invocation_cache=invocation_cache,
logger=logger,
model_manager=model_manager,
model_records=model_record_service,
download_queue=download_queue_service,
model_install=model_install_service,
names=names,
performance_statistics=performance_statistics,
processor=processor,

View File

@ -32,7 +32,7 @@ from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata
from ..dependencies import ApiDependencies
model_records_router = APIRouter(prefix="/v1/model/record", tags=["model_manager_v2_unstable"])
model_manager_v2_router = APIRouter(prefix="/v2/models", tags=["model_manager_v2"])
class ModelsList(BaseModel):
@ -52,7 +52,7 @@ class ModelTagSet(BaseModel):
tags: Set[str]
@model_records_router.get(
@model_manager_v2_router.get(
"/",
operation_id="list_model_records",
)
@ -65,7 +65,7 @@ async def list_model_records(
),
) -> ModelsList:
"""Get a list of models."""
record_store = ApiDependencies.invoker.services.model_records
record_store = ApiDependencies.invoker.services.model_manager.store
found_models: list[AnyModelConfig] = []
if base_models:
for base_model in base_models:
@ -81,7 +81,7 @@ async def list_model_records(
return ModelsList(models=found_models)
@model_records_router.get(
@model_manager_v2_router.get(
"/i/{key}",
operation_id="get_model_record",
responses={
@ -94,24 +94,27 @@ async def get_model_record(
key: str = Path(description="Key of the model record to fetch."),
) -> AnyModelConfig:
"""Get a model record"""
record_store = ApiDependencies.invoker.services.model_records
record_store = ApiDependencies.invoker.services.model_manager.store
try:
return record_store.get_model(key)
config: AnyModelConfig = record_store.get_model(key)
return config
except UnknownModelException as e:
raise HTTPException(status_code=404, detail=str(e))
@model_records_router.get("/meta", operation_id="list_model_summary")
@model_manager_v2_router.get("/meta", operation_id="list_model_summary")
async def list_model_summary(
page: int = Query(default=0, description="The page to get"),
per_page: int = Query(default=10, description="The number of models per page"),
order_by: ModelRecordOrderBy = Query(default=ModelRecordOrderBy.Default, description="The attribute to order by"),
) -> PaginatedResults[ModelSummary]:
"""Gets a page of model summary data."""
return ApiDependencies.invoker.services.model_records.list_models(page=page, per_page=per_page, order_by=order_by)
record_store = ApiDependencies.invoker.services.model_manager.store
results: PaginatedResults[ModelSummary] = record_store.list_models(page=page, per_page=per_page, order_by=order_by)
return results
@model_records_router.get(
@model_manager_v2_router.get(
"/meta/i/{key}",
operation_id="get_model_metadata",
responses={
@ -124,24 +127,25 @@ async def get_model_metadata(
key: str = Path(description="Key of the model repo metadata to fetch."),
) -> Optional[AnyModelRepoMetadata]:
"""Get a model metadata object."""
record_store = ApiDependencies.invoker.services.model_records
result = record_store.get_metadata(key)
record_store = ApiDependencies.invoker.services.model_manager.store
result: Optional[AnyModelRepoMetadata] = record_store.get_metadata(key)
if not result:
raise HTTPException(status_code=404, detail="No metadata for a model with this key")
return result
@model_records_router.get(
@model_manager_v2_router.get(
"/tags",
operation_id="list_tags",
)
async def list_tags() -> Set[str]:
"""Get a unique set of all the model tags."""
record_store = ApiDependencies.invoker.services.model_records
return record_store.list_tags()
record_store = ApiDependencies.invoker.services.model_manager.store
result: Set[str] = record_store.list_tags()
return result
@model_records_router.get(
@model_manager_v2_router.get(
"/tags/search",
operation_id="search_by_metadata_tags",
)
@ -149,12 +153,12 @@ async def search_by_metadata_tags(
tags: Set[str] = Query(default=None, description="Tags to search for"),
) -> ModelsList:
"""Get a list of models."""
record_store = ApiDependencies.invoker.services.model_records
record_store = ApiDependencies.invoker.services.model_manager.store
results = record_store.search_by_metadata_tag(tags)
return ModelsList(models=results)
@model_records_router.patch(
@model_manager_v2_router.patch(
"/i/{key}",
operation_id="update_model_record",
responses={
@ -172,9 +176,9 @@ async def update_model_record(
) -> AnyModelConfig:
"""Update model contents with a new config. If the model name or base fields are changed, then the model is renamed."""
logger = ApiDependencies.invoker.services.logger
record_store = ApiDependencies.invoker.services.model_records
record_store = ApiDependencies.invoker.services.model_manager.store
try:
model_response = record_store.update_model(key, config=info)
model_response: AnyModelConfig = record_store.update_model(key, config=info)
logger.info(f"Updated model: {key}")
except UnknownModelException as e:
raise HTTPException(status_code=404, detail=str(e))
@ -184,7 +188,7 @@ async def update_model_record(
return model_response
@model_records_router.delete(
@model_manager_v2_router.delete(
"/i/{key}",
operation_id="del_model_record",
responses={
@ -205,7 +209,7 @@ async def del_model_record(
logger = ApiDependencies.invoker.services.logger
try:
installer = ApiDependencies.invoker.services.model_install
installer = ApiDependencies.invoker.services.model_manager.install
installer.delete(key)
logger.info(f"Deleted model: {key}")
return Response(status_code=204)
@ -214,7 +218,7 @@ async def del_model_record(
raise HTTPException(status_code=404, detail=str(e))
@model_records_router.post(
@model_manager_v2_router.post(
"/i/",
operation_id="add_model_record",
responses={
@ -229,7 +233,7 @@ async def add_model_record(
) -> AnyModelConfig:
"""Add a model using the configuration information appropriate for its type."""
logger = ApiDependencies.invoker.services.logger
record_store = ApiDependencies.invoker.services.model_records
record_store = ApiDependencies.invoker.services.model_manager.store
if config.key == "<NOKEY>":
config.key = sha1(randbytes(100)).hexdigest()
logger.info(f"Created model {config.key} for {config.name}")
@ -243,10 +247,11 @@ async def add_model_record(
raise HTTPException(status_code=415)
# now fetch it out
return record_store.get_model(config.key)
result: AnyModelConfig = record_store.get_model(config.key)
return result
@model_records_router.post(
@model_manager_v2_router.post(
"/import",
operation_id="import_model_record",
responses={
@ -322,7 +327,7 @@ async def import_model(
logger = ApiDependencies.invoker.services.logger
try:
installer = ApiDependencies.invoker.services.model_install
installer = ApiDependencies.invoker.services.model_manager.install
result: ModelInstallJob = installer.import_model(
source=source,
config=config,
@ -340,17 +345,17 @@ async def import_model(
return result
@model_records_router.get(
@model_manager_v2_router.get(
"/import",
operation_id="list_model_install_jobs",
)
async def list_model_install_jobs() -> List[ModelInstallJob]:
"""Return list of model install jobs."""
jobs: List[ModelInstallJob] = ApiDependencies.invoker.services.model_install.list_jobs()
jobs: List[ModelInstallJob] = ApiDependencies.invoker.services.model_manager.install.list_jobs()
return jobs
@model_records_router.get(
@model_manager_v2_router.get(
"/import/{id}",
operation_id="get_model_install_job",
responses={
@ -361,12 +366,13 @@ async def list_model_install_jobs() -> List[ModelInstallJob]:
async def get_model_install_job(id: int = Path(description="Model install id")) -> ModelInstallJob:
"""Return model install job corresponding to the given source."""
try:
return ApiDependencies.invoker.services.model_install.get_job_by_id(id)
result: ModelInstallJob = ApiDependencies.invoker.services.model_manager.install.get_job_by_id(id)
return result
except ValueError as e:
raise HTTPException(status_code=404, detail=str(e))
@model_records_router.delete(
@model_manager_v2_router.delete(
"/import/{id}",
operation_id="cancel_model_install_job",
responses={
@ -377,7 +383,7 @@ async def get_model_install_job(id: int = Path(description="Model install id"))
)
async def cancel_model_install_job(id: int = Path(description="Model install job ID")) -> None:
"""Cancel the model install job(s) corresponding to the given job ID."""
installer = ApiDependencies.invoker.services.model_install
installer = ApiDependencies.invoker.services.model_manager.install
try:
job = installer.get_job_by_id(id)
except ValueError as e:
@ -385,7 +391,7 @@ async def cancel_model_install_job(id: int = Path(description="Model install job
installer.cancel_job(job)
@model_records_router.patch(
@model_manager_v2_router.patch(
"/import",
operation_id="prune_model_install_jobs",
responses={
@ -395,11 +401,11 @@ async def cancel_model_install_job(id: int = Path(description="Model install job
)
async def prune_model_install_jobs() -> Response:
"""Prune all completed and errored jobs from the install job list."""
ApiDependencies.invoker.services.model_install.prune_jobs()
ApiDependencies.invoker.services.model_manager.install.prune_jobs()
return Response(status_code=204)
@model_records_router.patch(
@model_manager_v2_router.patch(
"/sync",
operation_id="sync_models_to_config",
responses={
@ -414,11 +420,11 @@ async def sync_models_to_config() -> Response:
Model files without a corresponding
record in the database are added. Orphan records without a models file are deleted.
"""
ApiDependencies.invoker.services.model_install.sync_to_config()
ApiDependencies.invoker.services.model_manager.install.sync_to_config()
return Response(status_code=204)
@model_records_router.put(
@model_manager_v2_router.put(
"/merge",
operation_id="merge",
)
@ -451,7 +457,7 @@ async def merge(
try:
logger.info(f"Merging models: {keys} into {merge_dest_directory or '<MODELS>'}/{merged_model_name}")
dest = pathlib.Path(merge_dest_directory) if merge_dest_directory else None
installer = ApiDependencies.invoker.services.model_install
installer = ApiDependencies.invoker.services.model_manager.install
merger = ModelMerger(installer)
model_names = [installer.record_store.get_model(x).name for x in keys]
response = merger.merge_diffusion_models_and_save(

View File

@ -8,8 +8,7 @@ from fastapi.routing import APIRouter
from pydantic import BaseModel, ConfigDict, Field, TypeAdapter
from starlette.exceptions import HTTPException
from invokeai.backend import BaseModelType, ModelType
from invokeai.backend.model_management import MergeInterpolationMethod
from invokeai.backend.model_management import BaseModelType, MergeInterpolationMethod, ModelType
from invokeai.backend.model_management.models import (
OPENAPI_MODEL_CONFIGS,
InvalidModelException,

View File

@ -48,7 +48,7 @@ if True: # hack to make flake8 happy with imports coming after setting up the c
boards,
download_queue,
images,
model_records,
model_manager_v2,
models,
session_queue,
sessions,
@ -114,8 +114,7 @@ async def shutdown_event() -> None:
app.include_router(sessions.session_router, prefix="/api")
app.include_router(utilities.utilities_router, prefix="/api")
app.include_router(models.models_router, prefix="/api")
app.include_router(model_records.model_records_router, prefix="/api")
app.include_router(model_manager_v2.model_manager_v2_router, prefix="/api")
app.include_router(download_queue.download_queue_router, prefix="/api")
app.include_router(images.images_router, prefix="/api")
app.include_router(boards.boards_router, prefix="/api")

View File

@ -3,6 +3,7 @@ from typing import Iterator, List, Optional, Tuple, Union
import torch
from compel import Compel, ReturnedEmbeddingsType
from compel.prompt_parser import Blend, Conjunction, CrossAttentionControlSubstitute, FlattenedPrompt, Fragment
from transformers import CLIPTokenizer
import invokeai.backend.util.logging as logger
from invokeai.app.invocations.fields import (
@ -68,18 +69,18 @@ class CompelInvocation(BaseInvocation):
@torch.no_grad()
def invoke(self, context: InvocationContext) -> ConditioningOutput:
tokenizer_info = context.services.model_records.load_model(
tokenizer_info = context.services.model_manager.load.load_model_by_key(
**self.clip.tokenizer.model_dump(),
context=context,
)
text_encoder_info = context.services.model_records.load_model(
text_encoder_info = context.services.model_manager.load.load_model_by_key(
**self.clip.text_encoder.model_dump(),
context=context,
)
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
for lora in self.clip.loras:
lora_info = context.services.model_records.load_model(
lora_info = context.services.model_manager.load.load_model_by_key(
**lora.model_dump(exclude={"weight"}), context=context
)
assert isinstance(lora_info.model, LoRAModelRaw)
@ -93,7 +94,7 @@ class CompelInvocation(BaseInvocation):
for trigger in extract_ti_triggers_from_prompt(self.prompt):
name = trigger[1:-1]
try:
loaded_model = context.services.model_records.load_model(
loaded_model = context.services.model_manager.load.load_model_by_key(
**self.clip.text_encoder.model_dump(),
context=context,
).model
@ -164,11 +165,11 @@ class SDXLPromptInvocationBase:
lora_prefix: str,
zero_on_empty: bool,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[ExtraConditioningInfo]]:
tokenizer_info = context.services.model_records.load_model(
tokenizer_info = context.services.model_manager.load.load_model_by_key(
**clip_field.tokenizer.model_dump(),
context=context,
)
text_encoder_info = context.services.model_records.load_model(
text_encoder_info = context.services.model_manager.load.load_model_by_key(
**clip_field.text_encoder.model_dump(),
context=context,
)
@ -196,7 +197,7 @@ class SDXLPromptInvocationBase:
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
for lora in clip_field.loras:
lora_info = context.services.model_records.load_model(
lora_info = context.services.model_manager.load.load_model_by_key(
**lora.model_dump(exclude={"weight"}), context=context
)
lora_model = lora_info.model
@ -211,7 +212,7 @@ class SDXLPromptInvocationBase:
for trigger in extract_ti_triggers_from_prompt(prompt):
name = trigger[1:-1]
try:
ti_model = context.services.model_records.load_model_by_attr(
ti_model = context.services.model_manager.load.load_model_by_attr(
model_name=name,
base_model=text_encoder_info.config.base,
model_type=ModelType.TextualInversion,
@ -448,9 +449,9 @@ class ClipSkipInvocation(BaseInvocation):
def get_max_token_count(
tokenizer,
tokenizer: CLIPTokenizer,
prompt: Union[FlattenedPrompt, Blend, Conjunction],
truncate_if_too_long=False,
truncate_if_too_long: bool = False,
) -> int:
if type(prompt) is Blend:
blend: Blend = prompt
@ -462,7 +463,9 @@ def get_max_token_count(
return len(get_tokens_for_prompt_object(tokenizer, prompt, truncate_if_too_long))
def get_tokens_for_prompt_object(tokenizer, parsed_prompt: FlattenedPrompt, truncate_if_too_long=True) -> List[str]:
def get_tokens_for_prompt_object(
tokenizer: CLIPTokenizer, parsed_prompt: FlattenedPrompt, truncate_if_too_long: bool = True
) -> List[str]:
if type(parsed_prompt) is Blend:
raise ValueError("Blend is not supported here - you need to get tokens for each of its .children")
@ -475,24 +478,29 @@ def get_tokens_for_prompt_object(tokenizer, parsed_prompt: FlattenedPrompt, trun
for x in parsed_prompt.children
]
text = " ".join(text_fragments)
tokens = tokenizer.tokenize(text)
tokens: List[str] = tokenizer.tokenize(text)
if truncate_if_too_long:
max_tokens_length = tokenizer.model_max_length - 2 # typically 75
tokens = tokens[0:max_tokens_length]
return tokens
def log_tokenization_for_conjunction(c: Conjunction, tokenizer, display_label_prefix=None):
def log_tokenization_for_conjunction(
c: Conjunction, tokenizer: CLIPTokenizer, display_label_prefix: Optional[str] = None
) -> None:
display_label_prefix = display_label_prefix or ""
for i, p in enumerate(c.prompts):
if len(c.prompts) > 1:
this_display_label_prefix = f"{display_label_prefix}(conjunction part {i + 1}, weight={c.weights[i]})"
else:
assert display_label_prefix is not None
this_display_label_prefix = display_label_prefix
log_tokenization_for_prompt_object(p, tokenizer, display_label_prefix=this_display_label_prefix)
def log_tokenization_for_prompt_object(p: Union[Blend, FlattenedPrompt], tokenizer, display_label_prefix=None):
def log_tokenization_for_prompt_object(
p: Union[Blend, FlattenedPrompt], tokenizer: CLIPTokenizer, display_label_prefix: Optional[str] = None
) -> None:
display_label_prefix = display_label_prefix or ""
if type(p) is Blend:
blend: Blend = p
@ -532,7 +540,12 @@ def log_tokenization_for_prompt_object(p: Union[Blend, FlattenedPrompt], tokeniz
log_tokenization_for_text(text, tokenizer, display_label=display_label_prefix)
def log_tokenization_for_text(text, tokenizer, display_label=None, truncate_if_too_long=False):
def log_tokenization_for_text(
text: str,
tokenizer: CLIPTokenizer,
display_label: Optional[str] = None,
truncate_if_too_long: Optional[bool] = False,
) -> None:
"""shows how the prompt is tokenized
# usually tokens have '</w>' to indicate end-of-word,
# but for readability it has been replaced with ' '

View File

@ -3,13 +3,15 @@
import math
from contextlib import ExitStack
from functools import singledispatchmethod
from typing import Iterator, List, Literal, Optional, Tuple, Union
from typing import Any, Iterator, List, Literal, Optional, Tuple, Union
import einops
import numpy as np
import numpy.typing as npt
import torch
import torchvision.transforms as T
from diffusers import AutoencoderKL, AutoencoderTiny, UNet2DConditionModel
from diffusers import AutoencoderKL, AutoencoderTiny
from diffusers.configuration_utils import ConfigMixin
from diffusers.image_processor import VaeImageProcessor
from diffusers.models.adapter import T2IAdapter
from diffusers.models.attention_processor import (
@ -18,8 +20,10 @@ from diffusers.models.attention_processor import (
LoRAXFormersAttnProcessor,
XFormersAttnProcessor,
)
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
from diffusers.schedulers import DPMSolverSDEScheduler
from diffusers.schedulers import SchedulerMixin as Scheduler
from PIL import Image
from pydantic import field_validator
from torchvision.transforms.functional import resize as tv_resize
@ -46,9 +50,10 @@ from invokeai.app.invocations.primitives import (
from invokeai.app.invocations.t2i_adapter import T2IAdapterField
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.util.controlnet_utils import prepare_control_image
from invokeai.backend.embeddings.lora import LoRAModelRaw
from invokeai.backend.embeddings.model_patcher import ModelPatcher
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter, IPAdapterPlus
from invokeai.backend.model_manager import AnyModel, BaseModelType
from invokeai.backend.model_manager import BaseModelType, LoadedModel
from invokeai.backend.stable_diffusion import PipelineIntermediateState, set_seamless
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningData, IPAdapterConditioningInfo
from invokeai.backend.util.silence_warnings import SilenceWarnings
@ -123,10 +128,10 @@ class CreateDenoiseMaskInvocation(BaseInvocation):
ui_order=4,
)
def prep_mask_tensor(self, mask_image):
def prep_mask_tensor(self, mask_image: Image) -> torch.Tensor:
if mask_image.mode != "L":
mask_image = mask_image.convert("L")
mask_tensor = image_resized_to_grid_as_tensor(mask_image, normalize=False)
mask_tensor: torch.Tensor = image_resized_to_grid_as_tensor(mask_image, normalize=False)
if mask_tensor.dim() == 3:
mask_tensor = mask_tensor.unsqueeze(0)
# if shape is not None:
@ -136,25 +141,25 @@ class CreateDenoiseMaskInvocation(BaseInvocation):
@torch.no_grad()
def invoke(self, context: InvocationContext) -> DenoiseMaskOutput:
if self.image is not None:
image = context.images.get_pil(self.image.image_name)
image = image_resized_to_grid_as_tensor(image.convert("RGB"))
if image.dim() == 3:
image = image.unsqueeze(0)
image = context.services.images.get_pil_image(self.image.image_name)
image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB"))
if image_tensor.dim() == 3:
image_tensor = image_tensor.unsqueeze(0)
else:
image = None
image_tensor = None
mask = self.prep_mask_tensor(
context.images.get_pil(self.mask.image_name),
)
if image is not None:
vae_info = context.services.model_records.load_model(
if image_tensor is not None:
vae_info = context.services.model_manager.load.load_model_by_key(
**self.vae.vae.model_dump(),
context=context,
)
img_mask = tv_resize(mask, image.shape[-2:], T.InterpolationMode.BILINEAR, antialias=False)
masked_image = image * torch.where(img_mask < 0.5, 0.0, 1.0)
img_mask = tv_resize(mask, image_tensor.shape[-2:], T.InterpolationMode.BILINEAR, antialias=False)
masked_image = image_tensor * torch.where(img_mask < 0.5, 0.0, 1.0)
# TODO:
masked_latents = ImageToLatentsInvocation.vae_encode(vae_info, self.fp32, self.tiled, masked_image.clone())
@ -177,7 +182,7 @@ def get_scheduler(
seed: int,
) -> Scheduler:
scheduler_class, scheduler_extra_config = SCHEDULER_MAP.get(scheduler_name, SCHEDULER_MAP["ddim"])
orig_scheduler_info = context.services.model_records.load_model(
orig_scheduler_info = context.services.model_manager.load.load_model_by_key(
**scheduler_info.model_dump(),
context=context,
)
@ -188,7 +193,7 @@ def get_scheduler(
scheduler_config = scheduler_config["_backup"]
scheduler_config = {
**scheduler_config,
**scheduler_extra_config,
**scheduler_extra_config, # FIXME
"_backup": scheduler_config,
}
@ -201,6 +206,7 @@ def get_scheduler(
# hack copied over from generate.py
if not hasattr(scheduler, "uses_inpainting_model"):
scheduler.uses_inpainting_model = lambda: False
assert isinstance(scheduler, Scheduler)
return scheduler
@ -284,7 +290,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
)
@field_validator("cfg_scale")
def ge_one(cls, v):
def ge_one(cls, v: Union[List[float], float]) -> Union[List[float], float]:
"""validate that all cfg_scale values are >= 1"""
if isinstance(v, list):
for i in v:
@ -298,9 +304,9 @@ class DenoiseLatentsInvocation(BaseInvocation):
def get_conditioning_data(
self,
context: InvocationContext,
scheduler,
unet,
seed,
scheduler: Scheduler,
unet: UNet2DConditionModel,
seed: int,
) -> ConditioningData:
positive_cond_data = context.conditioning.load(self.positive_conditioning.conditioning_name)
c = positive_cond_data.conditionings[0].to(device=unet.device, dtype=unet.dtype)
@ -323,7 +329,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
),
)
conditioning_data = conditioning_data.add_scheduler_args_if_applicable(
conditioning_data = conditioning_data.add_scheduler_args_if_applicable( # FIXME
scheduler,
# for ddim scheduler
eta=0.0, # ddim_eta
@ -335,8 +341,8 @@ class DenoiseLatentsInvocation(BaseInvocation):
def create_pipeline(
self,
unet,
scheduler,
unet: UNet2DConditionModel,
scheduler: Scheduler,
) -> StableDiffusionGeneratorPipeline:
# TODO:
# configure_model_padding(
@ -347,10 +353,10 @@ class DenoiseLatentsInvocation(BaseInvocation):
class FakeVae:
class FakeVaeConfig:
def __init__(self):
def __init__(self) -> None:
self.block_out_channels = [0]
def __init__(self):
def __init__(self) -> None:
self.config = FakeVae.FakeVaeConfig()
return StableDiffusionGeneratorPipeline(
@ -367,11 +373,11 @@ class DenoiseLatentsInvocation(BaseInvocation):
def prep_control_data(
self,
context: InvocationContext,
control_input: Union[ControlField, List[ControlField]],
control_input: Optional[Union[ControlField, List[ControlField]]],
latents_shape: List[int],
exit_stack: ExitStack,
do_classifier_free_guidance: bool = True,
) -> List[ControlNetData]:
) -> Optional[List[ControlNetData]]:
# Assuming fixed dimensional scaling of LATENT_SCALE_FACTOR.
control_height_resize = latents_shape[2] * LATENT_SCALE_FACTOR
control_width_resize = latents_shape[3] * LATENT_SCALE_FACTOR
@ -394,7 +400,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
controlnet_data = []
for control_info in control_list:
control_model = exit_stack.enter_context(
context.services.model_records.load_model(
context.services.model_manager.load.load_model_by_key(
key=control_info.control_model.key,
context=context,
)
@ -460,23 +466,25 @@ class DenoiseLatentsInvocation(BaseInvocation):
conditioning_data.ip_adapter_conditioning = []
for single_ip_adapter in ip_adapter:
ip_adapter_model: Union[IPAdapter, IPAdapterPlus] = exit_stack.enter_context(
context.services.model_records.load_model(
context.services.model_manager.load.load_model_by_key(
key=single_ip_adapter.ip_adapter_model.key,
context=context,
)
)
image_encoder_model_info = context.services.model_records.load_model(
image_encoder_model_info = context.services.model_manager.load.load_model_by_key(
key=single_ip_adapter.image_encoder_model.key,
context=context,
)
# `single_ip_adapter.image` could be a list or a single ImageField. Normalize to a list here.
single_ipa_images = single_ip_adapter.image
if not isinstance(single_ipa_images, list):
single_ipa_images = [single_ipa_images]
single_ipa_image_fields = single_ip_adapter.image
if not isinstance(single_ipa_image_fields, list):
single_ipa_image_fields = [single_ipa_image_fields]
single_ipa_images = [context.images.get_pil(image.image_name) for image in single_ipa_images]
single_ipa_images = [
context.services.images.get_pil_image(image.image_name) for image in single_ipa_image_fields
]
# TODO(ryand): With some effort, the step of running the CLIP Vision encoder could be done before any other
# models are needed in memory. This would help to reduce peak memory utilization in low-memory environments.
@ -520,21 +528,19 @@ class DenoiseLatentsInvocation(BaseInvocation):
t2i_adapter_data = []
for t2i_adapter_field in t2i_adapter:
t2i_adapter_model_info = context.services.model_records.load_model(
t2i_adapter_model_info = context.services.model_manager.load.load_model_by_key(
key=t2i_adapter_field.t2i_adapter_model.key,
context=context,
)
image = context.images.get_pil(t2i_adapter_field.image.image_name)
# The max_unet_downscale is the maximum amount that the UNet model downscales the latent image internally.
if t2i_adapter_field.t2i_adapter_model.base_model == BaseModelType.StableDiffusion1:
if t2i_adapter_model_info.base == BaseModelType.StableDiffusion1:
max_unet_downscale = 8
elif t2i_adapter_field.t2i_adapter_model.base_model == BaseModelType.StableDiffusionXL:
elif t2i_adapter_model_info.base == BaseModelType.StableDiffusionXL:
max_unet_downscale = 4
else:
raise ValueError(
f"Unexpected T2I-Adapter base model type: '{t2i_adapter_field.t2i_adapter_model.base_model}'."
)
raise ValueError(f"Unexpected T2I-Adapter base model type: '{t2i_adapter_model_info.base}'.")
t2i_adapter_model: T2IAdapter
with t2i_adapter_model_info as t2i_adapter_model:
@ -582,7 +588,15 @@ class DenoiseLatentsInvocation(BaseInvocation):
# original idea by https://github.com/AmericanPresidentJimmyCarter
# TODO: research more for second order schedulers timesteps
def init_scheduler(self, scheduler, device, steps, denoising_start, denoising_end):
def init_scheduler(
self,
scheduler: Union[Scheduler, ConfigMixin],
device: torch.device,
steps: int,
denoising_start: float,
denoising_end: float,
) -> Tuple[int, List[int], int]:
assert isinstance(scheduler, ConfigMixin)
if scheduler.config.get("cpu_only", False):
scheduler.set_timesteps(steps, device="cpu")
timesteps = scheduler.timesteps.to(device=device)
@ -594,11 +608,11 @@ class DenoiseLatentsInvocation(BaseInvocation):
_timesteps = timesteps[:: scheduler.order]
# get start timestep index
t_start_val = int(round(scheduler.config.num_train_timesteps * (1 - denoising_start)))
t_start_val = int(round(scheduler.config["num_train_timesteps"] * (1 - denoising_start)))
t_start_idx = len(list(filter(lambda ts: ts >= t_start_val, _timesteps)))
# get end timestep index
t_end_val = int(round(scheduler.config.num_train_timesteps * (1 - denoising_end)))
t_end_val = int(round(scheduler.config["num_train_timesteps"] * (1 - denoising_end)))
t_end_idx = len(list(filter(lambda ts: ts >= t_end_val, _timesteps[t_start_idx:])))
# apply order to indexes
@ -611,7 +625,9 @@ class DenoiseLatentsInvocation(BaseInvocation):
return num_inference_steps, timesteps, init_timestep
def prep_inpaint_mask(self, context: InvocationContext, latents):
def prep_inpaint_mask(
self, context: InvocationContext, latents: torch.Tensor
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
if self.denoise_mask is None:
return None, None
@ -660,12 +676,19 @@ class DenoiseLatentsInvocation(BaseInvocation):
do_classifier_free_guidance=True,
)
def step_callback(state: PipelineIntermediateState):
context.util.sd_step_callback(state, self.unet.unet.base_model)
# Get the source node id (we are invoking the prepared node)
graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id)
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
def _lora_loader() -> Iterator[Tuple[AnyModel, float]]:
# get the unet's config so that we can pass the base to dispatch_progress()
unet_config = context.services.model_manager.store.get_model(**self.unet.unet.model_dump())
def step_callback(state: PipelineIntermediateState) -> None:
self.dispatch_progress(context, source_node_id, state, unet_config.base)
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
for lora in self.unet.loras:
lora_info = context.services.model_records.load_model(
lora_info = context.services.model_manager.load.load_model_by_key(
**lora.model_dump(exclude={"weight"}),
context=context,
)
@ -673,7 +696,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
del lora_info
return
unet_info = context.services.model_records.load_model(
unet_info = context.services.model_manager.load.load_model_by_key(
**self.unet.unet.model_dump(),
context=context,
)
@ -783,7 +806,7 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
def invoke(self, context: InvocationContext) -> ImageOutput:
latents = context.tensors.load(self.latents.latents_name)
vae_info = context.services.model_records.load_model(
vae_info = context.services.model_manager.load.load_model_by_key(
**self.vae.vae.model_dump(),
context=context,
)
@ -961,8 +984,9 @@ class ImageToLatentsInvocation(BaseInvocation):
fp32: bool = InputField(default=DEFAULT_PRECISION == "float32", description=FieldDescriptions.fp32)
@staticmethod
def vae_encode(vae_info, upcast, tiled, image_tensor):
def vae_encode(vae_info: LoadedModel, upcast: bool, tiled: bool, image_tensor: torch.Tensor) -> torch.Tensor:
with vae_info as vae:
assert isinstance(vae, torch.nn.Module)
orig_dtype = vae.dtype
if upcast:
vae.to(dtype=torch.float32)
@ -1008,7 +1032,7 @@ class ImageToLatentsInvocation(BaseInvocation):
def invoke(self, context: InvocationContext) -> LatentsOutput:
image = context.images.get_pil(self.image.image_name)
vae_info = context.services.model_records.load_model(
vae_info = context.services.model_manager.load.load_model_by_key(
**self.vae.vae.model_dump(),
context=context,
)
@ -1026,14 +1050,19 @@ class ImageToLatentsInvocation(BaseInvocation):
@singledispatchmethod
@staticmethod
def _encode_to_tensor(vae: AutoencoderKL, image_tensor: torch.FloatTensor) -> torch.FloatTensor:
assert isinstance(vae, torch.nn.Module)
image_tensor_dist = vae.encode(image_tensor).latent_dist
latents = image_tensor_dist.sample().to(dtype=vae.dtype) # FIXME: uses torch.randn. make reproducible!
latents: torch.Tensor = image_tensor_dist.sample().to(
dtype=vae.dtype
) # FIXME: uses torch.randn. make reproducible!
return latents
@_encode_to_tensor.register
@staticmethod
def _(vae: AutoencoderTiny, image_tensor: torch.FloatTensor) -> torch.FloatTensor:
return vae.encode(image_tensor).latents
assert isinstance(vae, torch.nn.Module)
latents: torch.FloatTensor = vae.encode(image_tensor).latents
return latents
@invocation(
@ -1066,7 +1095,12 @@ class BlendLatentsInvocation(BaseInvocation):
# TODO:
device = choose_torch_device()
def slerp(t, v0, v1, DOT_THRESHOLD=0.9995):
def slerp(
t: Union[float, npt.NDArray[Any]], # FIXME: maybe use np.float32 here?
v0: Union[torch.Tensor, npt.NDArray[Any]],
v1: Union[torch.Tensor, npt.NDArray[Any]],
DOT_THRESHOLD: float = 0.9995,
) -> Union[torch.Tensor, npt.NDArray[Any]]:
"""
Spherical linear interpolation
Args:
@ -1099,12 +1133,16 @@ class BlendLatentsInvocation(BaseInvocation):
v2 = s0 * v0 + s1 * v1
if inputs_are_torch:
v2 = torch.from_numpy(v2).to(device)
return v2
v2_torch: torch.Tensor = torch.from_numpy(v2).to(device)
return v2_torch
else:
assert isinstance(v2, np.ndarray)
return v2
# blend
blended_latents = slerp(self.alpha, latents_a, latents_b)
bl = slerp(self.alpha, latents_a, latents_b)
assert isinstance(bl, torch.Tensor)
blended_latents: torch.Tensor = bl # for type checking convenience
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
blended_latents = blended_latents.to("cpu")
@ -1197,15 +1235,19 @@ class IdealSizeInvocation(BaseInvocation):
description="Amount to multiply the model's dimensions by when calculating the ideal size (may result in initial generation artifacts if too large)",
)
def trim_to_multiple_of(self, *args, multiple_of=LATENT_SCALE_FACTOR):
def trim_to_multiple_of(self, *args: int, multiple_of: int = LATENT_SCALE_FACTOR) -> Tuple[int, ...]:
return tuple((x - x % multiple_of) for x in args)
def invoke(self, context: InvocationContext) -> IdealSizeOutput:
unet_config = context.services.model_manager.load.load_model_by_key(
**self.unet.unet.model_dump(),
context=context,
)
aspect = self.width / self.height
dimension = 512
if self.unet.unet.base_model == BaseModelType.StableDiffusion2:
dimension: float = 512
if unet_config.base == BaseModelType.StableDiffusion2:
dimension = 768
elif self.unet.unet.base_model == BaseModelType.StableDiffusionXL:
elif unet_config.base == BaseModelType.StableDiffusionXL:
dimension = 1024
dimension = dimension * self.multiplier
min_dimension = math.floor(dimension * 0.5)

View File

@ -17,7 +17,7 @@ from .baseinvocation import (
class ModelInfo(BaseModel):
key: str = Field(description="Info to load submodel")
key: str = Field(description="Key of model as returned by ModelRecordServiceBase.get_model()")
submodel: Optional[SubModelType] = Field(default=None, description="Info to load submodel")

View File

@ -27,9 +27,7 @@ if TYPE_CHECKING:
from .invocation_queue.invocation_queue_base import InvocationQueueABC
from .invocation_stats.invocation_stats_base import InvocationStatsServiceBase
from .item_storage.item_storage_base import ItemStorageABC
from .model_install import ModelInstallServiceBase
from .model_manager.model_manager_base import ModelManagerServiceBase
from .model_records import ModelRecordServiceBase
from .names.names_base import NameServiceBase
from .session_processor.session_processor_base import SessionProcessorBase
from .session_queue.session_queue_base import SessionQueueBase
@ -55,9 +53,7 @@ class InvocationServices:
image_records: "ImageRecordStorageBase",
logger: "Logger",
model_manager: "ModelManagerServiceBase",
model_records: "ModelRecordServiceBase",
download_queue: "DownloadQueueServiceBase",
model_install: "ModelInstallServiceBase",
processor: "InvocationProcessorABC",
performance_statistics: "InvocationStatsServiceBase",
queue: "InvocationQueueABC",
@ -82,9 +78,7 @@ class InvocationServices:
self.image_records = image_records
self.logger = logger
self.model_manager = model_manager
self.model_records = model_records
self.download_queue = download_queue
self.model_install = model_install
self.processor = processor
self.performance_statistics = performance_statistics
self.queue = queue

View File

@ -43,8 +43,10 @@ class InvocationStatsService(InvocationStatsServiceBase):
@contextmanager
def collect_stats(self, invocation: BaseInvocation, graph_execution_state_id: str) -> Iterator[None]:
# This is to handle case of the model manager not being initialized, which happens
# during some tests.
services = self._invoker.services
if services.model_records is None or services.model_records.loader is None:
if services.model_manager is None or services.model_manager.load is None:
yield None
if not self._stats.get(graph_execution_state_id):
# First time we're seeing this graph_execution_state_id.
@ -60,9 +62,8 @@ class InvocationStatsService(InvocationStatsServiceBase):
if torch.cuda.is_available():
torch.cuda.reset_peak_memory_stats()
# TO DO [LS]: clean up loader service - shouldn't be an attribute of model records
assert services.model_records.loader is not None
services.model_records.loader.ram_cache.stats = self._cache_stats[graph_execution_state_id]
assert services.model_manager.load is not None
services.model_manager.load.ram_cache.stats = self._cache_stats[graph_execution_state_id]
try:
# Let the invocation run.

View File

@ -4,7 +4,8 @@
from abc import ABC, abstractmethod
from typing import Optional
from invokeai.backend.model_manager import AnyModelConfig, SubModelType
from invokeai.app.invocations.baseinvocation import InvocationContext
from invokeai.backend.model_manager import AnyModelConfig, BaseModelType, ModelType, SubModelType
from invokeai.backend.model_manager.load import LoadedModel
@ -12,11 +13,60 @@ class ModelLoadServiceBase(ABC):
"""Wrapper around AnyModelLoader."""
@abstractmethod
def load_model_by_key(self, key: str, submodel_type: Optional[SubModelType] = None) -> LoadedModel:
"""Given a model's key, load it and return the LoadedModel object."""
def load_model_by_key(
self,
key: str,
submodel_type: Optional[SubModelType] = None,
context: Optional[InvocationContext] = None,
) -> LoadedModel:
"""
Given a model's key, load it and return the LoadedModel object.
:param key: Key of model config to be fetched.
:param submodel: For main (pipeline models), the submodel to fetch.
:param context: Invocation context used for event reporting
"""
pass
@abstractmethod
def load_model_by_config(self, config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> LoadedModel:
"""Given a model's configuration, load it and return the LoadedModel object."""
def load_model_by_config(
self,
model_config: AnyModelConfig,
submodel_type: Optional[SubModelType] = None,
context: Optional[InvocationContext] = None,
) -> LoadedModel:
"""
Given a model's configuration, load it and return the LoadedModel object.
:param model_config: Model configuration record (as returned by ModelRecordBase.get_model())
:param submodel: For main (pipeline models), the submodel to fetch.
:param context: Invocation context used for event reporting
"""
pass
@abstractmethod
def load_model_by_attr(
self,
model_name: str,
base_model: BaseModelType,
model_type: ModelType,
submodel: Optional[SubModelType] = None,
context: Optional[InvocationContext] = None,
) -> LoadedModel:
"""
Given a model's attributes, search the database for it, and if found, load and return the LoadedModel object.
This is provided for API compatability with the get_model() method
in the original model manager. However, note that LoadedModel is
not the same as the original ModelInfo that ws returned.
:param model_name: Name of to be fetched.
:param base_model: Base model
:param model_type: Type of the model
:param submodel: For main (pipeline models), the submodel to fetch
:param context: The invocation context.
Exceptions: UnknownModelException -- model with these attributes not known
NotImplementedException -- a model loader was not provided at initialization time
ValueError -- more than one model matches this combination
"""

View File

@ -3,12 +3,14 @@
from typing import Optional
from invokeai.app.invocations.baseinvocation import InvocationContext
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.model_records import ModelRecordServiceBase
from invokeai.backend.model_manager import AnyModelConfig, SubModelType
from invokeai.app.services.invocation_processor.invocation_processor_common import CanceledException
from invokeai.app.services.model_records import ModelRecordServiceBase, UnknownModelException
from invokeai.backend.model_manager import AnyModel, AnyModelConfig, BaseModelType, ModelType, SubModelType
from invokeai.backend.model_manager.load import AnyModelLoader, LoadedModel, ModelCache, ModelConvertCache
from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase
from invokeai.backend.model_manager.load.ram_cache import ModelCacheBase
from invokeai.backend.model_manager.load.model_cache import ModelCacheBase
from invokeai.backend.util.logging import InvokeAILogger
from .model_load_base import ModelLoadServiceBase
@ -21,7 +23,7 @@ class ModelLoadService(ModelLoadServiceBase):
self,
app_config: InvokeAIAppConfig,
record_store: ModelRecordServiceBase,
ram_cache: Optional[ModelCacheBase] = None,
ram_cache: Optional[ModelCacheBase[AnyModel]] = None,
convert_cache: Optional[ModelConvertCacheBase] = None,
):
"""Initialize the model load service."""
@ -44,11 +46,104 @@ class ModelLoadService(ModelLoadServiceBase):
),
)
def load_model_by_key(self, key: str, submodel_type: Optional[SubModelType] = None) -> LoadedModel:
"""Given a model's key, load it and return the LoadedModel object."""
config = self._store.get_model(key)
return self.load_model_by_config(config, submodel_type)
def load_model_by_key(
self,
key: str,
submodel_type: Optional[SubModelType] = None,
context: Optional[InvocationContext] = None,
) -> LoadedModel:
"""
Given a model's key, load it and return the LoadedModel object.
def load_model_by_config(self, config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> LoadedModel:
"""Given a model's configuration, load it and return the LoadedModel object."""
return self._any_loader.load_model(config, submodel_type)
:param key: Key of model config to be fetched.
:param submodel: For main (pipeline models), the submodel to fetch.
:param context: Invocation context used for event reporting
"""
config = self._store.get_model(key)
return self.load_model_by_config(config, submodel_type, context)
def load_model_by_attr(
self,
model_name: str,
base_model: BaseModelType,
model_type: ModelType,
submodel: Optional[SubModelType] = None,
context: Optional[InvocationContext] = None,
) -> LoadedModel:
"""
Given a model's attributes, search the database for it, and if found, load and return the LoadedModel object.
This is provided for API compatability with the get_model() method
in the original model manager. However, note that LoadedModel is
not the same as the original ModelInfo that ws returned.
:param model_name: Name of to be fetched.
:param base_model: Base model
:param model_type: Type of the model
:param submodel: For main (pipeline models), the submodel to fetch
:param context: The invocation context.
Exceptions: UnknownModelException -- model with this key not known
NotImplementedException -- a model loader was not provided at initialization time
ValueError -- more than one model matches this combination
"""
configs = self._store.search_by_attr(model_name, base_model, model_type)
if len(configs) == 0:
raise UnknownModelException(f"{base_model}/{model_type}/{model_name}: Unknown model")
elif len(configs) > 1:
raise ValueError(f"{base_model}/{model_type}/{model_name}: More than one model matches.")
else:
return self.load_model_by_key(configs[0].key, submodel)
def load_model_by_config(
self,
model_config: AnyModelConfig,
submodel_type: Optional[SubModelType] = None,
context: Optional[InvocationContext] = None,
) -> LoadedModel:
"""
Given a model's configuration, load it and return the LoadedModel object.
:param model_config: Model configuration record (as returned by ModelRecordBase.get_model())
:param submodel: For main (pipeline models), the submodel to fetch.
:param context: Invocation context used for event reporting
"""
if context:
self._emit_load_event(
context=context,
model_config=model_config,
)
loaded_model = self._any_loader.load_model(model_config, submodel_type)
if context:
self._emit_load_event(
context=context,
model_config=model_config,
loaded=True,
)
return loaded_model
def _emit_load_event(
self,
context: InvocationContext,
model_config: AnyModelConfig,
loaded: Optional[bool] = False,
) -> None:
if context.services.queue.is_canceled(context.graph_execution_state_id):
raise CanceledException()
if not loaded:
context.services.events.emit_model_load_started(
queue_id=context.queue_id,
queue_item_id=context.queue_item_id,
queue_batch_id=context.queue_batch_id,
graph_execution_state_id=context.graph_execution_state_id,
model_config=model_config,
)
else:
context.services.events.emit_model_load_completed(
queue_id=context.queue_id,
queue_item_id=context.queue_item_id,
queue_batch_id=context.queue_batch_id,
graph_execution_state_id=context.graph_execution_state_id,
model_config=model_config,
)

View File

@ -2,9 +2,10 @@
from abc import ABC, abstractmethod
from pydantic import BaseModel, Field
from typing_extensions import Self
from invokeai.app.services.invoker import Invoker
from ..config import InvokeAIAppConfig
from ..download import DownloadQueueServiceBase
from ..events.events_base import EventServiceBase
@ -14,12 +15,13 @@ from ..model_records import ModelRecordServiceBase
from ..shared.sqlite.sqlite_database import SqliteDatabase
class ModelManagerServiceBase(BaseModel, ABC):
class ModelManagerServiceBase(ABC):
"""Abstract base class for the model manager service."""
store: ModelRecordServiceBase = Field(description="An instance of the model record configuration service.")
install: ModelInstallServiceBase = Field(description="An instance of the model install service.")
load: ModelLoadServiceBase = Field(description="An instance of the model load service.")
# attributes:
# store: ModelRecordServiceBase = Field(description="An instance of the model record configuration service.")
# install: ModelInstallServiceBase = Field(description="An instance of the model install service.")
# load: ModelLoadServiceBase = Field(description="An instance of the model load service.")
@classmethod
@abstractmethod
@ -37,3 +39,29 @@ class ModelManagerServiceBase(BaseModel, ABC):
method simplifies the construction considerably.
"""
pass
@property
@abstractmethod
def store(self) -> ModelRecordServiceBase:
"""Return the ModelRecordServiceBase used to store and retrieve configuration records."""
pass
@property
@abstractmethod
def load(self) -> ModelLoadServiceBase:
"""Return the ModelLoadServiceBase used to load models from their configuration records."""
pass
@property
@abstractmethod
def install(self) -> ModelInstallServiceBase:
"""Return the ModelInstallServiceBase used to download and manipulate model files."""
pass
@abstractmethod
def start(self, invoker: Invoker) -> None:
pass
@abstractmethod
def stop(self, invoker: Invoker) -> None:
pass

View File

@ -3,6 +3,7 @@
from typing_extensions import Self
from invokeai.app.services.invoker import Invoker
from invokeai.backend.model_manager.load import ModelCache, ModelConvertCache
from invokeai.backend.model_manager.metadata import ModelMetadataStore
from invokeai.backend.util.logging import InvokeAILogger
@ -10,9 +11,9 @@ from invokeai.backend.util.logging import InvokeAILogger
from ..config import InvokeAIAppConfig
from ..download import DownloadQueueServiceBase
from ..events.events_base import EventServiceBase
from ..model_install import ModelInstallService
from ..model_load import ModelLoadService
from ..model_records import ModelRecordServiceSQL
from ..model_install import ModelInstallService, ModelInstallServiceBase
from ..model_load import ModelLoadService, ModelLoadServiceBase
from ..model_records import ModelRecordServiceBase, ModelRecordServiceSQL
from ..shared.sqlite.sqlite_database import SqliteDatabase
from .model_manager_base import ModelManagerServiceBase
@ -27,6 +28,38 @@ class ModelManagerService(ModelManagerServiceBase):
model_manager.load -- Routines to load models into memory.
"""
def __init__(
self,
store: ModelRecordServiceBase,
install: ModelInstallServiceBase,
load: ModelLoadServiceBase,
):
self._store = store
self._install = install
self._load = load
@property
def store(self) -> ModelRecordServiceBase:
return self._store
@property
def install(self) -> ModelInstallServiceBase:
return self._install
@property
def load(self) -> ModelLoadServiceBase:
return self._load
def start(self, invoker: Invoker) -> None:
for service in [self._store, self._install, self._load]:
if hasattr(service, "start"):
service.start(invoker)
def stop(self, invoker: Invoker) -> None:
for service in [self._store, self._install, self._load]:
if hasattr(service, "stop"):
service.stop(invoker)
@classmethod
def build_model_manager(
cls,

View File

@ -10,15 +10,12 @@ from typing import Any, Dict, List, Optional, Set, Tuple, Union
from pydantic import BaseModel, Field
from invokeai.app.invocations.baseinvocation import InvocationContext
from invokeai.app.services.shared.pagination import PaginatedResults
from invokeai.backend.model_manager import (
AnyModelConfig,
BaseModelType,
LoadedModel,
ModelFormat,
ModelType,
SubModelType,
)
from invokeai.backend.model_manager.load import AnyModelLoader
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata, ModelMetadataStore
@ -111,52 +108,6 @@ class ModelRecordServiceBase(ABC):
"""
pass
@abstractmethod
def load_model(
self,
key: str,
submodel: Optional[SubModelType] = None,
context: Optional[InvocationContext] = None,
) -> LoadedModel:
"""
Load the indicated model into memory and return a LoadedModel object.
:param key: Key of model config to be fetched.
:param submodel: For main (pipeline models), the submodel to fetch
:param context: Invocation context, used for event issuing.
Exceptions: UnknownModelException -- model with this key not known
NotImplementedException -- a model loader was not provided at initialization time
"""
pass
@abstractmethod
def load_model_by_attr(
self,
model_name: str,
base_model: BaseModelType,
model_type: ModelType,
submodel: Optional[SubModelType] = None,
context: Optional[InvocationContext] = None,
) -> LoadedModel:
"""
Load the indicated model into memory and return a LoadedModel object.
This is provided for API compatability with the get_model() method
in the original model manager. However, note that LoadedModel is
not the same as the original ModelInfo that ws returned.
:param model_name: Key of model config to be fetched.
:param base_model: Base model
:param model_type: Type of the model
:param submodel: For main (pipeline models), the submodel to fetch
:param context: The invocation context.
Exceptions: UnknownModelException -- model with this key not known
NotImplementedException -- a model loader was not provided at initialization time
"""
pass
@property
@abstractmethod
def metadata_store(self) -> ModelMetadataStore:

View File

@ -46,8 +46,6 @@ from math import ceil
from pathlib import Path
from typing import Any, Dict, List, Optional, Set, Tuple, Union
from invokeai.app.invocations.baseinvocation import InvocationContext
from invokeai.app.services.invocation_processor.invocation_processor_common import CanceledException
from invokeai.app.services.shared.pagination import PaginatedResults
from invokeai.backend.model_manager.config import (
AnyModelConfig,
@ -55,9 +53,8 @@ from invokeai.backend.model_manager.config import (
ModelConfigFactory,
ModelFormat,
ModelType,
SubModelType,
)
from invokeai.backend.model_manager.load import AnyModelLoader, LoadedModel
from invokeai.backend.model_manager.load import AnyModelLoader
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata, ModelMetadataStore, UnknownMetadataException
from ..shared.sqlite.sqlite_database import SqliteDatabase
@ -220,74 +217,6 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
model = ModelConfigFactory.make_config(json.loads(rows[0]), timestamp=rows[1])
return model
def load_model(
self,
key: str,
submodel: Optional[SubModelType],
context: Optional[InvocationContext] = None,
) -> LoadedModel:
"""
Load the indicated model into memory and return a LoadedModel object.
:param key: Key of model config to be fetched.
:param submodel: For main (pipeline models), the submodel to fetch.
:param context: Invocation context used for event reporting
Exceptions: UnknownModelException -- model with this key not known
NotImplementedException -- a model loader was not provided at initialization time
"""
if not self._loader:
raise NotImplementedError(f"Class {self.__class__} was not initialized with a model loader")
# we can emit model loading events if we are executing with access to the invocation context
model_config = self.get_model(key)
if context:
self._emit_load_event(
context=context,
model_config=model_config,
)
loaded_model = self._loader.load_model(model_config, submodel)
if context:
self._emit_load_event(
context=context,
model_config=model_config,
loaded=True,
)
return loaded_model
def load_model_by_attr(
self,
model_name: str,
base_model: BaseModelType,
model_type: ModelType,
submodel: Optional[SubModelType] = None,
context: Optional[InvocationContext] = None,
) -> LoadedModel:
"""
Load the indicated model into memory and return a LoadedModel object.
This is provided for API compatability with the get_model() method
in the original model manager. However, note that LoadedModel is
not the same as the original ModelInfo that ws returned.
:param model_name: Key of model config to be fetched.
:param base_model: Base model
:param model_type: Type of the model
:param submodel: For main (pipeline models), the submodel to fetch
:param context: The invocation context.
Exceptions: UnknownModelException -- model with this key not known
NotImplementedException -- a model loader was not provided at initialization time
ValueError -- more than one model matches this combination
"""
configs = self.search_by_attr(model_name, base_model, model_type)
if len(configs) == 0:
raise UnknownModelException(f"{base_model}/{model_type}/{model_name}: Unknown model")
elif len(configs) > 1:
raise ValueError(f"{base_model}/{model_type}/{model_name}: More than one model matches.")
else:
return self.load_model(configs[0].key, submodel)
def exists(self, key: str) -> bool:
"""
Return True if a model with the indicated key exists in the databse.
@ -476,29 +405,3 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
return PaginatedResults(
page=page, pages=ceil(total / per_page), per_page=per_page, total=total, items=items
)
def _emit_load_event(
self,
context: InvocationContext,
model_config: AnyModelConfig,
loaded: Optional[bool] = False,
) -> None:
if context.services.queue.is_canceled(context.graph_execution_state_id):
raise CanceledException()
if not loaded:
context.services.events.emit_model_load_started(
queue_id=context.queue_id,
queue_item_id=context.queue_item_id,
queue_batch_id=context.queue_batch_id,
graph_execution_state_id=context.graph_execution_state_id,
model_config=model_config,
)
else:
context.services.events.emit_model_load_completed(
queue_id=context.queue_id,
queue_item_id=context.queue_item_id,
queue_batch_id=context.queue_batch_id,
graph_execution_state_id=context.graph_execution_state_id,
model_config=model_config,
)

View File

@ -6,6 +6,7 @@ from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import
class Migration6Callback:
def __call__(self, cursor: sqlite3.Cursor) -> None:
self._recreate_model_triggers(cursor)
self._delete_ip_adapters(cursor)
def _recreate_model_triggers(self, cursor: sqlite3.Cursor) -> None:
"""
@ -26,6 +27,22 @@ class Migration6Callback:
"""
)
def _delete_ip_adapters(self, cursor: sqlite3.Cursor) -> None:
"""
Delete all the IP adapters.
The model manager will automatically find and re-add them after the migration
is done. This allows the manager to add the correct image encoder to their
configuration records.
"""
cursor.execute(
"""--sql
DELETE FROM model_config
WHERE type='ip_adapter';
"""
)
def build_migration_6() -> Migration:
"""
@ -33,6 +50,8 @@ def build_migration_6() -> Migration:
This migration does the following:
- Adds the model_config_updated_at trigger if it does not exist
- Delete all ip_adapter models so that the model prober can find and
update with the correct image processor model.
"""
migration_6 = Migration(
from_version=5,

View File

@ -64,7 +64,7 @@ class ModelPatcher:
def apply_lora_unet(
cls,
unet: UNet2DConditionModel,
loras: List[Tuple[LoRAModelRaw, float]],
loras: Iterator[Tuple[LoRAModelRaw, float]],
) -> None:
with cls.apply_lora(unet, loras, "lora_unet_"):
yield
@ -307,7 +307,7 @@ class ONNXModelPatcher:
def apply_lora_unet(
cls,
unet: OnnxRuntimeModel,
loras: List[Tuple[LoRAModelRaw, float]],
loras: Iterator[Tuple[LoRAModelRaw, float]],
) -> None:
with cls.apply_lora(unet, loras, "lora_unet_"):
yield

View File

@ -8,8 +8,8 @@ from PIL import Image
import invokeai.backend.util.logging as logger
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.backend import SilenceWarnings
from invokeai.backend.util.devices import choose_torch_device
from invokeai.backend.util.silence_warnings import SilenceWarnings
config = InvokeAIAppConfig.get_config()

View File

@ -8,7 +8,6 @@ from PIL import Image
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
from invokeai.backend.ip_adapter.ip_attention_weights import IPAttentionWeights
from invokeai.backend.model_management.models.base import calc_model_size_by_data
from .resampler import Resampler
@ -124,6 +123,9 @@ class IPAdapter:
self.attn_weights.to(device=self.device, dtype=self.dtype)
def calc_size(self):
# workaround for circular import
from invokeai.backend.model_manager.load.model_util import calc_model_size_by_data
return calc_model_size_by_data(self._image_proj_model) + calc_model_size_by_data(self.attn_weights)
def _init_image_proj_model(self, state_dict):

View File

@ -21,7 +21,7 @@ Validation errors will raise an InvalidModelConfigException error.
"""
import time
from enum import Enum
from typing import Literal, Optional, Type, Union, Class
from typing import Literal, Optional, Type, Union
import torch
from diffusers import ModelMixin
@ -335,7 +335,7 @@ class ModelConfigFactory(object):
cls,
model_data: Union[Dict[str, Any], AnyModelConfig],
key: Optional[str] = None,
dest_class: Optional[Type[Class]] = None,
dest_class: Optional[Type[ModelConfigBase]] = None,
timestamp: Optional[float] = None,
) -> AnyModelConfig:
"""
@ -347,14 +347,17 @@ class ModelConfigFactory(object):
:param dest_class: The config class to be returned. If not provided, will
be selected automatically.
"""
model: Optional[ModelConfigBase] = None
if isinstance(model_data, ModelConfigBase):
model = model_data
elif dest_class:
model = dest_class.validate_python(model_data)
model = dest_class.model_validate(model_data)
else:
model = AnyModelConfigValidator.validate_python(model_data)
# mypy doesn't typecheck TypeAdapters well?
model = AnyModelConfigValidator.validate_python(model_data) # type: ignore
assert model is not None
if key:
model.key = key
if timestamp:
model.last_modified = timestamp
return model
return model # type: ignore

View File

@ -18,8 +18,16 @@ from pathlib import Path
from typing import Any, Callable, Dict, Optional, Tuple, Type
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.backend.model_manager import AnyModelConfig, BaseModelType, ModelFormat, ModelType, SubModelType
from invokeai.backend.model_manager.config import AnyModel, VaeCheckpointConfig, VaeDiffusersConfig
from invokeai.backend.model_manager.config import (
AnyModel,
AnyModelConfig,
BaseModelType,
ModelFormat,
ModelType,
SubModelType,
VaeCheckpointConfig,
VaeDiffusersConfig,
)
from invokeai.backend.model_manager.load.convert_cache.convert_cache_base import ModelConvertCacheBase
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase, ModelLockerBase
from invokeai.backend.util.logging import InvokeAILogger
@ -32,7 +40,7 @@ class LoadedModel:
config: AnyModelConfig
locker: ModelLockerBase
def __enter__(self) -> AnyModel: # I think load_file() always returns a dict
def __enter__(self) -> AnyModel:
"""Context entry."""
self.locker.lock()
return self.model
@ -171,6 +179,10 @@ class AnyModelLoader:
def decorator(subclass: Type[ModelLoaderBase]) -> Type[ModelLoaderBase]:
cls._logger.debug(f"Registering class {subclass.__name__} to load models of type {base}/{type}/{format}")
key = cls._to_registry_key(base, type, format)
if key in cls._registry:
raise Exception(
f"{subclass.__name__} is trying to register as a loader for {base}/{type}/{format}, but this type of model has already been registered by {cls._registry[key].__name__}"
)
cls._registry[key] = subclass
return subclass

View File

@ -169,7 +169,7 @@ class ModelLoader(ModelLoaderBase):
raise InvalidModelConfigException("An expected config.json file is missing from this model.") from e
# This needs to be implemented in subclasses that handle checkpoints
def _convert_model(self, config: AnyModelConfig, weights_path: Path, output_path: Path) -> Path:
def _convert_model(self, config: AnyModelConfig, model_path: Path, output_path: Path) -> Path:
raise NotImplementedError
# This needs to be implemented in the subclass

View File

@ -246,7 +246,7 @@ class ModelCache(ModelCacheBase[AnyModel]):
def move_model_to_device(self, cache_entry: CacheRecord[AnyModel], target_device: torch.device) -> None:
"""Move model into the indicated device."""
# These attributes are not in the base ModelMixin class but in derived classes.
# These attributes are not in the base ModelMixin class but in various derived classes.
# Some models don't have these attributes, in which case they run in RAM/CPU.
self.logger.debug(f"Called to move {cache_entry.key} to {target_device}")
if not (hasattr(cache_entry.model, "device") and hasattr(cache_entry.model, "to")):

View File

@ -35,28 +35,28 @@ class ControlnetLoader(GenericDiffusersLoader):
else:
return True
def _convert_model(self, config: AnyModelConfig, weights_path: Path, output_path: Path) -> Path:
def _convert_model(self, config: AnyModelConfig, model_path: Path, output_path: Path) -> Path:
if config.base not in {BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2}:
raise Exception(f"Vae conversion not supported for model type: {config.base}")
else:
assert hasattr(config, "config")
config_file = config.config
if weights_path.suffix == ".safetensors":
checkpoint = safetensors.torch.load_file(weights_path, device="cpu")
if model_path.suffix == ".safetensors":
checkpoint = safetensors.torch.load_file(model_path, device="cpu")
else:
checkpoint = torch.load(weights_path, map_location="cpu")
checkpoint = torch.load(model_path, map_location="cpu")
# sometimes weights are hidden under "state_dict", and sometimes not
if "state_dict" in checkpoint:
checkpoint = checkpoint["state_dict"]
convert_controlnet_to_diffusers(
weights_path,
model_path,
output_path,
original_config_file=self._app_config.root_path / config_file,
image_size=512,
scan_needed=True,
from_safetensors=weights_path.suffix == ".safetensors",
from_safetensors=model_path.suffix == ".safetensors",
)
return output_path

View File

@ -12,8 +12,9 @@ from invokeai.backend.model_manager import (
ModelType,
SubModelType,
)
from invokeai.backend.model_manager.load.load_base import AnyModelLoader
from invokeai.backend.model_manager.load.load_default import ModelLoader
from ..load_base import AnyModelLoader
from ..load_default import ModelLoader
@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.CLIPVision, format=ModelFormat.Diffusers)

View File

@ -65,7 +65,7 @@ class StableDiffusionDiffusersModel(ModelLoader):
else:
return True
def _convert_model(self, config: AnyModelConfig, weights_path: Path, output_path: Path) -> Path:
def _convert_model(self, config: AnyModelConfig, model_path: Path, output_path: Path) -> Path:
assert isinstance(config, MainCheckpointConfig)
variant = config.variant
base = config.base
@ -75,9 +75,9 @@ class StableDiffusionDiffusersModel(ModelLoader):
config_file = config.config
self._logger.info(f"Converting {weights_path} to diffusers format")
self._logger.info(f"Converting {model_path} to diffusers format")
convert_ckpt_to_diffusers(
weights_path,
model_path,
output_path,
model_type=self.model_base_to_model_type[base],
model_version=base,
@ -86,7 +86,7 @@ class StableDiffusionDiffusersModel(ModelLoader):
extract_ema=True,
scan_needed=True,
pipeline_class=pipeline_class,
from_safetensors=weights_path.suffix == ".safetensors",
from_safetensors=model_path.suffix == ".safetensors",
precision=self._torch_dtype,
load_safety_checker=False,
)

View File

@ -37,7 +37,7 @@ class VaeLoader(GenericDiffusersLoader):
else:
return True
def _convert_model(self, config: AnyModelConfig, weights_path: Path, output_path: Path) -> Path:
def _convert_model(self, config: AnyModelConfig, model_path: Path, output_path: Path) -> Path:
# TO DO: check whether sdxl VAE models convert.
if config.base not in {BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2}:
raise Exception(f"Vae conversion not supported for model type: {config.base}")
@ -46,10 +46,10 @@ class VaeLoader(GenericDiffusersLoader):
"v1-inference.yaml" if config.base == BaseModelType.StableDiffusion1 else "v2-inference-v.yaml"
)
if weights_path.suffix == ".safetensors":
checkpoint = safetensors.torch.load_file(weights_path, device="cpu")
if model_path.suffix == ".safetensors":
checkpoint = safetensors.torch.load_file(model_path, device="cpu")
else:
checkpoint = torch.load(weights_path, map_location="cpu")
checkpoint = torch.load(model_path, map_location="cpu")
# sometimes weights are hidden under "state_dict", and sometimes not
if "state_dict" in checkpoint:

View File

@ -65,7 +65,7 @@ def calc_model_size_by_fs(model_path: Path, subfolder: Optional[str] = None, var
bit8_files = {f for f in all_files if ".8bit." in f.name or ".8bit-" in f.name}
other_files = set(all_files) - fp16_files - bit8_files
if variant is None:
if not variant: # ModelRepoVariant.DEFAULT evaluates to empty string for compatability with HF
files = other_files
elif variant == "fp16":
files = fp16_files

View File

@ -22,11 +22,12 @@ Example usage:
import os
from abc import ABC, abstractmethod
from logging import Logger
from pathlib import Path
from typing import Callable, Optional, Set, Union
from pydantic import BaseModel, Field
from logging import Logger
from invokeai.backend.util.logging import InvokeAILogger
default_logger: Logger = InvokeAILogger.get_logger()

View File

@ -1 +1,3 @@
from .schedulers import SCHEDULER_MAP # noqa: F401
__all__ = ["SCHEDULER_MAP"]

View File

@ -513,7 +513,7 @@ def select_and_download_models(opt: Namespace) -> None:
"""Prompt user for install/delete selections and execute."""
precision = "float32" if opt.full_precision else choose_precision(torch.device(choose_torch_device()))
# unsure how to avoid a typing complaint in the next line: config.precision is an enumerated Literal
config.precision = precision # type: ignore
config.precision = precision
install_helper = InstallHelper(config, logger)
installer = install_helper.installer

View File

@ -62,9 +62,7 @@ def mock_services() -> InvocationServices:
invocation_cache=MemoryInvocationCache(max_cache_size=0),
logger=logging, # type: ignore
model_manager=None, # type: ignore
model_records=None, # type: ignore
download_queue=None, # type: ignore
model_install=None, # type: ignore
names=None, # type: ignore
performance_statistics=InvocationStatsService(),
processor=DefaultInvocationProcessor(),

View File

@ -65,9 +65,7 @@ def mock_services() -> InvocationServices:
invocation_cache=MemoryInvocationCache(max_cache_size=0),
logger=logging, # type: ignore
model_manager=None, # type: ignore
model_records=None, # type: ignore
download_queue=None, # type: ignore
model_install=None, # type: ignore
names=None, # type: ignore
performance_statistics=InvocationStatsService(),
processor=DefaultInvocationProcessor(),

View File

@ -0,0 +1,22 @@
"""
Test model loading
"""
from pathlib import Path
from invokeai.app.services.model_install import ModelInstallServiceBase
from invokeai.backend.embeddings.textual_inversion import TextualInversionModelRaw
from invokeai.backend.model_manager.load import AnyModelLoader
from tests.backend.model_manager_2.model_manager_2_fixtures import * # noqa F403
def test_loading(mm2_installer: ModelInstallServiceBase, mm2_loader: AnyModelLoader, embedding_file: Path):
store = mm2_installer.record_store
matches = store.search_by_attr(model_name="test_embedding")
assert len(matches) == 0
key = mm2_installer.register_path(embedding_file)
loaded_model = mm2_loader.load_model(store.get_model(key))
assert loaded_model is not None
assert loaded_model.config.key == key
with loaded_model as model:
assert isinstance(model, TextualInversionModelRaw)

View File

@ -20,6 +20,7 @@ from invokeai.backend.model_manager.config import (
ModelFormat,
ModelType,
)
from invokeai.backend.model_manager.load import AnyModelLoader, ModelCache, ModelConvertCache
from invokeai.backend.model_manager.metadata import ModelMetadataStore
from invokeai.backend.util.logging import InvokeAILogger
from tests.backend.model_manager_2.model_metadata.metadata_examples import (
@ -89,6 +90,16 @@ def mm2_app_config(mm2_root_dir: Path) -> InvokeAIAppConfig:
return app_config
@pytest.fixture
def mm2_loader(mm2_app_config: InvokeAIAppConfig, mm2_record_store: ModelRecordServiceSQL) -> AnyModelLoader:
logger = InvokeAILogger.get_logger(config=mm2_app_config)
ram_cache = ModelCache(
logger=logger, max_cache_size=mm2_app_config.ram_cache_size, max_vram_cache_size=mm2_app_config.vram_cache_size
)
convert_cache = ModelConvertCache(mm2_app_config.models_convert_cache_path)
return AnyModelLoader(app_config=mm2_app_config, logger=logger, ram_cache=ram_cache, convert_cache=convert_cache)
@pytest.fixture
def mm2_record_store(mm2_app_config: InvokeAIAppConfig) -> ModelRecordServiceSQL:
logger = InvokeAILogger.get_logger(config=mm2_app_config)