mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Merge branch 'main' into feat/nodes/freeu
This commit is contained in:
@ -1,35 +1,35 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
import sqlite3
|
||||
from logging import Logger
|
||||
|
||||
from invokeai.app.services.board_image_record_storage import SqliteBoardImageRecordStorage
|
||||
from invokeai.app.services.board_images import BoardImagesService, BoardImagesServiceDependencies
|
||||
from invokeai.app.services.board_record_storage import SqliteBoardRecordStorage
|
||||
from invokeai.app.services.boards import BoardService, BoardServiceDependencies
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.app.services.image_record_storage import SqliteImageRecordStorage
|
||||
from invokeai.app.services.images import ImageService, ImageServiceDependencies
|
||||
from invokeai.app.services.invocation_cache.invocation_cache_memory import MemoryInvocationCache
|
||||
from invokeai.app.services.resource_name import SimpleNameService
|
||||
from invokeai.app.services.session_processor.session_processor_default import DefaultSessionProcessor
|
||||
from invokeai.app.services.session_queue.session_queue_sqlite import SqliteSessionQueue
|
||||
from invokeai.app.services.urls import LocalUrlService
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
from invokeai.version.invokeai_version import __version__
|
||||
|
||||
from ..services.default_graphs import create_system_graphs
|
||||
from ..services.graph import GraphExecutionState, LibraryGraph
|
||||
from ..services.image_file_storage import DiskImageFileStorage
|
||||
from ..services.invocation_queue import MemoryInvocationQueue
|
||||
from ..services.board_image_records.board_image_records_sqlite import SqliteBoardImageRecordStorage
|
||||
from ..services.board_images.board_images_default import BoardImagesService
|
||||
from ..services.board_records.board_records_sqlite import SqliteBoardRecordStorage
|
||||
from ..services.boards.boards_default import BoardService
|
||||
from ..services.config import InvokeAIAppConfig
|
||||
from ..services.image_files.image_files_disk import DiskImageFileStorage
|
||||
from ..services.image_records.image_records_sqlite import SqliteImageRecordStorage
|
||||
from ..services.images.images_default import ImageService
|
||||
from ..services.invocation_cache.invocation_cache_memory import MemoryInvocationCache
|
||||
from ..services.invocation_processor.invocation_processor_default import DefaultInvocationProcessor
|
||||
from ..services.invocation_queue.invocation_queue_memory import MemoryInvocationQueue
|
||||
from ..services.invocation_services import InvocationServices
|
||||
from ..services.invocation_stats import InvocationStatsService
|
||||
from ..services.invocation_stats.invocation_stats_default import InvocationStatsService
|
||||
from ..services.invoker import Invoker
|
||||
from ..services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage
|
||||
from ..services.model_manager_service import ModelManagerService
|
||||
from ..services.processor import DefaultInvocationProcessor
|
||||
from ..services.sqlite import SqliteItemStorage
|
||||
from ..services.thread import lock
|
||||
from ..services.item_storage.item_storage_sqlite import SqliteItemStorage
|
||||
from ..services.latents_storage.latents_storage_disk import DiskLatentsStorage
|
||||
from ..services.latents_storage.latents_storage_forward_cache import ForwardCacheLatentsStorage
|
||||
from ..services.model_manager.model_manager_default import ModelManagerService
|
||||
from ..services.names.names_default import SimpleNameService
|
||||
from ..services.session_processor.session_processor_default import DefaultSessionProcessor
|
||||
from ..services.session_queue.session_queue_sqlite import SqliteSessionQueue
|
||||
from ..services.shared.default_graphs import create_system_graphs
|
||||
from ..services.shared.graph import GraphExecutionState, LibraryGraph
|
||||
from ..services.shared.sqlite import SqliteDatabase
|
||||
from ..services.urls.urls_default import LocalUrlService
|
||||
from .events import FastAPIEventService
|
||||
|
||||
|
||||
@ -63,100 +63,64 @@ class ApiDependencies:
|
||||
logger.info(f"Root directory = {str(config.root_path)}")
|
||||
logger.debug(f"Internet connectivity is {config.internet_available}")
|
||||
|
||||
events = FastAPIEventService(event_handler_id)
|
||||
|
||||
output_folder = config.output_path
|
||||
|
||||
# TODO: build a file/path manager?
|
||||
if config.use_memory_db:
|
||||
db_location = ":memory:"
|
||||
else:
|
||||
db_path = config.db_path
|
||||
db_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
db_location = str(db_path)
|
||||
db = SqliteDatabase(config, logger)
|
||||
|
||||
logger.info(f"Using database at {db_location}")
|
||||
db_conn = sqlite3.connect(db_location, check_same_thread=False) # TODO: figure out a better threading solution
|
||||
configuration = config
|
||||
logger = logger
|
||||
|
||||
if config.log_sql:
|
||||
db_conn.set_trace_callback(print)
|
||||
db_conn.execute("PRAGMA foreign_keys = ON;")
|
||||
|
||||
graph_execution_manager = SqliteItemStorage[GraphExecutionState](
|
||||
conn=db_conn, table_name="graph_executions", lock=lock
|
||||
)
|
||||
|
||||
urls = LocalUrlService()
|
||||
image_record_storage = SqliteImageRecordStorage(conn=db_conn, lock=lock)
|
||||
image_file_storage = DiskImageFileStorage(f"{output_folder}/images")
|
||||
names = SimpleNameService()
|
||||
board_image_records = SqliteBoardImageRecordStorage(db=db)
|
||||
board_images = BoardImagesService()
|
||||
board_records = SqliteBoardRecordStorage(db=db)
|
||||
boards = BoardService()
|
||||
events = FastAPIEventService(event_handler_id)
|
||||
graph_execution_manager = SqliteItemStorage[GraphExecutionState](db=db, table_name="graph_executions")
|
||||
graph_library = SqliteItemStorage[LibraryGraph](db=db, table_name="graphs")
|
||||
image_files = DiskImageFileStorage(f"{output_folder}/images")
|
||||
image_records = SqliteImageRecordStorage(db=db)
|
||||
images = ImageService()
|
||||
invocation_cache = MemoryInvocationCache(max_cache_size=config.node_cache_size)
|
||||
latents = ForwardCacheLatentsStorage(DiskLatentsStorage(f"{output_folder}/latents"))
|
||||
|
||||
board_record_storage = SqliteBoardRecordStorage(conn=db_conn, lock=lock)
|
||||
board_image_record_storage = SqliteBoardImageRecordStorage(conn=db_conn, lock=lock)
|
||||
|
||||
boards = BoardService(
|
||||
services=BoardServiceDependencies(
|
||||
board_image_record_storage=board_image_record_storage,
|
||||
board_record_storage=board_record_storage,
|
||||
image_record_storage=image_record_storage,
|
||||
url=urls,
|
||||
logger=logger,
|
||||
)
|
||||
)
|
||||
|
||||
board_images = BoardImagesService(
|
||||
services=BoardImagesServiceDependencies(
|
||||
board_image_record_storage=board_image_record_storage,
|
||||
board_record_storage=board_record_storage,
|
||||
image_record_storage=image_record_storage,
|
||||
url=urls,
|
||||
logger=logger,
|
||||
)
|
||||
)
|
||||
|
||||
images = ImageService(
|
||||
services=ImageServiceDependencies(
|
||||
board_image_record_storage=board_image_record_storage,
|
||||
image_record_storage=image_record_storage,
|
||||
image_file_storage=image_file_storage,
|
||||
url=urls,
|
||||
logger=logger,
|
||||
names=names,
|
||||
graph_execution_manager=graph_execution_manager,
|
||||
)
|
||||
)
|
||||
model_manager = ModelManagerService(config, logger)
|
||||
names = SimpleNameService()
|
||||
performance_statistics = InvocationStatsService()
|
||||
processor = DefaultInvocationProcessor()
|
||||
queue = MemoryInvocationQueue()
|
||||
session_processor = DefaultSessionProcessor()
|
||||
session_queue = SqliteSessionQueue(db=db)
|
||||
urls = LocalUrlService()
|
||||
|
||||
services = InvocationServices(
|
||||
model_manager=ModelManagerService(config, logger),
|
||||
events=events,
|
||||
latents=latents,
|
||||
images=images,
|
||||
boards=boards,
|
||||
board_image_records=board_image_records,
|
||||
board_images=board_images,
|
||||
queue=MemoryInvocationQueue(),
|
||||
graph_library=SqliteItemStorage[LibraryGraph](conn=db_conn, lock=lock, table_name="graphs"),
|
||||
board_records=board_records,
|
||||
boards=boards,
|
||||
configuration=configuration,
|
||||
events=events,
|
||||
graph_execution_manager=graph_execution_manager,
|
||||
processor=DefaultInvocationProcessor(),
|
||||
configuration=config,
|
||||
performance_statistics=InvocationStatsService(graph_execution_manager),
|
||||
graph_library=graph_library,
|
||||
image_files=image_files,
|
||||
image_records=image_records,
|
||||
images=images,
|
||||
invocation_cache=invocation_cache,
|
||||
latents=latents,
|
||||
logger=logger,
|
||||
session_queue=SqliteSessionQueue(conn=db_conn, lock=lock),
|
||||
session_processor=DefaultSessionProcessor(),
|
||||
invocation_cache=MemoryInvocationCache(max_cache_size=config.node_cache_size),
|
||||
model_manager=model_manager,
|
||||
names=names,
|
||||
performance_statistics=performance_statistics,
|
||||
processor=processor,
|
||||
queue=queue,
|
||||
session_processor=session_processor,
|
||||
session_queue=session_queue,
|
||||
urls=urls,
|
||||
)
|
||||
|
||||
create_system_graphs(services.graph_library)
|
||||
|
||||
ApiDependencies.invoker = Invoker(services)
|
||||
|
||||
try:
|
||||
lock.acquire()
|
||||
db_conn.execute("VACUUM;")
|
||||
db_conn.commit()
|
||||
logger.info("Cleaned database")
|
||||
finally:
|
||||
lock.release()
|
||||
db.clean()
|
||||
|
||||
@staticmethod
|
||||
def shutdown():
|
||||
|
@ -7,7 +7,7 @@ from typing import Any
|
||||
|
||||
from fastapi_events.dispatcher import dispatch
|
||||
|
||||
from ..services.events import EventServiceBase
|
||||
from ..services.events.events_base import EventServiceBase
|
||||
|
||||
|
||||
class FastAPIEventService(EventServiceBase):
|
||||
|
@ -4,9 +4,9 @@ from fastapi import Body, HTTPException, Path, Query
|
||||
from fastapi.routing import APIRouter
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from invokeai.app.services.board_record_storage import BoardChanges
|
||||
from invokeai.app.services.image_record_storage import OffsetPaginatedResults
|
||||
from invokeai.app.services.models.board_record import BoardDTO
|
||||
from invokeai.app.services.board_records.board_records_common import BoardChanges
|
||||
from invokeai.app.services.boards.boards_common import BoardDTO
|
||||
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
|
||||
|
||||
from ..dependencies import ApiDependencies
|
||||
|
||||
|
@ -8,9 +8,9 @@ from PIL import Image
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from invokeai.app.invocations.metadata import ImageMetadata
|
||||
from invokeai.app.models.image import ImageCategory, ResourceOrigin
|
||||
from invokeai.app.services.image_record_storage import OffsetPaginatedResults
|
||||
from invokeai.app.services.models.image_record import ImageDTO, ImageRecordChanges, ImageUrlsDTO
|
||||
from invokeai.app.services.image_records.image_records_common import ImageCategory, ImageRecordChanges, ResourceOrigin
|
||||
from invokeai.app.services.images.images_common import ImageDTO, ImageUrlsDTO
|
||||
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
|
||||
|
||||
from ..dependencies import ApiDependencies
|
||||
|
||||
@ -42,7 +42,7 @@ async def upload_image(
|
||||
crop_visible: Optional[bool] = Query(default=False, description="Whether to crop the image"),
|
||||
) -> ImageDTO:
|
||||
"""Uploads an image"""
|
||||
if not file.content_type.startswith("image"):
|
||||
if not file.content_type or not file.content_type.startswith("image"):
|
||||
raise HTTPException(status_code=415, detail="Not an image")
|
||||
|
||||
contents = await file.read()
|
||||
@ -322,3 +322,20 @@ async def unstar_images_in_list(
|
||||
return ImagesUpdatedFromListResult(updated_image_names=updated_image_names)
|
||||
except Exception:
|
||||
raise HTTPException(status_code=500, detail="Failed to unstar images")
|
||||
|
||||
|
||||
class ImagesDownloaded(BaseModel):
|
||||
response: Optional[str] = Field(
|
||||
description="If defined, the message to display to the user when images begin downloading"
|
||||
)
|
||||
|
||||
|
||||
@images_router.post("/download", operation_id="download_images_from_list", response_model=ImagesDownloaded)
|
||||
async def download_images_from_list(
|
||||
image_names: list[str] = Body(description="The list of names of images to download", embed=True),
|
||||
board_id: Optional[str] = Body(
|
||||
default=None, description="The board from which image should be downloaded from", embed=True
|
||||
),
|
||||
) -> ImagesDownloaded:
|
||||
# return ImagesDownloaded(response="Your images are downloading")
|
||||
raise HTTPException(status_code=501, detail="Endpoint is not yet implemented")
|
||||
|
@ -2,11 +2,11 @@
|
||||
|
||||
|
||||
import pathlib
|
||||
from typing import List, Literal, Optional, Union
|
||||
from typing import Annotated, List, Literal, Optional, Union
|
||||
|
||||
from fastapi import Body, Path, Query, Response
|
||||
from fastapi.routing import APIRouter
|
||||
from pydantic import BaseModel, parse_obj_as
|
||||
from pydantic import BaseModel, ConfigDict, Field, TypeAdapter
|
||||
from starlette.exceptions import HTTPException
|
||||
|
||||
from invokeai.backend import BaseModelType, ModelType
|
||||
@ -23,8 +23,14 @@ from ..dependencies import ApiDependencies
|
||||
models_router = APIRouter(prefix="/v1/models", tags=["models"])
|
||||
|
||||
UpdateModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)]
|
||||
update_models_response_adapter = TypeAdapter(UpdateModelResponse)
|
||||
|
||||
ImportModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)]
|
||||
import_models_response_adapter = TypeAdapter(ImportModelResponse)
|
||||
|
||||
ConvertModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)]
|
||||
convert_models_response_adapter = TypeAdapter(ConvertModelResponse)
|
||||
|
||||
MergeModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)]
|
||||
ImportModelAttributes = Union[tuple(OPENAPI_MODEL_CONFIGS)]
|
||||
|
||||
@ -32,6 +38,11 @@ ImportModelAttributes = Union[tuple(OPENAPI_MODEL_CONFIGS)]
|
||||
class ModelsList(BaseModel):
|
||||
models: list[Union[tuple(OPENAPI_MODEL_CONFIGS)]]
|
||||
|
||||
model_config = ConfigDict(use_enum_values=True)
|
||||
|
||||
|
||||
models_list_adapter = TypeAdapter(ModelsList)
|
||||
|
||||
|
||||
@models_router.get(
|
||||
"/",
|
||||
@ -49,7 +60,7 @@ async def list_models(
|
||||
models_raw.extend(ApiDependencies.invoker.services.model_manager.list_models(base_model, model_type))
|
||||
else:
|
||||
models_raw = ApiDependencies.invoker.services.model_manager.list_models(None, model_type)
|
||||
models = parse_obj_as(ModelsList, {"models": models_raw})
|
||||
models = models_list_adapter.validate_python({"models": models_raw})
|
||||
return models
|
||||
|
||||
|
||||
@ -105,11 +116,14 @@ async def update_model(
|
||||
info.path = new_info.get("path")
|
||||
|
||||
# replace empty string values with None/null to avoid phenomenon of vae: ''
|
||||
info_dict = info.dict()
|
||||
info_dict = info.model_dump()
|
||||
info_dict = {x: info_dict[x] if info_dict[x] else None for x in info_dict.keys()}
|
||||
|
||||
ApiDependencies.invoker.services.model_manager.update_model(
|
||||
model_name=model_name, base_model=base_model, model_type=model_type, model_attributes=info_dict
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
model_attributes=info_dict,
|
||||
)
|
||||
|
||||
model_raw = ApiDependencies.invoker.services.model_manager.list_model(
|
||||
@ -117,7 +131,7 @@ async def update_model(
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
)
|
||||
model_response = parse_obj_as(UpdateModelResponse, model_raw)
|
||||
model_response = update_models_response_adapter.validate_python(model_raw)
|
||||
except ModelNotFoundException as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
except ValueError as e:
|
||||
@ -152,13 +166,15 @@ async def import_model(
|
||||
) -> ImportModelResponse:
|
||||
"""Add a model using its local path, repo_id, or remote URL. Model characteristics will be probed and configured automatically"""
|
||||
|
||||
location = location.strip("\"' ")
|
||||
items_to_import = {location}
|
||||
prediction_types = {x.value: x for x in SchedulerPredictionType}
|
||||
logger = ApiDependencies.invoker.services.logger
|
||||
|
||||
try:
|
||||
installed_models = ApiDependencies.invoker.services.model_manager.heuristic_import(
|
||||
items_to_import=items_to_import, prediction_type_helper=lambda x: prediction_types.get(prediction_type)
|
||||
items_to_import=items_to_import,
|
||||
prediction_type_helper=lambda x: prediction_types.get(prediction_type),
|
||||
)
|
||||
info = installed_models.get(location)
|
||||
|
||||
@ -170,7 +186,7 @@ async def import_model(
|
||||
model_raw = ApiDependencies.invoker.services.model_manager.list_model(
|
||||
model_name=info.name, base_model=info.base_model, model_type=info.model_type
|
||||
)
|
||||
return parse_obj_as(ImportModelResponse, model_raw)
|
||||
return import_models_response_adapter.validate_python(model_raw)
|
||||
|
||||
except ModelNotFoundException as e:
|
||||
logger.error(str(e))
|
||||
@ -204,13 +220,18 @@ async def add_model(
|
||||
|
||||
try:
|
||||
ApiDependencies.invoker.services.model_manager.add_model(
|
||||
info.model_name, info.base_model, info.model_type, model_attributes=info.dict()
|
||||
info.model_name,
|
||||
info.base_model,
|
||||
info.model_type,
|
||||
model_attributes=info.model_dump(),
|
||||
)
|
||||
logger.info(f"Successfully added {info.model_name}")
|
||||
model_raw = ApiDependencies.invoker.services.model_manager.list_model(
|
||||
model_name=info.model_name, base_model=info.base_model, model_type=info.model_type
|
||||
model_name=info.model_name,
|
||||
base_model=info.base_model,
|
||||
model_type=info.model_type,
|
||||
)
|
||||
return parse_obj_as(ImportModelResponse, model_raw)
|
||||
return import_models_response_adapter.validate_python(model_raw)
|
||||
except ModelNotFoundException as e:
|
||||
logger.error(str(e))
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
@ -222,7 +243,10 @@ async def add_model(
|
||||
@models_router.delete(
|
||||
"/{base_model}/{model_type}/{model_name}",
|
||||
operation_id="del_model",
|
||||
responses={204: {"description": "Model deleted successfully"}, 404: {"description": "Model not found"}},
|
||||
responses={
|
||||
204: {"description": "Model deleted successfully"},
|
||||
404: {"description": "Model not found"},
|
||||
},
|
||||
status_code=204,
|
||||
response_model=None,
|
||||
)
|
||||
@ -278,7 +302,7 @@ async def convert_model(
|
||||
model_raw = ApiDependencies.invoker.services.model_manager.list_model(
|
||||
model_name, base_model=base_model, model_type=model_type
|
||||
)
|
||||
response = parse_obj_as(ConvertModelResponse, model_raw)
|
||||
response = convert_models_response_adapter.validate_python(model_raw)
|
||||
except ModelNotFoundException as e:
|
||||
raise HTTPException(status_code=404, detail=f"Model '{model_name}' not found: {str(e)}")
|
||||
except ValueError as e:
|
||||
@ -301,7 +325,8 @@ async def search_for_models(
|
||||
) -> List[pathlib.Path]:
|
||||
if not search_path.is_dir():
|
||||
raise HTTPException(
|
||||
status_code=404, detail=f"The search path '{search_path}' does not exist or is not directory"
|
||||
status_code=404,
|
||||
detail=f"The search path '{search_path}' does not exist or is not directory",
|
||||
)
|
||||
return ApiDependencies.invoker.services.model_manager.search_for_models(search_path)
|
||||
|
||||
@ -336,6 +361,26 @@ async def sync_to_config() -> bool:
|
||||
return True
|
||||
|
||||
|
||||
# There's some weird pydantic-fastapi behaviour that requires this to be a separate class
|
||||
# TODO: After a few updates, see if it works inside the route operation handler?
|
||||
class MergeModelsBody(BaseModel):
|
||||
model_names: List[str] = Field(description="model name", min_length=2, max_length=3)
|
||||
merged_model_name: Optional[str] = Field(description="Name of destination model")
|
||||
alpha: Optional[float] = Field(description="Alpha weighting strength to apply to 2d and 3d models", default=0.5)
|
||||
interp: Optional[MergeInterpolationMethod] = Field(description="Interpolation method")
|
||||
force: Optional[bool] = Field(
|
||||
description="Force merging of models created with different versions of diffusers",
|
||||
default=False,
|
||||
)
|
||||
|
||||
merge_dest_directory: Optional[str] = Field(
|
||||
description="Save the merged model to the designated directory (with 'merged_model_name' appended)",
|
||||
default=None,
|
||||
)
|
||||
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
|
||||
@models_router.put(
|
||||
"/merge/{base_model}",
|
||||
operation_id="merge_models",
|
||||
@ -348,31 +393,23 @@ async def sync_to_config() -> bool:
|
||||
response_model=MergeModelResponse,
|
||||
)
|
||||
async def merge_models(
|
||||
body: Annotated[MergeModelsBody, Body(description="Model configuration", embed=True)],
|
||||
base_model: BaseModelType = Path(description="Base model"),
|
||||
model_names: List[str] = Body(description="model name", min_items=2, max_items=3),
|
||||
merged_model_name: Optional[str] = Body(description="Name of destination model"),
|
||||
alpha: Optional[float] = Body(description="Alpha weighting strength to apply to 2d and 3d models", default=0.5),
|
||||
interp: Optional[MergeInterpolationMethod] = Body(description="Interpolation method"),
|
||||
force: Optional[bool] = Body(
|
||||
description="Force merging of models created with different versions of diffusers", default=False
|
||||
),
|
||||
merge_dest_directory: Optional[str] = Body(
|
||||
description="Save the merged model to the designated directory (with 'merged_model_name' appended)",
|
||||
default=None,
|
||||
),
|
||||
) -> MergeModelResponse:
|
||||
"""Convert a checkpoint model into a diffusers model"""
|
||||
logger = ApiDependencies.invoker.services.logger
|
||||
try:
|
||||
logger.info(f"Merging models: {model_names} into {merge_dest_directory or '<MODELS>'}/{merged_model_name}")
|
||||
dest = pathlib.Path(merge_dest_directory) if merge_dest_directory else None
|
||||
logger.info(
|
||||
f"Merging models: {body.model_names} into {body.merge_dest_directory or '<MODELS>'}/{body.merged_model_name}"
|
||||
)
|
||||
dest = pathlib.Path(body.merge_dest_directory) if body.merge_dest_directory else None
|
||||
result = ApiDependencies.invoker.services.model_manager.merge_models(
|
||||
model_names,
|
||||
base_model,
|
||||
merged_model_name=merged_model_name or "+".join(model_names),
|
||||
alpha=alpha,
|
||||
interp=interp,
|
||||
force=force,
|
||||
model_names=body.model_names,
|
||||
base_model=base_model,
|
||||
merged_model_name=body.merged_model_name or "+".join(body.model_names),
|
||||
alpha=body.alpha,
|
||||
interp=body.interp,
|
||||
force=body.force,
|
||||
merge_dest_directory=dest,
|
||||
)
|
||||
model_raw = ApiDependencies.invoker.services.model_manager.list_model(
|
||||
@ -380,9 +417,12 @@ async def merge_models(
|
||||
base_model=base_model,
|
||||
model_type=ModelType.Main,
|
||||
)
|
||||
response = parse_obj_as(ConvertModelResponse, model_raw)
|
||||
response = convert_models_response_adapter.validate_python(model_raw)
|
||||
except ModelNotFoundException:
|
||||
raise HTTPException(status_code=404, detail=f"One or more of the models '{model_names}' not found")
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"One or more of the models '{body.model_names}' not found",
|
||||
)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
return response
|
||||
|
@ -18,9 +18,9 @@ from invokeai.app.services.session_queue.session_queue_common import (
|
||||
SessionQueueItemDTO,
|
||||
SessionQueueStatus,
|
||||
)
|
||||
from invokeai.app.services.shared.models import CursorPaginatedResults
|
||||
from invokeai.app.services.shared.graph import Graph
|
||||
from invokeai.app.services.shared.pagination import CursorPaginatedResults
|
||||
|
||||
from ...services.graph import Graph
|
||||
from ..dependencies import ApiDependencies
|
||||
|
||||
session_queue_router = APIRouter(prefix="/v1/queue", tags=["queue"])
|
||||
|
@ -1,56 +1,50 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
from typing import Annotated, Optional, Union
|
||||
|
||||
from fastapi import Body, HTTPException, Path, Query, Response
|
||||
from fastapi import HTTPException, Path
|
||||
from fastapi.routing import APIRouter
|
||||
from pydantic.fields import Field
|
||||
|
||||
# Importing * is bad karma but needed here for node detection
|
||||
from ...invocations import * # noqa: F401 F403
|
||||
from ...invocations.baseinvocation import BaseInvocation
|
||||
from ...services.graph import Edge, EdgeConnection, Graph, GraphExecutionState, NodeAlreadyExecutedError
|
||||
from ...services.item_storage import PaginatedResults
|
||||
from ...services.shared.graph import GraphExecutionState
|
||||
from ..dependencies import ApiDependencies
|
||||
|
||||
session_router = APIRouter(prefix="/v1/sessions", tags=["sessions"])
|
||||
|
||||
|
||||
@session_router.post(
|
||||
"/",
|
||||
operation_id="create_session",
|
||||
responses={
|
||||
200: {"model": GraphExecutionState},
|
||||
400: {"description": "Invalid json"},
|
||||
},
|
||||
deprecated=True,
|
||||
)
|
||||
async def create_session(
|
||||
queue_id: str = Query(default="", description="The id of the queue to associate the session with"),
|
||||
graph: Optional[Graph] = Body(default=None, description="The graph to initialize the session with"),
|
||||
) -> GraphExecutionState:
|
||||
"""Creates a new session, optionally initializing it with an invocation graph"""
|
||||
session = ApiDependencies.invoker.create_execution_state(queue_id=queue_id, graph=graph)
|
||||
return session
|
||||
# @session_router.post(
|
||||
# "/",
|
||||
# operation_id="create_session",
|
||||
# responses={
|
||||
# 200: {"model": GraphExecutionState},
|
||||
# 400: {"description": "Invalid json"},
|
||||
# },
|
||||
# deprecated=True,
|
||||
# )
|
||||
# async def create_session(
|
||||
# queue_id: str = Query(default="", description="The id of the queue to associate the session with"),
|
||||
# graph: Optional[Graph] = Body(default=None, description="The graph to initialize the session with"),
|
||||
# ) -> GraphExecutionState:
|
||||
# """Creates a new session, optionally initializing it with an invocation graph"""
|
||||
# session = ApiDependencies.invoker.create_execution_state(queue_id=queue_id, graph=graph)
|
||||
# return session
|
||||
|
||||
|
||||
@session_router.get(
|
||||
"/",
|
||||
operation_id="list_sessions",
|
||||
responses={200: {"model": PaginatedResults[GraphExecutionState]}},
|
||||
deprecated=True,
|
||||
)
|
||||
async def list_sessions(
|
||||
page: int = Query(default=0, description="The page of results to get"),
|
||||
per_page: int = Query(default=10, description="The number of results per page"),
|
||||
query: str = Query(default="", description="The query string to search for"),
|
||||
) -> PaginatedResults[GraphExecutionState]:
|
||||
"""Gets a list of sessions, optionally searching"""
|
||||
if query == "":
|
||||
result = ApiDependencies.invoker.services.graph_execution_manager.list(page, per_page)
|
||||
else:
|
||||
result = ApiDependencies.invoker.services.graph_execution_manager.search(query, page, per_page)
|
||||
return result
|
||||
# @session_router.get(
|
||||
# "/",
|
||||
# operation_id="list_sessions",
|
||||
# responses={200: {"model": PaginatedResults[GraphExecutionState]}},
|
||||
# deprecated=True,
|
||||
# )
|
||||
# async def list_sessions(
|
||||
# page: int = Query(default=0, description="The page of results to get"),
|
||||
# per_page: int = Query(default=10, description="The number of results per page"),
|
||||
# query: str = Query(default="", description="The query string to search for"),
|
||||
# ) -> PaginatedResults[GraphExecutionState]:
|
||||
# """Gets a list of sessions, optionally searching"""
|
||||
# if query == "":
|
||||
# result = ApiDependencies.invoker.services.graph_execution_manager.list(page, per_page)
|
||||
# else:
|
||||
# result = ApiDependencies.invoker.services.graph_execution_manager.search(query, page, per_page)
|
||||
# return result
|
||||
|
||||
|
||||
@session_router.get(
|
||||
@ -60,7 +54,6 @@ async def list_sessions(
|
||||
200: {"model": GraphExecutionState},
|
||||
404: {"description": "Session not found"},
|
||||
},
|
||||
deprecated=True,
|
||||
)
|
||||
async def get_session(
|
||||
session_id: str = Path(description="The id of the session to get"),
|
||||
@ -73,211 +66,211 @@ async def get_session(
|
||||
return session
|
||||
|
||||
|
||||
@session_router.post(
|
||||
"/{session_id}/nodes",
|
||||
operation_id="add_node",
|
||||
responses={
|
||||
200: {"model": str},
|
||||
400: {"description": "Invalid node or link"},
|
||||
404: {"description": "Session not found"},
|
||||
},
|
||||
deprecated=True,
|
||||
)
|
||||
async def add_node(
|
||||
session_id: str = Path(description="The id of the session"),
|
||||
node: Annotated[Union[BaseInvocation.get_invocations()], Field(discriminator="type")] = Body( # type: ignore
|
||||
description="The node to add"
|
||||
),
|
||||
) -> str:
|
||||
"""Adds a node to the graph"""
|
||||
session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
|
||||
if session is None:
|
||||
raise HTTPException(status_code=404)
|
||||
# @session_router.post(
|
||||
# "/{session_id}/nodes",
|
||||
# operation_id="add_node",
|
||||
# responses={
|
||||
# 200: {"model": str},
|
||||
# 400: {"description": "Invalid node or link"},
|
||||
# 404: {"description": "Session not found"},
|
||||
# },
|
||||
# deprecated=True,
|
||||
# )
|
||||
# async def add_node(
|
||||
# session_id: str = Path(description="The id of the session"),
|
||||
# node: Annotated[Union[BaseInvocation.get_invocations()], Field(discriminator="type")] = Body( # type: ignore
|
||||
# description="The node to add"
|
||||
# ),
|
||||
# ) -> str:
|
||||
# """Adds a node to the graph"""
|
||||
# session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
|
||||
# if session is None:
|
||||
# raise HTTPException(status_code=404)
|
||||
|
||||
try:
|
||||
session.add_node(node)
|
||||
ApiDependencies.invoker.services.graph_execution_manager.set(
|
||||
session
|
||||
) # TODO: can this be done automatically, or add node through an API?
|
||||
return session.id
|
||||
except NodeAlreadyExecutedError:
|
||||
raise HTTPException(status_code=400)
|
||||
except IndexError:
|
||||
raise HTTPException(status_code=400)
|
||||
# try:
|
||||
# session.add_node(node)
|
||||
# ApiDependencies.invoker.services.graph_execution_manager.set(
|
||||
# session
|
||||
# ) # TODO: can this be done automatically, or add node through an API?
|
||||
# return session.id
|
||||
# except NodeAlreadyExecutedError:
|
||||
# raise HTTPException(status_code=400)
|
||||
# except IndexError:
|
||||
# raise HTTPException(status_code=400)
|
||||
|
||||
|
||||
@session_router.put(
|
||||
"/{session_id}/nodes/{node_path}",
|
||||
operation_id="update_node",
|
||||
responses={
|
||||
200: {"model": GraphExecutionState},
|
||||
400: {"description": "Invalid node or link"},
|
||||
404: {"description": "Session not found"},
|
||||
},
|
||||
deprecated=True,
|
||||
)
|
||||
async def update_node(
|
||||
session_id: str = Path(description="The id of the session"),
|
||||
node_path: str = Path(description="The path to the node in the graph"),
|
||||
node: Annotated[Union[BaseInvocation.get_invocations()], Field(discriminator="type")] = Body( # type: ignore
|
||||
description="The new node"
|
||||
),
|
||||
) -> GraphExecutionState:
|
||||
"""Updates a node in the graph and removes all linked edges"""
|
||||
session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
|
||||
if session is None:
|
||||
raise HTTPException(status_code=404)
|
||||
# @session_router.put(
|
||||
# "/{session_id}/nodes/{node_path}",
|
||||
# operation_id="update_node",
|
||||
# responses={
|
||||
# 200: {"model": GraphExecutionState},
|
||||
# 400: {"description": "Invalid node or link"},
|
||||
# 404: {"description": "Session not found"},
|
||||
# },
|
||||
# deprecated=True,
|
||||
# )
|
||||
# async def update_node(
|
||||
# session_id: str = Path(description="The id of the session"),
|
||||
# node_path: str = Path(description="The path to the node in the graph"),
|
||||
# node: Annotated[Union[BaseInvocation.get_invocations()], Field(discriminator="type")] = Body( # type: ignore
|
||||
# description="The new node"
|
||||
# ),
|
||||
# ) -> GraphExecutionState:
|
||||
# """Updates a node in the graph and removes all linked edges"""
|
||||
# session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
|
||||
# if session is None:
|
||||
# raise HTTPException(status_code=404)
|
||||
|
||||
try:
|
||||
session.update_node(node_path, node)
|
||||
ApiDependencies.invoker.services.graph_execution_manager.set(
|
||||
session
|
||||
) # TODO: can this be done automatically, or add node through an API?
|
||||
return session
|
||||
except NodeAlreadyExecutedError:
|
||||
raise HTTPException(status_code=400)
|
||||
except IndexError:
|
||||
raise HTTPException(status_code=400)
|
||||
# try:
|
||||
# session.update_node(node_path, node)
|
||||
# ApiDependencies.invoker.services.graph_execution_manager.set(
|
||||
# session
|
||||
# ) # TODO: can this be done automatically, or add node through an API?
|
||||
# return session
|
||||
# except NodeAlreadyExecutedError:
|
||||
# raise HTTPException(status_code=400)
|
||||
# except IndexError:
|
||||
# raise HTTPException(status_code=400)
|
||||
|
||||
|
||||
@session_router.delete(
|
||||
"/{session_id}/nodes/{node_path}",
|
||||
operation_id="delete_node",
|
||||
responses={
|
||||
200: {"model": GraphExecutionState},
|
||||
400: {"description": "Invalid node or link"},
|
||||
404: {"description": "Session not found"},
|
||||
},
|
||||
deprecated=True,
|
||||
)
|
||||
async def delete_node(
|
||||
session_id: str = Path(description="The id of the session"),
|
||||
node_path: str = Path(description="The path to the node to delete"),
|
||||
) -> GraphExecutionState:
|
||||
"""Deletes a node in the graph and removes all linked edges"""
|
||||
session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
|
||||
if session is None:
|
||||
raise HTTPException(status_code=404)
|
||||
# @session_router.delete(
|
||||
# "/{session_id}/nodes/{node_path}",
|
||||
# operation_id="delete_node",
|
||||
# responses={
|
||||
# 200: {"model": GraphExecutionState},
|
||||
# 400: {"description": "Invalid node or link"},
|
||||
# 404: {"description": "Session not found"},
|
||||
# },
|
||||
# deprecated=True,
|
||||
# )
|
||||
# async def delete_node(
|
||||
# session_id: str = Path(description="The id of the session"),
|
||||
# node_path: str = Path(description="The path to the node to delete"),
|
||||
# ) -> GraphExecutionState:
|
||||
# """Deletes a node in the graph and removes all linked edges"""
|
||||
# session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
|
||||
# if session is None:
|
||||
# raise HTTPException(status_code=404)
|
||||
|
||||
try:
|
||||
session.delete_node(node_path)
|
||||
ApiDependencies.invoker.services.graph_execution_manager.set(
|
||||
session
|
||||
) # TODO: can this be done automatically, or add node through an API?
|
||||
return session
|
||||
except NodeAlreadyExecutedError:
|
||||
raise HTTPException(status_code=400)
|
||||
except IndexError:
|
||||
raise HTTPException(status_code=400)
|
||||
# try:
|
||||
# session.delete_node(node_path)
|
||||
# ApiDependencies.invoker.services.graph_execution_manager.set(
|
||||
# session
|
||||
# ) # TODO: can this be done automatically, or add node through an API?
|
||||
# return session
|
||||
# except NodeAlreadyExecutedError:
|
||||
# raise HTTPException(status_code=400)
|
||||
# except IndexError:
|
||||
# raise HTTPException(status_code=400)
|
||||
|
||||
|
||||
@session_router.post(
|
||||
"/{session_id}/edges",
|
||||
operation_id="add_edge",
|
||||
responses={
|
||||
200: {"model": GraphExecutionState},
|
||||
400: {"description": "Invalid node or link"},
|
||||
404: {"description": "Session not found"},
|
||||
},
|
||||
deprecated=True,
|
||||
)
|
||||
async def add_edge(
|
||||
session_id: str = Path(description="The id of the session"),
|
||||
edge: Edge = Body(description="The edge to add"),
|
||||
) -> GraphExecutionState:
|
||||
"""Adds an edge to the graph"""
|
||||
session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
|
||||
if session is None:
|
||||
raise HTTPException(status_code=404)
|
||||
# @session_router.post(
|
||||
# "/{session_id}/edges",
|
||||
# operation_id="add_edge",
|
||||
# responses={
|
||||
# 200: {"model": GraphExecutionState},
|
||||
# 400: {"description": "Invalid node or link"},
|
||||
# 404: {"description": "Session not found"},
|
||||
# },
|
||||
# deprecated=True,
|
||||
# )
|
||||
# async def add_edge(
|
||||
# session_id: str = Path(description="The id of the session"),
|
||||
# edge: Edge = Body(description="The edge to add"),
|
||||
# ) -> GraphExecutionState:
|
||||
# """Adds an edge to the graph"""
|
||||
# session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
|
||||
# if session is None:
|
||||
# raise HTTPException(status_code=404)
|
||||
|
||||
try:
|
||||
session.add_edge(edge)
|
||||
ApiDependencies.invoker.services.graph_execution_manager.set(
|
||||
session
|
||||
) # TODO: can this be done automatically, or add node through an API?
|
||||
return session
|
||||
except NodeAlreadyExecutedError:
|
||||
raise HTTPException(status_code=400)
|
||||
except IndexError:
|
||||
raise HTTPException(status_code=400)
|
||||
# try:
|
||||
# session.add_edge(edge)
|
||||
# ApiDependencies.invoker.services.graph_execution_manager.set(
|
||||
# session
|
||||
# ) # TODO: can this be done automatically, or add node through an API?
|
||||
# return session
|
||||
# except NodeAlreadyExecutedError:
|
||||
# raise HTTPException(status_code=400)
|
||||
# except IndexError:
|
||||
# raise HTTPException(status_code=400)
|
||||
|
||||
|
||||
# TODO: the edge being in the path here is really ugly, find a better solution
|
||||
@session_router.delete(
|
||||
"/{session_id}/edges/{from_node_id}/{from_field}/{to_node_id}/{to_field}",
|
||||
operation_id="delete_edge",
|
||||
responses={
|
||||
200: {"model": GraphExecutionState},
|
||||
400: {"description": "Invalid node or link"},
|
||||
404: {"description": "Session not found"},
|
||||
},
|
||||
deprecated=True,
|
||||
)
|
||||
async def delete_edge(
|
||||
session_id: str = Path(description="The id of the session"),
|
||||
from_node_id: str = Path(description="The id of the node the edge is coming from"),
|
||||
from_field: str = Path(description="The field of the node the edge is coming from"),
|
||||
to_node_id: str = Path(description="The id of the node the edge is going to"),
|
||||
to_field: str = Path(description="The field of the node the edge is going to"),
|
||||
) -> GraphExecutionState:
|
||||
"""Deletes an edge from the graph"""
|
||||
session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
|
||||
if session is None:
|
||||
raise HTTPException(status_code=404)
|
||||
# # TODO: the edge being in the path here is really ugly, find a better solution
|
||||
# @session_router.delete(
|
||||
# "/{session_id}/edges/{from_node_id}/{from_field}/{to_node_id}/{to_field}",
|
||||
# operation_id="delete_edge",
|
||||
# responses={
|
||||
# 200: {"model": GraphExecutionState},
|
||||
# 400: {"description": "Invalid node or link"},
|
||||
# 404: {"description": "Session not found"},
|
||||
# },
|
||||
# deprecated=True,
|
||||
# )
|
||||
# async def delete_edge(
|
||||
# session_id: str = Path(description="The id of the session"),
|
||||
# from_node_id: str = Path(description="The id of the node the edge is coming from"),
|
||||
# from_field: str = Path(description="The field of the node the edge is coming from"),
|
||||
# to_node_id: str = Path(description="The id of the node the edge is going to"),
|
||||
# to_field: str = Path(description="The field of the node the edge is going to"),
|
||||
# ) -> GraphExecutionState:
|
||||
# """Deletes an edge from the graph"""
|
||||
# session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
|
||||
# if session is None:
|
||||
# raise HTTPException(status_code=404)
|
||||
|
||||
try:
|
||||
edge = Edge(
|
||||
source=EdgeConnection(node_id=from_node_id, field=from_field),
|
||||
destination=EdgeConnection(node_id=to_node_id, field=to_field),
|
||||
)
|
||||
session.delete_edge(edge)
|
||||
ApiDependencies.invoker.services.graph_execution_manager.set(
|
||||
session
|
||||
) # TODO: can this be done automatically, or add node through an API?
|
||||
return session
|
||||
except NodeAlreadyExecutedError:
|
||||
raise HTTPException(status_code=400)
|
||||
except IndexError:
|
||||
raise HTTPException(status_code=400)
|
||||
# try:
|
||||
# edge = Edge(
|
||||
# source=EdgeConnection(node_id=from_node_id, field=from_field),
|
||||
# destination=EdgeConnection(node_id=to_node_id, field=to_field),
|
||||
# )
|
||||
# session.delete_edge(edge)
|
||||
# ApiDependencies.invoker.services.graph_execution_manager.set(
|
||||
# session
|
||||
# ) # TODO: can this be done automatically, or add node through an API?
|
||||
# return session
|
||||
# except NodeAlreadyExecutedError:
|
||||
# raise HTTPException(status_code=400)
|
||||
# except IndexError:
|
||||
# raise HTTPException(status_code=400)
|
||||
|
||||
|
||||
@session_router.put(
|
||||
"/{session_id}/invoke",
|
||||
operation_id="invoke_session",
|
||||
responses={
|
||||
200: {"model": None},
|
||||
202: {"description": "The invocation is queued"},
|
||||
400: {"description": "The session has no invocations ready to invoke"},
|
||||
404: {"description": "Session not found"},
|
||||
},
|
||||
deprecated=True,
|
||||
)
|
||||
async def invoke_session(
|
||||
queue_id: str = Query(description="The id of the queue to associate the session with"),
|
||||
session_id: str = Path(description="The id of the session to invoke"),
|
||||
all: bool = Query(default=False, description="Whether or not to invoke all remaining invocations"),
|
||||
) -> Response:
|
||||
"""Invokes a session"""
|
||||
session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
|
||||
if session is None:
|
||||
raise HTTPException(status_code=404)
|
||||
# @session_router.put(
|
||||
# "/{session_id}/invoke",
|
||||
# operation_id="invoke_session",
|
||||
# responses={
|
||||
# 200: {"model": None},
|
||||
# 202: {"description": "The invocation is queued"},
|
||||
# 400: {"description": "The session has no invocations ready to invoke"},
|
||||
# 404: {"description": "Session not found"},
|
||||
# },
|
||||
# deprecated=True,
|
||||
# )
|
||||
# async def invoke_session(
|
||||
# queue_id: str = Query(description="The id of the queue to associate the session with"),
|
||||
# session_id: str = Path(description="The id of the session to invoke"),
|
||||
# all: bool = Query(default=False, description="Whether or not to invoke all remaining invocations"),
|
||||
# ) -> Response:
|
||||
# """Invokes a session"""
|
||||
# session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
|
||||
# if session is None:
|
||||
# raise HTTPException(status_code=404)
|
||||
|
||||
if session.is_complete():
|
||||
raise HTTPException(status_code=400)
|
||||
# if session.is_complete():
|
||||
# raise HTTPException(status_code=400)
|
||||
|
||||
ApiDependencies.invoker.invoke(queue_id, session, invoke_all=all)
|
||||
return Response(status_code=202)
|
||||
# ApiDependencies.invoker.invoke(queue_id, session, invoke_all=all)
|
||||
# return Response(status_code=202)
|
||||
|
||||
|
||||
@session_router.delete(
|
||||
"/{session_id}/invoke",
|
||||
operation_id="cancel_session_invoke",
|
||||
responses={202: {"description": "The invocation is canceled"}},
|
||||
deprecated=True,
|
||||
)
|
||||
async def cancel_session_invoke(
|
||||
session_id: str = Path(description="The id of the session to cancel"),
|
||||
) -> Response:
|
||||
"""Invokes a session"""
|
||||
ApiDependencies.invoker.cancel(session_id)
|
||||
return Response(status_code=202)
|
||||
# @session_router.delete(
|
||||
# "/{session_id}/invoke",
|
||||
# operation_id="cancel_session_invoke",
|
||||
# responses={202: {"description": "The invocation is canceled"}},
|
||||
# deprecated=True,
|
||||
# )
|
||||
# async def cancel_session_invoke(
|
||||
# session_id: str = Path(description="The id of the session to cancel"),
|
||||
# ) -> Response:
|
||||
# """Invokes a session"""
|
||||
# ApiDependencies.invoker.cancel(session_id)
|
||||
# return Response(status_code=202)
|
||||
|
@ -1,4 +1,4 @@
|
||||
from typing import Optional
|
||||
from typing import Optional, Union
|
||||
|
||||
from dynamicprompts.generators import CombinatorialPromptGenerator, RandomPromptGenerator
|
||||
from fastapi import Body
|
||||
@ -27,6 +27,7 @@ async def parse_dynamicprompts(
|
||||
combinatorial: bool = Body(default=True, description="Whether to use the combinatorial generator"),
|
||||
) -> DynamicPromptsResponse:
|
||||
"""Creates a batch process"""
|
||||
generator: Union[RandomPromptGenerator, CombinatorialPromptGenerator]
|
||||
try:
|
||||
error: Optional[str] = None
|
||||
if combinatorial:
|
||||
|
@ -5,7 +5,7 @@ from fastapi_events.handlers.local import local_handler
|
||||
from fastapi_events.typing import Event
|
||||
from socketio import ASGIApp, AsyncServer
|
||||
|
||||
from ..services.events import EventServiceBase
|
||||
from ..services.events.events_base import EventServiceBase
|
||||
|
||||
|
||||
class SocketIO:
|
||||
@ -30,8 +30,8 @@ class SocketIO:
|
||||
|
||||
async def _handle_sub_queue(self, sid, data, *args, **kwargs):
|
||||
if "queue_id" in data:
|
||||
self.__sio.enter_room(sid, data["queue_id"])
|
||||
await self.__sio.enter_room(sid, data["queue_id"])
|
||||
|
||||
async def _handle_unsub_queue(self, sid, data, *args, **kwargs):
|
||||
if "queue_id" in data:
|
||||
self.__sio.enter_room(sid, data["queue_id"])
|
||||
await self.__sio.enter_room(sid, data["queue_id"])
|
||||
|
@ -22,7 +22,7 @@ if True: # hack to make flake8 happy with imports coming after setting up the c
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from fastapi_events.handlers.local import local_handler
|
||||
from fastapi_events.middleware import EventHandlerASGIMiddleware
|
||||
from pydantic.schema import schema
|
||||
from pydantic.json_schema import models_json_schema
|
||||
|
||||
# noinspection PyUnresolvedReferences
|
||||
import invokeai.backend.util.hotfixes # noqa: F401 (monkeypatching on import)
|
||||
@ -51,7 +51,7 @@ mimetypes.add_type("text/css", ".css")
|
||||
|
||||
# Create the app
|
||||
# TODO: create this all in a method so configuration/etc. can be passed in?
|
||||
app = FastAPI(title="Invoke AI", docs_url=None, redoc_url=None)
|
||||
app = FastAPI(title="Invoke AI", docs_url=None, redoc_url=None, separate_input_output_schemas=False)
|
||||
|
||||
# Add event handler
|
||||
event_handler_id: int = id(app)
|
||||
@ -63,18 +63,18 @@ app.add_middleware(
|
||||
|
||||
socket_io = SocketIO(app)
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=app_config.allow_origins,
|
||||
allow_credentials=app_config.allow_credentials,
|
||||
allow_methods=app_config.allow_methods,
|
||||
allow_headers=app_config.allow_headers,
|
||||
)
|
||||
|
||||
|
||||
# Add startup event to load dependencies
|
||||
@app.on_event("startup")
|
||||
async def startup_event():
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=app_config.allow_origins,
|
||||
allow_credentials=app_config.allow_credentials,
|
||||
allow_methods=app_config.allow_methods,
|
||||
allow_headers=app_config.allow_headers,
|
||||
)
|
||||
|
||||
ApiDependencies.initialize(config=app_config, event_handler_id=event_handler_id, logger=logger)
|
||||
|
||||
|
||||
@ -85,11 +85,6 @@ async def shutdown_event():
|
||||
|
||||
|
||||
# Include all routers
|
||||
# TODO: REMOVE
|
||||
# app.include_router(
|
||||
# invocation.invocation_router,
|
||||
# prefix = '/api')
|
||||
|
||||
app.include_router(sessions.session_router, prefix="/api")
|
||||
|
||||
app.include_router(utilities.utilities_router, prefix="/api")
|
||||
@ -117,6 +112,7 @@ def custom_openapi():
|
||||
description="An API for invoking AI image operations",
|
||||
version="1.0.0",
|
||||
routes=app.routes,
|
||||
separate_input_output_schemas=False, # https://fastapi.tiangolo.com/how-to/separate-openapi-schemas/
|
||||
)
|
||||
|
||||
# Add all outputs
|
||||
@ -127,29 +123,32 @@ def custom_openapi():
|
||||
output_type = signature(invoker.invoke).return_annotation
|
||||
output_types.add(output_type)
|
||||
|
||||
output_schemas = schema(output_types, ref_prefix="#/components/schemas/")
|
||||
for schema_key, output_schema in output_schemas["definitions"].items():
|
||||
output_schema["class"] = "output"
|
||||
openapi_schema["components"]["schemas"][schema_key] = output_schema
|
||||
|
||||
output_schemas = models_json_schema(
|
||||
models=[(o, "serialization") for o in output_types], ref_template="#/components/schemas/{model}"
|
||||
)
|
||||
for schema_key, output_schema in output_schemas[1]["$defs"].items():
|
||||
# TODO: note that we assume the schema_key here is the TYPE.__name__
|
||||
# This could break in some cases, figure out a better way to do it
|
||||
output_type_titles[schema_key] = output_schema["title"]
|
||||
|
||||
# Add Node Editor UI helper schemas
|
||||
ui_config_schemas = schema([UIConfigBase, _InputField, _OutputField], ref_prefix="#/components/schemas/")
|
||||
for schema_key, ui_config_schema in ui_config_schemas["definitions"].items():
|
||||
ui_config_schemas = models_json_schema(
|
||||
[(UIConfigBase, "serialization"), (_InputField, "serialization"), (_OutputField, "serialization")],
|
||||
ref_template="#/components/schemas/{model}",
|
||||
)
|
||||
for schema_key, ui_config_schema in ui_config_schemas[1]["$defs"].items():
|
||||
openapi_schema["components"]["schemas"][schema_key] = ui_config_schema
|
||||
|
||||
# Add a reference to the output type to additionalProperties of the invoker schema
|
||||
for invoker in all_invocations:
|
||||
invoker_name = invoker.__name__
|
||||
output_type = signature(invoker.invoke).return_annotation
|
||||
output_type = signature(obj=invoker.invoke).return_annotation
|
||||
output_type_title = output_type_titles[output_type.__name__]
|
||||
invoker_schema = openapi_schema["components"]["schemas"][invoker_name]
|
||||
invoker_schema = openapi_schema["components"]["schemas"][f"{invoker_name}"]
|
||||
outputs_ref = {"$ref": f"#/components/schemas/{output_type_title}"}
|
||||
invoker_schema["output"] = outputs_ref
|
||||
invoker_schema["class"] = "invocation"
|
||||
openapi_schema["components"]["schemas"][f"{output_type_title}"]["class"] = "output"
|
||||
|
||||
from invokeai.backend.model_management.models import get_model_config_enums
|
||||
|
||||
@ -172,7 +171,7 @@ def custom_openapi():
|
||||
return app.openapi_schema
|
||||
|
||||
|
||||
app.openapi = custom_openapi
|
||||
app.openapi = custom_openapi # type: ignore [method-assign] # this is a valid assignment
|
||||
|
||||
# Override API doc favicons
|
||||
app.mount("/static", StaticFiles(directory=Path(web_dir.__path__[0], "static/dream_web")), name="static")
|
||||
|
@ -24,8 +24,8 @@ def add_field_argument(command_parser, name: str, field, default_override=None):
|
||||
if field.default_factory is None
|
||||
else field.default_factory()
|
||||
)
|
||||
if get_origin(field.type_) == Literal:
|
||||
allowed_values = get_args(field.type_)
|
||||
if get_origin(field.annotation) == Literal:
|
||||
allowed_values = get_args(field.annotation)
|
||||
allowed_types = set()
|
||||
for val in allowed_values:
|
||||
allowed_types.add(type(val))
|
||||
@ -38,15 +38,15 @@ def add_field_argument(command_parser, name: str, field, default_override=None):
|
||||
type=field_type,
|
||||
default=default,
|
||||
choices=allowed_values,
|
||||
help=field.field_info.description,
|
||||
help=field.description,
|
||||
)
|
||||
else:
|
||||
command_parser.add_argument(
|
||||
f"--{name}",
|
||||
dest=name,
|
||||
type=field.type_,
|
||||
type=field.annotation,
|
||||
default=default,
|
||||
help=field.field_info.description,
|
||||
help=field.description,
|
||||
)
|
||||
|
||||
|
||||
@ -142,7 +142,6 @@ class BaseCommand(ABC, BaseModel):
|
||||
"""A CLI command"""
|
||||
|
||||
# All commands must include a type name like this:
|
||||
# type: Literal['your_command_name'] = 'your_command_name'
|
||||
|
||||
@classmethod
|
||||
def get_all_subclasses(cls):
|
||||
|
@ -7,28 +7,16 @@ import re
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
from inspect import signature
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
AbstractSet,
|
||||
Any,
|
||||
Callable,
|
||||
ClassVar,
|
||||
Literal,
|
||||
Mapping,
|
||||
Optional,
|
||||
Type,
|
||||
TypeVar,
|
||||
Union,
|
||||
get_args,
|
||||
get_type_hints,
|
||||
)
|
||||
from types import UnionType
|
||||
from typing import TYPE_CHECKING, Any, Callable, ClassVar, Iterable, Literal, Optional, Type, TypeVar, Union
|
||||
|
||||
import semver
|
||||
from pydantic import BaseModel, Field, validator
|
||||
from pydantic.fields import ModelField, Undefined
|
||||
from pydantic.typing import NoArgAnyCallable
|
||||
from pydantic import BaseModel, ConfigDict, Field, create_model, field_validator
|
||||
from pydantic.fields import _Unset
|
||||
from pydantic_core import PydanticUndefined
|
||||
|
||||
from invokeai.app.services.config.invokeai_config import InvokeAIAppConfig
|
||||
from invokeai.app.services.config.config_default import InvokeAIAppConfig
|
||||
from invokeai.app.util.misc import uuid_string
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..services.invocation_services import InvocationServices
|
||||
@ -215,6 +203,11 @@ class _InputField(BaseModel):
|
||||
ui_choice_labels: Optional[dict[str, str]]
|
||||
item_default: Optional[Any]
|
||||
|
||||
model_config = ConfigDict(
|
||||
validate_assignment=True,
|
||||
json_schema_serialization_defaults_required=True,
|
||||
)
|
||||
|
||||
|
||||
class _OutputField(BaseModel):
|
||||
"""
|
||||
@ -228,34 +221,36 @@ class _OutputField(BaseModel):
|
||||
ui_type: Optional[UIType]
|
||||
ui_order: Optional[int]
|
||||
|
||||
model_config = ConfigDict(
|
||||
validate_assignment=True,
|
||||
json_schema_serialization_defaults_required=True,
|
||||
)
|
||||
|
||||
|
||||
def get_type(klass: BaseModel) -> str:
|
||||
"""Helper function to get an invocation or invocation output's type. This is the default value of the `type` field."""
|
||||
return klass.model_fields["type"].default
|
||||
|
||||
|
||||
def InputField(
|
||||
*args: Any,
|
||||
default: Any = Undefined,
|
||||
default_factory: Optional[NoArgAnyCallable] = None,
|
||||
alias: Optional[str] = None,
|
||||
title: Optional[str] = None,
|
||||
description: Optional[str] = None,
|
||||
exclude: Optional[Union[AbstractSet[Union[int, str]], Mapping[Union[int, str], Any], Any]] = None,
|
||||
include: Optional[Union[AbstractSet[Union[int, str]], Mapping[Union[int, str], Any], Any]] = None,
|
||||
const: Optional[bool] = None,
|
||||
gt: Optional[float] = None,
|
||||
ge: Optional[float] = None,
|
||||
lt: Optional[float] = None,
|
||||
le: Optional[float] = None,
|
||||
multiple_of: Optional[float] = None,
|
||||
allow_inf_nan: Optional[bool] = None,
|
||||
max_digits: Optional[int] = None,
|
||||
decimal_places: Optional[int] = None,
|
||||
min_items: Optional[int] = None,
|
||||
max_items: Optional[int] = None,
|
||||
unique_items: Optional[bool] = None,
|
||||
min_length: Optional[int] = None,
|
||||
max_length: Optional[int] = None,
|
||||
allow_mutation: bool = True,
|
||||
regex: Optional[str] = None,
|
||||
discriminator: Optional[str] = None,
|
||||
repr: bool = True,
|
||||
# copied from pydantic's Field
|
||||
default: Any = _Unset,
|
||||
default_factory: Callable[[], Any] | None = _Unset,
|
||||
title: str | None = _Unset,
|
||||
description: str | None = _Unset,
|
||||
pattern: str | None = _Unset,
|
||||
strict: bool | None = _Unset,
|
||||
gt: float | None = _Unset,
|
||||
ge: float | None = _Unset,
|
||||
lt: float | None = _Unset,
|
||||
le: float | None = _Unset,
|
||||
multiple_of: float | None = _Unset,
|
||||
allow_inf_nan: bool | None = _Unset,
|
||||
max_digits: int | None = _Unset,
|
||||
decimal_places: int | None = _Unset,
|
||||
min_length: int | None = _Unset,
|
||||
max_length: int | None = _Unset,
|
||||
# custom
|
||||
input: Input = Input.Any,
|
||||
ui_type: Optional[UIType] = None,
|
||||
ui_component: Optional[UIComponent] = None,
|
||||
@ -263,7 +258,6 @@ def InputField(
|
||||
ui_order: Optional[int] = None,
|
||||
ui_choice_labels: Optional[dict[str, str]] = None,
|
||||
item_default: Optional[Any] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""
|
||||
Creates an input field for an invocation.
|
||||
@ -293,18 +287,26 @@ def InputField(
|
||||
: param int ui_order: [None] Specifies the order in which this field should be rendered in the UI. \
|
||||
|
||||
: param bool item_default: [None] Specifies the default item value, if this is a collection input. \
|
||||
Ignored for non-collection fields..
|
||||
Ignored for non-collection fields.
|
||||
"""
|
||||
return Field(
|
||||
*args,
|
||||
|
||||
json_schema_extra_: dict[str, Any] = dict(
|
||||
input=input,
|
||||
ui_type=ui_type,
|
||||
ui_component=ui_component,
|
||||
ui_hidden=ui_hidden,
|
||||
ui_order=ui_order,
|
||||
item_default=item_default,
|
||||
ui_choice_labels=ui_choice_labels,
|
||||
)
|
||||
|
||||
field_args = dict(
|
||||
default=default,
|
||||
default_factory=default_factory,
|
||||
alias=alias,
|
||||
title=title,
|
||||
description=description,
|
||||
exclude=exclude,
|
||||
include=include,
|
||||
const=const,
|
||||
pattern=pattern,
|
||||
strict=strict,
|
||||
gt=gt,
|
||||
ge=ge,
|
||||
lt=lt,
|
||||
@ -313,57 +315,92 @@ def InputField(
|
||||
allow_inf_nan=allow_inf_nan,
|
||||
max_digits=max_digits,
|
||||
decimal_places=decimal_places,
|
||||
min_items=min_items,
|
||||
max_items=max_items,
|
||||
unique_items=unique_items,
|
||||
min_length=min_length,
|
||||
max_length=max_length,
|
||||
allow_mutation=allow_mutation,
|
||||
regex=regex,
|
||||
discriminator=discriminator,
|
||||
repr=repr,
|
||||
input=input,
|
||||
ui_type=ui_type,
|
||||
ui_component=ui_component,
|
||||
ui_hidden=ui_hidden,
|
||||
ui_order=ui_order,
|
||||
item_default=item_default,
|
||||
ui_choice_labels=ui_choice_labels,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
"""
|
||||
Invocation definitions have their fields typed correctly for their `invoke()` functions.
|
||||
This typing is often more specific than the actual invocation definition requires, because
|
||||
fields may have values provided only by connections.
|
||||
|
||||
For example, consider an ResizeImageInvocation with an `image: ImageField` field.
|
||||
|
||||
`image` is required during the call to `invoke()`, but when the python class is instantiated,
|
||||
the field may not be present. This is fine, because that image field will be provided by a
|
||||
an ancestor node that outputs the image.
|
||||
|
||||
So we'd like to type that `image` field as `Optional[ImageField]`. If we do that, however, then
|
||||
we need to handle a lot of extra logic in the `invoke()` function to check if the field has a
|
||||
value or not. This is very tedious.
|
||||
|
||||
Ideally, the invocation definition would be able to specify that the field is required during
|
||||
invocation, but optional during instantiation. So the field would be typed as `image: ImageField`,
|
||||
but when calling the `invoke()` function, we raise an error if the field is not present.
|
||||
|
||||
To do this, we need to do a bit of fanagling to make the pydantic field optional, and then do
|
||||
extra validation when calling `invoke()`.
|
||||
|
||||
There is some additional logic here to cleaning create the pydantic field via the wrapper.
|
||||
"""
|
||||
|
||||
# Filter out field args not provided
|
||||
provided_args = {k: v for (k, v) in field_args.items() if v is not PydanticUndefined}
|
||||
|
||||
if (default is not PydanticUndefined) and (default_factory is not PydanticUndefined):
|
||||
raise ValueError("Cannot specify both default and default_factory")
|
||||
|
||||
# because we are manually making fields optional, we need to store the original required bool for reference later
|
||||
if default is PydanticUndefined and default_factory is PydanticUndefined:
|
||||
json_schema_extra_.update(dict(orig_required=True))
|
||||
else:
|
||||
json_schema_extra_.update(dict(orig_required=False))
|
||||
|
||||
# make Input.Any and Input.Connection fields optional, providing None as a default if the field doesn't already have one
|
||||
if (input is Input.Any or input is Input.Connection) and default_factory is PydanticUndefined:
|
||||
default_ = None if default is PydanticUndefined else default
|
||||
provided_args.update(dict(default=default_))
|
||||
if default is not PydanticUndefined:
|
||||
# before invoking, we'll grab the original default value and set it on the field if the field wasn't provided a value
|
||||
json_schema_extra_.update(dict(default=default))
|
||||
json_schema_extra_.update(dict(orig_default=default))
|
||||
elif default is not PydanticUndefined and default_factory is PydanticUndefined:
|
||||
default_ = default
|
||||
provided_args.update(dict(default=default_))
|
||||
json_schema_extra_.update(dict(orig_default=default_))
|
||||
elif default_factory is not PydanticUndefined:
|
||||
provided_args.update(dict(default_factory=default_factory))
|
||||
# TODO: cannot serialize default_factory...
|
||||
# json_schema_extra_.update(dict(orig_default_factory=default_factory))
|
||||
|
||||
return Field(
|
||||
**provided_args,
|
||||
json_schema_extra=json_schema_extra_,
|
||||
)
|
||||
|
||||
|
||||
def OutputField(
|
||||
*args: Any,
|
||||
default: Any = Undefined,
|
||||
default_factory: Optional[NoArgAnyCallable] = None,
|
||||
alias: Optional[str] = None,
|
||||
title: Optional[str] = None,
|
||||
description: Optional[str] = None,
|
||||
exclude: Optional[Union[AbstractSet[Union[int, str]], Mapping[Union[int, str], Any], Any]] = None,
|
||||
include: Optional[Union[AbstractSet[Union[int, str]], Mapping[Union[int, str], Any], Any]] = None,
|
||||
const: Optional[bool] = None,
|
||||
gt: Optional[float] = None,
|
||||
ge: Optional[float] = None,
|
||||
lt: Optional[float] = None,
|
||||
le: Optional[float] = None,
|
||||
multiple_of: Optional[float] = None,
|
||||
allow_inf_nan: Optional[bool] = None,
|
||||
max_digits: Optional[int] = None,
|
||||
decimal_places: Optional[int] = None,
|
||||
min_items: Optional[int] = None,
|
||||
max_items: Optional[int] = None,
|
||||
unique_items: Optional[bool] = None,
|
||||
min_length: Optional[int] = None,
|
||||
max_length: Optional[int] = None,
|
||||
allow_mutation: bool = True,
|
||||
regex: Optional[str] = None,
|
||||
discriminator: Optional[str] = None,
|
||||
repr: bool = True,
|
||||
# copied from pydantic's Field
|
||||
default: Any = _Unset,
|
||||
default_factory: Callable[[], Any] | None = _Unset,
|
||||
title: str | None = _Unset,
|
||||
description: str | None = _Unset,
|
||||
pattern: str | None = _Unset,
|
||||
strict: bool | None = _Unset,
|
||||
gt: float | None = _Unset,
|
||||
ge: float | None = _Unset,
|
||||
lt: float | None = _Unset,
|
||||
le: float | None = _Unset,
|
||||
multiple_of: float | None = _Unset,
|
||||
allow_inf_nan: bool | None = _Unset,
|
||||
max_digits: int | None = _Unset,
|
||||
decimal_places: int | None = _Unset,
|
||||
min_length: int | None = _Unset,
|
||||
max_length: int | None = _Unset,
|
||||
# custom
|
||||
ui_type: Optional[UIType] = None,
|
||||
ui_hidden: bool = False,
|
||||
ui_order: Optional[int] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""
|
||||
Creates an output field for an invocation output.
|
||||
@ -383,15 +420,12 @@ def OutputField(
|
||||
: param int ui_order: [None] Specifies the order in which this field should be rendered in the UI. \
|
||||
"""
|
||||
return Field(
|
||||
*args,
|
||||
default=default,
|
||||
default_factory=default_factory,
|
||||
alias=alias,
|
||||
title=title,
|
||||
description=description,
|
||||
exclude=exclude,
|
||||
include=include,
|
||||
const=const,
|
||||
pattern=pattern,
|
||||
strict=strict,
|
||||
gt=gt,
|
||||
ge=ge,
|
||||
lt=lt,
|
||||
@ -400,19 +434,13 @@ def OutputField(
|
||||
allow_inf_nan=allow_inf_nan,
|
||||
max_digits=max_digits,
|
||||
decimal_places=decimal_places,
|
||||
min_items=min_items,
|
||||
max_items=max_items,
|
||||
unique_items=unique_items,
|
||||
min_length=min_length,
|
||||
max_length=max_length,
|
||||
allow_mutation=allow_mutation,
|
||||
regex=regex,
|
||||
discriminator=discriminator,
|
||||
repr=repr,
|
||||
ui_type=ui_type,
|
||||
ui_hidden=ui_hidden,
|
||||
ui_order=ui_order,
|
||||
**kwargs,
|
||||
json_schema_extra=dict(
|
||||
ui_type=ui_type,
|
||||
ui_hidden=ui_hidden,
|
||||
ui_order=ui_order,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@ -426,7 +454,13 @@ class UIConfigBase(BaseModel):
|
||||
title: Optional[str] = Field(default=None, description="The node's display name")
|
||||
category: Optional[str] = Field(default=None, description="The node's category")
|
||||
version: Optional[str] = Field(
|
||||
default=None, description='The node\'s version. Should be a valid semver string e.g. "1.0.0" or "3.8.13".'
|
||||
default=None,
|
||||
description='The node\'s version. Should be a valid semver string e.g. "1.0.0" or "3.8.13".',
|
||||
)
|
||||
|
||||
model_config = ConfigDict(
|
||||
validate_assignment=True,
|
||||
json_schema_serialization_defaults_required=True,
|
||||
)
|
||||
|
||||
|
||||
@ -461,23 +495,38 @@ class BaseInvocationOutput(BaseModel):
|
||||
All invocation outputs must use the `@invocation_output` decorator to provide their unique type.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def get_all_subclasses_tuple(cls):
|
||||
subclasses = []
|
||||
toprocess = [cls]
|
||||
while len(toprocess) > 0:
|
||||
next = toprocess.pop(0)
|
||||
next_subclasses = next.__subclasses__()
|
||||
subclasses.extend(next_subclasses)
|
||||
toprocess.extend(next_subclasses)
|
||||
return tuple(subclasses)
|
||||
_output_classes: ClassVar[set[BaseInvocationOutput]] = set()
|
||||
|
||||
class Config:
|
||||
@staticmethod
|
||||
def schema_extra(schema: dict[str, Any], model_class: Type[BaseModel]) -> None:
|
||||
if "required" not in schema or not isinstance(schema["required"], list):
|
||||
schema["required"] = list()
|
||||
schema["required"].extend(["type"])
|
||||
@classmethod
|
||||
def register_output(cls, output: BaseInvocationOutput) -> None:
|
||||
cls._output_classes.add(output)
|
||||
|
||||
@classmethod
|
||||
def get_outputs(cls) -> Iterable[BaseInvocationOutput]:
|
||||
return cls._output_classes
|
||||
|
||||
@classmethod
|
||||
def get_outputs_union(cls) -> UnionType:
|
||||
outputs_union = Union[tuple(cls._output_classes)] # type: ignore [valid-type]
|
||||
return outputs_union # type: ignore [return-value]
|
||||
|
||||
@classmethod
|
||||
def get_output_types(cls) -> Iterable[str]:
|
||||
return map(lambda i: get_type(i), BaseInvocationOutput.get_outputs())
|
||||
|
||||
@staticmethod
|
||||
def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseModel]) -> None:
|
||||
# Because we use a pydantic Literal field with default value for the invocation type,
|
||||
# it will be typed as optional in the OpenAPI schema. Make it required manually.
|
||||
if "required" not in schema or not isinstance(schema["required"], list):
|
||||
schema["required"] = list()
|
||||
schema["required"].extend(["type"])
|
||||
|
||||
model_config = ConfigDict(
|
||||
validate_assignment=True,
|
||||
json_schema_serialization_defaults_required=True,
|
||||
json_schema_extra=json_schema_extra,
|
||||
)
|
||||
|
||||
|
||||
class RequiredConnectionException(Exception):
|
||||
@ -502,104 +551,91 @@ class BaseInvocation(ABC, BaseModel):
|
||||
All invocations must use the `@invocation` decorator to provide their unique type.
|
||||
"""
|
||||
|
||||
_invocation_classes: ClassVar[set[BaseInvocation]] = set()
|
||||
|
||||
@classmethod
|
||||
def get_all_subclasses(cls):
|
||||
def register_invocation(cls, invocation: BaseInvocation) -> None:
|
||||
cls._invocation_classes.add(invocation)
|
||||
|
||||
@classmethod
|
||||
def get_invocations_union(cls) -> UnionType:
|
||||
invocations_union = Union[tuple(cls._invocation_classes)] # type: ignore [valid-type]
|
||||
return invocations_union # type: ignore [return-value]
|
||||
|
||||
@classmethod
|
||||
def get_invocations(cls) -> Iterable[BaseInvocation]:
|
||||
app_config = InvokeAIAppConfig.get_config()
|
||||
subclasses = []
|
||||
toprocess = [cls]
|
||||
while len(toprocess) > 0:
|
||||
next = toprocess.pop(0)
|
||||
next_subclasses = next.__subclasses__()
|
||||
subclasses.extend(next_subclasses)
|
||||
toprocess.extend(next_subclasses)
|
||||
allowed_invocations = []
|
||||
for sc in subclasses:
|
||||
allowed_invocations: set[BaseInvocation] = set()
|
||||
for sc in cls._invocation_classes:
|
||||
invocation_type = get_type(sc)
|
||||
is_in_allowlist = (
|
||||
sc.__fields__.get("type").default in app_config.allow_nodes
|
||||
if isinstance(app_config.allow_nodes, list)
|
||||
else True
|
||||
invocation_type in app_config.allow_nodes if isinstance(app_config.allow_nodes, list) else True
|
||||
)
|
||||
|
||||
is_in_denylist = (
|
||||
sc.__fields__.get("type").default in app_config.deny_nodes
|
||||
if isinstance(app_config.deny_nodes, list)
|
||||
else False
|
||||
invocation_type in app_config.deny_nodes if isinstance(app_config.deny_nodes, list) else False
|
||||
)
|
||||
|
||||
if is_in_allowlist and not is_in_denylist:
|
||||
allowed_invocations.append(sc)
|
||||
allowed_invocations.add(sc)
|
||||
return allowed_invocations
|
||||
|
||||
@classmethod
|
||||
def get_invocations(cls):
|
||||
return tuple(BaseInvocation.get_all_subclasses())
|
||||
|
||||
@classmethod
|
||||
def get_invocations_map(cls):
|
||||
def get_invocations_map(cls) -> dict[str, BaseInvocation]:
|
||||
# Get the type strings out of the literals and into a dictionary
|
||||
return dict(
|
||||
map(
|
||||
lambda t: (get_args(get_type_hints(t)["type"])[0], t),
|
||||
BaseInvocation.get_all_subclasses(),
|
||||
lambda i: (get_type(i), i),
|
||||
BaseInvocation.get_invocations(),
|
||||
)
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_output_type(cls):
|
||||
def get_invocation_types(cls) -> Iterable[str]:
|
||||
return map(lambda i: get_type(i), BaseInvocation.get_invocations())
|
||||
|
||||
@classmethod
|
||||
def get_output_type(cls) -> BaseInvocationOutput:
|
||||
return signature(cls.invoke).return_annotation
|
||||
|
||||
class Config:
|
||||
validate_assignment = True
|
||||
validate_all = True
|
||||
|
||||
@staticmethod
|
||||
def schema_extra(schema: dict[str, Any], model_class: Type[BaseModel]) -> None:
|
||||
uiconfig = getattr(model_class, "UIConfig", None)
|
||||
if uiconfig and hasattr(uiconfig, "title"):
|
||||
schema["title"] = uiconfig.title
|
||||
if uiconfig and hasattr(uiconfig, "tags"):
|
||||
schema["tags"] = uiconfig.tags
|
||||
if uiconfig and hasattr(uiconfig, "category"):
|
||||
schema["category"] = uiconfig.category
|
||||
if uiconfig and hasattr(uiconfig, "version"):
|
||||
schema["version"] = uiconfig.version
|
||||
if "required" not in schema or not isinstance(schema["required"], list):
|
||||
schema["required"] = list()
|
||||
schema["required"].extend(["type", "id"])
|
||||
@staticmethod
|
||||
def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseModel]) -> None:
|
||||
# Add the various UI-facing attributes to the schema. These are used to build the invocation templates.
|
||||
uiconfig = getattr(model_class, "UIConfig", None)
|
||||
if uiconfig and hasattr(uiconfig, "title"):
|
||||
schema["title"] = uiconfig.title
|
||||
if uiconfig and hasattr(uiconfig, "tags"):
|
||||
schema["tags"] = uiconfig.tags
|
||||
if uiconfig and hasattr(uiconfig, "category"):
|
||||
schema["category"] = uiconfig.category
|
||||
if uiconfig and hasattr(uiconfig, "version"):
|
||||
schema["version"] = uiconfig.version
|
||||
if "required" not in schema or not isinstance(schema["required"], list):
|
||||
schema["required"] = list()
|
||||
schema["required"].extend(["type", "id"])
|
||||
|
||||
@abstractmethod
|
||||
def invoke(self, context: InvocationContext) -> BaseInvocationOutput:
|
||||
"""Invoke with provided context and return outputs."""
|
||||
pass
|
||||
|
||||
def __init__(self, **data):
|
||||
# nodes may have required fields, that can accept input from connections
|
||||
# on instantiation of the model, we need to exclude these from validation
|
||||
restore = dict()
|
||||
try:
|
||||
field_names = list(self.__fields__.keys())
|
||||
for field_name in field_names:
|
||||
# if the field is required and may get its value from a connection, exclude it from validation
|
||||
field = self.__fields__[field_name]
|
||||
_input = field.field_info.extra.get("input", None)
|
||||
if _input in [Input.Connection, Input.Any] and field.required:
|
||||
if field_name not in data:
|
||||
restore[field_name] = self.__fields__.pop(field_name)
|
||||
# instantiate the node, which will validate the data
|
||||
super().__init__(**data)
|
||||
finally:
|
||||
# restore the removed fields
|
||||
for field_name, field in restore.items():
|
||||
self.__fields__[field_name] = field
|
||||
|
||||
def invoke_internal(self, context: InvocationContext) -> BaseInvocationOutput:
|
||||
for field_name, field in self.__fields__.items():
|
||||
_input = field.field_info.extra.get("input", None)
|
||||
if field.required and not hasattr(self, field_name):
|
||||
if _input == Input.Connection:
|
||||
raise RequiredConnectionException(self.__fields__["type"].default, field_name)
|
||||
elif _input == Input.Any:
|
||||
raise MissingInputException(self.__fields__["type"].default, field_name)
|
||||
for field_name, field in self.model_fields.items():
|
||||
if not field.json_schema_extra or callable(field.json_schema_extra):
|
||||
# something has gone terribly awry, we should always have this and it should be a dict
|
||||
continue
|
||||
|
||||
# Here we handle the case where the field is optional in the pydantic class, but required
|
||||
# in the `invoke()` method.
|
||||
|
||||
orig_default = field.json_schema_extra.get("orig_default", PydanticUndefined)
|
||||
orig_required = field.json_schema_extra.get("orig_required", True)
|
||||
input_ = field.json_schema_extra.get("input", None)
|
||||
if orig_default is not PydanticUndefined and not hasattr(self, field_name):
|
||||
setattr(self, field_name, orig_default)
|
||||
if orig_required and orig_default is PydanticUndefined and getattr(self, field_name) is None:
|
||||
if input_ == Input.Connection:
|
||||
raise RequiredConnectionException(self.model_fields["type"].default, field_name)
|
||||
elif input_ == Input.Any:
|
||||
raise MissingInputException(self.model_fields["type"].default, field_name)
|
||||
|
||||
# skip node cache codepath if it's disabled
|
||||
if context.services.configuration.node_cache_size == 0:
|
||||
@ -622,23 +658,31 @@ class BaseInvocation(ABC, BaseModel):
|
||||
return self.invoke(context)
|
||||
|
||||
def get_type(self) -> str:
|
||||
return self.__fields__["type"].default
|
||||
return self.model_fields["type"].default
|
||||
|
||||
id: str = Field(
|
||||
description="The id of this instance of an invocation. Must be unique among all instances of invocations."
|
||||
default_factory=uuid_string,
|
||||
description="The id of this instance of an invocation. Must be unique among all instances of invocations.",
|
||||
)
|
||||
is_intermediate: bool = InputField(
|
||||
default=False, description="Whether or not this is an intermediate invocation.", ui_type=UIType.IsIntermediate
|
||||
is_intermediate: Optional[bool] = Field(
|
||||
default=False,
|
||||
description="Whether or not this is an intermediate invocation.",
|
||||
json_schema_extra=dict(ui_type=UIType.IsIntermediate),
|
||||
)
|
||||
workflow: Optional[str] = InputField(
|
||||
workflow: Optional[str] = Field(
|
||||
default=None,
|
||||
description="The workflow to save with the image",
|
||||
ui_type=UIType.WorkflowField,
|
||||
json_schema_extra=dict(ui_type=UIType.WorkflowField),
|
||||
)
|
||||
use_cache: Optional[bool] = Field(
|
||||
default=True,
|
||||
description="Whether or not to use the cache",
|
||||
)
|
||||
use_cache: bool = InputField(default=True, description="Whether or not to use the cache")
|
||||
|
||||
@validator("workflow", pre=True)
|
||||
@field_validator("workflow", mode="before")
|
||||
@classmethod
|
||||
def validate_workflow_is_json(cls, v):
|
||||
"""We don't have a workflow schema in the backend, so we just check that it's valid JSON"""
|
||||
if v is None:
|
||||
return None
|
||||
try:
|
||||
@ -649,8 +693,14 @@ class BaseInvocation(ABC, BaseModel):
|
||||
|
||||
UIConfig: ClassVar[Type[UIConfigBase]]
|
||||
|
||||
model_config = ConfigDict(
|
||||
validate_assignment=True,
|
||||
json_schema_extra=json_schema_extra,
|
||||
json_schema_serialization_defaults_required=True,
|
||||
)
|
||||
|
||||
GenericBaseInvocation = TypeVar("GenericBaseInvocation", bound=BaseInvocation)
|
||||
|
||||
TBaseInvocation = TypeVar("TBaseInvocation", bound=BaseInvocation)
|
||||
|
||||
|
||||
def invocation(
|
||||
@ -660,7 +710,7 @@ def invocation(
|
||||
category: Optional[str] = None,
|
||||
version: Optional[str] = None,
|
||||
use_cache: Optional[bool] = True,
|
||||
) -> Callable[[Type[GenericBaseInvocation]], Type[GenericBaseInvocation]]:
|
||||
) -> Callable[[Type[TBaseInvocation]], Type[TBaseInvocation]]:
|
||||
"""
|
||||
Adds metadata to an invocation.
|
||||
|
||||
@ -672,12 +722,15 @@ def invocation(
|
||||
:param Optional[bool] use_cache: Whether or not to use the invocation cache. Defaults to True. The user may override this in the workflow editor.
|
||||
"""
|
||||
|
||||
def wrapper(cls: Type[GenericBaseInvocation]) -> Type[GenericBaseInvocation]:
|
||||
def wrapper(cls: Type[TBaseInvocation]) -> Type[TBaseInvocation]:
|
||||
# Validate invocation types on creation of invocation classes
|
||||
# TODO: ensure unique?
|
||||
if re.compile(r"^\S+$").match(invocation_type) is None:
|
||||
raise ValueError(f'"invocation_type" must consist of non-whitespace characters, got "{invocation_type}"')
|
||||
|
||||
if invocation_type in BaseInvocation.get_invocation_types():
|
||||
raise ValueError(f'Invocation type "{invocation_type}" already exists')
|
||||
|
||||
# Add OpenAPI schema extras
|
||||
uiconf_name = cls.__qualname__ + ".UIConfig"
|
||||
if not hasattr(cls, "UIConfig") or cls.UIConfig.__qualname__ != uiconf_name:
|
||||
@ -695,59 +748,83 @@ def invocation(
|
||||
raise InvalidVersionError(f'Invalid version string for node "{invocation_type}": "{version}"') from e
|
||||
cls.UIConfig.version = version
|
||||
if use_cache is not None:
|
||||
cls.__fields__["use_cache"].default = use_cache
|
||||
cls.model_fields["use_cache"].default = use_cache
|
||||
|
||||
# Add the invocation type to the model.
|
||||
|
||||
# You'd be tempted to just add the type field and rebuild the model, like this:
|
||||
# cls.model_fields.update(type=FieldInfo.from_annotated_attribute(Literal[invocation_type], invocation_type))
|
||||
# cls.model_rebuild() or cls.model_rebuild(force=True)
|
||||
|
||||
# Unfortunately, because the `GraphInvocation` uses a forward ref in its `graph` field's annotation, this does
|
||||
# not work. Instead, we have to create a new class with the type field and patch the original class with it.
|
||||
|
||||
# Add the invocation type to the pydantic model of the invocation
|
||||
invocation_type_annotation = Literal[invocation_type] # type: ignore
|
||||
invocation_type_field = ModelField.infer(
|
||||
name="type",
|
||||
value=invocation_type,
|
||||
annotation=invocation_type_annotation,
|
||||
class_validators=None,
|
||||
config=cls.__config__,
|
||||
invocation_type_field = Field(
|
||||
title="type",
|
||||
default=invocation_type,
|
||||
)
|
||||
cls.__fields__.update({"type": invocation_type_field})
|
||||
# to support 3.9, 3.10 and 3.11, as described in https://docs.python.org/3/howto/annotations.html
|
||||
if annotations := cls.__dict__.get("__annotations__", None):
|
||||
annotations.update({"type": invocation_type_annotation})
|
||||
|
||||
docstring = cls.__doc__
|
||||
cls = create_model(
|
||||
cls.__qualname__,
|
||||
__base__=cls,
|
||||
__module__=cls.__module__,
|
||||
type=(invocation_type_annotation, invocation_type_field),
|
||||
)
|
||||
cls.__doc__ = docstring
|
||||
|
||||
# TODO: how to type this correctly? it's typed as ModelMetaclass, a private class in pydantic
|
||||
BaseInvocation.register_invocation(cls) # type: ignore
|
||||
|
||||
return cls
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
GenericBaseInvocationOutput = TypeVar("GenericBaseInvocationOutput", bound=BaseInvocationOutput)
|
||||
TBaseInvocationOutput = TypeVar("TBaseInvocationOutput", bound=BaseInvocationOutput)
|
||||
|
||||
|
||||
def invocation_output(
|
||||
output_type: str,
|
||||
) -> Callable[[Type[GenericBaseInvocationOutput]], Type[GenericBaseInvocationOutput]]:
|
||||
) -> Callable[[Type[TBaseInvocationOutput]], Type[TBaseInvocationOutput]]:
|
||||
"""
|
||||
Adds metadata to an invocation output.
|
||||
|
||||
:param str output_type: The type of the invocation output. Must be unique among all invocation outputs.
|
||||
"""
|
||||
|
||||
def wrapper(cls: Type[GenericBaseInvocationOutput]) -> Type[GenericBaseInvocationOutput]:
|
||||
def wrapper(cls: Type[TBaseInvocationOutput]) -> Type[TBaseInvocationOutput]:
|
||||
# Validate output types on creation of invocation output classes
|
||||
# TODO: ensure unique?
|
||||
if re.compile(r"^\S+$").match(output_type) is None:
|
||||
raise ValueError(f'"output_type" must consist of non-whitespace characters, got "{output_type}"')
|
||||
|
||||
# Add the output type to the pydantic model of the invocation output
|
||||
output_type_annotation = Literal[output_type] # type: ignore
|
||||
output_type_field = ModelField.infer(
|
||||
name="type",
|
||||
value=output_type,
|
||||
annotation=output_type_annotation,
|
||||
class_validators=None,
|
||||
config=cls.__config__,
|
||||
)
|
||||
cls.__fields__.update({"type": output_type_field})
|
||||
if output_type in BaseInvocationOutput.get_output_types():
|
||||
raise ValueError(f'Invocation type "{output_type}" already exists')
|
||||
|
||||
# to support 3.9, 3.10 and 3.11, as described in https://docs.python.org/3/howto/annotations.html
|
||||
if annotations := cls.__dict__.get("__annotations__", None):
|
||||
annotations.update({"type": output_type_annotation})
|
||||
# Add the output type to the model.
|
||||
|
||||
output_type_annotation = Literal[output_type] # type: ignore
|
||||
output_type_field = Field(
|
||||
title="type",
|
||||
default=output_type,
|
||||
)
|
||||
|
||||
docstring = cls.__doc__
|
||||
cls = create_model(
|
||||
cls.__qualname__,
|
||||
__base__=cls,
|
||||
__module__=cls.__module__,
|
||||
type=(output_type_annotation, output_type_field),
|
||||
)
|
||||
cls.__doc__ = docstring
|
||||
|
||||
BaseInvocationOutput.register_output(cls) # type: ignore # TODO: how to type this correctly?
|
||||
|
||||
return cls
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
GenericBaseModel = TypeVar("GenericBaseModel", bound=BaseModel)
|
||||
|
@ -2,7 +2,7 @@
|
||||
|
||||
|
||||
import numpy as np
|
||||
from pydantic import validator
|
||||
from pydantic import ValidationInfo, field_validator
|
||||
|
||||
from invokeai.app.invocations.primitives import IntegerCollectionOutput
|
||||
from invokeai.app.util.misc import SEED_MAX, get_random_seed
|
||||
@ -20,9 +20,9 @@ class RangeInvocation(BaseInvocation):
|
||||
stop: int = InputField(default=10, description="The stop of the range")
|
||||
step: int = InputField(default=1, description="The step of the range")
|
||||
|
||||
@validator("stop")
|
||||
def stop_gt_start(cls, v, values):
|
||||
if "start" in values and v <= values["start"]:
|
||||
@field_validator("stop")
|
||||
def stop_gt_start(cls, v: int, info: ValidationInfo):
|
||||
if "start" in info.data and v <= info.data["start"]:
|
||||
raise ValueError("stop must be greater than start")
|
||||
return v
|
||||
|
||||
|
@ -1,6 +1,6 @@
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Union
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import torch
|
||||
from compel import Compel, ReturnedEmbeddingsType
|
||||
@ -43,7 +43,13 @@ class ConditioningFieldData:
|
||||
# PerpNeg = "perp_neg"
|
||||
|
||||
|
||||
@invocation("compel", title="Prompt", tags=["prompt", "compel"], category="conditioning", version="1.0.0")
|
||||
@invocation(
|
||||
"compel",
|
||||
title="Prompt",
|
||||
tags=["prompt", "compel"],
|
||||
category="conditioning",
|
||||
version="1.0.0",
|
||||
)
|
||||
class CompelInvocation(BaseInvocation):
|
||||
"""Parse prompt using compel package to conditioning."""
|
||||
|
||||
@ -61,17 +67,19 @@ class CompelInvocation(BaseInvocation):
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> ConditioningOutput:
|
||||
tokenizer_info = context.services.model_manager.get_model(
|
||||
**self.clip.tokenizer.dict(),
|
||||
**self.clip.tokenizer.model_dump(),
|
||||
context=context,
|
||||
)
|
||||
text_encoder_info = context.services.model_manager.get_model(
|
||||
**self.clip.text_encoder.dict(),
|
||||
**self.clip.text_encoder.model_dump(),
|
||||
context=context,
|
||||
)
|
||||
|
||||
def _lora_loader():
|
||||
for lora in self.clip.loras:
|
||||
lora_info = context.services.model_manager.get_model(**lora.dict(exclude={"weight"}), context=context)
|
||||
lora_info = context.services.model_manager.get_model(
|
||||
**lora.model_dump(exclude={"weight"}), context=context
|
||||
)
|
||||
yield (lora_info.context.model, lora.weight)
|
||||
del lora_info
|
||||
return
|
||||
@ -160,11 +168,11 @@ class SDXLPromptInvocationBase:
|
||||
zero_on_empty: bool,
|
||||
):
|
||||
tokenizer_info = context.services.model_manager.get_model(
|
||||
**clip_field.tokenizer.dict(),
|
||||
**clip_field.tokenizer.model_dump(),
|
||||
context=context,
|
||||
)
|
||||
text_encoder_info = context.services.model_manager.get_model(
|
||||
**clip_field.text_encoder.dict(),
|
||||
**clip_field.text_encoder.model_dump(),
|
||||
context=context,
|
||||
)
|
||||
|
||||
@ -172,7 +180,11 @@ class SDXLPromptInvocationBase:
|
||||
if prompt == "" and zero_on_empty:
|
||||
cpu_text_encoder = text_encoder_info.context.model
|
||||
c = torch.zeros(
|
||||
(1, cpu_text_encoder.config.max_position_embeddings, cpu_text_encoder.config.hidden_size),
|
||||
(
|
||||
1,
|
||||
cpu_text_encoder.config.max_position_embeddings,
|
||||
cpu_text_encoder.config.hidden_size,
|
||||
),
|
||||
dtype=text_encoder_info.context.cache.precision,
|
||||
)
|
||||
if get_pooled:
|
||||
@ -186,7 +198,9 @@ class SDXLPromptInvocationBase:
|
||||
|
||||
def _lora_loader():
|
||||
for lora in clip_field.loras:
|
||||
lora_info = context.services.model_manager.get_model(**lora.dict(exclude={"weight"}), context=context)
|
||||
lora_info = context.services.model_manager.get_model(
|
||||
**lora.model_dump(exclude={"weight"}), context=context
|
||||
)
|
||||
yield (lora_info.context.model, lora.weight)
|
||||
del lora_info
|
||||
return
|
||||
@ -273,8 +287,16 @@ class SDXLPromptInvocationBase:
|
||||
class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
||||
"""Parse prompt using compel package to conditioning."""
|
||||
|
||||
prompt: str = InputField(default="", description=FieldDescriptions.compel_prompt, ui_component=UIComponent.Textarea)
|
||||
style: str = InputField(default="", description=FieldDescriptions.compel_prompt, ui_component=UIComponent.Textarea)
|
||||
prompt: str = InputField(
|
||||
default="",
|
||||
description=FieldDescriptions.compel_prompt,
|
||||
ui_component=UIComponent.Textarea,
|
||||
)
|
||||
style: str = InputField(
|
||||
default="",
|
||||
description=FieldDescriptions.compel_prompt,
|
||||
ui_component=UIComponent.Textarea,
|
||||
)
|
||||
original_width: int = InputField(default=1024, description="")
|
||||
original_height: int = InputField(default=1024, description="")
|
||||
crop_top: int = InputField(default=0, description="")
|
||||
@ -310,7 +332,9 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
||||
[
|
||||
c1,
|
||||
torch.zeros(
|
||||
(c1.shape[0], c2.shape[1] - c1.shape[1], c1.shape[2]), device=c1.device, dtype=c1.dtype
|
||||
(c1.shape[0], c2.shape[1] - c1.shape[1], c1.shape[2]),
|
||||
device=c1.device,
|
||||
dtype=c1.dtype,
|
||||
),
|
||||
],
|
||||
dim=1,
|
||||
@ -321,7 +345,9 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
||||
[
|
||||
c2,
|
||||
torch.zeros(
|
||||
(c2.shape[0], c1.shape[1] - c2.shape[1], c2.shape[2]), device=c2.device, dtype=c2.dtype
|
||||
(c2.shape[0], c1.shape[1] - c2.shape[1], c2.shape[2]),
|
||||
device=c2.device,
|
||||
dtype=c2.dtype,
|
||||
),
|
||||
],
|
||||
dim=1,
|
||||
@ -359,7 +385,9 @@ class SDXLRefinerCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase
|
||||
"""Parse prompt using compel package to conditioning."""
|
||||
|
||||
style: str = InputField(
|
||||
default="", description=FieldDescriptions.compel_prompt, ui_component=UIComponent.Textarea
|
||||
default="",
|
||||
description=FieldDescriptions.compel_prompt,
|
||||
ui_component=UIComponent.Textarea,
|
||||
) # TODO: ?
|
||||
original_width: int = InputField(default=1024, description="")
|
||||
original_height: int = InputField(default=1024, description="")
|
||||
@ -403,10 +431,16 @@ class SDXLRefinerCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase
|
||||
class ClipSkipInvocationOutput(BaseInvocationOutput):
|
||||
"""Clip skip node output"""
|
||||
|
||||
clip: ClipField = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP")
|
||||
clip: Optional[ClipField] = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP")
|
||||
|
||||
|
||||
@invocation("clip_skip", title="CLIP Skip", tags=["clipskip", "clip", "skip"], category="conditioning", version="1.0.0")
|
||||
@invocation(
|
||||
"clip_skip",
|
||||
title="CLIP Skip",
|
||||
tags=["clipskip", "clip", "skip"],
|
||||
category="conditioning",
|
||||
version="1.0.0",
|
||||
)
|
||||
class ClipSkipInvocation(BaseInvocation):
|
||||
"""Skip layers in clip text_encoder model."""
|
||||
|
||||
@ -421,7 +455,9 @@ class ClipSkipInvocation(BaseInvocation):
|
||||
|
||||
|
||||
def get_max_token_count(
|
||||
tokenizer, prompt: Union[FlattenedPrompt, Blend, Conjunction], truncate_if_too_long=False
|
||||
tokenizer,
|
||||
prompt: Union[FlattenedPrompt, Blend, Conjunction],
|
||||
truncate_if_too_long=False,
|
||||
) -> int:
|
||||
if type(prompt) is Blend:
|
||||
blend: Blend = prompt
|
||||
|
@ -2,7 +2,7 @@
|
||||
# initial implementation by Gregg Helt, 2023
|
||||
# heavily leverages controlnet_aux package: https://github.com/patrickvonplaten/controlnet_aux
|
||||
from builtins import bool, float
|
||||
from typing import Dict, List, Literal, Optional, Union
|
||||
from typing import Dict, List, Literal, Union
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
@ -24,12 +24,12 @@ from controlnet_aux import (
|
||||
)
|
||||
from controlnet_aux.util import HWC3, ade_palette
|
||||
from PIL import Image
|
||||
from pydantic import BaseModel, Field, validator
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||
|
||||
from invokeai.app.invocations.primitives import ImageField, ImageOutput
|
||||
from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin
|
||||
|
||||
from ...backend.model_management import BaseModelType
|
||||
from ..models.image import ImageCategory, ResourceOrigin
|
||||
from .baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
@ -57,6 +57,8 @@ class ControlNetModelField(BaseModel):
|
||||
model_name: str = Field(description="Name of the ControlNet model")
|
||||
base_model: BaseModelType = Field(description="Base model")
|
||||
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
|
||||
class ControlField(BaseModel):
|
||||
image: ImageField = Field(description="The control image")
|
||||
@ -71,7 +73,7 @@ class ControlField(BaseModel):
|
||||
control_mode: CONTROLNET_MODE_VALUES = Field(default="balanced", description="The control mode to use")
|
||||
resize_mode: CONTROLNET_RESIZE_VALUES = Field(default="just_resize", description="The resize mode to use")
|
||||
|
||||
@validator("control_weight")
|
||||
@field_validator("control_weight")
|
||||
def validate_control_weight(cls, v):
|
||||
"""Validate that all control weights in the valid range"""
|
||||
if isinstance(v, list):
|
||||
@ -124,9 +126,7 @@ class ControlNetInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
|
||||
@invocation(
|
||||
"image_processor", title="Base Image Processor", tags=["controlnet"], category="controlnet", version="1.0.0"
|
||||
)
|
||||
# This invocation exists for other invocations to subclass it - do not register with @invocation!
|
||||
class ImageProcessorInvocation(BaseInvocation):
|
||||
"""Base class for invocations that preprocess images for ControlNet"""
|
||||
|
||||
@ -393,9 +393,9 @@ class ContentShuffleImageProcessorInvocation(ImageProcessorInvocation):
|
||||
|
||||
detect_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.detect_res)
|
||||
image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res)
|
||||
h: Optional[int] = InputField(default=512, ge=0, description="Content shuffle `h` parameter")
|
||||
w: Optional[int] = InputField(default=512, ge=0, description="Content shuffle `w` parameter")
|
||||
f: Optional[int] = InputField(default=256, ge=0, description="Content shuffle `f` parameter")
|
||||
h: int = InputField(default=512, ge=0, description="Content shuffle `h` parameter")
|
||||
w: int = InputField(default=512, ge=0, description="Content shuffle `w` parameter")
|
||||
f: int = InputField(default=256, ge=0, description="Content shuffle `f` parameter")
|
||||
|
||||
def run_processor(self, image):
|
||||
content_shuffle_processor = ContentShuffleDetector()
|
||||
@ -575,14 +575,14 @@ class ColorMapImageProcessorInvocation(ImageProcessorInvocation):
|
||||
|
||||
def run_processor(self, image: Image.Image):
|
||||
image = image.convert("RGB")
|
||||
image = np.array(image, dtype=np.uint8)
|
||||
height, width = image.shape[:2]
|
||||
np_image = np.array(image, dtype=np.uint8)
|
||||
height, width = np_image.shape[:2]
|
||||
|
||||
width_tile_size = min(self.color_map_tile_size, width)
|
||||
height_tile_size = min(self.color_map_tile_size, height)
|
||||
|
||||
color_map = cv2.resize(
|
||||
image,
|
||||
np_image,
|
||||
(width // width_tile_size, height // height_tile_size),
|
||||
interpolation=cv2.INTER_CUBIC,
|
||||
)
|
||||
|
@ -6,7 +6,7 @@ import numpy
|
||||
from PIL import Image, ImageOps
|
||||
|
||||
from invokeai.app.invocations.primitives import ImageField, ImageOutput
|
||||
from invokeai.app.models.image import ImageCategory, ResourceOrigin
|
||||
from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin
|
||||
|
||||
from .baseinvocation import BaseInvocation, InputField, InvocationContext, invocation
|
||||
|
||||
|
@ -8,7 +8,7 @@ import numpy as np
|
||||
from mediapipe.python.solutions.face_mesh import FaceMesh # type: ignore[import]
|
||||
from PIL import Image, ImageDraw, ImageFilter, ImageFont, ImageOps
|
||||
from PIL.Image import Image as ImageType
|
||||
from pydantic import validator
|
||||
from pydantic import field_validator
|
||||
|
||||
import invokeai.assets.fonts as font_assets
|
||||
from invokeai.app.invocations.baseinvocation import (
|
||||
@ -20,7 +20,7 @@ from invokeai.app.invocations.baseinvocation import (
|
||||
invocation_output,
|
||||
)
|
||||
from invokeai.app.invocations.primitives import ImageField, ImageOutput
|
||||
from invokeai.app.models.image import ImageCategory, ResourceOrigin
|
||||
from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin
|
||||
|
||||
|
||||
@invocation_output("face_mask_output")
|
||||
@ -46,6 +46,8 @@ class FaceResultData(TypedDict):
|
||||
y_center: float
|
||||
mesh_width: int
|
||||
mesh_height: int
|
||||
chunk_x_offset: int
|
||||
chunk_y_offset: int
|
||||
|
||||
|
||||
class FaceResultDataWithId(FaceResultData):
|
||||
@ -78,6 +80,48 @@ FONT_SIZE = 32
|
||||
FONT_STROKE_WIDTH = 4
|
||||
|
||||
|
||||
def coalesce_faces(face1: FaceResultData, face2: FaceResultData) -> FaceResultData:
|
||||
face1_x_offset = face1["chunk_x_offset"] - min(face1["chunk_x_offset"], face2["chunk_x_offset"])
|
||||
face2_x_offset = face2["chunk_x_offset"] - min(face1["chunk_x_offset"], face2["chunk_x_offset"])
|
||||
face1_y_offset = face1["chunk_y_offset"] - min(face1["chunk_y_offset"], face2["chunk_y_offset"])
|
||||
face2_y_offset = face2["chunk_y_offset"] - min(face1["chunk_y_offset"], face2["chunk_y_offset"])
|
||||
|
||||
new_im_width = (
|
||||
max(face1["image"].width, face2["image"].width)
|
||||
+ max(face1["chunk_x_offset"], face2["chunk_x_offset"])
|
||||
- min(face1["chunk_x_offset"], face2["chunk_x_offset"])
|
||||
)
|
||||
new_im_height = (
|
||||
max(face1["image"].height, face2["image"].height)
|
||||
+ max(face1["chunk_y_offset"], face2["chunk_y_offset"])
|
||||
- min(face1["chunk_y_offset"], face2["chunk_y_offset"])
|
||||
)
|
||||
pil_image = Image.new(mode=face1["image"].mode, size=(new_im_width, new_im_height))
|
||||
pil_image.paste(face1["image"], (face1_x_offset, face1_y_offset))
|
||||
pil_image.paste(face2["image"], (face2_x_offset, face2_y_offset))
|
||||
|
||||
# Mask images are always from the origin
|
||||
new_mask_im_width = max(face1["mask"].width, face2["mask"].width)
|
||||
new_mask_im_height = max(face1["mask"].height, face2["mask"].height)
|
||||
mask_pil = create_white_image(new_mask_im_width, new_mask_im_height)
|
||||
black_image = create_black_image(face1["mask"].width, face1["mask"].height)
|
||||
mask_pil.paste(black_image, (0, 0), ImageOps.invert(face1["mask"]))
|
||||
black_image = create_black_image(face2["mask"].width, face2["mask"].height)
|
||||
mask_pil.paste(black_image, (0, 0), ImageOps.invert(face2["mask"]))
|
||||
|
||||
new_face = FaceResultData(
|
||||
image=pil_image,
|
||||
mask=mask_pil,
|
||||
x_center=max(face1["x_center"], face2["x_center"]),
|
||||
y_center=max(face1["y_center"], face2["y_center"]),
|
||||
mesh_width=max(face1["mesh_width"], face2["mesh_width"]),
|
||||
mesh_height=max(face1["mesh_height"], face2["mesh_height"]),
|
||||
chunk_x_offset=max(face1["chunk_x_offset"], face2["chunk_x_offset"]),
|
||||
chunk_y_offset=max(face2["chunk_y_offset"], face2["chunk_y_offset"]),
|
||||
)
|
||||
return new_face
|
||||
|
||||
|
||||
def prepare_faces_list(
|
||||
face_result_list: list[FaceResultData],
|
||||
) -> list[FaceResultDataWithId]:
|
||||
@ -91,7 +135,7 @@ def prepare_faces_list(
|
||||
should_add = True
|
||||
candidate_x_center = candidate["x_center"]
|
||||
candidate_y_center = candidate["y_center"]
|
||||
for face in deduped_faces:
|
||||
for idx, face in enumerate(deduped_faces):
|
||||
face_center_x = face["x_center"]
|
||||
face_center_y = face["y_center"]
|
||||
face_radius_w = face["mesh_width"] / 2
|
||||
@ -105,6 +149,7 @@ def prepare_faces_list(
|
||||
)
|
||||
|
||||
if p < 1: # Inside of the already-added face's radius
|
||||
deduped_faces[idx] = coalesce_faces(face, candidate)
|
||||
should_add = False
|
||||
break
|
||||
|
||||
@ -138,7 +183,6 @@ def generate_face_box_mask(
|
||||
chunk_x_offset: int = 0,
|
||||
chunk_y_offset: int = 0,
|
||||
draw_mesh: bool = True,
|
||||
check_bounds: bool = True,
|
||||
) -> list[FaceResultData]:
|
||||
result = []
|
||||
mask_pil = None
|
||||
@ -211,33 +255,20 @@ def generate_face_box_mask(
|
||||
mask_pil = create_white_image(w + chunk_x_offset, h + chunk_y_offset)
|
||||
mask_pil.paste(init_mask_pil, (chunk_x_offset, chunk_y_offset))
|
||||
|
||||
left_side = x_center - mesh_width
|
||||
right_side = x_center + mesh_width
|
||||
top_side = y_center - mesh_height
|
||||
bottom_side = y_center + mesh_height
|
||||
im_width, im_height = pil_image.size
|
||||
over_w = im_width * 0.1
|
||||
over_h = im_height * 0.1
|
||||
if not check_bounds or (
|
||||
(left_side >= -over_w)
|
||||
and (right_side < im_width + over_w)
|
||||
and (top_side >= -over_h)
|
||||
and (bottom_side < im_height + over_h)
|
||||
):
|
||||
x_center = float(x_center)
|
||||
y_center = float(y_center)
|
||||
face = FaceResultData(
|
||||
image=pil_image,
|
||||
mask=mask_pil or create_white_image(*pil_image.size),
|
||||
x_center=x_center + chunk_x_offset,
|
||||
y_center=y_center + chunk_y_offset,
|
||||
mesh_width=mesh_width,
|
||||
mesh_height=mesh_height,
|
||||
)
|
||||
x_center = float(x_center)
|
||||
y_center = float(y_center)
|
||||
face = FaceResultData(
|
||||
image=pil_image,
|
||||
mask=mask_pil or create_white_image(*pil_image.size),
|
||||
x_center=x_center + chunk_x_offset,
|
||||
y_center=y_center + chunk_y_offset,
|
||||
mesh_width=mesh_width,
|
||||
mesh_height=mesh_height,
|
||||
chunk_x_offset=chunk_x_offset,
|
||||
chunk_y_offset=chunk_y_offset,
|
||||
)
|
||||
|
||||
result.append(face)
|
||||
else:
|
||||
context.services.logger.info("FaceTools --> Face out of bounds, ignoring.")
|
||||
result.append(face)
|
||||
|
||||
return result
|
||||
|
||||
@ -346,7 +377,6 @@ def get_faces_list(
|
||||
chunk_x_offset=0,
|
||||
chunk_y_offset=0,
|
||||
draw_mesh=draw_mesh,
|
||||
check_bounds=False,
|
||||
)
|
||||
if should_chunk or len(result) == 0:
|
||||
context.services.logger.info("FaceTools --> Chunking image (chunk toggled on, or no face found in full image).")
|
||||
@ -360,24 +390,26 @@ def get_faces_list(
|
||||
if width > height:
|
||||
# Landscape - slice the image horizontally
|
||||
fx = 0.0
|
||||
steps = int(width * 2 / height)
|
||||
steps = int(width * 2 / height) + 1
|
||||
increment = (width - height) / (steps - 1)
|
||||
while fx <= (width - height):
|
||||
x = int(fx)
|
||||
image_chunks.append(image.crop((x, 0, x + height - 1, height - 1)))
|
||||
image_chunks.append(image.crop((x, 0, x + height, height)))
|
||||
x_offsets.append(x)
|
||||
y_offsets.append(0)
|
||||
fx += (width - height) / steps
|
||||
fx += increment
|
||||
context.services.logger.info(f"FaceTools --> Chunk starting at x = {x}")
|
||||
elif height > width:
|
||||
# Portrait - slice the image vertically
|
||||
fy = 0.0
|
||||
steps = int(height * 2 / width)
|
||||
steps = int(height * 2 / width) + 1
|
||||
increment = (height - width) / (steps - 1)
|
||||
while fy <= (height - width):
|
||||
y = int(fy)
|
||||
image_chunks.append(image.crop((0, y, width - 1, y + width - 1)))
|
||||
image_chunks.append(image.crop((0, y, width, y + width)))
|
||||
x_offsets.append(0)
|
||||
y_offsets.append(y)
|
||||
fy += (height - width) / steps
|
||||
fy += increment
|
||||
context.services.logger.info(f"FaceTools --> Chunk starting at y = {y}")
|
||||
|
||||
for idx in range(len(image_chunks)):
|
||||
@ -404,7 +436,7 @@ def get_faces_list(
|
||||
return all_faces
|
||||
|
||||
|
||||
@invocation("face_off", title="FaceOff", tags=["image", "faceoff", "face", "mask"], category="image", version="1.0.1")
|
||||
@invocation("face_off", title="FaceOff", tags=["image", "faceoff", "face", "mask"], category="image", version="1.0.2")
|
||||
class FaceOffInvocation(BaseInvocation):
|
||||
"""Bound, extract, and mask a face from an image using MediaPipe detection"""
|
||||
|
||||
@ -498,7 +530,7 @@ class FaceOffInvocation(BaseInvocation):
|
||||
return output
|
||||
|
||||
|
||||
@invocation("face_mask_detection", title="FaceMask", tags=["image", "face", "mask"], category="image", version="1.0.1")
|
||||
@invocation("face_mask_detection", title="FaceMask", tags=["image", "face", "mask"], category="image", version="1.0.2")
|
||||
class FaceMaskInvocation(BaseInvocation):
|
||||
"""Face mask creation using mediapipe face detection"""
|
||||
|
||||
@ -518,7 +550,7 @@ class FaceMaskInvocation(BaseInvocation):
|
||||
)
|
||||
invert_mask: bool = InputField(default=False, description="Toggle to invert the mask")
|
||||
|
||||
@validator("face_ids")
|
||||
@field_validator("face_ids")
|
||||
def validate_comma_separated_ints(cls, v) -> str:
|
||||
comma_separated_ints_regex = re.compile(r"^\d*(,\d+)*$")
|
||||
if comma_separated_ints_regex.match(v) is None:
|
||||
@ -616,7 +648,7 @@ class FaceMaskInvocation(BaseInvocation):
|
||||
|
||||
|
||||
@invocation(
|
||||
"face_identifier", title="FaceIdentifier", tags=["image", "face", "identifier"], category="image", version="1.0.1"
|
||||
"face_identifier", title="FaceIdentifier", tags=["image", "face", "identifier"], category="image", version="1.0.2"
|
||||
)
|
||||
class FaceIdentifierInvocation(BaseInvocation):
|
||||
"""Outputs an image with detected face IDs printed on each face. For use with other FaceTools."""
|
||||
|
@ -9,10 +9,10 @@ from PIL import Image, ImageChops, ImageFilter, ImageOps
|
||||
|
||||
from invokeai.app.invocations.metadata import CoreMetadata
|
||||
from invokeai.app.invocations.primitives import BoardField, ColorField, ImageField, ImageOutput
|
||||
from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin
|
||||
from invokeai.backend.image_util.invisible_watermark import InvisibleWatermark
|
||||
from invokeai.backend.image_util.safety_checker import SafetyChecker
|
||||
|
||||
from ..models.image import ImageCategory, ResourceOrigin
|
||||
from .baseinvocation import BaseInvocation, FieldDescriptions, Input, InputField, InvocationContext, invocation
|
||||
|
||||
|
||||
@ -36,7 +36,13 @@ class ShowImageInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
|
||||
@invocation("blank_image", title="Blank Image", tags=["image"], category="image", version="1.0.0")
|
||||
@invocation(
|
||||
"blank_image",
|
||||
title="Blank Image",
|
||||
tags=["image"],
|
||||
category="image",
|
||||
version="1.0.0",
|
||||
)
|
||||
class BlankImageInvocation(BaseInvocation):
|
||||
"""Creates a blank image and forwards it to the pipeline"""
|
||||
|
||||
@ -65,7 +71,13 @@ class BlankImageInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
|
||||
@invocation("img_crop", title="Crop Image", tags=["image", "crop"], category="image", version="1.0.0")
|
||||
@invocation(
|
||||
"img_crop",
|
||||
title="Crop Image",
|
||||
tags=["image", "crop"],
|
||||
category="image",
|
||||
version="1.0.0",
|
||||
)
|
||||
class ImageCropInvocation(BaseInvocation):
|
||||
"""Crops an image to a specified box. The box can be outside of the image."""
|
||||
|
||||
@ -98,7 +110,13 @@ class ImageCropInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
|
||||
@invocation("img_paste", title="Paste Image", tags=["image", "paste"], category="image", version="1.0.1")
|
||||
@invocation(
|
||||
"img_paste",
|
||||
title="Paste Image",
|
||||
tags=["image", "paste"],
|
||||
category="image",
|
||||
version="1.0.1",
|
||||
)
|
||||
class ImagePasteInvocation(BaseInvocation):
|
||||
"""Pastes an image into another image."""
|
||||
|
||||
@ -151,7 +169,13 @@ class ImagePasteInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
|
||||
@invocation("tomask", title="Mask from Alpha", tags=["image", "mask"], category="image", version="1.0.0")
|
||||
@invocation(
|
||||
"tomask",
|
||||
title="Mask from Alpha",
|
||||
tags=["image", "mask"],
|
||||
category="image",
|
||||
version="1.0.0",
|
||||
)
|
||||
class MaskFromAlphaInvocation(BaseInvocation):
|
||||
"""Extracts the alpha channel of an image as a mask."""
|
||||
|
||||
@ -182,7 +206,13 @@ class MaskFromAlphaInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
|
||||
@invocation("img_mul", title="Multiply Images", tags=["image", "multiply"], category="image", version="1.0.0")
|
||||
@invocation(
|
||||
"img_mul",
|
||||
title="Multiply Images",
|
||||
tags=["image", "multiply"],
|
||||
category="image",
|
||||
version="1.0.0",
|
||||
)
|
||||
class ImageMultiplyInvocation(BaseInvocation):
|
||||
"""Multiplies two images together using `PIL.ImageChops.multiply()`."""
|
||||
|
||||
@ -215,7 +245,13 @@ class ImageMultiplyInvocation(BaseInvocation):
|
||||
IMAGE_CHANNELS = Literal["A", "R", "G", "B"]
|
||||
|
||||
|
||||
@invocation("img_chan", title="Extract Image Channel", tags=["image", "channel"], category="image", version="1.0.0")
|
||||
@invocation(
|
||||
"img_chan",
|
||||
title="Extract Image Channel",
|
||||
tags=["image", "channel"],
|
||||
category="image",
|
||||
version="1.0.0",
|
||||
)
|
||||
class ImageChannelInvocation(BaseInvocation):
|
||||
"""Gets a channel from an image."""
|
||||
|
||||
@ -247,7 +283,13 @@ class ImageChannelInvocation(BaseInvocation):
|
||||
IMAGE_MODES = Literal["L", "RGB", "RGBA", "CMYK", "YCbCr", "LAB", "HSV", "I", "F"]
|
||||
|
||||
|
||||
@invocation("img_conv", title="Convert Image Mode", tags=["image", "convert"], category="image", version="1.0.0")
|
||||
@invocation(
|
||||
"img_conv",
|
||||
title="Convert Image Mode",
|
||||
tags=["image", "convert"],
|
||||
category="image",
|
||||
version="1.0.0",
|
||||
)
|
||||
class ImageConvertInvocation(BaseInvocation):
|
||||
"""Converts an image to a different mode."""
|
||||
|
||||
@ -276,7 +318,13 @@ class ImageConvertInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
|
||||
@invocation("img_blur", title="Blur Image", tags=["image", "blur"], category="image", version="1.0.0")
|
||||
@invocation(
|
||||
"img_blur",
|
||||
title="Blur Image",
|
||||
tags=["image", "blur"],
|
||||
category="image",
|
||||
version="1.0.0",
|
||||
)
|
||||
class ImageBlurInvocation(BaseInvocation):
|
||||
"""Blurs an image"""
|
||||
|
||||
@ -330,7 +378,13 @@ PIL_RESAMPLING_MAP = {
|
||||
}
|
||||
|
||||
|
||||
@invocation("img_resize", title="Resize Image", tags=["image", "resize"], category="image", version="1.0.0")
|
||||
@invocation(
|
||||
"img_resize",
|
||||
title="Resize Image",
|
||||
tags=["image", "resize"],
|
||||
category="image",
|
||||
version="1.0.0",
|
||||
)
|
||||
class ImageResizeInvocation(BaseInvocation):
|
||||
"""Resizes an image to specific dimensions"""
|
||||
|
||||
@ -359,7 +413,7 @@ class ImageResizeInvocation(BaseInvocation):
|
||||
node_id=self.id,
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
metadata=self.metadata.dict() if self.metadata else None,
|
||||
metadata=self.metadata.model_dump() if self.metadata else None,
|
||||
workflow=self.workflow,
|
||||
)
|
||||
|
||||
@ -370,7 +424,13 @@ class ImageResizeInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
|
||||
@invocation("img_scale", title="Scale Image", tags=["image", "scale"], category="image", version="1.0.0")
|
||||
@invocation(
|
||||
"img_scale",
|
||||
title="Scale Image",
|
||||
tags=["image", "scale"],
|
||||
category="image",
|
||||
version="1.0.0",
|
||||
)
|
||||
class ImageScaleInvocation(BaseInvocation):
|
||||
"""Scales an image by a factor"""
|
||||
|
||||
@ -411,7 +471,13 @@ class ImageScaleInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
|
||||
@invocation("img_lerp", title="Lerp Image", tags=["image", "lerp"], category="image", version="1.0.0")
|
||||
@invocation(
|
||||
"img_lerp",
|
||||
title="Lerp Image",
|
||||
tags=["image", "lerp"],
|
||||
category="image",
|
||||
version="1.0.0",
|
||||
)
|
||||
class ImageLerpInvocation(BaseInvocation):
|
||||
"""Linear interpolation of all pixels of an image"""
|
||||
|
||||
@ -444,7 +510,13 @@ class ImageLerpInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
|
||||
@invocation("img_ilerp", title="Inverse Lerp Image", tags=["image", "ilerp"], category="image", version="1.0.0")
|
||||
@invocation(
|
||||
"img_ilerp",
|
||||
title="Inverse Lerp Image",
|
||||
tags=["image", "ilerp"],
|
||||
category="image",
|
||||
version="1.0.0",
|
||||
)
|
||||
class ImageInverseLerpInvocation(BaseInvocation):
|
||||
"""Inverse linear interpolation of all pixels of an image"""
|
||||
|
||||
@ -456,7 +528,7 @@ class ImageInverseLerpInvocation(BaseInvocation):
|
||||
image = context.services.images.get_pil_image(self.image.image_name)
|
||||
|
||||
image_arr = numpy.asarray(image, dtype=numpy.float32)
|
||||
image_arr = numpy.minimum(numpy.maximum(image_arr - self.min, 0) / float(self.max - self.min), 1) * 255
|
||||
image_arr = numpy.minimum(numpy.maximum(image_arr - self.min, 0) / float(self.max - self.min), 1) * 255 # type: ignore [assignment]
|
||||
|
||||
ilerp_image = Image.fromarray(numpy.uint8(image_arr))
|
||||
|
||||
@ -477,7 +549,13 @@ class ImageInverseLerpInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
|
||||
@invocation("img_nsfw", title="Blur NSFW Image", tags=["image", "nsfw"], category="image", version="1.0.0")
|
||||
@invocation(
|
||||
"img_nsfw",
|
||||
title="Blur NSFW Image",
|
||||
tags=["image", "nsfw"],
|
||||
category="image",
|
||||
version="1.0.0",
|
||||
)
|
||||
class ImageNSFWBlurInvocation(BaseInvocation):
|
||||
"""Add blur to NSFW-flagged images"""
|
||||
|
||||
@ -505,7 +583,7 @@ class ImageNSFWBlurInvocation(BaseInvocation):
|
||||
node_id=self.id,
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
metadata=self.metadata.dict() if self.metadata else None,
|
||||
metadata=self.metadata.model_dump() if self.metadata else None,
|
||||
workflow=self.workflow,
|
||||
)
|
||||
|
||||
@ -515,7 +593,7 @@ class ImageNSFWBlurInvocation(BaseInvocation):
|
||||
height=image_dto.height,
|
||||
)
|
||||
|
||||
def _get_caution_img(self) -> Image:
|
||||
def _get_caution_img(self) -> Image.Image:
|
||||
import invokeai.app.assets.images as image_assets
|
||||
|
||||
caution = Image.open(Path(image_assets.__path__[0]) / "caution.png")
|
||||
@ -523,7 +601,11 @@ class ImageNSFWBlurInvocation(BaseInvocation):
|
||||
|
||||
|
||||
@invocation(
|
||||
"img_watermark", title="Add Invisible Watermark", tags=["image", "watermark"], category="image", version="1.0.0"
|
||||
"img_watermark",
|
||||
title="Add Invisible Watermark",
|
||||
tags=["image", "watermark"],
|
||||
category="image",
|
||||
version="1.0.0",
|
||||
)
|
||||
class ImageWatermarkInvocation(BaseInvocation):
|
||||
"""Add an invisible watermark to an image"""
|
||||
@ -544,7 +626,7 @@ class ImageWatermarkInvocation(BaseInvocation):
|
||||
node_id=self.id,
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
metadata=self.metadata.dict() if self.metadata else None,
|
||||
metadata=self.metadata.model_dump() if self.metadata else None,
|
||||
workflow=self.workflow,
|
||||
)
|
||||
|
||||
@ -555,7 +637,13 @@ class ImageWatermarkInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
|
||||
@invocation("mask_edge", title="Mask Edge", tags=["image", "mask", "inpaint"], category="image", version="1.0.0")
|
||||
@invocation(
|
||||
"mask_edge",
|
||||
title="Mask Edge",
|
||||
tags=["image", "mask", "inpaint"],
|
||||
category="image",
|
||||
version="1.0.0",
|
||||
)
|
||||
class MaskEdgeInvocation(BaseInvocation):
|
||||
"""Applies an edge mask to an image"""
|
||||
|
||||
@ -601,7 +689,11 @@ class MaskEdgeInvocation(BaseInvocation):
|
||||
|
||||
|
||||
@invocation(
|
||||
"mask_combine", title="Combine Masks", tags=["image", "mask", "multiply"], category="image", version="1.0.0"
|
||||
"mask_combine",
|
||||
title="Combine Masks",
|
||||
tags=["image", "mask", "multiply"],
|
||||
category="image",
|
||||
version="1.0.0",
|
||||
)
|
||||
class MaskCombineInvocation(BaseInvocation):
|
||||
"""Combine two masks together by multiplying them using `PIL.ImageChops.multiply()`."""
|
||||
@ -632,7 +724,13 @@ class MaskCombineInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
|
||||
@invocation("color_correct", title="Color Correct", tags=["image", "color"], category="image", version="1.0.0")
|
||||
@invocation(
|
||||
"color_correct",
|
||||
title="Color Correct",
|
||||
tags=["image", "color"],
|
||||
category="image",
|
||||
version="1.0.0",
|
||||
)
|
||||
class ColorCorrectInvocation(BaseInvocation):
|
||||
"""
|
||||
Shifts the colors of a target image to match the reference image, optionally
|
||||
@ -742,7 +840,13 @@ class ColorCorrectInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
|
||||
@invocation("img_hue_adjust", title="Adjust Image Hue", tags=["image", "hue"], category="image", version="1.0.0")
|
||||
@invocation(
|
||||
"img_hue_adjust",
|
||||
title="Adjust Image Hue",
|
||||
tags=["image", "hue"],
|
||||
category="image",
|
||||
version="1.0.0",
|
||||
)
|
||||
class ImageHueAdjustmentInvocation(BaseInvocation):
|
||||
"""Adjusts the Hue of an image."""
|
||||
|
||||
@ -980,7 +1084,7 @@ class SaveImageInvocation(BaseInvocation):
|
||||
|
||||
image: ImageField = InputField(description=FieldDescriptions.image)
|
||||
board: Optional[BoardField] = InputField(default=None, description=FieldDescriptions.board, input=Input.Direct)
|
||||
metadata: CoreMetadata = InputField(
|
||||
metadata: Optional[CoreMetadata] = InputField(
|
||||
default=None,
|
||||
description=FieldDescriptions.core_metadata,
|
||||
ui_hidden=True,
|
||||
@ -997,7 +1101,7 @@ class SaveImageInvocation(BaseInvocation):
|
||||
node_id=self.id,
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
metadata=self.metadata.dict() if self.metadata else None,
|
||||
metadata=self.metadata.model_dump() if self.metadata else None,
|
||||
workflow=self.workflow,
|
||||
)
|
||||
|
||||
|
@ -7,12 +7,12 @@ import numpy as np
|
||||
from PIL import Image, ImageOps
|
||||
|
||||
from invokeai.app.invocations.primitives import ColorField, ImageField, ImageOutput
|
||||
from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin
|
||||
from invokeai.app.util.misc import SEED_MAX, get_random_seed
|
||||
from invokeai.backend.image_util.cv2_inpaint import cv2_inpaint
|
||||
from invokeai.backend.image_util.lama import LaMA
|
||||
from invokeai.backend.image_util.patchmatch import PatchMatch
|
||||
|
||||
from ..models.image import ImageCategory, ResourceOrigin
|
||||
from .baseinvocation import BaseInvocation, InputField, InvocationContext, invocation
|
||||
from .image import PIL_RESAMPLING_MAP, PIL_RESAMPLING_MODES
|
||||
|
||||
|
@ -2,7 +2,7 @@ import os
|
||||
from builtins import float
|
||||
from typing import List, Union
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import (
|
||||
BaseInvocation,
|
||||
@ -25,11 +25,15 @@ class IPAdapterModelField(BaseModel):
|
||||
model_name: str = Field(description="Name of the IP-Adapter model")
|
||||
base_model: BaseModelType = Field(description="Base model")
|
||||
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
|
||||
class CLIPVisionModelField(BaseModel):
|
||||
model_name: str = Field(description="Name of the CLIP Vision image encoder model")
|
||||
base_model: BaseModelType = Field(description="Base model (usually 'Any')")
|
||||
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
|
||||
class IPAdapterField(BaseModel):
|
||||
image: ImageField = Field(description="The IP-Adapter image prompt.")
|
||||
|
@ -19,7 +19,7 @@ from diffusers.models.attention_processor import (
|
||||
)
|
||||
from diffusers.schedulers import DPMSolverSDEScheduler
|
||||
from diffusers.schedulers import SchedulerMixin as Scheduler
|
||||
from pydantic import validator
|
||||
from pydantic import field_validator
|
||||
from torchvision.transforms.functional import resize as tv_resize
|
||||
|
||||
from invokeai.app.invocations.ip_adapter import IPAdapterField
|
||||
@ -34,6 +34,7 @@ from invokeai.app.invocations.primitives import (
|
||||
build_latents_output,
|
||||
)
|
||||
from invokeai.app.invocations.t2i_adapter import T2IAdapterField
|
||||
from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin
|
||||
from invokeai.app.util.controlnet_utils import prepare_control_image
|
||||
from invokeai.app.util.step_callback import stable_diffusion_step_callback
|
||||
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter, IPAdapterPlus
|
||||
@ -54,7 +55,6 @@ from ...backend.stable_diffusion.diffusers_pipeline import (
|
||||
from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import PostprocessingSettings
|
||||
from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP
|
||||
from ...backend.util.devices import choose_precision, choose_torch_device
|
||||
from ..models.image import ImageCategory, ResourceOrigin
|
||||
from .baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
@ -84,12 +84,20 @@ class SchedulerOutput(BaseInvocationOutput):
|
||||
scheduler: SAMPLER_NAME_VALUES = OutputField(description=FieldDescriptions.scheduler, ui_type=UIType.Scheduler)
|
||||
|
||||
|
||||
@invocation("scheduler", title="Scheduler", tags=["scheduler"], category="latents", version="1.0.0")
|
||||
@invocation(
|
||||
"scheduler",
|
||||
title="Scheduler",
|
||||
tags=["scheduler"],
|
||||
category="latents",
|
||||
version="1.0.0",
|
||||
)
|
||||
class SchedulerInvocation(BaseInvocation):
|
||||
"""Selects a scheduler."""
|
||||
|
||||
scheduler: SAMPLER_NAME_VALUES = InputField(
|
||||
default="euler", description=FieldDescriptions.scheduler, ui_type=UIType.Scheduler
|
||||
default="euler",
|
||||
description=FieldDescriptions.scheduler,
|
||||
ui_type=UIType.Scheduler,
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> SchedulerOutput:
|
||||
@ -97,7 +105,11 @@ class SchedulerInvocation(BaseInvocation):
|
||||
|
||||
|
||||
@invocation(
|
||||
"create_denoise_mask", title="Create Denoise Mask", tags=["mask", "denoise"], category="latents", version="1.0.0"
|
||||
"create_denoise_mask",
|
||||
title="Create Denoise Mask",
|
||||
tags=["mask", "denoise"],
|
||||
category="latents",
|
||||
version="1.0.0",
|
||||
)
|
||||
class CreateDenoiseMaskInvocation(BaseInvocation):
|
||||
"""Creates mask for denoising model run."""
|
||||
@ -106,7 +118,11 @@ class CreateDenoiseMaskInvocation(BaseInvocation):
|
||||
image: Optional[ImageField] = InputField(default=None, description="Image which will be masked", ui_order=1)
|
||||
mask: ImageField = InputField(description="The mask to use when pasting", ui_order=2)
|
||||
tiled: bool = InputField(default=False, description=FieldDescriptions.tiled, ui_order=3)
|
||||
fp32: bool = InputField(default=DEFAULT_PRECISION == "float32", description=FieldDescriptions.fp32, ui_order=4)
|
||||
fp32: bool = InputField(
|
||||
default=DEFAULT_PRECISION == "float32",
|
||||
description=FieldDescriptions.fp32,
|
||||
ui_order=4,
|
||||
)
|
||||
|
||||
def prep_mask_tensor(self, mask_image):
|
||||
if mask_image.mode != "L":
|
||||
@ -134,7 +150,7 @@ class CreateDenoiseMaskInvocation(BaseInvocation):
|
||||
|
||||
if image is not None:
|
||||
vae_info = context.services.model_manager.get_model(
|
||||
**self.vae.vae.dict(),
|
||||
**self.vae.vae.model_dump(),
|
||||
context=context,
|
||||
)
|
||||
|
||||
@ -167,7 +183,7 @@ def get_scheduler(
|
||||
) -> Scheduler:
|
||||
scheduler_class, scheduler_extra_config = SCHEDULER_MAP.get(scheduler_name, SCHEDULER_MAP["ddim"])
|
||||
orig_scheduler_info = context.services.model_manager.get_model(
|
||||
**scheduler_info.dict(),
|
||||
**scheduler_info.model_dump(),
|
||||
context=context,
|
||||
)
|
||||
with orig_scheduler_info as orig_scheduler:
|
||||
@ -209,34 +225,64 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
negative_conditioning: ConditioningField = InputField(
|
||||
description=FieldDescriptions.negative_cond, input=Input.Connection, ui_order=1
|
||||
)
|
||||
noise: Optional[LatentsField] = InputField(description=FieldDescriptions.noise, input=Input.Connection, ui_order=3)
|
||||
noise: Optional[LatentsField] = InputField(
|
||||
default=None,
|
||||
description=FieldDescriptions.noise,
|
||||
input=Input.Connection,
|
||||
ui_order=3,
|
||||
)
|
||||
steps: int = InputField(default=10, gt=0, description=FieldDescriptions.steps)
|
||||
cfg_scale: Union[float, List[float]] = InputField(
|
||||
default=7.5, ge=1, description=FieldDescriptions.cfg_scale, title="CFG Scale"
|
||||
)
|
||||
denoising_start: float = InputField(default=0.0, ge=0, le=1, description=FieldDescriptions.denoising_start)
|
||||
denoising_start: float = InputField(
|
||||
default=0.0,
|
||||
ge=0,
|
||||
le=1,
|
||||
description=FieldDescriptions.denoising_start,
|
||||
)
|
||||
denoising_end: float = InputField(default=1.0, ge=0, le=1, description=FieldDescriptions.denoising_end)
|
||||
scheduler: SAMPLER_NAME_VALUES = InputField(
|
||||
default="euler", description=FieldDescriptions.scheduler, ui_type=UIType.Scheduler
|
||||
default="euler",
|
||||
description=FieldDescriptions.scheduler,
|
||||
ui_type=UIType.Scheduler,
|
||||
)
|
||||
unet: UNetField = InputField(description=FieldDescriptions.unet, input=Input.Connection, title="UNet", ui_order=2)
|
||||
control: Union[ControlField, list[ControlField]] = InputField(
|
||||
unet: UNetField = InputField(
|
||||
description=FieldDescriptions.unet,
|
||||
input=Input.Connection,
|
||||
title="UNet",
|
||||
ui_order=2,
|
||||
)
|
||||
control: Optional[Union[ControlField, list[ControlField]]] = InputField(
|
||||
default=None,
|
||||
input=Input.Connection,
|
||||
ui_order=5,
|
||||
)
|
||||
ip_adapter: Optional[Union[IPAdapterField, list[IPAdapterField]]] = InputField(
|
||||
description=FieldDescriptions.ip_adapter, title="IP-Adapter", default=None, input=Input.Connection, ui_order=6
|
||||
description=FieldDescriptions.ip_adapter,
|
||||
title="IP-Adapter",
|
||||
default=None,
|
||||
input=Input.Connection,
|
||||
ui_order=6,
|
||||
)
|
||||
t2i_adapter: Union[T2IAdapterField, list[T2IAdapterField]] = InputField(
|
||||
description=FieldDescriptions.t2i_adapter, title="T2I-Adapter", default=None, input=Input.Connection, ui_order=7
|
||||
t2i_adapter: Optional[Union[T2IAdapterField, list[T2IAdapterField]]] = InputField(
|
||||
description=FieldDescriptions.t2i_adapter,
|
||||
title="T2I-Adapter",
|
||||
default=None,
|
||||
input=Input.Connection,
|
||||
ui_order=7,
|
||||
)
|
||||
latents: Optional[LatentsField] = InputField(
|
||||
default=None, description=FieldDescriptions.latents, input=Input.Connection
|
||||
)
|
||||
latents: Optional[LatentsField] = InputField(description=FieldDescriptions.latents, input=Input.Connection)
|
||||
denoise_mask: Optional[DenoiseMaskField] = InputField(
|
||||
default=None, description=FieldDescriptions.mask, input=Input.Connection, ui_order=8
|
||||
default=None,
|
||||
description=FieldDescriptions.mask,
|
||||
input=Input.Connection,
|
||||
ui_order=8,
|
||||
)
|
||||
|
||||
@validator("cfg_scale")
|
||||
@field_validator("cfg_scale")
|
||||
def ge_one(cls, v):
|
||||
"""validate that all cfg_scale values are >= 1"""
|
||||
if isinstance(v, list):
|
||||
@ -259,7 +305,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
stable_diffusion_step_callback(
|
||||
context=context,
|
||||
intermediate_state=intermediate_state,
|
||||
node=self.dict(),
|
||||
node=self.model_dump(),
|
||||
source_node_id=source_node_id,
|
||||
base_model=base_model,
|
||||
)
|
||||
@ -451,9 +497,10 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
# models are needed in memory. This would help to reduce peak memory utilization in low-memory environments.
|
||||
with image_encoder_model_info as image_encoder_model:
|
||||
# Get image embeddings from CLIP and ImageProjModel.
|
||||
image_prompt_embeds, uncond_image_prompt_embeds = ip_adapter_model.get_image_embeds(
|
||||
input_image, image_encoder_model
|
||||
)
|
||||
(
|
||||
image_prompt_embeds,
|
||||
uncond_image_prompt_embeds,
|
||||
) = ip_adapter_model.get_image_embeds(input_image, image_encoder_model)
|
||||
conditioning_data.ip_adapter_conditioning.append(
|
||||
IPAdapterConditioningInfo(image_prompt_embeds, uncond_image_prompt_embeds)
|
||||
)
|
||||
@ -628,7 +675,10 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
# TODO(ryand): I have hard-coded `do_classifier_free_guidance=True` to mirror the behaviour of ControlNets,
|
||||
# below. Investigate whether this is appropriate.
|
||||
t2i_adapter_data = self.run_t2i_adapters(
|
||||
context, self.t2i_adapter, latents.shape, do_classifier_free_guidance=True
|
||||
context,
|
||||
self.t2i_adapter,
|
||||
latents.shape,
|
||||
do_classifier_free_guidance=True,
|
||||
)
|
||||
|
||||
# Get the source node id (we are invoking the prepared node)
|
||||
@ -641,7 +691,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
def _lora_loader():
|
||||
for lora in self.unet.loras:
|
||||
lora_info = context.services.model_manager.get_model(
|
||||
**lora.dict(exclude={"weight"}),
|
||||
**lora.model_dump(exclude={"weight"}),
|
||||
context=context,
|
||||
)
|
||||
yield (lora_info.context.model, lora.weight)
|
||||
@ -649,7 +699,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
return
|
||||
|
||||
unet_info = context.services.model_manager.get_model(
|
||||
**self.unet.unet.dict(),
|
||||
**self.unet.unet.model_dump(),
|
||||
context=context,
|
||||
)
|
||||
with (
|
||||
@ -701,7 +751,10 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
denoising_end=self.denoising_end,
|
||||
)
|
||||
|
||||
result_latents, result_attention_map_saver = pipeline.latents_from_embeddings(
|
||||
(
|
||||
result_latents,
|
||||
result_attention_map_saver,
|
||||
) = pipeline.latents_from_embeddings(
|
||||
latents=latents,
|
||||
timesteps=timesteps,
|
||||
init_timestep=init_timestep,
|
||||
@ -729,7 +782,11 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
|
||||
|
||||
@invocation(
|
||||
"l2i", title="Latents to Image", tags=["latents", "image", "vae", "l2i"], category="latents", version="1.0.0"
|
||||
"l2i",
|
||||
title="Latents to Image",
|
||||
tags=["latents", "image", "vae", "l2i"],
|
||||
category="latents",
|
||||
version="1.0.0",
|
||||
)
|
||||
class LatentsToImageInvocation(BaseInvocation):
|
||||
"""Generates an image from latents."""
|
||||
@ -744,7 +801,7 @@ class LatentsToImageInvocation(BaseInvocation):
|
||||
)
|
||||
tiled: bool = InputField(default=False, description=FieldDescriptions.tiled)
|
||||
fp32: bool = InputField(default=DEFAULT_PRECISION == "float32", description=FieldDescriptions.fp32)
|
||||
metadata: CoreMetadata = InputField(
|
||||
metadata: Optional[CoreMetadata] = InputField(
|
||||
default=None,
|
||||
description=FieldDescriptions.core_metadata,
|
||||
ui_hidden=True,
|
||||
@ -755,7 +812,7 @@ class LatentsToImageInvocation(BaseInvocation):
|
||||
latents = context.services.latents.get(self.latents.latents_name)
|
||||
|
||||
vae_info = context.services.model_manager.get_model(
|
||||
**self.vae.vae.dict(),
|
||||
**self.vae.vae.model_dump(),
|
||||
context=context,
|
||||
)
|
||||
|
||||
@ -817,7 +874,7 @@ class LatentsToImageInvocation(BaseInvocation):
|
||||
node_id=self.id,
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
metadata=self.metadata.dict() if self.metadata else None,
|
||||
metadata=self.metadata.model_dump() if self.metadata else None,
|
||||
workflow=self.workflow,
|
||||
)
|
||||
|
||||
@ -831,7 +888,13 @@ class LatentsToImageInvocation(BaseInvocation):
|
||||
LATENTS_INTERPOLATION_MODE = Literal["nearest", "linear", "bilinear", "bicubic", "trilinear", "area", "nearest-exact"]
|
||||
|
||||
|
||||
@invocation("lresize", title="Resize Latents", tags=["latents", "resize"], category="latents", version="1.0.0")
|
||||
@invocation(
|
||||
"lresize",
|
||||
title="Resize Latents",
|
||||
tags=["latents", "resize"],
|
||||
category="latents",
|
||||
version="1.0.0",
|
||||
)
|
||||
class ResizeLatentsInvocation(BaseInvocation):
|
||||
"""Resizes latents to explicit width/height (in pixels). Provided dimensions are floor-divided by 8."""
|
||||
|
||||
@ -877,7 +940,13 @@ class ResizeLatentsInvocation(BaseInvocation):
|
||||
return build_latents_output(latents_name=name, latents=resized_latents, seed=self.latents.seed)
|
||||
|
||||
|
||||
@invocation("lscale", title="Scale Latents", tags=["latents", "resize"], category="latents", version="1.0.0")
|
||||
@invocation(
|
||||
"lscale",
|
||||
title="Scale Latents",
|
||||
tags=["latents", "resize"],
|
||||
category="latents",
|
||||
version="1.0.0",
|
||||
)
|
||||
class ScaleLatentsInvocation(BaseInvocation):
|
||||
"""Scales latents by a given factor."""
|
||||
|
||||
@ -916,7 +985,11 @@ class ScaleLatentsInvocation(BaseInvocation):
|
||||
|
||||
|
||||
@invocation(
|
||||
"i2l", title="Image to Latents", tags=["latents", "image", "vae", "i2l"], category="latents", version="1.0.0"
|
||||
"i2l",
|
||||
title="Image to Latents",
|
||||
tags=["latents", "image", "vae", "i2l"],
|
||||
category="latents",
|
||||
version="1.0.0",
|
||||
)
|
||||
class ImageToLatentsInvocation(BaseInvocation):
|
||||
"""Encodes an image into latents."""
|
||||
@ -980,7 +1053,7 @@ class ImageToLatentsInvocation(BaseInvocation):
|
||||
image = context.services.images.get_pil_image(self.image.image_name)
|
||||
|
||||
vae_info = context.services.model_manager.get_model(
|
||||
**self.vae.vae.dict(),
|
||||
**self.vae.vae.model_dump(),
|
||||
context=context,
|
||||
)
|
||||
|
||||
@ -1008,7 +1081,13 @@ class ImageToLatentsInvocation(BaseInvocation):
|
||||
return vae.encode(image_tensor).latents
|
||||
|
||||
|
||||
@invocation("lblend", title="Blend Latents", tags=["latents", "blend"], category="latents", version="1.0.0")
|
||||
@invocation(
|
||||
"lblend",
|
||||
title="Blend Latents",
|
||||
tags=["latents", "blend"],
|
||||
category="latents",
|
||||
version="1.0.0",
|
||||
)
|
||||
class BlendLatentsInvocation(BaseInvocation):
|
||||
"""Blend two latents using a given alpha. Latents must have same size."""
|
||||
|
||||
|
@ -3,7 +3,7 @@
|
||||
from typing import Literal
|
||||
|
||||
import numpy as np
|
||||
from pydantic import validator
|
||||
from pydantic import field_validator
|
||||
|
||||
from invokeai.app.invocations.primitives import FloatOutput, IntegerOutput
|
||||
|
||||
@ -72,7 +72,14 @@ class RandomIntInvocation(BaseInvocation):
|
||||
return IntegerOutput(value=np.random.randint(self.low, self.high))
|
||||
|
||||
|
||||
@invocation("rand_float", title="Random Float", tags=["math", "float", "random"], category="math", version="1.0.0")
|
||||
@invocation(
|
||||
"rand_float",
|
||||
title="Random Float",
|
||||
tags=["math", "float", "random"],
|
||||
category="math",
|
||||
version="1.0.1",
|
||||
use_cache=False,
|
||||
)
|
||||
class RandomFloatInvocation(BaseInvocation):
|
||||
"""Outputs a single random float"""
|
||||
|
||||
@ -178,7 +185,7 @@ class IntegerMathInvocation(BaseInvocation):
|
||||
a: int = InputField(default=0, description=FieldDescriptions.num_1)
|
||||
b: int = InputField(default=0, description=FieldDescriptions.num_2)
|
||||
|
||||
@validator("b")
|
||||
@field_validator("b")
|
||||
def no_unrepresentable_results(cls, v, values):
|
||||
if values["operation"] == "DIV" and v == 0:
|
||||
raise ValueError("Cannot divide by zero")
|
||||
@ -252,7 +259,7 @@ class FloatMathInvocation(BaseInvocation):
|
||||
a: float = InputField(default=0, description=FieldDescriptions.num_1)
|
||||
b: float = InputField(default=0, description=FieldDescriptions.num_2)
|
||||
|
||||
@validator("b")
|
||||
@field_validator("b")
|
||||
def no_unrepresentable_results(cls, v, values):
|
||||
if values["operation"] == "DIV" and v == 0:
|
||||
raise ValueError("Cannot divide by zero")
|
||||
|
@ -44,28 +44,31 @@ class CoreMetadata(BaseModelExcludeNull):
|
||||
"""Core generation metadata for an image generated in InvokeAI."""
|
||||
|
||||
app_version: str = Field(default=__version__, description="The version of InvokeAI used to generate this image")
|
||||
generation_mode: str = Field(
|
||||
generation_mode: Optional[str] = Field(
|
||||
default=None,
|
||||
description="The generation mode that output this image",
|
||||
)
|
||||
created_by: Optional[str] = Field(description="The name of the creator of the image")
|
||||
positive_prompt: str = Field(description="The positive prompt parameter")
|
||||
negative_prompt: str = Field(description="The negative prompt parameter")
|
||||
width: int = Field(description="The width parameter")
|
||||
height: int = Field(description="The height parameter")
|
||||
seed: int = Field(description="The seed used for noise generation")
|
||||
rand_device: str = Field(description="The device used for random number generation")
|
||||
cfg_scale: float = Field(description="The classifier-free guidance scale parameter")
|
||||
steps: int = Field(description="The number of steps used for inference")
|
||||
scheduler: str = Field(description="The scheduler used for inference")
|
||||
created_by: Optional[str] = Field(default=None, description="The name of the creator of the image")
|
||||
positive_prompt: Optional[str] = Field(default=None, description="The positive prompt parameter")
|
||||
negative_prompt: Optional[str] = Field(default=None, description="The negative prompt parameter")
|
||||
width: Optional[int] = Field(default=None, description="The width parameter")
|
||||
height: Optional[int] = Field(default=None, description="The height parameter")
|
||||
seed: Optional[int] = Field(default=None, description="The seed used for noise generation")
|
||||
rand_device: Optional[str] = Field(default=None, description="The device used for random number generation")
|
||||
cfg_scale: Optional[float] = Field(default=None, description="The classifier-free guidance scale parameter")
|
||||
steps: Optional[int] = Field(default=None, description="The number of steps used for inference")
|
||||
scheduler: Optional[str] = Field(default=None, description="The scheduler used for inference")
|
||||
clip_skip: Optional[int] = Field(
|
||||
default=None,
|
||||
description="The number of skipped CLIP layers",
|
||||
)
|
||||
model: MainModelField = Field(description="The main model used for inference")
|
||||
controlnets: list[ControlField] = Field(description="The ControlNets used for inference")
|
||||
ipAdapters: list[IPAdapterMetadataField] = Field(description="The IP Adapters used for inference")
|
||||
t2iAdapters: list[T2IAdapterField] = Field(description="The IP Adapters used for inference")
|
||||
loras: list[LoRAMetadataField] = Field(description="The LoRAs used for inference")
|
||||
model: Optional[MainModelField] = Field(default=None, description="The main model used for inference")
|
||||
controlnets: Optional[list[ControlField]] = Field(default=None, description="The ControlNets used for inference")
|
||||
ipAdapters: Optional[list[IPAdapterMetadataField]] = Field(
|
||||
default=None, description="The IP Adapters used for inference"
|
||||
)
|
||||
t2iAdapters: Optional[list[T2IAdapterField]] = Field(default=None, description="The IP Adapters used for inference")
|
||||
loras: Optional[list[LoRAMetadataField]] = Field(default=None, description="The LoRAs used for inference")
|
||||
vae: Optional[VAEModelField] = Field(
|
||||
default=None,
|
||||
description="The VAE used for decoding, if the main model's default was not used",
|
||||
@ -122,27 +125,34 @@ class MetadataAccumulatorOutput(BaseInvocationOutput):
|
||||
class MetadataAccumulatorInvocation(BaseInvocation):
|
||||
"""Outputs a Core Metadata Object"""
|
||||
|
||||
generation_mode: str = InputField(
|
||||
generation_mode: Optional[str] = InputField(
|
||||
default=None,
|
||||
description="The generation mode that output this image",
|
||||
)
|
||||
positive_prompt: str = InputField(description="The positive prompt parameter")
|
||||
negative_prompt: str = InputField(description="The negative prompt parameter")
|
||||
width: int = InputField(description="The width parameter")
|
||||
height: int = InputField(description="The height parameter")
|
||||
seed: int = InputField(description="The seed used for noise generation")
|
||||
rand_device: str = InputField(description="The device used for random number generation")
|
||||
cfg_scale: float = InputField(description="The classifier-free guidance scale parameter")
|
||||
steps: int = InputField(description="The number of steps used for inference")
|
||||
scheduler: str = InputField(description="The scheduler used for inference")
|
||||
clip_skip: Optional[int] = Field(
|
||||
positive_prompt: Optional[str] = InputField(default=None, description="The positive prompt parameter")
|
||||
negative_prompt: Optional[str] = InputField(default=None, description="The negative prompt parameter")
|
||||
width: Optional[int] = InputField(default=None, description="The width parameter")
|
||||
height: Optional[int] = InputField(default=None, description="The height parameter")
|
||||
seed: Optional[int] = InputField(default=None, description="The seed used for noise generation")
|
||||
rand_device: Optional[str] = InputField(default=None, description="The device used for random number generation")
|
||||
cfg_scale: Optional[float] = InputField(default=None, description="The classifier-free guidance scale parameter")
|
||||
steps: Optional[int] = InputField(default=None, description="The number of steps used for inference")
|
||||
scheduler: Optional[str] = InputField(default=None, description="The scheduler used for inference")
|
||||
clip_skip: Optional[int] = InputField(
|
||||
default=None,
|
||||
description="The number of skipped CLIP layers",
|
||||
)
|
||||
model: MainModelField = InputField(description="The main model used for inference")
|
||||
controlnets: list[ControlField] = InputField(description="The ControlNets used for inference")
|
||||
ipAdapters: list[IPAdapterMetadataField] = InputField(description="The IP Adapters used for inference")
|
||||
t2iAdapters: list[T2IAdapterField] = Field(description="The IP Adapters used for inference")
|
||||
loras: list[LoRAMetadataField] = InputField(description="The LoRAs used for inference")
|
||||
model: Optional[MainModelField] = InputField(default=None, description="The main model used for inference")
|
||||
controlnets: Optional[list[ControlField]] = InputField(
|
||||
default=None, description="The ControlNets used for inference"
|
||||
)
|
||||
ipAdapters: Optional[list[IPAdapterMetadataField]] = InputField(
|
||||
default=None, description="The IP Adapters used for inference"
|
||||
)
|
||||
t2iAdapters: Optional[list[T2IAdapterField]] = InputField(
|
||||
default=None, description="The IP Adapters used for inference"
|
||||
)
|
||||
loras: Optional[list[LoRAMetadataField]] = InputField(default=None, description="The LoRAs used for inference")
|
||||
strength: Optional[float] = InputField(
|
||||
default=None,
|
||||
description="The strength used for latents-to-latents",
|
||||
@ -156,6 +166,20 @@ class MetadataAccumulatorInvocation(BaseInvocation):
|
||||
description="The VAE used for decoding, if the main model's default was not used",
|
||||
)
|
||||
|
||||
# High resolution fix metadata.
|
||||
hrf_width: Optional[int] = InputField(
|
||||
default=None,
|
||||
description="The high resolution fix height and width multipler.",
|
||||
)
|
||||
hrf_height: Optional[int] = InputField(
|
||||
default=None,
|
||||
description="The high resolution fix height and width multipler.",
|
||||
)
|
||||
hrf_strength: Optional[float] = InputField(
|
||||
default=None,
|
||||
description="The high resolution fix img2img strength used in the upscale pass.",
|
||||
)
|
||||
|
||||
# SDXL
|
||||
positive_style_prompt: Optional[str] = InputField(
|
||||
default=None,
|
||||
@ -199,4 +223,4 @@ class MetadataAccumulatorInvocation(BaseInvocation):
|
||||
def invoke(self, context: InvocationContext) -> MetadataAccumulatorOutput:
|
||||
"""Collects and outputs a CoreMetadata object"""
|
||||
|
||||
return MetadataAccumulatorOutput(metadata=CoreMetadata(**self.dict()))
|
||||
return MetadataAccumulatorOutput(metadata=CoreMetadata(**self.model_dump()))
|
||||
|
@ -1,7 +1,7 @@
|
||||
import copy
|
||||
from typing import List, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from invokeai.app.invocations.shared import FreeUConfig
|
||||
|
||||
@ -26,6 +26,8 @@ class ModelInfo(BaseModel):
|
||||
model_type: ModelType = Field(description="Info to load submodel")
|
||||
submodel: Optional[SubModelType] = Field(default=None, description="Info to load submodel")
|
||||
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
|
||||
class LoraInfo(ModelInfo):
|
||||
weight: float = Field(description="Lora's weight which to use when apply to model")
|
||||
@ -87,6 +89,8 @@ class MainModelField(BaseModel):
|
||||
base_model: BaseModelType = Field(description="Base model")
|
||||
model_type: ModelType = Field(description="Model Type")
|
||||
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
|
||||
class LoRAModelField(BaseModel):
|
||||
"""LoRA model field"""
|
||||
@ -94,8 +98,16 @@ class LoRAModelField(BaseModel):
|
||||
model_name: str = Field(description="Name of the LoRA model")
|
||||
base_model: BaseModelType = Field(description="Base model")
|
||||
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
@invocation("main_model_loader", title="Main Model", tags=["model"], category="model", version="1.0.0")
|
||||
|
||||
@invocation(
|
||||
"main_model_loader",
|
||||
title="Main Model",
|
||||
tags=["model"],
|
||||
category="model",
|
||||
version="1.0.0",
|
||||
)
|
||||
class MainModelLoaderInvocation(BaseInvocation):
|
||||
"""Loads a main model, outputting its submodels."""
|
||||
|
||||
@ -202,10 +214,16 @@ class LoraLoaderInvocation(BaseInvocation):
|
||||
lora: LoRAModelField = InputField(description=FieldDescriptions.lora_model, input=Input.Direct, title="LoRA")
|
||||
weight: float = InputField(default=0.75, description=FieldDescriptions.lora_weight)
|
||||
unet: Optional[UNetField] = InputField(
|
||||
default=None, description=FieldDescriptions.unet, input=Input.Connection, title="UNet"
|
||||
default=None,
|
||||
description=FieldDescriptions.unet,
|
||||
input=Input.Connection,
|
||||
title="UNet",
|
||||
)
|
||||
clip: Optional[ClipField] = InputField(
|
||||
default=None, description=FieldDescriptions.clip, input=Input.Connection, title="CLIP"
|
||||
default=None,
|
||||
description=FieldDescriptions.clip,
|
||||
input=Input.Connection,
|
||||
title="CLIP",
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> LoraLoaderOutput:
|
||||
@ -266,20 +284,35 @@ class SDXLLoraLoaderOutput(BaseInvocationOutput):
|
||||
clip2: Optional[ClipField] = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP 2")
|
||||
|
||||
|
||||
@invocation("sdxl_lora_loader", title="SDXL LoRA", tags=["lora", "model"], category="model", version="1.0.0")
|
||||
@invocation(
|
||||
"sdxl_lora_loader",
|
||||
title="SDXL LoRA",
|
||||
tags=["lora", "model"],
|
||||
category="model",
|
||||
version="1.0.0",
|
||||
)
|
||||
class SDXLLoraLoaderInvocation(BaseInvocation):
|
||||
"""Apply selected lora to unet and text_encoder."""
|
||||
|
||||
lora: LoRAModelField = InputField(description=FieldDescriptions.lora_model, input=Input.Direct, title="LoRA")
|
||||
weight: float = InputField(default=0.75, description=FieldDescriptions.lora_weight)
|
||||
unet: Optional[UNetField] = InputField(
|
||||
default=None, description=FieldDescriptions.unet, input=Input.Connection, title="UNet"
|
||||
default=None,
|
||||
description=FieldDescriptions.unet,
|
||||
input=Input.Connection,
|
||||
title="UNet",
|
||||
)
|
||||
clip: Optional[ClipField] = InputField(
|
||||
default=None, description=FieldDescriptions.clip, input=Input.Connection, title="CLIP 1"
|
||||
default=None,
|
||||
description=FieldDescriptions.clip,
|
||||
input=Input.Connection,
|
||||
title="CLIP 1",
|
||||
)
|
||||
clip2: Optional[ClipField] = InputField(
|
||||
default=None, description=FieldDescriptions.clip, input=Input.Connection, title="CLIP 2"
|
||||
default=None,
|
||||
description=FieldDescriptions.clip,
|
||||
input=Input.Connection,
|
||||
title="CLIP 2",
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> SDXLLoraLoaderOutput:
|
||||
@ -352,13 +385,18 @@ class VAEModelField(BaseModel):
|
||||
model_name: str = Field(description="Name of the model")
|
||||
base_model: BaseModelType = Field(description="Base model")
|
||||
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
|
||||
@invocation("vae_loader", title="VAE", tags=["vae", "model"], category="model", version="1.0.0")
|
||||
class VaeLoaderInvocation(BaseInvocation):
|
||||
"""Loads a VAE model, outputting a VaeLoaderOutput"""
|
||||
|
||||
vae_model: VAEModelField = InputField(
|
||||
description=FieldDescriptions.vae_model, input=Input.Direct, ui_type=UIType.VaeModel, title="VAE"
|
||||
description=FieldDescriptions.vae_model,
|
||||
input=Input.Direct,
|
||||
ui_type=UIType.VaeModel,
|
||||
title="VAE",
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> VAEOutput:
|
||||
@ -387,19 +425,31 @@ class VaeLoaderInvocation(BaseInvocation):
|
||||
class SeamlessModeOutput(BaseInvocationOutput):
|
||||
"""Modified Seamless Model output"""
|
||||
|
||||
unet: Optional[UNetField] = OutputField(description=FieldDescriptions.unet, title="UNet")
|
||||
vae: Optional[VaeField] = OutputField(description=FieldDescriptions.vae, title="VAE")
|
||||
unet: Optional[UNetField] = OutputField(default=None, description=FieldDescriptions.unet, title="UNet")
|
||||
vae: Optional[VaeField] = OutputField(default=None, description=FieldDescriptions.vae, title="VAE")
|
||||
|
||||
|
||||
@invocation("seamless", title="Seamless", tags=["seamless", "model"], category="model", version="1.0.0")
|
||||
@invocation(
|
||||
"seamless",
|
||||
title="Seamless",
|
||||
tags=["seamless", "model"],
|
||||
category="model",
|
||||
version="1.0.0",
|
||||
)
|
||||
class SeamlessModeInvocation(BaseInvocation):
|
||||
"""Applies the seamless transformation to the Model UNet and VAE."""
|
||||
|
||||
unet: Optional[UNetField] = InputField(
|
||||
default=None, description=FieldDescriptions.unet, input=Input.Connection, title="UNet"
|
||||
default=None,
|
||||
description=FieldDescriptions.unet,
|
||||
input=Input.Connection,
|
||||
title="UNet",
|
||||
)
|
||||
vae: Optional[VaeField] = InputField(
|
||||
default=None, description=FieldDescriptions.vae_model, input=Input.Connection, title="VAE"
|
||||
default=None,
|
||||
description=FieldDescriptions.vae_model,
|
||||
input=Input.Connection,
|
||||
title="VAE",
|
||||
)
|
||||
seamless_y: bool = InputField(default=True, input=Input.Any, description="Specify whether Y axis is seamless")
|
||||
seamless_x: bool = InputField(default=True, input=Input.Any, description="Specify whether X axis is seamless")
|
||||
|
@ -2,7 +2,7 @@
|
||||
|
||||
|
||||
import torch
|
||||
from pydantic import validator
|
||||
from pydantic import field_validator
|
||||
|
||||
from invokeai.app.invocations.latent import LatentsField
|
||||
from invokeai.app.util.misc import SEED_MAX, get_random_seed
|
||||
@ -65,7 +65,7 @@ Nodes
|
||||
class NoiseOutput(BaseInvocationOutput):
|
||||
"""Invocation noise output"""
|
||||
|
||||
noise: LatentsField = OutputField(default=None, description=FieldDescriptions.noise)
|
||||
noise: LatentsField = OutputField(description=FieldDescriptions.noise)
|
||||
width: int = OutputField(description=FieldDescriptions.width)
|
||||
height: int = OutputField(description=FieldDescriptions.height)
|
||||
|
||||
@ -78,7 +78,13 @@ def build_noise_output(latents_name: str, latents: torch.Tensor, seed: int):
|
||||
)
|
||||
|
||||
|
||||
@invocation("noise", title="Noise", tags=["latents", "noise"], category="latents", version="1.0.0")
|
||||
@invocation(
|
||||
"noise",
|
||||
title="Noise",
|
||||
tags=["latents", "noise"],
|
||||
category="latents",
|
||||
version="1.0.0",
|
||||
)
|
||||
class NoiseInvocation(BaseInvocation):
|
||||
"""Generates latent noise."""
|
||||
|
||||
@ -105,7 +111,7 @@ class NoiseInvocation(BaseInvocation):
|
||||
description="Use CPU for noise generation (for reproducible results across platforms)",
|
||||
)
|
||||
|
||||
@validator("seed", pre=True)
|
||||
@field_validator("seed", mode="before")
|
||||
def modulo_seed(cls, v):
|
||||
"""Returns the seed modulo (SEED_MAX + 1) to ensure it is within the valid range."""
|
||||
return v % (SEED_MAX + 1)
|
||||
|
@ -9,18 +9,18 @@ from typing import List, Literal, Optional, Union
|
||||
import numpy as np
|
||||
import torch
|
||||
from diffusers.image_processor import VaeImageProcessor
|
||||
from pydantic import BaseModel, Field, validator
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||
from tqdm import tqdm
|
||||
|
||||
from invokeai.app.invocations.metadata import CoreMetadata
|
||||
from invokeai.app.invocations.primitives import ConditioningField, ConditioningOutput, ImageField, ImageOutput
|
||||
from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin
|
||||
from invokeai.app.util.step_callback import stable_diffusion_step_callback
|
||||
from invokeai.backend import BaseModelType, ModelType, SubModelType
|
||||
|
||||
from ...backend.model_management import ONNXModelPatcher
|
||||
from ...backend.stable_diffusion import PipelineIntermediateState
|
||||
from ...backend.util import choose_torch_device
|
||||
from ..models.image import ImageCategory, ResourceOrigin
|
||||
from .baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
@ -63,14 +63,17 @@ class ONNXPromptInvocation(BaseInvocation):
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ConditioningOutput:
|
||||
tokenizer_info = context.services.model_manager.get_model(
|
||||
**self.clip.tokenizer.dict(),
|
||||
**self.clip.tokenizer.model_dump(),
|
||||
)
|
||||
text_encoder_info = context.services.model_manager.get_model(
|
||||
**self.clip.text_encoder.dict(),
|
||||
**self.clip.text_encoder.model_dump(),
|
||||
)
|
||||
with tokenizer_info as orig_tokenizer, text_encoder_info as text_encoder: # , ExitStack() as stack:
|
||||
loras = [
|
||||
(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight)
|
||||
(
|
||||
context.services.model_manager.get_model(**lora.model_dump(exclude={"weight"})).context.model,
|
||||
lora.weight,
|
||||
)
|
||||
for lora in self.clip.loras
|
||||
]
|
||||
|
||||
@ -175,14 +178,14 @@ class ONNXTextToLatentsInvocation(BaseInvocation):
|
||||
description=FieldDescriptions.unet,
|
||||
input=Input.Connection,
|
||||
)
|
||||
control: Optional[Union[ControlField, list[ControlField]]] = InputField(
|
||||
control: Union[ControlField, list[ControlField]] = InputField(
|
||||
default=None,
|
||||
description=FieldDescriptions.control,
|
||||
)
|
||||
# seamless: bool = InputField(default=False, description="Whether or not to generate an image that can tile without seams", )
|
||||
# seamless_axes: str = InputField(default="", description="The axes to tile the image on, 'x' and/or 'y'")
|
||||
|
||||
@validator("cfg_scale")
|
||||
@field_validator("cfg_scale")
|
||||
def ge_one(cls, v):
|
||||
"""validate that all cfg_scale values are >= 1"""
|
||||
if isinstance(v, list):
|
||||
@ -241,7 +244,7 @@ class ONNXTextToLatentsInvocation(BaseInvocation):
|
||||
stable_diffusion_step_callback(
|
||||
context=context,
|
||||
intermediate_state=intermediate_state,
|
||||
node=self.dict(),
|
||||
node=self.model_dump(),
|
||||
source_node_id=source_node_id,
|
||||
)
|
||||
|
||||
@ -254,12 +257,15 @@ class ONNXTextToLatentsInvocation(BaseInvocation):
|
||||
eta=0.0,
|
||||
)
|
||||
|
||||
unet_info = context.services.model_manager.get_model(**self.unet.unet.dict())
|
||||
unet_info = context.services.model_manager.get_model(**self.unet.unet.model_dump())
|
||||
|
||||
with unet_info as unet: # , ExitStack() as stack:
|
||||
# loras = [(stack.enter_context(context.services.model_manager.get_model(**lora.dict(exclude={"weight"}))), lora.weight) for lora in self.unet.loras]
|
||||
loras = [
|
||||
(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight)
|
||||
(
|
||||
context.services.model_manager.get_model(**lora.model_dump(exclude={"weight"})).context.model,
|
||||
lora.weight,
|
||||
)
|
||||
for lora in self.unet.loras
|
||||
]
|
||||
|
||||
@ -346,7 +352,7 @@ class ONNXLatentsToImageInvocation(BaseInvocation):
|
||||
raise Exception(f"Expected vae_decoder, found: {self.vae.vae.model_type}")
|
||||
|
||||
vae_info = context.services.model_manager.get_model(
|
||||
**self.vae.vae.dict(),
|
||||
**self.vae.vae.model_dump(),
|
||||
)
|
||||
|
||||
# clear memory as vae decode can request a lot
|
||||
@ -375,7 +381,7 @@ class ONNXLatentsToImageInvocation(BaseInvocation):
|
||||
node_id=self.id,
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
metadata=self.metadata.dict() if self.metadata else None,
|
||||
metadata=self.metadata.model_dump() if self.metadata else None,
|
||||
workflow=self.workflow,
|
||||
)
|
||||
|
||||
@ -403,6 +409,8 @@ class OnnxModelField(BaseModel):
|
||||
base_model: BaseModelType = Field(description="Base model")
|
||||
model_type: ModelType = Field(description="Model Type")
|
||||
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
|
||||
@invocation("onnx_model_loader", title="ONNX Main Model", tags=["onnx", "model"], category="model", version="1.0.0")
|
||||
class OnnxModelLoaderInvocation(BaseInvocation):
|
||||
|
@ -44,13 +44,22 @@ from invokeai.app.invocations.primitives import FloatCollectionOutput
|
||||
from .baseinvocation import BaseInvocation, InputField, InvocationContext, invocation
|
||||
|
||||
|
||||
@invocation("float_range", title="Float Range", tags=["math", "range"], category="math", version="1.0.0")
|
||||
@invocation(
|
||||
"float_range",
|
||||
title="Float Range",
|
||||
tags=["math", "range"],
|
||||
category="math",
|
||||
version="1.0.0",
|
||||
)
|
||||
class FloatLinearRangeInvocation(BaseInvocation):
|
||||
"""Creates a range"""
|
||||
|
||||
start: float = InputField(default=5, description="The first value of the range")
|
||||
stop: float = InputField(default=10, description="The last value of the range")
|
||||
steps: int = InputField(default=30, description="number of values to interpolate over (including start and stop)")
|
||||
steps: int = InputField(
|
||||
default=30,
|
||||
description="number of values to interpolate over (including start and stop)",
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> FloatCollectionOutput:
|
||||
param_list = list(np.linspace(self.start, self.stop, self.steps))
|
||||
@ -95,7 +104,13 @@ EASING_FUNCTION_KEYS = Literal[tuple(list(EASING_FUNCTIONS_MAP.keys()))]
|
||||
|
||||
|
||||
# actually I think for now could just use CollectionOutput (which is list[Any]
|
||||
@invocation("step_param_easing", title="Step Param Easing", tags=["step", "easing"], category="step", version="1.0.0")
|
||||
@invocation(
|
||||
"step_param_easing",
|
||||
title="Step Param Easing",
|
||||
tags=["step", "easing"],
|
||||
category="step",
|
||||
version="1.0.0",
|
||||
)
|
||||
class StepParamEasingInvocation(BaseInvocation):
|
||||
"""Experimental per-step parameter easing for denoising steps"""
|
||||
|
||||
@ -159,7 +174,9 @@ class StepParamEasingInvocation(BaseInvocation):
|
||||
context.services.logger.debug("base easing duration: " + str(base_easing_duration))
|
||||
even_num_steps = num_easing_steps % 2 == 0 # even number of steps
|
||||
easing_function = easing_class(
|
||||
start=self.start_value, end=self.end_value, duration=base_easing_duration - 1
|
||||
start=self.start_value,
|
||||
end=self.end_value,
|
||||
duration=base_easing_duration - 1,
|
||||
)
|
||||
base_easing_vals = list()
|
||||
for step_index in range(base_easing_duration):
|
||||
@ -199,7 +216,11 @@ class StepParamEasingInvocation(BaseInvocation):
|
||||
#
|
||||
|
||||
else: # no mirroring (default)
|
||||
easing_function = easing_class(start=self.start_value, end=self.end_value, duration=num_easing_steps - 1)
|
||||
easing_function = easing_class(
|
||||
start=self.start_value,
|
||||
end=self.end_value,
|
||||
duration=num_easing_steps - 1,
|
||||
)
|
||||
for step_index in range(num_easing_steps):
|
||||
step_val = easing_function.ease(step_index)
|
||||
easing_list.append(step_val)
|
||||
|
@ -3,7 +3,7 @@ from typing import Optional, Union
|
||||
|
||||
import numpy as np
|
||||
from dynamicprompts.generators import CombinatorialPromptGenerator, RandomPromptGenerator
|
||||
from pydantic import validator
|
||||
from pydantic import field_validator
|
||||
|
||||
from invokeai.app.invocations.primitives import StringCollectionOutput
|
||||
|
||||
@ -21,7 +21,10 @@ from .baseinvocation import BaseInvocation, InputField, InvocationContext, UICom
|
||||
class DynamicPromptInvocation(BaseInvocation):
|
||||
"""Parses a prompt using adieyal/dynamicprompts' random or combinatorial generator"""
|
||||
|
||||
prompt: str = InputField(description="The prompt to parse with dynamicprompts", ui_component=UIComponent.Textarea)
|
||||
prompt: str = InputField(
|
||||
description="The prompt to parse with dynamicprompts",
|
||||
ui_component=UIComponent.Textarea,
|
||||
)
|
||||
max_prompts: int = InputField(default=1, description="The number of prompts to generate")
|
||||
combinatorial: bool = InputField(default=False, description="Whether to use the combinatorial generator")
|
||||
|
||||
@ -36,21 +39,31 @@ class DynamicPromptInvocation(BaseInvocation):
|
||||
return StringCollectionOutput(collection=prompts)
|
||||
|
||||
|
||||
@invocation("prompt_from_file", title="Prompts from File", tags=["prompt", "file"], category="prompt", version="1.0.0")
|
||||
@invocation(
|
||||
"prompt_from_file",
|
||||
title="Prompts from File",
|
||||
tags=["prompt", "file"],
|
||||
category="prompt",
|
||||
version="1.0.0",
|
||||
)
|
||||
class PromptsFromFileInvocation(BaseInvocation):
|
||||
"""Loads prompts from a text file"""
|
||||
|
||||
file_path: str = InputField(description="Path to prompt text file")
|
||||
pre_prompt: Optional[str] = InputField(
|
||||
default=None, description="String to prepend to each prompt", ui_component=UIComponent.Textarea
|
||||
default=None,
|
||||
description="String to prepend to each prompt",
|
||||
ui_component=UIComponent.Textarea,
|
||||
)
|
||||
post_prompt: Optional[str] = InputField(
|
||||
default=None, description="String to append to each prompt", ui_component=UIComponent.Textarea
|
||||
default=None,
|
||||
description="String to append to each prompt",
|
||||
ui_component=UIComponent.Textarea,
|
||||
)
|
||||
start_line: int = InputField(default=1, ge=1, description="Line in the file to start start from")
|
||||
max_prompts: int = InputField(default=1, ge=0, description="Max lines to read from file (0=all)")
|
||||
|
||||
@validator("file_path")
|
||||
@field_validator("file_path")
|
||||
def file_path_exists(cls, v):
|
||||
if not exists(v):
|
||||
raise ValueError(FileNotFoundError)
|
||||
@ -79,6 +92,10 @@ class PromptsFromFileInvocation(BaseInvocation):
|
||||
|
||||
def invoke(self, context: InvocationContext) -> StringCollectionOutput:
|
||||
prompts = self.promptsFromFile(
|
||||
self.file_path, self.pre_prompt, self.post_prompt, self.start_line, self.max_prompts
|
||||
self.file_path,
|
||||
self.pre_prompt,
|
||||
self.post_prompt,
|
||||
self.start_line,
|
||||
self.max_prompts,
|
||||
)
|
||||
return StringCollectionOutput(collection=prompts)
|
||||
|
@ -1,6 +1,6 @@
|
||||
from typing import Union
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import (
|
||||
BaseInvocation,
|
||||
@ -23,6 +23,8 @@ class T2IAdapterModelField(BaseModel):
|
||||
model_name: str = Field(description="Name of the T2I-Adapter model")
|
||||
base_model: BaseModelType = Field(description="Base model")
|
||||
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
|
||||
class T2IAdapterField(BaseModel):
|
||||
image: ImageField = Field(description="The T2I-Adapter image prompt.")
|
||||
|
@ -7,10 +7,11 @@ import numpy as np
|
||||
import torch
|
||||
from basicsr.archs.rrdbnet_arch import RRDBNet
|
||||
from PIL import Image
|
||||
from pydantic import ConfigDict
|
||||
from realesrgan import RealESRGANer
|
||||
|
||||
from invokeai.app.invocations.primitives import ImageField, ImageOutput
|
||||
from invokeai.app.models.image import ImageCategory, ResourceOrigin
|
||||
from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin
|
||||
from invokeai.backend.util.devices import choose_torch_device
|
||||
|
||||
from .baseinvocation import BaseInvocation, InputField, InvocationContext, invocation
|
||||
@ -38,6 +39,8 @@ class ESRGANInvocation(BaseInvocation):
|
||||
default=400, ge=0, description="Tile size for tiled ESRGAN upscaling (0=tiling disabled)"
|
||||
)
|
||||
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get_pil_image(self.image.image_name)
|
||||
models_path = context.services.configuration.models_path
|
||||
|
@ -1,4 +0,0 @@
|
||||
class CanceledException(Exception):
|
||||
"""Execution canceled by user."""
|
||||
|
||||
pass
|
@ -1,71 +0,0 @@
|
||||
from enum import Enum
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from invokeai.app.util.metaenum import MetaEnum
|
||||
|
||||
|
||||
class ProgressImage(BaseModel):
|
||||
"""The progress image sent intermittently during processing"""
|
||||
|
||||
width: int = Field(description="The effective width of the image in pixels")
|
||||
height: int = Field(description="The effective height of the image in pixels")
|
||||
dataURL: str = Field(description="The image data as a b64 data URL")
|
||||
|
||||
|
||||
class ResourceOrigin(str, Enum, metaclass=MetaEnum):
|
||||
"""The origin of a resource (eg image).
|
||||
|
||||
- INTERNAL: The resource was created by the application.
|
||||
- EXTERNAL: The resource was not created by the application.
|
||||
This may be a user-initiated upload, or an internal application upload (eg Canvas init image).
|
||||
"""
|
||||
|
||||
INTERNAL = "internal"
|
||||
"""The resource was created by the application."""
|
||||
EXTERNAL = "external"
|
||||
"""The resource was not created by the application.
|
||||
This may be a user-initiated upload, or an internal application upload (eg Canvas init image).
|
||||
"""
|
||||
|
||||
|
||||
class InvalidOriginException(ValueError):
|
||||
"""Raised when a provided value is not a valid ResourceOrigin.
|
||||
|
||||
Subclasses `ValueError`.
|
||||
"""
|
||||
|
||||
def __init__(self, message="Invalid resource origin."):
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class ImageCategory(str, Enum, metaclass=MetaEnum):
|
||||
"""The category of an image.
|
||||
|
||||
- GENERAL: The image is an output, init image, or otherwise an image without a specialized purpose.
|
||||
- MASK: The image is a mask image.
|
||||
- CONTROL: The image is a ControlNet control image.
|
||||
- USER: The image is a user-provide image.
|
||||
- OTHER: The image is some other type of image with a specialized purpose. To be used by external nodes.
|
||||
"""
|
||||
|
||||
GENERAL = "general"
|
||||
"""GENERAL: The image is an output, init image, or otherwise an image without a specialized purpose."""
|
||||
MASK = "mask"
|
||||
"""MASK: The image is a mask image."""
|
||||
CONTROL = "control"
|
||||
"""CONTROL: The image is a ControlNet control image."""
|
||||
USER = "user"
|
||||
"""USER: The image is a user-provide image."""
|
||||
OTHER = "other"
|
||||
"""OTHER: The image is some other type of image with a specialized purpose. To be used by external nodes."""
|
||||
|
||||
|
||||
class InvalidImageCategoryException(ValueError):
|
||||
"""Raised when a provided value is not a valid ImageCategory.
|
||||
|
||||
Subclasses `ValueError`.
|
||||
"""
|
||||
|
||||
def __init__(self, message="Invalid image category."):
|
||||
super().__init__(message)
|
@ -0,0 +1,47 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class BoardImageRecordStorageBase(ABC):
|
||||
"""Abstract base class for the one-to-many board-image relationship record storage."""
|
||||
|
||||
@abstractmethod
|
||||
def add_image_to_board(
|
||||
self,
|
||||
board_id: str,
|
||||
image_name: str,
|
||||
) -> None:
|
||||
"""Adds an image to a board."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def remove_image_from_board(
|
||||
self,
|
||||
image_name: str,
|
||||
) -> None:
|
||||
"""Removes an image from a board."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_all_board_image_names_for_board(
|
||||
self,
|
||||
board_id: str,
|
||||
) -> list[str]:
|
||||
"""Gets all board images for a board, as a list of the image names."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_board_for_image(
|
||||
self,
|
||||
image_name: str,
|
||||
) -> Optional[str]:
|
||||
"""Gets an image's board id, if it has one."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_image_count_for_board(
|
||||
self,
|
||||
board_id: str,
|
||||
) -> int:
|
||||
"""Gets the number of images for a board."""
|
||||
pass
|
@ -1,69 +1,24 @@
|
||||
import sqlite3
|
||||
import threading
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional, cast
|
||||
|
||||
from invokeai.app.services.image_record_storage import OffsetPaginatedResults
|
||||
from invokeai.app.services.models.image_record import ImageRecord, deserialize_image_record
|
||||
from invokeai.app.services.image_records.image_records_common import ImageRecord, deserialize_image_record
|
||||
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
|
||||
from invokeai.app.services.shared.sqlite import SqliteDatabase
|
||||
|
||||
|
||||
class BoardImageRecordStorageBase(ABC):
|
||||
"""Abstract base class for the one-to-many board-image relationship record storage."""
|
||||
|
||||
@abstractmethod
|
||||
def add_image_to_board(
|
||||
self,
|
||||
board_id: str,
|
||||
image_name: str,
|
||||
) -> None:
|
||||
"""Adds an image to a board."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def remove_image_from_board(
|
||||
self,
|
||||
image_name: str,
|
||||
) -> None:
|
||||
"""Removes an image from a board."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_all_board_image_names_for_board(
|
||||
self,
|
||||
board_id: str,
|
||||
) -> list[str]:
|
||||
"""Gets all board images for a board, as a list of the image names."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_board_for_image(
|
||||
self,
|
||||
image_name: str,
|
||||
) -> Optional[str]:
|
||||
"""Gets an image's board id, if it has one."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_image_count_for_board(
|
||||
self,
|
||||
board_id: str,
|
||||
) -> int:
|
||||
"""Gets the number of images for a board."""
|
||||
pass
|
||||
from .board_image_records_base import BoardImageRecordStorageBase
|
||||
|
||||
|
||||
class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase):
|
||||
_conn: sqlite3.Connection
|
||||
_cursor: sqlite3.Cursor
|
||||
_lock: threading.Lock
|
||||
_lock: threading.RLock
|
||||
|
||||
def __init__(self, conn: sqlite3.Connection, lock: threading.Lock) -> None:
|
||||
def __init__(self, db: SqliteDatabase) -> None:
|
||||
super().__init__()
|
||||
self._conn = conn
|
||||
# Enable row factory to get rows as dictionaries (must be done before making the cursor!)
|
||||
self._conn.row_factory = sqlite3.Row
|
||||
self._lock = db.lock
|
||||
self._conn = db.conn
|
||||
self._cursor = self._conn.cursor()
|
||||
self._lock = lock
|
||||
|
||||
try:
|
||||
self._lock.acquire()
|
@ -1,112 +0,0 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from logging import Logger
|
||||
from typing import Optional
|
||||
|
||||
from invokeai.app.services.board_image_record_storage import BoardImageRecordStorageBase
|
||||
from invokeai.app.services.board_record_storage import BoardRecord, BoardRecordStorageBase
|
||||
from invokeai.app.services.image_record_storage import ImageRecordStorageBase
|
||||
from invokeai.app.services.models.board_record import BoardDTO
|
||||
from invokeai.app.services.urls import UrlServiceBase
|
||||
|
||||
|
||||
class BoardImagesServiceABC(ABC):
|
||||
"""High-level service for board-image relationship management."""
|
||||
|
||||
@abstractmethod
|
||||
def add_image_to_board(
|
||||
self,
|
||||
board_id: str,
|
||||
image_name: str,
|
||||
) -> None:
|
||||
"""Adds an image to a board."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def remove_image_from_board(
|
||||
self,
|
||||
image_name: str,
|
||||
) -> None:
|
||||
"""Removes an image from a board."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_all_board_image_names_for_board(
|
||||
self,
|
||||
board_id: str,
|
||||
) -> list[str]:
|
||||
"""Gets all board images for a board, as a list of the image names."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_board_for_image(
|
||||
self,
|
||||
image_name: str,
|
||||
) -> Optional[str]:
|
||||
"""Gets an image's board id, if it has one."""
|
||||
pass
|
||||
|
||||
|
||||
class BoardImagesServiceDependencies:
|
||||
"""Service dependencies for the BoardImagesService."""
|
||||
|
||||
board_image_records: BoardImageRecordStorageBase
|
||||
board_records: BoardRecordStorageBase
|
||||
image_records: ImageRecordStorageBase
|
||||
urls: UrlServiceBase
|
||||
logger: Logger
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
board_image_record_storage: BoardImageRecordStorageBase,
|
||||
image_record_storage: ImageRecordStorageBase,
|
||||
board_record_storage: BoardRecordStorageBase,
|
||||
url: UrlServiceBase,
|
||||
logger: Logger,
|
||||
):
|
||||
self.board_image_records = board_image_record_storage
|
||||
self.image_records = image_record_storage
|
||||
self.board_records = board_record_storage
|
||||
self.urls = url
|
||||
self.logger = logger
|
||||
|
||||
|
||||
class BoardImagesService(BoardImagesServiceABC):
|
||||
_services: BoardImagesServiceDependencies
|
||||
|
||||
def __init__(self, services: BoardImagesServiceDependencies):
|
||||
self._services = services
|
||||
|
||||
def add_image_to_board(
|
||||
self,
|
||||
board_id: str,
|
||||
image_name: str,
|
||||
) -> None:
|
||||
self._services.board_image_records.add_image_to_board(board_id, image_name)
|
||||
|
||||
def remove_image_from_board(
|
||||
self,
|
||||
image_name: str,
|
||||
) -> None:
|
||||
self._services.board_image_records.remove_image_from_board(image_name)
|
||||
|
||||
def get_all_board_image_names_for_board(
|
||||
self,
|
||||
board_id: str,
|
||||
) -> list[str]:
|
||||
return self._services.board_image_records.get_all_board_image_names_for_board(board_id)
|
||||
|
||||
def get_board_for_image(
|
||||
self,
|
||||
image_name: str,
|
||||
) -> Optional[str]:
|
||||
board_id = self._services.board_image_records.get_board_for_image(image_name)
|
||||
return board_id
|
||||
|
||||
|
||||
def board_record_to_dto(board_record: BoardRecord, cover_image_name: Optional[str], image_count: int) -> BoardDTO:
|
||||
"""Converts a board record to a board DTO."""
|
||||
return BoardDTO(
|
||||
**board_record.dict(exclude={"cover_image_name"}),
|
||||
cover_image_name=cover_image_name,
|
||||
image_count=image_count,
|
||||
)
|
0
invokeai/app/services/board_images/__init__.py
Normal file
0
invokeai/app/services/board_images/__init__.py
Normal file
39
invokeai/app/services/board_images/board_images_base.py
Normal file
39
invokeai/app/services/board_images/board_images_base.py
Normal file
@ -0,0 +1,39 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class BoardImagesServiceABC(ABC):
|
||||
"""High-level service for board-image relationship management."""
|
||||
|
||||
@abstractmethod
|
||||
def add_image_to_board(
|
||||
self,
|
||||
board_id: str,
|
||||
image_name: str,
|
||||
) -> None:
|
||||
"""Adds an image to a board."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def remove_image_from_board(
|
||||
self,
|
||||
image_name: str,
|
||||
) -> None:
|
||||
"""Removes an image from a board."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_all_board_image_names_for_board(
|
||||
self,
|
||||
board_id: str,
|
||||
) -> list[str]:
|
||||
"""Gets all board images for a board, as a list of the image names."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_board_for_image(
|
||||
self,
|
||||
image_name: str,
|
||||
) -> Optional[str]:
|
||||
"""Gets an image's board id, if it has one."""
|
||||
pass
|
38
invokeai/app/services/board_images/board_images_default.py
Normal file
38
invokeai/app/services/board_images/board_images_default.py
Normal file
@ -0,0 +1,38 @@
|
||||
from typing import Optional
|
||||
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
|
||||
from .board_images_base import BoardImagesServiceABC
|
||||
|
||||
|
||||
class BoardImagesService(BoardImagesServiceABC):
|
||||
__invoker: Invoker
|
||||
|
||||
def start(self, invoker: Invoker) -> None:
|
||||
self.__invoker = invoker
|
||||
|
||||
def add_image_to_board(
|
||||
self,
|
||||
board_id: str,
|
||||
image_name: str,
|
||||
) -> None:
|
||||
self.__invoker.services.board_image_records.add_image_to_board(board_id, image_name)
|
||||
|
||||
def remove_image_from_board(
|
||||
self,
|
||||
image_name: str,
|
||||
) -> None:
|
||||
self.__invoker.services.board_image_records.remove_image_from_board(image_name)
|
||||
|
||||
def get_all_board_image_names_for_board(
|
||||
self,
|
||||
board_id: str,
|
||||
) -> list[str]:
|
||||
return self.__invoker.services.board_image_records.get_all_board_image_names_for_board(board_id)
|
||||
|
||||
def get_board_for_image(
|
||||
self,
|
||||
image_name: str,
|
||||
) -> Optional[str]:
|
||||
board_id = self.__invoker.services.board_image_records.get_board_for_image(image_name)
|
||||
return board_id
|
55
invokeai/app/services/board_records/board_records_base.py
Normal file
55
invokeai/app/services/board_records/board_records_base.py
Normal file
@ -0,0 +1,55 @@
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
|
||||
|
||||
from .board_records_common import BoardChanges, BoardRecord
|
||||
|
||||
|
||||
class BoardRecordStorageBase(ABC):
|
||||
"""Low-level service responsible for interfacing with the board record store."""
|
||||
|
||||
@abstractmethod
|
||||
def delete(self, board_id: str) -> None:
|
||||
"""Deletes a board record."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def save(
|
||||
self,
|
||||
board_name: str,
|
||||
) -> BoardRecord:
|
||||
"""Saves a board record."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get(
|
||||
self,
|
||||
board_id: str,
|
||||
) -> BoardRecord:
|
||||
"""Gets a board record."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def update(
|
||||
self,
|
||||
board_id: str,
|
||||
changes: BoardChanges,
|
||||
) -> BoardRecord:
|
||||
"""Updates a board record."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_many(
|
||||
self,
|
||||
offset: int = 0,
|
||||
limit: int = 10,
|
||||
) -> OffsetPaginatedResults[BoardRecord]:
|
||||
"""Gets many board records."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_all(
|
||||
self,
|
||||
) -> list[BoardRecord]:
|
||||
"""Gets all board records."""
|
||||
pass
|
@ -1,7 +1,7 @@
|
||||
from datetime import datetime
|
||||
from typing import Optional, Union
|
||||
|
||||
from pydantic import Field
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from invokeai.app.util.misc import get_iso_timestamp
|
||||
from invokeai.app.util.model_exclude_null import BaseModelExcludeNull
|
||||
@ -18,21 +18,12 @@ class BoardRecord(BaseModelExcludeNull):
|
||||
"""The created timestamp of the image."""
|
||||
updated_at: Union[datetime, str] = Field(description="The updated timestamp of the board.")
|
||||
"""The updated timestamp of the image."""
|
||||
deleted_at: Union[datetime, str, None] = Field(description="The deleted timestamp of the board.")
|
||||
deleted_at: Optional[Union[datetime, str]] = Field(default=None, description="The deleted timestamp of the board.")
|
||||
"""The updated timestamp of the image."""
|
||||
cover_image_name: Optional[str] = Field(description="The name of the cover image of the board.")
|
||||
cover_image_name: Optional[str] = Field(default=None, description="The name of the cover image of the board.")
|
||||
"""The name of the cover image of the board."""
|
||||
|
||||
|
||||
class BoardDTO(BoardRecord):
|
||||
"""Deserialized board record with cover image URL and image count."""
|
||||
|
||||
cover_image_name: Optional[str] = Field(description="The name of the board's cover image.")
|
||||
"""The URL of the thumbnail of the most recent image in the board."""
|
||||
image_count: int = Field(description="The number of images in the board.")
|
||||
"""The number of images in the board."""
|
||||
|
||||
|
||||
def deserialize_board_record(board_dict: dict) -> BoardRecord:
|
||||
"""Deserializes a board record."""
|
||||
|
||||
@ -53,3 +44,29 @@ def deserialize_board_record(board_dict: dict) -> BoardRecord:
|
||||
updated_at=updated_at,
|
||||
deleted_at=deleted_at,
|
||||
)
|
||||
|
||||
|
||||
class BoardChanges(BaseModel, extra="forbid"):
|
||||
board_name: Optional[str] = Field(default=None, description="The board's new name.")
|
||||
cover_image_name: Optional[str] = Field(default=None, description="The name of the board's new cover image.")
|
||||
|
||||
|
||||
class BoardRecordNotFoundException(Exception):
|
||||
"""Raised when an board record is not found."""
|
||||
|
||||
def __init__(self, message="Board record not found"):
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class BoardRecordSaveException(Exception):
|
||||
"""Raised when an board record cannot be saved."""
|
||||
|
||||
def __init__(self, message="Board record not saved"):
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class BoardRecordDeleteException(Exception):
|
||||
"""Raised when an board record cannot be deleted."""
|
||||
|
||||
def __init__(self, message="Board record not deleted"):
|
||||
super().__init__(message)
|
@ -1,103 +1,32 @@
|
||||
import sqlite3
|
||||
import threading
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional, Union, cast
|
||||
from typing import Union, cast
|
||||
|
||||
from pydantic import BaseModel, Extra, Field
|
||||
|
||||
from invokeai.app.services.image_record_storage import OffsetPaginatedResults
|
||||
from invokeai.app.services.models.board_record import BoardRecord, deserialize_board_record
|
||||
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
|
||||
from invokeai.app.services.shared.sqlite import SqliteDatabase
|
||||
from invokeai.app.util.misc import uuid_string
|
||||
|
||||
|
||||
class BoardChanges(BaseModel, extra=Extra.forbid):
|
||||
board_name: Optional[str] = Field(description="The board's new name.")
|
||||
cover_image_name: Optional[str] = Field(description="The name of the board's new cover image.")
|
||||
|
||||
|
||||
class BoardRecordNotFoundException(Exception):
|
||||
"""Raised when an board record is not found."""
|
||||
|
||||
def __init__(self, message="Board record not found"):
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class BoardRecordSaveException(Exception):
|
||||
"""Raised when an board record cannot be saved."""
|
||||
|
||||
def __init__(self, message="Board record not saved"):
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class BoardRecordDeleteException(Exception):
|
||||
"""Raised when an board record cannot be deleted."""
|
||||
|
||||
def __init__(self, message="Board record not deleted"):
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class BoardRecordStorageBase(ABC):
|
||||
"""Low-level service responsible for interfacing with the board record store."""
|
||||
|
||||
@abstractmethod
|
||||
def delete(self, board_id: str) -> None:
|
||||
"""Deletes a board record."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def save(
|
||||
self,
|
||||
board_name: str,
|
||||
) -> BoardRecord:
|
||||
"""Saves a board record."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get(
|
||||
self,
|
||||
board_id: str,
|
||||
) -> BoardRecord:
|
||||
"""Gets a board record."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def update(
|
||||
self,
|
||||
board_id: str,
|
||||
changes: BoardChanges,
|
||||
) -> BoardRecord:
|
||||
"""Updates a board record."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_many(
|
||||
self,
|
||||
offset: int = 0,
|
||||
limit: int = 10,
|
||||
) -> OffsetPaginatedResults[BoardRecord]:
|
||||
"""Gets many board records."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_all(
|
||||
self,
|
||||
) -> list[BoardRecord]:
|
||||
"""Gets all board records."""
|
||||
pass
|
||||
from .board_records_base import BoardRecordStorageBase
|
||||
from .board_records_common import (
|
||||
BoardChanges,
|
||||
BoardRecord,
|
||||
BoardRecordDeleteException,
|
||||
BoardRecordNotFoundException,
|
||||
BoardRecordSaveException,
|
||||
deserialize_board_record,
|
||||
)
|
||||
|
||||
|
||||
class SqliteBoardRecordStorage(BoardRecordStorageBase):
|
||||
_conn: sqlite3.Connection
|
||||
_cursor: sqlite3.Cursor
|
||||
_lock: threading.Lock
|
||||
_lock: threading.RLock
|
||||
|
||||
def __init__(self, conn: sqlite3.Connection, lock: threading.Lock) -> None:
|
||||
def __init__(self, db: SqliteDatabase) -> None:
|
||||
super().__init__()
|
||||
self._conn = conn
|
||||
# Enable row factory to get rows as dictionaries (must be done before making the cursor!)
|
||||
self._conn.row_factory = sqlite3.Row
|
||||
self._lock = db.lock
|
||||
self._conn = db.conn
|
||||
self._cursor = self._conn.cursor()
|
||||
self._lock = lock
|
||||
|
||||
try:
|
||||
self._lock.acquire()
|
@ -1,158 +0,0 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from logging import Logger
|
||||
|
||||
from invokeai.app.services.board_image_record_storage import BoardImageRecordStorageBase
|
||||
from invokeai.app.services.board_images import board_record_to_dto
|
||||
from invokeai.app.services.board_record_storage import BoardChanges, BoardRecordStorageBase
|
||||
from invokeai.app.services.image_record_storage import ImageRecordStorageBase, OffsetPaginatedResults
|
||||
from invokeai.app.services.models.board_record import BoardDTO
|
||||
from invokeai.app.services.urls import UrlServiceBase
|
||||
|
||||
|
||||
class BoardServiceABC(ABC):
|
||||
"""High-level service for board management."""
|
||||
|
||||
@abstractmethod
|
||||
def create(
|
||||
self,
|
||||
board_name: str,
|
||||
) -> BoardDTO:
|
||||
"""Creates a board."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_dto(
|
||||
self,
|
||||
board_id: str,
|
||||
) -> BoardDTO:
|
||||
"""Gets a board."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def update(
|
||||
self,
|
||||
board_id: str,
|
||||
changes: BoardChanges,
|
||||
) -> BoardDTO:
|
||||
"""Updates a board."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def delete(
|
||||
self,
|
||||
board_id: str,
|
||||
) -> None:
|
||||
"""Deletes a board."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_many(
|
||||
self,
|
||||
offset: int = 0,
|
||||
limit: int = 10,
|
||||
) -> OffsetPaginatedResults[BoardDTO]:
|
||||
"""Gets many boards."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_all(
|
||||
self,
|
||||
) -> list[BoardDTO]:
|
||||
"""Gets all boards."""
|
||||
pass
|
||||
|
||||
|
||||
class BoardServiceDependencies:
|
||||
"""Service dependencies for the BoardService."""
|
||||
|
||||
board_image_records: BoardImageRecordStorageBase
|
||||
board_records: BoardRecordStorageBase
|
||||
image_records: ImageRecordStorageBase
|
||||
urls: UrlServiceBase
|
||||
logger: Logger
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
board_image_record_storage: BoardImageRecordStorageBase,
|
||||
image_record_storage: ImageRecordStorageBase,
|
||||
board_record_storage: BoardRecordStorageBase,
|
||||
url: UrlServiceBase,
|
||||
logger: Logger,
|
||||
):
|
||||
self.board_image_records = board_image_record_storage
|
||||
self.image_records = image_record_storage
|
||||
self.board_records = board_record_storage
|
||||
self.urls = url
|
||||
self.logger = logger
|
||||
|
||||
|
||||
class BoardService(BoardServiceABC):
|
||||
_services: BoardServiceDependencies
|
||||
|
||||
def __init__(self, services: BoardServiceDependencies):
|
||||
self._services = services
|
||||
|
||||
def create(
|
||||
self,
|
||||
board_name: str,
|
||||
) -> BoardDTO:
|
||||
board_record = self._services.board_records.save(board_name)
|
||||
return board_record_to_dto(board_record, None, 0)
|
||||
|
||||
def get_dto(self, board_id: str) -> BoardDTO:
|
||||
board_record = self._services.board_records.get(board_id)
|
||||
cover_image = self._services.image_records.get_most_recent_image_for_board(board_record.board_id)
|
||||
if cover_image:
|
||||
cover_image_name = cover_image.image_name
|
||||
else:
|
||||
cover_image_name = None
|
||||
image_count = self._services.board_image_records.get_image_count_for_board(board_id)
|
||||
return board_record_to_dto(board_record, cover_image_name, image_count)
|
||||
|
||||
def update(
|
||||
self,
|
||||
board_id: str,
|
||||
changes: BoardChanges,
|
||||
) -> BoardDTO:
|
||||
board_record = self._services.board_records.update(board_id, changes)
|
||||
cover_image = self._services.image_records.get_most_recent_image_for_board(board_record.board_id)
|
||||
if cover_image:
|
||||
cover_image_name = cover_image.image_name
|
||||
else:
|
||||
cover_image_name = None
|
||||
|
||||
image_count = self._services.board_image_records.get_image_count_for_board(board_id)
|
||||
return board_record_to_dto(board_record, cover_image_name, image_count)
|
||||
|
||||
def delete(self, board_id: str) -> None:
|
||||
self._services.board_records.delete(board_id)
|
||||
|
||||
def get_many(self, offset: int = 0, limit: int = 10) -> OffsetPaginatedResults[BoardDTO]:
|
||||
board_records = self._services.board_records.get_many(offset, limit)
|
||||
board_dtos = []
|
||||
for r in board_records.items:
|
||||
cover_image = self._services.image_records.get_most_recent_image_for_board(r.board_id)
|
||||
if cover_image:
|
||||
cover_image_name = cover_image.image_name
|
||||
else:
|
||||
cover_image_name = None
|
||||
|
||||
image_count = self._services.board_image_records.get_image_count_for_board(r.board_id)
|
||||
board_dtos.append(board_record_to_dto(r, cover_image_name, image_count))
|
||||
|
||||
return OffsetPaginatedResults[BoardDTO](items=board_dtos, offset=offset, limit=limit, total=len(board_dtos))
|
||||
|
||||
def get_all(self) -> list[BoardDTO]:
|
||||
board_records = self._services.board_records.get_all()
|
||||
board_dtos = []
|
||||
for r in board_records:
|
||||
cover_image = self._services.image_records.get_most_recent_image_for_board(r.board_id)
|
||||
if cover_image:
|
||||
cover_image_name = cover_image.image_name
|
||||
else:
|
||||
cover_image_name = None
|
||||
|
||||
image_count = self._services.board_image_records.get_image_count_for_board(r.board_id)
|
||||
board_dtos.append(board_record_to_dto(r, cover_image_name, image_count))
|
||||
|
||||
return board_dtos
|
0
invokeai/app/services/boards/__init__.py
Normal file
0
invokeai/app/services/boards/__init__.py
Normal file
59
invokeai/app/services/boards/boards_base.py
Normal file
59
invokeai/app/services/boards/boards_base.py
Normal file
@ -0,0 +1,59 @@
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from invokeai.app.services.board_records.board_records_common import BoardChanges
|
||||
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
|
||||
|
||||
from .boards_common import BoardDTO
|
||||
|
||||
|
||||
class BoardServiceABC(ABC):
|
||||
"""High-level service for board management."""
|
||||
|
||||
@abstractmethod
|
||||
def create(
|
||||
self,
|
||||
board_name: str,
|
||||
) -> BoardDTO:
|
||||
"""Creates a board."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_dto(
|
||||
self,
|
||||
board_id: str,
|
||||
) -> BoardDTO:
|
||||
"""Gets a board."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def update(
|
||||
self,
|
||||
board_id: str,
|
||||
changes: BoardChanges,
|
||||
) -> BoardDTO:
|
||||
"""Updates a board."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def delete(
|
||||
self,
|
||||
board_id: str,
|
||||
) -> None:
|
||||
"""Deletes a board."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_many(
|
||||
self,
|
||||
offset: int = 0,
|
||||
limit: int = 10,
|
||||
) -> OffsetPaginatedResults[BoardDTO]:
|
||||
"""Gets many boards."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_all(
|
||||
self,
|
||||
) -> list[BoardDTO]:
|
||||
"""Gets all boards."""
|
||||
pass
|
23
invokeai/app/services/boards/boards_common.py
Normal file
23
invokeai/app/services/boards/boards_common.py
Normal file
@ -0,0 +1,23 @@
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from ..board_records.board_records_common import BoardRecord
|
||||
|
||||
|
||||
class BoardDTO(BoardRecord):
|
||||
"""Deserialized board record with cover image URL and image count."""
|
||||
|
||||
cover_image_name: Optional[str] = Field(description="The name of the board's cover image.")
|
||||
"""The URL of the thumbnail of the most recent image in the board."""
|
||||
image_count: int = Field(description="The number of images in the board.")
|
||||
"""The number of images in the board."""
|
||||
|
||||
|
||||
def board_record_to_dto(board_record: BoardRecord, cover_image_name: Optional[str], image_count: int) -> BoardDTO:
|
||||
"""Converts a board record to a board DTO."""
|
||||
return BoardDTO(
|
||||
**board_record.model_dump(exclude={"cover_image_name"}),
|
||||
cover_image_name=cover_image_name,
|
||||
image_count=image_count,
|
||||
)
|
79
invokeai/app/services/boards/boards_default.py
Normal file
79
invokeai/app/services/boards/boards_default.py
Normal file
@ -0,0 +1,79 @@
|
||||
from invokeai.app.services.board_records.board_records_common import BoardChanges
|
||||
from invokeai.app.services.boards.boards_common import BoardDTO
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
|
||||
|
||||
from .boards_base import BoardServiceABC
|
||||
from .boards_common import board_record_to_dto
|
||||
|
||||
|
||||
class BoardService(BoardServiceABC):
|
||||
__invoker: Invoker
|
||||
|
||||
def start(self, invoker: Invoker) -> None:
|
||||
self.__invoker = invoker
|
||||
|
||||
def create(
|
||||
self,
|
||||
board_name: str,
|
||||
) -> BoardDTO:
|
||||
board_record = self.__invoker.services.board_records.save(board_name)
|
||||
return board_record_to_dto(board_record, None, 0)
|
||||
|
||||
def get_dto(self, board_id: str) -> BoardDTO:
|
||||
board_record = self.__invoker.services.board_records.get(board_id)
|
||||
cover_image = self.__invoker.services.image_records.get_most_recent_image_for_board(board_record.board_id)
|
||||
if cover_image:
|
||||
cover_image_name = cover_image.image_name
|
||||
else:
|
||||
cover_image_name = None
|
||||
image_count = self.__invoker.services.board_image_records.get_image_count_for_board(board_id)
|
||||
return board_record_to_dto(board_record, cover_image_name, image_count)
|
||||
|
||||
def update(
|
||||
self,
|
||||
board_id: str,
|
||||
changes: BoardChanges,
|
||||
) -> BoardDTO:
|
||||
board_record = self.__invoker.services.board_records.update(board_id, changes)
|
||||
cover_image = self.__invoker.services.image_records.get_most_recent_image_for_board(board_record.board_id)
|
||||
if cover_image:
|
||||
cover_image_name = cover_image.image_name
|
||||
else:
|
||||
cover_image_name = None
|
||||
|
||||
image_count = self.__invoker.services.board_image_records.get_image_count_for_board(board_id)
|
||||
return board_record_to_dto(board_record, cover_image_name, image_count)
|
||||
|
||||
def delete(self, board_id: str) -> None:
|
||||
self.__invoker.services.board_records.delete(board_id)
|
||||
|
||||
def get_many(self, offset: int = 0, limit: int = 10) -> OffsetPaginatedResults[BoardDTO]:
|
||||
board_records = self.__invoker.services.board_records.get_many(offset, limit)
|
||||
board_dtos = []
|
||||
for r in board_records.items:
|
||||
cover_image = self.__invoker.services.image_records.get_most_recent_image_for_board(r.board_id)
|
||||
if cover_image:
|
||||
cover_image_name = cover_image.image_name
|
||||
else:
|
||||
cover_image_name = None
|
||||
|
||||
image_count = self.__invoker.services.board_image_records.get_image_count_for_board(r.board_id)
|
||||
board_dtos.append(board_record_to_dto(r, cover_image_name, image_count))
|
||||
|
||||
return OffsetPaginatedResults[BoardDTO](items=board_dtos, offset=offset, limit=limit, total=len(board_dtos))
|
||||
|
||||
def get_all(self) -> list[BoardDTO]:
|
||||
board_records = self.__invoker.services.board_records.get_all()
|
||||
board_dtos = []
|
||||
for r in board_records:
|
||||
cover_image = self.__invoker.services.image_records.get_most_recent_image_for_board(r.board_id)
|
||||
if cover_image:
|
||||
cover_image_name = cover_image.image_name
|
||||
else:
|
||||
cover_image_name = None
|
||||
|
||||
image_count = self.__invoker.services.board_image_records.get_image_count_for_board(r.board_id)
|
||||
board_dtos.append(board_record_to_dto(r, cover_image_name, image_count))
|
||||
|
||||
return board_dtos
|
@ -2,5 +2,5 @@
|
||||
Init file for InvokeAI configure package
|
||||
"""
|
||||
|
||||
from .base import PagingArgumentParser # noqa F401
|
||||
from .invokeai_config import InvokeAIAppConfig, get_invokeai_config # noqa F401
|
||||
from .config_base import PagingArgumentParser # noqa F401
|
||||
from .config_default import InvokeAIAppConfig, get_invokeai_config # noqa F401
|
||||
|
@ -12,25 +12,15 @@ from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import pydoc
|
||||
import sys
|
||||
from argparse import ArgumentParser
|
||||
from pathlib import Path
|
||||
from typing import ClassVar, Dict, List, Literal, Optional, Union, get_args, get_origin, get_type_hints
|
||||
|
||||
from omegaconf import DictConfig, ListConfig, OmegaConf
|
||||
from pydantic import BaseSettings
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
|
||||
|
||||
class PagingArgumentParser(argparse.ArgumentParser):
|
||||
"""
|
||||
A custom ArgumentParser that uses pydoc to page its output.
|
||||
It also supports reading defaults from an init file.
|
||||
"""
|
||||
|
||||
def print_help(self, file=None):
|
||||
text = self.format_help()
|
||||
pydoc.pager(text)
|
||||
from invokeai.app.services.config.config_common import PagingArgumentParser, int_or_float_or_str
|
||||
|
||||
|
||||
class InvokeAISettings(BaseSettings):
|
||||
@ -42,12 +32,14 @@ class InvokeAISettings(BaseSettings):
|
||||
initconf: ClassVar[Optional[DictConfig]] = None
|
||||
argparse_groups: ClassVar[Dict] = {}
|
||||
|
||||
model_config = SettingsConfigDict(env_file_encoding="utf-8", arbitrary_types_allowed=True, case_sensitive=True)
|
||||
|
||||
def parse_args(self, argv: Optional[list] = sys.argv[1:]):
|
||||
parser = self.get_parser()
|
||||
opt, unknown_opts = parser.parse_known_args(argv)
|
||||
if len(unknown_opts) > 0:
|
||||
print("Unknown args:", unknown_opts)
|
||||
for name in self.__fields__:
|
||||
for name in self.model_fields:
|
||||
if name not in self._excluded():
|
||||
value = getattr(opt, name)
|
||||
if isinstance(value, ListConfig):
|
||||
@ -64,10 +56,12 @@ class InvokeAISettings(BaseSettings):
|
||||
cls = self.__class__
|
||||
type = get_args(get_type_hints(cls)["type"])[0]
|
||||
field_dict = dict({type: dict()})
|
||||
for name, field in self.__fields__.items():
|
||||
for name, field in self.model_fields.items():
|
||||
if name in cls._excluded_from_yaml():
|
||||
continue
|
||||
category = field.field_info.extra.get("category") or "Uncategorized"
|
||||
category = (
|
||||
field.json_schema_extra.get("category", "Uncategorized") if field.json_schema_extra else "Uncategorized"
|
||||
)
|
||||
value = getattr(self, name)
|
||||
if category not in field_dict[type]:
|
||||
field_dict[type][category] = dict()
|
||||
@ -83,7 +77,7 @@ class InvokeAISettings(BaseSettings):
|
||||
else:
|
||||
settings_stanza = "Uncategorized"
|
||||
|
||||
env_prefix = getattr(cls.Config, "env_prefix", None)
|
||||
env_prefix = getattr(cls.model_config, "env_prefix", None)
|
||||
env_prefix = env_prefix if env_prefix is not None else settings_stanza.upper()
|
||||
|
||||
initconf = (
|
||||
@ -99,14 +93,18 @@ class InvokeAISettings(BaseSettings):
|
||||
for key, value in os.environ.items():
|
||||
upcase_environ[key.upper()] = value
|
||||
|
||||
fields = cls.__fields__
|
||||
fields = cls.model_fields
|
||||
cls.argparse_groups = {}
|
||||
|
||||
for name, field in fields.items():
|
||||
if name not in cls._excluded():
|
||||
current_default = field.default
|
||||
|
||||
category = field.field_info.extra.get("category", "Uncategorized")
|
||||
category = (
|
||||
field.json_schema_extra.get("category", "Uncategorized")
|
||||
if field.json_schema_extra
|
||||
else "Uncategorized"
|
||||
)
|
||||
env_name = env_prefix + "_" + name
|
||||
if category in initconf and name in initconf.get(category):
|
||||
field.default = initconf.get(category).get(name)
|
||||
@ -156,11 +154,6 @@ class InvokeAISettings(BaseSettings):
|
||||
"tiled_decode",
|
||||
]
|
||||
|
||||
class Config:
|
||||
env_file_encoding = "utf-8"
|
||||
arbitrary_types_allowed = True
|
||||
case_sensitive = True
|
||||
|
||||
@classmethod
|
||||
def add_field_argument(cls, command_parser, name: str, field, default_override=None):
|
||||
field_type = get_type_hints(cls).get(name)
|
||||
@ -171,7 +164,7 @@ class InvokeAISettings(BaseSettings):
|
||||
if field.default_factory is None
|
||||
else field.default_factory()
|
||||
)
|
||||
if category := field.field_info.extra.get("category"):
|
||||
if category := (field.json_schema_extra.get("category", None) if field.json_schema_extra else None):
|
||||
if category not in cls.argparse_groups:
|
||||
cls.argparse_groups[category] = command_parser.add_argument_group(category)
|
||||
argparse_group = cls.argparse_groups[category]
|
||||
@ -179,7 +172,7 @@ class InvokeAISettings(BaseSettings):
|
||||
argparse_group = command_parser
|
||||
|
||||
if get_origin(field_type) == Literal:
|
||||
allowed_values = get_args(field.type_)
|
||||
allowed_values = get_args(field.annotation)
|
||||
allowed_types = set()
|
||||
for val in allowed_values:
|
||||
allowed_types.add(type(val))
|
||||
@ -192,7 +185,7 @@ class InvokeAISettings(BaseSettings):
|
||||
type=field_type,
|
||||
default=default,
|
||||
choices=allowed_values,
|
||||
help=field.field_info.description,
|
||||
help=field.description,
|
||||
)
|
||||
|
||||
elif get_origin(field_type) == Union:
|
||||
@ -201,7 +194,7 @@ class InvokeAISettings(BaseSettings):
|
||||
dest=name,
|
||||
type=int_or_float_or_str,
|
||||
default=default,
|
||||
help=field.field_info.description,
|
||||
help=field.description,
|
||||
)
|
||||
|
||||
elif get_origin(field_type) == list:
|
||||
@ -209,32 +202,17 @@ class InvokeAISettings(BaseSettings):
|
||||
f"--{name}",
|
||||
dest=name,
|
||||
nargs="*",
|
||||
type=field.type_,
|
||||
type=field.annotation,
|
||||
default=default,
|
||||
action=argparse.BooleanOptionalAction if field.type_ == bool else "store",
|
||||
help=field.field_info.description,
|
||||
action=argparse.BooleanOptionalAction if field.annotation == bool else "store",
|
||||
help=field.description,
|
||||
)
|
||||
else:
|
||||
argparse_group.add_argument(
|
||||
f"--{name}",
|
||||
dest=name,
|
||||
type=field.type_,
|
||||
type=field.annotation,
|
||||
default=default,
|
||||
action=argparse.BooleanOptionalAction if field.type_ == bool else "store",
|
||||
help=field.field_info.description,
|
||||
action=argparse.BooleanOptionalAction if field.annotation == bool else "store",
|
||||
help=field.description,
|
||||
)
|
||||
|
||||
|
||||
def int_or_float_or_str(value: str) -> Union[int, float, str]:
|
||||
"""
|
||||
Workaround for argparse type checking.
|
||||
"""
|
||||
try:
|
||||
return int(value)
|
||||
except Exception as e: # noqa F841
|
||||
pass
|
||||
try:
|
||||
return float(value)
|
||||
except Exception as e: # noqa F841
|
||||
pass
|
||||
return str(value)
|
41
invokeai/app/services/config/config_common.py
Normal file
41
invokeai/app/services/config/config_common.py
Normal file
@ -0,0 +1,41 @@
|
||||
# Copyright (c) 2023 Lincoln Stein (https://github.com/lstein) and the InvokeAI Development Team
|
||||
|
||||
"""
|
||||
Base class for the InvokeAI configuration system.
|
||||
It defines a type of pydantic BaseSettings object that
|
||||
is able to read and write from an omegaconf-based config file,
|
||||
with overriding of settings from environment variables and/or
|
||||
the command line.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import pydoc
|
||||
from typing import Union
|
||||
|
||||
|
||||
class PagingArgumentParser(argparse.ArgumentParser):
|
||||
"""
|
||||
A custom ArgumentParser that uses pydoc to page its output.
|
||||
It also supports reading defaults from an init file.
|
||||
"""
|
||||
|
||||
def print_help(self, file=None):
|
||||
text = self.format_help()
|
||||
pydoc.pager(text)
|
||||
|
||||
|
||||
def int_or_float_or_str(value: str) -> Union[int, float, str]:
|
||||
"""
|
||||
Workaround for argparse type checking.
|
||||
"""
|
||||
try:
|
||||
return int(value)
|
||||
except Exception as e: # noqa F841
|
||||
pass
|
||||
try:
|
||||
return float(value)
|
||||
except Exception as e: # noqa F841
|
||||
pass
|
||||
return str(value)
|
@ -144,8 +144,8 @@ which is set to the desired top-level name. For example, to create a
|
||||
|
||||
class InvokeBatch(InvokeAISettings):
|
||||
type: Literal["InvokeBatch"] = "InvokeBatch"
|
||||
node_count : int = Field(default=1, description="Number of nodes to run on", category='Resources')
|
||||
cpu_count : int = Field(default=8, description="Number of GPUs to run on per node", category='Resources')
|
||||
node_count : int = Field(default=1, description="Number of nodes to run on", json_schema_extra=dict(category='Resources'))
|
||||
cpu_count : int = Field(default=8, description="Number of GPUs to run on per node", json_schema_extra=dict(category='Resources'))
|
||||
|
||||
This will now read and write from the "InvokeBatch" section of the
|
||||
config file, look for environment variables named INVOKEBATCH_*, and
|
||||
@ -175,9 +175,10 @@ from pathlib import Path
|
||||
from typing import ClassVar, Dict, List, Literal, Optional, Union, get_type_hints
|
||||
|
||||
from omegaconf import DictConfig, OmegaConf
|
||||
from pydantic import Field, parse_obj_as
|
||||
from pydantic import Field, TypeAdapter
|
||||
from pydantic_settings import SettingsConfigDict
|
||||
|
||||
from .base import InvokeAISettings
|
||||
from .config_base import InvokeAISettings
|
||||
|
||||
INIT_FILE = Path("invokeai.yaml")
|
||||
DB_FILE = Path("invokeai.db")
|
||||
@ -185,6 +186,21 @@ LEGACY_INIT_FILE = Path("invokeai.init")
|
||||
DEFAULT_MAX_VRAM = 0.5
|
||||
|
||||
|
||||
class Categories(object):
|
||||
WebServer = dict(category="Web Server")
|
||||
Features = dict(category="Features")
|
||||
Paths = dict(category="Paths")
|
||||
Logging = dict(category="Logging")
|
||||
Development = dict(category="Development")
|
||||
Other = dict(category="Other")
|
||||
ModelCache = dict(category="Model Cache")
|
||||
Device = dict(category="Device")
|
||||
Generation = dict(category="Generation")
|
||||
Queue = dict(category="Queue")
|
||||
Nodes = dict(category="Nodes")
|
||||
MemoryPerformance = dict(category="Memory/Performance")
|
||||
|
||||
|
||||
class InvokeAIAppConfig(InvokeAISettings):
|
||||
"""
|
||||
Generate images using Stable Diffusion. Use "invokeai" to launch
|
||||
@ -201,86 +217,88 @@ class InvokeAIAppConfig(InvokeAISettings):
|
||||
type: Literal["InvokeAI"] = "InvokeAI"
|
||||
|
||||
# WEB
|
||||
host : str = Field(default="127.0.0.1", description="IP address to bind to", category='Web Server')
|
||||
port : int = Field(default=9090, description="Port to bind to", category='Web Server')
|
||||
allow_origins : List[str] = Field(default=[], description="Allowed CORS origins", category='Web Server')
|
||||
allow_credentials : bool = Field(default=True, description="Allow CORS credentials", category='Web Server')
|
||||
allow_methods : List[str] = Field(default=["*"], description="Methods allowed for CORS", category='Web Server')
|
||||
allow_headers : List[str] = Field(default=["*"], description="Headers allowed for CORS", category='Web Server')
|
||||
host : str = Field(default="127.0.0.1", description="IP address to bind to", json_schema_extra=Categories.WebServer)
|
||||
port : int = Field(default=9090, description="Port to bind to", json_schema_extra=Categories.WebServer)
|
||||
allow_origins : List[str] = Field(default=[], description="Allowed CORS origins", json_schema_extra=Categories.WebServer)
|
||||
allow_credentials : bool = Field(default=True, description="Allow CORS credentials", json_schema_extra=Categories.WebServer)
|
||||
allow_methods : List[str] = Field(default=["*"], description="Methods allowed for CORS", json_schema_extra=Categories.WebServer)
|
||||
allow_headers : List[str] = Field(default=["*"], description="Headers allowed for CORS", json_schema_extra=Categories.WebServer)
|
||||
|
||||
# FEATURES
|
||||
esrgan : bool = Field(default=True, description="Enable/disable upscaling code", category='Features')
|
||||
internet_available : bool = Field(default=True, description="If true, attempt to download models on the fly; otherwise only use local models", category='Features')
|
||||
log_tokenization : bool = Field(default=False, description="Enable logging of parsed prompt tokens.", category='Features')
|
||||
patchmatch : bool = Field(default=True, description="Enable/disable patchmatch inpaint code", category='Features')
|
||||
ignore_missing_core_models : bool = Field(default=False, description='Ignore missing models in models/core/convert', category='Features')
|
||||
esrgan : bool = Field(default=True, description="Enable/disable upscaling code", json_schema_extra=Categories.Features)
|
||||
internet_available : bool = Field(default=True, description="If true, attempt to download models on the fly; otherwise only use local models", json_schema_extra=Categories.Features)
|
||||
log_tokenization : bool = Field(default=False, description="Enable logging of parsed prompt tokens.", json_schema_extra=Categories.Features)
|
||||
patchmatch : bool = Field(default=True, description="Enable/disable patchmatch inpaint code", json_schema_extra=Categories.Features)
|
||||
ignore_missing_core_models : bool = Field(default=False, description='Ignore missing models in models/core/convert', json_schema_extra=Categories.Features)
|
||||
|
||||
# PATHS
|
||||
root : Path = Field(default=None, description='InvokeAI runtime root directory', category='Paths')
|
||||
autoimport_dir : Path = Field(default='autoimport', description='Path to a directory of models files to be imported on startup.', category='Paths')
|
||||
lora_dir : Path = Field(default=None, description='Path to a directory of LoRA/LyCORIS models to be imported on startup.', category='Paths')
|
||||
embedding_dir : Path = Field(default=None, description='Path to a directory of Textual Inversion embeddings to be imported on startup.', category='Paths')
|
||||
controlnet_dir : Path = Field(default=None, description='Path to a directory of ControlNet embeddings to be imported on startup.', category='Paths')
|
||||
conf_path : Path = Field(default='configs/models.yaml', description='Path to models definition file', category='Paths')
|
||||
models_dir : Path = Field(default='models', description='Path to the models directory', category='Paths')
|
||||
legacy_conf_dir : Path = Field(default='configs/stable-diffusion', description='Path to directory of legacy checkpoint config files', category='Paths')
|
||||
db_dir : Path = Field(default='databases', description='Path to InvokeAI databases directory', category='Paths')
|
||||
outdir : Path = Field(default='outputs', description='Default folder for output images', category='Paths')
|
||||
use_memory_db : bool = Field(default=False, description='Use in-memory database for storing image metadata', category='Paths')
|
||||
from_file : Path = Field(default=None, description='Take command input from the indicated file (command-line client only)', category='Paths')
|
||||
root : Optional[Path] = Field(default=None, description='InvokeAI runtime root directory', json_schema_extra=Categories.Paths)
|
||||
autoimport_dir : Optional[Path] = Field(default=Path('autoimport'), description='Path to a directory of models files to be imported on startup.', json_schema_extra=Categories.Paths)
|
||||
lora_dir : Optional[Path] = Field(default=None, description='Path to a directory of LoRA/LyCORIS models to be imported on startup.', json_schema_extra=Categories.Paths)
|
||||
embedding_dir : Optional[Path] = Field(default=None, description='Path to a directory of Textual Inversion embeddings to be imported on startup.', json_schema_extra=Categories.Paths)
|
||||
controlnet_dir : Optional[Path] = Field(default=None, description='Path to a directory of ControlNet embeddings to be imported on startup.', json_schema_extra=Categories.Paths)
|
||||
conf_path : Optional[Path] = Field(default=Path('configs/models.yaml'), description='Path to models definition file', json_schema_extra=Categories.Paths)
|
||||
models_dir : Optional[Path] = Field(default=Path('models'), description='Path to the models directory', json_schema_extra=Categories.Paths)
|
||||
legacy_conf_dir : Optional[Path] = Field(default=Path('configs/stable-diffusion'), description='Path to directory of legacy checkpoint config files', json_schema_extra=Categories.Paths)
|
||||
db_dir : Optional[Path] = Field(default=Path('databases'), description='Path to InvokeAI databases directory', json_schema_extra=Categories.Paths)
|
||||
outdir : Optional[Path] = Field(default=Path('outputs'), description='Default folder for output images', json_schema_extra=Categories.Paths)
|
||||
use_memory_db : bool = Field(default=False, description='Use in-memory database for storing image metadata', json_schema_extra=Categories.Paths)
|
||||
from_file : Optional[Path] = Field(default=None, description='Take command input from the indicated file (command-line client only)', json_schema_extra=Categories.Paths)
|
||||
|
||||
# LOGGING
|
||||
log_handlers : List[str] = Field(default=["console"], description='Log handler. Valid options are "console", "file=<path>", "syslog=path|address:host:port", "http=<url>"', category="Logging")
|
||||
log_handlers : List[str] = Field(default=["console"], description='Log handler. Valid options are "console", "file=<path>", "syslog=path|address:host:port", "http=<url>"', json_schema_extra=Categories.Logging)
|
||||
# note - would be better to read the log_format values from logging.py, but this creates circular dependencies issues
|
||||
log_format : Literal['plain', 'color', 'syslog', 'legacy'] = Field(default="color", description='Log format. Use "plain" for text-only, "color" for colorized output, "legacy" for 2.3-style logging and "syslog" for syslog-style', category="Logging")
|
||||
log_level : Literal["debug", "info", "warning", "error", "critical"] = Field(default="info", description="Emit logging messages at this level or higher", category="Logging")
|
||||
log_sql : bool = Field(default=False, description="Log SQL queries", category="Logging")
|
||||
log_format : Literal['plain', 'color', 'syslog', 'legacy'] = Field(default="color", description='Log format. Use "plain" for text-only, "color" for colorized output, "legacy" for 2.3-style logging and "syslog" for syslog-style', json_schema_extra=Categories.Logging)
|
||||
log_level : Literal["debug", "info", "warning", "error", "critical"] = Field(default="info", description="Emit logging messages at this level or higher", json_schema_extra=Categories.Logging)
|
||||
log_sql : bool = Field(default=False, description="Log SQL queries", json_schema_extra=Categories.Logging)
|
||||
|
||||
dev_reload : bool = Field(default=False, description="Automatically reload when Python sources are changed.", category="Development")
|
||||
dev_reload : bool = Field(default=False, description="Automatically reload when Python sources are changed.", json_schema_extra=Categories.Development)
|
||||
|
||||
version : bool = Field(default=False, description="Show InvokeAI version and exit", category="Other")
|
||||
version : bool = Field(default=False, description="Show InvokeAI version and exit", json_schema_extra=Categories.Other)
|
||||
|
||||
# CACHE
|
||||
ram : Union[float, Literal["auto"]] = Field(default=7.5, gt=0, description="Maximum memory amount used by model cache for rapid switching (floating point number or 'auto')", category="Model Cache", )
|
||||
vram : Union[float, Literal["auto"]] = Field(default=0.25, ge=0, description="Amount of VRAM reserved for model storage (floating point number or 'auto')", category="Model Cache", )
|
||||
lazy_offload : bool = Field(default=True, description="Keep models in VRAM until their space is needed", category="Model Cache", )
|
||||
ram : float = Field(default=7.5, gt=0, description="Maximum memory amount used by model cache for rapid switching (floating point number, GB)", json_schema_extra=Categories.ModelCache, )
|
||||
vram : float = Field(default=0.25, ge=0, description="Amount of VRAM reserved for model storage (floating point number, GB)", json_schema_extra=Categories.ModelCache, )
|
||||
lazy_offload : bool = Field(default=True, description="Keep models in VRAM until their space is needed", json_schema_extra=Categories.ModelCache, )
|
||||
|
||||
# DEVICE
|
||||
device : Literal["auto", "cpu", "cuda", "cuda:1", "mps"] = Field(default="auto", description="Generation device", category="Device", )
|
||||
precision : Literal["auto", "float16", "float32", "autocast"] = Field(default="auto", description="Floating point precision", category="Device", )
|
||||
device : Literal["auto", "cpu", "cuda", "cuda:1", "mps"] = Field(default="auto", description="Generation device", json_schema_extra=Categories.Device)
|
||||
precision : Literal["auto", "float16", "float32", "autocast"] = Field(default="auto", description="Floating point precision", json_schema_extra=Categories.Device)
|
||||
|
||||
# GENERATION
|
||||
sequential_guidance : bool = Field(default=False, description="Whether to calculate guidance in serial instead of in parallel, lowering memory requirements", category="Generation", )
|
||||
attention_type : Literal["auto", "normal", "xformers", "sliced", "torch-sdp"] = Field(default="auto", description="Attention type", category="Generation", )
|
||||
attention_slice_size: Literal["auto", "balanced", "max", 1, 2, 3, 4, 5, 6, 7, 8] = Field(default="auto", description='Slice size, valid when attention_type=="sliced"', category="Generation", )
|
||||
force_tiled_decode : bool = Field(default=False, description="Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty)", category="Generation",)
|
||||
force_tiled_decode: bool = Field(default=False, description="Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty)", category="Generation",)
|
||||
png_compress_level : int = Field(default=6, description="The compress_level setting of PIL.Image.save(), used for PNG encoding. All settings are lossless. 0 = fastest, largest filesize, 9 = slowest, smallest filesize", category="Generation", )
|
||||
sequential_guidance : bool = Field(default=False, description="Whether to calculate guidance in serial instead of in parallel, lowering memory requirements", json_schema_extra=Categories.Generation)
|
||||
attention_type : Literal["auto", "normal", "xformers", "sliced", "torch-sdp"] = Field(default="auto", description="Attention type", json_schema_extra=Categories.Generation)
|
||||
attention_slice_size: Literal["auto", "balanced", "max", 1, 2, 3, 4, 5, 6, 7, 8] = Field(default="auto", description='Slice size, valid when attention_type=="sliced"', json_schema_extra=Categories.Generation)
|
||||
force_tiled_decode : bool = Field(default=False, description="Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty)", json_schema_extra=Categories.Generation)
|
||||
png_compress_level : int = Field(default=6, description="The compress_level setting of PIL.Image.save(), used for PNG encoding. All settings are lossless. 0 = fastest, largest filesize, 9 = slowest, smallest filesize", json_schema_extra=Categories.Generation)
|
||||
|
||||
# QUEUE
|
||||
max_queue_size : int = Field(default=10000, gt=0, description="Maximum number of items in the session queue", category="Queue", )
|
||||
max_queue_size : int = Field(default=10000, gt=0, description="Maximum number of items in the session queue", json_schema_extra=Categories.Queue)
|
||||
|
||||
# NODES
|
||||
allow_nodes : Optional[List[str]] = Field(default=None, description="List of nodes to allow. Omit to allow all.", category="Nodes")
|
||||
deny_nodes : Optional[List[str]] = Field(default=None, description="List of nodes to deny. Omit to deny none.", category="Nodes")
|
||||
node_cache_size : int = Field(default=512, description="How many cached nodes to keep in memory", category="Nodes", )
|
||||
allow_nodes : Optional[List[str]] = Field(default=None, description="List of nodes to allow. Omit to allow all.", json_schema_extra=Categories.Nodes)
|
||||
deny_nodes : Optional[List[str]] = Field(default=None, description="List of nodes to deny. Omit to deny none.", json_schema_extra=Categories.Nodes)
|
||||
node_cache_size : int = Field(default=512, description="How many cached nodes to keep in memory", json_schema_extra=Categories.Nodes)
|
||||
|
||||
# DEPRECATED FIELDS - STILL HERE IN ORDER TO OBTAN VALUES FROM PRE-3.1 CONFIG FILES
|
||||
always_use_cpu : bool = Field(default=False, description="If true, use the CPU for rendering even if a GPU is available.", category='Memory/Performance')
|
||||
free_gpu_mem : Optional[bool] = Field(default=None, description="If true, purge model from GPU after each generation.", category='Memory/Performance')
|
||||
max_cache_size : Optional[float] = Field(default=None, gt=0, description="Maximum memory amount used by model cache for rapid switching", category='Memory/Performance')
|
||||
max_vram_cache_size : Optional[float] = Field(default=None, ge=0, description="Amount of VRAM reserved for model storage", category='Memory/Performance')
|
||||
xformers_enabled : bool = Field(default=True, description="Enable/disable memory-efficient attention", category='Memory/Performance')
|
||||
tiled_decode : bool = Field(default=False, description="Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty)", category='Memory/Performance')
|
||||
always_use_cpu : bool = Field(default=False, description="If true, use the CPU for rendering even if a GPU is available.", json_schema_extra=Categories.MemoryPerformance)
|
||||
free_gpu_mem : Optional[bool] = Field(default=None, description="If true, purge model from GPU after each generation.", json_schema_extra=Categories.MemoryPerformance)
|
||||
max_cache_size : Optional[float] = Field(default=None, gt=0, description="Maximum memory amount used by model cache for rapid switching", json_schema_extra=Categories.MemoryPerformance)
|
||||
max_vram_cache_size : Optional[float] = Field(default=None, ge=0, description="Amount of VRAM reserved for model storage", json_schema_extra=Categories.MemoryPerformance)
|
||||
xformers_enabled : bool = Field(default=True, description="Enable/disable memory-efficient attention", json_schema_extra=Categories.MemoryPerformance)
|
||||
tiled_decode : bool = Field(default=False, description="Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty)", json_schema_extra=Categories.MemoryPerformance)
|
||||
|
||||
# See InvokeAIAppConfig subclass below for CACHE and DEVICE categories
|
||||
# fmt: on
|
||||
|
||||
class Config:
|
||||
validate_assignment = True
|
||||
env_prefix = "INVOKEAI"
|
||||
model_config = SettingsConfigDict(validate_assignment=True, env_prefix="INVOKEAI")
|
||||
|
||||
def parse_args(self, argv: Optional[list[str]] = None, conf: Optional[DictConfig] = None, clobber=False):
|
||||
def parse_args(
|
||||
self,
|
||||
argv: Optional[list[str]] = None,
|
||||
conf: Optional[DictConfig] = None,
|
||||
clobber=False,
|
||||
):
|
||||
"""
|
||||
Update settings with contents of init file, environment, and
|
||||
command-line settings.
|
||||
@ -308,7 +326,11 @@ class InvokeAIAppConfig(InvokeAISettings):
|
||||
if self.singleton_init and not clobber:
|
||||
hints = get_type_hints(self.__class__)
|
||||
for k in self.singleton_init:
|
||||
setattr(self, k, parse_obj_as(hints[k], self.singleton_init[k]))
|
||||
setattr(
|
||||
self,
|
||||
k,
|
||||
TypeAdapter(hints[k]).validate_python(self.singleton_init[k]),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_config(cls, **kwargs) -> InvokeAIAppConfig:
|
0
invokeai/app/services/events/__init__.py
Normal file
0
invokeai/app/services/events/__init__.py
Normal file
@ -2,8 +2,7 @@
|
||||
|
||||
from typing import Any, Optional
|
||||
|
||||
from invokeai.app.models.image import ProgressImage
|
||||
from invokeai.app.services.model_manager_service import BaseModelType, ModelInfo, ModelType, SubModelType
|
||||
from invokeai.app.services.invocation_processor.invocation_processor_common import ProgressImage
|
||||
from invokeai.app.services.session_queue.session_queue_common import (
|
||||
BatchStatus,
|
||||
EnqueueBatchResult,
|
||||
@ -11,6 +10,8 @@ from invokeai.app.services.session_queue.session_queue_common import (
|
||||
SessionQueueStatus,
|
||||
)
|
||||
from invokeai.app.util.misc import get_timestamp
|
||||
from invokeai.backend.model_management.model_manager import ModelInfo
|
||||
from invokeai.backend.model_management.models.base import BaseModelType, ModelType, SubModelType
|
||||
|
||||
|
||||
class EventServiceBase:
|
||||
@ -54,7 +55,7 @@ class EventServiceBase:
|
||||
graph_execution_state_id=graph_execution_state_id,
|
||||
node_id=node.get("id"),
|
||||
source_node_id=source_node_id,
|
||||
progress_image=progress_image.dict() if progress_image is not None else None,
|
||||
progress_image=progress_image.model_dump() if progress_image is not None else None,
|
||||
step=step,
|
||||
order=order,
|
||||
total_steps=total_steps,
|
||||
@ -290,8 +291,8 @@ class EventServiceBase:
|
||||
started_at=str(session_queue_item.started_at) if session_queue_item.started_at else None,
|
||||
completed_at=str(session_queue_item.completed_at) if session_queue_item.completed_at else None,
|
||||
),
|
||||
batch_status=batch_status.dict(),
|
||||
queue_status=queue_status.dict(),
|
||||
batch_status=batch_status.model_dump(),
|
||||
queue_status=queue_status.model_dump(),
|
||||
),
|
||||
)
|
||||
|
0
invokeai/app/services/image_files/__init__.py
Normal file
0
invokeai/app/services/image_files/__init__.py
Normal file
43
invokeai/app/services/image_files/image_files_base.py
Normal file
43
invokeai/app/services/image_files/image_files_base.py
Normal file
@ -0,0 +1,43 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from PIL.Image import Image as PILImageType
|
||||
|
||||
|
||||
class ImageFileStorageBase(ABC):
|
||||
"""Low-level service responsible for storing and retrieving image files."""
|
||||
|
||||
@abstractmethod
|
||||
def get(self, image_name: str) -> PILImageType:
|
||||
"""Retrieves an image as PIL Image."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_path(self, image_name: str, thumbnail: bool = False) -> Path:
|
||||
"""Gets the internal path to an image or thumbnail."""
|
||||
pass
|
||||
|
||||
# TODO: We need to validate paths before starlette makes the FileResponse, else we get a
|
||||
# 500 internal server error. I don't like having this method on the service.
|
||||
@abstractmethod
|
||||
def validate_path(self, path: str) -> bool:
|
||||
"""Validates the path given for an image or thumbnail."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def save(
|
||||
self,
|
||||
image: PILImageType,
|
||||
image_name: str,
|
||||
metadata: Optional[dict] = None,
|
||||
workflow: Optional[str] = None,
|
||||
thumbnail_size: int = 256,
|
||||
) -> None:
|
||||
"""Saves an image and a 256x256 WEBP thumbnail. Returns a tuple of the image name, thumbnail name, and created timestamp."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def delete(self, image_name: str) -> None:
|
||||
"""Deletes an image and its thumbnail (if one exists)."""
|
||||
pass
|
20
invokeai/app/services/image_files/image_files_common.py
Normal file
20
invokeai/app/services/image_files/image_files_common.py
Normal file
@ -0,0 +1,20 @@
|
||||
# TODO: Should these excpetions subclass existing python exceptions?
|
||||
class ImageFileNotFoundException(Exception):
|
||||
"""Raised when an image file is not found in storage."""
|
||||
|
||||
def __init__(self, message="Image file not found"):
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class ImageFileSaveException(Exception):
|
||||
"""Raised when an image cannot be saved."""
|
||||
|
||||
def __init__(self, message="Image file not saved"):
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class ImageFileDeleteException(Exception):
|
||||
"""Raised when an image cannot be deleted."""
|
||||
|
||||
def __init__(self, message="Image file not deleted"):
|
||||
super().__init__(message)
|
@ -1,6 +1,5 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) and the InvokeAI Team
|
||||
import json
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from queue import Queue
|
||||
from typing import Dict, Optional, Union
|
||||
@ -9,68 +8,11 @@ from PIL import Image, PngImagePlugin
|
||||
from PIL.Image import Image as PILImageType
|
||||
from send2trash import send2trash
|
||||
|
||||
from invokeai.app.services.config.invokeai_config import InvokeAIAppConfig
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
from invokeai.app.util.thumbnails import get_thumbnail_name, make_thumbnail
|
||||
|
||||
|
||||
# TODO: Should these excpetions subclass existing python exceptions?
|
||||
class ImageFileNotFoundException(Exception):
|
||||
"""Raised when an image file is not found in storage."""
|
||||
|
||||
def __init__(self, message="Image file not found"):
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class ImageFileSaveException(Exception):
|
||||
"""Raised when an image cannot be saved."""
|
||||
|
||||
def __init__(self, message="Image file not saved"):
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class ImageFileDeleteException(Exception):
|
||||
"""Raised when an image cannot be deleted."""
|
||||
|
||||
def __init__(self, message="Image file not deleted"):
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class ImageFileStorageBase(ABC):
|
||||
"""Low-level service responsible for storing and retrieving image files."""
|
||||
|
||||
@abstractmethod
|
||||
def get(self, image_name: str) -> PILImageType:
|
||||
"""Retrieves an image as PIL Image."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_path(self, image_name: str, thumbnail: bool = False) -> str:
|
||||
"""Gets the internal path to an image or thumbnail."""
|
||||
pass
|
||||
|
||||
# TODO: We need to validate paths before starlette makes the FileResponse, else we get a
|
||||
# 500 internal server error. I don't like having this method on the service.
|
||||
@abstractmethod
|
||||
def validate_path(self, path: str) -> bool:
|
||||
"""Validates the path given for an image or thumbnail."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def save(
|
||||
self,
|
||||
image: PILImageType,
|
||||
image_name: str,
|
||||
metadata: Optional[dict] = None,
|
||||
workflow: Optional[str] = None,
|
||||
thumbnail_size: int = 256,
|
||||
) -> None:
|
||||
"""Saves an image and a 256x256 WEBP thumbnail. Returns a tuple of the image name, thumbnail name, and created timestamp."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def delete(self, image_name: str) -> None:
|
||||
"""Deletes an image and its thumbnail (if one exists)."""
|
||||
pass
|
||||
from .image_files_base import ImageFileStorageBase
|
||||
from .image_files_common import ImageFileDeleteException, ImageFileNotFoundException, ImageFileSaveException
|
||||
|
||||
|
||||
class DiskImageFileStorage(ImageFileStorageBase):
|
||||
@ -80,7 +22,7 @@ class DiskImageFileStorage(ImageFileStorageBase):
|
||||
__cache_ids: Queue # TODO: this is an incredibly naive cache
|
||||
__cache: Dict[Path, PILImageType]
|
||||
__max_cache_size: int
|
||||
__compress_level: int
|
||||
__invoker: Invoker
|
||||
|
||||
def __init__(self, output_folder: Union[str, Path]):
|
||||
self.__cache = dict()
|
||||
@ -89,10 +31,12 @@ class DiskImageFileStorage(ImageFileStorageBase):
|
||||
|
||||
self.__output_folder: Path = output_folder if isinstance(output_folder, Path) else Path(output_folder)
|
||||
self.__thumbnails_folder = self.__output_folder / "thumbnails"
|
||||
self.__compress_level = InvokeAIAppConfig.get_config().png_compress_level
|
||||
# Validate required output folders at launch
|
||||
self.__validate_storage_folders()
|
||||
|
||||
def start(self, invoker: Invoker) -> None:
|
||||
self.__invoker = invoker
|
||||
|
||||
def get(self, image_name: str) -> PILImageType:
|
||||
try:
|
||||
image_path = self.get_path(image_name)
|
||||
@ -136,7 +80,12 @@ class DiskImageFileStorage(ImageFileStorageBase):
|
||||
if original_workflow is not None:
|
||||
pnginfo.add_text("invokeai_workflow", original_workflow)
|
||||
|
||||
image.save(image_path, "PNG", pnginfo=pnginfo, compress_level=self.__compress_level)
|
||||
image.save(
|
||||
image_path,
|
||||
"PNG",
|
||||
pnginfo=pnginfo,
|
||||
compress_level=self.__invoker.services.configuration.png_compress_level,
|
||||
)
|
||||
|
||||
thumbnail_name = get_thumbnail_name(image_name)
|
||||
thumbnail_path = self.get_path(thumbnail_name, thumbnail=True)
|
0
invokeai/app/services/image_records/__init__.py
Normal file
0
invokeai/app/services/image_records/__init__.py
Normal file
84
invokeai/app/services/image_records/image_records_base.py
Normal file
84
invokeai/app/services/image_records/image_records_base.py
Normal file
@ -0,0 +1,84 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
|
||||
|
||||
from .image_records_common import ImageCategory, ImageRecord, ImageRecordChanges, ResourceOrigin
|
||||
|
||||
|
||||
class ImageRecordStorageBase(ABC):
|
||||
"""Low-level service responsible for interfacing with the image record store."""
|
||||
|
||||
# TODO: Implement an `update()` method
|
||||
|
||||
@abstractmethod
|
||||
def get(self, image_name: str) -> ImageRecord:
|
||||
"""Gets an image record."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_metadata(self, image_name: str) -> Optional[dict]:
|
||||
"""Gets an image's metadata'."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def update(
|
||||
self,
|
||||
image_name: str,
|
||||
changes: ImageRecordChanges,
|
||||
) -> None:
|
||||
"""Updates an image record."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_many(
|
||||
self,
|
||||
offset: int = 0,
|
||||
limit: int = 10,
|
||||
image_origin: Optional[ResourceOrigin] = None,
|
||||
categories: Optional[list[ImageCategory]] = None,
|
||||
is_intermediate: Optional[bool] = None,
|
||||
board_id: Optional[str] = None,
|
||||
) -> OffsetPaginatedResults[ImageRecord]:
|
||||
"""Gets a page of image records."""
|
||||
pass
|
||||
|
||||
# TODO: The database has a nullable `deleted_at` column, currently unused.
|
||||
# Should we implement soft deletes? Would need coordination with ImageFileStorage.
|
||||
@abstractmethod
|
||||
def delete(self, image_name: str) -> None:
|
||||
"""Deletes an image record."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def delete_many(self, image_names: list[str]) -> None:
|
||||
"""Deletes many image records."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def delete_intermediates(self) -> list[str]:
|
||||
"""Deletes all intermediate image records, returning a list of deleted image names."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def save(
|
||||
self,
|
||||
image_name: str,
|
||||
image_origin: ResourceOrigin,
|
||||
image_category: ImageCategory,
|
||||
width: int,
|
||||
height: int,
|
||||
is_intermediate: Optional[bool] = False,
|
||||
starred: Optional[bool] = False,
|
||||
session_id: Optional[str] = None,
|
||||
node_id: Optional[str] = None,
|
||||
metadata: Optional[dict] = None,
|
||||
) -> datetime:
|
||||
"""Saves an image record."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_most_recent_image_for_board(self, board_id: str) -> Optional[ImageRecord]:
|
||||
"""Gets the most recent image for a board."""
|
||||
pass
|
@ -1,13 +1,117 @@
|
||||
# TODO: Should these excpetions subclass existing python exceptions?
|
||||
import datetime
|
||||
from enum import Enum
|
||||
from typing import Optional, Union
|
||||
|
||||
from pydantic import Extra, Field, StrictBool, StrictStr
|
||||
from pydantic import Field, StrictBool, StrictStr
|
||||
|
||||
from invokeai.app.models.image import ImageCategory, ResourceOrigin
|
||||
from invokeai.app.util.metaenum import MetaEnum
|
||||
from invokeai.app.util.misc import get_iso_timestamp
|
||||
from invokeai.app.util.model_exclude_null import BaseModelExcludeNull
|
||||
|
||||
|
||||
class ResourceOrigin(str, Enum, metaclass=MetaEnum):
|
||||
"""The origin of a resource (eg image).
|
||||
|
||||
- INTERNAL: The resource was created by the application.
|
||||
- EXTERNAL: The resource was not created by the application.
|
||||
This may be a user-initiated upload, or an internal application upload (eg Canvas init image).
|
||||
"""
|
||||
|
||||
INTERNAL = "internal"
|
||||
"""The resource was created by the application."""
|
||||
EXTERNAL = "external"
|
||||
"""The resource was not created by the application.
|
||||
This may be a user-initiated upload, or an internal application upload (eg Canvas init image).
|
||||
"""
|
||||
|
||||
|
||||
class InvalidOriginException(ValueError):
|
||||
"""Raised when a provided value is not a valid ResourceOrigin.
|
||||
|
||||
Subclasses `ValueError`.
|
||||
"""
|
||||
|
||||
def __init__(self, message="Invalid resource origin."):
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class ImageCategory(str, Enum, metaclass=MetaEnum):
|
||||
"""The category of an image.
|
||||
|
||||
- GENERAL: The image is an output, init image, or otherwise an image without a specialized purpose.
|
||||
- MASK: The image is a mask image.
|
||||
- CONTROL: The image is a ControlNet control image.
|
||||
- USER: The image is a user-provide image.
|
||||
- OTHER: The image is some other type of image with a specialized purpose. To be used by external nodes.
|
||||
"""
|
||||
|
||||
GENERAL = "general"
|
||||
"""GENERAL: The image is an output, init image, or otherwise an image without a specialized purpose."""
|
||||
MASK = "mask"
|
||||
"""MASK: The image is a mask image."""
|
||||
CONTROL = "control"
|
||||
"""CONTROL: The image is a ControlNet control image."""
|
||||
USER = "user"
|
||||
"""USER: The image is a user-provide image."""
|
||||
OTHER = "other"
|
||||
"""OTHER: The image is some other type of image with a specialized purpose. To be used by external nodes."""
|
||||
|
||||
|
||||
class InvalidImageCategoryException(ValueError):
|
||||
"""Raised when a provided value is not a valid ImageCategory.
|
||||
|
||||
Subclasses `ValueError`.
|
||||
"""
|
||||
|
||||
def __init__(self, message="Invalid image category."):
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class ImageRecordNotFoundException(Exception):
|
||||
"""Raised when an image record is not found."""
|
||||
|
||||
def __init__(self, message="Image record not found"):
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class ImageRecordSaveException(Exception):
|
||||
"""Raised when an image record cannot be saved."""
|
||||
|
||||
def __init__(self, message="Image record not saved"):
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class ImageRecordDeleteException(Exception):
|
||||
"""Raised when an image record cannot be deleted."""
|
||||
|
||||
def __init__(self, message="Image record not deleted"):
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
IMAGE_DTO_COLS = ", ".join(
|
||||
list(
|
||||
map(
|
||||
lambda c: "images." + c,
|
||||
[
|
||||
"image_name",
|
||||
"image_origin",
|
||||
"image_category",
|
||||
"width",
|
||||
"height",
|
||||
"session_id",
|
||||
"node_id",
|
||||
"is_intermediate",
|
||||
"created_at",
|
||||
"updated_at",
|
||||
"deleted_at",
|
||||
"starred",
|
||||
],
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class ImageRecord(BaseModelExcludeNull):
|
||||
"""Deserialized image record without metadata."""
|
||||
|
||||
@ -25,7 +129,9 @@ class ImageRecord(BaseModelExcludeNull):
|
||||
"""The created timestamp of the image."""
|
||||
updated_at: Union[datetime.datetime, str] = Field(description="The updated timestamp of the image.")
|
||||
"""The updated timestamp of the image."""
|
||||
deleted_at: Union[datetime.datetime, str, None] = Field(description="The deleted timestamp of the image.")
|
||||
deleted_at: Optional[Union[datetime.datetime, str]] = Field(
|
||||
default=None, description="The deleted timestamp of the image."
|
||||
)
|
||||
"""The deleted timestamp of the image."""
|
||||
is_intermediate: bool = Field(description="Whether this is an intermediate image.")
|
||||
"""Whether this is an intermediate image."""
|
||||
@ -43,7 +149,7 @@ class ImageRecord(BaseModelExcludeNull):
|
||||
"""Whether this image is starred."""
|
||||
|
||||
|
||||
class ImageRecordChanges(BaseModelExcludeNull, extra=Extra.forbid):
|
||||
class ImageRecordChanges(BaseModelExcludeNull, extra="allow"):
|
||||
"""A set of changes to apply to an image record.
|
||||
|
||||
Only limited changes are valid:
|
||||
@ -66,41 +172,6 @@ class ImageRecordChanges(BaseModelExcludeNull, extra=Extra.forbid):
|
||||
"""The image's new `starred` state."""
|
||||
|
||||
|
||||
class ImageUrlsDTO(BaseModelExcludeNull):
|
||||
"""The URLs for an image and its thumbnail."""
|
||||
|
||||
image_name: str = Field(description="The unique name of the image.")
|
||||
"""The unique name of the image."""
|
||||
image_url: str = Field(description="The URL of the image.")
|
||||
"""The URL of the image."""
|
||||
thumbnail_url: str = Field(description="The URL of the image's thumbnail.")
|
||||
"""The URL of the image's thumbnail."""
|
||||
|
||||
|
||||
class ImageDTO(ImageRecord, ImageUrlsDTO):
|
||||
"""Deserialized image record, enriched for the frontend."""
|
||||
|
||||
board_id: Optional[str] = Field(description="The id of the board the image belongs to, if one exists.")
|
||||
"""The id of the board the image belongs to, if one exists."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
def image_record_to_dto(
|
||||
image_record: ImageRecord,
|
||||
image_url: str,
|
||||
thumbnail_url: str,
|
||||
board_id: Optional[str],
|
||||
) -> ImageDTO:
|
||||
"""Converts an image record to an image DTO."""
|
||||
return ImageDTO(
|
||||
**image_record.dict(),
|
||||
image_url=image_url,
|
||||
thumbnail_url=thumbnail_url,
|
||||
board_id=board_id,
|
||||
)
|
||||
|
||||
|
||||
def deserialize_image_record(image_dict: dict) -> ImageRecord:
|
||||
"""Deserializes an image record."""
|
||||
|
@ -1,164 +1,36 @@
|
||||
import json
|
||||
import sqlite3
|
||||
import threading
|
||||
from abc import ABC, abstractmethod
|
||||
from datetime import datetime
|
||||
from typing import Generic, Optional, TypeVar, cast
|
||||
from typing import Optional, Union, cast
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic.generics import GenericModel
|
||||
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
|
||||
from invokeai.app.services.shared.sqlite import SqliteDatabase
|
||||
|
||||
from invokeai.app.models.image import ImageCategory, ResourceOrigin
|
||||
from invokeai.app.services.models.image_record import ImageRecord, ImageRecordChanges, deserialize_image_record
|
||||
|
||||
T = TypeVar("T", bound=BaseModel)
|
||||
|
||||
|
||||
class OffsetPaginatedResults(GenericModel, Generic[T]):
|
||||
"""Offset-paginated results"""
|
||||
|
||||
# fmt: off
|
||||
items: list[T] = Field(description="Items")
|
||||
offset: int = Field(description="Offset from which to retrieve items")
|
||||
limit: int = Field(description="Limit of items to get")
|
||||
total: int = Field(description="Total number of items in result")
|
||||
# fmt: on
|
||||
|
||||
|
||||
# TODO: Should these excpetions subclass existing python exceptions?
|
||||
class ImageRecordNotFoundException(Exception):
|
||||
"""Raised when an image record is not found."""
|
||||
|
||||
def __init__(self, message="Image record not found"):
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class ImageRecordSaveException(Exception):
|
||||
"""Raised when an image record cannot be saved."""
|
||||
|
||||
def __init__(self, message="Image record not saved"):
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class ImageRecordDeleteException(Exception):
|
||||
"""Raised when an image record cannot be deleted."""
|
||||
|
||||
def __init__(self, message="Image record not deleted"):
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
IMAGE_DTO_COLS = ", ".join(
|
||||
list(
|
||||
map(
|
||||
lambda c: "images." + c,
|
||||
[
|
||||
"image_name",
|
||||
"image_origin",
|
||||
"image_category",
|
||||
"width",
|
||||
"height",
|
||||
"session_id",
|
||||
"node_id",
|
||||
"is_intermediate",
|
||||
"created_at",
|
||||
"updated_at",
|
||||
"deleted_at",
|
||||
"starred",
|
||||
],
|
||||
)
|
||||
)
|
||||
from .image_records_base import ImageRecordStorageBase
|
||||
from .image_records_common import (
|
||||
IMAGE_DTO_COLS,
|
||||
ImageCategory,
|
||||
ImageRecord,
|
||||
ImageRecordChanges,
|
||||
ImageRecordDeleteException,
|
||||
ImageRecordNotFoundException,
|
||||
ImageRecordSaveException,
|
||||
ResourceOrigin,
|
||||
deserialize_image_record,
|
||||
)
|
||||
|
||||
|
||||
class ImageRecordStorageBase(ABC):
|
||||
"""Low-level service responsible for interfacing with the image record store."""
|
||||
|
||||
# TODO: Implement an `update()` method
|
||||
|
||||
@abstractmethod
|
||||
def get(self, image_name: str) -> ImageRecord:
|
||||
"""Gets an image record."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_metadata(self, image_name: str) -> Optional[dict]:
|
||||
"""Gets an image's metadata'."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def update(
|
||||
self,
|
||||
image_name: str,
|
||||
changes: ImageRecordChanges,
|
||||
) -> None:
|
||||
"""Updates an image record."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_many(
|
||||
self,
|
||||
offset: Optional[int] = None,
|
||||
limit: Optional[int] = None,
|
||||
image_origin: Optional[ResourceOrigin] = None,
|
||||
categories: Optional[list[ImageCategory]] = None,
|
||||
is_intermediate: Optional[bool] = None,
|
||||
board_id: Optional[str] = None,
|
||||
) -> OffsetPaginatedResults[ImageRecord]:
|
||||
"""Gets a page of image records."""
|
||||
pass
|
||||
|
||||
# TODO: The database has a nullable `deleted_at` column, currently unused.
|
||||
# Should we implement soft deletes? Would need coordination with ImageFileStorage.
|
||||
@abstractmethod
|
||||
def delete(self, image_name: str) -> None:
|
||||
"""Deletes an image record."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def delete_many(self, image_names: list[str]) -> None:
|
||||
"""Deletes many image records."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def delete_intermediates(self) -> list[str]:
|
||||
"""Deletes all intermediate image records, returning a list of deleted image names."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def save(
|
||||
self,
|
||||
image_name: str,
|
||||
image_origin: ResourceOrigin,
|
||||
image_category: ImageCategory,
|
||||
width: int,
|
||||
height: int,
|
||||
session_id: Optional[str],
|
||||
node_id: Optional[str],
|
||||
metadata: Optional[dict],
|
||||
is_intermediate: bool = False,
|
||||
starred: bool = False,
|
||||
) -> datetime:
|
||||
"""Saves an image record."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_most_recent_image_for_board(self, board_id: str) -> Optional[ImageRecord]:
|
||||
"""Gets the most recent image for a board."""
|
||||
pass
|
||||
|
||||
|
||||
class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
_conn: sqlite3.Connection
|
||||
_cursor: sqlite3.Cursor
|
||||
_lock: threading.Lock
|
||||
_lock: threading.RLock
|
||||
|
||||
def __init__(self, conn: sqlite3.Connection, lock: threading.Lock) -> None:
|
||||
def __init__(self, db: SqliteDatabase) -> None:
|
||||
super().__init__()
|
||||
self._conn = conn
|
||||
# Enable row factory to get rows as dictionaries (must be done before making the cursor!)
|
||||
self._conn.row_factory = sqlite3.Row
|
||||
self._lock = db.lock
|
||||
self._conn = db.conn
|
||||
self._cursor = self._conn.cursor()
|
||||
self._lock = lock
|
||||
|
||||
try:
|
||||
self._lock.acquire()
|
||||
@ -245,7 +117,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
"""
|
||||
)
|
||||
|
||||
def get(self, image_name: str) -> Optional[ImageRecord]:
|
||||
def get(self, image_name: str) -> ImageRecord:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
|
||||
@ -351,8 +223,8 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
|
||||
def get_many(
|
||||
self,
|
||||
offset: Optional[int] = None,
|
||||
limit: Optional[int] = None,
|
||||
offset: int = 0,
|
||||
limit: int = 10,
|
||||
image_origin: Optional[ResourceOrigin] = None,
|
||||
categories: Optional[list[ImageCategory]] = None,
|
||||
is_intermediate: Optional[bool] = None,
|
||||
@ -377,7 +249,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
"""
|
||||
|
||||
query_conditions = ""
|
||||
query_params = []
|
||||
query_params: list[Union[int, str, bool]] = []
|
||||
|
||||
if image_origin is not None:
|
||||
query_conditions += """--sql
|
||||
@ -515,13 +387,13 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
image_name: str,
|
||||
image_origin: ResourceOrigin,
|
||||
image_category: ImageCategory,
|
||||
session_id: Optional[str],
|
||||
width: int,
|
||||
height: int,
|
||||
node_id: Optional[str],
|
||||
metadata: Optional[dict],
|
||||
is_intermediate: bool = False,
|
||||
starred: bool = False,
|
||||
is_intermediate: Optional[bool] = False,
|
||||
starred: Optional[bool] = False,
|
||||
session_id: Optional[str] = None,
|
||||
node_id: Optional[str] = None,
|
||||
metadata: Optional[dict] = None,
|
||||
) -> datetime:
|
||||
try:
|
||||
metadata_json = None if metadata is None else json.dumps(metadata)
|
@ -1,449 +0,0 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from logging import Logger
|
||||
from typing import TYPE_CHECKING, Callable, Optional
|
||||
|
||||
from PIL.Image import Image as PILImageType
|
||||
|
||||
from invokeai.app.invocations.metadata import ImageMetadata
|
||||
from invokeai.app.models.image import (
|
||||
ImageCategory,
|
||||
InvalidImageCategoryException,
|
||||
InvalidOriginException,
|
||||
ResourceOrigin,
|
||||
)
|
||||
from invokeai.app.services.board_image_record_storage import BoardImageRecordStorageBase
|
||||
from invokeai.app.services.image_file_storage import (
|
||||
ImageFileDeleteException,
|
||||
ImageFileNotFoundException,
|
||||
ImageFileSaveException,
|
||||
ImageFileStorageBase,
|
||||
)
|
||||
from invokeai.app.services.image_record_storage import (
|
||||
ImageRecordDeleteException,
|
||||
ImageRecordNotFoundException,
|
||||
ImageRecordSaveException,
|
||||
ImageRecordStorageBase,
|
||||
OffsetPaginatedResults,
|
||||
)
|
||||
from invokeai.app.services.item_storage import ItemStorageABC
|
||||
from invokeai.app.services.models.image_record import ImageDTO, ImageRecord, ImageRecordChanges, image_record_to_dto
|
||||
from invokeai.app.services.resource_name import NameServiceBase
|
||||
from invokeai.app.services.urls import UrlServiceBase
|
||||
from invokeai.app.util.metadata import get_metadata_graph_from_raw_session
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from invokeai.app.services.graph import GraphExecutionState
|
||||
|
||||
|
||||
class ImageServiceABC(ABC):
|
||||
"""High-level service for image management."""
|
||||
|
||||
_on_changed_callbacks: list[Callable[[ImageDTO], None]]
|
||||
_on_deleted_callbacks: list[Callable[[str], None]]
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._on_changed_callbacks = list()
|
||||
self._on_deleted_callbacks = list()
|
||||
|
||||
def on_changed(self, on_changed: Callable[[ImageDTO], None]) -> None:
|
||||
"""Register a callback for when an image is changed"""
|
||||
self._on_changed_callbacks.append(on_changed)
|
||||
|
||||
def on_deleted(self, on_deleted: Callable[[str], None]) -> None:
|
||||
"""Register a callback for when an image is deleted"""
|
||||
self._on_deleted_callbacks.append(on_deleted)
|
||||
|
||||
def _on_changed(self, item: ImageDTO) -> None:
|
||||
for callback in self._on_changed_callbacks:
|
||||
callback(item)
|
||||
|
||||
def _on_deleted(self, item_id: str) -> None:
|
||||
for callback in self._on_deleted_callbacks:
|
||||
callback(item_id)
|
||||
|
||||
@abstractmethod
|
||||
def create(
|
||||
self,
|
||||
image: PILImageType,
|
||||
image_origin: ResourceOrigin,
|
||||
image_category: ImageCategory,
|
||||
node_id: Optional[str] = None,
|
||||
session_id: Optional[str] = None,
|
||||
board_id: Optional[str] = None,
|
||||
is_intermediate: bool = False,
|
||||
metadata: Optional[dict] = None,
|
||||
workflow: Optional[str] = None,
|
||||
) -> ImageDTO:
|
||||
"""Creates an image, storing the file and its metadata."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def update(
|
||||
self,
|
||||
image_name: str,
|
||||
changes: ImageRecordChanges,
|
||||
) -> ImageDTO:
|
||||
"""Updates an image."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_pil_image(self, image_name: str) -> PILImageType:
|
||||
"""Gets an image as a PIL image."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_record(self, image_name: str) -> ImageRecord:
|
||||
"""Gets an image record."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_dto(self, image_name: str) -> ImageDTO:
|
||||
"""Gets an image DTO."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_metadata(self, image_name: str) -> ImageMetadata:
|
||||
"""Gets an image's metadata."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_path(self, image_name: str, thumbnail: bool = False) -> str:
|
||||
"""Gets an image's path."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def validate_path(self, path: str) -> bool:
|
||||
"""Validates an image's path."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_url(self, image_name: str, thumbnail: bool = False) -> str:
|
||||
"""Gets an image's or thumbnail's URL."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_many(
|
||||
self,
|
||||
offset: int = 0,
|
||||
limit: int = 10,
|
||||
image_origin: Optional[ResourceOrigin] = None,
|
||||
categories: Optional[list[ImageCategory]] = None,
|
||||
is_intermediate: Optional[bool] = None,
|
||||
board_id: Optional[str] = None,
|
||||
) -> OffsetPaginatedResults[ImageDTO]:
|
||||
"""Gets a paginated list of image DTOs."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def delete(self, image_name: str):
|
||||
"""Deletes an image."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def delete_intermediates(self) -> int:
|
||||
"""Deletes all intermediate images."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def delete_images_on_board(self, board_id: str):
|
||||
"""Deletes all images on a board."""
|
||||
pass
|
||||
|
||||
|
||||
class ImageServiceDependencies:
|
||||
"""Service dependencies for the ImageService."""
|
||||
|
||||
image_records: ImageRecordStorageBase
|
||||
image_files: ImageFileStorageBase
|
||||
board_image_records: BoardImageRecordStorageBase
|
||||
urls: UrlServiceBase
|
||||
logger: Logger
|
||||
names: NameServiceBase
|
||||
graph_execution_manager: ItemStorageABC["GraphExecutionState"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
image_record_storage: ImageRecordStorageBase,
|
||||
image_file_storage: ImageFileStorageBase,
|
||||
board_image_record_storage: BoardImageRecordStorageBase,
|
||||
url: UrlServiceBase,
|
||||
logger: Logger,
|
||||
names: NameServiceBase,
|
||||
graph_execution_manager: ItemStorageABC["GraphExecutionState"],
|
||||
):
|
||||
self.image_records = image_record_storage
|
||||
self.image_files = image_file_storage
|
||||
self.board_image_records = board_image_record_storage
|
||||
self.urls = url
|
||||
self.logger = logger
|
||||
self.names = names
|
||||
self.graph_execution_manager = graph_execution_manager
|
||||
|
||||
|
||||
class ImageService(ImageServiceABC):
|
||||
_services: ImageServiceDependencies
|
||||
|
||||
def __init__(self, services: ImageServiceDependencies):
|
||||
super().__init__()
|
||||
self._services = services
|
||||
|
||||
def create(
|
||||
self,
|
||||
image: PILImageType,
|
||||
image_origin: ResourceOrigin,
|
||||
image_category: ImageCategory,
|
||||
node_id: Optional[str] = None,
|
||||
session_id: Optional[str] = None,
|
||||
board_id: Optional[str] = None,
|
||||
is_intermediate: bool = False,
|
||||
metadata: Optional[dict] = None,
|
||||
workflow: Optional[str] = None,
|
||||
) -> ImageDTO:
|
||||
if image_origin not in ResourceOrigin:
|
||||
raise InvalidOriginException
|
||||
|
||||
if image_category not in ImageCategory:
|
||||
raise InvalidImageCategoryException
|
||||
|
||||
image_name = self._services.names.create_image_name()
|
||||
|
||||
# TODO: Do we want to store the graph in the image at all? I don't think so...
|
||||
# graph = None
|
||||
# if session_id is not None:
|
||||
# session_raw = self._services.graph_execution_manager.get_raw(session_id)
|
||||
# if session_raw is not None:
|
||||
# try:
|
||||
# graph = get_metadata_graph_from_raw_session(session_raw)
|
||||
# except Exception as e:
|
||||
# self._services.logger.warn(f"Failed to parse session graph: {e}")
|
||||
# graph = None
|
||||
|
||||
(width, height) = image.size
|
||||
|
||||
try:
|
||||
# TODO: Consider using a transaction here to ensure consistency between storage and database
|
||||
self._services.image_records.save(
|
||||
# Non-nullable fields
|
||||
image_name=image_name,
|
||||
image_origin=image_origin,
|
||||
image_category=image_category,
|
||||
width=width,
|
||||
height=height,
|
||||
# Meta fields
|
||||
is_intermediate=is_intermediate,
|
||||
# Nullable fields
|
||||
node_id=node_id,
|
||||
metadata=metadata,
|
||||
session_id=session_id,
|
||||
)
|
||||
if board_id is not None:
|
||||
self._services.board_image_records.add_image_to_board(board_id=board_id, image_name=image_name)
|
||||
self._services.image_files.save(image_name=image_name, image=image, metadata=metadata, workflow=workflow)
|
||||
image_dto = self.get_dto(image_name)
|
||||
|
||||
self._on_changed(image_dto)
|
||||
return image_dto
|
||||
except ImageRecordSaveException:
|
||||
self._services.logger.error("Failed to save image record")
|
||||
raise
|
||||
except ImageFileSaveException:
|
||||
self._services.logger.error("Failed to save image file")
|
||||
raise
|
||||
except Exception as e:
|
||||
self._services.logger.error(f"Problem saving image record and file: {str(e)}")
|
||||
raise e
|
||||
|
||||
def update(
|
||||
self,
|
||||
image_name: str,
|
||||
changes: ImageRecordChanges,
|
||||
) -> ImageDTO:
|
||||
try:
|
||||
self._services.image_records.update(image_name, changes)
|
||||
image_dto = self.get_dto(image_name)
|
||||
self._on_changed(image_dto)
|
||||
return image_dto
|
||||
except ImageRecordSaveException:
|
||||
self._services.logger.error("Failed to update image record")
|
||||
raise
|
||||
except Exception as e:
|
||||
self._services.logger.error("Problem updating image record")
|
||||
raise e
|
||||
|
||||
def get_pil_image(self, image_name: str) -> PILImageType:
|
||||
try:
|
||||
return self._services.image_files.get(image_name)
|
||||
except ImageFileNotFoundException:
|
||||
self._services.logger.error("Failed to get image file")
|
||||
raise
|
||||
except Exception as e:
|
||||
self._services.logger.error("Problem getting image file")
|
||||
raise e
|
||||
|
||||
def get_record(self, image_name: str) -> ImageRecord:
|
||||
try:
|
||||
return self._services.image_records.get(image_name)
|
||||
except ImageRecordNotFoundException:
|
||||
self._services.logger.error("Image record not found")
|
||||
raise
|
||||
except Exception as e:
|
||||
self._services.logger.error("Problem getting image record")
|
||||
raise e
|
||||
|
||||
def get_dto(self, image_name: str) -> ImageDTO:
|
||||
try:
|
||||
image_record = self._services.image_records.get(image_name)
|
||||
|
||||
image_dto = image_record_to_dto(
|
||||
image_record,
|
||||
self._services.urls.get_image_url(image_name),
|
||||
self._services.urls.get_image_url(image_name, True),
|
||||
self._services.board_image_records.get_board_for_image(image_name),
|
||||
)
|
||||
|
||||
return image_dto
|
||||
except ImageRecordNotFoundException:
|
||||
self._services.logger.error("Image record not found")
|
||||
raise
|
||||
except Exception as e:
|
||||
self._services.logger.error("Problem getting image DTO")
|
||||
raise e
|
||||
|
||||
def get_metadata(self, image_name: str) -> Optional[ImageMetadata]:
|
||||
try:
|
||||
image_record = self._services.image_records.get(image_name)
|
||||
metadata = self._services.image_records.get_metadata(image_name)
|
||||
|
||||
if not image_record.session_id:
|
||||
return ImageMetadata(metadata=metadata)
|
||||
|
||||
session_raw = self._services.graph_execution_manager.get_raw(image_record.session_id)
|
||||
graph = None
|
||||
|
||||
if session_raw:
|
||||
try:
|
||||
graph = get_metadata_graph_from_raw_session(session_raw)
|
||||
except Exception as e:
|
||||
self._services.logger.warn(f"Failed to parse session graph: {e}")
|
||||
graph = None
|
||||
|
||||
return ImageMetadata(graph=graph, metadata=metadata)
|
||||
except ImageRecordNotFoundException:
|
||||
self._services.logger.error("Image record not found")
|
||||
raise
|
||||
except Exception as e:
|
||||
self._services.logger.error("Problem getting image DTO")
|
||||
raise e
|
||||
|
||||
def get_path(self, image_name: str, thumbnail: bool = False) -> str:
|
||||
try:
|
||||
return self._services.image_files.get_path(image_name, thumbnail)
|
||||
except Exception as e:
|
||||
self._services.logger.error("Problem getting image path")
|
||||
raise e
|
||||
|
||||
def validate_path(self, path: str) -> bool:
|
||||
try:
|
||||
return self._services.image_files.validate_path(path)
|
||||
except Exception as e:
|
||||
self._services.logger.error("Problem validating image path")
|
||||
raise e
|
||||
|
||||
def get_url(self, image_name: str, thumbnail: bool = False) -> str:
|
||||
try:
|
||||
return self._services.urls.get_image_url(image_name, thumbnail)
|
||||
except Exception as e:
|
||||
self._services.logger.error("Problem getting image path")
|
||||
raise e
|
||||
|
||||
def get_many(
|
||||
self,
|
||||
offset: int = 0,
|
||||
limit: int = 10,
|
||||
image_origin: Optional[ResourceOrigin] = None,
|
||||
categories: Optional[list[ImageCategory]] = None,
|
||||
is_intermediate: Optional[bool] = None,
|
||||
board_id: Optional[str] = None,
|
||||
) -> OffsetPaginatedResults[ImageDTO]:
|
||||
try:
|
||||
results = self._services.image_records.get_many(
|
||||
offset,
|
||||
limit,
|
||||
image_origin,
|
||||
categories,
|
||||
is_intermediate,
|
||||
board_id,
|
||||
)
|
||||
|
||||
image_dtos = list(
|
||||
map(
|
||||
lambda r: image_record_to_dto(
|
||||
r,
|
||||
self._services.urls.get_image_url(r.image_name),
|
||||
self._services.urls.get_image_url(r.image_name, True),
|
||||
self._services.board_image_records.get_board_for_image(r.image_name),
|
||||
),
|
||||
results.items,
|
||||
)
|
||||
)
|
||||
|
||||
return OffsetPaginatedResults[ImageDTO](
|
||||
items=image_dtos,
|
||||
offset=results.offset,
|
||||
limit=results.limit,
|
||||
total=results.total,
|
||||
)
|
||||
except Exception as e:
|
||||
self._services.logger.error("Problem getting paginated image DTOs")
|
||||
raise e
|
||||
|
||||
def delete(self, image_name: str):
|
||||
try:
|
||||
self._services.image_files.delete(image_name)
|
||||
self._services.image_records.delete(image_name)
|
||||
self._on_deleted(image_name)
|
||||
except ImageRecordDeleteException:
|
||||
self._services.logger.error("Failed to delete image record")
|
||||
raise
|
||||
except ImageFileDeleteException:
|
||||
self._services.logger.error("Failed to delete image file")
|
||||
raise
|
||||
except Exception as e:
|
||||
self._services.logger.error("Problem deleting image record and file")
|
||||
raise e
|
||||
|
||||
def delete_images_on_board(self, board_id: str):
|
||||
try:
|
||||
image_names = self._services.board_image_records.get_all_board_image_names_for_board(board_id)
|
||||
for image_name in image_names:
|
||||
self._services.image_files.delete(image_name)
|
||||
self._services.image_records.delete_many(image_names)
|
||||
for image_name in image_names:
|
||||
self._on_deleted(image_name)
|
||||
except ImageRecordDeleteException:
|
||||
self._services.logger.error("Failed to delete image records")
|
||||
raise
|
||||
except ImageFileDeleteException:
|
||||
self._services.logger.error("Failed to delete image files")
|
||||
raise
|
||||
except Exception as e:
|
||||
self._services.logger.error("Problem deleting image records and files")
|
||||
raise e
|
||||
|
||||
def delete_intermediates(self) -> int:
|
||||
try:
|
||||
image_names = self._services.image_records.delete_intermediates()
|
||||
count = len(image_names)
|
||||
for image_name in image_names:
|
||||
self._services.image_files.delete(image_name)
|
||||
self._on_deleted(image_name)
|
||||
return count
|
||||
except ImageRecordDeleteException:
|
||||
self._services.logger.error("Failed to delete image records")
|
||||
raise
|
||||
except ImageFileDeleteException:
|
||||
self._services.logger.error("Failed to delete image files")
|
||||
raise
|
||||
except Exception as e:
|
||||
self._services.logger.error("Problem deleting image records and files")
|
||||
raise e
|
0
invokeai/app/services/images/__init__.py
Normal file
0
invokeai/app/services/images/__init__.py
Normal file
129
invokeai/app/services/images/images_base.py
Normal file
129
invokeai/app/services/images/images_base.py
Normal file
@ -0,0 +1,129 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Callable, Optional
|
||||
|
||||
from PIL.Image import Image as PILImageType
|
||||
|
||||
from invokeai.app.invocations.metadata import ImageMetadata
|
||||
from invokeai.app.services.image_records.image_records_common import (
|
||||
ImageCategory,
|
||||
ImageRecord,
|
||||
ImageRecordChanges,
|
||||
ResourceOrigin,
|
||||
)
|
||||
from invokeai.app.services.images.images_common import ImageDTO
|
||||
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
|
||||
|
||||
|
||||
class ImageServiceABC(ABC):
|
||||
"""High-level service for image management."""
|
||||
|
||||
_on_changed_callbacks: list[Callable[[ImageDTO], None]]
|
||||
_on_deleted_callbacks: list[Callable[[str], None]]
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._on_changed_callbacks = list()
|
||||
self._on_deleted_callbacks = list()
|
||||
|
||||
def on_changed(self, on_changed: Callable[[ImageDTO], None]) -> None:
|
||||
"""Register a callback for when an image is changed"""
|
||||
self._on_changed_callbacks.append(on_changed)
|
||||
|
||||
def on_deleted(self, on_deleted: Callable[[str], None]) -> None:
|
||||
"""Register a callback for when an image is deleted"""
|
||||
self._on_deleted_callbacks.append(on_deleted)
|
||||
|
||||
def _on_changed(self, item: ImageDTO) -> None:
|
||||
for callback in self._on_changed_callbacks:
|
||||
callback(item)
|
||||
|
||||
def _on_deleted(self, item_id: str) -> None:
|
||||
for callback in self._on_deleted_callbacks:
|
||||
callback(item_id)
|
||||
|
||||
@abstractmethod
|
||||
def create(
|
||||
self,
|
||||
image: PILImageType,
|
||||
image_origin: ResourceOrigin,
|
||||
image_category: ImageCategory,
|
||||
node_id: Optional[str] = None,
|
||||
session_id: Optional[str] = None,
|
||||
board_id: Optional[str] = None,
|
||||
is_intermediate: Optional[bool] = False,
|
||||
metadata: Optional[dict] = None,
|
||||
workflow: Optional[str] = None,
|
||||
) -> ImageDTO:
|
||||
"""Creates an image, storing the file and its metadata."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def update(
|
||||
self,
|
||||
image_name: str,
|
||||
changes: ImageRecordChanges,
|
||||
) -> ImageDTO:
|
||||
"""Updates an image."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_pil_image(self, image_name: str) -> PILImageType:
|
||||
"""Gets an image as a PIL image."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_record(self, image_name: str) -> ImageRecord:
|
||||
"""Gets an image record."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_dto(self, image_name: str) -> ImageDTO:
|
||||
"""Gets an image DTO."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_metadata(self, image_name: str) -> ImageMetadata:
|
||||
"""Gets an image's metadata."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_path(self, image_name: str, thumbnail: bool = False) -> str:
|
||||
"""Gets an image's path."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def validate_path(self, path: str) -> bool:
|
||||
"""Validates an image's path."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_url(self, image_name: str, thumbnail: bool = False) -> str:
|
||||
"""Gets an image's or thumbnail's URL."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_many(
|
||||
self,
|
||||
offset: int = 0,
|
||||
limit: int = 10,
|
||||
image_origin: Optional[ResourceOrigin] = None,
|
||||
categories: Optional[list[ImageCategory]] = None,
|
||||
is_intermediate: Optional[bool] = None,
|
||||
board_id: Optional[str] = None,
|
||||
) -> OffsetPaginatedResults[ImageDTO]:
|
||||
"""Gets a paginated list of image DTOs."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def delete(self, image_name: str):
|
||||
"""Deletes an image."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def delete_intermediates(self) -> int:
|
||||
"""Deletes all intermediate images."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def delete_images_on_board(self, board_id: str):
|
||||
"""Deletes all images on a board."""
|
||||
pass
|
43
invokeai/app/services/images/images_common.py
Normal file
43
invokeai/app/services/images/images_common.py
Normal file
@ -0,0 +1,43 @@
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from invokeai.app.services.image_records.image_records_common import ImageRecord
|
||||
from invokeai.app.util.model_exclude_null import BaseModelExcludeNull
|
||||
|
||||
|
||||
class ImageUrlsDTO(BaseModelExcludeNull):
|
||||
"""The URLs for an image and its thumbnail."""
|
||||
|
||||
image_name: str = Field(description="The unique name of the image.")
|
||||
"""The unique name of the image."""
|
||||
image_url: str = Field(description="The URL of the image.")
|
||||
"""The URL of the image."""
|
||||
thumbnail_url: str = Field(description="The URL of the image's thumbnail.")
|
||||
"""The URL of the image's thumbnail."""
|
||||
|
||||
|
||||
class ImageDTO(ImageRecord, ImageUrlsDTO):
|
||||
"""Deserialized image record, enriched for the frontend."""
|
||||
|
||||
board_id: Optional[str] = Field(
|
||||
default=None, description="The id of the board the image belongs to, if one exists."
|
||||
)
|
||||
"""The id of the board the image belongs to, if one exists."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
def image_record_to_dto(
|
||||
image_record: ImageRecord,
|
||||
image_url: str,
|
||||
thumbnail_url: str,
|
||||
board_id: Optional[str],
|
||||
) -> ImageDTO:
|
||||
"""Converts an image record to an image DTO."""
|
||||
return ImageDTO(
|
||||
**image_record.model_dump(),
|
||||
image_url=image_url,
|
||||
thumbnail_url=thumbnail_url,
|
||||
board_id=board_id,
|
||||
)
|
286
invokeai/app/services/images/images_default.py
Normal file
286
invokeai/app/services/images/images_default.py
Normal file
@ -0,0 +1,286 @@
|
||||
from typing import Optional
|
||||
|
||||
from PIL.Image import Image as PILImageType
|
||||
|
||||
from invokeai.app.invocations.metadata import ImageMetadata
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
|
||||
from invokeai.app.util.metadata import get_metadata_graph_from_raw_session
|
||||
|
||||
from ..image_files.image_files_common import (
|
||||
ImageFileDeleteException,
|
||||
ImageFileNotFoundException,
|
||||
ImageFileSaveException,
|
||||
)
|
||||
from ..image_records.image_records_common import (
|
||||
ImageCategory,
|
||||
ImageRecord,
|
||||
ImageRecordChanges,
|
||||
ImageRecordDeleteException,
|
||||
ImageRecordNotFoundException,
|
||||
ImageRecordSaveException,
|
||||
InvalidImageCategoryException,
|
||||
InvalidOriginException,
|
||||
ResourceOrigin,
|
||||
)
|
||||
from .images_base import ImageServiceABC
|
||||
from .images_common import ImageDTO, image_record_to_dto
|
||||
|
||||
|
||||
class ImageService(ImageServiceABC):
|
||||
__invoker: Invoker
|
||||
|
||||
def start(self, invoker: Invoker) -> None:
|
||||
self.__invoker = invoker
|
||||
|
||||
def create(
|
||||
self,
|
||||
image: PILImageType,
|
||||
image_origin: ResourceOrigin,
|
||||
image_category: ImageCategory,
|
||||
node_id: Optional[str] = None,
|
||||
session_id: Optional[str] = None,
|
||||
board_id: Optional[str] = None,
|
||||
is_intermediate: Optional[bool] = False,
|
||||
metadata: Optional[dict] = None,
|
||||
workflow: Optional[str] = None,
|
||||
) -> ImageDTO:
|
||||
if image_origin not in ResourceOrigin:
|
||||
raise InvalidOriginException
|
||||
|
||||
if image_category not in ImageCategory:
|
||||
raise InvalidImageCategoryException
|
||||
|
||||
image_name = self.__invoker.services.names.create_image_name()
|
||||
|
||||
(width, height) = image.size
|
||||
|
||||
try:
|
||||
# TODO: Consider using a transaction here to ensure consistency between storage and database
|
||||
self.__invoker.services.image_records.save(
|
||||
# Non-nullable fields
|
||||
image_name=image_name,
|
||||
image_origin=image_origin,
|
||||
image_category=image_category,
|
||||
width=width,
|
||||
height=height,
|
||||
# Meta fields
|
||||
is_intermediate=is_intermediate,
|
||||
# Nullable fields
|
||||
node_id=node_id,
|
||||
metadata=metadata,
|
||||
session_id=session_id,
|
||||
)
|
||||
if board_id is not None:
|
||||
self.__invoker.services.board_image_records.add_image_to_board(board_id=board_id, image_name=image_name)
|
||||
self.__invoker.services.image_files.save(
|
||||
image_name=image_name, image=image, metadata=metadata, workflow=workflow
|
||||
)
|
||||
image_dto = self.get_dto(image_name)
|
||||
|
||||
self._on_changed(image_dto)
|
||||
return image_dto
|
||||
except ImageRecordSaveException:
|
||||
self.__invoker.services.logger.error("Failed to save image record")
|
||||
raise
|
||||
except ImageFileSaveException:
|
||||
self.__invoker.services.logger.error("Failed to save image file")
|
||||
raise
|
||||
except Exception as e:
|
||||
self.__invoker.services.logger.error(f"Problem saving image record and file: {str(e)}")
|
||||
raise e
|
||||
|
||||
def update(
|
||||
self,
|
||||
image_name: str,
|
||||
changes: ImageRecordChanges,
|
||||
) -> ImageDTO:
|
||||
try:
|
||||
self.__invoker.services.image_records.update(image_name, changes)
|
||||
image_dto = self.get_dto(image_name)
|
||||
self._on_changed(image_dto)
|
||||
return image_dto
|
||||
except ImageRecordSaveException:
|
||||
self.__invoker.services.logger.error("Failed to update image record")
|
||||
raise
|
||||
except Exception as e:
|
||||
self.__invoker.services.logger.error("Problem updating image record")
|
||||
raise e
|
||||
|
||||
def get_pil_image(self, image_name: str) -> PILImageType:
|
||||
try:
|
||||
return self.__invoker.services.image_files.get(image_name)
|
||||
except ImageFileNotFoundException:
|
||||
self.__invoker.services.logger.error("Failed to get image file")
|
||||
raise
|
||||
except Exception as e:
|
||||
self.__invoker.services.logger.error("Problem getting image file")
|
||||
raise e
|
||||
|
||||
def get_record(self, image_name: str) -> ImageRecord:
|
||||
try:
|
||||
return self.__invoker.services.image_records.get(image_name)
|
||||
except ImageRecordNotFoundException:
|
||||
self.__invoker.services.logger.error("Image record not found")
|
||||
raise
|
||||
except Exception as e:
|
||||
self.__invoker.services.logger.error("Problem getting image record")
|
||||
raise e
|
||||
|
||||
def get_dto(self, image_name: str) -> ImageDTO:
|
||||
try:
|
||||
image_record = self.__invoker.services.image_records.get(image_name)
|
||||
|
||||
image_dto = image_record_to_dto(
|
||||
image_record,
|
||||
self.__invoker.services.urls.get_image_url(image_name),
|
||||
self.__invoker.services.urls.get_image_url(image_name, True),
|
||||
self.__invoker.services.board_image_records.get_board_for_image(image_name),
|
||||
)
|
||||
|
||||
return image_dto
|
||||
except ImageRecordNotFoundException:
|
||||
self.__invoker.services.logger.error("Image record not found")
|
||||
raise
|
||||
except Exception as e:
|
||||
self.__invoker.services.logger.error("Problem getting image DTO")
|
||||
raise e
|
||||
|
||||
def get_metadata(self, image_name: str) -> ImageMetadata:
|
||||
try:
|
||||
image_record = self.__invoker.services.image_records.get(image_name)
|
||||
metadata = self.__invoker.services.image_records.get_metadata(image_name)
|
||||
|
||||
if not image_record.session_id:
|
||||
return ImageMetadata(metadata=metadata)
|
||||
|
||||
session_raw = self.__invoker.services.graph_execution_manager.get_raw(image_record.session_id)
|
||||
graph = None
|
||||
|
||||
if session_raw:
|
||||
try:
|
||||
graph = get_metadata_graph_from_raw_session(session_raw)
|
||||
except Exception as e:
|
||||
self.__invoker.services.logger.warn(f"Failed to parse session graph: {e}")
|
||||
graph = None
|
||||
|
||||
return ImageMetadata(graph=graph, metadata=metadata)
|
||||
except ImageRecordNotFoundException:
|
||||
self.__invoker.services.logger.error("Image record not found")
|
||||
raise
|
||||
except Exception as e:
|
||||
self.__invoker.services.logger.error("Problem getting image DTO")
|
||||
raise e
|
||||
|
||||
def get_path(self, image_name: str, thumbnail: bool = False) -> str:
|
||||
try:
|
||||
return str(self.__invoker.services.image_files.get_path(image_name, thumbnail))
|
||||
except Exception as e:
|
||||
self.__invoker.services.logger.error("Problem getting image path")
|
||||
raise e
|
||||
|
||||
def validate_path(self, path: str) -> bool:
|
||||
try:
|
||||
return self.__invoker.services.image_files.validate_path(path)
|
||||
except Exception as e:
|
||||
self.__invoker.services.logger.error("Problem validating image path")
|
||||
raise e
|
||||
|
||||
def get_url(self, image_name: str, thumbnail: bool = False) -> str:
|
||||
try:
|
||||
return self.__invoker.services.urls.get_image_url(image_name, thumbnail)
|
||||
except Exception as e:
|
||||
self.__invoker.services.logger.error("Problem getting image path")
|
||||
raise e
|
||||
|
||||
def get_many(
|
||||
self,
|
||||
offset: int = 0,
|
||||
limit: int = 10,
|
||||
image_origin: Optional[ResourceOrigin] = None,
|
||||
categories: Optional[list[ImageCategory]] = None,
|
||||
is_intermediate: Optional[bool] = None,
|
||||
board_id: Optional[str] = None,
|
||||
) -> OffsetPaginatedResults[ImageDTO]:
|
||||
try:
|
||||
results = self.__invoker.services.image_records.get_many(
|
||||
offset,
|
||||
limit,
|
||||
image_origin,
|
||||
categories,
|
||||
is_intermediate,
|
||||
board_id,
|
||||
)
|
||||
|
||||
image_dtos = list(
|
||||
map(
|
||||
lambda r: image_record_to_dto(
|
||||
r,
|
||||
self.__invoker.services.urls.get_image_url(r.image_name),
|
||||
self.__invoker.services.urls.get_image_url(r.image_name, True),
|
||||
self.__invoker.services.board_image_records.get_board_for_image(r.image_name),
|
||||
),
|
||||
results.items,
|
||||
)
|
||||
)
|
||||
|
||||
return OffsetPaginatedResults[ImageDTO](
|
||||
items=image_dtos,
|
||||
offset=results.offset,
|
||||
limit=results.limit,
|
||||
total=results.total,
|
||||
)
|
||||
except Exception as e:
|
||||
self.__invoker.services.logger.error("Problem getting paginated image DTOs")
|
||||
raise e
|
||||
|
||||
def delete(self, image_name: str):
|
||||
try:
|
||||
self.__invoker.services.image_files.delete(image_name)
|
||||
self.__invoker.services.image_records.delete(image_name)
|
||||
self._on_deleted(image_name)
|
||||
except ImageRecordDeleteException:
|
||||
self.__invoker.services.logger.error("Failed to delete image record")
|
||||
raise
|
||||
except ImageFileDeleteException:
|
||||
self.__invoker.services.logger.error("Failed to delete image file")
|
||||
raise
|
||||
except Exception as e:
|
||||
self.__invoker.services.logger.error("Problem deleting image record and file")
|
||||
raise e
|
||||
|
||||
def delete_images_on_board(self, board_id: str):
|
||||
try:
|
||||
image_names = self.__invoker.services.board_image_records.get_all_board_image_names_for_board(board_id)
|
||||
for image_name in image_names:
|
||||
self.__invoker.services.image_files.delete(image_name)
|
||||
self.__invoker.services.image_records.delete_many(image_names)
|
||||
for image_name in image_names:
|
||||
self._on_deleted(image_name)
|
||||
except ImageRecordDeleteException:
|
||||
self.__invoker.services.logger.error("Failed to delete image records")
|
||||
raise
|
||||
except ImageFileDeleteException:
|
||||
self.__invoker.services.logger.error("Failed to delete image files")
|
||||
raise
|
||||
except Exception as e:
|
||||
self.__invoker.services.logger.error("Problem deleting image records and files")
|
||||
raise e
|
||||
|
||||
def delete_intermediates(self) -> int:
|
||||
try:
|
||||
image_names = self.__invoker.services.image_records.delete_intermediates()
|
||||
count = len(image_names)
|
||||
for image_name in image_names:
|
||||
self.__invoker.services.image_files.delete(image_name)
|
||||
self._on_deleted(image_name)
|
||||
return count
|
||||
except ImageRecordDeleteException:
|
||||
self.__invoker.services.logger.error("Failed to delete image records")
|
||||
raise
|
||||
except ImageFileDeleteException:
|
||||
self.__invoker.services.logger.error("Failed to delete image files")
|
||||
raise
|
||||
except Exception as e:
|
||||
self.__invoker.services.logger.error("Problem deleting image records and files")
|
||||
raise e
|
@ -58,7 +58,12 @@ class MemoryInvocationCache(InvocationCacheBase):
|
||||
# If the cache is full, we need to remove the least used
|
||||
number_to_delete = len(self._cache) + 1 - self._max_cache_size
|
||||
self._delete_oldest_access(number_to_delete)
|
||||
self._cache[key] = CachedItem(invocation_output, invocation_output.json())
|
||||
self._cache[key] = CachedItem(
|
||||
invocation_output,
|
||||
invocation_output.model_dump_json(
|
||||
warnings=False, exclude_defaults=True, exclude_unset=True, include={"type"}
|
||||
),
|
||||
)
|
||||
|
||||
def _delete_oldest_access(self, number_to_delete: int) -> None:
|
||||
number_to_delete = min(number_to_delete, len(self._cache))
|
||||
@ -85,7 +90,7 @@ class MemoryInvocationCache(InvocationCacheBase):
|
||||
|
||||
@staticmethod
|
||||
def create_key(invocation: BaseInvocation) -> int:
|
||||
return hash(invocation.json(exclude={"id"}))
|
||||
return hash(invocation.model_dump_json(exclude={"id"}, warnings=False))
|
||||
|
||||
def disable(self) -> None:
|
||||
with self._lock:
|
||||
|
@ -0,0 +1,5 @@
|
||||
from abc import ABC
|
||||
|
||||
|
||||
class InvocationProcessorABC(ABC):
|
||||
pass
|
@ -0,0 +1,15 @@
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class ProgressImage(BaseModel):
|
||||
"""The progress image sent intermittently during processing"""
|
||||
|
||||
width: int = Field(description="The effective width of the image in pixels")
|
||||
height: int = Field(description="The effective height of the image in pixels")
|
||||
dataURL: str = Field(description="The image data as a b64 data URL")
|
||||
|
||||
|
||||
class CanceledException(Exception):
|
||||
"""Execution canceled by user."""
|
||||
|
||||
pass
|
@ -4,12 +4,12 @@ from threading import BoundedSemaphore, Event, Thread
|
||||
from typing import Optional
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
from invokeai.app.invocations.baseinvocation import InvocationContext
|
||||
from invokeai.app.services.invocation_queue.invocation_queue_common import InvocationQueueItem
|
||||
|
||||
from ..invocations.baseinvocation import InvocationContext
|
||||
from ..models.exceptions import CanceledException
|
||||
from .invocation_queue import InvocationQueueItem
|
||||
from .invocation_stats import InvocationStatsServiceBase
|
||||
from .invoker import InvocationProcessorABC, Invoker
|
||||
from ..invoker import Invoker
|
||||
from .invocation_processor_base import InvocationProcessorABC
|
||||
from .invocation_processor_common import CanceledException
|
||||
|
||||
|
||||
class DefaultInvocationProcessor(InvocationProcessorABC):
|
||||
@ -37,7 +37,6 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
||||
def __process(self, stop_event: Event):
|
||||
try:
|
||||
self.__threadLimit.acquire()
|
||||
statistics: InvocationStatsServiceBase = self.__invoker.services.performance_statistics
|
||||
queue_item: Optional[InvocationQueueItem] = None
|
||||
|
||||
while not stop_event.is_set():
|
||||
@ -90,15 +89,14 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
||||
queue_item_id=queue_item.session_queue_item_id,
|
||||
queue_id=queue_item.session_queue_id,
|
||||
graph_execution_state_id=graph_execution_state.id,
|
||||
node=invocation.dict(),
|
||||
node=invocation.model_dump(),
|
||||
source_node_id=source_node_id,
|
||||
)
|
||||
|
||||
# Invoke
|
||||
try:
|
||||
graph_id = graph_execution_state.id
|
||||
model_manager = self.__invoker.services.model_manager
|
||||
with statistics.collect_stats(invocation, graph_id, model_manager):
|
||||
with self.__invoker.services.performance_statistics.collect_stats(invocation, graph_id):
|
||||
# use the internal invoke_internal(), which wraps the node's invoke() method,
|
||||
# which handles a few things:
|
||||
# - nodes that require a value, but get it only from a connection
|
||||
@ -129,17 +127,17 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
||||
queue_item_id=queue_item.session_queue_item_id,
|
||||
queue_id=queue_item.session_queue_id,
|
||||
graph_execution_state_id=graph_execution_state.id,
|
||||
node=invocation.dict(),
|
||||
node=invocation.model_dump(),
|
||||
source_node_id=source_node_id,
|
||||
result=outputs.dict(),
|
||||
result=outputs.model_dump(),
|
||||
)
|
||||
statistics.log_stats()
|
||||
self.__invoker.services.performance_statistics.log_stats()
|
||||
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
|
||||
except CanceledException:
|
||||
statistics.reset_stats(graph_execution_state.id)
|
||||
self.__invoker.services.performance_statistics.reset_stats(graph_execution_state.id)
|
||||
pass
|
||||
|
||||
except Exception as e:
|
||||
@ -159,12 +157,12 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
||||
queue_item_id=queue_item.session_queue_item_id,
|
||||
queue_id=queue_item.session_queue_id,
|
||||
graph_execution_state_id=graph_execution_state.id,
|
||||
node=invocation.dict(),
|
||||
node=invocation.model_dump(),
|
||||
source_node_id=source_node_id,
|
||||
error_type=e.__class__.__name__,
|
||||
error=error,
|
||||
)
|
||||
statistics.reset_stats(graph_execution_state.id)
|
||||
self.__invoker.services.performance_statistics.reset_stats(graph_execution_state.id)
|
||||
pass
|
||||
|
||||
# Check queue to see if this is canceled, and skip if so
|
||||
@ -189,7 +187,7 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
||||
queue_item_id=queue_item.session_queue_item_id,
|
||||
queue_id=queue_item.session_queue_id,
|
||||
graph_execution_state_id=graph_execution_state.id,
|
||||
node=invocation.dict(),
|
||||
node=invocation.model_dump(),
|
||||
source_node_id=source_node_id,
|
||||
error_type=e.__class__.__name__,
|
||||
error=traceback.format_exc(),
|
0
invokeai/app/services/invocation_queue/__init__.py
Normal file
0
invokeai/app/services/invocation_queue/__init__.py
Normal file
@ -0,0 +1,26 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional
|
||||
|
||||
from .invocation_queue_common import InvocationQueueItem
|
||||
|
||||
|
||||
class InvocationQueueABC(ABC):
|
||||
"""Abstract base class for all invocation queues"""
|
||||
|
||||
@abstractmethod
|
||||
def get(self) -> InvocationQueueItem:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def put(self, item: Optional[InvocationQueueItem]) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def cancel(self, graph_execution_state_id: str) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def is_canceled(self, graph_execution_state_id: str) -> bool:
|
||||
pass
|
@ -0,0 +1,19 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
import time
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class InvocationQueueItem(BaseModel):
|
||||
graph_execution_state_id: str = Field(description="The ID of the graph execution state")
|
||||
invocation_id: str = Field(description="The ID of the node being invoked")
|
||||
session_queue_id: str = Field(description="The ID of the session queue from which this invocation queue item came")
|
||||
session_queue_item_id: int = Field(
|
||||
description="The ID of session queue item from which this invocation queue item came"
|
||||
)
|
||||
session_queue_batch_id: str = Field(
|
||||
description="The ID of the session batch from which this invocation queue item came"
|
||||
)
|
||||
invoke_all: bool = Field(default=False)
|
||||
timestamp: float = Field(default_factory=time.time)
|
@ -1,45 +1,11 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
from queue import Queue
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class InvocationQueueItem(BaseModel):
|
||||
graph_execution_state_id: str = Field(description="The ID of the graph execution state")
|
||||
invocation_id: str = Field(description="The ID of the node being invoked")
|
||||
session_queue_id: str = Field(description="The ID of the session queue from which this invocation queue item came")
|
||||
session_queue_item_id: int = Field(
|
||||
description="The ID of session queue item from which this invocation queue item came"
|
||||
)
|
||||
session_queue_batch_id: str = Field(
|
||||
description="The ID of the session batch from which this invocation queue item came"
|
||||
)
|
||||
invoke_all: bool = Field(default=False)
|
||||
timestamp: float = Field(default_factory=time.time)
|
||||
|
||||
|
||||
class InvocationQueueABC(ABC):
|
||||
"""Abstract base class for all invocation queues"""
|
||||
|
||||
@abstractmethod
|
||||
def get(self) -> InvocationQueueItem:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def put(self, item: Optional[InvocationQueueItem]) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def cancel(self, graph_execution_state_id: str) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def is_canceled(self, graph_execution_state_id: str) -> bool:
|
||||
pass
|
||||
from .invocation_queue_base import InvocationQueueABC
|
||||
from .invocation_queue_common import InvocationQueueItem
|
||||
|
||||
|
||||
class MemoryInvocationQueue(InvocationQueueABC):
|
@ -6,21 +6,27 @@ from typing import TYPE_CHECKING
|
||||
if TYPE_CHECKING:
|
||||
from logging import Logger
|
||||
|
||||
from invokeai.app.services.board_images import BoardImagesServiceABC
|
||||
from invokeai.app.services.boards import BoardServiceABC
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.app.services.events import EventServiceBase
|
||||
from invokeai.app.services.graph import GraphExecutionState, LibraryGraph
|
||||
from invokeai.app.services.images import ImageServiceABC
|
||||
from invokeai.app.services.invocation_cache.invocation_cache_base import InvocationCacheBase
|
||||
from invokeai.app.services.invocation_queue import InvocationQueueABC
|
||||
from invokeai.app.services.invocation_stats import InvocationStatsServiceBase
|
||||
from invokeai.app.services.invoker import InvocationProcessorABC
|
||||
from invokeai.app.services.item_storage import ItemStorageABC
|
||||
from invokeai.app.services.latent_storage import LatentsStorageBase
|
||||
from invokeai.app.services.model_manager_service import ModelManagerServiceBase
|
||||
from invokeai.app.services.session_processor.session_processor_base import SessionProcessorBase
|
||||
from invokeai.app.services.session_queue.session_queue_base import SessionQueueBase
|
||||
from .board_image_records.board_image_records_base import BoardImageRecordStorageBase
|
||||
from .board_images.board_images_base import BoardImagesServiceABC
|
||||
from .board_records.board_records_base import BoardRecordStorageBase
|
||||
from .boards.boards_base import BoardServiceABC
|
||||
from .config import InvokeAIAppConfig
|
||||
from .events.events_base import EventServiceBase
|
||||
from .image_files.image_files_base import ImageFileStorageBase
|
||||
from .image_records.image_records_base import ImageRecordStorageBase
|
||||
from .images.images_base import ImageServiceABC
|
||||
from .invocation_cache.invocation_cache_base import InvocationCacheBase
|
||||
from .invocation_processor.invocation_processor_base import InvocationProcessorABC
|
||||
from .invocation_queue.invocation_queue_base import InvocationQueueABC
|
||||
from .invocation_stats.invocation_stats_base import InvocationStatsServiceBase
|
||||
from .item_storage.item_storage_base import ItemStorageABC
|
||||
from .latents_storage.latents_storage_base import LatentsStorageBase
|
||||
from .model_manager.model_manager_base import ModelManagerServiceBase
|
||||
from .names.names_base import NameServiceBase
|
||||
from .session_processor.session_processor_base import SessionProcessorBase
|
||||
from .session_queue.session_queue_base import SessionQueueBase
|
||||
from .shared.graph import GraphExecutionState, LibraryGraph
|
||||
from .urls.urls_base import UrlServiceBase
|
||||
|
||||
|
||||
class InvocationServices:
|
||||
@ -28,12 +34,16 @@ class InvocationServices:
|
||||
|
||||
# TODO: Just forward-declared everything due to circular dependencies. Fix structure.
|
||||
board_images: "BoardImagesServiceABC"
|
||||
board_image_record_storage: "BoardImageRecordStorageBase"
|
||||
boards: "BoardServiceABC"
|
||||
board_records: "BoardRecordStorageBase"
|
||||
configuration: "InvokeAIAppConfig"
|
||||
events: "EventServiceBase"
|
||||
graph_execution_manager: "ItemStorageABC[GraphExecutionState]"
|
||||
graph_library: "ItemStorageABC[LibraryGraph]"
|
||||
images: "ImageServiceABC"
|
||||
image_records: "ImageRecordStorageBase"
|
||||
image_files: "ImageFileStorageBase"
|
||||
latents: "LatentsStorageBase"
|
||||
logger: "Logger"
|
||||
model_manager: "ModelManagerServiceBase"
|
||||
@ -43,16 +53,22 @@ class InvocationServices:
|
||||
session_queue: "SessionQueueBase"
|
||||
session_processor: "SessionProcessorBase"
|
||||
invocation_cache: "InvocationCacheBase"
|
||||
names: "NameServiceBase"
|
||||
urls: "UrlServiceBase"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
board_images: "BoardImagesServiceABC",
|
||||
board_image_records: "BoardImageRecordStorageBase",
|
||||
boards: "BoardServiceABC",
|
||||
board_records: "BoardRecordStorageBase",
|
||||
configuration: "InvokeAIAppConfig",
|
||||
events: "EventServiceBase",
|
||||
graph_execution_manager: "ItemStorageABC[GraphExecutionState]",
|
||||
graph_library: "ItemStorageABC[LibraryGraph]",
|
||||
images: "ImageServiceABC",
|
||||
image_files: "ImageFileStorageBase",
|
||||
image_records: "ImageRecordStorageBase",
|
||||
latents: "LatentsStorageBase",
|
||||
logger: "Logger",
|
||||
model_manager: "ModelManagerServiceBase",
|
||||
@ -62,14 +78,20 @@ class InvocationServices:
|
||||
session_queue: "SessionQueueBase",
|
||||
session_processor: "SessionProcessorBase",
|
||||
invocation_cache: "InvocationCacheBase",
|
||||
names: "NameServiceBase",
|
||||
urls: "UrlServiceBase",
|
||||
):
|
||||
self.board_images = board_images
|
||||
self.board_image_records = board_image_records
|
||||
self.boards = boards
|
||||
self.board_records = board_records
|
||||
self.configuration = configuration
|
||||
self.events = events
|
||||
self.graph_execution_manager = graph_execution_manager
|
||||
self.graph_library = graph_library
|
||||
self.images = images
|
||||
self.image_files = image_files
|
||||
self.image_records = image_records
|
||||
self.latents = latents
|
||||
self.logger = logger
|
||||
self.model_manager = model_manager
|
||||
@ -79,3 +101,5 @@ class InvocationServices:
|
||||
self.session_queue = session_queue
|
||||
self.session_processor = session_processor
|
||||
self.invocation_cache = invocation_cache
|
||||
self.names = names
|
||||
self.urls = urls
|
||||
|
0
invokeai/app/services/invocation_stats/__init__.py
Normal file
0
invokeai/app/services/invocation_stats/__init__.py
Normal file
121
invokeai/app/services/invocation_stats/invocation_stats_base.py
Normal file
121
invokeai/app/services/invocation_stats/invocation_stats_base.py
Normal file
@ -0,0 +1,121 @@
|
||||
# Copyright 2023 Lincoln D. Stein <lincoln.stein@gmail.com>
|
||||
"""Utility to collect execution time and GPU usage stats on invocations in flight
|
||||
|
||||
Usage:
|
||||
|
||||
statistics = InvocationStatsService(graph_execution_manager)
|
||||
with statistics.collect_stats(invocation, graph_execution_state.id):
|
||||
... execute graphs...
|
||||
statistics.log_stats()
|
||||
|
||||
Typical output:
|
||||
[2023-08-02 18:03:04,507]::[InvokeAI]::INFO --> Graph stats: c7764585-9c68-4d9d-a199-55e8186790f3
|
||||
[2023-08-02 18:03:04,507]::[InvokeAI]::INFO --> Node Calls Seconds VRAM Used
|
||||
[2023-08-02 18:03:04,507]::[InvokeAI]::INFO --> main_model_loader 1 0.005s 0.01G
|
||||
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> clip_skip 1 0.004s 0.01G
|
||||
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> compel 2 0.512s 0.26G
|
||||
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> rand_int 1 0.001s 0.01G
|
||||
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> range_of_size 1 0.001s 0.01G
|
||||
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> iterate 1 0.001s 0.01G
|
||||
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> metadata_accumulator 1 0.002s 0.01G
|
||||
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> noise 1 0.002s 0.01G
|
||||
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> t2l 1 3.541s 1.93G
|
||||
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> l2i 1 0.679s 0.58G
|
||||
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> TOTAL GRAPH EXECUTION TIME: 4.749s
|
||||
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> Current VRAM utilization 0.01G
|
||||
|
||||
The abstract base class for this class is InvocationStatsServiceBase. An implementing class which
|
||||
writes to the system log is stored in InvocationServices.performance_statistics.
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from contextlib import AbstractContextManager
|
||||
from typing import Dict
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation
|
||||
from invokeai.backend.model_management.model_cache import CacheStats
|
||||
|
||||
from .invocation_stats_common import NodeLog
|
||||
|
||||
|
||||
class InvocationStatsServiceBase(ABC):
|
||||
"Abstract base class for recording node memory/time performance statistics"
|
||||
|
||||
# {graph_id => NodeLog}
|
||||
_stats: Dict[str, NodeLog]
|
||||
_cache_stats: Dict[str, CacheStats]
|
||||
ram_used: float
|
||||
ram_changed: float
|
||||
|
||||
@abstractmethod
|
||||
def __init__(self):
|
||||
"""
|
||||
Initialize the InvocationStatsService and reset counters to zero
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def collect_stats(
|
||||
self,
|
||||
invocation: BaseInvocation,
|
||||
graph_execution_state_id: str,
|
||||
) -> AbstractContextManager:
|
||||
"""
|
||||
Return a context object that will capture the statistics on the execution
|
||||
of invocaation. Use with: to place around the part of the code that executes the invocation.
|
||||
:param invocation: BaseInvocation object from the current graph.
|
||||
:param graph_execution_state_id: The id of the current session.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def reset_stats(self, graph_execution_state_id: str):
|
||||
"""
|
||||
Reset all statistics for the indicated graph
|
||||
:param graph_execution_state_id
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def reset_all_stats(self):
|
||||
"""Zero all statistics"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def update_invocation_stats(
|
||||
self,
|
||||
graph_id: str,
|
||||
invocation_type: str,
|
||||
time_used: float,
|
||||
vram_used: float,
|
||||
):
|
||||
"""
|
||||
Add timing information on execution of a node. Usually
|
||||
used internally.
|
||||
:param graph_id: ID of the graph that is currently executing
|
||||
:param invocation_type: String literal type of the node
|
||||
:param time_used: Time used by node's exection (sec)
|
||||
:param vram_used: Maximum VRAM used during exection (GB)
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def log_stats(self):
|
||||
"""
|
||||
Write out the accumulated statistics to the log or somewhere else.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def update_mem_stats(
|
||||
self,
|
||||
ram_used: float,
|
||||
ram_changed: float,
|
||||
):
|
||||
"""
|
||||
Update the collector with RAM memory usage info.
|
||||
|
||||
:param ram_used: How much RAM is currently in use.
|
||||
:param ram_changed: How much RAM changed since last generation.
|
||||
"""
|
||||
pass
|
@ -0,0 +1,25 @@
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Dict
|
||||
|
||||
# size of GIG in bytes
|
||||
GIG = 1073741824
|
||||
|
||||
|
||||
@dataclass
|
||||
class NodeStats:
|
||||
"""Class for tracking execution stats of an invocation node"""
|
||||
|
||||
calls: int = 0
|
||||
time_used: float = 0.0 # seconds
|
||||
max_vram: float = 0.0 # GB
|
||||
cache_hits: int = 0
|
||||
cache_misses: int = 0
|
||||
cache_high_watermark: int = 0
|
||||
|
||||
|
||||
@dataclass
|
||||
class NodeLog:
|
||||
"""Class for tracking node usage"""
|
||||
|
||||
# {node_type => NodeStats}
|
||||
nodes: Dict[str, NodeStats] = field(default_factory=dict)
|
@ -1,171 +1,35 @@
|
||||
# Copyright 2023 Lincoln D. Stein <lincoln.stein@gmail.com>
|
||||
"""Utility to collect execution time and GPU usage stats on invocations in flight
|
||||
|
||||
Usage:
|
||||
|
||||
statistics = InvocationStatsService(graph_execution_manager)
|
||||
with statistics.collect_stats(invocation, graph_execution_state.id):
|
||||
... execute graphs...
|
||||
statistics.log_stats()
|
||||
|
||||
Typical output:
|
||||
[2023-08-02 18:03:04,507]::[InvokeAI]::INFO --> Graph stats: c7764585-9c68-4d9d-a199-55e8186790f3
|
||||
[2023-08-02 18:03:04,507]::[InvokeAI]::INFO --> Node Calls Seconds VRAM Used
|
||||
[2023-08-02 18:03:04,507]::[InvokeAI]::INFO --> main_model_loader 1 0.005s 0.01G
|
||||
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> clip_skip 1 0.004s 0.01G
|
||||
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> compel 2 0.512s 0.26G
|
||||
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> rand_int 1 0.001s 0.01G
|
||||
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> range_of_size 1 0.001s 0.01G
|
||||
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> iterate 1 0.001s 0.01G
|
||||
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> metadata_accumulator 1 0.002s 0.01G
|
||||
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> noise 1 0.002s 0.01G
|
||||
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> t2l 1 3.541s 1.93G
|
||||
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> l2i 1 0.679s 0.58G
|
||||
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> TOTAL GRAPH EXECUTION TIME: 4.749s
|
||||
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> Current VRAM utilization 0.01G
|
||||
|
||||
The abstract base class for this class is InvocationStatsServiceBase. An implementing class which
|
||||
writes to the system log is stored in InvocationServices.performance_statistics.
|
||||
"""
|
||||
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
from contextlib import AbstractContextManager
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Dict
|
||||
|
||||
import psutil
|
||||
import torch
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
from invokeai.app.services.model_manager.model_manager_base import ModelManagerServiceBase
|
||||
from invokeai.backend.model_management.model_cache import CacheStats
|
||||
|
||||
from ..invocations.baseinvocation import BaseInvocation
|
||||
from .graph import GraphExecutionState
|
||||
from .item_storage import ItemStorageABC
|
||||
from .model_manager_service import ModelManagerService
|
||||
|
||||
# size of GIG in bytes
|
||||
GIG = 1073741824
|
||||
|
||||
|
||||
@dataclass
|
||||
class NodeStats:
|
||||
"""Class for tracking execution stats of an invocation node"""
|
||||
|
||||
calls: int = 0
|
||||
time_used: float = 0.0 # seconds
|
||||
max_vram: float = 0.0 # GB
|
||||
cache_hits: int = 0
|
||||
cache_misses: int = 0
|
||||
cache_high_watermark: int = 0
|
||||
|
||||
|
||||
@dataclass
|
||||
class NodeLog:
|
||||
"""Class for tracking node usage"""
|
||||
|
||||
# {node_type => NodeStats}
|
||||
nodes: Dict[str, NodeStats] = field(default_factory=dict)
|
||||
|
||||
|
||||
class InvocationStatsServiceBase(ABC):
|
||||
"Abstract base class for recording node memory/time performance statistics"
|
||||
|
||||
graph_execution_manager: ItemStorageABC["GraphExecutionState"]
|
||||
# {graph_id => NodeLog}
|
||||
_stats: Dict[str, NodeLog]
|
||||
_cache_stats: Dict[str, CacheStats]
|
||||
ram_used: float
|
||||
ram_changed: float
|
||||
|
||||
@abstractmethod
|
||||
def __init__(self, graph_execution_manager: ItemStorageABC["GraphExecutionState"]):
|
||||
"""
|
||||
Initialize the InvocationStatsService and reset counters to zero
|
||||
:param graph_execution_manager: Graph execution manager for this session
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def collect_stats(
|
||||
self,
|
||||
invocation: BaseInvocation,
|
||||
graph_execution_state_id: str,
|
||||
) -> AbstractContextManager:
|
||||
"""
|
||||
Return a context object that will capture the statistics on the execution
|
||||
of invocaation. Use with: to place around the part of the code that executes the invocation.
|
||||
:param invocation: BaseInvocation object from the current graph.
|
||||
:param graph_execution_state: GraphExecutionState object from the current session.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def reset_stats(self, graph_execution_state_id: str):
|
||||
"""
|
||||
Reset all statistics for the indicated graph
|
||||
:param graph_execution_state_id
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def reset_all_stats(self):
|
||||
"""Zero all statistics"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def update_invocation_stats(
|
||||
self,
|
||||
graph_id: str,
|
||||
invocation_type: str,
|
||||
time_used: float,
|
||||
vram_used: float,
|
||||
):
|
||||
"""
|
||||
Add timing information on execution of a node. Usually
|
||||
used internally.
|
||||
:param graph_id: ID of the graph that is currently executing
|
||||
:param invocation_type: String literal type of the node
|
||||
:param time_used: Time used by node's exection (sec)
|
||||
:param vram_used: Maximum VRAM used during exection (GB)
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def log_stats(self):
|
||||
"""
|
||||
Write out the accumulated statistics to the log or somewhere else.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def update_mem_stats(
|
||||
self,
|
||||
ram_used: float,
|
||||
ram_changed: float,
|
||||
):
|
||||
"""
|
||||
Update the collector with RAM memory usage info.
|
||||
|
||||
:param ram_used: How much RAM is currently in use.
|
||||
:param ram_changed: How much RAM changed since last generation.
|
||||
"""
|
||||
pass
|
||||
from .invocation_stats_base import InvocationStatsServiceBase
|
||||
from .invocation_stats_common import GIG, NodeLog, NodeStats
|
||||
|
||||
|
||||
class InvocationStatsService(InvocationStatsServiceBase):
|
||||
"""Accumulate performance information about a running graph. Collects time spent in each node,
|
||||
as well as the maximum and current VRAM utilisation for CUDA systems"""
|
||||
|
||||
def __init__(self, graph_execution_manager: ItemStorageABC["GraphExecutionState"]):
|
||||
self.graph_execution_manager = graph_execution_manager
|
||||
_invoker: Invoker
|
||||
|
||||
def __init__(self):
|
||||
# {graph_id => NodeLog}
|
||||
self._stats: Dict[str, NodeLog] = {}
|
||||
self._cache_stats: Dict[str, CacheStats] = {}
|
||||
self.ram_used: float = 0.0
|
||||
self.ram_changed: float = 0.0
|
||||
|
||||
def start(self, invoker: Invoker) -> None:
|
||||
self._invoker = invoker
|
||||
|
||||
class StatsContext:
|
||||
"""Context manager for collecting statistics."""
|
||||
|
||||
@ -174,13 +38,13 @@ class InvocationStatsService(InvocationStatsServiceBase):
|
||||
graph_id: str
|
||||
start_time: float
|
||||
ram_used: int
|
||||
model_manager: ModelManagerService
|
||||
model_manager: ModelManagerServiceBase
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
invocation: BaseInvocation,
|
||||
graph_id: str,
|
||||
model_manager: ModelManagerService,
|
||||
model_manager: ModelManagerServiceBase,
|
||||
collector: "InvocationStatsServiceBase",
|
||||
):
|
||||
"""Initialize statistics for this run."""
|
||||
@ -208,7 +72,7 @@ class InvocationStatsService(InvocationStatsServiceBase):
|
||||
)
|
||||
self.collector.update_invocation_stats(
|
||||
graph_id=self.graph_id,
|
||||
invocation_type=self.invocation.type, # type: ignore - `type` is not on the `BaseInvocation` model, but *is* on all invocations
|
||||
invocation_type=self.invocation.type, # type: ignore # `type` is not on the `BaseInvocation` model, but *is* on all invocations
|
||||
time_used=time.time() - self.start_time,
|
||||
vram_used=torch.cuda.max_memory_allocated() / GIG if torch.cuda.is_available() else 0.0,
|
||||
)
|
||||
@ -217,12 +81,11 @@ class InvocationStatsService(InvocationStatsServiceBase):
|
||||
self,
|
||||
invocation: BaseInvocation,
|
||||
graph_execution_state_id: str,
|
||||
model_manager: ModelManagerService,
|
||||
) -> StatsContext:
|
||||
if not self._stats.get(graph_execution_state_id): # first time we're seeing this
|
||||
self._stats[graph_execution_state_id] = NodeLog()
|
||||
self._cache_stats[graph_execution_state_id] = CacheStats()
|
||||
return self.StatsContext(invocation, graph_execution_state_id, model_manager, self)
|
||||
return self.StatsContext(invocation, graph_execution_state_id, self._invoker.services.model_manager, self)
|
||||
|
||||
def reset_all_stats(self):
|
||||
"""Zero all statistics"""
|
||||
@ -261,7 +124,7 @@ class InvocationStatsService(InvocationStatsServiceBase):
|
||||
errored = set()
|
||||
for graph_id, node_log in self._stats.items():
|
||||
try:
|
||||
current_graph_state = self.graph_execution_manager.get(graph_id)
|
||||
current_graph_state = self._invoker.services.graph_execution_manager.get(graph_id)
|
||||
except Exception:
|
||||
errored.add(graph_id)
|
||||
continue
|
@ -1,11 +1,10 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
from abc import ABC
|
||||
from typing import Optional
|
||||
|
||||
from .graph import Graph, GraphExecutionState
|
||||
from .invocation_queue import InvocationQueueItem
|
||||
from .invocation_queue.invocation_queue_common import InvocationQueueItem
|
||||
from .invocation_services import InvocationServices
|
||||
from .shared.graph import Graph, GraphExecutionState
|
||||
|
||||
|
||||
class Invoker:
|
||||
@ -84,7 +83,3 @@ class Invoker:
|
||||
self.__stop_service(getattr(self.services, service))
|
||||
|
||||
self.services.queue.put(None)
|
||||
|
||||
|
||||
class InvocationProcessorABC(ABC):
|
||||
pass
|
||||
|
0
invokeai/app/services/item_storage/__init__.py
Normal file
0
invokeai/app/services/item_storage/__init__.py
Normal file
@ -1,25 +1,16 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Callable, Generic, Optional, TypeVar
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic.generics import GenericModel
|
||||
from pydantic import BaseModel
|
||||
|
||||
from invokeai.app.services.shared.pagination import PaginatedResults
|
||||
|
||||
T = TypeVar("T", bound=BaseModel)
|
||||
|
||||
|
||||
class PaginatedResults(GenericModel, Generic[T]):
|
||||
"""Paginated results"""
|
||||
|
||||
# fmt: off
|
||||
items: list[T] = Field(description="Items")
|
||||
page: int = Field(description="Current Page")
|
||||
pages: int = Field(description="Total number of pages")
|
||||
per_page: int = Field(description="Number of items per page")
|
||||
total: int = Field(description="Total number of items in result")
|
||||
# fmt: on
|
||||
|
||||
|
||||
class ItemStorageABC(ABC, Generic[T]):
|
||||
"""Provides storage for a single type of item. The type must be a Pydantic model."""
|
||||
|
||||
_on_changed_callbacks: list[Callable[[T], None]]
|
||||
_on_deleted_callbacks: list[Callable[[str], None]]
|
||||
|
@ -2,30 +2,33 @@ import sqlite3
|
||||
import threading
|
||||
from typing import Generic, Optional, TypeVar, get_args
|
||||
|
||||
from pydantic import BaseModel, parse_raw_as
|
||||
from pydantic import BaseModel, TypeAdapter
|
||||
|
||||
from .item_storage import ItemStorageABC, PaginatedResults
|
||||
from invokeai.app.services.shared.pagination import PaginatedResults
|
||||
from invokeai.app.services.shared.sqlite import SqliteDatabase
|
||||
|
||||
from .item_storage_base import ItemStorageABC
|
||||
|
||||
T = TypeVar("T", bound=BaseModel)
|
||||
|
||||
sqlite_memory = ":memory:"
|
||||
|
||||
|
||||
class SqliteItemStorage(ItemStorageABC, Generic[T]):
|
||||
_table_name: str
|
||||
_conn: sqlite3.Connection
|
||||
_cursor: sqlite3.Cursor
|
||||
_id_field: str
|
||||
_lock: threading.Lock
|
||||
_lock: threading.RLock
|
||||
_adapter: Optional[TypeAdapter[T]]
|
||||
|
||||
def __init__(self, conn: sqlite3.Connection, table_name: str, lock: threading.Lock, id_field: str = "id"):
|
||||
def __init__(self, db: SqliteDatabase, table_name: str, id_field: str = "id"):
|
||||
super().__init__()
|
||||
|
||||
self._lock = db.lock
|
||||
self._conn = db.conn
|
||||
self._table_name = table_name
|
||||
self._id_field = id_field # TODO: validate that T has this field
|
||||
self._lock = lock
|
||||
self._conn = conn
|
||||
self._cursor = self._conn.cursor()
|
||||
self._adapter: Optional[TypeAdapter[T]] = None
|
||||
|
||||
self._create_table()
|
||||
|
||||
@ -44,15 +47,21 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
|
||||
self._lock.release()
|
||||
|
||||
def _parse_item(self, item: str) -> T:
|
||||
item_type = get_args(self.__orig_class__)[0]
|
||||
return parse_raw_as(item_type, item)
|
||||
if self._adapter is None:
|
||||
"""
|
||||
We don't get access to `__orig_class__` in `__init__()`, and we need this before start(), so
|
||||
we can create it when it is first needed instead.
|
||||
__orig_class__ is technically an implementation detail of the typing module, not a supported API
|
||||
"""
|
||||
self._adapter = TypeAdapter(get_args(self.__orig_class__)[0]) # type: ignore [attr-defined]
|
||||
return self._adapter.validate_json(item)
|
||||
|
||||
def set(self, item: T):
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
f"""INSERT OR REPLACE INTO {self._table_name} (item) VALUES (?);""",
|
||||
(item.json(),),
|
||||
(item.model_dump_json(warnings=False, exclude_none=True),),
|
||||
)
|
||||
self._conn.commit()
|
||||
finally:
|
@ -1,119 +0,0 @@
|
||||
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from queue import Queue
|
||||
from typing import Callable, Dict, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class LatentsStorageBase(ABC):
|
||||
"""Responsible for storing and retrieving latents."""
|
||||
|
||||
_on_changed_callbacks: list[Callable[[torch.Tensor], None]]
|
||||
_on_deleted_callbacks: list[Callable[[str], None]]
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._on_changed_callbacks = list()
|
||||
self._on_deleted_callbacks = list()
|
||||
|
||||
@abstractmethod
|
||||
def get(self, name: str) -> torch.Tensor:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def save(self, name: str, data: torch.Tensor) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def delete(self, name: str) -> None:
|
||||
pass
|
||||
|
||||
def on_changed(self, on_changed: Callable[[torch.Tensor], None]) -> None:
|
||||
"""Register a callback for when an item is changed"""
|
||||
self._on_changed_callbacks.append(on_changed)
|
||||
|
||||
def on_deleted(self, on_deleted: Callable[[str], None]) -> None:
|
||||
"""Register a callback for when an item is deleted"""
|
||||
self._on_deleted_callbacks.append(on_deleted)
|
||||
|
||||
def _on_changed(self, item: torch.Tensor) -> None:
|
||||
for callback in self._on_changed_callbacks:
|
||||
callback(item)
|
||||
|
||||
def _on_deleted(self, item_id: str) -> None:
|
||||
for callback in self._on_deleted_callbacks:
|
||||
callback(item_id)
|
||||
|
||||
|
||||
class ForwardCacheLatentsStorage(LatentsStorageBase):
|
||||
"""Caches the latest N latents in memory, writing-thorugh to and reading from underlying storage"""
|
||||
|
||||
__cache: Dict[str, torch.Tensor]
|
||||
__cache_ids: Queue
|
||||
__max_cache_size: int
|
||||
__underlying_storage: LatentsStorageBase
|
||||
|
||||
def __init__(self, underlying_storage: LatentsStorageBase, max_cache_size: int = 20):
|
||||
super().__init__()
|
||||
self.__underlying_storage = underlying_storage
|
||||
self.__cache = dict()
|
||||
self.__cache_ids = Queue()
|
||||
self.__max_cache_size = max_cache_size
|
||||
|
||||
def get(self, name: str) -> torch.Tensor:
|
||||
cache_item = self.__get_cache(name)
|
||||
if cache_item is not None:
|
||||
return cache_item
|
||||
|
||||
latent = self.__underlying_storage.get(name)
|
||||
self.__set_cache(name, latent)
|
||||
return latent
|
||||
|
||||
def save(self, name: str, data: torch.Tensor) -> None:
|
||||
self.__underlying_storage.save(name, data)
|
||||
self.__set_cache(name, data)
|
||||
self._on_changed(data)
|
||||
|
||||
def delete(self, name: str) -> None:
|
||||
self.__underlying_storage.delete(name)
|
||||
if name in self.__cache:
|
||||
del self.__cache[name]
|
||||
self._on_deleted(name)
|
||||
|
||||
def __get_cache(self, name: str) -> Optional[torch.Tensor]:
|
||||
return None if name not in self.__cache else self.__cache[name]
|
||||
|
||||
def __set_cache(self, name: str, data: torch.Tensor):
|
||||
if name not in self.__cache:
|
||||
self.__cache[name] = data
|
||||
self.__cache_ids.put(name)
|
||||
if self.__cache_ids.qsize() > self.__max_cache_size:
|
||||
self.__cache.pop(self.__cache_ids.get())
|
||||
|
||||
|
||||
class DiskLatentsStorage(LatentsStorageBase):
|
||||
"""Stores latents in a folder on disk without caching"""
|
||||
|
||||
__output_folder: Union[str, Path]
|
||||
|
||||
def __init__(self, output_folder: Union[str, Path]):
|
||||
self.__output_folder = output_folder if isinstance(output_folder, Path) else Path(output_folder)
|
||||
self.__output_folder.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
def get(self, name: str) -> torch.Tensor:
|
||||
latent_path = self.get_path(name)
|
||||
return torch.load(latent_path)
|
||||
|
||||
def save(self, name: str, data: torch.Tensor) -> None:
|
||||
self.__output_folder.mkdir(parents=True, exist_ok=True)
|
||||
latent_path = self.get_path(name)
|
||||
torch.save(data, latent_path)
|
||||
|
||||
def delete(self, name: str) -> None:
|
||||
latent_path = self.get_path(name)
|
||||
latent_path.unlink()
|
||||
|
||||
def get_path(self, name: str) -> Path:
|
||||
return self.__output_folder / name
|
0
invokeai/app/services/latents_storage/__init__.py
Normal file
0
invokeai/app/services/latents_storage/__init__.py
Normal file
@ -0,0 +1,45 @@
|
||||
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Callable
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class LatentsStorageBase(ABC):
|
||||
"""Responsible for storing and retrieving latents."""
|
||||
|
||||
_on_changed_callbacks: list[Callable[[torch.Tensor], None]]
|
||||
_on_deleted_callbacks: list[Callable[[str], None]]
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._on_changed_callbacks = list()
|
||||
self._on_deleted_callbacks = list()
|
||||
|
||||
@abstractmethod
|
||||
def get(self, name: str) -> torch.Tensor:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def save(self, name: str, data: torch.Tensor) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def delete(self, name: str) -> None:
|
||||
pass
|
||||
|
||||
def on_changed(self, on_changed: Callable[[torch.Tensor], None]) -> None:
|
||||
"""Register a callback for when an item is changed"""
|
||||
self._on_changed_callbacks.append(on_changed)
|
||||
|
||||
def on_deleted(self, on_deleted: Callable[[str], None]) -> None:
|
||||
"""Register a callback for when an item is deleted"""
|
||||
self._on_deleted_callbacks.append(on_deleted)
|
||||
|
||||
def _on_changed(self, item: torch.Tensor) -> None:
|
||||
for callback in self._on_changed_callbacks:
|
||||
callback(item)
|
||||
|
||||
def _on_deleted(self, item_id: str) -> None:
|
||||
for callback in self._on_deleted_callbacks:
|
||||
callback(item_id)
|
@ -0,0 +1,34 @@
|
||||
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
import torch
|
||||
|
||||
from .latents_storage_base import LatentsStorageBase
|
||||
|
||||
|
||||
class DiskLatentsStorage(LatentsStorageBase):
|
||||
"""Stores latents in a folder on disk without caching"""
|
||||
|
||||
__output_folder: Path
|
||||
|
||||
def __init__(self, output_folder: Union[str, Path]):
|
||||
self.__output_folder = output_folder if isinstance(output_folder, Path) else Path(output_folder)
|
||||
self.__output_folder.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
def get(self, name: str) -> torch.Tensor:
|
||||
latent_path = self.get_path(name)
|
||||
return torch.load(latent_path)
|
||||
|
||||
def save(self, name: str, data: torch.Tensor) -> None:
|
||||
self.__output_folder.mkdir(parents=True, exist_ok=True)
|
||||
latent_path = self.get_path(name)
|
||||
torch.save(data, latent_path)
|
||||
|
||||
def delete(self, name: str) -> None:
|
||||
latent_path = self.get_path(name)
|
||||
latent_path.unlink()
|
||||
|
||||
def get_path(self, name: str) -> Path:
|
||||
return self.__output_folder / name
|
@ -0,0 +1,54 @@
|
||||
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
from queue import Queue
|
||||
from typing import Dict, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from .latents_storage_base import LatentsStorageBase
|
||||
|
||||
|
||||
class ForwardCacheLatentsStorage(LatentsStorageBase):
|
||||
"""Caches the latest N latents in memory, writing-thorugh to and reading from underlying storage"""
|
||||
|
||||
__cache: Dict[str, torch.Tensor]
|
||||
__cache_ids: Queue
|
||||
__max_cache_size: int
|
||||
__underlying_storage: LatentsStorageBase
|
||||
|
||||
def __init__(self, underlying_storage: LatentsStorageBase, max_cache_size: int = 20):
|
||||
super().__init__()
|
||||
self.__underlying_storage = underlying_storage
|
||||
self.__cache = dict()
|
||||
self.__cache_ids = Queue()
|
||||
self.__max_cache_size = max_cache_size
|
||||
|
||||
def get(self, name: str) -> torch.Tensor:
|
||||
cache_item = self.__get_cache(name)
|
||||
if cache_item is not None:
|
||||
return cache_item
|
||||
|
||||
latent = self.__underlying_storage.get(name)
|
||||
self.__set_cache(name, latent)
|
||||
return latent
|
||||
|
||||
def save(self, name: str, data: torch.Tensor) -> None:
|
||||
self.__underlying_storage.save(name, data)
|
||||
self.__set_cache(name, data)
|
||||
self._on_changed(data)
|
||||
|
||||
def delete(self, name: str) -> None:
|
||||
self.__underlying_storage.delete(name)
|
||||
if name in self.__cache:
|
||||
del self.__cache[name]
|
||||
self._on_deleted(name)
|
||||
|
||||
def __get_cache(self, name: str) -> Optional[torch.Tensor]:
|
||||
return None if name not in self.__cache else self.__cache[name]
|
||||
|
||||
def __set_cache(self, name: str, data: torch.Tensor):
|
||||
if name not in self.__cache:
|
||||
self.__cache[name] = data
|
||||
self.__cache_ids.put(name)
|
||||
if self.__cache_ids.qsize() > self.__max_cache_size:
|
||||
self.__cache.pop(self.__cache_ids.get())
|
0
invokeai/app/services/model_manager/__init__.py
Normal file
0
invokeai/app/services/model_manager/__init__.py
Normal file
286
invokeai/app/services/model_manager/model_manager_base.py
Normal file
286
invokeai/app/services/model_manager/model_manager_base.py
Normal file
@ -0,0 +1,286 @@
|
||||
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Team
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from logging import Logger
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Callable, List, Literal, Optional, Tuple, Union
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from invokeai.app.services.config.config_default import InvokeAIAppConfig
|
||||
from invokeai.backend.model_management import (
|
||||
AddModelResult,
|
||||
BaseModelType,
|
||||
MergeInterpolationMethod,
|
||||
ModelInfo,
|
||||
ModelType,
|
||||
SchedulerPredictionType,
|
||||
SubModelType,
|
||||
)
|
||||
from invokeai.backend.model_management.model_cache import CacheStats
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, InvocationContext
|
||||
|
||||
|
||||
class ModelManagerServiceBase(ABC):
|
||||
"""Responsible for managing models on disk and in memory"""
|
||||
|
||||
@abstractmethod
|
||||
def __init__(
|
||||
self,
|
||||
config: InvokeAIAppConfig,
|
||||
logger: Logger,
|
||||
):
|
||||
"""
|
||||
Initialize with the path to the models.yaml config file.
|
||||
Optional parameters are the torch device type, precision, max_models,
|
||||
and sequential_offload boolean. Note that the default device
|
||||
type and precision are set up for a CUDA system running at half precision.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_model(
|
||||
self,
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
submodel: Optional[SubModelType] = None,
|
||||
node: Optional[BaseInvocation] = None,
|
||||
context: Optional[InvocationContext] = None,
|
||||
) -> ModelInfo:
|
||||
"""Retrieve the indicated model with name and type.
|
||||
submodel can be used to get a part (such as the vae)
|
||||
of a diffusers pipeline."""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def logger(self):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def model_exists(
|
||||
self,
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
) -> bool:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def model_info(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> dict:
|
||||
"""
|
||||
Given a model name returns a dict-like (OmegaConf) object describing it.
|
||||
Uses the exact format as the omegaconf stanza.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def list_models(self, base_model: Optional[BaseModelType] = None, model_type: Optional[ModelType] = None) -> dict:
|
||||
"""
|
||||
Return a dict of models in the format:
|
||||
{ model_type1:
|
||||
{ model_name1: {'status': 'active'|'cached'|'not loaded',
|
||||
'model_name' : name,
|
||||
'model_type' : SDModelType,
|
||||
'description': description,
|
||||
'format': 'folder'|'safetensors'|'ckpt'
|
||||
},
|
||||
model_name2: { etc }
|
||||
},
|
||||
model_type2:
|
||||
{ model_name_n: etc
|
||||
}
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def list_model(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> dict:
|
||||
"""
|
||||
Return information about the model using the same format as list_models()
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def model_names(self) -> List[Tuple[str, BaseModelType, ModelType]]:
|
||||
"""
|
||||
Returns a list of all the model names known.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def add_model(
|
||||
self,
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
model_attributes: dict,
|
||||
clobber: bool = False,
|
||||
) -> AddModelResult:
|
||||
"""
|
||||
Update the named model with a dictionary of attributes. Will fail with an
|
||||
assertion error if the name already exists. Pass clobber=True to overwrite.
|
||||
On a successful update, the config will be changed in memory. Will fail
|
||||
with an assertion error if provided attributes are incorrect or
|
||||
the model name is missing. Call commit() to write changes to disk.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def update_model(
|
||||
self,
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
model_attributes: dict,
|
||||
) -> AddModelResult:
|
||||
"""
|
||||
Update the named model with a dictionary of attributes. Will fail with a
|
||||
ModelNotFoundException if the name does not already exist.
|
||||
|
||||
On a successful update, the config will be changed in memory. Will fail
|
||||
with an assertion error if provided attributes are incorrect or
|
||||
the model name is missing. Call commit() to write changes to disk.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def del_model(
|
||||
self,
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
):
|
||||
"""
|
||||
Delete the named model from configuration. If delete_files is true,
|
||||
then the underlying weight file or diffusers directory will be deleted
|
||||
as well. Call commit() to write to disk.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def rename_model(
|
||||
self,
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
new_name: str,
|
||||
):
|
||||
"""
|
||||
Rename the indicated model.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def list_checkpoint_configs(self) -> List[Path]:
|
||||
"""
|
||||
List the checkpoint config paths from ROOT/configs/stable-diffusion.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def convert_model(
|
||||
self,
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: Literal[ModelType.Main, ModelType.Vae],
|
||||
) -> AddModelResult:
|
||||
"""
|
||||
Convert a checkpoint file into a diffusers folder, deleting the cached
|
||||
version and deleting the original checkpoint file if it is in the models
|
||||
directory.
|
||||
:param model_name: Name of the model to convert
|
||||
:param base_model: Base model type
|
||||
:param model_type: Type of model ['vae' or 'main']
|
||||
|
||||
This will raise a ValueError unless the model is not a checkpoint. It will
|
||||
also raise a ValueError in the event that there is a similarly-named diffusers
|
||||
directory already in place.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def heuristic_import(
|
||||
self,
|
||||
items_to_import: set[str],
|
||||
prediction_type_helper: Optional[Callable[[Path], SchedulerPredictionType]] = None,
|
||||
) -> dict[str, AddModelResult]:
|
||||
"""Import a list of paths, repo_ids or URLs. Returns the set of
|
||||
successfully imported items.
|
||||
:param items_to_import: Set of strings corresponding to models to be imported.
|
||||
:param prediction_type_helper: A callback that receives the Path of a Stable Diffusion 2 checkpoint model and returns a SchedulerPredictionType.
|
||||
|
||||
The prediction type helper is necessary to distinguish between
|
||||
models based on Stable Diffusion 2 Base (requiring
|
||||
SchedulerPredictionType.Epsilson) and Stable Diffusion 768
|
||||
(requiring SchedulerPredictionType.VPrediction). It is
|
||||
generally impossible to do this programmatically, so the
|
||||
prediction_type_helper usually asks the user to choose.
|
||||
|
||||
The result is a set of successfully installed models. Each element
|
||||
of the set is a dict corresponding to the newly-created OmegaConf stanza for
|
||||
that model.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def merge_models(
|
||||
self,
|
||||
model_names: List[str] = Field(
|
||||
default=None, min_length=2, max_length=3, description="List of model names to merge"
|
||||
),
|
||||
base_model: Union[BaseModelType, str] = Field(
|
||||
default=None, description="Base model shared by all models to be merged"
|
||||
),
|
||||
merged_model_name: str = Field(default=None, description="Name of destination model after merging"),
|
||||
alpha: Optional[float] = 0.5,
|
||||
interp: Optional[MergeInterpolationMethod] = None,
|
||||
force: Optional[bool] = False,
|
||||
merge_dest_directory: Optional[Path] = None,
|
||||
) -> AddModelResult:
|
||||
"""
|
||||
Merge two to three diffusrs pipeline models and save as a new model.
|
||||
:param model_names: List of 2-3 models to merge
|
||||
:param base_model: Base model to use for all models
|
||||
:param merged_model_name: Name of destination merged model
|
||||
:param alpha: Alpha strength to apply to 2d and 3d model
|
||||
:param interp: Interpolation method. None (default)
|
||||
:param merge_dest_directory: Save the merged model to the designated directory (with 'merged_model_name' appended)
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def search_for_models(self, directory: Path) -> List[Path]:
|
||||
"""
|
||||
Return list of all models found in the designated directory.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def sync_to_config(self):
|
||||
"""
|
||||
Re-read models.yaml, rescan the models directory, and reimport models
|
||||
in the autoimport directories. Call after making changes outside the
|
||||
model manager API.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def collect_cache_stats(self, cache_stats: CacheStats):
|
||||
"""
|
||||
Reset model cache statistics for graph with graph_id.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def commit(self, conf_file: Optional[Path] = None) -> None:
|
||||
"""
|
||||
Write current configuration out to the indicated file.
|
||||
If no conf_file is provided, then replaces the
|
||||
original file/database used to initialize the object.
|
||||
"""
|
||||
pass
|
@ -2,16 +2,15 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from logging import Logger
|
||||
from pathlib import Path
|
||||
from types import ModuleType
|
||||
from typing import TYPE_CHECKING, Callable, List, Literal, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from pydantic import Field
|
||||
|
||||
from invokeai.app.models.exceptions import CanceledException
|
||||
from invokeai.app.services.config.config_default import InvokeAIAppConfig
|
||||
from invokeai.app.services.invocation_processor.invocation_processor_common import CanceledException
|
||||
from invokeai.backend.model_management import (
|
||||
AddModelResult,
|
||||
BaseModelType,
|
||||
@ -26,273 +25,12 @@ from invokeai.backend.model_management import (
|
||||
)
|
||||
from invokeai.backend.model_management.model_cache import CacheStats
|
||||
from invokeai.backend.model_management.model_search import FindModels
|
||||
from invokeai.backend.util import choose_precision, choose_torch_device
|
||||
|
||||
from ...backend.util import choose_precision, choose_torch_device
|
||||
from .config import InvokeAIAppConfig
|
||||
from .model_manager_base import ModelManagerServiceBase
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..invocations.baseinvocation import BaseInvocation, InvocationContext
|
||||
|
||||
|
||||
class ModelManagerServiceBase(ABC):
|
||||
"""Responsible for managing models on disk and in memory"""
|
||||
|
||||
@abstractmethod
|
||||
def __init__(
|
||||
self,
|
||||
config: InvokeAIAppConfig,
|
||||
logger: ModuleType,
|
||||
):
|
||||
"""
|
||||
Initialize with the path to the models.yaml config file.
|
||||
Optional parameters are the torch device type, precision, max_models,
|
||||
and sequential_offload boolean. Note that the default device
|
||||
type and precision are set up for a CUDA system running at half precision.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_model(
|
||||
self,
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
submodel: Optional[SubModelType] = None,
|
||||
node: Optional[BaseInvocation] = None,
|
||||
context: Optional[InvocationContext] = None,
|
||||
) -> ModelInfo:
|
||||
"""Retrieve the indicated model with name and type.
|
||||
submodel can be used to get a part (such as the vae)
|
||||
of a diffusers pipeline."""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def logger(self):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def model_exists(
|
||||
self,
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
) -> bool:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def model_info(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> dict:
|
||||
"""
|
||||
Given a model name returns a dict-like (OmegaConf) object describing it.
|
||||
Uses the exact format as the omegaconf stanza.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def list_models(self, base_model: Optional[BaseModelType] = None, model_type: Optional[ModelType] = None) -> dict:
|
||||
"""
|
||||
Return a dict of models in the format:
|
||||
{ model_type1:
|
||||
{ model_name1: {'status': 'active'|'cached'|'not loaded',
|
||||
'model_name' : name,
|
||||
'model_type' : SDModelType,
|
||||
'description': description,
|
||||
'format': 'folder'|'safetensors'|'ckpt'
|
||||
},
|
||||
model_name2: { etc }
|
||||
},
|
||||
model_type2:
|
||||
{ model_name_n: etc
|
||||
}
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def list_model(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> dict:
|
||||
"""
|
||||
Return information about the model using the same format as list_models()
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def model_names(self) -> List[Tuple[str, BaseModelType, ModelType]]:
|
||||
"""
|
||||
Returns a list of all the model names known.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def add_model(
|
||||
self,
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
model_attributes: dict,
|
||||
clobber: bool = False,
|
||||
) -> AddModelResult:
|
||||
"""
|
||||
Update the named model with a dictionary of attributes. Will fail with an
|
||||
assertion error if the name already exists. Pass clobber=True to overwrite.
|
||||
On a successful update, the config will be changed in memory. Will fail
|
||||
with an assertion error if provided attributes are incorrect or
|
||||
the model name is missing. Call commit() to write changes to disk.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def update_model(
|
||||
self,
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
model_attributes: dict,
|
||||
) -> AddModelResult:
|
||||
"""
|
||||
Update the named model with a dictionary of attributes. Will fail with a
|
||||
ModelNotFoundException if the name does not already exist.
|
||||
|
||||
On a successful update, the config will be changed in memory. Will fail
|
||||
with an assertion error if provided attributes are incorrect or
|
||||
the model name is missing. Call commit() to write changes to disk.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def del_model(
|
||||
self,
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
):
|
||||
"""
|
||||
Delete the named model from configuration. If delete_files is true,
|
||||
then the underlying weight file or diffusers directory will be deleted
|
||||
as well. Call commit() to write to disk.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def rename_model(
|
||||
self,
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
new_name: str,
|
||||
):
|
||||
"""
|
||||
Rename the indicated model.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def list_checkpoint_configs(self) -> List[Path]:
|
||||
"""
|
||||
List the checkpoint config paths from ROOT/configs/stable-diffusion.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def convert_model(
|
||||
self,
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: Literal[ModelType.Main, ModelType.Vae],
|
||||
) -> AddModelResult:
|
||||
"""
|
||||
Convert a checkpoint file into a diffusers folder, deleting the cached
|
||||
version and deleting the original checkpoint file if it is in the models
|
||||
directory.
|
||||
:param model_name: Name of the model to convert
|
||||
:param base_model: Base model type
|
||||
:param model_type: Type of model ['vae' or 'main']
|
||||
|
||||
This will raise a ValueError unless the model is not a checkpoint. It will
|
||||
also raise a ValueError in the event that there is a similarly-named diffusers
|
||||
directory already in place.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def heuristic_import(
|
||||
self,
|
||||
items_to_import: set[str],
|
||||
prediction_type_helper: Optional[Callable[[Path], SchedulerPredictionType]] = None,
|
||||
) -> dict[str, AddModelResult]:
|
||||
"""Import a list of paths, repo_ids or URLs. Returns the set of
|
||||
successfully imported items.
|
||||
:param items_to_import: Set of strings corresponding to models to be imported.
|
||||
:param prediction_type_helper: A callback that receives the Path of a Stable Diffusion 2 checkpoint model and returns a SchedulerPredictionType.
|
||||
|
||||
The prediction type helper is necessary to distinguish between
|
||||
models based on Stable Diffusion 2 Base (requiring
|
||||
SchedulerPredictionType.Epsilson) and Stable Diffusion 768
|
||||
(requiring SchedulerPredictionType.VPrediction). It is
|
||||
generally impossible to do this programmatically, so the
|
||||
prediction_type_helper usually asks the user to choose.
|
||||
|
||||
The result is a set of successfully installed models. Each element
|
||||
of the set is a dict corresponding to the newly-created OmegaConf stanza for
|
||||
that model.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def merge_models(
|
||||
self,
|
||||
model_names: List[str] = Field(
|
||||
default=None, min_items=2, max_items=3, description="List of model names to merge"
|
||||
),
|
||||
base_model: Union[BaseModelType, str] = Field(
|
||||
default=None, description="Base model shared by all models to be merged"
|
||||
),
|
||||
merged_model_name: str = Field(default=None, description="Name of destination model after merging"),
|
||||
alpha: Optional[float] = 0.5,
|
||||
interp: Optional[MergeInterpolationMethod] = None,
|
||||
force: Optional[bool] = False,
|
||||
merge_dest_directory: Optional[Path] = None,
|
||||
) -> AddModelResult:
|
||||
"""
|
||||
Merge two to three diffusrs pipeline models and save as a new model.
|
||||
:param model_names: List of 2-3 models to merge
|
||||
:param base_model: Base model to use for all models
|
||||
:param merged_model_name: Name of destination merged model
|
||||
:param alpha: Alpha strength to apply to 2d and 3d model
|
||||
:param interp: Interpolation method. None (default)
|
||||
:param merge_dest_directory: Save the merged model to the designated directory (with 'merged_model_name' appended)
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def search_for_models(self, directory: Path) -> List[Path]:
|
||||
"""
|
||||
Return list of all models found in the designated directory.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def sync_to_config(self):
|
||||
"""
|
||||
Re-read models.yaml, rescan the models directory, and reimport models
|
||||
in the autoimport directories. Call after making changes outside the
|
||||
model manager API.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def collect_cache_stats(self, cache_stats: CacheStats):
|
||||
"""
|
||||
Reset model cache statistics for graph with graph_id.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def commit(self, conf_file: Optional[Path] = None) -> None:
|
||||
"""
|
||||
Write current configuration out to the indicated file.
|
||||
If no conf_file is provided, then replaces the
|
||||
original file/database used to initialize the object.
|
||||
"""
|
||||
pass
|
||||
from invokeai.app.invocations.baseinvocation import InvocationContext
|
||||
|
||||
|
||||
# simple implementation
|
||||
@ -589,7 +327,7 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
def merge_models(
|
||||
self,
|
||||
model_names: List[str] = Field(
|
||||
default=None, min_items=2, max_items=3, description="List of model names to merge"
|
||||
default=None, min_length=2, max_length=3, description="List of model names to merge"
|
||||
),
|
||||
base_model: Union[BaseModelType, str] = Field(
|
||||
default=None, description="Base model shared by all models to be merged"
|
0
invokeai/app/services/names/__init__.py
Normal file
0
invokeai/app/services/names/__init__.py
Normal file
11
invokeai/app/services/names/names_base.py
Normal file
11
invokeai/app/services/names/names_base.py
Normal file
@ -0,0 +1,11 @@
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
class NameServiceBase(ABC):
|
||||
"""Low-level service responsible for naming resources (images, latents, etc)."""
|
||||
|
||||
# TODO: Add customizable naming schemes
|
||||
@abstractmethod
|
||||
def create_image_name(self) -> str:
|
||||
"""Creates a name for an image."""
|
||||
pass
|
8
invokeai/app/services/names/names_common.py
Normal file
8
invokeai/app/services/names/names_common.py
Normal file
@ -0,0 +1,8 @@
|
||||
from enum import Enum, EnumMeta
|
||||
|
||||
|
||||
class ResourceType(str, Enum, metaclass=EnumMeta):
|
||||
"""Enum for resource types."""
|
||||
|
||||
IMAGE = "image"
|
||||
LATENT = "latent"
|
13
invokeai/app/services/names/names_default.py
Normal file
13
invokeai/app/services/names/names_default.py
Normal file
@ -0,0 +1,13 @@
|
||||
from invokeai.app.util.misc import uuid_string
|
||||
|
||||
from .names_base import NameServiceBase
|
||||
|
||||
|
||||
class SimpleNameService(NameServiceBase):
|
||||
"""Creates image names from UUIDs."""
|
||||
|
||||
# TODO: Add customizable naming schemes
|
||||
def create_image_name(self) -> str:
|
||||
uuid_str = uuid_string()
|
||||
filename = f"{uuid_str}.png"
|
||||
return filename
|
@ -1,31 +0,0 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum, EnumMeta
|
||||
|
||||
from invokeai.app.util.misc import uuid_string
|
||||
|
||||
|
||||
class ResourceType(str, Enum, metaclass=EnumMeta):
|
||||
"""Enum for resource types."""
|
||||
|
||||
IMAGE = "image"
|
||||
LATENT = "latent"
|
||||
|
||||
|
||||
class NameServiceBase(ABC):
|
||||
"""Low-level service responsible for naming resources (images, latents, etc)."""
|
||||
|
||||
# TODO: Add customizable naming schemes
|
||||
@abstractmethod
|
||||
def create_image_name(self) -> str:
|
||||
"""Creates a name for an image."""
|
||||
pass
|
||||
|
||||
|
||||
class SimpleNameService(NameServiceBase):
|
||||
"""Creates image names from UUIDs."""
|
||||
|
||||
# TODO: Add customizable naming schemes
|
||||
def create_image_name(self) -> str:
|
||||
uuid_str = uuid_string()
|
||||
filename = f"{uuid_str}.png"
|
||||
return filename
|
@ -7,7 +7,7 @@ from typing import Optional
|
||||
from fastapi_events.handlers.local import local_handler
|
||||
from fastapi_events.typing import Event as FastAPIEvent
|
||||
|
||||
from invokeai.app.services.events import EventServiceBase
|
||||
from invokeai.app.services.events.events_base import EventServiceBase
|
||||
from invokeai.app.services.session_queue.session_queue_common import SessionQueueItem
|
||||
|
||||
from ..invoker import Invoker
|
||||
@ -97,7 +97,6 @@ class DefaultSessionProcessor(SessionProcessorBase):
|
||||
resume_event.set()
|
||||
self.__threadLimit.acquire()
|
||||
queue_item: Optional[SessionQueueItem] = None
|
||||
self.__invoker.services.logger
|
||||
while not stop_event.is_set():
|
||||
poll_now_event.clear()
|
||||
try:
|
||||
|
@ -1,7 +1,6 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional
|
||||
|
||||
from invokeai.app.services.graph import Graph
|
||||
from invokeai.app.services.session_queue.session_queue_common import (
|
||||
QUEUE_ITEM_STATUS,
|
||||
Batch,
|
||||
@ -18,7 +17,8 @@ from invokeai.app.services.session_queue.session_queue_common import (
|
||||
SessionQueueItemDTO,
|
||||
SessionQueueStatus,
|
||||
)
|
||||
from invokeai.app.services.shared.models import CursorPaginatedResults
|
||||
from invokeai.app.services.shared.graph import Graph
|
||||
from invokeai.app.services.shared.pagination import CursorPaginatedResults
|
||||
|
||||
|
||||
class SessionQueueBase(ABC):
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user