mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
all files migrated; tweaks needed
This commit is contained in:
80
invokeai/app/api/dependencies.py
Normal file
80
invokeai/app/api/dependencies.py
Normal file
@ -0,0 +1,80 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
from argparse import Namespace
|
||||
import os
|
||||
|
||||
from ..services.processor import DefaultInvocationProcessor
|
||||
|
||||
from ..services.graph import GraphExecutionState
|
||||
from ..services.sqlite import SqliteItemStorage
|
||||
|
||||
from ...globals import Globals
|
||||
|
||||
from ..services.image_storage import DiskImageStorage
|
||||
from ..services.invocation_queue import MemoryInvocationQueue
|
||||
from ..services.invocation_services import InvocationServices
|
||||
from ..services.invoker import Invoker
|
||||
from ..services.generate_initializer import get_generate
|
||||
from .events import FastAPIEventService
|
||||
|
||||
|
||||
# TODO: is there a better way to achieve this?
|
||||
def check_internet()->bool:
|
||||
'''
|
||||
Return true if the internet is reachable.
|
||||
It does this by pinging huggingface.co.
|
||||
'''
|
||||
import urllib.request
|
||||
host = 'http://huggingface.co'
|
||||
try:
|
||||
urllib.request.urlopen(host,timeout=1)
|
||||
return True
|
||||
except:
|
||||
return False
|
||||
|
||||
|
||||
class ApiDependencies:
|
||||
"""Contains and initializes all dependencies for the API"""
|
||||
invoker: Invoker = None
|
||||
|
||||
@staticmethod
|
||||
def initialize(
|
||||
args,
|
||||
config,
|
||||
event_handler_id: int
|
||||
):
|
||||
Globals.try_patchmatch = args.patchmatch
|
||||
Globals.always_use_cpu = args.always_use_cpu
|
||||
Globals.internet_available = args.internet_available and check_internet()
|
||||
Globals.disable_xformers = not args.xformers
|
||||
Globals.ckpt_convert = args.ckpt_convert
|
||||
|
||||
# TODO: Use a logger
|
||||
print(f'>> Internet connectivity is {Globals.internet_available}')
|
||||
|
||||
generate = get_generate(args, config)
|
||||
|
||||
events = FastAPIEventService(event_handler_id)
|
||||
|
||||
output_folder = os.path.abspath(os.path.join(os.path.dirname(__file__), '../../../../outputs'))
|
||||
|
||||
images = DiskImageStorage(output_folder)
|
||||
|
||||
# TODO: build a file/path manager?
|
||||
db_location = os.path.join(output_folder, 'invokeai.db')
|
||||
|
||||
services = InvocationServices(
|
||||
generate = generate,
|
||||
events = events,
|
||||
images = images,
|
||||
queue = MemoryInvocationQueue(),
|
||||
graph_execution_manager = SqliteItemStorage[GraphExecutionState](filename = db_location, table_name = 'graph_executions'),
|
||||
processor = DefaultInvocationProcessor()
|
||||
)
|
||||
|
||||
ApiDependencies.invoker = Invoker(services)
|
||||
|
||||
@staticmethod
|
||||
def shutdown():
|
||||
if ApiDependencies.invoker:
|
||||
ApiDependencies.invoker.stop()
|
54
invokeai/app/api/events.py
Normal file
54
invokeai/app/api/events.py
Normal file
@ -0,0 +1,54 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
import asyncio
|
||||
from queue import Empty, Queue
|
||||
from typing import Any
|
||||
from fastapi_events.dispatcher import dispatch
|
||||
from ..services.events import EventServiceBase
|
||||
import threading
|
||||
|
||||
class FastAPIEventService(EventServiceBase):
|
||||
event_handler_id: int
|
||||
__queue: Queue
|
||||
__stop_event: threading.Event
|
||||
|
||||
def __init__(self, event_handler_id: int) -> None:
|
||||
self.event_handler_id = event_handler_id
|
||||
self.__queue = Queue()
|
||||
self.__stop_event = threading.Event()
|
||||
asyncio.create_task(self.__dispatch_from_queue(stop_event = self.__stop_event))
|
||||
|
||||
super().__init__()
|
||||
|
||||
|
||||
def stop(self, *args, **kwargs):
|
||||
self.__stop_event.set()
|
||||
self.__queue.put(None)
|
||||
|
||||
|
||||
def dispatch(self, event_name: str, payload: Any) -> None:
|
||||
self.__queue.put(dict(
|
||||
event_name = event_name,
|
||||
payload = payload
|
||||
))
|
||||
|
||||
|
||||
async def __dispatch_from_queue(self, stop_event: threading.Event):
|
||||
"""Get events on from the queue and dispatch them, from the correct thread"""
|
||||
while not stop_event.is_set():
|
||||
try:
|
||||
event = self.__queue.get(block = False)
|
||||
if not event: # Probably stopping
|
||||
continue
|
||||
|
||||
dispatch(
|
||||
event.get('event_name'),
|
||||
payload = event.get('payload'),
|
||||
middleware_id = self.event_handler_id)
|
||||
|
||||
except Empty:
|
||||
await asyncio.sleep(0.001)
|
||||
pass
|
||||
|
||||
except asyncio.CancelledError as e:
|
||||
raise e # Raise a proper error
|
57
invokeai/app/api/routers/images.py
Normal file
57
invokeai/app/api/routers/images.py
Normal file
@ -0,0 +1,57 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from fastapi import Path, UploadFile, Request
|
||||
from fastapi.routing import APIRouter
|
||||
from fastapi.responses import FileResponse, Response
|
||||
from PIL import Image
|
||||
from ...services.image_storage import ImageType
|
||||
from ..dependencies import ApiDependencies
|
||||
|
||||
images_router = APIRouter(
|
||||
prefix = '/v1/images',
|
||||
tags = ['images']
|
||||
)
|
||||
|
||||
|
||||
@images_router.get('/{image_type}/{image_name}',
|
||||
operation_id = 'get_image'
|
||||
)
|
||||
async def get_image(
|
||||
image_type: ImageType = Path(description = "The type of image to get"),
|
||||
image_name: str = Path(description = "The name of the image to get")
|
||||
):
|
||||
"""Gets a result"""
|
||||
# TODO: This is not really secure at all. At least make sure only output results are served
|
||||
filename = ApiDependencies.invoker.services.images.get_path(image_type, image_name)
|
||||
return FileResponse(filename)
|
||||
|
||||
@images_router.post('/uploads/',
|
||||
operation_id = 'upload_image',
|
||||
responses = {
|
||||
201: {'description': 'The image was uploaded successfully'},
|
||||
404: {'description': 'Session not found'}
|
||||
})
|
||||
async def upload_image(
|
||||
file: UploadFile,
|
||||
request: Request
|
||||
):
|
||||
if not file.content_type.startswith('image'):
|
||||
return Response(status_code = 415)
|
||||
|
||||
contents = await file.read()
|
||||
try:
|
||||
im = Image.open(contents)
|
||||
except:
|
||||
# Error opening the image
|
||||
return Response(status_code = 415)
|
||||
|
||||
filename = f'{str(int(datetime.now(timezone.utc).timestamp()))}.png'
|
||||
ApiDependencies.invoker.services.images.save(ImageType.UPLOAD, filename, im)
|
||||
|
||||
return Response(
|
||||
status_code=201,
|
||||
headers = {
|
||||
'Location': request.url_for('get_image', image_type=ImageType.UPLOAD, image_name=filename)
|
||||
}
|
||||
)
|
232
invokeai/app/api/routers/sessions.py
Normal file
232
invokeai/app/api/routers/sessions.py
Normal file
@ -0,0 +1,232 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
from typing import List, Optional, Union, Annotated
|
||||
from fastapi import Query, Path, Body
|
||||
from fastapi.routing import APIRouter
|
||||
from fastapi.responses import Response
|
||||
from pydantic.fields import Field
|
||||
|
||||
from ...services.item_storage import PaginatedResults
|
||||
from ..dependencies import ApiDependencies
|
||||
from ...invocations.baseinvocation import BaseInvocation
|
||||
from ...services.graph import EdgeConnection, Graph, GraphExecutionState, NodeAlreadyExecutedError
|
||||
from ...invocations import *
|
||||
|
||||
session_router = APIRouter(
|
||||
prefix = '/v1/sessions',
|
||||
tags = ['sessions']
|
||||
)
|
||||
|
||||
|
||||
@session_router.post('/',
|
||||
operation_id = 'create_session',
|
||||
responses = {
|
||||
200: {"model": GraphExecutionState},
|
||||
400: {'description': 'Invalid json'}
|
||||
})
|
||||
async def create_session(
|
||||
graph: Optional[Graph] = Body(default = None, description = "The graph to initialize the session with")
|
||||
) -> GraphExecutionState:
|
||||
"""Creates a new session, optionally initializing it with an invocation graph"""
|
||||
session = ApiDependencies.invoker.create_execution_state(graph)
|
||||
return session
|
||||
|
||||
|
||||
@session_router.get('/',
|
||||
operation_id = 'list_sessions',
|
||||
responses = {
|
||||
200: {"model": PaginatedResults[GraphExecutionState]}
|
||||
})
|
||||
async def list_sessions(
|
||||
page: int = Query(default = 0, description = "The page of results to get"),
|
||||
per_page: int = Query(default = 10, description = "The number of results per page"),
|
||||
query: str = Query(default = '', description = "The query string to search for")
|
||||
) -> PaginatedResults[GraphExecutionState]:
|
||||
"""Gets a list of sessions, optionally searching"""
|
||||
if filter == '':
|
||||
result = ApiDependencies.invoker.services.graph_execution_manager.list(page, per_page)
|
||||
else:
|
||||
result = ApiDependencies.invoker.services.graph_execution_manager.search(query, page, per_page)
|
||||
return result
|
||||
|
||||
|
||||
@session_router.get('/{session_id}',
|
||||
operation_id = 'get_session',
|
||||
responses = {
|
||||
200: {"model": GraphExecutionState},
|
||||
404: {'description': 'Session not found'}
|
||||
})
|
||||
async def get_session(
|
||||
session_id: str = Path(description = "The id of the session to get")
|
||||
) -> GraphExecutionState:
|
||||
"""Gets a session"""
|
||||
session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
|
||||
if session is None:
|
||||
return Response(status_code = 404)
|
||||
else:
|
||||
return session
|
||||
|
||||
|
||||
@session_router.post('/{session_id}/nodes',
|
||||
operation_id = 'add_node',
|
||||
responses = {
|
||||
200: {"model": str},
|
||||
400: {'description': 'Invalid node or link'},
|
||||
404: {'description': 'Session not found'}
|
||||
}
|
||||
)
|
||||
async def add_node(
|
||||
session_id: str = Path(description = "The id of the session"),
|
||||
node: Annotated[Union[BaseInvocation.get_invocations()], Field(discriminator="type")] = Body(description = "The node to add")
|
||||
) -> str:
|
||||
"""Adds a node to the graph"""
|
||||
session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
|
||||
if session is None:
|
||||
return Response(status_code = 404)
|
||||
|
||||
try:
|
||||
session.add_node(node)
|
||||
ApiDependencies.invoker.services.graph_execution_manager.set(session) # TODO: can this be done automatically, or add node through an API?
|
||||
return session.id
|
||||
except NodeAlreadyExecutedError:
|
||||
return Response(status_code = 400)
|
||||
except IndexError:
|
||||
return Response(status_code = 400)
|
||||
|
||||
|
||||
@session_router.put('/{session_id}/nodes/{node_path}',
|
||||
operation_id = 'update_node',
|
||||
responses = {
|
||||
200: {"model": GraphExecutionState},
|
||||
400: {'description': 'Invalid node or link'},
|
||||
404: {'description': 'Session not found'}
|
||||
}
|
||||
)
|
||||
async def update_node(
|
||||
session_id: str = Path(description = "The id of the session"),
|
||||
node_path: str = Path(description = "The path to the node in the graph"),
|
||||
node: Annotated[Union[BaseInvocation.get_invocations()], Field(discriminator="type")] = Body(description = "The new node")
|
||||
) -> GraphExecutionState:
|
||||
"""Updates a node in the graph and removes all linked edges"""
|
||||
session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
|
||||
if session is None:
|
||||
return Response(status_code = 404)
|
||||
|
||||
try:
|
||||
session.update_node(node_path, node)
|
||||
ApiDependencies.invoker.services.graph_execution_manager.set(session) # TODO: can this be done automatically, or add node through an API?
|
||||
return session
|
||||
except NodeAlreadyExecutedError:
|
||||
return Response(status_code = 400)
|
||||
except IndexError:
|
||||
return Response(status_code = 400)
|
||||
|
||||
|
||||
@session_router.delete('/{session_id}/nodes/{node_path}',
|
||||
operation_id = 'delete_node',
|
||||
responses = {
|
||||
200: {"model": GraphExecutionState},
|
||||
400: {'description': 'Invalid node or link'},
|
||||
404: {'description': 'Session not found'}
|
||||
}
|
||||
)
|
||||
async def delete_node(
|
||||
session_id: str = Path(description = "The id of the session"),
|
||||
node_path: str = Path(description = "The path to the node to delete")
|
||||
) -> GraphExecutionState:
|
||||
"""Deletes a node in the graph and removes all linked edges"""
|
||||
session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
|
||||
if session is None:
|
||||
return Response(status_code = 404)
|
||||
|
||||
try:
|
||||
session.delete_node(node_path)
|
||||
ApiDependencies.invoker.services.graph_execution_manager.set(session) # TODO: can this be done automatically, or add node through an API?
|
||||
return session
|
||||
except NodeAlreadyExecutedError:
|
||||
return Response(status_code = 400)
|
||||
except IndexError:
|
||||
return Response(status_code = 400)
|
||||
|
||||
|
||||
@session_router.post('/{session_id}/edges',
|
||||
operation_id = 'add_edge',
|
||||
responses = {
|
||||
200: {"model": GraphExecutionState},
|
||||
400: {'description': 'Invalid node or link'},
|
||||
404: {'description': 'Session not found'}
|
||||
}
|
||||
)
|
||||
async def add_edge(
|
||||
session_id: str = Path(description = "The id of the session"),
|
||||
edge: tuple[EdgeConnection, EdgeConnection] = Body(description = "The edge to add")
|
||||
) -> GraphExecutionState:
|
||||
"""Adds an edge to the graph"""
|
||||
session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
|
||||
if session is None:
|
||||
return Response(status_code = 404)
|
||||
|
||||
try:
|
||||
session.add_edge(edge)
|
||||
ApiDependencies.invoker.services.graph_execution_manager.set(session) # TODO: can this be done automatically, or add node through an API?
|
||||
return session
|
||||
except NodeAlreadyExecutedError:
|
||||
return Response(status_code = 400)
|
||||
except IndexError:
|
||||
return Response(status_code = 400)
|
||||
|
||||
|
||||
# TODO: the edge being in the path here is really ugly, find a better solution
|
||||
@session_router.delete('/{session_id}/edges/{from_node_id}/{from_field}/{to_node_id}/{to_field}',
|
||||
operation_id = 'delete_edge',
|
||||
responses = {
|
||||
200: {"model": GraphExecutionState},
|
||||
400: {'description': 'Invalid node or link'},
|
||||
404: {'description': 'Session not found'}
|
||||
}
|
||||
)
|
||||
async def delete_edge(
|
||||
session_id: str = Path(description = "The id of the session"),
|
||||
from_node_id: str = Path(description = "The id of the node the edge is coming from"),
|
||||
from_field: str = Path(description = "The field of the node the edge is coming from"),
|
||||
to_node_id: str = Path(description = "The id of the node the edge is going to"),
|
||||
to_field: str = Path(description = "The field of the node the edge is going to")
|
||||
) -> GraphExecutionState:
|
||||
"""Deletes an edge from the graph"""
|
||||
session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
|
||||
if session is None:
|
||||
return Response(status_code = 404)
|
||||
|
||||
try:
|
||||
edge = (EdgeConnection(node_id = from_node_id, field = from_field), EdgeConnection(node_id = to_node_id, field = to_field))
|
||||
session.delete_edge(edge)
|
||||
ApiDependencies.invoker.services.graph_execution_manager.set(session) # TODO: can this be done automatically, or add node through an API?
|
||||
return session
|
||||
except NodeAlreadyExecutedError:
|
||||
return Response(status_code = 400)
|
||||
except IndexError:
|
||||
return Response(status_code = 400)
|
||||
|
||||
|
||||
@session_router.put('/{session_id}/invoke',
|
||||
operation_id = 'invoke_session',
|
||||
responses = {
|
||||
200: {"model": None},
|
||||
202: {'description': 'The invocation is queued'},
|
||||
400: {'description': 'The session has no invocations ready to invoke'},
|
||||
404: {'description': 'Session not found'}
|
||||
})
|
||||
async def invoke_session(
|
||||
session_id: str = Path(description = "The id of the session to invoke"),
|
||||
all: bool = Query(default = False, description = "Whether or not to invoke all remaining invocations")
|
||||
) -> None:
|
||||
"""Invokes a session"""
|
||||
session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
|
||||
if session is None:
|
||||
return Response(status_code = 404)
|
||||
|
||||
if session.is_complete():
|
||||
return Response(status_code = 400)
|
||||
|
||||
ApiDependencies.invoker.invoke(session, invoke_all = all)
|
||||
return Response(status_code=202)
|
36
invokeai/app/api/sockets.py
Normal file
36
invokeai/app/api/sockets.py
Normal file
@ -0,0 +1,36 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
from fastapi import FastAPI
|
||||
from fastapi_socketio import SocketManager
|
||||
from fastapi_events.handlers.local import local_handler
|
||||
from fastapi_events.typing import Event
|
||||
from ..services.events import EventServiceBase
|
||||
|
||||
class SocketIO:
|
||||
__sio: SocketManager
|
||||
|
||||
def __init__(self, app: FastAPI):
|
||||
self.__sio = SocketManager(app = app)
|
||||
self.__sio.on('subscribe', handler=self._handle_sub)
|
||||
self.__sio.on('unsubscribe', handler=self._handle_unsub)
|
||||
|
||||
local_handler.register(
|
||||
event_name = EventServiceBase.session_event,
|
||||
_func=self._handle_session_event
|
||||
)
|
||||
|
||||
async def _handle_session_event(self, event: Event):
|
||||
await self.__sio.emit(
|
||||
event = event[1]['event'],
|
||||
data = event[1]['data'],
|
||||
room = event[1]['data']['graph_execution_state_id']
|
||||
)
|
||||
|
||||
async def _handle_sub(self, sid, data, *args, **kwargs):
|
||||
if 'session' in data:
|
||||
self.__sio.enter_room(sid, data['session'])
|
||||
|
||||
# @app.sio.on('unsubscribe')
|
||||
async def _handle_unsub(self, sid, data, *args, **kwargs):
|
||||
if 'session' in data:
|
||||
self.__sio.leave_room(sid, data['session'])
|
164
invokeai/app/api_app.py
Normal file
164
invokeai/app/api_app.py
Normal file
@ -0,0 +1,164 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
import asyncio
|
||||
from inspect import signature
|
||||
from fastapi import FastAPI
|
||||
from fastapi.openapi.utils import get_openapi
|
||||
from fastapi.openapi.docs import get_swagger_ui_html, get_redoc_html
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from fastapi_events.middleware import EventHandlerASGIMiddleware
|
||||
from fastapi_events.handlers.local import local_handler
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from pydantic.schema import schema
|
||||
import uvicorn
|
||||
from .api.sockets import SocketIO
|
||||
from .invocations import *
|
||||
from .invocations.baseinvocation import BaseInvocation
|
||||
from .api.routers import images, sessions
|
||||
from .api.dependencies import ApiDependencies
|
||||
from ..args import Args
|
||||
|
||||
# Create the app
|
||||
# TODO: create this all in a method so configuration/etc. can be passed in?
|
||||
app = FastAPI(
|
||||
title = "Invoke AI",
|
||||
docs_url = None,
|
||||
redoc_url = None
|
||||
)
|
||||
|
||||
# Add event handler
|
||||
event_handler_id: int = id(app)
|
||||
app.add_middleware(
|
||||
EventHandlerASGIMiddleware,
|
||||
handlers = [local_handler], # TODO: consider doing this in services to support different configurations
|
||||
middleware_id = event_handler_id)
|
||||
|
||||
# Add CORS
|
||||
# TODO: use configuration for this
|
||||
origins = []
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=origins,
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
socket_io = SocketIO(app)
|
||||
|
||||
config = {}
|
||||
|
||||
# Add startup event to load dependencies
|
||||
@app.on_event('startup')
|
||||
async def startup_event():
|
||||
args = Args()
|
||||
config = args.parse_args()
|
||||
|
||||
ApiDependencies.initialize(
|
||||
args = args,
|
||||
config = config,
|
||||
event_handler_id = event_handler_id
|
||||
)
|
||||
|
||||
# Shut down threads
|
||||
@app.on_event('shutdown')
|
||||
async def shutdown_event():
|
||||
ApiDependencies.shutdown()
|
||||
|
||||
# Include all routers
|
||||
# TODO: REMOVE
|
||||
# app.include_router(
|
||||
# invocation.invocation_router,
|
||||
# prefix = '/api')
|
||||
|
||||
app.include_router(
|
||||
sessions.session_router,
|
||||
prefix = '/api'
|
||||
)
|
||||
|
||||
app.include_router(
|
||||
images.images_router,
|
||||
prefix = '/api'
|
||||
)
|
||||
|
||||
# Build a custom OpenAPI to include all outputs
|
||||
# TODO: can outputs be included on metadata of invocation schemas somehow?
|
||||
def custom_openapi():
|
||||
if app.openapi_schema:
|
||||
return app.openapi_schema
|
||||
openapi_schema = get_openapi(
|
||||
title = app.title,
|
||||
description = "An API for invoking AI image operations",
|
||||
version = "1.0.0",
|
||||
routes = app.routes
|
||||
)
|
||||
|
||||
# Add all outputs
|
||||
all_invocations = BaseInvocation.get_invocations()
|
||||
output_types = set()
|
||||
output_type_titles = dict()
|
||||
for invoker in all_invocations:
|
||||
output_type = signature(invoker.invoke).return_annotation
|
||||
output_types.add(output_type)
|
||||
|
||||
output_schemas = schema(output_types, ref_prefix="#/components/schemas/")
|
||||
for schema_key, output_schema in output_schemas['definitions'].items():
|
||||
openapi_schema["components"]["schemas"][schema_key] = output_schema
|
||||
|
||||
# TODO: note that we assume the schema_key here is the TYPE.__name__
|
||||
# This could break in some cases, figure out a better way to do it
|
||||
output_type_titles[schema_key] = output_schema['title']
|
||||
|
||||
# Add a reference to the output type to additionalProperties of the invoker schema
|
||||
for invoker in all_invocations:
|
||||
invoker_name = invoker.__name__
|
||||
output_type = signature(invoker.invoke).return_annotation
|
||||
output_type_title = output_type_titles[output_type.__name__]
|
||||
invoker_schema = openapi_schema["components"]["schemas"][invoker_name]
|
||||
outputs_ref = { '$ref': f'#/components/schemas/{output_type_title}' }
|
||||
if 'additionalProperties' not in invoker_schema:
|
||||
invoker_schema['additionalProperties'] = {}
|
||||
|
||||
invoker_schema['additionalProperties']['outputs'] = outputs_ref
|
||||
|
||||
app.openapi_schema = openapi_schema
|
||||
return app.openapi_schema
|
||||
|
||||
app.openapi = custom_openapi
|
||||
|
||||
# Override API doc favicons
|
||||
app.mount('/static', StaticFiles(directory='static/dream_web'), name='static')
|
||||
|
||||
@app.get("/docs", include_in_schema=False)
|
||||
def overridden_swagger():
|
||||
return get_swagger_ui_html(
|
||||
openapi_url=app.openapi_url,
|
||||
title=app.title,
|
||||
swagger_favicon_url="/static/favicon.ico"
|
||||
)
|
||||
|
||||
@app.get("/redoc", include_in_schema=False)
|
||||
def overridden_redoc():
|
||||
return get_redoc_html(
|
||||
openapi_url=app.openapi_url,
|
||||
title=app.title,
|
||||
redoc_favicon_url="/static/favicon.ico"
|
||||
)
|
||||
|
||||
def invoke_api():
|
||||
# Start our own event loop for eventing usage
|
||||
# TODO: determine if there's a better way to do this
|
||||
loop = asyncio.new_event_loop()
|
||||
config = uvicorn.Config(
|
||||
app = app,
|
||||
host = "0.0.0.0",
|
||||
port = 9090,
|
||||
loop = loop)
|
||||
# Use access_log to turn off logging
|
||||
|
||||
server = uvicorn.Server(config)
|
||||
loop.run_until_complete(server.serve())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
invoke_api()
|
315
invokeai/app/cli_app.py
Normal file
315
invokeai/app/cli_app.py
Normal file
@ -0,0 +1,315 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
import argparse
|
||||
import shlex
|
||||
import os
|
||||
import time
|
||||
from typing import Any, Dict, Iterable, Literal, Union, get_args, get_origin, get_type_hints
|
||||
from pydantic import BaseModel
|
||||
from pydantic.fields import Field
|
||||
|
||||
from .services.processor import DefaultInvocationProcessor
|
||||
|
||||
from .services.graph import EdgeConnection, GraphExecutionState
|
||||
|
||||
from .services.sqlite import SqliteItemStorage
|
||||
|
||||
from .invocations.image import ImageField
|
||||
from .services.generate_initializer import get_generate
|
||||
from .services.image_storage import DiskImageStorage
|
||||
from .services.invocation_queue import MemoryInvocationQueue
|
||||
from .invocations.baseinvocation import BaseInvocation
|
||||
from .services.invocation_services import InvocationServices
|
||||
from .services.invoker import Invoker
|
||||
from .invocations import *
|
||||
from ..args import Args
|
||||
from .services.events import EventServiceBase
|
||||
|
||||
|
||||
class InvocationCommand(BaseModel):
|
||||
invocation: Union[BaseInvocation.get_invocations()] = Field(discriminator="type") # type: ignore
|
||||
|
||||
|
||||
class InvalidArgs(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def get_invocation_parser() -> argparse.ArgumentParser:
|
||||
|
||||
# Create invocation parser
|
||||
parser = argparse.ArgumentParser()
|
||||
def exit(*args, **kwargs):
|
||||
raise InvalidArgs
|
||||
parser.exit = exit
|
||||
|
||||
subparsers = parser.add_subparsers(dest='type')
|
||||
invocation_parsers = dict()
|
||||
|
||||
# Add history parser
|
||||
history_parser = subparsers.add_parser('history', help="Shows the invocation history")
|
||||
history_parser.add_argument('count', nargs='?', default=5, type=int, help="The number of history entries to show")
|
||||
|
||||
# Add default parser
|
||||
default_parser = subparsers.add_parser('default', help="Define a default value for all inputs with a specified name")
|
||||
default_parser.add_argument('input', type=str, help="The input field")
|
||||
default_parser.add_argument('value', help="The default value")
|
||||
|
||||
default_parser = subparsers.add_parser('reset_default', help="Resets a default value")
|
||||
default_parser.add_argument('input', type=str, help="The input field")
|
||||
|
||||
# Create subparsers for each invocation
|
||||
invocations = BaseInvocation.get_all_subclasses()
|
||||
for invocation in invocations:
|
||||
hints = get_type_hints(invocation)
|
||||
cmd_name = get_args(hints['type'])[0]
|
||||
command_parser = subparsers.add_parser(cmd_name, help=invocation.__doc__)
|
||||
invocation_parsers[cmd_name] = command_parser
|
||||
|
||||
# Add linking capability
|
||||
command_parser.add_argument('--link', '-l', action='append', nargs=3,
|
||||
help="A link in the format 'dest_field source_node source_field'. source_node can be relative to history (e.g. -1)")
|
||||
|
||||
command_parser.add_argument('--link_node', '-ln', action='append',
|
||||
help="A link from all fields in the specified node. Node can be relative to history (e.g. -1)")
|
||||
|
||||
# Convert all fields to arguments
|
||||
fields = invocation.__fields__
|
||||
for name, field in fields.items():
|
||||
if name in ['id', 'type']:
|
||||
continue
|
||||
|
||||
if get_origin(field.type_) == Literal:
|
||||
allowed_values = get_args(field.type_)
|
||||
allowed_types = set()
|
||||
for val in allowed_values:
|
||||
allowed_types.add(type(val))
|
||||
allowed_types_list = list(allowed_types)
|
||||
field_type = allowed_types_list[0] if len(allowed_types) == 1 else Union[allowed_types_list] # type: ignore
|
||||
|
||||
command_parser.add_argument(
|
||||
f"--{name}",
|
||||
dest=name,
|
||||
type=field_type,
|
||||
default=field.default,
|
||||
choices = allowed_values,
|
||||
help=field.field_info.description
|
||||
)
|
||||
else:
|
||||
command_parser.add_argument(
|
||||
f"--{name}",
|
||||
dest=name,
|
||||
type=field.type_,
|
||||
default=field.default,
|
||||
help=field.field_info.description
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def get_invocation_command(invocation) -> str:
|
||||
fields = invocation.__fields__.items()
|
||||
type_hints = get_type_hints(type(invocation))
|
||||
command = [invocation.type]
|
||||
for name,field in fields:
|
||||
if name in ['id', 'type']:
|
||||
continue
|
||||
|
||||
# TODO: add links
|
||||
|
||||
# Skip image fields when serializing command
|
||||
type_hint = type_hints.get(name) or None
|
||||
if type_hint is ImageField or ImageField in get_args(type_hint):
|
||||
continue
|
||||
|
||||
field_value = getattr(invocation, name)
|
||||
field_default = field.default
|
||||
if field_value != field_default:
|
||||
if type_hint is str or str in get_args(type_hint):
|
||||
command.append(f'--{name} "{field_value}"')
|
||||
else:
|
||||
command.append(f'--{name} {field_value}')
|
||||
|
||||
return ' '.join(command)
|
||||
|
||||
|
||||
def get_graph_execution_history(graph_execution_state: GraphExecutionState) -> Iterable[str]:
|
||||
"""Gets the history of fully-executed invocations for a graph execution"""
|
||||
return (n for n in reversed(graph_execution_state.executed_history) if n in graph_execution_state.graph.nodes)
|
||||
|
||||
|
||||
def generate_matching_edges(a: BaseInvocation, b: BaseInvocation) -> list[tuple[EdgeConnection, EdgeConnection]]:
|
||||
"""Generates all possible edges between two invocations"""
|
||||
atype = type(a)
|
||||
btype = type(b)
|
||||
|
||||
aoutputtype = atype.get_output_type()
|
||||
|
||||
afields = get_type_hints(aoutputtype)
|
||||
bfields = get_type_hints(btype)
|
||||
|
||||
matching_fields = set(afields.keys()).intersection(bfields.keys())
|
||||
|
||||
# Remove invalid fields
|
||||
invalid_fields = set(['type', 'id'])
|
||||
matching_fields = matching_fields.difference(invalid_fields)
|
||||
|
||||
edges = [(EdgeConnection(node_id = a.id, field = field), EdgeConnection(node_id = b.id, field = field)) for field in matching_fields]
|
||||
return edges
|
||||
|
||||
|
||||
def invoke_cli():
|
||||
args = Args()
|
||||
config = args.parse_args()
|
||||
|
||||
generate = get_generate(args, config)
|
||||
|
||||
# NOTE: load model on first use, uncomment to load at startup
|
||||
# TODO: Make this a config option?
|
||||
#generate.load_model()
|
||||
|
||||
events = EventServiceBase()
|
||||
|
||||
output_folder = os.path.abspath(os.path.join(os.path.dirname(__file__), '../../../outputs'))
|
||||
|
||||
# TODO: build a file/path manager?
|
||||
db_location = os.path.join(output_folder, 'invokeai.db')
|
||||
|
||||
services = InvocationServices(
|
||||
generate = generate,
|
||||
events = events,
|
||||
images = DiskImageStorage(output_folder),
|
||||
queue = MemoryInvocationQueue(),
|
||||
graph_execution_manager = SqliteItemStorage[GraphExecutionState](filename = db_location, table_name = 'graph_executions'),
|
||||
processor = DefaultInvocationProcessor()
|
||||
)
|
||||
|
||||
invoker = Invoker(services)
|
||||
session: GraphExecutionState = invoker.create_execution_state()
|
||||
|
||||
parser = get_invocation_parser()
|
||||
|
||||
# Uncomment to print out previous sessions at startup
|
||||
# print(services.session_manager.list())
|
||||
|
||||
# Defaults storage
|
||||
defaults: Dict[str, Any] = dict()
|
||||
|
||||
while True:
|
||||
try:
|
||||
cmd_input = input("> ")
|
||||
except KeyboardInterrupt:
|
||||
# Ctrl-c exits
|
||||
break
|
||||
|
||||
if cmd_input in ['exit','q']:
|
||||
break;
|
||||
|
||||
if cmd_input in ['--help','help','h','?']:
|
||||
parser.print_help()
|
||||
continue
|
||||
|
||||
try:
|
||||
# Refresh the state of the session
|
||||
session = invoker.services.graph_execution_manager.get(session.id)
|
||||
history = list(get_graph_execution_history(session))
|
||||
|
||||
# Split the command for piping
|
||||
cmds = cmd_input.split('|')
|
||||
start_id = len(history)
|
||||
current_id = start_id
|
||||
new_invocations = list()
|
||||
for cmd in cmds:
|
||||
if cmd is None or cmd.strip() == '':
|
||||
raise InvalidArgs('Empty command')
|
||||
|
||||
# Parse args to create invocation
|
||||
args = vars(parser.parse_args(shlex.split(cmd.strip())))
|
||||
|
||||
# Check for special commands
|
||||
# TODO: These might be better as Pydantic models, similar to the invocations
|
||||
if args['type'] == 'history':
|
||||
history_count = args['count'] or 5
|
||||
for i in range(min(history_count, len(history))):
|
||||
entry_id = history[-1 - i]
|
||||
entry = session.graph.get_node(entry_id)
|
||||
print(f'{entry_id}: {get_invocation_command(entry.invocation)}')
|
||||
continue
|
||||
|
||||
if args['type'] == 'reset_default':
|
||||
if args['input'] in defaults:
|
||||
del defaults[args['input']]
|
||||
continue
|
||||
|
||||
if args['type'] == 'default':
|
||||
field = args['input']
|
||||
field_value = args['value']
|
||||
defaults[field] = field_value
|
||||
continue
|
||||
|
||||
# Override defaults
|
||||
for field_name,field_default in defaults.items():
|
||||
if field_name in args:
|
||||
args[field_name] = field_default
|
||||
|
||||
# Parse invocation
|
||||
args['id'] = current_id
|
||||
command = InvocationCommand(invocation = args)
|
||||
|
||||
# Pipe previous command output (if there was a previous command)
|
||||
edges = []
|
||||
if len(history) > 0 or current_id != start_id:
|
||||
from_id = history[0] if current_id == start_id else str(current_id - 1)
|
||||
from_node = next(filter(lambda n: n[0].id == from_id, new_invocations))[0] if current_id != start_id else session.graph.get_node(from_id)
|
||||
matching_edges = generate_matching_edges(from_node, command.invocation)
|
||||
edges.extend(matching_edges)
|
||||
|
||||
# Parse provided links
|
||||
if 'link_node' in args and args['link_node']:
|
||||
for link in args['link_node']:
|
||||
link_node = session.graph.get_node(link)
|
||||
matching_edges = generate_matching_edges(link_node, command.invocation)
|
||||
edges.extend(matching_edges)
|
||||
|
||||
if 'link' in args and args['link']:
|
||||
for link in args['link']:
|
||||
edges.append((EdgeConnection(node_id = link[1], field = link[0]), EdgeConnection(node_id = command.invocation.id, field = link[2])))
|
||||
|
||||
new_invocations.append((command.invocation, edges))
|
||||
|
||||
current_id = current_id + 1
|
||||
|
||||
# Command line was parsed successfully
|
||||
# Add the invocations to the session
|
||||
for invocation in new_invocations:
|
||||
session.add_node(invocation[0])
|
||||
for edge in invocation[1]:
|
||||
session.add_edge(edge)
|
||||
|
||||
# Execute all available invocations
|
||||
invoker.invoke(session, invoke_all = True)
|
||||
while not session.is_complete():
|
||||
# Wait some time
|
||||
session = invoker.services.graph_execution_manager.get(session.id)
|
||||
time.sleep(0.1)
|
||||
|
||||
# Print any errors
|
||||
if session.has_error():
|
||||
for n in session.errors:
|
||||
print(f'Error in node {n} (source node {session.prepared_source_mapping[n]}): {session.errors[n]}')
|
||||
|
||||
# Start a new session
|
||||
print("Creating a new session")
|
||||
session = invoker.create_execution_state()
|
||||
|
||||
except InvalidArgs:
|
||||
print('Invalid command, use "help" to list commands')
|
||||
continue
|
||||
|
||||
except SystemExit:
|
||||
continue
|
||||
|
||||
invoker.stop()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
invoke_cli()
|
8
invokeai/app/invocations/__init__.py
Normal file
8
invokeai/app/invocations/__init__.py
Normal file
@ -0,0 +1,8 @@
|
||||
import os
|
||||
|
||||
__all__ = []
|
||||
|
||||
dirname = os.path.dirname(os.path.abspath(__file__))
|
||||
for f in os.listdir(dirname):
|
||||
if f != "__init__.py" and os.path.isfile("%s/%s" % (dirname, f)) and f[-3:] == ".py":
|
||||
__all__.append(f[:-3])
|
74
invokeai/app/invocations/baseinvocation.py
Normal file
74
invokeai/app/invocations/baseinvocation.py
Normal file
@ -0,0 +1,74 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from inspect import signature
|
||||
from typing import get_args, get_type_hints
|
||||
from pydantic import BaseModel, Field
|
||||
from ..services.invocation_services import InvocationServices
|
||||
|
||||
|
||||
class InvocationContext:
|
||||
services: InvocationServices
|
||||
graph_execution_state_id: str
|
||||
|
||||
def __init__(self, services: InvocationServices, graph_execution_state_id: str):
|
||||
self.services = services
|
||||
self.graph_execution_state_id = graph_execution_state_id
|
||||
|
||||
|
||||
class BaseInvocationOutput(BaseModel):
|
||||
"""Base class for all invocation outputs"""
|
||||
|
||||
# All outputs must include a type name like this:
|
||||
# type: Literal['your_output_name']
|
||||
|
||||
@classmethod
|
||||
def get_all_subclasses_tuple(cls):
|
||||
subclasses = []
|
||||
toprocess = [cls]
|
||||
while len(toprocess) > 0:
|
||||
next = toprocess.pop(0)
|
||||
next_subclasses = next.__subclasses__()
|
||||
subclasses.extend(next_subclasses)
|
||||
toprocess.extend(next_subclasses)
|
||||
return tuple(subclasses)
|
||||
|
||||
|
||||
class BaseInvocation(ABC, BaseModel):
|
||||
"""A node to process inputs and produce outputs.
|
||||
May use dependency injection in __init__ to receive providers.
|
||||
"""
|
||||
|
||||
# All invocations must include a type name like this:
|
||||
# type: Literal['your_output_name']
|
||||
|
||||
@classmethod
|
||||
def get_all_subclasses(cls):
|
||||
subclasses = []
|
||||
toprocess = [cls]
|
||||
while len(toprocess) > 0:
|
||||
next = toprocess.pop(0)
|
||||
next_subclasses = next.__subclasses__()
|
||||
subclasses.extend(next_subclasses)
|
||||
toprocess.extend(next_subclasses)
|
||||
return subclasses
|
||||
|
||||
@classmethod
|
||||
def get_invocations(cls):
|
||||
return tuple(BaseInvocation.get_all_subclasses())
|
||||
|
||||
@classmethod
|
||||
def get_invocations_map(cls):
|
||||
# Get the type strings out of the literals and into a dictionary
|
||||
return dict(map(lambda t: (get_args(get_type_hints(t)['type'])[0], t),BaseInvocation.get_all_subclasses()))
|
||||
|
||||
@classmethod
|
||||
def get_output_type(cls):
|
||||
return signature(cls.invoke).return_annotation
|
||||
|
||||
@abstractmethod
|
||||
def invoke(self, context: InvocationContext) -> BaseInvocationOutput:
|
||||
"""Invoke with provided context and return outputs."""
|
||||
pass
|
||||
|
||||
id: str = Field(description="The id of this node. Must be unique among all nodes.")
|
42
invokeai/app/invocations/cv.py
Normal file
42
invokeai/app/invocations/cv.py
Normal file
@ -0,0 +1,42 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
from typing import Literal
|
||||
import numpy
|
||||
from pydantic import Field
|
||||
from PIL import Image, ImageOps
|
||||
import cv2 as cv
|
||||
from .image import ImageField, ImageOutput
|
||||
from .baseinvocation import BaseInvocation, InvocationContext
|
||||
from ..services.image_storage import ImageType
|
||||
|
||||
|
||||
class CvInpaintInvocation(BaseInvocation):
|
||||
"""Simple inpaint using opencv."""
|
||||
type: Literal['cv_inpaint'] = 'cv_inpaint'
|
||||
|
||||
# Inputs
|
||||
image: ImageField = Field(default=None, description="The image to inpaint")
|
||||
mask: ImageField = Field(default=None, description="The mask to use when inpainting")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get(self.image.image_type, self.image.image_name)
|
||||
mask = context.services.images.get(self.mask.image_type, self.mask.image_name)
|
||||
|
||||
# Convert to cv image/mask
|
||||
# TODO: consider making these utility functions
|
||||
cv_image = cv.cvtColor(numpy.array(image.convert('RGB')), cv.COLOR_RGB2BGR)
|
||||
cv_mask = numpy.array(ImageOps.invert(mask))
|
||||
|
||||
# Inpaint
|
||||
cv_inpainted = cv.inpaint(cv_image, cv_mask, 3, cv.INPAINT_TELEA)
|
||||
|
||||
# Convert back to Pillow
|
||||
# TODO: consider making a utility function
|
||||
image_inpainted = Image.fromarray(cv.cvtColor(cv_inpainted, cv.COLOR_BGR2RGB))
|
||||
|
||||
image_type = ImageType.INTERMEDIATE
|
||||
image_name = context.services.images.create_name(context.graph_execution_state_id, self.id)
|
||||
context.services.images.save(image_type, image_name, image_inpainted)
|
||||
return ImageOutput(
|
||||
image = ImageField(image_type = image_type, image_name = image_name)
|
||||
)
|
160
invokeai/app/invocations/generate.py
Normal file
160
invokeai/app/invocations/generate.py
Normal file
@ -0,0 +1,160 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Literal, Optional, Union
|
||||
import numpy as np
|
||||
from pydantic import Field
|
||||
from PIL import Image
|
||||
from skimage.exposure.histogram_matching import match_histograms
|
||||
from .image import ImageField, ImageOutput
|
||||
from .baseinvocation import BaseInvocation, InvocationContext
|
||||
from ..services.image_storage import ImageType
|
||||
from ..services.invocation_services import InvocationServices
|
||||
|
||||
|
||||
SAMPLER_NAME_VALUES = Literal["ddim","plms","k_lms","k_dpm_2","k_dpm_2_a","k_euler","k_euler_a","k_heun"]
|
||||
|
||||
# Text to image
|
||||
class TextToImageInvocation(BaseInvocation):
|
||||
"""Generates an image using text2img."""
|
||||
type: Literal['txt2img'] = 'txt2img'
|
||||
|
||||
# Inputs
|
||||
# TODO: consider making prompt optional to enable providing prompt through a link
|
||||
prompt: Optional[str] = Field(description="The prompt to generate an image from")
|
||||
seed: int = Field(default=-1, ge=-1, le=np.iinfo(np.uint32).max, description="The seed to use (-1 for a random seed)")
|
||||
steps: int = Field(default=10, gt=0, description="The number of steps to use to generate the image")
|
||||
width: int = Field(default=512, multiple_of=64, gt=0, description="The width of the resulting image")
|
||||
height: int = Field(default=512, multiple_of=64, gt=0, description="The height of the resulting image")
|
||||
cfg_scale: float = Field(default=7.5, gt=0, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt")
|
||||
sampler_name: SAMPLER_NAME_VALUES = Field(default="k_lms", description="The sampler to use")
|
||||
seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams")
|
||||
model: str = Field(default='', description="The model to use (currently ignored)")
|
||||
progress_images: bool = Field(default=False, description="Whether or not to produce progress images during generation")
|
||||
|
||||
# TODO: pass this an emitter method or something? or a session for dispatching?
|
||||
def dispatch_progress(self, context: InvocationContext, sample: Any = None, step: int = 0) -> None:
|
||||
context.services.events.emit_generator_progress(
|
||||
context.graph_execution_state_id, self.id, step, float(step) / float(self.steps)
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
|
||||
def step_callback(sample, step = 0):
|
||||
self.dispatch_progress(context, sample, step)
|
||||
|
||||
# Handle invalid model parameter
|
||||
# TODO: figure out if this can be done via a validator that uses the model_cache
|
||||
# TODO: How to get the default model name now?
|
||||
if self.model is None or self.model == '':
|
||||
self.model = context.services.generate.model_name
|
||||
|
||||
# Set the model (if already cached, this does nothing)
|
||||
context.services.generate.set_model(self.model)
|
||||
|
||||
results = context.services.generate.prompt2image(
|
||||
prompt = self.prompt,
|
||||
step_callback = step_callback,
|
||||
**self.dict(exclude = {'prompt'}) # Shorthand for passing all of the parameters above manually
|
||||
)
|
||||
|
||||
# Results are image and seed, unwrap for now and ignore the seed
|
||||
# TODO: pre-seed?
|
||||
# TODO: can this return multiple results? Should it?
|
||||
image_type = ImageType.RESULT
|
||||
image_name = context.services.images.create_name(context.graph_execution_state_id, self.id)
|
||||
context.services.images.save(image_type, image_name, results[0][0])
|
||||
return ImageOutput(
|
||||
image = ImageField(image_type = image_type, image_name = image_name)
|
||||
)
|
||||
|
||||
|
||||
class ImageToImageInvocation(TextToImageInvocation):
|
||||
"""Generates an image using img2img."""
|
||||
type: Literal['img2img'] = 'img2img'
|
||||
|
||||
# Inputs
|
||||
image: Union[ImageField,None] = Field(description="The input image")
|
||||
strength: float = Field(default=0.75, gt=0, le=1, description="The strength of the original image")
|
||||
fit: bool = Field(default=True, description="Whether or not the result should be fit to the aspect ratio of the input image")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = None if self.image is None else context.services.images.get(self.image.image_type, self.image.image_name)
|
||||
mask = None
|
||||
|
||||
def step_callback(sample, step = 0):
|
||||
self.dispatch_progress(context, sample, step)
|
||||
|
||||
# Handle invalid model parameter
|
||||
# TODO: figure out if this can be done via a validator that uses the model_cache
|
||||
# TODO: How to get the default model name now?
|
||||
if self.model is None or self.model == '':
|
||||
self.model = context.services.generate.model_name
|
||||
|
||||
# Set the model (if already cached, this does nothing)
|
||||
context.services.generate.set_model(self.model)
|
||||
|
||||
results = context.services.generate.prompt2image(
|
||||
prompt = self.prompt,
|
||||
init_img = image,
|
||||
init_mask = mask,
|
||||
step_callback = step_callback,
|
||||
**self.dict(exclude = {'prompt','image','mask'}) # Shorthand for passing all of the parameters above manually
|
||||
)
|
||||
|
||||
result_image = results[0][0]
|
||||
|
||||
# Results are image and seed, unwrap for now and ignore the seed
|
||||
# TODO: pre-seed?
|
||||
# TODO: can this return multiple results? Should it?
|
||||
image_type = ImageType.RESULT
|
||||
image_name = context.services.images.create_name(context.graph_execution_state_id, self.id)
|
||||
context.services.images.save(image_type, image_name, result_image)
|
||||
return ImageOutput(
|
||||
image = ImageField(image_type = image_type, image_name = image_name)
|
||||
)
|
||||
|
||||
|
||||
class InpaintInvocation(ImageToImageInvocation):
|
||||
"""Generates an image using inpaint."""
|
||||
type: Literal['inpaint'] = 'inpaint'
|
||||
|
||||
# Inputs
|
||||
mask: Union[ImageField,None] = Field(description="The mask")
|
||||
inpaint_replace: float = Field(default=0.0, ge=0.0, le=1.0, description="The amount by which to replace masked areas with latent noise")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = None if self.image is None else context.services.images.get(self.image.image_type, self.image.image_name)
|
||||
mask = None if self.mask is None else context.services.images.get(self.mask.image_type, self.mask.image_name)
|
||||
|
||||
def step_callback(sample, step = 0):
|
||||
self.dispatch_progress(context, sample, step)
|
||||
|
||||
# Handle invalid model parameter
|
||||
# TODO: figure out if this can be done via a validator that uses the model_cache
|
||||
# TODO: How to get the default model name now?
|
||||
if self.model is None or self.model == '':
|
||||
self.model = context.services.generate.model_name
|
||||
|
||||
# Set the model (if already cached, this does nothing)
|
||||
context.services.generate.set_model(self.model)
|
||||
|
||||
results = context.services.generate.prompt2image(
|
||||
prompt = self.prompt,
|
||||
init_img = image,
|
||||
init_mask = mask,
|
||||
step_callback = step_callback,
|
||||
**self.dict(exclude = {'prompt','image','mask'}) # Shorthand for passing all of the parameters above manually
|
||||
)
|
||||
|
||||
result_image = results[0][0]
|
||||
|
||||
# Results are image and seed, unwrap for now and ignore the seed
|
||||
# TODO: pre-seed?
|
||||
# TODO: can this return multiple results? Should it?
|
||||
image_type = ImageType.RESULT
|
||||
image_name = context.services.images.create_name(context.graph_execution_state_id, self.id)
|
||||
context.services.images.save(image_type, image_name, result_image)
|
||||
return ImageOutput(
|
||||
image = ImageField(image_type = image_type, image_name = image_name)
|
||||
)
|
219
invokeai/app/invocations/image.py
Normal file
219
invokeai/app/invocations/image.py
Normal file
@ -0,0 +1,219 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from typing import Literal, Optional
|
||||
import numpy
|
||||
from pydantic import Field, BaseModel
|
||||
from PIL import Image, ImageOps, ImageFilter
|
||||
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext
|
||||
from ..services.image_storage import ImageType
|
||||
from ..services.invocation_services import InvocationServices
|
||||
|
||||
|
||||
class ImageField(BaseModel):
|
||||
"""An image field used for passing image objects between invocations"""
|
||||
image_type: str = Field(default=ImageType.RESULT, description="The type of the image")
|
||||
image_name: Optional[str] = Field(default=None, description="The name of the image")
|
||||
|
||||
|
||||
class ImageOutput(BaseInvocationOutput):
|
||||
"""Base class for invocations that output an image"""
|
||||
type: Literal['image'] = 'image'
|
||||
|
||||
image: ImageField = Field(default=None, description="The output image")
|
||||
|
||||
|
||||
class MaskOutput(BaseInvocationOutput):
|
||||
"""Base class for invocations that output a mask"""
|
||||
type: Literal['mask'] = 'mask'
|
||||
|
||||
mask: ImageField = Field(default=None, description="The output mask")
|
||||
|
||||
|
||||
# TODO: this isn't really necessary anymore
|
||||
class LoadImageInvocation(BaseInvocation):
|
||||
"""Load an image from a filename and provide it as output."""
|
||||
type: Literal['load_image'] = 'load_image'
|
||||
|
||||
# Inputs
|
||||
image_type: ImageType = Field(description="The type of the image")
|
||||
image_name: str = Field(description="The name of the image")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
return ImageOutput(
|
||||
image = ImageField(image_type = self.image_type, image_name = self.image_name)
|
||||
)
|
||||
|
||||
|
||||
class ShowImageInvocation(BaseInvocation):
|
||||
"""Displays a provided image, and passes it forward in the pipeline."""
|
||||
type: Literal['show_image'] = 'show_image'
|
||||
|
||||
# Inputs
|
||||
image: ImageField = Field(default=None, description="The image to show")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get(self.image.image_type, self.image.image_name)
|
||||
if image:
|
||||
image.show()
|
||||
|
||||
# TODO: how to handle failure?
|
||||
|
||||
return ImageOutput(
|
||||
image = ImageField(image_type = self.image.image_type, image_name = self.image.image_name)
|
||||
)
|
||||
|
||||
|
||||
class CropImageInvocation(BaseInvocation):
|
||||
"""Crops an image to a specified box. The box can be outside of the image."""
|
||||
type: Literal['crop'] = 'crop'
|
||||
|
||||
# Inputs
|
||||
image: ImageField = Field(default=None, description="The image to crop")
|
||||
x: int = Field(default=0, description="The left x coordinate of the crop rectangle")
|
||||
y: int = Field(default=0, description="The top y coordinate of the crop rectangle")
|
||||
width: int = Field(default=512, gt=0, description="The width of the crop rectangle")
|
||||
height: int = Field(default=512, gt=0, description="The height of the crop rectangle")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get(self.image.image_type, self.image.image_name)
|
||||
|
||||
image_crop = Image.new(mode = 'RGBA', size = (self.width, self.height), color = (0, 0, 0, 0))
|
||||
image_crop.paste(image, (-self.x, -self.y))
|
||||
|
||||
image_type = ImageType.INTERMEDIATE
|
||||
image_name = context.services.images.create_name(context.graph_execution_state_id, self.id)
|
||||
context.services.images.save(image_type, image_name, image_crop)
|
||||
return ImageOutput(
|
||||
image = ImageField(image_type = image_type, image_name = image_name)
|
||||
)
|
||||
|
||||
|
||||
class PasteImageInvocation(BaseInvocation):
|
||||
"""Pastes an image into another image."""
|
||||
type: Literal['paste'] = 'paste'
|
||||
|
||||
# Inputs
|
||||
base_image: ImageField = Field(default=None, description="The base image")
|
||||
image: ImageField = Field(default=None, description="The image to paste")
|
||||
mask: Optional[ImageField] = Field(default=None, description="The mask to use when pasting")
|
||||
x: int = Field(default=0, description="The left x coordinate at which to paste the image")
|
||||
y: int = Field(default=0, description="The top y coordinate at which to paste the image")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
base_image = context.services.images.get(self.base_image.image_type, self.base_image.image_name)
|
||||
image = context.services.images.get(self.image.image_type, self.image.image_name)
|
||||
mask = None if self.mask is None else ImageOps.invert(services.images.get(self.mask.image_type, self.mask.image_name))
|
||||
# TODO: probably shouldn't invert mask here... should user be required to do it?
|
||||
|
||||
min_x = min(0, self.x)
|
||||
min_y = min(0, self.y)
|
||||
max_x = max(base_image.width, image.width + self.x)
|
||||
max_y = max(base_image.height, image.height + self.y)
|
||||
|
||||
new_image = Image.new(mode = 'RGBA', size = (max_x - min_x, max_y - min_y), color = (0, 0, 0, 0))
|
||||
new_image.paste(base_image, (abs(min_x), abs(min_y)))
|
||||
new_image.paste(image, (max(0, self.x), max(0, self.y)), mask = mask)
|
||||
|
||||
image_type = ImageType.RESULT
|
||||
image_name = context.services.images.create_name(context.graph_execution_state_id, self.id)
|
||||
context.services.images.save(image_type, image_name, new_image)
|
||||
return ImageOutput(
|
||||
image = ImageField(image_type = image_type, image_name = image_name)
|
||||
)
|
||||
|
||||
|
||||
class MaskFromAlphaInvocation(BaseInvocation):
|
||||
"""Extracts the alpha channel of an image as a mask."""
|
||||
type: Literal['tomask'] = 'tomask'
|
||||
|
||||
# Inputs
|
||||
image: ImageField = Field(default=None, description="The image to create the mask from")
|
||||
invert: bool = Field(default=False, description="Whether or not to invert the mask")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> MaskOutput:
|
||||
image = context.services.images.get(self.image.image_type, self.image.image_name)
|
||||
|
||||
image_mask = image.split()[-1]
|
||||
if self.invert:
|
||||
image_mask = ImageOps.invert(image_mask)
|
||||
|
||||
image_type = ImageType.INTERMEDIATE
|
||||
image_name = context.services.images.create_name(context.graph_execution_state_id, self.id)
|
||||
context.services.images.save(image_type, image_name, image_mask)
|
||||
return MaskOutput(
|
||||
mask = ImageField(image_type = image_type, image_name = image_name)
|
||||
)
|
||||
|
||||
|
||||
class BlurInvocation(BaseInvocation):
|
||||
"""Blurs an image"""
|
||||
type: Literal['blur'] = 'blur'
|
||||
|
||||
# Inputs
|
||||
image: ImageField = Field(default=None, description="The image to blur")
|
||||
radius: float = Field(default=8.0, ge=0, description="The blur radius")
|
||||
blur_type: Literal['gaussian', 'box'] = Field(default='gaussian', description="The type of blur")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get(self.image.image_type, self.image.image_name)
|
||||
|
||||
blur = ImageFilter.GaussianBlur(self.radius) if self.blur_type == 'gaussian' else ImageFilter.BoxBlur(self.radius)
|
||||
blur_image = image.filter(blur)
|
||||
|
||||
image_type = ImageType.INTERMEDIATE
|
||||
image_name = context.services.images.create_name(context.graph_execution_state_id, self.id)
|
||||
context.services.images.save(image_type, image_name, blur_image)
|
||||
return ImageOutput(
|
||||
image = ImageField(image_type = image_type, image_name = image_name)
|
||||
)
|
||||
|
||||
|
||||
class LerpInvocation(BaseInvocation):
|
||||
"""Linear interpolation of all pixels of an image"""
|
||||
type: Literal['lerp'] = 'lerp'
|
||||
|
||||
# Inputs
|
||||
image: ImageField = Field(default=None, description="The image to lerp")
|
||||
min: int = Field(default=0, ge=0, le=255, description="The minimum output value")
|
||||
max: int = Field(default=255, ge=0, le=255, description="The maximum output value")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get(self.image.image_type, self.image.image_name)
|
||||
|
||||
image_arr = numpy.asarray(image, dtype=numpy.float32) / 255
|
||||
image_arr = image_arr * (self.max - self.min) + self.max
|
||||
|
||||
lerp_image = Image.fromarray(numpy.uint8(image_arr))
|
||||
|
||||
image_type = ImageType.INTERMEDIATE
|
||||
image_name = context.services.images.create_name(context.graph_execution_state_id, self.id)
|
||||
context.services.images.save(image_type, image_name, lerp_image)
|
||||
return ImageOutput(
|
||||
image = ImageField(image_type = image_type, image_name = image_name)
|
||||
)
|
||||
|
||||
|
||||
class InverseLerpInvocation(BaseInvocation):
|
||||
"""Inverse linear interpolation of all pixels of an image"""
|
||||
type: Literal['ilerp'] = 'ilerp'
|
||||
|
||||
# Inputs
|
||||
image: ImageField = Field(default=None, description="The image to lerp")
|
||||
min: int = Field(default=0, ge=0, le=255, description="The minimum input value")
|
||||
max: int = Field(default=255, ge=0, le=255, description="The maximum input value")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get(self.image.image_type, self.image.image_name)
|
||||
|
||||
image_arr = numpy.asarray(image, dtype=numpy.float32)
|
||||
image_arr = numpy.minimum(numpy.maximum(image_arr - self.min, 0) / float(self.max - self.min), 1) * 255
|
||||
|
||||
ilerp_image = Image.fromarray(numpy.uint8(image_arr))
|
||||
|
||||
image_type = ImageType.INTERMEDIATE
|
||||
image_name = context.services.images.create_name(context.graph_execution_state_id, self.id)
|
||||
context.services.images.save(image_type, image_name, ilerp_image)
|
||||
return ImageOutput(
|
||||
image = ImageField(image_type = image_type, image_name = image_name)
|
||||
)
|
9
invokeai/app/invocations/prompt.py
Normal file
9
invokeai/app/invocations/prompt.py
Normal file
@ -0,0 +1,9 @@
|
||||
from typing import Literal
|
||||
from pydantic.fields import Field
|
||||
from .baseinvocation import BaseInvocationOutput
|
||||
|
||||
class PromptOutput(BaseInvocationOutput):
|
||||
"""Base class for invocations that output a prompt"""
|
||||
type: Literal['prompt'] = 'prompt'
|
||||
|
||||
prompt: str = Field(default=None, description="The output prompt")
|
36
invokeai/app/invocations/reconstruct.py
Normal file
36
invokeai/app/invocations/reconstruct.py
Normal file
@ -0,0 +1,36 @@
|
||||
from datetime import datetime, timezone
|
||||
from typing import Literal, Union
|
||||
from pydantic import Field
|
||||
from .image import ImageField, ImageOutput
|
||||
from .baseinvocation import BaseInvocation, InvocationContext
|
||||
from ..services.image_storage import ImageType
|
||||
from ..services.invocation_services import InvocationServices
|
||||
|
||||
|
||||
class RestoreFaceInvocation(BaseInvocation):
|
||||
"""Restores faces in an image."""
|
||||
type: Literal['restore_face'] = 'restore_face'
|
||||
|
||||
# Inputs
|
||||
image: Union[ImageField,None] = Field(description="The input image")
|
||||
strength: float = Field(default=0.75, gt=0, le=1, description="The strength of the restoration")
|
||||
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get(self.image.image_type, self.image.image_name)
|
||||
results = context.services.generate.upscale_and_reconstruct(
|
||||
image_list = [[image, 0]],
|
||||
upscale = None,
|
||||
strength = self.strength, # GFPGAN strength
|
||||
save_original = False,
|
||||
image_callback = None,
|
||||
)
|
||||
|
||||
# Results are image and seed, unwrap for now
|
||||
# TODO: can this return multiple results?
|
||||
image_type = ImageType.RESULT
|
||||
image_name = context.services.images.create_name(context.graph_execution_state_id, self.id)
|
||||
context.services.images.save(image_type, image_name, results[0][0])
|
||||
return ImageOutput(
|
||||
image = ImageField(image_type = image_type, image_name = image_name)
|
||||
)
|
38
invokeai/app/invocations/upscale.py
Normal file
38
invokeai/app/invocations/upscale.py
Normal file
@ -0,0 +1,38 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from typing import Literal, Union
|
||||
from pydantic import Field
|
||||
from .image import ImageField, ImageOutput
|
||||
from .baseinvocation import BaseInvocation, InvocationContext
|
||||
from ..services.image_storage import ImageType
|
||||
from ..services.invocation_services import InvocationServices
|
||||
|
||||
|
||||
class UpscaleInvocation(BaseInvocation):
|
||||
"""Upscales an image."""
|
||||
type: Literal['upscale'] = 'upscale'
|
||||
|
||||
# Inputs
|
||||
image: Union[ImageField,None] = Field(description="The input image", default=None)
|
||||
strength: float = Field(default=0.75, gt=0, le=1, description="The strength")
|
||||
level: Literal[2,4] = Field(default=2, description = "The upscale level")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get(self.image.image_type, self.image.image_name)
|
||||
results = context.services.generate.upscale_and_reconstruct(
|
||||
image_list = [[image, 0]],
|
||||
upscale = (self.level, self.strength),
|
||||
strength = 0.0, # GFPGAN strength
|
||||
save_original = False,
|
||||
image_callback = None,
|
||||
)
|
||||
|
||||
# Results are image and seed, unwrap for now
|
||||
# TODO: can this return multiple results?
|
||||
image_type = ImageType.RESULT
|
||||
image_name = context.services.images.create_name(context.graph_execution_state_id, self.id)
|
||||
context.services.images.save(image_type, image_name, results[0][0])
|
||||
return ImageOutput(
|
||||
image = ImageField(image_type = image_type, image_name = image_name)
|
||||
)
|
0
invokeai/app/services/__init__.py
Normal file
0
invokeai/app/services/__init__.py
Normal file
93
invokeai/app/services/events.py
Normal file
93
invokeai/app/services/events.py
Normal file
@ -0,0 +1,93 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
from typing import Any, Dict
|
||||
|
||||
|
||||
class EventServiceBase:
|
||||
session_event: str = 'session_event'
|
||||
|
||||
"""Basic event bus, to have an empty stand-in when not needed"""
|
||||
def dispatch(self, event_name: str, payload: Any) -> None:
|
||||
pass
|
||||
|
||||
def __emit_session_event(self,
|
||||
event_name: str,
|
||||
payload: Dict) -> None:
|
||||
self.dispatch(
|
||||
event_name = EventServiceBase.session_event,
|
||||
payload = dict(
|
||||
event = event_name,
|
||||
data = payload
|
||||
)
|
||||
)
|
||||
|
||||
# Define events here for every event in the system.
|
||||
# This will make them easier to integrate until we find a schema generator.
|
||||
def emit_generator_progress(self,
|
||||
graph_execution_state_id: str,
|
||||
invocation_id: str,
|
||||
step: int,
|
||||
percent: float
|
||||
) -> None:
|
||||
"""Emitted when there is generation progress"""
|
||||
self.__emit_session_event(
|
||||
event_name = 'generator_progress',
|
||||
payload = dict(
|
||||
graph_execution_state_id = graph_execution_state_id,
|
||||
invocation_id = invocation_id,
|
||||
step = step,
|
||||
percent = percent
|
||||
)
|
||||
)
|
||||
|
||||
def emit_invocation_complete(self,
|
||||
graph_execution_state_id: str,
|
||||
invocation_id: str,
|
||||
result: Dict
|
||||
) -> None:
|
||||
"""Emitted when an invocation has completed"""
|
||||
self.__emit_session_event(
|
||||
event_name = 'invocation_complete',
|
||||
payload = dict(
|
||||
graph_execution_state_id = graph_execution_state_id,
|
||||
invocation_id = invocation_id,
|
||||
result = result
|
||||
)
|
||||
)
|
||||
|
||||
def emit_invocation_error(self,
|
||||
graph_execution_state_id: str,
|
||||
invocation_id: str,
|
||||
error: str
|
||||
) -> None:
|
||||
"""Emitted when an invocation has completed"""
|
||||
self.__emit_session_event(
|
||||
event_name = 'invocation_error',
|
||||
payload = dict(
|
||||
graph_execution_state_id = graph_execution_state_id,
|
||||
invocation_id = invocation_id,
|
||||
error = error
|
||||
)
|
||||
)
|
||||
|
||||
def emit_invocation_started(self,
|
||||
graph_execution_state_id: str,
|
||||
invocation_id: str
|
||||
) -> None:
|
||||
"""Emitted when an invocation has started"""
|
||||
self.__emit_session_event(
|
||||
event_name = 'invocation_started',
|
||||
payload = dict(
|
||||
graph_execution_state_id = graph_execution_state_id,
|
||||
invocation_id = invocation_id
|
||||
)
|
||||
)
|
||||
|
||||
def emit_graph_execution_complete(self, graph_execution_state_id: str) -> None:
|
||||
"""Emitted when a session has completed all invocations"""
|
||||
self.__emit_session_event(
|
||||
event_name = 'graph_execution_state_complete',
|
||||
payload = dict(
|
||||
graph_execution_state_id = graph_execution_state_id
|
||||
)
|
||||
)
|
231
invokeai/app/services/generate_initializer.py
Normal file
231
invokeai/app/services/generate_initializer.py
Normal file
@ -0,0 +1,231 @@
|
||||
from argparse import Namespace
|
||||
import os
|
||||
import sys
|
||||
import traceback
|
||||
|
||||
from invokeai.backend import ModelManager, Generate
|
||||
|
||||
from ...globals import Globals
|
||||
import invokeai.version
|
||||
|
||||
# TODO: most of this code should be split into individual services as the Generate.py code is deprecated
|
||||
def get_generate(args, config) -> Generate:
|
||||
if not args.conf:
|
||||
config_file = os.path.join(Globals.root,'configs','models.yaml')
|
||||
if not os.path.exists(config_file):
|
||||
report_model_error(args, FileNotFoundError(f"The file {config_file} could not be found."))
|
||||
|
||||
print(f'>> {invokeai.version.__app_name__}, version {invokeai.version.__version__}')
|
||||
print(f'>> InvokeAI runtime directory is "{Globals.root}"')
|
||||
|
||||
# these two lines prevent a horrible warning message from appearing
|
||||
# when the frozen CLIP tokenizer is imported
|
||||
import transformers # type: ignore
|
||||
transformers.logging.set_verbosity_error()
|
||||
import diffusers
|
||||
diffusers.logging.set_verbosity_error()
|
||||
|
||||
# Loading Face Restoration and ESRGAN Modules
|
||||
gfpgan,codeformer,esrgan = load_face_restoration(args)
|
||||
|
||||
# normalize the config directory relative to root
|
||||
if not os.path.isabs(args.conf):
|
||||
args.conf = os.path.normpath(os.path.join(Globals.root,args.conf))
|
||||
|
||||
if args.embeddings:
|
||||
if not os.path.isabs(args.embedding_path):
|
||||
embedding_path = os.path.normpath(os.path.join(Globals.root,args.embedding_path))
|
||||
else:
|
||||
embedding_path = args.embedding_path
|
||||
else:
|
||||
embedding_path = None
|
||||
|
||||
# migrate legacy models
|
||||
ModelManager.migrate_models()
|
||||
|
||||
# load the infile as a list of lines
|
||||
if args.infile:
|
||||
try:
|
||||
if os.path.isfile(args.infile):
|
||||
infile = open(args.infile, 'r', encoding='utf-8')
|
||||
elif args.infile == '-': # stdin
|
||||
infile = sys.stdin
|
||||
else:
|
||||
raise FileNotFoundError(f'{args.infile} not found.')
|
||||
except (FileNotFoundError, IOError) as e:
|
||||
print(f'{e}. Aborting.')
|
||||
sys.exit(-1)
|
||||
|
||||
# creating a Generate object:
|
||||
try:
|
||||
gen = Generate(
|
||||
conf = args.conf,
|
||||
model = args.model,
|
||||
sampler_name = args.sampler_name,
|
||||
embedding_path = embedding_path,
|
||||
full_precision = args.full_precision,
|
||||
precision = args.precision,
|
||||
gfpgan = gfpgan,
|
||||
codeformer = codeformer,
|
||||
esrgan = esrgan,
|
||||
free_gpu_mem = args.free_gpu_mem,
|
||||
safety_checker = args.safety_checker,
|
||||
max_loaded_models = args.max_loaded_models,
|
||||
)
|
||||
except (FileNotFoundError, TypeError, AssertionError) as e:
|
||||
report_model_error(opt,e)
|
||||
except (IOError, KeyError) as e:
|
||||
print(f'{e}. Aborting.')
|
||||
sys.exit(-1)
|
||||
|
||||
if args.seamless:
|
||||
print(">> changed to seamless tiling mode")
|
||||
|
||||
# preload the model
|
||||
try:
|
||||
gen.load_model()
|
||||
except KeyError:
|
||||
pass
|
||||
except Exception as e:
|
||||
report_model_error(args, e)
|
||||
|
||||
# try to autoconvert new models
|
||||
# autoimport new .ckpt files
|
||||
if path := args.autoconvert:
|
||||
gen.model_manager.autoconvert_weights(
|
||||
conf_path=args.conf,
|
||||
weights_directory=path,
|
||||
)
|
||||
|
||||
return gen
|
||||
|
||||
|
||||
def load_face_restoration(opt):
|
||||
try:
|
||||
gfpgan, codeformer, esrgan = None, None, None
|
||||
if opt.restore or opt.esrgan:
|
||||
from ldm.invoke.restoration import Restoration
|
||||
restoration = Restoration()
|
||||
if opt.restore:
|
||||
gfpgan, codeformer = restoration.load_face_restore_models(opt.gfpgan_model_path)
|
||||
else:
|
||||
print('>> Face restoration disabled')
|
||||
if opt.esrgan:
|
||||
esrgan = restoration.load_esrgan(opt.esrgan_bg_tile)
|
||||
else:
|
||||
print('>> Upscaling disabled')
|
||||
else:
|
||||
print('>> Face restoration and upscaling disabled')
|
||||
except (ModuleNotFoundError, ImportError):
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
print('>> You may need to install the ESRGAN and/or GFPGAN modules')
|
||||
return gfpgan,codeformer,esrgan
|
||||
|
||||
|
||||
def report_model_error(opt:Namespace, e:Exception):
|
||||
print(f'** An error occurred while attempting to initialize the model: "{str(e)}"')
|
||||
print('** This can be caused by a missing or corrupted models file, and can sometimes be fixed by (re)installing the models.')
|
||||
yes_to_all = os.environ.get('INVOKE_MODEL_RECONFIGURE')
|
||||
if yes_to_all:
|
||||
print('** Reconfiguration is being forced by environment variable INVOKE_MODEL_RECONFIGURE')
|
||||
else:
|
||||
response = input('Do you want to run invokeai-configure script to select and/or reinstall models? [y] ')
|
||||
if response.startswith(('n', 'N')):
|
||||
return
|
||||
|
||||
print('invokeai-configure is launching....\n')
|
||||
|
||||
# Match arguments that were set on the CLI
|
||||
# only the arguments accepted by the configuration script are parsed
|
||||
root_dir = ["--root", opt.root_dir] if opt.root_dir is not None else []
|
||||
config = ["--config", opt.conf] if opt.conf is not None else []
|
||||
previous_args = sys.argv
|
||||
sys.argv = [ 'invokeai-configure' ]
|
||||
sys.argv.extend(root_dir)
|
||||
sys.argv.extend(config)
|
||||
if yes_to_all is not None:
|
||||
for arg in yes_to_all.split():
|
||||
sys.argv.append(arg)
|
||||
|
||||
from ldm.invoke.config import invokeai_configure
|
||||
invokeai_configure.main()
|
||||
# TODO: Figure out how to restart
|
||||
# print('** InvokeAI will now restart')
|
||||
# sys.argv = previous_args
|
||||
# main() # would rather do a os.exec(), but doesn't exist?
|
||||
# sys.exit(0)
|
||||
|
||||
|
||||
# Temporary initializer for Generate until we migrate off of it
|
||||
def old_get_generate(args, config) -> Generate:
|
||||
# TODO: Remove the need for globals
|
||||
from invokeai.backend.globals import Globals
|
||||
|
||||
# alert - setting globals here
|
||||
Globals.root = os.path.expanduser(args.root_dir or os.environ.get('INVOKEAI_ROOT') or os.path.abspath('.'))
|
||||
Globals.try_patchmatch = args.patchmatch
|
||||
|
||||
print(f'>> InvokeAI runtime directory is "{Globals.root}"')
|
||||
|
||||
# these two lines prevent a horrible warning message from appearing
|
||||
# when the frozen CLIP tokenizer is imported
|
||||
import transformers
|
||||
transformers.logging.set_verbosity_error()
|
||||
|
||||
# Loading Face Restoration and ESRGAN Modules
|
||||
gfpgan, codeformer, esrgan = None, None, None
|
||||
try:
|
||||
if config.restore or config.esrgan:
|
||||
from ldm.invoke.restoration import Restoration
|
||||
restoration = Restoration()
|
||||
if config.restore:
|
||||
gfpgan, codeformer = restoration.load_face_restore_models(config.gfpgan_model_path)
|
||||
else:
|
||||
print('>> Face restoration disabled')
|
||||
if config.esrgan:
|
||||
esrgan = restoration.load_esrgan(config.esrgan_bg_tile)
|
||||
else:
|
||||
print('>> Upscaling disabled')
|
||||
else:
|
||||
print('>> Face restoration and upscaling disabled')
|
||||
except (ModuleNotFoundError, ImportError):
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
print('>> You may need to install the ESRGAN and/or GFPGAN modules')
|
||||
|
||||
# normalize the config directory relative to root
|
||||
if not os.path.isabs(config.conf):
|
||||
config.conf = os.path.normpath(os.path.join(Globals.root,config.conf))
|
||||
|
||||
if config.embeddings:
|
||||
if not os.path.isabs(config.embedding_path):
|
||||
embedding_path = os.path.normpath(os.path.join(Globals.root,config.embedding_path))
|
||||
else:
|
||||
embedding_path = None
|
||||
|
||||
|
||||
# TODO: lazy-initialize this by wrapping it
|
||||
try:
|
||||
generate = Generate(
|
||||
conf = config.conf,
|
||||
model = config.model,
|
||||
sampler_name = config.sampler_name,
|
||||
embedding_path = embedding_path,
|
||||
full_precision = config.full_precision,
|
||||
precision = config.precision,
|
||||
gfpgan = gfpgan,
|
||||
codeformer = codeformer,
|
||||
esrgan = esrgan,
|
||||
free_gpu_mem = config.free_gpu_mem,
|
||||
safety_checker = config.safety_checker,
|
||||
max_loaded_models = config.max_loaded_models,
|
||||
)
|
||||
except (FileNotFoundError, TypeError, AssertionError):
|
||||
#emergency_model_reconfigure() # TODO?
|
||||
sys.exit(-1)
|
||||
except (IOError, KeyError) as e:
|
||||
print(f'{e}. Aborting.')
|
||||
sys.exit(-1)
|
||||
|
||||
generate.free_gpu_mem = config.free_gpu_mem
|
||||
|
||||
return generate
|
809
invokeai/app/services/graph.py
Normal file
809
invokeai/app/services/graph.py
Normal file
@ -0,0 +1,809 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
import copy
|
||||
import itertools
|
||||
import traceback
|
||||
from types import NoneType
|
||||
import uuid
|
||||
import networkx as nx
|
||||
from pydantic import BaseModel, validator
|
||||
from pydantic.fields import Field
|
||||
from typing import Any, Literal, Optional, Union, get_args, get_origin, get_type_hints, Annotated
|
||||
|
||||
from .invocation_services import InvocationServices
|
||||
from ..invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext
|
||||
from ..invocations import *
|
||||
|
||||
|
||||
class EdgeConnection(BaseModel):
|
||||
node_id: str = Field(description="The id of the node for this edge connection")
|
||||
field: str = Field(description="The field for this connection")
|
||||
|
||||
def __eq__(self, other):
|
||||
return (isinstance(other, self.__class__) and
|
||||
getattr(other, 'node_id', None) == self.node_id and
|
||||
getattr(other, 'field', None) == self.field)
|
||||
|
||||
def __hash__(self):
|
||||
return hash(f'{self.node_id}.{self.field}')
|
||||
|
||||
|
||||
def get_output_field(node: BaseInvocation, field: str) -> Any:
|
||||
node_type = type(node)
|
||||
node_outputs = get_type_hints(node_type.get_output_type())
|
||||
node_output_field = node_outputs.get(field) or None
|
||||
return node_output_field
|
||||
|
||||
|
||||
def get_input_field(node: BaseInvocation, field: str) -> Any:
|
||||
node_type = type(node)
|
||||
node_inputs = get_type_hints(node_type)
|
||||
node_input_field = node_inputs.get(field) or None
|
||||
return node_input_field
|
||||
|
||||
|
||||
def are_connection_types_compatible(from_type: Any, to_type: Any) -> bool:
|
||||
if not from_type:
|
||||
return False
|
||||
if not to_type:
|
||||
return False
|
||||
|
||||
# TODO: this is pretty forgiving on generic types. Clean that up (need to handle optionals and such)
|
||||
if from_type and to_type:
|
||||
# Ports are compatible
|
||||
if (from_type == to_type or
|
||||
from_type == Any or
|
||||
to_type == Any or
|
||||
Any in get_args(from_type) or
|
||||
Any in get_args(to_type)):
|
||||
return True
|
||||
|
||||
if from_type in get_args(to_type):
|
||||
return True
|
||||
|
||||
if to_type in get_args(from_type):
|
||||
return True
|
||||
|
||||
if not issubclass(from_type, to_type):
|
||||
return False
|
||||
else:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def are_connections_compatible(
|
||||
from_node: BaseInvocation,
|
||||
from_field: str,
|
||||
to_node: BaseInvocation,
|
||||
to_field: str) -> bool:
|
||||
"""Determines if a connection between fields of two nodes is compatible."""
|
||||
|
||||
# TODO: handle iterators and collectors
|
||||
from_node_field = get_output_field(from_node, from_field)
|
||||
to_node_field = get_input_field(to_node, to_field)
|
||||
|
||||
return are_connection_types_compatible(from_node_field, to_node_field)
|
||||
|
||||
|
||||
class NodeAlreadyInGraphError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class InvalidEdgeError(Exception):
|
||||
pass
|
||||
|
||||
class NodeNotFoundError(Exception):
|
||||
pass
|
||||
|
||||
class NodeAlreadyExecutedError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
# TODO: Create and use an Empty output?
|
||||
class GraphInvocationOutput(BaseInvocationOutput):
|
||||
type: Literal['graph_output'] = 'graph_output'
|
||||
|
||||
|
||||
# TODO: Fill this out and move to invocations
|
||||
class GraphInvocation(BaseInvocation):
|
||||
type: Literal['graph'] = 'graph'
|
||||
|
||||
# TODO: figure out how to create a default here
|
||||
graph: 'Graph' = Field(description="The graph to run", default=None)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> GraphInvocationOutput:
|
||||
"""Invoke with provided services and return outputs."""
|
||||
return GraphInvocationOutput()
|
||||
|
||||
|
||||
class IterateInvocationOutput(BaseInvocationOutput):
|
||||
"""Used to connect iteration outputs. Will be expanded to a specific output."""
|
||||
type: Literal['iterate_output'] = 'iterate_output'
|
||||
|
||||
item: Any = Field(description="The item being iterated over")
|
||||
|
||||
|
||||
# TODO: Fill this out and move to invocations
|
||||
class IterateInvocation(BaseInvocation):
|
||||
type: Literal['iterate'] = 'iterate'
|
||||
|
||||
collection: list[Any] = Field(description="The list of items to iterate over", default_factory=list)
|
||||
index: int = Field(description="The index, will be provided on executed iterators", default=0)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> IterateInvocationOutput:
|
||||
"""Produces the outputs as values"""
|
||||
return IterateInvocationOutput(item = self.collection[self.index])
|
||||
|
||||
|
||||
class CollectInvocationOutput(BaseInvocationOutput):
|
||||
type: Literal['collect_output'] = 'collect_output'
|
||||
|
||||
collection: list[Any] = Field(description="The collection of input items")
|
||||
|
||||
|
||||
class CollectInvocation(BaseInvocation):
|
||||
"""Collects values into a collection"""
|
||||
type: Literal['collect'] = 'collect'
|
||||
|
||||
item: Any = Field(description="The item to collect (all inputs must be of the same type)", default=None)
|
||||
collection: list[Any] = Field(description="The collection, will be provided on execution", default_factory=list)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> CollectInvocationOutput:
|
||||
"""Invoke with provided services and return outputs."""
|
||||
return CollectInvocationOutput(collection = copy.copy(self.collection))
|
||||
|
||||
|
||||
InvocationsUnion = Union[BaseInvocation.get_invocations()] # type: ignore
|
||||
InvocationOutputsUnion = Union[BaseInvocationOutput.get_all_subclasses_tuple()] # type: ignore
|
||||
|
||||
|
||||
class Graph(BaseModel):
|
||||
id: str = Field(description="The id of this graph", default_factory=uuid.uuid4)
|
||||
# TODO: use a list (and never use dict in a BaseModel) because pydantic/fastapi hates me
|
||||
nodes: dict[str, Annotated[InvocationsUnion, Field(discriminator="type")]] = Field(description="The nodes in this graph", default_factory=dict)
|
||||
edges: list[tuple[EdgeConnection,EdgeConnection]] = Field(description="The connections between nodes and their fields in this graph", default_factory=list)
|
||||
|
||||
def add_node(self, node: BaseInvocation) -> None:
|
||||
"""Adds a node to a graph
|
||||
|
||||
:raises NodeAlreadyInGraphError: the node is already present in the graph.
|
||||
"""
|
||||
|
||||
if node.id in self.nodes:
|
||||
raise NodeAlreadyInGraphError()
|
||||
|
||||
self.nodes[node.id] = node
|
||||
|
||||
|
||||
def _get_graph_and_node(self, node_path: str) -> tuple['Graph', str]:
|
||||
"""Returns the graph and node id for a node path."""
|
||||
# Materialized graphs may have nodes at the top level
|
||||
if node_path in self.nodes:
|
||||
return (self, node_path)
|
||||
|
||||
node_id = node_path if '.' not in node_path else node_path[:node_path.index('.')]
|
||||
if node_id not in self.nodes:
|
||||
raise NodeNotFoundError(f'Node {node_path} not found in graph')
|
||||
|
||||
node = self.nodes[node_id]
|
||||
|
||||
if not isinstance(node, GraphInvocation):
|
||||
# There's more node path left but this isn't a graph - failure
|
||||
raise NodeNotFoundError('Node path terminated early at a non-graph node')
|
||||
|
||||
return node.graph._get_graph_and_node(node_path[node_path.index('.')+1:])
|
||||
|
||||
|
||||
def delete_node(self, node_path: str) -> None:
|
||||
"""Deletes a node from a graph"""
|
||||
|
||||
try:
|
||||
graph, node_id = self._get_graph_and_node(node_path)
|
||||
|
||||
# Delete edges for this node
|
||||
input_edges = self._get_input_edges_and_graphs(node_path)
|
||||
output_edges = self._get_output_edges_and_graphs(node_path)
|
||||
|
||||
for edge_graph,_,edge in input_edges:
|
||||
edge_graph.delete_edge(edge)
|
||||
|
||||
for edge_graph,_,edge in output_edges:
|
||||
edge_graph.delete_edge(edge)
|
||||
|
||||
del graph.nodes[node_id]
|
||||
|
||||
except NodeNotFoundError:
|
||||
pass # Ignore, not doesn't exist (should this throw?)
|
||||
|
||||
|
||||
def add_edge(self, edge: tuple[EdgeConnection, EdgeConnection]) -> None:
|
||||
"""Adds an edge to a graph
|
||||
|
||||
:raises InvalidEdgeError: the provided edge is invalid.
|
||||
"""
|
||||
|
||||
if self._is_edge_valid(edge) and edge not in self.edges:
|
||||
self.edges.append(edge)
|
||||
else:
|
||||
raise InvalidEdgeError()
|
||||
|
||||
|
||||
def delete_edge(self, edge: tuple[EdgeConnection, EdgeConnection]) -> None:
|
||||
"""Deletes an edge from a graph"""
|
||||
|
||||
try:
|
||||
self.edges.remove(edge)
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
|
||||
def is_valid(self) -> bool:
|
||||
"""Validates the graph."""
|
||||
|
||||
# Validate all subgraphs
|
||||
for gn in (n for n in self.nodes.values() if isinstance(n, GraphInvocation)):
|
||||
if not gn.graph.is_valid():
|
||||
return False
|
||||
|
||||
# Validate all edges reference nodes in the graph
|
||||
node_ids = set([e[0].node_id for e in self.edges]+[e[1].node_id for e in self.edges])
|
||||
if not all((self.has_node(node_id) for node_id in node_ids)):
|
||||
return False
|
||||
|
||||
# Validate there are no cycles
|
||||
g = self.nx_graph_flat()
|
||||
if not nx.is_directed_acyclic_graph(g):
|
||||
return False
|
||||
|
||||
# Validate all edge connections are valid
|
||||
if not all((are_connections_compatible(
|
||||
self.get_node(e[0].node_id), e[0].field,
|
||||
self.get_node(e[1].node_id), e[1].field
|
||||
) for e in self.edges)):
|
||||
return False
|
||||
|
||||
# Validate all iterators
|
||||
# TODO: may need to validate all iterators in subgraphs so edge connections in parent graphs will be available
|
||||
if not all((self._is_iterator_connection_valid(n.id) for n in self.nodes.values() if isinstance(n, IterateInvocation))):
|
||||
return False
|
||||
|
||||
# Validate all collectors
|
||||
# TODO: may need to validate all collectors in subgraphs so edge connections in parent graphs will be available
|
||||
if not all((self._is_collector_connection_valid(n.id) for n in self.nodes.values() if isinstance(n, CollectInvocation))):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def _is_edge_valid(self, edge: tuple[EdgeConnection, EdgeConnection]) -> bool:
|
||||
"""Validates that a new edge doesn't create a cycle in the graph"""
|
||||
|
||||
# Validate that the nodes exist (edges may contain node paths, so we can't just check for nodes directly)
|
||||
try:
|
||||
from_node = self.get_node(edge[0].node_id)
|
||||
to_node = self.get_node(edge[1].node_id)
|
||||
except NodeNotFoundError:
|
||||
return False
|
||||
|
||||
# Validate that an edge to this node+field doesn't already exist
|
||||
input_edges = self._get_input_edges(edge[1].node_id, edge[1].field)
|
||||
if len(input_edges) > 0 and not isinstance(to_node, CollectInvocation):
|
||||
return False
|
||||
|
||||
# Validate that no cycles would be created
|
||||
g = self.nx_graph_flat()
|
||||
g.add_edge(edge[0].node_id, edge[1].node_id)
|
||||
if not nx.is_directed_acyclic_graph(g):
|
||||
return False
|
||||
|
||||
# Validate that the field types are compatible
|
||||
if not are_connections_compatible(from_node, edge[0].field, to_node, edge[1].field):
|
||||
return False
|
||||
|
||||
# Validate if iterator output type matches iterator input type (if this edge results in both being set)
|
||||
if isinstance(to_node, IterateInvocation) and edge[1].field == 'collection':
|
||||
if not self._is_iterator_connection_valid(edge[1].node_id, new_input = edge[0]):
|
||||
return False
|
||||
|
||||
# Validate if iterator input type matches output type (if this edge results in both being set)
|
||||
if isinstance(from_node, IterateInvocation) and edge[0].field == 'item':
|
||||
if not self._is_iterator_connection_valid(edge[0].node_id, new_output = edge[1]):
|
||||
return False
|
||||
|
||||
# Validate if collector input type matches output type (if this edge results in both being set)
|
||||
if isinstance(to_node, CollectInvocation) and edge[1].field == 'item':
|
||||
if not self._is_collector_connection_valid(edge[1].node_id, new_input = edge[0]):
|
||||
return False
|
||||
|
||||
# Validate if collector output type matches input type (if this edge results in both being set)
|
||||
if isinstance(from_node, CollectInvocation) and edge[0].field == 'collection':
|
||||
if not self._is_collector_connection_valid(edge[0].node_id, new_output = edge[1]):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def has_node(self, node_path: str) -> bool:
|
||||
"""Determines whether or not a node exists in the graph."""
|
||||
try:
|
||||
n = self.get_node(node_path)
|
||||
if n is not None:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
except NodeNotFoundError:
|
||||
return False
|
||||
|
||||
def get_node(self, node_path: str) -> InvocationsUnion:
|
||||
"""Gets a node from the graph using a node path."""
|
||||
# Materialized graphs may have nodes at the top level
|
||||
graph, node_id = self._get_graph_and_node(node_path)
|
||||
return graph.nodes[node_id]
|
||||
|
||||
|
||||
def _get_node_path(self, node_id: str, prefix: Optional[str] = None) -> str:
|
||||
return node_id if prefix is None or prefix == '' else f'{prefix}.{node_id}'
|
||||
|
||||
|
||||
def update_node(self, node_path: str, new_node: BaseInvocation) -> None:
|
||||
"""Updates a node in the graph."""
|
||||
graph, node_id = self._get_graph_and_node(node_path)
|
||||
node = graph.nodes[node_id]
|
||||
|
||||
# Ensure the node type matches the new node
|
||||
if type(node) != type(new_node):
|
||||
raise TypeError(f'Node {node_path} is type {type(node)} but new node is type {type(new_node)}')
|
||||
|
||||
# Ensure the new id is either the same or is not in the graph
|
||||
prefix = None if '.' not in node_path else node_path[:node_path.rindex('.')]
|
||||
new_path = self._get_node_path(new_node.id, prefix = prefix)
|
||||
if new_node.id != node.id and self.has_node(new_path):
|
||||
raise NodeAlreadyInGraphError('Node with id {new_node.id} already exists in graph')
|
||||
|
||||
# Set the new node in the graph
|
||||
graph.nodes[new_node.id] = new_node
|
||||
if new_node.id != node.id:
|
||||
input_edges = self._get_input_edges_and_graphs(node_path)
|
||||
output_edges = self._get_output_edges_and_graphs(node_path)
|
||||
|
||||
# Delete node and all edges
|
||||
graph.delete_node(node_path)
|
||||
|
||||
# Create new edges for each input and output
|
||||
for graph,_,edge in input_edges:
|
||||
# Remove the graph prefix from the node path
|
||||
new_graph_node_path = new_node.id if '.' not in edge[1].node_id else f'{edge[1].node_id[edge[1].node_id.rindex("."):]}.{new_node.id}'
|
||||
graph.add_edge((edge[0], EdgeConnection(node_id = new_graph_node_path, field = edge[1].field)))
|
||||
|
||||
for graph,_,edge in output_edges:
|
||||
# Remove the graph prefix from the node path
|
||||
new_graph_node_path = new_node.id if '.' not in edge[0].node_id else f'{edge[0].node_id[edge[0].node_id.rindex("."):]}.{new_node.id}'
|
||||
graph.add_edge((EdgeConnection(node_id = new_graph_node_path, field = edge[0].field), edge[1]))
|
||||
|
||||
|
||||
def _get_input_edges(self, node_path: str, field: Optional[str] = None) -> list[tuple[EdgeConnection,EdgeConnection]]:
|
||||
"""Gets all input edges for a node"""
|
||||
edges = self._get_input_edges_and_graphs(node_path)
|
||||
|
||||
# Filter to edges that match the field
|
||||
filtered_edges = (e for e in edges if field is None or e[2][1].field == field)
|
||||
|
||||
# Create full node paths for each edge
|
||||
return [(EdgeConnection(node_id = self._get_node_path(e[0].node_id, prefix = prefix), field=e[0].field), EdgeConnection(node_id = self._get_node_path(e[1].node_id, prefix = prefix), field=e[1].field)) for _,prefix,e in filtered_edges]
|
||||
|
||||
|
||||
def _get_input_edges_and_graphs(self, node_path: str, prefix: Optional[str] = None) -> list[tuple['Graph', str, tuple[EdgeConnection,EdgeConnection]]]:
|
||||
"""Gets all input edges for a node along with the graph they are in and the graph's path"""
|
||||
edges = list()
|
||||
|
||||
# Return any input edges that appear in this graph
|
||||
edges.extend([(self, prefix, e) for e in self.edges if e[1].node_id == node_path])
|
||||
|
||||
node_id = node_path if '.' not in node_path else node_path[:node_path.index('.')]
|
||||
node = self.nodes[node_id]
|
||||
|
||||
if isinstance(node, GraphInvocation):
|
||||
graph = node.graph
|
||||
graph_path = node.id if prefix is None or prefix == '' else self._get_node_path(node.id, prefix = prefix)
|
||||
graph_edges = graph._get_input_edges_and_graphs(node_path[(len(node_id)+1):], prefix=graph_path)
|
||||
edges.extend(graph_edges)
|
||||
|
||||
return edges
|
||||
|
||||
|
||||
def _get_output_edges(self, node_path: str, field: str) -> list[tuple[EdgeConnection,EdgeConnection]]:
|
||||
"""Gets all output edges for a node"""
|
||||
edges = self._get_output_edges_and_graphs(node_path)
|
||||
|
||||
# Filter to edges that match the field
|
||||
filtered_edges = (e for e in edges if e[2][0].field == field)
|
||||
|
||||
# Create full node paths for each edge
|
||||
return [(EdgeConnection(node_id = self._get_node_path(e[0].node_id, prefix = prefix), field=e[0].field), EdgeConnection(node_id = self._get_node_path(e[1].node_id, prefix = prefix), field=e[1].field)) for _,prefix,e in filtered_edges]
|
||||
|
||||
|
||||
def _get_output_edges_and_graphs(self, node_path: str, prefix: Optional[str] = None) -> list[tuple['Graph', str, tuple[EdgeConnection,EdgeConnection]]]:
|
||||
"""Gets all output edges for a node along with the graph they are in and the graph's path"""
|
||||
edges = list()
|
||||
|
||||
# Return any input edges that appear in this graph
|
||||
edges.extend([(self, prefix, e) for e in self.edges if e[0].node_id == node_path])
|
||||
|
||||
node_id = node_path if '.' not in node_path else node_path[:node_path.index('.')]
|
||||
node = self.nodes[node_id]
|
||||
|
||||
if isinstance(node, GraphInvocation):
|
||||
graph = node.graph
|
||||
graph_path = node.id if prefix is None or prefix == '' else self._get_node_path(node.id, prefix = prefix)
|
||||
graph_edges = graph._get_output_edges_and_graphs(node_path[(len(node_id)+1):], prefix=graph_path)
|
||||
edges.extend(graph_edges)
|
||||
|
||||
return edges
|
||||
|
||||
|
||||
def _is_iterator_connection_valid(self, node_path: str, new_input: Optional[EdgeConnection] = None, new_output: Optional[EdgeConnection] = None) -> bool:
|
||||
inputs = list([e[0] for e in self._get_input_edges(node_path, 'collection')])
|
||||
outputs = list([e[1] for e in self._get_output_edges(node_path, 'item')])
|
||||
|
||||
if new_input is not None:
|
||||
inputs.append(new_input)
|
||||
if new_output is not None:
|
||||
outputs.append(new_output)
|
||||
|
||||
# Only one input is allowed for iterators
|
||||
if len(inputs) > 1:
|
||||
return False
|
||||
|
||||
# Get input and output fields (the fields linked to the iterator's input/output)
|
||||
input_field = get_output_field(self.get_node(inputs[0].node_id), inputs[0].field)
|
||||
output_fields = list([get_input_field(self.get_node(e.node_id), e.field) for e in outputs])
|
||||
|
||||
# Input type must be a list
|
||||
if get_origin(input_field) != list:
|
||||
return False
|
||||
|
||||
# Validate that all outputs match the input type
|
||||
input_field_item_type = get_args(input_field)[0]
|
||||
if not all((are_connection_types_compatible(input_field_item_type, f) for f in output_fields)):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def _is_collector_connection_valid(self, node_path: str, new_input: Optional[EdgeConnection] = None, new_output: Optional[EdgeConnection] = None) -> bool:
|
||||
inputs = list([e[0] for e in self._get_input_edges(node_path, 'item')])
|
||||
outputs = list([e[1] for e in self._get_output_edges(node_path, 'collection')])
|
||||
|
||||
if new_input is not None:
|
||||
inputs.append(new_input)
|
||||
if new_output is not None:
|
||||
outputs.append(new_output)
|
||||
|
||||
# Get input and output fields (the fields linked to the iterator's input/output)
|
||||
input_fields = list([get_output_field(self.get_node(e.node_id), e.field) for e in inputs])
|
||||
output_fields = list([get_input_field(self.get_node(e.node_id), e.field) for e in outputs])
|
||||
|
||||
# Validate that all inputs are derived from or match a single type
|
||||
input_field_types = set([t for input_field in input_fields for t in ([input_field] if get_origin(input_field) == None else get_args(input_field)) if t != NoneType]) # Get unique types
|
||||
type_tree = nx.DiGraph()
|
||||
type_tree.add_nodes_from(input_field_types)
|
||||
type_tree.add_edges_from([e for e in itertools.permutations(input_field_types, 2) if issubclass(e[1], e[0])])
|
||||
type_degrees = type_tree.in_degree(type_tree.nodes)
|
||||
if sum((t[1] == 0 for t in type_degrees)) != 1: # type: ignore
|
||||
return False # There is more than one root type
|
||||
|
||||
# Get the input root type
|
||||
input_root_type = next(t[0] for t in type_degrees if t[1] == 0) # type: ignore
|
||||
|
||||
# Verify that all outputs are lists
|
||||
if not all((get_origin(f) == list for f in output_fields)):
|
||||
return False
|
||||
|
||||
# Verify that all outputs match the input type (are a base class or the same class)
|
||||
if not all((issubclass(input_root_type, get_args(f)[0]) for f in output_fields)):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def nx_graph(self) -> nx.DiGraph:
|
||||
"""Returns a NetworkX DiGraph representing the layout of this graph"""
|
||||
# TODO: Cache this?
|
||||
g = nx.DiGraph()
|
||||
g.add_nodes_from([n for n in self.nodes.keys()])
|
||||
g.add_edges_from(set([(e[0].node_id, e[1].node_id) for e in self.edges]))
|
||||
return g
|
||||
|
||||
def nx_graph_flat(self, nx_graph: Optional[nx.DiGraph] = None, prefix: Optional[str] = None) -> nx.DiGraph:
|
||||
"""Returns a flattened NetworkX DiGraph, including all subgraphs (but not with iterations expanded)"""
|
||||
g = nx_graph or nx.DiGraph()
|
||||
|
||||
# Add all nodes from this graph except graph/iteration nodes
|
||||
g.add_nodes_from([self._get_node_path(n.id, prefix) for n in self.nodes.values() if not isinstance(n, GraphInvocation) and not isinstance(n, IterateInvocation)])
|
||||
|
||||
# Expand graph nodes
|
||||
for sgn in (gn for gn in self.nodes.values() if isinstance(gn, GraphInvocation)):
|
||||
sgn.graph.nx_graph_flat(g, self._get_node_path(sgn.id, prefix))
|
||||
|
||||
# TODO: figure out if iteration nodes need to be expanded
|
||||
|
||||
unique_edges = set([(e[0].node_id, e[1].node_id) for e in self.edges])
|
||||
g.add_edges_from([(self._get_node_path(e[0], prefix), self._get_node_path(e[1], prefix)) for e in unique_edges])
|
||||
return g
|
||||
|
||||
|
||||
class GraphExecutionState(BaseModel):
|
||||
"""Tracks the state of a graph execution"""
|
||||
id: str = Field(description="The id of the execution state", default_factory=uuid.uuid4)
|
||||
|
||||
# TODO: Store a reference to the graph instead of the actual graph?
|
||||
graph: Graph = Field(description="The graph being executed")
|
||||
|
||||
# The graph of materialized nodes
|
||||
execution_graph: Graph = Field(description="The expanded graph of activated and executed nodes", default_factory=Graph)
|
||||
|
||||
# Nodes that have been executed
|
||||
executed: set[str] = Field(description="The set of node ids that have been executed", default_factory=set)
|
||||
executed_history: list[str] = Field(description="The list of node ids that have been executed, in order of execution", default_factory=list)
|
||||
|
||||
# The results of executed nodes
|
||||
results: dict[str, Annotated[InvocationOutputsUnion, Field(discriminator="type")]] = Field(description="The results of node executions", default_factory=dict)
|
||||
|
||||
# Errors raised when executing nodes
|
||||
errors: dict[str, str] = Field(description="Errors raised when executing nodes", default_factory=dict)
|
||||
|
||||
# Map of prepared/executed nodes to their original nodes
|
||||
prepared_source_mapping: dict[str, str] = Field(description="The map of prepared nodes to original graph nodes", default_factory=dict)
|
||||
|
||||
# Map of original nodes to prepared nodes
|
||||
source_prepared_mapping: dict[str, set[str]] = Field(description="The map of original graph nodes to prepared nodes", default_factory=dict)
|
||||
|
||||
def next(self) -> BaseInvocation | None:
|
||||
"""Gets the next node ready to execute."""
|
||||
|
||||
# TODO: enable multiple nodes to execute simultaneously by tracking currently executing nodes
|
||||
# possibly with a timeout?
|
||||
|
||||
# If there are no prepared nodes, prepare some nodes
|
||||
next_node = self._get_next_node()
|
||||
if next_node is None:
|
||||
prepared_id = self._prepare()
|
||||
|
||||
# TODO: prepare multiple nodes at once?
|
||||
# while prepared_id is not None and not isinstance(self.graph.nodes[prepared_id], IterateInvocation):
|
||||
# prepared_id = self._prepare()
|
||||
|
||||
if prepared_id is not None:
|
||||
next_node = self._get_next_node()
|
||||
|
||||
# Get values from edges
|
||||
if next_node is not None:
|
||||
self._prepare_inputs(next_node)
|
||||
|
||||
# If next is still none, there's no next node, return None
|
||||
return next_node
|
||||
|
||||
def complete(self, node_id: str, output: InvocationOutputsUnion):
|
||||
"""Marks a node as complete"""
|
||||
|
||||
if node_id not in self.execution_graph.nodes:
|
||||
return # TODO: log error?
|
||||
|
||||
# Mark node as executed
|
||||
self.executed.add(node_id)
|
||||
self.results[node_id] = output
|
||||
|
||||
# Check if source node is complete (all prepared nodes are complete)
|
||||
source_node = self.prepared_source_mapping[node_id]
|
||||
prepared_nodes = self.source_prepared_mapping[source_node]
|
||||
|
||||
if all([n in self.executed for n in prepared_nodes]):
|
||||
self.executed.add(source_node)
|
||||
self.executed_history.append(source_node)
|
||||
|
||||
def set_node_error(self, node_id: str, error: str):
|
||||
"""Marks a node as errored"""
|
||||
self.errors[node_id] = error
|
||||
|
||||
def is_complete(self) -> bool:
|
||||
"""Returns true if the graph is complete"""
|
||||
return self.has_error() or all((k in self.executed for k in self.graph.nodes))
|
||||
|
||||
def has_error(self) -> bool:
|
||||
"""Returns true if the graph has any errors"""
|
||||
return len(self.errors) > 0
|
||||
|
||||
def _create_execution_node(self, node_path: str, iteration_node_map: list[tuple[str, str]]) -> list[str]:
|
||||
"""Prepares an iteration node and connects all edges, returning the new node id"""
|
||||
|
||||
node = self.graph.get_node(node_path)
|
||||
|
||||
self_iteration_count = -1
|
||||
|
||||
# If this is an iterator node, we must create a copy for each iteration
|
||||
if isinstance(node, IterateInvocation):
|
||||
# Get input collection edge (should error if there are no inputs)
|
||||
input_collection_edge = next(iter(self.graph._get_input_edges(node_path, 'collection')))
|
||||
input_collection_prepared_node_id = next(n[1] for n in iteration_node_map if n[0] == input_collection_edge[0].node_id)
|
||||
input_collection_prepared_node_output = self.results[input_collection_prepared_node_id]
|
||||
input_collection = getattr(input_collection_prepared_node_output, input_collection_edge[0].field)
|
||||
self_iteration_count = len(input_collection)
|
||||
|
||||
new_nodes = list()
|
||||
if self_iteration_count == 0:
|
||||
# TODO: should this raise a warning? It might just happen if an empty collection is input, and should be valid.
|
||||
return new_nodes
|
||||
|
||||
# Get all input edges
|
||||
input_edges = self.graph._get_input_edges(node_path)
|
||||
|
||||
# Create new edges for this iteration
|
||||
# For collect nodes, this may contain multiple inputs to the same field
|
||||
new_edges = list()
|
||||
for edge in input_edges:
|
||||
for input_node_id in (n[1] for n in iteration_node_map if n[0] == edge[0].node_id):
|
||||
new_edge = (EdgeConnection(node_id = input_node_id, field = edge[0].field), EdgeConnection(node_id = '', field = edge[1].field))
|
||||
new_edges.append(new_edge)
|
||||
|
||||
# Create a new node (or one for each iteration of this iterator)
|
||||
for i in (range(self_iteration_count) if self_iteration_count > 0 else [-1]):
|
||||
# Create a new node
|
||||
new_node = copy.deepcopy(node)
|
||||
|
||||
# Create the node id (use a random uuid)
|
||||
new_node.id = str(uuid.uuid4())
|
||||
|
||||
# Set the iteration index for iteration invocations
|
||||
if isinstance(new_node, IterateInvocation):
|
||||
new_node.index = i
|
||||
|
||||
# Add to execution graph
|
||||
self.execution_graph.add_node(new_node)
|
||||
self.prepared_source_mapping[new_node.id] = node_path
|
||||
if node_path not in self.source_prepared_mapping:
|
||||
self.source_prepared_mapping[node_path] = set()
|
||||
self.source_prepared_mapping[node_path].add(new_node.id)
|
||||
|
||||
# Add new edges to execution graph
|
||||
for edge in new_edges:
|
||||
new_edge = (edge[0], EdgeConnection(node_id = new_node.id, field = edge[1].field))
|
||||
self.execution_graph.add_edge(new_edge)
|
||||
|
||||
new_nodes.append(new_node.id)
|
||||
|
||||
return new_nodes
|
||||
|
||||
def _iterator_graph(self) -> nx.DiGraph:
|
||||
"""Gets a DiGraph with edges to collectors removed so an ancestor search produces all active iterators for any node"""
|
||||
g = self.graph.nx_graph()
|
||||
collectors = (n for n in self.graph.nodes if isinstance(self.graph.nodes[n], CollectInvocation))
|
||||
for c in collectors:
|
||||
g.remove_edges_from(list(g.in_edges(c)))
|
||||
return g
|
||||
|
||||
|
||||
def _get_node_iterators(self, node_id: str) -> list[str]:
|
||||
"""Gets iterators for a node"""
|
||||
g = self._iterator_graph()
|
||||
iterators = [n for n in nx.ancestors(g, node_id) if isinstance(self.graph.nodes[n], IterateInvocation)]
|
||||
return iterators
|
||||
|
||||
|
||||
def _prepare(self) -> Optional[str]:
|
||||
# Get flattened source graph
|
||||
g = self.graph.nx_graph_flat()
|
||||
|
||||
# Find next unprepared node where all source nodes are executed
|
||||
sorted_nodes = nx.topological_sort(g)
|
||||
next_node_id = next((n for n in sorted_nodes if n not in self.source_prepared_mapping and all((e[0] in self.executed for e in g.in_edges(n)))), None)
|
||||
|
||||
if next_node_id == None:
|
||||
return None
|
||||
|
||||
# Get all parents of the next node
|
||||
next_node_parents = [e[0] for e in g.in_edges(next_node_id)]
|
||||
|
||||
# Create execution nodes
|
||||
next_node = self.graph.get_node(next_node_id)
|
||||
new_node_ids = list()
|
||||
if isinstance(next_node, CollectInvocation):
|
||||
# Collapse all iterator input mappings and create a single execution node for the collect invocation
|
||||
all_iteration_mappings = list(itertools.chain(*(((s,p) for p in self.source_prepared_mapping[s]) for s in next_node_parents)))
|
||||
#all_iteration_mappings = list(set(itertools.chain(*prepared_parent_mappings)))
|
||||
create_results = self._create_execution_node(next_node_id, all_iteration_mappings)
|
||||
if create_results is not None:
|
||||
new_node_ids.extend(create_results)
|
||||
else: # Iterators or normal nodes
|
||||
# Get all iterator combinations for this node
|
||||
# Will produce a list of lists of prepared iterator nodes, from which results can be iterated
|
||||
iterator_nodes = self._get_node_iterators(next_node_id)
|
||||
iterator_nodes_prepared = [list(self.source_prepared_mapping[n]) for n in iterator_nodes]
|
||||
iterator_node_prepared_combinations = list(itertools.product(*iterator_nodes_prepared))
|
||||
|
||||
# Select the correct prepared parents for each iteration
|
||||
# For every iterator, the parent must either not be a child of that iterator, or must match the prepared iteration for that iterator
|
||||
# TODO: Handle a node mapping to none
|
||||
eg = self.execution_graph.nx_graph_flat()
|
||||
prepared_parent_mappings = [[(n,self._get_iteration_node(n, g, eg, it)) for n in next_node_parents] for it in iterator_node_prepared_combinations] # type: ignore
|
||||
|
||||
# Create execution node for each iteration
|
||||
for iteration_mappings in prepared_parent_mappings:
|
||||
create_results = self._create_execution_node(next_node_id, iteration_mappings) # type: ignore
|
||||
if create_results is not None:
|
||||
new_node_ids.extend(create_results)
|
||||
|
||||
return next(iter(new_node_ids), None)
|
||||
|
||||
def _get_iteration_node(self, source_node_path: str, graph: nx.DiGraph, execution_graph: nx.DiGraph, prepared_iterator_nodes: list[str]) -> Optional[str]:
|
||||
"""Gets the prepared version of the specified source node that matches every iteration specified"""
|
||||
prepared_nodes = self.source_prepared_mapping[source_node_path]
|
||||
if len(prepared_nodes) == 1:
|
||||
return next(iter(prepared_nodes))
|
||||
|
||||
# Check if the requested node is an iterator
|
||||
prepared_iterator = next((n for n in prepared_nodes if n in prepared_iterator_nodes), None)
|
||||
if prepared_iterator is not None:
|
||||
return prepared_iterator
|
||||
|
||||
# Filter to only iterator nodes that are a parent of the specified node, in tuple format (prepared, source)
|
||||
iterator_source_node_mapping = [(n, self.prepared_source_mapping[n]) for n in prepared_iterator_nodes]
|
||||
parent_iterators = [itn for itn in iterator_source_node_mapping if nx.has_path(graph, itn[1], source_node_path)]
|
||||
|
||||
return next((n for n in prepared_nodes if all(pit for pit in parent_iterators if nx.has_path(execution_graph, pit[0], n))), None)
|
||||
|
||||
def _get_next_node(self) -> Optional[BaseInvocation]:
|
||||
g = self.execution_graph.nx_graph()
|
||||
sorted_nodes = nx.topological_sort(g)
|
||||
next_node = next((n for n in sorted_nodes if n not in self.executed), None)
|
||||
if next_node is None:
|
||||
return None
|
||||
|
||||
return self.execution_graph.nodes[next_node]
|
||||
|
||||
def _prepare_inputs(self, node: BaseInvocation):
|
||||
input_edges = [e for e in self.execution_graph.edges if e[1].node_id == node.id]
|
||||
if isinstance(node, CollectInvocation):
|
||||
output_collection = [getattr(self.results[edge[0].node_id], edge[0].field) for edge in input_edges if edge[1].field == 'item']
|
||||
setattr(node, 'collection', output_collection)
|
||||
else:
|
||||
for edge in input_edges:
|
||||
output_value = getattr(self.results[edge[0].node_id], edge[0].field)
|
||||
setattr(node, edge[1].field, output_value)
|
||||
|
||||
# TODO: Add API for modifying underlying graph that checks if the change will be valid given the current execution state
|
||||
def _is_edge_valid(self, edge: tuple[EdgeConnection, EdgeConnection]) -> bool:
|
||||
if not self._is_edge_valid(edge):
|
||||
return False
|
||||
|
||||
# Invalid if destination has already been prepared or executed
|
||||
if edge[1].node_id in self.source_prepared_mapping:
|
||||
return False
|
||||
|
||||
# Otherwise, the edge is valid
|
||||
return True
|
||||
|
||||
def _is_node_updatable(self, node_id: str) -> bool:
|
||||
# The node is updatable as long as it hasn't been prepared or executed
|
||||
return node_id not in self.source_prepared_mapping
|
||||
|
||||
def add_node(self, node: BaseInvocation) -> None:
|
||||
self.graph.add_node(node)
|
||||
|
||||
def update_node(self, node_path: str, new_node: BaseInvocation) -> None:
|
||||
if not self._is_node_updatable(node_path):
|
||||
raise NodeAlreadyExecutedError(f'Node {node_path} has already been prepared or executed and cannot be updated')
|
||||
self.graph.update_node(node_path, new_node)
|
||||
|
||||
def delete_node(self, node_path: str) -> None:
|
||||
if not self._is_node_updatable(node_path):
|
||||
raise NodeAlreadyExecutedError(f'Node {node_path} has already been prepared or executed and cannot be deleted')
|
||||
self.graph.delete_node(node_path)
|
||||
|
||||
def add_edge(self, edge: tuple[EdgeConnection, EdgeConnection]) -> None:
|
||||
if not self._is_node_updatable(edge[1].node_id):
|
||||
raise NodeAlreadyExecutedError(f'Destination node {edge[1].node_id} has already been prepared or executed and cannot be linked to')
|
||||
self.graph.add_edge(edge)
|
||||
|
||||
def delete_edge(self, edge: tuple[EdgeConnection, EdgeConnection]) -> None:
|
||||
if not self._is_node_updatable(edge[1].node_id):
|
||||
raise NodeAlreadyExecutedError(f'Destination node {edge[1].node_id} has already been prepared or executed and cannot have a source edge deleted')
|
||||
self.graph.delete_edge(edge)
|
||||
|
||||
GraphInvocation.update_forward_refs()
|
104
invokeai/app/services/image_storage.py
Normal file
104
invokeai/app/services/image_storage.py
Normal file
@ -0,0 +1,104 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
import datetime
|
||||
import os
|
||||
from pathlib import Path
|
||||
from queue import Queue
|
||||
from typing import Dict
|
||||
from PIL.Image import Image
|
||||
from invokeai.backend.image_util import PngWriter
|
||||
|
||||
|
||||
class ImageType(str, Enum):
|
||||
RESULT = 'results'
|
||||
INTERMEDIATE = 'intermediates'
|
||||
UPLOAD = 'uploads'
|
||||
|
||||
|
||||
class ImageStorageBase(ABC):
|
||||
"""Responsible for storing and retrieving images."""
|
||||
|
||||
@abstractmethod
|
||||
def get(self, image_type: ImageType, image_name: str) -> Image:
|
||||
pass
|
||||
|
||||
# TODO: make this a bit more flexible for e.g. cloud storage
|
||||
@abstractmethod
|
||||
def get_path(self, image_type: ImageType, image_name: str) -> str:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def save(self, image_type: ImageType, image_name: str, image: Image) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def delete(self, image_type: ImageType, image_name: str) -> None:
|
||||
pass
|
||||
|
||||
def create_name(self, context_id: str, node_id: str) -> str:
|
||||
return f'{context_id}_{node_id}_{str(int(datetime.datetime.now(datetime.timezone.utc).timestamp()))}.png'
|
||||
|
||||
|
||||
class DiskImageStorage(ImageStorageBase):
|
||||
"""Stores images on disk"""
|
||||
__output_folder: str
|
||||
__pngWriter: PngWriter
|
||||
__cache_ids: Queue # TODO: this is an incredibly naive cache
|
||||
__cache: Dict[str, Image]
|
||||
__max_cache_size: int
|
||||
|
||||
def __init__(self, output_folder: str):
|
||||
self.__output_folder = output_folder
|
||||
self.__pngWriter = PngWriter(output_folder)
|
||||
self.__cache = dict()
|
||||
self.__cache_ids = Queue()
|
||||
self.__max_cache_size = 10 # TODO: get this from config
|
||||
|
||||
Path(output_folder).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# TODO: don't hard-code. get/save/delete should maybe take subpath?
|
||||
for image_type in ImageType:
|
||||
Path(os.path.join(output_folder, image_type)).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
def get(self, image_type: ImageType, image_name: str) -> Image:
|
||||
image_path = self.get_path(image_type, image_name)
|
||||
cache_item = self.__get_cache(image_path)
|
||||
if cache_item:
|
||||
return cache_item
|
||||
|
||||
image = Image.open(image_path)
|
||||
self.__set_cache(image_path, image)
|
||||
return image
|
||||
|
||||
# TODO: make this a bit more flexible for e.g. cloud storage
|
||||
def get_path(self, image_type: ImageType, image_name: str) -> str:
|
||||
path = os.path.join(self.__output_folder, image_type, image_name)
|
||||
return path
|
||||
|
||||
def save(self, image_type: ImageType, image_name: str, image: Image) -> None:
|
||||
image_subpath = os.path.join(image_type, image_name)
|
||||
self.__pngWriter.save_image_and_prompt_to_png(image, "", image_subpath, None) # TODO: just pass full path to png writer
|
||||
|
||||
image_path = self.get_path(image_type, image_name)
|
||||
self.__set_cache(image_path, image)
|
||||
|
||||
def delete(self, image_type: ImageType, image_name: str) -> None:
|
||||
image_path = self.get_path(image_type, image_name)
|
||||
if os.path.exists(image_path):
|
||||
os.remove(image_path)
|
||||
|
||||
if image_path in self.__cache:
|
||||
del self.__cache[image_path]
|
||||
|
||||
def __get_cache(self, image_name: str) -> Image:
|
||||
return None if image_name not in self.__cache else self.__cache[image_name]
|
||||
|
||||
def __set_cache(self, image_name: str, image: Image):
|
||||
if not image_name in self.__cache:
|
||||
self.__cache[image_name] = image
|
||||
self.__cache_ids.put(image_name) # TODO: this should refresh position for LRU cache
|
||||
if len(self.__cache) > self.__max_cache_size:
|
||||
cache_id = self.__cache_ids.get()
|
||||
del self.__cache[cache_id]
|
46
invokeai/app/services/invocation_queue.py
Normal file
46
invokeai/app/services/invocation_queue.py
Normal file
@ -0,0 +1,46 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from queue import Queue
|
||||
|
||||
|
||||
# TODO: make this serializable
|
||||
class InvocationQueueItem:
|
||||
#session_id: str
|
||||
graph_execution_state_id: str
|
||||
invocation_id: str
|
||||
invoke_all: bool
|
||||
|
||||
def __init__(self,
|
||||
#session_id: str,
|
||||
graph_execution_state_id: str,
|
||||
invocation_id: str,
|
||||
invoke_all: bool = False):
|
||||
#self.session_id = session_id
|
||||
self.graph_execution_state_id = graph_execution_state_id
|
||||
self.invocation_id = invocation_id
|
||||
self.invoke_all = invoke_all
|
||||
|
||||
|
||||
class InvocationQueueABC(ABC):
|
||||
"""Abstract base class for all invocation queues"""
|
||||
@abstractmethod
|
||||
def get(self) -> InvocationQueueItem:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def put(self, item: InvocationQueueItem|None) -> None:
|
||||
pass
|
||||
|
||||
|
||||
class MemoryInvocationQueue(InvocationQueueABC):
|
||||
__queue: Queue
|
||||
|
||||
def __init__(self):
|
||||
self.__queue = Queue()
|
||||
|
||||
def get(self) -> InvocationQueueItem:
|
||||
return self.__queue.get()
|
||||
|
||||
def put(self, item: InvocationQueueItem|None) -> None:
|
||||
self.__queue.put(item)
|
32
invokeai/app/services/invocation_services.py
Normal file
32
invokeai/app/services/invocation_services.py
Normal file
@ -0,0 +1,32 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
from .invocation_queue import InvocationQueueABC
|
||||
from .item_storage import ItemStorageABC
|
||||
from .image_storage import ImageStorageBase
|
||||
from .events import EventServiceBase
|
||||
from invokeai.backend import Generate
|
||||
|
||||
class InvocationServices():
|
||||
"""Services that can be used by invocations"""
|
||||
generate: Generate # TODO: wrap Generate, or split it up from model?
|
||||
events: EventServiceBase
|
||||
images: ImageStorageBase
|
||||
queue: InvocationQueueABC
|
||||
|
||||
# NOTE: we must forward-declare any types that include invocations, since invocations can use services
|
||||
graph_execution_manager: ItemStorageABC['GraphExecutionState']
|
||||
processor: 'InvocationProcessorABC'
|
||||
|
||||
def __init__(self,
|
||||
generate: Generate,
|
||||
events: EventServiceBase,
|
||||
images: ImageStorageBase,
|
||||
queue: InvocationQueueABC,
|
||||
graph_execution_manager: ItemStorageABC['GraphExecutionState'],
|
||||
processor: 'InvocationProcessorABC'
|
||||
):
|
||||
self.generate = generate
|
||||
self.events = events
|
||||
self.images = images
|
||||
self.queue = queue
|
||||
self.graph_execution_manager = graph_execution_manager
|
||||
self.processor = processor
|
90
invokeai/app/services/invoker.py
Normal file
90
invokeai/app/services/invoker.py
Normal file
@ -0,0 +1,90 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
from abc import ABC
|
||||
from threading import Event, Thread
|
||||
from .graph import Graph, GraphExecutionState
|
||||
from .item_storage import ItemStorageABC
|
||||
from ..invocations.baseinvocation import InvocationContext
|
||||
from .invocation_services import InvocationServices
|
||||
from .invocation_queue import InvocationQueueABC, InvocationQueueItem
|
||||
|
||||
|
||||
class Invoker:
|
||||
"""The invoker, used to execute invocations"""
|
||||
|
||||
services: InvocationServices
|
||||
|
||||
def __init__(self,
|
||||
services: InvocationServices
|
||||
):
|
||||
self.services = services
|
||||
self._start()
|
||||
|
||||
|
||||
def invoke(self, graph_execution_state: GraphExecutionState, invoke_all: bool = False) -> str|None:
|
||||
"""Determines the next node to invoke and returns the id of the invoked node, or None if there are no nodes to execute"""
|
||||
|
||||
# Get the next invocation
|
||||
invocation = graph_execution_state.next()
|
||||
if not invocation:
|
||||
return None
|
||||
|
||||
# Save the execution state
|
||||
self.services.graph_execution_manager.set(graph_execution_state)
|
||||
|
||||
# Queue the invocation
|
||||
print(f'queueing item {invocation.id}')
|
||||
self.services.queue.put(InvocationQueueItem(
|
||||
#session_id = session.id,
|
||||
graph_execution_state_id = graph_execution_state.id,
|
||||
invocation_id = invocation.id,
|
||||
invoke_all = invoke_all
|
||||
))
|
||||
|
||||
return invocation.id
|
||||
|
||||
|
||||
def create_execution_state(self, graph: Graph|None = None) -> GraphExecutionState:
|
||||
"""Creates a new execution state for the given graph"""
|
||||
new_state = GraphExecutionState(graph = Graph() if graph is None else graph)
|
||||
self.services.graph_execution_manager.set(new_state)
|
||||
return new_state
|
||||
|
||||
|
||||
def __start_service(self, service) -> None:
|
||||
# Call start() method on any services that have it
|
||||
start_op = getattr(service, 'start', None)
|
||||
if callable(start_op):
|
||||
start_op(self)
|
||||
|
||||
|
||||
def __stop_service(self, service) -> None:
|
||||
# Call stop() method on any services that have it
|
||||
stop_op = getattr(service, 'stop', None)
|
||||
if callable(stop_op):
|
||||
stop_op(self)
|
||||
|
||||
|
||||
def _start(self) -> None:
|
||||
"""Starts the invoker. This is called automatically when the invoker is created."""
|
||||
for service in vars(self.services):
|
||||
self.__start_service(getattr(self.services, service))
|
||||
|
||||
for service in vars(self.services):
|
||||
self.__start_service(getattr(self.services, service))
|
||||
|
||||
|
||||
def stop(self) -> None:
|
||||
"""Stops the invoker. A new invoker will have to be created to execute further."""
|
||||
# First stop all services
|
||||
for service in vars(self.services):
|
||||
self.__stop_service(getattr(self.services, service))
|
||||
|
||||
for service in vars(self.services):
|
||||
self.__stop_service(getattr(self.services, service))
|
||||
|
||||
self.services.queue.put(None)
|
||||
|
||||
|
||||
class InvocationProcessorABC(ABC):
|
||||
pass
|
57
invokeai/app/services/item_storage.py
Normal file
57
invokeai/app/services/item_storage.py
Normal file
@ -0,0 +1,57 @@
|
||||
|
||||
from typing import Callable, TypeVar, Generic
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic.generics import GenericModel
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
T = TypeVar('T', bound=BaseModel)
|
||||
|
||||
class PaginatedResults(GenericModel, Generic[T]):
|
||||
"""Paginated results"""
|
||||
items: list[T] = Field(description = "Items")
|
||||
page: int = Field(description = "Current Page")
|
||||
pages: int = Field(description = "Total number of pages")
|
||||
per_page: int = Field(description = "Number of items per page")
|
||||
total: int = Field(description = "Total number of items in result")
|
||||
|
||||
|
||||
class ItemStorageABC(ABC, Generic[T]):
|
||||
_on_changed_callbacks: list[Callable[[T], None]]
|
||||
_on_deleted_callbacks: list[Callable[[str], None]]
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._on_changed_callbacks = list()
|
||||
self._on_deleted_callbacks = list()
|
||||
|
||||
"""Base item storage class"""
|
||||
@abstractmethod
|
||||
def get(self, item_id: str) -> T:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def set(self, item: T) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def list(self, page: int = 0, per_page: int = 10) -> PaginatedResults[T]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def search(self, query: str, page: int = 0, per_page: int = 10) -> PaginatedResults[T]:
|
||||
pass
|
||||
|
||||
def on_changed(self, on_changed: Callable[[T], None]) -> None:
|
||||
"""Register a callback for when an item is changed"""
|
||||
self._on_changed_callbacks.append(on_changed)
|
||||
|
||||
def on_deleted(self, on_deleted: Callable[[str], None]) -> None:
|
||||
"""Register a callback for when an item is deleted"""
|
||||
self._on_deleted_callbacks.append(on_deleted)
|
||||
|
||||
def _on_changed(self, item: T) -> None:
|
||||
for callback in self._on_changed_callbacks:
|
||||
callback(item)
|
||||
|
||||
def _on_deleted(self, item_id: str) -> None:
|
||||
for callback in self._on_deleted_callbacks:
|
||||
callback(item_id)
|
95
invokeai/app/services/processor.py
Normal file
95
invokeai/app/services/processor.py
Normal file
@ -0,0 +1,95 @@
|
||||
from threading import Event, Thread
|
||||
import traceback
|
||||
from ..invocations.baseinvocation import InvocationContext
|
||||
from .invocation_queue import InvocationQueueItem
|
||||
from .invoker import InvocationProcessorABC, Invoker
|
||||
|
||||
|
||||
class DefaultInvocationProcessor(InvocationProcessorABC):
|
||||
__invoker_thread: Thread
|
||||
__stop_event: Event
|
||||
__invoker: Invoker
|
||||
|
||||
def start(self, invoker) -> None:
|
||||
self.__invoker = invoker
|
||||
self.__stop_event = Event()
|
||||
self.__invoker_thread = Thread(
|
||||
name = "invoker_processor",
|
||||
target = self.__process,
|
||||
kwargs = dict(stop_event = self.__stop_event)
|
||||
)
|
||||
self.__invoker_thread.daemon = True # TODO: probably better to just not use threads?
|
||||
self.__invoker_thread.start()
|
||||
|
||||
|
||||
def stop(self, *args, **kwargs) -> None:
|
||||
self.__stop_event.set()
|
||||
|
||||
|
||||
def __process(self, stop_event: Event):
|
||||
try:
|
||||
while not stop_event.is_set():
|
||||
queue_item: InvocationQueueItem = self.__invoker.services.queue.get()
|
||||
if not queue_item: # Probably stopping
|
||||
continue
|
||||
|
||||
graph_execution_state = self.__invoker.services.graph_execution_manager.get(queue_item.graph_execution_state_id)
|
||||
invocation = graph_execution_state.execution_graph.get_node(queue_item.invocation_id)
|
||||
|
||||
# Send starting event
|
||||
self.__invoker.services.events.emit_invocation_started(
|
||||
graph_execution_state_id = graph_execution_state.id,
|
||||
invocation_id = invocation.id
|
||||
)
|
||||
|
||||
# Invoke
|
||||
try:
|
||||
outputs = invocation.invoke(InvocationContext(
|
||||
services = self.__invoker.services,
|
||||
graph_execution_state_id = graph_execution_state.id
|
||||
))
|
||||
|
||||
# Save outputs and history
|
||||
graph_execution_state.complete(invocation.id, outputs)
|
||||
|
||||
# Save the state changes
|
||||
self.__invoker.services.graph_execution_manager.set(graph_execution_state)
|
||||
|
||||
# Send complete event
|
||||
self.__invoker.services.events.emit_invocation_complete(
|
||||
graph_execution_state_id = graph_execution_state.id,
|
||||
invocation_id = invocation.id,
|
||||
result = outputs.dict()
|
||||
)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
|
||||
except Exception as e:
|
||||
error = traceback.format_exc()
|
||||
|
||||
# Save error
|
||||
graph_execution_state.set_node_error(invocation.id, error)
|
||||
|
||||
# Save the state changes
|
||||
self.__invoker.services.graph_execution_manager.set(graph_execution_state)
|
||||
|
||||
# Send error event
|
||||
self.__invoker.services.events.emit_invocation_error(
|
||||
graph_execution_state_id = graph_execution_state.id,
|
||||
invocation_id = invocation.id,
|
||||
error = error
|
||||
)
|
||||
|
||||
pass
|
||||
|
||||
|
||||
# Queue any further commands if invoking all
|
||||
is_complete = graph_execution_state.is_complete()
|
||||
if queue_item.invoke_all and not is_complete:
|
||||
self.__invoker.invoke(graph_execution_state, invoke_all = True)
|
||||
elif is_complete:
|
||||
self.__invoker.services.events.emit_graph_execution_complete(graph_execution_state.id)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
... # Log something?
|
119
invokeai/app/services/sqlite.py
Normal file
119
invokeai/app/services/sqlite.py
Normal file
@ -0,0 +1,119 @@
|
||||
import sqlite3
|
||||
from threading import Lock
|
||||
from typing import Generic, TypeVar, Union, get_args
|
||||
from pydantic import BaseModel, parse_raw_as
|
||||
from .item_storage import ItemStorageABC, PaginatedResults
|
||||
|
||||
T = TypeVar('T', bound=BaseModel)
|
||||
|
||||
sqlite_memory = ':memory:'
|
||||
|
||||
class SqliteItemStorage(ItemStorageABC, Generic[T]):
|
||||
_filename: str
|
||||
_table_name: str
|
||||
_conn: sqlite3.Connection
|
||||
_cursor: sqlite3.Cursor
|
||||
_id_field: str
|
||||
_lock: Lock
|
||||
|
||||
def __init__(self, filename: str, table_name: str, id_field: str = 'id'):
|
||||
super().__init__()
|
||||
|
||||
self._filename = filename
|
||||
self._table_name = table_name
|
||||
self._id_field = id_field # TODO: validate that T has this field
|
||||
self._lock = Lock()
|
||||
|
||||
self._conn = sqlite3.connect(self._filename, check_same_thread=False) # TODO: figure out a better threading solution
|
||||
self._cursor = self._conn.cursor()
|
||||
|
||||
self._create_table()
|
||||
|
||||
def _create_table(self):
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(f'''CREATE TABLE IF NOT EXISTS {self._table_name} (
|
||||
item TEXT,
|
||||
id TEXT GENERATED ALWAYS AS (json_extract(item, '$.{self._id_field}')) VIRTUAL NOT NULL);''')
|
||||
self._cursor.execute(f'''CREATE UNIQUE INDEX IF NOT EXISTS {self._table_name}_id ON {self._table_name}(id);''')
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
def _parse_item(self, item: str) -> T:
|
||||
item_type = get_args(self.__orig_class__)[0]
|
||||
return parse_raw_as(item_type, item)
|
||||
|
||||
def set(self, item: T):
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(f'''INSERT OR REPLACE INTO {self._table_name} (item) VALUES (?);''', (item.json(),))
|
||||
finally:
|
||||
self._lock.release()
|
||||
self._on_changed(item)
|
||||
|
||||
def get(self, id: str) -> Union[T, None]:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(f'''SELECT item FROM {self._table_name} WHERE id = ?;''', (str(id),))
|
||||
result = self._cursor.fetchone()
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
if not result:
|
||||
return None
|
||||
|
||||
return self._parse_item(result[0])
|
||||
|
||||
def delete(self, id: str):
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(f'''DELETE FROM {self._table_name} WHERE id = ?;''', (str(id),))
|
||||
finally:
|
||||
self._lock.release()
|
||||
self._on_deleted(id)
|
||||
|
||||
def list(self, page: int = 0, per_page: int = 10) -> PaginatedResults[T]:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(f'''SELECT item FROM {self._table_name} LIMIT ? OFFSET ?;''', (per_page, page * per_page))
|
||||
result = self._cursor.fetchall()
|
||||
|
||||
items = list(map(lambda r: self._parse_item(r[0]), result))
|
||||
|
||||
self._cursor.execute(f'''SELECT count(*) FROM {self._table_name};''')
|
||||
count = self._cursor.fetchone()[0]
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
pageCount = int(count / per_page) + 1
|
||||
|
||||
return PaginatedResults[T](
|
||||
items = items,
|
||||
page = page,
|
||||
pages = pageCount,
|
||||
per_page = per_page,
|
||||
total = count
|
||||
)
|
||||
|
||||
def search(self, query: str, page: int = 0, per_page: int = 10) -> PaginatedResults[T]:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(f'''SELECT item FROM {self._table_name} WHERE item LIKE ? LIMIT ? OFFSET ?;''', (f'%{query}%', per_page, page * per_page))
|
||||
result = self._cursor.fetchall()
|
||||
|
||||
items = list(map(lambda r: self._parse_item(r[0]), result))
|
||||
|
||||
self._cursor.execute(f'''SELECT count(*) FROM {self._table_name} WHERE item LIKE ?;''', (f'%{query}%',))
|
||||
count = self._cursor.fetchone()[0]
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
pageCount = int(count / per_page) + 1
|
||||
|
||||
return PaginatedResults[T](
|
||||
items = items,
|
||||
page = page,
|
||||
pages = pageCount,
|
||||
per_page = per_page,
|
||||
total = count
|
||||
)
|
Reference in New Issue
Block a user