all files migrated; tweaks needed
80
invokeai/app/api/dependencies.py
Normal file
@ -0,0 +1,80 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
from argparse import Namespace
|
||||
import os
|
||||
|
||||
from ..services.processor import DefaultInvocationProcessor
|
||||
|
||||
from ..services.graph import GraphExecutionState
|
||||
from ..services.sqlite import SqliteItemStorage
|
||||
|
||||
from ...globals import Globals
|
||||
|
||||
from ..services.image_storage import DiskImageStorage
|
||||
from ..services.invocation_queue import MemoryInvocationQueue
|
||||
from ..services.invocation_services import InvocationServices
|
||||
from ..services.invoker import Invoker
|
||||
from ..services.generate_initializer import get_generate
|
||||
from .events import FastAPIEventService
|
||||
|
||||
|
||||
# TODO: is there a better way to achieve this?
|
||||
def check_internet()->bool:
|
||||
'''
|
||||
Return true if the internet is reachable.
|
||||
It does this by pinging huggingface.co.
|
||||
'''
|
||||
import urllib.request
|
||||
host = 'http://huggingface.co'
|
||||
try:
|
||||
urllib.request.urlopen(host,timeout=1)
|
||||
return True
|
||||
except:
|
||||
return False
|
||||
|
||||
|
||||
class ApiDependencies:
|
||||
"""Contains and initializes all dependencies for the API"""
|
||||
invoker: Invoker = None
|
||||
|
||||
@staticmethod
|
||||
def initialize(
|
||||
args,
|
||||
config,
|
||||
event_handler_id: int
|
||||
):
|
||||
Globals.try_patchmatch = args.patchmatch
|
||||
Globals.always_use_cpu = args.always_use_cpu
|
||||
Globals.internet_available = args.internet_available and check_internet()
|
||||
Globals.disable_xformers = not args.xformers
|
||||
Globals.ckpt_convert = args.ckpt_convert
|
||||
|
||||
# TODO: Use a logger
|
||||
print(f'>> Internet connectivity is {Globals.internet_available}')
|
||||
|
||||
generate = get_generate(args, config)
|
||||
|
||||
events = FastAPIEventService(event_handler_id)
|
||||
|
||||
output_folder = os.path.abspath(os.path.join(os.path.dirname(__file__), '../../../../outputs'))
|
||||
|
||||
images = DiskImageStorage(output_folder)
|
||||
|
||||
# TODO: build a file/path manager?
|
||||
db_location = os.path.join(output_folder, 'invokeai.db')
|
||||
|
||||
services = InvocationServices(
|
||||
generate = generate,
|
||||
events = events,
|
||||
images = images,
|
||||
queue = MemoryInvocationQueue(),
|
||||
graph_execution_manager = SqliteItemStorage[GraphExecutionState](filename = db_location, table_name = 'graph_executions'),
|
||||
processor = DefaultInvocationProcessor()
|
||||
)
|
||||
|
||||
ApiDependencies.invoker = Invoker(services)
|
||||
|
||||
@staticmethod
|
||||
def shutdown():
|
||||
if ApiDependencies.invoker:
|
||||
ApiDependencies.invoker.stop()
|
54
invokeai/app/api/events.py
Normal file
@ -0,0 +1,54 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
import asyncio
|
||||
from queue import Empty, Queue
|
||||
from typing import Any
|
||||
from fastapi_events.dispatcher import dispatch
|
||||
from ..services.events import EventServiceBase
|
||||
import threading
|
||||
|
||||
class FastAPIEventService(EventServiceBase):
|
||||
event_handler_id: int
|
||||
__queue: Queue
|
||||
__stop_event: threading.Event
|
||||
|
||||
def __init__(self, event_handler_id: int) -> None:
|
||||
self.event_handler_id = event_handler_id
|
||||
self.__queue = Queue()
|
||||
self.__stop_event = threading.Event()
|
||||
asyncio.create_task(self.__dispatch_from_queue(stop_event = self.__stop_event))
|
||||
|
||||
super().__init__()
|
||||
|
||||
|
||||
def stop(self, *args, **kwargs):
|
||||
self.__stop_event.set()
|
||||
self.__queue.put(None)
|
||||
|
||||
|
||||
def dispatch(self, event_name: str, payload: Any) -> None:
|
||||
self.__queue.put(dict(
|
||||
event_name = event_name,
|
||||
payload = payload
|
||||
))
|
||||
|
||||
|
||||
async def __dispatch_from_queue(self, stop_event: threading.Event):
|
||||
"""Get events on from the queue and dispatch them, from the correct thread"""
|
||||
while not stop_event.is_set():
|
||||
try:
|
||||
event = self.__queue.get(block = False)
|
||||
if not event: # Probably stopping
|
||||
continue
|
||||
|
||||
dispatch(
|
||||
event.get('event_name'),
|
||||
payload = event.get('payload'),
|
||||
middleware_id = self.event_handler_id)
|
||||
|
||||
except Empty:
|
||||
await asyncio.sleep(0.001)
|
||||
pass
|
||||
|
||||
except asyncio.CancelledError as e:
|
||||
raise e # Raise a proper error
|
57
invokeai/app/api/routers/images.py
Normal file
@ -0,0 +1,57 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from fastapi import Path, UploadFile, Request
|
||||
from fastapi.routing import APIRouter
|
||||
from fastapi.responses import FileResponse, Response
|
||||
from PIL import Image
|
||||
from ...services.image_storage import ImageType
|
||||
from ..dependencies import ApiDependencies
|
||||
|
||||
images_router = APIRouter(
|
||||
prefix = '/v1/images',
|
||||
tags = ['images']
|
||||
)
|
||||
|
||||
|
||||
@images_router.get('/{image_type}/{image_name}',
|
||||
operation_id = 'get_image'
|
||||
)
|
||||
async def get_image(
|
||||
image_type: ImageType = Path(description = "The type of image to get"),
|
||||
image_name: str = Path(description = "The name of the image to get")
|
||||
):
|
||||
"""Gets a result"""
|
||||
# TODO: This is not really secure at all. At least make sure only output results are served
|
||||
filename = ApiDependencies.invoker.services.images.get_path(image_type, image_name)
|
||||
return FileResponse(filename)
|
||||
|
||||
@images_router.post('/uploads/',
|
||||
operation_id = 'upload_image',
|
||||
responses = {
|
||||
201: {'description': 'The image was uploaded successfully'},
|
||||
404: {'description': 'Session not found'}
|
||||
})
|
||||
async def upload_image(
|
||||
file: UploadFile,
|
||||
request: Request
|
||||
):
|
||||
if not file.content_type.startswith('image'):
|
||||
return Response(status_code = 415)
|
||||
|
||||
contents = await file.read()
|
||||
try:
|
||||
im = Image.open(contents)
|
||||
except:
|
||||
# Error opening the image
|
||||
return Response(status_code = 415)
|
||||
|
||||
filename = f'{str(int(datetime.now(timezone.utc).timestamp()))}.png'
|
||||
ApiDependencies.invoker.services.images.save(ImageType.UPLOAD, filename, im)
|
||||
|
||||
return Response(
|
||||
status_code=201,
|
||||
headers = {
|
||||
'Location': request.url_for('get_image', image_type=ImageType.UPLOAD, image_name=filename)
|
||||
}
|
||||
)
|
232
invokeai/app/api/routers/sessions.py
Normal file
@ -0,0 +1,232 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
from typing import List, Optional, Union, Annotated
|
||||
from fastapi import Query, Path, Body
|
||||
from fastapi.routing import APIRouter
|
||||
from fastapi.responses import Response
|
||||
from pydantic.fields import Field
|
||||
|
||||
from ...services.item_storage import PaginatedResults
|
||||
from ..dependencies import ApiDependencies
|
||||
from ...invocations.baseinvocation import BaseInvocation
|
||||
from ...services.graph import EdgeConnection, Graph, GraphExecutionState, NodeAlreadyExecutedError
|
||||
from ...invocations import *
|
||||
|
||||
session_router = APIRouter(
|
||||
prefix = '/v1/sessions',
|
||||
tags = ['sessions']
|
||||
)
|
||||
|
||||
|
||||
@session_router.post('/',
|
||||
operation_id = 'create_session',
|
||||
responses = {
|
||||
200: {"model": GraphExecutionState},
|
||||
400: {'description': 'Invalid json'}
|
||||
})
|
||||
async def create_session(
|
||||
graph: Optional[Graph] = Body(default = None, description = "The graph to initialize the session with")
|
||||
) -> GraphExecutionState:
|
||||
"""Creates a new session, optionally initializing it with an invocation graph"""
|
||||
session = ApiDependencies.invoker.create_execution_state(graph)
|
||||
return session
|
||||
|
||||
|
||||
@session_router.get('/',
|
||||
operation_id = 'list_sessions',
|
||||
responses = {
|
||||
200: {"model": PaginatedResults[GraphExecutionState]}
|
||||
})
|
||||
async def list_sessions(
|
||||
page: int = Query(default = 0, description = "The page of results to get"),
|
||||
per_page: int = Query(default = 10, description = "The number of results per page"),
|
||||
query: str = Query(default = '', description = "The query string to search for")
|
||||
) -> PaginatedResults[GraphExecutionState]:
|
||||
"""Gets a list of sessions, optionally searching"""
|
||||
if filter == '':
|
||||
result = ApiDependencies.invoker.services.graph_execution_manager.list(page, per_page)
|
||||
else:
|
||||
result = ApiDependencies.invoker.services.graph_execution_manager.search(query, page, per_page)
|
||||
return result
|
||||
|
||||
|
||||
@session_router.get('/{session_id}',
|
||||
operation_id = 'get_session',
|
||||
responses = {
|
||||
200: {"model": GraphExecutionState},
|
||||
404: {'description': 'Session not found'}
|
||||
})
|
||||
async def get_session(
|
||||
session_id: str = Path(description = "The id of the session to get")
|
||||
) -> GraphExecutionState:
|
||||
"""Gets a session"""
|
||||
session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
|
||||
if session is None:
|
||||
return Response(status_code = 404)
|
||||
else:
|
||||
return session
|
||||
|
||||
|
||||
@session_router.post('/{session_id}/nodes',
|
||||
operation_id = 'add_node',
|
||||
responses = {
|
||||
200: {"model": str},
|
||||
400: {'description': 'Invalid node or link'},
|
||||
404: {'description': 'Session not found'}
|
||||
}
|
||||
)
|
||||
async def add_node(
|
||||
session_id: str = Path(description = "The id of the session"),
|
||||
node: Annotated[Union[BaseInvocation.get_invocations()], Field(discriminator="type")] = Body(description = "The node to add")
|
||||
) -> str:
|
||||
"""Adds a node to the graph"""
|
||||
session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
|
||||
if session is None:
|
||||
return Response(status_code = 404)
|
||||
|
||||
try:
|
||||
session.add_node(node)
|
||||
ApiDependencies.invoker.services.graph_execution_manager.set(session) # TODO: can this be done automatically, or add node through an API?
|
||||
return session.id
|
||||
except NodeAlreadyExecutedError:
|
||||
return Response(status_code = 400)
|
||||
except IndexError:
|
||||
return Response(status_code = 400)
|
||||
|
||||
|
||||
@session_router.put('/{session_id}/nodes/{node_path}',
|
||||
operation_id = 'update_node',
|
||||
responses = {
|
||||
200: {"model": GraphExecutionState},
|
||||
400: {'description': 'Invalid node or link'},
|
||||
404: {'description': 'Session not found'}
|
||||
}
|
||||
)
|
||||
async def update_node(
|
||||
session_id: str = Path(description = "The id of the session"),
|
||||
node_path: str = Path(description = "The path to the node in the graph"),
|
||||
node: Annotated[Union[BaseInvocation.get_invocations()], Field(discriminator="type")] = Body(description = "The new node")
|
||||
) -> GraphExecutionState:
|
||||
"""Updates a node in the graph and removes all linked edges"""
|
||||
session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
|
||||
if session is None:
|
||||
return Response(status_code = 404)
|
||||
|
||||
try:
|
||||
session.update_node(node_path, node)
|
||||
ApiDependencies.invoker.services.graph_execution_manager.set(session) # TODO: can this be done automatically, or add node through an API?
|
||||
return session
|
||||
except NodeAlreadyExecutedError:
|
||||
return Response(status_code = 400)
|
||||
except IndexError:
|
||||
return Response(status_code = 400)
|
||||
|
||||
|
||||
@session_router.delete('/{session_id}/nodes/{node_path}',
|
||||
operation_id = 'delete_node',
|
||||
responses = {
|
||||
200: {"model": GraphExecutionState},
|
||||
400: {'description': 'Invalid node or link'},
|
||||
404: {'description': 'Session not found'}
|
||||
}
|
||||
)
|
||||
async def delete_node(
|
||||
session_id: str = Path(description = "The id of the session"),
|
||||
node_path: str = Path(description = "The path to the node to delete")
|
||||
) -> GraphExecutionState:
|
||||
"""Deletes a node in the graph and removes all linked edges"""
|
||||
session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
|
||||
if session is None:
|
||||
return Response(status_code = 404)
|
||||
|
||||
try:
|
||||
session.delete_node(node_path)
|
||||
ApiDependencies.invoker.services.graph_execution_manager.set(session) # TODO: can this be done automatically, or add node through an API?
|
||||
return session
|
||||
except NodeAlreadyExecutedError:
|
||||
return Response(status_code = 400)
|
||||
except IndexError:
|
||||
return Response(status_code = 400)
|
||||
|
||||
|
||||
@session_router.post('/{session_id}/edges',
|
||||
operation_id = 'add_edge',
|
||||
responses = {
|
||||
200: {"model": GraphExecutionState},
|
||||
400: {'description': 'Invalid node or link'},
|
||||
404: {'description': 'Session not found'}
|
||||
}
|
||||
)
|
||||
async def add_edge(
|
||||
session_id: str = Path(description = "The id of the session"),
|
||||
edge: tuple[EdgeConnection, EdgeConnection] = Body(description = "The edge to add")
|
||||
) -> GraphExecutionState:
|
||||
"""Adds an edge to the graph"""
|
||||
session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
|
||||
if session is None:
|
||||
return Response(status_code = 404)
|
||||
|
||||
try:
|
||||
session.add_edge(edge)
|
||||
ApiDependencies.invoker.services.graph_execution_manager.set(session) # TODO: can this be done automatically, or add node through an API?
|
||||
return session
|
||||
except NodeAlreadyExecutedError:
|
||||
return Response(status_code = 400)
|
||||
except IndexError:
|
||||
return Response(status_code = 400)
|
||||
|
||||
|
||||
# TODO: the edge being in the path here is really ugly, find a better solution
|
||||
@session_router.delete('/{session_id}/edges/{from_node_id}/{from_field}/{to_node_id}/{to_field}',
|
||||
operation_id = 'delete_edge',
|
||||
responses = {
|
||||
200: {"model": GraphExecutionState},
|
||||
400: {'description': 'Invalid node or link'},
|
||||
404: {'description': 'Session not found'}
|
||||
}
|
||||
)
|
||||
async def delete_edge(
|
||||
session_id: str = Path(description = "The id of the session"),
|
||||
from_node_id: str = Path(description = "The id of the node the edge is coming from"),
|
||||
from_field: str = Path(description = "The field of the node the edge is coming from"),
|
||||
to_node_id: str = Path(description = "The id of the node the edge is going to"),
|
||||
to_field: str = Path(description = "The field of the node the edge is going to")
|
||||
) -> GraphExecutionState:
|
||||
"""Deletes an edge from the graph"""
|
||||
session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
|
||||
if session is None:
|
||||
return Response(status_code = 404)
|
||||
|
||||
try:
|
||||
edge = (EdgeConnection(node_id = from_node_id, field = from_field), EdgeConnection(node_id = to_node_id, field = to_field))
|
||||
session.delete_edge(edge)
|
||||
ApiDependencies.invoker.services.graph_execution_manager.set(session) # TODO: can this be done automatically, or add node through an API?
|
||||
return session
|
||||
except NodeAlreadyExecutedError:
|
||||
return Response(status_code = 400)
|
||||
except IndexError:
|
||||
return Response(status_code = 400)
|
||||
|
||||
|
||||
@session_router.put('/{session_id}/invoke',
|
||||
operation_id = 'invoke_session',
|
||||
responses = {
|
||||
200: {"model": None},
|
||||
202: {'description': 'The invocation is queued'},
|
||||
400: {'description': 'The session has no invocations ready to invoke'},
|
||||
404: {'description': 'Session not found'}
|
||||
})
|
||||
async def invoke_session(
|
||||
session_id: str = Path(description = "The id of the session to invoke"),
|
||||
all: bool = Query(default = False, description = "Whether or not to invoke all remaining invocations")
|
||||
) -> None:
|
||||
"""Invokes a session"""
|
||||
session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
|
||||
if session is None:
|
||||
return Response(status_code = 404)
|
||||
|
||||
if session.is_complete():
|
||||
return Response(status_code = 400)
|
||||
|
||||
ApiDependencies.invoker.invoke(session, invoke_all = all)
|
||||
return Response(status_code=202)
|
36
invokeai/app/api/sockets.py
Normal file
@ -0,0 +1,36 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
from fastapi import FastAPI
|
||||
from fastapi_socketio import SocketManager
|
||||
from fastapi_events.handlers.local import local_handler
|
||||
from fastapi_events.typing import Event
|
||||
from ..services.events import EventServiceBase
|
||||
|
||||
class SocketIO:
|
||||
__sio: SocketManager
|
||||
|
||||
def __init__(self, app: FastAPI):
|
||||
self.__sio = SocketManager(app = app)
|
||||
self.__sio.on('subscribe', handler=self._handle_sub)
|
||||
self.__sio.on('unsubscribe', handler=self._handle_unsub)
|
||||
|
||||
local_handler.register(
|
||||
event_name = EventServiceBase.session_event,
|
||||
_func=self._handle_session_event
|
||||
)
|
||||
|
||||
async def _handle_session_event(self, event: Event):
|
||||
await self.__sio.emit(
|
||||
event = event[1]['event'],
|
||||
data = event[1]['data'],
|
||||
room = event[1]['data']['graph_execution_state_id']
|
||||
)
|
||||
|
||||
async def _handle_sub(self, sid, data, *args, **kwargs):
|
||||
if 'session' in data:
|
||||
self.__sio.enter_room(sid, data['session'])
|
||||
|
||||
# @app.sio.on('unsubscribe')
|
||||
async def _handle_unsub(self, sid, data, *args, **kwargs):
|
||||
if 'session' in data:
|
||||
self.__sio.leave_room(sid, data['session'])
|
164
invokeai/app/api_app.py
Normal file
@ -0,0 +1,164 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
import asyncio
|
||||
from inspect import signature
|
||||
from fastapi import FastAPI
|
||||
from fastapi.openapi.utils import get_openapi
|
||||
from fastapi.openapi.docs import get_swagger_ui_html, get_redoc_html
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from fastapi_events.middleware import EventHandlerASGIMiddleware
|
||||
from fastapi_events.handlers.local import local_handler
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from pydantic.schema import schema
|
||||
import uvicorn
|
||||
from .api.sockets import SocketIO
|
||||
from .invocations import *
|
||||
from .invocations.baseinvocation import BaseInvocation
|
||||
from .api.routers import images, sessions
|
||||
from .api.dependencies import ApiDependencies
|
||||
from ..args import Args
|
||||
|
||||
# Create the app
|
||||
# TODO: create this all in a method so configuration/etc. can be passed in?
|
||||
app = FastAPI(
|
||||
title = "Invoke AI",
|
||||
docs_url = None,
|
||||
redoc_url = None
|
||||
)
|
||||
|
||||
# Add event handler
|
||||
event_handler_id: int = id(app)
|
||||
app.add_middleware(
|
||||
EventHandlerASGIMiddleware,
|
||||
handlers = [local_handler], # TODO: consider doing this in services to support different configurations
|
||||
middleware_id = event_handler_id)
|
||||
|
||||
# Add CORS
|
||||
# TODO: use configuration for this
|
||||
origins = []
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=origins,
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
socket_io = SocketIO(app)
|
||||
|
||||
config = {}
|
||||
|
||||
# Add startup event to load dependencies
|
||||
@app.on_event('startup')
|
||||
async def startup_event():
|
||||
args = Args()
|
||||
config = args.parse_args()
|
||||
|
||||
ApiDependencies.initialize(
|
||||
args = args,
|
||||
config = config,
|
||||
event_handler_id = event_handler_id
|
||||
)
|
||||
|
||||
# Shut down threads
|
||||
@app.on_event('shutdown')
|
||||
async def shutdown_event():
|
||||
ApiDependencies.shutdown()
|
||||
|
||||
# Include all routers
|
||||
# TODO: REMOVE
|
||||
# app.include_router(
|
||||
# invocation.invocation_router,
|
||||
# prefix = '/api')
|
||||
|
||||
app.include_router(
|
||||
sessions.session_router,
|
||||
prefix = '/api'
|
||||
)
|
||||
|
||||
app.include_router(
|
||||
images.images_router,
|
||||
prefix = '/api'
|
||||
)
|
||||
|
||||
# Build a custom OpenAPI to include all outputs
|
||||
# TODO: can outputs be included on metadata of invocation schemas somehow?
|
||||
def custom_openapi():
|
||||
if app.openapi_schema:
|
||||
return app.openapi_schema
|
||||
openapi_schema = get_openapi(
|
||||
title = app.title,
|
||||
description = "An API for invoking AI image operations",
|
||||
version = "1.0.0",
|
||||
routes = app.routes
|
||||
)
|
||||
|
||||
# Add all outputs
|
||||
all_invocations = BaseInvocation.get_invocations()
|
||||
output_types = set()
|
||||
output_type_titles = dict()
|
||||
for invoker in all_invocations:
|
||||
output_type = signature(invoker.invoke).return_annotation
|
||||
output_types.add(output_type)
|
||||
|
||||
output_schemas = schema(output_types, ref_prefix="#/components/schemas/")
|
||||
for schema_key, output_schema in output_schemas['definitions'].items():
|
||||
openapi_schema["components"]["schemas"][schema_key] = output_schema
|
||||
|
||||
# TODO: note that we assume the schema_key here is the TYPE.__name__
|
||||
# This could break in some cases, figure out a better way to do it
|
||||
output_type_titles[schema_key] = output_schema['title']
|
||||
|
||||
# Add a reference to the output type to additionalProperties of the invoker schema
|
||||
for invoker in all_invocations:
|
||||
invoker_name = invoker.__name__
|
||||
output_type = signature(invoker.invoke).return_annotation
|
||||
output_type_title = output_type_titles[output_type.__name__]
|
||||
invoker_schema = openapi_schema["components"]["schemas"][invoker_name]
|
||||
outputs_ref = { '$ref': f'#/components/schemas/{output_type_title}' }
|
||||
if 'additionalProperties' not in invoker_schema:
|
||||
invoker_schema['additionalProperties'] = {}
|
||||
|
||||
invoker_schema['additionalProperties']['outputs'] = outputs_ref
|
||||
|
||||
app.openapi_schema = openapi_schema
|
||||
return app.openapi_schema
|
||||
|
||||
app.openapi = custom_openapi
|
||||
|
||||
# Override API doc favicons
|
||||
app.mount('/static', StaticFiles(directory='static/dream_web'), name='static')
|
||||
|
||||
@app.get("/docs", include_in_schema=False)
|
||||
def overridden_swagger():
|
||||
return get_swagger_ui_html(
|
||||
openapi_url=app.openapi_url,
|
||||
title=app.title,
|
||||
swagger_favicon_url="/static/favicon.ico"
|
||||
)
|
||||
|
||||
@app.get("/redoc", include_in_schema=False)
|
||||
def overridden_redoc():
|
||||
return get_redoc_html(
|
||||
openapi_url=app.openapi_url,
|
||||
title=app.title,
|
||||
redoc_favicon_url="/static/favicon.ico"
|
||||
)
|
||||
|
||||
def invoke_api():
|
||||
# Start our own event loop for eventing usage
|
||||
# TODO: determine if there's a better way to do this
|
||||
loop = asyncio.new_event_loop()
|
||||
config = uvicorn.Config(
|
||||
app = app,
|
||||
host = "0.0.0.0",
|
||||
port = 9090,
|
||||
loop = loop)
|
||||
# Use access_log to turn off logging
|
||||
|
||||
server = uvicorn.Server(config)
|
||||
loop.run_until_complete(server.serve())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
invoke_api()
|
315
invokeai/app/cli_app.py
Normal file
@ -0,0 +1,315 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
import argparse
|
||||
import shlex
|
||||
import os
|
||||
import time
|
||||
from typing import Any, Dict, Iterable, Literal, Union, get_args, get_origin, get_type_hints
|
||||
from pydantic import BaseModel
|
||||
from pydantic.fields import Field
|
||||
|
||||
from .services.processor import DefaultInvocationProcessor
|
||||
|
||||
from .services.graph import EdgeConnection, GraphExecutionState
|
||||
|
||||
from .services.sqlite import SqliteItemStorage
|
||||
|
||||
from .invocations.image import ImageField
|
||||
from .services.generate_initializer import get_generate
|
||||
from .services.image_storage import DiskImageStorage
|
||||
from .services.invocation_queue import MemoryInvocationQueue
|
||||
from .invocations.baseinvocation import BaseInvocation
|
||||
from .services.invocation_services import InvocationServices
|
||||
from .services.invoker import Invoker
|
||||
from .invocations import *
|
||||
from ..args import Args
|
||||
from .services.events import EventServiceBase
|
||||
|
||||
|
||||
class InvocationCommand(BaseModel):
|
||||
invocation: Union[BaseInvocation.get_invocations()] = Field(discriminator="type") # type: ignore
|
||||
|
||||
|
||||
class InvalidArgs(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def get_invocation_parser() -> argparse.ArgumentParser:
|
||||
|
||||
# Create invocation parser
|
||||
parser = argparse.ArgumentParser()
|
||||
def exit(*args, **kwargs):
|
||||
raise InvalidArgs
|
||||
parser.exit = exit
|
||||
|
||||
subparsers = parser.add_subparsers(dest='type')
|
||||
invocation_parsers = dict()
|
||||
|
||||
# Add history parser
|
||||
history_parser = subparsers.add_parser('history', help="Shows the invocation history")
|
||||
history_parser.add_argument('count', nargs='?', default=5, type=int, help="The number of history entries to show")
|
||||
|
||||
# Add default parser
|
||||
default_parser = subparsers.add_parser('default', help="Define a default value for all inputs with a specified name")
|
||||
default_parser.add_argument('input', type=str, help="The input field")
|
||||
default_parser.add_argument('value', help="The default value")
|
||||
|
||||
default_parser = subparsers.add_parser('reset_default', help="Resets a default value")
|
||||
default_parser.add_argument('input', type=str, help="The input field")
|
||||
|
||||
# Create subparsers for each invocation
|
||||
invocations = BaseInvocation.get_all_subclasses()
|
||||
for invocation in invocations:
|
||||
hints = get_type_hints(invocation)
|
||||
cmd_name = get_args(hints['type'])[0]
|
||||
command_parser = subparsers.add_parser(cmd_name, help=invocation.__doc__)
|
||||
invocation_parsers[cmd_name] = command_parser
|
||||
|
||||
# Add linking capability
|
||||
command_parser.add_argument('--link', '-l', action='append', nargs=3,
|
||||
help="A link in the format 'dest_field source_node source_field'. source_node can be relative to history (e.g. -1)")
|
||||
|
||||
command_parser.add_argument('--link_node', '-ln', action='append',
|
||||
help="A link from all fields in the specified node. Node can be relative to history (e.g. -1)")
|
||||
|
||||
# Convert all fields to arguments
|
||||
fields = invocation.__fields__
|
||||
for name, field in fields.items():
|
||||
if name in ['id', 'type']:
|
||||
continue
|
||||
|
||||
if get_origin(field.type_) == Literal:
|
||||
allowed_values = get_args(field.type_)
|
||||
allowed_types = set()
|
||||
for val in allowed_values:
|
||||
allowed_types.add(type(val))
|
||||
allowed_types_list = list(allowed_types)
|
||||
field_type = allowed_types_list[0] if len(allowed_types) == 1 else Union[allowed_types_list] # type: ignore
|
||||
|
||||
command_parser.add_argument(
|
||||
f"--{name}",
|
||||
dest=name,
|
||||
type=field_type,
|
||||
default=field.default,
|
||||
choices = allowed_values,
|
||||
help=field.field_info.description
|
||||
)
|
||||
else:
|
||||
command_parser.add_argument(
|
||||
f"--{name}",
|
||||
dest=name,
|
||||
type=field.type_,
|
||||
default=field.default,
|
||||
help=field.field_info.description
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def get_invocation_command(invocation) -> str:
|
||||
fields = invocation.__fields__.items()
|
||||
type_hints = get_type_hints(type(invocation))
|
||||
command = [invocation.type]
|
||||
for name,field in fields:
|
||||
if name in ['id', 'type']:
|
||||
continue
|
||||
|
||||
# TODO: add links
|
||||
|
||||
# Skip image fields when serializing command
|
||||
type_hint = type_hints.get(name) or None
|
||||
if type_hint is ImageField or ImageField in get_args(type_hint):
|
||||
continue
|
||||
|
||||
field_value = getattr(invocation, name)
|
||||
field_default = field.default
|
||||
if field_value != field_default:
|
||||
if type_hint is str or str in get_args(type_hint):
|
||||
command.append(f'--{name} "{field_value}"')
|
||||
else:
|
||||
command.append(f'--{name} {field_value}')
|
||||
|
||||
return ' '.join(command)
|
||||
|
||||
|
||||
def get_graph_execution_history(graph_execution_state: GraphExecutionState) -> Iterable[str]:
|
||||
"""Gets the history of fully-executed invocations for a graph execution"""
|
||||
return (n for n in reversed(graph_execution_state.executed_history) if n in graph_execution_state.graph.nodes)
|
||||
|
||||
|
||||
def generate_matching_edges(a: BaseInvocation, b: BaseInvocation) -> list[tuple[EdgeConnection, EdgeConnection]]:
|
||||
"""Generates all possible edges between two invocations"""
|
||||
atype = type(a)
|
||||
btype = type(b)
|
||||
|
||||
aoutputtype = atype.get_output_type()
|
||||
|
||||
afields = get_type_hints(aoutputtype)
|
||||
bfields = get_type_hints(btype)
|
||||
|
||||
matching_fields = set(afields.keys()).intersection(bfields.keys())
|
||||
|
||||
# Remove invalid fields
|
||||
invalid_fields = set(['type', 'id'])
|
||||
matching_fields = matching_fields.difference(invalid_fields)
|
||||
|
||||
edges = [(EdgeConnection(node_id = a.id, field = field), EdgeConnection(node_id = b.id, field = field)) for field in matching_fields]
|
||||
return edges
|
||||
|
||||
|
||||
def invoke_cli():
|
||||
args = Args()
|
||||
config = args.parse_args()
|
||||
|
||||
generate = get_generate(args, config)
|
||||
|
||||
# NOTE: load model on first use, uncomment to load at startup
|
||||
# TODO: Make this a config option?
|
||||
#generate.load_model()
|
||||
|
||||
events = EventServiceBase()
|
||||
|
||||
output_folder = os.path.abspath(os.path.join(os.path.dirname(__file__), '../../../outputs'))
|
||||
|
||||
# TODO: build a file/path manager?
|
||||
db_location = os.path.join(output_folder, 'invokeai.db')
|
||||
|
||||
services = InvocationServices(
|
||||
generate = generate,
|
||||
events = events,
|
||||
images = DiskImageStorage(output_folder),
|
||||
queue = MemoryInvocationQueue(),
|
||||
graph_execution_manager = SqliteItemStorage[GraphExecutionState](filename = db_location, table_name = 'graph_executions'),
|
||||
processor = DefaultInvocationProcessor()
|
||||
)
|
||||
|
||||
invoker = Invoker(services)
|
||||
session: GraphExecutionState = invoker.create_execution_state()
|
||||
|
||||
parser = get_invocation_parser()
|
||||
|
||||
# Uncomment to print out previous sessions at startup
|
||||
# print(services.session_manager.list())
|
||||
|
||||
# Defaults storage
|
||||
defaults: Dict[str, Any] = dict()
|
||||
|
||||
while True:
|
||||
try:
|
||||
cmd_input = input("> ")
|
||||
except KeyboardInterrupt:
|
||||
# Ctrl-c exits
|
||||
break
|
||||
|
||||
if cmd_input in ['exit','q']:
|
||||
break;
|
||||
|
||||
if cmd_input in ['--help','help','h','?']:
|
||||
parser.print_help()
|
||||
continue
|
||||
|
||||
try:
|
||||
# Refresh the state of the session
|
||||
session = invoker.services.graph_execution_manager.get(session.id)
|
||||
history = list(get_graph_execution_history(session))
|
||||
|
||||
# Split the command for piping
|
||||
cmds = cmd_input.split('|')
|
||||
start_id = len(history)
|
||||
current_id = start_id
|
||||
new_invocations = list()
|
||||
for cmd in cmds:
|
||||
if cmd is None or cmd.strip() == '':
|
||||
raise InvalidArgs('Empty command')
|
||||
|
||||
# Parse args to create invocation
|
||||
args = vars(parser.parse_args(shlex.split(cmd.strip())))
|
||||
|
||||
# Check for special commands
|
||||
# TODO: These might be better as Pydantic models, similar to the invocations
|
||||
if args['type'] == 'history':
|
||||
history_count = args['count'] or 5
|
||||
for i in range(min(history_count, len(history))):
|
||||
entry_id = history[-1 - i]
|
||||
entry = session.graph.get_node(entry_id)
|
||||
print(f'{entry_id}: {get_invocation_command(entry.invocation)}')
|
||||
continue
|
||||
|
||||
if args['type'] == 'reset_default':
|
||||
if args['input'] in defaults:
|
||||
del defaults[args['input']]
|
||||
continue
|
||||
|
||||
if args['type'] == 'default':
|
||||
field = args['input']
|
||||
field_value = args['value']
|
||||
defaults[field] = field_value
|
||||
continue
|
||||
|
||||
# Override defaults
|
||||
for field_name,field_default in defaults.items():
|
||||
if field_name in args:
|
||||
args[field_name] = field_default
|
||||
|
||||
# Parse invocation
|
||||
args['id'] = current_id
|
||||
command = InvocationCommand(invocation = args)
|
||||
|
||||
# Pipe previous command output (if there was a previous command)
|
||||
edges = []
|
||||
if len(history) > 0 or current_id != start_id:
|
||||
from_id = history[0] if current_id == start_id else str(current_id - 1)
|
||||
from_node = next(filter(lambda n: n[0].id == from_id, new_invocations))[0] if current_id != start_id else session.graph.get_node(from_id)
|
||||
matching_edges = generate_matching_edges(from_node, command.invocation)
|
||||
edges.extend(matching_edges)
|
||||
|
||||
# Parse provided links
|
||||
if 'link_node' in args and args['link_node']:
|
||||
for link in args['link_node']:
|
||||
link_node = session.graph.get_node(link)
|
||||
matching_edges = generate_matching_edges(link_node, command.invocation)
|
||||
edges.extend(matching_edges)
|
||||
|
||||
if 'link' in args and args['link']:
|
||||
for link in args['link']:
|
||||
edges.append((EdgeConnection(node_id = link[1], field = link[0]), EdgeConnection(node_id = command.invocation.id, field = link[2])))
|
||||
|
||||
new_invocations.append((command.invocation, edges))
|
||||
|
||||
current_id = current_id + 1
|
||||
|
||||
# Command line was parsed successfully
|
||||
# Add the invocations to the session
|
||||
for invocation in new_invocations:
|
||||
session.add_node(invocation[0])
|
||||
for edge in invocation[1]:
|
||||
session.add_edge(edge)
|
||||
|
||||
# Execute all available invocations
|
||||
invoker.invoke(session, invoke_all = True)
|
||||
while not session.is_complete():
|
||||
# Wait some time
|
||||
session = invoker.services.graph_execution_manager.get(session.id)
|
||||
time.sleep(0.1)
|
||||
|
||||
# Print any errors
|
||||
if session.has_error():
|
||||
for n in session.errors:
|
||||
print(f'Error in node {n} (source node {session.prepared_source_mapping[n]}): {session.errors[n]}')
|
||||
|
||||
# Start a new session
|
||||
print("Creating a new session")
|
||||
session = invoker.create_execution_state()
|
||||
|
||||
except InvalidArgs:
|
||||
print('Invalid command, use "help" to list commands')
|
||||
continue
|
||||
|
||||
except SystemExit:
|
||||
continue
|
||||
|
||||
invoker.stop()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
invoke_cli()
|
8
invokeai/app/invocations/__init__.py
Normal file
@ -0,0 +1,8 @@
|
||||
import os
|
||||
|
||||
__all__ = []
|
||||
|
||||
dirname = os.path.dirname(os.path.abspath(__file__))
|
||||
for f in os.listdir(dirname):
|
||||
if f != "__init__.py" and os.path.isfile("%s/%s" % (dirname, f)) and f[-3:] == ".py":
|
||||
__all__.append(f[:-3])
|
74
invokeai/app/invocations/baseinvocation.py
Normal file
@ -0,0 +1,74 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from inspect import signature
|
||||
from typing import get_args, get_type_hints
|
||||
from pydantic import BaseModel, Field
|
||||
from ..services.invocation_services import InvocationServices
|
||||
|
||||
|
||||
class InvocationContext:
|
||||
services: InvocationServices
|
||||
graph_execution_state_id: str
|
||||
|
||||
def __init__(self, services: InvocationServices, graph_execution_state_id: str):
|
||||
self.services = services
|
||||
self.graph_execution_state_id = graph_execution_state_id
|
||||
|
||||
|
||||
class BaseInvocationOutput(BaseModel):
|
||||
"""Base class for all invocation outputs"""
|
||||
|
||||
# All outputs must include a type name like this:
|
||||
# type: Literal['your_output_name']
|
||||
|
||||
@classmethod
|
||||
def get_all_subclasses_tuple(cls):
|
||||
subclasses = []
|
||||
toprocess = [cls]
|
||||
while len(toprocess) > 0:
|
||||
next = toprocess.pop(0)
|
||||
next_subclasses = next.__subclasses__()
|
||||
subclasses.extend(next_subclasses)
|
||||
toprocess.extend(next_subclasses)
|
||||
return tuple(subclasses)
|
||||
|
||||
|
||||
class BaseInvocation(ABC, BaseModel):
|
||||
"""A node to process inputs and produce outputs.
|
||||
May use dependency injection in __init__ to receive providers.
|
||||
"""
|
||||
|
||||
# All invocations must include a type name like this:
|
||||
# type: Literal['your_output_name']
|
||||
|
||||
@classmethod
|
||||
def get_all_subclasses(cls):
|
||||
subclasses = []
|
||||
toprocess = [cls]
|
||||
while len(toprocess) > 0:
|
||||
next = toprocess.pop(0)
|
||||
next_subclasses = next.__subclasses__()
|
||||
subclasses.extend(next_subclasses)
|
||||
toprocess.extend(next_subclasses)
|
||||
return subclasses
|
||||
|
||||
@classmethod
|
||||
def get_invocations(cls):
|
||||
return tuple(BaseInvocation.get_all_subclasses())
|
||||
|
||||
@classmethod
|
||||
def get_invocations_map(cls):
|
||||
# Get the type strings out of the literals and into a dictionary
|
||||
return dict(map(lambda t: (get_args(get_type_hints(t)['type'])[0], t),BaseInvocation.get_all_subclasses()))
|
||||
|
||||
@classmethod
|
||||
def get_output_type(cls):
|
||||
return signature(cls.invoke).return_annotation
|
||||
|
||||
@abstractmethod
|
||||
def invoke(self, context: InvocationContext) -> BaseInvocationOutput:
|
||||
"""Invoke with provided context and return outputs."""
|
||||
pass
|
||||
|
||||
id: str = Field(description="The id of this node. Must be unique among all nodes.")
|
42
invokeai/app/invocations/cv.py
Normal file
@ -0,0 +1,42 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
from typing import Literal
|
||||
import numpy
|
||||
from pydantic import Field
|
||||
from PIL import Image, ImageOps
|
||||
import cv2 as cv
|
||||
from .image import ImageField, ImageOutput
|
||||
from .baseinvocation import BaseInvocation, InvocationContext
|
||||
from ..services.image_storage import ImageType
|
||||
|
||||
|
||||
class CvInpaintInvocation(BaseInvocation):
|
||||
"""Simple inpaint using opencv."""
|
||||
type: Literal['cv_inpaint'] = 'cv_inpaint'
|
||||
|
||||
# Inputs
|
||||
image: ImageField = Field(default=None, description="The image to inpaint")
|
||||
mask: ImageField = Field(default=None, description="The mask to use when inpainting")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get(self.image.image_type, self.image.image_name)
|
||||
mask = context.services.images.get(self.mask.image_type, self.mask.image_name)
|
||||
|
||||
# Convert to cv image/mask
|
||||
# TODO: consider making these utility functions
|
||||
cv_image = cv.cvtColor(numpy.array(image.convert('RGB')), cv.COLOR_RGB2BGR)
|
||||
cv_mask = numpy.array(ImageOps.invert(mask))
|
||||
|
||||
# Inpaint
|
||||
cv_inpainted = cv.inpaint(cv_image, cv_mask, 3, cv.INPAINT_TELEA)
|
||||
|
||||
# Convert back to Pillow
|
||||
# TODO: consider making a utility function
|
||||
image_inpainted = Image.fromarray(cv.cvtColor(cv_inpainted, cv.COLOR_BGR2RGB))
|
||||
|
||||
image_type = ImageType.INTERMEDIATE
|
||||
image_name = context.services.images.create_name(context.graph_execution_state_id, self.id)
|
||||
context.services.images.save(image_type, image_name, image_inpainted)
|
||||
return ImageOutput(
|
||||
image = ImageField(image_type = image_type, image_name = image_name)
|
||||
)
|
160
invokeai/app/invocations/generate.py
Normal file
@ -0,0 +1,160 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Literal, Optional, Union
|
||||
import numpy as np
|
||||
from pydantic import Field
|
||||
from PIL import Image
|
||||
from skimage.exposure.histogram_matching import match_histograms
|
||||
from .image import ImageField, ImageOutput
|
||||
from .baseinvocation import BaseInvocation, InvocationContext
|
||||
from ..services.image_storage import ImageType
|
||||
from ..services.invocation_services import InvocationServices
|
||||
|
||||
|
||||
SAMPLER_NAME_VALUES = Literal["ddim","plms","k_lms","k_dpm_2","k_dpm_2_a","k_euler","k_euler_a","k_heun"]
|
||||
|
||||
# Text to image
|
||||
class TextToImageInvocation(BaseInvocation):
|
||||
"""Generates an image using text2img."""
|
||||
type: Literal['txt2img'] = 'txt2img'
|
||||
|
||||
# Inputs
|
||||
# TODO: consider making prompt optional to enable providing prompt through a link
|
||||
prompt: Optional[str] = Field(description="The prompt to generate an image from")
|
||||
seed: int = Field(default=-1, ge=-1, le=np.iinfo(np.uint32).max, description="The seed to use (-1 for a random seed)")
|
||||
steps: int = Field(default=10, gt=0, description="The number of steps to use to generate the image")
|
||||
width: int = Field(default=512, multiple_of=64, gt=0, description="The width of the resulting image")
|
||||
height: int = Field(default=512, multiple_of=64, gt=0, description="The height of the resulting image")
|
||||
cfg_scale: float = Field(default=7.5, gt=0, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt")
|
||||
sampler_name: SAMPLER_NAME_VALUES = Field(default="k_lms", description="The sampler to use")
|
||||
seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams")
|
||||
model: str = Field(default='', description="The model to use (currently ignored)")
|
||||
progress_images: bool = Field(default=False, description="Whether or not to produce progress images during generation")
|
||||
|
||||
# TODO: pass this an emitter method or something? or a session for dispatching?
|
||||
def dispatch_progress(self, context: InvocationContext, sample: Any = None, step: int = 0) -> None:
|
||||
context.services.events.emit_generator_progress(
|
||||
context.graph_execution_state_id, self.id, step, float(step) / float(self.steps)
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
|
||||
def step_callback(sample, step = 0):
|
||||
self.dispatch_progress(context, sample, step)
|
||||
|
||||
# Handle invalid model parameter
|
||||
# TODO: figure out if this can be done via a validator that uses the model_cache
|
||||
# TODO: How to get the default model name now?
|
||||
if self.model is None or self.model == '':
|
||||
self.model = context.services.generate.model_name
|
||||
|
||||
# Set the model (if already cached, this does nothing)
|
||||
context.services.generate.set_model(self.model)
|
||||
|
||||
results = context.services.generate.prompt2image(
|
||||
prompt = self.prompt,
|
||||
step_callback = step_callback,
|
||||
**self.dict(exclude = {'prompt'}) # Shorthand for passing all of the parameters above manually
|
||||
)
|
||||
|
||||
# Results are image and seed, unwrap for now and ignore the seed
|
||||
# TODO: pre-seed?
|
||||
# TODO: can this return multiple results? Should it?
|
||||
image_type = ImageType.RESULT
|
||||
image_name = context.services.images.create_name(context.graph_execution_state_id, self.id)
|
||||
context.services.images.save(image_type, image_name, results[0][0])
|
||||
return ImageOutput(
|
||||
image = ImageField(image_type = image_type, image_name = image_name)
|
||||
)
|
||||
|
||||
|
||||
class ImageToImageInvocation(TextToImageInvocation):
|
||||
"""Generates an image using img2img."""
|
||||
type: Literal['img2img'] = 'img2img'
|
||||
|
||||
# Inputs
|
||||
image: Union[ImageField,None] = Field(description="The input image")
|
||||
strength: float = Field(default=0.75, gt=0, le=1, description="The strength of the original image")
|
||||
fit: bool = Field(default=True, description="Whether or not the result should be fit to the aspect ratio of the input image")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = None if self.image is None else context.services.images.get(self.image.image_type, self.image.image_name)
|
||||
mask = None
|
||||
|
||||
def step_callback(sample, step = 0):
|
||||
self.dispatch_progress(context, sample, step)
|
||||
|
||||
# Handle invalid model parameter
|
||||
# TODO: figure out if this can be done via a validator that uses the model_cache
|
||||
# TODO: How to get the default model name now?
|
||||
if self.model is None or self.model == '':
|
||||
self.model = context.services.generate.model_name
|
||||
|
||||
# Set the model (if already cached, this does nothing)
|
||||
context.services.generate.set_model(self.model)
|
||||
|
||||
results = context.services.generate.prompt2image(
|
||||
prompt = self.prompt,
|
||||
init_img = image,
|
||||
init_mask = mask,
|
||||
step_callback = step_callback,
|
||||
**self.dict(exclude = {'prompt','image','mask'}) # Shorthand for passing all of the parameters above manually
|
||||
)
|
||||
|
||||
result_image = results[0][0]
|
||||
|
||||
# Results are image and seed, unwrap for now and ignore the seed
|
||||
# TODO: pre-seed?
|
||||
# TODO: can this return multiple results? Should it?
|
||||
image_type = ImageType.RESULT
|
||||
image_name = context.services.images.create_name(context.graph_execution_state_id, self.id)
|
||||
context.services.images.save(image_type, image_name, result_image)
|
||||
return ImageOutput(
|
||||
image = ImageField(image_type = image_type, image_name = image_name)
|
||||
)
|
||||
|
||||
|
||||
class InpaintInvocation(ImageToImageInvocation):
|
||||
"""Generates an image using inpaint."""
|
||||
type: Literal['inpaint'] = 'inpaint'
|
||||
|
||||
# Inputs
|
||||
mask: Union[ImageField,None] = Field(description="The mask")
|
||||
inpaint_replace: float = Field(default=0.0, ge=0.0, le=1.0, description="The amount by which to replace masked areas with latent noise")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = None if self.image is None else context.services.images.get(self.image.image_type, self.image.image_name)
|
||||
mask = None if self.mask is None else context.services.images.get(self.mask.image_type, self.mask.image_name)
|
||||
|
||||
def step_callback(sample, step = 0):
|
||||
self.dispatch_progress(context, sample, step)
|
||||
|
||||
# Handle invalid model parameter
|
||||
# TODO: figure out if this can be done via a validator that uses the model_cache
|
||||
# TODO: How to get the default model name now?
|
||||
if self.model is None or self.model == '':
|
||||
self.model = context.services.generate.model_name
|
||||
|
||||
# Set the model (if already cached, this does nothing)
|
||||
context.services.generate.set_model(self.model)
|
||||
|
||||
results = context.services.generate.prompt2image(
|
||||
prompt = self.prompt,
|
||||
init_img = image,
|
||||
init_mask = mask,
|
||||
step_callback = step_callback,
|
||||
**self.dict(exclude = {'prompt','image','mask'}) # Shorthand for passing all of the parameters above manually
|
||||
)
|
||||
|
||||
result_image = results[0][0]
|
||||
|
||||
# Results are image and seed, unwrap for now and ignore the seed
|
||||
# TODO: pre-seed?
|
||||
# TODO: can this return multiple results? Should it?
|
||||
image_type = ImageType.RESULT
|
||||
image_name = context.services.images.create_name(context.graph_execution_state_id, self.id)
|
||||
context.services.images.save(image_type, image_name, result_image)
|
||||
return ImageOutput(
|
||||
image = ImageField(image_type = image_type, image_name = image_name)
|
||||
)
|
219
invokeai/app/invocations/image.py
Normal file
@ -0,0 +1,219 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from typing import Literal, Optional
|
||||
import numpy
|
||||
from pydantic import Field, BaseModel
|
||||
from PIL import Image, ImageOps, ImageFilter
|
||||
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext
|
||||
from ..services.image_storage import ImageType
|
||||
from ..services.invocation_services import InvocationServices
|
||||
|
||||
|
||||
class ImageField(BaseModel):
|
||||
"""An image field used for passing image objects between invocations"""
|
||||
image_type: str = Field(default=ImageType.RESULT, description="The type of the image")
|
||||
image_name: Optional[str] = Field(default=None, description="The name of the image")
|
||||
|
||||
|
||||
class ImageOutput(BaseInvocationOutput):
|
||||
"""Base class for invocations that output an image"""
|
||||
type: Literal['image'] = 'image'
|
||||
|
||||
image: ImageField = Field(default=None, description="The output image")
|
||||
|
||||
|
||||
class MaskOutput(BaseInvocationOutput):
|
||||
"""Base class for invocations that output a mask"""
|
||||
type: Literal['mask'] = 'mask'
|
||||
|
||||
mask: ImageField = Field(default=None, description="The output mask")
|
||||
|
||||
|
||||
# TODO: this isn't really necessary anymore
|
||||
class LoadImageInvocation(BaseInvocation):
|
||||
"""Load an image from a filename and provide it as output."""
|
||||
type: Literal['load_image'] = 'load_image'
|
||||
|
||||
# Inputs
|
||||
image_type: ImageType = Field(description="The type of the image")
|
||||
image_name: str = Field(description="The name of the image")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
return ImageOutput(
|
||||
image = ImageField(image_type = self.image_type, image_name = self.image_name)
|
||||
)
|
||||
|
||||
|
||||
class ShowImageInvocation(BaseInvocation):
|
||||
"""Displays a provided image, and passes it forward in the pipeline."""
|
||||
type: Literal['show_image'] = 'show_image'
|
||||
|
||||
# Inputs
|
||||
image: ImageField = Field(default=None, description="The image to show")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get(self.image.image_type, self.image.image_name)
|
||||
if image:
|
||||
image.show()
|
||||
|
||||
# TODO: how to handle failure?
|
||||
|
||||
return ImageOutput(
|
||||
image = ImageField(image_type = self.image.image_type, image_name = self.image.image_name)
|
||||
)
|
||||
|
||||
|
||||
class CropImageInvocation(BaseInvocation):
|
||||
"""Crops an image to a specified box. The box can be outside of the image."""
|
||||
type: Literal['crop'] = 'crop'
|
||||
|
||||
# Inputs
|
||||
image: ImageField = Field(default=None, description="The image to crop")
|
||||
x: int = Field(default=0, description="The left x coordinate of the crop rectangle")
|
||||
y: int = Field(default=0, description="The top y coordinate of the crop rectangle")
|
||||
width: int = Field(default=512, gt=0, description="The width of the crop rectangle")
|
||||
height: int = Field(default=512, gt=0, description="The height of the crop rectangle")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get(self.image.image_type, self.image.image_name)
|
||||
|
||||
image_crop = Image.new(mode = 'RGBA', size = (self.width, self.height), color = (0, 0, 0, 0))
|
||||
image_crop.paste(image, (-self.x, -self.y))
|
||||
|
||||
image_type = ImageType.INTERMEDIATE
|
||||
image_name = context.services.images.create_name(context.graph_execution_state_id, self.id)
|
||||
context.services.images.save(image_type, image_name, image_crop)
|
||||
return ImageOutput(
|
||||
image = ImageField(image_type = image_type, image_name = image_name)
|
||||
)
|
||||
|
||||
|
||||
class PasteImageInvocation(BaseInvocation):
|
||||
"""Pastes an image into another image."""
|
||||
type: Literal['paste'] = 'paste'
|
||||
|
||||
# Inputs
|
||||
base_image: ImageField = Field(default=None, description="The base image")
|
||||
image: ImageField = Field(default=None, description="The image to paste")
|
||||
mask: Optional[ImageField] = Field(default=None, description="The mask to use when pasting")
|
||||
x: int = Field(default=0, description="The left x coordinate at which to paste the image")
|
||||
y: int = Field(default=0, description="The top y coordinate at which to paste the image")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
base_image = context.services.images.get(self.base_image.image_type, self.base_image.image_name)
|
||||
image = context.services.images.get(self.image.image_type, self.image.image_name)
|
||||
mask = None if self.mask is None else ImageOps.invert(services.images.get(self.mask.image_type, self.mask.image_name))
|
||||
# TODO: probably shouldn't invert mask here... should user be required to do it?
|
||||
|
||||
min_x = min(0, self.x)
|
||||
min_y = min(0, self.y)
|
||||
max_x = max(base_image.width, image.width + self.x)
|
||||
max_y = max(base_image.height, image.height + self.y)
|
||||
|
||||
new_image = Image.new(mode = 'RGBA', size = (max_x - min_x, max_y - min_y), color = (0, 0, 0, 0))
|
||||
new_image.paste(base_image, (abs(min_x), abs(min_y)))
|
||||
new_image.paste(image, (max(0, self.x), max(0, self.y)), mask = mask)
|
||||
|
||||
image_type = ImageType.RESULT
|
||||
image_name = context.services.images.create_name(context.graph_execution_state_id, self.id)
|
||||
context.services.images.save(image_type, image_name, new_image)
|
||||
return ImageOutput(
|
||||
image = ImageField(image_type = image_type, image_name = image_name)
|
||||
)
|
||||
|
||||
|
||||
class MaskFromAlphaInvocation(BaseInvocation):
|
||||
"""Extracts the alpha channel of an image as a mask."""
|
||||
type: Literal['tomask'] = 'tomask'
|
||||
|
||||
# Inputs
|
||||
image: ImageField = Field(default=None, description="The image to create the mask from")
|
||||
invert: bool = Field(default=False, description="Whether or not to invert the mask")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> MaskOutput:
|
||||
image = context.services.images.get(self.image.image_type, self.image.image_name)
|
||||
|
||||
image_mask = image.split()[-1]
|
||||
if self.invert:
|
||||
image_mask = ImageOps.invert(image_mask)
|
||||
|
||||
image_type = ImageType.INTERMEDIATE
|
||||
image_name = context.services.images.create_name(context.graph_execution_state_id, self.id)
|
||||
context.services.images.save(image_type, image_name, image_mask)
|
||||
return MaskOutput(
|
||||
mask = ImageField(image_type = image_type, image_name = image_name)
|
||||
)
|
||||
|
||||
|
||||
class BlurInvocation(BaseInvocation):
|
||||
"""Blurs an image"""
|
||||
type: Literal['blur'] = 'blur'
|
||||
|
||||
# Inputs
|
||||
image: ImageField = Field(default=None, description="The image to blur")
|
||||
radius: float = Field(default=8.0, ge=0, description="The blur radius")
|
||||
blur_type: Literal['gaussian', 'box'] = Field(default='gaussian', description="The type of blur")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get(self.image.image_type, self.image.image_name)
|
||||
|
||||
blur = ImageFilter.GaussianBlur(self.radius) if self.blur_type == 'gaussian' else ImageFilter.BoxBlur(self.radius)
|
||||
blur_image = image.filter(blur)
|
||||
|
||||
image_type = ImageType.INTERMEDIATE
|
||||
image_name = context.services.images.create_name(context.graph_execution_state_id, self.id)
|
||||
context.services.images.save(image_type, image_name, blur_image)
|
||||
return ImageOutput(
|
||||
image = ImageField(image_type = image_type, image_name = image_name)
|
||||
)
|
||||
|
||||
|
||||
class LerpInvocation(BaseInvocation):
|
||||
"""Linear interpolation of all pixels of an image"""
|
||||
type: Literal['lerp'] = 'lerp'
|
||||
|
||||
# Inputs
|
||||
image: ImageField = Field(default=None, description="The image to lerp")
|
||||
min: int = Field(default=0, ge=0, le=255, description="The minimum output value")
|
||||
max: int = Field(default=255, ge=0, le=255, description="The maximum output value")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get(self.image.image_type, self.image.image_name)
|
||||
|
||||
image_arr = numpy.asarray(image, dtype=numpy.float32) / 255
|
||||
image_arr = image_arr * (self.max - self.min) + self.max
|
||||
|
||||
lerp_image = Image.fromarray(numpy.uint8(image_arr))
|
||||
|
||||
image_type = ImageType.INTERMEDIATE
|
||||
image_name = context.services.images.create_name(context.graph_execution_state_id, self.id)
|
||||
context.services.images.save(image_type, image_name, lerp_image)
|
||||
return ImageOutput(
|
||||
image = ImageField(image_type = image_type, image_name = image_name)
|
||||
)
|
||||
|
||||
|
||||
class InverseLerpInvocation(BaseInvocation):
|
||||
"""Inverse linear interpolation of all pixels of an image"""
|
||||
type: Literal['ilerp'] = 'ilerp'
|
||||
|
||||
# Inputs
|
||||
image: ImageField = Field(default=None, description="The image to lerp")
|
||||
min: int = Field(default=0, ge=0, le=255, description="The minimum input value")
|
||||
max: int = Field(default=255, ge=0, le=255, description="The maximum input value")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get(self.image.image_type, self.image.image_name)
|
||||
|
||||
image_arr = numpy.asarray(image, dtype=numpy.float32)
|
||||
image_arr = numpy.minimum(numpy.maximum(image_arr - self.min, 0) / float(self.max - self.min), 1) * 255
|
||||
|
||||
ilerp_image = Image.fromarray(numpy.uint8(image_arr))
|
||||
|
||||
image_type = ImageType.INTERMEDIATE
|
||||
image_name = context.services.images.create_name(context.graph_execution_state_id, self.id)
|
||||
context.services.images.save(image_type, image_name, ilerp_image)
|
||||
return ImageOutput(
|
||||
image = ImageField(image_type = image_type, image_name = image_name)
|
||||
)
|
9
invokeai/app/invocations/prompt.py
Normal file
@ -0,0 +1,9 @@
|
||||
from typing import Literal
|
||||
from pydantic.fields import Field
|
||||
from .baseinvocation import BaseInvocationOutput
|
||||
|
||||
class PromptOutput(BaseInvocationOutput):
|
||||
"""Base class for invocations that output a prompt"""
|
||||
type: Literal['prompt'] = 'prompt'
|
||||
|
||||
prompt: str = Field(default=None, description="The output prompt")
|
36
invokeai/app/invocations/reconstruct.py
Normal file
@ -0,0 +1,36 @@
|
||||
from datetime import datetime, timezone
|
||||
from typing import Literal, Union
|
||||
from pydantic import Field
|
||||
from .image import ImageField, ImageOutput
|
||||
from .baseinvocation import BaseInvocation, InvocationContext
|
||||
from ..services.image_storage import ImageType
|
||||
from ..services.invocation_services import InvocationServices
|
||||
|
||||
|
||||
class RestoreFaceInvocation(BaseInvocation):
|
||||
"""Restores faces in an image."""
|
||||
type: Literal['restore_face'] = 'restore_face'
|
||||
|
||||
# Inputs
|
||||
image: Union[ImageField,None] = Field(description="The input image")
|
||||
strength: float = Field(default=0.75, gt=0, le=1, description="The strength of the restoration")
|
||||
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get(self.image.image_type, self.image.image_name)
|
||||
results = context.services.generate.upscale_and_reconstruct(
|
||||
image_list = [[image, 0]],
|
||||
upscale = None,
|
||||
strength = self.strength, # GFPGAN strength
|
||||
save_original = False,
|
||||
image_callback = None,
|
||||
)
|
||||
|
||||
# Results are image and seed, unwrap for now
|
||||
# TODO: can this return multiple results?
|
||||
image_type = ImageType.RESULT
|
||||
image_name = context.services.images.create_name(context.graph_execution_state_id, self.id)
|
||||
context.services.images.save(image_type, image_name, results[0][0])
|
||||
return ImageOutput(
|
||||
image = ImageField(image_type = image_type, image_name = image_name)
|
||||
)
|
38
invokeai/app/invocations/upscale.py
Normal file
@ -0,0 +1,38 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from typing import Literal, Union
|
||||
from pydantic import Field
|
||||
from .image import ImageField, ImageOutput
|
||||
from .baseinvocation import BaseInvocation, InvocationContext
|
||||
from ..services.image_storage import ImageType
|
||||
from ..services.invocation_services import InvocationServices
|
||||
|
||||
|
||||
class UpscaleInvocation(BaseInvocation):
|
||||
"""Upscales an image."""
|
||||
type: Literal['upscale'] = 'upscale'
|
||||
|
||||
# Inputs
|
||||
image: Union[ImageField,None] = Field(description="The input image", default=None)
|
||||
strength: float = Field(default=0.75, gt=0, le=1, description="The strength")
|
||||
level: Literal[2,4] = Field(default=2, description = "The upscale level")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get(self.image.image_type, self.image.image_name)
|
||||
results = context.services.generate.upscale_and_reconstruct(
|
||||
image_list = [[image, 0]],
|
||||
upscale = (self.level, self.strength),
|
||||
strength = 0.0, # GFPGAN strength
|
||||
save_original = False,
|
||||
image_callback = None,
|
||||
)
|
||||
|
||||
# Results are image and seed, unwrap for now
|
||||
# TODO: can this return multiple results?
|
||||
image_type = ImageType.RESULT
|
||||
image_name = context.services.images.create_name(context.graph_execution_state_id, self.id)
|
||||
context.services.images.save(image_type, image_name, results[0][0])
|
||||
return ImageOutput(
|
||||
image = ImageField(image_type = image_type, image_name = image_name)
|
||||
)
|
93
invokeai/app/services/events.py
Normal file
@ -0,0 +1,93 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
from typing import Any, Dict
|
||||
|
||||
|
||||
class EventServiceBase:
|
||||
session_event: str = 'session_event'
|
||||
|
||||
"""Basic event bus, to have an empty stand-in when not needed"""
|
||||
def dispatch(self, event_name: str, payload: Any) -> None:
|
||||
pass
|
||||
|
||||
def __emit_session_event(self,
|
||||
event_name: str,
|
||||
payload: Dict) -> None:
|
||||
self.dispatch(
|
||||
event_name = EventServiceBase.session_event,
|
||||
payload = dict(
|
||||
event = event_name,
|
||||
data = payload
|
||||
)
|
||||
)
|
||||
|
||||
# Define events here for every event in the system.
|
||||
# This will make them easier to integrate until we find a schema generator.
|
||||
def emit_generator_progress(self,
|
||||
graph_execution_state_id: str,
|
||||
invocation_id: str,
|
||||
step: int,
|
||||
percent: float
|
||||
) -> None:
|
||||
"""Emitted when there is generation progress"""
|
||||
self.__emit_session_event(
|
||||
event_name = 'generator_progress',
|
||||
payload = dict(
|
||||
graph_execution_state_id = graph_execution_state_id,
|
||||
invocation_id = invocation_id,
|
||||
step = step,
|
||||
percent = percent
|
||||
)
|
||||
)
|
||||
|
||||
def emit_invocation_complete(self,
|
||||
graph_execution_state_id: str,
|
||||
invocation_id: str,
|
||||
result: Dict
|
||||
) -> None:
|
||||
"""Emitted when an invocation has completed"""
|
||||
self.__emit_session_event(
|
||||
event_name = 'invocation_complete',
|
||||
payload = dict(
|
||||
graph_execution_state_id = graph_execution_state_id,
|
||||
invocation_id = invocation_id,
|
||||
result = result
|
||||
)
|
||||
)
|
||||
|
||||
def emit_invocation_error(self,
|
||||
graph_execution_state_id: str,
|
||||
invocation_id: str,
|
||||
error: str
|
||||
) -> None:
|
||||
"""Emitted when an invocation has completed"""
|
||||
self.__emit_session_event(
|
||||
event_name = 'invocation_error',
|
||||
payload = dict(
|
||||
graph_execution_state_id = graph_execution_state_id,
|
||||
invocation_id = invocation_id,
|
||||
error = error
|
||||
)
|
||||
)
|
||||
|
||||
def emit_invocation_started(self,
|
||||
graph_execution_state_id: str,
|
||||
invocation_id: str
|
||||
) -> None:
|
||||
"""Emitted when an invocation has started"""
|
||||
self.__emit_session_event(
|
||||
event_name = 'invocation_started',
|
||||
payload = dict(
|
||||
graph_execution_state_id = graph_execution_state_id,
|
||||
invocation_id = invocation_id
|
||||
)
|
||||
)
|
||||
|
||||
def emit_graph_execution_complete(self, graph_execution_state_id: str) -> None:
|
||||
"""Emitted when a session has completed all invocations"""
|
||||
self.__emit_session_event(
|
||||
event_name = 'graph_execution_state_complete',
|
||||
payload = dict(
|
||||
graph_execution_state_id = graph_execution_state_id
|
||||
)
|
||||
)
|
231
invokeai/app/services/generate_initializer.py
Normal file
@ -0,0 +1,231 @@
|
||||
from argparse import Namespace
|
||||
import os
|
||||
import sys
|
||||
import traceback
|
||||
|
||||
from invokeai.backend import ModelManager, Generate
|
||||
|
||||
from ...globals import Globals
|
||||
import invokeai.version
|
||||
|
||||
# TODO: most of this code should be split into individual services as the Generate.py code is deprecated
|
||||
def get_generate(args, config) -> Generate:
|
||||
if not args.conf:
|
||||
config_file = os.path.join(Globals.root,'configs','models.yaml')
|
||||
if not os.path.exists(config_file):
|
||||
report_model_error(args, FileNotFoundError(f"The file {config_file} could not be found."))
|
||||
|
||||
print(f'>> {invokeai.version.__app_name__}, version {invokeai.version.__version__}')
|
||||
print(f'>> InvokeAI runtime directory is "{Globals.root}"')
|
||||
|
||||
# these two lines prevent a horrible warning message from appearing
|
||||
# when the frozen CLIP tokenizer is imported
|
||||
import transformers # type: ignore
|
||||
transformers.logging.set_verbosity_error()
|
||||
import diffusers
|
||||
diffusers.logging.set_verbosity_error()
|
||||
|
||||
# Loading Face Restoration and ESRGAN Modules
|
||||
gfpgan,codeformer,esrgan = load_face_restoration(args)
|
||||
|
||||
# normalize the config directory relative to root
|
||||
if not os.path.isabs(args.conf):
|
||||
args.conf = os.path.normpath(os.path.join(Globals.root,args.conf))
|
||||
|
||||
if args.embeddings:
|
||||
if not os.path.isabs(args.embedding_path):
|
||||
embedding_path = os.path.normpath(os.path.join(Globals.root,args.embedding_path))
|
||||
else:
|
||||
embedding_path = args.embedding_path
|
||||
else:
|
||||
embedding_path = None
|
||||
|
||||
# migrate legacy models
|
||||
ModelManager.migrate_models()
|
||||
|
||||
# load the infile as a list of lines
|
||||
if args.infile:
|
||||
try:
|
||||
if os.path.isfile(args.infile):
|
||||
infile = open(args.infile, 'r', encoding='utf-8')
|
||||
elif args.infile == '-': # stdin
|
||||
infile = sys.stdin
|
||||
else:
|
||||
raise FileNotFoundError(f'{args.infile} not found.')
|
||||
except (FileNotFoundError, IOError) as e:
|
||||
print(f'{e}. Aborting.')
|
||||
sys.exit(-1)
|
||||
|
||||
# creating a Generate object:
|
||||
try:
|
||||
gen = Generate(
|
||||
conf = args.conf,
|
||||
model = args.model,
|
||||
sampler_name = args.sampler_name,
|
||||
embedding_path = embedding_path,
|
||||
full_precision = args.full_precision,
|
||||
precision = args.precision,
|
||||
gfpgan = gfpgan,
|
||||
codeformer = codeformer,
|
||||
esrgan = esrgan,
|
||||
free_gpu_mem = args.free_gpu_mem,
|
||||
safety_checker = args.safety_checker,
|
||||
max_loaded_models = args.max_loaded_models,
|
||||
)
|
||||
except (FileNotFoundError, TypeError, AssertionError) as e:
|
||||
report_model_error(opt,e)
|
||||
except (IOError, KeyError) as e:
|
||||
print(f'{e}. Aborting.')
|
||||
sys.exit(-1)
|
||||
|
||||
if args.seamless:
|
||||
print(">> changed to seamless tiling mode")
|
||||
|
||||
# preload the model
|
||||
try:
|
||||
gen.load_model()
|
||||
except KeyError:
|
||||
pass
|
||||
except Exception as e:
|
||||
report_model_error(args, e)
|
||||
|
||||
# try to autoconvert new models
|
||||
# autoimport new .ckpt files
|
||||
if path := args.autoconvert:
|
||||
gen.model_manager.autoconvert_weights(
|
||||
conf_path=args.conf,
|
||||
weights_directory=path,
|
||||
)
|
||||
|
||||
return gen
|
||||
|
||||
|
||||
def load_face_restoration(opt):
|
||||
try:
|
||||
gfpgan, codeformer, esrgan = None, None, None
|
||||
if opt.restore or opt.esrgan:
|
||||
from ldm.invoke.restoration import Restoration
|
||||
restoration = Restoration()
|
||||
if opt.restore:
|
||||
gfpgan, codeformer = restoration.load_face_restore_models(opt.gfpgan_model_path)
|
||||
else:
|
||||
print('>> Face restoration disabled')
|
||||
if opt.esrgan:
|
||||
esrgan = restoration.load_esrgan(opt.esrgan_bg_tile)
|
||||
else:
|
||||
print('>> Upscaling disabled')
|
||||
else:
|
||||
print('>> Face restoration and upscaling disabled')
|
||||
except (ModuleNotFoundError, ImportError):
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
print('>> You may need to install the ESRGAN and/or GFPGAN modules')
|
||||
return gfpgan,codeformer,esrgan
|
||||
|
||||
|
||||
def report_model_error(opt:Namespace, e:Exception):
|
||||
print(f'** An error occurred while attempting to initialize the model: "{str(e)}"')
|
||||
print('** This can be caused by a missing or corrupted models file, and can sometimes be fixed by (re)installing the models.')
|
||||
yes_to_all = os.environ.get('INVOKE_MODEL_RECONFIGURE')
|
||||
if yes_to_all:
|
||||
print('** Reconfiguration is being forced by environment variable INVOKE_MODEL_RECONFIGURE')
|
||||
else:
|
||||
response = input('Do you want to run invokeai-configure script to select and/or reinstall models? [y] ')
|
||||
if response.startswith(('n', 'N')):
|
||||
return
|
||||
|
||||
print('invokeai-configure is launching....\n')
|
||||
|
||||
# Match arguments that were set on the CLI
|
||||
# only the arguments accepted by the configuration script are parsed
|
||||
root_dir = ["--root", opt.root_dir] if opt.root_dir is not None else []
|
||||
config = ["--config", opt.conf] if opt.conf is not None else []
|
||||
previous_args = sys.argv
|
||||
sys.argv = [ 'invokeai-configure' ]
|
||||
sys.argv.extend(root_dir)
|
||||
sys.argv.extend(config)
|
||||
if yes_to_all is not None:
|
||||
for arg in yes_to_all.split():
|
||||
sys.argv.append(arg)
|
||||
|
||||
from ldm.invoke.config import invokeai_configure
|
||||
invokeai_configure.main()
|
||||
# TODO: Figure out how to restart
|
||||
# print('** InvokeAI will now restart')
|
||||
# sys.argv = previous_args
|
||||
# main() # would rather do a os.exec(), but doesn't exist?
|
||||
# sys.exit(0)
|
||||
|
||||
|
||||
# Temporary initializer for Generate until we migrate off of it
|
||||
def old_get_generate(args, config) -> Generate:
|
||||
# TODO: Remove the need for globals
|
||||
from invokeai.backend.globals import Globals
|
||||
|
||||
# alert - setting globals here
|
||||
Globals.root = os.path.expanduser(args.root_dir or os.environ.get('INVOKEAI_ROOT') or os.path.abspath('.'))
|
||||
Globals.try_patchmatch = args.patchmatch
|
||||
|
||||
print(f'>> InvokeAI runtime directory is "{Globals.root}"')
|
||||
|
||||
# these two lines prevent a horrible warning message from appearing
|
||||
# when the frozen CLIP tokenizer is imported
|
||||
import transformers
|
||||
transformers.logging.set_verbosity_error()
|
||||
|
||||
# Loading Face Restoration and ESRGAN Modules
|
||||
gfpgan, codeformer, esrgan = None, None, None
|
||||
try:
|
||||
if config.restore or config.esrgan:
|
||||
from ldm.invoke.restoration import Restoration
|
||||
restoration = Restoration()
|
||||
if config.restore:
|
||||
gfpgan, codeformer = restoration.load_face_restore_models(config.gfpgan_model_path)
|
||||
else:
|
||||
print('>> Face restoration disabled')
|
||||
if config.esrgan:
|
||||
esrgan = restoration.load_esrgan(config.esrgan_bg_tile)
|
||||
else:
|
||||
print('>> Upscaling disabled')
|
||||
else:
|
||||
print('>> Face restoration and upscaling disabled')
|
||||
except (ModuleNotFoundError, ImportError):
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
print('>> You may need to install the ESRGAN and/or GFPGAN modules')
|
||||
|
||||
# normalize the config directory relative to root
|
||||
if not os.path.isabs(config.conf):
|
||||
config.conf = os.path.normpath(os.path.join(Globals.root,config.conf))
|
||||
|
||||
if config.embeddings:
|
||||
if not os.path.isabs(config.embedding_path):
|
||||
embedding_path = os.path.normpath(os.path.join(Globals.root,config.embedding_path))
|
||||
else:
|
||||
embedding_path = None
|
||||
|
||||
|
||||
# TODO: lazy-initialize this by wrapping it
|
||||
try:
|
||||
generate = Generate(
|
||||
conf = config.conf,
|
||||
model = config.model,
|
||||
sampler_name = config.sampler_name,
|
||||
embedding_path = embedding_path,
|
||||
full_precision = config.full_precision,
|
||||
precision = config.precision,
|
||||
gfpgan = gfpgan,
|
||||
codeformer = codeformer,
|
||||
esrgan = esrgan,
|
||||
free_gpu_mem = config.free_gpu_mem,
|
||||
safety_checker = config.safety_checker,
|
||||
max_loaded_models = config.max_loaded_models,
|
||||
)
|
||||
except (FileNotFoundError, TypeError, AssertionError):
|
||||
#emergency_model_reconfigure() # TODO?
|
||||
sys.exit(-1)
|
||||
except (IOError, KeyError) as e:
|
||||
print(f'{e}. Aborting.')
|
||||
sys.exit(-1)
|
||||
|
||||
generate.free_gpu_mem = config.free_gpu_mem
|
||||
|
||||
return generate
|
809
invokeai/app/services/graph.py
Normal file
@ -0,0 +1,809 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
import copy
|
||||
import itertools
|
||||
import traceback
|
||||
from types import NoneType
|
||||
import uuid
|
||||
import networkx as nx
|
||||
from pydantic import BaseModel, validator
|
||||
from pydantic.fields import Field
|
||||
from typing import Any, Literal, Optional, Union, get_args, get_origin, get_type_hints, Annotated
|
||||
|
||||
from .invocation_services import InvocationServices
|
||||
from ..invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext
|
||||
from ..invocations import *
|
||||
|
||||
|
||||
class EdgeConnection(BaseModel):
|
||||
node_id: str = Field(description="The id of the node for this edge connection")
|
||||
field: str = Field(description="The field for this connection")
|
||||
|
||||
def __eq__(self, other):
|
||||
return (isinstance(other, self.__class__) and
|
||||
getattr(other, 'node_id', None) == self.node_id and
|
||||
getattr(other, 'field', None) == self.field)
|
||||
|
||||
def __hash__(self):
|
||||
return hash(f'{self.node_id}.{self.field}')
|
||||
|
||||
|
||||
def get_output_field(node: BaseInvocation, field: str) -> Any:
|
||||
node_type = type(node)
|
||||
node_outputs = get_type_hints(node_type.get_output_type())
|
||||
node_output_field = node_outputs.get(field) or None
|
||||
return node_output_field
|
||||
|
||||
|
||||
def get_input_field(node: BaseInvocation, field: str) -> Any:
|
||||
node_type = type(node)
|
||||
node_inputs = get_type_hints(node_type)
|
||||
node_input_field = node_inputs.get(field) or None
|
||||
return node_input_field
|
||||
|
||||
|
||||
def are_connection_types_compatible(from_type: Any, to_type: Any) -> bool:
|
||||
if not from_type:
|
||||
return False
|
||||
if not to_type:
|
||||
return False
|
||||
|
||||
# TODO: this is pretty forgiving on generic types. Clean that up (need to handle optionals and such)
|
||||
if from_type and to_type:
|
||||
# Ports are compatible
|
||||
if (from_type == to_type or
|
||||
from_type == Any or
|
||||
to_type == Any or
|
||||
Any in get_args(from_type) or
|
||||
Any in get_args(to_type)):
|
||||
return True
|
||||
|
||||
if from_type in get_args(to_type):
|
||||
return True
|
||||
|
||||
if to_type in get_args(from_type):
|
||||
return True
|
||||
|
||||
if not issubclass(from_type, to_type):
|
||||
return False
|
||||
else:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def are_connections_compatible(
|
||||
from_node: BaseInvocation,
|
||||
from_field: str,
|
||||
to_node: BaseInvocation,
|
||||
to_field: str) -> bool:
|
||||
"""Determines if a connection between fields of two nodes is compatible."""
|
||||
|
||||
# TODO: handle iterators and collectors
|
||||
from_node_field = get_output_field(from_node, from_field)
|
||||
to_node_field = get_input_field(to_node, to_field)
|
||||
|
||||
return are_connection_types_compatible(from_node_field, to_node_field)
|
||||
|
||||
|
||||
class NodeAlreadyInGraphError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class InvalidEdgeError(Exception):
|
||||
pass
|
||||
|
||||
class NodeNotFoundError(Exception):
|
||||
pass
|
||||
|
||||
class NodeAlreadyExecutedError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
# TODO: Create and use an Empty output?
|
||||
class GraphInvocationOutput(BaseInvocationOutput):
|
||||
type: Literal['graph_output'] = 'graph_output'
|
||||
|
||||
|
||||
# TODO: Fill this out and move to invocations
|
||||
class GraphInvocation(BaseInvocation):
|
||||
type: Literal['graph'] = 'graph'
|
||||
|
||||
# TODO: figure out how to create a default here
|
||||
graph: 'Graph' = Field(description="The graph to run", default=None)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> GraphInvocationOutput:
|
||||
"""Invoke with provided services and return outputs."""
|
||||
return GraphInvocationOutput()
|
||||
|
||||
|
||||
class IterateInvocationOutput(BaseInvocationOutput):
|
||||
"""Used to connect iteration outputs. Will be expanded to a specific output."""
|
||||
type: Literal['iterate_output'] = 'iterate_output'
|
||||
|
||||
item: Any = Field(description="The item being iterated over")
|
||||
|
||||
|
||||
# TODO: Fill this out and move to invocations
|
||||
class IterateInvocation(BaseInvocation):
|
||||
type: Literal['iterate'] = 'iterate'
|
||||
|
||||
collection: list[Any] = Field(description="The list of items to iterate over", default_factory=list)
|
||||
index: int = Field(description="The index, will be provided on executed iterators", default=0)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> IterateInvocationOutput:
|
||||
"""Produces the outputs as values"""
|
||||
return IterateInvocationOutput(item = self.collection[self.index])
|
||||
|
||||
|
||||
class CollectInvocationOutput(BaseInvocationOutput):
|
||||
type: Literal['collect_output'] = 'collect_output'
|
||||
|
||||
collection: list[Any] = Field(description="The collection of input items")
|
||||
|
||||
|
||||
class CollectInvocation(BaseInvocation):
|
||||
"""Collects values into a collection"""
|
||||
type: Literal['collect'] = 'collect'
|
||||
|
||||
item: Any = Field(description="The item to collect (all inputs must be of the same type)", default=None)
|
||||
collection: list[Any] = Field(description="The collection, will be provided on execution", default_factory=list)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> CollectInvocationOutput:
|
||||
"""Invoke with provided services and return outputs."""
|
||||
return CollectInvocationOutput(collection = copy.copy(self.collection))
|
||||
|
||||
|
||||
InvocationsUnion = Union[BaseInvocation.get_invocations()] # type: ignore
|
||||
InvocationOutputsUnion = Union[BaseInvocationOutput.get_all_subclasses_tuple()] # type: ignore
|
||||
|
||||
|
||||
class Graph(BaseModel):
|
||||
id: str = Field(description="The id of this graph", default_factory=uuid.uuid4)
|
||||
# TODO: use a list (and never use dict in a BaseModel) because pydantic/fastapi hates me
|
||||
nodes: dict[str, Annotated[InvocationsUnion, Field(discriminator="type")]] = Field(description="The nodes in this graph", default_factory=dict)
|
||||
edges: list[tuple[EdgeConnection,EdgeConnection]] = Field(description="The connections between nodes and their fields in this graph", default_factory=list)
|
||||
|
||||
def add_node(self, node: BaseInvocation) -> None:
|
||||
"""Adds a node to a graph
|
||||
|
||||
:raises NodeAlreadyInGraphError: the node is already present in the graph.
|
||||
"""
|
||||
|
||||
if node.id in self.nodes:
|
||||
raise NodeAlreadyInGraphError()
|
||||
|
||||
self.nodes[node.id] = node
|
||||
|
||||
|
||||
def _get_graph_and_node(self, node_path: str) -> tuple['Graph', str]:
|
||||
"""Returns the graph and node id for a node path."""
|
||||
# Materialized graphs may have nodes at the top level
|
||||
if node_path in self.nodes:
|
||||
return (self, node_path)
|
||||
|
||||
node_id = node_path if '.' not in node_path else node_path[:node_path.index('.')]
|
||||
if node_id not in self.nodes:
|
||||
raise NodeNotFoundError(f'Node {node_path} not found in graph')
|
||||
|
||||
node = self.nodes[node_id]
|
||||
|
||||
if not isinstance(node, GraphInvocation):
|
||||
# There's more node path left but this isn't a graph - failure
|
||||
raise NodeNotFoundError('Node path terminated early at a non-graph node')
|
||||
|
||||
return node.graph._get_graph_and_node(node_path[node_path.index('.')+1:])
|
||||
|
||||
|
||||
def delete_node(self, node_path: str) -> None:
|
||||
"""Deletes a node from a graph"""
|
||||
|
||||
try:
|
||||
graph, node_id = self._get_graph_and_node(node_path)
|
||||
|
||||
# Delete edges for this node
|
||||
input_edges = self._get_input_edges_and_graphs(node_path)
|
||||
output_edges = self._get_output_edges_and_graphs(node_path)
|
||||
|
||||
for edge_graph,_,edge in input_edges:
|
||||
edge_graph.delete_edge(edge)
|
||||
|
||||
for edge_graph,_,edge in output_edges:
|
||||
edge_graph.delete_edge(edge)
|
||||
|
||||
del graph.nodes[node_id]
|
||||
|
||||
except NodeNotFoundError:
|
||||
pass # Ignore, not doesn't exist (should this throw?)
|
||||
|
||||
|
||||
def add_edge(self, edge: tuple[EdgeConnection, EdgeConnection]) -> None:
|
||||
"""Adds an edge to a graph
|
||||
|
||||
:raises InvalidEdgeError: the provided edge is invalid.
|
||||
"""
|
||||
|
||||
if self._is_edge_valid(edge) and edge not in self.edges:
|
||||
self.edges.append(edge)
|
||||
else:
|
||||
raise InvalidEdgeError()
|
||||
|
||||
|
||||
def delete_edge(self, edge: tuple[EdgeConnection, EdgeConnection]) -> None:
|
||||
"""Deletes an edge from a graph"""
|
||||
|
||||
try:
|
||||
self.edges.remove(edge)
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
|
||||
def is_valid(self) -> bool:
|
||||
"""Validates the graph."""
|
||||
|
||||
# Validate all subgraphs
|
||||
for gn in (n for n in self.nodes.values() if isinstance(n, GraphInvocation)):
|
||||
if not gn.graph.is_valid():
|
||||
return False
|
||||
|
||||
# Validate all edges reference nodes in the graph
|
||||
node_ids = set([e[0].node_id for e in self.edges]+[e[1].node_id for e in self.edges])
|
||||
if not all((self.has_node(node_id) for node_id in node_ids)):
|
||||
return False
|
||||
|
||||
# Validate there are no cycles
|
||||
g = self.nx_graph_flat()
|
||||
if not nx.is_directed_acyclic_graph(g):
|
||||
return False
|
||||
|
||||
# Validate all edge connections are valid
|
||||
if not all((are_connections_compatible(
|
||||
self.get_node(e[0].node_id), e[0].field,
|
||||
self.get_node(e[1].node_id), e[1].field
|
||||
) for e in self.edges)):
|
||||
return False
|
||||
|
||||
# Validate all iterators
|
||||
# TODO: may need to validate all iterators in subgraphs so edge connections in parent graphs will be available
|
||||
if not all((self._is_iterator_connection_valid(n.id) for n in self.nodes.values() if isinstance(n, IterateInvocation))):
|
||||
return False
|
||||
|
||||
# Validate all collectors
|
||||
# TODO: may need to validate all collectors in subgraphs so edge connections in parent graphs will be available
|
||||
if not all((self._is_collector_connection_valid(n.id) for n in self.nodes.values() if isinstance(n, CollectInvocation))):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def _is_edge_valid(self, edge: tuple[EdgeConnection, EdgeConnection]) -> bool:
|
||||
"""Validates that a new edge doesn't create a cycle in the graph"""
|
||||
|
||||
# Validate that the nodes exist (edges may contain node paths, so we can't just check for nodes directly)
|
||||
try:
|
||||
from_node = self.get_node(edge[0].node_id)
|
||||
to_node = self.get_node(edge[1].node_id)
|
||||
except NodeNotFoundError:
|
||||
return False
|
||||
|
||||
# Validate that an edge to this node+field doesn't already exist
|
||||
input_edges = self._get_input_edges(edge[1].node_id, edge[1].field)
|
||||
if len(input_edges) > 0 and not isinstance(to_node, CollectInvocation):
|
||||
return False
|
||||
|
||||
# Validate that no cycles would be created
|
||||
g = self.nx_graph_flat()
|
||||
g.add_edge(edge[0].node_id, edge[1].node_id)
|
||||
if not nx.is_directed_acyclic_graph(g):
|
||||
return False
|
||||
|
||||
# Validate that the field types are compatible
|
||||
if not are_connections_compatible(from_node, edge[0].field, to_node, edge[1].field):
|
||||
return False
|
||||
|
||||
# Validate if iterator output type matches iterator input type (if this edge results in both being set)
|
||||
if isinstance(to_node, IterateInvocation) and edge[1].field == 'collection':
|
||||
if not self._is_iterator_connection_valid(edge[1].node_id, new_input = edge[0]):
|
||||
return False
|
||||
|
||||
# Validate if iterator input type matches output type (if this edge results in both being set)
|
||||
if isinstance(from_node, IterateInvocation) and edge[0].field == 'item':
|
||||
if not self._is_iterator_connection_valid(edge[0].node_id, new_output = edge[1]):
|
||||
return False
|
||||
|
||||
# Validate if collector input type matches output type (if this edge results in both being set)
|
||||
if isinstance(to_node, CollectInvocation) and edge[1].field == 'item':
|
||||
if not self._is_collector_connection_valid(edge[1].node_id, new_input = edge[0]):
|
||||
return False
|
||||
|
||||
# Validate if collector output type matches input type (if this edge results in both being set)
|
||||
if isinstance(from_node, CollectInvocation) and edge[0].field == 'collection':
|
||||
if not self._is_collector_connection_valid(edge[0].node_id, new_output = edge[1]):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def has_node(self, node_path: str) -> bool:
|
||||
"""Determines whether or not a node exists in the graph."""
|
||||
try:
|
||||
n = self.get_node(node_path)
|
||||
if n is not None:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
except NodeNotFoundError:
|
||||
return False
|
||||
|
||||
def get_node(self, node_path: str) -> InvocationsUnion:
|
||||
"""Gets a node from the graph using a node path."""
|
||||
# Materialized graphs may have nodes at the top level
|
||||
graph, node_id = self._get_graph_and_node(node_path)
|
||||
return graph.nodes[node_id]
|
||||
|
||||
|
||||
def _get_node_path(self, node_id: str, prefix: Optional[str] = None) -> str:
|
||||
return node_id if prefix is None or prefix == '' else f'{prefix}.{node_id}'
|
||||
|
||||
|
||||
def update_node(self, node_path: str, new_node: BaseInvocation) -> None:
|
||||
"""Updates a node in the graph."""
|
||||
graph, node_id = self._get_graph_and_node(node_path)
|
||||
node = graph.nodes[node_id]
|
||||
|
||||
# Ensure the node type matches the new node
|
||||
if type(node) != type(new_node):
|
||||
raise TypeError(f'Node {node_path} is type {type(node)} but new node is type {type(new_node)}')
|
||||
|
||||
# Ensure the new id is either the same or is not in the graph
|
||||
prefix = None if '.' not in node_path else node_path[:node_path.rindex('.')]
|
||||
new_path = self._get_node_path(new_node.id, prefix = prefix)
|
||||
if new_node.id != node.id and self.has_node(new_path):
|
||||
raise NodeAlreadyInGraphError('Node with id {new_node.id} already exists in graph')
|
||||
|
||||
# Set the new node in the graph
|
||||
graph.nodes[new_node.id] = new_node
|
||||
if new_node.id != node.id:
|
||||
input_edges = self._get_input_edges_and_graphs(node_path)
|
||||
output_edges = self._get_output_edges_and_graphs(node_path)
|
||||
|
||||
# Delete node and all edges
|
||||
graph.delete_node(node_path)
|
||||
|
||||
# Create new edges for each input and output
|
||||
for graph,_,edge in input_edges:
|
||||
# Remove the graph prefix from the node path
|
||||
new_graph_node_path = new_node.id if '.' not in edge[1].node_id else f'{edge[1].node_id[edge[1].node_id.rindex("."):]}.{new_node.id}'
|
||||
graph.add_edge((edge[0], EdgeConnection(node_id = new_graph_node_path, field = edge[1].field)))
|
||||
|
||||
for graph,_,edge in output_edges:
|
||||
# Remove the graph prefix from the node path
|
||||
new_graph_node_path = new_node.id if '.' not in edge[0].node_id else f'{edge[0].node_id[edge[0].node_id.rindex("."):]}.{new_node.id}'
|
||||
graph.add_edge((EdgeConnection(node_id = new_graph_node_path, field = edge[0].field), edge[1]))
|
||||
|
||||
|
||||
def _get_input_edges(self, node_path: str, field: Optional[str] = None) -> list[tuple[EdgeConnection,EdgeConnection]]:
|
||||
"""Gets all input edges for a node"""
|
||||
edges = self._get_input_edges_and_graphs(node_path)
|
||||
|
||||
# Filter to edges that match the field
|
||||
filtered_edges = (e for e in edges if field is None or e[2][1].field == field)
|
||||
|
||||
# Create full node paths for each edge
|
||||
return [(EdgeConnection(node_id = self._get_node_path(e[0].node_id, prefix = prefix), field=e[0].field), EdgeConnection(node_id = self._get_node_path(e[1].node_id, prefix = prefix), field=e[1].field)) for _,prefix,e in filtered_edges]
|
||||
|
||||
|
||||
def _get_input_edges_and_graphs(self, node_path: str, prefix: Optional[str] = None) -> list[tuple['Graph', str, tuple[EdgeConnection,EdgeConnection]]]:
|
||||
"""Gets all input edges for a node along with the graph they are in and the graph's path"""
|
||||
edges = list()
|
||||
|
||||
# Return any input edges that appear in this graph
|
||||
edges.extend([(self, prefix, e) for e in self.edges if e[1].node_id == node_path])
|
||||
|
||||
node_id = node_path if '.' not in node_path else node_path[:node_path.index('.')]
|
||||
node = self.nodes[node_id]
|
||||
|
||||
if isinstance(node, GraphInvocation):
|
||||
graph = node.graph
|
||||
graph_path = node.id if prefix is None or prefix == '' else self._get_node_path(node.id, prefix = prefix)
|
||||
graph_edges = graph._get_input_edges_and_graphs(node_path[(len(node_id)+1):], prefix=graph_path)
|
||||
edges.extend(graph_edges)
|
||||
|
||||
return edges
|
||||
|
||||
|
||||
def _get_output_edges(self, node_path: str, field: str) -> list[tuple[EdgeConnection,EdgeConnection]]:
|
||||
"""Gets all output edges for a node"""
|
||||
edges = self._get_output_edges_and_graphs(node_path)
|
||||
|
||||
# Filter to edges that match the field
|
||||
filtered_edges = (e for e in edges if e[2][0].field == field)
|
||||
|
||||
# Create full node paths for each edge
|
||||
return [(EdgeConnection(node_id = self._get_node_path(e[0].node_id, prefix = prefix), field=e[0].field), EdgeConnection(node_id = self._get_node_path(e[1].node_id, prefix = prefix), field=e[1].field)) for _,prefix,e in filtered_edges]
|
||||
|
||||
|
||||
def _get_output_edges_and_graphs(self, node_path: str, prefix: Optional[str] = None) -> list[tuple['Graph', str, tuple[EdgeConnection,EdgeConnection]]]:
|
||||
"""Gets all output edges for a node along with the graph they are in and the graph's path"""
|
||||
edges = list()
|
||||
|
||||
# Return any input edges that appear in this graph
|
||||
edges.extend([(self, prefix, e) for e in self.edges if e[0].node_id == node_path])
|
||||
|
||||
node_id = node_path if '.' not in node_path else node_path[:node_path.index('.')]
|
||||
node = self.nodes[node_id]
|
||||
|
||||
if isinstance(node, GraphInvocation):
|
||||
graph = node.graph
|
||||
graph_path = node.id if prefix is None or prefix == '' else self._get_node_path(node.id, prefix = prefix)
|
||||
graph_edges = graph._get_output_edges_and_graphs(node_path[(len(node_id)+1):], prefix=graph_path)
|
||||
edges.extend(graph_edges)
|
||||
|
||||
return edges
|
||||
|
||||
|
||||
def _is_iterator_connection_valid(self, node_path: str, new_input: Optional[EdgeConnection] = None, new_output: Optional[EdgeConnection] = None) -> bool:
|
||||
inputs = list([e[0] for e in self._get_input_edges(node_path, 'collection')])
|
||||
outputs = list([e[1] for e in self._get_output_edges(node_path, 'item')])
|
||||
|
||||
if new_input is not None:
|
||||
inputs.append(new_input)
|
||||
if new_output is not None:
|
||||
outputs.append(new_output)
|
||||
|
||||
# Only one input is allowed for iterators
|
||||
if len(inputs) > 1:
|
||||
return False
|
||||
|
||||
# Get input and output fields (the fields linked to the iterator's input/output)
|
||||
input_field = get_output_field(self.get_node(inputs[0].node_id), inputs[0].field)
|
||||
output_fields = list([get_input_field(self.get_node(e.node_id), e.field) for e in outputs])
|
||||
|
||||
# Input type must be a list
|
||||
if get_origin(input_field) != list:
|
||||
return False
|
||||
|
||||
# Validate that all outputs match the input type
|
||||
input_field_item_type = get_args(input_field)[0]
|
||||
if not all((are_connection_types_compatible(input_field_item_type, f) for f in output_fields)):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def _is_collector_connection_valid(self, node_path: str, new_input: Optional[EdgeConnection] = None, new_output: Optional[EdgeConnection] = None) -> bool:
|
||||
inputs = list([e[0] for e in self._get_input_edges(node_path, 'item')])
|
||||
outputs = list([e[1] for e in self._get_output_edges(node_path, 'collection')])
|
||||
|
||||
if new_input is not None:
|
||||
inputs.append(new_input)
|
||||
if new_output is not None:
|
||||
outputs.append(new_output)
|
||||
|
||||
# Get input and output fields (the fields linked to the iterator's input/output)
|
||||
input_fields = list([get_output_field(self.get_node(e.node_id), e.field) for e in inputs])
|
||||
output_fields = list([get_input_field(self.get_node(e.node_id), e.field) for e in outputs])
|
||||
|
||||
# Validate that all inputs are derived from or match a single type
|
||||
input_field_types = set([t for input_field in input_fields for t in ([input_field] if get_origin(input_field) == None else get_args(input_field)) if t != NoneType]) # Get unique types
|
||||
type_tree = nx.DiGraph()
|
||||
type_tree.add_nodes_from(input_field_types)
|
||||
type_tree.add_edges_from([e for e in itertools.permutations(input_field_types, 2) if issubclass(e[1], e[0])])
|
||||
type_degrees = type_tree.in_degree(type_tree.nodes)
|
||||
if sum((t[1] == 0 for t in type_degrees)) != 1: # type: ignore
|
||||
return False # There is more than one root type
|
||||
|
||||
# Get the input root type
|
||||
input_root_type = next(t[0] for t in type_degrees if t[1] == 0) # type: ignore
|
||||
|
||||
# Verify that all outputs are lists
|
||||
if not all((get_origin(f) == list for f in output_fields)):
|
||||
return False
|
||||
|
||||
# Verify that all outputs match the input type (are a base class or the same class)
|
||||
if not all((issubclass(input_root_type, get_args(f)[0]) for f in output_fields)):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def nx_graph(self) -> nx.DiGraph:
|
||||
"""Returns a NetworkX DiGraph representing the layout of this graph"""
|
||||
# TODO: Cache this?
|
||||
g = nx.DiGraph()
|
||||
g.add_nodes_from([n for n in self.nodes.keys()])
|
||||
g.add_edges_from(set([(e[0].node_id, e[1].node_id) for e in self.edges]))
|
||||
return g
|
||||
|
||||
def nx_graph_flat(self, nx_graph: Optional[nx.DiGraph] = None, prefix: Optional[str] = None) -> nx.DiGraph:
|
||||
"""Returns a flattened NetworkX DiGraph, including all subgraphs (but not with iterations expanded)"""
|
||||
g = nx_graph or nx.DiGraph()
|
||||
|
||||
# Add all nodes from this graph except graph/iteration nodes
|
||||
g.add_nodes_from([self._get_node_path(n.id, prefix) for n in self.nodes.values() if not isinstance(n, GraphInvocation) and not isinstance(n, IterateInvocation)])
|
||||
|
||||
# Expand graph nodes
|
||||
for sgn in (gn for gn in self.nodes.values() if isinstance(gn, GraphInvocation)):
|
||||
sgn.graph.nx_graph_flat(g, self._get_node_path(sgn.id, prefix))
|
||||
|
||||
# TODO: figure out if iteration nodes need to be expanded
|
||||
|
||||
unique_edges = set([(e[0].node_id, e[1].node_id) for e in self.edges])
|
||||
g.add_edges_from([(self._get_node_path(e[0], prefix), self._get_node_path(e[1], prefix)) for e in unique_edges])
|
||||
return g
|
||||
|
||||
|
||||
class GraphExecutionState(BaseModel):
|
||||
"""Tracks the state of a graph execution"""
|
||||
id: str = Field(description="The id of the execution state", default_factory=uuid.uuid4)
|
||||
|
||||
# TODO: Store a reference to the graph instead of the actual graph?
|
||||
graph: Graph = Field(description="The graph being executed")
|
||||
|
||||
# The graph of materialized nodes
|
||||
execution_graph: Graph = Field(description="The expanded graph of activated and executed nodes", default_factory=Graph)
|
||||
|
||||
# Nodes that have been executed
|
||||
executed: set[str] = Field(description="The set of node ids that have been executed", default_factory=set)
|
||||
executed_history: list[str] = Field(description="The list of node ids that have been executed, in order of execution", default_factory=list)
|
||||
|
||||
# The results of executed nodes
|
||||
results: dict[str, Annotated[InvocationOutputsUnion, Field(discriminator="type")]] = Field(description="The results of node executions", default_factory=dict)
|
||||
|
||||
# Errors raised when executing nodes
|
||||
errors: dict[str, str] = Field(description="Errors raised when executing nodes", default_factory=dict)
|
||||
|
||||
# Map of prepared/executed nodes to their original nodes
|
||||
prepared_source_mapping: dict[str, str] = Field(description="The map of prepared nodes to original graph nodes", default_factory=dict)
|
||||
|
||||
# Map of original nodes to prepared nodes
|
||||
source_prepared_mapping: dict[str, set[str]] = Field(description="The map of original graph nodes to prepared nodes", default_factory=dict)
|
||||
|
||||
def next(self) -> BaseInvocation | None:
|
||||
"""Gets the next node ready to execute."""
|
||||
|
||||
# TODO: enable multiple nodes to execute simultaneously by tracking currently executing nodes
|
||||
# possibly with a timeout?
|
||||
|
||||
# If there are no prepared nodes, prepare some nodes
|
||||
next_node = self._get_next_node()
|
||||
if next_node is None:
|
||||
prepared_id = self._prepare()
|
||||
|
||||
# TODO: prepare multiple nodes at once?
|
||||
# while prepared_id is not None and not isinstance(self.graph.nodes[prepared_id], IterateInvocation):
|
||||
# prepared_id = self._prepare()
|
||||
|
||||
if prepared_id is not None:
|
||||
next_node = self._get_next_node()
|
||||
|
||||
# Get values from edges
|
||||
if next_node is not None:
|
||||
self._prepare_inputs(next_node)
|
||||
|
||||
# If next is still none, there's no next node, return None
|
||||
return next_node
|
||||
|
||||
def complete(self, node_id: str, output: InvocationOutputsUnion):
|
||||
"""Marks a node as complete"""
|
||||
|
||||
if node_id not in self.execution_graph.nodes:
|
||||
return # TODO: log error?
|
||||
|
||||
# Mark node as executed
|
||||
self.executed.add(node_id)
|
||||
self.results[node_id] = output
|
||||
|
||||
# Check if source node is complete (all prepared nodes are complete)
|
||||
source_node = self.prepared_source_mapping[node_id]
|
||||
prepared_nodes = self.source_prepared_mapping[source_node]
|
||||
|
||||
if all([n in self.executed for n in prepared_nodes]):
|
||||
self.executed.add(source_node)
|
||||
self.executed_history.append(source_node)
|
||||
|
||||
def set_node_error(self, node_id: str, error: str):
|
||||
"""Marks a node as errored"""
|
||||
self.errors[node_id] = error
|
||||
|
||||
def is_complete(self) -> bool:
|
||||
"""Returns true if the graph is complete"""
|
||||
return self.has_error() or all((k in self.executed for k in self.graph.nodes))
|
||||
|
||||
def has_error(self) -> bool:
|
||||
"""Returns true if the graph has any errors"""
|
||||
return len(self.errors) > 0
|
||||
|
||||
def _create_execution_node(self, node_path: str, iteration_node_map: list[tuple[str, str]]) -> list[str]:
|
||||
"""Prepares an iteration node and connects all edges, returning the new node id"""
|
||||
|
||||
node = self.graph.get_node(node_path)
|
||||
|
||||
self_iteration_count = -1
|
||||
|
||||
# If this is an iterator node, we must create a copy for each iteration
|
||||
if isinstance(node, IterateInvocation):
|
||||
# Get input collection edge (should error if there are no inputs)
|
||||
input_collection_edge = next(iter(self.graph._get_input_edges(node_path, 'collection')))
|
||||
input_collection_prepared_node_id = next(n[1] for n in iteration_node_map if n[0] == input_collection_edge[0].node_id)
|
||||
input_collection_prepared_node_output = self.results[input_collection_prepared_node_id]
|
||||
input_collection = getattr(input_collection_prepared_node_output, input_collection_edge[0].field)
|
||||
self_iteration_count = len(input_collection)
|
||||
|
||||
new_nodes = list()
|
||||
if self_iteration_count == 0:
|
||||
# TODO: should this raise a warning? It might just happen if an empty collection is input, and should be valid.
|
||||
return new_nodes
|
||||
|
||||
# Get all input edges
|
||||
input_edges = self.graph._get_input_edges(node_path)
|
||||
|
||||
# Create new edges for this iteration
|
||||
# For collect nodes, this may contain multiple inputs to the same field
|
||||
new_edges = list()
|
||||
for edge in input_edges:
|
||||
for input_node_id in (n[1] for n in iteration_node_map if n[0] == edge[0].node_id):
|
||||
new_edge = (EdgeConnection(node_id = input_node_id, field = edge[0].field), EdgeConnection(node_id = '', field = edge[1].field))
|
||||
new_edges.append(new_edge)
|
||||
|
||||
# Create a new node (or one for each iteration of this iterator)
|
||||
for i in (range(self_iteration_count) if self_iteration_count > 0 else [-1]):
|
||||
# Create a new node
|
||||
new_node = copy.deepcopy(node)
|
||||
|
||||
# Create the node id (use a random uuid)
|
||||
new_node.id = str(uuid.uuid4())
|
||||
|
||||
# Set the iteration index for iteration invocations
|
||||
if isinstance(new_node, IterateInvocation):
|
||||
new_node.index = i
|
||||
|
||||
# Add to execution graph
|
||||
self.execution_graph.add_node(new_node)
|
||||
self.prepared_source_mapping[new_node.id] = node_path
|
||||
if node_path not in self.source_prepared_mapping:
|
||||
self.source_prepared_mapping[node_path] = set()
|
||||
self.source_prepared_mapping[node_path].add(new_node.id)
|
||||
|
||||
# Add new edges to execution graph
|
||||
for edge in new_edges:
|
||||
new_edge = (edge[0], EdgeConnection(node_id = new_node.id, field = edge[1].field))
|
||||
self.execution_graph.add_edge(new_edge)
|
||||
|
||||
new_nodes.append(new_node.id)
|
||||
|
||||
return new_nodes
|
||||
|
||||
def _iterator_graph(self) -> nx.DiGraph:
|
||||
"""Gets a DiGraph with edges to collectors removed so an ancestor search produces all active iterators for any node"""
|
||||
g = self.graph.nx_graph()
|
||||
collectors = (n for n in self.graph.nodes if isinstance(self.graph.nodes[n], CollectInvocation))
|
||||
for c in collectors:
|
||||
g.remove_edges_from(list(g.in_edges(c)))
|
||||
return g
|
||||
|
||||
|
||||
def _get_node_iterators(self, node_id: str) -> list[str]:
|
||||
"""Gets iterators for a node"""
|
||||
g = self._iterator_graph()
|
||||
iterators = [n for n in nx.ancestors(g, node_id) if isinstance(self.graph.nodes[n], IterateInvocation)]
|
||||
return iterators
|
||||
|
||||
|
||||
def _prepare(self) -> Optional[str]:
|
||||
# Get flattened source graph
|
||||
g = self.graph.nx_graph_flat()
|
||||
|
||||
# Find next unprepared node where all source nodes are executed
|
||||
sorted_nodes = nx.topological_sort(g)
|
||||
next_node_id = next((n for n in sorted_nodes if n not in self.source_prepared_mapping and all((e[0] in self.executed for e in g.in_edges(n)))), None)
|
||||
|
||||
if next_node_id == None:
|
||||
return None
|
||||
|
||||
# Get all parents of the next node
|
||||
next_node_parents = [e[0] for e in g.in_edges(next_node_id)]
|
||||
|
||||
# Create execution nodes
|
||||
next_node = self.graph.get_node(next_node_id)
|
||||
new_node_ids = list()
|
||||
if isinstance(next_node, CollectInvocation):
|
||||
# Collapse all iterator input mappings and create a single execution node for the collect invocation
|
||||
all_iteration_mappings = list(itertools.chain(*(((s,p) for p in self.source_prepared_mapping[s]) for s in next_node_parents)))
|
||||
#all_iteration_mappings = list(set(itertools.chain(*prepared_parent_mappings)))
|
||||
create_results = self._create_execution_node(next_node_id, all_iteration_mappings)
|
||||
if create_results is not None:
|
||||
new_node_ids.extend(create_results)
|
||||
else: # Iterators or normal nodes
|
||||
# Get all iterator combinations for this node
|
||||
# Will produce a list of lists of prepared iterator nodes, from which results can be iterated
|
||||
iterator_nodes = self._get_node_iterators(next_node_id)
|
||||
iterator_nodes_prepared = [list(self.source_prepared_mapping[n]) for n in iterator_nodes]
|
||||
iterator_node_prepared_combinations = list(itertools.product(*iterator_nodes_prepared))
|
||||
|
||||
# Select the correct prepared parents for each iteration
|
||||
# For every iterator, the parent must either not be a child of that iterator, or must match the prepared iteration for that iterator
|
||||
# TODO: Handle a node mapping to none
|
||||
eg = self.execution_graph.nx_graph_flat()
|
||||
prepared_parent_mappings = [[(n,self._get_iteration_node(n, g, eg, it)) for n in next_node_parents] for it in iterator_node_prepared_combinations] # type: ignore
|
||||
|
||||
# Create execution node for each iteration
|
||||
for iteration_mappings in prepared_parent_mappings:
|
||||
create_results = self._create_execution_node(next_node_id, iteration_mappings) # type: ignore
|
||||
if create_results is not None:
|
||||
new_node_ids.extend(create_results)
|
||||
|
||||
return next(iter(new_node_ids), None)
|
||||
|
||||
def _get_iteration_node(self, source_node_path: str, graph: nx.DiGraph, execution_graph: nx.DiGraph, prepared_iterator_nodes: list[str]) -> Optional[str]:
|
||||
"""Gets the prepared version of the specified source node that matches every iteration specified"""
|
||||
prepared_nodes = self.source_prepared_mapping[source_node_path]
|
||||
if len(prepared_nodes) == 1:
|
||||
return next(iter(prepared_nodes))
|
||||
|
||||
# Check if the requested node is an iterator
|
||||
prepared_iterator = next((n for n in prepared_nodes if n in prepared_iterator_nodes), None)
|
||||
if prepared_iterator is not None:
|
||||
return prepared_iterator
|
||||
|
||||
# Filter to only iterator nodes that are a parent of the specified node, in tuple format (prepared, source)
|
||||
iterator_source_node_mapping = [(n, self.prepared_source_mapping[n]) for n in prepared_iterator_nodes]
|
||||
parent_iterators = [itn for itn in iterator_source_node_mapping if nx.has_path(graph, itn[1], source_node_path)]
|
||||
|
||||
return next((n for n in prepared_nodes if all(pit for pit in parent_iterators if nx.has_path(execution_graph, pit[0], n))), None)
|
||||
|
||||
def _get_next_node(self) -> Optional[BaseInvocation]:
|
||||
g = self.execution_graph.nx_graph()
|
||||
sorted_nodes = nx.topological_sort(g)
|
||||
next_node = next((n for n in sorted_nodes if n not in self.executed), None)
|
||||
if next_node is None:
|
||||
return None
|
||||
|
||||
return self.execution_graph.nodes[next_node]
|
||||
|
||||
def _prepare_inputs(self, node: BaseInvocation):
|
||||
input_edges = [e for e in self.execution_graph.edges if e[1].node_id == node.id]
|
||||
if isinstance(node, CollectInvocation):
|
||||
output_collection = [getattr(self.results[edge[0].node_id], edge[0].field) for edge in input_edges if edge[1].field == 'item']
|
||||
setattr(node, 'collection', output_collection)
|
||||
else:
|
||||
for edge in input_edges:
|
||||
output_value = getattr(self.results[edge[0].node_id], edge[0].field)
|
||||
setattr(node, edge[1].field, output_value)
|
||||
|
||||
# TODO: Add API for modifying underlying graph that checks if the change will be valid given the current execution state
|
||||
def _is_edge_valid(self, edge: tuple[EdgeConnection, EdgeConnection]) -> bool:
|
||||
if not self._is_edge_valid(edge):
|
||||
return False
|
||||
|
||||
# Invalid if destination has already been prepared or executed
|
||||
if edge[1].node_id in self.source_prepared_mapping:
|
||||
return False
|
||||
|
||||
# Otherwise, the edge is valid
|
||||
return True
|
||||
|
||||
def _is_node_updatable(self, node_id: str) -> bool:
|
||||
# The node is updatable as long as it hasn't been prepared or executed
|
||||
return node_id not in self.source_prepared_mapping
|
||||
|
||||
def add_node(self, node: BaseInvocation) -> None:
|
||||
self.graph.add_node(node)
|
||||
|
||||
def update_node(self, node_path: str, new_node: BaseInvocation) -> None:
|
||||
if not self._is_node_updatable(node_path):
|
||||
raise NodeAlreadyExecutedError(f'Node {node_path} has already been prepared or executed and cannot be updated')
|
||||
self.graph.update_node(node_path, new_node)
|
||||
|
||||
def delete_node(self, node_path: str) -> None:
|
||||
if not self._is_node_updatable(node_path):
|
||||
raise NodeAlreadyExecutedError(f'Node {node_path} has already been prepared or executed and cannot be deleted')
|
||||
self.graph.delete_node(node_path)
|
||||
|
||||
def add_edge(self, edge: tuple[EdgeConnection, EdgeConnection]) -> None:
|
||||
if not self._is_node_updatable(edge[1].node_id):
|
||||
raise NodeAlreadyExecutedError(f'Destination node {edge[1].node_id} has already been prepared or executed and cannot be linked to')
|
||||
self.graph.add_edge(edge)
|
||||
|
||||
def delete_edge(self, edge: tuple[EdgeConnection, EdgeConnection]) -> None:
|
||||
if not self._is_node_updatable(edge[1].node_id):
|
||||
raise NodeAlreadyExecutedError(f'Destination node {edge[1].node_id} has already been prepared or executed and cannot have a source edge deleted')
|
||||
self.graph.delete_edge(edge)
|
||||
|
||||
GraphInvocation.update_forward_refs()
|
104
invokeai/app/services/image_storage.py
Normal file
@ -0,0 +1,104 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
import datetime
|
||||
import os
|
||||
from pathlib import Path
|
||||
from queue import Queue
|
||||
from typing import Dict
|
||||
from PIL.Image import Image
|
||||
from invokeai.backend.image_util import PngWriter
|
||||
|
||||
|
||||
class ImageType(str, Enum):
|
||||
RESULT = 'results'
|
||||
INTERMEDIATE = 'intermediates'
|
||||
UPLOAD = 'uploads'
|
||||
|
||||
|
||||
class ImageStorageBase(ABC):
|
||||
"""Responsible for storing and retrieving images."""
|
||||
|
||||
@abstractmethod
|
||||
def get(self, image_type: ImageType, image_name: str) -> Image:
|
||||
pass
|
||||
|
||||
# TODO: make this a bit more flexible for e.g. cloud storage
|
||||
@abstractmethod
|
||||
def get_path(self, image_type: ImageType, image_name: str) -> str:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def save(self, image_type: ImageType, image_name: str, image: Image) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def delete(self, image_type: ImageType, image_name: str) -> None:
|
||||
pass
|
||||
|
||||
def create_name(self, context_id: str, node_id: str) -> str:
|
||||
return f'{context_id}_{node_id}_{str(int(datetime.datetime.now(datetime.timezone.utc).timestamp()))}.png'
|
||||
|
||||
|
||||
class DiskImageStorage(ImageStorageBase):
|
||||
"""Stores images on disk"""
|
||||
__output_folder: str
|
||||
__pngWriter: PngWriter
|
||||
__cache_ids: Queue # TODO: this is an incredibly naive cache
|
||||
__cache: Dict[str, Image]
|
||||
__max_cache_size: int
|
||||
|
||||
def __init__(self, output_folder: str):
|
||||
self.__output_folder = output_folder
|
||||
self.__pngWriter = PngWriter(output_folder)
|
||||
self.__cache = dict()
|
||||
self.__cache_ids = Queue()
|
||||
self.__max_cache_size = 10 # TODO: get this from config
|
||||
|
||||
Path(output_folder).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# TODO: don't hard-code. get/save/delete should maybe take subpath?
|
||||
for image_type in ImageType:
|
||||
Path(os.path.join(output_folder, image_type)).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
def get(self, image_type: ImageType, image_name: str) -> Image:
|
||||
image_path = self.get_path(image_type, image_name)
|
||||
cache_item = self.__get_cache(image_path)
|
||||
if cache_item:
|
||||
return cache_item
|
||||
|
||||
image = Image.open(image_path)
|
||||
self.__set_cache(image_path, image)
|
||||
return image
|
||||
|
||||
# TODO: make this a bit more flexible for e.g. cloud storage
|
||||
def get_path(self, image_type: ImageType, image_name: str) -> str:
|
||||
path = os.path.join(self.__output_folder, image_type, image_name)
|
||||
return path
|
||||
|
||||
def save(self, image_type: ImageType, image_name: str, image: Image) -> None:
|
||||
image_subpath = os.path.join(image_type, image_name)
|
||||
self.__pngWriter.save_image_and_prompt_to_png(image, "", image_subpath, None) # TODO: just pass full path to png writer
|
||||
|
||||
image_path = self.get_path(image_type, image_name)
|
||||
self.__set_cache(image_path, image)
|
||||
|
||||
def delete(self, image_type: ImageType, image_name: str) -> None:
|
||||
image_path = self.get_path(image_type, image_name)
|
||||
if os.path.exists(image_path):
|
||||
os.remove(image_path)
|
||||
|
||||
if image_path in self.__cache:
|
||||
del self.__cache[image_path]
|
||||
|
||||
def __get_cache(self, image_name: str) -> Image:
|
||||
return None if image_name not in self.__cache else self.__cache[image_name]
|
||||
|
||||
def __set_cache(self, image_name: str, image: Image):
|
||||
if not image_name in self.__cache:
|
||||
self.__cache[image_name] = image
|
||||
self.__cache_ids.put(image_name) # TODO: this should refresh position for LRU cache
|
||||
if len(self.__cache) > self.__max_cache_size:
|
||||
cache_id = self.__cache_ids.get()
|
||||
del self.__cache[cache_id]
|
46
invokeai/app/services/invocation_queue.py
Normal file
@ -0,0 +1,46 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from queue import Queue
|
||||
|
||||
|
||||
# TODO: make this serializable
|
||||
class InvocationQueueItem:
|
||||
#session_id: str
|
||||
graph_execution_state_id: str
|
||||
invocation_id: str
|
||||
invoke_all: bool
|
||||
|
||||
def __init__(self,
|
||||
#session_id: str,
|
||||
graph_execution_state_id: str,
|
||||
invocation_id: str,
|
||||
invoke_all: bool = False):
|
||||
#self.session_id = session_id
|
||||
self.graph_execution_state_id = graph_execution_state_id
|
||||
self.invocation_id = invocation_id
|
||||
self.invoke_all = invoke_all
|
||||
|
||||
|
||||
class InvocationQueueABC(ABC):
|
||||
"""Abstract base class for all invocation queues"""
|
||||
@abstractmethod
|
||||
def get(self) -> InvocationQueueItem:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def put(self, item: InvocationQueueItem|None) -> None:
|
||||
pass
|
||||
|
||||
|
||||
class MemoryInvocationQueue(InvocationQueueABC):
|
||||
__queue: Queue
|
||||
|
||||
def __init__(self):
|
||||
self.__queue = Queue()
|
||||
|
||||
def get(self) -> InvocationQueueItem:
|
||||
return self.__queue.get()
|
||||
|
||||
def put(self, item: InvocationQueueItem|None) -> None:
|
||||
self.__queue.put(item)
|
32
invokeai/app/services/invocation_services.py
Normal file
@ -0,0 +1,32 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
from .invocation_queue import InvocationQueueABC
|
||||
from .item_storage import ItemStorageABC
|
||||
from .image_storage import ImageStorageBase
|
||||
from .events import EventServiceBase
|
||||
from invokeai.backend import Generate
|
||||
|
||||
class InvocationServices():
|
||||
"""Services that can be used by invocations"""
|
||||
generate: Generate # TODO: wrap Generate, or split it up from model?
|
||||
events: EventServiceBase
|
||||
images: ImageStorageBase
|
||||
queue: InvocationQueueABC
|
||||
|
||||
# NOTE: we must forward-declare any types that include invocations, since invocations can use services
|
||||
graph_execution_manager: ItemStorageABC['GraphExecutionState']
|
||||
processor: 'InvocationProcessorABC'
|
||||
|
||||
def __init__(self,
|
||||
generate: Generate,
|
||||
events: EventServiceBase,
|
||||
images: ImageStorageBase,
|
||||
queue: InvocationQueueABC,
|
||||
graph_execution_manager: ItemStorageABC['GraphExecutionState'],
|
||||
processor: 'InvocationProcessorABC'
|
||||
):
|
||||
self.generate = generate
|
||||
self.events = events
|
||||
self.images = images
|
||||
self.queue = queue
|
||||
self.graph_execution_manager = graph_execution_manager
|
||||
self.processor = processor
|
90
invokeai/app/services/invoker.py
Normal file
@ -0,0 +1,90 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
from abc import ABC
|
||||
from threading import Event, Thread
|
||||
from .graph import Graph, GraphExecutionState
|
||||
from .item_storage import ItemStorageABC
|
||||
from ..invocations.baseinvocation import InvocationContext
|
||||
from .invocation_services import InvocationServices
|
||||
from .invocation_queue import InvocationQueueABC, InvocationQueueItem
|
||||
|
||||
|
||||
class Invoker:
|
||||
"""The invoker, used to execute invocations"""
|
||||
|
||||
services: InvocationServices
|
||||
|
||||
def __init__(self,
|
||||
services: InvocationServices
|
||||
):
|
||||
self.services = services
|
||||
self._start()
|
||||
|
||||
|
||||
def invoke(self, graph_execution_state: GraphExecutionState, invoke_all: bool = False) -> str|None:
|
||||
"""Determines the next node to invoke and returns the id of the invoked node, or None if there are no nodes to execute"""
|
||||
|
||||
# Get the next invocation
|
||||
invocation = graph_execution_state.next()
|
||||
if not invocation:
|
||||
return None
|
||||
|
||||
# Save the execution state
|
||||
self.services.graph_execution_manager.set(graph_execution_state)
|
||||
|
||||
# Queue the invocation
|
||||
print(f'queueing item {invocation.id}')
|
||||
self.services.queue.put(InvocationQueueItem(
|
||||
#session_id = session.id,
|
||||
graph_execution_state_id = graph_execution_state.id,
|
||||
invocation_id = invocation.id,
|
||||
invoke_all = invoke_all
|
||||
))
|
||||
|
||||
return invocation.id
|
||||
|
||||
|
||||
def create_execution_state(self, graph: Graph|None = None) -> GraphExecutionState:
|
||||
"""Creates a new execution state for the given graph"""
|
||||
new_state = GraphExecutionState(graph = Graph() if graph is None else graph)
|
||||
self.services.graph_execution_manager.set(new_state)
|
||||
return new_state
|
||||
|
||||
|
||||
def __start_service(self, service) -> None:
|
||||
# Call start() method on any services that have it
|
||||
start_op = getattr(service, 'start', None)
|
||||
if callable(start_op):
|
||||
start_op(self)
|
||||
|
||||
|
||||
def __stop_service(self, service) -> None:
|
||||
# Call stop() method on any services that have it
|
||||
stop_op = getattr(service, 'stop', None)
|
||||
if callable(stop_op):
|
||||
stop_op(self)
|
||||
|
||||
|
||||
def _start(self) -> None:
|
||||
"""Starts the invoker. This is called automatically when the invoker is created."""
|
||||
for service in vars(self.services):
|
||||
self.__start_service(getattr(self.services, service))
|
||||
|
||||
for service in vars(self.services):
|
||||
self.__start_service(getattr(self.services, service))
|
||||
|
||||
|
||||
def stop(self) -> None:
|
||||
"""Stops the invoker. A new invoker will have to be created to execute further."""
|
||||
# First stop all services
|
||||
for service in vars(self.services):
|
||||
self.__stop_service(getattr(self.services, service))
|
||||
|
||||
for service in vars(self.services):
|
||||
self.__stop_service(getattr(self.services, service))
|
||||
|
||||
self.services.queue.put(None)
|
||||
|
||||
|
||||
class InvocationProcessorABC(ABC):
|
||||
pass
|
57
invokeai/app/services/item_storage.py
Normal file
@ -0,0 +1,57 @@
|
||||
|
||||
from typing import Callable, TypeVar, Generic
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic.generics import GenericModel
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
T = TypeVar('T', bound=BaseModel)
|
||||
|
||||
class PaginatedResults(GenericModel, Generic[T]):
|
||||
"""Paginated results"""
|
||||
items: list[T] = Field(description = "Items")
|
||||
page: int = Field(description = "Current Page")
|
||||
pages: int = Field(description = "Total number of pages")
|
||||
per_page: int = Field(description = "Number of items per page")
|
||||
total: int = Field(description = "Total number of items in result")
|
||||
|
||||
|
||||
class ItemStorageABC(ABC, Generic[T]):
|
||||
_on_changed_callbacks: list[Callable[[T], None]]
|
||||
_on_deleted_callbacks: list[Callable[[str], None]]
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._on_changed_callbacks = list()
|
||||
self._on_deleted_callbacks = list()
|
||||
|
||||
"""Base item storage class"""
|
||||
@abstractmethod
|
||||
def get(self, item_id: str) -> T:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def set(self, item: T) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def list(self, page: int = 0, per_page: int = 10) -> PaginatedResults[T]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def search(self, query: str, page: int = 0, per_page: int = 10) -> PaginatedResults[T]:
|
||||
pass
|
||||
|
||||
def on_changed(self, on_changed: Callable[[T], None]) -> None:
|
||||
"""Register a callback for when an item is changed"""
|
||||
self._on_changed_callbacks.append(on_changed)
|
||||
|
||||
def on_deleted(self, on_deleted: Callable[[str], None]) -> None:
|
||||
"""Register a callback for when an item is deleted"""
|
||||
self._on_deleted_callbacks.append(on_deleted)
|
||||
|
||||
def _on_changed(self, item: T) -> None:
|
||||
for callback in self._on_changed_callbacks:
|
||||
callback(item)
|
||||
|
||||
def _on_deleted(self, item_id: str) -> None:
|
||||
for callback in self._on_deleted_callbacks:
|
||||
callback(item_id)
|
95
invokeai/app/services/processor.py
Normal file
@ -0,0 +1,95 @@
|
||||
from threading import Event, Thread
|
||||
import traceback
|
||||
from ..invocations.baseinvocation import InvocationContext
|
||||
from .invocation_queue import InvocationQueueItem
|
||||
from .invoker import InvocationProcessorABC, Invoker
|
||||
|
||||
|
||||
class DefaultInvocationProcessor(InvocationProcessorABC):
|
||||
__invoker_thread: Thread
|
||||
__stop_event: Event
|
||||
__invoker: Invoker
|
||||
|
||||
def start(self, invoker) -> None:
|
||||
self.__invoker = invoker
|
||||
self.__stop_event = Event()
|
||||
self.__invoker_thread = Thread(
|
||||
name = "invoker_processor",
|
||||
target = self.__process,
|
||||
kwargs = dict(stop_event = self.__stop_event)
|
||||
)
|
||||
self.__invoker_thread.daemon = True # TODO: probably better to just not use threads?
|
||||
self.__invoker_thread.start()
|
||||
|
||||
|
||||
def stop(self, *args, **kwargs) -> None:
|
||||
self.__stop_event.set()
|
||||
|
||||
|
||||
def __process(self, stop_event: Event):
|
||||
try:
|
||||
while not stop_event.is_set():
|
||||
queue_item: InvocationQueueItem = self.__invoker.services.queue.get()
|
||||
if not queue_item: # Probably stopping
|
||||
continue
|
||||
|
||||
graph_execution_state = self.__invoker.services.graph_execution_manager.get(queue_item.graph_execution_state_id)
|
||||
invocation = graph_execution_state.execution_graph.get_node(queue_item.invocation_id)
|
||||
|
||||
# Send starting event
|
||||
self.__invoker.services.events.emit_invocation_started(
|
||||
graph_execution_state_id = graph_execution_state.id,
|
||||
invocation_id = invocation.id
|
||||
)
|
||||
|
||||
# Invoke
|
||||
try:
|
||||
outputs = invocation.invoke(InvocationContext(
|
||||
services = self.__invoker.services,
|
||||
graph_execution_state_id = graph_execution_state.id
|
||||
))
|
||||
|
||||
# Save outputs and history
|
||||
graph_execution_state.complete(invocation.id, outputs)
|
||||
|
||||
# Save the state changes
|
||||
self.__invoker.services.graph_execution_manager.set(graph_execution_state)
|
||||
|
||||
# Send complete event
|
||||
self.__invoker.services.events.emit_invocation_complete(
|
||||
graph_execution_state_id = graph_execution_state.id,
|
||||
invocation_id = invocation.id,
|
||||
result = outputs.dict()
|
||||
)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
|
||||
except Exception as e:
|
||||
error = traceback.format_exc()
|
||||
|
||||
# Save error
|
||||
graph_execution_state.set_node_error(invocation.id, error)
|
||||
|
||||
# Save the state changes
|
||||
self.__invoker.services.graph_execution_manager.set(graph_execution_state)
|
||||
|
||||
# Send error event
|
||||
self.__invoker.services.events.emit_invocation_error(
|
||||
graph_execution_state_id = graph_execution_state.id,
|
||||
invocation_id = invocation.id,
|
||||
error = error
|
||||
)
|
||||
|
||||
pass
|
||||
|
||||
|
||||
# Queue any further commands if invoking all
|
||||
is_complete = graph_execution_state.is_complete()
|
||||
if queue_item.invoke_all and not is_complete:
|
||||
self.__invoker.invoke(graph_execution_state, invoke_all = True)
|
||||
elif is_complete:
|
||||
self.__invoker.services.events.emit_graph_execution_complete(graph_execution_state.id)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
... # Log something?
|
119
invokeai/app/services/sqlite.py
Normal file
@ -0,0 +1,119 @@
|
||||
import sqlite3
|
||||
from threading import Lock
|
||||
from typing import Generic, TypeVar, Union, get_args
|
||||
from pydantic import BaseModel, parse_raw_as
|
||||
from .item_storage import ItemStorageABC, PaginatedResults
|
||||
|
||||
T = TypeVar('T', bound=BaseModel)
|
||||
|
||||
sqlite_memory = ':memory:'
|
||||
|
||||
class SqliteItemStorage(ItemStorageABC, Generic[T]):
|
||||
_filename: str
|
||||
_table_name: str
|
||||
_conn: sqlite3.Connection
|
||||
_cursor: sqlite3.Cursor
|
||||
_id_field: str
|
||||
_lock: Lock
|
||||
|
||||
def __init__(self, filename: str, table_name: str, id_field: str = 'id'):
|
||||
super().__init__()
|
||||
|
||||
self._filename = filename
|
||||
self._table_name = table_name
|
||||
self._id_field = id_field # TODO: validate that T has this field
|
||||
self._lock = Lock()
|
||||
|
||||
self._conn = sqlite3.connect(self._filename, check_same_thread=False) # TODO: figure out a better threading solution
|
||||
self._cursor = self._conn.cursor()
|
||||
|
||||
self._create_table()
|
||||
|
||||
def _create_table(self):
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(f'''CREATE TABLE IF NOT EXISTS {self._table_name} (
|
||||
item TEXT,
|
||||
id TEXT GENERATED ALWAYS AS (json_extract(item, '$.{self._id_field}')) VIRTUAL NOT NULL);''')
|
||||
self._cursor.execute(f'''CREATE UNIQUE INDEX IF NOT EXISTS {self._table_name}_id ON {self._table_name}(id);''')
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
def _parse_item(self, item: str) -> T:
|
||||
item_type = get_args(self.__orig_class__)[0]
|
||||
return parse_raw_as(item_type, item)
|
||||
|
||||
def set(self, item: T):
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(f'''INSERT OR REPLACE INTO {self._table_name} (item) VALUES (?);''', (item.json(),))
|
||||
finally:
|
||||
self._lock.release()
|
||||
self._on_changed(item)
|
||||
|
||||
def get(self, id: str) -> Union[T, None]:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(f'''SELECT item FROM {self._table_name} WHERE id = ?;''', (str(id),))
|
||||
result = self._cursor.fetchone()
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
if not result:
|
||||
return None
|
||||
|
||||
return self._parse_item(result[0])
|
||||
|
||||
def delete(self, id: str):
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(f'''DELETE FROM {self._table_name} WHERE id = ?;''', (str(id),))
|
||||
finally:
|
||||
self._lock.release()
|
||||
self._on_deleted(id)
|
||||
|
||||
def list(self, page: int = 0, per_page: int = 10) -> PaginatedResults[T]:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(f'''SELECT item FROM {self._table_name} LIMIT ? OFFSET ?;''', (per_page, page * per_page))
|
||||
result = self._cursor.fetchall()
|
||||
|
||||
items = list(map(lambda r: self._parse_item(r[0]), result))
|
||||
|
||||
self._cursor.execute(f'''SELECT count(*) FROM {self._table_name};''')
|
||||
count = self._cursor.fetchone()[0]
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
pageCount = int(count / per_page) + 1
|
||||
|
||||
return PaginatedResults[T](
|
||||
items = items,
|
||||
page = page,
|
||||
pages = pageCount,
|
||||
per_page = per_page,
|
||||
total = count
|
||||
)
|
||||
|
||||
def search(self, query: str, page: int = 0, per_page: int = 10) -> PaginatedResults[T]:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(f'''SELECT item FROM {self._table_name} WHERE item LIKE ? LIMIT ? OFFSET ?;''', (f'%{query}%', per_page, page * per_page))
|
||||
result = self._cursor.fetchall()
|
||||
|
||||
items = list(map(lambda r: self._parse_item(r[0]), result))
|
||||
|
||||
self._cursor.execute(f'''SELECT count(*) FROM {self._table_name} WHERE item LIKE ?;''', (f'%{query}%',))
|
||||
count = self._cursor.fetchone()[0]
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
pageCount = int(count / per_page) + 1
|
||||
|
||||
return PaginatedResults[T](
|
||||
items = items,
|
||||
page = page,
|
||||
pages = pageCount,
|
||||
per_page = per_page,
|
||||
total = count
|
||||
)
|
@ -1,8 +1,8 @@
|
||||
'''
|
||||
Initialization file for invokeai.backend
|
||||
'''
|
||||
# this is causing circular import issues
|
||||
# from .invoke_ai_web_server import InvokeAIWebServer
|
||||
from .model_manager import ModelManager
|
||||
from .model_management import ModelManager
|
||||
from .generate import Generate
|
||||
|
||||
|
||||
|
||||
|
1347
invokeai/backend/args.py
Normal file
0
invokeai/backend/config/__init__.py
Normal file
860
invokeai/backend/config/invokeai_configure.py
Executable file
@ -0,0 +1,860 @@
|
||||
#!/usr/bin/env python
|
||||
# Copyright (c) 2022 Lincoln D. Stein (https://github.com/lstein)
|
||||
# Before running stable-diffusion on an internet-isolated machine,
|
||||
# run this script from one with internet connectivity. The
|
||||
# two machines must share a common .cache directory.
|
||||
#
|
||||
# Coauthor: Kevin Turner http://github.com/keturn
|
||||
#
|
||||
print("Loading Python libraries...\n")
|
||||
import argparse
|
||||
import io
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
import sys
|
||||
import traceback
|
||||
import warnings
|
||||
from argparse import Namespace
|
||||
from pathlib import Path
|
||||
from urllib import request
|
||||
from shutil import get_terminal_size
|
||||
|
||||
import npyscreen
|
||||
import torch
|
||||
import transformers
|
||||
from diffusers import AutoencoderKL
|
||||
from huggingface_hub import HfFolder
|
||||
from huggingface_hub import login as hf_hub_login
|
||||
from omegaconf import OmegaConf
|
||||
from tqdm import tqdm
|
||||
from transformers import (
|
||||
AutoProcessor,
|
||||
CLIPSegForImageSegmentation,
|
||||
CLIPTextModel,
|
||||
CLIPTokenizer,
|
||||
)
|
||||
|
||||
import invokeai.configs as configs
|
||||
|
||||
from ..args import PRECISION_CHOICES, Args
|
||||
from ..globals import Globals, global_config_dir, global_config_file, global_cache_dir
|
||||
from ...frontend.config.model_install import addModelsForm, process_and_execute
|
||||
from .model_install_backend import (
|
||||
default_dataset,
|
||||
download_from_hf,
|
||||
recommended_datasets,
|
||||
hf_download_with_resume,
|
||||
)
|
||||
from ...frontend.config.widgets import IntTitleSlider, CenteredButtonPress, set_min_terminal_size
|
||||
|
||||
|
||||
warnings.filterwarnings("ignore")
|
||||
|
||||
transformers.logging.set_verbosity_error()
|
||||
|
||||
# --------------------------globals-----------------------
|
||||
Model_dir = "models"
|
||||
Weights_dir = "ldm/stable-diffusion-v1/"
|
||||
|
||||
# the initial "configs" dir is now bundled in the `invokeai.configs` package
|
||||
Dataset_path = Path(configs.__path__[0]) / "INITIAL_MODELS.yaml"
|
||||
|
||||
Default_config_file = Path(global_config_dir()) / "models.yaml"
|
||||
SD_Configs = Path(global_config_dir()) / "stable-diffusion"
|
||||
|
||||
Datasets = OmegaConf.load(Dataset_path)
|
||||
|
||||
# minimum size for the UI
|
||||
MIN_COLS = 135
|
||||
MIN_LINES = 45
|
||||
|
||||
INIT_FILE_PREAMBLE = """# InvokeAI initialization file
|
||||
# This is the InvokeAI initialization file, which contains command-line default values.
|
||||
# Feel free to edit. If anything goes wrong, you can re-initialize this file by deleting
|
||||
# or renaming it and then running invokeai-configure again.
|
||||
# Place frequently-used startup commands here, one or more per line.
|
||||
# Examples:
|
||||
# --outdir=D:\data\images
|
||||
# --no-nsfw_checker
|
||||
# --web --host=0.0.0.0
|
||||
# --steps=20
|
||||
# -Ak_euler_a -C10.0
|
||||
"""
|
||||
|
||||
# --------------------------------------------
|
||||
def postscript(errors: None):
|
||||
if not any(errors):
|
||||
message = f"""
|
||||
** INVOKEAI INSTALLATION SUCCESSFUL **
|
||||
If you installed manually from source or with 'pip install': activate the virtual environment
|
||||
then run one of the following commands to start InvokeAI.
|
||||
|
||||
Web UI:
|
||||
invokeai --web # (connect to http://localhost:9090)
|
||||
invokeai --web --host 0.0.0.0 # (connect to http://your-lan-ip:9090 from another computer on the local network)
|
||||
|
||||
Command-line interface:
|
||||
invokeai
|
||||
|
||||
If you installed using an installation script, run:
|
||||
{Globals.root}/invoke.{"bat" if sys.platform == "win32" else "sh"}
|
||||
|
||||
Add the '--help' argument to see all of the command-line switches available for use.
|
||||
"""
|
||||
|
||||
else:
|
||||
message = "\n** There were errors during installation. It is possible some of the models were not fully downloaded.\n"
|
||||
for err in errors:
|
||||
message += f"\t - {err}\n"
|
||||
message += "Please check the logs above and correct any issues."
|
||||
|
||||
print(message)
|
||||
|
||||
|
||||
# ---------------------------------------------
|
||||
def yes_or_no(prompt: str, default_yes=True):
|
||||
default = "y" if default_yes else "n"
|
||||
response = input(f"{prompt} [{default}] ") or default
|
||||
if default_yes:
|
||||
return response[0] not in ("n", "N")
|
||||
else:
|
||||
return response[0] in ("y", "Y")
|
||||
|
||||
|
||||
# ---------------------------------------------
|
||||
def HfLogin(access_token) -> str:
|
||||
"""
|
||||
Helper for logging in to Huggingface
|
||||
The stdout capture is needed to hide the irrelevant "git credential helper" warning
|
||||
"""
|
||||
|
||||
capture = io.StringIO()
|
||||
sys.stdout = capture
|
||||
try:
|
||||
hf_hub_login(token=access_token, add_to_git_credential=False)
|
||||
sys.stdout = sys.__stdout__
|
||||
except Exception as exc:
|
||||
sys.stdout = sys.__stdout__
|
||||
print(exc)
|
||||
raise exc
|
||||
|
||||
|
||||
# -------------------------------------
|
||||
class ProgressBar:
|
||||
def __init__(self, model_name="file"):
|
||||
self.pbar = None
|
||||
self.name = model_name
|
||||
|
||||
def __call__(self, block_num, block_size, total_size):
|
||||
if not self.pbar:
|
||||
self.pbar = tqdm(
|
||||
desc=self.name,
|
||||
initial=0,
|
||||
unit="iB",
|
||||
unit_scale=True,
|
||||
unit_divisor=1000,
|
||||
total=total_size,
|
||||
)
|
||||
self.pbar.update(block_size)
|
||||
|
||||
|
||||
# ---------------------------------------------
|
||||
def download_with_progress_bar(model_url: str, model_dest: str, label: str = "the"):
|
||||
try:
|
||||
print(f"Installing {label} model file {model_url}...", end="", file=sys.stderr)
|
||||
if not os.path.exists(model_dest):
|
||||
os.makedirs(os.path.dirname(model_dest), exist_ok=True)
|
||||
request.urlretrieve(
|
||||
model_url, model_dest, ProgressBar(os.path.basename(model_dest))
|
||||
)
|
||||
print("...downloaded successfully", file=sys.stderr)
|
||||
else:
|
||||
print("...exists", file=sys.stderr)
|
||||
except Exception:
|
||||
print("...download failed", file=sys.stderr)
|
||||
print(f"Error downloading {label} model", file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
|
||||
|
||||
# ---------------------------------------------
|
||||
# this will preload the Bert tokenizer fles
|
||||
def download_bert():
|
||||
print(
|
||||
"Installing bert tokenizer...",
|
||||
file=sys.stderr
|
||||
)
|
||||
with warnings.catch_warnings():
|
||||
warnings.filterwarnings("ignore", category=DeprecationWarning)
|
||||
from transformers import BertTokenizerFast
|
||||
download_from_hf(BertTokenizerFast, "bert-base-uncased")
|
||||
|
||||
|
||||
# ---------------------------------------------
|
||||
def download_sd1_clip():
|
||||
print("Installing SD1 clip model...", file=sys.stderr)
|
||||
version = "openai/clip-vit-large-patch14"
|
||||
download_from_hf(CLIPTokenizer, version)
|
||||
download_from_hf(CLIPTextModel, version)
|
||||
|
||||
# ---------------------------------------------
|
||||
def download_sd2_clip():
|
||||
version = 'stabilityai/stable-diffusion-2'
|
||||
print("Installing SD2 clip model...", file=sys.stderr)
|
||||
download_from_hf(CLIPTokenizer, version, subfolder='tokenizer')
|
||||
download_from_hf(CLIPTextModel, version, subfolder='text_encoder')
|
||||
|
||||
# ---------------------------------------------
|
||||
def download_realesrgan():
|
||||
print("Installing models from RealESRGAN...", file=sys.stderr)
|
||||
model_url = "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth"
|
||||
wdn_model_url = "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-wdn-x4v3.pth"
|
||||
|
||||
model_dest = os.path.join(
|
||||
Globals.root, "models/realesrgan/realesr-general-x4v3.pth"
|
||||
)
|
||||
|
||||
wdn_model_dest = os.path.join(
|
||||
Globals.root, "models/realesrgan/realesr-general-wdn-x4v3.pth"
|
||||
)
|
||||
|
||||
download_with_progress_bar(model_url, model_dest, "RealESRGAN")
|
||||
download_with_progress_bar(wdn_model_url, wdn_model_dest, "RealESRGANwdn")
|
||||
|
||||
|
||||
def download_gfpgan():
|
||||
print("Installing GFPGAN models...", file=sys.stderr)
|
||||
for model in (
|
||||
[
|
||||
"https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth",
|
||||
"./models/gfpgan/GFPGANv1.4.pth",
|
||||
],
|
||||
[
|
||||
"https://github.com/xinntao/facexlib/releases/download/v0.1.0/detection_Resnet50_Final.pth",
|
||||
"./models/gfpgan/weights/detection_Resnet50_Final.pth",
|
||||
],
|
||||
[
|
||||
"https://github.com/xinntao/facexlib/releases/download/v0.2.2/parsing_parsenet.pth",
|
||||
"./models/gfpgan/weights/parsing_parsenet.pth",
|
||||
],
|
||||
):
|
||||
model_url, model_dest = model[0], os.path.join(Globals.root, model[1])
|
||||
download_with_progress_bar(model_url, model_dest, "GFPGAN weights")
|
||||
|
||||
|
||||
# ---------------------------------------------
|
||||
def download_codeformer():
|
||||
print("Installing CodeFormer model file...", file=sys.stderr)
|
||||
model_url = (
|
||||
"https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth"
|
||||
)
|
||||
model_dest = os.path.join(Globals.root, "models/codeformer/codeformer.pth")
|
||||
download_with_progress_bar(model_url, model_dest, "CodeFormer")
|
||||
|
||||
|
||||
# ---------------------------------------------
|
||||
def download_clipseg():
|
||||
print("Installing clipseg model for text-based masking...", file=sys.stderr)
|
||||
CLIPSEG_MODEL = "CIDAS/clipseg-rd64-refined"
|
||||
try:
|
||||
download_from_hf(AutoProcessor, CLIPSEG_MODEL)
|
||||
download_from_hf(CLIPSegForImageSegmentation, CLIPSEG_MODEL)
|
||||
except Exception:
|
||||
print("Error installing clipseg model:")
|
||||
print(traceback.format_exc())
|
||||
|
||||
|
||||
# -------------------------------------
|
||||
def download_safety_checker():
|
||||
print("Installing model for NSFW content detection...", file=sys.stderr)
|
||||
try:
|
||||
from diffusers.pipelines.stable_diffusion.safety_checker import (
|
||||
StableDiffusionSafetyChecker,
|
||||
)
|
||||
from transformers import AutoFeatureExtractor
|
||||
except ModuleNotFoundError:
|
||||
print("Error installing NSFW checker model:")
|
||||
print(traceback.format_exc())
|
||||
return
|
||||
safety_model_id = "CompVis/stable-diffusion-safety-checker"
|
||||
print("AutoFeatureExtractor...", file=sys.stderr)
|
||||
download_from_hf(AutoFeatureExtractor, safety_model_id)
|
||||
print("StableDiffusionSafetyChecker...", file=sys.stderr)
|
||||
download_from_hf(StableDiffusionSafetyChecker, safety_model_id)
|
||||
|
||||
|
||||
# -------------------------------------
|
||||
def download_vaes():
|
||||
print("Installing stabilityai VAE...", file=sys.stderr)
|
||||
try:
|
||||
# first the diffusers version
|
||||
repo_id = "stabilityai/sd-vae-ft-mse"
|
||||
args = dict(
|
||||
cache_dir=global_cache_dir("diffusers"),
|
||||
)
|
||||
if not AutoencoderKL.from_pretrained(repo_id, **args):
|
||||
raise Exception(f"download of {repo_id} failed")
|
||||
|
||||
repo_id = "stabilityai/sd-vae-ft-mse-original"
|
||||
model_name = "vae-ft-mse-840000-ema-pruned.ckpt"
|
||||
# next the legacy checkpoint version
|
||||
if not hf_download_with_resume(
|
||||
repo_id=repo_id,
|
||||
model_name=model_name,
|
||||
model_dir=str(Globals.root / Model_dir / Weights_dir),
|
||||
):
|
||||
raise Exception(f"download of {model_name} failed")
|
||||
except Exception as e:
|
||||
print(f"Error downloading StabilityAI standard VAE: {str(e)}", file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
|
||||
|
||||
# -------------------------------------
|
||||
def get_root(root: str = None) -> str:
|
||||
if root:
|
||||
return root
|
||||
elif os.environ.get("INVOKEAI_ROOT"):
|
||||
return os.environ.get("INVOKEAI_ROOT")
|
||||
else:
|
||||
return Globals.root
|
||||
|
||||
|
||||
# -------------------------------------
|
||||
class editOptsForm(npyscreen.FormMultiPage):
|
||||
# for responsive resizing - disabled
|
||||
# FIX_MINIMUM_SIZE_WHEN_CREATED = False
|
||||
|
||||
def create(self):
|
||||
program_opts = self.parentApp.program_opts
|
||||
old_opts = self.parentApp.invokeai_opts
|
||||
first_time = not (Globals.root / Globals.initfile).exists()
|
||||
access_token = HfFolder.get_token()
|
||||
window_width,window_height = get_terminal_size()
|
||||
for i in [
|
||||
"Configure startup settings. You can come back and change these later.",
|
||||
"Use ctrl-N and ctrl-P to move to the <N>ext and <P>revious fields.",
|
||||
"Use cursor arrows to make a checkbox selection, and space to toggle.",
|
||||
]:
|
||||
self.add_widget_intelligent(
|
||||
npyscreen.FixedText,
|
||||
value=i,
|
||||
editable=False,
|
||||
color="CONTROL",
|
||||
)
|
||||
|
||||
self.nextrely += 1
|
||||
self.add_widget_intelligent(
|
||||
npyscreen.TitleFixedText,
|
||||
name="== BASIC OPTIONS ==",
|
||||
begin_entry_at=0,
|
||||
editable=False,
|
||||
color="CONTROL",
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.nextrely -= 1
|
||||
self.add_widget_intelligent(
|
||||
npyscreen.FixedText,
|
||||
value="Select an output directory for images:",
|
||||
editable=False,
|
||||
color="CONTROL",
|
||||
)
|
||||
self.outdir = self.add_widget_intelligent(
|
||||
npyscreen.TitleFilename,
|
||||
name="(<tab> autocompletes, ctrl-N advances):",
|
||||
value=old_opts.outdir or str(default_output_dir()),
|
||||
select_dir=True,
|
||||
must_exist=False,
|
||||
use_two_lines=False,
|
||||
labelColor="GOOD",
|
||||
begin_entry_at=40,
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.nextrely += 1
|
||||
self.add_widget_intelligent(
|
||||
npyscreen.FixedText,
|
||||
value="Activate the NSFW checker to blur images showing potential sexual imagery:",
|
||||
editable=False,
|
||||
color="CONTROL",
|
||||
)
|
||||
self.safety_checker = self.add_widget_intelligent(
|
||||
npyscreen.Checkbox,
|
||||
name="NSFW checker",
|
||||
value=old_opts.safety_checker,
|
||||
relx=5,
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.nextrely += 1
|
||||
for i in [
|
||||
"If you have an account at HuggingFace you may paste your access token here",
|
||||
'to allow InvokeAI to download styles & subjects from the "Concept Library".',
|
||||
"See https://huggingface.co/settings/tokens",
|
||||
]:
|
||||
self.add_widget_intelligent(
|
||||
npyscreen.FixedText,
|
||||
value=i,
|
||||
editable=False,
|
||||
color="CONTROL",
|
||||
)
|
||||
|
||||
self.hf_token = self.add_widget_intelligent(
|
||||
npyscreen.TitlePassword,
|
||||
name="Access Token (ctrl-shift-V pastes):",
|
||||
value=access_token,
|
||||
begin_entry_at=42,
|
||||
use_two_lines=False,
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.nextrely += 1
|
||||
self.add_widget_intelligent(
|
||||
npyscreen.TitleFixedText,
|
||||
name="== ADVANCED OPTIONS ==",
|
||||
begin_entry_at=0,
|
||||
editable=False,
|
||||
color="CONTROL",
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.nextrely -= 1
|
||||
self.add_widget_intelligent(
|
||||
npyscreen.TitleFixedText,
|
||||
name="GPU Management",
|
||||
begin_entry_at=0,
|
||||
editable=False,
|
||||
color="CONTROL",
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.nextrely -= 1
|
||||
self.free_gpu_mem = self.add_widget_intelligent(
|
||||
npyscreen.Checkbox,
|
||||
name="Free GPU memory after each generation",
|
||||
value=old_opts.free_gpu_mem,
|
||||
relx=5,
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.xformers = self.add_widget_intelligent(
|
||||
npyscreen.Checkbox,
|
||||
name="Enable xformers support if available",
|
||||
value=old_opts.xformers,
|
||||
relx=5,
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.ckpt_convert = self.add_widget_intelligent(
|
||||
npyscreen.Checkbox,
|
||||
name="Load legacy checkpoint models into memory as diffusers models",
|
||||
value=old_opts.ckpt_convert,
|
||||
relx=5,
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.always_use_cpu = self.add_widget_intelligent(
|
||||
npyscreen.Checkbox,
|
||||
name="Force CPU to be used on GPU systems",
|
||||
value=old_opts.always_use_cpu,
|
||||
relx=5,
|
||||
scroll_exit=True,
|
||||
)
|
||||
precision = old_opts.precision or (
|
||||
"float32" if program_opts.full_precision else "auto"
|
||||
)
|
||||
self.precision = self.add_widget_intelligent(
|
||||
npyscreen.TitleSelectOne,
|
||||
name="Precision",
|
||||
values=PRECISION_CHOICES,
|
||||
value=PRECISION_CHOICES.index(precision),
|
||||
begin_entry_at=3,
|
||||
max_height=len(PRECISION_CHOICES) + 1,
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.max_loaded_models = self.add_widget_intelligent(
|
||||
IntTitleSlider,
|
||||
name="Number of models to cache in CPU memory (each will use 2-4 GB!)",
|
||||
value=old_opts.max_loaded_models,
|
||||
out_of=10,
|
||||
lowest=1,
|
||||
begin_entry_at=4,
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.nextrely += 1
|
||||
self.add_widget_intelligent(
|
||||
npyscreen.FixedText,
|
||||
value="Directory containing embedding/textual inversion files:",
|
||||
editable=False,
|
||||
color="CONTROL",
|
||||
)
|
||||
self.embedding_path = self.add_widget_intelligent(
|
||||
npyscreen.TitleFilename,
|
||||
name="(<tab> autocompletes, ctrl-N advances):",
|
||||
value=str(default_embedding_dir()),
|
||||
select_dir=True,
|
||||
must_exist=False,
|
||||
use_two_lines=False,
|
||||
labelColor="GOOD",
|
||||
begin_entry_at=40,
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.nextrely += 1
|
||||
self.add_widget_intelligent(
|
||||
npyscreen.TitleFixedText,
|
||||
name="== LICENSE ==",
|
||||
begin_entry_at=0,
|
||||
editable=False,
|
||||
color="CONTROL",
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.nextrely -= 1
|
||||
for i in [
|
||||
"BY DOWNLOADING THE STABLE DIFFUSION WEIGHT FILES, YOU AGREE TO HAVE READ",
|
||||
"AND ACCEPTED THE CREATIVEML RESPONSIBLE AI LICENSE LOCATED AT",
|
||||
"https://huggingface.co/spaces/CompVis/stable-diffusion-license",
|
||||
]:
|
||||
self.add_widget_intelligent(
|
||||
npyscreen.FixedText,
|
||||
value=i,
|
||||
editable=False,
|
||||
color="CONTROL",
|
||||
)
|
||||
self.license_acceptance = self.add_widget_intelligent(
|
||||
npyscreen.Checkbox,
|
||||
name="I accept the CreativeML Responsible AI License",
|
||||
value=not first_time,
|
||||
relx=2,
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.nextrely += 1
|
||||
label = (
|
||||
"DONE"
|
||||
if program_opts.skip_sd_weights or program_opts.default_only
|
||||
else "NEXT"
|
||||
)
|
||||
self.ok_button = self.add_widget_intelligent(
|
||||
CenteredButtonPress,
|
||||
name=label,
|
||||
relx=(window_width - len(label)) // 2,
|
||||
rely=-3,
|
||||
when_pressed_function=self.on_ok,
|
||||
)
|
||||
|
||||
def on_ok(self):
|
||||
options = self.marshall_arguments()
|
||||
if self.validate_field_values(options):
|
||||
self.parentApp.new_opts = options
|
||||
if hasattr(self.parentApp, "model_select"):
|
||||
self.parentApp.setNextForm("MODELS")
|
||||
else:
|
||||
self.parentApp.setNextForm(None)
|
||||
self.editing = False
|
||||
else:
|
||||
self.editing = True
|
||||
|
||||
def validate_field_values(self, opt: Namespace) -> bool:
|
||||
bad_fields = []
|
||||
if not opt.license_acceptance:
|
||||
bad_fields.append(
|
||||
"Please accept the license terms before proceeding to model downloads"
|
||||
)
|
||||
if not Path(opt.outdir).parent.exists():
|
||||
bad_fields.append(
|
||||
f"The output directory does not seem to be valid. Please check that {str(Path(opt.outdir).parent)} is an existing directory."
|
||||
)
|
||||
if not Path(opt.embedding_path).parent.exists():
|
||||
bad_fields.append(
|
||||
f"The embedding directory does not seem to be valid. Please check that {str(Path(opt.embedding_path).parent)} is an existing directory."
|
||||
)
|
||||
if len(bad_fields) > 0:
|
||||
message = "The following problems were detected and must be corrected:\n"
|
||||
for problem in bad_fields:
|
||||
message += f"* {problem}\n"
|
||||
npyscreen.notify_confirm(message)
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
||||
def marshall_arguments(self):
|
||||
new_opts = Namespace()
|
||||
|
||||
for attr in [
|
||||
"outdir",
|
||||
"safety_checker",
|
||||
"free_gpu_mem",
|
||||
"max_loaded_models",
|
||||
"xformers",
|
||||
"always_use_cpu",
|
||||
"embedding_path",
|
||||
"ckpt_convert",
|
||||
]:
|
||||
setattr(new_opts, attr, getattr(self, attr).value)
|
||||
|
||||
new_opts.hf_token = self.hf_token.value
|
||||
new_opts.license_acceptance = self.license_acceptance.value
|
||||
new_opts.precision = PRECISION_CHOICES[self.precision.value[0]]
|
||||
|
||||
return new_opts
|
||||
|
||||
|
||||
class EditOptApplication(npyscreen.NPSAppManaged):
|
||||
def __init__(self, program_opts: Namespace, invokeai_opts: Namespace):
|
||||
super().__init__()
|
||||
self.program_opts = program_opts
|
||||
self.invokeai_opts = invokeai_opts
|
||||
self.user_cancelled = False
|
||||
self.user_selections = default_user_selections(program_opts)
|
||||
|
||||
def onStart(self):
|
||||
npyscreen.setTheme(npyscreen.Themes.DefaultTheme)
|
||||
self.options = self.addForm(
|
||||
"MAIN",
|
||||
editOptsForm,
|
||||
name="InvokeAI Startup Options",
|
||||
)
|
||||
if not (self.program_opts.skip_sd_weights or self.program_opts.default_only):
|
||||
self.model_select = self.addForm(
|
||||
"MODELS",
|
||||
addModelsForm,
|
||||
name="Install Stable Diffusion Models",
|
||||
multipage=True,
|
||||
)
|
||||
|
||||
def new_opts(self):
|
||||
return self.options.marshall_arguments()
|
||||
|
||||
|
||||
def edit_opts(program_opts: Namespace, invokeai_opts: Namespace) -> argparse.Namespace:
|
||||
editApp = EditOptApplication(program_opts, invokeai_opts)
|
||||
editApp.run()
|
||||
return editApp.new_opts()
|
||||
|
||||
|
||||
def default_startup_options(init_file: Path) -> Namespace:
|
||||
opts = Args().parse_args([])
|
||||
outdir = Path(opts.outdir)
|
||||
if not outdir.is_absolute():
|
||||
opts.outdir = str(Globals.root / opts.outdir)
|
||||
if not init_file.exists():
|
||||
opts.safety_checker = True
|
||||
return opts
|
||||
|
||||
|
||||
def default_user_selections(program_opts: Namespace) -> Namespace:
|
||||
return Namespace(
|
||||
starter_models=default_dataset()
|
||||
if program_opts.default_only
|
||||
else recommended_datasets()
|
||||
if program_opts.yes_to_all
|
||||
else dict(),
|
||||
purge_deleted_models=False,
|
||||
scan_directory=None,
|
||||
autoscan_on_startup=None,
|
||||
import_model_paths=None,
|
||||
convert_to_diffusers=None,
|
||||
)
|
||||
|
||||
|
||||
# -------------------------------------
|
||||
def initialize_rootdir(root: str, yes_to_all: bool = False):
|
||||
print("** INITIALIZING INVOKEAI RUNTIME DIRECTORY **")
|
||||
|
||||
for name in (
|
||||
"models",
|
||||
"configs",
|
||||
"embeddings",
|
||||
"text-inversion-output",
|
||||
"text-inversion-training-data",
|
||||
):
|
||||
os.makedirs(os.path.join(root, name), exist_ok=True)
|
||||
|
||||
configs_src = Path(configs.__path__[0])
|
||||
configs_dest = Path(root) / "configs"
|
||||
if not os.path.samefile(configs_src, configs_dest):
|
||||
shutil.copytree(configs_src, configs_dest, dirs_exist_ok=True)
|
||||
|
||||
|
||||
# -------------------------------------
|
||||
def run_console_ui(
|
||||
program_opts: Namespace, initfile: Path = None
|
||||
) -> (Namespace, Namespace):
|
||||
# parse_args() will read from init file if present
|
||||
invokeai_opts = default_startup_options(initfile)
|
||||
|
||||
set_min_terminal_size(MIN_COLS, MIN_LINES)
|
||||
editApp = EditOptApplication(program_opts, invokeai_opts)
|
||||
editApp.run()
|
||||
if editApp.user_cancelled:
|
||||
return (None, None)
|
||||
else:
|
||||
return (editApp.new_opts, editApp.user_selections)
|
||||
|
||||
# -------------------------------------
|
||||
def write_opts(opts: Namespace, init_file: Path):
|
||||
"""
|
||||
Update the invokeai.init file with values from opts Namespace
|
||||
"""
|
||||
# touch file if it doesn't exist
|
||||
if not init_file.exists():
|
||||
with open(init_file, "w") as f:
|
||||
f.write(INIT_FILE_PREAMBLE)
|
||||
|
||||
# We want to write in the changed arguments without clobbering
|
||||
# any other initialization values the user has entered. There is
|
||||
# no good way to do this because of the one-way nature of
|
||||
# argparse: i.e. --outdir could be --outdir, --out, or -o
|
||||
# initfile needs to be replaced with a fully structured format
|
||||
# such as yaml; this is a hack that will work much of the time
|
||||
args_to_skip = re.compile(
|
||||
"^--?(o|out|no-xformer|xformer|no-ckpt|ckpt|free|no-nsfw|nsfw|prec|max_load|embed|always|ckpt|free_gpu)"
|
||||
)
|
||||
# fix windows paths
|
||||
opts.outdir = opts.outdir.replace('\\','/')
|
||||
opts.embedding_path = opts.embedding_path.replace('\\','/')
|
||||
new_file = f"{init_file}.new"
|
||||
try:
|
||||
lines = [x.strip() for x in open(init_file, "r").readlines()]
|
||||
with open(new_file, "w") as out_file:
|
||||
for line in lines:
|
||||
if len(line) > 0 and not args_to_skip.match(line):
|
||||
out_file.write(line + "\n")
|
||||
out_file.write(
|
||||
f"""
|
||||
--outdir={opts.outdir}
|
||||
--embedding_path={opts.embedding_path}
|
||||
--precision={opts.precision}
|
||||
--max_loaded_models={int(opts.max_loaded_models)}
|
||||
--{'no-' if not opts.safety_checker else ''}nsfw_checker
|
||||
--{'no-' if not opts.xformers else ''}xformers
|
||||
--{'no-' if not opts.ckpt_convert else ''}ckpt_convert
|
||||
{'--free_gpu_mem' if opts.free_gpu_mem else ''}
|
||||
{'--always_use_cpu' if opts.always_use_cpu else ''}
|
||||
"""
|
||||
)
|
||||
except OSError as e:
|
||||
print(f"** An error occurred while writing the init file: {str(e)}")
|
||||
|
||||
os.replace(new_file, init_file)
|
||||
|
||||
if opts.hf_token:
|
||||
HfLogin(opts.hf_token)
|
||||
|
||||
|
||||
# -------------------------------------
|
||||
def default_output_dir() -> Path:
|
||||
return Globals.root / "outputs"
|
||||
|
||||
|
||||
# -------------------------------------
|
||||
def default_embedding_dir() -> Path:
|
||||
return Globals.root / "embeddings"
|
||||
|
||||
|
||||
# -------------------------------------
|
||||
def write_default_options(program_opts: Namespace, initfile: Path):
|
||||
opt = default_startup_options(initfile)
|
||||
opt.hf_token = HfFolder.get_token()
|
||||
write_opts(opt, initfile)
|
||||
|
||||
|
||||
# -------------------------------------
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="InvokeAI model downloader")
|
||||
parser.add_argument(
|
||||
"--skip-sd-weights",
|
||||
dest="skip_sd_weights",
|
||||
action=argparse.BooleanOptionalAction,
|
||||
default=False,
|
||||
help="skip downloading the large Stable Diffusion weight files",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--skip-support-models",
|
||||
dest="skip_support_models",
|
||||
action=argparse.BooleanOptionalAction,
|
||||
default=False,
|
||||
help="skip downloading the support models",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--full-precision",
|
||||
dest="full_precision",
|
||||
action=argparse.BooleanOptionalAction,
|
||||
type=bool,
|
||||
default=False,
|
||||
help="use 32-bit weights instead of faster 16-bit weights",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--yes",
|
||||
"-y",
|
||||
dest="yes_to_all",
|
||||
action="store_true",
|
||||
help='answer "yes" to all prompts',
|
||||
)
|
||||
parser.add_argument(
|
||||
"--default_only",
|
||||
action="store_true",
|
||||
help="when --yes specified, only install the default model",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--config_file",
|
||||
"-c",
|
||||
dest="config_file",
|
||||
type=str,
|
||||
default=None,
|
||||
help="path to configuration file to create",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--root_dir",
|
||||
dest="root",
|
||||
type=str,
|
||||
default=None,
|
||||
help="path to root of install directory",
|
||||
)
|
||||
opt = parser.parse_args()
|
||||
|
||||
# setting a global here
|
||||
Globals.root = Path(os.path.expanduser(get_root(opt.root) or ""))
|
||||
|
||||
errors = set()
|
||||
|
||||
try:
|
||||
models_to_download = default_user_selections(opt)
|
||||
|
||||
# We check for to see if the runtime directory is correctly initialized.
|
||||
init_file = Path(Globals.root, Globals.initfile)
|
||||
if not init_file.exists() or not global_config_file().exists():
|
||||
initialize_rootdir(Globals.root, opt.yes_to_all)
|
||||
|
||||
if opt.yes_to_all:
|
||||
write_default_options(opt, init_file)
|
||||
init_options = Namespace(
|
||||
precision="float32" if opt.full_precision else "float16"
|
||||
)
|
||||
else:
|
||||
init_options, models_to_download = run_console_ui(opt, init_file)
|
||||
if init_options:
|
||||
write_opts(init_options, init_file)
|
||||
else:
|
||||
print(
|
||||
'\n** CANCELLED AT USER\'S REQUEST. USE THE "invoke.sh" LAUNCHER TO RUN LATER **\n'
|
||||
)
|
||||
sys.exit(0)
|
||||
|
||||
if opt.skip_support_models:
|
||||
print("\n** SKIPPING SUPPORT MODEL DOWNLOADS PER USER REQUEST **")
|
||||
else:
|
||||
print("\n** DOWNLOADING SUPPORT MODELS **")
|
||||
download_bert()
|
||||
download_sd1_clip()
|
||||
download_sd2_clip()
|
||||
download_realesrgan()
|
||||
download_gfpgan()
|
||||
download_codeformer()
|
||||
download_clipseg()
|
||||
download_safety_checker()
|
||||
download_vaes()
|
||||
|
||||
if opt.skip_sd_weights:
|
||||
print("\n** SKIPPING DIFFUSION WEIGHTS DOWNLOAD PER USER REQUEST **")
|
||||
elif models_to_download:
|
||||
print("\n** DOWNLOADING DIFFUSION WEIGHTS **")
|
||||
process_and_execute(opt, models_to_download)
|
||||
|
||||
postscript(errors=errors)
|
||||
except KeyboardInterrupt:
|
||||
print("\nGoodbye! Come back soon.")
|
||||
|
||||
# -------------------------------------
|
||||
if __name__ == "__main__":
|
||||
main()
|
455
invokeai/backend/config/model_install_backend.py
Normal file
@ -0,0 +1,455 @@
|
||||
"""
|
||||
Utility (backend) functions used by model_install.py
|
||||
"""
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
import sys
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from tempfile import TemporaryFile
|
||||
|
||||
import requests
|
||||
from diffusers import AutoencoderKL
|
||||
from huggingface_hub import hf_hub_url
|
||||
from omegaconf import OmegaConf
|
||||
from omegaconf.dictconfig import DictConfig
|
||||
from tqdm import tqdm
|
||||
from typing import List
|
||||
|
||||
import invokeai.configs as configs
|
||||
from ..stable_diffusion import StableDiffusionGeneratorPipeline
|
||||
from ..globals import Globals, global_cache_dir, global_config_dir
|
||||
from ..model_management import ModelManager
|
||||
|
||||
warnings.filterwarnings("ignore")
|
||||
|
||||
# --------------------------globals-----------------------
|
||||
Model_dir = "models"
|
||||
Weights_dir = "ldm/stable-diffusion-v1/"
|
||||
|
||||
# the initial "configs" dir is now bundled in the `invokeai.configs` package
|
||||
Dataset_path = Path(configs.__path__[0]) / "INITIAL_MODELS.yaml"
|
||||
|
||||
# initial models omegaconf
|
||||
Datasets = None
|
||||
|
||||
Config_preamble = """
|
||||
# This file describes the alternative machine learning models
|
||||
# available to InvokeAI script.
|
||||
#
|
||||
# To add a new model, follow the examples below. Each
|
||||
# model requires a model config file, a weights file,
|
||||
# and the width and height of the images it
|
||||
# was trained on.
|
||||
"""
|
||||
|
||||
def default_config_file():
|
||||
return Path(global_config_dir()) / "models.yaml"
|
||||
|
||||
def sd_configs():
|
||||
return Path(global_config_dir()) / "stable-diffusion"
|
||||
|
||||
def initial_models():
|
||||
global Datasets
|
||||
if Datasets:
|
||||
return Datasets
|
||||
return (Datasets := OmegaConf.load(Dataset_path))
|
||||
|
||||
def install_requested_models(
|
||||
install_initial_models: List[str] = None,
|
||||
remove_models: List[str] = None,
|
||||
scan_directory: Path = None,
|
||||
external_models: List[str] = None,
|
||||
scan_at_startup: bool = False,
|
||||
convert_to_diffusers: bool = False,
|
||||
precision: str = "float16",
|
||||
purge_deleted: bool = False,
|
||||
config_file_path: Path = None,
|
||||
):
|
||||
'''
|
||||
Entry point for installing/deleting starter models, or installing external models.
|
||||
'''
|
||||
config_file_path=config_file_path or default_config_file()
|
||||
if not config_file_path.exists():
|
||||
open(config_file_path,'w')
|
||||
|
||||
model_manager= ModelManager(OmegaConf.load(config_file_path),precision=precision)
|
||||
|
||||
if remove_models and len(remove_models) > 0:
|
||||
print("== DELETING UNCHECKED STARTER MODELS ==")
|
||||
for model in remove_models:
|
||||
print(f'{model}...')
|
||||
model_manager.del_model(model, delete_files=purge_deleted)
|
||||
model_manager.commit(config_file_path)
|
||||
|
||||
if install_initial_models and len(install_initial_models) > 0:
|
||||
print("== INSTALLING SELECTED STARTER MODELS ==")
|
||||
successfully_downloaded = download_weight_datasets(
|
||||
models=install_initial_models,
|
||||
access_token=None,
|
||||
precision=precision,
|
||||
) # FIX: for historical reasons, we don't use model manager here
|
||||
update_config_file(successfully_downloaded, config_file_path)
|
||||
if len(successfully_downloaded) < len(install_initial_models):
|
||||
print("** Some of the model downloads were not successful")
|
||||
|
||||
# due to above, we have to reload the model manager because conf file
|
||||
# was changed behind its back
|
||||
model_manager= ModelManager(OmegaConf.load(config_file_path),precision=precision)
|
||||
|
||||
external_models = external_models or list()
|
||||
if scan_directory:
|
||||
external_models.append(str(scan_directory))
|
||||
|
||||
if len(external_models)>0:
|
||||
print("== INSTALLING EXTERNAL MODELS ==")
|
||||
for path_url_or_repo in external_models:
|
||||
try:
|
||||
model_manager.heuristic_import(
|
||||
path_url_or_repo,
|
||||
convert=convert_to_diffusers,
|
||||
commit_to_conf=config_file_path
|
||||
)
|
||||
except KeyboardInterrupt:
|
||||
sys.exit(-1)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if scan_at_startup and scan_directory.is_dir():
|
||||
argument = '--autoconvert' if convert_to_diffusers else '--autoimport'
|
||||
initfile = Path(Globals.root, Globals.initfile)
|
||||
replacement = Path(Globals.root, f'{Globals.initfile}.new')
|
||||
directory = str(scan_directory).replace('\\','/')
|
||||
with open(initfile,'r') as input:
|
||||
with open(replacement,'w') as output:
|
||||
while line := input.readline():
|
||||
if not line.startswith(argument):
|
||||
output.writelines([line])
|
||||
output.writelines([f'{argument} {directory}'])
|
||||
os.replace(replacement,initfile)
|
||||
|
||||
# -------------------------------------
|
||||
def yes_or_no(prompt: str, default_yes=True):
|
||||
default = "y" if default_yes else "n"
|
||||
response = input(f"{prompt} [{default}] ") or default
|
||||
if default_yes:
|
||||
return response[0] not in ("n", "N")
|
||||
else:
|
||||
return response[0] in ("y", "Y")
|
||||
|
||||
|
||||
# -------------------------------------
|
||||
def get_root(root: str = None) -> str:
|
||||
if root:
|
||||
return root
|
||||
elif os.environ.get("INVOKEAI_ROOT"):
|
||||
return os.environ.get("INVOKEAI_ROOT")
|
||||
else:
|
||||
return Globals.root
|
||||
|
||||
|
||||
# ---------------------------------------------
|
||||
def recommended_datasets() -> dict:
|
||||
datasets = dict()
|
||||
for ds in initial_models().keys():
|
||||
if initial_models()[ds].get("recommended", False):
|
||||
datasets[ds] = True
|
||||
return datasets
|
||||
|
||||
|
||||
# ---------------------------------------------
|
||||
def default_dataset() -> dict:
|
||||
datasets = dict()
|
||||
for ds in initial_models().keys():
|
||||
if initial_models()[ds].get("default", False):
|
||||
datasets[ds] = True
|
||||
return datasets
|
||||
|
||||
|
||||
# ---------------------------------------------
|
||||
def all_datasets() -> dict:
|
||||
datasets = dict()
|
||||
for ds in initial_models().keys():
|
||||
datasets[ds] = True
|
||||
return datasets
|
||||
|
||||
|
||||
# ---------------------------------------------
|
||||
# look for legacy model.ckpt in models directory and offer to
|
||||
# normalize its name
|
||||
def migrate_models_ckpt():
|
||||
model_path = os.path.join(Globals.root, Model_dir, Weights_dir)
|
||||
if not os.path.exists(os.path.join(model_path, "model.ckpt")):
|
||||
return
|
||||
new_name = initial_models()["stable-diffusion-1.4"]["file"]
|
||||
print('The Stable Diffusion v4.1 "model.ckpt" is already installed. The name will be changed to {new_name} to avoid confusion.')
|
||||
print(f"model.ckpt => {new_name}")
|
||||
os.replace(
|
||||
os.path.join(model_path, "model.ckpt"), os.path.join(model_path, new_name)
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------
|
||||
def download_weight_datasets(
|
||||
models: List[str], access_token: str, precision: str = "float32"
|
||||
):
|
||||
migrate_models_ckpt()
|
||||
successful = dict()
|
||||
for mod in models:
|
||||
print(f"Downloading {mod}:")
|
||||
successful[mod] = _download_repo_or_file(
|
||||
initial_models()[mod], access_token, precision=precision
|
||||
)
|
||||
return successful
|
||||
|
||||
|
||||
def _download_repo_or_file(
|
||||
mconfig: DictConfig, access_token: str, precision: str = "float32"
|
||||
) -> Path:
|
||||
path = None
|
||||
if mconfig["format"] == "ckpt":
|
||||
path = _download_ckpt_weights(mconfig, access_token)
|
||||
else:
|
||||
path = _download_diffusion_weights(mconfig, access_token, precision=precision)
|
||||
if "vae" in mconfig and "repo_id" in mconfig["vae"]:
|
||||
_download_diffusion_weights(
|
||||
mconfig["vae"], access_token, precision=precision
|
||||
)
|
||||
return path
|
||||
|
||||
|
||||
def _download_ckpt_weights(mconfig: DictConfig, access_token: str) -> Path:
|
||||
repo_id = mconfig["repo_id"]
|
||||
filename = mconfig["file"]
|
||||
cache_dir = os.path.join(Globals.root, Model_dir, Weights_dir)
|
||||
return hf_download_with_resume(
|
||||
repo_id=repo_id,
|
||||
model_dir=cache_dir,
|
||||
model_name=filename,
|
||||
access_token=access_token,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------
|
||||
def download_from_hf(
|
||||
model_class: object, model_name: str, cache_subdir: Path = Path("hub"), **kwargs
|
||||
):
|
||||
path = global_cache_dir(cache_subdir)
|
||||
model = model_class.from_pretrained(
|
||||
model_name,
|
||||
cache_dir=path,
|
||||
resume_download=True,
|
||||
**kwargs,
|
||||
)
|
||||
model_name = "--".join(("models", *model_name.split("/")))
|
||||
return path / model_name if model else None
|
||||
|
||||
|
||||
def _download_diffusion_weights(
|
||||
mconfig: DictConfig, access_token: str, precision: str = "float32"
|
||||
):
|
||||
repo_id = mconfig["repo_id"]
|
||||
model_class = (
|
||||
StableDiffusionGeneratorPipeline
|
||||
if mconfig.get("format", None) == "diffusers"
|
||||
else AutoencoderKL
|
||||
)
|
||||
extra_arg_list = [{"revision": "fp16"}, {}] if precision == "float16" else [{}]
|
||||
path = None
|
||||
for extra_args in extra_arg_list:
|
||||
try:
|
||||
path = download_from_hf(
|
||||
model_class,
|
||||
repo_id,
|
||||
cache_subdir="diffusers",
|
||||
safety_checker=None,
|
||||
**extra_args,
|
||||
)
|
||||
except OSError as e:
|
||||
if str(e).startswith("fp16 is not a valid"):
|
||||
pass
|
||||
else:
|
||||
print(f"An unexpected error occurred while downloading the model: {e})")
|
||||
if path:
|
||||
break
|
||||
return path
|
||||
|
||||
|
||||
# ---------------------------------------------
|
||||
def hf_download_with_resume(
|
||||
repo_id: str, model_dir: str, model_name: str, access_token: str = None
|
||||
) -> Path:
|
||||
model_dest = Path(os.path.join(model_dir, model_name))
|
||||
os.makedirs(model_dir, exist_ok=True)
|
||||
|
||||
url = hf_hub_url(repo_id, model_name)
|
||||
|
||||
header = {"Authorization": f"Bearer {access_token}"} if access_token else {}
|
||||
open_mode = "wb"
|
||||
exist_size = 0
|
||||
|
||||
if os.path.exists(model_dest):
|
||||
exist_size = os.path.getsize(model_dest)
|
||||
header["Range"] = f"bytes={exist_size}-"
|
||||
open_mode = "ab"
|
||||
|
||||
resp = requests.get(url, headers=header, stream=True)
|
||||
total = int(resp.headers.get("content-length", 0))
|
||||
|
||||
if (
|
||||
resp.status_code == 416
|
||||
): # "range not satisfiable", which means nothing to return
|
||||
print(f"* {model_name}: complete file found. Skipping.")
|
||||
return model_dest
|
||||
elif resp.status_code != 200:
|
||||
print(f"** An error occurred during downloading {model_name}: {resp.reason}")
|
||||
elif exist_size > 0:
|
||||
print(f"* {model_name}: partial file found. Resuming...")
|
||||
else:
|
||||
print(f"* {model_name}: Downloading...")
|
||||
|
||||
try:
|
||||
if total < 2000:
|
||||
print(f"*** ERROR DOWNLOADING {model_name}: {resp.text}")
|
||||
return None
|
||||
|
||||
with open(model_dest, open_mode) as file, tqdm(
|
||||
desc=model_name,
|
||||
initial=exist_size,
|
||||
total=total + exist_size,
|
||||
unit="iB",
|
||||
unit_scale=True,
|
||||
unit_divisor=1000,
|
||||
) as bar:
|
||||
for data in resp.iter_content(chunk_size=1024):
|
||||
size = file.write(data)
|
||||
bar.update(size)
|
||||
except Exception as e:
|
||||
print(f"An error occurred while downloading {model_name}: {str(e)}")
|
||||
return None
|
||||
return model_dest
|
||||
|
||||
|
||||
# ---------------------------------------------
|
||||
def update_config_file(successfully_downloaded: dict, config_file: Path):
|
||||
config_file = (
|
||||
Path(config_file) if config_file is not None else default_config_file()
|
||||
)
|
||||
|
||||
# In some cases (incomplete setup, etc), the default configs directory might be missing.
|
||||
# Create it if it doesn't exist.
|
||||
# this check is ignored if opt.config_file is specified - user is assumed to know what they
|
||||
# are doing if they are passing a custom config file from elsewhere.
|
||||
if config_file is default_config_file() and not config_file.parent.exists():
|
||||
configs_src = Dataset_path.parent
|
||||
configs_dest = default_config_file().parent
|
||||
shutil.copytree(configs_src, configs_dest, dirs_exist_ok=True)
|
||||
|
||||
yaml = new_config_file_contents(successfully_downloaded, config_file)
|
||||
|
||||
try:
|
||||
backup = None
|
||||
if os.path.exists(config_file):
|
||||
print(
|
||||
f"** {config_file.name} exists. Renaming to {config_file.stem}.yaml.orig"
|
||||
)
|
||||
backup = config_file.with_suffix(".yaml.orig")
|
||||
## Ugh. Windows is unable to overwrite an existing backup file, raises a WinError 183
|
||||
if sys.platform == "win32" and backup.is_file():
|
||||
backup.unlink()
|
||||
config_file.rename(backup)
|
||||
|
||||
with TemporaryFile() as tmp:
|
||||
tmp.write(Config_preamble.encode())
|
||||
tmp.write(yaml.encode())
|
||||
|
||||
with open(str(config_file.expanduser().resolve()), "wb") as new_config:
|
||||
tmp.seek(0)
|
||||
new_config.write(tmp.read())
|
||||
|
||||
except Exception as e:
|
||||
print(f"**Error creating config file {config_file}: {str(e)} **")
|
||||
if backup is not None:
|
||||
print("restoring previous config file")
|
||||
## workaround, for WinError 183, see above
|
||||
if sys.platform == "win32" and config_file.is_file():
|
||||
config_file.unlink()
|
||||
backup.rename(config_file)
|
||||
return
|
||||
|
||||
print(f"Successfully created new configuration file {config_file}")
|
||||
|
||||
|
||||
# ---------------------------------------------
|
||||
def new_config_file_contents(
|
||||
successfully_downloaded: dict, config_file: Path,
|
||||
) -> str:
|
||||
if config_file.exists():
|
||||
conf = OmegaConf.load(str(config_file.expanduser().resolve()))
|
||||
else:
|
||||
conf = OmegaConf.create()
|
||||
|
||||
default_selected = None
|
||||
for model in successfully_downloaded:
|
||||
# a bit hacky - what we are doing here is seeing whether a checkpoint
|
||||
# version of the model was previously defined, and whether the current
|
||||
# model is a diffusers (indicated with a path)
|
||||
if conf.get(model) and Path(successfully_downloaded[model]).is_dir():
|
||||
delete_weights(model, conf[model])
|
||||
|
||||
stanza = {}
|
||||
mod = initial_models()[model]
|
||||
stanza["description"] = mod["description"]
|
||||
stanza["repo_id"] = mod["repo_id"]
|
||||
stanza["format"] = mod["format"]
|
||||
# diffusers don't need width and height (probably .ckpt doesn't either)
|
||||
# so we no longer require these in INITIAL_MODELS.yaml
|
||||
if "width" in mod:
|
||||
stanza["width"] = mod["width"]
|
||||
if "height" in mod:
|
||||
stanza["height"] = mod["height"]
|
||||
if "file" in mod:
|
||||
stanza["weights"] = os.path.relpath(
|
||||
successfully_downloaded[model], start=Globals.root
|
||||
)
|
||||
stanza["config"] = os.path.normpath(os.path.join(sd_configs(), mod["config"]))
|
||||
if "vae" in mod:
|
||||
if "file" in mod["vae"]:
|
||||
stanza["vae"] = os.path.normpath(
|
||||
os.path.join(Model_dir, Weights_dir, mod["vae"]["file"])
|
||||
)
|
||||
else:
|
||||
stanza["vae"] = mod["vae"]
|
||||
if mod.get("default", False):
|
||||
stanza["default"] = True
|
||||
default_selected = True
|
||||
|
||||
conf[model] = stanza
|
||||
|
||||
# if no default model was chosen, then we select the first
|
||||
# one in the list
|
||||
if not default_selected:
|
||||
conf[list(successfully_downloaded.keys())[0]]["default"] = True
|
||||
|
||||
return OmegaConf.to_yaml(conf)
|
||||
|
||||
|
||||
# ---------------------------------------------
|
||||
def delete_weights(model_name: str, conf_stanza: dict):
|
||||
if not (weights := conf_stanza.get("weights")):
|
||||
return
|
||||
if re.match("/VAE/", conf_stanza.get("config")):
|
||||
return
|
||||
|
||||
print(
|
||||
f"\n** The checkpoint version of {model_name} is superseded by the diffusers version. Deleting the original file {weights}?"
|
||||
)
|
||||
|
||||
weights = Path(weights)
|
||||
if not weights.is_absolute():
|
||||
weights = Path(Globals.root) / weights
|
||||
try:
|
||||
weights.unlink()
|
||||
except OSError as e:
|
||||
print(str(e))
|
1349
invokeai/backend/generate.py
Normal file
@ -23,7 +23,7 @@ from tqdm import trange
|
||||
|
||||
import invokeai.assets.web as web_assets
|
||||
from ..stable_diffusion.diffusion.ddpm import DiffusionWrapper
|
||||
from ..util import rand_perlin_2d
|
||||
from ..util.util import rand_perlin_2d
|
||||
|
||||
downsampling = 8
|
||||
CAUTION_IMG = 'caution.png'
|
||||
|
115
invokeai/backend/globals.py
Normal file
@ -0,0 +1,115 @@
|
||||
'''
|
||||
invokeai.backend.globals defines a small number of global variables that would
|
||||
otherwise have to be passed through long and complex call chains.
|
||||
|
||||
It defines a Namespace object named "Globals" that contains
|
||||
the attributes:
|
||||
|
||||
- root - the root directory under which "models" and "outputs" can be found
|
||||
- initfile - path to the initialization file
|
||||
- try_patchmatch - option to globally disable loading of 'patchmatch' module
|
||||
- always_use_cpu - force use of CPU even if GPU is available
|
||||
'''
|
||||
|
||||
import os
|
||||
import os.path as osp
|
||||
from argparse import Namespace
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
Globals = Namespace()
|
||||
|
||||
# Where to look for the initialization file and other key components
|
||||
Globals.initfile = 'invokeai.init'
|
||||
Globals.models_file = 'models.yaml'
|
||||
Globals.models_dir = 'models'
|
||||
Globals.config_dir = 'configs'
|
||||
Globals.autoscan_dir = 'weights'
|
||||
Globals.converted_ckpts_dir = 'converted_ckpts'
|
||||
|
||||
# Set the default root directory. This can be overwritten by explicitly
|
||||
# passing the `--root <directory>` argument on the command line.
|
||||
# logic is:
|
||||
# 1) use INVOKEAI_ROOT environment variable (no check for this being a valid directory)
|
||||
# 2) use VIRTUAL_ENV environment variable, with a check for initfile being there
|
||||
# 3) use ~/invokeai
|
||||
|
||||
if os.environ.get('INVOKEAI_ROOT'):
|
||||
Globals.root = osp.abspath(os.environ.get('INVOKEAI_ROOT'))
|
||||
elif os.environ.get('VIRTUAL_ENV') and Path(os.environ.get('VIRTUAL_ENV'),'..',Globals.initfile).exists():
|
||||
Globals.root = osp.abspath(osp.join(os.environ.get('VIRTUAL_ENV'), '..'))
|
||||
else:
|
||||
Globals.root = osp.abspath(osp.expanduser('~/invokeai'))
|
||||
|
||||
# Try loading patchmatch
|
||||
Globals.try_patchmatch = True
|
||||
|
||||
# Use CPU even if GPU is available (main use case is for debugging MPS issues)
|
||||
Globals.always_use_cpu = False
|
||||
|
||||
# Whether the internet is reachable for dynamic downloads
|
||||
# The CLI will test connectivity at startup time.
|
||||
Globals.internet_available = True
|
||||
|
||||
# Whether to disable xformers
|
||||
Globals.disable_xformers = False
|
||||
|
||||
# Low-memory tradeoff for guidance calculations.
|
||||
Globals.sequential_guidance = False
|
||||
|
||||
# whether we are forcing full precision
|
||||
Globals.full_precision = False
|
||||
|
||||
# whether we should convert ckpt files into diffusers models on the fly
|
||||
Globals.ckpt_convert = True
|
||||
|
||||
# logging tokenization everywhere
|
||||
Globals.log_tokenization = False
|
||||
|
||||
def global_config_file()->Path:
|
||||
return Path(Globals.root, Globals.config_dir, Globals.models_file)
|
||||
|
||||
def global_config_dir()->Path:
|
||||
return Path(Globals.root, Globals.config_dir)
|
||||
|
||||
def global_models_dir()->Path:
|
||||
return Path(Globals.root, Globals.models_dir)
|
||||
|
||||
def global_autoscan_dir()->Path:
|
||||
return Path(Globals.root, Globals.autoscan_dir)
|
||||
|
||||
def global_converted_ckpts_dir()->Path:
|
||||
return Path(global_models_dir(), Globals.converted_ckpts_dir)
|
||||
|
||||
def global_set_root(root_dir:Union[str,Path]):
|
||||
Globals.root = root_dir
|
||||
|
||||
def global_cache_dir(subdir:Union[str,Path]='')->Path:
|
||||
'''
|
||||
Returns Path to the model cache directory. If a subdirectory
|
||||
is provided, it will be appended to the end of the path, allowing
|
||||
for huggingface-style conventions:
|
||||
global_cache_dir('diffusers')
|
||||
global_cache_dir('hub')
|
||||
Current HuggingFace documentation (mid-Jan 2023) indicates that
|
||||
transformers models will be cached into a "transformers" subdirectory,
|
||||
but in practice they seem to go into "hub". But if needed:
|
||||
global_cache_dir('transformers')
|
||||
One other caveat is that HuggingFace is moving some diffusers models
|
||||
into the "hub" subdirectory as well, so this will need to be revisited
|
||||
from time to time.
|
||||
'''
|
||||
home: str = os.getenv('HF_HOME')
|
||||
|
||||
if home is None:
|
||||
home = os.getenv('XDG_CACHE_HOME')
|
||||
|
||||
if home is not None:
|
||||
# Set `home` to $XDG_CACHE_HOME/huggingface, which is the default location mentioned in HuggingFace Hub Client Library.
|
||||
# See: https://huggingface.co/docs/huggingface_hub/main/en/package_reference/environment_variables#xdgcachehome
|
||||
home += os.sep + 'huggingface'
|
||||
|
||||
if home is not None:
|
||||
return Path(home,subdir)
|
||||
else:
|
||||
return Path(Globals.root,'models',subdir)
|
@ -9,6 +9,7 @@ from .pngwriter import (PngWriter,
|
||||
retrieve_metadata,
|
||||
write_metadata,
|
||||
)
|
||||
from .seamless import configure_model_padding
|
||||
|
||||
def debug_image(
|
||||
debug_image, debug_text, debug_show=True, debug_result=False, debug_status=False
|
||||
|
@ -4,7 +4,7 @@ wraps the actual patchmatch object. It respects the global
|
||||
"try_patchmatch" attribute, so that patchmatch loading can
|
||||
be suppressed or deferred
|
||||
'''
|
||||
from ldm.invoke.globals import Globals
|
||||
from invokeai.backend.globals import Globals
|
||||
import numpy as np
|
||||
|
||||
class PatchMatch:
|
||||
|
31
invokeai/backend/image_util/seamless.py
Normal file
@ -0,0 +1,31 @@
|
||||
import torch.nn as nn
|
||||
|
||||
def _conv_forward_asymmetric(self, input, weight, bias):
|
||||
"""
|
||||
Patch for Conv2d._conv_forward that supports asymmetric padding
|
||||
"""
|
||||
working = nn.functional.pad(input, self.asymmetric_padding['x'], mode=self.asymmetric_padding_mode['x'])
|
||||
working = nn.functional.pad(working, self.asymmetric_padding['y'], mode=self.asymmetric_padding_mode['y'])
|
||||
return nn.functional.conv2d(working, weight, bias, self.stride, nn.modules.utils._pair(0), self.dilation, self.groups)
|
||||
|
||||
def configure_model_padding(model, seamless, seamless_axes):
|
||||
"""
|
||||
Modifies the 2D convolution layers to use a circular padding mode based on the `seamless` and `seamless_axes` options.
|
||||
"""
|
||||
# TODO: get an explicit interface for this in diffusers: https://github.com/huggingface/diffusers/issues/556
|
||||
for m in model.modules():
|
||||
if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
|
||||
if seamless:
|
||||
m.asymmetric_padding_mode = {}
|
||||
m.asymmetric_padding = {}
|
||||
m.asymmetric_padding_mode['x'] = 'circular' if ('x' in seamless_axes) else 'constant'
|
||||
m.asymmetric_padding['x'] = (m._reversed_padding_repeated_twice[0], m._reversed_padding_repeated_twice[1], 0, 0)
|
||||
m.asymmetric_padding_mode['y'] = 'circular' if ('y' in seamless_axes) else 'constant'
|
||||
m.asymmetric_padding['y'] = (0, 0, m._reversed_padding_repeated_twice[2], m._reversed_padding_repeated_twice[3])
|
||||
m._conv_forward = _conv_forward_asymmetric.__get__(m, nn.Conv2d)
|
||||
else:
|
||||
m._conv_forward = nn.Conv2d._conv_forward.__get__(m, nn.Conv2d)
|
||||
if hasattr(m, 'asymmetric_padding_mode'):
|
||||
del m.asymmetric_padding_mode
|
||||
if hasattr(m, 'asymmetric_padding'):
|
||||
del m.asymmetric_padding
|
@ -32,7 +32,7 @@ import numpy as np
|
||||
from transformers import AutoProcessor, CLIPSegForImageSegmentation
|
||||
from PIL import Image, ImageOps
|
||||
from torchvision import transforms
|
||||
from ldm.invoke.globals import global_cache_dir
|
||||
from invokeai.backend.globals import global_cache_dir
|
||||
|
||||
CLIPSEG_MODEL = 'CIDAS/clipseg-rd64-refined'
|
||||
CLIPSEG_SIZE = 352
|
||||
|
8
invokeai/backend/model_management/__init__.py
Normal file
@ -0,0 +1,8 @@
|
||||
'''
|
||||
Initialization file for invokeai.backend.model_management
|
||||
'''
|
||||
from .model_manager import ModelManager
|
||||
from .convert_ckpt_to_diffusers import (load_pipeline_from_original_stable_diffusion_ckpt,
|
||||
convert_ckpt_to_diffusers)
|
||||
from ...frontend.merge.merge_diffusers import (merge_diffusion_models,
|
||||
merge_diffusion_models_and_commit)
|
1035
invokeai/backend/model_management/convert_ckpt_to_diffusers.py
Normal file
@ -31,14 +31,13 @@ from omegaconf import OmegaConf
|
||||
from omegaconf.dictconfig import DictConfig
|
||||
from picklescan.scanner import scan_file_path
|
||||
|
||||
from .devices import CPU_DEVICE
|
||||
from ldm.invoke.globals import Globals, global_cache_dir
|
||||
from .util import (
|
||||
from ..util import CPU_DEVICE
|
||||
from invokeai.backend.globals import Globals, global_cache_dir
|
||||
from ..util import (
|
||||
ask_user,
|
||||
download_with_resume,
|
||||
url_attachment_name,
|
||||
)
|
||||
from .stable_diffusion import StableDiffusionGeneratorPipeline
|
||||
from ..stable_diffusion import StableDiffusionGeneratorPipeline
|
||||
|
||||
class SDLegacyType(Enum):
|
||||
V1 = 1
|
||||
@ -416,6 +415,51 @@ class ModelManager(object):
|
||||
|
||||
return pipeline, width, height, model_hash
|
||||
|
||||
def _load_ckpt_model(self, model_name, mconfig):
|
||||
config = mconfig.config
|
||||
weights = mconfig.weights
|
||||
vae = mconfig.get("vae")
|
||||
width = mconfig.width
|
||||
height = mconfig.height
|
||||
|
||||
if not os.path.isabs(config):
|
||||
config = os.path.join(Globals.root, config)
|
||||
if not os.path.isabs(weights):
|
||||
weights = os.path.normpath(os.path.join(Globals.root, weights))
|
||||
|
||||
# Convert to diffusers and return a diffusers pipeline
|
||||
print(
|
||||
f">> Converting legacy checkpoint {model_name} into a diffusers model..."
|
||||
)
|
||||
|
||||
from . import load_pipeline_from_original_stable_diffusion_ckpt
|
||||
|
||||
self.offload_model(self.current_model)
|
||||
if vae_config := self._choose_diffusers_vae(model_name):
|
||||
vae = self._load_vae(vae_config)
|
||||
if self._has_cuda():
|
||||
torch.cuda.empty_cache()
|
||||
pipeline = load_pipeline_from_original_stable_diffusion_ckpt(
|
||||
checkpoint_path=weights,
|
||||
original_config_file=config,
|
||||
vae=vae,
|
||||
return_generator_pipeline=True,
|
||||
precision=torch.float16
|
||||
if self.precision == "float16"
|
||||
else torch.float32,
|
||||
)
|
||||
if self.sequential_offload:
|
||||
pipeline.enable_offload_submodels(self.device)
|
||||
else:
|
||||
pipeline.to(self.device)
|
||||
|
||||
return (
|
||||
pipeline,
|
||||
width,
|
||||
height,
|
||||
"NOHASH",
|
||||
)
|
||||
|
||||
def model_name_or_path(self, model_name: Union[str, DictConfig]) -> str | Path:
|
||||
if isinstance(model_name, DictConfig) or isinstance(model_name, dict):
|
||||
mconfig = model_name
|
||||
@ -519,66 +563,6 @@ class ModelManager(object):
|
||||
self.commit(commit_to_conf)
|
||||
return model_name
|
||||
|
||||
def import_ckpt_model(
|
||||
self,
|
||||
weights: Union[str, Path],
|
||||
config: Union[str, Path] = "configs/stable-diffusion/v1-inference.yaml",
|
||||
vae: Union[str, Path] = None,
|
||||
model_name: str = None,
|
||||
model_description: str = None,
|
||||
commit_to_conf: Path = None,
|
||||
) -> str:
|
||||
"""
|
||||
Attempts to install the indicated ckpt file and returns True if successful.
|
||||
|
||||
"weights" can be either a path-like object corresponding to a local .ckpt file
|
||||
or a http/https URL pointing to a remote model.
|
||||
|
||||
"vae" is a Path or str object pointing to a ckpt or safetensors file to be used
|
||||
as the VAE for this model.
|
||||
|
||||
"config" is the model config file to use with this ckpt file. It defaults to
|
||||
v1-inference.yaml. If a URL is provided, the config will be downloaded.
|
||||
|
||||
You can optionally provide a model name and/or description. If not provided,
|
||||
then these will be derived from the weight file name. If you provide a commit_to_conf
|
||||
path to the configuration file, then the new entry will be committed to the
|
||||
models.yaml file.
|
||||
|
||||
Return value is the name of the imported file, or None if an error occurred.
|
||||
"""
|
||||
if str(weights).startswith(("http:", "https:")):
|
||||
model_name = model_name or url_attachment_name(weights)
|
||||
|
||||
weights_path = self._resolve_path(weights, "models/ldm/stable-diffusion-v1")
|
||||
config_path = self._resolve_path(config, "configs/stable-diffusion")
|
||||
|
||||
if weights_path is None or not weights_path.exists():
|
||||
return
|
||||
if config_path is None or not config_path.exists():
|
||||
return
|
||||
|
||||
model_name = (
|
||||
model_name or Path(weights).stem
|
||||
) # note this gives ugly pathnames if used on a URL without a Content-Disposition header
|
||||
model_description = (
|
||||
model_description or f"Imported stable diffusion weights file {model_name}"
|
||||
)
|
||||
new_config = dict(
|
||||
weights=str(weights_path),
|
||||
config=str(config_path),
|
||||
description=model_description,
|
||||
format="ckpt",
|
||||
width=512,
|
||||
height=512,
|
||||
)
|
||||
if vae:
|
||||
new_config["vae"] = vae
|
||||
self.add_model(model_name, new_config, True)
|
||||
if commit_to_conf:
|
||||
self.commit(commit_to_conf)
|
||||
return model_name
|
||||
|
||||
@classmethod
|
||||
def probe_model_type(self, checkpoint: dict) -> SDLegacyType:
|
||||
"""
|
||||
@ -746,36 +730,18 @@ class ModelManager(object):
|
||||
)
|
||||
return
|
||||
|
||||
if convert:
|
||||
diffuser_path = Path(
|
||||
Globals.root, "models", Globals.converted_ckpts_dir, model_path.stem
|
||||
)
|
||||
model_name = self.convert_and_import(
|
||||
model_path,
|
||||
diffusers_path=diffuser_path,
|
||||
vae=dict(repo_id="stabilityai/sd-vae-ft-mse"),
|
||||
model_name=model_name,
|
||||
model_description=description,
|
||||
original_config_file=model_config_file,
|
||||
commit_to_conf=commit_to_conf,
|
||||
)
|
||||
else:
|
||||
model_name = self.import_ckpt_model(
|
||||
model_path,
|
||||
config=model_config_file,
|
||||
model_name=model_name,
|
||||
model_description=description,
|
||||
vae=str(
|
||||
Path(
|
||||
Globals.root,
|
||||
"models/ldm/stable-diffusion-v1/vae-ft-mse-840000-ema-pruned.ckpt",
|
||||
)
|
||||
),
|
||||
commit_to_conf=commit_to_conf,
|
||||
)
|
||||
|
||||
if commit_to_conf:
|
||||
self.commit(commit_to_conf)
|
||||
diffuser_path = Path(
|
||||
Globals.root, "models", Globals.converted_ckpts_dir, model_path.stem
|
||||
)
|
||||
model_name = self.convert_and_import(
|
||||
model_path,
|
||||
diffusers_path=diffuser_path,
|
||||
vae=dict(repo_id="stabilityai/sd-vae-ft-mse"),
|
||||
model_name=model_name,
|
||||
model_description=description,
|
||||
original_config_file=model_config_file,
|
||||
commit_to_conf=commit_to_conf,
|
||||
)
|
||||
return model_name
|
||||
|
||||
def convert_and_import(
|
||||
@ -800,7 +766,7 @@ class ModelManager(object):
|
||||
|
||||
new_config = None
|
||||
|
||||
from ldm.invoke.ckpt_to_diffuser import convert_ckpt_to_diffuser
|
||||
from . import convert_ckpt_to_diffusers
|
||||
|
||||
if diffusers_path.exists():
|
||||
print(
|
||||
@ -815,7 +781,7 @@ class ModelManager(object):
|
||||
# By passing the specified VAE to the conversion function, the autoencoder
|
||||
# will be built into the model rather than tacked on afterward via the config file
|
||||
vae_model = self._load_vae(vae) if vae else None
|
||||
convert_ckpt_to_diffuser(
|
||||
convert_ckpt_to_diffusers (
|
||||
ckpt_path,
|
||||
diffusers_path,
|
||||
extract_ema=True,
|
@ -13,9 +13,9 @@ from transformers import CLIPTokenizer, CLIPTextModel
|
||||
|
||||
from compel import Compel
|
||||
from compel.prompt_parser import FlattenedPrompt, Blend, Fragment, CrossAttentionControlSubstitute, PromptParser
|
||||
from ..devices import torch_dtype
|
||||
from ..util import torch_dtype
|
||||
from ..stable_diffusion import InvokeAIDiffuserComponent
|
||||
from ldm.invoke.globals import Globals
|
||||
from invokeai.backend.globals import Globals
|
||||
|
||||
def get_tokenizer(model) -> CLIPTokenizer:
|
||||
# TODO remove legacy ckpt fallback handling
|
||||
|
4
invokeai/backend/restoration/__init__.py
Normal file
@ -0,0 +1,4 @@
|
||||
'''
|
||||
Initialization file for the ldm.invoke.restoration package
|
||||
'''
|
||||
from .base import Restoration
|
38
invokeai/backend/restoration/base.py
Normal file
@ -0,0 +1,38 @@
|
||||
class Restoration():
|
||||
def __init__(self) -> None:
|
||||
pass
|
||||
|
||||
def load_face_restore_models(self, gfpgan_model_path='./models/gfpgan/GFPGANv1.4.pth'):
|
||||
# Load GFPGAN
|
||||
gfpgan = self.load_gfpgan(gfpgan_model_path)
|
||||
if gfpgan.gfpgan_model_exists:
|
||||
print('>> GFPGAN Initialized')
|
||||
else:
|
||||
print('>> GFPGAN Disabled')
|
||||
gfpgan = None
|
||||
|
||||
# Load CodeFormer
|
||||
codeformer = self.load_codeformer()
|
||||
if codeformer.codeformer_model_exists:
|
||||
print('>> CodeFormer Initialized')
|
||||
else:
|
||||
print('>> CodeFormer Disabled')
|
||||
codeformer = None
|
||||
|
||||
return gfpgan, codeformer
|
||||
|
||||
# Face Restore Models
|
||||
def load_gfpgan(self, gfpgan_model_path):
|
||||
from .gfpgan import GFPGAN
|
||||
return GFPGAN(gfpgan_model_path)
|
||||
|
||||
def load_codeformer(self):
|
||||
from .codeformer import CodeFormerRestoration
|
||||
return CodeFormerRestoration()
|
||||
|
||||
# Upscale Models
|
||||
def load_esrgan(self, esrgan_bg_tile=400):
|
||||
from .realesrgan import ESRGAN
|
||||
esrgan = ESRGAN(esrgan_bg_tile)
|
||||
print('>> ESRGAN Initialized')
|
||||
return esrgan;
|
108
invokeai/backend/restoration/codeformer.py
Normal file
@ -0,0 +1,108 @@
|
||||
import os
|
||||
import torch
|
||||
import numpy as np
|
||||
import warnings
|
||||
import sys
|
||||
from invokeai.backend.globals import Globals
|
||||
|
||||
pretrained_model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth'
|
||||
|
||||
class CodeFormerRestoration():
|
||||
def __init__(self,
|
||||
codeformer_dir='models/codeformer',
|
||||
codeformer_model_path='codeformer.pth') -> None:
|
||||
|
||||
if not os.path.isabs(codeformer_dir):
|
||||
codeformer_dir = os.path.join(Globals.root, codeformer_dir)
|
||||
|
||||
self.model_path = os.path.join(codeformer_dir, codeformer_model_path)
|
||||
self.codeformer_model_exists = os.path.isfile(self.model_path)
|
||||
|
||||
if not self.codeformer_model_exists:
|
||||
print('## NOT FOUND: CodeFormer model not found at ' + self.model_path)
|
||||
sys.path.append(os.path.abspath(codeformer_dir))
|
||||
|
||||
def process(self, image, strength, device, seed=None, fidelity=0.75):
|
||||
if seed is not None:
|
||||
print(f'>> CodeFormer - Restoring Faces for image seed:{seed}')
|
||||
with warnings.catch_warnings():
|
||||
warnings.filterwarnings('ignore', category=DeprecationWarning)
|
||||
warnings.filterwarnings('ignore', category=UserWarning)
|
||||
|
||||
from basicsr.utils.download_util import load_file_from_url
|
||||
from basicsr.utils import img2tensor, tensor2img
|
||||
from facexlib.utils.face_restoration_helper import FaceRestoreHelper
|
||||
from ldm.invoke.restoration.codeformer_arch import CodeFormer
|
||||
from torchvision.transforms.functional import normalize
|
||||
from PIL import Image
|
||||
|
||||
cf_class = CodeFormer
|
||||
|
||||
cf = cf_class(
|
||||
dim_embd=512,
|
||||
codebook_size=1024,
|
||||
n_head=8,
|
||||
n_layers=9,
|
||||
connect_list=['32', '64', '128', '256']
|
||||
).to(device)
|
||||
|
||||
# note that this file should already be downloaded and cached at
|
||||
# this point
|
||||
checkpoint_path = load_file_from_url(url=pretrained_model_url,
|
||||
model_dir=os.path.abspath(os.path.dirname(self.model_path)),
|
||||
progress=True
|
||||
)
|
||||
checkpoint = torch.load(checkpoint_path)['params_ema']
|
||||
cf.load_state_dict(checkpoint)
|
||||
cf.eval()
|
||||
|
||||
image = image.convert('RGB')
|
||||
# Codeformer expects a BGR np array; make array and flip channels
|
||||
bgr_image_array = np.array(image, dtype=np.uint8)[...,::-1]
|
||||
|
||||
face_helper = FaceRestoreHelper(
|
||||
upscale_factor=1,
|
||||
use_parse=True,
|
||||
device=device,
|
||||
model_rootpath=os.path.join(Globals.root,'models','gfpgan','weights'),
|
||||
)
|
||||
face_helper.clean_all()
|
||||
face_helper.read_image(bgr_image_array)
|
||||
face_helper.get_face_landmarks_5(resize=640, eye_dist_threshold=5)
|
||||
face_helper.align_warp_face()
|
||||
|
||||
for idx, cropped_face in enumerate(face_helper.cropped_faces):
|
||||
cropped_face_t = img2tensor(cropped_face / 255., bgr2rgb=True, float32=True)
|
||||
normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
|
||||
cropped_face_t = cropped_face_t.unsqueeze(0).to(device)
|
||||
|
||||
try:
|
||||
with torch.no_grad():
|
||||
output = cf(cropped_face_t, w=fidelity, adain=True)[0]
|
||||
restored_face = tensor2img(output.squeeze(0), rgb2bgr=True, min_max=(-1, 1))
|
||||
del output
|
||||
torch.cuda.empty_cache()
|
||||
except RuntimeError as error:
|
||||
print(f'\tFailed inference for CodeFormer: {error}.')
|
||||
restored_face = cropped_face
|
||||
|
||||
restored_face = restored_face.astype('uint8')
|
||||
face_helper.add_restored_face(restored_face)
|
||||
|
||||
|
||||
face_helper.get_inverse_affine(None)
|
||||
|
||||
restored_img = face_helper.paste_faces_to_input_image()
|
||||
|
||||
# Flip the channels back to RGB
|
||||
res = Image.fromarray(restored_img[...,::-1])
|
||||
|
||||
if strength < 1.0:
|
||||
# Resize the image to the new image if the sizes have changed
|
||||
if restored_img.size != image.size:
|
||||
image = image.resize(res.size)
|
||||
res = Image.blend(image, res, strength)
|
||||
|
||||
cf = None
|
||||
|
||||
return res
|
275
invokeai/backend/restoration/codeformer_arch.py
Normal file
@ -0,0 +1,275 @@
|
||||
import math
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import nn, Tensor
|
||||
import torch.nn.functional as F
|
||||
from typing import Optional, List
|
||||
|
||||
from .vqgan_arch import *
|
||||
from basicsr.utils import get_root_logger
|
||||
from basicsr.utils.registry import ARCH_REGISTRY
|
||||
|
||||
def calc_mean_std(feat, eps=1e-5):
|
||||
"""Calculate mean and std for adaptive_instance_normalization.
|
||||
|
||||
Args:
|
||||
feat (Tensor): 4D tensor.
|
||||
eps (float): A small value added to the variance to avoid
|
||||
divide-by-zero. Default: 1e-5.
|
||||
"""
|
||||
size = feat.size()
|
||||
assert len(size) == 4, 'The input feature should be 4D tensor.'
|
||||
b, c = size[:2]
|
||||
feat_var = feat.view(b, c, -1).var(dim=2) + eps
|
||||
feat_std = feat_var.sqrt().view(b, c, 1, 1)
|
||||
feat_mean = feat.view(b, c, -1).mean(dim=2).view(b, c, 1, 1)
|
||||
return feat_mean, feat_std
|
||||
|
||||
def adaptive_instance_normalization(content_feat, style_feat):
|
||||
"""Adaptive instance normalization.
|
||||
|
||||
Adjust the reference features to have the similar color and illuminations
|
||||
as those in the degradate features.
|
||||
|
||||
Args:
|
||||
content_feat (Tensor): The reference feature.
|
||||
style_feat (Tensor): The degradate features.
|
||||
"""
|
||||
size = content_feat.size()
|
||||
style_mean, style_std = calc_mean_std(style_feat)
|
||||
content_mean, content_std = calc_mean_std(content_feat)
|
||||
normalized_feat = (content_feat - content_mean.expand(size)) / content_std.expand(size)
|
||||
return normalized_feat * style_std.expand(size) + style_mean.expand(size)
|
||||
|
||||
|
||||
class PositionEmbeddingSine(nn.Module):
|
||||
"""
|
||||
This is a more standard version of the position embedding, very similar to the one
|
||||
used by the Attention is all you need paper, generalized to work on images.
|
||||
"""
|
||||
|
||||
def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):
|
||||
super().__init__()
|
||||
self.num_pos_feats = num_pos_feats
|
||||
self.temperature = temperature
|
||||
self.normalize = normalize
|
||||
if scale is not None and normalize is False:
|
||||
raise ValueError("normalize should be True if scale is passed")
|
||||
if scale is None:
|
||||
scale = 2 * math.pi
|
||||
self.scale = scale
|
||||
|
||||
def forward(self, x, mask=None):
|
||||
if mask is None:
|
||||
mask = torch.zeros((x.size(0), x.size(2), x.size(3)), device=x.device, dtype=torch.bool)
|
||||
not_mask = ~mask
|
||||
y_embed = not_mask.cumsum(1, dtype=torch.float32)
|
||||
x_embed = not_mask.cumsum(2, dtype=torch.float32)
|
||||
if self.normalize:
|
||||
eps = 1e-6
|
||||
y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
|
||||
x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
|
||||
|
||||
dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
|
||||
dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
|
||||
|
||||
pos_x = x_embed[:, :, :, None] / dim_t
|
||||
pos_y = y_embed[:, :, :, None] / dim_t
|
||||
pos_x = torch.stack(
|
||||
(pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4
|
||||
).flatten(3)
|
||||
pos_y = torch.stack(
|
||||
(pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4
|
||||
).flatten(3)
|
||||
pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
|
||||
return pos
|
||||
|
||||
def _get_activation_fn(activation):
|
||||
"""Return an activation function given a string"""
|
||||
if activation == "relu":
|
||||
return F.relu
|
||||
if activation == "gelu":
|
||||
return F.gelu
|
||||
if activation == "glu":
|
||||
return F.glu
|
||||
raise RuntimeError(F"activation should be relu/gelu, not {activation}.")
|
||||
|
||||
|
||||
class TransformerSALayer(nn.Module):
|
||||
def __init__(self, embed_dim, nhead=8, dim_mlp=2048, dropout=0.0, activation="gelu"):
|
||||
super().__init__()
|
||||
self.self_attn = nn.MultiheadAttention(embed_dim, nhead, dropout=dropout)
|
||||
# Implementation of Feedforward model - MLP
|
||||
self.linear1 = nn.Linear(embed_dim, dim_mlp)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
self.linear2 = nn.Linear(dim_mlp, embed_dim)
|
||||
|
||||
self.norm1 = nn.LayerNorm(embed_dim)
|
||||
self.norm2 = nn.LayerNorm(embed_dim)
|
||||
self.dropout1 = nn.Dropout(dropout)
|
||||
self.dropout2 = nn.Dropout(dropout)
|
||||
|
||||
self.activation = _get_activation_fn(activation)
|
||||
|
||||
def with_pos_embed(self, tensor, pos: Optional[Tensor]):
|
||||
return tensor if pos is None else tensor + pos
|
||||
|
||||
def forward(self, tgt,
|
||||
tgt_mask: Optional[Tensor] = None,
|
||||
tgt_key_padding_mask: Optional[Tensor] = None,
|
||||
query_pos: Optional[Tensor] = None):
|
||||
|
||||
# self attention
|
||||
tgt2 = self.norm1(tgt)
|
||||
q = k = self.with_pos_embed(tgt2, query_pos)
|
||||
tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask,
|
||||
key_padding_mask=tgt_key_padding_mask)[0]
|
||||
tgt = tgt + self.dropout1(tgt2)
|
||||
|
||||
# ffn
|
||||
tgt2 = self.norm2(tgt)
|
||||
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
|
||||
tgt = tgt + self.dropout2(tgt2)
|
||||
return tgt
|
||||
|
||||
class Fuse_sft_block(nn.Module):
|
||||
def __init__(self, in_ch, out_ch):
|
||||
super().__init__()
|
||||
self.encode_enc = ResBlock(2*in_ch, out_ch)
|
||||
|
||||
self.scale = nn.Sequential(
|
||||
nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
|
||||
nn.LeakyReLU(0.2, True),
|
||||
nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1))
|
||||
|
||||
self.shift = nn.Sequential(
|
||||
nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
|
||||
nn.LeakyReLU(0.2, True),
|
||||
nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1))
|
||||
|
||||
def forward(self, enc_feat, dec_feat, w=1):
|
||||
enc_feat = self.encode_enc(torch.cat([enc_feat, dec_feat], dim=1))
|
||||
scale = self.scale(enc_feat)
|
||||
shift = self.shift(enc_feat)
|
||||
residual = w * (dec_feat * scale + shift)
|
||||
out = dec_feat + residual
|
||||
return out
|
||||
|
||||
|
||||
@ARCH_REGISTRY.register()
|
||||
class CodeFormer(VQAutoEncoder):
|
||||
def __init__(self, dim_embd=512, n_head=8, n_layers=9,
|
||||
codebook_size=1024, latent_size=256,
|
||||
connect_list=['32', '64', '128', '256'],
|
||||
fix_modules=['quantize','generator']):
|
||||
super(CodeFormer, self).__init__(512, 64, [1, 2, 2, 4, 4, 8], 'nearest',2, [16], codebook_size)
|
||||
|
||||
if fix_modules is not None:
|
||||
for module in fix_modules:
|
||||
for param in getattr(self, module).parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
self.connect_list = connect_list
|
||||
self.n_layers = n_layers
|
||||
self.dim_embd = dim_embd
|
||||
self.dim_mlp = dim_embd*2
|
||||
|
||||
self.position_emb = nn.Parameter(torch.zeros(latent_size, self.dim_embd))
|
||||
self.feat_emb = nn.Linear(256, self.dim_embd)
|
||||
|
||||
# transformer
|
||||
self.ft_layers = nn.Sequential(*[TransformerSALayer(embed_dim=dim_embd, nhead=n_head, dim_mlp=self.dim_mlp, dropout=0.0)
|
||||
for _ in range(self.n_layers)])
|
||||
|
||||
# logits_predict head
|
||||
self.idx_pred_layer = nn.Sequential(
|
||||
nn.LayerNorm(dim_embd),
|
||||
nn.Linear(dim_embd, codebook_size, bias=False))
|
||||
|
||||
self.channels = {
|
||||
'16': 512,
|
||||
'32': 256,
|
||||
'64': 256,
|
||||
'128': 128,
|
||||
'256': 128,
|
||||
'512': 64,
|
||||
}
|
||||
|
||||
# after second residual block for > 16, before attn layer for ==16
|
||||
self.fuse_encoder_block = {'512':2, '256':5, '128':8, '64':11, '32':14, '16':18}
|
||||
# after first residual block for > 16, before attn layer for ==16
|
||||
self.fuse_generator_block = {'16':6, '32': 9, '64':12, '128':15, '256':18, '512':21}
|
||||
|
||||
# fuse_convs_dict
|
||||
self.fuse_convs_dict = nn.ModuleDict()
|
||||
for f_size in self.connect_list:
|
||||
in_ch = self.channels[f_size]
|
||||
self.fuse_convs_dict[f_size] = Fuse_sft_block(in_ch, in_ch)
|
||||
|
||||
def _init_weights(self, module):
|
||||
if isinstance(module, (nn.Linear, nn.Embedding)):
|
||||
module.weight.data.normal_(mean=0.0, std=0.02)
|
||||
if isinstance(module, nn.Linear) and module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
|
||||
def forward(self, x, w=0, detach_16=True, code_only=False, adain=False):
|
||||
# ################### Encoder #####################
|
||||
enc_feat_dict = {}
|
||||
out_list = [self.fuse_encoder_block[f_size] for f_size in self.connect_list]
|
||||
for i, block in enumerate(self.encoder.blocks):
|
||||
x = block(x)
|
||||
if i in out_list:
|
||||
enc_feat_dict[str(x.shape[-1])] = x.clone()
|
||||
|
||||
lq_feat = x
|
||||
# ################# Transformer ###################
|
||||
# quant_feat, codebook_loss, quant_stats = self.quantize(lq_feat)
|
||||
pos_emb = self.position_emb.unsqueeze(1).repeat(1,x.shape[0],1)
|
||||
# BCHW -> BC(HW) -> (HW)BC
|
||||
feat_emb = self.feat_emb(lq_feat.flatten(2).permute(2,0,1))
|
||||
query_emb = feat_emb
|
||||
# Transformer encoder
|
||||
for layer in self.ft_layers:
|
||||
query_emb = layer(query_emb, query_pos=pos_emb)
|
||||
|
||||
# output logits
|
||||
logits = self.idx_pred_layer(query_emb) # (hw)bn
|
||||
logits = logits.permute(1,0,2) # (hw)bn -> b(hw)n
|
||||
|
||||
if code_only: # for training stage II
|
||||
# logits doesn't need softmax before cross_entropy loss
|
||||
return logits, lq_feat
|
||||
|
||||
# ################# Quantization ###################
|
||||
# if self.training:
|
||||
# quant_feat = torch.einsum('btn,nc->btc', [soft_one_hot, self.quantize.embedding.weight])
|
||||
# # b(hw)c -> bc(hw) -> bchw
|
||||
# quant_feat = quant_feat.permute(0,2,1).view(lq_feat.shape)
|
||||
# ------------
|
||||
soft_one_hot = F.softmax(logits, dim=2)
|
||||
_, top_idx = torch.topk(soft_one_hot, 1, dim=2)
|
||||
quant_feat = self.quantize.get_codebook_feat(top_idx, shape=[x.shape[0],16,16,256])
|
||||
# preserve gradients
|
||||
# quant_feat = lq_feat + (quant_feat - lq_feat).detach()
|
||||
|
||||
if detach_16:
|
||||
quant_feat = quant_feat.detach() # for training stage III
|
||||
if adain:
|
||||
quant_feat = adaptive_instance_normalization(quant_feat, lq_feat)
|
||||
|
||||
# ################## Generator ####################
|
||||
x = quant_feat
|
||||
fuse_list = [self.fuse_generator_block[f_size] for f_size in self.connect_list]
|
||||
|
||||
for i, block in enumerate(self.generator.blocks):
|
||||
x = block(x)
|
||||
if i in fuse_list: # fuse after i-th block
|
||||
f_size = str(x.shape[-1])
|
||||
if w>0:
|
||||
x = self.fuse_convs_dict[f_size](enc_feat_dict[f_size].detach(), x, w)
|
||||
out = x
|
||||
# logits doesn't need softmax before cross_entropy loss
|
||||
return out, logits, lq_feat
|
87
invokeai/backend/restoration/gfpgan.py
Normal file
@ -0,0 +1,87 @@
|
||||
import torch
|
||||
import warnings
|
||||
import os
|
||||
import sys
|
||||
import numpy as np
|
||||
from invokeai.backend.globals import Globals
|
||||
|
||||
from PIL import Image
|
||||
|
||||
|
||||
class GFPGAN():
|
||||
def __init__(
|
||||
self,
|
||||
gfpgan_model_path='models/gfpgan/GFPGANv1.4.pth'
|
||||
) -> None:
|
||||
|
||||
if not os.path.isabs(gfpgan_model_path):
|
||||
gfpgan_model_path=os.path.abspath(os.path.join(Globals.root,gfpgan_model_path))
|
||||
self.model_path = gfpgan_model_path
|
||||
self.gfpgan_model_exists = os.path.isfile(self.model_path)
|
||||
|
||||
if not self.gfpgan_model_exists:
|
||||
print('## NOT FOUND: GFPGAN model not found at ' + self.model_path)
|
||||
return None
|
||||
|
||||
def model_exists(self):
|
||||
return os.path.isfile(self.model_path)
|
||||
|
||||
def process(self, image, strength: float, seed: str = None):
|
||||
if seed is not None:
|
||||
print(f'>> GFPGAN - Restoring Faces for image seed:{seed}')
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.filterwarnings('ignore', category=DeprecationWarning)
|
||||
warnings.filterwarnings('ignore', category=UserWarning)
|
||||
cwd = os.getcwd()
|
||||
os.chdir(os.path.join(Globals.root,'models'))
|
||||
try:
|
||||
from gfpgan import GFPGANer
|
||||
self.gfpgan = GFPGANer(
|
||||
model_path=self.model_path,
|
||||
upscale=1,
|
||||
arch='clean',
|
||||
channel_multiplier=2,
|
||||
bg_upsampler=None,
|
||||
)
|
||||
except Exception:
|
||||
import traceback
|
||||
print('>> Error loading GFPGAN:', file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
os.chdir(cwd)
|
||||
|
||||
if self.gfpgan is None:
|
||||
print(
|
||||
f'>> WARNING: GFPGAN not initialized.'
|
||||
)
|
||||
print(
|
||||
f'>> Download https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth to {self.model_path}'
|
||||
)
|
||||
|
||||
image = image.convert('RGB')
|
||||
|
||||
# GFPGAN expects a BGR np array; make array and flip channels
|
||||
bgr_image_array = np.array(image, dtype=np.uint8)[...,::-1]
|
||||
|
||||
_, _, restored_img = self.gfpgan.enhance(
|
||||
bgr_image_array,
|
||||
has_aligned=False,
|
||||
only_center_face=False,
|
||||
paste_back=True,
|
||||
)
|
||||
|
||||
# Flip the channels back to RGB
|
||||
res = Image.fromarray(restored_img[...,::-1])
|
||||
|
||||
if strength < 1.0:
|
||||
# Resize the image to the new image if the sizes have changed
|
||||
if restored_img.size != image.size:
|
||||
image = image.resize(res.size)
|
||||
res = Image.blend(image, res, strength)
|
||||
|
||||
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
self.gfpgan = None
|
||||
|
||||
return res
|
108
invokeai/backend/restoration/outcrop.py
Normal file
@ -0,0 +1,108 @@
|
||||
import warnings
|
||||
import math
|
||||
from PIL import Image, ImageFilter
|
||||
|
||||
class Outcrop(object):
|
||||
def __init__(
|
||||
self,
|
||||
image,
|
||||
generate, # current generate object
|
||||
):
|
||||
self.image = image
|
||||
self.generate = generate
|
||||
|
||||
def process (
|
||||
self,
|
||||
extents:dict,
|
||||
opt, # current options
|
||||
orig_opt, # ones originally used to generate the image
|
||||
image_callback = None,
|
||||
prefix = None
|
||||
):
|
||||
# grow and mask the image
|
||||
extended_image = self._extend_all(extents)
|
||||
|
||||
# switch samplers temporarily
|
||||
curr_sampler = self.generate.sampler
|
||||
self.generate.sampler_name = opt.sampler_name
|
||||
self.generate._set_sampler()
|
||||
|
||||
def wrapped_callback(img,seed,**kwargs):
|
||||
preferred_seed = orig_opt.seed if orig_opt.seed is not None and orig_opt.seed >= 0 else seed
|
||||
image_callback(img,preferred_seed,use_prefix=prefix,**kwargs)
|
||||
|
||||
result= self.generate.prompt2image(
|
||||
opt.prompt,
|
||||
seed = opt.seed or orig_opt.seed,
|
||||
sampler = self.generate.sampler,
|
||||
steps = opt.steps,
|
||||
cfg_scale = opt.cfg_scale,
|
||||
ddim_eta = self.generate.ddim_eta,
|
||||
width = extended_image.width,
|
||||
height = extended_image.height,
|
||||
init_img = extended_image,
|
||||
strength = 0.90,
|
||||
image_callback = wrapped_callback if image_callback else None,
|
||||
seam_size = opt.seam_size or 96,
|
||||
seam_blur = opt.seam_blur or 16,
|
||||
seam_strength = opt.seam_strength or 0.7,
|
||||
seam_steps = 20,
|
||||
tile_size = 32,
|
||||
color_match = True,
|
||||
force_outpaint = True, # this just stops the warning about erased regions
|
||||
)
|
||||
|
||||
# swap sampler back
|
||||
self.generate.sampler = curr_sampler
|
||||
return result
|
||||
|
||||
def _extend_all(
|
||||
self,
|
||||
extents:dict,
|
||||
) -> Image:
|
||||
'''
|
||||
Extend the image in direction ('top','bottom','left','right') by
|
||||
the indicated value. The image canvas is extended, and the empty
|
||||
rectangular section will be filled with a blurred copy of the
|
||||
adjacent image.
|
||||
'''
|
||||
image = self.image
|
||||
for direction in extents:
|
||||
assert direction in ['top', 'left', 'bottom', 'right'],'Direction must be one of "top", "left", "bottom", "right"'
|
||||
pixels = extents[direction]
|
||||
# round pixels up to the nearest 64
|
||||
pixels = math.ceil(pixels/64) * 64
|
||||
print(f'>> extending image {direction}ward by {pixels} pixels')
|
||||
image = self._rotate(image,direction)
|
||||
image = self._extend(image,pixels)
|
||||
image = self._rotate(image,direction,reverse=True)
|
||||
return image
|
||||
|
||||
def _rotate(self,image:Image,direction:str,reverse=False) -> Image:
|
||||
'''
|
||||
Rotates image so that the area to extend is always at the top top.
|
||||
Simplifies logic later. The reverse argument, if true, will undo the
|
||||
previous transpose.
|
||||
'''
|
||||
transposes = {
|
||||
'right': ['ROTATE_90','ROTATE_270'],
|
||||
'bottom': ['ROTATE_180','ROTATE_180'],
|
||||
'left': ['ROTATE_270','ROTATE_90']
|
||||
}
|
||||
if direction not in transposes:
|
||||
return image
|
||||
transpose = transposes[direction][1 if reverse else 0]
|
||||
return image.transpose(Image.Transpose.__dict__[transpose])
|
||||
|
||||
def _extend(self,image:Image,pixels:int)-> Image:
|
||||
extended_img = Image.new('RGBA',(image.width,image.height+pixels))
|
||||
|
||||
extended_img.paste((0,0,0),[0,0,image.width,image.height+pixels])
|
||||
extended_img.paste(image,box=(0,pixels))
|
||||
|
||||
# now make the top part transparent to use as a mask
|
||||
alpha = extended_img.getchannel('A')
|
||||
alpha.paste(0,(0,0,extended_img.width,pixels))
|
||||
extended_img.putalpha(alpha)
|
||||
|
||||
return extended_img
|
92
invokeai/backend/restoration/outpaint.py
Normal file
@ -0,0 +1,92 @@
|
||||
import warnings
|
||||
import math
|
||||
from PIL import Image, ImageFilter
|
||||
|
||||
class Outpaint(object):
|
||||
def __init__(self, image, generate):
|
||||
self.image = image
|
||||
self.generate = generate
|
||||
|
||||
def process(self, opt, old_opt, image_callback = None, prefix = None):
|
||||
image = self._create_outpaint_image(self.image, opt.out_direction)
|
||||
|
||||
seed = old_opt.seed
|
||||
prompt = old_opt.prompt
|
||||
|
||||
def wrapped_callback(img,seed,**kwargs):
|
||||
image_callback(img,seed,use_prefix=prefix,**kwargs)
|
||||
|
||||
|
||||
return self.generate.prompt2image(
|
||||
prompt,
|
||||
seed = seed,
|
||||
sampler = self.generate.sampler,
|
||||
steps = opt.steps,
|
||||
cfg_scale = opt.cfg_scale,
|
||||
ddim_eta = self.generate.ddim_eta,
|
||||
width = opt.width,
|
||||
height = opt.height,
|
||||
init_img = image,
|
||||
strength = 0.83,
|
||||
image_callback = wrapped_callback,
|
||||
prefix = prefix,
|
||||
)
|
||||
|
||||
def _create_outpaint_image(self, image, direction_args):
|
||||
assert len(direction_args) in [1, 2], 'Direction (-D) must have exactly one or two arguments.'
|
||||
|
||||
if len(direction_args) == 1:
|
||||
direction = direction_args[0]
|
||||
pixels = None
|
||||
elif len(direction_args) == 2:
|
||||
direction = direction_args[0]
|
||||
pixels = int(direction_args[1])
|
||||
|
||||
assert direction in ['top', 'left', 'bottom', 'right'], 'Direction (-D) must be one of "top", "left", "bottom", "right"'
|
||||
|
||||
image = image.convert("RGBA")
|
||||
# we always extend top, but rotate to extend along the requested side
|
||||
if direction == 'left':
|
||||
image = image.transpose(Image.Transpose.ROTATE_270)
|
||||
elif direction == 'bottom':
|
||||
image = image.transpose(Image.Transpose.ROTATE_180)
|
||||
elif direction == 'right':
|
||||
image = image.transpose(Image.Transpose.ROTATE_90)
|
||||
|
||||
pixels = image.height//2 if pixels is None else int(pixels)
|
||||
assert 0 < pixels < image.height, 'Direction (-D) pixels length must be in the range 0 - image.size'
|
||||
|
||||
# the top part of the image is taken from the source image mirrored
|
||||
# coordinates (0,0) are the upper left corner of an image
|
||||
top = image.transpose(Image.Transpose.FLIP_TOP_BOTTOM).convert("RGBA")
|
||||
top = top.crop((0, top.height - pixels, top.width, top.height))
|
||||
|
||||
# setting all alpha of the top part to 0
|
||||
alpha = top.getchannel("A")
|
||||
alpha.paste(0, (0, 0, top.width, top.height))
|
||||
top.putalpha(alpha)
|
||||
|
||||
# taking the bottom from the original image
|
||||
bottom = image.crop((0, 0, image.width, image.height - pixels))
|
||||
|
||||
new_img = image.copy()
|
||||
new_img.paste(top, (0, 0))
|
||||
new_img.paste(bottom, (0, pixels))
|
||||
|
||||
# create a 10% dither in the middle
|
||||
dither = min(image.height//10, pixels)
|
||||
for x in range(0, image.width, 2):
|
||||
for y in range(pixels - dither, pixels + dither):
|
||||
(r, g, b, a) = new_img.getpixel((x, y))
|
||||
new_img.putpixel((x, y), (r, g, b, 0))
|
||||
|
||||
# let's rotate back again
|
||||
if direction == 'left':
|
||||
new_img = new_img.transpose(Image.Transpose.ROTATE_90)
|
||||
elif direction == 'bottom':
|
||||
new_img = new_img.transpose(Image.Transpose.ROTATE_180)
|
||||
elif direction == 'right':
|
||||
new_img = new_img.transpose(Image.Transpose.ROTATE_270)
|
||||
|
||||
return new_img
|
||||
|
92
invokeai/backend/restoration/realesrgan.py
Normal file
@ -0,0 +1,92 @@
|
||||
import torch
|
||||
import warnings
|
||||
import numpy as np
|
||||
import os
|
||||
|
||||
from invokeai.backend.globals import Globals
|
||||
from PIL import Image
|
||||
from PIL.Image import Image as ImageType
|
||||
|
||||
class ESRGAN():
|
||||
def __init__(self, bg_tile_size=400) -> None:
|
||||
self.bg_tile_size = bg_tile_size
|
||||
|
||||
if not torch.cuda.is_available(): # CPU or MPS on M1
|
||||
use_half_precision = False
|
||||
else:
|
||||
use_half_precision = True
|
||||
|
||||
def load_esrgan_bg_upsampler(self, denoise_str):
|
||||
if not torch.cuda.is_available(): # CPU or MPS on M1
|
||||
use_half_precision = False
|
||||
else:
|
||||
use_half_precision = True
|
||||
|
||||
from realesrgan.archs.srvgg_arch import SRVGGNetCompact
|
||||
from realesrgan import RealESRGANer
|
||||
|
||||
model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu')
|
||||
model_path = os.path.join(Globals.root, 'models/realesrgan/realesr-general-x4v3.pth')
|
||||
wdn_model_path = os.path.join(Globals.root, 'models/realesrgan/realesr-general-wdn-x4v3.pth')
|
||||
scale = 4
|
||||
|
||||
bg_upsampler = RealESRGANer(
|
||||
scale=scale,
|
||||
model_path=[model_path, wdn_model_path],
|
||||
model=model,
|
||||
tile=self.bg_tile_size,
|
||||
dni_weight=[denoise_str, 1 - denoise_str],
|
||||
tile_pad=10,
|
||||
pre_pad=0,
|
||||
half=use_half_precision,
|
||||
)
|
||||
|
||||
return bg_upsampler
|
||||
|
||||
def process(self, image: ImageType, strength: float, seed: str = None, upsampler_scale: int = 2, denoise_str: float = 0.75):
|
||||
with warnings.catch_warnings():
|
||||
warnings.filterwarnings('ignore', category=DeprecationWarning)
|
||||
warnings.filterwarnings('ignore', category=UserWarning)
|
||||
|
||||
try:
|
||||
upsampler = self.load_esrgan_bg_upsampler(denoise_str)
|
||||
except Exception:
|
||||
import traceback
|
||||
import sys
|
||||
print('>> Error loading Real-ESRGAN:', file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
|
||||
if upsampler_scale == 0:
|
||||
print('>> Real-ESRGAN: Invalid scaling option. Image not upscaled.')
|
||||
return image
|
||||
|
||||
if seed is not None:
|
||||
print(
|
||||
f'>> Real-ESRGAN Upscaling seed:{seed}, scale:{upsampler_scale}x, tile:{self.bg_tile_size}, denoise:{denoise_str}'
|
||||
)
|
||||
# ESRGAN outputs images with partial transparency if given RGBA images; convert to RGB
|
||||
image = image.convert("RGB")
|
||||
|
||||
# REALSRGAN expects a BGR np array; make array and flip channels
|
||||
bgr_image_array = np.array(image, dtype=np.uint8)[...,::-1]
|
||||
|
||||
output, _ = upsampler.enhance(
|
||||
bgr_image_array,
|
||||
outscale=upsampler_scale,
|
||||
alpha_upsampler='realesrgan',
|
||||
)
|
||||
|
||||
# Flip the channels back to RGB
|
||||
res = Image.fromarray(output[...,::-1])
|
||||
|
||||
if strength < 1.0:
|
||||
# Resize the image to the new image if the sizes have changed
|
||||
if output.size != image.size:
|
||||
image = image.resize(res.size)
|
||||
res = Image.blend(image, res, strength)
|
||||
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
upsampler = None
|
||||
|
||||
return res
|
435
invokeai/backend/restoration/vqgan_arch.py
Normal file
@ -0,0 +1,435 @@
|
||||
'''
|
||||
VQGAN code, adapted from the original created by the Unleashing Transformers authors:
|
||||
https://github.com/samb-t/unleashing-transformers/blob/master/models/vqgan.py
|
||||
|
||||
'''
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import copy
|
||||
from basicsr.utils import get_root_logger
|
||||
from basicsr.utils.registry import ARCH_REGISTRY
|
||||
|
||||
def normalize(in_channels):
|
||||
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def swish(x):
|
||||
return x*torch.sigmoid(x)
|
||||
|
||||
|
||||
# Define VQVAE classes
|
||||
class VectorQuantizer(nn.Module):
|
||||
def __init__(self, codebook_size, emb_dim, beta):
|
||||
super(VectorQuantizer, self).__init__()
|
||||
self.codebook_size = codebook_size # number of embeddings
|
||||
self.emb_dim = emb_dim # dimension of embedding
|
||||
self.beta = beta # commitment cost used in loss term, beta * ||z_e(x)-sg[e]||^2
|
||||
self.embedding = nn.Embedding(self.codebook_size, self.emb_dim)
|
||||
self.embedding.weight.data.uniform_(-1.0 / self.codebook_size, 1.0 / self.codebook_size)
|
||||
|
||||
def forward(self, z):
|
||||
# reshape z -> (batch, height, width, channel) and flatten
|
||||
z = z.permute(0, 2, 3, 1).contiguous()
|
||||
z_flattened = z.view(-1, self.emb_dim)
|
||||
|
||||
# distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
|
||||
d = (z_flattened ** 2).sum(dim=1, keepdim=True) + (self.embedding.weight**2).sum(1) - \
|
||||
2 * torch.matmul(z_flattened, self.embedding.weight.t())
|
||||
|
||||
mean_distance = torch.mean(d)
|
||||
# find closest encodings
|
||||
# min_encoding_indices = torch.argmin(d, dim=1).unsqueeze(1)
|
||||
min_encoding_scores, min_encoding_indices = torch.topk(d, 1, dim=1, largest=False)
|
||||
# [0-1], higher score, higher confidence
|
||||
min_encoding_scores = torch.exp(-min_encoding_scores/10)
|
||||
|
||||
min_encodings = torch.zeros(min_encoding_indices.shape[0], self.codebook_size).to(z)
|
||||
min_encodings.scatter_(1, min_encoding_indices, 1)
|
||||
|
||||
# get quantized latent vectors
|
||||
z_q = torch.matmul(min_encodings, self.embedding.weight).view(z.shape)
|
||||
# compute loss for embedding
|
||||
loss = torch.mean((z_q.detach()-z)**2) + self.beta * torch.mean((z_q - z.detach()) ** 2)
|
||||
# preserve gradients
|
||||
z_q = z + (z_q - z).detach()
|
||||
|
||||
# perplexity
|
||||
e_mean = torch.mean(min_encodings, dim=0)
|
||||
perplexity = torch.exp(-torch.sum(e_mean * torch.log(e_mean + 1e-10)))
|
||||
# reshape back to match original input shape
|
||||
z_q = z_q.permute(0, 3, 1, 2).contiguous()
|
||||
|
||||
return z_q, loss, {
|
||||
"perplexity": perplexity,
|
||||
"min_encodings": min_encodings,
|
||||
"min_encoding_indices": min_encoding_indices,
|
||||
"min_encoding_scores": min_encoding_scores,
|
||||
"mean_distance": mean_distance
|
||||
}
|
||||
|
||||
def get_codebook_feat(self, indices, shape):
|
||||
# input indices: batch*token_num -> (batch*token_num)*1
|
||||
# shape: batch, height, width, channel
|
||||
indices = indices.view(-1,1)
|
||||
min_encodings = torch.zeros(indices.shape[0], self.codebook_size).to(indices)
|
||||
min_encodings.scatter_(1, indices, 1)
|
||||
# get quantized latent vectors
|
||||
z_q = torch.matmul(min_encodings.float(), self.embedding.weight)
|
||||
|
||||
if shape is not None: # reshape back to match original input shape
|
||||
z_q = z_q.view(shape).permute(0, 3, 1, 2).contiguous()
|
||||
|
||||
return z_q
|
||||
|
||||
|
||||
class GumbelQuantizer(nn.Module):
|
||||
def __init__(self, codebook_size, emb_dim, num_hiddens, straight_through=False, kl_weight=5e-4, temp_init=1.0):
|
||||
super().__init__()
|
||||
self.codebook_size = codebook_size # number of embeddings
|
||||
self.emb_dim = emb_dim # dimension of embedding
|
||||
self.straight_through = straight_through
|
||||
self.temperature = temp_init
|
||||
self.kl_weight = kl_weight
|
||||
self.proj = nn.Conv2d(num_hiddens, codebook_size, 1) # projects last encoder layer to quantized logits
|
||||
self.embed = nn.Embedding(codebook_size, emb_dim)
|
||||
|
||||
def forward(self, z):
|
||||
hard = self.straight_through if self.training else True
|
||||
|
||||
logits = self.proj(z)
|
||||
|
||||
soft_one_hot = F.gumbel_softmax(logits, tau=self.temperature, dim=1, hard=hard)
|
||||
|
||||
z_q = torch.einsum("b n h w, n d -> b d h w", soft_one_hot, self.embed.weight)
|
||||
|
||||
# + kl divergence to the prior loss
|
||||
qy = F.softmax(logits, dim=1)
|
||||
diff = self.kl_weight * torch.sum(qy * torch.log(qy * self.codebook_size + 1e-10), dim=1).mean()
|
||||
min_encoding_indices = soft_one_hot.argmax(dim=1)
|
||||
|
||||
return z_q, diff, {
|
||||
"min_encoding_indices": min_encoding_indices
|
||||
}
|
||||
|
||||
|
||||
class Downsample(nn.Module):
|
||||
def __init__(self, in_channels):
|
||||
super().__init__()
|
||||
self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
|
||||
|
||||
def forward(self, x):
|
||||
pad = (0, 1, 0, 1)
|
||||
x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
|
||||
x = self.conv(x)
|
||||
return x
|
||||
|
||||
|
||||
class Upsample(nn.Module):
|
||||
def __init__(self, in_channels):
|
||||
super().__init__()
|
||||
self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
def forward(self, x):
|
||||
x = F.interpolate(x, scale_factor=2.0, mode="nearest")
|
||||
x = self.conv(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class ResBlock(nn.Module):
|
||||
def __init__(self, in_channels, out_channels=None):
|
||||
super(ResBlock, self).__init__()
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = in_channels if out_channels is None else out_channels
|
||||
self.norm1 = normalize(in_channels)
|
||||
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
||||
self.norm2 = normalize(out_channels)
|
||||
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
||||
if self.in_channels != self.out_channels:
|
||||
self.conv_out = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
|
||||
|
||||
def forward(self, x_in):
|
||||
x = x_in
|
||||
x = self.norm1(x)
|
||||
x = swish(x)
|
||||
x = self.conv1(x)
|
||||
x = self.norm2(x)
|
||||
x = swish(x)
|
||||
x = self.conv2(x)
|
||||
if self.in_channels != self.out_channels:
|
||||
x_in = self.conv_out(x_in)
|
||||
|
||||
return x + x_in
|
||||
|
||||
|
||||
class AttnBlock(nn.Module):
|
||||
def __init__(self, in_channels):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
|
||||
self.norm = normalize(in_channels)
|
||||
self.q = torch.nn.Conv2d(
|
||||
in_channels,
|
||||
in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0
|
||||
)
|
||||
self.k = torch.nn.Conv2d(
|
||||
in_channels,
|
||||
in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0
|
||||
)
|
||||
self.v = torch.nn.Conv2d(
|
||||
in_channels,
|
||||
in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0
|
||||
)
|
||||
self.proj_out = torch.nn.Conv2d(
|
||||
in_channels,
|
||||
in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
h_ = x
|
||||
h_ = self.norm(h_)
|
||||
q = self.q(h_)
|
||||
k = self.k(h_)
|
||||
v = self.v(h_)
|
||||
|
||||
# compute attention
|
||||
b, c, h, w = q.shape
|
||||
q = q.reshape(b, c, h*w)
|
||||
q = q.permute(0, 2, 1)
|
||||
k = k.reshape(b, c, h*w)
|
||||
w_ = torch.bmm(q, k)
|
||||
w_ = w_ * (int(c)**(-0.5))
|
||||
w_ = F.softmax(w_, dim=2)
|
||||
|
||||
# attend to values
|
||||
v = v.reshape(b, c, h*w)
|
||||
w_ = w_.permute(0, 2, 1)
|
||||
h_ = torch.bmm(v, w_)
|
||||
h_ = h_.reshape(b, c, h, w)
|
||||
|
||||
h_ = self.proj_out(h_)
|
||||
|
||||
return x+h_
|
||||
|
||||
|
||||
class Encoder(nn.Module):
|
||||
def __init__(self, in_channels, nf, emb_dim, ch_mult, num_res_blocks, resolution, attn_resolutions):
|
||||
super().__init__()
|
||||
self.nf = nf
|
||||
self.num_resolutions = len(ch_mult)
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.resolution = resolution
|
||||
self.attn_resolutions = attn_resolutions
|
||||
|
||||
curr_res = self.resolution
|
||||
in_ch_mult = (1,)+tuple(ch_mult)
|
||||
|
||||
blocks = []
|
||||
# initial convultion
|
||||
blocks.append(nn.Conv2d(in_channels, nf, kernel_size=3, stride=1, padding=1))
|
||||
|
||||
# residual and downsampling blocks, with attention on smaller res (16x16)
|
||||
for i in range(self.num_resolutions):
|
||||
block_in_ch = nf * in_ch_mult[i]
|
||||
block_out_ch = nf * ch_mult[i]
|
||||
for _ in range(self.num_res_blocks):
|
||||
blocks.append(ResBlock(block_in_ch, block_out_ch))
|
||||
block_in_ch = block_out_ch
|
||||
if curr_res in attn_resolutions:
|
||||
blocks.append(AttnBlock(block_in_ch))
|
||||
|
||||
if i != self.num_resolutions - 1:
|
||||
blocks.append(Downsample(block_in_ch))
|
||||
curr_res = curr_res // 2
|
||||
|
||||
# non-local attention block
|
||||
blocks.append(ResBlock(block_in_ch, block_in_ch))
|
||||
blocks.append(AttnBlock(block_in_ch))
|
||||
blocks.append(ResBlock(block_in_ch, block_in_ch))
|
||||
|
||||
# normalise and convert to latent size
|
||||
blocks.append(normalize(block_in_ch))
|
||||
blocks.append(nn.Conv2d(block_in_ch, emb_dim, kernel_size=3, stride=1, padding=1))
|
||||
self.blocks = nn.ModuleList(blocks)
|
||||
|
||||
def forward(self, x):
|
||||
for block in self.blocks:
|
||||
x = block(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class Generator(nn.Module):
|
||||
def __init__(self, nf, emb_dim, ch_mult, res_blocks, img_size, attn_resolutions):
|
||||
super().__init__()
|
||||
self.nf = nf
|
||||
self.ch_mult = ch_mult
|
||||
self.num_resolutions = len(self.ch_mult)
|
||||
self.num_res_blocks = res_blocks
|
||||
self.resolution = img_size
|
||||
self.attn_resolutions = attn_resolutions
|
||||
self.in_channels = emb_dim
|
||||
self.out_channels = 3
|
||||
block_in_ch = self.nf * self.ch_mult[-1]
|
||||
curr_res = self.resolution // 2 ** (self.num_resolutions-1)
|
||||
|
||||
blocks = []
|
||||
# initial conv
|
||||
blocks.append(nn.Conv2d(self.in_channels, block_in_ch, kernel_size=3, stride=1, padding=1))
|
||||
|
||||
# non-local attention block
|
||||
blocks.append(ResBlock(block_in_ch, block_in_ch))
|
||||
blocks.append(AttnBlock(block_in_ch))
|
||||
blocks.append(ResBlock(block_in_ch, block_in_ch))
|
||||
|
||||
for i in reversed(range(self.num_resolutions)):
|
||||
block_out_ch = self.nf * self.ch_mult[i]
|
||||
|
||||
for _ in range(self.num_res_blocks):
|
||||
blocks.append(ResBlock(block_in_ch, block_out_ch))
|
||||
block_in_ch = block_out_ch
|
||||
|
||||
if curr_res in self.attn_resolutions:
|
||||
blocks.append(AttnBlock(block_in_ch))
|
||||
|
||||
if i != 0:
|
||||
blocks.append(Upsample(block_in_ch))
|
||||
curr_res = curr_res * 2
|
||||
|
||||
blocks.append(normalize(block_in_ch))
|
||||
blocks.append(nn.Conv2d(block_in_ch, self.out_channels, kernel_size=3, stride=1, padding=1))
|
||||
|
||||
self.blocks = nn.ModuleList(blocks)
|
||||
|
||||
|
||||
def forward(self, x):
|
||||
for block in self.blocks:
|
||||
x = block(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
@ARCH_REGISTRY.register()
|
||||
class VQAutoEncoder(nn.Module):
|
||||
def __init__(self, img_size, nf, ch_mult, quantizer="nearest", res_blocks=2, attn_resolutions=[16], codebook_size=1024, emb_dim=256,
|
||||
beta=0.25, gumbel_straight_through=False, gumbel_kl_weight=1e-8, model_path=None):
|
||||
super().__init__()
|
||||
logger = get_root_logger()
|
||||
self.in_channels = 3
|
||||
self.nf = nf
|
||||
self.n_blocks = res_blocks
|
||||
self.codebook_size = codebook_size
|
||||
self.embed_dim = emb_dim
|
||||
self.ch_mult = ch_mult
|
||||
self.resolution = img_size
|
||||
self.attn_resolutions = attn_resolutions
|
||||
self.quantizer_type = quantizer
|
||||
self.encoder = Encoder(
|
||||
self.in_channels,
|
||||
self.nf,
|
||||
self.embed_dim,
|
||||
self.ch_mult,
|
||||
self.n_blocks,
|
||||
self.resolution,
|
||||
self.attn_resolutions
|
||||
)
|
||||
if self.quantizer_type == "nearest":
|
||||
self.beta = beta #0.25
|
||||
self.quantize = VectorQuantizer(self.codebook_size, self.embed_dim, self.beta)
|
||||
elif self.quantizer_type == "gumbel":
|
||||
self.gumbel_num_hiddens = emb_dim
|
||||
self.straight_through = gumbel_straight_through
|
||||
self.kl_weight = gumbel_kl_weight
|
||||
self.quantize = GumbelQuantizer(
|
||||
self.codebook_size,
|
||||
self.embed_dim,
|
||||
self.gumbel_num_hiddens,
|
||||
self.straight_through,
|
||||
self.kl_weight
|
||||
)
|
||||
self.generator = Generator(
|
||||
self.nf,
|
||||
self.embed_dim,
|
||||
self.ch_mult,
|
||||
self.n_blocks,
|
||||
self.resolution,
|
||||
self.attn_resolutions
|
||||
)
|
||||
|
||||
if model_path is not None:
|
||||
chkpt = torch.load(model_path, map_location='cpu')
|
||||
if 'params_ema' in chkpt:
|
||||
self.load_state_dict(torch.load(model_path, map_location='cpu')['params_ema'])
|
||||
logger.info(f'vqgan is loaded from: {model_path} [params_ema]')
|
||||
elif 'params' in chkpt:
|
||||
self.load_state_dict(torch.load(model_path, map_location='cpu')['params'])
|
||||
logger.info(f'vqgan is loaded from: {model_path} [params]')
|
||||
else:
|
||||
raise ValueError(f'Wrong params!')
|
||||
|
||||
|
||||
def forward(self, x):
|
||||
x = self.encoder(x)
|
||||
quant, codebook_loss, quant_stats = self.quantize(x)
|
||||
x = self.generator(quant)
|
||||
return x, codebook_loss, quant_stats
|
||||
|
||||
|
||||
|
||||
# patch based discriminator
|
||||
@ARCH_REGISTRY.register()
|
||||
class VQGANDiscriminator(nn.Module):
|
||||
def __init__(self, nc=3, ndf=64, n_layers=4, model_path=None):
|
||||
super().__init__()
|
||||
|
||||
layers = [nn.Conv2d(nc, ndf, kernel_size=4, stride=2, padding=1), nn.LeakyReLU(0.2, True)]
|
||||
ndf_mult = 1
|
||||
ndf_mult_prev = 1
|
||||
for n in range(1, n_layers): # gradually increase the number of filters
|
||||
ndf_mult_prev = ndf_mult
|
||||
ndf_mult = min(2 ** n, 8)
|
||||
layers += [
|
||||
nn.Conv2d(ndf * ndf_mult_prev, ndf * ndf_mult, kernel_size=4, stride=2, padding=1, bias=False),
|
||||
nn.BatchNorm2d(ndf * ndf_mult),
|
||||
nn.LeakyReLU(0.2, True)
|
||||
]
|
||||
|
||||
ndf_mult_prev = ndf_mult
|
||||
ndf_mult = min(2 ** n_layers, 8)
|
||||
|
||||
layers += [
|
||||
nn.Conv2d(ndf * ndf_mult_prev, ndf * ndf_mult, kernel_size=4, stride=1, padding=1, bias=False),
|
||||
nn.BatchNorm2d(ndf * ndf_mult),
|
||||
nn.LeakyReLU(0.2, True)
|
||||
]
|
||||
|
||||
layers += [
|
||||
nn.Conv2d(ndf * ndf_mult, 1, kernel_size=4, stride=1, padding=1)] # output 1 channel prediction map
|
||||
self.main = nn.Sequential(*layers)
|
||||
|
||||
if model_path is not None:
|
||||
chkpt = torch.load(model_path, map_location='cpu')
|
||||
if 'params_d' in chkpt:
|
||||
self.load_state_dict(torch.load(model_path, map_location='cpu')['params_d'])
|
||||
elif 'params' in chkpt:
|
||||
self.load_state_dict(torch.load(model_path, map_location='cpu')['params'])
|
||||
else:
|
||||
raise ValueError(f'Wrong params!')
|
||||
|
||||
def forward(self, x):
|
||||
return self.main(x)
|
@ -10,7 +10,7 @@ import traceback
|
||||
from typing import Callable
|
||||
from urllib import request, error as ul_error
|
||||
from huggingface_hub import HfFolder, hf_hub_url, ModelSearchArguments, ModelFilter, HfApi
|
||||
from ldm.invoke.globals import Globals
|
||||
from invokeai.backend.globals import Globals
|
||||
|
||||
class HuggingFaceConceptsLibrary(object):
|
||||
def __init__(self, root=None):
|
||||
|
@ -26,11 +26,11 @@ from torchvision.transforms.functional import resize as tv_resize
|
||||
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
||||
from typing_extensions import ParamSpec
|
||||
|
||||
from ldm.invoke.globals import Globals
|
||||
from ..stable_diffusion.diffusion import InvokeAIDiffuserComponent, PostprocessingSettings, AttentionMapSaver
|
||||
from ..stable_diffusion.textual_inversion_manager import TextualInversionManager
|
||||
from ..stable_diffusion.offloading import LazilyLoadedModelGroup, FullyLoadedModelGroup, ModelGroup
|
||||
from ..devices import normalize_device, CPU_DEVICE
|
||||
from invokeai.backend.globals import Globals
|
||||
from .diffusion import InvokeAIDiffuserComponent, PostprocessingSettings, AttentionMapSaver
|
||||
from .textual_inversion_manager import TextualInversionManager
|
||||
from .offloading import LazilyLoadedModelGroup, FullyLoadedModelGroup, ModelGroup
|
||||
from ..util import normalize_device, CPU_DEVICE
|
||||
from compel import EmbeddingsProvider
|
||||
|
||||
@dataclass
|
||||
|
@ -15,7 +15,7 @@ from torch import nn
|
||||
from compel.cross_attention_control import Arguments
|
||||
from diffusers.models.unet_2d_condition import UNet2DConditionModel
|
||||
from diffusers.models.cross_attention import AttnProcessor
|
||||
from ...devices import torch_dtype
|
||||
from ...util import torch_dtype
|
||||
|
||||
|
||||
class CrossAttentionType(enum.Enum):
|
||||
|
@ -23,7 +23,7 @@ from omegaconf import ListConfig
|
||||
import urllib
|
||||
|
||||
from ..textual_inversion_manager import TextualInversionManager
|
||||
from ...util import (
|
||||
from ...util.util import (
|
||||
log_txt_as_img,
|
||||
exists,
|
||||
default,
|
||||
|
@ -4,7 +4,7 @@ import torch
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
from functools import partial
|
||||
from ...devices import choose_torch_device
|
||||
from ...util import choose_torch_device
|
||||
from .shared_invokeai_diffusion import InvokeAIDiffuserComponent
|
||||
from .sampler import Sampler
|
||||
from ..diffusionmodules.util import noise_like
|
||||
|
@ -7,7 +7,7 @@ import torch
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
from functools import partial
|
||||
from ...devices import choose_torch_device
|
||||
from ...util import choose_torch_device
|
||||
from .shared_invokeai_diffusion import InvokeAIDiffuserComponent
|
||||
|
||||
from ..diffusionmodules.util import (
|
||||
|
@ -8,7 +8,7 @@ import torch
|
||||
from diffusers.models.cross_attention import AttnProcessor
|
||||
from typing_extensions import TypeAlias
|
||||
|
||||
from ldm.invoke.globals import Globals
|
||||
from invokeai.backend.globals import Globals
|
||||
from .cross_attention_control import Arguments, \
|
||||
restore_default_cross_attention, override_cross_attention, Context, get_cross_attention_modules, \
|
||||
CrossAttentionType, SwapCrossAttnContext
|
||||
|
@ -15,7 +15,7 @@ import torch.nn as nn
|
||||
import numpy as np
|
||||
from einops import repeat
|
||||
|
||||
from ...util import instantiate_from_config
|
||||
from ...util.util import instantiate_from_config
|
||||
|
||||
|
||||
def make_beta_schedule(
|
||||
|
@ -10,7 +10,7 @@ from einops import repeat
|
||||
from transformers import CLIPTokenizer, CLIPTextModel
|
||||
|
||||
from ldm.invoke.devices import choose_torch_device
|
||||
from ldm.invoke.globals import global_cache_dir
|
||||
from invokeai.backend.globals import global_cache_dir
|
||||
from ldm.modules.x_transformer import (
|
||||
Encoder,
|
||||
TransformerWrapper,
|
||||
|
4
invokeai/backend/training/__init.py__
Normal file
@ -0,0 +1,4 @@
|
||||
'''
|
||||
Initialization file for invokeai.backend.training
|
||||
'''
|
||||
from .textual_inversion_training import do_textual_inversion_training, parse_args
|
1009
invokeai/backend/training/textual_inversion_training.py
Normal file
18
invokeai/backend/util/__init__.py
Normal file
@ -0,0 +1,18 @@
|
||||
'''
|
||||
Initialization file for invokeai.backend.util
|
||||
'''
|
||||
from .devices import (choose_torch_device,
|
||||
choose_precision,
|
||||
normalize_device,
|
||||
torch_dtype,
|
||||
CPU_DEVICE,
|
||||
CUDA_DEVICE,
|
||||
MPS_DEVICE,
|
||||
)
|
||||
from .util import (ask_user,
|
||||
download_with_resume,
|
||||
instantiate_from_config,
|
||||
url_attachment_name,
|
||||
)
|
||||
from .log import write_log
|
||||
|
@ -5,9 +5,11 @@ from contextlib import nullcontext
|
||||
import torch
|
||||
from torch import autocast
|
||||
|
||||
from ldm.invoke.globals import Globals
|
||||
from invokeai.backend.globals import Globals
|
||||
|
||||
CPU_DEVICE = torch.device("cpu")
|
||||
CUDA_DEVICE = torch.device("cuda")
|
||||
MPS_DEVICE = torch.device("mps")
|
||||
|
||||
def choose_torch_device() -> torch.device:
|
||||
'''Convenience routine for guessing which GPU device to run model on'''
|
66
invokeai/backend/util/log.py
Normal file
@ -0,0 +1,66 @@
|
||||
"""
|
||||
Functions for better format logging
|
||||
write_log -- logs the name of the output image, prompt, and prompt args to the terminal and different types of file
|
||||
1 write_log_message -- Writes a message to the console
|
||||
2 write_log_files -- Writes a message to files
|
||||
2.1 write_log_default -- File in plain text
|
||||
2.2 write_log_txt -- File in txt format
|
||||
2.3 write_log_markdown -- File in markdown format
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
|
||||
def write_log(results, log_path, file_types, output_cntr):
|
||||
"""
|
||||
logs the name of the output image, prompt, and prompt args to the terminal and files
|
||||
"""
|
||||
output_cntr = write_log_message(results, output_cntr)
|
||||
write_log_files(results, log_path, file_types)
|
||||
return output_cntr
|
||||
|
||||
|
||||
def write_log_message(results, output_cntr):
|
||||
"""logs to the terminal"""
|
||||
if len(results) == 0:
|
||||
return output_cntr
|
||||
log_lines = [f"{path}: {prompt}\n" for path, prompt in results]
|
||||
if len(log_lines)>1:
|
||||
subcntr = 1
|
||||
for l in log_lines:
|
||||
print(f"[{output_cntr}.{subcntr}] {l}", end="")
|
||||
subcntr += 1
|
||||
else:
|
||||
print(f"[{output_cntr}] {log_lines[0]}", end="")
|
||||
return output_cntr+1
|
||||
|
||||
def write_log_files(results, log_path, file_types):
|
||||
for file_type in file_types:
|
||||
if file_type == "txt":
|
||||
write_log_txt(log_path, results)
|
||||
elif file_type == "md" or file_type == "markdown":
|
||||
write_log_markdown(log_path, results)
|
||||
else:
|
||||
print(f"'{file_type}' format is not supported, so write in plain text")
|
||||
write_log_default(log_path, results, file_type)
|
||||
|
||||
|
||||
def write_log_default(log_path, results, file_type):
|
||||
plain_txt_lines = [f"{path}: {prompt}\n" for path, prompt in results]
|
||||
with open(log_path + "." + file_type, "a", encoding="utf-8") as file:
|
||||
file.writelines(plain_txt_lines)
|
||||
|
||||
|
||||
def write_log_txt(log_path, results):
|
||||
txt_lines = [f"{path}: {prompt}\n" for path, prompt in results]
|
||||
with open(log_path + ".txt", "a", encoding="utf-8") as file:
|
||||
file.writelines(txt_lines)
|
||||
|
||||
|
||||
def write_log_markdown(log_path, results):
|
||||
md_lines = []
|
||||
for path, prompt in results:
|
||||
file_name = os.path.basename(path)
|
||||
md_lines.append(f"## {file_name}\n\n\n{prompt}\n")
|
||||
with open(log_path + ".md", "a", encoding="utf-8") as file:
|
||||
file.writelines(md_lines)
|
4
invokeai/backend/web/__init__.py
Normal file
@ -0,0 +1,4 @@
|
||||
'''
|
||||
Initialization file for the web backend.
|
||||
'''
|
||||
from .invoke_ai_web_server import InvokeAIWebServer
|
@ -12,7 +12,7 @@ from threading import Event
|
||||
from uuid import uuid4
|
||||
|
||||
import eventlet
|
||||
import invokeai.frontend.dist as frontend
|
||||
import invokeai.frontend.web.dist as frontend
|
||||
from PIL import Image
|
||||
from PIL.Image import Image as ImageType
|
||||
from compel.prompt_parser import Blend
|
||||
@ -20,24 +20,24 @@ from flask import Flask, redirect, send_from_directory, request, make_response
|
||||
from flask_socketio import SocketIO
|
||||
from werkzeug.utils import secure_filename
|
||||
|
||||
from invokeai.backend.modules.get_canvas_generation_mode import (
|
||||
from .modules.get_canvas_generation_mode import (
|
||||
get_canvas_generation_mode,
|
||||
)
|
||||
from .modules.parameters import parameters_to_command
|
||||
from .prompting import (get_tokens_for_prompt_object,
|
||||
get_prompt_structure,
|
||||
get_tokenizer
|
||||
)
|
||||
from .image_util import PngWriter, retrieve_metadata
|
||||
from .generator import infill_methods
|
||||
from .stable_diffusion import PipelineIntermediateState
|
||||
from ..prompting import (get_tokens_for_prompt_object,
|
||||
get_prompt_structure,
|
||||
get_tokenizer
|
||||
)
|
||||
from ..image_util import PngWriter, retrieve_metadata
|
||||
from ..generator import infill_methods
|
||||
from ..stable_diffusion import PipelineIntermediateState
|
||||
|
||||
from ldm.generate import Generate
|
||||
from ldm.invoke.args import Args, APP_ID, APP_VERSION, calculate_init_img_hash
|
||||
from ldm.invoke.globals import ( Globals, global_converted_ckpts_dir,
|
||||
global_models_dir
|
||||
)
|
||||
from ldm.invoke.merge_diffusers import merge_diffusion_models
|
||||
from .. import Generate
|
||||
from ..args import Args, APP_ID, APP_VERSION, calculate_init_img_hash
|
||||
from ..globals import ( Globals, global_converted_ckpts_dir,
|
||||
global_models_dir
|
||||
)
|
||||
from ..model_management import merge_diffusion_models
|
||||
|
||||
# Loading Arguments
|
||||
opt = Args()
|
||||
@ -236,7 +236,7 @@ class InvokeAIWebServer:
|
||||
sys.exit(0)
|
||||
else:
|
||||
useSSL = args.certfile or args.keyfile
|
||||
print(">> Started Invoke AI Web Server!")
|
||||
print(">> Started Invoke AI Web Server")
|
||||
if self.host == "0.0.0.0":
|
||||
print(
|
||||
f"Point your browser at http{'s' if useSSL else ''}://localhost:{self.port} or use the host's DNS name or IP address."
|
0
invokeai/backend/web/modules/__init__.py
Normal file
@ -1,4 +1,4 @@
|
||||
from invokeai.backend.modules.parse_seed_weights import parse_seed_weights
|
||||
from .parse_seed_weights import parse_seed_weights
|
||||
import argparse
|
||||
|
||||
SAMPLER_CHOICES = [
|
Before Width: | Height: | Size: 2.7 KiB After Width: | Height: | Size: 2.7 KiB |
Before Width: | Height: | Size: 292 KiB After Width: | Height: | Size: 292 KiB |
Before Width: | Height: | Size: 164 KiB After Width: | Height: | Size: 164 KiB |
Before Width: | Height: | Size: 9.5 KiB After Width: | Height: | Size: 9.5 KiB |
Before Width: | Height: | Size: 3.4 KiB After Width: | Height: | Size: 3.4 KiB |
1237
invokeai/frontend/CLI/CLI.py
Normal file
4
invokeai/frontend/CLI/__init__.py
Normal file
@ -0,0 +1,4 @@
|
||||
'''
|
||||
Initialization file for invokeai.frontend.CLI
|
||||
'''
|
||||
from .CLI import main as invokeai_command_line_interface
|
455
invokeai/frontend/CLI/readline.py
Normal file
@ -0,0 +1,455 @@
|
||||
"""
|
||||
Readline helper functions for invoke.py.
|
||||
You may import the global singleton `completer` to get access to the
|
||||
completer object itself. This is useful when you want to autocomplete
|
||||
seeds:
|
||||
|
||||
from ldm.invoke.readline import completer
|
||||
completer.add_seed(18247566)
|
||||
completer.add_seed(9281839)
|
||||
"""
|
||||
import os
|
||||
import re
|
||||
import atexit
|
||||
from ...backend.args import Args
|
||||
from ...backend.globals import Globals
|
||||
from ...backend.stable_diffusion import HuggingFaceConceptsLibrary
|
||||
|
||||
# ---------------readline utilities---------------------
|
||||
try:
|
||||
import readline
|
||||
readline_available = True
|
||||
except (ImportError,ModuleNotFoundError) as e:
|
||||
print(f'** An error occurred when loading the readline module: {str(e)}')
|
||||
readline_available = False
|
||||
|
||||
IMG_EXTENSIONS = ('.png','.jpg','.jpeg','.PNG','.JPG','.JPEG','.gif','.GIF')
|
||||
WEIGHT_EXTENSIONS = ('.ckpt','.vae','.safetensors')
|
||||
TEXT_EXTENSIONS = ('.txt','.TXT')
|
||||
CONFIG_EXTENSIONS = ('.yaml','.yml')
|
||||
COMMANDS = (
|
||||
'--steps','-s',
|
||||
'--seed','-S',
|
||||
'--iterations','-n',
|
||||
'--width','-W','--height','-H',
|
||||
'--cfg_scale','-C',
|
||||
'--threshold',
|
||||
'--perlin',
|
||||
'--grid','-g',
|
||||
'--individual','-i',
|
||||
'--save_intermediates',
|
||||
'--init_img','-I',
|
||||
'--init_mask','-M',
|
||||
'--init_color',
|
||||
'--strength','-f',
|
||||
'--variants','-v',
|
||||
'--outdir','-o',
|
||||
'--sampler','-A','-m',
|
||||
'--embedding_path',
|
||||
'--device',
|
||||
'--grid','-g',
|
||||
'--facetool','-ft',
|
||||
'--facetool_strength','-G',
|
||||
'--codeformer_fidelity','-cf',
|
||||
'--upscale','-U',
|
||||
'-save_orig','--save_original',
|
||||
'--log_tokenization','-t',
|
||||
'--hires_fix',
|
||||
'--inpaint_replace','-r',
|
||||
'--png_compression','-z',
|
||||
'--text_mask','-tm',
|
||||
'--h_symmetry_time_pct',
|
||||
'--v_symmetry_time_pct',
|
||||
'!fix','!fetch','!replay','!history','!search','!clear',
|
||||
'!models','!switch','!import_model','!optimize_model','!convert_model','!edit_model','!del_model',
|
||||
'!mask','!triggers',
|
||||
)
|
||||
MODEL_COMMANDS = (
|
||||
'!switch',
|
||||
'!edit_model',
|
||||
'!del_model',
|
||||
)
|
||||
CKPT_MODEL_COMMANDS = (
|
||||
'!optimize_model',
|
||||
)
|
||||
WEIGHT_COMMANDS = (
|
||||
'!import_model',
|
||||
'!convert_model',
|
||||
)
|
||||
IMG_PATH_COMMANDS = (
|
||||
'--outdir[=\s]',
|
||||
)
|
||||
TEXT_PATH_COMMANDS=(
|
||||
'!replay',
|
||||
)
|
||||
IMG_FILE_COMMANDS=(
|
||||
'!fix',
|
||||
'!fetch',
|
||||
'!mask',
|
||||
'--init_img[=\s]','-I',
|
||||
'--init_mask[=\s]','-M',
|
||||
'--init_color[=\s]',
|
||||
'--embedding_path[=\s]',
|
||||
)
|
||||
|
||||
path_regexp = '(' + '|'.join(IMG_PATH_COMMANDS+IMG_FILE_COMMANDS) + ')\s*\S*$'
|
||||
weight_regexp = '(' + '|'.join(WEIGHT_COMMANDS) + ')\s*\S*$'
|
||||
text_regexp = '(' + '|'.join(TEXT_PATH_COMMANDS) + ')\s*\S*$'
|
||||
|
||||
class Completer(object):
|
||||
def __init__(self, options, models={}):
|
||||
self.options = sorted(options)
|
||||
self.models = models
|
||||
self.seeds = set()
|
||||
self.matches = list()
|
||||
self.default_dir = None
|
||||
self.linebuffer = None
|
||||
self.auto_history_active = True
|
||||
self.extensions = None
|
||||
self.concepts = None
|
||||
self.embedding_terms = set()
|
||||
return
|
||||
|
||||
def complete(self, text, state):
|
||||
'''
|
||||
Completes invoke command line.
|
||||
BUG: it doesn't correctly complete files that have spaces in the name.
|
||||
'''
|
||||
buffer = readline.get_line_buffer()
|
||||
|
||||
if state == 0:
|
||||
|
||||
# extensions defined, so go directly into path completion mode
|
||||
if self.extensions is not None:
|
||||
self.matches = self._path_completions(text, state, self.extensions)
|
||||
|
||||
# looking for an image file
|
||||
elif re.search(path_regexp,buffer):
|
||||
do_shortcut = re.search('^'+'|'.join(IMG_FILE_COMMANDS),buffer)
|
||||
self.matches = self._path_completions(text, state, IMG_EXTENSIONS,shortcut_ok=do_shortcut)
|
||||
|
||||
# looking for a seed
|
||||
elif re.search('(-S\s*|--seed[=\s])\d*$',buffer):
|
||||
self.matches= self._seed_completions(text,state)
|
||||
|
||||
# looking for an embedding concept
|
||||
elif re.search('<[\w-]*$',buffer):
|
||||
self.matches= self._concept_completions(text,state)
|
||||
|
||||
# looking for a model
|
||||
elif re.match('^'+'|'.join(MODEL_COMMANDS),buffer):
|
||||
self.matches= self._model_completions(text, state)
|
||||
|
||||
# looking for a ckpt model
|
||||
elif re.match('^'+'|'.join(CKPT_MODEL_COMMANDS),buffer):
|
||||
self.matches= self._model_completions(text, state, ckpt_only=True)
|
||||
|
||||
elif re.search(weight_regexp,buffer):
|
||||
self.matches = self._path_completions(
|
||||
text,
|
||||
state,
|
||||
WEIGHT_EXTENSIONS,
|
||||
default_dir=Globals.root,
|
||||
)
|
||||
|
||||
elif re.search(text_regexp,buffer):
|
||||
self.matches = self._path_completions(text, state, TEXT_EXTENSIONS)
|
||||
|
||||
# This is the first time for this text, so build a match list.
|
||||
elif text:
|
||||
self.matches = [
|
||||
s for s in self.options if s and s.startswith(text)
|
||||
]
|
||||
else:
|
||||
self.matches = self.options[:]
|
||||
|
||||
# Return the state'th item from the match list,
|
||||
# if we have that many.
|
||||
try:
|
||||
response = self.matches[state]
|
||||
except IndexError:
|
||||
response = None
|
||||
return response
|
||||
|
||||
def complete_extensions(self, extensions:list):
|
||||
'''
|
||||
If called with a list of extensions, will force completer
|
||||
to do file path completions.
|
||||
'''
|
||||
self.extensions=extensions
|
||||
|
||||
def add_history(self,line):
|
||||
'''
|
||||
Pass thru to readline
|
||||
'''
|
||||
if not self.auto_history_active:
|
||||
readline.add_history(line)
|
||||
|
||||
def clear_history(self):
|
||||
'''
|
||||
Pass clear_history() thru to readline
|
||||
'''
|
||||
readline.clear_history()
|
||||
|
||||
def search_history(self,match:str):
|
||||
'''
|
||||
Like show_history() but only shows items that
|
||||
contain the match string.
|
||||
'''
|
||||
self.show_history(match)
|
||||
|
||||
def remove_history_item(self,pos):
|
||||
readline.remove_history_item(pos)
|
||||
|
||||
def add_seed(self, seed):
|
||||
'''
|
||||
Add a seed to the autocomplete list for display when -S is autocompleted.
|
||||
'''
|
||||
if seed is not None:
|
||||
self.seeds.add(str(seed))
|
||||
|
||||
def set_default_dir(self, path):
|
||||
self.default_dir=path
|
||||
|
||||
def set_options(self,options):
|
||||
self.options = options
|
||||
|
||||
def get_line(self,index):
|
||||
try:
|
||||
line = self.get_history_item(index)
|
||||
except IndexError:
|
||||
return None
|
||||
return line
|
||||
|
||||
def get_current_history_length(self):
|
||||
return readline.get_current_history_length()
|
||||
|
||||
def get_history_item(self,index):
|
||||
return readline.get_history_item(index)
|
||||
|
||||
def show_history(self,match=None):
|
||||
'''
|
||||
Print the session history using the pydoc pager
|
||||
'''
|
||||
import pydoc
|
||||
lines = list()
|
||||
h_len = self.get_current_history_length()
|
||||
if h_len < 1:
|
||||
print('<empty history>')
|
||||
return
|
||||
|
||||
for i in range(0,h_len):
|
||||
line = self.get_history_item(i+1)
|
||||
if match and match not in line:
|
||||
continue
|
||||
lines.append(f'[{i+1}] {line}')
|
||||
pydoc.pager('\n'.join(lines))
|
||||
|
||||
def set_line(self,line)->None:
|
||||
'''
|
||||
Set the default string displayed in the next line of input.
|
||||
'''
|
||||
self.linebuffer = line
|
||||
readline.redisplay()
|
||||
|
||||
def update_models(self,models:dict)->None:
|
||||
'''
|
||||
update our list of models
|
||||
'''
|
||||
self.models = models
|
||||
|
||||
def _seed_completions(self, text, state):
|
||||
m = re.search('(-S\s?|--seed[=\s]?)(\d*)',text)
|
||||
if m:
|
||||
switch = m.groups()[0]
|
||||
partial = m.groups()[1]
|
||||
else:
|
||||
switch = ''
|
||||
partial = text
|
||||
|
||||
matches = list()
|
||||
for s in self.seeds:
|
||||
if s.startswith(partial):
|
||||
matches.append(switch+s)
|
||||
matches.sort()
|
||||
return matches
|
||||
|
||||
def add_embedding_terms(self, terms:list[str]):
|
||||
self.embedding_terms = set(terms)
|
||||
if self.concepts:
|
||||
self.embedding_terms.update(set(self.concepts.list_concepts()))
|
||||
|
||||
def _concept_completions(self, text, state):
|
||||
if self.concepts is None:
|
||||
# cache Concepts() instance so we can check for updates in concepts_list during runtime.
|
||||
self.concepts = HuggingFaceConceptsLibrary()
|
||||
self.embedding_terms.update(set(self.concepts.list_concepts()))
|
||||
else:
|
||||
self.embedding_terms.update(set(self.concepts.list_concepts()))
|
||||
|
||||
partial = text[1:] # this removes the leading '<'
|
||||
if len(partial) == 0:
|
||||
return list(self.embedding_terms) # whole dump - think if user wants this!
|
||||
|
||||
matches = list()
|
||||
for concept in self.embedding_terms:
|
||||
if concept.startswith(partial):
|
||||
matches.append(f'<{concept}>')
|
||||
matches.sort()
|
||||
return matches
|
||||
|
||||
def _model_completions(self, text, state, ckpt_only=False):
|
||||
m = re.search('(!switch\s+)(\w*)',text)
|
||||
if m:
|
||||
switch = m.groups()[0]
|
||||
partial = m.groups()[1]
|
||||
else:
|
||||
switch = ''
|
||||
partial = text
|
||||
matches = list()
|
||||
for s in self.models:
|
||||
format = self.models[s]['format']
|
||||
if format == 'vae':
|
||||
continue
|
||||
if ckpt_only and format != 'ckpt':
|
||||
continue
|
||||
if s.startswith(partial):
|
||||
matches.append(switch+s)
|
||||
matches.sort()
|
||||
return matches
|
||||
|
||||
def _pre_input_hook(self):
|
||||
if self.linebuffer:
|
||||
readline.insert_text(self.linebuffer)
|
||||
readline.redisplay()
|
||||
self.linebuffer = None
|
||||
|
||||
def _path_completions(self, text, state, extensions, shortcut_ok=True, default_dir:str=''):
|
||||
# separate the switch from the partial path
|
||||
match = re.search('^(-\w|--\w+=?)(.*)',text)
|
||||
if match is None:
|
||||
switch = None
|
||||
partial_path = text
|
||||
else:
|
||||
switch,partial_path = match.groups()
|
||||
|
||||
partial_path = partial_path.lstrip()
|
||||
|
||||
matches = list()
|
||||
path = os.path.expanduser(partial_path)
|
||||
|
||||
if os.path.isdir(path):
|
||||
dir = path
|
||||
elif os.path.dirname(path) != '':
|
||||
dir = os.path.dirname(path)
|
||||
else:
|
||||
dir = default_dir if os.path.exists(default_dir) else ''
|
||||
path= os.path.join(dir,path)
|
||||
|
||||
dir_list = os.listdir(dir or '.')
|
||||
if shortcut_ok and os.path.exists(self.default_dir) and dir=='':
|
||||
dir_list += os.listdir(self.default_dir)
|
||||
|
||||
for node in dir_list:
|
||||
if node.startswith('.') and len(node) > 1:
|
||||
continue
|
||||
full_path = os.path.join(dir, node)
|
||||
|
||||
if not (node.endswith(extensions) or os.path.isdir(full_path)):
|
||||
continue
|
||||
|
||||
if path and not full_path.startswith(path):
|
||||
continue
|
||||
|
||||
if switch is None:
|
||||
match_path = os.path.join(dir,node)
|
||||
matches.append(match_path+'/' if os.path.isdir(full_path) else match_path)
|
||||
elif os.path.isdir(full_path):
|
||||
matches.append(
|
||||
switch+os.path.join(os.path.dirname(full_path), node) + '/'
|
||||
)
|
||||
elif node.endswith(extensions):
|
||||
matches.append(
|
||||
switch+os.path.join(os.path.dirname(full_path), node)
|
||||
)
|
||||
|
||||
return matches
|
||||
|
||||
class DummyCompleter(Completer):
|
||||
def __init__(self,options):
|
||||
super().__init__(options)
|
||||
self.history = list()
|
||||
|
||||
def add_history(self,line):
|
||||
self.history.append(line)
|
||||
|
||||
def clear_history(self):
|
||||
self.history = list()
|
||||
|
||||
def get_current_history_length(self):
|
||||
return len(self.history)
|
||||
|
||||
def get_history_item(self,index):
|
||||
return self.history[index-1]
|
||||
|
||||
def remove_history_item(self,index):
|
||||
return self.history.pop(index-1)
|
||||
|
||||
def set_line(self,line):
|
||||
print(f'# {line}')
|
||||
|
||||
def generic_completer(commands:list)->Completer:
|
||||
if readline_available:
|
||||
completer = Completer(commands,[])
|
||||
readline.set_completer(completer.complete)
|
||||
readline.set_pre_input_hook(completer._pre_input_hook)
|
||||
readline.set_completer_delims(' ')
|
||||
readline.parse_and_bind('tab: complete')
|
||||
readline.parse_and_bind('set print-completions-horizontally off')
|
||||
readline.parse_and_bind('set page-completions on')
|
||||
readline.parse_and_bind('set skip-completed-text on')
|
||||
readline.parse_and_bind('set show-all-if-ambiguous on')
|
||||
else:
|
||||
completer = DummyCompleter(commands)
|
||||
return completer
|
||||
|
||||
def get_completer(opt:Args, models=[])->Completer:
|
||||
if readline_available:
|
||||
completer = Completer(COMMANDS,models)
|
||||
|
||||
readline.set_completer(
|
||||
completer.complete
|
||||
)
|
||||
# pyreadline3 does not have a set_auto_history() method
|
||||
try:
|
||||
readline.set_auto_history(False)
|
||||
completer.auto_history_active = False
|
||||
except:
|
||||
completer.auto_history_active = True
|
||||
readline.set_pre_input_hook(completer._pre_input_hook)
|
||||
readline.set_completer_delims(' ')
|
||||
readline.parse_and_bind('tab: complete')
|
||||
readline.parse_and_bind('set print-completions-horizontally off')
|
||||
readline.parse_and_bind('set page-completions on')
|
||||
readline.parse_and_bind('set skip-completed-text on')
|
||||
readline.parse_and_bind('set show-all-if-ambiguous on')
|
||||
|
||||
outdir = os.path.expanduser(opt.outdir)
|
||||
if os.path.isabs(outdir):
|
||||
histfile = os.path.join(outdir,'.invoke_history')
|
||||
else:
|
||||
histfile = os.path.join(Globals.root, outdir, '.invoke_history')
|
||||
try:
|
||||
readline.read_history_file(histfile)
|
||||
readline.set_history_length(1000)
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
except OSError: # file likely corrupted
|
||||
newname = f'{histfile}.old'
|
||||
print(f'## Your history file {histfile} couldn\'t be loaded and may be corrupted. Renaming it to {newname}')
|
||||
os.replace(histfile,newname)
|
||||
atexit.register(readline.write_history_file, histfile)
|
||||
|
||||
else:
|
||||
completer = DummyCompleter(COMMANDS)
|
||||
return completer
|
3
invokeai/frontend/__init__.py
Normal file
@ -0,0 +1,3 @@
|
||||
'''
|
||||
Initialization file for invokeai.frontend
|
||||
'''
|
7
invokeai/frontend/config/__init__.py
Normal file
@ -0,0 +1,7 @@
|
||||
'''
|
||||
Initialization file for invokeai.frontend.config
|
||||
'''
|
||||
from .model_install import main as invokeai_model_install
|
||||
from .invokeai_configure import main as invokeai_configure
|
||||
from .invokeai_update import main as invokeai_update
|
||||
|
4
invokeai/frontend/config/invokeai_configure.py
Normal file
@ -0,0 +1,4 @@
|
||||
'''
|
||||
Wrapper for invokeai.backend.configure.invokeai_configure
|
||||
'''
|
||||
from ...backend.config.invokeai_configure import main
|
88
invokeai/frontend/config/invokeai_update.py
Normal file
@ -0,0 +1,88 @@
|
||||
'''
|
||||
Minimalist updater script. Prompts user for the tag or branch to update to and runs
|
||||
pip install <path_to_git_source>.
|
||||
'''
|
||||
import os
|
||||
import platform
|
||||
import requests
|
||||
from rich import box, print
|
||||
from rich.console import Console, Group, group
|
||||
from rich.panel import Panel
|
||||
from rich.prompt import Prompt
|
||||
from rich.style import Style
|
||||
from rich.syntax import Syntax
|
||||
from rich.text import Text
|
||||
|
||||
from invokeai.version import __version__
|
||||
|
||||
INVOKE_AI_SRC="https://github.com/invoke-ai/InvokeAI/archive"
|
||||
INVOKE_AI_REL="https://api.github.com/repos/invoke-ai/InvokeAI/releases"
|
||||
|
||||
OS = platform.uname().system
|
||||
ARCH = platform.uname().machine
|
||||
|
||||
if OS == "Windows":
|
||||
# Windows terminals look better without a background colour
|
||||
console = Console(style=Style(color="grey74"))
|
||||
else:
|
||||
console = Console(style=Style(color="grey74", bgcolor="grey19"))
|
||||
|
||||
def get_versions()->dict:
|
||||
return requests.get(url=INVOKE_AI_REL).json()
|
||||
|
||||
def welcome(versions: dict):
|
||||
|
||||
@group()
|
||||
def text():
|
||||
yield f'InvokeAI Version: [bold yellow]{__version__}'
|
||||
yield ''
|
||||
yield 'This script will update InvokeAI to the latest release, or to a development version of your choice.'
|
||||
yield ''
|
||||
yield '[bold yellow]Options:'
|
||||
yield f'''[1] Update to the latest official release ([italic]{versions[0]['tag_name']}[/italic])
|
||||
[2] Update to the bleeding-edge development version ([italic]main[/italic])
|
||||
[3] Manually enter the tag or branch name you wish to update'''
|
||||
|
||||
console.rule()
|
||||
print(
|
||||
Panel(
|
||||
title="[bold wheat1]InvokeAI Updater",
|
||||
renderable=text(),
|
||||
box=box.DOUBLE,
|
||||
expand=True,
|
||||
padding=(1, 2),
|
||||
style=Style(bgcolor="grey23", color="orange1"),
|
||||
subtitle=f"[bold grey39]{OS}-{ARCH}",
|
||||
)
|
||||
)
|
||||
console.line()
|
||||
|
||||
def main():
|
||||
versions = get_versions()
|
||||
welcome(versions)
|
||||
|
||||
tag = None
|
||||
choice = Prompt.ask('Choice:',choices=['1','2','3'],default='1')
|
||||
|
||||
if choice=='1':
|
||||
tag = versions[0]['tag_name']
|
||||
elif choice=='2':
|
||||
tag = 'main'
|
||||
elif choice=='3':
|
||||
tag = Prompt.ask('Enter an InvokeAI tag or branch name')
|
||||
|
||||
print(f':crossed_fingers: Upgrading to [yellow]{tag}[/yellow]')
|
||||
cmd = f'pip install {INVOKE_AI_SRC}/{tag}.zip --use-pep517'
|
||||
print('')
|
||||
print('')
|
||||
if os.system(cmd)==0:
|
||||
print(f':heavy_check_mark: Upgrade successful')
|
||||
else:
|
||||
print(f':exclamation: [bold red]Upgrade failed[/red bold]')
|
||||
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
main()
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
|
504
invokeai/frontend/config/model_install.py
Normal file
@ -0,0 +1,504 @@
|
||||
#!/usr/bin/env python
|
||||
# Copyright (c) 2022 Lincoln D. Stein (https://github.com/lstein)
|
||||
# Before running stable-diffusion on an internet-isolated machine,
|
||||
# run this script from one with internet connectivity. The
|
||||
# two machines must share a common .cache directory.
|
||||
|
||||
"""
|
||||
This is the npyscreen frontend to the model installation application.
|
||||
The work is actually done in backend code in model_install_backend.py.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import sys
|
||||
from argparse import Namespace
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
import npyscreen
|
||||
import torch
|
||||
from npyscreen import widget
|
||||
from omegaconf import OmegaConf
|
||||
from shutil import get_terminal_size
|
||||
|
||||
from ...backend.util import choose_precision, choose_torch_device
|
||||
from invokeai.backend.globals import Globals, global_config_dir
|
||||
from ...backend.config.model_install_backend import (Dataset_path, default_config_file,
|
||||
default_dataset, get_root,
|
||||
install_requested_models,
|
||||
recommended_datasets,
|
||||
)
|
||||
from .widgets import (MultiSelectColumns, TextBox,
|
||||
OffsetButtonPress, CenteredTitleText,
|
||||
set_min_terminal_size,
|
||||
)
|
||||
|
||||
# minimum size for the UI
|
||||
MIN_COLS = 120
|
||||
MIN_LINES = 45
|
||||
|
||||
class addModelsForm(npyscreen.FormMultiPage):
|
||||
# for responsive resizing - disabled
|
||||
#FIX_MINIMUM_SIZE_WHEN_CREATED = False
|
||||
|
||||
def __init__(self, parentApp, name, multipage=False, *args, **keywords):
|
||||
self.multipage = multipage
|
||||
self.initial_models = OmegaConf.load(Dataset_path)
|
||||
try:
|
||||
self.existing_models = OmegaConf.load(default_config_file())
|
||||
except:
|
||||
self.existing_models = dict()
|
||||
self.starter_model_list = [
|
||||
x for x in list(self.initial_models.keys()) if x not in self.existing_models
|
||||
]
|
||||
self.installed_models = dict()
|
||||
super().__init__(parentApp=parentApp, name=name, *args, **keywords)
|
||||
|
||||
def create(self):
|
||||
window_width, window_height = get_terminal_size()
|
||||
starter_model_labels = self._get_starter_model_labels()
|
||||
recommended_models = [
|
||||
x
|
||||
for x in self.starter_model_list
|
||||
if self.initial_models[x].get("recommended", False)
|
||||
]
|
||||
self.installed_models = sorted(
|
||||
[x for x in list(self.initial_models.keys()) if x in self.existing_models]
|
||||
)
|
||||
self.nextrely -= 1
|
||||
self.add_widget_intelligent(
|
||||
npyscreen.FixedText,
|
||||
value="Use ctrl-N and ctrl-P to move to the <N>ext and <P>revious fields,",
|
||||
editable=False,
|
||||
color='CAUTION',
|
||||
)
|
||||
self.add_widget_intelligent(
|
||||
npyscreen.FixedText,
|
||||
value="Use cursor arrows to make a selection, and space to toggle checkboxes.",
|
||||
editable=False,
|
||||
color='CAUTION'
|
||||
)
|
||||
self.nextrely += 1
|
||||
if len(self.installed_models) > 0:
|
||||
self.add_widget_intelligent(
|
||||
CenteredTitleText,
|
||||
name="== INSTALLED STARTER MODELS ==",
|
||||
editable=False,
|
||||
color="CONTROL",
|
||||
)
|
||||
self.nextrely -= 1
|
||||
self.add_widget_intelligent(
|
||||
CenteredTitleText,
|
||||
name="Currently installed starter models. Uncheck to delete:",
|
||||
editable=False,
|
||||
labelColor="CAUTION",
|
||||
)
|
||||
self.nextrely -= 1
|
||||
columns = self._get_columns()
|
||||
self.previously_installed_models = self.add_widget_intelligent(
|
||||
MultiSelectColumns,
|
||||
columns=columns,
|
||||
values=self.installed_models,
|
||||
value=[x for x in range(0, len(self.installed_models))],
|
||||
max_height=1 + len(self.installed_models) // columns,
|
||||
relx=4,
|
||||
slow_scroll=True,
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.purge_deleted = self.add_widget_intelligent(
|
||||
npyscreen.Checkbox,
|
||||
name="Purge deleted models from disk",
|
||||
value=False,
|
||||
scroll_exit=True,
|
||||
relx=4,
|
||||
)
|
||||
self.nextrely += 1
|
||||
if len(self.starter_model_list) > 0:
|
||||
self.add_widget_intelligent(
|
||||
CenteredTitleText,
|
||||
name="== STARTER MODELS (recommended ones selected) ==",
|
||||
editable=False,
|
||||
color="CONTROL",
|
||||
)
|
||||
self.nextrely -= 1
|
||||
self.add_widget_intelligent(
|
||||
CenteredTitleText,
|
||||
name="Select from a starter set of Stable Diffusion models from HuggingFace.",
|
||||
editable=False,
|
||||
labelColor="CAUTION",
|
||||
)
|
||||
self.nextrely -= 1
|
||||
# if user has already installed some initial models, then don't patronize them
|
||||
# by showing more recommendations
|
||||
show_recommended = not self.existing_models
|
||||
self.models_selected = self.add_widget_intelligent(
|
||||
npyscreen.MultiSelect,
|
||||
name="Install Starter Models",
|
||||
values=starter_model_labels,
|
||||
value=[
|
||||
self.starter_model_list.index(x)
|
||||
for x in self.starter_model_list
|
||||
if show_recommended and x in recommended_models
|
||||
],
|
||||
max_height=len(starter_model_labels) + 1,
|
||||
relx=4,
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.add_widget_intelligent(
|
||||
CenteredTitleText,
|
||||
name='== IMPORT LOCAL AND REMOTE MODELS ==',
|
||||
editable=False,
|
||||
color="CONTROL",
|
||||
)
|
||||
self.nextrely -= 1
|
||||
|
||||
for line in [
|
||||
"In the box below, enter URLs, file paths, or HuggingFace repository IDs.",
|
||||
"Separate model names by lines or whitespace (Use shift-control-V to paste):",
|
||||
]:
|
||||
self.add_widget_intelligent(
|
||||
CenteredTitleText,
|
||||
name=line,
|
||||
editable=False,
|
||||
labelColor="CONTROL",
|
||||
relx = 4,
|
||||
)
|
||||
self.nextrely -= 1
|
||||
self.import_model_paths = self.add_widget_intelligent(
|
||||
TextBox,
|
||||
max_height=7,
|
||||
scroll_exit=True,
|
||||
editable=True,
|
||||
relx=4
|
||||
)
|
||||
self.nextrely += 1
|
||||
self.show_directory_fields = self.add_widget_intelligent(
|
||||
npyscreen.FormControlCheckbox,
|
||||
name="Select a directory for models to import",
|
||||
value=False,
|
||||
)
|
||||
self.autoload_directory = self.add_widget_intelligent(
|
||||
npyscreen.TitleFilename,
|
||||
name="Directory (<tab> autocompletes):",
|
||||
select_dir=True,
|
||||
must_exist=True,
|
||||
use_two_lines=False,
|
||||
labelColor="DANGER",
|
||||
begin_entry_at=34,
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.autoscan_on_startup = self.add_widget_intelligent(
|
||||
npyscreen.Checkbox,
|
||||
name="Scan this directory each time InvokeAI starts for new models to import",
|
||||
value=False,
|
||||
relx=4,
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.nextrely += 1
|
||||
self.convert_models = self.add_widget_intelligent(
|
||||
npyscreen.TitleSelectOne,
|
||||
name="== CONVERT IMPORTED MODELS INTO DIFFUSERS==",
|
||||
values=["Keep original format", "Convert to diffusers"],
|
||||
value=0,
|
||||
begin_entry_at=4,
|
||||
max_height=4,
|
||||
hidden=True, # will appear when imported models box is edited
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.cancel = self.add_widget_intelligent(
|
||||
npyscreen.ButtonPress,
|
||||
name="CANCEL",
|
||||
rely=-3,
|
||||
when_pressed_function=self.on_cancel,
|
||||
)
|
||||
done_label = "DONE"
|
||||
back_label = "BACK"
|
||||
button_length = len(done_label)
|
||||
button_offset = 0
|
||||
if self.multipage:
|
||||
button_length += len(back_label) + 1
|
||||
button_offset += len(back_label) + 1
|
||||
self.back_button = self.add_widget_intelligent(
|
||||
OffsetButtonPress,
|
||||
name=back_label,
|
||||
relx=(window_width - button_length) // 2,
|
||||
offset=-3,
|
||||
rely=-3,
|
||||
when_pressed_function=self.on_back,
|
||||
)
|
||||
self.ok_button = self.add_widget_intelligent(
|
||||
OffsetButtonPress,
|
||||
name=done_label,
|
||||
offset=+3,
|
||||
relx=button_offset + 1 + (window_width - button_length) // 2,
|
||||
rely=-3,
|
||||
when_pressed_function=self.on_ok,
|
||||
)
|
||||
|
||||
for i in [self.autoload_directory, self.autoscan_on_startup]:
|
||||
self.show_directory_fields.addVisibleWhenSelected(i)
|
||||
|
||||
self.show_directory_fields.when_value_edited = self._clear_scan_directory
|
||||
self.import_model_paths.when_value_edited = self._show_hide_convert
|
||||
self.autoload_directory.when_value_edited = self._show_hide_convert
|
||||
|
||||
def resize(self):
|
||||
super().resize()
|
||||
if hasattr(self,'models_selected'):
|
||||
self.models_selected.values = self._get_starter_model_labels()
|
||||
|
||||
def _clear_scan_directory(self):
|
||||
if not self.show_directory_fields.value:
|
||||
self.autoload_directory.value = ""
|
||||
|
||||
def _show_hide_convert(self):
|
||||
model_paths = self.import_model_paths.value or ""
|
||||
autoload_directory = self.autoload_directory.value or ""
|
||||
self.convert_models.hidden = (
|
||||
len(model_paths) == 0 and len(autoload_directory) == 0
|
||||
)
|
||||
|
||||
def _get_starter_model_labels(self) -> List[str]:
|
||||
window_width, window_height = get_terminal_size()
|
||||
label_width = 25
|
||||
checkbox_width = 4
|
||||
spacing_width = 2
|
||||
description_width = window_width - label_width - checkbox_width - spacing_width
|
||||
im = self.initial_models
|
||||
names = self.starter_model_list
|
||||
descriptions = [
|
||||
im[x].description[0 : description_width - 3] + "..."
|
||||
if len(im[x].description) > description_width
|
||||
else im[x].description
|
||||
for x in names
|
||||
]
|
||||
return [
|
||||
f"%-{label_width}s %s" % (names[x], descriptions[x])
|
||||
for x in range(0, len(names))
|
||||
]
|
||||
|
||||
def _get_columns(self) -> int:
|
||||
window_width, window_height = get_terminal_size()
|
||||
cols = (
|
||||
4
|
||||
if window_width > 240
|
||||
else 3
|
||||
if window_width > 160
|
||||
else 2
|
||||
if window_width > 80
|
||||
else 1
|
||||
)
|
||||
return min(cols, len(self.installed_models))
|
||||
|
||||
def on_ok(self):
|
||||
self.parentApp.setNextForm(None)
|
||||
self.editing = False
|
||||
self.parentApp.user_cancelled = False
|
||||
self.marshall_arguments()
|
||||
|
||||
def on_back(self):
|
||||
self.parentApp.switchFormPrevious()
|
||||
self.editing = False
|
||||
|
||||
def on_cancel(self):
|
||||
if npyscreen.notify_yes_no(
|
||||
"Are you sure you want to cancel?\nYou may re-run this script later using the invoke.sh or invoke.bat command.\n"
|
||||
):
|
||||
self.parentApp.setNextForm(None)
|
||||
self.parentApp.user_cancelled = True
|
||||
self.editing = False
|
||||
|
||||
def marshall_arguments(self):
|
||||
"""
|
||||
Assemble arguments and store as attributes of the application:
|
||||
.starter_models: dict of model names to install from INITIAL_CONFIGURE.yaml
|
||||
True => Install
|
||||
False => Remove
|
||||
.scan_directory: Path to a directory of models to scan and import
|
||||
.autoscan_on_startup: True if invokeai should scan and import at startup time
|
||||
.import_model_paths: list of URLs, repo_ids and file paths to import
|
||||
.convert_to_diffusers: if True, convert legacy checkpoints into diffusers
|
||||
"""
|
||||
# we're using a global here rather than storing the result in the parentapp
|
||||
# due to some bug in npyscreen that is causing attributes to be lost
|
||||
selections = self.parentApp.user_selections
|
||||
|
||||
# starter models to install/remove
|
||||
if hasattr(self,'models_selected'):
|
||||
starter_models = dict(
|
||||
map(
|
||||
lambda x: (self.starter_model_list[x], True), self.models_selected.value
|
||||
)
|
||||
)
|
||||
else:
|
||||
starter_models = dict()
|
||||
selections.purge_deleted_models = False
|
||||
if hasattr(self, "previously_installed_models"):
|
||||
unchecked = [
|
||||
self.previously_installed_models.values[x]
|
||||
for x in range(0, len(self.previously_installed_models.values))
|
||||
if x not in self.previously_installed_models.value
|
||||
]
|
||||
starter_models.update(map(lambda x: (x, False), unchecked))
|
||||
selections.purge_deleted_models = self.purge_deleted.value
|
||||
selections.starter_models = starter_models
|
||||
|
||||
# load directory and whether to scan on startup
|
||||
if self.show_directory_fields.value:
|
||||
selections.scan_directory = self.autoload_directory.value
|
||||
selections.autoscan_on_startup = self.autoscan_on_startup.value
|
||||
else:
|
||||
selections.scan_directory = None
|
||||
selections.autoscan_on_startup = False
|
||||
|
||||
# URLs and the like
|
||||
selections.import_model_paths = self.import_model_paths.value.split()
|
||||
selections.convert_to_diffusers = self.convert_models.value[0] == 1
|
||||
|
||||
|
||||
class AddModelApplication(npyscreen.NPSAppManaged):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.user_cancelled = False
|
||||
self.user_selections = Namespace(
|
||||
starter_models=None,
|
||||
purge_deleted_models=False,
|
||||
scan_directory=None,
|
||||
autoscan_on_startup=None,
|
||||
import_model_paths=None,
|
||||
convert_to_diffusers=None,
|
||||
)
|
||||
|
||||
def onStart(self):
|
||||
npyscreen.setTheme(npyscreen.Themes.DefaultTheme)
|
||||
self.main_form = self.addForm(
|
||||
"MAIN", addModelsForm, name="Install Stable Diffusion Models"
|
||||
)
|
||||
|
||||
# --------------------------------------------------------
|
||||
def process_and_execute(opt: Namespace, selections: Namespace):
|
||||
models_to_remove = [
|
||||
x for x in selections.starter_models if not selections.starter_models[x]
|
||||
]
|
||||
models_to_install = [
|
||||
x for x in selections.starter_models if selections.starter_models[x]
|
||||
]
|
||||
directory_to_scan = selections.scan_directory
|
||||
scan_at_startup = selections.autoscan_on_startup
|
||||
potential_models_to_install = selections.import_model_paths
|
||||
convert_to_diffusers = selections.convert_to_diffusers
|
||||
|
||||
install_requested_models(
|
||||
install_initial_models=models_to_install,
|
||||
remove_models=models_to_remove,
|
||||
scan_directory=Path(directory_to_scan) if directory_to_scan else None,
|
||||
external_models=potential_models_to_install,
|
||||
scan_at_startup=scan_at_startup,
|
||||
convert_to_diffusers=convert_to_diffusers,
|
||||
precision="float32"
|
||||
if opt.full_precision
|
||||
else choose_precision(torch.device(choose_torch_device())),
|
||||
purge_deleted=selections.purge_deleted_models,
|
||||
config_file_path=Path(opt.config_file) if opt.config_file else None,
|
||||
)
|
||||
|
||||
|
||||
# --------------------------------------------------------
|
||||
def select_and_download_models(opt: Namespace):
|
||||
precision = (
|
||||
"float32"
|
||||
if opt.full_precision
|
||||
else choose_precision(torch.device(choose_torch_device()))
|
||||
)
|
||||
if opt.default_only:
|
||||
install_requested_models(
|
||||
install_initial_models=default_dataset(),
|
||||
precision=precision,
|
||||
)
|
||||
elif opt.yes_to_all:
|
||||
install_requested_models(
|
||||
install_initial_models=recommended_datasets(),
|
||||
precision=precision,
|
||||
)
|
||||
else:
|
||||
set_min_terminal_size(MIN_COLS, MIN_LINES)
|
||||
installApp = AddModelApplication()
|
||||
installApp.run()
|
||||
|
||||
if not installApp.user_cancelled:
|
||||
process_and_execute(opt, installApp.user_selections)
|
||||
|
||||
|
||||
# -------------------------------------
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="InvokeAI model downloader")
|
||||
parser.add_argument(
|
||||
"--full-precision",
|
||||
dest="full_precision",
|
||||
action=argparse.BooleanOptionalAction,
|
||||
type=bool,
|
||||
default=False,
|
||||
help="use 32-bit weights instead of faster 16-bit weights",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--yes",
|
||||
"-y",
|
||||
dest="yes_to_all",
|
||||
action="store_true",
|
||||
help='answer "yes" to all prompts',
|
||||
)
|
||||
parser.add_argument(
|
||||
"--default_only",
|
||||
action="store_true",
|
||||
help="only install the default model",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--config_file",
|
||||
"-c",
|
||||
dest="config_file",
|
||||
type=str,
|
||||
default=None,
|
||||
help="path to configuration file to create",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--root_dir",
|
||||
dest="root",
|
||||
type=str,
|
||||
default=None,
|
||||
help="path to root of install directory",
|
||||
)
|
||||
opt = parser.parse_args()
|
||||
|
||||
# setting a global here
|
||||
Globals.root = os.path.expanduser(get_root(opt.root) or "")
|
||||
|
||||
if not global_config_dir().exists():
|
||||
print(
|
||||
">> Your InvokeAI root directory is not set up. Calling invokeai-configure."
|
||||
)
|
||||
import ldm.invoke.config.invokeai_configure
|
||||
|
||||
ldm.invoke.config.invokeai_configure.main()
|
||||
sys.exit(0)
|
||||
|
||||
try:
|
||||
select_and_download_models(opt)
|
||||
except AssertionError as e:
|
||||
print(str(e))
|
||||
sys.exit(-1)
|
||||
except KeyboardInterrupt:
|
||||
print("\nGoodbye! Come back soon.")
|
||||
except widget.NotEnoughSpaceForWidget as e:
|
||||
if str(e).startswith("Height of 1 allocated"):
|
||||
print(
|
||||
"** Insufficient vertical space for the interface. Please make your window taller and try again"
|
||||
)
|
||||
elif str(e).startswith("addwstr"):
|
||||
print(
|
||||
"** Insufficient horizontal space for the interface. Please make your window wider and try again."
|
||||
)
|
||||
|
||||
# -------------------------------------
|
||||
if __name__ == "__main__":
|
||||
main()
|
164
invokeai/frontend/config/widgets.py
Normal file
@ -0,0 +1,164 @@
|
||||
'''
|
||||
Widget class definitions used by model_select.py, merge_diffusers.py and textual_inversion.py
|
||||
'''
|
||||
import math
|
||||
import platform
|
||||
import npyscreen
|
||||
import os
|
||||
import sys
|
||||
import curses
|
||||
import struct
|
||||
|
||||
from shutil import get_terminal_size
|
||||
|
||||
# -------------------------------------
|
||||
def set_terminal_size(columns: int, lines: int):
|
||||
OS = platform.uname().system
|
||||
if OS=="Windows":
|
||||
os.system(f'mode con: cols={columns} lines={lines}')
|
||||
elif OS in ['Darwin', 'Linux']:
|
||||
import termios
|
||||
import fcntl
|
||||
winsize = struct.pack("HHHH", lines, columns, 0, 0)
|
||||
fcntl.ioctl(sys.stdout.fileno(), termios.TIOCSWINSZ, winsize)
|
||||
sys.stdout.write("\x1b[8;{rows};{cols}t".format(rows=lines, cols=columns))
|
||||
sys.stdout.flush()
|
||||
|
||||
def set_min_terminal_size(min_cols: int, min_lines: int):
|
||||
# make sure there's enough room for the ui
|
||||
term_cols, term_lines = get_terminal_size()
|
||||
cols = max(term_cols, min_cols)
|
||||
lines = max(term_lines, min_lines)
|
||||
set_terminal_size(cols,lines)
|
||||
|
||||
class IntSlider(npyscreen.Slider):
|
||||
def translate_value(self):
|
||||
stri = "%2d / %2d" % (self.value, self.out_of)
|
||||
l = (len(str(self.out_of))) * 2 + 4
|
||||
stri = stri.rjust(l)
|
||||
return stri
|
||||
|
||||
# -------------------------------------
|
||||
class CenteredTitleText(npyscreen.TitleText):
|
||||
def __init__(self,*args,**keywords):
|
||||
super().__init__(*args,**keywords)
|
||||
self.resize()
|
||||
|
||||
def resize(self):
|
||||
super().resize()
|
||||
maxy, maxx = self.parent.curses_pad.getmaxyx()
|
||||
label = self.name
|
||||
self.relx = (maxx - len(label)) // 2
|
||||
|
||||
# -------------------------------------
|
||||
class CenteredButtonPress(npyscreen.ButtonPress):
|
||||
def resize(self):
|
||||
super().resize()
|
||||
maxy, maxx = self.parent.curses_pad.getmaxyx()
|
||||
label = self.name
|
||||
self.relx = (maxx - len(label)) // 2
|
||||
|
||||
# -------------------------------------
|
||||
class OffsetButtonPress(npyscreen.ButtonPress):
|
||||
def __init__(self, screen, offset=0, *args, **keywords):
|
||||
super().__init__(screen, *args, **keywords)
|
||||
self.offset = offset
|
||||
|
||||
def resize(self):
|
||||
maxy, maxx = self.parent.curses_pad.getmaxyx()
|
||||
width = len(self.name)
|
||||
self.relx = self.offset + (maxx - width) // 2
|
||||
|
||||
class IntTitleSlider(npyscreen.TitleText):
|
||||
_entry_type = IntSlider
|
||||
|
||||
class FloatSlider(npyscreen.Slider):
|
||||
# this is supposed to adjust display precision, but doesn't
|
||||
def translate_value(self):
|
||||
stri = "%3.2f / %3.2f" % (self.value, self.out_of)
|
||||
l = (len(str(self.out_of))) * 2 + 4
|
||||
stri = stri.rjust(l)
|
||||
return stri
|
||||
|
||||
class FloatTitleSlider(npyscreen.TitleText):
|
||||
_entry_type = FloatSlider
|
||||
|
||||
class MultiSelectColumns(npyscreen.MultiSelect):
|
||||
def __init__(self, screen, columns: int=1, values: list=[], **keywords):
|
||||
self.columns = columns
|
||||
self.value_cnt = len(values)
|
||||
self.rows = math.ceil(self.value_cnt / self.columns)
|
||||
super().__init__(screen,values=values, **keywords)
|
||||
|
||||
def make_contained_widgets(self):
|
||||
self._my_widgets = []
|
||||
column_width = self.width // self.columns
|
||||
for h in range(self.value_cnt):
|
||||
self._my_widgets.append(
|
||||
self._contained_widgets(self.parent,
|
||||
rely=self.rely + (h % self.rows) * self._contained_widget_height,
|
||||
relx=self.relx + (h // self.rows) * column_width,
|
||||
max_width=column_width,
|
||||
max_height=self.__class__._contained_widget_height,
|
||||
)
|
||||
)
|
||||
|
||||
def set_up_handlers(self):
|
||||
super().set_up_handlers()
|
||||
self.handlers.update({
|
||||
curses.KEY_UP: self.h_cursor_line_left,
|
||||
curses.KEY_DOWN: self.h_cursor_line_right,
|
||||
}
|
||||
)
|
||||
def h_cursor_line_down(self, ch):
|
||||
self.cursor_line += self.rows
|
||||
if self.cursor_line >= len(self.values):
|
||||
if self.scroll_exit:
|
||||
self.cursor_line = len(self.values)-self.rows
|
||||
self.h_exit_down(ch)
|
||||
return True
|
||||
else:
|
||||
self.cursor_line -= self.rows
|
||||
return True
|
||||
|
||||
def h_cursor_line_up(self, ch):
|
||||
self.cursor_line -= self.rows
|
||||
if self.cursor_line < 0:
|
||||
if self.scroll_exit:
|
||||
self.cursor_line = 0
|
||||
self.h_exit_up(ch)
|
||||
else:
|
||||
self.cursor_line = 0
|
||||
|
||||
def h_cursor_line_left(self,ch):
|
||||
super().h_cursor_line_up(ch)
|
||||
|
||||
def h_cursor_line_right(self,ch):
|
||||
super().h_cursor_line_down(ch)
|
||||
|
||||
class TextBox(npyscreen.MultiLineEdit):
|
||||
def update(self, clear=True):
|
||||
if clear: self.clear()
|
||||
|
||||
HEIGHT = self.height
|
||||
WIDTH = self.width
|
||||
# draw box.
|
||||
self.parent.curses_pad.hline(self.rely, self.relx, curses.ACS_HLINE, WIDTH)
|
||||
self.parent.curses_pad.hline(self.rely + HEIGHT, self.relx, curses.ACS_HLINE, WIDTH)
|
||||
self.parent.curses_pad.vline(self.rely, self.relx, curses.ACS_VLINE, self.height)
|
||||
self.parent.curses_pad.vline(self.rely, self.relx+WIDTH, curses.ACS_VLINE, HEIGHT)
|
||||
|
||||
# draw corners
|
||||
self.parent.curses_pad.addch(self.rely, self.relx, curses.ACS_ULCORNER, )
|
||||
self.parent.curses_pad.addch(self.rely, self.relx+WIDTH, curses.ACS_URCORNER, )
|
||||
self.parent.curses_pad.addch(self.rely+HEIGHT, self.relx, curses.ACS_LLCORNER, )
|
||||
self.parent.curses_pad.addch(self.rely+HEIGHT, self.relx+WIDTH, curses.ACS_LRCORNER, )
|
||||
|
||||
# fool our superclass into thinking drawing area is smaller - this is really hacky but it seems to work
|
||||
(relx,rely,height,width) = (self.relx, self.rely, self.height, self.width)
|
||||
self.relx += 1
|
||||
self.rely += 1
|
||||
self.height -= 1
|
||||
self.width -= 1
|
||||
super().update(clear=False)
|
||||
(self.relx,self.rely,self.height,self.width) = (relx, rely, height, width)
|
4
invokeai/frontend/merge/__init__.py
Normal file
@ -0,0 +1,4 @@
|
||||
'''
|
||||
Initialization file for invokeai.frontend.merge
|
||||
'''
|
||||
from .merge_diffusers import main as invokeai_merge_diffusers
|
467
invokeai/frontend/merge/merge_diffusers.py
Normal file
@ -0,0 +1,467 @@
|
||||
"""
|
||||
ldm.invoke.merge_diffusers exports a single function call merge_diffusion_models()
|
||||
used to merge 2-3 models together and create a new InvokeAI-registered diffusion model.
|
||||
|
||||
Copyright (c) 2023 Lincoln Stein and the InvokeAI Development Team
|
||||
"""
|
||||
import argparse
|
||||
import curses
|
||||
import os
|
||||
import sys
|
||||
import traceback
|
||||
import warnings
|
||||
from argparse import Namespace
|
||||
from pathlib import Path
|
||||
from typing import List, Union
|
||||
|
||||
import npyscreen
|
||||
from diffusers import DiffusionPipeline
|
||||
from diffusers import logging as dlogging
|
||||
from npyscreen import widget
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
from ...frontend.config.widgets import FloatTitleSlider
|
||||
from ...backend.globals import (Globals, global_cache_dir, global_config_file,
|
||||
global_models_dir, global_set_root)
|
||||
from ...backend.model_management import ModelManager
|
||||
|
||||
DEST_MERGED_MODEL_DIR = "merged_models"
|
||||
|
||||
def merge_diffusion_models(
|
||||
model_ids_or_paths: List[Union[str, Path]],
|
||||
alpha: float = 0.5,
|
||||
interp: str = None,
|
||||
force: bool = False,
|
||||
**kwargs,
|
||||
) -> DiffusionPipeline:
|
||||
"""
|
||||
model_ids_or_paths - up to three models, designated by their local paths or HuggingFace repo_ids
|
||||
alpha - The interpolation parameter. Ranges from 0 to 1. It affects the ratio in which the checkpoints are merged. A 0.8 alpha
|
||||
would mean that the first model checkpoints would affect the final result far less than an alpha of 0.2
|
||||
interp - The interpolation method to use for the merging. Supports "sigmoid", "inv_sigmoid", "add_difference" and None.
|
||||
Passing None uses the default interpolation which is weighted sum interpolation. For merging three checkpoints, only "add_difference" is supported.
|
||||
force - Whether to ignore mismatch in model_config.json for the current models. Defaults to False.
|
||||
|
||||
**kwargs - the default DiffusionPipeline.get_config_dict kwargs:
|
||||
cache_dir, resume_download, force_download, proxies, local_files_only, use_auth_token, revision, torch_dtype, device_map
|
||||
"""
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
verbosity = dlogging.get_verbosity()
|
||||
dlogging.set_verbosity_error()
|
||||
|
||||
pipe = DiffusionPipeline.from_pretrained(
|
||||
model_ids_or_paths[0],
|
||||
cache_dir=kwargs.get("cache_dir", global_cache_dir()),
|
||||
custom_pipeline="checkpoint_merger",
|
||||
)
|
||||
merged_pipe = pipe.merge(
|
||||
pretrained_model_name_or_path_list=model_ids_or_paths,
|
||||
alpha=alpha,
|
||||
interp=interp,
|
||||
force=force,
|
||||
**kwargs,
|
||||
)
|
||||
dlogging.set_verbosity(verbosity)
|
||||
return merged_pipe
|
||||
|
||||
|
||||
def merge_diffusion_models_and_commit(
|
||||
models: List["str"],
|
||||
merged_model_name: str,
|
||||
alpha: float = 0.5,
|
||||
interp: str = None,
|
||||
force: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
models - up to three models, designated by their InvokeAI models.yaml model name
|
||||
merged_model_name = name for new model
|
||||
alpha - The interpolation parameter. Ranges from 0 to 1. It affects the ratio in which the checkpoints are merged. A 0.8 alpha
|
||||
would mean that the first model checkpoints would affect the final result far less than an alpha of 0.2
|
||||
interp - The interpolation method to use for the merging. Supports "weighted_average", "sigmoid", "inv_sigmoid", "add_difference" and None.
|
||||
Passing None uses the default interpolation which is weighted sum interpolation. For merging three checkpoints, only "add_difference" is supported. Add_difference is A+(B-C).
|
||||
force - Whether to ignore mismatch in model_config.json for the current models. Defaults to False.
|
||||
|
||||
**kwargs - the default DiffusionPipeline.get_config_dict kwargs:
|
||||
cache_dir, resume_download, force_download, proxies, local_files_only, use_auth_token, revision, torch_dtype, device_map
|
||||
"""
|
||||
config_file = global_config_file()
|
||||
model_manager = ModelManager(OmegaConf.load(config_file))
|
||||
for mod in models:
|
||||
assert mod in model_manager.model_names(), f'** Unknown model "{mod}"'
|
||||
assert (
|
||||
model_manager.model_info(mod).get("format", None) == "diffusers"
|
||||
), f"** {mod} is not a diffusers model. It must be optimized before merging."
|
||||
model_ids_or_paths = [model_manager.model_name_or_path(x) for x in models]
|
||||
|
||||
merged_pipe = merge_diffusion_models(
|
||||
model_ids_or_paths, alpha, interp, force, **kwargs
|
||||
)
|
||||
dump_path = global_models_dir() / DEST_MERGED_MODEL_DIR
|
||||
|
||||
os.makedirs(dump_path, exist_ok=True)
|
||||
dump_path = dump_path / merged_model_name
|
||||
merged_pipe.save_pretrained(dump_path, safe_serialization=1)
|
||||
import_args = dict(
|
||||
model_name=merged_model_name, description=f'Merge of models {", ".join(models)}'
|
||||
)
|
||||
if vae := model_manager.config[models[0]].get("vae", None):
|
||||
print(f">> Using configured VAE assigned to {models[0]}")
|
||||
import_args.update(vae=vae)
|
||||
model_manager.import_diffuser_model(dump_path, **import_args)
|
||||
model_manager.commit(config_file)
|
||||
|
||||
|
||||
def _parse_args() -> Namespace:
|
||||
parser = argparse.ArgumentParser(description="InvokeAI model merging")
|
||||
parser.add_argument(
|
||||
"--root_dir",
|
||||
type=Path,
|
||||
default=Globals.root,
|
||||
help="Path to the invokeai runtime directory",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--front_end",
|
||||
"--gui",
|
||||
dest="front_end",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Activate the text-based graphical front end for collecting parameters. Aside from --root_dir, other parameters will be ignored.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--models",
|
||||
type=str,
|
||||
nargs="+",
|
||||
help="Two to three model names to be merged",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--merged_model_name",
|
||||
"--destination",
|
||||
dest="merged_model_name",
|
||||
type=str,
|
||||
help="Name of the output model. If not specified, will be the concatenation of the input model names.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--alpha",
|
||||
type=float,
|
||||
default=0.5,
|
||||
help="The interpolation parameter, ranging from 0 to 1. It affects the ratio in which the checkpoints are merged. Higher values give more weight to the 2d and 3d models",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--interpolation",
|
||||
dest="interp",
|
||||
type=str,
|
||||
choices=["weighted_sum", "sigmoid", "inv_sigmoid", "add_difference"],
|
||||
default="weighted_sum",
|
||||
help='Interpolation method to use. If three models are present, only "add_difference" will work.',
|
||||
)
|
||||
parser.add_argument(
|
||||
"--force",
|
||||
action="store_true",
|
||||
help="Try to merge models even if they are incompatible with each other",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--clobber",
|
||||
"--overwrite",
|
||||
dest="clobber",
|
||||
action="store_true",
|
||||
help="Overwrite the merged model if --merged_model_name already exists",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
# ------------------------- GUI HERE -------------------------
|
||||
class mergeModelsForm(npyscreen.FormMultiPageAction):
|
||||
interpolations = ["weighted_sum", "sigmoid", "inv_sigmoid"]
|
||||
|
||||
def __init__(self, parentApp, name):
|
||||
self.parentApp = parentApp
|
||||
self.ALLOW_RESIZE = True
|
||||
self.FIX_MINIMUM_SIZE_WHEN_CREATED = False
|
||||
super().__init__(parentApp, name)
|
||||
|
||||
@property
|
||||
def model_manager(self):
|
||||
return self.parentApp.model_manager
|
||||
|
||||
def afterEditing(self):
|
||||
self.parentApp.setNextForm(None)
|
||||
|
||||
def create(self):
|
||||
window_height, window_width = curses.initscr().getmaxyx()
|
||||
|
||||
self.model_names = self.get_model_names()
|
||||
max_width = max([len(x) for x in self.model_names])
|
||||
max_width += 6
|
||||
horizontal_layout = max_width * 3 < window_width
|
||||
|
||||
self.add_widget_intelligent(
|
||||
npyscreen.FixedText,
|
||||
color="CONTROL",
|
||||
value="Select two models to merge and optionally a third.",
|
||||
editable=False,
|
||||
)
|
||||
self.add_widget_intelligent(
|
||||
npyscreen.FixedText,
|
||||
color="CONTROL",
|
||||
value="Use up and down arrows to move, <space> to select an item, <tab> and <shift-tab> to move from one field to the next.",
|
||||
editable=False,
|
||||
)
|
||||
self.add_widget_intelligent(
|
||||
npyscreen.FixedText,
|
||||
value="MODEL 1",
|
||||
color="GOOD",
|
||||
editable=False,
|
||||
rely=4 if horizontal_layout else None,
|
||||
)
|
||||
self.model1 = self.add_widget_intelligent(
|
||||
npyscreen.SelectOne,
|
||||
values=self.model_names,
|
||||
value=0,
|
||||
max_height=len(self.model_names),
|
||||
max_width=max_width,
|
||||
scroll_exit=True,
|
||||
rely=5,
|
||||
)
|
||||
self.add_widget_intelligent(
|
||||
npyscreen.FixedText,
|
||||
value="MODEL 2",
|
||||
color="GOOD",
|
||||
editable=False,
|
||||
relx=max_width + 3 if horizontal_layout else None,
|
||||
rely=4 if horizontal_layout else None,
|
||||
)
|
||||
self.model2 = self.add_widget_intelligent(
|
||||
npyscreen.SelectOne,
|
||||
name="(2)",
|
||||
values=self.model_names,
|
||||
value=1,
|
||||
max_height=len(self.model_names),
|
||||
max_width=max_width,
|
||||
relx=max_width + 3 if horizontal_layout else None,
|
||||
rely=5 if horizontal_layout else None,
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.add_widget_intelligent(
|
||||
npyscreen.FixedText,
|
||||
value="MODEL 3",
|
||||
color="GOOD",
|
||||
editable=False,
|
||||
relx=max_width * 2 + 3 if horizontal_layout else None,
|
||||
rely=4 if horizontal_layout else None,
|
||||
)
|
||||
models_plus_none = self.model_names.copy()
|
||||
models_plus_none.insert(0, "None")
|
||||
self.model3 = self.add_widget_intelligent(
|
||||
npyscreen.SelectOne,
|
||||
name="(3)",
|
||||
values=models_plus_none,
|
||||
value=0,
|
||||
max_height=len(self.model_names) + 1,
|
||||
max_width=max_width,
|
||||
scroll_exit=True,
|
||||
relx=max_width * 2 + 3 if horizontal_layout else None,
|
||||
rely=5 if horizontal_layout else None,
|
||||
)
|
||||
for m in [self.model1, self.model2, self.model3]:
|
||||
m.when_value_edited = self.models_changed
|
||||
self.merged_model_name = self.add_widget_intelligent(
|
||||
npyscreen.TitleText,
|
||||
name="Name for merged model:",
|
||||
labelColor="CONTROL",
|
||||
value="",
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.force = self.add_widget_intelligent(
|
||||
npyscreen.Checkbox,
|
||||
name="Force merge of incompatible models",
|
||||
labelColor="CONTROL",
|
||||
value=False,
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.merge_method = self.add_widget_intelligent(
|
||||
npyscreen.TitleSelectOne,
|
||||
name="Merge Method:",
|
||||
values=self.interpolations,
|
||||
value=0,
|
||||
labelColor="CONTROL",
|
||||
max_height=len(self.interpolations) + 1,
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.alpha = self.add_widget_intelligent(
|
||||
FloatTitleSlider,
|
||||
name="Weight (alpha) to assign to second and third models:",
|
||||
out_of=1.0,
|
||||
step=0.01,
|
||||
lowest=0,
|
||||
value=0.5,
|
||||
labelColor="CONTROL",
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.model1.editing = True
|
||||
|
||||
def models_changed(self):
|
||||
models = self.model1.values
|
||||
selected_model1 = self.model1.value[0]
|
||||
selected_model2 = self.model2.value[0]
|
||||
selected_model3 = self.model3.value[0]
|
||||
merged_model_name = f"{models[selected_model1]}+{models[selected_model2]}"
|
||||
self.merged_model_name.value = merged_model_name
|
||||
|
||||
if selected_model3 > 0:
|
||||
self.merge_method.values = ['add_difference ( A+(B-C) )']
|
||||
self.merged_model_name.value += f"+{models[selected_model3 -1]}" # In model3 there is one more element in the list (None). So we have to subtract one.
|
||||
else:
|
||||
self.merge_method.values = self.interpolations
|
||||
self.merge_method.value = 0
|
||||
|
||||
def on_ok(self):
|
||||
if self.validate_field_values() and self.check_for_overwrite():
|
||||
self.parentApp.setNextForm(None)
|
||||
self.editing = False
|
||||
self.parentApp.merge_arguments = self.marshall_arguments()
|
||||
npyscreen.notify("Starting the merge...")
|
||||
else:
|
||||
self.editing = True
|
||||
|
||||
def on_cancel(self):
|
||||
sys.exit(0)
|
||||
|
||||
def marshall_arguments(self) -> dict:
|
||||
model_names = self.model_names
|
||||
models = [
|
||||
model_names[self.model1.value[0]],
|
||||
model_names[self.model2.value[0]],
|
||||
]
|
||||
if self.model3.value[0] > 0:
|
||||
models.append(model_names[self.model3.value[0] - 1])
|
||||
interp='add_difference'
|
||||
else:
|
||||
interp=self.interpolations[self.merge_method.value[0]]
|
||||
|
||||
args = dict(
|
||||
models=models,
|
||||
alpha=self.alpha.value,
|
||||
interp=interp,
|
||||
force=self.force.value,
|
||||
merged_model_name=self.merged_model_name.value,
|
||||
)
|
||||
return args
|
||||
|
||||
def check_for_overwrite(self) -> bool:
|
||||
model_out = self.merged_model_name.value
|
||||
if model_out not in self.model_names:
|
||||
return True
|
||||
else:
|
||||
return npyscreen.notify_yes_no(
|
||||
f"The chosen merged model destination, {model_out}, is already in use. Overwrite?"
|
||||
)
|
||||
|
||||
def validate_field_values(self) -> bool:
|
||||
bad_fields = []
|
||||
model_names = self.model_names
|
||||
selected_models = set(
|
||||
(model_names[self.model1.value[0]], model_names[self.model2.value[0]])
|
||||
)
|
||||
if self.model3.value[0] > 0:
|
||||
selected_models.add(model_names[self.model3.value[0] - 1])
|
||||
if len(selected_models) < 2:
|
||||
bad_fields.append(
|
||||
f"Please select two or three DIFFERENT models to compare. You selected {selected_models}"
|
||||
)
|
||||
if len(bad_fields) > 0:
|
||||
message = "The following problems were detected and must be corrected:"
|
||||
for problem in bad_fields:
|
||||
message += f"\n* {problem}"
|
||||
npyscreen.notify_confirm(message)
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
||||
def get_model_names(self) -> List[str]:
|
||||
model_names = [
|
||||
name
|
||||
for name in self.model_manager.model_names()
|
||||
if self.model_manager.model_info(name).get("format") == "diffusers"
|
||||
]
|
||||
print(model_names)
|
||||
return sorted(model_names)
|
||||
|
||||
|
||||
class Mergeapp(npyscreen.NPSAppManaged):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
conf = OmegaConf.load(global_config_file())
|
||||
self.model_manager = ModelManager(
|
||||
conf, "cpu", "float16"
|
||||
) # precision doesn't really matter here
|
||||
|
||||
def onStart(self):
|
||||
npyscreen.setTheme(npyscreen.Themes.ElegantTheme)
|
||||
self.main = self.addForm("MAIN", mergeModelsForm, name="Merge Models Settings")
|
||||
|
||||
|
||||
def run_gui(args: Namespace):
|
||||
mergeapp = Mergeapp()
|
||||
mergeapp.run()
|
||||
|
||||
args = mergeapp.merge_arguments
|
||||
merge_diffusion_models_and_commit(**args)
|
||||
print(f'>> Models merged into new model: "{args["merged_model_name"]}".')
|
||||
|
||||
|
||||
def run_cli(args: Namespace):
|
||||
assert args.alpha >= 0 and args.alpha <= 1.0, "alpha must be between 0 and 1"
|
||||
assert (
|
||||
args.models and len(args.models) >= 1 and len(args.models) <= 3
|
||||
), "Please provide the --models argument to list 2 to 3 models to merge. Use --help for full usage."
|
||||
|
||||
if not args.merged_model_name:
|
||||
args.merged_model_name = "+".join(args.models)
|
||||
print(
|
||||
f'>> No --merged_model_name provided. Defaulting to "{args.merged_model_name}"'
|
||||
)
|
||||
|
||||
model_manager = ModelManager(OmegaConf.load(global_config_file()))
|
||||
assert (
|
||||
args.clobber or args.merged_model_name not in model_manager.model_names()
|
||||
), f'A model named "{args.merged_model_name}" already exists. Use --clobber to overwrite.'
|
||||
|
||||
merge_diffusion_models_and_commit(**vars(args))
|
||||
print(f'>> Models merged into new model: "{args.merged_model_name}".')
|
||||
|
||||
|
||||
def main():
|
||||
args = _parse_args()
|
||||
global_set_root(args.root_dir)
|
||||
|
||||
cache_dir = str(global_cache_dir("diffusers"))
|
||||
os.environ[
|
||||
"HF_HOME"
|
||||
] = cache_dir # because not clear the merge pipeline is honoring cache_dir
|
||||
args.cache_dir = cache_dir
|
||||
|
||||
try:
|
||||
if args.front_end:
|
||||
run_gui(args)
|
||||
else:
|
||||
run_cli(args)
|
||||
except widget.NotEnoughSpaceForWidget as e:
|
||||
if str(e).startswith("Height of 1 allocated"):
|
||||
print(
|
||||
"** You need to have at least two diffusers models defined in models.yaml in order to merge"
|
||||
)
|
||||
else:
|
||||
print("** Not enough room for the user interface. Try making this window larger.")
|
||||
sys.exit(-1)
|
||||
except Exception:
|
||||
print(">> An error occurred:")
|
||||
traceback.print_exc()
|
||||
sys.exit(-1)
|
||||
except KeyboardInterrupt:
|
||||
sys.exit(-1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
5
invokeai/frontend/training/__init__.py
Normal file
@ -0,0 +1,5 @@
|
||||
'''
|
||||
Initialization file for invokeai.frontend.training
|
||||
'''
|
||||
from .textual_inversion import main as invokeai_textual_inversion
|
||||
|
461
invokeai/frontend/training/textual_inversion.py
Executable file
@ -0,0 +1,461 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
"""
|
||||
This is the frontend to "textual_inversion_training.py".
|
||||
|
||||
Copyright (c) 2023 Lincoln Stein and the InvokeAI Development Team
|
||||
"""
|
||||
|
||||
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
import sys
|
||||
import traceback
|
||||
from argparse import Namespace
|
||||
from pathlib import Path
|
||||
from typing import List, Tuple
|
||||
|
||||
import npyscreen
|
||||
from npyscreen import widget
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
from invokeai.backend.globals import Globals, global_set_root
|
||||
from ...backend.training import (
|
||||
do_textual_inversion_training,
|
||||
parse_args,
|
||||
)
|
||||
|
||||
TRAINING_DATA = "text-inversion-training-data"
|
||||
TRAINING_DIR = "text-inversion-output"
|
||||
CONF_FILE = "preferences.conf"
|
||||
|
||||
|
||||
class textualInversionForm(npyscreen.FormMultiPageAction):
|
||||
resolutions = [512, 768, 1024]
|
||||
lr_schedulers = [
|
||||
"linear",
|
||||
"cosine",
|
||||
"cosine_with_restarts",
|
||||
"polynomial",
|
||||
"constant",
|
||||
"constant_with_warmup",
|
||||
]
|
||||
precisions = ["no", "fp16", "bf16"]
|
||||
learnable_properties = ["object", "style"]
|
||||
|
||||
def __init__(self, parentApp, name, saved_args=None):
|
||||
self.saved_args = saved_args or {}
|
||||
super().__init__(parentApp, name)
|
||||
|
||||
def afterEditing(self):
|
||||
self.parentApp.setNextForm(None)
|
||||
|
||||
def create(self):
|
||||
self.model_names, default = self.get_model_names()
|
||||
default_initializer_token = "★"
|
||||
default_placeholder_token = ""
|
||||
saved_args = self.saved_args
|
||||
|
||||
try:
|
||||
default = self.model_names.index(saved_args["model"])
|
||||
except:
|
||||
pass
|
||||
|
||||
self.add_widget_intelligent(
|
||||
npyscreen.FixedText,
|
||||
value="Use ctrl-N and ctrl-P to move to the <N>ext and <P>revious fields, cursor arrows to make a selection, and space to toggle checkboxes.",
|
||||
editable=False,
|
||||
)
|
||||
|
||||
self.model = self.add_widget_intelligent(
|
||||
npyscreen.TitleSelectOne,
|
||||
name="Model Name:",
|
||||
values=self.model_names,
|
||||
value=default,
|
||||
max_height=len(self.model_names) + 1,
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.placeholder_token = self.add_widget_intelligent(
|
||||
npyscreen.TitleText,
|
||||
name="Trigger Term:",
|
||||
value="", # saved_args.get('placeholder_token',''), # to restore previous term
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.placeholder_token.when_value_edited = self.initializer_changed
|
||||
self.nextrely -= 1
|
||||
self.nextrelx += 30
|
||||
self.prompt_token = self.add_widget_intelligent(
|
||||
npyscreen.FixedText,
|
||||
name="Trigger term for use in prompt",
|
||||
value="",
|
||||
editable=False,
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.nextrelx -= 30
|
||||
self.initializer_token = self.add_widget_intelligent(
|
||||
npyscreen.TitleText,
|
||||
name="Initializer:",
|
||||
value=saved_args.get("initializer_token", default_initializer_token),
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.resume_from_checkpoint = self.add_widget_intelligent(
|
||||
npyscreen.Checkbox,
|
||||
name="Resume from last saved checkpoint",
|
||||
value=False,
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.learnable_property = self.add_widget_intelligent(
|
||||
npyscreen.TitleSelectOne,
|
||||
name="Learnable property:",
|
||||
values=self.learnable_properties,
|
||||
value=self.learnable_properties.index(
|
||||
saved_args.get("learnable_property", "object")
|
||||
),
|
||||
max_height=4,
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.train_data_dir = self.add_widget_intelligent(
|
||||
npyscreen.TitleFilename,
|
||||
name="Data Training Directory:",
|
||||
select_dir=True,
|
||||
must_exist=False,
|
||||
value=str(
|
||||
saved_args.get(
|
||||
"train_data_dir",
|
||||
Path(Globals.root) / TRAINING_DATA / default_placeholder_token,
|
||||
)
|
||||
),
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.output_dir = self.add_widget_intelligent(
|
||||
npyscreen.TitleFilename,
|
||||
name="Output Destination Directory:",
|
||||
select_dir=True,
|
||||
must_exist=False,
|
||||
value=str(
|
||||
saved_args.get(
|
||||
"output_dir",
|
||||
Path(Globals.root) / TRAINING_DIR / default_placeholder_token,
|
||||
)
|
||||
),
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.resolution = self.add_widget_intelligent(
|
||||
npyscreen.TitleSelectOne,
|
||||
name="Image resolution (pixels):",
|
||||
values=self.resolutions,
|
||||
value=self.resolutions.index(saved_args.get("resolution", 512)),
|
||||
max_height=4,
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.center_crop = self.add_widget_intelligent(
|
||||
npyscreen.Checkbox,
|
||||
name="Center crop images before resizing to resolution",
|
||||
value=saved_args.get("center_crop", False),
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.mixed_precision = self.add_widget_intelligent(
|
||||
npyscreen.TitleSelectOne,
|
||||
name="Mixed Precision:",
|
||||
values=self.precisions,
|
||||
value=self.precisions.index(saved_args.get("mixed_precision", "fp16")),
|
||||
max_height=4,
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.num_train_epochs = self.add_widget_intelligent(
|
||||
npyscreen.TitleSlider,
|
||||
name="Number of training epochs:",
|
||||
out_of=1000,
|
||||
step=50,
|
||||
lowest=1,
|
||||
value=saved_args.get("num_train_epochs", 100),
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.max_train_steps = self.add_widget_intelligent(
|
||||
npyscreen.TitleSlider,
|
||||
name="Max Training Steps:",
|
||||
out_of=10000,
|
||||
step=500,
|
||||
lowest=1,
|
||||
value=saved_args.get("max_train_steps", 3000),
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.train_batch_size = self.add_widget_intelligent(
|
||||
npyscreen.TitleSlider,
|
||||
name="Batch Size (reduce if you run out of memory):",
|
||||
out_of=50,
|
||||
step=1,
|
||||
lowest=1,
|
||||
value=saved_args.get("train_batch_size", 8),
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.gradient_accumulation_steps = self.add_widget_intelligent(
|
||||
npyscreen.TitleSlider,
|
||||
name="Gradient Accumulation Steps (may need to decrease this to resume from a checkpoint):",
|
||||
out_of=10,
|
||||
step=1,
|
||||
lowest=1,
|
||||
value=saved_args.get("gradient_accumulation_steps", 4),
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.lr_warmup_steps = self.add_widget_intelligent(
|
||||
npyscreen.TitleSlider,
|
||||
name="Warmup Steps:",
|
||||
out_of=100,
|
||||
step=1,
|
||||
lowest=0,
|
||||
value=saved_args.get("lr_warmup_steps", 0),
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.learning_rate = self.add_widget_intelligent(
|
||||
npyscreen.TitleText,
|
||||
name="Learning Rate:",
|
||||
value=str(
|
||||
saved_args.get("learning_rate", "5.0e-04"),
|
||||
),
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.scale_lr = self.add_widget_intelligent(
|
||||
npyscreen.Checkbox,
|
||||
name="Scale learning rate by number GPUs, steps and batch size",
|
||||
value=saved_args.get("scale_lr", True),
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.enable_xformers_memory_efficient_attention = self.add_widget_intelligent(
|
||||
npyscreen.Checkbox,
|
||||
name="Use xformers acceleration",
|
||||
value=saved_args.get("enable_xformers_memory_efficient_attention", False),
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.lr_scheduler = self.add_widget_intelligent(
|
||||
npyscreen.TitleSelectOne,
|
||||
name="Learning rate scheduler:",
|
||||
values=self.lr_schedulers,
|
||||
max_height=7,
|
||||
value=self.lr_schedulers.index(saved_args.get("lr_scheduler", "constant")),
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.model.editing = True
|
||||
|
||||
def initializer_changed(self):
|
||||
placeholder = self.placeholder_token.value
|
||||
self.prompt_token.value = f"(Trigger by using <{placeholder}> in your prompts)"
|
||||
self.train_data_dir.value = str(
|
||||
Path(Globals.root) / TRAINING_DATA / placeholder
|
||||
)
|
||||
self.output_dir.value = str(Path(Globals.root) / TRAINING_DIR / placeholder)
|
||||
self.resume_from_checkpoint.value = Path(self.output_dir.value).exists()
|
||||
|
||||
def on_ok(self):
|
||||
if self.validate_field_values():
|
||||
self.parentApp.setNextForm(None)
|
||||
self.editing = False
|
||||
self.parentApp.ti_arguments = self.marshall_arguments()
|
||||
npyscreen.notify(
|
||||
"Launching textual inversion training. This will take a while..."
|
||||
)
|
||||
else:
|
||||
self.editing = True
|
||||
|
||||
def ok_cancel(self):
|
||||
sys.exit(0)
|
||||
|
||||
def validate_field_values(self) -> bool:
|
||||
bad_fields = []
|
||||
if self.model.value is None:
|
||||
bad_fields.append(
|
||||
"Model Name must correspond to a known model in models.yaml"
|
||||
)
|
||||
if not re.match("^[a-zA-Z0-9.-]+$", self.placeholder_token.value):
|
||||
bad_fields.append(
|
||||
"Trigger term must only contain alphanumeric characters, the dot and hyphen"
|
||||
)
|
||||
if self.train_data_dir.value is None:
|
||||
bad_fields.append("Data Training Directory cannot be empty")
|
||||
if self.output_dir.value is None:
|
||||
bad_fields.append("The Output Destination Directory cannot be empty")
|
||||
if len(bad_fields) > 0:
|
||||
message = "The following problems were detected and must be corrected:"
|
||||
for problem in bad_fields:
|
||||
message += f"\n* {problem}"
|
||||
npyscreen.notify_confirm(message)
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
||||
def get_model_names(self) -> Tuple[List[str], int]:
|
||||
conf = OmegaConf.load(os.path.join(Globals.root, "configs/models.yaml"))
|
||||
model_names = [
|
||||
idx
|
||||
for idx in sorted(list(conf.keys()))
|
||||
if conf[idx].get("format", None) == "diffusers"
|
||||
]
|
||||
defaults = [
|
||||
idx
|
||||
for idx in range(len(model_names))
|
||||
if "default" in conf[model_names[idx]]
|
||||
]
|
||||
default = defaults[0] if len(defaults) > 0 else 0
|
||||
return (model_names, default)
|
||||
|
||||
def marshall_arguments(self) -> dict:
|
||||
args = dict()
|
||||
|
||||
# the choices
|
||||
args.update(
|
||||
model=self.model_names[self.model.value[0]],
|
||||
resolution=self.resolutions[self.resolution.value[0]],
|
||||
lr_scheduler=self.lr_schedulers[self.lr_scheduler.value[0]],
|
||||
mixed_precision=self.precisions[self.mixed_precision.value[0]],
|
||||
learnable_property=self.learnable_properties[
|
||||
self.learnable_property.value[0]
|
||||
],
|
||||
)
|
||||
|
||||
# all the strings and booleans
|
||||
for attr in (
|
||||
"initializer_token",
|
||||
"placeholder_token",
|
||||
"train_data_dir",
|
||||
"output_dir",
|
||||
"scale_lr",
|
||||
"center_crop",
|
||||
"enable_xformers_memory_efficient_attention",
|
||||
):
|
||||
args[attr] = getattr(self, attr).value
|
||||
|
||||
# all the integers
|
||||
for attr in (
|
||||
"train_batch_size",
|
||||
"gradient_accumulation_steps",
|
||||
"num_train_epochs",
|
||||
"max_train_steps",
|
||||
"lr_warmup_steps",
|
||||
):
|
||||
args[attr] = int(getattr(self, attr).value)
|
||||
|
||||
# the floats (just one)
|
||||
args.update(learning_rate=float(self.learning_rate.value))
|
||||
|
||||
# a special case
|
||||
if self.resume_from_checkpoint.value and Path(self.output_dir.value).exists():
|
||||
args["resume_from_checkpoint"] = "latest"
|
||||
|
||||
return args
|
||||
|
||||
|
||||
class MyApplication(npyscreen.NPSAppManaged):
|
||||
def __init__(self, saved_args=None):
|
||||
super().__init__()
|
||||
self.ti_arguments = None
|
||||
self.saved_args = saved_args
|
||||
|
||||
def onStart(self):
|
||||
npyscreen.setTheme(npyscreen.Themes.DefaultTheme)
|
||||
self.main = self.addForm(
|
||||
"MAIN",
|
||||
textualInversionForm,
|
||||
name="Textual Inversion Settings",
|
||||
saved_args=self.saved_args,
|
||||
)
|
||||
|
||||
|
||||
def copy_to_embeddings_folder(args: dict):
|
||||
"""
|
||||
Copy learned_embeds.bin into the embeddings folder, and offer to
|
||||
delete the full model and checkpoints.
|
||||
"""
|
||||
source = Path(args["output_dir"], "learned_embeds.bin")
|
||||
dest_dir_name = args["placeholder_token"].strip("<>")
|
||||
destination = Path(Globals.root, "embeddings", dest_dir_name)
|
||||
os.makedirs(destination, exist_ok=True)
|
||||
print(f">> Training completed. Copying learned_embeds.bin into {str(destination)}")
|
||||
shutil.copy(source, destination)
|
||||
if (
|
||||
input("Delete training logs and intermediate checkpoints? [y] ") or "y"
|
||||
).startswith(("y", "Y")):
|
||||
shutil.rmtree(Path(args["output_dir"]))
|
||||
else:
|
||||
print(f'>> Keeping {args["output_dir"]}')
|
||||
|
||||
|
||||
def save_args(args: dict):
|
||||
"""
|
||||
Save the current argument values to an omegaconf file
|
||||
"""
|
||||
dest_dir = Path(Globals.root) / TRAINING_DIR
|
||||
os.makedirs(dest_dir, exist_ok=True)
|
||||
conf_file = dest_dir / CONF_FILE
|
||||
conf = OmegaConf.create(args)
|
||||
OmegaConf.save(config=conf, f=conf_file)
|
||||
|
||||
|
||||
def previous_args() -> dict:
|
||||
"""
|
||||
Get the previous arguments used.
|
||||
"""
|
||||
conf_file = Path(Globals.root) / TRAINING_DIR / CONF_FILE
|
||||
try:
|
||||
conf = OmegaConf.load(conf_file)
|
||||
conf["placeholder_token"] = conf["placeholder_token"].strip("<>")
|
||||
except:
|
||||
conf = None
|
||||
|
||||
return conf
|
||||
|
||||
|
||||
def do_front_end(args: Namespace):
|
||||
saved_args = previous_args()
|
||||
myapplication = MyApplication(saved_args=saved_args)
|
||||
myapplication.run()
|
||||
|
||||
if args := myapplication.ti_arguments:
|
||||
os.makedirs(args["output_dir"], exist_ok=True)
|
||||
|
||||
# Automatically add angle brackets around the trigger
|
||||
if not re.match("^<.+>$", args["placeholder_token"]):
|
||||
args["placeholder_token"] = f"<{args['placeholder_token']}>"
|
||||
|
||||
args["only_save_embeds"] = True
|
||||
save_args(args)
|
||||
|
||||
try:
|
||||
do_textual_inversion_training(**args)
|
||||
copy_to_embeddings_folder(args)
|
||||
except Exception as e:
|
||||
print("** An exception occurred during training. The exception was:")
|
||||
print(str(e))
|
||||
print("** DETAILS:")
|
||||
print(traceback.format_exc())
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
global_set_root(args.root_dir or Globals.root)
|
||||
try:
|
||||
if args.front_end:
|
||||
do_front_end(args)
|
||||
else:
|
||||
do_textual_inversion_training(**vars(args))
|
||||
except AssertionError as e:
|
||||
print(str(e))
|
||||
sys.exit(-1)
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
except (widget.NotEnoughSpaceForWidget, Exception) as e:
|
||||
if str(e).startswith("Height of 1 allocated"):
|
||||
print(
|
||||
"** You need to have at least one diffusers models defined in models.yaml in order to train"
|
||||
)
|
||||
elif str(e).startswith('addwstr'):
|
||||
print(
|
||||
'** Not enough window space for the interface. Please make your window larger and try again.'
|
||||
)
|
||||
else:
|
||||
print(f"** An error has occurred: {str(e)}")
|
||||
sys.exit(-1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|