all files migrated; tweaks needed

This commit is contained in:
Lincoln Stein
2023-03-03 00:02:15 -05:00
parent 3f0b0f3250
commit 6a990565ff
496 changed files with 276 additions and 934 deletions

View File

@ -0,0 +1,80 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
from argparse import Namespace
import os
from ..services.processor import DefaultInvocationProcessor
from ..services.graph import GraphExecutionState
from ..services.sqlite import SqliteItemStorage
from ...globals import Globals
from ..services.image_storage import DiskImageStorage
from ..services.invocation_queue import MemoryInvocationQueue
from ..services.invocation_services import InvocationServices
from ..services.invoker import Invoker
from ..services.generate_initializer import get_generate
from .events import FastAPIEventService
# TODO: is there a better way to achieve this?
def check_internet()->bool:
'''
Return true if the internet is reachable.
It does this by pinging huggingface.co.
'''
import urllib.request
host = 'http://huggingface.co'
try:
urllib.request.urlopen(host,timeout=1)
return True
except:
return False
class ApiDependencies:
"""Contains and initializes all dependencies for the API"""
invoker: Invoker = None
@staticmethod
def initialize(
args,
config,
event_handler_id: int
):
Globals.try_patchmatch = args.patchmatch
Globals.always_use_cpu = args.always_use_cpu
Globals.internet_available = args.internet_available and check_internet()
Globals.disable_xformers = not args.xformers
Globals.ckpt_convert = args.ckpt_convert
# TODO: Use a logger
print(f'>> Internet connectivity is {Globals.internet_available}')
generate = get_generate(args, config)
events = FastAPIEventService(event_handler_id)
output_folder = os.path.abspath(os.path.join(os.path.dirname(__file__), '../../../../outputs'))
images = DiskImageStorage(output_folder)
# TODO: build a file/path manager?
db_location = os.path.join(output_folder, 'invokeai.db')
services = InvocationServices(
generate = generate,
events = events,
images = images,
queue = MemoryInvocationQueue(),
graph_execution_manager = SqliteItemStorage[GraphExecutionState](filename = db_location, table_name = 'graph_executions'),
processor = DefaultInvocationProcessor()
)
ApiDependencies.invoker = Invoker(services)
@staticmethod
def shutdown():
if ApiDependencies.invoker:
ApiDependencies.invoker.stop()

View File

@ -0,0 +1,54 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
import asyncio
from queue import Empty, Queue
from typing import Any
from fastapi_events.dispatcher import dispatch
from ..services.events import EventServiceBase
import threading
class FastAPIEventService(EventServiceBase):
event_handler_id: int
__queue: Queue
__stop_event: threading.Event
def __init__(self, event_handler_id: int) -> None:
self.event_handler_id = event_handler_id
self.__queue = Queue()
self.__stop_event = threading.Event()
asyncio.create_task(self.__dispatch_from_queue(stop_event = self.__stop_event))
super().__init__()
def stop(self, *args, **kwargs):
self.__stop_event.set()
self.__queue.put(None)
def dispatch(self, event_name: str, payload: Any) -> None:
self.__queue.put(dict(
event_name = event_name,
payload = payload
))
async def __dispatch_from_queue(self, stop_event: threading.Event):
"""Get events on from the queue and dispatch them, from the correct thread"""
while not stop_event.is_set():
try:
event = self.__queue.get(block = False)
if not event: # Probably stopping
continue
dispatch(
event.get('event_name'),
payload = event.get('payload'),
middleware_id = self.event_handler_id)
except Empty:
await asyncio.sleep(0.001)
pass
except asyncio.CancelledError as e:
raise e # Raise a proper error

View File

@ -0,0 +1,57 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
from datetime import datetime, timezone
from fastapi import Path, UploadFile, Request
from fastapi.routing import APIRouter
from fastapi.responses import FileResponse, Response
from PIL import Image
from ...services.image_storage import ImageType
from ..dependencies import ApiDependencies
images_router = APIRouter(
prefix = '/v1/images',
tags = ['images']
)
@images_router.get('/{image_type}/{image_name}',
operation_id = 'get_image'
)
async def get_image(
image_type: ImageType = Path(description = "The type of image to get"),
image_name: str = Path(description = "The name of the image to get")
):
"""Gets a result"""
# TODO: This is not really secure at all. At least make sure only output results are served
filename = ApiDependencies.invoker.services.images.get_path(image_type, image_name)
return FileResponse(filename)
@images_router.post('/uploads/',
operation_id = 'upload_image',
responses = {
201: {'description': 'The image was uploaded successfully'},
404: {'description': 'Session not found'}
})
async def upload_image(
file: UploadFile,
request: Request
):
if not file.content_type.startswith('image'):
return Response(status_code = 415)
contents = await file.read()
try:
im = Image.open(contents)
except:
# Error opening the image
return Response(status_code = 415)
filename = f'{str(int(datetime.now(timezone.utc).timestamp()))}.png'
ApiDependencies.invoker.services.images.save(ImageType.UPLOAD, filename, im)
return Response(
status_code=201,
headers = {
'Location': request.url_for('get_image', image_type=ImageType.UPLOAD, image_name=filename)
}
)

View File

@ -0,0 +1,232 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
from typing import List, Optional, Union, Annotated
from fastapi import Query, Path, Body
from fastapi.routing import APIRouter
from fastapi.responses import Response
from pydantic.fields import Field
from ...services.item_storage import PaginatedResults
from ..dependencies import ApiDependencies
from ...invocations.baseinvocation import BaseInvocation
from ...services.graph import EdgeConnection, Graph, GraphExecutionState, NodeAlreadyExecutedError
from ...invocations import *
session_router = APIRouter(
prefix = '/v1/sessions',
tags = ['sessions']
)
@session_router.post('/',
operation_id = 'create_session',
responses = {
200: {"model": GraphExecutionState},
400: {'description': 'Invalid json'}
})
async def create_session(
graph: Optional[Graph] = Body(default = None, description = "The graph to initialize the session with")
) -> GraphExecutionState:
"""Creates a new session, optionally initializing it with an invocation graph"""
session = ApiDependencies.invoker.create_execution_state(graph)
return session
@session_router.get('/',
operation_id = 'list_sessions',
responses = {
200: {"model": PaginatedResults[GraphExecutionState]}
})
async def list_sessions(
page: int = Query(default = 0, description = "The page of results to get"),
per_page: int = Query(default = 10, description = "The number of results per page"),
query: str = Query(default = '', description = "The query string to search for")
) -> PaginatedResults[GraphExecutionState]:
"""Gets a list of sessions, optionally searching"""
if filter == '':
result = ApiDependencies.invoker.services.graph_execution_manager.list(page, per_page)
else:
result = ApiDependencies.invoker.services.graph_execution_manager.search(query, page, per_page)
return result
@session_router.get('/{session_id}',
operation_id = 'get_session',
responses = {
200: {"model": GraphExecutionState},
404: {'description': 'Session not found'}
})
async def get_session(
session_id: str = Path(description = "The id of the session to get")
) -> GraphExecutionState:
"""Gets a session"""
session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
if session is None:
return Response(status_code = 404)
else:
return session
@session_router.post('/{session_id}/nodes',
operation_id = 'add_node',
responses = {
200: {"model": str},
400: {'description': 'Invalid node or link'},
404: {'description': 'Session not found'}
}
)
async def add_node(
session_id: str = Path(description = "The id of the session"),
node: Annotated[Union[BaseInvocation.get_invocations()], Field(discriminator="type")] = Body(description = "The node to add")
) -> str:
"""Adds a node to the graph"""
session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
if session is None:
return Response(status_code = 404)
try:
session.add_node(node)
ApiDependencies.invoker.services.graph_execution_manager.set(session) # TODO: can this be done automatically, or add node through an API?
return session.id
except NodeAlreadyExecutedError:
return Response(status_code = 400)
except IndexError:
return Response(status_code = 400)
@session_router.put('/{session_id}/nodes/{node_path}',
operation_id = 'update_node',
responses = {
200: {"model": GraphExecutionState},
400: {'description': 'Invalid node or link'},
404: {'description': 'Session not found'}
}
)
async def update_node(
session_id: str = Path(description = "The id of the session"),
node_path: str = Path(description = "The path to the node in the graph"),
node: Annotated[Union[BaseInvocation.get_invocations()], Field(discriminator="type")] = Body(description = "The new node")
) -> GraphExecutionState:
"""Updates a node in the graph and removes all linked edges"""
session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
if session is None:
return Response(status_code = 404)
try:
session.update_node(node_path, node)
ApiDependencies.invoker.services.graph_execution_manager.set(session) # TODO: can this be done automatically, or add node through an API?
return session
except NodeAlreadyExecutedError:
return Response(status_code = 400)
except IndexError:
return Response(status_code = 400)
@session_router.delete('/{session_id}/nodes/{node_path}',
operation_id = 'delete_node',
responses = {
200: {"model": GraphExecutionState},
400: {'description': 'Invalid node or link'},
404: {'description': 'Session not found'}
}
)
async def delete_node(
session_id: str = Path(description = "The id of the session"),
node_path: str = Path(description = "The path to the node to delete")
) -> GraphExecutionState:
"""Deletes a node in the graph and removes all linked edges"""
session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
if session is None:
return Response(status_code = 404)
try:
session.delete_node(node_path)
ApiDependencies.invoker.services.graph_execution_manager.set(session) # TODO: can this be done automatically, or add node through an API?
return session
except NodeAlreadyExecutedError:
return Response(status_code = 400)
except IndexError:
return Response(status_code = 400)
@session_router.post('/{session_id}/edges',
operation_id = 'add_edge',
responses = {
200: {"model": GraphExecutionState},
400: {'description': 'Invalid node or link'},
404: {'description': 'Session not found'}
}
)
async def add_edge(
session_id: str = Path(description = "The id of the session"),
edge: tuple[EdgeConnection, EdgeConnection] = Body(description = "The edge to add")
) -> GraphExecutionState:
"""Adds an edge to the graph"""
session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
if session is None:
return Response(status_code = 404)
try:
session.add_edge(edge)
ApiDependencies.invoker.services.graph_execution_manager.set(session) # TODO: can this be done automatically, or add node through an API?
return session
except NodeAlreadyExecutedError:
return Response(status_code = 400)
except IndexError:
return Response(status_code = 400)
# TODO: the edge being in the path here is really ugly, find a better solution
@session_router.delete('/{session_id}/edges/{from_node_id}/{from_field}/{to_node_id}/{to_field}',
operation_id = 'delete_edge',
responses = {
200: {"model": GraphExecutionState},
400: {'description': 'Invalid node or link'},
404: {'description': 'Session not found'}
}
)
async def delete_edge(
session_id: str = Path(description = "The id of the session"),
from_node_id: str = Path(description = "The id of the node the edge is coming from"),
from_field: str = Path(description = "The field of the node the edge is coming from"),
to_node_id: str = Path(description = "The id of the node the edge is going to"),
to_field: str = Path(description = "The field of the node the edge is going to")
) -> GraphExecutionState:
"""Deletes an edge from the graph"""
session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
if session is None:
return Response(status_code = 404)
try:
edge = (EdgeConnection(node_id = from_node_id, field = from_field), EdgeConnection(node_id = to_node_id, field = to_field))
session.delete_edge(edge)
ApiDependencies.invoker.services.graph_execution_manager.set(session) # TODO: can this be done automatically, or add node through an API?
return session
except NodeAlreadyExecutedError:
return Response(status_code = 400)
except IndexError:
return Response(status_code = 400)
@session_router.put('/{session_id}/invoke',
operation_id = 'invoke_session',
responses = {
200: {"model": None},
202: {'description': 'The invocation is queued'},
400: {'description': 'The session has no invocations ready to invoke'},
404: {'description': 'Session not found'}
})
async def invoke_session(
session_id: str = Path(description = "The id of the session to invoke"),
all: bool = Query(default = False, description = "Whether or not to invoke all remaining invocations")
) -> None:
"""Invokes a session"""
session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
if session is None:
return Response(status_code = 404)
if session.is_complete():
return Response(status_code = 400)
ApiDependencies.invoker.invoke(session, invoke_all = all)
return Response(status_code=202)

View File

@ -0,0 +1,36 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
from fastapi import FastAPI
from fastapi_socketio import SocketManager
from fastapi_events.handlers.local import local_handler
from fastapi_events.typing import Event
from ..services.events import EventServiceBase
class SocketIO:
__sio: SocketManager
def __init__(self, app: FastAPI):
self.__sio = SocketManager(app = app)
self.__sio.on('subscribe', handler=self._handle_sub)
self.__sio.on('unsubscribe', handler=self._handle_unsub)
local_handler.register(
event_name = EventServiceBase.session_event,
_func=self._handle_session_event
)
async def _handle_session_event(self, event: Event):
await self.__sio.emit(
event = event[1]['event'],
data = event[1]['data'],
room = event[1]['data']['graph_execution_state_id']
)
async def _handle_sub(self, sid, data, *args, **kwargs):
if 'session' in data:
self.__sio.enter_room(sid, data['session'])
# @app.sio.on('unsubscribe')
async def _handle_unsub(self, sid, data, *args, **kwargs):
if 'session' in data:
self.__sio.leave_room(sid, data['session'])

164
invokeai/app/api_app.py Normal file
View File

@ -0,0 +1,164 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
import asyncio
from inspect import signature
from fastapi import FastAPI
from fastapi.openapi.utils import get_openapi
from fastapi.openapi.docs import get_swagger_ui_html, get_redoc_html
from fastapi.staticfiles import StaticFiles
from fastapi_events.middleware import EventHandlerASGIMiddleware
from fastapi_events.handlers.local import local_handler
from fastapi.middleware.cors import CORSMiddleware
from pydantic.schema import schema
import uvicorn
from .api.sockets import SocketIO
from .invocations import *
from .invocations.baseinvocation import BaseInvocation
from .api.routers import images, sessions
from .api.dependencies import ApiDependencies
from ..args import Args
# Create the app
# TODO: create this all in a method so configuration/etc. can be passed in?
app = FastAPI(
title = "Invoke AI",
docs_url = None,
redoc_url = None
)
# Add event handler
event_handler_id: int = id(app)
app.add_middleware(
EventHandlerASGIMiddleware,
handlers = [local_handler], # TODO: consider doing this in services to support different configurations
middleware_id = event_handler_id)
# Add CORS
# TODO: use configuration for this
origins = []
app.add_middleware(
CORSMiddleware,
allow_origins=origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
socket_io = SocketIO(app)
config = {}
# Add startup event to load dependencies
@app.on_event('startup')
async def startup_event():
args = Args()
config = args.parse_args()
ApiDependencies.initialize(
args = args,
config = config,
event_handler_id = event_handler_id
)
# Shut down threads
@app.on_event('shutdown')
async def shutdown_event():
ApiDependencies.shutdown()
# Include all routers
# TODO: REMOVE
# app.include_router(
# invocation.invocation_router,
# prefix = '/api')
app.include_router(
sessions.session_router,
prefix = '/api'
)
app.include_router(
images.images_router,
prefix = '/api'
)
# Build a custom OpenAPI to include all outputs
# TODO: can outputs be included on metadata of invocation schemas somehow?
def custom_openapi():
if app.openapi_schema:
return app.openapi_schema
openapi_schema = get_openapi(
title = app.title,
description = "An API for invoking AI image operations",
version = "1.0.0",
routes = app.routes
)
# Add all outputs
all_invocations = BaseInvocation.get_invocations()
output_types = set()
output_type_titles = dict()
for invoker in all_invocations:
output_type = signature(invoker.invoke).return_annotation
output_types.add(output_type)
output_schemas = schema(output_types, ref_prefix="#/components/schemas/")
for schema_key, output_schema in output_schemas['definitions'].items():
openapi_schema["components"]["schemas"][schema_key] = output_schema
# TODO: note that we assume the schema_key here is the TYPE.__name__
# This could break in some cases, figure out a better way to do it
output_type_titles[schema_key] = output_schema['title']
# Add a reference to the output type to additionalProperties of the invoker schema
for invoker in all_invocations:
invoker_name = invoker.__name__
output_type = signature(invoker.invoke).return_annotation
output_type_title = output_type_titles[output_type.__name__]
invoker_schema = openapi_schema["components"]["schemas"][invoker_name]
outputs_ref = { '$ref': f'#/components/schemas/{output_type_title}' }
if 'additionalProperties' not in invoker_schema:
invoker_schema['additionalProperties'] = {}
invoker_schema['additionalProperties']['outputs'] = outputs_ref
app.openapi_schema = openapi_schema
return app.openapi_schema
app.openapi = custom_openapi
# Override API doc favicons
app.mount('/static', StaticFiles(directory='static/dream_web'), name='static')
@app.get("/docs", include_in_schema=False)
def overridden_swagger():
return get_swagger_ui_html(
openapi_url=app.openapi_url,
title=app.title,
swagger_favicon_url="/static/favicon.ico"
)
@app.get("/redoc", include_in_schema=False)
def overridden_redoc():
return get_redoc_html(
openapi_url=app.openapi_url,
title=app.title,
redoc_favicon_url="/static/favicon.ico"
)
def invoke_api():
# Start our own event loop for eventing usage
# TODO: determine if there's a better way to do this
loop = asyncio.new_event_loop()
config = uvicorn.Config(
app = app,
host = "0.0.0.0",
port = 9090,
loop = loop)
# Use access_log to turn off logging
server = uvicorn.Server(config)
loop.run_until_complete(server.serve())
if __name__ == "__main__":
invoke_api()

315
invokeai/app/cli_app.py Normal file
View File

@ -0,0 +1,315 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
import argparse
import shlex
import os
import time
from typing import Any, Dict, Iterable, Literal, Union, get_args, get_origin, get_type_hints
from pydantic import BaseModel
from pydantic.fields import Field
from .services.processor import DefaultInvocationProcessor
from .services.graph import EdgeConnection, GraphExecutionState
from .services.sqlite import SqliteItemStorage
from .invocations.image import ImageField
from .services.generate_initializer import get_generate
from .services.image_storage import DiskImageStorage
from .services.invocation_queue import MemoryInvocationQueue
from .invocations.baseinvocation import BaseInvocation
from .services.invocation_services import InvocationServices
from .services.invoker import Invoker
from .invocations import *
from ..args import Args
from .services.events import EventServiceBase
class InvocationCommand(BaseModel):
invocation: Union[BaseInvocation.get_invocations()] = Field(discriminator="type") # type: ignore
class InvalidArgs(Exception):
pass
def get_invocation_parser() -> argparse.ArgumentParser:
# Create invocation parser
parser = argparse.ArgumentParser()
def exit(*args, **kwargs):
raise InvalidArgs
parser.exit = exit
subparsers = parser.add_subparsers(dest='type')
invocation_parsers = dict()
# Add history parser
history_parser = subparsers.add_parser('history', help="Shows the invocation history")
history_parser.add_argument('count', nargs='?', default=5, type=int, help="The number of history entries to show")
# Add default parser
default_parser = subparsers.add_parser('default', help="Define a default value for all inputs with a specified name")
default_parser.add_argument('input', type=str, help="The input field")
default_parser.add_argument('value', help="The default value")
default_parser = subparsers.add_parser('reset_default', help="Resets a default value")
default_parser.add_argument('input', type=str, help="The input field")
# Create subparsers for each invocation
invocations = BaseInvocation.get_all_subclasses()
for invocation in invocations:
hints = get_type_hints(invocation)
cmd_name = get_args(hints['type'])[0]
command_parser = subparsers.add_parser(cmd_name, help=invocation.__doc__)
invocation_parsers[cmd_name] = command_parser
# Add linking capability
command_parser.add_argument('--link', '-l', action='append', nargs=3,
help="A link in the format 'dest_field source_node source_field'. source_node can be relative to history (e.g. -1)")
command_parser.add_argument('--link_node', '-ln', action='append',
help="A link from all fields in the specified node. Node can be relative to history (e.g. -1)")
# Convert all fields to arguments
fields = invocation.__fields__
for name, field in fields.items():
if name in ['id', 'type']:
continue
if get_origin(field.type_) == Literal:
allowed_values = get_args(field.type_)
allowed_types = set()
for val in allowed_values:
allowed_types.add(type(val))
allowed_types_list = list(allowed_types)
field_type = allowed_types_list[0] if len(allowed_types) == 1 else Union[allowed_types_list] # type: ignore
command_parser.add_argument(
f"--{name}",
dest=name,
type=field_type,
default=field.default,
choices = allowed_values,
help=field.field_info.description
)
else:
command_parser.add_argument(
f"--{name}",
dest=name,
type=field.type_,
default=field.default,
help=field.field_info.description
)
return parser
def get_invocation_command(invocation) -> str:
fields = invocation.__fields__.items()
type_hints = get_type_hints(type(invocation))
command = [invocation.type]
for name,field in fields:
if name in ['id', 'type']:
continue
# TODO: add links
# Skip image fields when serializing command
type_hint = type_hints.get(name) or None
if type_hint is ImageField or ImageField in get_args(type_hint):
continue
field_value = getattr(invocation, name)
field_default = field.default
if field_value != field_default:
if type_hint is str or str in get_args(type_hint):
command.append(f'--{name} "{field_value}"')
else:
command.append(f'--{name} {field_value}')
return ' '.join(command)
def get_graph_execution_history(graph_execution_state: GraphExecutionState) -> Iterable[str]:
"""Gets the history of fully-executed invocations for a graph execution"""
return (n for n in reversed(graph_execution_state.executed_history) if n in graph_execution_state.graph.nodes)
def generate_matching_edges(a: BaseInvocation, b: BaseInvocation) -> list[tuple[EdgeConnection, EdgeConnection]]:
"""Generates all possible edges between two invocations"""
atype = type(a)
btype = type(b)
aoutputtype = atype.get_output_type()
afields = get_type_hints(aoutputtype)
bfields = get_type_hints(btype)
matching_fields = set(afields.keys()).intersection(bfields.keys())
# Remove invalid fields
invalid_fields = set(['type', 'id'])
matching_fields = matching_fields.difference(invalid_fields)
edges = [(EdgeConnection(node_id = a.id, field = field), EdgeConnection(node_id = b.id, field = field)) for field in matching_fields]
return edges
def invoke_cli():
args = Args()
config = args.parse_args()
generate = get_generate(args, config)
# NOTE: load model on first use, uncomment to load at startup
# TODO: Make this a config option?
#generate.load_model()
events = EventServiceBase()
output_folder = os.path.abspath(os.path.join(os.path.dirname(__file__), '../../../outputs'))
# TODO: build a file/path manager?
db_location = os.path.join(output_folder, 'invokeai.db')
services = InvocationServices(
generate = generate,
events = events,
images = DiskImageStorage(output_folder),
queue = MemoryInvocationQueue(),
graph_execution_manager = SqliteItemStorage[GraphExecutionState](filename = db_location, table_name = 'graph_executions'),
processor = DefaultInvocationProcessor()
)
invoker = Invoker(services)
session: GraphExecutionState = invoker.create_execution_state()
parser = get_invocation_parser()
# Uncomment to print out previous sessions at startup
# print(services.session_manager.list())
# Defaults storage
defaults: Dict[str, Any] = dict()
while True:
try:
cmd_input = input("> ")
except KeyboardInterrupt:
# Ctrl-c exits
break
if cmd_input in ['exit','q']:
break;
if cmd_input in ['--help','help','h','?']:
parser.print_help()
continue
try:
# Refresh the state of the session
session = invoker.services.graph_execution_manager.get(session.id)
history = list(get_graph_execution_history(session))
# Split the command for piping
cmds = cmd_input.split('|')
start_id = len(history)
current_id = start_id
new_invocations = list()
for cmd in cmds:
if cmd is None or cmd.strip() == '':
raise InvalidArgs('Empty command')
# Parse args to create invocation
args = vars(parser.parse_args(shlex.split(cmd.strip())))
# Check for special commands
# TODO: These might be better as Pydantic models, similar to the invocations
if args['type'] == 'history':
history_count = args['count'] or 5
for i in range(min(history_count, len(history))):
entry_id = history[-1 - i]
entry = session.graph.get_node(entry_id)
print(f'{entry_id}: {get_invocation_command(entry.invocation)}')
continue
if args['type'] == 'reset_default':
if args['input'] in defaults:
del defaults[args['input']]
continue
if args['type'] == 'default':
field = args['input']
field_value = args['value']
defaults[field] = field_value
continue
# Override defaults
for field_name,field_default in defaults.items():
if field_name in args:
args[field_name] = field_default
# Parse invocation
args['id'] = current_id
command = InvocationCommand(invocation = args)
# Pipe previous command output (if there was a previous command)
edges = []
if len(history) > 0 or current_id != start_id:
from_id = history[0] if current_id == start_id else str(current_id - 1)
from_node = next(filter(lambda n: n[0].id == from_id, new_invocations))[0] if current_id != start_id else session.graph.get_node(from_id)
matching_edges = generate_matching_edges(from_node, command.invocation)
edges.extend(matching_edges)
# Parse provided links
if 'link_node' in args and args['link_node']:
for link in args['link_node']:
link_node = session.graph.get_node(link)
matching_edges = generate_matching_edges(link_node, command.invocation)
edges.extend(matching_edges)
if 'link' in args and args['link']:
for link in args['link']:
edges.append((EdgeConnection(node_id = link[1], field = link[0]), EdgeConnection(node_id = command.invocation.id, field = link[2])))
new_invocations.append((command.invocation, edges))
current_id = current_id + 1
# Command line was parsed successfully
# Add the invocations to the session
for invocation in new_invocations:
session.add_node(invocation[0])
for edge in invocation[1]:
session.add_edge(edge)
# Execute all available invocations
invoker.invoke(session, invoke_all = True)
while not session.is_complete():
# Wait some time
session = invoker.services.graph_execution_manager.get(session.id)
time.sleep(0.1)
# Print any errors
if session.has_error():
for n in session.errors:
print(f'Error in node {n} (source node {session.prepared_source_mapping[n]}): {session.errors[n]}')
# Start a new session
print("Creating a new session")
session = invoker.create_execution_state()
except InvalidArgs:
print('Invalid command, use "help" to list commands')
continue
except SystemExit:
continue
invoker.stop()
if __name__ == "__main__":
invoke_cli()

View File

@ -0,0 +1,8 @@
import os
__all__ = []
dirname = os.path.dirname(os.path.abspath(__file__))
for f in os.listdir(dirname):
if f != "__init__.py" and os.path.isfile("%s/%s" % (dirname, f)) and f[-3:] == ".py":
__all__.append(f[:-3])

View File

@ -0,0 +1,74 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
from abc import ABC, abstractmethod
from inspect import signature
from typing import get_args, get_type_hints
from pydantic import BaseModel, Field
from ..services.invocation_services import InvocationServices
class InvocationContext:
services: InvocationServices
graph_execution_state_id: str
def __init__(self, services: InvocationServices, graph_execution_state_id: str):
self.services = services
self.graph_execution_state_id = graph_execution_state_id
class BaseInvocationOutput(BaseModel):
"""Base class for all invocation outputs"""
# All outputs must include a type name like this:
# type: Literal['your_output_name']
@classmethod
def get_all_subclasses_tuple(cls):
subclasses = []
toprocess = [cls]
while len(toprocess) > 0:
next = toprocess.pop(0)
next_subclasses = next.__subclasses__()
subclasses.extend(next_subclasses)
toprocess.extend(next_subclasses)
return tuple(subclasses)
class BaseInvocation(ABC, BaseModel):
"""A node to process inputs and produce outputs.
May use dependency injection in __init__ to receive providers.
"""
# All invocations must include a type name like this:
# type: Literal['your_output_name']
@classmethod
def get_all_subclasses(cls):
subclasses = []
toprocess = [cls]
while len(toprocess) > 0:
next = toprocess.pop(0)
next_subclasses = next.__subclasses__()
subclasses.extend(next_subclasses)
toprocess.extend(next_subclasses)
return subclasses
@classmethod
def get_invocations(cls):
return tuple(BaseInvocation.get_all_subclasses())
@classmethod
def get_invocations_map(cls):
# Get the type strings out of the literals and into a dictionary
return dict(map(lambda t: (get_args(get_type_hints(t)['type'])[0], t),BaseInvocation.get_all_subclasses()))
@classmethod
def get_output_type(cls):
return signature(cls.invoke).return_annotation
@abstractmethod
def invoke(self, context: InvocationContext) -> BaseInvocationOutput:
"""Invoke with provided context and return outputs."""
pass
id: str = Field(description="The id of this node. Must be unique among all nodes.")

View File

@ -0,0 +1,42 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
from typing import Literal
import numpy
from pydantic import Field
from PIL import Image, ImageOps
import cv2 as cv
from .image import ImageField, ImageOutput
from .baseinvocation import BaseInvocation, InvocationContext
from ..services.image_storage import ImageType
class CvInpaintInvocation(BaseInvocation):
"""Simple inpaint using opencv."""
type: Literal['cv_inpaint'] = 'cv_inpaint'
# Inputs
image: ImageField = Field(default=None, description="The image to inpaint")
mask: ImageField = Field(default=None, description="The mask to use when inpainting")
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get(self.image.image_type, self.image.image_name)
mask = context.services.images.get(self.mask.image_type, self.mask.image_name)
# Convert to cv image/mask
# TODO: consider making these utility functions
cv_image = cv.cvtColor(numpy.array(image.convert('RGB')), cv.COLOR_RGB2BGR)
cv_mask = numpy.array(ImageOps.invert(mask))
# Inpaint
cv_inpainted = cv.inpaint(cv_image, cv_mask, 3, cv.INPAINT_TELEA)
# Convert back to Pillow
# TODO: consider making a utility function
image_inpainted = Image.fromarray(cv.cvtColor(cv_inpainted, cv.COLOR_BGR2RGB))
image_type = ImageType.INTERMEDIATE
image_name = context.services.images.create_name(context.graph_execution_state_id, self.id)
context.services.images.save(image_type, image_name, image_inpainted)
return ImageOutput(
image = ImageField(image_type = image_type, image_name = image_name)
)

View File

@ -0,0 +1,160 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
from datetime import datetime, timezone
from typing import Any, Literal, Optional, Union
import numpy as np
from pydantic import Field
from PIL import Image
from skimage.exposure.histogram_matching import match_histograms
from .image import ImageField, ImageOutput
from .baseinvocation import BaseInvocation, InvocationContext
from ..services.image_storage import ImageType
from ..services.invocation_services import InvocationServices
SAMPLER_NAME_VALUES = Literal["ddim","plms","k_lms","k_dpm_2","k_dpm_2_a","k_euler","k_euler_a","k_heun"]
# Text to image
class TextToImageInvocation(BaseInvocation):
"""Generates an image using text2img."""
type: Literal['txt2img'] = 'txt2img'
# Inputs
# TODO: consider making prompt optional to enable providing prompt through a link
prompt: Optional[str] = Field(description="The prompt to generate an image from")
seed: int = Field(default=-1, ge=-1, le=np.iinfo(np.uint32).max, description="The seed to use (-1 for a random seed)")
steps: int = Field(default=10, gt=0, description="The number of steps to use to generate the image")
width: int = Field(default=512, multiple_of=64, gt=0, description="The width of the resulting image")
height: int = Field(default=512, multiple_of=64, gt=0, description="The height of the resulting image")
cfg_scale: float = Field(default=7.5, gt=0, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt")
sampler_name: SAMPLER_NAME_VALUES = Field(default="k_lms", description="The sampler to use")
seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams")
model: str = Field(default='', description="The model to use (currently ignored)")
progress_images: bool = Field(default=False, description="Whether or not to produce progress images during generation")
# TODO: pass this an emitter method or something? or a session for dispatching?
def dispatch_progress(self, context: InvocationContext, sample: Any = None, step: int = 0) -> None:
context.services.events.emit_generator_progress(
context.graph_execution_state_id, self.id, step, float(step) / float(self.steps)
)
def invoke(self, context: InvocationContext) -> ImageOutput:
def step_callback(sample, step = 0):
self.dispatch_progress(context, sample, step)
# Handle invalid model parameter
# TODO: figure out if this can be done via a validator that uses the model_cache
# TODO: How to get the default model name now?
if self.model is None or self.model == '':
self.model = context.services.generate.model_name
# Set the model (if already cached, this does nothing)
context.services.generate.set_model(self.model)
results = context.services.generate.prompt2image(
prompt = self.prompt,
step_callback = step_callback,
**self.dict(exclude = {'prompt'}) # Shorthand for passing all of the parameters above manually
)
# Results are image and seed, unwrap for now and ignore the seed
# TODO: pre-seed?
# TODO: can this return multiple results? Should it?
image_type = ImageType.RESULT
image_name = context.services.images.create_name(context.graph_execution_state_id, self.id)
context.services.images.save(image_type, image_name, results[0][0])
return ImageOutput(
image = ImageField(image_type = image_type, image_name = image_name)
)
class ImageToImageInvocation(TextToImageInvocation):
"""Generates an image using img2img."""
type: Literal['img2img'] = 'img2img'
# Inputs
image: Union[ImageField,None] = Field(description="The input image")
strength: float = Field(default=0.75, gt=0, le=1, description="The strength of the original image")
fit: bool = Field(default=True, description="Whether or not the result should be fit to the aspect ratio of the input image")
def invoke(self, context: InvocationContext) -> ImageOutput:
image = None if self.image is None else context.services.images.get(self.image.image_type, self.image.image_name)
mask = None
def step_callback(sample, step = 0):
self.dispatch_progress(context, sample, step)
# Handle invalid model parameter
# TODO: figure out if this can be done via a validator that uses the model_cache
# TODO: How to get the default model name now?
if self.model is None or self.model == '':
self.model = context.services.generate.model_name
# Set the model (if already cached, this does nothing)
context.services.generate.set_model(self.model)
results = context.services.generate.prompt2image(
prompt = self.prompt,
init_img = image,
init_mask = mask,
step_callback = step_callback,
**self.dict(exclude = {'prompt','image','mask'}) # Shorthand for passing all of the parameters above manually
)
result_image = results[0][0]
# Results are image and seed, unwrap for now and ignore the seed
# TODO: pre-seed?
# TODO: can this return multiple results? Should it?
image_type = ImageType.RESULT
image_name = context.services.images.create_name(context.graph_execution_state_id, self.id)
context.services.images.save(image_type, image_name, result_image)
return ImageOutput(
image = ImageField(image_type = image_type, image_name = image_name)
)
class InpaintInvocation(ImageToImageInvocation):
"""Generates an image using inpaint."""
type: Literal['inpaint'] = 'inpaint'
# Inputs
mask: Union[ImageField,None] = Field(description="The mask")
inpaint_replace: float = Field(default=0.0, ge=0.0, le=1.0, description="The amount by which to replace masked areas with latent noise")
def invoke(self, context: InvocationContext) -> ImageOutput:
image = None if self.image is None else context.services.images.get(self.image.image_type, self.image.image_name)
mask = None if self.mask is None else context.services.images.get(self.mask.image_type, self.mask.image_name)
def step_callback(sample, step = 0):
self.dispatch_progress(context, sample, step)
# Handle invalid model parameter
# TODO: figure out if this can be done via a validator that uses the model_cache
# TODO: How to get the default model name now?
if self.model is None or self.model == '':
self.model = context.services.generate.model_name
# Set the model (if already cached, this does nothing)
context.services.generate.set_model(self.model)
results = context.services.generate.prompt2image(
prompt = self.prompt,
init_img = image,
init_mask = mask,
step_callback = step_callback,
**self.dict(exclude = {'prompt','image','mask'}) # Shorthand for passing all of the parameters above manually
)
result_image = results[0][0]
# Results are image and seed, unwrap for now and ignore the seed
# TODO: pre-seed?
# TODO: can this return multiple results? Should it?
image_type = ImageType.RESULT
image_name = context.services.images.create_name(context.graph_execution_state_id, self.id)
context.services.images.save(image_type, image_name, result_image)
return ImageOutput(
image = ImageField(image_type = image_type, image_name = image_name)
)

View File

@ -0,0 +1,219 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
from datetime import datetime, timezone
from typing import Literal, Optional
import numpy
from pydantic import Field, BaseModel
from PIL import Image, ImageOps, ImageFilter
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext
from ..services.image_storage import ImageType
from ..services.invocation_services import InvocationServices
class ImageField(BaseModel):
"""An image field used for passing image objects between invocations"""
image_type: str = Field(default=ImageType.RESULT, description="The type of the image")
image_name: Optional[str] = Field(default=None, description="The name of the image")
class ImageOutput(BaseInvocationOutput):
"""Base class for invocations that output an image"""
type: Literal['image'] = 'image'
image: ImageField = Field(default=None, description="The output image")
class MaskOutput(BaseInvocationOutput):
"""Base class for invocations that output a mask"""
type: Literal['mask'] = 'mask'
mask: ImageField = Field(default=None, description="The output mask")
# TODO: this isn't really necessary anymore
class LoadImageInvocation(BaseInvocation):
"""Load an image from a filename and provide it as output."""
type: Literal['load_image'] = 'load_image'
# Inputs
image_type: ImageType = Field(description="The type of the image")
image_name: str = Field(description="The name of the image")
def invoke(self, context: InvocationContext) -> ImageOutput:
return ImageOutput(
image = ImageField(image_type = self.image_type, image_name = self.image_name)
)
class ShowImageInvocation(BaseInvocation):
"""Displays a provided image, and passes it forward in the pipeline."""
type: Literal['show_image'] = 'show_image'
# Inputs
image: ImageField = Field(default=None, description="The image to show")
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get(self.image.image_type, self.image.image_name)
if image:
image.show()
# TODO: how to handle failure?
return ImageOutput(
image = ImageField(image_type = self.image.image_type, image_name = self.image.image_name)
)
class CropImageInvocation(BaseInvocation):
"""Crops an image to a specified box. The box can be outside of the image."""
type: Literal['crop'] = 'crop'
# Inputs
image: ImageField = Field(default=None, description="The image to crop")
x: int = Field(default=0, description="The left x coordinate of the crop rectangle")
y: int = Field(default=0, description="The top y coordinate of the crop rectangle")
width: int = Field(default=512, gt=0, description="The width of the crop rectangle")
height: int = Field(default=512, gt=0, description="The height of the crop rectangle")
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get(self.image.image_type, self.image.image_name)
image_crop = Image.new(mode = 'RGBA', size = (self.width, self.height), color = (0, 0, 0, 0))
image_crop.paste(image, (-self.x, -self.y))
image_type = ImageType.INTERMEDIATE
image_name = context.services.images.create_name(context.graph_execution_state_id, self.id)
context.services.images.save(image_type, image_name, image_crop)
return ImageOutput(
image = ImageField(image_type = image_type, image_name = image_name)
)
class PasteImageInvocation(BaseInvocation):
"""Pastes an image into another image."""
type: Literal['paste'] = 'paste'
# Inputs
base_image: ImageField = Field(default=None, description="The base image")
image: ImageField = Field(default=None, description="The image to paste")
mask: Optional[ImageField] = Field(default=None, description="The mask to use when pasting")
x: int = Field(default=0, description="The left x coordinate at which to paste the image")
y: int = Field(default=0, description="The top y coordinate at which to paste the image")
def invoke(self, context: InvocationContext) -> ImageOutput:
base_image = context.services.images.get(self.base_image.image_type, self.base_image.image_name)
image = context.services.images.get(self.image.image_type, self.image.image_name)
mask = None if self.mask is None else ImageOps.invert(services.images.get(self.mask.image_type, self.mask.image_name))
# TODO: probably shouldn't invert mask here... should user be required to do it?
min_x = min(0, self.x)
min_y = min(0, self.y)
max_x = max(base_image.width, image.width + self.x)
max_y = max(base_image.height, image.height + self.y)
new_image = Image.new(mode = 'RGBA', size = (max_x - min_x, max_y - min_y), color = (0, 0, 0, 0))
new_image.paste(base_image, (abs(min_x), abs(min_y)))
new_image.paste(image, (max(0, self.x), max(0, self.y)), mask = mask)
image_type = ImageType.RESULT
image_name = context.services.images.create_name(context.graph_execution_state_id, self.id)
context.services.images.save(image_type, image_name, new_image)
return ImageOutput(
image = ImageField(image_type = image_type, image_name = image_name)
)
class MaskFromAlphaInvocation(BaseInvocation):
"""Extracts the alpha channel of an image as a mask."""
type: Literal['tomask'] = 'tomask'
# Inputs
image: ImageField = Field(default=None, description="The image to create the mask from")
invert: bool = Field(default=False, description="Whether or not to invert the mask")
def invoke(self, context: InvocationContext) -> MaskOutput:
image = context.services.images.get(self.image.image_type, self.image.image_name)
image_mask = image.split()[-1]
if self.invert:
image_mask = ImageOps.invert(image_mask)
image_type = ImageType.INTERMEDIATE
image_name = context.services.images.create_name(context.graph_execution_state_id, self.id)
context.services.images.save(image_type, image_name, image_mask)
return MaskOutput(
mask = ImageField(image_type = image_type, image_name = image_name)
)
class BlurInvocation(BaseInvocation):
"""Blurs an image"""
type: Literal['blur'] = 'blur'
# Inputs
image: ImageField = Field(default=None, description="The image to blur")
radius: float = Field(default=8.0, ge=0, description="The blur radius")
blur_type: Literal['gaussian', 'box'] = Field(default='gaussian', description="The type of blur")
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get(self.image.image_type, self.image.image_name)
blur = ImageFilter.GaussianBlur(self.radius) if self.blur_type == 'gaussian' else ImageFilter.BoxBlur(self.radius)
blur_image = image.filter(blur)
image_type = ImageType.INTERMEDIATE
image_name = context.services.images.create_name(context.graph_execution_state_id, self.id)
context.services.images.save(image_type, image_name, blur_image)
return ImageOutput(
image = ImageField(image_type = image_type, image_name = image_name)
)
class LerpInvocation(BaseInvocation):
"""Linear interpolation of all pixels of an image"""
type: Literal['lerp'] = 'lerp'
# Inputs
image: ImageField = Field(default=None, description="The image to lerp")
min: int = Field(default=0, ge=0, le=255, description="The minimum output value")
max: int = Field(default=255, ge=0, le=255, description="The maximum output value")
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get(self.image.image_type, self.image.image_name)
image_arr = numpy.asarray(image, dtype=numpy.float32) / 255
image_arr = image_arr * (self.max - self.min) + self.max
lerp_image = Image.fromarray(numpy.uint8(image_arr))
image_type = ImageType.INTERMEDIATE
image_name = context.services.images.create_name(context.graph_execution_state_id, self.id)
context.services.images.save(image_type, image_name, lerp_image)
return ImageOutput(
image = ImageField(image_type = image_type, image_name = image_name)
)
class InverseLerpInvocation(BaseInvocation):
"""Inverse linear interpolation of all pixels of an image"""
type: Literal['ilerp'] = 'ilerp'
# Inputs
image: ImageField = Field(default=None, description="The image to lerp")
min: int = Field(default=0, ge=0, le=255, description="The minimum input value")
max: int = Field(default=255, ge=0, le=255, description="The maximum input value")
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get(self.image.image_type, self.image.image_name)
image_arr = numpy.asarray(image, dtype=numpy.float32)
image_arr = numpy.minimum(numpy.maximum(image_arr - self.min, 0) / float(self.max - self.min), 1) * 255
ilerp_image = Image.fromarray(numpy.uint8(image_arr))
image_type = ImageType.INTERMEDIATE
image_name = context.services.images.create_name(context.graph_execution_state_id, self.id)
context.services.images.save(image_type, image_name, ilerp_image)
return ImageOutput(
image = ImageField(image_type = image_type, image_name = image_name)
)

View File

@ -0,0 +1,9 @@
from typing import Literal
from pydantic.fields import Field
from .baseinvocation import BaseInvocationOutput
class PromptOutput(BaseInvocationOutput):
"""Base class for invocations that output a prompt"""
type: Literal['prompt'] = 'prompt'
prompt: str = Field(default=None, description="The output prompt")

View File

@ -0,0 +1,36 @@
from datetime import datetime, timezone
from typing import Literal, Union
from pydantic import Field
from .image import ImageField, ImageOutput
from .baseinvocation import BaseInvocation, InvocationContext
from ..services.image_storage import ImageType
from ..services.invocation_services import InvocationServices
class RestoreFaceInvocation(BaseInvocation):
"""Restores faces in an image."""
type: Literal['restore_face'] = 'restore_face'
# Inputs
image: Union[ImageField,None] = Field(description="The input image")
strength: float = Field(default=0.75, gt=0, le=1, description="The strength of the restoration")
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get(self.image.image_type, self.image.image_name)
results = context.services.generate.upscale_and_reconstruct(
image_list = [[image, 0]],
upscale = None,
strength = self.strength, # GFPGAN strength
save_original = False,
image_callback = None,
)
# Results are image and seed, unwrap for now
# TODO: can this return multiple results?
image_type = ImageType.RESULT
image_name = context.services.images.create_name(context.graph_execution_state_id, self.id)
context.services.images.save(image_type, image_name, results[0][0])
return ImageOutput(
image = ImageField(image_type = image_type, image_name = image_name)
)

View File

@ -0,0 +1,38 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
from datetime import datetime, timezone
from typing import Literal, Union
from pydantic import Field
from .image import ImageField, ImageOutput
from .baseinvocation import BaseInvocation, InvocationContext
from ..services.image_storage import ImageType
from ..services.invocation_services import InvocationServices
class UpscaleInvocation(BaseInvocation):
"""Upscales an image."""
type: Literal['upscale'] = 'upscale'
# Inputs
image: Union[ImageField,None] = Field(description="The input image", default=None)
strength: float = Field(default=0.75, gt=0, le=1, description="The strength")
level: Literal[2,4] = Field(default=2, description = "The upscale level")
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get(self.image.image_type, self.image.image_name)
results = context.services.generate.upscale_and_reconstruct(
image_list = [[image, 0]],
upscale = (self.level, self.strength),
strength = 0.0, # GFPGAN strength
save_original = False,
image_callback = None,
)
# Results are image and seed, unwrap for now
# TODO: can this return multiple results?
image_type = ImageType.RESULT
image_name = context.services.images.create_name(context.graph_execution_state_id, self.id)
context.services.images.save(image_type, image_name, results[0][0])
return ImageOutput(
image = ImageField(image_type = image_type, image_name = image_name)
)

View File

@ -0,0 +1,93 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
from typing import Any, Dict
class EventServiceBase:
session_event: str = 'session_event'
"""Basic event bus, to have an empty stand-in when not needed"""
def dispatch(self, event_name: str, payload: Any) -> None:
pass
def __emit_session_event(self,
event_name: str,
payload: Dict) -> None:
self.dispatch(
event_name = EventServiceBase.session_event,
payload = dict(
event = event_name,
data = payload
)
)
# Define events here for every event in the system.
# This will make them easier to integrate until we find a schema generator.
def emit_generator_progress(self,
graph_execution_state_id: str,
invocation_id: str,
step: int,
percent: float
) -> None:
"""Emitted when there is generation progress"""
self.__emit_session_event(
event_name = 'generator_progress',
payload = dict(
graph_execution_state_id = graph_execution_state_id,
invocation_id = invocation_id,
step = step,
percent = percent
)
)
def emit_invocation_complete(self,
graph_execution_state_id: str,
invocation_id: str,
result: Dict
) -> None:
"""Emitted when an invocation has completed"""
self.__emit_session_event(
event_name = 'invocation_complete',
payload = dict(
graph_execution_state_id = graph_execution_state_id,
invocation_id = invocation_id,
result = result
)
)
def emit_invocation_error(self,
graph_execution_state_id: str,
invocation_id: str,
error: str
) -> None:
"""Emitted when an invocation has completed"""
self.__emit_session_event(
event_name = 'invocation_error',
payload = dict(
graph_execution_state_id = graph_execution_state_id,
invocation_id = invocation_id,
error = error
)
)
def emit_invocation_started(self,
graph_execution_state_id: str,
invocation_id: str
) -> None:
"""Emitted when an invocation has started"""
self.__emit_session_event(
event_name = 'invocation_started',
payload = dict(
graph_execution_state_id = graph_execution_state_id,
invocation_id = invocation_id
)
)
def emit_graph_execution_complete(self, graph_execution_state_id: str) -> None:
"""Emitted when a session has completed all invocations"""
self.__emit_session_event(
event_name = 'graph_execution_state_complete',
payload = dict(
graph_execution_state_id = graph_execution_state_id
)
)

View File

@ -0,0 +1,231 @@
from argparse import Namespace
import os
import sys
import traceback
from invokeai.backend import ModelManager, Generate
from ...globals import Globals
import invokeai.version
# TODO: most of this code should be split into individual services as the Generate.py code is deprecated
def get_generate(args, config) -> Generate:
if not args.conf:
config_file = os.path.join(Globals.root,'configs','models.yaml')
if not os.path.exists(config_file):
report_model_error(args, FileNotFoundError(f"The file {config_file} could not be found."))
print(f'>> {invokeai.version.__app_name__}, version {invokeai.version.__version__}')
print(f'>> InvokeAI runtime directory is "{Globals.root}"')
# these two lines prevent a horrible warning message from appearing
# when the frozen CLIP tokenizer is imported
import transformers # type: ignore
transformers.logging.set_verbosity_error()
import diffusers
diffusers.logging.set_verbosity_error()
# Loading Face Restoration and ESRGAN Modules
gfpgan,codeformer,esrgan = load_face_restoration(args)
# normalize the config directory relative to root
if not os.path.isabs(args.conf):
args.conf = os.path.normpath(os.path.join(Globals.root,args.conf))
if args.embeddings:
if not os.path.isabs(args.embedding_path):
embedding_path = os.path.normpath(os.path.join(Globals.root,args.embedding_path))
else:
embedding_path = args.embedding_path
else:
embedding_path = None
# migrate legacy models
ModelManager.migrate_models()
# load the infile as a list of lines
if args.infile:
try:
if os.path.isfile(args.infile):
infile = open(args.infile, 'r', encoding='utf-8')
elif args.infile == '-': # stdin
infile = sys.stdin
else:
raise FileNotFoundError(f'{args.infile} not found.')
except (FileNotFoundError, IOError) as e:
print(f'{e}. Aborting.')
sys.exit(-1)
# creating a Generate object:
try:
gen = Generate(
conf = args.conf,
model = args.model,
sampler_name = args.sampler_name,
embedding_path = embedding_path,
full_precision = args.full_precision,
precision = args.precision,
gfpgan = gfpgan,
codeformer = codeformer,
esrgan = esrgan,
free_gpu_mem = args.free_gpu_mem,
safety_checker = args.safety_checker,
max_loaded_models = args.max_loaded_models,
)
except (FileNotFoundError, TypeError, AssertionError) as e:
report_model_error(opt,e)
except (IOError, KeyError) as e:
print(f'{e}. Aborting.')
sys.exit(-1)
if args.seamless:
print(">> changed to seamless tiling mode")
# preload the model
try:
gen.load_model()
except KeyError:
pass
except Exception as e:
report_model_error(args, e)
# try to autoconvert new models
# autoimport new .ckpt files
if path := args.autoconvert:
gen.model_manager.autoconvert_weights(
conf_path=args.conf,
weights_directory=path,
)
return gen
def load_face_restoration(opt):
try:
gfpgan, codeformer, esrgan = None, None, None
if opt.restore or opt.esrgan:
from ldm.invoke.restoration import Restoration
restoration = Restoration()
if opt.restore:
gfpgan, codeformer = restoration.load_face_restore_models(opt.gfpgan_model_path)
else:
print('>> Face restoration disabled')
if opt.esrgan:
esrgan = restoration.load_esrgan(opt.esrgan_bg_tile)
else:
print('>> Upscaling disabled')
else:
print('>> Face restoration and upscaling disabled')
except (ModuleNotFoundError, ImportError):
print(traceback.format_exc(), file=sys.stderr)
print('>> You may need to install the ESRGAN and/or GFPGAN modules')
return gfpgan,codeformer,esrgan
def report_model_error(opt:Namespace, e:Exception):
print(f'** An error occurred while attempting to initialize the model: "{str(e)}"')
print('** This can be caused by a missing or corrupted models file, and can sometimes be fixed by (re)installing the models.')
yes_to_all = os.environ.get('INVOKE_MODEL_RECONFIGURE')
if yes_to_all:
print('** Reconfiguration is being forced by environment variable INVOKE_MODEL_RECONFIGURE')
else:
response = input('Do you want to run invokeai-configure script to select and/or reinstall models? [y] ')
if response.startswith(('n', 'N')):
return
print('invokeai-configure is launching....\n')
# Match arguments that were set on the CLI
# only the arguments accepted by the configuration script are parsed
root_dir = ["--root", opt.root_dir] if opt.root_dir is not None else []
config = ["--config", opt.conf] if opt.conf is not None else []
previous_args = sys.argv
sys.argv = [ 'invokeai-configure' ]
sys.argv.extend(root_dir)
sys.argv.extend(config)
if yes_to_all is not None:
for arg in yes_to_all.split():
sys.argv.append(arg)
from ldm.invoke.config import invokeai_configure
invokeai_configure.main()
# TODO: Figure out how to restart
# print('** InvokeAI will now restart')
# sys.argv = previous_args
# main() # would rather do a os.exec(), but doesn't exist?
# sys.exit(0)
# Temporary initializer for Generate until we migrate off of it
def old_get_generate(args, config) -> Generate:
# TODO: Remove the need for globals
from invokeai.backend.globals import Globals
# alert - setting globals here
Globals.root = os.path.expanduser(args.root_dir or os.environ.get('INVOKEAI_ROOT') or os.path.abspath('.'))
Globals.try_patchmatch = args.patchmatch
print(f'>> InvokeAI runtime directory is "{Globals.root}"')
# these two lines prevent a horrible warning message from appearing
# when the frozen CLIP tokenizer is imported
import transformers
transformers.logging.set_verbosity_error()
# Loading Face Restoration and ESRGAN Modules
gfpgan, codeformer, esrgan = None, None, None
try:
if config.restore or config.esrgan:
from ldm.invoke.restoration import Restoration
restoration = Restoration()
if config.restore:
gfpgan, codeformer = restoration.load_face_restore_models(config.gfpgan_model_path)
else:
print('>> Face restoration disabled')
if config.esrgan:
esrgan = restoration.load_esrgan(config.esrgan_bg_tile)
else:
print('>> Upscaling disabled')
else:
print('>> Face restoration and upscaling disabled')
except (ModuleNotFoundError, ImportError):
print(traceback.format_exc(), file=sys.stderr)
print('>> You may need to install the ESRGAN and/or GFPGAN modules')
# normalize the config directory relative to root
if not os.path.isabs(config.conf):
config.conf = os.path.normpath(os.path.join(Globals.root,config.conf))
if config.embeddings:
if not os.path.isabs(config.embedding_path):
embedding_path = os.path.normpath(os.path.join(Globals.root,config.embedding_path))
else:
embedding_path = None
# TODO: lazy-initialize this by wrapping it
try:
generate = Generate(
conf = config.conf,
model = config.model,
sampler_name = config.sampler_name,
embedding_path = embedding_path,
full_precision = config.full_precision,
precision = config.precision,
gfpgan = gfpgan,
codeformer = codeformer,
esrgan = esrgan,
free_gpu_mem = config.free_gpu_mem,
safety_checker = config.safety_checker,
max_loaded_models = config.max_loaded_models,
)
except (FileNotFoundError, TypeError, AssertionError):
#emergency_model_reconfigure() # TODO?
sys.exit(-1)
except (IOError, KeyError) as e:
print(f'{e}. Aborting.')
sys.exit(-1)
generate.free_gpu_mem = config.free_gpu_mem
return generate

View File

@ -0,0 +1,809 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
import copy
import itertools
import traceback
from types import NoneType
import uuid
import networkx as nx
from pydantic import BaseModel, validator
from pydantic.fields import Field
from typing import Any, Literal, Optional, Union, get_args, get_origin, get_type_hints, Annotated
from .invocation_services import InvocationServices
from ..invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext
from ..invocations import *
class EdgeConnection(BaseModel):
node_id: str = Field(description="The id of the node for this edge connection")
field: str = Field(description="The field for this connection")
def __eq__(self, other):
return (isinstance(other, self.__class__) and
getattr(other, 'node_id', None) == self.node_id and
getattr(other, 'field', None) == self.field)
def __hash__(self):
return hash(f'{self.node_id}.{self.field}')
def get_output_field(node: BaseInvocation, field: str) -> Any:
node_type = type(node)
node_outputs = get_type_hints(node_type.get_output_type())
node_output_field = node_outputs.get(field) or None
return node_output_field
def get_input_field(node: BaseInvocation, field: str) -> Any:
node_type = type(node)
node_inputs = get_type_hints(node_type)
node_input_field = node_inputs.get(field) or None
return node_input_field
def are_connection_types_compatible(from_type: Any, to_type: Any) -> bool:
if not from_type:
return False
if not to_type:
return False
# TODO: this is pretty forgiving on generic types. Clean that up (need to handle optionals and such)
if from_type and to_type:
# Ports are compatible
if (from_type == to_type or
from_type == Any or
to_type == Any or
Any in get_args(from_type) or
Any in get_args(to_type)):
return True
if from_type in get_args(to_type):
return True
if to_type in get_args(from_type):
return True
if not issubclass(from_type, to_type):
return False
else:
return False
return True
def are_connections_compatible(
from_node: BaseInvocation,
from_field: str,
to_node: BaseInvocation,
to_field: str) -> bool:
"""Determines if a connection between fields of two nodes is compatible."""
# TODO: handle iterators and collectors
from_node_field = get_output_field(from_node, from_field)
to_node_field = get_input_field(to_node, to_field)
return are_connection_types_compatible(from_node_field, to_node_field)
class NodeAlreadyInGraphError(Exception):
pass
class InvalidEdgeError(Exception):
pass
class NodeNotFoundError(Exception):
pass
class NodeAlreadyExecutedError(Exception):
pass
# TODO: Create and use an Empty output?
class GraphInvocationOutput(BaseInvocationOutput):
type: Literal['graph_output'] = 'graph_output'
# TODO: Fill this out and move to invocations
class GraphInvocation(BaseInvocation):
type: Literal['graph'] = 'graph'
# TODO: figure out how to create a default here
graph: 'Graph' = Field(description="The graph to run", default=None)
def invoke(self, context: InvocationContext) -> GraphInvocationOutput:
"""Invoke with provided services and return outputs."""
return GraphInvocationOutput()
class IterateInvocationOutput(BaseInvocationOutput):
"""Used to connect iteration outputs. Will be expanded to a specific output."""
type: Literal['iterate_output'] = 'iterate_output'
item: Any = Field(description="The item being iterated over")
# TODO: Fill this out and move to invocations
class IterateInvocation(BaseInvocation):
type: Literal['iterate'] = 'iterate'
collection: list[Any] = Field(description="The list of items to iterate over", default_factory=list)
index: int = Field(description="The index, will be provided on executed iterators", default=0)
def invoke(self, context: InvocationContext) -> IterateInvocationOutput:
"""Produces the outputs as values"""
return IterateInvocationOutput(item = self.collection[self.index])
class CollectInvocationOutput(BaseInvocationOutput):
type: Literal['collect_output'] = 'collect_output'
collection: list[Any] = Field(description="The collection of input items")
class CollectInvocation(BaseInvocation):
"""Collects values into a collection"""
type: Literal['collect'] = 'collect'
item: Any = Field(description="The item to collect (all inputs must be of the same type)", default=None)
collection: list[Any] = Field(description="The collection, will be provided on execution", default_factory=list)
def invoke(self, context: InvocationContext) -> CollectInvocationOutput:
"""Invoke with provided services and return outputs."""
return CollectInvocationOutput(collection = copy.copy(self.collection))
InvocationsUnion = Union[BaseInvocation.get_invocations()] # type: ignore
InvocationOutputsUnion = Union[BaseInvocationOutput.get_all_subclasses_tuple()] # type: ignore
class Graph(BaseModel):
id: str = Field(description="The id of this graph", default_factory=uuid.uuid4)
# TODO: use a list (and never use dict in a BaseModel) because pydantic/fastapi hates me
nodes: dict[str, Annotated[InvocationsUnion, Field(discriminator="type")]] = Field(description="The nodes in this graph", default_factory=dict)
edges: list[tuple[EdgeConnection,EdgeConnection]] = Field(description="The connections between nodes and their fields in this graph", default_factory=list)
def add_node(self, node: BaseInvocation) -> None:
"""Adds a node to a graph
:raises NodeAlreadyInGraphError: the node is already present in the graph.
"""
if node.id in self.nodes:
raise NodeAlreadyInGraphError()
self.nodes[node.id] = node
def _get_graph_and_node(self, node_path: str) -> tuple['Graph', str]:
"""Returns the graph and node id for a node path."""
# Materialized graphs may have nodes at the top level
if node_path in self.nodes:
return (self, node_path)
node_id = node_path if '.' not in node_path else node_path[:node_path.index('.')]
if node_id not in self.nodes:
raise NodeNotFoundError(f'Node {node_path} not found in graph')
node = self.nodes[node_id]
if not isinstance(node, GraphInvocation):
# There's more node path left but this isn't a graph - failure
raise NodeNotFoundError('Node path terminated early at a non-graph node')
return node.graph._get_graph_and_node(node_path[node_path.index('.')+1:])
def delete_node(self, node_path: str) -> None:
"""Deletes a node from a graph"""
try:
graph, node_id = self._get_graph_and_node(node_path)
# Delete edges for this node
input_edges = self._get_input_edges_and_graphs(node_path)
output_edges = self._get_output_edges_and_graphs(node_path)
for edge_graph,_,edge in input_edges:
edge_graph.delete_edge(edge)
for edge_graph,_,edge in output_edges:
edge_graph.delete_edge(edge)
del graph.nodes[node_id]
except NodeNotFoundError:
pass # Ignore, not doesn't exist (should this throw?)
def add_edge(self, edge: tuple[EdgeConnection, EdgeConnection]) -> None:
"""Adds an edge to a graph
:raises InvalidEdgeError: the provided edge is invalid.
"""
if self._is_edge_valid(edge) and edge not in self.edges:
self.edges.append(edge)
else:
raise InvalidEdgeError()
def delete_edge(self, edge: tuple[EdgeConnection, EdgeConnection]) -> None:
"""Deletes an edge from a graph"""
try:
self.edges.remove(edge)
except KeyError:
pass
def is_valid(self) -> bool:
"""Validates the graph."""
# Validate all subgraphs
for gn in (n for n in self.nodes.values() if isinstance(n, GraphInvocation)):
if not gn.graph.is_valid():
return False
# Validate all edges reference nodes in the graph
node_ids = set([e[0].node_id for e in self.edges]+[e[1].node_id for e in self.edges])
if not all((self.has_node(node_id) for node_id in node_ids)):
return False
# Validate there are no cycles
g = self.nx_graph_flat()
if not nx.is_directed_acyclic_graph(g):
return False
# Validate all edge connections are valid
if not all((are_connections_compatible(
self.get_node(e[0].node_id), e[0].field,
self.get_node(e[1].node_id), e[1].field
) for e in self.edges)):
return False
# Validate all iterators
# TODO: may need to validate all iterators in subgraphs so edge connections in parent graphs will be available
if not all((self._is_iterator_connection_valid(n.id) for n in self.nodes.values() if isinstance(n, IterateInvocation))):
return False
# Validate all collectors
# TODO: may need to validate all collectors in subgraphs so edge connections in parent graphs will be available
if not all((self._is_collector_connection_valid(n.id) for n in self.nodes.values() if isinstance(n, CollectInvocation))):
return False
return True
def _is_edge_valid(self, edge: tuple[EdgeConnection, EdgeConnection]) -> bool:
"""Validates that a new edge doesn't create a cycle in the graph"""
# Validate that the nodes exist (edges may contain node paths, so we can't just check for nodes directly)
try:
from_node = self.get_node(edge[0].node_id)
to_node = self.get_node(edge[1].node_id)
except NodeNotFoundError:
return False
# Validate that an edge to this node+field doesn't already exist
input_edges = self._get_input_edges(edge[1].node_id, edge[1].field)
if len(input_edges) > 0 and not isinstance(to_node, CollectInvocation):
return False
# Validate that no cycles would be created
g = self.nx_graph_flat()
g.add_edge(edge[0].node_id, edge[1].node_id)
if not nx.is_directed_acyclic_graph(g):
return False
# Validate that the field types are compatible
if not are_connections_compatible(from_node, edge[0].field, to_node, edge[1].field):
return False
# Validate if iterator output type matches iterator input type (if this edge results in both being set)
if isinstance(to_node, IterateInvocation) and edge[1].field == 'collection':
if not self._is_iterator_connection_valid(edge[1].node_id, new_input = edge[0]):
return False
# Validate if iterator input type matches output type (if this edge results in both being set)
if isinstance(from_node, IterateInvocation) and edge[0].field == 'item':
if not self._is_iterator_connection_valid(edge[0].node_id, new_output = edge[1]):
return False
# Validate if collector input type matches output type (if this edge results in both being set)
if isinstance(to_node, CollectInvocation) and edge[1].field == 'item':
if not self._is_collector_connection_valid(edge[1].node_id, new_input = edge[0]):
return False
# Validate if collector output type matches input type (if this edge results in both being set)
if isinstance(from_node, CollectInvocation) and edge[0].field == 'collection':
if not self._is_collector_connection_valid(edge[0].node_id, new_output = edge[1]):
return False
return True
def has_node(self, node_path: str) -> bool:
"""Determines whether or not a node exists in the graph."""
try:
n = self.get_node(node_path)
if n is not None:
return True
else:
return False
except NodeNotFoundError:
return False
def get_node(self, node_path: str) -> InvocationsUnion:
"""Gets a node from the graph using a node path."""
# Materialized graphs may have nodes at the top level
graph, node_id = self._get_graph_and_node(node_path)
return graph.nodes[node_id]
def _get_node_path(self, node_id: str, prefix: Optional[str] = None) -> str:
return node_id if prefix is None or prefix == '' else f'{prefix}.{node_id}'
def update_node(self, node_path: str, new_node: BaseInvocation) -> None:
"""Updates a node in the graph."""
graph, node_id = self._get_graph_and_node(node_path)
node = graph.nodes[node_id]
# Ensure the node type matches the new node
if type(node) != type(new_node):
raise TypeError(f'Node {node_path} is type {type(node)} but new node is type {type(new_node)}')
# Ensure the new id is either the same or is not in the graph
prefix = None if '.' not in node_path else node_path[:node_path.rindex('.')]
new_path = self._get_node_path(new_node.id, prefix = prefix)
if new_node.id != node.id and self.has_node(new_path):
raise NodeAlreadyInGraphError('Node with id {new_node.id} already exists in graph')
# Set the new node in the graph
graph.nodes[new_node.id] = new_node
if new_node.id != node.id:
input_edges = self._get_input_edges_and_graphs(node_path)
output_edges = self._get_output_edges_and_graphs(node_path)
# Delete node and all edges
graph.delete_node(node_path)
# Create new edges for each input and output
for graph,_,edge in input_edges:
# Remove the graph prefix from the node path
new_graph_node_path = new_node.id if '.' not in edge[1].node_id else f'{edge[1].node_id[edge[1].node_id.rindex("."):]}.{new_node.id}'
graph.add_edge((edge[0], EdgeConnection(node_id = new_graph_node_path, field = edge[1].field)))
for graph,_,edge in output_edges:
# Remove the graph prefix from the node path
new_graph_node_path = new_node.id if '.' not in edge[0].node_id else f'{edge[0].node_id[edge[0].node_id.rindex("."):]}.{new_node.id}'
graph.add_edge((EdgeConnection(node_id = new_graph_node_path, field = edge[0].field), edge[1]))
def _get_input_edges(self, node_path: str, field: Optional[str] = None) -> list[tuple[EdgeConnection,EdgeConnection]]:
"""Gets all input edges for a node"""
edges = self._get_input_edges_and_graphs(node_path)
# Filter to edges that match the field
filtered_edges = (e for e in edges if field is None or e[2][1].field == field)
# Create full node paths for each edge
return [(EdgeConnection(node_id = self._get_node_path(e[0].node_id, prefix = prefix), field=e[0].field), EdgeConnection(node_id = self._get_node_path(e[1].node_id, prefix = prefix), field=e[1].field)) for _,prefix,e in filtered_edges]
def _get_input_edges_and_graphs(self, node_path: str, prefix: Optional[str] = None) -> list[tuple['Graph', str, tuple[EdgeConnection,EdgeConnection]]]:
"""Gets all input edges for a node along with the graph they are in and the graph's path"""
edges = list()
# Return any input edges that appear in this graph
edges.extend([(self, prefix, e) for e in self.edges if e[1].node_id == node_path])
node_id = node_path if '.' not in node_path else node_path[:node_path.index('.')]
node = self.nodes[node_id]
if isinstance(node, GraphInvocation):
graph = node.graph
graph_path = node.id if prefix is None or prefix == '' else self._get_node_path(node.id, prefix = prefix)
graph_edges = graph._get_input_edges_and_graphs(node_path[(len(node_id)+1):], prefix=graph_path)
edges.extend(graph_edges)
return edges
def _get_output_edges(self, node_path: str, field: str) -> list[tuple[EdgeConnection,EdgeConnection]]:
"""Gets all output edges for a node"""
edges = self._get_output_edges_and_graphs(node_path)
# Filter to edges that match the field
filtered_edges = (e for e in edges if e[2][0].field == field)
# Create full node paths for each edge
return [(EdgeConnection(node_id = self._get_node_path(e[0].node_id, prefix = prefix), field=e[0].field), EdgeConnection(node_id = self._get_node_path(e[1].node_id, prefix = prefix), field=e[1].field)) for _,prefix,e in filtered_edges]
def _get_output_edges_and_graphs(self, node_path: str, prefix: Optional[str] = None) -> list[tuple['Graph', str, tuple[EdgeConnection,EdgeConnection]]]:
"""Gets all output edges for a node along with the graph they are in and the graph's path"""
edges = list()
# Return any input edges that appear in this graph
edges.extend([(self, prefix, e) for e in self.edges if e[0].node_id == node_path])
node_id = node_path if '.' not in node_path else node_path[:node_path.index('.')]
node = self.nodes[node_id]
if isinstance(node, GraphInvocation):
graph = node.graph
graph_path = node.id if prefix is None or prefix == '' else self._get_node_path(node.id, prefix = prefix)
graph_edges = graph._get_output_edges_and_graphs(node_path[(len(node_id)+1):], prefix=graph_path)
edges.extend(graph_edges)
return edges
def _is_iterator_connection_valid(self, node_path: str, new_input: Optional[EdgeConnection] = None, new_output: Optional[EdgeConnection] = None) -> bool:
inputs = list([e[0] for e in self._get_input_edges(node_path, 'collection')])
outputs = list([e[1] for e in self._get_output_edges(node_path, 'item')])
if new_input is not None:
inputs.append(new_input)
if new_output is not None:
outputs.append(new_output)
# Only one input is allowed for iterators
if len(inputs) > 1:
return False
# Get input and output fields (the fields linked to the iterator's input/output)
input_field = get_output_field(self.get_node(inputs[0].node_id), inputs[0].field)
output_fields = list([get_input_field(self.get_node(e.node_id), e.field) for e in outputs])
# Input type must be a list
if get_origin(input_field) != list:
return False
# Validate that all outputs match the input type
input_field_item_type = get_args(input_field)[0]
if not all((are_connection_types_compatible(input_field_item_type, f) for f in output_fields)):
return False
return True
def _is_collector_connection_valid(self, node_path: str, new_input: Optional[EdgeConnection] = None, new_output: Optional[EdgeConnection] = None) -> bool:
inputs = list([e[0] for e in self._get_input_edges(node_path, 'item')])
outputs = list([e[1] for e in self._get_output_edges(node_path, 'collection')])
if new_input is not None:
inputs.append(new_input)
if new_output is not None:
outputs.append(new_output)
# Get input and output fields (the fields linked to the iterator's input/output)
input_fields = list([get_output_field(self.get_node(e.node_id), e.field) for e in inputs])
output_fields = list([get_input_field(self.get_node(e.node_id), e.field) for e in outputs])
# Validate that all inputs are derived from or match a single type
input_field_types = set([t for input_field in input_fields for t in ([input_field] if get_origin(input_field) == None else get_args(input_field)) if t != NoneType]) # Get unique types
type_tree = nx.DiGraph()
type_tree.add_nodes_from(input_field_types)
type_tree.add_edges_from([e for e in itertools.permutations(input_field_types, 2) if issubclass(e[1], e[0])])
type_degrees = type_tree.in_degree(type_tree.nodes)
if sum((t[1] == 0 for t in type_degrees)) != 1: # type: ignore
return False # There is more than one root type
# Get the input root type
input_root_type = next(t[0] for t in type_degrees if t[1] == 0) # type: ignore
# Verify that all outputs are lists
if not all((get_origin(f) == list for f in output_fields)):
return False
# Verify that all outputs match the input type (are a base class or the same class)
if not all((issubclass(input_root_type, get_args(f)[0]) for f in output_fields)):
return False
return True
def nx_graph(self) -> nx.DiGraph:
"""Returns a NetworkX DiGraph representing the layout of this graph"""
# TODO: Cache this?
g = nx.DiGraph()
g.add_nodes_from([n for n in self.nodes.keys()])
g.add_edges_from(set([(e[0].node_id, e[1].node_id) for e in self.edges]))
return g
def nx_graph_flat(self, nx_graph: Optional[nx.DiGraph] = None, prefix: Optional[str] = None) -> nx.DiGraph:
"""Returns a flattened NetworkX DiGraph, including all subgraphs (but not with iterations expanded)"""
g = nx_graph or nx.DiGraph()
# Add all nodes from this graph except graph/iteration nodes
g.add_nodes_from([self._get_node_path(n.id, prefix) for n in self.nodes.values() if not isinstance(n, GraphInvocation) and not isinstance(n, IterateInvocation)])
# Expand graph nodes
for sgn in (gn for gn in self.nodes.values() if isinstance(gn, GraphInvocation)):
sgn.graph.nx_graph_flat(g, self._get_node_path(sgn.id, prefix))
# TODO: figure out if iteration nodes need to be expanded
unique_edges = set([(e[0].node_id, e[1].node_id) for e in self.edges])
g.add_edges_from([(self._get_node_path(e[0], prefix), self._get_node_path(e[1], prefix)) for e in unique_edges])
return g
class GraphExecutionState(BaseModel):
"""Tracks the state of a graph execution"""
id: str = Field(description="The id of the execution state", default_factory=uuid.uuid4)
# TODO: Store a reference to the graph instead of the actual graph?
graph: Graph = Field(description="The graph being executed")
# The graph of materialized nodes
execution_graph: Graph = Field(description="The expanded graph of activated and executed nodes", default_factory=Graph)
# Nodes that have been executed
executed: set[str] = Field(description="The set of node ids that have been executed", default_factory=set)
executed_history: list[str] = Field(description="The list of node ids that have been executed, in order of execution", default_factory=list)
# The results of executed nodes
results: dict[str, Annotated[InvocationOutputsUnion, Field(discriminator="type")]] = Field(description="The results of node executions", default_factory=dict)
# Errors raised when executing nodes
errors: dict[str, str] = Field(description="Errors raised when executing nodes", default_factory=dict)
# Map of prepared/executed nodes to their original nodes
prepared_source_mapping: dict[str, str] = Field(description="The map of prepared nodes to original graph nodes", default_factory=dict)
# Map of original nodes to prepared nodes
source_prepared_mapping: dict[str, set[str]] = Field(description="The map of original graph nodes to prepared nodes", default_factory=dict)
def next(self) -> BaseInvocation | None:
"""Gets the next node ready to execute."""
# TODO: enable multiple nodes to execute simultaneously by tracking currently executing nodes
# possibly with a timeout?
# If there are no prepared nodes, prepare some nodes
next_node = self._get_next_node()
if next_node is None:
prepared_id = self._prepare()
# TODO: prepare multiple nodes at once?
# while prepared_id is not None and not isinstance(self.graph.nodes[prepared_id], IterateInvocation):
# prepared_id = self._prepare()
if prepared_id is not None:
next_node = self._get_next_node()
# Get values from edges
if next_node is not None:
self._prepare_inputs(next_node)
# If next is still none, there's no next node, return None
return next_node
def complete(self, node_id: str, output: InvocationOutputsUnion):
"""Marks a node as complete"""
if node_id not in self.execution_graph.nodes:
return # TODO: log error?
# Mark node as executed
self.executed.add(node_id)
self.results[node_id] = output
# Check if source node is complete (all prepared nodes are complete)
source_node = self.prepared_source_mapping[node_id]
prepared_nodes = self.source_prepared_mapping[source_node]
if all([n in self.executed for n in prepared_nodes]):
self.executed.add(source_node)
self.executed_history.append(source_node)
def set_node_error(self, node_id: str, error: str):
"""Marks a node as errored"""
self.errors[node_id] = error
def is_complete(self) -> bool:
"""Returns true if the graph is complete"""
return self.has_error() or all((k in self.executed for k in self.graph.nodes))
def has_error(self) -> bool:
"""Returns true if the graph has any errors"""
return len(self.errors) > 0
def _create_execution_node(self, node_path: str, iteration_node_map: list[tuple[str, str]]) -> list[str]:
"""Prepares an iteration node and connects all edges, returning the new node id"""
node = self.graph.get_node(node_path)
self_iteration_count = -1
# If this is an iterator node, we must create a copy for each iteration
if isinstance(node, IterateInvocation):
# Get input collection edge (should error if there are no inputs)
input_collection_edge = next(iter(self.graph._get_input_edges(node_path, 'collection')))
input_collection_prepared_node_id = next(n[1] for n in iteration_node_map if n[0] == input_collection_edge[0].node_id)
input_collection_prepared_node_output = self.results[input_collection_prepared_node_id]
input_collection = getattr(input_collection_prepared_node_output, input_collection_edge[0].field)
self_iteration_count = len(input_collection)
new_nodes = list()
if self_iteration_count == 0:
# TODO: should this raise a warning? It might just happen if an empty collection is input, and should be valid.
return new_nodes
# Get all input edges
input_edges = self.graph._get_input_edges(node_path)
# Create new edges for this iteration
# For collect nodes, this may contain multiple inputs to the same field
new_edges = list()
for edge in input_edges:
for input_node_id in (n[1] for n in iteration_node_map if n[0] == edge[0].node_id):
new_edge = (EdgeConnection(node_id = input_node_id, field = edge[0].field), EdgeConnection(node_id = '', field = edge[1].field))
new_edges.append(new_edge)
# Create a new node (or one for each iteration of this iterator)
for i in (range(self_iteration_count) if self_iteration_count > 0 else [-1]):
# Create a new node
new_node = copy.deepcopy(node)
# Create the node id (use a random uuid)
new_node.id = str(uuid.uuid4())
# Set the iteration index for iteration invocations
if isinstance(new_node, IterateInvocation):
new_node.index = i
# Add to execution graph
self.execution_graph.add_node(new_node)
self.prepared_source_mapping[new_node.id] = node_path
if node_path not in self.source_prepared_mapping:
self.source_prepared_mapping[node_path] = set()
self.source_prepared_mapping[node_path].add(new_node.id)
# Add new edges to execution graph
for edge in new_edges:
new_edge = (edge[0], EdgeConnection(node_id = new_node.id, field = edge[1].field))
self.execution_graph.add_edge(new_edge)
new_nodes.append(new_node.id)
return new_nodes
def _iterator_graph(self) -> nx.DiGraph:
"""Gets a DiGraph with edges to collectors removed so an ancestor search produces all active iterators for any node"""
g = self.graph.nx_graph()
collectors = (n for n in self.graph.nodes if isinstance(self.graph.nodes[n], CollectInvocation))
for c in collectors:
g.remove_edges_from(list(g.in_edges(c)))
return g
def _get_node_iterators(self, node_id: str) -> list[str]:
"""Gets iterators for a node"""
g = self._iterator_graph()
iterators = [n for n in nx.ancestors(g, node_id) if isinstance(self.graph.nodes[n], IterateInvocation)]
return iterators
def _prepare(self) -> Optional[str]:
# Get flattened source graph
g = self.graph.nx_graph_flat()
# Find next unprepared node where all source nodes are executed
sorted_nodes = nx.topological_sort(g)
next_node_id = next((n for n in sorted_nodes if n not in self.source_prepared_mapping and all((e[0] in self.executed for e in g.in_edges(n)))), None)
if next_node_id == None:
return None
# Get all parents of the next node
next_node_parents = [e[0] for e in g.in_edges(next_node_id)]
# Create execution nodes
next_node = self.graph.get_node(next_node_id)
new_node_ids = list()
if isinstance(next_node, CollectInvocation):
# Collapse all iterator input mappings and create a single execution node for the collect invocation
all_iteration_mappings = list(itertools.chain(*(((s,p) for p in self.source_prepared_mapping[s]) for s in next_node_parents)))
#all_iteration_mappings = list(set(itertools.chain(*prepared_parent_mappings)))
create_results = self._create_execution_node(next_node_id, all_iteration_mappings)
if create_results is not None:
new_node_ids.extend(create_results)
else: # Iterators or normal nodes
# Get all iterator combinations for this node
# Will produce a list of lists of prepared iterator nodes, from which results can be iterated
iterator_nodes = self._get_node_iterators(next_node_id)
iterator_nodes_prepared = [list(self.source_prepared_mapping[n]) for n in iterator_nodes]
iterator_node_prepared_combinations = list(itertools.product(*iterator_nodes_prepared))
# Select the correct prepared parents for each iteration
# For every iterator, the parent must either not be a child of that iterator, or must match the prepared iteration for that iterator
# TODO: Handle a node mapping to none
eg = self.execution_graph.nx_graph_flat()
prepared_parent_mappings = [[(n,self._get_iteration_node(n, g, eg, it)) for n in next_node_parents] for it in iterator_node_prepared_combinations] # type: ignore
# Create execution node for each iteration
for iteration_mappings in prepared_parent_mappings:
create_results = self._create_execution_node(next_node_id, iteration_mappings) # type: ignore
if create_results is not None:
new_node_ids.extend(create_results)
return next(iter(new_node_ids), None)
def _get_iteration_node(self, source_node_path: str, graph: nx.DiGraph, execution_graph: nx.DiGraph, prepared_iterator_nodes: list[str]) -> Optional[str]:
"""Gets the prepared version of the specified source node that matches every iteration specified"""
prepared_nodes = self.source_prepared_mapping[source_node_path]
if len(prepared_nodes) == 1:
return next(iter(prepared_nodes))
# Check if the requested node is an iterator
prepared_iterator = next((n for n in prepared_nodes if n in prepared_iterator_nodes), None)
if prepared_iterator is not None:
return prepared_iterator
# Filter to only iterator nodes that are a parent of the specified node, in tuple format (prepared, source)
iterator_source_node_mapping = [(n, self.prepared_source_mapping[n]) for n in prepared_iterator_nodes]
parent_iterators = [itn for itn in iterator_source_node_mapping if nx.has_path(graph, itn[1], source_node_path)]
return next((n for n in prepared_nodes if all(pit for pit in parent_iterators if nx.has_path(execution_graph, pit[0], n))), None)
def _get_next_node(self) -> Optional[BaseInvocation]:
g = self.execution_graph.nx_graph()
sorted_nodes = nx.topological_sort(g)
next_node = next((n for n in sorted_nodes if n not in self.executed), None)
if next_node is None:
return None
return self.execution_graph.nodes[next_node]
def _prepare_inputs(self, node: BaseInvocation):
input_edges = [e for e in self.execution_graph.edges if e[1].node_id == node.id]
if isinstance(node, CollectInvocation):
output_collection = [getattr(self.results[edge[0].node_id], edge[0].field) for edge in input_edges if edge[1].field == 'item']
setattr(node, 'collection', output_collection)
else:
for edge in input_edges:
output_value = getattr(self.results[edge[0].node_id], edge[0].field)
setattr(node, edge[1].field, output_value)
# TODO: Add API for modifying underlying graph that checks if the change will be valid given the current execution state
def _is_edge_valid(self, edge: tuple[EdgeConnection, EdgeConnection]) -> bool:
if not self._is_edge_valid(edge):
return False
# Invalid if destination has already been prepared or executed
if edge[1].node_id in self.source_prepared_mapping:
return False
# Otherwise, the edge is valid
return True
def _is_node_updatable(self, node_id: str) -> bool:
# The node is updatable as long as it hasn't been prepared or executed
return node_id not in self.source_prepared_mapping
def add_node(self, node: BaseInvocation) -> None:
self.graph.add_node(node)
def update_node(self, node_path: str, new_node: BaseInvocation) -> None:
if not self._is_node_updatable(node_path):
raise NodeAlreadyExecutedError(f'Node {node_path} has already been prepared or executed and cannot be updated')
self.graph.update_node(node_path, new_node)
def delete_node(self, node_path: str) -> None:
if not self._is_node_updatable(node_path):
raise NodeAlreadyExecutedError(f'Node {node_path} has already been prepared or executed and cannot be deleted')
self.graph.delete_node(node_path)
def add_edge(self, edge: tuple[EdgeConnection, EdgeConnection]) -> None:
if not self._is_node_updatable(edge[1].node_id):
raise NodeAlreadyExecutedError(f'Destination node {edge[1].node_id} has already been prepared or executed and cannot be linked to')
self.graph.add_edge(edge)
def delete_edge(self, edge: tuple[EdgeConnection, EdgeConnection]) -> None:
if not self._is_node_updatable(edge[1].node_id):
raise NodeAlreadyExecutedError(f'Destination node {edge[1].node_id} has already been prepared or executed and cannot have a source edge deleted')
self.graph.delete_edge(edge)
GraphInvocation.update_forward_refs()

View File

@ -0,0 +1,104 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
from abc import ABC, abstractmethod
from enum import Enum
import datetime
import os
from pathlib import Path
from queue import Queue
from typing import Dict
from PIL.Image import Image
from invokeai.backend.image_util import PngWriter
class ImageType(str, Enum):
RESULT = 'results'
INTERMEDIATE = 'intermediates'
UPLOAD = 'uploads'
class ImageStorageBase(ABC):
"""Responsible for storing and retrieving images."""
@abstractmethod
def get(self, image_type: ImageType, image_name: str) -> Image:
pass
# TODO: make this a bit more flexible for e.g. cloud storage
@abstractmethod
def get_path(self, image_type: ImageType, image_name: str) -> str:
pass
@abstractmethod
def save(self, image_type: ImageType, image_name: str, image: Image) -> None:
pass
@abstractmethod
def delete(self, image_type: ImageType, image_name: str) -> None:
pass
def create_name(self, context_id: str, node_id: str) -> str:
return f'{context_id}_{node_id}_{str(int(datetime.datetime.now(datetime.timezone.utc).timestamp()))}.png'
class DiskImageStorage(ImageStorageBase):
"""Stores images on disk"""
__output_folder: str
__pngWriter: PngWriter
__cache_ids: Queue # TODO: this is an incredibly naive cache
__cache: Dict[str, Image]
__max_cache_size: int
def __init__(self, output_folder: str):
self.__output_folder = output_folder
self.__pngWriter = PngWriter(output_folder)
self.__cache = dict()
self.__cache_ids = Queue()
self.__max_cache_size = 10 # TODO: get this from config
Path(output_folder).mkdir(parents=True, exist_ok=True)
# TODO: don't hard-code. get/save/delete should maybe take subpath?
for image_type in ImageType:
Path(os.path.join(output_folder, image_type)).mkdir(parents=True, exist_ok=True)
def get(self, image_type: ImageType, image_name: str) -> Image:
image_path = self.get_path(image_type, image_name)
cache_item = self.__get_cache(image_path)
if cache_item:
return cache_item
image = Image.open(image_path)
self.__set_cache(image_path, image)
return image
# TODO: make this a bit more flexible for e.g. cloud storage
def get_path(self, image_type: ImageType, image_name: str) -> str:
path = os.path.join(self.__output_folder, image_type, image_name)
return path
def save(self, image_type: ImageType, image_name: str, image: Image) -> None:
image_subpath = os.path.join(image_type, image_name)
self.__pngWriter.save_image_and_prompt_to_png(image, "", image_subpath, None) # TODO: just pass full path to png writer
image_path = self.get_path(image_type, image_name)
self.__set_cache(image_path, image)
def delete(self, image_type: ImageType, image_name: str) -> None:
image_path = self.get_path(image_type, image_name)
if os.path.exists(image_path):
os.remove(image_path)
if image_path in self.__cache:
del self.__cache[image_path]
def __get_cache(self, image_name: str) -> Image:
return None if image_name not in self.__cache else self.__cache[image_name]
def __set_cache(self, image_name: str, image: Image):
if not image_name in self.__cache:
self.__cache[image_name] = image
self.__cache_ids.put(image_name) # TODO: this should refresh position for LRU cache
if len(self.__cache) > self.__max_cache_size:
cache_id = self.__cache_ids.get()
del self.__cache[cache_id]

View File

@ -0,0 +1,46 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
from abc import ABC, abstractmethod
from queue import Queue
# TODO: make this serializable
class InvocationQueueItem:
#session_id: str
graph_execution_state_id: str
invocation_id: str
invoke_all: bool
def __init__(self,
#session_id: str,
graph_execution_state_id: str,
invocation_id: str,
invoke_all: bool = False):
#self.session_id = session_id
self.graph_execution_state_id = graph_execution_state_id
self.invocation_id = invocation_id
self.invoke_all = invoke_all
class InvocationQueueABC(ABC):
"""Abstract base class for all invocation queues"""
@abstractmethod
def get(self) -> InvocationQueueItem:
pass
@abstractmethod
def put(self, item: InvocationQueueItem|None) -> None:
pass
class MemoryInvocationQueue(InvocationQueueABC):
__queue: Queue
def __init__(self):
self.__queue = Queue()
def get(self) -> InvocationQueueItem:
return self.__queue.get()
def put(self, item: InvocationQueueItem|None) -> None:
self.__queue.put(item)

View File

@ -0,0 +1,32 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
from .invocation_queue import InvocationQueueABC
from .item_storage import ItemStorageABC
from .image_storage import ImageStorageBase
from .events import EventServiceBase
from invokeai.backend import Generate
class InvocationServices():
"""Services that can be used by invocations"""
generate: Generate # TODO: wrap Generate, or split it up from model?
events: EventServiceBase
images: ImageStorageBase
queue: InvocationQueueABC
# NOTE: we must forward-declare any types that include invocations, since invocations can use services
graph_execution_manager: ItemStorageABC['GraphExecutionState']
processor: 'InvocationProcessorABC'
def __init__(self,
generate: Generate,
events: EventServiceBase,
images: ImageStorageBase,
queue: InvocationQueueABC,
graph_execution_manager: ItemStorageABC['GraphExecutionState'],
processor: 'InvocationProcessorABC'
):
self.generate = generate
self.events = events
self.images = images
self.queue = queue
self.graph_execution_manager = graph_execution_manager
self.processor = processor

View File

@ -0,0 +1,90 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
from abc import ABC
from threading import Event, Thread
from .graph import Graph, GraphExecutionState
from .item_storage import ItemStorageABC
from ..invocations.baseinvocation import InvocationContext
from .invocation_services import InvocationServices
from .invocation_queue import InvocationQueueABC, InvocationQueueItem
class Invoker:
"""The invoker, used to execute invocations"""
services: InvocationServices
def __init__(self,
services: InvocationServices
):
self.services = services
self._start()
def invoke(self, graph_execution_state: GraphExecutionState, invoke_all: bool = False) -> str|None:
"""Determines the next node to invoke and returns the id of the invoked node, or None if there are no nodes to execute"""
# Get the next invocation
invocation = graph_execution_state.next()
if not invocation:
return None
# Save the execution state
self.services.graph_execution_manager.set(graph_execution_state)
# Queue the invocation
print(f'queueing item {invocation.id}')
self.services.queue.put(InvocationQueueItem(
#session_id = session.id,
graph_execution_state_id = graph_execution_state.id,
invocation_id = invocation.id,
invoke_all = invoke_all
))
return invocation.id
def create_execution_state(self, graph: Graph|None = None) -> GraphExecutionState:
"""Creates a new execution state for the given graph"""
new_state = GraphExecutionState(graph = Graph() if graph is None else graph)
self.services.graph_execution_manager.set(new_state)
return new_state
def __start_service(self, service) -> None:
# Call start() method on any services that have it
start_op = getattr(service, 'start', None)
if callable(start_op):
start_op(self)
def __stop_service(self, service) -> None:
# Call stop() method on any services that have it
stop_op = getattr(service, 'stop', None)
if callable(stop_op):
stop_op(self)
def _start(self) -> None:
"""Starts the invoker. This is called automatically when the invoker is created."""
for service in vars(self.services):
self.__start_service(getattr(self.services, service))
for service in vars(self.services):
self.__start_service(getattr(self.services, service))
def stop(self) -> None:
"""Stops the invoker. A new invoker will have to be created to execute further."""
# First stop all services
for service in vars(self.services):
self.__stop_service(getattr(self.services, service))
for service in vars(self.services):
self.__stop_service(getattr(self.services, service))
self.services.queue.put(None)
class InvocationProcessorABC(ABC):
pass

View File

@ -0,0 +1,57 @@
from typing import Callable, TypeVar, Generic
from pydantic import BaseModel, Field
from pydantic.generics import GenericModel
from abc import ABC, abstractmethod
T = TypeVar('T', bound=BaseModel)
class PaginatedResults(GenericModel, Generic[T]):
"""Paginated results"""
items: list[T] = Field(description = "Items")
page: int = Field(description = "Current Page")
pages: int = Field(description = "Total number of pages")
per_page: int = Field(description = "Number of items per page")
total: int = Field(description = "Total number of items in result")
class ItemStorageABC(ABC, Generic[T]):
_on_changed_callbacks: list[Callable[[T], None]]
_on_deleted_callbacks: list[Callable[[str], None]]
def __init__(self) -> None:
self._on_changed_callbacks = list()
self._on_deleted_callbacks = list()
"""Base item storage class"""
@abstractmethod
def get(self, item_id: str) -> T:
pass
@abstractmethod
def set(self, item: T) -> None:
pass
@abstractmethod
def list(self, page: int = 0, per_page: int = 10) -> PaginatedResults[T]:
pass
@abstractmethod
def search(self, query: str, page: int = 0, per_page: int = 10) -> PaginatedResults[T]:
pass
def on_changed(self, on_changed: Callable[[T], None]) -> None:
"""Register a callback for when an item is changed"""
self._on_changed_callbacks.append(on_changed)
def on_deleted(self, on_deleted: Callable[[str], None]) -> None:
"""Register a callback for when an item is deleted"""
self._on_deleted_callbacks.append(on_deleted)
def _on_changed(self, item: T) -> None:
for callback in self._on_changed_callbacks:
callback(item)
def _on_deleted(self, item_id: str) -> None:
for callback in self._on_deleted_callbacks:
callback(item_id)

View File

@ -0,0 +1,95 @@
from threading import Event, Thread
import traceback
from ..invocations.baseinvocation import InvocationContext
from .invocation_queue import InvocationQueueItem
from .invoker import InvocationProcessorABC, Invoker
class DefaultInvocationProcessor(InvocationProcessorABC):
__invoker_thread: Thread
__stop_event: Event
__invoker: Invoker
def start(self, invoker) -> None:
self.__invoker = invoker
self.__stop_event = Event()
self.__invoker_thread = Thread(
name = "invoker_processor",
target = self.__process,
kwargs = dict(stop_event = self.__stop_event)
)
self.__invoker_thread.daemon = True # TODO: probably better to just not use threads?
self.__invoker_thread.start()
def stop(self, *args, **kwargs) -> None:
self.__stop_event.set()
def __process(self, stop_event: Event):
try:
while not stop_event.is_set():
queue_item: InvocationQueueItem = self.__invoker.services.queue.get()
if not queue_item: # Probably stopping
continue
graph_execution_state = self.__invoker.services.graph_execution_manager.get(queue_item.graph_execution_state_id)
invocation = graph_execution_state.execution_graph.get_node(queue_item.invocation_id)
# Send starting event
self.__invoker.services.events.emit_invocation_started(
graph_execution_state_id = graph_execution_state.id,
invocation_id = invocation.id
)
# Invoke
try:
outputs = invocation.invoke(InvocationContext(
services = self.__invoker.services,
graph_execution_state_id = graph_execution_state.id
))
# Save outputs and history
graph_execution_state.complete(invocation.id, outputs)
# Save the state changes
self.__invoker.services.graph_execution_manager.set(graph_execution_state)
# Send complete event
self.__invoker.services.events.emit_invocation_complete(
graph_execution_state_id = graph_execution_state.id,
invocation_id = invocation.id,
result = outputs.dict()
)
except KeyboardInterrupt:
pass
except Exception as e:
error = traceback.format_exc()
# Save error
graph_execution_state.set_node_error(invocation.id, error)
# Save the state changes
self.__invoker.services.graph_execution_manager.set(graph_execution_state)
# Send error event
self.__invoker.services.events.emit_invocation_error(
graph_execution_state_id = graph_execution_state.id,
invocation_id = invocation.id,
error = error
)
pass
# Queue any further commands if invoking all
is_complete = graph_execution_state.is_complete()
if queue_item.invoke_all and not is_complete:
self.__invoker.invoke(graph_execution_state, invoke_all = True)
elif is_complete:
self.__invoker.services.events.emit_graph_execution_complete(graph_execution_state.id)
except KeyboardInterrupt:
... # Log something?

View File

@ -0,0 +1,119 @@
import sqlite3
from threading import Lock
from typing import Generic, TypeVar, Union, get_args
from pydantic import BaseModel, parse_raw_as
from .item_storage import ItemStorageABC, PaginatedResults
T = TypeVar('T', bound=BaseModel)
sqlite_memory = ':memory:'
class SqliteItemStorage(ItemStorageABC, Generic[T]):
_filename: str
_table_name: str
_conn: sqlite3.Connection
_cursor: sqlite3.Cursor
_id_field: str
_lock: Lock
def __init__(self, filename: str, table_name: str, id_field: str = 'id'):
super().__init__()
self._filename = filename
self._table_name = table_name
self._id_field = id_field # TODO: validate that T has this field
self._lock = Lock()
self._conn = sqlite3.connect(self._filename, check_same_thread=False) # TODO: figure out a better threading solution
self._cursor = self._conn.cursor()
self._create_table()
def _create_table(self):
try:
self._lock.acquire()
self._cursor.execute(f'''CREATE TABLE IF NOT EXISTS {self._table_name} (
item TEXT,
id TEXT GENERATED ALWAYS AS (json_extract(item, '$.{self._id_field}')) VIRTUAL NOT NULL);''')
self._cursor.execute(f'''CREATE UNIQUE INDEX IF NOT EXISTS {self._table_name}_id ON {self._table_name}(id);''')
finally:
self._lock.release()
def _parse_item(self, item: str) -> T:
item_type = get_args(self.__orig_class__)[0]
return parse_raw_as(item_type, item)
def set(self, item: T):
try:
self._lock.acquire()
self._cursor.execute(f'''INSERT OR REPLACE INTO {self._table_name} (item) VALUES (?);''', (item.json(),))
finally:
self._lock.release()
self._on_changed(item)
def get(self, id: str) -> Union[T, None]:
try:
self._lock.acquire()
self._cursor.execute(f'''SELECT item FROM {self._table_name} WHERE id = ?;''', (str(id),))
result = self._cursor.fetchone()
finally:
self._lock.release()
if not result:
return None
return self._parse_item(result[0])
def delete(self, id: str):
try:
self._lock.acquire()
self._cursor.execute(f'''DELETE FROM {self._table_name} WHERE id = ?;''', (str(id),))
finally:
self._lock.release()
self._on_deleted(id)
def list(self, page: int = 0, per_page: int = 10) -> PaginatedResults[T]:
try:
self._lock.acquire()
self._cursor.execute(f'''SELECT item FROM {self._table_name} LIMIT ? OFFSET ?;''', (per_page, page * per_page))
result = self._cursor.fetchall()
items = list(map(lambda r: self._parse_item(r[0]), result))
self._cursor.execute(f'''SELECT count(*) FROM {self._table_name};''')
count = self._cursor.fetchone()[0]
finally:
self._lock.release()
pageCount = int(count / per_page) + 1
return PaginatedResults[T](
items = items,
page = page,
pages = pageCount,
per_page = per_page,
total = count
)
def search(self, query: str, page: int = 0, per_page: int = 10) -> PaginatedResults[T]:
try:
self._lock.acquire()
self._cursor.execute(f'''SELECT item FROM {self._table_name} WHERE item LIKE ? LIMIT ? OFFSET ?;''', (f'%{query}%', per_page, page * per_page))
result = self._cursor.fetchall()
items = list(map(lambda r: self._parse_item(r[0]), result))
self._cursor.execute(f'''SELECT count(*) FROM {self._table_name} WHERE item LIKE ?;''', (f'%{query}%',))
count = self._cursor.fetchone()[0]
finally:
self._lock.release()
pageCount = int(count / per_page) + 1
return PaginatedResults[T](
items = items,
page = page,
pages = pageCount,
per_page = per_page,
total = count
)

View File

@ -1,8 +1,8 @@
'''
Initialization file for invokeai.backend
'''
# this is causing circular import issues
# from .invoke_ai_web_server import InvokeAIWebServer
from .model_manager import ModelManager
from .model_management import ModelManager
from .generate import Generate

1347
invokeai/backend/args.py Normal file

File diff suppressed because it is too large Load Diff

View File

View File

@ -0,0 +1,860 @@
#!/usr/bin/env python
# Copyright (c) 2022 Lincoln D. Stein (https://github.com/lstein)
# Before running stable-diffusion on an internet-isolated machine,
# run this script from one with internet connectivity. The
# two machines must share a common .cache directory.
#
# Coauthor: Kevin Turner http://github.com/keturn
#
print("Loading Python libraries...\n")
import argparse
import io
import os
import re
import shutil
import sys
import traceback
import warnings
from argparse import Namespace
from pathlib import Path
from urllib import request
from shutil import get_terminal_size
import npyscreen
import torch
import transformers
from diffusers import AutoencoderKL
from huggingface_hub import HfFolder
from huggingface_hub import login as hf_hub_login
from omegaconf import OmegaConf
from tqdm import tqdm
from transformers import (
AutoProcessor,
CLIPSegForImageSegmentation,
CLIPTextModel,
CLIPTokenizer,
)
import invokeai.configs as configs
from ..args import PRECISION_CHOICES, Args
from ..globals import Globals, global_config_dir, global_config_file, global_cache_dir
from ...frontend.config.model_install import addModelsForm, process_and_execute
from .model_install_backend import (
default_dataset,
download_from_hf,
recommended_datasets,
hf_download_with_resume,
)
from ...frontend.config.widgets import IntTitleSlider, CenteredButtonPress, set_min_terminal_size
warnings.filterwarnings("ignore")
transformers.logging.set_verbosity_error()
# --------------------------globals-----------------------
Model_dir = "models"
Weights_dir = "ldm/stable-diffusion-v1/"
# the initial "configs" dir is now bundled in the `invokeai.configs` package
Dataset_path = Path(configs.__path__[0]) / "INITIAL_MODELS.yaml"
Default_config_file = Path(global_config_dir()) / "models.yaml"
SD_Configs = Path(global_config_dir()) / "stable-diffusion"
Datasets = OmegaConf.load(Dataset_path)
# minimum size for the UI
MIN_COLS = 135
MIN_LINES = 45
INIT_FILE_PREAMBLE = """# InvokeAI initialization file
# This is the InvokeAI initialization file, which contains command-line default values.
# Feel free to edit. If anything goes wrong, you can re-initialize this file by deleting
# or renaming it and then running invokeai-configure again.
# Place frequently-used startup commands here, one or more per line.
# Examples:
# --outdir=D:\data\images
# --no-nsfw_checker
# --web --host=0.0.0.0
# --steps=20
# -Ak_euler_a -C10.0
"""
# --------------------------------------------
def postscript(errors: None):
if not any(errors):
message = f"""
** INVOKEAI INSTALLATION SUCCESSFUL **
If you installed manually from source or with 'pip install': activate the virtual environment
then run one of the following commands to start InvokeAI.
Web UI:
invokeai --web # (connect to http://localhost:9090)
invokeai --web --host 0.0.0.0 # (connect to http://your-lan-ip:9090 from another computer on the local network)
Command-line interface:
invokeai
If you installed using an installation script, run:
{Globals.root}/invoke.{"bat" if sys.platform == "win32" else "sh"}
Add the '--help' argument to see all of the command-line switches available for use.
"""
else:
message = "\n** There were errors during installation. It is possible some of the models were not fully downloaded.\n"
for err in errors:
message += f"\t - {err}\n"
message += "Please check the logs above and correct any issues."
print(message)
# ---------------------------------------------
def yes_or_no(prompt: str, default_yes=True):
default = "y" if default_yes else "n"
response = input(f"{prompt} [{default}] ") or default
if default_yes:
return response[0] not in ("n", "N")
else:
return response[0] in ("y", "Y")
# ---------------------------------------------
def HfLogin(access_token) -> str:
"""
Helper for logging in to Huggingface
The stdout capture is needed to hide the irrelevant "git credential helper" warning
"""
capture = io.StringIO()
sys.stdout = capture
try:
hf_hub_login(token=access_token, add_to_git_credential=False)
sys.stdout = sys.__stdout__
except Exception as exc:
sys.stdout = sys.__stdout__
print(exc)
raise exc
# -------------------------------------
class ProgressBar:
def __init__(self, model_name="file"):
self.pbar = None
self.name = model_name
def __call__(self, block_num, block_size, total_size):
if not self.pbar:
self.pbar = tqdm(
desc=self.name,
initial=0,
unit="iB",
unit_scale=True,
unit_divisor=1000,
total=total_size,
)
self.pbar.update(block_size)
# ---------------------------------------------
def download_with_progress_bar(model_url: str, model_dest: str, label: str = "the"):
try:
print(f"Installing {label} model file {model_url}...", end="", file=sys.stderr)
if not os.path.exists(model_dest):
os.makedirs(os.path.dirname(model_dest), exist_ok=True)
request.urlretrieve(
model_url, model_dest, ProgressBar(os.path.basename(model_dest))
)
print("...downloaded successfully", file=sys.stderr)
else:
print("...exists", file=sys.stderr)
except Exception:
print("...download failed", file=sys.stderr)
print(f"Error downloading {label} model", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
# ---------------------------------------------
# this will preload the Bert tokenizer fles
def download_bert():
print(
"Installing bert tokenizer...",
file=sys.stderr
)
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=DeprecationWarning)
from transformers import BertTokenizerFast
download_from_hf(BertTokenizerFast, "bert-base-uncased")
# ---------------------------------------------
def download_sd1_clip():
print("Installing SD1 clip model...", file=sys.stderr)
version = "openai/clip-vit-large-patch14"
download_from_hf(CLIPTokenizer, version)
download_from_hf(CLIPTextModel, version)
# ---------------------------------------------
def download_sd2_clip():
version = 'stabilityai/stable-diffusion-2'
print("Installing SD2 clip model...", file=sys.stderr)
download_from_hf(CLIPTokenizer, version, subfolder='tokenizer')
download_from_hf(CLIPTextModel, version, subfolder='text_encoder')
# ---------------------------------------------
def download_realesrgan():
print("Installing models from RealESRGAN...", file=sys.stderr)
model_url = "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth"
wdn_model_url = "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-wdn-x4v3.pth"
model_dest = os.path.join(
Globals.root, "models/realesrgan/realesr-general-x4v3.pth"
)
wdn_model_dest = os.path.join(
Globals.root, "models/realesrgan/realesr-general-wdn-x4v3.pth"
)
download_with_progress_bar(model_url, model_dest, "RealESRGAN")
download_with_progress_bar(wdn_model_url, wdn_model_dest, "RealESRGANwdn")
def download_gfpgan():
print("Installing GFPGAN models...", file=sys.stderr)
for model in (
[
"https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth",
"./models/gfpgan/GFPGANv1.4.pth",
],
[
"https://github.com/xinntao/facexlib/releases/download/v0.1.0/detection_Resnet50_Final.pth",
"./models/gfpgan/weights/detection_Resnet50_Final.pth",
],
[
"https://github.com/xinntao/facexlib/releases/download/v0.2.2/parsing_parsenet.pth",
"./models/gfpgan/weights/parsing_parsenet.pth",
],
):
model_url, model_dest = model[0], os.path.join(Globals.root, model[1])
download_with_progress_bar(model_url, model_dest, "GFPGAN weights")
# ---------------------------------------------
def download_codeformer():
print("Installing CodeFormer model file...", file=sys.stderr)
model_url = (
"https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth"
)
model_dest = os.path.join(Globals.root, "models/codeformer/codeformer.pth")
download_with_progress_bar(model_url, model_dest, "CodeFormer")
# ---------------------------------------------
def download_clipseg():
print("Installing clipseg model for text-based masking...", file=sys.stderr)
CLIPSEG_MODEL = "CIDAS/clipseg-rd64-refined"
try:
download_from_hf(AutoProcessor, CLIPSEG_MODEL)
download_from_hf(CLIPSegForImageSegmentation, CLIPSEG_MODEL)
except Exception:
print("Error installing clipseg model:")
print(traceback.format_exc())
# -------------------------------------
def download_safety_checker():
print("Installing model for NSFW content detection...", file=sys.stderr)
try:
from diffusers.pipelines.stable_diffusion.safety_checker import (
StableDiffusionSafetyChecker,
)
from transformers import AutoFeatureExtractor
except ModuleNotFoundError:
print("Error installing NSFW checker model:")
print(traceback.format_exc())
return
safety_model_id = "CompVis/stable-diffusion-safety-checker"
print("AutoFeatureExtractor...", file=sys.stderr)
download_from_hf(AutoFeatureExtractor, safety_model_id)
print("StableDiffusionSafetyChecker...", file=sys.stderr)
download_from_hf(StableDiffusionSafetyChecker, safety_model_id)
# -------------------------------------
def download_vaes():
print("Installing stabilityai VAE...", file=sys.stderr)
try:
# first the diffusers version
repo_id = "stabilityai/sd-vae-ft-mse"
args = dict(
cache_dir=global_cache_dir("diffusers"),
)
if not AutoencoderKL.from_pretrained(repo_id, **args):
raise Exception(f"download of {repo_id} failed")
repo_id = "stabilityai/sd-vae-ft-mse-original"
model_name = "vae-ft-mse-840000-ema-pruned.ckpt"
# next the legacy checkpoint version
if not hf_download_with_resume(
repo_id=repo_id,
model_name=model_name,
model_dir=str(Globals.root / Model_dir / Weights_dir),
):
raise Exception(f"download of {model_name} failed")
except Exception as e:
print(f"Error downloading StabilityAI standard VAE: {str(e)}", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
# -------------------------------------
def get_root(root: str = None) -> str:
if root:
return root
elif os.environ.get("INVOKEAI_ROOT"):
return os.environ.get("INVOKEAI_ROOT")
else:
return Globals.root
# -------------------------------------
class editOptsForm(npyscreen.FormMultiPage):
# for responsive resizing - disabled
# FIX_MINIMUM_SIZE_WHEN_CREATED = False
def create(self):
program_opts = self.parentApp.program_opts
old_opts = self.parentApp.invokeai_opts
first_time = not (Globals.root / Globals.initfile).exists()
access_token = HfFolder.get_token()
window_width,window_height = get_terminal_size()
for i in [
"Configure startup settings. You can come back and change these later.",
"Use ctrl-N and ctrl-P to move to the <N>ext and <P>revious fields.",
"Use cursor arrows to make a checkbox selection, and space to toggle.",
]:
self.add_widget_intelligent(
npyscreen.FixedText,
value=i,
editable=False,
color="CONTROL",
)
self.nextrely += 1
self.add_widget_intelligent(
npyscreen.TitleFixedText,
name="== BASIC OPTIONS ==",
begin_entry_at=0,
editable=False,
color="CONTROL",
scroll_exit=True,
)
self.nextrely -= 1
self.add_widget_intelligent(
npyscreen.FixedText,
value="Select an output directory for images:",
editable=False,
color="CONTROL",
)
self.outdir = self.add_widget_intelligent(
npyscreen.TitleFilename,
name="(<tab> autocompletes, ctrl-N advances):",
value=old_opts.outdir or str(default_output_dir()),
select_dir=True,
must_exist=False,
use_two_lines=False,
labelColor="GOOD",
begin_entry_at=40,
scroll_exit=True,
)
self.nextrely += 1
self.add_widget_intelligent(
npyscreen.FixedText,
value="Activate the NSFW checker to blur images showing potential sexual imagery:",
editable=False,
color="CONTROL",
)
self.safety_checker = self.add_widget_intelligent(
npyscreen.Checkbox,
name="NSFW checker",
value=old_opts.safety_checker,
relx=5,
scroll_exit=True,
)
self.nextrely += 1
for i in [
"If you have an account at HuggingFace you may paste your access token here",
'to allow InvokeAI to download styles & subjects from the "Concept Library".',
"See https://huggingface.co/settings/tokens",
]:
self.add_widget_intelligent(
npyscreen.FixedText,
value=i,
editable=False,
color="CONTROL",
)
self.hf_token = self.add_widget_intelligent(
npyscreen.TitlePassword,
name="Access Token (ctrl-shift-V pastes):",
value=access_token,
begin_entry_at=42,
use_two_lines=False,
scroll_exit=True,
)
self.nextrely += 1
self.add_widget_intelligent(
npyscreen.TitleFixedText,
name="== ADVANCED OPTIONS ==",
begin_entry_at=0,
editable=False,
color="CONTROL",
scroll_exit=True,
)
self.nextrely -= 1
self.add_widget_intelligent(
npyscreen.TitleFixedText,
name="GPU Management",
begin_entry_at=0,
editable=False,
color="CONTROL",
scroll_exit=True,
)
self.nextrely -= 1
self.free_gpu_mem = self.add_widget_intelligent(
npyscreen.Checkbox,
name="Free GPU memory after each generation",
value=old_opts.free_gpu_mem,
relx=5,
scroll_exit=True,
)
self.xformers = self.add_widget_intelligent(
npyscreen.Checkbox,
name="Enable xformers support if available",
value=old_opts.xformers,
relx=5,
scroll_exit=True,
)
self.ckpt_convert = self.add_widget_intelligent(
npyscreen.Checkbox,
name="Load legacy checkpoint models into memory as diffusers models",
value=old_opts.ckpt_convert,
relx=5,
scroll_exit=True,
)
self.always_use_cpu = self.add_widget_intelligent(
npyscreen.Checkbox,
name="Force CPU to be used on GPU systems",
value=old_opts.always_use_cpu,
relx=5,
scroll_exit=True,
)
precision = old_opts.precision or (
"float32" if program_opts.full_precision else "auto"
)
self.precision = self.add_widget_intelligent(
npyscreen.TitleSelectOne,
name="Precision",
values=PRECISION_CHOICES,
value=PRECISION_CHOICES.index(precision),
begin_entry_at=3,
max_height=len(PRECISION_CHOICES) + 1,
scroll_exit=True,
)
self.max_loaded_models = self.add_widget_intelligent(
IntTitleSlider,
name="Number of models to cache in CPU memory (each will use 2-4 GB!)",
value=old_opts.max_loaded_models,
out_of=10,
lowest=1,
begin_entry_at=4,
scroll_exit=True,
)
self.nextrely += 1
self.add_widget_intelligent(
npyscreen.FixedText,
value="Directory containing embedding/textual inversion files:",
editable=False,
color="CONTROL",
)
self.embedding_path = self.add_widget_intelligent(
npyscreen.TitleFilename,
name="(<tab> autocompletes, ctrl-N advances):",
value=str(default_embedding_dir()),
select_dir=True,
must_exist=False,
use_two_lines=False,
labelColor="GOOD",
begin_entry_at=40,
scroll_exit=True,
)
self.nextrely += 1
self.add_widget_intelligent(
npyscreen.TitleFixedText,
name="== LICENSE ==",
begin_entry_at=0,
editable=False,
color="CONTROL",
scroll_exit=True,
)
self.nextrely -= 1
for i in [
"BY DOWNLOADING THE STABLE DIFFUSION WEIGHT FILES, YOU AGREE TO HAVE READ",
"AND ACCEPTED THE CREATIVEML RESPONSIBLE AI LICENSE LOCATED AT",
"https://huggingface.co/spaces/CompVis/stable-diffusion-license",
]:
self.add_widget_intelligent(
npyscreen.FixedText,
value=i,
editable=False,
color="CONTROL",
)
self.license_acceptance = self.add_widget_intelligent(
npyscreen.Checkbox,
name="I accept the CreativeML Responsible AI License",
value=not first_time,
relx=2,
scroll_exit=True,
)
self.nextrely += 1
label = (
"DONE"
if program_opts.skip_sd_weights or program_opts.default_only
else "NEXT"
)
self.ok_button = self.add_widget_intelligent(
CenteredButtonPress,
name=label,
relx=(window_width - len(label)) // 2,
rely=-3,
when_pressed_function=self.on_ok,
)
def on_ok(self):
options = self.marshall_arguments()
if self.validate_field_values(options):
self.parentApp.new_opts = options
if hasattr(self.parentApp, "model_select"):
self.parentApp.setNextForm("MODELS")
else:
self.parentApp.setNextForm(None)
self.editing = False
else:
self.editing = True
def validate_field_values(self, opt: Namespace) -> bool:
bad_fields = []
if not opt.license_acceptance:
bad_fields.append(
"Please accept the license terms before proceeding to model downloads"
)
if not Path(opt.outdir).parent.exists():
bad_fields.append(
f"The output directory does not seem to be valid. Please check that {str(Path(opt.outdir).parent)} is an existing directory."
)
if not Path(opt.embedding_path).parent.exists():
bad_fields.append(
f"The embedding directory does not seem to be valid. Please check that {str(Path(opt.embedding_path).parent)} is an existing directory."
)
if len(bad_fields) > 0:
message = "The following problems were detected and must be corrected:\n"
for problem in bad_fields:
message += f"* {problem}\n"
npyscreen.notify_confirm(message)
return False
else:
return True
def marshall_arguments(self):
new_opts = Namespace()
for attr in [
"outdir",
"safety_checker",
"free_gpu_mem",
"max_loaded_models",
"xformers",
"always_use_cpu",
"embedding_path",
"ckpt_convert",
]:
setattr(new_opts, attr, getattr(self, attr).value)
new_opts.hf_token = self.hf_token.value
new_opts.license_acceptance = self.license_acceptance.value
new_opts.precision = PRECISION_CHOICES[self.precision.value[0]]
return new_opts
class EditOptApplication(npyscreen.NPSAppManaged):
def __init__(self, program_opts: Namespace, invokeai_opts: Namespace):
super().__init__()
self.program_opts = program_opts
self.invokeai_opts = invokeai_opts
self.user_cancelled = False
self.user_selections = default_user_selections(program_opts)
def onStart(self):
npyscreen.setTheme(npyscreen.Themes.DefaultTheme)
self.options = self.addForm(
"MAIN",
editOptsForm,
name="InvokeAI Startup Options",
)
if not (self.program_opts.skip_sd_weights or self.program_opts.default_only):
self.model_select = self.addForm(
"MODELS",
addModelsForm,
name="Install Stable Diffusion Models",
multipage=True,
)
def new_opts(self):
return self.options.marshall_arguments()
def edit_opts(program_opts: Namespace, invokeai_opts: Namespace) -> argparse.Namespace:
editApp = EditOptApplication(program_opts, invokeai_opts)
editApp.run()
return editApp.new_opts()
def default_startup_options(init_file: Path) -> Namespace:
opts = Args().parse_args([])
outdir = Path(opts.outdir)
if not outdir.is_absolute():
opts.outdir = str(Globals.root / opts.outdir)
if not init_file.exists():
opts.safety_checker = True
return opts
def default_user_selections(program_opts: Namespace) -> Namespace:
return Namespace(
starter_models=default_dataset()
if program_opts.default_only
else recommended_datasets()
if program_opts.yes_to_all
else dict(),
purge_deleted_models=False,
scan_directory=None,
autoscan_on_startup=None,
import_model_paths=None,
convert_to_diffusers=None,
)
# -------------------------------------
def initialize_rootdir(root: str, yes_to_all: bool = False):
print("** INITIALIZING INVOKEAI RUNTIME DIRECTORY **")
for name in (
"models",
"configs",
"embeddings",
"text-inversion-output",
"text-inversion-training-data",
):
os.makedirs(os.path.join(root, name), exist_ok=True)
configs_src = Path(configs.__path__[0])
configs_dest = Path(root) / "configs"
if not os.path.samefile(configs_src, configs_dest):
shutil.copytree(configs_src, configs_dest, dirs_exist_ok=True)
# -------------------------------------
def run_console_ui(
program_opts: Namespace, initfile: Path = None
) -> (Namespace, Namespace):
# parse_args() will read from init file if present
invokeai_opts = default_startup_options(initfile)
set_min_terminal_size(MIN_COLS, MIN_LINES)
editApp = EditOptApplication(program_opts, invokeai_opts)
editApp.run()
if editApp.user_cancelled:
return (None, None)
else:
return (editApp.new_opts, editApp.user_selections)
# -------------------------------------
def write_opts(opts: Namespace, init_file: Path):
"""
Update the invokeai.init file with values from opts Namespace
"""
# touch file if it doesn't exist
if not init_file.exists():
with open(init_file, "w") as f:
f.write(INIT_FILE_PREAMBLE)
# We want to write in the changed arguments without clobbering
# any other initialization values the user has entered. There is
# no good way to do this because of the one-way nature of
# argparse: i.e. --outdir could be --outdir, --out, or -o
# initfile needs to be replaced with a fully structured format
# such as yaml; this is a hack that will work much of the time
args_to_skip = re.compile(
"^--?(o|out|no-xformer|xformer|no-ckpt|ckpt|free|no-nsfw|nsfw|prec|max_load|embed|always|ckpt|free_gpu)"
)
# fix windows paths
opts.outdir = opts.outdir.replace('\\','/')
opts.embedding_path = opts.embedding_path.replace('\\','/')
new_file = f"{init_file}.new"
try:
lines = [x.strip() for x in open(init_file, "r").readlines()]
with open(new_file, "w") as out_file:
for line in lines:
if len(line) > 0 and not args_to_skip.match(line):
out_file.write(line + "\n")
out_file.write(
f"""
--outdir={opts.outdir}
--embedding_path={opts.embedding_path}
--precision={opts.precision}
--max_loaded_models={int(opts.max_loaded_models)}
--{'no-' if not opts.safety_checker else ''}nsfw_checker
--{'no-' if not opts.xformers else ''}xformers
--{'no-' if not opts.ckpt_convert else ''}ckpt_convert
{'--free_gpu_mem' if opts.free_gpu_mem else ''}
{'--always_use_cpu' if opts.always_use_cpu else ''}
"""
)
except OSError as e:
print(f"** An error occurred while writing the init file: {str(e)}")
os.replace(new_file, init_file)
if opts.hf_token:
HfLogin(opts.hf_token)
# -------------------------------------
def default_output_dir() -> Path:
return Globals.root / "outputs"
# -------------------------------------
def default_embedding_dir() -> Path:
return Globals.root / "embeddings"
# -------------------------------------
def write_default_options(program_opts: Namespace, initfile: Path):
opt = default_startup_options(initfile)
opt.hf_token = HfFolder.get_token()
write_opts(opt, initfile)
# -------------------------------------
def main():
parser = argparse.ArgumentParser(description="InvokeAI model downloader")
parser.add_argument(
"--skip-sd-weights",
dest="skip_sd_weights",
action=argparse.BooleanOptionalAction,
default=False,
help="skip downloading the large Stable Diffusion weight files",
)
parser.add_argument(
"--skip-support-models",
dest="skip_support_models",
action=argparse.BooleanOptionalAction,
default=False,
help="skip downloading the support models",
)
parser.add_argument(
"--full-precision",
dest="full_precision",
action=argparse.BooleanOptionalAction,
type=bool,
default=False,
help="use 32-bit weights instead of faster 16-bit weights",
)
parser.add_argument(
"--yes",
"-y",
dest="yes_to_all",
action="store_true",
help='answer "yes" to all prompts',
)
parser.add_argument(
"--default_only",
action="store_true",
help="when --yes specified, only install the default model",
)
parser.add_argument(
"--config_file",
"-c",
dest="config_file",
type=str,
default=None,
help="path to configuration file to create",
)
parser.add_argument(
"--root_dir",
dest="root",
type=str,
default=None,
help="path to root of install directory",
)
opt = parser.parse_args()
# setting a global here
Globals.root = Path(os.path.expanduser(get_root(opt.root) or ""))
errors = set()
try:
models_to_download = default_user_selections(opt)
# We check for to see if the runtime directory is correctly initialized.
init_file = Path(Globals.root, Globals.initfile)
if not init_file.exists() or not global_config_file().exists():
initialize_rootdir(Globals.root, opt.yes_to_all)
if opt.yes_to_all:
write_default_options(opt, init_file)
init_options = Namespace(
precision="float32" if opt.full_precision else "float16"
)
else:
init_options, models_to_download = run_console_ui(opt, init_file)
if init_options:
write_opts(init_options, init_file)
else:
print(
'\n** CANCELLED AT USER\'S REQUEST. USE THE "invoke.sh" LAUNCHER TO RUN LATER **\n'
)
sys.exit(0)
if opt.skip_support_models:
print("\n** SKIPPING SUPPORT MODEL DOWNLOADS PER USER REQUEST **")
else:
print("\n** DOWNLOADING SUPPORT MODELS **")
download_bert()
download_sd1_clip()
download_sd2_clip()
download_realesrgan()
download_gfpgan()
download_codeformer()
download_clipseg()
download_safety_checker()
download_vaes()
if opt.skip_sd_weights:
print("\n** SKIPPING DIFFUSION WEIGHTS DOWNLOAD PER USER REQUEST **")
elif models_to_download:
print("\n** DOWNLOADING DIFFUSION WEIGHTS **")
process_and_execute(opt, models_to_download)
postscript(errors=errors)
except KeyboardInterrupt:
print("\nGoodbye! Come back soon.")
# -------------------------------------
if __name__ == "__main__":
main()

View File

@ -0,0 +1,455 @@
"""
Utility (backend) functions used by model_install.py
"""
import os
import re
import shutil
import sys
import warnings
from pathlib import Path
from tempfile import TemporaryFile
import requests
from diffusers import AutoencoderKL
from huggingface_hub import hf_hub_url
from omegaconf import OmegaConf
from omegaconf.dictconfig import DictConfig
from tqdm import tqdm
from typing import List
import invokeai.configs as configs
from ..stable_diffusion import StableDiffusionGeneratorPipeline
from ..globals import Globals, global_cache_dir, global_config_dir
from ..model_management import ModelManager
warnings.filterwarnings("ignore")
# --------------------------globals-----------------------
Model_dir = "models"
Weights_dir = "ldm/stable-diffusion-v1/"
# the initial "configs" dir is now bundled in the `invokeai.configs` package
Dataset_path = Path(configs.__path__[0]) / "INITIAL_MODELS.yaml"
# initial models omegaconf
Datasets = None
Config_preamble = """
# This file describes the alternative machine learning models
# available to InvokeAI script.
#
# To add a new model, follow the examples below. Each
# model requires a model config file, a weights file,
# and the width and height of the images it
# was trained on.
"""
def default_config_file():
return Path(global_config_dir()) / "models.yaml"
def sd_configs():
return Path(global_config_dir()) / "stable-diffusion"
def initial_models():
global Datasets
if Datasets:
return Datasets
return (Datasets := OmegaConf.load(Dataset_path))
def install_requested_models(
install_initial_models: List[str] = None,
remove_models: List[str] = None,
scan_directory: Path = None,
external_models: List[str] = None,
scan_at_startup: bool = False,
convert_to_diffusers: bool = False,
precision: str = "float16",
purge_deleted: bool = False,
config_file_path: Path = None,
):
'''
Entry point for installing/deleting starter models, or installing external models.
'''
config_file_path=config_file_path or default_config_file()
if not config_file_path.exists():
open(config_file_path,'w')
model_manager= ModelManager(OmegaConf.load(config_file_path),precision=precision)
if remove_models and len(remove_models) > 0:
print("== DELETING UNCHECKED STARTER MODELS ==")
for model in remove_models:
print(f'{model}...')
model_manager.del_model(model, delete_files=purge_deleted)
model_manager.commit(config_file_path)
if install_initial_models and len(install_initial_models) > 0:
print("== INSTALLING SELECTED STARTER MODELS ==")
successfully_downloaded = download_weight_datasets(
models=install_initial_models,
access_token=None,
precision=precision,
) # FIX: for historical reasons, we don't use model manager here
update_config_file(successfully_downloaded, config_file_path)
if len(successfully_downloaded) < len(install_initial_models):
print("** Some of the model downloads were not successful")
# due to above, we have to reload the model manager because conf file
# was changed behind its back
model_manager= ModelManager(OmegaConf.load(config_file_path),precision=precision)
external_models = external_models or list()
if scan_directory:
external_models.append(str(scan_directory))
if len(external_models)>0:
print("== INSTALLING EXTERNAL MODELS ==")
for path_url_or_repo in external_models:
try:
model_manager.heuristic_import(
path_url_or_repo,
convert=convert_to_diffusers,
commit_to_conf=config_file_path
)
except KeyboardInterrupt:
sys.exit(-1)
except Exception:
pass
if scan_at_startup and scan_directory.is_dir():
argument = '--autoconvert' if convert_to_diffusers else '--autoimport'
initfile = Path(Globals.root, Globals.initfile)
replacement = Path(Globals.root, f'{Globals.initfile}.new')
directory = str(scan_directory).replace('\\','/')
with open(initfile,'r') as input:
with open(replacement,'w') as output:
while line := input.readline():
if not line.startswith(argument):
output.writelines([line])
output.writelines([f'{argument} {directory}'])
os.replace(replacement,initfile)
# -------------------------------------
def yes_or_no(prompt: str, default_yes=True):
default = "y" if default_yes else "n"
response = input(f"{prompt} [{default}] ") or default
if default_yes:
return response[0] not in ("n", "N")
else:
return response[0] in ("y", "Y")
# -------------------------------------
def get_root(root: str = None) -> str:
if root:
return root
elif os.environ.get("INVOKEAI_ROOT"):
return os.environ.get("INVOKEAI_ROOT")
else:
return Globals.root
# ---------------------------------------------
def recommended_datasets() -> dict:
datasets = dict()
for ds in initial_models().keys():
if initial_models()[ds].get("recommended", False):
datasets[ds] = True
return datasets
# ---------------------------------------------
def default_dataset() -> dict:
datasets = dict()
for ds in initial_models().keys():
if initial_models()[ds].get("default", False):
datasets[ds] = True
return datasets
# ---------------------------------------------
def all_datasets() -> dict:
datasets = dict()
for ds in initial_models().keys():
datasets[ds] = True
return datasets
# ---------------------------------------------
# look for legacy model.ckpt in models directory and offer to
# normalize its name
def migrate_models_ckpt():
model_path = os.path.join(Globals.root, Model_dir, Weights_dir)
if not os.path.exists(os.path.join(model_path, "model.ckpt")):
return
new_name = initial_models()["stable-diffusion-1.4"]["file"]
print('The Stable Diffusion v4.1 "model.ckpt" is already installed. The name will be changed to {new_name} to avoid confusion.')
print(f"model.ckpt => {new_name}")
os.replace(
os.path.join(model_path, "model.ckpt"), os.path.join(model_path, new_name)
)
# ---------------------------------------------
def download_weight_datasets(
models: List[str], access_token: str, precision: str = "float32"
):
migrate_models_ckpt()
successful = dict()
for mod in models:
print(f"Downloading {mod}:")
successful[mod] = _download_repo_or_file(
initial_models()[mod], access_token, precision=precision
)
return successful
def _download_repo_or_file(
mconfig: DictConfig, access_token: str, precision: str = "float32"
) -> Path:
path = None
if mconfig["format"] == "ckpt":
path = _download_ckpt_weights(mconfig, access_token)
else:
path = _download_diffusion_weights(mconfig, access_token, precision=precision)
if "vae" in mconfig and "repo_id" in mconfig["vae"]:
_download_diffusion_weights(
mconfig["vae"], access_token, precision=precision
)
return path
def _download_ckpt_weights(mconfig: DictConfig, access_token: str) -> Path:
repo_id = mconfig["repo_id"]
filename = mconfig["file"]
cache_dir = os.path.join(Globals.root, Model_dir, Weights_dir)
return hf_download_with_resume(
repo_id=repo_id,
model_dir=cache_dir,
model_name=filename,
access_token=access_token,
)
# ---------------------------------------------
def download_from_hf(
model_class: object, model_name: str, cache_subdir: Path = Path("hub"), **kwargs
):
path = global_cache_dir(cache_subdir)
model = model_class.from_pretrained(
model_name,
cache_dir=path,
resume_download=True,
**kwargs,
)
model_name = "--".join(("models", *model_name.split("/")))
return path / model_name if model else None
def _download_diffusion_weights(
mconfig: DictConfig, access_token: str, precision: str = "float32"
):
repo_id = mconfig["repo_id"]
model_class = (
StableDiffusionGeneratorPipeline
if mconfig.get("format", None) == "diffusers"
else AutoencoderKL
)
extra_arg_list = [{"revision": "fp16"}, {}] if precision == "float16" else [{}]
path = None
for extra_args in extra_arg_list:
try:
path = download_from_hf(
model_class,
repo_id,
cache_subdir="diffusers",
safety_checker=None,
**extra_args,
)
except OSError as e:
if str(e).startswith("fp16 is not a valid"):
pass
else:
print(f"An unexpected error occurred while downloading the model: {e})")
if path:
break
return path
# ---------------------------------------------
def hf_download_with_resume(
repo_id: str, model_dir: str, model_name: str, access_token: str = None
) -> Path:
model_dest = Path(os.path.join(model_dir, model_name))
os.makedirs(model_dir, exist_ok=True)
url = hf_hub_url(repo_id, model_name)
header = {"Authorization": f"Bearer {access_token}"} if access_token else {}
open_mode = "wb"
exist_size = 0
if os.path.exists(model_dest):
exist_size = os.path.getsize(model_dest)
header["Range"] = f"bytes={exist_size}-"
open_mode = "ab"
resp = requests.get(url, headers=header, stream=True)
total = int(resp.headers.get("content-length", 0))
if (
resp.status_code == 416
): # "range not satisfiable", which means nothing to return
print(f"* {model_name}: complete file found. Skipping.")
return model_dest
elif resp.status_code != 200:
print(f"** An error occurred during downloading {model_name}: {resp.reason}")
elif exist_size > 0:
print(f"* {model_name}: partial file found. Resuming...")
else:
print(f"* {model_name}: Downloading...")
try:
if total < 2000:
print(f"*** ERROR DOWNLOADING {model_name}: {resp.text}")
return None
with open(model_dest, open_mode) as file, tqdm(
desc=model_name,
initial=exist_size,
total=total + exist_size,
unit="iB",
unit_scale=True,
unit_divisor=1000,
) as bar:
for data in resp.iter_content(chunk_size=1024):
size = file.write(data)
bar.update(size)
except Exception as e:
print(f"An error occurred while downloading {model_name}: {str(e)}")
return None
return model_dest
# ---------------------------------------------
def update_config_file(successfully_downloaded: dict, config_file: Path):
config_file = (
Path(config_file) if config_file is not None else default_config_file()
)
# In some cases (incomplete setup, etc), the default configs directory might be missing.
# Create it if it doesn't exist.
# this check is ignored if opt.config_file is specified - user is assumed to know what they
# are doing if they are passing a custom config file from elsewhere.
if config_file is default_config_file() and not config_file.parent.exists():
configs_src = Dataset_path.parent
configs_dest = default_config_file().parent
shutil.copytree(configs_src, configs_dest, dirs_exist_ok=True)
yaml = new_config_file_contents(successfully_downloaded, config_file)
try:
backup = None
if os.path.exists(config_file):
print(
f"** {config_file.name} exists. Renaming to {config_file.stem}.yaml.orig"
)
backup = config_file.with_suffix(".yaml.orig")
## Ugh. Windows is unable to overwrite an existing backup file, raises a WinError 183
if sys.platform == "win32" and backup.is_file():
backup.unlink()
config_file.rename(backup)
with TemporaryFile() as tmp:
tmp.write(Config_preamble.encode())
tmp.write(yaml.encode())
with open(str(config_file.expanduser().resolve()), "wb") as new_config:
tmp.seek(0)
new_config.write(tmp.read())
except Exception as e:
print(f"**Error creating config file {config_file}: {str(e)} **")
if backup is not None:
print("restoring previous config file")
## workaround, for WinError 183, see above
if sys.platform == "win32" and config_file.is_file():
config_file.unlink()
backup.rename(config_file)
return
print(f"Successfully created new configuration file {config_file}")
# ---------------------------------------------
def new_config_file_contents(
successfully_downloaded: dict, config_file: Path,
) -> str:
if config_file.exists():
conf = OmegaConf.load(str(config_file.expanduser().resolve()))
else:
conf = OmegaConf.create()
default_selected = None
for model in successfully_downloaded:
# a bit hacky - what we are doing here is seeing whether a checkpoint
# version of the model was previously defined, and whether the current
# model is a diffusers (indicated with a path)
if conf.get(model) and Path(successfully_downloaded[model]).is_dir():
delete_weights(model, conf[model])
stanza = {}
mod = initial_models()[model]
stanza["description"] = mod["description"]
stanza["repo_id"] = mod["repo_id"]
stanza["format"] = mod["format"]
# diffusers don't need width and height (probably .ckpt doesn't either)
# so we no longer require these in INITIAL_MODELS.yaml
if "width" in mod:
stanza["width"] = mod["width"]
if "height" in mod:
stanza["height"] = mod["height"]
if "file" in mod:
stanza["weights"] = os.path.relpath(
successfully_downloaded[model], start=Globals.root
)
stanza["config"] = os.path.normpath(os.path.join(sd_configs(), mod["config"]))
if "vae" in mod:
if "file" in mod["vae"]:
stanza["vae"] = os.path.normpath(
os.path.join(Model_dir, Weights_dir, mod["vae"]["file"])
)
else:
stanza["vae"] = mod["vae"]
if mod.get("default", False):
stanza["default"] = True
default_selected = True
conf[model] = stanza
# if no default model was chosen, then we select the first
# one in the list
if not default_selected:
conf[list(successfully_downloaded.keys())[0]]["default"] = True
return OmegaConf.to_yaml(conf)
# ---------------------------------------------
def delete_weights(model_name: str, conf_stanza: dict):
if not (weights := conf_stanza.get("weights")):
return
if re.match("/VAE/", conf_stanza.get("config")):
return
print(
f"\n** The checkpoint version of {model_name} is superseded by the diffusers version. Deleting the original file {weights}?"
)
weights = Path(weights)
if not weights.is_absolute():
weights = Path(Globals.root) / weights
try:
weights.unlink()
except OSError as e:
print(str(e))

1349
invokeai/backend/generate.py Normal file

File diff suppressed because it is too large Load Diff

View File

@ -23,7 +23,7 @@ from tqdm import trange
import invokeai.assets.web as web_assets
from ..stable_diffusion.diffusion.ddpm import DiffusionWrapper
from ..util import rand_perlin_2d
from ..util.util import rand_perlin_2d
downsampling = 8
CAUTION_IMG = 'caution.png'

115
invokeai/backend/globals.py Normal file
View File

@ -0,0 +1,115 @@
'''
invokeai.backend.globals defines a small number of global variables that would
otherwise have to be passed through long and complex call chains.
It defines a Namespace object named "Globals" that contains
the attributes:
- root - the root directory under which "models" and "outputs" can be found
- initfile - path to the initialization file
- try_patchmatch - option to globally disable loading of 'patchmatch' module
- always_use_cpu - force use of CPU even if GPU is available
'''
import os
import os.path as osp
from argparse import Namespace
from pathlib import Path
from typing import Union
Globals = Namespace()
# Where to look for the initialization file and other key components
Globals.initfile = 'invokeai.init'
Globals.models_file = 'models.yaml'
Globals.models_dir = 'models'
Globals.config_dir = 'configs'
Globals.autoscan_dir = 'weights'
Globals.converted_ckpts_dir = 'converted_ckpts'
# Set the default root directory. This can be overwritten by explicitly
# passing the `--root <directory>` argument on the command line.
# logic is:
# 1) use INVOKEAI_ROOT environment variable (no check for this being a valid directory)
# 2) use VIRTUAL_ENV environment variable, with a check for initfile being there
# 3) use ~/invokeai
if os.environ.get('INVOKEAI_ROOT'):
Globals.root = osp.abspath(os.environ.get('INVOKEAI_ROOT'))
elif os.environ.get('VIRTUAL_ENV') and Path(os.environ.get('VIRTUAL_ENV'),'..',Globals.initfile).exists():
Globals.root = osp.abspath(osp.join(os.environ.get('VIRTUAL_ENV'), '..'))
else:
Globals.root = osp.abspath(osp.expanduser('~/invokeai'))
# Try loading patchmatch
Globals.try_patchmatch = True
# Use CPU even if GPU is available (main use case is for debugging MPS issues)
Globals.always_use_cpu = False
# Whether the internet is reachable for dynamic downloads
# The CLI will test connectivity at startup time.
Globals.internet_available = True
# Whether to disable xformers
Globals.disable_xformers = False
# Low-memory tradeoff for guidance calculations.
Globals.sequential_guidance = False
# whether we are forcing full precision
Globals.full_precision = False
# whether we should convert ckpt files into diffusers models on the fly
Globals.ckpt_convert = True
# logging tokenization everywhere
Globals.log_tokenization = False
def global_config_file()->Path:
return Path(Globals.root, Globals.config_dir, Globals.models_file)
def global_config_dir()->Path:
return Path(Globals.root, Globals.config_dir)
def global_models_dir()->Path:
return Path(Globals.root, Globals.models_dir)
def global_autoscan_dir()->Path:
return Path(Globals.root, Globals.autoscan_dir)
def global_converted_ckpts_dir()->Path:
return Path(global_models_dir(), Globals.converted_ckpts_dir)
def global_set_root(root_dir:Union[str,Path]):
Globals.root = root_dir
def global_cache_dir(subdir:Union[str,Path]='')->Path:
'''
Returns Path to the model cache directory. If a subdirectory
is provided, it will be appended to the end of the path, allowing
for huggingface-style conventions:
global_cache_dir('diffusers')
global_cache_dir('hub')
Current HuggingFace documentation (mid-Jan 2023) indicates that
transformers models will be cached into a "transformers" subdirectory,
but in practice they seem to go into "hub". But if needed:
global_cache_dir('transformers')
One other caveat is that HuggingFace is moving some diffusers models
into the "hub" subdirectory as well, so this will need to be revisited
from time to time.
'''
home: str = os.getenv('HF_HOME')
if home is None:
home = os.getenv('XDG_CACHE_HOME')
if home is not None:
# Set `home` to $XDG_CACHE_HOME/huggingface, which is the default location mentioned in HuggingFace Hub Client Library.
# See: https://huggingface.co/docs/huggingface_hub/main/en/package_reference/environment_variables#xdgcachehome
home += os.sep + 'huggingface'
if home is not None:
return Path(home,subdir)
else:
return Path(Globals.root,'models',subdir)

View File

@ -9,6 +9,7 @@ from .pngwriter import (PngWriter,
retrieve_metadata,
write_metadata,
)
from .seamless import configure_model_padding
def debug_image(
debug_image, debug_text, debug_show=True, debug_result=False, debug_status=False

View File

@ -4,7 +4,7 @@ wraps the actual patchmatch object. It respects the global
"try_patchmatch" attribute, so that patchmatch loading can
be suppressed or deferred
'''
from ldm.invoke.globals import Globals
from invokeai.backend.globals import Globals
import numpy as np
class PatchMatch:

View File

@ -0,0 +1,31 @@
import torch.nn as nn
def _conv_forward_asymmetric(self, input, weight, bias):
"""
Patch for Conv2d._conv_forward that supports asymmetric padding
"""
working = nn.functional.pad(input, self.asymmetric_padding['x'], mode=self.asymmetric_padding_mode['x'])
working = nn.functional.pad(working, self.asymmetric_padding['y'], mode=self.asymmetric_padding_mode['y'])
return nn.functional.conv2d(working, weight, bias, self.stride, nn.modules.utils._pair(0), self.dilation, self.groups)
def configure_model_padding(model, seamless, seamless_axes):
"""
Modifies the 2D convolution layers to use a circular padding mode based on the `seamless` and `seamless_axes` options.
"""
# TODO: get an explicit interface for this in diffusers: https://github.com/huggingface/diffusers/issues/556
for m in model.modules():
if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
if seamless:
m.asymmetric_padding_mode = {}
m.asymmetric_padding = {}
m.asymmetric_padding_mode['x'] = 'circular' if ('x' in seamless_axes) else 'constant'
m.asymmetric_padding['x'] = (m._reversed_padding_repeated_twice[0], m._reversed_padding_repeated_twice[1], 0, 0)
m.asymmetric_padding_mode['y'] = 'circular' if ('y' in seamless_axes) else 'constant'
m.asymmetric_padding['y'] = (0, 0, m._reversed_padding_repeated_twice[2], m._reversed_padding_repeated_twice[3])
m._conv_forward = _conv_forward_asymmetric.__get__(m, nn.Conv2d)
else:
m._conv_forward = nn.Conv2d._conv_forward.__get__(m, nn.Conv2d)
if hasattr(m, 'asymmetric_padding_mode'):
del m.asymmetric_padding_mode
if hasattr(m, 'asymmetric_padding'):
del m.asymmetric_padding

View File

@ -32,7 +32,7 @@ import numpy as np
from transformers import AutoProcessor, CLIPSegForImageSegmentation
from PIL import Image, ImageOps
from torchvision import transforms
from ldm.invoke.globals import global_cache_dir
from invokeai.backend.globals import global_cache_dir
CLIPSEG_MODEL = 'CIDAS/clipseg-rd64-refined'
CLIPSEG_SIZE = 352

View File

@ -0,0 +1,8 @@
'''
Initialization file for invokeai.backend.model_management
'''
from .model_manager import ModelManager
from .convert_ckpt_to_diffusers import (load_pipeline_from_original_stable_diffusion_ckpt,
convert_ckpt_to_diffusers)
from ...frontend.merge.merge_diffusers import (merge_diffusion_models,
merge_diffusion_models_and_commit)

File diff suppressed because it is too large Load Diff

View File

@ -31,14 +31,13 @@ from omegaconf import OmegaConf
from omegaconf.dictconfig import DictConfig
from picklescan.scanner import scan_file_path
from .devices import CPU_DEVICE
from ldm.invoke.globals import Globals, global_cache_dir
from .util import (
from ..util import CPU_DEVICE
from invokeai.backend.globals import Globals, global_cache_dir
from ..util import (
ask_user,
download_with_resume,
url_attachment_name,
)
from .stable_diffusion import StableDiffusionGeneratorPipeline
from ..stable_diffusion import StableDiffusionGeneratorPipeline
class SDLegacyType(Enum):
V1 = 1
@ -416,6 +415,51 @@ class ModelManager(object):
return pipeline, width, height, model_hash
def _load_ckpt_model(self, model_name, mconfig):
config = mconfig.config
weights = mconfig.weights
vae = mconfig.get("vae")
width = mconfig.width
height = mconfig.height
if not os.path.isabs(config):
config = os.path.join(Globals.root, config)
if not os.path.isabs(weights):
weights = os.path.normpath(os.path.join(Globals.root, weights))
# Convert to diffusers and return a diffusers pipeline
print(
f">> Converting legacy checkpoint {model_name} into a diffusers model..."
)
from . import load_pipeline_from_original_stable_diffusion_ckpt
self.offload_model(self.current_model)
if vae_config := self._choose_diffusers_vae(model_name):
vae = self._load_vae(vae_config)
if self._has_cuda():
torch.cuda.empty_cache()
pipeline = load_pipeline_from_original_stable_diffusion_ckpt(
checkpoint_path=weights,
original_config_file=config,
vae=vae,
return_generator_pipeline=True,
precision=torch.float16
if self.precision == "float16"
else torch.float32,
)
if self.sequential_offload:
pipeline.enable_offload_submodels(self.device)
else:
pipeline.to(self.device)
return (
pipeline,
width,
height,
"NOHASH",
)
def model_name_or_path(self, model_name: Union[str, DictConfig]) -> str | Path:
if isinstance(model_name, DictConfig) or isinstance(model_name, dict):
mconfig = model_name
@ -519,66 +563,6 @@ class ModelManager(object):
self.commit(commit_to_conf)
return model_name
def import_ckpt_model(
self,
weights: Union[str, Path],
config: Union[str, Path] = "configs/stable-diffusion/v1-inference.yaml",
vae: Union[str, Path] = None,
model_name: str = None,
model_description: str = None,
commit_to_conf: Path = None,
) -> str:
"""
Attempts to install the indicated ckpt file and returns True if successful.
"weights" can be either a path-like object corresponding to a local .ckpt file
or a http/https URL pointing to a remote model.
"vae" is a Path or str object pointing to a ckpt or safetensors file to be used
as the VAE for this model.
"config" is the model config file to use with this ckpt file. It defaults to
v1-inference.yaml. If a URL is provided, the config will be downloaded.
You can optionally provide a model name and/or description. If not provided,
then these will be derived from the weight file name. If you provide a commit_to_conf
path to the configuration file, then the new entry will be committed to the
models.yaml file.
Return value is the name of the imported file, or None if an error occurred.
"""
if str(weights).startswith(("http:", "https:")):
model_name = model_name or url_attachment_name(weights)
weights_path = self._resolve_path(weights, "models/ldm/stable-diffusion-v1")
config_path = self._resolve_path(config, "configs/stable-diffusion")
if weights_path is None or not weights_path.exists():
return
if config_path is None or not config_path.exists():
return
model_name = (
model_name or Path(weights).stem
) # note this gives ugly pathnames if used on a URL without a Content-Disposition header
model_description = (
model_description or f"Imported stable diffusion weights file {model_name}"
)
new_config = dict(
weights=str(weights_path),
config=str(config_path),
description=model_description,
format="ckpt",
width=512,
height=512,
)
if vae:
new_config["vae"] = vae
self.add_model(model_name, new_config, True)
if commit_to_conf:
self.commit(commit_to_conf)
return model_name
@classmethod
def probe_model_type(self, checkpoint: dict) -> SDLegacyType:
"""
@ -746,36 +730,18 @@ class ModelManager(object):
)
return
if convert:
diffuser_path = Path(
Globals.root, "models", Globals.converted_ckpts_dir, model_path.stem
)
model_name = self.convert_and_import(
model_path,
diffusers_path=diffuser_path,
vae=dict(repo_id="stabilityai/sd-vae-ft-mse"),
model_name=model_name,
model_description=description,
original_config_file=model_config_file,
commit_to_conf=commit_to_conf,
)
else:
model_name = self.import_ckpt_model(
model_path,
config=model_config_file,
model_name=model_name,
model_description=description,
vae=str(
Path(
Globals.root,
"models/ldm/stable-diffusion-v1/vae-ft-mse-840000-ema-pruned.ckpt",
)
),
commit_to_conf=commit_to_conf,
)
if commit_to_conf:
self.commit(commit_to_conf)
diffuser_path = Path(
Globals.root, "models", Globals.converted_ckpts_dir, model_path.stem
)
model_name = self.convert_and_import(
model_path,
diffusers_path=diffuser_path,
vae=dict(repo_id="stabilityai/sd-vae-ft-mse"),
model_name=model_name,
model_description=description,
original_config_file=model_config_file,
commit_to_conf=commit_to_conf,
)
return model_name
def convert_and_import(
@ -800,7 +766,7 @@ class ModelManager(object):
new_config = None
from ldm.invoke.ckpt_to_diffuser import convert_ckpt_to_diffuser
from . import convert_ckpt_to_diffusers
if diffusers_path.exists():
print(
@ -815,7 +781,7 @@ class ModelManager(object):
# By passing the specified VAE to the conversion function, the autoencoder
# will be built into the model rather than tacked on afterward via the config file
vae_model = self._load_vae(vae) if vae else None
convert_ckpt_to_diffuser(
convert_ckpt_to_diffusers (
ckpt_path,
diffusers_path,
extract_ema=True,

View File

@ -13,9 +13,9 @@ from transformers import CLIPTokenizer, CLIPTextModel
from compel import Compel
from compel.prompt_parser import FlattenedPrompt, Blend, Fragment, CrossAttentionControlSubstitute, PromptParser
from ..devices import torch_dtype
from ..util import torch_dtype
from ..stable_diffusion import InvokeAIDiffuserComponent
from ldm.invoke.globals import Globals
from invokeai.backend.globals import Globals
def get_tokenizer(model) -> CLIPTokenizer:
# TODO remove legacy ckpt fallback handling

View File

@ -0,0 +1,4 @@
'''
Initialization file for the ldm.invoke.restoration package
'''
from .base import Restoration

View File

@ -0,0 +1,38 @@
class Restoration():
def __init__(self) -> None:
pass
def load_face_restore_models(self, gfpgan_model_path='./models/gfpgan/GFPGANv1.4.pth'):
# Load GFPGAN
gfpgan = self.load_gfpgan(gfpgan_model_path)
if gfpgan.gfpgan_model_exists:
print('>> GFPGAN Initialized')
else:
print('>> GFPGAN Disabled')
gfpgan = None
# Load CodeFormer
codeformer = self.load_codeformer()
if codeformer.codeformer_model_exists:
print('>> CodeFormer Initialized')
else:
print('>> CodeFormer Disabled')
codeformer = None
return gfpgan, codeformer
# Face Restore Models
def load_gfpgan(self, gfpgan_model_path):
from .gfpgan import GFPGAN
return GFPGAN(gfpgan_model_path)
def load_codeformer(self):
from .codeformer import CodeFormerRestoration
return CodeFormerRestoration()
# Upscale Models
def load_esrgan(self, esrgan_bg_tile=400):
from .realesrgan import ESRGAN
esrgan = ESRGAN(esrgan_bg_tile)
print('>> ESRGAN Initialized')
return esrgan;

View File

@ -0,0 +1,108 @@
import os
import torch
import numpy as np
import warnings
import sys
from invokeai.backend.globals import Globals
pretrained_model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth'
class CodeFormerRestoration():
def __init__(self,
codeformer_dir='models/codeformer',
codeformer_model_path='codeformer.pth') -> None:
if not os.path.isabs(codeformer_dir):
codeformer_dir = os.path.join(Globals.root, codeformer_dir)
self.model_path = os.path.join(codeformer_dir, codeformer_model_path)
self.codeformer_model_exists = os.path.isfile(self.model_path)
if not self.codeformer_model_exists:
print('## NOT FOUND: CodeFormer model not found at ' + self.model_path)
sys.path.append(os.path.abspath(codeformer_dir))
def process(self, image, strength, device, seed=None, fidelity=0.75):
if seed is not None:
print(f'>> CodeFormer - Restoring Faces for image seed:{seed}')
with warnings.catch_warnings():
warnings.filterwarnings('ignore', category=DeprecationWarning)
warnings.filterwarnings('ignore', category=UserWarning)
from basicsr.utils.download_util import load_file_from_url
from basicsr.utils import img2tensor, tensor2img
from facexlib.utils.face_restoration_helper import FaceRestoreHelper
from ldm.invoke.restoration.codeformer_arch import CodeFormer
from torchvision.transforms.functional import normalize
from PIL import Image
cf_class = CodeFormer
cf = cf_class(
dim_embd=512,
codebook_size=1024,
n_head=8,
n_layers=9,
connect_list=['32', '64', '128', '256']
).to(device)
# note that this file should already be downloaded and cached at
# this point
checkpoint_path = load_file_from_url(url=pretrained_model_url,
model_dir=os.path.abspath(os.path.dirname(self.model_path)),
progress=True
)
checkpoint = torch.load(checkpoint_path)['params_ema']
cf.load_state_dict(checkpoint)
cf.eval()
image = image.convert('RGB')
# Codeformer expects a BGR np array; make array and flip channels
bgr_image_array = np.array(image, dtype=np.uint8)[...,::-1]
face_helper = FaceRestoreHelper(
upscale_factor=1,
use_parse=True,
device=device,
model_rootpath=os.path.join(Globals.root,'models','gfpgan','weights'),
)
face_helper.clean_all()
face_helper.read_image(bgr_image_array)
face_helper.get_face_landmarks_5(resize=640, eye_dist_threshold=5)
face_helper.align_warp_face()
for idx, cropped_face in enumerate(face_helper.cropped_faces):
cropped_face_t = img2tensor(cropped_face / 255., bgr2rgb=True, float32=True)
normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
cropped_face_t = cropped_face_t.unsqueeze(0).to(device)
try:
with torch.no_grad():
output = cf(cropped_face_t, w=fidelity, adain=True)[0]
restored_face = tensor2img(output.squeeze(0), rgb2bgr=True, min_max=(-1, 1))
del output
torch.cuda.empty_cache()
except RuntimeError as error:
print(f'\tFailed inference for CodeFormer: {error}.')
restored_face = cropped_face
restored_face = restored_face.astype('uint8')
face_helper.add_restored_face(restored_face)
face_helper.get_inverse_affine(None)
restored_img = face_helper.paste_faces_to_input_image()
# Flip the channels back to RGB
res = Image.fromarray(restored_img[...,::-1])
if strength < 1.0:
# Resize the image to the new image if the sizes have changed
if restored_img.size != image.size:
image = image.resize(res.size)
res = Image.blend(image, res, strength)
cf = None
return res

View File

@ -0,0 +1,275 @@
import math
import numpy as np
import torch
from torch import nn, Tensor
import torch.nn.functional as F
from typing import Optional, List
from .vqgan_arch import *
from basicsr.utils import get_root_logger
from basicsr.utils.registry import ARCH_REGISTRY
def calc_mean_std(feat, eps=1e-5):
"""Calculate mean and std for adaptive_instance_normalization.
Args:
feat (Tensor): 4D tensor.
eps (float): A small value added to the variance to avoid
divide-by-zero. Default: 1e-5.
"""
size = feat.size()
assert len(size) == 4, 'The input feature should be 4D tensor.'
b, c = size[:2]
feat_var = feat.view(b, c, -1).var(dim=2) + eps
feat_std = feat_var.sqrt().view(b, c, 1, 1)
feat_mean = feat.view(b, c, -1).mean(dim=2).view(b, c, 1, 1)
return feat_mean, feat_std
def adaptive_instance_normalization(content_feat, style_feat):
"""Adaptive instance normalization.
Adjust the reference features to have the similar color and illuminations
as those in the degradate features.
Args:
content_feat (Tensor): The reference feature.
style_feat (Tensor): The degradate features.
"""
size = content_feat.size()
style_mean, style_std = calc_mean_std(style_feat)
content_mean, content_std = calc_mean_std(content_feat)
normalized_feat = (content_feat - content_mean.expand(size)) / content_std.expand(size)
return normalized_feat * style_std.expand(size) + style_mean.expand(size)
class PositionEmbeddingSine(nn.Module):
"""
This is a more standard version of the position embedding, very similar to the one
used by the Attention is all you need paper, generalized to work on images.
"""
def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):
super().__init__()
self.num_pos_feats = num_pos_feats
self.temperature = temperature
self.normalize = normalize
if scale is not None and normalize is False:
raise ValueError("normalize should be True if scale is passed")
if scale is None:
scale = 2 * math.pi
self.scale = scale
def forward(self, x, mask=None):
if mask is None:
mask = torch.zeros((x.size(0), x.size(2), x.size(3)), device=x.device, dtype=torch.bool)
not_mask = ~mask
y_embed = not_mask.cumsum(1, dtype=torch.float32)
x_embed = not_mask.cumsum(2, dtype=torch.float32)
if self.normalize:
eps = 1e-6
y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
pos_x = x_embed[:, :, :, None] / dim_t
pos_y = y_embed[:, :, :, None] / dim_t
pos_x = torch.stack(
(pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4
).flatten(3)
pos_y = torch.stack(
(pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4
).flatten(3)
pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
return pos
def _get_activation_fn(activation):
"""Return an activation function given a string"""
if activation == "relu":
return F.relu
if activation == "gelu":
return F.gelu
if activation == "glu":
return F.glu
raise RuntimeError(F"activation should be relu/gelu, not {activation}.")
class TransformerSALayer(nn.Module):
def __init__(self, embed_dim, nhead=8, dim_mlp=2048, dropout=0.0, activation="gelu"):
super().__init__()
self.self_attn = nn.MultiheadAttention(embed_dim, nhead, dropout=dropout)
# Implementation of Feedforward model - MLP
self.linear1 = nn.Linear(embed_dim, dim_mlp)
self.dropout = nn.Dropout(dropout)
self.linear2 = nn.Linear(dim_mlp, embed_dim)
self.norm1 = nn.LayerNorm(embed_dim)
self.norm2 = nn.LayerNorm(embed_dim)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
self.activation = _get_activation_fn(activation)
def with_pos_embed(self, tensor, pos: Optional[Tensor]):
return tensor if pos is None else tensor + pos
def forward(self, tgt,
tgt_mask: Optional[Tensor] = None,
tgt_key_padding_mask: Optional[Tensor] = None,
query_pos: Optional[Tensor] = None):
# self attention
tgt2 = self.norm1(tgt)
q = k = self.with_pos_embed(tgt2, query_pos)
tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask,
key_padding_mask=tgt_key_padding_mask)[0]
tgt = tgt + self.dropout1(tgt2)
# ffn
tgt2 = self.norm2(tgt)
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
tgt = tgt + self.dropout2(tgt2)
return tgt
class Fuse_sft_block(nn.Module):
def __init__(self, in_ch, out_ch):
super().__init__()
self.encode_enc = ResBlock(2*in_ch, out_ch)
self.scale = nn.Sequential(
nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
nn.LeakyReLU(0.2, True),
nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1))
self.shift = nn.Sequential(
nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
nn.LeakyReLU(0.2, True),
nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1))
def forward(self, enc_feat, dec_feat, w=1):
enc_feat = self.encode_enc(torch.cat([enc_feat, dec_feat], dim=1))
scale = self.scale(enc_feat)
shift = self.shift(enc_feat)
residual = w * (dec_feat * scale + shift)
out = dec_feat + residual
return out
@ARCH_REGISTRY.register()
class CodeFormer(VQAutoEncoder):
def __init__(self, dim_embd=512, n_head=8, n_layers=9,
codebook_size=1024, latent_size=256,
connect_list=['32', '64', '128', '256'],
fix_modules=['quantize','generator']):
super(CodeFormer, self).__init__(512, 64, [1, 2, 2, 4, 4, 8], 'nearest',2, [16], codebook_size)
if fix_modules is not None:
for module in fix_modules:
for param in getattr(self, module).parameters():
param.requires_grad = False
self.connect_list = connect_list
self.n_layers = n_layers
self.dim_embd = dim_embd
self.dim_mlp = dim_embd*2
self.position_emb = nn.Parameter(torch.zeros(latent_size, self.dim_embd))
self.feat_emb = nn.Linear(256, self.dim_embd)
# transformer
self.ft_layers = nn.Sequential(*[TransformerSALayer(embed_dim=dim_embd, nhead=n_head, dim_mlp=self.dim_mlp, dropout=0.0)
for _ in range(self.n_layers)])
# logits_predict head
self.idx_pred_layer = nn.Sequential(
nn.LayerNorm(dim_embd),
nn.Linear(dim_embd, codebook_size, bias=False))
self.channels = {
'16': 512,
'32': 256,
'64': 256,
'128': 128,
'256': 128,
'512': 64,
}
# after second residual block for > 16, before attn layer for ==16
self.fuse_encoder_block = {'512':2, '256':5, '128':8, '64':11, '32':14, '16':18}
# after first residual block for > 16, before attn layer for ==16
self.fuse_generator_block = {'16':6, '32': 9, '64':12, '128':15, '256':18, '512':21}
# fuse_convs_dict
self.fuse_convs_dict = nn.ModuleDict()
for f_size in self.connect_list:
in_ch = self.channels[f_size]
self.fuse_convs_dict[f_size] = Fuse_sft_block(in_ch, in_ch)
def _init_weights(self, module):
if isinstance(module, (nn.Linear, nn.Embedding)):
module.weight.data.normal_(mean=0.0, std=0.02)
if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
def forward(self, x, w=0, detach_16=True, code_only=False, adain=False):
# ################### Encoder #####################
enc_feat_dict = {}
out_list = [self.fuse_encoder_block[f_size] for f_size in self.connect_list]
for i, block in enumerate(self.encoder.blocks):
x = block(x)
if i in out_list:
enc_feat_dict[str(x.shape[-1])] = x.clone()
lq_feat = x
# ################# Transformer ###################
# quant_feat, codebook_loss, quant_stats = self.quantize(lq_feat)
pos_emb = self.position_emb.unsqueeze(1).repeat(1,x.shape[0],1)
# BCHW -> BC(HW) -> (HW)BC
feat_emb = self.feat_emb(lq_feat.flatten(2).permute(2,0,1))
query_emb = feat_emb
# Transformer encoder
for layer in self.ft_layers:
query_emb = layer(query_emb, query_pos=pos_emb)
# output logits
logits = self.idx_pred_layer(query_emb) # (hw)bn
logits = logits.permute(1,0,2) # (hw)bn -> b(hw)n
if code_only: # for training stage II
# logits doesn't need softmax before cross_entropy loss
return logits, lq_feat
# ################# Quantization ###################
# if self.training:
# quant_feat = torch.einsum('btn,nc->btc', [soft_one_hot, self.quantize.embedding.weight])
# # b(hw)c -> bc(hw) -> bchw
# quant_feat = quant_feat.permute(0,2,1).view(lq_feat.shape)
# ------------
soft_one_hot = F.softmax(logits, dim=2)
_, top_idx = torch.topk(soft_one_hot, 1, dim=2)
quant_feat = self.quantize.get_codebook_feat(top_idx, shape=[x.shape[0],16,16,256])
# preserve gradients
# quant_feat = lq_feat + (quant_feat - lq_feat).detach()
if detach_16:
quant_feat = quant_feat.detach() # for training stage III
if adain:
quant_feat = adaptive_instance_normalization(quant_feat, lq_feat)
# ################## Generator ####################
x = quant_feat
fuse_list = [self.fuse_generator_block[f_size] for f_size in self.connect_list]
for i, block in enumerate(self.generator.blocks):
x = block(x)
if i in fuse_list: # fuse after i-th block
f_size = str(x.shape[-1])
if w>0:
x = self.fuse_convs_dict[f_size](enc_feat_dict[f_size].detach(), x, w)
out = x
# logits doesn't need softmax before cross_entropy loss
return out, logits, lq_feat

View File

@ -0,0 +1,87 @@
import torch
import warnings
import os
import sys
import numpy as np
from invokeai.backend.globals import Globals
from PIL import Image
class GFPGAN():
def __init__(
self,
gfpgan_model_path='models/gfpgan/GFPGANv1.4.pth'
) -> None:
if not os.path.isabs(gfpgan_model_path):
gfpgan_model_path=os.path.abspath(os.path.join(Globals.root,gfpgan_model_path))
self.model_path = gfpgan_model_path
self.gfpgan_model_exists = os.path.isfile(self.model_path)
if not self.gfpgan_model_exists:
print('## NOT FOUND: GFPGAN model not found at ' + self.model_path)
return None
def model_exists(self):
return os.path.isfile(self.model_path)
def process(self, image, strength: float, seed: str = None):
if seed is not None:
print(f'>> GFPGAN - Restoring Faces for image seed:{seed}')
with warnings.catch_warnings():
warnings.filterwarnings('ignore', category=DeprecationWarning)
warnings.filterwarnings('ignore', category=UserWarning)
cwd = os.getcwd()
os.chdir(os.path.join(Globals.root,'models'))
try:
from gfpgan import GFPGANer
self.gfpgan = GFPGANer(
model_path=self.model_path,
upscale=1,
arch='clean',
channel_multiplier=2,
bg_upsampler=None,
)
except Exception:
import traceback
print('>> Error loading GFPGAN:', file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
os.chdir(cwd)
if self.gfpgan is None:
print(
f'>> WARNING: GFPGAN not initialized.'
)
print(
f'>> Download https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth to {self.model_path}'
)
image = image.convert('RGB')
# GFPGAN expects a BGR np array; make array and flip channels
bgr_image_array = np.array(image, dtype=np.uint8)[...,::-1]
_, _, restored_img = self.gfpgan.enhance(
bgr_image_array,
has_aligned=False,
only_center_face=False,
paste_back=True,
)
# Flip the channels back to RGB
res = Image.fromarray(restored_img[...,::-1])
if strength < 1.0:
# Resize the image to the new image if the sizes have changed
if restored_img.size != image.size:
image = image.resize(res.size)
res = Image.blend(image, res, strength)
if torch.cuda.is_available():
torch.cuda.empty_cache()
self.gfpgan = None
return res

View File

@ -0,0 +1,108 @@
import warnings
import math
from PIL import Image, ImageFilter
class Outcrop(object):
def __init__(
self,
image,
generate, # current generate object
):
self.image = image
self.generate = generate
def process (
self,
extents:dict,
opt, # current options
orig_opt, # ones originally used to generate the image
image_callback = None,
prefix = None
):
# grow and mask the image
extended_image = self._extend_all(extents)
# switch samplers temporarily
curr_sampler = self.generate.sampler
self.generate.sampler_name = opt.sampler_name
self.generate._set_sampler()
def wrapped_callback(img,seed,**kwargs):
preferred_seed = orig_opt.seed if orig_opt.seed is not None and orig_opt.seed >= 0 else seed
image_callback(img,preferred_seed,use_prefix=prefix,**kwargs)
result= self.generate.prompt2image(
opt.prompt,
seed = opt.seed or orig_opt.seed,
sampler = self.generate.sampler,
steps = opt.steps,
cfg_scale = opt.cfg_scale,
ddim_eta = self.generate.ddim_eta,
width = extended_image.width,
height = extended_image.height,
init_img = extended_image,
strength = 0.90,
image_callback = wrapped_callback if image_callback else None,
seam_size = opt.seam_size or 96,
seam_blur = opt.seam_blur or 16,
seam_strength = opt.seam_strength or 0.7,
seam_steps = 20,
tile_size = 32,
color_match = True,
force_outpaint = True, # this just stops the warning about erased regions
)
# swap sampler back
self.generate.sampler = curr_sampler
return result
def _extend_all(
self,
extents:dict,
) -> Image:
'''
Extend the image in direction ('top','bottom','left','right') by
the indicated value. The image canvas is extended, and the empty
rectangular section will be filled with a blurred copy of the
adjacent image.
'''
image = self.image
for direction in extents:
assert direction in ['top', 'left', 'bottom', 'right'],'Direction must be one of "top", "left", "bottom", "right"'
pixels = extents[direction]
# round pixels up to the nearest 64
pixels = math.ceil(pixels/64) * 64
print(f'>> extending image {direction}ward by {pixels} pixels')
image = self._rotate(image,direction)
image = self._extend(image,pixels)
image = self._rotate(image,direction,reverse=True)
return image
def _rotate(self,image:Image,direction:str,reverse=False) -> Image:
'''
Rotates image so that the area to extend is always at the top top.
Simplifies logic later. The reverse argument, if true, will undo the
previous transpose.
'''
transposes = {
'right': ['ROTATE_90','ROTATE_270'],
'bottom': ['ROTATE_180','ROTATE_180'],
'left': ['ROTATE_270','ROTATE_90']
}
if direction not in transposes:
return image
transpose = transposes[direction][1 if reverse else 0]
return image.transpose(Image.Transpose.__dict__[transpose])
def _extend(self,image:Image,pixels:int)-> Image:
extended_img = Image.new('RGBA',(image.width,image.height+pixels))
extended_img.paste((0,0,0),[0,0,image.width,image.height+pixels])
extended_img.paste(image,box=(0,pixels))
# now make the top part transparent to use as a mask
alpha = extended_img.getchannel('A')
alpha.paste(0,(0,0,extended_img.width,pixels))
extended_img.putalpha(alpha)
return extended_img

View File

@ -0,0 +1,92 @@
import warnings
import math
from PIL import Image, ImageFilter
class Outpaint(object):
def __init__(self, image, generate):
self.image = image
self.generate = generate
def process(self, opt, old_opt, image_callback = None, prefix = None):
image = self._create_outpaint_image(self.image, opt.out_direction)
seed = old_opt.seed
prompt = old_opt.prompt
def wrapped_callback(img,seed,**kwargs):
image_callback(img,seed,use_prefix=prefix,**kwargs)
return self.generate.prompt2image(
prompt,
seed = seed,
sampler = self.generate.sampler,
steps = opt.steps,
cfg_scale = opt.cfg_scale,
ddim_eta = self.generate.ddim_eta,
width = opt.width,
height = opt.height,
init_img = image,
strength = 0.83,
image_callback = wrapped_callback,
prefix = prefix,
)
def _create_outpaint_image(self, image, direction_args):
assert len(direction_args) in [1, 2], 'Direction (-D) must have exactly one or two arguments.'
if len(direction_args) == 1:
direction = direction_args[0]
pixels = None
elif len(direction_args) == 2:
direction = direction_args[0]
pixels = int(direction_args[1])
assert direction in ['top', 'left', 'bottom', 'right'], 'Direction (-D) must be one of "top", "left", "bottom", "right"'
image = image.convert("RGBA")
# we always extend top, but rotate to extend along the requested side
if direction == 'left':
image = image.transpose(Image.Transpose.ROTATE_270)
elif direction == 'bottom':
image = image.transpose(Image.Transpose.ROTATE_180)
elif direction == 'right':
image = image.transpose(Image.Transpose.ROTATE_90)
pixels = image.height//2 if pixels is None else int(pixels)
assert 0 < pixels < image.height, 'Direction (-D) pixels length must be in the range 0 - image.size'
# the top part of the image is taken from the source image mirrored
# coordinates (0,0) are the upper left corner of an image
top = image.transpose(Image.Transpose.FLIP_TOP_BOTTOM).convert("RGBA")
top = top.crop((0, top.height - pixels, top.width, top.height))
# setting all alpha of the top part to 0
alpha = top.getchannel("A")
alpha.paste(0, (0, 0, top.width, top.height))
top.putalpha(alpha)
# taking the bottom from the original image
bottom = image.crop((0, 0, image.width, image.height - pixels))
new_img = image.copy()
new_img.paste(top, (0, 0))
new_img.paste(bottom, (0, pixels))
# create a 10% dither in the middle
dither = min(image.height//10, pixels)
for x in range(0, image.width, 2):
for y in range(pixels - dither, pixels + dither):
(r, g, b, a) = new_img.getpixel((x, y))
new_img.putpixel((x, y), (r, g, b, 0))
# let's rotate back again
if direction == 'left':
new_img = new_img.transpose(Image.Transpose.ROTATE_90)
elif direction == 'bottom':
new_img = new_img.transpose(Image.Transpose.ROTATE_180)
elif direction == 'right':
new_img = new_img.transpose(Image.Transpose.ROTATE_270)
return new_img

View File

@ -0,0 +1,92 @@
import torch
import warnings
import numpy as np
import os
from invokeai.backend.globals import Globals
from PIL import Image
from PIL.Image import Image as ImageType
class ESRGAN():
def __init__(self, bg_tile_size=400) -> None:
self.bg_tile_size = bg_tile_size
if not torch.cuda.is_available(): # CPU or MPS on M1
use_half_precision = False
else:
use_half_precision = True
def load_esrgan_bg_upsampler(self, denoise_str):
if not torch.cuda.is_available(): # CPU or MPS on M1
use_half_precision = False
else:
use_half_precision = True
from realesrgan.archs.srvgg_arch import SRVGGNetCompact
from realesrgan import RealESRGANer
model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu')
model_path = os.path.join(Globals.root, 'models/realesrgan/realesr-general-x4v3.pth')
wdn_model_path = os.path.join(Globals.root, 'models/realesrgan/realesr-general-wdn-x4v3.pth')
scale = 4
bg_upsampler = RealESRGANer(
scale=scale,
model_path=[model_path, wdn_model_path],
model=model,
tile=self.bg_tile_size,
dni_weight=[denoise_str, 1 - denoise_str],
tile_pad=10,
pre_pad=0,
half=use_half_precision,
)
return bg_upsampler
def process(self, image: ImageType, strength: float, seed: str = None, upsampler_scale: int = 2, denoise_str: float = 0.75):
with warnings.catch_warnings():
warnings.filterwarnings('ignore', category=DeprecationWarning)
warnings.filterwarnings('ignore', category=UserWarning)
try:
upsampler = self.load_esrgan_bg_upsampler(denoise_str)
except Exception:
import traceback
import sys
print('>> Error loading Real-ESRGAN:', file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
if upsampler_scale == 0:
print('>> Real-ESRGAN: Invalid scaling option. Image not upscaled.')
return image
if seed is not None:
print(
f'>> Real-ESRGAN Upscaling seed:{seed}, scale:{upsampler_scale}x, tile:{self.bg_tile_size}, denoise:{denoise_str}'
)
# ESRGAN outputs images with partial transparency if given RGBA images; convert to RGB
image = image.convert("RGB")
# REALSRGAN expects a BGR np array; make array and flip channels
bgr_image_array = np.array(image, dtype=np.uint8)[...,::-1]
output, _ = upsampler.enhance(
bgr_image_array,
outscale=upsampler_scale,
alpha_upsampler='realesrgan',
)
# Flip the channels back to RGB
res = Image.fromarray(output[...,::-1])
if strength < 1.0:
# Resize the image to the new image if the sizes have changed
if output.size != image.size:
image = image.resize(res.size)
res = Image.blend(image, res, strength)
if torch.cuda.is_available():
torch.cuda.empty_cache()
upsampler = None
return res

View File

@ -0,0 +1,435 @@
'''
VQGAN code, adapted from the original created by the Unleashing Transformers authors:
https://github.com/samb-t/unleashing-transformers/blob/master/models/vqgan.py
'''
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import copy
from basicsr.utils import get_root_logger
from basicsr.utils.registry import ARCH_REGISTRY
def normalize(in_channels):
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
@torch.jit.script
def swish(x):
return x*torch.sigmoid(x)
# Define VQVAE classes
class VectorQuantizer(nn.Module):
def __init__(self, codebook_size, emb_dim, beta):
super(VectorQuantizer, self).__init__()
self.codebook_size = codebook_size # number of embeddings
self.emb_dim = emb_dim # dimension of embedding
self.beta = beta # commitment cost used in loss term, beta * ||z_e(x)-sg[e]||^2
self.embedding = nn.Embedding(self.codebook_size, self.emb_dim)
self.embedding.weight.data.uniform_(-1.0 / self.codebook_size, 1.0 / self.codebook_size)
def forward(self, z):
# reshape z -> (batch, height, width, channel) and flatten
z = z.permute(0, 2, 3, 1).contiguous()
z_flattened = z.view(-1, self.emb_dim)
# distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
d = (z_flattened ** 2).sum(dim=1, keepdim=True) + (self.embedding.weight**2).sum(1) - \
2 * torch.matmul(z_flattened, self.embedding.weight.t())
mean_distance = torch.mean(d)
# find closest encodings
# min_encoding_indices = torch.argmin(d, dim=1).unsqueeze(1)
min_encoding_scores, min_encoding_indices = torch.topk(d, 1, dim=1, largest=False)
# [0-1], higher score, higher confidence
min_encoding_scores = torch.exp(-min_encoding_scores/10)
min_encodings = torch.zeros(min_encoding_indices.shape[0], self.codebook_size).to(z)
min_encodings.scatter_(1, min_encoding_indices, 1)
# get quantized latent vectors
z_q = torch.matmul(min_encodings, self.embedding.weight).view(z.shape)
# compute loss for embedding
loss = torch.mean((z_q.detach()-z)**2) + self.beta * torch.mean((z_q - z.detach()) ** 2)
# preserve gradients
z_q = z + (z_q - z).detach()
# perplexity
e_mean = torch.mean(min_encodings, dim=0)
perplexity = torch.exp(-torch.sum(e_mean * torch.log(e_mean + 1e-10)))
# reshape back to match original input shape
z_q = z_q.permute(0, 3, 1, 2).contiguous()
return z_q, loss, {
"perplexity": perplexity,
"min_encodings": min_encodings,
"min_encoding_indices": min_encoding_indices,
"min_encoding_scores": min_encoding_scores,
"mean_distance": mean_distance
}
def get_codebook_feat(self, indices, shape):
# input indices: batch*token_num -> (batch*token_num)*1
# shape: batch, height, width, channel
indices = indices.view(-1,1)
min_encodings = torch.zeros(indices.shape[0], self.codebook_size).to(indices)
min_encodings.scatter_(1, indices, 1)
# get quantized latent vectors
z_q = torch.matmul(min_encodings.float(), self.embedding.weight)
if shape is not None: # reshape back to match original input shape
z_q = z_q.view(shape).permute(0, 3, 1, 2).contiguous()
return z_q
class GumbelQuantizer(nn.Module):
def __init__(self, codebook_size, emb_dim, num_hiddens, straight_through=False, kl_weight=5e-4, temp_init=1.0):
super().__init__()
self.codebook_size = codebook_size # number of embeddings
self.emb_dim = emb_dim # dimension of embedding
self.straight_through = straight_through
self.temperature = temp_init
self.kl_weight = kl_weight
self.proj = nn.Conv2d(num_hiddens, codebook_size, 1) # projects last encoder layer to quantized logits
self.embed = nn.Embedding(codebook_size, emb_dim)
def forward(self, z):
hard = self.straight_through if self.training else True
logits = self.proj(z)
soft_one_hot = F.gumbel_softmax(logits, tau=self.temperature, dim=1, hard=hard)
z_q = torch.einsum("b n h w, n d -> b d h w", soft_one_hot, self.embed.weight)
# + kl divergence to the prior loss
qy = F.softmax(logits, dim=1)
diff = self.kl_weight * torch.sum(qy * torch.log(qy * self.codebook_size + 1e-10), dim=1).mean()
min_encoding_indices = soft_one_hot.argmax(dim=1)
return z_q, diff, {
"min_encoding_indices": min_encoding_indices
}
class Downsample(nn.Module):
def __init__(self, in_channels):
super().__init__()
self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
def forward(self, x):
pad = (0, 1, 0, 1)
x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
x = self.conv(x)
return x
class Upsample(nn.Module):
def __init__(self, in_channels):
super().__init__()
self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
def forward(self, x):
x = F.interpolate(x, scale_factor=2.0, mode="nearest")
x = self.conv(x)
return x
class ResBlock(nn.Module):
def __init__(self, in_channels, out_channels=None):
super(ResBlock, self).__init__()
self.in_channels = in_channels
self.out_channels = in_channels if out_channels is None else out_channels
self.norm1 = normalize(in_channels)
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
self.norm2 = normalize(out_channels)
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
if self.in_channels != self.out_channels:
self.conv_out = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
def forward(self, x_in):
x = x_in
x = self.norm1(x)
x = swish(x)
x = self.conv1(x)
x = self.norm2(x)
x = swish(x)
x = self.conv2(x)
if self.in_channels != self.out_channels:
x_in = self.conv_out(x_in)
return x + x_in
class AttnBlock(nn.Module):
def __init__(self, in_channels):
super().__init__()
self.in_channels = in_channels
self.norm = normalize(in_channels)
self.q = torch.nn.Conv2d(
in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0
)
self.k = torch.nn.Conv2d(
in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0
)
self.v = torch.nn.Conv2d(
in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0
)
self.proj_out = torch.nn.Conv2d(
in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0
)
def forward(self, x):
h_ = x
h_ = self.norm(h_)
q = self.q(h_)
k = self.k(h_)
v = self.v(h_)
# compute attention
b, c, h, w = q.shape
q = q.reshape(b, c, h*w)
q = q.permute(0, 2, 1)
k = k.reshape(b, c, h*w)
w_ = torch.bmm(q, k)
w_ = w_ * (int(c)**(-0.5))
w_ = F.softmax(w_, dim=2)
# attend to values
v = v.reshape(b, c, h*w)
w_ = w_.permute(0, 2, 1)
h_ = torch.bmm(v, w_)
h_ = h_.reshape(b, c, h, w)
h_ = self.proj_out(h_)
return x+h_
class Encoder(nn.Module):
def __init__(self, in_channels, nf, emb_dim, ch_mult, num_res_blocks, resolution, attn_resolutions):
super().__init__()
self.nf = nf
self.num_resolutions = len(ch_mult)
self.num_res_blocks = num_res_blocks
self.resolution = resolution
self.attn_resolutions = attn_resolutions
curr_res = self.resolution
in_ch_mult = (1,)+tuple(ch_mult)
blocks = []
# initial convultion
blocks.append(nn.Conv2d(in_channels, nf, kernel_size=3, stride=1, padding=1))
# residual and downsampling blocks, with attention on smaller res (16x16)
for i in range(self.num_resolutions):
block_in_ch = nf * in_ch_mult[i]
block_out_ch = nf * ch_mult[i]
for _ in range(self.num_res_blocks):
blocks.append(ResBlock(block_in_ch, block_out_ch))
block_in_ch = block_out_ch
if curr_res in attn_resolutions:
blocks.append(AttnBlock(block_in_ch))
if i != self.num_resolutions - 1:
blocks.append(Downsample(block_in_ch))
curr_res = curr_res // 2
# non-local attention block
blocks.append(ResBlock(block_in_ch, block_in_ch))
blocks.append(AttnBlock(block_in_ch))
blocks.append(ResBlock(block_in_ch, block_in_ch))
# normalise and convert to latent size
blocks.append(normalize(block_in_ch))
blocks.append(nn.Conv2d(block_in_ch, emb_dim, kernel_size=3, stride=1, padding=1))
self.blocks = nn.ModuleList(blocks)
def forward(self, x):
for block in self.blocks:
x = block(x)
return x
class Generator(nn.Module):
def __init__(self, nf, emb_dim, ch_mult, res_blocks, img_size, attn_resolutions):
super().__init__()
self.nf = nf
self.ch_mult = ch_mult
self.num_resolutions = len(self.ch_mult)
self.num_res_blocks = res_blocks
self.resolution = img_size
self.attn_resolutions = attn_resolutions
self.in_channels = emb_dim
self.out_channels = 3
block_in_ch = self.nf * self.ch_mult[-1]
curr_res = self.resolution // 2 ** (self.num_resolutions-1)
blocks = []
# initial conv
blocks.append(nn.Conv2d(self.in_channels, block_in_ch, kernel_size=3, stride=1, padding=1))
# non-local attention block
blocks.append(ResBlock(block_in_ch, block_in_ch))
blocks.append(AttnBlock(block_in_ch))
blocks.append(ResBlock(block_in_ch, block_in_ch))
for i in reversed(range(self.num_resolutions)):
block_out_ch = self.nf * self.ch_mult[i]
for _ in range(self.num_res_blocks):
blocks.append(ResBlock(block_in_ch, block_out_ch))
block_in_ch = block_out_ch
if curr_res in self.attn_resolutions:
blocks.append(AttnBlock(block_in_ch))
if i != 0:
blocks.append(Upsample(block_in_ch))
curr_res = curr_res * 2
blocks.append(normalize(block_in_ch))
blocks.append(nn.Conv2d(block_in_ch, self.out_channels, kernel_size=3, stride=1, padding=1))
self.blocks = nn.ModuleList(blocks)
def forward(self, x):
for block in self.blocks:
x = block(x)
return x
@ARCH_REGISTRY.register()
class VQAutoEncoder(nn.Module):
def __init__(self, img_size, nf, ch_mult, quantizer="nearest", res_blocks=2, attn_resolutions=[16], codebook_size=1024, emb_dim=256,
beta=0.25, gumbel_straight_through=False, gumbel_kl_weight=1e-8, model_path=None):
super().__init__()
logger = get_root_logger()
self.in_channels = 3
self.nf = nf
self.n_blocks = res_blocks
self.codebook_size = codebook_size
self.embed_dim = emb_dim
self.ch_mult = ch_mult
self.resolution = img_size
self.attn_resolutions = attn_resolutions
self.quantizer_type = quantizer
self.encoder = Encoder(
self.in_channels,
self.nf,
self.embed_dim,
self.ch_mult,
self.n_blocks,
self.resolution,
self.attn_resolutions
)
if self.quantizer_type == "nearest":
self.beta = beta #0.25
self.quantize = VectorQuantizer(self.codebook_size, self.embed_dim, self.beta)
elif self.quantizer_type == "gumbel":
self.gumbel_num_hiddens = emb_dim
self.straight_through = gumbel_straight_through
self.kl_weight = gumbel_kl_weight
self.quantize = GumbelQuantizer(
self.codebook_size,
self.embed_dim,
self.gumbel_num_hiddens,
self.straight_through,
self.kl_weight
)
self.generator = Generator(
self.nf,
self.embed_dim,
self.ch_mult,
self.n_blocks,
self.resolution,
self.attn_resolutions
)
if model_path is not None:
chkpt = torch.load(model_path, map_location='cpu')
if 'params_ema' in chkpt:
self.load_state_dict(torch.load(model_path, map_location='cpu')['params_ema'])
logger.info(f'vqgan is loaded from: {model_path} [params_ema]')
elif 'params' in chkpt:
self.load_state_dict(torch.load(model_path, map_location='cpu')['params'])
logger.info(f'vqgan is loaded from: {model_path} [params]')
else:
raise ValueError(f'Wrong params!')
def forward(self, x):
x = self.encoder(x)
quant, codebook_loss, quant_stats = self.quantize(x)
x = self.generator(quant)
return x, codebook_loss, quant_stats
# patch based discriminator
@ARCH_REGISTRY.register()
class VQGANDiscriminator(nn.Module):
def __init__(self, nc=3, ndf=64, n_layers=4, model_path=None):
super().__init__()
layers = [nn.Conv2d(nc, ndf, kernel_size=4, stride=2, padding=1), nn.LeakyReLU(0.2, True)]
ndf_mult = 1
ndf_mult_prev = 1
for n in range(1, n_layers): # gradually increase the number of filters
ndf_mult_prev = ndf_mult
ndf_mult = min(2 ** n, 8)
layers += [
nn.Conv2d(ndf * ndf_mult_prev, ndf * ndf_mult, kernel_size=4, stride=2, padding=1, bias=False),
nn.BatchNorm2d(ndf * ndf_mult),
nn.LeakyReLU(0.2, True)
]
ndf_mult_prev = ndf_mult
ndf_mult = min(2 ** n_layers, 8)
layers += [
nn.Conv2d(ndf * ndf_mult_prev, ndf * ndf_mult, kernel_size=4, stride=1, padding=1, bias=False),
nn.BatchNorm2d(ndf * ndf_mult),
nn.LeakyReLU(0.2, True)
]
layers += [
nn.Conv2d(ndf * ndf_mult, 1, kernel_size=4, stride=1, padding=1)] # output 1 channel prediction map
self.main = nn.Sequential(*layers)
if model_path is not None:
chkpt = torch.load(model_path, map_location='cpu')
if 'params_d' in chkpt:
self.load_state_dict(torch.load(model_path, map_location='cpu')['params_d'])
elif 'params' in chkpt:
self.load_state_dict(torch.load(model_path, map_location='cpu')['params'])
else:
raise ValueError(f'Wrong params!')
def forward(self, x):
return self.main(x)

View File

@ -10,7 +10,7 @@ import traceback
from typing import Callable
from urllib import request, error as ul_error
from huggingface_hub import HfFolder, hf_hub_url, ModelSearchArguments, ModelFilter, HfApi
from ldm.invoke.globals import Globals
from invokeai.backend.globals import Globals
class HuggingFaceConceptsLibrary(object):
def __init__(self, root=None):

View File

@ -26,11 +26,11 @@ from torchvision.transforms.functional import resize as tv_resize
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from typing_extensions import ParamSpec
from ldm.invoke.globals import Globals
from ..stable_diffusion.diffusion import InvokeAIDiffuserComponent, PostprocessingSettings, AttentionMapSaver
from ..stable_diffusion.textual_inversion_manager import TextualInversionManager
from ..stable_diffusion.offloading import LazilyLoadedModelGroup, FullyLoadedModelGroup, ModelGroup
from ..devices import normalize_device, CPU_DEVICE
from invokeai.backend.globals import Globals
from .diffusion import InvokeAIDiffuserComponent, PostprocessingSettings, AttentionMapSaver
from .textual_inversion_manager import TextualInversionManager
from .offloading import LazilyLoadedModelGroup, FullyLoadedModelGroup, ModelGroup
from ..util import normalize_device, CPU_DEVICE
from compel import EmbeddingsProvider
@dataclass

View File

@ -15,7 +15,7 @@ from torch import nn
from compel.cross_attention_control import Arguments
from diffusers.models.unet_2d_condition import UNet2DConditionModel
from diffusers.models.cross_attention import AttnProcessor
from ...devices import torch_dtype
from ...util import torch_dtype
class CrossAttentionType(enum.Enum):

View File

@ -23,7 +23,7 @@ from omegaconf import ListConfig
import urllib
from ..textual_inversion_manager import TextualInversionManager
from ...util import (
from ...util.util import (
log_txt_as_img,
exists,
default,

View File

@ -4,7 +4,7 @@ import torch
import numpy as np
from tqdm import tqdm
from functools import partial
from ...devices import choose_torch_device
from ...util import choose_torch_device
from .shared_invokeai_diffusion import InvokeAIDiffuserComponent
from .sampler import Sampler
from ..diffusionmodules.util import noise_like

View File

@ -7,7 +7,7 @@ import torch
import numpy as np
from tqdm import tqdm
from functools import partial
from ...devices import choose_torch_device
from ...util import choose_torch_device
from .shared_invokeai_diffusion import InvokeAIDiffuserComponent
from ..diffusionmodules.util import (

View File

@ -8,7 +8,7 @@ import torch
from diffusers.models.cross_attention import AttnProcessor
from typing_extensions import TypeAlias
from ldm.invoke.globals import Globals
from invokeai.backend.globals import Globals
from .cross_attention_control import Arguments, \
restore_default_cross_attention, override_cross_attention, Context, get_cross_attention_modules, \
CrossAttentionType, SwapCrossAttnContext

View File

@ -15,7 +15,7 @@ import torch.nn as nn
import numpy as np
from einops import repeat
from ...util import instantiate_from_config
from ...util.util import instantiate_from_config
def make_beta_schedule(

View File

@ -10,7 +10,7 @@ from einops import repeat
from transformers import CLIPTokenizer, CLIPTextModel
from ldm.invoke.devices import choose_torch_device
from ldm.invoke.globals import global_cache_dir
from invokeai.backend.globals import global_cache_dir
from ldm.modules.x_transformer import (
Encoder,
TransformerWrapper,

View File

@ -0,0 +1,4 @@
'''
Initialization file for invokeai.backend.training
'''
from .textual_inversion_training import do_textual_inversion_training, parse_args

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,18 @@
'''
Initialization file for invokeai.backend.util
'''
from .devices import (choose_torch_device,
choose_precision,
normalize_device,
torch_dtype,
CPU_DEVICE,
CUDA_DEVICE,
MPS_DEVICE,
)
from .util import (ask_user,
download_with_resume,
instantiate_from_config,
url_attachment_name,
)
from .log import write_log

View File

@ -5,9 +5,11 @@ from contextlib import nullcontext
import torch
from torch import autocast
from ldm.invoke.globals import Globals
from invokeai.backend.globals import Globals
CPU_DEVICE = torch.device("cpu")
CUDA_DEVICE = torch.device("cuda")
MPS_DEVICE = torch.device("mps")
def choose_torch_device() -> torch.device:
'''Convenience routine for guessing which GPU device to run model on'''

View File

@ -0,0 +1,66 @@
"""
Functions for better format logging
write_log -- logs the name of the output image, prompt, and prompt args to the terminal and different types of file
1 write_log_message -- Writes a message to the console
2 write_log_files -- Writes a message to files
2.1 write_log_default -- File in plain text
2.2 write_log_txt -- File in txt format
2.3 write_log_markdown -- File in markdown format
"""
import os
def write_log(results, log_path, file_types, output_cntr):
"""
logs the name of the output image, prompt, and prompt args to the terminal and files
"""
output_cntr = write_log_message(results, output_cntr)
write_log_files(results, log_path, file_types)
return output_cntr
def write_log_message(results, output_cntr):
"""logs to the terminal"""
if len(results) == 0:
return output_cntr
log_lines = [f"{path}: {prompt}\n" for path, prompt in results]
if len(log_lines)>1:
subcntr = 1
for l in log_lines:
print(f"[{output_cntr}.{subcntr}] {l}", end="")
subcntr += 1
else:
print(f"[{output_cntr}] {log_lines[0]}", end="")
return output_cntr+1
def write_log_files(results, log_path, file_types):
for file_type in file_types:
if file_type == "txt":
write_log_txt(log_path, results)
elif file_type == "md" or file_type == "markdown":
write_log_markdown(log_path, results)
else:
print(f"'{file_type}' format is not supported, so write in plain text")
write_log_default(log_path, results, file_type)
def write_log_default(log_path, results, file_type):
plain_txt_lines = [f"{path}: {prompt}\n" for path, prompt in results]
with open(log_path + "." + file_type, "a", encoding="utf-8") as file:
file.writelines(plain_txt_lines)
def write_log_txt(log_path, results):
txt_lines = [f"{path}: {prompt}\n" for path, prompt in results]
with open(log_path + ".txt", "a", encoding="utf-8") as file:
file.writelines(txt_lines)
def write_log_markdown(log_path, results):
md_lines = []
for path, prompt in results:
file_name = os.path.basename(path)
md_lines.append(f"## {file_name}\n![]({file_name})\n\n{prompt}\n")
with open(log_path + ".md", "a", encoding="utf-8") as file:
file.writelines(md_lines)

View File

@ -0,0 +1,4 @@
'''
Initialization file for the web backend.
'''
from .invoke_ai_web_server import InvokeAIWebServer

View File

@ -12,7 +12,7 @@ from threading import Event
from uuid import uuid4
import eventlet
import invokeai.frontend.dist as frontend
import invokeai.frontend.web.dist as frontend
from PIL import Image
from PIL.Image import Image as ImageType
from compel.prompt_parser import Blend
@ -20,24 +20,24 @@ from flask import Flask, redirect, send_from_directory, request, make_response
from flask_socketio import SocketIO
from werkzeug.utils import secure_filename
from invokeai.backend.modules.get_canvas_generation_mode import (
from .modules.get_canvas_generation_mode import (
get_canvas_generation_mode,
)
from .modules.parameters import parameters_to_command
from .prompting import (get_tokens_for_prompt_object,
get_prompt_structure,
get_tokenizer
)
from .image_util import PngWriter, retrieve_metadata
from .generator import infill_methods
from .stable_diffusion import PipelineIntermediateState
from ..prompting import (get_tokens_for_prompt_object,
get_prompt_structure,
get_tokenizer
)
from ..image_util import PngWriter, retrieve_metadata
from ..generator import infill_methods
from ..stable_diffusion import PipelineIntermediateState
from ldm.generate import Generate
from ldm.invoke.args import Args, APP_ID, APP_VERSION, calculate_init_img_hash
from ldm.invoke.globals import ( Globals, global_converted_ckpts_dir,
global_models_dir
)
from ldm.invoke.merge_diffusers import merge_diffusion_models
from .. import Generate
from ..args import Args, APP_ID, APP_VERSION, calculate_init_img_hash
from ..globals import ( Globals, global_converted_ckpts_dir,
global_models_dir
)
from ..model_management import merge_diffusion_models
# Loading Arguments
opt = Args()
@ -236,7 +236,7 @@ class InvokeAIWebServer:
sys.exit(0)
else:
useSSL = args.certfile or args.keyfile
print(">> Started Invoke AI Web Server!")
print(">> Started Invoke AI Web Server")
if self.host == "0.0.0.0":
print(
f"Point your browser at http{'s' if useSSL else ''}://localhost:{self.port} or use the host's DNS name or IP address."

View File

View File

@ -1,4 +1,4 @@
from invokeai.backend.modules.parse_seed_weights import parse_seed_weights
from .parse_seed_weights import parse_seed_weights
import argparse
SAMPLER_CHOICES = [

View File

Before

Width:  |  Height:  |  Size: 2.7 KiB

After

Width:  |  Height:  |  Size: 2.7 KiB

View File

Before

Width:  |  Height:  |  Size: 292 KiB

After

Width:  |  Height:  |  Size: 292 KiB

View File

Before

Width:  |  Height:  |  Size: 9.5 KiB

After

Width:  |  Height:  |  Size: 9.5 KiB

View File

Before

Width:  |  Height:  |  Size: 3.4 KiB

After

Width:  |  Height:  |  Size: 3.4 KiB

1237
invokeai/frontend/CLI/CLI.py Normal file

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,4 @@
'''
Initialization file for invokeai.frontend.CLI
'''
from .CLI import main as invokeai_command_line_interface

View File

@ -0,0 +1,455 @@
"""
Readline helper functions for invoke.py.
You may import the global singleton `completer` to get access to the
completer object itself. This is useful when you want to autocomplete
seeds:
from ldm.invoke.readline import completer
completer.add_seed(18247566)
completer.add_seed(9281839)
"""
import os
import re
import atexit
from ...backend.args import Args
from ...backend.globals import Globals
from ...backend.stable_diffusion import HuggingFaceConceptsLibrary
# ---------------readline utilities---------------------
try:
import readline
readline_available = True
except (ImportError,ModuleNotFoundError) as e:
print(f'** An error occurred when loading the readline module: {str(e)}')
readline_available = False
IMG_EXTENSIONS = ('.png','.jpg','.jpeg','.PNG','.JPG','.JPEG','.gif','.GIF')
WEIGHT_EXTENSIONS = ('.ckpt','.vae','.safetensors')
TEXT_EXTENSIONS = ('.txt','.TXT')
CONFIG_EXTENSIONS = ('.yaml','.yml')
COMMANDS = (
'--steps','-s',
'--seed','-S',
'--iterations','-n',
'--width','-W','--height','-H',
'--cfg_scale','-C',
'--threshold',
'--perlin',
'--grid','-g',
'--individual','-i',
'--save_intermediates',
'--init_img','-I',
'--init_mask','-M',
'--init_color',
'--strength','-f',
'--variants','-v',
'--outdir','-o',
'--sampler','-A','-m',
'--embedding_path',
'--device',
'--grid','-g',
'--facetool','-ft',
'--facetool_strength','-G',
'--codeformer_fidelity','-cf',
'--upscale','-U',
'-save_orig','--save_original',
'--log_tokenization','-t',
'--hires_fix',
'--inpaint_replace','-r',
'--png_compression','-z',
'--text_mask','-tm',
'--h_symmetry_time_pct',
'--v_symmetry_time_pct',
'!fix','!fetch','!replay','!history','!search','!clear',
'!models','!switch','!import_model','!optimize_model','!convert_model','!edit_model','!del_model',
'!mask','!triggers',
)
MODEL_COMMANDS = (
'!switch',
'!edit_model',
'!del_model',
)
CKPT_MODEL_COMMANDS = (
'!optimize_model',
)
WEIGHT_COMMANDS = (
'!import_model',
'!convert_model',
)
IMG_PATH_COMMANDS = (
'--outdir[=\s]',
)
TEXT_PATH_COMMANDS=(
'!replay',
)
IMG_FILE_COMMANDS=(
'!fix',
'!fetch',
'!mask',
'--init_img[=\s]','-I',
'--init_mask[=\s]','-M',
'--init_color[=\s]',
'--embedding_path[=\s]',
)
path_regexp = '(' + '|'.join(IMG_PATH_COMMANDS+IMG_FILE_COMMANDS) + ')\s*\S*$'
weight_regexp = '(' + '|'.join(WEIGHT_COMMANDS) + ')\s*\S*$'
text_regexp = '(' + '|'.join(TEXT_PATH_COMMANDS) + ')\s*\S*$'
class Completer(object):
def __init__(self, options, models={}):
self.options = sorted(options)
self.models = models
self.seeds = set()
self.matches = list()
self.default_dir = None
self.linebuffer = None
self.auto_history_active = True
self.extensions = None
self.concepts = None
self.embedding_terms = set()
return
def complete(self, text, state):
'''
Completes invoke command line.
BUG: it doesn't correctly complete files that have spaces in the name.
'''
buffer = readline.get_line_buffer()
if state == 0:
# extensions defined, so go directly into path completion mode
if self.extensions is not None:
self.matches = self._path_completions(text, state, self.extensions)
# looking for an image file
elif re.search(path_regexp,buffer):
do_shortcut = re.search('^'+'|'.join(IMG_FILE_COMMANDS),buffer)
self.matches = self._path_completions(text, state, IMG_EXTENSIONS,shortcut_ok=do_shortcut)
# looking for a seed
elif re.search('(-S\s*|--seed[=\s])\d*$',buffer):
self.matches= self._seed_completions(text,state)
# looking for an embedding concept
elif re.search('<[\w-]*$',buffer):
self.matches= self._concept_completions(text,state)
# looking for a model
elif re.match('^'+'|'.join(MODEL_COMMANDS),buffer):
self.matches= self._model_completions(text, state)
# looking for a ckpt model
elif re.match('^'+'|'.join(CKPT_MODEL_COMMANDS),buffer):
self.matches= self._model_completions(text, state, ckpt_only=True)
elif re.search(weight_regexp,buffer):
self.matches = self._path_completions(
text,
state,
WEIGHT_EXTENSIONS,
default_dir=Globals.root,
)
elif re.search(text_regexp,buffer):
self.matches = self._path_completions(text, state, TEXT_EXTENSIONS)
# This is the first time for this text, so build a match list.
elif text:
self.matches = [
s for s in self.options if s and s.startswith(text)
]
else:
self.matches = self.options[:]
# Return the state'th item from the match list,
# if we have that many.
try:
response = self.matches[state]
except IndexError:
response = None
return response
def complete_extensions(self, extensions:list):
'''
If called with a list of extensions, will force completer
to do file path completions.
'''
self.extensions=extensions
def add_history(self,line):
'''
Pass thru to readline
'''
if not self.auto_history_active:
readline.add_history(line)
def clear_history(self):
'''
Pass clear_history() thru to readline
'''
readline.clear_history()
def search_history(self,match:str):
'''
Like show_history() but only shows items that
contain the match string.
'''
self.show_history(match)
def remove_history_item(self,pos):
readline.remove_history_item(pos)
def add_seed(self, seed):
'''
Add a seed to the autocomplete list for display when -S is autocompleted.
'''
if seed is not None:
self.seeds.add(str(seed))
def set_default_dir(self, path):
self.default_dir=path
def set_options(self,options):
self.options = options
def get_line(self,index):
try:
line = self.get_history_item(index)
except IndexError:
return None
return line
def get_current_history_length(self):
return readline.get_current_history_length()
def get_history_item(self,index):
return readline.get_history_item(index)
def show_history(self,match=None):
'''
Print the session history using the pydoc pager
'''
import pydoc
lines = list()
h_len = self.get_current_history_length()
if h_len < 1:
print('<empty history>')
return
for i in range(0,h_len):
line = self.get_history_item(i+1)
if match and match not in line:
continue
lines.append(f'[{i+1}] {line}')
pydoc.pager('\n'.join(lines))
def set_line(self,line)->None:
'''
Set the default string displayed in the next line of input.
'''
self.linebuffer = line
readline.redisplay()
def update_models(self,models:dict)->None:
'''
update our list of models
'''
self.models = models
def _seed_completions(self, text, state):
m = re.search('(-S\s?|--seed[=\s]?)(\d*)',text)
if m:
switch = m.groups()[0]
partial = m.groups()[1]
else:
switch = ''
partial = text
matches = list()
for s in self.seeds:
if s.startswith(partial):
matches.append(switch+s)
matches.sort()
return matches
def add_embedding_terms(self, terms:list[str]):
self.embedding_terms = set(terms)
if self.concepts:
self.embedding_terms.update(set(self.concepts.list_concepts()))
def _concept_completions(self, text, state):
if self.concepts is None:
# cache Concepts() instance so we can check for updates in concepts_list during runtime.
self.concepts = HuggingFaceConceptsLibrary()
self.embedding_terms.update(set(self.concepts.list_concepts()))
else:
self.embedding_terms.update(set(self.concepts.list_concepts()))
partial = text[1:] # this removes the leading '<'
if len(partial) == 0:
return list(self.embedding_terms) # whole dump - think if user wants this!
matches = list()
for concept in self.embedding_terms:
if concept.startswith(partial):
matches.append(f'<{concept}>')
matches.sort()
return matches
def _model_completions(self, text, state, ckpt_only=False):
m = re.search('(!switch\s+)(\w*)',text)
if m:
switch = m.groups()[0]
partial = m.groups()[1]
else:
switch = ''
partial = text
matches = list()
for s in self.models:
format = self.models[s]['format']
if format == 'vae':
continue
if ckpt_only and format != 'ckpt':
continue
if s.startswith(partial):
matches.append(switch+s)
matches.sort()
return matches
def _pre_input_hook(self):
if self.linebuffer:
readline.insert_text(self.linebuffer)
readline.redisplay()
self.linebuffer = None
def _path_completions(self, text, state, extensions, shortcut_ok=True, default_dir:str=''):
# separate the switch from the partial path
match = re.search('^(-\w|--\w+=?)(.*)',text)
if match is None:
switch = None
partial_path = text
else:
switch,partial_path = match.groups()
partial_path = partial_path.lstrip()
matches = list()
path = os.path.expanduser(partial_path)
if os.path.isdir(path):
dir = path
elif os.path.dirname(path) != '':
dir = os.path.dirname(path)
else:
dir = default_dir if os.path.exists(default_dir) else ''
path= os.path.join(dir,path)
dir_list = os.listdir(dir or '.')
if shortcut_ok and os.path.exists(self.default_dir) and dir=='':
dir_list += os.listdir(self.default_dir)
for node in dir_list:
if node.startswith('.') and len(node) > 1:
continue
full_path = os.path.join(dir, node)
if not (node.endswith(extensions) or os.path.isdir(full_path)):
continue
if path and not full_path.startswith(path):
continue
if switch is None:
match_path = os.path.join(dir,node)
matches.append(match_path+'/' if os.path.isdir(full_path) else match_path)
elif os.path.isdir(full_path):
matches.append(
switch+os.path.join(os.path.dirname(full_path), node) + '/'
)
elif node.endswith(extensions):
matches.append(
switch+os.path.join(os.path.dirname(full_path), node)
)
return matches
class DummyCompleter(Completer):
def __init__(self,options):
super().__init__(options)
self.history = list()
def add_history(self,line):
self.history.append(line)
def clear_history(self):
self.history = list()
def get_current_history_length(self):
return len(self.history)
def get_history_item(self,index):
return self.history[index-1]
def remove_history_item(self,index):
return self.history.pop(index-1)
def set_line(self,line):
print(f'# {line}')
def generic_completer(commands:list)->Completer:
if readline_available:
completer = Completer(commands,[])
readline.set_completer(completer.complete)
readline.set_pre_input_hook(completer._pre_input_hook)
readline.set_completer_delims(' ')
readline.parse_and_bind('tab: complete')
readline.parse_and_bind('set print-completions-horizontally off')
readline.parse_and_bind('set page-completions on')
readline.parse_and_bind('set skip-completed-text on')
readline.parse_and_bind('set show-all-if-ambiguous on')
else:
completer = DummyCompleter(commands)
return completer
def get_completer(opt:Args, models=[])->Completer:
if readline_available:
completer = Completer(COMMANDS,models)
readline.set_completer(
completer.complete
)
# pyreadline3 does not have a set_auto_history() method
try:
readline.set_auto_history(False)
completer.auto_history_active = False
except:
completer.auto_history_active = True
readline.set_pre_input_hook(completer._pre_input_hook)
readline.set_completer_delims(' ')
readline.parse_and_bind('tab: complete')
readline.parse_and_bind('set print-completions-horizontally off')
readline.parse_and_bind('set page-completions on')
readline.parse_and_bind('set skip-completed-text on')
readline.parse_and_bind('set show-all-if-ambiguous on')
outdir = os.path.expanduser(opt.outdir)
if os.path.isabs(outdir):
histfile = os.path.join(outdir,'.invoke_history')
else:
histfile = os.path.join(Globals.root, outdir, '.invoke_history')
try:
readline.read_history_file(histfile)
readline.set_history_length(1000)
except FileNotFoundError:
pass
except OSError: # file likely corrupted
newname = f'{histfile}.old'
print(f'## Your history file {histfile} couldn\'t be loaded and may be corrupted. Renaming it to {newname}')
os.replace(histfile,newname)
atexit.register(readline.write_history_file, histfile)
else:
completer = DummyCompleter(COMMANDS)
return completer

View File

@ -0,0 +1,3 @@
'''
Initialization file for invokeai.frontend
'''

View File

@ -0,0 +1,7 @@
'''
Initialization file for invokeai.frontend.config
'''
from .model_install import main as invokeai_model_install
from .invokeai_configure import main as invokeai_configure
from .invokeai_update import main as invokeai_update

View File

@ -0,0 +1,4 @@
'''
Wrapper for invokeai.backend.configure.invokeai_configure
'''
from ...backend.config.invokeai_configure import main

View File

@ -0,0 +1,88 @@
'''
Minimalist updater script. Prompts user for the tag or branch to update to and runs
pip install <path_to_git_source>.
'''
import os
import platform
import requests
from rich import box, print
from rich.console import Console, Group, group
from rich.panel import Panel
from rich.prompt import Prompt
from rich.style import Style
from rich.syntax import Syntax
from rich.text import Text
from invokeai.version import __version__
INVOKE_AI_SRC="https://github.com/invoke-ai/InvokeAI/archive"
INVOKE_AI_REL="https://api.github.com/repos/invoke-ai/InvokeAI/releases"
OS = platform.uname().system
ARCH = platform.uname().machine
if OS == "Windows":
# Windows terminals look better without a background colour
console = Console(style=Style(color="grey74"))
else:
console = Console(style=Style(color="grey74", bgcolor="grey19"))
def get_versions()->dict:
return requests.get(url=INVOKE_AI_REL).json()
def welcome(versions: dict):
@group()
def text():
yield f'InvokeAI Version: [bold yellow]{__version__}'
yield ''
yield 'This script will update InvokeAI to the latest release, or to a development version of your choice.'
yield ''
yield '[bold yellow]Options:'
yield f'''[1] Update to the latest official release ([italic]{versions[0]['tag_name']}[/italic])
[2] Update to the bleeding-edge development version ([italic]main[/italic])
[3] Manually enter the tag or branch name you wish to update'''
console.rule()
print(
Panel(
title="[bold wheat1]InvokeAI Updater",
renderable=text(),
box=box.DOUBLE,
expand=True,
padding=(1, 2),
style=Style(bgcolor="grey23", color="orange1"),
subtitle=f"[bold grey39]{OS}-{ARCH}",
)
)
console.line()
def main():
versions = get_versions()
welcome(versions)
tag = None
choice = Prompt.ask('Choice:',choices=['1','2','3'],default='1')
if choice=='1':
tag = versions[0]['tag_name']
elif choice=='2':
tag = 'main'
elif choice=='3':
tag = Prompt.ask('Enter an InvokeAI tag or branch name')
print(f':crossed_fingers: Upgrading to [yellow]{tag}[/yellow]')
cmd = f'pip install {INVOKE_AI_SRC}/{tag}.zip --use-pep517'
print('')
print('')
if os.system(cmd)==0:
print(f':heavy_check_mark: Upgrade successful')
else:
print(f':exclamation: [bold red]Upgrade failed[/red bold]')
if __name__ == "__main__":
try:
main()
except KeyboardInterrupt:
pass

View File

@ -0,0 +1,504 @@
#!/usr/bin/env python
# Copyright (c) 2022 Lincoln D. Stein (https://github.com/lstein)
# Before running stable-diffusion on an internet-isolated machine,
# run this script from one with internet connectivity. The
# two machines must share a common .cache directory.
"""
This is the npyscreen frontend to the model installation application.
The work is actually done in backend code in model_install_backend.py.
"""
import argparse
import os
import sys
from argparse import Namespace
from pathlib import Path
from typing import List
import npyscreen
import torch
from npyscreen import widget
from omegaconf import OmegaConf
from shutil import get_terminal_size
from ...backend.util import choose_precision, choose_torch_device
from invokeai.backend.globals import Globals, global_config_dir
from ...backend.config.model_install_backend import (Dataset_path, default_config_file,
default_dataset, get_root,
install_requested_models,
recommended_datasets,
)
from .widgets import (MultiSelectColumns, TextBox,
OffsetButtonPress, CenteredTitleText,
set_min_terminal_size,
)
# minimum size for the UI
MIN_COLS = 120
MIN_LINES = 45
class addModelsForm(npyscreen.FormMultiPage):
# for responsive resizing - disabled
#FIX_MINIMUM_SIZE_WHEN_CREATED = False
def __init__(self, parentApp, name, multipage=False, *args, **keywords):
self.multipage = multipage
self.initial_models = OmegaConf.load(Dataset_path)
try:
self.existing_models = OmegaConf.load(default_config_file())
except:
self.existing_models = dict()
self.starter_model_list = [
x for x in list(self.initial_models.keys()) if x not in self.existing_models
]
self.installed_models = dict()
super().__init__(parentApp=parentApp, name=name, *args, **keywords)
def create(self):
window_width, window_height = get_terminal_size()
starter_model_labels = self._get_starter_model_labels()
recommended_models = [
x
for x in self.starter_model_list
if self.initial_models[x].get("recommended", False)
]
self.installed_models = sorted(
[x for x in list(self.initial_models.keys()) if x in self.existing_models]
)
self.nextrely -= 1
self.add_widget_intelligent(
npyscreen.FixedText,
value="Use ctrl-N and ctrl-P to move to the <N>ext and <P>revious fields,",
editable=False,
color='CAUTION',
)
self.add_widget_intelligent(
npyscreen.FixedText,
value="Use cursor arrows to make a selection, and space to toggle checkboxes.",
editable=False,
color='CAUTION'
)
self.nextrely += 1
if len(self.installed_models) > 0:
self.add_widget_intelligent(
CenteredTitleText,
name="== INSTALLED STARTER MODELS ==",
editable=False,
color="CONTROL",
)
self.nextrely -= 1
self.add_widget_intelligent(
CenteredTitleText,
name="Currently installed starter models. Uncheck to delete:",
editable=False,
labelColor="CAUTION",
)
self.nextrely -= 1
columns = self._get_columns()
self.previously_installed_models = self.add_widget_intelligent(
MultiSelectColumns,
columns=columns,
values=self.installed_models,
value=[x for x in range(0, len(self.installed_models))],
max_height=1 + len(self.installed_models) // columns,
relx=4,
slow_scroll=True,
scroll_exit=True,
)
self.purge_deleted = self.add_widget_intelligent(
npyscreen.Checkbox,
name="Purge deleted models from disk",
value=False,
scroll_exit=True,
relx=4,
)
self.nextrely += 1
if len(self.starter_model_list) > 0:
self.add_widget_intelligent(
CenteredTitleText,
name="== STARTER MODELS (recommended ones selected) ==",
editable=False,
color="CONTROL",
)
self.nextrely -= 1
self.add_widget_intelligent(
CenteredTitleText,
name="Select from a starter set of Stable Diffusion models from HuggingFace.",
editable=False,
labelColor="CAUTION",
)
self.nextrely -= 1
# if user has already installed some initial models, then don't patronize them
# by showing more recommendations
show_recommended = not self.existing_models
self.models_selected = self.add_widget_intelligent(
npyscreen.MultiSelect,
name="Install Starter Models",
values=starter_model_labels,
value=[
self.starter_model_list.index(x)
for x in self.starter_model_list
if show_recommended and x in recommended_models
],
max_height=len(starter_model_labels) + 1,
relx=4,
scroll_exit=True,
)
self.add_widget_intelligent(
CenteredTitleText,
name='== IMPORT LOCAL AND REMOTE MODELS ==',
editable=False,
color="CONTROL",
)
self.nextrely -= 1
for line in [
"In the box below, enter URLs, file paths, or HuggingFace repository IDs.",
"Separate model names by lines or whitespace (Use shift-control-V to paste):",
]:
self.add_widget_intelligent(
CenteredTitleText,
name=line,
editable=False,
labelColor="CONTROL",
relx = 4,
)
self.nextrely -= 1
self.import_model_paths = self.add_widget_intelligent(
TextBox,
max_height=7,
scroll_exit=True,
editable=True,
relx=4
)
self.nextrely += 1
self.show_directory_fields = self.add_widget_intelligent(
npyscreen.FormControlCheckbox,
name="Select a directory for models to import",
value=False,
)
self.autoload_directory = self.add_widget_intelligent(
npyscreen.TitleFilename,
name="Directory (<tab> autocompletes):",
select_dir=True,
must_exist=True,
use_two_lines=False,
labelColor="DANGER",
begin_entry_at=34,
scroll_exit=True,
)
self.autoscan_on_startup = self.add_widget_intelligent(
npyscreen.Checkbox,
name="Scan this directory each time InvokeAI starts for new models to import",
value=False,
relx=4,
scroll_exit=True,
)
self.nextrely += 1
self.convert_models = self.add_widget_intelligent(
npyscreen.TitleSelectOne,
name="== CONVERT IMPORTED MODELS INTO DIFFUSERS==",
values=["Keep original format", "Convert to diffusers"],
value=0,
begin_entry_at=4,
max_height=4,
hidden=True, # will appear when imported models box is edited
scroll_exit=True,
)
self.cancel = self.add_widget_intelligent(
npyscreen.ButtonPress,
name="CANCEL",
rely=-3,
when_pressed_function=self.on_cancel,
)
done_label = "DONE"
back_label = "BACK"
button_length = len(done_label)
button_offset = 0
if self.multipage:
button_length += len(back_label) + 1
button_offset += len(back_label) + 1
self.back_button = self.add_widget_intelligent(
OffsetButtonPress,
name=back_label,
relx=(window_width - button_length) // 2,
offset=-3,
rely=-3,
when_pressed_function=self.on_back,
)
self.ok_button = self.add_widget_intelligent(
OffsetButtonPress,
name=done_label,
offset=+3,
relx=button_offset + 1 + (window_width - button_length) // 2,
rely=-3,
when_pressed_function=self.on_ok,
)
for i in [self.autoload_directory, self.autoscan_on_startup]:
self.show_directory_fields.addVisibleWhenSelected(i)
self.show_directory_fields.when_value_edited = self._clear_scan_directory
self.import_model_paths.when_value_edited = self._show_hide_convert
self.autoload_directory.when_value_edited = self._show_hide_convert
def resize(self):
super().resize()
if hasattr(self,'models_selected'):
self.models_selected.values = self._get_starter_model_labels()
def _clear_scan_directory(self):
if not self.show_directory_fields.value:
self.autoload_directory.value = ""
def _show_hide_convert(self):
model_paths = self.import_model_paths.value or ""
autoload_directory = self.autoload_directory.value or ""
self.convert_models.hidden = (
len(model_paths) == 0 and len(autoload_directory) == 0
)
def _get_starter_model_labels(self) -> List[str]:
window_width, window_height = get_terminal_size()
label_width = 25
checkbox_width = 4
spacing_width = 2
description_width = window_width - label_width - checkbox_width - spacing_width
im = self.initial_models
names = self.starter_model_list
descriptions = [
im[x].description[0 : description_width - 3] + "..."
if len(im[x].description) > description_width
else im[x].description
for x in names
]
return [
f"%-{label_width}s %s" % (names[x], descriptions[x])
for x in range(0, len(names))
]
def _get_columns(self) -> int:
window_width, window_height = get_terminal_size()
cols = (
4
if window_width > 240
else 3
if window_width > 160
else 2
if window_width > 80
else 1
)
return min(cols, len(self.installed_models))
def on_ok(self):
self.parentApp.setNextForm(None)
self.editing = False
self.parentApp.user_cancelled = False
self.marshall_arguments()
def on_back(self):
self.parentApp.switchFormPrevious()
self.editing = False
def on_cancel(self):
if npyscreen.notify_yes_no(
"Are you sure you want to cancel?\nYou may re-run this script later using the invoke.sh or invoke.bat command.\n"
):
self.parentApp.setNextForm(None)
self.parentApp.user_cancelled = True
self.editing = False
def marshall_arguments(self):
"""
Assemble arguments and store as attributes of the application:
.starter_models: dict of model names to install from INITIAL_CONFIGURE.yaml
True => Install
False => Remove
.scan_directory: Path to a directory of models to scan and import
.autoscan_on_startup: True if invokeai should scan and import at startup time
.import_model_paths: list of URLs, repo_ids and file paths to import
.convert_to_diffusers: if True, convert legacy checkpoints into diffusers
"""
# we're using a global here rather than storing the result in the parentapp
# due to some bug in npyscreen that is causing attributes to be lost
selections = self.parentApp.user_selections
# starter models to install/remove
if hasattr(self,'models_selected'):
starter_models = dict(
map(
lambda x: (self.starter_model_list[x], True), self.models_selected.value
)
)
else:
starter_models = dict()
selections.purge_deleted_models = False
if hasattr(self, "previously_installed_models"):
unchecked = [
self.previously_installed_models.values[x]
for x in range(0, len(self.previously_installed_models.values))
if x not in self.previously_installed_models.value
]
starter_models.update(map(lambda x: (x, False), unchecked))
selections.purge_deleted_models = self.purge_deleted.value
selections.starter_models = starter_models
# load directory and whether to scan on startup
if self.show_directory_fields.value:
selections.scan_directory = self.autoload_directory.value
selections.autoscan_on_startup = self.autoscan_on_startup.value
else:
selections.scan_directory = None
selections.autoscan_on_startup = False
# URLs and the like
selections.import_model_paths = self.import_model_paths.value.split()
selections.convert_to_diffusers = self.convert_models.value[0] == 1
class AddModelApplication(npyscreen.NPSAppManaged):
def __init__(self):
super().__init__()
self.user_cancelled = False
self.user_selections = Namespace(
starter_models=None,
purge_deleted_models=False,
scan_directory=None,
autoscan_on_startup=None,
import_model_paths=None,
convert_to_diffusers=None,
)
def onStart(self):
npyscreen.setTheme(npyscreen.Themes.DefaultTheme)
self.main_form = self.addForm(
"MAIN", addModelsForm, name="Install Stable Diffusion Models"
)
# --------------------------------------------------------
def process_and_execute(opt: Namespace, selections: Namespace):
models_to_remove = [
x for x in selections.starter_models if not selections.starter_models[x]
]
models_to_install = [
x for x in selections.starter_models if selections.starter_models[x]
]
directory_to_scan = selections.scan_directory
scan_at_startup = selections.autoscan_on_startup
potential_models_to_install = selections.import_model_paths
convert_to_diffusers = selections.convert_to_diffusers
install_requested_models(
install_initial_models=models_to_install,
remove_models=models_to_remove,
scan_directory=Path(directory_to_scan) if directory_to_scan else None,
external_models=potential_models_to_install,
scan_at_startup=scan_at_startup,
convert_to_diffusers=convert_to_diffusers,
precision="float32"
if opt.full_precision
else choose_precision(torch.device(choose_torch_device())),
purge_deleted=selections.purge_deleted_models,
config_file_path=Path(opt.config_file) if opt.config_file else None,
)
# --------------------------------------------------------
def select_and_download_models(opt: Namespace):
precision = (
"float32"
if opt.full_precision
else choose_precision(torch.device(choose_torch_device()))
)
if opt.default_only:
install_requested_models(
install_initial_models=default_dataset(),
precision=precision,
)
elif opt.yes_to_all:
install_requested_models(
install_initial_models=recommended_datasets(),
precision=precision,
)
else:
set_min_terminal_size(MIN_COLS, MIN_LINES)
installApp = AddModelApplication()
installApp.run()
if not installApp.user_cancelled:
process_and_execute(opt, installApp.user_selections)
# -------------------------------------
def main():
parser = argparse.ArgumentParser(description="InvokeAI model downloader")
parser.add_argument(
"--full-precision",
dest="full_precision",
action=argparse.BooleanOptionalAction,
type=bool,
default=False,
help="use 32-bit weights instead of faster 16-bit weights",
)
parser.add_argument(
"--yes",
"-y",
dest="yes_to_all",
action="store_true",
help='answer "yes" to all prompts',
)
parser.add_argument(
"--default_only",
action="store_true",
help="only install the default model",
)
parser.add_argument(
"--config_file",
"-c",
dest="config_file",
type=str,
default=None,
help="path to configuration file to create",
)
parser.add_argument(
"--root_dir",
dest="root",
type=str,
default=None,
help="path to root of install directory",
)
opt = parser.parse_args()
# setting a global here
Globals.root = os.path.expanduser(get_root(opt.root) or "")
if not global_config_dir().exists():
print(
">> Your InvokeAI root directory is not set up. Calling invokeai-configure."
)
import ldm.invoke.config.invokeai_configure
ldm.invoke.config.invokeai_configure.main()
sys.exit(0)
try:
select_and_download_models(opt)
except AssertionError as e:
print(str(e))
sys.exit(-1)
except KeyboardInterrupt:
print("\nGoodbye! Come back soon.")
except widget.NotEnoughSpaceForWidget as e:
if str(e).startswith("Height of 1 allocated"):
print(
"** Insufficient vertical space for the interface. Please make your window taller and try again"
)
elif str(e).startswith("addwstr"):
print(
"** Insufficient horizontal space for the interface. Please make your window wider and try again."
)
# -------------------------------------
if __name__ == "__main__":
main()

View File

@ -0,0 +1,164 @@
'''
Widget class definitions used by model_select.py, merge_diffusers.py and textual_inversion.py
'''
import math
import platform
import npyscreen
import os
import sys
import curses
import struct
from shutil import get_terminal_size
# -------------------------------------
def set_terminal_size(columns: int, lines: int):
OS = platform.uname().system
if OS=="Windows":
os.system(f'mode con: cols={columns} lines={lines}')
elif OS in ['Darwin', 'Linux']:
import termios
import fcntl
winsize = struct.pack("HHHH", lines, columns, 0, 0)
fcntl.ioctl(sys.stdout.fileno(), termios.TIOCSWINSZ, winsize)
sys.stdout.write("\x1b[8;{rows};{cols}t".format(rows=lines, cols=columns))
sys.stdout.flush()
def set_min_terminal_size(min_cols: int, min_lines: int):
# make sure there's enough room for the ui
term_cols, term_lines = get_terminal_size()
cols = max(term_cols, min_cols)
lines = max(term_lines, min_lines)
set_terminal_size(cols,lines)
class IntSlider(npyscreen.Slider):
def translate_value(self):
stri = "%2d / %2d" % (self.value, self.out_of)
l = (len(str(self.out_of))) * 2 + 4
stri = stri.rjust(l)
return stri
# -------------------------------------
class CenteredTitleText(npyscreen.TitleText):
def __init__(self,*args,**keywords):
super().__init__(*args,**keywords)
self.resize()
def resize(self):
super().resize()
maxy, maxx = self.parent.curses_pad.getmaxyx()
label = self.name
self.relx = (maxx - len(label)) // 2
# -------------------------------------
class CenteredButtonPress(npyscreen.ButtonPress):
def resize(self):
super().resize()
maxy, maxx = self.parent.curses_pad.getmaxyx()
label = self.name
self.relx = (maxx - len(label)) // 2
# -------------------------------------
class OffsetButtonPress(npyscreen.ButtonPress):
def __init__(self, screen, offset=0, *args, **keywords):
super().__init__(screen, *args, **keywords)
self.offset = offset
def resize(self):
maxy, maxx = self.parent.curses_pad.getmaxyx()
width = len(self.name)
self.relx = self.offset + (maxx - width) // 2
class IntTitleSlider(npyscreen.TitleText):
_entry_type = IntSlider
class FloatSlider(npyscreen.Slider):
# this is supposed to adjust display precision, but doesn't
def translate_value(self):
stri = "%3.2f / %3.2f" % (self.value, self.out_of)
l = (len(str(self.out_of))) * 2 + 4
stri = stri.rjust(l)
return stri
class FloatTitleSlider(npyscreen.TitleText):
_entry_type = FloatSlider
class MultiSelectColumns(npyscreen.MultiSelect):
def __init__(self, screen, columns: int=1, values: list=[], **keywords):
self.columns = columns
self.value_cnt = len(values)
self.rows = math.ceil(self.value_cnt / self.columns)
super().__init__(screen,values=values, **keywords)
def make_contained_widgets(self):
self._my_widgets = []
column_width = self.width // self.columns
for h in range(self.value_cnt):
self._my_widgets.append(
self._contained_widgets(self.parent,
rely=self.rely + (h % self.rows) * self._contained_widget_height,
relx=self.relx + (h // self.rows) * column_width,
max_width=column_width,
max_height=self.__class__._contained_widget_height,
)
)
def set_up_handlers(self):
super().set_up_handlers()
self.handlers.update({
curses.KEY_UP: self.h_cursor_line_left,
curses.KEY_DOWN: self.h_cursor_line_right,
}
)
def h_cursor_line_down(self, ch):
self.cursor_line += self.rows
if self.cursor_line >= len(self.values):
if self.scroll_exit:
self.cursor_line = len(self.values)-self.rows
self.h_exit_down(ch)
return True
else:
self.cursor_line -= self.rows
return True
def h_cursor_line_up(self, ch):
self.cursor_line -= self.rows
if self.cursor_line < 0:
if self.scroll_exit:
self.cursor_line = 0
self.h_exit_up(ch)
else:
self.cursor_line = 0
def h_cursor_line_left(self,ch):
super().h_cursor_line_up(ch)
def h_cursor_line_right(self,ch):
super().h_cursor_line_down(ch)
class TextBox(npyscreen.MultiLineEdit):
def update(self, clear=True):
if clear: self.clear()
HEIGHT = self.height
WIDTH = self.width
# draw box.
self.parent.curses_pad.hline(self.rely, self.relx, curses.ACS_HLINE, WIDTH)
self.parent.curses_pad.hline(self.rely + HEIGHT, self.relx, curses.ACS_HLINE, WIDTH)
self.parent.curses_pad.vline(self.rely, self.relx, curses.ACS_VLINE, self.height)
self.parent.curses_pad.vline(self.rely, self.relx+WIDTH, curses.ACS_VLINE, HEIGHT)
# draw corners
self.parent.curses_pad.addch(self.rely, self.relx, curses.ACS_ULCORNER, )
self.parent.curses_pad.addch(self.rely, self.relx+WIDTH, curses.ACS_URCORNER, )
self.parent.curses_pad.addch(self.rely+HEIGHT, self.relx, curses.ACS_LLCORNER, )
self.parent.curses_pad.addch(self.rely+HEIGHT, self.relx+WIDTH, curses.ACS_LRCORNER, )
# fool our superclass into thinking drawing area is smaller - this is really hacky but it seems to work
(relx,rely,height,width) = (self.relx, self.rely, self.height, self.width)
self.relx += 1
self.rely += 1
self.height -= 1
self.width -= 1
super().update(clear=False)
(self.relx,self.rely,self.height,self.width) = (relx, rely, height, width)

View File

@ -0,0 +1,4 @@
'''
Initialization file for invokeai.frontend.merge
'''
from .merge_diffusers import main as invokeai_merge_diffusers

View File

@ -0,0 +1,467 @@
"""
ldm.invoke.merge_diffusers exports a single function call merge_diffusion_models()
used to merge 2-3 models together and create a new InvokeAI-registered diffusion model.
Copyright (c) 2023 Lincoln Stein and the InvokeAI Development Team
"""
import argparse
import curses
import os
import sys
import traceback
import warnings
from argparse import Namespace
from pathlib import Path
from typing import List, Union
import npyscreen
from diffusers import DiffusionPipeline
from diffusers import logging as dlogging
from npyscreen import widget
from omegaconf import OmegaConf
from ...frontend.config.widgets import FloatTitleSlider
from ...backend.globals import (Globals, global_cache_dir, global_config_file,
global_models_dir, global_set_root)
from ...backend.model_management import ModelManager
DEST_MERGED_MODEL_DIR = "merged_models"
def merge_diffusion_models(
model_ids_or_paths: List[Union[str, Path]],
alpha: float = 0.5,
interp: str = None,
force: bool = False,
**kwargs,
) -> DiffusionPipeline:
"""
model_ids_or_paths - up to three models, designated by their local paths or HuggingFace repo_ids
alpha - The interpolation parameter. Ranges from 0 to 1. It affects the ratio in which the checkpoints are merged. A 0.8 alpha
would mean that the first model checkpoints would affect the final result far less than an alpha of 0.2
interp - The interpolation method to use for the merging. Supports "sigmoid", "inv_sigmoid", "add_difference" and None.
Passing None uses the default interpolation which is weighted sum interpolation. For merging three checkpoints, only "add_difference" is supported.
force - Whether to ignore mismatch in model_config.json for the current models. Defaults to False.
**kwargs - the default DiffusionPipeline.get_config_dict kwargs:
cache_dir, resume_download, force_download, proxies, local_files_only, use_auth_token, revision, torch_dtype, device_map
"""
with warnings.catch_warnings():
warnings.simplefilter("ignore")
verbosity = dlogging.get_verbosity()
dlogging.set_verbosity_error()
pipe = DiffusionPipeline.from_pretrained(
model_ids_or_paths[0],
cache_dir=kwargs.get("cache_dir", global_cache_dir()),
custom_pipeline="checkpoint_merger",
)
merged_pipe = pipe.merge(
pretrained_model_name_or_path_list=model_ids_or_paths,
alpha=alpha,
interp=interp,
force=force,
**kwargs,
)
dlogging.set_verbosity(verbosity)
return merged_pipe
def merge_diffusion_models_and_commit(
models: List["str"],
merged_model_name: str,
alpha: float = 0.5,
interp: str = None,
force: bool = False,
**kwargs,
):
"""
models - up to three models, designated by their InvokeAI models.yaml model name
merged_model_name = name for new model
alpha - The interpolation parameter. Ranges from 0 to 1. It affects the ratio in which the checkpoints are merged. A 0.8 alpha
would mean that the first model checkpoints would affect the final result far less than an alpha of 0.2
interp - The interpolation method to use for the merging. Supports "weighted_average", "sigmoid", "inv_sigmoid", "add_difference" and None.
Passing None uses the default interpolation which is weighted sum interpolation. For merging three checkpoints, only "add_difference" is supported. Add_difference is A+(B-C).
force - Whether to ignore mismatch in model_config.json for the current models. Defaults to False.
**kwargs - the default DiffusionPipeline.get_config_dict kwargs:
cache_dir, resume_download, force_download, proxies, local_files_only, use_auth_token, revision, torch_dtype, device_map
"""
config_file = global_config_file()
model_manager = ModelManager(OmegaConf.load(config_file))
for mod in models:
assert mod in model_manager.model_names(), f'** Unknown model "{mod}"'
assert (
model_manager.model_info(mod).get("format", None) == "diffusers"
), f"** {mod} is not a diffusers model. It must be optimized before merging."
model_ids_or_paths = [model_manager.model_name_or_path(x) for x in models]
merged_pipe = merge_diffusion_models(
model_ids_or_paths, alpha, interp, force, **kwargs
)
dump_path = global_models_dir() / DEST_MERGED_MODEL_DIR
os.makedirs(dump_path, exist_ok=True)
dump_path = dump_path / merged_model_name
merged_pipe.save_pretrained(dump_path, safe_serialization=1)
import_args = dict(
model_name=merged_model_name, description=f'Merge of models {", ".join(models)}'
)
if vae := model_manager.config[models[0]].get("vae", None):
print(f">> Using configured VAE assigned to {models[0]}")
import_args.update(vae=vae)
model_manager.import_diffuser_model(dump_path, **import_args)
model_manager.commit(config_file)
def _parse_args() -> Namespace:
parser = argparse.ArgumentParser(description="InvokeAI model merging")
parser.add_argument(
"--root_dir",
type=Path,
default=Globals.root,
help="Path to the invokeai runtime directory",
)
parser.add_argument(
"--front_end",
"--gui",
dest="front_end",
action="store_true",
default=False,
help="Activate the text-based graphical front end for collecting parameters. Aside from --root_dir, other parameters will be ignored.",
)
parser.add_argument(
"--models",
type=str,
nargs="+",
help="Two to three model names to be merged",
)
parser.add_argument(
"--merged_model_name",
"--destination",
dest="merged_model_name",
type=str,
help="Name of the output model. If not specified, will be the concatenation of the input model names.",
)
parser.add_argument(
"--alpha",
type=float,
default=0.5,
help="The interpolation parameter, ranging from 0 to 1. It affects the ratio in which the checkpoints are merged. Higher values give more weight to the 2d and 3d models",
)
parser.add_argument(
"--interpolation",
dest="interp",
type=str,
choices=["weighted_sum", "sigmoid", "inv_sigmoid", "add_difference"],
default="weighted_sum",
help='Interpolation method to use. If three models are present, only "add_difference" will work.',
)
parser.add_argument(
"--force",
action="store_true",
help="Try to merge models even if they are incompatible with each other",
)
parser.add_argument(
"--clobber",
"--overwrite",
dest="clobber",
action="store_true",
help="Overwrite the merged model if --merged_model_name already exists",
)
return parser.parse_args()
# ------------------------- GUI HERE -------------------------
class mergeModelsForm(npyscreen.FormMultiPageAction):
interpolations = ["weighted_sum", "sigmoid", "inv_sigmoid"]
def __init__(self, parentApp, name):
self.parentApp = parentApp
self.ALLOW_RESIZE = True
self.FIX_MINIMUM_SIZE_WHEN_CREATED = False
super().__init__(parentApp, name)
@property
def model_manager(self):
return self.parentApp.model_manager
def afterEditing(self):
self.parentApp.setNextForm(None)
def create(self):
window_height, window_width = curses.initscr().getmaxyx()
self.model_names = self.get_model_names()
max_width = max([len(x) for x in self.model_names])
max_width += 6
horizontal_layout = max_width * 3 < window_width
self.add_widget_intelligent(
npyscreen.FixedText,
color="CONTROL",
value="Select two models to merge and optionally a third.",
editable=False,
)
self.add_widget_intelligent(
npyscreen.FixedText,
color="CONTROL",
value="Use up and down arrows to move, <space> to select an item, <tab> and <shift-tab> to move from one field to the next.",
editable=False,
)
self.add_widget_intelligent(
npyscreen.FixedText,
value="MODEL 1",
color="GOOD",
editable=False,
rely=4 if horizontal_layout else None,
)
self.model1 = self.add_widget_intelligent(
npyscreen.SelectOne,
values=self.model_names,
value=0,
max_height=len(self.model_names),
max_width=max_width,
scroll_exit=True,
rely=5,
)
self.add_widget_intelligent(
npyscreen.FixedText,
value="MODEL 2",
color="GOOD",
editable=False,
relx=max_width + 3 if horizontal_layout else None,
rely=4 if horizontal_layout else None,
)
self.model2 = self.add_widget_intelligent(
npyscreen.SelectOne,
name="(2)",
values=self.model_names,
value=1,
max_height=len(self.model_names),
max_width=max_width,
relx=max_width + 3 if horizontal_layout else None,
rely=5 if horizontal_layout else None,
scroll_exit=True,
)
self.add_widget_intelligent(
npyscreen.FixedText,
value="MODEL 3",
color="GOOD",
editable=False,
relx=max_width * 2 + 3 if horizontal_layout else None,
rely=4 if horizontal_layout else None,
)
models_plus_none = self.model_names.copy()
models_plus_none.insert(0, "None")
self.model3 = self.add_widget_intelligent(
npyscreen.SelectOne,
name="(3)",
values=models_plus_none,
value=0,
max_height=len(self.model_names) + 1,
max_width=max_width,
scroll_exit=True,
relx=max_width * 2 + 3 if horizontal_layout else None,
rely=5 if horizontal_layout else None,
)
for m in [self.model1, self.model2, self.model3]:
m.when_value_edited = self.models_changed
self.merged_model_name = self.add_widget_intelligent(
npyscreen.TitleText,
name="Name for merged model:",
labelColor="CONTROL",
value="",
scroll_exit=True,
)
self.force = self.add_widget_intelligent(
npyscreen.Checkbox,
name="Force merge of incompatible models",
labelColor="CONTROL",
value=False,
scroll_exit=True,
)
self.merge_method = self.add_widget_intelligent(
npyscreen.TitleSelectOne,
name="Merge Method:",
values=self.interpolations,
value=0,
labelColor="CONTROL",
max_height=len(self.interpolations) + 1,
scroll_exit=True,
)
self.alpha = self.add_widget_intelligent(
FloatTitleSlider,
name="Weight (alpha) to assign to second and third models:",
out_of=1.0,
step=0.01,
lowest=0,
value=0.5,
labelColor="CONTROL",
scroll_exit=True,
)
self.model1.editing = True
def models_changed(self):
models = self.model1.values
selected_model1 = self.model1.value[0]
selected_model2 = self.model2.value[0]
selected_model3 = self.model3.value[0]
merged_model_name = f"{models[selected_model1]}+{models[selected_model2]}"
self.merged_model_name.value = merged_model_name
if selected_model3 > 0:
self.merge_method.values = ['add_difference ( A+(B-C) )']
self.merged_model_name.value += f"+{models[selected_model3 -1]}" # In model3 there is one more element in the list (None). So we have to subtract one.
else:
self.merge_method.values = self.interpolations
self.merge_method.value = 0
def on_ok(self):
if self.validate_field_values() and self.check_for_overwrite():
self.parentApp.setNextForm(None)
self.editing = False
self.parentApp.merge_arguments = self.marshall_arguments()
npyscreen.notify("Starting the merge...")
else:
self.editing = True
def on_cancel(self):
sys.exit(0)
def marshall_arguments(self) -> dict:
model_names = self.model_names
models = [
model_names[self.model1.value[0]],
model_names[self.model2.value[0]],
]
if self.model3.value[0] > 0:
models.append(model_names[self.model3.value[0] - 1])
interp='add_difference'
else:
interp=self.interpolations[self.merge_method.value[0]]
args = dict(
models=models,
alpha=self.alpha.value,
interp=interp,
force=self.force.value,
merged_model_name=self.merged_model_name.value,
)
return args
def check_for_overwrite(self) -> bool:
model_out = self.merged_model_name.value
if model_out not in self.model_names:
return True
else:
return npyscreen.notify_yes_no(
f"The chosen merged model destination, {model_out}, is already in use. Overwrite?"
)
def validate_field_values(self) -> bool:
bad_fields = []
model_names = self.model_names
selected_models = set(
(model_names[self.model1.value[0]], model_names[self.model2.value[0]])
)
if self.model3.value[0] > 0:
selected_models.add(model_names[self.model3.value[0] - 1])
if len(selected_models) < 2:
bad_fields.append(
f"Please select two or three DIFFERENT models to compare. You selected {selected_models}"
)
if len(bad_fields) > 0:
message = "The following problems were detected and must be corrected:"
for problem in bad_fields:
message += f"\n* {problem}"
npyscreen.notify_confirm(message)
return False
else:
return True
def get_model_names(self) -> List[str]:
model_names = [
name
for name in self.model_manager.model_names()
if self.model_manager.model_info(name).get("format") == "diffusers"
]
print(model_names)
return sorted(model_names)
class Mergeapp(npyscreen.NPSAppManaged):
def __init__(self):
super().__init__()
conf = OmegaConf.load(global_config_file())
self.model_manager = ModelManager(
conf, "cpu", "float16"
) # precision doesn't really matter here
def onStart(self):
npyscreen.setTheme(npyscreen.Themes.ElegantTheme)
self.main = self.addForm("MAIN", mergeModelsForm, name="Merge Models Settings")
def run_gui(args: Namespace):
mergeapp = Mergeapp()
mergeapp.run()
args = mergeapp.merge_arguments
merge_diffusion_models_and_commit(**args)
print(f'>> Models merged into new model: "{args["merged_model_name"]}".')
def run_cli(args: Namespace):
assert args.alpha >= 0 and args.alpha <= 1.0, "alpha must be between 0 and 1"
assert (
args.models and len(args.models) >= 1 and len(args.models) <= 3
), "Please provide the --models argument to list 2 to 3 models to merge. Use --help for full usage."
if not args.merged_model_name:
args.merged_model_name = "+".join(args.models)
print(
f'>> No --merged_model_name provided. Defaulting to "{args.merged_model_name}"'
)
model_manager = ModelManager(OmegaConf.load(global_config_file()))
assert (
args.clobber or args.merged_model_name not in model_manager.model_names()
), f'A model named "{args.merged_model_name}" already exists. Use --clobber to overwrite.'
merge_diffusion_models_and_commit(**vars(args))
print(f'>> Models merged into new model: "{args.merged_model_name}".')
def main():
args = _parse_args()
global_set_root(args.root_dir)
cache_dir = str(global_cache_dir("diffusers"))
os.environ[
"HF_HOME"
] = cache_dir # because not clear the merge pipeline is honoring cache_dir
args.cache_dir = cache_dir
try:
if args.front_end:
run_gui(args)
else:
run_cli(args)
except widget.NotEnoughSpaceForWidget as e:
if str(e).startswith("Height of 1 allocated"):
print(
"** You need to have at least two diffusers models defined in models.yaml in order to merge"
)
else:
print("** Not enough room for the user interface. Try making this window larger.")
sys.exit(-1)
except Exception:
print(">> An error occurred:")
traceback.print_exc()
sys.exit(-1)
except KeyboardInterrupt:
sys.exit(-1)
if __name__ == "__main__":
main()

View File

@ -0,0 +1,5 @@
'''
Initialization file for invokeai.frontend.training
'''
from .textual_inversion import main as invokeai_textual_inversion

View File

@ -0,0 +1,461 @@
#!/usr/bin/env python
"""
This is the frontend to "textual_inversion_training.py".
Copyright (c) 2023 Lincoln Stein and the InvokeAI Development Team
"""
import os
import re
import shutil
import sys
import traceback
from argparse import Namespace
from pathlib import Path
from typing import List, Tuple
import npyscreen
from npyscreen import widget
from omegaconf import OmegaConf
from invokeai.backend.globals import Globals, global_set_root
from ...backend.training import (
do_textual_inversion_training,
parse_args,
)
TRAINING_DATA = "text-inversion-training-data"
TRAINING_DIR = "text-inversion-output"
CONF_FILE = "preferences.conf"
class textualInversionForm(npyscreen.FormMultiPageAction):
resolutions = [512, 768, 1024]
lr_schedulers = [
"linear",
"cosine",
"cosine_with_restarts",
"polynomial",
"constant",
"constant_with_warmup",
]
precisions = ["no", "fp16", "bf16"]
learnable_properties = ["object", "style"]
def __init__(self, parentApp, name, saved_args=None):
self.saved_args = saved_args or {}
super().__init__(parentApp, name)
def afterEditing(self):
self.parentApp.setNextForm(None)
def create(self):
self.model_names, default = self.get_model_names()
default_initializer_token = ""
default_placeholder_token = ""
saved_args = self.saved_args
try:
default = self.model_names.index(saved_args["model"])
except:
pass
self.add_widget_intelligent(
npyscreen.FixedText,
value="Use ctrl-N and ctrl-P to move to the <N>ext and <P>revious fields, cursor arrows to make a selection, and space to toggle checkboxes.",
editable=False,
)
self.model = self.add_widget_intelligent(
npyscreen.TitleSelectOne,
name="Model Name:",
values=self.model_names,
value=default,
max_height=len(self.model_names) + 1,
scroll_exit=True,
)
self.placeholder_token = self.add_widget_intelligent(
npyscreen.TitleText,
name="Trigger Term:",
value="", # saved_args.get('placeholder_token',''), # to restore previous term
scroll_exit=True,
)
self.placeholder_token.when_value_edited = self.initializer_changed
self.nextrely -= 1
self.nextrelx += 30
self.prompt_token = self.add_widget_intelligent(
npyscreen.FixedText,
name="Trigger term for use in prompt",
value="",
editable=False,
scroll_exit=True,
)
self.nextrelx -= 30
self.initializer_token = self.add_widget_intelligent(
npyscreen.TitleText,
name="Initializer:",
value=saved_args.get("initializer_token", default_initializer_token),
scroll_exit=True,
)
self.resume_from_checkpoint = self.add_widget_intelligent(
npyscreen.Checkbox,
name="Resume from last saved checkpoint",
value=False,
scroll_exit=True,
)
self.learnable_property = self.add_widget_intelligent(
npyscreen.TitleSelectOne,
name="Learnable property:",
values=self.learnable_properties,
value=self.learnable_properties.index(
saved_args.get("learnable_property", "object")
),
max_height=4,
scroll_exit=True,
)
self.train_data_dir = self.add_widget_intelligent(
npyscreen.TitleFilename,
name="Data Training Directory:",
select_dir=True,
must_exist=False,
value=str(
saved_args.get(
"train_data_dir",
Path(Globals.root) / TRAINING_DATA / default_placeholder_token,
)
),
scroll_exit=True,
)
self.output_dir = self.add_widget_intelligent(
npyscreen.TitleFilename,
name="Output Destination Directory:",
select_dir=True,
must_exist=False,
value=str(
saved_args.get(
"output_dir",
Path(Globals.root) / TRAINING_DIR / default_placeholder_token,
)
),
scroll_exit=True,
)
self.resolution = self.add_widget_intelligent(
npyscreen.TitleSelectOne,
name="Image resolution (pixels):",
values=self.resolutions,
value=self.resolutions.index(saved_args.get("resolution", 512)),
max_height=4,
scroll_exit=True,
)
self.center_crop = self.add_widget_intelligent(
npyscreen.Checkbox,
name="Center crop images before resizing to resolution",
value=saved_args.get("center_crop", False),
scroll_exit=True,
)
self.mixed_precision = self.add_widget_intelligent(
npyscreen.TitleSelectOne,
name="Mixed Precision:",
values=self.precisions,
value=self.precisions.index(saved_args.get("mixed_precision", "fp16")),
max_height=4,
scroll_exit=True,
)
self.num_train_epochs = self.add_widget_intelligent(
npyscreen.TitleSlider,
name="Number of training epochs:",
out_of=1000,
step=50,
lowest=1,
value=saved_args.get("num_train_epochs", 100),
scroll_exit=True,
)
self.max_train_steps = self.add_widget_intelligent(
npyscreen.TitleSlider,
name="Max Training Steps:",
out_of=10000,
step=500,
lowest=1,
value=saved_args.get("max_train_steps", 3000),
scroll_exit=True,
)
self.train_batch_size = self.add_widget_intelligent(
npyscreen.TitleSlider,
name="Batch Size (reduce if you run out of memory):",
out_of=50,
step=1,
lowest=1,
value=saved_args.get("train_batch_size", 8),
scroll_exit=True,
)
self.gradient_accumulation_steps = self.add_widget_intelligent(
npyscreen.TitleSlider,
name="Gradient Accumulation Steps (may need to decrease this to resume from a checkpoint):",
out_of=10,
step=1,
lowest=1,
value=saved_args.get("gradient_accumulation_steps", 4),
scroll_exit=True,
)
self.lr_warmup_steps = self.add_widget_intelligent(
npyscreen.TitleSlider,
name="Warmup Steps:",
out_of=100,
step=1,
lowest=0,
value=saved_args.get("lr_warmup_steps", 0),
scroll_exit=True,
)
self.learning_rate = self.add_widget_intelligent(
npyscreen.TitleText,
name="Learning Rate:",
value=str(
saved_args.get("learning_rate", "5.0e-04"),
),
scroll_exit=True,
)
self.scale_lr = self.add_widget_intelligent(
npyscreen.Checkbox,
name="Scale learning rate by number GPUs, steps and batch size",
value=saved_args.get("scale_lr", True),
scroll_exit=True,
)
self.enable_xformers_memory_efficient_attention = self.add_widget_intelligent(
npyscreen.Checkbox,
name="Use xformers acceleration",
value=saved_args.get("enable_xformers_memory_efficient_attention", False),
scroll_exit=True,
)
self.lr_scheduler = self.add_widget_intelligent(
npyscreen.TitleSelectOne,
name="Learning rate scheduler:",
values=self.lr_schedulers,
max_height=7,
value=self.lr_schedulers.index(saved_args.get("lr_scheduler", "constant")),
scroll_exit=True,
)
self.model.editing = True
def initializer_changed(self):
placeholder = self.placeholder_token.value
self.prompt_token.value = f"(Trigger by using <{placeholder}> in your prompts)"
self.train_data_dir.value = str(
Path(Globals.root) / TRAINING_DATA / placeholder
)
self.output_dir.value = str(Path(Globals.root) / TRAINING_DIR / placeholder)
self.resume_from_checkpoint.value = Path(self.output_dir.value).exists()
def on_ok(self):
if self.validate_field_values():
self.parentApp.setNextForm(None)
self.editing = False
self.parentApp.ti_arguments = self.marshall_arguments()
npyscreen.notify(
"Launching textual inversion training. This will take a while..."
)
else:
self.editing = True
def ok_cancel(self):
sys.exit(0)
def validate_field_values(self) -> bool:
bad_fields = []
if self.model.value is None:
bad_fields.append(
"Model Name must correspond to a known model in models.yaml"
)
if not re.match("^[a-zA-Z0-9.-]+$", self.placeholder_token.value):
bad_fields.append(
"Trigger term must only contain alphanumeric characters, the dot and hyphen"
)
if self.train_data_dir.value is None:
bad_fields.append("Data Training Directory cannot be empty")
if self.output_dir.value is None:
bad_fields.append("The Output Destination Directory cannot be empty")
if len(bad_fields) > 0:
message = "The following problems were detected and must be corrected:"
for problem in bad_fields:
message += f"\n* {problem}"
npyscreen.notify_confirm(message)
return False
else:
return True
def get_model_names(self) -> Tuple[List[str], int]:
conf = OmegaConf.load(os.path.join(Globals.root, "configs/models.yaml"))
model_names = [
idx
for idx in sorted(list(conf.keys()))
if conf[idx].get("format", None) == "diffusers"
]
defaults = [
idx
for idx in range(len(model_names))
if "default" in conf[model_names[idx]]
]
default = defaults[0] if len(defaults) > 0 else 0
return (model_names, default)
def marshall_arguments(self) -> dict:
args = dict()
# the choices
args.update(
model=self.model_names[self.model.value[0]],
resolution=self.resolutions[self.resolution.value[0]],
lr_scheduler=self.lr_schedulers[self.lr_scheduler.value[0]],
mixed_precision=self.precisions[self.mixed_precision.value[0]],
learnable_property=self.learnable_properties[
self.learnable_property.value[0]
],
)
# all the strings and booleans
for attr in (
"initializer_token",
"placeholder_token",
"train_data_dir",
"output_dir",
"scale_lr",
"center_crop",
"enable_xformers_memory_efficient_attention",
):
args[attr] = getattr(self, attr).value
# all the integers
for attr in (
"train_batch_size",
"gradient_accumulation_steps",
"num_train_epochs",
"max_train_steps",
"lr_warmup_steps",
):
args[attr] = int(getattr(self, attr).value)
# the floats (just one)
args.update(learning_rate=float(self.learning_rate.value))
# a special case
if self.resume_from_checkpoint.value and Path(self.output_dir.value).exists():
args["resume_from_checkpoint"] = "latest"
return args
class MyApplication(npyscreen.NPSAppManaged):
def __init__(self, saved_args=None):
super().__init__()
self.ti_arguments = None
self.saved_args = saved_args
def onStart(self):
npyscreen.setTheme(npyscreen.Themes.DefaultTheme)
self.main = self.addForm(
"MAIN",
textualInversionForm,
name="Textual Inversion Settings",
saved_args=self.saved_args,
)
def copy_to_embeddings_folder(args: dict):
"""
Copy learned_embeds.bin into the embeddings folder, and offer to
delete the full model and checkpoints.
"""
source = Path(args["output_dir"], "learned_embeds.bin")
dest_dir_name = args["placeholder_token"].strip("<>")
destination = Path(Globals.root, "embeddings", dest_dir_name)
os.makedirs(destination, exist_ok=True)
print(f">> Training completed. Copying learned_embeds.bin into {str(destination)}")
shutil.copy(source, destination)
if (
input("Delete training logs and intermediate checkpoints? [y] ") or "y"
).startswith(("y", "Y")):
shutil.rmtree(Path(args["output_dir"]))
else:
print(f'>> Keeping {args["output_dir"]}')
def save_args(args: dict):
"""
Save the current argument values to an omegaconf file
"""
dest_dir = Path(Globals.root) / TRAINING_DIR
os.makedirs(dest_dir, exist_ok=True)
conf_file = dest_dir / CONF_FILE
conf = OmegaConf.create(args)
OmegaConf.save(config=conf, f=conf_file)
def previous_args() -> dict:
"""
Get the previous arguments used.
"""
conf_file = Path(Globals.root) / TRAINING_DIR / CONF_FILE
try:
conf = OmegaConf.load(conf_file)
conf["placeholder_token"] = conf["placeholder_token"].strip("<>")
except:
conf = None
return conf
def do_front_end(args: Namespace):
saved_args = previous_args()
myapplication = MyApplication(saved_args=saved_args)
myapplication.run()
if args := myapplication.ti_arguments:
os.makedirs(args["output_dir"], exist_ok=True)
# Automatically add angle brackets around the trigger
if not re.match("^<.+>$", args["placeholder_token"]):
args["placeholder_token"] = f"<{args['placeholder_token']}>"
args["only_save_embeds"] = True
save_args(args)
try:
do_textual_inversion_training(**args)
copy_to_embeddings_folder(args)
except Exception as e:
print("** An exception occurred during training. The exception was:")
print(str(e))
print("** DETAILS:")
print(traceback.format_exc())
def main():
args = parse_args()
global_set_root(args.root_dir or Globals.root)
try:
if args.front_end:
do_front_end(args)
else:
do_textual_inversion_training(**vars(args))
except AssertionError as e:
print(str(e))
sys.exit(-1)
except KeyboardInterrupt:
pass
except (widget.NotEnoughSpaceForWidget, Exception) as e:
if str(e).startswith("Height of 1 allocated"):
print(
"** You need to have at least one diffusers models defined in models.yaml in order to train"
)
elif str(e).startswith('addwstr'):
print(
'** Not enough window space for the interface. Please make your window larger and try again.'
)
else:
print(f"** An error has occurred: {str(e)}")
sys.exit(-1)
if __name__ == "__main__":
main()

Some files were not shown because too many files have changed in this diff Show More