Final phase of source tree restructure (#2833)

# All python code has been moved under `invokeai`. All vestiges of `ldm`
and `ldm.invoke` are now gone.

***You will need to run `pip install -e .` before the code will work
again!***

Everything seems to be functional, but extensive testing is advised.

A guide to where the files have gone is forthcoming.
This commit is contained in:
Lincoln Stein 2023-03-03 15:05:41 -05:00 committed by GitHub
commit b3dccfaeb6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
585 changed files with 9987 additions and 10905 deletions

59
.github/CODEOWNERS vendored
View File

@ -1,51 +1,34 @@
# continuous integration # continuous integration
/.github/workflows/ @mauwii @lstein @blessedcoolant /.github/workflows/ @mauwii @lstein
# documentation # documentation
/docs/ @lstein @mauwii @tildebyte @blessedcoolant /docs/ @lstein @mauwii @tildebyte
mkdocs.yml @lstein @mauwii @blessedcoolant /mkdocs.yml @lstein @mauwii
# nodes
/invokeai/app/ @Kyle0654 @blessedcoolant
# installation and configuration # installation and configuration
/pyproject.toml @mauwii @lstein @ebr @blessedcoolant /pyproject.toml @mauwii @lstein @blessedcoolant
/docker/ @mauwii @lstein @blessedcoolant /docker/ @mauwii @lstein
/scripts/ @ebr @lstein @blessedcoolant /scripts/ @ebr @lstein
/installer/ @ebr @lstein @tildebyte @blessedcoolant /installer/ @lstein @ebr
ldm/invoke/config @lstein @ebr @blessedcoolant /invokeai/assets @lstein @ebr
invokeai/assets @lstein @ebr @blessedcoolant /invokeai/configs @lstein
invokeai/configs @lstein @ebr @blessedcoolant /invokeai/version @lstein @blessedcoolant
/ldm/invoke/_version.py @lstein @blessedcoolant
# web ui # web ui
/invokeai/frontend @blessedcoolant @psychedelicious @lstein /invokeai/frontend @blessedcoolant @psychedelicious @lstein
/invokeai/backend @blessedcoolant @psychedelicious @lstein /invokeai/backend @blessedcoolant @psychedelicious @lstein
# generation and model management # generation, model management, postprocessing
/ldm/*.py @lstein @blessedcoolant /invokeai/backend @keturn @damian0815 @lstein @blessedcoolant @jpphoto
/ldm/generate.py @lstein @keturn @blessedcoolant
/ldm/invoke/args.py @lstein @blessedcoolant
/ldm/invoke/ckpt* @lstein @blessedcoolant
/ldm/invoke/ckpt_generator @lstein @blessedcoolant
/ldm/invoke/CLI.py @lstein @blessedcoolant
/ldm/invoke/config @lstein @ebr @mauwii @blessedcoolant
/ldm/invoke/generator @keturn @damian0815 @blessedcoolant
/ldm/invoke/globals.py @lstein @blessedcoolant
/ldm/invoke/merge_diffusers.py @lstein @blessedcoolant
/ldm/invoke/model_manager.py @lstein @blessedcoolant
/ldm/invoke/txt2mask.py @lstein @blessedcoolant
/ldm/invoke/patchmatch.py @Kyle0654 @blessedcoolant @lstein
/ldm/invoke/restoration @lstein @blessedcoolant
# attention, textual inversion, model configuration # front ends
/ldm/models @damian0815 @keturn @lstein @blessedcoolant /invokeai/frontend/CLI @lstein
/ldm/modules @damian0815 @keturn @lstein @blessedcoolant /invokeai/frontend/install @lstein @ebr @mauwii
/invokeai/frontend/merge @lstein @blessedcoolant @hipsterusername
/invokeai/frontend/training @lstein @blessedcoolant @hipsterusername
/invokeai/frontend/web @psychedelicious @blessedcoolant
# Nodes
apps/ @Kyle0654 @lstein @blessedcoolant
# legacy REST API
# is CapableWeb still engaged?
/ldm/invoke/pngwriter.py @CapableWeb @lstein @blessedcoolant
/ldm/invoke/server_legacy.py @CapableWeb @lstein @blessedcoolant
/scripts/legacy_api.py @CapableWeb @lstein @blessedcoolant
/tests/legacy_tests.sh @CapableWeb @lstein @blessedcoolant

View File

@ -9,7 +9,7 @@ on:
- 'dev/docker/*' - 'dev/docker/*'
paths: paths:
- 'pyproject.toml' - 'pyproject.toml'
- 'ldm/**' - 'invokeai/**'
- 'invokeai/backend/**' - 'invokeai/backend/**'
- 'invokeai/configs/**' - 'invokeai/configs/**'
- 'invokeai/frontend/dist/**' - 'invokeai/frontend/dist/**'

View File

@ -3,14 +3,14 @@ name: Lint frontend
on: on:
pull_request: pull_request:
paths: paths:
- 'invokeai/frontend/**' - 'invokeai/frontend/web/**'
push: push:
paths: paths:
- 'invokeai/frontend/**' - 'invokeai/frontend/web/**'
defaults: defaults:
run: run:
working-directory: invokeai/frontend working-directory: invokeai/frontend/web
jobs: jobs:
lint-frontend: lint-frontend:

View File

@ -3,7 +3,7 @@ name: PyPI Release
on: on:
push: push:
paths: paths:
- 'ldm/invoke/_version.py' - 'invokeai/version/invokeai_version.py'
workflow_dispatch: workflow_dispatch:
jobs: jobs:

View File

@ -3,7 +3,7 @@ on:
pull_request: pull_request:
paths-ignore: paths-ignore:
- 'pyproject.toml' - 'pyproject.toml'
- 'ldm/**' - 'invokeai/**'
- 'invokeai/backend/**' - 'invokeai/backend/**'
- 'invokeai/configs/**' - 'invokeai/configs/**'
- 'invokeai/frontend/dist/**' - 'invokeai/frontend/dist/**'

View File

@ -5,14 +5,14 @@ on:
- 'main' - 'main'
paths: paths:
- 'pyproject.toml' - 'pyproject.toml'
- 'ldm/**' - 'invokeai/**'
- 'invokeai/backend/**' - 'invokeai/backend/**'
- 'invokeai/configs/**' - 'invokeai/configs/**'
- 'invokeai/frontend/dist/**' - 'invokeai/frontend/dist/**'
pull_request: pull_request:
paths: paths:
- 'pyproject.toml' - 'pyproject.toml'
- 'ldm/**' - 'invokeai/**'
- 'invokeai/backend/**' - 'invokeai/backend/**'
- 'invokeai/configs/**' - 'invokeai/configs/**'
- 'invokeai/frontend/dist/**' - 'invokeai/frontend/dist/**'
@ -112,7 +112,7 @@ jobs:
- name: set INVOKEAI_OUTDIR - name: set INVOKEAI_OUTDIR
run: > run: >
python -c python -c
"import os;from ldm.invoke.globals import Globals;OUTDIR=os.path.join(Globals.root,str('outputs'));print(f'INVOKEAI_OUTDIR={OUTDIR}')" "import os;from invokeai.backend.globals import Globals;OUTDIR=os.path.join(Globals.root,str('outputs'));print(f'INVOKEAI_OUTDIR={OUTDIR}')"
>> ${{ matrix.github-env }} >> ${{ matrix.github-env }}
- name: run invokeai-configure - name: run invokeai-configure

10
.gitignore vendored
View File

@ -198,7 +198,7 @@ checkpoints
.DS_Store .DS_Store
# Let the frontend manage its own gitignore # Let the frontend manage its own gitignore
!invokeai/frontend/* !invokeai/frontend/web/*
# Scratch folder # Scratch folder
.scratch/ .scratch/
@ -213,11 +213,6 @@ gfpgan/
# config file (will be created by installer) # config file (will be created by installer)
configs/models.yaml configs/models.yaml
# weights (will be created by installer)
models/ldm/stable-diffusion-v1/*.ckpt
models/clipseg
models/gfpgan
# ignore initfile # ignore initfile
.invokeai .invokeai
@ -232,6 +227,3 @@ installer/install.bat
installer/install.sh installer/install.sh
installer/update.bat installer/update.bat
installer/update.sh installer/update.sh
# no longer stored in source directory
models

View File

@ -11,10 +11,10 @@ if [[ -v "VIRTUAL_ENV" ]]; then
exit -1 exit -1
fi fi
VERSION=$(cd ..; python -c "from ldm.invoke import __version__ as version; print(version)") VERSION=$(cd ..; python -c "from invokeai.version import __version__ as version; print(version)")
PATCH="" PATCH=""
VERSION="v${VERSION}${PATCH}" VERSION="v${VERSION}${PATCH}"
LATEST_TAG="v2.3-latest" LATEST_TAG="v3.0-latest"
echo Building installer for version $VERSION echo Building installer for version $VERSION
echo "Be certain that you're in the 'installer' directory before continuing." echo "Be certain that you're in the 'installer' directory before continuing."

View File

@ -291,7 +291,7 @@ class InvokeAiInstance:
src = Path(__file__).parents[1].expanduser().resolve() src = Path(__file__).parents[1].expanduser().resolve()
# if the above directory contains one of these files, we'll do a source install # if the above directory contains one of these files, we'll do a source install
next(src.glob("pyproject.toml")) next(src.glob("pyproject.toml"))
next(src.glob("ldm")) next(src.glob("invokeai"))
except StopIteration: except StopIteration:
print("Unable to find a wheel or perform a source install. Giving up.") print("Unable to find a wheel or perform a source install. Giving up.")
@ -342,14 +342,14 @@ class InvokeAiInstance:
introduction() introduction()
from ldm.invoke.config import invokeai_configure from invokeai.frontend.install import invokeai_configure
# NOTE: currently the config script does its own arg parsing! this means the command-line switches # NOTE: currently the config script does its own arg parsing! this means the command-line switches
# from the installer will also automatically propagate down to the config script. # from the installer will also automatically propagate down to the config script.
# this may change in the future with config refactoring! # this may change in the future with config refactoring!
succeeded = False succeeded = False
try: try:
invokeai_configure.main() invokeai_configure()
succeeded = True succeeded = True
except requests.exceptions.ConnectionError as e: except requests.exceptions.ConnectionError as e:
print(f'\nA network error was encountered during configuration and download: {str(e)}') print(f'\nA network error was encountered during configuration and download: {str(e)}')

View File

@ -1,3 +1,11 @@
After version 2.3 is released, the ldm/invoke modules will be migrated to this location Organization of the source tree:
so that we have a proper invokeai distribution. Currently it is only being used for
data files. app -- Home of nodes invocations and services
assets -- Images and other data files used by InvokeAI
backend -- Non-user facing libraries, including the rendering
core.
configs -- Configuration files used at install and run times
frontend -- User-facing scripts, including the CLI and the WebUI
version -- Current InvokeAI version string, stored
in version/invokeai_version.py

View File

@ -1,33 +1,31 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) # Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
from argparse import Namespace
import os import os
from argparse import Namespace
from ..services.processor import DefaultInvocationProcessor
from ..services.graph import GraphExecutionState
from ..services.sqlite import SqliteItemStorage
from ...globals import Globals from ...globals import Globals
from ..services.generate_initializer import get_generate
from ..services.graph import GraphExecutionState
from ..services.image_storage import DiskImageStorage from ..services.image_storage import DiskImageStorage
from ..services.invocation_queue import MemoryInvocationQueue from ..services.invocation_queue import MemoryInvocationQueue
from ..services.invocation_services import InvocationServices from ..services.invocation_services import InvocationServices
from ..services.invoker import Invoker from ..services.invoker import Invoker
from ..services.generate_initializer import get_generate from ..services.processor import DefaultInvocationProcessor
from ..services.sqlite import SqliteItemStorage
from .events import FastAPIEventService from .events import FastAPIEventService
# TODO: is there a better way to achieve this? # TODO: is there a better way to achieve this?
def check_internet()->bool: def check_internet() -> bool:
''' """
Return true if the internet is reachable. Return true if the internet is reachable.
It does this by pinging huggingface.co. It does this by pinging huggingface.co.
''' """
import urllib.request import urllib.request
host = 'http://huggingface.co'
host = "http://huggingface.co"
try: try:
urllib.request.urlopen(host,timeout=1) urllib.request.urlopen(host, timeout=1)
return True return True
except: except:
return False return False
@ -35,14 +33,11 @@ def check_internet()->bool:
class ApiDependencies: class ApiDependencies:
"""Contains and initializes all dependencies for the API""" """Contains and initializes all dependencies for the API"""
invoker: Invoker = None invoker: Invoker = None
@staticmethod @staticmethod
def initialize( def initialize(args, config, event_handler_id: int):
args,
config,
event_handler_id: int
):
Globals.try_patchmatch = args.patchmatch Globals.try_patchmatch = args.patchmatch
Globals.always_use_cpu = args.always_use_cpu Globals.always_use_cpu = args.always_use_cpu
Globals.internet_available = args.internet_available and check_internet() Globals.internet_available = args.internet_available and check_internet()
@ -50,30 +45,34 @@ class ApiDependencies:
Globals.ckpt_convert = args.ckpt_convert Globals.ckpt_convert = args.ckpt_convert
# TODO: Use a logger # TODO: Use a logger
print(f'>> Internet connectivity is {Globals.internet_available}') print(f">> Internet connectivity is {Globals.internet_available}")
generate = get_generate(args, config) generate = get_generate(args, config)
events = FastAPIEventService(event_handler_id) events = FastAPIEventService(event_handler_id)
output_folder = os.path.abspath(os.path.join(os.path.dirname(__file__), '../../../../outputs')) output_folder = os.path.abspath(
os.path.join(os.path.dirname(__file__), "../../../../outputs")
)
images = DiskImageStorage(output_folder) images = DiskImageStorage(output_folder)
# TODO: build a file/path manager? # TODO: build a file/path manager?
db_location = os.path.join(output_folder, 'invokeai.db') db_location = os.path.join(output_folder, "invokeai.db")
services = InvocationServices( services = InvocationServices(
generate = generate, generate=generate,
events = events, events=events,
images = images, images=images,
queue = MemoryInvocationQueue(), queue=MemoryInvocationQueue(),
graph_execution_manager = SqliteItemStorage[GraphExecutionState](filename = db_location, table_name = 'graph_executions'), graph_execution_manager=SqliteItemStorage[GraphExecutionState](
processor = DefaultInvocationProcessor() filename=db_location, table_name="graph_executions"
),
processor=DefaultInvocationProcessor(),
) )
ApiDependencies.invoker = Invoker(services) ApiDependencies.invoker = Invoker(services)
@staticmethod @staticmethod
def shutdown(): def shutdown():
if ApiDependencies.invoker: if ApiDependencies.invoker:

View File

@ -1,11 +1,14 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) # Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
import asyncio import asyncio
import threading
from queue import Empty, Queue from queue import Empty, Queue
from typing import Any from typing import Any
from fastapi_events.dispatcher import dispatch from fastapi_events.dispatcher import dispatch
from ..services.events import EventServiceBase from ..services.events import EventServiceBase
import threading
class FastAPIEventService(EventServiceBase): class FastAPIEventService(EventServiceBase):
event_handler_id: int event_handler_id: int
@ -16,39 +19,34 @@ class FastAPIEventService(EventServiceBase):
self.event_handler_id = event_handler_id self.event_handler_id = event_handler_id
self.__queue = Queue() self.__queue = Queue()
self.__stop_event = threading.Event() self.__stop_event = threading.Event()
asyncio.create_task(self.__dispatch_from_queue(stop_event = self.__stop_event)) asyncio.create_task(self.__dispatch_from_queue(stop_event=self.__stop_event))
super().__init__() super().__init__()
def stop(self, *args, **kwargs): def stop(self, *args, **kwargs):
self.__stop_event.set() self.__stop_event.set()
self.__queue.put(None) self.__queue.put(None)
def dispatch(self, event_name: str, payload: Any) -> None: def dispatch(self, event_name: str, payload: Any) -> None:
self.__queue.put(dict( self.__queue.put(dict(event_name=event_name, payload=payload))
event_name = event_name,
payload = payload
))
async def __dispatch_from_queue(self, stop_event: threading.Event): async def __dispatch_from_queue(self, stop_event: threading.Event):
"""Get events on from the queue and dispatch them, from the correct thread""" """Get events on from the queue and dispatch them, from the correct thread"""
while not stop_event.is_set(): while not stop_event.is_set():
try: try:
event = self.__queue.get(block = False) event = self.__queue.get(block=False)
if not event: # Probably stopping if not event: # Probably stopping
continue continue
dispatch( dispatch(
event.get('event_name'), event.get("event_name"),
payload = event.get('payload'), payload=event.get("payload"),
middleware_id = self.event_handler_id) middleware_id=self.event_handler_id,
)
except Empty: except Empty:
await asyncio.sleep(0.001) await asyncio.sleep(0.001)
pass pass
except asyncio.CancelledError as e: except asyncio.CancelledError as e:
raise e # Raise a proper error raise e # Raise a proper error

View File

@ -0,0 +1,56 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
from datetime import datetime, timezone
from fastapi import Path, Request, UploadFile
from fastapi.responses import FileResponse, Response
from fastapi.routing import APIRouter
from PIL import Image
from ...services.image_storage import ImageType
from ..dependencies import ApiDependencies
images_router = APIRouter(prefix="/v1/images", tags=["images"])
@images_router.get("/{image_type}/{image_name}", operation_id="get_image")
async def get_image(
image_type: ImageType = Path(description="The type of image to get"),
image_name: str = Path(description="The name of the image to get"),
):
"""Gets a result"""
# TODO: This is not really secure at all. At least make sure only output results are served
filename = ApiDependencies.invoker.services.images.get_path(image_type, image_name)
return FileResponse(filename)
@images_router.post(
"/uploads/",
operation_id="upload_image",
responses={
201: {"description": "The image was uploaded successfully"},
404: {"description": "Session not found"},
},
)
async def upload_image(file: UploadFile, request: Request):
if not file.content_type.startswith("image"):
return Response(status_code=415)
contents = await file.read()
try:
im = Image.open(contents)
except:
# Error opening the image
return Response(status_code=415)
filename = f"{str(int(datetime.now(timezone.utc).timestamp()))}.png"
ApiDependencies.invoker.services.images.save(ImageType.UPLOAD, filename, im)
return Response(
status_code=201,
headers={
"Location": request.url_for(
"get_image", image_type=ImageType.UPLOAD, image_name=filename
)
},
)

View File

@ -0,0 +1,271 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
from typing import Annotated, List, Optional, Union
from fastapi import Body, Path, Query
from fastapi.responses import Response
from fastapi.routing import APIRouter
from pydantic.fields import Field
from ...invocations import *
from ...invocations.baseinvocation import BaseInvocation
from ...services.graph import (
EdgeConnection,
Graph,
GraphExecutionState,
NodeAlreadyExecutedError,
)
from ...services.item_storage import PaginatedResults
from ..dependencies import ApiDependencies
session_router = APIRouter(prefix="/v1/sessions", tags=["sessions"])
@session_router.post(
"/",
operation_id="create_session",
responses={
200: {"model": GraphExecutionState},
400: {"description": "Invalid json"},
},
)
async def create_session(
graph: Optional[Graph] = Body(
default=None, description="The graph to initialize the session with"
)
) -> GraphExecutionState:
"""Creates a new session, optionally initializing it with an invocation graph"""
session = ApiDependencies.invoker.create_execution_state(graph)
return session
@session_router.get(
"/",
operation_id="list_sessions",
responses={200: {"model": PaginatedResults[GraphExecutionState]}},
)
async def list_sessions(
page: int = Query(default=0, description="The page of results to get"),
per_page: int = Query(default=10, description="The number of results per page"),
query: str = Query(default="", description="The query string to search for"),
) -> PaginatedResults[GraphExecutionState]:
"""Gets a list of sessions, optionally searching"""
if filter == "":
result = ApiDependencies.invoker.services.graph_execution_manager.list(
page, per_page
)
else:
result = ApiDependencies.invoker.services.graph_execution_manager.search(
query, page, per_page
)
return result
@session_router.get(
"/{session_id}",
operation_id="get_session",
responses={
200: {"model": GraphExecutionState},
404: {"description": "Session not found"},
},
)
async def get_session(
session_id: str = Path(description="The id of the session to get"),
) -> GraphExecutionState:
"""Gets a session"""
session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
if session is None:
return Response(status_code=404)
else:
return session
@session_router.post(
"/{session_id}/nodes",
operation_id="add_node",
responses={
200: {"model": str},
400: {"description": "Invalid node or link"},
404: {"description": "Session not found"},
},
)
async def add_node(
session_id: str = Path(description="The id of the session"),
node: Annotated[
Union[BaseInvocation.get_invocations()], Field(discriminator="type")
] = Body(description="The node to add"),
) -> str:
"""Adds a node to the graph"""
session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
if session is None:
return Response(status_code=404)
try:
session.add_node(node)
ApiDependencies.invoker.services.graph_execution_manager.set(
session
) # TODO: can this be done automatically, or add node through an API?
return session.id
except NodeAlreadyExecutedError:
return Response(status_code=400)
except IndexError:
return Response(status_code=400)
@session_router.put(
"/{session_id}/nodes/{node_path}",
operation_id="update_node",
responses={
200: {"model": GraphExecutionState},
400: {"description": "Invalid node or link"},
404: {"description": "Session not found"},
},
)
async def update_node(
session_id: str = Path(description="The id of the session"),
node_path: str = Path(description="The path to the node in the graph"),
node: Annotated[
Union[BaseInvocation.get_invocations()], Field(discriminator="type")
] = Body(description="The new node"),
) -> GraphExecutionState:
"""Updates a node in the graph and removes all linked edges"""
session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
if session is None:
return Response(status_code=404)
try:
session.update_node(node_path, node)
ApiDependencies.invoker.services.graph_execution_manager.set(
session
) # TODO: can this be done automatically, or add node through an API?
return session
except NodeAlreadyExecutedError:
return Response(status_code=400)
except IndexError:
return Response(status_code=400)
@session_router.delete(
"/{session_id}/nodes/{node_path}",
operation_id="delete_node",
responses={
200: {"model": GraphExecutionState},
400: {"description": "Invalid node or link"},
404: {"description": "Session not found"},
},
)
async def delete_node(
session_id: str = Path(description="The id of the session"),
node_path: str = Path(description="The path to the node to delete"),
) -> GraphExecutionState:
"""Deletes a node in the graph and removes all linked edges"""
session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
if session is None:
return Response(status_code=404)
try:
session.delete_node(node_path)
ApiDependencies.invoker.services.graph_execution_manager.set(
session
) # TODO: can this be done automatically, or add node through an API?
return session
except NodeAlreadyExecutedError:
return Response(status_code=400)
except IndexError:
return Response(status_code=400)
@session_router.post(
"/{session_id}/edges",
operation_id="add_edge",
responses={
200: {"model": GraphExecutionState},
400: {"description": "Invalid node or link"},
404: {"description": "Session not found"},
},
)
async def add_edge(
session_id: str = Path(description="The id of the session"),
edge: tuple[EdgeConnection, EdgeConnection] = Body(description="The edge to add"),
) -> GraphExecutionState:
"""Adds an edge to the graph"""
session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
if session is None:
return Response(status_code=404)
try:
session.add_edge(edge)
ApiDependencies.invoker.services.graph_execution_manager.set(
session
) # TODO: can this be done automatically, or add node through an API?
return session
except NodeAlreadyExecutedError:
return Response(status_code=400)
except IndexError:
return Response(status_code=400)
# TODO: the edge being in the path here is really ugly, find a better solution
@session_router.delete(
"/{session_id}/edges/{from_node_id}/{from_field}/{to_node_id}/{to_field}",
operation_id="delete_edge",
responses={
200: {"model": GraphExecutionState},
400: {"description": "Invalid node or link"},
404: {"description": "Session not found"},
},
)
async def delete_edge(
session_id: str = Path(description="The id of the session"),
from_node_id: str = Path(description="The id of the node the edge is coming from"),
from_field: str = Path(description="The field of the node the edge is coming from"),
to_node_id: str = Path(description="The id of the node the edge is going to"),
to_field: str = Path(description="The field of the node the edge is going to"),
) -> GraphExecutionState:
"""Deletes an edge from the graph"""
session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
if session is None:
return Response(status_code=404)
try:
edge = (
EdgeConnection(node_id=from_node_id, field=from_field),
EdgeConnection(node_id=to_node_id, field=to_field),
)
session.delete_edge(edge)
ApiDependencies.invoker.services.graph_execution_manager.set(
session
) # TODO: can this be done automatically, or add node through an API?
return session
except NodeAlreadyExecutedError:
return Response(status_code=400)
except IndexError:
return Response(status_code=400)
@session_router.put(
"/{session_id}/invoke",
operation_id="invoke_session",
responses={
200: {"model": None},
202: {"description": "The invocation is queued"},
400: {"description": "The session has no invocations ready to invoke"},
404: {"description": "Session not found"},
},
)
async def invoke_session(
session_id: str = Path(description="The id of the session to invoke"),
all: bool = Query(
default=False, description="Whether or not to invoke all remaining invocations"
),
) -> None:
"""Invokes a session"""
session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
if session is None:
return Response(status_code=404)
if session.is_complete():
return Response(status_code=400)
ApiDependencies.invoker.invoke(session, invoke_all=all)
return Response(status_code=202)

View File

@ -1,36 +1,38 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) # Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
from fastapi import FastAPI from fastapi import FastAPI
from fastapi_socketio import SocketManager
from fastapi_events.handlers.local import local_handler from fastapi_events.handlers.local import local_handler
from fastapi_events.typing import Event from fastapi_events.typing import Event
from fastapi_socketio import SocketManager
from ..services.events import EventServiceBase from ..services.events import EventServiceBase
class SocketIO: class SocketIO:
__sio: SocketManager __sio: SocketManager
def __init__(self, app: FastAPI): def __init__(self, app: FastAPI):
self.__sio = SocketManager(app = app) self.__sio = SocketManager(app=app)
self.__sio.on('subscribe', handler=self._handle_sub) self.__sio.on("subscribe", handler=self._handle_sub)
self.__sio.on('unsubscribe', handler=self._handle_unsub) self.__sio.on("unsubscribe", handler=self._handle_unsub)
local_handler.register( local_handler.register(
event_name = EventServiceBase.session_event, event_name=EventServiceBase.session_event, _func=self._handle_session_event
_func=self._handle_session_event
) )
async def _handle_session_event(self, event: Event): async def _handle_session_event(self, event: Event):
await self.__sio.emit( await self.__sio.emit(
event = event[1]['event'], event=event[1]["event"],
data = event[1]['data'], data=event[1]["data"],
room = event[1]['data']['graph_execution_state_id'] room=event[1]["data"]["graph_execution_state_id"],
) )
async def _handle_sub(self, sid, data, *args, **kwargs): async def _handle_sub(self, sid, data, *args, **kwargs):
if 'session' in data: if "session" in data:
self.__sio.enter_room(sid, data['session']) self.__sio.enter_room(sid, data["session"])
# @app.sio.on('unsubscribe') # @app.sio.on('unsubscribe')
async def _handle_unsub(self, sid, data, *args, **kwargs): async def _handle_unsub(self, sid, data, *args, **kwargs):
if 'session' in data: if "session" in data:
self.__sio.leave_room(sid, data['session']) self.__sio.leave_room(sid, data["session"])

View File

@ -2,36 +2,37 @@
import asyncio import asyncio
from inspect import signature from inspect import signature
from fastapi import FastAPI
from fastapi.openapi.utils import get_openapi
from fastapi.openapi.docs import get_swagger_ui_html, get_redoc_html
from fastapi.staticfiles import StaticFiles
from fastapi_events.middleware import EventHandlerASGIMiddleware
from fastapi_events.handlers.local import local_handler
from fastapi.middleware.cors import CORSMiddleware
from pydantic.schema import schema
import uvicorn import uvicorn
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from fastapi.openapi.docs import get_redoc_html, get_swagger_ui_html
from fastapi.openapi.utils import get_openapi
from fastapi.staticfiles import StaticFiles
from fastapi_events.handlers.local import local_handler
from fastapi_events.middleware import EventHandlerASGIMiddleware
from pydantic.schema import schema
from ..args import Args
from .api.dependencies import ApiDependencies
from .api.routers import images, sessions
from .api.sockets import SocketIO from .api.sockets import SocketIO
from .invocations import * from .invocations import *
from .invocations.baseinvocation import BaseInvocation from .invocations.baseinvocation import BaseInvocation
from .api.routers import images, sessions
from .api.dependencies import ApiDependencies
from ..args import Args
# Create the app # Create the app
# TODO: create this all in a method so configuration/etc. can be passed in? # TODO: create this all in a method so configuration/etc. can be passed in?
app = FastAPI( app = FastAPI(title="Invoke AI", docs_url=None, redoc_url=None)
title = "Invoke AI",
docs_url = None,
redoc_url = None
)
# Add event handler # Add event handler
event_handler_id: int = id(app) event_handler_id: int = id(app)
app.add_middleware( app.add_middleware(
EventHandlerASGIMiddleware, EventHandlerASGIMiddleware,
handlers = [local_handler], # TODO: consider doing this in services to support different configurations handlers=[
middleware_id = event_handler_id) local_handler
], # TODO: consider doing this in services to support different configurations
middleware_id=event_handler_id,
)
# Add CORS # Add CORS
# TODO: use configuration for this # TODO: use configuration for this
@ -48,38 +49,34 @@ socket_io = SocketIO(app)
config = {} config = {}
# Add startup event to load dependencies # Add startup event to load dependencies
@app.on_event('startup') @app.on_event("startup")
async def startup_event(): async def startup_event():
args = Args() args = Args()
config = args.parse_args() config = args.parse_args()
ApiDependencies.initialize( ApiDependencies.initialize(
args = args, args=args, config=config, event_handler_id=event_handler_id
config = config,
event_handler_id = event_handler_id
) )
# Shut down threads # Shut down threads
@app.on_event('shutdown') @app.on_event("shutdown")
async def shutdown_event(): async def shutdown_event():
ApiDependencies.shutdown() ApiDependencies.shutdown()
# Include all routers # Include all routers
# TODO: REMOVE # TODO: REMOVE
# app.include_router( # app.include_router(
# invocation.invocation_router, # invocation.invocation_router,
# prefix = '/api') # prefix = '/api')
app.include_router( app.include_router(sessions.session_router, prefix="/api")
sessions.session_router,
prefix = '/api' app.include_router(images.images_router, prefix="/api")
)
app.include_router(
images.images_router,
prefix = '/api'
)
# Build a custom OpenAPI to include all outputs # Build a custom OpenAPI to include all outputs
# TODO: can outputs be included on metadata of invocation schemas somehow? # TODO: can outputs be included on metadata of invocation schemas somehow?
@ -87,10 +84,10 @@ def custom_openapi():
if app.openapi_schema: if app.openapi_schema:
return app.openapi_schema return app.openapi_schema
openapi_schema = get_openapi( openapi_schema = get_openapi(
title = app.title, title=app.title,
description = "An API for invoking AI image operations", description="An API for invoking AI image operations",
version = "1.0.0", version="1.0.0",
routes = app.routes routes=app.routes,
) )
# Add all outputs # Add all outputs
@ -102,12 +99,12 @@ def custom_openapi():
output_types.add(output_type) output_types.add(output_type)
output_schemas = schema(output_types, ref_prefix="#/components/schemas/") output_schemas = schema(output_types, ref_prefix="#/components/schemas/")
for schema_key, output_schema in output_schemas['definitions'].items(): for schema_key, output_schema in output_schemas["definitions"].items():
openapi_schema["components"]["schemas"][schema_key] = output_schema openapi_schema["components"]["schemas"][schema_key] = output_schema
# TODO: note that we assume the schema_key here is the TYPE.__name__ # TODO: note that we assume the schema_key here is the TYPE.__name__
# This could break in some cases, figure out a better way to do it # This could break in some cases, figure out a better way to do it
output_type_titles[schema_key] = output_schema['title'] output_type_titles[schema_key] = output_schema["title"]
# Add a reference to the output type to additionalProperties of the invoker schema # Add a reference to the output type to additionalProperties of the invoker schema
for invoker in all_invocations: for invoker in all_invocations:
@ -115,47 +112,47 @@ def custom_openapi():
output_type = signature(invoker.invoke).return_annotation output_type = signature(invoker.invoke).return_annotation
output_type_title = output_type_titles[output_type.__name__] output_type_title = output_type_titles[output_type.__name__]
invoker_schema = openapi_schema["components"]["schemas"][invoker_name] invoker_schema = openapi_schema["components"]["schemas"][invoker_name]
outputs_ref = { '$ref': f'#/components/schemas/{output_type_title}' } outputs_ref = {"$ref": f"#/components/schemas/{output_type_title}"}
if 'additionalProperties' not in invoker_schema: if "additionalProperties" not in invoker_schema:
invoker_schema['additionalProperties'] = {} invoker_schema["additionalProperties"] = {}
invoker_schema["additionalProperties"]["outputs"] = outputs_ref
invoker_schema['additionalProperties']['outputs'] = outputs_ref
app.openapi_schema = openapi_schema app.openapi_schema = openapi_schema
return app.openapi_schema return app.openapi_schema
app.openapi = custom_openapi app.openapi = custom_openapi
# Override API doc favicons # Override API doc favicons
app.mount('/static', StaticFiles(directory='static/dream_web'), name='static') app.mount("/static", StaticFiles(directory="static/dream_web"), name="static")
@app.get("/docs", include_in_schema=False) @app.get("/docs", include_in_schema=False)
def overridden_swagger(): def overridden_swagger():
return get_swagger_ui_html( return get_swagger_ui_html(
openapi_url=app.openapi_url, openapi_url=app.openapi_url,
title=app.title, title=app.title,
swagger_favicon_url="/static/favicon.ico" swagger_favicon_url="/static/favicon.ico",
) )
@app.get("/redoc", include_in_schema=False) @app.get("/redoc", include_in_schema=False)
def overridden_redoc(): def overridden_redoc():
return get_redoc_html( return get_redoc_html(
openapi_url=app.openapi_url, openapi_url=app.openapi_url,
title=app.title, title=app.title,
redoc_favicon_url="/static/favicon.ico" redoc_favicon_url="/static/favicon.ico",
) )
def invoke_api(): def invoke_api():
# Start our own event loop for eventing usage # Start our own event loop for eventing usage
# TODO: determine if there's a better way to do this # TODO: determine if there's a better way to do this
loop = asyncio.new_event_loop() loop = asyncio.new_event_loop()
config = uvicorn.Config( config = uvicorn.Config(app=app, host="0.0.0.0", port=9090, loop=loop)
app = app, # Use access_log to turn off logging
host = "0.0.0.0",
port = 9090,
loop = loop)
# Use access_log to turn off logging
server = uvicorn.Server(config) server = uvicorn.Server(config)
loop.run_until_complete(server.serve()) loop.run_until_complete(server.serve())

View File

@ -1,33 +1,40 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) # Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
import argparse import argparse
import shlex
import os import os
import shlex
import time import time
from typing import Any, Dict, Iterable, Literal, Union, get_args, get_origin, get_type_hints from typing import (
Any,
Dict,
Iterable,
Literal,
Union,
get_args,
get_origin,
get_type_hints,
)
from pydantic import BaseModel from pydantic import BaseModel
from pydantic.fields import Field from pydantic.fields import Field
from .services.processor import DefaultInvocationProcessor from ..args import Args
from .invocations import *
from .services.graph import EdgeConnection, GraphExecutionState from .invocations.baseinvocation import BaseInvocation
from .services.sqlite import SqliteItemStorage
from .invocations.image import ImageField from .invocations.image import ImageField
from .services.events import EventServiceBase
from .services.generate_initializer import get_generate from .services.generate_initializer import get_generate
from .services.graph import EdgeConnection, GraphExecutionState
from .services.image_storage import DiskImageStorage from .services.image_storage import DiskImageStorage
from .services.invocation_queue import MemoryInvocationQueue from .services.invocation_queue import MemoryInvocationQueue
from .invocations.baseinvocation import BaseInvocation
from .services.invocation_services import InvocationServices from .services.invocation_services import InvocationServices
from .services.invoker import Invoker from .services.invoker import Invoker
from .invocations import * from .services.processor import DefaultInvocationProcessor
from ..args import Args from .services.sqlite import SqliteItemStorage
from .services.events import EventServiceBase
class InvocationCommand(BaseModel): class InvocationCommand(BaseModel):
invocation: Union[BaseInvocation.get_invocations()] = Field(discriminator="type") # type: ignore invocation: Union[BaseInvocation.get_invocations()] = Field(discriminator="type") # type: ignore
class InvalidArgs(Exception): class InvalidArgs(Exception):
@ -35,72 +42,94 @@ class InvalidArgs(Exception):
def get_invocation_parser() -> argparse.ArgumentParser: def get_invocation_parser() -> argparse.ArgumentParser:
# Create invocation parser # Create invocation parser
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
def exit(*args, **kwargs): def exit(*args, **kwargs):
raise InvalidArgs raise InvalidArgs
parser.exit = exit parser.exit = exit
subparsers = parser.add_subparsers(dest='type') subparsers = parser.add_subparsers(dest="type")
invocation_parsers = dict() invocation_parsers = dict()
# Add history parser # Add history parser
history_parser = subparsers.add_parser('history', help="Shows the invocation history") history_parser = subparsers.add_parser(
history_parser.add_argument('count', nargs='?', default=5, type=int, help="The number of history entries to show") "history", help="Shows the invocation history"
)
history_parser.add_argument(
"count",
nargs="?",
default=5,
type=int,
help="The number of history entries to show",
)
# Add default parser # Add default parser
default_parser = subparsers.add_parser('default', help="Define a default value for all inputs with a specified name") default_parser = subparsers.add_parser(
default_parser.add_argument('input', type=str, help="The input field") "default", help="Define a default value for all inputs with a specified name"
default_parser.add_argument('value', help="The default value") )
default_parser.add_argument("input", type=str, help="The input field")
default_parser = subparsers.add_parser('reset_default', help="Resets a default value") default_parser.add_argument("value", help="The default value")
default_parser.add_argument('input', type=str, help="The input field")
default_parser = subparsers.add_parser(
"reset_default", help="Resets a default value"
)
default_parser.add_argument("input", type=str, help="The input field")
# Create subparsers for each invocation # Create subparsers for each invocation
invocations = BaseInvocation.get_all_subclasses() invocations = BaseInvocation.get_all_subclasses()
for invocation in invocations: for invocation in invocations:
hints = get_type_hints(invocation) hints = get_type_hints(invocation)
cmd_name = get_args(hints['type'])[0] cmd_name = get_args(hints["type"])[0]
command_parser = subparsers.add_parser(cmd_name, help=invocation.__doc__) command_parser = subparsers.add_parser(cmd_name, help=invocation.__doc__)
invocation_parsers[cmd_name] = command_parser invocation_parsers[cmd_name] = command_parser
# Add linking capability # Add linking capability
command_parser.add_argument('--link', '-l', action='append', nargs=3, command_parser.add_argument(
help="A link in the format 'dest_field source_node source_field'. source_node can be relative to history (e.g. -1)") "--link",
"-l",
action="append",
nargs=3,
help="A link in the format 'dest_field source_node source_field'. source_node can be relative to history (e.g. -1)",
)
command_parser.add_argument('--link_node', '-ln', action='append', command_parser.add_argument(
help="A link from all fields in the specified node. Node can be relative to history (e.g. -1)") "--link_node",
"-ln",
action="append",
help="A link from all fields in the specified node. Node can be relative to history (e.g. -1)",
)
# Convert all fields to arguments # Convert all fields to arguments
fields = invocation.__fields__ fields = invocation.__fields__
for name, field in fields.items(): for name, field in fields.items():
if name in ['id', 'type']: if name in ["id", "type"]:
continue continue
if get_origin(field.type_) == Literal: if get_origin(field.type_) == Literal:
allowed_values = get_args(field.type_) allowed_values = get_args(field.type_)
allowed_types = set() allowed_types = set()
for val in allowed_values: for val in allowed_values:
allowed_types.add(type(val)) allowed_types.add(type(val))
allowed_types_list = list(allowed_types) allowed_types_list = list(allowed_types)
field_type = allowed_types_list[0] if len(allowed_types) == 1 else Union[allowed_types_list] # type: ignore field_type = allowed_types_list[0] if len(allowed_types) == 1 else Union[allowed_types_list] # type: ignore
command_parser.add_argument( command_parser.add_argument(
f"--{name}", f"--{name}",
dest=name, dest=name,
type=field_type, type=field_type,
default=field.default, default=field.default,
choices = allowed_values, choices=allowed_values,
help=field.field_info.description help=field.field_info.description,
) )
else: else:
command_parser.add_argument( command_parser.add_argument(
f"--{name}", f"--{name}",
dest=name, dest=name,
type=field.type_, type=field.type_,
default=field.default, default=field.default,
help=field.field_info.description help=field.field_info.description,
) )
return parser return parser
@ -110,8 +139,8 @@ def get_invocation_command(invocation) -> str:
fields = invocation.__fields__.items() fields = invocation.__fields__.items()
type_hints = get_type_hints(type(invocation)) type_hints = get_type_hints(type(invocation))
command = [invocation.type] command = [invocation.type]
for name,field in fields: for name, field in fields:
if name in ['id', 'type']: if name in ["id", "type"]:
continue continue
# TODO: add links # TODO: add links
@ -127,17 +156,25 @@ def get_invocation_command(invocation) -> str:
if type_hint is str or str in get_args(type_hint): if type_hint is str or str in get_args(type_hint):
command.append(f'--{name} "{field_value}"') command.append(f'--{name} "{field_value}"')
else: else:
command.append(f'--{name} {field_value}') command.append(f"--{name} {field_value}")
return ' '.join(command) return " ".join(command)
def get_graph_execution_history(graph_execution_state: GraphExecutionState) -> Iterable[str]: def get_graph_execution_history(
graph_execution_state: GraphExecutionState,
) -> Iterable[str]:
"""Gets the history of fully-executed invocations for a graph execution""" """Gets the history of fully-executed invocations for a graph execution"""
return (n for n in reversed(graph_execution_state.executed_history) if n in graph_execution_state.graph.nodes) return (
n
for n in reversed(graph_execution_state.executed_history)
if n in graph_execution_state.graph.nodes
)
def generate_matching_edges(a: BaseInvocation, b: BaseInvocation) -> list[tuple[EdgeConnection, EdgeConnection]]: def generate_matching_edges(
a: BaseInvocation, b: BaseInvocation
) -> list[tuple[EdgeConnection, EdgeConnection]]:
"""Generates all possible edges between two invocations""" """Generates all possible edges between two invocations"""
atype = type(a) atype = type(a)
btype = type(b) btype = type(b)
@ -148,12 +185,18 @@ def generate_matching_edges(a: BaseInvocation, b: BaseInvocation) -> list[tuple[
bfields = get_type_hints(btype) bfields = get_type_hints(btype)
matching_fields = set(afields.keys()).intersection(bfields.keys()) matching_fields = set(afields.keys()).intersection(bfields.keys())
# Remove invalid fields # Remove invalid fields
invalid_fields = set(['type', 'id']) invalid_fields = set(["type", "id"])
matching_fields = matching_fields.difference(invalid_fields) matching_fields = matching_fields.difference(invalid_fields)
edges = [(EdgeConnection(node_id = a.id, field = field), EdgeConnection(node_id = b.id, field = field)) for field in matching_fields] edges = [
(
EdgeConnection(node_id=a.id, field=field),
EdgeConnection(node_id=b.id, field=field),
)
for field in matching_fields
]
return edges return edges
@ -165,27 +208,31 @@ def invoke_cli():
# NOTE: load model on first use, uncomment to load at startup # NOTE: load model on first use, uncomment to load at startup
# TODO: Make this a config option? # TODO: Make this a config option?
#generate.load_model() # generate.load_model()
events = EventServiceBase() events = EventServiceBase()
output_folder = os.path.abspath(os.path.join(os.path.dirname(__file__), '../../../outputs')) output_folder = os.path.abspath(
os.path.join(os.path.dirname(__file__), "../../../outputs")
)
# TODO: build a file/path manager? # TODO: build a file/path manager?
db_location = os.path.join(output_folder, 'invokeai.db') db_location = os.path.join(output_folder, "invokeai.db")
services = InvocationServices( services = InvocationServices(
generate = generate, generate=generate,
events = events, events=events,
images = DiskImageStorage(output_folder), images=DiskImageStorage(output_folder),
queue = MemoryInvocationQueue(), queue=MemoryInvocationQueue(),
graph_execution_manager = SqliteItemStorage[GraphExecutionState](filename = db_location, table_name = 'graph_executions'), graph_execution_manager=SqliteItemStorage[GraphExecutionState](
processor = DefaultInvocationProcessor() filename=db_location, table_name="graph_executions"
),
processor=DefaultInvocationProcessor(),
) )
invoker = Invoker(services) invoker = Invoker(services)
session: GraphExecutionState = invoker.create_execution_state() session: GraphExecutionState = invoker.create_execution_state()
parser = get_invocation_parser() parser = get_invocation_parser()
# Uncomment to print out previous sessions at startup # Uncomment to print out previous sessions at startup
@ -201,10 +248,10 @@ def invoke_cli():
# Ctrl-c exits # Ctrl-c exits
break break
if cmd_input in ['exit','q']: if cmd_input in ["exit", "q"]:
break; break
if cmd_input in ['--help','help','h','?']: if cmd_input in ["--help", "help", "h", "?"]:
parser.print_help() parser.print_help()
continue continue
@ -214,65 +261,82 @@ def invoke_cli():
history = list(get_graph_execution_history(session)) history = list(get_graph_execution_history(session))
# Split the command for piping # Split the command for piping
cmds = cmd_input.split('|') cmds = cmd_input.split("|")
start_id = len(history) start_id = len(history)
current_id = start_id current_id = start_id
new_invocations = list() new_invocations = list()
for cmd in cmds: for cmd in cmds:
if cmd is None or cmd.strip() == '': if cmd is None or cmd.strip() == "":
raise InvalidArgs('Empty command') raise InvalidArgs("Empty command")
# Parse args to create invocation # Parse args to create invocation
args = vars(parser.parse_args(shlex.split(cmd.strip()))) args = vars(parser.parse_args(shlex.split(cmd.strip())))
# Check for special commands # Check for special commands
# TODO: These might be better as Pydantic models, similar to the invocations # TODO: These might be better as Pydantic models, similar to the invocations
if args['type'] == 'history': if args["type"] == "history":
history_count = args['count'] or 5 history_count = args["count"] or 5
for i in range(min(history_count, len(history))): for i in range(min(history_count, len(history))):
entry_id = history[-1 - i] entry_id = history[-1 - i]
entry = session.graph.get_node(entry_id) entry = session.graph.get_node(entry_id)
print(f'{entry_id}: {get_invocation_command(entry.invocation)}') print(f"{entry_id}: {get_invocation_command(entry.invocation)}")
continue continue
if args['type'] == 'reset_default': if args["type"] == "reset_default":
if args['input'] in defaults: if args["input"] in defaults:
del defaults[args['input']] del defaults[args["input"]]
continue continue
if args['type'] == 'default': if args["type"] == "default":
field = args['input'] field = args["input"]
field_value = args['value'] field_value = args["value"]
defaults[field] = field_value defaults[field] = field_value
continue continue
# Override defaults # Override defaults
for field_name,field_default in defaults.items(): for field_name, field_default in defaults.items():
if field_name in args: if field_name in args:
args[field_name] = field_default args[field_name] = field_default
# Parse invocation # Parse invocation
args['id'] = current_id args["id"] = current_id
command = InvocationCommand(invocation = args) command = InvocationCommand(invocation=args)
# Pipe previous command output (if there was a previous command) # Pipe previous command output (if there was a previous command)
edges = [] edges = []
if len(history) > 0 or current_id != start_id: if len(history) > 0 or current_id != start_id:
from_id = history[0] if current_id == start_id else str(current_id - 1) from_id = (
from_node = next(filter(lambda n: n[0].id == from_id, new_invocations))[0] if current_id != start_id else session.graph.get_node(from_id) history[0] if current_id == start_id else str(current_id - 1)
matching_edges = generate_matching_edges(from_node, command.invocation) )
from_node = (
next(filter(lambda n: n[0].id == from_id, new_invocations))[0]
if current_id != start_id
else session.graph.get_node(from_id)
)
matching_edges = generate_matching_edges(
from_node, command.invocation
)
edges.extend(matching_edges) edges.extend(matching_edges)
# Parse provided links # Parse provided links
if 'link_node' in args and args['link_node']: if "link_node" in args and args["link_node"]:
for link in args['link_node']: for link in args["link_node"]:
link_node = session.graph.get_node(link) link_node = session.graph.get_node(link)
matching_edges = generate_matching_edges(link_node, command.invocation) matching_edges = generate_matching_edges(
link_node, command.invocation
)
edges.extend(matching_edges) edges.extend(matching_edges)
if 'link' in args and args['link']: if "link" in args and args["link"]:
for link in args['link']: for link in args["link"]:
edges.append((EdgeConnection(node_id = link[1], field = link[0]), EdgeConnection(node_id = command.invocation.id, field = link[2]))) edges.append(
(
EdgeConnection(node_id=link[1], field=link[0]),
EdgeConnection(
node_id=command.invocation.id, field=link[2]
),
)
)
new_invocations.append((command.invocation, edges)) new_invocations.append((command.invocation, edges))
@ -286,17 +350,19 @@ def invoke_cli():
session.add_edge(edge) session.add_edge(edge)
# Execute all available invocations # Execute all available invocations
invoker.invoke(session, invoke_all = True) invoker.invoke(session, invoke_all=True)
while not session.is_complete(): while not session.is_complete():
# Wait some time # Wait some time
session = invoker.services.graph_execution_manager.get(session.id) session = invoker.services.graph_execution_manager.get(session.id)
time.sleep(0.1) time.sleep(0.1)
# Print any errors # Print any errors
if session.has_error(): if session.has_error():
for n in session.errors: for n in session.errors:
print(f'Error in node {n} (source node {session.prepared_source_mapping[n]}): {session.errors[n]}') print(
f"Error in node {n} (source node {session.prepared_source_mapping[n]}): {session.errors[n]}"
)
# Start a new session # Start a new session
print("Creating a new session") print("Creating a new session")
session = invoker.create_execution_state() session = invoker.create_execution_state()
@ -307,7 +373,7 @@ def invoke_cli():
except SystemExit: except SystemExit:
continue continue
invoker.stop() invoker.stop()

View File

@ -4,5 +4,9 @@ __all__ = []
dirname = os.path.dirname(os.path.abspath(__file__)) dirname = os.path.dirname(os.path.abspath(__file__))
for f in os.listdir(dirname): for f in os.listdir(dirname):
if f != "__init__.py" and os.path.isfile("%s/%s" % (dirname, f)) and f[-3:] == ".py": if (
f != "__init__.py"
and os.path.isfile("%s/%s" % (dirname, f))
and f[-3:] == ".py"
):
__all__.append(f[:-3]) __all__.append(f[:-3])

View File

@ -3,7 +3,9 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from inspect import signature from inspect import signature
from typing import get_args, get_type_hints from typing import get_args, get_type_hints
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from ..services.invocation_services import InvocationServices from ..services.invocation_services import InvocationServices

View File

@ -1,30 +1,37 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) # Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
from typing import Literal from typing import Literal
import numpy
from pydantic import Field
from PIL import Image, ImageOps
import cv2 as cv import cv2 as cv
from .image import ImageField, ImageOutput import numpy
from .baseinvocation import BaseInvocation, InvocationContext from PIL import Image, ImageOps
from pydantic import Field
from ..services.image_storage import ImageType from ..services.image_storage import ImageType
from .baseinvocation import BaseInvocation, InvocationContext
from .image import ImageField, ImageOutput
class CvInpaintInvocation(BaseInvocation): class CvInpaintInvocation(BaseInvocation):
"""Simple inpaint using opencv.""" """Simple inpaint using opencv."""
type: Literal['cv_inpaint'] = 'cv_inpaint'
type: Literal["cv_inpaint"] = "cv_inpaint"
# Inputs # Inputs
image: ImageField = Field(default=None, description="The image to inpaint") image: ImageField = Field(default=None, description="The image to inpaint")
mask: ImageField = Field(default=None, description="The mask to use when inpainting") mask: ImageField = Field(
default=None, description="The mask to use when inpainting"
)
def invoke(self, context: InvocationContext) -> ImageOutput: def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get(self.image.image_type, self.image.image_name) image = context.services.images.get(
self.image.image_type, self.image.image_name
)
mask = context.services.images.get(self.mask.image_type, self.mask.image_name) mask = context.services.images.get(self.mask.image_type, self.mask.image_name)
# Convert to cv image/mask # Convert to cv image/mask
# TODO: consider making these utility functions # TODO: consider making these utility functions
cv_image = cv.cvtColor(numpy.array(image.convert('RGB')), cv.COLOR_RGB2BGR) cv_image = cv.cvtColor(numpy.array(image.convert("RGB")), cv.COLOR_RGB2BGR)
cv_mask = numpy.array(ImageOps.invert(mask)) cv_mask = numpy.array(ImageOps.invert(mask))
# Inpaint # Inpaint
@ -35,8 +42,10 @@ class CvInpaintInvocation(BaseInvocation):
image_inpainted = Image.fromarray(cv.cvtColor(cv_inpainted, cv.COLOR_BGR2RGB)) image_inpainted = Image.fromarray(cv.cvtColor(cv_inpainted, cv.COLOR_BGR2RGB))
image_type = ImageType.INTERMEDIATE image_type = ImageType.INTERMEDIATE
image_name = context.services.images.create_name(context.graph_execution_state_id, self.id) image_name = context.services.images.create_name(
context.graph_execution_state_id, self.id
)
context.services.images.save(image_type, image_name, image_inpainted) context.services.images.save(image_type, image_name, image_inpainted)
return ImageOutput( return ImageOutput(
image = ImageField(image_type = image_type, image_name = image_name) image=ImageField(image_type=image_type, image_name=image_name)
) )

View File

@ -0,0 +1,211 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
from datetime import datetime, timezone
from typing import Any, Literal, Optional, Union
import numpy as np
from PIL import Image
from pydantic import Field
from skimage.exposure.histogram_matching import match_histograms
from ..services.image_storage import ImageType
from ..services.invocation_services import InvocationServices
from .baseinvocation import BaseInvocation, InvocationContext
from .image import ImageField, ImageOutput
SAMPLER_NAME_VALUES = Literal[
"ddim", "plms", "k_lms", "k_dpm_2", "k_dpm_2_a", "k_euler", "k_euler_a", "k_heun"
]
# Text to image
class TextToImageInvocation(BaseInvocation):
"""Generates an image using text2img."""
type: Literal["txt2img"] = "txt2img"
# Inputs
# TODO: consider making prompt optional to enable providing prompt through a link
# fmt: off
prompt: Optional[str] = Field(description="The prompt to generate an image from")
seed: int = Field(default=-1,ge=-1, le=np.iinfo(np.uint32).max, description="The seed to use (-1 for a random seed)", )
steps: int = Field(default=10, gt=0, description="The number of steps to use to generate the image")
width: int = Field(default=512, multiple_of=64, gt=0, description="The width of the resulting image", )
height: int = Field(default=512, multiple_of=64, gt=0, description="The height of the resulting image", )
cfg_scale: float = Field(default=7.5, gt=0, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", )
sampler_name: SAMPLER_NAME_VALUES = Field(default="k_lms", description="The sampler to use" )
seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", )
model: str = Field(default="", description="The model to use (currently ignored)")
progress_images: bool = Field(default=False, description="Whether or not to produce progress images during generation", )
# fmt: on
# TODO: pass this an emitter method or something? or a session for dispatching?
def dispatch_progress(
self, context: InvocationContext, sample: Any = None, step: int = 0
) -> None:
context.services.events.emit_generator_progress(
context.graph_execution_state_id,
self.id,
step,
float(step) / float(self.steps),
)
def invoke(self, context: InvocationContext) -> ImageOutput:
def step_callback(sample, step=0):
self.dispatch_progress(context, sample, step)
# Handle invalid model parameter
# TODO: figure out if this can be done via a validator that uses the model_cache
# TODO: How to get the default model name now?
if self.model is None or self.model == "":
self.model = context.services.generate.model_name
# Set the model (if already cached, this does nothing)
context.services.generate.set_model(self.model)
results = context.services.generate.prompt2image(
prompt=self.prompt,
step_callback=step_callback,
**self.dict(
exclude={"prompt"}
), # Shorthand for passing all of the parameters above manually
)
# Results are image and seed, unwrap for now and ignore the seed
# TODO: pre-seed?
# TODO: can this return multiple results? Should it?
image_type = ImageType.RESULT
image_name = context.services.images.create_name(
context.graph_execution_state_id, self.id
)
context.services.images.save(image_type, image_name, results[0][0])
return ImageOutput(
image=ImageField(image_type=image_type, image_name=image_name)
)
class ImageToImageInvocation(TextToImageInvocation):
"""Generates an image using img2img."""
type: Literal["img2img"] = "img2img"
# Inputs
image: Union[ImageField, None] = Field(description="The input image")
strength: float = Field(
default=0.75, gt=0, le=1, description="The strength of the original image"
)
fit: bool = Field(
default=True,
description="Whether or not the result should be fit to the aspect ratio of the input image",
)
def invoke(self, context: InvocationContext) -> ImageOutput:
image = (
None
if self.image is None
else context.services.images.get(
self.image.image_type, self.image.image_name
)
)
mask = None
def step_callback(sample, step=0):
self.dispatch_progress(context, sample, step)
# Handle invalid model parameter
# TODO: figure out if this can be done via a validator that uses the model_cache
# TODO: How to get the default model name now?
if self.model is None or self.model == "":
self.model = context.services.generate.model_name
# Set the model (if already cached, this does nothing)
context.services.generate.set_model(self.model)
results = context.services.generate.prompt2image(
prompt=self.prompt,
init_img=image,
init_mask=mask,
step_callback=step_callback,
**self.dict(
exclude={"prompt", "image", "mask"}
), # Shorthand for passing all of the parameters above manually
)
result_image = results[0][0]
# Results are image and seed, unwrap for now and ignore the seed
# TODO: pre-seed?
# TODO: can this return multiple results? Should it?
image_type = ImageType.RESULT
image_name = context.services.images.create_name(
context.graph_execution_state_id, self.id
)
context.services.images.save(image_type, image_name, result_image)
return ImageOutput(
image=ImageField(image_type=image_type, image_name=image_name)
)
class InpaintInvocation(ImageToImageInvocation):
"""Generates an image using inpaint."""
type: Literal["inpaint"] = "inpaint"
# Inputs
mask: Union[ImageField, None] = Field(description="The mask")
inpaint_replace: float = Field(
default=0.0,
ge=0.0,
le=1.0,
description="The amount by which to replace masked areas with latent noise",
)
def invoke(self, context: InvocationContext) -> ImageOutput:
image = (
None
if self.image is None
else context.services.images.get(
self.image.image_type, self.image.image_name
)
)
mask = (
None
if self.mask is None
else context.services.images.get(self.mask.image_type, self.mask.image_name)
)
def step_callback(sample, step=0):
self.dispatch_progress(context, sample, step)
# Handle invalid model parameter
# TODO: figure out if this can be done via a validator that uses the model_cache
# TODO: How to get the default model name now?
if self.model is None or self.model == "":
self.model = context.services.generate.model_name
# Set the model (if already cached, this does nothing)
context.services.generate.set_model(self.model)
results = context.services.generate.prompt2image(
prompt=self.prompt,
init_img=image,
init_mask=mask,
step_callback=step_callback,
**self.dict(
exclude={"prompt", "image", "mask"}
), # Shorthand for passing all of the parameters above manually
)
result_image = results[0][0]
# Results are image and seed, unwrap for now and ignore the seed
# TODO: pre-seed?
# TODO: can this return multiple results? Should it?
image_type = ImageType.RESULT
image_name = context.services.images.create_name(
context.graph_execution_state_id, self.id
)
context.services.images.save(image_type, image_name, result_image)
return ImageOutput(
image=ImageField(image_type=image_type, image_name=image_name)
)

View File

@ -2,30 +2,37 @@
from datetime import datetime, timezone from datetime import datetime, timezone
from typing import Literal, Optional from typing import Literal, Optional
import numpy import numpy
from pydantic import Field, BaseModel from PIL import Image, ImageFilter, ImageOps
from PIL import Image, ImageOps, ImageFilter from pydantic import BaseModel, Field
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext
from ..services.image_storage import ImageType from ..services.image_storage import ImageType
from ..services.invocation_services import InvocationServices from ..services.invocation_services import InvocationServices
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext
class ImageField(BaseModel): class ImageField(BaseModel):
"""An image field used for passing image objects between invocations""" """An image field used for passing image objects between invocations"""
image_type: str = Field(default=ImageType.RESULT, description="The type of the image")
image_type: str = Field(
default=ImageType.RESULT, description="The type of the image"
)
image_name: Optional[str] = Field(default=None, description="The name of the image") image_name: Optional[str] = Field(default=None, description="The name of the image")
class ImageOutput(BaseInvocationOutput): class ImageOutput(BaseInvocationOutput):
"""Base class for invocations that output an image""" """Base class for invocations that output an image"""
type: Literal['image'] = 'image'
type: Literal["image"] = "image"
image: ImageField = Field(default=None, description="The output image") image: ImageField = Field(default=None, description="The output image")
class MaskOutput(BaseInvocationOutput): class MaskOutput(BaseInvocationOutput):
"""Base class for invocations that output a mask""" """Base class for invocations that output a mask"""
type: Literal['mask'] = 'mask'
type: Literal["mask"] = "mask"
mask: ImageField = Field(default=None, description="The output mask") mask: ImageField = Field(default=None, description="The output mask")
@ -33,7 +40,8 @@ class MaskOutput(BaseInvocationOutput):
# TODO: this isn't really necessary anymore # TODO: this isn't really necessary anymore
class LoadImageInvocation(BaseInvocation): class LoadImageInvocation(BaseInvocation):
"""Load an image from a filename and provide it as output.""" """Load an image from a filename and provide it as output."""
type: Literal['load_image'] = 'load_image'
type: Literal["load_image"] = "load_image"
# Inputs # Inputs
image_type: ImageType = Field(description="The type of the image") image_type: ImageType = Field(description="The type of the image")
@ -41,69 +49,100 @@ class LoadImageInvocation(BaseInvocation):
def invoke(self, context: InvocationContext) -> ImageOutput: def invoke(self, context: InvocationContext) -> ImageOutput:
return ImageOutput( return ImageOutput(
image = ImageField(image_type = self.image_type, image_name = self.image_name) image=ImageField(image_type=self.image_type, image_name=self.image_name)
) )
class ShowImageInvocation(BaseInvocation): class ShowImageInvocation(BaseInvocation):
"""Displays a provided image, and passes it forward in the pipeline.""" """Displays a provided image, and passes it forward in the pipeline."""
type: Literal['show_image'] = 'show_image'
type: Literal["show_image"] = "show_image"
# Inputs # Inputs
image: ImageField = Field(default=None, description="The image to show") image: ImageField = Field(default=None, description="The image to show")
def invoke(self, context: InvocationContext) -> ImageOutput: def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get(self.image.image_type, self.image.image_name) image = context.services.images.get(
self.image.image_type, self.image.image_name
)
if image: if image:
image.show() image.show()
# TODO: how to handle failure? # TODO: how to handle failure?
return ImageOutput( return ImageOutput(
image = ImageField(image_type = self.image.image_type, image_name = self.image.image_name) image=ImageField(
image_type=self.image.image_type, image_name=self.image.image_name
)
) )
class CropImageInvocation(BaseInvocation): class CropImageInvocation(BaseInvocation):
"""Crops an image to a specified box. The box can be outside of the image.""" """Crops an image to a specified box. The box can be outside of the image."""
type: Literal['crop'] = 'crop'
type: Literal["crop"] = "crop"
# Inputs # Inputs
image: ImageField = Field(default=None, description="The image to crop") image: ImageField = Field(default=None, description="The image to crop")
x: int = Field(default=0, description="The left x coordinate of the crop rectangle") x: int = Field(default=0, description="The left x coordinate of the crop rectangle")
y: int = Field(default=0, description="The top y coordinate of the crop rectangle") y: int = Field(default=0, description="The top y coordinate of the crop rectangle")
width: int = Field(default=512, gt=0, description="The width of the crop rectangle") width: int = Field(default=512, gt=0, description="The width of the crop rectangle")
height: int = Field(default=512, gt=0, description="The height of the crop rectangle") height: int = Field(
default=512, gt=0, description="The height of the crop rectangle"
)
def invoke(self, context: InvocationContext) -> ImageOutput: def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get(self.image.image_type, self.image.image_name) image = context.services.images.get(
self.image.image_type, self.image.image_name
)
image_crop = Image.new(mode = 'RGBA', size = (self.width, self.height), color = (0, 0, 0, 0)) image_crop = Image.new(
mode="RGBA", size=(self.width, self.height), color=(0, 0, 0, 0)
)
image_crop.paste(image, (-self.x, -self.y)) image_crop.paste(image, (-self.x, -self.y))
image_type = ImageType.INTERMEDIATE image_type = ImageType.INTERMEDIATE
image_name = context.services.images.create_name(context.graph_execution_state_id, self.id) image_name = context.services.images.create_name(
context.graph_execution_state_id, self.id
)
context.services.images.save(image_type, image_name, image_crop) context.services.images.save(image_type, image_name, image_crop)
return ImageOutput( return ImageOutput(
image = ImageField(image_type = image_type, image_name = image_name) image=ImageField(image_type=image_type, image_name=image_name)
) )
class PasteImageInvocation(BaseInvocation): class PasteImageInvocation(BaseInvocation):
"""Pastes an image into another image.""" """Pastes an image into another image."""
type: Literal['paste'] = 'paste'
type: Literal["paste"] = "paste"
# Inputs # Inputs
base_image: ImageField = Field(default=None, description="The base image") base_image: ImageField = Field(default=None, description="The base image")
image: ImageField = Field(default=None, description="The image to paste") image: ImageField = Field(default=None, description="The image to paste")
mask: Optional[ImageField] = Field(default=None, description="The mask to use when pasting") mask: Optional[ImageField] = Field(
x: int = Field(default=0, description="The left x coordinate at which to paste the image") default=None, description="The mask to use when pasting"
y: int = Field(default=0, description="The top y coordinate at which to paste the image") )
x: int = Field(
default=0, description="The left x coordinate at which to paste the image"
)
y: int = Field(
default=0, description="The top y coordinate at which to paste the image"
)
def invoke(self, context: InvocationContext) -> ImageOutput: def invoke(self, context: InvocationContext) -> ImageOutput:
base_image = context.services.images.get(self.base_image.image_type, self.base_image.image_name) base_image = context.services.images.get(
image = context.services.images.get(self.image.image_type, self.image.image_name) self.base_image.image_type, self.base_image.image_name
mask = None if self.mask is None else ImageOps.invert(services.images.get(self.mask.image_type, self.mask.image_name)) )
image = context.services.images.get(
self.image.image_type, self.image.image_name
)
mask = (
None
if self.mask is None
else ImageOps.invert(
services.images.get(self.mask.image_type, self.mask.image_name)
)
)
# TODO: probably shouldn't invert mask here... should user be required to do it? # TODO: probably shouldn't invert mask here... should user be required to do it?
min_x = min(0, self.x) min_x = min(0, self.x)
@ -111,67 +150,88 @@ class PasteImageInvocation(BaseInvocation):
max_x = max(base_image.width, image.width + self.x) max_x = max(base_image.width, image.width + self.x)
max_y = max(base_image.height, image.height + self.y) max_y = max(base_image.height, image.height + self.y)
new_image = Image.new(mode = 'RGBA', size = (max_x - min_x, max_y - min_y), color = (0, 0, 0, 0)) new_image = Image.new(
mode="RGBA", size=(max_x - min_x, max_y - min_y), color=(0, 0, 0, 0)
)
new_image.paste(base_image, (abs(min_x), abs(min_y))) new_image.paste(base_image, (abs(min_x), abs(min_y)))
new_image.paste(image, (max(0, self.x), max(0, self.y)), mask = mask) new_image.paste(image, (max(0, self.x), max(0, self.y)), mask=mask)
image_type = ImageType.RESULT image_type = ImageType.RESULT
image_name = context.services.images.create_name(context.graph_execution_state_id, self.id) image_name = context.services.images.create_name(
context.graph_execution_state_id, self.id
)
context.services.images.save(image_type, image_name, new_image) context.services.images.save(image_type, image_name, new_image)
return ImageOutput( return ImageOutput(
image = ImageField(image_type = image_type, image_name = image_name) image=ImageField(image_type=image_type, image_name=image_name)
) )
class MaskFromAlphaInvocation(BaseInvocation): class MaskFromAlphaInvocation(BaseInvocation):
"""Extracts the alpha channel of an image as a mask.""" """Extracts the alpha channel of an image as a mask."""
type: Literal['tomask'] = 'tomask'
type: Literal["tomask"] = "tomask"
# Inputs # Inputs
image: ImageField = Field(default=None, description="The image to create the mask from") image: ImageField = Field(
default=None, description="The image to create the mask from"
)
invert: bool = Field(default=False, description="Whether or not to invert the mask") invert: bool = Field(default=False, description="Whether or not to invert the mask")
def invoke(self, context: InvocationContext) -> MaskOutput: def invoke(self, context: InvocationContext) -> MaskOutput:
image = context.services.images.get(self.image.image_type, self.image.image_name) image = context.services.images.get(
self.image.image_type, self.image.image_name
)
image_mask = image.split()[-1] image_mask = image.split()[-1]
if self.invert: if self.invert:
image_mask = ImageOps.invert(image_mask) image_mask = ImageOps.invert(image_mask)
image_type = ImageType.INTERMEDIATE image_type = ImageType.INTERMEDIATE
image_name = context.services.images.create_name(context.graph_execution_state_id, self.id) image_name = context.services.images.create_name(
context.services.images.save(image_type, image_name, image_mask) context.graph_execution_state_id, self.id
return MaskOutput(
mask = ImageField(image_type = image_type, image_name = image_name)
) )
context.services.images.save(image_type, image_name, image_mask)
return MaskOutput(mask=ImageField(image_type=image_type, image_name=image_name))
class BlurInvocation(BaseInvocation): class BlurInvocation(BaseInvocation):
"""Blurs an image""" """Blurs an image"""
type: Literal['blur'] = 'blur'
type: Literal["blur"] = "blur"
# Inputs # Inputs
image: ImageField = Field(default=None, description="The image to blur") image: ImageField = Field(default=None, description="The image to blur")
radius: float = Field(default=8.0, ge=0, description="The blur radius") radius: float = Field(default=8.0, ge=0, description="The blur radius")
blur_type: Literal['gaussian', 'box'] = Field(default='gaussian', description="The type of blur") blur_type: Literal["gaussian", "box"] = Field(
default="gaussian", description="The type of blur"
)
def invoke(self, context: InvocationContext) -> ImageOutput: def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get(self.image.image_type, self.image.image_name) image = context.services.images.get(
self.image.image_type, self.image.image_name
)
blur = ImageFilter.GaussianBlur(self.radius) if self.blur_type == 'gaussian' else ImageFilter.BoxBlur(self.radius) blur = (
ImageFilter.GaussianBlur(self.radius)
if self.blur_type == "gaussian"
else ImageFilter.BoxBlur(self.radius)
)
blur_image = image.filter(blur) blur_image = image.filter(blur)
image_type = ImageType.INTERMEDIATE image_type = ImageType.INTERMEDIATE
image_name = context.services.images.create_name(context.graph_execution_state_id, self.id) image_name = context.services.images.create_name(
context.graph_execution_state_id, self.id
)
context.services.images.save(image_type, image_name, blur_image) context.services.images.save(image_type, image_name, blur_image)
return ImageOutput( return ImageOutput(
image = ImageField(image_type = image_type, image_name = image_name) image=ImageField(image_type=image_type, image_name=image_name)
) )
class LerpInvocation(BaseInvocation): class LerpInvocation(BaseInvocation):
"""Linear interpolation of all pixels of an image""" """Linear interpolation of all pixels of an image"""
type: Literal['lerp'] = 'lerp'
type: Literal["lerp"] = "lerp"
# Inputs # Inputs
image: ImageField = Field(default=None, description="The image to lerp") image: ImageField = Field(default=None, description="The image to lerp")
@ -179,7 +239,9 @@ class LerpInvocation(BaseInvocation):
max: int = Field(default=255, ge=0, le=255, description="The maximum output value") max: int = Field(default=255, ge=0, le=255, description="The maximum output value")
def invoke(self, context: InvocationContext) -> ImageOutput: def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get(self.image.image_type, self.image.image_name) image = context.services.images.get(
self.image.image_type, self.image.image_name
)
image_arr = numpy.asarray(image, dtype=numpy.float32) / 255 image_arr = numpy.asarray(image, dtype=numpy.float32) / 255
image_arr = image_arr * (self.max - self.min) + self.max image_arr = image_arr * (self.max - self.min) + self.max
@ -187,33 +249,46 @@ class LerpInvocation(BaseInvocation):
lerp_image = Image.fromarray(numpy.uint8(image_arr)) lerp_image = Image.fromarray(numpy.uint8(image_arr))
image_type = ImageType.INTERMEDIATE image_type = ImageType.INTERMEDIATE
image_name = context.services.images.create_name(context.graph_execution_state_id, self.id) image_name = context.services.images.create_name(
context.graph_execution_state_id, self.id
)
context.services.images.save(image_type, image_name, lerp_image) context.services.images.save(image_type, image_name, lerp_image)
return ImageOutput( return ImageOutput(
image = ImageField(image_type = image_type, image_name = image_name) image=ImageField(image_type=image_type, image_name=image_name)
) )
class InverseLerpInvocation(BaseInvocation): class InverseLerpInvocation(BaseInvocation):
"""Inverse linear interpolation of all pixels of an image""" """Inverse linear interpolation of all pixels of an image"""
type: Literal['ilerp'] = 'ilerp' #fmt: off
type: Literal["ilerp"] = "ilerp"
# Inputs # Inputs
image: ImageField = Field(default=None, description="The image to lerp") image: ImageField = Field(default=None, description="The image to lerp")
min: int = Field(default=0, ge=0, le=255, description="The minimum input value") min: int = Field(default=0, ge=0, le=255, description="The minimum input value")
max: int = Field(default=255, ge=0, le=255, description="The maximum input value") max: int = Field(default=255, ge=0, le=255, description="The maximum input value")
#fmt: on
def invoke(self, context: InvocationContext) -> ImageOutput: def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get(self.image.image_type, self.image.image_name) image = context.services.images.get(
self.image.image_type, self.image.image_name
)
image_arr = numpy.asarray(image, dtype=numpy.float32) image_arr = numpy.asarray(image, dtype=numpy.float32)
image_arr = numpy.minimum(numpy.maximum(image_arr - self.min, 0) / float(self.max - self.min), 1) * 255 image_arr = (
numpy.minimum(
numpy.maximum(image_arr - self.min, 0) / float(self.max - self.min), 1
)
* 255
)
ilerp_image = Image.fromarray(numpy.uint8(image_arr)) ilerp_image = Image.fromarray(numpy.uint8(image_arr))
image_type = ImageType.INTERMEDIATE image_type = ImageType.INTERMEDIATE
image_name = context.services.images.create_name(context.graph_execution_state_id, self.id) image_name = context.services.images.create_name(
context.graph_execution_state_id, self.id
)
context.services.images.save(image_type, image_name, ilerp_image) context.services.images.save(image_type, image_name, ilerp_image)
return ImageOutput( return ImageOutput(
image = ImageField(image_type = image_type, image_name = image_name) image=ImageField(image_type=image_type, image_name=image_name)
) )

View File

@ -1,9 +1,14 @@
from typing import Literal from typing import Literal
from pydantic.fields import Field from pydantic.fields import Field
from .baseinvocation import BaseInvocationOutput from .baseinvocation import BaseInvocationOutput
class PromptOutput(BaseInvocationOutput): class PromptOutput(BaseInvocationOutput):
"""Base class for invocations that output a prompt""" """Base class for invocations that output a prompt"""
type: Literal['prompt'] = 'prompt' #fmt: off
type: Literal["prompt"] = "prompt"
prompt: str = Field(default=None, description="The output prompt") prompt: str = Field(default=None, description="The output prompt")
#fmt: on

View File

@ -1,36 +1,43 @@
from datetime import datetime, timezone from datetime import datetime, timezone
from typing import Literal, Union from typing import Literal, Union
from pydantic import Field from pydantic import Field
from .image import ImageField, ImageOutput
from .baseinvocation import BaseInvocation, InvocationContext
from ..services.image_storage import ImageType from ..services.image_storage import ImageType
from ..services.invocation_services import InvocationServices from ..services.invocation_services import InvocationServices
from .baseinvocation import BaseInvocation, InvocationContext
from .image import ImageField, ImageOutput
class RestoreFaceInvocation(BaseInvocation): class RestoreFaceInvocation(BaseInvocation):
"""Restores faces in an image.""" """Restores faces in an image."""
type: Literal['restore_face'] = 'restore_face' #fmt: off
type: Literal["restore_face"] = "restore_face"
# Inputs # Inputs
image: Union[ImageField,None] = Field(description="The input image") image: Union[ImageField, None] = Field(description="The input image")
strength: float = Field(default=0.75, gt=0, le=1, description="The strength of the restoration") strength: float = Field(default=0.75, gt=0, le=1, description="The strength of the restoration" )
#fmt: on
def invoke(self, context: InvocationContext) -> ImageOutput: def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get(self.image.image_type, self.image.image_name) image = context.services.images.get(
self.image.image_type, self.image.image_name
)
results = context.services.generate.upscale_and_reconstruct( results = context.services.generate.upscale_and_reconstruct(
image_list = [[image, 0]], image_list=[[image, 0]],
upscale = None, upscale=None,
strength = self.strength, # GFPGAN strength strength=self.strength, # GFPGAN strength
save_original = False, save_original=False,
image_callback = None, image_callback=None,
) )
# Results are image and seed, unwrap for now # Results are image and seed, unwrap for now
# TODO: can this return multiple results? # TODO: can this return multiple results?
image_type = ImageType.RESULT image_type = ImageType.RESULT
image_name = context.services.images.create_name(context.graph_execution_state_id, self.id) image_name = context.services.images.create_name(
context.graph_execution_state_id, self.id
)
context.services.images.save(image_type, image_name, results[0][0]) context.services.images.save(image_type, image_name, results[0][0])
return ImageOutput( return ImageOutput(
image = ImageField(image_type = image_type, image_name = image_name) image=ImageField(image_type=image_type, image_name=image_name)
) )

View File

@ -2,37 +2,45 @@
from datetime import datetime, timezone from datetime import datetime, timezone
from typing import Literal, Union from typing import Literal, Union
from pydantic import Field from pydantic import Field
from .image import ImageField, ImageOutput
from .baseinvocation import BaseInvocation, InvocationContext
from ..services.image_storage import ImageType from ..services.image_storage import ImageType
from ..services.invocation_services import InvocationServices from ..services.invocation_services import InvocationServices
from .baseinvocation import BaseInvocation, InvocationContext
from .image import ImageField, ImageOutput
class UpscaleInvocation(BaseInvocation): class UpscaleInvocation(BaseInvocation):
"""Upscales an image.""" """Upscales an image."""
type: Literal['upscale'] = 'upscale' #fmt: off
type: Literal["upscale"] = "upscale"
# Inputs # Inputs
image: Union[ImageField,None] = Field(description="The input image", default=None) image: Union[ImageField, None] = Field(description="The input image", default=None)
strength: float = Field(default=0.75, gt=0, le=1, description="The strength") strength: float = Field(default=0.75, gt=0, le=1, description="The strength")
level: Literal[2,4] = Field(default=2, description = "The upscale level") level: Literal[2, 4] = Field(default=2, description="The upscale level")
#fmt: on
def invoke(self, context: InvocationContext) -> ImageOutput: def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get(self.image.image_type, self.image.image_name) image = context.services.images.get(
self.image.image_type, self.image.image_name
)
results = context.services.generate.upscale_and_reconstruct( results = context.services.generate.upscale_and_reconstruct(
image_list = [[image, 0]], image_list=[[image, 0]],
upscale = (self.level, self.strength), upscale=(self.level, self.strength),
strength = 0.0, # GFPGAN strength strength=0.0, # GFPGAN strength
save_original = False, save_original=False,
image_callback = None, image_callback=None,
) )
# Results are image and seed, unwrap for now # Results are image and seed, unwrap for now
# TODO: can this return multiple results? # TODO: can this return multiple results?
image_type = ImageType.RESULT image_type = ImageType.RESULT
image_name = context.services.images.create_name(context.graph_execution_state_id, self.id) image_name = context.services.images.create_name(
context.graph_execution_state_id, self.id
)
context.services.images.save(image_type, image_name, results[0][0]) context.services.images.save(image_type, image_name, results[0][0])
return ImageOutput( return ImageOutput(
image = ImageField(image_type = image_type, image_name = image_name) image=ImageField(image_type=image_type, image_name=image_name)
) )

View File

@ -0,0 +1,83 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
from typing import Any, Dict
class EventServiceBase:
session_event: str = "session_event"
"""Basic event bus, to have an empty stand-in when not needed"""
def dispatch(self, event_name: str, payload: Any) -> None:
pass
def __emit_session_event(self, event_name: str, payload: Dict) -> None:
self.dispatch(
event_name=EventServiceBase.session_event,
payload=dict(event=event_name, data=payload),
)
# Define events here for every event in the system.
# This will make them easier to integrate until we find a schema generator.
def emit_generator_progress(
self,
graph_execution_state_id: str,
invocation_id: str,
step: int,
percent: float,
) -> None:
"""Emitted when there is generation progress"""
self.__emit_session_event(
event_name="generator_progress",
payload=dict(
graph_execution_state_id=graph_execution_state_id,
invocation_id=invocation_id,
step=step,
percent=percent,
),
)
def emit_invocation_complete(
self, graph_execution_state_id: str, invocation_id: str, result: Dict
) -> None:
"""Emitted when an invocation has completed"""
self.__emit_session_event(
event_name="invocation_complete",
payload=dict(
graph_execution_state_id=graph_execution_state_id,
invocation_id=invocation_id,
result=result,
),
)
def emit_invocation_error(
self, graph_execution_state_id: str, invocation_id: str, error: str
) -> None:
"""Emitted when an invocation has completed"""
self.__emit_session_event(
event_name="invocation_error",
payload=dict(
graph_execution_state_id=graph_execution_state_id,
invocation_id=invocation_id,
error=error,
),
)
def emit_invocation_started(
self, graph_execution_state_id: str, invocation_id: str
) -> None:
"""Emitted when an invocation has started"""
self.__emit_session_event(
event_name="invocation_started",
payload=dict(
graph_execution_state_id=graph_execution_state_id,
invocation_id=invocation_id,
),
)
def emit_graph_execution_complete(self, graph_execution_state_id: str) -> None:
"""Emitted when a session has completed all invocations"""
self.__emit_session_event(
event_name="graph_execution_state_complete",
payload=dict(graph_execution_state_id=graph_execution_state_id),
)

View File

@ -1,42 +1,47 @@
from argparse import Namespace
import os import os
import sys import sys
import traceback import traceback
from argparse import Namespace
from ...model_manager import ModelManager import invokeai.version
from invokeai.backend import Generate, ModelManager
from ...globals import Globals from ...globals import Globals
from ....generate import Generate
import ldm.invoke
# TODO: most of this code should be split into individual services as the Generate.py code is deprecated # TODO: most of this code should be split into individual services as the Generate.py code is deprecated
def get_generate(args, config) -> Generate: def get_generate(args, config) -> Generate:
if not args.conf: if not args.conf:
config_file = os.path.join(Globals.root,'configs','models.yaml') config_file = os.path.join(Globals.root, "configs", "models.yaml")
if not os.path.exists(config_file): if not os.path.exists(config_file):
report_model_error(args, FileNotFoundError(f"The file {config_file} could not be found.")) report_model_error(
args, FileNotFoundError(f"The file {config_file} could not be found.")
)
print(f'>> {ldm.invoke.__app_name__}, version {ldm.invoke.__version__}') print(f">> {invokeai.version.__app_name__}, version {invokeai.version.__version__}")
print(f'>> InvokeAI runtime directory is "{Globals.root}"') print(f'>> InvokeAI runtime directory is "{Globals.root}"')
# these two lines prevent a horrible warning message from appearing # these two lines prevent a horrible warning message from appearing
# when the frozen CLIP tokenizer is imported # when the frozen CLIP tokenizer is imported
import transformers # type: ignore import transformers # type: ignore
transformers.logging.set_verbosity_error() transformers.logging.set_verbosity_error()
import diffusers import diffusers
diffusers.logging.set_verbosity_error() diffusers.logging.set_verbosity_error()
# Loading Face Restoration and ESRGAN Modules # Loading Face Restoration and ESRGAN Modules
gfpgan,codeformer,esrgan = load_face_restoration(args) gfpgan, codeformer, esrgan = load_face_restoration(args)
# normalize the config directory relative to root # normalize the config directory relative to root
if not os.path.isabs(args.conf): if not os.path.isabs(args.conf):
args.conf = os.path.normpath(os.path.join(Globals.root,args.conf)) args.conf = os.path.normpath(os.path.join(Globals.root, args.conf))
if args.embeddings: if args.embeddings:
if not os.path.isabs(args.embedding_path): if not os.path.isabs(args.embedding_path):
embedding_path = os.path.normpath(os.path.join(Globals.root,args.embedding_path)) embedding_path = os.path.normpath(
os.path.join(Globals.root, args.embedding_path)
)
else: else:
embedding_path = args.embedding_path embedding_path = args.embedding_path
else: else:
@ -49,35 +54,35 @@ def get_generate(args, config) -> Generate:
if args.infile: if args.infile:
try: try:
if os.path.isfile(args.infile): if os.path.isfile(args.infile):
infile = open(args.infile, 'r', encoding='utf-8') infile = open(args.infile, "r", encoding="utf-8")
elif args.infile == '-': # stdin elif args.infile == "-": # stdin
infile = sys.stdin infile = sys.stdin
else: else:
raise FileNotFoundError(f'{args.infile} not found.') raise FileNotFoundError(f"{args.infile} not found.")
except (FileNotFoundError, IOError) as e: except (FileNotFoundError, IOError) as e:
print(f'{e}. Aborting.') print(f"{e}. Aborting.")
sys.exit(-1) sys.exit(-1)
# creating a Generate object: # creating a Generate object:
try: try:
gen = Generate( gen = Generate(
conf = args.conf, conf=args.conf,
model = args.model, model=args.model,
sampler_name = args.sampler_name, sampler_name=args.sampler_name,
embedding_path = embedding_path, embedding_path=embedding_path,
full_precision = args.full_precision, full_precision=args.full_precision,
precision = args.precision, precision=args.precision,
gfpgan = gfpgan, gfpgan=gfpgan,
codeformer = codeformer, codeformer=codeformer,
esrgan = esrgan, esrgan=esrgan,
free_gpu_mem = args.free_gpu_mem, free_gpu_mem=args.free_gpu_mem,
safety_checker = args.safety_checker, safety_checker=args.safety_checker,
max_loaded_models = args.max_loaded_models, max_loaded_models=args.max_loaded_models,
) )
except (FileNotFoundError, TypeError, AssertionError) as e: except (FileNotFoundError, TypeError, AssertionError) as e:
report_model_error(opt,e) report_model_error(opt, e)
except (IOError, KeyError) as e: except (IOError, KeyError) as e:
print(f'{e}. Aborting.') print(f"{e}. Aborting.")
sys.exit(-1) sys.exit(-1)
if args.seamless: if args.seamless:
@ -98,7 +103,7 @@ def get_generate(args, config) -> Generate:
conf_path=args.conf, conf_path=args.conf,
weights_directory=path, weights_directory=path,
) )
return gen return gen
@ -106,51 +111,61 @@ def load_face_restoration(opt):
try: try:
gfpgan, codeformer, esrgan = None, None, None gfpgan, codeformer, esrgan = None, None, None
if opt.restore or opt.esrgan: if opt.restore or opt.esrgan:
from ldm.invoke.restoration import Restoration from invokeai.backend.restoration import Restoration
restoration = Restoration() restoration = Restoration()
if opt.restore: if opt.restore:
gfpgan, codeformer = restoration.load_face_restore_models(opt.gfpgan_model_path) gfpgan, codeformer = restoration.load_face_restore_models(
opt.gfpgan_model_path
)
else: else:
print('>> Face restoration disabled') print(">> Face restoration disabled")
if opt.esrgan: if opt.esrgan:
esrgan = restoration.load_esrgan(opt.esrgan_bg_tile) esrgan = restoration.load_esrgan(opt.esrgan_bg_tile)
else: else:
print('>> Upscaling disabled') print(">> Upscaling disabled")
else: else:
print('>> Face restoration and upscaling disabled') print(">> Face restoration and upscaling disabled")
except (ModuleNotFoundError, ImportError): except (ModuleNotFoundError, ImportError):
print(traceback.format_exc(), file=sys.stderr) print(traceback.format_exc(), file=sys.stderr)
print('>> You may need to install the ESRGAN and/or GFPGAN modules') print(">> You may need to install the ESRGAN and/or GFPGAN modules")
return gfpgan,codeformer,esrgan return gfpgan, codeformer, esrgan
def report_model_error(opt:Namespace, e:Exception): def report_model_error(opt: Namespace, e: Exception):
print(f'** An error occurred while attempting to initialize the model: "{str(e)}"') print(f'** An error occurred while attempting to initialize the model: "{str(e)}"')
print('** This can be caused by a missing or corrupted models file, and can sometimes be fixed by (re)installing the models.') print(
yes_to_all = os.environ.get('INVOKE_MODEL_RECONFIGURE') "** This can be caused by a missing or corrupted models file, and can sometimes be fixed by (re)installing the models."
)
yes_to_all = os.environ.get("INVOKE_MODEL_RECONFIGURE")
if yes_to_all: if yes_to_all:
print('** Reconfiguration is being forced by environment variable INVOKE_MODEL_RECONFIGURE') print(
"** Reconfiguration is being forced by environment variable INVOKE_MODEL_RECONFIGURE"
)
else: else:
response = input('Do you want to run invokeai-configure script to select and/or reinstall models? [y] ') response = input(
if response.startswith(('n', 'N')): "Do you want to run invokeai-configure script to select and/or reinstall models? [y] "
)
if response.startswith(("n", "N")):
return return
print('invokeai-configure is launching....\n') print("invokeai-configure is launching....\n")
# Match arguments that were set on the CLI # Match arguments that were set on the CLI
# only the arguments accepted by the configuration script are parsed # only the arguments accepted by the configuration script are parsed
root_dir = ["--root", opt.root_dir] if opt.root_dir is not None else [] root_dir = ["--root", opt.root_dir] if opt.root_dir is not None else []
config = ["--config", opt.conf] if opt.conf is not None else [] config = ["--config", opt.conf] if opt.conf is not None else []
previous_args = sys.argv previous_args = sys.argv
sys.argv = [ 'invokeai-configure' ] sys.argv = ["invokeai-configure"]
sys.argv.extend(root_dir) sys.argv.extend(root_dir)
sys.argv.extend(config) sys.argv.extend(config)
if yes_to_all is not None: if yes_to_all is not None:
for arg in yes_to_all.split(): for arg in yes_to_all.split():
sys.argv.append(arg) sys.argv.append(arg)
from ldm.invoke.config import invokeai_configure from invokeai.frontend.install import invokeai_configure
invokeai_configure.main()
invokeai_configure()
# TODO: Figure out how to restart # TODO: Figure out how to restart
# print('** InvokeAI will now restart') # print('** InvokeAI will now restart')
# sys.argv = previous_args # sys.argv = previous_args
@ -161,17 +176,20 @@ def report_model_error(opt:Namespace, e:Exception):
# Temporary initializer for Generate until we migrate off of it # Temporary initializer for Generate until we migrate off of it
def old_get_generate(args, config) -> Generate: def old_get_generate(args, config) -> Generate:
# TODO: Remove the need for globals # TODO: Remove the need for globals
from ldm.invoke.globals import Globals from invokeai.backend.globals import Globals
# alert - setting globals here # alert - setting globals here
Globals.root = os.path.expanduser(args.root_dir or os.environ.get('INVOKEAI_ROOT') or os.path.abspath('.')) Globals.root = os.path.expanduser(
args.root_dir or os.environ.get("INVOKEAI_ROOT") or os.path.abspath(".")
)
Globals.try_patchmatch = args.patchmatch Globals.try_patchmatch = args.patchmatch
print(f'>> InvokeAI runtime directory is "{Globals.root}"') print(f'>> InvokeAI runtime directory is "{Globals.root}"')
# these two lines prevent a horrible warning message from appearing # these two lines prevent a horrible warning message from appearing
# when the frozen CLIP tokenizer is imported # when the frozen CLIP tokenizer is imported
import transformers import transformers
transformers.logging.set_verbosity_error() transformers.logging.set_verbosity_error()
# Loading Face Restoration and ESRGAN Modules # Loading Face Restoration and ESRGAN Modules
@ -179,53 +197,57 @@ def old_get_generate(args, config) -> Generate:
try: try:
if config.restore or config.esrgan: if config.restore or config.esrgan:
from ldm.invoke.restoration import Restoration from ldm.invoke.restoration import Restoration
restoration = Restoration() restoration = Restoration()
if config.restore: if config.restore:
gfpgan, codeformer = restoration.load_face_restore_models(config.gfpgan_model_path) gfpgan, codeformer = restoration.load_face_restore_models(
config.gfpgan_model_path
)
else: else:
print('>> Face restoration disabled') print(">> Face restoration disabled")
if config.esrgan: if config.esrgan:
esrgan = restoration.load_esrgan(config.esrgan_bg_tile) esrgan = restoration.load_esrgan(config.esrgan_bg_tile)
else: else:
print('>> Upscaling disabled') print(">> Upscaling disabled")
else: else:
print('>> Face restoration and upscaling disabled') print(">> Face restoration and upscaling disabled")
except (ModuleNotFoundError, ImportError): except (ModuleNotFoundError, ImportError):
print(traceback.format_exc(), file=sys.stderr) print(traceback.format_exc(), file=sys.stderr)
print('>> You may need to install the ESRGAN and/or GFPGAN modules') print(">> You may need to install the ESRGAN and/or GFPGAN modules")
# normalize the config directory relative to root # normalize the config directory relative to root
if not os.path.isabs(config.conf): if not os.path.isabs(config.conf):
config.conf = os.path.normpath(os.path.join(Globals.root,config.conf)) config.conf = os.path.normpath(os.path.join(Globals.root, config.conf))
if config.embeddings: if config.embeddings:
if not os.path.isabs(config.embedding_path): if not os.path.isabs(config.embedding_path):
embedding_path = os.path.normpath(os.path.join(Globals.root,config.embedding_path)) embedding_path = os.path.normpath(
os.path.join(Globals.root, config.embedding_path)
)
else: else:
embedding_path = None embedding_path = None
# TODO: lazy-initialize this by wrapping it # TODO: lazy-initialize this by wrapping it
try: try:
generate = Generate( generate = Generate(
conf = config.conf, conf=config.conf,
model = config.model, model=config.model,
sampler_name = config.sampler_name, sampler_name=config.sampler_name,
embedding_path = embedding_path, embedding_path=embedding_path,
full_precision = config.full_precision, full_precision=config.full_precision,
precision = config.precision, precision=config.precision,
gfpgan = gfpgan, gfpgan=gfpgan,
codeformer = codeformer, codeformer=codeformer,
esrgan = esrgan, esrgan=esrgan,
free_gpu_mem = config.free_gpu_mem, free_gpu_mem=config.free_gpu_mem,
safety_checker = config.safety_checker, safety_checker=config.safety_checker,
max_loaded_models = config.max_loaded_models, max_loaded_models=config.max_loaded_models,
) )
except (FileNotFoundError, TypeError, AssertionError): except (FileNotFoundError, TypeError, AssertionError):
#emergency_model_reconfigure() # TODO? # emergency_model_reconfigure() # TODO?
sys.exit(-1) sys.exit(-1)
except (IOError, KeyError) as e: except (IOError, KeyError) as e:
print(f'{e}. Aborting.') print(f"{e}. Aborting.")
sys.exit(-1) sys.exit(-1)
generate.free_gpu_mem = config.free_gpu_mem generate.free_gpu_mem = config.free_gpu_mem

View File

@ -1,20 +1,22 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) # Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
from abc import ABC, abstractmethod
from enum import Enum
import datetime import datetime
import os import os
from abc import ABC, abstractmethod
from enum import Enum
from pathlib import Path from pathlib import Path
from queue import Queue from queue import Queue
from typing import Dict from typing import Dict
from PIL.Image import Image from PIL.Image import Image
from ...pngwriter import PngWriter
from invokeai.backend.image_util import PngWriter
class ImageType(str, Enum): class ImageType(str, Enum):
RESULT = 'results' RESULT = "results"
INTERMEDIATE = 'intermediates' INTERMEDIATE = "intermediates"
UPLOAD = 'uploads' UPLOAD = "uploads"
class ImageStorageBase(ABC): class ImageStorageBase(ABC):
@ -38,14 +40,15 @@ class ImageStorageBase(ABC):
pass pass
def create_name(self, context_id: str, node_id: str) -> str: def create_name(self, context_id: str, node_id: str) -> str:
return f'{context_id}_{node_id}_{str(int(datetime.datetime.now(datetime.timezone.utc).timestamp()))}.png' return f"{context_id}_{node_id}_{str(int(datetime.datetime.now(datetime.timezone.utc).timestamp()))}.png"
class DiskImageStorage(ImageStorageBase): class DiskImageStorage(ImageStorageBase):
"""Stores images on disk""" """Stores images on disk"""
__output_folder: str __output_folder: str
__pngWriter: PngWriter __pngWriter: PngWriter
__cache_ids: Queue # TODO: this is an incredibly naive cache __cache_ids: Queue # TODO: this is an incredibly naive cache
__cache: Dict[str, Image] __cache: Dict[str, Image]
__max_cache_size: int __max_cache_size: int
@ -54,13 +57,15 @@ class DiskImageStorage(ImageStorageBase):
self.__pngWriter = PngWriter(output_folder) self.__pngWriter = PngWriter(output_folder)
self.__cache = dict() self.__cache = dict()
self.__cache_ids = Queue() self.__cache_ids = Queue()
self.__max_cache_size = 10 # TODO: get this from config self.__max_cache_size = 10 # TODO: get this from config
Path(output_folder).mkdir(parents=True, exist_ok=True) Path(output_folder).mkdir(parents=True, exist_ok=True)
# TODO: don't hard-code. get/save/delete should maybe take subpath? # TODO: don't hard-code. get/save/delete should maybe take subpath?
for image_type in ImageType: for image_type in ImageType:
Path(os.path.join(output_folder, image_type)).mkdir(parents=True, exist_ok=True) Path(os.path.join(output_folder, image_type)).mkdir(
parents=True, exist_ok=True
)
def get(self, image_type: ImageType, image_name: str) -> Image: def get(self, image_type: ImageType, image_name: str) -> Image:
image_path = self.get_path(image_type, image_name) image_path = self.get_path(image_type, image_name)
@ -79,7 +84,9 @@ class DiskImageStorage(ImageStorageBase):
def save(self, image_type: ImageType, image_name: str, image: Image) -> None: def save(self, image_type: ImageType, image_name: str, image: Image) -> None:
image_subpath = os.path.join(image_type, image_name) image_subpath = os.path.join(image_type, image_name)
self.__pngWriter.save_image_and_prompt_to_png(image, "", image_subpath, None) # TODO: just pass full path to png writer self.__pngWriter.save_image_and_prompt_to_png(
image, "", image_subpath, None
) # TODO: just pass full path to png writer
image_path = self.get_path(image_type, image_name) image_path = self.get_path(image_type, image_name)
self.__set_cache(image_path, image) self.__set_cache(image_path, image)
@ -88,7 +95,7 @@ class DiskImageStorage(ImageStorageBase):
image_path = self.get_path(image_type, image_name) image_path = self.get_path(image_type, image_name)
if os.path.exists(image_path): if os.path.exists(image_path):
os.remove(image_path) os.remove(image_path)
if image_path in self.__cache: if image_path in self.__cache:
del self.__cache[image_path] del self.__cache[image_path]
@ -98,7 +105,9 @@ class DiskImageStorage(ImageStorageBase):
def __set_cache(self, image_name: str, image: Image): def __set_cache(self, image_name: str, image: Image):
if not image_name in self.__cache: if not image_name in self.__cache:
self.__cache[image_name] = image self.__cache[image_name] = image
self.__cache_ids.put(image_name) # TODO: this should refresh position for LRU cache self.__cache_ids.put(
image_name
) # TODO: this should refresh position for LRU cache
if len(self.__cache) > self.__max_cache_size: if len(self.__cache) > self.__max_cache_size:
cache_id = self.__cache_ids.get() cache_id = self.__cache_ids.get()
del self.__cache[cache_id] del self.__cache[cache_id]

View File

@ -6,17 +6,19 @@ from queue import Queue
# TODO: make this serializable # TODO: make this serializable
class InvocationQueueItem: class InvocationQueueItem:
#session_id: str # session_id: str
graph_execution_state_id: str graph_execution_state_id: str
invocation_id: str invocation_id: str
invoke_all: bool invoke_all: bool
def __init__(self, def __init__(
#session_id: str, self,
# session_id: str,
graph_execution_state_id: str, graph_execution_state_id: str,
invocation_id: str, invocation_id: str,
invoke_all: bool = False): invoke_all: bool = False,
#self.session_id = session_id ):
# self.session_id = session_id
self.graph_execution_state_id = graph_execution_state_id self.graph_execution_state_id = graph_execution_state_id
self.invocation_id = invocation_id self.invocation_id = invocation_id
self.invoke_all = invoke_all self.invoke_all = invoke_all
@ -24,12 +26,13 @@ class InvocationQueueItem:
class InvocationQueueABC(ABC): class InvocationQueueABC(ABC):
"""Abstract base class for all invocation queues""" """Abstract base class for all invocation queues"""
@abstractmethod @abstractmethod
def get(self) -> InvocationQueueItem: def get(self) -> InvocationQueueItem:
pass pass
@abstractmethod @abstractmethod
def put(self, item: InvocationQueueItem|None) -> None: def put(self, item: InvocationQueueItem | None) -> None:
pass pass
@ -38,9 +41,9 @@ class MemoryInvocationQueue(InvocationQueueABC):
def __init__(self): def __init__(self):
self.__queue = Queue() self.__queue = Queue()
def get(self) -> InvocationQueueItem: def get(self) -> InvocationQueueItem:
return self.__queue.get() return self.__queue.get()
def put(self, item: InvocationQueueItem|None) -> None: def put(self, item: InvocationQueueItem | None) -> None:
self.__queue.put(item) self.__queue.put(item)

View File

@ -1,29 +1,32 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) # Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
from invokeai.backend import Generate
from .events import EventServiceBase
from .image_storage import ImageStorageBase
from .invocation_queue import InvocationQueueABC from .invocation_queue import InvocationQueueABC
from .item_storage import ItemStorageABC from .item_storage import ItemStorageABC
from .image_storage import ImageStorageBase
from .events import EventServiceBase
from ....generate import Generate
class InvocationServices(): class InvocationServices:
"""Services that can be used by invocations""" """Services that can be used by invocations"""
generate: Generate # TODO: wrap Generate, or split it up from model?
generate: Generate # TODO: wrap Generate, or split it up from model?
events: EventServiceBase events: EventServiceBase
images: ImageStorageBase images: ImageStorageBase
queue: InvocationQueueABC queue: InvocationQueueABC
# NOTE: we must forward-declare any types that include invocations, since invocations can use services # NOTE: we must forward-declare any types that include invocations, since invocations can use services
graph_execution_manager: ItemStorageABC['GraphExecutionState'] graph_execution_manager: ItemStorageABC["GraphExecutionState"]
processor: 'InvocationProcessorABC' processor: "InvocationProcessorABC"
def __init__(self, def __init__(
self,
generate: Generate, generate: Generate,
events: EventServiceBase, events: EventServiceBase,
images: ImageStorageBase, images: ImageStorageBase,
queue: InvocationQueueABC, queue: InvocationQueueABC,
graph_execution_manager: ItemStorageABC['GraphExecutionState'], graph_execution_manager: ItemStorageABC["GraphExecutionState"],
processor: 'InvocationProcessorABC' processor: "InvocationProcessorABC",
): ):
self.generate = generate self.generate = generate
self.events = events self.events = events

View File

@ -2,11 +2,12 @@
from abc import ABC from abc import ABC
from threading import Event, Thread from threading import Event, Thread
from .graph import Graph, GraphExecutionState
from .item_storage import ItemStorageABC
from ..invocations.baseinvocation import InvocationContext from ..invocations.baseinvocation import InvocationContext
from .invocation_services import InvocationServices from .graph import Graph, GraphExecutionState
from .invocation_queue import InvocationQueueABC, InvocationQueueItem from .invocation_queue import InvocationQueueABC, InvocationQueueItem
from .invocation_services import InvocationServices
from .item_storage import ItemStorageABC
class Invoker: class Invoker:
@ -14,14 +15,13 @@ class Invoker:
services: InvocationServices services: InvocationServices
def __init__(self, def __init__(self, services: InvocationServices):
services: InvocationServices
):
self.services = services self.services = services
self._start() self._start()
def invoke(
def invoke(self, graph_execution_state: GraphExecutionState, invoke_all: bool = False) -> str|None: self, graph_execution_state: GraphExecutionState, invoke_all: bool = False
) -> str | None:
"""Determines the next node to invoke and returns the id of the invoked node, or None if there are no nodes to execute""" """Determines the next node to invoke and returns the id of the invoked node, or None if there are no nodes to execute"""
# Get the next invocation # Get the next invocation
@ -33,38 +33,36 @@ class Invoker:
self.services.graph_execution_manager.set(graph_execution_state) self.services.graph_execution_manager.set(graph_execution_state)
# Queue the invocation # Queue the invocation
print(f'queueing item {invocation.id}') print(f"queueing item {invocation.id}")
self.services.queue.put(InvocationQueueItem( self.services.queue.put(
#session_id = session.id, InvocationQueueItem(
graph_execution_state_id = graph_execution_state.id, # session_id = session.id,
invocation_id = invocation.id, graph_execution_state_id=graph_execution_state.id,
invoke_all = invoke_all invocation_id=invocation.id,
)) invoke_all=invoke_all,
)
)
return invocation.id return invocation.id
def create_execution_state(self, graph: Graph | None = None) -> GraphExecutionState:
def create_execution_state(self, graph: Graph|None = None) -> GraphExecutionState:
"""Creates a new execution state for the given graph""" """Creates a new execution state for the given graph"""
new_state = GraphExecutionState(graph = Graph() if graph is None else graph) new_state = GraphExecutionState(graph=Graph() if graph is None else graph)
self.services.graph_execution_manager.set(new_state) self.services.graph_execution_manager.set(new_state)
return new_state return new_state
def __start_service(self, service) -> None: def __start_service(self, service) -> None:
# Call start() method on any services that have it # Call start() method on any services that have it
start_op = getattr(service, 'start', None) start_op = getattr(service, "start", None)
if callable(start_op): if callable(start_op):
start_op(self) start_op(self)
def __stop_service(self, service) -> None: def __stop_service(self, service) -> None:
# Call stop() method on any services that have it # Call stop() method on any services that have it
stop_op = getattr(service, 'stop', None) stop_op = getattr(service, "stop", None)
if callable(stop_op): if callable(stop_op):
stop_op(self) stop_op(self)
def _start(self) -> None: def _start(self) -> None:
"""Starts the invoker. This is called automatically when the invoker is created.""" """Starts the invoker. This is called automatically when the invoker is created."""
for service in vars(self.services): for service in vars(self.services):
@ -73,7 +71,6 @@ class Invoker:
for service in vars(self.services): for service in vars(self.services):
self.__start_service(getattr(self.services, service)) self.__start_service(getattr(self.services, service))
def stop(self) -> None: def stop(self) -> None:
"""Stops the invoker. A new invoker will have to be created to execute further.""" """Stops the invoker. A new invoker will have to be created to execute further."""
# First stop all services # First stop all services
@ -87,4 +84,4 @@ class Invoker:
class InvocationProcessorABC(ABC): class InvocationProcessorABC(ABC):
pass pass

View File

@ -1,19 +1,21 @@
from abc import ABC, abstractmethod
from typing import Callable, Generic, TypeVar
from typing import Callable, TypeVar, Generic
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from pydantic.generics import GenericModel from pydantic.generics import GenericModel
from abc import ABC, abstractmethod
T = TypeVar('T', bound=BaseModel) T = TypeVar("T", bound=BaseModel)
class PaginatedResults(GenericModel, Generic[T]): class PaginatedResults(GenericModel, Generic[T]):
"""Paginated results""" """Paginated results"""
items: list[T] = Field(description = "Items") #fmt: off
page: int = Field(description = "Current Page") items: list[T] = Field(description="Items")
pages: int = Field(description = "Total number of pages") page: int = Field(description="Current Page")
per_page: int = Field(description = "Number of items per page") pages: int = Field(description="Total number of pages")
total: int = Field(description = "Total number of items in result") per_page: int = Field(description="Number of items per page")
total: int = Field(description="Total number of items in result")
#fmt: on
class ItemStorageABC(ABC, Generic[T]): class ItemStorageABC(ABC, Generic[T]):
_on_changed_callbacks: list[Callable[[T], None]] _on_changed_callbacks: list[Callable[[T], None]]
@ -24,6 +26,7 @@ class ItemStorageABC(ABC, Generic[T]):
self._on_deleted_callbacks = list() self._on_deleted_callbacks = list()
"""Base item storage class""" """Base item storage class"""
@abstractmethod @abstractmethod
def get(self, item_id: str) -> T: def get(self, item_id: str) -> T:
pass pass
@ -37,7 +40,9 @@ class ItemStorageABC(ABC, Generic[T]):
pass pass
@abstractmethod @abstractmethod
def search(self, query: str, page: int = 0, per_page: int = 10) -> PaginatedResults[T]: def search(
self, query: str, page: int = 0, per_page: int = 10
) -> PaginatedResults[T]:
pass pass
def on_changed(self, on_changed: Callable[[T], None]) -> None: def on_changed(self, on_changed: Callable[[T], None]) -> None:
@ -51,7 +56,7 @@ class ItemStorageABC(ABC, Generic[T]):
def _on_changed(self, item: T) -> None: def _on_changed(self, item: T) -> None:
for callback in self._on_changed_callbacks: for callback in self._on_changed_callbacks:
callback(item) callback(item)
def _on_deleted(self, item_id: str) -> None: def _on_deleted(self, item_id: str) -> None:
for callback in self._on_deleted_callbacks: for callback in self._on_deleted_callbacks:
callback(item_id) callback(item_id)

View File

@ -1,5 +1,6 @@
from threading import Event, Thread
import traceback import traceback
from threading import Event, Thread
from ..invocations.baseinvocation import InvocationContext from ..invocations.baseinvocation import InvocationContext
from .invocation_queue import InvocationQueueItem from .invocation_queue import InvocationQueueItem
from .invoker import InvocationProcessorABC, Invoker from .invoker import InvocationProcessorABC, Invoker
@ -14,52 +15,62 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
self.__invoker = invoker self.__invoker = invoker
self.__stop_event = Event() self.__stop_event = Event()
self.__invoker_thread = Thread( self.__invoker_thread = Thread(
name = "invoker_processor", name="invoker_processor",
target = self.__process, target=self.__process,
kwargs = dict(stop_event = self.__stop_event) kwargs=dict(stop_event=self.__stop_event),
)
self.__invoker_thread.daemon = (
True # TODO: probably better to just not use threads?
) )
self.__invoker_thread.daemon = True # TODO: probably better to just not use threads?
self.__invoker_thread.start() self.__invoker_thread.start()
def stop(self, *args, **kwargs) -> None: def stop(self, *args, **kwargs) -> None:
self.__stop_event.set() self.__stop_event.set()
def __process(self, stop_event: Event): def __process(self, stop_event: Event):
try: try:
while not stop_event.is_set(): while not stop_event.is_set():
queue_item: InvocationQueueItem = self.__invoker.services.queue.get() queue_item: InvocationQueueItem = self.__invoker.services.queue.get()
if not queue_item: # Probably stopping if not queue_item: # Probably stopping
continue continue
graph_execution_state = self.__invoker.services.graph_execution_manager.get(queue_item.graph_execution_state_id) graph_execution_state = (
invocation = graph_execution_state.execution_graph.get_node(queue_item.invocation_id) self.__invoker.services.graph_execution_manager.get(
queue_item.graph_execution_state_id
)
)
invocation = graph_execution_state.execution_graph.get_node(
queue_item.invocation_id
)
# Send starting event # Send starting event
self.__invoker.services.events.emit_invocation_started( self.__invoker.services.events.emit_invocation_started(
graph_execution_state_id = graph_execution_state.id, graph_execution_state_id=graph_execution_state.id,
invocation_id = invocation.id invocation_id=invocation.id,
) )
# Invoke # Invoke
try: try:
outputs = invocation.invoke(InvocationContext( outputs = invocation.invoke(
services = self.__invoker.services, InvocationContext(
graph_execution_state_id = graph_execution_state.id services=self.__invoker.services,
)) graph_execution_state_id=graph_execution_state.id,
)
)
# Save outputs and history # Save outputs and history
graph_execution_state.complete(invocation.id, outputs) graph_execution_state.complete(invocation.id, outputs)
# Save the state changes # Save the state changes
self.__invoker.services.graph_execution_manager.set(graph_execution_state) self.__invoker.services.graph_execution_manager.set(
graph_execution_state
)
# Send complete event # Send complete event
self.__invoker.services.events.emit_invocation_complete( self.__invoker.services.events.emit_invocation_complete(
graph_execution_state_id = graph_execution_state.id, graph_execution_state_id=graph_execution_state.id,
invocation_id = invocation.id, invocation_id=invocation.id,
result = outputs.dict() result=outputs.dict(),
) )
except KeyboardInterrupt: except KeyboardInterrupt:
@ -72,24 +83,27 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
graph_execution_state.set_node_error(invocation.id, error) graph_execution_state.set_node_error(invocation.id, error)
# Save the state changes # Save the state changes
self.__invoker.services.graph_execution_manager.set(graph_execution_state) self.__invoker.services.graph_execution_manager.set(
graph_execution_state
)
# Send error event # Send error event
self.__invoker.services.events.emit_invocation_error( self.__invoker.services.events.emit_invocation_error(
graph_execution_state_id = graph_execution_state.id, graph_execution_state_id=graph_execution_state.id,
invocation_id = invocation.id, invocation_id=invocation.id,
error = error error=error,
) )
pass pass
# Queue any further commands if invoking all # Queue any further commands if invoking all
is_complete = graph_execution_state.is_complete() is_complete = graph_execution_state.is_complete()
if queue_item.invoke_all and not is_complete: if queue_item.invoke_all and not is_complete:
self.__invoker.invoke(graph_execution_state, invoke_all = True) self.__invoker.invoke(graph_execution_state, invoke_all=True)
elif is_complete: elif is_complete:
self.__invoker.services.events.emit_graph_execution_complete(graph_execution_state.id) self.__invoker.services.events.emit_graph_execution_complete(
graph_execution_state.id
)
except KeyboardInterrupt: except KeyboardInterrupt:
... # Log something? ... # Log something?

View File

@ -1,12 +1,15 @@
import sqlite3 import sqlite3
from threading import Lock from threading import Lock
from typing import Generic, TypeVar, Union, get_args from typing import Generic, TypeVar, Union, get_args
from pydantic import BaseModel, parse_raw_as from pydantic import BaseModel, parse_raw_as
from .item_storage import ItemStorageABC, PaginatedResults from .item_storage import ItemStorageABC, PaginatedResults
T = TypeVar('T', bound=BaseModel) T = TypeVar("T", bound=BaseModel)
sqlite_memory = ":memory:"
sqlite_memory = ':memory:'
class SqliteItemStorage(ItemStorageABC, Generic[T]): class SqliteItemStorage(ItemStorageABC, Generic[T]):
_filename: str _filename: str
@ -16,15 +19,17 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
_id_field: str _id_field: str
_lock: Lock _lock: Lock
def __init__(self, filename: str, table_name: str, id_field: str = 'id'): def __init__(self, filename: str, table_name: str, id_field: str = "id"):
super().__init__() super().__init__()
self._filename = filename self._filename = filename
self._table_name = table_name self._table_name = table_name
self._id_field = id_field # TODO: validate that T has this field self._id_field = id_field # TODO: validate that T has this field
self._lock = Lock() self._lock = Lock()
self._conn = sqlite3.connect(self._filename, check_same_thread=False) # TODO: figure out a better threading solution self._conn = sqlite3.connect(
self._filename, check_same_thread=False
) # TODO: figure out a better threading solution
self._cursor = self._conn.cursor() self._cursor = self._conn.cursor()
self._create_table() self._create_table()
@ -32,10 +37,14 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
def _create_table(self): def _create_table(self):
try: try:
self._lock.acquire() self._lock.acquire()
self._cursor.execute(f'''CREATE TABLE IF NOT EXISTS {self._table_name} ( self._cursor.execute(
f"""CREATE TABLE IF NOT EXISTS {self._table_name} (
item TEXT, item TEXT,
id TEXT GENERATED ALWAYS AS (json_extract(item, '$.{self._id_field}')) VIRTUAL NOT NULL);''') id TEXT GENERATED ALWAYS AS (json_extract(item, '$.{self._id_field}')) VIRTUAL NOT NULL);"""
self._cursor.execute(f'''CREATE UNIQUE INDEX IF NOT EXISTS {self._table_name}_id ON {self._table_name}(id);''') )
self._cursor.execute(
f"""CREATE UNIQUE INDEX IF NOT EXISTS {self._table_name}_id ON {self._table_name}(id);"""
)
finally: finally:
self._lock.release() self._lock.release()
@ -46,7 +55,10 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
def set(self, item: T): def set(self, item: T):
try: try:
self._lock.acquire() self._lock.acquire()
self._cursor.execute(f'''INSERT OR REPLACE INTO {self._table_name} (item) VALUES (?);''', (item.json(),)) self._cursor.execute(
f"""INSERT OR REPLACE INTO {self._table_name} (item) VALUES (?);""",
(item.json(),),
)
finally: finally:
self._lock.release() self._lock.release()
self._on_changed(item) self._on_changed(item)
@ -54,7 +66,9 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
def get(self, id: str) -> Union[T, None]: def get(self, id: str) -> Union[T, None]:
try: try:
self._lock.acquire() self._lock.acquire()
self._cursor.execute(f'''SELECT item FROM {self._table_name} WHERE id = ?;''', (str(id),)) self._cursor.execute(
f"""SELECT item FROM {self._table_name} WHERE id = ?;""", (str(id),)
)
result = self._cursor.fetchone() result = self._cursor.fetchone()
finally: finally:
self._lock.release() self._lock.release()
@ -67,7 +81,9 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
def delete(self, id: str): def delete(self, id: str):
try: try:
self._lock.acquire() self._lock.acquire()
self._cursor.execute(f'''DELETE FROM {self._table_name} WHERE id = ?;''', (str(id),)) self._cursor.execute(
f"""DELETE FROM {self._table_name} WHERE id = ?;""", (str(id),)
)
finally: finally:
self._lock.release() self._lock.release()
self._on_deleted(id) self._on_deleted(id)
@ -75,12 +91,15 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
def list(self, page: int = 0, per_page: int = 10) -> PaginatedResults[T]: def list(self, page: int = 0, per_page: int = 10) -> PaginatedResults[T]:
try: try:
self._lock.acquire() self._lock.acquire()
self._cursor.execute(f'''SELECT item FROM {self._table_name} LIMIT ? OFFSET ?;''', (per_page, page * per_page)) self._cursor.execute(
f"""SELECT item FROM {self._table_name} LIMIT ? OFFSET ?;""",
(per_page, page * per_page),
)
result = self._cursor.fetchall() result = self._cursor.fetchall()
items = list(map(lambda r: self._parse_item(r[0]), result)) items = list(map(lambda r: self._parse_item(r[0]), result))
self._cursor.execute(f'''SELECT count(*) FROM {self._table_name};''') self._cursor.execute(f"""SELECT count(*) FROM {self._table_name};""")
count = self._cursor.fetchone()[0] count = self._cursor.fetchone()[0]
finally: finally:
self._lock.release() self._lock.release()
@ -88,22 +107,26 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
pageCount = int(count / per_page) + 1 pageCount = int(count / per_page) + 1
return PaginatedResults[T]( return PaginatedResults[T](
items = items, items=items, page=page, pages=pageCount, per_page=per_page, total=count
page = page,
pages = pageCount,
per_page = per_page,
total = count
) )
def search(self, query: str, page: int = 0, per_page: int = 10) -> PaginatedResults[T]: def search(
self, query: str, page: int = 0, per_page: int = 10
) -> PaginatedResults[T]:
try: try:
self._lock.acquire() self._lock.acquire()
self._cursor.execute(f'''SELECT item FROM {self._table_name} WHERE item LIKE ? LIMIT ? OFFSET ?;''', (f'%{query}%', per_page, page * per_page)) self._cursor.execute(
f"""SELECT item FROM {self._table_name} WHERE item LIKE ? LIMIT ? OFFSET ?;""",
(f"%{query}%", per_page, page * per_page),
)
result = self._cursor.fetchall() result = self._cursor.fetchall()
items = list(map(lambda r: self._parse_item(r[0]), result)) items = list(map(lambda r: self._parse_item(r[0]), result))
self._cursor.execute(f'''SELECT count(*) FROM {self._table_name} WHERE item LIKE ?;''', (f'%{query}%',)) self._cursor.execute(
f"""SELECT count(*) FROM {self._table_name} WHERE item LIKE ?;""",
(f"%{query}%",),
)
count = self._cursor.fetchone()[0] count = self._cursor.fetchone()[0]
finally: finally:
self._lock.release() self._lock.release()
@ -111,9 +134,5 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
pageCount = int(count / per_page) + 1 pageCount = int(count / per_page) + 1
return PaginatedResults[T]( return PaginatedResults[T](
items = items, items=items, page=page, pages=pageCount, per_page=per_page, total=count
page = page,
pages = pageCount,
per_page = per_page,
total = count
) )

View File

@ -1,5 +1,5 @@
''' """
Initialization file for invokeai.backend Initialization file for invokeai.backend
''' """
from .invoke_ai_web_server import InvokeAIWebServer from .generate import Generate
from .model_management import ModelManager

File diff suppressed because it is too large Load Diff

View File

@ -6,19 +6,20 @@
# #
# Coauthor: Kevin Turner http://github.com/keturn # Coauthor: Kevin Turner http://github.com/keturn
# #
print("Loading Python libraries...\n") import sys
print("Loading Python libraries...\n",file=sys.stderr)
import argparse import argparse
import io import io
import os import os
import re import re
import shutil import shutil
import sys
import traceback import traceback
import warnings import warnings
from argparse import Namespace from argparse import Namespace
from pathlib import Path from pathlib import Path
from urllib import request
from shutil import get_terminal_size from shutil import get_terminal_size
from urllib import request
import npyscreen import npyscreen
import torch import torch
@ -37,17 +38,20 @@ from transformers import (
import invokeai.configs as configs import invokeai.configs as configs
from ...frontend.install.model_install import addModelsForm, process_and_execute
from ...frontend.install.widgets import (
CenteredButtonPress,
IntTitleSlider,
set_min_terminal_size,
)
from ..args import PRECISION_CHOICES, Args from ..args import PRECISION_CHOICES, Args
from ..globals import Globals, global_config_dir, global_config_file, global_cache_dir from ..globals import Globals, global_cache_dir, global_config_dir, global_config_file
from .model_install import addModelsForm, process_and_execute
from .model_install_backend import ( from .model_install_backend import (
default_dataset, default_dataset,
download_from_hf, download_from_hf,
recommended_datasets,
hf_download_with_resume, hf_download_with_resume,
recommended_datasets,
) )
from .widgets import IntTitleSlider, CenteredButtonPress, set_min_terminal_size
warnings.filterwarnings("ignore") warnings.filterwarnings("ignore")
@ -82,6 +86,7 @@ INIT_FILE_PREAMBLE = """# InvokeAI initialization file
# -Ak_euler_a -C10.0 # -Ak_euler_a -C10.0
""" """
# -------------------------------------------- # --------------------------------------------
def postscript(errors: None): def postscript(errors: None):
if not any(errors): if not any(errors):
@ -180,13 +185,11 @@ def download_with_progress_bar(model_url: str, model_dest: str, label: str = "th
# --------------------------------------------- # ---------------------------------------------
# this will preload the Bert tokenizer fles # this will preload the Bert tokenizer fles
def download_bert(): def download_bert():
print( print("Installing bert tokenizer...", file=sys.stderr)
"Installing bert tokenizer...",
file=sys.stderr
)
with warnings.catch_warnings(): with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=DeprecationWarning) warnings.filterwarnings("ignore", category=DeprecationWarning)
from transformers import BertTokenizerFast from transformers import BertTokenizerFast
download_from_hf(BertTokenizerFast, "bert-base-uncased") download_from_hf(BertTokenizerFast, "bert-base-uncased")
@ -197,12 +200,14 @@ def download_sd1_clip():
download_from_hf(CLIPTokenizer, version) download_from_hf(CLIPTokenizer, version)
download_from_hf(CLIPTextModel, version) download_from_hf(CLIPTextModel, version)
# --------------------------------------------- # ---------------------------------------------
def download_sd2_clip(): def download_sd2_clip():
version = 'stabilityai/stable-diffusion-2' version = "stabilityai/stable-diffusion-2"
print("Installing SD2 clip model...", file=sys.stderr) print("Installing SD2 clip model...", file=sys.stderr)
download_from_hf(CLIPTokenizer, version, subfolder='tokenizer') download_from_hf(CLIPTokenizer, version, subfolder="tokenizer")
download_from_hf(CLIPTextModel, version, subfolder='text_encoder') download_from_hf(CLIPTextModel, version, subfolder="text_encoder")
# --------------------------------------------- # ---------------------------------------------
def download_realesrgan(): def download_realesrgan():
@ -323,13 +328,13 @@ def get_root(root: str = None) -> str:
class editOptsForm(npyscreen.FormMultiPage): class editOptsForm(npyscreen.FormMultiPage):
# for responsive resizing - disabled # for responsive resizing - disabled
# FIX_MINIMUM_SIZE_WHEN_CREATED = False # FIX_MINIMUM_SIZE_WHEN_CREATED = False
def create(self): def create(self):
program_opts = self.parentApp.program_opts program_opts = self.parentApp.program_opts
old_opts = self.parentApp.invokeai_opts old_opts = self.parentApp.invokeai_opts
first_time = not (Globals.root / Globals.initfile).exists() first_time = not (Globals.root / Globals.initfile).exists()
access_token = HfFolder.get_token() access_token = HfFolder.get_token()
window_width,window_height = get_terminal_size() window_width, window_height = get_terminal_size()
for i in [ for i in [
"Configure startup settings. You can come back and change these later.", "Configure startup settings. You can come back and change these later.",
"Use ctrl-N and ctrl-P to move to the <N>ext and <P>revious fields.", "Use ctrl-N and ctrl-P to move to the <N>ext and <P>revious fields.",
@ -681,6 +686,7 @@ def run_console_ui(
else: else:
return (editApp.new_opts, editApp.user_selections) return (editApp.new_opts, editApp.user_selections)
# ------------------------------------- # -------------------------------------
def write_opts(opts: Namespace, init_file: Path): def write_opts(opts: Namespace, init_file: Path):
""" """
@ -701,8 +707,8 @@ def write_opts(opts: Namespace, init_file: Path):
"^--?(o|out|no-xformer|xformer|no-ckpt|ckpt|free|no-nsfw|nsfw|prec|max_load|embed|always|ckpt|free_gpu)" "^--?(o|out|no-xformer|xformer|no-ckpt|ckpt|free|no-nsfw|nsfw|prec|max_load|embed|always|ckpt|free_gpu)"
) )
# fix windows paths # fix windows paths
opts.outdir = opts.outdir.replace('\\','/') opts.outdir = opts.outdir.replace("\\", "/")
opts.embedding_path = opts.embedding_path.replace('\\','/') opts.embedding_path = opts.embedding_path.replace("\\", "/")
new_file = f"{init_file}.new" new_file = f"{init_file}.new"
try: try:
lines = [x.strip() for x in open(init_file, "r").readlines()] lines = [x.strip() for x in open(init_file, "r").readlines()]
@ -855,6 +861,7 @@ def main():
except KeyboardInterrupt: except KeyboardInterrupt:
print("\nGoodbye! Come back soon.") print("\nGoodbye! Come back soon.")
# ------------------------------------- # -------------------------------------
if __name__ == "__main__": if __name__ == "__main__":
main() main()

View File

@ -8,6 +8,7 @@ import sys
import warnings import warnings
from pathlib import Path from pathlib import Path
from tempfile import TemporaryFile from tempfile import TemporaryFile
from typing import List
import requests import requests
from diffusers import AutoencoderKL from diffusers import AutoencoderKL
@ -15,12 +16,12 @@ from huggingface_hub import hf_hub_url
from omegaconf import OmegaConf from omegaconf import OmegaConf
from omegaconf.dictconfig import DictConfig from omegaconf.dictconfig import DictConfig
from tqdm import tqdm from tqdm import tqdm
from typing import List
import invokeai.configs as configs import invokeai.configs as configs
from ..generator.diffusers_pipeline import StableDiffusionGeneratorPipeline
from ..globals import Globals, global_cache_dir, global_config_dir from ..globals import Globals, global_cache_dir, global_config_dir
from ..model_manager import ModelManager from ..model_management import ModelManager
from ..stable_diffusion import StableDiffusionGeneratorPipeline
warnings.filterwarnings("ignore") warnings.filterwarnings("ignore")
@ -44,45 +45,49 @@ Config_preamble = """
# was trained on. # was trained on.
""" """
def default_config_file(): def default_config_file():
return Path(global_config_dir()) / "models.yaml" return Path(global_config_dir()) / "models.yaml"
def sd_configs(): def sd_configs():
return Path(global_config_dir()) / "stable-diffusion" return Path(global_config_dir()) / "stable-diffusion"
def initial_models(): def initial_models():
global Datasets global Datasets
if Datasets: if Datasets:
return Datasets return Datasets
return (Datasets := OmegaConf.load(Dataset_path)) return (Datasets := OmegaConf.load(Dataset_path))
def install_requested_models( def install_requested_models(
install_initial_models: List[str] = None, install_initial_models: List[str] = None,
remove_models: List[str] = None, remove_models: List[str] = None,
scan_directory: Path = None, scan_directory: Path = None,
external_models: List[str] = None, external_models: List[str] = None,
scan_at_startup: bool = False, scan_at_startup: bool = False,
convert_to_diffusers: bool = False, convert_to_diffusers: bool = False,
precision: str = "float16", precision: str = "float16",
purge_deleted: bool = False, purge_deleted: bool = False,
config_file_path: Path = None, config_file_path: Path = None,
): ):
''' """
Entry point for installing/deleting starter models, or installing external models. Entry point for installing/deleting starter models, or installing external models.
''' """
config_file_path=config_file_path or default_config_file() config_file_path = config_file_path or default_config_file()
if not config_file_path.exists(): if not config_file_path.exists():
open(config_file_path,'w') open(config_file_path, "w")
model_manager= ModelManager(OmegaConf.load(config_file_path),precision=precision) model_manager = ModelManager(OmegaConf.load(config_file_path), precision=precision)
if remove_models and len(remove_models) > 0: if remove_models and len(remove_models) > 0:
print("== DELETING UNCHECKED STARTER MODELS ==") print("== DELETING UNCHECKED STARTER MODELS ==")
for model in remove_models: for model in remove_models:
print(f'{model}...') print(f"{model}...")
model_manager.del_model(model, delete_files=purge_deleted) model_manager.del_model(model, delete_files=purge_deleted)
model_manager.commit(config_file_path) model_manager.commit(config_file_path)
if install_initial_models and len(install_initial_models) > 0: if install_initial_models and len(install_initial_models) > 0:
print("== INSTALLING SELECTED STARTER MODELS ==") print("== INSTALLING SELECTED STARTER MODELS ==")
successfully_downloaded = download_weight_datasets( successfully_downloaded = download_weight_datasets(
@ -96,20 +101,20 @@ def install_requested_models(
# due to above, we have to reload the model manager because conf file # due to above, we have to reload the model manager because conf file
# was changed behind its back # was changed behind its back
model_manager= ModelManager(OmegaConf.load(config_file_path),precision=precision) model_manager = ModelManager(OmegaConf.load(config_file_path), precision=precision)
external_models = external_models or list() external_models = external_models or list()
if scan_directory: if scan_directory:
external_models.append(str(scan_directory)) external_models.append(str(scan_directory))
if len(external_models)>0: if len(external_models) > 0:
print("== INSTALLING EXTERNAL MODELS ==") print("== INSTALLING EXTERNAL MODELS ==")
for path_url_or_repo in external_models: for path_url_or_repo in external_models:
try: try:
model_manager.heuristic_import( model_manager.heuristic_import(
path_url_or_repo, path_url_or_repo,
convert=convert_to_diffusers, convert=convert_to_diffusers,
commit_to_conf=config_file_path commit_to_conf=config_file_path,
) )
except KeyboardInterrupt: except KeyboardInterrupt:
sys.exit(-1) sys.exit(-1)
@ -117,17 +122,18 @@ def install_requested_models(
pass pass
if scan_at_startup and scan_directory.is_dir(): if scan_at_startup and scan_directory.is_dir():
argument = '--autoconvert' if convert_to_diffusers else '--autoimport' argument = "--autoconvert" if convert_to_diffusers else "--autoimport"
initfile = Path(Globals.root, Globals.initfile) initfile = Path(Globals.root, Globals.initfile)
replacement = Path(Globals.root, f'{Globals.initfile}.new') replacement = Path(Globals.root, f"{Globals.initfile}.new")
directory = str(scan_directory).replace('\\','/') directory = str(scan_directory).replace("\\", "/")
with open(initfile,'r') as input: with open(initfile, "r") as input:
with open(replacement,'w') as output: with open(replacement, "w") as output:
while line := input.readline(): while line := input.readline():
if not line.startswith(argument): if not line.startswith(argument):
output.writelines([line]) output.writelines([line])
output.writelines([f'{argument} {directory}']) output.writelines([f"{argument} {directory}"])
os.replace(replacement,initfile) os.replace(replacement, initfile)
# ------------------------------------- # -------------------------------------
def yes_or_no(prompt: str, default_yes=True): def yes_or_no(prompt: str, default_yes=True):
@ -183,7 +189,9 @@ def migrate_models_ckpt():
if not os.path.exists(os.path.join(model_path, "model.ckpt")): if not os.path.exists(os.path.join(model_path, "model.ckpt")):
return return
new_name = initial_models()["stable-diffusion-1.4"]["file"] new_name = initial_models()["stable-diffusion-1.4"]["file"]
print('The Stable Diffusion v4.1 "model.ckpt" is already installed. The name will be changed to {new_name} to avoid confusion.') print(
'The Stable Diffusion v4.1 "model.ckpt" is already installed. The name will be changed to {new_name} to avoid confusion.'
)
print(f"model.ckpt => {new_name}") print(f"model.ckpt => {new_name}")
os.replace( os.replace(
os.path.join(model_path, "model.ckpt"), os.path.join(model_path, new_name) os.path.join(model_path, "model.ckpt"), os.path.join(model_path, new_name)
@ -383,7 +391,8 @@ def update_config_file(successfully_downloaded: dict, config_file: Path):
# --------------------------------------------- # ---------------------------------------------
def new_config_file_contents( def new_config_file_contents(
successfully_downloaded: dict, config_file: Path, successfully_downloaded: dict,
config_file: Path,
) -> str: ) -> str:
if config_file.exists(): if config_file.exists():
conf = OmegaConf.load(str(config_file.expanduser().resolve())) conf = OmegaConf.load(str(config_file.expanduser().resolve()))
@ -413,7 +422,9 @@ def new_config_file_contents(
stanza["weights"] = os.path.relpath( stanza["weights"] = os.path.relpath(
successfully_downloaded[model], start=Globals.root successfully_downloaded[model], start=Globals.root
) )
stanza["config"] = os.path.normpath(os.path.join(sd_configs(), mod["config"])) stanza["config"] = os.path.normpath(
os.path.join(sd_configs(), mod["config"])
)
if "vae" in mod: if "vae" in mod:
if "file" in mod["vae"]: if "file" in mod["vae"]:
stanza["vae"] = os.path.normpath( stanza["vae"] = os.path.normpath(
@ -445,7 +456,7 @@ def delete_weights(model_name: str, conf_stanza: dict):
print( print(
f"\n** The checkpoint version of {model_name} is superseded by the diffusers version. Deleting the original file {weights}?" f"\n** The checkpoint version of {model_name} is superseded by the diffusers version. Deleting the original file {weights}?"
) )
weights = Path(weights) weights = Path(weights)
if not weights.is_absolute(): if not weights.is_absolute():
weights = Path(Globals.root) / weights weights = Path(Globals.root) / weights

View File

@ -25,21 +25,19 @@ from omegaconf import OmegaConf
from PIL import Image, ImageOps from PIL import Image, ImageOps
from pytorch_lightning import logging, seed_everything from pytorch_lightning import logging, seed_everything
import ldm.invoke.conditioning from .model_management import ModelManager
from ldm.invoke.args import metadata_from_png from .args import metadata_from_png
from ldm.invoke.concepts_lib import HuggingFaceConceptsLibrary from .generator import infill_methods
from ldm.invoke.conditioning import get_uc_and_c_and_ec from .globals import Globals, global_cache_dir
from ldm.invoke.devices import choose_precision, choose_torch_device from .image_util import InitImageResizer, PngWriter, Txt2Mask, configure_model_padding
from ldm.invoke.generator.inpaint import infill_methods from .prompting import get_uc_and_c_and_ec
from ldm.invoke.globals import Globals, global_cache_dir from .stable_diffusion import (
from ldm.invoke.image_util import InitImageResizer DDIMSampler,
from ldm.invoke.model_manager import ModelManager HuggingFaceConceptsLibrary,
from ldm.invoke.pngwriter import PngWriter KSampler,
from ldm.invoke.seamless import configure_model_padding PLMSSampler,
from ldm.invoke.txt2mask import Txt2Mask )
from ldm.models.diffusion.ddim import DDIMSampler from .util import choose_precision, choose_torch_device
from ldm.models.diffusion.ksampler import KSampler
from ldm.models.diffusion.plms import PLMSSampler
def fix_func(orig): def fix_func(orig):
@ -328,8 +326,8 @@ class Generate:
variation_amount=0.0, variation_amount=0.0,
threshold=0.0, threshold=0.0,
perlin=0.0, perlin=0.0,
h_symmetry_time_pct = None, h_symmetry_time_pct=None,
v_symmetry_time_pct = None, v_symmetry_time_pct=None,
karras_max=None, karras_max=None,
outdir=None, outdir=None,
# these are specific to img2img and inpaint # these are specific to img2img and inpaint
@ -717,7 +715,7 @@ class Generate:
prompt, prompt,
model=self.model, model=self.model,
skip_normalize_legacy_blend=opt.skip_normalize, skip_normalize_legacy_blend=opt.skip_normalize,
log_tokens=ldm.invoke.conditioning.log_tokenization, log_tokens=invokeai.backend.prompting.conditioning.log_tokenization,
) )
if tool in ("gfpgan", "codeformer", "upscale"): if tool in ("gfpgan", "codeformer", "upscale"):
@ -741,7 +739,7 @@ class Generate:
) )
elif tool == "outcrop": elif tool == "outcrop":
from ldm.invoke.restoration.outcrop import Outcrop from .restoration.outcrop import Outcrop
extend_instructions = {} extend_instructions = {}
for direction, pixels in _pairwise(opt.outcrop): for direction, pixels in _pairwise(opt.outcrop):
@ -794,7 +792,7 @@ class Generate:
clear_cuda_cache=self.clear_cuda_cache, clear_cuda_cache=self.clear_cuda_cache,
) )
elif tool == "outpaint": elif tool == "outpaint":
from ldm.invoke.restoration.outpaint import Outpaint from .restoration.outpaint import Outpaint
restorer = Outpaint(image, self) restorer = Outpaint(image, self)
return restorer.process(opt, args, image_callback=callback, prefix=prefix) return restorer.process(opt, args, image_callback=callback, prefix=prefix)
@ -816,17 +814,12 @@ class Generate:
hires_fix: bool = False, hires_fix: bool = False,
force_outpaint: bool = False, force_outpaint: bool = False,
): ):
inpainting_model_in_use = self.sampler.uses_inpainting_model()
if hires_fix: if hires_fix:
return self._make_txt2img2img() return self._make_txt2img2img()
if embiggen is not None: if embiggen is not None:
return self._make_embiggen() return self._make_embiggen()
if inpainting_model_in_use:
return self._make_omnibus()
if ((init_image is not None) and (mask_image is not None)) or force_outpaint: if ((init_image is not None) and (mask_image is not None)) or force_outpaint:
return self._make_inpaint() return self._make_inpaint()
@ -903,16 +896,9 @@ class Generate:
def _make_inpaint(self): def _make_inpaint(self):
return self._load_generator(".inpaint", "Inpaint") return self._load_generator(".inpaint", "Inpaint")
def _make_omnibus(self):
return self._load_generator(".omnibus", "Omnibus")
def _load_generator(self, module, class_name): def _load_generator(self, module, class_name):
if self.is_legacy_model(self.model_name): mn = f"invokeai.backend.generator{module}"
mn = f"ldm.invoke.ckpt_generator{module}" cn = class_name
cn = f"Ckpt{class_name}"
else:
mn = f"ldm.invoke.generator{module}"
cn = class_name
module = importlib.import_module(mn) module = importlib.import_module(mn)
constructor = getattr(module, cn) constructor = getattr(module, cn)
return constructor(self.model, self.precision) return constructor(self.model, self.precision)
@ -975,7 +961,7 @@ class Generate:
seed_everything(random.randrange(0, np.iinfo(np.uint32).max)) seed_everything(random.randrange(0, np.iinfo(np.uint32).max))
if self.embedding_path is not None: if self.embedding_path is not None:
print(f'>> Loading embeddings from {self.embedding_path}') print(f">> Loading embeddings from {self.embedding_path}")
for root, _, files in os.walk(self.embedding_path): for root, _, files in os.walk(self.embedding_path):
for name in files: for name in files:
ti_path = os.path.join(root, name) ti_path = os.path.join(root, name)
@ -1030,7 +1016,6 @@ class Generate:
image_callback=None, image_callback=None,
prefix=None, prefix=None,
): ):
results = [] results = []
for r in image_list: for r in image_list:
image, seed = r image, seed = r

View File

@ -0,0 +1,5 @@
"""
Initialization file for the invokeai.generator package
"""
from .base import Generator
from .inpaint import infill_methods

View File

@ -1,7 +1,7 @@
''' """
Base class for ldm.invoke.generator.* Base class for invokeai.backend.generator.*
including img2img, txt2img, and inpaint including img2img, txt2img, and inpaint
''' """
from __future__ import annotations from __future__ import annotations
import os import os
@ -9,24 +9,25 @@ import os.path as osp
import random import random
import traceback import traceback
from contextlib import nullcontext from contextlib import nullcontext
from pathlib import Path
import cv2 import cv2
import numpy as np import numpy as np
import torch import torch
from PIL import Image, ImageFilter, ImageChops
from diffusers import DiffusionPipeline from diffusers import DiffusionPipeline
from einops import rearrange from einops import rearrange
from pathlib import Path from PIL import Image, ImageChops, ImageFilter
from pytorch_lightning import seed_everything from pytorch_lightning import seed_everything
from tqdm import trange from tqdm import trange
import invokeai.assets.web as web_assets import invokeai.assets.web as web_assets
from ldm.models.diffusion.ddpm import DiffusionWrapper
from ldm.util import rand_perlin_2d from ..stable_diffusion.diffusion.ddpm import DiffusionWrapper
from ..util.util import rand_perlin_2d
downsampling = 8 downsampling = 8
CAUTION_IMG = 'caution.png' CAUTION_IMG = "caution.png"
class Generator: class Generator:
downsampling_factor: int downsampling_factor: int
@ -39,7 +40,7 @@ class Generator:
self.precision = precision self.precision = precision
self.seed = None self.seed = None
self.latent_channels = model.channels self.latent_channels = model.channels
self.downsampling_factor = downsampling # BUG: should come from model or config self.downsampling_factor = downsampling # BUG: should come from model or config
self.safety_checker = None self.safety_checker = None
self.perlin = 0.0 self.perlin = 0.0
self.threshold = 0 self.threshold = 0
@ -50,56 +51,73 @@ class Generator:
self.caution_img = None self.caution_img = None
# this is going to be overridden in img2img.py, txt2img.py and inpaint.py # this is going to be overridden in img2img.py, txt2img.py and inpaint.py
def get_make_image(self,prompt,**kwargs): def get_make_image(self, prompt, **kwargs):
""" """
Returns a function returning an image derived from the prompt and the initial image Returns a function returning an image derived from the prompt and the initial image
Return value depends on the seed at the time you call it Return value depends on the seed at the time you call it
""" """
raise NotImplementedError("image_iterator() must be implemented in a descendent class") raise NotImplementedError(
"image_iterator() must be implemented in a descendent class"
)
def set_variation(self, seed, variation_amount, with_variations): def set_variation(self, seed, variation_amount, with_variations):
self.seed = seed self.seed = seed
self.variation_amount = variation_amount self.variation_amount = variation_amount
self.with_variations = with_variations self.with_variations = with_variations
def generate(self,prompt,init_image,width,height,sampler, iterations=1,seed=None, def generate(
image_callback=None, step_callback=None, threshold=0.0, perlin=0.0, self,
h_symmetry_time_pct=None, v_symmetry_time_pct=None, prompt,
safety_checker:dict=None, init_image,
free_gpu_mem: bool=False, width,
**kwargs): height,
sampler,
iterations=1,
seed=None,
image_callback=None,
step_callback=None,
threshold=0.0,
perlin=0.0,
h_symmetry_time_pct=None,
v_symmetry_time_pct=None,
safety_checker: dict = None,
free_gpu_mem: bool = False,
**kwargs,
):
scope = nullcontext scope = nullcontext
self.safety_checker = safety_checker self.safety_checker = safety_checker
self.free_gpu_mem = free_gpu_mem self.free_gpu_mem = free_gpu_mem
attention_maps_images = [] attention_maps_images = []
attention_maps_callback = lambda saver: attention_maps_images.append(saver.get_stacked_maps_image()) attention_maps_callback = lambda saver: attention_maps_images.append(
saver.get_stacked_maps_image()
)
make_image = self.get_make_image( make_image = self.get_make_image(
prompt, prompt,
sampler = sampler, sampler=sampler,
init_image = init_image, init_image=init_image,
width = width, width=width,
height = height, height=height,
step_callback = step_callback, step_callback=step_callback,
threshold = threshold, threshold=threshold,
perlin = perlin, perlin=perlin,
h_symmetry_time_pct = h_symmetry_time_pct, h_symmetry_time_pct=h_symmetry_time_pct,
v_symmetry_time_pct = v_symmetry_time_pct, v_symmetry_time_pct=v_symmetry_time_pct,
attention_maps_callback = attention_maps_callback, attention_maps_callback=attention_maps_callback,
**kwargs **kwargs,
) )
results = [] results = []
seed = seed if seed is not None and seed >= 0 else self.new_seed() seed = seed if seed is not None and seed >= 0 else self.new_seed()
first_seed = seed first_seed = seed
seed, initial_noise = self.generate_initial_noise(seed, width, height) seed, initial_noise = self.generate_initial_noise(seed, width, height)
# There used to be an additional self.model.ema_scope() here, but it breaks # There used to be an additional self.model.ema_scope() here, but it breaks
# the inpaint-1.5 model. Not sure what it did.... ? # the inpaint-1.5 model. Not sure what it did.... ?
with scope(self.model.device.type): with scope(self.model.device.type):
for n in trange(iterations, desc='Generating'): for n in trange(iterations, desc="Generating"):
x_T = None x_T = None
if self.variation_amount > 0: if self.variation_amount > 0:
seed_everything(seed) seed_everything(seed)
target_noise = self.get_noise(width,height) target_noise = self.get_noise(width, height)
x_T = self.slerp(self.variation_amount, initial_noise, target_noise) x_T = self.slerp(self.variation_amount, initial_noise, target_noise)
elif initial_noise is not None: elif initial_noise is not None:
# i.e. we specified particular variations # i.e. we specified particular variations
@ -107,9 +125,9 @@ class Generator:
else: else:
seed_everything(seed) seed_everything(seed)
try: try:
x_T = self.get_noise(width,height) x_T = self.get_noise(width, height)
except: except:
print('** An error occurred while getting initial noise **') print("** An error occurred while getting initial noise **")
print(traceback.format_exc()) print(traceback.format_exc())
image = make_image(x_T) image = make_image(x_T)
@ -120,19 +138,30 @@ class Generator:
results.append([image, seed]) results.append([image, seed])
if image_callback is not None: if image_callback is not None:
attention_maps_image = None if len(attention_maps_images)==0 else attention_maps_images[-1] attention_maps_image = (
image_callback(image, seed, first_seed=first_seed, attention_maps_image=attention_maps_image) None
if len(attention_maps_images) == 0
else attention_maps_images[-1]
)
image_callback(
image,
seed,
first_seed=first_seed,
attention_maps_image=attention_maps_image,
)
seed = self.new_seed() seed = self.new_seed()
# Free up memory from the last generation. # Free up memory from the last generation.
clear_cuda_cache = kwargs['clear_cuda_cache'] if 'clear_cuda_cache' in kwargs else None clear_cuda_cache = (
kwargs["clear_cuda_cache"] if "clear_cuda_cache" in kwargs else None
)
if clear_cuda_cache is not None: if clear_cuda_cache is not None:
clear_cuda_cache() clear_cuda_cache()
return results return results
def sample_to_image(self,samples)->Image.Image: def sample_to_image(self, samples) -> Image.Image:
""" """
Given samples returned from a sampler, converts Given samples returned from a sampler, converts
it into a PIL Image it into a PIL Image
@ -141,18 +170,30 @@ class Generator:
image = self.model.decode_latents(samples) image = self.model.decode_latents(samples)
return self.model.numpy_to_pil(image)[0] return self.model.numpy_to_pil(image)[0]
def repaste_and_color_correct(self, result: Image.Image, init_image: Image.Image, init_mask: Image.Image, mask_blur_radius: int = 8) -> Image.Image: def repaste_and_color_correct(
self,
result: Image.Image,
init_image: Image.Image,
init_mask: Image.Image,
mask_blur_radius: int = 8,
) -> Image.Image:
if init_image is None or init_mask is None: if init_image is None or init_mask is None:
return result return result
# Get the original alpha channel of the mask if there is one. # Get the original alpha channel of the mask if there is one.
# Otherwise it is some other black/white image format ('1', 'L' or 'RGB') # Otherwise it is some other black/white image format ('1', 'L' or 'RGB')
pil_init_mask = init_mask.getchannel('A') if init_mask.mode == 'RGBA' else init_mask.convert('L') pil_init_mask = (
pil_init_image = init_image.convert('RGBA') # Add an alpha channel if one doesn't exist init_mask.getchannel("A")
if init_mask.mode == "RGBA"
else init_mask.convert("L")
)
pil_init_image = init_image.convert(
"RGBA"
) # Add an alpha channel if one doesn't exist
# Build an image with only visible pixels from source to use as reference for color-matching. # Build an image with only visible pixels from source to use as reference for color-matching.
init_rgb_pixels = np.asarray(init_image.convert('RGB'), dtype=np.uint8) init_rgb_pixels = np.asarray(init_image.convert("RGB"), dtype=np.uint8)
init_a_pixels = np.asarray(pil_init_image.getchannel('A'), dtype=np.uint8) init_a_pixels = np.asarray(pil_init_image.getchannel("A"), dtype=np.uint8)
init_mask_pixels = np.asarray(pil_init_mask, dtype=np.uint8) init_mask_pixels = np.asarray(pil_init_mask, dtype=np.uint8)
# Get numpy version of result # Get numpy version of result
@ -171,44 +212,70 @@ class Generator:
# Color correct # Color correct
np_matched_result = np_image.copy() np_matched_result = np_image.copy()
np_matched_result[:,:,:] = (((np_matched_result[:,:,:].astype(np.float32) - gen_means[None,None,:]) / gen_std[None,None,:]) * init_std[None,None,:] + init_means[None,None,:]).clip(0, 255).astype(np.uint8) np_matched_result[:, :, :] = (
matched_result = Image.fromarray(np_matched_result, mode='RGB') (
(
(
np_matched_result[:, :, :].astype(np.float32)
- gen_means[None, None, :]
)
/ gen_std[None, None, :]
)
* init_std[None, None, :]
+ init_means[None, None, :]
)
.clip(0, 255)
.astype(np.uint8)
)
matched_result = Image.fromarray(np_matched_result, mode="RGB")
else: else:
matched_result = Image.fromarray(np_image, mode='RGB') matched_result = Image.fromarray(np_image, mode="RGB")
# Blur the mask out (into init image) by specified amount # Blur the mask out (into init image) by specified amount
if mask_blur_radius > 0: if mask_blur_radius > 0:
nm = np.asarray(pil_init_mask, dtype=np.uint8) nm = np.asarray(pil_init_mask, dtype=np.uint8)
nmd = cv2.erode(nm, kernel=np.ones((3,3), dtype=np.uint8), iterations=int(mask_blur_radius / 2)) nmd = cv2.erode(
pmd = Image.fromarray(nmd, mode='L') nm,
kernel=np.ones((3, 3), dtype=np.uint8),
iterations=int(mask_blur_radius / 2),
)
pmd = Image.fromarray(nmd, mode="L")
blurred_init_mask = pmd.filter(ImageFilter.BoxBlur(mask_blur_radius)) blurred_init_mask = pmd.filter(ImageFilter.BoxBlur(mask_blur_radius))
else: else:
blurred_init_mask = pil_init_mask blurred_init_mask = pil_init_mask
multiplied_blurred_init_mask = ImageChops.multiply(blurred_init_mask, self.pil_image.split()[-1]) multiplied_blurred_init_mask = ImageChops.multiply(
blurred_init_mask, self.pil_image.split()[-1]
)
# Paste original on color-corrected generation (using blurred mask) # Paste original on color-corrected generation (using blurred mask)
matched_result.paste(init_image, (0,0), mask = multiplied_blurred_init_mask) matched_result.paste(init_image, (0, 0), mask=multiplied_blurred_init_mask)
return matched_result return matched_result
def sample_to_lowres_estimated_image(self,samples): def sample_to_lowres_estimated_image(self, samples):
# origingally adapted from code by @erucipe and @keturn here: # origingally adapted from code by @erucipe and @keturn here:
# https://discuss.huggingface.co/t/decoding-latents-to-rgb-without-upscaling/23204/7 # https://discuss.huggingface.co/t/decoding-latents-to-rgb-without-upscaling/23204/7
# these updated numbers for v1.5 are from @torridgristle # these updated numbers for v1.5 are from @torridgristle
v1_5_latent_rgb_factors = torch.tensor([ v1_5_latent_rgb_factors = torch.tensor(
# R G B [
[ 0.3444, 0.1385, 0.0670], # L1 # R G B
[ 0.1247, 0.4027, 0.1494], # L2 [0.3444, 0.1385, 0.0670], # L1
[-0.3192, 0.2513, 0.2103], # L3 [0.1247, 0.4027, 0.1494], # L2
[-0.1307, -0.1874, -0.7445] # L4 [-0.3192, 0.2513, 0.2103], # L3
], dtype=samples.dtype, device=samples.device) [-0.1307, -0.1874, -0.7445], # L4
],
dtype=samples.dtype,
device=samples.device,
)
latent_image = samples[0].permute(1, 2, 0) @ v1_5_latent_rgb_factors latent_image = samples[0].permute(1, 2, 0) @ v1_5_latent_rgb_factors
latents_ubyte = (((latent_image + 1) / 2) latents_ubyte = (
.clamp(0, 1) # change scale from -1..1 to 0..1 ((latent_image + 1) / 2)
.mul(0xFF) # to 0..255 .clamp(0, 1) # change scale from -1..1 to 0..1
.byte()).cpu() .mul(0xFF) # to 0..255
.byte()
).cpu()
return Image.fromarray(latents_ubyte.numpy()) return Image.fromarray(latents_ubyte.numpy())
@ -217,38 +284,45 @@ class Generator:
if self.variation_amount > 0 or len(self.with_variations) > 0: if self.variation_amount > 0 or len(self.with_variations) > 0:
# use fixed initial noise plus random noise per iteration # use fixed initial noise plus random noise per iteration
seed_everything(seed) seed_everything(seed)
initial_noise = self.get_noise(width,height) initial_noise = self.get_noise(width, height)
for v_seed, v_weight in self.with_variations: for v_seed, v_weight in self.with_variations:
seed = v_seed seed = v_seed
seed_everything(seed) seed_everything(seed)
next_noise = self.get_noise(width,height) next_noise = self.get_noise(width, height)
initial_noise = self.slerp(v_weight, initial_noise, next_noise) initial_noise = self.slerp(v_weight, initial_noise, next_noise)
if self.variation_amount > 0: if self.variation_amount > 0:
random.seed() # reset RNG to an actually random state, so we can get a random seed for variations random.seed() # reset RNG to an actually random state, so we can get a random seed for variations
seed = random.randrange(0,np.iinfo(np.uint32).max) seed = random.randrange(0, np.iinfo(np.uint32).max)
return (seed, initial_noise) return (seed, initial_noise)
else: else:
return (seed, None) return (seed, None)
# returns a tensor filled with random numbers from a normal distribution # returns a tensor filled with random numbers from a normal distribution
def get_noise(self,width,height): def get_noise(self, width, height):
""" """
Returns a tensor filled with random numbers, either form a normal distribution Returns a tensor filled with random numbers, either form a normal distribution
(txt2img) or from the latent image (img2img, inpaint) (txt2img) or from the latent image (img2img, inpaint)
""" """
raise NotImplementedError("get_noise() must be implemented in a descendent class") raise NotImplementedError(
"get_noise() must be implemented in a descendent class"
)
def get_perlin_noise(self,width,height): def get_perlin_noise(self, width, height):
fixdevice = 'cpu' if (self.model.device.type == 'mps') else self.model.device fixdevice = "cpu" if (self.model.device.type == "mps") else self.model.device
# limit noise to only the diffusion image channels, not the mask channels # limit noise to only the diffusion image channels, not the mask channels
input_channels = min(self.latent_channels, 4) input_channels = min(self.latent_channels, 4)
# round up to the nearest block of 8 # round up to the nearest block of 8
temp_width = int((width + 7) / 8) * 8 temp_width = int((width + 7) / 8) * 8
temp_height = int((height + 7) / 8) * 8 temp_height = int((height + 7) / 8) * 8
noise = torch.stack([ noise = torch.stack(
rand_perlin_2d((temp_height, temp_width), [
(8, 8), rand_perlin_2d(
device = self.model.device).to(fixdevice) for _ in range(input_channels)], dim=0).to(self.model.device) (temp_height, temp_width), (8, 8), device=self.model.device
).to(fixdevice)
for _ in range(input_channels)
],
dim=0,
).to(self.model.device)
return noise[0:4, 0:height, 0:width] return noise[0:4, 0:height, 0:width]
def new_seed(self): def new_seed(self):
@ -256,7 +330,7 @@ class Generator:
return self.seed return self.seed
def slerp(self, t, v0, v1, DOT_THRESHOLD=0.9995): def slerp(self, t, v0, v1, DOT_THRESHOLD=0.9995):
''' """
Spherical linear interpolation Spherical linear interpolation
Args: Args:
t (float/np.ndarray): Float value between 0.0 and 1.0 t (float/np.ndarray): Float value between 0.0 and 1.0
@ -266,7 +340,7 @@ class Generator:
colineal. Not recommended to alter this. colineal. Not recommended to alter this.
Returns: Returns:
v2 (np.ndarray): Interpolation vector between v0 and v1 v2 (np.ndarray): Interpolation vector between v0 and v1
''' """
inputs_are_torch = False inputs_are_torch = False
if not isinstance(v0, np.ndarray): if not isinstance(v0, np.ndarray):
inputs_are_torch = True inputs_are_torch = True
@ -292,15 +366,15 @@ class Generator:
return v2 return v2
def safety_check(self,image:Image.Image): def safety_check(self, image: Image.Image):
''' """
If the CompViz safety checker flags an NSFW image, we If the CompViz safety checker flags an NSFW image, we
blur it out. blur it out.
''' """
import diffusers import diffusers
checker = self.safety_checker['checker'] checker = self.safety_checker["checker"]
extractor = self.safety_checker['extractor'] extractor = self.safety_checker["extractor"]
features = extractor([image], return_tensors="pt") features = extractor([image], return_tensors="pt")
features.to(self.model.device) features.to(self.model.device)
@ -309,19 +383,23 @@ class Generator:
x_image = x_image[None].transpose(0, 3, 1, 2) x_image = x_image[None].transpose(0, 3, 1, 2)
diffusers.logging.set_verbosity_error() diffusers.logging.set_verbosity_error()
checked_image, has_nsfw_concept = checker(images=x_image, clip_input=features.pixel_values) checked_image, has_nsfw_concept = checker(
images=x_image, clip_input=features.pixel_values
)
if has_nsfw_concept[0]: if has_nsfw_concept[0]:
print('** An image with potential non-safe content has been detected. A blurred image will be returned. **') print(
"** An image with potential non-safe content has been detected. A blurred image will be returned. **"
)
return self.blur(image) return self.blur(image)
else: else:
return image return image
def blur(self,input): def blur(self, input):
blurry = input.filter(filter=ImageFilter.GaussianBlur(radius=32)) blurry = input.filter(filter=ImageFilter.GaussianBlur(radius=32))
try: try:
caution = self.get_caution_img() caution = self.get_caution_img()
if caution: if caution:
blurry.paste(caution,(0,0),caution) blurry.paste(caution, (0, 0), caution)
except FileNotFoundError: except FileNotFoundError:
pass pass
return blurry return blurry
@ -332,43 +410,52 @@ class Generator:
return self.caution_img return self.caution_img
path = Path(web_assets.__path__[0]) / CAUTION_IMG path = Path(web_assets.__path__[0]) / CAUTION_IMG
caution = Image.open(path) caution = Image.open(path)
self.caution_img = caution.resize((caution.width // 2, caution.height //2)) self.caution_img = caution.resize((caution.width // 2, caution.height // 2))
return self.caution_img return self.caution_img
# this is a handy routine for debugging use. Given a generated sample, # this is a handy routine for debugging use. Given a generated sample,
# convert it into a PNG image and store it at the indicated path # convert it into a PNG image and store it at the indicated path
def save_sample(self, sample, filepath): def save_sample(self, sample, filepath):
image = self.sample_to_image(sample) image = self.sample_to_image(sample)
dirname = os.path.dirname(filepath) or '.' dirname = os.path.dirname(filepath) or "."
if not os.path.exists(dirname): if not os.path.exists(dirname):
print(f'** creating directory {dirname}') print(f"** creating directory {dirname}")
os.makedirs(dirname, exist_ok=True) os.makedirs(dirname, exist_ok=True)
image.save(filepath,'PNG') image.save(filepath, "PNG")
def torch_dtype(self) -> torch.dtype:
def torch_dtype(self)->torch.dtype: return torch.float16 if self.precision == "float16" else torch.float32
return torch.float16 if self.precision == 'float16' else torch.float32
# returns a tensor filled with random numbers from a normal distribution # returns a tensor filled with random numbers from a normal distribution
def get_noise(self,width,height): def get_noise(self, width, height):
device = self.model.device device = self.model.device
# limit noise to only the diffusion image channels, not the mask channels # limit noise to only the diffusion image channels, not the mask channels
input_channels = min(self.latent_channels, 4) input_channels = min(self.latent_channels, 4)
if self.use_mps_noise or device.type == 'mps': if self.use_mps_noise or device.type == "mps":
x = torch.randn([1, x = torch.randn(
input_channels, [
height // self.downsampling_factor, 1,
width // self.downsampling_factor], input_channels,
dtype=self.torch_dtype(), height // self.downsampling_factor,
device='cpu').to(device) width // self.downsampling_factor,
],
dtype=self.torch_dtype(),
device="cpu",
).to(device)
else: else:
x = torch.randn([1, x = torch.randn(
input_channels, [
height // self.downsampling_factor, 1,
width // self.downsampling_factor], input_channels,
dtype=self.torch_dtype(), height // self.downsampling_factor,
device=device) width // self.downsampling_factor,
],
dtype=self.torch_dtype(),
device=device,
)
if self.perlin > 0.0: if self.perlin > 0.0:
perlin_noise = self.get_perlin_noise(width // self.downsampling_factor, height // self.downsampling_factor) perlin_noise = self.get_perlin_noise(
x = (1-self.perlin)*x + self.perlin*perlin_noise width // self.downsampling_factor, height // self.downsampling_factor
)
x = (1 - self.perlin) * x + self.perlin * perlin_noise
return x return x

View File

@ -1,37 +1,38 @@
''' """
ldm.invoke.generator.embiggen descends from ldm.invoke.generator invokeai.backend.generator.embiggen descends from .generator
and generates with ldm.invoke.generator.img2img and generates with .generator.img2img
''' """
import numpy as np import numpy as np
import torch import torch
from PIL import Image from PIL import Image
from tqdm import trange from tqdm import trange
from ldm.invoke.generator.base import Generator from .base import Generator
from ldm.invoke.generator.img2img import Img2Img from .img2img import Img2Img
class Embiggen(Generator): class Embiggen(Generator):
def __init__(self, model, precision): def __init__(self, model, precision):
super().__init__(model, precision) super().__init__(model, precision)
self.init_latent = None self.init_latent = None
# Replace generate because Embiggen doesn't need/use most of what it does normallly # Replace generate because Embiggen doesn't need/use most of what it does normallly
def generate(self,prompt,iterations=1,seed=None, def generate(
image_callback=None, step_callback=None, self,
**kwargs): prompt,
iterations=1,
make_image = self.get_make_image( seed=None,
prompt, image_callback=None,
step_callback = step_callback, step_callback=None,
**kwargs **kwargs,
) ):
results = [] make_image = self.get_make_image(prompt, step_callback=step_callback, **kwargs)
seed = seed if seed else self.new_seed() results = []
seed = seed if seed else self.new_seed()
# Noise will be generated by the Img2Img generator when called # Noise will be generated by the Img2Img generator when called
for _ in trange(iterations, desc='Generating'): for _ in trange(iterations, desc="Generating"):
# make_image will call Img2Img which will do the equivalent of get_noise itself # make_image will call Img2Img which will do the equivalent of get_noise itself
image = make_image() image = make_image()
results.append([image, seed]) results.append([image, seed])
@ -56,13 +57,15 @@ class Embiggen(Generator):
embiggen, embiggen,
embiggen_tiles, embiggen_tiles,
step_callback=None, step_callback=None,
**kwargs **kwargs,
): ):
""" """
Returns a function returning an image derived from the prompt and multi-stage twice-baked potato layering over the img2img on the initial image Returns a function returning an image derived from the prompt and multi-stage twice-baked potato layering over the img2img on the initial image
Return value depends on the seed at the time you call it Return value depends on the seed at the time you call it
""" """
assert not sampler.uses_inpainting_model(), "--embiggen is not supported by inpainting models" assert (
not sampler.uses_inpainting_model()
), "--embiggen is not supported by inpainting models"
# Construct embiggen arg array, and sanity check arguments # Construct embiggen arg array, and sanity check arguments
if embiggen == None: # embiggen can also be called with just embiggen_tiles if embiggen == None: # embiggen can also be called with just embiggen_tiles
@ -70,48 +73,57 @@ class Embiggen(Generator):
elif embiggen[0] < 0: elif embiggen[0] < 0:
embiggen[0] = 1.0 embiggen[0] = 1.0
print( print(
'>> Embiggen scaling factor cannot be negative, fell back to the default of 1.0 !') ">> Embiggen scaling factor cannot be negative, fell back to the default of 1.0 !"
)
if len(embiggen) < 2: if len(embiggen) < 2:
embiggen.append(0.75) embiggen.append(0.75)
elif embiggen[1] > 1.0 or embiggen[1] < 0: elif embiggen[1] > 1.0 or embiggen[1] < 0:
embiggen[1] = 0.75 embiggen[1] = 0.75
print('>> Embiggen upscaling strength for ESRGAN must be between 0 and 1, fell back to the default of 0.75 !') print(
">> Embiggen upscaling strength for ESRGAN must be between 0 and 1, fell back to the default of 0.75 !"
)
if len(embiggen) < 3: if len(embiggen) < 3:
embiggen.append(0.25) embiggen.append(0.25)
elif embiggen[2] < 0: elif embiggen[2] < 0:
embiggen[2] = 0.25 embiggen[2] = 0.25
print('>> Overlap size for Embiggen must be a positive ratio between 0 and 1 OR a number of pixels, fell back to the default of 0.25 !') print(
">> Overlap size for Embiggen must be a positive ratio between 0 and 1 OR a number of pixels, fell back to the default of 0.25 !"
)
# Convert tiles from their user-freindly count-from-one to count-from-zero, because we need to do modulo math # Convert tiles from their user-freindly count-from-one to count-from-zero, because we need to do modulo math
# and then sort them, because... people. # and then sort them, because... people.
if embiggen_tiles: if embiggen_tiles:
embiggen_tiles = list(map(lambda n: n-1, embiggen_tiles)) embiggen_tiles = list(map(lambda n: n - 1, embiggen_tiles))
embiggen_tiles.sort() embiggen_tiles.sort()
if strength >= 0.5: if strength >= 0.5:
print(f'* WARNING: Embiggen may produce mirror motifs if the strength (-f) is too high (currently {strength}). Try values between 0.35-0.45.') print(
f"* WARNING: Embiggen may produce mirror motifs if the strength (-f) is too high (currently {strength}). Try values between 0.35-0.45."
)
# Prep img2img generator, since we wrap over it # Prep img2img generator, since we wrap over it
gen_img2img = Img2Img(self.model,self.precision) gen_img2img = Img2Img(self.model, self.precision)
# Open original init image (not a tensor) to manipulate # Open original init image (not a tensor) to manipulate
initsuperimage = Image.open(init_img) initsuperimage = Image.open(init_img)
with Image.open(init_img) as img: with Image.open(init_img) as img:
initsuperimage = img.convert('RGB') initsuperimage = img.convert("RGB")
# Size of the target super init image in pixels # Size of the target super init image in pixels
initsuperwidth, initsuperheight = initsuperimage.size initsuperwidth, initsuperheight = initsuperimage.size
# Increase by scaling factor if not already resized, using ESRGAN as able # Increase by scaling factor if not already resized, using ESRGAN as able
if embiggen[0] != 1.0: if embiggen[0] != 1.0:
initsuperwidth = round(initsuperwidth*embiggen[0]) initsuperwidth = round(initsuperwidth * embiggen[0])
initsuperheight = round(initsuperheight*embiggen[0]) initsuperheight = round(initsuperheight * embiggen[0])
if embiggen[1] > 0: # No point in ESRGAN upscaling if strength is set zero if embiggen[1] > 0: # No point in ESRGAN upscaling if strength is set zero
from ldm.invoke.restoration.realesrgan import ESRGAN from ..restoration.realesrgan import ESRGAN
esrgan = ESRGAN() esrgan = ESRGAN()
print( print(
f'>> ESRGAN upscaling init image prior to cutting with Embiggen with strength {embiggen[1]}') f">> ESRGAN upscaling init image prior to cutting with Embiggen with strength {embiggen[1]}"
)
if embiggen[0] > 2: if embiggen[0] > 2:
initsuperimage = esrgan.process( initsuperimage = esrgan.process(
initsuperimage, initsuperimage,
@ -130,7 +142,8 @@ class Embiggen(Generator):
# but from personal experiance it doesn't greatly improve anything after 4x # but from personal experiance it doesn't greatly improve anything after 4x
# Resize to target scaling factor resolution # Resize to target scaling factor resolution
initsuperimage = initsuperimage.resize( initsuperimage = initsuperimage.resize(
(initsuperwidth, initsuperheight), Image.Resampling.LANCZOS) (initsuperwidth, initsuperheight), Image.Resampling.LANCZOS
)
# Use width and height as tile widths and height # Use width and height as tile widths and height
# Determine buffer size in pixels # Determine buffer size in pixels
@ -153,23 +166,24 @@ class Embiggen(Generator):
emb_tiles_x = 1 emb_tiles_x = 1
emb_tiles_y = 1 emb_tiles_y = 1
if (initsuperwidth - width) > 0: if (initsuperwidth - width) > 0:
emb_tiles_x = ceildiv(initsuperwidth - width, emb_tiles_x = ceildiv(initsuperwidth - width, width - overlap_size_x) + 1
width - overlap_size_x) + 1
if (initsuperheight - height) > 0: if (initsuperheight - height) > 0:
emb_tiles_y = ceildiv(initsuperheight - height, emb_tiles_y = ceildiv(initsuperheight - height, height - overlap_size_y) + 1
height - overlap_size_y) + 1
# Sanity # Sanity
assert emb_tiles_x > 1 or emb_tiles_y > 1, f'ERROR: Based on the requested dimensions of {initsuperwidth}x{initsuperheight} and tiles of {width}x{height} you don\'t need to Embiggen! Check your arguments.' assert (
emb_tiles_x > 1 or emb_tiles_y > 1
), f"ERROR: Based on the requested dimensions of {initsuperwidth}x{initsuperheight} and tiles of {width}x{height} you don't need to Embiggen! Check your arguments."
# Prep alpha layers -------------- # Prep alpha layers --------------
# https://stackoverflow.com/questions/69321734/how-to-create-different-transparency-like-gradient-with-python-pil # https://stackoverflow.com/questions/69321734/how-to-create-different-transparency-like-gradient-with-python-pil
# agradientL is Left-side transparent # agradientL is Left-side transparent
agradientL = Image.linear_gradient('L').rotate( agradientL = (
90).resize((overlap_size_x, height)) Image.linear_gradient("L").rotate(90).resize((overlap_size_x, height))
)
# agradientT is Top-side transparent # agradientT is Top-side transparent
agradientT = Image.linear_gradient('L').resize((width, overlap_size_y)) agradientT = Image.linear_gradient("L").resize((width, overlap_size_y))
# radial corner is the left-top corner, made full circle then cut to just the left-top quadrant # radial corner is the left-top corner, made full circle then cut to just the left-top quadrant
agradientC = Image.new('L', (256, 256)) agradientC = Image.new("L", (256, 256))
for y in range(256): for y in range(256):
for x in range(256): for x in range(256):
# Find distance to lower right corner (numpy takes arrays) # Find distance to lower right corner (numpy takes arrays)
@ -177,16 +191,16 @@ class Embiggen(Generator):
# Clamp values to max 255 # Clamp values to max 255
if distanceToLR > 255: if distanceToLR > 255:
distanceToLR = 255 distanceToLR = 255
#Place the pixel as invert of distance # Place the pixel as invert of distance
agradientC.putpixel((x, y), round(255 - distanceToLR)) agradientC.putpixel((x, y), round(255 - distanceToLR))
# Create alternative asymmetric diagonal corner to use on "tailing" intersections to prevent hard edges # Create alternative asymmetric diagonal corner to use on "tailing" intersections to prevent hard edges
# Fits for a left-fading gradient on the bottom side and full opacity on the right side. # Fits for a left-fading gradient on the bottom side and full opacity on the right side.
agradientAsymC = Image.new('L', (256, 256)) agradientAsymC = Image.new("L", (256, 256))
for y in range(256): for y in range(256):
for x in range(256): for x in range(256):
value = round(max(0, x-(255-y)) * (255 / max(1,y))) value = round(max(0, x - (255 - y)) * (255 / max(1, y)))
#Clamp values # Clamp values
value = max(0, value) value = max(0, value)
value = min(255, value) value = min(255, value)
agradientAsymC.putpixel((x, y), value) agradientAsymC.putpixel((x, y), value)
@ -204,80 +218,91 @@ class Embiggen(Generator):
# make masks with an asymmetric upper-right corner so when the curved transparent corner of the next tile # make masks with an asymmetric upper-right corner so when the curved transparent corner of the next tile
# to its right is placed it doesn't reveal a hard trailing semi-transparent edge in the overlapping space # to its right is placed it doesn't reveal a hard trailing semi-transparent edge in the overlapping space
alphaLayerTaC = alphaLayerT.copy() alphaLayerTaC = alphaLayerT.copy()
alphaLayerTaC.paste(agradientAsymC.rotate(270).resize((overlap_size_x, overlap_size_y)), (width - overlap_size_x, 0)) alphaLayerTaC.paste(
agradientAsymC.rotate(270).resize((overlap_size_x, overlap_size_y)),
(width - overlap_size_x, 0),
)
alphaLayerLTaC = alphaLayerLTC.copy() alphaLayerLTaC = alphaLayerLTC.copy()
alphaLayerLTaC.paste(agradientAsymC.rotate(270).resize((overlap_size_x, overlap_size_y)), (width - overlap_size_x, 0)) alphaLayerLTaC.paste(
agradientAsymC.rotate(270).resize((overlap_size_x, overlap_size_y)),
(width - overlap_size_x, 0),
)
if embiggen_tiles: if embiggen_tiles:
# Individual unconnected sides # Individual unconnected sides
alphaLayerR = Image.new("L", (width, height), 255) alphaLayerR = Image.new("L", (width, height), 255)
alphaLayerR.paste(agradientL.rotate( alphaLayerR.paste(agradientL.rotate(180), (width - overlap_size_x, 0))
180), (width - overlap_size_x, 0))
alphaLayerB = Image.new("L", (width, height), 255) alphaLayerB = Image.new("L", (width, height), 255)
alphaLayerB.paste(agradientT.rotate( alphaLayerB.paste(agradientT.rotate(180), (0, height - overlap_size_y))
180), (0, height - overlap_size_y))
alphaLayerTB = Image.new("L", (width, height), 255) alphaLayerTB = Image.new("L", (width, height), 255)
alphaLayerTB.paste(agradientT, (0, 0)) alphaLayerTB.paste(agradientT, (0, 0))
alphaLayerTB.paste(agradientT.rotate( alphaLayerTB.paste(agradientT.rotate(180), (0, height - overlap_size_y))
180), (0, height - overlap_size_y))
alphaLayerLR = Image.new("L", (width, height), 255) alphaLayerLR = Image.new("L", (width, height), 255)
alphaLayerLR.paste(agradientL, (0, 0)) alphaLayerLR.paste(agradientL, (0, 0))
alphaLayerLR.paste(agradientL.rotate( alphaLayerLR.paste(agradientL.rotate(180), (width - overlap_size_x, 0))
180), (width - overlap_size_x, 0))
# Sides and corner Layers # Sides and corner Layers
alphaLayerRBC = Image.new("L", (width, height), 255) alphaLayerRBC = Image.new("L", (width, height), 255)
alphaLayerRBC.paste(agradientL.rotate( alphaLayerRBC.paste(agradientL.rotate(180), (width - overlap_size_x, 0))
180), (width - overlap_size_x, 0)) alphaLayerRBC.paste(agradientT.rotate(180), (0, height - overlap_size_y))
alphaLayerRBC.paste(agradientT.rotate( alphaLayerRBC.paste(
180), (0, height - overlap_size_y)) agradientC.rotate(180).resize((overlap_size_x, overlap_size_y)),
alphaLayerRBC.paste(agradientC.rotate(180).resize( (width - overlap_size_x, height - overlap_size_y),
(overlap_size_x, overlap_size_y)), (width - overlap_size_x, height - overlap_size_y)) )
alphaLayerLBC = Image.new("L", (width, height), 255) alphaLayerLBC = Image.new("L", (width, height), 255)
alphaLayerLBC.paste(agradientL, (0, 0)) alphaLayerLBC.paste(agradientL, (0, 0))
alphaLayerLBC.paste(agradientT.rotate( alphaLayerLBC.paste(agradientT.rotate(180), (0, height - overlap_size_y))
180), (0, height - overlap_size_y)) alphaLayerLBC.paste(
alphaLayerLBC.paste(agradientC.rotate(90).resize( agradientC.rotate(90).resize((overlap_size_x, overlap_size_y)),
(overlap_size_x, overlap_size_y)), (0, height - overlap_size_y)) (0, height - overlap_size_y),
)
alphaLayerRTC = Image.new("L", (width, height), 255) alphaLayerRTC = Image.new("L", (width, height), 255)
alphaLayerRTC.paste(agradientL.rotate( alphaLayerRTC.paste(agradientL.rotate(180), (width - overlap_size_x, 0))
180), (width - overlap_size_x, 0))
alphaLayerRTC.paste(agradientT, (0, 0)) alphaLayerRTC.paste(agradientT, (0, 0))
alphaLayerRTC.paste(agradientC.rotate(270).resize( alphaLayerRTC.paste(
(overlap_size_x, overlap_size_y)), (width - overlap_size_x, 0)) agradientC.rotate(270).resize((overlap_size_x, overlap_size_y)),
(width - overlap_size_x, 0),
)
# All but X layers # All but X layers
alphaLayerABT = Image.new("L", (width, height), 255) alphaLayerABT = Image.new("L", (width, height), 255)
alphaLayerABT.paste(alphaLayerLBC, (0, 0)) alphaLayerABT.paste(alphaLayerLBC, (0, 0))
alphaLayerABT.paste(agradientL.rotate( alphaLayerABT.paste(agradientL.rotate(180), (width - overlap_size_x, 0))
180), (width - overlap_size_x, 0)) alphaLayerABT.paste(
alphaLayerABT.paste(agradientC.rotate(180).resize( agradientC.rotate(180).resize((overlap_size_x, overlap_size_y)),
(overlap_size_x, overlap_size_y)), (width - overlap_size_x, height - overlap_size_y)) (width - overlap_size_x, height - overlap_size_y),
)
alphaLayerABL = Image.new("L", (width, height), 255) alphaLayerABL = Image.new("L", (width, height), 255)
alphaLayerABL.paste(alphaLayerRTC, (0, 0)) alphaLayerABL.paste(alphaLayerRTC, (0, 0))
alphaLayerABL.paste(agradientT.rotate( alphaLayerABL.paste(agradientT.rotate(180), (0, height - overlap_size_y))
180), (0, height - overlap_size_y)) alphaLayerABL.paste(
alphaLayerABL.paste(agradientC.rotate(180).resize( agradientC.rotate(180).resize((overlap_size_x, overlap_size_y)),
(overlap_size_x, overlap_size_y)), (width - overlap_size_x, height - overlap_size_y)) (width - overlap_size_x, height - overlap_size_y),
)
alphaLayerABR = Image.new("L", (width, height), 255) alphaLayerABR = Image.new("L", (width, height), 255)
alphaLayerABR.paste(alphaLayerLBC, (0, 0)) alphaLayerABR.paste(alphaLayerLBC, (0, 0))
alphaLayerABR.paste(agradientT, (0, 0)) alphaLayerABR.paste(agradientT, (0, 0))
alphaLayerABR.paste(agradientC.resize( alphaLayerABR.paste(
(overlap_size_x, overlap_size_y)), (0, 0)) agradientC.resize((overlap_size_x, overlap_size_y)), (0, 0)
)
alphaLayerABB = Image.new("L", (width, height), 255) alphaLayerABB = Image.new("L", (width, height), 255)
alphaLayerABB.paste(alphaLayerRTC, (0, 0)) alphaLayerABB.paste(alphaLayerRTC, (0, 0))
alphaLayerABB.paste(agradientL, (0, 0)) alphaLayerABB.paste(agradientL, (0, 0))
alphaLayerABB.paste(agradientC.resize( alphaLayerABB.paste(
(overlap_size_x, overlap_size_y)), (0, 0)) agradientC.resize((overlap_size_x, overlap_size_y)), (0, 0)
)
# All-around layer # All-around layer
alphaLayerAA = Image.new("L", (width, height), 255) alphaLayerAA = Image.new("L", (width, height), 255)
alphaLayerAA.paste(alphaLayerABT, (0, 0)) alphaLayerAA.paste(alphaLayerABT, (0, 0))
alphaLayerAA.paste(agradientT, (0, 0)) alphaLayerAA.paste(agradientT, (0, 0))
alphaLayerAA.paste(agradientC.resize( alphaLayerAA.paste(
(overlap_size_x, overlap_size_y)), (0, 0)) agradientC.resize((overlap_size_x, overlap_size_y)), (0, 0)
alphaLayerAA.paste(agradientC.rotate(270).resize( )
(overlap_size_x, overlap_size_y)), (width - overlap_size_x, 0)) alphaLayerAA.paste(
agradientC.rotate(270).resize((overlap_size_x, overlap_size_y)),
(width - overlap_size_x, 0),
)
# Clean up temporary gradients # Clean up temporary gradients
del agradientL del agradientL
@ -287,17 +312,20 @@ class Embiggen(Generator):
def make_image(): def make_image():
# Make main tiles ------------------------------------------------- # Make main tiles -------------------------------------------------
if embiggen_tiles: if embiggen_tiles:
print(f'>> Making {len(embiggen_tiles)} Embiggen tiles...') print(f">> Making {len(embiggen_tiles)} Embiggen tiles...")
else: else:
print( print(
f'>> Making {(emb_tiles_x * emb_tiles_y)} Embiggen tiles ({emb_tiles_x}x{emb_tiles_y})...') f">> Making {(emb_tiles_x * emb_tiles_y)} Embiggen tiles ({emb_tiles_x}x{emb_tiles_y})..."
)
emb_tile_store = [] emb_tile_store = []
# Although we could use the same seed for every tile for determinism, at higher strengths this may # Although we could use the same seed for every tile for determinism, at higher strengths this may
# produce duplicated structures for each tile and make the tiling effect more obvious # produce duplicated structures for each tile and make the tiling effect more obvious
# instead track and iterate a local seed we pass to Img2Img # instead track and iterate a local seed we pass to Img2Img
seed = self.seed seed = self.seed
seedintlimit = np.iinfo(np.uint32).max - 1 # only retreive this one from numpy seedintlimit = (
np.iinfo(np.uint32).max - 1
) # only retreive this one from numpy
for tile in range(emb_tiles_x * emb_tiles_y): for tile in range(emb_tiles_x * emb_tiles_y):
# Don't iterate on first tile # Don't iterate on first tile
@ -334,37 +362,38 @@ class Embiggen(Generator):
if embiggen_tiles: if embiggen_tiles:
print( print(
f'Making tile #{tile + 1} ({embiggen_tiles.index(tile) + 1} of {len(embiggen_tiles)} requested)') f"Making tile #{tile + 1} ({embiggen_tiles.index(tile) + 1} of {len(embiggen_tiles)} requested)"
)
else: else:
print( print(f"Starting {tile + 1} of {(emb_tiles_x * emb_tiles_y)} tiles")
f'Starting {tile + 1} of {(emb_tiles_x * emb_tiles_y)} tiles')
# create a torch tensor from an Image # create a torch tensor from an Image
newinitimage = np.array( newinitimage = np.array(newinitimage).astype(np.float32) / 255.0
newinitimage).astype(np.float32) / 255.0
newinitimage = newinitimage[None].transpose(0, 3, 1, 2) newinitimage = newinitimage[None].transpose(0, 3, 1, 2)
newinitimage = torch.from_numpy(newinitimage) newinitimage = torch.from_numpy(newinitimage)
newinitimage = 2.0 * newinitimage - 1.0 newinitimage = 2.0 * newinitimage - 1.0
newinitimage = newinitimage.to(self.model.device) newinitimage = newinitimage.to(self.model.device)
clear_cuda_cache = kwargs['clear_cuda_cache'] if 'clear_cuda_cache' in kwargs else None clear_cuda_cache = (
kwargs["clear_cuda_cache"] if "clear_cuda_cache" in kwargs else None
)
tile_results = gen_img2img.generate( tile_results = gen_img2img.generate(
prompt, prompt,
iterations = 1, iterations=1,
seed = seed, seed=seed,
sampler = sampler, sampler=sampler,
steps = steps, steps=steps,
cfg_scale = cfg_scale, cfg_scale=cfg_scale,
conditioning = conditioning, conditioning=conditioning,
ddim_eta = ddim_eta, ddim_eta=ddim_eta,
image_callback = None, # called only after the final image is generated image_callback=None, # called only after the final image is generated
step_callback = step_callback, # called after each intermediate image is generated step_callback=step_callback, # called after each intermediate image is generated
width = width, width=width,
height = height, height=height,
init_image = newinitimage, # notice that init_image is different from init_img init_image=newinitimage, # notice that init_image is different from init_img
mask_image = None, mask_image=None,
strength = strength, strength=strength,
clear_cuda_cache = clear_cuda_cache clear_cuda_cache=clear_cuda_cache,
) )
emb_tile_store.append(tile_results[0][0]) emb_tile_store.append(tile_results[0][0])
@ -373,12 +402,14 @@ class Embiggen(Generator):
del newinitimage del newinitimage
# Sanity check we have them all # Sanity check we have them all
if len(emb_tile_store) == (emb_tiles_x * emb_tiles_y) or (embiggen_tiles != [] and len(emb_tile_store) == len(embiggen_tiles)): if len(emb_tile_store) == (emb_tiles_x * emb_tiles_y) or (
outputsuperimage = Image.new( embiggen_tiles != [] and len(emb_tile_store) == len(embiggen_tiles)
"RGBA", (initsuperwidth, initsuperheight)) ):
outputsuperimage = Image.new("RGBA", (initsuperwidth, initsuperheight))
if embiggen_tiles: if embiggen_tiles:
outputsuperimage.alpha_composite( outputsuperimage.alpha_composite(
initsuperimage.convert('RGBA'), (0, 0)) initsuperimage.convert("RGBA"), (0, 0)
)
for tile in range(emb_tiles_x * emb_tiles_y): for tile in range(emb_tiles_x * emb_tiles_y):
if embiggen_tiles: if embiggen_tiles:
if tile in embiggen_tiles: if tile in embiggen_tiles:
@ -387,7 +418,7 @@ class Embiggen(Generator):
continue continue
else: else:
intileimage = emb_tile_store[tile] intileimage = emb_tile_store[tile]
intileimage = intileimage.convert('RGBA') intileimage = intileimage.convert("RGBA")
# Get row and column entries # Get row and column entries
emb_row_i = tile // emb_tiles_x emb_row_i = tile // emb_tiles_x
emb_column_i = tile % emb_tiles_x emb_column_i = tile % emb_tiles_x
@ -399,8 +430,7 @@ class Embiggen(Generator):
if emb_column_i + 1 == emb_tiles_x: if emb_column_i + 1 == emb_tiles_x:
left = initsuperwidth - width left = initsuperwidth - width
else: else:
left = round(emb_column_i * left = round(emb_column_i * (width - overlap_size_x))
(width - overlap_size_x))
if emb_row_i + 1 == emb_tiles_y: if emb_row_i + 1 == emb_tiles_y:
top = initsuperheight - height top = initsuperheight - height
else: else:
@ -411,33 +441,43 @@ class Embiggen(Generator):
# top of image # top of image
if emb_row_i == 0: if emb_row_i == 0:
if emb_column_i == 0: if emb_column_i == 0:
if (tile+1) in embiggen_tiles: # Look-ahead right if (tile + 1) in embiggen_tiles: # Look-ahead right
if (tile+emb_tiles_x) not in embiggen_tiles: # Look-ahead down if (
tile + emb_tiles_x
) not in embiggen_tiles: # Look-ahead down
intileimage.putalpha(alphaLayerB) intileimage.putalpha(alphaLayerB)
# Otherwise do nothing on this tile # Otherwise do nothing on this tile
elif (tile+emb_tiles_x) in embiggen_tiles: # Look-ahead down only elif (
tile + emb_tiles_x
) in embiggen_tiles: # Look-ahead down only
intileimage.putalpha(alphaLayerR) intileimage.putalpha(alphaLayerR)
else: else:
intileimage.putalpha(alphaLayerRBC) intileimage.putalpha(alphaLayerRBC)
elif emb_column_i == emb_tiles_x - 1: elif emb_column_i == emb_tiles_x - 1:
if (tile+emb_tiles_x) in embiggen_tiles: # Look-ahead down if (
tile + emb_tiles_x
) in embiggen_tiles: # Look-ahead down
intileimage.putalpha(alphaLayerL) intileimage.putalpha(alphaLayerL)
else: else:
intileimage.putalpha(alphaLayerLBC) intileimage.putalpha(alphaLayerLBC)
else: else:
if (tile+1) in embiggen_tiles: # Look-ahead right if (tile + 1) in embiggen_tiles: # Look-ahead right
if (tile+emb_tiles_x) in embiggen_tiles: # Look-ahead down if (
tile + emb_tiles_x
) in embiggen_tiles: # Look-ahead down
intileimage.putalpha(alphaLayerL) intileimage.putalpha(alphaLayerL)
else: else:
intileimage.putalpha(alphaLayerLBC) intileimage.putalpha(alphaLayerLBC)
elif (tile+emb_tiles_x) in embiggen_tiles: # Look-ahead down only elif (
tile + emb_tiles_x
) in embiggen_tiles: # Look-ahead down only
intileimage.putalpha(alphaLayerLR) intileimage.putalpha(alphaLayerLR)
else: else:
intileimage.putalpha(alphaLayerABT) intileimage.putalpha(alphaLayerABT)
# bottom of image # bottom of image
elif emb_row_i == emb_tiles_y - 1: elif emb_row_i == emb_tiles_y - 1:
if emb_column_i == 0: if emb_column_i == 0:
if (tile+1) in embiggen_tiles: # Look-ahead right if (tile + 1) in embiggen_tiles: # Look-ahead right
intileimage.putalpha(alphaLayerTaC) intileimage.putalpha(alphaLayerTaC)
else: else:
intileimage.putalpha(alphaLayerRTC) intileimage.putalpha(alphaLayerRTC)
@ -445,34 +485,44 @@ class Embiggen(Generator):
# No tiles to look ahead to # No tiles to look ahead to
intileimage.putalpha(alphaLayerLTC) intileimage.putalpha(alphaLayerLTC)
else: else:
if (tile+1) in embiggen_tiles: # Look-ahead right if (tile + 1) in embiggen_tiles: # Look-ahead right
intileimage.putalpha(alphaLayerLTaC) intileimage.putalpha(alphaLayerLTaC)
else: else:
intileimage.putalpha(alphaLayerABB) intileimage.putalpha(alphaLayerABB)
# vertical middle of image # vertical middle of image
else: else:
if emb_column_i == 0: if emb_column_i == 0:
if (tile+1) in embiggen_tiles: # Look-ahead right if (tile + 1) in embiggen_tiles: # Look-ahead right
if (tile+emb_tiles_x) in embiggen_tiles: # Look-ahead down if (
tile + emb_tiles_x
) in embiggen_tiles: # Look-ahead down
intileimage.putalpha(alphaLayerTaC) intileimage.putalpha(alphaLayerTaC)
else: else:
intileimage.putalpha(alphaLayerTB) intileimage.putalpha(alphaLayerTB)
elif (tile+emb_tiles_x) in embiggen_tiles: # Look-ahead down only elif (
tile + emb_tiles_x
) in embiggen_tiles: # Look-ahead down only
intileimage.putalpha(alphaLayerRTC) intileimage.putalpha(alphaLayerRTC)
else: else:
intileimage.putalpha(alphaLayerABL) intileimage.putalpha(alphaLayerABL)
elif emb_column_i == emb_tiles_x - 1: elif emb_column_i == emb_tiles_x - 1:
if (tile+emb_tiles_x) in embiggen_tiles: # Look-ahead down if (
tile + emb_tiles_x
) in embiggen_tiles: # Look-ahead down
intileimage.putalpha(alphaLayerLTC) intileimage.putalpha(alphaLayerLTC)
else: else:
intileimage.putalpha(alphaLayerABR) intileimage.putalpha(alphaLayerABR)
else: else:
if (tile+1) in embiggen_tiles: # Look-ahead right if (tile + 1) in embiggen_tiles: # Look-ahead right
if (tile+emb_tiles_x) in embiggen_tiles: # Look-ahead down if (
tile + emb_tiles_x
) in embiggen_tiles: # Look-ahead down
intileimage.putalpha(alphaLayerLTaC) intileimage.putalpha(alphaLayerLTaC)
else: else:
intileimage.putalpha(alphaLayerABR) intileimage.putalpha(alphaLayerABR)
elif (tile+emb_tiles_x) in embiggen_tiles: # Look-ahead down only elif (
tile + emb_tiles_x
) in embiggen_tiles: # Look-ahead down only
intileimage.putalpha(alphaLayerABB) intileimage.putalpha(alphaLayerABB)
else: else:
intileimage.putalpha(alphaLayerAA) intileimage.putalpha(alphaLayerAA)
@ -481,21 +531,28 @@ class Embiggen(Generator):
if emb_row_i == 0 and emb_column_i >= 1: if emb_row_i == 0 and emb_column_i >= 1:
intileimage.putalpha(alphaLayerL) intileimage.putalpha(alphaLayerL)
elif emb_row_i >= 1 and emb_column_i == 0: elif emb_row_i >= 1 and emb_column_i == 0:
if emb_column_i + 1 == emb_tiles_x: # If we don't have anything that can be placed to the right if (
emb_column_i + 1 == emb_tiles_x
): # If we don't have anything that can be placed to the right
intileimage.putalpha(alphaLayerT) intileimage.putalpha(alphaLayerT)
else: else:
intileimage.putalpha(alphaLayerTaC) intileimage.putalpha(alphaLayerTaC)
else: else:
if emb_column_i + 1 == emb_tiles_x: # If we don't have anything that can be placed to the right if (
emb_column_i + 1 == emb_tiles_x
): # If we don't have anything that can be placed to the right
intileimage.putalpha(alphaLayerLTC) intileimage.putalpha(alphaLayerLTC)
else: else:
intileimage.putalpha(alphaLayerLTaC) intileimage.putalpha(alphaLayerLTaC)
# Layer tile onto final image # Layer tile onto final image
outputsuperimage.alpha_composite(intileimage, (left, top)) outputsuperimage.alpha_composite(intileimage, (left, top))
else: else:
print('Error: could not find all Embiggen output tiles in memory? Something must have gone wrong with img2img generation.') print(
"Error: could not find all Embiggen output tiles in memory? Something must have gone wrong with img2img generation."
)
# after internal loops and patching up return Embiggen image # after internal loops and patching up return Embiggen image
return outputsuperimage return outputsuperimage
# end of function declaration # end of function declaration
return make_image return make_image

View File

@ -0,0 +1,97 @@
"""
invokeai.backend.generator.img2img descends from .generator
"""
import torch
from diffusers import logging
from ..stable_diffusion import (
ConditioningData,
PostprocessingSettings,
StableDiffusionGeneratorPipeline,
)
from .base import Generator
class Img2Img(Generator):
def __init__(self, model, precision):
super().__init__(model, precision)
self.init_latent = None # by get_noise()
def get_make_image(
self,
prompt,
sampler,
steps,
cfg_scale,
ddim_eta,
conditioning,
init_image,
strength,
step_callback=None,
threshold=0.0,
warmup=0.2,
perlin=0.0,
h_symmetry_time_pct=None,
v_symmetry_time_pct=None,
attention_maps_callback=None,
**kwargs,
):
"""
Returns a function returning an image derived from the prompt and the initial image
Return value depends on the seed at the time you call it.
"""
self.perlin = perlin
# noinspection PyTypeChecker
pipeline: StableDiffusionGeneratorPipeline = self.model
pipeline.scheduler = sampler
uc, c, extra_conditioning_info = conditioning
conditioning_data = ConditioningData(
uc,
c,
cfg_scale,
extra_conditioning_info,
postprocessing_settings=PostprocessingSettings(
threshold=threshold,
warmup=warmup,
h_symmetry_time_pct=h_symmetry_time_pct,
v_symmetry_time_pct=v_symmetry_time_pct,
),
).add_scheduler_args_if_applicable(pipeline.scheduler, eta=ddim_eta)
def make_image(x_T):
# FIXME: use x_T for initial seeded noise
# We're not at the moment because the pipeline automatically resizes init_image if
# necessary, which the x_T input might not match.
logging.set_verbosity_error() # quench safety check warnings
pipeline_output = pipeline.img2img_from_embeddings(
init_image,
strength,
steps,
conditioning_data,
noise_func=self.get_noise_like,
callback=step_callback,
)
if (
pipeline_output.attention_map_saver is not None
and attention_maps_callback is not None
):
attention_maps_callback(pipeline_output.attention_map_saver)
return pipeline.numpy_to_pil(pipeline_output.images)[0]
return make_image
def get_noise_like(self, like: torch.Tensor):
device = like.device
if device.type == "mps":
x = torch.randn_like(like, device="cpu").to(device)
else:
x = torch.randn_like(like, device=device)
if self.perlin > 0.0:
shape = like.shape
x = (1 - self.perlin) * x + self.perlin * self.get_perlin_noise(
shape[3], shape[2]
)
return x

View File

@ -1,32 +1,35 @@
''' """
ldm.invoke.generator.inpaint descends from ldm.invoke.generator invokeai.backend.generator.inpaint descends from .generator
''' """
from __future__ import annotations from __future__ import annotations
import math import math
import PIL
import cv2 import cv2
import numpy as np import numpy as np
import PIL
import torch import torch
from PIL import Image, ImageFilter, ImageOps, ImageChops from PIL import Image, ImageChops, ImageFilter, ImageOps
from ldm.invoke.generator.diffusers_pipeline import image_resized_to_grid_as_tensor, StableDiffusionGeneratorPipeline, \ from ..image_util import PatchMatch, debug_image
ConditioningData from ..stable_diffusion.diffusers_pipeline import (
from ldm.invoke.generator.img2img import Img2Img ConditioningData,
from ldm.invoke.patchmatch import PatchMatch StableDiffusionGeneratorPipeline,
from ldm.util import debug_image image_resized_to_grid_as_tensor,
)
from .img2img import Img2Img
def infill_methods()->list[str]: def infill_methods() -> list[str]:
methods = [ methods = [
"tile", "tile",
"solid", "solid",
] ]
if PatchMatch.patchmatch_available(): if PatchMatch.patchmatch_available():
methods.insert(0, 'patchmatch') methods.insert(0, "patchmatch")
return methods return methods
class Inpaint(Img2Img): class Inpaint(Img2Img):
def __init__(self, model, precision): def __init__(self, model, precision):
self.inpaint_height = 0 self.inpaint_height = 0
@ -53,11 +56,11 @@ class Inpaint(Img2Img):
np.ravel(image), np.ravel(image),
shape=(nrows, ncols, height, width, depth), shape=(nrows, ncols, height, width, depth),
strides=(height * _strides[0], width * _strides[1], *_strides), strides=(height * _strides[0], width * _strides[1], *_strides),
writeable=False writeable=False,
) )
def infill_patchmatch(self, im: Image.Image) -> Image: def infill_patchmatch(self, im: Image.Image) -> Image:
if im.mode != 'RGBA': if im.mode != "RGBA":
return im return im
# Skip patchmatch if patchmatch isn't available # Skip patchmatch if patchmatch isn't available
@ -65,13 +68,17 @@ class Inpaint(Img2Img):
return im return im
# Patchmatch (note, we may want to expose patch_size? Increasing it significantly impacts performance though) # Patchmatch (note, we may want to expose patch_size? Increasing it significantly impacts performance though)
im_patched_np = PatchMatch.inpaint(im.convert('RGB'), ImageOps.invert(im.split()[-1]), patch_size = 3) im_patched_np = PatchMatch.inpaint(
im_patched = Image.fromarray(im_patched_np, mode = 'RGB') im.convert("RGB"), ImageOps.invert(im.split()[-1]), patch_size=3
)
im_patched = Image.fromarray(im_patched_np, mode="RGB")
return im_patched return im_patched
def tile_fill_missing(self, im: Image.Image, tile_size: int = 16, seed: int = None) -> Image: def tile_fill_missing(
self, im: Image.Image, tile_size: int = 16, seed: int = None
) -> Image:
# Only fill if there's an alpha layer # Only fill if there's an alpha layer
if im.mode != 'RGBA': if im.mode != "RGBA":
return im return im
a = np.asarray(im, dtype=np.uint8) a = np.asarray(im, dtype=np.uint8)
@ -79,21 +86,21 @@ class Inpaint(Img2Img):
tile_size = (tile_size, tile_size) tile_size = (tile_size, tile_size)
# Get the image as tiles of a specified size # Get the image as tiles of a specified size
tiles = self.get_tile_images(a,*tile_size).copy() tiles = self.get_tile_images(a, *tile_size).copy()
# Get the mask as tiles # Get the mask as tiles
tiles_mask = tiles[:,:,:,:,3] tiles_mask = tiles[:, :, :, :, 3]
# Find any mask tiles with any fully transparent pixels (we will be replacing these later) # Find any mask tiles with any fully transparent pixels (we will be replacing these later)
tmask_shape = tiles_mask.shape tmask_shape = tiles_mask.shape
tiles_mask = tiles_mask.reshape(math.prod(tiles_mask.shape)) tiles_mask = tiles_mask.reshape(math.prod(tiles_mask.shape))
n,ny = (math.prod(tmask_shape[0:2])), math.prod(tmask_shape[2:]) n, ny = (math.prod(tmask_shape[0:2])), math.prod(tmask_shape[2:])
tiles_mask = (tiles_mask > 0) tiles_mask = tiles_mask > 0
tiles_mask = tiles_mask.reshape((n,ny)).all(axis = 1) tiles_mask = tiles_mask.reshape((n, ny)).all(axis=1)
# Get RGB tiles in single array and filter by the mask # Get RGB tiles in single array and filter by the mask
tshape = tiles.shape tshape = tiles.shape
tiles_all = tiles.reshape((math.prod(tiles.shape[0:2]), * tiles.shape[2:])) tiles_all = tiles.reshape((math.prod(tiles.shape[0:2]), *tiles.shape[2:]))
filtered_tiles = tiles_all[tiles_mask] filtered_tiles = tiles_all[tiles_mask]
if len(filtered_tiles) == 0: if len(filtered_tiles) == 0:
@ -101,23 +108,32 @@ class Inpaint(Img2Img):
# Find all invalid tiles and replace with a random valid tile # Find all invalid tiles and replace with a random valid tile
replace_count = (tiles_mask == False).sum() replace_count = (tiles_mask == False).sum()
rng = np.random.default_rng(seed = seed) rng = np.random.default_rng(seed=seed)
tiles_all[np.logical_not(tiles_mask)] = filtered_tiles[rng.choice(filtered_tiles.shape[0], replace_count),:,:,:] tiles_all[np.logical_not(tiles_mask)] = filtered_tiles[
rng.choice(filtered_tiles.shape[0], replace_count), :, :, :
]
# Convert back to an image # Convert back to an image
tiles_all = tiles_all.reshape(tshape) tiles_all = tiles_all.reshape(tshape)
tiles_all = tiles_all.swapaxes(1,2) tiles_all = tiles_all.swapaxes(1, 2)
st = tiles_all.reshape((math.prod(tiles_all.shape[0:2]), math.prod(tiles_all.shape[2:4]), tiles_all.shape[4])) st = tiles_all.reshape(
si = Image.fromarray(st, mode='RGBA') (
math.prod(tiles_all.shape[0:2]),
math.prod(tiles_all.shape[2:4]),
tiles_all.shape[4],
)
)
si = Image.fromarray(st, mode="RGBA")
return si return si
def mask_edge(self, mask: Image, edge_size: int, edge_blur: int) -> Image: def mask_edge(self, mask: Image, edge_size: int, edge_blur: int) -> Image:
npimg = np.asarray(mask, dtype=np.uint8) npimg = np.asarray(mask, dtype=np.uint8)
# Detect any partially transparent regions # Detect any partially transparent regions
npgradient = np.uint8(255 * (1.0 - np.floor(np.abs(0.5 - np.float32(npimg) / 255.0) * 2.0))) npgradient = np.uint8(
255 * (1.0 - np.floor(np.abs(0.5 - np.float32(npimg) / 255.0) * 2.0))
)
# Detect hard edges # Detect hard edges
npedge = cv2.Canny(npimg, threshold1=100, threshold2=200) npedge = cv2.Canny(npimg, threshold1=100, threshold2=200)
@ -126,7 +142,9 @@ class Inpaint(Img2Img):
npmask = npgradient + npedge npmask = npgradient + npedge
# Expand # Expand
npmask = cv2.dilate(npmask, np.ones((3,3), np.uint8), iterations = int(edge_size / 2)) npmask = cv2.dilate(
npmask, np.ones((3, 3), np.uint8), iterations=int(edge_size / 2)
)
new_mask = Image.fromarray(npmask) new_mask = Image.fromarray(npmask)
@ -135,9 +153,22 @@ class Inpaint(Img2Img):
return ImageOps.invert(new_mask) return ImageOps.invert(new_mask)
def seam_paint(
def seam_paint(self, im: Image.Image, seam_size: int, seam_blur: int, prompt, sampler, steps, cfg_scale, ddim_eta, self,
conditioning, strength, noise, infill_method, step_callback) -> Image.Image: im: Image.Image,
seam_size: int,
seam_blur: int,
prompt,
sampler,
steps,
cfg_scale,
ddim_eta,
conditioning,
strength,
noise,
infill_method,
step_callback,
) -> Image.Image:
hard_mask = self.pil_image.split()[-1].copy() hard_mask = self.pil_image.split()[-1].copy()
mask = self.mask_edge(hard_mask, seam_size, seam_blur) mask = self.mask_edge(hard_mask, seam_size, seam_blur)
@ -148,15 +179,15 @@ class Inpaint(Img2Img):
cfg_scale, cfg_scale,
ddim_eta, ddim_eta,
conditioning, conditioning,
init_image = im.copy().convert('RGBA'), init_image=im.copy().convert("RGBA"),
mask_image = mask, mask_image=mask,
strength = strength, strength=strength,
mask_blur_radius = 0, mask_blur_radius=0,
seam_size = 0, seam_size=0,
step_callback = step_callback, step_callback=step_callback,
inpaint_width = im.width, inpaint_width=im.width,
inpaint_height = im.height, inpaint_height=im.height,
infill_method = infill_method infill_method=infill_method,
) )
seam_noise = self.get_noise(im.width, im.height) seam_noise = self.get_noise(im.width, im.height)
@ -165,28 +196,35 @@ class Inpaint(Img2Img):
return result return result
@torch.no_grad() @torch.no_grad()
def get_make_image(self,prompt,sampler,steps,cfg_scale,ddim_eta, def get_make_image(
conditioning, self,
init_image: PIL.Image.Image | torch.FloatTensor, prompt,
mask_image: PIL.Image.Image | torch.FloatTensor, sampler,
strength: float, steps,
mask_blur_radius: int = 8, cfg_scale,
# Seam settings - when 0, doesn't fill seam ddim_eta,
seam_size: int = 0, conditioning,
seam_blur: int = 0, init_image: PIL.Image.Image | torch.FloatTensor,
seam_strength: float = 0.7, mask_image: PIL.Image.Image | torch.FloatTensor,
seam_steps: int = 10, strength: float,
tile_size: int = 32, mask_blur_radius: int = 8,
step_callback=None, # Seam settings - when 0, doesn't fill seam
inpaint_replace=False, enable_image_debugging=False, seam_size: int = 0,
infill_method = None, seam_blur: int = 0,
inpaint_width=None, seam_strength: float = 0.7,
inpaint_height=None, seam_steps: int = 10,
inpaint_fill:tuple(int)=(0x7F, 0x7F, 0x7F, 0xFF), tile_size: int = 32,
attention_maps_callback=None, step_callback=None,
**kwargs): inpaint_replace=False,
enable_image_debugging=False,
infill_method=None,
inpaint_width=None,
inpaint_height=None,
inpaint_fill: tuple(int) = (0x7F, 0x7F, 0x7F, 0xFF),
attention_maps_callback=None,
**kwargs,
):
""" """
Returns a function returning an image derived from the prompt and Returns a function returning an image derived from the prompt and
the initial image + mask. Return value depends on the seed at the initial image + mask. Return value depends on the seed at
@ -204,33 +242,39 @@ class Inpaint(Img2Img):
self.pil_image = init_image.copy() self.pil_image = init_image.copy()
# Do infill # Do infill
if infill_method == 'patchmatch' and PatchMatch.patchmatch_available(): if infill_method == "patchmatch" and PatchMatch.patchmatch_available():
init_filled = self.infill_patchmatch(self.pil_image.copy()) init_filled = self.infill_patchmatch(self.pil_image.copy())
elif infill_method == 'tile': elif infill_method == "tile":
init_filled = self.tile_fill_missing( init_filled = self.tile_fill_missing(
self.pil_image.copy(), self.pil_image.copy(), seed=self.seed, tile_size=tile_size
seed = self.seed,
tile_size = tile_size
) )
elif infill_method == 'solid': elif infill_method == "solid":
solid_bg = PIL.Image.new("RGBA", init_image.size, inpaint_fill) solid_bg = PIL.Image.new("RGBA", init_image.size, inpaint_fill)
init_filled = PIL.Image.alpha_composite(solid_bg, init_image) init_filled = PIL.Image.alpha_composite(solid_bg, init_image)
else: else:
raise ValueError(f"Non-supported infill type {infill_method}", infill_method) raise ValueError(
init_filled.paste(init_image, (0,0), init_image.split()[-1]) f"Non-supported infill type {infill_method}", infill_method
)
init_filled.paste(init_image, (0, 0), init_image.split()[-1])
# Resize if requested for inpainting # Resize if requested for inpainting
if inpaint_width and inpaint_height: if inpaint_width and inpaint_height:
init_filled = init_filled.resize((inpaint_width, inpaint_height)) init_filled = init_filled.resize((inpaint_width, inpaint_height))
debug_image(init_filled, "init_filled", debug_status=self.enable_image_debugging) debug_image(
init_filled, "init_filled", debug_status=self.enable_image_debugging
)
# Create init tensor # Create init tensor
init_image = image_resized_to_grid_as_tensor(init_filled.convert('RGB')) init_image = image_resized_to_grid_as_tensor(init_filled.convert("RGB"))
if isinstance(mask_image, PIL.Image.Image): if isinstance(mask_image, PIL.Image.Image):
self.pil_mask = mask_image.copy() self.pil_mask = mask_image.copy()
debug_image(mask_image, "mask_image BEFORE multiply with pil_image", debug_status=self.enable_image_debugging) debug_image(
mask_image,
"mask_image BEFORE multiply with pil_image",
debug_status=self.enable_image_debugging,
)
init_alpha = self.pil_image.getchannel("A") init_alpha = self.pil_image.getchannel("A")
if mask_image.mode != "L": if mask_image.mode != "L":
@ -243,8 +287,14 @@ class Inpaint(Img2Img):
if inpaint_width and inpaint_height: if inpaint_width and inpaint_height:
mask_image = mask_image.resize((inpaint_width, inpaint_height)) mask_image = mask_image.resize((inpaint_width, inpaint_height))
debug_image(mask_image, "mask_image AFTER multiply with pil_image", debug_status=self.enable_image_debugging) debug_image(
mask: torch.FloatTensor = image_resized_to_grid_as_tensor(mask_image, normalize=False) mask_image,
"mask_image AFTER multiply with pil_image",
debug_status=self.enable_image_debugging,
)
mask: torch.FloatTensor = image_resized_to_grid_as_tensor(
mask_image, normalize=False
)
else: else:
mask: torch.FloatTensor = mask_image mask: torch.FloatTensor = mask_image
@ -256,9 +306,9 @@ class Inpaint(Img2Img):
# todo: support cross-attention control # todo: support cross-attention control
uc, c, _ = conditioning uc, c, _ = conditioning
conditioning_data = (ConditioningData(uc, c, cfg_scale) conditioning_data = ConditioningData(
.add_scheduler_args_if_applicable(pipeline.scheduler, eta=ddim_eta)) uc, c, cfg_scale
).add_scheduler_args_if_applicable(pipeline.scheduler, eta=ddim_eta)
def make_image(x_T): def make_image(x_T):
pipeline_output = pipeline.inpaint_from_embeddings( pipeline_output = pipeline.inpaint_from_embeddings(
@ -271,43 +321,71 @@ class Inpaint(Img2Img):
callback=step_callback, callback=step_callback,
) )
if pipeline_output.attention_map_saver is not None and attention_maps_callback is not None: if (
pipeline_output.attention_map_saver is not None
and attention_maps_callback is not None
):
attention_maps_callback(pipeline_output.attention_map_saver) attention_maps_callback(pipeline_output.attention_map_saver)
result = self.postprocess_size_and_mask(pipeline.numpy_to_pil(pipeline_output.images)[0]) result = self.postprocess_size_and_mask(
pipeline.numpy_to_pil(pipeline_output.images)[0]
)
# Seam paint if this is our first pass (seam_size set to 0 during seam painting) # Seam paint if this is our first pass (seam_size set to 0 during seam painting)
if seam_size > 0: if seam_size > 0:
old_image = self.pil_image or init_image old_image = self.pil_image or init_image
old_mask = self.pil_mask or mask_image old_mask = self.pil_mask or mask_image
result = self.seam_paint(result, seam_size, seam_blur, prompt, sampler, seam_steps, cfg_scale, ddim_eta, result = self.seam_paint(
conditioning, seam_strength, x_T, infill_method, step_callback) result,
seam_size,
seam_blur,
prompt,
sampler,
seam_steps,
cfg_scale,
ddim_eta,
conditioning,
seam_strength,
x_T,
infill_method,
step_callback,
)
# Restore original settings # Restore original settings
self.get_make_image(prompt,sampler,steps,cfg_scale,ddim_eta, self.get_make_image(
conditioning, prompt,
old_image, sampler,
old_mask, steps,
strength, cfg_scale,
mask_blur_radius, seam_size, seam_blur, seam_strength, ddim_eta,
seam_steps, tile_size, step_callback, conditioning,
inpaint_replace, enable_image_debugging, old_image,
inpaint_width = inpaint_width, old_mask,
inpaint_height = inpaint_height, strength,
infill_method = infill_method, mask_blur_radius,
**kwargs) seam_size,
seam_blur,
seam_strength,
seam_steps,
tile_size,
step_callback,
inpaint_replace,
enable_image_debugging,
inpaint_width=inpaint_width,
inpaint_height=inpaint_height,
infill_method=infill_method,
**kwargs,
)
return result return result
return make_image return make_image
def sample_to_image(self, samples) -> Image.Image:
def sample_to_image(self, samples)->Image.Image: gen_result = super().sample_to_image(samples).convert("RGB")
gen_result = super().sample_to_image(samples).convert('RGB')
return self.postprocess_size_and_mask(gen_result) return self.postprocess_size_and_mask(gen_result)
def postprocess_size_and_mask(self, gen_result: Image.Image) -> Image.Image: def postprocess_size_and_mask(self, gen_result: Image.Image) -> Image.Image:
debug_image(gen_result, "gen_result", debug_status=self.enable_image_debugging) debug_image(gen_result, "gen_result", debug_status=self.enable_image_debugging)
@ -318,7 +396,13 @@ class Inpaint(Img2Img):
if self.pil_image is None or self.pil_mask is None: if self.pil_image is None or self.pil_mask is None:
return gen_result return gen_result
corrected_result = self.repaste_and_color_correct(gen_result, self.pil_image, self.pil_mask, self.mask_blur_radius) corrected_result = self.repaste_and_color_correct(
debug_image(corrected_result, "corrected_result", debug_status=self.enable_image_debugging) gen_result, self.pil_image, self.pil_mask, self.mask_blur_radius
)
debug_image(
corrected_result,
"corrected_result",
debug_status=self.enable_image_debugging,
)
return corrected_result return corrected_result

View File

@ -0,0 +1,81 @@
"""
invokeai.backend.generator.txt2img inherits from invokeai.backend.generator
"""
import PIL.Image
import torch
from ..stable_diffusion import (
ConditioningData,
PostprocessingSettings,
StableDiffusionGeneratorPipeline,
)
from .base import Generator
class Txt2Img(Generator):
def __init__(self, model, precision):
super().__init__(model, precision)
@torch.no_grad()
def get_make_image(
self,
prompt,
sampler,
steps,
cfg_scale,
ddim_eta,
conditioning,
width,
height,
step_callback=None,
threshold=0.0,
warmup=0.2,
perlin=0.0,
h_symmetry_time_pct=None,
v_symmetry_time_pct=None,
attention_maps_callback=None,
**kwargs,
):
"""
Returns a function returning an image derived from the prompt and the initial image
Return value depends on the seed at the time you call it
kwargs are 'width' and 'height'
"""
self.perlin = perlin
# noinspection PyTypeChecker
pipeline: StableDiffusionGeneratorPipeline = self.model
pipeline.scheduler = sampler
uc, c, extra_conditioning_info = conditioning
conditioning_data = ConditioningData(
uc,
c,
cfg_scale,
extra_conditioning_info,
postprocessing_settings=PostprocessingSettings(
threshold=threshold,
warmup=warmup,
h_symmetry_time_pct=h_symmetry_time_pct,
v_symmetry_time_pct=v_symmetry_time_pct,
),
).add_scheduler_args_if_applicable(pipeline.scheduler, eta=ddim_eta)
def make_image(x_T) -> PIL.Image.Image:
pipeline_output = pipeline.image_from_embeddings(
latents=torch.zeros_like(x_T, dtype=self.torch_dtype()),
noise=x_T,
num_inference_steps=steps,
conditioning_data=conditioning_data,
callback=step_callback,
)
if (
pipeline_output.attention_map_saver is not None
and attention_maps_callback is not None
):
attention_maps_callback(pipeline_output.attention_map_saver)
return pipeline.numpy_to_pil(pipeline_output.images)[0]
return make_image

View File

@ -1,6 +1,6 @@
''' """
ldm.invoke.generator.txt2img inherits from ldm.invoke.generator invokeai.backend.generator.txt2img inherits from invokeai.backend.generator
''' """
import math import math
from typing import Callable, Optional from typing import Callable, Optional
@ -8,21 +8,40 @@ from typing import Callable, Optional
import torch import torch
from diffusers.utils.logging import get_verbosity, set_verbosity, set_verbosity_error from diffusers.utils.logging import get_verbosity, set_verbosity, set_verbosity_error
from ldm.invoke.generator.base import Generator from ..models import PostprocessingSettings
from ldm.invoke.generator.diffusers_pipeline import trim_to_multiple_of, StableDiffusionGeneratorPipeline, \ from .base import Generator
ConditioningData from .diffusers_pipeline import (
from ldm.models.diffusion.shared_invokeai_diffusion import PostprocessingSettings ConditioningData,
StableDiffusionGeneratorPipeline,
trim_to_multiple_of,
)
class Txt2Img2Img(Generator): class Txt2Img2Img(Generator):
def __init__(self, model, precision): def __init__(self, model, precision):
super().__init__(model, precision) super().__init__(model, precision)
self.init_latent = None # for get_noise() self.init_latent = None # for get_noise()
def get_make_image(self, prompt:str, sampler, steps:int, cfg_scale:float, ddim_eta, def get_make_image(
conditioning, width:int, height:int, strength:float, self,
step_callback:Optional[Callable]=None, threshold=0.0, warmup=0.2, perlin=0.0, prompt: str,
h_symmetry_time_pct=None, v_symmetry_time_pct=None, attention_maps_callback=None, **kwargs): sampler,
steps: int,
cfg_scale: float,
ddim_eta,
conditioning,
width: int,
height: int,
strength: float,
step_callback: Optional[Callable] = None,
threshold=0.0,
warmup=0.2,
perlin=0.0,
h_symmetry_time_pct=None,
v_symmetry_time_pct=None,
attention_maps_callback=None,
**kwargs,
):
""" """
Returns a function returning an image derived from the prompt and the initial image Returns a function returning an image derived from the prompt and the initial image
Return value depends on the seed at the time you call it Return value depends on the seed at the time you call it
@ -35,19 +54,20 @@ class Txt2Img2Img(Generator):
pipeline.scheduler = sampler pipeline.scheduler = sampler
uc, c, extra_conditioning_info = conditioning uc, c, extra_conditioning_info = conditioning
conditioning_data = ( conditioning_data = ConditioningData(
ConditioningData( uc,
uc, c, cfg_scale, extra_conditioning_info, c,
postprocessing_settings = PostprocessingSettings( cfg_scale,
threshold=threshold, extra_conditioning_info,
warmup=0.2, postprocessing_settings=PostprocessingSettings(
h_symmetry_time_pct=h_symmetry_time_pct, threshold=threshold,
v_symmetry_time_pct=v_symmetry_time_pct warmup=0.2,
) h_symmetry_time_pct=h_symmetry_time_pct,
).add_scheduler_args_if_applicable(pipeline.scheduler, eta=ddim_eta)) v_symmetry_time_pct=v_symmetry_time_pct,
),
).add_scheduler_args_if_applicable(pipeline.scheduler, eta=ddim_eta)
def make_image(x_T): def make_image(x_T):
first_pass_latent_output, _ = pipeline.latents_from_embeddings( first_pass_latent_output, _ = pipeline.latents_from_embeddings(
latents=torch.zeros_like(x_T), latents=torch.zeros_like(x_T),
num_inference_steps=steps, num_inference_steps=steps,
@ -61,28 +81,40 @@ class Txt2Img2Img(Generator):
init_width = first_pass_latent_output.size()[3] * self.downsampling_factor init_width = first_pass_latent_output.size()[3] * self.downsampling_factor
init_height = first_pass_latent_output.size()[2] * self.downsampling_factor init_height = first_pass_latent_output.size()[2] * self.downsampling_factor
print( print(
f"\n>> Interpolating from {init_width}x{init_height} to {width}x{height} using DDIM sampling" f"\n>> Interpolating from {init_width}x{init_height} to {width}x{height} using DDIM sampling"
) )
# resizing # resizing
resized_latents = torch.nn.functional.interpolate( resized_latents = torch.nn.functional.interpolate(
first_pass_latent_output, first_pass_latent_output,
size=(height // self.downsampling_factor, width // self.downsampling_factor), size=(
mode="bilinear" height // self.downsampling_factor,
width // self.downsampling_factor,
),
mode="bilinear",
) )
# Free up memory from the last generation. # Free up memory from the last generation.
clear_cuda_cache = kwargs['clear_cuda_cache'] or None clear_cuda_cache = kwargs["clear_cuda_cache"] or None
if clear_cuda_cache is not None: if clear_cuda_cache is not None:
clear_cuda_cache() clear_cuda_cache()
second_pass_noise = self.get_noise_like(resized_latents, override_perlin=True) second_pass_noise = self.get_noise_like(
resized_latents, override_perlin=True
)
# Clear symmetry for the second pass # Clear symmetry for the second pass
from dataclasses import replace from dataclasses import replace
new_postprocessing_settings = replace(conditioning_data.postprocessing_settings, h_symmetry_time_pct=None)
new_postprocessing_settings = replace(new_postprocessing_settings, v_symmetry_time_pct=None) new_postprocessing_settings = replace(
new_conditioning_data = replace(conditioning_data, postprocessing_settings=new_postprocessing_settings) conditioning_data.postprocessing_settings, h_symmetry_time_pct=None
)
new_postprocessing_settings = replace(
new_postprocessing_settings, v_symmetry_time_pct=None
)
new_conditioning_data = replace(
conditioning_data, postprocessing_settings=new_postprocessing_settings
)
verbosity = get_verbosity() verbosity = get_verbosity()
set_verbosity_error() set_verbosity_error()
@ -92,15 +124,18 @@ class Txt2Img2Img(Generator):
conditioning_data=new_conditioning_data, conditioning_data=new_conditioning_data,
strength=strength, strength=strength,
noise=second_pass_noise, noise=second_pass_noise,
callback=step_callback) callback=step_callback,
)
set_verbosity(verbosity) set_verbosity(verbosity)
if pipeline_output.attention_map_saver is not None and attention_maps_callback is not None: if (
pipeline_output.attention_map_saver is not None
and attention_maps_callback is not None
):
attention_maps_callback(pipeline_output.attention_map_saver) attention_maps_callback(pipeline_output.attention_map_saver)
return pipeline.numpy_to_pil(pipeline_output.images)[0] return pipeline.numpy_to_pil(pipeline_output.images)[0]
# FIXME: do we really need something entirely different for the inpainting model? # FIXME: do we really need something entirely different for the inpainting model?
# in the case of the inpainting model being loaded, the trick of # in the case of the inpainting model being loaded, the trick of
@ -111,19 +146,23 @@ class Txt2Img2Img(Generator):
return make_image return make_image
def get_noise_like(self, like: torch.Tensor, override_perlin: bool=False): def get_noise_like(self, like: torch.Tensor, override_perlin: bool = False):
device = like.device device = like.device
if device.type == 'mps': if device.type == "mps":
x = torch.randn_like(like, device='cpu', dtype=self.torch_dtype()).to(device) x = torch.randn_like(like, device="cpu", dtype=self.torch_dtype()).to(
device
)
else: else:
x = torch.randn_like(like, device=device, dtype=self.torch_dtype()) x = torch.randn_like(like, device=device, dtype=self.torch_dtype())
if self.perlin > 0.0 and override_perlin == False: if self.perlin > 0.0 and override_perlin == False:
shape = like.shape shape = like.shape
x = (1-self.perlin)*x + self.perlin*self.get_perlin_noise(shape[3], shape[2]) x = (1 - self.perlin) * x + self.perlin * self.get_perlin_noise(
shape[3], shape[2]
)
return x return x
# returns a tensor filled with random numbers from a normal distribution # returns a tensor filled with random numbers from a normal distribution
def get_noise(self,width,height,scale = True): def get_noise(self, width, height, scale=True):
# print(f"Get noise: {width}x{height}") # print(f"Get noise: {width}x{height}")
if scale: if scale:
# Scale the input width and height for the initial generation # Scale the input width and height for the initial generation
@ -133,7 +172,9 @@ class Txt2Img2Img(Generator):
aspect = width / height aspect = width / height
dimension = self.model.unet.config.sample_size * self.model.vae_scale_factor dimension = self.model.unet.config.sample_size * self.model.vae_scale_factor
min_dimension = math.floor(dimension * 0.5) min_dimension = math.floor(dimension * 0.5)
model_area = dimension * dimension # hardcoded for now since all models are trained on square images model_area = (
dimension * dimension
) # hardcoded for now since all models are trained on square images
if aspect > 1.0: if aspect > 1.0:
init_height = max(min_dimension, math.sqrt(model_area / aspect)) init_height = max(min_dimension, math.sqrt(model_area / aspect))
@ -142,7 +183,9 @@ class Txt2Img2Img(Generator):
init_width = max(min_dimension, math.sqrt(model_area * aspect)) init_width = max(min_dimension, math.sqrt(model_area * aspect))
init_height = init_width / aspect init_height = init_width / aspect
scaled_width, scaled_height = trim_to_multiple_of(math.floor(init_width), math.floor(init_height)) scaled_width, scaled_height = trim_to_multiple_of(
math.floor(init_width), math.floor(init_height)
)
else: else:
scaled_width = width scaled_width = width
@ -152,10 +195,14 @@ class Txt2Img2Img(Generator):
channels = self.latent_channels channels = self.latent_channels
if channels == 9: if channels == 9:
channels = 4 # we don't really want noise for all the mask channels channels = 4 # we don't really want noise for all the mask channels
shape = (1, channels, shape = (
scaled_height // self.downsampling_factor, scaled_width // self.downsampling_factor) 1,
if self.use_mps_noise or device.type == 'mps': channels,
tensor = torch.empty(size=shape, device='cpu') scaled_height // self.downsampling_factor,
scaled_width // self.downsampling_factor,
)
if self.use_mps_noise or device.type == "mps":
tensor = torch.empty(size=shape, device="cpu")
tensor = self.get_noise_like(like=tensor).to(device) tensor = self.get_noise_like(like=tensor).to(device)
else: else:
tensor = torch.empty(size=shape, device=device) tensor = torch.empty(size=shape, device=device)

View File

@ -1,5 +1,5 @@
''' """
ldm.invoke.globals defines a small number of global variables that would invokeai.backend.globals defines a small number of global variables that would
otherwise have to be passed through long and complex call chains. otherwise have to be passed through long and complex call chains.
It defines a Namespace object named "Globals" that contains It defines a Namespace object named "Globals" that contains
@ -9,7 +9,7 @@ the attributes:
- initfile - path to the initialization file - initfile - path to the initialization file
- try_patchmatch - option to globally disable loading of 'patchmatch' module - try_patchmatch - option to globally disable loading of 'patchmatch' module
- always_use_cpu - force use of CPU even if GPU is available - always_use_cpu - force use of CPU even if GPU is available
''' """
import os import os
import os.path as osp import os.path as osp
@ -20,12 +20,12 @@ from typing import Union
Globals = Namespace() Globals = Namespace()
# Where to look for the initialization file and other key components # Where to look for the initialization file and other key components
Globals.initfile = 'invokeai.init' Globals.initfile = "invokeai.init"
Globals.models_file = 'models.yaml' Globals.models_file = "models.yaml"
Globals.models_dir = 'models' Globals.models_dir = "models"
Globals.config_dir = 'configs' Globals.config_dir = "configs"
Globals.autoscan_dir = 'weights' Globals.autoscan_dir = "weights"
Globals.converted_ckpts_dir = 'converted_ckpts' Globals.converted_ckpts_dir = "converted_ckpts"
# Set the default root directory. This can be overwritten by explicitly # Set the default root directory. This can be overwritten by explicitly
# passing the `--root <directory>` argument on the command line. # passing the `--root <directory>` argument on the command line.
@ -34,12 +34,15 @@ Globals.converted_ckpts_dir = 'converted_ckpts'
# 2) use VIRTUAL_ENV environment variable, with a check for initfile being there # 2) use VIRTUAL_ENV environment variable, with a check for initfile being there
# 3) use ~/invokeai # 3) use ~/invokeai
if os.environ.get('INVOKEAI_ROOT'): if os.environ.get("INVOKEAI_ROOT"):
Globals.root = osp.abspath(os.environ.get('INVOKEAI_ROOT')) Globals.root = osp.abspath(os.environ.get("INVOKEAI_ROOT"))
elif os.environ.get('VIRTUAL_ENV') and Path(os.environ.get('VIRTUAL_ENV'),'..',Globals.initfile).exists(): elif (
Globals.root = osp.abspath(osp.join(os.environ.get('VIRTUAL_ENV'), '..')) os.environ.get("VIRTUAL_ENV")
and Path(os.environ.get("VIRTUAL_ENV"), "..", Globals.initfile).exists()
):
Globals.root = osp.abspath(osp.join(os.environ.get("VIRTUAL_ENV"), ".."))
else: else:
Globals.root = osp.abspath(osp.expanduser('~/invokeai')) Globals.root = osp.abspath(osp.expanduser("~/invokeai"))
# Try loading patchmatch # Try loading patchmatch
Globals.try_patchmatch = True Globals.try_patchmatch = True
@ -61,31 +64,38 @@ Globals.sequential_guidance = False
Globals.full_precision = False Globals.full_precision = False
# whether we should convert ckpt files into diffusers models on the fly # whether we should convert ckpt files into diffusers models on the fly
Globals.ckpt_convert = False Globals.ckpt_convert = True
# logging tokenization everywhere # logging tokenization everywhere
Globals.log_tokenization = False Globals.log_tokenization = False
def global_config_file()->Path:
def global_config_file() -> Path:
return Path(Globals.root, Globals.config_dir, Globals.models_file) return Path(Globals.root, Globals.config_dir, Globals.models_file)
def global_config_dir()->Path:
def global_config_dir() -> Path:
return Path(Globals.root, Globals.config_dir) return Path(Globals.root, Globals.config_dir)
def global_models_dir()->Path:
def global_models_dir() -> Path:
return Path(Globals.root, Globals.models_dir) return Path(Globals.root, Globals.models_dir)
def global_autoscan_dir()->Path:
def global_autoscan_dir() -> Path:
return Path(Globals.root, Globals.autoscan_dir) return Path(Globals.root, Globals.autoscan_dir)
def global_converted_ckpts_dir()->Path:
def global_converted_ckpts_dir() -> Path:
return Path(global_models_dir(), Globals.converted_ckpts_dir) return Path(global_models_dir(), Globals.converted_ckpts_dir)
def global_set_root(root_dir:Union[str,Path]):
def global_set_root(root_dir: Union[str, Path]):
Globals.root = root_dir Globals.root = root_dir
def global_cache_dir(subdir:Union[str,Path]='')->Path:
''' def global_cache_dir(subdir: Union[str, Path] = "") -> Path:
"""
Returns Path to the model cache directory. If a subdirectory Returns Path to the model cache directory. If a subdirectory
is provided, it will be appended to the end of the path, allowing is provided, it will be appended to the end of the path, allowing
for huggingface-style conventions: for huggingface-style conventions:
@ -98,18 +108,18 @@ def global_cache_dir(subdir:Union[str,Path]='')->Path:
One other caveat is that HuggingFace is moving some diffusers models One other caveat is that HuggingFace is moving some diffusers models
into the "hub" subdirectory as well, so this will need to be revisited into the "hub" subdirectory as well, so this will need to be revisited
from time to time. from time to time.
''' """
home: str = os.getenv('HF_HOME') home: str = os.getenv("HF_HOME")
if home is None: if home is None:
home = os.getenv('XDG_CACHE_HOME') home = os.getenv("XDG_CACHE_HOME")
if home is not None: if home is not None:
# Set `home` to $XDG_CACHE_HOME/huggingface, which is the default location mentioned in HuggingFace Hub Client Library. # Set `home` to $XDG_CACHE_HOME/huggingface, which is the default location mentioned in HuggingFace Hub Client Library.
# See: https://huggingface.co/docs/huggingface_hub/main/en/package_reference/environment_variables#xdgcachehome # See: https://huggingface.co/docs/huggingface_hub/main/en/package_reference/environment_variables#xdgcachehome
home += os.sep + 'huggingface' home += os.sep + "huggingface"
if home is not None: if home is not None:
return Path(home,subdir) return Path(home, subdir)
else: else:
return Path(Globals.root,'models',subdir) return Path(Globals.root, "models", subdir)

View File

@ -0,0 +1,24 @@
"""
Initialization file for invokeai.backend.image_util methods.
"""
from .patchmatch import PatchMatch
from .pngwriter import PngWriter, PromptFormatter, retrieve_metadata, write_metadata
from .seamless import configure_model_padding
from .txt2mask import Txt2Mask
from .util import InitImageResizer, make_grid
def debug_image(
debug_image, debug_text, debug_show=True, debug_result=False, debug_status=False
):
if not debug_status:
return
image_copy = debug_image.copy().convert("RGBA")
ImageDraw.Draw(image_copy).text((5, 5), debug_text, (255, 0, 0))
if debug_show:
image_copy.show()
if debug_result:
return image_copy

View File

@ -1,20 +1,22 @@
''' """
This module defines a singleton object, "patchmatch" that This module defines a singleton object, "patchmatch" that
wraps the actual patchmatch object. It respects the global wraps the actual patchmatch object. It respects the global
"try_patchmatch" attribute, so that patchmatch loading can "try_patchmatch" attribute, so that patchmatch loading can
be suppressed or deferred be suppressed or deferred
''' """
from ldm.invoke.globals import Globals import numpy as np
import numpy as np
from invokeai.backend.globals import Globals
class PatchMatch: class PatchMatch:
''' """
Thin class wrapper around the patchmatch function. Thin class wrapper around the patchmatch function.
''' """
patch_match = None patch_match = None
tried_load:bool = False tried_load: bool = False
def __init__(self): def __init__(self):
super().__init__() super().__init__()
@ -24,21 +26,22 @@ class PatchMatch:
return return
if Globals.try_patchmatch: if Globals.try_patchmatch:
from patchmatch import patch_match as pm from patchmatch import patch_match as pm
if pm.patchmatch_available: if pm.patchmatch_available:
print('>> Patchmatch initialized') print(">> Patchmatch initialized")
else: else:
print('>> Patchmatch not loaded (nonfatal)') print(">> Patchmatch not loaded (nonfatal)")
self.patch_match = pm self.patch_match = pm
else: else:
print('>> Patchmatch loading disabled') print(">> Patchmatch loading disabled")
self.tried_load = True self.tried_load = True
@classmethod @classmethod
def patchmatch_available(self)->bool: def patchmatch_available(self) -> bool:
self._load_patch_match() self._load_patch_match()
return self.patch_match and self.patch_match.patchmatch_available return self.patch_match and self.patch_match.patchmatch_available
@classmethod @classmethod
def inpaint(self,*args,**kwargs)->np.ndarray: def inpaint(self, *args, **kwargs) -> np.ndarray:
if self.patchmatch_available(): if self.patchmatch_available():
return self.patch_match.inpaint(*args,**kwargs) return self.patch_match.inpaint(*args, **kwargs)

View File

@ -6,10 +6,11 @@ PngWriter -- Converts Images generated by T2I into PNGs, finds
Exports function retrieve_metadata(path) Exports function retrieve_metadata(path)
""" """
import json
import os import os
import re import re
import json
from PIL import PngImagePlugin, Image from PIL import Image, PngImagePlugin
# -------------------image generation utils----- # -------------------image generation utils-----
@ -25,52 +26,57 @@ class PngWriter:
dirlist = sorted(os.listdir(self.outdir), reverse=True) dirlist = sorted(os.listdir(self.outdir), reverse=True)
# find the first filename that matches our pattern or return 000000.0.png # find the first filename that matches our pattern or return 000000.0.png
existing_name = next( existing_name = next(
(f for f in dirlist if re.match('^(\d+)\..*\.png', f)), (f for f in dirlist if re.match("^(\d+)\..*\.png", f)),
'0000000.0.png', "0000000.0.png",
) )
basecount = int(existing_name.split('.', 1)[0]) + 1 basecount = int(existing_name.split(".", 1)[0]) + 1
return f'{basecount:06}' return f"{basecount:06}"
# saves image named _image_ to outdir/name, writing metadata from prompt # saves image named _image_ to outdir/name, writing metadata from prompt
# returns full path of output # returns full path of output
def save_image_and_prompt_to_png(self, image, dream_prompt, name, metadata=None, compress_level=6): def save_image_and_prompt_to_png(
self, image, dream_prompt, name, metadata=None, compress_level=6
):
path = os.path.join(self.outdir, name) path = os.path.join(self.outdir, name)
info = PngImagePlugin.PngInfo() info = PngImagePlugin.PngInfo()
info.add_text('Dream', dream_prompt) info.add_text("Dream", dream_prompt)
if metadata: if metadata:
info.add_text('sd-metadata', json.dumps(metadata)) info.add_text("sd-metadata", json.dumps(metadata))
image.save(path, 'PNG', pnginfo=info, compress_level=compress_level) image.save(path, "PNG", pnginfo=info, compress_level=compress_level)
return path return path
def retrieve_metadata(self,img_basename): def retrieve_metadata(self, img_basename):
''' """
Given a PNG filename stored in outdir, returns the "sd-metadata" Given a PNG filename stored in outdir, returns the "sd-metadata"
metadata stored there, as a dict metadata stored there, as a dict
''' """
path = os.path.join(self.outdir,img_basename) path = os.path.join(self.outdir, img_basename)
all_metadata = retrieve_metadata(path) all_metadata = retrieve_metadata(path)
return all_metadata['sd-metadata'] return all_metadata["sd-metadata"]
def retrieve_metadata(img_path): def retrieve_metadata(img_path):
''' """
Given a path to a PNG image, returns the "sd-metadata" Given a path to a PNG image, returns the "sd-metadata"
metadata stored there, as a dict metadata stored there, as a dict
''' """
im = Image.open(img_path) im = Image.open(img_path)
if hasattr(im, 'text'): if hasattr(im, "text"):
md = im.text.get('sd-metadata', '{}') md = im.text.get("sd-metadata", "{}")
dream_prompt = im.text.get('Dream', '') dream_prompt = im.text.get("Dream", "")
else: else:
# When trying to retrieve metadata from images without a 'text' payload, such as JPG images. # When trying to retrieve metadata from images without a 'text' payload, such as JPG images.
md = '{}' md = "{}"
dream_prompt = '' dream_prompt = ""
return {'sd-metadata': json.loads(md), 'Dream': dream_prompt} return {"sd-metadata": json.loads(md), "Dream": dream_prompt}
def write_metadata(img_path:str, meta:dict):
def write_metadata(img_path: str, meta: dict):
im = Image.open(img_path) im = Image.open(img_path)
info = PngImagePlugin.PngInfo() info = PngImagePlugin.PngInfo()
info.add_text('sd-metadata', json.dumps(meta)) info.add_text("sd-metadata", json.dumps(meta))
im.save(img_path,'PNG',pnginfo=info) im.save(img_path, "PNG", pnginfo=info)
class PromptFormatter: class PromptFormatter:
def __init__(self, t2i, opt): def __init__(self, t2i, opt):
@ -86,28 +92,30 @@ class PromptFormatter:
switches = list() switches = list()
switches.append(f'"{opt.prompt}"') switches.append(f'"{opt.prompt}"')
switches.append(f'-s{opt.steps or t2i.steps}') switches.append(f"-s{opt.steps or t2i.steps}")
switches.append(f'-W{opt.width or t2i.width}') switches.append(f"-W{opt.width or t2i.width}")
switches.append(f'-H{opt.height or t2i.height}') switches.append(f"-H{opt.height or t2i.height}")
switches.append(f'-C{opt.cfg_scale or t2i.cfg_scale}') switches.append(f"-C{opt.cfg_scale or t2i.cfg_scale}")
switches.append(f'-A{opt.sampler_name or t2i.sampler_name}') switches.append(f"-A{opt.sampler_name or t2i.sampler_name}")
# to do: put model name into the t2i object # to do: put model name into the t2i object
# switches.append(f'--model{t2i.model_name}') # switches.append(f'--model{t2i.model_name}')
if opt.seamless or t2i.seamless: if opt.seamless or t2i.seamless:
switches.append(f'--seamless') switches.append(f"--seamless")
if opt.init_img: if opt.init_img:
switches.append(f'-I{opt.init_img}') switches.append(f"-I{opt.init_img}")
if opt.fit: if opt.fit:
switches.append(f'--fit') switches.append(f"--fit")
if opt.strength and opt.init_img is not None: if opt.strength and opt.init_img is not None:
switches.append(f'-f{opt.strength or t2i.strength}') switches.append(f"-f{opt.strength or t2i.strength}")
if opt.gfpgan_strength: if opt.gfpgan_strength:
switches.append(f'-G{opt.gfpgan_strength}') switches.append(f"-G{opt.gfpgan_strength}")
if opt.upscale: if opt.upscale:
switches.append(f'-U {" ".join([str(u) for u in opt.upscale])}') switches.append(f'-U {" ".join([str(u) for u in opt.upscale])}')
if opt.variation_amount > 0: if opt.variation_amount > 0:
switches.append(f'-v{opt.variation_amount}') switches.append(f"-v{opt.variation_amount}")
if opt.with_variations: if opt.with_variations:
formatted_variations = ','.join(f'{seed}:{weight}' for seed, weight in opt.with_variations) formatted_variations = ",".join(
switches.append(f'-V{formatted_variations}') f"{seed}:{weight}" for seed, weight in opt.with_variations
return ' '.join(switches) )
switches.append(f"-V{formatted_variations}")
return " ".join(switches)

View File

@ -0,0 +1,59 @@
import torch.nn as nn
def _conv_forward_asymmetric(self, input, weight, bias):
"""
Patch for Conv2d._conv_forward that supports asymmetric padding
"""
working = nn.functional.pad(
input, self.asymmetric_padding["x"], mode=self.asymmetric_padding_mode["x"]
)
working = nn.functional.pad(
working, self.asymmetric_padding["y"], mode=self.asymmetric_padding_mode["y"]
)
return nn.functional.conv2d(
working,
weight,
bias,
self.stride,
nn.modules.utils._pair(0),
self.dilation,
self.groups,
)
def configure_model_padding(model, seamless, seamless_axes):
"""
Modifies the 2D convolution layers to use a circular padding mode based on the `seamless` and `seamless_axes` options.
"""
# TODO: get an explicit interface for this in diffusers: https://github.com/huggingface/diffusers/issues/556
for m in model.modules():
if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
if seamless:
m.asymmetric_padding_mode = {}
m.asymmetric_padding = {}
m.asymmetric_padding_mode["x"] = (
"circular" if ("x" in seamless_axes) else "constant"
)
m.asymmetric_padding["x"] = (
m._reversed_padding_repeated_twice[0],
m._reversed_padding_repeated_twice[1],
0,
0,
)
m.asymmetric_padding_mode["y"] = (
"circular" if ("y" in seamless_axes) else "constant"
)
m.asymmetric_padding["y"] = (
0,
0,
m._reversed_padding_repeated_twice[2],
m._reversed_padding_repeated_twice[3],
)
m._conv_forward = _conv_forward_asymmetric.__get__(m, nn.Conv2d)
else:
m._conv_forward = nn.Conv2d._conv_forward.__get__(m, nn.Conv2d)
if hasattr(m, "asymmetric_padding_mode"):
del m.asymmetric_padding_mode
if hasattr(m, "asymmetric_padding"):
del m.asymmetric_padding

View File

@ -1,9 +1,9 @@
'''Makes available the Txt2Mask class, which assists in the automatic """Makes available the Txt2Mask class, which assists in the automatic
assignment of masks via text prompt using clipseg. assignment of masks via text prompt using clipseg.
Here is typical usage: Here is typical usage:
from ldm.invoke.txt2mask import Txt2Mask, SegmentedGrayscale from invokeai.backend.image_util.txt2mask import Txt2Mask, SegmentedGrayscale
from PIL import Image from PIL import Image
txt2mask = Txt2Mask(self.device) txt2mask = Txt2Mask(self.device)
@ -25,31 +25,39 @@ the mask that exceed the indicated confidence threshold. Values range
from 0.0 to 1.0. The higher the threshold, the more confident the from 0.0 to 1.0. The higher the threshold, the more confident the
algorithm is. In limited testing, I have found that values around 0.5 algorithm is. In limited testing, I have found that values around 0.5
work fine. work fine.
''' """
import numpy as np
import torch import torch
import numpy as np
from transformers import AutoProcessor, CLIPSegForImageSegmentation
from PIL import Image, ImageOps from PIL import Image, ImageOps
from torchvision import transforms from torchvision import transforms
from ldm.invoke.globals import global_cache_dir from transformers import AutoProcessor, CLIPSegForImageSegmentation
CLIPSEG_MODEL = 'CIDAS/clipseg-rd64-refined' from invokeai.backend.globals import global_cache_dir
CLIPSEG_MODEL = "CIDAS/clipseg-rd64-refined"
CLIPSEG_SIZE = 352 CLIPSEG_SIZE = 352
class SegmentedGrayscale(object): class SegmentedGrayscale(object):
def __init__(self, image:Image, heatmap:torch.Tensor): def __init__(self, image: Image, heatmap: torch.Tensor):
self.heatmap = heatmap self.heatmap = heatmap
self.image = image self.image = image
def to_grayscale(self,invert:bool=False)->Image: def to_grayscale(self, invert: bool = False) -> Image:
return self._rescale(Image.fromarray(np.uint8(255 - self.heatmap * 255 if invert else self.heatmap * 255))) return self._rescale(
Image.fromarray(
np.uint8(255 - self.heatmap * 255 if invert else self.heatmap * 255)
)
)
def to_mask(self,threshold:float=0.5)->Image: def to_mask(self, threshold: float = 0.5) -> Image:
discrete_heatmap = self.heatmap.lt(threshold).int() discrete_heatmap = self.heatmap.lt(threshold).int()
return self._rescale(Image.fromarray(np.uint8(discrete_heatmap*255),mode='L')) return self._rescale(
Image.fromarray(np.uint8(discrete_heatmap * 255), mode="L")
)
def to_transparent(self,invert:bool=False)->Image: def to_transparent(self, invert: bool = False) -> Image:
transparent_image = self.image.copy() transparent_image = self.image.copy()
# For img2img, we want the selected regions to be transparent, # For img2img, we want the selected regions to be transparent,
# but to_grayscale() returns the opposite. Thus invert. # but to_grayscale() returns the opposite. Thus invert.
@ -58,70 +66,77 @@ class SegmentedGrayscale(object):
return transparent_image return transparent_image
# unscales and uncrops the 352x352 heatmap so that it matches the image again # unscales and uncrops the 352x352 heatmap so that it matches the image again
def _rescale(self, heatmap:Image)->Image: def _rescale(self, heatmap: Image) -> Image:
size = self.image.width if (self.image.width > self.image.height) else self.image.height size = (
resized_image = heatmap.resize( self.image.width
(size,size), if (self.image.width > self.image.height)
resample=Image.Resampling.LANCZOS else self.image.height
) )
return resized_image.crop((0,0,self.image.width,self.image.height)) resized_image = heatmap.resize((size, size), resample=Image.Resampling.LANCZOS)
return resized_image.crop((0, 0, self.image.width, self.image.height))
class Txt2Mask(object): class Txt2Mask(object):
''' """
Create new Txt2Mask object. The optional device argument can be one of Create new Txt2Mask object. The optional device argument can be one of
'cuda', 'mps' or 'cpu'. 'cuda', 'mps' or 'cpu'.
''' """
def __init__(self,device='cpu',refined=False):
print('>> Initializing clipseg model for text to mask inference') def __init__(self, device="cpu", refined=False):
print(">> Initializing clipseg model for text to mask inference")
# BUG: we are not doing anything with the device option at this time # BUG: we are not doing anything with the device option at this time
self.device = device self.device = device
self.processor = AutoProcessor.from_pretrained(CLIPSEG_MODEL, self.processor = AutoProcessor.from_pretrained(
cache_dir=global_cache_dir('hub') CLIPSEG_MODEL, cache_dir=global_cache_dir("hub")
) )
self.model = CLIPSegForImageSegmentation.from_pretrained(CLIPSEG_MODEL, self.model = CLIPSegForImageSegmentation.from_pretrained(
cache_dir=global_cache_dir('hub') CLIPSEG_MODEL, cache_dir=global_cache_dir("hub")
) )
@torch.no_grad() @torch.no_grad()
def segment(self, image, prompt:str) -> SegmentedGrayscale: def segment(self, image, prompt: str) -> SegmentedGrayscale:
''' """
Given a prompt string such as "a bagel", tries to identify the object in the Given a prompt string such as "a bagel", tries to identify the object in the
provided image and returns a SegmentedGrayscale object in which the brighter provided image and returns a SegmentedGrayscale object in which the brighter
pixels indicate where the object is inferred to be. pixels indicate where the object is inferred to be.
''' """
transform = transforms.Compose([ transform = transforms.Compose(
transforms.ToTensor(), [
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), transforms.ToTensor(),
transforms.Resize((CLIPSEG_SIZE, CLIPSEG_SIZE)), # must be multiple of 64... transforms.Normalize(
]) mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
),
transforms.Resize(
(CLIPSEG_SIZE, CLIPSEG_SIZE)
), # must be multiple of 64...
]
)
if type(image) is str: if type(image) is str:
image = Image.open(image).convert('RGB') image = Image.open(image).convert("RGB")
image = ImageOps.exif_transpose(image) image = ImageOps.exif_transpose(image)
img = self._scale_and_crop(image) img = self._scale_and_crop(image)
inputs = self.processor(text=[prompt], inputs = self.processor(
images=[img], text=[prompt], images=[img], padding=True, return_tensors="pt"
padding=True, )
return_tensors='pt')
outputs = self.model(**inputs) outputs = self.model(**inputs)
heatmap = torch.sigmoid(outputs.logits) heatmap = torch.sigmoid(outputs.logits)
return SegmentedGrayscale(image, heatmap) return SegmentedGrayscale(image, heatmap)
def _scale_and_crop(self, image:Image)->Image: def _scale_and_crop(self, image: Image) -> Image:
scaled_image = Image.new('RGB',(CLIPSEG_SIZE,CLIPSEG_SIZE)) scaled_image = Image.new("RGB", (CLIPSEG_SIZE, CLIPSEG_SIZE))
if image.width > image.height: # width is constraint if image.width > image.height: # width is constraint
scale = CLIPSEG_SIZE / image.width scale = CLIPSEG_SIZE / image.width
else: else:
scale = CLIPSEG_SIZE / image.height scale = CLIPSEG_SIZE / image.height
scaled_image.paste( scaled_image.paste(
image.resize( image.resize(
(int(scale * image.width), (int(scale * image.width), int(scale * image.height)),
int(scale * image.height) resample=Image.Resampling.LANCZOS,
), ),
resample=Image.Resampling.LANCZOS box=(0, 0),
),box=(0,0)
) )
return scaled_image return scaled_image

View File

@ -1,12 +1,15 @@
from math import sqrt, floor, ceil from math import ceil, floor, sqrt
from PIL import Image from PIL import Image
class InitImageResizer():
class InitImageResizer:
"""Simple class to create resized copies of an Image while preserving the aspect ratio.""" """Simple class to create resized copies of an Image while preserving the aspect ratio."""
def __init__(self,Image):
def __init__(self, Image):
self.image = Image self.image = Image
def resize(self,width=None,height=None) -> Image: def resize(self, width=None, height=None) -> Image:
""" """
Return a copy of the image resized to fit within Return a copy of the image resized to fit within
a box width x height. The aspect ratio is a box width x height. The aspect ratio is
@ -18,37 +21,36 @@ class InitImageResizer():
Everything is floored to the nearest multiple of 64 so Everything is floored to the nearest multiple of 64 so
that it can be passed to img2img() that it can be passed to img2img()
""" """
im = self.image im = self.image
ar = im.width/float(im.height) ar = im.width / float(im.height)
# Infer missing values from aspect ratio # Infer missing values from aspect ratio
if not(width or height): # both missing if not (width or height): # both missing
width = im.width width = im.width
height = im.height height = im.height
elif not height: # height missing elif not height: # height missing
height = int(width/ar) height = int(width / ar)
elif not width: # width missing elif not width: # width missing
width = int(height*ar) width = int(height * ar)
w_scale = width/im.width w_scale = width / im.width
h_scale = height/im.height h_scale = height / im.height
scale = min(w_scale,h_scale) scale = min(w_scale, h_scale)
(rw,rh) = (int(scale*im.width),int(scale*im.height)) (rw, rh) = (int(scale * im.width), int(scale * im.height))
#round everything to multiples of 64 # round everything to multiples of 64
width,height,rw,rh = map( width, height, rw, rh = map(lambda x: x - x % 64, (width, height, rw, rh))
lambda x: x-x%64, (width,height,rw,rh)
)
# no resize necessary, but return a copy # no resize necessary, but return a copy
if im.width == width and im.height == height: if im.width == width and im.height == height:
return im.copy() return im.copy()
# otherwise resize the original image so that it fits inside the bounding box # otherwise resize the original image so that it fits inside the bounding box
resized_image = self.image.resize((rw,rh),resample=Image.Resampling.LANCZOS) resized_image = self.image.resize((rw, rh), resample=Image.Resampling.LANCZOS)
return resized_image return resized_image
def make_grid(image_list, rows=None, cols=None): def make_grid(image_list, rows=None, cols=None):
image_cnt = len(image_list) image_cnt = len(image_list)
if None in (rows, cols): if None in (rows, cols):
@ -57,7 +59,7 @@ def make_grid(image_list, rows=None, cols=None):
width = image_list[0].width width = image_list[0].width
height = image_list[0].height height = image_list[0].height
grid_img = Image.new('RGB', (width * cols, height * rows)) grid_img = Image.new("RGB", (width * cols, height * rows))
i = 0 i = 0
for r in range(0, rows): for r in range(0, rows):
for c in range(0, cols): for c in range(0, cols):
@ -67,4 +69,3 @@ def make_grid(image_list, rows=None, cols=None):
i = i + 1 i = i + 1
return grid_img return grid_img

View File

@ -0,0 +1,10 @@
"""
Initialization file for invokeai.backend.model_management
"""
from .convert_ckpt_to_diffusers import (
convert_ckpt_to_diffusers,
load_pipeline_from_original_stable_diffusion_ckpt,
)
from .model_manager import ModelManager
from invokeai.frontend.merge import merge_diffusion_models

View File

@ -9,7 +9,6 @@ from __future__ import annotations
import contextlib import contextlib
import gc import gc
import hashlib import hashlib
import io
import os import os
import re import re
import sys import sys
@ -32,15 +31,10 @@ from omegaconf import OmegaConf
from omegaconf.dictconfig import DictConfig from omegaconf.dictconfig import DictConfig
from picklescan.scanner import scan_file_path from picklescan.scanner import scan_file_path
from ldm.invoke.devices import CPU_DEVICE from invokeai.backend.globals import Globals, global_cache_dir
from ldm.invoke.generator.diffusers_pipeline import StableDiffusionGeneratorPipeline
from ldm.invoke.globals import Globals, global_cache_dir from ..stable_diffusion import StableDiffusionGeneratorPipeline
from ldm.util import ( from ..util import CPU_DEVICE, ask_user, download_with_resume
ask_user,
download_with_resume,
instantiate_from_config,
url_attachment_name,
)
class SDLegacyType(Enum): class SDLegacyType(Enum):
@ -341,22 +335,9 @@ class ModelManager(object):
tic = time.time() tic = time.time()
# this does the work with warnings.catch_warnings():
model_format = mconfig.get("format", "ckpt") warnings.simplefilter("ignore")
if model_format == "ckpt": model, width, height, model_hash = self._load_diffusers_model(mconfig)
weights = mconfig.weights
print(f">> Loading {model_name} from {weights}")
model, width, height, model_hash = self._load_ckpt_model(
model_name, mconfig
)
elif model_format == "diffusers":
with warnings.catch_warnings():
warnings.simplefilter("ignore")
model, width, height, model_hash = self._load_diffusers_model(mconfig)
else:
raise NotImplementedError(
f"Unknown model format {model_name}: {model_format}"
)
# usage statistics # usage statistics
toc = time.time() toc = time.time()
@ -370,125 +351,6 @@ class ModelManager(object):
) )
return model, width, height, model_hash return model, width, height, model_hash
def _load_ckpt_model(self, model_name, mconfig):
config = mconfig.config
weights = mconfig.weights
vae = mconfig.get("vae")
width = mconfig.width
height = mconfig.height
if not os.path.isabs(config):
config = os.path.join(Globals.root, config)
if not os.path.isabs(weights):
weights = os.path.normpath(os.path.join(Globals.root, weights))
# if converting automatically to diffusers, then we do the conversion and return
# a diffusers pipeline
if Globals.ckpt_convert:
print(
f">> Converting legacy checkpoint {model_name} into a diffusers model..."
)
from ldm.invoke.ckpt_to_diffuser import (
load_pipeline_from_original_stable_diffusion_ckpt,
)
self.offload_model(self.current_model)
if vae_config := self._choose_diffusers_vae(model_name):
vae = self._load_vae(vae_config)
if self._has_cuda():
torch.cuda.empty_cache()
pipeline = load_pipeline_from_original_stable_diffusion_ckpt(
checkpoint_path=weights,
original_config_file=config,
vae=vae,
return_generator_pipeline=True,
precision=torch.float16
if self.precision == "float16"
else torch.float32,
)
if self.sequential_offload:
pipeline.enable_offload_submodels(self.device)
else:
pipeline.to(self.device)
return (
pipeline,
width,
height,
"NOHASH",
)
# scan model
self.scan_model(model_name, weights)
print(f">> Loading {model_name} from {weights}")
# for usage statistics
if self._has_cuda():
torch.cuda.reset_peak_memory_stats()
torch.cuda.empty_cache()
# this does the work
if not os.path.isabs(config):
config = os.path.join(Globals.root, config)
omega_config = OmegaConf.load(config)
with open(weights, "rb") as f:
weight_bytes = f.read()
model_hash = self._cached_sha256(weights, weight_bytes)
sd = None
if weights.endswith(".safetensors"):
sd = safetensors.torch.load(weight_bytes)
else:
sd = torch.load(io.BytesIO(weight_bytes), map_location="cpu")
del weight_bytes
# merged models from auto11 merge board are flat for some reason
if "state_dict" in sd:
sd = sd["state_dict"]
print(" | Forcing garbage collection prior to loading new model")
gc.collect()
model = instantiate_from_config(omega_config.model)
model.load_state_dict(sd, strict=False)
if self.precision == "float16":
print(" | Using faster float16 precision")
model = model.to(torch.float16)
else:
print(" | Using more accurate float32 precision")
# look and load a matching vae file. Code borrowed from AUTOMATIC1111 modules/sd_models.py
if vae:
if not os.path.isabs(vae):
vae = os.path.normpath(os.path.join(Globals.root, vae))
if os.path.exists(vae):
print(f" | Loading VAE weights from: {vae}")
vae_ckpt = None
vae_dict = None
if vae.endswith(".safetensors"):
vae_ckpt = safetensors.torch.load_file(vae)
vae_dict = {k: v for k, v in vae_ckpt.items() if k[0:4] != "loss"}
else:
vae_ckpt = torch.load(vae, map_location="cpu")
vae_dict = {
k: v
for k, v in vae_ckpt["state_dict"].items()
if k[0:4] != "loss"
}
model.first_stage_model.load_state_dict(vae_dict, strict=False)
else:
print(f" | VAE file {vae} not found. Skipping.")
model.to(self.device)
# model.to doesn't change the cond_stage_model.device used to move the tokenizer output, so set it here
model.cond_stage_model.device = self.device
model.eval()
for module in model.modules():
if isinstance(module, (torch.nn.Conv2d, torch.nn.ConvTranspose2d)):
module._orig_padding_mode = module.padding_mode
return model, width, height, model_hash
def _load_diffusers_model(self, mconfig): def _load_diffusers_model(self, mconfig):
name_or_path = self.model_name_or_path(mconfig) name_or_path = self.model_name_or_path(mconfig)
using_fp16 = self.precision == "float16" using_fp16 = self.precision == "float16"
@ -553,6 +415,47 @@ class ModelManager(object):
return pipeline, width, height, model_hash return pipeline, width, height, model_hash
def _load_ckpt_model(self, model_name, mconfig):
config = mconfig.config
weights = mconfig.weights
vae = mconfig.get("vae")
width = mconfig.width
height = mconfig.height
if not os.path.isabs(config):
config = os.path.join(Globals.root, config)
if not os.path.isabs(weights):
weights = os.path.normpath(os.path.join(Globals.root, weights))
# Convert to diffusers and return a diffusers pipeline
print(f">> Converting legacy checkpoint {model_name} into a diffusers model...")
from . import load_pipeline_from_original_stable_diffusion_ckpt
self.offload_model(self.current_model)
if vae_config := self._choose_diffusers_vae(model_name):
vae = self._load_vae(vae_config)
if self._has_cuda():
torch.cuda.empty_cache()
pipeline = load_pipeline_from_original_stable_diffusion_ckpt(
checkpoint_path=weights,
original_config_file=config,
vae=vae,
return_generator_pipeline=True,
precision=torch.float16 if self.precision == "float16" else torch.float32,
)
if self.sequential_offload:
pipeline.enable_offload_submodels(self.device)
else:
pipeline.to(self.device)
return (
pipeline,
width,
height,
"NOHASH",
)
def model_name_or_path(self, model_name: Union[str, DictConfig]) -> str | Path: def model_name_or_path(self, model_name: Union[str, DictConfig]) -> str | Path:
if isinstance(model_name, DictConfig) or isinstance(model_name, dict): if isinstance(model_name, DictConfig) or isinstance(model_name, dict):
mconfig = model_name mconfig = model_name
@ -640,7 +543,9 @@ class ModelManager(object):
models.yaml file. models.yaml file.
""" """
model_name = model_name or Path(repo_or_path).stem model_name = model_name or Path(repo_or_path).stem
model_description = model_description or f"Imported diffusers model {model_name}" model_description = (
model_description or f"Imported diffusers model {model_name}"
)
new_config = dict( new_config = dict(
description=model_description, description=model_description,
vae=vae, vae=vae,
@ -656,66 +561,6 @@ class ModelManager(object):
self.commit(commit_to_conf) self.commit(commit_to_conf)
return model_name return model_name
def import_ckpt_model(
self,
weights: Union[str, Path],
config: Union[str, Path] = "configs/stable-diffusion/v1-inference.yaml",
vae: Union[str, Path] = None,
model_name: str = None,
model_description: str = None,
commit_to_conf: Path = None,
) -> str:
"""
Attempts to install the indicated ckpt file and returns True if successful.
"weights" can be either a path-like object corresponding to a local .ckpt file
or a http/https URL pointing to a remote model.
"vae" is a Path or str object pointing to a ckpt or safetensors file to be used
as the VAE for this model.
"config" is the model config file to use with this ckpt file. It defaults to
v1-inference.yaml. If a URL is provided, the config will be downloaded.
You can optionally provide a model name and/or description. If not provided,
then these will be derived from the weight file name. If you provide a commit_to_conf
path to the configuration file, then the new entry will be committed to the
models.yaml file.
Return value is the name of the imported file, or None if an error occurred.
"""
if str(weights).startswith(("http:", "https:")):
model_name = model_name or url_attachment_name(weights)
weights_path = self._resolve_path(weights, "models/ldm/stable-diffusion-v1")
config_path = self._resolve_path(config, "configs/stable-diffusion")
if weights_path is None or not weights_path.exists():
return
if config_path is None or not config_path.exists():
return
model_name = (
model_name or Path(weights).stem
) # note this gives ugly pathnames if used on a URL without a Content-Disposition header
model_description = (
model_description or f"Imported stable diffusion weights file {model_name}"
)
new_config = dict(
weights=str(weights_path),
config=str(config_path),
description=model_description,
format="ckpt",
width=512,
height=512,
)
if vae:
new_config["vae"] = vae
self.add_model(model_name, new_config, True)
if commit_to_conf:
self.commit(commit_to_conf)
return model_name
@classmethod @classmethod
def probe_model_type(self, checkpoint: dict) -> SDLegacyType: def probe_model_type(self, checkpoint: dict) -> SDLegacyType:
""" """
@ -748,7 +593,7 @@ class ModelManager(object):
def heuristic_import( def heuristic_import(
self, self,
path_url_or_repo: str, path_url_or_repo: str,
convert: bool = False, convert: bool = True,
model_name: str = None, model_name: str = None,
description: str = None, description: str = None,
commit_to_conf: Path = None, commit_to_conf: Path = None,
@ -883,35 +728,18 @@ class ModelManager(object):
) )
return return
if convert: diffuser_path = Path(
diffuser_path = Path( Globals.root, "models", Globals.converted_ckpts_dir, model_path.stem
Globals.root, "models", Globals.converted_ckpts_dir, model_path.stem )
) model_name = self.convert_and_import(
model_name = self.convert_and_import( model_path,
model_path, diffusers_path=diffuser_path,
diffusers_path=diffuser_path, vae=dict(repo_id="stabilityai/sd-vae-ft-mse"),
vae=dict(repo_id="stabilityai/sd-vae-ft-mse"), model_name=model_name,
model_name=model_name, model_description=description,
model_description=description, original_config_file=model_config_file,
original_config_file=model_config_file, commit_to_conf=commit_to_conf,
commit_to_conf=commit_to_conf, )
)
else:
model_name = self.import_ckpt_model(
model_path,
config=model_config_file,
model_name=model_name,
model_description=description,
vae=str(
Path(
Globals.root,
"models/ldm/stable-diffusion-v1/vae-ft-mse-840000-ema-pruned.ckpt",
)
),
commit_to_conf=commit_to_conf,
)
if commit_to_conf:
self.commit(commit_to_conf)
return model_name return model_name
def convert_and_import( def convert_and_import(
@ -936,7 +764,7 @@ class ModelManager(object):
new_config = None new_config = None
from ldm.invoke.ckpt_to_diffuser import convert_ckpt_to_diffuser from . import convert_ckpt_to_diffusers
if diffusers_path.exists(): if diffusers_path.exists():
print( print(
@ -951,7 +779,7 @@ class ModelManager(object):
# By passing the specified VAE to the conversion function, the autoencoder # By passing the specified VAE to the conversion function, the autoencoder
# will be built into the model rather than tacked on afterward via the config file # will be built into the model rather than tacked on afterward via the config file
vae_model = self._load_vae(vae) if vae else None vae_model = self._load_vae(vae) if vae else None
convert_ckpt_to_diffuser( convert_ckpt_to_diffusers(
ckpt_path, ckpt_path,
diffusers_path, diffusers_path,
extract_ema=True, extract_ema=True,

View File

@ -0,0 +1,10 @@
"""
Initialization file for invokeai.backend.prompting
"""
from .conditioning import (
get_prompt_structure,
get_tokenizer,
get_tokens_for_prompt_object,
get_uc_and_c_and_ec,
split_weighted_subprompts,
)

View File

@ -1,31 +1,46 @@
''' """
This module handles the generation of the conditioning tensors. This module handles the generation of the conditioning tensors.
Useful function exports: Useful function exports:
get_uc_and_c_and_ec() get the conditioned and unconditioned latent, and edited conditioning if we're doing cross-attention control get_uc_and_c_and_ec() get the conditioned and unconditioned latent, and edited conditioning if we're doing cross-attention control
''' """
import re import re
from typing import Union, Optional, Any from typing import Any, Optional, Union
from transformers import CLIPTokenizer, CLIPTextModel
from compel import Compel from compel import Compel
from compel.prompt_parser import FlattenedPrompt, Blend, Fragment, CrossAttentionControlSubstitute, PromptParser from compel.prompt_parser import (
from .devices import torch_dtype Blend,
from ..models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent CrossAttentionControlSubstitute,
from ldm.invoke.globals import Globals FlattenedPrompt,
Fragment,
PromptParser,
)
from transformers import CLIPTextModel, CLIPTokenizer
from invokeai.backend.globals import Globals
from ..stable_diffusion import InvokeAIDiffuserComponent
from ..util import torch_dtype
def get_tokenizer(model) -> CLIPTokenizer: def get_tokenizer(model) -> CLIPTokenizer:
# TODO remove legacy ckpt fallback handling # TODO remove legacy ckpt fallback handling
return (getattr(model, 'tokenizer', None) # diffusers return (
or model.cond_stage_model.tokenizer) # ldm getattr(model, "tokenizer", None) # diffusers
or model.cond_stage_model.tokenizer
) # ldm
def get_text_encoder(model) -> Any: def get_text_encoder(model) -> Any:
# TODO remove legacy ckpt fallback handling # TODO remove legacy ckpt fallback handling
return (getattr(model, 'text_encoder', None) # diffusers return getattr(
or UnsqueezingLDMTransformer(model.cond_stage_model.transformer)) # ldm model, "text_encoder", None
) or UnsqueezingLDMTransformer( # diffusers
model.cond_stage_model.transformer
) # ldm
class UnsqueezingLDMTransformer: class UnsqueezingLDMTransformer:
def __init__(self, ldm_transformer): def __init__(self, ldm_transformer):
@ -40,28 +55,41 @@ class UnsqueezingLDMTransformer:
return insufficiently_unsqueezed_tensor.unsqueeze(0) return insufficiently_unsqueezed_tensor.unsqueeze(0)
def get_uc_and_c_and_ec(prompt_string, model, log_tokens=False, skip_normalize_legacy_blend=False): def get_uc_and_c_and_ec(
prompt_string, model, log_tokens=False, skip_normalize_legacy_blend=False
):
# lazy-load any deferred textual inversions. # lazy-load any deferred textual inversions.
# this might take a couple of seconds the first time a textual inversion is used. # this might take a couple of seconds the first time a textual inversion is used.
model.textual_inversion_manager.create_deferred_token_ids_for_any_trigger_terms(prompt_string) model.textual_inversion_manager.create_deferred_token_ids_for_any_trigger_terms(
prompt_string
)
tokenizer = get_tokenizer(model) tokenizer = get_tokenizer(model)
text_encoder = get_text_encoder(model) text_encoder = get_text_encoder(model)
compel = Compel(tokenizer=tokenizer, compel = Compel(
text_encoder=text_encoder, tokenizer=tokenizer,
textual_inversion_manager=model.textual_inversion_manager, text_encoder=text_encoder,
dtype_for_device_getter=torch_dtype) textual_inversion_manager=model.textual_inversion_manager,
dtype_for_device_getter=torch_dtype,
)
# get rid of any newline characters # get rid of any newline characters
prompt_string = prompt_string.replace("\n", " ") prompt_string = prompt_string.replace("\n", " ")
positive_prompt_string, negative_prompt_string = split_prompt_to_positive_and_negative(prompt_string) (
legacy_blend = try_parse_legacy_blend(positive_prompt_string, skip_normalize_legacy_blend) positive_prompt_string,
positive_prompt: FlattenedPrompt|Blend negative_prompt_string,
) = split_prompt_to_positive_and_negative(prompt_string)
legacy_blend = try_parse_legacy_blend(
positive_prompt_string, skip_normalize_legacy_blend
)
positive_prompt: FlattenedPrompt | Blend
if legacy_blend is not None: if legacy_blend is not None:
positive_prompt = legacy_blend positive_prompt = legacy_blend
else: else:
positive_prompt = Compel.parse_prompt_string(positive_prompt_string) positive_prompt = Compel.parse_prompt_string(positive_prompt_string)
negative_prompt: FlattenedPrompt|Blend = Compel.parse_prompt_string(negative_prompt_string) negative_prompt: FlattenedPrompt | Blend = Compel.parse_prompt_string(
negative_prompt_string
)
if log_tokens or getattr(Globals, "log_tokenization", False): if log_tokens or getattr(Globals, "log_tokenization", False):
log_tokenization(positive_prompt, negative_prompt, tokenizer=tokenizer) log_tokenization(positive_prompt, negative_prompt, tokenizer=tokenizer)
@ -71,42 +99,70 @@ def get_uc_and_c_and_ec(prompt_string, model, log_tokens=False, skip_normalize_l
tokens_count = get_max_token_count(tokenizer, positive_prompt) tokens_count = get_max_token_count(tokenizer, positive_prompt)
ec = InvokeAIDiffuserComponent.ExtraConditioningInfo(tokens_count_including_eos_bos=tokens_count, ec = InvokeAIDiffuserComponent.ExtraConditioningInfo(
cross_attention_control_args=options.get( tokens_count_including_eos_bos=tokens_count,
'cross_attention_control', None)) cross_attention_control_args=options.get("cross_attention_control", None),
)
return uc, c, ec return uc, c, ec
def get_prompt_structure(prompt_string, skip_normalize_legacy_blend: bool = False) -> ( def get_prompt_structure(
Union[FlattenedPrompt, Blend], FlattenedPrompt): prompt_string, skip_normalize_legacy_blend: bool = False
positive_prompt_string, negative_prompt_string = split_prompt_to_positive_and_negative(prompt_string) ) -> (Union[FlattenedPrompt, Blend], FlattenedPrompt):
legacy_blend = try_parse_legacy_blend(positive_prompt_string, skip_normalize_legacy_blend) (
positive_prompt: FlattenedPrompt|Blend positive_prompt_string,
negative_prompt_string,
) = split_prompt_to_positive_and_negative(prompt_string)
legacy_blend = try_parse_legacy_blend(
positive_prompt_string, skip_normalize_legacy_blend
)
positive_prompt: FlattenedPrompt | Blend
if legacy_blend is not None: if legacy_blend is not None:
positive_prompt = legacy_blend positive_prompt = legacy_blend
else: else:
positive_prompt = Compel.parse_prompt_string(positive_prompt_string) positive_prompt = Compel.parse_prompt_string(positive_prompt_string)
negative_prompt: FlattenedPrompt|Blend = Compel.parse_prompt_string(negative_prompt_string) negative_prompt: FlattenedPrompt | Blend = Compel.parse_prompt_string(
negative_prompt_string
)
return positive_prompt, negative_prompt return positive_prompt, negative_prompt
def get_max_token_count(tokenizer, prompt: Union[FlattenedPrompt, Blend], truncate_if_too_long=True) -> int:
def get_max_token_count(
tokenizer, prompt: Union[FlattenedPrompt, Blend], truncate_if_too_long=True
) -> int:
if type(prompt) is Blend: if type(prompt) is Blend:
blend: Blend = prompt blend: Blend = prompt
return max([get_max_token_count(tokenizer, c, truncate_if_too_long) for c in blend.prompts]) return max(
[
get_max_token_count(tokenizer, c, truncate_if_too_long)
for c in blend.prompts
]
)
else: else:
return len(get_tokens_for_prompt_object(tokenizer, prompt, truncate_if_too_long)) return len(
get_tokens_for_prompt_object(tokenizer, prompt, truncate_if_too_long)
)
def get_tokens_for_prompt_object(tokenizer, parsed_prompt: FlattenedPrompt, truncate_if_too_long=True) -> [str]: def get_tokens_for_prompt_object(
tokenizer, parsed_prompt: FlattenedPrompt, truncate_if_too_long=True
) -> [str]:
if type(parsed_prompt) is Blend: if type(parsed_prompt) is Blend:
raise ValueError("Blend is not supported here - you need to get tokens for each of its .children") raise ValueError(
"Blend is not supported here - you need to get tokens for each of its .children"
)
text_fragments = [x.text if type(x) is Fragment else text_fragments = [
(" ".join([f.text for f in x.original]) if type(x) is CrossAttentionControlSubstitute else x.text
str(x)) if type(x) is Fragment
for x in parsed_prompt.children] else (
" ".join([f.text for f in x.original])
if type(x) is CrossAttentionControlSubstitute
else str(x)
)
for x in parsed_prompt.children
]
text = " ".join(text_fragments) text = " ".join(text_fragments)
tokens = tokenizer.tokenize(text) tokens = tokenizer.tokenize(text)
if truncate_if_too_long: if truncate_if_too_long:
@ -116,39 +172,47 @@ def get_tokens_for_prompt_object(tokenizer, parsed_prompt: FlattenedPrompt, trun
def split_prompt_to_positive_and_negative(prompt_string_uncleaned: str): def split_prompt_to_positive_and_negative(prompt_string_uncleaned: str):
unconditioned_words = '' unconditioned_words = ""
unconditional_regex = r'\[(.*?)\]' unconditional_regex = r"\[(.*?)\]"
unconditionals = re.findall(unconditional_regex, prompt_string_uncleaned) unconditionals = re.findall(unconditional_regex, prompt_string_uncleaned)
if len(unconditionals) > 0: if len(unconditionals) > 0:
unconditioned_words = ' '.join(unconditionals) unconditioned_words = " ".join(unconditionals)
# Remove Unconditioned Words From Prompt # Remove Unconditioned Words From Prompt
unconditional_regex_compile = re.compile(unconditional_regex) unconditional_regex_compile = re.compile(unconditional_regex)
clean_prompt = unconditional_regex_compile.sub(' ', prompt_string_uncleaned) clean_prompt = unconditional_regex_compile.sub(" ", prompt_string_uncleaned)
prompt_string_cleaned = re.sub(' +', ' ', clean_prompt) prompt_string_cleaned = re.sub(" +", " ", clean_prompt)
else: else:
prompt_string_cleaned = prompt_string_uncleaned prompt_string_cleaned = prompt_string_uncleaned
return prompt_string_cleaned, unconditioned_words return prompt_string_cleaned, unconditioned_words
def log_tokenization(positive_prompt: Union[Blend, FlattenedPrompt], def log_tokenization(
negative_prompt: Union[Blend, FlattenedPrompt], positive_prompt: Union[Blend, FlattenedPrompt],
tokenizer): negative_prompt: Union[Blend, FlattenedPrompt],
tokenizer,
):
print(f"\n>> [TOKENLOG] Parsed Prompt: {positive_prompt}") print(f"\n>> [TOKENLOG] Parsed Prompt: {positive_prompt}")
print(f"\n>> [TOKENLOG] Parsed Negative Prompt: {negative_prompt}") print(f"\n>> [TOKENLOG] Parsed Negative Prompt: {negative_prompt}")
log_tokenization_for_prompt_object(positive_prompt, tokenizer) log_tokenization_for_prompt_object(positive_prompt, tokenizer)
log_tokenization_for_prompt_object(negative_prompt, tokenizer, display_label_prefix="(negative prompt)") log_tokenization_for_prompt_object(
negative_prompt, tokenizer, display_label_prefix="(negative prompt)"
)
def log_tokenization_for_prompt_object(p: Union[Blend, FlattenedPrompt], tokenizer, display_label_prefix=None): def log_tokenization_for_prompt_object(
p: Union[Blend, FlattenedPrompt], tokenizer, display_label_prefix=None
):
display_label_prefix = display_label_prefix or "" display_label_prefix = display_label_prefix or ""
if type(p) is Blend: if type(p) is Blend:
blend: Blend = p blend: Blend = p
for i, c in enumerate(blend.prompts): for i, c in enumerate(blend.prompts):
log_tokenization_for_prompt_object( log_tokenization_for_prompt_object(
c, tokenizer, c,
display_label_prefix=f"{display_label_prefix}(blend part {i + 1}, weight={blend.weights[i]})") tokenizer,
display_label_prefix=f"{display_label_prefix}(blend part {i + 1}, weight={blend.weights[i]})",
)
elif type(p) is FlattenedPrompt: elif type(p) is FlattenedPrompt:
flattened_prompt: FlattenedPrompt = p flattened_prompt: FlattenedPrompt = p
if flattened_prompt.wants_cross_attention_control: if flattened_prompt.wants_cross_attention_control:
@ -163,18 +227,26 @@ def log_tokenization_for_prompt_object(p: Union[Blend, FlattenedPrompt], tokeniz
edited_fragments.append(f) edited_fragments.append(f)
original_text = " ".join([x.text for x in original_fragments]) original_text = " ".join([x.text for x in original_fragments])
log_tokenization_for_text(original_text, tokenizer, log_tokenization_for_text(
display_label=f"{display_label_prefix}(.swap originals)") original_text,
tokenizer,
display_label=f"{display_label_prefix}(.swap originals)",
)
edited_text = " ".join([x.text for x in edited_fragments]) edited_text = " ".join([x.text for x in edited_fragments])
log_tokenization_for_text(edited_text, tokenizer, log_tokenization_for_text(
display_label=f"{display_label_prefix}(.swap replacements)") edited_text,
tokenizer,
display_label=f"{display_label_prefix}(.swap replacements)",
)
else: else:
text = " ".join([x.text for x in flattened_prompt.children]) text = " ".join([x.text for x in flattened_prompt.children])
log_tokenization_for_text(text, tokenizer, display_label=display_label_prefix) log_tokenization_for_text(
text, tokenizer, display_label=display_label_prefix
)
def log_tokenization_for_text(text, tokenizer, display_label=None): def log_tokenization_for_text(text, tokenizer, display_label=None):
""" shows how the prompt is tokenized """shows how the prompt is tokenized
# usually tokens have '</w>' to indicate end-of-word, # usually tokens have '</w>' to indicate end-of-word,
# but for readability it has been replaced with ' ' # but for readability it has been replaced with ' '
""" """
@ -185,7 +257,7 @@ def log_tokenization_for_text(text, tokenizer, display_label=None):
totalTokens = len(tokens) totalTokens = len(tokens)
for i in range(0, totalTokens): for i in range(0, totalTokens):
token = tokens[i].replace('</w>', ' ') token = tokens[i].replace("</w>", " ")
# alternate color # alternate color
s = (usedTokens % 6) + 1 s = (usedTokens % 6) + 1
if i < tokenizer.model_max_length: if i < tokenizer.model_max_length:
@ -196,14 +268,14 @@ def log_tokenization_for_text(text, tokenizer, display_label=None):
if usedTokens > 0: if usedTokens > 0:
print(f'\n>> [TOKENLOG] Tokens {display_label or ""} ({usedTokens}):') print(f'\n>> [TOKENLOG] Tokens {display_label or ""} ({usedTokens}):')
print(f'{tokenized}\x1b[0m') print(f"{tokenized}\x1b[0m")
if discarded != "": if discarded != "":
print(f'\n>> [TOKENLOG] Tokens Discarded ({totalTokens - usedTokens}):') print(f"\n>> [TOKENLOG] Tokens Discarded ({totalTokens - usedTokens}):")
print(f'{discarded}\x1b[0m') print(f"{discarded}\x1b[0m")
def try_parse_legacy_blend(text: str, skip_normalize: bool=False) -> Optional[Blend]: def try_parse_legacy_blend(text: str, skip_normalize: bool = False) -> Optional[Blend]:
weighted_subprompts = split_weighted_subprompts(text, skip_normalize=skip_normalize) weighted_subprompts = split_weighted_subprompts(text, skip_normalize=skip_normalize)
if len(weighted_subprompts) <= 1: if len(weighted_subprompts) <= 1:
return None return None
@ -214,10 +286,12 @@ def try_parse_legacy_blend(text: str, skip_normalize: bool=False) -> Optional[Bl
parsed_conjunctions = [pp.parse_conjunction(x) for x in strings] parsed_conjunctions = [pp.parse_conjunction(x) for x in strings]
flattened_prompts = [x.prompts[0] for x in parsed_conjunctions] flattened_prompts = [x.prompts[0] for x in parsed_conjunctions]
return Blend(prompts=flattened_prompts, weights=weights, normalize_weights=not skip_normalize) return Blend(
prompts=flattened_prompts, weights=weights, normalize_weights=not skip_normalize
)
def split_weighted_subprompts(text, skip_normalize=False)->list: def split_weighted_subprompts(text, skip_normalize=False) -> list:
""" """
Legacy blend parsing. Legacy blend parsing.
@ -226,7 +300,8 @@ def split_weighted_subprompts(text, skip_normalize=False)->list:
if ':' has no value defined, defaults to 1.0 if ':' has no value defined, defaults to 1.0
repeats until no text remaining repeats until no text remaining
""" """
prompt_parser = re.compile(""" prompt_parser = re.compile(
"""
(?P<prompt> # capture group for 'prompt' (?P<prompt> # capture group for 'prompt'
(?:\\\:|[^:])+ # match one or more non ':' characters or escaped colons '\:' (?:\\\:|[^:])+ # match one or more non ':' characters or escaped colons '\:'
) # end 'prompt' ) # end 'prompt'
@ -239,16 +314,20 @@ def split_weighted_subprompts(text, skip_normalize=False)->list:
| # OR | # OR
$ # else, if no ':' then match end of line $ # else, if no ':' then match end of line
) # end non-capture group ) # end non-capture group
""", re.VERBOSE) """,
parsed_prompts = [(match.group("prompt").replace("\\:", ":"), float( re.VERBOSE,
match.group("weight") or 1)) for match in re.finditer(prompt_parser, text)] )
parsed_prompts = [
(match.group("prompt").replace("\\:", ":"), float(match.group("weight") or 1))
for match in re.finditer(prompt_parser, text)
]
if skip_normalize: if skip_normalize:
return parsed_prompts return parsed_prompts
weight_sum = sum(map(lambda x: x[1], parsed_prompts)) weight_sum = sum(map(lambda x: x[1], parsed_prompts))
if weight_sum == 0: if weight_sum == 0:
print( print(
"* Warning: Subprompt weights add up to zero. Discarding and using even weights instead.") "* Warning: Subprompt weights add up to zero. Discarding and using even weights instead."
)
equal_weight = 1 / max(len(parsed_prompts), 1) equal_weight = 1 / max(len(parsed_prompts), 1)
return [(x[0], equal_weight) for x in parsed_prompts] return [(x[0], equal_weight) for x in parsed_prompts]
return [(x[0], x[1] / weight_sum) for x in parsed_prompts] return [(x[0], x[1] / weight_sum) for x in parsed_prompts]

View File

@ -0,0 +1,4 @@
"""
Initialization file for the invokeai.backend.restoration package
"""
from .base import Restoration

View File

@ -1,38 +1,43 @@
class Restoration(): class Restoration:
def __init__(self) -> None: def __init__(self) -> None:
pass pass
def load_face_restore_models(self, gfpgan_model_path='./models/gfpgan/GFPGANv1.4.pth'): def load_face_restore_models(
self, gfpgan_model_path="./models/gfpgan/GFPGANv1.4.pth"
):
# Load GFPGAN # Load GFPGAN
gfpgan = self.load_gfpgan(gfpgan_model_path) gfpgan = self.load_gfpgan(gfpgan_model_path)
if gfpgan.gfpgan_model_exists: if gfpgan.gfpgan_model_exists:
print('>> GFPGAN Initialized') print(">> GFPGAN Initialized")
else: else:
print('>> GFPGAN Disabled') print(">> GFPGAN Disabled")
gfpgan = None gfpgan = None
# Load CodeFormer # Load CodeFormer
codeformer = self.load_codeformer() codeformer = self.load_codeformer()
if codeformer.codeformer_model_exists: if codeformer.codeformer_model_exists:
print('>> CodeFormer Initialized') print(">> CodeFormer Initialized")
else: else:
print('>> CodeFormer Disabled') print(">> CodeFormer Disabled")
codeformer = None codeformer = None
return gfpgan, codeformer return gfpgan, codeformer
# Face Restore Models # Face Restore Models
def load_gfpgan(self, gfpgan_model_path): def load_gfpgan(self, gfpgan_model_path):
from ldm.invoke.restoration.gfpgan import GFPGAN from .gfpgan import GFPGAN
return GFPGAN(gfpgan_model_path) return GFPGAN(gfpgan_model_path)
def load_codeformer(self): def load_codeformer(self):
from ldm.invoke.restoration.codeformer import CodeFormerRestoration from .codeformer import CodeFormerRestoration
return CodeFormerRestoration() return CodeFormerRestoration()
# Upscale Models # Upscale Models
def load_esrgan(self, esrgan_bg_tile=400): def load_esrgan(self, esrgan_bg_tile=400):
from ldm.invoke.restoration.realesrgan import ESRGAN from .realesrgan import ESRGAN
esrgan = ESRGAN(esrgan_bg_tile) esrgan = ESRGAN(esrgan_bg_tile)
print('>> ESRGAN Initialized') print(">> ESRGAN Initialized")
return esrgan; return esrgan

View File

@ -1,17 +1,21 @@
import os import os
import torch
import numpy as np
import warnings
import sys import sys
from ldm.invoke.globals import Globals import warnings
pretrained_model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth' import numpy as np
import torch
class CodeFormerRestoration(): from ..globals import Globals
def __init__(self,
codeformer_dir='models/codeformer',
codeformer_model_path='codeformer.pth') -> None:
pretrained_model_url = (
"https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth"
)
class CodeFormerRestoration:
def __init__(
self, codeformer_dir="models/codeformer", codeformer_model_path="codeformer.pth"
) -> None:
if not os.path.isabs(codeformer_dir): if not os.path.isabs(codeformer_dir):
codeformer_dir = os.path.join(Globals.root, codeformer_dir) codeformer_dir = os.path.join(Globals.root, codeformer_dir)
@ -19,22 +23,23 @@ class CodeFormerRestoration():
self.codeformer_model_exists = os.path.isfile(self.model_path) self.codeformer_model_exists = os.path.isfile(self.model_path)
if not self.codeformer_model_exists: if not self.codeformer_model_exists:
print('## NOT FOUND: CodeFormer model not found at ' + self.model_path) print("## NOT FOUND: CodeFormer model not found at " + self.model_path)
sys.path.append(os.path.abspath(codeformer_dir)) sys.path.append(os.path.abspath(codeformer_dir))
def process(self, image, strength, device, seed=None, fidelity=0.75): def process(self, image, strength, device, seed=None, fidelity=0.75):
if seed is not None: if seed is not None:
print(f'>> CodeFormer - Restoring Faces for image seed:{seed}') print(f">> CodeFormer - Restoring Faces for image seed:{seed}")
with warnings.catch_warnings(): with warnings.catch_warnings():
warnings.filterwarnings('ignore', category=DeprecationWarning) warnings.filterwarnings("ignore", category=DeprecationWarning)
warnings.filterwarnings('ignore', category=UserWarning) warnings.filterwarnings("ignore", category=UserWarning)
from basicsr.utils.download_util import load_file_from_url
from basicsr.utils import img2tensor, tensor2img from basicsr.utils import img2tensor, tensor2img
from basicsr.utils.download_util import load_file_from_url
from facexlib.utils.face_restoration_helper import FaceRestoreHelper from facexlib.utils.face_restoration_helper import FaceRestoreHelper
from ldm.invoke.restoration.codeformer_arch import CodeFormer
from torchvision.transforms.functional import normalize
from PIL import Image from PIL import Image
from torchvision.transforms.functional import normalize
from .codeformer_arch import CodeFormer
cf_class = CodeFormer cf_class = CodeFormer
@ -43,28 +48,31 @@ class CodeFormerRestoration():
codebook_size=1024, codebook_size=1024,
n_head=8, n_head=8,
n_layers=9, n_layers=9,
connect_list=['32', '64', '128', '256'] connect_list=["32", "64", "128", "256"],
).to(device) ).to(device)
# note that this file should already be downloaded and cached at # note that this file should already be downloaded and cached at
# this point # this point
checkpoint_path = load_file_from_url(url=pretrained_model_url, checkpoint_path = load_file_from_url(
model_dir=os.path.abspath(os.path.dirname(self.model_path)), url=pretrained_model_url,
progress=True model_dir=os.path.abspath(os.path.dirname(self.model_path)),
progress=True,
) )
checkpoint = torch.load(checkpoint_path)['params_ema'] checkpoint = torch.load(checkpoint_path)["params_ema"]
cf.load_state_dict(checkpoint) cf.load_state_dict(checkpoint)
cf.eval() cf.eval()
image = image.convert('RGB') image = image.convert("RGB")
# Codeformer expects a BGR np array; make array and flip channels # Codeformer expects a BGR np array; make array and flip channels
bgr_image_array = np.array(image, dtype=np.uint8)[...,::-1] bgr_image_array = np.array(image, dtype=np.uint8)[..., ::-1]
face_helper = FaceRestoreHelper( face_helper = FaceRestoreHelper(
upscale_factor=1, upscale_factor=1,
use_parse=True, use_parse=True,
device=device, device=device,
model_rootpath=os.path.join(Globals.root,'models','gfpgan','weights'), model_rootpath=os.path.join(
Globals.root, "models", "gfpgan", "weights"
),
) )
face_helper.clean_all() face_helper.clean_all()
face_helper.read_image(bgr_image_array) face_helper.read_image(bgr_image_array)
@ -72,30 +80,35 @@ class CodeFormerRestoration():
face_helper.align_warp_face() face_helper.align_warp_face()
for idx, cropped_face in enumerate(face_helper.cropped_faces): for idx, cropped_face in enumerate(face_helper.cropped_faces):
cropped_face_t = img2tensor(cropped_face / 255., bgr2rgb=True, float32=True) cropped_face_t = img2tensor(
normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True) cropped_face / 255.0, bgr2rgb=True, float32=True
)
normalize(
cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True
)
cropped_face_t = cropped_face_t.unsqueeze(0).to(device) cropped_face_t = cropped_face_t.unsqueeze(0).to(device)
try: try:
with torch.no_grad(): with torch.no_grad():
output = cf(cropped_face_t, w=fidelity, adain=True)[0] output = cf(cropped_face_t, w=fidelity, adain=True)[0]
restored_face = tensor2img(output.squeeze(0), rgb2bgr=True, min_max=(-1, 1)) restored_face = tensor2img(
output.squeeze(0), rgb2bgr=True, min_max=(-1, 1)
)
del output del output
torch.cuda.empty_cache() torch.cuda.empty_cache()
except RuntimeError as error: except RuntimeError as error:
print(f'\tFailed inference for CodeFormer: {error}.') print(f"\tFailed inference for CodeFormer: {error}.")
restored_face = cropped_face restored_face = cropped_face
restored_face = restored_face.astype('uint8') restored_face = restored_face.astype("uint8")
face_helper.add_restored_face(restored_face) face_helper.add_restored_face(restored_face)
face_helper.get_inverse_affine(None) face_helper.get_inverse_affine(None)
restored_img = face_helper.paste_faces_to_input_image() restored_img = face_helper.paste_faces_to_input_image()
# Flip the channels back to RGB # Flip the channels back to RGB
res = Image.fromarray(restored_img[...,::-1]) res = Image.fromarray(restored_img[..., ::-1])
if strength < 1.0: if strength < 1.0:
# Resize the image to the new image if the sizes have changed # Resize the image to the new image if the sizes have changed

View File

@ -1,13 +1,15 @@
import math import math
from typing import List, Optional
import numpy as np import numpy as np
import torch import torch
from torch import nn, Tensor
import torch.nn.functional as F import torch.nn.functional as F
from typing import Optional, List
from ldm.invoke.restoration.vqgan_arch import *
from basicsr.utils import get_root_logger from basicsr.utils import get_root_logger
from basicsr.utils.registry import ARCH_REGISTRY from basicsr.utils.registry import ARCH_REGISTRY
from torch import Tensor, nn
from .vqgan_arch import *
def calc_mean_std(feat, eps=1e-5): def calc_mean_std(feat, eps=1e-5):
"""Calculate mean and std for adaptive_instance_normalization. """Calculate mean and std for adaptive_instance_normalization.
@ -18,7 +20,7 @@ def calc_mean_std(feat, eps=1e-5):
divide-by-zero. Default: 1e-5. divide-by-zero. Default: 1e-5.
""" """
size = feat.size() size = feat.size()
assert len(size) == 4, 'The input feature should be 4D tensor.' assert len(size) == 4, "The input feature should be 4D tensor."
b, c = size[:2] b, c = size[:2]
feat_var = feat.view(b, c, -1).var(dim=2) + eps feat_var = feat.view(b, c, -1).var(dim=2) + eps
feat_std = feat_var.sqrt().view(b, c, 1, 1) feat_std = feat_var.sqrt().view(b, c, 1, 1)
@ -39,7 +41,9 @@ def adaptive_instance_normalization(content_feat, style_feat):
size = content_feat.size() size = content_feat.size()
style_mean, style_std = calc_mean_std(style_feat) style_mean, style_std = calc_mean_std(style_feat)
content_mean, content_std = calc_mean_std(content_feat) content_mean, content_std = calc_mean_std(content_feat)
normalized_feat = (content_feat - content_mean.expand(size)) / content_std.expand(size) normalized_feat = (content_feat - content_mean.expand(size)) / content_std.expand(
size
)
return normalized_feat * style_std.expand(size) + style_mean.expand(size) return normalized_feat * style_std.expand(size) + style_mean.expand(size)
@ -49,7 +53,9 @@ class PositionEmbeddingSine(nn.Module):
used by the Attention is all you need paper, generalized to work on images. used by the Attention is all you need paper, generalized to work on images.
""" """
def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None): def __init__(
self, num_pos_feats=64, temperature=10000, normalize=False, scale=None
):
super().__init__() super().__init__()
self.num_pos_feats = num_pos_feats self.num_pos_feats = num_pos_feats
self.temperature = temperature self.temperature = temperature
@ -62,7 +68,9 @@ class PositionEmbeddingSine(nn.Module):
def forward(self, x, mask=None): def forward(self, x, mask=None):
if mask is None: if mask is None:
mask = torch.zeros((x.size(0), x.size(2), x.size(3)), device=x.device, dtype=torch.bool) mask = torch.zeros(
(x.size(0), x.size(2), x.size(3)), device=x.device, dtype=torch.bool
)
not_mask = ~mask not_mask = ~mask
y_embed = not_mask.cumsum(1, dtype=torch.float32) y_embed = not_mask.cumsum(1, dtype=torch.float32)
x_embed = not_mask.cumsum(2, dtype=torch.float32) x_embed = not_mask.cumsum(2, dtype=torch.float32)
@ -85,6 +93,7 @@ class PositionEmbeddingSine(nn.Module):
pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
return pos return pos
def _get_activation_fn(activation): def _get_activation_fn(activation):
"""Return an activation function given a string""" """Return an activation function given a string"""
if activation == "relu": if activation == "relu":
@ -93,11 +102,13 @@ def _get_activation_fn(activation):
return F.gelu return F.gelu
if activation == "glu": if activation == "glu":
return F.glu return F.glu
raise RuntimeError(F"activation should be relu/gelu, not {activation}.") raise RuntimeError(f"activation should be relu/gelu, not {activation}.")
class TransformerSALayer(nn.Module): class TransformerSALayer(nn.Module):
def __init__(self, embed_dim, nhead=8, dim_mlp=2048, dropout=0.0, activation="gelu"): def __init__(
self, embed_dim, nhead=8, dim_mlp=2048, dropout=0.0, activation="gelu"
):
super().__init__() super().__init__()
self.self_attn = nn.MultiheadAttention(embed_dim, nhead, dropout=dropout) self.self_attn = nn.MultiheadAttention(embed_dim, nhead, dropout=dropout)
# Implementation of Feedforward model - MLP # Implementation of Feedforward model - MLP
@ -115,16 +126,19 @@ class TransformerSALayer(nn.Module):
def with_pos_embed(self, tensor, pos: Optional[Tensor]): def with_pos_embed(self, tensor, pos: Optional[Tensor]):
return tensor if pos is None else tensor + pos return tensor if pos is None else tensor + pos
def forward(self, tgt, def forward(
tgt_mask: Optional[Tensor] = None, self,
tgt_key_padding_mask: Optional[Tensor] = None, tgt,
query_pos: Optional[Tensor] = None): tgt_mask: Optional[Tensor] = None,
tgt_key_padding_mask: Optional[Tensor] = None,
query_pos: Optional[Tensor] = None,
):
# self attention # self attention
tgt2 = self.norm1(tgt) tgt2 = self.norm1(tgt)
q = k = self.with_pos_embed(tgt2, query_pos) q = k = self.with_pos_embed(tgt2, query_pos)
tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask, tgt2 = self.self_attn(
key_padding_mask=tgt_key_padding_mask)[0] q, k, value=tgt2, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask
)[0]
tgt = tgt + self.dropout1(tgt2) tgt = tgt + self.dropout1(tgt2)
# ffn # ffn
@ -133,20 +147,23 @@ class TransformerSALayer(nn.Module):
tgt = tgt + self.dropout2(tgt2) tgt = tgt + self.dropout2(tgt2)
return tgt return tgt
class Fuse_sft_block(nn.Module): class Fuse_sft_block(nn.Module):
def __init__(self, in_ch, out_ch): def __init__(self, in_ch, out_ch):
super().__init__() super().__init__()
self.encode_enc = ResBlock(2*in_ch, out_ch) self.encode_enc = ResBlock(2 * in_ch, out_ch)
self.scale = nn.Sequential( self.scale = nn.Sequential(
nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1), nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
nn.LeakyReLU(0.2, True), nn.LeakyReLU(0.2, True),
nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1)) nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1),
)
self.shift = nn.Sequential( self.shift = nn.Sequential(
nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1), nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
nn.LeakyReLU(0.2, True), nn.LeakyReLU(0.2, True),
nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1)) nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1),
)
def forward(self, enc_feat, dec_feat, w=1): def forward(self, enc_feat, dec_feat, w=1):
enc_feat = self.encode_enc(torch.cat([enc_feat, dec_feat], dim=1)) enc_feat = self.encode_enc(torch.cat([enc_feat, dec_feat], dim=1))
@ -159,11 +176,19 @@ class Fuse_sft_block(nn.Module):
@ARCH_REGISTRY.register() @ARCH_REGISTRY.register()
class CodeFormer(VQAutoEncoder): class CodeFormer(VQAutoEncoder):
def __init__(self, dim_embd=512, n_head=8, n_layers=9, def __init__(
codebook_size=1024, latent_size=256, self,
connect_list=['32', '64', '128', '256'], dim_embd=512,
fix_modules=['quantize','generator']): n_head=8,
super(CodeFormer, self).__init__(512, 64, [1, 2, 2, 4, 4, 8], 'nearest',2, [16], codebook_size) n_layers=9,
codebook_size=1024,
latent_size=256,
connect_list=["32", "64", "128", "256"],
fix_modules=["quantize", "generator"],
):
super(CodeFormer, self).__init__(
512, 64, [1, 2, 2, 4, 4, 8], "nearest", 2, [16], codebook_size
)
if fix_modules is not None: if fix_modules is not None:
for module in fix_modules: for module in fix_modules:
@ -173,33 +198,53 @@ class CodeFormer(VQAutoEncoder):
self.connect_list = connect_list self.connect_list = connect_list
self.n_layers = n_layers self.n_layers = n_layers
self.dim_embd = dim_embd self.dim_embd = dim_embd
self.dim_mlp = dim_embd*2 self.dim_mlp = dim_embd * 2
self.position_emb = nn.Parameter(torch.zeros(latent_size, self.dim_embd)) self.position_emb = nn.Parameter(torch.zeros(latent_size, self.dim_embd))
self.feat_emb = nn.Linear(256, self.dim_embd) self.feat_emb = nn.Linear(256, self.dim_embd)
# transformer # transformer
self.ft_layers = nn.Sequential(*[TransformerSALayer(embed_dim=dim_embd, nhead=n_head, dim_mlp=self.dim_mlp, dropout=0.0) self.ft_layers = nn.Sequential(
for _ in range(self.n_layers)]) *[
TransformerSALayer(
embed_dim=dim_embd, nhead=n_head, dim_mlp=self.dim_mlp, dropout=0.0
)
for _ in range(self.n_layers)
]
)
# logits_predict head # logits_predict head
self.idx_pred_layer = nn.Sequential( self.idx_pred_layer = nn.Sequential(
nn.LayerNorm(dim_embd), nn.LayerNorm(dim_embd), nn.Linear(dim_embd, codebook_size, bias=False)
nn.Linear(dim_embd, codebook_size, bias=False)) )
self.channels = { self.channels = {
'16': 512, "16": 512,
'32': 256, "32": 256,
'64': 256, "64": 256,
'128': 128, "128": 128,
'256': 128, "256": 128,
'512': 64, "512": 64,
} }
# after second residual block for > 16, before attn layer for ==16 # after second residual block for > 16, before attn layer for ==16
self.fuse_encoder_block = {'512':2, '256':5, '128':8, '64':11, '32':14, '16':18} self.fuse_encoder_block = {
"512": 2,
"256": 5,
"128": 8,
"64": 11,
"32": 14,
"16": 18,
}
# after first residual block for > 16, before attn layer for ==16 # after first residual block for > 16, before attn layer for ==16
self.fuse_generator_block = {'16':6, '32': 9, '64':12, '128':15, '256':18, '512':21} self.fuse_generator_block = {
"16": 6,
"32": 9,
"64": 12,
"128": 15,
"256": 18,
"512": 21,
}
# fuse_convs_dict # fuse_convs_dict
self.fuse_convs_dict = nn.ModuleDict() self.fuse_convs_dict = nn.ModuleDict()
@ -228,20 +273,20 @@ class CodeFormer(VQAutoEncoder):
lq_feat = x lq_feat = x
# ################# Transformer ################### # ################# Transformer ###################
# quant_feat, codebook_loss, quant_stats = self.quantize(lq_feat) # quant_feat, codebook_loss, quant_stats = self.quantize(lq_feat)
pos_emb = self.position_emb.unsqueeze(1).repeat(1,x.shape[0],1) pos_emb = self.position_emb.unsqueeze(1).repeat(1, x.shape[0], 1)
# BCHW -> BC(HW) -> (HW)BC # BCHW -> BC(HW) -> (HW)BC
feat_emb = self.feat_emb(lq_feat.flatten(2).permute(2,0,1)) feat_emb = self.feat_emb(lq_feat.flatten(2).permute(2, 0, 1))
query_emb = feat_emb query_emb = feat_emb
# Transformer encoder # Transformer encoder
for layer in self.ft_layers: for layer in self.ft_layers:
query_emb = layer(query_emb, query_pos=pos_emb) query_emb = layer(query_emb, query_pos=pos_emb)
# output logits # output logits
logits = self.idx_pred_layer(query_emb) # (hw)bn logits = self.idx_pred_layer(query_emb) # (hw)bn
logits = logits.permute(1,0,2) # (hw)bn -> b(hw)n logits = logits.permute(1, 0, 2) # (hw)bn -> b(hw)n
if code_only: # for training stage II if code_only: # for training stage II
# logits doesn't need softmax before cross_entropy loss # logits doesn't need softmax before cross_entropy loss
return logits, lq_feat return logits, lq_feat
# ################# Quantization ################### # ################# Quantization ###################
@ -252,12 +297,14 @@ class CodeFormer(VQAutoEncoder):
# ------------ # ------------
soft_one_hot = F.softmax(logits, dim=2) soft_one_hot = F.softmax(logits, dim=2)
_, top_idx = torch.topk(soft_one_hot, 1, dim=2) _, top_idx = torch.topk(soft_one_hot, 1, dim=2)
quant_feat = self.quantize.get_codebook_feat(top_idx, shape=[x.shape[0],16,16,256]) quant_feat = self.quantize.get_codebook_feat(
top_idx, shape=[x.shape[0], 16, 16, 256]
)
# preserve gradients # preserve gradients
# quant_feat = lq_feat + (quant_feat - lq_feat).detach() # quant_feat = lq_feat + (quant_feat - lq_feat).detach()
if detach_16: if detach_16:
quant_feat = quant_feat.detach() # for training stage III quant_feat = quant_feat.detach() # for training stage III
if adain: if adain:
quant_feat = adaptive_instance_normalization(quant_feat, lq_feat) quant_feat = adaptive_instance_normalization(quant_feat, lq_feat)
@ -267,10 +314,12 @@ class CodeFormer(VQAutoEncoder):
for i, block in enumerate(self.generator.blocks): for i, block in enumerate(self.generator.blocks):
x = block(x) x = block(x)
if i in fuse_list: # fuse after i-th block if i in fuse_list: # fuse after i-th block
f_size = str(x.shape[-1]) f_size = str(x.shape[-1])
if w>0: if w > 0:
x = self.fuse_convs_dict[f_size](enc_feat_dict[f_size].detach(), x, w) x = self.fuse_convs_dict[f_size](
enc_feat_dict[f_size].detach(), x, w
)
out = x out = x
# logits doesn't need softmax before cross_entropy loss # logits doesn't need softmax before cross_entropy loss
return out, logits, lq_feat return out, logits, lq_feat

View File

@ -1,26 +1,25 @@
import torch
import warnings
import os import os
import sys import sys
import numpy as np import warnings
from ldm.invoke.globals import Globals
import numpy as np
import torch
from PIL import Image from PIL import Image
from invokeai.backend.globals import Globals
class GFPGAN():
def __init__(
self,
gfpgan_model_path='models/gfpgan/GFPGANv1.4.pth'
) -> None:
class GFPGAN:
def __init__(self, gfpgan_model_path="models/gfpgan/GFPGANv1.4.pth") -> None:
if not os.path.isabs(gfpgan_model_path): if not os.path.isabs(gfpgan_model_path):
gfpgan_model_path=os.path.abspath(os.path.join(Globals.root,gfpgan_model_path)) gfpgan_model_path = os.path.abspath(
os.path.join(Globals.root, gfpgan_model_path)
)
self.model_path = gfpgan_model_path self.model_path = gfpgan_model_path
self.gfpgan_model_exists = os.path.isfile(self.model_path) self.gfpgan_model_exists = os.path.isfile(self.model_path)
if not self.gfpgan_model_exists: if not self.gfpgan_model_exists:
print('## NOT FOUND: GFPGAN model not found at ' + self.model_path) print("## NOT FOUND: GFPGAN model not found at " + self.model_path)
return None return None
def model_exists(self): def model_exists(self):
@ -28,40 +27,40 @@ class GFPGAN():
def process(self, image, strength: float, seed: str = None): def process(self, image, strength: float, seed: str = None):
if seed is not None: if seed is not None:
print(f'>> GFPGAN - Restoring Faces for image seed:{seed}') print(f">> GFPGAN - Restoring Faces for image seed:{seed}")
with warnings.catch_warnings(): with warnings.catch_warnings():
warnings.filterwarnings('ignore', category=DeprecationWarning) warnings.filterwarnings("ignore", category=DeprecationWarning)
warnings.filterwarnings('ignore', category=UserWarning) warnings.filterwarnings("ignore", category=UserWarning)
cwd = os.getcwd() cwd = os.getcwd()
os.chdir(os.path.join(Globals.root,'models')) os.chdir(os.path.join(Globals.root, "models"))
try: try:
from gfpgan import GFPGANer from gfpgan import GFPGANer
self.gfpgan = GFPGANer( self.gfpgan = GFPGANer(
model_path=self.model_path, model_path=self.model_path,
upscale=1, upscale=1,
arch='clean', arch="clean",
channel_multiplier=2, channel_multiplier=2,
bg_upsampler=None, bg_upsampler=None,
) )
except Exception: except Exception:
import traceback import traceback
print('>> Error loading GFPGAN:', file=sys.stderr)
print(">> Error loading GFPGAN:", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr) print(traceback.format_exc(), file=sys.stderr)
os.chdir(cwd) os.chdir(cwd)
if self.gfpgan is None: if self.gfpgan is None:
print(f">> WARNING: GFPGAN not initialized.")
print( print(
f'>> WARNING: GFPGAN not initialized.' f">> Download https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth to {self.model_path}"
)
print(
f'>> Download https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth to {self.model_path}'
) )
image = image.convert('RGB') image = image.convert("RGB")
# GFPGAN expects a BGR np array; make array and flip channels # GFPGAN expects a BGR np array; make array and flip channels
bgr_image_array = np.array(image, dtype=np.uint8)[...,::-1] bgr_image_array = np.array(image, dtype=np.uint8)[..., ::-1]
_, _, restored_img = self.gfpgan.enhance( _, _, restored_img = self.gfpgan.enhance(
bgr_image_array, bgr_image_array,
@ -71,7 +70,7 @@ class GFPGAN():
) )
# Flip the channels back to RGB # Flip the channels back to RGB
res = Image.fromarray(restored_img[...,::-1]) res = Image.fromarray(restored_img[..., ::-1])
if strength < 1.0: if strength < 1.0:
# Resize the image to the new image if the sizes have changed # Resize the image to the new image if the sizes have changed
@ -79,7 +78,6 @@ class GFPGAN():
image = image.resize(res.size) image = image.resize(res.size)
res = Image.blend(image, res, strength) res = Image.blend(image, res, strength)
if torch.cuda.is_available(): if torch.cuda.is_available():
torch.cuda.empty_cache() torch.cuda.empty_cache()
self.gfpgan = None self.gfpgan = None

View File

@ -0,0 +1,119 @@
import math
import warnings
from PIL import Image, ImageFilter
class Outcrop(object):
def __init__(
self,
image,
generate, # current generate object
):
self.image = image
self.generate = generate
def process(
self,
extents: dict,
opt, # current options
orig_opt, # ones originally used to generate the image
image_callback=None,
prefix=None,
):
# grow and mask the image
extended_image = self._extend_all(extents)
# switch samplers temporarily
curr_sampler = self.generate.sampler
self.generate.sampler_name = opt.sampler_name
self.generate._set_sampler()
def wrapped_callback(img, seed, **kwargs):
preferred_seed = (
orig_opt.seed
if orig_opt.seed is not None and orig_opt.seed >= 0
else seed
)
image_callback(img, preferred_seed, use_prefix=prefix, **kwargs)
result = self.generate.prompt2image(
opt.prompt,
seed=opt.seed or orig_opt.seed,
sampler=self.generate.sampler,
steps=opt.steps,
cfg_scale=opt.cfg_scale,
ddim_eta=self.generate.ddim_eta,
width=extended_image.width,
height=extended_image.height,
init_img=extended_image,
strength=0.90,
image_callback=wrapped_callback if image_callback else None,
seam_size=opt.seam_size or 96,
seam_blur=opt.seam_blur or 16,
seam_strength=opt.seam_strength or 0.7,
seam_steps=20,
tile_size=32,
color_match=True,
force_outpaint=True, # this just stops the warning about erased regions
)
# swap sampler back
self.generate.sampler = curr_sampler
return result
def _extend_all(
self,
extents: dict,
) -> Image:
"""
Extend the image in direction ('top','bottom','left','right') by
the indicated value. The image canvas is extended, and the empty
rectangular section will be filled with a blurred copy of the
adjacent image.
"""
image = self.image
for direction in extents:
assert direction in [
"top",
"left",
"bottom",
"right",
], 'Direction must be one of "top", "left", "bottom", "right"'
pixels = extents[direction]
# round pixels up to the nearest 64
pixels = math.ceil(pixels / 64) * 64
print(f">> extending image {direction}ward by {pixels} pixels")
image = self._rotate(image, direction)
image = self._extend(image, pixels)
image = self._rotate(image, direction, reverse=True)
return image
def _rotate(self, image: Image, direction: str, reverse=False) -> Image:
"""
Rotates image so that the area to extend is always at the top top.
Simplifies logic later. The reverse argument, if true, will undo the
previous transpose.
"""
transposes = {
"right": ["ROTATE_90", "ROTATE_270"],
"bottom": ["ROTATE_180", "ROTATE_180"],
"left": ["ROTATE_270", "ROTATE_90"],
}
if direction not in transposes:
return image
transpose = transposes[direction][1 if reverse else 0]
return image.transpose(Image.Transpose.__dict__[transpose])
def _extend(self, image: Image, pixels: int) -> Image:
extended_img = Image.new("RGBA", (image.width, image.height + pixels))
extended_img.paste((0, 0, 0), [0, 0, image.width, image.height + pixels])
extended_img.paste(image, box=(0, pixels))
# now make the top part transparent to use as a mask
alpha = extended_img.getchannel("A")
alpha.paste(0, (0, 0, extended_img.width, pixels))
extended_img.putalpha(alpha)
return extended_img

View File

@ -1,39 +1,43 @@
import warnings
import math import math
import warnings
from PIL import Image, ImageFilter from PIL import Image, ImageFilter
class Outpaint(object): class Outpaint(object):
def __init__(self, image, generate): def __init__(self, image, generate):
self.image = image self.image = image
self.generate = generate self.generate = generate
def process(self, opt, old_opt, image_callback = None, prefix = None): def process(self, opt, old_opt, image_callback=None, prefix=None):
image = self._create_outpaint_image(self.image, opt.out_direction) image = self._create_outpaint_image(self.image, opt.out_direction)
seed = old_opt.seed seed = old_opt.seed
prompt = old_opt.prompt prompt = old_opt.prompt
def wrapped_callback(img,seed,**kwargs): def wrapped_callback(img, seed, **kwargs):
image_callback(img,seed,use_prefix=prefix,**kwargs) image_callback(img, seed, use_prefix=prefix, **kwargs)
return self.generate.prompt2image( return self.generate.prompt2image(
prompt, prompt,
seed = seed, seed=seed,
sampler = self.generate.sampler, sampler=self.generate.sampler,
steps = opt.steps, steps=opt.steps,
cfg_scale = opt.cfg_scale, cfg_scale=opt.cfg_scale,
ddim_eta = self.generate.ddim_eta, ddim_eta=self.generate.ddim_eta,
width = opt.width, width=opt.width,
height = opt.height, height=opt.height,
init_img = image, init_img=image,
strength = 0.83, strength=0.83,
image_callback = wrapped_callback, image_callback=wrapped_callback,
prefix = prefix, prefix=prefix,
) )
def _create_outpaint_image(self, image, direction_args): def _create_outpaint_image(self, image, direction_args):
assert len(direction_args) in [1, 2], 'Direction (-D) must have exactly one or two arguments.' assert len(direction_args) in [
1,
2,
], "Direction (-D) must have exactly one or two arguments."
if len(direction_args) == 1: if len(direction_args) == 1:
direction = direction_args[0] direction = direction_args[0]
@ -42,19 +46,26 @@ class Outpaint(object):
direction = direction_args[0] direction = direction_args[0]
pixels = int(direction_args[1]) pixels = int(direction_args[1])
assert direction in ['top', 'left', 'bottom', 'right'], 'Direction (-D) must be one of "top", "left", "bottom", "right"' assert direction in [
"top",
"left",
"bottom",
"right",
], 'Direction (-D) must be one of "top", "left", "bottom", "right"'
image = image.convert("RGBA") image = image.convert("RGBA")
# we always extend top, but rotate to extend along the requested side # we always extend top, but rotate to extend along the requested side
if direction == 'left': if direction == "left":
image = image.transpose(Image.Transpose.ROTATE_270) image = image.transpose(Image.Transpose.ROTATE_270)
elif direction == 'bottom': elif direction == "bottom":
image = image.transpose(Image.Transpose.ROTATE_180) image = image.transpose(Image.Transpose.ROTATE_180)
elif direction == 'right': elif direction == "right":
image = image.transpose(Image.Transpose.ROTATE_90) image = image.transpose(Image.Transpose.ROTATE_90)
pixels = image.height//2 if pixels is None else int(pixels) pixels = image.height // 2 if pixels is None else int(pixels)
assert 0 < pixels < image.height, 'Direction (-D) pixels length must be in the range 0 - image.size' assert (
0 < pixels < image.height
), "Direction (-D) pixels length must be in the range 0 - image.size"
# the top part of the image is taken from the source image mirrored # the top part of the image is taken from the source image mirrored
# coordinates (0,0) are the upper left corner of an image # coordinates (0,0) are the upper left corner of an image
@ -74,19 +85,18 @@ class Outpaint(object):
new_img.paste(bottom, (0, pixels)) new_img.paste(bottom, (0, pixels))
# create a 10% dither in the middle # create a 10% dither in the middle
dither = min(image.height//10, pixels) dither = min(image.height // 10, pixels)
for x in range(0, image.width, 2): for x in range(0, image.width, 2):
for y in range(pixels - dither, pixels + dither): for y in range(pixels - dither, pixels + dither):
(r, g, b, a) = new_img.getpixel((x, y)) (r, g, b, a) = new_img.getpixel((x, y))
new_img.putpixel((x, y), (r, g, b, 0)) new_img.putpixel((x, y), (r, g, b, 0))
# let's rotate back again # let's rotate back again
if direction == 'left': if direction == "left":
new_img = new_img.transpose(Image.Transpose.ROTATE_90) new_img = new_img.transpose(Image.Transpose.ROTATE_90)
elif direction == 'bottom': elif direction == "bottom":
new_img = new_img.transpose(Image.Transpose.ROTATE_180) new_img = new_img.transpose(Image.Transpose.ROTATE_180)
elif direction == 'right': elif direction == "right":
new_img = new_img.transpose(Image.Transpose.ROTATE_270) new_img = new_img.transpose(Image.Transpose.ROTATE_270)
return new_img return new_img

View File

@ -1,13 +1,15 @@
import torch
import warnings
import numpy as np
import os import os
import warnings
from ldm.invoke.globals import Globals import numpy as np
import torch
from PIL import Image from PIL import Image
from PIL.Image import Image as ImageType from PIL.Image import Image as ImageType
class ESRGAN(): from invokeai.backend.globals import Globals
class ESRGAN:
def __init__(self, bg_tile_size=400) -> None: def __init__(self, bg_tile_size=400) -> None:
self.bg_tile_size = bg_tile_size self.bg_tile_size = bg_tile_size
@ -22,12 +24,23 @@ class ESRGAN():
else: else:
use_half_precision = True use_half_precision = True
from realesrgan.archs.srvgg_arch import SRVGGNetCompact
from realesrgan import RealESRGANer from realesrgan import RealESRGANer
from realesrgan.archs.srvgg_arch import SRVGGNetCompact
model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu') model = SRVGGNetCompact(
model_path = os.path.join(Globals.root, 'models/realesrgan/realesr-general-x4v3.pth') num_in_ch=3,
wdn_model_path = os.path.join(Globals.root, 'models/realesrgan/realesr-general-wdn-x4v3.pth') num_out_ch=3,
num_feat=64,
num_conv=32,
upscale=4,
act_type="prelu",
)
model_path = os.path.join(
Globals.root, "models/realesrgan/realesr-general-x4v3.pth"
)
wdn_model_path = os.path.join(
Globals.root, "models/realesrgan/realesr-general-wdn-x4v3.pth"
)
scale = 4 scale = 4
bg_upsampler = RealESRGANer( bg_upsampler = RealESRGANer(
@ -43,41 +56,49 @@ class ESRGAN():
return bg_upsampler return bg_upsampler
def process(self, image: ImageType, strength: float, seed: str = None, upsampler_scale: int = 2, denoise_str: float = 0.75): def process(
self,
image: ImageType,
strength: float,
seed: str = None,
upsampler_scale: int = 2,
denoise_str: float = 0.75,
):
with warnings.catch_warnings(): with warnings.catch_warnings():
warnings.filterwarnings('ignore', category=DeprecationWarning) warnings.filterwarnings("ignore", category=DeprecationWarning)
warnings.filterwarnings('ignore', category=UserWarning) warnings.filterwarnings("ignore", category=UserWarning)
try: try:
upsampler = self.load_esrgan_bg_upsampler(denoise_str) upsampler = self.load_esrgan_bg_upsampler(denoise_str)
except Exception: except Exception:
import traceback
import sys import sys
print('>> Error loading Real-ESRGAN:', file=sys.stderr) import traceback
print(">> Error loading Real-ESRGAN:", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr) print(traceback.format_exc(), file=sys.stderr)
if upsampler_scale == 0: if upsampler_scale == 0:
print('>> Real-ESRGAN: Invalid scaling option. Image not upscaled.') print(">> Real-ESRGAN: Invalid scaling option. Image not upscaled.")
return image return image
if seed is not None: if seed is not None:
print( print(
f'>> Real-ESRGAN Upscaling seed:{seed}, scale:{upsampler_scale}x, tile:{self.bg_tile_size}, denoise:{denoise_str}' f">> Real-ESRGAN Upscaling seed:{seed}, scale:{upsampler_scale}x, tile:{self.bg_tile_size}, denoise:{denoise_str}"
) )
# ESRGAN outputs images with partial transparency if given RGBA images; convert to RGB # ESRGAN outputs images with partial transparency if given RGBA images; convert to RGB
image = image.convert("RGB") image = image.convert("RGB")
# REALSRGAN expects a BGR np array; make array and flip channels # REALSRGAN expects a BGR np array; make array and flip channels
bgr_image_array = np.array(image, dtype=np.uint8)[...,::-1] bgr_image_array = np.array(image, dtype=np.uint8)[..., ::-1]
output, _ = upsampler.enhance( output, _ = upsampler.enhance(
bgr_image_array, bgr_image_array,
outscale=upsampler_scale, outscale=upsampler_scale,
alpha_upsampler='realesrgan', alpha_upsampler="realesrgan",
) )
# Flip the channels back to RGB # Flip the channels back to RGB
res = Image.fromarray(output[...,::-1]) res = Image.fromarray(output[..., ::-1])
if strength < 1.0: if strength < 1.0:
# Resize the image to the new image if the sizes have changed # Resize the image to the new image if the sizes have changed

View File

@ -1,23 +1,27 @@
''' """
VQGAN code, adapted from the original created by the Unleashing Transformers authors: VQGAN code, adapted from the original created by the Unleashing Transformers authors:
https://github.com/samb-t/unleashing-transformers/blob/master/models/vqgan.py https://github.com/samb-t/unleashing-transformers/blob/master/models/vqgan.py
''' """
import copy
import numpy as np import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import copy
from basicsr.utils import get_root_logger from basicsr.utils import get_root_logger
from basicsr.utils.registry import ARCH_REGISTRY from basicsr.utils.registry import ARCH_REGISTRY
def normalize(in_channels): def normalize(in_channels):
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) return torch.nn.GroupNorm(
num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
)
@torch.jit.script @torch.jit.script
def swish(x): def swish(x):
return x*torch.sigmoid(x) return x * torch.sigmoid(x)
# Define VQVAE classes # Define VQVAE classes
@ -28,7 +32,9 @@ class VectorQuantizer(nn.Module):
self.emb_dim = emb_dim # dimension of embedding self.emb_dim = emb_dim # dimension of embedding
self.beta = beta # commitment cost used in loss term, beta * ||z_e(x)-sg[e]||^2 self.beta = beta # commitment cost used in loss term, beta * ||z_e(x)-sg[e]||^2
self.embedding = nn.Embedding(self.codebook_size, self.emb_dim) self.embedding = nn.Embedding(self.codebook_size, self.emb_dim)
self.embedding.weight.data.uniform_(-1.0 / self.codebook_size, 1.0 / self.codebook_size) self.embedding.weight.data.uniform_(
-1.0 / self.codebook_size, 1.0 / self.codebook_size
)
def forward(self, z): def forward(self, z):
# reshape z -> (batch, height, width, channel) and flatten # reshape z -> (batch, height, width, channel) and flatten
@ -36,23 +42,32 @@ class VectorQuantizer(nn.Module):
z_flattened = z.view(-1, self.emb_dim) z_flattened = z.view(-1, self.emb_dim)
# distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
d = (z_flattened ** 2).sum(dim=1, keepdim=True) + (self.embedding.weight**2).sum(1) - \ d = (
2 * torch.matmul(z_flattened, self.embedding.weight.t()) (z_flattened**2).sum(dim=1, keepdim=True)
+ (self.embedding.weight**2).sum(1)
- 2 * torch.matmul(z_flattened, self.embedding.weight.t())
)
mean_distance = torch.mean(d) mean_distance = torch.mean(d)
# find closest encodings # find closest encodings
# min_encoding_indices = torch.argmin(d, dim=1).unsqueeze(1) # min_encoding_indices = torch.argmin(d, dim=1).unsqueeze(1)
min_encoding_scores, min_encoding_indices = torch.topk(d, 1, dim=1, largest=False) min_encoding_scores, min_encoding_indices = torch.topk(
d, 1, dim=1, largest=False
)
# [0-1], higher score, higher confidence # [0-1], higher score, higher confidence
min_encoding_scores = torch.exp(-min_encoding_scores/10) min_encoding_scores = torch.exp(-min_encoding_scores / 10)
min_encodings = torch.zeros(min_encoding_indices.shape[0], self.codebook_size).to(z) min_encodings = torch.zeros(
min_encoding_indices.shape[0], self.codebook_size
).to(z)
min_encodings.scatter_(1, min_encoding_indices, 1) min_encodings.scatter_(1, min_encoding_indices, 1)
# get quantized latent vectors # get quantized latent vectors
z_q = torch.matmul(min_encodings, self.embedding.weight).view(z.shape) z_q = torch.matmul(min_encodings, self.embedding.weight).view(z.shape)
# compute loss for embedding # compute loss for embedding
loss = torch.mean((z_q.detach()-z)**2) + self.beta * torch.mean((z_q - z.detach()) ** 2) loss = torch.mean((z_q.detach() - z) ** 2) + self.beta * torch.mean(
(z_q - z.detach()) ** 2
)
# preserve gradients # preserve gradients
z_q = z + (z_q - z).detach() z_q = z + (z_q - z).detach()
@ -62,18 +77,22 @@ class VectorQuantizer(nn.Module):
# reshape back to match original input shape # reshape back to match original input shape
z_q = z_q.permute(0, 3, 1, 2).contiguous() z_q = z_q.permute(0, 3, 1, 2).contiguous()
return z_q, loss, { return (
"perplexity": perplexity, z_q,
"min_encodings": min_encodings, loss,
"min_encoding_indices": min_encoding_indices, {
"min_encoding_scores": min_encoding_scores, "perplexity": perplexity,
"mean_distance": mean_distance "min_encodings": min_encodings,
} "min_encoding_indices": min_encoding_indices,
"min_encoding_scores": min_encoding_scores,
"mean_distance": mean_distance,
},
)
def get_codebook_feat(self, indices, shape): def get_codebook_feat(self, indices, shape):
# input indices: batch*token_num -> (batch*token_num)*1 # input indices: batch*token_num -> (batch*token_num)*1
# shape: batch, height, width, channel # shape: batch, height, width, channel
indices = indices.view(-1,1) indices = indices.view(-1, 1)
min_encodings = torch.zeros(indices.shape[0], self.codebook_size).to(indices) min_encodings = torch.zeros(indices.shape[0], self.codebook_size).to(indices)
min_encodings.scatter_(1, indices, 1) min_encodings.scatter_(1, indices, 1)
# get quantized latent vectors # get quantized latent vectors
@ -86,14 +105,24 @@ class VectorQuantizer(nn.Module):
class GumbelQuantizer(nn.Module): class GumbelQuantizer(nn.Module):
def __init__(self, codebook_size, emb_dim, num_hiddens, straight_through=False, kl_weight=5e-4, temp_init=1.0): def __init__(
self,
codebook_size,
emb_dim,
num_hiddens,
straight_through=False,
kl_weight=5e-4,
temp_init=1.0,
):
super().__init__() super().__init__()
self.codebook_size = codebook_size # number of embeddings self.codebook_size = codebook_size # number of embeddings
self.emb_dim = emb_dim # dimension of embedding self.emb_dim = emb_dim # dimension of embedding
self.straight_through = straight_through self.straight_through = straight_through
self.temperature = temp_init self.temperature = temp_init
self.kl_weight = kl_weight self.kl_weight = kl_weight
self.proj = nn.Conv2d(num_hiddens, codebook_size, 1) # projects last encoder layer to quantized logits self.proj = nn.Conv2d(
num_hiddens, codebook_size, 1
) # projects last encoder layer to quantized logits
self.embed = nn.Embedding(codebook_size, emb_dim) self.embed = nn.Embedding(codebook_size, emb_dim)
def forward(self, z): def forward(self, z):
@ -107,18 +136,21 @@ class GumbelQuantizer(nn.Module):
# + kl divergence to the prior loss # + kl divergence to the prior loss
qy = F.softmax(logits, dim=1) qy = F.softmax(logits, dim=1)
diff = self.kl_weight * torch.sum(qy * torch.log(qy * self.codebook_size + 1e-10), dim=1).mean() diff = (
self.kl_weight
* torch.sum(qy * torch.log(qy * self.codebook_size + 1e-10), dim=1).mean()
)
min_encoding_indices = soft_one_hot.argmax(dim=1) min_encoding_indices = soft_one_hot.argmax(dim=1)
return z_q, diff, { return z_q, diff, {"min_encoding_indices": min_encoding_indices}
"min_encoding_indices": min_encoding_indices
}
class Downsample(nn.Module): class Downsample(nn.Module):
def __init__(self, in_channels): def __init__(self, in_channels):
super().__init__() super().__init__()
self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0) self.conv = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=3, stride=2, padding=0
)
def forward(self, x): def forward(self, x):
pad = (0, 1, 0, 1) pad = (0, 1, 0, 1)
@ -130,7 +162,9 @@ class Downsample(nn.Module):
class Upsample(nn.Module): class Upsample(nn.Module):
def __init__(self, in_channels): def __init__(self, in_channels):
super().__init__() super().__init__()
self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1) self.conv = nn.Conv2d(
in_channels, in_channels, kernel_size=3, stride=1, padding=1
)
def forward(self, x): def forward(self, x):
x = F.interpolate(x, scale_factor=2.0, mode="nearest") x = F.interpolate(x, scale_factor=2.0, mode="nearest")
@ -145,11 +179,17 @@ class ResBlock(nn.Module):
self.in_channels = in_channels self.in_channels = in_channels
self.out_channels = in_channels if out_channels is None else out_channels self.out_channels = in_channels if out_channels is None else out_channels
self.norm1 = normalize(in_channels) self.norm1 = normalize(in_channels)
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) self.conv1 = nn.Conv2d(
in_channels, out_channels, kernel_size=3, stride=1, padding=1
)
self.norm2 = normalize(out_channels) self.norm2 = normalize(out_channels)
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) self.conv2 = nn.Conv2d(
out_channels, out_channels, kernel_size=3, stride=1, padding=1
)
if self.in_channels != self.out_channels: if self.in_channels != self.out_channels:
self.conv_out = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) self.conv_out = nn.Conv2d(
in_channels, out_channels, kernel_size=1, stride=1, padding=0
)
def forward(self, x_in): def forward(self, x_in):
x = x_in x = x_in
@ -172,32 +212,16 @@ class AttnBlock(nn.Module):
self.norm = normalize(in_channels) self.norm = normalize(in_channels)
self.q = torch.nn.Conv2d( self.q = torch.nn.Conv2d(
in_channels, in_channels, in_channels, kernel_size=1, stride=1, padding=0
in_channels,
kernel_size=1,
stride=1,
padding=0
) )
self.k = torch.nn.Conv2d( self.k = torch.nn.Conv2d(
in_channels, in_channels, in_channels, kernel_size=1, stride=1, padding=0
in_channels,
kernel_size=1,
stride=1,
padding=0
) )
self.v = torch.nn.Conv2d( self.v = torch.nn.Conv2d(
in_channels, in_channels, in_channels, kernel_size=1, stride=1, padding=0
in_channels,
kernel_size=1,
stride=1,
padding=0
) )
self.proj_out = torch.nn.Conv2d( self.proj_out = torch.nn.Conv2d(
in_channels, in_channels, in_channels, kernel_size=1, stride=1, padding=0
in_channels,
kernel_size=1,
stride=1,
padding=0
) )
def forward(self, x): def forward(self, x):
@ -209,26 +233,35 @@ class AttnBlock(nn.Module):
# compute attention # compute attention
b, c, h, w = q.shape b, c, h, w = q.shape
q = q.reshape(b, c, h*w) q = q.reshape(b, c, h * w)
q = q.permute(0, 2, 1) q = q.permute(0, 2, 1)
k = k.reshape(b, c, h*w) k = k.reshape(b, c, h * w)
w_ = torch.bmm(q, k) w_ = torch.bmm(q, k)
w_ = w_ * (int(c)**(-0.5)) w_ = w_ * (int(c) ** (-0.5))
w_ = F.softmax(w_, dim=2) w_ = F.softmax(w_, dim=2)
# attend to values # attend to values
v = v.reshape(b, c, h*w) v = v.reshape(b, c, h * w)
w_ = w_.permute(0, 2, 1) w_ = w_.permute(0, 2, 1)
h_ = torch.bmm(v, w_) h_ = torch.bmm(v, w_)
h_ = h_.reshape(b, c, h, w) h_ = h_.reshape(b, c, h, w)
h_ = self.proj_out(h_) h_ = self.proj_out(h_)
return x+h_ return x + h_
class Encoder(nn.Module): class Encoder(nn.Module):
def __init__(self, in_channels, nf, emb_dim, ch_mult, num_res_blocks, resolution, attn_resolutions): def __init__(
self,
in_channels,
nf,
emb_dim,
ch_mult,
num_res_blocks,
resolution,
attn_resolutions,
):
super().__init__() super().__init__()
self.nf = nf self.nf = nf
self.num_resolutions = len(ch_mult) self.num_resolutions = len(ch_mult)
@ -237,7 +270,7 @@ class Encoder(nn.Module):
self.attn_resolutions = attn_resolutions self.attn_resolutions = attn_resolutions
curr_res = self.resolution curr_res = self.resolution
in_ch_mult = (1,)+tuple(ch_mult) in_ch_mult = (1,) + tuple(ch_mult)
blocks = [] blocks = []
# initial convultion # initial convultion
@ -264,7 +297,9 @@ class Encoder(nn.Module):
# normalise and convert to latent size # normalise and convert to latent size
blocks.append(normalize(block_in_ch)) blocks.append(normalize(block_in_ch))
blocks.append(nn.Conv2d(block_in_ch, emb_dim, kernel_size=3, stride=1, padding=1)) blocks.append(
nn.Conv2d(block_in_ch, emb_dim, kernel_size=3, stride=1, padding=1)
)
self.blocks = nn.ModuleList(blocks) self.blocks = nn.ModuleList(blocks)
def forward(self, x): def forward(self, x):
@ -286,11 +321,13 @@ class Generator(nn.Module):
self.in_channels = emb_dim self.in_channels = emb_dim
self.out_channels = 3 self.out_channels = 3
block_in_ch = self.nf * self.ch_mult[-1] block_in_ch = self.nf * self.ch_mult[-1]
curr_res = self.resolution // 2 ** (self.num_resolutions-1) curr_res = self.resolution // 2 ** (self.num_resolutions - 1)
blocks = [] blocks = []
# initial conv # initial conv
blocks.append(nn.Conv2d(self.in_channels, block_in_ch, kernel_size=3, stride=1, padding=1)) blocks.append(
nn.Conv2d(self.in_channels, block_in_ch, kernel_size=3, stride=1, padding=1)
)
# non-local attention block # non-local attention block
blocks.append(ResBlock(block_in_ch, block_in_ch)) blocks.append(ResBlock(block_in_ch, block_in_ch))
@ -312,11 +349,14 @@ class Generator(nn.Module):
curr_res = curr_res * 2 curr_res = curr_res * 2
blocks.append(normalize(block_in_ch)) blocks.append(normalize(block_in_ch))
blocks.append(nn.Conv2d(block_in_ch, self.out_channels, kernel_size=3, stride=1, padding=1)) blocks.append(
nn.Conv2d(
block_in_ch, self.out_channels, kernel_size=3, stride=1, padding=1
)
)
self.blocks = nn.ModuleList(blocks) self.blocks = nn.ModuleList(blocks)
def forward(self, x): def forward(self, x):
for block in self.blocks: for block in self.blocks:
x = block(x) x = block(x)
@ -326,8 +366,21 @@ class Generator(nn.Module):
@ARCH_REGISTRY.register() @ARCH_REGISTRY.register()
class VQAutoEncoder(nn.Module): class VQAutoEncoder(nn.Module):
def __init__(self, img_size, nf, ch_mult, quantizer="nearest", res_blocks=2, attn_resolutions=[16], codebook_size=1024, emb_dim=256, def __init__(
beta=0.25, gumbel_straight_through=False, gumbel_kl_weight=1e-8, model_path=None): self,
img_size,
nf,
ch_mult,
quantizer="nearest",
res_blocks=2,
attn_resolutions=[16],
codebook_size=1024,
emb_dim=256,
beta=0.25,
gumbel_straight_through=False,
gumbel_kl_weight=1e-8,
model_path=None,
):
super().__init__() super().__init__()
logger = get_root_logger() logger = get_root_logger()
self.in_channels = 3 self.in_channels = 3
@ -346,11 +399,13 @@ class VQAutoEncoder(nn.Module):
self.ch_mult, self.ch_mult,
self.n_blocks, self.n_blocks,
self.resolution, self.resolution,
self.attn_resolutions self.attn_resolutions,
) )
if self.quantizer_type == "nearest": if self.quantizer_type == "nearest":
self.beta = beta #0.25 self.beta = beta # 0.25
self.quantize = VectorQuantizer(self.codebook_size, self.embed_dim, self.beta) self.quantize = VectorQuantizer(
self.codebook_size, self.embed_dim, self.beta
)
elif self.quantizer_type == "gumbel": elif self.quantizer_type == "gumbel":
self.gumbel_num_hiddens = emb_dim self.gumbel_num_hiddens = emb_dim
self.straight_through = gumbel_straight_through self.straight_through = gumbel_straight_through
@ -360,7 +415,7 @@ class VQAutoEncoder(nn.Module):
self.embed_dim, self.embed_dim,
self.gumbel_num_hiddens, self.gumbel_num_hiddens,
self.straight_through, self.straight_through,
self.kl_weight self.kl_weight,
) )
self.generator = Generator( self.generator = Generator(
self.nf, self.nf,
@ -368,20 +423,23 @@ class VQAutoEncoder(nn.Module):
self.ch_mult, self.ch_mult,
self.n_blocks, self.n_blocks,
self.resolution, self.resolution,
self.attn_resolutions self.attn_resolutions,
) )
if model_path is not None: if model_path is not None:
chkpt = torch.load(model_path, map_location='cpu') chkpt = torch.load(model_path, map_location="cpu")
if 'params_ema' in chkpt: if "params_ema" in chkpt:
self.load_state_dict(torch.load(model_path, map_location='cpu')['params_ema']) self.load_state_dict(
logger.info(f'vqgan is loaded from: {model_path} [params_ema]') torch.load(model_path, map_location="cpu")["params_ema"]
elif 'params' in chkpt: )
self.load_state_dict(torch.load(model_path, map_location='cpu')['params']) logger.info(f"vqgan is loaded from: {model_path} [params_ema]")
logger.info(f'vqgan is loaded from: {model_path} [params]') elif "params" in chkpt:
self.load_state_dict(
torch.load(model_path, map_location="cpu")["params"]
)
logger.info(f"vqgan is loaded from: {model_path} [params]")
else: else:
raise ValueError(f'Wrong params!') raise ValueError(f"Wrong params!")
def forward(self, x): def forward(self, x):
x = self.encoder(x) x = self.encoder(x)
@ -390,46 +448,67 @@ class VQAutoEncoder(nn.Module):
return x, codebook_loss, quant_stats return x, codebook_loss, quant_stats
# patch based discriminator # patch based discriminator
@ARCH_REGISTRY.register() @ARCH_REGISTRY.register()
class VQGANDiscriminator(nn.Module): class VQGANDiscriminator(nn.Module):
def __init__(self, nc=3, ndf=64, n_layers=4, model_path=None): def __init__(self, nc=3, ndf=64, n_layers=4, model_path=None):
super().__init__() super().__init__()
layers = [nn.Conv2d(nc, ndf, kernel_size=4, stride=2, padding=1), nn.LeakyReLU(0.2, True)] layers = [
nn.Conv2d(nc, ndf, kernel_size=4, stride=2, padding=1),
nn.LeakyReLU(0.2, True),
]
ndf_mult = 1 ndf_mult = 1
ndf_mult_prev = 1 ndf_mult_prev = 1
for n in range(1, n_layers): # gradually increase the number of filters for n in range(1, n_layers): # gradually increase the number of filters
ndf_mult_prev = ndf_mult ndf_mult_prev = ndf_mult
ndf_mult = min(2 ** n, 8) ndf_mult = min(2**n, 8)
layers += [ layers += [
nn.Conv2d(ndf * ndf_mult_prev, ndf * ndf_mult, kernel_size=4, stride=2, padding=1, bias=False), nn.Conv2d(
ndf * ndf_mult_prev,
ndf * ndf_mult,
kernel_size=4,
stride=2,
padding=1,
bias=False,
),
nn.BatchNorm2d(ndf * ndf_mult), nn.BatchNorm2d(ndf * ndf_mult),
nn.LeakyReLU(0.2, True) nn.LeakyReLU(0.2, True),
] ]
ndf_mult_prev = ndf_mult ndf_mult_prev = ndf_mult
ndf_mult = min(2 ** n_layers, 8) ndf_mult = min(2**n_layers, 8)
layers += [ layers += [
nn.Conv2d(ndf * ndf_mult_prev, ndf * ndf_mult, kernel_size=4, stride=1, padding=1, bias=False), nn.Conv2d(
ndf * ndf_mult_prev,
ndf * ndf_mult,
kernel_size=4,
stride=1,
padding=1,
bias=False,
),
nn.BatchNorm2d(ndf * ndf_mult), nn.BatchNorm2d(ndf * ndf_mult),
nn.LeakyReLU(0.2, True) nn.LeakyReLU(0.2, True),
] ]
layers += [ layers += [
nn.Conv2d(ndf * ndf_mult, 1, kernel_size=4, stride=1, padding=1)] # output 1 channel prediction map nn.Conv2d(ndf * ndf_mult, 1, kernel_size=4, stride=1, padding=1)
] # output 1 channel prediction map
self.main = nn.Sequential(*layers) self.main = nn.Sequential(*layers)
if model_path is not None: if model_path is not None:
chkpt = torch.load(model_path, map_location='cpu') chkpt = torch.load(model_path, map_location="cpu")
if 'params_d' in chkpt: if "params_d" in chkpt:
self.load_state_dict(torch.load(model_path, map_location='cpu')['params_d']) self.load_state_dict(
elif 'params' in chkpt: torch.load(model_path, map_location="cpu")["params_d"]
self.load_state_dict(torch.load(model_path, map_location='cpu')['params']) )
elif "params" in chkpt:
self.load_state_dict(
torch.load(model_path, map_location="cpu")["params"]
)
else: else:
raise ValueError(f'Wrong params!') raise ValueError(f"Wrong params!")
def forward(self, x): def forward(self, x):
return self.main(x) return self.main(x)

View File

@ -0,0 +1,16 @@
"""
Initialization file for the invokeai.backend.stable_diffusion package
"""
from .concepts_lib import HuggingFaceConceptsLibrary
from .diffusers_pipeline import (
ConditioningData,
PipelineIntermediateState,
StableDiffusionGeneratorPipeline,
)
from .diffusion import InvokeAIDiffuserComponent
from .diffusion.cross_attention_map_saving import AttentionMapSaver
from .diffusion.ddim import DDIMSampler
from .diffusion.ksampler import KSampler
from .diffusion.plms import PLMSSampler
from .diffusion.shared_invokeai_diffusion import PostprocessingSettings
from .textual_inversion_manager import TextualInversionManager

View File

@ -1,21 +1,22 @@
from inspect import isfunction
import math import math
from inspect import isfunction
from typing import Callable, Optional from typing import Callable, Optional
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from torch import nn, einsum
from einops import rearrange, repeat from einops import rearrange, repeat
from torch import einsum, nn
from .diffusion import InvokeAICrossAttentionMixin
from .diffusionmodules.util import checkpoint
from ldm.models.diffusion.cross_attention_control import InvokeAICrossAttentionMixin
from ldm.modules.diffusionmodules.util import checkpoint
def exists(val): def exists(val):
return val is not None return val is not None
def uniq(arr): def uniq(arr):
return{el: True for el in arr}.keys() return {el: True for el in arr}.keys()
def default(val, d): def default(val, d):
@ -47,19 +48,18 @@ class GEGLU(nn.Module):
class FeedForward(nn.Module): class FeedForward(nn.Module):
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.): def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0):
super().__init__() super().__init__()
inner_dim = int(dim * mult) inner_dim = int(dim * mult)
dim_out = default(dim_out, dim) dim_out = default(dim_out, dim)
project_in = nn.Sequential( project_in = (
nn.Linear(dim, inner_dim), nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU())
nn.GELU() if not glu
) if not glu else GEGLU(dim, inner_dim) else GEGLU(dim, inner_dim)
)
self.net = nn.Sequential( self.net = nn.Sequential(
project_in, project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out)
nn.Dropout(dropout),
nn.Linear(inner_dim, dim_out)
) )
def forward(self, x): def forward(self, x):
@ -76,7 +76,9 @@ def zero_module(module):
def Normalize(in_channels): def Normalize(in_channels):
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) return torch.nn.GroupNorm(
num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
)
class LinearAttention(nn.Module): class LinearAttention(nn.Module):
@ -84,17 +86,21 @@ class LinearAttention(nn.Module):
super().__init__() super().__init__()
self.heads = heads self.heads = heads
hidden_dim = dim_head * heads hidden_dim = dim_head * heads
self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False) self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
self.to_out = nn.Conv2d(hidden_dim, dim, 1) self.to_out = nn.Conv2d(hidden_dim, dim, 1)
def forward(self, x): def forward(self, x):
b, c, h, w = x.shape b, c, h, w = x.shape
qkv = self.to_qkv(x) qkv = self.to_qkv(x)
q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads = self.heads, qkv=3) q, k, v = rearrange(
qkv, "b (qkv heads c) h w -> qkv b heads c (h w)", heads=self.heads, qkv=3
)
k = k.softmax(dim=-1) k = k.softmax(dim=-1)
context = torch.einsum('bhdn,bhen->bhde', k, v) context = torch.einsum("bhdn,bhen->bhde", k, v)
out = torch.einsum('bhde,bhdn->bhen', context, q) out = torch.einsum("bhde,bhdn->bhen", context, q)
out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w) out = rearrange(
out, "b heads c (h w) -> b (heads c) h w", heads=self.heads, h=h, w=w
)
return self.to_out(out) return self.to_out(out)
@ -104,26 +110,18 @@ class SpatialSelfAttention(nn.Module):
self.in_channels = in_channels self.in_channels = in_channels
self.norm = Normalize(in_channels) self.norm = Normalize(in_channels)
self.q = torch.nn.Conv2d(in_channels, self.q = torch.nn.Conv2d(
in_channels, in_channels, in_channels, kernel_size=1, stride=1, padding=0
kernel_size=1, )
stride=1, self.k = torch.nn.Conv2d(
padding=0) in_channels, in_channels, kernel_size=1, stride=1, padding=0
self.k = torch.nn.Conv2d(in_channels, )
in_channels, self.v = torch.nn.Conv2d(
kernel_size=1, in_channels, in_channels, kernel_size=1, stride=1, padding=0
stride=1, )
padding=0) self.proj_out = torch.nn.Conv2d(
self.v = torch.nn.Conv2d(in_channels, in_channels, in_channels, kernel_size=1, stride=1, padding=0
in_channels, )
kernel_size=1,
stride=1,
padding=0)
self.proj_out = torch.nn.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
def forward(self, x): def forward(self, x):
h_ = x h_ = x
@ -133,43 +131,45 @@ class SpatialSelfAttention(nn.Module):
v = self.v(h_) v = self.v(h_)
# compute attention # compute attention
b,c,h,w = q.shape b, c, h, w = q.shape
q = rearrange(q, 'b c h w -> b (h w) c') q = rearrange(q, "b c h w -> b (h w) c")
k = rearrange(k, 'b c h w -> b c (h w)') k = rearrange(k, "b c h w -> b c (h w)")
w_ = torch.einsum('bij,bjk->bik', q, k) w_ = torch.einsum("bij,bjk->bik", q, k)
w_ = w_ * (int(c)**(-0.5)) w_ = w_ * (int(c) ** (-0.5))
w_ = torch.nn.functional.softmax(w_, dim=2) w_ = torch.nn.functional.softmax(w_, dim=2)
# attend to values # attend to values
v = rearrange(v, 'b c h w -> b c (h w)') v = rearrange(v, "b c h w -> b c (h w)")
w_ = rearrange(w_, 'b i j -> b j i') w_ = rearrange(w_, "b i j -> b j i")
h_ = torch.einsum('bij,bjk->bik', v, w_) h_ = torch.einsum("bij,bjk->bik", v, w_)
h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h) h_ = rearrange(h_, "b c (h w) -> b c h w", h=h)
h_ = self.proj_out(h_) h_ = self.proj_out(h_)
return x+h_ return x + h_
def get_mem_free_total(device): def get_mem_free_total(device):
#only on cuda # only on cuda
if not torch.cuda.is_available(): if not torch.cuda.is_available():
return None return None
stats = torch.cuda.memory_stats(device) stats = torch.cuda.memory_stats(device)
mem_active = stats['active_bytes.all.current'] mem_active = stats["active_bytes.all.current"]
mem_reserved = stats['reserved_bytes.all.current'] mem_reserved = stats["reserved_bytes.all.current"]
mem_free_cuda, _ = torch.cuda.mem_get_info(device) mem_free_cuda, _ = torch.cuda.mem_get_info(device)
mem_free_torch = mem_reserved - mem_active mem_free_torch = mem_reserved - mem_active
mem_free_total = mem_free_cuda + mem_free_torch mem_free_total = mem_free_cuda + mem_free_torch
return mem_free_total return mem_free_total
class CrossAttention(nn.Module, InvokeAICrossAttentionMixin): class CrossAttention(nn.Module, InvokeAICrossAttentionMixin):
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.): def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0):
super().__init__() super().__init__()
InvokeAICrossAttentionMixin.__init__(self) InvokeAICrossAttentionMixin.__init__(self)
inner_dim = dim_head * heads inner_dim = dim_head * heads
context_dim = default(context_dim, query_dim) context_dim = default(context_dim, query_dim)
self.scale = dim_head ** -0.5 self.scale = dim_head**-0.5
self.heads = heads self.heads = heads
self.to_q = nn.Linear(query_dim, inner_dim, bias=False) self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
@ -177,8 +177,7 @@ class CrossAttention(nn.Module, InvokeAICrossAttentionMixin):
self.to_v = nn.Linear(context_dim, inner_dim, bias=False) self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
self.to_out = nn.Sequential( self.to_out = nn.Sequential(
nn.Linear(inner_dim, query_dim), nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)
nn.Dropout(dropout)
) )
def forward(self, x, context=None, mask=None): def forward(self, x, context=None, mask=None):
@ -190,7 +189,7 @@ class CrossAttention(nn.Module, InvokeAICrossAttentionMixin):
v = self.to_v(context) v = self.to_v(context)
del context, x del context, x
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v))
# don't apply scale twice # don't apply scale twice
cached_scale = self.scale cached_scale = self.scale
@ -198,29 +197,45 @@ class CrossAttention(nn.Module, InvokeAICrossAttentionMixin):
r = self.get_invokeai_attention_mem_efficient(q, k, v) r = self.get_invokeai_attention_mem_efficient(q, k, v)
self.scale = cached_scale self.scale = cached_scale
hidden_states = rearrange(r, '(b h) n d -> b n (h d)', h=h) hidden_states = rearrange(r, "(b h) n d -> b n (h d)", h=h)
return self.to_out(hidden_states) return self.to_out(hidden_states)
class BasicTransformerBlock(nn.Module): class BasicTransformerBlock(nn.Module):
def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True): def __init__(
self,
dim,
n_heads,
d_head,
dropout=0.0,
context_dim=None,
gated_ff=True,
checkpoint=True,
):
super().__init__() super().__init__()
self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout) # is a self-attention self.attn1 = CrossAttention(
query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout
) # is a self-attention
self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff) self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim, self.attn2 = CrossAttention(
heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none query_dim=dim,
context_dim=context_dim,
heads=n_heads,
dim_head=d_head,
dropout=dropout,
) # is self-attn if context is none
self.norm1 = nn.LayerNorm(dim) self.norm1 = nn.LayerNorm(dim)
self.norm2 = nn.LayerNorm(dim) self.norm2 = nn.LayerNorm(dim)
self.norm3 = nn.LayerNorm(dim) self.norm3 = nn.LayerNorm(dim)
self.checkpoint = checkpoint self.checkpoint = checkpoint
def forward(self, x, context=None): def forward(self, x, context=None):
return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint) return checkpoint(
self._forward, (x, context), self.parameters(), self.checkpoint
)
def _forward(self, x, context=None): def _forward(self, x, context=None):
x = x.contiguous() if x.device.type == 'mps' else x x = x.contiguous() if x.device.type == "mps" else x
x += self.attn1(self.norm1(x.clone())) x += self.attn1(self.norm1(x.clone()))
x += self.attn2(self.norm2(x.clone()), context=context) x += self.attn2(self.norm2(x.clone()), context=context)
x += self.ff(self.norm3(x.clone())) x += self.ff(self.norm3(x.clone()))
@ -235,29 +250,31 @@ class SpatialTransformer(nn.Module):
Then apply standard transformer action. Then apply standard transformer action.
Finally, reshape to image Finally, reshape to image
""" """
def __init__(self, in_channels, n_heads, d_head,
depth=1, dropout=0., context_dim=None): def __init__(
self, in_channels, n_heads, d_head, depth=1, dropout=0.0, context_dim=None
):
super().__init__() super().__init__()
self.in_channels = in_channels self.in_channels = in_channels
inner_dim = n_heads * d_head inner_dim = n_heads * d_head
self.norm = Normalize(in_channels) self.norm = Normalize(in_channels)
self.proj_in = nn.Conv2d(in_channels, self.proj_in = nn.Conv2d(
inner_dim, in_channels, inner_dim, kernel_size=1, stride=1, padding=0
kernel_size=1,
stride=1,
padding=0)
self.transformer_blocks = nn.ModuleList(
[BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim)
for d in range(depth)]
) )
self.proj_out = zero_module(nn.Conv2d(inner_dim, self.transformer_blocks = nn.ModuleList(
in_channels, [
kernel_size=1, BasicTransformerBlock(
stride=1, inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim
padding=0)) )
for d in range(depth)
]
)
self.proj_out = zero_module(
nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
)
def forward(self, x, context=None): def forward(self, x, context=None):
# note: if no context is given, cross-attention defaults to self-attention # note: if no context is given, cross-attention defaults to self-attention
@ -265,9 +282,9 @@ class SpatialTransformer(nn.Module):
x_in = x x_in = x
x = self.norm(x) x = self.norm(x)
x = self.proj_in(x) x = self.proj_in(x)
x = rearrange(x, 'b c h w -> b (h w) c') x = rearrange(x, "b c h w -> b (h w) c")
for block in self.transformer_blocks: for block in self.transformer_blocks:
x = block(x, context=context) x = block(x, context=context)
x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w) x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w)
x = self.proj_out(x) x = self.proj_out(x)
return x + x_in return x + x_in

View File

@ -1,16 +1,13 @@
import torch
import pytorch_lightning as pl
import torch.nn.functional as F
from contextlib import contextmanager from contextlib import contextmanager
import pytorch_lightning as pl
import torch
import torch.nn.functional as F
from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer
from ldm.modules.diffusionmodules.model import Encoder, Decoder from ..util import instantiate_from_config
from ldm.modules.distributions.distributions import ( from .diffusionmodules.model import Decoder, Encoder
DiagonalGaussianDistribution, from .distributions.distributions import DiagonalGaussianDistribution
)
from ldm.util import instantiate_from_config
class VQModel(pl.LightningModule): class VQModel(pl.LightningModule):
@ -22,7 +19,7 @@ class VQModel(pl.LightningModule):
embed_dim, embed_dim,
ckpt_path=None, ckpt_path=None,
ignore_keys=[], ignore_keys=[],
image_key='image', image_key="image",
colorize_nlabels=None, colorize_nlabels=None,
monitor=None, monitor=None,
batch_resize_range=None, batch_resize_range=None,
@ -46,27 +43,23 @@ class VQModel(pl.LightningModule):
remap=remap, remap=remap,
sane_index_shape=sane_index_shape, sane_index_shape=sane_index_shape,
) )
self.quant_conv = torch.nn.Conv2d(ddconfig['z_channels'], embed_dim, 1) self.quant_conv = torch.nn.Conv2d(ddconfig["z_channels"], embed_dim, 1)
self.post_quant_conv = torch.nn.Conv2d( self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
embed_dim, ddconfig['z_channels'], 1
)
if colorize_nlabels is not None: if colorize_nlabels is not None:
assert type(colorize_nlabels) == int assert type(colorize_nlabels) == int
self.register_buffer( self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
'colorize', torch.randn(3, colorize_nlabels, 1, 1)
)
if monitor is not None: if monitor is not None:
self.monitor = monitor self.monitor = monitor
self.batch_resize_range = batch_resize_range self.batch_resize_range = batch_resize_range
if self.batch_resize_range is not None: if self.batch_resize_range is not None:
print( print(
f'{self.__class__.__name__}: Using per-batch resizing in range {batch_resize_range}.' f"{self.__class__.__name__}: Using per-batch resizing in range {batch_resize_range}."
) )
self.use_ema = use_ema self.use_ema = use_ema
if self.use_ema: if self.use_ema:
self.model_ema = LitEma(self) self.model_ema = LitEma(self)
print(f'>> Keeping EMAs of {len(list(self.model_ema.buffers()))}.') print(f">> Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
if ckpt_path is not None: if ckpt_path is not None:
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
@ -79,30 +72,30 @@ class VQModel(pl.LightningModule):
self.model_ema.store(self.parameters()) self.model_ema.store(self.parameters())
self.model_ema.copy_to(self) self.model_ema.copy_to(self)
if context is not None: if context is not None:
print(f'{context}: Switched to EMA weights') print(f"{context}: Switched to EMA weights")
try: try:
yield None yield None
finally: finally:
if self.use_ema: if self.use_ema:
self.model_ema.restore(self.parameters()) self.model_ema.restore(self.parameters())
if context is not None: if context is not None:
print(f'{context}: Restored training weights') print(f"{context}: Restored training weights")
def init_from_ckpt(self, path, ignore_keys=list()): def init_from_ckpt(self, path, ignore_keys=list()):
sd = torch.load(path, map_location='cpu')['state_dict'] sd = torch.load(path, map_location="cpu")["state_dict"]
keys = list(sd.keys()) keys = list(sd.keys())
for k in keys: for k in keys:
for ik in ignore_keys: for ik in ignore_keys:
if k.startswith(ik): if k.startswith(ik):
print('Deleting key {} from state_dict.'.format(k)) print("Deleting key {} from state_dict.".format(k))
del sd[k] del sd[k]
missing, unexpected = self.load_state_dict(sd, strict=False) missing, unexpected = self.load_state_dict(sd, strict=False)
print( print(
f'Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys' f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys"
) )
if len(missing) > 0: if len(missing) > 0:
print(f'Missing Keys: {missing}') print(f"Missing Keys: {missing}")
print(f'Unexpected Keys: {unexpected}') print(f"Unexpected Keys: {unexpected}")
def on_train_batch_end(self, *args, **kwargs): def on_train_batch_end(self, *args, **kwargs):
if self.use_ema: if self.use_ema:
@ -140,11 +133,7 @@ class VQModel(pl.LightningModule):
x = batch[k] x = batch[k]
if len(x.shape) == 3: if len(x.shape) == 3:
x = x[..., None] x = x[..., None]
x = ( x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
x.permute(0, 3, 1, 2)
.to(memory_format=torch.contiguous_format)
.float()
)
if self.batch_resize_range is not None: if self.batch_resize_range is not None:
lower_size = self.batch_resize_range[0] lower_size = self.batch_resize_range[0]
upper_size = self.batch_resize_range[1] upper_size = self.batch_resize_range[1]
@ -156,7 +145,7 @@ class VQModel(pl.LightningModule):
np.arange(lower_size, upper_size + 16, 16) np.arange(lower_size, upper_size + 16, 16)
) )
if new_resize != x.shape[2]: if new_resize != x.shape[2]:
x = F.interpolate(x, size=new_resize, mode='bicubic') x = F.interpolate(x, size=new_resize, mode="bicubic")
x = x.detach() x = x.detach()
return x return x
@ -175,7 +164,7 @@ class VQModel(pl.LightningModule):
optimizer_idx, optimizer_idx,
self.global_step, self.global_step,
last_layer=self.get_last_layer(), last_layer=self.get_last_layer(),
split='train', split="train",
predicted_indices=ind, predicted_indices=ind,
) )
@ -197,7 +186,7 @@ class VQModel(pl.LightningModule):
optimizer_idx, optimizer_idx,
self.global_step, self.global_step,
last_layer=self.get_last_layer(), last_layer=self.get_last_layer(),
split='train', split="train",
) )
self.log_dict( self.log_dict(
log_dict_disc, log_dict_disc,
@ -211,12 +200,10 @@ class VQModel(pl.LightningModule):
def validation_step(self, batch, batch_idx): def validation_step(self, batch, batch_idx):
log_dict = self._validation_step(batch, batch_idx) log_dict = self._validation_step(batch, batch_idx)
with self.ema_scope(): with self.ema_scope():
log_dict_ema = self._validation_step( log_dict_ema = self._validation_step(batch, batch_idx, suffix="_ema")
batch, batch_idx, suffix='_ema'
)
return log_dict return log_dict
def _validation_step(self, batch, batch_idx, suffix=''): def _validation_step(self, batch, batch_idx, suffix=""):
x = self.get_input(batch, self.image_key) x = self.get_input(batch, self.image_key)
xrec, qloss, ind = self(x, return_pred_indices=True) xrec, qloss, ind = self(x, return_pred_indices=True)
aeloss, log_dict_ae = self.loss( aeloss, log_dict_ae = self.loss(
@ -226,7 +213,7 @@ class VQModel(pl.LightningModule):
0, 0,
self.global_step, self.global_step,
last_layer=self.get_last_layer(), last_layer=self.get_last_layer(),
split='val' + suffix, split="val" + suffix,
predicted_indices=ind, predicted_indices=ind,
) )
@ -237,12 +224,12 @@ class VQModel(pl.LightningModule):
1, 1,
self.global_step, self.global_step,
last_layer=self.get_last_layer(), last_layer=self.get_last_layer(),
split='val' + suffix, split="val" + suffix,
predicted_indices=ind, predicted_indices=ind,
) )
rec_loss = log_dict_ae[f'val{suffix}/rec_loss'] rec_loss = log_dict_ae[f"val{suffix}/rec_loss"]
self.log( self.log(
f'val{suffix}/rec_loss', f"val{suffix}/rec_loss",
rec_loss, rec_loss,
prog_bar=True, prog_bar=True,
logger=True, logger=True,
@ -251,7 +238,7 @@ class VQModel(pl.LightningModule):
sync_dist=True, sync_dist=True,
) )
self.log( self.log(
f'val{suffix}/aeloss', f"val{suffix}/aeloss",
aeloss, aeloss,
prog_bar=True, prog_bar=True,
logger=True, logger=True,
@ -259,8 +246,8 @@ class VQModel(pl.LightningModule):
on_epoch=True, on_epoch=True,
sync_dist=True, sync_dist=True,
) )
if version.parse(pl.__version__) >= version.parse('1.4.0'): if version.parse(pl.__version__) >= version.parse("1.4.0"):
del log_dict_ae[f'val{suffix}/rec_loss'] del log_dict_ae[f"val{suffix}/rec_loss"]
self.log_dict(log_dict_ae) self.log_dict(log_dict_ae)
self.log_dict(log_dict_disc) self.log_dict(log_dict_disc)
return self.log_dict return self.log_dict
@ -268,8 +255,8 @@ class VQModel(pl.LightningModule):
def configure_optimizers(self): def configure_optimizers(self):
lr_d = self.learning_rate lr_d = self.learning_rate
lr_g = self.lr_g_factor * self.learning_rate lr_g = self.lr_g_factor * self.learning_rate
print('lr_d', lr_d) print("lr_d", lr_d)
print('lr_g', lr_g) print("lr_g", lr_g)
opt_ae = torch.optim.Adam( opt_ae = torch.optim.Adam(
list(self.encoder.parameters()) list(self.encoder.parameters())
+ list(self.decoder.parameters()) + list(self.decoder.parameters())
@ -286,21 +273,17 @@ class VQModel(pl.LightningModule):
if self.scheduler_config is not None: if self.scheduler_config is not None:
scheduler = instantiate_from_config(self.scheduler_config) scheduler = instantiate_from_config(self.scheduler_config)
print('Setting up LambdaLR scheduler...') print("Setting up LambdaLR scheduler...")
scheduler = [ scheduler = [
{ {
'scheduler': LambdaLR( "scheduler": LambdaLR(opt_ae, lr_lambda=scheduler.schedule),
opt_ae, lr_lambda=scheduler.schedule "interval": "step",
), "frequency": 1,
'interval': 'step',
'frequency': 1,
}, },
{ {
'scheduler': LambdaLR( "scheduler": LambdaLR(opt_disc, lr_lambda=scheduler.schedule),
opt_disc, lr_lambda=scheduler.schedule "interval": "step",
), "frequency": 1,
'interval': 'step',
'frequency': 1,
}, },
] ]
return [opt_ae, opt_disc], scheduler return [opt_ae, opt_disc], scheduler
@ -314,7 +297,7 @@ class VQModel(pl.LightningModule):
x = self.get_input(batch, self.image_key) x = self.get_input(batch, self.image_key)
x = x.to(self.device) x = x.to(self.device)
if only_inputs: if only_inputs:
log['inputs'] = x log["inputs"] = x
return log return log
xrec, _ = self(x) xrec, _ = self(x)
if x.shape[1] > 3: if x.shape[1] > 3:
@ -322,22 +305,20 @@ class VQModel(pl.LightningModule):
assert xrec.shape[1] > 3 assert xrec.shape[1] > 3
x = self.to_rgb(x) x = self.to_rgb(x)
xrec = self.to_rgb(xrec) xrec = self.to_rgb(xrec)
log['inputs'] = x log["inputs"] = x
log['reconstructions'] = xrec log["reconstructions"] = xrec
if plot_ema: if plot_ema:
with self.ema_scope(): with self.ema_scope():
xrec_ema, _ = self(x) xrec_ema, _ = self(x)
if x.shape[1] > 3: if x.shape[1] > 3:
xrec_ema = self.to_rgb(xrec_ema) xrec_ema = self.to_rgb(xrec_ema)
log['reconstructions_ema'] = xrec_ema log["reconstructions_ema"] = xrec_ema
return log return log
def to_rgb(self, x): def to_rgb(self, x):
assert self.image_key == 'segmentation' assert self.image_key == "segmentation"
if not hasattr(self, 'colorize'): if not hasattr(self, "colorize"):
self.register_buffer( self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
'colorize', torch.randn(3, x.shape[1], 1, 1).to(x)
)
x = F.conv2d(x, weight=self.colorize) x = F.conv2d(x, weight=self.colorize)
x = 2.0 * (x - x.min()) / (x.max() - x.min()) - 1.0 x = 2.0 * (x - x.min()) / (x.max() - x.min()) - 1.0
return x return x
@ -372,7 +353,7 @@ class AutoencoderKL(pl.LightningModule):
embed_dim, embed_dim,
ckpt_path=None, ckpt_path=None,
ignore_keys=[], ignore_keys=[],
image_key='image', image_key="image",
colorize_nlabels=None, colorize_nlabels=None,
monitor=None, monitor=None,
): ):
@ -381,34 +362,28 @@ class AutoencoderKL(pl.LightningModule):
self.encoder = Encoder(**ddconfig) self.encoder = Encoder(**ddconfig)
self.decoder = Decoder(**ddconfig) self.decoder = Decoder(**ddconfig)
self.loss = instantiate_from_config(lossconfig) self.loss = instantiate_from_config(lossconfig)
assert ddconfig['double_z'] assert ddconfig["double_z"]
self.quant_conv = torch.nn.Conv2d( self.quant_conv = torch.nn.Conv2d(2 * ddconfig["z_channels"], 2 * embed_dim, 1)
2 * ddconfig['z_channels'], 2 * embed_dim, 1 self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
)
self.post_quant_conv = torch.nn.Conv2d(
embed_dim, ddconfig['z_channels'], 1
)
self.embed_dim = embed_dim self.embed_dim = embed_dim
if colorize_nlabels is not None: if colorize_nlabels is not None:
assert type(colorize_nlabels) == int assert type(colorize_nlabels) == int
self.register_buffer( self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
'colorize', torch.randn(3, colorize_nlabels, 1, 1)
)
if monitor is not None: if monitor is not None:
self.monitor = monitor self.monitor = monitor
if ckpt_path is not None: if ckpt_path is not None:
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
def init_from_ckpt(self, path, ignore_keys=list()): def init_from_ckpt(self, path, ignore_keys=list()):
sd = torch.load(path, map_location='cpu')['state_dict'] sd = torch.load(path, map_location="cpu")["state_dict"]
keys = list(sd.keys()) keys = list(sd.keys())
for k in keys: for k in keys:
for ik in ignore_keys: for ik in ignore_keys:
if k.startswith(ik): if k.startswith(ik):
print('Deleting key {} from state_dict.'.format(k)) print("Deleting key {} from state_dict.".format(k))
del sd[k] del sd[k]
self.load_state_dict(sd, strict=False) self.load_state_dict(sd, strict=False)
print(f'Restored from {path}') print(f"Restored from {path}")
def encode(self, x): def encode(self, x):
h = self.encoder(x) h = self.encoder(x)
@ -434,11 +409,7 @@ class AutoencoderKL(pl.LightningModule):
x = batch[k] x = batch[k]
if len(x.shape) == 3: if len(x.shape) == 3:
x = x[..., None] x = x[..., None]
x = ( x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
x.permute(0, 3, 1, 2)
.to(memory_format=torch.contiguous_format)
.float()
)
return x return x
def training_step(self, batch, batch_idx, optimizer_idx): def training_step(self, batch, batch_idx, optimizer_idx):
@ -454,10 +425,10 @@ class AutoencoderKL(pl.LightningModule):
optimizer_idx, optimizer_idx,
self.global_step, self.global_step,
last_layer=self.get_last_layer(), last_layer=self.get_last_layer(),
split='train', split="train",
) )
self.log( self.log(
'aeloss', "aeloss",
aeloss, aeloss,
prog_bar=True, prog_bar=True,
logger=True, logger=True,
@ -482,11 +453,11 @@ class AutoencoderKL(pl.LightningModule):
optimizer_idx, optimizer_idx,
self.global_step, self.global_step,
last_layer=self.get_last_layer(), last_layer=self.get_last_layer(),
split='train', split="train",
) )
self.log( self.log(
'discloss', "discloss",
discloss, discloss,
prog_bar=True, prog_bar=True,
logger=True, logger=True,
@ -512,7 +483,7 @@ class AutoencoderKL(pl.LightningModule):
0, 0,
self.global_step, self.global_step,
last_layer=self.get_last_layer(), last_layer=self.get_last_layer(),
split='val', split="val",
) )
discloss, log_dict_disc = self.loss( discloss, log_dict_disc = self.loss(
@ -522,10 +493,10 @@ class AutoencoderKL(pl.LightningModule):
1, 1,
self.global_step, self.global_step,
last_layer=self.get_last_layer(), last_layer=self.get_last_layer(),
split='val', split="val",
) )
self.log('val/rec_loss', log_dict_ae['val/rec_loss']) self.log("val/rec_loss", log_dict_ae["val/rec_loss"])
self.log_dict(log_dict_ae) self.log_dict(log_dict_ae)
self.log_dict(log_dict_disc) self.log_dict(log_dict_disc)
return self.log_dict return self.log_dict
@ -560,17 +531,15 @@ class AutoencoderKL(pl.LightningModule):
assert xrec.shape[1] > 3 assert xrec.shape[1] > 3
x = self.to_rgb(x) x = self.to_rgb(x)
xrec = self.to_rgb(xrec) xrec = self.to_rgb(xrec)
log['samples'] = self.decode(torch.randn_like(posterior.sample())) log["samples"] = self.decode(torch.randn_like(posterior.sample()))
log['reconstructions'] = xrec log["reconstructions"] = xrec
log['inputs'] = x log["inputs"] = x
return log return log
def to_rgb(self, x): def to_rgb(self, x):
assert self.image_key == 'segmentation' assert self.image_key == "segmentation"
if not hasattr(self, 'colorize'): if not hasattr(self, "colorize"):
self.register_buffer( self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
'colorize', torch.randn(3, x.shape[1], 1, 1).to(x)
)
x = F.conv2d(x, weight=self.colorize) x = F.conv2d(x, weight=self.colorize)
x = 2.0 * (x - x.min()) / (x.max() - x.min()) - 1.0 x = 2.0 * (x - x.min()) / (x.max() - x.min()) - 1.0
return x return x

View File

@ -8,32 +8,50 @@ import os
import re import re
import traceback import traceback
from typing import Callable from typing import Callable
from urllib import request, error as ul_error from urllib import error as ul_error
from huggingface_hub import HfFolder, hf_hub_url, ModelSearchArguments, ModelFilter, HfApi from urllib import request
from ldm.invoke.globals import Globals
from huggingface_hub import (
HfApi,
HfFolder,
ModelFilter,
ModelSearchArguments,
hf_hub_url,
)
from invokeai.backend.globals import Globals
class HuggingFaceConceptsLibrary(object): class HuggingFaceConceptsLibrary(object):
def __init__(self, root=None): def __init__(self, root=None):
''' """
Initialize the Concepts object. May optionally pass a root directory. Initialize the Concepts object. May optionally pass a root directory.
''' """
self.root = root or Globals.root self.root = root or Globals.root
self.hf_api = HfApi() self.hf_api = HfApi()
self.local_concepts = dict() self.local_concepts = dict()
self.concept_list = None self.concept_list = None
self.concepts_loaded = dict() self.concepts_loaded = dict()
self.triggers = dict() # concept name to trigger phrase self.triggers = dict() # concept name to trigger phrase
self.concept_names = dict() # trigger phrase to concept name self.concept_names = dict() # trigger phrase to concept name
self.match_trigger = re.compile('(<[\w\- >]+>)') # trigger is slightly less restrictive than HF concept name self.match_trigger = re.compile(
self.match_concept = re.compile('<([\w\-]+)>') # HF concept name can only contain A-Za-z0-9_- "(<[\w\- >]+>)"
) # trigger is slightly less restrictive than HF concept name
self.match_concept = re.compile(
"<([\w\-]+)>"
) # HF concept name can only contain A-Za-z0-9_-
def list_concepts(self)->list: def list_concepts(self) -> list:
''' """
Return a list of all the concepts by name, without the 'sd-concepts-library' part. Return a list of all the concepts by name, without the 'sd-concepts-library' part.
Also adds local concepts in invokeai/embeddings folder. Also adds local concepts in invokeai/embeddings folder.
''' """
local_concepts_now = self.get_local_concepts(os.path.join(self.root, 'embeddings')) local_concepts_now = self.get_local_concepts(
local_concepts_to_add = set(local_concepts_now).difference(set(self.local_concepts)) os.path.join(self.root, "embeddings")
)
local_concepts_to_add = set(local_concepts_now).difference(
set(self.local_concepts)
)
self.local_concepts.update(local_concepts_now) self.local_concepts.update(local_concepts_now)
if self.concept_list is not None: if self.concept_list is not None:
@ -43,83 +61,96 @@ class HuggingFaceConceptsLibrary(object):
return self.concept_list return self.concept_list
else: else:
try: try:
models = self.hf_api.list_models(filter=ModelFilter(model_name='sd-concepts-library/')) models = self.hf_api.list_models(
self.concept_list = [a.id.split('/')[1] for a in models] filter=ModelFilter(model_name="sd-concepts-library/")
)
self.concept_list = [a.id.split("/")[1] for a in models]
# when init, add all in dir. when not init, add only concepts added between init and now # when init, add all in dir. when not init, add only concepts added between init and now
self.concept_list.extend(list(local_concepts_to_add)) self.concept_list.extend(list(local_concepts_to_add))
except Exception as e: except Exception as e:
print(f' ** WARNING: Hugging Face textual inversion concepts libraries could not be loaded. The error was {str(e)}.') print(
print(' ** You may load .bin and .pt file(s) manually using the --embedding_directory argument.') f" ** WARNING: Hugging Face textual inversion concepts libraries could not be loaded. The error was {str(e)}."
)
print(
" ** You may load .bin and .pt file(s) manually using the --embedding_directory argument."
)
return self.concept_list return self.concept_list
def get_concept_model_path(self, concept_name:str)->str: def get_concept_model_path(self, concept_name: str) -> str:
''' """
Returns the path to the 'learned_embeds.bin' file in Returns the path to the 'learned_embeds.bin' file in
the named concept. Returns None if invalid or cannot the named concept. Returns None if invalid or cannot
be downloaded. be downloaded.
''' """
if not concept_name in self.list_concepts(): if not concept_name in self.list_concepts():
print(f'This concept is not a local embedding trigger, nor is it a HuggingFace concept. Generation will continue without the concept.') print(
f"This concept is not a local embedding trigger, nor is it a HuggingFace concept. Generation will continue without the concept."
)
return None return None
return self.get_concept_file(concept_name.lower(),'learned_embeds.bin') return self.get_concept_file(concept_name.lower(), "learned_embeds.bin")
def concept_to_trigger(self, concept_name:str)->str: def concept_to_trigger(self, concept_name: str) -> str:
''' """
Given a concept name returns its trigger by looking in the Given a concept name returns its trigger by looking in the
"token_identifier.txt" file. "token_identifier.txt" file.
''' """
if concept_name in self.triggers: if concept_name in self.triggers:
return self.triggers[concept_name] return self.triggers[concept_name]
elif self.concept_is_local(concept_name): elif self.concept_is_local(concept_name):
trigger = f'<{concept_name}>' trigger = f"<{concept_name}>"
self.triggers[concept_name] = trigger self.triggers[concept_name] = trigger
self.concept_names[trigger] = concept_name self.concept_names[trigger] = concept_name
return trigger return trigger
file = self.get_concept_file(concept_name, 'token_identifier.txt', local_only=True) file = self.get_concept_file(
concept_name, "token_identifier.txt", local_only=True
)
if not file: if not file:
return None return None
with open(file,'r') as f: with open(file, "r") as f:
trigger = f.readline() trigger = f.readline()
trigger = trigger.strip() trigger = trigger.strip()
self.triggers[concept_name] = trigger self.triggers[concept_name] = trigger
self.concept_names[trigger] = concept_name self.concept_names[trigger] = concept_name
return trigger return trigger
def trigger_to_concept(self, trigger:str)->str: def trigger_to_concept(self, trigger: str) -> str:
''' """
Given a trigger phrase, maps it to the concept library name. Given a trigger phrase, maps it to the concept library name.
Only works if concept_to_trigger() has previously been called Only works if concept_to_trigger() has previously been called
on this library. There needs to be a persistent database for on this library. There needs to be a persistent database for
this. this.
''' """
concept = self.concept_names.get(trigger,None) concept = self.concept_names.get(trigger, None)
return f'<{concept}>' if concept else f'{trigger}' return f"<{concept}>" if concept else f"{trigger}"
def replace_triggers_with_concepts(self, prompt:str)->str: def replace_triggers_with_concepts(self, prompt: str) -> str:
''' """
Given a prompt string that contains <trigger> tags, replace these Given a prompt string that contains <trigger> tags, replace these
tags with the concept name. The reason for this is so that the tags with the concept name. The reason for this is so that the
concept names get stored in the prompt metadata. There is no concept names get stored in the prompt metadata. There is no
controlling of colliding triggers in the SD library, so it is controlling of colliding triggers in the SD library, so it is
better to store the concept name (unique) than the concept trigger better to store the concept name (unique) than the concept trigger
(not necessarily unique!) (not necessarily unique!)
''' """
if not prompt: if not prompt:
return prompt return prompt
triggers = self.match_trigger.findall(prompt) triggers = self.match_trigger.findall(prompt)
if not triggers: if not triggers:
return prompt return prompt
def do_replace(match)->str: def do_replace(match) -> str:
return self.trigger_to_concept(match.group(1)) or f'<{match.group(1)}>' return self.trigger_to_concept(match.group(1)) or f"<{match.group(1)}>"
return self.match_trigger.sub(do_replace, prompt) return self.match_trigger.sub(do_replace, prompt)
def replace_concepts_with_triggers(self, def replace_concepts_with_triggers(
prompt:str, self,
load_concepts_callback: Callable[[list], any], prompt: str,
excluded_tokens:list[str])->str: load_concepts_callback: Callable[[list], any],
''' excluded_tokens: list[str],
) -> str:
"""
Given a prompt string that contains `<concept_name>` tags, replace Given a prompt string that contains `<concept_name>` tags, replace
these tags with the appropriate trigger. these tags with the appropriate trigger.
@ -128,20 +159,30 @@ class HuggingFaceConceptsLibrary(object):
`excluded_tokens` are any tokens that should not be replaced, typically because they `excluded_tokens` are any tokens that should not be replaced, typically because they
are trigger tokens from a locally-loaded embedding. are trigger tokens from a locally-loaded embedding.
''' """
concepts = self.match_concept.findall(prompt) concepts = self.match_concept.findall(prompt)
if not concepts: if not concepts:
return prompt return prompt
load_concepts_callback(concepts) load_concepts_callback(concepts)
def do_replace(match)->str: def do_replace(match) -> str:
if excluded_tokens and f'<{match.group(1)}>' in excluded_tokens: if excluded_tokens and f"<{match.group(1)}>" in excluded_tokens:
return f'<{match.group(1)}>' return f"<{match.group(1)}>"
return self.concept_to_trigger(match.group(1)) or f'<{match.group(1)}>' return self.concept_to_trigger(match.group(1)) or f"<{match.group(1)}>"
return self.match_concept.sub(do_replace, prompt) return self.match_concept.sub(do_replace, prompt)
def get_concept_file(self, concept_name:str, file_name:str='learned_embeds.bin' , local_only:bool=False)->str: def get_concept_file(
if not (self.concept_is_downloaded(concept_name) or self.concept_is_local(concept_name) or local_only): self,
concept_name: str,
file_name: str = "learned_embeds.bin",
local_only: bool = False,
) -> str:
if not (
self.concept_is_downloaded(concept_name)
or self.concept_is_local(concept_name)
or local_only
):
self.download_concept(concept_name) self.download_concept(concept_name)
# get local path in invokeai/embeddings if local concept # get local path in invokeai/embeddings if local concept
@ -153,19 +194,19 @@ class HuggingFaceConceptsLibrary(object):
path = os.path.join(concept_path, file_name) path = os.path.join(concept_path, file_name)
return path if os.path.exists(path) else None return path if os.path.exists(path) else None
def concept_is_local(self, concept_name)->bool: def concept_is_local(self, concept_name) -> bool:
return concept_name in self.local_concepts return concept_name in self.local_concepts
def concept_is_downloaded(self, concept_name)->bool: def concept_is_downloaded(self, concept_name) -> bool:
concept_directory = self._concept_path(concept_name) concept_directory = self._concept_path(concept_name)
return os.path.exists(concept_directory) return os.path.exists(concept_directory)
def download_concept(self,concept_name)->bool: def download_concept(self, concept_name) -> bool:
repo_id = self._concept_id(concept_name) repo_id = self._concept_id(concept_name)
dest = self._concept_path(concept_name) dest = self._concept_path(concept_name)
access_token = HfFolder.get_token() access_token = HfFolder.get_token()
header = [("Authorization", f'Bearer {access_token}')] if access_token else [] header = [("Authorization", f"Bearer {access_token}")] if access_token else []
opener = request.build_opener() opener = request.build_opener()
opener.addheaders = header opener.addheaders = header
request.install_opener(opener) request.install_opener(opener)
@ -174,45 +215,59 @@ class HuggingFaceConceptsLibrary(object):
succeeded = True succeeded = True
bytes = 0 bytes = 0
def tally_download_size(chunk, size, total): def tally_download_size(chunk, size, total):
nonlocal bytes nonlocal bytes
if chunk==0: if chunk == 0:
bytes += total bytes += total
print(f'>> Downloading {repo_id}...',end='') print(f">> Downloading {repo_id}...", end="")
try: try:
for file in ('README.md','learned_embeds.bin','token_identifier.txt','type_of_concept.txt'): for file in (
"README.md",
"learned_embeds.bin",
"token_identifier.txt",
"type_of_concept.txt",
):
url = hf_hub_url(repo_id, file) url = hf_hub_url(repo_id, file)
request.urlretrieve(url, os.path.join(dest,file),reporthook=tally_download_size) request.urlretrieve(
url, os.path.join(dest, file), reporthook=tally_download_size
)
except ul_error.HTTPError as e: except ul_error.HTTPError as e:
if e.code==404: if e.code == 404:
print(f'This concept is not known to the Hugging Face library. Generation will continue without the concept.') print(
f"This concept is not known to the Hugging Face library. Generation will continue without the concept."
)
else: else:
print(f'Failed to download {concept_name}/{file} ({str(e)}. Generation will continue without the concept.)') print(
f"Failed to download {concept_name}/{file} ({str(e)}. Generation will continue without the concept.)"
)
os.rmdir(dest) os.rmdir(dest)
return False return False
except ul_error.URLError as e: except ul_error.URLError as e:
print(f'ERROR: {str(e)}. This may reflect a network issue. Generation will continue without the concept.') print(
f"ERROR: {str(e)}. This may reflect a network issue. Generation will continue without the concept."
)
os.rmdir(dest) os.rmdir(dest)
return False return False
print('...{:.2f}Kb'.format(bytes/1024)) print("...{:.2f}Kb".format(bytes / 1024))
return succeeded return succeeded
def _concept_id(self, concept_name:str)->str: def _concept_id(self, concept_name: str) -> str:
return f'sd-concepts-library/{concept_name}' return f"sd-concepts-library/{concept_name}"
def _concept_path(self, concept_name:str)->str: def _concept_path(self, concept_name: str) -> str:
return os.path.join(self.root,'models','sd-concepts-library',concept_name) return os.path.join(self.root, "models", "sd-concepts-library", concept_name)
def _concept_local_path(self, concept_name:str)->str: def _concept_local_path(self, concept_name: str) -> str:
filename = self.local_concepts[concept_name] filename = self.local_concepts[concept_name]
return os.path.join(self.root,'embeddings',filename) return os.path.join(self.root, "embeddings", filename)
def get_local_concepts(self, loc_dir:str): def get_local_concepts(self, loc_dir: str):
locs_dic = dict() locs_dic = dict()
if os.path.isdir(loc_dir): if os.path.isdir(loc_dir):
for file in os.listdir(loc_dir): for file in os.listdir(loc_dir):
f = os.path.splitext(file) f = os.path.splitext(file)
if f[1] == '.bin' or f[1] == '.pt': if f[1] == ".bin" or f[1] == ".pt":
locs_dic[f[0]] = file locs_dic[f[0]] = file
return locs_dic return locs_dic

View File

@ -1,10 +1,6 @@
from abc import abstractmethod from abc import abstractmethod
from torch.utils.data import (
Dataset, from torch.utils.data import ChainDataset, ConcatDataset, Dataset, IterableDataset
ConcatDataset,
ChainDataset,
IterableDataset,
)
class Txt2ImgIterableBaseDataset(IterableDataset): class Txt2ImgIterableBaseDataset(IterableDataset):
@ -19,9 +15,7 @@ class Txt2ImgIterableBaseDataset(IterableDataset):
self.sample_ids = valid_ids self.sample_ids = valid_ids
self.size = size self.size = size
print( print(f"{self.__class__.__name__} dataset contains {self.__len__()} examples.")
f'{self.__class__.__name__} dataset contains {self.__len__()} examples.'
)
def __len__(self): def __len__(self):
return self.num_records return self.num_records

View File

@ -1,31 +1,32 @@
import os, yaml, pickle, shutil, tarfile, glob import glob
import cv2 import os
import albumentations import pickle
import PIL import shutil
import numpy as np import tarfile
import torchvision.transforms.functional as TF
from omegaconf import OmegaConf
from functools import partial from functools import partial
from PIL import Image
from tqdm import tqdm
from torch.utils.data import Dataset, Subset
import albumentations
import cv2
import numpy as np
import PIL
import taming.data.utils as tdu import taming.data.utils as tdu
import torchvision.transforms.functional as TF
import yaml
from ldm.modules.image_degradation import degradation_fn_bsr, degradation_fn_bsr_light
from omegaconf import OmegaConf
from PIL import Image
from taming.data.imagenet import ( from taming.data.imagenet import (
str_to_indices, ImagePaths,
give_synsets_from_indices,
download, download,
give_synsets_from_indices,
retrieve, retrieve,
str_to_indices,
) )
from taming.data.imagenet import ImagePaths from torch.utils.data import Dataset, Subset
from tqdm import tqdm
from ldm.modules.image_degradation import (
degradation_fn_bsr,
degradation_fn_bsr_light,
)
def synset2idx(path_to_yaml='data/index_synset.yaml'): def synset2idx(path_to_yaml="data/index_synset.yaml"):
with open(path_to_yaml) as f: with open(path_to_yaml) as f:
di2s = yaml.load(f) di2s = yaml.load(f)
return dict((v, k) for k, v in di2s.items()) return dict((v, k) for k, v in di2s.items())
@ -36,9 +37,7 @@ class ImageNetBase(Dataset):
self.config = config or OmegaConf.create() self.config = config or OmegaConf.create()
if not type(self.config) == dict: if not type(self.config) == dict:
self.config = OmegaConf.to_container(self.config) self.config = OmegaConf.to_container(self.config)
self.keep_orig_class_label = self.config.get( self.keep_orig_class_label = self.config.get("keep_orig_class_label", False)
'keep_orig_class_label', False
)
self.process_images = True # if False we skip loading & processing images and self.data contains filepaths self.process_images = True # if False we skip loading & processing images and self.data contains filepaths
self._prepare() self._prepare()
self._prepare_synset_to_human() self._prepare_synset_to_human()
@ -58,21 +57,19 @@ class ImageNetBase(Dataset):
def _filter_relpaths(self, relpaths): def _filter_relpaths(self, relpaths):
ignore = set( ignore = set(
[ [
'n06596364_9591.JPEG', "n06596364_9591.JPEG",
] ]
) )
relpaths = [ relpaths = [rpath for rpath in relpaths if not rpath.split("/")[-1] in ignore]
rpath for rpath in relpaths if not rpath.split('/')[-1] in ignore if "sub_indices" in self.config:
] indices = str_to_indices(self.config["sub_indices"])
if 'sub_indices' in self.config:
indices = str_to_indices(self.config['sub_indices'])
synsets = give_synsets_from_indices( synsets = give_synsets_from_indices(
indices, path_to_yaml=self.idx2syn indices, path_to_yaml=self.idx2syn
) # returns a list of strings ) # returns a list of strings
self.synset2idx = synset2idx(path_to_yaml=self.idx2syn) self.synset2idx = synset2idx(path_to_yaml=self.idx2syn)
files = [] files = []
for rpath in relpaths: for rpath in relpaths:
syn = rpath.split('/')[0] syn = rpath.split("/")[0]
if syn in synsets: if syn in synsets:
files.append(rpath) files.append(rpath)
return files return files
@ -81,8 +78,8 @@ class ImageNetBase(Dataset):
def _prepare_synset_to_human(self): def _prepare_synset_to_human(self):
SIZE = 2655750 SIZE = 2655750
URL = 'https://heibox.uni-heidelberg.de/f/9f28e956cd304264bb82/?dl=1' URL = "https://heibox.uni-heidelberg.de/f/9f28e956cd304264bb82/?dl=1"
self.human_dict = os.path.join(self.root, 'synset_human.txt') self.human_dict = os.path.join(self.root, "synset_human.txt")
if ( if (
not os.path.exists(self.human_dict) not os.path.exists(self.human_dict)
or not os.path.getsize(self.human_dict) == SIZE or not os.path.getsize(self.human_dict) == SIZE
@ -90,64 +87,62 @@ class ImageNetBase(Dataset):
download(URL, self.human_dict) download(URL, self.human_dict)
def _prepare_idx_to_synset(self): def _prepare_idx_to_synset(self):
URL = 'https://heibox.uni-heidelberg.de/f/d835d5b6ceda4d3aa910/?dl=1' URL = "https://heibox.uni-heidelberg.de/f/d835d5b6ceda4d3aa910/?dl=1"
self.idx2syn = os.path.join(self.root, 'index_synset.yaml') self.idx2syn = os.path.join(self.root, "index_synset.yaml")
if not os.path.exists(self.idx2syn): if not os.path.exists(self.idx2syn):
download(URL, self.idx2syn) download(URL, self.idx2syn)
def _prepare_human_to_integer_label(self): def _prepare_human_to_integer_label(self):
URL = 'https://heibox.uni-heidelberg.de/f/2362b797d5be43b883f6/?dl=1' URL = "https://heibox.uni-heidelberg.de/f/2362b797d5be43b883f6/?dl=1"
self.human2integer = os.path.join( self.human2integer = os.path.join(
self.root, 'imagenet1000_clsidx_to_labels.txt' self.root, "imagenet1000_clsidx_to_labels.txt"
) )
if not os.path.exists(self.human2integer): if not os.path.exists(self.human2integer):
download(URL, self.human2integer) download(URL, self.human2integer)
with open(self.human2integer, 'r') as f: with open(self.human2integer, "r") as f:
lines = f.read().splitlines() lines = f.read().splitlines()
assert len(lines) == 1000 assert len(lines) == 1000
self.human2integer_dict = dict() self.human2integer_dict = dict()
for line in lines: for line in lines:
value, key = line.split(':') value, key = line.split(":")
self.human2integer_dict[key] = int(value) self.human2integer_dict[key] = int(value)
def _load(self): def _load(self):
with open(self.txt_filelist, 'r') as f: with open(self.txt_filelist, "r") as f:
self.relpaths = f.read().splitlines() self.relpaths = f.read().splitlines()
l1 = len(self.relpaths) l1 = len(self.relpaths)
self.relpaths = self._filter_relpaths(self.relpaths) self.relpaths = self._filter_relpaths(self.relpaths)
print( print(
'Removed {} files from filelist during filtering.'.format( "Removed {} files from filelist during filtering.".format(
l1 - len(self.relpaths) l1 - len(self.relpaths)
) )
) )
self.synsets = [p.split('/')[0] for p in self.relpaths] self.synsets = [p.split("/")[0] for p in self.relpaths]
self.abspaths = [os.path.join(self.datadir, p) for p in self.relpaths] self.abspaths = [os.path.join(self.datadir, p) for p in self.relpaths]
unique_synsets = np.unique(self.synsets) unique_synsets = np.unique(self.synsets)
class_dict = dict( class_dict = dict((synset, i) for i, synset in enumerate(unique_synsets))
(synset, i) for i, synset in enumerate(unique_synsets)
)
if not self.keep_orig_class_label: if not self.keep_orig_class_label:
self.class_labels = [class_dict[s] for s in self.synsets] self.class_labels = [class_dict[s] for s in self.synsets]
else: else:
self.class_labels = [self.synset2idx[s] for s in self.synsets] self.class_labels = [self.synset2idx[s] for s in self.synsets]
with open(self.human_dict, 'r') as f: with open(self.human_dict, "r") as f:
human_dict = f.read().splitlines() human_dict = f.read().splitlines()
human_dict = dict(line.split(maxsplit=1) for line in human_dict) human_dict = dict(line.split(maxsplit=1) for line in human_dict)
self.human_labels = [human_dict[s] for s in self.synsets] self.human_labels = [human_dict[s] for s in self.synsets]
labels = { labels = {
'relpath': np.array(self.relpaths), "relpath": np.array(self.relpaths),
'synsets': np.array(self.synsets), "synsets": np.array(self.synsets),
'class_label': np.array(self.class_labels), "class_label": np.array(self.class_labels),
'human_label': np.array(self.human_labels), "human_label": np.array(self.human_labels),
} }
if self.process_images: if self.process_images:
self.size = retrieve(self.config, 'size', default=256) self.size = retrieve(self.config, "size", default=256)
self.data = ImagePaths( self.data = ImagePaths(
self.abspaths, self.abspaths,
labels=labels, labels=labels,
@ -159,11 +154,11 @@ class ImageNetBase(Dataset):
class ImageNetTrain(ImageNetBase): class ImageNetTrain(ImageNetBase):
NAME = 'ILSVRC2012_train' NAME = "ILSVRC2012_train"
URL = 'http://www.image-net.org/challenges/LSVRC/2012/' URL = "http://www.image-net.org/challenges/LSVRC/2012/"
AT_HASH = 'a306397ccf9c2ead27155983c254227c0fd938e2' AT_HASH = "a306397ccf9c2ead27155983c254227c0fd938e2"
FILES = [ FILES = [
'ILSVRC2012_img_train.tar', "ILSVRC2012_img_train.tar",
] ]
SIZES = [ SIZES = [
147897477120, 147897477120,
@ -178,20 +173,18 @@ class ImageNetTrain(ImageNetBase):
if self.data_root: if self.data_root:
self.root = os.path.join(self.data_root, self.NAME) self.root = os.path.join(self.data_root, self.NAME)
else: else:
cachedir = os.environ.get( cachedir = os.environ.get("XDG_CACHE_HOME", os.path.expanduser("~/.cache"))
'XDG_CACHE_HOME', os.path.expanduser('~/.cache') self.root = os.path.join(cachedir, "autoencoders/data", self.NAME)
)
self.root = os.path.join(cachedir, 'autoencoders/data', self.NAME)
self.datadir = os.path.join(self.root, 'data') self.datadir = os.path.join(self.root, "data")
self.txt_filelist = os.path.join(self.root, 'filelist.txt') self.txt_filelist = os.path.join(self.root, "filelist.txt")
self.expected_length = 1281167 self.expected_length = 1281167
self.random_crop = retrieve( self.random_crop = retrieve(
self.config, 'ImageNetTrain/random_crop', default=True self.config, "ImageNetTrain/random_crop", default=True
) )
if not tdu.is_prepared(self.root): if not tdu.is_prepared(self.root):
# prep # prep
print('Preparing dataset {} in {}'.format(self.NAME, self.root)) print("Preparing dataset {} in {}".format(self.NAME, self.root))
datadir = self.datadir datadir = self.datadir
if not os.path.exists(datadir): if not os.path.exists(datadir):
@ -205,37 +198,37 @@ class ImageNetTrain(ImageNetBase):
atpath = at.get(self.AT_HASH, datastore=self.root) atpath = at.get(self.AT_HASH, datastore=self.root)
assert atpath == path assert atpath == path
print('Extracting {} to {}'.format(path, datadir)) print("Extracting {} to {}".format(path, datadir))
os.makedirs(datadir, exist_ok=True) os.makedirs(datadir, exist_ok=True)
with tarfile.open(path, 'r:') as tar: with tarfile.open(path, "r:") as tar:
tar.extractall(path=datadir) tar.extractall(path=datadir)
print('Extracting sub-tars.') print("Extracting sub-tars.")
subpaths = sorted(glob.glob(os.path.join(datadir, '*.tar'))) subpaths = sorted(glob.glob(os.path.join(datadir, "*.tar")))
for subpath in tqdm(subpaths): for subpath in tqdm(subpaths):
subdir = subpath[: -len('.tar')] subdir = subpath[: -len(".tar")]
os.makedirs(subdir, exist_ok=True) os.makedirs(subdir, exist_ok=True)
with tarfile.open(subpath, 'r:') as tar: with tarfile.open(subpath, "r:") as tar:
tar.extractall(path=subdir) tar.extractall(path=subdir)
filelist = glob.glob(os.path.join(datadir, '**', '*.JPEG')) filelist = glob.glob(os.path.join(datadir, "**", "*.JPEG"))
filelist = [os.path.relpath(p, start=datadir) for p in filelist] filelist = [os.path.relpath(p, start=datadir) for p in filelist]
filelist = sorted(filelist) filelist = sorted(filelist)
filelist = '\n'.join(filelist) + '\n' filelist = "\n".join(filelist) + "\n"
with open(self.txt_filelist, 'w') as f: with open(self.txt_filelist, "w") as f:
f.write(filelist) f.write(filelist)
tdu.mark_prepared(self.root) tdu.mark_prepared(self.root)
class ImageNetValidation(ImageNetBase): class ImageNetValidation(ImageNetBase):
NAME = 'ILSVRC2012_validation' NAME = "ILSVRC2012_validation"
URL = 'http://www.image-net.org/challenges/LSVRC/2012/' URL = "http://www.image-net.org/challenges/LSVRC/2012/"
AT_HASH = '5d6d0df7ed81efd49ca99ea4737e0ae5e3a5f2e5' AT_HASH = "5d6d0df7ed81efd49ca99ea4737e0ae5e3a5f2e5"
VS_URL = 'https://heibox.uni-heidelberg.de/f/3e0f6e9c624e45f2bd73/?dl=1' VS_URL = "https://heibox.uni-heidelberg.de/f/3e0f6e9c624e45f2bd73/?dl=1"
FILES = [ FILES = [
'ILSVRC2012_img_val.tar', "ILSVRC2012_img_val.tar",
'validation_synset.txt', "validation_synset.txt",
] ]
SIZES = [ SIZES = [
6744924160, 6744924160,
@ -251,19 +244,17 @@ class ImageNetValidation(ImageNetBase):
if self.data_root: if self.data_root:
self.root = os.path.join(self.data_root, self.NAME) self.root = os.path.join(self.data_root, self.NAME)
else: else:
cachedir = os.environ.get( cachedir = os.environ.get("XDG_CACHE_HOME", os.path.expanduser("~/.cache"))
'XDG_CACHE_HOME', os.path.expanduser('~/.cache') self.root = os.path.join(cachedir, "autoencoders/data", self.NAME)
) self.datadir = os.path.join(self.root, "data")
self.root = os.path.join(cachedir, 'autoencoders/data', self.NAME) self.txt_filelist = os.path.join(self.root, "filelist.txt")
self.datadir = os.path.join(self.root, 'data')
self.txt_filelist = os.path.join(self.root, 'filelist.txt')
self.expected_length = 50000 self.expected_length = 50000
self.random_crop = retrieve( self.random_crop = retrieve(
self.config, 'ImageNetValidation/random_crop', default=False self.config, "ImageNetValidation/random_crop", default=False
) )
if not tdu.is_prepared(self.root): if not tdu.is_prepared(self.root):
# prep # prep
print('Preparing dataset {} in {}'.format(self.NAME, self.root)) print("Preparing dataset {} in {}".format(self.NAME, self.root))
datadir = self.datadir datadir = self.datadir
if not os.path.exists(datadir): if not os.path.exists(datadir):
@ -277,9 +268,9 @@ class ImageNetValidation(ImageNetBase):
atpath = at.get(self.AT_HASH, datastore=self.root) atpath = at.get(self.AT_HASH, datastore=self.root)
assert atpath == path assert atpath == path
print('Extracting {} to {}'.format(path, datadir)) print("Extracting {} to {}".format(path, datadir))
os.makedirs(datadir, exist_ok=True) os.makedirs(datadir, exist_ok=True)
with tarfile.open(path, 'r:') as tar: with tarfile.open(path, "r:") as tar:
tar.extractall(path=datadir) tar.extractall(path=datadir)
vspath = os.path.join(self.root, self.FILES[1]) vspath = os.path.join(self.root, self.FILES[1])
@ -289,11 +280,11 @@ class ImageNetValidation(ImageNetBase):
): ):
download(self.VS_URL, vspath) download(self.VS_URL, vspath)
with open(vspath, 'r') as f: with open(vspath, "r") as f:
synset_dict = f.read().splitlines() synset_dict = f.read().splitlines()
synset_dict = dict(line.split() for line in synset_dict) synset_dict = dict(line.split() for line in synset_dict)
print('Reorganizing into synset folders') print("Reorganizing into synset folders")
synsets = np.unique(list(synset_dict.values())) synsets = np.unique(list(synset_dict.values()))
for s in synsets: for s in synsets:
os.makedirs(os.path.join(datadir, s), exist_ok=True) os.makedirs(os.path.join(datadir, s), exist_ok=True)
@ -302,11 +293,11 @@ class ImageNetValidation(ImageNetBase):
dst = os.path.join(datadir, v) dst = os.path.join(datadir, v)
shutil.move(src, dst) shutil.move(src, dst)
filelist = glob.glob(os.path.join(datadir, '**', '*.JPEG')) filelist = glob.glob(os.path.join(datadir, "**", "*.JPEG"))
filelist = [os.path.relpath(p, start=datadir) for p in filelist] filelist = [os.path.relpath(p, start=datadir) for p in filelist]
filelist = sorted(filelist) filelist = sorted(filelist)
filelist = '\n'.join(filelist) + '\n' filelist = "\n".join(filelist) + "\n"
with open(self.txt_filelist, 'w') as f: with open(self.txt_filelist, "w") as f:
f.write(filelist) f.write(filelist)
tdu.mark_prepared(self.root) tdu.mark_prepared(self.root)
@ -356,32 +347,28 @@ class ImageNetSR(Dataset):
False # gets reset later if incase interp_op is from pillow False # gets reset later if incase interp_op is from pillow
) )
if degradation == 'bsrgan': if degradation == "bsrgan":
self.degradation_process = partial( self.degradation_process = partial(degradation_fn_bsr, sf=downscale_f)
degradation_fn_bsr, sf=downscale_f
)
elif degradation == 'bsrgan_light': elif degradation == "bsrgan_light":
self.degradation_process = partial( self.degradation_process = partial(degradation_fn_bsr_light, sf=downscale_f)
degradation_fn_bsr_light, sf=downscale_f
)
else: else:
interpolation_fn = { interpolation_fn = {
'cv_nearest': cv2.INTER_NEAREST, "cv_nearest": cv2.INTER_NEAREST,
'cv_bilinear': cv2.INTER_LINEAR, "cv_bilinear": cv2.INTER_LINEAR,
'cv_bicubic': cv2.INTER_CUBIC, "cv_bicubic": cv2.INTER_CUBIC,
'cv_area': cv2.INTER_AREA, "cv_area": cv2.INTER_AREA,
'cv_lanczos': cv2.INTER_LANCZOS4, "cv_lanczos": cv2.INTER_LANCZOS4,
'pil_nearest': PIL.Image.NEAREST, "pil_nearest": PIL.Image.NEAREST,
'pil_bilinear': PIL.Image.BILINEAR, "pil_bilinear": PIL.Image.BILINEAR,
'pil_bicubic': PIL.Image.BICUBIC, "pil_bicubic": PIL.Image.BICUBIC,
'pil_box': PIL.Image.BOX, "pil_box": PIL.Image.BOX,
'pil_hamming': PIL.Image.HAMMING, "pil_hamming": PIL.Image.HAMMING,
'pil_lanczos': PIL.Image.LANCZOS, "pil_lanczos": PIL.Image.LANCZOS,
}[degradation] }[degradation]
self.pil_interpolation = degradation.startswith('pil_') self.pil_interpolation = degradation.startswith("pil_")
if self.pil_interpolation: if self.pil_interpolation:
self.degradation_process = partial( self.degradation_process = partial(
@ -400,10 +387,10 @@ class ImageNetSR(Dataset):
def __getitem__(self, i): def __getitem__(self, i):
example = self.base[i] example = self.base[i]
image = Image.open(example['file_path_']) image = Image.open(example["file_path_"])
if not image.mode == 'RGB': if not image.mode == "RGB":
image = image.convert('RGB') image = image.convert("RGB")
image = np.array(image).astype(np.uint8) image = np.array(image).astype(np.uint8)
@ -423,8 +410,8 @@ class ImageNetSR(Dataset):
height=crop_side_len, width=crop_side_len height=crop_side_len, width=crop_side_len
) )
image = self.cropper(image=image)['image'] image = self.cropper(image=image)["image"]
image = self.image_rescaler(image=image)['image'] image = self.image_rescaler(image=image)["image"]
if self.pil_interpolation: if self.pil_interpolation:
image_pil = PIL.Image.fromarray(image) image_pil = PIL.Image.fromarray(image)
@ -432,10 +419,10 @@ class ImageNetSR(Dataset):
LR_image = np.array(LR_image).astype(np.uint8) LR_image = np.array(LR_image).astype(np.uint8)
else: else:
LR_image = self.degradation_process(image=image)['image'] LR_image = self.degradation_process(image=image)["image"]
example['image'] = (image / 127.5 - 1.0).astype(np.float32) example["image"] = (image / 127.5 - 1.0).astype(np.float32)
example['LR_image'] = (LR_image / 127.5 - 1.0).astype(np.float32) example["LR_image"] = (LR_image / 127.5 - 1.0).astype(np.float32)
return example return example
@ -445,7 +432,7 @@ class ImageNetSRTrain(ImageNetSR):
super().__init__(**kwargs) super().__init__(**kwargs)
def get_base(self): def get_base(self):
with open('data/imagenet_train_hr_indices.p', 'rb') as f: with open("data/imagenet_train_hr_indices.p", "rb") as f:
indices = pickle.load(f) indices = pickle.load(f)
dset = ImageNetTrain( dset = ImageNetTrain(
process_images=False, process_images=False,
@ -458,7 +445,7 @@ class ImageNetSRValidation(ImageNetSR):
super().__init__(**kwargs) super().__init__(**kwargs)
def get_base(self): def get_base(self):
with open('data/imagenet_val_hr_indices.p', 'rb') as f: with open("data/imagenet_val_hr_indices.p", "rb") as f:
indices = pickle.load(f) indices = pickle.load(f)
dset = ImageNetValidation( dset = ImageNetValidation(
process_images=False, process_images=False,

View File

@ -1,4 +1,5 @@
import os import os
import numpy as np import numpy as np
import PIL import PIL
from PIL import Image from PIL import Image
@ -12,27 +13,25 @@ class LSUNBase(Dataset):
txt_file, txt_file,
data_root, data_root,
size=None, size=None,
interpolation='bicubic', interpolation="bicubic",
flip_p=0.5, flip_p=0.5,
): ):
self.data_paths = txt_file self.data_paths = txt_file
self.data_root = data_root self.data_root = data_root
with open(self.data_paths, 'r') as f: with open(self.data_paths, "r") as f:
self.image_paths = f.read().splitlines() self.image_paths = f.read().splitlines()
self._length = len(self.image_paths) self._length = len(self.image_paths)
self.labels = { self.labels = {
'relative_file_path_': [l for l in self.image_paths], "relative_file_path_": [l for l in self.image_paths],
'file_path_': [ "file_path_": [os.path.join(self.data_root, l) for l in self.image_paths],
os.path.join(self.data_root, l) for l in self.image_paths
],
} }
self.size = size self.size = size
self.interpolation = { self.interpolation = {
'linear': PIL.Image.LINEAR, "linear": PIL.Image.LINEAR,
'bilinear': PIL.Image.BILINEAR, "bilinear": PIL.Image.BILINEAR,
'bicubic': PIL.Image.BICUBIC, "bicubic": PIL.Image.BICUBIC,
'lanczos': PIL.Image.LANCZOS, "lanczos": PIL.Image.LANCZOS,
}[interpolation] }[interpolation]
self.flip = transforms.RandomHorizontalFlip(p=flip_p) self.flip = transforms.RandomHorizontalFlip(p=flip_p)
@ -41,14 +40,17 @@ class LSUNBase(Dataset):
def __getitem__(self, i): def __getitem__(self, i):
example = dict((k, self.labels[k][i]) for k in self.labels) example = dict((k, self.labels[k][i]) for k in self.labels)
image = Image.open(example['file_path_']) image = Image.open(example["file_path_"])
if not image.mode == 'RGB': if not image.mode == "RGB":
image = image.convert('RGB') image = image.convert("RGB")
# default to score-sde preprocessing # default to score-sde preprocessing
img = np.array(image).astype(np.uint8) img = np.array(image).astype(np.uint8)
crop = min(img.shape[0], img.shape[1]) crop = min(img.shape[0], img.shape[1])
h, w, = ( (
h,
w,
) = (
img.shape[0], img.shape[0],
img.shape[1], img.shape[1],
) )
@ -59,68 +61,64 @@ class LSUNBase(Dataset):
image = Image.fromarray(img) image = Image.fromarray(img)
if self.size is not None: if self.size is not None:
image = image.resize( image = image.resize((self.size, self.size), resample=self.interpolation)
(self.size, self.size), resample=self.interpolation
)
image = self.flip(image) image = self.flip(image)
image = np.array(image).astype(np.uint8) image = np.array(image).astype(np.uint8)
example['image'] = (image / 127.5 - 1.0).astype(np.float32) example["image"] = (image / 127.5 - 1.0).astype(np.float32)
return example return example
class LSUNChurchesTrain(LSUNBase): class LSUNChurchesTrain(LSUNBase):
def __init__(self, **kwargs): def __init__(self, **kwargs):
super().__init__( super().__init__(
txt_file='data/lsun/church_outdoor_train.txt', txt_file="data/lsun/church_outdoor_train.txt",
data_root='data/lsun/churches', data_root="data/lsun/churches",
**kwargs **kwargs,
) )
class LSUNChurchesValidation(LSUNBase): class LSUNChurchesValidation(LSUNBase):
def __init__(self, flip_p=0.0, **kwargs): def __init__(self, flip_p=0.0, **kwargs):
super().__init__( super().__init__(
txt_file='data/lsun/church_outdoor_val.txt', txt_file="data/lsun/church_outdoor_val.txt",
data_root='data/lsun/churches', data_root="data/lsun/churches",
flip_p=flip_p, flip_p=flip_p,
**kwargs **kwargs,
) )
class LSUNBedroomsTrain(LSUNBase): class LSUNBedroomsTrain(LSUNBase):
def __init__(self, **kwargs): def __init__(self, **kwargs):
super().__init__( super().__init__(
txt_file='data/lsun/bedrooms_train.txt', txt_file="data/lsun/bedrooms_train.txt",
data_root='data/lsun/bedrooms', data_root="data/lsun/bedrooms",
**kwargs **kwargs,
) )
class LSUNBedroomsValidation(LSUNBase): class LSUNBedroomsValidation(LSUNBase):
def __init__(self, flip_p=0.0, **kwargs): def __init__(self, flip_p=0.0, **kwargs):
super().__init__( super().__init__(
txt_file='data/lsun/bedrooms_val.txt', txt_file="data/lsun/bedrooms_val.txt",
data_root='data/lsun/bedrooms', data_root="data/lsun/bedrooms",
flip_p=flip_p, flip_p=flip_p,
**kwargs **kwargs,
) )
class LSUNCatsTrain(LSUNBase): class LSUNCatsTrain(LSUNBase):
def __init__(self, **kwargs): def __init__(self, **kwargs):
super().__init__( super().__init__(
txt_file='data/lsun/cat_train.txt', txt_file="data/lsun/cat_train.txt", data_root="data/lsun/cats", **kwargs
data_root='data/lsun/cats',
**kwargs
) )
class LSUNCatsValidation(LSUNBase): class LSUNCatsValidation(LSUNBase):
def __init__(self, flip_p=0.0, **kwargs): def __init__(self, flip_p=0.0, **kwargs):
super().__init__( super().__init__(
txt_file='data/lsun/cat_val.txt', txt_file="data/lsun/cat_val.txt",
data_root='data/lsun/cats', data_root="data/lsun/cats",
flip_p=flip_p, flip_p=flip_p,
**kwargs **kwargs,
) )

View File

@ -0,0 +1,199 @@
import os
import random
import numpy as np
import PIL
from PIL import Image
from torch.utils.data import Dataset
from torchvision import transforms
imagenet_templates_smallest = [
"a photo of a {}",
]
imagenet_templates_small = [
"a photo of a {}",
"a rendering of a {}",
"a cropped photo of the {}",
"the photo of a {}",
"a photo of a clean {}",
"a photo of a dirty {}",
"a dark photo of the {}",
"a photo of my {}",
"a photo of the cool {}",
"a close-up photo of a {}",
"a bright photo of the {}",
"a cropped photo of a {}",
"a photo of the {}",
"a good photo of the {}",
"a photo of one {}",
"a close-up photo of the {}",
"a rendition of the {}",
"a photo of the clean {}",
"a rendition of a {}",
"a photo of a nice {}",
"a good photo of a {}",
"a photo of the nice {}",
"a photo of the small {}",
"a photo of the weird {}",
"a photo of the large {}",
"a photo of a cool {}",
"a photo of a small {}",
]
imagenet_dual_templates_small = [
"a photo of a {} with {}",
"a rendering of a {} with {}",
"a cropped photo of the {} with {}",
"the photo of a {} with {}",
"a photo of a clean {} with {}",
"a photo of a dirty {} with {}",
"a dark photo of the {} with {}",
"a photo of my {} with {}",
"a photo of the cool {} with {}",
"a close-up photo of a {} with {}",
"a bright photo of the {} with {}",
"a cropped photo of a {} with {}",
"a photo of the {} with {}",
"a good photo of the {} with {}",
"a photo of one {} with {}",
"a close-up photo of the {} with {}",
"a rendition of the {} with {}",
"a photo of the clean {} with {}",
"a rendition of a {} with {}",
"a photo of a nice {} with {}",
"a good photo of a {} with {}",
"a photo of the nice {} with {}",
"a photo of the small {} with {}",
"a photo of the weird {} with {}",
"a photo of the large {} with {}",
"a photo of a cool {} with {}",
"a photo of a small {} with {}",
]
per_img_token_list = [
"א",
"ב",
"ג",
"ד",
"ה",
"ו",
"ז",
"ח",
"ט",
"י",
"כ",
"ל",
"מ",
"נ",
"ס",
"ע",
"פ",
"צ",
"ק",
"ר",
"ש",
"ת",
]
class PersonalizedBase(Dataset):
def __init__(
self,
data_root,
size=None,
repeats=100,
interpolation="bicubic",
flip_p=0.5,
set="train",
placeholder_token="*",
per_image_tokens=False,
center_crop=False,
mixing_prob=0.25,
coarse_class_text=None,
):
self.data_root = data_root
self.image_paths = [
os.path.join(self.data_root, file_path)
for file_path in os.listdir(self.data_root)
if file_path != ".DS_Store"
]
# self._length = len(self.image_paths)
self.num_images = len(self.image_paths)
self._length = self.num_images
self.placeholder_token = placeholder_token
self.per_image_tokens = per_image_tokens
self.center_crop = center_crop
self.mixing_prob = mixing_prob
self.coarse_class_text = coarse_class_text
if per_image_tokens:
assert self.num_images < len(
per_img_token_list
), f"Can't use per-image tokens when the training set contains more than {len(per_img_token_list)} tokens. To enable larger sets, add more tokens to 'per_img_token_list'."
if set == "train":
self._length = self.num_images * repeats
self.size = size
self.interpolation = {
"linear": PIL.Image.LINEAR,
"bilinear": PIL.Image.BILINEAR,
"bicubic": PIL.Image.BICUBIC,
"lanczos": PIL.Image.LANCZOS,
}[interpolation]
self.flip = transforms.RandomHorizontalFlip(p=flip_p)
def __len__(self):
return self._length
def __getitem__(self, i):
example = {}
image = Image.open(self.image_paths[i % self.num_images])
if not image.mode == "RGB":
image = image.convert("RGB")
placeholder_string = self.placeholder_token
if self.coarse_class_text:
placeholder_string = f"{self.coarse_class_text} {placeholder_string}"
if self.per_image_tokens and np.random.uniform() < self.mixing_prob:
text = random.choice(imagenet_dual_templates_small).format(
placeholder_string, per_img_token_list[i % self.num_images]
)
else:
text = random.choice(imagenet_templates_small).format(placeholder_string)
example["caption"] = text
# default to score-sde preprocessing
img = np.array(image).astype(np.uint8)
if self.center_crop:
crop = min(img.shape[0], img.shape[1])
(
h,
w,
) = (
img.shape[0],
img.shape[1],
)
img = img[
(h - crop) // 2 : (h + crop) // 2,
(w - crop) // 2 : (w + crop) // 2,
]
image = Image.fromarray(img)
if self.size is not None:
image = image.resize((self.size, self.size), resample=self.interpolation)
image = self.flip(image)
image = np.array(image).astype(np.uint8)
example["image"] = (image / 127.5 - 1.0).astype(np.float32)
return example

View File

@ -0,0 +1,170 @@
import os
import random
import numpy as np
import PIL
from PIL import Image
from torch.utils.data import Dataset
from torchvision import transforms
imagenet_templates_small = [
"a painting in the style of {}",
"a rendering in the style of {}",
"a cropped painting in the style of {}",
"the painting in the style of {}",
"a clean painting in the style of {}",
"a dirty painting in the style of {}",
"a dark painting in the style of {}",
"a picture in the style of {}",
"a cool painting in the style of {}",
"a close-up painting in the style of {}",
"a bright painting in the style of {}",
"a cropped painting in the style of {}",
"a good painting in the style of {}",
"a close-up painting in the style of {}",
"a rendition in the style of {}",
"a nice painting in the style of {}",
"a small painting in the style of {}",
"a weird painting in the style of {}",
"a large painting in the style of {}",
]
imagenet_dual_templates_small = [
"a painting in the style of {} with {}",
"a rendering in the style of {} with {}",
"a cropped painting in the style of {} with {}",
"the painting in the style of {} with {}",
"a clean painting in the style of {} with {}",
"a dirty painting in the style of {} with {}",
"a dark painting in the style of {} with {}",
"a cool painting in the style of {} with {}",
"a close-up painting in the style of {} with {}",
"a bright painting in the style of {} with {}",
"a cropped painting in the style of {} with {}",
"a good painting in the style of {} with {}",
"a painting of one {} in the style of {}",
"a nice painting in the style of {} with {}",
"a small painting in the style of {} with {}",
"a weird painting in the style of {} with {}",
"a large painting in the style of {} with {}",
]
per_img_token_list = [
"א",
"ב",
"ג",
"ד",
"ה",
"ו",
"ז",
"ח",
"ט",
"י",
"כ",
"ל",
"מ",
"נ",
"ס",
"ע",
"פ",
"צ",
"ק",
"ר",
"ש",
"ת",
]
class PersonalizedBase(Dataset):
def __init__(
self,
data_root,
size=None,
repeats=100,
interpolation="bicubic",
flip_p=0.5,
set="train",
placeholder_token="*",
per_image_tokens=False,
center_crop=False,
):
self.data_root = data_root
self.image_paths = [
os.path.join(self.data_root, file_path)
for file_path in os.listdir(self.data_root)
if file_path != ".DS_Store"
]
# self._length = len(self.image_paths)
self.num_images = len(self.image_paths)
self._length = self.num_images
self.placeholder_token = placeholder_token
self.per_image_tokens = per_image_tokens
self.center_crop = center_crop
if per_image_tokens:
assert self.num_images < len(
per_img_token_list
), f"Can't use per-image tokens when the training set contains more than {len(per_img_token_list)} tokens. To enable larger sets, add more tokens to 'per_img_token_list'."
if set == "train":
self._length = self.num_images * repeats
self.size = size
self.interpolation = {
"linear": PIL.Image.LINEAR,
"bilinear": PIL.Image.BILINEAR,
"bicubic": PIL.Image.BICUBIC,
"lanczos": PIL.Image.LANCZOS,
}[interpolation]
self.flip = transforms.RandomHorizontalFlip(p=flip_p)
def __len__(self):
return self._length
def __getitem__(self, i):
example = {}
image = Image.open(self.image_paths[i % self.num_images])
if not image.mode == "RGB":
image = image.convert("RGB")
if self.per_image_tokens and np.random.uniform() < 0.25:
text = random.choice(imagenet_dual_templates_small).format(
self.placeholder_token, per_img_token_list[i % self.num_images]
)
else:
text = random.choice(imagenet_templates_small).format(
self.placeholder_token
)
example["caption"] = text
# default to score-sde preprocessing
img = np.array(image).astype(np.uint8)
if self.center_crop:
crop = min(img.shape[0], img.shape[1])
(
h,
w,
) = (
img.shape[0],
img.shape[1],
)
img = img[
(h - crop) // 2 : (h + crop) // 2,
(w - crop) // 2 : (w + crop) // 2,
]
image = Image.fromarray(img)
if self.size is not None:
image = image.resize((self.size, self.size), resample=self.interpolation)
image = self.flip(image)
image = np.array(image).astype(np.uint8)
example["image"] = (image / 127.5 - 1.0).astype(np.float32)
return example

View File

@ -2,22 +2,28 @@ from __future__ import annotations
import dataclasses import dataclasses
import inspect import inspect
import psutil
import secrets import secrets
from collections.abc import Sequence from collections.abc import Sequence
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import List, Optional, Union, Callable, Type, TypeVar, Generic, Any from typing import Any, Callable, Generic, List, Optional, Type, TypeVar, Union
import PIL.Image
import einops import einops
import PIL.Image
import psutil import psutil
import torch import torch
import torchvision.transforms as T import torchvision.transforms as T
from compel import EmbeddingsProvider
from diffusers.models import AutoencoderKL, UNet2DConditionModel from diffusers.models import AutoencoderKL, UNet2DConditionModel
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipeline from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import (
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img import StableDiffusionImg2ImgPipeline StableDiffusionPipeline,
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker )
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img import (
StableDiffusionImg2ImgPipeline,
)
from diffusers.pipelines.stable_diffusion.safety_checker import (
StableDiffusionSafetyChecker,
)
from diffusers.schedulers import KarrasDiffusionSchedulers from diffusers.schedulers import KarrasDiffusionSchedulers
from diffusers.schedulers.scheduling_utils import SchedulerMixin, SchedulerOutput from diffusers.schedulers.scheduling_utils import SchedulerMixin, SchedulerOutput
from diffusers.utils.import_utils import is_xformers_available from diffusers.utils.import_utils import is_xformers_available
@ -26,13 +32,16 @@ from torchvision.transforms.functional import resize as tv_resize
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from typing_extensions import ParamSpec from typing_extensions import ParamSpec
from ldm.invoke.globals import Globals from invokeai.backend.globals import Globals
from ldm.models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent, PostprocessingSettings
from ldm.modules.textual_inversion_manager import TextualInversionManager from ..util import CPU_DEVICE, normalize_device
from ..devices import normalize_device, CPU_DEVICE from .diffusion import (
from ..offloading import LazilyLoadedModelGroup, FullyLoadedModelGroup, ModelGroup AttentionMapSaver,
from ...models.diffusion.cross_attention_map_saving import AttentionMapSaver InvokeAIDiffuserComponent,
from compel import EmbeddingsProvider PostprocessingSettings,
)
from .offloading import FullyLoadedModelGroup, LazilyLoadedModelGroup, ModelGroup
from .textual_inversion_manager import TextualInversionManager
@dataclass @dataclass
@ -51,7 +60,7 @@ _default_personalization_config_params = dict(
initializer_wods=["sculpture"], initializer_wods=["sculpture"],
per_image_tokens=False, per_image_tokens=False,
num_vectors_per_token=1, num_vectors_per_token=1,
progressive_words=False progressive_words=False,
) )
@ -64,29 +73,34 @@ class AddsMaskLatents:
This class assumes the same mask and base image should apply to all items in the batch. This class assumes the same mask and base image should apply to all items in the batch.
""" """
forward: Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor] forward: Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor]
mask: torch.Tensor mask: torch.Tensor
initial_image_latents: torch.Tensor initial_image_latents: torch.Tensor
def __call__(self, latents: torch.Tensor, t: torch.Tensor, text_embeddings: torch.Tensor) -> torch.Tensor: def __call__(
self, latents: torch.Tensor, t: torch.Tensor, text_embeddings: torch.Tensor
) -> torch.Tensor:
model_input = self.add_mask_channels(latents) model_input = self.add_mask_channels(latents)
return self.forward(model_input, t, text_embeddings) return self.forward(model_input, t, text_embeddings)
def add_mask_channels(self, latents): def add_mask_channels(self, latents):
batch_size = latents.size(0) batch_size = latents.size(0)
# duplicate mask and latents for each batch # duplicate mask and latents for each batch
mask = einops.repeat(self.mask, 'b c h w -> (repeat b) c h w', repeat=batch_size) mask = einops.repeat(
image_latents = einops.repeat(self.initial_image_latents, 'b c h w -> (repeat b) c h w', repeat=batch_size) self.mask, "b c h w -> (repeat b) c h w", repeat=batch_size
)
image_latents = einops.repeat(
self.initial_image_latents, "b c h w -> (repeat b) c h w", repeat=batch_size
)
# add mask and image as additional channels # add mask and image as additional channels
model_input, _ = einops.pack([latents, mask, image_latents], 'b * h w') model_input, _ = einops.pack([latents, mask, image_latents], "b * h w")
return model_input return model_input
def are_like_tensors(a: torch.Tensor, b: object) -> bool: def are_like_tensors(a: torch.Tensor, b: object) -> bool:
return ( return isinstance(b, torch.Tensor) and (a.size() == b.size())
isinstance(b, torch.Tensor)
and (a.size() == b.size())
)
@dataclass @dataclass
class AddsMaskGuidance: class AddsMaskGuidance:
@ -96,7 +110,9 @@ class AddsMaskGuidance:
noise: torch.Tensor noise: torch.Tensor
_debug: Optional[Callable] = None _debug: Optional[Callable] = None
def __call__(self, step_output: BaseOutput | SchedulerOutput, t: torch.Tensor, conditioning) -> BaseOutput: def __call__(
self, step_output: BaseOutput | SchedulerOutput, t: torch.Tensor, conditioning
) -> BaseOutput:
output_class = step_output.__class__ # We'll create a new one with masked data. output_class = step_output.__class__ # We'll create a new one with masked data.
# The problem with taking SchedulerOutput instead of the model output is that we're less certain what's in it. # The problem with taking SchedulerOutput instead of the model output is that we're less certain what's in it.
@ -106,30 +122,41 @@ class AddsMaskGuidance:
prev_sample = step_output[0] prev_sample = step_output[0]
# Mask anything that has the same shape as prev_sample, return others as-is. # Mask anything that has the same shape as prev_sample, return others as-is.
return output_class( return output_class(
{k: (self.apply_mask(v, self._t_for_field(k, t)) {
if are_like_tensors(prev_sample, v) else v) k: (
for k, v in step_output.items()} self.apply_mask(v, self._t_for_field(k, t))
if are_like_tensors(prev_sample, v)
else v
)
for k, v in step_output.items()
}
) )
def _t_for_field(self, field_name:str, t): def _t_for_field(self, field_name: str, t):
if field_name == "pred_original_sample": if field_name == "pred_original_sample":
return torch.zeros_like(t, dtype=t.dtype) # it represents t=0 return torch.zeros_like(t, dtype=t.dtype) # it represents t=0
return t return t
def apply_mask(self, latents: torch.Tensor, t) -> torch.Tensor: def apply_mask(self, latents: torch.Tensor, t) -> torch.Tensor:
batch_size = latents.size(0) batch_size = latents.size(0)
mask = einops.repeat(self.mask, 'b c h w -> (repeat b) c h w', repeat=batch_size) mask = einops.repeat(
self.mask, "b c h w -> (repeat b) c h w", repeat=batch_size
)
if t.dim() == 0: if t.dim() == 0:
# some schedulers expect t to be one-dimensional. # some schedulers expect t to be one-dimensional.
# TODO: file diffusers bug about inconsistency? # TODO: file diffusers bug about inconsistency?
t = einops.repeat(t, '-> batch', batch=batch_size) t = einops.repeat(t, "-> batch", batch=batch_size)
# Noise shouldn't be re-randomized between steps here. The multistep schedulers # Noise shouldn't be re-randomized between steps here. The multistep schedulers
# get very confused about what is happening from step to step when we do that. # get very confused about what is happening from step to step when we do that.
mask_latents = self.scheduler.add_noise(self.mask_latents, self.noise, t) mask_latents = self.scheduler.add_noise(self.mask_latents, self.noise, t)
# TODO: Do we need to also apply scheduler.scale_model_input? Or is add_noise appropriately scaled already? # TODO: Do we need to also apply scheduler.scale_model_input? Or is add_noise appropriately scaled already?
# mask_latents = self.scheduler.scale_model_input(mask_latents, t) # mask_latents = self.scheduler.scale_model_input(mask_latents, t)
mask_latents = einops.repeat(mask_latents, 'b c h w -> (repeat b) c h w', repeat=batch_size) mask_latents = einops.repeat(
masked_input = torch.lerp(mask_latents.to(dtype=latents.dtype), latents, mask.to(dtype=latents.dtype)) mask_latents, "b c h w -> (repeat b) c h w", repeat=batch_size
)
masked_input = torch.lerp(
mask_latents.to(dtype=latents.dtype), latents, mask.to(dtype=latents.dtype)
)
if self._debug: if self._debug:
self._debug(masked_input, f"t={t} lerped") self._debug(masked_input, f"t={t} lerped")
return masked_input return masked_input
@ -139,7 +166,9 @@ def trim_to_multiple_of(*args, multiple_of=8):
return tuple((x - x % multiple_of) for x in args) return tuple((x - x % multiple_of) for x in args)
def image_resized_to_grid_as_tensor(image: PIL.Image.Image, normalize: bool=True, multiple_of=8) -> torch.FloatTensor: def image_resized_to_grid_as_tensor(
image: PIL.Image.Image, normalize: bool = True, multiple_of=8
) -> torch.FloatTensor:
""" """
:param image: input image :param image: input image
@ -147,10 +176,12 @@ def image_resized_to_grid_as_tensor(image: PIL.Image.Image, normalize: bool=True
:param multiple_of: resize the input so both dimensions are a multiple of this :param multiple_of: resize the input so both dimensions are a multiple of this
""" """
w, h = trim_to_multiple_of(*image.size) w, h = trim_to_multiple_of(*image.size)
transformation = T.Compose([ transformation = T.Compose(
T.Resize((h, w), T.InterpolationMode.LANCZOS), [
T.ToTensor(), T.Resize((h, w), T.InterpolationMode.LANCZOS),
]) T.ToTensor(),
]
)
tensor = transformation(image) tensor = transformation(image)
if normalize: if normalize:
tensor = tensor * 2.0 - 1.0 tensor = tensor * 2.0 - 1.0
@ -160,9 +191,11 @@ def image_resized_to_grid_as_tensor(image: PIL.Image.Image, normalize: bool=True
def is_inpainting_model(unet: UNet2DConditionModel): def is_inpainting_model(unet: UNet2DConditionModel):
return unet.conv_in.in_channels == 9 return unet.conv_in.in_channels == 9
CallbackType = TypeVar('CallbackType')
ReturnType = TypeVar('ReturnType') CallbackType = TypeVar("CallbackType")
ParamType = ParamSpec('ParamType') ReturnType = TypeVar("ReturnType")
ParamType = ParamSpec("ParamType")
@dataclass(frozen=True) @dataclass(frozen=True)
class GeneratorToCallbackinator(Generic[ParamType, ReturnType, CallbackType]): class GeneratorToCallbackinator(Generic[ParamType, ReturnType, CallbackType]):
@ -171,9 +204,12 @@ class GeneratorToCallbackinator(Generic[ParamType, ReturnType, CallbackType]):
generator_method: Callable[ParamType, ReturnType] generator_method: Callable[ParamType, ReturnType]
callback_arg_type: Type[CallbackType] callback_arg_type: Type[CallbackType]
def __call__(self, *args: ParamType.args, def __call__(
callback:Callable[[CallbackType], Any]=None, self,
**kwargs: ParamType.kwargs) -> ReturnType: *args: ParamType.args,
callback: Callable[[CallbackType], Any] = None,
**kwargs: ParamType.kwargs,
) -> ReturnType:
result = None result = None
for result in self.generator_method(*args, **kwargs): for result in self.generator_method(*args, **kwargs):
if callback is not None and isinstance(result, self.callback_arg_type): if callback is not None and isinstance(result, self.callback_arg_type):
@ -218,6 +254,7 @@ class ConditioningData:
scheduler_args[name] = value scheduler_args[name] = value
return dataclasses.replace(self, scheduler_args=scheduler_args) return dataclasses.replace(self, scheduler_args=scheduler_args)
@dataclass @dataclass
class InvokeAIStableDiffusionPipelineOutput(StableDiffusionPipelineOutput): class InvokeAIStableDiffusionPipelineOutput(StableDiffusionPipelineOutput):
r""" r"""
@ -275,10 +312,18 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
safety_checker: Optional[StableDiffusionSafetyChecker], safety_checker: Optional[StableDiffusionSafetyChecker],
feature_extractor: Optional[CLIPFeatureExtractor], feature_extractor: Optional[CLIPFeatureExtractor],
requires_safety_checker: bool = False, requires_safety_checker: bool = False,
precision: str = 'float32', precision: str = "float32",
): ):
super().__init__(vae, text_encoder, tokenizer, unet, scheduler, super().__init__(
safety_checker, feature_extractor, requires_safety_checker) vae,
text_encoder,
tokenizer,
unet,
scheduler,
safety_checker,
feature_extractor,
requires_safety_checker,
)
self.register_modules( self.register_modules(
vae=vae, vae=vae,
@ -289,27 +334,34 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
safety_checker=safety_checker, safety_checker=safety_checker,
feature_extractor=feature_extractor, feature_extractor=feature_extractor,
) )
self.invokeai_diffuser = InvokeAIDiffuserComponent(self.unet, self._unet_forward, is_running_diffusers=True) self.invokeai_diffuser = InvokeAIDiffuserComponent(
use_full_precision = (precision == 'float32' or precision == 'autocast') self.unet, self._unet_forward, is_running_diffusers=True
self.textual_inversion_manager = TextualInversionManager(tokenizer=self.tokenizer, )
text_encoder=self.text_encoder, use_full_precision = precision == "float32" or precision == "autocast"
full_precision=use_full_precision) self.textual_inversion_manager = TextualInversionManager(
tokenizer=self.tokenizer,
text_encoder=self.text_encoder,
full_precision=use_full_precision,
)
# InvokeAI's interface for text embeddings and whatnot # InvokeAI's interface for text embeddings and whatnot
self.embeddings_provider = EmbeddingsProvider( self.embeddings_provider = EmbeddingsProvider(
tokenizer=self.tokenizer, tokenizer=self.tokenizer,
text_encoder=self.text_encoder, text_encoder=self.text_encoder,
textual_inversion_manager=self.textual_inversion_manager textual_inversion_manager=self.textual_inversion_manager,
) )
self._model_group = FullyLoadedModelGroup(self.unet.device) self._model_group = FullyLoadedModelGroup(self.unet.device)
self._model_group.install(*self._submodels) self._model_group.install(*self._submodels)
def _adjust_memory_efficient_attention(self, latents: torch.Tensor): def _adjust_memory_efficient_attention(self, latents: torch.Tensor):
""" """
if xformers is available, use it, otherwise use sliced attention. if xformers is available, use it, otherwise use sliced attention.
""" """
if torch.cuda.is_available() and is_xformers_available() and not Globals.disable_xformers: if (
torch.cuda.is_available()
and is_xformers_available()
and not Globals.disable_xformers
):
self.enable_xformers_memory_efficient_attention() self.enable_xformers_memory_efficient_attention()
else: else:
if torch.backends.mps.is_available(): if torch.backends.mps.is_available():
@ -318,25 +370,32 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
# fix is in https://github.com/kulinseth/pytorch/pull/222 but no idea when it will get merged to pytorch mainline. # fix is in https://github.com/kulinseth/pytorch/pull/222 but no idea when it will get merged to pytorch mainline.
pass pass
else: else:
if self.device.type == 'cpu' or self.device.type == 'mps': if self.device.type == "cpu" or self.device.type == "mps":
mem_free = psutil.virtual_memory().free mem_free = psutil.virtual_memory().free
elif self.device.type == 'cuda': elif self.device.type == "cuda":
mem_free, _ = torch.cuda.mem_get_info(normalize_device(self.device)) mem_free, _ = torch.cuda.mem_get_info(normalize_device(self.device))
else: else:
raise ValueError(f"unrecognized device {self.device}") raise ValueError(f"unrecognized device {self.device}")
# input tensor of [1, 4, h/8, w/8] # input tensor of [1, 4, h/8, w/8]
# output tensor of [16, (h/8 * w/8), (h/8 * w/8)] # output tensor of [16, (h/8 * w/8), (h/8 * w/8)]
bytes_per_element_needed_for_baddbmm_duplication = latents.element_size() + 4 bytes_per_element_needed_for_baddbmm_duplication = (
max_size_required_for_baddbmm = \ latents.element_size() + 4
16 * \ )
latents.size(dim=2) * latents.size(dim=3) * latents.size(dim=2) * latents.size(dim=3) * \ max_size_required_for_baddbmm = (
bytes_per_element_needed_for_baddbmm_duplication 16
if max_size_required_for_baddbmm > (mem_free * 3.0 / 4.0): # 3.3 / 4.0 is from old Invoke code * latents.size(dim=2)
self.enable_attention_slicing(slice_size='max') * latents.size(dim=3)
* latents.size(dim=2)
* latents.size(dim=3)
* bytes_per_element_needed_for_baddbmm_duplication
)
if max_size_required_for_baddbmm > (
mem_free * 3.0 / 4.0
): # 3.3 / 4.0 is from old Invoke code
self.enable_attention_slicing(slice_size="max")
else: else:
self.disable_attention_slicing() self.disable_attention_slicing()
def enable_offload_submodels(self, device: torch.device): def enable_offload_submodels(self, device: torch.device):
""" """
Offload each submodel when it's not in use. Offload each submodel when it's not in use.
@ -398,12 +457,16 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
values = [getattr(self, name) for name in module_names.keys()] values = [getattr(self, name) for name in module_names.keys()]
return [m for m in values if isinstance(m, torch.nn.Module)] return [m for m in values if isinstance(m, torch.nn.Module)]
def image_from_embeddings(self, latents: torch.Tensor, num_inference_steps: int, def image_from_embeddings(
conditioning_data: ConditioningData, self,
*, latents: torch.Tensor,
noise: torch.Tensor, num_inference_steps: int,
callback: Callable[[PipelineIntermediateState], None]=None, conditioning_data: ConditioningData,
run_id=None) -> InvokeAIStableDiffusionPipelineOutput: *,
noise: torch.Tensor,
callback: Callable[[PipelineIntermediateState], None] = None,
run_id=None,
) -> InvokeAIStableDiffusionPipelineOutput:
r""" r"""
Function invoked when calling the pipeline for generation. Function invoked when calling the pipeline for generation.
@ -417,71 +480,104 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
:param run_id: :param run_id:
""" """
result_latents, result_attention_map_saver = self.latents_from_embeddings( result_latents, result_attention_map_saver = self.latents_from_embeddings(
latents, num_inference_steps, latents,
num_inference_steps,
conditioning_data, conditioning_data,
noise=noise, noise=noise,
run_id=run_id, run_id=run_id,
callback=callback) callback=callback,
)
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699 # https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
torch.cuda.empty_cache() torch.cuda.empty_cache()
with torch.inference_mode(): with torch.inference_mode():
image = self.decode_latents(result_latents) image = self.decode_latents(result_latents)
output = InvokeAIStableDiffusionPipelineOutput(images=image, nsfw_content_detected=[], attention_map_saver=result_attention_map_saver) output = InvokeAIStableDiffusionPipelineOutput(
images=image,
nsfw_content_detected=[],
attention_map_saver=result_attention_map_saver,
)
return self.check_for_safety(output, dtype=conditioning_data.dtype) return self.check_for_safety(output, dtype=conditioning_data.dtype)
def latents_from_embeddings(self, latents: torch.Tensor, num_inference_steps: int, def latents_from_embeddings(
conditioning_data: ConditioningData, self,
*, latents: torch.Tensor,
noise: torch.Tensor, num_inference_steps: int,
timesteps=None, conditioning_data: ConditioningData,
additional_guidance: List[Callable] = None, run_id=None, *,
callback: Callable[[PipelineIntermediateState], None] = None noise: torch.Tensor,
) -> tuple[torch.Tensor, Optional[AttentionMapSaver]]: timesteps=None,
additional_guidance: List[Callable] = None,
run_id=None,
callback: Callable[[PipelineIntermediateState], None] = None,
) -> tuple[torch.Tensor, Optional[AttentionMapSaver]]:
if timesteps is None: if timesteps is None:
self.scheduler.set_timesteps(num_inference_steps, device=self._model_group.device_for(self.unet)) self.scheduler.set_timesteps(
num_inference_steps, device=self._model_group.device_for(self.unet)
)
timesteps = self.scheduler.timesteps timesteps = self.scheduler.timesteps
infer_latents_from_embeddings = GeneratorToCallbackinator(self.generate_latents_from_embeddings, PipelineIntermediateState) infer_latents_from_embeddings = GeneratorToCallbackinator(
self.generate_latents_from_embeddings, PipelineIntermediateState
)
result: PipelineIntermediateState = infer_latents_from_embeddings( result: PipelineIntermediateState = infer_latents_from_embeddings(
latents, timesteps, conditioning_data, latents,
timesteps,
conditioning_data,
noise=noise, noise=noise,
additional_guidance=additional_guidance, additional_guidance=additional_guidance,
run_id=run_id, run_id=run_id,
callback=callback) callback=callback,
)
return result.latents, result.attention_map_saver return result.latents, result.attention_map_saver
def generate_latents_from_embeddings(self, latents: torch.Tensor, timesteps, def generate_latents_from_embeddings(
conditioning_data: ConditioningData, self,
*, latents: torch.Tensor,
noise: torch.Tensor, timesteps,
run_id: str = None, conditioning_data: ConditioningData,
additional_guidance: List[Callable] = None): *,
noise: torch.Tensor,
run_id: str = None,
additional_guidance: List[Callable] = None,
):
self._adjust_memory_efficient_attention(latents) self._adjust_memory_efficient_attention(latents)
if run_id is None: if run_id is None:
run_id = secrets.token_urlsafe(self.ID_LENGTH) run_id = secrets.token_urlsafe(self.ID_LENGTH)
if additional_guidance is None: if additional_guidance is None:
additional_guidance = [] additional_guidance = []
extra_conditioning_info = conditioning_data.extra extra_conditioning_info = conditioning_data.extra
with self.invokeai_diffuser.custom_attention_context(extra_conditioning_info=extra_conditioning_info, with self.invokeai_diffuser.custom_attention_context(
step_count=len(self.scheduler.timesteps) extra_conditioning_info=extra_conditioning_info,
): step_count=len(self.scheduler.timesteps),
):
yield PipelineIntermediateState(run_id=run_id, step=-1, timestep=self.scheduler.num_train_timesteps, yield PipelineIntermediateState(
latents=latents) run_id=run_id,
step=-1,
timestep=self.scheduler.num_train_timesteps,
latents=latents,
)
batch_size = latents.shape[0] batch_size = latents.shape[0]
batched_t = torch.full((batch_size,), timesteps[0], batched_t = torch.full(
dtype=timesteps.dtype, device=self._model_group.device_for(self.unet)) (batch_size,),
timesteps[0],
dtype=timesteps.dtype,
device=self._model_group.device_for(self.unet),
)
latents = self.scheduler.add_noise(latents, noise, batched_t) latents = self.scheduler.add_noise(latents, noise, batched_t)
attention_map_saver: Optional[AttentionMapSaver] = None attention_map_saver: Optional[AttentionMapSaver] = None
for i, t in enumerate(self.progress_bar(timesteps)): for i, t in enumerate(self.progress_bar(timesteps)):
batched_t.fill_(t) batched_t.fill_(t)
step_output = self.step(batched_t, latents, conditioning_data, step_output = self.step(
step_index=i, batched_t,
total_step_count=len(timesteps), latents,
additional_guidance=additional_guidance) conditioning_data,
step_index=i,
total_step_count=len(timesteps),
additional_guidance=additional_guidance,
)
latents = step_output.prev_sample latents = step_output.prev_sample
latents = self.invokeai_diffuser.do_latent_postprocessing( latents = self.invokeai_diffuser.do_latent_postprocessing(
@ -489,28 +585,39 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
latents=latents, latents=latents,
sigma=batched_t, sigma=batched_t,
step_index=i, step_index=i,
total_step_count=len(timesteps) total_step_count=len(timesteps),
) )
predicted_original = getattr(step_output, 'pred_original_sample', None) predicted_original = getattr(step_output, "pred_original_sample", None)
# TODO resuscitate attention map saving # TODO resuscitate attention map saving
#if i == len(timesteps)-1 and extra_conditioning_info is not None: # if i == len(timesteps)-1 and extra_conditioning_info is not None:
# eos_token_index = extra_conditioning_info.tokens_count_including_eos_bos - 1 # eos_token_index = extra_conditioning_info.tokens_count_including_eos_bos - 1
# attention_map_token_ids = range(1, eos_token_index) # attention_map_token_ids = range(1, eos_token_index)
# attention_map_saver = AttentionMapSaver(token_ids=attention_map_token_ids, latents_shape=latents.shape[-2:]) # attention_map_saver = AttentionMapSaver(token_ids=attention_map_token_ids, latents_shape=latents.shape[-2:])
# self.invokeai_diffuser.setup_attention_map_saving(attention_map_saver) # self.invokeai_diffuser.setup_attention_map_saving(attention_map_saver)
yield PipelineIntermediateState(run_id=run_id, step=i, timestep=int(t), latents=latents, yield PipelineIntermediateState(
predicted_original=predicted_original, attention_map_saver=attention_map_saver) run_id=run_id,
step=i,
timestep=int(t),
latents=latents,
predicted_original=predicted_original,
attention_map_saver=attention_map_saver,
)
return latents, attention_map_saver return latents, attention_map_saver
@torch.inference_mode() @torch.inference_mode()
def step(self, t: torch.Tensor, latents: torch.Tensor, def step(
conditioning_data: ConditioningData, self,
step_index:int, total_step_count:int, t: torch.Tensor,
additional_guidance: List[Callable] = None): latents: torch.Tensor,
conditioning_data: ConditioningData,
step_index: int,
total_step_count: int,
additional_guidance: List[Callable] = None,
):
# invokeai_diffuser has batched timesteps, but diffusers schedulers expect a single value # invokeai_diffuser has batched timesteps, but diffusers schedulers expect a single value
timestep = t[0] timestep = t[0]
@ -523,16 +630,19 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
# predict the noise residual # predict the noise residual
noise_pred = self.invokeai_diffuser.do_diffusion_step( noise_pred = self.invokeai_diffuser.do_diffusion_step(
latent_model_input, t, latent_model_input,
conditioning_data.unconditioned_embeddings, conditioning_data.text_embeddings, t,
conditioning_data.unconditioned_embeddings,
conditioning_data.text_embeddings,
conditioning_data.guidance_scale, conditioning_data.guidance_scale,
step_index=step_index, step_index=step_index,
total_step_count=total_step_count, total_step_count=total_step_count,
) )
# compute the previous noisy sample x_t -> x_t-1 # compute the previous noisy sample x_t -> x_t-1
step_output = self.scheduler.step(noise_pred, timestep, latents, step_output = self.scheduler.step(
**conditioning_data.scheduler_args) noise_pred, timestep, latents, **conditioning_data.scheduler_args
)
# TODO: this additional_guidance extension point feels redundant with InvokeAIDiffusionComponent. # TODO: this additional_guidance extension point feels redundant with InvokeAIDiffusionComponent.
# But the way things are now, scheduler runs _after_ that, so there was # But the way things are now, scheduler runs _after_ that, so there was
@ -542,7 +652,13 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
return step_output return step_output
def _unet_forward(self, latents, t, text_embeddings, cross_attention_kwargs: Optional[dict[str,Any]] = None): def _unet_forward(
self,
latents,
t,
text_embeddings,
cross_attention_kwargs: Optional[dict[str, Any]] = None,
):
"""predict the noise residual""" """predict the noise residual"""
if is_inpainting_model(self.unet) and latents.size(1) == 4: if is_inpainting_model(self.unet) and latents.size(1) == 4:
# Pad out normal non-inpainting inputs for an inpainting model. # Pad out normal non-inpainting inputs for an inpainting model.
@ -551,67 +667,100 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
# use of AddsMaskLatents. # use of AddsMaskLatents.
latents = AddsMaskLatents( latents = AddsMaskLatents(
self._unet_forward, self._unet_forward,
mask=torch.ones_like(latents[:1, :1], device=latents.device, dtype=latents.dtype), mask=torch.ones_like(
initial_image_latents=torch.zeros_like(latents[:1], device=latents.device, dtype=latents.dtype) latents[:1, :1], device=latents.device, dtype=latents.dtype
),
initial_image_latents=torch.zeros_like(
latents[:1], device=latents.device, dtype=latents.dtype
),
).add_mask_channels(latents) ).add_mask_channels(latents)
# First three args should be positional, not keywords, so torch hooks can see them. # First three args should be positional, not keywords, so torch hooks can see them.
return self.unet(latents, t, text_embeddings, return self.unet(
cross_attention_kwargs=cross_attention_kwargs).sample latents, t, text_embeddings, cross_attention_kwargs=cross_attention_kwargs
).sample
def img2img_from_embeddings(self, def img2img_from_embeddings(
init_image: Union[torch.FloatTensor, PIL.Image.Image], self,
strength: float, init_image: Union[torch.FloatTensor, PIL.Image.Image],
num_inference_steps: int, strength: float,
conditioning_data: ConditioningData, num_inference_steps: int,
*, callback: Callable[[PipelineIntermediateState], None] = None, conditioning_data: ConditioningData,
run_id=None, *,
noise_func=None callback: Callable[[PipelineIntermediateState], None] = None,
) -> InvokeAIStableDiffusionPipelineOutput: run_id=None,
noise_func=None,
) -> InvokeAIStableDiffusionPipelineOutput:
if isinstance(init_image, PIL.Image.Image): if isinstance(init_image, PIL.Image.Image):
init_image = image_resized_to_grid_as_tensor(init_image.convert('RGB')) init_image = image_resized_to_grid_as_tensor(init_image.convert("RGB"))
if init_image.dim() == 3: if init_image.dim() == 3:
init_image = einops.rearrange(init_image, 'c h w -> 1 c h w') init_image = einops.rearrange(init_image, "c h w -> 1 c h w")
# 6. Prepare latent variables # 6. Prepare latent variables
initial_latents = self.non_noised_latents_from_image( initial_latents = self.non_noised_latents_from_image(
init_image, device=self._model_group.device_for(self.unet), init_image,
dtype=self.unet.dtype) device=self._model_group.device_for(self.unet),
dtype=self.unet.dtype,
)
noise = noise_func(initial_latents) noise = noise_func(initial_latents)
return self.img2img_from_latents_and_embeddings(initial_latents, num_inference_steps, return self.img2img_from_latents_and_embeddings(
conditioning_data, initial_latents,
strength, num_inference_steps,
noise, run_id, callback) conditioning_data,
strength,
noise,
run_id,
callback,
)
def img2img_from_latents_and_embeddings(self, initial_latents, num_inference_steps, def img2img_from_latents_and_embeddings(
conditioning_data: ConditioningData, self,
strength, initial_latents,
noise: torch.Tensor, run_id=None, callback=None num_inference_steps,
) -> InvokeAIStableDiffusionPipelineOutput: conditioning_data: ConditioningData,
timesteps, _ = self.get_img2img_timesteps(num_inference_steps, strength, strength,
device=self._model_group.device_for(self.unet)) noise: torch.Tensor,
run_id=None,
callback=None,
) -> InvokeAIStableDiffusionPipelineOutput:
timesteps, _ = self.get_img2img_timesteps(
num_inference_steps,
strength,
device=self._model_group.device_for(self.unet),
)
result_latents, result_attention_maps = self.latents_from_embeddings( result_latents, result_attention_maps = self.latents_from_embeddings(
initial_latents, num_inference_steps, conditioning_data, initial_latents,
num_inference_steps,
conditioning_data,
timesteps=timesteps, timesteps=timesteps,
noise=noise, noise=noise,
run_id=run_id, run_id=run_id,
callback=callback) callback=callback,
)
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699 # https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
torch.cuda.empty_cache() torch.cuda.empty_cache()
with torch.inference_mode(): with torch.inference_mode():
image = self.decode_latents(result_latents) image = self.decode_latents(result_latents)
output = InvokeAIStableDiffusionPipelineOutput(images=image, nsfw_content_detected=[], attention_map_saver=result_attention_maps) output = InvokeAIStableDiffusionPipelineOutput(
images=image,
nsfw_content_detected=[],
attention_map_saver=result_attention_maps,
)
return self.check_for_safety(output, dtype=conditioning_data.dtype) return self.check_for_safety(output, dtype=conditioning_data.dtype)
def get_img2img_timesteps(self, num_inference_steps: int, strength: float, device) -> (torch.Tensor, int): def get_img2img_timesteps(
self, num_inference_steps: int, strength: float, device
) -> (torch.Tensor, int):
img2img_pipeline = StableDiffusionImg2ImgPipeline(**self.components) img2img_pipeline = StableDiffusionImg2ImgPipeline(**self.components)
assert img2img_pipeline.scheduler is self.scheduler assert img2img_pipeline.scheduler is self.scheduler
img2img_pipeline.scheduler.set_timesteps(num_inference_steps, device=device) img2img_pipeline.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps, adjusted_steps = img2img_pipeline.get_timesteps(num_inference_steps, strength, device=device) timesteps, adjusted_steps = img2img_pipeline.get_timesteps(
num_inference_steps, strength, device=device
)
# Workaround for low strength resulting in zero timesteps. # Workaround for low strength resulting in zero timesteps.
# TODO: submit upstream fix for zero-step img2img # TODO: submit upstream fix for zero-step img2img
if timesteps.numel() == 0: if timesteps.numel() == 0:
@ -620,21 +769,22 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
return timesteps, adjusted_steps return timesteps, adjusted_steps
def inpaint_from_embeddings( def inpaint_from_embeddings(
self, self,
init_image: torch.FloatTensor, init_image: torch.FloatTensor,
mask: torch.FloatTensor, mask: torch.FloatTensor,
strength: float, strength: float,
num_inference_steps: int, num_inference_steps: int,
conditioning_data: ConditioningData, conditioning_data: ConditioningData,
*, callback: Callable[[PipelineIntermediateState], None] = None, *,
run_id=None, callback: Callable[[PipelineIntermediateState], None] = None,
noise_func=None, run_id=None,
) -> InvokeAIStableDiffusionPipelineOutput: noise_func=None,
) -> InvokeAIStableDiffusionPipelineOutput:
device = self._model_group.device_for(self.unet) device = self._model_group.device_for(self.unet)
latents_dtype = self.unet.dtype latents_dtype = self.unet.dtype
if isinstance(init_image, PIL.Image.Image): if isinstance(init_image, PIL.Image.Image):
init_image = image_resized_to_grid_as_tensor(init_image.convert('RGB')) init_image = image_resized_to_grid_as_tensor(init_image.convert("RGB"))
init_image = init_image.to(device=device, dtype=latents_dtype) init_image = init_image.to(device=device, dtype=latents_dtype)
mask = mask.to(device=device, dtype=latents_dtype) mask = mask.to(device=device, dtype=latents_dtype)
@ -642,18 +792,23 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
if init_image.dim() == 3: if init_image.dim() == 3:
init_image = init_image.unsqueeze(0) init_image = init_image.unsqueeze(0)
timesteps, _ = self.get_img2img_timesteps(num_inference_steps, strength, device=device) timesteps, _ = self.get_img2img_timesteps(
num_inference_steps, strength, device=device
)
# 6. Prepare latent variables # 6. Prepare latent variables
# can't quite use upstream StableDiffusionImg2ImgPipeline.prepare_latents # can't quite use upstream StableDiffusionImg2ImgPipeline.prepare_latents
# because we have our own noise function # because we have our own noise function
init_image_latents = self.non_noised_latents_from_image(init_image, device=device, dtype=latents_dtype) init_image_latents = self.non_noised_latents_from_image(
init_image, device=device, dtype=latents_dtype
)
noise = noise_func(init_image_latents) noise = noise_func(init_image_latents)
if mask.dim() == 3: if mask.dim() == 3:
mask = mask.unsqueeze(0) mask = mask.unsqueeze(0)
latent_mask = tv_resize(mask, init_image_latents.shape[-2:], T.InterpolationMode.BILINEAR) \ latent_mask = tv_resize(
.to(device=device, dtype=latents_dtype) mask, init_image_latents.shape[-2:], T.InterpolationMode.BILINEAR
).to(device=device, dtype=latents_dtype)
guidance: List[Callable] = [] guidance: List[Callable] = []
@ -661,20 +816,30 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
# You'd think the inpainting model wouldn't be paying attention to the area it is going to repaint # You'd think the inpainting model wouldn't be paying attention to the area it is going to repaint
# (that's why there's a mask!) but it seems to really want that blanked out. # (that's why there's a mask!) but it seems to really want that blanked out.
masked_init_image = init_image * torch.where(mask < 0.5, 1, 0) masked_init_image = init_image * torch.where(mask < 0.5, 1, 0)
masked_latents = self.non_noised_latents_from_image(masked_init_image, device=device, dtype=latents_dtype) masked_latents = self.non_noised_latents_from_image(
masked_init_image, device=device, dtype=latents_dtype
)
# TODO: we should probably pass this in so we don't have to try/finally around setting it. # TODO: we should probably pass this in so we don't have to try/finally around setting it.
self.invokeai_diffuser.model_forward_callback = \ self.invokeai_diffuser.model_forward_callback = AddsMaskLatents(
AddsMaskLatents(self._unet_forward, latent_mask, masked_latents) self._unet_forward, latent_mask, masked_latents
)
else: else:
guidance.append(AddsMaskGuidance(latent_mask, init_image_latents, self.scheduler, noise)) guidance.append(
AddsMaskGuidance(latent_mask, init_image_latents, self.scheduler, noise)
)
try: try:
result_latents, result_attention_maps = self.latents_from_embeddings( result_latents, result_attention_maps = self.latents_from_embeddings(
init_image_latents, num_inference_steps, init_image_latents,
conditioning_data, noise=noise, timesteps=timesteps, num_inference_steps,
conditioning_data,
noise=noise,
timesteps=timesteps,
additional_guidance=guidance, additional_guidance=guidance,
run_id=run_id, callback=callback) run_id=run_id,
callback=callback,
)
finally: finally:
self.invokeai_diffuser.model_forward_callback = self._unet_forward self.invokeai_diffuser.model_forward_callback = self._unet_forward
@ -683,13 +848,17 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
with torch.inference_mode(): with torch.inference_mode():
image = self.decode_latents(result_latents) image = self.decode_latents(result_latents)
output = InvokeAIStableDiffusionPipelineOutput(images=image, nsfw_content_detected=[], attention_map_saver=result_attention_maps) output = InvokeAIStableDiffusionPipelineOutput(
images=image,
nsfw_content_detected=[],
attention_map_saver=result_attention_maps,
)
return self.check_for_safety(output, dtype=conditioning_data.dtype) return self.check_for_safety(output, dtype=conditioning_data.dtype)
def non_noised_latents_from_image(self, init_image, *, device: torch.device, dtype): def non_noised_latents_from_image(self, init_image, *, device: torch.device, dtype):
init_image = init_image.to(device=device, dtype=dtype) init_image = init_image.to(device=device, dtype=dtype)
with torch.inference_mode(): with torch.inference_mode():
if device.type == 'mps': if device.type == "mps":
# workaround for torch MPS bug that has been fixed in https://github.com/kulinseth/pytorch/pull/222 # workaround for torch MPS bug that has been fixed in https://github.com/kulinseth/pytorch/pull/222
# TODO remove this workaround once kulinseth#222 is merged to pytorch mainline # TODO remove this workaround once kulinseth#222 is merged to pytorch mainline
self.vae.to(CPU_DEVICE) self.vae.to(CPU_DEVICE)
@ -697,8 +866,10 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
else: else:
self._model_group.load(self.vae) self._model_group.load(self.vae)
init_latent_dist = self.vae.encode(init_image).latent_dist init_latent_dist = self.vae.encode(init_image).latent_dist
init_latents = init_latent_dist.sample().to(dtype=dtype) # FIXME: uses torch.randn. make reproducible! init_latents = init_latent_dist.sample().to(
if device.type == 'mps': dtype=dtype
) # FIXME: uses torch.randn. make reproducible!
if device.type == "mps":
self.vae.to(device) self.vae.to(device)
init_latents = init_latents.to(device) init_latents = init_latents.to(device)
@ -707,14 +878,18 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
def check_for_safety(self, output, dtype): def check_for_safety(self, output, dtype):
with torch.inference_mode(): with torch.inference_mode():
screened_images, has_nsfw_concept = self.run_safety_checker(output.images, dtype=dtype) screened_images, has_nsfw_concept = self.run_safety_checker(
output.images, dtype=dtype
)
screened_attention_map_saver = None screened_attention_map_saver = None
if has_nsfw_concept is None or not has_nsfw_concept: if has_nsfw_concept is None or not has_nsfw_concept:
screened_attention_map_saver = output.attention_map_saver screened_attention_map_saver = output.attention_map_saver
return InvokeAIStableDiffusionPipelineOutput(screened_images, return InvokeAIStableDiffusionPipelineOutput(
has_nsfw_concept, screened_images,
# block the attention maps if NSFW content is detected has_nsfw_concept,
attention_map_saver=screened_attention_map_saver) # block the attention maps if NSFW content is detected
attention_map_saver=screened_attention_map_saver,
)
def run_safety_checker(self, image, device=None, dtype=None): def run_safety_checker(self, image, device=None, dtype=None):
# overriding to use the model group for device info instead of requiring the caller to know. # overriding to use the model group for device info instead of requiring the caller to know.
@ -723,15 +898,18 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
return super().run_safety_checker(image, device, dtype) return super().run_safety_checker(image, device, dtype)
@torch.inference_mode() @torch.inference_mode()
def get_learned_conditioning(self, c: List[List[str]], *, return_tokens=True, fragment_weights=None): def get_learned_conditioning(
self, c: List[List[str]], *, return_tokens=True, fragment_weights=None
):
""" """
Compatibility function for ldm.models.diffusion.ddpm.LatentDiffusion. Compatibility function for invokeai.models.diffusion.ddpm.LatentDiffusion.
""" """
return self.embeddings_provider.get_embeddings_for_weighted_prompt_fragments( return self.embeddings_provider.get_embeddings_for_weighted_prompt_fragments(
text_batch=c, text_batch=c,
fragment_weights_batch=fragment_weights, fragment_weights_batch=fragment_weights,
should_return_tokens=return_tokens, should_return_tokens=return_tokens,
device=self._model_group.device_for(self.unet)) device=self._model_group.device_for(self.unet),
)
@property @property
def cond_stage_model(self): def cond_stage_model(self):
@ -760,6 +938,9 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
def debug_latents(self, latents, msg): def debug_latents(self, latents, msg):
with torch.inference_mode(): with torch.inference_mode():
from ldm.util import debug_image from ldm.util import debug_image
decoded = self.numpy_to_pil(self.decode_latents(latents)) decoded = self.numpy_to_pil(self.decode_latents(latents))
for i, img in enumerate(decoded): for i, img in enumerate(decoded):
debug_image(img, f"latents {msg} {i+1}/{len(decoded)}", debug_status=True) debug_image(
img, f"latents {msg} {i+1}/{len(decoded)}", debug_status=True
)

View File

@ -0,0 +1,6 @@
"""
Initialization file for invokeai.models.diffusion
"""
from .cross_attention_control import InvokeAICrossAttentionMixin
from .cross_attention_map_saving import AttentionMapSaver
from .shared_invokeai_diffusion import InvokeAIDiffuserComponent, PostprocessingSettings

View File

@ -1,22 +1,19 @@
import os import os
import torch from copy import deepcopy
from glob import glob
import pytorch_lightning as pl import pytorch_lightning as pl
import torch
from einops import rearrange
from ldm.modules.diffusionmodules.openaimodel import EncoderUNetModel, UNetModel
from ldm.util import default, instantiate_from_config, ismap, log_txt_as_img
from natsort import natsorted
from omegaconf import OmegaConf from omegaconf import OmegaConf
from torch.nn import functional as F from torch.nn import functional as F
from torch.optim import AdamW from torch.optim import AdamW
from torch.optim.lr_scheduler import LambdaLR from torch.optim.lr_scheduler import LambdaLR
from copy import deepcopy
from einops import rearrange
from glob import glob
from natsort import natsorted
from ldm.modules.diffusionmodules.openaimodel import ( __models__ = {"class_label": EncoderUNetModel, "segmentation": UNetModel}
EncoderUNetModel,
UNetModel,
)
from ldm.util import log_txt_as_img, default, ismap, instantiate_from_config
__models__ = {'class_label': EncoderUNetModel, 'segmentation': UNetModel}
def disabled_train(self, mode=True): def disabled_train(self, mode=True):
@ -31,13 +28,13 @@ class NoisyLatentImageClassifier(pl.LightningModule):
diffusion_path, diffusion_path,
num_classes, num_classes,
ckpt_path=None, ckpt_path=None,
pool='attention', pool="attention",
label_key=None, label_key=None,
diffusion_ckpt_path=None, diffusion_ckpt_path=None,
scheduler_config=None, scheduler_config=None,
weight_decay=1.0e-2, weight_decay=1.0e-2,
log_steps=10, log_steps=10,
monitor='val/loss', monitor="val/loss",
*args, *args,
**kwargs, **kwargs,
): ):
@ -45,30 +42,26 @@ class NoisyLatentImageClassifier(pl.LightningModule):
self.num_classes = num_classes self.num_classes = num_classes
# get latest config of diffusion model # get latest config of diffusion model
diffusion_config = natsorted( diffusion_config = natsorted(
glob(os.path.join(diffusion_path, 'configs', '*-project.yaml')) glob(os.path.join(diffusion_path, "configs", "*-project.yaml"))
)[-1] )[-1]
self.diffusion_config = OmegaConf.load(diffusion_config).model self.diffusion_config = OmegaConf.load(diffusion_config).model
self.diffusion_config.params.ckpt_path = diffusion_ckpt_path self.diffusion_config.params.ckpt_path = diffusion_ckpt_path
self.load_diffusion() self.load_diffusion()
self.monitor = monitor self.monitor = monitor
self.numd = ( self.numd = self.diffusion_model.first_stage_model.encoder.num_resolutions - 1
self.diffusion_model.first_stage_model.encoder.num_resolutions - 1 self.log_time_interval = self.diffusion_model.num_timesteps // log_steps
)
self.log_time_interval = (
self.diffusion_model.num_timesteps // log_steps
)
self.log_steps = log_steps self.log_steps = log_steps
self.label_key = ( self.label_key = (
label_key label_key
if not hasattr(self.diffusion_model, 'cond_stage_key') if not hasattr(self.diffusion_model, "cond_stage_key")
else self.diffusion_model.cond_stage_key else self.diffusion_model.cond_stage_key
) )
assert ( assert (
self.label_key is not None self.label_key is not None
), 'label_key neither in diffusion model nor in model.params' ), "label_key neither in diffusion model nor in model.params"
if self.label_key not in __models__: if self.label_key not in __models__:
raise NotImplementedError() raise NotImplementedError()
@ -80,14 +73,14 @@ class NoisyLatentImageClassifier(pl.LightningModule):
self.weight_decay = weight_decay self.weight_decay = weight_decay
def init_from_ckpt(self, path, ignore_keys=list(), only_model=False): def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):
sd = torch.load(path, map_location='cpu') sd = torch.load(path, map_location="cpu")
if 'state_dict' in list(sd.keys()): if "state_dict" in list(sd.keys()):
sd = sd['state_dict'] sd = sd["state_dict"]
keys = list(sd.keys()) keys = list(sd.keys())
for k in keys: for k in keys:
for ik in ignore_keys: for ik in ignore_keys:
if k.startswith(ik): if k.startswith(ik):
print('Deleting key {} from state_dict.'.format(k)) print("Deleting key {} from state_dict.".format(k))
del sd[k] del sd[k]
missing, unexpected = ( missing, unexpected = (
self.load_state_dict(sd, strict=False) self.load_state_dict(sd, strict=False)
@ -95,12 +88,12 @@ class NoisyLatentImageClassifier(pl.LightningModule):
else self.model.load_state_dict(sd, strict=False) else self.model.load_state_dict(sd, strict=False)
) )
print( print(
f'Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys' f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys"
) )
if len(missing) > 0: if len(missing) > 0:
print(f'Missing Keys: {missing}') print(f"Missing Keys: {missing}")
if len(unexpected) > 0: if len(unexpected) > 0:
print(f'Unexpected Keys: {unexpected}') print(f"Unexpected Keys: {unexpected}")
def load_diffusion(self): def load_diffusion(self):
model = instantiate_from_config(self.diffusion_config) model = instantiate_from_config(self.diffusion_config)
@ -110,24 +103,22 @@ class NoisyLatentImageClassifier(pl.LightningModule):
param.requires_grad = False param.requires_grad = False
def load_classifier(self, ckpt_path, pool): def load_classifier(self, ckpt_path, pool):
model_config = deepcopy( model_config = deepcopy(self.diffusion_config.params.unet_config.params)
self.diffusion_config.params.unet_config.params
)
model_config.in_channels = ( model_config.in_channels = (
self.diffusion_config.params.unet_config.params.out_channels self.diffusion_config.params.unet_config.params.out_channels
) )
model_config.out_channels = self.num_classes model_config.out_channels = self.num_classes
if self.label_key == 'class_label': if self.label_key == "class_label":
model_config.pool = pool model_config.pool = pool
self.model = __models__[self.label_key](**model_config) self.model = __models__[self.label_key](**model_config)
if ckpt_path is not None: if ckpt_path is not None:
print( print(
'#####################################################################' "#####################################################################"
) )
print(f'load from ckpt "{ckpt_path}"') print(f'load from ckpt "{ckpt_path}"')
print( print(
'#####################################################################' "#####################################################################"
) )
self.init_from_ckpt(ckpt_path) self.init_from_ckpt(ckpt_path)
@ -137,9 +128,7 @@ class NoisyLatentImageClassifier(pl.LightningModule):
continuous_sqrt_alpha_cumprod = None continuous_sqrt_alpha_cumprod = None
if self.diffusion_model.use_continuous_noise: if self.diffusion_model.use_continuous_noise:
continuous_sqrt_alpha_cumprod = ( continuous_sqrt_alpha_cumprod = (
self.diffusion_model.sample_continuous_noise_level( self.diffusion_model.sample_continuous_noise_level(x.shape[0], t + 1)
x.shape[0], t + 1
)
) )
# todo: make sure t+1 is correct here # todo: make sure t+1 is correct here
@ -158,7 +147,7 @@ class NoisyLatentImageClassifier(pl.LightningModule):
x = batch[k] x = batch[k]
if len(x.shape) == 3: if len(x.shape) == 3:
x = x[..., None] x = x[..., None]
x = rearrange(x, 'b h w c -> b c h w') x = rearrange(x, "b h w c -> b c h w")
x = x.to(memory_format=torch.contiguous_format).float() x = x.to(memory_format=torch.contiguous_format).float()
return x return x
@ -166,45 +155,41 @@ class NoisyLatentImageClassifier(pl.LightningModule):
def get_conditioning(self, batch, k=None): def get_conditioning(self, batch, k=None):
if k is None: if k is None:
k = self.label_key k = self.label_key
assert k is not None, 'Needs to provide label key' assert k is not None, "Needs to provide label key"
targets = batch[k].to(self.device) targets = batch[k].to(self.device)
if self.label_key == 'segmentation': if self.label_key == "segmentation":
targets = rearrange(targets, 'b h w c -> b c h w') targets = rearrange(targets, "b h w c -> b c h w")
for down in range(self.numd): for down in range(self.numd):
h, w = targets.shape[-2:] h, w = targets.shape[-2:]
targets = F.interpolate( targets = F.interpolate(targets, size=(h // 2, w // 2), mode="nearest")
targets, size=(h // 2, w // 2), mode='nearest'
)
# targets = rearrange(targets,'b c h w -> b h w c') # targets = rearrange(targets,'b c h w -> b h w c')
return targets return targets
def compute_top_k(self, logits, labels, k, reduction='mean'): def compute_top_k(self, logits, labels, k, reduction="mean"):
_, top_ks = torch.topk(logits, k, dim=1) _, top_ks = torch.topk(logits, k, dim=1)
if reduction == 'mean': if reduction == "mean":
return ( return (top_ks == labels[:, None]).float().sum(dim=-1).mean().item()
(top_ks == labels[:, None]).float().sum(dim=-1).mean().item() elif reduction == "none":
)
elif reduction == 'none':
return (top_ks == labels[:, None]).float().sum(dim=-1) return (top_ks == labels[:, None]).float().sum(dim=-1)
def on_train_epoch_start(self): def on_train_epoch_start(self):
# save some memory # save some memory
self.diffusion_model.model.to('cpu') self.diffusion_model.model.to("cpu")
@torch.no_grad() @torch.no_grad()
def write_logs(self, loss, logits, targets): def write_logs(self, loss, logits, targets):
log_prefix = 'train' if self.training else 'val' log_prefix = "train" if self.training else "val"
log = {} log = {}
log[f'{log_prefix}/loss'] = loss.mean() log[f"{log_prefix}/loss"] = loss.mean()
log[f'{log_prefix}/acc@1'] = self.compute_top_k( log[f"{log_prefix}/acc@1"] = self.compute_top_k(
logits, targets, k=1, reduction='mean' logits, targets, k=1, reduction="mean"
) )
log[f'{log_prefix}/acc@5'] = self.compute_top_k( log[f"{log_prefix}/acc@5"] = self.compute_top_k(
logits, targets, k=5, reduction='mean' logits, targets, k=5, reduction="mean"
) )
self.log_dict( self.log_dict(
@ -214,19 +199,17 @@ class NoisyLatentImageClassifier(pl.LightningModule):
on_step=self.training, on_step=self.training,
on_epoch=True, on_epoch=True,
) )
self.log("loss", log[f"{log_prefix}/loss"], prog_bar=True, logger=False)
self.log( self.log(
'loss', log[f'{log_prefix}/loss'], prog_bar=True, logger=False "global_step",
)
self.log(
'global_step',
self.global_step, self.global_step,
logger=False, logger=False,
on_epoch=False, on_epoch=False,
prog_bar=True, prog_bar=True,
) )
lr = self.optimizers().param_groups[0]['lr'] lr = self.optimizers().param_groups[0]["lr"]
self.log( self.log(
'lr_abs', "lr_abs",
lr, lr,
on_step=True, on_step=True,
logger=True, logger=True,
@ -249,13 +232,11 @@ class NoisyLatentImageClassifier(pl.LightningModule):
device=self.device, device=self.device,
).long() ).long()
else: else:
t = torch.full( t = torch.full(size=(x.shape[0],), fill_value=t, device=self.device).long()
size=(x.shape[0],), fill_value=t, device=self.device
).long()
x_noisy = self.get_x_noisy(x, t) x_noisy = self.get_x_noisy(x, t)
logits = self(x_noisy, t) logits = self(x_noisy, t)
loss = F.cross_entropy(logits, targets, reduction='none') loss = F.cross_entropy(logits, targets, reduction="none")
self.write_logs(loss.detach(), logits.detach(), targets.detach()) self.write_logs(loss.detach(), logits.detach(), targets.detach())
@ -268,7 +249,7 @@ class NoisyLatentImageClassifier(pl.LightningModule):
def reset_noise_accs(self): def reset_noise_accs(self):
self.noisy_acc = { self.noisy_acc = {
t: {'acc@1': [], 'acc@5': []} t: {"acc@1": [], "acc@5": []}
for t in range( for t in range(
0, 0,
self.diffusion_model.num_timesteps, self.diffusion_model.num_timesteps,
@ -285,11 +266,11 @@ class NoisyLatentImageClassifier(pl.LightningModule):
for t in self.noisy_acc: for t in self.noisy_acc:
_, logits, _, targets = self.shared_step(batch, t) _, logits, _, targets = self.shared_step(batch, t)
self.noisy_acc[t]['acc@1'].append( self.noisy_acc[t]["acc@1"].append(
self.compute_top_k(logits, targets, k=1, reduction='mean') self.compute_top_k(logits, targets, k=1, reduction="mean")
) )
self.noisy_acc[t]['acc@5'].append( self.noisy_acc[t]["acc@5"].append(
self.compute_top_k(logits, targets, k=5, reduction='mean') self.compute_top_k(logits, targets, k=5, reduction="mean")
) )
return loss return loss
@ -304,14 +285,12 @@ class NoisyLatentImageClassifier(pl.LightningModule):
if self.use_scheduler: if self.use_scheduler:
scheduler = instantiate_from_config(self.scheduler_config) scheduler = instantiate_from_config(self.scheduler_config)
print('Setting up LambdaLR scheduler...') print("Setting up LambdaLR scheduler...")
scheduler = [ scheduler = [
{ {
'scheduler': LambdaLR( "scheduler": LambdaLR(optimizer, lr_lambda=scheduler.schedule),
optimizer, lr_lambda=scheduler.schedule "interval": "step",
), "frequency": 1,
'interval': 'step',
'frequency': 1,
} }
] ]
return [optimizer], scheduler return [optimizer], scheduler
@ -322,32 +301,28 @@ class NoisyLatentImageClassifier(pl.LightningModule):
def log_images(self, batch, N=8, *args, **kwargs): def log_images(self, batch, N=8, *args, **kwargs):
log = dict() log = dict()
x = self.get_input(batch, self.diffusion_model.first_stage_key) x = self.get_input(batch, self.diffusion_model.first_stage_key)
log['inputs'] = x log["inputs"] = x
y = self.get_conditioning(batch) y = self.get_conditioning(batch)
if self.label_key == 'class_label': if self.label_key == "class_label":
y = log_txt_as_img((x.shape[2], x.shape[3]), batch['human_label']) y = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"])
log['labels'] = y log["labels"] = y
if ismap(y): if ismap(y):
log['labels'] = self.diffusion_model.to_rgb(y) log["labels"] = self.diffusion_model.to_rgb(y)
for step in range(self.log_steps): for step in range(self.log_steps):
current_time = step * self.log_time_interval current_time = step * self.log_time_interval
_, logits, x_noisy, _ = self.shared_step(batch, t=current_time) _, logits, x_noisy, _ = self.shared_step(batch, t=current_time)
log[f'inputs@t{current_time}'] = x_noisy log[f"inputs@t{current_time}"] = x_noisy
pred = F.one_hot( pred = F.one_hot(logits.argmax(dim=1), num_classes=self.num_classes)
logits.argmax(dim=1), num_classes=self.num_classes pred = rearrange(pred, "b h w c -> b c h w")
)
pred = rearrange(pred, 'b h w c -> b c h w')
log[f'pred@t{current_time}'] = self.diffusion_model.to_rgb( log[f"pred@t{current_time}"] = self.diffusion_model.to_rgb(pred)
pred
)
for key in log: for key in log:
log[key] = log[key][:N] log[key] = log[key][:N]

View File

@ -1,21 +1,20 @@
# adapted from bloc97's CrossAttentionControl colab # adapted from bloc97's CrossAttentionControl colab
# https://github.com/bloc97/CrossAttentionControl # https://github.com/bloc97/CrossAttentionControl
import enum import enum
import math import math
from typing import Optional, Callable from typing import Callable, Optional
import diffusers
import psutil import psutil
import torch import torch
import diffusers from compel.cross_attention_control import Arguments
from diffusers.models.cross_attention import AttnProcessor
from diffusers.models.unet_2d_condition import UNet2DConditionModel
from torch import nn from torch import nn
from compel.cross_attention_control import Arguments from ...util import torch_dtype
from diffusers.models.unet_2d_condition import UNet2DConditionModel
from diffusers.models.cross_attention import AttnProcessor
from ldm.invoke.devices import torch_dtype
class CrossAttentionType(enum.Enum): class CrossAttentionType(enum.Enum):
@ -24,13 +23,12 @@ class CrossAttentionType(enum.Enum):
class Context: class Context:
cross_attention_mask: Optional[torch.Tensor] cross_attention_mask: Optional[torch.Tensor]
cross_attention_index_map: Optional[torch.Tensor] cross_attention_index_map: Optional[torch.Tensor]
class Action(enum.Enum): class Action(enum.Enum):
NONE = 0 NONE = 0
SAVE = 1, SAVE = (1,)
APPLY = 2 APPLY = 2
def __init__(self, arguments: Arguments, step_count: int): def __init__(self, arguments: Arguments, step_count: int):
@ -53,11 +51,13 @@ class Context:
self.clear_requests(cleanup=True) self.clear_requests(cleanup=True)
def register_cross_attention_modules(self, model): def register_cross_attention_modules(self, model):
for name,module in get_cross_attention_modules(model, CrossAttentionType.SELF): for name, module in get_cross_attention_modules(model, CrossAttentionType.SELF):
if name in self.self_cross_attention_module_identifiers: if name in self.self_cross_attention_module_identifiers:
assert False, f"name {name} cannot appear more than once" assert False, f"name {name} cannot appear more than once"
self.self_cross_attention_module_identifiers.append(name) self.self_cross_attention_module_identifiers.append(name)
for name,module in get_cross_attention_modules(model, CrossAttentionType.TOKENS): for name, module in get_cross_attention_modules(
model, CrossAttentionType.TOKENS
):
if name in self.tokens_cross_attention_module_identifiers: if name in self.tokens_cross_attention_module_identifiers:
assert False, f"name {name} cannot appear more than once" assert False, f"name {name} cannot appear more than once"
self.tokens_cross_attention_module_identifiers.append(name) self.tokens_cross_attention_module_identifiers.append(name)
@ -68,7 +68,9 @@ class Context:
else: else:
self.tokens_cross_attention_action = Context.Action.SAVE self.tokens_cross_attention_action = Context.Action.SAVE
def request_apply_saved_attention_maps(self, cross_attention_type: CrossAttentionType): def request_apply_saved_attention_maps(
self, cross_attention_type: CrossAttentionType
):
if cross_attention_type == CrossAttentionType.SELF: if cross_attention_type == CrossAttentionType.SELF:
self.self_cross_attention_action = Context.Action.APPLY self.self_cross_attention_action = Context.Action.APPLY
else: else:
@ -91,8 +93,9 @@ class Context:
return self.tokens_cross_attention_action == Context.Action.APPLY return self.tokens_cross_attention_action == Context.Action.APPLY
return False return False
def get_active_cross_attention_control_types_for_step(self, percent_through:float=None)\ def get_active_cross_attention_control_types_for_step(
-> list[CrossAttentionType]: self, percent_through: float = None
) -> list[CrossAttentionType]:
""" """
Should cross-attention control be applied on the given step? Should cross-attention control be applied on the given step?
:param percent_through: How far through the step sequence are we (0.0=pure noise, 1.0=completely denoised image). Expected range 0.0..<1.0. :param percent_through: How far through the step sequence are we (0.0=pure noise, 1.0=completely denoised image). Expected range 0.0..<1.0.
@ -103,50 +106,73 @@ class Context:
opts = self.arguments.edit_options opts = self.arguments.edit_options
to_control = [] to_control = []
if opts['s_start'] <= percent_through < opts['s_end']: if opts["s_start"] <= percent_through < opts["s_end"]:
to_control.append(CrossAttentionType.SELF) to_control.append(CrossAttentionType.SELF)
if opts['t_start'] <= percent_through < opts['t_end']: if opts["t_start"] <= percent_through < opts["t_end"]:
to_control.append(CrossAttentionType.TOKENS) to_control.append(CrossAttentionType.TOKENS)
return to_control return to_control
def save_slice(self, identifier: str, slice: torch.Tensor, dim: Optional[int], offset: int, def save_slice(
slice_size: Optional[int]): self,
identifier: str,
slice: torch.Tensor,
dim: Optional[int],
offset: int,
slice_size: Optional[int],
):
if identifier not in self.saved_cross_attention_maps: if identifier not in self.saved_cross_attention_maps:
self.saved_cross_attention_maps[identifier] = { self.saved_cross_attention_maps[identifier] = {
'dim': dim, "dim": dim,
'slice_size': slice_size, "slice_size": slice_size,
'slices': {offset or 0: slice} "slices": {offset or 0: slice},
} }
else: else:
self.saved_cross_attention_maps[identifier]['slices'][offset or 0] = slice self.saved_cross_attention_maps[identifier]["slices"][offset or 0] = slice
def get_slice(self, identifier: str, requested_dim: Optional[int], requested_offset: int, slice_size: int): def get_slice(
self,
identifier: str,
requested_dim: Optional[int],
requested_offset: int,
slice_size: int,
):
saved_attention_dict = self.saved_cross_attention_maps[identifier] saved_attention_dict = self.saved_cross_attention_maps[identifier]
if requested_dim is None: if requested_dim is None:
if saved_attention_dict['dim'] is not None: if saved_attention_dict["dim"] is not None:
raise RuntimeError(f"dim mismatch: expected dim=None, have {saved_attention_dict['dim']}")
return saved_attention_dict['slices'][0]
if saved_attention_dict['dim'] == requested_dim:
if slice_size != saved_attention_dict['slice_size']:
raise RuntimeError( raise RuntimeError(
f"slice_size mismatch: expected slice_size={slice_size}, have {saved_attention_dict['slice_size']}") f"dim mismatch: expected dim=None, have {saved_attention_dict['dim']}"
return saved_attention_dict['slices'][requested_offset] )
return saved_attention_dict["slices"][0]
if saved_attention_dict['dim'] is None: if saved_attention_dict["dim"] == requested_dim:
whole_saved_attention = saved_attention_dict['slices'][0] if slice_size != saved_attention_dict["slice_size"]:
raise RuntimeError(
f"slice_size mismatch: expected slice_size={slice_size}, have {saved_attention_dict['slice_size']}"
)
return saved_attention_dict["slices"][requested_offset]
if saved_attention_dict["dim"] is None:
whole_saved_attention = saved_attention_dict["slices"][0]
if requested_dim == 0: if requested_dim == 0:
return whole_saved_attention[requested_offset:requested_offset + slice_size] return whole_saved_attention[
requested_offset : requested_offset + slice_size
]
elif requested_dim == 1: elif requested_dim == 1:
return whole_saved_attention[:, requested_offset:requested_offset + slice_size] return whole_saved_attention[
:, requested_offset : requested_offset + slice_size
]
raise RuntimeError(f"Cannot convert dim {saved_attention_dict['dim']} to requested dim {requested_dim}") raise RuntimeError(
f"Cannot convert dim {saved_attention_dict['dim']} to requested dim {requested_dim}"
)
def get_slicing_strategy(self, identifier: str) -> tuple[Optional[int], Optional[int]]: def get_slicing_strategy(
self, identifier: str
) -> tuple[Optional[int], Optional[int]]:
saved_attention = self.saved_cross_attention_maps.get(identifier, None) saved_attention = self.saved_cross_attention_maps.get(identifier, None)
if saved_attention is None: if saved_attention is None:
return None, None return None, None
return saved_attention['dim'], saved_attention['slice_size'] return saved_attention["dim"], saved_attention["slice_size"]
def clear_requests(self, cleanup=True): def clear_requests(self, cleanup=True):
self.tokens_cross_attention_action = Context.Action.NONE self.tokens_cross_attention_action = Context.Action.NONE
@ -156,9 +182,8 @@ class Context:
def offload_saved_attention_slices_to_cpu(self): def offload_saved_attention_slices_to_cpu(self):
for key, map_dict in self.saved_cross_attention_maps.items(): for key, map_dict in self.saved_cross_attention_maps.items():
for offset, slice in map_dict['slices'].items(): for offset, slice in map_dict["slices"].items():
map_dict[offset] = slice.to('cpu') map_dict[offset] = slice.to("cpu")
class InvokeAICrossAttentionMixin: class InvokeAICrossAttentionMixin:
@ -167,14 +192,20 @@ class InvokeAICrossAttentionMixin:
through both to an attention_slice_wrangler and a slicing_strategy_getter for custom attention map wrangling through both to an attention_slice_wrangler and a slicing_strategy_getter for custom attention map wrangling
and dymamic slicing strategy selection. and dymamic slicing strategy selection.
""" """
def __init__(self): def __init__(self):
self.mem_total_gb = psutil.virtual_memory().total // (1 << 30) self.mem_total_gb = psutil.virtual_memory().total // (1 << 30)
self.attention_slice_wrangler = None self.attention_slice_wrangler = None
self.slicing_strategy_getter = None self.slicing_strategy_getter = None
self.attention_slice_calculated_callback = None self.attention_slice_calculated_callback = None
def set_attention_slice_wrangler(self, wrangler: Optional[Callable[[nn.Module, torch.Tensor, int, int, int], torch.Tensor]]): def set_attention_slice_wrangler(
''' self,
wrangler: Optional[
Callable[[nn.Module, torch.Tensor, int, int, int], torch.Tensor]
],
):
"""
Set custom attention calculator to be called when attention is calculated Set custom attention calculator to be called when attention is calculated
:param wrangler: Callback, with args (module, suggested_attention_slice, dim, offset, slice_size), :param wrangler: Callback, with args (module, suggested_attention_slice, dim, offset, slice_size),
which returns either the suggested_attention_slice or an adjusted equivalent. which returns either the suggested_attention_slice or an adjusted equivalent.
@ -185,20 +216,30 @@ class InvokeAICrossAttentionMixin:
Pass None to use the default attention calculation. Pass None to use the default attention calculation.
:return: :return:
''' """
self.attention_slice_wrangler = wrangler self.attention_slice_wrangler = wrangler
def set_slicing_strategy_getter(self, getter: Optional[Callable[[nn.Module], tuple[int,int]]]): def set_slicing_strategy_getter(
self, getter: Optional[Callable[[nn.Module], tuple[int, int]]]
):
self.slicing_strategy_getter = getter self.slicing_strategy_getter = getter
def set_attention_slice_calculated_callback(self, callback: Optional[Callable[[torch.Tensor], None]]): def set_attention_slice_calculated_callback(
self, callback: Optional[Callable[[torch.Tensor], None]]
):
self.attention_slice_calculated_callback = callback self.attention_slice_calculated_callback = callback
def einsum_lowest_level(self, query, key, value, dim, offset, slice_size): def einsum_lowest_level(self, query, key, value, dim, offset, slice_size):
# calculate attention scores # calculate attention scores
#attention_scores = torch.einsum('b i d, b j d -> b i j', q, k) # attention_scores = torch.einsum('b i d, b j d -> b i j', q, k)
attention_scores = torch.baddbmm( attention_scores = torch.baddbmm(
torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device), torch.empty(
query.shape[0],
query.shape[1],
key.shape[1],
dtype=query.dtype,
device=query.device,
),
query, query,
key.transpose(-1, -2), key.transpose(-1, -2),
beta=0, beta=0,
@ -206,35 +247,49 @@ class InvokeAICrossAttentionMixin:
) )
# calculate attention slice by taking the best scores for each latent pixel # calculate attention slice by taking the best scores for each latent pixel
default_attention_slice = attention_scores.softmax(dim=-1, dtype=attention_scores.dtype) default_attention_slice = attention_scores.softmax(
dim=-1, dtype=attention_scores.dtype
)
attention_slice_wrangler = self.attention_slice_wrangler attention_slice_wrangler = self.attention_slice_wrangler
if attention_slice_wrangler is not None: if attention_slice_wrangler is not None:
attention_slice = attention_slice_wrangler(self, default_attention_slice, dim, offset, slice_size) attention_slice = attention_slice_wrangler(
self, default_attention_slice, dim, offset, slice_size
)
else: else:
attention_slice = default_attention_slice attention_slice = default_attention_slice
if self.attention_slice_calculated_callback is not None: if self.attention_slice_calculated_callback is not None:
self.attention_slice_calculated_callback(attention_slice, dim, offset, slice_size) self.attention_slice_calculated_callback(
attention_slice, dim, offset, slice_size
)
hidden_states = torch.bmm(attention_slice, value) hidden_states = torch.bmm(attention_slice, value)
return hidden_states return hidden_states
def einsum_op_slice_dim0(self, q, k, v, slice_size): def einsum_op_slice_dim0(self, q, k, v, slice_size):
r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype) r = torch.zeros(
q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype
)
for i in range(0, q.shape[0], slice_size): for i in range(0, q.shape[0], slice_size):
end = i + slice_size end = i + slice_size
r[i:end] = self.einsum_lowest_level(q[i:end], k[i:end], v[i:end], dim=0, offset=i, slice_size=slice_size) r[i:end] = self.einsum_lowest_level(
q[i:end], k[i:end], v[i:end], dim=0, offset=i, slice_size=slice_size
)
return r return r
def einsum_op_slice_dim1(self, q, k, v, slice_size): def einsum_op_slice_dim1(self, q, k, v, slice_size):
r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype) r = torch.zeros(
q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype
)
for i in range(0, q.shape[1], slice_size): for i in range(0, q.shape[1], slice_size):
end = i + slice_size end = i + slice_size
r[:, i:end] = self.einsum_lowest_level(q[:, i:end], k, v, dim=1, offset=i, slice_size=slice_size) r[:, i:end] = self.einsum_lowest_level(
q[:, i:end], k, v, dim=1, offset=i, slice_size=slice_size
)
return r return r
def einsum_op_mps_v1(self, q, k, v): def einsum_op_mps_v1(self, q, k, v):
if q.shape[1] <= 4096: # (512x512) max q.shape[1]: 4096 if q.shape[1] <= 4096: # (512x512) max q.shape[1]: 4096
return self.einsum_lowest_level(q, k, v, None, None, None) return self.einsum_lowest_level(q, k, v, None, None, None)
else: else:
slice_size = math.floor(2**30 / (q.shape[0] * q.shape[1])) slice_size = math.floor(2**30 / (q.shape[0] * q.shape[1]))
@ -272,13 +327,12 @@ class InvokeAICrossAttentionMixin:
# Divide factor of safety as there's copying and fragmentation # Divide factor of safety as there's copying and fragmentation
return self.einsum_op_tensor_mem(q, k, v, mem_free_total / 3.3 / (1 << 20)) return self.einsum_op_tensor_mem(q, k, v, mem_free_total / 3.3 / (1 << 20))
def get_invokeai_attention_mem_efficient(self, q, k, v): def get_invokeai_attention_mem_efficient(self, q, k, v):
if q.device.type == 'cuda': if q.device.type == "cuda":
#print("in get_attention_mem_efficient with q shape", q.shape, ", k shape", k.shape, ", free memory is", get_mem_free_total(q.device)) # print("in get_attention_mem_efficient with q shape", q.shape, ", k shape", k.shape, ", free memory is", get_mem_free_total(q.device))
return self.einsum_op_cuda(q, k, v) return self.einsum_op_cuda(q, k, v)
if q.device.type == 'mps' or q.device.type == 'cpu': if q.device.type == "mps" or q.device.type == "cpu":
if self.mem_total_gb >= 32: if self.mem_total_gb >= 32:
return self.einsum_op_mps_v1(q, k, v) return self.einsum_op_mps_v1(q, k, v)
return self.einsum_op_mps_v2(q, k, v) return self.einsum_op_mps_v2(q, k, v)
@ -288,8 +342,11 @@ class InvokeAICrossAttentionMixin:
return self.einsum_op_tensor_mem(q, k, v, 32) return self.einsum_op_tensor_mem(q, k, v, 32)
def restore_default_cross_attention(
def restore_default_cross_attention(model, is_running_diffusers: bool, restore_attention_processor: Optional[AttnProcessor]=None): model,
is_running_diffusers: bool,
restore_attention_processor: Optional[AttnProcessor] = None,
):
if is_running_diffusers: if is_running_diffusers:
unet = model unet = model
unet.set_attn_processor(restore_attention_processor or CrossAttnProcessor()) unet.set_attn_processor(restore_attention_processor or CrossAttnProcessor())
@ -297,7 +354,7 @@ def restore_default_cross_attention(model, is_running_diffusers: bool, restore_a
remove_attention_function(model) remove_attention_function(model)
def override_cross_attention(model, context: Context, is_running_diffusers = False): def override_cross_attention(model, context: Context, is_running_diffusers=False):
""" """
Inject attention parameters and functions into the passed in model to enable cross attention editing. Inject attention parameters and functions into the passed in model to enable cross attention editing.
@ -316,7 +373,7 @@ def override_cross_attention(model, context: Context, is_running_diffusers = Fal
indices = torch.arange(max_length, dtype=torch.long) indices = torch.arange(max_length, dtype=torch.long)
for name, a0, a1, b0, b1 in context.arguments.edit_opcodes: for name, a0, a1, b0, b1 in context.arguments.edit_opcodes:
if b0 < max_length: if b0 < max_length:
if name == "equal":# or (name == "replace" and a1 - a0 == b1 - b0): if name == "equal": # or (name == "replace" and a1 - a0 == b1 - b0):
# these tokens have not been edited # these tokens have not been edited
indices[b0:b1] = indices_target[a0:a1] indices[b0:b1] = indices_target[a0:a1]
mask[b0:b1] = 1 mask[b0:b1] = 1
@ -332,7 +389,14 @@ def override_cross_attention(model, context: Context, is_running_diffusers = Fal
else: else:
# try to re-use an existing slice size # try to re-use an existing slice size
default_slice_size = 4 default_slice_size = 4
slice_size = next((p.slice_size for p in old_attn_processors.values() if type(p) is SlicedAttnProcessor), default_slice_size) slice_size = next(
(
p.slice_size
for p in old_attn_processors.values()
if type(p) is SlicedAttnProcessor
),
default_slice_size,
)
unet.set_attn_processor(SlicedSwapCrossAttnProcesser(slice_size=slice_size)) unet.set_attn_processor(SlicedSwapCrossAttnProcesser(slice_size=slice_size))
return old_attn_processors return old_attn_processors
else: else:
@ -341,65 +405,96 @@ def override_cross_attention(model, context: Context, is_running_diffusers = Fal
return None return None
def get_cross_attention_modules(
model, which: CrossAttentionType
) -> list[tuple[str, InvokeAICrossAttentionMixin]]:
from ldm.modules.attention import CrossAttention # avoid circular import
cross_attention_class: type = (
def get_cross_attention_modules(model, which: CrossAttentionType) -> list[tuple[str, InvokeAICrossAttentionMixin]]: InvokeAIDiffusersCrossAttention
from ldm.modules.attention import CrossAttention # avoid circular import if isinstance(model, UNet2DConditionModel)
cross_attention_class: type = InvokeAIDiffusersCrossAttention if isinstance(model,UNet2DConditionModel) else CrossAttention else CrossAttention
)
which_attn = "attn1" if which is CrossAttentionType.SELF else "attn2" which_attn = "attn1" if which is CrossAttentionType.SELF else "attn2"
attention_module_tuples = [(name,module) for name, module in model.named_modules() if attention_module_tuples = [
isinstance(module, cross_attention_class) and which_attn in name] (name, module)
for name, module in model.named_modules()
if isinstance(module, cross_attention_class) and which_attn in name
]
cross_attention_modules_in_model_count = len(attention_module_tuples) cross_attention_modules_in_model_count = len(attention_module_tuples)
expected_count = 16 expected_count = 16
if cross_attention_modules_in_model_count != expected_count: if cross_attention_modules_in_model_count != expected_count:
# non-fatal error but .swap() won't work. # non-fatal error but .swap() won't work.
print(f"Error! CrossAttentionControl found an unexpected number of {cross_attention_class} modules in the model " + print(
f"(expected {expected_count}, found {cross_attention_modules_in_model_count}). Either monkey-patching failed " + f"Error! CrossAttentionControl found an unexpected number of {cross_attention_class} modules in the model "
f"or some assumption has changed about the structure of the model itself. Please fix the monkey-patching, " + + f"(expected {expected_count}, found {cross_attention_modules_in_model_count}). Either monkey-patching failed "
f"and/or update the {expected_count} above to an appropriate number, and/or find and inform someone who knows " + + f"or some assumption has changed about the structure of the model itself. Please fix the monkey-patching, "
f"what it means. This error is non-fatal, but it is likely that .swap() and attention map display will not " + + f"and/or update the {expected_count} above to an appropriate number, and/or find and inform someone who knows "
f"work properly until it is fixed.") + f"what it means. This error is non-fatal, but it is likely that .swap() and attention map display will not "
+ f"work properly until it is fixed."
)
return attention_module_tuples return attention_module_tuples
def inject_attention_function(unet, context: Context): def inject_attention_function(unet, context: Context):
# ORIGINAL SOURCE CODE: https://github.com/huggingface/diffusers/blob/91ddd2a25b848df0fa1262d4f1cd98c7ccb87750/src/diffusers/models/attention.py#L276 # ORIGINAL SOURCE CODE: https://github.com/huggingface/diffusers/blob/91ddd2a25b848df0fa1262d4f1cd98c7ccb87750/src/diffusers/models/attention.py#L276
def attention_slice_wrangler(module, suggested_attention_slice:torch.Tensor, dim, offset, slice_size): def attention_slice_wrangler(
module, suggested_attention_slice: torch.Tensor, dim, offset, slice_size
#memory_usage = suggested_attention_slice.element_size() * suggested_attention_slice.nelement() ):
# memory_usage = suggested_attention_slice.element_size() * suggested_attention_slice.nelement()
attention_slice = suggested_attention_slice attention_slice = suggested_attention_slice
if context.get_should_save_maps(module.identifier): if context.get_should_save_maps(module.identifier):
#print(module.identifier, "saving suggested_attention_slice of shape", # print(module.identifier, "saving suggested_attention_slice of shape",
# suggested_attention_slice.shape, "dim", dim, "offset", offset) # suggested_attention_slice.shape, "dim", dim, "offset", offset)
slice_to_save = attention_slice.to('cpu') if dim is not None else attention_slice slice_to_save = (
context.save_slice(module.identifier, slice_to_save, dim=dim, offset=offset, slice_size=slice_size) attention_slice.to("cpu") if dim is not None else attention_slice
)
context.save_slice(
module.identifier,
slice_to_save,
dim=dim,
offset=offset,
slice_size=slice_size,
)
elif context.get_should_apply_saved_maps(module.identifier): elif context.get_should_apply_saved_maps(module.identifier):
#print(module.identifier, "applying saved attention slice for dim", dim, "offset", offset) # print(module.identifier, "applying saved attention slice for dim", dim, "offset", offset)
saved_attention_slice = context.get_slice(module.identifier, dim, offset, slice_size) saved_attention_slice = context.get_slice(
module.identifier, dim, offset, slice_size
)
# slice may have been offloaded to CPU # slice may have been offloaded to CPU
saved_attention_slice = saved_attention_slice.to(suggested_attention_slice.device) saved_attention_slice = saved_attention_slice.to(
suggested_attention_slice.device
)
if context.is_tokens_cross_attention(module.identifier): if context.is_tokens_cross_attention(module.identifier):
index_map = context.cross_attention_index_map index_map = context.cross_attention_index_map
remapped_saved_attention_slice = torch.index_select(saved_attention_slice, -1, index_map) remapped_saved_attention_slice = torch.index_select(
saved_attention_slice, -1, index_map
)
this_attention_slice = suggested_attention_slice this_attention_slice = suggested_attention_slice
mask = context.cross_attention_mask.to(torch_dtype(suggested_attention_slice.device)) mask = context.cross_attention_mask.to(
torch_dtype(suggested_attention_slice.device)
)
saved_mask = mask saved_mask = mask
this_mask = 1 - mask this_mask = 1 - mask
attention_slice = remapped_saved_attention_slice * saved_mask + \ attention_slice = (
this_attention_slice * this_mask remapped_saved_attention_slice * saved_mask
+ this_attention_slice * this_mask
)
else: else:
# just use everything # just use everything
attention_slice = saved_attention_slice attention_slice = saved_attention_slice
return attention_slice return attention_slice
cross_attention_modules = get_cross_attention_modules(unet, CrossAttentionType.TOKENS) + get_cross_attention_modules(unet, CrossAttentionType.SELF) cross_attention_modules = get_cross_attention_modules(
unet, CrossAttentionType.TOKENS
) + get_cross_attention_modules(unet, CrossAttentionType.SELF)
for identifier, module in cross_attention_modules: for identifier, module in cross_attention_modules:
module.identifier = identifier module.identifier = identifier
try: try:
@ -408,56 +503,61 @@ def inject_attention_function(unet, context: Context):
lambda module: context.get_slicing_strategy(identifier) lambda module: context.get_slicing_strategy(identifier)
) )
except AttributeError as e: except AttributeError as e:
if is_attribute_error_about(e, 'set_attention_slice_wrangler'): if is_attribute_error_about(e, "set_attention_slice_wrangler"):
print(f"TODO: implement set_attention_slice_wrangler for {type(module)}") # TODO print(
f"TODO: implement set_attention_slice_wrangler for {type(module)}"
) # TODO
else: else:
raise raise
def remove_attention_function(unet): def remove_attention_function(unet):
cross_attention_modules = get_cross_attention_modules(unet, CrossAttentionType.TOKENS) + get_cross_attention_modules(unet, CrossAttentionType.SELF) cross_attention_modules = get_cross_attention_modules(
unet, CrossAttentionType.TOKENS
) + get_cross_attention_modules(unet, CrossAttentionType.SELF)
for identifier, module in cross_attention_modules: for identifier, module in cross_attention_modules:
try: try:
# clear wrangler callback # clear wrangler callback
module.set_attention_slice_wrangler(None) module.set_attention_slice_wrangler(None)
module.set_slicing_strategy_getter(None) module.set_slicing_strategy_getter(None)
except AttributeError as e: except AttributeError as e:
if is_attribute_error_about(e, 'set_attention_slice_wrangler'): if is_attribute_error_about(e, "set_attention_slice_wrangler"):
print(f"TODO: implement set_attention_slice_wrangler for {type(module)}") print(
f"TODO: implement set_attention_slice_wrangler for {type(module)}"
)
else: else:
raise raise
def is_attribute_error_about(error: AttributeError, attribute: str): def is_attribute_error_about(error: AttributeError, attribute: str):
if hasattr(error, 'name'): # Python 3.10 if hasattr(error, "name"): # Python 3.10
return error.name == attribute return error.name == attribute
else: # Python 3.9 else: # Python 3.9
return attribute in str(error) return attribute in str(error)
def get_mem_free_total(device): def get_mem_free_total(device):
#only on cuda # only on cuda
if not torch.cuda.is_available(): if not torch.cuda.is_available():
return None return None
stats = torch.cuda.memory_stats(device) stats = torch.cuda.memory_stats(device)
mem_active = stats['active_bytes.all.current'] mem_active = stats["active_bytes.all.current"]
mem_reserved = stats['reserved_bytes.all.current'] mem_reserved = stats["reserved_bytes.all.current"]
mem_free_cuda, _ = torch.cuda.mem_get_info(device) mem_free_cuda, _ = torch.cuda.mem_get_info(device)
mem_free_torch = mem_reserved - mem_active mem_free_torch = mem_reserved - mem_active
mem_free_total = mem_free_cuda + mem_free_torch mem_free_total = mem_free_cuda + mem_free_torch
return mem_free_total return mem_free_total
class InvokeAIDiffusersCrossAttention(
class InvokeAIDiffusersCrossAttention(diffusers.models.attention.CrossAttention, InvokeAICrossAttentionMixin): diffusers.models.attention.CrossAttention, InvokeAICrossAttentionMixin
):
def __init__(self, **kwargs): def __init__(self, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
InvokeAICrossAttentionMixin.__init__(self) InvokeAICrossAttentionMixin.__init__(self)
def _attention(self, query, key, value, attention_mask=None): def _attention(self, query, key, value, attention_mask=None):
#default_result = super()._attention(query, key, value) # default_result = super()._attention(query, key, value)
if attention_mask is not None: if attention_mask is not None:
print(f"{type(self).__name__} ignoring passed-in attention_mask") print(f"{type(self).__name__} ignoring passed-in attention_mask")
attention_result = self.get_invokeai_attention_mem_efficient(query, key, value) attention_result = self.get_invokeai_attention_mem_efficient(query, key, value)
@ -466,9 +566,6 @@ class InvokeAIDiffusersCrossAttention(diffusers.models.attention.CrossAttention,
return hidden_states return hidden_states
## 🧨diffusers implementation follows ## 🧨diffusers implementation follows
@ -501,25 +598,30 @@ class CrossAttnProcessor:
return hidden_states return hidden_states
""" """
from dataclasses import field, dataclass from dataclasses import dataclass, field
import torch import torch
from diffusers.models.cross_attention import (
from diffusers.models.cross_attention import CrossAttention, CrossAttnProcessor, SlicedAttnProcessor CrossAttention,
CrossAttnProcessor,
SlicedAttnProcessor,
)
@dataclass @dataclass
class SwapCrossAttnContext: class SwapCrossAttnContext:
modified_text_embeddings: torch.Tensor modified_text_embeddings: torch.Tensor
index_map: torch.Tensor # maps from original prompt token indices to the equivalent tokens in the modified prompt index_map: torch.Tensor # maps from original prompt token indices to the equivalent tokens in the modified prompt
mask: torch.Tensor # in the target space of the index_map mask: torch.Tensor # in the target space of the index_map
cross_attention_types_to_do: list[CrossAttentionType] = field(default_factory=list) cross_attention_types_to_do: list[CrossAttentionType] = field(default_factory=list)
def __int__(self, def __int__(
cac_types_to_do: [CrossAttentionType], self,
modified_text_embeddings: torch.Tensor, cac_types_to_do: [CrossAttentionType],
index_map: torch.Tensor, modified_text_embeddings: torch.Tensor,
mask: torch.Tensor): index_map: torch.Tensor,
mask: torch.Tensor,
):
self.cross_attention_types_to_do = cac_types_to_do self.cross_attention_types_to_do = cac_types_to_do
self.modified_text_embeddings = modified_text_embeddings self.modified_text_embeddings = modified_text_embeddings
self.index_map = index_map self.index_map = index_map
@ -529,9 +631,9 @@ class SwapCrossAttnContext:
return attn_type in self.cross_attention_types_to_do return attn_type in self.cross_attention_types_to_do
@classmethod @classmethod
def make_mask_and_index_map(cls, edit_opcodes: list[tuple[str, int, int, int, int]], max_length: int) \ def make_mask_and_index_map(
-> tuple[torch.Tensor, torch.Tensor]: cls, edit_opcodes: list[tuple[str, int, int, int, int]], max_length: int
) -> tuple[torch.Tensor, torch.Tensor]:
# mask=1 means use original prompt attention, mask=0 means use modified prompt attention # mask=1 means use original prompt attention, mask=0 means use modified prompt attention
mask = torch.zeros(max_length) mask = torch.zeros(max_length)
indices_target = torch.arange(max_length, dtype=torch.long) indices_target = torch.arange(max_length, dtype=torch.long)
@ -547,28 +649,42 @@ class SwapCrossAttnContext:
class SlicedSwapCrossAttnProcesser(SlicedAttnProcessor): class SlicedSwapCrossAttnProcesser(SlicedAttnProcessor):
# TODO: dynamically pick slice size based on memory conditions # TODO: dynamically pick slice size based on memory conditions
def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None, def __call__(
# kwargs self,
swap_cross_attn_context: SwapCrossAttnContext=None): attn: CrossAttention,
hidden_states,
attention_type = CrossAttentionType.SELF if encoder_hidden_states is None else CrossAttentionType.TOKENS encoder_hidden_states=None,
attention_mask=None,
# kwargs
swap_cross_attn_context: SwapCrossAttnContext = None,
):
attention_type = (
CrossAttentionType.SELF
if encoder_hidden_states is None
else CrossAttentionType.TOKENS
)
# if cross-attention control is not in play, just call through to the base implementation. # if cross-attention control is not in play, just call through to the base implementation.
if attention_type is CrossAttentionType.SELF or \ if (
swap_cross_attn_context is None or \ attention_type is CrossAttentionType.SELF
not swap_cross_attn_context.wants_cross_attention_control(attention_type): or swap_cross_attn_context is None
#print(f"SwapCrossAttnContext for {attention_type} not active - passing request to superclass") or not swap_cross_attn_context.wants_cross_attention_control(attention_type)
return super().__call__(attn, hidden_states, encoder_hidden_states, attention_mask) ):
#else: # print(f"SwapCrossAttnContext for {attention_type} not active - passing request to superclass")
return super().__call__(
attn, hidden_states, encoder_hidden_states, attention_mask
)
# else:
# print(f"SwapCrossAttnContext for {attention_type} active") # print(f"SwapCrossAttnContext for {attention_type} active")
batch_size, sequence_length, _ = hidden_states.shape batch_size, sequence_length, _ = hidden_states.shape
attention_mask = attn.prepare_attention_mask( attention_mask = attn.prepare_attention_mask(
attention_mask=attention_mask, target_length=sequence_length, attention_mask=attention_mask,
batch_size=batch_size) target_length=sequence_length,
batch_size=batch_size,
)
query = attn.to_q(hidden_states) query = attn.to_q(hidden_states)
dim = query.shape[-1] dim = query.shape[-1]
@ -589,41 +705,51 @@ class SlicedSwapCrossAttnProcesser(SlicedAttnProcessor):
# compute slices and prepare output tensor # compute slices and prepare output tensor
batch_size_attention = query.shape[0] batch_size_attention = query.shape[0]
hidden_states = torch.zeros( hidden_states = torch.zeros(
(batch_size_attention, sequence_length, dim // attn.heads), device=query.device, dtype=query.dtype (batch_size_attention, sequence_length, dim // attn.heads),
device=query.device,
dtype=query.dtype,
) )
# do slices # do slices
for i in range(max(1,hidden_states.shape[0] // self.slice_size)): for i in range(max(1, hidden_states.shape[0] // self.slice_size)):
start_idx = i * self.slice_size start_idx = i * self.slice_size
end_idx = (i + 1) * self.slice_size end_idx = (i + 1) * self.slice_size
query_slice = query[start_idx:end_idx] query_slice = query[start_idx:end_idx]
original_key_slice = original_text_key[start_idx:end_idx] original_key_slice = original_text_key[start_idx:end_idx]
modified_key_slice = modified_text_key[start_idx:end_idx] modified_key_slice = modified_text_key[start_idx:end_idx]
attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None attn_mask_slice = (
attention_mask[start_idx:end_idx]
if attention_mask is not None
else None
)
original_attn_slice = attn.get_attention_scores(query_slice, original_key_slice, attn_mask_slice) original_attn_slice = attn.get_attention_scores(
modified_attn_slice = attn.get_attention_scores(query_slice, modified_key_slice, attn_mask_slice) query_slice, original_key_slice, attn_mask_slice
)
modified_attn_slice = attn.get_attention_scores(
query_slice, modified_key_slice, attn_mask_slice
)
# because the prompt modifications may result in token sequences shifted forwards or backwards, # because the prompt modifications may result in token sequences shifted forwards or backwards,
# the original attention probabilities must be remapped to account for token index changes in the # the original attention probabilities must be remapped to account for token index changes in the
# modified prompt # modified prompt
remapped_original_attn_slice = torch.index_select(original_attn_slice, -1, remapped_original_attn_slice = torch.index_select(
swap_cross_attn_context.index_map) original_attn_slice, -1, swap_cross_attn_context.index_map
)
# only some tokens taken from the original attention probabilities. this is controlled by the mask. # only some tokens taken from the original attention probabilities. this is controlled by the mask.
mask = swap_cross_attn_context.mask mask = swap_cross_attn_context.mask
inverse_mask = 1 - mask inverse_mask = 1 - mask
attn_slice = \ attn_slice = (
remapped_original_attn_slice * mask + \ remapped_original_attn_slice * mask + modified_attn_slice * inverse_mask
modified_attn_slice * inverse_mask )
del remapped_original_attn_slice, modified_attn_slice del remapped_original_attn_slice, modified_attn_slice
attn_slice = torch.bmm(attn_slice, modified_value[start_idx:end_idx]) attn_slice = torch.bmm(attn_slice, modified_value[start_idx:end_idx])
hidden_states[start_idx:end_idx] = attn_slice hidden_states[start_idx:end_idx] = attn_slice
# done # done
hidden_states = attn.batch_to_head_dim(hidden_states) hidden_states = attn.batch_to_head_dim(hidden_states)
@ -636,7 +762,7 @@ class SlicedSwapCrossAttnProcesser(SlicedAttnProcessor):
class SwapCrossAttnProcessor(SlicedSwapCrossAttnProcesser): class SwapCrossAttnProcessor(SlicedSwapCrossAttnProcesser):
def __init__(self): def __init__(self):
super(SwapCrossAttnProcessor, self).__init__(slice_size=int(1e9)) # massive slice size = don't slice super(SwapCrossAttnProcessor, self).__init__(
slice_size=int(1e9)
) # massive slice size = don't slice

View File

@ -2,17 +2,17 @@ import math
import PIL import PIL
import torch import torch
from torchvision.transforms.functional import resize as tv_resize, InterpolationMode from torchvision.transforms.functional import InterpolationMode
from torchvision.transforms.functional import resize as tv_resize
from ldm.models.diffusion.cross_attention_control import get_cross_attention_modules, CrossAttentionType from .cross_attention_control import CrossAttentionType, get_cross_attention_modules
class AttentionMapSaver(): class AttentionMapSaver:
def __init__(self, token_ids: range, latents_shape: torch.Size): def __init__(self, token_ids: range, latents_shape: torch.Size):
self.token_ids = token_ids self.token_ids = token_ids
self.latents_shape = latents_shape self.latents_shape = latents_shape
#self.collated_maps = #torch.zeros([len(token_ids), latents_shape[0], latents_shape[1]]) # self.collated_maps = #torch.zeros([len(token_ids), latents_shape[0], latents_shape[1]])
self.collated_maps = {} self.collated_maps = {}
def clear_maps(self): def clear_maps(self):
@ -25,7 +25,7 @@ class AttentionMapSaver():
:param key: Storage key. If a map already exists for this key it will be summed with the incoming data. In this case the maps sizes (H and W) should match. :param key: Storage key. If a map already exists for this key it will be summed with the incoming data. In this case the maps sizes (H and W) should match.
:return: None :return: None
""" """
key_and_size = f'{key}_{maps.shape[1]}' key_and_size = f"{key}_{maps.shape[1]}"
# extract desired tokens # extract desired tokens
maps = maps[:, :, self.token_ids] maps = maps[:, :, self.token_ids]
@ -35,12 +35,12 @@ class AttentionMapSaver():
# store # store
if key_and_size not in self.collated_maps: if key_and_size not in self.collated_maps:
self.collated_maps[key_and_size] = torch.zeros_like(maps, device='cpu') self.collated_maps[key_and_size] = torch.zeros_like(maps, device="cpu")
self.collated_maps[key_and_size] += maps.cpu() self.collated_maps[key_and_size] += maps.cpu()
def write_maps_to_disk(self, path: str): def write_maps_to_disk(self, path: str):
pil_image = self.get_stacked_maps_image() pil_image = self.get_stacked_maps_image()
pil_image.save(path, 'PNG') pil_image.save(path, "PNG")
def get_stacked_maps_image(self) -> PIL.Image: def get_stacked_maps_image(self) -> PIL.Image:
""" """
@ -57,39 +57,50 @@ class AttentionMapSaver():
merged = None merged = None
for key, maps in self.collated_maps.items(): for key, maps in self.collated_maps.items():
# maps has shape [(H*W), N] for N tokens # maps has shape [(H*W), N] for N tokens
# but we want [N, H, W] # but we want [N, H, W]
this_scale_factor = math.sqrt(maps.shape[0] / (latents_width * latents_height)) this_scale_factor = math.sqrt(
maps.shape[0] / (latents_width * latents_height)
)
this_maps_height = int(float(latents_height) * this_scale_factor) this_maps_height = int(float(latents_height) * this_scale_factor)
this_maps_width = int(float(latents_width) * this_scale_factor) this_maps_width = int(float(latents_width) * this_scale_factor)
# and we need to do some dimension juggling # and we need to do some dimension juggling
maps = torch.reshape(torch.swapdims(maps, 0, 1), [num_tokens, this_maps_height, this_maps_width]) maps = torch.reshape(
torch.swapdims(maps, 0, 1),
[num_tokens, this_maps_height, this_maps_width],
)
# scale to output size if necessary # scale to output size if necessary
if this_scale_factor != 1: if this_scale_factor != 1:
maps = tv_resize(maps, [latents_height, latents_width], InterpolationMode.BICUBIC) maps = tv_resize(
maps, [latents_height, latents_width], InterpolationMode.BICUBIC
)
# normalize # normalize
maps_min = torch.min(maps) maps_min = torch.min(maps)
maps_range = torch.max(maps) - maps_min maps_range = torch.max(maps) - maps_min
#print(f"map {key} size {[this_maps_width, this_maps_height]} range {[maps_min, maps_min + maps_range]}") # print(f"map {key} size {[this_maps_width, this_maps_height]} range {[maps_min, maps_min + maps_range]}")
maps_normalized = (maps - maps_min) / maps_range maps_normalized = (maps - maps_min) / maps_range
# expand to (-0.1, 1.1) and clamp # expand to (-0.1, 1.1) and clamp
maps_normalized_expanded = maps_normalized * 1.1 - 0.05 maps_normalized_expanded = maps_normalized * 1.1 - 0.05
maps_normalized_expanded_clamped = torch.clamp(maps_normalized_expanded, 0, 1) maps_normalized_expanded_clamped = torch.clamp(
maps_normalized_expanded, 0, 1
)
# merge together, producing a vertical stack # merge together, producing a vertical stack
maps_stacked = torch.reshape(maps_normalized_expanded_clamped, [num_tokens * latents_height, latents_width]) maps_stacked = torch.reshape(
maps_normalized_expanded_clamped,
[num_tokens * latents_height, latents_width],
)
if merged is None: if merged is None:
merged = maps_stacked merged = maps_stacked
else: else:
# screen blend # screen blend
merged = 1 - (1 - maps_stacked)*(1 - merged) merged = 1 - (1 - maps_stacked) * (1 - merged)
if merged is None: if merged is None:
return None return None
merged_bytes = merged.mul(0xff).byte() merged_bytes = merged.mul(0xFF).byte()
return PIL.Image.fromarray(merged_bytes.numpy(), mode='L') return PIL.Image.fromarray(merged_bytes.numpy(), mode="L")

View File

@ -1,77 +1,82 @@
"""SAMPLING ONLY.""" """SAMPLING ONLY."""
import torch import torch
from ldm.models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent
from ldm.models.diffusion.sampler import Sampler from ..diffusionmodules.util import noise_like
from ldm.modules.diffusionmodules.util import noise_like from .sampler import Sampler
from .shared_invokeai_diffusion import InvokeAIDiffuserComponent
class DDIMSampler(Sampler): class DDIMSampler(Sampler):
def __init__(self, model, schedule='linear', device=None, **kwargs): def __init__(self, model, schedule="linear", device=None, **kwargs):
super().__init__(model,schedule,model.num_timesteps,device) super().__init__(model, schedule, model.num_timesteps, device)
self.invokeai_diffuser = InvokeAIDiffuserComponent(self.model, self.invokeai_diffuser = InvokeAIDiffuserComponent(
model_forward_callback = lambda x, sigma, cond: self.model.apply_model(x, sigma, cond)) self.model,
model_forward_callback=lambda x, sigma, cond: self.model.apply_model(
x, sigma, cond
),
)
def prepare_to_sample(self, t_enc, **kwargs): def prepare_to_sample(self, t_enc, **kwargs):
super().prepare_to_sample(t_enc, **kwargs) super().prepare_to_sample(t_enc, **kwargs)
extra_conditioning_info = kwargs.get('extra_conditioning_info', None) extra_conditioning_info = kwargs.get("extra_conditioning_info", None)
all_timesteps_count = kwargs.get('all_timesteps_count', t_enc) all_timesteps_count = kwargs.get("all_timesteps_count", t_enc)
if extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control: if (
self.invokeai_diffuser.override_cross_attention(extra_conditioning_info, step_count = all_timesteps_count) extra_conditioning_info is not None
and extra_conditioning_info.wants_cross_attention_control
):
self.invokeai_diffuser.override_cross_attention(
extra_conditioning_info, step_count=all_timesteps_count
)
else: else:
self.invokeai_diffuser.restore_default_cross_attention() self.invokeai_diffuser.restore_default_cross_attention()
# This is the central routine # This is the central routine
@torch.no_grad() @torch.no_grad()
def p_sample( def p_sample(
self, self,
x, x,
c, c,
t, t,
index, index,
repeat_noise=False, repeat_noise=False,
use_original_steps=False, use_original_steps=False,
quantize_denoised=False, quantize_denoised=False,
temperature=1.0, temperature=1.0,
noise_dropout=0.0, noise_dropout=0.0,
score_corrector=None, score_corrector=None,
corrector_kwargs=None, corrector_kwargs=None,
unconditional_guidance_scale=1.0, unconditional_guidance_scale=1.0,
unconditional_conditioning=None, unconditional_conditioning=None,
step_count:int=1000, # total number of steps step_count: int = 1000, # total number of steps
**kwargs, **kwargs,
): ):
b, *_, device = *x.shape, x.device b, *_, device = *x.shape, x.device
if ( if unconditional_conditioning is None or unconditional_guidance_scale == 1.0:
unconditional_conditioning is None
or unconditional_guidance_scale == 1.0
):
# damian0815 would like to know when/if this code path is used # damian0815 would like to know when/if this code path is used
e_t = self.model.apply_model(x, t, c) e_t = self.model.apply_model(x, t, c)
else: else:
# step_index counts in the opposite direction to index # step_index counts in the opposite direction to index
step_index = step_count-(index+1) step_index = step_count - (index + 1)
e_t = self.invokeai_diffuser.do_diffusion_step( e_t = self.invokeai_diffuser.do_diffusion_step(
x, t, x,
unconditional_conditioning, c, t,
unconditional_conditioning,
c,
unconditional_guidance_scale, unconditional_guidance_scale,
step_index=step_index step_index=step_index,
) )
if score_corrector is not None: if score_corrector is not None:
assert self.model.parameterization == 'eps' assert self.model.parameterization == "eps"
e_t = score_corrector.modify_score( e_t = score_corrector.modify_score(
self.model, e_t, x, t, c, **corrector_kwargs self.model, e_t, x, t, c, **corrector_kwargs
) )
alphas = ( alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
self.model.alphas_cumprod
if use_original_steps
else self.ddim_alphas
)
alphas_prev = ( alphas_prev = (
self.model.alphas_cumprod_prev self.model.alphas_cumprod_prev
if use_original_steps if use_original_steps
@ -101,11 +106,8 @@ class DDIMSampler(Sampler):
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0) pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
# direction pointing to x_t # direction pointing to x_t
dir_xt = (1.0 - a_prev - sigma_t**2).sqrt() * e_t dir_xt = (1.0 - a_prev - sigma_t**2).sqrt() * e_t
noise = ( noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
)
if noise_dropout > 0.0: if noise_dropout > 0.0:
noise = torch.nn.functional.dropout(noise, p=noise_dropout) noise = torch.nn.functional.dropout(noise, p=noise_dropout)
x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
return x_prev, pred_x0, None return x_prev, pred_x0, None

View File

@ -8,12 +8,12 @@ from .cross_attention_map_saving import AttentionMapSaver
from .sampler import Sampler from .sampler import Sampler
from .shared_invokeai_diffusion import InvokeAIDiffuserComponent from .shared_invokeai_diffusion import InvokeAIDiffuserComponent
# at this threshold, the scheduler will stop using the Karras # at this threshold, the scheduler will stop using the Karras
# noise schedule and start using the model's schedule # noise schedule and start using the model's schedule
STEP_THRESHOLD = 30 STEP_THRESHOLD = 30
def cfg_apply_threshold(result, threshold = 0.0, scale = 0.7):
def cfg_apply_threshold(result, threshold=0.0, scale=0.7):
if threshold <= 0.0: if threshold <= 0.0:
return result return result
maxval = 0.0 + torch.max(result).cpu().numpy() maxval = 0.0 + torch.max(result).cpu().numpy()
@ -21,35 +21,43 @@ def cfg_apply_threshold(result, threshold = 0.0, scale = 0.7):
if maxval < threshold and minval > -threshold: if maxval < threshold and minval > -threshold:
return result return result
if maxval > threshold: if maxval > threshold:
maxval = min(max(1, scale*maxval), threshold) maxval = min(max(1, scale * maxval), threshold)
if minval < -threshold: if minval < -threshold:
minval = max(min(-1, scale*minval), -threshold) minval = max(min(-1, scale * minval), -threshold)
return torch.clamp(result, min=minval, max=maxval) return torch.clamp(result, min=minval, max=maxval)
class CFGDenoiser(nn.Module): class CFGDenoiser(nn.Module):
def __init__(self, model, threshold = 0, warmup = 0): def __init__(self, model, threshold=0, warmup=0):
super().__init__() super().__init__()
self.inner_model = model self.inner_model = model
self.threshold = threshold self.threshold = threshold
self.warmup_max = warmup self.warmup_max = warmup
self.warmup = max(warmup / 10, 1) self.warmup = max(warmup / 10, 1)
self.invokeai_diffuser = InvokeAIDiffuserComponent(model, self.invokeai_diffuser = InvokeAIDiffuserComponent(
model_forward_callback=lambda x, sigma, cond: self.inner_model(x, sigma, cond=cond)) model,
model_forward_callback=lambda x, sigma, cond: self.inner_model(
x, sigma, cond=cond
),
)
def prepare_to_sample(self, t_enc, **kwargs): def prepare_to_sample(self, t_enc, **kwargs):
extra_conditioning_info = kwargs.get("extra_conditioning_info", None)
extra_conditioning_info = kwargs.get('extra_conditioning_info', None) if (
extra_conditioning_info is not None
if extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control: and extra_conditioning_info.wants_cross_attention_control
self.invokeai_diffuser.override_cross_attention(extra_conditioning_info, step_count = t_enc) ):
self.invokeai_diffuser.override_cross_attention(
extra_conditioning_info, step_count=t_enc
)
else: else:
self.invokeai_diffuser.restore_default_cross_attention() self.invokeai_diffuser.restore_default_cross_attention()
def forward(self, x, sigma, uncond, cond, cond_scale): def forward(self, x, sigma, uncond, cond, cond_scale):
next_x = self.invokeai_diffuser.do_diffusion_step(x, sigma, uncond, cond, cond_scale) next_x = self.invokeai_diffuser.do_diffusion_step(
x, sigma, uncond, cond, cond_scale
)
if self.warmup < self.warmup_max: if self.warmup < self.warmup_max:
thresh = max(1, 1 + (self.threshold - 1) * (self.warmup / self.warmup_max)) thresh = max(1, 1 + (self.threshold - 1) * (self.warmup / self.warmup_max))
self.warmup += 1 self.warmup += 1
@ -59,8 +67,9 @@ class CFGDenoiser(nn.Module):
thresh = self.threshold thresh = self.threshold
return cfg_apply_threshold(next_x, thresh) return cfg_apply_threshold(next_x, thresh)
class KSampler(Sampler): class KSampler(Sampler):
def __init__(self, model, schedule='lms', device=None, **kwargs): def __init__(self, model, schedule="lms", device=None, **kwargs):
denoiser = K.external.CompVisDenoiser(model) denoiser = K.external.CompVisDenoiser(model)
super().__init__( super().__init__(
denoiser, denoiser,
@ -68,45 +77,49 @@ class KSampler(Sampler):
steps=model.num_timesteps, steps=model.num_timesteps,
) )
self.sigmas = None self.sigmas = None
self.ds = None self.ds = None
self.s_in = None self.s_in = None
self.karras_max = kwargs.get('karras_max',STEP_THRESHOLD) self.karras_max = kwargs.get("karras_max", STEP_THRESHOLD)
if self.karras_max is None: if self.karras_max is None:
self.karras_max = STEP_THRESHOLD self.karras_max = STEP_THRESHOLD
def make_schedule( def make_schedule(
self, self,
ddim_num_steps, ddim_num_steps,
ddim_discretize='uniform', ddim_discretize="uniform",
ddim_eta=0.0, ddim_eta=0.0,
verbose=False, verbose=False,
): ):
outer_model = self.model outer_model = self.model
self.model = outer_model.inner_model self.model = outer_model.inner_model
super().make_schedule( super().make_schedule(
ddim_num_steps, ddim_num_steps,
ddim_discretize='uniform', ddim_discretize="uniform",
ddim_eta=0.0, ddim_eta=0.0,
verbose=False, verbose=False,
) )
self.model = outer_model self.model = outer_model
self.ddim_num_steps = ddim_num_steps self.ddim_num_steps = ddim_num_steps
# we don't need both of these sigmas, but storing them here to make # we don't need both of these sigmas, but storing them here to make
# comparison easier later on # comparison easier later on
self.model_sigmas = self.model.get_sigmas(ddim_num_steps) self.model_sigmas = self.model.get_sigmas(ddim_num_steps)
self.karras_sigmas = K.sampling.get_sigmas_karras( self.karras_sigmas = K.sampling.get_sigmas_karras(
n=ddim_num_steps, n=ddim_num_steps,
sigma_min=self.model.sigmas[0].item(), sigma_min=self.model.sigmas[0].item(),
sigma_max=self.model.sigmas[-1].item(), sigma_max=self.model.sigmas[-1].item(),
rho=7., rho=7.0,
device=self.device, device=self.device,
) )
if ddim_num_steps >= self.karras_max: if ddim_num_steps >= self.karras_max:
print(f'>> Ksampler using model noise schedule (steps >= {self.karras_max})') print(
f">> Ksampler using model noise schedule (steps >= {self.karras_max})"
)
self.sigmas = self.model_sigmas self.sigmas = self.model_sigmas
else: else:
print(f'>> Ksampler using karras noise schedule (steps < {self.karras_max})') print(
f">> Ksampler using karras noise schedule (steps < {self.karras_max})"
)
self.sigmas = self.karras_sigmas self.sigmas = self.karras_sigmas
# ALERT: We are completely overriding the sample() method in the base class, which # ALERT: We are completely overriding the sample() method in the base class, which
@ -116,31 +129,31 @@ class KSampler(Sampler):
@torch.no_grad() @torch.no_grad()
def decode( def decode(
self, self,
z_enc, z_enc,
cond, cond,
t_enc, t_enc,
img_callback=None, img_callback=None,
unconditional_guidance_scale=1.0, unconditional_guidance_scale=1.0,
unconditional_conditioning=None, unconditional_conditioning=None,
use_original_steps=False, use_original_steps=False,
init_latent = None, init_latent=None,
mask = None, mask=None,
**kwargs **kwargs,
): ):
samples,_ = self.sample( samples, _ = self.sample(
batch_size = 1, batch_size=1,
S = t_enc, S=t_enc,
x_T = z_enc, x_T=z_enc,
shape = z_enc.shape[1:], shape=z_enc.shape[1:],
conditioning = cond, conditioning=cond,
unconditional_guidance_scale=unconditional_guidance_scale, unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning = unconditional_conditioning, unconditional_conditioning=unconditional_conditioning,
img_callback = img_callback, img_callback=img_callback,
x0 = init_latent, x0=init_latent,
mask = mask, mask=mask,
**kwargs **kwargs,
) )
return samples return samples
# this is a no-op, provided here for compatibility with ddim and plms samplers # this is a no-op, provided here for compatibility with ddim and plms samplers
@ -174,26 +187,26 @@ class KSampler(Sampler):
log_every_t=100, log_every_t=100,
unconditional_guidance_scale=1.0, unconditional_guidance_scale=1.0,
unconditional_conditioning=None, unconditional_conditioning=None,
extra_conditioning_info: InvokeAIDiffuserComponent.ExtraConditioningInfo=None, extra_conditioning_info: InvokeAIDiffuserComponent.ExtraConditioningInfo = None,
threshold = 0, threshold=0,
perlin = 0, perlin=0,
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
**kwargs, **kwargs,
): ):
def route_callback(k_callback_values): def route_callback(k_callback_values):
if img_callback is not None: if img_callback is not None:
img_callback(k_callback_values['x'],k_callback_values['i']) img_callback(k_callback_values["x"], k_callback_values["i"])
# if make_schedule() hasn't been called, we do it now # if make_schedule() hasn't been called, we do it now
if self.sigmas is None: if self.sigmas is None:
self.make_schedule( self.make_schedule(
ddim_num_steps=S, ddim_num_steps=S,
ddim_eta = eta, ddim_eta=eta,
verbose = False, verbose=False,
) )
# sigmas are set up in make_schedule - we take the last steps items # sigmas are set up in make_schedule - we take the last steps items
sigmas = self.sigmas[-S-1:] sigmas = self.sigmas[-S - 1 :]
# x_T is variation noise. When an init image is provided (in x0) we need to add # x_T is variation noise. When an init image is provided (in x0) we need to add
# more randomness to the starting image. # more randomness to the starting image.
@ -205,27 +218,40 @@ class KSampler(Sampler):
else: else:
x = torch.randn([batch_size, *shape], device=self.device) * sigmas[0] x = torch.randn([batch_size, *shape], device=self.device) * sigmas[0]
model_wrap_cfg = CFGDenoiser(self.model, threshold=threshold, warmup=max(0.8*S,S-10)) model_wrap_cfg = CFGDenoiser(
model_wrap_cfg.prepare_to_sample(S, extra_conditioning_info=extra_conditioning_info) self.model, threshold=threshold, warmup=max(0.8 * S, S - 10)
)
model_wrap_cfg.prepare_to_sample(
S, extra_conditioning_info=extra_conditioning_info
)
# setup attention maps saving. checks for None are because there are multiple code paths to get here. # setup attention maps saving. checks for None are because there are multiple code paths to get here.
attention_map_saver = None attention_map_saver = None
if attention_maps_callback is not None and extra_conditioning_info is not None: if attention_maps_callback is not None and extra_conditioning_info is not None:
eos_token_index = extra_conditioning_info.tokens_count_including_eos_bos - 1 eos_token_index = extra_conditioning_info.tokens_count_including_eos_bos - 1
attention_map_token_ids = range(1, eos_token_index) attention_map_token_ids = range(1, eos_token_index)
attention_map_saver = AttentionMapSaver(token_ids = attention_map_token_ids, latents_shape=x.shape[-2:]) attention_map_saver = AttentionMapSaver(
model_wrap_cfg.invokeai_diffuser.setup_attention_map_saving(attention_map_saver) token_ids=attention_map_token_ids, latents_shape=x.shape[-2:]
)
model_wrap_cfg.invokeai_diffuser.setup_attention_map_saving(
attention_map_saver
)
extra_args = { extra_args = {
'cond': conditioning, "cond": conditioning,
'uncond': unconditional_conditioning, "uncond": unconditional_conditioning,
'cond_scale': unconditional_guidance_scale, "cond_scale": unconditional_guidance_scale,
} }
print(f'>> Sampling with k_{self.schedule} starting at step {len(self.sigmas)-S-1} of {len(self.sigmas)-1} ({S} new sampling steps)') print(
f">> Sampling with k_{self.schedule} starting at step {len(self.sigmas)-S-1} of {len(self.sigmas)-1} ({S} new sampling steps)"
)
sampling_result = ( sampling_result = (
K.sampling.__dict__[f'sample_{self.schedule}']( K.sampling.__dict__[f"sample_{self.schedule}"](
model_wrap_cfg, x, sigmas, extra_args=extra_args, model_wrap_cfg,
callback=route_callback x,
sigmas,
extra_args=extra_args,
callback=route_callback,
), ),
None, None,
) )
@ -237,25 +263,25 @@ class KSampler(Sampler):
# a workaround is found. # a workaround is found.
@torch.no_grad() @torch.no_grad()
def p_sample( def p_sample(
self, self,
img, img,
cond, cond,
ts, ts,
index, index,
unconditional_guidance_scale=1.0, unconditional_guidance_scale=1.0,
unconditional_conditioning=None, unconditional_conditioning=None,
extra_conditioning_info=None, extra_conditioning_info=None,
**kwargs, **kwargs,
): ):
if self.model_wrap is None: if self.model_wrap is None:
self.model_wrap = CFGDenoiser(self.model) self.model_wrap = CFGDenoiser(self.model)
extra_args = { extra_args = {
'cond': cond, "cond": cond,
'uncond': unconditional_conditioning, "uncond": unconditional_conditioning,
'cond_scale': unconditional_guidance_scale, "cond_scale": unconditional_guidance_scale,
} }
if self.s_in is None: if self.s_in is None:
self.s_in = img.new_ones([img.shape[0]]) self.s_in = img.new_ones([img.shape[0]])
if self.ds is None: if self.ds is None:
self.ds = [] self.ds = []
@ -270,14 +296,16 @@ class KSampler(Sampler):
# so the actual formula for indexing into sigmas: # so the actual formula for indexing into sigmas:
# sigma_index = (steps-index) # sigma_index = (steps-index)
s_index = t_enc - index - 1 s_index = t_enc - index - 1
self.model_wrap.prepare_to_sample(s_index, extra_conditioning_info=extra_conditioning_info) self.model_wrap.prepare_to_sample(
img = K.sampling.__dict__[f'_{self.schedule}']( s_index, extra_conditioning_info=extra_conditioning_info
)
img = K.sampling.__dict__[f"_{self.schedule}"](
self.model_wrap, self.model_wrap,
img, img,
self.sigmas, self.sigmas,
s_index, s_index,
s_in = self.s_in, s_in=self.s_in,
ds = self.ds, ds=self.ds,
extra_args=extra_args, extra_args=extra_args,
) )
@ -287,26 +315,25 @@ class KSampler(Sampler):
# we should not be multiplying by self.sigmas[0] if we # we should not be multiplying by self.sigmas[0] if we
# are at an intermediate step in img2img. See similar in # are at an intermediate step in img2img. See similar in
# sample() which does work. # sample() which does work.
def get_initial_image(self,x_T,shape,steps): def get_initial_image(self, x_T, shape, steps):
print(f'WARNING: ksampler.get_initial_image(): get_initial_image needs testing') print(f"WARNING: ksampler.get_initial_image(): get_initial_image needs testing")
x = (torch.randn(shape, device=self.device) * self.sigmas[0]) x = torch.randn(shape, device=self.device) * self.sigmas[0]
if x_T is not None: if x_T is not None:
return x_T + x return x_T + x
else: else:
return x return x
def prepare_to_sample(self,t_enc,**kwargs): def prepare_to_sample(self, t_enc, **kwargs):
self.t_enc = t_enc self.t_enc = t_enc
self.model_wrap = None self.model_wrap = None
self.ds = None self.ds = None
self.s_in = None self.s_in = None
def q_sample(self,x0,ts): def q_sample(self, x0, ts):
''' """
Overrides parent method to return the q_sample of the inner model. Overrides parent method to return the q_sample of the inner model.
''' """
return self.model.inner_model.q_sample(x0,ts) return self.model.inner_model.q_sample(x0, ts)
def conditioning_key(self)->str: def conditioning_key(self) -> str:
return self.model.inner_model.model.conditioning_key return self.model.inner_model.model.conditioning_key

View File

@ -1,52 +1,58 @@
"""SAMPLING ONLY.""" """SAMPLING ONLY."""
import torch
import numpy as np
from tqdm import tqdm
from functools import partial from functools import partial
from ldm.invoke.devices import choose_torch_device
from ldm.models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent import numpy as np
from ldm.models.diffusion.sampler import Sampler import torch
from ldm.modules.diffusionmodules.util import noise_like from tqdm import tqdm
from ...util import choose_torch_device
from ..diffusionmodules.util import noise_like
from .sampler import Sampler
from .shared_invokeai_diffusion import InvokeAIDiffuserComponent
class PLMSSampler(Sampler): class PLMSSampler(Sampler):
def __init__(self, model, schedule='linear', device=None, **kwargs): def __init__(self, model, schedule="linear", device=None, **kwargs):
super().__init__(model,schedule,model.num_timesteps, device) super().__init__(model, schedule, model.num_timesteps, device)
def prepare_to_sample(self, t_enc, **kwargs): def prepare_to_sample(self, t_enc, **kwargs):
super().prepare_to_sample(t_enc, **kwargs) super().prepare_to_sample(t_enc, **kwargs)
extra_conditioning_info = kwargs.get('extra_conditioning_info', None) extra_conditioning_info = kwargs.get("extra_conditioning_info", None)
all_timesteps_count = kwargs.get('all_timesteps_count', t_enc) all_timesteps_count = kwargs.get("all_timesteps_count", t_enc)
if extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control: if (
self.invokeai_diffuser.override_cross_attention(extra_conditioning_info, step_count = all_timesteps_count) extra_conditioning_info is not None
and extra_conditioning_info.wants_cross_attention_control
):
self.invokeai_diffuser.override_cross_attention(
extra_conditioning_info, step_count=all_timesteps_count
)
else: else:
self.invokeai_diffuser.restore_default_cross_attention() self.invokeai_diffuser.restore_default_cross_attention()
# this is the essential routine # this is the essential routine
@torch.no_grad() @torch.no_grad()
def p_sample( def p_sample(
self, self,
x, # image, called 'img' elsewhere x, # image, called 'img' elsewhere
c, # conditioning, called 'cond' elsewhere c, # conditioning, called 'cond' elsewhere
t, # timesteps, called 'ts' elsewhere t, # timesteps, called 'ts' elsewhere
index, index,
repeat_noise=False, repeat_noise=False,
use_original_steps=False, use_original_steps=False,
quantize_denoised=False, quantize_denoised=False,
temperature=1.0, temperature=1.0,
noise_dropout=0.0, noise_dropout=0.0,
score_corrector=None, score_corrector=None,
corrector_kwargs=None, corrector_kwargs=None,
unconditional_guidance_scale=1.0, unconditional_guidance_scale=1.0,
unconditional_conditioning=None, unconditional_conditioning=None,
old_eps=[], old_eps=[],
t_next=None, t_next=None,
step_count:int=1000, # total number of steps step_count: int = 1000, # total number of steps
**kwargs, **kwargs,
): ):
b, *_, device = *x.shape, x.device b, *_, device = *x.shape, x.device
@ -59,24 +65,24 @@ class PLMSSampler(Sampler):
e_t = self.model.apply_model(x, t, c) e_t = self.model.apply_model(x, t, c)
else: else:
# step_index counts in the opposite direction to index # step_index counts in the opposite direction to index
step_index = step_count-(index+1) step_index = step_count - (index + 1)
e_t = self.invokeai_diffuser.do_diffusion_step(x, t, e_t = self.invokeai_diffuser.do_diffusion_step(
unconditional_conditioning, c, x,
unconditional_guidance_scale, t,
step_index=step_index) unconditional_conditioning,
c,
unconditional_guidance_scale,
step_index=step_index,
)
if score_corrector is not None: if score_corrector is not None:
assert self.model.parameterization == 'eps' assert self.model.parameterization == "eps"
e_t = score_corrector.modify_score( e_t = score_corrector.modify_score(
self.model, e_t, x, t, c, **corrector_kwargs self.model, e_t, x, t, c, **corrector_kwargs
) )
return e_t return e_t
alphas = ( alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
self.model.alphas_cumprod
if use_original_steps
else self.ddim_alphas
)
alphas_prev = ( alphas_prev = (
self.model.alphas_cumprod_prev self.model.alphas_cumprod_prev
if use_original_steps if use_original_steps
@ -96,9 +102,7 @@ class PLMSSampler(Sampler):
def get_x_prev_and_pred_x0(e_t, index): def get_x_prev_and_pred_x0(e_t, index):
# select parameters corresponding to the currently considered timestep # select parameters corresponding to the currently considered timestep
a_t = torch.full((b, 1, 1, 1), alphas[index], device=device) a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
a_prev = torch.full( a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
(b, 1, 1, 1), alphas_prev[index], device=device
)
sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device) sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
sqrt_one_minus_at = torch.full( sqrt_one_minus_at = torch.full(
(b, 1, 1, 1), sqrt_one_minus_alphas[index], device=device (b, 1, 1, 1), sqrt_one_minus_alphas[index], device=device
@ -110,11 +114,7 @@ class PLMSSampler(Sampler):
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0) pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
# direction pointing to x_t # direction pointing to x_t
dir_xt = (1.0 - a_prev - sigma_t**2).sqrt() * e_t dir_xt = (1.0 - a_prev - sigma_t**2).sqrt() * e_t
noise = ( noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
sigma_t
* noise_like(x.shape, device, repeat_noise)
* temperature
)
if noise_dropout > 0.0: if noise_dropout > 0.0:
noise = torch.nn.functional.dropout(noise, p=noise_dropout) noise = torch.nn.functional.dropout(noise, p=noise_dropout)
x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
@ -135,10 +135,7 @@ class PLMSSampler(Sampler):
elif len(old_eps) >= 3: elif len(old_eps) >= 3:
# 4nd order Pseudo Linear Multistep (Adams-Bashforth) # 4nd order Pseudo Linear Multistep (Adams-Bashforth)
e_t_prime = ( e_t_prime = (
55 * e_t 55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]
- 59 * old_eps[-1]
+ 37 * old_eps[-2]
- 9 * old_eps[-3]
) / 24 ) / 24
x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index) x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index)

View File

@ -1,31 +1,37 @@
''' """
ldm.models.diffusion.sampler invokeai.models.diffusion.sampler
Base class for ldm.models.diffusion.ddim, ldm.models.diffusion.ksampler, etc Base class for invokeai.models.diffusion.ddim, invokeai.models.diffusion.ksampler, etc
''' """
import torch
import numpy as np
from tqdm import tqdm
from functools import partial from functools import partial
from ldm.invoke.devices import choose_torch_device
from ldm.models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent
from ldm.modules.diffusionmodules.util import ( import numpy as np
import torch
from tqdm import tqdm
from ...util import choose_torch_device
from ..diffusionmodules.util import (
extract_into_tensor,
make_ddim_sampling_parameters, make_ddim_sampling_parameters,
make_ddim_timesteps, make_ddim_timesteps,
noise_like, noise_like,
extract_into_tensor,
) )
from .shared_invokeai_diffusion import InvokeAIDiffuserComponent
class Sampler(object): class Sampler(object):
def __init__(self, model, schedule='linear', steps=None, device=None, **kwargs): def __init__(self, model, schedule="linear", steps=None, device=None, **kwargs):
self.model = model self.model = model
self.ddim_timesteps = None self.ddim_timesteps = None
self.ddpm_num_timesteps = steps self.ddpm_num_timesteps = steps
self.schedule = schedule self.schedule = schedule
self.device = device or choose_torch_device() self.device = device or choose_torch_device()
self.invokeai_diffuser = InvokeAIDiffuserComponent(self.model, self.invokeai_diffuser = InvokeAIDiffuserComponent(
model_forward_callback = lambda x, sigma, cond: self.model.apply_model(x, sigma, cond)) self.model,
model_forward_callback=lambda x, sigma, cond: self.model.apply_model(
x, sigma, cond
),
)
def register_buffer(self, name, attr): def register_buffer(self, name, attr):
if type(attr) == torch.Tensor: if type(attr) == torch.Tensor:
@ -36,11 +42,11 @@ class Sampler(object):
# This method was copied over from ddim.py and probably does stuff that is # This method was copied over from ddim.py and probably does stuff that is
# ddim-specific. Disentangle at some point. # ddim-specific. Disentangle at some point.
def make_schedule( def make_schedule(
self, self,
ddim_num_steps, ddim_num_steps,
ddim_discretize='uniform', ddim_discretize="uniform",
ddim_eta=0.0, ddim_eta=0.0,
verbose=False, verbose=False,
): ):
self.total_steps = ddim_num_steps self.total_steps = ddim_num_steps
self.ddim_timesteps = make_ddim_timesteps( self.ddim_timesteps = make_ddim_timesteps(
@ -52,38 +58,33 @@ class Sampler(object):
alphas_cumprod = self.model.alphas_cumprod alphas_cumprod = self.model.alphas_cumprod
assert ( assert (
alphas_cumprod.shape[0] == self.ddpm_num_timesteps alphas_cumprod.shape[0] == self.ddpm_num_timesteps
), 'alphas have to be defined for each timestep' ), "alphas have to be defined for each timestep"
to_torch = ( to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
lambda x: x.clone()
.detach()
.to(torch.float32)
.to(self.model.device)
)
self.register_buffer('betas', to_torch(self.model.betas)) self.register_buffer("betas", to_torch(self.model.betas))
self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod))
self.register_buffer( self.register_buffer(
'alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev) "alphas_cumprod_prev", to_torch(self.model.alphas_cumprod_prev)
) )
# calculations for diffusion q(x_t | x_{t-1}) and others # calculations for diffusion q(x_t | x_{t-1}) and others
self.register_buffer( self.register_buffer(
'sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())) "sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod.cpu()))
) )
self.register_buffer( self.register_buffer(
'sqrt_one_minus_alphas_cumprod', "sqrt_one_minus_alphas_cumprod",
to_torch(np.sqrt(1.0 - alphas_cumprod.cpu())), to_torch(np.sqrt(1.0 - alphas_cumprod.cpu())),
) )
self.register_buffer( self.register_buffer(
'log_one_minus_alphas_cumprod', "log_one_minus_alphas_cumprod",
to_torch(np.log(1.0 - alphas_cumprod.cpu())), to_torch(np.log(1.0 - alphas_cumprod.cpu())),
) )
self.register_buffer( self.register_buffer(
'sqrt_recip_alphas_cumprod', "sqrt_recip_alphas_cumprod",
to_torch(np.sqrt(1.0 / alphas_cumprod.cpu())), to_torch(np.sqrt(1.0 / alphas_cumprod.cpu())),
) )
self.register_buffer( self.register_buffer(
'sqrt_recipm1_alphas_cumprod', "sqrt_recipm1_alphas_cumprod",
to_torch(np.sqrt(1.0 / alphas_cumprod.cpu() - 1)), to_torch(np.sqrt(1.0 / alphas_cumprod.cpu() - 1)),
) )
@ -98,19 +99,17 @@ class Sampler(object):
eta=ddim_eta, eta=ddim_eta,
verbose=verbose, verbose=verbose,
) )
self.register_buffer('ddim_sigmas', ddim_sigmas) self.register_buffer("ddim_sigmas", ddim_sigmas)
self.register_buffer('ddim_alphas', ddim_alphas) self.register_buffer("ddim_alphas", ddim_alphas)
self.register_buffer('ddim_alphas_prev', ddim_alphas_prev) self.register_buffer("ddim_alphas_prev", ddim_alphas_prev)
self.register_buffer( self.register_buffer("ddim_sqrt_one_minus_alphas", np.sqrt(1.0 - ddim_alphas))
'ddim_sqrt_one_minus_alphas', np.sqrt(1.0 - ddim_alphas)
)
sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt( sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
(1 - self.alphas_cumprod_prev) (1 - self.alphas_cumprod_prev)
/ (1 - self.alphas_cumprod) / (1 - self.alphas_cumprod)
* (1 - self.alphas_cumprod / self.alphas_cumprod_prev) * (1 - self.alphas_cumprod / self.alphas_cumprod_prev)
) )
self.register_buffer( self.register_buffer(
'ddim_sigmas_for_original_num_steps', "ddim_sigmas_for_original_num_steps",
sigmas_for_original_sampling_steps, sigmas_for_original_sampling_steps,
) )
@ -129,20 +128,19 @@ class Sampler(object):
noise = torch.randn_like(x0) noise = torch.randn_like(x0)
return ( return (
extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0 extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0
+ extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape) + extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape) * noise
* noise
) )
@torch.no_grad() @torch.no_grad()
def sample( def sample(
self, self,
S, # S is steps S, # S is steps
batch_size, batch_size,
shape, shape,
conditioning=None, conditioning=None,
callback=None, callback=None,
normals_sequence=None, normals_sequence=None,
img_callback=None, # TODO: this is very confusing because it is called "step_callback" elsewhere. Change. img_callback=None, # TODO: this is very confusing because it is called "step_callback" elsewhere. Change.
quantize_x0=False, quantize_x0=False,
eta=0.0, eta=0.0,
mask=None, mask=None,
@ -159,7 +157,6 @@ class Sampler(object):
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
**kwargs, **kwargs,
): ):
if conditioning is not None: if conditioning is not None:
if isinstance(conditioning, dict): if isinstance(conditioning, dict):
ctmp = conditioning[list(conditioning.keys())[0]] ctmp = conditioning[list(conditioning.keys())[0]]
@ -167,17 +164,21 @@ class Sampler(object):
ctmp = ctmp[0] ctmp = ctmp[0]
cbs = ctmp.shape[0] cbs = ctmp.shape[0]
if cbs != batch_size: if cbs != batch_size:
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") print(
f"Warning: Got {cbs} conditionings but batch-size is {batch_size}"
)
else: else:
if conditioning.shape[0] != batch_size: if conditioning.shape[0] != batch_size:
print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") print(
f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}"
)
# check to see if make_schedule() has run, and if not, run it # check to see if make_schedule() has run, and if not, run it
if self.ddim_timesteps is None: if self.ddim_timesteps is None:
self.make_schedule( self.make_schedule(
ddim_num_steps=S, ddim_num_steps=S,
ddim_eta = eta, ddim_eta=eta,
verbose = False, verbose=False,
) )
ts = self.get_timesteps(S) ts = self.get_timesteps(S)
@ -204,32 +205,32 @@ class Sampler(object):
unconditional_guidance_scale=unconditional_guidance_scale, unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning, unconditional_conditioning=unconditional_conditioning,
steps=S, steps=S,
**kwargs **kwargs,
) )
return samples, intermediates return samples, intermediates
@torch.no_grad() @torch.no_grad()
def do_sampling( def do_sampling(
self, self,
cond, cond,
shape, shape,
timesteps=None, timesteps=None,
x_T=None, x_T=None,
ddim_use_original_steps=False, ddim_use_original_steps=False,
callback=None, callback=None,
quantize_denoised=False, quantize_denoised=False,
mask=None, mask=None,
x0=None, x0=None,
img_callback=None, img_callback=None,
log_every_t=100, log_every_t=100,
temperature=1.0, temperature=1.0,
noise_dropout=0.0, noise_dropout=0.0,
score_corrector=None, score_corrector=None,
corrector_kwargs=None, corrector_kwargs=None,
unconditional_guidance_scale=1.0, unconditional_guidance_scale=1.0,
unconditional_conditioning=None, unconditional_conditioning=None,
steps=None, steps=None,
**kwargs **kwargs,
): ):
b = shape[0] b = shape[0]
time_range = ( time_range = (
@ -238,29 +239,24 @@ class Sampler(object):
else np.flip(timesteps) else np.flip(timesteps)
) )
total_steps=steps total_steps = steps
iterator = tqdm( iterator = tqdm(
time_range, time_range,
desc=f'{self.__class__.__name__}', desc=f"{self.__class__.__name__}",
total=total_steps, total=total_steps,
dynamic_ncols=True, dynamic_ncols=True,
) )
old_eps = [] old_eps = []
self.prepare_to_sample(t_enc=total_steps,all_timesteps_count=steps,**kwargs) self.prepare_to_sample(t_enc=total_steps, all_timesteps_count=steps, **kwargs)
img = self.get_initial_image(x_T,shape,total_steps) img = self.get_initial_image(x_T, shape, total_steps)
# probably don't need this at all # probably don't need this at all
intermediates = {'x_inter': [img], 'pred_x0': [img]} intermediates = {"x_inter": [img], "pred_x0": [img]}
for i, step in enumerate(iterator): for i, step in enumerate(iterator):
index = total_steps - i - 1 index = total_steps - i - 1
ts = torch.full( ts = torch.full((b,), step, device=self.device, dtype=torch.long)
(b,),
step,
device=self.device,
dtype=torch.long
)
ts_next = torch.full( ts_next = torch.full(
(b,), (b,),
time_range[min(i + 1, len(time_range) - 1)], time_range[min(i + 1, len(time_range) - 1)],
@ -290,7 +286,7 @@ class Sampler(object):
unconditional_conditioning=unconditional_conditioning, unconditional_conditioning=unconditional_conditioning,
old_eps=old_eps, old_eps=old_eps,
t_next=ts_next, t_next=ts_next,
step_count=steps step_count=steps,
) )
img, pred_x0, e_t = outs img, pred_x0, e_t = outs
@ -300,11 +296,11 @@ class Sampler(object):
if callback: if callback:
callback(i) callback(i)
if img_callback: if img_callback:
img_callback(img,i) img_callback(img, i)
if index % log_every_t == 0 or index == total_steps - 1: if index % log_every_t == 0 or index == total_steps - 1:
intermediates['x_inter'].append(img) intermediates["x_inter"].append(img)
intermediates['pred_x0'].append(pred_x0) intermediates["pred_x0"].append(pred_x0)
return img, intermediates return img, intermediates
@ -312,18 +308,18 @@ class Sampler(object):
# The variable names are changed in order to be confusing. # The variable names are changed in order to be confusing.
@torch.no_grad() @torch.no_grad()
def decode( def decode(
self, self,
x_latent, x_latent,
cond, cond,
t_start, t_start,
img_callback=None, img_callback=None,
unconditional_guidance_scale=1.0, unconditional_guidance_scale=1.0,
unconditional_conditioning=None, unconditional_conditioning=None,
use_original_steps=False, use_original_steps=False,
init_latent = None, init_latent=None,
mask = None, mask=None,
all_timesteps_count = None, all_timesteps_count=None,
**kwargs **kwargs,
): ):
timesteps = ( timesteps = (
np.arange(self.ddpm_num_timesteps) np.arange(self.ddpm_num_timesteps)
@ -334,12 +330,16 @@ class Sampler(object):
time_range = np.flip(timesteps) time_range = np.flip(timesteps)
total_steps = timesteps.shape[0] total_steps = timesteps.shape[0]
print(f'>> Running {self.__class__.__name__} sampling starting at step {self.total_steps - t_start} of {self.total_steps} ({total_steps} new sampling steps)') print(
f">> Running {self.__class__.__name__} sampling starting at step {self.total_steps - t_start} of {self.total_steps} ({total_steps} new sampling steps)"
)
iterator = tqdm(time_range, desc='Decoding image', total=total_steps) iterator = tqdm(time_range, desc="Decoding image", total=total_steps)
x_dec = x_latent x_dec = x_latent
x0 = init_latent x0 = init_latent
self.prepare_to_sample(t_enc=total_steps, all_timesteps_count=all_timesteps_count, **kwargs) self.prepare_to_sample(
t_enc=total_steps, all_timesteps_count=all_timesteps_count, **kwargs
)
for i, step in enumerate(iterator): for i, step in enumerate(iterator):
index = total_steps - i - 1 index = total_steps - i - 1
@ -370,81 +370,85 @@ class Sampler(object):
use_original_steps=use_original_steps, use_original_steps=use_original_steps,
unconditional_guidance_scale=unconditional_guidance_scale, unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning, unconditional_conditioning=unconditional_conditioning,
t_next = ts_next, t_next=ts_next,
step_count=len(self.ddim_timesteps) step_count=len(self.ddim_timesteps),
) )
x_dec, pred_x0, e_t = outs x_dec, pred_x0, e_t = outs
if img_callback: if img_callback:
img_callback(x_dec,i) img_callback(x_dec, i)
return x_dec return x_dec
def get_initial_image(self,x_T,shape,timesteps=None): def get_initial_image(self, x_T, shape, timesteps=None):
if x_T is None: if x_T is None:
return torch.randn(shape, device=self.device) return torch.randn(shape, device=self.device)
else: else:
return x_T return x_T
def p_sample( def p_sample(
self, self,
img, img,
cond, cond,
ts, ts,
index, index,
repeat_noise=False, repeat_noise=False,
use_original_steps=False, use_original_steps=False,
quantize_denoised=False, quantize_denoised=False,
temperature=1.0, temperature=1.0,
noise_dropout=0.0, noise_dropout=0.0,
score_corrector=None, score_corrector=None,
corrector_kwargs=None, corrector_kwargs=None,
unconditional_guidance_scale=1.0, unconditional_guidance_scale=1.0,
unconditional_conditioning=None, unconditional_conditioning=None,
old_eps=None, old_eps=None,
t_next=None, t_next=None,
steps=None, steps=None,
): ):
raise NotImplementedError("p_sample() must be implemented in a descendent class") raise NotImplementedError(
"p_sample() must be implemented in a descendent class"
)
def prepare_to_sample(self,t_enc,**kwargs): def prepare_to_sample(self, t_enc, **kwargs):
''' """
Hook that will be called right before the very first invocation of p_sample() Hook that will be called right before the very first invocation of p_sample()
to allow subclass to do additional initialization. t_enc corresponds to the actual to allow subclass to do additional initialization. t_enc corresponds to the actual
number of steps that will be run, and may be less than total steps if img2img is number of steps that will be run, and may be less than total steps if img2img is
active. active.
''' """
pass pass
def get_timesteps(self,ddim_steps): def get_timesteps(self, ddim_steps):
''' """
The ddim and plms samplers work on timesteps. This method is called after The ddim and plms samplers work on timesteps. This method is called after
ddim_timesteps are created in make_schedule(), and selects the portion of ddim_timesteps are created in make_schedule(), and selects the portion of
timesteps that will be used for sampling, depending on the t_enc in img2img. timesteps that will be used for sampling, depending on the t_enc in img2img.
''' """
return self.ddim_timesteps[:ddim_steps] return self.ddim_timesteps[:ddim_steps]
def q_sample(self,x0,ts): def q_sample(self, x0, ts):
''' """
Returns self.model.q_sample(x0,ts). Is overridden in the k* samplers to Returns self.model.q_sample(x0,ts). Is overridden in the k* samplers to
return self.model.inner_model.q_sample(x0,ts) return self.model.inner_model.q_sample(x0,ts)
''' """
return self.model.q_sample(x0,ts) return self.model.q_sample(x0, ts)
def conditioning_key(self)->str: def conditioning_key(self) -> str:
return self.model.model.conditioning_key return self.model.model.conditioning_key
def uses_inpainting_model(self)->bool: def uses_inpainting_model(self) -> bool:
return self.conditioning_key() in ('hybrid','concat') return self.conditioning_key() in ("hybrid", "concat")
def adjust_settings(self,**kwargs): def adjust_settings(self, **kwargs):
''' """
This is a catch-all method for adjusting any instance variables This is a catch-all method for adjusting any instance variables
after the sampler is instantiated. No type-checking performed after the sampler is instantiated. No type-checking performed
here, so use with care! here, so use with care!
''' """
for k in kwargs.keys(): for k in kwargs.keys():
try: try:
setattr(self,k,kwargs[k]) setattr(self, k, kwargs[k])
except AttributeError: except AttributeError:
print(f'** Warning: attempt to set unknown attribute {k} in sampler of type {type(self)}') print(
f"** Warning: attempt to set unknown attribute {k} in sampler of type {type(self)}"
)

View File

@ -1,25 +1,36 @@
from contextlib import contextmanager from contextlib import contextmanager
from dataclasses import dataclass from dataclasses import dataclass
from math import ceil from math import ceil
from typing import Callable, Optional, Union, Any, Dict from typing import Any, Callable, Dict, Optional, Union
import numpy as np import numpy as np
import torch import torch
from diffusers.models.cross_attention import AttnProcessor from diffusers.models.cross_attention import AttnProcessor
from typing_extensions import TypeAlias from typing_extensions import TypeAlias
from ldm.invoke.globals import Globals from invokeai.backend.globals import Globals
from ldm.models.diffusion.cross_attention_control import Arguments, \
restore_default_cross_attention, override_cross_attention, Context, get_cross_attention_modules, \ from .cross_attention_control import (
CrossAttentionType, SwapCrossAttnContext Arguments,
from ldm.models.diffusion.cross_attention_map_saving import AttentionMapSaver Context,
CrossAttentionType,
SwapCrossAttnContext,
get_cross_attention_modules,
override_cross_attention,
restore_default_cross_attention,
)
from .cross_attention_map_saving import AttentionMapSaver
ModelForwardCallback: TypeAlias = Union[ ModelForwardCallback: TypeAlias = Union[
# x, t, conditioning, Optional[cross-attention kwargs] # x, t, conditioning, Optional[cross-attention kwargs]
Callable[[torch.Tensor, torch.Tensor, torch.Tensor, Optional[dict[str, Any]]], torch.Tensor], Callable[
Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor] [torch.Tensor, torch.Tensor, torch.Tensor, Optional[dict[str, Any]]],
torch.Tensor,
],
Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor],
] ]
@dataclass(frozen=True) @dataclass(frozen=True)
class PostprocessingSettings: class PostprocessingSettings:
threshold: float threshold: float
@ -29,20 +40,20 @@ class PostprocessingSettings:
class InvokeAIDiffuserComponent: class InvokeAIDiffuserComponent:
''' """
The aim of this component is to provide a single place for code that can be applied identically to The aim of this component is to provide a single place for code that can be applied identically to
all InvokeAI diffusion procedures. all InvokeAI diffusion procedures.
At the moment it includes the following features: At the moment it includes the following features:
* Cross attention control ("prompt2prompt") * Cross attention control ("prompt2prompt")
* Hybrid conditioning (used for inpainting) * Hybrid conditioning (used for inpainting)
''' """
debug_thresholding = False debug_thresholding = False
sequential_guidance = False sequential_guidance = False
@dataclass @dataclass
class ExtraConditioningInfo: class ExtraConditioningInfo:
tokens_count_including_eos_bos: int tokens_count_including_eos_bos: int
cross_attention_control_args: Optional[Arguments] = None cross_attention_control_args: Optional[Arguments] = None
@ -50,10 +61,12 @@ class InvokeAIDiffuserComponent:
def wants_cross_attention_control(self): def wants_cross_attention_control(self):
return self.cross_attention_control_args is not None return self.cross_attention_control_args is not None
def __init__(
def __init__(self, model, model_forward_callback: ModelForwardCallback, self,
is_running_diffusers: bool=False, model,
): model_forward_callback: ModelForwardCallback,
is_running_diffusers: bool = False,
):
""" """
:param model: the unet model to pass through to cross attention control :param model: the unet model to pass through to cross attention control
:param model_forward_callback: a lambda with arguments (x, sigma, conditioning_to_apply). will be called repeatedly. most likely, this should simply call model.forward(x, sigma, conditioning) :param model_forward_callback: a lambda with arguments (x, sigma, conditioning_to_apply). will be called repeatedly. most likely, this should simply call model.forward(x, sigma, conditioning)
@ -66,23 +79,29 @@ class InvokeAIDiffuserComponent:
self.sequential_guidance = Globals.sequential_guidance self.sequential_guidance = Globals.sequential_guidance
@contextmanager @contextmanager
def custom_attention_context(self, def custom_attention_context(
extra_conditioning_info: Optional[ExtraConditioningInfo], self, extra_conditioning_info: Optional[ExtraConditioningInfo], step_count: int
step_count: int): ):
do_swap = extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control do_swap = (
extra_conditioning_info is not None
and extra_conditioning_info.wants_cross_attention_control
)
old_attn_processor = None old_attn_processor = None
if do_swap: if do_swap:
old_attn_processor = self.override_cross_attention(extra_conditioning_info, old_attn_processor = self.override_cross_attention(
step_count=step_count) extra_conditioning_info, step_count=step_count
)
try: try:
yield None yield None
finally: finally:
if old_attn_processor is not None: if old_attn_processor is not None:
self.restore_default_cross_attention(old_attn_processor) self.restore_default_cross_attention(old_attn_processor)
# TODO resuscitate attention map saving # TODO resuscitate attention map saving
#self.remove_attention_map_saving() # self.remove_attention_map_saving()
def override_cross_attention(self, conditioning: ExtraConditioningInfo, step_count: int) -> Dict[str, AttnProcessor]: def override_cross_attention(
self, conditioning: ExtraConditioningInfo, step_count: int
) -> Dict[str, AttnProcessor]:
""" """
setup cross attention .swap control. for diffusers this replaces the attention processor, so setup cross attention .swap control. for diffusers this replaces the attention processor, so
the previous attention processor is returned so that the caller can restore it later. the previous attention processor is returned so that the caller can restore it later.
@ -90,18 +109,24 @@ class InvokeAIDiffuserComponent:
self.conditioning = conditioning self.conditioning = conditioning
self.cross_attention_control_context = Context( self.cross_attention_control_context = Context(
arguments=self.conditioning.cross_attention_control_args, arguments=self.conditioning.cross_attention_control_args,
step_count=step_count step_count=step_count,
)
return override_cross_attention(
self.model,
self.cross_attention_control_context,
is_running_diffusers=self.is_running_diffusers,
) )
return override_cross_attention(self.model,
self.cross_attention_control_context,
is_running_diffusers=self.is_running_diffusers)
def restore_default_cross_attention(self, restore_attention_processor: Optional['AttnProcessor']=None): def restore_default_cross_attention(
self, restore_attention_processor: Optional["AttnProcessor"] = None
):
self.conditioning = None self.conditioning = None
self.cross_attention_control_context = None self.cross_attention_control_context = None
restore_default_cross_attention(self.model, restore_default_cross_attention(
is_running_diffusers=self.is_running_diffusers, self.model,
restore_attention_processor=restore_attention_processor) is_running_diffusers=self.is_running_diffusers,
restore_attention_processor=restore_attention_processor,
)
def setup_attention_map_saving(self, saver: AttentionMapSaver): def setup_attention_map_saving(self, saver: AttentionMapSaver):
def callback(slice, dim, offset, slice_size, key): def callback(slice, dim, offset, slice_size, key):
@ -110,26 +135,40 @@ class InvokeAIDiffuserComponent:
return return
saver.add_attention_maps(slice, key) saver.add_attention_maps(slice, key)
tokens_cross_attention_modules = get_cross_attention_modules(self.model, CrossAttentionType.TOKENS) tokens_cross_attention_modules = get_cross_attention_modules(
self.model, CrossAttentionType.TOKENS
)
for identifier, module in tokens_cross_attention_modules: for identifier, module in tokens_cross_attention_modules:
key = ('down' if identifier.startswith('down') else key = (
'up' if identifier.startswith('up') else "down"
'mid') if identifier.startswith("down")
else "up"
if identifier.startswith("up")
else "mid"
)
module.set_attention_slice_calculated_callback( module.set_attention_slice_calculated_callback(
lambda slice, dim, offset, slice_size, key=key: callback(slice, dim, offset, slice_size, key)) lambda slice, dim, offset, slice_size, key=key: callback(
slice, dim, offset, slice_size, key
)
)
def remove_attention_map_saving(self): def remove_attention_map_saving(self):
tokens_cross_attention_modules = get_cross_attention_modules(self.model, CrossAttentionType.TOKENS) tokens_cross_attention_modules = get_cross_attention_modules(
self.model, CrossAttentionType.TOKENS
)
for _, module in tokens_cross_attention_modules: for _, module in tokens_cross_attention_modules:
module.set_attention_slice_calculated_callback(None) module.set_attention_slice_calculated_callback(None)
def do_diffusion_step(self, x: torch.Tensor, sigma: torch.Tensor, def do_diffusion_step(
unconditioning: Union[torch.Tensor,dict], self,
conditioning: Union[torch.Tensor,dict], x: torch.Tensor,
unconditional_guidance_scale: float, sigma: torch.Tensor,
step_index: Optional[int]=None, unconditioning: Union[torch.Tensor, dict],
total_step_count: Optional[int]=None, conditioning: Union[torch.Tensor, dict],
): unconditional_guidance_scale: float,
step_index: Optional[int] = None,
total_step_count: Optional[int] = None,
):
""" """
:param x: current latents :param x: current latents
:param sigma: aka t, passed to the internal model to control how much denoising will occur :param sigma: aka t, passed to the internal model to control how much denoising will occur
@ -140,33 +179,55 @@ class InvokeAIDiffuserComponent:
:return: the new latents after applying the model to x using unscaled unconditioning and CFG-scaled conditioning. :return: the new latents after applying the model to x using unscaled unconditioning and CFG-scaled conditioning.
""" """
cross_attention_control_types_to_do = [] cross_attention_control_types_to_do = []
context: Context = self.cross_attention_control_context context: Context = self.cross_attention_control_context
if self.cross_attention_control_context is not None: if self.cross_attention_control_context is not None:
percent_through = self.calculate_percent_through(sigma, step_index, total_step_count) percent_through = self.calculate_percent_through(
cross_attention_control_types_to_do = context.get_active_cross_attention_control_types_for_step(percent_through) sigma, step_index, total_step_count
)
cross_attention_control_types_to_do = (
context.get_active_cross_attention_control_types_for_step(
percent_through
)
)
wants_cross_attention_control = (len(cross_attention_control_types_to_do) > 0) wants_cross_attention_control = len(cross_attention_control_types_to_do) > 0
wants_hybrid_conditioning = isinstance(conditioning, dict) wants_hybrid_conditioning = isinstance(conditioning, dict)
if wants_hybrid_conditioning: if wants_hybrid_conditioning:
unconditioned_next_x, conditioned_next_x = self._apply_hybrid_conditioning(x, sigma, unconditioning, unconditioned_next_x, conditioned_next_x = self._apply_hybrid_conditioning(
conditioning) x, sigma, unconditioning, conditioning
)
elif wants_cross_attention_control: elif wants_cross_attention_control:
unconditioned_next_x, conditioned_next_x = self._apply_cross_attention_controlled_conditioning(x, sigma, (
unconditioning, unconditioned_next_x,
conditioning, conditioned_next_x,
cross_attention_control_types_to_do) ) = self._apply_cross_attention_controlled_conditioning(
x,
sigma,
unconditioning,
conditioning,
cross_attention_control_types_to_do,
)
elif self.sequential_guidance: elif self.sequential_guidance:
unconditioned_next_x, conditioned_next_x = self._apply_standard_conditioning_sequentially( (
x, sigma, unconditioning, conditioning) unconditioned_next_x,
conditioned_next_x,
) = self._apply_standard_conditioning_sequentially(
x, sigma, unconditioning, conditioning
)
else: else:
unconditioned_next_x, conditioned_next_x = self._apply_standard_conditioning( (
x, sigma, unconditioning, conditioning) unconditioned_next_x,
conditioned_next_x,
) = self._apply_standard_conditioning(
x, sigma, unconditioning, conditioning
)
combined_next_x = self._combine(unconditioned_next_x, conditioned_next_x, unconditional_guidance_scale) combined_next_x = self._combine(
unconditioned_next_x, conditioned_next_x, unconditional_guidance_scale
)
return combined_next_x return combined_next_x
@ -176,24 +237,33 @@ class InvokeAIDiffuserComponent:
latents: torch.Tensor, latents: torch.Tensor,
sigma, sigma,
step_index, step_index,
total_step_count total_step_count,
) -> torch.Tensor: ) -> torch.Tensor:
if postprocessing_settings is not None: if postprocessing_settings is not None:
percent_through = self.calculate_percent_through(sigma, step_index, total_step_count) percent_through = self.calculate_percent_through(
latents = self.apply_threshold(postprocessing_settings, latents, percent_through) sigma, step_index, total_step_count
latents = self.apply_symmetry(postprocessing_settings, latents, percent_through) )
latents = self.apply_threshold(
postprocessing_settings, latents, percent_through
)
latents = self.apply_symmetry(
postprocessing_settings, latents, percent_through
)
return latents return latents
def calculate_percent_through(self, sigma, step_index, total_step_count): def calculate_percent_through(self, sigma, step_index, total_step_count):
if step_index is not None and total_step_count is not None: if step_index is not None and total_step_count is not None:
# 🧨diffusers codepath # 🧨diffusers codepath
percent_through = step_index / total_step_count # will never reach 1.0 - this is deliberate percent_through = (
step_index / total_step_count
) # will never reach 1.0 - this is deliberate
else: else:
# legacy compvis codepath # legacy compvis codepath
# TODO remove when compvis codepath support is dropped # TODO remove when compvis codepath support is dropped
if step_index is None and sigma is None: if step_index is None and sigma is None:
raise ValueError( raise ValueError(
f"Either step_index or sigma is required when doing cross attention control, but both are None.") f"Either step_index or sigma is required when doing cross attention control, but both are None."
)
percent_through = self.estimate_percent_through(step_index, sigma) percent_through = self.estimate_percent_through(step_index, sigma)
return percent_through return percent_through
@ -204,24 +274,30 @@ class InvokeAIDiffuserComponent:
x_twice = torch.cat([x] * 2) x_twice = torch.cat([x] * 2)
sigma_twice = torch.cat([sigma] * 2) sigma_twice = torch.cat([sigma] * 2)
both_conditionings = torch.cat([unconditioning, conditioning]) both_conditionings = torch.cat([unconditioning, conditioning])
both_results = self.model_forward_callback(x_twice, sigma_twice, both_conditionings) both_results = self.model_forward_callback(
x_twice, sigma_twice, both_conditionings
)
unconditioned_next_x, conditioned_next_x = both_results.chunk(2) unconditioned_next_x, conditioned_next_x = both_results.chunk(2)
if conditioned_next_x.device.type == 'mps': if conditioned_next_x.device.type == "mps":
# prevent a result filled with zeros. seems to be a torch bug. # prevent a result filled with zeros. seems to be a torch bug.
conditioned_next_x = conditioned_next_x.clone() conditioned_next_x = conditioned_next_x.clone()
return unconditioned_next_x, conditioned_next_x return unconditioned_next_x, conditioned_next_x
def _apply_standard_conditioning_sequentially(
def _apply_standard_conditioning_sequentially(self, x: torch.Tensor, sigma, unconditioning: torch.Tensor, conditioning: torch.Tensor): self,
x: torch.Tensor,
sigma,
unconditioning: torch.Tensor,
conditioning: torch.Tensor,
):
# low-memory sequential path # low-memory sequential path
unconditioned_next_x = self.model_forward_callback(x, sigma, unconditioning) unconditioned_next_x = self.model_forward_callback(x, sigma, unconditioning)
conditioned_next_x = self.model_forward_callback(x, sigma, conditioning) conditioned_next_x = self.model_forward_callback(x, sigma, conditioning)
if conditioned_next_x.device.type == 'mps': if conditioned_next_x.device.type == "mps":
# prevent a result filled with zeros. seems to be a torch bug. # prevent a result filled with zeros. seems to be a torch bug.
conditioned_next_x = conditioned_next_x.clone() conditioned_next_x = conditioned_next_x.clone()
return unconditioned_next_x, conditioned_next_x return unconditioned_next_x, conditioned_next_x
def _apply_hybrid_conditioning(self, x, sigma, unconditioning, conditioning): def _apply_hybrid_conditioning(self, x, sigma, unconditioning, conditioning):
assert isinstance(conditioning, dict) assert isinstance(conditioning, dict)
assert isinstance(unconditioning, dict) assert isinstance(unconditioning, dict)
@ -236,48 +312,80 @@ class InvokeAIDiffuserComponent:
] ]
else: else:
both_conditionings[k] = torch.cat([unconditioning[k], conditioning[k]]) both_conditionings[k] = torch.cat([unconditioning[k], conditioning[k]])
unconditioned_next_x, conditioned_next_x = self.model_forward_callback(x_twice, sigma_twice, both_conditionings).chunk(2) unconditioned_next_x, conditioned_next_x = self.model_forward_callback(
x_twice, sigma_twice, both_conditionings
).chunk(2)
return unconditioned_next_x, conditioned_next_x return unconditioned_next_x, conditioned_next_x
def _apply_cross_attention_controlled_conditioning(
def _apply_cross_attention_controlled_conditioning(self, self,
x: torch.Tensor, x: torch.Tensor,
sigma, sigma,
unconditioning, unconditioning,
conditioning, conditioning,
cross_attention_control_types_to_do): cross_attention_control_types_to_do,
):
if self.is_running_diffusers: if self.is_running_diffusers:
return self._apply_cross_attention_controlled_conditioning__diffusers(x, sigma, unconditioning, return self._apply_cross_attention_controlled_conditioning__diffusers(
conditioning, x,
cross_attention_control_types_to_do) sigma,
unconditioning,
conditioning,
cross_attention_control_types_to_do,
)
else: else:
return self._apply_cross_attention_controlled_conditioning__compvis(x, sigma, unconditioning, conditioning, return self._apply_cross_attention_controlled_conditioning__compvis(
cross_attention_control_types_to_do) x,
sigma,
unconditioning,
conditioning,
cross_attention_control_types_to_do,
)
def _apply_cross_attention_controlled_conditioning__diffusers(self, def _apply_cross_attention_controlled_conditioning__diffusers(
x: torch.Tensor, self,
sigma, x: torch.Tensor,
unconditioning, sigma,
conditioning, unconditioning,
cross_attention_control_types_to_do): conditioning,
cross_attention_control_types_to_do,
):
context: Context = self.cross_attention_control_context context: Context = self.cross_attention_control_context
cross_attn_processor_context = SwapCrossAttnContext(modified_text_embeddings=context.arguments.edited_conditioning, cross_attn_processor_context = SwapCrossAttnContext(
index_map=context.cross_attention_index_map, modified_text_embeddings=context.arguments.edited_conditioning,
mask=context.cross_attention_mask, index_map=context.cross_attention_index_map,
cross_attention_types_to_do=[]) mask=context.cross_attention_mask,
cross_attention_types_to_do=[],
)
# no cross attention for unconditioning (negative prompt) # no cross attention for unconditioning (negative prompt)
unconditioned_next_x = self.model_forward_callback(x, sigma, unconditioning, unconditioned_next_x = self.model_forward_callback(
{"swap_cross_attn_context": cross_attn_processor_context}) x,
sigma,
unconditioning,
{"swap_cross_attn_context": cross_attn_processor_context},
)
# do requested cross attention types for conditioning (positive prompt) # do requested cross attention types for conditioning (positive prompt)
cross_attn_processor_context.cross_attention_types_to_do = cross_attention_control_types_to_do cross_attn_processor_context.cross_attention_types_to_do = (
conditioned_next_x = self.model_forward_callback(x, sigma, conditioning, cross_attention_control_types_to_do
{"swap_cross_attn_context": cross_attn_processor_context}) )
conditioned_next_x = self.model_forward_callback(
x,
sigma,
conditioning,
{"swap_cross_attn_context": cross_attn_processor_context},
)
return unconditioned_next_x, conditioned_next_x return unconditioned_next_x, conditioned_next_x
def _apply_cross_attention_controlled_conditioning__compvis(
def _apply_cross_attention_controlled_conditioning__compvis(self, x:torch.Tensor, sigma, unconditioning, conditioning, cross_attention_control_types_to_do): self,
x: torch.Tensor,
sigma,
unconditioning,
conditioning,
cross_attention_control_types_to_do,
):
# print('pct', percent_through, ': doing cross attention control on', cross_attention_control_types_to_do) # print('pct', percent_through, ': doing cross attention control on', cross_attention_control_types_to_do)
# slower non-batched path (20% slower on mac MPS) # slower non-batched path (20% slower on mac MPS)
# We are only interested in using attention maps for conditioned_next_x, but batching them with generation of # We are only interested in using attention maps for conditioned_next_x, but batching them with generation of
@ -287,24 +395,28 @@ class InvokeAIDiffuserComponent:
# representing batched uncond + cond, but then when it comes to applying the saved attention, the # representing batched uncond + cond, but then when it comes to applying the saved attention, the
# wrangler gets an attention tensor which only has shape[0]=8, representing just self.edited_conditionings.) # wrangler gets an attention tensor which only has shape[0]=8, representing just self.edited_conditionings.)
# todo: give CrossAttentionControl's `wrangler` function more info so it can work with a batched call as well. # todo: give CrossAttentionControl's `wrangler` function more info so it can work with a batched call as well.
context:Context = self.cross_attention_control_context context: Context = self.cross_attention_control_context
try: try:
unconditioned_next_x = self.model_forward_callback(x, sigma, unconditioning) unconditioned_next_x = self.model_forward_callback(x, sigma, unconditioning)
# process x using the original prompt, saving the attention maps # process x using the original prompt, saving the attention maps
#print("saving attention maps for", cross_attention_control_types_to_do) # print("saving attention maps for", cross_attention_control_types_to_do)
for ca_type in cross_attention_control_types_to_do: for ca_type in cross_attention_control_types_to_do:
context.request_save_attention_maps(ca_type) context.request_save_attention_maps(ca_type)
_ = self.model_forward_callback(x, sigma, conditioning) _ = self.model_forward_callback(x, sigma, conditioning)
context.clear_requests(cleanup=False) context.clear_requests(cleanup=False)
# process x again, using the saved attention maps to control where self.edited_conditioning will be applied # process x again, using the saved attention maps to control where self.edited_conditioning will be applied
#print("applying saved attention maps for", cross_attention_control_types_to_do) # print("applying saved attention maps for", cross_attention_control_types_to_do)
for ca_type in cross_attention_control_types_to_do: for ca_type in cross_attention_control_types_to_do:
context.request_apply_saved_attention_maps(ca_type) context.request_apply_saved_attention_maps(ca_type)
edited_conditioning = self.conditioning.cross_attention_control_args.edited_conditioning edited_conditioning = (
conditioned_next_x = self.model_forward_callback(x, sigma, edited_conditioning) self.conditioning.cross_attention_control_args.edited_conditioning
)
conditioned_next_x = self.model_forward_callback(
x, sigma, edited_conditioning
)
context.clear_requests(cleanup=True) context.clear_requests(cleanup=True)
except: except:
@ -323,17 +435,21 @@ class InvokeAIDiffuserComponent:
self, self,
postprocessing_settings: PostprocessingSettings, postprocessing_settings: PostprocessingSettings,
latents: torch.Tensor, latents: torch.Tensor,
percent_through: float percent_through: float,
) -> torch.Tensor: ) -> torch.Tensor:
if (
if postprocessing_settings.threshold is None or postprocessing_settings.threshold == 0.0: postprocessing_settings.threshold is None
or postprocessing_settings.threshold == 0.0
):
return latents return latents
threshold = postprocessing_settings.threshold threshold = postprocessing_settings.threshold
warmup = postprocessing_settings.warmup warmup = postprocessing_settings.warmup
if percent_through < warmup: if percent_through < warmup:
current_threshold = threshold + threshold * 5 * (1 - (percent_through / warmup)) current_threshold = threshold + threshold * 5 * (
1 - (percent_through / warmup)
)
else: else:
current_threshold = threshold current_threshold = threshold
@ -347,10 +463,14 @@ class InvokeAIDiffuserComponent:
if self.debug_thresholding: if self.debug_thresholding:
std, mean = [i.item() for i in torch.std_mean(latents)] std, mean = [i.item() for i in torch.std_mean(latents)]
outside = torch.count_nonzero((latents < -current_threshold) | (latents > current_threshold)) outside = torch.count_nonzero(
print(f"\nThreshold: %={percent_through} threshold={current_threshold:.3f} (of {threshold:.3f})\n" (latents < -current_threshold) | (latents > current_threshold)
f" | min, mean, max = {minval:.3f}, {mean:.3f}, {maxval:.3f}\tstd={std}\n" )
f" | {outside / latents.numel() * 100:.2f}% values outside threshold") print(
f"\nThreshold: %={percent_through} threshold={current_threshold:.3f} (of {threshold:.3f})\n"
f" | min, mean, max = {minval:.3f}, {mean:.3f}, {maxval:.3f}\tstd={std}\n"
f" | {outside / latents.numel() * 100:.2f}% values outside threshold"
)
if maxval < current_threshold and minval > -current_threshold: if maxval < current_threshold and minval > -current_threshold:
return latents return latents
@ -363,17 +483,23 @@ class InvokeAIDiffuserComponent:
latents = torch.clone(latents) latents = torch.clone(latents)
maxval = np.clip(maxval * scale, 1, current_threshold) maxval = np.clip(maxval * scale, 1, current_threshold)
num_altered += torch.count_nonzero(latents > maxval) num_altered += torch.count_nonzero(latents > maxval)
latents[latents > maxval] = torch.rand_like(latents[latents > maxval]) * maxval latents[latents > maxval] = (
torch.rand_like(latents[latents > maxval]) * maxval
)
if minval < -current_threshold: if minval < -current_threshold:
latents = torch.clone(latents) latents = torch.clone(latents)
minval = np.clip(minval * scale, -current_threshold, -1) minval = np.clip(minval * scale, -current_threshold, -1)
num_altered += torch.count_nonzero(latents < minval) num_altered += torch.count_nonzero(latents < minval)
latents[latents < minval] = torch.rand_like(latents[latents < minval]) * minval latents[latents < minval] = (
torch.rand_like(latents[latents < minval]) * minval
)
if self.debug_thresholding: if self.debug_thresholding:
print(f" | min, , max = {minval:.3f}, , {maxval:.3f}\t(scaled by {scale})\n" print(
f" | {num_altered / latents.numel() * 100:.2f}% values altered") f" | min, , max = {minval:.3f}, , {maxval:.3f}\t(scaled by {scale})\n"
f" | {num_altered / latents.numel() * 100:.2f}% values altered"
)
return latents return latents
@ -381,9 +507,8 @@ class InvokeAIDiffuserComponent:
self, self,
postprocessing_settings: PostprocessingSettings, postprocessing_settings: PostprocessingSettings,
latents: torch.Tensor, latents: torch.Tensor,
percent_through: float percent_through: float,
) -> torch.Tensor: ) -> torch.Tensor:
# Reset our last percent through if this is our first step. # Reset our last percent through if this is our first step.
if percent_through == 0.0: if percent_through == 0.0:
self.last_percent_through = 0.0 self.last_percent_through = 0.0
@ -393,36 +518,52 @@ class InvokeAIDiffuserComponent:
# Check for out of bounds # Check for out of bounds
h_symmetry_time_pct = postprocessing_settings.h_symmetry_time_pct h_symmetry_time_pct = postprocessing_settings.h_symmetry_time_pct
if (h_symmetry_time_pct is not None and (h_symmetry_time_pct <= 0.0 or h_symmetry_time_pct > 1.0)): if h_symmetry_time_pct is not None and (
h_symmetry_time_pct <= 0.0 or h_symmetry_time_pct > 1.0
):
h_symmetry_time_pct = None h_symmetry_time_pct = None
v_symmetry_time_pct = postprocessing_settings.v_symmetry_time_pct v_symmetry_time_pct = postprocessing_settings.v_symmetry_time_pct
if (v_symmetry_time_pct is not None and (v_symmetry_time_pct <= 0.0 or v_symmetry_time_pct > 1.0)): if v_symmetry_time_pct is not None and (
v_symmetry_time_pct <= 0.0 or v_symmetry_time_pct > 1.0
):
v_symmetry_time_pct = None v_symmetry_time_pct = None
dev = latents.device.type dev = latents.device.type
latents.to(device='cpu') latents.to(device="cpu")
if ( if (
h_symmetry_time_pct != None and h_symmetry_time_pct != None
self.last_percent_through < h_symmetry_time_pct and and self.last_percent_through < h_symmetry_time_pct
percent_through >= h_symmetry_time_pct and percent_through >= h_symmetry_time_pct
): ):
# Horizontal symmetry occurs on the 3rd dimension of the latent # Horizontal symmetry occurs on the 3rd dimension of the latent
width = latents.shape[3] width = latents.shape[3]
x_flipped = torch.flip(latents, dims=[3]) x_flipped = torch.flip(latents, dims=[3])
latents = torch.cat([latents[:, :, :, 0:int(width/2)], x_flipped[:, :, :, int(width/2):int(width)]], dim=3) latents = torch.cat(
[
latents[:, :, :, 0 : int(width / 2)],
x_flipped[:, :, :, int(width / 2) : int(width)],
],
dim=3,
)
if ( if (
v_symmetry_time_pct != None and v_symmetry_time_pct != None
self.last_percent_through < v_symmetry_time_pct and and self.last_percent_through < v_symmetry_time_pct
percent_through >= v_symmetry_time_pct and percent_through >= v_symmetry_time_pct
): ):
# Vertical symmetry occurs on the 2nd dimension of the latent # Vertical symmetry occurs on the 2nd dimension of the latent
height = latents.shape[2] height = latents.shape[2]
y_flipped = torch.flip(latents, dims=[2]) y_flipped = torch.flip(latents, dims=[2])
latents = torch.cat([latents[:, :, 0:int(height / 2)], y_flipped[:, :, int(height / 2):int(height)]], dim=2) latents = torch.cat(
[
latents[:, :, 0 : int(height / 2)],
y_flipped[:, :, int(height / 2) : int(height)],
],
dim=2,
)
self.last_percent_through = percent_through self.last_percent_through = percent_through
return latents.to(device=dev) return latents.to(device=dev)
@ -430,7 +571,9 @@ class InvokeAIDiffuserComponent:
def estimate_percent_through(self, step_index, sigma): def estimate_percent_through(self, step_index, sigma):
if step_index is not None and self.cross_attention_control_context is not None: if step_index is not None and self.cross_attention_control_context is not None:
# percent_through will never reach 1.0 (but this is intended) # percent_through will never reach 1.0 (but this is intended)
return float(step_index) / float(self.cross_attention_control_context.step_count) return float(step_index) / float(
self.cross_attention_control_context.step_count
)
# find the best possible index of the current sigma in the sigma sequence # find the best possible index of the current sigma in the sigma sequence
smaller_sigmas = torch.nonzero(self.model.sigmas <= sigma) smaller_sigmas = torch.nonzero(self.model.sigmas <= sigma)
sigma_index = smaller_sigmas[-1].item() if smaller_sigmas.shape[0] > 0 else 0 sigma_index = smaller_sigmas[-1].item() if smaller_sigmas.shape[0] > 0 else 0
@ -439,33 +582,38 @@ class InvokeAIDiffuserComponent:
return 1.0 - float(sigma_index + 1) / float(self.model.sigmas.shape[0]) return 1.0 - float(sigma_index + 1) / float(self.model.sigmas.shape[0])
# print('estimated percent_through', percent_through, 'from sigma', sigma.item()) # print('estimated percent_through', percent_through, 'from sigma', sigma.item())
# todo: make this work # todo: make this work
@classmethod @classmethod
def apply_conjunction(cls, x, t, forward_func, uc, c_or_weighted_c_list, global_guidance_scale): def apply_conjunction(
cls, x, t, forward_func, uc, c_or_weighted_c_list, global_guidance_scale
):
x_in = torch.cat([x] * 2) x_in = torch.cat([x] * 2)
t_in = torch.cat([t] * 2) # aka sigmas t_in = torch.cat([t] * 2) # aka sigmas
deltas = None deltas = None
uncond_latents = None uncond_latents = None
weighted_cond_list = c_or_weighted_c_list if type(c_or_weighted_c_list) is list else [(c_or_weighted_c_list, 1)] weighted_cond_list = (
c_or_weighted_c_list
if type(c_or_weighted_c_list) is list
else [(c_or_weighted_c_list, 1)]
)
# below is fugly omg # below is fugly omg
num_actual_conditionings = len(c_or_weighted_c_list) num_actual_conditionings = len(c_or_weighted_c_list)
conditionings = [uc] + [c for c,weight in weighted_cond_list] conditionings = [uc] + [c for c, weight in weighted_cond_list]
weights = [1] + [weight for c,weight in weighted_cond_list] weights = [1] + [weight for c, weight in weighted_cond_list]
chunk_count = ceil(len(conditionings)/2) chunk_count = ceil(len(conditionings) / 2)
deltas = None deltas = None
for chunk_index in range(chunk_count): for chunk_index in range(chunk_count):
offset = chunk_index*2 offset = chunk_index * 2
chunk_size = min(2, len(conditionings)-offset) chunk_size = min(2, len(conditionings) - offset)
if chunk_size == 1: if chunk_size == 1:
c_in = conditionings[offset] c_in = conditionings[offset]
latents_a = forward_func(x_in[:-1], t_in[:-1], c_in) latents_a = forward_func(x_in[:-1], t_in[:-1], c_in)
latents_b = None latents_b = None
else: else:
c_in = torch.cat(conditionings[offset:offset+2]) c_in = torch.cat(conditionings[offset : offset + 2])
latents_a, latents_b = forward_func(x_in, t_in, c_in).chunk(2) latents_a, latents_b = forward_func(x_in, t_in, c_in).chunk(2)
# first chunk is guaranteed to be 2 entries: uncond_latents + first conditioining # first chunk is guaranteed to be 2 entries: uncond_latents + first conditioining
@ -478,11 +626,15 @@ class InvokeAIDiffuserComponent:
deltas = torch.cat((deltas, latents_b - uncond_latents)) deltas = torch.cat((deltas, latents_b - uncond_latents))
# merge the weighted deltas together into a single merged delta # merge the weighted deltas together into a single merged delta
per_delta_weights = torch.tensor(weights[1:], dtype=deltas.dtype, device=deltas.device) per_delta_weights = torch.tensor(
weights[1:], dtype=deltas.dtype, device=deltas.device
)
normalize = False normalize = False
if normalize: if normalize:
per_delta_weights /= torch.sum(per_delta_weights) per_delta_weights /= torch.sum(per_delta_weights)
reshaped_weights = per_delta_weights.reshape(per_delta_weights.shape + (1, 1, 1)) reshaped_weights = per_delta_weights.reshape(
per_delta_weights.shape + (1, 1, 1)
)
deltas_merged = torch.sum(deltas * reshaped_weights, dim=0, keepdim=True) deltas_merged = torch.sum(deltas * reshaped_weights, dim=0, keepdim=True)
# old_return_value = super().forward(x, sigma, uncond, cond, cond_scale) # old_return_value = super().forward(x, sigma, uncond, cond, cond_scale)

File diff suppressed because it is too large Load Diff

View File

@ -1,23 +1,22 @@
import math
from abc import abstractmethod from abc import abstractmethod
from functools import partial from functools import partial
import math
from typing import Iterable from typing import Iterable
import numpy as np import numpy as np
import torch as th import torch as th
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from ldm.modules.attention import SpatialTransformer
from ldm.modules.diffusionmodules.util import ( from ldm.modules.diffusionmodules.util import (
avg_pool_nd,
checkpoint, checkpoint,
conv_nd, conv_nd,
linear, linear,
avg_pool_nd,
zero_module,
normalization, normalization,
timestep_embedding, timestep_embedding,
zero_module,
) )
from ldm.modules.attention import SpatialTransformer
# dummy replace # dummy replace
@ -100,9 +99,7 @@ class Upsample(nn.Module):
upsampling occurs in the inner-two dimensions. upsampling occurs in the inner-two dimensions.
""" """
def __init__( def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
self, channels, use_conv, dims=2, out_channels=None, padding=1
):
super().__init__() super().__init__()
self.channels = channels self.channels = channels
self.out_channels = out_channels or channels self.out_channels = out_channels or channels
@ -117,10 +114,10 @@ class Upsample(nn.Module):
assert x.shape[1] == self.channels assert x.shape[1] == self.channels
if self.dims == 3: if self.dims == 3:
x = F.interpolate( x = F.interpolate(
x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode='nearest' x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest"
) )
else: else:
x = F.interpolate(x, scale_factor=2, mode='nearest') x = F.interpolate(x, scale_factor=2, mode="nearest")
if self.use_conv: if self.use_conv:
x = self.conv(x) x = self.conv(x)
return x return x
@ -151,9 +148,7 @@ class Downsample(nn.Module):
downsampling occurs in the inner-two dimensions. downsampling occurs in the inner-two dimensions.
""" """
def __init__( def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
self, channels, use_conv, dims=2, out_channels=None, padding=1
):
super().__init__() super().__init__()
self.channels = channels self.channels = channels
self.out_channels = out_channels or channels self.out_channels = out_channels or channels
@ -237,9 +232,7 @@ class ResBlock(TimestepBlock):
nn.SiLU(), nn.SiLU(),
linear( linear(
emb_channels, emb_channels,
2 * self.out_channels 2 * self.out_channels if use_scale_shift_norm else self.out_channels,
if use_scale_shift_norm
else self.out_channels,
), ),
) )
self.out_layers = nn.Sequential( self.out_layers = nn.Sequential(
@ -247,9 +240,7 @@ class ResBlock(TimestepBlock):
nn.SiLU(), nn.SiLU(),
nn.Dropout(p=dropout), nn.Dropout(p=dropout),
zero_module( zero_module(
conv_nd( conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)
dims, self.out_channels, self.out_channels, 3, padding=1
)
), ),
) )
@ -260,9 +251,7 @@ class ResBlock(TimestepBlock):
dims, channels, self.out_channels, 3, padding=1 dims, channels, self.out_channels, 3, padding=1
) )
else: else:
self.skip_connection = conv_nd( self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
dims, channels, self.out_channels, 1
)
def forward(self, x, emb): def forward(self, x, emb):
""" """
@ -320,7 +309,7 @@ class AttentionBlock(nn.Module):
else: else:
assert ( assert (
channels % num_head_channels == 0 channels % num_head_channels == 0
), f'q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}' ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
self.num_heads = channels // num_head_channels self.num_heads = channels // num_head_channels
self.use_checkpoint = use_checkpoint self.use_checkpoint = use_checkpoint
self.norm = normalization(channels) self.norm = normalization(channels)
@ -337,7 +326,7 @@ class AttentionBlock(nn.Module):
def forward(self, x): def forward(self, x):
return checkpoint( return checkpoint(
self._forward, (x,), self.parameters(), True self._forward, (x,), self.parameters(), True
) # TODO: check checkpoint usage, is True # TODO: fix the .half call!!! ) # TODO: check checkpoint usage, is True # TODO: fix the .half call!!!
# return pt_checkpoint(self._forward, x) # pytorch # return pt_checkpoint(self._forward, x) # pytorch
def _forward(self, x): def _forward(self, x):
@ -387,15 +376,13 @@ class QKVAttentionLegacy(nn.Module):
bs, width, length = qkv.shape bs, width, length = qkv.shape
assert width % (3 * self.n_heads) == 0 assert width % (3 * self.n_heads) == 0
ch = width // (3 * self.n_heads) ch = width // (3 * self.n_heads)
q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split( q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
ch, dim=1
)
scale = 1 / math.sqrt(math.sqrt(ch)) scale = 1 / math.sqrt(math.sqrt(ch))
weight = th.einsum( weight = th.einsum(
'bct,bcs->bts', q * scale, k * scale "bct,bcs->bts", q * scale, k * scale
) # More stable with f16 than dividing afterwards ) # More stable with f16 than dividing afterwards
weight = th.softmax(weight.float(), dim=-1).type(weight.dtype) weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
a = th.einsum('bts,bcs->bct', weight, v) a = th.einsum("bts,bcs->bct", weight, v)
return a.reshape(bs, -1, length) return a.reshape(bs, -1, length)
@staticmethod @staticmethod
@ -424,14 +411,12 @@ class QKVAttention(nn.Module):
q, k, v = qkv.chunk(3, dim=1) q, k, v = qkv.chunk(3, dim=1)
scale = 1 / math.sqrt(math.sqrt(ch)) scale = 1 / math.sqrt(math.sqrt(ch))
weight = th.einsum( weight = th.einsum(
'bct,bcs->bts', "bct,bcs->bts",
(q * scale).view(bs * self.n_heads, ch, length), (q * scale).view(bs * self.n_heads, ch, length),
(k * scale).view(bs * self.n_heads, ch, length), (k * scale).view(bs * self.n_heads, ch, length),
) # More stable with f16 than dividing afterwards ) # More stable with f16 than dividing afterwards
weight = th.softmax(weight.float(), dim=-1).type(weight.dtype) weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
a = th.einsum( a = th.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length))
'bts,bcs->bct', weight, v.reshape(bs * self.n_heads, ch, length)
)
return a.reshape(bs, -1, length) return a.reshape(bs, -1, length)
@staticmethod @staticmethod
@ -500,12 +485,12 @@ class UNetModel(nn.Module):
if use_spatial_transformer: if use_spatial_transformer:
assert ( assert (
context_dim is not None context_dim is not None
), 'Fool!! You forgot to include the dimension of your cross-attention conditioning...' ), "Fool!! You forgot to include the dimension of your cross-attention conditioning..."
if context_dim is not None: if context_dim is not None:
assert ( assert (
use_spatial_transformer use_spatial_transformer
), 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...' ), "Fool!! You forgot to use the spatial transformer for your cross-attention conditioning..."
from omegaconf.listconfig import ListConfig from omegaconf.listconfig import ListConfig
if type(context_dim) == ListConfig: if type(context_dim) == ListConfig:
@ -517,12 +502,12 @@ class UNetModel(nn.Module):
if num_heads == -1: if num_heads == -1:
assert ( assert (
num_head_channels != -1 num_head_channels != -1
), 'Either num_heads or num_head_channels has to be set' ), "Either num_heads or num_head_channels has to be set"
if num_head_channels == -1: if num_head_channels == -1:
assert ( assert (
num_heads != -1 num_heads != -1
), 'Either num_heads or num_head_channels has to be set' ), "Either num_heads or num_head_channels has to be set"
self.image_size = image_size self.image_size = image_size
self.in_channels = in_channels self.in_channels = in_channels
@ -641,11 +626,7 @@ class UNetModel(nn.Module):
dim_head = num_head_channels dim_head = num_head_channels
if legacy: if legacy:
# num_heads = 1 # num_heads = 1
dim_head = ( dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
ch // num_heads
if use_spatial_transformer
else num_head_channels
)
self.middle_block = TimestepEmbedSequential( self.middle_block = TimestepEmbedSequential(
ResBlock( ResBlock(
ch, ch,
@ -741,9 +722,7 @@ class UNetModel(nn.Module):
up=True, up=True,
) )
if resblock_updown if resblock_updown
else Upsample( else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)
ch, conv_resample, dims=dims, out_channels=out_ch
)
) )
ds //= 2 ds //= 2
self.output_blocks.append(TimestepEmbedSequential(*layers)) self.output_blocks.append(TimestepEmbedSequential(*layers))
@ -752,9 +731,7 @@ class UNetModel(nn.Module):
self.out = nn.Sequential( self.out = nn.Sequential(
normalization(ch), normalization(ch),
nn.SiLU(), nn.SiLU(),
zero_module( zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)),
conv_nd(dims, model_channels, out_channels, 3, padding=1)
),
) )
if self.predict_codebook_ids: if self.predict_codebook_ids:
self.id_predictor = nn.Sequential( self.id_predictor = nn.Sequential(
@ -790,11 +767,9 @@ class UNetModel(nn.Module):
""" """
assert (y is not None) == ( assert (y is not None) == (
self.num_classes is not None self.num_classes is not None
), 'must specify y if and only if the model is class-conditional' ), "must specify y if and only if the model is class-conditional"
hs = [] hs = []
t_emb = timestep_embedding( t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
timesteps, self.model_channels, repeat_only=False
)
emb = self.time_embed(t_emb) emb = self.time_embed(t_emb)
if self.num_classes is not None: if self.num_classes is not None:
@ -842,7 +817,7 @@ class EncoderUNetModel(nn.Module):
use_scale_shift_norm=False, use_scale_shift_norm=False,
resblock_updown=False, resblock_updown=False,
use_new_attention_order=False, use_new_attention_order=False,
pool='adaptive', pool="adaptive",
*args, *args,
**kwargs, **kwargs,
): ):
@ -962,7 +937,7 @@ class EncoderUNetModel(nn.Module):
) )
self._feature_size += ch self._feature_size += ch
self.pool = pool self.pool = pool
if pool == 'adaptive': if pool == "adaptive":
self.out = nn.Sequential( self.out = nn.Sequential(
normalization(ch), normalization(ch),
nn.SiLU(), nn.SiLU(),
@ -970,7 +945,7 @@ class EncoderUNetModel(nn.Module):
zero_module(conv_nd(dims, ch, out_channels, 1)), zero_module(conv_nd(dims, ch, out_channels, 1)),
nn.Flatten(), nn.Flatten(),
) )
elif pool == 'attention': elif pool == "attention":
assert num_head_channels != -1 assert num_head_channels != -1
self.out = nn.Sequential( self.out = nn.Sequential(
normalization(ch), normalization(ch),
@ -979,13 +954,13 @@ class EncoderUNetModel(nn.Module):
(image_size // ds), ch, num_head_channels, out_channels (image_size // ds), ch, num_head_channels, out_channels
), ),
) )
elif pool == 'spatial': elif pool == "spatial":
self.out = nn.Sequential( self.out = nn.Sequential(
nn.Linear(self._feature_size, 2048), nn.Linear(self._feature_size, 2048),
nn.ReLU(), nn.ReLU(),
nn.Linear(2048, self.out_channels), nn.Linear(2048, self.out_channels),
) )
elif pool == 'spatial_v2': elif pool == "spatial_v2":
self.out = nn.Sequential( self.out = nn.Sequential(
nn.Linear(self._feature_size, 2048), nn.Linear(self._feature_size, 2048),
normalization(2048), normalization(2048),
@ -993,7 +968,7 @@ class EncoderUNetModel(nn.Module):
nn.Linear(2048, self.out_channels), nn.Linear(2048, self.out_channels),
) )
else: else:
raise NotImplementedError(f'Unexpected {pool} pooling') raise NotImplementedError(f"Unexpected {pool} pooling")
def convert_to_fp16(self): def convert_to_fp16(self):
""" """
@ -1016,18 +991,16 @@ class EncoderUNetModel(nn.Module):
:param timesteps: a 1-D batch of timesteps. :param timesteps: a 1-D batch of timesteps.
:return: an [N x K] Tensor of outputs. :return: an [N x K] Tensor of outputs.
""" """
emb = self.time_embed( emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
timestep_embedding(timesteps, self.model_channels)
)
results = [] results = []
h = x.type(self.dtype) h = x.type(self.dtype)
for module in self.input_blocks: for module in self.input_blocks:
h = module(h, emb) h = module(h, emb)
if self.pool.startswith('spatial'): if self.pool.startswith("spatial"):
results.append(h.type(x.dtype).mean(dim=(2, 3))) results.append(h.type(x.dtype).mean(dim=(2, 3)))
h = self.middle_block(h, emb) h = self.middle_block(h, emb)
if self.pool.startswith('spatial'): if self.pool.startswith("spatial"):
results.append(h.type(x.dtype).mean(dim=(2, 3))) results.append(h.type(x.dtype).mean(dim=(2, 3)))
h = th.cat(results, axis=-1) h = th.cat(results, axis=-1)
return self.out(h) return self.out(h)

View File

@ -8,20 +8,21 @@
# thanks! # thanks!
import os
import math import math
import os
import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
import numpy as np
from einops import repeat from einops import repeat
from ldm.util import instantiate_from_config from ...util.util import instantiate_from_config
def make_beta_schedule( def make_beta_schedule(
schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3 schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3
): ):
if schedule == 'linear': if schedule == "linear":
betas = ( betas = (
torch.linspace( torch.linspace(
linear_start**0.5, linear_start**0.5,
@ -32,10 +33,9 @@ def make_beta_schedule(
** 2 ** 2
) )
elif schedule == 'cosine': elif schedule == "cosine":
timesteps = ( timesteps = (
torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s
+ cosine_s
) )
alphas = timesteps / (1 + cosine_s) * np.pi / 2 alphas = timesteps / (1 + cosine_s) * np.pi / 2
alphas = torch.cos(alphas).pow(2) alphas = torch.cos(alphas).pow(2)
@ -43,15 +43,13 @@ def make_beta_schedule(
betas = 1 - alphas[1:] / alphas[:-1] betas = 1 - alphas[1:] / alphas[:-1]
betas = np.clip(betas, a_min=0, a_max=0.999) betas = np.clip(betas, a_min=0, a_max=0.999)
elif schedule == 'sqrt_linear': elif schedule == "sqrt_linear":
betas = torch.linspace( betas = torch.linspace(
linear_start, linear_end, n_timestep, dtype=torch.float64 linear_start, linear_end, n_timestep, dtype=torch.float64
) )
elif schedule == 'sqrt': elif schedule == "sqrt":
betas = ( betas = (
torch.linspace( torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64)
linear_start, linear_end, n_timestep, dtype=torch.float64
)
** 0.5 ** 0.5
) )
else: else:
@ -62,19 +60,14 @@ def make_beta_schedule(
def make_ddim_timesteps( def make_ddim_timesteps(
ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True
): ):
if ddim_discr_method == 'uniform': if ddim_discr_method == "uniform":
c = num_ddpm_timesteps // num_ddim_timesteps c = num_ddpm_timesteps // num_ddim_timesteps
if c < 1: if c < 1:
c = 1 c = 1
ddim_timesteps = (np.arange(0, num_ddim_timesteps) * c).astype(int) ddim_timesteps = (np.arange(0, num_ddim_timesteps) * c).astype(int)
elif ddim_discr_method == 'quad': elif ddim_discr_method == "quad":
ddim_timesteps = ( ddim_timesteps = (
( (np.linspace(0, np.sqrt(num_ddpm_timesteps * 0.8), num_ddim_timesteps)) ** 2
np.linspace(
0, np.sqrt(num_ddpm_timesteps * 0.8), num_ddim_timesteps
)
)
** 2
).astype(int) ).astype(int)
else: else:
raise NotImplementedError( raise NotImplementedError(
@ -87,18 +80,14 @@ def make_ddim_timesteps(
# steps_out = ddim_timesteps # steps_out = ddim_timesteps
if verbose: if verbose:
print(f'Selected timesteps for ddim sampler: {steps_out}') print(f"Selected timesteps for ddim sampler: {steps_out}")
return steps_out return steps_out
def make_ddim_sampling_parameters( def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True):
alphacums, ddim_timesteps, eta, verbose=True
):
# select alphas for computing the variance schedule # select alphas for computing the variance schedule
alphas = alphacums[ddim_timesteps] alphas = alphacums[ddim_timesteps]
alphas_prev = np.asarray( alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist())
[alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist()
)
# according the the formula provided in https://arxiv.org/abs/2010.02502 # according the the formula provided in https://arxiv.org/abs/2010.02502
sigmas = eta * np.sqrt( sigmas = eta * np.sqrt(
@ -106,11 +95,11 @@ def make_ddim_sampling_parameters(
) )
if verbose: if verbose:
print( print(
f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}' f"Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}"
) )
print( print(
f'For the chosen value of eta, which is {eta}, ' f"For the chosen value of eta, which is {eta}, "
f'this results in the following sigma_t schedule for ddim sampler {sigmas}' f"this results in the following sigma_t schedule for ddim sampler {sigmas}"
) )
return sigmas, alphas, alphas_prev return sigmas, alphas, alphas_prev
@ -150,9 +139,7 @@ def checkpoint(func, inputs, params, flag):
explicitly take as arguments. explicitly take as arguments.
:param flag: if False, disable gradient checkpointing. :param flag: if False, disable gradient checkpointing.
""" """
if ( if False: # disabled checkpointing to allow requires_grad = False for main model
False
): # disabled checkpointing to allow requires_grad = False for main model
args = tuple(inputs) + tuple(params) args = tuple(inputs) + tuple(params)
return CheckpointFunction.apply(func, len(inputs), *args) return CheckpointFunction.apply(func, len(inputs), *args)
else: else:
@ -172,9 +159,7 @@ class CheckpointFunction(torch.autograd.Function):
@staticmethod @staticmethod
def backward(ctx, *output_grads): def backward(ctx, *output_grads):
ctx.input_tensors = [ ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
x.detach().requires_grad_(True) for x in ctx.input_tensors
]
with torch.enable_grad(): with torch.enable_grad():
# Fixes a bug where the first op in run_function modifies the # Fixes a bug where the first op in run_function modifies the
# Tensor storage in place, which is not allowed for detach()'d # Tensor storage in place, which is not allowed for detach()'d
@ -216,7 +201,7 @@ def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
[embedding, torch.zeros_like(embedding[:, :1])], dim=-1 [embedding, torch.zeros_like(embedding[:, :1])], dim=-1
) )
else: else:
embedding = repeat(timesteps, 'b -> b d', d=dim) embedding = repeat(timesteps, "b -> b d", d=dim)
return embedding return embedding
@ -269,7 +254,7 @@ def conv_nd(dims, *args, **kwargs):
return nn.Conv2d(*args, **kwargs) return nn.Conv2d(*args, **kwargs)
elif dims == 3: elif dims == 3:
return nn.Conv3d(*args, **kwargs) return nn.Conv3d(*args, **kwargs)
raise ValueError(f'unsupported dimensions: {dims}') raise ValueError(f"unsupported dimensions: {dims}")
def linear(*args, **kwargs): def linear(*args, **kwargs):
@ -289,21 +274,19 @@ def avg_pool_nd(dims, *args, **kwargs):
return nn.AvgPool2d(*args, **kwargs) return nn.AvgPool2d(*args, **kwargs)
elif dims == 3: elif dims == 3:
return nn.AvgPool3d(*args, **kwargs) return nn.AvgPool3d(*args, **kwargs)
raise ValueError(f'unsupported dimensions: {dims}') raise ValueError(f"unsupported dimensions: {dims}")
class HybridConditioner(nn.Module): class HybridConditioner(nn.Module):
def __init__(self, c_concat_config, c_crossattn_config): def __init__(self, c_concat_config, c_crossattn_config):
super().__init__() super().__init__()
self.concat_conditioner = instantiate_from_config(c_concat_config) self.concat_conditioner = instantiate_from_config(c_concat_config)
self.crossattn_conditioner = instantiate_from_config( self.crossattn_conditioner = instantiate_from_config(c_crossattn_config)
c_crossattn_config
)
def forward(self, c_concat, c_crossattn): def forward(self, c_concat, c_crossattn):
c_concat = self.concat_conditioner(c_concat) c_concat = self.concat_conditioner(c_concat)
c_crossattn = self.crossattn_conditioner(c_crossattn) c_crossattn = self.crossattn_conditioner(c_crossattn)
return {'c_concat': [c_concat], 'c_crossattn': [c_crossattn]} return {"c_concat": [c_concat], "c_crossattn": [c_crossattn]}
def noise_like(shape, device, repeat=False): def noise_like(shape, device, repeat=False):

View File

@ -1,5 +1,5 @@
import torch
import numpy as np import numpy as np
import torch
class AbstractDistribution: class AbstractDistribution:
@ -64,9 +64,7 @@ class DiagonalGaussianDistribution(object):
return torch.Tensor([0.0]) return torch.Tensor([0.0])
logtwopi = np.log(2.0 * np.pi) logtwopi = np.log(2.0 * np.pi)
return 0.5 * torch.sum( return 0.5 * torch.sum(
logtwopi logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
+ self.logvar
+ torch.pow(sample - self.mean, 2) / self.var,
dim=dims, dim=dims,
) )
@ -86,7 +84,7 @@ def normal_kl(mean1, logvar1, mean2, logvar2):
if isinstance(obj, torch.Tensor): if isinstance(obj, torch.Tensor):
tensor = obj tensor = obj
break break
assert tensor is not None, 'at least one argument must be a Tensor' assert tensor is not None, "at least one argument must be a Tensor"
# Force variances to be Tensors. Broadcasting helps convert scalars to # Force variances to be Tensors. Broadcasting helps convert scalars to
# Tensors, but it does not work for torch.exp(). # Tensors, but it does not work for torch.exp().

View File

@ -6,12 +6,12 @@ class LitEma(nn.Module):
def __init__(self, model, decay=0.9999, use_num_upates=True): def __init__(self, model, decay=0.9999, use_num_upates=True):
super().__init__() super().__init__()
if decay < 0.0 or decay > 1.0: if decay < 0.0 or decay > 1.0:
raise ValueError('Decay must be between 0 and 1') raise ValueError("Decay must be between 0 and 1")
self.m_name2s_name = {} self.m_name2s_name = {}
self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32)) self.register_buffer("decay", torch.tensor(decay, dtype=torch.float32))
self.register_buffer( self.register_buffer(
'num_updates', "num_updates",
torch.tensor(0, dtype=torch.int) torch.tensor(0, dtype=torch.int)
if use_num_upates if use_num_upates
else torch.tensor(-1, dtype=torch.int), else torch.tensor(-1, dtype=torch.int),
@ -20,7 +20,7 @@ class LitEma(nn.Module):
for name, p in model.named_parameters(): for name, p in model.named_parameters():
if p.requires_grad: if p.requires_grad:
# remove as '.'-character is not allowed in buffers # remove as '.'-character is not allowed in buffers
s_name = name.replace('.', '') s_name = name.replace(".", "")
self.m_name2s_name.update({name: s_name}) self.m_name2s_name.update({name: s_name})
self.register_buffer(s_name, p.clone().detach().data) self.register_buffer(s_name, p.clone().detach().data)
@ -31,9 +31,7 @@ class LitEma(nn.Module):
if self.num_updates >= 0: if self.num_updates >= 0:
self.num_updates += 1 self.num_updates += 1
decay = min( decay = min(self.decay, (1 + self.num_updates) / (10 + self.num_updates))
self.decay, (1 + self.num_updates) / (10 + self.num_updates)
)
one_minus_decay = 1.0 - decay one_minus_decay = 1.0 - decay
@ -44,9 +42,7 @@ class LitEma(nn.Module):
for key in m_param: for key in m_param:
if m_param[key].requires_grad: if m_param[key].requires_grad:
sname = self.m_name2s_name[key] sname = self.m_name2s_name[key]
shadow_params[sname] = shadow_params[sname].type_as( shadow_params[sname] = shadow_params[sname].type_as(m_param[key])
m_param[key]
)
shadow_params[sname].sub_( shadow_params[sname].sub_(
one_minus_decay * (shadow_params[sname] - m_param[key]) one_minus_decay * (shadow_params[sname] - m_param[key])
) )
@ -58,9 +54,7 @@ class LitEma(nn.Module):
shadow_params = dict(self.named_buffers()) shadow_params = dict(self.named_buffers())
for key in m_param: for key in m_param:
if m_param[key].requires_grad: if m_param[key].requires_grad:
m_param[key].data.copy_( m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data)
shadow_params[self.m_name2s_name[key]].data
)
else: else:
assert not key in self.m_name2s_name assert not key in self.m_name2s_name

View File

@ -7,14 +7,14 @@ import kornia
import torch import torch
import torch.nn as nn import torch.nn as nn
from einops import repeat from einops import repeat
from transformers import CLIPTokenizer, CLIPTextModel from transformers import CLIPTextModel, CLIPTokenizer
from ldm.invoke.devices import choose_torch_device from ...util import choose_torch_device
from ldm.invoke.globals import global_cache_dir from ..globals import global_cache_dir
from ldm.modules.x_transformer import ( from ..x_transformer import ( # TODO: can we directly rely on lucidrains code and simply add this as a reuirement? --> test
Encoder, Encoder,
TransformerWrapper, TransformerWrapper,
) # TODO: can we directly rely on lucidrains code and simply add this as a reuirement? --> test )
def _expand_mask(mask, dtype, tgt_len=None): def _expand_mask(mask, dtype, tgt_len=None):
@ -24,9 +24,7 @@ def _expand_mask(mask, dtype, tgt_len=None):
bsz, src_len = mask.size() bsz, src_len = mask.size()
tgt_len = tgt_len if tgt_len is not None else src_len tgt_len = tgt_len if tgt_len is not None else src_len
expanded_mask = ( expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
)
inverted_mask = 1.0 - expanded_mask inverted_mask = 1.0 - expanded_mask
@ -54,7 +52,7 @@ class AbstractEncoder(nn.Module):
class ClassEmbedder(nn.Module): class ClassEmbedder(nn.Module):
def __init__(self, embed_dim, n_classes=1000, key='class'): def __init__(self, embed_dim, n_classes=1000, key="class"):
super().__init__() super().__init__()
self.key = key self.key = key
self.embedding = nn.Embedding(n_classes, embed_dim) self.embedding = nn.Embedding(n_classes, embed_dim)
@ -99,20 +97,14 @@ class TransformerEmbedder(AbstractEncoder):
class BERTTokenizer(AbstractEncoder): class BERTTokenizer(AbstractEncoder):
"""Uses a pretrained BERT tokenizer by huggingface. Vocab size: 30522 (?)""" """Uses a pretrained BERT tokenizer by huggingface. Vocab size: 30522 (?)"""
def __init__( def __init__(self, device=choose_torch_device(), vq_interface=True, max_length=77):
self, device=choose_torch_device(), vq_interface=True, max_length=77
):
super().__init__() super().__init__()
from transformers import ( from transformers import BertTokenizerFast
BertTokenizerFast,
)
cache = global_cache_dir('hub') cache = global_cache_dir("hub")
try: try:
self.tokenizer = BertTokenizerFast.from_pretrained( self.tokenizer = BertTokenizerFast.from_pretrained(
'bert-base-uncased', "bert-base-uncased", cache_dir=cache, local_files_only=True
cache_dir=cache,
local_files_only=True
) )
except OSError: except OSError:
raise SystemExit( raise SystemExit(
@ -129,10 +121,10 @@ class BERTTokenizer(AbstractEncoder):
max_length=self.max_length, max_length=self.max_length,
return_length=True, return_length=True,
return_overflowing_tokens=False, return_overflowing_tokens=False,
padding='max_length', padding="max_length",
return_tensors='pt', return_tensors="pt",
) )
tokens = batch_encoding['input_ids'].to(self.device) tokens = batch_encoding["input_ids"].to(self.device)
return tokens return tokens
@torch.no_grad() @torch.no_grad()
@ -150,21 +142,19 @@ class BERTEmbedder(AbstractEncoder):
"""Uses the BERT tokenizr model and add some transformer encoder layers""" """Uses the BERT tokenizr model and add some transformer encoder layers"""
def __init__( def __init__(
self, self,
n_embed, n_embed,
n_layer, n_layer,
vocab_size=30522, vocab_size=30522,
max_seq_len=77, max_seq_len=77,
device=choose_torch_device(), device=choose_torch_device(),
use_tokenizer=True, use_tokenizer=True,
embedding_dropout=0.0, embedding_dropout=0.0,
): ):
super().__init__() super().__init__()
self.use_tknz_fn = use_tokenizer self.use_tknz_fn = use_tokenizer
if self.use_tknz_fn: if self.use_tknz_fn:
self.tknz_fn = BERTTokenizer( self.tknz_fn = BERTTokenizer(vq_interface=False, max_length=max_seq_len)
vq_interface=False, max_length=max_seq_len
)
self.device = device self.device = device
self.transformer = TransformerWrapper( self.transformer = TransformerWrapper(
num_tokens=vocab_size, num_tokens=vocab_size,
@ -192,7 +182,7 @@ class SpatialRescaler(nn.Module):
def __init__( def __init__(
self, self,
n_stages=1, n_stages=1,
method='bilinear', method="bilinear",
multiplier=0.5, multiplier=0.5,
in_channels=3, in_channels=3,
out_channels=None, out_channels=None,
@ -202,25 +192,21 @@ class SpatialRescaler(nn.Module):
self.n_stages = n_stages self.n_stages = n_stages
assert self.n_stages >= 0 assert self.n_stages >= 0
assert method in [ assert method in [
'nearest', "nearest",
'linear', "linear",
'bilinear', "bilinear",
'trilinear', "trilinear",
'bicubic', "bicubic",
'area', "area",
] ]
self.multiplier = multiplier self.multiplier = multiplier
self.interpolator = partial( self.interpolator = partial(torch.nn.functional.interpolate, mode=method)
torch.nn.functional.interpolate, mode=method
)
self.remap_output = out_channels is not None self.remap_output = out_channels is not None
if self.remap_output: if self.remap_output:
print( print(
f'Spatial Rescaler mapping from {in_channels} to {out_channels} channels after resizing.' f"Spatial Rescaler mapping from {in_channels} to {out_channels} channels after resizing."
)
self.channel_mapper = nn.Conv2d(
in_channels, out_channels, 1, bias=bias
) )
self.channel_mapper = nn.Conv2d(in_channels, out_channels, 1, bias=bias)
def forward(self, x): def forward(self, x):
for stage in range(self.n_stages): for stage in range(self.n_stages):
@ -236,27 +222,24 @@ class SpatialRescaler(nn.Module):
class FrozenCLIPEmbedder(AbstractEncoder): class FrozenCLIPEmbedder(AbstractEncoder):
"""Uses the CLIP transformer encoder for text (from Hugging Face)""" """Uses the CLIP transformer encoder for text (from Hugging Face)"""
tokenizer: CLIPTokenizer tokenizer: CLIPTokenizer
transformer: CLIPTextModel transformer: CLIPTextModel
def __init__( def __init__(
self, self,
version:str='openai/clip-vit-large-patch14', version: str = "openai/clip-vit-large-patch14",
max_length:int=77, max_length: int = 77,
tokenizer:Optional[CLIPTokenizer]=None, tokenizer: Optional[CLIPTokenizer] = None,
transformer:Optional[CLIPTextModel]=None, transformer: Optional[CLIPTextModel] = None,
): ):
super().__init__() super().__init__()
cache = global_cache_dir('hub') cache = global_cache_dir("hub")
self.tokenizer = tokenizer or CLIPTokenizer.from_pretrained( self.tokenizer = tokenizer or CLIPTokenizer.from_pretrained(
version, version, cache_dir=cache, local_files_only=True
cache_dir=cache,
local_files_only=True
) )
self.transformer = transformer or CLIPTextModel.from_pretrained( self.transformer = transformer or CLIPTextModel.from_pretrained(
version, version, cache_dir=cache, local_files_only=True
cache_dir=cache,
local_files_only=True
) )
self.max_length = max_length self.max_length = max_length
self.freeze() self.freeze()
@ -268,7 +251,6 @@ class FrozenCLIPEmbedder(AbstractEncoder):
inputs_embeds=None, inputs_embeds=None,
embedding_manager=None, embedding_manager=None,
) -> torch.Tensor: ) -> torch.Tensor:
seq_length = ( seq_length = (
input_ids.shape[-1] input_ids.shape[-1]
if input_ids is not None if input_ids is not None
@ -289,8 +271,8 @@ class FrozenCLIPEmbedder(AbstractEncoder):
return embeddings return embeddings
self.transformer.text_model.embeddings.forward = ( self.transformer.text_model.embeddings.forward = embedding_forward.__get__(
embedding_forward.__get__(self.transformer.text_model.embeddings) self.transformer.text_model.embeddings
) )
def encoder_forward( def encoder_forward(
@ -313,9 +295,7 @@ class FrozenCLIPEmbedder(AbstractEncoder):
else self.config.output_hidden_states else self.config.output_hidden_states
) )
return_dict = ( return_dict = (
return_dict return_dict if return_dict is not None else self.config.use_return_dict
if return_dict is not None
else self.config.use_return_dict
) )
encoder_states = () if output_hidden_states else None encoder_states = () if output_hidden_states else None
@ -368,13 +348,11 @@ class FrozenCLIPEmbedder(AbstractEncoder):
else self.config.output_hidden_states else self.config.output_hidden_states
) )
return_dict = ( return_dict = (
return_dict return_dict if return_dict is not None else self.config.use_return_dict
if return_dict is not None
else self.config.use_return_dict
) )
if input_ids is None: if input_ids is None:
raise ValueError('You have to specify either input_ids') raise ValueError("You have to specify either input_ids")
input_shape = input_ids.size() input_shape = input_ids.size()
input_ids = input_ids.view(-1, input_shape[-1]) input_ids = input_ids.view(-1, input_shape[-1])
@ -395,9 +373,7 @@ class FrozenCLIPEmbedder(AbstractEncoder):
# expand attention_mask # expand attention_mask
if attention_mask is not None: if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
attention_mask = _expand_mask( attention_mask = _expand_mask(attention_mask, hidden_states.dtype)
attention_mask, hidden_states.dtype
)
last_hidden_state = self.encoder( last_hidden_state = self.encoder(
inputs_embeds=hidden_states, inputs_embeds=hidden_states,
@ -436,9 +412,7 @@ class FrozenCLIPEmbedder(AbstractEncoder):
embedding_manager=embedding_manager, embedding_manager=embedding_manager,
) )
self.transformer.forward = transformer_forward.__get__( self.transformer.forward = transformer_forward.__get__(self.transformer)
self.transformer
)
def freeze(self): def freeze(self):
self.transformer = self.transformer.eval() self.transformer = self.transformer.eval()
@ -452,10 +426,10 @@ class FrozenCLIPEmbedder(AbstractEncoder):
max_length=self.max_length, max_length=self.max_length,
return_length=True, return_length=True,
return_overflowing_tokens=False, return_overflowing_tokens=False,
padding='max_length', padding="max_length",
return_tensors='pt', return_tensors="pt",
) )
tokens = batch_encoding['input_ids'].to(self.device) tokens = batch_encoding["input_ids"].to(self.device)
z = self.transformer(input_ids=tokens, **kwargs) z = self.transformer(input_ids=tokens, **kwargs)
return z return z
@ -471,25 +445,25 @@ class FrozenCLIPEmbedder(AbstractEncoder):
def device(self, device): def device(self, device):
self.transformer.to(device=device) self.transformer.to(device=device)
class WeightedFrozenCLIPEmbedder(FrozenCLIPEmbedder):
class WeightedFrozenCLIPEmbedder(FrozenCLIPEmbedder):
fragment_weights_key = "fragment_weights" fragment_weights_key = "fragment_weights"
return_tokens_key = "return_tokens" return_tokens_key = "return_tokens"
def set_textual_inversion_manager(self, manager): #TextualInversionManager): def set_textual_inversion_manager(self, manager): # TextualInversionManager):
# TODO all of the weighting and expanding stuff needs be moved out of this class # TODO all of the weighting and expanding stuff needs be moved out of this class
self.textual_inversion_manager = manager self.textual_inversion_manager = manager
def forward(self, text: list, **kwargs): def forward(self, text: list, **kwargs):
# TODO all of the weighting and expanding stuff needs be moved out of this class # TODO all of the weighting and expanding stuff needs be moved out of this class
''' """
:param text: A batch of prompt strings, or, a batch of lists of fragments of prompt strings to which different :param text: A batch of prompt strings, or, a batch of lists of fragments of prompt strings to which different
weights shall be applied. weights shall be applied.
:param kwargs: If the keyword arg "fragment_weights" is passed, it shall contain a batch of lists of weights :param kwargs: If the keyword arg "fragment_weights" is passed, it shall contain a batch of lists of weights
for the prompt fragments. In this case text must contain batches of lists of prompt fragments. for the prompt fragments. In this case text must contain batches of lists of prompt fragments.
:return: A tensor of shape (B, 77, 768) containing weighted embeddings :return: A tensor of shape (B, 77, 768) containing weighted embeddings
''' """
if self.fragment_weights_key not in kwargs: if self.fragment_weights_key not in kwargs:
# fallback to base class implementation # fallback to base class implementation
return super().forward(text, **kwargs) return super().forward(text, **kwargs)
@ -507,7 +481,6 @@ class WeightedFrozenCLIPEmbedder(FrozenCLIPEmbedder):
batch_z = None batch_z = None
batch_tokens = None batch_tokens = None
for fragments, weights in zip(text, fragment_weights): for fragments, weights in zip(text, fragment_weights):
# First, weight tokens in individual fragments by scaling the feature vectors as requested (effectively # First, weight tokens in individual fragments by scaling the feature vectors as requested (effectively
# applying a multiplier to the CFG scale on a per-token basis). # applying a multiplier to the CFG scale on a per-token basis).
# For tokens weighted<1, intuitively we want SD to become not merely *less* interested in the concept # For tokens weighted<1, intuitively we want SD to become not merely *less* interested in the concept
@ -520,7 +493,9 @@ class WeightedFrozenCLIPEmbedder(FrozenCLIPEmbedder):
# handle weights >=1 # handle weights >=1
tokens, per_token_weights = self.get_tokens_and_weights(fragments, weights) tokens, per_token_weights = self.get_tokens_and_weights(fragments, weights)
base_embedding = self.build_weighted_embedding_tensor(tokens, per_token_weights, **kwargs) base_embedding = self.build_weighted_embedding_tensor(
tokens, per_token_weights, **kwargs
)
# this is our starting point # this is our starting point
embeddings = base_embedding.unsqueeze(0) embeddings = base_embedding.unsqueeze(0)
@ -536,12 +511,18 @@ class WeightedFrozenCLIPEmbedder(FrozenCLIPEmbedder):
# such that the resulting lerped embedding is exactly half-way between "mountain man" and "mountain". # such that the resulting lerped embedding is exactly half-way between "mountain man" and "mountain".
for index, fragment_weight in enumerate(weights): for index, fragment_weight in enumerate(weights):
if fragment_weight < 1: if fragment_weight < 1:
fragments_without_this = fragments[:index] + fragments[index+1:] fragments_without_this = fragments[:index] + fragments[index + 1 :]
weights_without_this = weights[:index] + weights[index+1:] weights_without_this = weights[:index] + weights[index + 1 :]
tokens, per_token_weights = self.get_tokens_and_weights(fragments_without_this, weights_without_this) tokens, per_token_weights = self.get_tokens_and_weights(
embedding_without_this = self.build_weighted_embedding_tensor(tokens, per_token_weights, **kwargs) fragments_without_this, weights_without_this
)
embedding_without_this = self.build_weighted_embedding_tensor(
tokens, per_token_weights, **kwargs
)
embeddings = torch.cat((embeddings, embedding_without_this.unsqueeze(0)), dim=1) embeddings = torch.cat(
(embeddings, embedding_without_this.unsqueeze(0)), dim=1
)
# weight of the embedding *without* this fragment gets *stronger* as its weight approaches 0 # weight of the embedding *without* this fragment gets *stronger* as its weight approaches 0
# if fragment_weight = 0, basically we want embedding_without_this to completely overwhelm base_embedding # if fragment_weight = 0, basically we want embedding_without_this to completely overwhelm base_embedding
# therefore: # therefore:
@ -554,29 +535,43 @@ class WeightedFrozenCLIPEmbedder(FrozenCLIPEmbedder):
# inf at PI/2 # inf at PI/2
# -> tan((1-weight)*PI/2) should give us ideal lerp weights # -> tan((1-weight)*PI/2) should give us ideal lerp weights
epsilon = 1e-9 epsilon = 1e-9
fragment_weight = max(epsilon, fragment_weight) # inf is bad fragment_weight = max(epsilon, fragment_weight) # inf is bad
embedding_lerp_weight = math.tan((1.0 - fragment_weight) * math.pi / 2) embedding_lerp_weight = math.tan(
(1.0 - fragment_weight) * math.pi / 2
)
# todo handle negative weight? # todo handle negative weight?
per_embedding_weights.append(embedding_lerp_weight) per_embedding_weights.append(embedding_lerp_weight)
lerped_embeddings = self.apply_embedding_weights(embeddings, per_embedding_weights, normalize=True).squeeze(0) lerped_embeddings = self.apply_embedding_weights(
embeddings, per_embedding_weights, normalize=True
).squeeze(0)
#print(f"assembled tokens for '{fragments}' into tensor of shape {lerped_embeddings.shape}") # print(f"assembled tokens for '{fragments}' into tensor of shape {lerped_embeddings.shape}")
# append to batch # append to batch
batch_z = lerped_embeddings.unsqueeze(0) if batch_z is None else torch.cat([batch_z, lerped_embeddings.unsqueeze(0)], dim=1) batch_z = (
batch_tokens = tokens.unsqueeze(0) if batch_tokens is None else torch.cat([batch_tokens, tokens.unsqueeze(0)], dim=1) lerped_embeddings.unsqueeze(0)
if batch_z is None
else torch.cat([batch_z, lerped_embeddings.unsqueeze(0)], dim=1)
)
batch_tokens = (
tokens.unsqueeze(0)
if batch_tokens is None
else torch.cat([batch_tokens, tokens.unsqueeze(0)], dim=1)
)
# should have shape (B, 77, 768) # should have shape (B, 77, 768)
#print(f"assembled all tokens into tensor of shape {batch_z.shape}") # print(f"assembled all tokens into tensor of shape {batch_z.shape}")
if should_return_tokens: if should_return_tokens:
return batch_z, batch_tokens return batch_z, batch_tokens
else: else:
return batch_z return batch_z
def get_token_ids(self, fragments: list[str], include_start_and_end_markers: bool = True) -> list[list[int]]: def get_token_ids(
self, fragments: list[str], include_start_and_end_markers: bool = True
) -> list[list[int]]:
""" """
Convert a list of strings like `["a cat", "sitting", "on a mat"]` into a list of lists of token ids like Convert a list of strings like `["a cat", "sitting", "on a mat"]` into a list of lists of token ids like
`[[bos, 0, 1, eos], [bos, 2, eos], [bos, 3, 0, 4, eos]]`. bos/eos markers are skipped if `[[bos, 0, 1, eos], [bos, 2, eos], [bos, 3, 0, 4, eos]]`. bos/eos markers are skipped if
@ -594,58 +589,81 @@ class WeightedFrozenCLIPEmbedder(FrozenCLIPEmbedder):
truncation=True, truncation=True,
max_length=self.max_length, max_length=self.max_length,
return_overflowing_tokens=False, return_overflowing_tokens=False,
padding='do_not_pad', padding="do_not_pad",
return_tensors=None, # just give me lists of ints return_tensors=None, # just give me lists of ints
)['input_ids'] )["input_ids"]
result = [] result = []
for token_ids in token_ids_list: for token_ids in token_ids_list:
# trim eos/bos # trim eos/bos
token_ids = token_ids[1:-1] token_ids = token_ids[1:-1]
# pad for textual inversions with vector length >1 # pad for textual inversions with vector length >1
token_ids = self.textual_inversion_manager.expand_textual_inversion_token_ids_if_necessary(token_ids) token_ids = self.textual_inversion_manager.expand_textual_inversion_token_ids_if_necessary(
token_ids
)
# restrict length to max_length-2 (leaving room for bos/eos) # restrict length to max_length-2 (leaving room for bos/eos)
token_ids = token_ids[0:self.max_length - 2] token_ids = token_ids[0 : self.max_length - 2]
# add back eos/bos if requested # add back eos/bos if requested
if include_start_and_end_markers: if include_start_and_end_markers:
token_ids = [self.tokenizer.bos_token_id] + token_ids + [self.tokenizer.eos_token_id] token_ids = (
[self.tokenizer.bos_token_id]
+ token_ids
+ [self.tokenizer.eos_token_id]
)
result.append(token_ids) result.append(token_ids)
return result return result
@classmethod @classmethod
def apply_embedding_weights(self, embeddings: torch.Tensor, per_embedding_weights: list[float], normalize:bool) -> torch.Tensor: def apply_embedding_weights(
per_embedding_weights = torch.tensor(per_embedding_weights, dtype=embeddings.dtype, device=embeddings.device) self,
embeddings: torch.Tensor,
per_embedding_weights: list[float],
normalize: bool,
) -> torch.Tensor:
per_embedding_weights = torch.tensor(
per_embedding_weights, dtype=embeddings.dtype, device=embeddings.device
)
if normalize: if normalize:
per_embedding_weights = per_embedding_weights / torch.sum(per_embedding_weights) per_embedding_weights = per_embedding_weights / torch.sum(
reshaped_weights = per_embedding_weights.reshape(per_embedding_weights.shape + (1, 1,)) per_embedding_weights
#reshaped_weights = per_embedding_weights.reshape(per_embedding_weights.shape + (1,1,)).expand(embeddings.shape) )
reshaped_weights = per_embedding_weights.reshape(
per_embedding_weights.shape
+ (
1,
1,
)
)
# reshaped_weights = per_embedding_weights.reshape(per_embedding_weights.shape + (1,1,)).expand(embeddings.shape)
return torch.sum(embeddings * reshaped_weights, dim=1) return torch.sum(embeddings * reshaped_weights, dim=1)
# lerped embeddings has shape (77, 768) # lerped embeddings has shape (77, 768)
def get_tokens_and_weights(
def get_tokens_and_weights(self, fragments: list[str], weights: list[float]) -> (torch.Tensor, torch.Tensor): self, fragments: list[str], weights: list[float]
''' ) -> (torch.Tensor, torch.Tensor):
"""
:param fragments: :param fragments:
:param weights: Per-fragment weights (CFG scaling). No need for these to be normalized. They will not be normalized here and that's fine. :param weights: Per-fragment weights (CFG scaling). No need for these to be normalized. They will not be normalized here and that's fine.
:return: :return:
''' """
# empty is meaningful # empty is meaningful
if len(fragments) == 0 and len(weights) == 0: if len(fragments) == 0 and len(weights) == 0:
fragments = [''] fragments = [""]
weights = [1] weights = [1]
per_fragment_token_ids = self.get_token_ids(fragments, include_start_and_end_markers=False) per_fragment_token_ids = self.get_token_ids(
fragments, include_start_and_end_markers=False
)
all_token_ids = [] all_token_ids = []
per_token_weights = [] per_token_weights = []
#print("all fragments:", fragments, weights) # print("all fragments:", fragments, weights)
for index, fragment in enumerate(per_fragment_token_ids): for index, fragment in enumerate(per_fragment_token_ids):
weight = float(weights[index]) weight = float(weights[index])
#print("processing fragment", fragment, weight) # print("processing fragment", fragment, weight)
this_fragment_token_ids = per_fragment_token_ids[index] this_fragment_token_ids = per_fragment_token_ids[index]
#print("fragment", fragment, "processed to", this_fragment_token_ids) # print("fragment", fragment, "processed to", this_fragment_token_ids)
# append # append
all_token_ids += this_fragment_token_ids all_token_ids += this_fragment_token_ids
# fill out weights tensor with one float per token # fill out weights tensor with one float per token
@ -654,60 +672,85 @@ class WeightedFrozenCLIPEmbedder(FrozenCLIPEmbedder):
# leave room for bos/eos # leave room for bos/eos
max_token_count_without_bos_eos_markers = self.max_length - 2 max_token_count_without_bos_eos_markers = self.max_length - 2
if len(all_token_ids) > max_token_count_without_bos_eos_markers: if len(all_token_ids) > max_token_count_without_bos_eos_markers:
excess_token_count = len(all_token_ids) - max_token_count_without_bos_eos_markers excess_token_count = (
len(all_token_ids) - max_token_count_without_bos_eos_markers
)
# TODO build nice description string of how the truncation was applied # TODO build nice description string of how the truncation was applied
# this should be done by calling self.tokenizer.convert_ids_to_tokens() then passing the result to # this should be done by calling self.tokenizer.convert_ids_to_tokens() then passing the result to
# self.tokenizer.convert_tokens_to_string() for the token_ids on each side of the truncation limit. # self.tokenizer.convert_tokens_to_string() for the token_ids on each side of the truncation limit.
print(f">> Prompt is {excess_token_count} token(s) too long and has been truncated") print(
f">> Prompt is {excess_token_count} token(s) too long and has been truncated"
)
all_token_ids = all_token_ids[0:max_token_count_without_bos_eos_markers] all_token_ids = all_token_ids[0:max_token_count_without_bos_eos_markers]
per_token_weights = per_token_weights[0:max_token_count_without_bos_eos_markers] per_token_weights = per_token_weights[
0:max_token_count_without_bos_eos_markers
]
# pad out to a 77-entry array: [bos_token, <prompt tokens>, eos_token, pad_token…] # pad out to a 77-entry array: [bos_token, <prompt tokens>, eos_token, pad_token…]
# (77 = self.max_length) # (77 = self.max_length)
all_token_ids = [self.tokenizer.bos_token_id] + all_token_ids + [self.tokenizer.eos_token_id] all_token_ids = (
[self.tokenizer.bos_token_id]
+ all_token_ids
+ [self.tokenizer.eos_token_id]
)
per_token_weights = [1.0] + per_token_weights + [1.0] per_token_weights = [1.0] + per_token_weights + [1.0]
pad_length = self.max_length - len(all_token_ids) pad_length = self.max_length - len(all_token_ids)
all_token_ids += [self.tokenizer.pad_token_id] * pad_length all_token_ids += [self.tokenizer.pad_token_id] * pad_length
per_token_weights += [1.0] * pad_length per_token_weights += [1.0] * pad_length
all_token_ids_tensor = torch.tensor(all_token_ids, dtype=torch.long).to(self.device) all_token_ids_tensor = torch.tensor(all_token_ids, dtype=torch.long).to(
per_token_weights_tensor = torch.tensor(per_token_weights, dtype=torch.float32).to(self.device) self.device
#print(f"assembled all_token_ids_tensor with shape {all_token_ids_tensor.shape}") )
per_token_weights_tensor = torch.tensor(
per_token_weights, dtype=torch.float32
).to(self.device)
# print(f"assembled all_token_ids_tensor with shape {all_token_ids_tensor.shape}")
return all_token_ids_tensor, per_token_weights_tensor return all_token_ids_tensor, per_token_weights_tensor
def build_weighted_embedding_tensor(self, token_ids: torch.Tensor, per_token_weights: torch.Tensor, weight_delta_from_empty=True, **kwargs) -> torch.Tensor: def build_weighted_embedding_tensor(
''' self,
token_ids: torch.Tensor,
per_token_weights: torch.Tensor,
weight_delta_from_empty=True,
**kwargs,
) -> torch.Tensor:
"""
Build a tensor representing the passed-in tokens, each of which has a weight. Build a tensor representing the passed-in tokens, each of which has a weight.
:param token_ids: A tensor of shape (77) containing token ids (integers) :param token_ids: A tensor of shape (77) containing token ids (integers)
:param per_token_weights: A tensor of shape (77) containing weights (floats) :param per_token_weights: A tensor of shape (77) containing weights (floats)
:param method: Whether to multiply the whole feature vector for each token or just its distance from an "empty" feature vector :param method: Whether to multiply the whole feature vector for each token or just its distance from an "empty" feature vector
:param kwargs: passed on to self.transformer() :param kwargs: passed on to self.transformer()
:return: A tensor of shape (1, 77, 768) representing the requested weighted embeddings. :return: A tensor of shape (1, 77, 768) representing the requested weighted embeddings.
''' """
#print(f"building weighted embedding tensor for {tokens} with weights {per_token_weights}") # print(f"building weighted embedding tensor for {tokens} with weights {per_token_weights}")
if token_ids.shape != torch.Size([self.max_length]): if token_ids.shape != torch.Size([self.max_length]):
raise ValueError(f"token_ids has shape {token_ids.shape} - expected [{self.max_length}]") raise ValueError(
f"token_ids has shape {token_ids.shape} - expected [{self.max_length}]"
)
z = self.transformer(input_ids=token_ids.unsqueeze(0), **kwargs) z = self.transformer(input_ids=token_ids.unsqueeze(0), **kwargs)
batch_weights_expanded = per_token_weights.reshape(per_token_weights.shape + (1,)).expand(z.shape) batch_weights_expanded = per_token_weights.reshape(
per_token_weights.shape + (1,)
).expand(z.shape)
if weight_delta_from_empty: if weight_delta_from_empty:
empty_tokens = self.tokenizer([''] * z.shape[0], empty_tokens = self.tokenizer(
truncation=True, [""] * z.shape[0],
max_length=self.max_length, truncation=True,
padding='max_length', max_length=self.max_length,
return_tensors='pt' padding="max_length",
)['input_ids'].to(self.device) return_tensors="pt",
)["input_ids"].to(self.device)
empty_z = self.transformer(input_ids=empty_tokens, **kwargs) empty_z = self.transformer(input_ids=empty_tokens, **kwargs)
z_delta_from_empty = z - empty_z z_delta_from_empty = z - empty_z
weighted_z = empty_z + (z_delta_from_empty * batch_weights_expanded) weighted_z = empty_z + (z_delta_from_empty * batch_weights_expanded)
#weighted_z_delta_from_empty = (weighted_z-empty_z) # weighted_z_delta_from_empty = (weighted_z-empty_z)
#print("weighted z has delta from empty with sum", weighted_z_delta_from_empty.sum().item(), "mean", weighted_z_delta_from_empty.mean().item() ) # print("weighted z has delta from empty with sum", weighted_z_delta_from_empty.sum().item(), "mean", weighted_z_delta_from_empty.mean().item() )
#print("using empty-delta method, first 5 rows:") # print("using empty-delta method, first 5 rows:")
#print(weighted_z[:5]) # print(weighted_z[:5])
return weighted_z return weighted_z
@ -716,7 +759,7 @@ class WeightedFrozenCLIPEmbedder(FrozenCLIPEmbedder):
z *= batch_weights_expanded z *= batch_weights_expanded
after_weighting_mean = z.mean() after_weighting_mean = z.mean()
# correct the mean. not sure if this is right but it's what the automatic1111 fork of SD does # correct the mean. not sure if this is right but it's what the automatic1111 fork of SD does
mean_correction_factor = original_mean/after_weighting_mean mean_correction_factor = original_mean / after_weighting_mean
z *= mean_correction_factor z *= mean_correction_factor
return z return z
@ -728,7 +771,7 @@ class FrozenCLIPTextEmbedder(nn.Module):
def __init__( def __init__(
self, self,
version='ViT-L/14', version="ViT-L/14",
device=choose_torch_device(), device=choose_torch_device(),
max_length=77, max_length=77,
n_repeat=1, n_repeat=1,
@ -757,7 +800,7 @@ class FrozenCLIPTextEmbedder(nn.Module):
z = self(text) z = self(text)
if z.ndim == 2: if z.ndim == 2:
z = z[:, None, :] z = z[:, None, :]
z = repeat(z, 'b 1 d -> b k d', k=self.n_repeat) z = repeat(z, "b 1 d -> b k d", k=self.n_repeat)
return z return z
@ -779,12 +822,12 @@ class FrozenClipImageEmbedder(nn.Module):
self.antialias = antialias self.antialias = antialias
self.register_buffer( self.register_buffer(
'mean', "mean",
torch.Tensor([0.48145466, 0.4578275, 0.40821073]), torch.Tensor([0.48145466, 0.4578275, 0.40821073]),
persistent=False, persistent=False,
) )
self.register_buffer( self.register_buffer(
'std', "std",
torch.Tensor([0.26862954, 0.26130258, 0.27577711]), torch.Tensor([0.26862954, 0.26130258, 0.27577711]),
persistent=False, persistent=False,
) )
@ -794,7 +837,7 @@ class FrozenClipImageEmbedder(nn.Module):
x = kornia.geometry.resize( x = kornia.geometry.resize(
x, x,
(224, 224), (224, 224),
interpolation='bicubic', interpolation="bicubic",
align_corners=True, align_corners=True,
antialias=self.antialias, antialias=self.antialias,
) )
@ -808,8 +851,8 @@ class FrozenClipImageEmbedder(nn.Module):
return self.model.encode_image(self.preprocess(x)) return self.model.encode_image(self.preprocess(x))
if __name__ == '__main__': if __name__ == "__main__":
from ldm.util import count_params from ...util.util import count_params
model = FrozenCLIPEmbedder() model = FrozenCLIPEmbedder()
count_params(model, verbose=True) count_params(model, verbose=True)

Some files were not shown because too many files have changed in this diff Show More