mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
make model manager v2 ready for PR review
- Replace legacy model manager service with the v2 manager. - Update invocations to use new load interface. - Fixed many but not all type checking errors in the invocations. Most were unrelated to model manager - Updated routes. All the new routes live under the route tag `model_manager_v2`. To avoid confusion with the old routes, they have the URL prefix `/api/v2/models`. The old routes have been de-registered. - Added a pytest for the loader. - Updated documentation in contributing/MODEL_MANAGER.md
This commit is contained in:
parent
7956602b19
commit
a23dedd2ee
@ -28,7 +28,7 @@ model. These are the:
|
||||
Hugging Face, as well as discriminating among model versions in
|
||||
Civitai, but can be used for arbitrary content.
|
||||
|
||||
* _ModelLoadServiceBase_ (**CURRENTLY UNDER DEVELOPMENT - NOT IMPLEMENTED**)
|
||||
* _ModelLoadServiceBase_
|
||||
Responsible for loading a model from disk
|
||||
into RAM and VRAM and getting it ready for inference.
|
||||
|
||||
@ -41,10 +41,10 @@ The four main services can be found in
|
||||
* `invokeai/app/services/model_records/`
|
||||
* `invokeai/app/services/model_install/`
|
||||
* `invokeai/app/services/downloads/`
|
||||
* `invokeai/app/services/model_loader/` (**under development**)
|
||||
* `invokeai/app/services/model_load/`
|
||||
|
||||
Code related to the FastAPI web API can be found in
|
||||
`invokeai/app/api/routers/model_records.py`.
|
||||
`invokeai/app/api/routers/model_manager_v2.py`.
|
||||
|
||||
***
|
||||
|
||||
@ -84,10 +84,10 @@ diffusers model. When this happens, `original_hash` is unchanged, but
|
||||
`ModelType`, `ModelFormat` and `BaseModelType` are string enums that
|
||||
are defined in `invokeai.backend.model_manager.config`. They are also
|
||||
imported by, and can be reexported from,
|
||||
`invokeai.app.services.model_record_service`:
|
||||
`invokeai.app.services.model_manager.model_records`:
|
||||
|
||||
```
|
||||
from invokeai.app.services.model_record_service import ModelType, ModelFormat, BaseModelType
|
||||
from invokeai.app.services.model_records import ModelType, ModelFormat, BaseModelType
|
||||
```
|
||||
|
||||
The `path` field can be absolute or relative. If relative, it is taken
|
||||
@ -123,7 +123,7 @@ taken to be the `models_dir` directory.
|
||||
|
||||
`variant` is an enumerated string class with values `normal`,
|
||||
`inpaint` and `depth`. If needed, it can be imported if needed from
|
||||
either `invokeai.app.services.model_record_service` or
|
||||
either `invokeai.app.services.model_records` or
|
||||
`invokeai.backend.model_manager.config`.
|
||||
|
||||
### ONNXSD2Config
|
||||
@ -134,7 +134,7 @@ either `invokeai.app.services.model_record_service` or
|
||||
| `upcast_attention` | bool | Model requires its attention module to be upcast |
|
||||
|
||||
The `SchedulerPredictionType` enum can be imported from either
|
||||
`invokeai.app.services.model_record_service` or
|
||||
`invokeai.app.services.model_records` or
|
||||
`invokeai.backend.model_manager.config`.
|
||||
|
||||
### Other config classes
|
||||
@ -157,15 +157,6 @@ indicates that the model is compatible with any of the base
|
||||
models. This works OK for some models, such as the IP Adapter image
|
||||
encoders, but is an all-or-nothing proposition.
|
||||
|
||||
Another issue is that the config class hierarchy is paralleled to some
|
||||
extent by a `ModelBase` class hierarchy defined in
|
||||
`invokeai.backend.model_manager.models.base` and its subclasses. These
|
||||
are classes representing the models after they are loaded into RAM and
|
||||
include runtime information such as load status and bytes used. Some
|
||||
of the fields, including `name`, `model_type` and `base_model`, are
|
||||
shared between `ModelConfigBase` and `ModelBase`, and this is a
|
||||
potential source of confusion.
|
||||
|
||||
## Reading and Writing Model Configuration Records
|
||||
|
||||
The `ModelRecordService` provides the ability to retrieve model
|
||||
@ -177,11 +168,11 @@ initialization and can be retrieved within an invocation from the
|
||||
`InvocationContext` object:
|
||||
|
||||
```
|
||||
store = context.services.model_record_store
|
||||
store = context.services.model_manager.store
|
||||
```
|
||||
|
||||
or from elsewhere in the code by accessing
|
||||
`ApiDependencies.invoker.services.model_record_store`.
|
||||
`ApiDependencies.invoker.services.model_manager.store`.
|
||||
|
||||
### Creating a `ModelRecordService`
|
||||
|
||||
@ -190,7 +181,7 @@ you can directly create either a `ModelRecordServiceSQL` or a
|
||||
`ModelRecordServiceFile` object:
|
||||
|
||||
```
|
||||
from invokeai.app.services.model_record_service import ModelRecordServiceSQL, ModelRecordServiceFile
|
||||
from invokeai.app.services.model_records import ModelRecordServiceSQL, ModelRecordServiceFile
|
||||
|
||||
store = ModelRecordServiceSQL.from_connection(connection, lock)
|
||||
store = ModelRecordServiceSQL.from_db_file('/path/to/sqlite_database.db')
|
||||
@ -252,7 +243,7 @@ So a typical startup pattern would be:
|
||||
```
|
||||
import sqlite3
|
||||
from invokeai.app.services.thread import lock
|
||||
from invokeai.app.services.model_record_service import ModelRecordServiceBase
|
||||
from invokeai.app.services.model_records import ModelRecordServiceBase
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
|
||||
config = InvokeAIAppConfig.get_config()
|
||||
@ -260,19 +251,6 @@ db_conn = sqlite3.connect(config.db_path.as_posix(), check_same_thread=False)
|
||||
store = ModelRecordServiceBase.open(config, db_conn, lock)
|
||||
```
|
||||
|
||||
_A note on simultaneous access to `invokeai.db`_: The current InvokeAI
|
||||
service architecture for the image and graph databases is careful to
|
||||
use a shared sqlite3 connection and a thread lock to ensure that two
|
||||
threads don't attempt to access the database simultaneously. However,
|
||||
the default `sqlite3` library used by Python reports using
|
||||
**Serialized** mode, which allows multiple threads to access the
|
||||
database simultaneously using multiple database connections (see
|
||||
https://www.sqlite.org/threadsafe.html and
|
||||
https://ricardoanderegg.com/posts/python-sqlite-thread-safety/). Therefore
|
||||
it should be safe to allow the record service to open its own SQLite
|
||||
database connection. Opening a model record service should then be as
|
||||
simple as `ModelRecordServiceBase.open(config)`.
|
||||
|
||||
### Fetching a Model's Configuration from `ModelRecordServiceBase`
|
||||
|
||||
Configurations can be retrieved in several ways.
|
||||
@ -1465,7 +1443,7 @@ create alternative instances if you wish.
|
||||
### Creating a ModelLoadService object
|
||||
|
||||
The class is defined in
|
||||
`invokeai.app.services.model_loader_service`. It is initialized with
|
||||
`invokeai.app.services.model_load`. It is initialized with
|
||||
an InvokeAIAppConfig object, from which it gets configuration
|
||||
information such as the user's desired GPU and precision, and with a
|
||||
previously-created `ModelRecordServiceBase` object, from which it
|
||||
@ -1475,8 +1453,8 @@ Here is a typical initialization pattern:
|
||||
|
||||
```
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.app.services.model_record_service import ModelRecordServiceBase
|
||||
from invokeai.app.services.model_loader_service import ModelLoadService
|
||||
from invokeai.app.services.model_records import ModelRecordServiceBase
|
||||
from invokeai.app.services.model_load import ModelLoadService
|
||||
|
||||
config = InvokeAIAppConfig.get_config()
|
||||
store = ModelRecordServiceBase.open(config)
|
||||
@ -1487,14 +1465,11 @@ Note that we are relying on the contents of the application
|
||||
configuration to choose the implementation of
|
||||
`ModelRecordServiceBase`.
|
||||
|
||||
### get_model(key, [submodel_type], [context]) -> ModelInfo:
|
||||
### load_model_by_key(key, [submodel_type], [context]) -> LoadedModel
|
||||
|
||||
*** TO DO: change to get_model(key, context=None, **kwargs)
|
||||
|
||||
The `get_model()` method, like its similarly-named cousin in
|
||||
`ModelRecordService`, receives the unique key that identifies the
|
||||
model. It loads the model into memory, gets the model ready for use,
|
||||
and returns a `ModelInfo` object.
|
||||
The `load_model_by_key()` method receives the unique key that
|
||||
identifies the model. It loads the model into memory, gets the model
|
||||
ready for use, and returns a `LoadedModel` object.
|
||||
|
||||
The optional second argument, `subtype` is a `SubModelType` string
|
||||
enum, such as "vae". It is mandatory when used with a main model, and
|
||||
@ -1504,46 +1479,64 @@ The optional third argument, `context` can be provided by
|
||||
an invocation to trigger model load event reporting. See below for
|
||||
details.
|
||||
|
||||
The returned `ModelInfo` object shares some fields in common with
|
||||
`ModelConfigBase`, but is otherwise a completely different beast:
|
||||
The returned `LoadedModel` object contains a copy of the configuration
|
||||
record returned by the model record `get_model()` method, as well as
|
||||
the in-memory loaded model:
|
||||
|
||||
| **Field Name** | **Type** | **Description** |
|
||||
|
||||
| **Attribute Name** | **Type** | **Description** |
|
||||
|----------------|-----------------|------------------|
|
||||
| `key` | str | The model key derived from the ModelRecordService database |
|
||||
| `name` | str | Name of this model |
|
||||
| `base_model` | BaseModelType | Base model for this model |
|
||||
| `type` | ModelType or SubModelType | Either the model type (non-main) or the submodel type (main models)|
|
||||
| `location` | Path or str | Location of the model on the filesystem |
|
||||
| `precision` | torch.dtype | The torch.precision to use for inference |
|
||||
| `context` | ModelCache.ModelLocker | A context class used to lock the model in VRAM while in use |
|
||||
| `config` | AnyModelConfig | A copy of the model's configuration record for retrieving base type, etc. |
|
||||
| `model` | AnyModel | The instantiated model (details below) |
|
||||
| `locker` | ModelLockerBase | A context manager that mediates the movement of the model into VRAM |
|
||||
|
||||
The types for `ModelInfo` and `SubModelType` can be imported from
|
||||
`invokeai.app.services.model_loader_service`.
|
||||
Because the loader can return multiple model types, it is typed to
|
||||
return `AnyModel`, a Union `ModelMixin`, `torch.nn.Module`,
|
||||
`IAIOnnxRuntimeModel`, `IPAdapter`, `IPAdapterPlus`, and
|
||||
`EmbeddingModelRaw`. `ModelMixin` is the base class of all diffusers
|
||||
models, `EmbeddingModelRaw` is used for LoRA and TextualInversion
|
||||
models. The others are obvious.
|
||||
|
||||
To use the model, you use the `ModelInfo` as a context manager using
|
||||
the following pattern:
|
||||
|
||||
`LoadedModel` acts as a context manager. The context loads the model
|
||||
into the execution device (e.g. VRAM on CUDA systems), locks the model
|
||||
in the execution device for the duration of the context, and returns
|
||||
the model. Use it like this:
|
||||
|
||||
```
|
||||
model_info = loader.get_model('f13dd932c0c35c22dcb8d6cda4203764', SubModelType('vae'))
|
||||
model_info = loader.get_model_by_key('f13dd932c0c35c22dcb8d6cda4203764', SubModelType('vae'))
|
||||
with model_info as vae:
|
||||
image = vae.decode(latents)[0]
|
||||
```
|
||||
|
||||
The `vae` model will stay locked in the GPU during the period of time
|
||||
it is in the context manager's scope.
|
||||
`get_model_by_key()` may raise any of the following exceptions:
|
||||
|
||||
`get_model()` may raise any of the following exceptions:
|
||||
|
||||
- `UnknownModelException` -- key not in database
|
||||
- `ModelNotFoundException` -- key in database but model not found at path
|
||||
- `InvalidModelException` -- the model is guilty of a variety of sins
|
||||
- `UnknownModelException` -- key not in database
|
||||
- `ModelNotFoundException` -- key in database but model not found at path
|
||||
- `NotImplementedException` -- the loader doesn't know how to load this type of model
|
||||
|
||||
** TO DO: ** Resolve discrepancy between ModelInfo.location and
|
||||
ModelConfig.path.
|
||||
### load_model_by_attr(model_name, base_model, model_type, [submodel], [context]) -> LoadedModel
|
||||
|
||||
This is similar to `load_model_by_key`, but instead it accepts the
|
||||
combination of the model's name, type and base, which it passes to the
|
||||
model record config store for retrieval. If successful, this method
|
||||
returns a `LoadedModel`. It can raise the following exceptions:
|
||||
|
||||
```
|
||||
UnknownModelException -- model with these attributes not known
|
||||
NotImplementedException -- the loader doesn't know how to load this type of model
|
||||
ValueError -- more than one model matches this combination of base/type/name
|
||||
```
|
||||
|
||||
### load_model_by_config(config, [submodel], [context]) -> LoadedModel
|
||||
|
||||
This method takes an `AnyModelConfig` returned by
|
||||
ModelRecordService.get_model() and returns the corresponding loaded
|
||||
model. It may raise a `NotImplementedException`.
|
||||
|
||||
### Emitting model loading events
|
||||
|
||||
When the `context` argument is passed to `get_model()`, it will
|
||||
When the `context` argument is passed to `load_model_*()`, it will
|
||||
retrieve the invocation event bus from the passed `InvocationContext`
|
||||
object to emit events on the invocation bus. The two events are
|
||||
"model_load_started" and "model_load_completed". Both carry the
|
||||
@ -1563,3 +1556,97 @@ payload=dict(
|
||||
)
|
||||
```
|
||||
|
||||
### Adding Model Loaders
|
||||
|
||||
Model loaders are small classes that inherit from the `ModelLoader`
|
||||
base class. They typically implement one method `_load_model()` whose
|
||||
signature is:
|
||||
|
||||
```
|
||||
def _load_model(
|
||||
self,
|
||||
model_path: Path,
|
||||
model_variant: Optional[ModelRepoVariant] = None,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
) -> AnyModel:
|
||||
```
|
||||
|
||||
`_load_model()` will be passed the path to the model on disk, an
|
||||
optional repository variant (used by the diffusers loaders to select,
|
||||
e.g. the `fp16` variant, and an optional submodel_type for main and
|
||||
onnx models.
|
||||
|
||||
To install a new loader, place it in
|
||||
`invokeai/backend/model_manager/load/model_loaders`. Inherit from
|
||||
`ModelLoader` and use the `@AnyModelLoader.register()` decorator to
|
||||
indicate what type of models the loader can handle.
|
||||
|
||||
Here is a complete example from `generic_diffusers.py`, which is able
|
||||
to load several different diffusers types:
|
||||
|
||||
```
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from invokeai.backend.model_manager import (
|
||||
AnyModel,
|
||||
BaseModelType,
|
||||
ModelFormat,
|
||||
ModelRepoVariant,
|
||||
ModelType,
|
||||
SubModelType,
|
||||
)
|
||||
from ..load_base import AnyModelLoader
|
||||
from ..load_default import ModelLoader
|
||||
|
||||
|
||||
@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.CLIPVision, format=ModelFormat.Diffusers)
|
||||
@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.T2IAdapter, format=ModelFormat.Diffusers)
|
||||
class GenericDiffusersLoader(ModelLoader):
|
||||
"""Class to load simple diffusers models."""
|
||||
|
||||
def _load_model(
|
||||
self,
|
||||
model_path: Path,
|
||||
model_variant: Optional[ModelRepoVariant] = None,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
) -> AnyModel:
|
||||
model_class = self._get_hf_load_class(model_path)
|
||||
if submodel_type is not None:
|
||||
raise Exception(f"There are no submodels in models of type {model_class}")
|
||||
variant = model_variant.value if model_variant else None
|
||||
result: AnyModel = model_class.from_pretrained(model_path, torch_dtype=self._torch_dtype, variant=variant) # type: ignore
|
||||
return result
|
||||
```
|
||||
|
||||
Note that a loader can register itself to handle several different
|
||||
model types. An exception will be raised if more than one loader tries
|
||||
to register the same model type.
|
||||
|
||||
#### Conversion
|
||||
|
||||
Some models require conversion to diffusers format before they can be
|
||||
loaded. These loaders should override two additional methods:
|
||||
|
||||
```
|
||||
_needs_conversion(self, config: AnyModelConfig, model_path: Path, dest_path: Path) -> bool
|
||||
_convert_model(self, config: AnyModelConfig, model_path: Path, output_path: Path) -> Path:
|
||||
```
|
||||
|
||||
The first method accepts the model configuration, the path to where
|
||||
the unmodified model is currently installed, and a proposed
|
||||
destination for the converted model. This method returns True if the
|
||||
model needs to be converted. It typically does this by comparing the
|
||||
last modification time of the original model file to the modification
|
||||
time of the converted model. In some cases you will also want to check
|
||||
the modification date of the configuration record, in the event that
|
||||
the user has changed something like the scheduler prediction type that
|
||||
will require the model to be re-converted. See `controlnet.py` for an
|
||||
example of this logic.
|
||||
|
||||
The second method accepts the model configuration, the path to the
|
||||
original model on disk, and the desired output path for the converted
|
||||
model. It does whatever it needs to do to get the model into diffusers
|
||||
format, and returns the Path of the resulting model. (The path should
|
||||
ordinarily be the same as `output_path`.)
|
||||
|
||||
|
@ -8,9 +8,6 @@ from invokeai.app.services.item_storage.item_storage_memory import ItemStorageMe
|
||||
from invokeai.app.services.object_serializer.object_serializer_disk import ObjectSerializerDisk
|
||||
from invokeai.app.services.object_serializer.object_serializer_forward_cache import ObjectSerializerForwardCache
|
||||
from invokeai.app.services.shared.sqlite.sqlite_util import init_db
|
||||
from invokeai.backend.model_manager.load import AnyModelLoader, ModelConvertCache
|
||||
from invokeai.backend.model_manager.load.model_cache import ModelCache
|
||||
from invokeai.backend.model_manager.metadata import ModelMetadataStore
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
from invokeai.version.invokeai_version import __version__
|
||||
@ -30,9 +27,7 @@ from ..services.invocation_queue.invocation_queue_memory import MemoryInvocation
|
||||
from ..services.invocation_services import InvocationServices
|
||||
from ..services.invocation_stats.invocation_stats_default import InvocationStatsService
|
||||
from ..services.invoker import Invoker
|
||||
from ..services.model_install import ModelInstallService
|
||||
from ..services.model_manager.model_manager_default import ModelManagerService
|
||||
from ..services.model_records import ModelRecordServiceSQL
|
||||
from ..services.names.names_default import SimpleNameService
|
||||
from ..services.session_processor.session_processor_default import DefaultSessionProcessor
|
||||
from ..services.session_queue.session_queue_sqlite import SqliteSessionQueue
|
||||
@ -98,28 +93,10 @@ class ApiDependencies:
|
||||
conditioning = ObjectSerializerForwardCache(
|
||||
ObjectSerializerDisk[ConditioningFieldData](output_folder / "conditioning", ephemeral=True)
|
||||
)
|
||||
model_manager = ModelManagerService(config, logger)
|
||||
model_record_service = ModelRecordServiceSQL(db=db)
|
||||
model_loader = AnyModelLoader(
|
||||
app_config=config,
|
||||
logger=logger,
|
||||
ram_cache=ModelCache(
|
||||
max_cache_size=config.ram_cache_size, max_vram_cache_size=config.vram_cache_size, logger=logger
|
||||
),
|
||||
convert_cache=ModelConvertCache(
|
||||
cache_path=config.models_convert_cache_path, max_size=config.convert_cache_size
|
||||
),
|
||||
)
|
||||
model_record_service = ModelRecordServiceSQL(db=db, loader=model_loader)
|
||||
download_queue_service = DownloadQueueService(event_bus=events)
|
||||
model_install_service = ModelInstallService(
|
||||
app_config=config,
|
||||
record_store=model_record_service,
|
||||
download_queue=download_queue_service,
|
||||
metadata_store=ModelMetadataStore(db=db),
|
||||
event_bus=events,
|
||||
model_manager = ModelManagerService.build_model_manager(
|
||||
app_config=configuration, db=db, download_queue=download_queue_service, events=events
|
||||
)
|
||||
model_manager = ModelManagerService(config, logger) # TO DO: legacy model manager v1. Remove
|
||||
names = SimpleNameService()
|
||||
performance_statistics = InvocationStatsService()
|
||||
processor = DefaultInvocationProcessor()
|
||||
@ -143,9 +120,7 @@ class ApiDependencies:
|
||||
invocation_cache=invocation_cache,
|
||||
logger=logger,
|
||||
model_manager=model_manager,
|
||||
model_records=model_record_service,
|
||||
download_queue=download_queue_service,
|
||||
model_install=model_install_service,
|
||||
names=names,
|
||||
performance_statistics=performance_statistics,
|
||||
processor=processor,
|
||||
|
@ -32,7 +32,7 @@ from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata
|
||||
|
||||
from ..dependencies import ApiDependencies
|
||||
|
||||
model_records_router = APIRouter(prefix="/v1/model/record", tags=["model_manager_v2_unstable"])
|
||||
model_manager_v2_router = APIRouter(prefix="/v2/models", tags=["model_manager_v2"])
|
||||
|
||||
|
||||
class ModelsList(BaseModel):
|
||||
@ -52,7 +52,7 @@ class ModelTagSet(BaseModel):
|
||||
tags: Set[str]
|
||||
|
||||
|
||||
@model_records_router.get(
|
||||
@model_manager_v2_router.get(
|
||||
"/",
|
||||
operation_id="list_model_records",
|
||||
)
|
||||
@ -65,7 +65,7 @@ async def list_model_records(
|
||||
),
|
||||
) -> ModelsList:
|
||||
"""Get a list of models."""
|
||||
record_store = ApiDependencies.invoker.services.model_records
|
||||
record_store = ApiDependencies.invoker.services.model_manager.store
|
||||
found_models: list[AnyModelConfig] = []
|
||||
if base_models:
|
||||
for base_model in base_models:
|
||||
@ -81,7 +81,7 @@ async def list_model_records(
|
||||
return ModelsList(models=found_models)
|
||||
|
||||
|
||||
@model_records_router.get(
|
||||
@model_manager_v2_router.get(
|
||||
"/i/{key}",
|
||||
operation_id="get_model_record",
|
||||
responses={
|
||||
@ -94,24 +94,27 @@ async def get_model_record(
|
||||
key: str = Path(description="Key of the model record to fetch."),
|
||||
) -> AnyModelConfig:
|
||||
"""Get a model record"""
|
||||
record_store = ApiDependencies.invoker.services.model_records
|
||||
record_store = ApiDependencies.invoker.services.model_manager.store
|
||||
try:
|
||||
return record_store.get_model(key)
|
||||
config: AnyModelConfig = record_store.get_model(key)
|
||||
return config
|
||||
except UnknownModelException as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
|
||||
|
||||
@model_records_router.get("/meta", operation_id="list_model_summary")
|
||||
@model_manager_v2_router.get("/meta", operation_id="list_model_summary")
|
||||
async def list_model_summary(
|
||||
page: int = Query(default=0, description="The page to get"),
|
||||
per_page: int = Query(default=10, description="The number of models per page"),
|
||||
order_by: ModelRecordOrderBy = Query(default=ModelRecordOrderBy.Default, description="The attribute to order by"),
|
||||
) -> PaginatedResults[ModelSummary]:
|
||||
"""Gets a page of model summary data."""
|
||||
return ApiDependencies.invoker.services.model_records.list_models(page=page, per_page=per_page, order_by=order_by)
|
||||
record_store = ApiDependencies.invoker.services.model_manager.store
|
||||
results: PaginatedResults[ModelSummary] = record_store.list_models(page=page, per_page=per_page, order_by=order_by)
|
||||
return results
|
||||
|
||||
|
||||
@model_records_router.get(
|
||||
@model_manager_v2_router.get(
|
||||
"/meta/i/{key}",
|
||||
operation_id="get_model_metadata",
|
||||
responses={
|
||||
@ -124,24 +127,25 @@ async def get_model_metadata(
|
||||
key: str = Path(description="Key of the model repo metadata to fetch."),
|
||||
) -> Optional[AnyModelRepoMetadata]:
|
||||
"""Get a model metadata object."""
|
||||
record_store = ApiDependencies.invoker.services.model_records
|
||||
result = record_store.get_metadata(key)
|
||||
record_store = ApiDependencies.invoker.services.model_manager.store
|
||||
result: Optional[AnyModelRepoMetadata] = record_store.get_metadata(key)
|
||||
if not result:
|
||||
raise HTTPException(status_code=404, detail="No metadata for a model with this key")
|
||||
return result
|
||||
|
||||
|
||||
@model_records_router.get(
|
||||
@model_manager_v2_router.get(
|
||||
"/tags",
|
||||
operation_id="list_tags",
|
||||
)
|
||||
async def list_tags() -> Set[str]:
|
||||
"""Get a unique set of all the model tags."""
|
||||
record_store = ApiDependencies.invoker.services.model_records
|
||||
return record_store.list_tags()
|
||||
record_store = ApiDependencies.invoker.services.model_manager.store
|
||||
result: Set[str] = record_store.list_tags()
|
||||
return result
|
||||
|
||||
|
||||
@model_records_router.get(
|
||||
@model_manager_v2_router.get(
|
||||
"/tags/search",
|
||||
operation_id="search_by_metadata_tags",
|
||||
)
|
||||
@ -149,12 +153,12 @@ async def search_by_metadata_tags(
|
||||
tags: Set[str] = Query(default=None, description="Tags to search for"),
|
||||
) -> ModelsList:
|
||||
"""Get a list of models."""
|
||||
record_store = ApiDependencies.invoker.services.model_records
|
||||
record_store = ApiDependencies.invoker.services.model_manager.store
|
||||
results = record_store.search_by_metadata_tag(tags)
|
||||
return ModelsList(models=results)
|
||||
|
||||
|
||||
@model_records_router.patch(
|
||||
@model_manager_v2_router.patch(
|
||||
"/i/{key}",
|
||||
operation_id="update_model_record",
|
||||
responses={
|
||||
@ -172,9 +176,9 @@ async def update_model_record(
|
||||
) -> AnyModelConfig:
|
||||
"""Update model contents with a new config. If the model name or base fields are changed, then the model is renamed."""
|
||||
logger = ApiDependencies.invoker.services.logger
|
||||
record_store = ApiDependencies.invoker.services.model_records
|
||||
record_store = ApiDependencies.invoker.services.model_manager.store
|
||||
try:
|
||||
model_response = record_store.update_model(key, config=info)
|
||||
model_response: AnyModelConfig = record_store.update_model(key, config=info)
|
||||
logger.info(f"Updated model: {key}")
|
||||
except UnknownModelException as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
@ -184,7 +188,7 @@ async def update_model_record(
|
||||
return model_response
|
||||
|
||||
|
||||
@model_records_router.delete(
|
||||
@model_manager_v2_router.delete(
|
||||
"/i/{key}",
|
||||
operation_id="del_model_record",
|
||||
responses={
|
||||
@ -205,7 +209,7 @@ async def del_model_record(
|
||||
logger = ApiDependencies.invoker.services.logger
|
||||
|
||||
try:
|
||||
installer = ApiDependencies.invoker.services.model_install
|
||||
installer = ApiDependencies.invoker.services.model_manager.install
|
||||
installer.delete(key)
|
||||
logger.info(f"Deleted model: {key}")
|
||||
return Response(status_code=204)
|
||||
@ -214,7 +218,7 @@ async def del_model_record(
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
|
||||
|
||||
@model_records_router.post(
|
||||
@model_manager_v2_router.post(
|
||||
"/i/",
|
||||
operation_id="add_model_record",
|
||||
responses={
|
||||
@ -229,7 +233,7 @@ async def add_model_record(
|
||||
) -> AnyModelConfig:
|
||||
"""Add a model using the configuration information appropriate for its type."""
|
||||
logger = ApiDependencies.invoker.services.logger
|
||||
record_store = ApiDependencies.invoker.services.model_records
|
||||
record_store = ApiDependencies.invoker.services.model_manager.store
|
||||
if config.key == "<NOKEY>":
|
||||
config.key = sha1(randbytes(100)).hexdigest()
|
||||
logger.info(f"Created model {config.key} for {config.name}")
|
||||
@ -243,10 +247,11 @@ async def add_model_record(
|
||||
raise HTTPException(status_code=415)
|
||||
|
||||
# now fetch it out
|
||||
return record_store.get_model(config.key)
|
||||
result: AnyModelConfig = record_store.get_model(config.key)
|
||||
return result
|
||||
|
||||
|
||||
@model_records_router.post(
|
||||
@model_manager_v2_router.post(
|
||||
"/import",
|
||||
operation_id="import_model_record",
|
||||
responses={
|
||||
@ -322,7 +327,7 @@ async def import_model(
|
||||
logger = ApiDependencies.invoker.services.logger
|
||||
|
||||
try:
|
||||
installer = ApiDependencies.invoker.services.model_install
|
||||
installer = ApiDependencies.invoker.services.model_manager.install
|
||||
result: ModelInstallJob = installer.import_model(
|
||||
source=source,
|
||||
config=config,
|
||||
@ -340,17 +345,17 @@ async def import_model(
|
||||
return result
|
||||
|
||||
|
||||
@model_records_router.get(
|
||||
@model_manager_v2_router.get(
|
||||
"/import",
|
||||
operation_id="list_model_install_jobs",
|
||||
)
|
||||
async def list_model_install_jobs() -> List[ModelInstallJob]:
|
||||
"""Return list of model install jobs."""
|
||||
jobs: List[ModelInstallJob] = ApiDependencies.invoker.services.model_install.list_jobs()
|
||||
jobs: List[ModelInstallJob] = ApiDependencies.invoker.services.model_manager.install.list_jobs()
|
||||
return jobs
|
||||
|
||||
|
||||
@model_records_router.get(
|
||||
@model_manager_v2_router.get(
|
||||
"/import/{id}",
|
||||
operation_id="get_model_install_job",
|
||||
responses={
|
||||
@ -361,12 +366,13 @@ async def list_model_install_jobs() -> List[ModelInstallJob]:
|
||||
async def get_model_install_job(id: int = Path(description="Model install id")) -> ModelInstallJob:
|
||||
"""Return model install job corresponding to the given source."""
|
||||
try:
|
||||
return ApiDependencies.invoker.services.model_install.get_job_by_id(id)
|
||||
result: ModelInstallJob = ApiDependencies.invoker.services.model_manager.install.get_job_by_id(id)
|
||||
return result
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
|
||||
|
||||
@model_records_router.delete(
|
||||
@model_manager_v2_router.delete(
|
||||
"/import/{id}",
|
||||
operation_id="cancel_model_install_job",
|
||||
responses={
|
||||
@ -377,7 +383,7 @@ async def get_model_install_job(id: int = Path(description="Model install id"))
|
||||
)
|
||||
async def cancel_model_install_job(id: int = Path(description="Model install job ID")) -> None:
|
||||
"""Cancel the model install job(s) corresponding to the given job ID."""
|
||||
installer = ApiDependencies.invoker.services.model_install
|
||||
installer = ApiDependencies.invoker.services.model_manager.install
|
||||
try:
|
||||
job = installer.get_job_by_id(id)
|
||||
except ValueError as e:
|
||||
@ -385,7 +391,7 @@ async def cancel_model_install_job(id: int = Path(description="Model install job
|
||||
installer.cancel_job(job)
|
||||
|
||||
|
||||
@model_records_router.patch(
|
||||
@model_manager_v2_router.patch(
|
||||
"/import",
|
||||
operation_id="prune_model_install_jobs",
|
||||
responses={
|
||||
@ -395,11 +401,11 @@ async def cancel_model_install_job(id: int = Path(description="Model install job
|
||||
)
|
||||
async def prune_model_install_jobs() -> Response:
|
||||
"""Prune all completed and errored jobs from the install job list."""
|
||||
ApiDependencies.invoker.services.model_install.prune_jobs()
|
||||
ApiDependencies.invoker.services.model_manager.install.prune_jobs()
|
||||
return Response(status_code=204)
|
||||
|
||||
|
||||
@model_records_router.patch(
|
||||
@model_manager_v2_router.patch(
|
||||
"/sync",
|
||||
operation_id="sync_models_to_config",
|
||||
responses={
|
||||
@ -414,11 +420,11 @@ async def sync_models_to_config() -> Response:
|
||||
Model files without a corresponding
|
||||
record in the database are added. Orphan records without a models file are deleted.
|
||||
"""
|
||||
ApiDependencies.invoker.services.model_install.sync_to_config()
|
||||
ApiDependencies.invoker.services.model_manager.install.sync_to_config()
|
||||
return Response(status_code=204)
|
||||
|
||||
|
||||
@model_records_router.put(
|
||||
@model_manager_v2_router.put(
|
||||
"/merge",
|
||||
operation_id="merge",
|
||||
)
|
||||
@ -451,7 +457,7 @@ async def merge(
|
||||
try:
|
||||
logger.info(f"Merging models: {keys} into {merge_dest_directory or '<MODELS>'}/{merged_model_name}")
|
||||
dest = pathlib.Path(merge_dest_directory) if merge_dest_directory else None
|
||||
installer = ApiDependencies.invoker.services.model_install
|
||||
installer = ApiDependencies.invoker.services.model_manager.install
|
||||
merger = ModelMerger(installer)
|
||||
model_names = [installer.record_store.get_model(x).name for x in keys]
|
||||
response = merger.merge_diffusion_models_and_save(
|
@ -8,8 +8,7 @@ from fastapi.routing import APIRouter
|
||||
from pydantic import BaseModel, ConfigDict, Field, TypeAdapter
|
||||
from starlette.exceptions import HTTPException
|
||||
|
||||
from invokeai.backend import BaseModelType, ModelType
|
||||
from invokeai.backend.model_management import MergeInterpolationMethod
|
||||
from invokeai.backend.model_management import BaseModelType, MergeInterpolationMethod, ModelType
|
||||
from invokeai.backend.model_management.models import (
|
||||
OPENAPI_MODEL_CONFIGS,
|
||||
InvalidModelException,
|
||||
|
@ -48,7 +48,7 @@ if True: # hack to make flake8 happy with imports coming after setting up the c
|
||||
boards,
|
||||
download_queue,
|
||||
images,
|
||||
model_records,
|
||||
model_manager_v2,
|
||||
models,
|
||||
session_queue,
|
||||
sessions,
|
||||
@ -114,8 +114,7 @@ async def shutdown_event() -> None:
|
||||
app.include_router(sessions.session_router, prefix="/api")
|
||||
|
||||
app.include_router(utilities.utilities_router, prefix="/api")
|
||||
app.include_router(models.models_router, prefix="/api")
|
||||
app.include_router(model_records.model_records_router, prefix="/api")
|
||||
app.include_router(model_manager_v2.model_manager_v2_router, prefix="/api")
|
||||
app.include_router(download_queue.download_queue_router, prefix="/api")
|
||||
app.include_router(images.images_router, prefix="/api")
|
||||
app.include_router(boards.boards_router, prefix="/api")
|
||||
|
@ -3,6 +3,7 @@ from typing import Iterator, List, Optional, Tuple, Union
|
||||
import torch
|
||||
from compel import Compel, ReturnedEmbeddingsType
|
||||
from compel.prompt_parser import Blend, Conjunction, CrossAttentionControlSubstitute, FlattenedPrompt, Fragment
|
||||
from transformers import CLIPTokenizer
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
from invokeai.app.invocations.fields import (
|
||||
@ -68,18 +69,18 @@ class CompelInvocation(BaseInvocation):
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> ConditioningOutput:
|
||||
tokenizer_info = context.services.model_records.load_model(
|
||||
tokenizer_info = context.services.model_manager.load.load_model_by_key(
|
||||
**self.clip.tokenizer.model_dump(),
|
||||
context=context,
|
||||
)
|
||||
text_encoder_info = context.services.model_records.load_model(
|
||||
text_encoder_info = context.services.model_manager.load.load_model_by_key(
|
||||
**self.clip.text_encoder.model_dump(),
|
||||
context=context,
|
||||
)
|
||||
|
||||
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
|
||||
for lora in self.clip.loras:
|
||||
lora_info = context.services.model_records.load_model(
|
||||
lora_info = context.services.model_manager.load.load_model_by_key(
|
||||
**lora.model_dump(exclude={"weight"}), context=context
|
||||
)
|
||||
assert isinstance(lora_info.model, LoRAModelRaw)
|
||||
@ -93,7 +94,7 @@ class CompelInvocation(BaseInvocation):
|
||||
for trigger in extract_ti_triggers_from_prompt(self.prompt):
|
||||
name = trigger[1:-1]
|
||||
try:
|
||||
loaded_model = context.services.model_records.load_model(
|
||||
loaded_model = context.services.model_manager.load.load_model_by_key(
|
||||
**self.clip.text_encoder.model_dump(),
|
||||
context=context,
|
||||
).model
|
||||
@ -164,11 +165,11 @@ class SDXLPromptInvocationBase:
|
||||
lora_prefix: str,
|
||||
zero_on_empty: bool,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[ExtraConditioningInfo]]:
|
||||
tokenizer_info = context.services.model_records.load_model(
|
||||
tokenizer_info = context.services.model_manager.load.load_model_by_key(
|
||||
**clip_field.tokenizer.model_dump(),
|
||||
context=context,
|
||||
)
|
||||
text_encoder_info = context.services.model_records.load_model(
|
||||
text_encoder_info = context.services.model_manager.load.load_model_by_key(
|
||||
**clip_field.text_encoder.model_dump(),
|
||||
context=context,
|
||||
)
|
||||
@ -196,7 +197,7 @@ class SDXLPromptInvocationBase:
|
||||
|
||||
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
|
||||
for lora in clip_field.loras:
|
||||
lora_info = context.services.model_records.load_model(
|
||||
lora_info = context.services.model_manager.load.load_model_by_key(
|
||||
**lora.model_dump(exclude={"weight"}), context=context
|
||||
)
|
||||
lora_model = lora_info.model
|
||||
@ -211,7 +212,7 @@ class SDXLPromptInvocationBase:
|
||||
for trigger in extract_ti_triggers_from_prompt(prompt):
|
||||
name = trigger[1:-1]
|
||||
try:
|
||||
ti_model = context.services.model_records.load_model_by_attr(
|
||||
ti_model = context.services.model_manager.load.load_model_by_attr(
|
||||
model_name=name,
|
||||
base_model=text_encoder_info.config.base,
|
||||
model_type=ModelType.TextualInversion,
|
||||
@ -448,9 +449,9 @@ class ClipSkipInvocation(BaseInvocation):
|
||||
|
||||
|
||||
def get_max_token_count(
|
||||
tokenizer,
|
||||
tokenizer: CLIPTokenizer,
|
||||
prompt: Union[FlattenedPrompt, Blend, Conjunction],
|
||||
truncate_if_too_long=False,
|
||||
truncate_if_too_long: bool = False,
|
||||
) -> int:
|
||||
if type(prompt) is Blend:
|
||||
blend: Blend = prompt
|
||||
@ -462,7 +463,9 @@ def get_max_token_count(
|
||||
return len(get_tokens_for_prompt_object(tokenizer, prompt, truncate_if_too_long))
|
||||
|
||||
|
||||
def get_tokens_for_prompt_object(tokenizer, parsed_prompt: FlattenedPrompt, truncate_if_too_long=True) -> List[str]:
|
||||
def get_tokens_for_prompt_object(
|
||||
tokenizer: CLIPTokenizer, parsed_prompt: FlattenedPrompt, truncate_if_too_long: bool = True
|
||||
) -> List[str]:
|
||||
if type(parsed_prompt) is Blend:
|
||||
raise ValueError("Blend is not supported here - you need to get tokens for each of its .children")
|
||||
|
||||
@ -475,24 +478,29 @@ def get_tokens_for_prompt_object(tokenizer, parsed_prompt: FlattenedPrompt, trun
|
||||
for x in parsed_prompt.children
|
||||
]
|
||||
text = " ".join(text_fragments)
|
||||
tokens = tokenizer.tokenize(text)
|
||||
tokens: List[str] = tokenizer.tokenize(text)
|
||||
if truncate_if_too_long:
|
||||
max_tokens_length = tokenizer.model_max_length - 2 # typically 75
|
||||
tokens = tokens[0:max_tokens_length]
|
||||
return tokens
|
||||
|
||||
|
||||
def log_tokenization_for_conjunction(c: Conjunction, tokenizer, display_label_prefix=None):
|
||||
def log_tokenization_for_conjunction(
|
||||
c: Conjunction, tokenizer: CLIPTokenizer, display_label_prefix: Optional[str] = None
|
||||
) -> None:
|
||||
display_label_prefix = display_label_prefix or ""
|
||||
for i, p in enumerate(c.prompts):
|
||||
if len(c.prompts) > 1:
|
||||
this_display_label_prefix = f"{display_label_prefix}(conjunction part {i + 1}, weight={c.weights[i]})"
|
||||
else:
|
||||
assert display_label_prefix is not None
|
||||
this_display_label_prefix = display_label_prefix
|
||||
log_tokenization_for_prompt_object(p, tokenizer, display_label_prefix=this_display_label_prefix)
|
||||
|
||||
|
||||
def log_tokenization_for_prompt_object(p: Union[Blend, FlattenedPrompt], tokenizer, display_label_prefix=None):
|
||||
def log_tokenization_for_prompt_object(
|
||||
p: Union[Blend, FlattenedPrompt], tokenizer: CLIPTokenizer, display_label_prefix: Optional[str] = None
|
||||
) -> None:
|
||||
display_label_prefix = display_label_prefix or ""
|
||||
if type(p) is Blend:
|
||||
blend: Blend = p
|
||||
@ -532,7 +540,12 @@ def log_tokenization_for_prompt_object(p: Union[Blend, FlattenedPrompt], tokeniz
|
||||
log_tokenization_for_text(text, tokenizer, display_label=display_label_prefix)
|
||||
|
||||
|
||||
def log_tokenization_for_text(text, tokenizer, display_label=None, truncate_if_too_long=False):
|
||||
def log_tokenization_for_text(
|
||||
text: str,
|
||||
tokenizer: CLIPTokenizer,
|
||||
display_label: Optional[str] = None,
|
||||
truncate_if_too_long: Optional[bool] = False,
|
||||
) -> None:
|
||||
"""shows how the prompt is tokenized
|
||||
# usually tokens have '</w>' to indicate end-of-word,
|
||||
# but for readability it has been replaced with ' '
|
||||
|
@ -3,13 +3,15 @@
|
||||
import math
|
||||
from contextlib import ExitStack
|
||||
from functools import singledispatchmethod
|
||||
from typing import Iterator, List, Literal, Optional, Tuple, Union
|
||||
from typing import Any, Iterator, List, Literal, Optional, Tuple, Union
|
||||
|
||||
import einops
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
import torch
|
||||
import torchvision.transforms as T
|
||||
from diffusers import AutoencoderKL, AutoencoderTiny, UNet2DConditionModel
|
||||
from diffusers import AutoencoderKL, AutoencoderTiny
|
||||
from diffusers.configuration_utils import ConfigMixin
|
||||
from diffusers.image_processor import VaeImageProcessor
|
||||
from diffusers.models.adapter import T2IAdapter
|
||||
from diffusers.models.attention_processor import (
|
||||
@ -18,8 +20,10 @@ from diffusers.models.attention_processor import (
|
||||
LoRAXFormersAttnProcessor,
|
||||
XFormersAttnProcessor,
|
||||
)
|
||||
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
|
||||
from diffusers.schedulers import DPMSolverSDEScheduler
|
||||
from diffusers.schedulers import SchedulerMixin as Scheduler
|
||||
from PIL import Image
|
||||
from pydantic import field_validator
|
||||
from torchvision.transforms.functional import resize as tv_resize
|
||||
|
||||
@ -46,9 +50,10 @@ from invokeai.app.invocations.primitives import (
|
||||
from invokeai.app.invocations.t2i_adapter import T2IAdapterField
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.app.util.controlnet_utils import prepare_control_image
|
||||
from invokeai.backend.embeddings.lora import LoRAModelRaw
|
||||
from invokeai.backend.embeddings.model_patcher import ModelPatcher
|
||||
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter, IPAdapterPlus
|
||||
from invokeai.backend.model_manager import AnyModel, BaseModelType
|
||||
from invokeai.backend.model_manager import BaseModelType, LoadedModel
|
||||
from invokeai.backend.stable_diffusion import PipelineIntermediateState, set_seamless
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningData, IPAdapterConditioningInfo
|
||||
from invokeai.backend.util.silence_warnings import SilenceWarnings
|
||||
@ -123,10 +128,10 @@ class CreateDenoiseMaskInvocation(BaseInvocation):
|
||||
ui_order=4,
|
||||
)
|
||||
|
||||
def prep_mask_tensor(self, mask_image):
|
||||
def prep_mask_tensor(self, mask_image: Image) -> torch.Tensor:
|
||||
if mask_image.mode != "L":
|
||||
mask_image = mask_image.convert("L")
|
||||
mask_tensor = image_resized_to_grid_as_tensor(mask_image, normalize=False)
|
||||
mask_tensor: torch.Tensor = image_resized_to_grid_as_tensor(mask_image, normalize=False)
|
||||
if mask_tensor.dim() == 3:
|
||||
mask_tensor = mask_tensor.unsqueeze(0)
|
||||
# if shape is not None:
|
||||
@ -136,25 +141,25 @@ class CreateDenoiseMaskInvocation(BaseInvocation):
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> DenoiseMaskOutput:
|
||||
if self.image is not None:
|
||||
image = context.images.get_pil(self.image.image_name)
|
||||
image = image_resized_to_grid_as_tensor(image.convert("RGB"))
|
||||
if image.dim() == 3:
|
||||
image = image.unsqueeze(0)
|
||||
image = context.services.images.get_pil_image(self.image.image_name)
|
||||
image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB"))
|
||||
if image_tensor.dim() == 3:
|
||||
image_tensor = image_tensor.unsqueeze(0)
|
||||
else:
|
||||
image = None
|
||||
image_tensor = None
|
||||
|
||||
mask = self.prep_mask_tensor(
|
||||
context.images.get_pil(self.mask.image_name),
|
||||
)
|
||||
|
||||
if image is not None:
|
||||
vae_info = context.services.model_records.load_model(
|
||||
if image_tensor is not None:
|
||||
vae_info = context.services.model_manager.load.load_model_by_key(
|
||||
**self.vae.vae.model_dump(),
|
||||
context=context,
|
||||
)
|
||||
|
||||
img_mask = tv_resize(mask, image.shape[-2:], T.InterpolationMode.BILINEAR, antialias=False)
|
||||
masked_image = image * torch.where(img_mask < 0.5, 0.0, 1.0)
|
||||
img_mask = tv_resize(mask, image_tensor.shape[-2:], T.InterpolationMode.BILINEAR, antialias=False)
|
||||
masked_image = image_tensor * torch.where(img_mask < 0.5, 0.0, 1.0)
|
||||
# TODO:
|
||||
masked_latents = ImageToLatentsInvocation.vae_encode(vae_info, self.fp32, self.tiled, masked_image.clone())
|
||||
|
||||
@ -177,7 +182,7 @@ def get_scheduler(
|
||||
seed: int,
|
||||
) -> Scheduler:
|
||||
scheduler_class, scheduler_extra_config = SCHEDULER_MAP.get(scheduler_name, SCHEDULER_MAP["ddim"])
|
||||
orig_scheduler_info = context.services.model_records.load_model(
|
||||
orig_scheduler_info = context.services.model_manager.load.load_model_by_key(
|
||||
**scheduler_info.model_dump(),
|
||||
context=context,
|
||||
)
|
||||
@ -188,7 +193,7 @@ def get_scheduler(
|
||||
scheduler_config = scheduler_config["_backup"]
|
||||
scheduler_config = {
|
||||
**scheduler_config,
|
||||
**scheduler_extra_config,
|
||||
**scheduler_extra_config, # FIXME
|
||||
"_backup": scheduler_config,
|
||||
}
|
||||
|
||||
@ -201,6 +206,7 @@ def get_scheduler(
|
||||
# hack copied over from generate.py
|
||||
if not hasattr(scheduler, "uses_inpainting_model"):
|
||||
scheduler.uses_inpainting_model = lambda: False
|
||||
assert isinstance(scheduler, Scheduler)
|
||||
return scheduler
|
||||
|
||||
|
||||
@ -284,7 +290,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
@field_validator("cfg_scale")
|
||||
def ge_one(cls, v):
|
||||
def ge_one(cls, v: Union[List[float], float]) -> Union[List[float], float]:
|
||||
"""validate that all cfg_scale values are >= 1"""
|
||||
if isinstance(v, list):
|
||||
for i in v:
|
||||
@ -298,9 +304,9 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
def get_conditioning_data(
|
||||
self,
|
||||
context: InvocationContext,
|
||||
scheduler,
|
||||
unet,
|
||||
seed,
|
||||
scheduler: Scheduler,
|
||||
unet: UNet2DConditionModel,
|
||||
seed: int,
|
||||
) -> ConditioningData:
|
||||
positive_cond_data = context.conditioning.load(self.positive_conditioning.conditioning_name)
|
||||
c = positive_cond_data.conditionings[0].to(device=unet.device, dtype=unet.dtype)
|
||||
@ -323,7 +329,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
),
|
||||
)
|
||||
|
||||
conditioning_data = conditioning_data.add_scheduler_args_if_applicable(
|
||||
conditioning_data = conditioning_data.add_scheduler_args_if_applicable( # FIXME
|
||||
scheduler,
|
||||
# for ddim scheduler
|
||||
eta=0.0, # ddim_eta
|
||||
@ -335,8 +341,8 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
|
||||
def create_pipeline(
|
||||
self,
|
||||
unet,
|
||||
scheduler,
|
||||
unet: UNet2DConditionModel,
|
||||
scheduler: Scheduler,
|
||||
) -> StableDiffusionGeneratorPipeline:
|
||||
# TODO:
|
||||
# configure_model_padding(
|
||||
@ -347,10 +353,10 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
|
||||
class FakeVae:
|
||||
class FakeVaeConfig:
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
self.block_out_channels = [0]
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
self.config = FakeVae.FakeVaeConfig()
|
||||
|
||||
return StableDiffusionGeneratorPipeline(
|
||||
@ -367,11 +373,11 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
def prep_control_data(
|
||||
self,
|
||||
context: InvocationContext,
|
||||
control_input: Union[ControlField, List[ControlField]],
|
||||
control_input: Optional[Union[ControlField, List[ControlField]]],
|
||||
latents_shape: List[int],
|
||||
exit_stack: ExitStack,
|
||||
do_classifier_free_guidance: bool = True,
|
||||
) -> List[ControlNetData]:
|
||||
) -> Optional[List[ControlNetData]]:
|
||||
# Assuming fixed dimensional scaling of LATENT_SCALE_FACTOR.
|
||||
control_height_resize = latents_shape[2] * LATENT_SCALE_FACTOR
|
||||
control_width_resize = latents_shape[3] * LATENT_SCALE_FACTOR
|
||||
@ -394,7 +400,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
controlnet_data = []
|
||||
for control_info in control_list:
|
||||
control_model = exit_stack.enter_context(
|
||||
context.services.model_records.load_model(
|
||||
context.services.model_manager.load.load_model_by_key(
|
||||
key=control_info.control_model.key,
|
||||
context=context,
|
||||
)
|
||||
@ -460,23 +466,25 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
conditioning_data.ip_adapter_conditioning = []
|
||||
for single_ip_adapter in ip_adapter:
|
||||
ip_adapter_model: Union[IPAdapter, IPAdapterPlus] = exit_stack.enter_context(
|
||||
context.services.model_records.load_model(
|
||||
context.services.model_manager.load.load_model_by_key(
|
||||
key=single_ip_adapter.ip_adapter_model.key,
|
||||
context=context,
|
||||
)
|
||||
)
|
||||
|
||||
image_encoder_model_info = context.services.model_records.load_model(
|
||||
image_encoder_model_info = context.services.model_manager.load.load_model_by_key(
|
||||
key=single_ip_adapter.image_encoder_model.key,
|
||||
context=context,
|
||||
)
|
||||
|
||||
# `single_ip_adapter.image` could be a list or a single ImageField. Normalize to a list here.
|
||||
single_ipa_images = single_ip_adapter.image
|
||||
if not isinstance(single_ipa_images, list):
|
||||
single_ipa_images = [single_ipa_images]
|
||||
single_ipa_image_fields = single_ip_adapter.image
|
||||
if not isinstance(single_ipa_image_fields, list):
|
||||
single_ipa_image_fields = [single_ipa_image_fields]
|
||||
|
||||
single_ipa_images = [context.images.get_pil(image.image_name) for image in single_ipa_images]
|
||||
single_ipa_images = [
|
||||
context.services.images.get_pil_image(image.image_name) for image in single_ipa_image_fields
|
||||
]
|
||||
|
||||
# TODO(ryand): With some effort, the step of running the CLIP Vision encoder could be done before any other
|
||||
# models are needed in memory. This would help to reduce peak memory utilization in low-memory environments.
|
||||
@ -520,21 +528,19 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
|
||||
t2i_adapter_data = []
|
||||
for t2i_adapter_field in t2i_adapter:
|
||||
t2i_adapter_model_info = context.services.model_records.load_model(
|
||||
t2i_adapter_model_info = context.services.model_manager.load.load_model_by_key(
|
||||
key=t2i_adapter_field.t2i_adapter_model.key,
|
||||
context=context,
|
||||
)
|
||||
image = context.images.get_pil(t2i_adapter_field.image.image_name)
|
||||
|
||||
# The max_unet_downscale is the maximum amount that the UNet model downscales the latent image internally.
|
||||
if t2i_adapter_field.t2i_adapter_model.base_model == BaseModelType.StableDiffusion1:
|
||||
if t2i_adapter_model_info.base == BaseModelType.StableDiffusion1:
|
||||
max_unet_downscale = 8
|
||||
elif t2i_adapter_field.t2i_adapter_model.base_model == BaseModelType.StableDiffusionXL:
|
||||
elif t2i_adapter_model_info.base == BaseModelType.StableDiffusionXL:
|
||||
max_unet_downscale = 4
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unexpected T2I-Adapter base model type: '{t2i_adapter_field.t2i_adapter_model.base_model}'."
|
||||
)
|
||||
raise ValueError(f"Unexpected T2I-Adapter base model type: '{t2i_adapter_model_info.base}'.")
|
||||
|
||||
t2i_adapter_model: T2IAdapter
|
||||
with t2i_adapter_model_info as t2i_adapter_model:
|
||||
@ -582,7 +588,15 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
|
||||
# original idea by https://github.com/AmericanPresidentJimmyCarter
|
||||
# TODO: research more for second order schedulers timesteps
|
||||
def init_scheduler(self, scheduler, device, steps, denoising_start, denoising_end):
|
||||
def init_scheduler(
|
||||
self,
|
||||
scheduler: Union[Scheduler, ConfigMixin],
|
||||
device: torch.device,
|
||||
steps: int,
|
||||
denoising_start: float,
|
||||
denoising_end: float,
|
||||
) -> Tuple[int, List[int], int]:
|
||||
assert isinstance(scheduler, ConfigMixin)
|
||||
if scheduler.config.get("cpu_only", False):
|
||||
scheduler.set_timesteps(steps, device="cpu")
|
||||
timesteps = scheduler.timesteps.to(device=device)
|
||||
@ -594,11 +608,11 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
_timesteps = timesteps[:: scheduler.order]
|
||||
|
||||
# get start timestep index
|
||||
t_start_val = int(round(scheduler.config.num_train_timesteps * (1 - denoising_start)))
|
||||
t_start_val = int(round(scheduler.config["num_train_timesteps"] * (1 - denoising_start)))
|
||||
t_start_idx = len(list(filter(lambda ts: ts >= t_start_val, _timesteps)))
|
||||
|
||||
# get end timestep index
|
||||
t_end_val = int(round(scheduler.config.num_train_timesteps * (1 - denoising_end)))
|
||||
t_end_val = int(round(scheduler.config["num_train_timesteps"] * (1 - denoising_end)))
|
||||
t_end_idx = len(list(filter(lambda ts: ts >= t_end_val, _timesteps[t_start_idx:])))
|
||||
|
||||
# apply order to indexes
|
||||
@ -611,7 +625,9 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
|
||||
return num_inference_steps, timesteps, init_timestep
|
||||
|
||||
def prep_inpaint_mask(self, context: InvocationContext, latents):
|
||||
def prep_inpaint_mask(
|
||||
self, context: InvocationContext, latents: torch.Tensor
|
||||
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
|
||||
if self.denoise_mask is None:
|
||||
return None, None
|
||||
|
||||
@ -660,12 +676,19 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
do_classifier_free_guidance=True,
|
||||
)
|
||||
|
||||
def step_callback(state: PipelineIntermediateState):
|
||||
context.util.sd_step_callback(state, self.unet.unet.base_model)
|
||||
# Get the source node id (we are invoking the prepared node)
|
||||
graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id)
|
||||
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
|
||||
|
||||
def _lora_loader() -> Iterator[Tuple[AnyModel, float]]:
|
||||
# get the unet's config so that we can pass the base to dispatch_progress()
|
||||
unet_config = context.services.model_manager.store.get_model(**self.unet.unet.model_dump())
|
||||
|
||||
def step_callback(state: PipelineIntermediateState) -> None:
|
||||
self.dispatch_progress(context, source_node_id, state, unet_config.base)
|
||||
|
||||
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
|
||||
for lora in self.unet.loras:
|
||||
lora_info = context.services.model_records.load_model(
|
||||
lora_info = context.services.model_manager.load.load_model_by_key(
|
||||
**lora.model_dump(exclude={"weight"}),
|
||||
context=context,
|
||||
)
|
||||
@ -673,7 +696,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
del lora_info
|
||||
return
|
||||
|
||||
unet_info = context.services.model_records.load_model(
|
||||
unet_info = context.services.model_manager.load.load_model_by_key(
|
||||
**self.unet.unet.model_dump(),
|
||||
context=context,
|
||||
)
|
||||
@ -783,7 +806,7 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
latents = context.tensors.load(self.latents.latents_name)
|
||||
|
||||
vae_info = context.services.model_records.load_model(
|
||||
vae_info = context.services.model_manager.load.load_model_by_key(
|
||||
**self.vae.vae.model_dump(),
|
||||
context=context,
|
||||
)
|
||||
@ -961,8 +984,9 @@ class ImageToLatentsInvocation(BaseInvocation):
|
||||
fp32: bool = InputField(default=DEFAULT_PRECISION == "float32", description=FieldDescriptions.fp32)
|
||||
|
||||
@staticmethod
|
||||
def vae_encode(vae_info, upcast, tiled, image_tensor):
|
||||
def vae_encode(vae_info: LoadedModel, upcast: bool, tiled: bool, image_tensor: torch.Tensor) -> torch.Tensor:
|
||||
with vae_info as vae:
|
||||
assert isinstance(vae, torch.nn.Module)
|
||||
orig_dtype = vae.dtype
|
||||
if upcast:
|
||||
vae.to(dtype=torch.float32)
|
||||
@ -1008,7 +1032,7 @@ class ImageToLatentsInvocation(BaseInvocation):
|
||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||
image = context.images.get_pil(self.image.image_name)
|
||||
|
||||
vae_info = context.services.model_records.load_model(
|
||||
vae_info = context.services.model_manager.load.load_model_by_key(
|
||||
**self.vae.vae.model_dump(),
|
||||
context=context,
|
||||
)
|
||||
@ -1026,14 +1050,19 @@ class ImageToLatentsInvocation(BaseInvocation):
|
||||
@singledispatchmethod
|
||||
@staticmethod
|
||||
def _encode_to_tensor(vae: AutoencoderKL, image_tensor: torch.FloatTensor) -> torch.FloatTensor:
|
||||
assert isinstance(vae, torch.nn.Module)
|
||||
image_tensor_dist = vae.encode(image_tensor).latent_dist
|
||||
latents = image_tensor_dist.sample().to(dtype=vae.dtype) # FIXME: uses torch.randn. make reproducible!
|
||||
latents: torch.Tensor = image_tensor_dist.sample().to(
|
||||
dtype=vae.dtype
|
||||
) # FIXME: uses torch.randn. make reproducible!
|
||||
return latents
|
||||
|
||||
@_encode_to_tensor.register
|
||||
@staticmethod
|
||||
def _(vae: AutoencoderTiny, image_tensor: torch.FloatTensor) -> torch.FloatTensor:
|
||||
return vae.encode(image_tensor).latents
|
||||
assert isinstance(vae, torch.nn.Module)
|
||||
latents: torch.FloatTensor = vae.encode(image_tensor).latents
|
||||
return latents
|
||||
|
||||
|
||||
@invocation(
|
||||
@ -1066,7 +1095,12 @@ class BlendLatentsInvocation(BaseInvocation):
|
||||
# TODO:
|
||||
device = choose_torch_device()
|
||||
|
||||
def slerp(t, v0, v1, DOT_THRESHOLD=0.9995):
|
||||
def slerp(
|
||||
t: Union[float, npt.NDArray[Any]], # FIXME: maybe use np.float32 here?
|
||||
v0: Union[torch.Tensor, npt.NDArray[Any]],
|
||||
v1: Union[torch.Tensor, npt.NDArray[Any]],
|
||||
DOT_THRESHOLD: float = 0.9995,
|
||||
) -> Union[torch.Tensor, npt.NDArray[Any]]:
|
||||
"""
|
||||
Spherical linear interpolation
|
||||
Args:
|
||||
@ -1099,12 +1133,16 @@ class BlendLatentsInvocation(BaseInvocation):
|
||||
v2 = s0 * v0 + s1 * v1
|
||||
|
||||
if inputs_are_torch:
|
||||
v2 = torch.from_numpy(v2).to(device)
|
||||
|
||||
return v2
|
||||
v2_torch: torch.Tensor = torch.from_numpy(v2).to(device)
|
||||
return v2_torch
|
||||
else:
|
||||
assert isinstance(v2, np.ndarray)
|
||||
return v2
|
||||
|
||||
# blend
|
||||
blended_latents = slerp(self.alpha, latents_a, latents_b)
|
||||
bl = slerp(self.alpha, latents_a, latents_b)
|
||||
assert isinstance(bl, torch.Tensor)
|
||||
blended_latents: torch.Tensor = bl # for type checking convenience
|
||||
|
||||
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
||||
blended_latents = blended_latents.to("cpu")
|
||||
@ -1197,15 +1235,19 @@ class IdealSizeInvocation(BaseInvocation):
|
||||
description="Amount to multiply the model's dimensions by when calculating the ideal size (may result in initial generation artifacts if too large)",
|
||||
)
|
||||
|
||||
def trim_to_multiple_of(self, *args, multiple_of=LATENT_SCALE_FACTOR):
|
||||
def trim_to_multiple_of(self, *args: int, multiple_of: int = LATENT_SCALE_FACTOR) -> Tuple[int, ...]:
|
||||
return tuple((x - x % multiple_of) for x in args)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> IdealSizeOutput:
|
||||
unet_config = context.services.model_manager.load.load_model_by_key(
|
||||
**self.unet.unet.model_dump(),
|
||||
context=context,
|
||||
)
|
||||
aspect = self.width / self.height
|
||||
dimension = 512
|
||||
if self.unet.unet.base_model == BaseModelType.StableDiffusion2:
|
||||
dimension: float = 512
|
||||
if unet_config.base == BaseModelType.StableDiffusion2:
|
||||
dimension = 768
|
||||
elif self.unet.unet.base_model == BaseModelType.StableDiffusionXL:
|
||||
elif unet_config.base == BaseModelType.StableDiffusionXL:
|
||||
dimension = 1024
|
||||
dimension = dimension * self.multiplier
|
||||
min_dimension = math.floor(dimension * 0.5)
|
||||
|
@ -17,7 +17,7 @@ from .baseinvocation import (
|
||||
|
||||
|
||||
class ModelInfo(BaseModel):
|
||||
key: str = Field(description="Info to load submodel")
|
||||
key: str = Field(description="Key of model as returned by ModelRecordServiceBase.get_model()")
|
||||
submodel: Optional[SubModelType] = Field(default=None, description="Info to load submodel")
|
||||
|
||||
|
||||
|
@ -27,9 +27,7 @@ if TYPE_CHECKING:
|
||||
from .invocation_queue.invocation_queue_base import InvocationQueueABC
|
||||
from .invocation_stats.invocation_stats_base import InvocationStatsServiceBase
|
||||
from .item_storage.item_storage_base import ItemStorageABC
|
||||
from .model_install import ModelInstallServiceBase
|
||||
from .model_manager.model_manager_base import ModelManagerServiceBase
|
||||
from .model_records import ModelRecordServiceBase
|
||||
from .names.names_base import NameServiceBase
|
||||
from .session_processor.session_processor_base import SessionProcessorBase
|
||||
from .session_queue.session_queue_base import SessionQueueBase
|
||||
@ -55,9 +53,7 @@ class InvocationServices:
|
||||
image_records: "ImageRecordStorageBase",
|
||||
logger: "Logger",
|
||||
model_manager: "ModelManagerServiceBase",
|
||||
model_records: "ModelRecordServiceBase",
|
||||
download_queue: "DownloadQueueServiceBase",
|
||||
model_install: "ModelInstallServiceBase",
|
||||
processor: "InvocationProcessorABC",
|
||||
performance_statistics: "InvocationStatsServiceBase",
|
||||
queue: "InvocationQueueABC",
|
||||
@ -82,9 +78,7 @@ class InvocationServices:
|
||||
self.image_records = image_records
|
||||
self.logger = logger
|
||||
self.model_manager = model_manager
|
||||
self.model_records = model_records
|
||||
self.download_queue = download_queue
|
||||
self.model_install = model_install
|
||||
self.processor = processor
|
||||
self.performance_statistics = performance_statistics
|
||||
self.queue = queue
|
||||
|
@ -43,8 +43,10 @@ class InvocationStatsService(InvocationStatsServiceBase):
|
||||
|
||||
@contextmanager
|
||||
def collect_stats(self, invocation: BaseInvocation, graph_execution_state_id: str) -> Iterator[None]:
|
||||
# This is to handle case of the model manager not being initialized, which happens
|
||||
# during some tests.
|
||||
services = self._invoker.services
|
||||
if services.model_records is None or services.model_records.loader is None:
|
||||
if services.model_manager is None or services.model_manager.load is None:
|
||||
yield None
|
||||
if not self._stats.get(graph_execution_state_id):
|
||||
# First time we're seeing this graph_execution_state_id.
|
||||
@ -60,9 +62,8 @@ class InvocationStatsService(InvocationStatsServiceBase):
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
|
||||
# TO DO [LS]: clean up loader service - shouldn't be an attribute of model records
|
||||
assert services.model_records.loader is not None
|
||||
services.model_records.loader.ram_cache.stats = self._cache_stats[graph_execution_state_id]
|
||||
assert services.model_manager.load is not None
|
||||
services.model_manager.load.ram_cache.stats = self._cache_stats[graph_execution_state_id]
|
||||
|
||||
try:
|
||||
# Let the invocation run.
|
||||
|
@ -4,7 +4,8 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional
|
||||
|
||||
from invokeai.backend.model_manager import AnyModelConfig, SubModelType
|
||||
from invokeai.app.invocations.baseinvocation import InvocationContext
|
||||
from invokeai.backend.model_manager import AnyModelConfig, BaseModelType, ModelType, SubModelType
|
||||
from invokeai.backend.model_manager.load import LoadedModel
|
||||
|
||||
|
||||
@ -12,11 +13,60 @@ class ModelLoadServiceBase(ABC):
|
||||
"""Wrapper around AnyModelLoader."""
|
||||
|
||||
@abstractmethod
|
||||
def load_model_by_key(self, key: str, submodel_type: Optional[SubModelType] = None) -> LoadedModel:
|
||||
"""Given a model's key, load it and return the LoadedModel object."""
|
||||
def load_model_by_key(
|
||||
self,
|
||||
key: str,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
context: Optional[InvocationContext] = None,
|
||||
) -> LoadedModel:
|
||||
"""
|
||||
Given a model's key, load it and return the LoadedModel object.
|
||||
|
||||
:param key: Key of model config to be fetched.
|
||||
:param submodel: For main (pipeline models), the submodel to fetch.
|
||||
:param context: Invocation context used for event reporting
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def load_model_by_config(self, config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> LoadedModel:
|
||||
"""Given a model's configuration, load it and return the LoadedModel object."""
|
||||
def load_model_by_config(
|
||||
self,
|
||||
model_config: AnyModelConfig,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
context: Optional[InvocationContext] = None,
|
||||
) -> LoadedModel:
|
||||
"""
|
||||
Given a model's configuration, load it and return the LoadedModel object.
|
||||
|
||||
:param model_config: Model configuration record (as returned by ModelRecordBase.get_model())
|
||||
:param submodel: For main (pipeline models), the submodel to fetch.
|
||||
:param context: Invocation context used for event reporting
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def load_model_by_attr(
|
||||
self,
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
submodel: Optional[SubModelType] = None,
|
||||
context: Optional[InvocationContext] = None,
|
||||
) -> LoadedModel:
|
||||
"""
|
||||
Given a model's attributes, search the database for it, and if found, load and return the LoadedModel object.
|
||||
|
||||
This is provided for API compatability with the get_model() method
|
||||
in the original model manager. However, note that LoadedModel is
|
||||
not the same as the original ModelInfo that ws returned.
|
||||
|
||||
:param model_name: Name of to be fetched.
|
||||
:param base_model: Base model
|
||||
:param model_type: Type of the model
|
||||
:param submodel: For main (pipeline models), the submodel to fetch
|
||||
:param context: The invocation context.
|
||||
|
||||
Exceptions: UnknownModelException -- model with these attributes not known
|
||||
NotImplementedException -- a model loader was not provided at initialization time
|
||||
ValueError -- more than one model matches this combination
|
||||
"""
|
||||
|
@ -3,12 +3,14 @@
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import InvocationContext
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.app.services.model_records import ModelRecordServiceBase
|
||||
from invokeai.backend.model_manager import AnyModelConfig, SubModelType
|
||||
from invokeai.app.services.invocation_processor.invocation_processor_common import CanceledException
|
||||
from invokeai.app.services.model_records import ModelRecordServiceBase, UnknownModelException
|
||||
from invokeai.backend.model_manager import AnyModel, AnyModelConfig, BaseModelType, ModelType, SubModelType
|
||||
from invokeai.backend.model_manager.load import AnyModelLoader, LoadedModel, ModelCache, ModelConvertCache
|
||||
from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase
|
||||
from invokeai.backend.model_manager.load.ram_cache import ModelCacheBase
|
||||
from invokeai.backend.model_manager.load.model_cache import ModelCacheBase
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
|
||||
from .model_load_base import ModelLoadServiceBase
|
||||
@ -21,7 +23,7 @@ class ModelLoadService(ModelLoadServiceBase):
|
||||
self,
|
||||
app_config: InvokeAIAppConfig,
|
||||
record_store: ModelRecordServiceBase,
|
||||
ram_cache: Optional[ModelCacheBase] = None,
|
||||
ram_cache: Optional[ModelCacheBase[AnyModel]] = None,
|
||||
convert_cache: Optional[ModelConvertCacheBase] = None,
|
||||
):
|
||||
"""Initialize the model load service."""
|
||||
@ -44,11 +46,104 @@ class ModelLoadService(ModelLoadServiceBase):
|
||||
),
|
||||
)
|
||||
|
||||
def load_model_by_key(self, key: str, submodel_type: Optional[SubModelType] = None) -> LoadedModel:
|
||||
"""Given a model's key, load it and return the LoadedModel object."""
|
||||
config = self._store.get_model(key)
|
||||
return self.load_model_by_config(config, submodel_type)
|
||||
def load_model_by_key(
|
||||
self,
|
||||
key: str,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
context: Optional[InvocationContext] = None,
|
||||
) -> LoadedModel:
|
||||
"""
|
||||
Given a model's key, load it and return the LoadedModel object.
|
||||
|
||||
def load_model_by_config(self, config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> LoadedModel:
|
||||
"""Given a model's configuration, load it and return the LoadedModel object."""
|
||||
return self._any_loader.load_model(config, submodel_type)
|
||||
:param key: Key of model config to be fetched.
|
||||
:param submodel: For main (pipeline models), the submodel to fetch.
|
||||
:param context: Invocation context used for event reporting
|
||||
"""
|
||||
config = self._store.get_model(key)
|
||||
return self.load_model_by_config(config, submodel_type, context)
|
||||
|
||||
def load_model_by_attr(
|
||||
self,
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
submodel: Optional[SubModelType] = None,
|
||||
context: Optional[InvocationContext] = None,
|
||||
) -> LoadedModel:
|
||||
"""
|
||||
Given a model's attributes, search the database for it, and if found, load and return the LoadedModel object.
|
||||
|
||||
This is provided for API compatability with the get_model() method
|
||||
in the original model manager. However, note that LoadedModel is
|
||||
not the same as the original ModelInfo that ws returned.
|
||||
|
||||
:param model_name: Name of to be fetched.
|
||||
:param base_model: Base model
|
||||
:param model_type: Type of the model
|
||||
:param submodel: For main (pipeline models), the submodel to fetch
|
||||
:param context: The invocation context.
|
||||
|
||||
Exceptions: UnknownModelException -- model with this key not known
|
||||
NotImplementedException -- a model loader was not provided at initialization time
|
||||
ValueError -- more than one model matches this combination
|
||||
"""
|
||||
configs = self._store.search_by_attr(model_name, base_model, model_type)
|
||||
if len(configs) == 0:
|
||||
raise UnknownModelException(f"{base_model}/{model_type}/{model_name}: Unknown model")
|
||||
elif len(configs) > 1:
|
||||
raise ValueError(f"{base_model}/{model_type}/{model_name}: More than one model matches.")
|
||||
else:
|
||||
return self.load_model_by_key(configs[0].key, submodel)
|
||||
|
||||
def load_model_by_config(
|
||||
self,
|
||||
model_config: AnyModelConfig,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
context: Optional[InvocationContext] = None,
|
||||
) -> LoadedModel:
|
||||
"""
|
||||
Given a model's configuration, load it and return the LoadedModel object.
|
||||
|
||||
:param model_config: Model configuration record (as returned by ModelRecordBase.get_model())
|
||||
:param submodel: For main (pipeline models), the submodel to fetch.
|
||||
:param context: Invocation context used for event reporting
|
||||
"""
|
||||
if context:
|
||||
self._emit_load_event(
|
||||
context=context,
|
||||
model_config=model_config,
|
||||
)
|
||||
loaded_model = self._any_loader.load_model(model_config, submodel_type)
|
||||
if context:
|
||||
self._emit_load_event(
|
||||
context=context,
|
||||
model_config=model_config,
|
||||
loaded=True,
|
||||
)
|
||||
return loaded_model
|
||||
|
||||
def _emit_load_event(
|
||||
self,
|
||||
context: InvocationContext,
|
||||
model_config: AnyModelConfig,
|
||||
loaded: Optional[bool] = False,
|
||||
) -> None:
|
||||
if context.services.queue.is_canceled(context.graph_execution_state_id):
|
||||
raise CanceledException()
|
||||
|
||||
if not loaded:
|
||||
context.services.events.emit_model_load_started(
|
||||
queue_id=context.queue_id,
|
||||
queue_item_id=context.queue_item_id,
|
||||
queue_batch_id=context.queue_batch_id,
|
||||
graph_execution_state_id=context.graph_execution_state_id,
|
||||
model_config=model_config,
|
||||
)
|
||||
else:
|
||||
context.services.events.emit_model_load_completed(
|
||||
queue_id=context.queue_id,
|
||||
queue_item_id=context.queue_item_id,
|
||||
queue_batch_id=context.queue_batch_id,
|
||||
graph_execution_state_id=context.graph_execution_state_id,
|
||||
model_config=model_config,
|
||||
)
|
||||
|
@ -2,9 +2,10 @@
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from typing_extensions import Self
|
||||
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
|
||||
from ..config import InvokeAIAppConfig
|
||||
from ..download import DownloadQueueServiceBase
|
||||
from ..events.events_base import EventServiceBase
|
||||
@ -14,12 +15,13 @@ from ..model_records import ModelRecordServiceBase
|
||||
from ..shared.sqlite.sqlite_database import SqliteDatabase
|
||||
|
||||
|
||||
class ModelManagerServiceBase(BaseModel, ABC):
|
||||
class ModelManagerServiceBase(ABC):
|
||||
"""Abstract base class for the model manager service."""
|
||||
|
||||
store: ModelRecordServiceBase = Field(description="An instance of the model record configuration service.")
|
||||
install: ModelInstallServiceBase = Field(description="An instance of the model install service.")
|
||||
load: ModelLoadServiceBase = Field(description="An instance of the model load service.")
|
||||
# attributes:
|
||||
# store: ModelRecordServiceBase = Field(description="An instance of the model record configuration service.")
|
||||
# install: ModelInstallServiceBase = Field(description="An instance of the model install service.")
|
||||
# load: ModelLoadServiceBase = Field(description="An instance of the model load service.")
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
@ -37,3 +39,29 @@ class ModelManagerServiceBase(BaseModel, ABC):
|
||||
method simplifies the construction considerably.
|
||||
"""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def store(self) -> ModelRecordServiceBase:
|
||||
"""Return the ModelRecordServiceBase used to store and retrieve configuration records."""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def load(self) -> ModelLoadServiceBase:
|
||||
"""Return the ModelLoadServiceBase used to load models from their configuration records."""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def install(self) -> ModelInstallServiceBase:
|
||||
"""Return the ModelInstallServiceBase used to download and manipulate model files."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def start(self, invoker: Invoker) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def stop(self, invoker: Invoker) -> None:
|
||||
pass
|
||||
|
@ -3,6 +3,7 @@
|
||||
|
||||
from typing_extensions import Self
|
||||
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
from invokeai.backend.model_manager.load import ModelCache, ModelConvertCache
|
||||
from invokeai.backend.model_manager.metadata import ModelMetadataStore
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
@ -10,9 +11,9 @@ from invokeai.backend.util.logging import InvokeAILogger
|
||||
from ..config import InvokeAIAppConfig
|
||||
from ..download import DownloadQueueServiceBase
|
||||
from ..events.events_base import EventServiceBase
|
||||
from ..model_install import ModelInstallService
|
||||
from ..model_load import ModelLoadService
|
||||
from ..model_records import ModelRecordServiceSQL
|
||||
from ..model_install import ModelInstallService, ModelInstallServiceBase
|
||||
from ..model_load import ModelLoadService, ModelLoadServiceBase
|
||||
from ..model_records import ModelRecordServiceBase, ModelRecordServiceSQL
|
||||
from ..shared.sqlite.sqlite_database import SqliteDatabase
|
||||
from .model_manager_base import ModelManagerServiceBase
|
||||
|
||||
@ -27,6 +28,38 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
model_manager.load -- Routines to load models into memory.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
store: ModelRecordServiceBase,
|
||||
install: ModelInstallServiceBase,
|
||||
load: ModelLoadServiceBase,
|
||||
):
|
||||
self._store = store
|
||||
self._install = install
|
||||
self._load = load
|
||||
|
||||
@property
|
||||
def store(self) -> ModelRecordServiceBase:
|
||||
return self._store
|
||||
|
||||
@property
|
||||
def install(self) -> ModelInstallServiceBase:
|
||||
return self._install
|
||||
|
||||
@property
|
||||
def load(self) -> ModelLoadServiceBase:
|
||||
return self._load
|
||||
|
||||
def start(self, invoker: Invoker) -> None:
|
||||
for service in [self._store, self._install, self._load]:
|
||||
if hasattr(service, "start"):
|
||||
service.start(invoker)
|
||||
|
||||
def stop(self, invoker: Invoker) -> None:
|
||||
for service in [self._store, self._install, self._load]:
|
||||
if hasattr(service, "stop"):
|
||||
service.stop(invoker)
|
||||
|
||||
@classmethod
|
||||
def build_model_manager(
|
||||
cls,
|
||||
|
@ -10,15 +10,12 @@ from typing import Any, Dict, List, Optional, Set, Tuple, Union
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import InvocationContext
|
||||
from invokeai.app.services.shared.pagination import PaginatedResults
|
||||
from invokeai.backend.model_manager import (
|
||||
AnyModelConfig,
|
||||
BaseModelType,
|
||||
LoadedModel,
|
||||
ModelFormat,
|
||||
ModelType,
|
||||
SubModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.load import AnyModelLoader
|
||||
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata, ModelMetadataStore
|
||||
@ -111,52 +108,6 @@ class ModelRecordServiceBase(ABC):
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def load_model(
|
||||
self,
|
||||
key: str,
|
||||
submodel: Optional[SubModelType] = None,
|
||||
context: Optional[InvocationContext] = None,
|
||||
) -> LoadedModel:
|
||||
"""
|
||||
Load the indicated model into memory and return a LoadedModel object.
|
||||
|
||||
:param key: Key of model config to be fetched.
|
||||
:param submodel: For main (pipeline models), the submodel to fetch
|
||||
:param context: Invocation context, used for event issuing.
|
||||
|
||||
Exceptions: UnknownModelException -- model with this key not known
|
||||
NotImplementedException -- a model loader was not provided at initialization time
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def load_model_by_attr(
|
||||
self,
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
submodel: Optional[SubModelType] = None,
|
||||
context: Optional[InvocationContext] = None,
|
||||
) -> LoadedModel:
|
||||
"""
|
||||
Load the indicated model into memory and return a LoadedModel object.
|
||||
|
||||
This is provided for API compatability with the get_model() method
|
||||
in the original model manager. However, note that LoadedModel is
|
||||
not the same as the original ModelInfo that ws returned.
|
||||
|
||||
:param model_name: Key of model config to be fetched.
|
||||
:param base_model: Base model
|
||||
:param model_type: Type of the model
|
||||
:param submodel: For main (pipeline models), the submodel to fetch
|
||||
:param context: The invocation context.
|
||||
|
||||
Exceptions: UnknownModelException -- model with this key not known
|
||||
NotImplementedException -- a model loader was not provided at initialization time
|
||||
"""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def metadata_store(self) -> ModelMetadataStore:
|
||||
|
@ -46,8 +46,6 @@ from math import ceil
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Set, Tuple, Union
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import InvocationContext
|
||||
from invokeai.app.services.invocation_processor.invocation_processor_common import CanceledException
|
||||
from invokeai.app.services.shared.pagination import PaginatedResults
|
||||
from invokeai.backend.model_manager.config import (
|
||||
AnyModelConfig,
|
||||
@ -55,9 +53,8 @@ from invokeai.backend.model_manager.config import (
|
||||
ModelConfigFactory,
|
||||
ModelFormat,
|
||||
ModelType,
|
||||
SubModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.load import AnyModelLoader, LoadedModel
|
||||
from invokeai.backend.model_manager.load import AnyModelLoader
|
||||
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata, ModelMetadataStore, UnknownMetadataException
|
||||
|
||||
from ..shared.sqlite.sqlite_database import SqliteDatabase
|
||||
@ -220,74 +217,6 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
model = ModelConfigFactory.make_config(json.loads(rows[0]), timestamp=rows[1])
|
||||
return model
|
||||
|
||||
def load_model(
|
||||
self,
|
||||
key: str,
|
||||
submodel: Optional[SubModelType],
|
||||
context: Optional[InvocationContext] = None,
|
||||
) -> LoadedModel:
|
||||
"""
|
||||
Load the indicated model into memory and return a LoadedModel object.
|
||||
|
||||
:param key: Key of model config to be fetched.
|
||||
:param submodel: For main (pipeline models), the submodel to fetch.
|
||||
:param context: Invocation context used for event reporting
|
||||
|
||||
Exceptions: UnknownModelException -- model with this key not known
|
||||
NotImplementedException -- a model loader was not provided at initialization time
|
||||
"""
|
||||
if not self._loader:
|
||||
raise NotImplementedError(f"Class {self.__class__} was not initialized with a model loader")
|
||||
# we can emit model loading events if we are executing with access to the invocation context
|
||||
|
||||
model_config = self.get_model(key)
|
||||
if context:
|
||||
self._emit_load_event(
|
||||
context=context,
|
||||
model_config=model_config,
|
||||
)
|
||||
loaded_model = self._loader.load_model(model_config, submodel)
|
||||
if context:
|
||||
self._emit_load_event(
|
||||
context=context,
|
||||
model_config=model_config,
|
||||
loaded=True,
|
||||
)
|
||||
return loaded_model
|
||||
|
||||
def load_model_by_attr(
|
||||
self,
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
submodel: Optional[SubModelType] = None,
|
||||
context: Optional[InvocationContext] = None,
|
||||
) -> LoadedModel:
|
||||
"""
|
||||
Load the indicated model into memory and return a LoadedModel object.
|
||||
|
||||
This is provided for API compatability with the get_model() method
|
||||
in the original model manager. However, note that LoadedModel is
|
||||
not the same as the original ModelInfo that ws returned.
|
||||
|
||||
:param model_name: Key of model config to be fetched.
|
||||
:param base_model: Base model
|
||||
:param model_type: Type of the model
|
||||
:param submodel: For main (pipeline models), the submodel to fetch
|
||||
:param context: The invocation context.
|
||||
|
||||
Exceptions: UnknownModelException -- model with this key not known
|
||||
NotImplementedException -- a model loader was not provided at initialization time
|
||||
ValueError -- more than one model matches this combination
|
||||
"""
|
||||
configs = self.search_by_attr(model_name, base_model, model_type)
|
||||
if len(configs) == 0:
|
||||
raise UnknownModelException(f"{base_model}/{model_type}/{model_name}: Unknown model")
|
||||
elif len(configs) > 1:
|
||||
raise ValueError(f"{base_model}/{model_type}/{model_name}: More than one model matches.")
|
||||
else:
|
||||
return self.load_model(configs[0].key, submodel)
|
||||
|
||||
def exists(self, key: str) -> bool:
|
||||
"""
|
||||
Return True if a model with the indicated key exists in the databse.
|
||||
@ -476,29 +405,3 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
return PaginatedResults(
|
||||
page=page, pages=ceil(total / per_page), per_page=per_page, total=total, items=items
|
||||
)
|
||||
|
||||
def _emit_load_event(
|
||||
self,
|
||||
context: InvocationContext,
|
||||
model_config: AnyModelConfig,
|
||||
loaded: Optional[bool] = False,
|
||||
) -> None:
|
||||
if context.services.queue.is_canceled(context.graph_execution_state_id):
|
||||
raise CanceledException()
|
||||
|
||||
if not loaded:
|
||||
context.services.events.emit_model_load_started(
|
||||
queue_id=context.queue_id,
|
||||
queue_item_id=context.queue_item_id,
|
||||
queue_batch_id=context.queue_batch_id,
|
||||
graph_execution_state_id=context.graph_execution_state_id,
|
||||
model_config=model_config,
|
||||
)
|
||||
else:
|
||||
context.services.events.emit_model_load_completed(
|
||||
queue_id=context.queue_id,
|
||||
queue_item_id=context.queue_item_id,
|
||||
queue_batch_id=context.queue_batch_id,
|
||||
graph_execution_state_id=context.graph_execution_state_id,
|
||||
model_config=model_config,
|
||||
)
|
||||
|
@ -6,6 +6,7 @@ from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import
|
||||
class Migration6Callback:
|
||||
def __call__(self, cursor: sqlite3.Cursor) -> None:
|
||||
self._recreate_model_triggers(cursor)
|
||||
self._delete_ip_adapters(cursor)
|
||||
|
||||
def _recreate_model_triggers(self, cursor: sqlite3.Cursor) -> None:
|
||||
"""
|
||||
@ -26,6 +27,22 @@ class Migration6Callback:
|
||||
"""
|
||||
)
|
||||
|
||||
def _delete_ip_adapters(self, cursor: sqlite3.Cursor) -> None:
|
||||
"""
|
||||
Delete all the IP adapters.
|
||||
|
||||
The model manager will automatically find and re-add them after the migration
|
||||
is done. This allows the manager to add the correct image encoder to their
|
||||
configuration records.
|
||||
"""
|
||||
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
DELETE FROM model_config
|
||||
WHERE type='ip_adapter';
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
def build_migration_6() -> Migration:
|
||||
"""
|
||||
@ -33,6 +50,8 @@ def build_migration_6() -> Migration:
|
||||
|
||||
This migration does the following:
|
||||
- Adds the model_config_updated_at trigger if it does not exist
|
||||
- Delete all ip_adapter models so that the model prober can find and
|
||||
update with the correct image processor model.
|
||||
"""
|
||||
migration_6 = Migration(
|
||||
from_version=5,
|
||||
|
@ -64,7 +64,7 @@ class ModelPatcher:
|
||||
def apply_lora_unet(
|
||||
cls,
|
||||
unet: UNet2DConditionModel,
|
||||
loras: List[Tuple[LoRAModelRaw, float]],
|
||||
loras: Iterator[Tuple[LoRAModelRaw, float]],
|
||||
) -> None:
|
||||
with cls.apply_lora(unet, loras, "lora_unet_"):
|
||||
yield
|
||||
@ -307,7 +307,7 @@ class ONNXModelPatcher:
|
||||
def apply_lora_unet(
|
||||
cls,
|
||||
unet: OnnxRuntimeModel,
|
||||
loras: List[Tuple[LoRAModelRaw, float]],
|
||||
loras: Iterator[Tuple[LoRAModelRaw, float]],
|
||||
) -> None:
|
||||
with cls.apply_lora(unet, loras, "lora_unet_"):
|
||||
yield
|
||||
|
@ -8,8 +8,8 @@ from PIL import Image
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.backend import SilenceWarnings
|
||||
from invokeai.backend.util.devices import choose_torch_device
|
||||
from invokeai.backend.util.silence_warnings import SilenceWarnings
|
||||
|
||||
config = InvokeAIAppConfig.get_config()
|
||||
|
||||
|
@ -8,7 +8,6 @@ from PIL import Image
|
||||
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
|
||||
|
||||
from invokeai.backend.ip_adapter.ip_attention_weights import IPAttentionWeights
|
||||
from invokeai.backend.model_management.models.base import calc_model_size_by_data
|
||||
|
||||
from .resampler import Resampler
|
||||
|
||||
@ -124,6 +123,9 @@ class IPAdapter:
|
||||
self.attn_weights.to(device=self.device, dtype=self.dtype)
|
||||
|
||||
def calc_size(self):
|
||||
# workaround for circular import
|
||||
from invokeai.backend.model_manager.load.model_util import calc_model_size_by_data
|
||||
|
||||
return calc_model_size_by_data(self._image_proj_model) + calc_model_size_by_data(self.attn_weights)
|
||||
|
||||
def _init_image_proj_model(self, state_dict):
|
||||
|
@ -21,7 +21,7 @@ Validation errors will raise an InvalidModelConfigException error.
|
||||
"""
|
||||
import time
|
||||
from enum import Enum
|
||||
from typing import Literal, Optional, Type, Union, Class
|
||||
from typing import Literal, Optional, Type, Union
|
||||
|
||||
import torch
|
||||
from diffusers import ModelMixin
|
||||
@ -335,7 +335,7 @@ class ModelConfigFactory(object):
|
||||
cls,
|
||||
model_data: Union[Dict[str, Any], AnyModelConfig],
|
||||
key: Optional[str] = None,
|
||||
dest_class: Optional[Type[Class]] = None,
|
||||
dest_class: Optional[Type[ModelConfigBase]] = None,
|
||||
timestamp: Optional[float] = None,
|
||||
) -> AnyModelConfig:
|
||||
"""
|
||||
@ -347,14 +347,17 @@ class ModelConfigFactory(object):
|
||||
:param dest_class: The config class to be returned. If not provided, will
|
||||
be selected automatically.
|
||||
"""
|
||||
model: Optional[ModelConfigBase] = None
|
||||
if isinstance(model_data, ModelConfigBase):
|
||||
model = model_data
|
||||
elif dest_class:
|
||||
model = dest_class.validate_python(model_data)
|
||||
model = dest_class.model_validate(model_data)
|
||||
else:
|
||||
model = AnyModelConfigValidator.validate_python(model_data)
|
||||
# mypy doesn't typecheck TypeAdapters well?
|
||||
model = AnyModelConfigValidator.validate_python(model_data) # type: ignore
|
||||
assert model is not None
|
||||
if key:
|
||||
model.key = key
|
||||
if timestamp:
|
||||
model.last_modified = timestamp
|
||||
return model
|
||||
return model # type: ignore
|
||||
|
@ -18,8 +18,16 @@ from pathlib import Path
|
||||
from typing import Any, Callable, Dict, Optional, Tuple, Type
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.backend.model_manager import AnyModelConfig, BaseModelType, ModelFormat, ModelType, SubModelType
|
||||
from invokeai.backend.model_manager.config import AnyModel, VaeCheckpointConfig, VaeDiffusersConfig
|
||||
from invokeai.backend.model_manager.config import (
|
||||
AnyModel,
|
||||
AnyModelConfig,
|
||||
BaseModelType,
|
||||
ModelFormat,
|
||||
ModelType,
|
||||
SubModelType,
|
||||
VaeCheckpointConfig,
|
||||
VaeDiffusersConfig,
|
||||
)
|
||||
from invokeai.backend.model_manager.load.convert_cache.convert_cache_base import ModelConvertCacheBase
|
||||
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase, ModelLockerBase
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
@ -32,7 +40,7 @@ class LoadedModel:
|
||||
config: AnyModelConfig
|
||||
locker: ModelLockerBase
|
||||
|
||||
def __enter__(self) -> AnyModel: # I think load_file() always returns a dict
|
||||
def __enter__(self) -> AnyModel:
|
||||
"""Context entry."""
|
||||
self.locker.lock()
|
||||
return self.model
|
||||
@ -171,6 +179,10 @@ class AnyModelLoader:
|
||||
def decorator(subclass: Type[ModelLoaderBase]) -> Type[ModelLoaderBase]:
|
||||
cls._logger.debug(f"Registering class {subclass.__name__} to load models of type {base}/{type}/{format}")
|
||||
key = cls._to_registry_key(base, type, format)
|
||||
if key in cls._registry:
|
||||
raise Exception(
|
||||
f"{subclass.__name__} is trying to register as a loader for {base}/{type}/{format}, but this type of model has already been registered by {cls._registry[key].__name__}"
|
||||
)
|
||||
cls._registry[key] = subclass
|
||||
return subclass
|
||||
|
||||
|
@ -169,7 +169,7 @@ class ModelLoader(ModelLoaderBase):
|
||||
raise InvalidModelConfigException("An expected config.json file is missing from this model.") from e
|
||||
|
||||
# This needs to be implemented in subclasses that handle checkpoints
|
||||
def _convert_model(self, config: AnyModelConfig, weights_path: Path, output_path: Path) -> Path:
|
||||
def _convert_model(self, config: AnyModelConfig, model_path: Path, output_path: Path) -> Path:
|
||||
raise NotImplementedError
|
||||
|
||||
# This needs to be implemented in the subclass
|
||||
|
@ -246,7 +246,7 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
||||
|
||||
def move_model_to_device(self, cache_entry: CacheRecord[AnyModel], target_device: torch.device) -> None:
|
||||
"""Move model into the indicated device."""
|
||||
# These attributes are not in the base ModelMixin class but in derived classes.
|
||||
# These attributes are not in the base ModelMixin class but in various derived classes.
|
||||
# Some models don't have these attributes, in which case they run in RAM/CPU.
|
||||
self.logger.debug(f"Called to move {cache_entry.key} to {target_device}")
|
||||
if not (hasattr(cache_entry.model, "device") and hasattr(cache_entry.model, "to")):
|
||||
|
@ -35,28 +35,28 @@ class ControlnetLoader(GenericDiffusersLoader):
|
||||
else:
|
||||
return True
|
||||
|
||||
def _convert_model(self, config: AnyModelConfig, weights_path: Path, output_path: Path) -> Path:
|
||||
def _convert_model(self, config: AnyModelConfig, model_path: Path, output_path: Path) -> Path:
|
||||
if config.base not in {BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2}:
|
||||
raise Exception(f"Vae conversion not supported for model type: {config.base}")
|
||||
else:
|
||||
assert hasattr(config, "config")
|
||||
config_file = config.config
|
||||
|
||||
if weights_path.suffix == ".safetensors":
|
||||
checkpoint = safetensors.torch.load_file(weights_path, device="cpu")
|
||||
if model_path.suffix == ".safetensors":
|
||||
checkpoint = safetensors.torch.load_file(model_path, device="cpu")
|
||||
else:
|
||||
checkpoint = torch.load(weights_path, map_location="cpu")
|
||||
checkpoint = torch.load(model_path, map_location="cpu")
|
||||
|
||||
# sometimes weights are hidden under "state_dict", and sometimes not
|
||||
if "state_dict" in checkpoint:
|
||||
checkpoint = checkpoint["state_dict"]
|
||||
|
||||
convert_controlnet_to_diffusers(
|
||||
weights_path,
|
||||
model_path,
|
||||
output_path,
|
||||
original_config_file=self._app_config.root_path / config_file,
|
||||
image_size=512,
|
||||
scan_needed=True,
|
||||
from_safetensors=weights_path.suffix == ".safetensors",
|
||||
from_safetensors=model_path.suffix == ".safetensors",
|
||||
)
|
||||
return output_path
|
||||
|
@ -12,8 +12,9 @@ from invokeai.backend.model_manager import (
|
||||
ModelType,
|
||||
SubModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.load.load_base import AnyModelLoader
|
||||
from invokeai.backend.model_manager.load.load_default import ModelLoader
|
||||
|
||||
from ..load_base import AnyModelLoader
|
||||
from ..load_default import ModelLoader
|
||||
|
||||
|
||||
@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.CLIPVision, format=ModelFormat.Diffusers)
|
||||
|
@ -65,7 +65,7 @@ class StableDiffusionDiffusersModel(ModelLoader):
|
||||
else:
|
||||
return True
|
||||
|
||||
def _convert_model(self, config: AnyModelConfig, weights_path: Path, output_path: Path) -> Path:
|
||||
def _convert_model(self, config: AnyModelConfig, model_path: Path, output_path: Path) -> Path:
|
||||
assert isinstance(config, MainCheckpointConfig)
|
||||
variant = config.variant
|
||||
base = config.base
|
||||
@ -75,9 +75,9 @@ class StableDiffusionDiffusersModel(ModelLoader):
|
||||
|
||||
config_file = config.config
|
||||
|
||||
self._logger.info(f"Converting {weights_path} to diffusers format")
|
||||
self._logger.info(f"Converting {model_path} to diffusers format")
|
||||
convert_ckpt_to_diffusers(
|
||||
weights_path,
|
||||
model_path,
|
||||
output_path,
|
||||
model_type=self.model_base_to_model_type[base],
|
||||
model_version=base,
|
||||
@ -86,7 +86,7 @@ class StableDiffusionDiffusersModel(ModelLoader):
|
||||
extract_ema=True,
|
||||
scan_needed=True,
|
||||
pipeline_class=pipeline_class,
|
||||
from_safetensors=weights_path.suffix == ".safetensors",
|
||||
from_safetensors=model_path.suffix == ".safetensors",
|
||||
precision=self._torch_dtype,
|
||||
load_safety_checker=False,
|
||||
)
|
||||
|
@ -37,7 +37,7 @@ class VaeLoader(GenericDiffusersLoader):
|
||||
else:
|
||||
return True
|
||||
|
||||
def _convert_model(self, config: AnyModelConfig, weights_path: Path, output_path: Path) -> Path:
|
||||
def _convert_model(self, config: AnyModelConfig, model_path: Path, output_path: Path) -> Path:
|
||||
# TO DO: check whether sdxl VAE models convert.
|
||||
if config.base not in {BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2}:
|
||||
raise Exception(f"Vae conversion not supported for model type: {config.base}")
|
||||
@ -46,10 +46,10 @@ class VaeLoader(GenericDiffusersLoader):
|
||||
"v1-inference.yaml" if config.base == BaseModelType.StableDiffusion1 else "v2-inference-v.yaml"
|
||||
)
|
||||
|
||||
if weights_path.suffix == ".safetensors":
|
||||
checkpoint = safetensors.torch.load_file(weights_path, device="cpu")
|
||||
if model_path.suffix == ".safetensors":
|
||||
checkpoint = safetensors.torch.load_file(model_path, device="cpu")
|
||||
else:
|
||||
checkpoint = torch.load(weights_path, map_location="cpu")
|
||||
checkpoint = torch.load(model_path, map_location="cpu")
|
||||
|
||||
# sometimes weights are hidden under "state_dict", and sometimes not
|
||||
if "state_dict" in checkpoint:
|
||||
|
@ -65,7 +65,7 @@ def calc_model_size_by_fs(model_path: Path, subfolder: Optional[str] = None, var
|
||||
bit8_files = {f for f in all_files if ".8bit." in f.name or ".8bit-" in f.name}
|
||||
other_files = set(all_files) - fp16_files - bit8_files
|
||||
|
||||
if variant is None:
|
||||
if not variant: # ModelRepoVariant.DEFAULT evaluates to empty string for compatability with HF
|
||||
files = other_files
|
||||
elif variant == "fp16":
|
||||
files = fp16_files
|
||||
|
@ -22,11 +22,12 @@ Example usage:
|
||||
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from logging import Logger
|
||||
from pathlib import Path
|
||||
from typing import Callable, Optional, Set, Union
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from logging import Logger
|
||||
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
|
||||
default_logger: Logger = InvokeAILogger.get_logger()
|
||||
|
@ -1 +1,3 @@
|
||||
from .schedulers import SCHEDULER_MAP # noqa: F401
|
||||
|
||||
__all__ = ["SCHEDULER_MAP"]
|
||||
|
@ -513,7 +513,7 @@ def select_and_download_models(opt: Namespace) -> None:
|
||||
"""Prompt user for install/delete selections and execute."""
|
||||
precision = "float32" if opt.full_precision else choose_precision(torch.device(choose_torch_device()))
|
||||
# unsure how to avoid a typing complaint in the next line: config.precision is an enumerated Literal
|
||||
config.precision = precision # type: ignore
|
||||
config.precision = precision
|
||||
install_helper = InstallHelper(config, logger)
|
||||
installer = install_helper.installer
|
||||
|
||||
|
@ -62,9 +62,7 @@ def mock_services() -> InvocationServices:
|
||||
invocation_cache=MemoryInvocationCache(max_cache_size=0),
|
||||
logger=logging, # type: ignore
|
||||
model_manager=None, # type: ignore
|
||||
model_records=None, # type: ignore
|
||||
download_queue=None, # type: ignore
|
||||
model_install=None, # type: ignore
|
||||
names=None, # type: ignore
|
||||
performance_statistics=InvocationStatsService(),
|
||||
processor=DefaultInvocationProcessor(),
|
||||
|
@ -65,9 +65,7 @@ def mock_services() -> InvocationServices:
|
||||
invocation_cache=MemoryInvocationCache(max_cache_size=0),
|
||||
logger=logging, # type: ignore
|
||||
model_manager=None, # type: ignore
|
||||
model_records=None, # type: ignore
|
||||
download_queue=None, # type: ignore
|
||||
model_install=None, # type: ignore
|
||||
names=None, # type: ignore
|
||||
performance_statistics=InvocationStatsService(),
|
||||
processor=DefaultInvocationProcessor(),
|
||||
|
@ -0,0 +1,22 @@
|
||||
"""
|
||||
Test model loading
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from invokeai.app.services.model_install import ModelInstallServiceBase
|
||||
from invokeai.backend.embeddings.textual_inversion import TextualInversionModelRaw
|
||||
from invokeai.backend.model_manager.load import AnyModelLoader
|
||||
from tests.backend.model_manager_2.model_manager_2_fixtures import * # noqa F403
|
||||
|
||||
|
||||
def test_loading(mm2_installer: ModelInstallServiceBase, mm2_loader: AnyModelLoader, embedding_file: Path):
|
||||
store = mm2_installer.record_store
|
||||
matches = store.search_by_attr(model_name="test_embedding")
|
||||
assert len(matches) == 0
|
||||
key = mm2_installer.register_path(embedding_file)
|
||||
loaded_model = mm2_loader.load_model(store.get_model(key))
|
||||
assert loaded_model is not None
|
||||
assert loaded_model.config.key == key
|
||||
with loaded_model as model:
|
||||
assert isinstance(model, TextualInversionModelRaw)
|
@ -20,6 +20,7 @@ from invokeai.backend.model_manager.config import (
|
||||
ModelFormat,
|
||||
ModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.load import AnyModelLoader, ModelCache, ModelConvertCache
|
||||
from invokeai.backend.model_manager.metadata import ModelMetadataStore
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
from tests.backend.model_manager_2.model_metadata.metadata_examples import (
|
||||
@ -89,6 +90,16 @@ def mm2_app_config(mm2_root_dir: Path) -> InvokeAIAppConfig:
|
||||
return app_config
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mm2_loader(mm2_app_config: InvokeAIAppConfig, mm2_record_store: ModelRecordServiceSQL) -> AnyModelLoader:
|
||||
logger = InvokeAILogger.get_logger(config=mm2_app_config)
|
||||
ram_cache = ModelCache(
|
||||
logger=logger, max_cache_size=mm2_app_config.ram_cache_size, max_vram_cache_size=mm2_app_config.vram_cache_size
|
||||
)
|
||||
convert_cache = ModelConvertCache(mm2_app_config.models_convert_cache_path)
|
||||
return AnyModelLoader(app_config=mm2_app_config, logger=logger, ram_cache=ram_cache, convert_cache=convert_cache)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mm2_record_store(mm2_app_config: InvokeAIAppConfig) -> ModelRecordServiceSQL:
|
||||
logger = InvokeAILogger.get_logger(config=mm2_app_config)
|
||||
|
Loading…
Reference in New Issue
Block a user