mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
final tidying before marking PR as ready for review
- Replace AnyModelLoader with ModelLoaderRegistry - Fix type check errors in multiple files - Remove apparently unneeded `get_model_config_enum()` method from model manager - Remove last vestiges of old model manager - Updated tests and documentation resolve conflict with seamless.py
This commit is contained in:
parent
ed2d9ae0d9
commit
4ffe672bc1
@ -1531,23 +1531,29 @@ Here is a typical initialization pattern:
|
|||||||
|
|
||||||
```
|
```
|
||||||
from invokeai.app.services.config import InvokeAIAppConfig
|
from invokeai.app.services.config import InvokeAIAppConfig
|
||||||
from invokeai.app.services.model_records import ModelRecordServiceBase
|
from invokeai.app.services.model_load import ModelLoadService, ModelLoaderRegistry
|
||||||
from invokeai.app.services.model_load import ModelLoadService
|
|
||||||
|
|
||||||
config = InvokeAIAppConfig.get_config()
|
config = InvokeAIAppConfig.get_config()
|
||||||
store = ModelRecordServiceBase.open(config)
|
ram_cache = ModelCache(
|
||||||
loader = ModelLoadService(config, store)
|
max_cache_size=config.ram_cache_size, max_vram_cache_size=config.vram_cache_size, logger=logger
|
||||||
|
)
|
||||||
|
convert_cache = ModelConvertCache(
|
||||||
|
cache_path=config.models_convert_cache_path, max_size=config.convert_cache_size
|
||||||
|
)
|
||||||
|
loader = ModelLoadService(
|
||||||
|
app_config=config,
|
||||||
|
ram_cache=ram_cache,
|
||||||
|
convert_cache=convert_cache,
|
||||||
|
registry=ModelLoaderRegistry
|
||||||
|
)
|
||||||
```
|
```
|
||||||
|
|
||||||
Note that we are relying on the contents of the application
|
### load_model(model_config, [submodel_type], [context]) -> LoadedModel
|
||||||
configuration to choose the implementation of
|
|
||||||
`ModelRecordServiceBase`.
|
|
||||||
|
|
||||||
### load_model_by_key(key, [submodel_type], [context]) -> LoadedModel
|
The `load_model()` method takes an `AnyModelConfig` returned by
|
||||||
|
`ModelRecordService.get_model()` and returns the corresponding loaded
|
||||||
The `load_model_by_key()` method receives the unique key that
|
model. It loads the model into memory, gets the model ready for use,
|
||||||
identifies the model. It loads the model into memory, gets the model
|
and returns a `LoadedModel` object.
|
||||||
ready for use, and returns a `LoadedModel` object.
|
|
||||||
|
|
||||||
The optional second argument, `subtype` is a `SubModelType` string
|
The optional second argument, `subtype` is a `SubModelType` string
|
||||||
enum, such as "vae". It is mandatory when used with a main model, and
|
enum, such as "vae". It is mandatory when used with a main model, and
|
||||||
@ -1593,25 +1599,6 @@ with model_info as vae:
|
|||||||
- `ModelNotFoundException` -- key in database but model not found at path
|
- `ModelNotFoundException` -- key in database but model not found at path
|
||||||
- `NotImplementedException` -- the loader doesn't know how to load this type of model
|
- `NotImplementedException` -- the loader doesn't know how to load this type of model
|
||||||
|
|
||||||
### load_model_by_attr(model_name, base_model, model_type, [submodel], [context]) -> LoadedModel
|
|
||||||
|
|
||||||
This is similar to `load_model_by_key`, but instead it accepts the
|
|
||||||
combination of the model's name, type and base, which it passes to the
|
|
||||||
model record config store for retrieval. If successful, this method
|
|
||||||
returns a `LoadedModel`. It can raise the following exceptions:
|
|
||||||
|
|
||||||
```
|
|
||||||
UnknownModelException -- model with these attributes not known
|
|
||||||
NotImplementedException -- the loader doesn't know how to load this type of model
|
|
||||||
ValueError -- more than one model matches this combination of base/type/name
|
|
||||||
```
|
|
||||||
|
|
||||||
### load_model_by_config(config, [submodel], [context]) -> LoadedModel
|
|
||||||
|
|
||||||
This method takes an `AnyModelConfig` returned by
|
|
||||||
ModelRecordService.get_model() and returns the corresponding loaded
|
|
||||||
model. It may raise a `NotImplementedException`.
|
|
||||||
|
|
||||||
### Emitting model loading events
|
### Emitting model loading events
|
||||||
|
|
||||||
When the `context` argument is passed to `load_model_*()`, it will
|
When the `context` argument is passed to `load_model_*()`, it will
|
||||||
@ -1656,7 +1643,7 @@ onnx models.
|
|||||||
|
|
||||||
To install a new loader, place it in
|
To install a new loader, place it in
|
||||||
`invokeai/backend/model_manager/load/model_loaders`. Inherit from
|
`invokeai/backend/model_manager/load/model_loaders`. Inherit from
|
||||||
`ModelLoader` and use the `@AnyModelLoader.register()` decorator to
|
`ModelLoader` and use the `@ModelLoaderRegistry.register()` decorator to
|
||||||
indicate what type of models the loader can handle.
|
indicate what type of models the loader can handle.
|
||||||
|
|
||||||
Here is a complete example from `generic_diffusers.py`, which is able
|
Here is a complete example from `generic_diffusers.py`, which is able
|
||||||
@ -1674,12 +1661,11 @@ from invokeai.backend.model_manager import (
|
|||||||
ModelType,
|
ModelType,
|
||||||
SubModelType,
|
SubModelType,
|
||||||
)
|
)
|
||||||
from ..load_base import AnyModelLoader
|
from .. import ModelLoader, ModelLoaderRegistry
|
||||||
from ..load_default import ModelLoader
|
|
||||||
|
|
||||||
|
|
||||||
@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.CLIPVision, format=ModelFormat.Diffusers)
|
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.CLIPVision, format=ModelFormat.Diffusers)
|
||||||
@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.T2IAdapter, format=ModelFormat.Diffusers)
|
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.T2IAdapter, format=ModelFormat.Diffusers)
|
||||||
class GenericDiffusersLoader(ModelLoader):
|
class GenericDiffusersLoader(ModelLoader):
|
||||||
"""Class to load simple diffusers models."""
|
"""Class to load simple diffusers models."""
|
||||||
|
|
||||||
@ -1728,3 +1714,74 @@ model. It does whatever it needs to do to get the model into diffusers
|
|||||||
format, and returns the Path of the resulting model. (The path should
|
format, and returns the Path of the resulting model. (The path should
|
||||||
ordinarily be the same as `output_path`.)
|
ordinarily be the same as `output_path`.)
|
||||||
|
|
||||||
|
## The ModelManagerService object
|
||||||
|
|
||||||
|
For convenience, the API provides a `ModelManagerService` object which
|
||||||
|
gives a single point of access to the major model manager
|
||||||
|
services. This object is created at initialization time and can be
|
||||||
|
found in the global `ApiDependencies.invoker.services.model_manager`
|
||||||
|
object, or in `context.services.model_manager` from within an
|
||||||
|
invocation.
|
||||||
|
|
||||||
|
In the examples below, we have retrieved the manager using:
|
||||||
|
```
|
||||||
|
mm = ApiDependencies.invoker.services.model_manager
|
||||||
|
```
|
||||||
|
|
||||||
|
The following properties and methods will be available:
|
||||||
|
|
||||||
|
### mm.store
|
||||||
|
|
||||||
|
This retrieves the `ModelRecordService` associated with the
|
||||||
|
manager. Example:
|
||||||
|
|
||||||
|
```
|
||||||
|
configs = mm.store.get_model_by_attr(name='stable-diffusion-v1-5')
|
||||||
|
```
|
||||||
|
|
||||||
|
### mm.install
|
||||||
|
|
||||||
|
This retrieves the `ModelInstallService` associated with the manager.
|
||||||
|
Example:
|
||||||
|
|
||||||
|
```
|
||||||
|
job = mm.install.heuristic_import(`https://civitai.com/models/58390/detail-tweaker-lora-lora`)
|
||||||
|
```
|
||||||
|
|
||||||
|
### mm.load
|
||||||
|
|
||||||
|
This retrieves the `ModelLoaderService` associated with the manager. Example:
|
||||||
|
|
||||||
|
```
|
||||||
|
configs = mm.store.get_model_by_attr(name='stable-diffusion-v1-5')
|
||||||
|
assert len(configs) > 0
|
||||||
|
|
||||||
|
loaded_model = mm.load.load_model(configs[0])
|
||||||
|
```
|
||||||
|
|
||||||
|
The model manager also offers a few convenience shortcuts for loading
|
||||||
|
models:
|
||||||
|
|
||||||
|
### mm.load_model_by_config(model_config, [submodel], [context]) -> LoadedModel
|
||||||
|
|
||||||
|
Same as `mm.load.load_model()`.
|
||||||
|
|
||||||
|
### mm.load_model_by_attr(model_name, base_model, model_type, [submodel], [context]) -> LoadedModel
|
||||||
|
|
||||||
|
This accepts the combination of the model's name, type and base, which
|
||||||
|
it passes to the model record config store for retrieval. If a unique
|
||||||
|
model config is found, this method returns a `LoadedModel`. It can
|
||||||
|
raise the following exceptions:
|
||||||
|
|
||||||
|
```
|
||||||
|
UnknownModelException -- model with these attributes not known
|
||||||
|
NotImplementedException -- the loader doesn't know how to load this type of model
|
||||||
|
ValueError -- more than one model matches this combination of base/type/name
|
||||||
|
```
|
||||||
|
|
||||||
|
### mm.load_model_by_key(key, [submodel], [context]) -> LoadedModel
|
||||||
|
|
||||||
|
This method takes a model key, looks it up using the
|
||||||
|
`ModelRecordServiceBase` object in `mm.store`, and passes the returned
|
||||||
|
model configuration to `load_model_by_config()`. It may raise a
|
||||||
|
`NotImplementedException`.
|
||||||
|
@ -35,7 +35,7 @@ from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata
|
|||||||
|
|
||||||
from ..dependencies import ApiDependencies
|
from ..dependencies import ApiDependencies
|
||||||
|
|
||||||
model_manager_v2_router = APIRouter(prefix="/v2/models", tags=["model_manager_v2"])
|
model_manager_router = APIRouter(prefix="/v2/models", tags=["model_manager"])
|
||||||
|
|
||||||
|
|
||||||
class ModelsList(BaseModel):
|
class ModelsList(BaseModel):
|
||||||
@ -135,7 +135,7 @@ example_model_metadata = {
|
|||||||
##############################################################################
|
##############################################################################
|
||||||
|
|
||||||
|
|
||||||
@model_manager_v2_router.get(
|
@model_manager_router.get(
|
||||||
"/",
|
"/",
|
||||||
operation_id="list_model_records",
|
operation_id="list_model_records",
|
||||||
)
|
)
|
||||||
@ -164,7 +164,7 @@ async def list_model_records(
|
|||||||
return ModelsList(models=found_models)
|
return ModelsList(models=found_models)
|
||||||
|
|
||||||
|
|
||||||
@model_manager_v2_router.get(
|
@model_manager_router.get(
|
||||||
"/i/{key}",
|
"/i/{key}",
|
||||||
operation_id="get_model_record",
|
operation_id="get_model_record",
|
||||||
responses={
|
responses={
|
||||||
@ -188,7 +188,7 @@ async def get_model_record(
|
|||||||
raise HTTPException(status_code=404, detail=str(e))
|
raise HTTPException(status_code=404, detail=str(e))
|
||||||
|
|
||||||
|
|
||||||
@model_manager_v2_router.get("/summary", operation_id="list_model_summary")
|
@model_manager_router.get("/summary", operation_id="list_model_summary")
|
||||||
async def list_model_summary(
|
async def list_model_summary(
|
||||||
page: int = Query(default=0, description="The page to get"),
|
page: int = Query(default=0, description="The page to get"),
|
||||||
per_page: int = Query(default=10, description="The number of models per page"),
|
per_page: int = Query(default=10, description="The number of models per page"),
|
||||||
@ -200,7 +200,7 @@ async def list_model_summary(
|
|||||||
return results
|
return results
|
||||||
|
|
||||||
|
|
||||||
@model_manager_v2_router.get(
|
@model_manager_router.get(
|
||||||
"/meta/i/{key}",
|
"/meta/i/{key}",
|
||||||
operation_id="get_model_metadata",
|
operation_id="get_model_metadata",
|
||||||
responses={
|
responses={
|
||||||
@ -223,7 +223,7 @@ async def get_model_metadata(
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
@model_manager_v2_router.get(
|
@model_manager_router.get(
|
||||||
"/tags",
|
"/tags",
|
||||||
operation_id="list_tags",
|
operation_id="list_tags",
|
||||||
)
|
)
|
||||||
@ -234,7 +234,7 @@ async def list_tags() -> Set[str]:
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
@model_manager_v2_router.get(
|
@model_manager_router.get(
|
||||||
"/tags/search",
|
"/tags/search",
|
||||||
operation_id="search_by_metadata_tags",
|
operation_id="search_by_metadata_tags",
|
||||||
)
|
)
|
||||||
@ -247,7 +247,7 @@ async def search_by_metadata_tags(
|
|||||||
return ModelsList(models=results)
|
return ModelsList(models=results)
|
||||||
|
|
||||||
|
|
||||||
@model_manager_v2_router.patch(
|
@model_manager_router.patch(
|
||||||
"/i/{key}",
|
"/i/{key}",
|
||||||
operation_id="update_model_record",
|
operation_id="update_model_record",
|
||||||
responses={
|
responses={
|
||||||
@ -281,7 +281,7 @@ async def update_model_record(
|
|||||||
return model_response
|
return model_response
|
||||||
|
|
||||||
|
|
||||||
@model_manager_v2_router.delete(
|
@model_manager_router.delete(
|
||||||
"/i/{key}",
|
"/i/{key}",
|
||||||
operation_id="del_model_record",
|
operation_id="del_model_record",
|
||||||
responses={
|
responses={
|
||||||
@ -311,7 +311,7 @@ async def del_model_record(
|
|||||||
raise HTTPException(status_code=404, detail=str(e))
|
raise HTTPException(status_code=404, detail=str(e))
|
||||||
|
|
||||||
|
|
||||||
@model_manager_v2_router.post(
|
@model_manager_router.post(
|
||||||
"/i/",
|
"/i/",
|
||||||
operation_id="add_model_record",
|
operation_id="add_model_record",
|
||||||
responses={
|
responses={
|
||||||
@ -349,7 +349,7 @@ async def add_model_record(
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
@model_manager_v2_router.post(
|
@model_manager_router.post(
|
||||||
"/heuristic_import",
|
"/heuristic_import",
|
||||||
operation_id="heuristic_import_model",
|
operation_id="heuristic_import_model",
|
||||||
responses={
|
responses={
|
||||||
@ -416,7 +416,7 @@ async def heuristic_import(
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
@model_manager_v2_router.post(
|
@model_manager_router.post(
|
||||||
"/install",
|
"/install",
|
||||||
operation_id="import_model",
|
operation_id="import_model",
|
||||||
responses={
|
responses={
|
||||||
@ -516,7 +516,7 @@ async def import_model(
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
@model_manager_v2_router.get(
|
@model_manager_router.get(
|
||||||
"/import",
|
"/import",
|
||||||
operation_id="list_model_install_jobs",
|
operation_id="list_model_install_jobs",
|
||||||
)
|
)
|
||||||
@ -544,7 +544,7 @@ async def list_model_install_jobs() -> List[ModelInstallJob]:
|
|||||||
return jobs
|
return jobs
|
||||||
|
|
||||||
|
|
||||||
@model_manager_v2_router.get(
|
@model_manager_router.get(
|
||||||
"/import/{id}",
|
"/import/{id}",
|
||||||
operation_id="get_model_install_job",
|
operation_id="get_model_install_job",
|
||||||
responses={
|
responses={
|
||||||
@ -564,7 +564,7 @@ async def get_model_install_job(id: int = Path(description="Model install id"))
|
|||||||
raise HTTPException(status_code=404, detail=str(e))
|
raise HTTPException(status_code=404, detail=str(e))
|
||||||
|
|
||||||
|
|
||||||
@model_manager_v2_router.delete(
|
@model_manager_router.delete(
|
||||||
"/import/{id}",
|
"/import/{id}",
|
||||||
operation_id="cancel_model_install_job",
|
operation_id="cancel_model_install_job",
|
||||||
responses={
|
responses={
|
||||||
@ -583,7 +583,7 @@ async def cancel_model_install_job(id: int = Path(description="Model install job
|
|||||||
installer.cancel_job(job)
|
installer.cancel_job(job)
|
||||||
|
|
||||||
|
|
||||||
@model_manager_v2_router.patch(
|
@model_manager_router.patch(
|
||||||
"/import",
|
"/import",
|
||||||
operation_id="prune_model_install_jobs",
|
operation_id="prune_model_install_jobs",
|
||||||
responses={
|
responses={
|
||||||
@ -597,7 +597,7 @@ async def prune_model_install_jobs() -> Response:
|
|||||||
return Response(status_code=204)
|
return Response(status_code=204)
|
||||||
|
|
||||||
|
|
||||||
@model_manager_v2_router.patch(
|
@model_manager_router.patch(
|
||||||
"/sync",
|
"/sync",
|
||||||
operation_id="sync_models_to_config",
|
operation_id="sync_models_to_config",
|
||||||
responses={
|
responses={
|
||||||
@ -616,7 +616,7 @@ async def sync_models_to_config() -> Response:
|
|||||||
return Response(status_code=204)
|
return Response(status_code=204)
|
||||||
|
|
||||||
|
|
||||||
@model_manager_v2_router.put(
|
@model_manager_router.put(
|
||||||
"/convert/{key}",
|
"/convert/{key}",
|
||||||
operation_id="convert_model",
|
operation_id="convert_model",
|
||||||
responses={
|
responses={
|
||||||
@ -694,7 +694,7 @@ async def convert_model(
|
|||||||
return new_config
|
return new_config
|
||||||
|
|
||||||
|
|
||||||
@model_manager_v2_router.put(
|
@model_manager_router.put(
|
||||||
"/merge",
|
"/merge",
|
||||||
operation_id="merge",
|
operation_id="merge",
|
||||||
responses={
|
responses={
|
@ -1,426 +0,0 @@
|
|||||||
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654), 2023 Kent Keirsey (https://github.com/hipsterusername), 2023 Lincoln D. Stein
|
|
||||||
|
|
||||||
import pathlib
|
|
||||||
from typing import Annotated, List, Literal, Optional, Union
|
|
||||||
|
|
||||||
from fastapi import Body, Path, Query, Response
|
|
||||||
from fastapi.routing import APIRouter
|
|
||||||
from pydantic import BaseModel, ConfigDict, Field, TypeAdapter
|
|
||||||
from starlette.exceptions import HTTPException
|
|
||||||
|
|
||||||
from invokeai.backend.model_management import BaseModelType, MergeInterpolationMethod, ModelType
|
|
||||||
from invokeai.backend.model_management.models import (
|
|
||||||
OPENAPI_MODEL_CONFIGS,
|
|
||||||
InvalidModelException,
|
|
||||||
ModelNotFoundException,
|
|
||||||
SchedulerPredictionType,
|
|
||||||
)
|
|
||||||
|
|
||||||
from ..dependencies import ApiDependencies
|
|
||||||
|
|
||||||
models_router = APIRouter(prefix="/v1/models", tags=["models"])
|
|
||||||
|
|
||||||
UpdateModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)]
|
|
||||||
UpdateModelResponseValidator = TypeAdapter(UpdateModelResponse)
|
|
||||||
|
|
||||||
ImportModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)]
|
|
||||||
ImportModelResponseValidator = TypeAdapter(ImportModelResponse)
|
|
||||||
|
|
||||||
ConvertModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)]
|
|
||||||
ConvertModelResponseValidator = TypeAdapter(ConvertModelResponse)
|
|
||||||
|
|
||||||
MergeModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)]
|
|
||||||
ImportModelAttributes = Union[tuple(OPENAPI_MODEL_CONFIGS)]
|
|
||||||
|
|
||||||
|
|
||||||
class ModelsList(BaseModel):
|
|
||||||
models: list[Union[tuple(OPENAPI_MODEL_CONFIGS)]]
|
|
||||||
|
|
||||||
model_config = ConfigDict(use_enum_values=True)
|
|
||||||
|
|
||||||
|
|
||||||
ModelsListValidator = TypeAdapter(ModelsList)
|
|
||||||
|
|
||||||
|
|
||||||
@models_router.get(
|
|
||||||
"/",
|
|
||||||
operation_id="list_models",
|
|
||||||
responses={200: {"model": ModelsList}},
|
|
||||||
)
|
|
||||||
async def list_models(
|
|
||||||
base_models: Optional[List[BaseModelType]] = Query(default=None, description="Base models to include"),
|
|
||||||
model_type: Optional[ModelType] = Query(default=None, description="The type of model to get"),
|
|
||||||
) -> ModelsList:
|
|
||||||
"""Gets a list of models"""
|
|
||||||
if base_models and len(base_models) > 0:
|
|
||||||
models_raw = []
|
|
||||||
for base_model in base_models:
|
|
||||||
models_raw.extend(ApiDependencies.invoker.services.model_manager.list_models(base_model, model_type))
|
|
||||||
else:
|
|
||||||
models_raw = ApiDependencies.invoker.services.model_manager.list_models(None, model_type)
|
|
||||||
models = ModelsListValidator.validate_python({"models": models_raw})
|
|
||||||
return models
|
|
||||||
|
|
||||||
|
|
||||||
@models_router.patch(
|
|
||||||
"/{base_model}/{model_type}/{model_name}",
|
|
||||||
operation_id="update_model",
|
|
||||||
responses={
|
|
||||||
200: {"description": "The model was updated successfully"},
|
|
||||||
400: {"description": "Bad request"},
|
|
||||||
404: {"description": "The model could not be found"},
|
|
||||||
409: {"description": "There is already a model corresponding to the new name"},
|
|
||||||
},
|
|
||||||
status_code=200,
|
|
||||||
response_model=UpdateModelResponse,
|
|
||||||
)
|
|
||||||
async def update_model(
|
|
||||||
base_model: BaseModelType = Path(description="Base model"),
|
|
||||||
model_type: ModelType = Path(description="The type of model"),
|
|
||||||
model_name: str = Path(description="model name"),
|
|
||||||
info: Union[tuple(OPENAPI_MODEL_CONFIGS)] = Body(description="Model configuration"),
|
|
||||||
) -> UpdateModelResponse:
|
|
||||||
"""Update model contents with a new config. If the model name or base fields are changed, then the model is renamed."""
|
|
||||||
logger = ApiDependencies.invoker.services.logger
|
|
||||||
|
|
||||||
try:
|
|
||||||
previous_info = ApiDependencies.invoker.services.model_manager.list_model(
|
|
||||||
model_name=model_name,
|
|
||||||
base_model=base_model,
|
|
||||||
model_type=model_type,
|
|
||||||
)
|
|
||||||
|
|
||||||
# rename operation requested
|
|
||||||
if info.model_name != model_name or info.base_model != base_model:
|
|
||||||
ApiDependencies.invoker.services.model_manager.rename_model(
|
|
||||||
base_model=base_model,
|
|
||||||
model_type=model_type,
|
|
||||||
model_name=model_name,
|
|
||||||
new_name=info.model_name,
|
|
||||||
new_base=info.base_model,
|
|
||||||
)
|
|
||||||
logger.info(f"Successfully renamed {base_model.value}/{model_name}=>{info.base_model}/{info.model_name}")
|
|
||||||
# update information to support an update of attributes
|
|
||||||
model_name = info.model_name
|
|
||||||
base_model = info.base_model
|
|
||||||
new_info = ApiDependencies.invoker.services.model_manager.list_model(
|
|
||||||
model_name=model_name,
|
|
||||||
base_model=base_model,
|
|
||||||
model_type=model_type,
|
|
||||||
)
|
|
||||||
if new_info.get("path") != previous_info.get(
|
|
||||||
"path"
|
|
||||||
): # model manager moved model path during rename - don't overwrite it
|
|
||||||
info.path = new_info.get("path")
|
|
||||||
|
|
||||||
# replace empty string values with None/null to avoid phenomenon of vae: ''
|
|
||||||
info_dict = info.model_dump()
|
|
||||||
info_dict = {x: info_dict[x] if info_dict[x] else None for x in info_dict.keys()}
|
|
||||||
|
|
||||||
ApiDependencies.invoker.services.model_manager.update_model(
|
|
||||||
model_name=model_name,
|
|
||||||
base_model=base_model,
|
|
||||||
model_type=model_type,
|
|
||||||
model_attributes=info_dict,
|
|
||||||
)
|
|
||||||
|
|
||||||
model_raw = ApiDependencies.invoker.services.model_manager.list_model(
|
|
||||||
model_name=model_name,
|
|
||||||
base_model=base_model,
|
|
||||||
model_type=model_type,
|
|
||||||
)
|
|
||||||
model_response = UpdateModelResponseValidator.validate_python(model_raw)
|
|
||||||
except ModelNotFoundException as e:
|
|
||||||
raise HTTPException(status_code=404, detail=str(e))
|
|
||||||
except ValueError as e:
|
|
||||||
logger.error(str(e))
|
|
||||||
raise HTTPException(status_code=409, detail=str(e))
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(str(e))
|
|
||||||
raise HTTPException(status_code=400, detail=str(e))
|
|
||||||
|
|
||||||
return model_response
|
|
||||||
|
|
||||||
|
|
||||||
@models_router.post(
|
|
||||||
"/import",
|
|
||||||
operation_id="import_model",
|
|
||||||
responses={
|
|
||||||
201: {"description": "The model imported successfully"},
|
|
||||||
404: {"description": "The model could not be found"},
|
|
||||||
415: {"description": "Unrecognized file/folder format"},
|
|
||||||
424: {"description": "The model appeared to import successfully, but could not be found in the model manager"},
|
|
||||||
409: {"description": "There is already a model corresponding to this path or repo_id"},
|
|
||||||
},
|
|
||||||
status_code=201,
|
|
||||||
response_model=ImportModelResponse,
|
|
||||||
)
|
|
||||||
async def import_model(
|
|
||||||
location: str = Body(description="A model path, repo_id or URL to import"),
|
|
||||||
prediction_type: Optional[Literal["v_prediction", "epsilon", "sample"]] = Body(
|
|
||||||
description="Prediction type for SDv2 checkpoints and rare SDv1 checkpoints",
|
|
||||||
default=None,
|
|
||||||
),
|
|
||||||
) -> ImportModelResponse:
|
|
||||||
"""Add a model using its local path, repo_id, or remote URL. Model characteristics will be probed and configured automatically"""
|
|
||||||
|
|
||||||
location = location.strip("\"' ")
|
|
||||||
items_to_import = {location}
|
|
||||||
prediction_types = {x.value: x for x in SchedulerPredictionType}
|
|
||||||
logger = ApiDependencies.invoker.services.logger
|
|
||||||
|
|
||||||
try:
|
|
||||||
installed_models = ApiDependencies.invoker.services.model_manager.heuristic_import(
|
|
||||||
items_to_import=items_to_import,
|
|
||||||
prediction_type_helper=lambda x: prediction_types.get(prediction_type),
|
|
||||||
)
|
|
||||||
info = installed_models.get(location)
|
|
||||||
|
|
||||||
if not info:
|
|
||||||
logger.error("Import failed")
|
|
||||||
raise HTTPException(status_code=415)
|
|
||||||
|
|
||||||
logger.info(f"Successfully imported {location}, got {info}")
|
|
||||||
model_raw = ApiDependencies.invoker.services.model_manager.list_model(
|
|
||||||
model_name=info.name, base_model=info.base_model, model_type=info.model_type
|
|
||||||
)
|
|
||||||
return ImportModelResponseValidator.validate_python(model_raw)
|
|
||||||
|
|
||||||
except ModelNotFoundException as e:
|
|
||||||
logger.error(str(e))
|
|
||||||
raise HTTPException(status_code=404, detail=str(e))
|
|
||||||
except InvalidModelException as e:
|
|
||||||
logger.error(str(e))
|
|
||||||
raise HTTPException(status_code=415)
|
|
||||||
except ValueError as e:
|
|
||||||
logger.error(str(e))
|
|
||||||
raise HTTPException(status_code=409, detail=str(e))
|
|
||||||
|
|
||||||
|
|
||||||
@models_router.post(
|
|
||||||
"/add",
|
|
||||||
operation_id="add_model",
|
|
||||||
responses={
|
|
||||||
201: {"description": "The model added successfully"},
|
|
||||||
404: {"description": "The model could not be found"},
|
|
||||||
424: {"description": "The model appeared to add successfully, but could not be found in the model manager"},
|
|
||||||
409: {"description": "There is already a model corresponding to this path or repo_id"},
|
|
||||||
},
|
|
||||||
status_code=201,
|
|
||||||
response_model=ImportModelResponse,
|
|
||||||
)
|
|
||||||
async def add_model(
|
|
||||||
info: Union[tuple(OPENAPI_MODEL_CONFIGS)] = Body(description="Model configuration"),
|
|
||||||
) -> ImportModelResponse:
|
|
||||||
"""Add a model using the configuration information appropriate for its type. Only local models can be added by path"""
|
|
||||||
|
|
||||||
logger = ApiDependencies.invoker.services.logger
|
|
||||||
|
|
||||||
try:
|
|
||||||
ApiDependencies.invoker.services.model_manager.add_model(
|
|
||||||
info.model_name,
|
|
||||||
info.base_model,
|
|
||||||
info.model_type,
|
|
||||||
model_attributes=info.model_dump(),
|
|
||||||
)
|
|
||||||
logger.info(f"Successfully added {info.model_name}")
|
|
||||||
model_raw = ApiDependencies.invoker.services.model_manager.list_model(
|
|
||||||
model_name=info.model_name,
|
|
||||||
base_model=info.base_model,
|
|
||||||
model_type=info.model_type,
|
|
||||||
)
|
|
||||||
return ImportModelResponseValidator.validate_python(model_raw)
|
|
||||||
except ModelNotFoundException as e:
|
|
||||||
logger.error(str(e))
|
|
||||||
raise HTTPException(status_code=404, detail=str(e))
|
|
||||||
except ValueError as e:
|
|
||||||
logger.error(str(e))
|
|
||||||
raise HTTPException(status_code=409, detail=str(e))
|
|
||||||
|
|
||||||
|
|
||||||
@models_router.delete(
|
|
||||||
"/{base_model}/{model_type}/{model_name}",
|
|
||||||
operation_id="del_model",
|
|
||||||
responses={
|
|
||||||
204: {"description": "Model deleted successfully"},
|
|
||||||
404: {"description": "Model not found"},
|
|
||||||
},
|
|
||||||
status_code=204,
|
|
||||||
response_model=None,
|
|
||||||
)
|
|
||||||
async def delete_model(
|
|
||||||
base_model: BaseModelType = Path(description="Base model"),
|
|
||||||
model_type: ModelType = Path(description="The type of model"),
|
|
||||||
model_name: str = Path(description="model name"),
|
|
||||||
) -> Response:
|
|
||||||
"""Delete Model"""
|
|
||||||
logger = ApiDependencies.invoker.services.logger
|
|
||||||
|
|
||||||
try:
|
|
||||||
ApiDependencies.invoker.services.model_manager.del_model(
|
|
||||||
model_name, base_model=base_model, model_type=model_type
|
|
||||||
)
|
|
||||||
logger.info(f"Deleted model: {model_name}")
|
|
||||||
return Response(status_code=204)
|
|
||||||
except ModelNotFoundException as e:
|
|
||||||
logger.error(str(e))
|
|
||||||
raise HTTPException(status_code=404, detail=str(e))
|
|
||||||
|
|
||||||
|
|
||||||
@models_router.put(
|
|
||||||
"/convert/{base_model}/{model_type}/{model_name}",
|
|
||||||
operation_id="convert_model",
|
|
||||||
responses={
|
|
||||||
200: {"description": "Model converted successfully"},
|
|
||||||
400: {"description": "Bad request"},
|
|
||||||
404: {"description": "Model not found"},
|
|
||||||
},
|
|
||||||
status_code=200,
|
|
||||||
response_model=ConvertModelResponse,
|
|
||||||
)
|
|
||||||
async def convert_model(
|
|
||||||
base_model: BaseModelType = Path(description="Base model"),
|
|
||||||
model_type: ModelType = Path(description="The type of model"),
|
|
||||||
model_name: str = Path(description="model name"),
|
|
||||||
convert_dest_directory: Optional[str] = Query(
|
|
||||||
default=None, description="Save the converted model to the designated directory"
|
|
||||||
),
|
|
||||||
) -> ConvertModelResponse:
|
|
||||||
"""Convert a checkpoint model into a diffusers model, optionally saving to the indicated destination directory, or `models` if none."""
|
|
||||||
logger = ApiDependencies.invoker.services.logger
|
|
||||||
try:
|
|
||||||
logger.info(f"Converting model: {model_name}")
|
|
||||||
dest = pathlib.Path(convert_dest_directory) if convert_dest_directory else None
|
|
||||||
ApiDependencies.invoker.services.model_manager.convert_model(
|
|
||||||
model_name,
|
|
||||||
base_model=base_model,
|
|
||||||
model_type=model_type,
|
|
||||||
convert_dest_directory=dest,
|
|
||||||
)
|
|
||||||
model_raw = ApiDependencies.invoker.services.model_manager.list_model(
|
|
||||||
model_name, base_model=base_model, model_type=model_type
|
|
||||||
)
|
|
||||||
response = ConvertModelResponseValidator.validate_python(model_raw)
|
|
||||||
except ModelNotFoundException as e:
|
|
||||||
raise HTTPException(status_code=404, detail=f"Model '{model_name}' not found: {str(e)}")
|
|
||||||
except ValueError as e:
|
|
||||||
raise HTTPException(status_code=400, detail=str(e))
|
|
||||||
return response
|
|
||||||
|
|
||||||
|
|
||||||
@models_router.get(
|
|
||||||
"/search",
|
|
||||||
operation_id="search_for_models",
|
|
||||||
responses={
|
|
||||||
200: {"description": "Directory searched successfully"},
|
|
||||||
404: {"description": "Invalid directory path"},
|
|
||||||
},
|
|
||||||
status_code=200,
|
|
||||||
response_model=List[pathlib.Path],
|
|
||||||
)
|
|
||||||
async def search_for_models(
|
|
||||||
search_path: pathlib.Path = Query(description="Directory path to search for models"),
|
|
||||||
) -> List[pathlib.Path]:
|
|
||||||
if not search_path.is_dir():
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=404,
|
|
||||||
detail=f"The search path '{search_path}' does not exist or is not directory",
|
|
||||||
)
|
|
||||||
return ApiDependencies.invoker.services.model_manager.search_for_models(search_path)
|
|
||||||
|
|
||||||
|
|
||||||
@models_router.get(
|
|
||||||
"/ckpt_confs",
|
|
||||||
operation_id="list_ckpt_configs",
|
|
||||||
responses={
|
|
||||||
200: {"description": "paths retrieved successfully"},
|
|
||||||
},
|
|
||||||
status_code=200,
|
|
||||||
response_model=List[pathlib.Path],
|
|
||||||
)
|
|
||||||
async def list_ckpt_configs() -> List[pathlib.Path]:
|
|
||||||
"""Return a list of the legacy checkpoint configuration files stored in `ROOT/configs/stable-diffusion`, relative to ROOT."""
|
|
||||||
return ApiDependencies.invoker.services.model_manager.list_checkpoint_configs()
|
|
||||||
|
|
||||||
|
|
||||||
@models_router.post(
|
|
||||||
"/sync",
|
|
||||||
operation_id="sync_to_config",
|
|
||||||
responses={
|
|
||||||
201: {"description": "synchronization successful"},
|
|
||||||
},
|
|
||||||
status_code=201,
|
|
||||||
response_model=bool,
|
|
||||||
)
|
|
||||||
async def sync_to_config() -> bool:
|
|
||||||
"""Call after making changes to models.yaml, autoimport directories or models directory to synchronize
|
|
||||||
in-memory data structures with disk data structures."""
|
|
||||||
ApiDependencies.invoker.services.model_manager.sync_to_config()
|
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
# There's some weird pydantic-fastapi behaviour that requires this to be a separate class
|
|
||||||
# TODO: After a few updates, see if it works inside the route operation handler?
|
|
||||||
class MergeModelsBody(BaseModel):
|
|
||||||
model_names: List[str] = Field(description="model name", min_length=2, max_length=3)
|
|
||||||
merged_model_name: Optional[str] = Field(description="Name of destination model")
|
|
||||||
alpha: Optional[float] = Field(description="Alpha weighting strength to apply to 2d and 3d models", default=0.5)
|
|
||||||
interp: Optional[MergeInterpolationMethod] = Field(description="Interpolation method")
|
|
||||||
force: Optional[bool] = Field(
|
|
||||||
description="Force merging of models created with different versions of diffusers",
|
|
||||||
default=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
merge_dest_directory: Optional[str] = Field(
|
|
||||||
description="Save the merged model to the designated directory (with 'merged_model_name' appended)",
|
|
||||||
default=None,
|
|
||||||
)
|
|
||||||
|
|
||||||
model_config = ConfigDict(protected_namespaces=())
|
|
||||||
|
|
||||||
|
|
||||||
@models_router.put(
|
|
||||||
"/merge/{base_model}",
|
|
||||||
operation_id="merge_models",
|
|
||||||
responses={
|
|
||||||
200: {"description": "Model converted successfully"},
|
|
||||||
400: {"description": "Incompatible models"},
|
|
||||||
404: {"description": "One or more models not found"},
|
|
||||||
},
|
|
||||||
status_code=200,
|
|
||||||
response_model=MergeModelResponse,
|
|
||||||
)
|
|
||||||
async def merge_models(
|
|
||||||
body: Annotated[MergeModelsBody, Body(description="Model configuration", embed=True)],
|
|
||||||
base_model: BaseModelType = Path(description="Base model"),
|
|
||||||
) -> MergeModelResponse:
|
|
||||||
"""Convert a checkpoint model into a diffusers model"""
|
|
||||||
logger = ApiDependencies.invoker.services.logger
|
|
||||||
try:
|
|
||||||
logger.info(
|
|
||||||
f"Merging models: {body.model_names} into {body.merge_dest_directory or '<MODELS>'}/{body.merged_model_name}"
|
|
||||||
)
|
|
||||||
dest = pathlib.Path(body.merge_dest_directory) if body.merge_dest_directory else None
|
|
||||||
result = ApiDependencies.invoker.services.model_manager.merge_models(
|
|
||||||
model_names=body.model_names,
|
|
||||||
base_model=base_model,
|
|
||||||
merged_model_name=body.merged_model_name or "+".join(body.model_names),
|
|
||||||
alpha=body.alpha,
|
|
||||||
interp=body.interp,
|
|
||||||
force=body.force,
|
|
||||||
merge_dest_directory=dest,
|
|
||||||
)
|
|
||||||
model_raw = ApiDependencies.invoker.services.model_manager.list_model(
|
|
||||||
result.name,
|
|
||||||
base_model=base_model,
|
|
||||||
model_type=ModelType.Main,
|
|
||||||
)
|
|
||||||
response = ConvertModelResponseValidator.validate_python(model_raw)
|
|
||||||
except ModelNotFoundException:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=404,
|
|
||||||
detail=f"One or more of the models '{body.model_names}' not found",
|
|
||||||
)
|
|
||||||
except ValueError as e:
|
|
||||||
raise HTTPException(status_code=400, detail=str(e))
|
|
||||||
return response
|
|
@ -47,7 +47,7 @@ if True: # hack to make flake8 happy with imports coming after setting up the c
|
|||||||
boards,
|
boards,
|
||||||
download_queue,
|
download_queue,
|
||||||
images,
|
images,
|
||||||
model_manager_v2,
|
model_manager,
|
||||||
session_queue,
|
session_queue,
|
||||||
sessions,
|
sessions,
|
||||||
utilities,
|
utilities,
|
||||||
@ -114,7 +114,7 @@ async def shutdown_event() -> None:
|
|||||||
app.include_router(sessions.session_router, prefix="/api")
|
app.include_router(sessions.session_router, prefix="/api")
|
||||||
|
|
||||||
app.include_router(utilities.utilities_router, prefix="/api")
|
app.include_router(utilities.utilities_router, prefix="/api")
|
||||||
app.include_router(model_manager_v2.model_manager_v2_router, prefix="/api")
|
app.include_router(model_manager.model_manager_router, prefix="/api")
|
||||||
app.include_router(download_queue.download_queue_router, prefix="/api")
|
app.include_router(download_queue.download_queue_router, prefix="/api")
|
||||||
app.include_router(images.images_router, prefix="/api")
|
app.include_router(images.images_router, prefix="/api")
|
||||||
app.include_router(boards.boards_router, prefix="/api")
|
app.include_router(boards.boards_router, prefix="/api")
|
||||||
@ -176,21 +176,23 @@ def custom_openapi() -> dict[str, Any]:
|
|||||||
invoker_schema["class"] = "invocation"
|
invoker_schema["class"] = "invocation"
|
||||||
openapi_schema["components"]["schemas"][f"{output_type_title}"]["class"] = "output"
|
openapi_schema["components"]["schemas"][f"{output_type_title}"]["class"] = "output"
|
||||||
|
|
||||||
from invokeai.backend.model_management.models import get_model_config_enums
|
# This code no longer seems to be necessary?
|
||||||
|
# Leave it here just in case
|
||||||
|
#
|
||||||
|
# from invokeai.backend.model_manager import get_model_config_formats
|
||||||
|
# formats = get_model_config_formats()
|
||||||
|
# for model_config_name, enum_set in formats.items():
|
||||||
|
|
||||||
for model_config_format_enum in set(get_model_config_enums()):
|
# if model_config_name in openapi_schema["components"]["schemas"]:
|
||||||
name = model_config_format_enum.__qualname__
|
# # print(f"Config with name {name} already defined")
|
||||||
|
# continue
|
||||||
|
|
||||||
if name in openapi_schema["components"]["schemas"]:
|
# openapi_schema["components"]["schemas"][model_config_name] = {
|
||||||
# print(f"Config with name {name} already defined")
|
# "title": model_config_name,
|
||||||
continue
|
# "description": "An enumeration.",
|
||||||
|
# "type": "string",
|
||||||
openapi_schema["components"]["schemas"][name] = {
|
# "enum": [v.value for v in enum_set],
|
||||||
"title": name,
|
# }
|
||||||
"description": "An enumeration.",
|
|
||||||
"type": "string",
|
|
||||||
"enum": [v.value for v in model_config_format_enum],
|
|
||||||
}
|
|
||||||
|
|
||||||
app.openapi_schema = openapi_schema
|
app.openapi_schema = openapi_schema
|
||||||
return app.openapi_schema
|
return app.openapi_schema
|
||||||
|
@ -12,14 +12,14 @@ from invokeai.app.services.model_records import UnknownModelException
|
|||||||
from invokeai.app.shared.fields import FieldDescriptions
|
from invokeai.app.shared.fields import FieldDescriptions
|
||||||
from invokeai.app.util.ti_utils import extract_ti_triggers_from_prompt
|
from invokeai.app.util.ti_utils import extract_ti_triggers_from_prompt
|
||||||
from invokeai.backend.lora import LoRAModelRaw
|
from invokeai.backend.lora import LoRAModelRaw
|
||||||
from invokeai.backend.model_patcher import ModelPatcher
|
|
||||||
from invokeai.backend.textual_inversion import TextualInversionModelRaw
|
|
||||||
from invokeai.backend.model_manager import ModelType
|
from invokeai.backend.model_manager import ModelType
|
||||||
|
from invokeai.backend.model_patcher import ModelPatcher
|
||||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
||||||
BasicConditioningInfo,
|
BasicConditioningInfo,
|
||||||
ExtraConditioningInfo,
|
ExtraConditioningInfo,
|
||||||
SDXLConditioningInfo,
|
SDXLConditioningInfo,
|
||||||
)
|
)
|
||||||
|
from invokeai.backend.textual_inversion import TextualInversionModelRaw
|
||||||
from invokeai.backend.util.devices import torch_dtype
|
from invokeai.backend.util.devices import torch_dtype
|
||||||
|
|
||||||
from .baseinvocation import (
|
from .baseinvocation import (
|
||||||
@ -71,18 +71,18 @@ class CompelInvocation(BaseInvocation):
|
|||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def invoke(self, context: InvocationContext) -> ConditioningOutput:
|
def invoke(self, context: InvocationContext) -> ConditioningOutput:
|
||||||
tokenizer_info = context.services.model_manager.load.load_model_by_key(
|
tokenizer_info = context.services.model_manager.load_model_by_key(
|
||||||
**self.clip.tokenizer.model_dump(),
|
**self.clip.tokenizer.model_dump(),
|
||||||
context=context,
|
context=context,
|
||||||
)
|
)
|
||||||
text_encoder_info = context.services.model_manager.load.load_model_by_key(
|
text_encoder_info = context.services.model_manager.load_model_by_key(
|
||||||
**self.clip.text_encoder.model_dump(),
|
**self.clip.text_encoder.model_dump(),
|
||||||
context=context,
|
context=context,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
|
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
|
||||||
for lora in self.clip.loras:
|
for lora in self.clip.loras:
|
||||||
lora_info = context.services.model_manager.load.load_model_by_key(
|
lora_info = context.services.model_manager.load_model_by_key(
|
||||||
**lora.model_dump(exclude={"weight"}), context=context
|
**lora.model_dump(exclude={"weight"}), context=context
|
||||||
)
|
)
|
||||||
assert isinstance(lora_info.model, LoRAModelRaw)
|
assert isinstance(lora_info.model, LoRAModelRaw)
|
||||||
@ -96,7 +96,7 @@ class CompelInvocation(BaseInvocation):
|
|||||||
for trigger in extract_ti_triggers_from_prompt(self.prompt):
|
for trigger in extract_ti_triggers_from_prompt(self.prompt):
|
||||||
name = trigger[1:-1]
|
name = trigger[1:-1]
|
||||||
try:
|
try:
|
||||||
loaded_model = context.services.model_manager.load.load_model_by_key(
|
loaded_model = context.services.model_manager.load_model_by_key(
|
||||||
**self.clip.text_encoder.model_dump(),
|
**self.clip.text_encoder.model_dump(),
|
||||||
context=context,
|
context=context,
|
||||||
).model
|
).model
|
||||||
@ -172,11 +172,11 @@ class SDXLPromptInvocationBase:
|
|||||||
lora_prefix: str,
|
lora_prefix: str,
|
||||||
zero_on_empty: bool,
|
zero_on_empty: bool,
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[ExtraConditioningInfo]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[ExtraConditioningInfo]]:
|
||||||
tokenizer_info = context.services.model_manager.load.load_model_by_key(
|
tokenizer_info = context.services.model_manager.load_model_by_key(
|
||||||
**clip_field.tokenizer.model_dump(),
|
**clip_field.tokenizer.model_dump(),
|
||||||
context=context,
|
context=context,
|
||||||
)
|
)
|
||||||
text_encoder_info = context.services.model_manager.load.load_model_by_key(
|
text_encoder_info = context.services.model_manager.load_model_by_key(
|
||||||
**clip_field.text_encoder.model_dump(),
|
**clip_field.text_encoder.model_dump(),
|
||||||
context=context,
|
context=context,
|
||||||
)
|
)
|
||||||
@ -204,7 +204,7 @@ class SDXLPromptInvocationBase:
|
|||||||
|
|
||||||
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
|
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
|
||||||
for lora in clip_field.loras:
|
for lora in clip_field.loras:
|
||||||
lora_info = context.services.model_manager.load.load_model_by_key(
|
lora_info = context.services.model_manager.load_model_by_key(
|
||||||
**lora.model_dump(exclude={"weight"}), context=context
|
**lora.model_dump(exclude={"weight"}), context=context
|
||||||
)
|
)
|
||||||
lora_model = lora_info.model
|
lora_model = lora_info.model
|
||||||
@ -219,7 +219,7 @@ class SDXLPromptInvocationBase:
|
|||||||
for trigger in extract_ti_triggers_from_prompt(prompt):
|
for trigger in extract_ti_triggers_from_prompt(prompt):
|
||||||
name = trigger[1:-1]
|
name = trigger[1:-1]
|
||||||
try:
|
try:
|
||||||
ti_model = context.services.model_manager.load.load_model_by_attr(
|
ti_model = context.services.model_manager.load_model_by_attr(
|
||||||
model_name=name,
|
model_name=name,
|
||||||
base_model=text_encoder_info.config.base,
|
base_model=text_encoder_info.config.base,
|
||||||
model_type=ModelType.TextualInversion,
|
model_type=ModelType.TextualInversion,
|
||||||
|
@ -42,10 +42,10 @@ from invokeai.app.services.image_records.image_records_common import ImageCatego
|
|||||||
from invokeai.app.shared.fields import FieldDescriptions
|
from invokeai.app.shared.fields import FieldDescriptions
|
||||||
from invokeai.app.util.controlnet_utils import prepare_control_image
|
from invokeai.app.util.controlnet_utils import prepare_control_image
|
||||||
from invokeai.app.util.step_callback import stable_diffusion_step_callback
|
from invokeai.app.util.step_callback import stable_diffusion_step_callback
|
||||||
from invokeai.backend.lora import LoRAModelRaw
|
|
||||||
from invokeai.backend.model_patcher import ModelPatcher
|
|
||||||
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter, IPAdapterPlus
|
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter, IPAdapterPlus
|
||||||
|
from invokeai.backend.lora import LoRAModelRaw
|
||||||
from invokeai.backend.model_manager import BaseModelType, LoadedModel
|
from invokeai.backend.model_manager import BaseModelType, LoadedModel
|
||||||
|
from invokeai.backend.model_patcher import ModelPatcher
|
||||||
from invokeai.backend.stable_diffusion import PipelineIntermediateState, set_seamless
|
from invokeai.backend.stable_diffusion import PipelineIntermediateState, set_seamless
|
||||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningData, IPAdapterConditioningInfo
|
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningData, IPAdapterConditioningInfo
|
||||||
from invokeai.backend.util.silence_warnings import SilenceWarnings
|
from invokeai.backend.util.silence_warnings import SilenceWarnings
|
||||||
@ -162,7 +162,7 @@ class CreateDenoiseMaskInvocation(BaseInvocation):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if image_tensor is not None:
|
if image_tensor is not None:
|
||||||
vae_info = context.services.model_manager.load.load_model_by_key(
|
vae_info = context.services.model_manager.load_model_by_key(
|
||||||
**self.vae.vae.model_dump(),
|
**self.vae.vae.model_dump(),
|
||||||
context=context,
|
context=context,
|
||||||
)
|
)
|
||||||
@ -195,7 +195,7 @@ def get_scheduler(
|
|||||||
seed: int,
|
seed: int,
|
||||||
) -> Scheduler:
|
) -> Scheduler:
|
||||||
scheduler_class, scheduler_extra_config = SCHEDULER_MAP.get(scheduler_name, SCHEDULER_MAP["ddim"])
|
scheduler_class, scheduler_extra_config = SCHEDULER_MAP.get(scheduler_name, SCHEDULER_MAP["ddim"])
|
||||||
orig_scheduler_info = context.services.model_manager.load.load_model_by_key(
|
orig_scheduler_info = context.services.model_manager.load_model_by_key(
|
||||||
**scheduler_info.model_dump(),
|
**scheduler_info.model_dump(),
|
||||||
context=context,
|
context=context,
|
||||||
)
|
)
|
||||||
@ -429,7 +429,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
controlnet_data = []
|
controlnet_data = []
|
||||||
for control_info in control_list:
|
for control_info in control_list:
|
||||||
control_model = exit_stack.enter_context(
|
control_model = exit_stack.enter_context(
|
||||||
context.services.model_manager.load.load_model_by_key(
|
context.services.model_manager.load_model_by_key(
|
||||||
key=control_info.control_model.key,
|
key=control_info.control_model.key,
|
||||||
context=context,
|
context=context,
|
||||||
)
|
)
|
||||||
@ -495,13 +495,13 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
conditioning_data.ip_adapter_conditioning = []
|
conditioning_data.ip_adapter_conditioning = []
|
||||||
for single_ip_adapter in ip_adapter:
|
for single_ip_adapter in ip_adapter:
|
||||||
ip_adapter_model: Union[IPAdapter, IPAdapterPlus] = exit_stack.enter_context(
|
ip_adapter_model: Union[IPAdapter, IPAdapterPlus] = exit_stack.enter_context(
|
||||||
context.services.model_manager.load.load_model_by_key(
|
context.services.model_manager.load_model_by_key(
|
||||||
key=single_ip_adapter.ip_adapter_model.key,
|
key=single_ip_adapter.ip_adapter_model.key,
|
||||||
context=context,
|
context=context,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
image_encoder_model_info = context.services.model_manager.load.load_model_by_key(
|
image_encoder_model_info = context.services.model_manager.load_model_by_key(
|
||||||
key=single_ip_adapter.image_encoder_model.key,
|
key=single_ip_adapter.image_encoder_model.key,
|
||||||
context=context,
|
context=context,
|
||||||
)
|
)
|
||||||
@ -557,7 +557,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
|
|
||||||
t2i_adapter_data = []
|
t2i_adapter_data = []
|
||||||
for t2i_adapter_field in t2i_adapter:
|
for t2i_adapter_field in t2i_adapter:
|
||||||
t2i_adapter_model_info = context.services.model_manager.load.load_model_by_key(
|
t2i_adapter_model_info = context.services.model_manager.load_model_by_key(
|
||||||
key=t2i_adapter_field.t2i_adapter_model.key,
|
key=t2i_adapter_field.t2i_adapter_model.key,
|
||||||
context=context,
|
context=context,
|
||||||
)
|
)
|
||||||
@ -717,7 +717,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
|
|
||||||
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
|
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
|
||||||
for lora in self.unet.loras:
|
for lora in self.unet.loras:
|
||||||
lora_info = context.services.model_manager.load.load_model_by_key(
|
lora_info = context.services.model_manager.load_model_by_key(
|
||||||
**lora.model_dump(exclude={"weight"}),
|
**lora.model_dump(exclude={"weight"}),
|
||||||
context=context,
|
context=context,
|
||||||
)
|
)
|
||||||
@ -725,7 +725,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
del lora_info
|
del lora_info
|
||||||
return
|
return
|
||||||
|
|
||||||
unet_info = context.services.model_manager.load.load_model_by_key(
|
unet_info = context.services.model_manager.load_model_by_key(
|
||||||
**self.unet.unet.model_dump(),
|
**self.unet.unet.model_dump(),
|
||||||
context=context,
|
context=context,
|
||||||
)
|
)
|
||||||
@ -836,7 +836,7 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata):
|
|||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
latents = context.services.latents.get(self.latents.latents_name)
|
latents = context.services.latents.get(self.latents.latents_name)
|
||||||
|
|
||||||
vae_info = context.services.model_manager.load.load_model_by_key(
|
vae_info = context.services.model_manager.load_model_by_key(
|
||||||
**self.vae.vae.model_dump(),
|
**self.vae.vae.model_dump(),
|
||||||
context=context,
|
context=context,
|
||||||
)
|
)
|
||||||
@ -1079,7 +1079,7 @@ class ImageToLatentsInvocation(BaseInvocation):
|
|||||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||||
image = context.services.images.get_pil_image(self.image.image_name)
|
image = context.services.images.get_pil_image(self.image.image_name)
|
||||||
|
|
||||||
vae_info = context.services.model_manager.load.load_model_by_key(
|
vae_info = context.services.model_manager.load_model_by_key(
|
||||||
**self.vae.vae.model_dump(),
|
**self.vae.vae.model_dump(),
|
||||||
context=context,
|
context=context,
|
||||||
)
|
)
|
||||||
@ -1290,7 +1290,7 @@ class IdealSizeInvocation(BaseInvocation):
|
|||||||
return tuple((x - x % multiple_of) for x in args)
|
return tuple((x - x % multiple_of) for x in args)
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> IdealSizeOutput:
|
def invoke(self, context: InvocationContext) -> IdealSizeOutput:
|
||||||
unet_config = context.services.model_manager.load.load_model_by_key(
|
unet_config = context.services.model_manager.load_model_by_key(
|
||||||
**self.unet.unet.model_dump(),
|
**self.unet.unet.model_dump(),
|
||||||
context=context,
|
context=context,
|
||||||
)
|
)
|
||||||
|
@ -15,8 +15,8 @@ from invokeai.app.invocations.primitives import ConditioningField, ConditioningO
|
|||||||
from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin
|
from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin
|
||||||
from invokeai.app.shared.fields import FieldDescriptions
|
from invokeai.app.shared.fields import FieldDescriptions
|
||||||
from invokeai.app.util.step_callback import stable_diffusion_step_callback
|
from invokeai.app.util.step_callback import stable_diffusion_step_callback
|
||||||
from invokeai.backend.model_patcher import ONNXModelPatcher
|
|
||||||
from invokeai.backend.model_manager import ModelType, SubModelType
|
from invokeai.backend.model_manager import ModelType, SubModelType
|
||||||
|
from invokeai.backend.model_patcher import ONNXModelPatcher
|
||||||
|
|
||||||
from ...backend.stable_diffusion import PipelineIntermediateState
|
from ...backend.stable_diffusion import PipelineIntermediateState
|
||||||
from ...backend.util import choose_torch_device
|
from ...backend.util import choose_torch_device
|
||||||
@ -62,16 +62,16 @@ class ONNXPromptInvocation(BaseInvocation):
|
|||||||
clip: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection)
|
clip: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection)
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ConditioningOutput:
|
def invoke(self, context: InvocationContext) -> ConditioningOutput:
|
||||||
tokenizer_info = context.services.model_manager.load.load_model_by_key(
|
tokenizer_info = context.services.model_manager.load_model_by_key(
|
||||||
**self.clip.tokenizer.model_dump(),
|
**self.clip.tokenizer.model_dump(),
|
||||||
)
|
)
|
||||||
text_encoder_info = context.services.model_manager.load.load_model_by_key(
|
text_encoder_info = context.services.model_manager.load_model_by_key(
|
||||||
**self.clip.text_encoder.model_dump(),
|
**self.clip.text_encoder.model_dump(),
|
||||||
)
|
)
|
||||||
with tokenizer_info as orig_tokenizer, text_encoder_info as text_encoder: # , ExitStack() as stack:
|
with tokenizer_info as orig_tokenizer, text_encoder_info as text_encoder: # , ExitStack() as stack:
|
||||||
loras = [
|
loras = [
|
||||||
(
|
(
|
||||||
context.services.model_manager.load.load_model_by_key(**lora.model_dump(exclude={"weight"})).model,
|
context.services.model_manager.load_model_by_key(**lora.model_dump(exclude={"weight"})).model,
|
||||||
lora.weight,
|
lora.weight,
|
||||||
)
|
)
|
||||||
for lora in self.clip.loras
|
for lora in self.clip.loras
|
||||||
@ -84,7 +84,7 @@ class ONNXPromptInvocation(BaseInvocation):
|
|||||||
ti_list.append(
|
ti_list.append(
|
||||||
(
|
(
|
||||||
name,
|
name,
|
||||||
context.services.model_manager.load.load_model_by_attr(
|
context.services.model_manager.load_model_by_attr(
|
||||||
model_name=name,
|
model_name=name,
|
||||||
base_model=text_encoder_info.config.base,
|
base_model=text_encoder_info.config.base,
|
||||||
model_type=ModelType.TextualInversion,
|
model_type=ModelType.TextualInversion,
|
||||||
@ -257,13 +257,13 @@ class ONNXTextToLatentsInvocation(BaseInvocation):
|
|||||||
eta=0.0,
|
eta=0.0,
|
||||||
)
|
)
|
||||||
|
|
||||||
unet_info = context.services.model_manager.load.load_model_by_key(**self.unet.unet.model_dump())
|
unet_info = context.services.model_manager.load_model_by_key(**self.unet.unet.model_dump())
|
||||||
|
|
||||||
with unet_info as unet: # , ExitStack() as stack:
|
with unet_info as unet: # , ExitStack() as stack:
|
||||||
# loras = [(stack.enter_context(context.services.model_manager.get_model(**lora.dict(exclude={"weight"}))), lora.weight) for lora in self.unet.loras]
|
# loras = [(stack.enter_context(context.services.model_manager.get_model(**lora.dict(exclude={"weight"}))), lora.weight) for lora in self.unet.loras]
|
||||||
loras = [
|
loras = [
|
||||||
(
|
(
|
||||||
context.services.model_manager.load.load_model_by_key(**lora.model_dump(exclude={"weight"})).model,
|
context.services.model_manager.load_model_by_key(**lora.model_dump(exclude={"weight"})).model,
|
||||||
lora.weight,
|
lora.weight,
|
||||||
)
|
)
|
||||||
for lora in self.unet.loras
|
for lora in self.unet.loras
|
||||||
@ -346,7 +346,7 @@ class ONNXLatentsToImageInvocation(BaseInvocation, WithMetadata):
|
|||||||
if self.vae.vae.submodel != SubModelType.VaeDecoder:
|
if self.vae.vae.submodel != SubModelType.VaeDecoder:
|
||||||
raise Exception(f"Expected vae_decoder, found: {self.vae.vae.submodel}")
|
raise Exception(f"Expected vae_decoder, found: {self.vae.vae.submodel}")
|
||||||
|
|
||||||
vae_info = context.services.model_manager.load.load_model_by_key(
|
vae_info = context.services.model_manager.load_model_by_key(
|
||||||
**self.vae.vae.model_dump(),
|
**self.vae.vae.model_dump(),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -68,7 +68,7 @@ class InvokeAISettings(BaseSettings):
|
|||||||
return OmegaConf.to_yaml(conf)
|
return OmegaConf.to_yaml(conf)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def add_parser_arguments(cls, parser) -> None:
|
def add_parser_arguments(cls, parser: ArgumentParser) -> None:
|
||||||
"""Dynamically create arguments for a settings parser."""
|
"""Dynamically create arguments for a settings parser."""
|
||||||
if "type" in get_type_hints(cls):
|
if "type" in get_type_hints(cls):
|
||||||
settings_stanza = get_args(get_type_hints(cls)["type"])[0]
|
settings_stanza = get_args(get_type_hints(cls)["type"])[0]
|
||||||
|
@ -29,8 +29,8 @@ writes to the system log is stored in InvocationServices.performance_statistics.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from contextlib import AbstractContextManager
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import Iterator
|
||||||
|
|
||||||
from invokeai.app.invocations.baseinvocation import BaseInvocation
|
from invokeai.app.invocations.baseinvocation import BaseInvocation
|
||||||
from invokeai.app.services.invocation_stats.invocation_stats_common import InvocationStatsSummary
|
from invokeai.app.services.invocation_stats.invocation_stats_common import InvocationStatsSummary
|
||||||
@ -40,18 +40,17 @@ class InvocationStatsServiceBase(ABC):
|
|||||||
"Abstract base class for recording node memory/time performance statistics"
|
"Abstract base class for recording node memory/time performance statistics"
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
"""
|
"""
|
||||||
Initialize the InvocationStatsService and reset counters to zero
|
Initialize the InvocationStatsService and reset counters to zero
|
||||||
"""
|
"""
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def collect_stats(
|
def collect_stats(
|
||||||
self,
|
self,
|
||||||
invocation: BaseInvocation,
|
invocation: BaseInvocation,
|
||||||
graph_execution_state_id: str,
|
graph_execution_state_id: str,
|
||||||
) -> AbstractContextManager:
|
) -> Iterator[None]:
|
||||||
"""
|
"""
|
||||||
Return a context object that will capture the statistics on the execution
|
Return a context object that will capture the statistics on the execution
|
||||||
of invocaation. Use with: to place around the part of the code that executes the invocation.
|
of invocaation. Use with: to place around the part of the code that executes the invocation.
|
||||||
@ -61,7 +60,7 @@ class InvocationStatsServiceBase(ABC):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def reset_stats(self, graph_execution_state_id: str):
|
def reset_stats(self, graph_execution_state_id: str) -> None:
|
||||||
"""
|
"""
|
||||||
Reset all statistics for the indicated graph.
|
Reset all statistics for the indicated graph.
|
||||||
:param graph_execution_state_id: The id of the session whose stats to reset.
|
:param graph_execution_state_id: The id of the session whose stats to reset.
|
||||||
@ -70,7 +69,7 @@ class InvocationStatsServiceBase(ABC):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def log_stats(self, graph_execution_state_id: str):
|
def log_stats(self, graph_execution_state_id: str) -> None:
|
||||||
"""
|
"""
|
||||||
Write out the accumulated statistics to the log or somewhere else.
|
Write out the accumulated statistics to the log or somewhere else.
|
||||||
:param graph_execution_state_id: The id of the session whose stats to log.
|
:param graph_execution_state_id: The id of the session whose stats to log.
|
||||||
|
@ -14,7 +14,7 @@ from typing_extensions import Annotated
|
|||||||
|
|
||||||
from invokeai.app.services.config import InvokeAIAppConfig
|
from invokeai.app.services.config import InvokeAIAppConfig
|
||||||
from invokeai.app.services.download import DownloadJob, DownloadQueueServiceBase
|
from invokeai.app.services.download import DownloadJob, DownloadQueueServiceBase
|
||||||
from invokeai.app.services.events import EventServiceBase
|
from invokeai.app.services.events.events_base import EventServiceBase
|
||||||
from invokeai.app.services.invoker import Invoker
|
from invokeai.app.services.invoker import Invoker
|
||||||
from invokeai.app.services.model_records import ModelRecordServiceBase
|
from invokeai.app.services.model_records import ModelRecordServiceBase
|
||||||
from invokeai.backend.model_manager import AnyModelConfig, ModelRepoVariant
|
from invokeai.backend.model_manager import AnyModelConfig, ModelRepoVariant
|
||||||
|
@ -5,7 +5,7 @@ from abc import ABC, abstractmethod
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from invokeai.app.invocations.baseinvocation import InvocationContext
|
from invokeai.app.invocations.baseinvocation import InvocationContext
|
||||||
from invokeai.backend.model_manager import AnyModel, AnyModelConfig, BaseModelType, ModelType, SubModelType
|
from invokeai.backend.model_manager import AnyModel, AnyModelConfig, SubModelType
|
||||||
from invokeai.backend.model_manager.load import LoadedModel
|
from invokeai.backend.model_manager.load import LoadedModel
|
||||||
from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase
|
from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase
|
||||||
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase
|
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase
|
||||||
@ -15,23 +15,7 @@ class ModelLoadServiceBase(ABC):
|
|||||||
"""Wrapper around AnyModelLoader."""
|
"""Wrapper around AnyModelLoader."""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def load_model_by_key(
|
def load_model(
|
||||||
self,
|
|
||||||
key: str,
|
|
||||||
submodel_type: Optional[SubModelType] = None,
|
|
||||||
context: Optional[InvocationContext] = None,
|
|
||||||
) -> LoadedModel:
|
|
||||||
"""
|
|
||||||
Given a model's key, load it and return the LoadedModel object.
|
|
||||||
|
|
||||||
:param key: Key of model config to be fetched.
|
|
||||||
:param submodel: For main (pipeline models), the submodel to fetch.
|
|
||||||
:param context: Invocation context used for event reporting
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def load_model_by_config(
|
|
||||||
self,
|
self,
|
||||||
model_config: AnyModelConfig,
|
model_config: AnyModelConfig,
|
||||||
submodel_type: Optional[SubModelType] = None,
|
submodel_type: Optional[SubModelType] = None,
|
||||||
@ -44,34 +28,6 @@ class ModelLoadServiceBase(ABC):
|
|||||||
:param submodel: For main (pipeline models), the submodel to fetch.
|
:param submodel: For main (pipeline models), the submodel to fetch.
|
||||||
:param context: Invocation context used for event reporting
|
:param context: Invocation context used for event reporting
|
||||||
"""
|
"""
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def load_model_by_attr(
|
|
||||||
self,
|
|
||||||
model_name: str,
|
|
||||||
base_model: BaseModelType,
|
|
||||||
model_type: ModelType,
|
|
||||||
submodel: Optional[SubModelType] = None,
|
|
||||||
context: Optional[InvocationContext] = None,
|
|
||||||
) -> LoadedModel:
|
|
||||||
"""
|
|
||||||
Given a model's attributes, search the database for it, and if found, load and return the LoadedModel object.
|
|
||||||
|
|
||||||
This is provided for API compatability with the get_model() method
|
|
||||||
in the original model manager. However, note that LoadedModel is
|
|
||||||
not the same as the original ModelInfo that ws returned.
|
|
||||||
|
|
||||||
:param model_name: Name of to be fetched.
|
|
||||||
:param base_model: Base model
|
|
||||||
:param model_type: Type of the model
|
|
||||||
:param submodel: For main (pipeline models), the submodel to fetch
|
|
||||||
:param context: The invocation context.
|
|
||||||
|
|
||||||
Exceptions: UnknownModelException -- model with these attributes not known
|
|
||||||
NotImplementedException -- a model loader was not provided at initialization time
|
|
||||||
ValueError -- more than one model matches this combination
|
|
||||||
"""
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
|
@ -1,14 +1,13 @@
|
|||||||
# Copyright (c) 2024 Lincoln D. Stein and the InvokeAI Team
|
# Copyright (c) 2024 Lincoln D. Stein and the InvokeAI Team
|
||||||
"""Implementation of model loader service."""
|
"""Implementation of model loader service."""
|
||||||
|
|
||||||
from typing import Optional
|
from typing import Optional, Type
|
||||||
|
|
||||||
from invokeai.app.invocations.baseinvocation import InvocationContext
|
from invokeai.app.invocations.baseinvocation import InvocationContext
|
||||||
from invokeai.app.services.config import InvokeAIAppConfig
|
from invokeai.app.services.config import InvokeAIAppConfig
|
||||||
from invokeai.app.services.invocation_processor.invocation_processor_common import CanceledException
|
from invokeai.app.services.invocation_processor.invocation_processor_common import CanceledException
|
||||||
from invokeai.app.services.model_records import ModelRecordServiceBase, UnknownModelException
|
from invokeai.backend.model_manager import AnyModel, AnyModelConfig, SubModelType
|
||||||
from invokeai.backend.model_manager import AnyModel, AnyModelConfig, BaseModelType, ModelType, SubModelType
|
from invokeai.backend.model_manager.load import LoadedModel, ModelLoaderRegistry, ModelLoaderRegistryBase
|
||||||
from invokeai.backend.model_manager.load import AnyModelLoader, LoadedModel, ModelCache, ModelConvertCache
|
|
||||||
from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase
|
from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase
|
||||||
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase
|
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase
|
||||||
from invokeai.backend.util.logging import InvokeAILogger
|
from invokeai.backend.util.logging import InvokeAILogger
|
||||||
@ -17,86 +16,35 @@ from .model_load_base import ModelLoadServiceBase
|
|||||||
|
|
||||||
|
|
||||||
class ModelLoadService(ModelLoadServiceBase):
|
class ModelLoadService(ModelLoadServiceBase):
|
||||||
"""Wrapper around AnyModelLoader."""
|
"""Wrapper around ModelLoaderRegistry."""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
app_config: InvokeAIAppConfig,
|
app_config: InvokeAIAppConfig,
|
||||||
record_store: ModelRecordServiceBase,
|
|
||||||
ram_cache: ModelCacheBase[AnyModel],
|
ram_cache: ModelCacheBase[AnyModel],
|
||||||
convert_cache: ModelConvertCacheBase,
|
convert_cache: ModelConvertCacheBase,
|
||||||
|
registry: Optional[Type[ModelLoaderRegistryBase]] = ModelLoaderRegistry,
|
||||||
):
|
):
|
||||||
"""Initialize the model load service."""
|
"""Initialize the model load service."""
|
||||||
logger = InvokeAILogger.get_logger(self.__class__.__name__)
|
logger = InvokeAILogger.get_logger(self.__class__.__name__)
|
||||||
logger.setLevel(app_config.log_level.upper())
|
logger.setLevel(app_config.log_level.upper())
|
||||||
self._store = record_store
|
self._logger = logger
|
||||||
self._any_loader = AnyModelLoader(
|
self._app_config = app_config
|
||||||
app_config=app_config,
|
self._ram_cache = ram_cache
|
||||||
logger=logger,
|
self._convert_cache = convert_cache
|
||||||
ram_cache=ram_cache,
|
self._registry = registry
|
||||||
convert_cache=convert_cache,
|
|
||||||
)
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def ram_cache(self) -> ModelCacheBase[AnyModel]:
|
def ram_cache(self) -> ModelCacheBase[AnyModel]:
|
||||||
"""Return the RAM cache used by this loader."""
|
"""Return the RAM cache used by this loader."""
|
||||||
return self._any_loader.ram_cache
|
return self._ram_cache
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def convert_cache(self) -> ModelConvertCacheBase:
|
def convert_cache(self) -> ModelConvertCacheBase:
|
||||||
"""Return the checkpoint convert cache used by this loader."""
|
"""Return the checkpoint convert cache used by this loader."""
|
||||||
return self._any_loader.convert_cache
|
return self._convert_cache
|
||||||
|
|
||||||
def load_model_by_key(
|
def load_model(
|
||||||
self,
|
|
||||||
key: str,
|
|
||||||
submodel_type: Optional[SubModelType] = None,
|
|
||||||
context: Optional[InvocationContext] = None,
|
|
||||||
) -> LoadedModel:
|
|
||||||
"""
|
|
||||||
Given a model's key, load it and return the LoadedModel object.
|
|
||||||
|
|
||||||
:param key: Key of model config to be fetched.
|
|
||||||
:param submodel: For main (pipeline models), the submodel to fetch.
|
|
||||||
:param context: Invocation context used for event reporting
|
|
||||||
"""
|
|
||||||
config = self._store.get_model(key)
|
|
||||||
return self.load_model_by_config(config, submodel_type, context)
|
|
||||||
|
|
||||||
def load_model_by_attr(
|
|
||||||
self,
|
|
||||||
model_name: str,
|
|
||||||
base_model: BaseModelType,
|
|
||||||
model_type: ModelType,
|
|
||||||
submodel: Optional[SubModelType] = None,
|
|
||||||
context: Optional[InvocationContext] = None,
|
|
||||||
) -> LoadedModel:
|
|
||||||
"""
|
|
||||||
Given a model's attributes, search the database for it, and if found, load and return the LoadedModel object.
|
|
||||||
|
|
||||||
This is provided for API compatability with the get_model() method
|
|
||||||
in the original model manager. However, note that LoadedModel is
|
|
||||||
not the same as the original ModelInfo that ws returned.
|
|
||||||
|
|
||||||
:param model_name: Name of to be fetched.
|
|
||||||
:param base_model: Base model
|
|
||||||
:param model_type: Type of the model
|
|
||||||
:param submodel: For main (pipeline models), the submodel to fetch
|
|
||||||
:param context: The invocation context.
|
|
||||||
|
|
||||||
Exceptions: UnknownModelException -- model with this key not known
|
|
||||||
NotImplementedException -- a model loader was not provided at initialization time
|
|
||||||
ValueError -- more than one model matches this combination
|
|
||||||
"""
|
|
||||||
configs = self._store.search_by_attr(model_name, base_model, model_type)
|
|
||||||
if len(configs) == 0:
|
|
||||||
raise UnknownModelException(f"{base_model}/{model_type}/{model_name}: Unknown model")
|
|
||||||
elif len(configs) > 1:
|
|
||||||
raise ValueError(f"{base_model}/{model_type}/{model_name}: More than one model matches.")
|
|
||||||
else:
|
|
||||||
return self.load_model_by_key(configs[0].key, submodel)
|
|
||||||
|
|
||||||
def load_model_by_config(
|
|
||||||
self,
|
self,
|
||||||
model_config: AnyModelConfig,
|
model_config: AnyModelConfig,
|
||||||
submodel_type: Optional[SubModelType] = None,
|
submodel_type: Optional[SubModelType] = None,
|
||||||
@ -114,7 +62,15 @@ class ModelLoadService(ModelLoadServiceBase):
|
|||||||
context=context,
|
context=context,
|
||||||
model_config=model_config,
|
model_config=model_config,
|
||||||
)
|
)
|
||||||
loaded_model = self._any_loader.load_model(model_config, submodel_type)
|
|
||||||
|
implementation, model_config, submodel_type = self._registry.get_implementation(model_config, submodel_type) # type: ignore
|
||||||
|
loaded_model: LoadedModel = implementation(
|
||||||
|
app_config=self._app_config,
|
||||||
|
logger=self._logger,
|
||||||
|
ram_cache=self._ram_cache,
|
||||||
|
convert_cache=self._convert_cache,
|
||||||
|
).load_model(model_config, submodel_type)
|
||||||
|
|
||||||
if context:
|
if context:
|
||||||
self._emit_load_event(
|
self._emit_load_event(
|
||||||
context=context,
|
context=context,
|
||||||
|
@ -3,7 +3,7 @@
|
|||||||
from invokeai.backend.model_manager import AnyModel, AnyModelConfig, BaseModelType, ModelType, SubModelType
|
from invokeai.backend.model_manager import AnyModel, AnyModelConfig, BaseModelType, ModelType, SubModelType
|
||||||
from invokeai.backend.model_manager.load import LoadedModel
|
from invokeai.backend.model_manager.load import LoadedModel
|
||||||
|
|
||||||
from .model_manager_default import ModelManagerServiceBase, ModelManagerService
|
from .model_manager_default import ModelManagerService, ModelManagerServiceBase
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"ModelManagerServiceBase",
|
"ModelManagerServiceBase",
|
||||||
|
@ -1,10 +1,14 @@
|
|||||||
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Team
|
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Team
|
||||||
"""Implementation of ModelManagerServiceBase."""
|
"""Implementation of ModelManagerServiceBase."""
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
from typing_extensions import Self
|
from typing_extensions import Self
|
||||||
|
|
||||||
|
from invokeai.app.invocations.baseinvocation import InvocationContext
|
||||||
from invokeai.app.services.invoker import Invoker
|
from invokeai.app.services.invoker import Invoker
|
||||||
from invokeai.backend.model_manager.load import ModelCache, ModelConvertCache
|
from invokeai.backend.model_manager import AnyModelConfig, BaseModelType, LoadedModel, ModelType, SubModelType
|
||||||
|
from invokeai.backend.model_manager.load import ModelCache, ModelConvertCache, ModelLoaderRegistry
|
||||||
from invokeai.backend.util.logging import InvokeAILogger
|
from invokeai.backend.util.logging import InvokeAILogger
|
||||||
|
|
||||||
from ..config import InvokeAIAppConfig
|
from ..config import InvokeAIAppConfig
|
||||||
@ -12,7 +16,7 @@ from ..download import DownloadQueueServiceBase
|
|||||||
from ..events.events_base import EventServiceBase
|
from ..events.events_base import EventServiceBase
|
||||||
from ..model_install import ModelInstallService, ModelInstallServiceBase
|
from ..model_install import ModelInstallService, ModelInstallServiceBase
|
||||||
from ..model_load import ModelLoadService, ModelLoadServiceBase
|
from ..model_load import ModelLoadService, ModelLoadServiceBase
|
||||||
from ..model_records import ModelRecordServiceBase
|
from ..model_records import ModelRecordServiceBase, UnknownModelException
|
||||||
from .model_manager_base import ModelManagerServiceBase
|
from .model_manager_base import ModelManagerServiceBase
|
||||||
|
|
||||||
|
|
||||||
@ -58,6 +62,56 @@ class ModelManagerService(ModelManagerServiceBase):
|
|||||||
if hasattr(service, "stop"):
|
if hasattr(service, "stop"):
|
||||||
service.stop(invoker)
|
service.stop(invoker)
|
||||||
|
|
||||||
|
def load_model_by_config(
|
||||||
|
self,
|
||||||
|
model_config: AnyModelConfig,
|
||||||
|
submodel_type: Optional[SubModelType] = None,
|
||||||
|
context: Optional[InvocationContext] = None,
|
||||||
|
) -> LoadedModel:
|
||||||
|
return self.load.load_model(model_config, submodel_type, context)
|
||||||
|
|
||||||
|
def load_model_by_key(
|
||||||
|
self,
|
||||||
|
key: str,
|
||||||
|
submodel_type: Optional[SubModelType] = None,
|
||||||
|
context: Optional[InvocationContext] = None,
|
||||||
|
) -> LoadedModel:
|
||||||
|
config = self.store.get_model(key)
|
||||||
|
return self.load.load_model(config, submodel_type, context)
|
||||||
|
|
||||||
|
def load_model_by_attr(
|
||||||
|
self,
|
||||||
|
model_name: str,
|
||||||
|
base_model: BaseModelType,
|
||||||
|
model_type: ModelType,
|
||||||
|
submodel: Optional[SubModelType] = None,
|
||||||
|
context: Optional[InvocationContext] = None,
|
||||||
|
) -> LoadedModel:
|
||||||
|
"""
|
||||||
|
Given a model's attributes, search the database for it, and if found, load and return the LoadedModel object.
|
||||||
|
|
||||||
|
This is provided for API compatability with the get_model() method
|
||||||
|
in the original model manager. However, note that LoadedModel is
|
||||||
|
not the same as the original ModelInfo that ws returned.
|
||||||
|
|
||||||
|
:param model_name: Name of to be fetched.
|
||||||
|
:param base_model: Base model
|
||||||
|
:param model_type: Type of the model
|
||||||
|
:param submodel: For main (pipeline models), the submodel to fetch
|
||||||
|
:param context: The invocation context.
|
||||||
|
|
||||||
|
Exceptions: UnknownModelException -- model with this key not known
|
||||||
|
NotImplementedException -- a model loader was not provided at initialization time
|
||||||
|
ValueError -- more than one model matches this combination
|
||||||
|
"""
|
||||||
|
configs = self.store.search_by_attr(model_name, base_model, model_type)
|
||||||
|
if len(configs) == 0:
|
||||||
|
raise UnknownModelException(f"{base_model}/{model_type}/{model_name}: Unknown model")
|
||||||
|
elif len(configs) > 1:
|
||||||
|
raise ValueError(f"{base_model}/{model_type}/{model_name}: More than one model matches.")
|
||||||
|
else:
|
||||||
|
return self.load.load_model(configs[0], submodel, context)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def build_model_manager(
|
def build_model_manager(
|
||||||
cls,
|
cls,
|
||||||
@ -82,9 +136,9 @@ class ModelManagerService(ModelManagerServiceBase):
|
|||||||
)
|
)
|
||||||
loader = ModelLoadService(
|
loader = ModelLoadService(
|
||||||
app_config=app_config,
|
app_config=app_config,
|
||||||
record_store=model_record_service,
|
|
||||||
ram_cache=ram_cache,
|
ram_cache=ram_cache,
|
||||||
convert_cache=convert_cache,
|
convert_cache=convert_cache,
|
||||||
|
registry=ModelLoaderRegistry,
|
||||||
)
|
)
|
||||||
installer = ModelInstallService(
|
installer = ModelInstallService(
|
||||||
app_config=app_config,
|
app_config=app_config,
|
||||||
|
@ -1,591 +0,0 @@
|
|||||||
"""
|
|
||||||
Migrate the models directory and models.yaml file from an existing
|
|
||||||
InvokeAI 2.3 installation to 3.0.0.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import argparse
|
|
||||||
import os
|
|
||||||
import shutil
|
|
||||||
import warnings
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Union
|
|
||||||
|
|
||||||
import diffusers
|
|
||||||
import transformers
|
|
||||||
import yaml
|
|
||||||
from diffusers import AutoencoderKL, StableDiffusionPipeline
|
|
||||||
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
|
||||||
from omegaconf import DictConfig, OmegaConf
|
|
||||||
from transformers import AutoFeatureExtractor, BertTokenizerFast, CLIPTextModel, CLIPTokenizer
|
|
||||||
|
|
||||||
import invokeai.backend.util.logging as logger
|
|
||||||
from invokeai.app.services.config import InvokeAIAppConfig
|
|
||||||
from invokeai.backend.model_management import ModelManager
|
|
||||||
from invokeai.backend.model_management.model_probe import BaseModelType, ModelProbe, ModelProbeInfo, ModelType
|
|
||||||
|
|
||||||
warnings.filterwarnings("ignore")
|
|
||||||
transformers.logging.set_verbosity_error()
|
|
||||||
diffusers.logging.set_verbosity_error()
|
|
||||||
|
|
||||||
|
|
||||||
# holder for paths that we will migrate
|
|
||||||
@dataclass
|
|
||||||
class ModelPaths:
|
|
||||||
models: Path
|
|
||||||
embeddings: Path
|
|
||||||
loras: Path
|
|
||||||
controlnets: Path
|
|
||||||
|
|
||||||
|
|
||||||
class MigrateTo3(object):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
from_root: Path,
|
|
||||||
to_models: Path,
|
|
||||||
model_manager: ModelManager,
|
|
||||||
src_paths: ModelPaths,
|
|
||||||
):
|
|
||||||
self.root_directory = from_root
|
|
||||||
self.dest_models = to_models
|
|
||||||
self.mgr = model_manager
|
|
||||||
self.src_paths = src_paths
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def initialize_yaml(cls, yaml_file: Path):
|
|
||||||
with open(yaml_file, "w") as file:
|
|
||||||
file.write(yaml.dump({"__metadata__": {"version": "3.0.0"}}))
|
|
||||||
|
|
||||||
def create_directory_structure(self):
|
|
||||||
"""
|
|
||||||
Create the basic directory structure for the models folder.
|
|
||||||
"""
|
|
||||||
for model_base in [BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2]:
|
|
||||||
for model_type in [
|
|
||||||
ModelType.Main,
|
|
||||||
ModelType.Vae,
|
|
||||||
ModelType.Lora,
|
|
||||||
ModelType.ControlNet,
|
|
||||||
ModelType.TextualInversion,
|
|
||||||
]:
|
|
||||||
path = self.dest_models / model_base.value / model_type.value
|
|
||||||
path.mkdir(parents=True, exist_ok=True)
|
|
||||||
path = self.dest_models / "core"
|
|
||||||
path.mkdir(parents=True, exist_ok=True)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def copy_file(src: Path, dest: Path):
|
|
||||||
"""
|
|
||||||
copy a single file with logging
|
|
||||||
"""
|
|
||||||
if dest.exists():
|
|
||||||
logger.info(f"Skipping existing {str(dest)}")
|
|
||||||
return
|
|
||||||
logger.info(f"Copying {str(src)} to {str(dest)}")
|
|
||||||
try:
|
|
||||||
shutil.copy(src, dest)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"COPY FAILED: {str(e)}")
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def copy_dir(src: Path, dest: Path):
|
|
||||||
"""
|
|
||||||
Recursively copy a directory with logging
|
|
||||||
"""
|
|
||||||
if dest.exists():
|
|
||||||
logger.info(f"Skipping existing {str(dest)}")
|
|
||||||
return
|
|
||||||
|
|
||||||
logger.info(f"Copying {str(src)} to {str(dest)}")
|
|
||||||
try:
|
|
||||||
shutil.copytree(src, dest)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"COPY FAILED: {str(e)}")
|
|
||||||
|
|
||||||
def migrate_models(self, src_dir: Path):
|
|
||||||
"""
|
|
||||||
Recursively walk through src directory, probe anything
|
|
||||||
that looks like a model, and copy the model into the
|
|
||||||
appropriate location within the destination models directory.
|
|
||||||
"""
|
|
||||||
directories_scanned = set()
|
|
||||||
for root, dirs, files in os.walk(src_dir, followlinks=True):
|
|
||||||
for d in dirs:
|
|
||||||
try:
|
|
||||||
model = Path(root, d)
|
|
||||||
info = ModelProbe().heuristic_probe(model)
|
|
||||||
if not info:
|
|
||||||
continue
|
|
||||||
dest = self._model_probe_to_path(info) / model.name
|
|
||||||
self.copy_dir(model, dest)
|
|
||||||
directories_scanned.add(model)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(str(e))
|
|
||||||
except KeyboardInterrupt:
|
|
||||||
raise
|
|
||||||
for f in files:
|
|
||||||
# don't copy raw learned_embeds.bin or pytorch_lora_weights.bin
|
|
||||||
# let them be copied as part of a tree copy operation
|
|
||||||
try:
|
|
||||||
if f in {"learned_embeds.bin", "pytorch_lora_weights.bin"}:
|
|
||||||
continue
|
|
||||||
model = Path(root, f)
|
|
||||||
if model.parent in directories_scanned:
|
|
||||||
continue
|
|
||||||
info = ModelProbe().heuristic_probe(model)
|
|
||||||
if not info:
|
|
||||||
continue
|
|
||||||
dest = self._model_probe_to_path(info) / f
|
|
||||||
self.copy_file(model, dest)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(str(e))
|
|
||||||
except KeyboardInterrupt:
|
|
||||||
raise
|
|
||||||
|
|
||||||
def migrate_support_models(self):
|
|
||||||
"""
|
|
||||||
Copy the clipseg, upscaler, and restoration models to their new
|
|
||||||
locations.
|
|
||||||
"""
|
|
||||||
dest_directory = self.dest_models
|
|
||||||
if (self.root_directory / "models/clipseg").exists():
|
|
||||||
self.copy_dir(self.root_directory / "models/clipseg", dest_directory / "core/misc/clipseg")
|
|
||||||
if (self.root_directory / "models/realesrgan").exists():
|
|
||||||
self.copy_dir(self.root_directory / "models/realesrgan", dest_directory / "core/upscaling/realesrgan")
|
|
||||||
for d in ["codeformer", "gfpgan"]:
|
|
||||||
path = self.root_directory / "models" / d
|
|
||||||
if path.exists():
|
|
||||||
self.copy_dir(path, dest_directory / f"core/face_restoration/{d}")
|
|
||||||
|
|
||||||
def migrate_tuning_models(self):
|
|
||||||
"""
|
|
||||||
Migrate the embeddings, loras and controlnets directories to their new homes.
|
|
||||||
"""
|
|
||||||
for src in [self.src_paths.embeddings, self.src_paths.loras, self.src_paths.controlnets]:
|
|
||||||
if not src:
|
|
||||||
continue
|
|
||||||
if src.is_dir():
|
|
||||||
logger.info(f"Scanning {src}")
|
|
||||||
self.migrate_models(src)
|
|
||||||
else:
|
|
||||||
logger.info(f"{src} directory not found; skipping")
|
|
||||||
continue
|
|
||||||
|
|
||||||
def migrate_conversion_models(self):
|
|
||||||
"""
|
|
||||||
Migrate all the models that are needed by the ckpt_to_diffusers conversion
|
|
||||||
script.
|
|
||||||
"""
|
|
||||||
|
|
||||||
dest_directory = self.dest_models
|
|
||||||
kwargs = {
|
|
||||||
"cache_dir": self.root_directory / "models/hub",
|
|
||||||
# local_files_only = True
|
|
||||||
}
|
|
||||||
try:
|
|
||||||
logger.info("Migrating core tokenizers and text encoders")
|
|
||||||
target_dir = dest_directory / "core" / "convert"
|
|
||||||
|
|
||||||
self._migrate_pretrained(
|
|
||||||
BertTokenizerFast, repo_id="bert-base-uncased", dest=target_dir / "bert-base-uncased", **kwargs
|
|
||||||
)
|
|
||||||
|
|
||||||
# sd-1
|
|
||||||
repo_id = "openai/clip-vit-large-patch14"
|
|
||||||
self._migrate_pretrained(
|
|
||||||
CLIPTokenizer, repo_id=repo_id, dest=target_dir / "clip-vit-large-patch14", **kwargs
|
|
||||||
)
|
|
||||||
self._migrate_pretrained(
|
|
||||||
CLIPTextModel, repo_id=repo_id, dest=target_dir / "clip-vit-large-patch14", force=True, **kwargs
|
|
||||||
)
|
|
||||||
|
|
||||||
# sd-2
|
|
||||||
repo_id = "stabilityai/stable-diffusion-2"
|
|
||||||
self._migrate_pretrained(
|
|
||||||
CLIPTokenizer,
|
|
||||||
repo_id=repo_id,
|
|
||||||
dest=target_dir / "stable-diffusion-2-clip" / "tokenizer",
|
|
||||||
**{"subfolder": "tokenizer", **kwargs},
|
|
||||||
)
|
|
||||||
self._migrate_pretrained(
|
|
||||||
CLIPTextModel,
|
|
||||||
repo_id=repo_id,
|
|
||||||
dest=target_dir / "stable-diffusion-2-clip" / "text_encoder",
|
|
||||||
**{"subfolder": "text_encoder", **kwargs},
|
|
||||||
)
|
|
||||||
|
|
||||||
# VAE
|
|
||||||
logger.info("Migrating stable diffusion VAE")
|
|
||||||
self._migrate_pretrained(
|
|
||||||
AutoencoderKL, repo_id="stabilityai/sd-vae-ft-mse", dest=target_dir / "sd-vae-ft-mse", **kwargs
|
|
||||||
)
|
|
||||||
|
|
||||||
# safety checking
|
|
||||||
logger.info("Migrating safety checker")
|
|
||||||
repo_id = "CompVis/stable-diffusion-safety-checker"
|
|
||||||
self._migrate_pretrained(
|
|
||||||
AutoFeatureExtractor, repo_id=repo_id, dest=target_dir / "stable-diffusion-safety-checker", **kwargs
|
|
||||||
)
|
|
||||||
self._migrate_pretrained(
|
|
||||||
StableDiffusionSafetyChecker,
|
|
||||||
repo_id=repo_id,
|
|
||||||
dest=target_dir / "stable-diffusion-safety-checker",
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
except KeyboardInterrupt:
|
|
||||||
raise
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(str(e))
|
|
||||||
|
|
||||||
def _model_probe_to_path(self, info: ModelProbeInfo) -> Path:
|
|
||||||
return Path(self.dest_models, info.base_type.value, info.model_type.value)
|
|
||||||
|
|
||||||
def _migrate_pretrained(self, model_class, repo_id: str, dest: Path, force: bool = False, **kwargs):
|
|
||||||
if dest.exists() and not force:
|
|
||||||
logger.info(f"Skipping existing {dest}")
|
|
||||||
return
|
|
||||||
model = model_class.from_pretrained(repo_id, **kwargs)
|
|
||||||
self._save_pretrained(model, dest, overwrite=force)
|
|
||||||
|
|
||||||
def _save_pretrained(self, model, dest: Path, overwrite: bool = False):
|
|
||||||
model_name = dest.name
|
|
||||||
if overwrite:
|
|
||||||
model.save_pretrained(dest, safe_serialization=True)
|
|
||||||
else:
|
|
||||||
download_path = dest.with_name(f"{model_name}.downloading")
|
|
||||||
model.save_pretrained(download_path, safe_serialization=True)
|
|
||||||
download_path.replace(dest)
|
|
||||||
|
|
||||||
def _download_vae(self, repo_id: str, subfolder: str = None) -> Path:
|
|
||||||
vae = AutoencoderKL.from_pretrained(repo_id, cache_dir=self.root_directory / "models/hub", subfolder=subfolder)
|
|
||||||
info = ModelProbe().heuristic_probe(vae)
|
|
||||||
_, model_name = repo_id.split("/")
|
|
||||||
dest = self._model_probe_to_path(info) / self.unique_name(model_name, info)
|
|
||||||
vae.save_pretrained(dest, safe_serialization=True)
|
|
||||||
return dest
|
|
||||||
|
|
||||||
def _vae_path(self, vae: Union[str, dict]) -> Path:
|
|
||||||
"""
|
|
||||||
Convert 2.3 VAE stanza to a straight path.
|
|
||||||
"""
|
|
||||||
vae_path = None
|
|
||||||
|
|
||||||
# First get a path
|
|
||||||
if isinstance(vae, str):
|
|
||||||
vae_path = vae
|
|
||||||
|
|
||||||
elif isinstance(vae, DictConfig):
|
|
||||||
if p := vae.get("path"):
|
|
||||||
vae_path = p
|
|
||||||
elif repo_id := vae.get("repo_id"):
|
|
||||||
if repo_id == "stabilityai/sd-vae-ft-mse": # this guy is already downloaded
|
|
||||||
vae_path = "models/core/convert/sd-vae-ft-mse"
|
|
||||||
return vae_path
|
|
||||||
else:
|
|
||||||
vae_path = self._download_vae(repo_id, vae.get("subfolder"))
|
|
||||||
|
|
||||||
assert vae_path is not None, "Couldn't find VAE for this model"
|
|
||||||
|
|
||||||
# if the VAE is in the old models directory, then we must move it into the new
|
|
||||||
# one. VAEs outside of this directory can stay where they are.
|
|
||||||
vae_path = Path(vae_path)
|
|
||||||
if vae_path.is_relative_to(self.src_paths.models):
|
|
||||||
info = ModelProbe().heuristic_probe(vae_path)
|
|
||||||
dest = self._model_probe_to_path(info) / vae_path.name
|
|
||||||
if not dest.exists():
|
|
||||||
if vae_path.is_dir():
|
|
||||||
self.copy_dir(vae_path, dest)
|
|
||||||
else:
|
|
||||||
self.copy_file(vae_path, dest)
|
|
||||||
vae_path = dest
|
|
||||||
|
|
||||||
if vae_path.is_relative_to(self.dest_models):
|
|
||||||
rel_path = vae_path.relative_to(self.dest_models)
|
|
||||||
return Path("models", rel_path)
|
|
||||||
else:
|
|
||||||
return vae_path
|
|
||||||
|
|
||||||
def migrate_repo_id(self, repo_id: str, model_name: str = None, **extra_config):
|
|
||||||
"""
|
|
||||||
Migrate a locally-cached diffusers pipeline identified with a repo_id
|
|
||||||
"""
|
|
||||||
dest_dir = self.dest_models
|
|
||||||
|
|
||||||
cache = self.root_directory / "models/hub"
|
|
||||||
kwargs = {
|
|
||||||
"cache_dir": cache,
|
|
||||||
"safety_checker": None,
|
|
||||||
# local_files_only = True,
|
|
||||||
}
|
|
||||||
|
|
||||||
owner, repo_name = repo_id.split("/")
|
|
||||||
model_name = model_name or repo_name
|
|
||||||
model = cache / "--".join(["models", owner, repo_name])
|
|
||||||
|
|
||||||
if len(list(model.glob("snapshots/**/model_index.json"))) == 0:
|
|
||||||
return
|
|
||||||
revisions = [x.name for x in model.glob("refs/*")]
|
|
||||||
|
|
||||||
# if an fp16 is available we use that
|
|
||||||
revision = "fp16" if len(revisions) > 1 and "fp16" in revisions else revisions[0]
|
|
||||||
pipeline = StableDiffusionPipeline.from_pretrained(repo_id, revision=revision, **kwargs)
|
|
||||||
|
|
||||||
info = ModelProbe().heuristic_probe(pipeline)
|
|
||||||
if not info:
|
|
||||||
return
|
|
||||||
|
|
||||||
if self.mgr.model_exists(model_name, info.base_type, info.model_type):
|
|
||||||
logger.warning(f"A model named {model_name} already exists at the destination. Skipping migration.")
|
|
||||||
return
|
|
||||||
|
|
||||||
dest = self._model_probe_to_path(info) / model_name
|
|
||||||
self._save_pretrained(pipeline, dest)
|
|
||||||
|
|
||||||
rel_path = Path("models", dest.relative_to(dest_dir))
|
|
||||||
self._add_model(model_name, info, rel_path, **extra_config)
|
|
||||||
|
|
||||||
def migrate_path(self, location: Path, model_name: str = None, **extra_config):
|
|
||||||
"""
|
|
||||||
Migrate a model referred to using 'weights' or 'path'
|
|
||||||
"""
|
|
||||||
|
|
||||||
# handle relative paths
|
|
||||||
dest_dir = self.dest_models
|
|
||||||
location = self.root_directory / location
|
|
||||||
model_name = model_name or location.stem
|
|
||||||
|
|
||||||
info = ModelProbe().heuristic_probe(location)
|
|
||||||
if not info:
|
|
||||||
return
|
|
||||||
|
|
||||||
if self.mgr.model_exists(model_name, info.base_type, info.model_type):
|
|
||||||
logger.warning(f"A model named {model_name} already exists at the destination. Skipping migration.")
|
|
||||||
return
|
|
||||||
|
|
||||||
# uh oh, weights is in the old models directory - move it into the new one
|
|
||||||
if Path(location).is_relative_to(self.src_paths.models):
|
|
||||||
dest = Path(dest_dir, info.base_type.value, info.model_type.value, location.name)
|
|
||||||
if location.is_dir():
|
|
||||||
self.copy_dir(location, dest)
|
|
||||||
else:
|
|
||||||
self.copy_file(location, dest)
|
|
||||||
location = Path("models", info.base_type.value, info.model_type.value, location.name)
|
|
||||||
|
|
||||||
self._add_model(model_name, info, location, **extra_config)
|
|
||||||
|
|
||||||
def _add_model(self, model_name: str, info: ModelProbeInfo, location: Path, **extra_config):
|
|
||||||
if info.model_type != ModelType.Main:
|
|
||||||
return
|
|
||||||
|
|
||||||
self.mgr.add_model(
|
|
||||||
model_name=model_name,
|
|
||||||
base_model=info.base_type,
|
|
||||||
model_type=info.model_type,
|
|
||||||
clobber=True,
|
|
||||||
model_attributes={
|
|
||||||
"path": str(location),
|
|
||||||
"description": f"A {info.base_type.value} {info.model_type.value} model",
|
|
||||||
"model_format": info.format,
|
|
||||||
"variant": info.variant_type.value,
|
|
||||||
**extra_config,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
def migrate_defined_models(self):
|
|
||||||
"""
|
|
||||||
Migrate models defined in models.yaml
|
|
||||||
"""
|
|
||||||
# find any models referred to in old models.yaml
|
|
||||||
conf = OmegaConf.load(self.root_directory / "configs/models.yaml")
|
|
||||||
|
|
||||||
for model_name, stanza in conf.items():
|
|
||||||
try:
|
|
||||||
passthru_args = {}
|
|
||||||
|
|
||||||
if vae := stanza.get("vae"):
|
|
||||||
try:
|
|
||||||
passthru_args["vae"] = str(self._vae_path(vae))
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f'Could not find a VAE matching "{vae}" for model "{model_name}"')
|
|
||||||
logger.warning(str(e))
|
|
||||||
|
|
||||||
if config := stanza.get("config"):
|
|
||||||
passthru_args["config"] = config
|
|
||||||
|
|
||||||
if description := stanza.get("description"):
|
|
||||||
passthru_args["description"] = description
|
|
||||||
|
|
||||||
if repo_id := stanza.get("repo_id"):
|
|
||||||
logger.info(f"Migrating diffusers model {model_name}")
|
|
||||||
self.migrate_repo_id(repo_id, model_name, **passthru_args)
|
|
||||||
|
|
||||||
elif location := stanza.get("weights"):
|
|
||||||
logger.info(f"Migrating checkpoint model {model_name}")
|
|
||||||
self.migrate_path(Path(location), model_name, **passthru_args)
|
|
||||||
|
|
||||||
elif location := stanza.get("path"):
|
|
||||||
logger.info(f"Migrating diffusers model {model_name}")
|
|
||||||
self.migrate_path(Path(location), model_name, **passthru_args)
|
|
||||||
|
|
||||||
except KeyboardInterrupt:
|
|
||||||
raise
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(str(e))
|
|
||||||
|
|
||||||
def migrate(self):
|
|
||||||
self.create_directory_structure()
|
|
||||||
# the configure script is doing this
|
|
||||||
self.migrate_support_models()
|
|
||||||
self.migrate_conversion_models()
|
|
||||||
self.migrate_tuning_models()
|
|
||||||
self.migrate_defined_models()
|
|
||||||
|
|
||||||
|
|
||||||
def _parse_legacy_initfile(root: Path, initfile: Path) -> ModelPaths:
|
|
||||||
"""
|
|
||||||
Returns tuple of (embedding_path, lora_path, controlnet_path)
|
|
||||||
"""
|
|
||||||
parser = argparse.ArgumentParser(fromfile_prefix_chars="@")
|
|
||||||
parser.add_argument(
|
|
||||||
"--embedding_directory",
|
|
||||||
"--embedding_path",
|
|
||||||
type=Path,
|
|
||||||
dest="embedding_path",
|
|
||||||
default=Path("embeddings"),
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--lora_directory",
|
|
||||||
dest="lora_path",
|
|
||||||
type=Path,
|
|
||||||
default=Path("loras"),
|
|
||||||
)
|
|
||||||
opt, _ = parser.parse_known_args([f"@{str(initfile)}"])
|
|
||||||
return ModelPaths(
|
|
||||||
models=root / "models",
|
|
||||||
embeddings=root / str(opt.embedding_path).strip('"'),
|
|
||||||
loras=root / str(opt.lora_path).strip('"'),
|
|
||||||
controlnets=root / "controlnets",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _parse_legacy_yamlfile(root: Path, initfile: Path) -> ModelPaths:
|
|
||||||
"""
|
|
||||||
Returns tuple of (embedding_path, lora_path, controlnet_path)
|
|
||||||
"""
|
|
||||||
# Don't use the config object because it is unforgiving of version updates
|
|
||||||
# Just use omegaconf directly
|
|
||||||
opt = OmegaConf.load(initfile)
|
|
||||||
paths = opt.InvokeAI.Paths
|
|
||||||
models = paths.get("models_dir", "models")
|
|
||||||
embeddings = paths.get("embedding_dir", "embeddings")
|
|
||||||
loras = paths.get("lora_dir", "loras")
|
|
||||||
controlnets = paths.get("controlnet_dir", "controlnets")
|
|
||||||
return ModelPaths(
|
|
||||||
models=root / models if models else None,
|
|
||||||
embeddings=root / embeddings if embeddings else None,
|
|
||||||
loras=root / loras if loras else None,
|
|
||||||
controlnets=root / controlnets if controlnets else None,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def get_legacy_embeddings(root: Path) -> ModelPaths:
|
|
||||||
path = root / "invokeai.init"
|
|
||||||
if path.exists():
|
|
||||||
return _parse_legacy_initfile(root, path)
|
|
||||||
path = root / "invokeai.yaml"
|
|
||||||
if path.exists():
|
|
||||||
return _parse_legacy_yamlfile(root, path)
|
|
||||||
|
|
||||||
|
|
||||||
def do_migrate(src_directory: Path, dest_directory: Path):
|
|
||||||
"""
|
|
||||||
Migrate models from src to dest InvokeAI root directories
|
|
||||||
"""
|
|
||||||
config_file = dest_directory / "configs" / "models.yaml.3"
|
|
||||||
dest_models = dest_directory / "models.3"
|
|
||||||
|
|
||||||
version_3 = (dest_directory / "models" / "core").exists()
|
|
||||||
|
|
||||||
# Here we create the destination models.yaml file.
|
|
||||||
# If we are writing into a version 3 directory and the
|
|
||||||
# file already exists, then we write into a copy of it to
|
|
||||||
# avoid deleting its previous customizations. Otherwise we
|
|
||||||
# create a new empty one.
|
|
||||||
if version_3: # write into the dest directory
|
|
||||||
try:
|
|
||||||
shutil.copy(dest_directory / "configs" / "models.yaml", config_file)
|
|
||||||
except Exception:
|
|
||||||
MigrateTo3.initialize_yaml(config_file)
|
|
||||||
mgr = ModelManager(config_file) # important to initialize BEFORE moving the models directory
|
|
||||||
(dest_directory / "models").replace(dest_models)
|
|
||||||
else:
|
|
||||||
MigrateTo3.initialize_yaml(config_file)
|
|
||||||
mgr = ModelManager(config_file)
|
|
||||||
|
|
||||||
paths = get_legacy_embeddings(src_directory)
|
|
||||||
migrator = MigrateTo3(from_root=src_directory, to_models=dest_models, model_manager=mgr, src_paths=paths)
|
|
||||||
migrator.migrate()
|
|
||||||
print("Migration successful.")
|
|
||||||
|
|
||||||
if not version_3:
|
|
||||||
(dest_directory / "models").replace(src_directory / "models.orig")
|
|
||||||
print(f"Original models directory moved to {dest_directory}/models.orig")
|
|
||||||
|
|
||||||
(dest_directory / "configs" / "models.yaml").replace(src_directory / "configs" / "models.yaml.orig")
|
|
||||||
print(f"Original models.yaml file moved to {dest_directory}/configs/models.yaml.orig")
|
|
||||||
|
|
||||||
config_file.replace(config_file.with_suffix(""))
|
|
||||||
dest_models.replace(dest_models.with_suffix(""))
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
parser = argparse.ArgumentParser(
|
|
||||||
prog="invokeai-migrate3",
|
|
||||||
description="""
|
|
||||||
This will copy and convert the models directory and the configs/models.yaml from the InvokeAI 2.3 format
|
|
||||||
'--from-directory' root to the InvokeAI 3.0 '--to-directory' root. These may be abbreviated '--from' and '--to'.a
|
|
||||||
|
|
||||||
The old models directory and config file will be renamed 'models.orig' and 'models.yaml.orig' respectively.
|
|
||||||
It is safe to provide the same directory for both arguments, but it is better to use the invokeai_configure
|
|
||||||
script, which will perform a full upgrade in place.""",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--from-directory",
|
|
||||||
dest="src_root",
|
|
||||||
type=Path,
|
|
||||||
required=True,
|
|
||||||
help='Source InvokeAI 2.3 root directory (containing "invokeai.init" or "invokeai.yaml")',
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--to-directory",
|
|
||||||
dest="dest_root",
|
|
||||||
type=Path,
|
|
||||||
required=True,
|
|
||||||
help='Destination InvokeAI 3.0 directory (containing "invokeai.yaml")',
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
|
||||||
src_root = args.src_root
|
|
||||||
assert src_root.is_dir(), f"{src_root} is not a valid directory"
|
|
||||||
assert (src_root / "models").is_dir(), f"{src_root} does not contain a 'models' subdirectory"
|
|
||||||
assert (src_root / "models" / "hub").exists(), f"{src_root} does not contain a version 2.3 models directory"
|
|
||||||
assert (src_root / "invokeai.init").exists() or (
|
|
||||||
src_root / "invokeai.yaml"
|
|
||||||
).exists(), f"{src_root} does not contain an InvokeAI init file."
|
|
||||||
|
|
||||||
dest_root = args.dest_root
|
|
||||||
assert dest_root.is_dir(), f"{dest_root} is not a valid directory"
|
|
||||||
config = InvokeAIAppConfig.get_config()
|
|
||||||
config.parse_args(["--root", str(dest_root)])
|
|
||||||
|
|
||||||
# TODO: revisit - don't rely on invokeai.yaml to exist yet!
|
|
||||||
dest_is_setup = (dest_root / "models/core").exists() and (dest_root / "databases").exists()
|
|
||||||
if not dest_is_setup:
|
|
||||||
from invokeai.backend.install.invokeai_configure import initialize_rootdir
|
|
||||||
|
|
||||||
initialize_rootdir(dest_root, True)
|
|
||||||
|
|
||||||
do_migrate(src_root, dest_root)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
@ -1,637 +0,0 @@
|
|||||||
"""
|
|
||||||
Utility (backend) functions used by model_install.py
|
|
||||||
"""
|
|
||||||
import os
|
|
||||||
import re
|
|
||||||
import shutil
|
|
||||||
import warnings
|
|
||||||
from dataclasses import dataclass, field
|
|
||||||
from pathlib import Path
|
|
||||||
from tempfile import TemporaryDirectory
|
|
||||||
from typing import Callable, Dict, List, Optional, Set, Union
|
|
||||||
|
|
||||||
import requests
|
|
||||||
import torch
|
|
||||||
from diffusers import DiffusionPipeline
|
|
||||||
from diffusers import logging as dlogging
|
|
||||||
from huggingface_hub import HfApi, HfFolder, hf_hub_url
|
|
||||||
from omegaconf import OmegaConf
|
|
||||||
from tqdm import tqdm
|
|
||||||
|
|
||||||
import invokeai.configs as configs
|
|
||||||
from invokeai.app.services.config import InvokeAIAppConfig
|
|
||||||
from invokeai.backend.model_management import AddModelResult, BaseModelType, ModelManager, ModelType, ModelVariantType
|
|
||||||
from invokeai.backend.model_management.model_probe import ModelProbe, ModelProbeInfo, SchedulerPredictionType
|
|
||||||
from invokeai.backend.util import download_with_resume
|
|
||||||
from invokeai.backend.util.devices import choose_torch_device, torch_dtype
|
|
||||||
|
|
||||||
from ..util.logging import InvokeAILogger
|
|
||||||
|
|
||||||
warnings.filterwarnings("ignore")
|
|
||||||
|
|
||||||
# --------------------------globals-----------------------
|
|
||||||
config = InvokeAIAppConfig.get_config()
|
|
||||||
logger = InvokeAILogger.get_logger(name="InvokeAI")
|
|
||||||
|
|
||||||
# the initial "configs" dir is now bundled in the `invokeai.configs` package
|
|
||||||
Dataset_path = Path(configs.__path__[0]) / "INITIAL_MODELS.yaml"
|
|
||||||
|
|
||||||
Config_preamble = """
|
|
||||||
# This file describes the alternative machine learning models
|
|
||||||
# available to InvokeAI script.
|
|
||||||
#
|
|
||||||
# To add a new model, follow the examples below. Each
|
|
||||||
# model requires a model config file, a weights file,
|
|
||||||
# and the width and height of the images it
|
|
||||||
# was trained on.
|
|
||||||
"""
|
|
||||||
|
|
||||||
LEGACY_CONFIGS = {
|
|
||||||
BaseModelType.StableDiffusion1: {
|
|
||||||
ModelVariantType.Normal: {
|
|
||||||
SchedulerPredictionType.Epsilon: "v1-inference.yaml",
|
|
||||||
SchedulerPredictionType.VPrediction: "v1-inference-v.yaml",
|
|
||||||
},
|
|
||||||
ModelVariantType.Inpaint: {
|
|
||||||
SchedulerPredictionType.Epsilon: "v1-inpainting-inference.yaml",
|
|
||||||
SchedulerPredictionType.VPrediction: "v1-inpainting-inference-v.yaml",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
BaseModelType.StableDiffusion2: {
|
|
||||||
ModelVariantType.Normal: {
|
|
||||||
SchedulerPredictionType.Epsilon: "v2-inference.yaml",
|
|
||||||
SchedulerPredictionType.VPrediction: "v2-inference-v.yaml",
|
|
||||||
},
|
|
||||||
ModelVariantType.Inpaint: {
|
|
||||||
SchedulerPredictionType.Epsilon: "v2-inpainting-inference.yaml",
|
|
||||||
SchedulerPredictionType.VPrediction: "v2-inpainting-inference-v.yaml",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
BaseModelType.StableDiffusionXL: {
|
|
||||||
ModelVariantType.Normal: "sd_xl_base.yaml",
|
|
||||||
},
|
|
||||||
BaseModelType.StableDiffusionXLRefiner: {
|
|
||||||
ModelVariantType.Normal: "sd_xl_refiner.yaml",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class InstallSelections:
|
|
||||||
install_models: List[str] = field(default_factory=list)
|
|
||||||
remove_models: List[str] = field(default_factory=list)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class ModelLoadInfo:
|
|
||||||
name: str
|
|
||||||
model_type: ModelType
|
|
||||||
base_type: BaseModelType
|
|
||||||
path: Optional[Path] = None
|
|
||||||
repo_id: Optional[str] = None
|
|
||||||
subfolder: Optional[str] = None
|
|
||||||
description: str = ""
|
|
||||||
installed: bool = False
|
|
||||||
recommended: bool = False
|
|
||||||
default: bool = False
|
|
||||||
requires: Optional[List[str]] = field(default_factory=list)
|
|
||||||
|
|
||||||
|
|
||||||
class ModelInstall(object):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
config: InvokeAIAppConfig,
|
|
||||||
prediction_type_helper: Optional[Callable[[Path], SchedulerPredictionType]] = None,
|
|
||||||
model_manager: Optional[ModelManager] = None,
|
|
||||||
access_token: Optional[str] = None,
|
|
||||||
civitai_api_key: Optional[str] = None,
|
|
||||||
):
|
|
||||||
self.config = config
|
|
||||||
self.mgr = model_manager or ModelManager(config.model_conf_path)
|
|
||||||
self.datasets = OmegaConf.load(Dataset_path)
|
|
||||||
self.prediction_helper = prediction_type_helper
|
|
||||||
self.access_token = access_token or HfFolder.get_token()
|
|
||||||
self.civitai_api_key = civitai_api_key or config.civitai_api_key
|
|
||||||
self.reverse_paths = self._reverse_paths(self.datasets)
|
|
||||||
|
|
||||||
def all_models(self) -> Dict[str, ModelLoadInfo]:
|
|
||||||
"""
|
|
||||||
Return dict of model_key=>ModelLoadInfo objects.
|
|
||||||
This method consolidates and simplifies the entries in both
|
|
||||||
models.yaml and INITIAL_MODELS.yaml so that they can
|
|
||||||
be treated uniformly. It also sorts the models alphabetically
|
|
||||||
by their name, to improve the display somewhat.
|
|
||||||
"""
|
|
||||||
model_dict = {}
|
|
||||||
|
|
||||||
# first populate with the entries in INITIAL_MODELS.yaml
|
|
||||||
for key, value in self.datasets.items():
|
|
||||||
name, base, model_type = ModelManager.parse_key(key)
|
|
||||||
value["name"] = name
|
|
||||||
value["base_type"] = base
|
|
||||||
value["model_type"] = model_type
|
|
||||||
model_info = ModelLoadInfo(**value)
|
|
||||||
if model_info.subfolder and model_info.repo_id:
|
|
||||||
model_info.repo_id += f":{model_info.subfolder}"
|
|
||||||
model_dict[key] = model_info
|
|
||||||
|
|
||||||
# supplement with entries in models.yaml
|
|
||||||
installed_models = list(self.mgr.list_models())
|
|
||||||
|
|
||||||
for md in installed_models:
|
|
||||||
base = md["base_model"]
|
|
||||||
model_type = md["model_type"]
|
|
||||||
name = md["model_name"]
|
|
||||||
key = ModelManager.create_key(name, base, model_type)
|
|
||||||
if key in model_dict:
|
|
||||||
model_dict[key].installed = True
|
|
||||||
else:
|
|
||||||
model_dict[key] = ModelLoadInfo(
|
|
||||||
name=name,
|
|
||||||
base_type=base,
|
|
||||||
model_type=model_type,
|
|
||||||
path=value.get("path"),
|
|
||||||
installed=True,
|
|
||||||
)
|
|
||||||
return {x: model_dict[x] for x in sorted(model_dict.keys(), key=lambda y: model_dict[y].name.lower())}
|
|
||||||
|
|
||||||
def _is_autoloaded(self, model_info: dict) -> bool:
|
|
||||||
path = model_info.get("path")
|
|
||||||
if not path:
|
|
||||||
return False
|
|
||||||
for autodir in ["autoimport_dir", "lora_dir", "embedding_dir", "controlnet_dir"]:
|
|
||||||
if autodir_path := getattr(self.config, autodir):
|
|
||||||
autodir_path = self.config.root_path / autodir_path
|
|
||||||
if Path(path).is_relative_to(autodir_path):
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
def list_models(self, model_type):
|
|
||||||
installed = self.mgr.list_models(model_type=model_type)
|
|
||||||
print()
|
|
||||||
print(f"Installed models of type `{model_type}`:")
|
|
||||||
print(f"{'Model Key':50} Model Path")
|
|
||||||
for i in installed:
|
|
||||||
print(f"{'/'.join([i['base_model'],i['model_type'],i['model_name']]):50} {i['path']}")
|
|
||||||
print()
|
|
||||||
|
|
||||||
# logic here a little reversed to maintain backward compatibility
|
|
||||||
def starter_models(self, all_models: bool = False) -> Set[str]:
|
|
||||||
models = set()
|
|
||||||
for key, _value in self.datasets.items():
|
|
||||||
name, base, model_type = ModelManager.parse_key(key)
|
|
||||||
if all_models or model_type in [ModelType.Main, ModelType.Vae]:
|
|
||||||
models.add(key)
|
|
||||||
return models
|
|
||||||
|
|
||||||
def recommended_models(self) -> Set[str]:
|
|
||||||
starters = self.starter_models(all_models=True)
|
|
||||||
return {x for x in starters if self.datasets[x].get("recommended", False)}
|
|
||||||
|
|
||||||
def default_model(self) -> str:
|
|
||||||
starters = self.starter_models()
|
|
||||||
defaults = [x for x in starters if self.datasets[x].get("default", False)]
|
|
||||||
return defaults[0]
|
|
||||||
|
|
||||||
def install(self, selections: InstallSelections):
|
|
||||||
verbosity = dlogging.get_verbosity() # quench NSFW nags
|
|
||||||
dlogging.set_verbosity_error()
|
|
||||||
|
|
||||||
job = 1
|
|
||||||
jobs = len(selections.remove_models) + len(selections.install_models)
|
|
||||||
|
|
||||||
# remove requested models
|
|
||||||
for key in selections.remove_models:
|
|
||||||
name, base, mtype = self.mgr.parse_key(key)
|
|
||||||
logger.info(f"Deleting {mtype} model {name} [{job}/{jobs}]")
|
|
||||||
try:
|
|
||||||
self.mgr.del_model(name, base, mtype)
|
|
||||||
except FileNotFoundError as e:
|
|
||||||
logger.warning(e)
|
|
||||||
job += 1
|
|
||||||
|
|
||||||
# add requested models
|
|
||||||
self._remove_installed(selections.install_models)
|
|
||||||
self._add_required_models(selections.install_models)
|
|
||||||
for path in selections.install_models:
|
|
||||||
logger.info(f"Installing {path} [{job}/{jobs}]")
|
|
||||||
try:
|
|
||||||
self.heuristic_import(path)
|
|
||||||
except (ValueError, KeyError) as e:
|
|
||||||
logger.error(str(e))
|
|
||||||
job += 1
|
|
||||||
|
|
||||||
dlogging.set_verbosity(verbosity)
|
|
||||||
self.mgr.commit()
|
|
||||||
|
|
||||||
def heuristic_import(
|
|
||||||
self,
|
|
||||||
model_path_id_or_url: Union[str, Path],
|
|
||||||
models_installed: Set[Path] = None,
|
|
||||||
) -> Dict[str, AddModelResult]:
|
|
||||||
"""
|
|
||||||
:param model_path_id_or_url: A Path to a local model to import, or a string representing its repo_id or URL
|
|
||||||
:param models_installed: Set of installed models, used for recursive invocation
|
|
||||||
Returns a set of dict objects corresponding to newly-created stanzas in models.yaml.
|
|
||||||
"""
|
|
||||||
|
|
||||||
if not models_installed:
|
|
||||||
models_installed = {}
|
|
||||||
|
|
||||||
model_path_id_or_url = str(model_path_id_or_url).strip("\"' ")
|
|
||||||
|
|
||||||
# A little hack to allow nested routines to retrieve info on the requested ID
|
|
||||||
self.current_id = model_path_id_or_url
|
|
||||||
path = Path(model_path_id_or_url)
|
|
||||||
|
|
||||||
# fix relative paths
|
|
||||||
if path.exists() and not path.is_absolute():
|
|
||||||
path = path.absolute() # make relative to current WD
|
|
||||||
|
|
||||||
# checkpoint file, or similar
|
|
||||||
if path.is_file():
|
|
||||||
models_installed.update({str(path): self._install_path(path)})
|
|
||||||
|
|
||||||
# folders style or similar
|
|
||||||
elif path.is_dir() and any(
|
|
||||||
(path / x).exists()
|
|
||||||
for x in {
|
|
||||||
"config.json",
|
|
||||||
"model_index.json",
|
|
||||||
"learned_embeds.bin",
|
|
||||||
"pytorch_lora_weights.bin",
|
|
||||||
"pytorch_lora_weights.safetensors",
|
|
||||||
}
|
|
||||||
):
|
|
||||||
models_installed.update({str(model_path_id_or_url): self._install_path(path)})
|
|
||||||
|
|
||||||
# recursive scan
|
|
||||||
elif path.is_dir():
|
|
||||||
for child in path.iterdir():
|
|
||||||
self.heuristic_import(child, models_installed=models_installed)
|
|
||||||
|
|
||||||
# huggingface repo
|
|
||||||
elif len(str(model_path_id_or_url).split("/")) == 2:
|
|
||||||
models_installed.update({str(model_path_id_or_url): self._install_repo(str(model_path_id_or_url))})
|
|
||||||
|
|
||||||
# a URL
|
|
||||||
elif str(model_path_id_or_url).startswith(("http:", "https:", "ftp:")):
|
|
||||||
models_installed.update({str(model_path_id_or_url): self._install_url(model_path_id_or_url)})
|
|
||||||
|
|
||||||
else:
|
|
||||||
raise KeyError(f"{str(model_path_id_or_url)} is not recognized as a local path, repo ID or URL. Skipping")
|
|
||||||
|
|
||||||
return models_installed
|
|
||||||
|
|
||||||
def _remove_installed(self, model_list: List[str]):
|
|
||||||
all_models = self.all_models()
|
|
||||||
models_to_remove = []
|
|
||||||
|
|
||||||
for path in model_list:
|
|
||||||
key = self.reverse_paths.get(path)
|
|
||||||
if key and all_models[key].installed:
|
|
||||||
models_to_remove.append(path)
|
|
||||||
|
|
||||||
for path in models_to_remove:
|
|
||||||
logger.warning(f"{path} already installed. Skipping")
|
|
||||||
model_list.remove(path)
|
|
||||||
|
|
||||||
def _add_required_models(self, model_list: List[str]):
|
|
||||||
additional_models = []
|
|
||||||
all_models = self.all_models()
|
|
||||||
for path in model_list:
|
|
||||||
if not (key := self.reverse_paths.get(path)):
|
|
||||||
continue
|
|
||||||
for requirement in all_models[key].requires:
|
|
||||||
requirement_key = self.reverse_paths.get(requirement)
|
|
||||||
if not all_models[requirement_key].installed:
|
|
||||||
additional_models.append(requirement)
|
|
||||||
model_list.extend(additional_models)
|
|
||||||
|
|
||||||
# install a model from a local path. The optional info parameter is there to prevent
|
|
||||||
# the model from being probed twice in the event that it has already been probed.
|
|
||||||
def _install_path(self, path: Path, info: ModelProbeInfo = None) -> AddModelResult:
|
|
||||||
info = info or ModelProbe().heuristic_probe(path, self.prediction_helper)
|
|
||||||
if not info:
|
|
||||||
logger.warning(f"Unable to parse format of {path}")
|
|
||||||
return None
|
|
||||||
model_name = path.stem if path.is_file() else path.name
|
|
||||||
if self.mgr.model_exists(model_name, info.base_type, info.model_type):
|
|
||||||
raise ValueError(f'A model named "{model_name}" is already installed.')
|
|
||||||
attributes = self._make_attributes(path, info)
|
|
||||||
return self.mgr.add_model(
|
|
||||||
model_name=model_name,
|
|
||||||
base_model=info.base_type,
|
|
||||||
model_type=info.model_type,
|
|
||||||
model_attributes=attributes,
|
|
||||||
)
|
|
||||||
|
|
||||||
def _install_url(self, url: str) -> AddModelResult:
|
|
||||||
with TemporaryDirectory(dir=self.config.models_path) as staging:
|
|
||||||
CIVITAI_RE = r".*civitai.com.*"
|
|
||||||
civit_url = re.match(CIVITAI_RE, url, re.IGNORECASE)
|
|
||||||
location = download_with_resume(
|
|
||||||
url, Path(staging), access_token=self.civitai_api_key if civit_url else None
|
|
||||||
)
|
|
||||||
if not location:
|
|
||||||
logger.error(f"Unable to download {url}. Skipping.")
|
|
||||||
info = ModelProbe().heuristic_probe(location, self.prediction_helper)
|
|
||||||
dest = self.config.models_path / info.base_type.value / info.model_type.value / location.name
|
|
||||||
dest.parent.mkdir(parents=True, exist_ok=True)
|
|
||||||
models_path = shutil.move(location, dest)
|
|
||||||
|
|
||||||
# staged version will be garbage-collected at this time
|
|
||||||
return self._install_path(Path(models_path), info)
|
|
||||||
|
|
||||||
def _install_repo(self, repo_id: str) -> AddModelResult:
|
|
||||||
# hack to recover models stored in subfolders --
|
|
||||||
# Required to get the "v2" model of monster-labs/control_v1p_sd15_qrcode_monster
|
|
||||||
subfolder = None
|
|
||||||
if match := re.match(r"^([^/]+/[^/]+):(\w+)$", repo_id):
|
|
||||||
repo_id = match.group(1)
|
|
||||||
subfolder = match.group(2)
|
|
||||||
|
|
||||||
hinfo = HfApi().model_info(repo_id)
|
|
||||||
|
|
||||||
# we try to figure out how to download this most economically
|
|
||||||
# list all the files in the repo
|
|
||||||
files = [x.rfilename for x in hinfo.siblings]
|
|
||||||
if subfolder:
|
|
||||||
files = [x for x in files if x.startswith(f"{subfolder}/")]
|
|
||||||
prefix = f"{subfolder}/" if subfolder else ""
|
|
||||||
|
|
||||||
location = None
|
|
||||||
|
|
||||||
with TemporaryDirectory(dir=self.config.models_path) as staging:
|
|
||||||
staging = Path(staging)
|
|
||||||
if f"{prefix}model_index.json" in files:
|
|
||||||
location = self._download_hf_pipeline(repo_id, staging, subfolder=subfolder) # pipeline
|
|
||||||
elif f"{prefix}unet/model.onnx" in files:
|
|
||||||
location = self._download_hf_model(repo_id, files, staging)
|
|
||||||
else:
|
|
||||||
for suffix in ["safetensors", "bin"]:
|
|
||||||
if f"{prefix}pytorch_lora_weights.{suffix}" in files:
|
|
||||||
location = self._download_hf_model(
|
|
||||||
repo_id, [f"pytorch_lora_weights.{suffix}"], staging, subfolder=subfolder
|
|
||||||
) # LoRA
|
|
||||||
break
|
|
||||||
elif (
|
|
||||||
self.config.precision == "float16" and f"{prefix}diffusion_pytorch_model.fp16.{suffix}" in files
|
|
||||||
): # vae, controlnet or some other standalone
|
|
||||||
files = ["config.json", f"diffusion_pytorch_model.fp16.{suffix}"]
|
|
||||||
location = self._download_hf_model(repo_id, files, staging, subfolder=subfolder)
|
|
||||||
break
|
|
||||||
elif f"{prefix}diffusion_pytorch_model.{suffix}" in files:
|
|
||||||
files = ["config.json", f"diffusion_pytorch_model.{suffix}"]
|
|
||||||
location = self._download_hf_model(repo_id, files, staging, subfolder=subfolder)
|
|
||||||
break
|
|
||||||
elif f"{prefix}learned_embeds.{suffix}" in files:
|
|
||||||
location = self._download_hf_model(
|
|
||||||
repo_id, [f"learned_embeds.{suffix}"], staging, subfolder=subfolder
|
|
||||||
)
|
|
||||||
break
|
|
||||||
elif (
|
|
||||||
f"{prefix}image_encoder.txt" in files and f"{prefix}ip_adapter.{suffix}" in files
|
|
||||||
): # IP-Adapter
|
|
||||||
files = ["image_encoder.txt", f"ip_adapter.{suffix}"]
|
|
||||||
location = self._download_hf_model(repo_id, files, staging, subfolder=subfolder)
|
|
||||||
break
|
|
||||||
elif f"{prefix}model.{suffix}" in files and f"{prefix}config.json" in files:
|
|
||||||
# This elif-condition is pretty fragile, but it is intended to handle CLIP Vision models hosted
|
|
||||||
# by InvokeAI for use with IP-Adapters.
|
|
||||||
files = ["config.json", f"model.{suffix}"]
|
|
||||||
location = self._download_hf_model(repo_id, files, staging, subfolder=subfolder)
|
|
||||||
break
|
|
||||||
if not location:
|
|
||||||
logger.warning(f"Could not determine type of repo {repo_id}. Skipping install.")
|
|
||||||
return {}
|
|
||||||
|
|
||||||
info = ModelProbe().heuristic_probe(location, self.prediction_helper)
|
|
||||||
if not info:
|
|
||||||
logger.warning(f"Could not probe {location}. Skipping install.")
|
|
||||||
return {}
|
|
||||||
dest = (
|
|
||||||
self.config.models_path
|
|
||||||
/ info.base_type.value
|
|
||||||
/ info.model_type.value
|
|
||||||
/ self._get_model_name(repo_id, location)
|
|
||||||
)
|
|
||||||
if dest.exists():
|
|
||||||
shutil.rmtree(dest)
|
|
||||||
shutil.copytree(location, dest)
|
|
||||||
return self._install_path(dest, info)
|
|
||||||
|
|
||||||
def _get_model_name(self, path_name: str, location: Path) -> str:
|
|
||||||
"""
|
|
||||||
Calculate a name for the model - primitive implementation.
|
|
||||||
"""
|
|
||||||
if key := self.reverse_paths.get(path_name):
|
|
||||||
(name, base, mtype) = ModelManager.parse_key(key)
|
|
||||||
return name
|
|
||||||
elif location.is_dir():
|
|
||||||
return location.name
|
|
||||||
else:
|
|
||||||
return location.stem
|
|
||||||
|
|
||||||
def _make_attributes(self, path: Path, info: ModelProbeInfo) -> dict:
|
|
||||||
model_name = path.name if path.is_dir() else path.stem
|
|
||||||
description = f"{info.base_type.value} {info.model_type.value} model {model_name}"
|
|
||||||
if key := self.reverse_paths.get(self.current_id):
|
|
||||||
if key in self.datasets:
|
|
||||||
description = self.datasets[key].get("description") or description
|
|
||||||
|
|
||||||
rel_path = self.relative_to_root(path, self.config.models_path)
|
|
||||||
|
|
||||||
attributes = {
|
|
||||||
"path": str(rel_path),
|
|
||||||
"description": str(description),
|
|
||||||
"model_format": info.format,
|
|
||||||
}
|
|
||||||
legacy_conf = None
|
|
||||||
if info.model_type == ModelType.Main or info.model_type == ModelType.ONNX:
|
|
||||||
attributes.update(
|
|
||||||
{
|
|
||||||
"variant": info.variant_type,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
if info.format == "checkpoint":
|
|
||||||
try:
|
|
||||||
possible_conf = path.with_suffix(".yaml")
|
|
||||||
if possible_conf.exists():
|
|
||||||
legacy_conf = str(self.relative_to_root(possible_conf))
|
|
||||||
elif info.base_type in [BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2]:
|
|
||||||
legacy_conf = Path(
|
|
||||||
self.config.legacy_conf_dir,
|
|
||||||
LEGACY_CONFIGS[info.base_type][info.variant_type][info.prediction_type],
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
legacy_conf = Path(
|
|
||||||
self.config.legacy_conf_dir, LEGACY_CONFIGS[info.base_type][info.variant_type]
|
|
||||||
)
|
|
||||||
except KeyError:
|
|
||||||
legacy_conf = Path(self.config.legacy_conf_dir, "v1-inference.yaml") # best guess
|
|
||||||
|
|
||||||
if info.model_type == ModelType.ControlNet and info.format == "checkpoint":
|
|
||||||
possible_conf = path.with_suffix(".yaml")
|
|
||||||
if possible_conf.exists():
|
|
||||||
legacy_conf = str(self.relative_to_root(possible_conf))
|
|
||||||
else:
|
|
||||||
legacy_conf = Path(
|
|
||||||
self.config.root_path,
|
|
||||||
"configs/controlnet",
|
|
||||||
("cldm_v15.yaml" if info.base_type == BaseModelType("sd-1") else "cldm_v21.yaml"),
|
|
||||||
)
|
|
||||||
|
|
||||||
if legacy_conf:
|
|
||||||
attributes.update({"config": str(legacy_conf)})
|
|
||||||
return attributes
|
|
||||||
|
|
||||||
def relative_to_root(self, path: Path, root: Optional[Path] = None) -> Path:
|
|
||||||
root = root or self.config.root_path
|
|
||||||
if path.is_relative_to(root):
|
|
||||||
return path.relative_to(root)
|
|
||||||
else:
|
|
||||||
return path
|
|
||||||
|
|
||||||
def _download_hf_pipeline(self, repo_id: str, staging: Path, subfolder: str = None) -> Path:
|
|
||||||
"""
|
|
||||||
Retrieve a StableDiffusion model from cache or remote and then
|
|
||||||
does a save_pretrained() to the indicated staging area.
|
|
||||||
"""
|
|
||||||
_, name = repo_id.split("/")
|
|
||||||
precision = torch_dtype(choose_torch_device())
|
|
||||||
variants = ["fp16", None] if precision == torch.float16 else [None, "fp16"]
|
|
||||||
|
|
||||||
model = None
|
|
||||||
for variant in variants:
|
|
||||||
try:
|
|
||||||
model = DiffusionPipeline.from_pretrained(
|
|
||||||
repo_id,
|
|
||||||
variant=variant,
|
|
||||||
torch_dtype=precision,
|
|
||||||
safety_checker=None,
|
|
||||||
subfolder=subfolder,
|
|
||||||
)
|
|
||||||
except Exception as e: # most errors are due to fp16 not being present. Fix this to catch other errors
|
|
||||||
if "fp16" not in str(e):
|
|
||||||
print(e)
|
|
||||||
|
|
||||||
if model:
|
|
||||||
break
|
|
||||||
|
|
||||||
if not model:
|
|
||||||
logger.error(f"Diffusers model {repo_id} could not be downloaded. Skipping.")
|
|
||||||
return None
|
|
||||||
model.save_pretrained(staging / name, safe_serialization=True)
|
|
||||||
return staging / name
|
|
||||||
|
|
||||||
def _download_hf_model(self, repo_id: str, files: List[str], staging: Path, subfolder: None) -> Path:
|
|
||||||
_, name = repo_id.split("/")
|
|
||||||
location = staging / name
|
|
||||||
paths = []
|
|
||||||
for filename in files:
|
|
||||||
filePath = Path(filename)
|
|
||||||
p = hf_download_with_resume(
|
|
||||||
repo_id,
|
|
||||||
model_dir=location / filePath.parent,
|
|
||||||
model_name=filePath.name,
|
|
||||||
access_token=self.access_token,
|
|
||||||
subfolder=filePath.parent / subfolder if subfolder else filePath.parent,
|
|
||||||
)
|
|
||||||
if p:
|
|
||||||
paths.append(p)
|
|
||||||
else:
|
|
||||||
logger.warning(f"Could not download {filename} from {repo_id}.")
|
|
||||||
|
|
||||||
return location if len(paths) > 0 else None
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def _reverse_paths(cls, datasets) -> dict:
|
|
||||||
"""
|
|
||||||
Reverse mapping from repo_id/path to destination name.
|
|
||||||
"""
|
|
||||||
return {v.get("path") or v.get("repo_id"): k for k, v in datasets.items()}
|
|
||||||
|
|
||||||
|
|
||||||
# -------------------------------------
|
|
||||||
def yes_or_no(prompt: str, default_yes=True):
|
|
||||||
default = "y" if default_yes else "n"
|
|
||||||
response = input(f"{prompt} [{default}] ") or default
|
|
||||||
if default_yes:
|
|
||||||
return response[0] not in ("n", "N")
|
|
||||||
else:
|
|
||||||
return response[0] in ("y", "Y")
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------
|
|
||||||
def hf_download_from_pretrained(model_class: object, model_name: str, destination: Path, **kwargs):
|
|
||||||
logger = InvokeAILogger.get_logger("InvokeAI")
|
|
||||||
logger.addFilter(lambda x: "fp16 is not a valid" not in x.getMessage())
|
|
||||||
|
|
||||||
model = model_class.from_pretrained(
|
|
||||||
model_name,
|
|
||||||
resume_download=True,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
model.save_pretrained(destination, safe_serialization=True)
|
|
||||||
return destination
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------
|
|
||||||
def hf_download_with_resume(
|
|
||||||
repo_id: str,
|
|
||||||
model_dir: str,
|
|
||||||
model_name: str,
|
|
||||||
model_dest: Path = None,
|
|
||||||
access_token: str = None,
|
|
||||||
subfolder: str = None,
|
|
||||||
) -> Path:
|
|
||||||
model_dest = model_dest or Path(os.path.join(model_dir, model_name))
|
|
||||||
os.makedirs(model_dir, exist_ok=True)
|
|
||||||
|
|
||||||
url = hf_hub_url(repo_id, model_name, subfolder=subfolder)
|
|
||||||
|
|
||||||
header = {"Authorization": f"Bearer {access_token}"} if access_token else {}
|
|
||||||
open_mode = "wb"
|
|
||||||
exist_size = 0
|
|
||||||
|
|
||||||
if os.path.exists(model_dest):
|
|
||||||
exist_size = os.path.getsize(model_dest)
|
|
||||||
header["Range"] = f"bytes={exist_size}-"
|
|
||||||
open_mode = "ab"
|
|
||||||
|
|
||||||
resp = requests.get(url, headers=header, stream=True)
|
|
||||||
total = int(resp.headers.get("content-length", 0))
|
|
||||||
|
|
||||||
if resp.status_code == 416: # "range not satisfiable", which means nothing to return
|
|
||||||
logger.info(f"{model_name}: complete file found. Skipping.")
|
|
||||||
return model_dest
|
|
||||||
elif resp.status_code == 404:
|
|
||||||
logger.warning("File not found")
|
|
||||||
return None
|
|
||||||
elif resp.status_code != 200:
|
|
||||||
logger.warning(f"{model_name}: {resp.reason}")
|
|
||||||
elif exist_size > 0:
|
|
||||||
logger.info(f"{model_name}: partial file found. Resuming...")
|
|
||||||
else:
|
|
||||||
logger.info(f"{model_name}: Downloading...")
|
|
||||||
|
|
||||||
try:
|
|
||||||
with (
|
|
||||||
open(model_dest, open_mode) as file,
|
|
||||||
tqdm(
|
|
||||||
desc=model_name,
|
|
||||||
initial=exist_size,
|
|
||||||
total=total + exist_size,
|
|
||||||
unit="iB",
|
|
||||||
unit_scale=True,
|
|
||||||
unit_divisor=1000,
|
|
||||||
) as bar,
|
|
||||||
):
|
|
||||||
for data in resp.iter_content(chunk_size=1024):
|
|
||||||
size = file.write(data)
|
|
||||||
bar.update(size)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"An error occurred while downloading {model_name}: {str(e)}")
|
|
||||||
return None
|
|
||||||
return model_dest
|
|
@ -9,8 +9,8 @@ from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
|
|||||||
|
|
||||||
from invokeai.backend.ip_adapter.ip_attention_weights import IPAttentionWeights
|
from invokeai.backend.ip_adapter.ip_attention_weights import IPAttentionWeights
|
||||||
|
|
||||||
from .resampler import Resampler
|
|
||||||
from ..raw_model import RawModel
|
from ..raw_model import RawModel
|
||||||
|
from .resampler import Resampler
|
||||||
|
|
||||||
|
|
||||||
class ImageProjModel(torch.nn.Module):
|
class ImageProjModel(torch.nn.Module):
|
||||||
|
@ -10,6 +10,7 @@ from safetensors.torch import load_file
|
|||||||
from typing_extensions import Self
|
from typing_extensions import Self
|
||||||
|
|
||||||
from invokeai.backend.model_manager import BaseModelType
|
from invokeai.backend.model_manager import BaseModelType
|
||||||
|
|
||||||
from .raw_model import RawModel
|
from .raw_model import RawModel
|
||||||
|
|
||||||
|
|
||||||
@ -366,6 +367,7 @@ class IA3Layer(LoRALayerBase):
|
|||||||
|
|
||||||
AnyLoRALayer = Union[LoRALayer, LoHALayer, LoKRLayer, FullLayer, IA3Layer]
|
AnyLoRALayer = Union[LoRALayer, LoHALayer, LoKRLayer, FullLayer, IA3Layer]
|
||||||
|
|
||||||
|
|
||||||
class LoRAModelRaw(RawModel): # (torch.nn.Module):
|
class LoRAModelRaw(RawModel): # (torch.nn.Module):
|
||||||
_name: str
|
_name: str
|
||||||
layers: Dict[str, AnyLoRALayer]
|
layers: Dict[str, AnyLoRALayer]
|
||||||
|
@ -1,27 +0,0 @@
|
|||||||
# Model Cache
|
|
||||||
|
|
||||||
## `glibc` Memory Allocator Fragmentation
|
|
||||||
|
|
||||||
Python (and PyTorch) relies on the memory allocator from the C Standard Library (`libc`). On linux, with the GNU C Standard Library implementation (`glibc`), our memory access patterns have been observed to cause severe memory fragmentation. This fragmentation results in large amounts of memory that has been freed but can't be released back to the OS. Loading models from disk and moving them between CPU/CUDA seem to be the operations that contribute most to the fragmentation. This memory fragmentation issue can result in OOM crashes during frequent model switching, even if `max_cache_size` is set to a reasonable value (e.g. a OOM crash with `max_cache_size=16` on a system with 32GB of RAM).
|
|
||||||
|
|
||||||
This problem may also exist on other OSes, and other `libc` implementations. But, at the time of writing, it has only been investigated on linux with `glibc`.
|
|
||||||
|
|
||||||
To better understand how the `glibc` memory allocator works, see these references:
|
|
||||||
- Basics: https://www.gnu.org/software/libc/manual/html_node/The-GNU-Allocator.html
|
|
||||||
- Details: https://sourceware.org/glibc/wiki/MallocInternals
|
|
||||||
|
|
||||||
Note the differences between memory allocated as chunks in an arena vs. memory allocated with `mmap`. Under `glibc`'s default configuration, most model tensors get allocated as chunks in an arena making them vulnerable to the problem of fragmentation.
|
|
||||||
|
|
||||||
We can work around this memory fragmentation issue by setting the following env var:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# Force blocks >1MB to be allocated with `mmap` so that they are released to the system immediately when they are freed.
|
|
||||||
MALLOC_MMAP_THRESHOLD_=1048576
|
|
||||||
```
|
|
||||||
|
|
||||||
See the following references for more information about the `malloc` tunable parameters:
|
|
||||||
- https://www.gnu.org/software/libc/manual/html_node/Malloc-Tunable-Parameters.html
|
|
||||||
- https://www.gnu.org/software/libc/manual/html_node/Memory-Allocation-Tunables.html
|
|
||||||
- https://man7.org/linux/man-pages/man3/mallopt.3.html
|
|
||||||
|
|
||||||
The model cache emits debug logs that provide visibility into the state of the `libc` memory allocator. See the `LibcUtil` class for more info on how these `libc` malloc stats are collected.
|
|
@ -1,20 +0,0 @@
|
|||||||
# ruff: noqa: I001, F401
|
|
||||||
"""
|
|
||||||
Initialization file for invokeai.backend.model_management
|
|
||||||
"""
|
|
||||||
# This import must be first
|
|
||||||
from .model_manager import AddModelResult, ModelInfo, ModelManager, SchedulerPredictionType
|
|
||||||
from .lora import ModelPatcher, ONNXModelPatcher
|
|
||||||
from .model_cache import ModelCache
|
|
||||||
|
|
||||||
from .models import (
|
|
||||||
BaseModelType,
|
|
||||||
DuplicateModelException,
|
|
||||||
ModelNotFoundException,
|
|
||||||
ModelType,
|
|
||||||
ModelVariantType,
|
|
||||||
SubModelType,
|
|
||||||
)
|
|
||||||
|
|
||||||
# This import must be last
|
|
||||||
from .model_merge import MergeInterpolationMethod, ModelMerger
|
|
File diff suppressed because it is too large
Load Diff
@ -1,31 +0,0 @@
|
|||||||
# Copyright (c) 2024 Lincoln Stein and the InvokeAI Development Team
|
|
||||||
"""
|
|
||||||
This module exports the function has_baked_in_sdxl_vae().
|
|
||||||
It returns True if an SDXL checkpoint model has the original SDXL 1.0 VAE,
|
|
||||||
which doesn't work properly in fp16 mode.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import hashlib
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
from safetensors.torch import load_file
|
|
||||||
|
|
||||||
SDXL_1_0_VAE_HASH = "bc40b16c3a0fa4625abdfc01c04ffc21bf3cefa6af6c7768ec61eb1f1ac0da51"
|
|
||||||
|
|
||||||
|
|
||||||
def has_baked_in_sdxl_vae(checkpoint_path: Path) -> bool:
|
|
||||||
"""Return true if the checkpoint contains a custom (non SDXL-1.0) VAE."""
|
|
||||||
hash = _vae_hash(checkpoint_path)
|
|
||||||
return hash != SDXL_1_0_VAE_HASH
|
|
||||||
|
|
||||||
|
|
||||||
def _vae_hash(checkpoint_path: Path) -> str:
|
|
||||||
checkpoint = load_file(checkpoint_path, device="cpu")
|
|
||||||
vae_keys = [x for x in checkpoint.keys() if x.startswith("first_stage_model.")]
|
|
||||||
hash = hashlib.new("sha256")
|
|
||||||
for key in vae_keys:
|
|
||||||
value = checkpoint[key]
|
|
||||||
hash.update(bytes(key, "UTF-8"))
|
|
||||||
hash.update(bytes(str(value), "UTF-8"))
|
|
||||||
|
|
||||||
return hash.hexdigest()
|
|
@ -1,582 +0,0 @@
|
|||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import pickle
|
|
||||||
from contextlib import contextmanager
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
from compel.embeddings_provider import BaseTextualInversionManager
|
|
||||||
from diffusers.models import UNet2DConditionModel
|
|
||||||
from safetensors.torch import load_file
|
|
||||||
from transformers import CLIPTextModel, CLIPTokenizer
|
|
||||||
|
|
||||||
from invokeai.app.shared.models import FreeUConfig
|
|
||||||
from invokeai.backend.model_management.model_load_optimizations import skip_torch_weight_init
|
|
||||||
|
|
||||||
from .models.lora import LoRAModel
|
|
||||||
|
|
||||||
"""
|
|
||||||
loras = [
|
|
||||||
(lora_model1, 0.7),
|
|
||||||
(lora_model2, 0.4),
|
|
||||||
]
|
|
||||||
with LoRAHelper.apply_lora_unet(unet, loras):
|
|
||||||
# unet with applied loras
|
|
||||||
# unmodified unet
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
# TODO: rename smth like ModelPatcher and add TI method?
|
|
||||||
class ModelPatcher:
|
|
||||||
@staticmethod
|
|
||||||
def _resolve_lora_key(model: torch.nn.Module, lora_key: str, prefix: str) -> Tuple[str, torch.nn.Module]:
|
|
||||||
assert "." not in lora_key
|
|
||||||
|
|
||||||
if not lora_key.startswith(prefix):
|
|
||||||
raise Exception(f"lora_key with invalid prefix: {lora_key}, {prefix}")
|
|
||||||
|
|
||||||
module = model
|
|
||||||
module_key = ""
|
|
||||||
key_parts = lora_key[len(prefix) :].split("_")
|
|
||||||
|
|
||||||
submodule_name = key_parts.pop(0)
|
|
||||||
|
|
||||||
while len(key_parts) > 0:
|
|
||||||
try:
|
|
||||||
module = module.get_submodule(submodule_name)
|
|
||||||
module_key += "." + submodule_name
|
|
||||||
submodule_name = key_parts.pop(0)
|
|
||||||
except Exception:
|
|
||||||
submodule_name += "_" + key_parts.pop(0)
|
|
||||||
|
|
||||||
module = module.get_submodule(submodule_name)
|
|
||||||
module_key = (module_key + "." + submodule_name).lstrip(".")
|
|
||||||
|
|
||||||
return (module_key, module)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
@contextmanager
|
|
||||||
def apply_lora_unet(
|
|
||||||
cls,
|
|
||||||
unet: UNet2DConditionModel,
|
|
||||||
loras: List[Tuple[LoRAModel, float]],
|
|
||||||
):
|
|
||||||
with cls.apply_lora(unet, loras, "lora_unet_"):
|
|
||||||
yield
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
@contextmanager
|
|
||||||
def apply_lora_text_encoder(
|
|
||||||
cls,
|
|
||||||
text_encoder: CLIPTextModel,
|
|
||||||
loras: List[Tuple[LoRAModel, float]],
|
|
||||||
):
|
|
||||||
with cls.apply_lora(text_encoder, loras, "lora_te_"):
|
|
||||||
yield
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
@contextmanager
|
|
||||||
def apply_sdxl_lora_text_encoder(
|
|
||||||
cls,
|
|
||||||
text_encoder: CLIPTextModel,
|
|
||||||
loras: List[Tuple[LoRAModel, float]],
|
|
||||||
):
|
|
||||||
with cls.apply_lora(text_encoder, loras, "lora_te1_"):
|
|
||||||
yield
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
@contextmanager
|
|
||||||
def apply_sdxl_lora_text_encoder2(
|
|
||||||
cls,
|
|
||||||
text_encoder: CLIPTextModel,
|
|
||||||
loras: List[Tuple[LoRAModel, float]],
|
|
||||||
):
|
|
||||||
with cls.apply_lora(text_encoder, loras, "lora_te2_"):
|
|
||||||
yield
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
@contextmanager
|
|
||||||
def apply_lora(
|
|
||||||
cls,
|
|
||||||
model: torch.nn.Module,
|
|
||||||
loras: List[Tuple[LoRAModel, float]], # THIS IS INCORRECT. IT IS ACTUALLY A LoRAModelRaw
|
|
||||||
prefix: str,
|
|
||||||
):
|
|
||||||
original_weights = {}
|
|
||||||
try:
|
|
||||||
with torch.no_grad():
|
|
||||||
for lora, lora_weight in loras:
|
|
||||||
# assert lora.device.type == "cpu"
|
|
||||||
for layer_key, layer in lora.layers.items():
|
|
||||||
if not layer_key.startswith(prefix):
|
|
||||||
continue
|
|
||||||
|
|
||||||
# TODO(ryand): A non-negligible amount of time is currently spent resolving LoRA keys. This
|
|
||||||
# should be improved in the following ways:
|
|
||||||
# 1. The key mapping could be more-efficiently pre-computed. This would save time every time a
|
|
||||||
# LoRA model is applied.
|
|
||||||
# 2. From an API perspective, there's no reason that the `ModelPatcher` should be aware of the
|
|
||||||
# intricacies of Stable Diffusion key resolution. It should just expect the input LoRA
|
|
||||||
# weights to have valid keys.
|
|
||||||
module_key, module = cls._resolve_lora_key(model, layer_key, prefix)
|
|
||||||
|
|
||||||
# All of the LoRA weight calculations will be done on the same device as the module weight.
|
|
||||||
# (Performance will be best if this is a CUDA device.)
|
|
||||||
device = module.weight.device
|
|
||||||
dtype = module.weight.dtype
|
|
||||||
|
|
||||||
if module_key not in original_weights:
|
|
||||||
original_weights[module_key] = module.weight.detach().to(device="cpu", copy=True)
|
|
||||||
|
|
||||||
layer_scale = layer.alpha / layer.rank if (layer.alpha and layer.rank) else 1.0
|
|
||||||
|
|
||||||
# We intentionally move to the target device first, then cast. Experimentally, this was found to
|
|
||||||
# be significantly faster for 16-bit CPU tensors being moved to a CUDA device than doing the
|
|
||||||
# same thing in a single call to '.to(...)'.
|
|
||||||
layer.to(device=device)
|
|
||||||
layer.to(dtype=torch.float32)
|
|
||||||
# TODO(ryand): Using torch.autocast(...) over explicit casting may offer a speed benefit on CUDA
|
|
||||||
# devices here. Experimentally, it was found to be very slow on CPU. More investigation needed.
|
|
||||||
layer_weight = layer.get_weight(module.weight) * (lora_weight * layer_scale)
|
|
||||||
layer.to(device="cpu")
|
|
||||||
|
|
||||||
if module.weight.shape != layer_weight.shape:
|
|
||||||
# TODO: debug on lycoris
|
|
||||||
layer_weight = layer_weight.reshape(module.weight.shape)
|
|
||||||
|
|
||||||
module.weight += layer_weight.to(dtype=dtype)
|
|
||||||
|
|
||||||
yield # wait for context manager exit
|
|
||||||
|
|
||||||
finally:
|
|
||||||
with torch.no_grad():
|
|
||||||
for module_key, weight in original_weights.items():
|
|
||||||
model.get_submodule(module_key).weight.copy_(weight)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
@contextmanager
|
|
||||||
def apply_ti(
|
|
||||||
cls,
|
|
||||||
tokenizer: CLIPTokenizer,
|
|
||||||
text_encoder: CLIPTextModel,
|
|
||||||
ti_list: List[Tuple[str, Any]],
|
|
||||||
) -> Tuple[CLIPTokenizer, TextualInversionManager]:
|
|
||||||
init_tokens_count = None
|
|
||||||
new_tokens_added = None
|
|
||||||
|
|
||||||
# TODO: This is required since Transformers 4.32 see
|
|
||||||
# https://github.com/huggingface/transformers/pull/25088
|
|
||||||
# More information by NVIDIA:
|
|
||||||
# https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html#requirements-tc
|
|
||||||
# This value might need to be changed in the future and take the GPUs model into account as there seem
|
|
||||||
# to be ideal values for different GPUS. This value is temporary!
|
|
||||||
# For references to the current discussion please see https://github.com/invoke-ai/InvokeAI/pull/4817
|
|
||||||
pad_to_multiple_of = 8
|
|
||||||
|
|
||||||
try:
|
|
||||||
# HACK: The CLIPTokenizer API does not include a way to remove tokens after calling add_tokens(...). As a
|
|
||||||
# workaround, we create a full copy of `tokenizer` so that its original behavior can be restored after
|
|
||||||
# exiting this `apply_ti(...)` context manager.
|
|
||||||
#
|
|
||||||
# In a previous implementation, the deep copy was obtained with `ti_tokenizer = copy.deepcopy(tokenizer)`,
|
|
||||||
# but a pickle roundtrip was found to be much faster (1 sec vs. 0.05 secs).
|
|
||||||
ti_tokenizer = pickle.loads(pickle.dumps(tokenizer))
|
|
||||||
ti_manager = TextualInversionManager(ti_tokenizer)
|
|
||||||
init_tokens_count = text_encoder.resize_token_embeddings(None, pad_to_multiple_of).num_embeddings
|
|
||||||
|
|
||||||
def _get_trigger(ti_name, index):
|
|
||||||
trigger = ti_name
|
|
||||||
if index > 0:
|
|
||||||
trigger += f"-!pad-{i}"
|
|
||||||
return f"<{trigger}>"
|
|
||||||
|
|
||||||
def _get_ti_embedding(model_embeddings, ti):
|
|
||||||
print(f"DEBUG: model_embeddings={type(model_embeddings)}, ti={type(ti)}")
|
|
||||||
print(f"DEBUG: is it an nn.Module? {isinstance(model_embeddings, torch.nn.Module)}")
|
|
||||||
# for SDXL models, select the embedding that matches the text encoder's dimensions
|
|
||||||
if ti.embedding_2 is not None:
|
|
||||||
return (
|
|
||||||
ti.embedding_2
|
|
||||||
if ti.embedding_2.shape[1] == model_embeddings.weight.data[0].shape[0]
|
|
||||||
else ti.embedding
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
print(f"DEBUG: ti.embedding={type(ti.embedding)}")
|
|
||||||
return ti.embedding
|
|
||||||
|
|
||||||
# modify tokenizer
|
|
||||||
new_tokens_added = 0
|
|
||||||
for ti_name, ti in ti_list:
|
|
||||||
ti_embedding = _get_ti_embedding(text_encoder.get_input_embeddings(), ti)
|
|
||||||
|
|
||||||
for i in range(ti_embedding.shape[0]):
|
|
||||||
new_tokens_added += ti_tokenizer.add_tokens(_get_trigger(ti_name, i))
|
|
||||||
|
|
||||||
# Modify text_encoder.
|
|
||||||
# resize_token_embeddings(...) constructs a new torch.nn.Embedding internally. Initializing the weights of
|
|
||||||
# this embedding is slow and unnecessary, so we wrap this step in skip_torch_weight_init() to save some
|
|
||||||
# time.
|
|
||||||
with skip_torch_weight_init():
|
|
||||||
text_encoder.resize_token_embeddings(init_tokens_count + new_tokens_added, pad_to_multiple_of)
|
|
||||||
model_embeddings = text_encoder.get_input_embeddings()
|
|
||||||
|
|
||||||
for ti_name, ti in ti_list:
|
|
||||||
ti_embedding = _get_ti_embedding(text_encoder.get_input_embeddings(), ti)
|
|
||||||
|
|
||||||
ti_tokens = []
|
|
||||||
for i in range(ti_embedding.shape[0]):
|
|
||||||
embedding = ti_embedding[i]
|
|
||||||
trigger = _get_trigger(ti_name, i)
|
|
||||||
|
|
||||||
token_id = ti_tokenizer.convert_tokens_to_ids(trigger)
|
|
||||||
if token_id == ti_tokenizer.unk_token_id:
|
|
||||||
raise RuntimeError(f"Unable to find token id for token '{trigger}'")
|
|
||||||
|
|
||||||
if model_embeddings.weight.data[token_id].shape != embedding.shape:
|
|
||||||
raise ValueError(
|
|
||||||
f"Cannot load embedding for {trigger}. It was trained on a model with token dimension"
|
|
||||||
f" {embedding.shape[0]}, but the current model has token dimension"
|
|
||||||
f" {model_embeddings.weight.data[token_id].shape[0]}."
|
|
||||||
)
|
|
||||||
|
|
||||||
model_embeddings.weight.data[token_id] = embedding.to(
|
|
||||||
device=text_encoder.device, dtype=text_encoder.dtype
|
|
||||||
)
|
|
||||||
ti_tokens.append(token_id)
|
|
||||||
|
|
||||||
if len(ti_tokens) > 1:
|
|
||||||
ti_manager.pad_tokens[ti_tokens[0]] = ti_tokens[1:]
|
|
||||||
|
|
||||||
yield ti_tokenizer, ti_manager
|
|
||||||
|
|
||||||
finally:
|
|
||||||
if init_tokens_count and new_tokens_added:
|
|
||||||
text_encoder.resize_token_embeddings(init_tokens_count, pad_to_multiple_of)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
@contextmanager
|
|
||||||
def apply_clip_skip(
|
|
||||||
cls,
|
|
||||||
text_encoder: CLIPTextModel,
|
|
||||||
clip_skip: int,
|
|
||||||
):
|
|
||||||
skipped_layers = []
|
|
||||||
try:
|
|
||||||
for _i in range(clip_skip):
|
|
||||||
skipped_layers.append(text_encoder.text_model.encoder.layers.pop(-1))
|
|
||||||
|
|
||||||
yield
|
|
||||||
|
|
||||||
finally:
|
|
||||||
while len(skipped_layers) > 0:
|
|
||||||
text_encoder.text_model.encoder.layers.append(skipped_layers.pop())
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
@contextmanager
|
|
||||||
def apply_freeu(
|
|
||||||
cls,
|
|
||||||
unet: UNet2DConditionModel,
|
|
||||||
freeu_config: Optional[FreeUConfig] = None,
|
|
||||||
):
|
|
||||||
did_apply_freeu = False
|
|
||||||
try:
|
|
||||||
if freeu_config is not None:
|
|
||||||
unet.enable_freeu(b1=freeu_config.b1, b2=freeu_config.b2, s1=freeu_config.s1, s2=freeu_config.s2)
|
|
||||||
did_apply_freeu = True
|
|
||||||
|
|
||||||
yield
|
|
||||||
|
|
||||||
finally:
|
|
||||||
if did_apply_freeu:
|
|
||||||
unet.disable_freeu()
|
|
||||||
|
|
||||||
|
|
||||||
class TextualInversionModel:
|
|
||||||
embedding: torch.Tensor # [n, 768]|[n, 1280]
|
|
||||||
embedding_2: Optional[torch.Tensor] = None # [n, 768]|[n, 1280] - for SDXL models
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_checkpoint(
|
|
||||||
cls,
|
|
||||||
file_path: Union[str, Path],
|
|
||||||
device: Optional[torch.device] = None,
|
|
||||||
dtype: Optional[torch.dtype] = None,
|
|
||||||
):
|
|
||||||
if not isinstance(file_path, Path):
|
|
||||||
file_path = Path(file_path)
|
|
||||||
|
|
||||||
result = cls() # TODO:
|
|
||||||
|
|
||||||
if file_path.suffix == ".safetensors":
|
|
||||||
state_dict = load_file(file_path.absolute().as_posix(), device="cpu")
|
|
||||||
else:
|
|
||||||
state_dict = torch.load(file_path, map_location="cpu")
|
|
||||||
|
|
||||||
# both v1 and v2 format embeddings
|
|
||||||
# difference mostly in metadata
|
|
||||||
if "string_to_param" in state_dict:
|
|
||||||
if len(state_dict["string_to_param"]) > 1:
|
|
||||||
print(
|
|
||||||
f'Warn: Embedding "{file_path.name}" contains multiple tokens, which is not supported. The first',
|
|
||||||
" token will be used.",
|
|
||||||
)
|
|
||||||
|
|
||||||
result.embedding = next(iter(state_dict["string_to_param"].values()))
|
|
||||||
|
|
||||||
# v3 (easynegative)
|
|
||||||
elif "emb_params" in state_dict:
|
|
||||||
result.embedding = state_dict["emb_params"]
|
|
||||||
|
|
||||||
# v5(sdxl safetensors file)
|
|
||||||
elif "clip_g" in state_dict and "clip_l" in state_dict:
|
|
||||||
result.embedding = state_dict["clip_g"]
|
|
||||||
result.embedding_2 = state_dict["clip_l"]
|
|
||||||
|
|
||||||
# v4(diffusers bin files)
|
|
||||||
else:
|
|
||||||
result.embedding = next(iter(state_dict.values()))
|
|
||||||
|
|
||||||
if len(result.embedding.shape) == 1:
|
|
||||||
result.embedding = result.embedding.unsqueeze(0)
|
|
||||||
|
|
||||||
if not isinstance(result.embedding, torch.Tensor):
|
|
||||||
raise ValueError(f"Invalid embeddings file: {file_path.name}")
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
class TextualInversionManager(BaseTextualInversionManager):
|
|
||||||
pad_tokens: Dict[int, List[int]]
|
|
||||||
tokenizer: CLIPTokenizer
|
|
||||||
|
|
||||||
def __init__(self, tokenizer: CLIPTokenizer):
|
|
||||||
self.pad_tokens = {}
|
|
||||||
self.tokenizer = tokenizer
|
|
||||||
|
|
||||||
def expand_textual_inversion_token_ids_if_necessary(self, token_ids: list[int]) -> list[int]:
|
|
||||||
if len(self.pad_tokens) == 0:
|
|
||||||
return token_ids
|
|
||||||
|
|
||||||
if token_ids[0] == self.tokenizer.bos_token_id:
|
|
||||||
raise ValueError("token_ids must not start with bos_token_id")
|
|
||||||
if token_ids[-1] == self.tokenizer.eos_token_id:
|
|
||||||
raise ValueError("token_ids must not end with eos_token_id")
|
|
||||||
|
|
||||||
new_token_ids = []
|
|
||||||
for token_id in token_ids:
|
|
||||||
new_token_ids.append(token_id)
|
|
||||||
if token_id in self.pad_tokens:
|
|
||||||
new_token_ids.extend(self.pad_tokens[token_id])
|
|
||||||
|
|
||||||
# Do not exceed the max model input size
|
|
||||||
# The -2 here is compensating for compensate compel.embeddings_provider.get_token_ids(),
|
|
||||||
# which first removes and then adds back the start and end tokens.
|
|
||||||
max_length = list(self.tokenizer.max_model_input_sizes.values())[0] - 2
|
|
||||||
if len(new_token_ids) > max_length:
|
|
||||||
new_token_ids = new_token_ids[0:max_length]
|
|
||||||
|
|
||||||
return new_token_ids
|
|
||||||
|
|
||||||
|
|
||||||
class ONNXModelPatcher:
|
|
||||||
from diffusers import OnnxRuntimeModel
|
|
||||||
|
|
||||||
from .models.base import IAIOnnxRuntimeModel
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
@contextmanager
|
|
||||||
def apply_lora_unet(
|
|
||||||
cls,
|
|
||||||
unet: OnnxRuntimeModel,
|
|
||||||
loras: List[Tuple[LoRAModel, float]],
|
|
||||||
):
|
|
||||||
with cls.apply_lora(unet, loras, "lora_unet_"):
|
|
||||||
yield
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
@contextmanager
|
|
||||||
def apply_lora_text_encoder(
|
|
||||||
cls,
|
|
||||||
text_encoder: OnnxRuntimeModel,
|
|
||||||
loras: List[Tuple[LoRAModel, float]],
|
|
||||||
):
|
|
||||||
with cls.apply_lora(text_encoder, loras, "lora_te_"):
|
|
||||||
yield
|
|
||||||
|
|
||||||
# based on
|
|
||||||
# https://github.com/ssube/onnx-web/blob/ca2e436f0623e18b4cfe8a0363fcfcf10508acf7/api/onnx_web/convert/diffusion/lora.py#L323
|
|
||||||
@classmethod
|
|
||||||
@contextmanager
|
|
||||||
def apply_lora(
|
|
||||||
cls,
|
|
||||||
model: IAIOnnxRuntimeModel,
|
|
||||||
loras: List[Tuple[LoRAModel, float]],
|
|
||||||
prefix: str,
|
|
||||||
):
|
|
||||||
from .models.base import IAIOnnxRuntimeModel
|
|
||||||
|
|
||||||
if not isinstance(model, IAIOnnxRuntimeModel):
|
|
||||||
raise Exception("Only IAIOnnxRuntimeModel models supported")
|
|
||||||
|
|
||||||
orig_weights = {}
|
|
||||||
|
|
||||||
try:
|
|
||||||
blended_loras = {}
|
|
||||||
|
|
||||||
for lora, lora_weight in loras:
|
|
||||||
for layer_key, layer in lora.layers.items():
|
|
||||||
if not layer_key.startswith(prefix):
|
|
||||||
continue
|
|
||||||
|
|
||||||
layer.to(dtype=torch.float32)
|
|
||||||
layer_key = layer_key.replace(prefix, "")
|
|
||||||
# TODO: rewrite to pass original tensor weight(required by ia3)
|
|
||||||
layer_weight = layer.get_weight(None).detach().cpu().numpy() * lora_weight
|
|
||||||
if layer_key is blended_loras:
|
|
||||||
blended_loras[layer_key] += layer_weight
|
|
||||||
else:
|
|
||||||
blended_loras[layer_key] = layer_weight
|
|
||||||
|
|
||||||
node_names = {}
|
|
||||||
for node in model.nodes.values():
|
|
||||||
node_names[node.name.replace("/", "_").replace(".", "_").lstrip("_")] = node.name
|
|
||||||
|
|
||||||
for layer_key, lora_weight in blended_loras.items():
|
|
||||||
conv_key = layer_key + "_Conv"
|
|
||||||
gemm_key = layer_key + "_Gemm"
|
|
||||||
matmul_key = layer_key + "_MatMul"
|
|
||||||
|
|
||||||
if conv_key in node_names or gemm_key in node_names:
|
|
||||||
if conv_key in node_names:
|
|
||||||
conv_node = model.nodes[node_names[conv_key]]
|
|
||||||
else:
|
|
||||||
conv_node = model.nodes[node_names[gemm_key]]
|
|
||||||
|
|
||||||
weight_name = [n for n in conv_node.input if ".weight" in n][0]
|
|
||||||
orig_weight = model.tensors[weight_name]
|
|
||||||
|
|
||||||
if orig_weight.shape[-2:] == (1, 1):
|
|
||||||
if lora_weight.shape[-2:] == (1, 1):
|
|
||||||
new_weight = orig_weight.squeeze((3, 2)) + lora_weight.squeeze((3, 2))
|
|
||||||
else:
|
|
||||||
new_weight = orig_weight.squeeze((3, 2)) + lora_weight
|
|
||||||
|
|
||||||
new_weight = np.expand_dims(new_weight, (2, 3))
|
|
||||||
else:
|
|
||||||
if orig_weight.shape != lora_weight.shape:
|
|
||||||
new_weight = orig_weight + lora_weight.reshape(orig_weight.shape)
|
|
||||||
else:
|
|
||||||
new_weight = orig_weight + lora_weight
|
|
||||||
|
|
||||||
orig_weights[weight_name] = orig_weight
|
|
||||||
model.tensors[weight_name] = new_weight.astype(orig_weight.dtype)
|
|
||||||
|
|
||||||
elif matmul_key in node_names:
|
|
||||||
weight_node = model.nodes[node_names[matmul_key]]
|
|
||||||
matmul_name = [n for n in weight_node.input if "MatMul" in n][0]
|
|
||||||
|
|
||||||
orig_weight = model.tensors[matmul_name]
|
|
||||||
new_weight = orig_weight + lora_weight.transpose()
|
|
||||||
|
|
||||||
orig_weights[matmul_name] = orig_weight
|
|
||||||
model.tensors[matmul_name] = new_weight.astype(orig_weight.dtype)
|
|
||||||
|
|
||||||
else:
|
|
||||||
# warn? err?
|
|
||||||
pass
|
|
||||||
|
|
||||||
yield
|
|
||||||
|
|
||||||
finally:
|
|
||||||
# restore original weights
|
|
||||||
for name, orig_weight in orig_weights.items():
|
|
||||||
model.tensors[name] = orig_weight
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
@contextmanager
|
|
||||||
def apply_ti(
|
|
||||||
cls,
|
|
||||||
tokenizer: CLIPTokenizer,
|
|
||||||
text_encoder: IAIOnnxRuntimeModel,
|
|
||||||
ti_list: List[Tuple[str, Any]],
|
|
||||||
) -> Tuple[CLIPTokenizer, TextualInversionManager]:
|
|
||||||
from .models.base import IAIOnnxRuntimeModel
|
|
||||||
|
|
||||||
if not isinstance(text_encoder, IAIOnnxRuntimeModel):
|
|
||||||
raise Exception("Only IAIOnnxRuntimeModel models supported")
|
|
||||||
|
|
||||||
orig_embeddings = None
|
|
||||||
|
|
||||||
try:
|
|
||||||
# HACK: The CLIPTokenizer API does not include a way to remove tokens after calling add_tokens(...). As a
|
|
||||||
# workaround, we create a full copy of `tokenizer` so that its original behavior can be restored after
|
|
||||||
# exiting this `apply_ti(...)` context manager.
|
|
||||||
#
|
|
||||||
# In a previous implementation, the deep copy was obtained with `ti_tokenizer = copy.deepcopy(tokenizer)`,
|
|
||||||
# but a pickle roundtrip was found to be much faster (1 sec vs. 0.05 secs).
|
|
||||||
ti_tokenizer = pickle.loads(pickle.dumps(tokenizer))
|
|
||||||
ti_manager = TextualInversionManager(ti_tokenizer)
|
|
||||||
|
|
||||||
def _get_trigger(ti_name, index):
|
|
||||||
trigger = ti_name
|
|
||||||
if index > 0:
|
|
||||||
trigger += f"-!pad-{i}"
|
|
||||||
return f"<{trigger}>"
|
|
||||||
|
|
||||||
# modify text_encoder
|
|
||||||
orig_embeddings = text_encoder.tensors["text_model.embeddings.token_embedding.weight"]
|
|
||||||
|
|
||||||
# modify tokenizer
|
|
||||||
new_tokens_added = 0
|
|
||||||
for ti_name, ti in ti_list:
|
|
||||||
if ti.embedding_2 is not None:
|
|
||||||
ti_embedding = (
|
|
||||||
ti.embedding_2 if ti.embedding_2.shape[1] == orig_embeddings.shape[0] else ti.embedding
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
ti_embedding = ti.embedding
|
|
||||||
|
|
||||||
for i in range(ti_embedding.shape[0]):
|
|
||||||
new_tokens_added += ti_tokenizer.add_tokens(_get_trigger(ti_name, i))
|
|
||||||
|
|
||||||
embeddings = np.concatenate(
|
|
||||||
(np.copy(orig_embeddings), np.zeros((new_tokens_added, orig_embeddings.shape[1]))),
|
|
||||||
axis=0,
|
|
||||||
)
|
|
||||||
|
|
||||||
for ti_name, _ in ti_list:
|
|
||||||
ti_tokens = []
|
|
||||||
for i in range(ti_embedding.shape[0]):
|
|
||||||
embedding = ti_embedding[i].detach().numpy()
|
|
||||||
trigger = _get_trigger(ti_name, i)
|
|
||||||
|
|
||||||
token_id = ti_tokenizer.convert_tokens_to_ids(trigger)
|
|
||||||
if token_id == ti_tokenizer.unk_token_id:
|
|
||||||
raise RuntimeError(f"Unable to find token id for token '{trigger}'")
|
|
||||||
|
|
||||||
if embeddings[token_id].shape != embedding.shape:
|
|
||||||
raise ValueError(
|
|
||||||
f"Cannot load embedding for {trigger}. It was trained on a model with token dimension"
|
|
||||||
f" {embedding.shape[0]}, but the current model has token dimension"
|
|
||||||
f" {embeddings[token_id].shape[0]}."
|
|
||||||
)
|
|
||||||
|
|
||||||
embeddings[token_id] = embedding
|
|
||||||
ti_tokens.append(token_id)
|
|
||||||
|
|
||||||
if len(ti_tokens) > 1:
|
|
||||||
ti_manager.pad_tokens[ti_tokens[0]] = ti_tokens[1:]
|
|
||||||
|
|
||||||
text_encoder.tensors["text_model.embeddings.token_embedding.weight"] = embeddings.astype(
|
|
||||||
orig_embeddings.dtype
|
|
||||||
)
|
|
||||||
|
|
||||||
yield ti_tokenizer, ti_manager
|
|
||||||
|
|
||||||
finally:
|
|
||||||
# restore
|
|
||||||
if orig_embeddings is not None:
|
|
||||||
text_encoder.tensors["text_model.embeddings.token_embedding.weight"] = orig_embeddings
|
|
@ -1,99 +0,0 @@
|
|||||||
import gc
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
import psutil
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from invokeai.backend.model_management.libc_util import LibcUtil, Struct_mallinfo2
|
|
||||||
|
|
||||||
GB = 2**30 # 1 GB
|
|
||||||
|
|
||||||
|
|
||||||
class MemorySnapshot:
|
|
||||||
"""A snapshot of RAM and VRAM usage. All values are in bytes."""
|
|
||||||
|
|
||||||
def __init__(self, process_ram: int, vram: Optional[int], malloc_info: Optional[Struct_mallinfo2]):
|
|
||||||
"""Initialize a MemorySnapshot.
|
|
||||||
|
|
||||||
Most of the time, `MemorySnapshot` will be constructed with `MemorySnapshot.capture()`.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
process_ram (int): CPU RAM used by the current process.
|
|
||||||
vram (Optional[int]): VRAM used by torch.
|
|
||||||
malloc_info (Optional[Struct_mallinfo2]): Malloc info obtained from LibcUtil.
|
|
||||||
"""
|
|
||||||
self.process_ram = process_ram
|
|
||||||
self.vram = vram
|
|
||||||
self.malloc_info = malloc_info
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def capture(cls, run_garbage_collector: bool = True):
|
|
||||||
"""Capture and return a MemorySnapshot.
|
|
||||||
|
|
||||||
Note: This function has significant overhead, particularly if `run_garbage_collector == True`.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
run_garbage_collector (bool, optional): If true, gc.collect() will be run before checking the process RAM
|
|
||||||
usage. Defaults to True.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
MemorySnapshot
|
|
||||||
"""
|
|
||||||
if run_garbage_collector:
|
|
||||||
gc.collect()
|
|
||||||
|
|
||||||
# According to the psutil docs (https://psutil.readthedocs.io/en/latest/#psutil.Process.memory_info), rss is
|
|
||||||
# supported on all platforms.
|
|
||||||
process_ram = psutil.Process().memory_info().rss
|
|
||||||
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
vram = torch.cuda.memory_allocated()
|
|
||||||
else:
|
|
||||||
# TODO: We could add support for mps.current_allocated_memory() as well. Leaving out for now until we have
|
|
||||||
# time to test it properly.
|
|
||||||
vram = None
|
|
||||||
|
|
||||||
try:
|
|
||||||
malloc_info = LibcUtil().mallinfo2()
|
|
||||||
except (OSError, AttributeError):
|
|
||||||
# OSError: This is expected in environments that do not have the 'libc.so.6' shared library.
|
|
||||||
# AttributeError: This is expected in environments that have `libc.so.6` but do not have the `mallinfo2` (e.g. glibc < 2.33)
|
|
||||||
# TODO: Does `mallinfo` work?
|
|
||||||
malloc_info = None
|
|
||||||
|
|
||||||
return cls(process_ram, vram, malloc_info)
|
|
||||||
|
|
||||||
|
|
||||||
def get_pretty_snapshot_diff(snapshot_1: Optional[MemorySnapshot], snapshot_2: Optional[MemorySnapshot]) -> str:
|
|
||||||
"""Get a pretty string describing the difference between two `MemorySnapshot`s."""
|
|
||||||
|
|
||||||
def get_msg_line(prefix: str, val1: int, val2: int):
|
|
||||||
diff = val2 - val1
|
|
||||||
return f"{prefix: <30} ({(diff/GB):+5.3f}): {(val1/GB):5.3f}GB -> {(val2/GB):5.3f}GB\n"
|
|
||||||
|
|
||||||
msg = ""
|
|
||||||
|
|
||||||
if snapshot_1 is None or snapshot_2 is None:
|
|
||||||
return msg
|
|
||||||
|
|
||||||
msg += get_msg_line("Process RAM", snapshot_1.process_ram, snapshot_2.process_ram)
|
|
||||||
|
|
||||||
if snapshot_1.malloc_info is not None and snapshot_2.malloc_info is not None:
|
|
||||||
msg += get_msg_line("libc mmap allocated", snapshot_1.malloc_info.hblkhd, snapshot_2.malloc_info.hblkhd)
|
|
||||||
|
|
||||||
msg += get_msg_line("libc arena used", snapshot_1.malloc_info.uordblks, snapshot_2.malloc_info.uordblks)
|
|
||||||
|
|
||||||
msg += get_msg_line("libc arena free", snapshot_1.malloc_info.fordblks, snapshot_2.malloc_info.fordblks)
|
|
||||||
|
|
||||||
libc_total_allocated_1 = snapshot_1.malloc_info.arena + snapshot_1.malloc_info.hblkhd
|
|
||||||
libc_total_allocated_2 = snapshot_2.malloc_info.arena + snapshot_2.malloc_info.hblkhd
|
|
||||||
msg += get_msg_line("libc total allocated", libc_total_allocated_1, libc_total_allocated_2)
|
|
||||||
|
|
||||||
libc_total_used_1 = snapshot_1.malloc_info.uordblks + snapshot_1.malloc_info.hblkhd
|
|
||||||
libc_total_used_2 = snapshot_2.malloc_info.uordblks + snapshot_2.malloc_info.hblkhd
|
|
||||||
msg += get_msg_line("libc total used", libc_total_used_1, libc_total_used_2)
|
|
||||||
|
|
||||||
if snapshot_1.vram is not None and snapshot_2.vram is not None:
|
|
||||||
msg += get_msg_line("VRAM", snapshot_1.vram, snapshot_2.vram)
|
|
||||||
|
|
||||||
return msg
|
|
@ -1,553 +0,0 @@
|
|||||||
"""
|
|
||||||
Manage a RAM cache of diffusion/transformer models for fast switching.
|
|
||||||
They are moved between GPU VRAM and CPU RAM as necessary. If the cache
|
|
||||||
grows larger than a preset maximum, then the least recently used
|
|
||||||
model will be cleared and (re)loaded from disk when next needed.
|
|
||||||
|
|
||||||
The cache returns context manager generators designed to load the
|
|
||||||
model into the GPU within the context, and unload outside the
|
|
||||||
context. Use like this:
|
|
||||||
|
|
||||||
cache = ModelCache(max_cache_size=7.5)
|
|
||||||
with cache.get_model('runwayml/stable-diffusion-1-5') as SD1,
|
|
||||||
cache.get_model('stabilityai/stable-diffusion-2') as SD2:
|
|
||||||
do_something_in_GPU(SD1,SD2)
|
|
||||||
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
import gc
|
|
||||||
import hashlib
|
|
||||||
import math
|
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
import time
|
|
||||||
from contextlib import suppress
|
|
||||||
from dataclasses import dataclass, field
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Any, Dict, Optional, Type, Union, types
|
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
import invokeai.backend.util.logging as logger
|
|
||||||
from invokeai.backend.model_management.memory_snapshot import MemorySnapshot, get_pretty_snapshot_diff
|
|
||||||
from invokeai.backend.model_management.model_load_optimizations import skip_torch_weight_init
|
|
||||||
|
|
||||||
from ..util.devices import choose_torch_device
|
|
||||||
from .models import BaseModelType, ModelBase, ModelType, SubModelType
|
|
||||||
|
|
||||||
if choose_torch_device() == torch.device("mps"):
|
|
||||||
from torch import mps
|
|
||||||
|
|
||||||
# Maximum size of the cache, in gigs
|
|
||||||
# Default is roughly enough to hold three fp16 diffusers models in RAM simultaneously
|
|
||||||
DEFAULT_MAX_CACHE_SIZE = 6.0
|
|
||||||
|
|
||||||
# amount of GPU memory to hold in reserve for use by generations (GB)
|
|
||||||
DEFAULT_MAX_VRAM_CACHE_SIZE = 2.75
|
|
||||||
|
|
||||||
# actual size of a gig
|
|
||||||
GIG = 1073741824
|
|
||||||
# Size of a MB in bytes.
|
|
||||||
MB = 2**20
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class CacheStats(object):
|
|
||||||
hits: int = 0 # cache hits
|
|
||||||
misses: int = 0 # cache misses
|
|
||||||
high_watermark: int = 0 # amount of cache used
|
|
||||||
in_cache: int = 0 # number of models in cache
|
|
||||||
cleared: int = 0 # number of models cleared to make space
|
|
||||||
cache_size: int = 0 # total size of cache
|
|
||||||
# {submodel_key => size}
|
|
||||||
loaded_model_sizes: Dict[str, int] = field(default_factory=dict)
|
|
||||||
|
|
||||||
|
|
||||||
class ModelLocker(object):
|
|
||||||
"Forward declaration"
|
|
||||||
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class ModelCache(object):
|
|
||||||
"Forward declaration"
|
|
||||||
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class _CacheRecord:
|
|
||||||
size: int
|
|
||||||
model: Any
|
|
||||||
cache: ModelCache
|
|
||||||
_locks: int
|
|
||||||
|
|
||||||
def __init__(self, cache, model: Any, size: int):
|
|
||||||
self.size = size
|
|
||||||
self.model = model
|
|
||||||
self.cache = cache
|
|
||||||
self._locks = 0
|
|
||||||
|
|
||||||
def lock(self):
|
|
||||||
self._locks += 1
|
|
||||||
|
|
||||||
def unlock(self):
|
|
||||||
self._locks -= 1
|
|
||||||
assert self._locks >= 0
|
|
||||||
|
|
||||||
@property
|
|
||||||
def locked(self):
|
|
||||||
return self._locks > 0
|
|
||||||
|
|
||||||
@property
|
|
||||||
def loaded(self):
|
|
||||||
if self.model is not None and hasattr(self.model, "device"):
|
|
||||||
return self.model.device != self.cache.storage_device
|
|
||||||
else:
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
class ModelCache(object):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
max_cache_size: float = DEFAULT_MAX_CACHE_SIZE,
|
|
||||||
max_vram_cache_size: float = DEFAULT_MAX_VRAM_CACHE_SIZE,
|
|
||||||
execution_device: torch.device = torch.device("cuda"),
|
|
||||||
storage_device: torch.device = torch.device("cpu"),
|
|
||||||
precision: torch.dtype = torch.float16,
|
|
||||||
sequential_offload: bool = False,
|
|
||||||
lazy_offloading: bool = True,
|
|
||||||
sha_chunksize: int = 16777216,
|
|
||||||
logger: types.ModuleType = logger,
|
|
||||||
log_memory_usage: bool = False,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
:param max_cache_size: Maximum size of the RAM cache [6.0 GB]
|
|
||||||
:param execution_device: Torch device to load active model into [torch.device('cuda')]
|
|
||||||
:param storage_device: Torch device to save inactive model in [torch.device('cpu')]
|
|
||||||
:param precision: Precision for loaded models [torch.float16]
|
|
||||||
:param lazy_offloading: Keep model in VRAM until another model needs to be loaded
|
|
||||||
:param sequential_offload: Conserve VRAM by loading and unloading each stage of the pipeline sequentially
|
|
||||||
:param sha_chunksize: Chunksize to use when calculating sha256 model hash
|
|
||||||
:param log_memory_usage: If True, a memory snapshot will be captured before and after every model cache
|
|
||||||
operation, and the result will be logged (at debug level). There is a time cost to capturing the memory
|
|
||||||
snapshots, so it is recommended to disable this feature unless you are actively inspecting the model cache's
|
|
||||||
behaviour.
|
|
||||||
"""
|
|
||||||
self.model_infos: Dict[str, ModelBase] = {}
|
|
||||||
# allow lazy offloading only when vram cache enabled
|
|
||||||
self.lazy_offloading = lazy_offloading and max_vram_cache_size > 0
|
|
||||||
self.precision: torch.dtype = precision
|
|
||||||
self.max_cache_size: float = max_cache_size
|
|
||||||
self.max_vram_cache_size: float = max_vram_cache_size
|
|
||||||
self.execution_device: torch.device = execution_device
|
|
||||||
self.storage_device: torch.device = storage_device
|
|
||||||
self.sha_chunksize = sha_chunksize
|
|
||||||
self.logger = logger
|
|
||||||
self._log_memory_usage = log_memory_usage
|
|
||||||
|
|
||||||
# used for stats collection
|
|
||||||
self.stats = None
|
|
||||||
|
|
||||||
self._cached_models = {}
|
|
||||||
self._cache_stack = []
|
|
||||||
|
|
||||||
def _capture_memory_snapshot(self) -> Optional[MemorySnapshot]:
|
|
||||||
if self._log_memory_usage:
|
|
||||||
return MemorySnapshot.capture()
|
|
||||||
return None
|
|
||||||
|
|
||||||
def get_key(
|
|
||||||
self,
|
|
||||||
model_path: str,
|
|
||||||
base_model: BaseModelType,
|
|
||||||
model_type: ModelType,
|
|
||||||
submodel_type: Optional[SubModelType] = None,
|
|
||||||
):
|
|
||||||
key = f"{model_path}:{base_model}:{model_type}"
|
|
||||||
if submodel_type:
|
|
||||||
key += f":{submodel_type}"
|
|
||||||
return key
|
|
||||||
|
|
||||||
def _get_model_info(
|
|
||||||
self,
|
|
||||||
model_path: str,
|
|
||||||
model_class: Type[ModelBase],
|
|
||||||
base_model: BaseModelType,
|
|
||||||
model_type: ModelType,
|
|
||||||
):
|
|
||||||
model_info_key = self.get_key(
|
|
||||||
model_path=model_path,
|
|
||||||
base_model=base_model,
|
|
||||||
model_type=model_type,
|
|
||||||
submodel_type=None,
|
|
||||||
)
|
|
||||||
|
|
||||||
if model_info_key not in self.model_infos:
|
|
||||||
self.model_infos[model_info_key] = model_class(
|
|
||||||
model_path,
|
|
||||||
base_model,
|
|
||||||
model_type,
|
|
||||||
)
|
|
||||||
|
|
||||||
return self.model_infos[model_info_key]
|
|
||||||
|
|
||||||
# TODO: args
|
|
||||||
def get_model(
|
|
||||||
self,
|
|
||||||
model_path: Union[str, Path],
|
|
||||||
model_class: Type[ModelBase],
|
|
||||||
base_model: BaseModelType,
|
|
||||||
model_type: ModelType,
|
|
||||||
submodel: Optional[SubModelType] = None,
|
|
||||||
gpu_load: bool = True,
|
|
||||||
) -> Any:
|
|
||||||
if not isinstance(model_path, Path):
|
|
||||||
model_path = Path(model_path)
|
|
||||||
|
|
||||||
if not os.path.exists(model_path):
|
|
||||||
raise Exception(f"Model not found: {model_path}")
|
|
||||||
|
|
||||||
model_info = self._get_model_info(
|
|
||||||
model_path=model_path,
|
|
||||||
model_class=model_class,
|
|
||||||
base_model=base_model,
|
|
||||||
model_type=model_type,
|
|
||||||
)
|
|
||||||
key = self.get_key(
|
|
||||||
model_path=model_path,
|
|
||||||
base_model=base_model,
|
|
||||||
model_type=model_type,
|
|
||||||
submodel_type=submodel,
|
|
||||||
)
|
|
||||||
# TODO: lock for no copies on simultaneous calls?
|
|
||||||
cache_entry = self._cached_models.get(key, None)
|
|
||||||
if cache_entry is None:
|
|
||||||
self.logger.info(
|
|
||||||
f"Loading model {model_path}, type"
|
|
||||||
f" {base_model.value}:{model_type.value}{':'+submodel.value if submodel else ''}"
|
|
||||||
)
|
|
||||||
if self.stats:
|
|
||||||
self.stats.misses += 1
|
|
||||||
|
|
||||||
self_reported_model_size_before_load = model_info.get_size(submodel)
|
|
||||||
# Remove old models from the cache to make room for the new model.
|
|
||||||
self._make_cache_room(self_reported_model_size_before_load)
|
|
||||||
|
|
||||||
# Load the model from disk and capture a memory snapshot before/after.
|
|
||||||
start_load_time = time.time()
|
|
||||||
snapshot_before = self._capture_memory_snapshot()
|
|
||||||
with skip_torch_weight_init():
|
|
||||||
model = model_info.get_model(child_type=submodel, torch_dtype=self.precision)
|
|
||||||
snapshot_after = self._capture_memory_snapshot()
|
|
||||||
end_load_time = time.time()
|
|
||||||
|
|
||||||
self_reported_model_size_after_load = model_info.get_size(submodel)
|
|
||||||
|
|
||||||
self.logger.debug(
|
|
||||||
f"Moved model '{key}' from disk to cpu in {(end_load_time-start_load_time):.2f}s.\n"
|
|
||||||
f"Self-reported size before/after load: {(self_reported_model_size_before_load/GIG):.3f}GB /"
|
|
||||||
f" {(self_reported_model_size_after_load/GIG):.3f}GB.\n"
|
|
||||||
f"{get_pretty_snapshot_diff(snapshot_before, snapshot_after)}"
|
|
||||||
)
|
|
||||||
|
|
||||||
if abs(self_reported_model_size_after_load - self_reported_model_size_before_load) > 10 * MB:
|
|
||||||
self.logger.debug(
|
|
||||||
f"Model '{key}' mis-reported its size before load. Self-reported size before/after load:"
|
|
||||||
f" {(self_reported_model_size_before_load/GIG):.2f}GB /"
|
|
||||||
f" {(self_reported_model_size_after_load/GIG):.2f}GB."
|
|
||||||
)
|
|
||||||
|
|
||||||
cache_entry = _CacheRecord(self, model, self_reported_model_size_after_load)
|
|
||||||
self._cached_models[key] = cache_entry
|
|
||||||
else:
|
|
||||||
if self.stats:
|
|
||||||
self.stats.hits += 1
|
|
||||||
|
|
||||||
if self.stats:
|
|
||||||
self.stats.cache_size = self.max_cache_size * GIG
|
|
||||||
self.stats.high_watermark = max(self.stats.high_watermark, self._cache_size())
|
|
||||||
self.stats.in_cache = len(self._cached_models)
|
|
||||||
self.stats.loaded_model_sizes[key] = max(
|
|
||||||
self.stats.loaded_model_sizes.get(key, 0), model_info.get_size(submodel)
|
|
||||||
)
|
|
||||||
|
|
||||||
with suppress(Exception):
|
|
||||||
self._cache_stack.remove(key)
|
|
||||||
self._cache_stack.append(key)
|
|
||||||
|
|
||||||
return self.ModelLocker(self, key, cache_entry.model, gpu_load, cache_entry.size)
|
|
||||||
|
|
||||||
def _move_model_to_device(self, key: str, target_device: torch.device):
|
|
||||||
cache_entry = self._cached_models[key]
|
|
||||||
|
|
||||||
source_device = cache_entry.model.device
|
|
||||||
# Note: We compare device types only so that 'cuda' == 'cuda:0'. This would need to be revised to support
|
|
||||||
# multi-GPU.
|
|
||||||
if torch.device(source_device).type == torch.device(target_device).type:
|
|
||||||
return
|
|
||||||
|
|
||||||
start_model_to_time = time.time()
|
|
||||||
snapshot_before = self._capture_memory_snapshot()
|
|
||||||
cache_entry.model.to(target_device)
|
|
||||||
snapshot_after = self._capture_memory_snapshot()
|
|
||||||
end_model_to_time = time.time()
|
|
||||||
self.logger.debug(
|
|
||||||
f"Moved model '{key}' from {source_device} to"
|
|
||||||
f" {target_device} in {(end_model_to_time-start_model_to_time):.2f}s.\n"
|
|
||||||
f"Estimated model size: {(cache_entry.size/GIG):.3f} GB.\n"
|
|
||||||
f"{get_pretty_snapshot_diff(snapshot_before, snapshot_after)}"
|
|
||||||
)
|
|
||||||
|
|
||||||
if (
|
|
||||||
snapshot_before is not None
|
|
||||||
and snapshot_after is not None
|
|
||||||
and snapshot_before.vram is not None
|
|
||||||
and snapshot_after.vram is not None
|
|
||||||
):
|
|
||||||
vram_change = abs(snapshot_before.vram - snapshot_after.vram)
|
|
||||||
|
|
||||||
# If the estimated model size does not match the change in VRAM, log a warning.
|
|
||||||
if not math.isclose(
|
|
||||||
vram_change,
|
|
||||||
cache_entry.size,
|
|
||||||
rel_tol=0.1,
|
|
||||||
abs_tol=10 * MB,
|
|
||||||
):
|
|
||||||
self.logger.debug(
|
|
||||||
f"Moving model '{key}' from {source_device} to"
|
|
||||||
f" {target_device} caused an unexpected change in VRAM usage. The model's"
|
|
||||||
" estimated size may be incorrect. Estimated model size:"
|
|
||||||
f" {(cache_entry.size/GIG):.3f} GB.\n"
|
|
||||||
f"{get_pretty_snapshot_diff(snapshot_before, snapshot_after)}"
|
|
||||||
)
|
|
||||||
|
|
||||||
class ModelLocker(object):
|
|
||||||
def __init__(self, cache, key, model, gpu_load, size_needed):
|
|
||||||
"""
|
|
||||||
:param cache: The model_cache object
|
|
||||||
:param key: The key of the model to lock in GPU
|
|
||||||
:param model: The model to lock
|
|
||||||
:param gpu_load: True if load into gpu
|
|
||||||
:param size_needed: Size of the model to load
|
|
||||||
"""
|
|
||||||
self.gpu_load = gpu_load
|
|
||||||
self.cache = cache
|
|
||||||
self.key = key
|
|
||||||
self.model = model
|
|
||||||
self.size_needed = size_needed
|
|
||||||
self.cache_entry = self.cache._cached_models[self.key]
|
|
||||||
|
|
||||||
def __enter__(self) -> Any:
|
|
||||||
if not hasattr(self.model, "to"):
|
|
||||||
return self.model
|
|
||||||
|
|
||||||
# NOTE that the model has to have the to() method in order for this
|
|
||||||
# code to move it into GPU!
|
|
||||||
if self.gpu_load:
|
|
||||||
self.cache_entry.lock()
|
|
||||||
|
|
||||||
try:
|
|
||||||
if self.cache.lazy_offloading:
|
|
||||||
self.cache._offload_unlocked_models(self.size_needed)
|
|
||||||
|
|
||||||
self.cache._move_model_to_device(self.key, self.cache.execution_device)
|
|
||||||
|
|
||||||
self.cache.logger.debug(f"Locking {self.key} in {self.cache.execution_device}")
|
|
||||||
self.cache._print_cuda_stats()
|
|
||||||
|
|
||||||
except Exception:
|
|
||||||
self.cache_entry.unlock()
|
|
||||||
raise
|
|
||||||
|
|
||||||
# TODO: not fully understand
|
|
||||||
# in the event that the caller wants the model in RAM, we
|
|
||||||
# move it into CPU if it is in GPU and not locked
|
|
||||||
elif self.cache_entry.loaded and not self.cache_entry.locked:
|
|
||||||
self.cache._move_model_to_device(self.key, self.cache.storage_device)
|
|
||||||
|
|
||||||
return self.model
|
|
||||||
|
|
||||||
def __exit__(self, type, value, traceback):
|
|
||||||
if not hasattr(self.model, "to"):
|
|
||||||
return
|
|
||||||
|
|
||||||
self.cache_entry.unlock()
|
|
||||||
if not self.cache.lazy_offloading:
|
|
||||||
self.cache._offload_unlocked_models()
|
|
||||||
self.cache._print_cuda_stats()
|
|
||||||
|
|
||||||
# TODO: should it be called untrack_model?
|
|
||||||
def uncache_model(self, cache_id: str):
|
|
||||||
with suppress(ValueError):
|
|
||||||
self._cache_stack.remove(cache_id)
|
|
||||||
self._cached_models.pop(cache_id, None)
|
|
||||||
|
|
||||||
def model_hash(
|
|
||||||
self,
|
|
||||||
model_path: Union[str, Path],
|
|
||||||
) -> str:
|
|
||||||
"""
|
|
||||||
Given the HF repo id or path to a model on disk, returns a unique
|
|
||||||
hash. Works for legacy checkpoint files, HF models on disk, and HF repo IDs
|
|
||||||
|
|
||||||
:param model_path: Path to model file/directory on disk.
|
|
||||||
"""
|
|
||||||
return self._local_model_hash(model_path)
|
|
||||||
|
|
||||||
def cache_size(self) -> float:
|
|
||||||
"""Return the current size of the cache, in GB."""
|
|
||||||
return self._cache_size() / GIG
|
|
||||||
|
|
||||||
def _has_cuda(self) -> bool:
|
|
||||||
return self.execution_device.type == "cuda"
|
|
||||||
|
|
||||||
def _print_cuda_stats(self):
|
|
||||||
vram = "%4.2fG" % (torch.cuda.memory_allocated() / GIG)
|
|
||||||
ram = "%4.2fG" % self.cache_size()
|
|
||||||
|
|
||||||
cached_models = 0
|
|
||||||
loaded_models = 0
|
|
||||||
locked_models = 0
|
|
||||||
for model_info in self._cached_models.values():
|
|
||||||
cached_models += 1
|
|
||||||
if model_info.loaded:
|
|
||||||
loaded_models += 1
|
|
||||||
if model_info.locked:
|
|
||||||
locked_models += 1
|
|
||||||
|
|
||||||
self.logger.debug(
|
|
||||||
f"Current VRAM/RAM usage: {vram}/{ram}; cached_models/loaded_models/locked_models/ ="
|
|
||||||
f" {cached_models}/{loaded_models}/{locked_models}"
|
|
||||||
)
|
|
||||||
|
|
||||||
def _cache_size(self) -> int:
|
|
||||||
return sum([m.size for m in self._cached_models.values()])
|
|
||||||
|
|
||||||
def _make_cache_room(self, model_size):
|
|
||||||
# calculate how much memory this model will require
|
|
||||||
# multiplier = 2 if self.precision==torch.float32 else 1
|
|
||||||
bytes_needed = model_size
|
|
||||||
maximum_size = self.max_cache_size * GIG # stored in GB, convert to bytes
|
|
||||||
current_size = self._cache_size()
|
|
||||||
|
|
||||||
if current_size + bytes_needed > maximum_size:
|
|
||||||
self.logger.debug(
|
|
||||||
f"Max cache size exceeded: {(current_size/GIG):.2f}/{self.max_cache_size:.2f} GB, need an additional"
|
|
||||||
f" {(bytes_needed/GIG):.2f} GB"
|
|
||||||
)
|
|
||||||
|
|
||||||
self.logger.debug(f"Before unloading: cached_models={len(self._cached_models)}")
|
|
||||||
|
|
||||||
pos = 0
|
|
||||||
models_cleared = 0
|
|
||||||
while current_size + bytes_needed > maximum_size and pos < len(self._cache_stack):
|
|
||||||
model_key = self._cache_stack[pos]
|
|
||||||
cache_entry = self._cached_models[model_key]
|
|
||||||
|
|
||||||
refs = sys.getrefcount(cache_entry.model)
|
|
||||||
|
|
||||||
# HACK: This is a workaround for a memory-management issue that we haven't tracked down yet. We are directly
|
|
||||||
# going against the advice in the Python docs by using `gc.get_referrers(...)` in this way:
|
|
||||||
# https://docs.python.org/3/library/gc.html#gc.get_referrers
|
|
||||||
|
|
||||||
# manualy clear local variable references of just finished function calls
|
|
||||||
# for some reason python don't want to collect it even by gc.collect() immidiately
|
|
||||||
if refs > 2:
|
|
||||||
while True:
|
|
||||||
cleared = False
|
|
||||||
for referrer in gc.get_referrers(cache_entry.model):
|
|
||||||
if type(referrer).__name__ == "frame":
|
|
||||||
# RuntimeError: cannot clear an executing frame
|
|
||||||
with suppress(RuntimeError):
|
|
||||||
referrer.clear()
|
|
||||||
cleared = True
|
|
||||||
# break
|
|
||||||
|
|
||||||
# repeat if referrers changes(due to frame clear), else exit loop
|
|
||||||
if cleared:
|
|
||||||
gc.collect()
|
|
||||||
else:
|
|
||||||
break
|
|
||||||
|
|
||||||
device = cache_entry.model.device if hasattr(cache_entry.model, "device") else None
|
|
||||||
self.logger.debug(
|
|
||||||
f"Model: {model_key}, locks: {cache_entry._locks}, device: {device}, loaded: {cache_entry.loaded},"
|
|
||||||
f" refs: {refs}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Expected refs:
|
|
||||||
# 1 from cache_entry
|
|
||||||
# 1 from getrefcount function
|
|
||||||
# 1 from onnx runtime object
|
|
||||||
if not cache_entry.locked and refs <= (3 if "onnx" in model_key else 2):
|
|
||||||
self.logger.debug(
|
|
||||||
f"Unloading model {model_key} to free {(model_size/GIG):.2f} GB (-{(cache_entry.size/GIG):.2f} GB)"
|
|
||||||
)
|
|
||||||
current_size -= cache_entry.size
|
|
||||||
models_cleared += 1
|
|
||||||
if self.stats:
|
|
||||||
self.stats.cleared += 1
|
|
||||||
del self._cache_stack[pos]
|
|
||||||
del self._cached_models[model_key]
|
|
||||||
del cache_entry
|
|
||||||
|
|
||||||
else:
|
|
||||||
pos += 1
|
|
||||||
|
|
||||||
if models_cleared > 0:
|
|
||||||
# There would likely be some 'garbage' to be collected regardless of whether a model was cleared or not, but
|
|
||||||
# there is a significant time cost to calling `gc.collect()`, so we want to use it sparingly. (The time cost
|
|
||||||
# is high even if no garbage gets collected.)
|
|
||||||
#
|
|
||||||
# Calling gc.collect(...) when a model is cleared seems like a good middle-ground:
|
|
||||||
# - If models had to be cleared, it's a signal that we are close to our memory limit.
|
|
||||||
# - If models were cleared, there's a good chance that there's a significant amount of garbage to be
|
|
||||||
# collected.
|
|
||||||
#
|
|
||||||
# Keep in mind that gc is only responsible for handling reference cycles. Most objects should be cleaned up
|
|
||||||
# immediately when their reference count hits 0.
|
|
||||||
gc.collect()
|
|
||||||
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
if choose_torch_device() == torch.device("mps"):
|
|
||||||
mps.empty_cache()
|
|
||||||
|
|
||||||
self.logger.debug(f"After unloading: cached_models={len(self._cached_models)}")
|
|
||||||
|
|
||||||
def _offload_unlocked_models(self, size_needed: int = 0):
|
|
||||||
reserved = self.max_vram_cache_size * GIG
|
|
||||||
vram_in_use = torch.cuda.memory_allocated()
|
|
||||||
self.logger.debug(f"{(vram_in_use/GIG):.2f}GB VRAM used for models; max allowed={(reserved/GIG):.2f}GB")
|
|
||||||
for model_key, cache_entry in sorted(self._cached_models.items(), key=lambda x: x[1].size):
|
|
||||||
if vram_in_use <= reserved:
|
|
||||||
break
|
|
||||||
if not cache_entry.locked and cache_entry.loaded:
|
|
||||||
self._move_model_to_device(model_key, self.storage_device)
|
|
||||||
|
|
||||||
vram_in_use = torch.cuda.memory_allocated()
|
|
||||||
self.logger.debug(f"{(vram_in_use/GIG):.2f}GB VRAM used for models; max allowed={(reserved/GIG):.2f}GB")
|
|
||||||
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
if choose_torch_device() == torch.device("mps"):
|
|
||||||
mps.empty_cache()
|
|
||||||
|
|
||||||
def _local_model_hash(self, model_path: Union[str, Path]) -> str:
|
|
||||||
sha = hashlib.sha256()
|
|
||||||
path = Path(model_path)
|
|
||||||
|
|
||||||
hashpath = path / "checksum.sha256"
|
|
||||||
if hashpath.exists() and path.stat().st_mtime <= hashpath.stat().st_mtime:
|
|
||||||
with open(hashpath) as f:
|
|
||||||
hash = f.read()
|
|
||||||
return hash
|
|
||||||
|
|
||||||
self.logger.debug(f"computing hash of model {path.name}")
|
|
||||||
for file in list(path.rglob("*.ckpt")) + list(path.rglob("*.safetensors")) + list(path.rglob("*.pth")):
|
|
||||||
with open(file, "rb") as f:
|
|
||||||
while chunk := f.read(self.sha_chunksize):
|
|
||||||
sha.update(chunk)
|
|
||||||
hash = sha.hexdigest()
|
|
||||||
with open(hashpath, "w") as f:
|
|
||||||
f.write(hash)
|
|
||||||
return hash
|
|
@ -1,30 +0,0 @@
|
|||||||
from contextlib import contextmanager
|
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
|
|
||||||
def _no_op(*args, **kwargs):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
|
||||||
def skip_torch_weight_init():
|
|
||||||
"""A context manager that monkey-patches several of the common torch layers (torch.nn.Linear, torch.nn.Conv1d, etc.)
|
|
||||||
to skip weight initialization.
|
|
||||||
|
|
||||||
By default, `torch.nn.Linear` and `torch.nn.ConvNd` layers initialize their weights (according to a particular
|
|
||||||
distribution) when __init__ is called. This weight initialization step can take a significant amount of time, and is
|
|
||||||
completely unnecessary if the intent is to load checkpoint weights from disk for the layer. This context manager
|
|
||||||
monkey-patches common torch layers to skip the weight initialization step.
|
|
||||||
"""
|
|
||||||
torch_modules = [torch.nn.Linear, torch.nn.modules.conv._ConvNd, torch.nn.Embedding]
|
|
||||||
saved_functions = [m.reset_parameters for m in torch_modules]
|
|
||||||
|
|
||||||
try:
|
|
||||||
for torch_module in torch_modules:
|
|
||||||
torch_module.reset_parameters = _no_op
|
|
||||||
|
|
||||||
yield None
|
|
||||||
finally:
|
|
||||||
for torch_module, saved_function in zip(torch_modules, saved_functions, strict=True):
|
|
||||||
torch_module.reset_parameters = saved_function
|
|
File diff suppressed because it is too large
Load Diff
@ -1,140 +0,0 @@
|
|||||||
"""
|
|
||||||
invokeai.backend.model_management.model_merge exports:
|
|
||||||
merge_diffusion_models() -- combine multiple models by location and return a pipeline object
|
|
||||||
merge_diffusion_models_and_commit() -- combine multiple models by ModelManager ID and write to models.yaml
|
|
||||||
|
|
||||||
Copyright (c) 2023 Lincoln Stein and the InvokeAI Development Team
|
|
||||||
"""
|
|
||||||
|
|
||||||
import warnings
|
|
||||||
from enum import Enum
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import List, Optional, Union
|
|
||||||
|
|
||||||
from diffusers import DiffusionPipeline
|
|
||||||
from diffusers import logging as dlogging
|
|
||||||
|
|
||||||
import invokeai.backend.util.logging as logger
|
|
||||||
|
|
||||||
from ...backend.model_management import AddModelResult, BaseModelType, ModelManager, ModelType, ModelVariantType
|
|
||||||
|
|
||||||
|
|
||||||
class MergeInterpolationMethod(str, Enum):
|
|
||||||
WeightedSum = "weighted_sum"
|
|
||||||
Sigmoid = "sigmoid"
|
|
||||||
InvSigmoid = "inv_sigmoid"
|
|
||||||
AddDifference = "add_difference"
|
|
||||||
|
|
||||||
|
|
||||||
class ModelMerger(object):
|
|
||||||
def __init__(self, manager: ModelManager):
|
|
||||||
self.manager = manager
|
|
||||||
|
|
||||||
def merge_diffusion_models(
|
|
||||||
self,
|
|
||||||
model_paths: List[Path],
|
|
||||||
alpha: float = 0.5,
|
|
||||||
interp: Optional[MergeInterpolationMethod] = None,
|
|
||||||
force: bool = False,
|
|
||||||
**kwargs,
|
|
||||||
) -> DiffusionPipeline:
|
|
||||||
"""
|
|
||||||
:param model_paths: up to three models, designated by their local paths or HuggingFace repo_ids
|
|
||||||
:param alpha: The interpolation parameter. Ranges from 0 to 1. It affects the ratio in which the checkpoints are merged. A 0.8 alpha
|
|
||||||
would mean that the first model checkpoints would affect the final result far less than an alpha of 0.2
|
|
||||||
:param interp: The interpolation method to use for the merging. Supports "sigmoid", "inv_sigmoid", "add_difference" and None.
|
|
||||||
Passing None uses the default interpolation which is weighted sum interpolation. For merging three checkpoints, only "add_difference" is supported.
|
|
||||||
:param force: Whether to ignore mismatch in model_config.json for the current models. Defaults to False.
|
|
||||||
|
|
||||||
**kwargs - the default DiffusionPipeline.get_config_dict kwargs:
|
|
||||||
cache_dir, resume_download, force_download, proxies, local_files_only, use_auth_token, revision, torch_dtype, device_map
|
|
||||||
"""
|
|
||||||
with warnings.catch_warnings():
|
|
||||||
warnings.simplefilter("ignore")
|
|
||||||
verbosity = dlogging.get_verbosity()
|
|
||||||
dlogging.set_verbosity_error()
|
|
||||||
|
|
||||||
pipe = DiffusionPipeline.from_pretrained(
|
|
||||||
model_paths[0],
|
|
||||||
custom_pipeline="checkpoint_merger",
|
|
||||||
)
|
|
||||||
merged_pipe = pipe.merge(
|
|
||||||
pretrained_model_name_or_path_list=model_paths,
|
|
||||||
alpha=alpha,
|
|
||||||
interp=interp.value if interp else None, # diffusers API treats None as "weighted sum"
|
|
||||||
force=force,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
dlogging.set_verbosity(verbosity)
|
|
||||||
return merged_pipe
|
|
||||||
|
|
||||||
def merge_diffusion_models_and_save(
|
|
||||||
self,
|
|
||||||
model_names: List[str],
|
|
||||||
base_model: Union[BaseModelType, str],
|
|
||||||
merged_model_name: str,
|
|
||||||
alpha: float = 0.5,
|
|
||||||
interp: Optional[MergeInterpolationMethod] = None,
|
|
||||||
force: bool = False,
|
|
||||||
merge_dest_directory: Optional[Path] = None,
|
|
||||||
**kwargs,
|
|
||||||
) -> AddModelResult:
|
|
||||||
"""
|
|
||||||
:param models: up to three models, designated by their InvokeAI models.yaml model name
|
|
||||||
:param base_model: base model (must be the same for all merged models!)
|
|
||||||
:param merged_model_name: name for new model
|
|
||||||
:param alpha: The interpolation parameter. Ranges from 0 to 1. It affects the ratio in which the checkpoints are merged. A 0.8 alpha
|
|
||||||
would mean that the first model checkpoints would affect the final result far less than an alpha of 0.2
|
|
||||||
:param interp: The interpolation method to use for the merging. Supports "weighted_average", "sigmoid", "inv_sigmoid", "add_difference" and None.
|
|
||||||
Passing None uses the default interpolation which is weighted sum interpolation. For merging three checkpoints, only "add_difference" is supported. Add_difference is A+(B-C).
|
|
||||||
:param force: Whether to ignore mismatch in model_config.json for the current models. Defaults to False.
|
|
||||||
:param merge_dest_directory: Save the merged model to the designated directory (with 'merged_model_name' appended)
|
|
||||||
**kwargs - the default DiffusionPipeline.get_config_dict kwargs:
|
|
||||||
cache_dir, resume_download, force_download, proxies, local_files_only, use_auth_token, revision, torch_dtype, device_map
|
|
||||||
"""
|
|
||||||
model_paths = []
|
|
||||||
config = self.manager.app_config
|
|
||||||
base_model = BaseModelType(base_model)
|
|
||||||
vae = None
|
|
||||||
|
|
||||||
for mod in model_names:
|
|
||||||
info = self.manager.list_model(mod, base_model=base_model, model_type=ModelType.Main)
|
|
||||||
assert info, f"model {mod}, base_model {base_model}, is unknown"
|
|
||||||
assert (
|
|
||||||
info["model_format"] == "diffusers"
|
|
||||||
), f"{mod} is not a diffusers model. It must be optimized before merging"
|
|
||||||
assert info["variant"] == "normal", f"{mod} is a {info['variant']} model, which cannot currently be merged"
|
|
||||||
assert (
|
|
||||||
len(model_names) <= 2 or interp == MergeInterpolationMethod.AddDifference
|
|
||||||
), "When merging three models, only the 'add_difference' merge method is supported"
|
|
||||||
# pick up the first model's vae
|
|
||||||
if mod == model_names[0]:
|
|
||||||
vae = info.get("vae")
|
|
||||||
model_paths.extend([(config.root_path / info["path"]).as_posix()])
|
|
||||||
|
|
||||||
merge_method = None if interp == "weighted_sum" else MergeInterpolationMethod(interp)
|
|
||||||
logger.debug(f"interp = {interp}, merge_method={merge_method}")
|
|
||||||
merged_pipe = self.merge_diffusion_models(model_paths, alpha, merge_method, force, **kwargs)
|
|
||||||
dump_path = (
|
|
||||||
Path(merge_dest_directory)
|
|
||||||
if merge_dest_directory
|
|
||||||
else config.models_path / base_model.value / ModelType.Main.value
|
|
||||||
)
|
|
||||||
dump_path.mkdir(parents=True, exist_ok=True)
|
|
||||||
dump_path = (dump_path / merged_model_name).as_posix()
|
|
||||||
|
|
||||||
merged_pipe.save_pretrained(dump_path, safe_serialization=True)
|
|
||||||
attributes = {
|
|
||||||
"path": dump_path,
|
|
||||||
"description": f"Merge of models {', '.join(model_names)}",
|
|
||||||
"model_format": "diffusers",
|
|
||||||
"variant": ModelVariantType.Normal.value,
|
|
||||||
"vae": vae,
|
|
||||||
}
|
|
||||||
return self.manager.add_model(
|
|
||||||
merged_model_name,
|
|
||||||
base_model=base_model,
|
|
||||||
model_type=ModelType.Main,
|
|
||||||
model_attributes=attributes,
|
|
||||||
clobber=True,
|
|
||||||
)
|
|
@ -1,664 +0,0 @@
|
|||||||
import json
|
|
||||||
import re
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Callable, Dict, Literal, Optional, Union
|
|
||||||
|
|
||||||
import safetensors.torch
|
|
||||||
import torch
|
|
||||||
from diffusers import ConfigMixin, ModelMixin
|
|
||||||
from picklescan.scanner import scan_file_path
|
|
||||||
|
|
||||||
from invokeai.backend.model_management.models.ip_adapter import IPAdapterModelFormat
|
|
||||||
|
|
||||||
from .models import (
|
|
||||||
BaseModelType,
|
|
||||||
InvalidModelException,
|
|
||||||
ModelType,
|
|
||||||
ModelVariantType,
|
|
||||||
SchedulerPredictionType,
|
|
||||||
SilenceWarnings,
|
|
||||||
)
|
|
||||||
from .models.base import read_checkpoint_meta
|
|
||||||
from .util import lora_token_vector_length
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class ModelProbeInfo(object):
|
|
||||||
model_type: ModelType
|
|
||||||
base_type: BaseModelType
|
|
||||||
variant_type: ModelVariantType
|
|
||||||
prediction_type: SchedulerPredictionType
|
|
||||||
upcast_attention: bool
|
|
||||||
format: Literal["diffusers", "checkpoint", "lycoris", "olive", "onnx"]
|
|
||||||
image_size: int
|
|
||||||
name: Optional[str] = None
|
|
||||||
description: Optional[str] = None
|
|
||||||
|
|
||||||
|
|
||||||
class ProbeBase(object):
|
|
||||||
"""forward declaration"""
|
|
||||||
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class ModelProbe(object):
|
|
||||||
PROBES = {
|
|
||||||
"diffusers": {},
|
|
||||||
"checkpoint": {},
|
|
||||||
"onnx": {},
|
|
||||||
}
|
|
||||||
|
|
||||||
CLASS2TYPE = {
|
|
||||||
"StableDiffusionPipeline": ModelType.Main,
|
|
||||||
"StableDiffusionInpaintPipeline": ModelType.Main,
|
|
||||||
"StableDiffusionXLPipeline": ModelType.Main,
|
|
||||||
"StableDiffusionXLImg2ImgPipeline": ModelType.Main,
|
|
||||||
"StableDiffusionXLInpaintPipeline": ModelType.Main,
|
|
||||||
"LatentConsistencyModelPipeline": ModelType.Main,
|
|
||||||
"AutoencoderKL": ModelType.Vae,
|
|
||||||
"AutoencoderTiny": ModelType.Vae,
|
|
||||||
"ControlNetModel": ModelType.ControlNet,
|
|
||||||
"CLIPVisionModelWithProjection": ModelType.CLIPVision,
|
|
||||||
"T2IAdapter": ModelType.T2IAdapter,
|
|
||||||
}
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def register_probe(
|
|
||||||
cls, format: Literal["diffusers", "checkpoint", "onnx"], model_type: ModelType, probe_class: ProbeBase
|
|
||||||
):
|
|
||||||
cls.PROBES[format][model_type] = probe_class
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def heuristic_probe(
|
|
||||||
cls,
|
|
||||||
model: Union[Dict, ModelMixin, Path],
|
|
||||||
prediction_type_helper: Callable[[Path], SchedulerPredictionType] = None,
|
|
||||||
) -> ModelProbeInfo:
|
|
||||||
if isinstance(model, Path):
|
|
||||||
return cls.probe(model_path=model, prediction_type_helper=prediction_type_helper)
|
|
||||||
elif isinstance(model, (dict, ModelMixin, ConfigMixin)):
|
|
||||||
return cls.probe(model_path=None, model=model, prediction_type_helper=prediction_type_helper)
|
|
||||||
else:
|
|
||||||
raise InvalidModelException("model parameter {model} is neither a Path, nor a model")
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def probe(
|
|
||||||
cls,
|
|
||||||
model_path: Path,
|
|
||||||
model: Optional[Union[Dict, ModelMixin]] = None,
|
|
||||||
prediction_type_helper: Optional[Callable[[Path], SchedulerPredictionType]] = None,
|
|
||||||
) -> ModelProbeInfo:
|
|
||||||
"""
|
|
||||||
Probe the model at model_path and return sufficient information about it
|
|
||||||
to place it somewhere in the models directory hierarchy. If the model is
|
|
||||||
already loaded into memory, you may provide it as model in order to avoid
|
|
||||||
opening it a second time. The prediction_type_helper callable is a function that receives
|
|
||||||
the path to the model and returns the SchedulerPredictionType.
|
|
||||||
"""
|
|
||||||
if model_path:
|
|
||||||
format_type = "diffusers" if model_path.is_dir() else "checkpoint"
|
|
||||||
else:
|
|
||||||
format_type = "diffusers" if isinstance(model, (ConfigMixin, ModelMixin)) else "checkpoint"
|
|
||||||
model_info = None
|
|
||||||
try:
|
|
||||||
model_type = (
|
|
||||||
cls.get_model_type_from_folder(model_path, model)
|
|
||||||
if format_type == "diffusers"
|
|
||||||
else cls.get_model_type_from_checkpoint(model_path, model)
|
|
||||||
)
|
|
||||||
format_type = "onnx" if model_type == ModelType.ONNX else format_type
|
|
||||||
probe_class = cls.PROBES[format_type].get(model_type)
|
|
||||||
if not probe_class:
|
|
||||||
return None
|
|
||||||
probe = probe_class(model_path, model, prediction_type_helper)
|
|
||||||
base_type = probe.get_base_type()
|
|
||||||
variant_type = probe.get_variant_type()
|
|
||||||
prediction_type = probe.get_scheduler_prediction_type()
|
|
||||||
name = cls.get_model_name(model_path)
|
|
||||||
description = f"{base_type.value} {model_type.value} model {name}"
|
|
||||||
format = probe.get_format()
|
|
||||||
model_info = ModelProbeInfo(
|
|
||||||
model_type=model_type,
|
|
||||||
base_type=base_type,
|
|
||||||
variant_type=variant_type,
|
|
||||||
prediction_type=prediction_type,
|
|
||||||
name=name,
|
|
||||||
description=description,
|
|
||||||
upcast_attention=(
|
|
||||||
base_type == BaseModelType.StableDiffusion2
|
|
||||||
and prediction_type == SchedulerPredictionType.VPrediction
|
|
||||||
),
|
|
||||||
format=format,
|
|
||||||
image_size=(
|
|
||||||
1024
|
|
||||||
if (base_type in {BaseModelType.StableDiffusionXL, BaseModelType.StableDiffusionXLRefiner})
|
|
||||||
else (
|
|
||||||
768
|
|
||||||
if (
|
|
||||||
base_type == BaseModelType.StableDiffusion2
|
|
||||||
and prediction_type == SchedulerPredictionType.VPrediction
|
|
||||||
)
|
|
||||||
else 512
|
|
||||||
)
|
|
||||||
),
|
|
||||||
)
|
|
||||||
except Exception:
|
|
||||||
raise
|
|
||||||
|
|
||||||
return model_info
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_model_name(cls, model_path: Path) -> str:
|
|
||||||
if model_path.suffix in {".safetensors", ".bin", ".pt", ".ckpt"}:
|
|
||||||
return model_path.stem
|
|
||||||
else:
|
|
||||||
return model_path.name
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_model_type_from_checkpoint(cls, model_path: Path, checkpoint: dict) -> ModelType:
|
|
||||||
if model_path.suffix not in (".bin", ".pt", ".ckpt", ".safetensors", ".pth"):
|
|
||||||
return None
|
|
||||||
|
|
||||||
if model_path.name == "learned_embeds.bin":
|
|
||||||
return ModelType.TextualInversion
|
|
||||||
|
|
||||||
ckpt = checkpoint if checkpoint else read_checkpoint_meta(model_path, scan=True)
|
|
||||||
ckpt = ckpt.get("state_dict", ckpt)
|
|
||||||
|
|
||||||
for key in ckpt.keys():
|
|
||||||
if any(key.startswith(v) for v in {"cond_stage_model.", "first_stage_model.", "model.diffusion_model."}):
|
|
||||||
return ModelType.Main
|
|
||||||
elif any(key.startswith(v) for v in {"encoder.conv_in", "decoder.conv_in"}):
|
|
||||||
return ModelType.Vae
|
|
||||||
elif any(key.startswith(v) for v in {"lora_te_", "lora_unet_"}):
|
|
||||||
return ModelType.Lora
|
|
||||||
elif any(key.endswith(v) for v in {"to_k_lora.up.weight", "to_q_lora.down.weight"}):
|
|
||||||
return ModelType.Lora
|
|
||||||
elif any(key.startswith(v) for v in {"control_model", "input_blocks"}):
|
|
||||||
return ModelType.ControlNet
|
|
||||||
elif key in {"emb_params", "string_to_param"}:
|
|
||||||
return ModelType.TextualInversion
|
|
||||||
|
|
||||||
else:
|
|
||||||
# diffusers-ti
|
|
||||||
if len(ckpt) < 10 and all(isinstance(v, torch.Tensor) for v in ckpt.values()):
|
|
||||||
return ModelType.TextualInversion
|
|
||||||
|
|
||||||
raise InvalidModelException(f"Unable to determine model type for {model_path}")
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_model_type_from_folder(cls, folder_path: Path, model: ModelMixin) -> ModelType:
|
|
||||||
"""
|
|
||||||
Get the model type of a hugging-face style folder.
|
|
||||||
"""
|
|
||||||
class_name = None
|
|
||||||
error_hint = None
|
|
||||||
if model:
|
|
||||||
class_name = model.__class__.__name__
|
|
||||||
else:
|
|
||||||
for suffix in ["bin", "safetensors"]:
|
|
||||||
if (folder_path / f"learned_embeds.{suffix}").exists():
|
|
||||||
return ModelType.TextualInversion
|
|
||||||
if (folder_path / f"pytorch_lora_weights.{suffix}").exists():
|
|
||||||
return ModelType.Lora
|
|
||||||
if (folder_path / "unet/model.onnx").exists():
|
|
||||||
return ModelType.ONNX
|
|
||||||
if (folder_path / "image_encoder.txt").exists():
|
|
||||||
return ModelType.IPAdapter
|
|
||||||
|
|
||||||
i = folder_path / "model_index.json"
|
|
||||||
c = folder_path / "config.json"
|
|
||||||
config_path = i if i.exists() else c if c.exists() else None
|
|
||||||
|
|
||||||
if config_path:
|
|
||||||
with open(config_path, "r") as file:
|
|
||||||
conf = json.load(file)
|
|
||||||
if "_class_name" in conf:
|
|
||||||
class_name = conf["_class_name"]
|
|
||||||
elif "architectures" in conf:
|
|
||||||
class_name = conf["architectures"][0]
|
|
||||||
else:
|
|
||||||
class_name = None
|
|
||||||
else:
|
|
||||||
error_hint = f"No model_index.json or config.json found in {folder_path}."
|
|
||||||
|
|
||||||
if class_name and (type := cls.CLASS2TYPE.get(class_name)):
|
|
||||||
return type
|
|
||||||
else:
|
|
||||||
error_hint = f"class {class_name} is not one of the supported classes [{', '.join(cls.CLASS2TYPE.keys())}]"
|
|
||||||
|
|
||||||
# give up
|
|
||||||
raise InvalidModelException(
|
|
||||||
f"Unable to determine model type for {folder_path}" + (f"; {error_hint}" if error_hint else "")
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def _scan_and_load_checkpoint(cls, model_path: Path) -> dict:
|
|
||||||
with SilenceWarnings():
|
|
||||||
if model_path.suffix.endswith((".ckpt", ".pt", ".bin")):
|
|
||||||
cls._scan_model(model_path, model_path)
|
|
||||||
return torch.load(model_path, map_location="cpu")
|
|
||||||
else:
|
|
||||||
return safetensors.torch.load_file(model_path)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def _scan_model(cls, model_name, checkpoint):
|
|
||||||
"""
|
|
||||||
Apply picklescanner to the indicated checkpoint and issue a warning
|
|
||||||
and option to exit if an infected file is identified.
|
|
||||||
"""
|
|
||||||
# scan model
|
|
||||||
scan_result = scan_file_path(checkpoint)
|
|
||||||
if scan_result.infected_files != 0:
|
|
||||||
raise Exception("The model {model_name} is potentially infected by malware. Aborting import.")
|
|
||||||
|
|
||||||
|
|
||||||
# ##################################################3
|
|
||||||
# Checkpoint probing
|
|
||||||
# ##################################################3
|
|
||||||
class ProbeBase(object):
|
|
||||||
def get_base_type(self) -> BaseModelType:
|
|
||||||
pass
|
|
||||||
|
|
||||||
def get_variant_type(self) -> ModelVariantType:
|
|
||||||
pass
|
|
||||||
|
|
||||||
def get_scheduler_prediction_type(self) -> SchedulerPredictionType:
|
|
||||||
pass
|
|
||||||
|
|
||||||
def get_format(self) -> str:
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class CheckpointProbeBase(ProbeBase):
|
|
||||||
def __init__(
|
|
||||||
self, checkpoint_path: Path, checkpoint: dict, helper: Callable[[Path], SchedulerPredictionType] = None
|
|
||||||
) -> BaseModelType:
|
|
||||||
self.checkpoint = checkpoint or ModelProbe._scan_and_load_checkpoint(checkpoint_path)
|
|
||||||
self.checkpoint_path = checkpoint_path
|
|
||||||
self.helper = helper
|
|
||||||
|
|
||||||
def get_base_type(self) -> BaseModelType:
|
|
||||||
pass
|
|
||||||
|
|
||||||
def get_format(self) -> str:
|
|
||||||
return "checkpoint"
|
|
||||||
|
|
||||||
def get_variant_type(self) -> ModelVariantType:
|
|
||||||
model_type = ModelProbe.get_model_type_from_checkpoint(self.checkpoint_path, self.checkpoint)
|
|
||||||
if model_type != ModelType.Main:
|
|
||||||
return ModelVariantType.Normal
|
|
||||||
state_dict = self.checkpoint.get("state_dict") or self.checkpoint
|
|
||||||
in_channels = state_dict["model.diffusion_model.input_blocks.0.0.weight"].shape[1]
|
|
||||||
if in_channels == 9:
|
|
||||||
return ModelVariantType.Inpaint
|
|
||||||
elif in_channels == 5:
|
|
||||||
return ModelVariantType.Depth
|
|
||||||
elif in_channels == 4:
|
|
||||||
return ModelVariantType.Normal
|
|
||||||
else:
|
|
||||||
raise InvalidModelException(
|
|
||||||
f"Cannot determine variant type (in_channels={in_channels}) at {self.checkpoint_path}"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class PipelineCheckpointProbe(CheckpointProbeBase):
|
|
||||||
def get_base_type(self) -> BaseModelType:
|
|
||||||
checkpoint = self.checkpoint
|
|
||||||
state_dict = self.checkpoint.get("state_dict") or checkpoint
|
|
||||||
key_name = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight"
|
|
||||||
if key_name in state_dict and state_dict[key_name].shape[-1] == 768:
|
|
||||||
return BaseModelType.StableDiffusion1
|
|
||||||
if key_name in state_dict and state_dict[key_name].shape[-1] == 1024:
|
|
||||||
return BaseModelType.StableDiffusion2
|
|
||||||
key_name = "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_k.weight"
|
|
||||||
if key_name in state_dict and state_dict[key_name].shape[-1] == 2048:
|
|
||||||
return BaseModelType.StableDiffusionXL
|
|
||||||
elif key_name in state_dict and state_dict[key_name].shape[-1] == 1280:
|
|
||||||
return BaseModelType.StableDiffusionXLRefiner
|
|
||||||
else:
|
|
||||||
raise InvalidModelException("Cannot determine base type")
|
|
||||||
|
|
||||||
def get_scheduler_prediction_type(self) -> Optional[SchedulerPredictionType]:
|
|
||||||
"""Return model prediction type."""
|
|
||||||
# if there is a .yaml associated with this checkpoint, then we do not need
|
|
||||||
# to probe for the prediction type as it will be ignored.
|
|
||||||
if self.checkpoint_path and self.checkpoint_path.with_suffix(".yaml").exists():
|
|
||||||
return None
|
|
||||||
|
|
||||||
type = self.get_base_type()
|
|
||||||
if type == BaseModelType.StableDiffusion2:
|
|
||||||
checkpoint = self.checkpoint
|
|
||||||
state_dict = self.checkpoint.get("state_dict") or checkpoint
|
|
||||||
key_name = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight"
|
|
||||||
if key_name in state_dict and state_dict[key_name].shape[-1] == 1024:
|
|
||||||
if "global_step" in checkpoint:
|
|
||||||
if checkpoint["global_step"] == 220000:
|
|
||||||
return SchedulerPredictionType.Epsilon
|
|
||||||
elif checkpoint["global_step"] == 110000:
|
|
||||||
return SchedulerPredictionType.VPrediction
|
|
||||||
if self.helper and self.checkpoint_path:
|
|
||||||
if helper_guess := self.helper(self.checkpoint_path):
|
|
||||||
return helper_guess
|
|
||||||
return SchedulerPredictionType.VPrediction # a guess for sd2 ckpts
|
|
||||||
|
|
||||||
elif type == BaseModelType.StableDiffusion1:
|
|
||||||
if self.helper and self.checkpoint_path:
|
|
||||||
if helper_guess := self.helper(self.checkpoint_path):
|
|
||||||
return helper_guess
|
|
||||||
return SchedulerPredictionType.Epsilon # a reasonable guess for sd1 ckpts
|
|
||||||
else:
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
class VaeCheckpointProbe(CheckpointProbeBase):
|
|
||||||
def get_base_type(self) -> BaseModelType:
|
|
||||||
# I can't find any standalone 2.X VAEs to test with!
|
|
||||||
return BaseModelType.StableDiffusion1
|
|
||||||
|
|
||||||
|
|
||||||
class LoRACheckpointProbe(CheckpointProbeBase):
|
|
||||||
def get_format(self) -> str:
|
|
||||||
return "lycoris"
|
|
||||||
|
|
||||||
def get_base_type(self) -> BaseModelType:
|
|
||||||
checkpoint = self.checkpoint
|
|
||||||
token_vector_length = lora_token_vector_length(checkpoint)
|
|
||||||
|
|
||||||
if token_vector_length == 768:
|
|
||||||
return BaseModelType.StableDiffusion1
|
|
||||||
elif token_vector_length == 1024:
|
|
||||||
return BaseModelType.StableDiffusion2
|
|
||||||
elif token_vector_length == 1280:
|
|
||||||
return BaseModelType.StableDiffusionXL # recognizes format at https://civitai.com/models/224641
|
|
||||||
elif token_vector_length == 2048:
|
|
||||||
return BaseModelType.StableDiffusionXL
|
|
||||||
else:
|
|
||||||
raise InvalidModelException(f"Unknown LoRA type: {self.checkpoint_path}")
|
|
||||||
|
|
||||||
|
|
||||||
class TextualInversionCheckpointProbe(CheckpointProbeBase):
|
|
||||||
def get_format(self) -> str:
|
|
||||||
return None
|
|
||||||
|
|
||||||
def get_base_type(self) -> BaseModelType:
|
|
||||||
checkpoint = self.checkpoint
|
|
||||||
if "string_to_token" in checkpoint:
|
|
||||||
token_dim = list(checkpoint["string_to_param"].values())[0].shape[-1]
|
|
||||||
elif "emb_params" in checkpoint:
|
|
||||||
token_dim = checkpoint["emb_params"].shape[-1]
|
|
||||||
elif "clip_g" in checkpoint:
|
|
||||||
token_dim = checkpoint["clip_g"].shape[-1]
|
|
||||||
else:
|
|
||||||
token_dim = list(checkpoint.values())[0].shape[-1]
|
|
||||||
if token_dim == 768:
|
|
||||||
return BaseModelType.StableDiffusion1
|
|
||||||
elif token_dim == 1024:
|
|
||||||
return BaseModelType.StableDiffusion2
|
|
||||||
elif token_dim == 1280:
|
|
||||||
return BaseModelType.StableDiffusionXL
|
|
||||||
else:
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
class ControlNetCheckpointProbe(CheckpointProbeBase):
|
|
||||||
def get_base_type(self) -> BaseModelType:
|
|
||||||
checkpoint = self.checkpoint
|
|
||||||
for key_name in (
|
|
||||||
"control_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight",
|
|
||||||
"input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight",
|
|
||||||
):
|
|
||||||
if key_name not in checkpoint:
|
|
||||||
continue
|
|
||||||
if checkpoint[key_name].shape[-1] == 768:
|
|
||||||
return BaseModelType.StableDiffusion1
|
|
||||||
elif checkpoint[key_name].shape[-1] == 1024:
|
|
||||||
return BaseModelType.StableDiffusion2
|
|
||||||
elif self.checkpoint_path and self.helper:
|
|
||||||
return self.helper(self.checkpoint_path)
|
|
||||||
raise InvalidModelException("Unable to determine base type for {self.checkpoint_path}")
|
|
||||||
|
|
||||||
|
|
||||||
class IPAdapterCheckpointProbe(CheckpointProbeBase):
|
|
||||||
def get_base_type(self) -> BaseModelType:
|
|
||||||
raise NotImplementedError()
|
|
||||||
|
|
||||||
|
|
||||||
class CLIPVisionCheckpointProbe(CheckpointProbeBase):
|
|
||||||
def get_base_type(self) -> BaseModelType:
|
|
||||||
raise NotImplementedError()
|
|
||||||
|
|
||||||
|
|
||||||
class T2IAdapterCheckpointProbe(CheckpointProbeBase):
|
|
||||||
def get_base_type(self) -> BaseModelType:
|
|
||||||
raise NotImplementedError()
|
|
||||||
|
|
||||||
|
|
||||||
########################################################
|
|
||||||
# classes for probing folders
|
|
||||||
#######################################################
|
|
||||||
class FolderProbeBase(ProbeBase):
|
|
||||||
def __init__(self, folder_path: Path, model: ModelMixin = None, helper: Callable = None): # not used
|
|
||||||
self.model = model
|
|
||||||
self.folder_path = folder_path
|
|
||||||
|
|
||||||
def get_variant_type(self) -> ModelVariantType:
|
|
||||||
return ModelVariantType.Normal
|
|
||||||
|
|
||||||
def get_format(self) -> str:
|
|
||||||
return "diffusers"
|
|
||||||
|
|
||||||
|
|
||||||
class PipelineFolderProbe(FolderProbeBase):
|
|
||||||
def get_base_type(self) -> BaseModelType:
|
|
||||||
if self.model:
|
|
||||||
unet_conf = self.model.unet.config
|
|
||||||
else:
|
|
||||||
with open(self.folder_path / "unet" / "config.json", "r") as file:
|
|
||||||
unet_conf = json.load(file)
|
|
||||||
if unet_conf["cross_attention_dim"] == 768:
|
|
||||||
return BaseModelType.StableDiffusion1
|
|
||||||
elif unet_conf["cross_attention_dim"] == 1024:
|
|
||||||
return BaseModelType.StableDiffusion2
|
|
||||||
elif unet_conf["cross_attention_dim"] == 1280:
|
|
||||||
return BaseModelType.StableDiffusionXLRefiner
|
|
||||||
elif unet_conf["cross_attention_dim"] == 2048:
|
|
||||||
return BaseModelType.StableDiffusionXL
|
|
||||||
else:
|
|
||||||
raise InvalidModelException(f"Unknown base model for {self.folder_path}")
|
|
||||||
|
|
||||||
def get_scheduler_prediction_type(self) -> SchedulerPredictionType:
|
|
||||||
if self.model:
|
|
||||||
scheduler_conf = self.model.scheduler.config
|
|
||||||
else:
|
|
||||||
with open(self.folder_path / "scheduler" / "scheduler_config.json", "r") as file:
|
|
||||||
scheduler_conf = json.load(file)
|
|
||||||
if scheduler_conf["prediction_type"] == "v_prediction":
|
|
||||||
return SchedulerPredictionType.VPrediction
|
|
||||||
elif scheduler_conf["prediction_type"] == "epsilon":
|
|
||||||
return SchedulerPredictionType.Epsilon
|
|
||||||
else:
|
|
||||||
return None
|
|
||||||
|
|
||||||
def get_variant_type(self) -> ModelVariantType:
|
|
||||||
# This only works for pipelines! Any kind of
|
|
||||||
# exception results in our returning the
|
|
||||||
# "normal" variant type
|
|
||||||
try:
|
|
||||||
if self.model:
|
|
||||||
conf = self.model.unet.config
|
|
||||||
else:
|
|
||||||
config_file = self.folder_path / "unet" / "config.json"
|
|
||||||
with open(config_file, "r") as file:
|
|
||||||
conf = json.load(file)
|
|
||||||
|
|
||||||
in_channels = conf["in_channels"]
|
|
||||||
if in_channels == 9:
|
|
||||||
return ModelVariantType.Inpaint
|
|
||||||
elif in_channels == 5:
|
|
||||||
return ModelVariantType.Depth
|
|
||||||
elif in_channels == 4:
|
|
||||||
return ModelVariantType.Normal
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
return ModelVariantType.Normal
|
|
||||||
|
|
||||||
|
|
||||||
class VaeFolderProbe(FolderProbeBase):
|
|
||||||
def get_base_type(self) -> BaseModelType:
|
|
||||||
if self._config_looks_like_sdxl():
|
|
||||||
return BaseModelType.StableDiffusionXL
|
|
||||||
elif self._name_looks_like_sdxl():
|
|
||||||
# but SD and SDXL VAE are the same shape (3-channel RGB to 4-channel float scaled down
|
|
||||||
# by a factor of 8), we can't necessarily tell them apart by config hyperparameters.
|
|
||||||
return BaseModelType.StableDiffusionXL
|
|
||||||
else:
|
|
||||||
return BaseModelType.StableDiffusion1
|
|
||||||
|
|
||||||
def _config_looks_like_sdxl(self) -> bool:
|
|
||||||
# config values that distinguish Stability's SD 1.x VAE from their SDXL VAE.
|
|
||||||
config_file = self.folder_path / "config.json"
|
|
||||||
if not config_file.exists():
|
|
||||||
raise InvalidModelException(f"Cannot determine base type for {self.folder_path}")
|
|
||||||
with open(config_file, "r") as file:
|
|
||||||
config = json.load(file)
|
|
||||||
return config.get("scaling_factor", 0) == 0.13025 and config.get("sample_size") in [512, 1024]
|
|
||||||
|
|
||||||
def _name_looks_like_sdxl(self) -> bool:
|
|
||||||
return bool(re.search(r"xl\b", self._guess_name(), re.IGNORECASE))
|
|
||||||
|
|
||||||
def _guess_name(self) -> str:
|
|
||||||
name = self.folder_path.name
|
|
||||||
if name == "vae":
|
|
||||||
name = self.folder_path.parent.name
|
|
||||||
return name
|
|
||||||
|
|
||||||
|
|
||||||
class TextualInversionFolderProbe(FolderProbeBase):
|
|
||||||
def get_format(self) -> str:
|
|
||||||
return None
|
|
||||||
|
|
||||||
def get_base_type(self) -> BaseModelType:
|
|
||||||
path = self.folder_path / "learned_embeds.bin"
|
|
||||||
if not path.exists():
|
|
||||||
return None
|
|
||||||
checkpoint = ModelProbe._scan_and_load_checkpoint(path)
|
|
||||||
return TextualInversionCheckpointProbe(None, checkpoint=checkpoint).get_base_type()
|
|
||||||
|
|
||||||
|
|
||||||
class ONNXFolderProbe(FolderProbeBase):
|
|
||||||
def get_format(self) -> str:
|
|
||||||
return "onnx"
|
|
||||||
|
|
||||||
def get_base_type(self) -> BaseModelType:
|
|
||||||
return BaseModelType.StableDiffusion1
|
|
||||||
|
|
||||||
def get_variant_type(self) -> ModelVariantType:
|
|
||||||
return ModelVariantType.Normal
|
|
||||||
|
|
||||||
|
|
||||||
class ControlNetFolderProbe(FolderProbeBase):
|
|
||||||
def get_base_type(self) -> BaseModelType:
|
|
||||||
config_file = self.folder_path / "config.json"
|
|
||||||
if not config_file.exists():
|
|
||||||
raise InvalidModelException(f"Cannot determine base type for {self.folder_path}")
|
|
||||||
with open(config_file, "r") as file:
|
|
||||||
config = json.load(file)
|
|
||||||
# no obvious way to distinguish between sd2-base and sd2-768
|
|
||||||
dimension = config["cross_attention_dim"]
|
|
||||||
base_model = (
|
|
||||||
BaseModelType.StableDiffusion1
|
|
||||||
if dimension == 768
|
|
||||||
else (
|
|
||||||
BaseModelType.StableDiffusion2
|
|
||||||
if dimension == 1024
|
|
||||||
else BaseModelType.StableDiffusionXL
|
|
||||||
if dimension == 2048
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
)
|
|
||||||
if not base_model:
|
|
||||||
raise InvalidModelException(f"Unable to determine model base for {self.folder_path}")
|
|
||||||
return base_model
|
|
||||||
|
|
||||||
|
|
||||||
class LoRAFolderProbe(FolderProbeBase):
|
|
||||||
def get_base_type(self) -> BaseModelType:
|
|
||||||
model_file = None
|
|
||||||
for suffix in ["safetensors", "bin"]:
|
|
||||||
base_file = self.folder_path / f"pytorch_lora_weights.{suffix}"
|
|
||||||
if base_file.exists():
|
|
||||||
model_file = base_file
|
|
||||||
break
|
|
||||||
if not model_file:
|
|
||||||
raise InvalidModelException("Unknown LoRA format encountered")
|
|
||||||
return LoRACheckpointProbe(model_file, None).get_base_type()
|
|
||||||
|
|
||||||
|
|
||||||
class IPAdapterFolderProbe(FolderProbeBase):
|
|
||||||
def get_format(self) -> str:
|
|
||||||
return IPAdapterModelFormat.InvokeAI.value
|
|
||||||
|
|
||||||
def get_base_type(self) -> BaseModelType:
|
|
||||||
model_file = self.folder_path / "ip_adapter.bin"
|
|
||||||
if not model_file.exists():
|
|
||||||
raise InvalidModelException("Unknown IP-Adapter model format.")
|
|
||||||
|
|
||||||
state_dict = torch.load(model_file, map_location="cpu")
|
|
||||||
cross_attention_dim = state_dict["ip_adapter"]["1.to_k_ip.weight"].shape[-1]
|
|
||||||
if cross_attention_dim == 768:
|
|
||||||
return BaseModelType.StableDiffusion1
|
|
||||||
elif cross_attention_dim == 1024:
|
|
||||||
return BaseModelType.StableDiffusion2
|
|
||||||
elif cross_attention_dim == 2048:
|
|
||||||
return BaseModelType.StableDiffusionXL
|
|
||||||
else:
|
|
||||||
raise InvalidModelException(f"IP-Adapter had unexpected cross-attention dimension: {cross_attention_dim}.")
|
|
||||||
|
|
||||||
|
|
||||||
class CLIPVisionFolderProbe(FolderProbeBase):
|
|
||||||
def get_base_type(self) -> BaseModelType:
|
|
||||||
return BaseModelType.Any
|
|
||||||
|
|
||||||
|
|
||||||
class T2IAdapterFolderProbe(FolderProbeBase):
|
|
||||||
def get_base_type(self) -> BaseModelType:
|
|
||||||
config_file = self.folder_path / "config.json"
|
|
||||||
if not config_file.exists():
|
|
||||||
raise InvalidModelException(f"Cannot determine base type for {self.folder_path}")
|
|
||||||
with open(config_file, "r") as file:
|
|
||||||
config = json.load(file)
|
|
||||||
|
|
||||||
adapter_type = config.get("adapter_type", None)
|
|
||||||
if adapter_type == "full_adapter_xl":
|
|
||||||
return BaseModelType.StableDiffusionXL
|
|
||||||
elif adapter_type == "full_adapter" or "light_adapter":
|
|
||||||
# I haven't seen any T2I adapter models for SD2, so assume that this is an SD1 adapter.
|
|
||||||
return BaseModelType.StableDiffusion1
|
|
||||||
else:
|
|
||||||
raise InvalidModelException(
|
|
||||||
f"Unable to determine base model for '{self.folder_path}' (adapter_type = {adapter_type})."
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
############## register probe classes ######
|
|
||||||
ModelProbe.register_probe("diffusers", ModelType.Main, PipelineFolderProbe)
|
|
||||||
ModelProbe.register_probe("diffusers", ModelType.Vae, VaeFolderProbe)
|
|
||||||
ModelProbe.register_probe("diffusers", ModelType.Lora, LoRAFolderProbe)
|
|
||||||
ModelProbe.register_probe("diffusers", ModelType.TextualInversion, TextualInversionFolderProbe)
|
|
||||||
ModelProbe.register_probe("diffusers", ModelType.ControlNet, ControlNetFolderProbe)
|
|
||||||
ModelProbe.register_probe("diffusers", ModelType.IPAdapter, IPAdapterFolderProbe)
|
|
||||||
ModelProbe.register_probe("diffusers", ModelType.CLIPVision, CLIPVisionFolderProbe)
|
|
||||||
ModelProbe.register_probe("diffusers", ModelType.T2IAdapter, T2IAdapterFolderProbe)
|
|
||||||
|
|
||||||
ModelProbe.register_probe("checkpoint", ModelType.Main, PipelineCheckpointProbe)
|
|
||||||
ModelProbe.register_probe("checkpoint", ModelType.Vae, VaeCheckpointProbe)
|
|
||||||
ModelProbe.register_probe("checkpoint", ModelType.Lora, LoRACheckpointProbe)
|
|
||||||
ModelProbe.register_probe("checkpoint", ModelType.TextualInversion, TextualInversionCheckpointProbe)
|
|
||||||
ModelProbe.register_probe("checkpoint", ModelType.ControlNet, ControlNetCheckpointProbe)
|
|
||||||
ModelProbe.register_probe("checkpoint", ModelType.IPAdapter, IPAdapterCheckpointProbe)
|
|
||||||
ModelProbe.register_probe("checkpoint", ModelType.CLIPVision, CLIPVisionCheckpointProbe)
|
|
||||||
ModelProbe.register_probe("checkpoint", ModelType.T2IAdapter, T2IAdapterCheckpointProbe)
|
|
||||||
|
|
||||||
ModelProbe.register_probe("onnx", ModelType.ONNX, ONNXFolderProbe)
|
|
@ -1,112 +0,0 @@
|
|||||||
# Copyright 2023, Lincoln D. Stein and the InvokeAI Team
|
|
||||||
"""
|
|
||||||
Abstract base class for recursive directory search for models.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import os
|
|
||||||
from abc import ABC, abstractmethod
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import List, Set, types
|
|
||||||
|
|
||||||
import invokeai.backend.util.logging as logger
|
|
||||||
|
|
||||||
|
|
||||||
class ModelSearch(ABC):
|
|
||||||
def __init__(self, directories: List[Path], logger: types.ModuleType = logger):
|
|
||||||
"""
|
|
||||||
Initialize a recursive model directory search.
|
|
||||||
:param directories: List of directory Paths to recurse through
|
|
||||||
:param logger: Logger to use
|
|
||||||
"""
|
|
||||||
self.directories = directories
|
|
||||||
self.logger = logger
|
|
||||||
self._items_scanned = 0
|
|
||||||
self._models_found = 0
|
|
||||||
self._scanned_dirs = set()
|
|
||||||
self._scanned_paths = set()
|
|
||||||
self._pruned_paths = set()
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def on_search_started(self):
|
|
||||||
"""
|
|
||||||
Called before the scan starts.
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def on_model_found(self, model: Path):
|
|
||||||
"""
|
|
||||||
Process a found model. Raise an exception if something goes wrong.
|
|
||||||
:param model: Model to process - could be a directory or checkpoint.
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def on_search_completed(self):
|
|
||||||
"""
|
|
||||||
Perform some activity when the scan is completed. May use instance
|
|
||||||
variables, items_scanned and models_found
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
def search(self):
|
|
||||||
self.on_search_started()
|
|
||||||
for dir in self.directories:
|
|
||||||
self.walk_directory(dir)
|
|
||||||
self.on_search_completed()
|
|
||||||
|
|
||||||
def walk_directory(self, path: Path):
|
|
||||||
for root, dirs, files in os.walk(path, followlinks=True):
|
|
||||||
if str(Path(root).name).startswith("."):
|
|
||||||
self._pruned_paths.add(root)
|
|
||||||
if any(Path(root).is_relative_to(x) for x in self._pruned_paths):
|
|
||||||
continue
|
|
||||||
|
|
||||||
self._items_scanned += len(dirs) + len(files)
|
|
||||||
for d in dirs:
|
|
||||||
path = Path(root) / d
|
|
||||||
if path in self._scanned_paths or path.parent in self._scanned_dirs:
|
|
||||||
self._scanned_dirs.add(path)
|
|
||||||
continue
|
|
||||||
if any(
|
|
||||||
(path / x).exists()
|
|
||||||
for x in {
|
|
||||||
"config.json",
|
|
||||||
"model_index.json",
|
|
||||||
"learned_embeds.bin",
|
|
||||||
"pytorch_lora_weights.bin",
|
|
||||||
"image_encoder.txt",
|
|
||||||
}
|
|
||||||
):
|
|
||||||
try:
|
|
||||||
self.on_model_found(path)
|
|
||||||
self._models_found += 1
|
|
||||||
self._scanned_dirs.add(path)
|
|
||||||
except Exception as e:
|
|
||||||
self.logger.warning(f"Failed to process '{path}': {e}")
|
|
||||||
|
|
||||||
for f in files:
|
|
||||||
path = Path(root) / f
|
|
||||||
if path.parent in self._scanned_dirs:
|
|
||||||
continue
|
|
||||||
if path.suffix in {".ckpt", ".bin", ".pth", ".safetensors", ".pt"}:
|
|
||||||
try:
|
|
||||||
self.on_model_found(path)
|
|
||||||
self._models_found += 1
|
|
||||||
except Exception as e:
|
|
||||||
self.logger.warning(f"Failed to process '{path}': {e}")
|
|
||||||
|
|
||||||
|
|
||||||
class FindModels(ModelSearch):
|
|
||||||
def on_search_started(self):
|
|
||||||
self.models_found: Set[Path] = set()
|
|
||||||
|
|
||||||
def on_model_found(self, model: Path):
|
|
||||||
self.models_found.add(model)
|
|
||||||
|
|
||||||
def on_search_completed(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def list_models(self) -> List[Path]:
|
|
||||||
self.search()
|
|
||||||
return list(self.models_found)
|
|
@ -1,167 +0,0 @@
|
|||||||
import inspect
|
|
||||||
from enum import Enum
|
|
||||||
from typing import Literal, get_origin
|
|
||||||
|
|
||||||
from pydantic import BaseModel, ConfigDict, create_model
|
|
||||||
|
|
||||||
from .base import ( # noqa: F401
|
|
||||||
BaseModelType,
|
|
||||||
DuplicateModelException,
|
|
||||||
InvalidModelException,
|
|
||||||
ModelBase,
|
|
||||||
ModelConfigBase,
|
|
||||||
ModelError,
|
|
||||||
ModelNotFoundException,
|
|
||||||
ModelType,
|
|
||||||
ModelVariantType,
|
|
||||||
SchedulerPredictionType,
|
|
||||||
SilenceWarnings,
|
|
||||||
SubModelType,
|
|
||||||
)
|
|
||||||
from .clip_vision import CLIPVisionModel
|
|
||||||
from .controlnet import ControlNetModel # TODO:
|
|
||||||
from .ip_adapter import IPAdapterModel
|
|
||||||
from .lora import LoRAModel
|
|
||||||
from .sdxl import StableDiffusionXLModel
|
|
||||||
from .stable_diffusion import StableDiffusion1Model, StableDiffusion2Model
|
|
||||||
from .stable_diffusion_onnx import ONNXStableDiffusion1Model, ONNXStableDiffusion2Model
|
|
||||||
from .t2i_adapter import T2IAdapterModel
|
|
||||||
from .textual_inversion import TextualInversionModel
|
|
||||||
from .vae import VaeModel
|
|
||||||
|
|
||||||
MODEL_CLASSES = {
|
|
||||||
BaseModelType.StableDiffusion1: {
|
|
||||||
ModelType.ONNX: ONNXStableDiffusion1Model,
|
|
||||||
ModelType.Main: StableDiffusion1Model,
|
|
||||||
ModelType.Vae: VaeModel,
|
|
||||||
ModelType.Lora: LoRAModel,
|
|
||||||
ModelType.ControlNet: ControlNetModel,
|
|
||||||
ModelType.TextualInversion: TextualInversionModel,
|
|
||||||
ModelType.IPAdapter: IPAdapterModel,
|
|
||||||
ModelType.CLIPVision: CLIPVisionModel,
|
|
||||||
ModelType.T2IAdapter: T2IAdapterModel,
|
|
||||||
},
|
|
||||||
BaseModelType.StableDiffusion2: {
|
|
||||||
ModelType.ONNX: ONNXStableDiffusion2Model,
|
|
||||||
ModelType.Main: StableDiffusion2Model,
|
|
||||||
ModelType.Vae: VaeModel,
|
|
||||||
ModelType.Lora: LoRAModel,
|
|
||||||
ModelType.ControlNet: ControlNetModel,
|
|
||||||
ModelType.TextualInversion: TextualInversionModel,
|
|
||||||
ModelType.IPAdapter: IPAdapterModel,
|
|
||||||
ModelType.CLIPVision: CLIPVisionModel,
|
|
||||||
ModelType.T2IAdapter: T2IAdapterModel,
|
|
||||||
},
|
|
||||||
BaseModelType.StableDiffusionXL: {
|
|
||||||
ModelType.Main: StableDiffusionXLModel,
|
|
||||||
ModelType.Vae: VaeModel,
|
|
||||||
# will not work until support written
|
|
||||||
ModelType.Lora: LoRAModel,
|
|
||||||
ModelType.ControlNet: ControlNetModel,
|
|
||||||
ModelType.TextualInversion: TextualInversionModel,
|
|
||||||
ModelType.ONNX: ONNXStableDiffusion2Model,
|
|
||||||
ModelType.IPAdapter: IPAdapterModel,
|
|
||||||
ModelType.CLIPVision: CLIPVisionModel,
|
|
||||||
ModelType.T2IAdapter: T2IAdapterModel,
|
|
||||||
},
|
|
||||||
BaseModelType.StableDiffusionXLRefiner: {
|
|
||||||
ModelType.Main: StableDiffusionXLModel,
|
|
||||||
ModelType.Vae: VaeModel,
|
|
||||||
# will not work until support written
|
|
||||||
ModelType.Lora: LoRAModel,
|
|
||||||
ModelType.ControlNet: ControlNetModel,
|
|
||||||
ModelType.TextualInversion: TextualInversionModel,
|
|
||||||
ModelType.ONNX: ONNXStableDiffusion2Model,
|
|
||||||
ModelType.IPAdapter: IPAdapterModel,
|
|
||||||
ModelType.CLIPVision: CLIPVisionModel,
|
|
||||||
ModelType.T2IAdapter: T2IAdapterModel,
|
|
||||||
},
|
|
||||||
BaseModelType.Any: {
|
|
||||||
ModelType.CLIPVision: CLIPVisionModel,
|
|
||||||
# The following model types are not expected to be used with BaseModelType.Any.
|
|
||||||
ModelType.ONNX: ONNXStableDiffusion2Model,
|
|
||||||
ModelType.Main: StableDiffusion2Model,
|
|
||||||
ModelType.Vae: VaeModel,
|
|
||||||
ModelType.Lora: LoRAModel,
|
|
||||||
ModelType.ControlNet: ControlNetModel,
|
|
||||||
ModelType.TextualInversion: TextualInversionModel,
|
|
||||||
ModelType.IPAdapter: IPAdapterModel,
|
|
||||||
ModelType.T2IAdapter: T2IAdapterModel,
|
|
||||||
},
|
|
||||||
# BaseModelType.Kandinsky2_1: {
|
|
||||||
# ModelType.Main: Kandinsky2_1Model,
|
|
||||||
# ModelType.MoVQ: MoVQModel,
|
|
||||||
# ModelType.Lora: LoRAModel,
|
|
||||||
# ModelType.ControlNet: ControlNetModel,
|
|
||||||
# ModelType.TextualInversion: TextualInversionModel,
|
|
||||||
# },
|
|
||||||
}
|
|
||||||
|
|
||||||
MODEL_CONFIGS = []
|
|
||||||
OPENAPI_MODEL_CONFIGS = []
|
|
||||||
|
|
||||||
|
|
||||||
class OpenAPIModelInfoBase(BaseModel):
|
|
||||||
model_name: str
|
|
||||||
base_model: BaseModelType
|
|
||||||
model_type: ModelType
|
|
||||||
|
|
||||||
model_config = ConfigDict(protected_namespaces=())
|
|
||||||
|
|
||||||
|
|
||||||
for _base_model, models in MODEL_CLASSES.items():
|
|
||||||
for model_type, model_class in models.items():
|
|
||||||
model_configs = set(model_class._get_configs().values())
|
|
||||||
model_configs.discard(None)
|
|
||||||
MODEL_CONFIGS.extend(model_configs)
|
|
||||||
|
|
||||||
# LS: sort to get the checkpoint configs first, which makes
|
|
||||||
# for a better template in the Swagger docs
|
|
||||||
for cfg in sorted(model_configs, key=lambda x: str(x)):
|
|
||||||
model_name, cfg_name = cfg.__qualname__.split(".")[-2:]
|
|
||||||
openapi_cfg_name = model_name + cfg_name
|
|
||||||
if openapi_cfg_name in vars():
|
|
||||||
continue
|
|
||||||
|
|
||||||
api_wrapper = create_model(
|
|
||||||
openapi_cfg_name,
|
|
||||||
__base__=(cfg, OpenAPIModelInfoBase),
|
|
||||||
model_type=(Literal[model_type], model_type), # type: ignore
|
|
||||||
)
|
|
||||||
vars()[openapi_cfg_name] = api_wrapper
|
|
||||||
OPENAPI_MODEL_CONFIGS.append(api_wrapper)
|
|
||||||
|
|
||||||
|
|
||||||
def get_model_config_enums():
|
|
||||||
enums = []
|
|
||||||
|
|
||||||
for model_config in MODEL_CONFIGS:
|
|
||||||
if hasattr(inspect, "get_annotations"):
|
|
||||||
fields = inspect.get_annotations(model_config)
|
|
||||||
else:
|
|
||||||
fields = model_config.__annotations__
|
|
||||||
try:
|
|
||||||
field = fields["model_format"]
|
|
||||||
except Exception:
|
|
||||||
raise Exception("format field not found")
|
|
||||||
|
|
||||||
# model_format: None
|
|
||||||
# model_format: SomeModelFormat
|
|
||||||
# model_format: Literal[SomeModelFormat.Diffusers]
|
|
||||||
# model_format: Literal[SomeModelFormat.Diffusers, SomeModelFormat.Checkpoint]
|
|
||||||
|
|
||||||
if isinstance(field, type) and issubclass(field, str) and issubclass(field, Enum):
|
|
||||||
enums.append(field)
|
|
||||||
|
|
||||||
elif get_origin(field) is Literal and all(
|
|
||||||
isinstance(arg, str) and isinstance(arg, Enum) for arg in field.__args__
|
|
||||||
):
|
|
||||||
enums.append(type(field.__args__[0]))
|
|
||||||
|
|
||||||
elif field is None:
|
|
||||||
pass
|
|
||||||
|
|
||||||
else:
|
|
||||||
raise Exception(f"Unsupported format definition in {model_configs.__qualname__}")
|
|
||||||
|
|
||||||
return enums
|
|
@ -1,681 +0,0 @@
|
|||||||
import inspect
|
|
||||||
import json
|
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
import typing
|
|
||||||
import warnings
|
|
||||||
from abc import ABCMeta, abstractmethod
|
|
||||||
from contextlib import suppress
|
|
||||||
from enum import Enum
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Any, Callable, Dict, Generic, List, Literal, Optional, Type, TypeVar, Union
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import onnx
|
|
||||||
import safetensors.torch
|
|
||||||
import torch
|
|
||||||
from diffusers import ConfigMixin, DiffusionPipeline
|
|
||||||
from diffusers import logging as diffusers_logging
|
|
||||||
from onnx import numpy_helper
|
|
||||||
from onnxruntime import InferenceSession, SessionOptions, get_available_providers
|
|
||||||
from picklescan.scanner import scan_file_path
|
|
||||||
from pydantic import BaseModel, ConfigDict, Field
|
|
||||||
from transformers import logging as transformers_logging
|
|
||||||
|
|
||||||
|
|
||||||
class DuplicateModelException(Exception):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class InvalidModelException(Exception):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class ModelNotFoundException(Exception):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class BaseModelType(str, Enum):
|
|
||||||
Any = "any" # For models that are not associated with any particular base model.
|
|
||||||
StableDiffusion1 = "sd-1"
|
|
||||||
StableDiffusion2 = "sd-2"
|
|
||||||
StableDiffusionXL = "sdxl"
|
|
||||||
StableDiffusionXLRefiner = "sdxl-refiner"
|
|
||||||
# Kandinsky2_1 = "kandinsky-2.1"
|
|
||||||
|
|
||||||
|
|
||||||
class ModelType(str, Enum):
|
|
||||||
ONNX = "onnx"
|
|
||||||
Main = "main"
|
|
||||||
Vae = "vae"
|
|
||||||
Lora = "lora"
|
|
||||||
ControlNet = "controlnet" # used by model_probe
|
|
||||||
TextualInversion = "embedding"
|
|
||||||
IPAdapter = "ip_adapter"
|
|
||||||
CLIPVision = "clip_vision"
|
|
||||||
T2IAdapter = "t2i_adapter"
|
|
||||||
|
|
||||||
|
|
||||||
class SubModelType(str, Enum):
|
|
||||||
UNet = "unet"
|
|
||||||
TextEncoder = "text_encoder"
|
|
||||||
TextEncoder2 = "text_encoder_2"
|
|
||||||
Tokenizer = "tokenizer"
|
|
||||||
Tokenizer2 = "tokenizer_2"
|
|
||||||
Vae = "vae"
|
|
||||||
VaeDecoder = "vae_decoder"
|
|
||||||
VaeEncoder = "vae_encoder"
|
|
||||||
Scheduler = "scheduler"
|
|
||||||
SafetyChecker = "safety_checker"
|
|
||||||
# MoVQ = "movq"
|
|
||||||
|
|
||||||
|
|
||||||
class ModelVariantType(str, Enum):
|
|
||||||
Normal = "normal"
|
|
||||||
Inpaint = "inpaint"
|
|
||||||
Depth = "depth"
|
|
||||||
|
|
||||||
|
|
||||||
class SchedulerPredictionType(str, Enum):
|
|
||||||
Epsilon = "epsilon"
|
|
||||||
VPrediction = "v_prediction"
|
|
||||||
Sample = "sample"
|
|
||||||
|
|
||||||
|
|
||||||
class ModelError(str, Enum):
|
|
||||||
NotFound = "not_found"
|
|
||||||
|
|
||||||
|
|
||||||
def model_config_json_schema_extra(schema: dict[str, Any]) -> None:
|
|
||||||
if "required" not in schema:
|
|
||||||
schema["required"] = []
|
|
||||||
schema["required"].append("model_type")
|
|
||||||
|
|
||||||
|
|
||||||
class ModelConfigBase(BaseModel):
|
|
||||||
path: str # or Path
|
|
||||||
description: Optional[str] = Field(None)
|
|
||||||
model_format: Optional[str] = Field(None)
|
|
||||||
error: Optional[ModelError] = Field(None)
|
|
||||||
|
|
||||||
model_config = ConfigDict(
|
|
||||||
use_enum_values=True, protected_namespaces=(), json_schema_extra=model_config_json_schema_extra
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class EmptyConfigLoader(ConfigMixin):
|
|
||||||
@classmethod
|
|
||||||
def load_config(cls, *args, **kwargs):
|
|
||||||
cls.config_name = kwargs.pop("config_name")
|
|
||||||
return super().load_config(*args, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
T_co = TypeVar("T_co", covariant=True)
|
|
||||||
|
|
||||||
|
|
||||||
class classproperty(Generic[T_co]):
|
|
||||||
def __init__(self, fget: Callable[[Any], T_co]) -> None:
|
|
||||||
self.fget = fget
|
|
||||||
|
|
||||||
def __get__(self, instance: Optional[Any], owner: Type[Any]) -> T_co:
|
|
||||||
return self.fget(owner)
|
|
||||||
|
|
||||||
def __set__(self, instance: Optional[Any], value: Any) -> None:
|
|
||||||
raise AttributeError("cannot set attribute")
|
|
||||||
|
|
||||||
|
|
||||||
class ModelBase(metaclass=ABCMeta):
|
|
||||||
# model_path: str
|
|
||||||
# base_model: BaseModelType
|
|
||||||
# model_type: ModelType
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
model_path: str,
|
|
||||||
base_model: BaseModelType,
|
|
||||||
model_type: ModelType,
|
|
||||||
):
|
|
||||||
self.model_path = model_path
|
|
||||||
self.base_model = base_model
|
|
||||||
self.model_type = model_type
|
|
||||||
|
|
||||||
def _hf_definition_to_type(self, subtypes: List[str]) -> Type:
|
|
||||||
if len(subtypes) < 2:
|
|
||||||
raise Exception("Invalid subfolder definition!")
|
|
||||||
if all(t is None for t in subtypes):
|
|
||||||
return None
|
|
||||||
elif any(t is None for t in subtypes):
|
|
||||||
raise Exception(f"Unsupported definition: {subtypes}")
|
|
||||||
|
|
||||||
if subtypes[0] in ["diffusers", "transformers"]:
|
|
||||||
res_type = sys.modules[subtypes[0]]
|
|
||||||
subtypes = subtypes[1:]
|
|
||||||
|
|
||||||
else:
|
|
||||||
res_type = sys.modules["diffusers"]
|
|
||||||
res_type = res_type.pipelines
|
|
||||||
|
|
||||||
for subtype in subtypes:
|
|
||||||
res_type = getattr(res_type, subtype)
|
|
||||||
return res_type
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def _get_configs(cls):
|
|
||||||
with suppress(Exception):
|
|
||||||
return cls.__configs
|
|
||||||
|
|
||||||
configs = {}
|
|
||||||
for name in dir(cls):
|
|
||||||
if name.startswith("__"):
|
|
||||||
continue
|
|
||||||
|
|
||||||
value = getattr(cls, name)
|
|
||||||
if not isinstance(value, type) or not issubclass(value, ModelConfigBase):
|
|
||||||
continue
|
|
||||||
|
|
||||||
if hasattr(inspect, "get_annotations"):
|
|
||||||
fields = inspect.get_annotations(value)
|
|
||||||
else:
|
|
||||||
fields = value.__annotations__
|
|
||||||
try:
|
|
||||||
field = fields["model_format"]
|
|
||||||
except Exception:
|
|
||||||
raise Exception(f"Invalid config definition - format field not found({cls.__qualname__})")
|
|
||||||
|
|
||||||
if isinstance(field, type) and issubclass(field, str) and issubclass(field, Enum):
|
|
||||||
for model_format in field:
|
|
||||||
configs[model_format.value] = value
|
|
||||||
|
|
||||||
elif typing.get_origin(field) is Literal and all(
|
|
||||||
isinstance(arg, str) and isinstance(arg, Enum) for arg in field.__args__
|
|
||||||
):
|
|
||||||
for model_format in field.__args__:
|
|
||||||
configs[model_format.value] = value
|
|
||||||
|
|
||||||
elif field is None:
|
|
||||||
configs[None] = value
|
|
||||||
|
|
||||||
else:
|
|
||||||
raise Exception(f"Unsupported format definition in {cls.__qualname__}")
|
|
||||||
|
|
||||||
cls.__configs = configs
|
|
||||||
return cls.__configs
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def create_config(cls, **kwargs) -> ModelConfigBase:
|
|
||||||
if "model_format" not in kwargs:
|
|
||||||
raise Exception("Field 'model_format' not found in model config")
|
|
||||||
|
|
||||||
configs = cls._get_configs()
|
|
||||||
return configs[kwargs["model_format"]](**kwargs)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def probe_config(cls, path: str, **kwargs) -> ModelConfigBase:
|
|
||||||
return cls.create_config(
|
|
||||||
path=path,
|
|
||||||
model_format=cls.detect_format(path),
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
@abstractmethod
|
|
||||||
def detect_format(cls, path: str) -> str:
|
|
||||||
raise NotImplementedError()
|
|
||||||
|
|
||||||
@classproperty
|
|
||||||
@abstractmethod
|
|
||||||
def save_to_config(cls) -> bool:
|
|
||||||
raise NotImplementedError()
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def get_size(self, child_type: Optional[SubModelType] = None) -> int:
|
|
||||||
raise NotImplementedError()
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def get_model(
|
|
||||||
self,
|
|
||||||
torch_dtype: Optional[torch.dtype],
|
|
||||||
child_type: Optional[SubModelType] = None,
|
|
||||||
) -> Any:
|
|
||||||
raise NotImplementedError()
|
|
||||||
|
|
||||||
|
|
||||||
class DiffusersModel(ModelBase):
|
|
||||||
# child_types: Dict[str, Type]
|
|
||||||
# child_sizes: Dict[str, int]
|
|
||||||
|
|
||||||
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
|
|
||||||
super().__init__(model_path, base_model, model_type)
|
|
||||||
|
|
||||||
self.child_types: Dict[str, Type] = {}
|
|
||||||
self.child_sizes: Dict[str, int] = {}
|
|
||||||
|
|
||||||
try:
|
|
||||||
config_data = DiffusionPipeline.load_config(self.model_path)
|
|
||||||
# config_data = json.loads(os.path.join(self.model_path, "model_index.json"))
|
|
||||||
except Exception:
|
|
||||||
raise Exception("Invalid diffusers model! (model_index.json not found or invalid)")
|
|
||||||
|
|
||||||
config_data.pop("_ignore_files", None)
|
|
||||||
|
|
||||||
# retrieve all folder_names that contain relevant files
|
|
||||||
child_components = [k for k, v in config_data.items() if isinstance(v, list)]
|
|
||||||
|
|
||||||
for child_name in child_components:
|
|
||||||
child_type = self._hf_definition_to_type(config_data[child_name])
|
|
||||||
self.child_types[child_name] = child_type
|
|
||||||
self.child_sizes[child_name] = calc_model_size_by_fs(self.model_path, subfolder=child_name)
|
|
||||||
|
|
||||||
def get_size(self, child_type: Optional[SubModelType] = None):
|
|
||||||
if child_type is None:
|
|
||||||
return sum(self.child_sizes.values())
|
|
||||||
else:
|
|
||||||
return self.child_sizes[child_type]
|
|
||||||
|
|
||||||
def get_model(
|
|
||||||
self,
|
|
||||||
torch_dtype: Optional[torch.dtype],
|
|
||||||
child_type: Optional[SubModelType] = None,
|
|
||||||
):
|
|
||||||
# return pipeline in different function to pass more arguments
|
|
||||||
if child_type is None:
|
|
||||||
raise Exception("Child model type can't be null on diffusers model")
|
|
||||||
if child_type not in self.child_types:
|
|
||||||
return None # TODO: or raise
|
|
||||||
|
|
||||||
if torch_dtype == torch.float16:
|
|
||||||
variants = ["fp16", None]
|
|
||||||
else:
|
|
||||||
variants = [None, "fp16"]
|
|
||||||
|
|
||||||
# TODO: better error handling(differentiate not found from others)
|
|
||||||
for variant in variants:
|
|
||||||
try:
|
|
||||||
# TODO: set cache_dir to /dev/null to be sure that cache not used?
|
|
||||||
model = self.child_types[child_type].from_pretrained(
|
|
||||||
self.model_path,
|
|
||||||
subfolder=child_type.value,
|
|
||||||
torch_dtype=torch_dtype,
|
|
||||||
variant=variant,
|
|
||||||
local_files_only=True,
|
|
||||||
)
|
|
||||||
break
|
|
||||||
except Exception as e:
|
|
||||||
if not str(e).startswith("Error no file"):
|
|
||||||
print("====ERR LOAD====")
|
|
||||||
print(f"{variant}: {e}")
|
|
||||||
pass
|
|
||||||
else:
|
|
||||||
raise Exception(f"Failed to load {self.base_model}:{self.model_type}:{child_type} model")
|
|
||||||
|
|
||||||
# calc more accurate size
|
|
||||||
self.child_sizes[child_type] = calc_model_size_by_data(model)
|
|
||||||
return model
|
|
||||||
|
|
||||||
# def convert_if_required(model_path: str, cache_path: str, config: Optional[dict]) -> str:
|
|
||||||
|
|
||||||
|
|
||||||
def calc_model_size_by_fs(model_path: str, subfolder: Optional[str] = None, variant: Optional[str] = None):
|
|
||||||
if subfolder is not None:
|
|
||||||
model_path = os.path.join(model_path, subfolder)
|
|
||||||
|
|
||||||
# this can happen when, for example, the safety checker
|
|
||||||
# is not downloaded.
|
|
||||||
if not os.path.exists(model_path):
|
|
||||||
return 0
|
|
||||||
|
|
||||||
all_files = os.listdir(model_path)
|
|
||||||
all_files = [f for f in all_files if os.path.isfile(os.path.join(model_path, f))]
|
|
||||||
|
|
||||||
fp16_files = {f for f in all_files if ".fp16." in f or ".fp16-" in f}
|
|
||||||
bit8_files = {f for f in all_files if ".8bit." in f or ".8bit-" in f}
|
|
||||||
other_files = set(all_files) - fp16_files - bit8_files
|
|
||||||
|
|
||||||
if variant is None:
|
|
||||||
files = other_files
|
|
||||||
elif variant == "fp16":
|
|
||||||
files = fp16_files
|
|
||||||
elif variant == "8bit":
|
|
||||||
files = bit8_files
|
|
||||||
else:
|
|
||||||
raise NotImplementedError(f"Unknown variant: {variant}")
|
|
||||||
|
|
||||||
# try read from index if exists
|
|
||||||
index_postfix = ".index.json"
|
|
||||||
if variant is not None:
|
|
||||||
index_postfix = f".index.{variant}.json"
|
|
||||||
|
|
||||||
for file in files:
|
|
||||||
if not file.endswith(index_postfix):
|
|
||||||
continue
|
|
||||||
try:
|
|
||||||
with open(os.path.join(model_path, file), "r") as f:
|
|
||||||
index_data = json.loads(f.read())
|
|
||||||
return int(index_data["metadata"]["total_size"])
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
# calculate files size if there is no index file
|
|
||||||
formats = [
|
|
||||||
(".safetensors",), # safetensors
|
|
||||||
(".bin",), # torch
|
|
||||||
(".onnx", ".pb"), # onnx
|
|
||||||
(".msgpack",), # flax
|
|
||||||
(".ckpt",), # tf
|
|
||||||
(".h5",), # tf2
|
|
||||||
]
|
|
||||||
|
|
||||||
for file_format in formats:
|
|
||||||
model_files = [f for f in files if f.endswith(file_format)]
|
|
||||||
if len(model_files) == 0:
|
|
||||||
continue
|
|
||||||
|
|
||||||
model_size = 0
|
|
||||||
for model_file in model_files:
|
|
||||||
file_stats = os.stat(os.path.join(model_path, model_file))
|
|
||||||
model_size += file_stats.st_size
|
|
||||||
return model_size
|
|
||||||
|
|
||||||
# raise NotImplementedError(f"Unknown model structure! Files: {all_files}")
|
|
||||||
return 0 # scheduler/feature_extractor/tokenizer - models without loading to gpu
|
|
||||||
|
|
||||||
|
|
||||||
def calc_model_size_by_data(model) -> int:
|
|
||||||
if isinstance(model, DiffusionPipeline):
|
|
||||||
return _calc_pipeline_by_data(model)
|
|
||||||
elif isinstance(model, torch.nn.Module):
|
|
||||||
return _calc_model_by_data(model)
|
|
||||||
elif isinstance(model, IAIOnnxRuntimeModel):
|
|
||||||
return _calc_onnx_model_by_data(model)
|
|
||||||
else:
|
|
||||||
return 0
|
|
||||||
|
|
||||||
|
|
||||||
def _calc_pipeline_by_data(pipeline) -> int:
|
|
||||||
res = 0
|
|
||||||
for submodel_key in pipeline.components.keys():
|
|
||||||
submodel = getattr(pipeline, submodel_key)
|
|
||||||
if submodel is not None and isinstance(submodel, torch.nn.Module):
|
|
||||||
res += _calc_model_by_data(submodel)
|
|
||||||
return res
|
|
||||||
|
|
||||||
|
|
||||||
def _calc_model_by_data(model) -> int:
|
|
||||||
mem_params = sum([param.nelement() * param.element_size() for param in model.parameters()])
|
|
||||||
mem_bufs = sum([buf.nelement() * buf.element_size() for buf in model.buffers()])
|
|
||||||
mem = mem_params + mem_bufs # in bytes
|
|
||||||
return mem
|
|
||||||
|
|
||||||
|
|
||||||
def _calc_onnx_model_by_data(model) -> int:
|
|
||||||
tensor_size = model.tensors.size() * 2 # The session doubles this
|
|
||||||
mem = tensor_size # in bytes
|
|
||||||
return mem
|
|
||||||
|
|
||||||
|
|
||||||
def _fast_safetensors_reader(path: str):
|
|
||||||
checkpoint = {}
|
|
||||||
device = torch.device("meta")
|
|
||||||
with open(path, "rb") as f:
|
|
||||||
definition_len = int.from_bytes(f.read(8), "little")
|
|
||||||
definition_json = f.read(definition_len)
|
|
||||||
definition = json.loads(definition_json)
|
|
||||||
|
|
||||||
if "__metadata__" in definition and definition["__metadata__"].get("format", "pt") not in {
|
|
||||||
"pt",
|
|
||||||
"torch",
|
|
||||||
"pytorch",
|
|
||||||
}:
|
|
||||||
raise Exception("Supported only pytorch safetensors files")
|
|
||||||
definition.pop("__metadata__", None)
|
|
||||||
|
|
||||||
for key, info in definition.items():
|
|
||||||
dtype = {
|
|
||||||
"I8": torch.int8,
|
|
||||||
"I16": torch.int16,
|
|
||||||
"I32": torch.int32,
|
|
||||||
"I64": torch.int64,
|
|
||||||
"F16": torch.float16,
|
|
||||||
"F32": torch.float32,
|
|
||||||
"F64": torch.float64,
|
|
||||||
}[info["dtype"]]
|
|
||||||
|
|
||||||
checkpoint[key] = torch.empty(info["shape"], dtype=dtype, device=device)
|
|
||||||
|
|
||||||
return checkpoint
|
|
||||||
|
|
||||||
|
|
||||||
def read_checkpoint_meta(path: Union[str, Path], scan: bool = False):
|
|
||||||
if str(path).endswith(".safetensors"):
|
|
||||||
try:
|
|
||||||
checkpoint = _fast_safetensors_reader(path)
|
|
||||||
except Exception:
|
|
||||||
# TODO: create issue for support "meta"?
|
|
||||||
checkpoint = safetensors.torch.load_file(path, device="cpu")
|
|
||||||
else:
|
|
||||||
if scan:
|
|
||||||
scan_result = scan_file_path(path)
|
|
||||||
if scan_result.infected_files != 0:
|
|
||||||
raise Exception(f'The model file "{path}" is potentially infected by malware. Aborting import.')
|
|
||||||
checkpoint = torch.load(path, map_location=torch.device("meta"))
|
|
||||||
return checkpoint
|
|
||||||
|
|
||||||
|
|
||||||
class SilenceWarnings(object):
|
|
||||||
def __init__(self):
|
|
||||||
self.transformers_verbosity = transformers_logging.get_verbosity()
|
|
||||||
self.diffusers_verbosity = diffusers_logging.get_verbosity()
|
|
||||||
|
|
||||||
def __enter__(self):
|
|
||||||
transformers_logging.set_verbosity_error()
|
|
||||||
diffusers_logging.set_verbosity_error()
|
|
||||||
warnings.simplefilter("ignore")
|
|
||||||
|
|
||||||
def __exit__(self, type, value, traceback):
|
|
||||||
transformers_logging.set_verbosity(self.transformers_verbosity)
|
|
||||||
diffusers_logging.set_verbosity(self.diffusers_verbosity)
|
|
||||||
warnings.simplefilter("default")
|
|
||||||
|
|
||||||
|
|
||||||
ONNX_WEIGHTS_NAME = "model.onnx"
|
|
||||||
|
|
||||||
|
|
||||||
class IAIOnnxRuntimeModel:
|
|
||||||
class _tensor_access:
|
|
||||||
def __init__(self, model):
|
|
||||||
self.model = model
|
|
||||||
self.indexes = {}
|
|
||||||
for idx, obj in enumerate(self.model.proto.graph.initializer):
|
|
||||||
self.indexes[obj.name] = idx
|
|
||||||
|
|
||||||
def __getitem__(self, key: str):
|
|
||||||
value = self.model.proto.graph.initializer[self.indexes[key]]
|
|
||||||
return numpy_helper.to_array(value)
|
|
||||||
|
|
||||||
def __setitem__(self, key: str, value: np.ndarray):
|
|
||||||
new_node = numpy_helper.from_array(value)
|
|
||||||
# set_external_data(new_node, location="in-memory-location")
|
|
||||||
new_node.name = key
|
|
||||||
# new_node.ClearField("raw_data")
|
|
||||||
del self.model.proto.graph.initializer[self.indexes[key]]
|
|
||||||
self.model.proto.graph.initializer.insert(self.indexes[key], new_node)
|
|
||||||
# self.model.data[key] = OrtValue.ortvalue_from_numpy(value)
|
|
||||||
|
|
||||||
# __delitem__
|
|
||||||
|
|
||||||
def __contains__(self, key: str):
|
|
||||||
return self.indexes[key] in self.model.proto.graph.initializer
|
|
||||||
|
|
||||||
def items(self):
|
|
||||||
raise NotImplementedError("tensor.items")
|
|
||||||
# return [(obj.name, obj) for obj in self.raw_proto]
|
|
||||||
|
|
||||||
def keys(self):
|
|
||||||
return self.indexes.keys()
|
|
||||||
|
|
||||||
def values(self):
|
|
||||||
raise NotImplementedError("tensor.values")
|
|
||||||
# return [obj for obj in self.raw_proto]
|
|
||||||
|
|
||||||
def size(self):
|
|
||||||
bytesSum = 0
|
|
||||||
for node in self.model.proto.graph.initializer:
|
|
||||||
bytesSum += sys.getsizeof(node.raw_data)
|
|
||||||
return bytesSum
|
|
||||||
|
|
||||||
class _access_helper:
|
|
||||||
def __init__(self, raw_proto):
|
|
||||||
self.indexes = {}
|
|
||||||
self.raw_proto = raw_proto
|
|
||||||
for idx, obj in enumerate(raw_proto):
|
|
||||||
self.indexes[obj.name] = idx
|
|
||||||
|
|
||||||
def __getitem__(self, key: str):
|
|
||||||
return self.raw_proto[self.indexes[key]]
|
|
||||||
|
|
||||||
def __setitem__(self, key: str, value):
|
|
||||||
index = self.indexes[key]
|
|
||||||
del self.raw_proto[index]
|
|
||||||
self.raw_proto.insert(index, value)
|
|
||||||
|
|
||||||
# __delitem__
|
|
||||||
|
|
||||||
def __contains__(self, key: str):
|
|
||||||
return key in self.indexes
|
|
||||||
|
|
||||||
def items(self):
|
|
||||||
return [(obj.name, obj) for obj in self.raw_proto]
|
|
||||||
|
|
||||||
def keys(self):
|
|
||||||
return self.indexes.keys()
|
|
||||||
|
|
||||||
def values(self):
|
|
||||||
return list(self.raw_proto)
|
|
||||||
|
|
||||||
def __init__(self, model_path: str, provider: Optional[str]):
|
|
||||||
self.path = model_path
|
|
||||||
self.session = None
|
|
||||||
self.provider = provider
|
|
||||||
"""
|
|
||||||
self.data_path = self.path + "_data"
|
|
||||||
if not os.path.exists(self.data_path):
|
|
||||||
print(f"Moving model tensors to separate file: {self.data_path}")
|
|
||||||
tmp_proto = onnx.load(model_path, load_external_data=True)
|
|
||||||
onnx.save_model(tmp_proto, self.path, save_as_external_data=True, all_tensors_to_one_file=True, location=os.path.basename(self.data_path), size_threshold=1024, convert_attribute=False)
|
|
||||||
del tmp_proto
|
|
||||||
gc.collect()
|
|
||||||
|
|
||||||
self.proto = onnx.load(model_path, load_external_data=False)
|
|
||||||
"""
|
|
||||||
|
|
||||||
self.proto = onnx.load(model_path, load_external_data=True)
|
|
||||||
# self.data = dict()
|
|
||||||
# for tensor in self.proto.graph.initializer:
|
|
||||||
# name = tensor.name
|
|
||||||
|
|
||||||
# if tensor.HasField("raw_data"):
|
|
||||||
# npt = numpy_helper.to_array(tensor)
|
|
||||||
# orv = OrtValue.ortvalue_from_numpy(npt)
|
|
||||||
# # self.data[name] = orv
|
|
||||||
# # set_external_data(tensor, location="in-memory-location")
|
|
||||||
# tensor.name = name
|
|
||||||
# # tensor.ClearField("raw_data")
|
|
||||||
|
|
||||||
self.nodes = self._access_helper(self.proto.graph.node)
|
|
||||||
# self.initializers = self._access_helper(self.proto.graph.initializer)
|
|
||||||
# print(self.proto.graph.input)
|
|
||||||
# print(self.proto.graph.initializer)
|
|
||||||
|
|
||||||
self.tensors = self._tensor_access(self)
|
|
||||||
|
|
||||||
# TODO: integrate with model manager/cache
|
|
||||||
def create_session(self, height=None, width=None):
|
|
||||||
if self.session is None or self.session_width != width or self.session_height != height:
|
|
||||||
# onnx.save(self.proto, "tmp.onnx")
|
|
||||||
# onnx.save_model(self.proto, "tmp.onnx", save_as_external_data=True, all_tensors_to_one_file=True, location="tmp.onnx_data", size_threshold=1024, convert_attribute=False)
|
|
||||||
# TODO: something to be able to get weight when they already moved outside of model proto
|
|
||||||
# (trimmed_model, external_data) = buffer_external_data_tensors(self.proto)
|
|
||||||
sess = SessionOptions()
|
|
||||||
# self._external_data.update(**external_data)
|
|
||||||
# sess.add_external_initializers(list(self.data.keys()), list(self.data.values()))
|
|
||||||
# sess.enable_profiling = True
|
|
||||||
|
|
||||||
# sess.intra_op_num_threads = 1
|
|
||||||
# sess.inter_op_num_threads = 1
|
|
||||||
# sess.execution_mode = ExecutionMode.ORT_SEQUENTIAL
|
|
||||||
# sess.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL
|
|
||||||
# sess.enable_cpu_mem_arena = True
|
|
||||||
# sess.enable_mem_pattern = True
|
|
||||||
# sess.add_session_config_entry("session.intra_op.use_xnnpack_threadpool", "1") ########### It's the key code
|
|
||||||
self.session_height = height
|
|
||||||
self.session_width = width
|
|
||||||
if height and width:
|
|
||||||
sess.add_free_dimension_override_by_name("unet_sample_batch", 2)
|
|
||||||
sess.add_free_dimension_override_by_name("unet_sample_channels", 4)
|
|
||||||
sess.add_free_dimension_override_by_name("unet_hidden_batch", 2)
|
|
||||||
sess.add_free_dimension_override_by_name("unet_hidden_sequence", 77)
|
|
||||||
sess.add_free_dimension_override_by_name("unet_sample_height", self.session_height)
|
|
||||||
sess.add_free_dimension_override_by_name("unet_sample_width", self.session_width)
|
|
||||||
sess.add_free_dimension_override_by_name("unet_time_batch", 1)
|
|
||||||
providers = []
|
|
||||||
if self.provider:
|
|
||||||
providers.append(self.provider)
|
|
||||||
else:
|
|
||||||
providers = get_available_providers()
|
|
||||||
if "TensorrtExecutionProvider" in providers:
|
|
||||||
providers.remove("TensorrtExecutionProvider")
|
|
||||||
try:
|
|
||||||
self.session = InferenceSession(self.proto.SerializeToString(), providers=providers, sess_options=sess)
|
|
||||||
except Exception as e:
|
|
||||||
raise e
|
|
||||||
# self.session = InferenceSession("tmp.onnx", providers=[self.provider], sess_options=self.sess_options)
|
|
||||||
# self.io_binding = self.session.io_binding()
|
|
||||||
|
|
||||||
def release_session(self):
|
|
||||||
self.session = None
|
|
||||||
import gc
|
|
||||||
|
|
||||||
gc.collect()
|
|
||||||
return
|
|
||||||
|
|
||||||
def __call__(self, **kwargs):
|
|
||||||
if self.session is None:
|
|
||||||
raise Exception("You should call create_session before running model")
|
|
||||||
|
|
||||||
inputs = {k: np.array(v) for k, v in kwargs.items()}
|
|
||||||
# output_names = self.session.get_outputs()
|
|
||||||
# for k in inputs:
|
|
||||||
# self.io_binding.bind_cpu_input(k, inputs[k])
|
|
||||||
# for name in output_names:
|
|
||||||
# self.io_binding.bind_output(name.name)
|
|
||||||
# self.session.run_with_iobinding(self.io_binding, None)
|
|
||||||
# return self.io_binding.copy_outputs_to_cpu()
|
|
||||||
return self.session.run(None, inputs)
|
|
||||||
|
|
||||||
# compatability with diffusers load code
|
|
||||||
@classmethod
|
|
||||||
def from_pretrained(
|
|
||||||
cls,
|
|
||||||
model_id: Union[str, Path],
|
|
||||||
subfolder: Union[str, Path] = None,
|
|
||||||
file_name: Optional[str] = None,
|
|
||||||
provider: Optional[str] = None,
|
|
||||||
sess_options: Optional["SessionOptions"] = None,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
file_name = file_name or ONNX_WEIGHTS_NAME
|
|
||||||
|
|
||||||
if os.path.isdir(model_id):
|
|
||||||
model_path = model_id
|
|
||||||
if subfolder is not None:
|
|
||||||
model_path = os.path.join(model_path, subfolder)
|
|
||||||
model_path = os.path.join(model_path, file_name)
|
|
||||||
|
|
||||||
else:
|
|
||||||
model_path = model_id
|
|
||||||
|
|
||||||
# load model from local directory
|
|
||||||
if not os.path.isfile(model_path):
|
|
||||||
raise Exception(f"Model not found: {model_path}")
|
|
||||||
|
|
||||||
# TODO: session options
|
|
||||||
return cls(model_path, provider=provider)
|
|
@ -1,82 +0,0 @@
|
|||||||
import os
|
|
||||||
from enum import Enum
|
|
||||||
from typing import Literal, Optional
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from transformers import CLIPVisionModelWithProjection
|
|
||||||
|
|
||||||
from invokeai.backend.model_management.models.base import (
|
|
||||||
BaseModelType,
|
|
||||||
InvalidModelException,
|
|
||||||
ModelBase,
|
|
||||||
ModelConfigBase,
|
|
||||||
ModelType,
|
|
||||||
SubModelType,
|
|
||||||
calc_model_size_by_data,
|
|
||||||
calc_model_size_by_fs,
|
|
||||||
classproperty,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class CLIPVisionModelFormat(str, Enum):
|
|
||||||
Diffusers = "diffusers"
|
|
||||||
|
|
||||||
|
|
||||||
class CLIPVisionModel(ModelBase):
|
|
||||||
class DiffusersConfig(ModelConfigBase):
|
|
||||||
model_format: Literal[CLIPVisionModelFormat.Diffusers]
|
|
||||||
|
|
||||||
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
|
|
||||||
assert model_type == ModelType.CLIPVision
|
|
||||||
super().__init__(model_path, base_model, model_type)
|
|
||||||
|
|
||||||
self.model_size = calc_model_size_by_fs(self.model_path)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def detect_format(cls, path: str) -> str:
|
|
||||||
if not os.path.exists(path):
|
|
||||||
raise ModuleNotFoundError(f"No CLIP Vision model at path '{path}'.")
|
|
||||||
|
|
||||||
if os.path.isdir(path) and os.path.exists(os.path.join(path, "config.json")):
|
|
||||||
return CLIPVisionModelFormat.Diffusers
|
|
||||||
|
|
||||||
raise InvalidModelException(f"Unexpected CLIP Vision model format: {path}")
|
|
||||||
|
|
||||||
@classproperty
|
|
||||||
def save_to_config(cls) -> bool:
|
|
||||||
return True
|
|
||||||
|
|
||||||
def get_size(self, child_type: Optional[SubModelType] = None) -> int:
|
|
||||||
if child_type is not None:
|
|
||||||
raise ValueError("There are no child models in a CLIP Vision model.")
|
|
||||||
|
|
||||||
return self.model_size
|
|
||||||
|
|
||||||
def get_model(
|
|
||||||
self,
|
|
||||||
torch_dtype: Optional[torch.dtype],
|
|
||||||
child_type: Optional[SubModelType] = None,
|
|
||||||
) -> CLIPVisionModelWithProjection:
|
|
||||||
if child_type is not None:
|
|
||||||
raise ValueError("There are no child models in a CLIP Vision model.")
|
|
||||||
|
|
||||||
model = CLIPVisionModelWithProjection.from_pretrained(self.model_path, torch_dtype=torch_dtype)
|
|
||||||
|
|
||||||
# Calculate a more accurate model size.
|
|
||||||
self.model_size = calc_model_size_by_data(model)
|
|
||||||
|
|
||||||
return model
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def convert_if_required(
|
|
||||||
cls,
|
|
||||||
model_path: str,
|
|
||||||
output_path: str,
|
|
||||||
config: ModelConfigBase,
|
|
||||||
base_model: BaseModelType,
|
|
||||||
) -> str:
|
|
||||||
format = cls.detect_format(model_path)
|
|
||||||
if format == CLIPVisionModelFormat.Diffusers:
|
|
||||||
return model_path
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unsupported format: '{format}'.")
|
|
@ -1,162 +0,0 @@
|
|||||||
import os
|
|
||||||
from enum import Enum
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Literal, Optional
|
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
import invokeai.backend.util.logging as logger
|
|
||||||
from invokeai.app.services.config import InvokeAIAppConfig
|
|
||||||
|
|
||||||
from .base import (
|
|
||||||
BaseModelType,
|
|
||||||
EmptyConfigLoader,
|
|
||||||
InvalidModelException,
|
|
||||||
ModelBase,
|
|
||||||
ModelConfigBase,
|
|
||||||
ModelNotFoundException,
|
|
||||||
ModelType,
|
|
||||||
SubModelType,
|
|
||||||
calc_model_size_by_data,
|
|
||||||
calc_model_size_by_fs,
|
|
||||||
classproperty,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class ControlNetModelFormat(str, Enum):
|
|
||||||
Checkpoint = "checkpoint"
|
|
||||||
Diffusers = "diffusers"
|
|
||||||
|
|
||||||
|
|
||||||
class ControlNetModel(ModelBase):
|
|
||||||
# model_class: Type
|
|
||||||
# model_size: int
|
|
||||||
|
|
||||||
class DiffusersConfig(ModelConfigBase):
|
|
||||||
model_format: Literal[ControlNetModelFormat.Diffusers]
|
|
||||||
|
|
||||||
class CheckpointConfig(ModelConfigBase):
|
|
||||||
model_format: Literal[ControlNetModelFormat.Checkpoint]
|
|
||||||
config: str
|
|
||||||
|
|
||||||
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
|
|
||||||
assert model_type == ModelType.ControlNet
|
|
||||||
super().__init__(model_path, base_model, model_type)
|
|
||||||
|
|
||||||
try:
|
|
||||||
config = EmptyConfigLoader.load_config(self.model_path, config_name="config.json")
|
|
||||||
# config = json.loads(os.path.join(self.model_path, "config.json"))
|
|
||||||
except Exception:
|
|
||||||
raise Exception("Invalid controlnet model! (config.json not found or invalid)")
|
|
||||||
|
|
||||||
model_class_name = config.get("_class_name", None)
|
|
||||||
if model_class_name not in {"ControlNetModel"}:
|
|
||||||
raise Exception(f"Invalid ControlNet model! Unknown _class_name: {model_class_name}")
|
|
||||||
|
|
||||||
try:
|
|
||||||
self.model_class = self._hf_definition_to_type(["diffusers", model_class_name])
|
|
||||||
self.model_size = calc_model_size_by_fs(self.model_path)
|
|
||||||
except Exception:
|
|
||||||
raise Exception("Invalid ControlNet model!")
|
|
||||||
|
|
||||||
def get_size(self, child_type: Optional[SubModelType] = None):
|
|
||||||
if child_type is not None:
|
|
||||||
raise Exception("There is no child models in controlnet model")
|
|
||||||
return self.model_size
|
|
||||||
|
|
||||||
def get_model(
|
|
||||||
self,
|
|
||||||
torch_dtype: Optional[torch.dtype],
|
|
||||||
child_type: Optional[SubModelType] = None,
|
|
||||||
):
|
|
||||||
if child_type is not None:
|
|
||||||
raise Exception("There are no child models in controlnet model")
|
|
||||||
|
|
||||||
model = None
|
|
||||||
for variant in ["fp16", None]:
|
|
||||||
try:
|
|
||||||
model = self.model_class.from_pretrained(
|
|
||||||
self.model_path,
|
|
||||||
torch_dtype=torch_dtype,
|
|
||||||
variant=variant,
|
|
||||||
)
|
|
||||||
break
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
if not model:
|
|
||||||
raise ModelNotFoundException()
|
|
||||||
|
|
||||||
# calc more accurate size
|
|
||||||
self.model_size = calc_model_size_by_data(model)
|
|
||||||
return model
|
|
||||||
|
|
||||||
@classproperty
|
|
||||||
def save_to_config(cls) -> bool:
|
|
||||||
return False
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def detect_format(cls, path: str):
|
|
||||||
if not os.path.exists(path):
|
|
||||||
raise ModelNotFoundException()
|
|
||||||
|
|
||||||
if os.path.isdir(path):
|
|
||||||
if os.path.exists(os.path.join(path, "config.json")):
|
|
||||||
return ControlNetModelFormat.Diffusers
|
|
||||||
|
|
||||||
if os.path.isfile(path):
|
|
||||||
if any(path.endswith(f".{ext}") for ext in ["safetensors", "ckpt", "pt", "pth"]):
|
|
||||||
return ControlNetModelFormat.Checkpoint
|
|
||||||
|
|
||||||
raise InvalidModelException(f"Not a valid model: {path}")
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def convert_if_required(
|
|
||||||
cls,
|
|
||||||
model_path: str,
|
|
||||||
output_path: str,
|
|
||||||
config: ModelConfigBase,
|
|
||||||
base_model: BaseModelType,
|
|
||||||
) -> str:
|
|
||||||
if cls.detect_format(model_path) == ControlNetModelFormat.Checkpoint:
|
|
||||||
return _convert_controlnet_ckpt_and_cache(
|
|
||||||
model_path=model_path,
|
|
||||||
model_config=config.config,
|
|
||||||
output_path=output_path,
|
|
||||||
base_model=base_model,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
return model_path
|
|
||||||
|
|
||||||
|
|
||||||
def _convert_controlnet_ckpt_and_cache(
|
|
||||||
model_path: str,
|
|
||||||
output_path: str,
|
|
||||||
base_model: BaseModelType,
|
|
||||||
model_config: str,
|
|
||||||
) -> str:
|
|
||||||
"""
|
|
||||||
Convert the controlnet from checkpoint format to diffusers format,
|
|
||||||
cache it to disk, and return Path to converted
|
|
||||||
file. If already on disk then just returns Path.
|
|
||||||
"""
|
|
||||||
app_config = InvokeAIAppConfig.get_config()
|
|
||||||
weights = app_config.root_path / model_path
|
|
||||||
output_path = Path(output_path)
|
|
||||||
|
|
||||||
logger.info(f"Converting {weights} to diffusers format")
|
|
||||||
# return cached version if it exists
|
|
||||||
if output_path.exists():
|
|
||||||
return output_path
|
|
||||||
|
|
||||||
# to avoid circular import errors
|
|
||||||
from ..convert_ckpt_to_diffusers import convert_controlnet_to_diffusers
|
|
||||||
|
|
||||||
convert_controlnet_to_diffusers(
|
|
||||||
weights,
|
|
||||||
output_path,
|
|
||||||
original_config_file=app_config.root_path / model_config,
|
|
||||||
image_size=512,
|
|
||||||
scan_needed=True,
|
|
||||||
from_safetensors=weights.suffix == ".safetensors",
|
|
||||||
)
|
|
||||||
return output_path
|
|
@ -1,98 +0,0 @@
|
|||||||
import os
|
|
||||||
import typing
|
|
||||||
from enum import Enum
|
|
||||||
from typing import Literal, Optional
|
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter, IPAdapterPlus, build_ip_adapter
|
|
||||||
from invokeai.backend.model_management.models.base import (
|
|
||||||
BaseModelType,
|
|
||||||
InvalidModelException,
|
|
||||||
ModelBase,
|
|
||||||
ModelConfigBase,
|
|
||||||
ModelType,
|
|
||||||
SubModelType,
|
|
||||||
calc_model_size_by_fs,
|
|
||||||
classproperty,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class IPAdapterModelFormat(str, Enum):
|
|
||||||
# The custom IP-Adapter model format defined by InvokeAI.
|
|
||||||
InvokeAI = "invokeai"
|
|
||||||
|
|
||||||
|
|
||||||
class IPAdapterModel(ModelBase):
|
|
||||||
class InvokeAIConfig(ModelConfigBase):
|
|
||||||
model_format: Literal[IPAdapterModelFormat.InvokeAI]
|
|
||||||
|
|
||||||
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
|
|
||||||
assert model_type == ModelType.IPAdapter
|
|
||||||
super().__init__(model_path, base_model, model_type)
|
|
||||||
|
|
||||||
self.model_size = calc_model_size_by_fs(self.model_path)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def detect_format(cls, path: str) -> str:
|
|
||||||
if not os.path.exists(path):
|
|
||||||
raise ModuleNotFoundError(f"No IP-Adapter model at path '{path}'.")
|
|
||||||
|
|
||||||
if os.path.isdir(path):
|
|
||||||
model_file = os.path.join(path, "ip_adapter.bin")
|
|
||||||
image_encoder_config_file = os.path.join(path, "image_encoder.txt")
|
|
||||||
if os.path.exists(model_file) and os.path.exists(image_encoder_config_file):
|
|
||||||
return IPAdapterModelFormat.InvokeAI
|
|
||||||
|
|
||||||
raise InvalidModelException(f"Unexpected IP-Adapter model format: {path}")
|
|
||||||
|
|
||||||
@classproperty
|
|
||||||
def save_to_config(cls) -> bool:
|
|
||||||
return True
|
|
||||||
|
|
||||||
def get_size(self, child_type: Optional[SubModelType] = None) -> int:
|
|
||||||
if child_type is not None:
|
|
||||||
raise ValueError("There are no child models in an IP-Adapter model.")
|
|
||||||
|
|
||||||
return self.model_size
|
|
||||||
|
|
||||||
def get_model(
|
|
||||||
self,
|
|
||||||
torch_dtype: torch.dtype,
|
|
||||||
child_type: Optional[SubModelType] = None,
|
|
||||||
) -> typing.Union[IPAdapter, IPAdapterPlus]:
|
|
||||||
if child_type is not None:
|
|
||||||
raise ValueError("There are no child models in an IP-Adapter model.")
|
|
||||||
|
|
||||||
model = build_ip_adapter(
|
|
||||||
ip_adapter_ckpt_path=os.path.join(self.model_path, "ip_adapter.bin"),
|
|
||||||
device=torch.device("cpu"),
|
|
||||||
dtype=torch_dtype,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.model_size = model.calc_size()
|
|
||||||
return model
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def convert_if_required(
|
|
||||||
cls,
|
|
||||||
model_path: str,
|
|
||||||
output_path: str,
|
|
||||||
config: ModelConfigBase,
|
|
||||||
base_model: BaseModelType,
|
|
||||||
) -> str:
|
|
||||||
format = cls.detect_format(model_path)
|
|
||||||
if format == IPAdapterModelFormat.InvokeAI:
|
|
||||||
return model_path
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unsupported format: '{format}'.")
|
|
||||||
|
|
||||||
|
|
||||||
def get_ip_adapter_image_encoder_model_id(model_path: str):
|
|
||||||
"""Read the ID of the image encoder associated with the IP-Adapter at `model_path`."""
|
|
||||||
image_encoder_config_file = os.path.join(model_path, "image_encoder.txt")
|
|
||||||
|
|
||||||
with open(image_encoder_config_file, "r") as f:
|
|
||||||
image_encoder_model = f.readline().strip()
|
|
||||||
|
|
||||||
return image_encoder_model
|
|
@ -1,696 +0,0 @@
|
|||||||
import bisect
|
|
||||||
import os
|
|
||||||
from enum import Enum
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Dict, Optional, Union
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from safetensors.torch import load_file
|
|
||||||
|
|
||||||
from .base import (
|
|
||||||
BaseModelType,
|
|
||||||
InvalidModelException,
|
|
||||||
ModelBase,
|
|
||||||
ModelConfigBase,
|
|
||||||
ModelNotFoundException,
|
|
||||||
ModelType,
|
|
||||||
SubModelType,
|
|
||||||
classproperty,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class LoRAModelFormat(str, Enum):
|
|
||||||
LyCORIS = "lycoris"
|
|
||||||
Diffusers = "diffusers"
|
|
||||||
|
|
||||||
|
|
||||||
class LoRAModel(ModelBase):
|
|
||||||
# model_size: int
|
|
||||||
|
|
||||||
class Config(ModelConfigBase):
|
|
||||||
model_format: LoRAModelFormat # TODO:
|
|
||||||
|
|
||||||
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
|
|
||||||
assert model_type == ModelType.Lora
|
|
||||||
super().__init__(model_path, base_model, model_type)
|
|
||||||
|
|
||||||
self.model_size = os.path.getsize(self.model_path)
|
|
||||||
|
|
||||||
def get_size(self, child_type: Optional[SubModelType] = None):
|
|
||||||
if child_type is not None:
|
|
||||||
raise Exception("There is no child models in lora")
|
|
||||||
return self.model_size
|
|
||||||
|
|
||||||
def get_model(
|
|
||||||
self,
|
|
||||||
torch_dtype: Optional[torch.dtype],
|
|
||||||
child_type: Optional[SubModelType] = None,
|
|
||||||
):
|
|
||||||
if child_type is not None:
|
|
||||||
raise Exception("There is no child models in lora")
|
|
||||||
|
|
||||||
model = LoRAModelRaw.from_checkpoint(
|
|
||||||
file_path=self.model_path,
|
|
||||||
dtype=torch_dtype,
|
|
||||||
base_model=self.base_model,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.model_size = model.calc_size()
|
|
||||||
return model
|
|
||||||
|
|
||||||
@classproperty
|
|
||||||
def save_to_config(cls) -> bool:
|
|
||||||
return True
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def detect_format(cls, path: str):
|
|
||||||
if not os.path.exists(path):
|
|
||||||
raise ModelNotFoundException()
|
|
||||||
|
|
||||||
if os.path.isdir(path):
|
|
||||||
for ext in ["safetensors", "bin"]:
|
|
||||||
if os.path.exists(os.path.join(path, f"pytorch_lora_weights.{ext}")):
|
|
||||||
return LoRAModelFormat.Diffusers
|
|
||||||
|
|
||||||
if os.path.isfile(path):
|
|
||||||
if any(path.endswith(f".{ext}") for ext in ["safetensors", "ckpt", "pt"]):
|
|
||||||
return LoRAModelFormat.LyCORIS
|
|
||||||
|
|
||||||
raise InvalidModelException(f"Not a valid model: {path}")
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def convert_if_required(
|
|
||||||
cls,
|
|
||||||
model_path: str,
|
|
||||||
output_path: str,
|
|
||||||
config: ModelConfigBase,
|
|
||||||
base_model: BaseModelType,
|
|
||||||
) -> str:
|
|
||||||
if cls.detect_format(model_path) == LoRAModelFormat.Diffusers:
|
|
||||||
for ext in ["safetensors", "bin"]: # return path to the safetensors file inside the folder
|
|
||||||
path = Path(model_path, f"pytorch_lora_weights.{ext}")
|
|
||||||
if path.exists():
|
|
||||||
return path
|
|
||||||
else:
|
|
||||||
return model_path
|
|
||||||
|
|
||||||
|
|
||||||
class LoRALayerBase:
|
|
||||||
# rank: Optional[int]
|
|
||||||
# alpha: Optional[float]
|
|
||||||
# bias: Optional[torch.Tensor]
|
|
||||||
# layer_key: str
|
|
||||||
|
|
||||||
# @property
|
|
||||||
# def scale(self):
|
|
||||||
# return self.alpha / self.rank if (self.alpha and self.rank) else 1.0
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
layer_key: str,
|
|
||||||
values: dict,
|
|
||||||
):
|
|
||||||
if "alpha" in values:
|
|
||||||
self.alpha = values["alpha"].item()
|
|
||||||
else:
|
|
||||||
self.alpha = None
|
|
||||||
|
|
||||||
if "bias_indices" in values and "bias_values" in values and "bias_size" in values:
|
|
||||||
self.bias = torch.sparse_coo_tensor(
|
|
||||||
values["bias_indices"],
|
|
||||||
values["bias_values"],
|
|
||||||
tuple(values["bias_size"]),
|
|
||||||
)
|
|
||||||
|
|
||||||
else:
|
|
||||||
self.bias = None
|
|
||||||
|
|
||||||
self.rank = None # set in layer implementation
|
|
||||||
self.layer_key = layer_key
|
|
||||||
|
|
||||||
def get_weight(self, orig_weight: torch.Tensor):
|
|
||||||
raise NotImplementedError()
|
|
||||||
|
|
||||||
def calc_size(self) -> int:
|
|
||||||
model_size = 0
|
|
||||||
for val in [self.bias]:
|
|
||||||
if val is not None:
|
|
||||||
model_size += val.nelement() * val.element_size()
|
|
||||||
return model_size
|
|
||||||
|
|
||||||
def to(
|
|
||||||
self,
|
|
||||||
device: Optional[torch.device] = None,
|
|
||||||
dtype: Optional[torch.dtype] = None,
|
|
||||||
):
|
|
||||||
if self.bias is not None:
|
|
||||||
self.bias = self.bias.to(device=device, dtype=dtype)
|
|
||||||
|
|
||||||
|
|
||||||
# TODO: find and debug lora/locon with bias
|
|
||||||
class LoRALayer(LoRALayerBase):
|
|
||||||
# up: torch.Tensor
|
|
||||||
# mid: Optional[torch.Tensor]
|
|
||||||
# down: torch.Tensor
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
layer_key: str,
|
|
||||||
values: dict,
|
|
||||||
):
|
|
||||||
super().__init__(layer_key, values)
|
|
||||||
|
|
||||||
self.up = values["lora_up.weight"]
|
|
||||||
self.down = values["lora_down.weight"]
|
|
||||||
if "lora_mid.weight" in values:
|
|
||||||
self.mid = values["lora_mid.weight"]
|
|
||||||
else:
|
|
||||||
self.mid = None
|
|
||||||
|
|
||||||
self.rank = self.down.shape[0]
|
|
||||||
|
|
||||||
def get_weight(self, orig_weight: torch.Tensor):
|
|
||||||
if self.mid is not None:
|
|
||||||
up = self.up.reshape(self.up.shape[0], self.up.shape[1])
|
|
||||||
down = self.down.reshape(self.down.shape[0], self.down.shape[1])
|
|
||||||
weight = torch.einsum("m n w h, i m, n j -> i j w h", self.mid, up, down)
|
|
||||||
else:
|
|
||||||
weight = self.up.reshape(self.up.shape[0], -1) @ self.down.reshape(self.down.shape[0], -1)
|
|
||||||
|
|
||||||
return weight
|
|
||||||
|
|
||||||
def calc_size(self) -> int:
|
|
||||||
model_size = super().calc_size()
|
|
||||||
for val in [self.up, self.mid, self.down]:
|
|
||||||
if val is not None:
|
|
||||||
model_size += val.nelement() * val.element_size()
|
|
||||||
return model_size
|
|
||||||
|
|
||||||
def to(
|
|
||||||
self,
|
|
||||||
device: Optional[torch.device] = None,
|
|
||||||
dtype: Optional[torch.dtype] = None,
|
|
||||||
):
|
|
||||||
super().to(device=device, dtype=dtype)
|
|
||||||
|
|
||||||
self.up = self.up.to(device=device, dtype=dtype)
|
|
||||||
self.down = self.down.to(device=device, dtype=dtype)
|
|
||||||
|
|
||||||
if self.mid is not None:
|
|
||||||
self.mid = self.mid.to(device=device, dtype=dtype)
|
|
||||||
|
|
||||||
|
|
||||||
class LoHALayer(LoRALayerBase):
|
|
||||||
# w1_a: torch.Tensor
|
|
||||||
# w1_b: torch.Tensor
|
|
||||||
# w2_a: torch.Tensor
|
|
||||||
# w2_b: torch.Tensor
|
|
||||||
# t1: Optional[torch.Tensor] = None
|
|
||||||
# t2: Optional[torch.Tensor] = None
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
layer_key: str,
|
|
||||||
values: dict,
|
|
||||||
):
|
|
||||||
super().__init__(layer_key, values)
|
|
||||||
|
|
||||||
self.w1_a = values["hada_w1_a"]
|
|
||||||
self.w1_b = values["hada_w1_b"]
|
|
||||||
self.w2_a = values["hada_w2_a"]
|
|
||||||
self.w2_b = values["hada_w2_b"]
|
|
||||||
|
|
||||||
if "hada_t1" in values:
|
|
||||||
self.t1 = values["hada_t1"]
|
|
||||||
else:
|
|
||||||
self.t1 = None
|
|
||||||
|
|
||||||
if "hada_t2" in values:
|
|
||||||
self.t2 = values["hada_t2"]
|
|
||||||
else:
|
|
||||||
self.t2 = None
|
|
||||||
|
|
||||||
self.rank = self.w1_b.shape[0]
|
|
||||||
|
|
||||||
def get_weight(self, orig_weight: torch.Tensor):
|
|
||||||
if self.t1 is None:
|
|
||||||
weight = (self.w1_a @ self.w1_b) * (self.w2_a @ self.w2_b)
|
|
||||||
|
|
||||||
else:
|
|
||||||
rebuild1 = torch.einsum("i j k l, j r, i p -> p r k l", self.t1, self.w1_b, self.w1_a)
|
|
||||||
rebuild2 = torch.einsum("i j k l, j r, i p -> p r k l", self.t2, self.w2_b, self.w2_a)
|
|
||||||
weight = rebuild1 * rebuild2
|
|
||||||
|
|
||||||
return weight
|
|
||||||
|
|
||||||
def calc_size(self) -> int:
|
|
||||||
model_size = super().calc_size()
|
|
||||||
for val in [self.w1_a, self.w1_b, self.w2_a, self.w2_b, self.t1, self.t2]:
|
|
||||||
if val is not None:
|
|
||||||
model_size += val.nelement() * val.element_size()
|
|
||||||
return model_size
|
|
||||||
|
|
||||||
def to(
|
|
||||||
self,
|
|
||||||
device: Optional[torch.device] = None,
|
|
||||||
dtype: Optional[torch.dtype] = None,
|
|
||||||
):
|
|
||||||
super().to(device=device, dtype=dtype)
|
|
||||||
|
|
||||||
self.w1_a = self.w1_a.to(device=device, dtype=dtype)
|
|
||||||
self.w1_b = self.w1_b.to(device=device, dtype=dtype)
|
|
||||||
if self.t1 is not None:
|
|
||||||
self.t1 = self.t1.to(device=device, dtype=dtype)
|
|
||||||
|
|
||||||
self.w2_a = self.w2_a.to(device=device, dtype=dtype)
|
|
||||||
self.w2_b = self.w2_b.to(device=device, dtype=dtype)
|
|
||||||
if self.t2 is not None:
|
|
||||||
self.t2 = self.t2.to(device=device, dtype=dtype)
|
|
||||||
|
|
||||||
|
|
||||||
class LoKRLayer(LoRALayerBase):
|
|
||||||
# w1: Optional[torch.Tensor] = None
|
|
||||||
# w1_a: Optional[torch.Tensor] = None
|
|
||||||
# w1_b: Optional[torch.Tensor] = None
|
|
||||||
# w2: Optional[torch.Tensor] = None
|
|
||||||
# w2_a: Optional[torch.Tensor] = None
|
|
||||||
# w2_b: Optional[torch.Tensor] = None
|
|
||||||
# t2: Optional[torch.Tensor] = None
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
layer_key: str,
|
|
||||||
values: dict,
|
|
||||||
):
|
|
||||||
super().__init__(layer_key, values)
|
|
||||||
|
|
||||||
if "lokr_w1" in values:
|
|
||||||
self.w1 = values["lokr_w1"]
|
|
||||||
self.w1_a = None
|
|
||||||
self.w1_b = None
|
|
||||||
else:
|
|
||||||
self.w1 = None
|
|
||||||
self.w1_a = values["lokr_w1_a"]
|
|
||||||
self.w1_b = values["lokr_w1_b"]
|
|
||||||
|
|
||||||
if "lokr_w2" in values:
|
|
||||||
self.w2 = values["lokr_w2"]
|
|
||||||
self.w2_a = None
|
|
||||||
self.w2_b = None
|
|
||||||
else:
|
|
||||||
self.w2 = None
|
|
||||||
self.w2_a = values["lokr_w2_a"]
|
|
||||||
self.w2_b = values["lokr_w2_b"]
|
|
||||||
|
|
||||||
if "lokr_t2" in values:
|
|
||||||
self.t2 = values["lokr_t2"]
|
|
||||||
else:
|
|
||||||
self.t2 = None
|
|
||||||
|
|
||||||
if "lokr_w1_b" in values:
|
|
||||||
self.rank = values["lokr_w1_b"].shape[0]
|
|
||||||
elif "lokr_w2_b" in values:
|
|
||||||
self.rank = values["lokr_w2_b"].shape[0]
|
|
||||||
else:
|
|
||||||
self.rank = None # unscaled
|
|
||||||
|
|
||||||
def get_weight(self, orig_weight: torch.Tensor):
|
|
||||||
w1 = self.w1
|
|
||||||
if w1 is None:
|
|
||||||
w1 = self.w1_a @ self.w1_b
|
|
||||||
|
|
||||||
w2 = self.w2
|
|
||||||
if w2 is None:
|
|
||||||
if self.t2 is None:
|
|
||||||
w2 = self.w2_a @ self.w2_b
|
|
||||||
else:
|
|
||||||
w2 = torch.einsum("i j k l, i p, j r -> p r k l", self.t2, self.w2_a, self.w2_b)
|
|
||||||
|
|
||||||
if len(w2.shape) == 4:
|
|
||||||
w1 = w1.unsqueeze(2).unsqueeze(2)
|
|
||||||
w2 = w2.contiguous()
|
|
||||||
weight = torch.kron(w1, w2)
|
|
||||||
|
|
||||||
return weight
|
|
||||||
|
|
||||||
def calc_size(self) -> int:
|
|
||||||
model_size = super().calc_size()
|
|
||||||
for val in [self.w1, self.w1_a, self.w1_b, self.w2, self.w2_a, self.w2_b, self.t2]:
|
|
||||||
if val is not None:
|
|
||||||
model_size += val.nelement() * val.element_size()
|
|
||||||
return model_size
|
|
||||||
|
|
||||||
def to(
|
|
||||||
self,
|
|
||||||
device: Optional[torch.device] = None,
|
|
||||||
dtype: Optional[torch.dtype] = None,
|
|
||||||
):
|
|
||||||
super().to(device=device, dtype=dtype)
|
|
||||||
|
|
||||||
if self.w1 is not None:
|
|
||||||
self.w1 = self.w1.to(device=device, dtype=dtype)
|
|
||||||
else:
|
|
||||||
self.w1_a = self.w1_a.to(device=device, dtype=dtype)
|
|
||||||
self.w1_b = self.w1_b.to(device=device, dtype=dtype)
|
|
||||||
|
|
||||||
if self.w2 is not None:
|
|
||||||
self.w2 = self.w2.to(device=device, dtype=dtype)
|
|
||||||
else:
|
|
||||||
self.w2_a = self.w2_a.to(device=device, dtype=dtype)
|
|
||||||
self.w2_b = self.w2_b.to(device=device, dtype=dtype)
|
|
||||||
|
|
||||||
if self.t2 is not None:
|
|
||||||
self.t2 = self.t2.to(device=device, dtype=dtype)
|
|
||||||
|
|
||||||
|
|
||||||
class FullLayer(LoRALayerBase):
|
|
||||||
# weight: torch.Tensor
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
layer_key: str,
|
|
||||||
values: dict,
|
|
||||||
):
|
|
||||||
super().__init__(layer_key, values)
|
|
||||||
|
|
||||||
self.weight = values["diff"]
|
|
||||||
|
|
||||||
if len(values.keys()) > 1:
|
|
||||||
_keys = list(values.keys())
|
|
||||||
_keys.remove("diff")
|
|
||||||
raise NotImplementedError(f"Unexpected keys in lora diff layer: {_keys}")
|
|
||||||
|
|
||||||
self.rank = None # unscaled
|
|
||||||
|
|
||||||
def get_weight(self, orig_weight: torch.Tensor):
|
|
||||||
return self.weight
|
|
||||||
|
|
||||||
def calc_size(self) -> int:
|
|
||||||
model_size = super().calc_size()
|
|
||||||
model_size += self.weight.nelement() * self.weight.element_size()
|
|
||||||
return model_size
|
|
||||||
|
|
||||||
def to(
|
|
||||||
self,
|
|
||||||
device: Optional[torch.device] = None,
|
|
||||||
dtype: Optional[torch.dtype] = None,
|
|
||||||
):
|
|
||||||
super().to(device=device, dtype=dtype)
|
|
||||||
|
|
||||||
self.weight = self.weight.to(device=device, dtype=dtype)
|
|
||||||
|
|
||||||
|
|
||||||
class IA3Layer(LoRALayerBase):
|
|
||||||
# weight: torch.Tensor
|
|
||||||
# on_input: torch.Tensor
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
layer_key: str,
|
|
||||||
values: dict,
|
|
||||||
):
|
|
||||||
super().__init__(layer_key, values)
|
|
||||||
|
|
||||||
self.weight = values["weight"]
|
|
||||||
self.on_input = values["on_input"]
|
|
||||||
|
|
||||||
self.rank = None # unscaled
|
|
||||||
|
|
||||||
def get_weight(self, orig_weight: torch.Tensor):
|
|
||||||
weight = self.weight
|
|
||||||
if not self.on_input:
|
|
||||||
weight = weight.reshape(-1, 1)
|
|
||||||
return orig_weight * weight
|
|
||||||
|
|
||||||
def calc_size(self) -> int:
|
|
||||||
model_size = super().calc_size()
|
|
||||||
model_size += self.weight.nelement() * self.weight.element_size()
|
|
||||||
model_size += self.on_input.nelement() * self.on_input.element_size()
|
|
||||||
return model_size
|
|
||||||
|
|
||||||
def to(
|
|
||||||
self,
|
|
||||||
device: Optional[torch.device] = None,
|
|
||||||
dtype: Optional[torch.dtype] = None,
|
|
||||||
):
|
|
||||||
super().to(device=device, dtype=dtype)
|
|
||||||
|
|
||||||
self.weight = self.weight.to(device=device, dtype=dtype)
|
|
||||||
self.on_input = self.on_input.to(device=device, dtype=dtype)
|
|
||||||
|
|
||||||
|
|
||||||
# TODO: rename all methods used in model logic with Info postfix and remove here Raw postfix
|
|
||||||
class LoRAModelRaw: # (torch.nn.Module):
|
|
||||||
_name: str
|
|
||||||
layers: Dict[str, LoRALayer]
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
name: str,
|
|
||||||
layers: Dict[str, LoRALayer],
|
|
||||||
):
|
|
||||||
self._name = name
|
|
||||||
self.layers = layers
|
|
||||||
|
|
||||||
@property
|
|
||||||
def name(self):
|
|
||||||
return self._name
|
|
||||||
|
|
||||||
def to(
|
|
||||||
self,
|
|
||||||
device: Optional[torch.device] = None,
|
|
||||||
dtype: Optional[torch.dtype] = None,
|
|
||||||
):
|
|
||||||
# TODO: try revert if exception?
|
|
||||||
for _key, layer in self.layers.items():
|
|
||||||
layer.to(device=device, dtype=dtype)
|
|
||||||
|
|
||||||
def calc_size(self) -> int:
|
|
||||||
model_size = 0
|
|
||||||
for _, layer in self.layers.items():
|
|
||||||
model_size += layer.calc_size()
|
|
||||||
return model_size
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def _convert_sdxl_keys_to_diffusers_format(cls, state_dict):
|
|
||||||
"""Convert the keys of an SDXL LoRA state_dict to diffusers format.
|
|
||||||
|
|
||||||
The input state_dict can be in either Stability AI format or diffusers format. If the state_dict is already in
|
|
||||||
diffusers format, then this function will have no effect.
|
|
||||||
|
|
||||||
This function is adapted from:
|
|
||||||
https://github.com/bmaltais/kohya_ss/blob/2accb1305979ba62f5077a23aabac23b4c37e935/networks/lora_diffusers.py#L385-L409
|
|
||||||
|
|
||||||
Args:
|
|
||||||
state_dict (Dict[str, Tensor]): The SDXL LoRA state_dict.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: If state_dict contains an unrecognized key, or not all keys could be converted.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Dict[str, Tensor]: The diffusers-format state_dict.
|
|
||||||
"""
|
|
||||||
converted_count = 0 # The number of Stability AI keys converted to diffusers format.
|
|
||||||
not_converted_count = 0 # The number of keys that were not converted.
|
|
||||||
|
|
||||||
# Get a sorted list of Stability AI UNet keys so that we can efficiently search for keys with matching prefixes.
|
|
||||||
# For example, we want to efficiently find `input_blocks_4_1` in the list when searching for
|
|
||||||
# `input_blocks_4_1_proj_in`.
|
|
||||||
stability_unet_keys = list(SDXL_UNET_STABILITY_TO_DIFFUSERS_MAP)
|
|
||||||
stability_unet_keys.sort()
|
|
||||||
|
|
||||||
new_state_dict = {}
|
|
||||||
for full_key, value in state_dict.items():
|
|
||||||
if full_key.startswith("lora_unet_"):
|
|
||||||
search_key = full_key.replace("lora_unet_", "")
|
|
||||||
# Use bisect to find the key in stability_unet_keys that *may* match the search_key's prefix.
|
|
||||||
position = bisect.bisect_right(stability_unet_keys, search_key)
|
|
||||||
map_key = stability_unet_keys[position - 1]
|
|
||||||
# Now, check if the map_key *actually* matches the search_key.
|
|
||||||
if search_key.startswith(map_key):
|
|
||||||
new_key = full_key.replace(map_key, SDXL_UNET_STABILITY_TO_DIFFUSERS_MAP[map_key])
|
|
||||||
new_state_dict[new_key] = value
|
|
||||||
converted_count += 1
|
|
||||||
else:
|
|
||||||
new_state_dict[full_key] = value
|
|
||||||
not_converted_count += 1
|
|
||||||
elif full_key.startswith("lora_te1_") or full_key.startswith("lora_te2_"):
|
|
||||||
# The CLIP text encoders have the same keys in both Stability AI and diffusers formats.
|
|
||||||
new_state_dict[full_key] = value
|
|
||||||
continue
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unrecognized SDXL LoRA key prefix: '{full_key}'.")
|
|
||||||
|
|
||||||
if converted_count > 0 and not_converted_count > 0:
|
|
||||||
raise ValueError(
|
|
||||||
f"The SDXL LoRA could only be partially converted to diffusers format. converted={converted_count},"
|
|
||||||
f" not_converted={not_converted_count}"
|
|
||||||
)
|
|
||||||
|
|
||||||
return new_state_dict
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_checkpoint(
|
|
||||||
cls,
|
|
||||||
file_path: Union[str, Path],
|
|
||||||
device: Optional[torch.device] = None,
|
|
||||||
dtype: Optional[torch.dtype] = None,
|
|
||||||
base_model: Optional[BaseModelType] = None,
|
|
||||||
):
|
|
||||||
device = device or torch.device("cpu")
|
|
||||||
dtype = dtype or torch.float32
|
|
||||||
|
|
||||||
if isinstance(file_path, str):
|
|
||||||
file_path = Path(file_path)
|
|
||||||
|
|
||||||
model = cls(
|
|
||||||
name=file_path.stem, # TODO:
|
|
||||||
layers={},
|
|
||||||
)
|
|
||||||
|
|
||||||
if file_path.suffix == ".safetensors":
|
|
||||||
state_dict = load_file(file_path.absolute().as_posix(), device="cpu")
|
|
||||||
else:
|
|
||||||
state_dict = torch.load(file_path, map_location="cpu")
|
|
||||||
|
|
||||||
state_dict = cls._group_state(state_dict)
|
|
||||||
|
|
||||||
if base_model == BaseModelType.StableDiffusionXL:
|
|
||||||
state_dict = cls._convert_sdxl_keys_to_diffusers_format(state_dict)
|
|
||||||
|
|
||||||
for layer_key, values in state_dict.items():
|
|
||||||
# lora and locon
|
|
||||||
if "lora_down.weight" in values:
|
|
||||||
layer = LoRALayer(layer_key, values)
|
|
||||||
|
|
||||||
# loha
|
|
||||||
elif "hada_w1_b" in values:
|
|
||||||
layer = LoHALayer(layer_key, values)
|
|
||||||
|
|
||||||
# lokr
|
|
||||||
elif "lokr_w1_b" in values or "lokr_w1" in values:
|
|
||||||
layer = LoKRLayer(layer_key, values)
|
|
||||||
|
|
||||||
# diff
|
|
||||||
elif "diff" in values:
|
|
||||||
layer = FullLayer(layer_key, values)
|
|
||||||
|
|
||||||
# ia3
|
|
||||||
elif "weight" in values and "on_input" in values:
|
|
||||||
layer = IA3Layer(layer_key, values)
|
|
||||||
|
|
||||||
else:
|
|
||||||
print(f">> Encountered unknown lora layer module in {model.name}: {layer_key} - {list(values.keys())}")
|
|
||||||
raise Exception("Unknown lora format!")
|
|
||||||
|
|
||||||
# lower memory consumption by removing already parsed layer values
|
|
||||||
state_dict[layer_key].clear()
|
|
||||||
|
|
||||||
layer.to(device=device, dtype=dtype)
|
|
||||||
model.layers[layer_key] = layer
|
|
||||||
|
|
||||||
return model
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _group_state(state_dict: dict):
|
|
||||||
state_dict_groupped = {}
|
|
||||||
|
|
||||||
for key, value in state_dict.items():
|
|
||||||
stem, leaf = key.split(".", 1)
|
|
||||||
if stem not in state_dict_groupped:
|
|
||||||
state_dict_groupped[stem] = {}
|
|
||||||
state_dict_groupped[stem][leaf] = value
|
|
||||||
|
|
||||||
return state_dict_groupped
|
|
||||||
|
|
||||||
|
|
||||||
# code from
|
|
||||||
# https://github.com/bmaltais/kohya_ss/blob/2accb1305979ba62f5077a23aabac23b4c37e935/networks/lora_diffusers.py#L15C1-L97C32
|
|
||||||
def make_sdxl_unet_conversion_map():
|
|
||||||
"""Create a dict mapping state_dict keys from Stability AI SDXL format to diffusers SDXL format."""
|
|
||||||
unet_conversion_map_layer = []
|
|
||||||
|
|
||||||
for i in range(3): # num_blocks is 3 in sdxl
|
|
||||||
# loop over downblocks/upblocks
|
|
||||||
for j in range(2):
|
|
||||||
# loop over resnets/attentions for downblocks
|
|
||||||
hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
|
|
||||||
sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0."
|
|
||||||
unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
|
|
||||||
|
|
||||||
if i < 3:
|
|
||||||
# no attention layers in down_blocks.3
|
|
||||||
hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
|
|
||||||
sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1."
|
|
||||||
unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
|
|
||||||
|
|
||||||
for j in range(3):
|
|
||||||
# loop over resnets/attentions for upblocks
|
|
||||||
hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}."
|
|
||||||
sd_up_res_prefix = f"output_blocks.{3*i + j}.0."
|
|
||||||
unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))
|
|
||||||
|
|
||||||
# if i > 0: commentout for sdxl
|
|
||||||
# no attention layers in up_blocks.0
|
|
||||||
hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}."
|
|
||||||
sd_up_atn_prefix = f"output_blocks.{3*i + j}.1."
|
|
||||||
unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix))
|
|
||||||
|
|
||||||
if i < 3:
|
|
||||||
# no downsample in down_blocks.3
|
|
||||||
hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
|
|
||||||
sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op."
|
|
||||||
unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
|
|
||||||
|
|
||||||
# no upsample in up_blocks.3
|
|
||||||
hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
|
|
||||||
sd_upsample_prefix = f"output_blocks.{3*i + 2}.{2}." # change for sdxl
|
|
||||||
unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix))
|
|
||||||
|
|
||||||
hf_mid_atn_prefix = "mid_block.attentions.0."
|
|
||||||
sd_mid_atn_prefix = "middle_block.1."
|
|
||||||
unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
|
|
||||||
|
|
||||||
for j in range(2):
|
|
||||||
hf_mid_res_prefix = f"mid_block.resnets.{j}."
|
|
||||||
sd_mid_res_prefix = f"middle_block.{2*j}."
|
|
||||||
unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
|
|
||||||
|
|
||||||
unet_conversion_map_resnet = [
|
|
||||||
# (stable-diffusion, HF Diffusers)
|
|
||||||
("in_layers.0.", "norm1."),
|
|
||||||
("in_layers.2.", "conv1."),
|
|
||||||
("out_layers.0.", "norm2."),
|
|
||||||
("out_layers.3.", "conv2."),
|
|
||||||
("emb_layers.1.", "time_emb_proj."),
|
|
||||||
("skip_connection.", "conv_shortcut."),
|
|
||||||
]
|
|
||||||
|
|
||||||
unet_conversion_map = []
|
|
||||||
for sd, hf in unet_conversion_map_layer:
|
|
||||||
if "resnets" in hf:
|
|
||||||
for sd_res, hf_res in unet_conversion_map_resnet:
|
|
||||||
unet_conversion_map.append((sd + sd_res, hf + hf_res))
|
|
||||||
else:
|
|
||||||
unet_conversion_map.append((sd, hf))
|
|
||||||
|
|
||||||
for j in range(2):
|
|
||||||
hf_time_embed_prefix = f"time_embedding.linear_{j+1}."
|
|
||||||
sd_time_embed_prefix = f"time_embed.{j*2}."
|
|
||||||
unet_conversion_map.append((sd_time_embed_prefix, hf_time_embed_prefix))
|
|
||||||
|
|
||||||
for j in range(2):
|
|
||||||
hf_label_embed_prefix = f"add_embedding.linear_{j+1}."
|
|
||||||
sd_label_embed_prefix = f"label_emb.0.{j*2}."
|
|
||||||
unet_conversion_map.append((sd_label_embed_prefix, hf_label_embed_prefix))
|
|
||||||
|
|
||||||
unet_conversion_map.append(("input_blocks.0.0.", "conv_in."))
|
|
||||||
unet_conversion_map.append(("out.0.", "conv_norm_out."))
|
|
||||||
unet_conversion_map.append(("out.2.", "conv_out."))
|
|
||||||
|
|
||||||
return unet_conversion_map
|
|
||||||
|
|
||||||
|
|
||||||
SDXL_UNET_STABILITY_TO_DIFFUSERS_MAP = {
|
|
||||||
sd.rstrip(".").replace(".", "_"): hf.rstrip(".").replace(".", "_") for sd, hf in make_sdxl_unet_conversion_map()
|
|
||||||
}
|
|
@ -1,148 +0,0 @@
|
|||||||
import json
|
|
||||||
import os
|
|
||||||
from enum import Enum
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Literal, Optional
|
|
||||||
|
|
||||||
from omegaconf import OmegaConf
|
|
||||||
from pydantic import Field
|
|
||||||
|
|
||||||
from invokeai.app.services.config import InvokeAIAppConfig
|
|
||||||
from invokeai.backend.model_management.detect_baked_in_vae import has_baked_in_sdxl_vae
|
|
||||||
from invokeai.backend.util.logging import InvokeAILogger
|
|
||||||
|
|
||||||
from .base import (
|
|
||||||
BaseModelType,
|
|
||||||
DiffusersModel,
|
|
||||||
InvalidModelException,
|
|
||||||
ModelConfigBase,
|
|
||||||
ModelType,
|
|
||||||
ModelVariantType,
|
|
||||||
classproperty,
|
|
||||||
read_checkpoint_meta,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class StableDiffusionXLModelFormat(str, Enum):
|
|
||||||
Checkpoint = "checkpoint"
|
|
||||||
Diffusers = "diffusers"
|
|
||||||
|
|
||||||
|
|
||||||
class StableDiffusionXLModel(DiffusersModel):
|
|
||||||
# TODO: check that configs overwriten properly
|
|
||||||
class DiffusersConfig(ModelConfigBase):
|
|
||||||
model_format: Literal[StableDiffusionXLModelFormat.Diffusers]
|
|
||||||
vae: Optional[str] = Field(None)
|
|
||||||
variant: ModelVariantType
|
|
||||||
|
|
||||||
class CheckpointConfig(ModelConfigBase):
|
|
||||||
model_format: Literal[StableDiffusionXLModelFormat.Checkpoint]
|
|
||||||
vae: Optional[str] = Field(None)
|
|
||||||
config: str
|
|
||||||
variant: ModelVariantType
|
|
||||||
|
|
||||||
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
|
|
||||||
assert base_model in {BaseModelType.StableDiffusionXL, BaseModelType.StableDiffusionXLRefiner}
|
|
||||||
assert model_type == ModelType.Main
|
|
||||||
super().__init__(
|
|
||||||
model_path=model_path,
|
|
||||||
base_model=BaseModelType.StableDiffusionXL,
|
|
||||||
model_type=ModelType.Main,
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def probe_config(cls, path: str, **kwargs):
|
|
||||||
model_format = cls.detect_format(path)
|
|
||||||
ckpt_config_path = kwargs.get("config", None)
|
|
||||||
if model_format == StableDiffusionXLModelFormat.Checkpoint:
|
|
||||||
if ckpt_config_path:
|
|
||||||
ckpt_config = OmegaConf.load(ckpt_config_path)
|
|
||||||
in_channels = ckpt_config["model"]["params"]["unet_config"]["params"]["in_channels"]
|
|
||||||
|
|
||||||
else:
|
|
||||||
checkpoint = read_checkpoint_meta(path)
|
|
||||||
checkpoint = checkpoint.get("state_dict", checkpoint)
|
|
||||||
in_channels = checkpoint["model.diffusion_model.input_blocks.0.0.weight"].shape[1]
|
|
||||||
|
|
||||||
elif model_format == StableDiffusionXLModelFormat.Diffusers:
|
|
||||||
unet_config_path = os.path.join(path, "unet", "config.json")
|
|
||||||
if os.path.exists(unet_config_path):
|
|
||||||
with open(unet_config_path, "r") as f:
|
|
||||||
unet_config = json.loads(f.read())
|
|
||||||
in_channels = unet_config["in_channels"]
|
|
||||||
|
|
||||||
else:
|
|
||||||
raise InvalidModelException(f"{path} is not a recognized Stable Diffusion diffusers model")
|
|
||||||
|
|
||||||
else:
|
|
||||||
raise NotImplementedError(f"Unknown stable diffusion 2.* format: {model_format}")
|
|
||||||
|
|
||||||
if in_channels == 9:
|
|
||||||
variant = ModelVariantType.Inpaint
|
|
||||||
elif in_channels == 5:
|
|
||||||
variant = ModelVariantType.Depth
|
|
||||||
elif in_channels == 4:
|
|
||||||
variant = ModelVariantType.Normal
|
|
||||||
else:
|
|
||||||
raise Exception("Unkown stable diffusion 2.* model format")
|
|
||||||
|
|
||||||
if ckpt_config_path is None:
|
|
||||||
# avoid circular import
|
|
||||||
from .stable_diffusion import _select_ckpt_config
|
|
||||||
|
|
||||||
ckpt_config_path = _select_ckpt_config(kwargs.get("model_base", BaseModelType.StableDiffusionXL), variant)
|
|
||||||
|
|
||||||
return cls.create_config(
|
|
||||||
path=path,
|
|
||||||
model_format=model_format,
|
|
||||||
config=ckpt_config_path,
|
|
||||||
variant=variant,
|
|
||||||
)
|
|
||||||
|
|
||||||
@classproperty
|
|
||||||
def save_to_config(cls) -> bool:
|
|
||||||
return True
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def detect_format(cls, model_path: str):
|
|
||||||
if os.path.isdir(model_path):
|
|
||||||
return StableDiffusionXLModelFormat.Diffusers
|
|
||||||
else:
|
|
||||||
return StableDiffusionXLModelFormat.Checkpoint
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def convert_if_required(
|
|
||||||
cls,
|
|
||||||
model_path: str,
|
|
||||||
output_path: str,
|
|
||||||
config: ModelConfigBase,
|
|
||||||
base_model: BaseModelType,
|
|
||||||
) -> str:
|
|
||||||
# The convert script adapted from the diffusers package uses
|
|
||||||
# strings for the base model type. To avoid making too many
|
|
||||||
# source code changes, we simply translate here
|
|
||||||
if Path(output_path).exists():
|
|
||||||
return output_path
|
|
||||||
|
|
||||||
if isinstance(config, cls.CheckpointConfig):
|
|
||||||
from invokeai.backend.model_management.models.stable_diffusion import _convert_ckpt_and_cache
|
|
||||||
|
|
||||||
# Hack in VAE-fp16 fix - If model sdxl-vae-fp16-fix is installed,
|
|
||||||
# then we bake it into the converted model unless there is already
|
|
||||||
# a nonstandard VAE installed.
|
|
||||||
kwargs = {}
|
|
||||||
app_config = InvokeAIAppConfig.get_config()
|
|
||||||
vae_path = app_config.models_path / "sdxl/vae/sdxl-vae-fp16-fix"
|
|
||||||
if vae_path.exists() and not has_baked_in_sdxl_vae(Path(model_path)):
|
|
||||||
InvokeAILogger.get_logger().warning("No baked-in VAE detected. Inserting sdxl-vae-fp16-fix.")
|
|
||||||
kwargs["vae_path"] = vae_path
|
|
||||||
|
|
||||||
return _convert_ckpt_and_cache(
|
|
||||||
version=base_model,
|
|
||||||
model_config=config,
|
|
||||||
output_path=output_path,
|
|
||||||
use_safetensors=True,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
return model_path
|
|
@ -1,337 +0,0 @@
|
|||||||
import json
|
|
||||||
import os
|
|
||||||
from enum import Enum
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Literal, Optional, Union
|
|
||||||
|
|
||||||
from diffusers import StableDiffusionInpaintPipeline, StableDiffusionPipeline
|
|
||||||
from omegaconf import OmegaConf
|
|
||||||
from pydantic import Field
|
|
||||||
|
|
||||||
import invokeai.backend.util.logging as logger
|
|
||||||
from invokeai.app.services.config import InvokeAIAppConfig
|
|
||||||
|
|
||||||
from .base import (
|
|
||||||
BaseModelType,
|
|
||||||
DiffusersModel,
|
|
||||||
InvalidModelException,
|
|
||||||
ModelConfigBase,
|
|
||||||
ModelNotFoundException,
|
|
||||||
ModelType,
|
|
||||||
ModelVariantType,
|
|
||||||
SilenceWarnings,
|
|
||||||
classproperty,
|
|
||||||
read_checkpoint_meta,
|
|
||||||
)
|
|
||||||
from .sdxl import StableDiffusionXLModel
|
|
||||||
|
|
||||||
|
|
||||||
class StableDiffusion1ModelFormat(str, Enum):
|
|
||||||
Checkpoint = "checkpoint"
|
|
||||||
Diffusers = "diffusers"
|
|
||||||
|
|
||||||
|
|
||||||
class StableDiffusion1Model(DiffusersModel):
|
|
||||||
class DiffusersConfig(ModelConfigBase):
|
|
||||||
model_format: Literal[StableDiffusion1ModelFormat.Diffusers]
|
|
||||||
vae: Optional[str] = Field(None)
|
|
||||||
variant: ModelVariantType
|
|
||||||
|
|
||||||
class CheckpointConfig(ModelConfigBase):
|
|
||||||
model_format: Literal[StableDiffusion1ModelFormat.Checkpoint]
|
|
||||||
vae: Optional[str] = Field(None)
|
|
||||||
config: str
|
|
||||||
variant: ModelVariantType
|
|
||||||
|
|
||||||
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
|
|
||||||
assert base_model == BaseModelType.StableDiffusion1
|
|
||||||
assert model_type == ModelType.Main
|
|
||||||
super().__init__(
|
|
||||||
model_path=model_path,
|
|
||||||
base_model=BaseModelType.StableDiffusion1,
|
|
||||||
model_type=ModelType.Main,
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def probe_config(cls, path: str, **kwargs):
|
|
||||||
model_format = cls.detect_format(path)
|
|
||||||
ckpt_config_path = kwargs.get("config", None)
|
|
||||||
if model_format == StableDiffusion1ModelFormat.Checkpoint:
|
|
||||||
if ckpt_config_path:
|
|
||||||
ckpt_config = OmegaConf.load(ckpt_config_path)
|
|
||||||
ckpt_config["model"]["params"]["unet_config"]["params"]["in_channels"]
|
|
||||||
|
|
||||||
else:
|
|
||||||
checkpoint = read_checkpoint_meta(path)
|
|
||||||
checkpoint = checkpoint.get("state_dict", checkpoint)
|
|
||||||
in_channels = checkpoint["model.diffusion_model.input_blocks.0.0.weight"].shape[1]
|
|
||||||
|
|
||||||
elif model_format == StableDiffusion1ModelFormat.Diffusers:
|
|
||||||
unet_config_path = os.path.join(path, "unet", "config.json")
|
|
||||||
if os.path.exists(unet_config_path):
|
|
||||||
with open(unet_config_path, "r") as f:
|
|
||||||
unet_config = json.loads(f.read())
|
|
||||||
in_channels = unet_config["in_channels"]
|
|
||||||
|
|
||||||
else:
|
|
||||||
raise NotImplementedError(f"{path} is not a supported stable diffusion diffusers format")
|
|
||||||
|
|
||||||
else:
|
|
||||||
raise NotImplementedError(f"Unknown stable diffusion 1.* format: {model_format}")
|
|
||||||
|
|
||||||
if in_channels == 9:
|
|
||||||
variant = ModelVariantType.Inpaint
|
|
||||||
elif in_channels == 4:
|
|
||||||
variant = ModelVariantType.Normal
|
|
||||||
else:
|
|
||||||
raise Exception("Unkown stable diffusion 1.* model format")
|
|
||||||
|
|
||||||
if ckpt_config_path is None:
|
|
||||||
ckpt_config_path = _select_ckpt_config(BaseModelType.StableDiffusion1, variant)
|
|
||||||
|
|
||||||
return cls.create_config(
|
|
||||||
path=path,
|
|
||||||
model_format=model_format,
|
|
||||||
config=ckpt_config_path,
|
|
||||||
variant=variant,
|
|
||||||
)
|
|
||||||
|
|
||||||
@classproperty
|
|
||||||
def save_to_config(cls) -> bool:
|
|
||||||
return True
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def detect_format(cls, model_path: str):
|
|
||||||
if not os.path.exists(model_path):
|
|
||||||
raise ModelNotFoundException()
|
|
||||||
|
|
||||||
if os.path.isdir(model_path):
|
|
||||||
if os.path.exists(os.path.join(model_path, "model_index.json")):
|
|
||||||
return StableDiffusion1ModelFormat.Diffusers
|
|
||||||
|
|
||||||
if os.path.isfile(model_path):
|
|
||||||
if any(model_path.endswith(f".{ext}") for ext in ["safetensors", "ckpt", "pt"]):
|
|
||||||
return StableDiffusion1ModelFormat.Checkpoint
|
|
||||||
|
|
||||||
raise InvalidModelException(f"Not a valid model: {model_path}")
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def convert_if_required(
|
|
||||||
cls,
|
|
||||||
model_path: str,
|
|
||||||
output_path: str,
|
|
||||||
config: ModelConfigBase,
|
|
||||||
base_model: BaseModelType,
|
|
||||||
) -> str:
|
|
||||||
if isinstance(config, cls.CheckpointConfig):
|
|
||||||
return _convert_ckpt_and_cache(
|
|
||||||
version=BaseModelType.StableDiffusion1,
|
|
||||||
model_config=config,
|
|
||||||
load_safety_checker=False,
|
|
||||||
output_path=output_path,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
return model_path
|
|
||||||
|
|
||||||
|
|
||||||
class StableDiffusion2ModelFormat(str, Enum):
|
|
||||||
Checkpoint = "checkpoint"
|
|
||||||
Diffusers = "diffusers"
|
|
||||||
|
|
||||||
|
|
||||||
class StableDiffusion2Model(DiffusersModel):
|
|
||||||
# TODO: check that configs overwriten properly
|
|
||||||
class DiffusersConfig(ModelConfigBase):
|
|
||||||
model_format: Literal[StableDiffusion2ModelFormat.Diffusers]
|
|
||||||
vae: Optional[str] = Field(None)
|
|
||||||
variant: ModelVariantType
|
|
||||||
|
|
||||||
class CheckpointConfig(ModelConfigBase):
|
|
||||||
model_format: Literal[StableDiffusion2ModelFormat.Checkpoint]
|
|
||||||
vae: Optional[str] = Field(None)
|
|
||||||
config: str
|
|
||||||
variant: ModelVariantType
|
|
||||||
|
|
||||||
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
|
|
||||||
assert base_model == BaseModelType.StableDiffusion2
|
|
||||||
assert model_type == ModelType.Main
|
|
||||||
super().__init__(
|
|
||||||
model_path=model_path,
|
|
||||||
base_model=BaseModelType.StableDiffusion2,
|
|
||||||
model_type=ModelType.Main,
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def probe_config(cls, path: str, **kwargs):
|
|
||||||
model_format = cls.detect_format(path)
|
|
||||||
ckpt_config_path = kwargs.get("config", None)
|
|
||||||
if model_format == StableDiffusion2ModelFormat.Checkpoint:
|
|
||||||
if ckpt_config_path:
|
|
||||||
ckpt_config = OmegaConf.load(ckpt_config_path)
|
|
||||||
ckpt_config["model"]["params"]["unet_config"]["params"]["in_channels"]
|
|
||||||
|
|
||||||
else:
|
|
||||||
checkpoint = read_checkpoint_meta(path)
|
|
||||||
checkpoint = checkpoint.get("state_dict", checkpoint)
|
|
||||||
in_channels = checkpoint["model.diffusion_model.input_blocks.0.0.weight"].shape[1]
|
|
||||||
|
|
||||||
elif model_format == StableDiffusion2ModelFormat.Diffusers:
|
|
||||||
unet_config_path = os.path.join(path, "unet", "config.json")
|
|
||||||
if os.path.exists(unet_config_path):
|
|
||||||
with open(unet_config_path, "r") as f:
|
|
||||||
unet_config = json.loads(f.read())
|
|
||||||
in_channels = unet_config["in_channels"]
|
|
||||||
|
|
||||||
else:
|
|
||||||
raise Exception("Not supported stable diffusion diffusers format(possibly onnx?)")
|
|
||||||
|
|
||||||
else:
|
|
||||||
raise NotImplementedError(f"Unknown stable diffusion 2.* format: {model_format}")
|
|
||||||
|
|
||||||
if in_channels == 9:
|
|
||||||
variant = ModelVariantType.Inpaint
|
|
||||||
elif in_channels == 5:
|
|
||||||
variant = ModelVariantType.Depth
|
|
||||||
elif in_channels == 4:
|
|
||||||
variant = ModelVariantType.Normal
|
|
||||||
else:
|
|
||||||
raise Exception("Unkown stable diffusion 2.* model format")
|
|
||||||
|
|
||||||
if ckpt_config_path is None:
|
|
||||||
ckpt_config_path = _select_ckpt_config(BaseModelType.StableDiffusion2, variant)
|
|
||||||
|
|
||||||
return cls.create_config(
|
|
||||||
path=path,
|
|
||||||
model_format=model_format,
|
|
||||||
config=ckpt_config_path,
|
|
||||||
variant=variant,
|
|
||||||
)
|
|
||||||
|
|
||||||
@classproperty
|
|
||||||
def save_to_config(cls) -> bool:
|
|
||||||
return True
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def detect_format(cls, model_path: str):
|
|
||||||
if not os.path.exists(model_path):
|
|
||||||
raise ModelNotFoundException()
|
|
||||||
|
|
||||||
if os.path.isdir(model_path):
|
|
||||||
if os.path.exists(os.path.join(model_path, "model_index.json")):
|
|
||||||
return StableDiffusion2ModelFormat.Diffusers
|
|
||||||
|
|
||||||
if os.path.isfile(model_path):
|
|
||||||
if any(model_path.endswith(f".{ext}") for ext in ["safetensors", "ckpt", "pt"]):
|
|
||||||
return StableDiffusion2ModelFormat.Checkpoint
|
|
||||||
|
|
||||||
raise InvalidModelException(f"Not a valid model: {model_path}")
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def convert_if_required(
|
|
||||||
cls,
|
|
||||||
model_path: str,
|
|
||||||
output_path: str,
|
|
||||||
config: ModelConfigBase,
|
|
||||||
base_model: BaseModelType,
|
|
||||||
) -> str:
|
|
||||||
if isinstance(config, cls.CheckpointConfig):
|
|
||||||
return _convert_ckpt_and_cache(
|
|
||||||
version=BaseModelType.StableDiffusion2,
|
|
||||||
model_config=config,
|
|
||||||
output_path=output_path,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
return model_path
|
|
||||||
|
|
||||||
|
|
||||||
# TODO: rework
|
|
||||||
# pass precision - currently defaulting to fp16
|
|
||||||
def _convert_ckpt_and_cache(
|
|
||||||
version: BaseModelType,
|
|
||||||
model_config: Union[
|
|
||||||
StableDiffusion1Model.CheckpointConfig,
|
|
||||||
StableDiffusion2Model.CheckpointConfig,
|
|
||||||
StableDiffusionXLModel.CheckpointConfig,
|
|
||||||
],
|
|
||||||
output_path: str,
|
|
||||||
use_save_model: bool = False,
|
|
||||||
**kwargs,
|
|
||||||
) -> str:
|
|
||||||
"""
|
|
||||||
Convert the checkpoint model indicated in mconfig into a
|
|
||||||
diffusers, cache it to disk, and return Path to converted
|
|
||||||
file. If already on disk then just returns Path.
|
|
||||||
"""
|
|
||||||
app_config = InvokeAIAppConfig.get_config()
|
|
||||||
|
|
||||||
weights = app_config.models_path / model_config.path
|
|
||||||
config_file = app_config.root_path / model_config.config
|
|
||||||
output_path = Path(output_path)
|
|
||||||
variant = model_config.variant
|
|
||||||
pipeline_class = StableDiffusionInpaintPipeline if variant == "inpaint" else StableDiffusionPipeline
|
|
||||||
|
|
||||||
# return cached version if it exists
|
|
||||||
if output_path.exists():
|
|
||||||
return output_path
|
|
||||||
|
|
||||||
# to avoid circular import errors
|
|
||||||
from ...util.devices import choose_torch_device, torch_dtype
|
|
||||||
from ..convert_ckpt_to_diffusers import convert_ckpt_to_diffusers
|
|
||||||
|
|
||||||
model_base_to_model_type = {
|
|
||||||
BaseModelType.StableDiffusion1: "FrozenCLIPEmbedder",
|
|
||||||
BaseModelType.StableDiffusion2: "FrozenOpenCLIPEmbedder",
|
|
||||||
BaseModelType.StableDiffusionXL: "SDXL",
|
|
||||||
BaseModelType.StableDiffusionXLRefiner: "SDXL-Refiner",
|
|
||||||
}
|
|
||||||
logger.info(f"Converting {weights} to diffusers format")
|
|
||||||
with SilenceWarnings():
|
|
||||||
convert_ckpt_to_diffusers(
|
|
||||||
weights,
|
|
||||||
output_path,
|
|
||||||
model_type=model_base_to_model_type[version],
|
|
||||||
model_version=version,
|
|
||||||
model_variant=model_config.variant,
|
|
||||||
original_config_file=config_file,
|
|
||||||
extract_ema=True,
|
|
||||||
scan_needed=True,
|
|
||||||
pipeline_class=pipeline_class,
|
|
||||||
from_safetensors=weights.suffix == ".safetensors",
|
|
||||||
precision=torch_dtype(choose_torch_device()),
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
return output_path
|
|
||||||
|
|
||||||
|
|
||||||
def _select_ckpt_config(version: BaseModelType, variant: ModelVariantType):
|
|
||||||
ckpt_configs = {
|
|
||||||
BaseModelType.StableDiffusion1: {
|
|
||||||
ModelVariantType.Normal: "v1-inference.yaml",
|
|
||||||
ModelVariantType.Inpaint: "v1-inpainting-inference.yaml",
|
|
||||||
},
|
|
||||||
BaseModelType.StableDiffusion2: {
|
|
||||||
ModelVariantType.Normal: "v2-inference-v.yaml", # best guess, as we can't differentiate with base(512)
|
|
||||||
ModelVariantType.Inpaint: "v2-inpainting-inference.yaml",
|
|
||||||
ModelVariantType.Depth: "v2-midas-inference.yaml",
|
|
||||||
},
|
|
||||||
BaseModelType.StableDiffusionXL: {
|
|
||||||
ModelVariantType.Normal: "sd_xl_base.yaml",
|
|
||||||
ModelVariantType.Inpaint: None,
|
|
||||||
ModelVariantType.Depth: None,
|
|
||||||
},
|
|
||||||
BaseModelType.StableDiffusionXLRefiner: {
|
|
||||||
ModelVariantType.Normal: "sd_xl_refiner.yaml",
|
|
||||||
ModelVariantType.Inpaint: None,
|
|
||||||
ModelVariantType.Depth: None,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
app_config = InvokeAIAppConfig.get_config()
|
|
||||||
try:
|
|
||||||
config_path = app_config.legacy_conf_path / ckpt_configs[version][variant]
|
|
||||||
if config_path.is_relative_to(app_config.root_path):
|
|
||||||
config_path = config_path.relative_to(app_config.root_path)
|
|
||||||
return str(config_path)
|
|
||||||
|
|
||||||
except Exception:
|
|
||||||
return None
|
|
@ -1,150 +0,0 @@
|
|||||||
from enum import Enum
|
|
||||||
from typing import Literal
|
|
||||||
|
|
||||||
from diffusers import OnnxRuntimeModel
|
|
||||||
|
|
||||||
from .base import (
|
|
||||||
BaseModelType,
|
|
||||||
DiffusersModel,
|
|
||||||
IAIOnnxRuntimeModel,
|
|
||||||
ModelConfigBase,
|
|
||||||
ModelType,
|
|
||||||
ModelVariantType,
|
|
||||||
SchedulerPredictionType,
|
|
||||||
classproperty,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class StableDiffusionOnnxModelFormat(str, Enum):
|
|
||||||
Olive = "olive"
|
|
||||||
Onnx = "onnx"
|
|
||||||
|
|
||||||
|
|
||||||
class ONNXStableDiffusion1Model(DiffusersModel):
|
|
||||||
class Config(ModelConfigBase):
|
|
||||||
model_format: Literal[StableDiffusionOnnxModelFormat.Onnx]
|
|
||||||
variant: ModelVariantType
|
|
||||||
|
|
||||||
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
|
|
||||||
assert base_model == BaseModelType.StableDiffusion1
|
|
||||||
assert model_type == ModelType.ONNX
|
|
||||||
super().__init__(
|
|
||||||
model_path=model_path,
|
|
||||||
base_model=BaseModelType.StableDiffusion1,
|
|
||||||
model_type=ModelType.ONNX,
|
|
||||||
)
|
|
||||||
|
|
||||||
for child_name, child_type in self.child_types.items():
|
|
||||||
if child_type is OnnxRuntimeModel:
|
|
||||||
self.child_types[child_name] = IAIOnnxRuntimeModel
|
|
||||||
|
|
||||||
# TODO: check that no optimum models provided
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def probe_config(cls, path: str, **kwargs):
|
|
||||||
model_format = cls.detect_format(path)
|
|
||||||
in_channels = 4 # TODO:
|
|
||||||
|
|
||||||
if in_channels == 9:
|
|
||||||
variant = ModelVariantType.Inpaint
|
|
||||||
elif in_channels == 4:
|
|
||||||
variant = ModelVariantType.Normal
|
|
||||||
else:
|
|
||||||
raise Exception("Unkown stable diffusion 1.* model format")
|
|
||||||
|
|
||||||
return cls.create_config(
|
|
||||||
path=path,
|
|
||||||
model_format=model_format,
|
|
||||||
variant=variant,
|
|
||||||
)
|
|
||||||
|
|
||||||
@classproperty
|
|
||||||
def save_to_config(cls) -> bool:
|
|
||||||
return True
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def detect_format(cls, model_path: str):
|
|
||||||
# TODO: Detect onnx vs olive
|
|
||||||
return StableDiffusionOnnxModelFormat.Onnx
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def convert_if_required(
|
|
||||||
cls,
|
|
||||||
model_path: str,
|
|
||||||
output_path: str,
|
|
||||||
config: ModelConfigBase,
|
|
||||||
base_model: BaseModelType,
|
|
||||||
) -> str:
|
|
||||||
return model_path
|
|
||||||
|
|
||||||
|
|
||||||
class ONNXStableDiffusion2Model(DiffusersModel):
|
|
||||||
# TODO: check that configs overwriten properly
|
|
||||||
class Config(ModelConfigBase):
|
|
||||||
model_format: Literal[StableDiffusionOnnxModelFormat.Onnx]
|
|
||||||
variant: ModelVariantType
|
|
||||||
prediction_type: SchedulerPredictionType
|
|
||||||
upcast_attention: bool
|
|
||||||
|
|
||||||
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
|
|
||||||
assert base_model == BaseModelType.StableDiffusion2
|
|
||||||
assert model_type == ModelType.ONNX
|
|
||||||
super().__init__(
|
|
||||||
model_path=model_path,
|
|
||||||
base_model=BaseModelType.StableDiffusion2,
|
|
||||||
model_type=ModelType.ONNX,
|
|
||||||
)
|
|
||||||
|
|
||||||
for child_name, child_type in self.child_types.items():
|
|
||||||
if child_type is OnnxRuntimeModel:
|
|
||||||
self.child_types[child_name] = IAIOnnxRuntimeModel
|
|
||||||
# TODO: check that no optimum models provided
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def probe_config(cls, path: str, **kwargs):
|
|
||||||
model_format = cls.detect_format(path)
|
|
||||||
in_channels = 4 # TODO:
|
|
||||||
|
|
||||||
if in_channels == 9:
|
|
||||||
variant = ModelVariantType.Inpaint
|
|
||||||
elif in_channels == 5:
|
|
||||||
variant = ModelVariantType.Depth
|
|
||||||
elif in_channels == 4:
|
|
||||||
variant = ModelVariantType.Normal
|
|
||||||
else:
|
|
||||||
raise Exception("Unkown stable diffusion 2.* model format")
|
|
||||||
|
|
||||||
if variant == ModelVariantType.Normal:
|
|
||||||
prediction_type = SchedulerPredictionType.VPrediction
|
|
||||||
upcast_attention = True
|
|
||||||
|
|
||||||
else:
|
|
||||||
prediction_type = SchedulerPredictionType.Epsilon
|
|
||||||
upcast_attention = False
|
|
||||||
|
|
||||||
return cls.create_config(
|
|
||||||
path=path,
|
|
||||||
model_format=model_format,
|
|
||||||
variant=variant,
|
|
||||||
prediction_type=prediction_type,
|
|
||||||
upcast_attention=upcast_attention,
|
|
||||||
)
|
|
||||||
|
|
||||||
@classproperty
|
|
||||||
def save_to_config(cls) -> bool:
|
|
||||||
return True
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def detect_format(cls, model_path: str):
|
|
||||||
# TODO: Detect onnx vs olive
|
|
||||||
return StableDiffusionOnnxModelFormat.Onnx
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def convert_if_required(
|
|
||||||
cls,
|
|
||||||
model_path: str,
|
|
||||||
output_path: str,
|
|
||||||
config: ModelConfigBase,
|
|
||||||
base_model: BaseModelType,
|
|
||||||
) -> str:
|
|
||||||
return model_path
|
|
@ -1,102 +0,0 @@
|
|||||||
import os
|
|
||||||
from enum import Enum
|
|
||||||
from typing import Literal, Optional
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from diffusers import T2IAdapter
|
|
||||||
|
|
||||||
from invokeai.backend.model_management.models.base import (
|
|
||||||
BaseModelType,
|
|
||||||
EmptyConfigLoader,
|
|
||||||
InvalidModelException,
|
|
||||||
ModelBase,
|
|
||||||
ModelConfigBase,
|
|
||||||
ModelNotFoundException,
|
|
||||||
ModelType,
|
|
||||||
SubModelType,
|
|
||||||
calc_model_size_by_data,
|
|
||||||
calc_model_size_by_fs,
|
|
||||||
classproperty,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class T2IAdapterModelFormat(str, Enum):
|
|
||||||
Diffusers = "diffusers"
|
|
||||||
|
|
||||||
|
|
||||||
class T2IAdapterModel(ModelBase):
|
|
||||||
class DiffusersConfig(ModelConfigBase):
|
|
||||||
model_format: Literal[T2IAdapterModelFormat.Diffusers]
|
|
||||||
|
|
||||||
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
|
|
||||||
assert model_type == ModelType.T2IAdapter
|
|
||||||
super().__init__(model_path, base_model, model_type)
|
|
||||||
|
|
||||||
config = EmptyConfigLoader.load_config(self.model_path, config_name="config.json")
|
|
||||||
|
|
||||||
model_class_name = config.get("_class_name", None)
|
|
||||||
if model_class_name not in {"T2IAdapter"}:
|
|
||||||
raise InvalidModelException(f"Invalid T2I-Adapter model. Unknown _class_name: '{model_class_name}'.")
|
|
||||||
|
|
||||||
self.model_class = self._hf_definition_to_type(["diffusers", model_class_name])
|
|
||||||
self.model_size = calc_model_size_by_fs(self.model_path)
|
|
||||||
|
|
||||||
def get_size(self, child_type: Optional[SubModelType] = None):
|
|
||||||
if child_type is not None:
|
|
||||||
raise ValueError(f"T2I-Adapters do not have child models. Invalid child type: '{child_type}'.")
|
|
||||||
return self.model_size
|
|
||||||
|
|
||||||
def get_model(
|
|
||||||
self,
|
|
||||||
torch_dtype: Optional[torch.dtype],
|
|
||||||
child_type: Optional[SubModelType] = None,
|
|
||||||
) -> T2IAdapter:
|
|
||||||
if child_type is not None:
|
|
||||||
raise ValueError(f"T2I-Adapters do not have child models. Invalid child type: '{child_type}'.")
|
|
||||||
|
|
||||||
model = None
|
|
||||||
for variant in ["fp16", None]:
|
|
||||||
try:
|
|
||||||
model = self.model_class.from_pretrained(
|
|
||||||
self.model_path,
|
|
||||||
torch_dtype=torch_dtype,
|
|
||||||
variant=variant,
|
|
||||||
)
|
|
||||||
break
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
if not model:
|
|
||||||
raise ModelNotFoundException()
|
|
||||||
|
|
||||||
# Calculate a more accurate size after loading the model into memory.
|
|
||||||
self.model_size = calc_model_size_by_data(model)
|
|
||||||
return model
|
|
||||||
|
|
||||||
@classproperty
|
|
||||||
def save_to_config(cls) -> bool:
|
|
||||||
return False
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def detect_format(cls, path: str):
|
|
||||||
if not os.path.exists(path):
|
|
||||||
raise ModelNotFoundException(f"Model not found at '{path}'.")
|
|
||||||
|
|
||||||
if os.path.isdir(path):
|
|
||||||
if os.path.exists(os.path.join(path, "config.json")):
|
|
||||||
return T2IAdapterModelFormat.Diffusers
|
|
||||||
|
|
||||||
raise InvalidModelException(f"Unsupported T2I-Adapter format: '{path}'.")
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def convert_if_required(
|
|
||||||
cls,
|
|
||||||
model_path: str,
|
|
||||||
output_path: str,
|
|
||||||
config: ModelConfigBase,
|
|
||||||
base_model: BaseModelType,
|
|
||||||
) -> str:
|
|
||||||
format = cls.detect_format(model_path)
|
|
||||||
if format == T2IAdapterModelFormat.Diffusers:
|
|
||||||
return model_path
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unsupported format: '{format}'.")
|
|
@ -1,87 +0,0 @@
|
|||||||
import os
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
# TODO: naming
|
|
||||||
from ..lora import TextualInversionModel as TextualInversionModelRaw
|
|
||||||
from .base import (
|
|
||||||
BaseModelType,
|
|
||||||
InvalidModelException,
|
|
||||||
ModelBase,
|
|
||||||
ModelConfigBase,
|
|
||||||
ModelNotFoundException,
|
|
||||||
ModelType,
|
|
||||||
SubModelType,
|
|
||||||
classproperty,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class TextualInversionModel(ModelBase):
|
|
||||||
# model_size: int
|
|
||||||
|
|
||||||
class Config(ModelConfigBase):
|
|
||||||
model_format: None
|
|
||||||
|
|
||||||
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
|
|
||||||
assert model_type == ModelType.TextualInversion
|
|
||||||
super().__init__(model_path, base_model, model_type)
|
|
||||||
|
|
||||||
self.model_size = os.path.getsize(self.model_path)
|
|
||||||
|
|
||||||
def get_size(self, child_type: Optional[SubModelType] = None):
|
|
||||||
if child_type is not None:
|
|
||||||
raise Exception("There is no child models in textual inversion")
|
|
||||||
return self.model_size
|
|
||||||
|
|
||||||
def get_model(
|
|
||||||
self,
|
|
||||||
torch_dtype: Optional[torch.dtype],
|
|
||||||
child_type: Optional[SubModelType] = None,
|
|
||||||
):
|
|
||||||
if child_type is not None:
|
|
||||||
raise Exception("There is no child models in textual inversion")
|
|
||||||
|
|
||||||
checkpoint_path = self.model_path
|
|
||||||
if os.path.isdir(checkpoint_path):
|
|
||||||
checkpoint_path = os.path.join(checkpoint_path, "learned_embeds.bin")
|
|
||||||
|
|
||||||
if not os.path.exists(checkpoint_path):
|
|
||||||
raise ModelNotFoundException()
|
|
||||||
|
|
||||||
model = TextualInversionModelRaw.from_checkpoint(
|
|
||||||
file_path=checkpoint_path,
|
|
||||||
dtype=torch_dtype,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.model_size = model.embedding.nelement() * model.embedding.element_size()
|
|
||||||
return model
|
|
||||||
|
|
||||||
@classproperty
|
|
||||||
def save_to_config(cls) -> bool:
|
|
||||||
return False
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def detect_format(cls, path: str):
|
|
||||||
if not os.path.exists(path):
|
|
||||||
raise ModelNotFoundException()
|
|
||||||
|
|
||||||
if os.path.isdir(path):
|
|
||||||
if os.path.exists(os.path.join(path, "learned_embeds.bin")):
|
|
||||||
return None # diffusers-ti
|
|
||||||
|
|
||||||
if os.path.isfile(path):
|
|
||||||
if any(path.endswith(f".{ext}") for ext in ["safetensors", "ckpt", "pt", "bin"]):
|
|
||||||
return None
|
|
||||||
|
|
||||||
raise InvalidModelException(f"Not a valid model: {path}")
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def convert_if_required(
|
|
||||||
cls,
|
|
||||||
model_path: str,
|
|
||||||
output_path: str,
|
|
||||||
config: ModelConfigBase,
|
|
||||||
base_model: BaseModelType,
|
|
||||||
) -> str:
|
|
||||||
return model_path
|
|
@ -1,179 +0,0 @@
|
|||||||
import os
|
|
||||||
from enum import Enum
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
import safetensors
|
|
||||||
import torch
|
|
||||||
from omegaconf import OmegaConf
|
|
||||||
|
|
||||||
from invokeai.app.services.config import InvokeAIAppConfig
|
|
||||||
|
|
||||||
from .base import (
|
|
||||||
BaseModelType,
|
|
||||||
EmptyConfigLoader,
|
|
||||||
InvalidModelException,
|
|
||||||
ModelBase,
|
|
||||||
ModelConfigBase,
|
|
||||||
ModelNotFoundException,
|
|
||||||
ModelType,
|
|
||||||
ModelVariantType,
|
|
||||||
SubModelType,
|
|
||||||
calc_model_size_by_data,
|
|
||||||
calc_model_size_by_fs,
|
|
||||||
classproperty,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class VaeModelFormat(str, Enum):
|
|
||||||
Checkpoint = "checkpoint"
|
|
||||||
Diffusers = "diffusers"
|
|
||||||
|
|
||||||
|
|
||||||
class VaeModel(ModelBase):
|
|
||||||
# vae_class: Type
|
|
||||||
# model_size: int
|
|
||||||
|
|
||||||
class Config(ModelConfigBase):
|
|
||||||
model_format: VaeModelFormat
|
|
||||||
|
|
||||||
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
|
|
||||||
assert model_type == ModelType.Vae
|
|
||||||
super().__init__(model_path, base_model, model_type)
|
|
||||||
|
|
||||||
try:
|
|
||||||
config = EmptyConfigLoader.load_config(self.model_path, config_name="config.json")
|
|
||||||
# config = json.loads(os.path.join(self.model_path, "config.json"))
|
|
||||||
except Exception:
|
|
||||||
raise Exception("Invalid vae model! (config.json not found or invalid)")
|
|
||||||
|
|
||||||
try:
|
|
||||||
vae_class_name = config.get("_class_name", "AutoencoderKL")
|
|
||||||
self.vae_class = self._hf_definition_to_type(["diffusers", vae_class_name])
|
|
||||||
self.model_size = calc_model_size_by_fs(self.model_path)
|
|
||||||
except Exception:
|
|
||||||
raise Exception("Invalid vae model! (Unkown vae type)")
|
|
||||||
|
|
||||||
def get_size(self, child_type: Optional[SubModelType] = None):
|
|
||||||
if child_type is not None:
|
|
||||||
raise Exception("There is no child models in vae model")
|
|
||||||
return self.model_size
|
|
||||||
|
|
||||||
def get_model(
|
|
||||||
self,
|
|
||||||
torch_dtype: Optional[torch.dtype],
|
|
||||||
child_type: Optional[SubModelType] = None,
|
|
||||||
):
|
|
||||||
if child_type is not None:
|
|
||||||
raise Exception("There is no child models in vae model")
|
|
||||||
|
|
||||||
model = self.vae_class.from_pretrained(
|
|
||||||
self.model_path,
|
|
||||||
torch_dtype=torch_dtype,
|
|
||||||
)
|
|
||||||
# calc more accurate size
|
|
||||||
self.model_size = calc_model_size_by_data(model)
|
|
||||||
return model
|
|
||||||
|
|
||||||
@classproperty
|
|
||||||
def save_to_config(cls) -> bool:
|
|
||||||
return False
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def detect_format(cls, path: str):
|
|
||||||
if not os.path.exists(path):
|
|
||||||
raise ModelNotFoundException(f"Does not exist as local file: {path}")
|
|
||||||
|
|
||||||
if os.path.isdir(path):
|
|
||||||
if os.path.exists(os.path.join(path, "config.json")):
|
|
||||||
return VaeModelFormat.Diffusers
|
|
||||||
|
|
||||||
if os.path.isfile(path):
|
|
||||||
if any(path.endswith(f".{ext}") for ext in ["safetensors", "ckpt", "pt"]):
|
|
||||||
return VaeModelFormat.Checkpoint
|
|
||||||
|
|
||||||
raise InvalidModelException(f"Not a valid model: {path}")
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def convert_if_required(
|
|
||||||
cls,
|
|
||||||
model_path: str,
|
|
||||||
output_path: str,
|
|
||||||
config: ModelConfigBase, # empty config or config of parent model
|
|
||||||
base_model: BaseModelType,
|
|
||||||
) -> str:
|
|
||||||
if cls.detect_format(model_path) == VaeModelFormat.Checkpoint:
|
|
||||||
return _convert_vae_ckpt_and_cache(
|
|
||||||
weights_path=model_path,
|
|
||||||
output_path=output_path,
|
|
||||||
base_model=base_model,
|
|
||||||
model_config=config,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
return model_path
|
|
||||||
|
|
||||||
|
|
||||||
# TODO: rework
|
|
||||||
def _convert_vae_ckpt_and_cache(
|
|
||||||
weights_path: str,
|
|
||||||
output_path: str,
|
|
||||||
base_model: BaseModelType,
|
|
||||||
model_config: ModelConfigBase,
|
|
||||||
) -> str:
|
|
||||||
"""
|
|
||||||
Convert the VAE indicated in mconfig into a diffusers AutoencoderKL
|
|
||||||
object, cache it to disk, and return Path to converted
|
|
||||||
file. If already on disk then just returns Path.
|
|
||||||
"""
|
|
||||||
app_config = InvokeAIAppConfig.get_config()
|
|
||||||
weights_path = app_config.root_dir / weights_path
|
|
||||||
output_path = Path(output_path)
|
|
||||||
|
|
||||||
"""
|
|
||||||
this size used only in when tiling enabled to separate input in tiles
|
|
||||||
sizes in configs from stable diffusion githubs(1 and 2) set to 256
|
|
||||||
on huggingface it:
|
|
||||||
1.5 - 512
|
|
||||||
1.5-inpainting - 256
|
|
||||||
2-inpainting - 512
|
|
||||||
2-depth - 256
|
|
||||||
2-base - 512
|
|
||||||
2 - 768
|
|
||||||
2.1-base - 768
|
|
||||||
2.1 - 768
|
|
||||||
"""
|
|
||||||
image_size = 512
|
|
||||||
|
|
||||||
# return cached version if it exists
|
|
||||||
if output_path.exists():
|
|
||||||
return output_path
|
|
||||||
|
|
||||||
if base_model in {BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2}:
|
|
||||||
from .stable_diffusion import _select_ckpt_config
|
|
||||||
|
|
||||||
# all sd models use same vae settings
|
|
||||||
config_file = _select_ckpt_config(base_model, ModelVariantType.Normal)
|
|
||||||
else:
|
|
||||||
raise Exception(f"Vae conversion not supported for model type: {base_model}")
|
|
||||||
|
|
||||||
# this avoids circular import error
|
|
||||||
from ..convert_ckpt_to_diffusers import convert_ldm_vae_to_diffusers
|
|
||||||
|
|
||||||
if weights_path.suffix == ".safetensors":
|
|
||||||
checkpoint = safetensors.torch.load_file(weights_path, device="cpu")
|
|
||||||
else:
|
|
||||||
checkpoint = torch.load(weights_path, map_location="cpu")
|
|
||||||
|
|
||||||
# sometimes weights are hidden under "state_dict", and sometimes not
|
|
||||||
if "state_dict" in checkpoint:
|
|
||||||
checkpoint = checkpoint["state_dict"]
|
|
||||||
|
|
||||||
config = OmegaConf.load(app_config.root_path / config_file)
|
|
||||||
|
|
||||||
vae_model = convert_ldm_vae_to_diffusers(
|
|
||||||
checkpoint=checkpoint,
|
|
||||||
vae_config=config,
|
|
||||||
image_size=image_size,
|
|
||||||
)
|
|
||||||
vae_model.save_pretrained(output_path, safe_serialization=True)
|
|
||||||
return output_path
|
|
@ -1,102 +0,0 @@
|
|||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from contextlib import contextmanager
|
|
||||||
from typing import List, Union
|
|
||||||
|
|
||||||
import torch.nn as nn
|
|
||||||
from diffusers.models import AutoencoderKL, UNet2DConditionModel
|
|
||||||
|
|
||||||
|
|
||||||
def _conv_forward_asymmetric(self, input, weight, bias):
|
|
||||||
"""
|
|
||||||
Patch for Conv2d._conv_forward that supports asymmetric padding
|
|
||||||
"""
|
|
||||||
working = nn.functional.pad(input, self.asymmetric_padding["x"], mode=self.asymmetric_padding_mode["x"])
|
|
||||||
working = nn.functional.pad(working, self.asymmetric_padding["y"], mode=self.asymmetric_padding_mode["y"])
|
|
||||||
return nn.functional.conv2d(
|
|
||||||
working,
|
|
||||||
weight,
|
|
||||||
bias,
|
|
||||||
self.stride,
|
|
||||||
nn.modules.utils._pair(0),
|
|
||||||
self.dilation,
|
|
||||||
self.groups,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
|
||||||
def set_seamless(model: Union[UNet2DConditionModel, AutoencoderKL], seamless_axes: List[str]):
|
|
||||||
try:
|
|
||||||
to_restore = []
|
|
||||||
|
|
||||||
for m_name, m in model.named_modules():
|
|
||||||
if isinstance(model, UNet2DConditionModel):
|
|
||||||
if ".attentions." in m_name:
|
|
||||||
continue
|
|
||||||
|
|
||||||
if ".resnets." in m_name:
|
|
||||||
if ".conv2" in m_name:
|
|
||||||
continue
|
|
||||||
if ".conv_shortcut" in m_name:
|
|
||||||
continue
|
|
||||||
|
|
||||||
"""
|
|
||||||
if isinstance(model, UNet2DConditionModel):
|
|
||||||
if False and ".upsamplers." in m_name:
|
|
||||||
continue
|
|
||||||
|
|
||||||
if False and ".downsamplers." in m_name:
|
|
||||||
continue
|
|
||||||
|
|
||||||
if True and ".resnets." in m_name:
|
|
||||||
if True and ".conv1" in m_name:
|
|
||||||
if False and "down_blocks" in m_name:
|
|
||||||
continue
|
|
||||||
if False and "mid_block" in m_name:
|
|
||||||
continue
|
|
||||||
if False and "up_blocks" in m_name:
|
|
||||||
continue
|
|
||||||
|
|
||||||
if True and ".conv2" in m_name:
|
|
||||||
continue
|
|
||||||
|
|
||||||
if True and ".conv_shortcut" in m_name:
|
|
||||||
continue
|
|
||||||
|
|
||||||
if True and ".attentions." in m_name:
|
|
||||||
continue
|
|
||||||
|
|
||||||
if False and m_name in ["conv_in", "conv_out"]:
|
|
||||||
continue
|
|
||||||
"""
|
|
||||||
|
|
||||||
if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
|
|
||||||
m.asymmetric_padding_mode = {}
|
|
||||||
m.asymmetric_padding = {}
|
|
||||||
m.asymmetric_padding_mode["x"] = "circular" if ("x" in seamless_axes) else "constant"
|
|
||||||
m.asymmetric_padding["x"] = (
|
|
||||||
m._reversed_padding_repeated_twice[0],
|
|
||||||
m._reversed_padding_repeated_twice[1],
|
|
||||||
0,
|
|
||||||
0,
|
|
||||||
)
|
|
||||||
m.asymmetric_padding_mode["y"] = "circular" if ("y" in seamless_axes) else "constant"
|
|
||||||
m.asymmetric_padding["y"] = (
|
|
||||||
0,
|
|
||||||
0,
|
|
||||||
m._reversed_padding_repeated_twice[2],
|
|
||||||
m._reversed_padding_repeated_twice[3],
|
|
||||||
)
|
|
||||||
|
|
||||||
to_restore.append((m, m._conv_forward))
|
|
||||||
m._conv_forward = _conv_forward_asymmetric.__get__(m, nn.Conv2d)
|
|
||||||
|
|
||||||
yield
|
|
||||||
|
|
||||||
finally:
|
|
||||||
for module, orig_conv_forward in to_restore:
|
|
||||||
module._conv_forward = orig_conv_forward
|
|
||||||
if hasattr(module, "asymmetric_padding_mode"):
|
|
||||||
del module.asymmetric_padding_mode
|
|
||||||
if hasattr(module, "asymmetric_padding"):
|
|
||||||
del module.asymmetric_padding
|
|
@ -1,79 +0,0 @@
|
|||||||
# Copyright (c) 2023 The InvokeAI Development Team
|
|
||||||
"""Utilities used by the Model Manager"""
|
|
||||||
|
|
||||||
|
|
||||||
def lora_token_vector_length(checkpoint: dict) -> int:
|
|
||||||
"""
|
|
||||||
Given a checkpoint in memory, return the lora token vector length
|
|
||||||
|
|
||||||
:param checkpoint: The checkpoint
|
|
||||||
"""
|
|
||||||
|
|
||||||
def _get_shape_1(key: str, tensor, checkpoint) -> int:
|
|
||||||
lora_token_vector_length = None
|
|
||||||
|
|
||||||
if "." not in key:
|
|
||||||
return lora_token_vector_length # wrong key format
|
|
||||||
model_key, lora_key = key.split(".", 1)
|
|
||||||
|
|
||||||
# check lora/locon
|
|
||||||
if lora_key == "lora_down.weight":
|
|
||||||
lora_token_vector_length = tensor.shape[1]
|
|
||||||
|
|
||||||
# check loha (don't worry about hada_t1/hada_t2 as it used only in 4d shapes)
|
|
||||||
elif lora_key in ["hada_w1_b", "hada_w2_b"]:
|
|
||||||
lora_token_vector_length = tensor.shape[1]
|
|
||||||
|
|
||||||
# check lokr (don't worry about lokr_t2 as it used only in 4d shapes)
|
|
||||||
elif "lokr_" in lora_key:
|
|
||||||
if model_key + ".lokr_w1" in checkpoint:
|
|
||||||
_lokr_w1 = checkpoint[model_key + ".lokr_w1"]
|
|
||||||
elif model_key + "lokr_w1_b" in checkpoint:
|
|
||||||
_lokr_w1 = checkpoint[model_key + ".lokr_w1_b"]
|
|
||||||
else:
|
|
||||||
return lora_token_vector_length # unknown format
|
|
||||||
|
|
||||||
if model_key + ".lokr_w2" in checkpoint:
|
|
||||||
_lokr_w2 = checkpoint[model_key + ".lokr_w2"]
|
|
||||||
elif model_key + "lokr_w2_b" in checkpoint:
|
|
||||||
_lokr_w2 = checkpoint[model_key + ".lokr_w2_b"]
|
|
||||||
else:
|
|
||||||
return lora_token_vector_length # unknown format
|
|
||||||
|
|
||||||
lora_token_vector_length = _lokr_w1.shape[1] * _lokr_w2.shape[1]
|
|
||||||
|
|
||||||
elif lora_key == "diff":
|
|
||||||
lora_token_vector_length = tensor.shape[1]
|
|
||||||
|
|
||||||
# ia3 can be detected only by shape[0] in text encoder
|
|
||||||
elif lora_key == "weight" and "lora_unet_" not in model_key:
|
|
||||||
lora_token_vector_length = tensor.shape[0]
|
|
||||||
|
|
||||||
return lora_token_vector_length
|
|
||||||
|
|
||||||
lora_token_vector_length = None
|
|
||||||
lora_te1_length = None
|
|
||||||
lora_te2_length = None
|
|
||||||
for key, tensor in checkpoint.items():
|
|
||||||
if key.startswith("lora_unet_") and ("_attn2_to_k." in key or "_attn2_to_v." in key):
|
|
||||||
lora_token_vector_length = _get_shape_1(key, tensor, checkpoint)
|
|
||||||
elif key.startswith("lora_unet_") and (
|
|
||||||
"time_emb_proj.lora_down" in key
|
|
||||||
): # recognizes format at https://civitai.com/models/224641
|
|
||||||
lora_token_vector_length = _get_shape_1(key, tensor, checkpoint)
|
|
||||||
elif key.startswith("lora_te") and "_self_attn_" in key:
|
|
||||||
tmp_length = _get_shape_1(key, tensor, checkpoint)
|
|
||||||
if key.startswith("lora_te_"):
|
|
||||||
lora_token_vector_length = tmp_length
|
|
||||||
elif key.startswith("lora_te1_"):
|
|
||||||
lora_te1_length = tmp_length
|
|
||||||
elif key.startswith("lora_te2_"):
|
|
||||||
lora_te2_length = tmp_length
|
|
||||||
|
|
||||||
if lora_te1_length is not None and lora_te2_length is not None:
|
|
||||||
lora_token_vector_length = lora_te1_length + lora_te2_length
|
|
||||||
|
|
||||||
if lora_token_vector_length is not None:
|
|
||||||
break
|
|
||||||
|
|
||||||
return lora_token_vector_length
|
|
@ -1,5 +1,4 @@
|
|||||||
"""Re-export frequently-used symbols from the Model Manager backend."""
|
"""Re-export frequently-used symbols from the Model Manager backend."""
|
||||||
|
|
||||||
from .config import (
|
from .config import (
|
||||||
AnyModel,
|
AnyModel,
|
||||||
AnyModelConfig,
|
AnyModelConfig,
|
||||||
@ -33,3 +32,42 @@ __all__ = [
|
|||||||
"SchedulerPredictionType",
|
"SchedulerPredictionType",
|
||||||
"SubModelType",
|
"SubModelType",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
########## to help populate the openapi_schema with format enums for each config ###########
|
||||||
|
# This code is no longer necessary?
|
||||||
|
# leave it here just in case
|
||||||
|
#
|
||||||
|
# import inspect
|
||||||
|
# from enum import Enum
|
||||||
|
# from typing import Any, Iterable, Dict, get_args, Set
|
||||||
|
# def _expand(something: Any) -> Iterable[type]:
|
||||||
|
# if isinstance(something, type):
|
||||||
|
# yield something
|
||||||
|
# else:
|
||||||
|
# for x in get_args(something):
|
||||||
|
# for y in _expand(x):
|
||||||
|
# yield y
|
||||||
|
|
||||||
|
# def _find_format(cls: type) -> Iterable[Enum]:
|
||||||
|
# if hasattr(inspect, "get_annotations"):
|
||||||
|
# fields = inspect.get_annotations(cls)
|
||||||
|
# else:
|
||||||
|
# fields = cls.__annotations__
|
||||||
|
# if "format" in fields:
|
||||||
|
# for x in get_args(fields["format"]):
|
||||||
|
# yield x
|
||||||
|
# for parent_class in cls.__bases__:
|
||||||
|
# for x in _find_format(parent_class):
|
||||||
|
# yield x
|
||||||
|
# return None
|
||||||
|
|
||||||
|
# def get_model_config_formats() -> Dict[str, Set[Enum]]:
|
||||||
|
# result: Dict[str, Set[Enum]] = {}
|
||||||
|
# for model_config in _expand(AnyModelConfig):
|
||||||
|
# for field in _find_format(model_config):
|
||||||
|
# if field is None:
|
||||||
|
# continue
|
||||||
|
# if not result.get(model_config.__qualname__):
|
||||||
|
# result[model_config.__qualname__] = set()
|
||||||
|
# result[model_config.__qualname__].add(field)
|
||||||
|
# return result
|
||||||
|
@ -6,12 +6,22 @@ from importlib import import_module
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from .convert_cache.convert_cache_default import ModelConvertCache
|
from .convert_cache.convert_cache_default import ModelConvertCache
|
||||||
from .load_base import AnyModelLoader, LoadedModel
|
from .load_base import LoadedModel, ModelLoaderBase
|
||||||
|
from .load_default import ModelLoader
|
||||||
from .model_cache.model_cache_default import ModelCache
|
from .model_cache.model_cache_default import ModelCache
|
||||||
|
from .model_loader_registry import ModelLoaderRegistry, ModelLoaderRegistryBase
|
||||||
|
|
||||||
# This registers the subclasses that implement loaders of specific model types
|
# This registers the subclasses that implement loaders of specific model types
|
||||||
loaders = [x.stem for x in Path(Path(__file__).parent, "model_loaders").glob("*.py") if x.stem != "__init__"]
|
loaders = [x.stem for x in Path(Path(__file__).parent, "model_loaders").glob("*.py") if x.stem != "__init__"]
|
||||||
for module in loaders:
|
for module in loaders:
|
||||||
import_module(f"{__package__}.model_loaders.{module}")
|
import_module(f"{__package__}.model_loaders.{module}")
|
||||||
|
|
||||||
__all__ = ["AnyModelLoader", "LoadedModel", "ModelCache", "ModelConvertCache"]
|
__all__ = [
|
||||||
|
"LoadedModel",
|
||||||
|
"ModelCache",
|
||||||
|
"ModelConvertCache",
|
||||||
|
"ModelLoaderBase",
|
||||||
|
"ModelLoader",
|
||||||
|
"ModelLoaderRegistryBase",
|
||||||
|
"ModelLoaderRegistry",
|
||||||
|
]
|
||||||
|
@ -1,37 +1,22 @@
|
|||||||
# Copyright (c) 2024, Lincoln D. Stein and the InvokeAI Development Team
|
# Copyright (c) 2024, Lincoln D. Stein and the InvokeAI Development Team
|
||||||
"""
|
"""
|
||||||
Base class for model loading in InvokeAI.
|
Base class for model loading in InvokeAI.
|
||||||
|
|
||||||
Use like this:
|
|
||||||
|
|
||||||
loader = AnyModelLoader(...)
|
|
||||||
loaded_model = loader.get_model('019ab39adfa1840455')
|
|
||||||
with loaded_model as model: # context manager moves model into VRAM
|
|
||||||
# do something with loaded_model
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import hashlib
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from logging import Logger
|
from logging import Logger
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Callable, Dict, Optional, Tuple, Type
|
from typing import Any, Optional
|
||||||
|
|
||||||
from invokeai.app.services.config import InvokeAIAppConfig
|
from invokeai.app.services.config import InvokeAIAppConfig
|
||||||
from invokeai.backend.model_manager.config import (
|
from invokeai.backend.model_manager.config import (
|
||||||
AnyModel,
|
AnyModel,
|
||||||
AnyModelConfig,
|
AnyModelConfig,
|
||||||
BaseModelType,
|
|
||||||
ModelConfigBase,
|
|
||||||
ModelFormat,
|
|
||||||
ModelType,
|
|
||||||
SubModelType,
|
SubModelType,
|
||||||
VaeCheckpointConfig,
|
|
||||||
VaeDiffusersConfig,
|
|
||||||
)
|
)
|
||||||
from invokeai.backend.model_manager.load.convert_cache.convert_cache_base import ModelConvertCacheBase
|
from invokeai.backend.model_manager.load.convert_cache.convert_cache_base import ModelConvertCacheBase
|
||||||
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase, ModelLockerBase
|
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase, ModelLockerBase
|
||||||
from invokeai.backend.util.logging import InvokeAILogger
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -71,7 +56,7 @@ class ModelLoaderBase(ABC):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def load_model(self, model_config: ModelConfigBase, submodel_type: Optional[SubModelType] = None) -> LoadedModel:
|
def load_model(self, model_config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> LoadedModel:
|
||||||
"""
|
"""
|
||||||
Return a model given its confguration.
|
Return a model given its confguration.
|
||||||
|
|
||||||
@ -90,106 +75,3 @@ class ModelLoaderBase(ABC):
|
|||||||
) -> int:
|
) -> int:
|
||||||
"""Return size in bytes of the model, calculated before loading."""
|
"""Return size in bytes of the model, calculated before loading."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
# TO DO: Better name?
|
|
||||||
class AnyModelLoader:
|
|
||||||
"""This class manages the model loaders and invokes the correct one to load a model of given base and type."""
|
|
||||||
|
|
||||||
# this tracks the loader subclasses
|
|
||||||
_registry: Dict[str, Type[ModelLoaderBase]] = {}
|
|
||||||
_logger: Logger = InvokeAILogger.get_logger()
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
app_config: InvokeAIAppConfig,
|
|
||||||
logger: Logger,
|
|
||||||
ram_cache: ModelCacheBase[AnyModel],
|
|
||||||
convert_cache: ModelConvertCacheBase,
|
|
||||||
):
|
|
||||||
"""Initialize AnyModelLoader with its dependencies."""
|
|
||||||
self._app_config = app_config
|
|
||||||
self._logger = logger
|
|
||||||
self._ram_cache = ram_cache
|
|
||||||
self._convert_cache = convert_cache
|
|
||||||
|
|
||||||
@property
|
|
||||||
def ram_cache(self) -> ModelCacheBase[AnyModel]:
|
|
||||||
"""Return the RAM cache associated used by the loaders."""
|
|
||||||
return self._ram_cache
|
|
||||||
|
|
||||||
@property
|
|
||||||
def convert_cache(self) -> ModelConvertCacheBase:
|
|
||||||
"""Return the convert cache associated used by the loaders."""
|
|
||||||
return self._convert_cache
|
|
||||||
|
|
||||||
def load_model(self, model_config: ModelConfigBase, submodel_type: Optional[SubModelType] = None) -> LoadedModel:
|
|
||||||
"""
|
|
||||||
Return a model given its configuration.
|
|
||||||
|
|
||||||
:param key: model key, as known to the config backend
|
|
||||||
:param submodel_type: an ModelType enum indicating the portion of
|
|
||||||
the model to retrieve (e.g. ModelType.Vae)
|
|
||||||
"""
|
|
||||||
implementation, model_config, submodel_type = self.__class__.get_implementation(model_config, submodel_type)
|
|
||||||
return implementation(
|
|
||||||
app_config=self._app_config,
|
|
||||||
logger=self._logger,
|
|
||||||
ram_cache=self._ram_cache,
|
|
||||||
convert_cache=self._convert_cache,
|
|
||||||
).load_model(model_config, submodel_type)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _to_registry_key(base: BaseModelType, type: ModelType, format: ModelFormat) -> str:
|
|
||||||
return "-".join([base.value, type.value, format.value])
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_implementation(
|
|
||||||
cls, config: ModelConfigBase, submodel_type: Optional[SubModelType]
|
|
||||||
) -> Tuple[Type[ModelLoaderBase], ModelConfigBase, Optional[SubModelType]]:
|
|
||||||
"""Get subclass of ModelLoaderBase registered to handle base and type."""
|
|
||||||
# We have to handle VAE overrides here because this will change the model type and the corresponding implementation returned
|
|
||||||
conf2, submodel_type = cls._handle_subtype_overrides(config, submodel_type)
|
|
||||||
|
|
||||||
key1 = cls._to_registry_key(conf2.base, conf2.type, conf2.format) # for a specific base type
|
|
||||||
key2 = cls._to_registry_key(BaseModelType.Any, conf2.type, conf2.format) # with wildcard Any
|
|
||||||
implementation = cls._registry.get(key1) or cls._registry.get(key2)
|
|
||||||
if not implementation:
|
|
||||||
raise NotImplementedError(
|
|
||||||
f"No subclass of LoadedModel is registered for base={config.base}, type={config.type}, format={config.format}"
|
|
||||||
)
|
|
||||||
return implementation, conf2, submodel_type
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def _handle_subtype_overrides(
|
|
||||||
cls, config: ModelConfigBase, submodel_type: Optional[SubModelType]
|
|
||||||
) -> Tuple[ModelConfigBase, Optional[SubModelType]]:
|
|
||||||
if submodel_type == SubModelType.Vae and hasattr(config, "vae") and config.vae is not None:
|
|
||||||
model_path = Path(config.vae)
|
|
||||||
config_class = (
|
|
||||||
VaeCheckpointConfig if model_path.suffix in [".pt", ".safetensors", ".ckpt"] else VaeDiffusersConfig
|
|
||||||
)
|
|
||||||
hash = hashlib.md5(model_path.as_posix().encode("utf-8")).hexdigest()
|
|
||||||
new_conf = config_class(path=model_path.as_posix(), name=model_path.stem, base=config.base, key=hash)
|
|
||||||
submodel_type = None
|
|
||||||
else:
|
|
||||||
new_conf = config
|
|
||||||
return new_conf, submodel_type
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def register(
|
|
||||||
cls, type: ModelType, format: ModelFormat, base: BaseModelType = BaseModelType.Any
|
|
||||||
) -> Callable[[Type[ModelLoaderBase]], Type[ModelLoaderBase]]:
|
|
||||||
"""Define a decorator which registers the subclass of loader."""
|
|
||||||
|
|
||||||
def decorator(subclass: Type[ModelLoaderBase]) -> Type[ModelLoaderBase]:
|
|
||||||
cls._logger.debug(f"Registering class {subclass.__name__} to load models of type {base}/{type}/{format}")
|
|
||||||
key = cls._to_registry_key(base, type, format)
|
|
||||||
if key in cls._registry:
|
|
||||||
raise Exception(
|
|
||||||
f"{subclass.__name__} is trying to register as a loader for {base}/{type}/{format}, but this type of model has already been registered by {cls._registry[key].__name__}"
|
|
||||||
)
|
|
||||||
cls._registry[key] = subclass
|
|
||||||
return subclass
|
|
||||||
|
|
||||||
return decorator
|
|
||||||
|
@ -1,13 +1,9 @@
|
|||||||
# Copyright (c) 2024, Lincoln D. Stein and the InvokeAI Development Team
|
# Copyright (c) 2024, Lincoln D. Stein and the InvokeAI Development Team
|
||||||
"""Default implementation of model loading in InvokeAI."""
|
"""Default implementation of model loading in InvokeAI."""
|
||||||
|
|
||||||
import sys
|
|
||||||
from logging import Logger
|
from logging import Logger
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Dict, Optional, Tuple
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
from diffusers import ModelMixin
|
|
||||||
from diffusers.configuration_utils import ConfigMixin
|
|
||||||
|
|
||||||
from invokeai.app.services.config import InvokeAIAppConfig
|
from invokeai.app.services.config import InvokeAIAppConfig
|
||||||
from invokeai.backend.model_manager import (
|
from invokeai.backend.model_manager import (
|
||||||
@ -25,17 +21,6 @@ from invokeai.backend.model_manager.load.optimizations import skip_torch_weight_
|
|||||||
from invokeai.backend.util.devices import choose_torch_device, torch_dtype
|
from invokeai.backend.util.devices import choose_torch_device, torch_dtype
|
||||||
|
|
||||||
|
|
||||||
class ConfigLoader(ConfigMixin):
|
|
||||||
"""Subclass of ConfigMixin for loading diffusers configuration files."""
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def load_config(cls, *args: Any, **kwargs: Any) -> Dict[str, Any]:
|
|
||||||
"""Load a diffusrs ConfigMixin configuration."""
|
|
||||||
cls.config_name = kwargs.pop("config_name")
|
|
||||||
# Diffusers doesn't provide typing info
|
|
||||||
return super().load_config(*args, **kwargs) # type: ignore
|
|
||||||
|
|
||||||
|
|
||||||
# TO DO: The loader is not thread safe!
|
# TO DO: The loader is not thread safe!
|
||||||
class ModelLoader(ModelLoaderBase):
|
class ModelLoader(ModelLoaderBase):
|
||||||
"""Default implementation of ModelLoaderBase."""
|
"""Default implementation of ModelLoaderBase."""
|
||||||
@ -137,43 +122,6 @@ class ModelLoader(ModelLoaderBase):
|
|||||||
variant=config.repo_variant if hasattr(config, "repo_variant") else None,
|
variant=config.repo_variant if hasattr(config, "repo_variant") else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _load_diffusers_config(self, model_path: Path, config_name: str = "config.json") -> Dict[str, Any]:
|
|
||||||
return ConfigLoader.load_config(model_path, config_name=config_name)
|
|
||||||
|
|
||||||
# TO DO: Add exception handling
|
|
||||||
def _hf_definition_to_type(self, module: str, class_name: str) -> ModelMixin: # fix with correct type
|
|
||||||
if module in ["diffusers", "transformers"]:
|
|
||||||
res_type = sys.modules[module]
|
|
||||||
else:
|
|
||||||
res_type = sys.modules["diffusers"].pipelines
|
|
||||||
result: ModelMixin = getattr(res_type, class_name)
|
|
||||||
return result
|
|
||||||
|
|
||||||
# TO DO: Add exception handling
|
|
||||||
def _get_hf_load_class(self, model_path: Path, submodel_type: Optional[SubModelType] = None) -> ModelMixin:
|
|
||||||
if submodel_type:
|
|
||||||
try:
|
|
||||||
config = self._load_diffusers_config(model_path, config_name="model_index.json")
|
|
||||||
module, class_name = config[submodel_type.value]
|
|
||||||
return self._hf_definition_to_type(module=module, class_name=class_name)
|
|
||||||
except KeyError as e:
|
|
||||||
raise InvalidModelConfigException(
|
|
||||||
f'The "{submodel_type}" submodel is not available for this model.'
|
|
||||||
) from e
|
|
||||||
else:
|
|
||||||
try:
|
|
||||||
config = self._load_diffusers_config(model_path, config_name="config.json")
|
|
||||||
class_name = config.get("_class_name", None)
|
|
||||||
if class_name:
|
|
||||||
return self._hf_definition_to_type(module="diffusers", class_name=class_name)
|
|
||||||
if config.get("model_type", None) == "clip_vision_model":
|
|
||||||
class_name = config.get("architectures")[0]
|
|
||||||
return self._hf_definition_to_type(module="transformers", class_name=class_name)
|
|
||||||
if not class_name:
|
|
||||||
raise InvalidModelConfigException("Unable to decifer Load Class based on given config.json")
|
|
||||||
except KeyError as e:
|
|
||||||
raise InvalidModelConfigException("An expected config.json file is missing from this model.") from e
|
|
||||||
|
|
||||||
# This needs to be implemented in subclasses that handle checkpoints
|
# This needs to be implemented in subclasses that handle checkpoints
|
||||||
def _convert_model(self, config: AnyModelConfig, model_path: Path, output_path: Path) -> Path:
|
def _convert_model(self, config: AnyModelConfig, model_path: Path, output_path: Path) -> Path:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
@ -55,7 +55,7 @@ class MemorySnapshot:
|
|||||||
vram = None
|
vram = None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
malloc_info = LibcUtil().mallinfo2() # type: ignore
|
malloc_info = LibcUtil().mallinfo2()
|
||||||
except (OSError, AttributeError):
|
except (OSError, AttributeError):
|
||||||
# OSError: This is expected in environments that do not have the 'libc.so.6' shared library.
|
# OSError: This is expected in environments that do not have the 'libc.so.6' shared library.
|
||||||
# AttributeError: This is expected in environments that have `libc.so.6` but do not have the `mallinfo2` (e.g. glibc < 2.33)
|
# AttributeError: This is expected in environments that have `libc.so.6` but do not have the `mallinfo2` (e.g. glibc < 2.33)
|
||||||
|
122
invokeai/backend/model_manager/load/model_loader_registry.py
Normal file
122
invokeai/backend/model_manager/load/model_loader_registry.py
Normal file
@ -0,0 +1,122 @@
|
|||||||
|
# Copyright (c) 2024 Lincoln D. Stein and the InvokeAI Development team
|
||||||
|
"""
|
||||||
|
This module implements a system in which model loaders register the
|
||||||
|
type, base and format of models that they know how to load.
|
||||||
|
|
||||||
|
Use like this:
|
||||||
|
|
||||||
|
cls, model_config, submodel_type = ModelLoaderRegistry.get_implementation(model_config, submodel_type) # type: ignore
|
||||||
|
loaded_model = cls(
|
||||||
|
app_config=app_config,
|
||||||
|
logger=logger,
|
||||||
|
ram_cache=ram_cache,
|
||||||
|
convert_cache=convert_cache
|
||||||
|
).load_model(model_config, submodel_type)
|
||||||
|
|
||||||
|
"""
|
||||||
|
import hashlib
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Callable, Dict, Optional, Tuple, Type
|
||||||
|
|
||||||
|
from ..config import (
|
||||||
|
AnyModelConfig,
|
||||||
|
BaseModelType,
|
||||||
|
ModelConfigBase,
|
||||||
|
ModelFormat,
|
||||||
|
ModelType,
|
||||||
|
SubModelType,
|
||||||
|
VaeCheckpointConfig,
|
||||||
|
VaeDiffusersConfig,
|
||||||
|
)
|
||||||
|
from . import ModelLoaderBase
|
||||||
|
|
||||||
|
|
||||||
|
class ModelLoaderRegistryBase(ABC):
|
||||||
|
"""This class allows model loaders to register their type, base and format."""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@abstractmethod
|
||||||
|
def register(
|
||||||
|
cls, type: ModelType, format: ModelFormat, base: BaseModelType = BaseModelType.Any
|
||||||
|
) -> Callable[[Type[ModelLoaderBase]], Type[ModelLoaderBase]]:
|
||||||
|
"""Define a decorator which registers the subclass of loader."""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@abstractmethod
|
||||||
|
def get_implementation(
|
||||||
|
cls, config: AnyModelConfig, submodel_type: Optional[SubModelType]
|
||||||
|
) -> Tuple[Type[ModelLoaderBase], ModelConfigBase, Optional[SubModelType]]:
|
||||||
|
"""
|
||||||
|
Get subclass of ModelLoaderBase registered to handle base and type.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
:param config: Model configuration record, as returned by ModelRecordService
|
||||||
|
:param submodel_type: Submodel to fetch (main models only)
|
||||||
|
:return: tuple(loader_class, model_config, submodel_type)
|
||||||
|
|
||||||
|
Note that the returned model config may be different from one what passed
|
||||||
|
in, in the event that a submodel type is provided.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class ModelLoaderRegistry:
|
||||||
|
"""
|
||||||
|
This class allows model loaders to register their type, base and format.
|
||||||
|
"""
|
||||||
|
|
||||||
|
_registry: Dict[str, Type[ModelLoaderBase]] = {}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def register(
|
||||||
|
cls, type: ModelType, format: ModelFormat, base: BaseModelType = BaseModelType.Any
|
||||||
|
) -> Callable[[Type[ModelLoaderBase]], Type[ModelLoaderBase]]:
|
||||||
|
"""Define a decorator which registers the subclass of loader."""
|
||||||
|
|
||||||
|
def decorator(subclass: Type[ModelLoaderBase]) -> Type[ModelLoaderBase]:
|
||||||
|
key = cls._to_registry_key(base, type, format)
|
||||||
|
if key in cls._registry:
|
||||||
|
raise Exception(
|
||||||
|
f"{subclass.__name__} is trying to register as a loader for {base}/{type}/{format}, but this type of model has already been registered by {cls._registry[key].__name__}"
|
||||||
|
)
|
||||||
|
cls._registry[key] = subclass
|
||||||
|
return subclass
|
||||||
|
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_implementation(
|
||||||
|
cls, config: AnyModelConfig, submodel_type: Optional[SubModelType]
|
||||||
|
) -> Tuple[Type[ModelLoaderBase], ModelConfigBase, Optional[SubModelType]]:
|
||||||
|
"""Get subclass of ModelLoaderBase registered to handle base and type."""
|
||||||
|
# We have to handle VAE overrides here because this will change the model type and the corresponding implementation returned
|
||||||
|
conf2, submodel_type = cls._handle_subtype_overrides(config, submodel_type)
|
||||||
|
|
||||||
|
key1 = cls._to_registry_key(conf2.base, conf2.type, conf2.format) # for a specific base type
|
||||||
|
key2 = cls._to_registry_key(BaseModelType.Any, conf2.type, conf2.format) # with wildcard Any
|
||||||
|
implementation = cls._registry.get(key1) or cls._registry.get(key2)
|
||||||
|
if not implementation:
|
||||||
|
raise NotImplementedError(
|
||||||
|
f"No subclass of LoadedModel is registered for base={config.base}, type={config.type}, format={config.format}"
|
||||||
|
)
|
||||||
|
return implementation, conf2, submodel_type
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _handle_subtype_overrides(
|
||||||
|
cls, config: AnyModelConfig, submodel_type: Optional[SubModelType]
|
||||||
|
) -> Tuple[ModelConfigBase, Optional[SubModelType]]:
|
||||||
|
if submodel_type == SubModelType.Vae and hasattr(config, "vae") and config.vae is not None:
|
||||||
|
model_path = Path(config.vae)
|
||||||
|
config_class = (
|
||||||
|
VaeCheckpointConfig if model_path.suffix in [".pt", ".safetensors", ".ckpt"] else VaeDiffusersConfig
|
||||||
|
)
|
||||||
|
hash = hashlib.md5(model_path.as_posix().encode("utf-8")).hexdigest()
|
||||||
|
new_conf = config_class(path=model_path.as_posix(), name=model_path.stem, base=config.base, key=hash)
|
||||||
|
submodel_type = None
|
||||||
|
else:
|
||||||
|
new_conf = config
|
||||||
|
return new_conf, submodel_type
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _to_registry_key(base: BaseModelType, type: ModelType, format: ModelFormat) -> str:
|
||||||
|
return "-".join([base.value, type.value, format.value])
|
@ -13,13 +13,13 @@ from invokeai.backend.model_manager import (
|
|||||||
ModelType,
|
ModelType,
|
||||||
)
|
)
|
||||||
from invokeai.backend.model_manager.convert_ckpt_to_diffusers import convert_controlnet_to_diffusers
|
from invokeai.backend.model_manager.convert_ckpt_to_diffusers import convert_controlnet_to_diffusers
|
||||||
from invokeai.backend.model_manager.load.load_base import AnyModelLoader
|
|
||||||
|
|
||||||
|
from .. import ModelLoaderRegistry
|
||||||
from .generic_diffusers import GenericDiffusersLoader
|
from .generic_diffusers import GenericDiffusersLoader
|
||||||
|
|
||||||
|
|
||||||
@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.ControlNet, format=ModelFormat.Diffusers)
|
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.ControlNet, format=ModelFormat.Diffusers)
|
||||||
@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.ControlNet, format=ModelFormat.Checkpoint)
|
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.ControlNet, format=ModelFormat.Checkpoint)
|
||||||
class ControlnetLoader(GenericDiffusersLoader):
|
class ControlnetLoader(GenericDiffusersLoader):
|
||||||
"""Class to load ControlNet models."""
|
"""Class to load ControlNet models."""
|
||||||
|
|
||||||
|
@ -1,24 +1,27 @@
|
|||||||
# Copyright (c) 2024, Lincoln D. Stein and the InvokeAI Development Team
|
# Copyright (c) 2024, Lincoln D. Stein and the InvokeAI Development Team
|
||||||
"""Class for simple diffusers model loading in InvokeAI."""
|
"""Class for simple diffusers model loading in InvokeAI."""
|
||||||
|
|
||||||
|
import sys
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
|
from diffusers import ConfigMixin, ModelMixin
|
||||||
|
|
||||||
from invokeai.backend.model_manager import (
|
from invokeai.backend.model_manager import (
|
||||||
AnyModel,
|
AnyModel,
|
||||||
BaseModelType,
|
BaseModelType,
|
||||||
|
InvalidModelConfigException,
|
||||||
ModelFormat,
|
ModelFormat,
|
||||||
ModelRepoVariant,
|
ModelRepoVariant,
|
||||||
ModelType,
|
ModelType,
|
||||||
SubModelType,
|
SubModelType,
|
||||||
)
|
)
|
||||||
|
|
||||||
from ..load_base import AnyModelLoader
|
from .. import ModelLoader, ModelLoaderRegistry
|
||||||
from ..load_default import ModelLoader
|
|
||||||
|
|
||||||
|
|
||||||
@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.CLIPVision, format=ModelFormat.Diffusers)
|
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.CLIPVision, format=ModelFormat.Diffusers)
|
||||||
@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.T2IAdapter, format=ModelFormat.Diffusers)
|
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.T2IAdapter, format=ModelFormat.Diffusers)
|
||||||
class GenericDiffusersLoader(ModelLoader):
|
class GenericDiffusersLoader(ModelLoader):
|
||||||
"""Class to load simple diffusers models."""
|
"""Class to load simple diffusers models."""
|
||||||
|
|
||||||
@ -28,9 +31,60 @@ class GenericDiffusersLoader(ModelLoader):
|
|||||||
model_variant: Optional[ModelRepoVariant] = None,
|
model_variant: Optional[ModelRepoVariant] = None,
|
||||||
submodel_type: Optional[SubModelType] = None,
|
submodel_type: Optional[SubModelType] = None,
|
||||||
) -> AnyModel:
|
) -> AnyModel:
|
||||||
model_class = self._get_hf_load_class(model_path)
|
model_class = self.get_hf_load_class(model_path)
|
||||||
if submodel_type is not None:
|
if submodel_type is not None:
|
||||||
raise Exception(f"There are no submodels in models of type {model_class}")
|
raise Exception(f"There are no submodels in models of type {model_class}")
|
||||||
variant = model_variant.value if model_variant else None
|
variant = model_variant.value if model_variant else None
|
||||||
result: AnyModel = model_class.from_pretrained(model_path, torch_dtype=self._torch_dtype, variant=variant) # type: ignore
|
result: AnyModel = model_class.from_pretrained(model_path, torch_dtype=self._torch_dtype, variant=variant) # type: ignore
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
# TO DO: Add exception handling
|
||||||
|
def get_hf_load_class(self, model_path: Path, submodel_type: Optional[SubModelType] = None) -> ModelMixin:
|
||||||
|
"""Given the model path and submodel, returns the diffusers ModelMixin subclass needed to load."""
|
||||||
|
if submodel_type:
|
||||||
|
try:
|
||||||
|
config = self._load_diffusers_config(model_path, config_name="model_index.json")
|
||||||
|
module, class_name = config[submodel_type.value]
|
||||||
|
result = self._hf_definition_to_type(module=module, class_name=class_name)
|
||||||
|
except KeyError as e:
|
||||||
|
raise InvalidModelConfigException(
|
||||||
|
f'The "{submodel_type}" submodel is not available for this model.'
|
||||||
|
) from e
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
config = self._load_diffusers_config(model_path, config_name="config.json")
|
||||||
|
class_name = config.get("_class_name", None)
|
||||||
|
if class_name:
|
||||||
|
result = self._hf_definition_to_type(module="diffusers", class_name=class_name)
|
||||||
|
if config.get("model_type", None) == "clip_vision_model":
|
||||||
|
class_name = config.get("architectures")
|
||||||
|
assert class_name is not None
|
||||||
|
result = self._hf_definition_to_type(module="transformers", class_name=class_name[0])
|
||||||
|
if not class_name:
|
||||||
|
raise InvalidModelConfigException("Unable to decifer Load Class based on given config.json")
|
||||||
|
except KeyError as e:
|
||||||
|
raise InvalidModelConfigException("An expected config.json file is missing from this model.") from e
|
||||||
|
return result
|
||||||
|
|
||||||
|
# TO DO: Add exception handling
|
||||||
|
def _hf_definition_to_type(self, module: str, class_name: str) -> ModelMixin: # fix with correct type
|
||||||
|
if module in ["diffusers", "transformers"]:
|
||||||
|
res_type = sys.modules[module]
|
||||||
|
else:
|
||||||
|
res_type = sys.modules["diffusers"].pipelines
|
||||||
|
result: ModelMixin = getattr(res_type, class_name)
|
||||||
|
return result
|
||||||
|
|
||||||
|
def _load_diffusers_config(self, model_path: Path, config_name: str = "config.json") -> Dict[str, Any]:
|
||||||
|
return ConfigLoader.load_config(model_path, config_name=config_name)
|
||||||
|
|
||||||
|
|
||||||
|
class ConfigLoader(ConfigMixin):
|
||||||
|
"""Subclass of ConfigMixin for loading diffusers configuration files."""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def load_config(cls, *args: Any, **kwargs: Any) -> Dict[str, Any]:
|
||||||
|
"""Load a diffusrs ConfigMixin configuration."""
|
||||||
|
cls.config_name = kwargs.pop("config_name")
|
||||||
|
# Diffusers doesn't provide typing info
|
||||||
|
return super().load_config(*args, **kwargs) # type: ignore
|
||||||
|
@ -15,11 +15,10 @@ from invokeai.backend.model_manager import (
|
|||||||
ModelType,
|
ModelType,
|
||||||
SubModelType,
|
SubModelType,
|
||||||
)
|
)
|
||||||
from invokeai.backend.model_manager.load.load_base import AnyModelLoader
|
from invokeai.backend.model_manager.load import ModelLoader, ModelLoaderRegistry
|
||||||
from invokeai.backend.model_manager.load.load_default import ModelLoader
|
|
||||||
|
|
||||||
|
|
||||||
@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.IPAdapter, format=ModelFormat.InvokeAI)
|
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.IPAdapter, format=ModelFormat.InvokeAI)
|
||||||
class IPAdapterInvokeAILoader(ModelLoader):
|
class IPAdapterInvokeAILoader(ModelLoader):
|
||||||
"""Class to load IP Adapter diffusers models."""
|
"""Class to load IP Adapter diffusers models."""
|
||||||
|
|
||||||
|
@ -18,13 +18,13 @@ from invokeai.backend.model_manager import (
|
|||||||
SubModelType,
|
SubModelType,
|
||||||
)
|
)
|
||||||
from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase
|
from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase
|
||||||
from invokeai.backend.model_manager.load.load_base import AnyModelLoader
|
|
||||||
from invokeai.backend.model_manager.load.load_default import ModelLoader
|
|
||||||
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase
|
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase
|
||||||
|
|
||||||
|
from .. import ModelLoader, ModelLoaderRegistry
|
||||||
|
|
||||||
@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.Lora, format=ModelFormat.Diffusers)
|
|
||||||
@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.Lora, format=ModelFormat.Lycoris)
|
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.Lora, format=ModelFormat.Diffusers)
|
||||||
|
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.Lora, format=ModelFormat.Lycoris)
|
||||||
class LoraLoader(ModelLoader):
|
class LoraLoader(ModelLoader):
|
||||||
"""Class to load LoRA models."""
|
"""Class to load LoRA models."""
|
||||||
|
|
||||||
|
@ -13,13 +13,14 @@ from invokeai.backend.model_manager import (
|
|||||||
ModelType,
|
ModelType,
|
||||||
SubModelType,
|
SubModelType,
|
||||||
)
|
)
|
||||||
from invokeai.backend.model_manager.load.load_base import AnyModelLoader
|
|
||||||
from invokeai.backend.model_manager.load.load_default import ModelLoader
|
from .. import ModelLoaderRegistry
|
||||||
|
from .generic_diffusers import GenericDiffusersLoader
|
||||||
|
|
||||||
|
|
||||||
@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.ONNX, format=ModelFormat.Onnx)
|
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.ONNX, format=ModelFormat.Onnx)
|
||||||
@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.ONNX, format=ModelFormat.Olive)
|
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.ONNX, format=ModelFormat.Olive)
|
||||||
class OnnyxDiffusersModel(ModelLoader):
|
class OnnyxDiffusersModel(GenericDiffusersLoader):
|
||||||
"""Class to load onnx models."""
|
"""Class to load onnx models."""
|
||||||
|
|
||||||
def _load_model(
|
def _load_model(
|
||||||
@ -30,7 +31,7 @@ class OnnyxDiffusersModel(ModelLoader):
|
|||||||
) -> AnyModel:
|
) -> AnyModel:
|
||||||
if not submodel_type is not None:
|
if not submodel_type is not None:
|
||||||
raise Exception("A submodel type must be provided when loading onnx pipelines.")
|
raise Exception("A submodel type must be provided when loading onnx pipelines.")
|
||||||
load_class = self._get_hf_load_class(model_path, submodel_type)
|
load_class = self.get_hf_load_class(model_path, submodel_type)
|
||||||
variant = model_variant.value if model_variant else None
|
variant = model_variant.value if model_variant else None
|
||||||
model_path = model_path / submodel_type.value
|
model_path = model_path / submodel_type.value
|
||||||
result: AnyModel = load_class.from_pretrained(
|
result: AnyModel = load_class.from_pretrained(
|
||||||
|
@ -19,13 +19,14 @@ from invokeai.backend.model_manager import (
|
|||||||
)
|
)
|
||||||
from invokeai.backend.model_manager.config import MainCheckpointConfig
|
from invokeai.backend.model_manager.config import MainCheckpointConfig
|
||||||
from invokeai.backend.model_manager.convert_ckpt_to_diffusers import convert_ckpt_to_diffusers
|
from invokeai.backend.model_manager.convert_ckpt_to_diffusers import convert_ckpt_to_diffusers
|
||||||
from invokeai.backend.model_manager.load.load_base import AnyModelLoader
|
|
||||||
from invokeai.backend.model_manager.load.load_default import ModelLoader
|
from .. import ModelLoaderRegistry
|
||||||
|
from .generic_diffusers import GenericDiffusersLoader
|
||||||
|
|
||||||
|
|
||||||
@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.Main, format=ModelFormat.Diffusers)
|
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.Main, format=ModelFormat.Diffusers)
|
||||||
@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.Main, format=ModelFormat.Checkpoint)
|
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.Main, format=ModelFormat.Checkpoint)
|
||||||
class StableDiffusionDiffusersModel(ModelLoader):
|
class StableDiffusionDiffusersModel(GenericDiffusersLoader):
|
||||||
"""Class to load main models."""
|
"""Class to load main models."""
|
||||||
|
|
||||||
model_base_to_model_type = {
|
model_base_to_model_type = {
|
||||||
|
@ -5,7 +5,6 @@
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional, Tuple
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
from invokeai.backend.textual_inversion import TextualInversionModelRaw
|
|
||||||
from invokeai.backend.model_manager import (
|
from invokeai.backend.model_manager import (
|
||||||
AnyModel,
|
AnyModel,
|
||||||
AnyModelConfig,
|
AnyModelConfig,
|
||||||
@ -15,12 +14,15 @@ from invokeai.backend.model_manager import (
|
|||||||
ModelType,
|
ModelType,
|
||||||
SubModelType,
|
SubModelType,
|
||||||
)
|
)
|
||||||
from invokeai.backend.model_manager.load.load_base import AnyModelLoader
|
from invokeai.backend.textual_inversion import TextualInversionModelRaw
|
||||||
from invokeai.backend.model_manager.load.load_default import ModelLoader
|
|
||||||
|
from .. import ModelLoader, ModelLoaderRegistry
|
||||||
|
|
||||||
|
|
||||||
@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.TextualInversion, format=ModelFormat.EmbeddingFile)
|
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.TextualInversion, format=ModelFormat.EmbeddingFile)
|
||||||
@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.TextualInversion, format=ModelFormat.EmbeddingFolder)
|
@ModelLoaderRegistry.register(
|
||||||
|
base=BaseModelType.Any, type=ModelType.TextualInversion, format=ModelFormat.EmbeddingFolder
|
||||||
|
)
|
||||||
class TextualInversionLoader(ModelLoader):
|
class TextualInversionLoader(ModelLoader):
|
||||||
"""Class to load TI models."""
|
"""Class to load TI models."""
|
||||||
|
|
||||||
|
@ -14,14 +14,14 @@ from invokeai.backend.model_manager import (
|
|||||||
ModelType,
|
ModelType,
|
||||||
)
|
)
|
||||||
from invokeai.backend.model_manager.convert_ckpt_to_diffusers import convert_ldm_vae_to_diffusers
|
from invokeai.backend.model_manager.convert_ckpt_to_diffusers import convert_ldm_vae_to_diffusers
|
||||||
from invokeai.backend.model_manager.load.load_base import AnyModelLoader
|
|
||||||
|
|
||||||
|
from .. import ModelLoaderRegistry
|
||||||
from .generic_diffusers import GenericDiffusersLoader
|
from .generic_diffusers import GenericDiffusersLoader
|
||||||
|
|
||||||
|
|
||||||
@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.Vae, format=ModelFormat.Diffusers)
|
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.Vae, format=ModelFormat.Diffusers)
|
||||||
@AnyModelLoader.register(base=BaseModelType.StableDiffusion1, type=ModelType.Vae, format=ModelFormat.Checkpoint)
|
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion1, type=ModelType.Vae, format=ModelFormat.Checkpoint)
|
||||||
@AnyModelLoader.register(base=BaseModelType.StableDiffusion2, type=ModelType.Vae, format=ModelFormat.Checkpoint)
|
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion2, type=ModelType.Vae, format=ModelFormat.Checkpoint)
|
||||||
class VaeLoader(GenericDiffusersLoader):
|
class VaeLoader(GenericDiffusersLoader):
|
||||||
"""Class to load VAE models."""
|
"""Class to load VAE models."""
|
||||||
|
|
||||||
|
@ -1,16 +1,16 @@
|
|||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
|
from typing import Any, Generator
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
|
||||||
def _no_op(*args, **kwargs):
|
def _no_op(*args: Any, **kwargs: Any) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def skip_torch_weight_init():
|
def skip_torch_weight_init() -> Generator[None, None, None]:
|
||||||
"""A context manager that monkey-patches several of the common torch layers (torch.nn.Linear, torch.nn.Conv1d, etc.)
|
"""Monkey patch several of the common torch layers (torch.nn.Linear, torch.nn.Conv1d, etc.) to skip weight initialization.
|
||||||
to skip weight initialization.
|
|
||||||
|
|
||||||
By default, `torch.nn.Linear` and `torch.nn.ConvNd` layers initialize their weights (according to a particular
|
By default, `torch.nn.Linear` and `torch.nn.ConvNd` layers initialize their weights (according to a particular
|
||||||
distribution) when __init__ is called. This weight initialization step can take a significant amount of time, and is
|
distribution) when __init__ is called. This weight initialization step can take a significant amount of time, and is
|
||||||
@ -18,13 +18,14 @@ def skip_torch_weight_init():
|
|||||||
monkey-patches common torch layers to skip the weight initialization step.
|
monkey-patches common torch layers to skip the weight initialization step.
|
||||||
"""
|
"""
|
||||||
torch_modules = [torch.nn.Linear, torch.nn.modules.conv._ConvNd, torch.nn.Embedding]
|
torch_modules = [torch.nn.Linear, torch.nn.modules.conv._ConvNd, torch.nn.Embedding]
|
||||||
saved_functions = [m.reset_parameters for m in torch_modules]
|
saved_functions = [hasattr(m, "reset_parameters") and m.reset_parameters for m in torch_modules]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
for torch_module in torch_modules:
|
for torch_module in torch_modules:
|
||||||
|
assert hasattr(torch_module, "reset_parameters")
|
||||||
torch_module.reset_parameters = _no_op
|
torch_module.reset_parameters = _no_op
|
||||||
|
|
||||||
yield None
|
yield None
|
||||||
finally:
|
finally:
|
||||||
for torch_module, saved_function in zip(torch_modules, saved_functions, strict=True):
|
for torch_module, saved_function in zip(torch_modules, saved_functions, strict=True):
|
||||||
|
assert hasattr(torch_module, "reset_parameters")
|
||||||
torch_module.reset_parameters = saved_function
|
torch_module.reset_parameters = saved_function
|
||||||
|
@ -13,7 +13,7 @@ from typing import Any, List, Optional, Set
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
from diffusers import AutoPipelineForText2Image
|
from diffusers import AutoPipelineForText2Image
|
||||||
from diffusers import logging as dlogging
|
from diffusers.utils import logging as dlogging
|
||||||
|
|
||||||
from invokeai.app.services.model_install import ModelInstallServiceBase
|
from invokeai.app.services.model_install import ModelInstallServiceBase
|
||||||
from invokeai.backend.util.devices import choose_torch_device, torch_dtype
|
from invokeai.backend.util.devices import choose_torch_device, torch_dtype
|
||||||
@ -76,7 +76,7 @@ class ModelMerger(object):
|
|||||||
custom_pipeline="checkpoint_merger",
|
custom_pipeline="checkpoint_merger",
|
||||||
torch_dtype=dtype,
|
torch_dtype=dtype,
|
||||||
variant=variant,
|
variant=variant,
|
||||||
)
|
) # type: ignore
|
||||||
merged_pipe = pipe.merge(
|
merged_pipe = pipe.merge(
|
||||||
pretrained_model_name_or_path_list=model_paths,
|
pretrained_model_name_or_path_list=model_paths,
|
||||||
alpha=alpha,
|
alpha=alpha,
|
||||||
|
@ -54,8 +54,8 @@ class LicenseRestrictions(BaseModel):
|
|||||||
AllowDifferentLicense: bool = Field(
|
AllowDifferentLicense: bool = Field(
|
||||||
description="if true, derivatives of this model be redistributed under a different license", default=False
|
description="if true, derivatives of this model be redistributed under a different license", default=False
|
||||||
)
|
)
|
||||||
AllowCommercialUse: CommercialUsage = Field(
|
AllowCommercialUse: Optional[CommercialUsage] = Field(
|
||||||
description="Type of commercial use allowed or 'No' if no commercial use is allowed.", default_factory=set
|
description="Type of commercial use allowed or 'No' if no commercial use is allowed.", default=None
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -139,6 +139,9 @@ class CivitaiMetadata(ModelMetadataWithFiles):
|
|||||||
@property
|
@property
|
||||||
def allow_commercial_use(self) -> bool:
|
def allow_commercial_use(self) -> bool:
|
||||||
"""Return True if commercial use is allowed."""
|
"""Return True if commercial use is allowed."""
|
||||||
|
if self.restrictions.AllowCommercialUse is None:
|
||||||
|
return False
|
||||||
|
else:
|
||||||
return self.restrictions.AllowCommercialUse != CommercialUsage("None")
|
return self.restrictions.AllowCommercialUse != CommercialUsage("None")
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@ -8,7 +8,6 @@ import torch
|
|||||||
from picklescan.scanner import scan_file_path
|
from picklescan.scanner import scan_file_path
|
||||||
|
|
||||||
import invokeai.backend.util.logging as logger
|
import invokeai.backend.util.logging as logger
|
||||||
from .util.model_util import lora_token_vector_length, read_checkpoint_meta
|
|
||||||
from invokeai.backend.util.util import SilenceWarnings
|
from invokeai.backend.util.util import SilenceWarnings
|
||||||
|
|
||||||
from .config import (
|
from .config import (
|
||||||
@ -23,6 +22,7 @@ from .config import (
|
|||||||
SchedulerPredictionType,
|
SchedulerPredictionType,
|
||||||
)
|
)
|
||||||
from .hash import FastModelHash
|
from .hash import FastModelHash
|
||||||
|
from .util.model_util import lora_token_vector_length, read_checkpoint_meta
|
||||||
|
|
||||||
CkptType = Dict[str, Any]
|
CkptType = Dict[str, Any]
|
||||||
|
|
||||||
@ -53,6 +53,7 @@ LEGACY_CONFIGS: Dict[BaseModelType, Dict[ModelVariantType, Union[str, Dict[Sched
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
class ProbeBase(object):
|
class ProbeBase(object):
|
||||||
"""Base class for probes."""
|
"""Base class for probes."""
|
||||||
|
|
||||||
|
@ -116,9 +116,9 @@ class ModelSearch(ModelSearchBase):
|
|||||||
# returns all models that have 'anime' in the path
|
# returns all models that have 'anime' in the path
|
||||||
"""
|
"""
|
||||||
|
|
||||||
models_found: Optional[Set[Path]] = Field(default=None)
|
models_found: Set[Path] = Field(default_factory=set)
|
||||||
scanned_dirs: Optional[Set[Path]] = Field(default=None)
|
scanned_dirs: Set[Path] = Field(default_factory=set)
|
||||||
pruned_paths: Optional[Set[Path]] = Field(default=None)
|
pruned_paths: Set[Path] = Field(default_factory=set)
|
||||||
|
|
||||||
def search_started(self) -> None:
|
def search_started(self) -> None:
|
||||||
self.models_found = set()
|
self.models_found = set()
|
||||||
|
@ -35,7 +35,7 @@ class Struct_mallinfo2(ctypes.Structure):
|
|||||||
("keepcost", ctypes.c_size_t),
|
("keepcost", ctypes.c_size_t),
|
||||||
]
|
]
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self) -> str:
|
||||||
s = ""
|
s = ""
|
||||||
s += f"{'arena': <10}= {(self.arena/2**30):15.5f} # Non-mmapped space allocated (GB) (uordblks + fordblks)\n"
|
s += f"{'arena': <10}= {(self.arena/2**30):15.5f} # Non-mmapped space allocated (GB) (uordblks + fordblks)\n"
|
||||||
s += f"{'ordblks': <10}= {(self.ordblks): >15} # Number of free chunks\n"
|
s += f"{'ordblks': <10}= {(self.ordblks): >15} # Number of free chunks\n"
|
||||||
@ -62,7 +62,7 @@ class LibcUtil:
|
|||||||
TODO: Improve cross-OS compatibility of this class.
|
TODO: Improve cross-OS compatibility of this class.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
self._libc = ctypes.cdll.LoadLibrary("libc.so.6")
|
self._libc = ctypes.cdll.LoadLibrary("libc.so.6")
|
||||||
|
|
||||||
def mallinfo2(self) -> Struct_mallinfo2:
|
def mallinfo2(self) -> Struct_mallinfo2:
|
||||||
@ -72,4 +72,5 @@ class LibcUtil:
|
|||||||
"""
|
"""
|
||||||
mallinfo2 = self._libc.mallinfo2
|
mallinfo2 = self._libc.mallinfo2
|
||||||
mallinfo2.restype = Struct_mallinfo2
|
mallinfo2.restype = Struct_mallinfo2
|
||||||
return mallinfo2()
|
result: Struct_mallinfo2 = mallinfo2()
|
||||||
|
return result
|
||||||
|
@ -1,12 +1,15 @@
|
|||||||
"""Utilities for parsing model files, used mostly by probe.py"""
|
"""Utilities for parsing model files, used mostly by probe.py"""
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import torch
|
|
||||||
from typing import Union
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import Dict, Optional, Union
|
||||||
|
|
||||||
|
import safetensors
|
||||||
|
import torch
|
||||||
from picklescan.scanner import scan_file_path
|
from picklescan.scanner import scan_file_path
|
||||||
|
|
||||||
def _fast_safetensors_reader(path: str):
|
|
||||||
|
def _fast_safetensors_reader(path: str) -> Dict[str, torch.Tensor]:
|
||||||
checkpoint = {}
|
checkpoint = {}
|
||||||
device = torch.device("meta")
|
device = torch.device("meta")
|
||||||
with open(path, "rb") as f:
|
with open(path, "rb") as f:
|
||||||
@ -37,10 +40,12 @@ def _fast_safetensors_reader(path: str):
|
|||||||
|
|
||||||
return checkpoint
|
return checkpoint
|
||||||
|
|
||||||
def read_checkpoint_meta(path: Union[str, Path], scan: bool = False):
|
|
||||||
|
def read_checkpoint_meta(path: Union[str, Path], scan: bool = False) -> Dict[str, torch.Tensor]:
|
||||||
if str(path).endswith(".safetensors"):
|
if str(path).endswith(".safetensors"):
|
||||||
try:
|
try:
|
||||||
checkpoint = _fast_safetensors_reader(path)
|
path_str = path.as_posix() if isinstance(path, Path) else path
|
||||||
|
checkpoint = _fast_safetensors_reader(path_str)
|
||||||
except Exception:
|
except Exception:
|
||||||
# TODO: create issue for support "meta"?
|
# TODO: create issue for support "meta"?
|
||||||
checkpoint = safetensors.torch.load_file(path, device="cpu")
|
checkpoint = safetensors.torch.load_file(path, device="cpu")
|
||||||
@ -52,14 +57,15 @@ def read_checkpoint_meta(path: Union[str, Path], scan: bool = False):
|
|||||||
checkpoint = torch.load(path, map_location=torch.device("meta"))
|
checkpoint = torch.load(path, map_location=torch.device("meta"))
|
||||||
return checkpoint
|
return checkpoint
|
||||||
|
|
||||||
def lora_token_vector_length(checkpoint: dict) -> int:
|
|
||||||
|
def lora_token_vector_length(checkpoint: Dict[str, torch.Tensor]) -> Optional[int]:
|
||||||
"""
|
"""
|
||||||
Given a checkpoint in memory, return the lora token vector length
|
Given a checkpoint in memory, return the lora token vector length
|
||||||
|
|
||||||
:param checkpoint: The checkpoint
|
:param checkpoint: The checkpoint
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def _get_shape_1(key: str, tensor, checkpoint) -> int:
|
def _get_shape_1(key: str, tensor: torch.Tensor, checkpoint: Dict[str, torch.Tensor]) -> Optional[int]:
|
||||||
lora_token_vector_length = None
|
lora_token_vector_length = None
|
||||||
|
|
||||||
if "." not in key:
|
if "." not in key:
|
||||||
|
@ -8,6 +8,7 @@ import numpy as np
|
|||||||
import onnx
|
import onnx
|
||||||
from onnx import numpy_helper
|
from onnx import numpy_helper
|
||||||
from onnxruntime import InferenceSession, SessionOptions, get_available_providers
|
from onnxruntime import InferenceSession, SessionOptions, get_available_providers
|
||||||
|
|
||||||
from ..raw_model import RawModel
|
from ..raw_model import RawModel
|
||||||
|
|
||||||
ONNX_WEIGHTS_NAME = "model.onnx"
|
ONNX_WEIGHTS_NAME = "model.onnx"
|
||||||
@ -15,7 +16,7 @@ ONNX_WEIGHTS_NAME = "model.onnx"
|
|||||||
|
|
||||||
# NOTE FROM LS: This was copied from Stalker's original implementation.
|
# NOTE FROM LS: This was copied from Stalker's original implementation.
|
||||||
# I have not yet gone through and fixed all the type hints
|
# I have not yet gone through and fixed all the type hints
|
||||||
class IAIOnnxRuntimeModel:
|
class IAIOnnxRuntimeModel(RawModel):
|
||||||
class _tensor_access:
|
class _tensor_access:
|
||||||
def __init__(self, model): # type: ignore
|
def __init__(self, model): # type: ignore
|
||||||
self.model = model
|
self.model = model
|
||||||
|
@ -10,5 +10,6 @@ The term 'raw' was introduced to describe a wrapper around a torch.nn.Module
|
|||||||
that adds additional methods and attributes.
|
that adds additional methods and attributes.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
class RawModel:
|
class RawModel:
|
||||||
"""Base class for 'Raw' model wrappers."""
|
"""Base class for 'Raw' model wrappers."""
|
||||||
|
@ -1,10 +1,11 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from typing import List, Union
|
from typing import Callable, List, Union
|
||||||
|
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from diffusers.models import AutoencoderKL, UNet2DConditionModel
|
from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL
|
||||||
|
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
|
||||||
|
|
||||||
|
|
||||||
def _conv_forward_asymmetric(self, input, weight, bias):
|
def _conv_forward_asymmetric(self, input, weight, bias):
|
||||||
@ -26,51 +27,32 @@ def _conv_forward_asymmetric(self, input, weight, bias):
|
|||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def set_seamless(model: Union[UNet2DConditionModel, AutoencoderKL], seamless_axes: List[str]):
|
def set_seamless(model: Union[UNet2DConditionModel, AutoencoderKL], seamless_axes: List[str]):
|
||||||
|
# Callable: (input: Tensor, weight: Tensor, bias: Optional[Tensor]) -> Tensor
|
||||||
|
to_restore: list[tuple[nn.Conv2d | nn.ConvTranspose2d, Callable]] = []
|
||||||
try:
|
try:
|
||||||
to_restore = []
|
# Hard coded to skip down block layers, allowing for seamless tiling at the expense of prompt adherence
|
||||||
|
skipped_layers = 1
|
||||||
for m_name, m in model.named_modules():
|
for m_name, m in model.named_modules():
|
||||||
if isinstance(model, UNet2DConditionModel):
|
if not isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
|
||||||
if ".attentions." in m_name:
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if ".resnets." in m_name:
|
if isinstance(model, UNet2DConditionModel) and m_name.startswith("down_blocks.") and ".resnets." in m_name:
|
||||||
if ".conv2" in m_name:
|
# down_blocks.1.resnets.1.conv1
|
||||||
continue
|
_, block_num, _, resnet_num, submodule_name = m_name.split(".")
|
||||||
if ".conv_shortcut" in m_name:
|
block_num = int(block_num)
|
||||||
|
resnet_num = int(resnet_num)
|
||||||
|
|
||||||
|
if block_num >= len(model.down_blocks) - skipped_layers:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
"""
|
# Skip the second resnet (could be configurable)
|
||||||
if isinstance(model, UNet2DConditionModel):
|
if resnet_num > 0:
|
||||||
if False and ".upsamplers." in m_name:
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if False and ".downsamplers." in m_name:
|
# Skip Conv2d layers (could be configurable)
|
||||||
|
if submodule_name == "conv2":
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if True and ".resnets." in m_name:
|
|
||||||
if True and ".conv1" in m_name:
|
|
||||||
if False and "down_blocks" in m_name:
|
|
||||||
continue
|
|
||||||
if False and "mid_block" in m_name:
|
|
||||||
continue
|
|
||||||
if False and "up_blocks" in m_name:
|
|
||||||
continue
|
|
||||||
|
|
||||||
if True and ".conv2" in m_name:
|
|
||||||
continue
|
|
||||||
|
|
||||||
if True and ".conv_shortcut" in m_name:
|
|
||||||
continue
|
|
||||||
|
|
||||||
if True and ".attentions." in m_name:
|
|
||||||
continue
|
|
||||||
|
|
||||||
if False and m_name in ["conv_in", "conv_out"]:
|
|
||||||
continue
|
|
||||||
"""
|
|
||||||
|
|
||||||
if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
|
|
||||||
m.asymmetric_padding_mode = {}
|
m.asymmetric_padding_mode = {}
|
||||||
m.asymmetric_padding = {}
|
m.asymmetric_padding = {}
|
||||||
m.asymmetric_padding_mode["x"] = "circular" if ("x" in seamless_axes) else "constant"
|
m.asymmetric_padding_mode["x"] = "circular" if ("x" in seamless_axes) else "constant"
|
||||||
|
@ -8,8 +8,10 @@ from compel.embeddings_provider import BaseTextualInversionManager
|
|||||||
from safetensors.torch import load_file
|
from safetensors.torch import load_file
|
||||||
from transformers import CLIPTokenizer
|
from transformers import CLIPTokenizer
|
||||||
from typing_extensions import Self
|
from typing_extensions import Self
|
||||||
|
|
||||||
from .raw_model import RawModel
|
from .raw_model import RawModel
|
||||||
|
|
||||||
|
|
||||||
class TextualInversionModelRaw(RawModel):
|
class TextualInversionModelRaw(RawModel):
|
||||||
embedding: torch.Tensor # [n, 768]|[n, 1280]
|
embedding: torch.Tensor # [n, 768]|[n, 1280]
|
||||||
embedding_2: Optional[torch.Tensor] = None # [n, 768]|[n, 1280] - for SDXL models
|
embedding_2: Optional[torch.Tensor] = None # [n, 768]|[n, 1280] - for SDXL models
|
||||||
|
@ -5,9 +5,9 @@ from typing import Optional, Union
|
|||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from invokeai.app.services.model_records import UnknownModelException
|
|
||||||
from invokeai.app.services.model_manager import ModelManagerServiceBase
|
from invokeai.app.services.model_manager import ModelManagerServiceBase
|
||||||
from invokeai.backend.model_manager import BaseModelType, ModelType, SubModelType, LoadedModel
|
from invokeai.app.services.model_records import UnknownModelException
|
||||||
|
from invokeai.backend.model_manager import BaseModelType, LoadedModel, ModelType, SubModelType
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
@ -42,7 +42,9 @@ def install_and_load_model(
|
|||||||
# If the requested model is already installed, return its LoadedModel
|
# If the requested model is already installed, return its LoadedModel
|
||||||
with contextlib.suppress(UnknownModelException):
|
with contextlib.suppress(UnknownModelException):
|
||||||
# TODO: Replace with wrapper call
|
# TODO: Replace with wrapper call
|
||||||
loaded_model: LoadedModel = model_manager.load.load_model_by_attr(name=model_name, base=base_model, type=model_type)
|
loaded_model: LoadedModel = model_manager.load.load_model_by_attr(
|
||||||
|
name=model_name, base=base_model, type=model_type
|
||||||
|
)
|
||||||
return loaded_model
|
return loaded_model
|
||||||
|
|
||||||
# Install the requested model.
|
# Install the requested model.
|
||||||
|
@ -4,18 +4,27 @@ Test model loading
|
|||||||
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from invokeai.app.services.model_install import ModelInstallServiceBase
|
from invokeai.app.services.model_manager import ModelManagerServiceBase
|
||||||
from invokeai.app.services.model_load import ModelLoadServiceBase
|
|
||||||
from invokeai.backend.textual_inversion import TextualInversionModelRaw
|
from invokeai.backend.textual_inversion import TextualInversionModelRaw
|
||||||
from tests.backend.model_manager.model_manager_fixtures import * # noqa F403
|
from tests.backend.model_manager.model_manager_fixtures import * # noqa F403
|
||||||
|
|
||||||
def test_loading(mm2_installer: ModelInstallServiceBase, mm2_loader: ModelLoadServiceBase, embedding_file: Path):
|
|
||||||
store = mm2_installer.record_store
|
def test_loading(mm2_model_manager: ModelManagerServiceBase, embedding_file: Path):
|
||||||
|
store = mm2_model_manager.store
|
||||||
matches = store.search_by_attr(model_name="test_embedding")
|
matches = store.search_by_attr(model_name="test_embedding")
|
||||||
assert len(matches) == 0
|
assert len(matches) == 0
|
||||||
key = mm2_installer.register_path(embedding_file)
|
key = mm2_model_manager.install.register_path(embedding_file)
|
||||||
loaded_model = mm2_loader.load_model_by_config(store.get_model(key))
|
loaded_model = mm2_model_manager.load_model_by_config(store.get_model(key))
|
||||||
assert loaded_model is not None
|
assert loaded_model is not None
|
||||||
assert loaded_model.config.key == key
|
assert loaded_model.config.key == key
|
||||||
with loaded_model as model:
|
with loaded_model as model:
|
||||||
assert isinstance(model, TextualInversionModelRaw)
|
assert isinstance(model, TextualInversionModelRaw)
|
||||||
|
loaded_model_2 = mm2_model_manager.load_model_by_key(key)
|
||||||
|
assert loaded_model.config.key == loaded_model_2.config.key
|
||||||
|
|
||||||
|
loaded_model_3 = mm2_model_manager.load_model_by_attr(
|
||||||
|
model_name=loaded_model.config.name,
|
||||||
|
model_type=loaded_model.config.type,
|
||||||
|
base_model=loaded_model.config.base,
|
||||||
|
)
|
||||||
|
assert loaded_model.config.key == loaded_model_3.config.key
|
||||||
|
@ -6,17 +6,17 @@ from pathlib import Path
|
|||||||
from typing import Any, Dict, List
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from pytest import FixtureRequest
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
from pytest import FixtureRequest
|
||||||
from requests.sessions import Session
|
from requests.sessions import Session
|
||||||
from requests_testadapter import TestAdapter, TestSession
|
from requests_testadapter import TestAdapter, TestSession
|
||||||
|
|
||||||
from invokeai.app.services.config import InvokeAIAppConfig
|
from invokeai.app.services.config import InvokeAIAppConfig
|
||||||
from invokeai.app.services.download import DownloadQueueServiceBase, DownloadQueueService
|
from invokeai.app.services.download import DownloadQueueService, DownloadQueueServiceBase
|
||||||
from invokeai.app.services.events.events_base import EventServiceBase
|
from invokeai.app.services.events.events_base import EventServiceBase
|
||||||
from invokeai.app.services.model_manager import ModelManagerServiceBase, ModelManagerService
|
|
||||||
from invokeai.app.services.model_load import ModelLoadServiceBase, ModelLoadService
|
|
||||||
from invokeai.app.services.model_install import ModelInstallService, ModelInstallServiceBase
|
from invokeai.app.services.model_install import ModelInstallService, ModelInstallServiceBase
|
||||||
|
from invokeai.app.services.model_load import ModelLoadService, ModelLoadServiceBase
|
||||||
|
from invokeai.app.services.model_manager import ModelManagerService, ModelManagerServiceBase
|
||||||
from invokeai.app.services.model_metadata import ModelMetadataStoreBase, ModelMetadataStoreSQL
|
from invokeai.app.services.model_metadata import ModelMetadataStoreBase, ModelMetadataStoreSQL
|
||||||
from invokeai.app.services.model_records import ModelRecordServiceBase, ModelRecordServiceSQL
|
from invokeai.app.services.model_records import ModelRecordServiceBase, ModelRecordServiceSQL
|
||||||
from invokeai.backend.model_manager.config import (
|
from invokeai.backend.model_manager.config import (
|
||||||
@ -95,9 +95,7 @@ def mm2_app_config(mm2_root_dir: Path) -> InvokeAIAppConfig:
|
|||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mm2_download_queue(mm2_session: Session,
|
def mm2_download_queue(mm2_session: Session, request: FixtureRequest) -> DownloadQueueServiceBase:
|
||||||
request: FixtureRequest
|
|
||||||
) -> DownloadQueueServiceBase:
|
|
||||||
download_queue = DownloadQueueService(requests_session=mm2_session)
|
download_queue = DownloadQueueService(requests_session=mm2_session)
|
||||||
download_queue.start()
|
download_queue.start()
|
||||||
|
|
||||||
@ -107,30 +105,34 @@ def mm2_download_queue(mm2_session: Session,
|
|||||||
request.addfinalizer(stop_queue)
|
request.addfinalizer(stop_queue)
|
||||||
return download_queue
|
return download_queue
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mm2_metadata_store(mm2_record_store: ModelRecordServiceSQL) -> ModelMetadataStoreBase:
|
def mm2_metadata_store(mm2_record_store: ModelRecordServiceSQL) -> ModelMetadataStoreBase:
|
||||||
return mm2_record_store.metadata_store
|
return mm2_record_store.metadata_store
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mm2_loader(mm2_app_config: InvokeAIAppConfig, mm2_record_store: ModelRecordServiceBase) -> ModelLoadServiceBase:
|
def mm2_loader(mm2_app_config: InvokeAIAppConfig, mm2_record_store: ModelRecordServiceBase) -> ModelLoadServiceBase:
|
||||||
ram_cache = ModelCache(
|
ram_cache = ModelCache(
|
||||||
logger=InvokeAILogger.get_logger(),
|
logger=InvokeAILogger.get_logger(),
|
||||||
max_cache_size=mm2_app_config.ram_cache_size,
|
max_cache_size=mm2_app_config.ram_cache_size,
|
||||||
max_vram_cache_size=mm2_app_config.vram_cache_size
|
max_vram_cache_size=mm2_app_config.vram_cache_size,
|
||||||
)
|
)
|
||||||
convert_cache = ModelConvertCache(mm2_app_config.models_convert_cache_path)
|
convert_cache = ModelConvertCache(mm2_app_config.models_convert_cache_path)
|
||||||
return ModelLoadService(app_config=mm2_app_config,
|
return ModelLoadService(
|
||||||
record_store=mm2_record_store,
|
app_config=mm2_app_config,
|
||||||
ram_cache=ram_cache,
|
ram_cache=ram_cache,
|
||||||
convert_cache=convert_cache,
|
convert_cache=convert_cache,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mm2_installer(mm2_app_config: InvokeAIAppConfig,
|
def mm2_installer(
|
||||||
|
mm2_app_config: InvokeAIAppConfig,
|
||||||
mm2_download_queue: DownloadQueueServiceBase,
|
mm2_download_queue: DownloadQueueServiceBase,
|
||||||
mm2_session: Session,
|
mm2_session: Session,
|
||||||
request: FixtureRequest,
|
request: FixtureRequest,
|
||||||
) -> ModelInstallServiceBase:
|
) -> ModelInstallServiceBase:
|
||||||
logger = InvokeAILogger.get_logger()
|
logger = InvokeAILogger.get_logger()
|
||||||
db = create_mock_sqlite_database(mm2_app_config, logger)
|
db = create_mock_sqlite_database(mm2_app_config, logger)
|
||||||
events = DummyEventService()
|
events = DummyEventService()
|
||||||
@ -213,15 +215,13 @@ def mm2_record_store(mm2_app_config: InvokeAIAppConfig) -> ModelRecordServiceBas
|
|||||||
store.add_model("test_config_5", raw5)
|
store.add_model("test_config_5", raw5)
|
||||||
return store
|
return store
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mm2_model_manager(mm2_record_store: ModelRecordServiceBase,
|
def mm2_model_manager(
|
||||||
mm2_installer: ModelInstallServiceBase,
|
mm2_record_store: ModelRecordServiceBase, mm2_installer: ModelInstallServiceBase, mm2_loader: ModelLoadServiceBase
|
||||||
mm2_loader: ModelLoadServiceBase) -> ModelManagerServiceBase:
|
) -> ModelManagerServiceBase:
|
||||||
return ModelManagerService(
|
return ModelManagerService(store=mm2_record_store, install=mm2_installer, load=mm2_loader)
|
||||||
store=mm2_record_store,
|
|
||||||
install=mm2_installer,
|
|
||||||
load=mm2_loader
|
|
||||||
)
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mm2_session(embedding_file: Path, diffusers_dir: Path) -> Session:
|
def mm2_session(embedding_file: Path, diffusers_dir: Path) -> Session:
|
||||||
@ -306,5 +306,3 @@ def mm2_session(embedding_file: Path, diffusers_dir: Path) -> Session:
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
return sess
|
return sess
|
||||||
|
|
||||||
|
|
||||||
|
@ -5,8 +5,8 @@
|
|||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from invokeai.backend.model_patcher import ModelPatcher
|
|
||||||
from invokeai.backend.lora import LoRALayer, LoRAModelRaw
|
from invokeai.backend.lora import LoRALayer, LoRAModelRaw
|
||||||
|
from invokeai.backend.model_patcher import ModelPatcher
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
|
@ -1,7 +1,8 @@
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from invokeai.backend.model_manager.util.libc_util import Struct_mallinfo2
|
|
||||||
from invokeai.backend.model_manager.load.memory_snapshot import MemorySnapshot, get_pretty_snapshot_diff
|
from invokeai.backend.model_manager.load.memory_snapshot import MemorySnapshot, get_pretty_snapshot_diff
|
||||||
|
from invokeai.backend.model_manager.util.libc_util import Struct_mallinfo2
|
||||||
|
|
||||||
|
|
||||||
def test_memory_snapshot_capture():
|
def test_memory_snapshot_capture():
|
||||||
"""Smoke test of MemorySnapshot.capture()."""
|
"""Smoke test of MemorySnapshot.capture()."""
|
||||||
|
Loading…
Reference in New Issue
Block a user