mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Merge branch 'main' of github.com:invoke-ai/InvokeAI into feat/controlnet-control-modes
Only "real" conflicts were in: invokeai/frontend/web/src/features/controlNet/components/ControlNet.tsx invokeai/frontend/web/src/features/controlNet/store/controlNetSlice.ts
This commit is contained in:
commit
c5faffc18b
17
README.md
17
README.md
@ -43,6 +43,23 @@ _Note: InvokeAI is rapidly evolving. Please use the
|
||||
[Issues](https://github.com/invoke-ai/InvokeAI/issues) tab to report bugs and make feature
|
||||
requests. Be sure to use the provided templates. They will help us diagnose issues faster._
|
||||
|
||||
## FOR DEVELOPERS - MIGRATING TO THE 3.0.0 MODELS FORMAT
|
||||
|
||||
The models directory and models.yaml have changed. To migrate to the
|
||||
new layout, please follow this recipe:
|
||||
|
||||
1. Run `python scripts/migrate_models_to_3.0.py <path_to_root_directory>
|
||||
|
||||
2. This will create a new models directory named `models-3.0` and a
|
||||
new config directory named `models.yaml-3.0`, both in the current
|
||||
working directory. If you prefer to name them something else, pass
|
||||
the `--dest-directory` and/or `--dest-yaml` arguments.
|
||||
|
||||
3. Check that the new models directory and yaml file look ok.
|
||||
|
||||
4. Replace the existing directory and file, keeping backup copies just in
|
||||
case.
|
||||
|
||||
<div align="center">
|
||||
|
||||
![canvas preview](https://github.com/invoke-ai/InvokeAI/raw/main/docs/assets/canvas_preview.png)
|
||||
|
@ -67,7 +67,7 @@ title: Home
|
||||
implementation of Stable Diffusion, the open source text-to-image and
|
||||
image-to-image generator. It provides a streamlined process with various new
|
||||
features and options to aid the image generation process. It runs on Windows,
|
||||
Mac and Linux machines, and runs on GPU cards with as little as 4 GB or RAM.
|
||||
Mac and Linux machines, and runs on GPU cards with as little as 4 GB of RAM.
|
||||
|
||||
**Quick links**: [<a href="https://discord.gg/ZmtBAhwWhy">Discord Server</a>]
|
||||
[<a href="https://github.com/invoke-ai/InvokeAI/">Code and Downloads</a>] [<a
|
||||
|
@ -25,7 +25,7 @@ done
|
||||
|
||||
if [ -z "$PYTHON" ]; then
|
||||
echo "A suitable Python interpreter could not be found"
|
||||
echo "Please install Python 3.9 or higher before running this script. See instructions at $INSTRUCTIONS for help."
|
||||
echo "Please install Python $MINIMUM_PYTHON_VERSION or higher (maximum $MAXIMUM_PYTHON_VERSION) before running this script. See instructions at $INSTRUCTIONS for help."
|
||||
read -p "Press any key to exit"
|
||||
exit -1
|
||||
fi
|
||||
|
@ -2,8 +2,17 @@
|
||||
|
||||
from logging import Logger
|
||||
import os
|
||||
from invokeai.app.services.board_image_record_storage import (
|
||||
SqliteBoardImageRecordStorage,
|
||||
)
|
||||
from invokeai.app.services.board_images import (
|
||||
BoardImagesService,
|
||||
BoardImagesServiceDependencies,
|
||||
)
|
||||
from invokeai.app.services.board_record_storage import SqliteBoardRecordStorage
|
||||
from invokeai.app.services.boards import BoardService, BoardServiceDependencies
|
||||
from invokeai.app.services.image_record_storage import SqliteImageRecordStorage
|
||||
from invokeai.app.services.images import ImageService
|
||||
from invokeai.app.services.images import ImageService, ImageServiceDependencies
|
||||
from invokeai.app.services.metadata import CoreMetadataService
|
||||
from invokeai.app.services.resource_name import SimpleNameService
|
||||
from invokeai.app.services.urls import LocalUrlService
|
||||
@ -11,7 +20,6 @@ from invokeai.backend.util.logging import InvokeAILogger
|
||||
|
||||
from ..services.default_graphs import create_system_graphs
|
||||
from ..services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage
|
||||
from ..services.model_manager_initializer import get_model_manager
|
||||
from ..services.restoration_services import RestorationServices
|
||||
from ..services.graph import GraphExecutionState, LibraryGraph
|
||||
from ..services.image_file_storage import DiskImageFileStorage
|
||||
@ -20,6 +28,7 @@ from ..services.invocation_services import InvocationServices
|
||||
from ..services.invoker import Invoker
|
||||
from ..services.processor import DefaultInvocationProcessor
|
||||
from ..services.sqlite import SqliteItemStorage
|
||||
from ..services.model_manager_service import ModelManagerService
|
||||
from .events import FastAPIEventService
|
||||
|
||||
|
||||
@ -57,7 +66,7 @@ class ApiDependencies:
|
||||
|
||||
# TODO: build a file/path manager?
|
||||
db_location = config.db_path
|
||||
db_location.parent.mkdir(parents=True,exist_ok=True)
|
||||
db_location.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
graph_execution_manager = SqliteItemStorage[GraphExecutionState](
|
||||
filename=db_location, table_name="graph_executions"
|
||||
@ -72,21 +81,49 @@ class ApiDependencies:
|
||||
DiskLatentsStorage(f"{output_folder}/latents")
|
||||
)
|
||||
|
||||
board_record_storage = SqliteBoardRecordStorage(db_location)
|
||||
board_image_record_storage = SqliteBoardImageRecordStorage(db_location)
|
||||
|
||||
boards = BoardService(
|
||||
services=BoardServiceDependencies(
|
||||
board_image_record_storage=board_image_record_storage,
|
||||
board_record_storage=board_record_storage,
|
||||
image_record_storage=image_record_storage,
|
||||
url=urls,
|
||||
logger=logger,
|
||||
)
|
||||
)
|
||||
|
||||
board_images = BoardImagesService(
|
||||
services=BoardImagesServiceDependencies(
|
||||
board_image_record_storage=board_image_record_storage,
|
||||
board_record_storage=board_record_storage,
|
||||
image_record_storage=image_record_storage,
|
||||
url=urls,
|
||||
logger=logger,
|
||||
)
|
||||
)
|
||||
|
||||
images = ImageService(
|
||||
image_record_storage=image_record_storage,
|
||||
image_file_storage=image_file_storage,
|
||||
metadata=metadata,
|
||||
url=urls,
|
||||
logger=logger,
|
||||
names=names,
|
||||
graph_execution_manager=graph_execution_manager,
|
||||
services=ImageServiceDependencies(
|
||||
board_image_record_storage=board_image_record_storage,
|
||||
image_record_storage=image_record_storage,
|
||||
image_file_storage=image_file_storage,
|
||||
metadata=metadata,
|
||||
url=urls,
|
||||
logger=logger,
|
||||
names=names,
|
||||
graph_execution_manager=graph_execution_manager,
|
||||
)
|
||||
)
|
||||
|
||||
services = InvocationServices(
|
||||
model_manager=get_model_manager(config, logger),
|
||||
model_manager=ModelManagerService(config,logger),
|
||||
events=events,
|
||||
latents=latents,
|
||||
images=images,
|
||||
boards=boards,
|
||||
board_images=board_images,
|
||||
queue=MemoryInvocationQueue(),
|
||||
graph_library=SqliteItemStorage[LibraryGraph](
|
||||
filename=db_location, table_name="graphs"
|
||||
|
69
invokeai/app/api/routers/board_images.py
Normal file
69
invokeai/app/api/routers/board_images.py
Normal file
@ -0,0 +1,69 @@
|
||||
from fastapi import Body, HTTPException, Path, Query
|
||||
from fastapi.routing import APIRouter
|
||||
from invokeai.app.services.board_record_storage import BoardRecord, BoardChanges
|
||||
from invokeai.app.services.image_record_storage import OffsetPaginatedResults
|
||||
from invokeai.app.services.models.board_record import BoardDTO
|
||||
from invokeai.app.services.models.image_record import ImageDTO
|
||||
|
||||
from ..dependencies import ApiDependencies
|
||||
|
||||
board_images_router = APIRouter(prefix="/v1/board_images", tags=["boards"])
|
||||
|
||||
|
||||
@board_images_router.post(
|
||||
"/",
|
||||
operation_id="create_board_image",
|
||||
responses={
|
||||
201: {"description": "The image was added to a board successfully"},
|
||||
},
|
||||
status_code=201,
|
||||
)
|
||||
async def create_board_image(
|
||||
board_id: str = Body(description="The id of the board to add to"),
|
||||
image_name: str = Body(description="The name of the image to add"),
|
||||
):
|
||||
"""Creates a board_image"""
|
||||
try:
|
||||
result = ApiDependencies.invoker.services.board_images.add_image_to_board(board_id=board_id, image_name=image_name)
|
||||
return result
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail="Failed to add to board")
|
||||
|
||||
@board_images_router.delete(
|
||||
"/",
|
||||
operation_id="remove_board_image",
|
||||
responses={
|
||||
201: {"description": "The image was removed from the board successfully"},
|
||||
},
|
||||
status_code=201,
|
||||
)
|
||||
async def remove_board_image(
|
||||
board_id: str = Body(description="The id of the board"),
|
||||
image_name: str = Body(description="The name of the image to remove"),
|
||||
):
|
||||
"""Deletes a board_image"""
|
||||
try:
|
||||
result = ApiDependencies.invoker.services.board_images.remove_image_from_board(board_id=board_id, image_name=image_name)
|
||||
return result
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail="Failed to update board")
|
||||
|
||||
|
||||
|
||||
@board_images_router.get(
|
||||
"/{board_id}",
|
||||
operation_id="list_board_images",
|
||||
response_model=OffsetPaginatedResults[ImageDTO],
|
||||
)
|
||||
async def list_board_images(
|
||||
board_id: str = Path(description="The id of the board"),
|
||||
offset: int = Query(default=0, description="The page offset"),
|
||||
limit: int = Query(default=10, description="The number of boards per page"),
|
||||
) -> OffsetPaginatedResults[ImageDTO]:
|
||||
"""Gets a list of images for a board"""
|
||||
|
||||
results = ApiDependencies.invoker.services.board_images.get_images_for_board(
|
||||
board_id,
|
||||
)
|
||||
return results
|
||||
|
108
invokeai/app/api/routers/boards.py
Normal file
108
invokeai/app/api/routers/boards.py
Normal file
@ -0,0 +1,108 @@
|
||||
from typing import Optional, Union
|
||||
from fastapi import Body, HTTPException, Path, Query
|
||||
from fastapi.routing import APIRouter
|
||||
from invokeai.app.services.board_record_storage import BoardChanges
|
||||
from invokeai.app.services.image_record_storage import OffsetPaginatedResults
|
||||
from invokeai.app.services.models.board_record import BoardDTO
|
||||
|
||||
from ..dependencies import ApiDependencies
|
||||
|
||||
boards_router = APIRouter(prefix="/v1/boards", tags=["boards"])
|
||||
|
||||
|
||||
@boards_router.post(
|
||||
"/",
|
||||
operation_id="create_board",
|
||||
responses={
|
||||
201: {"description": "The board was created successfully"},
|
||||
},
|
||||
status_code=201,
|
||||
response_model=BoardDTO,
|
||||
)
|
||||
async def create_board(
|
||||
board_name: str = Query(description="The name of the board to create"),
|
||||
) -> BoardDTO:
|
||||
"""Creates a board"""
|
||||
try:
|
||||
result = ApiDependencies.invoker.services.boards.create(board_name=board_name)
|
||||
return result
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail="Failed to create board")
|
||||
|
||||
|
||||
@boards_router.get("/{board_id}", operation_id="get_board", response_model=BoardDTO)
|
||||
async def get_board(
|
||||
board_id: str = Path(description="The id of board to get"),
|
||||
) -> BoardDTO:
|
||||
"""Gets a board"""
|
||||
|
||||
try:
|
||||
result = ApiDependencies.invoker.services.boards.get_dto(board_id=board_id)
|
||||
return result
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=404, detail="Board not found")
|
||||
|
||||
|
||||
@boards_router.patch(
|
||||
"/{board_id}",
|
||||
operation_id="update_board",
|
||||
responses={
|
||||
201: {
|
||||
"description": "The board was updated successfully",
|
||||
},
|
||||
},
|
||||
status_code=201,
|
||||
response_model=BoardDTO,
|
||||
)
|
||||
async def update_board(
|
||||
board_id: str = Path(description="The id of board to update"),
|
||||
changes: BoardChanges = Body(description="The changes to apply to the board"),
|
||||
) -> BoardDTO:
|
||||
"""Updates a board"""
|
||||
try:
|
||||
result = ApiDependencies.invoker.services.boards.update(
|
||||
board_id=board_id, changes=changes
|
||||
)
|
||||
return result
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail="Failed to update board")
|
||||
|
||||
|
||||
@boards_router.delete("/{board_id}", operation_id="delete_board")
|
||||
async def delete_board(
|
||||
board_id: str = Path(description="The id of board to delete"),
|
||||
) -> None:
|
||||
"""Deletes a board"""
|
||||
|
||||
try:
|
||||
ApiDependencies.invoker.services.boards.delete(board_id=board_id)
|
||||
except Exception as e:
|
||||
# TODO: Does this need any exception handling at all?
|
||||
pass
|
||||
|
||||
|
||||
@boards_router.get(
|
||||
"/",
|
||||
operation_id="list_boards",
|
||||
response_model=Union[OffsetPaginatedResults[BoardDTO], list[BoardDTO]],
|
||||
)
|
||||
async def list_boards(
|
||||
all: Optional[bool] = Query(default=None, description="Whether to list all boards"),
|
||||
offset: Optional[int] = Query(default=None, description="The page offset"),
|
||||
limit: Optional[int] = Query(
|
||||
default=None, description="The number of boards per page"
|
||||
),
|
||||
) -> Union[OffsetPaginatedResults[BoardDTO], list[BoardDTO]]:
|
||||
"""Gets a list of boards"""
|
||||
if all:
|
||||
return ApiDependencies.invoker.services.boards.get_all()
|
||||
elif offset is not None and limit is not None:
|
||||
return ApiDependencies.invoker.services.boards.get_many(
|
||||
offset,
|
||||
limit,
|
||||
)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Invalid request: Must provide either 'all' or both 'offset' and 'limit'",
|
||||
)
|
@ -221,6 +221,9 @@ async def list_images_with_metadata(
|
||||
is_intermediate: Optional[bool] = Query(
|
||||
default=None, description="Whether to list intermediate images"
|
||||
),
|
||||
board_id: Optional[str] = Query(
|
||||
default=None, description="The board id to filter by"
|
||||
),
|
||||
offset: int = Query(default=0, description="The page offset"),
|
||||
limit: int = Query(default=10, description="The number of images per page"),
|
||||
) -> OffsetPaginatedResults[ImageDTO]:
|
||||
@ -232,6 +235,7 @@ async def list_images_with_metadata(
|
||||
image_origin,
|
||||
categories,
|
||||
is_intermediate,
|
||||
board_id,
|
||||
)
|
||||
|
||||
return image_dtos
|
||||
|
@ -1,13 +1,14 @@
|
||||
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654) and 2023 Kent Keirsey (https://github.com/hipsterusername)
|
||||
|
||||
import shutil
|
||||
import asyncio
|
||||
from typing import Annotated, Any, List, Literal, Optional, Union
|
||||
from typing import Annotated, Literal, Optional, Union, Dict
|
||||
|
||||
from fastapi import Query
|
||||
from fastapi.routing import APIRouter, HTTPException
|
||||
from pydantic import BaseModel, Field, parse_obj_as
|
||||
from pathlib import Path
|
||||
from ..dependencies import ApiDependencies
|
||||
from invokeai.backend import BaseModelType, ModelType
|
||||
from invokeai.backend.model_management.models import OPENAPI_MODEL_CONFIGS
|
||||
MODEL_CONFIGS = Union[tuple(OPENAPI_MODEL_CONFIGS)]
|
||||
|
||||
models_router = APIRouter(prefix="/v1/models", tags=["models"])
|
||||
|
||||
@ -19,6 +20,15 @@ class VaeRepo(BaseModel):
|
||||
|
||||
class ModelInfo(BaseModel):
|
||||
description: Optional[str] = Field(description="A description of the model")
|
||||
model_name: str = Field(description="The name of the model")
|
||||
model_type: str = Field(description="The type of the model")
|
||||
|
||||
class DiffusersModelInfo(ModelInfo):
|
||||
format: Literal['folder'] = 'folder'
|
||||
|
||||
vae: Optional[VaeRepo] = Field(description="The VAE repo to use for this model")
|
||||
repo_id: Optional[str] = Field(description="The repo ID to use for this model")
|
||||
path: Optional[str] = Field(description="The path to the model")
|
||||
|
||||
class CkptModelInfo(ModelInfo):
|
||||
format: Literal['ckpt'] = 'ckpt'
|
||||
@ -29,12 +39,8 @@ class CkptModelInfo(ModelInfo):
|
||||
width: Optional[int] = Field(description="The width of the model")
|
||||
height: Optional[int] = Field(description="The height of the model")
|
||||
|
||||
class DiffusersModelInfo(ModelInfo):
|
||||
format: Literal['diffusers'] = 'diffusers'
|
||||
|
||||
vae: Optional[VaeRepo] = Field(description="The VAE repo to use for this model")
|
||||
repo_id: Optional[str] = Field(description="The repo ID to use for this model")
|
||||
path: Optional[str] = Field(description="The path to the model")
|
||||
class SafetensorsModelInfo(CkptModelInfo):
|
||||
format: Literal['safetensors'] = 'safetensors'
|
||||
|
||||
class CreateModelRequest(BaseModel):
|
||||
name: str = Field(description="The name of the model")
|
||||
@ -56,7 +62,7 @@ class ConvertedModelResponse(BaseModel):
|
||||
info: DiffusersModelInfo = Field(description="The converted model info")
|
||||
|
||||
class ModelsList(BaseModel):
|
||||
models: dict[str, Annotated[Union[(CkptModelInfo,DiffusersModelInfo)], Field(discriminator="format")]]
|
||||
models: list[MODEL_CONFIGS]
|
||||
|
||||
|
||||
@models_router.get(
|
||||
@ -64,9 +70,16 @@ class ModelsList(BaseModel):
|
||||
operation_id="list_models",
|
||||
responses={200: {"model": ModelsList }},
|
||||
)
|
||||
async def list_models() -> ModelsList:
|
||||
async def list_models(
|
||||
base_model: Optional[BaseModelType] = Query(
|
||||
default=None, description="Base model"
|
||||
),
|
||||
model_type: Optional[ModelType] = Query(
|
||||
default=None, description="The type of model to get"
|
||||
),
|
||||
) -> ModelsList:
|
||||
"""Gets a list of models"""
|
||||
models_raw = ApiDependencies.invoker.services.model_manager.list_models()
|
||||
models_raw = ApiDependencies.invoker.services.model_manager.list_models(base_model, model_type)
|
||||
models = parse_obj_as(ModelsList, { "models": models_raw })
|
||||
return models
|
||||
|
||||
@ -121,7 +134,7 @@ async def delete_model(model_name: str) -> None:
|
||||
raise HTTPException(status_code=204, detail=f"Model '{model_name}' deleted successfully")
|
||||
|
||||
else:
|
||||
logger.error(f"Model not found")
|
||||
logger.error("Model not found")
|
||||
raise HTTPException(status_code=404, detail=f"Model '{model_name}' not found")
|
||||
|
||||
|
||||
|
@ -24,7 +24,7 @@ logger = InvokeAILogger.getLogger(config=app_config)
|
||||
import invokeai.frontend.web as web_dir
|
||||
|
||||
from .api.dependencies import ApiDependencies
|
||||
from .api.routers import sessions, models, images
|
||||
from .api.routers import sessions, models, images, boards, board_images
|
||||
from .api.sockets import SocketIO
|
||||
from .invocations.baseinvocation import BaseInvocation
|
||||
|
||||
@ -78,6 +78,10 @@ app.include_router(models.models_router, prefix="/api")
|
||||
|
||||
app.include_router(images.images_router, prefix="/api")
|
||||
|
||||
app.include_router(boards.boards_router, prefix="/api")
|
||||
|
||||
app.include_router(board_images.board_images_router, prefix="/api")
|
||||
|
||||
# Build a custom OpenAPI to include all outputs
|
||||
# TODO: can outputs be included on metadata of invocation schemas somehow?
|
||||
def custom_openapi():
|
||||
@ -116,6 +120,22 @@ def custom_openapi():
|
||||
|
||||
invoker_schema["output"] = outputs_ref
|
||||
|
||||
from invokeai.backend.model_management.models import get_model_config_enums
|
||||
for model_config_format_enum in set(get_model_config_enums()):
|
||||
name = model_config_format_enum.__qualname__
|
||||
|
||||
if name in openapi_schema["components"]["schemas"]:
|
||||
# print(f"Config with name {name} already defined")
|
||||
continue
|
||||
|
||||
# "BaseModelType":{"title":"BaseModelType","description":"An enumeration.","enum":["sd-1","sd-2"],"type":"string"}
|
||||
openapi_schema["components"]["schemas"][name] = dict(
|
||||
title=name,
|
||||
description="An enumeration.",
|
||||
type="string",
|
||||
enum=list(v.value for v in model_config_format_enum),
|
||||
)
|
||||
|
||||
app.openapi_schema = openapi_schema
|
||||
return app.openapi_schema
|
||||
|
||||
|
@ -6,10 +6,7 @@ import re
|
||||
import shlex
|
||||
import sys
|
||||
import time
|
||||
from typing import (
|
||||
Union,
|
||||
get_type_hints,
|
||||
)
|
||||
from typing import Union, get_type_hints
|
||||
|
||||
from pydantic import BaseModel, ValidationError
|
||||
from pydantic.fields import Field
|
||||
@ -26,23 +23,25 @@ from invokeai.app.services.images import ImageService
|
||||
from invokeai.app.services.metadata import CoreMetadataService
|
||||
from invokeai.app.services.resource_name import SimpleNameService
|
||||
from invokeai.app.services.urls import LocalUrlService
|
||||
|
||||
from .services.default_graphs import create_system_graphs
|
||||
from .services.default_graphs import (default_text_to_image_graph_id,
|
||||
create_system_graphs)
|
||||
from .services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage
|
||||
|
||||
from .cli.commands import BaseCommand, CliContext, ExitCli, add_graph_parsers, add_parsers, SortedHelpFormatter
|
||||
from .cli.commands import (BaseCommand, CliContext, ExitCli,
|
||||
SortedHelpFormatter, add_graph_parsers, add_parsers)
|
||||
from .cli.completer import set_autocompleter
|
||||
from .invocations.baseinvocation import BaseInvocation
|
||||
from .services.events import EventServiceBase
|
||||
from .services.model_manager_initializer import get_model_manager
|
||||
from .services.restoration_services import RestorationServices
|
||||
from .services.graph import Edge, EdgeConnection, GraphExecutionState, GraphInvocation, LibraryGraph, are_connection_types_compatible
|
||||
from .services.default_graphs import default_text_to_image_graph_id
|
||||
from .services.graph import (Edge, EdgeConnection, GraphExecutionState,
|
||||
GraphInvocation, LibraryGraph,
|
||||
are_connection_types_compatible)
|
||||
from .services.image_file_storage import DiskImageFileStorage
|
||||
from .services.invocation_queue import MemoryInvocationQueue
|
||||
from .services.invocation_services import InvocationServices
|
||||
from .services.invoker import Invoker
|
||||
from .services.model_manager_service import ModelManagerService
|
||||
from .services.processor import DefaultInvocationProcessor
|
||||
from .services.restoration_services import RestorationServices
|
||||
from .services.sqlite import SqliteItemStorage
|
||||
|
||||
|
||||
@ -197,7 +196,6 @@ def invoke_all(context: CliContext):
|
||||
raise SessionError()
|
||||
|
||||
def invoke_cli():
|
||||
|
||||
# get the optional list of invocations to execute on the command line
|
||||
parser = config.get_parser()
|
||||
parser.add_argument('commands',nargs='*')
|
||||
@ -208,7 +206,7 @@ def invoke_cli():
|
||||
if infile := config.from_file:
|
||||
sys.stdin = open(infile,"r")
|
||||
|
||||
model_manager = get_model_manager(config,logger=logger)
|
||||
model_manager = ModelManagerService(config,logger)
|
||||
|
||||
events = EventServiceBase()
|
||||
output_folder = config.output_path
|
||||
@ -258,8 +256,10 @@ def invoke_cli():
|
||||
configuration=config,
|
||||
)
|
||||
|
||||
|
||||
system_graphs = create_system_graphs(services.graph_library)
|
||||
system_graph_names = set([g.name for g in system_graphs])
|
||||
set_autocompleter(services)
|
||||
|
||||
invoker = Invoker(services)
|
||||
session: GraphExecutionState = invoker.create_execution_state()
|
||||
|
@ -1,13 +1,15 @@
|
||||
from typing import Literal, Optional, Union
|
||||
from pydantic import BaseModel, Field
|
||||
from contextlib import ExitStack
|
||||
import re
|
||||
|
||||
from invokeai.app.invocations.util.choose_model import choose_model
|
||||
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext, InvocationConfig
|
||||
from ...backend.prompting.conditioning import try_parse_legacy_blend
|
||||
from .model import ClipField
|
||||
|
||||
from ...backend.util.devices import choose_torch_device, torch_dtype
|
||||
from ...backend.util.devices import torch_dtype
|
||||
from ...backend.stable_diffusion.diffusion import InvokeAIDiffuserComponent
|
||||
from ...backend.stable_diffusion.textual_inversion_manager import TextualInversionManager
|
||||
from ...backend.model_management import BaseModelType, ModelType, SubModelType
|
||||
from ...backend.model_management.lora import ModelPatcher
|
||||
|
||||
from compel import Compel
|
||||
from compel.prompt_parser import (
|
||||
@ -40,7 +42,7 @@ class CompelInvocation(BaseInvocation):
|
||||
type: Literal["compel"] = "compel"
|
||||
|
||||
prompt: str = Field(default="", description="Prompt")
|
||||
model: str = Field(default="", description="Model to use")
|
||||
clip: ClipField = Field(None, description="Clip to use")
|
||||
|
||||
# Schema customisation
|
||||
class Config(InvocationConfig):
|
||||
@ -56,73 +58,74 @@ class CompelInvocation(BaseInvocation):
|
||||
|
||||
def invoke(self, context: InvocationContext) -> CompelOutput:
|
||||
|
||||
# TODO: load without model
|
||||
model = choose_model(context.services.model_manager, self.model)
|
||||
pipeline = model["model"]
|
||||
tokenizer = pipeline.tokenizer
|
||||
text_encoder = pipeline.text_encoder
|
||||
|
||||
# TODO: global? input?
|
||||
#use_full_precision = precision == "float32" or precision == "autocast"
|
||||
#use_full_precision = False
|
||||
|
||||
# TODO: redo TI when separate model loding implemented
|
||||
#textual_inversion_manager = TextualInversionManager(
|
||||
# tokenizer=tokenizer,
|
||||
# text_encoder=text_encoder,
|
||||
# full_precision=use_full_precision,
|
||||
#)
|
||||
|
||||
def load_huggingface_concepts(concepts: list[str]):
|
||||
pipeline.textual_inversion_manager.load_huggingface_concepts(concepts)
|
||||
|
||||
# apply the concepts library to the prompt
|
||||
prompt_str = pipeline.textual_inversion_manager.hf_concepts_library.replace_concepts_with_triggers(
|
||||
self.prompt,
|
||||
lambda concepts: load_huggingface_concepts(concepts),
|
||||
pipeline.textual_inversion_manager.get_all_trigger_strings(),
|
||||
tokenizer_info = context.services.model_manager.get_model(
|
||||
**self.clip.tokenizer.dict(),
|
||||
)
|
||||
|
||||
# lazy-load any deferred textual inversions.
|
||||
# this might take a couple of seconds the first time a textual inversion is used.
|
||||
pipeline.textual_inversion_manager.create_deferred_token_ids_for_any_trigger_terms(
|
||||
prompt_str
|
||||
text_encoder_info = context.services.model_manager.get_model(
|
||||
**self.clip.text_encoder.dict(),
|
||||
)
|
||||
with tokenizer_info as orig_tokenizer,\
|
||||
text_encoder_info as text_encoder,\
|
||||
ExitStack() as stack:
|
||||
|
||||
compel = Compel(
|
||||
tokenizer=tokenizer,
|
||||
text_encoder=text_encoder,
|
||||
textual_inversion_manager=pipeline.textual_inversion_manager,
|
||||
dtype_for_device_getter=torch_dtype,
|
||||
truncate_long_prompts=False,
|
||||
)
|
||||
loras = [(stack.enter_context(context.services.model_manager.get_model(**lora.dict(exclude={"weight"}))), lora.weight) for lora in self.clip.loras]
|
||||
|
||||
legacy_blend = try_parse_legacy_blend(prompt_str, skip_normalize=False)
|
||||
if legacy_blend is not None:
|
||||
conjunction = legacy_blend
|
||||
else:
|
||||
conjunction = Compel.parse_prompt_string(prompt_str)
|
||||
ti_list = []
|
||||
for trigger in re.findall(r"<[a-zA-Z0-9., _-]+>", self.prompt):
|
||||
name = trigger[1:-1]
|
||||
try:
|
||||
ti_list.append(
|
||||
stack.enter_context(
|
||||
context.services.model_manager.get_model(
|
||||
model_name=name,
|
||||
base_model=self.clip.text_encoder.base_model,
|
||||
model_type=ModelType.TextualInversion,
|
||||
)
|
||||
)
|
||||
)
|
||||
except Exception:
|
||||
#print(e)
|
||||
#import traceback
|
||||
#print(traceback.format_exc())
|
||||
print(f"Warn: trigger: \"{trigger}\" not found")
|
||||
|
||||
if context.services.configuration.log_tokenization:
|
||||
log_tokenization_for_conjunction(conjunction, tokenizer)
|
||||
with ModelPatcher.apply_lora_text_encoder(text_encoder, loras),\
|
||||
ModelPatcher.apply_ti(orig_tokenizer, text_encoder, ti_list) as (tokenizer, ti_manager):
|
||||
|
||||
c, options = compel.build_conditioning_tensor_for_conjunction(conjunction)
|
||||
compel = Compel(
|
||||
tokenizer=tokenizer,
|
||||
text_encoder=text_encoder,
|
||||
textual_inversion_manager=ti_manager,
|
||||
dtype_for_device_getter=torch_dtype,
|
||||
truncate_long_prompts=True, # TODO:
|
||||
)
|
||||
|
||||
ec = InvokeAIDiffuserComponent.ExtraConditioningInfo(
|
||||
tokens_count_including_eos_bos=get_max_token_count(tokenizer, conjunction),
|
||||
cross_attention_control_args=options.get("cross_attention_control", None),
|
||||
)
|
||||
conjunction = Compel.parse_prompt_string(self.prompt)
|
||||
prompt: Union[FlattenedPrompt, Blend] = conjunction.prompts[0]
|
||||
|
||||
conditioning_name = f"{context.graph_execution_state_id}_{self.id}_conditioning"
|
||||
if context.services.configuration.log_tokenization:
|
||||
log_tokenization_for_prompt_object(prompt, tokenizer)
|
||||
|
||||
# TODO: hacky but works ;D maybe rename latents somehow?
|
||||
context.services.latents.save(conditioning_name, (c, ec))
|
||||
c, options = compel.build_conditioning_tensor_for_prompt_object(prompt)
|
||||
|
||||
return CompelOutput(
|
||||
conditioning=ConditioningField(
|
||||
conditioning_name=conditioning_name,
|
||||
),
|
||||
)
|
||||
# TODO: long prompt support
|
||||
#if not self.truncate_long_prompts:
|
||||
# [c, uc] = compel.pad_conditioning_tensors_to_same_length([c, uc])
|
||||
ec = InvokeAIDiffuserComponent.ExtraConditioningInfo(
|
||||
tokens_count_including_eos_bos=get_max_token_count(tokenizer, conjunction),
|
||||
cross_attention_control_args=options.get("cross_attention_control", None),
|
||||
)
|
||||
|
||||
conditioning_name = f"{context.graph_execution_state_id}_{self.id}_conditioning"
|
||||
|
||||
# TODO: hacky but works ;D maybe rename latents somehow?
|
||||
context.services.latents.save(conditioning_name, (c, ec))
|
||||
|
||||
return CompelOutput(
|
||||
conditioning=ConditioningField(
|
||||
conditioning_name=conditioning_name,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def get_max_token_count(
|
||||
|
@ -3,23 +3,27 @@
|
||||
from functools import partial
|
||||
from typing import Literal, Optional, Union, get_args
|
||||
|
||||
import numpy as np
|
||||
from diffusers import ControlNetModel
|
||||
from torch import Tensor
|
||||
import torch
|
||||
|
||||
from diffusers import ControlNetModel
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from invokeai.app.models.image import ColorField, ImageField, ResourceOrigin
|
||||
from invokeai.app.invocations.util.choose_model import choose_model
|
||||
from invokeai.app.models.image import ImageCategory, ResourceOrigin
|
||||
from invokeai.app.models.image import (ColorField, ImageCategory, ImageField,
|
||||
ResourceOrigin)
|
||||
from invokeai.app.util.misc import SEED_MAX, get_random_seed
|
||||
from invokeai.backend.generator.inpaint import infill_methods
|
||||
from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig
|
||||
from .image import ImageOutput
|
||||
from ...backend.generator import Txt2Img, Img2Img, Inpaint, InvokeAIGenerator
|
||||
|
||||
from ...backend.generator import Inpaint, InvokeAIGenerator
|
||||
from ...backend.stable_diffusion import PipelineIntermediateState
|
||||
from ..util.step_callback import stable_diffusion_step_callback
|
||||
from .baseinvocation import BaseInvocation, InvocationConfig, InvocationContext
|
||||
from .image import ImageOutput
|
||||
|
||||
import re
|
||||
from ...backend.model_management.lora import ModelPatcher
|
||||
from ...backend.stable_diffusion.diffusers_pipeline import StableDiffusionGeneratorPipeline
|
||||
from .model import UNetField, VaeField
|
||||
from .compel import ConditioningField
|
||||
from contextlib import contextmanager, ExitStack, ContextDecorator
|
||||
|
||||
SAMPLER_NAME_VALUES = Literal[tuple(InvokeAIGenerator.schedulers())]
|
||||
INFILL_METHODS = Literal[tuple(infill_methods())]
|
||||
@ -28,114 +32,48 @@ DEFAULT_INFILL_METHOD = (
|
||||
)
|
||||
|
||||
|
||||
class SDImageInvocation(BaseModel):
|
||||
"""Helper class to provide all Stable Diffusion raster image invocations with additional config"""
|
||||
from .latent import get_scheduler
|
||||
|
||||
# Schema customisation
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"tags": ["stable-diffusion", "image"],
|
||||
"type_hints": {
|
||||
"model": "model",
|
||||
},
|
||||
},
|
||||
}
|
||||
class OldModelContext(ContextDecorator):
|
||||
model: StableDiffusionGeneratorPipeline
|
||||
|
||||
def __init__(self, model):
|
||||
self.model = model
|
||||
|
||||
def __enter__(self):
|
||||
return self.model
|
||||
|
||||
def __exit__(self, *exc):
|
||||
return False
|
||||
|
||||
class OldModelInfo:
|
||||
name: str
|
||||
hash: str
|
||||
context: OldModelContext
|
||||
|
||||
def __init__(self, name: str, hash: str, model: StableDiffusionGeneratorPipeline):
|
||||
self.name = name
|
||||
self.hash = hash
|
||||
self.context = OldModelContext(
|
||||
model=model,
|
||||
)
|
||||
|
||||
|
||||
# Text to image
|
||||
class TextToImageInvocation(BaseInvocation, SDImageInvocation):
|
||||
"""Generates an image using text2img."""
|
||||
class InpaintInvocation(BaseInvocation):
|
||||
"""Generates an image using inpaint."""
|
||||
|
||||
type: Literal["txt2img"] = "txt2img"
|
||||
type: Literal["inpaint"] = "inpaint"
|
||||
|
||||
# Inputs
|
||||
# TODO: consider making prompt optional to enable providing prompt through a link
|
||||
# fmt: off
|
||||
prompt: Optional[str] = Field(description="The prompt to generate an image from")
|
||||
positive_conditioning: Optional[ConditioningField] = Field(description="Positive conditioning for generation")
|
||||
negative_conditioning: Optional[ConditioningField] = Field(description="Negative conditioning for generation")
|
||||
seed: int = Field(ge=0, le=SEED_MAX, description="The seed to use (omit for random)", default_factory=get_random_seed)
|
||||
steps: int = Field(default=30, gt=0, description="The number of steps to use to generate the image")
|
||||
width: int = Field(default=512, multiple_of=8, gt=0, description="The width of the resulting image", )
|
||||
height: int = Field(default=512, multiple_of=8, gt=0, description="The height of the resulting image", )
|
||||
cfg_scale: float = Field(default=7.5, ge=1, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", )
|
||||
scheduler: SAMPLER_NAME_VALUES = Field(default="euler", description="The scheduler to use" )
|
||||
model: str = Field(default="", description="The model to use (currently ignored)")
|
||||
progress_images: bool = Field(default=False, description="Whether or not to produce progress images during generation", )
|
||||
control_model: Optional[str] = Field(default=None, description="The control model to use")
|
||||
control_image: Optional[ImageField] = Field(default=None, description="The processed control image")
|
||||
# fmt: on
|
||||
|
||||
# TODO: pass this an emitter method or something? or a session for dispatching?
|
||||
def dispatch_progress(
|
||||
self,
|
||||
context: InvocationContext,
|
||||
source_node_id: str,
|
||||
intermediate_state: PipelineIntermediateState,
|
||||
) -> None:
|
||||
stable_diffusion_step_callback(
|
||||
context=context,
|
||||
intermediate_state=intermediate_state,
|
||||
node=self.dict(),
|
||||
source_node_id=source_node_id,
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
# Handle invalid model parameter
|
||||
model = choose_model(context.services.model_manager, self.model)
|
||||
|
||||
# loading controlnet image (currently requires pre-processed image)
|
||||
control_image = (
|
||||
None if self.control_image is None
|
||||
else context.services.images.get_pil_image(self.control_image.image_name)
|
||||
)
|
||||
# loading controlnet model
|
||||
if (self.control_model is None or self.control_model==''):
|
||||
control_model = None
|
||||
else:
|
||||
# FIXME: change this to dropdown menu?
|
||||
# FIXME: generalize so don't have to hardcode torch_dtype and device
|
||||
control_model = ControlNetModel.from_pretrained(self.control_model,
|
||||
torch_dtype=torch.float16).to("cuda")
|
||||
|
||||
# Get the source node id (we are invoking the prepared node)
|
||||
graph_execution_state = context.services.graph_execution_manager.get(
|
||||
context.graph_execution_state_id
|
||||
)
|
||||
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
|
||||
|
||||
txt2img = Txt2Img(model, control_model=control_model)
|
||||
outputs = txt2img.generate(
|
||||
prompt=self.prompt,
|
||||
step_callback=partial(self.dispatch_progress, context, source_node_id),
|
||||
control_image=control_image,
|
||||
**self.dict(
|
||||
exclude={"prompt", "control_image" }
|
||||
), # Shorthand for passing all of the parameters above manually
|
||||
)
|
||||
# Outputs is an infinite iterator that will return a new InvokeAIGeneratorOutput object
|
||||
# each time it is called. We only need the first one.
|
||||
generate_output = next(outputs)
|
||||
|
||||
image_dto = context.services.images.create(
|
||||
image=generate_output.image,
|
||||
image_origin=ResourceOrigin.INTERNAL,
|
||||
image_category=ImageCategory.GENERAL,
|
||||
session_id=context.graph_execution_state_id,
|
||||
node_id=self.id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
)
|
||||
|
||||
return ImageOutput(
|
||||
image=ImageField(image_name=image_dto.image_name),
|
||||
width=image_dto.width,
|
||||
height=image_dto.height,
|
||||
)
|
||||
|
||||
|
||||
class ImageToImageInvocation(TextToImageInvocation):
|
||||
"""Generates an image using img2img."""
|
||||
|
||||
type: Literal["img2img"] = "img2img"
|
||||
unet: UNetField = Field(default=None, description="UNet model")
|
||||
vae: VaeField = Field(default=None, description="Vae model")
|
||||
|
||||
# Inputs
|
||||
image: Union[ImageField, None] = Field(description="The input image")
|
||||
@ -147,72 +85,6 @@ class ImageToImageInvocation(TextToImageInvocation):
|
||||
description="Whether or not the result should be fit to the aspect ratio of the input image",
|
||||
)
|
||||
|
||||
def dispatch_progress(
|
||||
self,
|
||||
context: InvocationContext,
|
||||
source_node_id: str,
|
||||
intermediate_state: PipelineIntermediateState,
|
||||
) -> None:
|
||||
stable_diffusion_step_callback(
|
||||
context=context,
|
||||
intermediate_state=intermediate_state,
|
||||
node=self.dict(),
|
||||
source_node_id=source_node_id,
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = (
|
||||
None
|
||||
if self.image is None
|
||||
else context.services.images.get_pil_image(self.image.image_name)
|
||||
)
|
||||
|
||||
if self.fit:
|
||||
image = image.resize((self.width, self.height))
|
||||
|
||||
# Handle invalid model parameter
|
||||
model = choose_model(context.services.model_manager, self.model)
|
||||
|
||||
# Get the source node id (we are invoking the prepared node)
|
||||
graph_execution_state = context.services.graph_execution_manager.get(
|
||||
context.graph_execution_state_id
|
||||
)
|
||||
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
|
||||
|
||||
outputs = Img2Img(model).generate(
|
||||
prompt=self.prompt,
|
||||
init_image=image,
|
||||
step_callback=partial(self.dispatch_progress, context, source_node_id),
|
||||
**self.dict(
|
||||
exclude={"prompt", "image", "mask"}
|
||||
), # Shorthand for passing all of the parameters above manually
|
||||
)
|
||||
|
||||
# Outputs is an infinite iterator that will return a new InvokeAIGeneratorOutput object
|
||||
# each time it is called. We only need the first one.
|
||||
generator_output = next(outputs)
|
||||
|
||||
image_dto = context.services.images.create(
|
||||
image=generator_output.image,
|
||||
image_origin=ResourceOrigin.INTERNAL,
|
||||
image_category=ImageCategory.GENERAL,
|
||||
session_id=context.graph_execution_state_id,
|
||||
node_id=self.id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
)
|
||||
|
||||
return ImageOutput(
|
||||
image=ImageField(image_name=image_dto.image_name),
|
||||
width=image_dto.width,
|
||||
height=image_dto.height,
|
||||
)
|
||||
|
||||
|
||||
class InpaintInvocation(ImageToImageInvocation):
|
||||
"""Generates an image using inpaint."""
|
||||
|
||||
type: Literal["inpaint"] = "inpaint"
|
||||
|
||||
# Inputs
|
||||
mask: Union[ImageField, None] = Field(description="The mask")
|
||||
seam_size: int = Field(default=96, ge=1, description="The seam inpaint size (px)")
|
||||
@ -255,6 +127,14 @@ class InpaintInvocation(ImageToImageInvocation):
|
||||
description="The amount by which to replace masked areas with latent noise",
|
||||
)
|
||||
|
||||
# Schema customisation
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"tags": ["stable-diffusion", "image"],
|
||||
},
|
||||
}
|
||||
|
||||
def dispatch_progress(
|
||||
self,
|
||||
context: InvocationContext,
|
||||
@ -268,6 +148,49 @@ class InpaintInvocation(ImageToImageInvocation):
|
||||
source_node_id=source_node_id,
|
||||
)
|
||||
|
||||
def get_conditioning(self, context):
|
||||
c, extra_conditioning_info = context.services.latents.get(self.positive_conditioning.conditioning_name)
|
||||
uc, _ = context.services.latents.get(self.negative_conditioning.conditioning_name)
|
||||
|
||||
return (uc, c, extra_conditioning_info)
|
||||
|
||||
@contextmanager
|
||||
def load_model_old_way(self, context, scheduler):
|
||||
unet_info = context.services.model_manager.get_model(**self.unet.unet.dict())
|
||||
vae_info = context.services.model_manager.get_model(**self.vae.vae.dict())
|
||||
|
||||
#unet = unet_info.context.model
|
||||
#vae = vae_info.context.model
|
||||
|
||||
with ExitStack() as stack:
|
||||
loras = [(stack.enter_context(context.services.model_manager.get_model(**lora.dict(exclude={"weight"}))), lora.weight) for lora in self.unet.loras]
|
||||
|
||||
with vae_info as vae,\
|
||||
unet_info as unet,\
|
||||
ModelPatcher.apply_lora_unet(unet, loras):
|
||||
|
||||
device = context.services.model_manager.mgr.cache.execution_device
|
||||
dtype = context.services.model_manager.mgr.cache.precision
|
||||
|
||||
pipeline = StableDiffusionGeneratorPipeline(
|
||||
vae=vae,
|
||||
text_encoder=None,
|
||||
tokenizer=None,
|
||||
unet=unet,
|
||||
scheduler=scheduler,
|
||||
safety_checker=None,
|
||||
feature_extractor=None,
|
||||
requires_safety_checker=False,
|
||||
precision="float16" if dtype == torch.float16 else "float32",
|
||||
execution_device=device,
|
||||
)
|
||||
|
||||
yield OldModelInfo(
|
||||
name=self.unet.unet.model_name,
|
||||
hash="<NO-HASH>",
|
||||
model=pipeline,
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = (
|
||||
None
|
||||
@ -280,25 +203,31 @@ class InpaintInvocation(ImageToImageInvocation):
|
||||
else context.services.images.get_pil_image(self.mask.image_name)
|
||||
)
|
||||
|
||||
# Handle invalid model parameter
|
||||
model = choose_model(context.services.model_manager, self.model)
|
||||
|
||||
# Get the source node id (we are invoking the prepared node)
|
||||
graph_execution_state = context.services.graph_execution_manager.get(
|
||||
context.graph_execution_state_id
|
||||
)
|
||||
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
|
||||
|
||||
outputs = Inpaint(model).generate(
|
||||
prompt=self.prompt,
|
||||
init_image=image,
|
||||
mask_image=mask,
|
||||
step_callback=partial(self.dispatch_progress, context, source_node_id),
|
||||
**self.dict(
|
||||
exclude={"prompt", "image", "mask"}
|
||||
), # Shorthand for passing all of the parameters above manually
|
||||
conditioning = self.get_conditioning(context)
|
||||
scheduler = get_scheduler(
|
||||
context=context,
|
||||
scheduler_info=self.unet.scheduler,
|
||||
scheduler_name=self.scheduler,
|
||||
)
|
||||
|
||||
with self.load_model_old_way(context, scheduler) as model:
|
||||
outputs = Inpaint(model).generate(
|
||||
conditioning=conditioning,
|
||||
scheduler=scheduler,
|
||||
init_image=image,
|
||||
mask_image=mask,
|
||||
step_callback=partial(self.dispatch_progress, context, source_node_id),
|
||||
**self.dict(
|
||||
exclude={"positive_conditioning", "negative_conditioning", "scheduler", "image", "mask"}
|
||||
), # Shorthand for passing all of the parameters above manually
|
||||
)
|
||||
|
||||
# Outputs is an infinite iterator that will return a new InvokeAIGeneratorOutput object
|
||||
# each time it is called. We only need the first one.
|
||||
generator_output = next(outputs)
|
||||
|
@ -1,43 +1,36 @@
|
||||
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
import random
|
||||
import einops
|
||||
from typing import Literal, Optional, Union, List
|
||||
from contextlib import ExitStack
|
||||
from typing import List, Literal, Optional, Union
|
||||
|
||||
from compel import Compel
|
||||
from diffusers.pipelines.controlnet import MultiControlNetModel
|
||||
import einops
|
||||
|
||||
from pydantic import BaseModel, Field, validator
|
||||
import torch
|
||||
|
||||
from invokeai.app.invocations.util.choose_model import choose_model
|
||||
from invokeai.app.models.image import ImageCategory
|
||||
from invokeai.app.util.misc import SEED_MAX, get_random_seed
|
||||
|
||||
from invokeai.app.util.step_callback import stable_diffusion_step_callback
|
||||
from .controlnet_image_processors import ControlField
|
||||
|
||||
from ...backend.model_management.model_manager import ModelManager
|
||||
from ...backend.util.devices import choose_torch_device, torch_dtype
|
||||
from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import PostprocessingSettings
|
||||
from ...backend.image_util.seamless import configure_model_padding
|
||||
from ...backend.prompting.conditioning import get_uc_and_c_and_ec
|
||||
|
||||
from ...backend.stable_diffusion.diffusers_pipeline import ConditioningData, StableDiffusionGeneratorPipeline, image_resized_to_grid_as_tensor
|
||||
from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP
|
||||
from ...backend.stable_diffusion.diffusers_pipeline import ControlNetData
|
||||
|
||||
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext, InvocationConfig
|
||||
import numpy as np
|
||||
from ..services.image_file_storage import ResourceOrigin
|
||||
from .baseinvocation import BaseInvocation, InvocationContext
|
||||
from .image import ImageField, ImageOutput
|
||||
from .compel import ConditioningField
|
||||
from ...backend.stable_diffusion import PipelineIntermediateState
|
||||
from diffusers import ControlNetModel, DPMSolverMultistepScheduler
|
||||
from diffusers.image_processor import VaeImageProcessor
|
||||
from diffusers.schedulers import SchedulerMixin as Scheduler
|
||||
import diffusers
|
||||
from diffusers import DiffusionPipeline, ControlNetModel
|
||||
|
||||
from invokeai.app.util.misc import SEED_MAX, get_random_seed
|
||||
from invokeai.app.util.step_callback import stable_diffusion_step_callback
|
||||
|
||||
from ..models.image import ImageCategory, ImageField, ResourceOrigin
|
||||
from ...backend.image_util.seamless import configure_model_padding
|
||||
from ...backend.stable_diffusion import PipelineIntermediateState
|
||||
from ...backend.stable_diffusion.diffusers_pipeline import (
|
||||
ConditioningData, ControlNetData, StableDiffusionGeneratorPipeline,
|
||||
image_resized_to_grid_as_tensor)
|
||||
from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import \
|
||||
PostprocessingSettings
|
||||
from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP
|
||||
from ...backend.util.devices import choose_torch_device, torch_dtype
|
||||
from ...backend.model_management.lora import ModelPatcher
|
||||
from .baseinvocation import (BaseInvocation, BaseInvocationOutput,
|
||||
InvocationConfig, InvocationContext)
|
||||
from .compel import ConditioningField
|
||||
from .controlnet_image_processors import ControlField
|
||||
from .image import ImageOutput
|
||||
from .model import ModelInfo, UNetField, VaeField
|
||||
|
||||
class LatentsField(BaseModel):
|
||||
"""A latents field used for passing latents between invocations"""
|
||||
@ -90,10 +83,17 @@ SAMPLER_NAME_VALUES = Literal[
|
||||
]
|
||||
|
||||
|
||||
def get_scheduler(scheduler_name:str, model: StableDiffusionGeneratorPipeline)->Scheduler:
|
||||
scheduler_class, scheduler_extra_config = SCHEDULER_MAP.get(scheduler_name, SCHEDULER_MAP['ddim'])
|
||||
|
||||
scheduler_config = model.scheduler.config
|
||||
def get_scheduler(
|
||||
context: InvocationContext,
|
||||
scheduler_info: ModelInfo,
|
||||
scheduler_name: str,
|
||||
) -> Scheduler:
|
||||
scheduler_class, scheduler_extra_config = SCHEDULER_MAP.get(scheduler_name, SCHEDULER_MAP['ddim'])
|
||||
orig_scheduler_info = context.services.model_manager.get_model(**scheduler_info.dict())
|
||||
with orig_scheduler_info as orig_scheduler:
|
||||
scheduler_config = orig_scheduler.config
|
||||
|
||||
if "_backup" in scheduler_config:
|
||||
scheduler_config = scheduler_config["_backup"]
|
||||
scheduler_config = {**scheduler_config, **scheduler_extra_config, "_backup": scheduler_config}
|
||||
@ -128,7 +128,6 @@ def get_noise(width:int, height:int, device:torch.device, seed:int = 0, latent_c
|
||||
# x = (1 - self.perlin) * x + self.perlin * perlin_noise
|
||||
return x
|
||||
|
||||
|
||||
class NoiseInvocation(BaseInvocation):
|
||||
"""Generates latent noise."""
|
||||
|
||||
@ -176,10 +175,10 @@ class TextToLatentsInvocation(BaseInvocation):
|
||||
steps: int = Field(default=10, gt=0, description="The number of steps to use to generate the image")
|
||||
cfg_scale: Union[float, List[float]] = Field(default=7.5, ge=1, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", )
|
||||
scheduler: SAMPLER_NAME_VALUES = Field(default="euler", description="The scheduler to use" )
|
||||
model: str = Field(default="", description="The model to use (currently ignored)")
|
||||
control: Union[ControlField, List[ControlField]] = Field(default=None, description="The control to use")
|
||||
# seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", )
|
||||
# seamless_axes: str = Field(default="", description="The axes to tile the image on, 'x' and/or 'y'")
|
||||
unet: UNetField = Field(default=None, description="UNet submodel")
|
||||
control: Union[ControlField, list[ControlField]] = Field(default=None, description="The control to use")
|
||||
#seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", )
|
||||
#seamless_axes: str = Field(default="", description="The axes to tile the image on, 'x' and/or 'y'")
|
||||
# fmt: on
|
||||
|
||||
@validator("cfg_scale")
|
||||
@ -219,44 +218,10 @@ class TextToLatentsInvocation(BaseInvocation):
|
||||
source_node_id=source_node_id,
|
||||
)
|
||||
|
||||
def get_model(self, model_manager: ModelManager) -> StableDiffusionGeneratorPipeline:
|
||||
model_info = choose_model(model_manager, self.model)
|
||||
model_name = model_info['model_name']
|
||||
model_hash = model_info['hash']
|
||||
model: StableDiffusionGeneratorPipeline = model_info['model']
|
||||
model.scheduler = get_scheduler(
|
||||
model=model,
|
||||
scheduler_name=self.scheduler
|
||||
)
|
||||
|
||||
# if isinstance(model, DiffusionPipeline):
|
||||
# for component in [model.unet, model.vae]:
|
||||
# configure_model_padding(component,
|
||||
# self.seamless,
|
||||
# self.seamless_axes
|
||||
# )
|
||||
# else:
|
||||
# configure_model_padding(model,
|
||||
# self.seamless,
|
||||
# self.seamless_axes
|
||||
# )
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def get_conditioning_data(self, context: InvocationContext, model: StableDiffusionGeneratorPipeline) -> ConditioningData:
|
||||
def get_conditioning_data(self, context: InvocationContext, scheduler) -> ConditioningData:
|
||||
c, extra_conditioning_info = context.services.latents.get(self.positive_conditioning.conditioning_name)
|
||||
uc, _ = context.services.latents.get(self.negative_conditioning.conditioning_name)
|
||||
|
||||
compel = Compel(
|
||||
tokenizer=model.tokenizer,
|
||||
text_encoder=model.text_encoder,
|
||||
textual_inversion_manager=model.textual_inversion_manager,
|
||||
dtype_for_device_getter=torch_dtype,
|
||||
truncate_long_prompts=False,
|
||||
)
|
||||
[c, uc] = compel.pad_conditioning_tensors_to_same_length([c, uc])
|
||||
|
||||
conditioning_data = ConditioningData(
|
||||
unconditioned_embeddings=uc,
|
||||
text_embeddings=c,
|
||||
@ -268,16 +233,56 @@ class TextToLatentsInvocation(BaseInvocation):
|
||||
h_symmetry_time_pct=None,#h_symmetry_time_pct,
|
||||
v_symmetry_time_pct=None#v_symmetry_time_pct,
|
||||
),
|
||||
).add_scheduler_args_if_applicable(model.scheduler, eta=0.0)#ddim_eta)
|
||||
)
|
||||
|
||||
conditioning_data = conditioning_data.add_scheduler_args_if_applicable(
|
||||
scheduler,
|
||||
|
||||
# for ddim scheduler
|
||||
eta=0.0, #ddim_eta
|
||||
|
||||
# for ancestral and sde schedulers
|
||||
generator=torch.Generator(device=uc.device).manual_seed(0),
|
||||
)
|
||||
return conditioning_data
|
||||
|
||||
def prep_control_data(self,
|
||||
context: InvocationContext,
|
||||
model: StableDiffusionGeneratorPipeline, # really only need model for dtype and device
|
||||
control_input: List[ControlField],
|
||||
latents_shape: List[int],
|
||||
do_classifier_free_guidance: bool = True,
|
||||
) -> List[ControlNetData]:
|
||||
def create_pipeline(self, unet, scheduler) -> StableDiffusionGeneratorPipeline:
|
||||
# TODO:
|
||||
#configure_model_padding(
|
||||
# unet,
|
||||
# self.seamless,
|
||||
# self.seamless_axes,
|
||||
#)
|
||||
|
||||
class FakeVae:
|
||||
class FakeVaeConfig:
|
||||
def __init__(self):
|
||||
self.block_out_channels = [0]
|
||||
|
||||
def __init__(self):
|
||||
self.config = FakeVae.FakeVaeConfig()
|
||||
|
||||
return StableDiffusionGeneratorPipeline(
|
||||
vae=FakeVae(), # TODO: oh...
|
||||
text_encoder=None,
|
||||
tokenizer=None,
|
||||
unet=unet,
|
||||
scheduler=scheduler,
|
||||
safety_checker=None,
|
||||
feature_extractor=None,
|
||||
requires_safety_checker=False,
|
||||
precision="float16" if unet.dtype == torch.float16 else "float32",
|
||||
)
|
||||
|
||||
def prep_control_data(
|
||||
self,
|
||||
context: InvocationContext,
|
||||
model: StableDiffusionGeneratorPipeline, # really only need model for dtype and device
|
||||
control_input: List[ControlField],
|
||||
latents_shape: List[int],
|
||||
do_classifier_free_guidance: bool = True,
|
||||
) -> List[ControlNetData]:
|
||||
|
||||
# assuming fixed dimensional scaling of 8:1 for image:latents
|
||||
control_height_resize = latents_shape[2] * 8
|
||||
control_width_resize = latents_shape[3] * 8
|
||||
@ -354,23 +359,38 @@ class TextToLatentsInvocation(BaseInvocation):
|
||||
def step_callback(state: PipelineIntermediateState):
|
||||
self.dispatch_progress(context, source_node_id, state)
|
||||
|
||||
model = self.get_model(context.services.model_manager)
|
||||
conditioning_data = self.get_conditioning_data(context, model)
|
||||
unet_info = context.services.model_manager.get_model(**self.unet.unet.dict())
|
||||
with unet_info as unet,\
|
||||
ExitStack() as stack:
|
||||
|
||||
control_data = self.prep_control_data(model=model, context=context, control_input=self.control,
|
||||
latents_shape=noise.shape,
|
||||
# do_classifier_free_guidance=(self.cfg_scale >= 1.0))
|
||||
do_classifier_free_guidance=True,)
|
||||
scheduler = get_scheduler(
|
||||
context=context,
|
||||
scheduler_info=self.unet.scheduler,
|
||||
scheduler_name=self.scheduler,
|
||||
)
|
||||
|
||||
# TODO: Verify the noise is the right size
|
||||
result_latents, result_attention_map_saver = model.latents_from_embeddings(
|
||||
latents=torch.zeros_like(noise, dtype=torch_dtype(model.device)),
|
||||
noise=noise,
|
||||
num_inference_steps=self.steps,
|
||||
conditioning_data=conditioning_data,
|
||||
control_data=control_data, # list[ControlNetData]
|
||||
callback=step_callback,
|
||||
)
|
||||
pipeline = self.create_pipeline(unet, scheduler)
|
||||
conditioning_data = self.get_conditioning_data(context, scheduler)
|
||||
|
||||
loras = [(stack.enter_context(context.services.model_manager.get_model(**lora.dict(exclude={"weight"}))), lora.weight) for lora in self.unet.loras]
|
||||
|
||||
control_data = self.prep_control_data(
|
||||
model=pipeline, context=context, control_input=self.control,
|
||||
latents_shape=noise.shape,
|
||||
# do_classifier_free_guidance=(self.cfg_scale >= 1.0))
|
||||
do_classifier_free_guidance=True,
|
||||
)
|
||||
|
||||
with ModelPatcher.apply_lora_unet(pipeline.unet, loras):
|
||||
# TODO: Verify the noise is the right size
|
||||
result_latents, result_attention_map_saver = pipeline.latents_from_embeddings(
|
||||
latents=torch.zeros_like(noise, dtype=torch_dtype(unet.device)),
|
||||
noise=noise,
|
||||
num_inference_steps=self.steps,
|
||||
conditioning_data=conditioning_data,
|
||||
control_data=control_data, # list[ControlNetData]
|
||||
callback=step_callback,
|
||||
)
|
||||
|
||||
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
||||
torch.cuda.empty_cache()
|
||||
@ -379,7 +399,6 @@ class TextToLatentsInvocation(BaseInvocation):
|
||||
context.services.latents.save(name, result_latents)
|
||||
return build_latents_output(latents_name=name, latents=result_latents)
|
||||
|
||||
|
||||
class LatentsToLatentsInvocation(TextToLatentsInvocation):
|
||||
"""Generates latents using latents as base image."""
|
||||
|
||||
@ -413,32 +432,52 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
|
||||
def step_callback(state: PipelineIntermediateState):
|
||||
self.dispatch_progress(context, source_node_id, state)
|
||||
|
||||
model = self.get_model(context.services.model_manager)
|
||||
conditioning_data = self.get_conditioning_data(context, model)
|
||||
|
||||
control_data = self.prep_control_data(model=model, context=context, control_input=self.control,
|
||||
latents_shape=noise.shape,
|
||||
# do_classifier_free_guidance=(self.cfg_scale >= 1.0))
|
||||
do_classifier_free_guidance=True,
|
||||
)
|
||||
|
||||
# TODO: Verify the noise is the right size
|
||||
|
||||
initial_latents = latent if self.strength < 1.0 else torch.zeros_like(
|
||||
latent, device=model.device, dtype=latent.dtype
|
||||
unet_info = context.services.model_manager.get_model(
|
||||
**self.unet.unet.dict(),
|
||||
)
|
||||
|
||||
timesteps, _ = model.get_img2img_timesteps(self.steps, self.strength)
|
||||
with unet_info as unet,\
|
||||
ExitStack() as stack:
|
||||
|
||||
result_latents, result_attention_map_saver = model.latents_from_embeddings(
|
||||
latents=initial_latents,
|
||||
timesteps=timesteps,
|
||||
noise=noise,
|
||||
num_inference_steps=self.steps,
|
||||
conditioning_data=conditioning_data,
|
||||
control_data=control_data, # list[ControlNetData]
|
||||
callback=step_callback
|
||||
)
|
||||
scheduler = get_scheduler(
|
||||
context=context,
|
||||
scheduler_info=self.unet.scheduler,
|
||||
scheduler_name=self.scheduler,
|
||||
)
|
||||
|
||||
pipeline = self.create_pipeline(unet, scheduler)
|
||||
conditioning_data = self.get_conditioning_data(context, scheduler)
|
||||
|
||||
control_data = self.prep_control_data(
|
||||
model=pipeline, context=context, control_input=self.control,
|
||||
latents_shape=noise.shape,
|
||||
# do_classifier_free_guidance=(self.cfg_scale >= 1.0))
|
||||
do_classifier_free_guidance=True,
|
||||
)
|
||||
|
||||
# TODO: Verify the noise is the right size
|
||||
initial_latents = latent if self.strength < 1.0 else torch.zeros_like(
|
||||
latent, device=unet.device, dtype=latent.dtype
|
||||
)
|
||||
|
||||
timesteps, _ = pipeline.get_img2img_timesteps(
|
||||
self.steps,
|
||||
self.strength,
|
||||
device=unet.device,
|
||||
)
|
||||
|
||||
loras = [(stack.enter_context(context.services.model_manager.get_model(**lora.dict(exclude={"weight"}))), lora.weight) for lora in self.unet.loras]
|
||||
|
||||
with ModelPatcher.apply_lora_unet(pipeline.unet, loras):
|
||||
result_latents, result_attention_map_saver = pipeline.latents_from_embeddings(
|
||||
latents=initial_latents,
|
||||
timesteps=timesteps,
|
||||
noise=noise,
|
||||
num_inference_steps=self.steps,
|
||||
conditioning_data=conditioning_data,
|
||||
control_data=control_data, # list[ControlNetData]
|
||||
callback=step_callback
|
||||
)
|
||||
|
||||
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
||||
torch.cuda.empty_cache()
|
||||
@ -456,16 +495,14 @@ class LatentsToImageInvocation(BaseInvocation):
|
||||
|
||||
# Inputs
|
||||
latents: Optional[LatentsField] = Field(description="The latents to generate an image from")
|
||||
model: str = Field(default="", description="The model to use")
|
||||
vae: VaeField = Field(default=None, description="Vae submodel")
|
||||
tiled: bool = Field(default=False, description="Decode latents by overlaping tiles(less memory consumption)")
|
||||
|
||||
# Schema customisation
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"tags": ["latents", "image"],
|
||||
"type_hints": {
|
||||
"model": "model"
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
@ -473,37 +510,45 @@ class LatentsToImageInvocation(BaseInvocation):
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
latents = context.services.latents.get(self.latents.latents_name)
|
||||
|
||||
# TODO: this only really needs the vae
|
||||
model_info = choose_model(context.services.model_manager, self.model)
|
||||
model: StableDiffusionGeneratorPipeline = model_info['model']
|
||||
vae_info = context.services.model_manager.get_model(
|
||||
**self.vae.vae.dict(),
|
||||
)
|
||||
|
||||
with torch.inference_mode():
|
||||
np_image = model.decode_latents(latents)
|
||||
image = model.numpy_to_pil(np_image)[0]
|
||||
|
||||
# what happened to metadata?
|
||||
# metadata = context.services.metadata.build_metadata(
|
||||
# session_id=context.graph_execution_state_id, node=self
|
||||
with vae_info as vae:
|
||||
if self.tiled or context.services.configuration.tiled_decode:
|
||||
vae.enable_tiling()
|
||||
else:
|
||||
vae.disable_tiling()
|
||||
|
||||
# clear memory as vae decode can request a lot
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# new (post Image service refactor) way of using services to save image
|
||||
# and gnenerate unique image_name
|
||||
image_dto = context.services.images.create(
|
||||
image=image,
|
||||
image_origin=ResourceOrigin.INTERNAL,
|
||||
image_category=ImageCategory.GENERAL,
|
||||
session_id=context.graph_execution_state_id,
|
||||
node_id=self.id,
|
||||
is_intermediate=self.is_intermediate
|
||||
)
|
||||
with torch.inference_mode():
|
||||
# copied from diffusers pipeline
|
||||
latents = latents / vae.config.scaling_factor
|
||||
image = vae.decode(latents, return_dict=False)[0]
|
||||
image = (image / 2 + 0.5).clamp(0, 1) # denormalize
|
||||
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
|
||||
np_image = image.cpu().permute(0, 2, 3, 1).float().numpy()
|
||||
|
||||
return ImageOutput(
|
||||
image=ImageField(image_name=image_dto.image_name),
|
||||
width=image_dto.width,
|
||||
height=image_dto.height,
|
||||
)
|
||||
image = VaeImageProcessor.numpy_to_pil(np_image)[0]
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
image_dto = context.services.images.create(
|
||||
image=image,
|
||||
image_origin=ResourceOrigin.INTERNAL,
|
||||
image_category=ImageCategory.GENERAL,
|
||||
node_id=self.id,
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate
|
||||
)
|
||||
|
||||
return ImageOutput(
|
||||
image=ImageField(image_name=image_dto.image_name),
|
||||
width=image_dto.width,
|
||||
height=image_dto.height,
|
||||
)
|
||||
|
||||
LATENTS_INTERPOLATION_MODE = Literal[
|
||||
"nearest", "linear", "bilinear", "bicubic", "trilinear", "area", "nearest-exact"
|
||||
@ -579,14 +624,14 @@ class ImageToLatentsInvocation(BaseInvocation):
|
||||
|
||||
# Inputs
|
||||
image: Union[ImageField, None] = Field(description="The image to encode")
|
||||
model: str = Field(default="", description="The model to use")
|
||||
vae: VaeField = Field(default=None, description="Vae submodel")
|
||||
tiled: bool = Field(default=False, description="Encode latents by overlaping tiles(less memory consumption)")
|
||||
|
||||
# Schema customisation
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"tags": ["latents", "image"],
|
||||
"type_hints": {"model": "model"},
|
||||
},
|
||||
}
|
||||
|
||||
@ -597,20 +642,30 @@ class ImageToLatentsInvocation(BaseInvocation):
|
||||
# )
|
||||
image = context.services.images.get_pil_image(self.image.image_name)
|
||||
|
||||
# TODO: this only really needs the vae
|
||||
model_info = choose_model(context.services.model_manager, self.model)
|
||||
model: StableDiffusionGeneratorPipeline = model_info["model"]
|
||||
#vae_info = context.services.model_manager.get_model(**self.vae.vae.dict())
|
||||
vae_info = context.services.model_manager.get_model(
|
||||
**self.vae.vae.dict(),
|
||||
)
|
||||
|
||||
image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB"))
|
||||
|
||||
if image_tensor.dim() == 3:
|
||||
image_tensor = einops.rearrange(image_tensor, "c h w -> 1 c h w")
|
||||
|
||||
latents = model.non_noised_latents_from_image(
|
||||
image_tensor,
|
||||
device=model._model_group.device_for(model.unet),
|
||||
dtype=model.unet.dtype,
|
||||
)
|
||||
with vae_info as vae:
|
||||
if self.tiled:
|
||||
vae.enable_tiling()
|
||||
else:
|
||||
vae.disable_tiling()
|
||||
|
||||
# non_noised_latents_from_image
|
||||
image_tensor = image_tensor.to(device=vae.device, dtype=vae.dtype)
|
||||
with torch.inference_mode():
|
||||
image_tensor_dist = vae.encode(image_tensor).latent_dist
|
||||
latents = image_tensor_dist.sample().to(
|
||||
dtype=vae.dtype
|
||||
) # FIXME: uses torch.randn. make reproducible!
|
||||
|
||||
latents = 0.18215 * latents
|
||||
|
||||
name = f"{context.graph_execution_state_id}__{self.id}"
|
||||
# context.services.latents.set(name, latents)
|
||||
|
217
invokeai/app/invocations/model.py
Normal file
217
invokeai/app/invocations/model.py
Normal file
@ -0,0 +1,217 @@
|
||||
from typing import Literal, Optional, Union, List
|
||||
from pydantic import BaseModel, Field
|
||||
import copy
|
||||
|
||||
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext, InvocationConfig
|
||||
|
||||
from ...backend.util.devices import choose_torch_device, torch_dtype
|
||||
from ...backend.model_management import BaseModelType, ModelType, SubModelType
|
||||
|
||||
class ModelInfo(BaseModel):
|
||||
model_name: str = Field(description="Info to load submodel")
|
||||
base_model: BaseModelType = Field(description="Base model")
|
||||
model_type: ModelType = Field(description="Info to load submodel")
|
||||
submodel: Optional[SubModelType] = Field(description="Info to load submodel")
|
||||
|
||||
class LoraInfo(ModelInfo):
|
||||
weight: float = Field(description="Lora's weight which to use when apply to model")
|
||||
|
||||
class UNetField(BaseModel):
|
||||
unet: ModelInfo = Field(description="Info to load unet submodel")
|
||||
scheduler: ModelInfo = Field(description="Info to load scheduler submodel")
|
||||
loras: List[LoraInfo] = Field(description="Loras to apply on model loading")
|
||||
|
||||
class ClipField(BaseModel):
|
||||
tokenizer: ModelInfo = Field(description="Info to load tokenizer submodel")
|
||||
text_encoder: ModelInfo = Field(description="Info to load text_encoder submodel")
|
||||
loras: List[LoraInfo] = Field(description="Loras to apply on model loading")
|
||||
|
||||
class VaeField(BaseModel):
|
||||
# TODO: better naming?
|
||||
vae: ModelInfo = Field(description="Info to load vae submodel")
|
||||
|
||||
|
||||
class ModelLoaderOutput(BaseInvocationOutput):
|
||||
"""Model loader output"""
|
||||
|
||||
#fmt: off
|
||||
type: Literal["model_loader_output"] = "model_loader_output"
|
||||
|
||||
unet: UNetField = Field(default=None, description="UNet submodel")
|
||||
clip: ClipField = Field(default=None, description="Tokenizer and text_encoder submodels")
|
||||
vae: VaeField = Field(default=None, description="Vae submodel")
|
||||
#fmt: on
|
||||
|
||||
|
||||
class PipelineModelField(BaseModel):
|
||||
"""Pipeline model field"""
|
||||
|
||||
model_name: str = Field(description="Name of the model")
|
||||
base_model: BaseModelType = Field(description="Base model")
|
||||
|
||||
|
||||
class PipelineModelLoaderInvocation(BaseInvocation):
|
||||
"""Loads a pipeline model, outputting its submodels."""
|
||||
|
||||
type: Literal["pipeline_model_loader"] = "pipeline_model_loader"
|
||||
|
||||
model: PipelineModelField = Field(description="The model to load")
|
||||
# TODO: precision?
|
||||
|
||||
# Schema customisation
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"tags": ["model", "loader"],
|
||||
"type_hints": {
|
||||
"model": "model"
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ModelLoaderOutput:
|
||||
|
||||
base_model = self.model.base_model
|
||||
model_name = self.model.model_name
|
||||
model_type = ModelType.Pipeline
|
||||
|
||||
# TODO: not found exceptions
|
||||
if not context.services.model_manager.model_exists(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
):
|
||||
raise Exception(f"Unknown {base_model} {model_type} model: {model_name}")
|
||||
|
||||
"""
|
||||
if not context.services.model_manager.model_exists(
|
||||
model_name=self.model_name,
|
||||
model_type=SDModelType.Diffusers,
|
||||
submodel=SDModelType.Tokenizer,
|
||||
):
|
||||
raise Exception(
|
||||
f"Failed to find tokenizer submodel in {self.model_name}! Check if model corrupted"
|
||||
)
|
||||
|
||||
if not context.services.model_manager.model_exists(
|
||||
model_name=self.model_name,
|
||||
model_type=SDModelType.Diffusers,
|
||||
submodel=SDModelType.TextEncoder,
|
||||
):
|
||||
raise Exception(
|
||||
f"Failed to find text_encoder submodel in {self.model_name}! Check if model corrupted"
|
||||
)
|
||||
|
||||
if not context.services.model_manager.model_exists(
|
||||
model_name=self.model_name,
|
||||
model_type=SDModelType.Diffusers,
|
||||
submodel=SDModelType.UNet,
|
||||
):
|
||||
raise Exception(
|
||||
f"Failed to find unet submodel from {self.model_name}! Check if model corrupted"
|
||||
)
|
||||
"""
|
||||
|
||||
|
||||
return ModelLoaderOutput(
|
||||
unet=UNetField(
|
||||
unet=ModelInfo(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
submodel=SubModelType.UNet,
|
||||
),
|
||||
scheduler=ModelInfo(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
submodel=SubModelType.Scheduler,
|
||||
),
|
||||
loras=[],
|
||||
),
|
||||
clip=ClipField(
|
||||
tokenizer=ModelInfo(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
submodel=SubModelType.Tokenizer,
|
||||
),
|
||||
text_encoder=ModelInfo(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
submodel=SubModelType.TextEncoder,
|
||||
),
|
||||
loras=[],
|
||||
),
|
||||
vae=VaeField(
|
||||
vae=ModelInfo(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
submodel=SubModelType.Vae,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
class LoraLoaderOutput(BaseInvocationOutput):
|
||||
"""Model loader output"""
|
||||
|
||||
#fmt: off
|
||||
type: Literal["lora_loader_output"] = "lora_loader_output"
|
||||
|
||||
unet: Optional[UNetField] = Field(default=None, description="UNet submodel")
|
||||
clip: Optional[ClipField] = Field(default=None, description="Tokenizer and text_encoder submodels")
|
||||
#fmt: on
|
||||
|
||||
class LoraLoaderInvocation(BaseInvocation):
|
||||
"""Apply selected lora to unet and text_encoder."""
|
||||
|
||||
type: Literal["lora_loader"] = "lora_loader"
|
||||
|
||||
lora_name: str = Field(description="Lora model name")
|
||||
weight: float = Field(default=0.75, description="With what weight to apply lora")
|
||||
|
||||
unet: Optional[UNetField] = Field(description="UNet model for applying lora")
|
||||
clip: Optional[ClipField] = Field(description="Clip model for applying lora")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> LoraLoaderOutput:
|
||||
|
||||
if not context.services.model_manager.model_exists(
|
||||
model_name=self.lora_name,
|
||||
model_type=SDModelType.Lora,
|
||||
):
|
||||
raise Exception(f"Unkown lora name: {self.lora_name}!")
|
||||
|
||||
if self.unet is not None and any(lora.model_name == self.lora_name for lora in self.unet.loras):
|
||||
raise Exception(f"Lora \"{self.lora_name}\" already applied to unet")
|
||||
|
||||
if self.clip is not None and any(lora.model_name == self.lora_name for lora in self.clip.loras):
|
||||
raise Exception(f"Lora \"{self.lora_name}\" already applied to clip")
|
||||
|
||||
output = LoraLoaderOutput()
|
||||
|
||||
if self.unet is not None:
|
||||
output.unet = copy.deepcopy(self.unet)
|
||||
output.unet.loras.append(
|
||||
LoraInfo(
|
||||
model_name=self.lora_name,
|
||||
model_type=SDModelType.Lora,
|
||||
submodel=None,
|
||||
weight=self.weight,
|
||||
)
|
||||
)
|
||||
|
||||
if self.clip is not None:
|
||||
output.clip = copy.deepcopy(self.clip)
|
||||
output.clip.loras.append(
|
||||
LoraInfo(
|
||||
model_name=self.lora_name,
|
||||
model_type=SDModelType.Lora,
|
||||
submodel=None,
|
||||
weight=self.weight,
|
||||
)
|
||||
)
|
||||
|
||||
return output
|
||||
|
@ -1,14 +0,0 @@
|
||||
from invokeai.backend.model_management.model_manager import ModelManager
|
||||
|
||||
|
||||
def choose_model(model_manager: ModelManager, model_name: str):
|
||||
"""Returns the default model if the `model_name` not a valid model, else returns the selected model."""
|
||||
logger = model_manager.logger
|
||||
if model_name and not model_manager.valid_model(model_name):
|
||||
default_model_name = model_manager.default_model()
|
||||
logger.warning(f"\'{model_name}\' is not a valid model name. Using default model \'{default_model_name}\' instead.")
|
||||
model = model_manager.get_model()
|
||||
else:
|
||||
model = model_manager.get_model(model_name)
|
||||
|
||||
return model
|
254
invokeai/app/services/board_image_record_storage.py
Normal file
254
invokeai/app/services/board_image_record_storage.py
Normal file
@ -0,0 +1,254 @@
|
||||
from abc import ABC, abstractmethod
|
||||
import sqlite3
|
||||
import threading
|
||||
from typing import Union, cast
|
||||
from invokeai.app.services.board_record_storage import BoardRecord
|
||||
|
||||
from invokeai.app.services.image_record_storage import OffsetPaginatedResults
|
||||
from invokeai.app.services.models.image_record import (
|
||||
ImageRecord,
|
||||
deserialize_image_record,
|
||||
)
|
||||
|
||||
|
||||
class BoardImageRecordStorageBase(ABC):
|
||||
"""Abstract base class for the one-to-many board-image relationship record storage."""
|
||||
|
||||
@abstractmethod
|
||||
def add_image_to_board(
|
||||
self,
|
||||
board_id: str,
|
||||
image_name: str,
|
||||
) -> None:
|
||||
"""Adds an image to a board."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def remove_image_from_board(
|
||||
self,
|
||||
board_id: str,
|
||||
image_name: str,
|
||||
) -> None:
|
||||
"""Removes an image from a board."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_images_for_board(
|
||||
self,
|
||||
board_id: str,
|
||||
) -> OffsetPaginatedResults[ImageRecord]:
|
||||
"""Gets images for a board."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_board_for_image(
|
||||
self,
|
||||
image_name: str,
|
||||
) -> Union[str, None]:
|
||||
"""Gets an image's board id, if it has one."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_image_count_for_board(
|
||||
self,
|
||||
board_id: str,
|
||||
) -> int:
|
||||
"""Gets the number of images for a board."""
|
||||
pass
|
||||
|
||||
|
||||
class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase):
|
||||
_filename: str
|
||||
_conn: sqlite3.Connection
|
||||
_cursor: sqlite3.Cursor
|
||||
_lock: threading.Lock
|
||||
|
||||
def __init__(self, filename: str) -> None:
|
||||
super().__init__()
|
||||
self._filename = filename
|
||||
self._conn = sqlite3.connect(filename, check_same_thread=False)
|
||||
# Enable row factory to get rows as dictionaries (must be done before making the cursor!)
|
||||
self._conn.row_factory = sqlite3.Row
|
||||
self._cursor = self._conn.cursor()
|
||||
self._lock = threading.Lock()
|
||||
|
||||
try:
|
||||
self._lock.acquire()
|
||||
# Enable foreign keys
|
||||
self._conn.execute("PRAGMA foreign_keys = ON;")
|
||||
self._create_tables()
|
||||
self._conn.commit()
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
def _create_tables(self) -> None:
|
||||
"""Creates the `board_images` junction table."""
|
||||
|
||||
# Create the `board_images` junction table.
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
CREATE TABLE IF NOT EXISTS board_images (
|
||||
board_id TEXT NOT NULL,
|
||||
image_name TEXT NOT NULL,
|
||||
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||
-- updated via trigger
|
||||
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||
-- Soft delete, currently unused
|
||||
deleted_at DATETIME,
|
||||
-- enforce one-to-many relationship between boards and images using PK
|
||||
-- (we can extend this to many-to-many later)
|
||||
PRIMARY KEY (image_name),
|
||||
FOREIGN KEY (board_id) REFERENCES boards (board_id) ON DELETE CASCADE,
|
||||
FOREIGN KEY (image_name) REFERENCES images (image_name) ON DELETE CASCADE
|
||||
);
|
||||
"""
|
||||
)
|
||||
|
||||
# Add index for board id
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
CREATE INDEX IF NOT EXISTS idx_board_images_board_id ON board_images (board_id);
|
||||
"""
|
||||
)
|
||||
|
||||
# Add index for board id, sorted by created_at
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
CREATE INDEX IF NOT EXISTS idx_board_images_board_id_created_at ON board_images (board_id, created_at);
|
||||
"""
|
||||
)
|
||||
|
||||
# Add trigger for `updated_at`.
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
CREATE TRIGGER IF NOT EXISTS tg_board_images_updated_at
|
||||
AFTER UPDATE
|
||||
ON board_images FOR EACH ROW
|
||||
BEGIN
|
||||
UPDATE board_images SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
|
||||
WHERE board_id = old.board_id AND image_name = old.image_name;
|
||||
END;
|
||||
"""
|
||||
)
|
||||
|
||||
def add_image_to_board(
|
||||
self,
|
||||
board_id: str,
|
||||
image_name: str,
|
||||
) -> None:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
INSERT INTO board_images (board_id, image_name)
|
||||
VALUES (?, ?)
|
||||
ON CONFLICT (image_name) DO UPDATE SET board_id = ?;
|
||||
""",
|
||||
(board_id, image_name, board_id),
|
||||
)
|
||||
self._conn.commit()
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise e
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
def remove_image_from_board(
|
||||
self,
|
||||
board_id: str,
|
||||
image_name: str,
|
||||
) -> None:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
DELETE FROM board_images
|
||||
WHERE board_id = ? AND image_name = ?;
|
||||
""",
|
||||
(board_id, image_name),
|
||||
)
|
||||
self._conn.commit()
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise e
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
def get_images_for_board(
|
||||
self,
|
||||
board_id: str,
|
||||
offset: int = 0,
|
||||
limit: int = 10,
|
||||
) -> OffsetPaginatedResults[ImageRecord]:
|
||||
# TODO: this isn't paginated yet?
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT images.*
|
||||
FROM board_images
|
||||
INNER JOIN images ON board_images.image_name = images.image_name
|
||||
WHERE board_images.board_id = ?
|
||||
ORDER BY board_images.updated_at DESC;
|
||||
""",
|
||||
(board_id,),
|
||||
)
|
||||
result = cast(list[sqlite3.Row], self._cursor.fetchall())
|
||||
images = list(map(lambda r: deserialize_image_record(dict(r)), result))
|
||||
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT COUNT(*) FROM images WHERE 1=1;
|
||||
"""
|
||||
)
|
||||
count = cast(int, self._cursor.fetchone()[0])
|
||||
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise e
|
||||
finally:
|
||||
self._lock.release()
|
||||
return OffsetPaginatedResults(
|
||||
items=images, offset=offset, limit=limit, total=count
|
||||
)
|
||||
|
||||
def get_board_for_image(
|
||||
self,
|
||||
image_name: str,
|
||||
) -> Union[str, None]:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT board_id
|
||||
FROM board_images
|
||||
WHERE image_name = ?;
|
||||
""",
|
||||
(image_name,),
|
||||
)
|
||||
result = self._cursor.fetchone()
|
||||
if result is None:
|
||||
return None
|
||||
return cast(str, result[0])
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise e
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
def get_image_count_for_board(self, board_id: str) -> int:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT COUNT(*) FROM board_images WHERE board_id = ?;
|
||||
""",
|
||||
(board_id,),
|
||||
)
|
||||
count = cast(int, self._cursor.fetchone()[0])
|
||||
return count
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise e
|
||||
finally:
|
||||
self._lock.release()
|
142
invokeai/app/services/board_images.py
Normal file
142
invokeai/app/services/board_images.py
Normal file
@ -0,0 +1,142 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from logging import Logger
|
||||
from typing import List, Union
|
||||
from invokeai.app.services.board_image_record_storage import BoardImageRecordStorageBase
|
||||
from invokeai.app.services.board_record_storage import (
|
||||
BoardRecord,
|
||||
BoardRecordStorageBase,
|
||||
)
|
||||
|
||||
from invokeai.app.services.image_record_storage import (
|
||||
ImageRecordStorageBase,
|
||||
OffsetPaginatedResults,
|
||||
)
|
||||
from invokeai.app.services.models.board_record import BoardDTO
|
||||
from invokeai.app.services.models.image_record import ImageDTO, image_record_to_dto
|
||||
from invokeai.app.services.urls import UrlServiceBase
|
||||
|
||||
|
||||
class BoardImagesServiceABC(ABC):
|
||||
"""High-level service for board-image relationship management."""
|
||||
|
||||
@abstractmethod
|
||||
def add_image_to_board(
|
||||
self,
|
||||
board_id: str,
|
||||
image_name: str,
|
||||
) -> None:
|
||||
"""Adds an image to a board."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def remove_image_from_board(
|
||||
self,
|
||||
board_id: str,
|
||||
image_name: str,
|
||||
) -> None:
|
||||
"""Removes an image from a board."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_images_for_board(
|
||||
self,
|
||||
board_id: str,
|
||||
) -> OffsetPaginatedResults[ImageDTO]:
|
||||
"""Gets images for a board."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_board_for_image(
|
||||
self,
|
||||
image_name: str,
|
||||
) -> Union[str, None]:
|
||||
"""Gets an image's board id, if it has one."""
|
||||
pass
|
||||
|
||||
|
||||
class BoardImagesServiceDependencies:
|
||||
"""Service dependencies for the BoardImagesService."""
|
||||
|
||||
board_image_records: BoardImageRecordStorageBase
|
||||
board_records: BoardRecordStorageBase
|
||||
image_records: ImageRecordStorageBase
|
||||
urls: UrlServiceBase
|
||||
logger: Logger
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
board_image_record_storage: BoardImageRecordStorageBase,
|
||||
image_record_storage: ImageRecordStorageBase,
|
||||
board_record_storage: BoardRecordStorageBase,
|
||||
url: UrlServiceBase,
|
||||
logger: Logger,
|
||||
):
|
||||
self.board_image_records = board_image_record_storage
|
||||
self.image_records = image_record_storage
|
||||
self.board_records = board_record_storage
|
||||
self.urls = url
|
||||
self.logger = logger
|
||||
|
||||
|
||||
class BoardImagesService(BoardImagesServiceABC):
|
||||
_services: BoardImagesServiceDependencies
|
||||
|
||||
def __init__(self, services: BoardImagesServiceDependencies):
|
||||
self._services = services
|
||||
|
||||
def add_image_to_board(
|
||||
self,
|
||||
board_id: str,
|
||||
image_name: str,
|
||||
) -> None:
|
||||
self._services.board_image_records.add_image_to_board(board_id, image_name)
|
||||
|
||||
def remove_image_from_board(
|
||||
self,
|
||||
board_id: str,
|
||||
image_name: str,
|
||||
) -> None:
|
||||
self._services.board_image_records.remove_image_from_board(board_id, image_name)
|
||||
|
||||
def get_images_for_board(
|
||||
self,
|
||||
board_id: str,
|
||||
) -> OffsetPaginatedResults[ImageDTO]:
|
||||
image_records = self._services.board_image_records.get_images_for_board(
|
||||
board_id
|
||||
)
|
||||
image_dtos = list(
|
||||
map(
|
||||
lambda r: image_record_to_dto(
|
||||
r,
|
||||
self._services.urls.get_image_url(r.image_name),
|
||||
self._services.urls.get_image_url(r.image_name, True),
|
||||
board_id,
|
||||
),
|
||||
image_records.items,
|
||||
)
|
||||
)
|
||||
return OffsetPaginatedResults[ImageDTO](
|
||||
items=image_dtos,
|
||||
offset=image_records.offset,
|
||||
limit=image_records.limit,
|
||||
total=image_records.total,
|
||||
)
|
||||
|
||||
def get_board_for_image(
|
||||
self,
|
||||
image_name: str,
|
||||
) -> Union[str, None]:
|
||||
board_id = self._services.board_image_records.get_board_for_image(image_name)
|
||||
return board_id
|
||||
|
||||
|
||||
def board_record_to_dto(
|
||||
board_record: BoardRecord, cover_image_name: str | None, image_count: int
|
||||
) -> BoardDTO:
|
||||
"""Converts a board record to a board DTO."""
|
||||
return BoardDTO(
|
||||
**board_record.dict(exclude={'cover_image_name'}),
|
||||
cover_image_name=cover_image_name,
|
||||
image_count=image_count,
|
||||
)
|
329
invokeai/app/services/board_record_storage.py
Normal file
329
invokeai/app/services/board_record_storage.py
Normal file
@ -0,0 +1,329 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional, cast
|
||||
import sqlite3
|
||||
import threading
|
||||
from typing import Optional, Union
|
||||
import uuid
|
||||
from invokeai.app.services.image_record_storage import OffsetPaginatedResults
|
||||
from invokeai.app.services.models.board_record import (
|
||||
BoardRecord,
|
||||
deserialize_board_record,
|
||||
)
|
||||
|
||||
from pydantic import BaseModel, Field, Extra
|
||||
|
||||
|
||||
class BoardChanges(BaseModel, extra=Extra.forbid):
|
||||
board_name: Optional[str] = Field(description="The board's new name.")
|
||||
cover_image_name: Optional[str] = Field(
|
||||
description="The name of the board's new cover image."
|
||||
)
|
||||
|
||||
|
||||
class BoardRecordNotFoundException(Exception):
|
||||
"""Raised when an board record is not found."""
|
||||
|
||||
def __init__(self, message="Board record not found"):
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class BoardRecordSaveException(Exception):
|
||||
"""Raised when an board record cannot be saved."""
|
||||
|
||||
def __init__(self, message="Board record not saved"):
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class BoardRecordDeleteException(Exception):
|
||||
"""Raised when an board record cannot be deleted."""
|
||||
|
||||
def __init__(self, message="Board record not deleted"):
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class BoardRecordStorageBase(ABC):
|
||||
"""Low-level service responsible for interfacing with the board record store."""
|
||||
|
||||
@abstractmethod
|
||||
def delete(self, board_id: str) -> None:
|
||||
"""Deletes a board record."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def save(
|
||||
self,
|
||||
board_name: str,
|
||||
) -> BoardRecord:
|
||||
"""Saves a board record."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get(
|
||||
self,
|
||||
board_id: str,
|
||||
) -> BoardRecord:
|
||||
"""Gets a board record."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def update(
|
||||
self,
|
||||
board_id: str,
|
||||
changes: BoardChanges,
|
||||
) -> BoardRecord:
|
||||
"""Updates a board record."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_many(
|
||||
self,
|
||||
offset: int = 0,
|
||||
limit: int = 10,
|
||||
) -> OffsetPaginatedResults[BoardRecord]:
|
||||
"""Gets many board records."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_all(
|
||||
self,
|
||||
) -> list[BoardRecord]:
|
||||
"""Gets all board records."""
|
||||
pass
|
||||
|
||||
|
||||
class SqliteBoardRecordStorage(BoardRecordStorageBase):
|
||||
_filename: str
|
||||
_conn: sqlite3.Connection
|
||||
_cursor: sqlite3.Cursor
|
||||
_lock: threading.Lock
|
||||
|
||||
def __init__(self, filename: str) -> None:
|
||||
super().__init__()
|
||||
self._filename = filename
|
||||
self._conn = sqlite3.connect(filename, check_same_thread=False)
|
||||
# Enable row factory to get rows as dictionaries (must be done before making the cursor!)
|
||||
self._conn.row_factory = sqlite3.Row
|
||||
self._cursor = self._conn.cursor()
|
||||
self._lock = threading.Lock()
|
||||
|
||||
try:
|
||||
self._lock.acquire()
|
||||
# Enable foreign keys
|
||||
self._conn.execute("PRAGMA foreign_keys = ON;")
|
||||
self._create_tables()
|
||||
self._conn.commit()
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
def _create_tables(self) -> None:
|
||||
"""Creates the `boards` table and `board_images` junction table."""
|
||||
|
||||
# Create the `boards` table.
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
CREATE TABLE IF NOT EXISTS boards (
|
||||
board_id TEXT NOT NULL PRIMARY KEY,
|
||||
board_name TEXT NOT NULL,
|
||||
cover_image_name TEXT,
|
||||
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||
-- Updated via trigger
|
||||
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||
-- Soft delete, currently unused
|
||||
deleted_at DATETIME,
|
||||
FOREIGN KEY (cover_image_name) REFERENCES images (image_name) ON DELETE SET NULL
|
||||
);
|
||||
"""
|
||||
)
|
||||
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
CREATE INDEX IF NOT EXISTS idx_boards_created_at ON boards (created_at);
|
||||
"""
|
||||
)
|
||||
|
||||
# Add trigger for `updated_at`.
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
CREATE TRIGGER IF NOT EXISTS tg_boards_updated_at
|
||||
AFTER UPDATE
|
||||
ON boards FOR EACH ROW
|
||||
BEGIN
|
||||
UPDATE boards SET updated_at = current_timestamp
|
||||
WHERE board_id = old.board_id;
|
||||
END;
|
||||
"""
|
||||
)
|
||||
|
||||
def delete(self, board_id: str) -> None:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
DELETE FROM boards
|
||||
WHERE board_id = ?;
|
||||
""",
|
||||
(board_id,),
|
||||
)
|
||||
self._conn.commit()
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise BoardRecordDeleteException from e
|
||||
except Exception as e:
|
||||
self._conn.rollback()
|
||||
raise BoardRecordDeleteException from e
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
def save(
|
||||
self,
|
||||
board_name: str,
|
||||
) -> BoardRecord:
|
||||
try:
|
||||
board_id = str(uuid.uuid4())
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
INSERT OR IGNORE INTO boards (board_id, board_name)
|
||||
VALUES (?, ?);
|
||||
""",
|
||||
(board_id, board_name),
|
||||
)
|
||||
self._conn.commit()
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise BoardRecordSaveException from e
|
||||
finally:
|
||||
self._lock.release()
|
||||
return self.get(board_id)
|
||||
|
||||
def get(
|
||||
self,
|
||||
board_id: str,
|
||||
) -> BoardRecord:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT *
|
||||
FROM boards
|
||||
WHERE board_id = ?;
|
||||
""",
|
||||
(board_id,),
|
||||
)
|
||||
|
||||
result = cast(Union[sqlite3.Row, None], self._cursor.fetchone())
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise BoardRecordNotFoundException from e
|
||||
finally:
|
||||
self._lock.release()
|
||||
if result is None:
|
||||
raise BoardRecordNotFoundException
|
||||
return BoardRecord(**dict(result))
|
||||
|
||||
def update(
|
||||
self,
|
||||
board_id: str,
|
||||
changes: BoardChanges,
|
||||
) -> BoardRecord:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
|
||||
# Change the name of a board
|
||||
if changes.board_name is not None:
|
||||
self._cursor.execute(
|
||||
f"""--sql
|
||||
UPDATE boards
|
||||
SET board_name = ?
|
||||
WHERE board_id = ?;
|
||||
""",
|
||||
(changes.board_name, board_id),
|
||||
)
|
||||
|
||||
# Change the cover image of a board
|
||||
if changes.cover_image_name is not None:
|
||||
self._cursor.execute(
|
||||
f"""--sql
|
||||
UPDATE boards
|
||||
SET cover_image_name = ?
|
||||
WHERE board_id = ?;
|
||||
""",
|
||||
(changes.cover_image_name, board_id),
|
||||
)
|
||||
|
||||
self._conn.commit()
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise BoardRecordSaveException from e
|
||||
finally:
|
||||
self._lock.release()
|
||||
return self.get(board_id)
|
||||
|
||||
def get_many(
|
||||
self,
|
||||
offset: int = 0,
|
||||
limit: int = 10,
|
||||
) -> OffsetPaginatedResults[BoardRecord]:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
|
||||
# Get all the boards
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT *
|
||||
FROM boards
|
||||
ORDER BY created_at DESC
|
||||
LIMIT ? OFFSET ?;
|
||||
""",
|
||||
(limit, offset),
|
||||
)
|
||||
|
||||
result = cast(list[sqlite3.Row], self._cursor.fetchall())
|
||||
boards = list(map(lambda r: deserialize_board_record(dict(r)), result))
|
||||
|
||||
# Get the total number of boards
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT COUNT(*)
|
||||
FROM boards
|
||||
WHERE 1=1;
|
||||
"""
|
||||
)
|
||||
|
||||
count = cast(int, self._cursor.fetchone()[0])
|
||||
|
||||
return OffsetPaginatedResults[BoardRecord](
|
||||
items=boards, offset=offset, limit=limit, total=count
|
||||
)
|
||||
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise e
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
def get_all(
|
||||
self,
|
||||
) -> list[BoardRecord]:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
|
||||
# Get all the boards
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT *
|
||||
FROM boards
|
||||
ORDER BY created_at DESC
|
||||
"""
|
||||
)
|
||||
|
||||
result = cast(list[sqlite3.Row], self._cursor.fetchall())
|
||||
boards = list(map(lambda r: deserialize_board_record(dict(r)), result))
|
||||
|
||||
return boards
|
||||
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise e
|
||||
finally:
|
||||
self._lock.release()
|
185
invokeai/app/services/boards.py
Normal file
185
invokeai/app/services/boards.py
Normal file
@ -0,0 +1,185 @@
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from logging import Logger
|
||||
from invokeai.app.services.board_image_record_storage import BoardImageRecordStorageBase
|
||||
from invokeai.app.services.board_images import board_record_to_dto
|
||||
|
||||
from invokeai.app.services.board_record_storage import (
|
||||
BoardChanges,
|
||||
BoardRecordStorageBase,
|
||||
)
|
||||
from invokeai.app.services.image_record_storage import (
|
||||
ImageRecordStorageBase,
|
||||
OffsetPaginatedResults,
|
||||
)
|
||||
from invokeai.app.services.models.board_record import BoardDTO
|
||||
from invokeai.app.services.urls import UrlServiceBase
|
||||
|
||||
|
||||
class BoardServiceABC(ABC):
|
||||
"""High-level service for board management."""
|
||||
|
||||
@abstractmethod
|
||||
def create(
|
||||
self,
|
||||
board_name: str,
|
||||
) -> BoardDTO:
|
||||
"""Creates a board."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_dto(
|
||||
self,
|
||||
board_id: str,
|
||||
) -> BoardDTO:
|
||||
"""Gets a board."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def update(
|
||||
self,
|
||||
board_id: str,
|
||||
changes: BoardChanges,
|
||||
) -> BoardDTO:
|
||||
"""Updates a board."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def delete(
|
||||
self,
|
||||
board_id: str,
|
||||
) -> None:
|
||||
"""Deletes a board."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_many(
|
||||
self,
|
||||
offset: int = 0,
|
||||
limit: int = 10,
|
||||
) -> OffsetPaginatedResults[BoardDTO]:
|
||||
"""Gets many boards."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_all(
|
||||
self,
|
||||
) -> list[BoardDTO]:
|
||||
"""Gets all boards."""
|
||||
pass
|
||||
|
||||
|
||||
class BoardServiceDependencies:
|
||||
"""Service dependencies for the BoardService."""
|
||||
|
||||
board_image_records: BoardImageRecordStorageBase
|
||||
board_records: BoardRecordStorageBase
|
||||
image_records: ImageRecordStorageBase
|
||||
urls: UrlServiceBase
|
||||
logger: Logger
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
board_image_record_storage: BoardImageRecordStorageBase,
|
||||
image_record_storage: ImageRecordStorageBase,
|
||||
board_record_storage: BoardRecordStorageBase,
|
||||
url: UrlServiceBase,
|
||||
logger: Logger,
|
||||
):
|
||||
self.board_image_records = board_image_record_storage
|
||||
self.image_records = image_record_storage
|
||||
self.board_records = board_record_storage
|
||||
self.urls = url
|
||||
self.logger = logger
|
||||
|
||||
|
||||
class BoardService(BoardServiceABC):
|
||||
_services: BoardServiceDependencies
|
||||
|
||||
def __init__(self, services: BoardServiceDependencies):
|
||||
self._services = services
|
||||
|
||||
def create(
|
||||
self,
|
||||
board_name: str,
|
||||
) -> BoardDTO:
|
||||
board_record = self._services.board_records.save(board_name)
|
||||
return board_record_to_dto(board_record, None, 0)
|
||||
|
||||
def get_dto(self, board_id: str) -> BoardDTO:
|
||||
board_record = self._services.board_records.get(board_id)
|
||||
cover_image = self._services.image_records.get_most_recent_image_for_board(
|
||||
board_record.board_id
|
||||
)
|
||||
if cover_image:
|
||||
cover_image_name = cover_image.image_name
|
||||
else:
|
||||
cover_image_name = None
|
||||
image_count = self._services.board_image_records.get_image_count_for_board(
|
||||
board_id
|
||||
)
|
||||
return board_record_to_dto(board_record, cover_image_name, image_count)
|
||||
|
||||
def update(
|
||||
self,
|
||||
board_id: str,
|
||||
changes: BoardChanges,
|
||||
) -> BoardDTO:
|
||||
board_record = self._services.board_records.update(board_id, changes)
|
||||
cover_image = self._services.image_records.get_most_recent_image_for_board(
|
||||
board_record.board_id
|
||||
)
|
||||
if cover_image:
|
||||
cover_image_name = cover_image.image_name
|
||||
else:
|
||||
cover_image_name = None
|
||||
|
||||
image_count = self._services.board_image_records.get_image_count_for_board(
|
||||
board_id
|
||||
)
|
||||
return board_record_to_dto(board_record, cover_image_name, image_count)
|
||||
|
||||
def delete(self, board_id: str) -> None:
|
||||
self._services.board_records.delete(board_id)
|
||||
|
||||
def get_many(
|
||||
self, offset: int = 0, limit: int = 10
|
||||
) -> OffsetPaginatedResults[BoardDTO]:
|
||||
board_records = self._services.board_records.get_many(offset, limit)
|
||||
board_dtos = []
|
||||
for r in board_records.items:
|
||||
cover_image = self._services.image_records.get_most_recent_image_for_board(
|
||||
r.board_id
|
||||
)
|
||||
if cover_image:
|
||||
cover_image_name = cover_image.image_name
|
||||
else:
|
||||
cover_image_name = None
|
||||
|
||||
image_count = self._services.board_image_records.get_image_count_for_board(
|
||||
r.board_id
|
||||
)
|
||||
board_dtos.append(board_record_to_dto(r, cover_image_name, image_count))
|
||||
|
||||
return OffsetPaginatedResults[BoardDTO](
|
||||
items=board_dtos, offset=offset, limit=limit, total=len(board_dtos)
|
||||
)
|
||||
|
||||
def get_all(self) -> list[BoardDTO]:
|
||||
board_records = self._services.board_records.get_all()
|
||||
board_dtos = []
|
||||
for r in board_records:
|
||||
cover_image = self._services.image_records.get_most_recent_image_for_board(
|
||||
r.board_id
|
||||
)
|
||||
if cover_image:
|
||||
cover_image_name = cover_image.image_name
|
||||
else:
|
||||
cover_image_name = None
|
||||
|
||||
image_count = self._services.board_image_records.get_image_count_for_board(
|
||||
r.board_id
|
||||
)
|
||||
board_dtos.append(board_record_to_dto(r, cover_image_name, image_count))
|
||||
|
||||
return board_dtos
|
@ -15,10 +15,7 @@ InvokeAI:
|
||||
conf_path: configs/models.yaml
|
||||
legacy_conf_dir: configs/stable-diffusion
|
||||
outdir: outputs
|
||||
embedding_dir: embeddings
|
||||
lora_dir: loras
|
||||
autoconvert_dir: null
|
||||
gfpgan_model_dir: models/gfpgan/GFPGANv1.4.pth
|
||||
Models:
|
||||
model: stable-diffusion-1.5
|
||||
embeddings: true
|
||||
@ -171,7 +168,7 @@ from argparse import ArgumentParser
|
||||
from omegaconf import OmegaConf, DictConfig
|
||||
from pathlib import Path
|
||||
from pydantic import BaseSettings, Field, parse_obj_as
|
||||
from typing import ClassVar, Dict, List, Literal, Type, Union, get_origin, get_type_hints, get_args
|
||||
from typing import ClassVar, Dict, List, Literal, Union, get_origin, get_type_hints, get_args
|
||||
|
||||
INIT_FILE = Path('invokeai.yaml')
|
||||
DB_FILE = Path('invokeai.db')
|
||||
@ -374,23 +371,19 @@ setting environment variables INVOKEAI_<setting>.
|
||||
precision : Literal[tuple(['auto','float16','float32','autocast'])] = Field(default='float16',description='Floating point precision', category='Memory/Performance')
|
||||
sequential_guidance : bool = Field(default=False, description="Whether to calculate guidance in serial instead of in parallel, lowering memory requirements", category='Memory/Performance')
|
||||
xformers_enabled : bool = Field(default=True, description="Enable/disable memory-efficient attention", category='Memory/Performance')
|
||||
|
||||
tiled_decode : bool = Field(default=False, description="Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty)", category='Memory/Performance')
|
||||
|
||||
root : Path = Field(default=_find_root(), description='InvokeAI runtime root directory', category='Paths')
|
||||
autoconvert_dir : Path = Field(default=None, description='Path to a directory of ckpt files to be converted into diffusers and imported on startup.', category='Paths')
|
||||
conf_path : Path = Field(default='configs/models.yaml', description='Path to models definition file', category='Paths')
|
||||
embedding_dir : Path = Field(default='embeddings', description='Path to InvokeAI textual inversion aembeddings directory', category='Paths')
|
||||
gfpgan_model_dir : Path = Field(default="./models/gfpgan/GFPGANv1.4.pth", description='Path to GFPGAN models directory.', category='Paths')
|
||||
controlnet_dir : Path = Field(default="controlnets", description='Path to directory of ControlNet models.', category='Paths')
|
||||
models_dir : Path = Field(default='./models', description='Path to the models directory', category='Paths')
|
||||
legacy_conf_dir : Path = Field(default='configs/stable-diffusion', description='Path to directory of legacy checkpoint config files', category='Paths')
|
||||
lora_dir : Path = Field(default='loras', description='Path to InvokeAI LoRA model directory', category='Paths')
|
||||
db_dir : Path = Field(default='databases', description='Path to InvokeAI databases directory', category='Paths')
|
||||
outdir : Path = Field(default='outputs', description='Default folder for output images', category='Paths')
|
||||
from_file : Path = Field(default=None, description='Take command input from the indicated file (command-line client only)', category='Paths')
|
||||
use_memory_db : bool = Field(default=False, description='Use in-memory database for storing image metadata', category='Paths')
|
||||
|
||||
model : str = Field(default='stable-diffusion-1.5', description='Initial model name', category='Models')
|
||||
embeddings : bool = Field(default=True, description='Load contents of embeddings directory', category='Models')
|
||||
|
||||
log_handlers : List[str] = Field(default=["console"], description='Log handler. Valid options are "console", "file=<path>", "syslog=path|address:host:port", "http=<url>"', category="Logging")
|
||||
# note - would be better to read the log_format values from logging.py, but this creates circular dependencies issues
|
||||
@ -492,39 +485,11 @@ setting environment variables INVOKEAI_<setting>.
|
||||
return self._resolve(self.legacy_conf_dir)
|
||||
|
||||
@property
|
||||
def cache_dir(self)->Path:
|
||||
'''
|
||||
Path to the global cache directory for HuggingFace hub-managed models
|
||||
'''
|
||||
return self.models_dir / "hub"
|
||||
|
||||
@property
|
||||
def models_dir(self)->Path:
|
||||
def models_path(self)->Path:
|
||||
'''
|
||||
Path to the models directory
|
||||
'''
|
||||
return self._resolve("models")
|
||||
|
||||
@property
|
||||
def embedding_path(self)->Path:
|
||||
'''
|
||||
Path to the textual inversion embeddings directory.
|
||||
'''
|
||||
return self._resolve(self.embedding_dir) if self.embedding_dir else None
|
||||
|
||||
@property
|
||||
def lora_path(self)->Path:
|
||||
'''
|
||||
Path to the LoRA models directory.
|
||||
'''
|
||||
return self._resolve(self.lora_dir) if self.lora_dir else None
|
||||
|
||||
@property
|
||||
def controlnet_path(self)->Path:
|
||||
'''
|
||||
Path to the controlnet models directory.
|
||||
'''
|
||||
return self._resolve(self.controlnet_dir) if self.controlnet_dir else None
|
||||
return self._resolve(self.models_dir)
|
||||
|
||||
@property
|
||||
def autoconvert_path(self)->Path:
|
||||
@ -533,13 +498,6 @@ setting environment variables INVOKEAI_<setting>.
|
||||
'''
|
||||
return self._resolve(self.autoconvert_dir) if self.autoconvert_dir else None
|
||||
|
||||
@property
|
||||
def gfpgan_model_path(self)->Path:
|
||||
'''
|
||||
Path to the GFPGAN model.
|
||||
'''
|
||||
return self._resolve(self.gfpgan_model_dir) if self.gfpgan_model_dir else None
|
||||
|
||||
# the following methods support legacy calls leftover from the Globals era
|
||||
@property
|
||||
def full_precision(self)->bool:
|
||||
|
@ -3,7 +3,8 @@
|
||||
from typing import Any
|
||||
from invokeai.app.models.image import ProgressImage
|
||||
from invokeai.app.util.misc import get_timestamp
|
||||
|
||||
from invokeai.app.services.model_manager_service import BaseModelType, ModelType, SubModelType, ModelInfo
|
||||
from invokeai.app.models.exceptions import CanceledException
|
||||
|
||||
class EventServiceBase:
|
||||
session_event: str = "session_event"
|
||||
@ -101,3 +102,53 @@ class EventServiceBase:
|
||||
graph_execution_state_id=graph_execution_state_id,
|
||||
),
|
||||
)
|
||||
|
||||
def emit_model_load_started (
|
||||
self,
|
||||
graph_execution_state_id: str,
|
||||
node: dict,
|
||||
source_node_id: str,
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
submodel: SubModelType,
|
||||
) -> None:
|
||||
"""Emitted when a model is requested"""
|
||||
self.__emit_session_event(
|
||||
event_name="model_load_started",
|
||||
payload=dict(
|
||||
graph_execution_state_id=graph_execution_state_id,
|
||||
node=node,
|
||||
source_node_id=source_node_id,
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
submodel=submodel,
|
||||
),
|
||||
)
|
||||
|
||||
def emit_model_load_completed(
|
||||
self,
|
||||
graph_execution_state_id: str,
|
||||
node: dict,
|
||||
source_node_id: str,
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
submodel: SubModelType,
|
||||
model_info: ModelInfo,
|
||||
) -> None:
|
||||
"""Emitted when a model is correctly loaded (returns model info)"""
|
||||
self.__emit_session_event(
|
||||
event_name="model_load_completed",
|
||||
payload=dict(
|
||||
graph_execution_state_id=graph_execution_state_id,
|
||||
node=node,
|
||||
source_node_id=source_node_id,
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
submodel=submodel,
|
||||
model_info=model_info,
|
||||
),
|
||||
)
|
||||
|
@ -82,6 +82,7 @@ class ImageRecordStorageBase(ABC):
|
||||
image_origin: Optional[ResourceOrigin] = None,
|
||||
categories: Optional[list[ImageCategory]] = None,
|
||||
is_intermediate: Optional[bool] = None,
|
||||
board_id: Optional[str] = None,
|
||||
) -> OffsetPaginatedResults[ImageRecord]:
|
||||
"""Gets a page of image records."""
|
||||
pass
|
||||
@ -109,6 +110,11 @@ class ImageRecordStorageBase(ABC):
|
||||
"""Saves an image record."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_most_recent_image_for_board(self, board_id: str) -> ImageRecord | None:
|
||||
"""Gets the most recent image for a board."""
|
||||
pass
|
||||
|
||||
|
||||
class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
_filename: str
|
||||
@ -135,7 +141,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
self._lock.release()
|
||||
|
||||
def _create_tables(self) -> None:
|
||||
"""Creates the tables for the `images` database."""
|
||||
"""Creates the `images` table."""
|
||||
|
||||
# Create the `images` table.
|
||||
self._cursor.execute(
|
||||
@ -152,6 +158,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
node_id TEXT,
|
||||
metadata TEXT,
|
||||
is_intermediate BOOLEAN DEFAULT FALSE,
|
||||
board_id TEXT,
|
||||
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||
-- Updated via trigger
|
||||
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||
@ -190,7 +197,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
AFTER UPDATE
|
||||
ON images FOR EACH ROW
|
||||
BEGIN
|
||||
UPDATE images SET updated_at = current_timestamp
|
||||
UPDATE images SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
|
||||
WHERE image_name = old.image_name;
|
||||
END;
|
||||
"""
|
||||
@ -259,6 +266,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
""",
|
||||
(changes.is_intermediate, image_name),
|
||||
)
|
||||
|
||||
self._conn.commit()
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
@ -273,38 +281,66 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
image_origin: Optional[ResourceOrigin] = None,
|
||||
categories: Optional[list[ImageCategory]] = None,
|
||||
is_intermediate: Optional[bool] = None,
|
||||
board_id: Optional[str] = None,
|
||||
) -> OffsetPaginatedResults[ImageRecord]:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
|
||||
# Manually build two queries - one for the count, one for the records
|
||||
count_query = """--sql
|
||||
SELECT COUNT(*)
|
||||
FROM images
|
||||
LEFT JOIN board_images ON board_images.image_name = images.image_name
|
||||
WHERE 1=1
|
||||
"""
|
||||
|
||||
count_query = f"""SELECT COUNT(*) FROM images WHERE 1=1\n"""
|
||||
images_query = f"""SELECT * FROM images WHERE 1=1\n"""
|
||||
images_query = """--sql
|
||||
SELECT images.*
|
||||
FROM images
|
||||
LEFT JOIN board_images ON board_images.image_name = images.image_name
|
||||
WHERE 1=1
|
||||
"""
|
||||
|
||||
query_conditions = ""
|
||||
query_params = []
|
||||
|
||||
if image_origin is not None:
|
||||
query_conditions += f"""AND image_origin = ?\n"""
|
||||
query_conditions += """--sql
|
||||
AND images.image_origin = ?
|
||||
"""
|
||||
query_params.append(image_origin.value)
|
||||
|
||||
if categories is not None:
|
||||
## Convert the enum values to unique list of strings
|
||||
# Convert the enum values to unique list of strings
|
||||
category_strings = list(map(lambda c: c.value, set(categories)))
|
||||
# Create the correct length of placeholders
|
||||
placeholders = ",".join("?" * len(category_strings))
|
||||
query_conditions += f"AND image_category IN ( {placeholders} )\n"
|
||||
|
||||
query_conditions += f"""--sql
|
||||
AND images.image_category IN ( {placeholders} )
|
||||
"""
|
||||
|
||||
# Unpack the included categories into the query params
|
||||
for c in category_strings:
|
||||
query_params.append(c)
|
||||
|
||||
if is_intermediate is not None:
|
||||
query_conditions += f"""AND is_intermediate = ?\n"""
|
||||
query_conditions += """--sql
|
||||
AND images.is_intermediate = ?
|
||||
"""
|
||||
|
||||
query_params.append(is_intermediate)
|
||||
|
||||
query_pagination = f"""ORDER BY created_at DESC LIMIT ? OFFSET ?\n"""
|
||||
if board_id is not None:
|
||||
query_conditions += """--sql
|
||||
AND board_images.board_id = ?
|
||||
"""
|
||||
|
||||
query_params.append(board_id)
|
||||
|
||||
query_pagination = """--sql
|
||||
ORDER BY images.created_at DESC LIMIT ? OFFSET ?
|
||||
"""
|
||||
|
||||
# Final images query with pagination
|
||||
images_query += query_conditions + query_pagination + ";"
|
||||
@ -321,7 +357,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
count_query += query_conditions + ";"
|
||||
count_params = query_params.copy()
|
||||
self._cursor.execute(count_query, count_params)
|
||||
count = self._cursor.fetchone()[0]
|
||||
count = cast(int, self._cursor.fetchone()[0])
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise e
|
||||
@ -412,3 +448,28 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
raise ImageRecordSaveException from e
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
def get_most_recent_image_for_board(
|
||||
self, board_id: str
|
||||
) -> Union[ImageRecord, None]:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT images.*
|
||||
FROM images
|
||||
JOIN board_images ON images.image_name = board_images.image_name
|
||||
WHERE board_images.board_id = ?
|
||||
ORDER BY images.created_at DESC
|
||||
LIMIT 1;
|
||||
""",
|
||||
(board_id,),
|
||||
)
|
||||
|
||||
result = cast(Union[sqlite3.Row, None], self._cursor.fetchone())
|
||||
finally:
|
||||
self._lock.release()
|
||||
if result is None:
|
||||
return None
|
||||
|
||||
return deserialize_image_record(dict(result))
|
||||
|
@ -10,6 +10,7 @@ from invokeai.app.models.image import (
|
||||
InvalidOriginException,
|
||||
)
|
||||
from invokeai.app.models.metadata import ImageMetadata
|
||||
from invokeai.app.services.board_image_record_storage import BoardImageRecordStorageBase
|
||||
from invokeai.app.services.image_record_storage import (
|
||||
ImageRecordDeleteException,
|
||||
ImageRecordNotFoundException,
|
||||
@ -49,7 +50,7 @@ class ImageServiceABC(ABC):
|
||||
image_category: ImageCategory,
|
||||
node_id: Optional[str] = None,
|
||||
session_id: Optional[str] = None,
|
||||
intermediate: bool = False,
|
||||
is_intermediate: bool = False,
|
||||
) -> ImageDTO:
|
||||
"""Creates an image, storing the file and its metadata."""
|
||||
pass
|
||||
@ -79,7 +80,7 @@ class ImageServiceABC(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_path(self, image_name: str) -> str:
|
||||
def get_path(self, image_name: str, thumbnail: bool = False) -> str:
|
||||
"""Gets an image's path."""
|
||||
pass
|
||||
|
||||
@ -101,6 +102,7 @@ class ImageServiceABC(ABC):
|
||||
image_origin: Optional[ResourceOrigin] = None,
|
||||
categories: Optional[list[ImageCategory]] = None,
|
||||
is_intermediate: Optional[bool] = None,
|
||||
board_id: Optional[str] = None,
|
||||
) -> OffsetPaginatedResults[ImageDTO]:
|
||||
"""Gets a paginated list of image DTOs."""
|
||||
pass
|
||||
@ -114,8 +116,9 @@ class ImageServiceABC(ABC):
|
||||
class ImageServiceDependencies:
|
||||
"""Service dependencies for the ImageService."""
|
||||
|
||||
records: ImageRecordStorageBase
|
||||
files: ImageFileStorageBase
|
||||
image_records: ImageRecordStorageBase
|
||||
image_files: ImageFileStorageBase
|
||||
board_image_records: BoardImageRecordStorageBase
|
||||
metadata: MetadataServiceBase
|
||||
urls: UrlServiceBase
|
||||
logger: Logger
|
||||
@ -126,14 +129,16 @@ class ImageServiceDependencies:
|
||||
self,
|
||||
image_record_storage: ImageRecordStorageBase,
|
||||
image_file_storage: ImageFileStorageBase,
|
||||
board_image_record_storage: BoardImageRecordStorageBase,
|
||||
metadata: MetadataServiceBase,
|
||||
url: UrlServiceBase,
|
||||
logger: Logger,
|
||||
names: NameServiceBase,
|
||||
graph_execution_manager: ItemStorageABC["GraphExecutionState"],
|
||||
):
|
||||
self.records = image_record_storage
|
||||
self.files = image_file_storage
|
||||
self.image_records = image_record_storage
|
||||
self.image_files = image_file_storage
|
||||
self.board_image_records = board_image_record_storage
|
||||
self.metadata = metadata
|
||||
self.urls = url
|
||||
self.logger = logger
|
||||
@ -144,25 +149,8 @@ class ImageServiceDependencies:
|
||||
class ImageService(ImageServiceABC):
|
||||
_services: ImageServiceDependencies
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
image_record_storage: ImageRecordStorageBase,
|
||||
image_file_storage: ImageFileStorageBase,
|
||||
metadata: MetadataServiceBase,
|
||||
url: UrlServiceBase,
|
||||
logger: Logger,
|
||||
names: NameServiceBase,
|
||||
graph_execution_manager: ItemStorageABC["GraphExecutionState"],
|
||||
):
|
||||
self._services = ImageServiceDependencies(
|
||||
image_record_storage=image_record_storage,
|
||||
image_file_storage=image_file_storage,
|
||||
metadata=metadata,
|
||||
url=url,
|
||||
logger=logger,
|
||||
names=names,
|
||||
graph_execution_manager=graph_execution_manager,
|
||||
)
|
||||
def __init__(self, services: ImageServiceDependencies):
|
||||
self._services = services
|
||||
|
||||
def create(
|
||||
self,
|
||||
@ -187,7 +175,7 @@ class ImageService(ImageServiceABC):
|
||||
|
||||
try:
|
||||
# TODO: Consider using a transaction here to ensure consistency between storage and database
|
||||
created_at = self._services.records.save(
|
||||
self._services.image_records.save(
|
||||
# Non-nullable fields
|
||||
image_name=image_name,
|
||||
image_origin=image_origin,
|
||||
@ -202,35 +190,15 @@ class ImageService(ImageServiceABC):
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
self._services.files.save(
|
||||
self._services.image_files.save(
|
||||
image_name=image_name,
|
||||
image=image,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
image_url = self._services.urls.get_image_url(image_name)
|
||||
thumbnail_url = self._services.urls.get_image_url(image_name, True)
|
||||
image_dto = self.get_dto(image_name)
|
||||
|
||||
return ImageDTO(
|
||||
# Non-nullable fields
|
||||
image_name=image_name,
|
||||
image_origin=image_origin,
|
||||
image_category=image_category,
|
||||
width=width,
|
||||
height=height,
|
||||
# Nullable fields
|
||||
node_id=node_id,
|
||||
session_id=session_id,
|
||||
metadata=metadata,
|
||||
# Meta fields
|
||||
created_at=created_at,
|
||||
updated_at=created_at, # this is always the same as the created_at at this time
|
||||
deleted_at=None,
|
||||
is_intermediate=is_intermediate,
|
||||
# Extra non-nullable fields for DTO
|
||||
image_url=image_url,
|
||||
thumbnail_url=thumbnail_url,
|
||||
)
|
||||
return image_dto
|
||||
except ImageRecordSaveException:
|
||||
self._services.logger.error("Failed to save image record")
|
||||
raise
|
||||
@ -247,7 +215,7 @@ class ImageService(ImageServiceABC):
|
||||
changes: ImageRecordChanges,
|
||||
) -> ImageDTO:
|
||||
try:
|
||||
self._services.records.update(image_name, changes)
|
||||
self._services.image_records.update(image_name, changes)
|
||||
return self.get_dto(image_name)
|
||||
except ImageRecordSaveException:
|
||||
self._services.logger.error("Failed to update image record")
|
||||
@ -258,7 +226,7 @@ class ImageService(ImageServiceABC):
|
||||
|
||||
def get_pil_image(self, image_name: str) -> PILImageType:
|
||||
try:
|
||||
return self._services.files.get(image_name)
|
||||
return self._services.image_files.get(image_name)
|
||||
except ImageFileNotFoundException:
|
||||
self._services.logger.error("Failed to get image file")
|
||||
raise
|
||||
@ -268,7 +236,7 @@ class ImageService(ImageServiceABC):
|
||||
|
||||
def get_record(self, image_name: str) -> ImageRecord:
|
||||
try:
|
||||
return self._services.records.get(image_name)
|
||||
return self._services.image_records.get(image_name)
|
||||
except ImageRecordNotFoundException:
|
||||
self._services.logger.error("Image record not found")
|
||||
raise
|
||||
@ -278,12 +246,13 @@ class ImageService(ImageServiceABC):
|
||||
|
||||
def get_dto(self, image_name: str) -> ImageDTO:
|
||||
try:
|
||||
image_record = self._services.records.get(image_name)
|
||||
image_record = self._services.image_records.get(image_name)
|
||||
|
||||
image_dto = image_record_to_dto(
|
||||
image_record,
|
||||
self._services.urls.get_image_url(image_name),
|
||||
self._services.urls.get_image_url(image_name, True),
|
||||
self._services.board_image_records.get_board_for_image(image_name),
|
||||
)
|
||||
|
||||
return image_dto
|
||||
@ -296,14 +265,14 @@ class ImageService(ImageServiceABC):
|
||||
|
||||
def get_path(self, image_name: str, thumbnail: bool = False) -> str:
|
||||
try:
|
||||
return self._services.files.get_path(image_name, thumbnail)
|
||||
return self._services.image_files.get_path(image_name, thumbnail)
|
||||
except Exception as e:
|
||||
self._services.logger.error("Problem getting image path")
|
||||
raise e
|
||||
|
||||
def validate_path(self, path: str) -> bool:
|
||||
try:
|
||||
return self._services.files.validate_path(path)
|
||||
return self._services.image_files.validate_path(path)
|
||||
except Exception as e:
|
||||
self._services.logger.error("Problem validating image path")
|
||||
raise e
|
||||
@ -322,14 +291,16 @@ class ImageService(ImageServiceABC):
|
||||
image_origin: Optional[ResourceOrigin] = None,
|
||||
categories: Optional[list[ImageCategory]] = None,
|
||||
is_intermediate: Optional[bool] = None,
|
||||
board_id: Optional[str] = None,
|
||||
) -> OffsetPaginatedResults[ImageDTO]:
|
||||
try:
|
||||
results = self._services.records.get_many(
|
||||
results = self._services.image_records.get_many(
|
||||
offset,
|
||||
limit,
|
||||
image_origin,
|
||||
categories,
|
||||
is_intermediate,
|
||||
board_id,
|
||||
)
|
||||
|
||||
image_dtos = list(
|
||||
@ -338,6 +309,9 @@ class ImageService(ImageServiceABC):
|
||||
r,
|
||||
self._services.urls.get_image_url(r.image_name),
|
||||
self._services.urls.get_image_url(r.image_name, True),
|
||||
self._services.board_image_records.get_board_for_image(
|
||||
r.image_name
|
||||
),
|
||||
),
|
||||
results.items,
|
||||
)
|
||||
@ -355,8 +329,8 @@ class ImageService(ImageServiceABC):
|
||||
|
||||
def delete(self, image_name: str):
|
||||
try:
|
||||
self._services.files.delete(image_name)
|
||||
self._services.records.delete(image_name)
|
||||
self._services.image_files.delete(image_name)
|
||||
self._services.image_records.delete(image_name)
|
||||
except ImageRecordDeleteException:
|
||||
self._services.logger.error(f"Failed to delete image record")
|
||||
raise
|
||||
|
@ -4,7 +4,9 @@ from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from logging import Logger
|
||||
from invokeai.app.services.images import ImageService
|
||||
from invokeai.app.services.board_images import BoardImagesServiceABC
|
||||
from invokeai.app.services.boards import BoardServiceABC
|
||||
from invokeai.app.services.images import ImageServiceABC
|
||||
from invokeai.backend import ModelManager
|
||||
from invokeai.app.services.events import EventServiceBase
|
||||
from invokeai.app.services.latent_storage import LatentsStorageBase
|
||||
@ -26,9 +28,9 @@ class InvocationServices:
|
||||
model_manager: "ModelManager"
|
||||
restoration: "RestorationServices"
|
||||
configuration: "InvokeAISettings"
|
||||
images: "ImageService"
|
||||
|
||||
# NOTE: we must forward-declare any types that include invocations, since invocations can use services
|
||||
images: "ImageServiceABC"
|
||||
boards: "BoardServiceABC"
|
||||
board_images: "BoardImagesServiceABC"
|
||||
graph_library: "ItemStorageABC"["LibraryGraph"]
|
||||
graph_execution_manager: "ItemStorageABC"["GraphExecutionState"]
|
||||
processor: "InvocationProcessorABC"
|
||||
@ -39,7 +41,9 @@ class InvocationServices:
|
||||
events: "EventServiceBase",
|
||||
logger: "Logger",
|
||||
latents: "LatentsStorageBase",
|
||||
images: "ImageService",
|
||||
images: "ImageServiceABC",
|
||||
boards: "BoardServiceABC",
|
||||
board_images: "BoardImagesServiceABC",
|
||||
queue: "InvocationQueueABC",
|
||||
graph_library: "ItemStorageABC"["LibraryGraph"],
|
||||
graph_execution_manager: "ItemStorageABC"["GraphExecutionState"],
|
||||
@ -52,9 +56,12 @@ class InvocationServices:
|
||||
self.logger = logger
|
||||
self.latents = latents
|
||||
self.images = images
|
||||
self.boards = boards
|
||||
self.board_images = board_images
|
||||
self.queue = queue
|
||||
self.graph_library = graph_library
|
||||
self.graph_execution_manager = graph_execution_manager
|
||||
self.processor = processor
|
||||
self.restoration = restoration
|
||||
self.configuration = configuration
|
||||
self.boards = boards
|
||||
|
@ -1,104 +0,0 @@
|
||||
import os
|
||||
import sys
|
||||
import torch
|
||||
from argparse import Namespace
|
||||
from omegaconf import OmegaConf
|
||||
from pathlib import Path
|
||||
from typing import types
|
||||
|
||||
import invokeai.version
|
||||
from .config import InvokeAISettings
|
||||
from ...backend import ModelManager
|
||||
from ...backend.util import choose_precision, choose_torch_device
|
||||
|
||||
# TODO: Replace with an abstract class base ModelManagerBase
|
||||
def get_model_manager(config: InvokeAISettings, logger: types.ModuleType) -> ModelManager:
|
||||
model_config = config.model_conf_path
|
||||
if not model_config.exists():
|
||||
report_model_error(
|
||||
config, FileNotFoundError(f"The file {model_config} could not be found."), logger
|
||||
)
|
||||
|
||||
logger.info(f"{invokeai.version.__app_name__}, version {invokeai.version.__version__}")
|
||||
logger.info(f'InvokeAI runtime directory is "{config.root}"')
|
||||
|
||||
# these two lines prevent a horrible warning message from appearing
|
||||
# when the frozen CLIP tokenizer is imported
|
||||
import transformers # type: ignore
|
||||
|
||||
transformers.logging.set_verbosity_error()
|
||||
import diffusers
|
||||
|
||||
diffusers.logging.set_verbosity_error()
|
||||
embedding_path = config.embedding_path
|
||||
|
||||
# migrate legacy models
|
||||
ModelManager.migrate_models()
|
||||
|
||||
# creating the model manager
|
||||
try:
|
||||
device = torch.device(choose_torch_device())
|
||||
precision = 'float16' if config.precision=='float16' \
|
||||
else 'float32' if config.precision=='float32' \
|
||||
else choose_precision(device)
|
||||
|
||||
model_manager = ModelManager(
|
||||
OmegaConf.load(config.model_conf_path),
|
||||
precision=precision,
|
||||
device_type=device,
|
||||
max_loaded_models=config.max_loaded_models,
|
||||
embedding_path = embedding_path,
|
||||
logger = logger,
|
||||
)
|
||||
except (FileNotFoundError, TypeError, AssertionError) as e:
|
||||
report_model_error(config, e, logger)
|
||||
except (IOError, KeyError) as e:
|
||||
logger.error(f"{e}. Aborting.")
|
||||
sys.exit(-1)
|
||||
|
||||
# try to autoconvert new models
|
||||
# autoimport new .ckpt files
|
||||
if config.autoconvert_path:
|
||||
model_manager.heuristic_import(
|
||||
config.autoconvert_path,
|
||||
)
|
||||
return model_manager
|
||||
|
||||
def report_model_error(opt: Namespace, e: Exception, logger: types.ModuleType):
|
||||
logger.error(f'An error occurred while attempting to initialize the model: "{str(e)}"')
|
||||
logger.error(
|
||||
"This can be caused by a missing or corrupted models file, and can sometimes be fixed by (re)installing the models."
|
||||
)
|
||||
yes_to_all = os.environ.get("INVOKE_MODEL_RECONFIGURE")
|
||||
if yes_to_all:
|
||||
logger.warning(
|
||||
"Reconfiguration is being forced by environment variable INVOKE_MODEL_RECONFIGURE"
|
||||
)
|
||||
else:
|
||||
response = input(
|
||||
"Do you want to run invokeai-configure script to select and/or reinstall models? [y] "
|
||||
)
|
||||
if response.startswith(("n", "N")):
|
||||
return
|
||||
|
||||
logger.info("invokeai-configure is launching....\n")
|
||||
|
||||
# Match arguments that were set on the CLI
|
||||
# only the arguments accepted by the configuration script are parsed
|
||||
root_dir = ["--root", opt.root_dir] if opt.root_dir is not None else []
|
||||
config = ["--config", opt.conf] if opt.conf is not None else []
|
||||
sys.argv = ["invokeai-configure"]
|
||||
sys.argv.extend(root_dir)
|
||||
sys.argv.extend(config.to_dict())
|
||||
if yes_to_all is not None:
|
||||
for arg in yes_to_all.split():
|
||||
sys.argv.append(arg)
|
||||
|
||||
from invokeai.frontend.install import invokeai_configure
|
||||
|
||||
invokeai_configure()
|
||||
# TODO: Figure out how to restart
|
||||
# print('** InvokeAI will now restart')
|
||||
# sys.argv = previous_args
|
||||
# main() # would rather do a os.exec(), but doesn't exist?
|
||||
# sys.exit(0)
|
363
invokeai/app/services/model_manager_service.py
Normal file
363
invokeai/app/services/model_manager_service.py
Normal file
@ -0,0 +1,363 @@
|
||||
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Team
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import torch
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import Optional, Union, Callable, List, Tuple, types, TYPE_CHECKING
|
||||
from dataclasses import dataclass
|
||||
|
||||
from invokeai.backend.model_management.model_manager import (
|
||||
ModelManager,
|
||||
BaseModelType,
|
||||
ModelType,
|
||||
SubModelType,
|
||||
ModelInfo,
|
||||
)
|
||||
from invokeai.app.models.exceptions import CanceledException
|
||||
from .config import InvokeAIAppConfig
|
||||
from ...backend.util import choose_precision, choose_torch_device
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..invocations.baseinvocation import BaseInvocation, InvocationContext
|
||||
|
||||
|
||||
class ModelManagerServiceBase(ABC):
|
||||
"""Responsible for managing models on disk and in memory"""
|
||||
|
||||
@abstractmethod
|
||||
def __init__(
|
||||
self,
|
||||
config: InvokeAIAppConfig,
|
||||
logger: types.ModuleType,
|
||||
):
|
||||
"""
|
||||
Initialize with the path to the models.yaml config file.
|
||||
Optional parameters are the torch device type, precision, max_models,
|
||||
and sequential_offload boolean. Note that the default device
|
||||
type and precision are set up for a CUDA system running at half precision.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_model(
|
||||
self,
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
submodel: Optional[SubModelType] = None,
|
||||
node: Optional[BaseInvocation] = None,
|
||||
context: Optional[InvocationContext] = None,
|
||||
) -> ModelInfo:
|
||||
"""Retrieve the indicated model with name and type.
|
||||
submodel can be used to get a part (such as the vae)
|
||||
of a diffusers pipeline."""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def logger(self):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def model_exists(
|
||||
self,
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
) -> bool:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def model_info(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> dict:
|
||||
"""
|
||||
Given a model name returns a dict-like (OmegaConf) object describing it.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def model_names(self) -> List[Tuple[str, BaseModelType, ModelType]]:
|
||||
"""
|
||||
Returns a list of all the model names known.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def list_models(self, base_model: Optional[BaseModelType] = None, model_type: Optional[ModelType] = None) -> dict:
|
||||
"""
|
||||
Return a dict of models in the format:
|
||||
{ model_type1:
|
||||
{ model_name1: {'status': 'active'|'cached'|'not loaded',
|
||||
'model_name' : name,
|
||||
'model_type' : SDModelType,
|
||||
'description': description,
|
||||
'format': 'folder'|'safetensors'|'ckpt'
|
||||
},
|
||||
model_name2: { etc }
|
||||
},
|
||||
model_type2:
|
||||
{ model_name_n: etc
|
||||
}
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
@abstractmethod
|
||||
def add_model(
|
||||
self,
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
model_attributes: dict,
|
||||
clobber: bool = False
|
||||
) -> None:
|
||||
"""
|
||||
Update the named model with a dictionary of attributes. Will fail with an
|
||||
assertion error if the name already exists. Pass clobber=True to overwrite.
|
||||
On a successful update, the config will be changed in memory. Will fail
|
||||
with an assertion error if provided attributes are incorrect or
|
||||
the model name is missing. Call commit() to write changes to disk.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def del_model(
|
||||
self,
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
):
|
||||
"""
|
||||
Delete the named model from configuration. If delete_files is true,
|
||||
then the underlying weight file or diffusers directory will be deleted
|
||||
as well. Call commit() to write to disk.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def commit(self, conf_file: Path = None) -> None:
|
||||
"""
|
||||
Write current configuration out to the indicated file.
|
||||
If no conf_file is provided, then replaces the
|
||||
original file/database used to initialize the object.
|
||||
"""
|
||||
pass
|
||||
|
||||
# simple implementation
|
||||
class ModelManagerService(ModelManagerServiceBase):
|
||||
"""Responsible for managing models on disk and in memory"""
|
||||
def __init__(
|
||||
self,
|
||||
config: InvokeAIAppConfig,
|
||||
logger: types.ModuleType,
|
||||
):
|
||||
"""
|
||||
Initialize with the path to the models.yaml config file.
|
||||
Optional parameters are the torch device type, precision, max_models,
|
||||
and sequential_offload boolean. Note that the default device
|
||||
type and precision are set up for a CUDA system running at half precision.
|
||||
"""
|
||||
if config.model_conf_path and config.model_conf_path.exists():
|
||||
config_file = config.model_conf_path
|
||||
else:
|
||||
config_file = config.root_dir / "configs/models.yaml"
|
||||
if not config_file.exists():
|
||||
raise IOError(f"The file {config_file} could not be found.")
|
||||
|
||||
logger.debug(f'config file={config_file}')
|
||||
|
||||
device = torch.device(choose_torch_device())
|
||||
precision = config.precision
|
||||
if precision == "auto":
|
||||
precision = choose_precision(device)
|
||||
dtype = torch.float32 if precision == 'float32' else torch.float16
|
||||
|
||||
# this is transitional backward compatibility
|
||||
# support for the deprecated `max_loaded_models`
|
||||
# configuration value. If present, then the
|
||||
# cache size is set to 2.5 GB times
|
||||
# the number of max_loaded_models. Otherwise
|
||||
# use new `max_cache_size` config setting
|
||||
max_cache_size = config.max_cache_size \
|
||||
if hasattr(config,'max_cache_size') \
|
||||
else config.max_loaded_models * 2.5
|
||||
|
||||
sequential_offload = config.sequential_guidance
|
||||
|
||||
self.mgr = ModelManager(
|
||||
config=config_file,
|
||||
device_type=device,
|
||||
precision=dtype,
|
||||
max_cache_size=max_cache_size,
|
||||
sequential_offload=sequential_offload,
|
||||
logger=logger,
|
||||
)
|
||||
logger.info('Model manager service initialized')
|
||||
|
||||
def get_model(
|
||||
self,
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
submodel: Optional[SubModelType] = None,
|
||||
node: Optional[BaseInvocation] = None,
|
||||
context: Optional[InvocationContext] = None,
|
||||
) -> ModelInfo:
|
||||
"""
|
||||
Retrieve the indicated model. submodel can be used to get a
|
||||
part (such as the vae) of a diffusers mode.
|
||||
"""
|
||||
|
||||
# if we are called from within a node, then we get to emit
|
||||
# load start and complete events
|
||||
if node and context:
|
||||
self._emit_load_event(
|
||||
node=node,
|
||||
context=context,
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
submodel=submodel,
|
||||
)
|
||||
|
||||
model_info = self.mgr.get_model(
|
||||
model_name,
|
||||
base_model,
|
||||
model_type,
|
||||
submodel,
|
||||
)
|
||||
|
||||
if node and context:
|
||||
self._emit_load_event(
|
||||
node=node,
|
||||
context=context,
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
submodel=submodel,
|
||||
model_info=model_info
|
||||
)
|
||||
|
||||
return model_info
|
||||
|
||||
def model_exists(
|
||||
self,
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
) -> bool:
|
||||
"""
|
||||
Given a model name, returns True if it is a valid
|
||||
identifier.
|
||||
"""
|
||||
return self.mgr.model_exists(
|
||||
model_name,
|
||||
base_model,
|
||||
model_type,
|
||||
)
|
||||
|
||||
def model_info(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> dict:
|
||||
"""
|
||||
Given a model name returns a dict-like (OmegaConf) object describing it.
|
||||
"""
|
||||
return self.mgr.model_info(model_name, base_model, model_type)
|
||||
|
||||
def model_names(self) -> List[Tuple[str, BaseModelType, ModelType]]:
|
||||
"""
|
||||
Returns a list of all the model names known.
|
||||
"""
|
||||
return self.mgr.model_names()
|
||||
|
||||
def list_models(
|
||||
self,
|
||||
base_model: Optional[BaseModelType] = None,
|
||||
model_type: Optional[ModelType] = None
|
||||
) -> list[dict]:
|
||||
# ) -> dict:
|
||||
"""
|
||||
Return a list of models.
|
||||
"""
|
||||
return self.mgr.list_models(base_model, model_type)
|
||||
|
||||
def add_model(
|
||||
self,
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
model_attributes: dict,
|
||||
clobber: bool = False,
|
||||
)->None:
|
||||
"""
|
||||
Update the named model with a dictionary of attributes. Will fail with an
|
||||
assertion error if the name already exists. Pass clobber=True to overwrite.
|
||||
On a successful update, the config will be changed in memory. Will fail
|
||||
with an assertion error if provided attributes are incorrect or
|
||||
the model name is missing. Call commit() to write changes to disk.
|
||||
"""
|
||||
return self.mgr.add_model(model_name, base_model, model_type, model_attributes, clobber)
|
||||
|
||||
|
||||
def del_model(
|
||||
self,
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
):
|
||||
"""
|
||||
Delete the named model from configuration. If delete_files is true,
|
||||
then the underlying weight file or diffusers directory will be deleted
|
||||
as well. Call commit() to write to disk.
|
||||
"""
|
||||
self.mgr.del_model(model_name, base_model, model_type)
|
||||
|
||||
|
||||
def commit(self, conf_file: Optional[Path]=None):
|
||||
"""
|
||||
Write current configuration out to the indicated file.
|
||||
If no conf_file is provided, then replaces the
|
||||
original file/database used to initialize the object.
|
||||
"""
|
||||
return self.mgr.commit(conf_file)
|
||||
|
||||
def _emit_load_event(
|
||||
self,
|
||||
node,
|
||||
context,
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
submodel: SubModelType,
|
||||
model_info: Optional[ModelInfo] = None,
|
||||
):
|
||||
if context.services.queue.is_canceled(context.graph_execution_state_id):
|
||||
raise CanceledException()
|
||||
graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id)
|
||||
source_node_id = graph_execution_state.prepared_source_mapping[node.id]
|
||||
if model_info:
|
||||
context.services.events.emit_model_load_completed(
|
||||
graph_execution_state_id=context.graph_execution_state_id,
|
||||
node=node.dict(),
|
||||
source_node_id=source_node_id,
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
submodel=submodel,
|
||||
model_info=model_info
|
||||
)
|
||||
else:
|
||||
context.services.events.emit_model_load_started(
|
||||
graph_execution_state_id=context.graph_execution_state_id,
|
||||
node=node.dict(),
|
||||
source_node_id=source_node_id,
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
submodel=submodel,
|
||||
)
|
||||
|
||||
|
||||
@property
|
||||
def logger(self):
|
||||
return self.mgr.logger
|
||||
|
62
invokeai/app/services/models/board_record.py
Normal file
62
invokeai/app/services/models/board_record.py
Normal file
@ -0,0 +1,62 @@
|
||||
from typing import Optional, Union
|
||||
from datetime import datetime
|
||||
from pydantic import BaseModel, Extra, Field, StrictBool, StrictStr
|
||||
from invokeai.app.util.misc import get_iso_timestamp
|
||||
|
||||
|
||||
class BoardRecord(BaseModel):
|
||||
"""Deserialized board record."""
|
||||
|
||||
board_id: str = Field(description="The unique ID of the board.")
|
||||
"""The unique ID of the board."""
|
||||
board_name: str = Field(description="The name of the board.")
|
||||
"""The name of the board."""
|
||||
created_at: Union[datetime, str] = Field(
|
||||
description="The created timestamp of the board."
|
||||
)
|
||||
"""The created timestamp of the image."""
|
||||
updated_at: Union[datetime, str] = Field(
|
||||
description="The updated timestamp of the board."
|
||||
)
|
||||
"""The updated timestamp of the image."""
|
||||
deleted_at: Union[datetime, str, None] = Field(
|
||||
description="The deleted timestamp of the board."
|
||||
)
|
||||
"""The updated timestamp of the image."""
|
||||
cover_image_name: Optional[str] = Field(
|
||||
description="The name of the cover image of the board."
|
||||
)
|
||||
"""The name of the cover image of the board."""
|
||||
|
||||
|
||||
class BoardDTO(BoardRecord):
|
||||
"""Deserialized board record with cover image URL and image count."""
|
||||
|
||||
cover_image_name: Optional[str] = Field(
|
||||
description="The name of the board's cover image."
|
||||
)
|
||||
"""The URL of the thumbnail of the most recent image in the board."""
|
||||
image_count: int = Field(description="The number of images in the board.")
|
||||
"""The number of images in the board."""
|
||||
|
||||
|
||||
def deserialize_board_record(board_dict: dict) -> BoardRecord:
|
||||
"""Deserializes a board record."""
|
||||
|
||||
# Retrieve all the values, setting "reasonable" defaults if they are not present.
|
||||
|
||||
board_id = board_dict.get("board_id", "unknown")
|
||||
board_name = board_dict.get("board_name", "unknown")
|
||||
cover_image_name = board_dict.get("cover_image_name", "unknown")
|
||||
created_at = board_dict.get("created_at", get_iso_timestamp())
|
||||
updated_at = board_dict.get("updated_at", get_iso_timestamp())
|
||||
deleted_at = board_dict.get("deleted_at", get_iso_timestamp())
|
||||
|
||||
return BoardRecord(
|
||||
board_id=board_id,
|
||||
board_name=board_name,
|
||||
cover_image_name=cover_image_name,
|
||||
created_at=created_at,
|
||||
updated_at=updated_at,
|
||||
deleted_at=deleted_at,
|
||||
)
|
@ -86,19 +86,24 @@ class ImageUrlsDTO(BaseModel):
|
||||
|
||||
|
||||
class ImageDTO(ImageRecord, ImageUrlsDTO):
|
||||
"""Deserialized image record, enriched for the frontend with URLs."""
|
||||
"""Deserialized image record, enriched for the frontend."""
|
||||
|
||||
board_id: Union[str, None] = Field(
|
||||
description="The id of the board the image belongs to, if one exists."
|
||||
)
|
||||
"""The id of the board the image belongs to, if one exists."""
|
||||
pass
|
||||
|
||||
|
||||
def image_record_to_dto(
|
||||
image_record: ImageRecord, image_url: str, thumbnail_url: str
|
||||
image_record: ImageRecord, image_url: str, thumbnail_url: str, board_id: Union[str, None]
|
||||
) -> ImageDTO:
|
||||
"""Converts an image record to an image DTO."""
|
||||
return ImageDTO(
|
||||
**image_record.dict(),
|
||||
image_url=image_url,
|
||||
thumbnail_url=thumbnail_url,
|
||||
board_id=board_id,
|
||||
)
|
||||
|
||||
|
||||
|
@ -16,13 +16,14 @@ class RestorationServices:
|
||||
gfpgan, codeformer, esrgan = None, None, None
|
||||
if args.restore or args.esrgan:
|
||||
restoration = Restoration()
|
||||
if args.restore:
|
||||
# TODO: redo for new model structure
|
||||
if False and args.restore:
|
||||
gfpgan, codeformer = restoration.load_face_restore_models(
|
||||
args.gfpgan_model_path
|
||||
)
|
||||
else:
|
||||
logger.info("Face restoration disabled")
|
||||
if args.esrgan:
|
||||
if False and args.esrgan:
|
||||
esrgan = restoration.load_esrgan(args.esrgan_bg_tile)
|
||||
else:
|
||||
logger.info("Upscaling disabled")
|
||||
|
@ -5,9 +5,11 @@ from .generator import (
|
||||
InvokeAIGeneratorBasicParams,
|
||||
InvokeAIGenerator,
|
||||
InvokeAIGeneratorOutput,
|
||||
Txt2Img,
|
||||
Img2Img,
|
||||
Inpaint
|
||||
)
|
||||
from .model_management import ModelManager, SDModelComponent
|
||||
from .model_management import (
|
||||
ModelManager, ModelCache, BaseModelType,
|
||||
ModelType, SubModelType, ModelInfo
|
||||
)
|
||||
from .safety_checker import SafetyChecker
|
||||
|
@ -5,7 +5,6 @@ from .base import (
|
||||
InvokeAIGenerator,
|
||||
InvokeAIGeneratorBasicParams,
|
||||
InvokeAIGeneratorOutput,
|
||||
Txt2Img,
|
||||
Img2Img,
|
||||
Inpaint,
|
||||
Generator,
|
||||
|
@ -29,7 +29,6 @@ import invokeai.backend.util.logging as logger
|
||||
from ..image_util import configure_model_padding
|
||||
from ..util.util import rand_perlin_2d
|
||||
from ..safety_checker import SafetyChecker
|
||||
from ..prompting.conditioning import get_uc_and_c_and_ec
|
||||
from ..stable_diffusion.diffusers_pipeline import StableDiffusionGeneratorPipeline
|
||||
from ..stable_diffusion.schedulers import SCHEDULER_MAP
|
||||
|
||||
@ -81,13 +80,15 @@ class InvokeAIGenerator(metaclass=ABCMeta):
|
||||
self.params=params
|
||||
self.kwargs = kwargs
|
||||
|
||||
def generate(self,
|
||||
prompt: str='',
|
||||
callback: Optional[Callable]=None,
|
||||
step_callback: Optional[Callable]=None,
|
||||
iterations: int=1,
|
||||
**keyword_args,
|
||||
)->Iterator[InvokeAIGeneratorOutput]:
|
||||
def generate(
|
||||
self,
|
||||
conditioning: tuple,
|
||||
scheduler,
|
||||
callback: Optional[Callable]=None,
|
||||
step_callback: Optional[Callable]=None,
|
||||
iterations: int=1,
|
||||
**keyword_args,
|
||||
)->Iterator[InvokeAIGeneratorOutput]:
|
||||
'''
|
||||
Return an iterator across the indicated number of generations.
|
||||
Each time the iterator is called it will return an InvokeAIGeneratorOutput
|
||||
@ -113,54 +114,46 @@ class InvokeAIGenerator(metaclass=ABCMeta):
|
||||
generator_args.update(keyword_args)
|
||||
|
||||
model_info = self.model_info
|
||||
model_name = model_info['model_name']
|
||||
model:StableDiffusionGeneratorPipeline = model_info['model']
|
||||
model_hash = model_info['hash']
|
||||
scheduler: Scheduler = self.get_scheduler(
|
||||
model=model,
|
||||
scheduler_name=generator_args.get('scheduler')
|
||||
)
|
||||
model_name = model_info.name
|
||||
model_hash = model_info.hash
|
||||
with model_info.context as model:
|
||||
gen_class = self._generator_class()
|
||||
generator = gen_class(model, self.params.precision, **self.kwargs)
|
||||
if self.params.variation_amount > 0:
|
||||
generator.set_variation(generator_args.get('seed'),
|
||||
generator_args.get('variation_amount'),
|
||||
generator_args.get('with_variations')
|
||||
)
|
||||
|
||||
# get conditioning from prompt via Compel package
|
||||
uc, c, extra_conditioning_info = get_uc_and_c_and_ec(prompt, model=model)
|
||||
|
||||
gen_class = self._generator_class()
|
||||
generator = gen_class(model, self.params.precision, **self.kwargs)
|
||||
if self.params.variation_amount > 0:
|
||||
generator.set_variation(generator_args.get('seed'),
|
||||
generator_args.get('variation_amount'),
|
||||
generator_args.get('with_variations')
|
||||
)
|
||||
|
||||
if isinstance(model, DiffusionPipeline):
|
||||
for component in [model.unet, model.vae]:
|
||||
configure_model_padding(component,
|
||||
if isinstance(model, DiffusionPipeline):
|
||||
for component in [model.unet, model.vae]:
|
||||
configure_model_padding(component,
|
||||
generator_args.get('seamless',False),
|
||||
generator_args.get('seamless_axes')
|
||||
)
|
||||
else:
|
||||
configure_model_padding(model,
|
||||
generator_args.get('seamless',False),
|
||||
generator_args.get('seamless_axes')
|
||||
)
|
||||
else:
|
||||
configure_model_padding(model,
|
||||
generator_args.get('seamless',False),
|
||||
generator_args.get('seamless_axes')
|
||||
)
|
||||
|
||||
iteration_count = range(iterations) if iterations else itertools.count(start=0, step=1)
|
||||
for i in iteration_count:
|
||||
results = generator.generate(prompt,
|
||||
conditioning=(uc, c, extra_conditioning_info),
|
||||
step_callback=step_callback,
|
||||
sampler=scheduler,
|
||||
**generator_args,
|
||||
)
|
||||
output = InvokeAIGeneratorOutput(
|
||||
image=results[0][0],
|
||||
seed=results[0][1],
|
||||
attention_maps_images=results[0][2],
|
||||
model_hash = model_hash,
|
||||
params=Namespace(model_name=model_name,**generator_args),
|
||||
)
|
||||
if callback:
|
||||
callback(output)
|
||||
iteration_count = range(iterations) if iterations else itertools.count(start=0, step=1)
|
||||
for i in iteration_count:
|
||||
results = generator.generate(
|
||||
conditioning=conditioning,
|
||||
step_callback=step_callback,
|
||||
sampler=scheduler,
|
||||
**generator_args,
|
||||
)
|
||||
output = InvokeAIGeneratorOutput(
|
||||
image=results[0][0],
|
||||
seed=results[0][1],
|
||||
attention_maps_images=results[0][2],
|
||||
model_hash = model_hash,
|
||||
params=Namespace(model_name=model_name,**generator_args),
|
||||
)
|
||||
if callback:
|
||||
callback(output)
|
||||
yield output
|
||||
|
||||
@classmethod
|
||||
@ -173,20 +166,6 @@ class InvokeAIGenerator(metaclass=ABCMeta):
|
||||
def load_generator(self, model: StableDiffusionGeneratorPipeline, generator_class: Type[Generator]):
|
||||
return generator_class(model, self.params.precision)
|
||||
|
||||
def get_scheduler(self, scheduler_name:str, model: StableDiffusionGeneratorPipeline)->Scheduler:
|
||||
scheduler_class, scheduler_extra_config = SCHEDULER_MAP.get(scheduler_name, SCHEDULER_MAP['ddim'])
|
||||
|
||||
scheduler_config = model.scheduler.config
|
||||
if "_backup" in scheduler_config:
|
||||
scheduler_config = scheduler_config["_backup"]
|
||||
scheduler_config = {**scheduler_config, **scheduler_extra_config, "_backup": scheduler_config}
|
||||
scheduler = scheduler_class.from_config(scheduler_config)
|
||||
|
||||
# hack copied over from generate.py
|
||||
if not hasattr(scheduler, 'uses_inpainting_model'):
|
||||
scheduler.uses_inpainting_model = lambda: False
|
||||
return scheduler
|
||||
|
||||
@classmethod
|
||||
def _generator_class(cls)->Type[Generator]:
|
||||
'''
|
||||
@ -196,13 +175,6 @@ class InvokeAIGenerator(metaclass=ABCMeta):
|
||||
'''
|
||||
return Generator
|
||||
|
||||
# ------------------------------------
|
||||
class Txt2Img(InvokeAIGenerator):
|
||||
@classmethod
|
||||
def _generator_class(cls):
|
||||
from .txt2img import Txt2Img
|
||||
return Txt2Img
|
||||
|
||||
# ------------------------------------
|
||||
class Img2Img(InvokeAIGenerator):
|
||||
def generate(self,
|
||||
@ -256,25 +228,6 @@ class Inpaint(Img2Img):
|
||||
from .inpaint import Inpaint
|
||||
return Inpaint
|
||||
|
||||
# ------------------------------------
|
||||
class Embiggen(Txt2Img):
|
||||
def generate(
|
||||
self,
|
||||
embiggen: list=None,
|
||||
embiggen_tiles: list = None,
|
||||
strength: float=0.75,
|
||||
**kwargs)->Iterator[InvokeAIGeneratorOutput]:
|
||||
return super().generate(embiggen=embiggen,
|
||||
embiggen_tiles=embiggen_tiles,
|
||||
strength=strength,
|
||||
**kwargs)
|
||||
|
||||
@classmethod
|
||||
def _generator_class(cls):
|
||||
from .embiggen import Embiggen
|
||||
return Embiggen
|
||||
|
||||
|
||||
class Generator:
|
||||
downsampling_factor: int
|
||||
latent_channels: int
|
||||
@ -285,7 +238,7 @@ class Generator:
|
||||
self.model = model
|
||||
self.precision = precision
|
||||
self.seed = None
|
||||
self.latent_channels = model.channels
|
||||
self.latent_channels = model.unet.config.in_channels
|
||||
self.downsampling_factor = downsampling # BUG: should come from model or config
|
||||
self.safety_checker = None
|
||||
self.perlin = 0.0
|
||||
@ -296,7 +249,7 @@ class Generator:
|
||||
self.free_gpu_mem = None
|
||||
|
||||
# this is going to be overridden in img2img.py, txt2img.py and inpaint.py
|
||||
def get_make_image(self, prompt, **kwargs):
|
||||
def get_make_image(self, **kwargs):
|
||||
"""
|
||||
Returns a function returning an image derived from the prompt and the initial image
|
||||
Return value depends on the seed at the time you call it
|
||||
@ -312,7 +265,6 @@ class Generator:
|
||||
|
||||
def generate(
|
||||
self,
|
||||
prompt,
|
||||
width,
|
||||
height,
|
||||
sampler,
|
||||
@ -337,7 +289,6 @@ class Generator:
|
||||
saver.get_stacked_maps_image()
|
||||
)
|
||||
make_image = self.get_make_image(
|
||||
prompt,
|
||||
sampler=sampler,
|
||||
init_image=init_image,
|
||||
width=width,
|
||||
|
@ -1,559 +0,0 @@
|
||||
"""
|
||||
invokeai.backend.generator.embiggen descends from .generator
|
||||
and generates with .generator.img2img
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
from tqdm import trange
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
|
||||
from .base import Generator
|
||||
from .img2img import Img2Img
|
||||
|
||||
class Embiggen(Generator):
|
||||
def __init__(self, model, precision):
|
||||
super().__init__(model, precision)
|
||||
self.init_latent = None
|
||||
|
||||
# Replace generate because Embiggen doesn't need/use most of what it does normallly
|
||||
def generate(
|
||||
self,
|
||||
prompt,
|
||||
iterations=1,
|
||||
seed=None,
|
||||
image_callback=None,
|
||||
step_callback=None,
|
||||
**kwargs,
|
||||
):
|
||||
make_image = self.get_make_image(prompt, step_callback=step_callback, **kwargs)
|
||||
results = []
|
||||
seed = seed if seed else self.new_seed()
|
||||
|
||||
# Noise will be generated by the Img2Img generator when called
|
||||
for _ in trange(iterations, desc="Generating"):
|
||||
# make_image will call Img2Img which will do the equivalent of get_noise itself
|
||||
image = make_image()
|
||||
results.append([image, seed])
|
||||
if image_callback is not None:
|
||||
image_callback(image, seed, prompt_in=prompt)
|
||||
seed = self.new_seed()
|
||||
return results
|
||||
|
||||
@torch.no_grad()
|
||||
def get_make_image(
|
||||
self,
|
||||
prompt,
|
||||
sampler,
|
||||
steps,
|
||||
cfg_scale,
|
||||
ddim_eta,
|
||||
conditioning,
|
||||
init_img,
|
||||
strength,
|
||||
width,
|
||||
height,
|
||||
embiggen,
|
||||
embiggen_tiles,
|
||||
step_callback=None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Returns a function returning an image derived from the prompt and multi-stage twice-baked potato layering over the img2img on the initial image
|
||||
Return value depends on the seed at the time you call it
|
||||
"""
|
||||
assert (
|
||||
not sampler.uses_inpainting_model()
|
||||
), "--embiggen is not supported by inpainting models"
|
||||
|
||||
# Construct embiggen arg array, and sanity check arguments
|
||||
if embiggen == None: # embiggen can also be called with just embiggen_tiles
|
||||
embiggen = [1.0] # If not specified, assume no scaling
|
||||
elif embiggen[0] < 0:
|
||||
embiggen[0] = 1.0
|
||||
logger.warning(
|
||||
"Embiggen scaling factor cannot be negative, fell back to the default of 1.0 !"
|
||||
)
|
||||
if len(embiggen) < 2:
|
||||
embiggen.append(0.75)
|
||||
elif embiggen[1] > 1.0 or embiggen[1] < 0:
|
||||
embiggen[1] = 0.75
|
||||
logger.warning(
|
||||
"Embiggen upscaling strength for ESRGAN must be between 0 and 1, fell back to the default of 0.75 !"
|
||||
)
|
||||
if len(embiggen) < 3:
|
||||
embiggen.append(0.25)
|
||||
elif embiggen[2] < 0:
|
||||
embiggen[2] = 0.25
|
||||
logger.warning(
|
||||
"Overlap size for Embiggen must be a positive ratio between 0 and 1 OR a number of pixels, fell back to the default of 0.25 !"
|
||||
)
|
||||
|
||||
# Convert tiles from their user-freindly count-from-one to count-from-zero, because we need to do modulo math
|
||||
# and then sort them, because... people.
|
||||
if embiggen_tiles:
|
||||
embiggen_tiles = list(map(lambda n: n - 1, embiggen_tiles))
|
||||
embiggen_tiles.sort()
|
||||
|
||||
if strength >= 0.5:
|
||||
logger.warning(
|
||||
f"Embiggen may produce mirror motifs if the strength (-f) is too high (currently {strength}). Try values between 0.35-0.45."
|
||||
)
|
||||
|
||||
# Prep img2img generator, since we wrap over it
|
||||
gen_img2img = Img2Img(self.model, self.precision)
|
||||
|
||||
# Open original init image (not a tensor) to manipulate
|
||||
initsuperimage = Image.open(init_img)
|
||||
|
||||
with Image.open(init_img) as img:
|
||||
initsuperimage = img.convert("RGB")
|
||||
|
||||
# Size of the target super init image in pixels
|
||||
initsuperwidth, initsuperheight = initsuperimage.size
|
||||
|
||||
# Increase by scaling factor if not already resized, using ESRGAN as able
|
||||
if embiggen[0] != 1.0:
|
||||
initsuperwidth = round(initsuperwidth * embiggen[0])
|
||||
initsuperheight = round(initsuperheight * embiggen[0])
|
||||
if embiggen[1] > 0: # No point in ESRGAN upscaling if strength is set zero
|
||||
from ..restoration.realesrgan import ESRGAN
|
||||
|
||||
esrgan = ESRGAN()
|
||||
logger.info(
|
||||
f"ESRGAN upscaling init image prior to cutting with Embiggen with strength {embiggen[1]}"
|
||||
)
|
||||
if embiggen[0] > 2:
|
||||
initsuperimage = esrgan.process(
|
||||
initsuperimage,
|
||||
embiggen[1], # upscale strength
|
||||
self.seed,
|
||||
4, # upscale scale
|
||||
)
|
||||
else:
|
||||
initsuperimage = esrgan.process(
|
||||
initsuperimage,
|
||||
embiggen[1], # upscale strength
|
||||
self.seed,
|
||||
2, # upscale scale
|
||||
)
|
||||
# We could keep recursively re-running ESRGAN for a requested embiggen[0] larger than 4x
|
||||
# but from personal experiance it doesn't greatly improve anything after 4x
|
||||
# Resize to target scaling factor resolution
|
||||
initsuperimage = initsuperimage.resize(
|
||||
(initsuperwidth, initsuperheight), Image.Resampling.LANCZOS
|
||||
)
|
||||
|
||||
# Use width and height as tile widths and height
|
||||
# Determine buffer size in pixels
|
||||
if embiggen[2] < 1:
|
||||
if embiggen[2] < 0:
|
||||
embiggen[2] = 0
|
||||
overlap_size_x = round(embiggen[2] * width)
|
||||
overlap_size_y = round(embiggen[2] * height)
|
||||
else:
|
||||
overlap_size_x = round(embiggen[2])
|
||||
overlap_size_y = round(embiggen[2])
|
||||
|
||||
# With overall image width and height known, determine how many tiles we need
|
||||
def ceildiv(a, b):
|
||||
return -1 * (-a // b)
|
||||
|
||||
# X and Y needs to be determined independantly (we may have savings on one based on the buffer pixel count)
|
||||
# (initsuperwidth - width) is the area remaining to the right that we need to layers tiles to fill
|
||||
# (width - overlap_size_x) is how much new we can fill with a single tile
|
||||
emb_tiles_x = 1
|
||||
emb_tiles_y = 1
|
||||
if (initsuperwidth - width) > 0:
|
||||
emb_tiles_x = ceildiv(initsuperwidth - width, width - overlap_size_x) + 1
|
||||
if (initsuperheight - height) > 0:
|
||||
emb_tiles_y = ceildiv(initsuperheight - height, height - overlap_size_y) + 1
|
||||
# Sanity
|
||||
assert (
|
||||
emb_tiles_x > 1 or emb_tiles_y > 1
|
||||
), f"ERROR: Based on the requested dimensions of {initsuperwidth}x{initsuperheight} and tiles of {width}x{height} you don't need to Embiggen! Check your arguments."
|
||||
|
||||
# Prep alpha layers --------------
|
||||
# https://stackoverflow.com/questions/69321734/how-to-create-different-transparency-like-gradient-with-python-pil
|
||||
# agradientL is Left-side transparent
|
||||
agradientL = (
|
||||
Image.linear_gradient("L").rotate(90).resize((overlap_size_x, height))
|
||||
)
|
||||
# agradientT is Top-side transparent
|
||||
agradientT = Image.linear_gradient("L").resize((width, overlap_size_y))
|
||||
# radial corner is the left-top corner, made full circle then cut to just the left-top quadrant
|
||||
agradientC = Image.new("L", (256, 256))
|
||||
for y in range(256):
|
||||
for x in range(256):
|
||||
# Find distance to lower right corner (numpy takes arrays)
|
||||
distanceToLR = np.sqrt([(255 - x) ** 2 + (255 - y) ** 2])[0]
|
||||
# Clamp values to max 255
|
||||
if distanceToLR > 255:
|
||||
distanceToLR = 255
|
||||
# Place the pixel as invert of distance
|
||||
agradientC.putpixel((x, y), round(255 - distanceToLR))
|
||||
|
||||
# Create alternative asymmetric diagonal corner to use on "tailing" intersections to prevent hard edges
|
||||
# Fits for a left-fading gradient on the bottom side and full opacity on the right side.
|
||||
agradientAsymC = Image.new("L", (256, 256))
|
||||
for y in range(256):
|
||||
for x in range(256):
|
||||
value = round(max(0, x - (255 - y)) * (255 / max(1, y)))
|
||||
# Clamp values
|
||||
value = max(0, value)
|
||||
value = min(255, value)
|
||||
agradientAsymC.putpixel((x, y), value)
|
||||
|
||||
# Create alpha layers default fully white
|
||||
alphaLayerL = Image.new("L", (width, height), 255)
|
||||
alphaLayerT = Image.new("L", (width, height), 255)
|
||||
alphaLayerLTC = Image.new("L", (width, height), 255)
|
||||
# Paste gradients into alpha layers
|
||||
alphaLayerL.paste(agradientL, (0, 0))
|
||||
alphaLayerT.paste(agradientT, (0, 0))
|
||||
alphaLayerLTC.paste(agradientL, (0, 0))
|
||||
alphaLayerLTC.paste(agradientT, (0, 0))
|
||||
alphaLayerLTC.paste(agradientC.resize((overlap_size_x, overlap_size_y)), (0, 0))
|
||||
# make masks with an asymmetric upper-right corner so when the curved transparent corner of the next tile
|
||||
# to its right is placed it doesn't reveal a hard trailing semi-transparent edge in the overlapping space
|
||||
alphaLayerTaC = alphaLayerT.copy()
|
||||
alphaLayerTaC.paste(
|
||||
agradientAsymC.rotate(270).resize((overlap_size_x, overlap_size_y)),
|
||||
(width - overlap_size_x, 0),
|
||||
)
|
||||
alphaLayerLTaC = alphaLayerLTC.copy()
|
||||
alphaLayerLTaC.paste(
|
||||
agradientAsymC.rotate(270).resize((overlap_size_x, overlap_size_y)),
|
||||
(width - overlap_size_x, 0),
|
||||
)
|
||||
|
||||
if embiggen_tiles:
|
||||
# Individual unconnected sides
|
||||
alphaLayerR = Image.new("L", (width, height), 255)
|
||||
alphaLayerR.paste(agradientL.rotate(180), (width - overlap_size_x, 0))
|
||||
alphaLayerB = Image.new("L", (width, height), 255)
|
||||
alphaLayerB.paste(agradientT.rotate(180), (0, height - overlap_size_y))
|
||||
alphaLayerTB = Image.new("L", (width, height), 255)
|
||||
alphaLayerTB.paste(agradientT, (0, 0))
|
||||
alphaLayerTB.paste(agradientT.rotate(180), (0, height - overlap_size_y))
|
||||
alphaLayerLR = Image.new("L", (width, height), 255)
|
||||
alphaLayerLR.paste(agradientL, (0, 0))
|
||||
alphaLayerLR.paste(agradientL.rotate(180), (width - overlap_size_x, 0))
|
||||
|
||||
# Sides and corner Layers
|
||||
alphaLayerRBC = Image.new("L", (width, height), 255)
|
||||
alphaLayerRBC.paste(agradientL.rotate(180), (width - overlap_size_x, 0))
|
||||
alphaLayerRBC.paste(agradientT.rotate(180), (0, height - overlap_size_y))
|
||||
alphaLayerRBC.paste(
|
||||
agradientC.rotate(180).resize((overlap_size_x, overlap_size_y)),
|
||||
(width - overlap_size_x, height - overlap_size_y),
|
||||
)
|
||||
alphaLayerLBC = Image.new("L", (width, height), 255)
|
||||
alphaLayerLBC.paste(agradientL, (0, 0))
|
||||
alphaLayerLBC.paste(agradientT.rotate(180), (0, height - overlap_size_y))
|
||||
alphaLayerLBC.paste(
|
||||
agradientC.rotate(90).resize((overlap_size_x, overlap_size_y)),
|
||||
(0, height - overlap_size_y),
|
||||
)
|
||||
alphaLayerRTC = Image.new("L", (width, height), 255)
|
||||
alphaLayerRTC.paste(agradientL.rotate(180), (width - overlap_size_x, 0))
|
||||
alphaLayerRTC.paste(agradientT, (0, 0))
|
||||
alphaLayerRTC.paste(
|
||||
agradientC.rotate(270).resize((overlap_size_x, overlap_size_y)),
|
||||
(width - overlap_size_x, 0),
|
||||
)
|
||||
|
||||
# All but X layers
|
||||
alphaLayerABT = Image.new("L", (width, height), 255)
|
||||
alphaLayerABT.paste(alphaLayerLBC, (0, 0))
|
||||
alphaLayerABT.paste(agradientL.rotate(180), (width - overlap_size_x, 0))
|
||||
alphaLayerABT.paste(
|
||||
agradientC.rotate(180).resize((overlap_size_x, overlap_size_y)),
|
||||
(width - overlap_size_x, height - overlap_size_y),
|
||||
)
|
||||
alphaLayerABL = Image.new("L", (width, height), 255)
|
||||
alphaLayerABL.paste(alphaLayerRTC, (0, 0))
|
||||
alphaLayerABL.paste(agradientT.rotate(180), (0, height - overlap_size_y))
|
||||
alphaLayerABL.paste(
|
||||
agradientC.rotate(180).resize((overlap_size_x, overlap_size_y)),
|
||||
(width - overlap_size_x, height - overlap_size_y),
|
||||
)
|
||||
alphaLayerABR = Image.new("L", (width, height), 255)
|
||||
alphaLayerABR.paste(alphaLayerLBC, (0, 0))
|
||||
alphaLayerABR.paste(agradientT, (0, 0))
|
||||
alphaLayerABR.paste(
|
||||
agradientC.resize((overlap_size_x, overlap_size_y)), (0, 0)
|
||||
)
|
||||
alphaLayerABB = Image.new("L", (width, height), 255)
|
||||
alphaLayerABB.paste(alphaLayerRTC, (0, 0))
|
||||
alphaLayerABB.paste(agradientL, (0, 0))
|
||||
alphaLayerABB.paste(
|
||||
agradientC.resize((overlap_size_x, overlap_size_y)), (0, 0)
|
||||
)
|
||||
|
||||
# All-around layer
|
||||
alphaLayerAA = Image.new("L", (width, height), 255)
|
||||
alphaLayerAA.paste(alphaLayerABT, (0, 0))
|
||||
alphaLayerAA.paste(agradientT, (0, 0))
|
||||
alphaLayerAA.paste(
|
||||
agradientC.resize((overlap_size_x, overlap_size_y)), (0, 0)
|
||||
)
|
||||
alphaLayerAA.paste(
|
||||
agradientC.rotate(270).resize((overlap_size_x, overlap_size_y)),
|
||||
(width - overlap_size_x, 0),
|
||||
)
|
||||
|
||||
# Clean up temporary gradients
|
||||
del agradientL
|
||||
del agradientT
|
||||
del agradientC
|
||||
|
||||
def make_image():
|
||||
# Make main tiles -------------------------------------------------
|
||||
if embiggen_tiles:
|
||||
logger.info(f"Making {len(embiggen_tiles)} Embiggen tiles...")
|
||||
else:
|
||||
logger.info(
|
||||
f"Making {(emb_tiles_x * emb_tiles_y)} Embiggen tiles ({emb_tiles_x}x{emb_tiles_y})..."
|
||||
)
|
||||
|
||||
emb_tile_store = []
|
||||
# Although we could use the same seed for every tile for determinism, at higher strengths this may
|
||||
# produce duplicated structures for each tile and make the tiling effect more obvious
|
||||
# instead track and iterate a local seed we pass to Img2Img
|
||||
seed = self.seed
|
||||
seedintlimit = (
|
||||
np.iinfo(np.uint32).max - 1
|
||||
) # only retreive this one from numpy
|
||||
|
||||
for tile in range(emb_tiles_x * emb_tiles_y):
|
||||
# Don't iterate on first tile
|
||||
if tile != 0:
|
||||
if seed < seedintlimit:
|
||||
seed += 1
|
||||
else:
|
||||
seed = 0
|
||||
|
||||
# Determine if this is a re-run and replace
|
||||
if embiggen_tiles and not tile in embiggen_tiles:
|
||||
continue
|
||||
# Get row and column entries
|
||||
emb_row_i = tile // emb_tiles_x
|
||||
emb_column_i = tile % emb_tiles_x
|
||||
# Determine bounds to cut up the init image
|
||||
# Determine upper-left point
|
||||
if emb_column_i + 1 == emb_tiles_x:
|
||||
left = initsuperwidth - width
|
||||
else:
|
||||
left = round(emb_column_i * (width - overlap_size_x))
|
||||
if emb_row_i + 1 == emb_tiles_y:
|
||||
top = initsuperheight - height
|
||||
else:
|
||||
top = round(emb_row_i * (height - overlap_size_y))
|
||||
right = left + width
|
||||
bottom = top + height
|
||||
|
||||
# Cropped image of above dimension (does not modify the original)
|
||||
newinitimage = initsuperimage.crop((left, top, right, bottom))
|
||||
# DEBUG:
|
||||
# newinitimagepath = init_img[0:-4] + f'_emb_Ti{tile}.png'
|
||||
# newinitimage.save(newinitimagepath)
|
||||
|
||||
if embiggen_tiles:
|
||||
logger.debug(
|
||||
f"Making tile #{tile + 1} ({embiggen_tiles.index(tile) + 1} of {len(embiggen_tiles)} requested)"
|
||||
)
|
||||
else:
|
||||
logger.debug(f"Starting {tile + 1} of {(emb_tiles_x * emb_tiles_y)} tiles")
|
||||
|
||||
# create a torch tensor from an Image
|
||||
newinitimage = np.array(newinitimage).astype(np.float32) / 255.0
|
||||
newinitimage = newinitimage[None].transpose(0, 3, 1, 2)
|
||||
newinitimage = torch.from_numpy(newinitimage)
|
||||
newinitimage = 2.0 * newinitimage - 1.0
|
||||
newinitimage = newinitimage.to(self.model.device)
|
||||
clear_cuda_cache = (
|
||||
kwargs["clear_cuda_cache"] if "clear_cuda_cache" in kwargs else None
|
||||
)
|
||||
|
||||
tile_results = gen_img2img.generate(
|
||||
prompt,
|
||||
iterations=1,
|
||||
seed=seed,
|
||||
sampler=sampler,
|
||||
steps=steps,
|
||||
cfg_scale=cfg_scale,
|
||||
conditioning=conditioning,
|
||||
ddim_eta=ddim_eta,
|
||||
image_callback=None, # called only after the final image is generated
|
||||
step_callback=step_callback, # called after each intermediate image is generated
|
||||
width=width,
|
||||
height=height,
|
||||
init_image=newinitimage, # notice that init_image is different from init_img
|
||||
mask_image=None,
|
||||
strength=strength,
|
||||
clear_cuda_cache=clear_cuda_cache,
|
||||
)
|
||||
|
||||
emb_tile_store.append(tile_results[0][0])
|
||||
# DEBUG (but, also has other uses), worth saving if you want tiles without a transparency overlap to manually composite
|
||||
# emb_tile_store[-1].save(init_img[0:-4] + f'_emb_To{tile}.png')
|
||||
del newinitimage
|
||||
|
||||
# Sanity check we have them all
|
||||
if len(emb_tile_store) == (emb_tiles_x * emb_tiles_y) or (
|
||||
embiggen_tiles != [] and len(emb_tile_store) == len(embiggen_tiles)
|
||||
):
|
||||
outputsuperimage = Image.new("RGBA", (initsuperwidth, initsuperheight))
|
||||
if embiggen_tiles:
|
||||
outputsuperimage.alpha_composite(
|
||||
initsuperimage.convert("RGBA"), (0, 0)
|
||||
)
|
||||
for tile in range(emb_tiles_x * emb_tiles_y):
|
||||
if embiggen_tiles:
|
||||
if tile in embiggen_tiles:
|
||||
intileimage = emb_tile_store.pop(0)
|
||||
else:
|
||||
continue
|
||||
else:
|
||||
intileimage = emb_tile_store[tile]
|
||||
intileimage = intileimage.convert("RGBA")
|
||||
# Get row and column entries
|
||||
emb_row_i = tile // emb_tiles_x
|
||||
emb_column_i = tile % emb_tiles_x
|
||||
if emb_row_i == 0 and emb_column_i == 0 and not embiggen_tiles:
|
||||
left = 0
|
||||
top = 0
|
||||
else:
|
||||
# Determine upper-left point
|
||||
if emb_column_i + 1 == emb_tiles_x:
|
||||
left = initsuperwidth - width
|
||||
else:
|
||||
left = round(emb_column_i * (width - overlap_size_x))
|
||||
if emb_row_i + 1 == emb_tiles_y:
|
||||
top = initsuperheight - height
|
||||
else:
|
||||
top = round(emb_row_i * (height - overlap_size_y))
|
||||
# Handle gradients for various conditions
|
||||
# Handle emb_rerun case
|
||||
if embiggen_tiles:
|
||||
# top of image
|
||||
if emb_row_i == 0:
|
||||
if emb_column_i == 0:
|
||||
if (tile + 1) in embiggen_tiles: # Look-ahead right
|
||||
if (
|
||||
tile + emb_tiles_x
|
||||
) not in embiggen_tiles: # Look-ahead down
|
||||
intileimage.putalpha(alphaLayerB)
|
||||
# Otherwise do nothing on this tile
|
||||
elif (
|
||||
tile + emb_tiles_x
|
||||
) in embiggen_tiles: # Look-ahead down only
|
||||
intileimage.putalpha(alphaLayerR)
|
||||
else:
|
||||
intileimage.putalpha(alphaLayerRBC)
|
||||
elif emb_column_i == emb_tiles_x - 1:
|
||||
if (
|
||||
tile + emb_tiles_x
|
||||
) in embiggen_tiles: # Look-ahead down
|
||||
intileimage.putalpha(alphaLayerL)
|
||||
else:
|
||||
intileimage.putalpha(alphaLayerLBC)
|
||||
else:
|
||||
if (tile + 1) in embiggen_tiles: # Look-ahead right
|
||||
if (
|
||||
tile + emb_tiles_x
|
||||
) in embiggen_tiles: # Look-ahead down
|
||||
intileimage.putalpha(alphaLayerL)
|
||||
else:
|
||||
intileimage.putalpha(alphaLayerLBC)
|
||||
elif (
|
||||
tile + emb_tiles_x
|
||||
) in embiggen_tiles: # Look-ahead down only
|
||||
intileimage.putalpha(alphaLayerLR)
|
||||
else:
|
||||
intileimage.putalpha(alphaLayerABT)
|
||||
# bottom of image
|
||||
elif emb_row_i == emb_tiles_y - 1:
|
||||
if emb_column_i == 0:
|
||||
if (tile + 1) in embiggen_tiles: # Look-ahead right
|
||||
intileimage.putalpha(alphaLayerTaC)
|
||||
else:
|
||||
intileimage.putalpha(alphaLayerRTC)
|
||||
elif emb_column_i == emb_tiles_x - 1:
|
||||
# No tiles to look ahead to
|
||||
intileimage.putalpha(alphaLayerLTC)
|
||||
else:
|
||||
if (tile + 1) in embiggen_tiles: # Look-ahead right
|
||||
intileimage.putalpha(alphaLayerLTaC)
|
||||
else:
|
||||
intileimage.putalpha(alphaLayerABB)
|
||||
# vertical middle of image
|
||||
else:
|
||||
if emb_column_i == 0:
|
||||
if (tile + 1) in embiggen_tiles: # Look-ahead right
|
||||
if (
|
||||
tile + emb_tiles_x
|
||||
) in embiggen_tiles: # Look-ahead down
|
||||
intileimage.putalpha(alphaLayerTaC)
|
||||
else:
|
||||
intileimage.putalpha(alphaLayerTB)
|
||||
elif (
|
||||
tile + emb_tiles_x
|
||||
) in embiggen_tiles: # Look-ahead down only
|
||||
intileimage.putalpha(alphaLayerRTC)
|
||||
else:
|
||||
intileimage.putalpha(alphaLayerABL)
|
||||
elif emb_column_i == emb_tiles_x - 1:
|
||||
if (
|
||||
tile + emb_tiles_x
|
||||
) in embiggen_tiles: # Look-ahead down
|
||||
intileimage.putalpha(alphaLayerLTC)
|
||||
else:
|
||||
intileimage.putalpha(alphaLayerABR)
|
||||
else:
|
||||
if (tile + 1) in embiggen_tiles: # Look-ahead right
|
||||
if (
|
||||
tile + emb_tiles_x
|
||||
) in embiggen_tiles: # Look-ahead down
|
||||
intileimage.putalpha(alphaLayerLTaC)
|
||||
else:
|
||||
intileimage.putalpha(alphaLayerABR)
|
||||
elif (
|
||||
tile + emb_tiles_x
|
||||
) in embiggen_tiles: # Look-ahead down only
|
||||
intileimage.putalpha(alphaLayerABB)
|
||||
else:
|
||||
intileimage.putalpha(alphaLayerAA)
|
||||
# Handle normal tiling case (much simpler - since we tile left to right, top to bottom)
|
||||
else:
|
||||
if emb_row_i == 0 and emb_column_i >= 1:
|
||||
intileimage.putalpha(alphaLayerL)
|
||||
elif emb_row_i >= 1 and emb_column_i == 0:
|
||||
if (
|
||||
emb_column_i + 1 == emb_tiles_x
|
||||
): # If we don't have anything that can be placed to the right
|
||||
intileimage.putalpha(alphaLayerT)
|
||||
else:
|
||||
intileimage.putalpha(alphaLayerTaC)
|
||||
else:
|
||||
if (
|
||||
emb_column_i + 1 == emb_tiles_x
|
||||
): # If we don't have anything that can be placed to the right
|
||||
intileimage.putalpha(alphaLayerLTC)
|
||||
else:
|
||||
intileimage.putalpha(alphaLayerLTaC)
|
||||
# Layer tile onto final image
|
||||
outputsuperimage.alpha_composite(intileimage, (left, top))
|
||||
else:
|
||||
logger.error(
|
||||
"Could not find all Embiggen output tiles in memory? Something must have gone wrong with img2img generation."
|
||||
)
|
||||
|
||||
# after internal loops and patching up return Embiggen image
|
||||
return outputsuperimage
|
||||
|
||||
# end of function declaration
|
||||
return make_image
|
@ -22,7 +22,6 @@ class Img2Img(Generator):
|
||||
|
||||
def get_make_image(
|
||||
self,
|
||||
prompt,
|
||||
sampler,
|
||||
steps,
|
||||
cfg_scale,
|
||||
|
@ -161,9 +161,7 @@ class Inpaint(Img2Img):
|
||||
im: Image.Image,
|
||||
seam_size: int,
|
||||
seam_blur: int,
|
||||
prompt,
|
||||
seed,
|
||||
sampler,
|
||||
steps,
|
||||
cfg_scale,
|
||||
ddim_eta,
|
||||
@ -177,8 +175,6 @@ class Inpaint(Img2Img):
|
||||
mask = self.mask_edge(hard_mask, seam_size, seam_blur)
|
||||
|
||||
make_image = self.get_make_image(
|
||||
prompt,
|
||||
sampler,
|
||||
steps,
|
||||
cfg_scale,
|
||||
ddim_eta,
|
||||
@ -203,8 +199,6 @@ class Inpaint(Img2Img):
|
||||
@torch.no_grad()
|
||||
def get_make_image(
|
||||
self,
|
||||
prompt,
|
||||
sampler,
|
||||
steps,
|
||||
cfg_scale,
|
||||
ddim_eta,
|
||||
@ -306,7 +300,6 @@ class Inpaint(Img2Img):
|
||||
|
||||
# noinspection PyTypeChecker
|
||||
pipeline: StableDiffusionGeneratorPipeline = self.model
|
||||
pipeline.scheduler = sampler
|
||||
|
||||
# todo: support cross-attention control
|
||||
uc, c, _ = conditioning
|
||||
@ -345,9 +338,7 @@ class Inpaint(Img2Img):
|
||||
result,
|
||||
seam_size,
|
||||
seam_blur,
|
||||
prompt,
|
||||
seed,
|
||||
sampler,
|
||||
seam_steps,
|
||||
cfg_scale,
|
||||
ddim_eta,
|
||||
@ -360,8 +351,6 @@ class Inpaint(Img2Img):
|
||||
|
||||
# Restore original settings
|
||||
self.get_make_image(
|
||||
prompt,
|
||||
sampler,
|
||||
steps,
|
||||
cfg_scale,
|
||||
ddim_eta,
|
||||
|
@ -1,125 +0,0 @@
|
||||
"""
|
||||
invokeai.backend.generator.txt2img inherits from invokeai.backend.generator
|
||||
"""
|
||||
import PIL.Image
|
||||
import torch
|
||||
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
from diffusers.models.controlnet import ControlNetModel, ControlNetOutput
|
||||
from diffusers.pipelines.controlnet import MultiControlNetModel
|
||||
|
||||
from ..stable_diffusion import (
|
||||
ConditioningData,
|
||||
PostprocessingSettings,
|
||||
StableDiffusionGeneratorPipeline,
|
||||
)
|
||||
from .base import Generator
|
||||
|
||||
|
||||
class Txt2Img(Generator):
|
||||
def __init__(self, model, precision,
|
||||
control_model: Optional[Union[ControlNetModel, List[ControlNetModel]]] = None,
|
||||
**kwargs):
|
||||
self.control_model = control_model
|
||||
if isinstance(self.control_model, list):
|
||||
self.control_model = MultiControlNetModel(self.control_model)
|
||||
super().__init__(model, precision, **kwargs)
|
||||
|
||||
@torch.no_grad()
|
||||
def get_make_image(
|
||||
self,
|
||||
prompt,
|
||||
sampler,
|
||||
steps,
|
||||
cfg_scale,
|
||||
ddim_eta,
|
||||
conditioning,
|
||||
width,
|
||||
height,
|
||||
step_callback=None,
|
||||
threshold=0.0,
|
||||
warmup=0.2,
|
||||
perlin=0.0,
|
||||
h_symmetry_time_pct=None,
|
||||
v_symmetry_time_pct=None,
|
||||
attention_maps_callback=None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Returns a function returning an image derived from the prompt and the initial image
|
||||
Return value depends on the seed at the time you call it
|
||||
kwargs are 'width' and 'height'
|
||||
"""
|
||||
self.perlin = perlin
|
||||
control_image = kwargs.get("control_image", None)
|
||||
do_classifier_free_guidance = cfg_scale > 1.0
|
||||
|
||||
# noinspection PyTypeChecker
|
||||
pipeline: StableDiffusionGeneratorPipeline = self.model
|
||||
pipeline.control_model = self.control_model
|
||||
pipeline.scheduler = sampler
|
||||
|
||||
uc, c, extra_conditioning_info = conditioning
|
||||
conditioning_data = ConditioningData(
|
||||
uc,
|
||||
c,
|
||||
cfg_scale,
|
||||
extra_conditioning_info,
|
||||
postprocessing_settings=PostprocessingSettings(
|
||||
threshold=threshold,
|
||||
warmup=warmup,
|
||||
h_symmetry_time_pct=h_symmetry_time_pct,
|
||||
v_symmetry_time_pct=v_symmetry_time_pct,
|
||||
),
|
||||
).add_scheduler_args_if_applicable(pipeline.scheduler, eta=ddim_eta)
|
||||
|
||||
# FIXME: still need to test with different widths, heights, devices, dtypes
|
||||
# and add in batch_size, num_images_per_prompt?
|
||||
if control_image is not None:
|
||||
if isinstance(self.control_model, ControlNetModel):
|
||||
control_image = pipeline.prepare_control_image(
|
||||
image=control_image,
|
||||
do_classifier_free_guidance=do_classifier_free_guidance,
|
||||
width=width,
|
||||
height=height,
|
||||
# batch_size=batch_size * num_images_per_prompt,
|
||||
# num_images_per_prompt=num_images_per_prompt,
|
||||
device=self.control_model.device,
|
||||
dtype=self.control_model.dtype,
|
||||
)
|
||||
elif isinstance(self.control_model, MultiControlNetModel):
|
||||
images = []
|
||||
for image_ in control_image:
|
||||
image_ = self.model.prepare_control_image(
|
||||
image=image_,
|
||||
do_classifier_free_guidance=do_classifier_free_guidance,
|
||||
width=width,
|
||||
height=height,
|
||||
# batch_size=batch_size * num_images_per_prompt,
|
||||
# num_images_per_prompt=num_images_per_prompt,
|
||||
device=self.control_model.device,
|
||||
dtype=self.control_model.dtype,
|
||||
)
|
||||
images.append(image_)
|
||||
control_image = images
|
||||
kwargs["control_image"] = control_image
|
||||
|
||||
def make_image(x_T: torch.Tensor, _: int) -> PIL.Image.Image:
|
||||
pipeline_output = pipeline.image_from_embeddings(
|
||||
latents=torch.zeros_like(x_T, dtype=self.torch_dtype()),
|
||||
noise=x_T,
|
||||
num_inference_steps=steps,
|
||||
conditioning_data=conditioning_data,
|
||||
callback=step_callback,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if (
|
||||
pipeline_output.attention_map_saver is not None
|
||||
and attention_maps_callback is not None
|
||||
):
|
||||
attention_maps_callback(pipeline_output.attention_map_saver)
|
||||
|
||||
return pipeline.numpy_to_pil(pipeline_output.images)[0]
|
||||
|
||||
return make_image
|
@ -1,209 +0,0 @@
|
||||
"""
|
||||
invokeai.backend.generator.txt2img inherits from invokeai.backend.generator
|
||||
"""
|
||||
|
||||
import math
|
||||
from typing import Callable, Optional
|
||||
|
||||
import torch
|
||||
from diffusers.utils.logging import get_verbosity, set_verbosity, set_verbosity_error
|
||||
|
||||
from ..stable_diffusion import PostprocessingSettings
|
||||
from .base import Generator
|
||||
from ..stable_diffusion.diffusers_pipeline import StableDiffusionGeneratorPipeline
|
||||
from ..stable_diffusion.diffusers_pipeline import ConditioningData
|
||||
from ..stable_diffusion.diffusers_pipeline import trim_to_multiple_of
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
|
||||
class Txt2Img2Img(Generator):
|
||||
def __init__(self, model, precision):
|
||||
super().__init__(model, precision)
|
||||
self.init_latent = None # for get_noise()
|
||||
|
||||
def get_make_image(
|
||||
self,
|
||||
prompt: str,
|
||||
sampler,
|
||||
steps: int,
|
||||
cfg_scale: float,
|
||||
ddim_eta,
|
||||
conditioning,
|
||||
width: int,
|
||||
height: int,
|
||||
strength: float,
|
||||
step_callback: Optional[Callable] = None,
|
||||
threshold=0.0,
|
||||
warmup=0.2,
|
||||
perlin=0.0,
|
||||
h_symmetry_time_pct=None,
|
||||
v_symmetry_time_pct=None,
|
||||
attention_maps_callback=None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Returns a function returning an image derived from the prompt and the initial image
|
||||
Return value depends on the seed at the time you call it
|
||||
kwargs are 'width' and 'height'
|
||||
"""
|
||||
self.perlin = perlin
|
||||
|
||||
# noinspection PyTypeChecker
|
||||
pipeline: StableDiffusionGeneratorPipeline = self.model
|
||||
pipeline.scheduler = sampler
|
||||
|
||||
uc, c, extra_conditioning_info = conditioning
|
||||
conditioning_data = ConditioningData(
|
||||
uc,
|
||||
c,
|
||||
cfg_scale,
|
||||
extra_conditioning_info,
|
||||
postprocessing_settings=PostprocessingSettings(
|
||||
threshold=threshold,
|
||||
warmup=0.2,
|
||||
h_symmetry_time_pct=h_symmetry_time_pct,
|
||||
v_symmetry_time_pct=v_symmetry_time_pct,
|
||||
),
|
||||
).add_scheduler_args_if_applicable(pipeline.scheduler, eta=ddim_eta)
|
||||
|
||||
def make_image(x_T: torch.Tensor, _: int):
|
||||
first_pass_latent_output, _ = pipeline.latents_from_embeddings(
|
||||
latents=torch.zeros_like(x_T),
|
||||
num_inference_steps=steps,
|
||||
conditioning_data=conditioning_data,
|
||||
noise=x_T,
|
||||
callback=step_callback,
|
||||
)
|
||||
|
||||
# Get our initial generation width and height directly from the latent output so
|
||||
# the message below is accurate.
|
||||
init_width = first_pass_latent_output.size()[3] * self.downsampling_factor
|
||||
init_height = first_pass_latent_output.size()[2] * self.downsampling_factor
|
||||
logger.info(
|
||||
f"Interpolating from {init_width}x{init_height} to {width}x{height} using DDIM sampling"
|
||||
)
|
||||
|
||||
# resizing
|
||||
resized_latents = torch.nn.functional.interpolate(
|
||||
first_pass_latent_output,
|
||||
size=(
|
||||
height // self.downsampling_factor,
|
||||
width // self.downsampling_factor,
|
||||
),
|
||||
mode="bilinear",
|
||||
)
|
||||
|
||||
# Free up memory from the last generation.
|
||||
clear_cuda_cache = kwargs["clear_cuda_cache"] or None
|
||||
if clear_cuda_cache is not None:
|
||||
clear_cuda_cache()
|
||||
|
||||
second_pass_noise = self.get_noise_like(
|
||||
resized_latents, override_perlin=True
|
||||
)
|
||||
|
||||
# Clear symmetry for the second pass
|
||||
from dataclasses import replace
|
||||
|
||||
new_postprocessing_settings = replace(
|
||||
conditioning_data.postprocessing_settings, h_symmetry_time_pct=None
|
||||
)
|
||||
new_postprocessing_settings = replace(
|
||||
new_postprocessing_settings, v_symmetry_time_pct=None
|
||||
)
|
||||
new_conditioning_data = replace(
|
||||
conditioning_data, postprocessing_settings=new_postprocessing_settings
|
||||
)
|
||||
|
||||
verbosity = get_verbosity()
|
||||
set_verbosity_error()
|
||||
pipeline_output = pipeline.img2img_from_latents_and_embeddings(
|
||||
resized_latents,
|
||||
num_inference_steps=steps,
|
||||
conditioning_data=new_conditioning_data,
|
||||
strength=strength,
|
||||
noise=second_pass_noise,
|
||||
callback=step_callback,
|
||||
)
|
||||
set_verbosity(verbosity)
|
||||
|
||||
if (
|
||||
pipeline_output.attention_map_saver is not None
|
||||
and attention_maps_callback is not None
|
||||
):
|
||||
attention_maps_callback(pipeline_output.attention_map_saver)
|
||||
|
||||
return pipeline.numpy_to_pil(pipeline_output.images)[0]
|
||||
|
||||
# FIXME: do we really need something entirely different for the inpainting model?
|
||||
|
||||
# in the case of the inpainting model being loaded, the trick of
|
||||
# providing an interpolated latent doesn't work, so we transiently
|
||||
# create a 512x512 PIL image, upscale it, and run the inpainting
|
||||
# over it in img2img mode. Because the inpaing model is so conservative
|
||||
# it doesn't change the image (much)
|
||||
|
||||
return make_image
|
||||
|
||||
def get_noise_like(self, like: torch.Tensor, override_perlin: bool = False):
|
||||
device = like.device
|
||||
if device.type == "mps":
|
||||
x = torch.randn_like(like, device="cpu", dtype=self.torch_dtype()).to(
|
||||
device
|
||||
)
|
||||
else:
|
||||
x = torch.randn_like(like, device=device, dtype=self.torch_dtype())
|
||||
if self.perlin > 0.0 and override_perlin == False:
|
||||
shape = like.shape
|
||||
x = (1 - self.perlin) * x + self.perlin * self.get_perlin_noise(
|
||||
shape[3], shape[2]
|
||||
)
|
||||
return x
|
||||
|
||||
# returns a tensor filled with random numbers from a normal distribution
|
||||
def get_noise(self, width, height, scale=True):
|
||||
# print(f"Get noise: {width}x{height}")
|
||||
if scale:
|
||||
# Scale the input width and height for the initial generation
|
||||
# Make their area equivalent to the model's resolution area (e.g. 512*512 = 262144),
|
||||
# while keeping the minimum dimension at least 0.5 * resolution (e.g. 512*0.5 = 256)
|
||||
|
||||
aspect = width / height
|
||||
dimension = self.model.unet.config.sample_size * self.model.vae_scale_factor
|
||||
min_dimension = math.floor(dimension * 0.5)
|
||||
model_area = (
|
||||
dimension * dimension
|
||||
) # hardcoded for now since all models are trained on square images
|
||||
|
||||
if aspect > 1.0:
|
||||
init_height = max(min_dimension, math.sqrt(model_area / aspect))
|
||||
init_width = init_height * aspect
|
||||
else:
|
||||
init_width = max(min_dimension, math.sqrt(model_area * aspect))
|
||||
init_height = init_width / aspect
|
||||
|
||||
scaled_width, scaled_height = trim_to_multiple_of(
|
||||
math.floor(init_width), math.floor(init_height)
|
||||
)
|
||||
|
||||
else:
|
||||
scaled_width = width
|
||||
scaled_height = height
|
||||
|
||||
device = self.model.device
|
||||
channels = self.latent_channels
|
||||
if channels == 9:
|
||||
channels = 4 # we don't really want noise for all the mask channels
|
||||
shape = (
|
||||
1,
|
||||
channels,
|
||||
scaled_height // self.downsampling_factor,
|
||||
scaled_width // self.downsampling_factor,
|
||||
)
|
||||
if self.use_mps_noise or device.type == "mps":
|
||||
tensor = torch.empty(size=shape, device="cpu")
|
||||
tensor = self.get_noise_like(like=tensor).to(device)
|
||||
else:
|
||||
tensor = torch.empty(size=shape, device=device)
|
||||
tensor = self.get_noise_like(like=tensor)
|
||||
return tensor
|
@ -9,6 +9,7 @@ SAMPLER_CHOICES = [
|
||||
"ddpm",
|
||||
"deis",
|
||||
"lms",
|
||||
"lms_k",
|
||||
"pndm",
|
||||
"heun",
|
||||
"heun_k",
|
||||
@ -18,8 +19,13 @@ SAMPLER_CHOICES = [
|
||||
"kdpm_2",
|
||||
"kdpm_2_a",
|
||||
"dpmpp_2s",
|
||||
"dpmpp_2s_k",
|
||||
"dpmpp_2m",
|
||||
"dpmpp_2m_k",
|
||||
"dpmpp_2m_sde",
|
||||
"dpmpp_2m_sde_k",
|
||||
"dpmpp_sde",
|
||||
"dpmpp_sde_k",
|
||||
"unipc",
|
||||
]
|
||||
|
||||
|
@ -1,11 +1,6 @@
|
||||
"""
|
||||
Initialization file for invokeai.backend.model_management
|
||||
"""
|
||||
from .convert_ckpt_to_diffusers import (
|
||||
convert_ckpt_to_diffusers,
|
||||
load_pipeline_from_original_stable_diffusion_ckpt,
|
||||
)
|
||||
from .model_manager import ModelManager,SDModelComponent
|
||||
|
||||
|
||||
|
||||
from .model_manager import ModelManager, ModelInfo
|
||||
from .model_cache import ModelCache
|
||||
from .models import BaseModelType, ModelType, SubModelType, ModelVariantType
|
||||
|
@ -28,10 +28,13 @@ from safetensors.torch import load_file
|
||||
import invokeai.backend.util.logging as logger
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
|
||||
from .model_manager import ModelManager, SDLegacyType
|
||||
from .model_manager import ModelManager
|
||||
from .model_cache import ModelCache
|
||||
from .models import SchedulerPredictionType, BaseModelType, ModelVariantType
|
||||
|
||||
try:
|
||||
from omegaconf import OmegaConf
|
||||
from omegaconf.dictconfig import DictConfig
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"OmegaConf is required to convert the LDM checkpoints. Please install it with `pip install OmegaConf`."
|
||||
@ -56,10 +59,6 @@ from diffusers.pipelines.latent_diffusion.pipeline_latent_diffusion import (
|
||||
LDMBertConfig,
|
||||
LDMBertModel,
|
||||
)
|
||||
from diffusers.pipelines.paint_by_example import (
|
||||
PaintByExampleImageEncoder,
|
||||
PaintByExamplePipeline,
|
||||
)
|
||||
from diffusers.pipelines.stable_diffusion.safety_checker import (
|
||||
StableDiffusionSafetyChecker,
|
||||
)
|
||||
@ -74,6 +73,8 @@ from transformers import (
|
||||
|
||||
from ..stable_diffusion import StableDiffusionGeneratorPipeline
|
||||
|
||||
MODEL_ROOT = None
|
||||
|
||||
def shave_segments(path, n_shave_prefix_segments=1):
|
||||
"""
|
||||
Removes segments. Positive values shave the first segments, negative shave the last segments.
|
||||
@ -158,17 +159,17 @@ def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0):
|
||||
new_item = new_item.replace("norm.weight", "group_norm.weight")
|
||||
new_item = new_item.replace("norm.bias", "group_norm.bias")
|
||||
|
||||
new_item = new_item.replace("q.weight", "query.weight")
|
||||
new_item = new_item.replace("q.bias", "query.bias")
|
||||
new_item = new_item.replace("q.weight", "to_q.weight")
|
||||
new_item = new_item.replace("q.bias", "to_q.bias")
|
||||
|
||||
new_item = new_item.replace("k.weight", "key.weight")
|
||||
new_item = new_item.replace("k.bias", "key.bias")
|
||||
new_item = new_item.replace("k.weight", "to_k.weight")
|
||||
new_item = new_item.replace("k.bias", "to_k.bias")
|
||||
|
||||
new_item = new_item.replace("v.weight", "value.weight")
|
||||
new_item = new_item.replace("v.bias", "value.bias")
|
||||
new_item = new_item.replace("v.weight", "to_v.weight")
|
||||
new_item = new_item.replace("v.bias", "to_v.bias")
|
||||
|
||||
new_item = new_item.replace("proj_out.weight", "proj_attn.weight")
|
||||
new_item = new_item.replace("proj_out.bias", "proj_attn.bias")
|
||||
new_item = new_item.replace("proj_out.weight", "to_out.0.weight")
|
||||
new_item = new_item.replace("proj_out.bias", "to_out.0.bias")
|
||||
|
||||
new_item = shave_segments(
|
||||
new_item, n_shave_prefix_segments=n_shave_prefix_segments
|
||||
@ -183,7 +184,6 @@ def assign_to_checkpoint(
|
||||
paths,
|
||||
checkpoint,
|
||||
old_checkpoint,
|
||||
attention_paths_to_split=None,
|
||||
additional_replacements=None,
|
||||
config=None,
|
||||
):
|
||||
@ -198,35 +198,9 @@ def assign_to_checkpoint(
|
||||
paths, list
|
||||
), "Paths should be a list of dicts containing 'old' and 'new' keys."
|
||||
|
||||
# Splits the attention layers into three variables.
|
||||
if attention_paths_to_split is not None:
|
||||
for path, path_map in attention_paths_to_split.items():
|
||||
old_tensor = old_checkpoint[path]
|
||||
channels = old_tensor.shape[0] // 3
|
||||
|
||||
target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1)
|
||||
|
||||
num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3
|
||||
|
||||
old_tensor = old_tensor.reshape(
|
||||
(num_heads, 3 * channels // num_heads) + old_tensor.shape[1:]
|
||||
)
|
||||
query, key, value = old_tensor.split(channels // num_heads, dim=1)
|
||||
|
||||
checkpoint[path_map["query"]] = query.reshape(target_shape)
|
||||
checkpoint[path_map["key"]] = key.reshape(target_shape)
|
||||
checkpoint[path_map["value"]] = value.reshape(target_shape)
|
||||
|
||||
for path in paths:
|
||||
new_path = path["new"]
|
||||
|
||||
# These have already been assigned
|
||||
if (
|
||||
attention_paths_to_split is not None
|
||||
and new_path in attention_paths_to_split
|
||||
):
|
||||
continue
|
||||
|
||||
# Global renaming happens here
|
||||
new_path = new_path.replace("middle_block.0", "mid_block.resnets.0")
|
||||
new_path = new_path.replace("middle_block.1", "mid_block.attentions.0")
|
||||
@ -245,14 +219,14 @@ def assign_to_checkpoint(
|
||||
|
||||
def conv_attn_to_linear(checkpoint):
|
||||
keys = list(checkpoint.keys())
|
||||
attn_keys = ["query.weight", "key.weight", "value.weight"]
|
||||
attn_keys = ["to_q.weight", "to_k.weight", "to_v.weight"]
|
||||
for key in keys:
|
||||
if ".".join(key.split(".")[-2:]) in attn_keys:
|
||||
if checkpoint[key].ndim > 2:
|
||||
checkpoint[key] = checkpoint[key][:, :, 0, 0]
|
||||
elif "proj_attn.weight" in key:
|
||||
elif "to_out.0.weight" in key:
|
||||
if checkpoint[key].ndim > 2:
|
||||
checkpoint[key] = checkpoint[key][:, :, 0]
|
||||
checkpoint[key] = checkpoint[key][:, :, 0, 0]
|
||||
|
||||
|
||||
def create_unet_diffusers_config(original_config, image_size: int):
|
||||
@ -612,16 +586,29 @@ def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False
|
||||
|
||||
return new_checkpoint
|
||||
|
||||
|
||||
def convert_ldm_vae_checkpoint(checkpoint, config):
|
||||
# extract state dict for VAE
|
||||
vae_state_dict = {}
|
||||
vae_key = "first_stage_model."
|
||||
keys = list(checkpoint.keys())
|
||||
for key in keys:
|
||||
if key.startswith(vae_key):
|
||||
vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key)
|
||||
# Extract state dict for VAE. Works both with burnt-in
|
||||
# VAEs, and with standalone VAEs.
|
||||
|
||||
# checkpoint can either be a all-in-one stable diffusion
|
||||
# model, or an isolated vae .ckpt. This tests for
|
||||
# a key that will be present in the all-in-one model
|
||||
# that isn't present in the isolated ckpt.
|
||||
probe_key = "first_stage_model.encoder.conv_in.weight"
|
||||
if probe_key in checkpoint:
|
||||
vae_state_dict = {}
|
||||
vae_key = "first_stage_model."
|
||||
keys = list(checkpoint.keys())
|
||||
for key in keys:
|
||||
if key.startswith(vae_key):
|
||||
vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key)
|
||||
else:
|
||||
vae_state_dict = checkpoint
|
||||
|
||||
new_checkpoint = convert_ldm_vae_state_dict(vae_state_dict,config)
|
||||
return new_checkpoint
|
||||
|
||||
def convert_ldm_vae_state_dict(vae_state_dict, config):
|
||||
new_checkpoint = {}
|
||||
|
||||
new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"]
|
||||
@ -841,10 +828,7 @@ def convert_ldm_bert_checkpoint(checkpoint, config):
|
||||
|
||||
|
||||
def convert_ldm_clip_checkpoint(checkpoint):
|
||||
text_model = CLIPTextModel.from_pretrained(
|
||||
"openai/clip-vit-large-patch14", cache_dir=InvokeAIAppConfig.get_config().cache_dir
|
||||
)
|
||||
|
||||
text_model = CLIPTextModel.from_pretrained(MODEL_ROOT / 'clip-vit-large-patch14')
|
||||
keys = list(checkpoint.keys())
|
||||
|
||||
text_model_dict = {}
|
||||
@ -896,82 +880,10 @@ protected = {re.escape(x[0]): x[1] for x in textenc_transformer_conversion_lst}
|
||||
textenc_pattern = re.compile("|".join(protected.keys()))
|
||||
|
||||
|
||||
def convert_paint_by_example_checkpoint(checkpoint):
|
||||
cache_dir = InvokeAIAppConfig.get_config().cache_dir
|
||||
config = CLIPVisionConfig.from_pretrained(
|
||||
"openai/clip-vit-large-patch14", cache_dir=cache_dir
|
||||
)
|
||||
model = PaintByExampleImageEncoder(config)
|
||||
|
||||
keys = list(checkpoint.keys())
|
||||
|
||||
text_model_dict = {}
|
||||
|
||||
for key in keys:
|
||||
if key.startswith("cond_stage_model.transformer"):
|
||||
text_model_dict[key[len("cond_stage_model.transformer.") :]] = checkpoint[
|
||||
key
|
||||
]
|
||||
|
||||
# load clip vision
|
||||
model.model.load_state_dict(text_model_dict)
|
||||
|
||||
# load mapper
|
||||
keys_mapper = {
|
||||
k[len("cond_stage_model.mapper.res") :]: v
|
||||
for k, v in checkpoint.items()
|
||||
if k.startswith("cond_stage_model.mapper")
|
||||
}
|
||||
|
||||
MAPPING = {
|
||||
"attn.c_qkv": ["attn1.to_q", "attn1.to_k", "attn1.to_v"],
|
||||
"attn.c_proj": ["attn1.to_out.0"],
|
||||
"ln_1": ["norm1"],
|
||||
"ln_2": ["norm3"],
|
||||
"mlp.c_fc": ["ff.net.0.proj"],
|
||||
"mlp.c_proj": ["ff.net.2"],
|
||||
}
|
||||
|
||||
mapped_weights = {}
|
||||
for key, value in keys_mapper.items():
|
||||
prefix = key[: len("blocks.i")]
|
||||
suffix = key.split(prefix)[-1].split(".")[-1]
|
||||
name = key.split(prefix)[-1].split(suffix)[0][1:-1]
|
||||
mapped_names = MAPPING[name]
|
||||
|
||||
num_splits = len(mapped_names)
|
||||
for i, mapped_name in enumerate(mapped_names):
|
||||
new_name = ".".join([prefix, mapped_name, suffix])
|
||||
shape = value.shape[0] // num_splits
|
||||
mapped_weights[new_name] = value[i * shape : (i + 1) * shape]
|
||||
|
||||
model.mapper.load_state_dict(mapped_weights)
|
||||
|
||||
# load final layer norm
|
||||
model.final_layer_norm.load_state_dict(
|
||||
{
|
||||
"bias": checkpoint["cond_stage_model.final_ln.bias"],
|
||||
"weight": checkpoint["cond_stage_model.final_ln.weight"],
|
||||
}
|
||||
)
|
||||
|
||||
# load final proj
|
||||
model.proj_out.load_state_dict(
|
||||
{
|
||||
"bias": checkpoint["proj_out.bias"],
|
||||
"weight": checkpoint["proj_out.weight"],
|
||||
}
|
||||
)
|
||||
|
||||
# load uncond vector
|
||||
model.uncond_vector.data = torch.nn.Parameter(checkpoint["learnable_vector"])
|
||||
return model
|
||||
|
||||
|
||||
def convert_open_clip_checkpoint(checkpoint):
|
||||
cache_dir = InvokeAIAppConfig.get_config().cache_dir
|
||||
text_model = CLIPTextModel.from_pretrained(
|
||||
"stabilityai/stable-diffusion-2", subfolder="text_encoder", cache_dir=cache_dir
|
||||
MODEL_ROOT / 'stable-diffusion-2-clip',
|
||||
subfolder='text_encoder',
|
||||
)
|
||||
|
||||
keys = list(checkpoint.keys())
|
||||
@ -1047,22 +959,30 @@ def replace_checkpoint_vae(checkpoint, vae_path:str):
|
||||
new_key = f'first_stage_model.{vae_key}'
|
||||
checkpoint[new_key] = state_dict[vae_key]
|
||||
|
||||
def convert_ldm_vae_to_diffusers(checkpoint, vae_config: DictConfig, image_size: int)->AutoencoderKL:
|
||||
vae_config = create_vae_diffusers_config(
|
||||
vae_config, image_size=image_size
|
||||
)
|
||||
|
||||
converted_vae_checkpoint = convert_ldm_vae_checkpoint(
|
||||
checkpoint, vae_config
|
||||
)
|
||||
|
||||
vae = AutoencoderKL(**vae_config)
|
||||
vae.load_state_dict(converted_vae_checkpoint)
|
||||
return vae
|
||||
|
||||
def load_pipeline_from_original_stable_diffusion_ckpt(
|
||||
checkpoint_path: str,
|
||||
original_config_file: str = None,
|
||||
num_in_channels: int = None,
|
||||
scheduler_type: str = "pndm",
|
||||
pipeline_type: str = None,
|
||||
image_size: int = None,
|
||||
prediction_type: str = None,
|
||||
model_version: BaseModelType,
|
||||
model_variant: ModelVariantType,
|
||||
original_config_file: str,
|
||||
extract_ema: bool = True,
|
||||
upcast_attn: bool = False,
|
||||
vae: AutoencoderKL = None,
|
||||
vae_path: str = None,
|
||||
precision: torch.dtype = torch.float32,
|
||||
return_generator_pipeline: bool = False,
|
||||
scan_needed:bool=True,
|
||||
) -> Union[StableDiffusionPipeline, StableDiffusionGeneratorPipeline]:
|
||||
upcast_attention: bool = False,
|
||||
prediction_type: SchedulerPredictionType = SchedulerPredictionType.Epsilon,
|
||||
scan_needed: bool = True,
|
||||
) -> StableDiffusionPipeline:
|
||||
"""
|
||||
Load a Stable Diffusion pipeline object from a CompVis-style `.ckpt`/`.safetensors` file and (ideally) a `.yaml`
|
||||
config file.
|
||||
@ -1074,148 +994,68 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
|
||||
:param checkpoint_path: Path to `.ckpt` file.
|
||||
:param original_config_file: Path to `.yaml` config file corresponding to the original architecture.
|
||||
If `None`, will be automatically inferred by looking for a key that only exists in SD2.0 models.
|
||||
:param image_size: The image size that the model was trained on. Use 512 for Stable Diffusion v1.X and Stable Diffusion v2
|
||||
Base. Use 768 for Stable Diffusion v2.
|
||||
:param prediction_type: The prediction type that the model was trained on. Use `'epsilon'` for Stable Diffusion
|
||||
v1.X and Stable Diffusion v2 Base. Use `'v-prediction'` for Stable Diffusion v2.
|
||||
:param num_in_channels: The number of input channels. If `None` number of input channels will be automatically
|
||||
inferred.
|
||||
:param scheduler_type: Type of scheduler to use. Should be one of `["pndm", "lms", "heun", "euler",
|
||||
"euler-ancestral", "dpm", "ddim"]`. :param model_type: The pipeline type. `None` to automatically infer, or one of
|
||||
`["FrozenOpenCLIPEmbedder", "FrozenCLIPEmbedder", "PaintByExample"]`. :param extract_ema: Only relevant for
|
||||
`["FrozenOpenCLIPEmbedder", "FrozenCLIPEmbedder"]`. :param extract_ema: Only relevant for
|
||||
checkpoints that have both EMA and non-EMA weights. Whether to extract the EMA weights
|
||||
or not. Defaults to `False`. Pass `True` to extract the EMA weights. EMA weights usually yield higher
|
||||
quality images for inference. Non-EMA weights are usually better to continue fine-tuning.
|
||||
:param precision: precision to use - torch.float16, torch.float32 or torch.autocast
|
||||
:param upcast_attention: Whether the attention computation should always be upcasted. This is necessary when
|
||||
running stable diffusion 2.1.
|
||||
:param vae: A diffusers VAE to load into the pipeline.
|
||||
:param vae_path: Path to a checkpoint VAE that will be converted into diffusers and loaded into the pipeline.
|
||||
"""
|
||||
config = InvokeAIAppConfig.get_config()
|
||||
cache_dir = config.cache_dir
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
verbosity = dlogging.get_verbosity()
|
||||
dlogging.set_verbosity_error()
|
||||
|
||||
if Path(checkpoint_path).suffix == '.ckpt':
|
||||
if scan_needed:
|
||||
ModelManager.scan_model(checkpoint_path,checkpoint_path)
|
||||
checkpoint = torch.load(checkpoint_path)
|
||||
else:
|
||||
if str(checkpoint_path).endswith(".safetensors"):
|
||||
checkpoint = load_file(checkpoint_path)
|
||||
|
||||
pipeline_class = (
|
||||
StableDiffusionGeneratorPipeline
|
||||
if return_generator_pipeline
|
||||
else StableDiffusionPipeline
|
||||
)
|
||||
|
||||
# Sometimes models don't have the global_step item
|
||||
if "global_step" in checkpoint:
|
||||
global_step = checkpoint["global_step"]
|
||||
else:
|
||||
logger.debug("global_step key not found in model")
|
||||
global_step = None
|
||||
if scan_needed:
|
||||
ModelCache.scan_model(checkpoint_path, checkpoint_path)
|
||||
checkpoint = torch.load(checkpoint_path)
|
||||
|
||||
# sometimes there is a state_dict key and sometimes not
|
||||
if "state_dict" in checkpoint:
|
||||
checkpoint = checkpoint["state_dict"]
|
||||
|
||||
upcast_attention = False
|
||||
if original_config_file is None:
|
||||
model_type = ModelManager.probe_model_type(checkpoint)
|
||||
|
||||
if model_type == SDLegacyType.V2_v:
|
||||
original_config_file = (
|
||||
config.legacy_conf_path / "v2-inference-v.yaml"
|
||||
)
|
||||
if global_step == 110000:
|
||||
# v2.1 needs to upcast attention
|
||||
upcast_attention = True
|
||||
elif model_type == SDLegacyType.V2_e:
|
||||
original_config_file = (
|
||||
config.legacy_conf_path / "v2-inference.yaml"
|
||||
)
|
||||
elif model_type == SDLegacyType.V1_INPAINT:
|
||||
original_config_file = (
|
||||
config.legacy_conf_path / "v1-inpainting-inference.yaml"
|
||||
)
|
||||
|
||||
elif model_type == SDLegacyType.V1:
|
||||
original_config_file = (
|
||||
config.legacy_conf_path / "v1-inference.yaml"
|
||||
)
|
||||
|
||||
else:
|
||||
raise Exception("Unknown checkpoint type")
|
||||
|
||||
original_config = OmegaConf.load(original_config_file)
|
||||
|
||||
if num_in_channels is not None:
|
||||
original_config["model"]["params"]["unet_config"]["params"][
|
||||
"in_channels"
|
||||
] = num_in_channels
|
||||
|
||||
if (
|
||||
"parameterization" in original_config["model"]["params"]
|
||||
and original_config["model"]["params"]["parameterization"] == "v"
|
||||
):
|
||||
if prediction_type is None:
|
||||
# NOTE: For stable diffusion 2 base it is recommended to pass `prediction_type=="epsilon"`
|
||||
# as it relies on a brittle global step parameter here
|
||||
prediction_type = "epsilon" if global_step == 875000 else "v_prediction"
|
||||
if image_size is None:
|
||||
# NOTE: For stable diffusion 2 base one has to pass `image_size==512`
|
||||
# as it relies on a brittle global step parameter here
|
||||
image_size = 512 if global_step == 875000 else 768
|
||||
if model_version == BaseModelType.StableDiffusion2 and prediction_type == SchedulerPredictionType.VPrediction:
|
||||
image_size = 768
|
||||
else:
|
||||
if prediction_type is None:
|
||||
prediction_type = "epsilon"
|
||||
if image_size is None:
|
||||
image_size = 512
|
||||
image_size = 512
|
||||
|
||||
#
|
||||
# convert scheduler
|
||||
#
|
||||
|
||||
num_train_timesteps = original_config.model.params.timesteps
|
||||
beta_start = original_config.model.params.linear_start
|
||||
beta_end = original_config.model.params.linear_end
|
||||
|
||||
scheduler = DDIMScheduler(
|
||||
scheduler = PNDMScheduler(
|
||||
beta_end=beta_end,
|
||||
beta_schedule="scaled_linear",
|
||||
beta_start=beta_start,
|
||||
num_train_timesteps=num_train_timesteps,
|
||||
steps_offset=1,
|
||||
clip_sample=False,
|
||||
set_alpha_to_one=False,
|
||||
prediction_type=prediction_type,
|
||||
skip_prk_steps=True
|
||||
)
|
||||
# make sure scheduler works correctly with DDIM
|
||||
scheduler.register_to_config(clip_sample=False)
|
||||
|
||||
if scheduler_type == "pndm":
|
||||
config = dict(scheduler.config)
|
||||
config["skip_prk_steps"] = True
|
||||
scheduler = PNDMScheduler.from_config(config)
|
||||
elif scheduler_type == "lms":
|
||||
scheduler = LMSDiscreteScheduler.from_config(scheduler.config)
|
||||
elif scheduler_type == "heun":
|
||||
scheduler = HeunDiscreteScheduler.from_config(scheduler.config)
|
||||
elif scheduler_type == "euler":
|
||||
scheduler = EulerDiscreteScheduler.from_config(scheduler.config)
|
||||
elif scheduler_type == "euler-ancestral":
|
||||
scheduler = EulerAncestralDiscreteScheduler.from_config(scheduler.config)
|
||||
elif scheduler_type == "dpm":
|
||||
scheduler = DPMSolverMultistepScheduler.from_config(scheduler.config)
|
||||
elif scheduler_type == 'unipc':
|
||||
scheduler = UniPCMultistepScheduler.from_config(scheduler.config)
|
||||
elif scheduler_type == "ddim":
|
||||
scheduler = scheduler
|
||||
else:
|
||||
raise ValueError(f"Scheduler of type {scheduler_type} doesn't exist!")
|
||||
#
|
||||
# convert unet
|
||||
#
|
||||
|
||||
# Convert the UNet2DConditionModel model.
|
||||
unet_config = create_unet_diffusers_config(
|
||||
original_config, image_size=image_size
|
||||
)
|
||||
@ -1228,44 +1068,25 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
|
||||
|
||||
unet.load_state_dict(converted_unet_checkpoint)
|
||||
|
||||
# If a replacement VAE path was specified, we'll incorporate that into
|
||||
# the checkpoint model and then convert it
|
||||
if vae_path:
|
||||
logger.debug(f"Converting VAE {vae_path}")
|
||||
replace_checkpoint_vae(checkpoint,vae_path)
|
||||
# otherwise we use the original VAE, provided that
|
||||
# an externally loaded diffusers VAE was not passed
|
||||
elif not vae:
|
||||
logger.debug("Using checkpoint model's original VAE")
|
||||
#
|
||||
# convert vae
|
||||
#
|
||||
|
||||
if vae:
|
||||
logger.debug("Using replacement diffusers VAE")
|
||||
else: # convert the original or replacement VAE
|
||||
vae_config = create_vae_diffusers_config(
|
||||
original_config, image_size=image_size
|
||||
)
|
||||
converted_vae_checkpoint = convert_ldm_vae_checkpoint(
|
||||
checkpoint, vae_config
|
||||
)
|
||||
|
||||
vae = AutoencoderKL(**vae_config)
|
||||
vae.load_state_dict(converted_vae_checkpoint)
|
||||
vae = convert_ldm_vae_to_diffusers(
|
||||
checkpoint,
|
||||
original_config,
|
||||
image_size,
|
||||
)
|
||||
|
||||
# Convert the text model.
|
||||
model_type = pipeline_type
|
||||
if model_type is None:
|
||||
model_type = original_config.model.params.cond_stage_config.target.split(
|
||||
"."
|
||||
)[-1]
|
||||
|
||||
model_type = original_config.model.params.cond_stage_config.target.split(".")[-1]
|
||||
if model_type == "FrozenOpenCLIPEmbedder":
|
||||
text_model = convert_open_clip_checkpoint(checkpoint)
|
||||
tokenizer = CLIPTokenizer.from_pretrained(
|
||||
"stabilityai/stable-diffusion-2",
|
||||
subfolder="tokenizer",
|
||||
cache_dir=cache_dir,
|
||||
MODEL_ROOT / 'stable-diffusion-2-clip',
|
||||
subfolder='tokenizer',
|
||||
)
|
||||
pipe = pipeline_class(
|
||||
pipe = StableDiffusionPipeline(
|
||||
vae=vae.to(precision),
|
||||
text_encoder=text_model.to(precision),
|
||||
tokenizer=tokenizer,
|
||||
@ -1275,49 +1096,26 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
|
||||
feature_extractor=None,
|
||||
requires_safety_checker=False,
|
||||
)
|
||||
elif model_type == "PaintByExample":
|
||||
vision_model = convert_paint_by_example_checkpoint(checkpoint)
|
||||
tokenizer = CLIPTokenizer.from_pretrained(
|
||||
"openai/clip-vit-large-patch14", cache_dir=cache_dir
|
||||
)
|
||||
feature_extractor = AutoFeatureExtractor.from_pretrained(
|
||||
"CompVis/stable-diffusion-safety-checker", cache_dir=cache_dir
|
||||
)
|
||||
pipe = PaintByExamplePipeline(
|
||||
vae=vae,
|
||||
image_encoder=vision_model,
|
||||
unet=unet,
|
||||
scheduler=scheduler,
|
||||
safety_checker=None,
|
||||
feature_extractor=feature_extractor,
|
||||
)
|
||||
|
||||
elif model_type in ["FrozenCLIPEmbedder", "WeightedFrozenCLIPEmbedder"]:
|
||||
text_model = convert_ldm_clip_checkpoint(checkpoint)
|
||||
tokenizer = CLIPTokenizer.from_pretrained(
|
||||
"openai/clip-vit-large-patch14", cache_dir=cache_dir
|
||||
)
|
||||
safety_checker = StableDiffusionSafetyChecker.from_pretrained(
|
||||
"CompVis/stable-diffusion-safety-checker",
|
||||
cache_dir=cache_dir,
|
||||
)
|
||||
feature_extractor = AutoFeatureExtractor.from_pretrained(
|
||||
"CompVis/stable-diffusion-safety-checker", cache_dir=cache_dir
|
||||
)
|
||||
pipe = pipeline_class(
|
||||
tokenizer = CLIPTokenizer.from_pretrained(MODEL_ROOT / 'clip-vit-large-patch14')
|
||||
safety_checker = StableDiffusionSafetyChecker.from_pretrained(MODEL_ROOT / 'stable-diffusion-safety-checker')
|
||||
feature_extractor = AutoFeatureExtractor.from_pretrained(MODEL_ROOT / 'stable-diffusion-safety-checker')
|
||||
pipe = StableDiffusionPipeline(
|
||||
vae=vae.to(precision),
|
||||
text_encoder=text_model.to(precision),
|
||||
tokenizer=tokenizer,
|
||||
unet=unet.to(precision),
|
||||
scheduler=scheduler,
|
||||
safety_checker=None if return_generator_pipeline else safety_checker.to(precision),
|
||||
safety_checker=safety_checker.to(precision),
|
||||
feature_extractor=feature_extractor,
|
||||
)
|
||||
|
||||
else:
|
||||
text_config = create_ldm_bert_config(original_config)
|
||||
text_model = convert_ldm_bert_checkpoint(checkpoint, text_config)
|
||||
tokenizer = BertTokenizerFast.from_pretrained(
|
||||
"bert-base-uncased", cache_dir=cache_dir
|
||||
)
|
||||
tokenizer = BertTokenizerFast.from_pretrained(MODEL_ROOT / "bert-base-uncased")
|
||||
pipe = LDMTextToImagePipeline(
|
||||
vqvae=vae,
|
||||
bert=text_model,
|
||||
@ -1331,15 +1129,19 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
|
||||
|
||||
|
||||
def convert_ckpt_to_diffusers(
|
||||
checkpoint_path: Union[str, Path],
|
||||
dump_path: Union[str, Path],
|
||||
**kwargs,
|
||||
checkpoint_path: Union[str, Path],
|
||||
dump_path: Union[str, Path],
|
||||
model_root: Union[str, Path],
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Takes all the arguments of load_pipeline_from_original_stable_diffusion_ckpt(),
|
||||
and in addition a path-like object indicating the location of the desired diffusers
|
||||
model to be written.
|
||||
"""
|
||||
# setting global here to avoid massive changes late at night
|
||||
global MODEL_ROOT
|
||||
MODEL_ROOT = Path(model_root) / 'core/convert'
|
||||
pipe = load_pipeline_from_original_stable_diffusion_ckpt(checkpoint_path, **kwargs)
|
||||
|
||||
pipe.save_pretrained(
|
||||
|
678
invokeai/backend/model_management/lora.py
Normal file
678
invokeai/backend/model_management/lora.py
Normal file
@ -0,0 +1,678 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
from pathlib import Path
|
||||
from contextlib import contextmanager
|
||||
from typing import Optional, Dict, Tuple, Any
|
||||
|
||||
import torch
|
||||
from safetensors.torch import load_file
|
||||
from torch.utils.hooks import RemovableHandle
|
||||
|
||||
from diffusers.models import UNet2DConditionModel
|
||||
from transformers import CLIPTextModel
|
||||
|
||||
from compel.embeddings_provider import BaseTextualInversionManager
|
||||
|
||||
class LoRALayerBase:
|
||||
#rank: Optional[int]
|
||||
#alpha: Optional[float]
|
||||
#bias: Optional[torch.Tensor]
|
||||
#layer_key: str
|
||||
|
||||
#@property
|
||||
#def scale(self):
|
||||
# return self.alpha / self.rank if (self.alpha and self.rank) else 1.0
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
layer_key: str,
|
||||
values: dict,
|
||||
):
|
||||
if "alpha" in values:
|
||||
self.alpha = values["alpha"].item()
|
||||
else:
|
||||
self.alpha = None
|
||||
|
||||
if (
|
||||
"bias_indices" in values
|
||||
and "bias_values" in values
|
||||
and "bias_size" in values
|
||||
):
|
||||
self.bias = torch.sparse_coo_tensor(
|
||||
values["bias_indices"],
|
||||
values["bias_values"],
|
||||
tuple(values["bias_size"]),
|
||||
)
|
||||
|
||||
else:
|
||||
self.bias = None
|
||||
|
||||
self.rank = None # set in layer implementation
|
||||
self.layer_key = layer_key
|
||||
|
||||
def forward(
|
||||
self,
|
||||
module: torch.nn.Module,
|
||||
input_h: Any, # for real looks like Tuple[torch.nn.Tensor] but not sure
|
||||
multiplier: float,
|
||||
):
|
||||
if type(module) == torch.nn.Conv2d:
|
||||
op = torch.nn.functional.conv2d
|
||||
extra_args = dict(
|
||||
stride=module.stride,
|
||||
padding=module.padding,
|
||||
dilation=module.dilation,
|
||||
groups=module.groups,
|
||||
)
|
||||
|
||||
else:
|
||||
op = torch.nn.functional.linear
|
||||
extra_args = {}
|
||||
|
||||
weight = self.get_weight(module)
|
||||
|
||||
bias = self.bias if self.bias is not None else 0
|
||||
scale = self.alpha / self.rank if (self.alpha and self.rank) else 1.0
|
||||
return op(
|
||||
*input_h,
|
||||
(weight + bias).view(module.weight.shape),
|
||||
None,
|
||||
**extra_args,
|
||||
) * multiplier * scale
|
||||
|
||||
def get_weight(self, module: torch.nn.Module):
|
||||
raise NotImplementedError()
|
||||
|
||||
def calc_size(self) -> int:
|
||||
model_size = 0
|
||||
for val in [self.bias]:
|
||||
if val is not None:
|
||||
model_size += val.nelement() * val.element_size()
|
||||
return model_size
|
||||
|
||||
def to(
|
||||
self,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
if self.bias is not None:
|
||||
self.bias = self.bias.to(device=device, dtype=dtype)
|
||||
|
||||
|
||||
# TODO: find and debug lora/locon with bias
|
||||
class LoRALayer(LoRALayerBase):
|
||||
#up: torch.Tensor
|
||||
#mid: Optional[torch.Tensor]
|
||||
#down: torch.Tensor
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
layer_key: str,
|
||||
values: dict,
|
||||
):
|
||||
super().__init__(layer_key, values)
|
||||
|
||||
self.up = values["lora_up.weight"]
|
||||
self.down = values["lora_down.weight"]
|
||||
if "lora_mid.weight" in values:
|
||||
self.mid = values["lora_mid.weight"]
|
||||
else:
|
||||
self.mid = None
|
||||
|
||||
self.rank = self.down.shape[0]
|
||||
|
||||
def get_weight(self, module: torch.nn.Module):
|
||||
if self.mid is not None:
|
||||
up = self.up.reshape(up.shape[0], up.shape[1])
|
||||
down = self.down.reshape(up.shape[0], up.shape[1])
|
||||
weight = torch.einsum("m n w h, i m, n j -> i j w h", self.mid, up, down)
|
||||
else:
|
||||
weight = self.up.reshape(self.up.shape[0], -1) @ self.down.reshape(self.down.shape[0], -1)
|
||||
|
||||
return weight
|
||||
|
||||
def calc_size(self) -> int:
|
||||
model_size = super().calc_size()
|
||||
for val in [self.up, self.mid, self.down]:
|
||||
if val is not None:
|
||||
model_size += val.nelement() * val.element_size()
|
||||
return model_size
|
||||
|
||||
def to(
|
||||
self,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
super().to(device=device, dtype=dtype)
|
||||
|
||||
self.up = self.up.to(device=device, dtype=dtype)
|
||||
self.down = self.down.to(device=device, dtype=dtype)
|
||||
|
||||
if self.mid is not None:
|
||||
self.mid = self.mid.to(device=device, dtype=dtype)
|
||||
|
||||
|
||||
class LoHALayer(LoRALayerBase):
|
||||
#w1_a: torch.Tensor
|
||||
#w1_b: torch.Tensor
|
||||
#w2_a: torch.Tensor
|
||||
#w2_b: torch.Tensor
|
||||
#t1: Optional[torch.Tensor] = None
|
||||
#t2: Optional[torch.Tensor] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
layer_key: str,
|
||||
values: dict,
|
||||
):
|
||||
super().__init__(module_key, rank, alpha, bias)
|
||||
|
||||
self.w1_a = values["hada_w1_a"]
|
||||
self.w1_b = values["hada_w1_b"]
|
||||
self.w2_a = values["hada_w2_a"]
|
||||
self.w2_b = values["hada_w2_b"]
|
||||
|
||||
if "hada_t1" in values:
|
||||
self.t1 = values["hada_t1"]
|
||||
else:
|
||||
self.t1 = None
|
||||
|
||||
if "hada_t2" in values:
|
||||
self.t2 = values["hada_t2"]
|
||||
else:
|
||||
self.t2 = None
|
||||
|
||||
self.rank = self.w1_b.shape[0]
|
||||
|
||||
def get_weight(self, module: torch.nn.Module):
|
||||
if self.t1 is None:
|
||||
weight = (self.w1_a @ self.w1_b) * (self.w2_a @ self.w2_b)
|
||||
|
||||
else:
|
||||
rebuild1 = torch.einsum(
|
||||
"i j k l, j r, i p -> p r k l", self.t1, self.w1_b, self.w1_a
|
||||
)
|
||||
rebuild2 = torch.einsum(
|
||||
"i j k l, j r, i p -> p r k l", self.t2, self.w2_b, self.w2_a
|
||||
)
|
||||
weight = rebuild1 * rebuild2
|
||||
|
||||
return weight
|
||||
|
||||
def calc_size(self) -> int:
|
||||
model_size = super().calc_size()
|
||||
for val in [self.w1_a, self.w1_b, self.w2_a, self.w2_b, self.t1, self.t2]:
|
||||
if val is not None:
|
||||
model_size += val.nelement() * val.element_size()
|
||||
return model_size
|
||||
|
||||
def to(
|
||||
self,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
super().to(device=device, dtype=dtype)
|
||||
|
||||
self.w1_a = self.w1_a.to(device=device, dtype=dtype)
|
||||
self.w1_b = self.w1_b.to(device=device, dtype=dtype)
|
||||
if self.t1 is not None:
|
||||
self.t1 = self.t1.to(device=device, dtype=dtype)
|
||||
|
||||
self.w2_a = self.w2_a.to(device=device, dtype=dtype)
|
||||
self.w2_b = self.w2_b.to(device=device, dtype=dtype)
|
||||
if self.t2 is not None:
|
||||
self.t2 = self.t2.to(device=device, dtype=dtype)
|
||||
|
||||
|
||||
class LoKRLayer(LoRALayerBase):
|
||||
#w1: Optional[torch.Tensor] = None
|
||||
#w1_a: Optional[torch.Tensor] = None
|
||||
#w1_b: Optional[torch.Tensor] = None
|
||||
#w2: Optional[torch.Tensor] = None
|
||||
#w2_a: Optional[torch.Tensor] = None
|
||||
#w2_b: Optional[torch.Tensor] = None
|
||||
#t2: Optional[torch.Tensor] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
layer_key: str,
|
||||
values: dict,
|
||||
):
|
||||
super().__init__(module_key, rank, alpha, bias)
|
||||
|
||||
if "lokr_w1" in values:
|
||||
self.w1 = values["lokr_w1"]
|
||||
self.w1_a = None
|
||||
self.w1_b = None
|
||||
else:
|
||||
self.w1 = None
|
||||
self.w1_a = values["lokr_w1_a"]
|
||||
self.w1_b = values["lokr_w1_b"]
|
||||
|
||||
if "lokr_w2" in values:
|
||||
self.w2 = values["lokr_w2"]
|
||||
self.w2_a = None
|
||||
self.w2_b = None
|
||||
else:
|
||||
self.w2 = None
|
||||
self.w2_a = values["lokr_w2_a"]
|
||||
self.w2_b = values["lokr_w2_b"]
|
||||
|
||||
if "lokr_t2" in values:
|
||||
self.t2 = values["lokr_t2"]
|
||||
else:
|
||||
self.t2 = None
|
||||
|
||||
if "lokr_w1_b" in values:
|
||||
self.rank = values["lokr_w1_b"].shape[0]
|
||||
elif "lokr_w2_b" in values:
|
||||
self.rank = values["lokr_w2_b"].shape[0]
|
||||
else:
|
||||
self.rank = None # unscaled
|
||||
|
||||
def get_weight(self, module: torch.nn.Module):
|
||||
w1 = self.w1
|
||||
if w1 is None:
|
||||
w1 = self.w1_a @ self.w1_b
|
||||
|
||||
w2 = self.w2
|
||||
if w2 is None:
|
||||
if self.t2 is None:
|
||||
w2 = self.w2_a @ self.w2_b
|
||||
else:
|
||||
w2 = torch.einsum('i j k l, i p, j r -> p r k l', self.t2, self.w2_a, self.w2_b)
|
||||
|
||||
if len(w2.shape) == 4:
|
||||
w1 = w1.unsqueeze(2).unsqueeze(2)
|
||||
w2 = w2.contiguous()
|
||||
weight = torch.kron(w1, w2).reshape(module.weight.shape) # TODO: can we remove reshape?
|
||||
|
||||
return weight
|
||||
|
||||
def calc_size(self) -> int:
|
||||
model_size = super().calc_size()
|
||||
for val in [self.w1, self.w1_a, self.w1_b, self.w2, self.w2_a, self.w2_b, self.t2]:
|
||||
if val is not None:
|
||||
model_size += val.nelement() * val.element_size()
|
||||
return model_size
|
||||
|
||||
def to(
|
||||
self,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
super().to(device=device, dtype=dtype)
|
||||
|
||||
if self.w1 is not None:
|
||||
self.w1 = self.w1.to(device=device, dtype=dtype)
|
||||
else:
|
||||
self.w1_a = self.w1_a.to(device=device, dtype=dtype)
|
||||
self.w1_b = self.w1_b.to(device=device, dtype=dtype)
|
||||
|
||||
if self.w2 is not None:
|
||||
self.w2 = self.w2.to(device=device, dtype=dtype)
|
||||
else:
|
||||
self.w2_a = self.w2_a.to(device=device, dtype=dtype)
|
||||
self.w2_b = self.w2_b.to(device=device, dtype=dtype)
|
||||
|
||||
if self.t2 is not None:
|
||||
self.t2 = self.t2.to(device=device, dtype=dtype)
|
||||
|
||||
|
||||
class LoRAModel: #(torch.nn.Module):
|
||||
_name: str
|
||||
layers: Dict[str, LoRALayer]
|
||||
_device: torch.device
|
||||
_dtype: torch.dtype
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
layers: Dict[str, LoRALayer],
|
||||
device: torch.device,
|
||||
dtype: torch.dtype,
|
||||
):
|
||||
self._name = name
|
||||
self._device = device or torch.cpu
|
||||
self._dtype = dtype or torch.float32
|
||||
self.layers = layers
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
return self._name
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return self._device
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
return self._dtype
|
||||
|
||||
def to(
|
||||
self,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
) -> LoRAModel:
|
||||
# TODO: try revert if exception?
|
||||
for key, layer in self.layers.items():
|
||||
layer.to(device=device, dtype=dtype)
|
||||
self._device = device
|
||||
self._dtype = dtype
|
||||
|
||||
def calc_size(self) -> int:
|
||||
model_size = 0
|
||||
for _, layer in self.layers.items():
|
||||
model_size += layer.calc_size()
|
||||
return model_size
|
||||
|
||||
@classmethod
|
||||
def from_checkpoint(
|
||||
cls,
|
||||
file_path: Union[str, Path],
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
device = device or torch.device("cpu")
|
||||
dtype = dtype or torch.float32
|
||||
|
||||
if isinstance(file_path, str):
|
||||
file_path = Path(file_path)
|
||||
|
||||
model = cls(
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
name=file_path.stem, # TODO:
|
||||
layers=dict(),
|
||||
)
|
||||
|
||||
if file_path.suffix == ".safetensors":
|
||||
state_dict = load_file(file_path.absolute().as_posix(), device="cpu")
|
||||
else:
|
||||
state_dict = torch.load(file_path, map_location="cpu")
|
||||
|
||||
state_dict = cls._group_state(state_dict)
|
||||
|
||||
for layer_key, values in state_dict.items():
|
||||
|
||||
# lora and locon
|
||||
if "lora_down.weight" in values:
|
||||
layer = LoRALayer(layer_key, values)
|
||||
|
||||
# loha
|
||||
elif "hada_w1_b" in values:
|
||||
layer = LoHALayer(layer_key, values)
|
||||
|
||||
# lokr
|
||||
elif "lokr_w1_b" in values or "lokr_w1" in values:
|
||||
layer = LoKRLayer(layer_key, values)
|
||||
|
||||
else:
|
||||
# TODO: diff/ia3/... format
|
||||
print(
|
||||
f">> Encountered unknown lora layer module in {self.name}: {layer_key}"
|
||||
)
|
||||
return
|
||||
|
||||
# lower memory consumption by removing already parsed layer values
|
||||
state_dict[layer_key].clear()
|
||||
|
||||
layer.to(device=device, dtype=dtype)
|
||||
model.layers[layer_key] = layer
|
||||
|
||||
return model
|
||||
|
||||
@staticmethod
|
||||
def _group_state(state_dict: dict):
|
||||
state_dict_groupped = dict()
|
||||
|
||||
for key, value in state_dict.items():
|
||||
stem, leaf = key.split(".", 1)
|
||||
if stem not in state_dict_groupped:
|
||||
state_dict_groupped[stem] = dict()
|
||||
state_dict_groupped[stem][leaf] = value
|
||||
|
||||
return state_dict_groupped
|
||||
|
||||
|
||||
"""
|
||||
loras = [
|
||||
(lora_model1, 0.7),
|
||||
(lora_model2, 0.4),
|
||||
]
|
||||
with LoRAHelper.apply_lora_unet(unet, loras):
|
||||
# unet with applied loras
|
||||
# unmodified unet
|
||||
|
||||
"""
|
||||
# TODO: rename smth like ModelPatcher and add TI method?
|
||||
class ModelPatcher:
|
||||
|
||||
@staticmethod
|
||||
def _resolve_lora_key(model: torch.nn.Module, lora_key: str, prefix: str) -> Tuple[str, torch.nn.Module]:
|
||||
assert "." not in lora_key
|
||||
|
||||
if not lora_key.startswith(prefix):
|
||||
raise Exception(f"lora_key with invalid prefix: {lora_key}, {prefix}")
|
||||
|
||||
module = model
|
||||
module_key = ""
|
||||
key_parts = lora_key[len(prefix):].split('_')
|
||||
|
||||
submodule_name = key_parts.pop(0)
|
||||
|
||||
while len(key_parts) > 0:
|
||||
try:
|
||||
module = module.get_submodule(submodule_name)
|
||||
module_key += "." + submodule_name
|
||||
submodule_name = key_parts.pop(0)
|
||||
except:
|
||||
submodule_name += "_" + key_parts.pop(0)
|
||||
|
||||
module = module.get_submodule(submodule_name)
|
||||
module_key = module_key.rstrip(".")
|
||||
|
||||
return (module_key, module)
|
||||
|
||||
@staticmethod
|
||||
def _lora_forward_hook(
|
||||
applied_loras: List[Tuple[LoraModel, float]],
|
||||
layer_name: str,
|
||||
):
|
||||
|
||||
def lora_forward(module, input_h, output):
|
||||
if len(applied_loras) == 0:
|
||||
return output
|
||||
|
||||
for lora, weight in applied_loras:
|
||||
layer = lora.layers.get(layer_name, None)
|
||||
if layer is None:
|
||||
continue
|
||||
output += layer.forward(module, input_h, weight)
|
||||
return output
|
||||
|
||||
return lora_forward
|
||||
|
||||
|
||||
@classmethod
|
||||
@contextmanager
|
||||
def apply_lora_unet(
|
||||
cls,
|
||||
unet: UNet2DConditionModel,
|
||||
loras: List[Tuple[LoRAModel, float]],
|
||||
):
|
||||
with cls.apply_lora(unet, loras, "lora_unet_"):
|
||||
yield
|
||||
|
||||
|
||||
@classmethod
|
||||
@contextmanager
|
||||
def apply_lora_text_encoder(
|
||||
cls,
|
||||
text_encoder: CLIPTextModel,
|
||||
loras: List[Tuple[LoRAModel, float]],
|
||||
):
|
||||
with cls.apply_lora(text_encoder, loras, "lora_te_"):
|
||||
yield
|
||||
|
||||
|
||||
@classmethod
|
||||
@contextmanager
|
||||
def apply_lora(
|
||||
cls,
|
||||
model: torch.nn.Module,
|
||||
loras: List[Tuple[LoraModel, float]],
|
||||
prefix: str,
|
||||
):
|
||||
hooks = dict()
|
||||
try:
|
||||
for lora, lora_weight in loras:
|
||||
for layer_key, layer in lora.layers.items():
|
||||
if not layer_key.startswith(prefix):
|
||||
continue
|
||||
|
||||
module_key, module = cls._resolve_lora_key(model, layer_key, prefix)
|
||||
if module_key not in hooks:
|
||||
hooks[module_key] = module.register_forward_hook(cls._lora_forward_hook(loras, layer_key))
|
||||
|
||||
yield # wait for context manager exit
|
||||
|
||||
finally:
|
||||
for module_key, hook in hooks.items():
|
||||
hook.remove()
|
||||
hooks.clear()
|
||||
|
||||
|
||||
@classmethod
|
||||
@contextmanager
|
||||
def apply_ti(
|
||||
cls,
|
||||
tokenizer: CLIPTokenizer,
|
||||
text_encoder: CLIPTextModel,
|
||||
ti_list: List[Any],
|
||||
) -> Tuple[CLIPTokenizer, TextualInversionManager]:
|
||||
init_tokens_count = None
|
||||
new_tokens_added = None
|
||||
|
||||
try:
|
||||
ti_tokenizer = copy.deepcopy(tokenizer)
|
||||
ti_manager = TextualInversionManager(ti_tokenizer)
|
||||
init_tokens_count = text_encoder.resize_token_embeddings(None).num_embeddings
|
||||
|
||||
def _get_trigger(ti, index):
|
||||
trigger = ti.name
|
||||
if index > 0:
|
||||
trigger += f"-!pad-{i}"
|
||||
return f"<{trigger}>"
|
||||
|
||||
# modify tokenizer
|
||||
new_tokens_added = 0
|
||||
for ti in ti_list:
|
||||
for i in range(ti.embedding.shape[0]):
|
||||
new_tokens_added += ti_tokenizer.add_tokens(_get_trigger(ti, i))
|
||||
|
||||
# modify text_encoder
|
||||
text_encoder.resize_token_embeddings(init_tokens_count + new_tokens_added)
|
||||
model_embeddings = text_encoder.get_input_embeddings()
|
||||
|
||||
for ti in ti_list:
|
||||
ti_tokens = []
|
||||
for i in range(ti.embedding.shape[0]):
|
||||
embedding = ti.embedding[i]
|
||||
trigger = _get_trigger(ti, i)
|
||||
|
||||
token_id = ti_tokenizer.convert_tokens_to_ids(trigger)
|
||||
if token_id == ti_tokenizer.unk_token_id:
|
||||
raise RuntimeError(f"Unable to find token id for token '{trigger}'")
|
||||
|
||||
if model_embeddings.weight.data[token_id].shape != embedding.shape:
|
||||
raise ValueError(
|
||||
f"Cannot load embedding for {trigger}. It was trained on a model with token dimension {embedding.shape[0]}, but the current model has token dimension {model_embeddings.weight.data[token_id].shape[0]}."
|
||||
)
|
||||
|
||||
model_embeddings.weight.data[token_id] = embedding
|
||||
ti_tokens.append(token_id)
|
||||
|
||||
if len(ti_tokens) > 1:
|
||||
ti_manager.pad_tokens[ti_tokens[0]] = ti_tokens[1:]
|
||||
|
||||
yield ti_tokenizer, ti_manager
|
||||
|
||||
finally:
|
||||
if init_tokens_count and new_tokens_added:
|
||||
text_encoder.resize_token_embeddings(init_tokens_count)
|
||||
|
||||
|
||||
class TextualInversionModel:
|
||||
name: str
|
||||
embedding: torch.Tensor # [n, 768]|[n, 1280]
|
||||
|
||||
@classmethod
|
||||
def from_checkpoint(
|
||||
cls,
|
||||
file_path: Union[str, Path],
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
if not isinstance(file_path, Path):
|
||||
file_path = Path(file_path)
|
||||
|
||||
result = cls() # TODO:
|
||||
result.name = file_path.stem # TODO:
|
||||
|
||||
if file_path.suffix == ".safetensors":
|
||||
state_dict = load_file(file_path.absolute().as_posix(), device="cpu")
|
||||
else:
|
||||
state_dict = torch.load(file_path, map_location="cpu")
|
||||
|
||||
# both v1 and v2 format embeddings
|
||||
# difference mostly in metadata
|
||||
if "string_to_param" in state_dict:
|
||||
if len(state_dict["string_to_param"]) > 1:
|
||||
print(f"Warn: Embedding \"{file_path.name}\" contains multiple tokens, which is not supported. The first token will be used.")
|
||||
|
||||
result.embedding = next(iter(state_dict["string_to_param"].values()))
|
||||
|
||||
# v3 (easynegative)
|
||||
elif "emb_params" in state_dict:
|
||||
result.embedding = state_dict["emb_params"]
|
||||
|
||||
# v4(diffusers bin files)
|
||||
else:
|
||||
result.embedding = next(iter(state_dict.values()))
|
||||
|
||||
if not isinstance(result.embedding, torch.Tensor):
|
||||
raise ValueError(f"Invalid embeddings file: {file_path.name}")
|
||||
|
||||
return result
|
||||
|
||||
|
||||
class TextualInversionManager(BaseTextualInversionManager):
|
||||
pad_tokens: Dict[int, List[int]]
|
||||
tokenizer: CLIPTokenizer
|
||||
|
||||
def __init__(self, tokenizer: CLIPTokenizer):
|
||||
self.pad_tokens = dict()
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
def expand_textual_inversion_token_ids_if_necessary(
|
||||
self, token_ids: list[int]
|
||||
) -> list[int]:
|
||||
|
||||
if len(self.pad_tokens) == 0:
|
||||
return token_ids
|
||||
|
||||
if token_ids[0] == self.tokenizer.bos_token_id:
|
||||
raise ValueError("token_ids must not start with bos_token_id")
|
||||
if token_ids[-1] == self.tokenizer.eos_token_id:
|
||||
raise ValueError("token_ids must not end with eos_token_id")
|
||||
|
||||
new_token_ids = []
|
||||
for token_id in token_ids:
|
||||
new_token_ids.append(token_id)
|
||||
if token_id in self.pad_tokens:
|
||||
new_token_ids.extend(self.pad_tokens[token_id])
|
||||
|
||||
return new_token_ids
|
||||
|
391
invokeai/backend/model_management/model_cache.py
Normal file
391
invokeai/backend/model_management/model_cache.py
Normal file
@ -0,0 +1,391 @@
|
||||
"""
|
||||
Manage a RAM cache of diffusion/transformer models for fast switching.
|
||||
They are moved between GPU VRAM and CPU RAM as necessary. If the cache
|
||||
grows larger than a preset maximum, then the least recently used
|
||||
model will be cleared and (re)loaded from disk when next needed.
|
||||
|
||||
The cache returns context manager generators designed to load the
|
||||
model into the GPU within the context, and unload outside the
|
||||
context. Use like this:
|
||||
|
||||
cache = ModelCache(max_models_cached=6)
|
||||
with cache.get_model('runwayml/stable-diffusion-1-5') as SD1,
|
||||
cache.get_model('stabilityai/stable-diffusion-2') as SD2:
|
||||
do_something_in_GPU(SD1,SD2)
|
||||
|
||||
|
||||
"""
|
||||
|
||||
import gc
|
||||
import os
|
||||
import sys
|
||||
import hashlib
|
||||
from contextlib import suppress
|
||||
from pathlib import Path
|
||||
from typing import Dict, Union, types, Optional, Type, Any
|
||||
|
||||
import torch
|
||||
|
||||
import logging
|
||||
import invokeai.backend.util.logging as logger
|
||||
from invokeai.app.services.config import get_invokeai_config
|
||||
from .lora import LoRAModel, TextualInversionModel
|
||||
from .models import BaseModelType, ModelType, SubModelType, ModelBase
|
||||
|
||||
# Maximum size of the cache, in gigs
|
||||
# Default is roughly enough to hold three fp16 diffusers models in RAM simultaneously
|
||||
DEFAULT_MAX_CACHE_SIZE = 6.0
|
||||
|
||||
# actual size of a gig
|
||||
GIG = 1073741824
|
||||
|
||||
class ModelLocker(object):
|
||||
"Forward declaration"
|
||||
pass
|
||||
|
||||
class ModelCache(object):
|
||||
"Forward declaration"
|
||||
pass
|
||||
|
||||
class _CacheRecord:
|
||||
size: int
|
||||
model: Any
|
||||
cache: ModelCache
|
||||
_locks: int
|
||||
|
||||
def __init__(self, cache, model: Any, size: int):
|
||||
self.size = size
|
||||
self.model = model
|
||||
self.cache = cache
|
||||
self._locks = 0
|
||||
|
||||
def lock(self):
|
||||
self._locks += 1
|
||||
|
||||
def unlock(self):
|
||||
self._locks -= 1
|
||||
assert self._locks >= 0
|
||||
|
||||
@property
|
||||
def locked(self):
|
||||
return self._locks > 0
|
||||
|
||||
@property
|
||||
def loaded(self):
|
||||
if self.model is not None and hasattr(self.model, "device"):
|
||||
return self.model.device != self.cache.storage_device
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
class ModelCache(object):
|
||||
def __init__(
|
||||
self,
|
||||
max_cache_size: float=DEFAULT_MAX_CACHE_SIZE,
|
||||
execution_device: torch.device=torch.device('cuda'),
|
||||
storage_device: torch.device=torch.device('cpu'),
|
||||
precision: torch.dtype=torch.float16,
|
||||
sequential_offload: bool=False,
|
||||
lazy_offloading: bool=True,
|
||||
sha_chunksize: int = 16777216,
|
||||
logger: types.ModuleType = logger
|
||||
):
|
||||
'''
|
||||
:param max_models: Maximum number of models to cache in CPU RAM [4]
|
||||
:param execution_device: Torch device to load active model into [torch.device('cuda')]
|
||||
:param storage_device: Torch device to save inactive model in [torch.device('cpu')]
|
||||
:param precision: Precision for loaded models [torch.float16]
|
||||
:param lazy_offloading: Keep model in VRAM until another model needs to be loaded
|
||||
:param sequential_offload: Conserve VRAM by loading and unloading each stage of the pipeline sequentially
|
||||
:param sha_chunksize: Chunksize to use when calculating sha256 model hash
|
||||
'''
|
||||
#max_cache_size = 9999
|
||||
execution_device = torch.device('cuda')
|
||||
|
||||
self.model_infos: Dict[str, ModelBase] = dict()
|
||||
self.lazy_offloading = lazy_offloading
|
||||
#self.sequential_offload: bool=sequential_offload
|
||||
self.precision: torch.dtype=precision
|
||||
self.max_cache_size: int=max_cache_size
|
||||
self.execution_device: torch.device=execution_device
|
||||
self.storage_device: torch.device=storage_device
|
||||
self.sha_chunksize=sha_chunksize
|
||||
self.logger = logger
|
||||
|
||||
self._cached_models = dict()
|
||||
self._cache_stack = list()
|
||||
|
||||
def get_key(
|
||||
self,
|
||||
model_path: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
):
|
||||
|
||||
key = f"{model_path}:{base_model}:{model_type}"
|
||||
if submodel_type:
|
||||
key += f":{submodel_type}"
|
||||
return key
|
||||
|
||||
#def get_model(
|
||||
# self,
|
||||
# repo_id_or_path: Union[str, Path],
|
||||
# model_type: ModelType = ModelType.Diffusers,
|
||||
# subfolder: Path = None,
|
||||
# submodel: ModelType = None,
|
||||
# revision: str = None,
|
||||
# attach_model_part: Tuple[ModelType, str] = (None, None),
|
||||
# gpu_load: bool = True,
|
||||
#) -> ModelLocker: # ?? what does it return
|
||||
def _get_model_info(
|
||||
self,
|
||||
model_path: str,
|
||||
model_class: Type[ModelBase],
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
):
|
||||
model_info_key = self.get_key(
|
||||
model_path=model_path,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
submodel_type=None,
|
||||
)
|
||||
|
||||
if model_info_key not in self.model_infos:
|
||||
self.model_infos[model_info_key] = model_class(
|
||||
model_path,
|
||||
base_model,
|
||||
model_type,
|
||||
)
|
||||
|
||||
return self.model_infos[model_info_key]
|
||||
|
||||
# TODO: args
|
||||
def get_model(
|
||||
self,
|
||||
model_path: Union[str, Path],
|
||||
model_class: Type[ModelBase],
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
submodel: Optional[SubModelType] = None,
|
||||
gpu_load: bool = True,
|
||||
) -> Any:
|
||||
|
||||
if not isinstance(model_path, Path):
|
||||
model_path = Path(model_path)
|
||||
|
||||
if not os.path.exists(model_path):
|
||||
raise Exception(f"Model not found: {model_path}")
|
||||
|
||||
model_info = self._get_model_info(
|
||||
model_path=model_path,
|
||||
model_class=model_class,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
)
|
||||
key = self.get_key(
|
||||
model_path=model_path,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
submodel_type=submodel,
|
||||
)
|
||||
|
||||
# TODO: lock for no copies on simultaneous calls?
|
||||
cache_entry = self._cached_models.get(key, None)
|
||||
if cache_entry is None:
|
||||
self.logger.info(f'Loading model {model_path}, type {base_model}:{model_type}:{submodel}')
|
||||
|
||||
# this will remove older cached models until
|
||||
# there is sufficient room to load the requested model
|
||||
self._make_cache_room(model_info.get_size(submodel))
|
||||
|
||||
# clean memory to make MemoryUsage() more accurate
|
||||
gc.collect()
|
||||
model = model_info.get_model(child_type=submodel, torch_dtype=self.precision)
|
||||
if mem_used := model_info.get_size(submodel):
|
||||
self.logger.debug(f'CPU RAM used for load: {(mem_used/GIG):.2f} GB')
|
||||
|
||||
cache_entry = _CacheRecord(self, model, mem_used)
|
||||
self._cached_models[key] = cache_entry
|
||||
|
||||
with suppress(Exception):
|
||||
self._cache_stack.remove(key)
|
||||
self._cache_stack.append(key)
|
||||
|
||||
return self.ModelLocker(self, key, cache_entry.model, gpu_load)
|
||||
|
||||
class ModelLocker(object):
|
||||
def __init__(self, cache, key, model, gpu_load):
|
||||
self.gpu_load = gpu_load
|
||||
self.cache = cache
|
||||
self.key = key
|
||||
self.model = model
|
||||
self.cache_entry = self.cache._cached_models[self.key]
|
||||
|
||||
def __enter__(self) -> Any:
|
||||
if not hasattr(self.model, 'to'):
|
||||
return self.model
|
||||
|
||||
# NOTE that the model has to have the to() method in order for this
|
||||
# code to move it into GPU!
|
||||
if self.gpu_load:
|
||||
self.cache_entry.lock()
|
||||
|
||||
try:
|
||||
if self.cache.lazy_offloading:
|
||||
self.cache._offload_unlocked_models()
|
||||
|
||||
if self.model.device != self.cache.execution_device:
|
||||
self.cache.logger.debug(f'Moving {self.key} into {self.cache.execution_device}')
|
||||
with VRAMUsage() as mem:
|
||||
self.model.to(self.cache.execution_device) # move into GPU
|
||||
self.cache.logger.debug(f'GPU VRAM used for load: {(mem.vram_used/GIG):.2f} GB')
|
||||
|
||||
self.cache.logger.debug(f'Locking {self.key} in {self.cache.execution_device}')
|
||||
self.cache._print_cuda_stats()
|
||||
|
||||
except:
|
||||
self.cache_entry.unlock()
|
||||
raise
|
||||
|
||||
|
||||
# TODO: not fully understand
|
||||
# in the event that the caller wants the model in RAM, we
|
||||
# move it into CPU if it is in GPU and not locked
|
||||
elif self.cache_entry.loaded and not self.cache_entry.locked:
|
||||
self.model.to(self.cache.storage_device)
|
||||
|
||||
return self.model
|
||||
|
||||
def __exit__(self, type, value, traceback):
|
||||
if not hasattr(self.model, 'to'):
|
||||
return
|
||||
|
||||
self.cache_entry.unlock()
|
||||
if not self.cache.lazy_offloading:
|
||||
self.cache._offload_unlocked_models()
|
||||
self.cache._print_cuda_stats()
|
||||
|
||||
# TODO: should it be called untrack_model?
|
||||
def uncache_model(self, cache_id: str):
|
||||
with suppress(ValueError):
|
||||
self._cache_stack.remove(cache_id)
|
||||
self._cached_models.pop(cache_id, None)
|
||||
|
||||
def model_hash(
|
||||
self,
|
||||
model_path: Union[str, Path],
|
||||
) -> str:
|
||||
'''
|
||||
Given the HF repo id or path to a model on disk, returns a unique
|
||||
hash. Works for legacy checkpoint files, HF models on disk, and HF repo IDs
|
||||
:param model_path: Path to model file/directory on disk.
|
||||
'''
|
||||
return self._local_model_hash(model_path)
|
||||
|
||||
def cache_size(self) -> float:
|
||||
"Return the current size of the cache, in GB"
|
||||
current_cache_size = sum([m.size for m in self._cached_models.values()])
|
||||
return current_cache_size / GIG
|
||||
|
||||
def _has_cuda(self) -> bool:
|
||||
return self.execution_device.type == 'cuda'
|
||||
|
||||
def _print_cuda_stats(self):
|
||||
vram = "%4.2fG" % (torch.cuda.memory_allocated() / GIG)
|
||||
ram = "%4.2fG" % self.cache_size()
|
||||
|
||||
cached_models = 0
|
||||
loaded_models = 0
|
||||
locked_models = 0
|
||||
for model_info in self._cached_models.values():
|
||||
cached_models += 1
|
||||
if model_info.loaded:
|
||||
loaded_models += 1
|
||||
if model_info.locked:
|
||||
locked_models += 1
|
||||
|
||||
self.logger.debug(f"Current VRAM/RAM usage: {vram}/{ram}; cached_models/loaded_models/locked_models/ = {cached_models}/{loaded_models}/{locked_models}")
|
||||
|
||||
|
||||
def _make_cache_room(self, model_size):
|
||||
# calculate how much memory this model will require
|
||||
#multiplier = 2 if self.precision==torch.float32 else 1
|
||||
bytes_needed = model_size
|
||||
maximum_size = self.max_cache_size * GIG # stored in GB, convert to bytes
|
||||
current_size = sum([m.size for m in self._cached_models.values()])
|
||||
|
||||
if current_size + bytes_needed > maximum_size:
|
||||
self.logger.debug(f'Max cache size exceeded: {(current_size/GIG):.2f}/{self.max_cache_size:.2f} GB, need an additional {(bytes_needed/GIG):.2f} GB')
|
||||
|
||||
self.logger.debug(f"Before unloading: cached_models={len(self._cached_models)}")
|
||||
|
||||
pos = 0
|
||||
while current_size + bytes_needed > maximum_size and pos < len(self._cache_stack):
|
||||
model_key = self._cache_stack[pos]
|
||||
cache_entry = self._cached_models[model_key]
|
||||
|
||||
refs = sys.getrefcount(cache_entry.model)
|
||||
|
||||
device = cache_entry.model.device if hasattr(cache_entry.model, "device") else None
|
||||
self.logger.debug(f"Model: {model_key}, locks: {cache_entry._locks}, device: {device}, loaded: {cache_entry.loaded}, refs: {refs}")
|
||||
|
||||
# 2 refs:
|
||||
# 1 from cache_entry
|
||||
# 1 from getrefcount function
|
||||
if not cache_entry.locked and refs <= 2:
|
||||
self.logger.debug(f'Unloading model {model_key} to free {(model_size/GIG):.2f} GB (-{(cache_entry.size/GIG):.2f} GB)')
|
||||
current_size -= cache_entry.size
|
||||
del self._cache_stack[pos]
|
||||
del self._cached_models[model_key]
|
||||
del cache_entry
|
||||
|
||||
else:
|
||||
pos += 1
|
||||
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
self.logger.debug(f"After unloading: cached_models={len(self._cached_models)}")
|
||||
|
||||
|
||||
def _offload_unlocked_models(self):
|
||||
for model_key, cache_entry in self._cached_models.items():
|
||||
if not cache_entry.locked and cache_entry.loaded:
|
||||
self.logger.debug(f'Offloading {model_key} from {self.execution_device} into {self.storage_device}')
|
||||
cache_entry.model.to(self.storage_device)
|
||||
|
||||
def _local_model_hash(self, model_path: Union[str, Path]) -> str:
|
||||
sha = hashlib.sha256()
|
||||
path = Path(model_path)
|
||||
|
||||
hashpath = path / "checksum.sha256"
|
||||
if hashpath.exists() and path.stat().st_mtime <= hashpath.stat().st_mtime:
|
||||
with open(hashpath) as f:
|
||||
hash = f.read()
|
||||
return hash
|
||||
|
||||
self.logger.debug(f'computing hash of model {path.name}')
|
||||
for file in list(path.rglob("*.ckpt")) \
|
||||
+ list(path.rglob("*.safetensors")) \
|
||||
+ list(path.rglob("*.pth")):
|
||||
with open(file, "rb") as f:
|
||||
while chunk := f.read(self.sha_chunksize):
|
||||
sha.update(chunk)
|
||||
hash = sha.hexdigest()
|
||||
with open(hashpath, "w") as f:
|
||||
f.write(hash)
|
||||
return hash
|
||||
|
||||
class VRAMUsage(object):
|
||||
def __init__(self):
|
||||
self.vram = None
|
||||
self.vram_used = 0
|
||||
|
||||
def __enter__(self):
|
||||
self.vram = torch.cuda.memory_allocated()
|
||||
return self
|
||||
|
||||
def __exit__(self, *args):
|
||||
self.vram_used = torch.cuda.memory_allocated() - self.vram
|
118
invokeai/backend/model_management/model_install.py
Normal file
118
invokeai/backend/model_management/model_install.py
Normal file
@ -0,0 +1,118 @@
|
||||
"""
|
||||
Routines for downloading and installing models.
|
||||
"""
|
||||
import json
|
||||
import safetensors
|
||||
import safetensors.torch
|
||||
import shutil
|
||||
import tempfile
|
||||
import torch
|
||||
import traceback
|
||||
from dataclasses import dataclass
|
||||
from diffusers import ModelMixin
|
||||
from enum import Enum
|
||||
from typing import Callable
|
||||
from pathlib import Path
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from . import ModelManager
|
||||
from .models import BaseModelType, ModelType, VariantType
|
||||
from .model_probe import ModelProbe, ModelVariantInfo
|
||||
from .model_cache import SilenceWarnings
|
||||
|
||||
class ModelInstall(object):
|
||||
'''
|
||||
This class is able to download and install several different kinds of
|
||||
InvokeAI models. The helper function, if provided, is called on to distinguish
|
||||
between v2-base and v2-768 stable diffusion pipelines. This usually involves
|
||||
asking the user to select the proper type, as there is no way of distinguishing
|
||||
the two type of v2 file programmatically (as far as I know).
|
||||
'''
|
||||
def __init__(self,
|
||||
config: InvokeAIAppConfig,
|
||||
model_base_helper: Callable[[Path],BaseModelType]=None,
|
||||
clobber:bool = False
|
||||
):
|
||||
'''
|
||||
:param config: InvokeAI configuration object
|
||||
:param model_base_helper: A function call that accepts the Path to a checkpoint model and returns a ModelType enum
|
||||
:param clobber: If true, models with colliding names will be overwritten
|
||||
'''
|
||||
self.config = config
|
||||
self.clogger = clobber
|
||||
self.helper = model_base_helper
|
||||
self.prober = ModelProbe()
|
||||
|
||||
def install_checkpoint_file(self, checkpoint: Path)->dict:
|
||||
'''
|
||||
Install the checkpoint file at path and return a
|
||||
configuration entry that can be added to `models.yaml`.
|
||||
Model checkpoints and VAEs will be converted into
|
||||
diffusers before installation. Note that the model manager
|
||||
does not hold entries for anything but diffusers pipelines,
|
||||
and the configuration file stanzas returned from such models
|
||||
can be safely ignored.
|
||||
'''
|
||||
model_info = self.prober.probe(checkpoint, self.helper)
|
||||
if not model_info:
|
||||
raise ValueError(f"Unable to determine type of checkpoint file {checkpoint}")
|
||||
|
||||
key = ModelManager.create_key(
|
||||
model_name = checkpoint.stem,
|
||||
base_model = model_info.base_type,
|
||||
model_type = model_info.model_type,
|
||||
)
|
||||
destination_path = self._dest_path(model_info) / checkpoint
|
||||
destination_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
self._check_for_collision(destination_path)
|
||||
stanza = {
|
||||
key: dict(
|
||||
name = checkpoint.stem,
|
||||
description = f'{model_info.model_type} model {checkpoint.stem}',
|
||||
base = model_info.base_model.value,
|
||||
type = model_info.model_type.value,
|
||||
variant = model_info.variant_type.value,
|
||||
path = str(destination_path),
|
||||
)
|
||||
}
|
||||
|
||||
# non-pipeline; no conversion needed, just copy into right place
|
||||
if model_info.model_type != ModelType.Pipeline:
|
||||
shutil.copyfile(checkpoint, destination_path)
|
||||
stanza[key].update({'format': 'checkpoint'})
|
||||
|
||||
# pipeline - conversion needed here
|
||||
else:
|
||||
destination_path = self._dest_path(model_info) / checkpoint.stem
|
||||
config_file = self._pipeline_type_to_config_file(model_info.model_type)
|
||||
|
||||
from .convert_ckpt_to_diffusers import convert_ckpt_to_diffusers
|
||||
with SilenceWarnings:
|
||||
convert_ckpt_to_diffusers(
|
||||
checkpoint,
|
||||
destination_path,
|
||||
extract_ema=True,
|
||||
original_config_file=config_file,
|
||||
scan_needed=False,
|
||||
)
|
||||
stanza[key].update({'format': 'folder',
|
||||
'path': destination_path, # no suffix on this
|
||||
})
|
||||
|
||||
return stanza
|
||||
|
||||
|
||||
def _check_for_collision(self, path: Path):
|
||||
if not path.exists():
|
||||
return
|
||||
if self.clobber:
|
||||
shutil.rmtree(path)
|
||||
else:
|
||||
raise ValueError(f"Destination {path} already exists. Won't overwrite unless clobber=True.")
|
||||
|
||||
def _staging_directory(self)->tempfile.TemporaryDirectory:
|
||||
return tempfile.TemporaryDirectory(dir=self.config.root_path)
|
||||
|
||||
|
||||
|
File diff suppressed because it is too large
Load Diff
417
invokeai/backend/model_management/model_probe.py
Normal file
417
invokeai/backend/model_management/model_probe.py
Normal file
@ -0,0 +1,417 @@
|
||||
import json
|
||||
import traceback
|
||||
import torch
|
||||
import safetensors.torch
|
||||
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
|
||||
from diffusers import ModelMixin, ConfigMixin, StableDiffusionPipeline, AutoencoderKL, ControlNetModel
|
||||
from pathlib import Path
|
||||
from typing import Callable, Literal, Union, Dict
|
||||
from picklescan.scanner import scan_file_path
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
from .models import BaseModelType, ModelType, ModelVariantType, SchedulerPredictionType, SilenceWarnings
|
||||
|
||||
@dataclass
|
||||
class ModelVariantInfo(object):
|
||||
model_type: ModelType
|
||||
base_type: BaseModelType
|
||||
variant_type: ModelVariantType
|
||||
prediction_type: SchedulerPredictionType
|
||||
upcast_attention: bool
|
||||
format: Literal['folder','checkpoint']
|
||||
image_size: int
|
||||
|
||||
class ProbeBase(object):
|
||||
'''forward declaration'''
|
||||
pass
|
||||
|
||||
class ModelProbe(object):
|
||||
|
||||
PROBES = {
|
||||
'folder': { },
|
||||
'checkpoint': { },
|
||||
}
|
||||
|
||||
CLASS2TYPE = {
|
||||
'StableDiffusionPipeline' : ModelType.Pipeline,
|
||||
'AutoencoderKL' : ModelType.Vae,
|
||||
'ControlNetModel' : ModelType.ControlNet,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def register_probe(cls,
|
||||
format: Literal['folder','file'],
|
||||
model_type: ModelType,
|
||||
probe_class: ProbeBase):
|
||||
cls.PROBES[format][model_type] = probe_class
|
||||
|
||||
@classmethod
|
||||
def heuristic_probe(cls,
|
||||
model: Union[Dict, ModelMixin, Path],
|
||||
prediction_type_helper: Callable[[Path],BaseModelType]=None,
|
||||
)->ModelVariantInfo:
|
||||
if isinstance(model,Path):
|
||||
return cls.probe(model_path=model,prediction_type_helper=prediction_type_helper)
|
||||
elif isinstance(model,(dict,ModelMixin,ConfigMixin)):
|
||||
return cls.probe(model_path=None, model=model, prediction_type_helper=prediction_type_helper)
|
||||
else:
|
||||
raise Exception("model parameter {model} is neither a Path, nor a model")
|
||||
|
||||
@classmethod
|
||||
def probe(cls,
|
||||
model_path: Path,
|
||||
model: Union[Dict, ModelMixin] = None,
|
||||
prediction_type_helper: Callable[[Path],BaseModelType] = None)->ModelVariantInfo:
|
||||
'''
|
||||
Probe the model at model_path and return sufficient information about it
|
||||
to place it somewhere in the models directory hierarchy. If the model is
|
||||
already loaded into memory, you may provide it as model in order to avoid
|
||||
opening it a second time. The prediction_type_helper callable is a function that receives
|
||||
the path to the model and returns the BaseModelType. It is called to distinguish
|
||||
between V2-Base and V2-768 SD models.
|
||||
'''
|
||||
if model_path:
|
||||
format = 'folder' if model_path.is_dir() else 'checkpoint'
|
||||
else:
|
||||
format = 'folder' if isinstance(model,(ConfigMixin,ModelMixin)) else 'checkpoint'
|
||||
|
||||
model_info = None
|
||||
try:
|
||||
model_type = cls.get_model_type_from_folder(model_path, model) \
|
||||
if format == 'folder' \
|
||||
else cls.get_model_type_from_checkpoint(model_path, model)
|
||||
probe_class = cls.PROBES[format].get(model_type)
|
||||
if not probe_class:
|
||||
return None
|
||||
probe = probe_class(model_path, model, prediction_type_helper)
|
||||
base_type = probe.get_base_type()
|
||||
variant_type = probe.get_variant_type()
|
||||
prediction_type = probe.get_scheduler_prediction_type()
|
||||
model_info = ModelVariantInfo(
|
||||
model_type = model_type,
|
||||
base_type = base_type,
|
||||
variant_type = variant_type,
|
||||
prediction_type = prediction_type,
|
||||
upcast_attention = (base_type==BaseModelType.StableDiffusion2 \
|
||||
and prediction_type==SchedulerPredictionType.VPrediction),
|
||||
format = format,
|
||||
image_size = 768 if (base_type==BaseModelType.StableDiffusion2 \
|
||||
and prediction_type==SchedulerPredictionType.VPrediction \
|
||||
) else 512,
|
||||
)
|
||||
except Exception as e:
|
||||
return None
|
||||
|
||||
return model_info
|
||||
|
||||
@classmethod
|
||||
def get_model_type_from_checkpoint(cls, model_path: Path, checkpoint: dict)->ModelType:
|
||||
if model_path.suffix not in ('.bin','.pt','.ckpt','.safetensors'):
|
||||
return None
|
||||
if model_path.name=='learned_embeds.bin':
|
||||
return ModelType.TextualInversion
|
||||
checkpoint = checkpoint or cls._scan_and_load_checkpoint(model_path)
|
||||
state_dict = checkpoint.get("state_dict") or checkpoint
|
||||
if any([x.startswith("model.diffusion_model") for x in state_dict.keys()]):
|
||||
return ModelType.Pipeline
|
||||
if any([x.startswith("encoder.conv_in") for x in state_dict.keys()]):
|
||||
return ModelType.Vae
|
||||
if "string_to_token" in state_dict or "emb_params" in state_dict:
|
||||
return ModelType.TextualInversion
|
||||
if any([x.startswith("lora") for x in state_dict.keys()]):
|
||||
return ModelType.Lora
|
||||
if any([x.startswith("control_model") for x in state_dict.keys()]):
|
||||
return ModelType.ControlNet
|
||||
if any([x.startswith("input_blocks") for x in state_dict.keys()]):
|
||||
return ModelType.ControlNet
|
||||
return None # give up
|
||||
|
||||
@classmethod
|
||||
def get_model_type_from_folder(cls, folder_path: Path, model: ModelMixin)->ModelType:
|
||||
'''
|
||||
Get the model type of a hugging-face style folder.
|
||||
'''
|
||||
class_name = None
|
||||
if model:
|
||||
class_name = model.__class__.__name__
|
||||
else:
|
||||
if (folder_path / 'learned_embeds.bin').exists():
|
||||
return ModelType.TextualInversion
|
||||
|
||||
if (folder_path / 'pytorch_lora_weights.bin').exists():
|
||||
return ModelType.Lora
|
||||
|
||||
i = folder_path / 'model_index.json'
|
||||
c = folder_path / 'config.json'
|
||||
config_path = i if i.exists() else c if c.exists() else None
|
||||
|
||||
if config_path:
|
||||
with open(config_path,'r') as file:
|
||||
conf = json.load(file)
|
||||
class_name = conf['_class_name']
|
||||
|
||||
if class_name and (type := cls.CLASS2TYPE.get(class_name)):
|
||||
return type
|
||||
|
||||
# give up
|
||||
raise ValueError("Unable to determine model type")
|
||||
|
||||
@classmethod
|
||||
def _scan_and_load_checkpoint(cls,model_path: Path)->dict:
|
||||
with SilenceWarnings():
|
||||
if model_path.suffix.endswith((".ckpt", ".pt", ".bin")):
|
||||
cls._scan_model(model_path, model_path)
|
||||
return torch.load(model_path)
|
||||
else:
|
||||
return safetensors.torch.load_file(model_path)
|
||||
|
||||
@classmethod
|
||||
def _scan_model(cls, model_name, checkpoint):
|
||||
"""
|
||||
Apply picklescanner to the indicated checkpoint and issue a warning
|
||||
and option to exit if an infected file is identified.
|
||||
"""
|
||||
# scan model
|
||||
scan_result = scan_file_path(checkpoint)
|
||||
if scan_result.infected_files != 0:
|
||||
raise "The model {model_name} is potentially infected by malware. Aborting import."
|
||||
|
||||
###################################################3
|
||||
# Checkpoint probing
|
||||
###################################################3
|
||||
class ProbeBase(object):
|
||||
def get_base_type(self)->BaseModelType:
|
||||
pass
|
||||
|
||||
def get_variant_type(self)->ModelVariantType:
|
||||
pass
|
||||
|
||||
def get_scheduler_prediction_type(self)->SchedulerPredictionType:
|
||||
pass
|
||||
|
||||
class CheckpointProbeBase(ProbeBase):
|
||||
def __init__(self,
|
||||
checkpoint_path: Path,
|
||||
checkpoint: dict,
|
||||
helper: Callable[[Path],BaseModelType] = None
|
||||
)->BaseModelType:
|
||||
self.checkpoint = checkpoint or ModelProbe._scan_and_load_checkpoint(checkpoint_path)
|
||||
self.checkpoint_path = checkpoint_path
|
||||
self.helper = helper
|
||||
|
||||
def get_base_type(self)->BaseModelType:
|
||||
pass
|
||||
|
||||
def get_variant_type(self)-> ModelVariantType:
|
||||
model_type = ModelProbe.get_model_type_from_checkpoint(self.checkpoint_path,self.checkpoint)
|
||||
if model_type != ModelType.Pipeline:
|
||||
return ModelVariantType.Normal
|
||||
state_dict = self.checkpoint.get('state_dict') or self.checkpoint
|
||||
in_channels = state_dict[
|
||||
"model.diffusion_model.input_blocks.0.0.weight"
|
||||
].shape[1]
|
||||
if in_channels == 9:
|
||||
return ModelVariantType.Inpaint
|
||||
elif in_channels == 5:
|
||||
return ModelVariantType.Depth
|
||||
elif in_channels == 4:
|
||||
return ModelVariantType.Normal
|
||||
else:
|
||||
raise Exception("Cannot determine variant type")
|
||||
|
||||
class PipelineCheckpointProbe(CheckpointProbeBase):
|
||||
def get_base_type(self)->BaseModelType:
|
||||
checkpoint = self.checkpoint
|
||||
state_dict = self.checkpoint.get('state_dict') or checkpoint
|
||||
key_name = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight"
|
||||
if key_name in state_dict and state_dict[key_name].shape[-1] == 768:
|
||||
return BaseModelType.StableDiffusion1
|
||||
if key_name in state_dict and state_dict[key_name].shape[-1] == 1024:
|
||||
return BaseModelType.StableDiffusion2
|
||||
raise Exception("Cannot determine base type")
|
||||
|
||||
def get_scheduler_prediction_type(self)->SchedulerPredictionType:
|
||||
type = self.get_base_type()
|
||||
if type == BaseModelType.StableDiffusion1:
|
||||
return SchedulerPredictionType.Epsilon
|
||||
checkpoint = self.checkpoint
|
||||
state_dict = self.checkpoint.get('state_dict') or checkpoint
|
||||
key_name = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight"
|
||||
if key_name in state_dict and state_dict[key_name].shape[-1] == 1024:
|
||||
if 'global_step' in checkpoint:
|
||||
if checkpoint['global_step'] == 220000:
|
||||
return SchedulerPredictionType.Epsilon
|
||||
elif checkpoint["global_step"] == 110000:
|
||||
return SchedulerPredictionType.VPrediction
|
||||
if self.checkpoint_path and self.helper:
|
||||
return self.helper(self.checkpoint_path)
|
||||
else:
|
||||
return None
|
||||
|
||||
class VaeCheckpointProbe(CheckpointProbeBase):
|
||||
def get_base_type(self)->BaseModelType:
|
||||
# I can't find any standalone 2.X VAEs to test with!
|
||||
return BaseModelType.StableDiffusion1
|
||||
|
||||
class LoRACheckpointProbe(CheckpointProbeBase):
|
||||
def get_base_type(self)->BaseModelType:
|
||||
checkpoint = self.checkpoint
|
||||
key1 = "lora_te_text_model_encoder_layers_0_mlp_fc1.lora_down.weight"
|
||||
key2 = "lora_te_text_model_encoder_layers_0_self_attn_k_proj.hada_w1_a"
|
||||
lora_token_vector_length = (
|
||||
checkpoint[key1].shape[1]
|
||||
if key1 in checkpoint
|
||||
else checkpoint[key2].shape[0]
|
||||
if key2 in checkpoint
|
||||
else 768
|
||||
)
|
||||
if lora_token_vector_length == 768:
|
||||
return BaseModelType.StableDiffusion1
|
||||
elif lora_token_vector_length == 1024:
|
||||
return BaseModelType.StableDiffusion2
|
||||
else:
|
||||
return None
|
||||
|
||||
class TextualInversionCheckpointProbe(CheckpointProbeBase):
|
||||
def get_base_type(self)->BaseModelType:
|
||||
checkpoint = self.checkpoint
|
||||
if 'string_to_token' in checkpoint:
|
||||
token_dim = list(checkpoint['string_to_param'].values())[0].shape[-1]
|
||||
elif 'emb_params' in checkpoint:
|
||||
token_dim = checkpoint['emb_params'].shape[-1]
|
||||
else:
|
||||
token_dim = list(checkpoint.values())[0].shape[0]
|
||||
if token_dim == 768:
|
||||
return BaseModelType.StableDiffusion1
|
||||
elif token_dim == 1024:
|
||||
return BaseModelType.StableDiffusion2
|
||||
else:
|
||||
return None
|
||||
|
||||
class ControlNetCheckpointProbe(CheckpointProbeBase):
|
||||
def get_base_type(self)->BaseModelType:
|
||||
checkpoint = self.checkpoint
|
||||
for key_name in ('control_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight',
|
||||
'input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight'
|
||||
):
|
||||
if key_name not in checkpoint:
|
||||
continue
|
||||
if checkpoint[key_name].shape[-1] == 768:
|
||||
return BaseModelType.StableDiffusion1
|
||||
elif checkpoint[key_name].shape[-1] == 1024:
|
||||
return BaseModelType.StableDiffusion2
|
||||
elif self.checkpoint_path and self.helper:
|
||||
return self.helper(self.checkpoint_path)
|
||||
raise Exception("Unable to determine base type for {self.checkpoint_path}")
|
||||
|
||||
########################################################
|
||||
# classes for probing folders
|
||||
#######################################################
|
||||
class FolderProbeBase(ProbeBase):
|
||||
def __init__(self,
|
||||
folder_path: Path,
|
||||
model: ModelMixin = None,
|
||||
helper: Callable=None # not used
|
||||
):
|
||||
self.model = model
|
||||
self.folder_path = folder_path
|
||||
|
||||
def get_variant_type(self)->ModelVariantType:
|
||||
return ModelVariantType.Normal
|
||||
|
||||
class PipelineFolderProbe(FolderProbeBase):
|
||||
def get_base_type(self)->BaseModelType:
|
||||
if self.model:
|
||||
unet_conf = self.model.unet.config
|
||||
scheduler_conf = self.model.scheduler.config
|
||||
else:
|
||||
with open(self.folder_path / 'unet' / 'config.json','r') as file:
|
||||
unet_conf = json.load(file)
|
||||
with open(self.folder_path / 'scheduler' / 'scheduler_config.json','r') as file:
|
||||
scheduler_conf = json.load(file)
|
||||
|
||||
if unet_conf['cross_attention_dim'] == 768:
|
||||
return BaseModelType.StableDiffusion1
|
||||
elif unet_conf['cross_attention_dim'] == 1024:
|
||||
return BaseModelType.StableDiffusion2
|
||||
else:
|
||||
raise ValueError(f'Unknown base model for {self.folder_path}')
|
||||
|
||||
def get_scheduler_prediction_type(self)->SchedulerPredictionType:
|
||||
if self.model:
|
||||
scheduler_conf = self.model.scheduler.config
|
||||
else:
|
||||
with open(self.folder_path / 'scheduler' / 'scheduler_config.json','r') as file:
|
||||
scheduler_conf = json.load(file)
|
||||
if scheduler_conf['prediction_type'] == "v_prediction":
|
||||
return SchedulerPredictionType.VPrediction
|
||||
elif scheduler_conf['prediction_type'] == 'epsilon':
|
||||
return SchedulerPredictionType.Epsilon
|
||||
else:
|
||||
return None
|
||||
|
||||
def get_variant_type(self)->ModelVariantType:
|
||||
# This only works for pipelines! Any kind of
|
||||
# exception results in our returning the
|
||||
# "normal" variant type
|
||||
try:
|
||||
if self.model:
|
||||
conf = self.model.unet.config
|
||||
else:
|
||||
config_file = self.folder_path / 'unet' / 'config.json'
|
||||
with open(config_file,'r') as file:
|
||||
conf = json.load(file)
|
||||
|
||||
in_channels = conf['in_channels']
|
||||
if in_channels == 9:
|
||||
return ModelVariantType.Inpainting
|
||||
elif in_channels == 5:
|
||||
return ModelVariantType.Depth
|
||||
elif in_channels == 4:
|
||||
return ModelVariantType.Normal
|
||||
except:
|
||||
pass
|
||||
return ModelVariantType.Normal
|
||||
|
||||
class VaeFolderProbe(FolderProbeBase):
|
||||
def get_base_type(self)->BaseModelType:
|
||||
return BaseModelType.StableDiffusion1
|
||||
|
||||
class TextualInversionFolderProbe(FolderProbeBase):
|
||||
def get_base_type(self)->BaseModelType:
|
||||
path = self.folder_path / 'learned_embeds.bin'
|
||||
if not path.exists():
|
||||
return None
|
||||
checkpoint = ModelProbe._scan_and_load_checkpoint(path)
|
||||
return TextualInversionCheckpointProbe(None,checkpoint=checkpoint).get_base_type()
|
||||
|
||||
class ControlNetFolderProbe(FolderProbeBase):
|
||||
def get_base_type(self)->BaseModelType:
|
||||
config_file = self.folder_path / 'config.json'
|
||||
if not config_file.exists():
|
||||
raise Exception(f"Cannot determine base type for {self.folder_path}")
|
||||
with open(config_file,'r') as file:
|
||||
config = json.load(file)
|
||||
# no obvious way to distinguish between sd2-base and sd2-768
|
||||
return BaseModelType.StableDiffusion1 \
|
||||
if config['cross_attention_dim']==768 \
|
||||
else BaseModelType.StableDiffusion2
|
||||
|
||||
class LoRAFolderProbe(FolderProbeBase):
|
||||
# I've never seen one of these in the wild, so this is a noop
|
||||
pass
|
||||
|
||||
############## register probe classes ######
|
||||
ModelProbe.register_probe('folder', ModelType.Pipeline, PipelineFolderProbe)
|
||||
ModelProbe.register_probe('folder', ModelType.Vae, VaeFolderProbe)
|
||||
ModelProbe.register_probe('folder', ModelType.Lora, LoRAFolderProbe)
|
||||
ModelProbe.register_probe('folder', ModelType.TextualInversion, TextualInversionFolderProbe)
|
||||
ModelProbe.register_probe('folder', ModelType.ControlNet, ControlNetFolderProbe)
|
||||
ModelProbe.register_probe('checkpoint', ModelType.Pipeline, PipelineCheckpointProbe)
|
||||
ModelProbe.register_probe('checkpoint', ModelType.Vae, VaeCheckpointProbe)
|
||||
ModelProbe.register_probe('checkpoint', ModelType.Lora, LoRACheckpointProbe)
|
||||
ModelProbe.register_probe('checkpoint', ModelType.TextualInversion, TextualInversionCheckpointProbe)
|
||||
ModelProbe.register_probe('checkpoint', ModelType.ControlNet, ControlNetCheckpointProbe)
|
95
invokeai/backend/model_management/models/__init__.py
Normal file
95
invokeai/backend/model_management/models/__init__.py
Normal file
@ -0,0 +1,95 @@
|
||||
import inspect
|
||||
from enum import Enum
|
||||
from pydantic import BaseModel
|
||||
from typing import Literal, get_origin
|
||||
from .base import BaseModelType, ModelType, SubModelType, ModelBase, ModelConfigBase, ModelVariantType, SchedulerPredictionType, ModelError, SilenceWarnings
|
||||
from .stable_diffusion import StableDiffusion1Model, StableDiffusion2Model
|
||||
from .vae import VaeModel
|
||||
from .lora import LoRAModel
|
||||
from .controlnet import ControlNetModel # TODO:
|
||||
from .textual_inversion import TextualInversionModel
|
||||
|
||||
MODEL_CLASSES = {
|
||||
BaseModelType.StableDiffusion1: {
|
||||
ModelType.Pipeline: StableDiffusion1Model,
|
||||
ModelType.Vae: VaeModel,
|
||||
ModelType.Lora: LoRAModel,
|
||||
ModelType.ControlNet: ControlNetModel,
|
||||
ModelType.TextualInversion: TextualInversionModel,
|
||||
},
|
||||
BaseModelType.StableDiffusion2: {
|
||||
ModelType.Pipeline: StableDiffusion2Model,
|
||||
ModelType.Vae: VaeModel,
|
||||
ModelType.Lora: LoRAModel,
|
||||
ModelType.ControlNet: ControlNetModel,
|
||||
ModelType.TextualInversion: TextualInversionModel,
|
||||
},
|
||||
#BaseModelType.Kandinsky2_1: {
|
||||
# ModelType.Pipeline: Kandinsky2_1Model,
|
||||
# ModelType.MoVQ: MoVQModel,
|
||||
# ModelType.Lora: LoRAModel,
|
||||
# ModelType.ControlNet: ControlNetModel,
|
||||
# ModelType.TextualInversion: TextualInversionModel,
|
||||
#},
|
||||
}
|
||||
|
||||
MODEL_CONFIGS = list()
|
||||
OPENAPI_MODEL_CONFIGS = list()
|
||||
|
||||
class OpenAPIModelInfoBase(BaseModel):
|
||||
name: str
|
||||
base_model: BaseModelType
|
||||
type: ModelType
|
||||
|
||||
|
||||
for base_model, models in MODEL_CLASSES.items():
|
||||
for model_type, model_class in models.items():
|
||||
model_configs = set(model_class._get_configs().values())
|
||||
model_configs.discard(None)
|
||||
MODEL_CONFIGS.extend(model_configs)
|
||||
|
||||
for cfg in model_configs:
|
||||
model_name, cfg_name = cfg.__qualname__.split('.')[-2:]
|
||||
openapi_cfg_name = model_name + cfg_name
|
||||
if openapi_cfg_name in vars():
|
||||
continue
|
||||
|
||||
api_wrapper = type(openapi_cfg_name, (cfg, OpenAPIModelInfoBase), dict(
|
||||
__annotations__ = dict(
|
||||
type=Literal[model_type.value],
|
||||
),
|
||||
))
|
||||
|
||||
#globals()[openapi_cfg_name] = api_wrapper
|
||||
vars()[openapi_cfg_name] = api_wrapper
|
||||
OPENAPI_MODEL_CONFIGS.append(api_wrapper)
|
||||
|
||||
def get_model_config_enums():
|
||||
enums = list()
|
||||
|
||||
for model_config in MODEL_CONFIGS:
|
||||
fields = inspect.get_annotations(model_config)
|
||||
try:
|
||||
field = fields["model_format"]
|
||||
except:
|
||||
raise Exception("format field not found")
|
||||
|
||||
# model_format: None
|
||||
# model_format: SomeModelFormat
|
||||
# model_format: Literal[SomeModelFormat.Diffusers]
|
||||
# model_format: Literal[SomeModelFormat.Diffusers, SomeModelFormat.Checkpoint]
|
||||
|
||||
if isinstance(field, type) and issubclass(field, str) and issubclass(field, Enum):
|
||||
enums.append(field)
|
||||
|
||||
elif get_origin(field) is Literal and all(isinstance(arg, str) and isinstance(arg, Enum) for arg in field.__args__):
|
||||
enums.append(type(field.__args__[0]))
|
||||
|
||||
elif field is None:
|
||||
pass
|
||||
|
||||
else:
|
||||
raise Exception(f"Unsupported format definition in {model_configs.__qualname__}")
|
||||
|
||||
return enums
|
||||
|
415
invokeai/backend/model_management/models/base.py
Normal file
415
invokeai/backend/model_management/models/base.py
Normal file
@ -0,0 +1,415 @@
|
||||
import os
|
||||
import sys
|
||||
import typing
|
||||
import inspect
|
||||
from enum import Enum
|
||||
from abc import ABCMeta, abstractmethod
|
||||
import torch
|
||||
import safetensors.torch
|
||||
from diffusers import DiffusionPipeline, ConfigMixin
|
||||
|
||||
from contextlib import suppress
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import List, Dict, Optional, Type, Literal, TypeVar, Generic, Callable, Any, Union
|
||||
|
||||
class BaseModelType(str, Enum):
|
||||
StableDiffusion1 = "sd-1"
|
||||
StableDiffusion2 = "sd-2"
|
||||
#Kandinsky2_1 = "kandinsky-2.1"
|
||||
|
||||
class ModelType(str, Enum):
|
||||
Pipeline = "pipeline"
|
||||
Vae = "vae"
|
||||
Lora = "lora"
|
||||
ControlNet = "controlnet" # used by model_probe
|
||||
TextualInversion = "embedding"
|
||||
|
||||
class SubModelType(str, Enum):
|
||||
UNet = "unet"
|
||||
TextEncoder = "text_encoder"
|
||||
Tokenizer = "tokenizer"
|
||||
Vae = "vae"
|
||||
Scheduler = "scheduler"
|
||||
SafetyChecker = "safety_checker"
|
||||
#MoVQ = "movq"
|
||||
|
||||
class ModelVariantType(str, Enum):
|
||||
Normal = "normal"
|
||||
Inpaint = "inpaint"
|
||||
Depth = "depth"
|
||||
|
||||
class SchedulerPredictionType(str, Enum):
|
||||
Epsilon = "epsilon"
|
||||
VPrediction = "v_prediction"
|
||||
Sample = "sample"
|
||||
|
||||
class ModelError(str, Enum):
|
||||
NotFound = "not_found"
|
||||
|
||||
class ModelConfigBase(BaseModel):
|
||||
path: str # or Path
|
||||
description: Optional[str] = Field(None)
|
||||
model_format: Optional[str] = Field(None)
|
||||
# do not save to config
|
||||
error: Optional[ModelError] = Field(None)
|
||||
|
||||
class Config:
|
||||
use_enum_values = True
|
||||
|
||||
|
||||
class EmptyConfigLoader(ConfigMixin):
|
||||
@classmethod
|
||||
def load_config(cls, *args, **kwargs):
|
||||
cls.config_name = kwargs.pop("config_name")
|
||||
return super().load_config(*args, **kwargs)
|
||||
|
||||
T_co = TypeVar('T_co', covariant=True)
|
||||
class classproperty(Generic[T_co]):
|
||||
def __init__(self, fget: Callable[[Any], T_co]) -> None:
|
||||
self.fget = fget
|
||||
|
||||
def __get__(self, instance: Optional[Any], owner: Type[Any]) -> T_co:
|
||||
return self.fget(owner)
|
||||
|
||||
def __set__(self, instance: Optional[Any], value: Any) -> None:
|
||||
raise AttributeError('cannot set attribute')
|
||||
|
||||
class ModelBase(metaclass=ABCMeta):
|
||||
#model_path: str
|
||||
#base_model: BaseModelType
|
||||
#model_type: ModelType
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_path: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
):
|
||||
self.model_path = model_path
|
||||
self.base_model = base_model
|
||||
self.model_type = model_type
|
||||
|
||||
def _hf_definition_to_type(self, subtypes: List[str]) -> Type:
|
||||
if len(subtypes) < 2:
|
||||
raise Exception("Invalid subfolder definition!")
|
||||
if all(t is None for t in subtypes):
|
||||
return None
|
||||
elif any(t is None for t in subtypes):
|
||||
raise Exception(f"Unsupported definition: {subtypes}")
|
||||
|
||||
if subtypes[0] in ["diffusers", "transformers"]:
|
||||
res_type = sys.modules[subtypes[0]]
|
||||
subtypes = subtypes[1:]
|
||||
|
||||
else:
|
||||
res_type = sys.modules["diffusers"]
|
||||
res_type = getattr(res_type, "pipelines")
|
||||
|
||||
|
||||
for subtype in subtypes:
|
||||
res_type = getattr(res_type, subtype)
|
||||
return res_type
|
||||
|
||||
@classmethod
|
||||
def _get_configs(cls):
|
||||
with suppress(Exception):
|
||||
return cls.__configs
|
||||
|
||||
configs = dict()
|
||||
for name in dir(cls):
|
||||
if name.startswith("__"):
|
||||
continue
|
||||
|
||||
value = getattr(cls, name)
|
||||
if not isinstance(value, type) or not issubclass(value, ModelConfigBase):
|
||||
continue
|
||||
|
||||
fields = inspect.get_annotations(value)
|
||||
try:
|
||||
field = fields["model_format"]
|
||||
except:
|
||||
raise Exception(f"Invalid config definition - format field not found({cls.__qualname__})")
|
||||
|
||||
if isinstance(field, type) and issubclass(field, str) and issubclass(field, Enum):
|
||||
for model_format in field:
|
||||
configs[model_format.value] = value
|
||||
|
||||
elif typing.get_origin(field) is Literal and all(isinstance(arg, str) and isinstance(arg, Enum) for arg in field.__args__):
|
||||
for model_format in field.__args__:
|
||||
configs[model_format.value] = value
|
||||
|
||||
elif field is None:
|
||||
configs[None] = value
|
||||
|
||||
else:
|
||||
raise Exception(f"Unsupported format definition in {cls.__qualname__}")
|
||||
|
||||
cls.__configs = configs
|
||||
return cls.__configs
|
||||
|
||||
@classmethod
|
||||
def create_config(cls, **kwargs) -> ModelConfigBase:
|
||||
if "model_format" not in kwargs:
|
||||
raise Exception("Field 'model_format' not found in model config")
|
||||
|
||||
configs = cls._get_configs()
|
||||
return configs[kwargs["model_format"]](**kwargs)
|
||||
|
||||
@classmethod
|
||||
def probe_config(cls, path: str, **kwargs) -> ModelConfigBase:
|
||||
return cls.create_config(
|
||||
path=path,
|
||||
model_format=cls.detect_format(path),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def detect_format(cls, path: str) -> str:
|
||||
raise NotImplementedError()
|
||||
|
||||
@classproperty
|
||||
@abstractmethod
|
||||
def save_to_config(cls) -> bool:
|
||||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
def get_size(self, child_type: Optional[SubModelType] = None) -> int:
|
||||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
def get_model(
|
||||
self,
|
||||
torch_dtype: Optional[torch.dtype],
|
||||
child_type: Optional[SubModelType] = None,
|
||||
) -> Any:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class DiffusersModel(ModelBase):
|
||||
#child_types: Dict[str, Type]
|
||||
#child_sizes: Dict[str, int]
|
||||
|
||||
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
|
||||
super().__init__(model_path, base_model, model_type)
|
||||
|
||||
self.child_types: Dict[str, Type] = dict()
|
||||
self.child_sizes: Dict[str, int] = dict()
|
||||
|
||||
try:
|
||||
config_data = DiffusionPipeline.load_config(self.model_path)
|
||||
#config_data = json.loads(os.path.join(self.model_path, "model_index.json"))
|
||||
except:
|
||||
raise Exception("Invalid diffusers model! (model_index.json not found or invalid)")
|
||||
|
||||
config_data.pop("_ignore_files", None)
|
||||
|
||||
# retrieve all folder_names that contain relevant files
|
||||
child_components = [k for k, v in config_data.items() if isinstance(v, list)]
|
||||
|
||||
for child_name in child_components:
|
||||
child_type = self._hf_definition_to_type(config_data[child_name])
|
||||
self.child_types[child_name] = child_type
|
||||
self.child_sizes[child_name] = calc_model_size_by_fs(self.model_path, subfolder=child_name)
|
||||
|
||||
|
||||
def get_size(self, child_type: Optional[SubModelType] = None):
|
||||
if child_type is None:
|
||||
return sum(self.child_sizes.values())
|
||||
else:
|
||||
return self.child_sizes[child_type]
|
||||
|
||||
|
||||
def get_model(
|
||||
self,
|
||||
torch_dtype: Optional[torch.dtype],
|
||||
child_type: Optional[SubModelType] = None,
|
||||
):
|
||||
# return pipeline in different function to pass more arguments
|
||||
if child_type is None:
|
||||
raise Exception("Child model type can't be null on diffusers model")
|
||||
if child_type not in self.child_types:
|
||||
return None # TODO: or raise
|
||||
|
||||
if torch_dtype == torch.float16:
|
||||
variants = ["fp16", None]
|
||||
else:
|
||||
variants = [None, "fp16"]
|
||||
|
||||
# TODO: better error handling(differentiate not found from others)
|
||||
for variant in variants:
|
||||
try:
|
||||
# TODO: set cache_dir to /dev/null to be sure that cache not used?
|
||||
model = self.child_types[child_type].from_pretrained(
|
||||
self.model_path,
|
||||
subfolder=child_type.value,
|
||||
torch_dtype=torch_dtype,
|
||||
variant=variant,
|
||||
local_files_only=True,
|
||||
)
|
||||
break
|
||||
except Exception as e:
|
||||
#print("====ERR LOAD====")
|
||||
#print(f"{variant}: {e}")
|
||||
pass
|
||||
else:
|
||||
raise Exception(f"Failed to load {self.base_model}:{self.model_type}:{child_type} model")
|
||||
|
||||
# calc more accurate size
|
||||
self.child_sizes[child_type] = calc_model_size_by_data(model)
|
||||
return model
|
||||
|
||||
#def convert_if_required(model_path: str, cache_path: str, config: Optional[dict]) -> str:
|
||||
|
||||
|
||||
|
||||
def calc_model_size_by_fs(
|
||||
model_path: str,
|
||||
subfolder: Optional[str] = None,
|
||||
variant: Optional[str] = None
|
||||
):
|
||||
if subfolder is not None:
|
||||
model_path = os.path.join(model_path, subfolder)
|
||||
|
||||
# this can happen when, for example, the safety checker
|
||||
# is not downloaded.
|
||||
if not os.path.exists(model_path):
|
||||
return 0
|
||||
|
||||
all_files = os.listdir(model_path)
|
||||
all_files = [f for f in all_files if os.path.isfile(os.path.join(model_path, f))]
|
||||
|
||||
fp16_files = set([f for f in all_files if ".fp16." in f or ".fp16-" in f])
|
||||
bit8_files = set([f for f in all_files if ".8bit." in f or ".8bit-" in f])
|
||||
other_files = set(all_files) - fp16_files - bit8_files
|
||||
|
||||
if variant is None:
|
||||
files = other_files
|
||||
elif variant == "fp16":
|
||||
files = fp16_files
|
||||
elif variant == "8bit":
|
||||
files = bit8_files
|
||||
else:
|
||||
raise NotImplementedError(f"Unknown variant: {variant}")
|
||||
|
||||
# try read from index if exists
|
||||
index_postfix = ".index.json"
|
||||
if variant is not None:
|
||||
index_postfix = f".index.{variant}.json"
|
||||
|
||||
for file in files:
|
||||
if not file.endswith(index_postfix):
|
||||
continue
|
||||
try:
|
||||
with open(os.path.join(model_path, file), "r") as f:
|
||||
index_data = json.loads(f.read())
|
||||
return int(index_data["metadata"]["total_size"])
|
||||
except:
|
||||
pass
|
||||
|
||||
# calculate files size if there is no index file
|
||||
formats = [
|
||||
(".safetensors",), # safetensors
|
||||
(".bin",), # torch
|
||||
(".onnx", ".pb"), # onnx
|
||||
(".msgpack",), # flax
|
||||
(".ckpt",), # tf
|
||||
(".h5",), # tf2
|
||||
]
|
||||
|
||||
for file_format in formats:
|
||||
model_files = [f for f in files if f.endswith(file_format)]
|
||||
if len(model_files) == 0:
|
||||
continue
|
||||
|
||||
model_size = 0
|
||||
for model_file in model_files:
|
||||
file_stats = os.stat(os.path.join(model_path, model_file))
|
||||
model_size += file_stats.st_size
|
||||
return model_size
|
||||
|
||||
#raise NotImplementedError(f"Unknown model structure! Files: {all_files}")
|
||||
return 0 # scheduler/feature_extractor/tokenizer - models without loading to gpu
|
||||
|
||||
|
||||
def calc_model_size_by_data(model) -> int:
|
||||
if isinstance(model, DiffusionPipeline):
|
||||
return _calc_pipeline_by_data(model)
|
||||
elif isinstance(model, torch.nn.Module):
|
||||
return _calc_model_by_data(model)
|
||||
else:
|
||||
return 0
|
||||
|
||||
|
||||
def _calc_pipeline_by_data(pipeline) -> int:
|
||||
res = 0
|
||||
for submodel_key in pipeline.components.keys():
|
||||
submodel = getattr(pipeline, submodel_key)
|
||||
if submodel is not None and isinstance(submodel, torch.nn.Module):
|
||||
res += _calc_model_by_data(submodel)
|
||||
return res
|
||||
|
||||
|
||||
def _calc_model_by_data(model) -> int:
|
||||
mem_params = sum([param.nelement()*param.element_size() for param in model.parameters()])
|
||||
mem_bufs = sum([buf.nelement()*buf.element_size() for buf in model.buffers()])
|
||||
mem = mem_params + mem_bufs # in bytes
|
||||
return mem
|
||||
|
||||
|
||||
def _fast_safetensors_reader(path: str):
|
||||
checkpoint = dict()
|
||||
device = torch.device("meta")
|
||||
with open(path, "rb") as f:
|
||||
definition_len = int.from_bytes(f.read(8), 'little')
|
||||
definition_json = f.read(definition_len)
|
||||
definition = json.loads(definition_json)
|
||||
|
||||
if "__metadata__" in definition and definition["__metadata__"].get("format", "pt") not in {"pt", "torch", "pytorch"}:
|
||||
raise Exception("Supported only pytorch safetensors files")
|
||||
definition.pop("__metadata__", None)
|
||||
|
||||
for key, info in definition.items():
|
||||
dtype = {
|
||||
"I8": torch.int8,
|
||||
"I16": torch.int16,
|
||||
"I32": torch.int32,
|
||||
"I64": torch.int64,
|
||||
"F16": torch.float16,
|
||||
"F32": torch.float32,
|
||||
"F64": torch.float64,
|
||||
}[info["dtype"]]
|
||||
|
||||
checkpoint[key] = torch.empty(info["shape"], dtype=dtype, device=device)
|
||||
|
||||
return checkpoint
|
||||
|
||||
|
||||
def read_checkpoint_meta(path: str):
|
||||
if path.endswith(".safetensors"):
|
||||
try:
|
||||
checkpoint = _fast_safetensors_reader(path)
|
||||
except:
|
||||
# TODO: create issue for support "meta"?
|
||||
checkpoint = safetensors.torch.load_file(path, device="cpu")
|
||||
else:
|
||||
checkpoint = torch.load(path, map_location=torch.device("meta"))
|
||||
return checkpoint
|
||||
|
||||
import warnings
|
||||
from diffusers import logging as diffusers_logging
|
||||
from transformers import logging as transformers_logging
|
||||
|
||||
class SilenceWarnings(object):
|
||||
def __init__(self):
|
||||
self.transformers_verbosity = transformers_logging.get_verbosity()
|
||||
self.diffusers_verbosity = diffusers_logging.get_verbosity()
|
||||
|
||||
def __enter__(self):
|
||||
transformers_logging.set_verbosity_error()
|
||||
diffusers_logging.set_verbosity_error()
|
||||
warnings.simplefilter('ignore')
|
||||
|
||||
def __exit__(self, type, value, traceback):
|
||||
transformers_logging.set_verbosity(self.transformers_verbosity)
|
||||
diffusers_logging.set_verbosity(self.diffusers_verbosity)
|
||||
warnings.simplefilter('default')
|
92
invokeai/backend/model_management/models/controlnet.py
Normal file
92
invokeai/backend/model_management/models/controlnet.py
Normal file
@ -0,0 +1,92 @@
|
||||
import os
|
||||
import torch
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Optional, Union, Literal
|
||||
from .base import (
|
||||
ModelBase,
|
||||
ModelConfigBase,
|
||||
BaseModelType,
|
||||
ModelType,
|
||||
SubModelType,
|
||||
EmptyConfigLoader,
|
||||
calc_model_size_by_fs,
|
||||
calc_model_size_by_data,
|
||||
classproperty,
|
||||
)
|
||||
|
||||
class ControlNetModelFormat(str, Enum):
|
||||
Checkpoint = "checkpoint"
|
||||
Diffusers = "diffusers"
|
||||
|
||||
class ControlNetModel(ModelBase):
|
||||
#model_class: Type
|
||||
#model_size: int
|
||||
|
||||
class Config(ModelConfigBase):
|
||||
model_format: ControlNetModelFormat
|
||||
|
||||
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
|
||||
assert model_type == ModelType.ControlNet
|
||||
super().__init__(model_path, base_model, model_type)
|
||||
|
||||
try:
|
||||
config = EmptyConfigLoader.load_config(self.model_path, config_name="config.json")
|
||||
#config = json.loads(os.path.join(self.model_path, "config.json"))
|
||||
except:
|
||||
raise Exception("Invalid controlnet model! (config.json not found or invalid)")
|
||||
|
||||
model_class_name = config.get("_class_name", None)
|
||||
if model_class_name not in {"ControlNetModel"}:
|
||||
raise Exception(f"Invalid ControlNet model! Unknown _class_name: {model_class_name}")
|
||||
|
||||
try:
|
||||
self.model_class = self._hf_definition_to_type(["diffusers", model_class_name])
|
||||
self.model_size = calc_model_size_by_fs(self.model_path)
|
||||
except:
|
||||
raise Exception("Invalid ControlNet model!")
|
||||
|
||||
def get_size(self, child_type: Optional[SubModelType] = None):
|
||||
if child_type is not None:
|
||||
raise Exception("There is no child models in controlnet model")
|
||||
return self.model_size
|
||||
|
||||
def get_model(
|
||||
self,
|
||||
torch_dtype: Optional[torch.dtype],
|
||||
child_type: Optional[SubModelType] = None,
|
||||
):
|
||||
if child_type is not None:
|
||||
raise Exception("There is no child models in controlnet model")
|
||||
|
||||
model = self.model_class.from_pretrained(
|
||||
self.model_path,
|
||||
torch_dtype=torch_dtype,
|
||||
)
|
||||
# calc more accurate size
|
||||
self.model_size = calc_model_size_by_data(model)
|
||||
return model
|
||||
|
||||
@classproperty
|
||||
def save_to_config(cls) -> bool:
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def detect_format(cls, path: str):
|
||||
if os.path.isdir(path):
|
||||
return ControlNetModelFormat.Diffusers
|
||||
else:
|
||||
return ControlNetModelFormat.Checkpoint
|
||||
|
||||
@classmethod
|
||||
def convert_if_required(
|
||||
cls,
|
||||
model_path: str,
|
||||
output_path: str,
|
||||
config: ModelConfigBase, # empty config or config of parent model
|
||||
base_model: BaseModelType,
|
||||
) -> str:
|
||||
if cls.detect_format(model_path) != ControlNetModelFormat.Diffusers:
|
||||
raise NotImplementedError("Checkpoint controlnet models currently unsupported")
|
||||
else:
|
||||
return model_path
|
76
invokeai/backend/model_management/models/lora.py
Normal file
76
invokeai/backend/model_management/models/lora.py
Normal file
@ -0,0 +1,76 @@
|
||||
import os
|
||||
import torch
|
||||
from enum import Enum
|
||||
from typing import Optional, Union, Literal
|
||||
from .base import (
|
||||
ModelBase,
|
||||
ModelConfigBase,
|
||||
BaseModelType,
|
||||
ModelType,
|
||||
SubModelType,
|
||||
classproperty,
|
||||
)
|
||||
# TODO: naming
|
||||
from ..lora import LoRAModel as LoRAModelRaw
|
||||
|
||||
class LoRAModelFormat(str, Enum):
|
||||
LyCORIS = "lycoris"
|
||||
Diffusers = "diffusers"
|
||||
|
||||
class LoRAModel(ModelBase):
|
||||
#model_size: int
|
||||
|
||||
class Config(ModelConfigBase):
|
||||
model_format: LoRAModelFormat # TODO:
|
||||
|
||||
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
|
||||
assert model_type == ModelType.Lora
|
||||
super().__init__(model_path, base_model, model_type)
|
||||
|
||||
self.model_size = os.path.getsize(self.model_path)
|
||||
|
||||
def get_size(self, child_type: Optional[SubModelType] = None):
|
||||
if child_type is not None:
|
||||
raise Exception("There is no child models in lora")
|
||||
return self.model_size
|
||||
|
||||
def get_model(
|
||||
self,
|
||||
torch_dtype: Optional[torch.dtype],
|
||||
child_type: Optional[SubModelType] = None,
|
||||
):
|
||||
if child_type is not None:
|
||||
raise Exception("There is no child models in lora")
|
||||
|
||||
model = LoRAModelRaw.from_checkpoint(
|
||||
file_path=self.model_path,
|
||||
dtype=torch_dtype,
|
||||
)
|
||||
|
||||
self.model_size = model.calc_size()
|
||||
return model
|
||||
|
||||
@classproperty
|
||||
def save_to_config(cls) -> bool:
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def detect_format(cls, path: str):
|
||||
if os.path.isdir(path):
|
||||
return LoRAModelFormat.Diffusers
|
||||
else:
|
||||
return LoRAModelFormat.LyCORIS
|
||||
|
||||
@classmethod
|
||||
def convert_if_required(
|
||||
cls,
|
||||
model_path: str,
|
||||
output_path: str,
|
||||
config: ModelConfigBase,
|
||||
base_model: BaseModelType,
|
||||
) -> str:
|
||||
if cls.detect_format(model_path) == LoRAModelFormat.Diffusers:
|
||||
# TODO: add diffusers lora when it stabilizes a bit
|
||||
raise NotImplementedError("Diffusers lora not supported")
|
||||
else:
|
||||
return model_path
|
321
invokeai/backend/model_management/models/stable_diffusion.py
Normal file
321
invokeai/backend/model_management/models/stable_diffusion.py
Normal file
@ -0,0 +1,321 @@
|
||||
import os
|
||||
import json
|
||||
from enum import Enum
|
||||
from pydantic import Field
|
||||
from pathlib import Path
|
||||
from typing import Literal, Optional, Union
|
||||
from .base import (
|
||||
ModelBase,
|
||||
ModelConfigBase,
|
||||
BaseModelType,
|
||||
ModelType,
|
||||
SubModelType,
|
||||
ModelVariantType,
|
||||
DiffusersModel,
|
||||
SchedulerPredictionType,
|
||||
SilenceWarnings,
|
||||
read_checkpoint_meta,
|
||||
classproperty,
|
||||
)
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
class StableDiffusion1ModelFormat(str, Enum):
|
||||
Checkpoint = "checkpoint"
|
||||
Diffusers = "diffusers"
|
||||
|
||||
class StableDiffusion1Model(DiffusersModel):
|
||||
|
||||
class DiffusersConfig(ModelConfigBase):
|
||||
model_format: Literal[StableDiffusion1ModelFormat.Diffusers]
|
||||
vae: Optional[str] = Field(None)
|
||||
variant: ModelVariantType
|
||||
|
||||
class CheckpointConfig(ModelConfigBase):
|
||||
model_format: Literal[StableDiffusion1ModelFormat.Checkpoint]
|
||||
vae: Optional[str] = Field(None)
|
||||
config: Optional[str] = Field(None)
|
||||
variant: ModelVariantType
|
||||
|
||||
|
||||
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
|
||||
assert base_model == BaseModelType.StableDiffusion1
|
||||
assert model_type == ModelType.Pipeline
|
||||
super().__init__(
|
||||
model_path=model_path,
|
||||
base_model=BaseModelType.StableDiffusion1,
|
||||
model_type=ModelType.Pipeline,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def probe_config(cls, path: str, **kwargs):
|
||||
model_format = cls.detect_format(path)
|
||||
ckpt_config_path = kwargs.get("config", None)
|
||||
if model_format == StableDiffusion1ModelFormat.Checkpoint:
|
||||
if ckpt_config_path:
|
||||
ckpt_config = OmegaConf.load(ckpt_config_path)
|
||||
ckpt_config["model"]["params"]["unet_config"]["params"]["in_channels"]
|
||||
|
||||
else:
|
||||
checkpoint = read_checkpoint_meta(path)
|
||||
checkpoint = checkpoint.get('state_dict', checkpoint)
|
||||
in_channels = checkpoint["model.diffusion_model.input_blocks.0.0.weight"].shape[1]
|
||||
|
||||
elif model_format == StableDiffusion1ModelFormat.Diffusers:
|
||||
unet_config_path = os.path.join(path, "unet", "config.json")
|
||||
if os.path.exists(unet_config_path):
|
||||
with open(unet_config_path, "r") as f:
|
||||
unet_config = json.loads(f.read())
|
||||
in_channels = unet_config['in_channels']
|
||||
|
||||
else:
|
||||
raise Exception("Not supported stable diffusion diffusers format(possibly onnx?)")
|
||||
|
||||
else:
|
||||
raise NotImplementedError(f"Unknown stable diffusion 1.* format: {model_format}")
|
||||
|
||||
if in_channels == 9:
|
||||
variant = ModelVariantType.Inpaint
|
||||
elif in_channels == 4:
|
||||
variant = ModelVariantType.Normal
|
||||
else:
|
||||
raise Exception("Unkown stable diffusion 1.* model format")
|
||||
|
||||
|
||||
return cls.create_config(
|
||||
path=path,
|
||||
model_format=model_format,
|
||||
|
||||
config=ckpt_config_path,
|
||||
variant=variant,
|
||||
)
|
||||
|
||||
@classproperty
|
||||
def save_to_config(cls) -> bool:
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def detect_format(cls, model_path: str):
|
||||
if os.path.isdir(model_path):
|
||||
return StableDiffusion1ModelFormat.Diffusers
|
||||
else:
|
||||
return StableDiffusion1ModelFormat.Checkpoint
|
||||
|
||||
@classmethod
|
||||
def convert_if_required(
|
||||
cls,
|
||||
model_path: str,
|
||||
output_path: str,
|
||||
config: ModelConfigBase,
|
||||
base_model: BaseModelType,
|
||||
) -> str:
|
||||
assert model_path == config.path
|
||||
|
||||
if isinstance(config, cls.CheckpointConfig):
|
||||
return _convert_ckpt_and_cache(
|
||||
version=BaseModelType.StableDiffusion1,
|
||||
model_config=config,
|
||||
output_path=output_path,
|
||||
) # TODO: args
|
||||
else:
|
||||
return model_path
|
||||
|
||||
class StableDiffusion2ModelFormat(str, Enum):
|
||||
Checkpoint = "checkpoint"
|
||||
Diffusers = "diffusers"
|
||||
|
||||
class StableDiffusion2Model(DiffusersModel):
|
||||
|
||||
# TODO: check that configs overwriten properly
|
||||
class DiffusersConfig(ModelConfigBase):
|
||||
model_format: Literal[StableDiffusion2ModelFormat.Diffusers]
|
||||
vae: Optional[str] = Field(None)
|
||||
variant: ModelVariantType
|
||||
prediction_type: SchedulerPredictionType
|
||||
upcast_attention: bool
|
||||
|
||||
class CheckpointConfig(ModelConfigBase):
|
||||
model_format: Literal[StableDiffusion2ModelFormat.Checkpoint]
|
||||
vae: Optional[str] = Field(None)
|
||||
config: Optional[str] = Field(None)
|
||||
variant: ModelVariantType
|
||||
prediction_type: SchedulerPredictionType
|
||||
upcast_attention: bool
|
||||
|
||||
|
||||
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
|
||||
assert base_model == BaseModelType.StableDiffusion2
|
||||
assert model_type == ModelType.Pipeline
|
||||
super().__init__(
|
||||
model_path=model_path,
|
||||
base_model=BaseModelType.StableDiffusion2,
|
||||
model_type=ModelType.Pipeline,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def probe_config(cls, path: str, **kwargs):
|
||||
model_format = cls.detect_format(path)
|
||||
ckpt_config_path = kwargs.get("config", None)
|
||||
if model_format == StableDiffusion2ModelFormat.Checkpoint:
|
||||
if ckpt_config_path:
|
||||
ckpt_config = OmegaConf.load(ckpt_config_path)
|
||||
ckpt_config["model"]["params"]["unet_config"]["params"]["in_channels"]
|
||||
|
||||
else:
|
||||
checkpoint = read_checkpoint_meta(path)
|
||||
checkpoint = checkpoint.get('state_dict', checkpoint)
|
||||
in_channels = checkpoint["model.diffusion_model.input_blocks.0.0.weight"].shape[1]
|
||||
|
||||
elif model_format == StableDiffusion2ModelFormat.Diffusers:
|
||||
unet_config_path = os.path.join(path, "unet", "config.json")
|
||||
if os.path.exists(unet_config_path):
|
||||
with open(unet_config_path, "r") as f:
|
||||
unet_config = json.loads(f.read())
|
||||
in_channels = unet_config['in_channels']
|
||||
|
||||
else:
|
||||
raise Exception("Not supported stable diffusion diffusers format(possibly onnx?)")
|
||||
|
||||
else:
|
||||
raise NotImplementedError(f"Unknown stable diffusion 2.* format: {model_format}")
|
||||
|
||||
if in_channels == 9:
|
||||
variant = ModelVariantType.Inpaint
|
||||
elif in_channels == 5:
|
||||
variant = ModelVariantType.Depth
|
||||
elif in_channels == 4:
|
||||
variant = ModelVariantType.Normal
|
||||
else:
|
||||
raise Exception("Unkown stable diffusion 2.* model format")
|
||||
|
||||
if variant == ModelVariantType.Normal:
|
||||
prediction_type = SchedulerPredictionType.VPrediction
|
||||
upcast_attention = True
|
||||
|
||||
else:
|
||||
prediction_type = SchedulerPredictionType.Epsilon
|
||||
upcast_attention = False
|
||||
|
||||
return cls.create_config(
|
||||
path=path,
|
||||
model_format=model_format,
|
||||
|
||||
config=ckpt_config_path,
|
||||
variant=variant,
|
||||
prediction_type=prediction_type,
|
||||
upcast_attention=upcast_attention,
|
||||
)
|
||||
|
||||
@classproperty
|
||||
def save_to_config(cls) -> bool:
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def detect_format(cls, model_path: str):
|
||||
if os.path.isdir(model_path):
|
||||
return StableDiffusion2ModelFormat.Diffusers
|
||||
else:
|
||||
return StableDiffusion2ModelFormat.Checkpoint
|
||||
|
||||
@classmethod
|
||||
def convert_if_required(
|
||||
cls,
|
||||
model_path: str,
|
||||
output_path: str,
|
||||
config: ModelConfigBase,
|
||||
base_model: BaseModelType,
|
||||
) -> str:
|
||||
assert model_path == config.path
|
||||
|
||||
if isinstance(config, cls.CheckpointConfig):
|
||||
return _convert_ckpt_and_cache(
|
||||
version=BaseModelType.StableDiffusion2,
|
||||
model_config=config,
|
||||
output_path=output_path,
|
||||
) # TODO: args
|
||||
else:
|
||||
return model_path
|
||||
|
||||
def _select_ckpt_config(version: BaseModelType, variant: ModelVariantType):
|
||||
ckpt_configs = {
|
||||
BaseModelType.StableDiffusion1: {
|
||||
ModelVariantType.Normal: "v1-inference.yaml",
|
||||
ModelVariantType.Inpaint: "v1-inpainting-inference.yaml",
|
||||
},
|
||||
BaseModelType.StableDiffusion2: {
|
||||
# code further will manually set upcast_attention and v_prediction
|
||||
ModelVariantType.Normal: "v2-inference.yaml",
|
||||
ModelVariantType.Inpaint: "v2-inpainting-inference.yaml",
|
||||
ModelVariantType.Depth: "v2-midas-inference.yaml",
|
||||
}
|
||||
}
|
||||
|
||||
try:
|
||||
# TODO: path
|
||||
#model_config.config = app_config.config_dir / "stable-diffusion" / ckpt_configs[version][model_config.variant]
|
||||
#return InvokeAIAppConfig.get_config().legacy_conf_dir / ckpt_configs[version][variant]
|
||||
return InvokeAIAppConfig.get_config().root_dir / "configs" / "stable-diffusion" / ckpt_configs[version][variant]
|
||||
|
||||
except:
|
||||
return None
|
||||
|
||||
|
||||
# TODO: rework
|
||||
def _convert_ckpt_and_cache(
|
||||
version: BaseModelType,
|
||||
model_config: Union[StableDiffusion1Model.CheckpointConfig, StableDiffusion2Model.CheckpointConfig],
|
||||
output_path: str,
|
||||
) -> str:
|
||||
"""
|
||||
Convert the checkpoint model indicated in mconfig into a
|
||||
diffusers, cache it to disk, and return Path to converted
|
||||
file. If already on disk then just returns Path.
|
||||
"""
|
||||
app_config = InvokeAIAppConfig.get_config()
|
||||
|
||||
if model_config.config is None:
|
||||
model_config.config = _select_ckpt_config(version, model_config.variant)
|
||||
if model_config.config is None:
|
||||
raise Exception(f"Model variant {model_config.variant} not supported for {version}")
|
||||
|
||||
|
||||
weights = app_config.root_dir / model_config.path
|
||||
config_file = app_config.root_dir / model_config.config
|
||||
output_path = Path(output_path)
|
||||
|
||||
if version == BaseModelType.StableDiffusion1:
|
||||
upcast_attention = False
|
||||
prediction_type = SchedulerPredictionType.Epsilon
|
||||
|
||||
elif version == BaseModelType.StableDiffusion2:
|
||||
upcast_attention = model_config.upcast_attention
|
||||
prediction_type = model_config.prediction_type
|
||||
|
||||
else:
|
||||
raise Exception(f"Unknown model provided: {version}")
|
||||
|
||||
|
||||
# return cached version if it exists
|
||||
if output_path.exists():
|
||||
return output_path
|
||||
|
||||
# TODO: I think that it more correctly to convert with embedded vae
|
||||
# as if user will delete custom vae he will got not embedded but also custom vae
|
||||
#vae_ckpt_path, vae_model = self._get_vae_for_conversion(weights, mconfig)
|
||||
|
||||
# to avoid circular import errors
|
||||
from ..convert_ckpt_to_diffusers import convert_ckpt_to_diffusers
|
||||
with SilenceWarnings():
|
||||
convert_ckpt_to_diffusers(
|
||||
weights,
|
||||
output_path,
|
||||
model_version=version,
|
||||
model_variant=model_config.variant,
|
||||
original_config_file=config_file,
|
||||
extract_ema=True,
|
||||
upcast_attention=upcast_attention,
|
||||
prediction_type=prediction_type,
|
||||
scan_needed=True,
|
||||
model_root=app_config.models_path,
|
||||
)
|
||||
return output_path
|
@ -0,0 +1,64 @@
|
||||
import os
|
||||
import torch
|
||||
from typing import Optional
|
||||
from .base import (
|
||||
ModelBase,
|
||||
ModelConfigBase,
|
||||
BaseModelType,
|
||||
ModelType,
|
||||
SubModelType,
|
||||
classproperty,
|
||||
)
|
||||
# TODO: naming
|
||||
from ..lora import TextualInversionModel as TextualInversionModelRaw
|
||||
|
||||
class TextualInversionModel(ModelBase):
|
||||
#model_size: int
|
||||
|
||||
class Config(ModelConfigBase):
|
||||
model_format: None
|
||||
|
||||
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
|
||||
assert model_type == ModelType.TextualInversion
|
||||
super().__init__(model_path, base_model, model_type)
|
||||
|
||||
self.model_size = os.path.getsize(self.model_path)
|
||||
|
||||
def get_size(self, child_type: Optional[SubModelType] = None):
|
||||
if child_type is not None:
|
||||
raise Exception("There is no child models in textual inversion")
|
||||
return self.model_size
|
||||
|
||||
def get_model(
|
||||
self,
|
||||
torch_dtype: Optional[torch.dtype],
|
||||
child_type: Optional[SubModelType] = None,
|
||||
):
|
||||
if child_type is not None:
|
||||
raise Exception("There is no child models in textual inversion")
|
||||
|
||||
model = TextualInversionModelRaw.from_checkpoint(
|
||||
file_path=self.model_path,
|
||||
dtype=torch_dtype,
|
||||
)
|
||||
|
||||
self.model_size = model.embedding.nelement() * model.embedding.element_size()
|
||||
return model
|
||||
|
||||
@classproperty
|
||||
def save_to_config(cls) -> bool:
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def detect_format(cls, path: str):
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def convert_if_required(
|
||||
cls,
|
||||
model_path: str,
|
||||
output_path: str,
|
||||
config: ModelConfigBase,
|
||||
base_model: BaseModelType,
|
||||
) -> str:
|
||||
return model_path
|
166
invokeai/backend/model_management/models/vae.py
Normal file
166
invokeai/backend/model_management/models/vae.py
Normal file
@ -0,0 +1,166 @@
|
||||
import os
|
||||
import torch
|
||||
import safetensors
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Optional, Union, Literal
|
||||
from .base import (
|
||||
ModelBase,
|
||||
ModelConfigBase,
|
||||
BaseModelType,
|
||||
ModelType,
|
||||
SubModelType,
|
||||
ModelVariantType,
|
||||
EmptyConfigLoader,
|
||||
calc_model_size_by_fs,
|
||||
calc_model_size_by_data,
|
||||
classproperty,
|
||||
)
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from diffusers.utils import is_safetensors_available
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
class VaeModelFormat(str, Enum):
|
||||
Checkpoint = "checkpoint"
|
||||
Diffusers = "diffusers"
|
||||
|
||||
class VaeModel(ModelBase):
|
||||
#vae_class: Type
|
||||
#model_size: int
|
||||
|
||||
class Config(ModelConfigBase):
|
||||
model_format: VaeModelFormat
|
||||
|
||||
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
|
||||
assert model_type == ModelType.Vae
|
||||
super().__init__(model_path, base_model, model_type)
|
||||
|
||||
try:
|
||||
config = EmptyConfigLoader.load_config(self.model_path, config_name="config.json")
|
||||
#config = json.loads(os.path.join(self.model_path, "config.json"))
|
||||
except:
|
||||
raise Exception("Invalid vae model! (config.json not found or invalid)")
|
||||
|
||||
try:
|
||||
vae_class_name = config.get("_class_name", "AutoencoderKL")
|
||||
self.vae_class = self._hf_definition_to_type(["diffusers", vae_class_name])
|
||||
self.model_size = calc_model_size_by_fs(self.model_path)
|
||||
except:
|
||||
raise Exception("Invalid vae model! (Unkown vae type)")
|
||||
|
||||
def get_size(self, child_type: Optional[SubModelType] = None):
|
||||
if child_type is not None:
|
||||
raise Exception("There is no child models in vae model")
|
||||
return self.model_size
|
||||
|
||||
def get_model(
|
||||
self,
|
||||
torch_dtype: Optional[torch.dtype],
|
||||
child_type: Optional[SubModelType] = None,
|
||||
):
|
||||
if child_type is not None:
|
||||
raise Exception("There is no child models in vae model")
|
||||
|
||||
model = self.vae_class.from_pretrained(
|
||||
self.model_path,
|
||||
torch_dtype=torch_dtype,
|
||||
)
|
||||
# calc more accurate size
|
||||
self.model_size = calc_model_size_by_data(model)
|
||||
return model
|
||||
|
||||
@classproperty
|
||||
def save_to_config(cls) -> bool:
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def detect_format(cls, path: str):
|
||||
if os.path.isdir(path):
|
||||
return VaeModelFormat.Diffusers
|
||||
else:
|
||||
return VaeModelFormat.Checkpoint
|
||||
|
||||
@classmethod
|
||||
def convert_if_required(
|
||||
cls,
|
||||
model_path: str,
|
||||
output_path: str,
|
||||
config: ModelConfigBase, # empty config or config of parent model
|
||||
base_model: BaseModelType,
|
||||
) -> str:
|
||||
if cls.detect_format(model_path) == VaeModelFormat.Checkpoint:
|
||||
return _convert_vae_ckpt_and_cache(
|
||||
weights_path=model_path,
|
||||
output_path=output_path,
|
||||
base_model=base_model,
|
||||
model_config=config,
|
||||
)
|
||||
else:
|
||||
return model_path
|
||||
|
||||
# TODO: rework
|
||||
def _convert_vae_ckpt_and_cache(
|
||||
weights_path: str,
|
||||
output_path: str,
|
||||
base_model: BaseModelType,
|
||||
model_config: ModelConfigBase,
|
||||
) -> str:
|
||||
"""
|
||||
Convert the VAE indicated in mconfig into a diffusers AutoencoderKL
|
||||
object, cache it to disk, and return Path to converted
|
||||
file. If already on disk then just returns Path.
|
||||
"""
|
||||
app_config = InvokeAIAppConfig.get_config()
|
||||
weights_path = app_config.root_dir / weights_path
|
||||
output_path = Path(output_path)
|
||||
|
||||
"""
|
||||
this size used only in when tiling enabled to separate input in tiles
|
||||
sizes in configs from stable diffusion githubs(1 and 2) set to 256
|
||||
on huggingface it:
|
||||
1.5 - 512
|
||||
1.5-inpainting - 256
|
||||
2-inpainting - 512
|
||||
2-depth - 256
|
||||
2-base - 512
|
||||
2 - 768
|
||||
2.1-base - 768
|
||||
2.1 - 768
|
||||
"""
|
||||
image_size = 512
|
||||
|
||||
# return cached version if it exists
|
||||
if output_path.exists():
|
||||
return output_path
|
||||
|
||||
if base_model in {BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2}:
|
||||
from .stable_diffusion import _select_ckpt_config
|
||||
# all sd models use same vae settings
|
||||
config_file = _select_ckpt_config(base_model, ModelVariantType.Normal)
|
||||
|
||||
else:
|
||||
raise Exception(f"Vae conversion not supported for model type: {base_model}")
|
||||
|
||||
# this avoids circular import error
|
||||
from ..convert_ckpt_to_diffusers import convert_ldm_vae_to_diffusers
|
||||
if weights_path.suffix == '.safetensors':
|
||||
checkpoint = safetensors.torch.load_file(weights_path, device="cpu")
|
||||
else:
|
||||
checkpoint = torch.load(weights_path, map_location="cpu")
|
||||
|
||||
# sometimes weights are hidden under "state_dict", and sometimes not
|
||||
if "state_dict" in checkpoint:
|
||||
checkpoint = checkpoint["state_dict"]
|
||||
|
||||
config = OmegaConf.load(config_file)
|
||||
|
||||
vae_model = convert_ldm_vae_to_diffusers(
|
||||
checkpoint = checkpoint,
|
||||
vae_config = config,
|
||||
image_size = image_size,
|
||||
)
|
||||
vae_model.save_pretrained(
|
||||
output_path,
|
||||
safe_serialization=is_safetensors_available()
|
||||
)
|
||||
return output_path
|
@ -1,9 +0,0 @@
|
||||
"""
|
||||
Initialization file for invokeai.backend.prompting
|
||||
"""
|
||||
from .conditioning import (
|
||||
get_prompt_structure,
|
||||
get_tokens_for_prompt_object,
|
||||
get_uc_and_c_and_ec,
|
||||
split_weighted_subprompts,
|
||||
)
|
@ -1,296 +0,0 @@
|
||||
"""
|
||||
This module handles the generation of the conditioning tensors.
|
||||
|
||||
Useful function exports:
|
||||
|
||||
get_uc_and_c_and_ec() get the conditioned and unconditioned latent, and edited conditioning if we're doing cross-attention control
|
||||
|
||||
"""
|
||||
import re
|
||||
from typing import Optional, Union
|
||||
|
||||
from compel import Compel
|
||||
from compel.prompt_parser import (
|
||||
Blend,
|
||||
CrossAttentionControlSubstitute,
|
||||
FlattenedPrompt,
|
||||
Fragment,
|
||||
PromptParser,
|
||||
Conjunction,
|
||||
)
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from ..stable_diffusion import InvokeAIDiffuserComponent
|
||||
from ..util import torch_dtype
|
||||
|
||||
config = InvokeAIAppConfig.get_config()
|
||||
|
||||
def get_uc_and_c_and_ec(prompt_string,
|
||||
model: InvokeAIDiffuserComponent,
|
||||
log_tokens=False, skip_normalize_legacy_blend=False):
|
||||
# lazy-load any deferred textual inversions.
|
||||
# this might take a couple of seconds the first time a textual inversion is used.
|
||||
model.textual_inversion_manager.create_deferred_token_ids_for_any_trigger_terms(prompt_string)
|
||||
|
||||
compel = Compel(tokenizer=model.tokenizer,
|
||||
text_encoder=model.text_encoder,
|
||||
textual_inversion_manager=model.textual_inversion_manager,
|
||||
dtype_for_device_getter=torch_dtype,
|
||||
truncate_long_prompts=False,
|
||||
)
|
||||
|
||||
# get rid of any newline characters
|
||||
prompt_string = prompt_string.replace("\n", " ")
|
||||
positive_prompt_string, negative_prompt_string = split_prompt_to_positive_and_negative(prompt_string)
|
||||
|
||||
legacy_blend = try_parse_legacy_blend(positive_prompt_string, skip_normalize_legacy_blend)
|
||||
positive_conjunction: Conjunction
|
||||
if legacy_blend is not None:
|
||||
positive_conjunction = legacy_blend
|
||||
else:
|
||||
positive_conjunction = Compel.parse_prompt_string(positive_prompt_string)
|
||||
positive_prompt = positive_conjunction.prompts[0]
|
||||
|
||||
negative_conjunction = Compel.parse_prompt_string(negative_prompt_string)
|
||||
negative_prompt: FlattenedPrompt | Blend = negative_conjunction.prompts[0]
|
||||
|
||||
tokens_count = get_max_token_count(model.tokenizer, positive_prompt)
|
||||
if log_tokens or config.log_tokenization:
|
||||
log_tokenization(positive_prompt, negative_prompt, tokenizer=model.tokenizer)
|
||||
|
||||
c, options = compel.build_conditioning_tensor_for_prompt_object(positive_prompt)
|
||||
uc, _ = compel.build_conditioning_tensor_for_prompt_object(negative_prompt)
|
||||
[c, uc] = compel.pad_conditioning_tensors_to_same_length([c, uc])
|
||||
|
||||
ec = InvokeAIDiffuserComponent.ExtraConditioningInfo(tokens_count_including_eos_bos=tokens_count,
|
||||
cross_attention_control_args=options.get(
|
||||
'cross_attention_control', None))
|
||||
return uc, c, ec
|
||||
|
||||
def get_prompt_structure(
|
||||
prompt_string, skip_normalize_legacy_blend: bool = False
|
||||
) -> (Union[FlattenedPrompt, Blend], FlattenedPrompt):
|
||||
(
|
||||
positive_prompt_string,
|
||||
negative_prompt_string,
|
||||
) = split_prompt_to_positive_and_negative(prompt_string)
|
||||
legacy_blend = try_parse_legacy_blend(
|
||||
positive_prompt_string, skip_normalize_legacy_blend
|
||||
)
|
||||
positive_prompt: Conjunction
|
||||
if legacy_blend is not None:
|
||||
positive_conjunction = legacy_blend
|
||||
else:
|
||||
positive_conjunction = Compel.parse_prompt_string(positive_prompt_string)
|
||||
positive_prompt = positive_conjunction.prompts[0]
|
||||
negative_conjunction = Compel.parse_prompt_string(negative_prompt_string)
|
||||
negative_prompt: FlattenedPrompt|Blend = negative_conjunction.prompts[0]
|
||||
|
||||
return positive_prompt, negative_prompt
|
||||
|
||||
def get_max_token_count(
|
||||
tokenizer, prompt: Union[FlattenedPrompt, Blend], truncate_if_too_long=False
|
||||
) -> int:
|
||||
if type(prompt) is Blend:
|
||||
blend: Blend = prompt
|
||||
return max(
|
||||
[
|
||||
get_max_token_count(tokenizer, c, truncate_if_too_long)
|
||||
for c in blend.prompts
|
||||
]
|
||||
)
|
||||
else:
|
||||
return len(
|
||||
get_tokens_for_prompt_object(tokenizer, prompt, truncate_if_too_long)
|
||||
)
|
||||
|
||||
|
||||
def get_tokens_for_prompt_object(
|
||||
tokenizer, parsed_prompt: FlattenedPrompt, truncate_if_too_long=True
|
||||
) -> [str]:
|
||||
if type(parsed_prompt) is Blend:
|
||||
raise ValueError(
|
||||
"Blend is not supported here - you need to get tokens for each of its .children"
|
||||
)
|
||||
|
||||
text_fragments = [
|
||||
x.text
|
||||
if type(x) is Fragment
|
||||
else (
|
||||
" ".join([f.text for f in x.original])
|
||||
if type(x) is CrossAttentionControlSubstitute
|
||||
else str(x)
|
||||
)
|
||||
for x in parsed_prompt.children
|
||||
]
|
||||
text = " ".join(text_fragments)
|
||||
tokens = tokenizer.tokenize(text)
|
||||
if truncate_if_too_long:
|
||||
max_tokens_length = tokenizer.model_max_length - 2 # typically 75
|
||||
tokens = tokens[0:max_tokens_length]
|
||||
return tokens
|
||||
|
||||
|
||||
def split_prompt_to_positive_and_negative(prompt_string_uncleaned: str):
|
||||
unconditioned_words = ""
|
||||
unconditional_regex = r"\[(.*?)\]"
|
||||
unconditionals = re.findall(unconditional_regex, prompt_string_uncleaned)
|
||||
if len(unconditionals) > 0:
|
||||
unconditioned_words = " ".join(unconditionals)
|
||||
|
||||
# Remove Unconditioned Words From Prompt
|
||||
unconditional_regex_compile = re.compile(unconditional_regex)
|
||||
clean_prompt = unconditional_regex_compile.sub(" ", prompt_string_uncleaned)
|
||||
prompt_string_cleaned = re.sub(" +", " ", clean_prompt)
|
||||
else:
|
||||
prompt_string_cleaned = prompt_string_uncleaned
|
||||
return prompt_string_cleaned, unconditioned_words
|
||||
|
||||
|
||||
def log_tokenization(
|
||||
positive_prompt: Union[Blend, FlattenedPrompt],
|
||||
negative_prompt: Union[Blend, FlattenedPrompt],
|
||||
tokenizer,
|
||||
):
|
||||
logger.info(f"[TOKENLOG] Parsed Prompt: {positive_prompt}")
|
||||
logger.info(f"[TOKENLOG] Parsed Negative Prompt: {negative_prompt}")
|
||||
|
||||
log_tokenization_for_prompt_object(positive_prompt, tokenizer)
|
||||
log_tokenization_for_prompt_object(
|
||||
negative_prompt, tokenizer, display_label_prefix="(negative prompt)"
|
||||
)
|
||||
|
||||
|
||||
def log_tokenization_for_prompt_object(
|
||||
p: Union[Blend, FlattenedPrompt], tokenizer, display_label_prefix=None
|
||||
):
|
||||
display_label_prefix = display_label_prefix or ""
|
||||
if type(p) is Blend:
|
||||
blend: Blend = p
|
||||
for i, c in enumerate(blend.prompts):
|
||||
log_tokenization_for_prompt_object(
|
||||
c,
|
||||
tokenizer,
|
||||
display_label_prefix=f"{display_label_prefix}(blend part {i + 1}, weight={blend.weights[i]})",
|
||||
)
|
||||
elif type(p) is FlattenedPrompt:
|
||||
flattened_prompt: FlattenedPrompt = p
|
||||
if flattened_prompt.wants_cross_attention_control:
|
||||
original_fragments = []
|
||||
edited_fragments = []
|
||||
for f in flattened_prompt.children:
|
||||
if type(f) is CrossAttentionControlSubstitute:
|
||||
original_fragments += f.original
|
||||
edited_fragments += f.edited
|
||||
else:
|
||||
original_fragments.append(f)
|
||||
edited_fragments.append(f)
|
||||
|
||||
original_text = " ".join([x.text for x in original_fragments])
|
||||
log_tokenization_for_text(
|
||||
original_text,
|
||||
tokenizer,
|
||||
display_label=f"{display_label_prefix}(.swap originals)",
|
||||
)
|
||||
edited_text = " ".join([x.text for x in edited_fragments])
|
||||
log_tokenization_for_text(
|
||||
edited_text,
|
||||
tokenizer,
|
||||
display_label=f"{display_label_prefix}(.swap replacements)",
|
||||
)
|
||||
else:
|
||||
text = " ".join([x.text for x in flattened_prompt.children])
|
||||
log_tokenization_for_text(
|
||||
text, tokenizer, display_label=display_label_prefix
|
||||
)
|
||||
|
||||
|
||||
def log_tokenization_for_text(text, tokenizer, display_label=None, truncate_if_too_long=False):
|
||||
"""shows how the prompt is tokenized
|
||||
# usually tokens have '</w>' to indicate end-of-word,
|
||||
# but for readability it has been replaced with ' '
|
||||
"""
|
||||
tokens = tokenizer.tokenize(text)
|
||||
tokenized = ""
|
||||
discarded = ""
|
||||
usedTokens = 0
|
||||
totalTokens = len(tokens)
|
||||
|
||||
for i in range(0, totalTokens):
|
||||
token = tokens[i].replace("</w>", " ")
|
||||
# alternate color
|
||||
s = (usedTokens % 6) + 1
|
||||
if truncate_if_too_long and i >= tokenizer.model_max_length:
|
||||
discarded = discarded + f"\x1b[0;3{s};40m{token}"
|
||||
else:
|
||||
tokenized = tokenized + f"\x1b[0;3{s};40m{token}"
|
||||
usedTokens += 1
|
||||
|
||||
if usedTokens > 0:
|
||||
logger.info(f'[TOKENLOG] Tokens {display_label or ""} ({usedTokens}):')
|
||||
logger.debug(f"{tokenized}\x1b[0m")
|
||||
|
||||
if discarded != "":
|
||||
logger.info(f"[TOKENLOG] Tokens Discarded ({totalTokens - usedTokens}):")
|
||||
logger.debug(f"{discarded}\x1b[0m")
|
||||
|
||||
def try_parse_legacy_blend(text: str, skip_normalize: bool = False) -> Optional[Conjunction]:
|
||||
weighted_subprompts = split_weighted_subprompts(text, skip_normalize=skip_normalize)
|
||||
if len(weighted_subprompts) <= 1:
|
||||
return None
|
||||
strings = [x[0] for x in weighted_subprompts]
|
||||
|
||||
pp = PromptParser()
|
||||
parsed_conjunctions = [pp.parse_conjunction(x) for x in strings]
|
||||
flattened_prompts = []
|
||||
weights = []
|
||||
for i, x in enumerate(parsed_conjunctions):
|
||||
if len(x.prompts)>0:
|
||||
flattened_prompts.append(x.prompts[0])
|
||||
weights.append(weighted_subprompts[i][1])
|
||||
return Conjunction([Blend(prompts=flattened_prompts, weights=weights, normalize_weights=not skip_normalize)])
|
||||
|
||||
def split_weighted_subprompts(text, skip_normalize=False) -> list:
|
||||
"""
|
||||
Legacy blend parsing.
|
||||
|
||||
grabs all text up to the first occurrence of ':'
|
||||
uses the grabbed text as a sub-prompt, and takes the value following ':' as weight
|
||||
if ':' has no value defined, defaults to 1.0
|
||||
repeats until no text remaining
|
||||
"""
|
||||
prompt_parser = re.compile(
|
||||
"""
|
||||
(?P<prompt> # capture group for 'prompt'
|
||||
(?:\\\:|[^:])+ # match one or more non ':' characters or escaped colons '\:'
|
||||
) # end 'prompt'
|
||||
(?: # non-capture group
|
||||
:+ # match one or more ':' characters
|
||||
(?P<weight> # capture group for 'weight'
|
||||
-?\d+(?:\.\d+)? # match positive or negative integer or decimal number
|
||||
)? # end weight capture group, make optional
|
||||
\s* # strip spaces after weight
|
||||
| # OR
|
||||
$ # else, if no ':' then match end of line
|
||||
) # end non-capture group
|
||||
""",
|
||||
re.VERBOSE,
|
||||
)
|
||||
parsed_prompts = [
|
||||
(match.group("prompt").replace("\\:", ":"), float(match.group("weight") or 1))
|
||||
for match in re.finditer(prompt_parser, text)
|
||||
]
|
||||
if len(parsed_prompts) == 0:
|
||||
return []
|
||||
if skip_normalize:
|
||||
return parsed_prompts
|
||||
weight_sum = sum(map(lambda x: x[1], parsed_prompts))
|
||||
if weight_sum == 0:
|
||||
logger.warning(
|
||||
"Subprompt weights add up to zero. Discarding and using even weights instead."
|
||||
)
|
||||
equal_weight = 1 / max(len(parsed_prompts), 1)
|
||||
return [(x[0], equal_weight) for x in parsed_prompts]
|
||||
return [(x[0], x[1] / weight_sum) for x in parsed_prompts]
|
@ -5,7 +5,7 @@ class Restoration:
|
||||
pass
|
||||
|
||||
def load_face_restore_models(
|
||||
self, gfpgan_model_path="./models/gfpgan/GFPGANv1.4.pth"
|
||||
self, gfpgan_model_path="./models/core/face_restoration/gfpgan/GFPGANv1.4.pth"
|
||||
):
|
||||
# Load GFPGAN
|
||||
gfpgan = self.load_gfpgan(gfpgan_model_path)
|
||||
|
@ -15,7 +15,7 @@ pretrained_model_url = (
|
||||
|
||||
class CodeFormerRestoration:
|
||||
def __init__(
|
||||
self, codeformer_dir="models/codeformer", codeformer_model_path="codeformer.pth"
|
||||
self, codeformer_dir="./models/core/face_restoration/codeformer", codeformer_model_path="codeformer.pth"
|
||||
) -> None:
|
||||
|
||||
self.globals = InvokeAIAppConfig.get_config()
|
||||
@ -24,7 +24,7 @@ class CodeFormerRestoration:
|
||||
self.codeformer_model_exists = self.model_path.exists()
|
||||
|
||||
if not self.codeformer_model_exists:
|
||||
logger.error("NOT FOUND: CodeFormer model not found at " + self.model_path)
|
||||
logger.error(f"NOT FOUND: CodeFormer model not found at {self.model_path}")
|
||||
sys.path.append(os.path.abspath(codeformer_dir))
|
||||
|
||||
def process(self, image, strength, device, seed=None, fidelity=0.75):
|
||||
@ -71,7 +71,7 @@ class CodeFormerRestoration:
|
||||
upscale_factor=1,
|
||||
use_parse=True,
|
||||
device=device,
|
||||
model_rootpath = self.globals.root_dir / "gfpgan" / "weights"
|
||||
model_rootpath = self.globals.model_path / 'core/face_restoration/gfpgan/weights'
|
||||
)
|
||||
face_helper.clean_all()
|
||||
face_helper.read_image(bgr_image_array)
|
||||
|
@ -18,7 +18,7 @@ class GFPGAN:
|
||||
self.gfpgan_model_exists = os.path.isfile(self.model_path)
|
||||
|
||||
if not self.gfpgan_model_exists:
|
||||
logger.error("NOT FOUND: GFPGAN model not found at " + self.model_path)
|
||||
logger.error(f"NOT FOUND: GFPGAN model not found at {self.model_path}")
|
||||
return None
|
||||
|
||||
def model_exists(self):
|
||||
|
@ -30,8 +30,8 @@ class ESRGAN:
|
||||
upscale=4,
|
||||
act_type="prelu",
|
||||
)
|
||||
model_path = config.root_dir / "models/realesrgan/realesr-general-x4v3.pth"
|
||||
wdn_model_path = config.root_dir / "models/realesrgan/realesr-general-wdn-x4v3.pth"
|
||||
model_path = config.models_path / "core/upscaling/realesrgan/realesr-general-x4v3.pth"
|
||||
wdn_model_path = config.models_path / "core/upscaling/realesrgan/realesr-general-wdn-x4v3.pth"
|
||||
scale = 4
|
||||
|
||||
bg_upsampler = RealESRGANer(
|
||||
|
@ -30,18 +30,10 @@ class SafetyChecker(object):
|
||||
self.device = device
|
||||
|
||||
try:
|
||||
safety_model_id = "CompVis/stable-diffusion-safety-checker"
|
||||
safety_model_path = config.cache_dir
|
||||
self.safety_checker = StableDiffusionSafetyChecker.from_pretrained(
|
||||
safety_model_id,
|
||||
local_files_only=True,
|
||||
cache_dir=safety_model_path,
|
||||
)
|
||||
self.safety_feature_extractor = AutoFeatureExtractor.from_pretrained(
|
||||
safety_model_id,
|
||||
local_files_only=True,
|
||||
cache_dir=safety_model_path,
|
||||
)
|
||||
safety_model_id = config.models_path / 'core/convert/stable-diffusion-safety-checker'
|
||||
feature_extractor_id = config.models_path / 'core/convert/stable-diffusion-safety-checker-extractor'
|
||||
self.safety_checker = StableDiffusionSafetyChecker.from_pretrained(safety_model_id)
|
||||
self.safety_feature_extractor = AutoFeatureExtractor.from_pretrained(feature_extractor_id)
|
||||
except Exception:
|
||||
logger.error(
|
||||
"An error was encountered while installing the safety checker:"
|
||||
|
@ -1,7 +1,6 @@
|
||||
"""
|
||||
Initialization file for the invokeai.backend.stable_diffusion package
|
||||
"""
|
||||
from .concepts_lib import HuggingFaceConceptsLibrary
|
||||
from .diffusers_pipeline import (
|
||||
ConditioningData,
|
||||
PipelineIntermediateState,
|
||||
@ -10,4 +9,3 @@ from .diffusers_pipeline import (
|
||||
from .diffusion import InvokeAIDiffuserComponent
|
||||
from .diffusion.cross_attention_map_saving import AttentionMapSaver
|
||||
from .diffusion.shared_invokeai_diffusion import PostprocessingSettings
|
||||
from .textual_inversion_manager import TextualInversionManager
|
||||
|
@ -1,275 +0,0 @@
|
||||
"""
|
||||
Query and install embeddings from the HuggingFace SD Concepts Library
|
||||
at https://huggingface.co/sd-concepts-library.
|
||||
|
||||
The interface is through the Concepts() object.
|
||||
"""
|
||||
import os
|
||||
import re
|
||||
from typing import Callable
|
||||
from urllib import error as ul_error
|
||||
from urllib import request
|
||||
|
||||
from huggingface_hub import (
|
||||
HfApi,
|
||||
HfFolder,
|
||||
ModelFilter,
|
||||
hf_hub_url,
|
||||
)
|
||||
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
logger = InvokeAILogger.getLogger()
|
||||
|
||||
class HuggingFaceConceptsLibrary(object):
|
||||
def __init__(self, root=None):
|
||||
"""
|
||||
Initialize the Concepts object. May optionally pass a root directory.
|
||||
"""
|
||||
self.config = InvokeAIAppConfig.get_config()
|
||||
self.root = root or self.config.root
|
||||
self.hf_api = HfApi()
|
||||
self.local_concepts = dict()
|
||||
self.concept_list = None
|
||||
self.concepts_loaded = dict()
|
||||
self.triggers = dict() # concept name to trigger phrase
|
||||
self.concept_names = dict() # trigger phrase to concept name
|
||||
self.match_trigger = re.compile(
|
||||
"(<[\w\- >]+>)"
|
||||
) # trigger is slightly less restrictive than HF concept name
|
||||
self.match_concept = re.compile(
|
||||
"<([\w\-]+)>"
|
||||
) # HF concept name can only contain A-Za-z0-9_-
|
||||
|
||||
def list_concepts(self) -> list:
|
||||
"""
|
||||
Return a list of all the concepts by name, without the 'sd-concepts-library' part.
|
||||
Also adds local concepts in invokeai/embeddings folder.
|
||||
"""
|
||||
local_concepts_now = self.get_local_concepts(
|
||||
os.path.join(self.root, "embeddings")
|
||||
)
|
||||
local_concepts_to_add = set(local_concepts_now).difference(
|
||||
set(self.local_concepts)
|
||||
)
|
||||
self.local_concepts.update(local_concepts_now)
|
||||
|
||||
if self.concept_list is not None:
|
||||
if local_concepts_to_add:
|
||||
self.concept_list.extend(list(local_concepts_to_add))
|
||||
return self.concept_list
|
||||
return self.concept_list
|
||||
elif self.config.internet_available is True:
|
||||
try:
|
||||
models = self.hf_api.list_models(
|
||||
filter=ModelFilter(model_name="sd-concepts-library/")
|
||||
)
|
||||
self.concept_list = [a.id.split("/")[1] for a in models]
|
||||
# when init, add all in dir. when not init, add only concepts added between init and now
|
||||
self.concept_list.extend(list(local_concepts_to_add))
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Hugging Face textual inversion concepts libraries could not be loaded. The error was {str(e)}."
|
||||
)
|
||||
logger.warning(
|
||||
"You may load .bin and .pt file(s) manually using the --embedding_directory argument."
|
||||
)
|
||||
return self.concept_list
|
||||
else:
|
||||
return self.concept_list
|
||||
|
||||
def get_concept_model_path(self, concept_name: str) -> str:
|
||||
"""
|
||||
Returns the path to the 'learned_embeds.bin' file in
|
||||
the named concept. Returns None if invalid or cannot
|
||||
be downloaded.
|
||||
"""
|
||||
if not concept_name in self.list_concepts():
|
||||
logger.warning(
|
||||
f"{concept_name} is not a local embedding trigger, nor is it a HuggingFace concept. Generation will continue without the concept."
|
||||
)
|
||||
return None
|
||||
return self.get_concept_file(concept_name.lower(), "learned_embeds.bin")
|
||||
|
||||
def concept_to_trigger(self, concept_name: str) -> str:
|
||||
"""
|
||||
Given a concept name returns its trigger by looking in the
|
||||
"token_identifier.txt" file.
|
||||
"""
|
||||
if concept_name in self.triggers:
|
||||
return self.triggers[concept_name]
|
||||
elif self.concept_is_local(concept_name):
|
||||
trigger = f"<{concept_name}>"
|
||||
self.triggers[concept_name] = trigger
|
||||
self.concept_names[trigger] = concept_name
|
||||
return trigger
|
||||
|
||||
file = self.get_concept_file(
|
||||
concept_name, "token_identifier.txt", local_only=True
|
||||
)
|
||||
if not file:
|
||||
return None
|
||||
with open(file, "r") as f:
|
||||
trigger = f.readline()
|
||||
trigger = trigger.strip()
|
||||
self.triggers[concept_name] = trigger
|
||||
self.concept_names[trigger] = concept_name
|
||||
return trigger
|
||||
|
||||
def trigger_to_concept(self, trigger: str) -> str:
|
||||
"""
|
||||
Given a trigger phrase, maps it to the concept library name.
|
||||
Only works if concept_to_trigger() has previously been called
|
||||
on this library. There needs to be a persistent database for
|
||||
this.
|
||||
"""
|
||||
concept = self.concept_names.get(trigger, None)
|
||||
return f"<{concept}>" if concept else f"{trigger}"
|
||||
|
||||
def replace_triggers_with_concepts(self, prompt: str) -> str:
|
||||
"""
|
||||
Given a prompt string that contains <trigger> tags, replace these
|
||||
tags with the concept name. The reason for this is so that the
|
||||
concept names get stored in the prompt metadata. There is no
|
||||
controlling of colliding triggers in the SD library, so it is
|
||||
better to store the concept name (unique) than the concept trigger
|
||||
(not necessarily unique!)
|
||||
"""
|
||||
if not prompt:
|
||||
return prompt
|
||||
triggers = self.match_trigger.findall(prompt)
|
||||
if not triggers:
|
||||
return prompt
|
||||
|
||||
def do_replace(match) -> str:
|
||||
return self.trigger_to_concept(match.group(1)) or f"<{match.group(1)}>"
|
||||
|
||||
return self.match_trigger.sub(do_replace, prompt)
|
||||
|
||||
def replace_concepts_with_triggers(
|
||||
self,
|
||||
prompt: str,
|
||||
load_concepts_callback: Callable[[list], any],
|
||||
excluded_tokens: list[str],
|
||||
) -> str:
|
||||
"""
|
||||
Given a prompt string that contains `<concept_name>` tags, replace
|
||||
these tags with the appropriate trigger.
|
||||
|
||||
If any `<concept_name>` tags are found, `load_concepts_callback()` is called with a list
|
||||
of `concepts_name` strings.
|
||||
|
||||
`excluded_tokens` are any tokens that should not be replaced, typically because they
|
||||
are trigger tokens from a locally-loaded embedding.
|
||||
"""
|
||||
concepts = self.match_concept.findall(prompt)
|
||||
if not concepts:
|
||||
return prompt
|
||||
load_concepts_callback(concepts)
|
||||
|
||||
def do_replace(match) -> str:
|
||||
if excluded_tokens and f"<{match.group(1)}>" in excluded_tokens:
|
||||
return f"<{match.group(1)}>"
|
||||
return self.concept_to_trigger(match.group(1)) or f"<{match.group(1)}>"
|
||||
|
||||
return self.match_concept.sub(do_replace, prompt)
|
||||
|
||||
def get_concept_file(
|
||||
self,
|
||||
concept_name: str,
|
||||
file_name: str = "learned_embeds.bin",
|
||||
local_only: bool = False,
|
||||
) -> str:
|
||||
if not (
|
||||
self.concept_is_downloaded(concept_name)
|
||||
or self.concept_is_local(concept_name)
|
||||
or local_only
|
||||
):
|
||||
self.download_concept(concept_name)
|
||||
|
||||
# get local path in invokeai/embeddings if local concept
|
||||
if self.concept_is_local(concept_name):
|
||||
concept_path = self._concept_local_path(concept_name)
|
||||
path = concept_path
|
||||
else:
|
||||
concept_path = self._concept_path(concept_name)
|
||||
path = os.path.join(concept_path, file_name)
|
||||
return path if os.path.exists(path) else None
|
||||
|
||||
def concept_is_local(self, concept_name) -> bool:
|
||||
return concept_name in self.local_concepts
|
||||
|
||||
def concept_is_downloaded(self, concept_name) -> bool:
|
||||
concept_directory = self._concept_path(concept_name)
|
||||
return os.path.exists(concept_directory)
|
||||
|
||||
def download_concept(self, concept_name) -> bool:
|
||||
repo_id = self._concept_id(concept_name)
|
||||
dest = self._concept_path(concept_name)
|
||||
|
||||
access_token = HfFolder.get_token()
|
||||
header = [("Authorization", f"Bearer {access_token}")] if access_token else []
|
||||
opener = request.build_opener()
|
||||
opener.addheaders = header
|
||||
request.install_opener(opener)
|
||||
|
||||
os.makedirs(dest, exist_ok=True)
|
||||
succeeded = True
|
||||
|
||||
bytes = 0
|
||||
|
||||
def tally_download_size(chunk, size, total):
|
||||
nonlocal bytes
|
||||
if chunk == 0:
|
||||
bytes += total
|
||||
|
||||
logger.info(f"Downloading {repo_id}...", end="")
|
||||
try:
|
||||
for file in (
|
||||
"README.md",
|
||||
"learned_embeds.bin",
|
||||
"token_identifier.txt",
|
||||
"type_of_concept.txt",
|
||||
):
|
||||
url = hf_hub_url(repo_id, file)
|
||||
request.urlretrieve(
|
||||
url, os.path.join(dest, file), reporthook=tally_download_size
|
||||
)
|
||||
except ul_error.HTTPError as e:
|
||||
if e.code == 404:
|
||||
logger.warning(
|
||||
f"Concept {concept_name} is not known to the Hugging Face library. Generation will continue without the concept."
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"Failed to download {concept_name}/{file} ({str(e)}. Generation will continue without the concept.)"
|
||||
)
|
||||
os.rmdir(dest)
|
||||
return False
|
||||
except ul_error.URLError as e:
|
||||
logger.error(
|
||||
f"an error occurred while downloading {concept_name}: {str(e)}. This may reflect a network issue. Generation will continue without the concept."
|
||||
)
|
||||
os.rmdir(dest)
|
||||
return False
|
||||
logger.info("...{:.2f}Kb".format(bytes / 1024))
|
||||
return succeeded
|
||||
|
||||
def _concept_id(self, concept_name: str) -> str:
|
||||
return f"sd-concepts-library/{concept_name}"
|
||||
|
||||
def _concept_path(self, concept_name: str) -> str:
|
||||
return os.path.join(self.root, "models", "sd-concepts-library", concept_name)
|
||||
|
||||
def _concept_local_path(self, concept_name: str) -> str:
|
||||
filename = self.local_concepts[concept_name]
|
||||
return os.path.join(self.root, "embeddings", filename)
|
||||
|
||||
def get_local_concepts(self, loc_dir: str):
|
||||
locs_dic = dict()
|
||||
if os.path.isdir(loc_dir):
|
||||
for file in os.listdir(loc_dir):
|
||||
f = os.path.splitext(file)
|
||||
if f[1] == ".bin" or f[1] == ".pt":
|
||||
locs_dic[f[0]] = file
|
||||
return locs_dic
|
@ -16,7 +16,6 @@ from accelerate.utils import set_seed
|
||||
import psutil
|
||||
import torch
|
||||
import torchvision.transforms as T
|
||||
from compel import EmbeddingsProvider
|
||||
from diffusers.models import AutoencoderKL, UNet2DConditionModel
|
||||
from diffusers.models.controlnet import ControlNetModel, ControlNetOutput
|
||||
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
||||
@ -48,7 +47,6 @@ from .diffusion import (
|
||||
PostprocessingSettings,
|
||||
)
|
||||
from .offloading import FullyLoadedModelGroup, LazilyLoadedModelGroup, ModelGroup
|
||||
from .textual_inversion_manager import TextualInversionManager
|
||||
|
||||
@dataclass
|
||||
class PipelineIntermediateState:
|
||||
@ -319,6 +317,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
requires_safety_checker: bool = False,
|
||||
precision: str = "float32",
|
||||
control_model: ControlNetModel = None,
|
||||
execution_device: Optional[torch.device] = None,
|
||||
):
|
||||
super().__init__(
|
||||
vae,
|
||||
@ -343,22 +342,10 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
# control_model=control_model,
|
||||
)
|
||||
self.invokeai_diffuser = InvokeAIDiffuserComponent(
|
||||
self.unet, self._unet_forward, is_running_diffusers=True
|
||||
)
|
||||
use_full_precision = precision == "float32" or precision == "autocast"
|
||||
self.textual_inversion_manager = TextualInversionManager(
|
||||
tokenizer=self.tokenizer,
|
||||
text_encoder=self.text_encoder,
|
||||
full_precision=use_full_precision,
|
||||
)
|
||||
# InvokeAI's interface for text embeddings and whatnot
|
||||
self.embeddings_provider = EmbeddingsProvider(
|
||||
tokenizer=self.tokenizer,
|
||||
text_encoder=self.text_encoder,
|
||||
textual_inversion_manager=self.textual_inversion_manager,
|
||||
self.unet, self._unet_forward
|
||||
)
|
||||
|
||||
self._model_group = FullyLoadedModelGroup(self.unet.device)
|
||||
self._model_group = FullyLoadedModelGroup(execution_device or self.unet.device)
|
||||
self._model_group.install(*self._submodels)
|
||||
self.control_model = control_model
|
||||
|
||||
@ -406,50 +393,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
else:
|
||||
self.disable_attention_slicing()
|
||||
|
||||
def enable_offload_submodels(self, device: torch.device):
|
||||
"""
|
||||
Offload each submodel when it's not in use.
|
||||
|
||||
Useful for low-vRAM situations where the size of the model in memory is a big chunk of
|
||||
the total available resource, and you want to free up as much for inference as possible.
|
||||
|
||||
This requires more moving parts and may add some delay as the U-Net is swapped out for the
|
||||
VAE and vice-versa.
|
||||
"""
|
||||
models = self._submodels
|
||||
if self._model_group is not None:
|
||||
self._model_group.uninstall(*models)
|
||||
group = LazilyLoadedModelGroup(device)
|
||||
group.install(*models)
|
||||
self._model_group = group
|
||||
|
||||
def disable_offload_submodels(self):
|
||||
"""
|
||||
Leave all submodels loaded.
|
||||
|
||||
Appropriate for cases where the size of the model in memory is small compared to the memory
|
||||
required for inference. Avoids the delay and complexity of shuffling the submodels to and
|
||||
from the GPU.
|
||||
"""
|
||||
models = self._submodels
|
||||
if self._model_group is not None:
|
||||
self._model_group.uninstall(*models)
|
||||
group = FullyLoadedModelGroup(self._model_group.execution_device)
|
||||
group.install(*models)
|
||||
self._model_group = group
|
||||
|
||||
def offload_all(self):
|
||||
"""Offload all this pipeline's models to CPU."""
|
||||
self._model_group.offload_current()
|
||||
|
||||
def ready(self):
|
||||
"""
|
||||
Ready this pipeline's models.
|
||||
|
||||
i.e. preload them to the GPU if appropriate.
|
||||
"""
|
||||
self._model_group.ready()
|
||||
|
||||
def to(self, torch_device: Optional[Union[str, torch.device]] = None, silence_dtype_warnings=False):
|
||||
# overridden method; types match the superclass.
|
||||
if torch_device is None:
|
||||
@ -1013,25 +956,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
device = self._model_group.device_for(self.safety_checker)
|
||||
return super().run_safety_checker(image, device, dtype)
|
||||
|
||||
@torch.inference_mode()
|
||||
def get_learned_conditioning(
|
||||
self, c: List[List[str]], *, return_tokens=True, fragment_weights=None
|
||||
):
|
||||
"""
|
||||
Compatibility function for invokeai.models.diffusion.ddpm.LatentDiffusion.
|
||||
"""
|
||||
return self.embeddings_provider.get_embeddings_for_weighted_prompt_fragments(
|
||||
text_batch=c,
|
||||
fragment_weights_batch=fragment_weights,
|
||||
should_return_tokens=return_tokens,
|
||||
device=self._model_group.device_for(self.unet),
|
||||
)
|
||||
|
||||
@property
|
||||
def channels(self) -> int:
|
||||
"""Compatible with DiffusionWrapper"""
|
||||
return self.unet.config.in_channels
|
||||
|
||||
def decode_latents(self, latents):
|
||||
# Explicit call to get the vae loaded, since `decode` isn't the forward method.
|
||||
self._model_group.load(self.vae)
|
||||
@ -1048,8 +972,8 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
|
||||
# Copied from diffusers pipeline_stable_diffusion_controlnet.py
|
||||
# Returns torch.Tensor of shape (batch_size, 3, height, width)
|
||||
@staticmethod
|
||||
def prepare_control_image(
|
||||
self,
|
||||
image,
|
||||
# FIXME: need to fix hardwiring of width and height, change to basing on latents dimensions?
|
||||
# latents,
|
||||
|
@ -18,7 +18,6 @@ from .cross_attention_control import (
|
||||
CrossAttentionType,
|
||||
SwapCrossAttnContext,
|
||||
get_cross_attention_modules,
|
||||
restore_default_cross_attention,
|
||||
setup_cross_attention_control_attention_processors,
|
||||
)
|
||||
from .cross_attention_map_saving import AttentionMapSaver
|
||||
@ -66,7 +65,6 @@ class InvokeAIDiffuserComponent:
|
||||
self,
|
||||
model,
|
||||
model_forward_callback: ModelForwardCallback,
|
||||
is_running_diffusers: bool = False,
|
||||
):
|
||||
"""
|
||||
:param model: the unet model to pass through to cross attention control
|
||||
@ -75,7 +73,6 @@ class InvokeAIDiffuserComponent:
|
||||
config = InvokeAIAppConfig.get_config()
|
||||
self.conditioning = None
|
||||
self.model = model
|
||||
self.is_running_diffusers = is_running_diffusers
|
||||
self.model_forward_callback = model_forward_callback
|
||||
self.cross_attention_control_context = None
|
||||
self.sequential_guidance = config.sequential_guidance
|
||||
@ -112,37 +109,6 @@ class InvokeAIDiffuserComponent:
|
||||
# TODO resuscitate attention map saving
|
||||
# self.remove_attention_map_saving()
|
||||
|
||||
# apparently unused code
|
||||
# TODO: delete
|
||||
# def override_cross_attention(
|
||||
# self, conditioning: ExtraConditioningInfo, step_count: int
|
||||
# ) -> Dict[str, AttentionProcessor]:
|
||||
# """
|
||||
# setup cross attention .swap control. for diffusers this replaces the attention processor, so
|
||||
# the previous attention processor is returned so that the caller can restore it later.
|
||||
# """
|
||||
# self.conditioning = conditioning
|
||||
# self.cross_attention_control_context = Context(
|
||||
# arguments=self.conditioning.cross_attention_control_args,
|
||||
# step_count=step_count,
|
||||
# )
|
||||
# return override_cross_attention(
|
||||
# self.model,
|
||||
# self.cross_attention_control_context,
|
||||
# is_running_diffusers=self.is_running_diffusers,
|
||||
# )
|
||||
|
||||
def restore_default_cross_attention(
|
||||
self, restore_attention_processor: Optional["AttentionProcessor"] = None
|
||||
):
|
||||
self.conditioning = None
|
||||
self.cross_attention_control_context = None
|
||||
restore_default_cross_attention(
|
||||
self.model,
|
||||
is_running_diffusers=self.is_running_diffusers,
|
||||
restore_attention_processor=restore_attention_processor,
|
||||
)
|
||||
|
||||
def setup_attention_map_saving(self, saver: AttentionMapSaver):
|
||||
def callback(slice, dim, offset, slice_size, key):
|
||||
if dim is not None:
|
||||
@ -204,9 +170,7 @@ class InvokeAIDiffuserComponent:
|
||||
cross_attention_control_types_to_do = []
|
||||
context: Context = self.cross_attention_control_context
|
||||
if self.cross_attention_control_context is not None:
|
||||
percent_through = self.calculate_percent_through(
|
||||
sigma, step_index, total_step_count
|
||||
)
|
||||
percent_through = step_index / total_step_count
|
||||
cross_attention_control_types_to_do = (
|
||||
context.get_active_cross_attention_control_types_for_step(
|
||||
percent_through
|
||||
@ -264,9 +228,7 @@ class InvokeAIDiffuserComponent:
|
||||
total_step_count,
|
||||
) -> torch.Tensor:
|
||||
if postprocessing_settings is not None:
|
||||
percent_through = self.calculate_percent_through(
|
||||
sigma, step_index, total_step_count
|
||||
)
|
||||
percent_through = step_index / total_step_count
|
||||
latents = self.apply_threshold(
|
||||
postprocessing_settings, latents, percent_through
|
||||
)
|
||||
@ -275,22 +237,6 @@ class InvokeAIDiffuserComponent:
|
||||
)
|
||||
return latents
|
||||
|
||||
def calculate_percent_through(self, sigma, step_index, total_step_count):
|
||||
if step_index is not None and total_step_count is not None:
|
||||
# 🧨diffusers codepath
|
||||
percent_through = (
|
||||
step_index / total_step_count
|
||||
) # will never reach 1.0 - this is deliberate
|
||||
else:
|
||||
# legacy compvis codepath
|
||||
# TODO remove when compvis codepath support is dropped
|
||||
if step_index is None and sigma is None:
|
||||
raise ValueError(
|
||||
"Either step_index or sigma is required when doing cross attention control, but both are None."
|
||||
)
|
||||
percent_through = self.estimate_percent_through(step_index, sigma)
|
||||
return percent_through
|
||||
|
||||
# methods below are called from do_diffusion_step and should be considered private to this class.
|
||||
|
||||
def _apply_standard_conditioning(self, x, sigma, unconditioning, conditioning, **kwargs):
|
||||
@ -323,6 +269,7 @@ class InvokeAIDiffuserComponent:
|
||||
conditioned_next_x = conditioned_next_x.clone()
|
||||
return unconditioned_next_x, conditioned_next_x
|
||||
|
||||
# TODO: looks unused
|
||||
def _apply_hybrid_conditioning(self, x, sigma, unconditioning, conditioning, **kwargs):
|
||||
assert isinstance(conditioning, dict)
|
||||
assert isinstance(unconditioning, dict)
|
||||
@ -350,34 +297,6 @@ class InvokeAIDiffuserComponent:
|
||||
conditioning,
|
||||
cross_attention_control_types_to_do,
|
||||
**kwargs,
|
||||
):
|
||||
if self.is_running_diffusers:
|
||||
return self._apply_cross_attention_controlled_conditioning__diffusers(
|
||||
x,
|
||||
sigma,
|
||||
unconditioning,
|
||||
conditioning,
|
||||
cross_attention_control_types_to_do,
|
||||
**kwargs,
|
||||
)
|
||||
else:
|
||||
return self._apply_cross_attention_controlled_conditioning__compvis(
|
||||
x,
|
||||
sigma,
|
||||
unconditioning,
|
||||
conditioning,
|
||||
cross_attention_control_types_to_do,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def _apply_cross_attention_controlled_conditioning__diffusers(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
sigma,
|
||||
unconditioning,
|
||||
conditioning,
|
||||
cross_attention_control_types_to_do,
|
||||
**kwargs,
|
||||
):
|
||||
context: Context = self.cross_attention_control_context
|
||||
|
||||
@ -409,54 +328,6 @@ class InvokeAIDiffuserComponent:
|
||||
)
|
||||
return unconditioned_next_x, conditioned_next_x
|
||||
|
||||
def _apply_cross_attention_controlled_conditioning__compvis(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
sigma,
|
||||
unconditioning,
|
||||
conditioning,
|
||||
cross_attention_control_types_to_do,
|
||||
**kwargs,
|
||||
):
|
||||
# print('pct', percent_through, ': doing cross attention control on', cross_attention_control_types_to_do)
|
||||
# slower non-batched path (20% slower on mac MPS)
|
||||
# We are only interested in using attention maps for conditioned_next_x, but batching them with generation of
|
||||
# unconditioned_next_x causes attention maps to *also* be saved for the unconditioned_next_x.
|
||||
# This messes app their application later, due to mismatched shape of dim 0 (seems to be 16 for batched vs. 8)
|
||||
# (For the batched invocation the `wrangler` function gets attention tensor with shape[0]=16,
|
||||
# representing batched uncond + cond, but then when it comes to applying the saved attention, the
|
||||
# wrangler gets an attention tensor which only has shape[0]=8, representing just self.edited_conditionings.)
|
||||
# todo: give CrossAttentionControl's `wrangler` function more info so it can work with a batched call as well.
|
||||
context: Context = self.cross_attention_control_context
|
||||
|
||||
try:
|
||||
unconditioned_next_x = self.model_forward_callback(x, sigma, unconditioning, **kwargs)
|
||||
|
||||
# process x using the original prompt, saving the attention maps
|
||||
# print("saving attention maps for", cross_attention_control_types_to_do)
|
||||
for ca_type in cross_attention_control_types_to_do:
|
||||
context.request_save_attention_maps(ca_type)
|
||||
_ = self.model_forward_callback(x, sigma, conditioning, **kwargs,)
|
||||
context.clear_requests(cleanup=False)
|
||||
|
||||
# process x again, using the saved attention maps to control where self.edited_conditioning will be applied
|
||||
# print("applying saved attention maps for", cross_attention_control_types_to_do)
|
||||
for ca_type in cross_attention_control_types_to_do:
|
||||
context.request_apply_saved_attention_maps(ca_type)
|
||||
edited_conditioning = (
|
||||
self.conditioning.cross_attention_control_args.edited_conditioning
|
||||
)
|
||||
conditioned_next_x = self.model_forward_callback(
|
||||
x, sigma, edited_conditioning, **kwargs,
|
||||
)
|
||||
context.clear_requests(cleanup=True)
|
||||
|
||||
except:
|
||||
context.clear_requests(cleanup=True)
|
||||
raise
|
||||
|
||||
return unconditioned_next_x, conditioned_next_x
|
||||
|
||||
def _combine(self, unconditioned_next_x, conditioned_next_x, guidance_scale):
|
||||
# to scale how much effect conditioning has, calculate the changes it does and then scale that
|
||||
scaled_delta = (conditioned_next_x - unconditioned_next_x) * guidance_scale
|
||||
|
@ -157,7 +157,7 @@ class LazilyLoadedModelGroup(ModelGroup):
|
||||
def offload_current(self):
|
||||
module = self._current_model_ref()
|
||||
if module is not NO_MODEL:
|
||||
module.to(device=OFFLOAD_DEVICE)
|
||||
module.to(OFFLOAD_DEVICE)
|
||||
self.clear_current_model()
|
||||
|
||||
def _load(self, module: torch.nn.Module) -> torch.nn.Module:
|
||||
@ -228,7 +228,7 @@ class FullyLoadedModelGroup(ModelGroup):
|
||||
def install(self, *models: torch.nn.Module):
|
||||
for model in models:
|
||||
self._models.add(model)
|
||||
model.to(device=self.execution_device)
|
||||
model.to(self.execution_device)
|
||||
|
||||
def uninstall(self, *models: torch.nn.Module):
|
||||
for model in models:
|
||||
@ -238,11 +238,11 @@ class FullyLoadedModelGroup(ModelGroup):
|
||||
self.uninstall(*self._models)
|
||||
|
||||
def load(self, model):
|
||||
model.to(device=self.execution_device)
|
||||
model.to(self.execution_device)
|
||||
|
||||
def offload_current(self):
|
||||
for model in self._models:
|
||||
model.to(device=OFFLOAD_DEVICE)
|
||||
model.to(OFFLOAD_DEVICE)
|
||||
|
||||
def ready(self):
|
||||
for model in self._models:
|
||||
@ -252,7 +252,7 @@ class FullyLoadedModelGroup(ModelGroup):
|
||||
self.execution_device = device
|
||||
for model in self._models:
|
||||
if model.device != OFFLOAD_DEVICE:
|
||||
model.to(device=device)
|
||||
model.to(device)
|
||||
|
||||
def device_for(self, model):
|
||||
if model not in self:
|
||||
|
@ -1,13 +1,14 @@
|
||||
from diffusers import DDIMScheduler, DPMSolverMultistepScheduler, KDPM2DiscreteScheduler, \
|
||||
KDPM2AncestralDiscreteScheduler, EulerDiscreteScheduler, EulerAncestralDiscreteScheduler, \
|
||||
HeunDiscreteScheduler, LMSDiscreteScheduler, PNDMScheduler, UniPCMultistepScheduler, \
|
||||
DPMSolverSinglestepScheduler, DEISMultistepScheduler, DDPMScheduler
|
||||
DPMSolverSinglestepScheduler, DEISMultistepScheduler, DDPMScheduler, DPMSolverSDEScheduler
|
||||
|
||||
SCHEDULER_MAP = dict(
|
||||
ddim=(DDIMScheduler, dict()),
|
||||
ddpm=(DDPMScheduler, dict()),
|
||||
deis=(DEISMultistepScheduler, dict()),
|
||||
lms=(LMSDiscreteScheduler, dict()),
|
||||
lms=(LMSDiscreteScheduler, dict(use_karras_sigmas=False)),
|
||||
lms_k=(LMSDiscreteScheduler, dict(use_karras_sigmas=True)),
|
||||
pndm=(PNDMScheduler, dict()),
|
||||
heun=(HeunDiscreteScheduler, dict(use_karras_sigmas=False)),
|
||||
heun_k=(HeunDiscreteScheduler, dict(use_karras_sigmas=True)),
|
||||
@ -16,8 +17,13 @@ SCHEDULER_MAP = dict(
|
||||
euler_a=(EulerAncestralDiscreteScheduler, dict()),
|
||||
kdpm_2=(KDPM2DiscreteScheduler, dict()),
|
||||
kdpm_2_a=(KDPM2AncestralDiscreteScheduler, dict()),
|
||||
dpmpp_2s=(DPMSolverSinglestepScheduler, dict()),
|
||||
dpmpp_2s=(DPMSolverSinglestepScheduler, dict(use_karras_sigmas=False)),
|
||||
dpmpp_2s_k=(DPMSolverSinglestepScheduler, dict(use_karras_sigmas=True)),
|
||||
dpmpp_2m=(DPMSolverMultistepScheduler, dict(use_karras_sigmas=False)),
|
||||
dpmpp_2m_k=(DPMSolverMultistepScheduler, dict(use_karras_sigmas=True)),
|
||||
dpmpp_2m_sde=(DPMSolverMultistepScheduler, dict(use_karras_sigmas=False, algorithm_type='sde-dpmsolver++')),
|
||||
dpmpp_2m_sde_k=(DPMSolverMultistepScheduler, dict(use_karras_sigmas=True, algorithm_type='sde-dpmsolver++')),
|
||||
dpmpp_sde=(DPMSolverSDEScheduler, dict(use_karras_sigmas=False, noise_sampler_seed=0)),
|
||||
dpmpp_sde_k=(DPMSolverSDEScheduler, dict(use_karras_sigmas=True, noise_sampler_seed=0)),
|
||||
unipc=(UniPCMultistepScheduler, dict(cpu_only=True))
|
||||
)
|
||||
|
@ -1,429 +0,0 @@
|
||||
import traceback
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Optional, Union, List
|
||||
|
||||
import safetensors.torch
|
||||
import torch
|
||||
|
||||
from compel.embeddings_provider import BaseTextualInversionManager
|
||||
from picklescan.scanner import scan_file_path
|
||||
from transformers import CLIPTextModel, CLIPTokenizer
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
from .concepts_lib import HuggingFaceConceptsLibrary
|
||||
|
||||
@dataclass
|
||||
class EmbeddingInfo:
|
||||
name: str
|
||||
embedding: torch.Tensor
|
||||
num_vectors_per_token: int
|
||||
token_dim: int
|
||||
trained_steps: int = None
|
||||
trained_model_name: str = None
|
||||
trained_model_checksum: str = None
|
||||
|
||||
@dataclass
|
||||
class TextualInversion:
|
||||
trigger_string: str
|
||||
embedding: torch.Tensor
|
||||
trigger_token_id: Optional[int] = None
|
||||
pad_token_ids: Optional[list[int]] = None
|
||||
|
||||
@property
|
||||
def embedding_vector_length(self) -> int:
|
||||
return self.embedding.shape[0]
|
||||
|
||||
|
||||
class TextualInversionManager(BaseTextualInversionManager):
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer: CLIPTokenizer,
|
||||
text_encoder: CLIPTextModel,
|
||||
full_precision: bool = True,
|
||||
):
|
||||
self.tokenizer = tokenizer
|
||||
self.text_encoder = text_encoder
|
||||
self.full_precision = full_precision
|
||||
self.hf_concepts_library = HuggingFaceConceptsLibrary()
|
||||
self.trigger_to_sourcefile = dict()
|
||||
default_textual_inversions: list[TextualInversion] = []
|
||||
self.textual_inversions = default_textual_inversions
|
||||
|
||||
def load_huggingface_concepts(self, concepts: list[str]):
|
||||
for concept_name in concepts:
|
||||
if concept_name in self.hf_concepts_library.concepts_loaded:
|
||||
continue
|
||||
trigger = self.hf_concepts_library.concept_to_trigger(concept_name)
|
||||
if (
|
||||
self.has_textual_inversion_for_trigger_string(trigger)
|
||||
or self.has_textual_inversion_for_trigger_string(concept_name)
|
||||
or self.has_textual_inversion_for_trigger_string(f"<{concept_name}>")
|
||||
): # in case a token with literal angle brackets encountered
|
||||
logger.info(f"Loaded local embedding for trigger {concept_name}")
|
||||
continue
|
||||
bin_file = self.hf_concepts_library.get_concept_model_path(concept_name)
|
||||
if not bin_file:
|
||||
continue
|
||||
logger.info(f"Loaded remote embedding for trigger {concept_name}")
|
||||
self.load_textual_inversion(bin_file)
|
||||
self.hf_concepts_library.concepts_loaded[concept_name] = True
|
||||
|
||||
def get_all_trigger_strings(self) -> list[str]:
|
||||
return [ti.trigger_string for ti in self.textual_inversions]
|
||||
|
||||
def load_textual_inversion(
|
||||
self, ckpt_path: Union[str, Path], defer_injecting_tokens: bool = False
|
||||
):
|
||||
ckpt_path = Path(ckpt_path)
|
||||
|
||||
if not ckpt_path.is_file():
|
||||
return
|
||||
|
||||
if str(ckpt_path).endswith(".DS_Store"):
|
||||
return
|
||||
|
||||
embedding_list = self._parse_embedding(str(ckpt_path))
|
||||
for embedding_info in embedding_list:
|
||||
if (self.text_encoder.get_input_embeddings().weight.data[0].shape[0] != embedding_info.token_dim):
|
||||
logger.warning(
|
||||
f"Notice: {ckpt_path.parents[0].name}/{ckpt_path.name} was trained on a model with an incompatible token dimension: {self.text_encoder.get_input_embeddings().weight.data[0].shape[0]} vs {embedding_info.token_dim}."
|
||||
)
|
||||
continue
|
||||
|
||||
# Resolve the situation in which an earlier embedding has claimed the same
|
||||
# trigger string. We replace the trigger with '<source_file>', as we used to.
|
||||
trigger_str = embedding_info.name
|
||||
sourcefile = (
|
||||
f"{ckpt_path.parent.name}/{ckpt_path.name}"
|
||||
if ckpt_path.name == "learned_embeds.bin"
|
||||
else ckpt_path.name
|
||||
)
|
||||
|
||||
if trigger_str in self.trigger_to_sourcefile:
|
||||
replacement_trigger_str = (
|
||||
f"<{ckpt_path.parent.name}>"
|
||||
if ckpt_path.name == "learned_embeds.bin"
|
||||
else f"<{ckpt_path.stem}>"
|
||||
)
|
||||
logger.info(
|
||||
f"{sourcefile}: Trigger token '{trigger_str}' is already claimed by '{self.trigger_to_sourcefile[trigger_str]}'. Trigger this concept with {replacement_trigger_str}"
|
||||
)
|
||||
trigger_str = replacement_trigger_str
|
||||
|
||||
try:
|
||||
self._add_textual_inversion(
|
||||
trigger_str,
|
||||
embedding_info.embedding,
|
||||
defer_injecting_tokens=defer_injecting_tokens,
|
||||
)
|
||||
# remember which source file claims this trigger
|
||||
self.trigger_to_sourcefile[trigger_str] = sourcefile
|
||||
|
||||
except ValueError as e:
|
||||
logger.debug(f'Ignoring incompatible embedding {embedding_info["name"]}')
|
||||
logger.debug(f"The error was {str(e)}")
|
||||
|
||||
def _add_textual_inversion(
|
||||
self, trigger_str, embedding, defer_injecting_tokens=False
|
||||
) -> Optional[TextualInversion]:
|
||||
"""
|
||||
Add a textual inversion to be recognised.
|
||||
:param trigger_str: The trigger text in the prompt that activates this textual inversion. If unknown to the embedder's tokenizer, will be added.
|
||||
:param embedding: The actual embedding data that will be inserted into the conditioning at the point where the token_str appears.
|
||||
:return: The token id for the added embedding, either existing or newly-added.
|
||||
"""
|
||||
if trigger_str in [ti.trigger_string for ti in self.textual_inversions]:
|
||||
logger.warning(
|
||||
f"TextualInversionManager refusing to overwrite already-loaded token '{trigger_str}'"
|
||||
)
|
||||
return
|
||||
if not self.full_precision:
|
||||
embedding = embedding.half()
|
||||
if len(embedding.shape) == 1:
|
||||
embedding = embedding.unsqueeze(0)
|
||||
elif len(embedding.shape) > 2:
|
||||
raise ValueError(
|
||||
f"** TextualInversionManager cannot add {trigger_str} because the embedding shape {embedding.shape} is incorrect. The embedding must have shape [token_dim] or [V, token_dim] where V is vector length and token_dim is 768 for SD1 or 1280 for SD2."
|
||||
)
|
||||
|
||||
try:
|
||||
ti = TextualInversion(trigger_string=trigger_str, embedding=embedding)
|
||||
if not defer_injecting_tokens:
|
||||
self._inject_tokens_and_assign_embeddings(ti)
|
||||
self.textual_inversions.append(ti)
|
||||
return ti
|
||||
|
||||
except ValueError as e:
|
||||
if str(e).startswith("Warning"):
|
||||
logger.warning(f"{str(e)}")
|
||||
else:
|
||||
traceback.print_exc()
|
||||
logger.error(
|
||||
f"TextualInversionManager was unable to add a textual inversion with trigger string {trigger_str}."
|
||||
)
|
||||
raise
|
||||
|
||||
def _inject_tokens_and_assign_embeddings(self, ti: TextualInversion) -> int:
|
||||
if ti.trigger_token_id is not None:
|
||||
raise ValueError(
|
||||
f"Tokens already injected for textual inversion with trigger '{ti.trigger_string}'"
|
||||
)
|
||||
|
||||
trigger_token_id = self._get_or_create_token_id_and_assign_embedding(
|
||||
ti.trigger_string, ti.embedding[0]
|
||||
)
|
||||
|
||||
if ti.embedding_vector_length > 1:
|
||||
# for embeddings with vector length > 1
|
||||
pad_token_strings = [
|
||||
ti.trigger_string + "-!pad-" + str(pad_index)
|
||||
for pad_index in range(1, ti.embedding_vector_length)
|
||||
]
|
||||
# todo: batched UI for faster loading when vector length >2
|
||||
pad_token_ids = [
|
||||
self._get_or_create_token_id_and_assign_embedding(
|
||||
pad_token_str, ti.embedding[1 + i]
|
||||
)
|
||||
for (i, pad_token_str) in enumerate(pad_token_strings)
|
||||
]
|
||||
else:
|
||||
pad_token_ids = []
|
||||
|
||||
ti.trigger_token_id = trigger_token_id
|
||||
ti.pad_token_ids = pad_token_ids
|
||||
return ti.trigger_token_id
|
||||
|
||||
def has_textual_inversion_for_trigger_string(self, trigger_string: str) -> bool:
|
||||
try:
|
||||
ti = self.get_textual_inversion_for_trigger_string(trigger_string)
|
||||
return ti is not None
|
||||
except StopIteration:
|
||||
return False
|
||||
|
||||
def get_textual_inversion_for_trigger_string(
|
||||
self, trigger_string: str
|
||||
) -> TextualInversion:
|
||||
return next(
|
||||
ti for ti in self.textual_inversions if ti.trigger_string == trigger_string
|
||||
)
|
||||
|
||||
def get_textual_inversion_for_token_id(self, token_id: int) -> TextualInversion:
|
||||
return next(
|
||||
ti for ti in self.textual_inversions if ti.trigger_token_id == token_id
|
||||
)
|
||||
|
||||
def create_deferred_token_ids_for_any_trigger_terms(
|
||||
self, prompt_string: str
|
||||
) -> list[int]:
|
||||
injected_token_ids = []
|
||||
for ti in self.textual_inversions:
|
||||
if ti.trigger_token_id is None and ti.trigger_string in prompt_string:
|
||||
if ti.embedding_vector_length > 1:
|
||||
logger.info(
|
||||
f"Preparing tokens for textual inversion {ti.trigger_string}..."
|
||||
)
|
||||
try:
|
||||
self._inject_tokens_and_assign_embeddings(ti)
|
||||
except ValueError as e:
|
||||
logger.debug(
|
||||
f"Ignoring incompatible embedding trigger {ti.trigger_string}"
|
||||
)
|
||||
logger.debug(f"The error was {str(e)}")
|
||||
continue
|
||||
injected_token_ids.append(ti.trigger_token_id)
|
||||
injected_token_ids.extend(ti.pad_token_ids)
|
||||
return injected_token_ids
|
||||
|
||||
def expand_textual_inversion_token_ids_if_necessary(
|
||||
self, prompt_token_ids: list[int]
|
||||
) -> list[int]:
|
||||
"""
|
||||
Insert padding tokens as necessary into the passed-in list of token ids to match any textual inversions it includes.
|
||||
|
||||
:param prompt_token_ids: The prompt as a list of token ids (`int`s). Should not include bos and eos markers.
|
||||
:return: The prompt token ids with any necessary padding to account for textual inversions inserted. May be too
|
||||
long - caller is responsible for prepending/appending eos and bos token ids, and truncating if necessary.
|
||||
"""
|
||||
if len(prompt_token_ids) == 0:
|
||||
return prompt_token_ids
|
||||
|
||||
if prompt_token_ids[0] == self.tokenizer.bos_token_id:
|
||||
raise ValueError("prompt_token_ids must not start with bos_token_id")
|
||||
if prompt_token_ids[-1] == self.tokenizer.eos_token_id:
|
||||
raise ValueError("prompt_token_ids must not end with eos_token_id")
|
||||
textual_inversion_trigger_token_ids = [
|
||||
ti.trigger_token_id for ti in self.textual_inversions
|
||||
]
|
||||
prompt_token_ids = prompt_token_ids.copy()
|
||||
for i, token_id in reversed(list(enumerate(prompt_token_ids))):
|
||||
if token_id in textual_inversion_trigger_token_ids:
|
||||
textual_inversion = next(
|
||||
ti
|
||||
for ti in self.textual_inversions
|
||||
if ti.trigger_token_id == token_id
|
||||
)
|
||||
for pad_idx in range(0, textual_inversion.embedding_vector_length - 1):
|
||||
prompt_token_ids.insert(
|
||||
i + pad_idx + 1, textual_inversion.pad_token_ids[pad_idx]
|
||||
)
|
||||
|
||||
return prompt_token_ids
|
||||
|
||||
def _get_or_create_token_id_and_assign_embedding(
|
||||
self, token_str: str, embedding: torch.Tensor
|
||||
) -> int:
|
||||
if len(embedding.shape) != 1:
|
||||
raise ValueError(
|
||||
"Embedding has incorrect shape - must be [token_dim] where token_dim is 768 for SD1 or 1280 for SD2"
|
||||
)
|
||||
existing_token_id = self.tokenizer.convert_tokens_to_ids(token_str)
|
||||
if existing_token_id == self.tokenizer.unk_token_id:
|
||||
num_tokens_added = self.tokenizer.add_tokens(token_str)
|
||||
current_embeddings = self.text_encoder.resize_token_embeddings(None)
|
||||
current_token_count = current_embeddings.num_embeddings
|
||||
new_token_count = current_token_count + num_tokens_added
|
||||
# the following call is slow - todo make batched for better performance with vector length >1
|
||||
self.text_encoder.resize_token_embeddings(new_token_count)
|
||||
|
||||
token_id = self.tokenizer.convert_tokens_to_ids(token_str)
|
||||
if token_id == self.tokenizer.unk_token_id:
|
||||
raise RuntimeError(f"Unable to find token id for token '{token_str}'")
|
||||
if (
|
||||
self.text_encoder.get_input_embeddings().weight.data[token_id].shape
|
||||
!= embedding.shape
|
||||
):
|
||||
raise ValueError(
|
||||
f"Warning. Cannot load embedding for {token_str}. It was trained on a model with token dimension {embedding.shape[0]}, but the current model has token dimension {self.text_encoder.get_input_embeddings().weight.data[token_id].shape[0]}."
|
||||
)
|
||||
self.text_encoder.get_input_embeddings().weight.data[token_id] = embedding
|
||||
|
||||
return token_id
|
||||
|
||||
|
||||
def _parse_embedding(self, embedding_file: str)->List[EmbeddingInfo]:
|
||||
suffix = Path(embedding_file).suffix
|
||||
try:
|
||||
if suffix in [".pt",".ckpt",".bin"]:
|
||||
scan_result = scan_file_path(embedding_file)
|
||||
if scan_result.infected_files > 0:
|
||||
logger.critical(
|
||||
f"Security Issues Found in Model: {scan_result.issues_count}"
|
||||
)
|
||||
logger.critical("For your safety, InvokeAI will not load this embed.")
|
||||
return list()
|
||||
ckpt = torch.load(embedding_file,map_location="cpu")
|
||||
else:
|
||||
ckpt = safetensors.torch.load_file(embedding_file)
|
||||
except Exception as e:
|
||||
logger.warning(f"Notice: unrecognized embedding file format: {embedding_file}: {e}")
|
||||
return list()
|
||||
|
||||
# try to figure out what kind of embedding file it is and parse accordingly
|
||||
keys = list(ckpt.keys())
|
||||
if all(x in keys for x in ['string_to_token','string_to_param','name','step']):
|
||||
return self._parse_embedding_v1(ckpt, embedding_file) # example rem_rezero.pt
|
||||
|
||||
elif all(x in keys for x in ['string_to_token','string_to_param']):
|
||||
return self._parse_embedding_v2(ckpt, embedding_file) # example midj-strong.pt
|
||||
|
||||
elif 'emb_params' in keys:
|
||||
return self._parse_embedding_v3(ckpt, embedding_file) # example easynegative.safetensors
|
||||
|
||||
else:
|
||||
return self._parse_embedding_v4(ckpt, embedding_file) # usually a '.bin' file
|
||||
|
||||
def _parse_embedding_v1(self, embedding_ckpt: dict, file_path: str)->List[EmbeddingInfo]:
|
||||
basename = Path(file_path).stem
|
||||
logger.debug(f'Loading v1 embedding file: {basename}')
|
||||
|
||||
embeddings = list()
|
||||
token_counter = -1
|
||||
for token,embedding in embedding_ckpt["string_to_param"].items():
|
||||
if token_counter < 0:
|
||||
trigger = embedding_ckpt["name"]
|
||||
elif token_counter == 0:
|
||||
trigger = '<basename>'
|
||||
else:
|
||||
trigger = f'<{basename}-{int(token_counter:=token_counter)}>'
|
||||
token_counter += 1
|
||||
embedding_info = EmbeddingInfo(
|
||||
name = trigger,
|
||||
embedding = embedding,
|
||||
num_vectors_per_token = embedding.size()[0],
|
||||
token_dim = embedding.size()[1],
|
||||
trained_steps = embedding_ckpt["step"],
|
||||
trained_model_name = embedding_ckpt["sd_checkpoint_name"],
|
||||
trained_model_checksum = embedding_ckpt["sd_checkpoint"]
|
||||
)
|
||||
embeddings.append(embedding_info)
|
||||
return embeddings
|
||||
|
||||
def _parse_embedding_v2 (
|
||||
self, embedding_ckpt: dict, file_path: str
|
||||
) -> List[EmbeddingInfo]:
|
||||
"""
|
||||
This handles embedding .pt file variant #2.
|
||||
"""
|
||||
basename = Path(file_path).stem
|
||||
logger.debug(f'Loading v2 embedding file: {basename}')
|
||||
embeddings = list()
|
||||
|
||||
if isinstance(
|
||||
list(embedding_ckpt["string_to_token"].values())[0], torch.Tensor
|
||||
):
|
||||
token_counter = 0
|
||||
for token,embedding in embedding_ckpt["string_to_param"].items():
|
||||
trigger = token if token != '*' \
|
||||
else f'<{basename}>' if token_counter == 0 \
|
||||
else f'<{basename}-{int(token_counter:=token_counter+1)}>'
|
||||
embedding_info = EmbeddingInfo(
|
||||
name = trigger,
|
||||
embedding = embedding,
|
||||
num_vectors_per_token = embedding.size()[0],
|
||||
token_dim = embedding.size()[1],
|
||||
)
|
||||
embeddings.append(embedding_info)
|
||||
else:
|
||||
logger.warning(f"{basename}: Unrecognized embedding format")
|
||||
|
||||
return embeddings
|
||||
|
||||
def _parse_embedding_v3(self, embedding_ckpt: dict, file_path: str)->List[EmbeddingInfo]:
|
||||
"""
|
||||
Parse 'version 3' of the .pt textual inversion embedding files.
|
||||
"""
|
||||
basename = Path(file_path).stem
|
||||
logger.debug(f'Loading v3 embedding file: {basename}')
|
||||
embedding = embedding_ckpt['emb_params']
|
||||
embedding_info = EmbeddingInfo(
|
||||
name = f'<{basename}>',
|
||||
embedding = embedding,
|
||||
num_vectors_per_token = embedding.size()[0],
|
||||
token_dim = embedding.size()[1],
|
||||
)
|
||||
return [embedding_info]
|
||||
|
||||
def _parse_embedding_v4(self, embedding_ckpt: dict, filepath: str)->List[EmbeddingInfo]:
|
||||
"""
|
||||
Parse 'version 4' of the textual inversion embedding files. This one
|
||||
is usually associated with .bin files trained by HuggingFace diffusers.
|
||||
"""
|
||||
basename = Path(filepath).stem
|
||||
short_path = Path(filepath).parents[0].name+'/'+Path(filepath).name
|
||||
|
||||
logger.debug(f'Loading v4 embedding file: {short_path}')
|
||||
|
||||
embeddings = list()
|
||||
if list(embedding_ckpt.keys()) == 0:
|
||||
logger.warning(f"Invalid embeddings file: {short_path}")
|
||||
else:
|
||||
for token,embedding in embedding_ckpt.items():
|
||||
embedding_info = EmbeddingInfo(
|
||||
name = token or f"<{basename}>",
|
||||
embedding = embedding,
|
||||
num_vectors_per_token = 1, # All Concepts seem to default to 1
|
||||
token_dim = embedding.size()[0],
|
||||
)
|
||||
embeddings.append(embedding_info)
|
||||
return embeddings
|
@ -358,7 +358,6 @@ class InvokeAILogger(object):
|
||||
|
||||
elif handler_name=='syslog':
|
||||
ch = cls._parse_syslog_args(args)
|
||||
ch.setFormatter(InvokeAISyslogFormatter())
|
||||
handlers.append(ch)
|
||||
|
||||
elif handler_name=='file':
|
||||
@ -367,7 +366,8 @@ class InvokeAILogger(object):
|
||||
handlers.append(ch)
|
||||
|
||||
elif handler_name=='http':
|
||||
handlers.append(cls._parse_http_args(args))
|
||||
ch = cls._parse_http_args(args)
|
||||
handlers.append(ch)
|
||||
return handlers
|
||||
|
||||
@staticmethod
|
||||
|
@ -1277,13 +1277,14 @@ class InvokeAIWebServer:
|
||||
eventlet.sleep(0)
|
||||
|
||||
parsed_prompt, _ = get_prompt_structure(generation_parameters["prompt"])
|
||||
tokens = (
|
||||
None
|
||||
if type(parsed_prompt) is Blend
|
||||
else get_tokens_for_prompt_object(
|
||||
self.generate.model.tokenizer, parsed_prompt
|
||||
with self.generate.model_context as model:
|
||||
tokens = (
|
||||
None
|
||||
if type(parsed_prompt) is Blend
|
||||
else get_tokens_for_prompt_object(
|
||||
model.tokenizer, parsed_prompt
|
||||
)
|
||||
)
|
||||
)
|
||||
attention_maps_image_base64_url = (
|
||||
None
|
||||
if attention_maps_image is None
|
||||
|
@ -7,6 +7,7 @@ SAMPLER_CHOICES = [
|
||||
"ddpm",
|
||||
"deis",
|
||||
"lms",
|
||||
"lms_k",
|
||||
"pndm",
|
||||
"heun",
|
||||
'heun_k',
|
||||
@ -16,8 +17,13 @@ SAMPLER_CHOICES = [
|
||||
"kdpm_2",
|
||||
"kdpm_2_a",
|
||||
"dpmpp_2s",
|
||||
"dpmpp_2s_k",
|
||||
"dpmpp_2m",
|
||||
"dpmpp_2m_k",
|
||||
"dpmpp_2m_sde",
|
||||
"dpmpp_2m_sde_k",
|
||||
"dpmpp_sde",
|
||||
"dpmpp_sde_k",
|
||||
"unipc",
|
||||
]
|
||||
|
||||
|
2
invokeai/frontend/web/dist/index.html
vendored
2
invokeai/frontend/web/dist/index.html
vendored
@ -12,7 +12,7 @@
|
||||
margin: 0;
|
||||
}
|
||||
</style>
|
||||
<script type="module" crossorigin src="./assets/index-b060dbab.js"></script>
|
||||
<script type="module" crossorigin src="./assets/index-8a3e9251.js"></script>
|
||||
</head>
|
||||
|
||||
<body dir="ltr">
|
||||
|
4
invokeai/frontend/web/dist/locales/en.json
vendored
4
invokeai/frontend/web/dist/locales/en.json
vendored
@ -506,8 +506,8 @@
|
||||
"isScheduled": "Canceling",
|
||||
"setType": "Set cancel type"
|
||||
},
|
||||
"promptPlaceholder": "Type prompt here. [negative tokens], (upweight)++, (downweight)--, swap and blend are available (see docs)",
|
||||
"negativePrompts": "Negative Prompts",
|
||||
"positivePromptPlaceholder": "Positive Prompt",
|
||||
"negativePromptPlaceholder": "Negative Prompt",
|
||||
"sendTo": "Send to",
|
||||
"sendToImg2Img": "Send to Image to Image",
|
||||
"sendToUnifiedCanvas": "Send To Unified Canvas",
|
||||
|
@ -23,8 +23,7 @@
|
||||
"dev": "concurrently \"vite dev\" \"yarn run theme:watch\"",
|
||||
"dev:host": "concurrently \"vite dev --host\" \"yarn run theme:watch\"",
|
||||
"build": "yarn run lint && vite build",
|
||||
"api:web": "openapi -i http://localhost:9090/openapi.json -o src/services/api --client axios --useOptions --useUnionTypes --indent 2 --request src/services/fixtures/request.ts",
|
||||
"api:file": "openapi -i src/services/fixtures/openapi.json -o src/services/api --client axios --useOptions --useUnionTypes --indent 2 --request src/services/fixtures/request.ts",
|
||||
"typegen": "npx openapi-typescript http://localhost:9090/openapi.json --output src/services/schema.d.ts -t",
|
||||
"preview": "vite preview",
|
||||
"lint:madge": "madge --circular src/main.tsx",
|
||||
"lint:eslint": "eslint --max-warnings=0 .",
|
||||
@ -81,9 +80,12 @@
|
||||
"i18next-http-backend": "^2.2.0",
|
||||
"konva": "^9.0.1",
|
||||
"lodash-es": "^4.17.21",
|
||||
"nanostores": "^0.9.2",
|
||||
"openapi-fetch": "^0.4.0",
|
||||
"overlayscrollbars": "^2.1.1",
|
||||
"overlayscrollbars-react": "^0.5.0",
|
||||
"patch-package": "^7.0.0",
|
||||
"query-string": "^8.1.0",
|
||||
"re-resizable": "^6.9.9",
|
||||
"react": "^18.2.0",
|
||||
"react-colorful": "^5.6.1",
|
||||
@ -140,6 +142,7 @@
|
||||
"lint-staged": "^13.2.2",
|
||||
"madge": "^6.0.0",
|
||||
"openapi-types": "^12.1.0",
|
||||
"openapi-typescript": "^6.2.8",
|
||||
"openapi-typescript-codegen": "^0.24.0",
|
||||
"postinstall-postinstall": "^2.1.0",
|
||||
"prettier": "^2.8.8",
|
||||
|
55
invokeai/frontend/web/patches/openapi-fetch+0.4.0.patch
Normal file
55
invokeai/frontend/web/patches/openapi-fetch+0.4.0.patch
Normal file
@ -0,0 +1,55 @@
|
||||
diff --git a/node_modules/openapi-fetch/dist/index.js b/node_modules/openapi-fetch/dist/index.js
|
||||
index cd4528a..8976b51 100644
|
||||
--- a/node_modules/openapi-fetch/dist/index.js
|
||||
+++ b/node_modules/openapi-fetch/dist/index.js
|
||||
@@ -1,5 +1,5 @@
|
||||
// settings & const
|
||||
-const DEFAULT_HEADERS = {
|
||||
+const CONTENT_TYPE_APPLICATION_JSON = {
|
||||
"Content-Type": "application/json",
|
||||
};
|
||||
const TRAILING_SLASH_RE = /\/*$/;
|
||||
@@ -29,18 +29,29 @@ export function createFinalURL(url, options) {
|
||||
}
|
||||
return finalURL;
|
||||
}
|
||||
+function stringifyBody(body) {
|
||||
+ if (body instanceof ArrayBuffer || body instanceof File || body instanceof DataView || body instanceof Blob || ArrayBuffer.isView(body) || body instanceof URLSearchParams || body instanceof FormData) {
|
||||
+ return;
|
||||
+ }
|
||||
+
|
||||
+ if (typeof body === "string") {
|
||||
+ return body;
|
||||
+ }
|
||||
+
|
||||
+ return JSON.stringify(body);
|
||||
+ }
|
||||
+
|
||||
export default function createClient(clientOptions = {}) {
|
||||
const { fetch = globalThis.fetch, ...options } = clientOptions;
|
||||
- const defaultHeaders = new Headers({
|
||||
- ...DEFAULT_HEADERS,
|
||||
- ...(options.headers ?? {}),
|
||||
- });
|
||||
+ const defaultHeaders = new Headers(options.headers ?? {});
|
||||
async function coreFetch(url, fetchOptions) {
|
||||
const { headers, body: requestBody, params = {}, parseAs = "json", querySerializer = defaultSerializer, ...init } = fetchOptions || {};
|
||||
// URL
|
||||
const finalURL = createFinalURL(url, { baseUrl: options.baseUrl, params, querySerializer });
|
||||
+ // Stringify body if needed
|
||||
+ const stringifiedBody = stringifyBody(requestBody);
|
||||
// headers
|
||||
- const baseHeaders = new Headers(defaultHeaders); // clone defaults (don’t overwrite!)
|
||||
+ const baseHeaders = new Headers(stringifiedBody ? { ...CONTENT_TYPE_APPLICATION_JSON, ...defaultHeaders } : defaultHeaders); // clone defaults (don’t overwrite!)
|
||||
const headerOverrides = new Headers(headers);
|
||||
for (const [k, v] of headerOverrides.entries()) {
|
||||
if (v === undefined || v === null)
|
||||
@@ -54,7 +65,7 @@ export default function createClient(clientOptions = {}) {
|
||||
...options,
|
||||
...init,
|
||||
headers: baseHeaders,
|
||||
- body: typeof requestBody === "string" ? requestBody : JSON.stringify(requestBody),
|
||||
+ body: stringifiedBody ?? requestBody,
|
||||
});
|
||||
// handle empty content
|
||||
// note: we return `{}` because we want user truthy checks for `.data` or `.error` to succeed
|
@ -548,7 +548,8 @@
|
||||
"general": "General",
|
||||
"generation": "Generation",
|
||||
"ui": "User Interface",
|
||||
"availableSchedulers": "Available Schedulers"
|
||||
"favoriteSchedulers": "Favorite Schedulers",
|
||||
"favoriteSchedulersPlaceholder": "No schedulers favorited"
|
||||
},
|
||||
"toast": {
|
||||
"serverError": "Server Error",
|
||||
|
@ -23,6 +23,8 @@ import GlobalHotkeys from './GlobalHotkeys';
|
||||
import Toaster from './Toaster';
|
||||
import DeleteImageModal from 'features/gallery/components/DeleteImageModal';
|
||||
import { requestCanvasRescale } from 'features/canvas/store/thunks/requestCanvasScale';
|
||||
import UpdateImageBoardModal from '../../features/gallery/components/Boards/UpdateImageBoardModal';
|
||||
import { useListModelsQuery } from 'services/api/endpoints/models';
|
||||
|
||||
const DEFAULT_CONFIG = {};
|
||||
|
||||
@ -45,6 +47,18 @@ const App = ({
|
||||
|
||||
const isApplicationReady = useIsApplicationReady();
|
||||
|
||||
const { data: pipelineModels } = useListModelsQuery({
|
||||
model_type: 'pipeline',
|
||||
});
|
||||
const { data: controlnetModels } = useListModelsQuery({
|
||||
model_type: 'controlnet',
|
||||
});
|
||||
const { data: vaeModels } = useListModelsQuery({ model_type: 'vae' });
|
||||
const { data: loraModels } = useListModelsQuery({ model_type: 'lora' });
|
||||
const { data: embeddingModels } = useListModelsQuery({
|
||||
model_type: 'embedding',
|
||||
});
|
||||
|
||||
const [loadingOverridden, setLoadingOverridden] = useState(false);
|
||||
|
||||
const dispatch = useAppDispatch();
|
||||
@ -143,6 +157,7 @@ const App = ({
|
||||
</Portal>
|
||||
</Grid>
|
||||
<DeleteImageModal />
|
||||
<UpdateImageBoardModal />
|
||||
<Toaster />
|
||||
<GlobalHotkeys />
|
||||
</>
|
||||
|
@ -11,8 +11,8 @@ import {
|
||||
} from '@dnd-kit/core';
|
||||
import { PropsWithChildren, memo, useCallback, useState } from 'react';
|
||||
import OverlayDragImage from './OverlayDragImage';
|
||||
import { ImageDTO } from 'services/api';
|
||||
import { isImageDTO } from 'services/types/guards';
|
||||
import { ImageDTO } from 'services/api/types';
|
||||
import { isImageDTO } from 'services/api/guards';
|
||||
import { snapCenterToCursor } from '@dnd-kit/modifiers';
|
||||
import { AnimatePresence, motion } from 'framer-motion';
|
||||
|
||||
|
@ -1,6 +1,6 @@
|
||||
import { Box, Image } from '@chakra-ui/react';
|
||||
import { memo } from 'react';
|
||||
import { ImageDTO } from 'services/api';
|
||||
import { ImageDTO } from 'services/api/types';
|
||||
|
||||
type OverlayDragImageProps = {
|
||||
image: ImageDTO;
|
||||
|
@ -7,7 +7,7 @@ import React, {
|
||||
} from 'react';
|
||||
import { Provider } from 'react-redux';
|
||||
import { store } from 'app/store/store';
|
||||
import { OpenAPI } from 'services/api';
|
||||
// import { OpenAPI } from 'services/api/types';
|
||||
|
||||
import Loading from '../../common/components/Loading/Loading';
|
||||
import { addMiddleware, resetMiddlewares } from 'redux-dynamic-middlewares';
|
||||
@ -21,6 +21,9 @@ import {
|
||||
DeleteImageContext,
|
||||
DeleteImageContextProvider,
|
||||
} from 'app/contexts/DeleteImageContext';
|
||||
import UpdateImageBoardModal from '../../features/gallery/components/Boards/UpdateImageBoardModal';
|
||||
import { AddImageToBoardContextProvider } from '../contexts/AddImageToBoardContext';
|
||||
import { $authToken, $baseUrl } from 'services/api/client';
|
||||
|
||||
const App = lazy(() => import('./App'));
|
||||
const ThemeLocaleProvider = lazy(() => import('./ThemeLocaleProvider'));
|
||||
@ -45,12 +48,12 @@ const InvokeAIUI = ({
|
||||
useEffect(() => {
|
||||
// configure API client token
|
||||
if (token) {
|
||||
OpenAPI.TOKEN = token;
|
||||
$authToken.set(token);
|
||||
}
|
||||
|
||||
// configure API client base url
|
||||
if (apiUrl) {
|
||||
OpenAPI.BASE = apiUrl;
|
||||
$baseUrl.set(apiUrl);
|
||||
}
|
||||
|
||||
// reset dynamically added middlewares
|
||||
@ -67,6 +70,12 @@ const InvokeAIUI = ({
|
||||
} else {
|
||||
addMiddleware(socketMiddleware());
|
||||
}
|
||||
|
||||
return () => {
|
||||
// Reset the API client token and base url on unmount
|
||||
$baseUrl.set(undefined);
|
||||
$authToken.set(undefined);
|
||||
};
|
||||
}, [apiUrl, token, middleware]);
|
||||
|
||||
return (
|
||||
@ -76,11 +85,13 @@ const InvokeAIUI = ({
|
||||
<ThemeLocaleProvider>
|
||||
<ImageDndContext>
|
||||
<DeleteImageContextProvider>
|
||||
<App
|
||||
config={config}
|
||||
headerComponent={headerComponent}
|
||||
setIsReady={setIsReady}
|
||||
/>
|
||||
<AddImageToBoardContextProvider>
|
||||
<App
|
||||
config={config}
|
||||
headerComponent={headerComponent}
|
||||
setIsReady={setIsReady}
|
||||
/>
|
||||
</AddImageToBoardContextProvider>
|
||||
</DeleteImageContextProvider>
|
||||
</ImageDndContext>
|
||||
</ThemeLocaleProvider>
|
||||
|
@ -1,25 +1,62 @@
|
||||
// TODO: use Enums?
|
||||
import { SchedulerParam } from 'features/parameters/store/parameterZodSchemas';
|
||||
|
||||
export const SCHEDULERS = [
|
||||
'ddim',
|
||||
'lms',
|
||||
// zod needs the array to be `as const` to infer the type correctly
|
||||
// this is the source of the `SchedulerParam` type, which is generated by zod
|
||||
export const SCHEDULER_NAMES_AS_CONST = [
|
||||
'euler',
|
||||
'euler_k',
|
||||
'euler_a',
|
||||
'deis',
|
||||
'ddim',
|
||||
'ddpm',
|
||||
'dpmpp_2s',
|
||||
'dpmpp_2m',
|
||||
'dpmpp_2m_k',
|
||||
'kdpm_2',
|
||||
'kdpm_2_a',
|
||||
'deis',
|
||||
'ddpm',
|
||||
'pndm',
|
||||
'dpmpp_2m_sde',
|
||||
'dpmpp_sde',
|
||||
'heun',
|
||||
'heun_k',
|
||||
'kdpm_2',
|
||||
'lms',
|
||||
'pndm',
|
||||
'unipc',
|
||||
'euler_k',
|
||||
'dpmpp_2s_k',
|
||||
'dpmpp_2m_k',
|
||||
'dpmpp_2m_sde_k',
|
||||
'dpmpp_sde_k',
|
||||
'heun_k',
|
||||
'lms_k',
|
||||
'euler_a',
|
||||
'kdpm_2_a',
|
||||
] as const;
|
||||
|
||||
export type Scheduler = (typeof SCHEDULERS)[number];
|
||||
export const DEFAULT_SCHEDULER_NAME = 'euler';
|
||||
|
||||
export const SCHEDULER_NAMES: SchedulerParam[] = [...SCHEDULER_NAMES_AS_CONST];
|
||||
|
||||
export const SCHEDULER_LABEL_MAP: Record<SchedulerParam, string> = {
|
||||
euler: 'Euler',
|
||||
deis: 'DEIS',
|
||||
ddim: 'DDIM',
|
||||
ddpm: 'DDPM',
|
||||
dpmpp_sde: 'DPM++ SDE',
|
||||
dpmpp_2s: 'DPM++ 2S',
|
||||
dpmpp_2m: 'DPM++ 2M',
|
||||
dpmpp_2m_sde: 'DPM++ 2M SDE',
|
||||
heun: 'Heun',
|
||||
kdpm_2: 'KDPM 2',
|
||||
lms: 'LMS',
|
||||
pndm: 'PNDM',
|
||||
unipc: 'UniPC',
|
||||
euler_k: 'Euler Karras',
|
||||
dpmpp_sde_k: 'DPM++ SDE Karras',
|
||||
dpmpp_2s_k: 'DPM++ 2S Karras',
|
||||
dpmpp_2m_k: 'DPM++ 2M Karras',
|
||||
dpmpp_2m_sde_k: 'DPM++ 2M SDE Karras',
|
||||
heun_k: 'Heun Karras',
|
||||
lms_k: 'LMS Karras',
|
||||
euler_a: 'Euler Ancestral',
|
||||
kdpm_2_a: 'KDPM 2 Ancestral',
|
||||
};
|
||||
|
||||
export type Scheduler = (typeof SCHEDULER_NAMES)[number];
|
||||
|
||||
// Valid upscaling levels
|
||||
export const UPSCALING_LEVELS: Array<{ label: string; value: string }> = [
|
||||
|
@ -0,0 +1,89 @@
|
||||
import { useDisclosure } from '@chakra-ui/react';
|
||||
import { PropsWithChildren, createContext, useCallback, useState } from 'react';
|
||||
import { ImageDTO } from 'services/api/types';
|
||||
import { useAddImageToBoardMutation } from 'services/api/endpoints/boardImages';
|
||||
|
||||
export type ImageUsage = {
|
||||
isInitialImage: boolean;
|
||||
isCanvasImage: boolean;
|
||||
isNodesImage: boolean;
|
||||
isControlNetImage: boolean;
|
||||
};
|
||||
|
||||
type AddImageToBoardContextValue = {
|
||||
/**
|
||||
* Whether the move image dialog is open.
|
||||
*/
|
||||
isOpen: boolean;
|
||||
/**
|
||||
* Closes the move image dialog.
|
||||
*/
|
||||
onClose: () => void;
|
||||
/**
|
||||
* The image pending movement
|
||||
*/
|
||||
image?: ImageDTO;
|
||||
onClickAddToBoard: (image: ImageDTO) => void;
|
||||
handleAddToBoard: (boardId: string) => void;
|
||||
};
|
||||
|
||||
export const AddImageToBoardContext =
|
||||
createContext<AddImageToBoardContextValue>({
|
||||
isOpen: false,
|
||||
onClose: () => undefined,
|
||||
onClickAddToBoard: () => undefined,
|
||||
handleAddToBoard: () => undefined,
|
||||
});
|
||||
|
||||
type Props = PropsWithChildren;
|
||||
|
||||
export const AddImageToBoardContextProvider = (props: Props) => {
|
||||
const [imageToMove, setImageToMove] = useState<ImageDTO>();
|
||||
const { isOpen, onOpen, onClose } = useDisclosure();
|
||||
|
||||
const [addImageToBoard, result] = useAddImageToBoardMutation();
|
||||
|
||||
// Clean up after deleting or dismissing the modal
|
||||
const closeAndClearImageToDelete = useCallback(() => {
|
||||
setImageToMove(undefined);
|
||||
onClose();
|
||||
}, [onClose]);
|
||||
|
||||
const onClickAddToBoard = useCallback(
|
||||
(image?: ImageDTO) => {
|
||||
if (!image) {
|
||||
return;
|
||||
}
|
||||
setImageToMove(image);
|
||||
onOpen();
|
||||
},
|
||||
[setImageToMove, onOpen]
|
||||
);
|
||||
|
||||
const handleAddToBoard = useCallback(
|
||||
(boardId: string) => {
|
||||
if (imageToMove) {
|
||||
addImageToBoard({
|
||||
board_id: boardId,
|
||||
image_name: imageToMove.image_name,
|
||||
});
|
||||
closeAndClearImageToDelete();
|
||||
}
|
||||
},
|
||||
[addImageToBoard, closeAndClearImageToDelete, imageToMove]
|
||||
);
|
||||
|
||||
return (
|
||||
<AddImageToBoardContext.Provider
|
||||
value={{
|
||||
isOpen,
|
||||
image: imageToMove,
|
||||
onClose: closeAndClearImageToDelete,
|
||||
onClickAddToBoard,
|
||||
handleAddToBoard,
|
||||
}}
|
||||
>
|
||||
{props.children}
|
||||
</AddImageToBoardContext.Provider>
|
||||
);
|
||||
};
|
@ -11,7 +11,7 @@ import {
|
||||
useEffect,
|
||||
useState,
|
||||
} from 'react';
|
||||
import { ImageDTO } from 'services/api';
|
||||
import { ImageDTO } from 'services/api/types';
|
||||
import { RootState } from 'app/store/store';
|
||||
import { canvasSelector } from 'features/canvas/store/canvasSelectors';
|
||||
import { controlNetSelector } from 'features/controlNet/store/controlNetSlice';
|
||||
@ -35,25 +35,23 @@ export const selectImageUsage = createSelector(
|
||||
(state: RootState, image_name?: string) => image_name,
|
||||
],
|
||||
(generation, canvas, nodes, controlNet, image_name) => {
|
||||
const isInitialImage = generation.initialImage?.image_name === image_name;
|
||||
const isInitialImage = generation.initialImage?.imageName === image_name;
|
||||
|
||||
const isCanvasImage = canvas.layerState.objects.some(
|
||||
(obj) => obj.kind === 'image' && obj.image.image_name === image_name
|
||||
(obj) => obj.kind === 'image' && obj.imageName === image_name
|
||||
);
|
||||
|
||||
const isNodesImage = nodes.nodes.some((node) => {
|
||||
return some(
|
||||
node.data.inputs,
|
||||
(input) =>
|
||||
input.type === 'image' && input.value?.image_name === image_name
|
||||
(input) => input.type === 'image' && input.value === image_name
|
||||
);
|
||||
});
|
||||
|
||||
const isControlNetImage = some(
|
||||
controlNet.controlNets,
|
||||
(c) =>
|
||||
c.controlImage?.image_name === image_name ||
|
||||
c.processedControlImage?.image_name === image_name
|
||||
c.controlImage === image_name || c.processedControlImage === image_name
|
||||
);
|
||||
|
||||
const imageUsage: ImageUsage = {
|
||||
|
@ -5,7 +5,6 @@ import { lightboxPersistDenylist } from 'features/lightbox/store/lightboxPersist
|
||||
import { nodesPersistDenylist } from 'features/nodes/store/nodesPersistDenylist';
|
||||
import { generationPersistDenylist } from 'features/parameters/store/generationPersistDenylist';
|
||||
import { postprocessingPersistDenylist } from 'features/parameters/store/postprocessingPersistDenylist';
|
||||
import { modelsPersistDenylist } from 'features/system/store/modelsPersistDenylist';
|
||||
import { systemPersistDenylist } from 'features/system/store/systemPersistDenylist';
|
||||
import { uiPersistDenylist } from 'features/ui/store/uiPersistDenylist';
|
||||
import { omit } from 'lodash-es';
|
||||
@ -18,7 +17,6 @@ const serializationDenylist: {
|
||||
gallery: galleryPersistDenylist,
|
||||
generation: generationPersistDenylist,
|
||||
lightbox: lightboxPersistDenylist,
|
||||
models: modelsPersistDenylist,
|
||||
nodes: nodesPersistDenylist,
|
||||
postprocessing: postprocessingPersistDenylist,
|
||||
system: systemPersistDenylist,
|
||||
|
@ -7,7 +7,6 @@ import { initialNodesState } from 'features/nodes/store/nodesSlice';
|
||||
import { initialGenerationState } from 'features/parameters/store/generationSlice';
|
||||
import { initialPostprocessingState } from 'features/parameters/store/postprocessingSlice';
|
||||
import { initialConfigState } from 'features/system/store/configSlice';
|
||||
import { initialModelsState } from 'features/system/store/modelSlice';
|
||||
import { initialSystemState } from 'features/system/store/systemSlice';
|
||||
import { initialHotkeysState } from 'features/ui/store/hotkeysSlice';
|
||||
import { initialUIState } from 'features/ui/store/uiSlice';
|
||||
@ -21,7 +20,6 @@ const initialStates: {
|
||||
gallery: initialGalleryState,
|
||||
generation: initialGenerationState,
|
||||
lightbox: initialLightboxState,
|
||||
models: initialModelsState,
|
||||
nodes: initialNodesState,
|
||||
postprocessing: initialPostprocessingState,
|
||||
system: initialSystemState,
|
||||
|
@ -1,7 +1,7 @@
|
||||
import { AnyAction } from '@reduxjs/toolkit';
|
||||
import { isAnyGraphBuilt } from 'features/nodes/store/actions';
|
||||
import { forEach } from 'lodash-es';
|
||||
import { Graph } from 'services/api';
|
||||
import { Graph } from 'services/api/types';
|
||||
|
||||
export const actionSanitizer = <A extends AnyAction>(action: A): A => {
|
||||
if (isAnyGraphBuilt(action)) {
|
||||
|
@ -73,6 +73,15 @@ import { addImageCategoriesChangedListener } from './listeners/imageCategoriesCh
|
||||
import { addControlNetImageProcessedListener } from './listeners/controlNetImageProcessed';
|
||||
import { addControlNetAutoProcessListener } from './listeners/controlNetAutoProcess';
|
||||
import { addUpdateImageUrlsOnConnectListener } from './listeners/updateImageUrlsOnConnect';
|
||||
import {
|
||||
addImageAddedToBoardFulfilledListener,
|
||||
addImageAddedToBoardRejectedListener,
|
||||
} from './listeners/imageAddedToBoard';
|
||||
import { addBoardIdSelectedListener } from './listeners/boardIdSelected';
|
||||
import {
|
||||
addImageRemovedFromBoardFulfilledListener,
|
||||
addImageRemovedFromBoardRejectedListener,
|
||||
} from './listeners/imageRemovedFromBoard';
|
||||
|
||||
export const listenerMiddleware = createListenerMiddleware();
|
||||
|
||||
@ -92,6 +101,12 @@ export type AppListenerEffect = ListenerEffect<
|
||||
AppDispatch
|
||||
>;
|
||||
|
||||
/**
|
||||
* The RTK listener middleware is a lightweight alternative sagas/observables.
|
||||
*
|
||||
* Most side effect logic should live in a listener.
|
||||
*/
|
||||
|
||||
// Image uploaded
|
||||
addImageUploadedFulfilledListener();
|
||||
addImageUploadedRejectedListener();
|
||||
@ -183,3 +198,10 @@ addControlNetAutoProcessListener();
|
||||
|
||||
// Update image URLs on connect
|
||||
addUpdateImageUrlsOnConnectListener();
|
||||
|
||||
// Boards
|
||||
addImageAddedToBoardFulfilledListener();
|
||||
addImageAddedToBoardRejectedListener();
|
||||
addImageRemovedFromBoardFulfilledListener();
|
||||
addImageRemovedFromBoardRejectedListener();
|
||||
addBoardIdSelectedListener();
|
||||
|
@ -1,7 +1,7 @@
|
||||
import { startAppListening } from '..';
|
||||
import { log } from 'app/logging/useLogger';
|
||||
import { commitStagingAreaImage } from 'features/canvas/store/canvasSlice';
|
||||
import { sessionCanceled } from 'services/thunks/session';
|
||||
import { sessionCanceled } from 'services/api/thunks/session';
|
||||
|
||||
const moduleLog = log.child({ namespace: 'canvas' });
|
||||
|
||||
@ -10,7 +10,7 @@ export const addCommitStagingAreaImageListener = () => {
|
||||
actionCreator: commitStagingAreaImage,
|
||||
effect: async (action, { dispatch, getState }) => {
|
||||
const state = getState();
|
||||
const { sessionId, isProcessing } = state.system;
|
||||
const { sessionId: session_id, isProcessing } = state.system;
|
||||
const canvasSessionId = action.payload;
|
||||
|
||||
if (!isProcessing) {
|
||||
@ -23,12 +23,12 @@ export const addCommitStagingAreaImageListener = () => {
|
||||
return;
|
||||
}
|
||||
|
||||
if (canvasSessionId !== sessionId) {
|
||||
if (canvasSessionId !== session_id) {
|
||||
moduleLog.debug(
|
||||
{
|
||||
data: {
|
||||
canvasSessionId,
|
||||
sessionId,
|
||||
session_id,
|
||||
},
|
||||
},
|
||||
'Canvas session does not match global session, skipping cancel'
|
||||
@ -36,7 +36,7 @@ export const addCommitStagingAreaImageListener = () => {
|
||||
return;
|
||||
}
|
||||
|
||||
dispatch(sessionCanceled({ sessionId }));
|
||||
dispatch(sessionCanceled({ session_id }));
|
||||
},
|
||||
});
|
||||
};
|
||||
|
@ -0,0 +1,108 @@
|
||||
import { log } from 'app/logging/useLogger';
|
||||
import { startAppListening } from '..';
|
||||
import { boardIdSelected } from 'features/gallery/store/boardSlice';
|
||||
import { selectImagesAll } from 'features/gallery/store/imagesSlice';
|
||||
import {
|
||||
IMAGES_PER_PAGE,
|
||||
receivedPageOfImages,
|
||||
} from 'services/api/thunks/image';
|
||||
import { imageSelected } from 'features/gallery/store/gallerySlice';
|
||||
import { boardsApi } from 'services/api/endpoints/boards';
|
||||
|
||||
const moduleLog = log.child({ namespace: 'boards' });
|
||||
|
||||
export const addBoardIdSelectedListener = () => {
|
||||
startAppListening({
|
||||
actionCreator: boardIdSelected,
|
||||
effect: (action, { getState, dispatch }) => {
|
||||
const board_id = action.payload;
|
||||
|
||||
// we need to check if we need to fetch more images
|
||||
|
||||
const state = getState();
|
||||
const allImages = selectImagesAll(state);
|
||||
|
||||
if (!board_id) {
|
||||
// a board was unselected
|
||||
dispatch(imageSelected(allImages[0]?.image_name));
|
||||
return;
|
||||
}
|
||||
|
||||
const { categories } = state.images;
|
||||
|
||||
const filteredImages = allImages.filter((i) => {
|
||||
const isInCategory = categories.includes(i.image_category);
|
||||
const isInSelectedBoard = board_id ? i.board_id === board_id : true;
|
||||
return isInCategory && isInSelectedBoard;
|
||||
});
|
||||
|
||||
// get the board from the cache
|
||||
const { data: boards } =
|
||||
boardsApi.endpoints.listAllBoards.select()(state);
|
||||
const board = boards?.find((b) => b.board_id === board_id);
|
||||
|
||||
if (!board) {
|
||||
// can't find the board in cache...
|
||||
dispatch(imageSelected(allImages[0]?.image_name));
|
||||
return;
|
||||
}
|
||||
|
||||
dispatch(imageSelected(board.cover_image_name));
|
||||
|
||||
// if we haven't loaded one full page of images from this board, load more
|
||||
if (
|
||||
filteredImages.length < board.image_count &&
|
||||
filteredImages.length < IMAGES_PER_PAGE
|
||||
) {
|
||||
dispatch(
|
||||
receivedPageOfImages({ categories, board_id, is_intermediate: false })
|
||||
);
|
||||
}
|
||||
},
|
||||
});
|
||||
};
|
||||
|
||||
export const addBoardIdSelected_changeSelectedImage_listener = () => {
|
||||
startAppListening({
|
||||
actionCreator: boardIdSelected,
|
||||
effect: (action, { getState, dispatch }) => {
|
||||
const board_id = action.payload;
|
||||
|
||||
const state = getState();
|
||||
|
||||
// we need to check if we need to fetch more images
|
||||
|
||||
if (!board_id) {
|
||||
// a board was unselected - we don't need to do anything
|
||||
return;
|
||||
}
|
||||
|
||||
const { categories } = state.images;
|
||||
|
||||
const filteredImages = selectImagesAll(state).filter((i) => {
|
||||
const isInCategory = categories.includes(i.image_category);
|
||||
const isInSelectedBoard = board_id ? i.board_id === board_id : true;
|
||||
return isInCategory && isInSelectedBoard;
|
||||
});
|
||||
|
||||
// get the board from the cache
|
||||
const { data: boards } =
|
||||
boardsApi.endpoints.listAllBoards.select()(state);
|
||||
const board = boards?.find((b) => b.board_id === board_id);
|
||||
if (!board) {
|
||||
// can't find the board in cache...
|
||||
return;
|
||||
}
|
||||
|
||||
// if we haven't loaded one full page of images from this board, load more
|
||||
if (
|
||||
filteredImages.length < board.image_count &&
|
||||
filteredImages.length < IMAGES_PER_PAGE
|
||||
) {
|
||||
dispatch(
|
||||
receivedPageOfImages({ categories, board_id, is_intermediate: false })
|
||||
);
|
||||
}
|
||||
},
|
||||
});
|
||||
};
|
@ -2,7 +2,7 @@ import { canvasMerged } from 'features/canvas/store/actions';
|
||||
import { startAppListening } from '..';
|
||||
import { log } from 'app/logging/useLogger';
|
||||
import { addToast } from 'features/system/store/systemSlice';
|
||||
import { imageUploaded } from 'services/thunks/image';
|
||||
import { imageUploaded } from 'services/api/thunks/image';
|
||||
import { setMergedCanvas } from 'features/canvas/store/canvasSlice';
|
||||
import { getCanvasBaseLayer } from 'features/canvas/util/konvaInstanceProvider';
|
||||
import { getFullBaseLayerBlob } from 'features/canvas/util/getFullBaseLayerBlob';
|
||||
@ -47,13 +47,11 @@ export const addCanvasMergedListener = () => {
|
||||
|
||||
const imageUploadedRequest = dispatch(
|
||||
imageUploaded({
|
||||
formData: {
|
||||
file: new File([blob], 'mergedCanvas.png', {
|
||||
type: 'image/png',
|
||||
}),
|
||||
},
|
||||
imageCategory: 'general',
|
||||
isIntermediate: true,
|
||||
file: new File([blob], 'mergedCanvas.png', {
|
||||
type: 'image/png',
|
||||
}),
|
||||
image_category: 'general',
|
||||
is_intermediate: true,
|
||||
postUploadAction: {
|
||||
type: 'TOAST_CANVAS_MERGED',
|
||||
},
|
||||
@ -68,13 +66,13 @@ export const addCanvasMergedListener = () => {
|
||||
uploadedImageAction.meta.requestId === imageUploadedRequest.requestId
|
||||
);
|
||||
|
||||
const mergedCanvasImage = payload;
|
||||
const { image_name } = payload;
|
||||
|
||||
dispatch(
|
||||
setMergedCanvas({
|
||||
kind: 'image',
|
||||
layer: 'base',
|
||||
image: mergedCanvasImage,
|
||||
imageName: image_name,
|
||||
...baseLayerRect,
|
||||
})
|
||||
);
|
||||
|
@ -1,7 +1,7 @@
|
||||
import { canvasSavedToGallery } from 'features/canvas/store/actions';
|
||||
import { startAppListening } from '..';
|
||||
import { log } from 'app/logging/useLogger';
|
||||
import { imageUploaded } from 'services/thunks/image';
|
||||
import { imageUploaded } from 'services/api/thunks/image';
|
||||
import { getBaseLayerBlob } from 'features/canvas/util/getBaseLayerBlob';
|
||||
import { addToast } from 'features/system/store/systemSlice';
|
||||
import { imageUpserted } from 'features/gallery/store/imagesSlice';
|
||||
@ -30,13 +30,11 @@ export const addCanvasSavedToGalleryListener = () => {
|
||||
|
||||
const imageUploadedRequest = dispatch(
|
||||
imageUploaded({
|
||||
formData: {
|
||||
file: new File([blob], 'savedCanvas.png', {
|
||||
type: 'image/png',
|
||||
}),
|
||||
},
|
||||
imageCategory: 'general',
|
||||
isIntermediate: false,
|
||||
file: new File([blob], 'savedCanvas.png', {
|
||||
type: 'image/png',
|
||||
}),
|
||||
image_category: 'general',
|
||||
is_intermediate: false,
|
||||
postUploadAction: {
|
||||
type: 'TOAST_CANVAS_SAVED_TO_GALLERY',
|
||||
},
|
||||
|
@ -1,14 +1,13 @@
|
||||
import { startAppListening } from '..';
|
||||
import { imageMetadataReceived } from 'services/thunks/image';
|
||||
import { imageMetadataReceived } from 'services/api/thunks/image';
|
||||
import { log } from 'app/logging/useLogger';
|
||||
import { controlNetImageProcessed } from 'features/controlNet/store/actions';
|
||||
import { Graph } from 'services/api';
|
||||
import { sessionCreated } from 'services/thunks/session';
|
||||
import { Graph } from 'services/api/types';
|
||||
import { sessionCreated } from 'services/api/thunks/session';
|
||||
import { sessionReadyToInvoke } from 'features/system/store/actions';
|
||||
import { socketInvocationComplete } from 'services/events/actions';
|
||||
import { isImageOutput } from 'services/types/guards';
|
||||
import { isImageOutput } from 'services/api/guards';
|
||||
import { controlNetProcessedImageChanged } from 'features/controlNet/store/controlNetSlice';
|
||||
import { pick } from 'lodash-es';
|
||||
|
||||
const moduleLog = log.child({ namespace: 'controlNet' });
|
||||
|
||||
@ -34,7 +33,7 @@ export const addControlNetImageProcessedListener = () => {
|
||||
[controlNet.processorNode.id]: {
|
||||
...controlNet.processorNode,
|
||||
is_intermediate: true,
|
||||
image: pick(controlNet.controlImage, ['image_name']),
|
||||
image: { image_name: controlNet.controlImage },
|
||||
},
|
||||
},
|
||||
};
|
||||
@ -81,7 +80,7 @@ export const addControlNetImageProcessedListener = () => {
|
||||
dispatch(
|
||||
controlNetProcessedImageChanged({
|
||||
controlNetId,
|
||||
processedControlImage,
|
||||
processedControlImage: processedControlImage.image_name,
|
||||
})
|
||||
);
|
||||
}
|
||||
|
@ -0,0 +1,40 @@
|
||||
import { log } from 'app/logging/useLogger';
|
||||
import { startAppListening } from '..';
|
||||
import { imageMetadataReceived } from 'services/api/thunks/image';
|
||||
import { boardImagesApi } from 'services/api/endpoints/boardImages';
|
||||
|
||||
const moduleLog = log.child({ namespace: 'boards' });
|
||||
|
||||
export const addImageAddedToBoardFulfilledListener = () => {
|
||||
startAppListening({
|
||||
matcher: boardImagesApi.endpoints.addImageToBoard.matchFulfilled,
|
||||
effect: (action, { getState, dispatch }) => {
|
||||
const { board_id, image_name } = action.meta.arg.originalArgs;
|
||||
|
||||
moduleLog.debug(
|
||||
{ data: { board_id, image_name } },
|
||||
'Image added to board'
|
||||
);
|
||||
|
||||
dispatch(
|
||||
imageMetadataReceived({
|
||||
image_name,
|
||||
})
|
||||
);
|
||||
},
|
||||
});
|
||||
};
|
||||
|
||||
export const addImageAddedToBoardRejectedListener = () => {
|
||||
startAppListening({
|
||||
matcher: boardImagesApi.endpoints.addImageToBoard.matchRejected,
|
||||
effect: (action, { getState, dispatch }) => {
|
||||
const { board_id, image_name } = action.meta.arg.originalArgs;
|
||||
|
||||
moduleLog.debug(
|
||||
{ data: { board_id, image_name } },
|
||||
'Problem adding image to board'
|
||||
);
|
||||
},
|
||||
});
|
||||
};
|
@ -1,6 +1,6 @@
|
||||
import { log } from 'app/logging/useLogger';
|
||||
import { startAppListening } from '..';
|
||||
import { receivedPageOfImages } from 'services/thunks/image';
|
||||
import { receivedPageOfImages } from 'services/api/thunks/image';
|
||||
import {
|
||||
imageCategoriesChanged,
|
||||
selectFilteredImagesAsArray,
|
||||
@ -12,12 +12,17 @@ export const addImageCategoriesChangedListener = () => {
|
||||
startAppListening({
|
||||
actionCreator: imageCategoriesChanged,
|
||||
effect: (action, { getState, dispatch }) => {
|
||||
const filteredImagesCount = selectFilteredImagesAsArray(
|
||||
getState()
|
||||
).length;
|
||||
const state = getState();
|
||||
const filteredImagesCount = selectFilteredImagesAsArray(state).length;
|
||||
|
||||
if (!filteredImagesCount) {
|
||||
dispatch(receivedPageOfImages());
|
||||
dispatch(
|
||||
receivedPageOfImages({
|
||||
categories: action.payload,
|
||||
board_id: state.boards.selectedBoardId,
|
||||
is_intermediate: false,
|
||||
})
|
||||
);
|
||||
}
|
||||
},
|
||||
});
|
||||
|
@ -1,20 +1,20 @@
|
||||
import { requestedImageDeletion } from 'features/gallery/store/actions';
|
||||
import { startAppListening } from '..';
|
||||
import { imageDeleted } from 'services/thunks/image';
|
||||
import { imageDeleted } from 'services/api/thunks/image';
|
||||
import { log } from 'app/logging/useLogger';
|
||||
import { clamp } from 'lodash-es';
|
||||
import { imageSelected } from 'features/gallery/store/gallerySlice';
|
||||
import {
|
||||
imageRemoved,
|
||||
selectImagesEntities,
|
||||
selectImagesIds,
|
||||
} from 'features/gallery/store/imagesSlice';
|
||||
import { resetCanvas } from 'features/canvas/store/canvasSlice';
|
||||
import { controlNetReset } from 'features/controlNet/store/controlNetSlice';
|
||||
import { clearInitialImage } from 'features/parameters/store/generationSlice';
|
||||
import { nodeEditorReset } from 'features/nodes/store/nodesSlice';
|
||||
import { api } from 'services/api';
|
||||
|
||||
const moduleLog = log.child({ namespace: 'addRequestedImageDeletionListener' });
|
||||
const moduleLog = log.child({ namespace: 'image' });
|
||||
|
||||
/**
|
||||
* Called when the user requests an image deletion
|
||||
@ -22,7 +22,7 @@ const moduleLog = log.child({ namespace: 'addRequestedImageDeletionListener' });
|
||||
export const addRequestedImageDeletionListener = () => {
|
||||
startAppListening({
|
||||
actionCreator: requestedImageDeletion,
|
||||
effect: (action, { dispatch, getState }) => {
|
||||
effect: async (action, { dispatch, getState, condition }) => {
|
||||
const { image, imageUsage } = action.payload;
|
||||
|
||||
const { image_name } = image;
|
||||
@ -30,9 +30,8 @@ export const addRequestedImageDeletionListener = () => {
|
||||
const state = getState();
|
||||
const selectedImage = state.gallery.selectedImage;
|
||||
|
||||
if (selectedImage && selectedImage.image_name === image_name) {
|
||||
if (selectedImage === image_name) {
|
||||
const ids = selectImagesIds(state);
|
||||
const entities = selectImagesEntities(state);
|
||||
|
||||
const deletedImageIndex = ids.findIndex(
|
||||
(result) => result.toString() === image_name
|
||||
@ -48,10 +47,8 @@ export const addRequestedImageDeletionListener = () => {
|
||||
|
||||
const newSelectedImageId = filteredIds[newSelectedImageIndex];
|
||||
|
||||
const newSelectedImage = entities[newSelectedImageId];
|
||||
|
||||
if (newSelectedImageId) {
|
||||
dispatch(imageSelected(newSelectedImage));
|
||||
dispatch(imageSelected(newSelectedImageId as string));
|
||||
} else {
|
||||
dispatch(imageSelected());
|
||||
}
|
||||
@ -79,7 +76,21 @@ export const addRequestedImageDeletionListener = () => {
|
||||
dispatch(imageRemoved(image_name));
|
||||
|
||||
// Delete from server
|
||||
dispatch(imageDeleted({ imageName: image_name }));
|
||||
const { requestId } = dispatch(imageDeleted({ image_name }));
|
||||
|
||||
// Wait for successful deletion, then trigger boards to re-fetch
|
||||
const wasImageDeleted = await condition(
|
||||
(action): action is ReturnType<typeof imageDeleted.fulfilled> =>
|
||||
imageDeleted.fulfilled.match(action) &&
|
||||
action.meta.requestId === requestId,
|
||||
30000
|
||||
);
|
||||
|
||||
if (wasImageDeleted) {
|
||||
dispatch(
|
||||
api.util.invalidateTags([{ type: 'Board', id: image.board_id }])
|
||||
);
|
||||
}
|
||||
},
|
||||
});
|
||||
};
|
||||
|
@ -1,6 +1,6 @@
|
||||
import { log } from 'app/logging/useLogger';
|
||||
import { startAppListening } from '..';
|
||||
import { imageMetadataReceived, imageUpdated } from 'services/thunks/image';
|
||||
import { imageMetadataReceived, imageUpdated } from 'services/api/thunks/image';
|
||||
import { imageUpserted } from 'features/gallery/store/imagesSlice';
|
||||
|
||||
const moduleLog = log.child({ namespace: 'image' });
|
||||
@ -19,8 +19,8 @@ export const addImageMetadataReceivedFulfilledListener = () => {
|
||||
) {
|
||||
dispatch(
|
||||
imageUpdated({
|
||||
imageName: image.image_name,
|
||||
requestBody: { is_intermediate: false },
|
||||
image_name: image.image_name,
|
||||
is_intermediate: image.is_intermediate,
|
||||
})
|
||||
);
|
||||
} else if (image.is_intermediate) {
|
||||
|
@ -0,0 +1,40 @@
|
||||
import { log } from 'app/logging/useLogger';
|
||||
import { startAppListening } from '..';
|
||||
import { imageMetadataReceived } from 'services/api/thunks/image';
|
||||
import { boardImagesApi } from 'services/api/endpoints/boardImages';
|
||||
|
||||
const moduleLog = log.child({ namespace: 'boards' });
|
||||
|
||||
export const addImageRemovedFromBoardFulfilledListener = () => {
|
||||
startAppListening({
|
||||
matcher: boardImagesApi.endpoints.removeImageFromBoard.matchFulfilled,
|
||||
effect: (action, { getState, dispatch }) => {
|
||||
const { board_id, image_name } = action.meta.arg.originalArgs;
|
||||
|
||||
moduleLog.debug(
|
||||
{ data: { board_id, image_name } },
|
||||
'Image added to board'
|
||||
);
|
||||
|
||||
dispatch(
|
||||
imageMetadataReceived({
|
||||
image_name,
|
||||
})
|
||||
);
|
||||
},
|
||||
});
|
||||
};
|
||||
|
||||
export const addImageRemovedFromBoardRejectedListener = () => {
|
||||
startAppListening({
|
||||
matcher: boardImagesApi.endpoints.removeImageFromBoard.matchRejected,
|
||||
effect: (action, { getState, dispatch }) => {
|
||||
const { board_id, image_name } = action.meta.arg.originalArgs;
|
||||
|
||||
moduleLog.debug(
|
||||
{ data: { board_id, image_name } },
|
||||
'Problem adding image to board'
|
||||
);
|
||||
},
|
||||
});
|
||||
};
|
@ -1,5 +1,5 @@
|
||||
import { startAppListening } from '..';
|
||||
import { imageUpdated } from 'services/thunks/image';
|
||||
import { imageUpdated } from 'services/api/thunks/image';
|
||||
import { log } from 'app/logging/useLogger';
|
||||
|
||||
const moduleLog = log.child({ namespace: 'image' });
|
||||
|
@ -1,5 +1,5 @@
|
||||
import { startAppListening } from '..';
|
||||
import { imageUploaded } from 'services/thunks/image';
|
||||
import { imageUploaded } from 'services/api/thunks/image';
|
||||
import { addToast } from 'features/system/store/systemSlice';
|
||||
import { log } from 'app/logging/useLogger';
|
||||
import { imageUpserted } from 'features/gallery/store/imagesSlice';
|
||||
@ -46,7 +46,12 @@ export const addImageUploadedFulfilledListener = () => {
|
||||
|
||||
if (postUploadAction?.type === 'SET_CONTROLNET_IMAGE') {
|
||||
const { controlNetId } = postUploadAction;
|
||||
dispatch(controlNetImageChanged({ controlNetId, controlImage: image }));
|
||||
dispatch(
|
||||
controlNetImageChanged({
|
||||
controlNetId,
|
||||
controlImage: image.image_name,
|
||||
})
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
|
@ -1,6 +1,6 @@
|
||||
import { log } from 'app/logging/useLogger';
|
||||
import { startAppListening } from '..';
|
||||
import { imageUrlsReceived } from 'services/thunks/image';
|
||||
import { imageUrlsReceived } from 'services/api/thunks/image';
|
||||
import { imageUpdatedOne } from 'features/gallery/store/imagesSlice';
|
||||
|
||||
const moduleLog = log.child({ namespace: 'image' });
|
||||
|
@ -5,7 +5,7 @@ import { startAppListening } from '..';
|
||||
import { initialImageSelected } from 'features/parameters/store/actions';
|
||||
import { makeToast } from 'app/components/Toaster';
|
||||
import { selectImagesById } from 'features/gallery/store/imagesSlice';
|
||||
import { isImageDTO } from 'services/types/guards';
|
||||
import { isImageDTO } from 'services/api/guards';
|
||||
|
||||
export const addInitialImageSelectedListener = () => {
|
||||
startAppListening({
|
||||
|
@ -1,7 +1,7 @@
|
||||
import { log } from 'app/logging/useLogger';
|
||||
import { startAppListening } from '..';
|
||||
import { serializeError } from 'serialize-error';
|
||||
import { receivedPageOfImages } from 'services/thunks/image';
|
||||
import { receivedPageOfImages } from 'services/api/thunks/image';
|
||||
|
||||
const moduleLog = log.child({ namespace: 'gallery' });
|
||||
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user