mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Final phase of source tree restructure (#2833)
# All python code has been moved under `invokeai`. All vestiges of `ldm` and `ldm.invoke` are now gone. ***You will need to run `pip install -e .` before the code will work again!*** Everything seems to be functional, but extensive testing is advised. A guide to where the files have gone is forthcoming.
This commit is contained in:
commit
b3dccfaeb6
59
.github/CODEOWNERS
vendored
59
.github/CODEOWNERS
vendored
@ -1,51 +1,34 @@
|
||||
# continuous integration
|
||||
/.github/workflows/ @mauwii @lstein @blessedcoolant
|
||||
/.github/workflows/ @mauwii @lstein
|
||||
|
||||
# documentation
|
||||
/docs/ @lstein @mauwii @tildebyte @blessedcoolant
|
||||
mkdocs.yml @lstein @mauwii @blessedcoolant
|
||||
/docs/ @lstein @mauwii @tildebyte
|
||||
/mkdocs.yml @lstein @mauwii
|
||||
|
||||
# nodes
|
||||
/invokeai/app/ @Kyle0654 @blessedcoolant
|
||||
|
||||
# installation and configuration
|
||||
/pyproject.toml @mauwii @lstein @ebr @blessedcoolant
|
||||
/docker/ @mauwii @lstein @blessedcoolant
|
||||
/scripts/ @ebr @lstein @blessedcoolant
|
||||
/installer/ @ebr @lstein @tildebyte @blessedcoolant
|
||||
ldm/invoke/config @lstein @ebr @blessedcoolant
|
||||
invokeai/assets @lstein @ebr @blessedcoolant
|
||||
invokeai/configs @lstein @ebr @blessedcoolant
|
||||
/ldm/invoke/_version.py @lstein @blessedcoolant
|
||||
/pyproject.toml @mauwii @lstein @blessedcoolant
|
||||
/docker/ @mauwii @lstein
|
||||
/scripts/ @ebr @lstein
|
||||
/installer/ @lstein @ebr
|
||||
/invokeai/assets @lstein @ebr
|
||||
/invokeai/configs @lstein
|
||||
/invokeai/version @lstein @blessedcoolant
|
||||
|
||||
# web ui
|
||||
/invokeai/frontend @blessedcoolant @psychedelicious @lstein
|
||||
/invokeai/backend @blessedcoolant @psychedelicious @lstein
|
||||
|
||||
# generation and model management
|
||||
/ldm/*.py @lstein @blessedcoolant
|
||||
/ldm/generate.py @lstein @keturn @blessedcoolant
|
||||
/ldm/invoke/args.py @lstein @blessedcoolant
|
||||
/ldm/invoke/ckpt* @lstein @blessedcoolant
|
||||
/ldm/invoke/ckpt_generator @lstein @blessedcoolant
|
||||
/ldm/invoke/CLI.py @lstein @blessedcoolant
|
||||
/ldm/invoke/config @lstein @ebr @mauwii @blessedcoolant
|
||||
/ldm/invoke/generator @keturn @damian0815 @blessedcoolant
|
||||
/ldm/invoke/globals.py @lstein @blessedcoolant
|
||||
/ldm/invoke/merge_diffusers.py @lstein @blessedcoolant
|
||||
/ldm/invoke/model_manager.py @lstein @blessedcoolant
|
||||
/ldm/invoke/txt2mask.py @lstein @blessedcoolant
|
||||
/ldm/invoke/patchmatch.py @Kyle0654 @blessedcoolant @lstein
|
||||
/ldm/invoke/restoration @lstein @blessedcoolant
|
||||
# generation, model management, postprocessing
|
||||
/invokeai/backend @keturn @damian0815 @lstein @blessedcoolant @jpphoto
|
||||
|
||||
# attention, textual inversion, model configuration
|
||||
/ldm/models @damian0815 @keturn @lstein @blessedcoolant
|
||||
/ldm/modules @damian0815 @keturn @lstein @blessedcoolant
|
||||
# front ends
|
||||
/invokeai/frontend/CLI @lstein
|
||||
/invokeai/frontend/install @lstein @ebr @mauwii
|
||||
/invokeai/frontend/merge @lstein @blessedcoolant @hipsterusername
|
||||
/invokeai/frontend/training @lstein @blessedcoolant @hipsterusername
|
||||
/invokeai/frontend/web @psychedelicious @blessedcoolant
|
||||
|
||||
# Nodes
|
||||
apps/ @Kyle0654 @lstein @blessedcoolant
|
||||
|
||||
# legacy REST API
|
||||
# is CapableWeb still engaged?
|
||||
/ldm/invoke/pngwriter.py @CapableWeb @lstein @blessedcoolant
|
||||
/ldm/invoke/server_legacy.py @CapableWeb @lstein @blessedcoolant
|
||||
/scripts/legacy_api.py @CapableWeb @lstein @blessedcoolant
|
||||
/tests/legacy_tests.sh @CapableWeb @lstein @blessedcoolant
|
||||
|
||||
|
2
.github/workflows/build-container.yml
vendored
2
.github/workflows/build-container.yml
vendored
@ -9,7 +9,7 @@ on:
|
||||
- 'dev/docker/*'
|
||||
paths:
|
||||
- 'pyproject.toml'
|
||||
- 'ldm/**'
|
||||
- 'invokeai/**'
|
||||
- 'invokeai/backend/**'
|
||||
- 'invokeai/configs/**'
|
||||
- 'invokeai/frontend/dist/**'
|
||||
|
6
.github/workflows/lint-frontend.yml
vendored
6
.github/workflows/lint-frontend.yml
vendored
@ -3,14 +3,14 @@ name: Lint frontend
|
||||
on:
|
||||
pull_request:
|
||||
paths:
|
||||
- 'invokeai/frontend/**'
|
||||
- 'invokeai/frontend/web/**'
|
||||
push:
|
||||
paths:
|
||||
- 'invokeai/frontend/**'
|
||||
- 'invokeai/frontend/web/**'
|
||||
|
||||
defaults:
|
||||
run:
|
||||
working-directory: invokeai/frontend
|
||||
working-directory: invokeai/frontend/web
|
||||
|
||||
jobs:
|
||||
lint-frontend:
|
||||
|
2
.github/workflows/pypi-release.yml
vendored
2
.github/workflows/pypi-release.yml
vendored
@ -3,7 +3,7 @@ name: PyPI Release
|
||||
on:
|
||||
push:
|
||||
paths:
|
||||
- 'ldm/invoke/_version.py'
|
||||
- 'invokeai/version/invokeai_version.py'
|
||||
workflow_dispatch:
|
||||
|
||||
jobs:
|
||||
|
2
.github/workflows/test-invoke-pip-skip.yml
vendored
2
.github/workflows/test-invoke-pip-skip.yml
vendored
@ -3,7 +3,7 @@ on:
|
||||
pull_request:
|
||||
paths-ignore:
|
||||
- 'pyproject.toml'
|
||||
- 'ldm/**'
|
||||
- 'invokeai/**'
|
||||
- 'invokeai/backend/**'
|
||||
- 'invokeai/configs/**'
|
||||
- 'invokeai/frontend/dist/**'
|
||||
|
6
.github/workflows/test-invoke-pip.yml
vendored
6
.github/workflows/test-invoke-pip.yml
vendored
@ -5,14 +5,14 @@ on:
|
||||
- 'main'
|
||||
paths:
|
||||
- 'pyproject.toml'
|
||||
- 'ldm/**'
|
||||
- 'invokeai/**'
|
||||
- 'invokeai/backend/**'
|
||||
- 'invokeai/configs/**'
|
||||
- 'invokeai/frontend/dist/**'
|
||||
pull_request:
|
||||
paths:
|
||||
- 'pyproject.toml'
|
||||
- 'ldm/**'
|
||||
- 'invokeai/**'
|
||||
- 'invokeai/backend/**'
|
||||
- 'invokeai/configs/**'
|
||||
- 'invokeai/frontend/dist/**'
|
||||
@ -112,7 +112,7 @@ jobs:
|
||||
- name: set INVOKEAI_OUTDIR
|
||||
run: >
|
||||
python -c
|
||||
"import os;from ldm.invoke.globals import Globals;OUTDIR=os.path.join(Globals.root,str('outputs'));print(f'INVOKEAI_OUTDIR={OUTDIR}')"
|
||||
"import os;from invokeai.backend.globals import Globals;OUTDIR=os.path.join(Globals.root,str('outputs'));print(f'INVOKEAI_OUTDIR={OUTDIR}')"
|
||||
>> ${{ matrix.github-env }}
|
||||
|
||||
- name: run invokeai-configure
|
||||
|
10
.gitignore
vendored
10
.gitignore
vendored
@ -198,7 +198,7 @@ checkpoints
|
||||
.DS_Store
|
||||
|
||||
# Let the frontend manage its own gitignore
|
||||
!invokeai/frontend/*
|
||||
!invokeai/frontend/web/*
|
||||
|
||||
# Scratch folder
|
||||
.scratch/
|
||||
@ -213,11 +213,6 @@ gfpgan/
|
||||
# config file (will be created by installer)
|
||||
configs/models.yaml
|
||||
|
||||
# weights (will be created by installer)
|
||||
models/ldm/stable-diffusion-v1/*.ckpt
|
||||
models/clipseg
|
||||
models/gfpgan
|
||||
|
||||
# ignore initfile
|
||||
.invokeai
|
||||
|
||||
@ -232,6 +227,3 @@ installer/install.bat
|
||||
installer/install.sh
|
||||
installer/update.bat
|
||||
installer/update.sh
|
||||
|
||||
# no longer stored in source directory
|
||||
models
|
||||
|
@ -11,10 +11,10 @@ if [[ -v "VIRTUAL_ENV" ]]; then
|
||||
exit -1
|
||||
fi
|
||||
|
||||
VERSION=$(cd ..; python -c "from ldm.invoke import __version__ as version; print(version)")
|
||||
VERSION=$(cd ..; python -c "from invokeai.version import __version__ as version; print(version)")
|
||||
PATCH=""
|
||||
VERSION="v${VERSION}${PATCH}"
|
||||
LATEST_TAG="v2.3-latest"
|
||||
LATEST_TAG="v3.0-latest"
|
||||
|
||||
echo Building installer for version $VERSION
|
||||
echo "Be certain that you're in the 'installer' directory before continuing."
|
||||
|
@ -291,7 +291,7 @@ class InvokeAiInstance:
|
||||
src = Path(__file__).parents[1].expanduser().resolve()
|
||||
# if the above directory contains one of these files, we'll do a source install
|
||||
next(src.glob("pyproject.toml"))
|
||||
next(src.glob("ldm"))
|
||||
next(src.glob("invokeai"))
|
||||
except StopIteration:
|
||||
print("Unable to find a wheel or perform a source install. Giving up.")
|
||||
|
||||
@ -342,14 +342,14 @@ class InvokeAiInstance:
|
||||
|
||||
introduction()
|
||||
|
||||
from ldm.invoke.config import invokeai_configure
|
||||
from invokeai.frontend.install import invokeai_configure
|
||||
|
||||
# NOTE: currently the config script does its own arg parsing! this means the command-line switches
|
||||
# from the installer will also automatically propagate down to the config script.
|
||||
# this may change in the future with config refactoring!
|
||||
succeeded = False
|
||||
try:
|
||||
invokeai_configure.main()
|
||||
invokeai_configure()
|
||||
succeeded = True
|
||||
except requests.exceptions.ConnectionError as e:
|
||||
print(f'\nA network error was encountered during configuration and download: {str(e)}')
|
||||
|
@ -1,3 +1,11 @@
|
||||
After version 2.3 is released, the ldm/invoke modules will be migrated to this location
|
||||
so that we have a proper invokeai distribution. Currently it is only being used for
|
||||
data files.
|
||||
Organization of the source tree:
|
||||
|
||||
app -- Home of nodes invocations and services
|
||||
assets -- Images and other data files used by InvokeAI
|
||||
backend -- Non-user facing libraries, including the rendering
|
||||
core.
|
||||
configs -- Configuration files used at install and run times
|
||||
frontend -- User-facing scripts, including the CLI and the WebUI
|
||||
version -- Current InvokeAI version string, stored
|
||||
in version/invokeai_version.py
|
||||
|
@ -1,33 +1,31 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
from argparse import Namespace
|
||||
import os
|
||||
|
||||
from ..services.processor import DefaultInvocationProcessor
|
||||
|
||||
from ..services.graph import GraphExecutionState
|
||||
from ..services.sqlite import SqliteItemStorage
|
||||
from argparse import Namespace
|
||||
|
||||
from ...globals import Globals
|
||||
|
||||
from ..services.generate_initializer import get_generate
|
||||
from ..services.graph import GraphExecutionState
|
||||
from ..services.image_storage import DiskImageStorage
|
||||
from ..services.invocation_queue import MemoryInvocationQueue
|
||||
from ..services.invocation_services import InvocationServices
|
||||
from ..services.invoker import Invoker
|
||||
from ..services.generate_initializer import get_generate
|
||||
from ..services.processor import DefaultInvocationProcessor
|
||||
from ..services.sqlite import SqliteItemStorage
|
||||
from .events import FastAPIEventService
|
||||
|
||||
|
||||
# TODO: is there a better way to achieve this?
|
||||
def check_internet()->bool:
|
||||
'''
|
||||
def check_internet() -> bool:
|
||||
"""
|
||||
Return true if the internet is reachable.
|
||||
It does this by pinging huggingface.co.
|
||||
'''
|
||||
"""
|
||||
import urllib.request
|
||||
host = 'http://huggingface.co'
|
||||
|
||||
host = "http://huggingface.co"
|
||||
try:
|
||||
urllib.request.urlopen(host,timeout=1)
|
||||
urllib.request.urlopen(host, timeout=1)
|
||||
return True
|
||||
except:
|
||||
return False
|
||||
@ -35,14 +33,11 @@ def check_internet()->bool:
|
||||
|
||||
class ApiDependencies:
|
||||
"""Contains and initializes all dependencies for the API"""
|
||||
|
||||
invoker: Invoker = None
|
||||
|
||||
@staticmethod
|
||||
def initialize(
|
||||
args,
|
||||
config,
|
||||
event_handler_id: int
|
||||
):
|
||||
def initialize(args, config, event_handler_id: int):
|
||||
Globals.try_patchmatch = args.patchmatch
|
||||
Globals.always_use_cpu = args.always_use_cpu
|
||||
Globals.internet_available = args.internet_available and check_internet()
|
||||
@ -50,26 +45,30 @@ class ApiDependencies:
|
||||
Globals.ckpt_convert = args.ckpt_convert
|
||||
|
||||
# TODO: Use a logger
|
||||
print(f'>> Internet connectivity is {Globals.internet_available}')
|
||||
print(f">> Internet connectivity is {Globals.internet_available}")
|
||||
|
||||
generate = get_generate(args, config)
|
||||
|
||||
events = FastAPIEventService(event_handler_id)
|
||||
|
||||
output_folder = os.path.abspath(os.path.join(os.path.dirname(__file__), '../../../../outputs'))
|
||||
output_folder = os.path.abspath(
|
||||
os.path.join(os.path.dirname(__file__), "../../../../outputs")
|
||||
)
|
||||
|
||||
images = DiskImageStorage(output_folder)
|
||||
|
||||
# TODO: build a file/path manager?
|
||||
db_location = os.path.join(output_folder, 'invokeai.db')
|
||||
db_location = os.path.join(output_folder, "invokeai.db")
|
||||
|
||||
services = InvocationServices(
|
||||
generate = generate,
|
||||
events = events,
|
||||
images = images,
|
||||
queue = MemoryInvocationQueue(),
|
||||
graph_execution_manager = SqliteItemStorage[GraphExecutionState](filename = db_location, table_name = 'graph_executions'),
|
||||
processor = DefaultInvocationProcessor()
|
||||
generate=generate,
|
||||
events=events,
|
||||
images=images,
|
||||
queue=MemoryInvocationQueue(),
|
||||
graph_execution_manager=SqliteItemStorage[GraphExecutionState](
|
||||
filename=db_location, table_name="graph_executions"
|
||||
),
|
||||
processor=DefaultInvocationProcessor(),
|
||||
)
|
||||
|
||||
ApiDependencies.invoker = Invoker(services)
|
@ -1,11 +1,14 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
import asyncio
|
||||
import threading
|
||||
from queue import Empty, Queue
|
||||
from typing import Any
|
||||
|
||||
from fastapi_events.dispatcher import dispatch
|
||||
|
||||
from ..services.events import EventServiceBase
|
||||
import threading
|
||||
|
||||
|
||||
class FastAPIEventService(EventServiceBase):
|
||||
event_handler_id: int
|
||||
@ -16,39 +19,34 @@ class FastAPIEventService(EventServiceBase):
|
||||
self.event_handler_id = event_handler_id
|
||||
self.__queue = Queue()
|
||||
self.__stop_event = threading.Event()
|
||||
asyncio.create_task(self.__dispatch_from_queue(stop_event = self.__stop_event))
|
||||
asyncio.create_task(self.__dispatch_from_queue(stop_event=self.__stop_event))
|
||||
|
||||
super().__init__()
|
||||
|
||||
|
||||
def stop(self, *args, **kwargs):
|
||||
self.__stop_event.set()
|
||||
self.__queue.put(None)
|
||||
|
||||
|
||||
def dispatch(self, event_name: str, payload: Any) -> None:
|
||||
self.__queue.put(dict(
|
||||
event_name = event_name,
|
||||
payload = payload
|
||||
))
|
||||
|
||||
self.__queue.put(dict(event_name=event_name, payload=payload))
|
||||
|
||||
async def __dispatch_from_queue(self, stop_event: threading.Event):
|
||||
"""Get events on from the queue and dispatch them, from the correct thread"""
|
||||
while not stop_event.is_set():
|
||||
try:
|
||||
event = self.__queue.get(block = False)
|
||||
if not event: # Probably stopping
|
||||
event = self.__queue.get(block=False)
|
||||
if not event: # Probably stopping
|
||||
continue
|
||||
|
||||
dispatch(
|
||||
event.get('event_name'),
|
||||
payload = event.get('payload'),
|
||||
middleware_id = self.event_handler_id)
|
||||
event.get("event_name"),
|
||||
payload=event.get("payload"),
|
||||
middleware_id=self.event_handler_id,
|
||||
)
|
||||
|
||||
except Empty:
|
||||
await asyncio.sleep(0.001)
|
||||
pass
|
||||
|
||||
except asyncio.CancelledError as e:
|
||||
raise e # Raise a proper error
|
||||
raise e # Raise a proper error
|
56
invokeai/app/api/routers/images.py
Normal file
56
invokeai/app/api/routers/images.py
Normal file
@ -0,0 +1,56 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from fastapi import Path, Request, UploadFile
|
||||
from fastapi.responses import FileResponse, Response
|
||||
from fastapi.routing import APIRouter
|
||||
from PIL import Image
|
||||
|
||||
from ...services.image_storage import ImageType
|
||||
from ..dependencies import ApiDependencies
|
||||
|
||||
images_router = APIRouter(prefix="/v1/images", tags=["images"])
|
||||
|
||||
|
||||
@images_router.get("/{image_type}/{image_name}", operation_id="get_image")
|
||||
async def get_image(
|
||||
image_type: ImageType = Path(description="The type of image to get"),
|
||||
image_name: str = Path(description="The name of the image to get"),
|
||||
):
|
||||
"""Gets a result"""
|
||||
# TODO: This is not really secure at all. At least make sure only output results are served
|
||||
filename = ApiDependencies.invoker.services.images.get_path(image_type, image_name)
|
||||
return FileResponse(filename)
|
||||
|
||||
|
||||
@images_router.post(
|
||||
"/uploads/",
|
||||
operation_id="upload_image",
|
||||
responses={
|
||||
201: {"description": "The image was uploaded successfully"},
|
||||
404: {"description": "Session not found"},
|
||||
},
|
||||
)
|
||||
async def upload_image(file: UploadFile, request: Request):
|
||||
if not file.content_type.startswith("image"):
|
||||
return Response(status_code=415)
|
||||
|
||||
contents = await file.read()
|
||||
try:
|
||||
im = Image.open(contents)
|
||||
except:
|
||||
# Error opening the image
|
||||
return Response(status_code=415)
|
||||
|
||||
filename = f"{str(int(datetime.now(timezone.utc).timestamp()))}.png"
|
||||
ApiDependencies.invoker.services.images.save(ImageType.UPLOAD, filename, im)
|
||||
|
||||
return Response(
|
||||
status_code=201,
|
||||
headers={
|
||||
"Location": request.url_for(
|
||||
"get_image", image_type=ImageType.UPLOAD, image_name=filename
|
||||
)
|
||||
},
|
||||
)
|
271
invokeai/app/api/routers/sessions.py
Normal file
271
invokeai/app/api/routers/sessions.py
Normal file
@ -0,0 +1,271 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
from typing import Annotated, List, Optional, Union
|
||||
|
||||
from fastapi import Body, Path, Query
|
||||
from fastapi.responses import Response
|
||||
from fastapi.routing import APIRouter
|
||||
from pydantic.fields import Field
|
||||
|
||||
from ...invocations import *
|
||||
from ...invocations.baseinvocation import BaseInvocation
|
||||
from ...services.graph import (
|
||||
EdgeConnection,
|
||||
Graph,
|
||||
GraphExecutionState,
|
||||
NodeAlreadyExecutedError,
|
||||
)
|
||||
from ...services.item_storage import PaginatedResults
|
||||
from ..dependencies import ApiDependencies
|
||||
|
||||
session_router = APIRouter(prefix="/v1/sessions", tags=["sessions"])
|
||||
|
||||
|
||||
@session_router.post(
|
||||
"/",
|
||||
operation_id="create_session",
|
||||
responses={
|
||||
200: {"model": GraphExecutionState},
|
||||
400: {"description": "Invalid json"},
|
||||
},
|
||||
)
|
||||
async def create_session(
|
||||
graph: Optional[Graph] = Body(
|
||||
default=None, description="The graph to initialize the session with"
|
||||
)
|
||||
) -> GraphExecutionState:
|
||||
"""Creates a new session, optionally initializing it with an invocation graph"""
|
||||
session = ApiDependencies.invoker.create_execution_state(graph)
|
||||
return session
|
||||
|
||||
|
||||
@session_router.get(
|
||||
"/",
|
||||
operation_id="list_sessions",
|
||||
responses={200: {"model": PaginatedResults[GraphExecutionState]}},
|
||||
)
|
||||
async def list_sessions(
|
||||
page: int = Query(default=0, description="The page of results to get"),
|
||||
per_page: int = Query(default=10, description="The number of results per page"),
|
||||
query: str = Query(default="", description="The query string to search for"),
|
||||
) -> PaginatedResults[GraphExecutionState]:
|
||||
"""Gets a list of sessions, optionally searching"""
|
||||
if filter == "":
|
||||
result = ApiDependencies.invoker.services.graph_execution_manager.list(
|
||||
page, per_page
|
||||
)
|
||||
else:
|
||||
result = ApiDependencies.invoker.services.graph_execution_manager.search(
|
||||
query, page, per_page
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
@session_router.get(
|
||||
"/{session_id}",
|
||||
operation_id="get_session",
|
||||
responses={
|
||||
200: {"model": GraphExecutionState},
|
||||
404: {"description": "Session not found"},
|
||||
},
|
||||
)
|
||||
async def get_session(
|
||||
session_id: str = Path(description="The id of the session to get"),
|
||||
) -> GraphExecutionState:
|
||||
"""Gets a session"""
|
||||
session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
|
||||
if session is None:
|
||||
return Response(status_code=404)
|
||||
else:
|
||||
return session
|
||||
|
||||
|
||||
@session_router.post(
|
||||
"/{session_id}/nodes",
|
||||
operation_id="add_node",
|
||||
responses={
|
||||
200: {"model": str},
|
||||
400: {"description": "Invalid node or link"},
|
||||
404: {"description": "Session not found"},
|
||||
},
|
||||
)
|
||||
async def add_node(
|
||||
session_id: str = Path(description="The id of the session"),
|
||||
node: Annotated[
|
||||
Union[BaseInvocation.get_invocations()], Field(discriminator="type")
|
||||
] = Body(description="The node to add"),
|
||||
) -> str:
|
||||
"""Adds a node to the graph"""
|
||||
session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
|
||||
if session is None:
|
||||
return Response(status_code=404)
|
||||
|
||||
try:
|
||||
session.add_node(node)
|
||||
ApiDependencies.invoker.services.graph_execution_manager.set(
|
||||
session
|
||||
) # TODO: can this be done automatically, or add node through an API?
|
||||
return session.id
|
||||
except NodeAlreadyExecutedError:
|
||||
return Response(status_code=400)
|
||||
except IndexError:
|
||||
return Response(status_code=400)
|
||||
|
||||
|
||||
@session_router.put(
|
||||
"/{session_id}/nodes/{node_path}",
|
||||
operation_id="update_node",
|
||||
responses={
|
||||
200: {"model": GraphExecutionState},
|
||||
400: {"description": "Invalid node or link"},
|
||||
404: {"description": "Session not found"},
|
||||
},
|
||||
)
|
||||
async def update_node(
|
||||
session_id: str = Path(description="The id of the session"),
|
||||
node_path: str = Path(description="The path to the node in the graph"),
|
||||
node: Annotated[
|
||||
Union[BaseInvocation.get_invocations()], Field(discriminator="type")
|
||||
] = Body(description="The new node"),
|
||||
) -> GraphExecutionState:
|
||||
"""Updates a node in the graph and removes all linked edges"""
|
||||
session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
|
||||
if session is None:
|
||||
return Response(status_code=404)
|
||||
|
||||
try:
|
||||
session.update_node(node_path, node)
|
||||
ApiDependencies.invoker.services.graph_execution_manager.set(
|
||||
session
|
||||
) # TODO: can this be done automatically, or add node through an API?
|
||||
return session
|
||||
except NodeAlreadyExecutedError:
|
||||
return Response(status_code=400)
|
||||
except IndexError:
|
||||
return Response(status_code=400)
|
||||
|
||||
|
||||
@session_router.delete(
|
||||
"/{session_id}/nodes/{node_path}",
|
||||
operation_id="delete_node",
|
||||
responses={
|
||||
200: {"model": GraphExecutionState},
|
||||
400: {"description": "Invalid node or link"},
|
||||
404: {"description": "Session not found"},
|
||||
},
|
||||
)
|
||||
async def delete_node(
|
||||
session_id: str = Path(description="The id of the session"),
|
||||
node_path: str = Path(description="The path to the node to delete"),
|
||||
) -> GraphExecutionState:
|
||||
"""Deletes a node in the graph and removes all linked edges"""
|
||||
session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
|
||||
if session is None:
|
||||
return Response(status_code=404)
|
||||
|
||||
try:
|
||||
session.delete_node(node_path)
|
||||
ApiDependencies.invoker.services.graph_execution_manager.set(
|
||||
session
|
||||
) # TODO: can this be done automatically, or add node through an API?
|
||||
return session
|
||||
except NodeAlreadyExecutedError:
|
||||
return Response(status_code=400)
|
||||
except IndexError:
|
||||
return Response(status_code=400)
|
||||
|
||||
|
||||
@session_router.post(
|
||||
"/{session_id}/edges",
|
||||
operation_id="add_edge",
|
||||
responses={
|
||||
200: {"model": GraphExecutionState},
|
||||
400: {"description": "Invalid node or link"},
|
||||
404: {"description": "Session not found"},
|
||||
},
|
||||
)
|
||||
async def add_edge(
|
||||
session_id: str = Path(description="The id of the session"),
|
||||
edge: tuple[EdgeConnection, EdgeConnection] = Body(description="The edge to add"),
|
||||
) -> GraphExecutionState:
|
||||
"""Adds an edge to the graph"""
|
||||
session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
|
||||
if session is None:
|
||||
return Response(status_code=404)
|
||||
|
||||
try:
|
||||
session.add_edge(edge)
|
||||
ApiDependencies.invoker.services.graph_execution_manager.set(
|
||||
session
|
||||
) # TODO: can this be done automatically, or add node through an API?
|
||||
return session
|
||||
except NodeAlreadyExecutedError:
|
||||
return Response(status_code=400)
|
||||
except IndexError:
|
||||
return Response(status_code=400)
|
||||
|
||||
|
||||
# TODO: the edge being in the path here is really ugly, find a better solution
|
||||
@session_router.delete(
|
||||
"/{session_id}/edges/{from_node_id}/{from_field}/{to_node_id}/{to_field}",
|
||||
operation_id="delete_edge",
|
||||
responses={
|
||||
200: {"model": GraphExecutionState},
|
||||
400: {"description": "Invalid node or link"},
|
||||
404: {"description": "Session not found"},
|
||||
},
|
||||
)
|
||||
async def delete_edge(
|
||||
session_id: str = Path(description="The id of the session"),
|
||||
from_node_id: str = Path(description="The id of the node the edge is coming from"),
|
||||
from_field: str = Path(description="The field of the node the edge is coming from"),
|
||||
to_node_id: str = Path(description="The id of the node the edge is going to"),
|
||||
to_field: str = Path(description="The field of the node the edge is going to"),
|
||||
) -> GraphExecutionState:
|
||||
"""Deletes an edge from the graph"""
|
||||
session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
|
||||
if session is None:
|
||||
return Response(status_code=404)
|
||||
|
||||
try:
|
||||
edge = (
|
||||
EdgeConnection(node_id=from_node_id, field=from_field),
|
||||
EdgeConnection(node_id=to_node_id, field=to_field),
|
||||
)
|
||||
session.delete_edge(edge)
|
||||
ApiDependencies.invoker.services.graph_execution_manager.set(
|
||||
session
|
||||
) # TODO: can this be done automatically, or add node through an API?
|
||||
return session
|
||||
except NodeAlreadyExecutedError:
|
||||
return Response(status_code=400)
|
||||
except IndexError:
|
||||
return Response(status_code=400)
|
||||
|
||||
|
||||
@session_router.put(
|
||||
"/{session_id}/invoke",
|
||||
operation_id="invoke_session",
|
||||
responses={
|
||||
200: {"model": None},
|
||||
202: {"description": "The invocation is queued"},
|
||||
400: {"description": "The session has no invocations ready to invoke"},
|
||||
404: {"description": "Session not found"},
|
||||
},
|
||||
)
|
||||
async def invoke_session(
|
||||
session_id: str = Path(description="The id of the session to invoke"),
|
||||
all: bool = Query(
|
||||
default=False, description="Whether or not to invoke all remaining invocations"
|
||||
),
|
||||
) -> None:
|
||||
"""Invokes a session"""
|
||||
session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
|
||||
if session is None:
|
||||
return Response(status_code=404)
|
||||
|
||||
if session.is_complete():
|
||||
return Response(status_code=400)
|
||||
|
||||
ApiDependencies.invoker.invoke(session, invoke_all=all)
|
||||
return Response(status_code=202)
|
@ -1,36 +1,38 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
from fastapi import FastAPI
|
||||
from fastapi_socketio import SocketManager
|
||||
from fastapi_events.handlers.local import local_handler
|
||||
from fastapi_events.typing import Event
|
||||
from fastapi_socketio import SocketManager
|
||||
|
||||
from ..services.events import EventServiceBase
|
||||
|
||||
|
||||
class SocketIO:
|
||||
__sio: SocketManager
|
||||
|
||||
def __init__(self, app: FastAPI):
|
||||
self.__sio = SocketManager(app = app)
|
||||
self.__sio.on('subscribe', handler=self._handle_sub)
|
||||
self.__sio.on('unsubscribe', handler=self._handle_unsub)
|
||||
self.__sio = SocketManager(app=app)
|
||||
self.__sio.on("subscribe", handler=self._handle_sub)
|
||||
self.__sio.on("unsubscribe", handler=self._handle_unsub)
|
||||
|
||||
local_handler.register(
|
||||
event_name = EventServiceBase.session_event,
|
||||
_func=self._handle_session_event
|
||||
event_name=EventServiceBase.session_event, _func=self._handle_session_event
|
||||
)
|
||||
|
||||
async def _handle_session_event(self, event: Event):
|
||||
await self.__sio.emit(
|
||||
event = event[1]['event'],
|
||||
data = event[1]['data'],
|
||||
room = event[1]['data']['graph_execution_state_id']
|
||||
event=event[1]["event"],
|
||||
data=event[1]["data"],
|
||||
room=event[1]["data"]["graph_execution_state_id"],
|
||||
)
|
||||
|
||||
async def _handle_sub(self, sid, data, *args, **kwargs):
|
||||
if 'session' in data:
|
||||
self.__sio.enter_room(sid, data['session'])
|
||||
if "session" in data:
|
||||
self.__sio.enter_room(sid, data["session"])
|
||||
|
||||
# @app.sio.on('unsubscribe')
|
||||
|
||||
async def _handle_unsub(self, sid, data, *args, **kwargs):
|
||||
if 'session' in data:
|
||||
self.__sio.leave_room(sid, data['session'])
|
||||
if "session" in data:
|
||||
self.__sio.leave_room(sid, data["session"])
|
@ -2,36 +2,37 @@
|
||||
|
||||
import asyncio
|
||||
from inspect import signature
|
||||
from fastapi import FastAPI
|
||||
from fastapi.openapi.utils import get_openapi
|
||||
from fastapi.openapi.docs import get_swagger_ui_html, get_redoc_html
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from fastapi_events.middleware import EventHandlerASGIMiddleware
|
||||
from fastapi_events.handlers.local import local_handler
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from pydantic.schema import schema
|
||||
|
||||
import uvicorn
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.openapi.docs import get_redoc_html, get_swagger_ui_html
|
||||
from fastapi.openapi.utils import get_openapi
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from fastapi_events.handlers.local import local_handler
|
||||
from fastapi_events.middleware import EventHandlerASGIMiddleware
|
||||
from pydantic.schema import schema
|
||||
|
||||
from ..args import Args
|
||||
from .api.dependencies import ApiDependencies
|
||||
from .api.routers import images, sessions
|
||||
from .api.sockets import SocketIO
|
||||
from .invocations import *
|
||||
from .invocations.baseinvocation import BaseInvocation
|
||||
from .api.routers import images, sessions
|
||||
from .api.dependencies import ApiDependencies
|
||||
from ..args import Args
|
||||
|
||||
# Create the app
|
||||
# TODO: create this all in a method so configuration/etc. can be passed in?
|
||||
app = FastAPI(
|
||||
title = "Invoke AI",
|
||||
docs_url = None,
|
||||
redoc_url = None
|
||||
)
|
||||
app = FastAPI(title="Invoke AI", docs_url=None, redoc_url=None)
|
||||
|
||||
# Add event handler
|
||||
event_handler_id: int = id(app)
|
||||
app.add_middleware(
|
||||
EventHandlerASGIMiddleware,
|
||||
handlers = [local_handler], # TODO: consider doing this in services to support different configurations
|
||||
middleware_id = event_handler_id)
|
||||
handlers=[
|
||||
local_handler
|
||||
], # TODO: consider doing this in services to support different configurations
|
||||
middleware_id=event_handler_id,
|
||||
)
|
||||
|
||||
# Add CORS
|
||||
# TODO: use configuration for this
|
||||
@ -48,38 +49,34 @@ socket_io = SocketIO(app)
|
||||
|
||||
config = {}
|
||||
|
||||
|
||||
# Add startup event to load dependencies
|
||||
@app.on_event('startup')
|
||||
@app.on_event("startup")
|
||||
async def startup_event():
|
||||
args = Args()
|
||||
config = args.parse_args()
|
||||
|
||||
ApiDependencies.initialize(
|
||||
args = args,
|
||||
config = config,
|
||||
event_handler_id = event_handler_id
|
||||
args=args, config=config, event_handler_id=event_handler_id
|
||||
)
|
||||
|
||||
|
||||
# Shut down threads
|
||||
@app.on_event('shutdown')
|
||||
@app.on_event("shutdown")
|
||||
async def shutdown_event():
|
||||
ApiDependencies.shutdown()
|
||||
|
||||
|
||||
# Include all routers
|
||||
# TODO: REMOVE
|
||||
# app.include_router(
|
||||
# invocation.invocation_router,
|
||||
# prefix = '/api')
|
||||
|
||||
app.include_router(
|
||||
sessions.session_router,
|
||||
prefix = '/api'
|
||||
)
|
||||
app.include_router(sessions.session_router, prefix="/api")
|
||||
|
||||
app.include_router(images.images_router, prefix="/api")
|
||||
|
||||
app.include_router(
|
||||
images.images_router,
|
||||
prefix = '/api'
|
||||
)
|
||||
|
||||
# Build a custom OpenAPI to include all outputs
|
||||
# TODO: can outputs be included on metadata of invocation schemas somehow?
|
||||
@ -87,10 +84,10 @@ def custom_openapi():
|
||||
if app.openapi_schema:
|
||||
return app.openapi_schema
|
||||
openapi_schema = get_openapi(
|
||||
title = app.title,
|
||||
description = "An API for invoking AI image operations",
|
||||
version = "1.0.0",
|
||||
routes = app.routes
|
||||
title=app.title,
|
||||
description="An API for invoking AI image operations",
|
||||
version="1.0.0",
|
||||
routes=app.routes,
|
||||
)
|
||||
|
||||
# Add all outputs
|
||||
@ -102,12 +99,12 @@ def custom_openapi():
|
||||
output_types.add(output_type)
|
||||
|
||||
output_schemas = schema(output_types, ref_prefix="#/components/schemas/")
|
||||
for schema_key, output_schema in output_schemas['definitions'].items():
|
||||
for schema_key, output_schema in output_schemas["definitions"].items():
|
||||
openapi_schema["components"]["schemas"][schema_key] = output_schema
|
||||
|
||||
# TODO: note that we assume the schema_key here is the TYPE.__name__
|
||||
# This could break in some cases, figure out a better way to do it
|
||||
output_type_titles[schema_key] = output_schema['title']
|
||||
output_type_titles[schema_key] = output_schema["title"]
|
||||
|
||||
# Add a reference to the output type to additionalProperties of the invoker schema
|
||||
for invoker in all_invocations:
|
||||
@ -115,46 +112,46 @@ def custom_openapi():
|
||||
output_type = signature(invoker.invoke).return_annotation
|
||||
output_type_title = output_type_titles[output_type.__name__]
|
||||
invoker_schema = openapi_schema["components"]["schemas"][invoker_name]
|
||||
outputs_ref = { '$ref': f'#/components/schemas/{output_type_title}' }
|
||||
if 'additionalProperties' not in invoker_schema:
|
||||
invoker_schema['additionalProperties'] = {}
|
||||
outputs_ref = {"$ref": f"#/components/schemas/{output_type_title}"}
|
||||
if "additionalProperties" not in invoker_schema:
|
||||
invoker_schema["additionalProperties"] = {}
|
||||
|
||||
invoker_schema['additionalProperties']['outputs'] = outputs_ref
|
||||
invoker_schema["additionalProperties"]["outputs"] = outputs_ref
|
||||
|
||||
app.openapi_schema = openapi_schema
|
||||
return app.openapi_schema
|
||||
|
||||
|
||||
app.openapi = custom_openapi
|
||||
|
||||
# Override API doc favicons
|
||||
app.mount('/static', StaticFiles(directory='static/dream_web'), name='static')
|
||||
app.mount("/static", StaticFiles(directory="static/dream_web"), name="static")
|
||||
|
||||
|
||||
@app.get("/docs", include_in_schema=False)
|
||||
def overridden_swagger():
|
||||
return get_swagger_ui_html(
|
||||
return get_swagger_ui_html(
|
||||
openapi_url=app.openapi_url,
|
||||
title=app.title,
|
||||
swagger_favicon_url="/static/favicon.ico"
|
||||
swagger_favicon_url="/static/favicon.ico",
|
||||
)
|
||||
|
||||
|
||||
@app.get("/redoc", include_in_schema=False)
|
||||
def overridden_redoc():
|
||||
return get_redoc_html(
|
||||
return get_redoc_html(
|
||||
openapi_url=app.openapi_url,
|
||||
title=app.title,
|
||||
redoc_favicon_url="/static/favicon.ico"
|
||||
redoc_favicon_url="/static/favicon.ico",
|
||||
)
|
||||
|
||||
|
||||
def invoke_api():
|
||||
# Start our own event loop for eventing usage
|
||||
# TODO: determine if there's a better way to do this
|
||||
loop = asyncio.new_event_loop()
|
||||
config = uvicorn.Config(
|
||||
app = app,
|
||||
host = "0.0.0.0",
|
||||
port = 9090,
|
||||
loop = loop)
|
||||
# Use access_log to turn off logging
|
||||
config = uvicorn.Config(app=app, host="0.0.0.0", port=9090, loop=loop)
|
||||
# Use access_log to turn off logging
|
||||
|
||||
server = uvicorn.Server(config)
|
||||
loop.run_until_complete(server.serve())
|
@ -1,33 +1,40 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
import argparse
|
||||
import shlex
|
||||
import os
|
||||
import shlex
|
||||
import time
|
||||
from typing import Any, Dict, Iterable, Literal, Union, get_args, get_origin, get_type_hints
|
||||
from typing import (
|
||||
Any,
|
||||
Dict,
|
||||
Iterable,
|
||||
Literal,
|
||||
Union,
|
||||
get_args,
|
||||
get_origin,
|
||||
get_type_hints,
|
||||
)
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic.fields import Field
|
||||
|
||||
from .services.processor import DefaultInvocationProcessor
|
||||
|
||||
from .services.graph import EdgeConnection, GraphExecutionState
|
||||
|
||||
from .services.sqlite import SqliteItemStorage
|
||||
|
||||
from ..args import Args
|
||||
from .invocations import *
|
||||
from .invocations.baseinvocation import BaseInvocation
|
||||
from .invocations.image import ImageField
|
||||
from .services.events import EventServiceBase
|
||||
from .services.generate_initializer import get_generate
|
||||
from .services.graph import EdgeConnection, GraphExecutionState
|
||||
from .services.image_storage import DiskImageStorage
|
||||
from .services.invocation_queue import MemoryInvocationQueue
|
||||
from .invocations.baseinvocation import BaseInvocation
|
||||
from .services.invocation_services import InvocationServices
|
||||
from .services.invoker import Invoker
|
||||
from .invocations import *
|
||||
from ..args import Args
|
||||
from .services.events import EventServiceBase
|
||||
from .services.processor import DefaultInvocationProcessor
|
||||
from .services.sqlite import SqliteItemStorage
|
||||
|
||||
|
||||
class InvocationCommand(BaseModel):
|
||||
invocation: Union[BaseInvocation.get_invocations()] = Field(discriminator="type") # type: ignore
|
||||
invocation: Union[BaseInvocation.get_invocations()] = Field(discriminator="type") # type: ignore
|
||||
|
||||
|
||||
class InvalidArgs(Exception):
|
||||
@ -35,47 +42,69 @@ class InvalidArgs(Exception):
|
||||
|
||||
|
||||
def get_invocation_parser() -> argparse.ArgumentParser:
|
||||
|
||||
# Create invocation parser
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
def exit(*args, **kwargs):
|
||||
raise InvalidArgs
|
||||
|
||||
parser.exit = exit
|
||||
|
||||
subparsers = parser.add_subparsers(dest='type')
|
||||
subparsers = parser.add_subparsers(dest="type")
|
||||
invocation_parsers = dict()
|
||||
|
||||
# Add history parser
|
||||
history_parser = subparsers.add_parser('history', help="Shows the invocation history")
|
||||
history_parser.add_argument('count', nargs='?', default=5, type=int, help="The number of history entries to show")
|
||||
history_parser = subparsers.add_parser(
|
||||
"history", help="Shows the invocation history"
|
||||
)
|
||||
history_parser.add_argument(
|
||||
"count",
|
||||
nargs="?",
|
||||
default=5,
|
||||
type=int,
|
||||
help="The number of history entries to show",
|
||||
)
|
||||
|
||||
# Add default parser
|
||||
default_parser = subparsers.add_parser('default', help="Define a default value for all inputs with a specified name")
|
||||
default_parser.add_argument('input', type=str, help="The input field")
|
||||
default_parser.add_argument('value', help="The default value")
|
||||
default_parser = subparsers.add_parser(
|
||||
"default", help="Define a default value for all inputs with a specified name"
|
||||
)
|
||||
default_parser.add_argument("input", type=str, help="The input field")
|
||||
default_parser.add_argument("value", help="The default value")
|
||||
|
||||
default_parser = subparsers.add_parser('reset_default', help="Resets a default value")
|
||||
default_parser.add_argument('input', type=str, help="The input field")
|
||||
default_parser = subparsers.add_parser(
|
||||
"reset_default", help="Resets a default value"
|
||||
)
|
||||
default_parser.add_argument("input", type=str, help="The input field")
|
||||
|
||||
# Create subparsers for each invocation
|
||||
invocations = BaseInvocation.get_all_subclasses()
|
||||
for invocation in invocations:
|
||||
hints = get_type_hints(invocation)
|
||||
cmd_name = get_args(hints['type'])[0]
|
||||
cmd_name = get_args(hints["type"])[0]
|
||||
command_parser = subparsers.add_parser(cmd_name, help=invocation.__doc__)
|
||||
invocation_parsers[cmd_name] = command_parser
|
||||
|
||||
# Add linking capability
|
||||
command_parser.add_argument('--link', '-l', action='append', nargs=3,
|
||||
help="A link in the format 'dest_field source_node source_field'. source_node can be relative to history (e.g. -1)")
|
||||
command_parser.add_argument(
|
||||
"--link",
|
||||
"-l",
|
||||
action="append",
|
||||
nargs=3,
|
||||
help="A link in the format 'dest_field source_node source_field'. source_node can be relative to history (e.g. -1)",
|
||||
)
|
||||
|
||||
command_parser.add_argument('--link_node', '-ln', action='append',
|
||||
help="A link from all fields in the specified node. Node can be relative to history (e.g. -1)")
|
||||
command_parser.add_argument(
|
||||
"--link_node",
|
||||
"-ln",
|
||||
action="append",
|
||||
help="A link from all fields in the specified node. Node can be relative to history (e.g. -1)",
|
||||
)
|
||||
|
||||
# Convert all fields to arguments
|
||||
fields = invocation.__fields__
|
||||
for name, field in fields.items():
|
||||
if name in ['id', 'type']:
|
||||
if name in ["id", "type"]:
|
||||
continue
|
||||
|
||||
if get_origin(field.type_) == Literal:
|
||||
@ -84,15 +113,15 @@ def get_invocation_parser() -> argparse.ArgumentParser:
|
||||
for val in allowed_values:
|
||||
allowed_types.add(type(val))
|
||||
allowed_types_list = list(allowed_types)
|
||||
field_type = allowed_types_list[0] if len(allowed_types) == 1 else Union[allowed_types_list] # type: ignore
|
||||
field_type = allowed_types_list[0] if len(allowed_types) == 1 else Union[allowed_types_list] # type: ignore
|
||||
|
||||
command_parser.add_argument(
|
||||
f"--{name}",
|
||||
dest=name,
|
||||
type=field_type,
|
||||
default=field.default,
|
||||
choices = allowed_values,
|
||||
help=field.field_info.description
|
||||
choices=allowed_values,
|
||||
help=field.field_info.description,
|
||||
)
|
||||
else:
|
||||
command_parser.add_argument(
|
||||
@ -100,7 +129,7 @@ def get_invocation_parser() -> argparse.ArgumentParser:
|
||||
dest=name,
|
||||
type=field.type_,
|
||||
default=field.default,
|
||||
help=field.field_info.description
|
||||
help=field.field_info.description,
|
||||
)
|
||||
|
||||
return parser
|
||||
@ -110,8 +139,8 @@ def get_invocation_command(invocation) -> str:
|
||||
fields = invocation.__fields__.items()
|
||||
type_hints = get_type_hints(type(invocation))
|
||||
command = [invocation.type]
|
||||
for name,field in fields:
|
||||
if name in ['id', 'type']:
|
||||
for name, field in fields:
|
||||
if name in ["id", "type"]:
|
||||
continue
|
||||
|
||||
# TODO: add links
|
||||
@ -127,17 +156,25 @@ def get_invocation_command(invocation) -> str:
|
||||
if type_hint is str or str in get_args(type_hint):
|
||||
command.append(f'--{name} "{field_value}"')
|
||||
else:
|
||||
command.append(f'--{name} {field_value}')
|
||||
command.append(f"--{name} {field_value}")
|
||||
|
||||
return ' '.join(command)
|
||||
return " ".join(command)
|
||||
|
||||
|
||||
def get_graph_execution_history(graph_execution_state: GraphExecutionState) -> Iterable[str]:
|
||||
def get_graph_execution_history(
|
||||
graph_execution_state: GraphExecutionState,
|
||||
) -> Iterable[str]:
|
||||
"""Gets the history of fully-executed invocations for a graph execution"""
|
||||
return (n for n in reversed(graph_execution_state.executed_history) if n in graph_execution_state.graph.nodes)
|
||||
return (
|
||||
n
|
||||
for n in reversed(graph_execution_state.executed_history)
|
||||
if n in graph_execution_state.graph.nodes
|
||||
)
|
||||
|
||||
|
||||
def generate_matching_edges(a: BaseInvocation, b: BaseInvocation) -> list[tuple[EdgeConnection, EdgeConnection]]:
|
||||
def generate_matching_edges(
|
||||
a: BaseInvocation, b: BaseInvocation
|
||||
) -> list[tuple[EdgeConnection, EdgeConnection]]:
|
||||
"""Generates all possible edges between two invocations"""
|
||||
atype = type(a)
|
||||
btype = type(b)
|
||||
@ -150,10 +187,16 @@ def generate_matching_edges(a: BaseInvocation, b: BaseInvocation) -> list[tuple[
|
||||
matching_fields = set(afields.keys()).intersection(bfields.keys())
|
||||
|
||||
# Remove invalid fields
|
||||
invalid_fields = set(['type', 'id'])
|
||||
invalid_fields = set(["type", "id"])
|
||||
matching_fields = matching_fields.difference(invalid_fields)
|
||||
|
||||
edges = [(EdgeConnection(node_id = a.id, field = field), EdgeConnection(node_id = b.id, field = field)) for field in matching_fields]
|
||||
edges = [
|
||||
(
|
||||
EdgeConnection(node_id=a.id, field=field),
|
||||
EdgeConnection(node_id=b.id, field=field),
|
||||
)
|
||||
for field in matching_fields
|
||||
]
|
||||
return edges
|
||||
|
||||
|
||||
@ -165,22 +208,26 @@ def invoke_cli():
|
||||
|
||||
# NOTE: load model on first use, uncomment to load at startup
|
||||
# TODO: Make this a config option?
|
||||
#generate.load_model()
|
||||
# generate.load_model()
|
||||
|
||||
events = EventServiceBase()
|
||||
|
||||
output_folder = os.path.abspath(os.path.join(os.path.dirname(__file__), '../../../outputs'))
|
||||
output_folder = os.path.abspath(
|
||||
os.path.join(os.path.dirname(__file__), "../../../outputs")
|
||||
)
|
||||
|
||||
# TODO: build a file/path manager?
|
||||
db_location = os.path.join(output_folder, 'invokeai.db')
|
||||
db_location = os.path.join(output_folder, "invokeai.db")
|
||||
|
||||
services = InvocationServices(
|
||||
generate = generate,
|
||||
events = events,
|
||||
images = DiskImageStorage(output_folder),
|
||||
queue = MemoryInvocationQueue(),
|
||||
graph_execution_manager = SqliteItemStorage[GraphExecutionState](filename = db_location, table_name = 'graph_executions'),
|
||||
processor = DefaultInvocationProcessor()
|
||||
generate=generate,
|
||||
events=events,
|
||||
images=DiskImageStorage(output_folder),
|
||||
queue=MemoryInvocationQueue(),
|
||||
graph_execution_manager=SqliteItemStorage[GraphExecutionState](
|
||||
filename=db_location, table_name="graph_executions"
|
||||
),
|
||||
processor=DefaultInvocationProcessor(),
|
||||
)
|
||||
|
||||
invoker = Invoker(services)
|
||||
@ -201,10 +248,10 @@ def invoke_cli():
|
||||
# Ctrl-c exits
|
||||
break
|
||||
|
||||
if cmd_input in ['exit','q']:
|
||||
break;
|
||||
if cmd_input in ["exit", "q"]:
|
||||
break
|
||||
|
||||
if cmd_input in ['--help','help','h','?']:
|
||||
if cmd_input in ["--help", "help", "h", "?"]:
|
||||
parser.print_help()
|
||||
continue
|
||||
|
||||
@ -214,65 +261,82 @@ def invoke_cli():
|
||||
history = list(get_graph_execution_history(session))
|
||||
|
||||
# Split the command for piping
|
||||
cmds = cmd_input.split('|')
|
||||
cmds = cmd_input.split("|")
|
||||
start_id = len(history)
|
||||
current_id = start_id
|
||||
new_invocations = list()
|
||||
for cmd in cmds:
|
||||
if cmd is None or cmd.strip() == '':
|
||||
raise InvalidArgs('Empty command')
|
||||
if cmd is None or cmd.strip() == "":
|
||||
raise InvalidArgs("Empty command")
|
||||
|
||||
# Parse args to create invocation
|
||||
args = vars(parser.parse_args(shlex.split(cmd.strip())))
|
||||
|
||||
# Check for special commands
|
||||
# TODO: These might be better as Pydantic models, similar to the invocations
|
||||
if args['type'] == 'history':
|
||||
history_count = args['count'] or 5
|
||||
if args["type"] == "history":
|
||||
history_count = args["count"] or 5
|
||||
for i in range(min(history_count, len(history))):
|
||||
entry_id = history[-1 - i]
|
||||
entry = session.graph.get_node(entry_id)
|
||||
print(f'{entry_id}: {get_invocation_command(entry.invocation)}')
|
||||
print(f"{entry_id}: {get_invocation_command(entry.invocation)}")
|
||||
continue
|
||||
|
||||
if args['type'] == 'reset_default':
|
||||
if args['input'] in defaults:
|
||||
del defaults[args['input']]
|
||||
if args["type"] == "reset_default":
|
||||
if args["input"] in defaults:
|
||||
del defaults[args["input"]]
|
||||
continue
|
||||
|
||||
if args['type'] == 'default':
|
||||
field = args['input']
|
||||
field_value = args['value']
|
||||
if args["type"] == "default":
|
||||
field = args["input"]
|
||||
field_value = args["value"]
|
||||
defaults[field] = field_value
|
||||
continue
|
||||
|
||||
# Override defaults
|
||||
for field_name,field_default in defaults.items():
|
||||
for field_name, field_default in defaults.items():
|
||||
if field_name in args:
|
||||
args[field_name] = field_default
|
||||
|
||||
# Parse invocation
|
||||
args['id'] = current_id
|
||||
command = InvocationCommand(invocation = args)
|
||||
args["id"] = current_id
|
||||
command = InvocationCommand(invocation=args)
|
||||
|
||||
# Pipe previous command output (if there was a previous command)
|
||||
edges = []
|
||||
if len(history) > 0 or current_id != start_id:
|
||||
from_id = history[0] if current_id == start_id else str(current_id - 1)
|
||||
from_node = next(filter(lambda n: n[0].id == from_id, new_invocations))[0] if current_id != start_id else session.graph.get_node(from_id)
|
||||
matching_edges = generate_matching_edges(from_node, command.invocation)
|
||||
from_id = (
|
||||
history[0] if current_id == start_id else str(current_id - 1)
|
||||
)
|
||||
from_node = (
|
||||
next(filter(lambda n: n[0].id == from_id, new_invocations))[0]
|
||||
if current_id != start_id
|
||||
else session.graph.get_node(from_id)
|
||||
)
|
||||
matching_edges = generate_matching_edges(
|
||||
from_node, command.invocation
|
||||
)
|
||||
edges.extend(matching_edges)
|
||||
|
||||
# Parse provided links
|
||||
if 'link_node' in args and args['link_node']:
|
||||
for link in args['link_node']:
|
||||
if "link_node" in args and args["link_node"]:
|
||||
for link in args["link_node"]:
|
||||
link_node = session.graph.get_node(link)
|
||||
matching_edges = generate_matching_edges(link_node, command.invocation)
|
||||
matching_edges = generate_matching_edges(
|
||||
link_node, command.invocation
|
||||
)
|
||||
edges.extend(matching_edges)
|
||||
|
||||
if 'link' in args and args['link']:
|
||||
for link in args['link']:
|
||||
edges.append((EdgeConnection(node_id = link[1], field = link[0]), EdgeConnection(node_id = command.invocation.id, field = link[2])))
|
||||
if "link" in args and args["link"]:
|
||||
for link in args["link"]:
|
||||
edges.append(
|
||||
(
|
||||
EdgeConnection(node_id=link[1], field=link[0]),
|
||||
EdgeConnection(
|
||||
node_id=command.invocation.id, field=link[2]
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
new_invocations.append((command.invocation, edges))
|
||||
|
||||
@ -286,7 +350,7 @@ def invoke_cli():
|
||||
session.add_edge(edge)
|
||||
|
||||
# Execute all available invocations
|
||||
invoker.invoke(session, invoke_all = True)
|
||||
invoker.invoke(session, invoke_all=True)
|
||||
while not session.is_complete():
|
||||
# Wait some time
|
||||
session = invoker.services.graph_execution_manager.get(session.id)
|
||||
@ -295,7 +359,9 @@ def invoke_cli():
|
||||
# Print any errors
|
||||
if session.has_error():
|
||||
for n in session.errors:
|
||||
print(f'Error in node {n} (source node {session.prepared_source_mapping[n]}): {session.errors[n]}')
|
||||
print(
|
||||
f"Error in node {n} (source node {session.prepared_source_mapping[n]}): {session.errors[n]}"
|
||||
)
|
||||
|
||||
# Start a new session
|
||||
print("Creating a new session")
|
@ -4,5 +4,9 @@ __all__ = []
|
||||
|
||||
dirname = os.path.dirname(os.path.abspath(__file__))
|
||||
for f in os.listdir(dirname):
|
||||
if f != "__init__.py" and os.path.isfile("%s/%s" % (dirname, f)) and f[-3:] == ".py":
|
||||
if (
|
||||
f != "__init__.py"
|
||||
and os.path.isfile("%s/%s" % (dirname, f))
|
||||
and f[-3:] == ".py"
|
||||
):
|
||||
__all__.append(f[:-3])
|
@ -3,7 +3,9 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from inspect import signature
|
||||
from typing import get_args, get_type_hints
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from ..services.invocation_services import InvocationServices
|
||||
|
||||
|
@ -1,30 +1,37 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
from typing import Literal
|
||||
import numpy
|
||||
from pydantic import Field
|
||||
from PIL import Image, ImageOps
|
||||
|
||||
import cv2 as cv
|
||||
from .image import ImageField, ImageOutput
|
||||
from .baseinvocation import BaseInvocation, InvocationContext
|
||||
import numpy
|
||||
from PIL import Image, ImageOps
|
||||
from pydantic import Field
|
||||
|
||||
from ..services.image_storage import ImageType
|
||||
from .baseinvocation import BaseInvocation, InvocationContext
|
||||
from .image import ImageField, ImageOutput
|
||||
|
||||
|
||||
class CvInpaintInvocation(BaseInvocation):
|
||||
"""Simple inpaint using opencv."""
|
||||
type: Literal['cv_inpaint'] = 'cv_inpaint'
|
||||
|
||||
type: Literal["cv_inpaint"] = "cv_inpaint"
|
||||
|
||||
# Inputs
|
||||
image: ImageField = Field(default=None, description="The image to inpaint")
|
||||
mask: ImageField = Field(default=None, description="The mask to use when inpainting")
|
||||
mask: ImageField = Field(
|
||||
default=None, description="The mask to use when inpainting"
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get(self.image.image_type, self.image.image_name)
|
||||
image = context.services.images.get(
|
||||
self.image.image_type, self.image.image_name
|
||||
)
|
||||
mask = context.services.images.get(self.mask.image_type, self.mask.image_name)
|
||||
|
||||
# Convert to cv image/mask
|
||||
# TODO: consider making these utility functions
|
||||
cv_image = cv.cvtColor(numpy.array(image.convert('RGB')), cv.COLOR_RGB2BGR)
|
||||
cv_image = cv.cvtColor(numpy.array(image.convert("RGB")), cv.COLOR_RGB2BGR)
|
||||
cv_mask = numpy.array(ImageOps.invert(mask))
|
||||
|
||||
# Inpaint
|
||||
@ -35,8 +42,10 @@ class CvInpaintInvocation(BaseInvocation):
|
||||
image_inpainted = Image.fromarray(cv.cvtColor(cv_inpainted, cv.COLOR_BGR2RGB))
|
||||
|
||||
image_type = ImageType.INTERMEDIATE
|
||||
image_name = context.services.images.create_name(context.graph_execution_state_id, self.id)
|
||||
image_name = context.services.images.create_name(
|
||||
context.graph_execution_state_id, self.id
|
||||
)
|
||||
context.services.images.save(image_type, image_name, image_inpainted)
|
||||
return ImageOutput(
|
||||
image = ImageField(image_type = image_type, image_name = image_name)
|
||||
image=ImageField(image_type=image_type, image_name=image_name)
|
||||
)
|
211
invokeai/app/invocations/generate.py
Normal file
211
invokeai/app/invocations/generate.py
Normal file
@ -0,0 +1,211 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Literal, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from pydantic import Field
|
||||
from skimage.exposure.histogram_matching import match_histograms
|
||||
|
||||
from ..services.image_storage import ImageType
|
||||
from ..services.invocation_services import InvocationServices
|
||||
from .baseinvocation import BaseInvocation, InvocationContext
|
||||
from .image import ImageField, ImageOutput
|
||||
|
||||
SAMPLER_NAME_VALUES = Literal[
|
||||
"ddim", "plms", "k_lms", "k_dpm_2", "k_dpm_2_a", "k_euler", "k_euler_a", "k_heun"
|
||||
]
|
||||
|
||||
|
||||
# Text to image
|
||||
class TextToImageInvocation(BaseInvocation):
|
||||
"""Generates an image using text2img."""
|
||||
|
||||
type: Literal["txt2img"] = "txt2img"
|
||||
|
||||
# Inputs
|
||||
# TODO: consider making prompt optional to enable providing prompt through a link
|
||||
# fmt: off
|
||||
prompt: Optional[str] = Field(description="The prompt to generate an image from")
|
||||
seed: int = Field(default=-1,ge=-1, le=np.iinfo(np.uint32).max, description="The seed to use (-1 for a random seed)", )
|
||||
steps: int = Field(default=10, gt=0, description="The number of steps to use to generate the image")
|
||||
width: int = Field(default=512, multiple_of=64, gt=0, description="The width of the resulting image", )
|
||||
height: int = Field(default=512, multiple_of=64, gt=0, description="The height of the resulting image", )
|
||||
cfg_scale: float = Field(default=7.5, gt=0, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", )
|
||||
sampler_name: SAMPLER_NAME_VALUES = Field(default="k_lms", description="The sampler to use" )
|
||||
seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", )
|
||||
model: str = Field(default="", description="The model to use (currently ignored)")
|
||||
progress_images: bool = Field(default=False, description="Whether or not to produce progress images during generation", )
|
||||
# fmt: on
|
||||
|
||||
# TODO: pass this an emitter method or something? or a session for dispatching?
|
||||
def dispatch_progress(
|
||||
self, context: InvocationContext, sample: Any = None, step: int = 0
|
||||
) -> None:
|
||||
context.services.events.emit_generator_progress(
|
||||
context.graph_execution_state_id,
|
||||
self.id,
|
||||
step,
|
||||
float(step) / float(self.steps),
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
def step_callback(sample, step=0):
|
||||
self.dispatch_progress(context, sample, step)
|
||||
|
||||
# Handle invalid model parameter
|
||||
# TODO: figure out if this can be done via a validator that uses the model_cache
|
||||
# TODO: How to get the default model name now?
|
||||
if self.model is None or self.model == "":
|
||||
self.model = context.services.generate.model_name
|
||||
|
||||
# Set the model (if already cached, this does nothing)
|
||||
context.services.generate.set_model(self.model)
|
||||
|
||||
results = context.services.generate.prompt2image(
|
||||
prompt=self.prompt,
|
||||
step_callback=step_callback,
|
||||
**self.dict(
|
||||
exclude={"prompt"}
|
||||
), # Shorthand for passing all of the parameters above manually
|
||||
)
|
||||
|
||||
# Results are image and seed, unwrap for now and ignore the seed
|
||||
# TODO: pre-seed?
|
||||
# TODO: can this return multiple results? Should it?
|
||||
image_type = ImageType.RESULT
|
||||
image_name = context.services.images.create_name(
|
||||
context.graph_execution_state_id, self.id
|
||||
)
|
||||
context.services.images.save(image_type, image_name, results[0][0])
|
||||
return ImageOutput(
|
||||
image=ImageField(image_type=image_type, image_name=image_name)
|
||||
)
|
||||
|
||||
|
||||
class ImageToImageInvocation(TextToImageInvocation):
|
||||
"""Generates an image using img2img."""
|
||||
|
||||
type: Literal["img2img"] = "img2img"
|
||||
|
||||
# Inputs
|
||||
image: Union[ImageField, None] = Field(description="The input image")
|
||||
strength: float = Field(
|
||||
default=0.75, gt=0, le=1, description="The strength of the original image"
|
||||
)
|
||||
fit: bool = Field(
|
||||
default=True,
|
||||
description="Whether or not the result should be fit to the aspect ratio of the input image",
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = (
|
||||
None
|
||||
if self.image is None
|
||||
else context.services.images.get(
|
||||
self.image.image_type, self.image.image_name
|
||||
)
|
||||
)
|
||||
mask = None
|
||||
|
||||
def step_callback(sample, step=0):
|
||||
self.dispatch_progress(context, sample, step)
|
||||
|
||||
# Handle invalid model parameter
|
||||
# TODO: figure out if this can be done via a validator that uses the model_cache
|
||||
# TODO: How to get the default model name now?
|
||||
if self.model is None or self.model == "":
|
||||
self.model = context.services.generate.model_name
|
||||
|
||||
# Set the model (if already cached, this does nothing)
|
||||
context.services.generate.set_model(self.model)
|
||||
|
||||
results = context.services.generate.prompt2image(
|
||||
prompt=self.prompt,
|
||||
init_img=image,
|
||||
init_mask=mask,
|
||||
step_callback=step_callback,
|
||||
**self.dict(
|
||||
exclude={"prompt", "image", "mask"}
|
||||
), # Shorthand for passing all of the parameters above manually
|
||||
)
|
||||
|
||||
result_image = results[0][0]
|
||||
|
||||
# Results are image and seed, unwrap for now and ignore the seed
|
||||
# TODO: pre-seed?
|
||||
# TODO: can this return multiple results? Should it?
|
||||
image_type = ImageType.RESULT
|
||||
image_name = context.services.images.create_name(
|
||||
context.graph_execution_state_id, self.id
|
||||
)
|
||||
context.services.images.save(image_type, image_name, result_image)
|
||||
return ImageOutput(
|
||||
image=ImageField(image_type=image_type, image_name=image_name)
|
||||
)
|
||||
|
||||
|
||||
class InpaintInvocation(ImageToImageInvocation):
|
||||
"""Generates an image using inpaint."""
|
||||
|
||||
type: Literal["inpaint"] = "inpaint"
|
||||
|
||||
# Inputs
|
||||
mask: Union[ImageField, None] = Field(description="The mask")
|
||||
inpaint_replace: float = Field(
|
||||
default=0.0,
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
description="The amount by which to replace masked areas with latent noise",
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = (
|
||||
None
|
||||
if self.image is None
|
||||
else context.services.images.get(
|
||||
self.image.image_type, self.image.image_name
|
||||
)
|
||||
)
|
||||
mask = (
|
||||
None
|
||||
if self.mask is None
|
||||
else context.services.images.get(self.mask.image_type, self.mask.image_name)
|
||||
)
|
||||
|
||||
def step_callback(sample, step=0):
|
||||
self.dispatch_progress(context, sample, step)
|
||||
|
||||
# Handle invalid model parameter
|
||||
# TODO: figure out if this can be done via a validator that uses the model_cache
|
||||
# TODO: How to get the default model name now?
|
||||
if self.model is None or self.model == "":
|
||||
self.model = context.services.generate.model_name
|
||||
|
||||
# Set the model (if already cached, this does nothing)
|
||||
context.services.generate.set_model(self.model)
|
||||
|
||||
results = context.services.generate.prompt2image(
|
||||
prompt=self.prompt,
|
||||
init_img=image,
|
||||
init_mask=mask,
|
||||
step_callback=step_callback,
|
||||
**self.dict(
|
||||
exclude={"prompt", "image", "mask"}
|
||||
), # Shorthand for passing all of the parameters above manually
|
||||
)
|
||||
|
||||
result_image = results[0][0]
|
||||
|
||||
# Results are image and seed, unwrap for now and ignore the seed
|
||||
# TODO: pre-seed?
|
||||
# TODO: can this return multiple results? Should it?
|
||||
image_type = ImageType.RESULT
|
||||
image_name = context.services.images.create_name(
|
||||
context.graph_execution_state_id, self.id
|
||||
)
|
||||
context.services.images.save(image_type, image_name, result_image)
|
||||
return ImageOutput(
|
||||
image=ImageField(image_type=image_type, image_name=image_name)
|
||||
)
|
@ -2,30 +2,37 @@
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from typing import Literal, Optional
|
||||
|
||||
import numpy
|
||||
from pydantic import Field, BaseModel
|
||||
from PIL import Image, ImageOps, ImageFilter
|
||||
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext
|
||||
from PIL import Image, ImageFilter, ImageOps
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from ..services.image_storage import ImageType
|
||||
from ..services.invocation_services import InvocationServices
|
||||
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext
|
||||
|
||||
|
||||
class ImageField(BaseModel):
|
||||
"""An image field used for passing image objects between invocations"""
|
||||
image_type: str = Field(default=ImageType.RESULT, description="The type of the image")
|
||||
|
||||
image_type: str = Field(
|
||||
default=ImageType.RESULT, description="The type of the image"
|
||||
)
|
||||
image_name: Optional[str] = Field(default=None, description="The name of the image")
|
||||
|
||||
|
||||
class ImageOutput(BaseInvocationOutput):
|
||||
"""Base class for invocations that output an image"""
|
||||
type: Literal['image'] = 'image'
|
||||
|
||||
type: Literal["image"] = "image"
|
||||
|
||||
image: ImageField = Field(default=None, description="The output image")
|
||||
|
||||
|
||||
class MaskOutput(BaseInvocationOutput):
|
||||
"""Base class for invocations that output a mask"""
|
||||
type: Literal['mask'] = 'mask'
|
||||
|
||||
type: Literal["mask"] = "mask"
|
||||
|
||||
mask: ImageField = Field(default=None, description="The output mask")
|
||||
|
||||
@ -33,7 +40,8 @@ class MaskOutput(BaseInvocationOutput):
|
||||
# TODO: this isn't really necessary anymore
|
||||
class LoadImageInvocation(BaseInvocation):
|
||||
"""Load an image from a filename and provide it as output."""
|
||||
type: Literal['load_image'] = 'load_image'
|
||||
|
||||
type: Literal["load_image"] = "load_image"
|
||||
|
||||
# Inputs
|
||||
image_type: ImageType = Field(description="The type of the image")
|
||||
@ -41,69 +49,100 @@ class LoadImageInvocation(BaseInvocation):
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
return ImageOutput(
|
||||
image = ImageField(image_type = self.image_type, image_name = self.image_name)
|
||||
image=ImageField(image_type=self.image_type, image_name=self.image_name)
|
||||
)
|
||||
|
||||
|
||||
class ShowImageInvocation(BaseInvocation):
|
||||
"""Displays a provided image, and passes it forward in the pipeline."""
|
||||
type: Literal['show_image'] = 'show_image'
|
||||
|
||||
type: Literal["show_image"] = "show_image"
|
||||
|
||||
# Inputs
|
||||
image: ImageField = Field(default=None, description="The image to show")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get(self.image.image_type, self.image.image_name)
|
||||
image = context.services.images.get(
|
||||
self.image.image_type, self.image.image_name
|
||||
)
|
||||
if image:
|
||||
image.show()
|
||||
|
||||
# TODO: how to handle failure?
|
||||
|
||||
return ImageOutput(
|
||||
image = ImageField(image_type = self.image.image_type, image_name = self.image.image_name)
|
||||
image=ImageField(
|
||||
image_type=self.image.image_type, image_name=self.image.image_name
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class CropImageInvocation(BaseInvocation):
|
||||
"""Crops an image to a specified box. The box can be outside of the image."""
|
||||
type: Literal['crop'] = 'crop'
|
||||
|
||||
type: Literal["crop"] = "crop"
|
||||
|
||||
# Inputs
|
||||
image: ImageField = Field(default=None, description="The image to crop")
|
||||
x: int = Field(default=0, description="The left x coordinate of the crop rectangle")
|
||||
y: int = Field(default=0, description="The top y coordinate of the crop rectangle")
|
||||
width: int = Field(default=512, gt=0, description="The width of the crop rectangle")
|
||||
height: int = Field(default=512, gt=0, description="The height of the crop rectangle")
|
||||
x: int = Field(default=0, description="The left x coordinate of the crop rectangle")
|
||||
y: int = Field(default=0, description="The top y coordinate of the crop rectangle")
|
||||
width: int = Field(default=512, gt=0, description="The width of the crop rectangle")
|
||||
height: int = Field(
|
||||
default=512, gt=0, description="The height of the crop rectangle"
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get(self.image.image_type, self.image.image_name)
|
||||
image = context.services.images.get(
|
||||
self.image.image_type, self.image.image_name
|
||||
)
|
||||
|
||||
image_crop = Image.new(mode = 'RGBA', size = (self.width, self.height), color = (0, 0, 0, 0))
|
||||
image_crop = Image.new(
|
||||
mode="RGBA", size=(self.width, self.height), color=(0, 0, 0, 0)
|
||||
)
|
||||
image_crop.paste(image, (-self.x, -self.y))
|
||||
|
||||
image_type = ImageType.INTERMEDIATE
|
||||
image_name = context.services.images.create_name(context.graph_execution_state_id, self.id)
|
||||
image_name = context.services.images.create_name(
|
||||
context.graph_execution_state_id, self.id
|
||||
)
|
||||
context.services.images.save(image_type, image_name, image_crop)
|
||||
return ImageOutput(
|
||||
image = ImageField(image_type = image_type, image_name = image_name)
|
||||
image=ImageField(image_type=image_type, image_name=image_name)
|
||||
)
|
||||
|
||||
|
||||
class PasteImageInvocation(BaseInvocation):
|
||||
"""Pastes an image into another image."""
|
||||
type: Literal['paste'] = 'paste'
|
||||
|
||||
type: Literal["paste"] = "paste"
|
||||
|
||||
# Inputs
|
||||
base_image: ImageField = Field(default=None, description="The base image")
|
||||
image: ImageField = Field(default=None, description="The image to paste")
|
||||
mask: Optional[ImageField] = Field(default=None, description="The mask to use when pasting")
|
||||
x: int = Field(default=0, description="The left x coordinate at which to paste the image")
|
||||
y: int = Field(default=0, description="The top y coordinate at which to paste the image")
|
||||
base_image: ImageField = Field(default=None, description="The base image")
|
||||
image: ImageField = Field(default=None, description="The image to paste")
|
||||
mask: Optional[ImageField] = Field(
|
||||
default=None, description="The mask to use when pasting"
|
||||
)
|
||||
x: int = Field(
|
||||
default=0, description="The left x coordinate at which to paste the image"
|
||||
)
|
||||
y: int = Field(
|
||||
default=0, description="The top y coordinate at which to paste the image"
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
base_image = context.services.images.get(self.base_image.image_type, self.base_image.image_name)
|
||||
image = context.services.images.get(self.image.image_type, self.image.image_name)
|
||||
mask = None if self.mask is None else ImageOps.invert(services.images.get(self.mask.image_type, self.mask.image_name))
|
||||
base_image = context.services.images.get(
|
||||
self.base_image.image_type, self.base_image.image_name
|
||||
)
|
||||
image = context.services.images.get(
|
||||
self.image.image_type, self.image.image_name
|
||||
)
|
||||
mask = (
|
||||
None
|
||||
if self.mask is None
|
||||
else ImageOps.invert(
|
||||
services.images.get(self.mask.image_type, self.mask.image_name)
|
||||
)
|
||||
)
|
||||
# TODO: probably shouldn't invert mask here... should user be required to do it?
|
||||
|
||||
min_x = min(0, self.x)
|
||||
@ -111,67 +150,88 @@ class PasteImageInvocation(BaseInvocation):
|
||||
max_x = max(base_image.width, image.width + self.x)
|
||||
max_y = max(base_image.height, image.height + self.y)
|
||||
|
||||
new_image = Image.new(mode = 'RGBA', size = (max_x - min_x, max_y - min_y), color = (0, 0, 0, 0))
|
||||
new_image = Image.new(
|
||||
mode="RGBA", size=(max_x - min_x, max_y - min_y), color=(0, 0, 0, 0)
|
||||
)
|
||||
new_image.paste(base_image, (abs(min_x), abs(min_y)))
|
||||
new_image.paste(image, (max(0, self.x), max(0, self.y)), mask = mask)
|
||||
new_image.paste(image, (max(0, self.x), max(0, self.y)), mask=mask)
|
||||
|
||||
image_type = ImageType.RESULT
|
||||
image_name = context.services.images.create_name(context.graph_execution_state_id, self.id)
|
||||
image_name = context.services.images.create_name(
|
||||
context.graph_execution_state_id, self.id
|
||||
)
|
||||
context.services.images.save(image_type, image_name, new_image)
|
||||
return ImageOutput(
|
||||
image = ImageField(image_type = image_type, image_name = image_name)
|
||||
image=ImageField(image_type=image_type, image_name=image_name)
|
||||
)
|
||||
|
||||
|
||||
class MaskFromAlphaInvocation(BaseInvocation):
|
||||
"""Extracts the alpha channel of an image as a mask."""
|
||||
type: Literal['tomask'] = 'tomask'
|
||||
|
||||
type: Literal["tomask"] = "tomask"
|
||||
|
||||
# Inputs
|
||||
image: ImageField = Field(default=None, description="The image to create the mask from")
|
||||
image: ImageField = Field(
|
||||
default=None, description="The image to create the mask from"
|
||||
)
|
||||
invert: bool = Field(default=False, description="Whether or not to invert the mask")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> MaskOutput:
|
||||
image = context.services.images.get(self.image.image_type, self.image.image_name)
|
||||
image = context.services.images.get(
|
||||
self.image.image_type, self.image.image_name
|
||||
)
|
||||
|
||||
image_mask = image.split()[-1]
|
||||
if self.invert:
|
||||
image_mask = ImageOps.invert(image_mask)
|
||||
|
||||
image_type = ImageType.INTERMEDIATE
|
||||
image_name = context.services.images.create_name(context.graph_execution_state_id, self.id)
|
||||
context.services.images.save(image_type, image_name, image_mask)
|
||||
return MaskOutput(
|
||||
mask = ImageField(image_type = image_type, image_name = image_name)
|
||||
image_name = context.services.images.create_name(
|
||||
context.graph_execution_state_id, self.id
|
||||
)
|
||||
context.services.images.save(image_type, image_name, image_mask)
|
||||
return MaskOutput(mask=ImageField(image_type=image_type, image_name=image_name))
|
||||
|
||||
|
||||
class BlurInvocation(BaseInvocation):
|
||||
"""Blurs an image"""
|
||||
type: Literal['blur'] = 'blur'
|
||||
|
||||
type: Literal["blur"] = "blur"
|
||||
|
||||
# Inputs
|
||||
image: ImageField = Field(default=None, description="The image to blur")
|
||||
radius: float = Field(default=8.0, ge=0, description="The blur radius")
|
||||
blur_type: Literal['gaussian', 'box'] = Field(default='gaussian', description="The type of blur")
|
||||
radius: float = Field(default=8.0, ge=0, description="The blur radius")
|
||||
blur_type: Literal["gaussian", "box"] = Field(
|
||||
default="gaussian", description="The type of blur"
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get(self.image.image_type, self.image.image_name)
|
||||
image = context.services.images.get(
|
||||
self.image.image_type, self.image.image_name
|
||||
)
|
||||
|
||||
blur = ImageFilter.GaussianBlur(self.radius) if self.blur_type == 'gaussian' else ImageFilter.BoxBlur(self.radius)
|
||||
blur = (
|
||||
ImageFilter.GaussianBlur(self.radius)
|
||||
if self.blur_type == "gaussian"
|
||||
else ImageFilter.BoxBlur(self.radius)
|
||||
)
|
||||
blur_image = image.filter(blur)
|
||||
|
||||
image_type = ImageType.INTERMEDIATE
|
||||
image_name = context.services.images.create_name(context.graph_execution_state_id, self.id)
|
||||
image_name = context.services.images.create_name(
|
||||
context.graph_execution_state_id, self.id
|
||||
)
|
||||
context.services.images.save(image_type, image_name, blur_image)
|
||||
return ImageOutput(
|
||||
image = ImageField(image_type = image_type, image_name = image_name)
|
||||
image=ImageField(image_type=image_type, image_name=image_name)
|
||||
)
|
||||
|
||||
|
||||
class LerpInvocation(BaseInvocation):
|
||||
"""Linear interpolation of all pixels of an image"""
|
||||
type: Literal['lerp'] = 'lerp'
|
||||
|
||||
type: Literal["lerp"] = "lerp"
|
||||
|
||||
# Inputs
|
||||
image: ImageField = Field(default=None, description="The image to lerp")
|
||||
@ -179,7 +239,9 @@ class LerpInvocation(BaseInvocation):
|
||||
max: int = Field(default=255, ge=0, le=255, description="The maximum output value")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get(self.image.image_type, self.image.image_name)
|
||||
image = context.services.images.get(
|
||||
self.image.image_type, self.image.image_name
|
||||
)
|
||||
|
||||
image_arr = numpy.asarray(image, dtype=numpy.float32) / 255
|
||||
image_arr = image_arr * (self.max - self.min) + self.max
|
||||
@ -187,33 +249,46 @@ class LerpInvocation(BaseInvocation):
|
||||
lerp_image = Image.fromarray(numpy.uint8(image_arr))
|
||||
|
||||
image_type = ImageType.INTERMEDIATE
|
||||
image_name = context.services.images.create_name(context.graph_execution_state_id, self.id)
|
||||
image_name = context.services.images.create_name(
|
||||
context.graph_execution_state_id, self.id
|
||||
)
|
||||
context.services.images.save(image_type, image_name, lerp_image)
|
||||
return ImageOutput(
|
||||
image = ImageField(image_type = image_type, image_name = image_name)
|
||||
image=ImageField(image_type=image_type, image_name=image_name)
|
||||
)
|
||||
|
||||
|
||||
class InverseLerpInvocation(BaseInvocation):
|
||||
"""Inverse linear interpolation of all pixels of an image"""
|
||||
type: Literal['ilerp'] = 'ilerp'
|
||||
#fmt: off
|
||||
type: Literal["ilerp"] = "ilerp"
|
||||
|
||||
# Inputs
|
||||
image: ImageField = Field(default=None, description="The image to lerp")
|
||||
min: int = Field(default=0, ge=0, le=255, description="The minimum input value")
|
||||
max: int = Field(default=255, ge=0, le=255, description="The maximum input value")
|
||||
#fmt: on
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get(self.image.image_type, self.image.image_name)
|
||||
image = context.services.images.get(
|
||||
self.image.image_type, self.image.image_name
|
||||
)
|
||||
|
||||
image_arr = numpy.asarray(image, dtype=numpy.float32)
|
||||
image_arr = numpy.minimum(numpy.maximum(image_arr - self.min, 0) / float(self.max - self.min), 1) * 255
|
||||
image_arr = (
|
||||
numpy.minimum(
|
||||
numpy.maximum(image_arr - self.min, 0) / float(self.max - self.min), 1
|
||||
)
|
||||
* 255
|
||||
)
|
||||
|
||||
ilerp_image = Image.fromarray(numpy.uint8(image_arr))
|
||||
|
||||
image_type = ImageType.INTERMEDIATE
|
||||
image_name = context.services.images.create_name(context.graph_execution_state_id, self.id)
|
||||
image_name = context.services.images.create_name(
|
||||
context.graph_execution_state_id, self.id
|
||||
)
|
||||
context.services.images.save(image_type, image_name, ilerp_image)
|
||||
return ImageOutput(
|
||||
image = ImageField(image_type = image_type, image_name = image_name)
|
||||
image=ImageField(image_type=image_type, image_name=image_name)
|
||||
)
|
@ -1,9 +1,14 @@
|
||||
from typing import Literal
|
||||
|
||||
from pydantic.fields import Field
|
||||
|
||||
from .baseinvocation import BaseInvocationOutput
|
||||
|
||||
|
||||
class PromptOutput(BaseInvocationOutput):
|
||||
"""Base class for invocations that output a prompt"""
|
||||
type: Literal['prompt'] = 'prompt'
|
||||
#fmt: off
|
||||
type: Literal["prompt"] = "prompt"
|
||||
|
||||
prompt: str = Field(default=None, description="The output prompt")
|
||||
#fmt: on
|
@ -1,36 +1,43 @@
|
||||
from datetime import datetime, timezone
|
||||
from typing import Literal, Union
|
||||
|
||||
from pydantic import Field
|
||||
from .image import ImageField, ImageOutput
|
||||
from .baseinvocation import BaseInvocation, InvocationContext
|
||||
|
||||
from ..services.image_storage import ImageType
|
||||
from ..services.invocation_services import InvocationServices
|
||||
from .baseinvocation import BaseInvocation, InvocationContext
|
||||
from .image import ImageField, ImageOutput
|
||||
|
||||
|
||||
class RestoreFaceInvocation(BaseInvocation):
|
||||
"""Restores faces in an image."""
|
||||
type: Literal['restore_face'] = 'restore_face'
|
||||
#fmt: off
|
||||
type: Literal["restore_face"] = "restore_face"
|
||||
|
||||
# Inputs
|
||||
image: Union[ImageField,None] = Field(description="The input image")
|
||||
strength: float = Field(default=0.75, gt=0, le=1, description="The strength of the restoration")
|
||||
|
||||
image: Union[ImageField, None] = Field(description="The input image")
|
||||
strength: float = Field(default=0.75, gt=0, le=1, description="The strength of the restoration" )
|
||||
#fmt: on
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get(self.image.image_type, self.image.image_name)
|
||||
image = context.services.images.get(
|
||||
self.image.image_type, self.image.image_name
|
||||
)
|
||||
results = context.services.generate.upscale_and_reconstruct(
|
||||
image_list = [[image, 0]],
|
||||
upscale = None,
|
||||
strength = self.strength, # GFPGAN strength
|
||||
save_original = False,
|
||||
image_callback = None,
|
||||
image_list=[[image, 0]],
|
||||
upscale=None,
|
||||
strength=self.strength, # GFPGAN strength
|
||||
save_original=False,
|
||||
image_callback=None,
|
||||
)
|
||||
|
||||
# Results are image and seed, unwrap for now
|
||||
# TODO: can this return multiple results?
|
||||
image_type = ImageType.RESULT
|
||||
image_name = context.services.images.create_name(context.graph_execution_state_id, self.id)
|
||||
image_name = context.services.images.create_name(
|
||||
context.graph_execution_state_id, self.id
|
||||
)
|
||||
context.services.images.save(image_type, image_name, results[0][0])
|
||||
return ImageOutput(
|
||||
image = ImageField(image_type = image_type, image_name = image_name)
|
||||
image=ImageField(image_type=image_type, image_name=image_name)
|
||||
)
|
@ -2,37 +2,45 @@
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from typing import Literal, Union
|
||||
|
||||
from pydantic import Field
|
||||
from .image import ImageField, ImageOutput
|
||||
from .baseinvocation import BaseInvocation, InvocationContext
|
||||
|
||||
from ..services.image_storage import ImageType
|
||||
from ..services.invocation_services import InvocationServices
|
||||
from .baseinvocation import BaseInvocation, InvocationContext
|
||||
from .image import ImageField, ImageOutput
|
||||
|
||||
|
||||
class UpscaleInvocation(BaseInvocation):
|
||||
"""Upscales an image."""
|
||||
type: Literal['upscale'] = 'upscale'
|
||||
#fmt: off
|
||||
type: Literal["upscale"] = "upscale"
|
||||
|
||||
# Inputs
|
||||
image: Union[ImageField,None] = Field(description="The input image", default=None)
|
||||
strength: float = Field(default=0.75, gt=0, le=1, description="The strength")
|
||||
level: Literal[2,4] = Field(default=2, description = "The upscale level")
|
||||
image: Union[ImageField, None] = Field(description="The input image", default=None)
|
||||
strength: float = Field(default=0.75, gt=0, le=1, description="The strength")
|
||||
level: Literal[2, 4] = Field(default=2, description="The upscale level")
|
||||
#fmt: on
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get(self.image.image_type, self.image.image_name)
|
||||
image = context.services.images.get(
|
||||
self.image.image_type, self.image.image_name
|
||||
)
|
||||
results = context.services.generate.upscale_and_reconstruct(
|
||||
image_list = [[image, 0]],
|
||||
upscale = (self.level, self.strength),
|
||||
strength = 0.0, # GFPGAN strength
|
||||
save_original = False,
|
||||
image_callback = None,
|
||||
image_list=[[image, 0]],
|
||||
upscale=(self.level, self.strength),
|
||||
strength=0.0, # GFPGAN strength
|
||||
save_original=False,
|
||||
image_callback=None,
|
||||
)
|
||||
|
||||
# Results are image and seed, unwrap for now
|
||||
# TODO: can this return multiple results?
|
||||
image_type = ImageType.RESULT
|
||||
image_name = context.services.images.create_name(context.graph_execution_state_id, self.id)
|
||||
image_name = context.services.images.create_name(
|
||||
context.graph_execution_state_id, self.id
|
||||
)
|
||||
context.services.images.save(image_type, image_name, results[0][0])
|
||||
return ImageOutput(
|
||||
image = ImageField(image_type = image_type, image_name = image_name)
|
||||
image=ImageField(image_type=image_type, image_name=image_name)
|
||||
)
|
83
invokeai/app/services/events.py
Normal file
83
invokeai/app/services/events.py
Normal file
@ -0,0 +1,83 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
from typing import Any, Dict
|
||||
|
||||
|
||||
class EventServiceBase:
|
||||
session_event: str = "session_event"
|
||||
|
||||
"""Basic event bus, to have an empty stand-in when not needed"""
|
||||
|
||||
def dispatch(self, event_name: str, payload: Any) -> None:
|
||||
pass
|
||||
|
||||
def __emit_session_event(self, event_name: str, payload: Dict) -> None:
|
||||
self.dispatch(
|
||||
event_name=EventServiceBase.session_event,
|
||||
payload=dict(event=event_name, data=payload),
|
||||
)
|
||||
|
||||
# Define events here for every event in the system.
|
||||
# This will make them easier to integrate until we find a schema generator.
|
||||
def emit_generator_progress(
|
||||
self,
|
||||
graph_execution_state_id: str,
|
||||
invocation_id: str,
|
||||
step: int,
|
||||
percent: float,
|
||||
) -> None:
|
||||
"""Emitted when there is generation progress"""
|
||||
self.__emit_session_event(
|
||||
event_name="generator_progress",
|
||||
payload=dict(
|
||||
graph_execution_state_id=graph_execution_state_id,
|
||||
invocation_id=invocation_id,
|
||||
step=step,
|
||||
percent=percent,
|
||||
),
|
||||
)
|
||||
|
||||
def emit_invocation_complete(
|
||||
self, graph_execution_state_id: str, invocation_id: str, result: Dict
|
||||
) -> None:
|
||||
"""Emitted when an invocation has completed"""
|
||||
self.__emit_session_event(
|
||||
event_name="invocation_complete",
|
||||
payload=dict(
|
||||
graph_execution_state_id=graph_execution_state_id,
|
||||
invocation_id=invocation_id,
|
||||
result=result,
|
||||
),
|
||||
)
|
||||
|
||||
def emit_invocation_error(
|
||||
self, graph_execution_state_id: str, invocation_id: str, error: str
|
||||
) -> None:
|
||||
"""Emitted when an invocation has completed"""
|
||||
self.__emit_session_event(
|
||||
event_name="invocation_error",
|
||||
payload=dict(
|
||||
graph_execution_state_id=graph_execution_state_id,
|
||||
invocation_id=invocation_id,
|
||||
error=error,
|
||||
),
|
||||
)
|
||||
|
||||
def emit_invocation_started(
|
||||
self, graph_execution_state_id: str, invocation_id: str
|
||||
) -> None:
|
||||
"""Emitted when an invocation has started"""
|
||||
self.__emit_session_event(
|
||||
event_name="invocation_started",
|
||||
payload=dict(
|
||||
graph_execution_state_id=graph_execution_state_id,
|
||||
invocation_id=invocation_id,
|
||||
),
|
||||
)
|
||||
|
||||
def emit_graph_execution_complete(self, graph_execution_state_id: str) -> None:
|
||||
"""Emitted when a session has completed all invocations"""
|
||||
self.__emit_session_event(
|
||||
event_name="graph_execution_state_complete",
|
||||
payload=dict(graph_execution_state_id=graph_execution_state_id),
|
||||
)
|
@ -1,42 +1,47 @@
|
||||
from argparse import Namespace
|
||||
import os
|
||||
import sys
|
||||
import traceback
|
||||
from argparse import Namespace
|
||||
|
||||
from ...model_manager import ModelManager
|
||||
import invokeai.version
|
||||
from invokeai.backend import Generate, ModelManager
|
||||
|
||||
from ...globals import Globals
|
||||
from ....generate import Generate
|
||||
import ldm.invoke
|
||||
|
||||
|
||||
# TODO: most of this code should be split into individual services as the Generate.py code is deprecated
|
||||
def get_generate(args, config) -> Generate:
|
||||
if not args.conf:
|
||||
config_file = os.path.join(Globals.root,'configs','models.yaml')
|
||||
config_file = os.path.join(Globals.root, "configs", "models.yaml")
|
||||
if not os.path.exists(config_file):
|
||||
report_model_error(args, FileNotFoundError(f"The file {config_file} could not be found."))
|
||||
report_model_error(
|
||||
args, FileNotFoundError(f"The file {config_file} could not be found.")
|
||||
)
|
||||
|
||||
print(f'>> {ldm.invoke.__app_name__}, version {ldm.invoke.__version__}')
|
||||
print(f">> {invokeai.version.__app_name__}, version {invokeai.version.__version__}")
|
||||
print(f'>> InvokeAI runtime directory is "{Globals.root}"')
|
||||
|
||||
# these two lines prevent a horrible warning message from appearing
|
||||
# when the frozen CLIP tokenizer is imported
|
||||
import transformers # type: ignore
|
||||
|
||||
transformers.logging.set_verbosity_error()
|
||||
import diffusers
|
||||
|
||||
diffusers.logging.set_verbosity_error()
|
||||
|
||||
# Loading Face Restoration and ESRGAN Modules
|
||||
gfpgan,codeformer,esrgan = load_face_restoration(args)
|
||||
gfpgan, codeformer, esrgan = load_face_restoration(args)
|
||||
|
||||
# normalize the config directory relative to root
|
||||
if not os.path.isabs(args.conf):
|
||||
args.conf = os.path.normpath(os.path.join(Globals.root,args.conf))
|
||||
args.conf = os.path.normpath(os.path.join(Globals.root, args.conf))
|
||||
|
||||
if args.embeddings:
|
||||
if not os.path.isabs(args.embedding_path):
|
||||
embedding_path = os.path.normpath(os.path.join(Globals.root,args.embedding_path))
|
||||
embedding_path = os.path.normpath(
|
||||
os.path.join(Globals.root, args.embedding_path)
|
||||
)
|
||||
else:
|
||||
embedding_path = args.embedding_path
|
||||
else:
|
||||
@ -49,35 +54,35 @@ def get_generate(args, config) -> Generate:
|
||||
if args.infile:
|
||||
try:
|
||||
if os.path.isfile(args.infile):
|
||||
infile = open(args.infile, 'r', encoding='utf-8')
|
||||
elif args.infile == '-': # stdin
|
||||
infile = open(args.infile, "r", encoding="utf-8")
|
||||
elif args.infile == "-": # stdin
|
||||
infile = sys.stdin
|
||||
else:
|
||||
raise FileNotFoundError(f'{args.infile} not found.')
|
||||
raise FileNotFoundError(f"{args.infile} not found.")
|
||||
except (FileNotFoundError, IOError) as e:
|
||||
print(f'{e}. Aborting.')
|
||||
print(f"{e}. Aborting.")
|
||||
sys.exit(-1)
|
||||
|
||||
# creating a Generate object:
|
||||
try:
|
||||
gen = Generate(
|
||||
conf = args.conf,
|
||||
model = args.model,
|
||||
sampler_name = args.sampler_name,
|
||||
embedding_path = embedding_path,
|
||||
full_precision = args.full_precision,
|
||||
precision = args.precision,
|
||||
gfpgan = gfpgan,
|
||||
codeformer = codeformer,
|
||||
esrgan = esrgan,
|
||||
free_gpu_mem = args.free_gpu_mem,
|
||||
safety_checker = args.safety_checker,
|
||||
max_loaded_models = args.max_loaded_models,
|
||||
)
|
||||
conf=args.conf,
|
||||
model=args.model,
|
||||
sampler_name=args.sampler_name,
|
||||
embedding_path=embedding_path,
|
||||
full_precision=args.full_precision,
|
||||
precision=args.precision,
|
||||
gfpgan=gfpgan,
|
||||
codeformer=codeformer,
|
||||
esrgan=esrgan,
|
||||
free_gpu_mem=args.free_gpu_mem,
|
||||
safety_checker=args.safety_checker,
|
||||
max_loaded_models=args.max_loaded_models,
|
||||
)
|
||||
except (FileNotFoundError, TypeError, AssertionError) as e:
|
||||
report_model_error(opt,e)
|
||||
report_model_error(opt, e)
|
||||
except (IOError, KeyError) as e:
|
||||
print(f'{e}. Aborting.')
|
||||
print(f"{e}. Aborting.")
|
||||
sys.exit(-1)
|
||||
|
||||
if args.seamless:
|
||||
@ -106,51 +111,61 @@ def load_face_restoration(opt):
|
||||
try:
|
||||
gfpgan, codeformer, esrgan = None, None, None
|
||||
if opt.restore or opt.esrgan:
|
||||
from ldm.invoke.restoration import Restoration
|
||||
from invokeai.backend.restoration import Restoration
|
||||
|
||||
restoration = Restoration()
|
||||
if opt.restore:
|
||||
gfpgan, codeformer = restoration.load_face_restore_models(opt.gfpgan_model_path)
|
||||
gfpgan, codeformer = restoration.load_face_restore_models(
|
||||
opt.gfpgan_model_path
|
||||
)
|
||||
else:
|
||||
print('>> Face restoration disabled')
|
||||
print(">> Face restoration disabled")
|
||||
if opt.esrgan:
|
||||
esrgan = restoration.load_esrgan(opt.esrgan_bg_tile)
|
||||
else:
|
||||
print('>> Upscaling disabled')
|
||||
print(">> Upscaling disabled")
|
||||
else:
|
||||
print('>> Face restoration and upscaling disabled')
|
||||
print(">> Face restoration and upscaling disabled")
|
||||
except (ModuleNotFoundError, ImportError):
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
print('>> You may need to install the ESRGAN and/or GFPGAN modules')
|
||||
return gfpgan,codeformer,esrgan
|
||||
print(">> You may need to install the ESRGAN and/or GFPGAN modules")
|
||||
return gfpgan, codeformer, esrgan
|
||||
|
||||
|
||||
def report_model_error(opt:Namespace, e:Exception):
|
||||
def report_model_error(opt: Namespace, e: Exception):
|
||||
print(f'** An error occurred while attempting to initialize the model: "{str(e)}"')
|
||||
print('** This can be caused by a missing or corrupted models file, and can sometimes be fixed by (re)installing the models.')
|
||||
yes_to_all = os.environ.get('INVOKE_MODEL_RECONFIGURE')
|
||||
print(
|
||||
"** This can be caused by a missing or corrupted models file, and can sometimes be fixed by (re)installing the models."
|
||||
)
|
||||
yes_to_all = os.environ.get("INVOKE_MODEL_RECONFIGURE")
|
||||
if yes_to_all:
|
||||
print('** Reconfiguration is being forced by environment variable INVOKE_MODEL_RECONFIGURE')
|
||||
print(
|
||||
"** Reconfiguration is being forced by environment variable INVOKE_MODEL_RECONFIGURE"
|
||||
)
|
||||
else:
|
||||
response = input('Do you want to run invokeai-configure script to select and/or reinstall models? [y] ')
|
||||
if response.startswith(('n', 'N')):
|
||||
response = input(
|
||||
"Do you want to run invokeai-configure script to select and/or reinstall models? [y] "
|
||||
)
|
||||
if response.startswith(("n", "N")):
|
||||
return
|
||||
|
||||
print('invokeai-configure is launching....\n')
|
||||
print("invokeai-configure is launching....\n")
|
||||
|
||||
# Match arguments that were set on the CLI
|
||||
# only the arguments accepted by the configuration script are parsed
|
||||
root_dir = ["--root", opt.root_dir] if opt.root_dir is not None else []
|
||||
config = ["--config", opt.conf] if opt.conf is not None else []
|
||||
previous_args = sys.argv
|
||||
sys.argv = [ 'invokeai-configure' ]
|
||||
sys.argv = ["invokeai-configure"]
|
||||
sys.argv.extend(root_dir)
|
||||
sys.argv.extend(config)
|
||||
if yes_to_all is not None:
|
||||
for arg in yes_to_all.split():
|
||||
sys.argv.append(arg)
|
||||
|
||||
from ldm.invoke.config import invokeai_configure
|
||||
invokeai_configure.main()
|
||||
from invokeai.frontend.install import invokeai_configure
|
||||
|
||||
invokeai_configure()
|
||||
# TODO: Figure out how to restart
|
||||
# print('** InvokeAI will now restart')
|
||||
# sys.argv = previous_args
|
||||
@ -161,10 +176,12 @@ def report_model_error(opt:Namespace, e:Exception):
|
||||
# Temporary initializer for Generate until we migrate off of it
|
||||
def old_get_generate(args, config) -> Generate:
|
||||
# TODO: Remove the need for globals
|
||||
from ldm.invoke.globals import Globals
|
||||
from invokeai.backend.globals import Globals
|
||||
|
||||
# alert - setting globals here
|
||||
Globals.root = os.path.expanduser(args.root_dir or os.environ.get('INVOKEAI_ROOT') or os.path.abspath('.'))
|
||||
Globals.root = os.path.expanduser(
|
||||
args.root_dir or os.environ.get("INVOKEAI_ROOT") or os.path.abspath(".")
|
||||
)
|
||||
Globals.try_patchmatch = args.patchmatch
|
||||
|
||||
print(f'>> InvokeAI runtime directory is "{Globals.root}"')
|
||||
@ -172,6 +189,7 @@ def old_get_generate(args, config) -> Generate:
|
||||
# these two lines prevent a horrible warning message from appearing
|
||||
# when the frozen CLIP tokenizer is imported
|
||||
import transformers
|
||||
|
||||
transformers.logging.set_verbosity_error()
|
||||
|
||||
# Loading Face Restoration and ESRGAN Modules
|
||||
@ -179,53 +197,57 @@ def old_get_generate(args, config) -> Generate:
|
||||
try:
|
||||
if config.restore or config.esrgan:
|
||||
from ldm.invoke.restoration import Restoration
|
||||
|
||||
restoration = Restoration()
|
||||
if config.restore:
|
||||
gfpgan, codeformer = restoration.load_face_restore_models(config.gfpgan_model_path)
|
||||
gfpgan, codeformer = restoration.load_face_restore_models(
|
||||
config.gfpgan_model_path
|
||||
)
|
||||
else:
|
||||
print('>> Face restoration disabled')
|
||||
print(">> Face restoration disabled")
|
||||
if config.esrgan:
|
||||
esrgan = restoration.load_esrgan(config.esrgan_bg_tile)
|
||||
else:
|
||||
print('>> Upscaling disabled')
|
||||
print(">> Upscaling disabled")
|
||||
else:
|
||||
print('>> Face restoration and upscaling disabled')
|
||||
print(">> Face restoration and upscaling disabled")
|
||||
except (ModuleNotFoundError, ImportError):
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
print('>> You may need to install the ESRGAN and/or GFPGAN modules')
|
||||
print(">> You may need to install the ESRGAN and/or GFPGAN modules")
|
||||
|
||||
# normalize the config directory relative to root
|
||||
if not os.path.isabs(config.conf):
|
||||
config.conf = os.path.normpath(os.path.join(Globals.root,config.conf))
|
||||
config.conf = os.path.normpath(os.path.join(Globals.root, config.conf))
|
||||
|
||||
if config.embeddings:
|
||||
if not os.path.isabs(config.embedding_path):
|
||||
embedding_path = os.path.normpath(os.path.join(Globals.root,config.embedding_path))
|
||||
embedding_path = os.path.normpath(
|
||||
os.path.join(Globals.root, config.embedding_path)
|
||||
)
|
||||
else:
|
||||
embedding_path = None
|
||||
|
||||
|
||||
# TODO: lazy-initialize this by wrapping it
|
||||
try:
|
||||
generate = Generate(
|
||||
conf = config.conf,
|
||||
model = config.model,
|
||||
sampler_name = config.sampler_name,
|
||||
embedding_path = embedding_path,
|
||||
full_precision = config.full_precision,
|
||||
precision = config.precision,
|
||||
gfpgan = gfpgan,
|
||||
codeformer = codeformer,
|
||||
esrgan = esrgan,
|
||||
free_gpu_mem = config.free_gpu_mem,
|
||||
safety_checker = config.safety_checker,
|
||||
max_loaded_models = config.max_loaded_models,
|
||||
conf=config.conf,
|
||||
model=config.model,
|
||||
sampler_name=config.sampler_name,
|
||||
embedding_path=embedding_path,
|
||||
full_precision=config.full_precision,
|
||||
precision=config.precision,
|
||||
gfpgan=gfpgan,
|
||||
codeformer=codeformer,
|
||||
esrgan=esrgan,
|
||||
free_gpu_mem=config.free_gpu_mem,
|
||||
safety_checker=config.safety_checker,
|
||||
max_loaded_models=config.max_loaded_models,
|
||||
)
|
||||
except (FileNotFoundError, TypeError, AssertionError):
|
||||
#emergency_model_reconfigure() # TODO?
|
||||
# emergency_model_reconfigure() # TODO?
|
||||
sys.exit(-1)
|
||||
except (IOError, KeyError) as e:
|
||||
print(f'{e}. Aborting.')
|
||||
print(f"{e}. Aborting.")
|
||||
sys.exit(-1)
|
||||
|
||||
generate.free_gpu_mem = config.free_gpu_mem
|
File diff suppressed because it is too large
Load Diff
@ -1,20 +1,22 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
import datetime
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from queue import Queue
|
||||
from typing import Dict
|
||||
|
||||
from PIL.Image import Image
|
||||
from ...pngwriter import PngWriter
|
||||
|
||||
from invokeai.backend.image_util import PngWriter
|
||||
|
||||
|
||||
class ImageType(str, Enum):
|
||||
RESULT = 'results'
|
||||
INTERMEDIATE = 'intermediates'
|
||||
UPLOAD = 'uploads'
|
||||
RESULT = "results"
|
||||
INTERMEDIATE = "intermediates"
|
||||
UPLOAD = "uploads"
|
||||
|
||||
|
||||
class ImageStorageBase(ABC):
|
||||
@ -38,14 +40,15 @@ class ImageStorageBase(ABC):
|
||||
pass
|
||||
|
||||
def create_name(self, context_id: str, node_id: str) -> str:
|
||||
return f'{context_id}_{node_id}_{str(int(datetime.datetime.now(datetime.timezone.utc).timestamp()))}.png'
|
||||
return f"{context_id}_{node_id}_{str(int(datetime.datetime.now(datetime.timezone.utc).timestamp()))}.png"
|
||||
|
||||
|
||||
class DiskImageStorage(ImageStorageBase):
|
||||
"""Stores images on disk"""
|
||||
|
||||
__output_folder: str
|
||||
__pngWriter: PngWriter
|
||||
__cache_ids: Queue # TODO: this is an incredibly naive cache
|
||||
__cache_ids: Queue # TODO: this is an incredibly naive cache
|
||||
__cache: Dict[str, Image]
|
||||
__max_cache_size: int
|
||||
|
||||
@ -54,13 +57,15 @@ class DiskImageStorage(ImageStorageBase):
|
||||
self.__pngWriter = PngWriter(output_folder)
|
||||
self.__cache = dict()
|
||||
self.__cache_ids = Queue()
|
||||
self.__max_cache_size = 10 # TODO: get this from config
|
||||
self.__max_cache_size = 10 # TODO: get this from config
|
||||
|
||||
Path(output_folder).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# TODO: don't hard-code. get/save/delete should maybe take subpath?
|
||||
for image_type in ImageType:
|
||||
Path(os.path.join(output_folder, image_type)).mkdir(parents=True, exist_ok=True)
|
||||
Path(os.path.join(output_folder, image_type)).mkdir(
|
||||
parents=True, exist_ok=True
|
||||
)
|
||||
|
||||
def get(self, image_type: ImageType, image_name: str) -> Image:
|
||||
image_path = self.get_path(image_type, image_name)
|
||||
@ -79,7 +84,9 @@ class DiskImageStorage(ImageStorageBase):
|
||||
|
||||
def save(self, image_type: ImageType, image_name: str, image: Image) -> None:
|
||||
image_subpath = os.path.join(image_type, image_name)
|
||||
self.__pngWriter.save_image_and_prompt_to_png(image, "", image_subpath, None) # TODO: just pass full path to png writer
|
||||
self.__pngWriter.save_image_and_prompt_to_png(
|
||||
image, "", image_subpath, None
|
||||
) # TODO: just pass full path to png writer
|
||||
|
||||
image_path = self.get_path(image_type, image_name)
|
||||
self.__set_cache(image_path, image)
|
||||
@ -98,7 +105,9 @@ class DiskImageStorage(ImageStorageBase):
|
||||
def __set_cache(self, image_name: str, image: Image):
|
||||
if not image_name in self.__cache:
|
||||
self.__cache[image_name] = image
|
||||
self.__cache_ids.put(image_name) # TODO: this should refresh position for LRU cache
|
||||
self.__cache_ids.put(
|
||||
image_name
|
||||
) # TODO: this should refresh position for LRU cache
|
||||
if len(self.__cache) > self.__max_cache_size:
|
||||
cache_id = self.__cache_ids.get()
|
||||
del self.__cache[cache_id]
|
@ -6,17 +6,19 @@ from queue import Queue
|
||||
|
||||
# TODO: make this serializable
|
||||
class InvocationQueueItem:
|
||||
#session_id: str
|
||||
# session_id: str
|
||||
graph_execution_state_id: str
|
||||
invocation_id: str
|
||||
invoke_all: bool
|
||||
|
||||
def __init__(self,
|
||||
#session_id: str,
|
||||
def __init__(
|
||||
self,
|
||||
# session_id: str,
|
||||
graph_execution_state_id: str,
|
||||
invocation_id: str,
|
||||
invoke_all: bool = False):
|
||||
#self.session_id = session_id
|
||||
invoke_all: bool = False,
|
||||
):
|
||||
# self.session_id = session_id
|
||||
self.graph_execution_state_id = graph_execution_state_id
|
||||
self.invocation_id = invocation_id
|
||||
self.invoke_all = invoke_all
|
||||
@ -24,12 +26,13 @@ class InvocationQueueItem:
|
||||
|
||||
class InvocationQueueABC(ABC):
|
||||
"""Abstract base class for all invocation queues"""
|
||||
|
||||
@abstractmethod
|
||||
def get(self) -> InvocationQueueItem:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def put(self, item: InvocationQueueItem|None) -> None:
|
||||
def put(self, item: InvocationQueueItem | None) -> None:
|
||||
pass
|
||||
|
||||
|
||||
@ -42,5 +45,5 @@ class MemoryInvocationQueue(InvocationQueueABC):
|
||||
def get(self) -> InvocationQueueItem:
|
||||
return self.__queue.get()
|
||||
|
||||
def put(self, item: InvocationQueueItem|None) -> None:
|
||||
def put(self, item: InvocationQueueItem | None) -> None:
|
||||
self.__queue.put(item)
|
@ -1,29 +1,32 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
from invokeai.backend import Generate
|
||||
|
||||
from .events import EventServiceBase
|
||||
from .image_storage import ImageStorageBase
|
||||
from .invocation_queue import InvocationQueueABC
|
||||
from .item_storage import ItemStorageABC
|
||||
from .image_storage import ImageStorageBase
|
||||
from .events import EventServiceBase
|
||||
from ....generate import Generate
|
||||
|
||||
|
||||
class InvocationServices():
|
||||
class InvocationServices:
|
||||
"""Services that can be used by invocations"""
|
||||
generate: Generate # TODO: wrap Generate, or split it up from model?
|
||||
|
||||
generate: Generate # TODO: wrap Generate, or split it up from model?
|
||||
events: EventServiceBase
|
||||
images: ImageStorageBase
|
||||
queue: InvocationQueueABC
|
||||
|
||||
# NOTE: we must forward-declare any types that include invocations, since invocations can use services
|
||||
graph_execution_manager: ItemStorageABC['GraphExecutionState']
|
||||
processor: 'InvocationProcessorABC'
|
||||
graph_execution_manager: ItemStorageABC["GraphExecutionState"]
|
||||
processor: "InvocationProcessorABC"
|
||||
|
||||
def __init__(self,
|
||||
def __init__(
|
||||
self,
|
||||
generate: Generate,
|
||||
events: EventServiceBase,
|
||||
images: ImageStorageBase,
|
||||
queue: InvocationQueueABC,
|
||||
graph_execution_manager: ItemStorageABC['GraphExecutionState'],
|
||||
processor: 'InvocationProcessorABC'
|
||||
graph_execution_manager: ItemStorageABC["GraphExecutionState"],
|
||||
processor: "InvocationProcessorABC",
|
||||
):
|
||||
self.generate = generate
|
||||
self.events = events
|
@ -2,11 +2,12 @@
|
||||
|
||||
from abc import ABC
|
||||
from threading import Event, Thread
|
||||
from .graph import Graph, GraphExecutionState
|
||||
from .item_storage import ItemStorageABC
|
||||
|
||||
from ..invocations.baseinvocation import InvocationContext
|
||||
from .invocation_services import InvocationServices
|
||||
from .graph import Graph, GraphExecutionState
|
||||
from .invocation_queue import InvocationQueueABC, InvocationQueueItem
|
||||
from .invocation_services import InvocationServices
|
||||
from .item_storage import ItemStorageABC
|
||||
|
||||
|
||||
class Invoker:
|
||||
@ -14,14 +15,13 @@ class Invoker:
|
||||
|
||||
services: InvocationServices
|
||||
|
||||
def __init__(self,
|
||||
services: InvocationServices
|
||||
):
|
||||
def __init__(self, services: InvocationServices):
|
||||
self.services = services
|
||||
self._start()
|
||||
|
||||
|
||||
def invoke(self, graph_execution_state: GraphExecutionState, invoke_all: bool = False) -> str|None:
|
||||
def invoke(
|
||||
self, graph_execution_state: GraphExecutionState, invoke_all: bool = False
|
||||
) -> str | None:
|
||||
"""Determines the next node to invoke and returns the id of the invoked node, or None if there are no nodes to execute"""
|
||||
|
||||
# Get the next invocation
|
||||
@ -33,38 +33,36 @@ class Invoker:
|
||||
self.services.graph_execution_manager.set(graph_execution_state)
|
||||
|
||||
# Queue the invocation
|
||||
print(f'queueing item {invocation.id}')
|
||||
self.services.queue.put(InvocationQueueItem(
|
||||
#session_id = session.id,
|
||||
graph_execution_state_id = graph_execution_state.id,
|
||||
invocation_id = invocation.id,
|
||||
invoke_all = invoke_all
|
||||
))
|
||||
print(f"queueing item {invocation.id}")
|
||||
self.services.queue.put(
|
||||
InvocationQueueItem(
|
||||
# session_id = session.id,
|
||||
graph_execution_state_id=graph_execution_state.id,
|
||||
invocation_id=invocation.id,
|
||||
invoke_all=invoke_all,
|
||||
)
|
||||
)
|
||||
|
||||
return invocation.id
|
||||
|
||||
|
||||
def create_execution_state(self, graph: Graph|None = None) -> GraphExecutionState:
|
||||
def create_execution_state(self, graph: Graph | None = None) -> GraphExecutionState:
|
||||
"""Creates a new execution state for the given graph"""
|
||||
new_state = GraphExecutionState(graph = Graph() if graph is None else graph)
|
||||
new_state = GraphExecutionState(graph=Graph() if graph is None else graph)
|
||||
self.services.graph_execution_manager.set(new_state)
|
||||
return new_state
|
||||
|
||||
|
||||
def __start_service(self, service) -> None:
|
||||
# Call start() method on any services that have it
|
||||
start_op = getattr(service, 'start', None)
|
||||
start_op = getattr(service, "start", None)
|
||||
if callable(start_op):
|
||||
start_op(self)
|
||||
|
||||
|
||||
def __stop_service(self, service) -> None:
|
||||
# Call stop() method on any services that have it
|
||||
stop_op = getattr(service, 'stop', None)
|
||||
stop_op = getattr(service, "stop", None)
|
||||
if callable(stop_op):
|
||||
stop_op(self)
|
||||
|
||||
|
||||
def _start(self) -> None:
|
||||
"""Starts the invoker. This is called automatically when the invoker is created."""
|
||||
for service in vars(self.services):
|
||||
@ -73,7 +71,6 @@ class Invoker:
|
||||
for service in vars(self.services):
|
||||
self.__start_service(getattr(self.services, service))
|
||||
|
||||
|
||||
def stop(self) -> None:
|
||||
"""Stops the invoker. A new invoker will have to be created to execute further."""
|
||||
# First stop all services
|
@ -1,19 +1,21 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Callable, Generic, TypeVar
|
||||
|
||||
from typing import Callable, TypeVar, Generic
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic.generics import GenericModel
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
T = TypeVar('T', bound=BaseModel)
|
||||
T = TypeVar("T", bound=BaseModel)
|
||||
|
||||
|
||||
class PaginatedResults(GenericModel, Generic[T]):
|
||||
"""Paginated results"""
|
||||
items: list[T] = Field(description = "Items")
|
||||
page: int = Field(description = "Current Page")
|
||||
pages: int = Field(description = "Total number of pages")
|
||||
per_page: int = Field(description = "Number of items per page")
|
||||
total: int = Field(description = "Total number of items in result")
|
||||
|
||||
#fmt: off
|
||||
items: list[T] = Field(description="Items")
|
||||
page: int = Field(description="Current Page")
|
||||
pages: int = Field(description="Total number of pages")
|
||||
per_page: int = Field(description="Number of items per page")
|
||||
total: int = Field(description="Total number of items in result")
|
||||
#fmt: on
|
||||
|
||||
class ItemStorageABC(ABC, Generic[T]):
|
||||
_on_changed_callbacks: list[Callable[[T], None]]
|
||||
@ -24,6 +26,7 @@ class ItemStorageABC(ABC, Generic[T]):
|
||||
self._on_deleted_callbacks = list()
|
||||
|
||||
"""Base item storage class"""
|
||||
|
||||
@abstractmethod
|
||||
def get(self, item_id: str) -> T:
|
||||
pass
|
||||
@ -37,7 +40,9 @@ class ItemStorageABC(ABC, Generic[T]):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def search(self, query: str, page: int = 0, per_page: int = 10) -> PaginatedResults[T]:
|
||||
def search(
|
||||
self, query: str, page: int = 0, per_page: int = 10
|
||||
) -> PaginatedResults[T]:
|
||||
pass
|
||||
|
||||
def on_changed(self, on_changed: Callable[[T], None]) -> None:
|
@ -1,5 +1,6 @@
|
||||
from threading import Event, Thread
|
||||
import traceback
|
||||
from threading import Event, Thread
|
||||
|
||||
from ..invocations.baseinvocation import InvocationContext
|
||||
from .invocation_queue import InvocationQueueItem
|
||||
from .invoker import InvocationProcessorABC, Invoker
|
||||
@ -14,52 +15,62 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
||||
self.__invoker = invoker
|
||||
self.__stop_event = Event()
|
||||
self.__invoker_thread = Thread(
|
||||
name = "invoker_processor",
|
||||
target = self.__process,
|
||||
kwargs = dict(stop_event = self.__stop_event)
|
||||
name="invoker_processor",
|
||||
target=self.__process,
|
||||
kwargs=dict(stop_event=self.__stop_event),
|
||||
)
|
||||
self.__invoker_thread.daemon = (
|
||||
True # TODO: probably better to just not use threads?
|
||||
)
|
||||
self.__invoker_thread.daemon = True # TODO: probably better to just not use threads?
|
||||
self.__invoker_thread.start()
|
||||
|
||||
|
||||
def stop(self, *args, **kwargs) -> None:
|
||||
self.__stop_event.set()
|
||||
|
||||
|
||||
def __process(self, stop_event: Event):
|
||||
try:
|
||||
while not stop_event.is_set():
|
||||
queue_item: InvocationQueueItem = self.__invoker.services.queue.get()
|
||||
if not queue_item: # Probably stopping
|
||||
if not queue_item: # Probably stopping
|
||||
continue
|
||||
|
||||
graph_execution_state = self.__invoker.services.graph_execution_manager.get(queue_item.graph_execution_state_id)
|
||||
invocation = graph_execution_state.execution_graph.get_node(queue_item.invocation_id)
|
||||
graph_execution_state = (
|
||||
self.__invoker.services.graph_execution_manager.get(
|
||||
queue_item.graph_execution_state_id
|
||||
)
|
||||
)
|
||||
invocation = graph_execution_state.execution_graph.get_node(
|
||||
queue_item.invocation_id
|
||||
)
|
||||
|
||||
# Send starting event
|
||||
self.__invoker.services.events.emit_invocation_started(
|
||||
graph_execution_state_id = graph_execution_state.id,
|
||||
invocation_id = invocation.id
|
||||
graph_execution_state_id=graph_execution_state.id,
|
||||
invocation_id=invocation.id,
|
||||
)
|
||||
|
||||
# Invoke
|
||||
try:
|
||||
outputs = invocation.invoke(InvocationContext(
|
||||
services = self.__invoker.services,
|
||||
graph_execution_state_id = graph_execution_state.id
|
||||
))
|
||||
outputs = invocation.invoke(
|
||||
InvocationContext(
|
||||
services=self.__invoker.services,
|
||||
graph_execution_state_id=graph_execution_state.id,
|
||||
)
|
||||
)
|
||||
|
||||
# Save outputs and history
|
||||
graph_execution_state.complete(invocation.id, outputs)
|
||||
|
||||
# Save the state changes
|
||||
self.__invoker.services.graph_execution_manager.set(graph_execution_state)
|
||||
self.__invoker.services.graph_execution_manager.set(
|
||||
graph_execution_state
|
||||
)
|
||||
|
||||
# Send complete event
|
||||
self.__invoker.services.events.emit_invocation_complete(
|
||||
graph_execution_state_id = graph_execution_state.id,
|
||||
invocation_id = invocation.id,
|
||||
result = outputs.dict()
|
||||
graph_execution_state_id=graph_execution_state.id,
|
||||
invocation_id=invocation.id,
|
||||
result=outputs.dict(),
|
||||
)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
@ -72,24 +83,27 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
||||
graph_execution_state.set_node_error(invocation.id, error)
|
||||
|
||||
# Save the state changes
|
||||
self.__invoker.services.graph_execution_manager.set(graph_execution_state)
|
||||
self.__invoker.services.graph_execution_manager.set(
|
||||
graph_execution_state
|
||||
)
|
||||
|
||||
# Send error event
|
||||
self.__invoker.services.events.emit_invocation_error(
|
||||
graph_execution_state_id = graph_execution_state.id,
|
||||
invocation_id = invocation.id,
|
||||
error = error
|
||||
graph_execution_state_id=graph_execution_state.id,
|
||||
invocation_id=invocation.id,
|
||||
error=error,
|
||||
)
|
||||
|
||||
pass
|
||||
|
||||
|
||||
# Queue any further commands if invoking all
|
||||
is_complete = graph_execution_state.is_complete()
|
||||
if queue_item.invoke_all and not is_complete:
|
||||
self.__invoker.invoke(graph_execution_state, invoke_all = True)
|
||||
self.__invoker.invoke(graph_execution_state, invoke_all=True)
|
||||
elif is_complete:
|
||||
self.__invoker.services.events.emit_graph_execution_complete(graph_execution_state.id)
|
||||
self.__invoker.services.events.emit_graph_execution_complete(
|
||||
graph_execution_state.id
|
||||
)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
... # Log something?
|
||||
... # Log something?
|
@ -1,12 +1,15 @@
|
||||
import sqlite3
|
||||
from threading import Lock
|
||||
from typing import Generic, TypeVar, Union, get_args
|
||||
|
||||
from pydantic import BaseModel, parse_raw_as
|
||||
|
||||
from .item_storage import ItemStorageABC, PaginatedResults
|
||||
|
||||
T = TypeVar('T', bound=BaseModel)
|
||||
T = TypeVar("T", bound=BaseModel)
|
||||
|
||||
sqlite_memory = ":memory:"
|
||||
|
||||
sqlite_memory = ':memory:'
|
||||
|
||||
class SqliteItemStorage(ItemStorageABC, Generic[T]):
|
||||
_filename: str
|
||||
@ -16,15 +19,17 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
|
||||
_id_field: str
|
||||
_lock: Lock
|
||||
|
||||
def __init__(self, filename: str, table_name: str, id_field: str = 'id'):
|
||||
def __init__(self, filename: str, table_name: str, id_field: str = "id"):
|
||||
super().__init__()
|
||||
|
||||
self._filename = filename
|
||||
self._table_name = table_name
|
||||
self._id_field = id_field # TODO: validate that T has this field
|
||||
self._id_field = id_field # TODO: validate that T has this field
|
||||
self._lock = Lock()
|
||||
|
||||
self._conn = sqlite3.connect(self._filename, check_same_thread=False) # TODO: figure out a better threading solution
|
||||
self._conn = sqlite3.connect(
|
||||
self._filename, check_same_thread=False
|
||||
) # TODO: figure out a better threading solution
|
||||
self._cursor = self._conn.cursor()
|
||||
|
||||
self._create_table()
|
||||
@ -32,10 +37,14 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
|
||||
def _create_table(self):
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(f'''CREATE TABLE IF NOT EXISTS {self._table_name} (
|
||||
self._cursor.execute(
|
||||
f"""CREATE TABLE IF NOT EXISTS {self._table_name} (
|
||||
item TEXT,
|
||||
id TEXT GENERATED ALWAYS AS (json_extract(item, '$.{self._id_field}')) VIRTUAL NOT NULL);''')
|
||||
self._cursor.execute(f'''CREATE UNIQUE INDEX IF NOT EXISTS {self._table_name}_id ON {self._table_name}(id);''')
|
||||
id TEXT GENERATED ALWAYS AS (json_extract(item, '$.{self._id_field}')) VIRTUAL NOT NULL);"""
|
||||
)
|
||||
self._cursor.execute(
|
||||
f"""CREATE UNIQUE INDEX IF NOT EXISTS {self._table_name}_id ON {self._table_name}(id);"""
|
||||
)
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
@ -46,7 +55,10 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
|
||||
def set(self, item: T):
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(f'''INSERT OR REPLACE INTO {self._table_name} (item) VALUES (?);''', (item.json(),))
|
||||
self._cursor.execute(
|
||||
f"""INSERT OR REPLACE INTO {self._table_name} (item) VALUES (?);""",
|
||||
(item.json(),),
|
||||
)
|
||||
finally:
|
||||
self._lock.release()
|
||||
self._on_changed(item)
|
||||
@ -54,7 +66,9 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
|
||||
def get(self, id: str) -> Union[T, None]:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(f'''SELECT item FROM {self._table_name} WHERE id = ?;''', (str(id),))
|
||||
self._cursor.execute(
|
||||
f"""SELECT item FROM {self._table_name} WHERE id = ?;""", (str(id),)
|
||||
)
|
||||
result = self._cursor.fetchone()
|
||||
finally:
|
||||
self._lock.release()
|
||||
@ -67,7 +81,9 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
|
||||
def delete(self, id: str):
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(f'''DELETE FROM {self._table_name} WHERE id = ?;''', (str(id),))
|
||||
self._cursor.execute(
|
||||
f"""DELETE FROM {self._table_name} WHERE id = ?;""", (str(id),)
|
||||
)
|
||||
finally:
|
||||
self._lock.release()
|
||||
self._on_deleted(id)
|
||||
@ -75,12 +91,15 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
|
||||
def list(self, page: int = 0, per_page: int = 10) -> PaginatedResults[T]:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(f'''SELECT item FROM {self._table_name} LIMIT ? OFFSET ?;''', (per_page, page * per_page))
|
||||
self._cursor.execute(
|
||||
f"""SELECT item FROM {self._table_name} LIMIT ? OFFSET ?;""",
|
||||
(per_page, page * per_page),
|
||||
)
|
||||
result = self._cursor.fetchall()
|
||||
|
||||
items = list(map(lambda r: self._parse_item(r[0]), result))
|
||||
|
||||
self._cursor.execute(f'''SELECT count(*) FROM {self._table_name};''')
|
||||
self._cursor.execute(f"""SELECT count(*) FROM {self._table_name};""")
|
||||
count = self._cursor.fetchone()[0]
|
||||
finally:
|
||||
self._lock.release()
|
||||
@ -88,22 +107,26 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
|
||||
pageCount = int(count / per_page) + 1
|
||||
|
||||
return PaginatedResults[T](
|
||||
items = items,
|
||||
page = page,
|
||||
pages = pageCount,
|
||||
per_page = per_page,
|
||||
total = count
|
||||
items=items, page=page, pages=pageCount, per_page=per_page, total=count
|
||||
)
|
||||
|
||||
def search(self, query: str, page: int = 0, per_page: int = 10) -> PaginatedResults[T]:
|
||||
def search(
|
||||
self, query: str, page: int = 0, per_page: int = 10
|
||||
) -> PaginatedResults[T]:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(f'''SELECT item FROM {self._table_name} WHERE item LIKE ? LIMIT ? OFFSET ?;''', (f'%{query}%', per_page, page * per_page))
|
||||
self._cursor.execute(
|
||||
f"""SELECT item FROM {self._table_name} WHERE item LIKE ? LIMIT ? OFFSET ?;""",
|
||||
(f"%{query}%", per_page, page * per_page),
|
||||
)
|
||||
result = self._cursor.fetchall()
|
||||
|
||||
items = list(map(lambda r: self._parse_item(r[0]), result))
|
||||
|
||||
self._cursor.execute(f'''SELECT count(*) FROM {self._table_name} WHERE item LIKE ?;''', (f'%{query}%',))
|
||||
self._cursor.execute(
|
||||
f"""SELECT count(*) FROM {self._table_name} WHERE item LIKE ?;""",
|
||||
(f"%{query}%",),
|
||||
)
|
||||
count = self._cursor.fetchone()[0]
|
||||
finally:
|
||||
self._lock.release()
|
||||
@ -111,9 +134,5 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
|
||||
pageCount = int(count / per_page) + 1
|
||||
|
||||
return PaginatedResults[T](
|
||||
items = items,
|
||||
page = page,
|
||||
pages = pageCount,
|
||||
per_page = per_page,
|
||||
total = count
|
||||
items=items, page=page, pages=pageCount, per_page=per_page, total=count
|
||||
)
|
@ -1,5 +1,5 @@
|
||||
'''
|
||||
"""
|
||||
Initialization file for invokeai.backend
|
||||
'''
|
||||
from .invoke_ai_web_server import InvokeAIWebServer
|
||||
|
||||
"""
|
||||
from .generate import Generate
|
||||
from .model_management import ModelManager
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -6,19 +6,20 @@
|
||||
#
|
||||
# Coauthor: Kevin Turner http://github.com/keturn
|
||||
#
|
||||
print("Loading Python libraries...\n")
|
||||
import sys
|
||||
print("Loading Python libraries...\n",file=sys.stderr)
|
||||
|
||||
import argparse
|
||||
import io
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
import sys
|
||||
import traceback
|
||||
import warnings
|
||||
from argparse import Namespace
|
||||
from pathlib import Path
|
||||
from urllib import request
|
||||
from shutil import get_terminal_size
|
||||
from urllib import request
|
||||
|
||||
import npyscreen
|
||||
import torch
|
||||
@ -37,17 +38,20 @@ from transformers import (
|
||||
|
||||
import invokeai.configs as configs
|
||||
|
||||
from ...frontend.install.model_install import addModelsForm, process_and_execute
|
||||
from ...frontend.install.widgets import (
|
||||
CenteredButtonPress,
|
||||
IntTitleSlider,
|
||||
set_min_terminal_size,
|
||||
)
|
||||
from ..args import PRECISION_CHOICES, Args
|
||||
from ..globals import Globals, global_config_dir, global_config_file, global_cache_dir
|
||||
from .model_install import addModelsForm, process_and_execute
|
||||
from ..globals import Globals, global_cache_dir, global_config_dir, global_config_file
|
||||
from .model_install_backend import (
|
||||
default_dataset,
|
||||
download_from_hf,
|
||||
recommended_datasets,
|
||||
hf_download_with_resume,
|
||||
recommended_datasets,
|
||||
)
|
||||
from .widgets import IntTitleSlider, CenteredButtonPress, set_min_terminal_size
|
||||
|
||||
|
||||
warnings.filterwarnings("ignore")
|
||||
|
||||
@ -82,6 +86,7 @@ INIT_FILE_PREAMBLE = """# InvokeAI initialization file
|
||||
# -Ak_euler_a -C10.0
|
||||
"""
|
||||
|
||||
|
||||
# --------------------------------------------
|
||||
def postscript(errors: None):
|
||||
if not any(errors):
|
||||
@ -180,13 +185,11 @@ def download_with_progress_bar(model_url: str, model_dest: str, label: str = "th
|
||||
# ---------------------------------------------
|
||||
# this will preload the Bert tokenizer fles
|
||||
def download_bert():
|
||||
print(
|
||||
"Installing bert tokenizer...",
|
||||
file=sys.stderr
|
||||
)
|
||||
print("Installing bert tokenizer...", file=sys.stderr)
|
||||
with warnings.catch_warnings():
|
||||
warnings.filterwarnings("ignore", category=DeprecationWarning)
|
||||
from transformers import BertTokenizerFast
|
||||
|
||||
download_from_hf(BertTokenizerFast, "bert-base-uncased")
|
||||
|
||||
|
||||
@ -197,12 +200,14 @@ def download_sd1_clip():
|
||||
download_from_hf(CLIPTokenizer, version)
|
||||
download_from_hf(CLIPTextModel, version)
|
||||
|
||||
|
||||
# ---------------------------------------------
|
||||
def download_sd2_clip():
|
||||
version = 'stabilityai/stable-diffusion-2'
|
||||
version = "stabilityai/stable-diffusion-2"
|
||||
print("Installing SD2 clip model...", file=sys.stderr)
|
||||
download_from_hf(CLIPTokenizer, version, subfolder='tokenizer')
|
||||
download_from_hf(CLIPTextModel, version, subfolder='text_encoder')
|
||||
download_from_hf(CLIPTokenizer, version, subfolder="tokenizer")
|
||||
download_from_hf(CLIPTextModel, version, subfolder="text_encoder")
|
||||
|
||||
|
||||
# ---------------------------------------------
|
||||
def download_realesrgan():
|
||||
@ -329,7 +334,7 @@ class editOptsForm(npyscreen.FormMultiPage):
|
||||
old_opts = self.parentApp.invokeai_opts
|
||||
first_time = not (Globals.root / Globals.initfile).exists()
|
||||
access_token = HfFolder.get_token()
|
||||
window_width,window_height = get_terminal_size()
|
||||
window_width, window_height = get_terminal_size()
|
||||
for i in [
|
||||
"Configure startup settings. You can come back and change these later.",
|
||||
"Use ctrl-N and ctrl-P to move to the <N>ext and <P>revious fields.",
|
||||
@ -681,6 +686,7 @@ def run_console_ui(
|
||||
else:
|
||||
return (editApp.new_opts, editApp.user_selections)
|
||||
|
||||
|
||||
# -------------------------------------
|
||||
def write_opts(opts: Namespace, init_file: Path):
|
||||
"""
|
||||
@ -701,8 +707,8 @@ def write_opts(opts: Namespace, init_file: Path):
|
||||
"^--?(o|out|no-xformer|xformer|no-ckpt|ckpt|free|no-nsfw|nsfw|prec|max_load|embed|always|ckpt|free_gpu)"
|
||||
)
|
||||
# fix windows paths
|
||||
opts.outdir = opts.outdir.replace('\\','/')
|
||||
opts.embedding_path = opts.embedding_path.replace('\\','/')
|
||||
opts.outdir = opts.outdir.replace("\\", "/")
|
||||
opts.embedding_path = opts.embedding_path.replace("\\", "/")
|
||||
new_file = f"{init_file}.new"
|
||||
try:
|
||||
lines = [x.strip() for x in open(init_file, "r").readlines()]
|
||||
@ -855,6 +861,7 @@ def main():
|
||||
except KeyboardInterrupt:
|
||||
print("\nGoodbye! Come back soon.")
|
||||
|
||||
|
||||
# -------------------------------------
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -8,6 +8,7 @@ import sys
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from tempfile import TemporaryFile
|
||||
from typing import List
|
||||
|
||||
import requests
|
||||
from diffusers import AutoencoderKL
|
||||
@ -15,12 +16,12 @@ from huggingface_hub import hf_hub_url
|
||||
from omegaconf import OmegaConf
|
||||
from omegaconf.dictconfig import DictConfig
|
||||
from tqdm import tqdm
|
||||
from typing import List
|
||||
|
||||
import invokeai.configs as configs
|
||||
from ..generator.diffusers_pipeline import StableDiffusionGeneratorPipeline
|
||||
|
||||
from ..globals import Globals, global_cache_dir, global_config_dir
|
||||
from ..model_manager import ModelManager
|
||||
from ..model_management import ModelManager
|
||||
from ..stable_diffusion import StableDiffusionGeneratorPipeline
|
||||
|
||||
warnings.filterwarnings("ignore")
|
||||
|
||||
@ -44,42 +45,46 @@ Config_preamble = """
|
||||
# was trained on.
|
||||
"""
|
||||
|
||||
|
||||
def default_config_file():
|
||||
return Path(global_config_dir()) / "models.yaml"
|
||||
|
||||
|
||||
def sd_configs():
|
||||
return Path(global_config_dir()) / "stable-diffusion"
|
||||
|
||||
|
||||
def initial_models():
|
||||
global Datasets
|
||||
if Datasets:
|
||||
return Datasets
|
||||
return (Datasets := OmegaConf.load(Dataset_path))
|
||||
|
||||
def install_requested_models(
|
||||
install_initial_models: List[str] = None,
|
||||
remove_models: List[str] = None,
|
||||
scan_directory: Path = None,
|
||||
external_models: List[str] = None,
|
||||
scan_at_startup: bool = False,
|
||||
convert_to_diffusers: bool = False,
|
||||
precision: str = "float16",
|
||||
purge_deleted: bool = False,
|
||||
config_file_path: Path = None,
|
||||
):
|
||||
'''
|
||||
Entry point for installing/deleting starter models, or installing external models.
|
||||
'''
|
||||
config_file_path=config_file_path or default_config_file()
|
||||
if not config_file_path.exists():
|
||||
open(config_file_path,'w')
|
||||
|
||||
model_manager= ModelManager(OmegaConf.load(config_file_path),precision=precision)
|
||||
def install_requested_models(
|
||||
install_initial_models: List[str] = None,
|
||||
remove_models: List[str] = None,
|
||||
scan_directory: Path = None,
|
||||
external_models: List[str] = None,
|
||||
scan_at_startup: bool = False,
|
||||
convert_to_diffusers: bool = False,
|
||||
precision: str = "float16",
|
||||
purge_deleted: bool = False,
|
||||
config_file_path: Path = None,
|
||||
):
|
||||
"""
|
||||
Entry point for installing/deleting starter models, or installing external models.
|
||||
"""
|
||||
config_file_path = config_file_path or default_config_file()
|
||||
if not config_file_path.exists():
|
||||
open(config_file_path, "w")
|
||||
|
||||
model_manager = ModelManager(OmegaConf.load(config_file_path), precision=precision)
|
||||
|
||||
if remove_models and len(remove_models) > 0:
|
||||
print("== DELETING UNCHECKED STARTER MODELS ==")
|
||||
for model in remove_models:
|
||||
print(f'{model}...')
|
||||
print(f"{model}...")
|
||||
model_manager.del_model(model, delete_files=purge_deleted)
|
||||
model_manager.commit(config_file_path)
|
||||
|
||||
@ -96,20 +101,20 @@ def install_requested_models(
|
||||
|
||||
# due to above, we have to reload the model manager because conf file
|
||||
# was changed behind its back
|
||||
model_manager= ModelManager(OmegaConf.load(config_file_path),precision=precision)
|
||||
model_manager = ModelManager(OmegaConf.load(config_file_path), precision=precision)
|
||||
|
||||
external_models = external_models or list()
|
||||
if scan_directory:
|
||||
external_models.append(str(scan_directory))
|
||||
|
||||
if len(external_models)>0:
|
||||
if len(external_models) > 0:
|
||||
print("== INSTALLING EXTERNAL MODELS ==")
|
||||
for path_url_or_repo in external_models:
|
||||
try:
|
||||
model_manager.heuristic_import(
|
||||
path_url_or_repo,
|
||||
convert=convert_to_diffusers,
|
||||
commit_to_conf=config_file_path
|
||||
commit_to_conf=config_file_path,
|
||||
)
|
||||
except KeyboardInterrupt:
|
||||
sys.exit(-1)
|
||||
@ -117,17 +122,18 @@ def install_requested_models(
|
||||
pass
|
||||
|
||||
if scan_at_startup and scan_directory.is_dir():
|
||||
argument = '--autoconvert' if convert_to_diffusers else '--autoimport'
|
||||
argument = "--autoconvert" if convert_to_diffusers else "--autoimport"
|
||||
initfile = Path(Globals.root, Globals.initfile)
|
||||
replacement = Path(Globals.root, f'{Globals.initfile}.new')
|
||||
directory = str(scan_directory).replace('\\','/')
|
||||
with open(initfile,'r') as input:
|
||||
with open(replacement,'w') as output:
|
||||
replacement = Path(Globals.root, f"{Globals.initfile}.new")
|
||||
directory = str(scan_directory).replace("\\", "/")
|
||||
with open(initfile, "r") as input:
|
||||
with open(replacement, "w") as output:
|
||||
while line := input.readline():
|
||||
if not line.startswith(argument):
|
||||
output.writelines([line])
|
||||
output.writelines([f'{argument} {directory}'])
|
||||
os.replace(replacement,initfile)
|
||||
output.writelines([f"{argument} {directory}"])
|
||||
os.replace(replacement, initfile)
|
||||
|
||||
|
||||
# -------------------------------------
|
||||
def yes_or_no(prompt: str, default_yes=True):
|
||||
@ -183,7 +189,9 @@ def migrate_models_ckpt():
|
||||
if not os.path.exists(os.path.join(model_path, "model.ckpt")):
|
||||
return
|
||||
new_name = initial_models()["stable-diffusion-1.4"]["file"]
|
||||
print('The Stable Diffusion v4.1 "model.ckpt" is already installed. The name will be changed to {new_name} to avoid confusion.')
|
||||
print(
|
||||
'The Stable Diffusion v4.1 "model.ckpt" is already installed. The name will be changed to {new_name} to avoid confusion.'
|
||||
)
|
||||
print(f"model.ckpt => {new_name}")
|
||||
os.replace(
|
||||
os.path.join(model_path, "model.ckpt"), os.path.join(model_path, new_name)
|
||||
@ -383,7 +391,8 @@ def update_config_file(successfully_downloaded: dict, config_file: Path):
|
||||
|
||||
# ---------------------------------------------
|
||||
def new_config_file_contents(
|
||||
successfully_downloaded: dict, config_file: Path,
|
||||
successfully_downloaded: dict,
|
||||
config_file: Path,
|
||||
) -> str:
|
||||
if config_file.exists():
|
||||
conf = OmegaConf.load(str(config_file.expanduser().resolve()))
|
||||
@ -413,7 +422,9 @@ def new_config_file_contents(
|
||||
stanza["weights"] = os.path.relpath(
|
||||
successfully_downloaded[model], start=Globals.root
|
||||
)
|
||||
stanza["config"] = os.path.normpath(os.path.join(sd_configs(), mod["config"]))
|
||||
stanza["config"] = os.path.normpath(
|
||||
os.path.join(sd_configs(), mod["config"])
|
||||
)
|
||||
if "vae" in mod:
|
||||
if "file" in mod["vae"]:
|
||||
stanza["vae"] = os.path.normpath(
|
@ -25,21 +25,19 @@ from omegaconf import OmegaConf
|
||||
from PIL import Image, ImageOps
|
||||
from pytorch_lightning import logging, seed_everything
|
||||
|
||||
import ldm.invoke.conditioning
|
||||
from ldm.invoke.args import metadata_from_png
|
||||
from ldm.invoke.concepts_lib import HuggingFaceConceptsLibrary
|
||||
from ldm.invoke.conditioning import get_uc_and_c_and_ec
|
||||
from ldm.invoke.devices import choose_precision, choose_torch_device
|
||||
from ldm.invoke.generator.inpaint import infill_methods
|
||||
from ldm.invoke.globals import Globals, global_cache_dir
|
||||
from ldm.invoke.image_util import InitImageResizer
|
||||
from ldm.invoke.model_manager import ModelManager
|
||||
from ldm.invoke.pngwriter import PngWriter
|
||||
from ldm.invoke.seamless import configure_model_padding
|
||||
from ldm.invoke.txt2mask import Txt2Mask
|
||||
from ldm.models.diffusion.ddim import DDIMSampler
|
||||
from ldm.models.diffusion.ksampler import KSampler
|
||||
from ldm.models.diffusion.plms import PLMSSampler
|
||||
from .model_management import ModelManager
|
||||
from .args import metadata_from_png
|
||||
from .generator import infill_methods
|
||||
from .globals import Globals, global_cache_dir
|
||||
from .image_util import InitImageResizer, PngWriter, Txt2Mask, configure_model_padding
|
||||
from .prompting import get_uc_and_c_and_ec
|
||||
from .stable_diffusion import (
|
||||
DDIMSampler,
|
||||
HuggingFaceConceptsLibrary,
|
||||
KSampler,
|
||||
PLMSSampler,
|
||||
)
|
||||
from .util import choose_precision, choose_torch_device
|
||||
|
||||
|
||||
def fix_func(orig):
|
||||
@ -328,8 +326,8 @@ class Generate:
|
||||
variation_amount=0.0,
|
||||
threshold=0.0,
|
||||
perlin=0.0,
|
||||
h_symmetry_time_pct = None,
|
||||
v_symmetry_time_pct = None,
|
||||
h_symmetry_time_pct=None,
|
||||
v_symmetry_time_pct=None,
|
||||
karras_max=None,
|
||||
outdir=None,
|
||||
# these are specific to img2img and inpaint
|
||||
@ -717,7 +715,7 @@ class Generate:
|
||||
prompt,
|
||||
model=self.model,
|
||||
skip_normalize_legacy_blend=opt.skip_normalize,
|
||||
log_tokens=ldm.invoke.conditioning.log_tokenization,
|
||||
log_tokens=invokeai.backend.prompting.conditioning.log_tokenization,
|
||||
)
|
||||
|
||||
if tool in ("gfpgan", "codeformer", "upscale"):
|
||||
@ -741,7 +739,7 @@ class Generate:
|
||||
)
|
||||
|
||||
elif tool == "outcrop":
|
||||
from ldm.invoke.restoration.outcrop import Outcrop
|
||||
from .restoration.outcrop import Outcrop
|
||||
|
||||
extend_instructions = {}
|
||||
for direction, pixels in _pairwise(opt.outcrop):
|
||||
@ -794,7 +792,7 @@ class Generate:
|
||||
clear_cuda_cache=self.clear_cuda_cache,
|
||||
)
|
||||
elif tool == "outpaint":
|
||||
from ldm.invoke.restoration.outpaint import Outpaint
|
||||
from .restoration.outpaint import Outpaint
|
||||
|
||||
restorer = Outpaint(image, self)
|
||||
return restorer.process(opt, args, image_callback=callback, prefix=prefix)
|
||||
@ -816,17 +814,12 @@ class Generate:
|
||||
hires_fix: bool = False,
|
||||
force_outpaint: bool = False,
|
||||
):
|
||||
inpainting_model_in_use = self.sampler.uses_inpainting_model()
|
||||
|
||||
if hires_fix:
|
||||
return self._make_txt2img2img()
|
||||
|
||||
if embiggen is not None:
|
||||
return self._make_embiggen()
|
||||
|
||||
if inpainting_model_in_use:
|
||||
return self._make_omnibus()
|
||||
|
||||
if ((init_image is not None) and (mask_image is not None)) or force_outpaint:
|
||||
return self._make_inpaint()
|
||||
|
||||
@ -903,16 +896,9 @@ class Generate:
|
||||
def _make_inpaint(self):
|
||||
return self._load_generator(".inpaint", "Inpaint")
|
||||
|
||||
def _make_omnibus(self):
|
||||
return self._load_generator(".omnibus", "Omnibus")
|
||||
|
||||
def _load_generator(self, module, class_name):
|
||||
if self.is_legacy_model(self.model_name):
|
||||
mn = f"ldm.invoke.ckpt_generator{module}"
|
||||
cn = f"Ckpt{class_name}"
|
||||
else:
|
||||
mn = f"ldm.invoke.generator{module}"
|
||||
cn = class_name
|
||||
mn = f"invokeai.backend.generator{module}"
|
||||
cn = class_name
|
||||
module = importlib.import_module(mn)
|
||||
constructor = getattr(module, cn)
|
||||
return constructor(self.model, self.precision)
|
||||
@ -975,7 +961,7 @@ class Generate:
|
||||
|
||||
seed_everything(random.randrange(0, np.iinfo(np.uint32).max))
|
||||
if self.embedding_path is not None:
|
||||
print(f'>> Loading embeddings from {self.embedding_path}')
|
||||
print(f">> Loading embeddings from {self.embedding_path}")
|
||||
for root, _, files in os.walk(self.embedding_path):
|
||||
for name in files:
|
||||
ti_path = os.path.join(root, name)
|
||||
@ -1030,7 +1016,6 @@ class Generate:
|
||||
image_callback=None,
|
||||
prefix=None,
|
||||
):
|
||||
|
||||
results = []
|
||||
for r in image_list:
|
||||
image, seed = r
|
5
invokeai/backend/generator/__init__.py
Normal file
5
invokeai/backend/generator/__init__.py
Normal file
@ -0,0 +1,5 @@
|
||||
"""
|
||||
Initialization file for the invokeai.generator package
|
||||
"""
|
||||
from .base import Generator
|
||||
from .inpaint import infill_methods
|
@ -1,7 +1,7 @@
|
||||
'''
|
||||
Base class for ldm.invoke.generator.*
|
||||
"""
|
||||
Base class for invokeai.backend.generator.*
|
||||
including img2img, txt2img, and inpaint
|
||||
'''
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
@ -9,24 +9,25 @@ import os.path as osp
|
||||
import random
|
||||
import traceback
|
||||
from contextlib import nullcontext
|
||||
from pathlib import Path
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from PIL import Image, ImageFilter, ImageChops
|
||||
from diffusers import DiffusionPipeline
|
||||
from einops import rearrange
|
||||
from pathlib import Path
|
||||
from PIL import Image, ImageChops, ImageFilter
|
||||
from pytorch_lightning import seed_everything
|
||||
from tqdm import trange
|
||||
|
||||
import invokeai.assets.web as web_assets
|
||||
from ldm.models.diffusion.ddpm import DiffusionWrapper
|
||||
from ldm.util import rand_perlin_2d
|
||||
|
||||
from ..stable_diffusion.diffusion.ddpm import DiffusionWrapper
|
||||
from ..util.util import rand_perlin_2d
|
||||
|
||||
downsampling = 8
|
||||
CAUTION_IMG = 'caution.png'
|
||||
CAUTION_IMG = "caution.png"
|
||||
|
||||
|
||||
class Generator:
|
||||
downsampling_factor: int
|
||||
@ -39,7 +40,7 @@ class Generator:
|
||||
self.precision = precision
|
||||
self.seed = None
|
||||
self.latent_channels = model.channels
|
||||
self.downsampling_factor = downsampling # BUG: should come from model or config
|
||||
self.downsampling_factor = downsampling # BUG: should come from model or config
|
||||
self.safety_checker = None
|
||||
self.perlin = 0.0
|
||||
self.threshold = 0
|
||||
@ -50,56 +51,73 @@ class Generator:
|
||||
self.caution_img = None
|
||||
|
||||
# this is going to be overridden in img2img.py, txt2img.py and inpaint.py
|
||||
def get_make_image(self,prompt,**kwargs):
|
||||
def get_make_image(self, prompt, **kwargs):
|
||||
"""
|
||||
Returns a function returning an image derived from the prompt and the initial image
|
||||
Return value depends on the seed at the time you call it
|
||||
"""
|
||||
raise NotImplementedError("image_iterator() must be implemented in a descendent class")
|
||||
raise NotImplementedError(
|
||||
"image_iterator() must be implemented in a descendent class"
|
||||
)
|
||||
|
||||
def set_variation(self, seed, variation_amount, with_variations):
|
||||
self.seed = seed
|
||||
self.seed = seed
|
||||
self.variation_amount = variation_amount
|
||||
self.with_variations = with_variations
|
||||
self.with_variations = with_variations
|
||||
|
||||
def generate(self,prompt,init_image,width,height,sampler, iterations=1,seed=None,
|
||||
image_callback=None, step_callback=None, threshold=0.0, perlin=0.0,
|
||||
h_symmetry_time_pct=None, v_symmetry_time_pct=None,
|
||||
safety_checker:dict=None,
|
||||
free_gpu_mem: bool=False,
|
||||
**kwargs):
|
||||
def generate(
|
||||
self,
|
||||
prompt,
|
||||
init_image,
|
||||
width,
|
||||
height,
|
||||
sampler,
|
||||
iterations=1,
|
||||
seed=None,
|
||||
image_callback=None,
|
||||
step_callback=None,
|
||||
threshold=0.0,
|
||||
perlin=0.0,
|
||||
h_symmetry_time_pct=None,
|
||||
v_symmetry_time_pct=None,
|
||||
safety_checker: dict = None,
|
||||
free_gpu_mem: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
scope = nullcontext
|
||||
self.safety_checker = safety_checker
|
||||
self.free_gpu_mem = free_gpu_mem
|
||||
attention_maps_images = []
|
||||
attention_maps_callback = lambda saver: attention_maps_images.append(saver.get_stacked_maps_image())
|
||||
attention_maps_callback = lambda saver: attention_maps_images.append(
|
||||
saver.get_stacked_maps_image()
|
||||
)
|
||||
make_image = self.get_make_image(
|
||||
prompt,
|
||||
sampler = sampler,
|
||||
init_image = init_image,
|
||||
width = width,
|
||||
height = height,
|
||||
step_callback = step_callback,
|
||||
threshold = threshold,
|
||||
perlin = perlin,
|
||||
h_symmetry_time_pct = h_symmetry_time_pct,
|
||||
v_symmetry_time_pct = v_symmetry_time_pct,
|
||||
attention_maps_callback = attention_maps_callback,
|
||||
**kwargs
|
||||
sampler=sampler,
|
||||
init_image=init_image,
|
||||
width=width,
|
||||
height=height,
|
||||
step_callback=step_callback,
|
||||
threshold=threshold,
|
||||
perlin=perlin,
|
||||
h_symmetry_time_pct=h_symmetry_time_pct,
|
||||
v_symmetry_time_pct=v_symmetry_time_pct,
|
||||
attention_maps_callback=attention_maps_callback,
|
||||
**kwargs,
|
||||
)
|
||||
results = []
|
||||
seed = seed if seed is not None and seed >= 0 else self.new_seed()
|
||||
first_seed = seed
|
||||
results = []
|
||||
seed = seed if seed is not None and seed >= 0 else self.new_seed()
|
||||
first_seed = seed
|
||||
seed, initial_noise = self.generate_initial_noise(seed, width, height)
|
||||
|
||||
# There used to be an additional self.model.ema_scope() here, but it breaks
|
||||
# the inpaint-1.5 model. Not sure what it did.... ?
|
||||
with scope(self.model.device.type):
|
||||
for n in trange(iterations, desc='Generating'):
|
||||
for n in trange(iterations, desc="Generating"):
|
||||
x_T = None
|
||||
if self.variation_amount > 0:
|
||||
seed_everything(seed)
|
||||
target_noise = self.get_noise(width,height)
|
||||
target_noise = self.get_noise(width, height)
|
||||
x_T = self.slerp(self.variation_amount, initial_noise, target_noise)
|
||||
elif initial_noise is not None:
|
||||
# i.e. we specified particular variations
|
||||
@ -107,9 +125,9 @@ class Generator:
|
||||
else:
|
||||
seed_everything(seed)
|
||||
try:
|
||||
x_T = self.get_noise(width,height)
|
||||
x_T = self.get_noise(width, height)
|
||||
except:
|
||||
print('** An error occurred while getting initial noise **')
|
||||
print("** An error occurred while getting initial noise **")
|
||||
print(traceback.format_exc())
|
||||
|
||||
image = make_image(x_T)
|
||||
@ -120,19 +138,30 @@ class Generator:
|
||||
results.append([image, seed])
|
||||
|
||||
if image_callback is not None:
|
||||
attention_maps_image = None if len(attention_maps_images)==0 else attention_maps_images[-1]
|
||||
image_callback(image, seed, first_seed=first_seed, attention_maps_image=attention_maps_image)
|
||||
attention_maps_image = (
|
||||
None
|
||||
if len(attention_maps_images) == 0
|
||||
else attention_maps_images[-1]
|
||||
)
|
||||
image_callback(
|
||||
image,
|
||||
seed,
|
||||
first_seed=first_seed,
|
||||
attention_maps_image=attention_maps_image,
|
||||
)
|
||||
|
||||
seed = self.new_seed()
|
||||
|
||||
# Free up memory from the last generation.
|
||||
clear_cuda_cache = kwargs['clear_cuda_cache'] if 'clear_cuda_cache' in kwargs else None
|
||||
clear_cuda_cache = (
|
||||
kwargs["clear_cuda_cache"] if "clear_cuda_cache" in kwargs else None
|
||||
)
|
||||
if clear_cuda_cache is not None:
|
||||
clear_cuda_cache()
|
||||
|
||||
return results
|
||||
|
||||
def sample_to_image(self,samples)->Image.Image:
|
||||
def sample_to_image(self, samples) -> Image.Image:
|
||||
"""
|
||||
Given samples returned from a sampler, converts
|
||||
it into a PIL Image
|
||||
@ -141,18 +170,30 @@ class Generator:
|
||||
image = self.model.decode_latents(samples)
|
||||
return self.model.numpy_to_pil(image)[0]
|
||||
|
||||
def repaste_and_color_correct(self, result: Image.Image, init_image: Image.Image, init_mask: Image.Image, mask_blur_radius: int = 8) -> Image.Image:
|
||||
def repaste_and_color_correct(
|
||||
self,
|
||||
result: Image.Image,
|
||||
init_image: Image.Image,
|
||||
init_mask: Image.Image,
|
||||
mask_blur_radius: int = 8,
|
||||
) -> Image.Image:
|
||||
if init_image is None or init_mask is None:
|
||||
return result
|
||||
|
||||
# Get the original alpha channel of the mask if there is one.
|
||||
# Otherwise it is some other black/white image format ('1', 'L' or 'RGB')
|
||||
pil_init_mask = init_mask.getchannel('A') if init_mask.mode == 'RGBA' else init_mask.convert('L')
|
||||
pil_init_image = init_image.convert('RGBA') # Add an alpha channel if one doesn't exist
|
||||
pil_init_mask = (
|
||||
init_mask.getchannel("A")
|
||||
if init_mask.mode == "RGBA"
|
||||
else init_mask.convert("L")
|
||||
)
|
||||
pil_init_image = init_image.convert(
|
||||
"RGBA"
|
||||
) # Add an alpha channel if one doesn't exist
|
||||
|
||||
# Build an image with only visible pixels from source to use as reference for color-matching.
|
||||
init_rgb_pixels = np.asarray(init_image.convert('RGB'), dtype=np.uint8)
|
||||
init_a_pixels = np.asarray(pil_init_image.getchannel('A'), dtype=np.uint8)
|
||||
init_rgb_pixels = np.asarray(init_image.convert("RGB"), dtype=np.uint8)
|
||||
init_a_pixels = np.asarray(pil_init_image.getchannel("A"), dtype=np.uint8)
|
||||
init_mask_pixels = np.asarray(pil_init_mask, dtype=np.uint8)
|
||||
|
||||
# Get numpy version of result
|
||||
@ -171,44 +212,70 @@ class Generator:
|
||||
|
||||
# Color correct
|
||||
np_matched_result = np_image.copy()
|
||||
np_matched_result[:,:,:] = (((np_matched_result[:,:,:].astype(np.float32) - gen_means[None,None,:]) / gen_std[None,None,:]) * init_std[None,None,:] + init_means[None,None,:]).clip(0, 255).astype(np.uint8)
|
||||
matched_result = Image.fromarray(np_matched_result, mode='RGB')
|
||||
np_matched_result[:, :, :] = (
|
||||
(
|
||||
(
|
||||
(
|
||||
np_matched_result[:, :, :].astype(np.float32)
|
||||
- gen_means[None, None, :]
|
||||
)
|
||||
/ gen_std[None, None, :]
|
||||
)
|
||||
* init_std[None, None, :]
|
||||
+ init_means[None, None, :]
|
||||
)
|
||||
.clip(0, 255)
|
||||
.astype(np.uint8)
|
||||
)
|
||||
matched_result = Image.fromarray(np_matched_result, mode="RGB")
|
||||
else:
|
||||
matched_result = Image.fromarray(np_image, mode='RGB')
|
||||
matched_result = Image.fromarray(np_image, mode="RGB")
|
||||
|
||||
# Blur the mask out (into init image) by specified amount
|
||||
if mask_blur_radius > 0:
|
||||
nm = np.asarray(pil_init_mask, dtype=np.uint8)
|
||||
nmd = cv2.erode(nm, kernel=np.ones((3,3), dtype=np.uint8), iterations=int(mask_blur_radius / 2))
|
||||
pmd = Image.fromarray(nmd, mode='L')
|
||||
nmd = cv2.erode(
|
||||
nm,
|
||||
kernel=np.ones((3, 3), dtype=np.uint8),
|
||||
iterations=int(mask_blur_radius / 2),
|
||||
)
|
||||
pmd = Image.fromarray(nmd, mode="L")
|
||||
blurred_init_mask = pmd.filter(ImageFilter.BoxBlur(mask_blur_radius))
|
||||
else:
|
||||
blurred_init_mask = pil_init_mask
|
||||
|
||||
multiplied_blurred_init_mask = ImageChops.multiply(blurred_init_mask, self.pil_image.split()[-1])
|
||||
multiplied_blurred_init_mask = ImageChops.multiply(
|
||||
blurred_init_mask, self.pil_image.split()[-1]
|
||||
)
|
||||
|
||||
# Paste original on color-corrected generation (using blurred mask)
|
||||
matched_result.paste(init_image, (0,0), mask = multiplied_blurred_init_mask)
|
||||
matched_result.paste(init_image, (0, 0), mask=multiplied_blurred_init_mask)
|
||||
return matched_result
|
||||
|
||||
def sample_to_lowres_estimated_image(self,samples):
|
||||
def sample_to_lowres_estimated_image(self, samples):
|
||||
# origingally adapted from code by @erucipe and @keturn here:
|
||||
# https://discuss.huggingface.co/t/decoding-latents-to-rgb-without-upscaling/23204/7
|
||||
|
||||
# these updated numbers for v1.5 are from @torridgristle
|
||||
v1_5_latent_rgb_factors = torch.tensor([
|
||||
# R G B
|
||||
[ 0.3444, 0.1385, 0.0670], # L1
|
||||
[ 0.1247, 0.4027, 0.1494], # L2
|
||||
[-0.3192, 0.2513, 0.2103], # L3
|
||||
[-0.1307, -0.1874, -0.7445] # L4
|
||||
], dtype=samples.dtype, device=samples.device)
|
||||
v1_5_latent_rgb_factors = torch.tensor(
|
||||
[
|
||||
# R G B
|
||||
[0.3444, 0.1385, 0.0670], # L1
|
||||
[0.1247, 0.4027, 0.1494], # L2
|
||||
[-0.3192, 0.2513, 0.2103], # L3
|
||||
[-0.1307, -0.1874, -0.7445], # L4
|
||||
],
|
||||
dtype=samples.dtype,
|
||||
device=samples.device,
|
||||
)
|
||||
|
||||
latent_image = samples[0].permute(1, 2, 0) @ v1_5_latent_rgb_factors
|
||||
latents_ubyte = (((latent_image + 1) / 2)
|
||||
.clamp(0, 1) # change scale from -1..1 to 0..1
|
||||
.mul(0xFF) # to 0..255
|
||||
.byte()).cpu()
|
||||
latents_ubyte = (
|
||||
((latent_image + 1) / 2)
|
||||
.clamp(0, 1) # change scale from -1..1 to 0..1
|
||||
.mul(0xFF) # to 0..255
|
||||
.byte()
|
||||
).cpu()
|
||||
|
||||
return Image.fromarray(latents_ubyte.numpy())
|
||||
|
||||
@ -217,38 +284,45 @@ class Generator:
|
||||
if self.variation_amount > 0 or len(self.with_variations) > 0:
|
||||
# use fixed initial noise plus random noise per iteration
|
||||
seed_everything(seed)
|
||||
initial_noise = self.get_noise(width,height)
|
||||
initial_noise = self.get_noise(width, height)
|
||||
for v_seed, v_weight in self.with_variations:
|
||||
seed = v_seed
|
||||
seed_everything(seed)
|
||||
next_noise = self.get_noise(width,height)
|
||||
next_noise = self.get_noise(width, height)
|
||||
initial_noise = self.slerp(v_weight, initial_noise, next_noise)
|
||||
if self.variation_amount > 0:
|
||||
random.seed() # reset RNG to an actually random state, so we can get a random seed for variations
|
||||
seed = random.randrange(0,np.iinfo(np.uint32).max)
|
||||
random.seed() # reset RNG to an actually random state, so we can get a random seed for variations
|
||||
seed = random.randrange(0, np.iinfo(np.uint32).max)
|
||||
return (seed, initial_noise)
|
||||
else:
|
||||
return (seed, None)
|
||||
|
||||
# returns a tensor filled with random numbers from a normal distribution
|
||||
def get_noise(self,width,height):
|
||||
def get_noise(self, width, height):
|
||||
"""
|
||||
Returns a tensor filled with random numbers, either form a normal distribution
|
||||
(txt2img) or from the latent image (img2img, inpaint)
|
||||
"""
|
||||
raise NotImplementedError("get_noise() must be implemented in a descendent class")
|
||||
raise NotImplementedError(
|
||||
"get_noise() must be implemented in a descendent class"
|
||||
)
|
||||
|
||||
def get_perlin_noise(self,width,height):
|
||||
fixdevice = 'cpu' if (self.model.device.type == 'mps') else self.model.device
|
||||
def get_perlin_noise(self, width, height):
|
||||
fixdevice = "cpu" if (self.model.device.type == "mps") else self.model.device
|
||||
# limit noise to only the diffusion image channels, not the mask channels
|
||||
input_channels = min(self.latent_channels, 4)
|
||||
# round up to the nearest block of 8
|
||||
temp_width = int((width + 7) / 8) * 8
|
||||
temp_height = int((height + 7) / 8) * 8
|
||||
noise = torch.stack([
|
||||
rand_perlin_2d((temp_height, temp_width),
|
||||
(8, 8),
|
||||
device = self.model.device).to(fixdevice) for _ in range(input_channels)], dim=0).to(self.model.device)
|
||||
noise = torch.stack(
|
||||
[
|
||||
rand_perlin_2d(
|
||||
(temp_height, temp_width), (8, 8), device=self.model.device
|
||||
).to(fixdevice)
|
||||
for _ in range(input_channels)
|
||||
],
|
||||
dim=0,
|
||||
).to(self.model.device)
|
||||
return noise[0:4, 0:height, 0:width]
|
||||
|
||||
def new_seed(self):
|
||||
@ -256,7 +330,7 @@ class Generator:
|
||||
return self.seed
|
||||
|
||||
def slerp(self, t, v0, v1, DOT_THRESHOLD=0.9995):
|
||||
'''
|
||||
"""
|
||||
Spherical linear interpolation
|
||||
Args:
|
||||
t (float/np.ndarray): Float value between 0.0 and 1.0
|
||||
@ -266,7 +340,7 @@ class Generator:
|
||||
colineal. Not recommended to alter this.
|
||||
Returns:
|
||||
v2 (np.ndarray): Interpolation vector between v0 and v1
|
||||
'''
|
||||
"""
|
||||
inputs_are_torch = False
|
||||
if not isinstance(v0, np.ndarray):
|
||||
inputs_are_torch = True
|
||||
@ -292,15 +366,15 @@ class Generator:
|
||||
|
||||
return v2
|
||||
|
||||
def safety_check(self,image:Image.Image):
|
||||
'''
|
||||
def safety_check(self, image: Image.Image):
|
||||
"""
|
||||
If the CompViz safety checker flags an NSFW image, we
|
||||
blur it out.
|
||||
'''
|
||||
"""
|
||||
import diffusers
|
||||
|
||||
checker = self.safety_checker['checker']
|
||||
extractor = self.safety_checker['extractor']
|
||||
checker = self.safety_checker["checker"]
|
||||
extractor = self.safety_checker["extractor"]
|
||||
features = extractor([image], return_tensors="pt")
|
||||
features.to(self.model.device)
|
||||
|
||||
@ -309,19 +383,23 @@ class Generator:
|
||||
x_image = x_image[None].transpose(0, 3, 1, 2)
|
||||
|
||||
diffusers.logging.set_verbosity_error()
|
||||
checked_image, has_nsfw_concept = checker(images=x_image, clip_input=features.pixel_values)
|
||||
checked_image, has_nsfw_concept = checker(
|
||||
images=x_image, clip_input=features.pixel_values
|
||||
)
|
||||
if has_nsfw_concept[0]:
|
||||
print('** An image with potential non-safe content has been detected. A blurred image will be returned. **')
|
||||
print(
|
||||
"** An image with potential non-safe content has been detected. A blurred image will be returned. **"
|
||||
)
|
||||
return self.blur(image)
|
||||
else:
|
||||
return image
|
||||
|
||||
def blur(self,input):
|
||||
def blur(self, input):
|
||||
blurry = input.filter(filter=ImageFilter.GaussianBlur(radius=32))
|
||||
try:
|
||||
caution = self.get_caution_img()
|
||||
if caution:
|
||||
blurry.paste(caution,(0,0),caution)
|
||||
blurry.paste(caution, (0, 0), caution)
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
return blurry
|
||||
@ -332,43 +410,52 @@ class Generator:
|
||||
return self.caution_img
|
||||
path = Path(web_assets.__path__[0]) / CAUTION_IMG
|
||||
caution = Image.open(path)
|
||||
self.caution_img = caution.resize((caution.width // 2, caution.height //2))
|
||||
self.caution_img = caution.resize((caution.width // 2, caution.height // 2))
|
||||
return self.caution_img
|
||||
|
||||
# this is a handy routine for debugging use. Given a generated sample,
|
||||
# convert it into a PNG image and store it at the indicated path
|
||||
def save_sample(self, sample, filepath):
|
||||
image = self.sample_to_image(sample)
|
||||
dirname = os.path.dirname(filepath) or '.'
|
||||
dirname = os.path.dirname(filepath) or "."
|
||||
if not os.path.exists(dirname):
|
||||
print(f'** creating directory {dirname}')
|
||||
print(f"** creating directory {dirname}")
|
||||
os.makedirs(dirname, exist_ok=True)
|
||||
image.save(filepath,'PNG')
|
||||
image.save(filepath, "PNG")
|
||||
|
||||
|
||||
def torch_dtype(self)->torch.dtype:
|
||||
return torch.float16 if self.precision == 'float16' else torch.float32
|
||||
def torch_dtype(self) -> torch.dtype:
|
||||
return torch.float16 if self.precision == "float16" else torch.float32
|
||||
|
||||
# returns a tensor filled with random numbers from a normal distribution
|
||||
def get_noise(self,width,height):
|
||||
device = self.model.device
|
||||
def get_noise(self, width, height):
|
||||
device = self.model.device
|
||||
# limit noise to only the diffusion image channels, not the mask channels
|
||||
input_channels = min(self.latent_channels, 4)
|
||||
if self.use_mps_noise or device.type == 'mps':
|
||||
x = torch.randn([1,
|
||||
input_channels,
|
||||
height // self.downsampling_factor,
|
||||
width // self.downsampling_factor],
|
||||
dtype=self.torch_dtype(),
|
||||
device='cpu').to(device)
|
||||
if self.use_mps_noise or device.type == "mps":
|
||||
x = torch.randn(
|
||||
[
|
||||
1,
|
||||
input_channels,
|
||||
height // self.downsampling_factor,
|
||||
width // self.downsampling_factor,
|
||||
],
|
||||
dtype=self.torch_dtype(),
|
||||
device="cpu",
|
||||
).to(device)
|
||||
else:
|
||||
x = torch.randn([1,
|
||||
input_channels,
|
||||
height // self.downsampling_factor,
|
||||
width // self.downsampling_factor],
|
||||
dtype=self.torch_dtype(),
|
||||
device=device)
|
||||
x = torch.randn(
|
||||
[
|
||||
1,
|
||||
input_channels,
|
||||
height // self.downsampling_factor,
|
||||
width // self.downsampling_factor,
|
||||
],
|
||||
dtype=self.torch_dtype(),
|
||||
device=device,
|
||||
)
|
||||
if self.perlin > 0.0:
|
||||
perlin_noise = self.get_perlin_noise(width // self.downsampling_factor, height // self.downsampling_factor)
|
||||
x = (1-self.perlin)*x + self.perlin*perlin_noise
|
||||
perlin_noise = self.get_perlin_noise(
|
||||
width // self.downsampling_factor, height // self.downsampling_factor
|
||||
)
|
||||
x = (1 - self.perlin) * x + self.perlin * perlin_noise
|
||||
return x
|
@ -1,37 +1,38 @@
|
||||
'''
|
||||
ldm.invoke.generator.embiggen descends from ldm.invoke.generator
|
||||
and generates with ldm.invoke.generator.img2img
|
||||
'''
|
||||
"""
|
||||
invokeai.backend.generator.embiggen descends from .generator
|
||||
and generates with .generator.img2img
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
from tqdm import trange
|
||||
|
||||
from ldm.invoke.generator.base import Generator
|
||||
from ldm.invoke.generator.img2img import Img2Img
|
||||
from .base import Generator
|
||||
from .img2img import Img2Img
|
||||
|
||||
|
||||
class Embiggen(Generator):
|
||||
def __init__(self, model, precision):
|
||||
super().__init__(model, precision)
|
||||
self.init_latent = None
|
||||
self.init_latent = None
|
||||
|
||||
# Replace generate because Embiggen doesn't need/use most of what it does normallly
|
||||
def generate(self,prompt,iterations=1,seed=None,
|
||||
image_callback=None, step_callback=None,
|
||||
**kwargs):
|
||||
|
||||
make_image = self.get_make_image(
|
||||
prompt,
|
||||
step_callback = step_callback,
|
||||
**kwargs
|
||||
)
|
||||
results = []
|
||||
seed = seed if seed else self.new_seed()
|
||||
def generate(
|
||||
self,
|
||||
prompt,
|
||||
iterations=1,
|
||||
seed=None,
|
||||
image_callback=None,
|
||||
step_callback=None,
|
||||
**kwargs,
|
||||
):
|
||||
make_image = self.get_make_image(prompt, step_callback=step_callback, **kwargs)
|
||||
results = []
|
||||
seed = seed if seed else self.new_seed()
|
||||
|
||||
# Noise will be generated by the Img2Img generator when called
|
||||
for _ in trange(iterations, desc='Generating'):
|
||||
for _ in trange(iterations, desc="Generating"):
|
||||
# make_image will call Img2Img which will do the equivalent of get_noise itself
|
||||
image = make_image()
|
||||
results.append([image, seed])
|
||||
@ -56,13 +57,15 @@ class Embiggen(Generator):
|
||||
embiggen,
|
||||
embiggen_tiles,
|
||||
step_callback=None,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Returns a function returning an image derived from the prompt and multi-stage twice-baked potato layering over the img2img on the initial image
|
||||
Return value depends on the seed at the time you call it
|
||||
"""
|
||||
assert not sampler.uses_inpainting_model(), "--embiggen is not supported by inpainting models"
|
||||
assert (
|
||||
not sampler.uses_inpainting_model()
|
||||
), "--embiggen is not supported by inpainting models"
|
||||
|
||||
# Construct embiggen arg array, and sanity check arguments
|
||||
if embiggen == None: # embiggen can also be called with just embiggen_tiles
|
||||
@ -70,48 +73,57 @@ class Embiggen(Generator):
|
||||
elif embiggen[0] < 0:
|
||||
embiggen[0] = 1.0
|
||||
print(
|
||||
'>> Embiggen scaling factor cannot be negative, fell back to the default of 1.0 !')
|
||||
">> Embiggen scaling factor cannot be negative, fell back to the default of 1.0 !"
|
||||
)
|
||||
if len(embiggen) < 2:
|
||||
embiggen.append(0.75)
|
||||
elif embiggen[1] > 1.0 or embiggen[1] < 0:
|
||||
embiggen[1] = 0.75
|
||||
print('>> Embiggen upscaling strength for ESRGAN must be between 0 and 1, fell back to the default of 0.75 !')
|
||||
print(
|
||||
">> Embiggen upscaling strength for ESRGAN must be between 0 and 1, fell back to the default of 0.75 !"
|
||||
)
|
||||
if len(embiggen) < 3:
|
||||
embiggen.append(0.25)
|
||||
elif embiggen[2] < 0:
|
||||
embiggen[2] = 0.25
|
||||
print('>> Overlap size for Embiggen must be a positive ratio between 0 and 1 OR a number of pixels, fell back to the default of 0.25 !')
|
||||
print(
|
||||
">> Overlap size for Embiggen must be a positive ratio between 0 and 1 OR a number of pixels, fell back to the default of 0.25 !"
|
||||
)
|
||||
|
||||
# Convert tiles from their user-freindly count-from-one to count-from-zero, because we need to do modulo math
|
||||
# and then sort them, because... people.
|
||||
if embiggen_tiles:
|
||||
embiggen_tiles = list(map(lambda n: n-1, embiggen_tiles))
|
||||
embiggen_tiles = list(map(lambda n: n - 1, embiggen_tiles))
|
||||
embiggen_tiles.sort()
|
||||
|
||||
if strength >= 0.5:
|
||||
print(f'* WARNING: Embiggen may produce mirror motifs if the strength (-f) is too high (currently {strength}). Try values between 0.35-0.45.')
|
||||
print(
|
||||
f"* WARNING: Embiggen may produce mirror motifs if the strength (-f) is too high (currently {strength}). Try values between 0.35-0.45."
|
||||
)
|
||||
|
||||
# Prep img2img generator, since we wrap over it
|
||||
gen_img2img = Img2Img(self.model,self.precision)
|
||||
gen_img2img = Img2Img(self.model, self.precision)
|
||||
|
||||
# Open original init image (not a tensor) to manipulate
|
||||
initsuperimage = Image.open(init_img)
|
||||
|
||||
with Image.open(init_img) as img:
|
||||
initsuperimage = img.convert('RGB')
|
||||
initsuperimage = img.convert("RGB")
|
||||
|
||||
# Size of the target super init image in pixels
|
||||
initsuperwidth, initsuperheight = initsuperimage.size
|
||||
|
||||
# Increase by scaling factor if not already resized, using ESRGAN as able
|
||||
if embiggen[0] != 1.0:
|
||||
initsuperwidth = round(initsuperwidth*embiggen[0])
|
||||
initsuperheight = round(initsuperheight*embiggen[0])
|
||||
initsuperwidth = round(initsuperwidth * embiggen[0])
|
||||
initsuperheight = round(initsuperheight * embiggen[0])
|
||||
if embiggen[1] > 0: # No point in ESRGAN upscaling if strength is set zero
|
||||
from ldm.invoke.restoration.realesrgan import ESRGAN
|
||||
from ..restoration.realesrgan import ESRGAN
|
||||
|
||||
esrgan = ESRGAN()
|
||||
print(
|
||||
f'>> ESRGAN upscaling init image prior to cutting with Embiggen with strength {embiggen[1]}')
|
||||
f">> ESRGAN upscaling init image prior to cutting with Embiggen with strength {embiggen[1]}"
|
||||
)
|
||||
if embiggen[0] > 2:
|
||||
initsuperimage = esrgan.process(
|
||||
initsuperimage,
|
||||
@ -130,7 +142,8 @@ class Embiggen(Generator):
|
||||
# but from personal experiance it doesn't greatly improve anything after 4x
|
||||
# Resize to target scaling factor resolution
|
||||
initsuperimage = initsuperimage.resize(
|
||||
(initsuperwidth, initsuperheight), Image.Resampling.LANCZOS)
|
||||
(initsuperwidth, initsuperheight), Image.Resampling.LANCZOS
|
||||
)
|
||||
|
||||
# Use width and height as tile widths and height
|
||||
# Determine buffer size in pixels
|
||||
@ -153,23 +166,24 @@ class Embiggen(Generator):
|
||||
emb_tiles_x = 1
|
||||
emb_tiles_y = 1
|
||||
if (initsuperwidth - width) > 0:
|
||||
emb_tiles_x = ceildiv(initsuperwidth - width,
|
||||
width - overlap_size_x) + 1
|
||||
emb_tiles_x = ceildiv(initsuperwidth - width, width - overlap_size_x) + 1
|
||||
if (initsuperheight - height) > 0:
|
||||
emb_tiles_y = ceildiv(initsuperheight - height,
|
||||
height - overlap_size_y) + 1
|
||||
emb_tiles_y = ceildiv(initsuperheight - height, height - overlap_size_y) + 1
|
||||
# Sanity
|
||||
assert emb_tiles_x > 1 or emb_tiles_y > 1, f'ERROR: Based on the requested dimensions of {initsuperwidth}x{initsuperheight} and tiles of {width}x{height} you don\'t need to Embiggen! Check your arguments.'
|
||||
assert (
|
||||
emb_tiles_x > 1 or emb_tiles_y > 1
|
||||
), f"ERROR: Based on the requested dimensions of {initsuperwidth}x{initsuperheight} and tiles of {width}x{height} you don't need to Embiggen! Check your arguments."
|
||||
|
||||
# Prep alpha layers --------------
|
||||
# https://stackoverflow.com/questions/69321734/how-to-create-different-transparency-like-gradient-with-python-pil
|
||||
# agradientL is Left-side transparent
|
||||
agradientL = Image.linear_gradient('L').rotate(
|
||||
90).resize((overlap_size_x, height))
|
||||
agradientL = (
|
||||
Image.linear_gradient("L").rotate(90).resize((overlap_size_x, height))
|
||||
)
|
||||
# agradientT is Top-side transparent
|
||||
agradientT = Image.linear_gradient('L').resize((width, overlap_size_y))
|
||||
agradientT = Image.linear_gradient("L").resize((width, overlap_size_y))
|
||||
# radial corner is the left-top corner, made full circle then cut to just the left-top quadrant
|
||||
agradientC = Image.new('L', (256, 256))
|
||||
agradientC = Image.new("L", (256, 256))
|
||||
for y in range(256):
|
||||
for x in range(256):
|
||||
# Find distance to lower right corner (numpy takes arrays)
|
||||
@ -177,16 +191,16 @@ class Embiggen(Generator):
|
||||
# Clamp values to max 255
|
||||
if distanceToLR > 255:
|
||||
distanceToLR = 255
|
||||
#Place the pixel as invert of distance
|
||||
# Place the pixel as invert of distance
|
||||
agradientC.putpixel((x, y), round(255 - distanceToLR))
|
||||
|
||||
# Create alternative asymmetric diagonal corner to use on "tailing" intersections to prevent hard edges
|
||||
# Fits for a left-fading gradient on the bottom side and full opacity on the right side.
|
||||
agradientAsymC = Image.new('L', (256, 256))
|
||||
agradientAsymC = Image.new("L", (256, 256))
|
||||
for y in range(256):
|
||||
for x in range(256):
|
||||
value = round(max(0, x-(255-y)) * (255 / max(1,y)))
|
||||
#Clamp values
|
||||
value = round(max(0, x - (255 - y)) * (255 / max(1, y)))
|
||||
# Clamp values
|
||||
value = max(0, value)
|
||||
value = min(255, value)
|
||||
agradientAsymC.putpixel((x, y), value)
|
||||
@ -204,80 +218,91 @@ class Embiggen(Generator):
|
||||
# make masks with an asymmetric upper-right corner so when the curved transparent corner of the next tile
|
||||
# to its right is placed it doesn't reveal a hard trailing semi-transparent edge in the overlapping space
|
||||
alphaLayerTaC = alphaLayerT.copy()
|
||||
alphaLayerTaC.paste(agradientAsymC.rotate(270).resize((overlap_size_x, overlap_size_y)), (width - overlap_size_x, 0))
|
||||
alphaLayerTaC.paste(
|
||||
agradientAsymC.rotate(270).resize((overlap_size_x, overlap_size_y)),
|
||||
(width - overlap_size_x, 0),
|
||||
)
|
||||
alphaLayerLTaC = alphaLayerLTC.copy()
|
||||
alphaLayerLTaC.paste(agradientAsymC.rotate(270).resize((overlap_size_x, overlap_size_y)), (width - overlap_size_x, 0))
|
||||
alphaLayerLTaC.paste(
|
||||
agradientAsymC.rotate(270).resize((overlap_size_x, overlap_size_y)),
|
||||
(width - overlap_size_x, 0),
|
||||
)
|
||||
|
||||
if embiggen_tiles:
|
||||
# Individual unconnected sides
|
||||
alphaLayerR = Image.new("L", (width, height), 255)
|
||||
alphaLayerR.paste(agradientL.rotate(
|
||||
180), (width - overlap_size_x, 0))
|
||||
alphaLayerR.paste(agradientL.rotate(180), (width - overlap_size_x, 0))
|
||||
alphaLayerB = Image.new("L", (width, height), 255)
|
||||
alphaLayerB.paste(agradientT.rotate(
|
||||
180), (0, height - overlap_size_y))
|
||||
alphaLayerB.paste(agradientT.rotate(180), (0, height - overlap_size_y))
|
||||
alphaLayerTB = Image.new("L", (width, height), 255)
|
||||
alphaLayerTB.paste(agradientT, (0, 0))
|
||||
alphaLayerTB.paste(agradientT.rotate(
|
||||
180), (0, height - overlap_size_y))
|
||||
alphaLayerTB.paste(agradientT.rotate(180), (0, height - overlap_size_y))
|
||||
alphaLayerLR = Image.new("L", (width, height), 255)
|
||||
alphaLayerLR.paste(agradientL, (0, 0))
|
||||
alphaLayerLR.paste(agradientL.rotate(
|
||||
180), (width - overlap_size_x, 0))
|
||||
alphaLayerLR.paste(agradientL.rotate(180), (width - overlap_size_x, 0))
|
||||
|
||||
# Sides and corner Layers
|
||||
alphaLayerRBC = Image.new("L", (width, height), 255)
|
||||
alphaLayerRBC.paste(agradientL.rotate(
|
||||
180), (width - overlap_size_x, 0))
|
||||
alphaLayerRBC.paste(agradientT.rotate(
|
||||
180), (0, height - overlap_size_y))
|
||||
alphaLayerRBC.paste(agradientC.rotate(180).resize(
|
||||
(overlap_size_x, overlap_size_y)), (width - overlap_size_x, height - overlap_size_y))
|
||||
alphaLayerRBC.paste(agradientL.rotate(180), (width - overlap_size_x, 0))
|
||||
alphaLayerRBC.paste(agradientT.rotate(180), (0, height - overlap_size_y))
|
||||
alphaLayerRBC.paste(
|
||||
agradientC.rotate(180).resize((overlap_size_x, overlap_size_y)),
|
||||
(width - overlap_size_x, height - overlap_size_y),
|
||||
)
|
||||
alphaLayerLBC = Image.new("L", (width, height), 255)
|
||||
alphaLayerLBC.paste(agradientL, (0, 0))
|
||||
alphaLayerLBC.paste(agradientT.rotate(
|
||||
180), (0, height - overlap_size_y))
|
||||
alphaLayerLBC.paste(agradientC.rotate(90).resize(
|
||||
(overlap_size_x, overlap_size_y)), (0, height - overlap_size_y))
|
||||
alphaLayerLBC.paste(agradientT.rotate(180), (0, height - overlap_size_y))
|
||||
alphaLayerLBC.paste(
|
||||
agradientC.rotate(90).resize((overlap_size_x, overlap_size_y)),
|
||||
(0, height - overlap_size_y),
|
||||
)
|
||||
alphaLayerRTC = Image.new("L", (width, height), 255)
|
||||
alphaLayerRTC.paste(agradientL.rotate(
|
||||
180), (width - overlap_size_x, 0))
|
||||
alphaLayerRTC.paste(agradientL.rotate(180), (width - overlap_size_x, 0))
|
||||
alphaLayerRTC.paste(agradientT, (0, 0))
|
||||
alphaLayerRTC.paste(agradientC.rotate(270).resize(
|
||||
(overlap_size_x, overlap_size_y)), (width - overlap_size_x, 0))
|
||||
alphaLayerRTC.paste(
|
||||
agradientC.rotate(270).resize((overlap_size_x, overlap_size_y)),
|
||||
(width - overlap_size_x, 0),
|
||||
)
|
||||
|
||||
# All but X layers
|
||||
alphaLayerABT = Image.new("L", (width, height), 255)
|
||||
alphaLayerABT.paste(alphaLayerLBC, (0, 0))
|
||||
alphaLayerABT.paste(agradientL.rotate(
|
||||
180), (width - overlap_size_x, 0))
|
||||
alphaLayerABT.paste(agradientC.rotate(180).resize(
|
||||
(overlap_size_x, overlap_size_y)), (width - overlap_size_x, height - overlap_size_y))
|
||||
alphaLayerABT.paste(agradientL.rotate(180), (width - overlap_size_x, 0))
|
||||
alphaLayerABT.paste(
|
||||
agradientC.rotate(180).resize((overlap_size_x, overlap_size_y)),
|
||||
(width - overlap_size_x, height - overlap_size_y),
|
||||
)
|
||||
alphaLayerABL = Image.new("L", (width, height), 255)
|
||||
alphaLayerABL.paste(alphaLayerRTC, (0, 0))
|
||||
alphaLayerABL.paste(agradientT.rotate(
|
||||
180), (0, height - overlap_size_y))
|
||||
alphaLayerABL.paste(agradientC.rotate(180).resize(
|
||||
(overlap_size_x, overlap_size_y)), (width - overlap_size_x, height - overlap_size_y))
|
||||
alphaLayerABL.paste(agradientT.rotate(180), (0, height - overlap_size_y))
|
||||
alphaLayerABL.paste(
|
||||
agradientC.rotate(180).resize((overlap_size_x, overlap_size_y)),
|
||||
(width - overlap_size_x, height - overlap_size_y),
|
||||
)
|
||||
alphaLayerABR = Image.new("L", (width, height), 255)
|
||||
alphaLayerABR.paste(alphaLayerLBC, (0, 0))
|
||||
alphaLayerABR.paste(agradientT, (0, 0))
|
||||
alphaLayerABR.paste(agradientC.resize(
|
||||
(overlap_size_x, overlap_size_y)), (0, 0))
|
||||
alphaLayerABR.paste(
|
||||
agradientC.resize((overlap_size_x, overlap_size_y)), (0, 0)
|
||||
)
|
||||
alphaLayerABB = Image.new("L", (width, height), 255)
|
||||
alphaLayerABB.paste(alphaLayerRTC, (0, 0))
|
||||
alphaLayerABB.paste(agradientL, (0, 0))
|
||||
alphaLayerABB.paste(agradientC.resize(
|
||||
(overlap_size_x, overlap_size_y)), (0, 0))
|
||||
alphaLayerABB.paste(
|
||||
agradientC.resize((overlap_size_x, overlap_size_y)), (0, 0)
|
||||
)
|
||||
|
||||
# All-around layer
|
||||
alphaLayerAA = Image.new("L", (width, height), 255)
|
||||
alphaLayerAA.paste(alphaLayerABT, (0, 0))
|
||||
alphaLayerAA.paste(agradientT, (0, 0))
|
||||
alphaLayerAA.paste(agradientC.resize(
|
||||
(overlap_size_x, overlap_size_y)), (0, 0))
|
||||
alphaLayerAA.paste(agradientC.rotate(270).resize(
|
||||
(overlap_size_x, overlap_size_y)), (width - overlap_size_x, 0))
|
||||
alphaLayerAA.paste(
|
||||
agradientC.resize((overlap_size_x, overlap_size_y)), (0, 0)
|
||||
)
|
||||
alphaLayerAA.paste(
|
||||
agradientC.rotate(270).resize((overlap_size_x, overlap_size_y)),
|
||||
(width - overlap_size_x, 0),
|
||||
)
|
||||
|
||||
# Clean up temporary gradients
|
||||
del agradientL
|
||||
@ -287,17 +312,20 @@ class Embiggen(Generator):
|
||||
def make_image():
|
||||
# Make main tiles -------------------------------------------------
|
||||
if embiggen_tiles:
|
||||
print(f'>> Making {len(embiggen_tiles)} Embiggen tiles...')
|
||||
print(f">> Making {len(embiggen_tiles)} Embiggen tiles...")
|
||||
else:
|
||||
print(
|
||||
f'>> Making {(emb_tiles_x * emb_tiles_y)} Embiggen tiles ({emb_tiles_x}x{emb_tiles_y})...')
|
||||
f">> Making {(emb_tiles_x * emb_tiles_y)} Embiggen tiles ({emb_tiles_x}x{emb_tiles_y})..."
|
||||
)
|
||||
|
||||
emb_tile_store = []
|
||||
# Although we could use the same seed for every tile for determinism, at higher strengths this may
|
||||
# produce duplicated structures for each tile and make the tiling effect more obvious
|
||||
# instead track and iterate a local seed we pass to Img2Img
|
||||
seed = self.seed
|
||||
seedintlimit = np.iinfo(np.uint32).max - 1 # only retreive this one from numpy
|
||||
seedintlimit = (
|
||||
np.iinfo(np.uint32).max - 1
|
||||
) # only retreive this one from numpy
|
||||
|
||||
for tile in range(emb_tiles_x * emb_tiles_y):
|
||||
# Don't iterate on first tile
|
||||
@ -334,37 +362,38 @@ class Embiggen(Generator):
|
||||
|
||||
if embiggen_tiles:
|
||||
print(
|
||||
f'Making tile #{tile + 1} ({embiggen_tiles.index(tile) + 1} of {len(embiggen_tiles)} requested)')
|
||||
f"Making tile #{tile + 1} ({embiggen_tiles.index(tile) + 1} of {len(embiggen_tiles)} requested)"
|
||||
)
|
||||
else:
|
||||
print(
|
||||
f'Starting {tile + 1} of {(emb_tiles_x * emb_tiles_y)} tiles')
|
||||
print(f"Starting {tile + 1} of {(emb_tiles_x * emb_tiles_y)} tiles")
|
||||
|
||||
# create a torch tensor from an Image
|
||||
newinitimage = np.array(
|
||||
newinitimage).astype(np.float32) / 255.0
|
||||
newinitimage = np.array(newinitimage).astype(np.float32) / 255.0
|
||||
newinitimage = newinitimage[None].transpose(0, 3, 1, 2)
|
||||
newinitimage = torch.from_numpy(newinitimage)
|
||||
newinitimage = 2.0 * newinitimage - 1.0
|
||||
newinitimage = newinitimage.to(self.model.device)
|
||||
clear_cuda_cache = kwargs['clear_cuda_cache'] if 'clear_cuda_cache' in kwargs else None
|
||||
clear_cuda_cache = (
|
||||
kwargs["clear_cuda_cache"] if "clear_cuda_cache" in kwargs else None
|
||||
)
|
||||
|
||||
tile_results = gen_img2img.generate(
|
||||
prompt,
|
||||
iterations = 1,
|
||||
seed = seed,
|
||||
sampler = sampler,
|
||||
steps = steps,
|
||||
cfg_scale = cfg_scale,
|
||||
conditioning = conditioning,
|
||||
ddim_eta = ddim_eta,
|
||||
image_callback = None, # called only after the final image is generated
|
||||
step_callback = step_callback, # called after each intermediate image is generated
|
||||
width = width,
|
||||
height = height,
|
||||
init_image = newinitimage, # notice that init_image is different from init_img
|
||||
mask_image = None,
|
||||
strength = strength,
|
||||
clear_cuda_cache = clear_cuda_cache
|
||||
iterations=1,
|
||||
seed=seed,
|
||||
sampler=sampler,
|
||||
steps=steps,
|
||||
cfg_scale=cfg_scale,
|
||||
conditioning=conditioning,
|
||||
ddim_eta=ddim_eta,
|
||||
image_callback=None, # called only after the final image is generated
|
||||
step_callback=step_callback, # called after each intermediate image is generated
|
||||
width=width,
|
||||
height=height,
|
||||
init_image=newinitimage, # notice that init_image is different from init_img
|
||||
mask_image=None,
|
||||
strength=strength,
|
||||
clear_cuda_cache=clear_cuda_cache,
|
||||
)
|
||||
|
||||
emb_tile_store.append(tile_results[0][0])
|
||||
@ -373,12 +402,14 @@ class Embiggen(Generator):
|
||||
del newinitimage
|
||||
|
||||
# Sanity check we have them all
|
||||
if len(emb_tile_store) == (emb_tiles_x * emb_tiles_y) or (embiggen_tiles != [] and len(emb_tile_store) == len(embiggen_tiles)):
|
||||
outputsuperimage = Image.new(
|
||||
"RGBA", (initsuperwidth, initsuperheight))
|
||||
if len(emb_tile_store) == (emb_tiles_x * emb_tiles_y) or (
|
||||
embiggen_tiles != [] and len(emb_tile_store) == len(embiggen_tiles)
|
||||
):
|
||||
outputsuperimage = Image.new("RGBA", (initsuperwidth, initsuperheight))
|
||||
if embiggen_tiles:
|
||||
outputsuperimage.alpha_composite(
|
||||
initsuperimage.convert('RGBA'), (0, 0))
|
||||
initsuperimage.convert("RGBA"), (0, 0)
|
||||
)
|
||||
for tile in range(emb_tiles_x * emb_tiles_y):
|
||||
if embiggen_tiles:
|
||||
if tile in embiggen_tiles:
|
||||
@ -387,7 +418,7 @@ class Embiggen(Generator):
|
||||
continue
|
||||
else:
|
||||
intileimage = emb_tile_store[tile]
|
||||
intileimage = intileimage.convert('RGBA')
|
||||
intileimage = intileimage.convert("RGBA")
|
||||
# Get row and column entries
|
||||
emb_row_i = tile // emb_tiles_x
|
||||
emb_column_i = tile % emb_tiles_x
|
||||
@ -399,8 +430,7 @@ class Embiggen(Generator):
|
||||
if emb_column_i + 1 == emb_tiles_x:
|
||||
left = initsuperwidth - width
|
||||
else:
|
||||
left = round(emb_column_i *
|
||||
(width - overlap_size_x))
|
||||
left = round(emb_column_i * (width - overlap_size_x))
|
||||
if emb_row_i + 1 == emb_tiles_y:
|
||||
top = initsuperheight - height
|
||||
else:
|
||||
@ -411,33 +441,43 @@ class Embiggen(Generator):
|
||||
# top of image
|
||||
if emb_row_i == 0:
|
||||
if emb_column_i == 0:
|
||||
if (tile+1) in embiggen_tiles: # Look-ahead right
|
||||
if (tile+emb_tiles_x) not in embiggen_tiles: # Look-ahead down
|
||||
if (tile + 1) in embiggen_tiles: # Look-ahead right
|
||||
if (
|
||||
tile + emb_tiles_x
|
||||
) not in embiggen_tiles: # Look-ahead down
|
||||
intileimage.putalpha(alphaLayerB)
|
||||
# Otherwise do nothing on this tile
|
||||
elif (tile+emb_tiles_x) in embiggen_tiles: # Look-ahead down only
|
||||
elif (
|
||||
tile + emb_tiles_x
|
||||
) in embiggen_tiles: # Look-ahead down only
|
||||
intileimage.putalpha(alphaLayerR)
|
||||
else:
|
||||
intileimage.putalpha(alphaLayerRBC)
|
||||
elif emb_column_i == emb_tiles_x - 1:
|
||||
if (tile+emb_tiles_x) in embiggen_tiles: # Look-ahead down
|
||||
if (
|
||||
tile + emb_tiles_x
|
||||
) in embiggen_tiles: # Look-ahead down
|
||||
intileimage.putalpha(alphaLayerL)
|
||||
else:
|
||||
intileimage.putalpha(alphaLayerLBC)
|
||||
else:
|
||||
if (tile+1) in embiggen_tiles: # Look-ahead right
|
||||
if (tile+emb_tiles_x) in embiggen_tiles: # Look-ahead down
|
||||
if (tile + 1) in embiggen_tiles: # Look-ahead right
|
||||
if (
|
||||
tile + emb_tiles_x
|
||||
) in embiggen_tiles: # Look-ahead down
|
||||
intileimage.putalpha(alphaLayerL)
|
||||
else:
|
||||
intileimage.putalpha(alphaLayerLBC)
|
||||
elif (tile+emb_tiles_x) in embiggen_tiles: # Look-ahead down only
|
||||
elif (
|
||||
tile + emb_tiles_x
|
||||
) in embiggen_tiles: # Look-ahead down only
|
||||
intileimage.putalpha(alphaLayerLR)
|
||||
else:
|
||||
intileimage.putalpha(alphaLayerABT)
|
||||
# bottom of image
|
||||
elif emb_row_i == emb_tiles_y - 1:
|
||||
if emb_column_i == 0:
|
||||
if (tile+1) in embiggen_tiles: # Look-ahead right
|
||||
if (tile + 1) in embiggen_tiles: # Look-ahead right
|
||||
intileimage.putalpha(alphaLayerTaC)
|
||||
else:
|
||||
intileimage.putalpha(alphaLayerRTC)
|
||||
@ -445,34 +485,44 @@ class Embiggen(Generator):
|
||||
# No tiles to look ahead to
|
||||
intileimage.putalpha(alphaLayerLTC)
|
||||
else:
|
||||
if (tile+1) in embiggen_tiles: # Look-ahead right
|
||||
if (tile + 1) in embiggen_tiles: # Look-ahead right
|
||||
intileimage.putalpha(alphaLayerLTaC)
|
||||
else:
|
||||
intileimage.putalpha(alphaLayerABB)
|
||||
# vertical middle of image
|
||||
else:
|
||||
if emb_column_i == 0:
|
||||
if (tile+1) in embiggen_tiles: # Look-ahead right
|
||||
if (tile+emb_tiles_x) in embiggen_tiles: # Look-ahead down
|
||||
if (tile + 1) in embiggen_tiles: # Look-ahead right
|
||||
if (
|
||||
tile + emb_tiles_x
|
||||
) in embiggen_tiles: # Look-ahead down
|
||||
intileimage.putalpha(alphaLayerTaC)
|
||||
else:
|
||||
intileimage.putalpha(alphaLayerTB)
|
||||
elif (tile+emb_tiles_x) in embiggen_tiles: # Look-ahead down only
|
||||
elif (
|
||||
tile + emb_tiles_x
|
||||
) in embiggen_tiles: # Look-ahead down only
|
||||
intileimage.putalpha(alphaLayerRTC)
|
||||
else:
|
||||
intileimage.putalpha(alphaLayerABL)
|
||||
elif emb_column_i == emb_tiles_x - 1:
|
||||
if (tile+emb_tiles_x) in embiggen_tiles: # Look-ahead down
|
||||
if (
|
||||
tile + emb_tiles_x
|
||||
) in embiggen_tiles: # Look-ahead down
|
||||
intileimage.putalpha(alphaLayerLTC)
|
||||
else:
|
||||
intileimage.putalpha(alphaLayerABR)
|
||||
else:
|
||||
if (tile+1) in embiggen_tiles: # Look-ahead right
|
||||
if (tile+emb_tiles_x) in embiggen_tiles: # Look-ahead down
|
||||
if (tile + 1) in embiggen_tiles: # Look-ahead right
|
||||
if (
|
||||
tile + emb_tiles_x
|
||||
) in embiggen_tiles: # Look-ahead down
|
||||
intileimage.putalpha(alphaLayerLTaC)
|
||||
else:
|
||||
intileimage.putalpha(alphaLayerABR)
|
||||
elif (tile+emb_tiles_x) in embiggen_tiles: # Look-ahead down only
|
||||
elif (
|
||||
tile + emb_tiles_x
|
||||
) in embiggen_tiles: # Look-ahead down only
|
||||
intileimage.putalpha(alphaLayerABB)
|
||||
else:
|
||||
intileimage.putalpha(alphaLayerAA)
|
||||
@ -481,21 +531,28 @@ class Embiggen(Generator):
|
||||
if emb_row_i == 0 and emb_column_i >= 1:
|
||||
intileimage.putalpha(alphaLayerL)
|
||||
elif emb_row_i >= 1 and emb_column_i == 0:
|
||||
if emb_column_i + 1 == emb_tiles_x: # If we don't have anything that can be placed to the right
|
||||
if (
|
||||
emb_column_i + 1 == emb_tiles_x
|
||||
): # If we don't have anything that can be placed to the right
|
||||
intileimage.putalpha(alphaLayerT)
|
||||
else:
|
||||
intileimage.putalpha(alphaLayerTaC)
|
||||
else:
|
||||
if emb_column_i + 1 == emb_tiles_x: # If we don't have anything that can be placed to the right
|
||||
if (
|
||||
emb_column_i + 1 == emb_tiles_x
|
||||
): # If we don't have anything that can be placed to the right
|
||||
intileimage.putalpha(alphaLayerLTC)
|
||||
else:
|
||||
intileimage.putalpha(alphaLayerLTaC)
|
||||
# Layer tile onto final image
|
||||
outputsuperimage.alpha_composite(intileimage, (left, top))
|
||||
else:
|
||||
print('Error: could not find all Embiggen output tiles in memory? Something must have gone wrong with img2img generation.')
|
||||
print(
|
||||
"Error: could not find all Embiggen output tiles in memory? Something must have gone wrong with img2img generation."
|
||||
)
|
||||
|
||||
# after internal loops and patching up return Embiggen image
|
||||
return outputsuperimage
|
||||
|
||||
# end of function declaration
|
||||
return make_image
|
97
invokeai/backend/generator/img2img.py
Normal file
97
invokeai/backend/generator/img2img.py
Normal file
@ -0,0 +1,97 @@
|
||||
"""
|
||||
invokeai.backend.generator.img2img descends from .generator
|
||||
"""
|
||||
|
||||
import torch
|
||||
from diffusers import logging
|
||||
|
||||
from ..stable_diffusion import (
|
||||
ConditioningData,
|
||||
PostprocessingSettings,
|
||||
StableDiffusionGeneratorPipeline,
|
||||
)
|
||||
from .base import Generator
|
||||
|
||||
|
||||
class Img2Img(Generator):
|
||||
def __init__(self, model, precision):
|
||||
super().__init__(model, precision)
|
||||
self.init_latent = None # by get_noise()
|
||||
|
||||
def get_make_image(
|
||||
self,
|
||||
prompt,
|
||||
sampler,
|
||||
steps,
|
||||
cfg_scale,
|
||||
ddim_eta,
|
||||
conditioning,
|
||||
init_image,
|
||||
strength,
|
||||
step_callback=None,
|
||||
threshold=0.0,
|
||||
warmup=0.2,
|
||||
perlin=0.0,
|
||||
h_symmetry_time_pct=None,
|
||||
v_symmetry_time_pct=None,
|
||||
attention_maps_callback=None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Returns a function returning an image derived from the prompt and the initial image
|
||||
Return value depends on the seed at the time you call it.
|
||||
"""
|
||||
self.perlin = perlin
|
||||
|
||||
# noinspection PyTypeChecker
|
||||
pipeline: StableDiffusionGeneratorPipeline = self.model
|
||||
pipeline.scheduler = sampler
|
||||
|
||||
uc, c, extra_conditioning_info = conditioning
|
||||
conditioning_data = ConditioningData(
|
||||
uc,
|
||||
c,
|
||||
cfg_scale,
|
||||
extra_conditioning_info,
|
||||
postprocessing_settings=PostprocessingSettings(
|
||||
threshold=threshold,
|
||||
warmup=warmup,
|
||||
h_symmetry_time_pct=h_symmetry_time_pct,
|
||||
v_symmetry_time_pct=v_symmetry_time_pct,
|
||||
),
|
||||
).add_scheduler_args_if_applicable(pipeline.scheduler, eta=ddim_eta)
|
||||
|
||||
def make_image(x_T):
|
||||
# FIXME: use x_T for initial seeded noise
|
||||
# We're not at the moment because the pipeline automatically resizes init_image if
|
||||
# necessary, which the x_T input might not match.
|
||||
logging.set_verbosity_error() # quench safety check warnings
|
||||
pipeline_output = pipeline.img2img_from_embeddings(
|
||||
init_image,
|
||||
strength,
|
||||
steps,
|
||||
conditioning_data,
|
||||
noise_func=self.get_noise_like,
|
||||
callback=step_callback,
|
||||
)
|
||||
if (
|
||||
pipeline_output.attention_map_saver is not None
|
||||
and attention_maps_callback is not None
|
||||
):
|
||||
attention_maps_callback(pipeline_output.attention_map_saver)
|
||||
return pipeline.numpy_to_pil(pipeline_output.images)[0]
|
||||
|
||||
return make_image
|
||||
|
||||
def get_noise_like(self, like: torch.Tensor):
|
||||
device = like.device
|
||||
if device.type == "mps":
|
||||
x = torch.randn_like(like, device="cpu").to(device)
|
||||
else:
|
||||
x = torch.randn_like(like, device=device)
|
||||
if self.perlin > 0.0:
|
||||
shape = like.shape
|
||||
x = (1 - self.perlin) * x + self.perlin * self.get_perlin_noise(
|
||||
shape[3], shape[2]
|
||||
)
|
||||
return x
|
@ -1,32 +1,35 @@
|
||||
'''
|
||||
ldm.invoke.generator.inpaint descends from ldm.invoke.generator
|
||||
'''
|
||||
"""
|
||||
invokeai.backend.generator.inpaint descends from .generator
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
|
||||
import PIL
|
||||
import cv2
|
||||
import numpy as np
|
||||
import PIL
|
||||
import torch
|
||||
from PIL import Image, ImageFilter, ImageOps, ImageChops
|
||||
from PIL import Image, ImageChops, ImageFilter, ImageOps
|
||||
|
||||
from ldm.invoke.generator.diffusers_pipeline import image_resized_to_grid_as_tensor, StableDiffusionGeneratorPipeline, \
|
||||
ConditioningData
|
||||
from ldm.invoke.generator.img2img import Img2Img
|
||||
from ldm.invoke.patchmatch import PatchMatch
|
||||
from ldm.util import debug_image
|
||||
from ..image_util import PatchMatch, debug_image
|
||||
from ..stable_diffusion.diffusers_pipeline import (
|
||||
ConditioningData,
|
||||
StableDiffusionGeneratorPipeline,
|
||||
image_resized_to_grid_as_tensor,
|
||||
)
|
||||
from .img2img import Img2Img
|
||||
|
||||
|
||||
def infill_methods()->list[str]:
|
||||
def infill_methods() -> list[str]:
|
||||
methods = [
|
||||
"tile",
|
||||
"solid",
|
||||
]
|
||||
if PatchMatch.patchmatch_available():
|
||||
methods.insert(0, 'patchmatch')
|
||||
methods.insert(0, "patchmatch")
|
||||
return methods
|
||||
|
||||
|
||||
class Inpaint(Img2Img):
|
||||
def __init__(self, model, precision):
|
||||
self.inpaint_height = 0
|
||||
@ -53,11 +56,11 @@ class Inpaint(Img2Img):
|
||||
np.ravel(image),
|
||||
shape=(nrows, ncols, height, width, depth),
|
||||
strides=(height * _strides[0], width * _strides[1], *_strides),
|
||||
writeable=False
|
||||
writeable=False,
|
||||
)
|
||||
|
||||
def infill_patchmatch(self, im: Image.Image) -> Image:
|
||||
if im.mode != 'RGBA':
|
||||
if im.mode != "RGBA":
|
||||
return im
|
||||
|
||||
# Skip patchmatch if patchmatch isn't available
|
||||
@ -65,13 +68,17 @@ class Inpaint(Img2Img):
|
||||
return im
|
||||
|
||||
# Patchmatch (note, we may want to expose patch_size? Increasing it significantly impacts performance though)
|
||||
im_patched_np = PatchMatch.inpaint(im.convert('RGB'), ImageOps.invert(im.split()[-1]), patch_size = 3)
|
||||
im_patched = Image.fromarray(im_patched_np, mode = 'RGB')
|
||||
im_patched_np = PatchMatch.inpaint(
|
||||
im.convert("RGB"), ImageOps.invert(im.split()[-1]), patch_size=3
|
||||
)
|
||||
im_patched = Image.fromarray(im_patched_np, mode="RGB")
|
||||
return im_patched
|
||||
|
||||
def tile_fill_missing(self, im: Image.Image, tile_size: int = 16, seed: int = None) -> Image:
|
||||
def tile_fill_missing(
|
||||
self, im: Image.Image, tile_size: int = 16, seed: int = None
|
||||
) -> Image:
|
||||
# Only fill if there's an alpha layer
|
||||
if im.mode != 'RGBA':
|
||||
if im.mode != "RGBA":
|
||||
return im
|
||||
|
||||
a = np.asarray(im, dtype=np.uint8)
|
||||
@ -79,21 +86,21 @@ class Inpaint(Img2Img):
|
||||
tile_size = (tile_size, tile_size)
|
||||
|
||||
# Get the image as tiles of a specified size
|
||||
tiles = self.get_tile_images(a,*tile_size).copy()
|
||||
tiles = self.get_tile_images(a, *tile_size).copy()
|
||||
|
||||
# Get the mask as tiles
|
||||
tiles_mask = tiles[:,:,:,:,3]
|
||||
tiles_mask = tiles[:, :, :, :, 3]
|
||||
|
||||
# Find any mask tiles with any fully transparent pixels (we will be replacing these later)
|
||||
tmask_shape = tiles_mask.shape
|
||||
tiles_mask = tiles_mask.reshape(math.prod(tiles_mask.shape))
|
||||
n,ny = (math.prod(tmask_shape[0:2])), math.prod(tmask_shape[2:])
|
||||
tiles_mask = (tiles_mask > 0)
|
||||
tiles_mask = tiles_mask.reshape((n,ny)).all(axis = 1)
|
||||
n, ny = (math.prod(tmask_shape[0:2])), math.prod(tmask_shape[2:])
|
||||
tiles_mask = tiles_mask > 0
|
||||
tiles_mask = tiles_mask.reshape((n, ny)).all(axis=1)
|
||||
|
||||
# Get RGB tiles in single array and filter by the mask
|
||||
tshape = tiles.shape
|
||||
tiles_all = tiles.reshape((math.prod(tiles.shape[0:2]), * tiles.shape[2:]))
|
||||
tiles_all = tiles.reshape((math.prod(tiles.shape[0:2]), *tiles.shape[2:]))
|
||||
filtered_tiles = tiles_all[tiles_mask]
|
||||
|
||||
if len(filtered_tiles) == 0:
|
||||
@ -101,23 +108,32 @@ class Inpaint(Img2Img):
|
||||
|
||||
# Find all invalid tiles and replace with a random valid tile
|
||||
replace_count = (tiles_mask == False).sum()
|
||||
rng = np.random.default_rng(seed = seed)
|
||||
tiles_all[np.logical_not(tiles_mask)] = filtered_tiles[rng.choice(filtered_tiles.shape[0], replace_count),:,:,:]
|
||||
rng = np.random.default_rng(seed=seed)
|
||||
tiles_all[np.logical_not(tiles_mask)] = filtered_tiles[
|
||||
rng.choice(filtered_tiles.shape[0], replace_count), :, :, :
|
||||
]
|
||||
|
||||
# Convert back to an image
|
||||
tiles_all = tiles_all.reshape(tshape)
|
||||
tiles_all = tiles_all.swapaxes(1,2)
|
||||
st = tiles_all.reshape((math.prod(tiles_all.shape[0:2]), math.prod(tiles_all.shape[2:4]), tiles_all.shape[4]))
|
||||
si = Image.fromarray(st, mode='RGBA')
|
||||
tiles_all = tiles_all.swapaxes(1, 2)
|
||||
st = tiles_all.reshape(
|
||||
(
|
||||
math.prod(tiles_all.shape[0:2]),
|
||||
math.prod(tiles_all.shape[2:4]),
|
||||
tiles_all.shape[4],
|
||||
)
|
||||
)
|
||||
si = Image.fromarray(st, mode="RGBA")
|
||||
|
||||
return si
|
||||
|
||||
|
||||
def mask_edge(self, mask: Image, edge_size: int, edge_blur: int) -> Image:
|
||||
npimg = np.asarray(mask, dtype=np.uint8)
|
||||
|
||||
# Detect any partially transparent regions
|
||||
npgradient = np.uint8(255 * (1.0 - np.floor(np.abs(0.5 - np.float32(npimg) / 255.0) * 2.0)))
|
||||
npgradient = np.uint8(
|
||||
255 * (1.0 - np.floor(np.abs(0.5 - np.float32(npimg) / 255.0) * 2.0))
|
||||
)
|
||||
|
||||
# Detect hard edges
|
||||
npedge = cv2.Canny(npimg, threshold1=100, threshold2=200)
|
||||
@ -126,7 +142,9 @@ class Inpaint(Img2Img):
|
||||
npmask = npgradient + npedge
|
||||
|
||||
# Expand
|
||||
npmask = cv2.dilate(npmask, np.ones((3,3), np.uint8), iterations = int(edge_size / 2))
|
||||
npmask = cv2.dilate(
|
||||
npmask, np.ones((3, 3), np.uint8), iterations=int(edge_size / 2)
|
||||
)
|
||||
|
||||
new_mask = Image.fromarray(npmask)
|
||||
|
||||
@ -135,9 +153,22 @@ class Inpaint(Img2Img):
|
||||
|
||||
return ImageOps.invert(new_mask)
|
||||
|
||||
|
||||
def seam_paint(self, im: Image.Image, seam_size: int, seam_blur: int, prompt, sampler, steps, cfg_scale, ddim_eta,
|
||||
conditioning, strength, noise, infill_method, step_callback) -> Image.Image:
|
||||
def seam_paint(
|
||||
self,
|
||||
im: Image.Image,
|
||||
seam_size: int,
|
||||
seam_blur: int,
|
||||
prompt,
|
||||
sampler,
|
||||
steps,
|
||||
cfg_scale,
|
||||
ddim_eta,
|
||||
conditioning,
|
||||
strength,
|
||||
noise,
|
||||
infill_method,
|
||||
step_callback,
|
||||
) -> Image.Image:
|
||||
hard_mask = self.pil_image.split()[-1].copy()
|
||||
mask = self.mask_edge(hard_mask, seam_size, seam_blur)
|
||||
|
||||
@ -148,15 +179,15 @@ class Inpaint(Img2Img):
|
||||
cfg_scale,
|
||||
ddim_eta,
|
||||
conditioning,
|
||||
init_image = im.copy().convert('RGBA'),
|
||||
mask_image = mask,
|
||||
strength = strength,
|
||||
mask_blur_radius = 0,
|
||||
seam_size = 0,
|
||||
step_callback = step_callback,
|
||||
inpaint_width = im.width,
|
||||
inpaint_height = im.height,
|
||||
infill_method = infill_method
|
||||
init_image=im.copy().convert("RGBA"),
|
||||
mask_image=mask,
|
||||
strength=strength,
|
||||
mask_blur_radius=0,
|
||||
seam_size=0,
|
||||
step_callback=step_callback,
|
||||
inpaint_width=im.width,
|
||||
inpaint_height=im.height,
|
||||
infill_method=infill_method,
|
||||
)
|
||||
|
||||
seam_noise = self.get_noise(im.width, im.height)
|
||||
@ -165,28 +196,35 @@ class Inpaint(Img2Img):
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def get_make_image(self,prompt,sampler,steps,cfg_scale,ddim_eta,
|
||||
conditioning,
|
||||
init_image: PIL.Image.Image | torch.FloatTensor,
|
||||
mask_image: PIL.Image.Image | torch.FloatTensor,
|
||||
strength: float,
|
||||
mask_blur_radius: int = 8,
|
||||
# Seam settings - when 0, doesn't fill seam
|
||||
seam_size: int = 0,
|
||||
seam_blur: int = 0,
|
||||
seam_strength: float = 0.7,
|
||||
seam_steps: int = 10,
|
||||
tile_size: int = 32,
|
||||
step_callback=None,
|
||||
inpaint_replace=False, enable_image_debugging=False,
|
||||
infill_method = None,
|
||||
inpaint_width=None,
|
||||
inpaint_height=None,
|
||||
inpaint_fill:tuple(int)=(0x7F, 0x7F, 0x7F, 0xFF),
|
||||
attention_maps_callback=None,
|
||||
**kwargs):
|
||||
def get_make_image(
|
||||
self,
|
||||
prompt,
|
||||
sampler,
|
||||
steps,
|
||||
cfg_scale,
|
||||
ddim_eta,
|
||||
conditioning,
|
||||
init_image: PIL.Image.Image | torch.FloatTensor,
|
||||
mask_image: PIL.Image.Image | torch.FloatTensor,
|
||||
strength: float,
|
||||
mask_blur_radius: int = 8,
|
||||
# Seam settings - when 0, doesn't fill seam
|
||||
seam_size: int = 0,
|
||||
seam_blur: int = 0,
|
||||
seam_strength: float = 0.7,
|
||||
seam_steps: int = 10,
|
||||
tile_size: int = 32,
|
||||
step_callback=None,
|
||||
inpaint_replace=False,
|
||||
enable_image_debugging=False,
|
||||
infill_method=None,
|
||||
inpaint_width=None,
|
||||
inpaint_height=None,
|
||||
inpaint_fill: tuple(int) = (0x7F, 0x7F, 0x7F, 0xFF),
|
||||
attention_maps_callback=None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Returns a function returning an image derived from the prompt and
|
||||
the initial image + mask. Return value depends on the seed at
|
||||
@ -204,33 +242,39 @@ class Inpaint(Img2Img):
|
||||
self.pil_image = init_image.copy()
|
||||
|
||||
# Do infill
|
||||
if infill_method == 'patchmatch' and PatchMatch.patchmatch_available():
|
||||
if infill_method == "patchmatch" and PatchMatch.patchmatch_available():
|
||||
init_filled = self.infill_patchmatch(self.pil_image.copy())
|
||||
elif infill_method == 'tile':
|
||||
elif infill_method == "tile":
|
||||
init_filled = self.tile_fill_missing(
|
||||
self.pil_image.copy(),
|
||||
seed = self.seed,
|
||||
tile_size = tile_size
|
||||
self.pil_image.copy(), seed=self.seed, tile_size=tile_size
|
||||
)
|
||||
elif infill_method == 'solid':
|
||||
elif infill_method == "solid":
|
||||
solid_bg = PIL.Image.new("RGBA", init_image.size, inpaint_fill)
|
||||
init_filled = PIL.Image.alpha_composite(solid_bg, init_image)
|
||||
else:
|
||||
raise ValueError(f"Non-supported infill type {infill_method}", infill_method)
|
||||
init_filled.paste(init_image, (0,0), init_image.split()[-1])
|
||||
raise ValueError(
|
||||
f"Non-supported infill type {infill_method}", infill_method
|
||||
)
|
||||
init_filled.paste(init_image, (0, 0), init_image.split()[-1])
|
||||
|
||||
# Resize if requested for inpainting
|
||||
if inpaint_width and inpaint_height:
|
||||
init_filled = init_filled.resize((inpaint_width, inpaint_height))
|
||||
|
||||
debug_image(init_filled, "init_filled", debug_status=self.enable_image_debugging)
|
||||
debug_image(
|
||||
init_filled, "init_filled", debug_status=self.enable_image_debugging
|
||||
)
|
||||
|
||||
# Create init tensor
|
||||
init_image = image_resized_to_grid_as_tensor(init_filled.convert('RGB'))
|
||||
init_image = image_resized_to_grid_as_tensor(init_filled.convert("RGB"))
|
||||
|
||||
if isinstance(mask_image, PIL.Image.Image):
|
||||
self.pil_mask = mask_image.copy()
|
||||
debug_image(mask_image, "mask_image BEFORE multiply with pil_image", debug_status=self.enable_image_debugging)
|
||||
debug_image(
|
||||
mask_image,
|
||||
"mask_image BEFORE multiply with pil_image",
|
||||
debug_status=self.enable_image_debugging,
|
||||
)
|
||||
|
||||
init_alpha = self.pil_image.getchannel("A")
|
||||
if mask_image.mode != "L":
|
||||
@ -243,8 +287,14 @@ class Inpaint(Img2Img):
|
||||
if inpaint_width and inpaint_height:
|
||||
mask_image = mask_image.resize((inpaint_width, inpaint_height))
|
||||
|
||||
debug_image(mask_image, "mask_image AFTER multiply with pil_image", debug_status=self.enable_image_debugging)
|
||||
mask: torch.FloatTensor = image_resized_to_grid_as_tensor(mask_image, normalize=False)
|
||||
debug_image(
|
||||
mask_image,
|
||||
"mask_image AFTER multiply with pil_image",
|
||||
debug_status=self.enable_image_debugging,
|
||||
)
|
||||
mask: torch.FloatTensor = image_resized_to_grid_as_tensor(
|
||||
mask_image, normalize=False
|
||||
)
|
||||
else:
|
||||
mask: torch.FloatTensor = mask_image
|
||||
|
||||
@ -256,9 +306,9 @@ class Inpaint(Img2Img):
|
||||
|
||||
# todo: support cross-attention control
|
||||
uc, c, _ = conditioning
|
||||
conditioning_data = (ConditioningData(uc, c, cfg_scale)
|
||||
.add_scheduler_args_if_applicable(pipeline.scheduler, eta=ddim_eta))
|
||||
|
||||
conditioning_data = ConditioningData(
|
||||
uc, c, cfg_scale
|
||||
).add_scheduler_args_if_applicable(pipeline.scheduler, eta=ddim_eta)
|
||||
|
||||
def make_image(x_T):
|
||||
pipeline_output = pipeline.inpaint_from_embeddings(
|
||||
@ -271,43 +321,71 @@ class Inpaint(Img2Img):
|
||||
callback=step_callback,
|
||||
)
|
||||
|
||||
if pipeline_output.attention_map_saver is not None and attention_maps_callback is not None:
|
||||
if (
|
||||
pipeline_output.attention_map_saver is not None
|
||||
and attention_maps_callback is not None
|
||||
):
|
||||
attention_maps_callback(pipeline_output.attention_map_saver)
|
||||
|
||||
result = self.postprocess_size_and_mask(pipeline.numpy_to_pil(pipeline_output.images)[0])
|
||||
result = self.postprocess_size_and_mask(
|
||||
pipeline.numpy_to_pil(pipeline_output.images)[0]
|
||||
)
|
||||
|
||||
# Seam paint if this is our first pass (seam_size set to 0 during seam painting)
|
||||
if seam_size > 0:
|
||||
old_image = self.pil_image or init_image
|
||||
old_mask = self.pil_mask or mask_image
|
||||
|
||||
result = self.seam_paint(result, seam_size, seam_blur, prompt, sampler, seam_steps, cfg_scale, ddim_eta,
|
||||
conditioning, seam_strength, x_T, infill_method, step_callback)
|
||||
result = self.seam_paint(
|
||||
result,
|
||||
seam_size,
|
||||
seam_blur,
|
||||
prompt,
|
||||
sampler,
|
||||
seam_steps,
|
||||
cfg_scale,
|
||||
ddim_eta,
|
||||
conditioning,
|
||||
seam_strength,
|
||||
x_T,
|
||||
infill_method,
|
||||
step_callback,
|
||||
)
|
||||
|
||||
# Restore original settings
|
||||
self.get_make_image(prompt,sampler,steps,cfg_scale,ddim_eta,
|
||||
conditioning,
|
||||
old_image,
|
||||
old_mask,
|
||||
strength,
|
||||
mask_blur_radius, seam_size, seam_blur, seam_strength,
|
||||
seam_steps, tile_size, step_callback,
|
||||
inpaint_replace, enable_image_debugging,
|
||||
inpaint_width = inpaint_width,
|
||||
inpaint_height = inpaint_height,
|
||||
infill_method = infill_method,
|
||||
**kwargs)
|
||||
self.get_make_image(
|
||||
prompt,
|
||||
sampler,
|
||||
steps,
|
||||
cfg_scale,
|
||||
ddim_eta,
|
||||
conditioning,
|
||||
old_image,
|
||||
old_mask,
|
||||
strength,
|
||||
mask_blur_radius,
|
||||
seam_size,
|
||||
seam_blur,
|
||||
seam_strength,
|
||||
seam_steps,
|
||||
tile_size,
|
||||
step_callback,
|
||||
inpaint_replace,
|
||||
enable_image_debugging,
|
||||
inpaint_width=inpaint_width,
|
||||
inpaint_height=inpaint_height,
|
||||
infill_method=infill_method,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
return make_image
|
||||
|
||||
|
||||
def sample_to_image(self, samples)->Image.Image:
|
||||
gen_result = super().sample_to_image(samples).convert('RGB')
|
||||
def sample_to_image(self, samples) -> Image.Image:
|
||||
gen_result = super().sample_to_image(samples).convert("RGB")
|
||||
return self.postprocess_size_and_mask(gen_result)
|
||||
|
||||
|
||||
def postprocess_size_and_mask(self, gen_result: Image.Image) -> Image.Image:
|
||||
debug_image(gen_result, "gen_result", debug_status=self.enable_image_debugging)
|
||||
|
||||
@ -318,7 +396,13 @@ class Inpaint(Img2Img):
|
||||
if self.pil_image is None or self.pil_mask is None:
|
||||
return gen_result
|
||||
|
||||
corrected_result = self.repaste_and_color_correct(gen_result, self.pil_image, self.pil_mask, self.mask_blur_radius)
|
||||
debug_image(corrected_result, "corrected_result", debug_status=self.enable_image_debugging)
|
||||
corrected_result = self.repaste_and_color_correct(
|
||||
gen_result, self.pil_image, self.pil_mask, self.mask_blur_radius
|
||||
)
|
||||
debug_image(
|
||||
corrected_result,
|
||||
"corrected_result",
|
||||
debug_status=self.enable_image_debugging,
|
||||
)
|
||||
|
||||
return corrected_result
|
81
invokeai/backend/generator/txt2img.py
Normal file
81
invokeai/backend/generator/txt2img.py
Normal file
@ -0,0 +1,81 @@
|
||||
"""
|
||||
invokeai.backend.generator.txt2img inherits from invokeai.backend.generator
|
||||
"""
|
||||
import PIL.Image
|
||||
import torch
|
||||
|
||||
from ..stable_diffusion import (
|
||||
ConditioningData,
|
||||
PostprocessingSettings,
|
||||
StableDiffusionGeneratorPipeline,
|
||||
)
|
||||
from .base import Generator
|
||||
|
||||
|
||||
class Txt2Img(Generator):
|
||||
def __init__(self, model, precision):
|
||||
super().__init__(model, precision)
|
||||
|
||||
@torch.no_grad()
|
||||
def get_make_image(
|
||||
self,
|
||||
prompt,
|
||||
sampler,
|
||||
steps,
|
||||
cfg_scale,
|
||||
ddim_eta,
|
||||
conditioning,
|
||||
width,
|
||||
height,
|
||||
step_callback=None,
|
||||
threshold=0.0,
|
||||
warmup=0.2,
|
||||
perlin=0.0,
|
||||
h_symmetry_time_pct=None,
|
||||
v_symmetry_time_pct=None,
|
||||
attention_maps_callback=None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Returns a function returning an image derived from the prompt and the initial image
|
||||
Return value depends on the seed at the time you call it
|
||||
kwargs are 'width' and 'height'
|
||||
"""
|
||||
self.perlin = perlin
|
||||
|
||||
# noinspection PyTypeChecker
|
||||
pipeline: StableDiffusionGeneratorPipeline = self.model
|
||||
pipeline.scheduler = sampler
|
||||
|
||||
uc, c, extra_conditioning_info = conditioning
|
||||
conditioning_data = ConditioningData(
|
||||
uc,
|
||||
c,
|
||||
cfg_scale,
|
||||
extra_conditioning_info,
|
||||
postprocessing_settings=PostprocessingSettings(
|
||||
threshold=threshold,
|
||||
warmup=warmup,
|
||||
h_symmetry_time_pct=h_symmetry_time_pct,
|
||||
v_symmetry_time_pct=v_symmetry_time_pct,
|
||||
),
|
||||
).add_scheduler_args_if_applicable(pipeline.scheduler, eta=ddim_eta)
|
||||
|
||||
def make_image(x_T) -> PIL.Image.Image:
|
||||
pipeline_output = pipeline.image_from_embeddings(
|
||||
latents=torch.zeros_like(x_T, dtype=self.torch_dtype()),
|
||||
noise=x_T,
|
||||
num_inference_steps=steps,
|
||||
conditioning_data=conditioning_data,
|
||||
callback=step_callback,
|
||||
)
|
||||
|
||||
if (
|
||||
pipeline_output.attention_map_saver is not None
|
||||
and attention_maps_callback is not None
|
||||
):
|
||||
attention_maps_callback(pipeline_output.attention_map_saver)
|
||||
|
||||
return pipeline.numpy_to_pil(pipeline_output.images)[0]
|
||||
|
||||
return make_image
|
@ -1,6 +1,6 @@
|
||||
'''
|
||||
ldm.invoke.generator.txt2img inherits from ldm.invoke.generator
|
||||
'''
|
||||
"""
|
||||
invokeai.backend.generator.txt2img inherits from invokeai.backend.generator
|
||||
"""
|
||||
|
||||
import math
|
||||
from typing import Callable, Optional
|
||||
@ -8,21 +8,40 @@ from typing import Callable, Optional
|
||||
import torch
|
||||
from diffusers.utils.logging import get_verbosity, set_verbosity, set_verbosity_error
|
||||
|
||||
from ldm.invoke.generator.base import Generator
|
||||
from ldm.invoke.generator.diffusers_pipeline import trim_to_multiple_of, StableDiffusionGeneratorPipeline, \
|
||||
ConditioningData
|
||||
from ldm.models.diffusion.shared_invokeai_diffusion import PostprocessingSettings
|
||||
from ..models import PostprocessingSettings
|
||||
from .base import Generator
|
||||
from .diffusers_pipeline import (
|
||||
ConditioningData,
|
||||
StableDiffusionGeneratorPipeline,
|
||||
trim_to_multiple_of,
|
||||
)
|
||||
|
||||
|
||||
class Txt2Img2Img(Generator):
|
||||
def __init__(self, model, precision):
|
||||
super().__init__(model, precision)
|
||||
self.init_latent = None # for get_noise()
|
||||
self.init_latent = None # for get_noise()
|
||||
|
||||
def get_make_image(self, prompt:str, sampler, steps:int, cfg_scale:float, ddim_eta,
|
||||
conditioning, width:int, height:int, strength:float,
|
||||
step_callback:Optional[Callable]=None, threshold=0.0, warmup=0.2, perlin=0.0,
|
||||
h_symmetry_time_pct=None, v_symmetry_time_pct=None, attention_maps_callback=None, **kwargs):
|
||||
def get_make_image(
|
||||
self,
|
||||
prompt: str,
|
||||
sampler,
|
||||
steps: int,
|
||||
cfg_scale: float,
|
||||
ddim_eta,
|
||||
conditioning,
|
||||
width: int,
|
||||
height: int,
|
||||
strength: float,
|
||||
step_callback: Optional[Callable] = None,
|
||||
threshold=0.0,
|
||||
warmup=0.2,
|
||||
perlin=0.0,
|
||||
h_symmetry_time_pct=None,
|
||||
v_symmetry_time_pct=None,
|
||||
attention_maps_callback=None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Returns a function returning an image derived from the prompt and the initial image
|
||||
Return value depends on the seed at the time you call it
|
||||
@ -35,19 +54,20 @@ class Txt2Img2Img(Generator):
|
||||
pipeline.scheduler = sampler
|
||||
|
||||
uc, c, extra_conditioning_info = conditioning
|
||||
conditioning_data = (
|
||||
ConditioningData(
|
||||
uc, c, cfg_scale, extra_conditioning_info,
|
||||
postprocessing_settings = PostprocessingSettings(
|
||||
threshold=threshold,
|
||||
warmup=0.2,
|
||||
h_symmetry_time_pct=h_symmetry_time_pct,
|
||||
v_symmetry_time_pct=v_symmetry_time_pct
|
||||
)
|
||||
).add_scheduler_args_if_applicable(pipeline.scheduler, eta=ddim_eta))
|
||||
conditioning_data = ConditioningData(
|
||||
uc,
|
||||
c,
|
||||
cfg_scale,
|
||||
extra_conditioning_info,
|
||||
postprocessing_settings=PostprocessingSettings(
|
||||
threshold=threshold,
|
||||
warmup=0.2,
|
||||
h_symmetry_time_pct=h_symmetry_time_pct,
|
||||
v_symmetry_time_pct=v_symmetry_time_pct,
|
||||
),
|
||||
).add_scheduler_args_if_applicable(pipeline.scheduler, eta=ddim_eta)
|
||||
|
||||
def make_image(x_T):
|
||||
|
||||
first_pass_latent_output, _ = pipeline.latents_from_embeddings(
|
||||
latents=torch.zeros_like(x_T),
|
||||
num_inference_steps=steps,
|
||||
@ -61,28 +81,40 @@ class Txt2Img2Img(Generator):
|
||||
init_width = first_pass_latent_output.size()[3] * self.downsampling_factor
|
||||
init_height = first_pass_latent_output.size()[2] * self.downsampling_factor
|
||||
print(
|
||||
f"\n>> Interpolating from {init_width}x{init_height} to {width}x{height} using DDIM sampling"
|
||||
)
|
||||
f"\n>> Interpolating from {init_width}x{init_height} to {width}x{height} using DDIM sampling"
|
||||
)
|
||||
|
||||
# resizing
|
||||
resized_latents = torch.nn.functional.interpolate(
|
||||
first_pass_latent_output,
|
||||
size=(height // self.downsampling_factor, width // self.downsampling_factor),
|
||||
mode="bilinear"
|
||||
size=(
|
||||
height // self.downsampling_factor,
|
||||
width // self.downsampling_factor,
|
||||
),
|
||||
mode="bilinear",
|
||||
)
|
||||
|
||||
# Free up memory from the last generation.
|
||||
clear_cuda_cache = kwargs['clear_cuda_cache'] or None
|
||||
clear_cuda_cache = kwargs["clear_cuda_cache"] or None
|
||||
if clear_cuda_cache is not None:
|
||||
clear_cuda_cache()
|
||||
|
||||
second_pass_noise = self.get_noise_like(resized_latents, override_perlin=True)
|
||||
second_pass_noise = self.get_noise_like(
|
||||
resized_latents, override_perlin=True
|
||||
)
|
||||
|
||||
# Clear symmetry for the second pass
|
||||
from dataclasses import replace
|
||||
new_postprocessing_settings = replace(conditioning_data.postprocessing_settings, h_symmetry_time_pct=None)
|
||||
new_postprocessing_settings = replace(new_postprocessing_settings, v_symmetry_time_pct=None)
|
||||
new_conditioning_data = replace(conditioning_data, postprocessing_settings=new_postprocessing_settings)
|
||||
|
||||
new_postprocessing_settings = replace(
|
||||
conditioning_data.postprocessing_settings, h_symmetry_time_pct=None
|
||||
)
|
||||
new_postprocessing_settings = replace(
|
||||
new_postprocessing_settings, v_symmetry_time_pct=None
|
||||
)
|
||||
new_conditioning_data = replace(
|
||||
conditioning_data, postprocessing_settings=new_postprocessing_settings
|
||||
)
|
||||
|
||||
verbosity = get_verbosity()
|
||||
set_verbosity_error()
|
||||
@ -92,15 +124,18 @@ class Txt2Img2Img(Generator):
|
||||
conditioning_data=new_conditioning_data,
|
||||
strength=strength,
|
||||
noise=second_pass_noise,
|
||||
callback=step_callback)
|
||||
callback=step_callback,
|
||||
)
|
||||
set_verbosity(verbosity)
|
||||
|
||||
if pipeline_output.attention_map_saver is not None and attention_maps_callback is not None:
|
||||
if (
|
||||
pipeline_output.attention_map_saver is not None
|
||||
and attention_maps_callback is not None
|
||||
):
|
||||
attention_maps_callback(pipeline_output.attention_map_saver)
|
||||
|
||||
return pipeline.numpy_to_pil(pipeline_output.images)[0]
|
||||
|
||||
|
||||
# FIXME: do we really need something entirely different for the inpainting model?
|
||||
|
||||
# in the case of the inpainting model being loaded, the trick of
|
||||
@ -111,19 +146,23 @@ class Txt2Img2Img(Generator):
|
||||
|
||||
return make_image
|
||||
|
||||
def get_noise_like(self, like: torch.Tensor, override_perlin: bool=False):
|
||||
def get_noise_like(self, like: torch.Tensor, override_perlin: bool = False):
|
||||
device = like.device
|
||||
if device.type == 'mps':
|
||||
x = torch.randn_like(like, device='cpu', dtype=self.torch_dtype()).to(device)
|
||||
if device.type == "mps":
|
||||
x = torch.randn_like(like, device="cpu", dtype=self.torch_dtype()).to(
|
||||
device
|
||||
)
|
||||
else:
|
||||
x = torch.randn_like(like, device=device, dtype=self.torch_dtype())
|
||||
if self.perlin > 0.0 and override_perlin == False:
|
||||
shape = like.shape
|
||||
x = (1-self.perlin)*x + self.perlin*self.get_perlin_noise(shape[3], shape[2])
|
||||
x = (1 - self.perlin) * x + self.perlin * self.get_perlin_noise(
|
||||
shape[3], shape[2]
|
||||
)
|
||||
return x
|
||||
|
||||
# returns a tensor filled with random numbers from a normal distribution
|
||||
def get_noise(self,width,height,scale = True):
|
||||
def get_noise(self, width, height, scale=True):
|
||||
# print(f"Get noise: {width}x{height}")
|
||||
if scale:
|
||||
# Scale the input width and height for the initial generation
|
||||
@ -133,7 +172,9 @@ class Txt2Img2Img(Generator):
|
||||
aspect = width / height
|
||||
dimension = self.model.unet.config.sample_size * self.model.vae_scale_factor
|
||||
min_dimension = math.floor(dimension * 0.5)
|
||||
model_area = dimension * dimension # hardcoded for now since all models are trained on square images
|
||||
model_area = (
|
||||
dimension * dimension
|
||||
) # hardcoded for now since all models are trained on square images
|
||||
|
||||
if aspect > 1.0:
|
||||
init_height = max(min_dimension, math.sqrt(model_area / aspect))
|
||||
@ -142,7 +183,9 @@ class Txt2Img2Img(Generator):
|
||||
init_width = max(min_dimension, math.sqrt(model_area * aspect))
|
||||
init_height = init_width / aspect
|
||||
|
||||
scaled_width, scaled_height = trim_to_multiple_of(math.floor(init_width), math.floor(init_height))
|
||||
scaled_width, scaled_height = trim_to_multiple_of(
|
||||
math.floor(init_width), math.floor(init_height)
|
||||
)
|
||||
|
||||
else:
|
||||
scaled_width = width
|
||||
@ -152,10 +195,14 @@ class Txt2Img2Img(Generator):
|
||||
channels = self.latent_channels
|
||||
if channels == 9:
|
||||
channels = 4 # we don't really want noise for all the mask channels
|
||||
shape = (1, channels,
|
||||
scaled_height // self.downsampling_factor, scaled_width // self.downsampling_factor)
|
||||
if self.use_mps_noise or device.type == 'mps':
|
||||
tensor = torch.empty(size=shape, device='cpu')
|
||||
shape = (
|
||||
1,
|
||||
channels,
|
||||
scaled_height // self.downsampling_factor,
|
||||
scaled_width // self.downsampling_factor,
|
||||
)
|
||||
if self.use_mps_noise or device.type == "mps":
|
||||
tensor = torch.empty(size=shape, device="cpu")
|
||||
tensor = self.get_noise_like(like=tensor).to(device)
|
||||
else:
|
||||
tensor = torch.empty(size=shape, device=device)
|
@ -1,5 +1,5 @@
|
||||
'''
|
||||
ldm.invoke.globals defines a small number of global variables that would
|
||||
"""
|
||||
invokeai.backend.globals defines a small number of global variables that would
|
||||
otherwise have to be passed through long and complex call chains.
|
||||
|
||||
It defines a Namespace object named "Globals" that contains
|
||||
@ -9,7 +9,7 @@ the attributes:
|
||||
- initfile - path to the initialization file
|
||||
- try_patchmatch - option to globally disable loading of 'patchmatch' module
|
||||
- always_use_cpu - force use of CPU even if GPU is available
|
||||
'''
|
||||
"""
|
||||
|
||||
import os
|
||||
import os.path as osp
|
||||
@ -20,12 +20,12 @@ from typing import Union
|
||||
Globals = Namespace()
|
||||
|
||||
# Where to look for the initialization file and other key components
|
||||
Globals.initfile = 'invokeai.init'
|
||||
Globals.models_file = 'models.yaml'
|
||||
Globals.models_dir = 'models'
|
||||
Globals.config_dir = 'configs'
|
||||
Globals.autoscan_dir = 'weights'
|
||||
Globals.converted_ckpts_dir = 'converted_ckpts'
|
||||
Globals.initfile = "invokeai.init"
|
||||
Globals.models_file = "models.yaml"
|
||||
Globals.models_dir = "models"
|
||||
Globals.config_dir = "configs"
|
||||
Globals.autoscan_dir = "weights"
|
||||
Globals.converted_ckpts_dir = "converted_ckpts"
|
||||
|
||||
# Set the default root directory. This can be overwritten by explicitly
|
||||
# passing the `--root <directory>` argument on the command line.
|
||||
@ -34,12 +34,15 @@ Globals.converted_ckpts_dir = 'converted_ckpts'
|
||||
# 2) use VIRTUAL_ENV environment variable, with a check for initfile being there
|
||||
# 3) use ~/invokeai
|
||||
|
||||
if os.environ.get('INVOKEAI_ROOT'):
|
||||
Globals.root = osp.abspath(os.environ.get('INVOKEAI_ROOT'))
|
||||
elif os.environ.get('VIRTUAL_ENV') and Path(os.environ.get('VIRTUAL_ENV'),'..',Globals.initfile).exists():
|
||||
Globals.root = osp.abspath(osp.join(os.environ.get('VIRTUAL_ENV'), '..'))
|
||||
if os.environ.get("INVOKEAI_ROOT"):
|
||||
Globals.root = osp.abspath(os.environ.get("INVOKEAI_ROOT"))
|
||||
elif (
|
||||
os.environ.get("VIRTUAL_ENV")
|
||||
and Path(os.environ.get("VIRTUAL_ENV"), "..", Globals.initfile).exists()
|
||||
):
|
||||
Globals.root = osp.abspath(osp.join(os.environ.get("VIRTUAL_ENV"), ".."))
|
||||
else:
|
||||
Globals.root = osp.abspath(osp.expanduser('~/invokeai'))
|
||||
Globals.root = osp.abspath(osp.expanduser("~/invokeai"))
|
||||
|
||||
# Try loading patchmatch
|
||||
Globals.try_patchmatch = True
|
||||
@ -61,31 +64,38 @@ Globals.sequential_guidance = False
|
||||
Globals.full_precision = False
|
||||
|
||||
# whether we should convert ckpt files into diffusers models on the fly
|
||||
Globals.ckpt_convert = False
|
||||
Globals.ckpt_convert = True
|
||||
|
||||
# logging tokenization everywhere
|
||||
Globals.log_tokenization = False
|
||||
|
||||
def global_config_file()->Path:
|
||||
|
||||
def global_config_file() -> Path:
|
||||
return Path(Globals.root, Globals.config_dir, Globals.models_file)
|
||||
|
||||
def global_config_dir()->Path:
|
||||
|
||||
def global_config_dir() -> Path:
|
||||
return Path(Globals.root, Globals.config_dir)
|
||||
|
||||
def global_models_dir()->Path:
|
||||
|
||||
def global_models_dir() -> Path:
|
||||
return Path(Globals.root, Globals.models_dir)
|
||||
|
||||
def global_autoscan_dir()->Path:
|
||||
|
||||
def global_autoscan_dir() -> Path:
|
||||
return Path(Globals.root, Globals.autoscan_dir)
|
||||
|
||||
def global_converted_ckpts_dir()->Path:
|
||||
|
||||
def global_converted_ckpts_dir() -> Path:
|
||||
return Path(global_models_dir(), Globals.converted_ckpts_dir)
|
||||
|
||||
def global_set_root(root_dir:Union[str,Path]):
|
||||
|
||||
def global_set_root(root_dir: Union[str, Path]):
|
||||
Globals.root = root_dir
|
||||
|
||||
def global_cache_dir(subdir:Union[str,Path]='')->Path:
|
||||
'''
|
||||
|
||||
def global_cache_dir(subdir: Union[str, Path] = "") -> Path:
|
||||
"""
|
||||
Returns Path to the model cache directory. If a subdirectory
|
||||
is provided, it will be appended to the end of the path, allowing
|
||||
for huggingface-style conventions:
|
||||
@ -98,18 +108,18 @@ def global_cache_dir(subdir:Union[str,Path]='')->Path:
|
||||
One other caveat is that HuggingFace is moving some diffusers models
|
||||
into the "hub" subdirectory as well, so this will need to be revisited
|
||||
from time to time.
|
||||
'''
|
||||
home: str = os.getenv('HF_HOME')
|
||||
"""
|
||||
home: str = os.getenv("HF_HOME")
|
||||
|
||||
if home is None:
|
||||
home = os.getenv('XDG_CACHE_HOME')
|
||||
home = os.getenv("XDG_CACHE_HOME")
|
||||
|
||||
if home is not None:
|
||||
# Set `home` to $XDG_CACHE_HOME/huggingface, which is the default location mentioned in HuggingFace Hub Client Library.
|
||||
# See: https://huggingface.co/docs/huggingface_hub/main/en/package_reference/environment_variables#xdgcachehome
|
||||
home += os.sep + 'huggingface'
|
||||
home += os.sep + "huggingface"
|
||||
|
||||
if home is not None:
|
||||
return Path(home,subdir)
|
||||
return Path(home, subdir)
|
||||
else:
|
||||
return Path(Globals.root,'models',subdir)
|
||||
return Path(Globals.root, "models", subdir)
|
24
invokeai/backend/image_util/__init__.py
Normal file
24
invokeai/backend/image_util/__init__.py
Normal file
@ -0,0 +1,24 @@
|
||||
"""
|
||||
Initialization file for invokeai.backend.image_util methods.
|
||||
"""
|
||||
from .patchmatch import PatchMatch
|
||||
from .pngwriter import PngWriter, PromptFormatter, retrieve_metadata, write_metadata
|
||||
from .seamless import configure_model_padding
|
||||
from .txt2mask import Txt2Mask
|
||||
from .util import InitImageResizer, make_grid
|
||||
|
||||
|
||||
def debug_image(
|
||||
debug_image, debug_text, debug_show=True, debug_result=False, debug_status=False
|
||||
):
|
||||
if not debug_status:
|
||||
return
|
||||
|
||||
image_copy = debug_image.copy().convert("RGBA")
|
||||
ImageDraw.Draw(image_copy).text((5, 5), debug_text, (255, 0, 0))
|
||||
|
||||
if debug_show:
|
||||
image_copy.show()
|
||||
|
||||
if debug_result:
|
||||
return image_copy
|
@ -1,19 +1,21 @@
|
||||
'''
|
||||
"""
|
||||
This module defines a singleton object, "patchmatch" that
|
||||
wraps the actual patchmatch object. It respects the global
|
||||
"try_patchmatch" attribute, so that patchmatch loading can
|
||||
be suppressed or deferred
|
||||
'''
|
||||
from ldm.invoke.globals import Globals
|
||||
import numpy as np
|
||||
"""
|
||||
import numpy as np
|
||||
|
||||
from invokeai.backend.globals import Globals
|
||||
|
||||
|
||||
class PatchMatch:
|
||||
'''
|
||||
"""
|
||||
Thin class wrapper around the patchmatch function.
|
||||
'''
|
||||
"""
|
||||
|
||||
patch_match = None
|
||||
tried_load:bool = False
|
||||
tried_load: bool = False
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
@ -24,21 +26,22 @@ class PatchMatch:
|
||||
return
|
||||
if Globals.try_patchmatch:
|
||||
from patchmatch import patch_match as pm
|
||||
|
||||
if pm.patchmatch_available:
|
||||
print('>> Patchmatch initialized')
|
||||
print(">> Patchmatch initialized")
|
||||
else:
|
||||
print('>> Patchmatch not loaded (nonfatal)')
|
||||
print(">> Patchmatch not loaded (nonfatal)")
|
||||
self.patch_match = pm
|
||||
else:
|
||||
print('>> Patchmatch loading disabled')
|
||||
print(">> Patchmatch loading disabled")
|
||||
self.tried_load = True
|
||||
|
||||
@classmethod
|
||||
def patchmatch_available(self)->bool:
|
||||
def patchmatch_available(self) -> bool:
|
||||
self._load_patch_match()
|
||||
return self.patch_match and self.patch_match.patchmatch_available
|
||||
|
||||
@classmethod
|
||||
def inpaint(self,*args,**kwargs)->np.ndarray:
|
||||
def inpaint(self, *args, **kwargs) -> np.ndarray:
|
||||
if self.patchmatch_available():
|
||||
return self.patch_match.inpaint(*args,**kwargs)
|
||||
return self.patch_match.inpaint(*args, **kwargs)
|
@ -6,10 +6,11 @@ PngWriter -- Converts Images generated by T2I into PNGs, finds
|
||||
|
||||
Exports function retrieve_metadata(path)
|
||||
"""
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import json
|
||||
from PIL import PngImagePlugin, Image
|
||||
|
||||
from PIL import Image, PngImagePlugin
|
||||
|
||||
# -------------------image generation utils-----
|
||||
|
||||
@ -25,52 +26,57 @@ class PngWriter:
|
||||
dirlist = sorted(os.listdir(self.outdir), reverse=True)
|
||||
# find the first filename that matches our pattern or return 000000.0.png
|
||||
existing_name = next(
|
||||
(f for f in dirlist if re.match('^(\d+)\..*\.png', f)),
|
||||
'0000000.0.png',
|
||||
(f for f in dirlist if re.match("^(\d+)\..*\.png", f)),
|
||||
"0000000.0.png",
|
||||
)
|
||||
basecount = int(existing_name.split('.', 1)[0]) + 1
|
||||
return f'{basecount:06}'
|
||||
basecount = int(existing_name.split(".", 1)[0]) + 1
|
||||
return f"{basecount:06}"
|
||||
|
||||
# saves image named _image_ to outdir/name, writing metadata from prompt
|
||||
# returns full path of output
|
||||
def save_image_and_prompt_to_png(self, image, dream_prompt, name, metadata=None, compress_level=6):
|
||||
def save_image_and_prompt_to_png(
|
||||
self, image, dream_prompt, name, metadata=None, compress_level=6
|
||||
):
|
||||
path = os.path.join(self.outdir, name)
|
||||
info = PngImagePlugin.PngInfo()
|
||||
info.add_text('Dream', dream_prompt)
|
||||
info.add_text("Dream", dream_prompt)
|
||||
if metadata:
|
||||
info.add_text('sd-metadata', json.dumps(metadata))
|
||||
image.save(path, 'PNG', pnginfo=info, compress_level=compress_level)
|
||||
info.add_text("sd-metadata", json.dumps(metadata))
|
||||
image.save(path, "PNG", pnginfo=info, compress_level=compress_level)
|
||||
return path
|
||||
|
||||
def retrieve_metadata(self,img_basename):
|
||||
'''
|
||||
def retrieve_metadata(self, img_basename):
|
||||
"""
|
||||
Given a PNG filename stored in outdir, returns the "sd-metadata"
|
||||
metadata stored there, as a dict
|
||||
'''
|
||||
path = os.path.join(self.outdir,img_basename)
|
||||
"""
|
||||
path = os.path.join(self.outdir, img_basename)
|
||||
all_metadata = retrieve_metadata(path)
|
||||
return all_metadata['sd-metadata']
|
||||
return all_metadata["sd-metadata"]
|
||||
|
||||
|
||||
def retrieve_metadata(img_path):
|
||||
'''
|
||||
"""
|
||||
Given a path to a PNG image, returns the "sd-metadata"
|
||||
metadata stored there, as a dict
|
||||
'''
|
||||
"""
|
||||
im = Image.open(img_path)
|
||||
if hasattr(im, 'text'):
|
||||
md = im.text.get('sd-metadata', '{}')
|
||||
dream_prompt = im.text.get('Dream', '')
|
||||
if hasattr(im, "text"):
|
||||
md = im.text.get("sd-metadata", "{}")
|
||||
dream_prompt = im.text.get("Dream", "")
|
||||
else:
|
||||
# When trying to retrieve metadata from images without a 'text' payload, such as JPG images.
|
||||
md = '{}'
|
||||
dream_prompt = ''
|
||||
return {'sd-metadata': json.loads(md), 'Dream': dream_prompt}
|
||||
md = "{}"
|
||||
dream_prompt = ""
|
||||
return {"sd-metadata": json.loads(md), "Dream": dream_prompt}
|
||||
|
||||
def write_metadata(img_path:str, meta:dict):
|
||||
|
||||
def write_metadata(img_path: str, meta: dict):
|
||||
im = Image.open(img_path)
|
||||
info = PngImagePlugin.PngInfo()
|
||||
info.add_text('sd-metadata', json.dumps(meta))
|
||||
im.save(img_path,'PNG',pnginfo=info)
|
||||
info.add_text("sd-metadata", json.dumps(meta))
|
||||
im.save(img_path, "PNG", pnginfo=info)
|
||||
|
||||
|
||||
class PromptFormatter:
|
||||
def __init__(self, t2i, opt):
|
||||
@ -86,28 +92,30 @@ class PromptFormatter:
|
||||
|
||||
switches = list()
|
||||
switches.append(f'"{opt.prompt}"')
|
||||
switches.append(f'-s{opt.steps or t2i.steps}')
|
||||
switches.append(f'-W{opt.width or t2i.width}')
|
||||
switches.append(f'-H{opt.height or t2i.height}')
|
||||
switches.append(f'-C{opt.cfg_scale or t2i.cfg_scale}')
|
||||
switches.append(f'-A{opt.sampler_name or t2i.sampler_name}')
|
||||
# to do: put model name into the t2i object
|
||||
# switches.append(f'--model{t2i.model_name}')
|
||||
switches.append(f"-s{opt.steps or t2i.steps}")
|
||||
switches.append(f"-W{opt.width or t2i.width}")
|
||||
switches.append(f"-H{opt.height or t2i.height}")
|
||||
switches.append(f"-C{opt.cfg_scale or t2i.cfg_scale}")
|
||||
switches.append(f"-A{opt.sampler_name or t2i.sampler_name}")
|
||||
# to do: put model name into the t2i object
|
||||
# switches.append(f'--model{t2i.model_name}')
|
||||
if opt.seamless or t2i.seamless:
|
||||
switches.append(f'--seamless')
|
||||
switches.append(f"--seamless")
|
||||
if opt.init_img:
|
||||
switches.append(f'-I{opt.init_img}')
|
||||
switches.append(f"-I{opt.init_img}")
|
||||
if opt.fit:
|
||||
switches.append(f'--fit')
|
||||
switches.append(f"--fit")
|
||||
if opt.strength and opt.init_img is not None:
|
||||
switches.append(f'-f{opt.strength or t2i.strength}')
|
||||
switches.append(f"-f{opt.strength or t2i.strength}")
|
||||
if opt.gfpgan_strength:
|
||||
switches.append(f'-G{opt.gfpgan_strength}')
|
||||
switches.append(f"-G{opt.gfpgan_strength}")
|
||||
if opt.upscale:
|
||||
switches.append(f'-U {" ".join([str(u) for u in opt.upscale])}')
|
||||
if opt.variation_amount > 0:
|
||||
switches.append(f'-v{opt.variation_amount}')
|
||||
switches.append(f"-v{opt.variation_amount}")
|
||||
if opt.with_variations:
|
||||
formatted_variations = ','.join(f'{seed}:{weight}' for seed, weight in opt.with_variations)
|
||||
switches.append(f'-V{formatted_variations}')
|
||||
return ' '.join(switches)
|
||||
formatted_variations = ",".join(
|
||||
f"{seed}:{weight}" for seed, weight in opt.with_variations
|
||||
)
|
||||
switches.append(f"-V{formatted_variations}")
|
||||
return " ".join(switches)
|
59
invokeai/backend/image_util/seamless.py
Normal file
59
invokeai/backend/image_util/seamless.py
Normal file
@ -0,0 +1,59 @@
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
def _conv_forward_asymmetric(self, input, weight, bias):
|
||||
"""
|
||||
Patch for Conv2d._conv_forward that supports asymmetric padding
|
||||
"""
|
||||
working = nn.functional.pad(
|
||||
input, self.asymmetric_padding["x"], mode=self.asymmetric_padding_mode["x"]
|
||||
)
|
||||
working = nn.functional.pad(
|
||||
working, self.asymmetric_padding["y"], mode=self.asymmetric_padding_mode["y"]
|
||||
)
|
||||
return nn.functional.conv2d(
|
||||
working,
|
||||
weight,
|
||||
bias,
|
||||
self.stride,
|
||||
nn.modules.utils._pair(0),
|
||||
self.dilation,
|
||||
self.groups,
|
||||
)
|
||||
|
||||
|
||||
def configure_model_padding(model, seamless, seamless_axes):
|
||||
"""
|
||||
Modifies the 2D convolution layers to use a circular padding mode based on the `seamless` and `seamless_axes` options.
|
||||
"""
|
||||
# TODO: get an explicit interface for this in diffusers: https://github.com/huggingface/diffusers/issues/556
|
||||
for m in model.modules():
|
||||
if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
|
||||
if seamless:
|
||||
m.asymmetric_padding_mode = {}
|
||||
m.asymmetric_padding = {}
|
||||
m.asymmetric_padding_mode["x"] = (
|
||||
"circular" if ("x" in seamless_axes) else "constant"
|
||||
)
|
||||
m.asymmetric_padding["x"] = (
|
||||
m._reversed_padding_repeated_twice[0],
|
||||
m._reversed_padding_repeated_twice[1],
|
||||
0,
|
||||
0,
|
||||
)
|
||||
m.asymmetric_padding_mode["y"] = (
|
||||
"circular" if ("y" in seamless_axes) else "constant"
|
||||
)
|
||||
m.asymmetric_padding["y"] = (
|
||||
0,
|
||||
0,
|
||||
m._reversed_padding_repeated_twice[2],
|
||||
m._reversed_padding_repeated_twice[3],
|
||||
)
|
||||
m._conv_forward = _conv_forward_asymmetric.__get__(m, nn.Conv2d)
|
||||
else:
|
||||
m._conv_forward = nn.Conv2d._conv_forward.__get__(m, nn.Conv2d)
|
||||
if hasattr(m, "asymmetric_padding_mode"):
|
||||
del m.asymmetric_padding_mode
|
||||
if hasattr(m, "asymmetric_padding"):
|
||||
del m.asymmetric_padding
|
@ -1,9 +1,9 @@
|
||||
'''Makes available the Txt2Mask class, which assists in the automatic
|
||||
"""Makes available the Txt2Mask class, which assists in the automatic
|
||||
assignment of masks via text prompt using clipseg.
|
||||
|
||||
Here is typical usage:
|
||||
|
||||
from ldm.invoke.txt2mask import Txt2Mask, SegmentedGrayscale
|
||||
from invokeai.backend.image_util.txt2mask import Txt2Mask, SegmentedGrayscale
|
||||
from PIL import Image
|
||||
|
||||
txt2mask = Txt2Mask(self.device)
|
||||
@ -25,31 +25,39 @@ the mask that exceed the indicated confidence threshold. Values range
|
||||
from 0.0 to 1.0. The higher the threshold, the more confident the
|
||||
algorithm is. In limited testing, I have found that values around 0.5
|
||||
work fine.
|
||||
'''
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import numpy as np
|
||||
from transformers import AutoProcessor, CLIPSegForImageSegmentation
|
||||
from PIL import Image, ImageOps
|
||||
from torchvision import transforms
|
||||
from ldm.invoke.globals import global_cache_dir
|
||||
from transformers import AutoProcessor, CLIPSegForImageSegmentation
|
||||
|
||||
CLIPSEG_MODEL = 'CIDAS/clipseg-rd64-refined'
|
||||
from invokeai.backend.globals import global_cache_dir
|
||||
|
||||
CLIPSEG_MODEL = "CIDAS/clipseg-rd64-refined"
|
||||
CLIPSEG_SIZE = 352
|
||||
|
||||
|
||||
class SegmentedGrayscale(object):
|
||||
def __init__(self, image:Image, heatmap:torch.Tensor):
|
||||
def __init__(self, image: Image, heatmap: torch.Tensor):
|
||||
self.heatmap = heatmap
|
||||
self.image = image
|
||||
|
||||
def to_grayscale(self,invert:bool=False)->Image:
|
||||
return self._rescale(Image.fromarray(np.uint8(255 - self.heatmap * 255 if invert else self.heatmap * 255)))
|
||||
def to_grayscale(self, invert: bool = False) -> Image:
|
||||
return self._rescale(
|
||||
Image.fromarray(
|
||||
np.uint8(255 - self.heatmap * 255 if invert else self.heatmap * 255)
|
||||
)
|
||||
)
|
||||
|
||||
def to_mask(self,threshold:float=0.5)->Image:
|
||||
def to_mask(self, threshold: float = 0.5) -> Image:
|
||||
discrete_heatmap = self.heatmap.lt(threshold).int()
|
||||
return self._rescale(Image.fromarray(np.uint8(discrete_heatmap*255),mode='L'))
|
||||
return self._rescale(
|
||||
Image.fromarray(np.uint8(discrete_heatmap * 255), mode="L")
|
||||
)
|
||||
|
||||
def to_transparent(self,invert:bool=False)->Image:
|
||||
def to_transparent(self, invert: bool = False) -> Image:
|
||||
transparent_image = self.image.copy()
|
||||
# For img2img, we want the selected regions to be transparent,
|
||||
# but to_grayscale() returns the opposite. Thus invert.
|
||||
@ -58,70 +66,77 @@ class SegmentedGrayscale(object):
|
||||
return transparent_image
|
||||
|
||||
# unscales and uncrops the 352x352 heatmap so that it matches the image again
|
||||
def _rescale(self, heatmap:Image)->Image:
|
||||
size = self.image.width if (self.image.width > self.image.height) else self.image.height
|
||||
resized_image = heatmap.resize(
|
||||
(size,size),
|
||||
resample=Image.Resampling.LANCZOS
|
||||
def _rescale(self, heatmap: Image) -> Image:
|
||||
size = (
|
||||
self.image.width
|
||||
if (self.image.width > self.image.height)
|
||||
else self.image.height
|
||||
)
|
||||
return resized_image.crop((0,0,self.image.width,self.image.height))
|
||||
resized_image = heatmap.resize((size, size), resample=Image.Resampling.LANCZOS)
|
||||
return resized_image.crop((0, 0, self.image.width, self.image.height))
|
||||
|
||||
|
||||
class Txt2Mask(object):
|
||||
'''
|
||||
"""
|
||||
Create new Txt2Mask object. The optional device argument can be one of
|
||||
'cuda', 'mps' or 'cpu'.
|
||||
'''
|
||||
def __init__(self,device='cpu',refined=False):
|
||||
print('>> Initializing clipseg model for text to mask inference')
|
||||
"""
|
||||
|
||||
def __init__(self, device="cpu", refined=False):
|
||||
print(">> Initializing clipseg model for text to mask inference")
|
||||
|
||||
# BUG: we are not doing anything with the device option at this time
|
||||
self.device = device
|
||||
self.processor = AutoProcessor.from_pretrained(CLIPSEG_MODEL,
|
||||
cache_dir=global_cache_dir('hub')
|
||||
)
|
||||
self.model = CLIPSegForImageSegmentation.from_pretrained(CLIPSEG_MODEL,
|
||||
cache_dir=global_cache_dir('hub')
|
||||
)
|
||||
self.processor = AutoProcessor.from_pretrained(
|
||||
CLIPSEG_MODEL, cache_dir=global_cache_dir("hub")
|
||||
)
|
||||
self.model = CLIPSegForImageSegmentation.from_pretrained(
|
||||
CLIPSEG_MODEL, cache_dir=global_cache_dir("hub")
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def segment(self, image, prompt:str) -> SegmentedGrayscale:
|
||||
'''
|
||||
def segment(self, image, prompt: str) -> SegmentedGrayscale:
|
||||
"""
|
||||
Given a prompt string such as "a bagel", tries to identify the object in the
|
||||
provided image and returns a SegmentedGrayscale object in which the brighter
|
||||
pixels indicate where the object is inferred to be.
|
||||
'''
|
||||
transform = transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
||||
transforms.Resize((CLIPSEG_SIZE, CLIPSEG_SIZE)), # must be multiple of 64...
|
||||
])
|
||||
"""
|
||||
transform = transforms.Compose(
|
||||
[
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(
|
||||
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
|
||||
),
|
||||
transforms.Resize(
|
||||
(CLIPSEG_SIZE, CLIPSEG_SIZE)
|
||||
), # must be multiple of 64...
|
||||
]
|
||||
)
|
||||
|
||||
if type(image) is str:
|
||||
image = Image.open(image).convert('RGB')
|
||||
image = Image.open(image).convert("RGB")
|
||||
|
||||
image = ImageOps.exif_transpose(image)
|
||||
img = self._scale_and_crop(image)
|
||||
|
||||
inputs = self.processor(text=[prompt],
|
||||
images=[img],
|
||||
padding=True,
|
||||
return_tensors='pt')
|
||||
inputs = self.processor(
|
||||
text=[prompt], images=[img], padding=True, return_tensors="pt"
|
||||
)
|
||||
outputs = self.model(**inputs)
|
||||
heatmap = torch.sigmoid(outputs.logits)
|
||||
return SegmentedGrayscale(image, heatmap)
|
||||
|
||||
def _scale_and_crop(self, image:Image)->Image:
|
||||
scaled_image = Image.new('RGB',(CLIPSEG_SIZE,CLIPSEG_SIZE))
|
||||
if image.width > image.height: # width is constraint
|
||||
def _scale_and_crop(self, image: Image) -> Image:
|
||||
scaled_image = Image.new("RGB", (CLIPSEG_SIZE, CLIPSEG_SIZE))
|
||||
if image.width > image.height: # width is constraint
|
||||
scale = CLIPSEG_SIZE / image.width
|
||||
else:
|
||||
scale = CLIPSEG_SIZE / image.height
|
||||
scaled_image.paste(
|
||||
image.resize(
|
||||
(int(scale * image.width),
|
||||
int(scale * image.height)
|
||||
),
|
||||
resample=Image.Resampling.LANCZOS
|
||||
),box=(0,0)
|
||||
(int(scale * image.width), int(scale * image.height)),
|
||||
resample=Image.Resampling.LANCZOS,
|
||||
),
|
||||
box=(0, 0),
|
||||
)
|
||||
return scaled_image
|
@ -1,12 +1,15 @@
|
||||
from math import sqrt, floor, ceil
|
||||
from math import ceil, floor, sqrt
|
||||
|
||||
from PIL import Image
|
||||
|
||||
class InitImageResizer():
|
||||
|
||||
class InitImageResizer:
|
||||
"""Simple class to create resized copies of an Image while preserving the aspect ratio."""
|
||||
def __init__(self,Image):
|
||||
|
||||
def __init__(self, Image):
|
||||
self.image = Image
|
||||
|
||||
def resize(self,width=None,height=None) -> Image:
|
||||
def resize(self, width=None, height=None) -> Image:
|
||||
"""
|
||||
Return a copy of the image resized to fit within
|
||||
a box width x height. The aspect ratio is
|
||||
@ -18,37 +21,36 @@ class InitImageResizer():
|
||||
Everything is floored to the nearest multiple of 64 so
|
||||
that it can be passed to img2img()
|
||||
"""
|
||||
im = self.image
|
||||
im = self.image
|
||||
|
||||
ar = im.width/float(im.height)
|
||||
ar = im.width / float(im.height)
|
||||
|
||||
# Infer missing values from aspect ratio
|
||||
if not(width or height): # both missing
|
||||
width = im.width
|
||||
if not (width or height): # both missing
|
||||
width = im.width
|
||||
height = im.height
|
||||
elif not height: # height missing
|
||||
height = int(width/ar)
|
||||
elif not width: # width missing
|
||||
width = int(height*ar)
|
||||
elif not height: # height missing
|
||||
height = int(width / ar)
|
||||
elif not width: # width missing
|
||||
width = int(height * ar)
|
||||
|
||||
w_scale = width/im.width
|
||||
h_scale = height/im.height
|
||||
scale = min(w_scale,h_scale)
|
||||
(rw,rh) = (int(scale*im.width),int(scale*im.height))
|
||||
w_scale = width / im.width
|
||||
h_scale = height / im.height
|
||||
scale = min(w_scale, h_scale)
|
||||
(rw, rh) = (int(scale * im.width), int(scale * im.height))
|
||||
|
||||
#round everything to multiples of 64
|
||||
width,height,rw,rh = map(
|
||||
lambda x: x-x%64, (width,height,rw,rh)
|
||||
)
|
||||
# round everything to multiples of 64
|
||||
width, height, rw, rh = map(lambda x: x - x % 64, (width, height, rw, rh))
|
||||
|
||||
# no resize necessary, but return a copy
|
||||
if im.width == width and im.height == height:
|
||||
return im.copy()
|
||||
|
||||
# otherwise resize the original image so that it fits inside the bounding box
|
||||
resized_image = self.image.resize((rw,rh),resample=Image.Resampling.LANCZOS)
|
||||
resized_image = self.image.resize((rw, rh), resample=Image.Resampling.LANCZOS)
|
||||
return resized_image
|
||||
|
||||
|
||||
def make_grid(image_list, rows=None, cols=None):
|
||||
image_cnt = len(image_list)
|
||||
if None in (rows, cols):
|
||||
@ -57,7 +59,7 @@ def make_grid(image_list, rows=None, cols=None):
|
||||
width = image_list[0].width
|
||||
height = image_list[0].height
|
||||
|
||||
grid_img = Image.new('RGB', (width * cols, height * rows))
|
||||
grid_img = Image.new("RGB", (width * cols, height * rows))
|
||||
i = 0
|
||||
for r in range(0, rows):
|
||||
for c in range(0, cols):
|
||||
@ -67,4 +69,3 @@ def make_grid(image_list, rows=None, cols=None):
|
||||
i = i + 1
|
||||
|
||||
return grid_img
|
||||
|
10
invokeai/backend/model_management/__init__.py
Normal file
10
invokeai/backend/model_management/__init__.py
Normal file
@ -0,0 +1,10 @@
|
||||
"""
|
||||
Initialization file for invokeai.backend.model_management
|
||||
"""
|
||||
from .convert_ckpt_to_diffusers import (
|
||||
convert_ckpt_to_diffusers,
|
||||
load_pipeline_from_original_stable_diffusion_ckpt,
|
||||
)
|
||||
from .model_manager import ModelManager
|
||||
from invokeai.frontend.merge import merge_diffusion_models
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -9,7 +9,6 @@ from __future__ import annotations
|
||||
import contextlib
|
||||
import gc
|
||||
import hashlib
|
||||
import io
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
@ -32,15 +31,10 @@ from omegaconf import OmegaConf
|
||||
from omegaconf.dictconfig import DictConfig
|
||||
from picklescan.scanner import scan_file_path
|
||||
|
||||
from ldm.invoke.devices import CPU_DEVICE
|
||||
from ldm.invoke.generator.diffusers_pipeline import StableDiffusionGeneratorPipeline
|
||||
from ldm.invoke.globals import Globals, global_cache_dir
|
||||
from ldm.util import (
|
||||
ask_user,
|
||||
download_with_resume,
|
||||
instantiate_from_config,
|
||||
url_attachment_name,
|
||||
)
|
||||
from invokeai.backend.globals import Globals, global_cache_dir
|
||||
|
||||
from ..stable_diffusion import StableDiffusionGeneratorPipeline
|
||||
from ..util import CPU_DEVICE, ask_user, download_with_resume
|
||||
|
||||
|
||||
class SDLegacyType(Enum):
|
||||
@ -341,22 +335,9 @@ class ModelManager(object):
|
||||
|
||||
tic = time.time()
|
||||
|
||||
# this does the work
|
||||
model_format = mconfig.get("format", "ckpt")
|
||||
if model_format == "ckpt":
|
||||
weights = mconfig.weights
|
||||
print(f">> Loading {model_name} from {weights}")
|
||||
model, width, height, model_hash = self._load_ckpt_model(
|
||||
model_name, mconfig
|
||||
)
|
||||
elif model_format == "diffusers":
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
model, width, height, model_hash = self._load_diffusers_model(mconfig)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"Unknown model format {model_name}: {model_format}"
|
||||
)
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
model, width, height, model_hash = self._load_diffusers_model(mconfig)
|
||||
|
||||
# usage statistics
|
||||
toc = time.time()
|
||||
@ -370,125 +351,6 @@ class ModelManager(object):
|
||||
)
|
||||
return model, width, height, model_hash
|
||||
|
||||
def _load_ckpt_model(self, model_name, mconfig):
|
||||
config = mconfig.config
|
||||
weights = mconfig.weights
|
||||
vae = mconfig.get("vae")
|
||||
width = mconfig.width
|
||||
height = mconfig.height
|
||||
|
||||
if not os.path.isabs(config):
|
||||
config = os.path.join(Globals.root, config)
|
||||
if not os.path.isabs(weights):
|
||||
weights = os.path.normpath(os.path.join(Globals.root, weights))
|
||||
|
||||
# if converting automatically to diffusers, then we do the conversion and return
|
||||
# a diffusers pipeline
|
||||
if Globals.ckpt_convert:
|
||||
print(
|
||||
f">> Converting legacy checkpoint {model_name} into a diffusers model..."
|
||||
)
|
||||
from ldm.invoke.ckpt_to_diffuser import (
|
||||
load_pipeline_from_original_stable_diffusion_ckpt,
|
||||
)
|
||||
|
||||
self.offload_model(self.current_model)
|
||||
if vae_config := self._choose_diffusers_vae(model_name):
|
||||
vae = self._load_vae(vae_config)
|
||||
if self._has_cuda():
|
||||
torch.cuda.empty_cache()
|
||||
pipeline = load_pipeline_from_original_stable_diffusion_ckpt(
|
||||
checkpoint_path=weights,
|
||||
original_config_file=config,
|
||||
vae=vae,
|
||||
return_generator_pipeline=True,
|
||||
precision=torch.float16
|
||||
if self.precision == "float16"
|
||||
else torch.float32,
|
||||
)
|
||||
if self.sequential_offload:
|
||||
pipeline.enable_offload_submodels(self.device)
|
||||
else:
|
||||
pipeline.to(self.device)
|
||||
|
||||
return (
|
||||
pipeline,
|
||||
width,
|
||||
height,
|
||||
"NOHASH",
|
||||
)
|
||||
|
||||
# scan model
|
||||
self.scan_model(model_name, weights)
|
||||
|
||||
print(f">> Loading {model_name} from {weights}")
|
||||
|
||||
# for usage statistics
|
||||
if self._has_cuda():
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# this does the work
|
||||
if not os.path.isabs(config):
|
||||
config = os.path.join(Globals.root, config)
|
||||
omega_config = OmegaConf.load(config)
|
||||
with open(weights, "rb") as f:
|
||||
weight_bytes = f.read()
|
||||
model_hash = self._cached_sha256(weights, weight_bytes)
|
||||
sd = None
|
||||
if weights.endswith(".safetensors"):
|
||||
sd = safetensors.torch.load(weight_bytes)
|
||||
else:
|
||||
sd = torch.load(io.BytesIO(weight_bytes), map_location="cpu")
|
||||
del weight_bytes
|
||||
# merged models from auto11 merge board are flat for some reason
|
||||
if "state_dict" in sd:
|
||||
sd = sd["state_dict"]
|
||||
|
||||
print(" | Forcing garbage collection prior to loading new model")
|
||||
gc.collect()
|
||||
model = instantiate_from_config(omega_config.model)
|
||||
model.load_state_dict(sd, strict=False)
|
||||
|
||||
if self.precision == "float16":
|
||||
print(" | Using faster float16 precision")
|
||||
model = model.to(torch.float16)
|
||||
else:
|
||||
print(" | Using more accurate float32 precision")
|
||||
|
||||
# look and load a matching vae file. Code borrowed from AUTOMATIC1111 modules/sd_models.py
|
||||
if vae:
|
||||
if not os.path.isabs(vae):
|
||||
vae = os.path.normpath(os.path.join(Globals.root, vae))
|
||||
if os.path.exists(vae):
|
||||
print(f" | Loading VAE weights from: {vae}")
|
||||
vae_ckpt = None
|
||||
vae_dict = None
|
||||
if vae.endswith(".safetensors"):
|
||||
vae_ckpt = safetensors.torch.load_file(vae)
|
||||
vae_dict = {k: v for k, v in vae_ckpt.items() if k[0:4] != "loss"}
|
||||
else:
|
||||
vae_ckpt = torch.load(vae, map_location="cpu")
|
||||
vae_dict = {
|
||||
k: v
|
||||
for k, v in vae_ckpt["state_dict"].items()
|
||||
if k[0:4] != "loss"
|
||||
}
|
||||
model.first_stage_model.load_state_dict(vae_dict, strict=False)
|
||||
else:
|
||||
print(f" | VAE file {vae} not found. Skipping.")
|
||||
|
||||
model.to(self.device)
|
||||
# model.to doesn't change the cond_stage_model.device used to move the tokenizer output, so set it here
|
||||
model.cond_stage_model.device = self.device
|
||||
|
||||
model.eval()
|
||||
|
||||
for module in model.modules():
|
||||
if isinstance(module, (torch.nn.Conv2d, torch.nn.ConvTranspose2d)):
|
||||
module._orig_padding_mode = module.padding_mode
|
||||
return model, width, height, model_hash
|
||||
|
||||
def _load_diffusers_model(self, mconfig):
|
||||
name_or_path = self.model_name_or_path(mconfig)
|
||||
using_fp16 = self.precision == "float16"
|
||||
@ -553,6 +415,47 @@ class ModelManager(object):
|
||||
|
||||
return pipeline, width, height, model_hash
|
||||
|
||||
def _load_ckpt_model(self, model_name, mconfig):
|
||||
config = mconfig.config
|
||||
weights = mconfig.weights
|
||||
vae = mconfig.get("vae")
|
||||
width = mconfig.width
|
||||
height = mconfig.height
|
||||
|
||||
if not os.path.isabs(config):
|
||||
config = os.path.join(Globals.root, config)
|
||||
if not os.path.isabs(weights):
|
||||
weights = os.path.normpath(os.path.join(Globals.root, weights))
|
||||
|
||||
# Convert to diffusers and return a diffusers pipeline
|
||||
print(f">> Converting legacy checkpoint {model_name} into a diffusers model...")
|
||||
|
||||
from . import load_pipeline_from_original_stable_diffusion_ckpt
|
||||
|
||||
self.offload_model(self.current_model)
|
||||
if vae_config := self._choose_diffusers_vae(model_name):
|
||||
vae = self._load_vae(vae_config)
|
||||
if self._has_cuda():
|
||||
torch.cuda.empty_cache()
|
||||
pipeline = load_pipeline_from_original_stable_diffusion_ckpt(
|
||||
checkpoint_path=weights,
|
||||
original_config_file=config,
|
||||
vae=vae,
|
||||
return_generator_pipeline=True,
|
||||
precision=torch.float16 if self.precision == "float16" else torch.float32,
|
||||
)
|
||||
if self.sequential_offload:
|
||||
pipeline.enable_offload_submodels(self.device)
|
||||
else:
|
||||
pipeline.to(self.device)
|
||||
|
||||
return (
|
||||
pipeline,
|
||||
width,
|
||||
height,
|
||||
"NOHASH",
|
||||
)
|
||||
|
||||
def model_name_or_path(self, model_name: Union[str, DictConfig]) -> str | Path:
|
||||
if isinstance(model_name, DictConfig) or isinstance(model_name, dict):
|
||||
mconfig = model_name
|
||||
@ -640,7 +543,9 @@ class ModelManager(object):
|
||||
models.yaml file.
|
||||
"""
|
||||
model_name = model_name or Path(repo_or_path).stem
|
||||
model_description = model_description or f"Imported diffusers model {model_name}"
|
||||
model_description = (
|
||||
model_description or f"Imported diffusers model {model_name}"
|
||||
)
|
||||
new_config = dict(
|
||||
description=model_description,
|
||||
vae=vae,
|
||||
@ -656,66 +561,6 @@ class ModelManager(object):
|
||||
self.commit(commit_to_conf)
|
||||
return model_name
|
||||
|
||||
def import_ckpt_model(
|
||||
self,
|
||||
weights: Union[str, Path],
|
||||
config: Union[str, Path] = "configs/stable-diffusion/v1-inference.yaml",
|
||||
vae: Union[str, Path] = None,
|
||||
model_name: str = None,
|
||||
model_description: str = None,
|
||||
commit_to_conf: Path = None,
|
||||
) -> str:
|
||||
"""
|
||||
Attempts to install the indicated ckpt file and returns True if successful.
|
||||
|
||||
"weights" can be either a path-like object corresponding to a local .ckpt file
|
||||
or a http/https URL pointing to a remote model.
|
||||
|
||||
"vae" is a Path or str object pointing to a ckpt or safetensors file to be used
|
||||
as the VAE for this model.
|
||||
|
||||
"config" is the model config file to use with this ckpt file. It defaults to
|
||||
v1-inference.yaml. If a URL is provided, the config will be downloaded.
|
||||
|
||||
You can optionally provide a model name and/or description. If not provided,
|
||||
then these will be derived from the weight file name. If you provide a commit_to_conf
|
||||
path to the configuration file, then the new entry will be committed to the
|
||||
models.yaml file.
|
||||
|
||||
Return value is the name of the imported file, or None if an error occurred.
|
||||
"""
|
||||
if str(weights).startswith(("http:", "https:")):
|
||||
model_name = model_name or url_attachment_name(weights)
|
||||
|
||||
weights_path = self._resolve_path(weights, "models/ldm/stable-diffusion-v1")
|
||||
config_path = self._resolve_path(config, "configs/stable-diffusion")
|
||||
|
||||
if weights_path is None or not weights_path.exists():
|
||||
return
|
||||
if config_path is None or not config_path.exists():
|
||||
return
|
||||
|
||||
model_name = (
|
||||
model_name or Path(weights).stem
|
||||
) # note this gives ugly pathnames if used on a URL without a Content-Disposition header
|
||||
model_description = (
|
||||
model_description or f"Imported stable diffusion weights file {model_name}"
|
||||
)
|
||||
new_config = dict(
|
||||
weights=str(weights_path),
|
||||
config=str(config_path),
|
||||
description=model_description,
|
||||
format="ckpt",
|
||||
width=512,
|
||||
height=512,
|
||||
)
|
||||
if vae:
|
||||
new_config["vae"] = vae
|
||||
self.add_model(model_name, new_config, True)
|
||||
if commit_to_conf:
|
||||
self.commit(commit_to_conf)
|
||||
return model_name
|
||||
|
||||
@classmethod
|
||||
def probe_model_type(self, checkpoint: dict) -> SDLegacyType:
|
||||
"""
|
||||
@ -748,7 +593,7 @@ class ModelManager(object):
|
||||
def heuristic_import(
|
||||
self,
|
||||
path_url_or_repo: str,
|
||||
convert: bool = False,
|
||||
convert: bool = True,
|
||||
model_name: str = None,
|
||||
description: str = None,
|
||||
commit_to_conf: Path = None,
|
||||
@ -883,35 +728,18 @@ class ModelManager(object):
|
||||
)
|
||||
return
|
||||
|
||||
if convert:
|
||||
diffuser_path = Path(
|
||||
Globals.root, "models", Globals.converted_ckpts_dir, model_path.stem
|
||||
)
|
||||
model_name = self.convert_and_import(
|
||||
model_path,
|
||||
diffusers_path=diffuser_path,
|
||||
vae=dict(repo_id="stabilityai/sd-vae-ft-mse"),
|
||||
model_name=model_name,
|
||||
model_description=description,
|
||||
original_config_file=model_config_file,
|
||||
commit_to_conf=commit_to_conf,
|
||||
)
|
||||
else:
|
||||
model_name = self.import_ckpt_model(
|
||||
model_path,
|
||||
config=model_config_file,
|
||||
model_name=model_name,
|
||||
model_description=description,
|
||||
vae=str(
|
||||
Path(
|
||||
Globals.root,
|
||||
"models/ldm/stable-diffusion-v1/vae-ft-mse-840000-ema-pruned.ckpt",
|
||||
)
|
||||
),
|
||||
commit_to_conf=commit_to_conf,
|
||||
)
|
||||
if commit_to_conf:
|
||||
self.commit(commit_to_conf)
|
||||
diffuser_path = Path(
|
||||
Globals.root, "models", Globals.converted_ckpts_dir, model_path.stem
|
||||
)
|
||||
model_name = self.convert_and_import(
|
||||
model_path,
|
||||
diffusers_path=diffuser_path,
|
||||
vae=dict(repo_id="stabilityai/sd-vae-ft-mse"),
|
||||
model_name=model_name,
|
||||
model_description=description,
|
||||
original_config_file=model_config_file,
|
||||
commit_to_conf=commit_to_conf,
|
||||
)
|
||||
return model_name
|
||||
|
||||
def convert_and_import(
|
||||
@ -936,7 +764,7 @@ class ModelManager(object):
|
||||
|
||||
new_config = None
|
||||
|
||||
from ldm.invoke.ckpt_to_diffuser import convert_ckpt_to_diffuser
|
||||
from . import convert_ckpt_to_diffusers
|
||||
|
||||
if diffusers_path.exists():
|
||||
print(
|
||||
@ -951,7 +779,7 @@ class ModelManager(object):
|
||||
# By passing the specified VAE to the conversion function, the autoencoder
|
||||
# will be built into the model rather than tacked on afterward via the config file
|
||||
vae_model = self._load_vae(vae) if vae else None
|
||||
convert_ckpt_to_diffuser(
|
||||
convert_ckpt_to_diffusers(
|
||||
ckpt_path,
|
||||
diffusers_path,
|
||||
extract_ema=True,
|
10
invokeai/backend/prompting/__init__.py
Normal file
10
invokeai/backend/prompting/__init__.py
Normal file
@ -0,0 +1,10 @@
|
||||
"""
|
||||
Initialization file for invokeai.backend.prompting
|
||||
"""
|
||||
from .conditioning import (
|
||||
get_prompt_structure,
|
||||
get_tokenizer,
|
||||
get_tokens_for_prompt_object,
|
||||
get_uc_and_c_and_ec,
|
||||
split_weighted_subprompts,
|
||||
)
|
@ -1,31 +1,46 @@
|
||||
'''
|
||||
"""
|
||||
This module handles the generation of the conditioning tensors.
|
||||
|
||||
Useful function exports:
|
||||
|
||||
get_uc_and_c_and_ec() get the conditioned and unconditioned latent, and edited conditioning if we're doing cross-attention control
|
||||
|
||||
'''
|
||||
"""
|
||||
import re
|
||||
from typing import Union, Optional, Any
|
||||
|
||||
from transformers import CLIPTokenizer, CLIPTextModel
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from compel import Compel
|
||||
from compel.prompt_parser import FlattenedPrompt, Blend, Fragment, CrossAttentionControlSubstitute, PromptParser
|
||||
from .devices import torch_dtype
|
||||
from ..models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent
|
||||
from ldm.invoke.globals import Globals
|
||||
from compel.prompt_parser import (
|
||||
Blend,
|
||||
CrossAttentionControlSubstitute,
|
||||
FlattenedPrompt,
|
||||
Fragment,
|
||||
PromptParser,
|
||||
)
|
||||
from transformers import CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from invokeai.backend.globals import Globals
|
||||
|
||||
from ..stable_diffusion import InvokeAIDiffuserComponent
|
||||
from ..util import torch_dtype
|
||||
|
||||
|
||||
def get_tokenizer(model) -> CLIPTokenizer:
|
||||
# TODO remove legacy ckpt fallback handling
|
||||
return (getattr(model, 'tokenizer', None) # diffusers
|
||||
or model.cond_stage_model.tokenizer) # ldm
|
||||
return (
|
||||
getattr(model, "tokenizer", None) # diffusers
|
||||
or model.cond_stage_model.tokenizer
|
||||
) # ldm
|
||||
|
||||
|
||||
def get_text_encoder(model) -> Any:
|
||||
# TODO remove legacy ckpt fallback handling
|
||||
return (getattr(model, 'text_encoder', None) # diffusers
|
||||
or UnsqueezingLDMTransformer(model.cond_stage_model.transformer)) # ldm
|
||||
return getattr(
|
||||
model, "text_encoder", None
|
||||
) or UnsqueezingLDMTransformer( # diffusers
|
||||
model.cond_stage_model.transformer
|
||||
) # ldm
|
||||
|
||||
|
||||
class UnsqueezingLDMTransformer:
|
||||
def __init__(self, ldm_transformer):
|
||||
@ -40,28 +55,41 @@ class UnsqueezingLDMTransformer:
|
||||
return insufficiently_unsqueezed_tensor.unsqueeze(0)
|
||||
|
||||
|
||||
def get_uc_and_c_and_ec(prompt_string, model, log_tokens=False, skip_normalize_legacy_blend=False):
|
||||
def get_uc_and_c_and_ec(
|
||||
prompt_string, model, log_tokens=False, skip_normalize_legacy_blend=False
|
||||
):
|
||||
# lazy-load any deferred textual inversions.
|
||||
# this might take a couple of seconds the first time a textual inversion is used.
|
||||
model.textual_inversion_manager.create_deferred_token_ids_for_any_trigger_terms(prompt_string)
|
||||
model.textual_inversion_manager.create_deferred_token_ids_for_any_trigger_terms(
|
||||
prompt_string
|
||||
)
|
||||
|
||||
tokenizer = get_tokenizer(model)
|
||||
text_encoder = get_text_encoder(model)
|
||||
compel = Compel(tokenizer=tokenizer,
|
||||
text_encoder=text_encoder,
|
||||
textual_inversion_manager=model.textual_inversion_manager,
|
||||
dtype_for_device_getter=torch_dtype)
|
||||
compel = Compel(
|
||||
tokenizer=tokenizer,
|
||||
text_encoder=text_encoder,
|
||||
textual_inversion_manager=model.textual_inversion_manager,
|
||||
dtype_for_device_getter=torch_dtype,
|
||||
)
|
||||
|
||||
# get rid of any newline characters
|
||||
prompt_string = prompt_string.replace("\n", " ")
|
||||
positive_prompt_string, negative_prompt_string = split_prompt_to_positive_and_negative(prompt_string)
|
||||
legacy_blend = try_parse_legacy_blend(positive_prompt_string, skip_normalize_legacy_blend)
|
||||
positive_prompt: FlattenedPrompt|Blend
|
||||
(
|
||||
positive_prompt_string,
|
||||
negative_prompt_string,
|
||||
) = split_prompt_to_positive_and_negative(prompt_string)
|
||||
legacy_blend = try_parse_legacy_blend(
|
||||
positive_prompt_string, skip_normalize_legacy_blend
|
||||
)
|
||||
positive_prompt: FlattenedPrompt | Blend
|
||||
if legacy_blend is not None:
|
||||
positive_prompt = legacy_blend
|
||||
else:
|
||||
positive_prompt = Compel.parse_prompt_string(positive_prompt_string)
|
||||
negative_prompt: FlattenedPrompt|Blend = Compel.parse_prompt_string(negative_prompt_string)
|
||||
negative_prompt: FlattenedPrompt | Blend = Compel.parse_prompt_string(
|
||||
negative_prompt_string
|
||||
)
|
||||
|
||||
if log_tokens or getattr(Globals, "log_tokenization", False):
|
||||
log_tokenization(positive_prompt, negative_prompt, tokenizer=tokenizer)
|
||||
@ -71,42 +99,70 @@ def get_uc_and_c_and_ec(prompt_string, model, log_tokens=False, skip_normalize_l
|
||||
|
||||
tokens_count = get_max_token_count(tokenizer, positive_prompt)
|
||||
|
||||
ec = InvokeAIDiffuserComponent.ExtraConditioningInfo(tokens_count_including_eos_bos=tokens_count,
|
||||
cross_attention_control_args=options.get(
|
||||
'cross_attention_control', None))
|
||||
ec = InvokeAIDiffuserComponent.ExtraConditioningInfo(
|
||||
tokens_count_including_eos_bos=tokens_count,
|
||||
cross_attention_control_args=options.get("cross_attention_control", None),
|
||||
)
|
||||
return uc, c, ec
|
||||
|
||||
|
||||
def get_prompt_structure(prompt_string, skip_normalize_legacy_blend: bool = False) -> (
|
||||
Union[FlattenedPrompt, Blend], FlattenedPrompt):
|
||||
positive_prompt_string, negative_prompt_string = split_prompt_to_positive_and_negative(prompt_string)
|
||||
legacy_blend = try_parse_legacy_blend(positive_prompt_string, skip_normalize_legacy_blend)
|
||||
positive_prompt: FlattenedPrompt|Blend
|
||||
def get_prompt_structure(
|
||||
prompt_string, skip_normalize_legacy_blend: bool = False
|
||||
) -> (Union[FlattenedPrompt, Blend], FlattenedPrompt):
|
||||
(
|
||||
positive_prompt_string,
|
||||
negative_prompt_string,
|
||||
) = split_prompt_to_positive_and_negative(prompt_string)
|
||||
legacy_blend = try_parse_legacy_blend(
|
||||
positive_prompt_string, skip_normalize_legacy_blend
|
||||
)
|
||||
positive_prompt: FlattenedPrompt | Blend
|
||||
if legacy_blend is not None:
|
||||
positive_prompt = legacy_blend
|
||||
else:
|
||||
positive_prompt = Compel.parse_prompt_string(positive_prompt_string)
|
||||
negative_prompt: FlattenedPrompt|Blend = Compel.parse_prompt_string(negative_prompt_string)
|
||||
negative_prompt: FlattenedPrompt | Blend = Compel.parse_prompt_string(
|
||||
negative_prompt_string
|
||||
)
|
||||
|
||||
return positive_prompt, negative_prompt
|
||||
|
||||
def get_max_token_count(tokenizer, prompt: Union[FlattenedPrompt, Blend], truncate_if_too_long=True) -> int:
|
||||
|
||||
def get_max_token_count(
|
||||
tokenizer, prompt: Union[FlattenedPrompt, Blend], truncate_if_too_long=True
|
||||
) -> int:
|
||||
if type(prompt) is Blend:
|
||||
blend: Blend = prompt
|
||||
return max([get_max_token_count(tokenizer, c, truncate_if_too_long) for c in blend.prompts])
|
||||
return max(
|
||||
[
|
||||
get_max_token_count(tokenizer, c, truncate_if_too_long)
|
||||
for c in blend.prompts
|
||||
]
|
||||
)
|
||||
else:
|
||||
return len(get_tokens_for_prompt_object(tokenizer, prompt, truncate_if_too_long))
|
||||
return len(
|
||||
get_tokens_for_prompt_object(tokenizer, prompt, truncate_if_too_long)
|
||||
)
|
||||
|
||||
|
||||
def get_tokens_for_prompt_object(tokenizer, parsed_prompt: FlattenedPrompt, truncate_if_too_long=True) -> [str]:
|
||||
|
||||
def get_tokens_for_prompt_object(
|
||||
tokenizer, parsed_prompt: FlattenedPrompt, truncate_if_too_long=True
|
||||
) -> [str]:
|
||||
if type(parsed_prompt) is Blend:
|
||||
raise ValueError("Blend is not supported here - you need to get tokens for each of its .children")
|
||||
raise ValueError(
|
||||
"Blend is not supported here - you need to get tokens for each of its .children"
|
||||
)
|
||||
|
||||
text_fragments = [x.text if type(x) is Fragment else
|
||||
(" ".join([f.text for f in x.original]) if type(x) is CrossAttentionControlSubstitute else
|
||||
str(x))
|
||||
for x in parsed_prompt.children]
|
||||
text_fragments = [
|
||||
x.text
|
||||
if type(x) is Fragment
|
||||
else (
|
||||
" ".join([f.text for f in x.original])
|
||||
if type(x) is CrossAttentionControlSubstitute
|
||||
else str(x)
|
||||
)
|
||||
for x in parsed_prompt.children
|
||||
]
|
||||
text = " ".join(text_fragments)
|
||||
tokens = tokenizer.tokenize(text)
|
||||
if truncate_if_too_long:
|
||||
@ -116,39 +172,47 @@ def get_tokens_for_prompt_object(tokenizer, parsed_prompt: FlattenedPrompt, trun
|
||||
|
||||
|
||||
def split_prompt_to_positive_and_negative(prompt_string_uncleaned: str):
|
||||
unconditioned_words = ''
|
||||
unconditional_regex = r'\[(.*?)\]'
|
||||
unconditioned_words = ""
|
||||
unconditional_regex = r"\[(.*?)\]"
|
||||
unconditionals = re.findall(unconditional_regex, prompt_string_uncleaned)
|
||||
if len(unconditionals) > 0:
|
||||
unconditioned_words = ' '.join(unconditionals)
|
||||
unconditioned_words = " ".join(unconditionals)
|
||||
|
||||
# Remove Unconditioned Words From Prompt
|
||||
unconditional_regex_compile = re.compile(unconditional_regex)
|
||||
clean_prompt = unconditional_regex_compile.sub(' ', prompt_string_uncleaned)
|
||||
prompt_string_cleaned = re.sub(' +', ' ', clean_prompt)
|
||||
clean_prompt = unconditional_regex_compile.sub(" ", prompt_string_uncleaned)
|
||||
prompt_string_cleaned = re.sub(" +", " ", clean_prompt)
|
||||
else:
|
||||
prompt_string_cleaned = prompt_string_uncleaned
|
||||
return prompt_string_cleaned, unconditioned_words
|
||||
|
||||
|
||||
def log_tokenization(positive_prompt: Union[Blend, FlattenedPrompt],
|
||||
negative_prompt: Union[Blend, FlattenedPrompt],
|
||||
tokenizer):
|
||||
def log_tokenization(
|
||||
positive_prompt: Union[Blend, FlattenedPrompt],
|
||||
negative_prompt: Union[Blend, FlattenedPrompt],
|
||||
tokenizer,
|
||||
):
|
||||
print(f"\n>> [TOKENLOG] Parsed Prompt: {positive_prompt}")
|
||||
print(f"\n>> [TOKENLOG] Parsed Negative Prompt: {negative_prompt}")
|
||||
|
||||
log_tokenization_for_prompt_object(positive_prompt, tokenizer)
|
||||
log_tokenization_for_prompt_object(negative_prompt, tokenizer, display_label_prefix="(negative prompt)")
|
||||
log_tokenization_for_prompt_object(
|
||||
negative_prompt, tokenizer, display_label_prefix="(negative prompt)"
|
||||
)
|
||||
|
||||
|
||||
def log_tokenization_for_prompt_object(p: Union[Blend, FlattenedPrompt], tokenizer, display_label_prefix=None):
|
||||
def log_tokenization_for_prompt_object(
|
||||
p: Union[Blend, FlattenedPrompt], tokenizer, display_label_prefix=None
|
||||
):
|
||||
display_label_prefix = display_label_prefix or ""
|
||||
if type(p) is Blend:
|
||||
blend: Blend = p
|
||||
for i, c in enumerate(blend.prompts):
|
||||
log_tokenization_for_prompt_object(
|
||||
c, tokenizer,
|
||||
display_label_prefix=f"{display_label_prefix}(blend part {i + 1}, weight={blend.weights[i]})")
|
||||
c,
|
||||
tokenizer,
|
||||
display_label_prefix=f"{display_label_prefix}(blend part {i + 1}, weight={blend.weights[i]})",
|
||||
)
|
||||
elif type(p) is FlattenedPrompt:
|
||||
flattened_prompt: FlattenedPrompt = p
|
||||
if flattened_prompt.wants_cross_attention_control:
|
||||
@ -163,18 +227,26 @@ def log_tokenization_for_prompt_object(p: Union[Blend, FlattenedPrompt], tokeniz
|
||||
edited_fragments.append(f)
|
||||
|
||||
original_text = " ".join([x.text for x in original_fragments])
|
||||
log_tokenization_for_text(original_text, tokenizer,
|
||||
display_label=f"{display_label_prefix}(.swap originals)")
|
||||
log_tokenization_for_text(
|
||||
original_text,
|
||||
tokenizer,
|
||||
display_label=f"{display_label_prefix}(.swap originals)",
|
||||
)
|
||||
edited_text = " ".join([x.text for x in edited_fragments])
|
||||
log_tokenization_for_text(edited_text, tokenizer,
|
||||
display_label=f"{display_label_prefix}(.swap replacements)")
|
||||
log_tokenization_for_text(
|
||||
edited_text,
|
||||
tokenizer,
|
||||
display_label=f"{display_label_prefix}(.swap replacements)",
|
||||
)
|
||||
else:
|
||||
text = " ".join([x.text for x in flattened_prompt.children])
|
||||
log_tokenization_for_text(text, tokenizer, display_label=display_label_prefix)
|
||||
log_tokenization_for_text(
|
||||
text, tokenizer, display_label=display_label_prefix
|
||||
)
|
||||
|
||||
|
||||
def log_tokenization_for_text(text, tokenizer, display_label=None):
|
||||
""" shows how the prompt is tokenized
|
||||
"""shows how the prompt is tokenized
|
||||
# usually tokens have '</w>' to indicate end-of-word,
|
||||
# but for readability it has been replaced with ' '
|
||||
"""
|
||||
@ -185,7 +257,7 @@ def log_tokenization_for_text(text, tokenizer, display_label=None):
|
||||
totalTokens = len(tokens)
|
||||
|
||||
for i in range(0, totalTokens):
|
||||
token = tokens[i].replace('</w>', ' ')
|
||||
token = tokens[i].replace("</w>", " ")
|
||||
# alternate color
|
||||
s = (usedTokens % 6) + 1
|
||||
if i < tokenizer.model_max_length:
|
||||
@ -196,14 +268,14 @@ def log_tokenization_for_text(text, tokenizer, display_label=None):
|
||||
|
||||
if usedTokens > 0:
|
||||
print(f'\n>> [TOKENLOG] Tokens {display_label or ""} ({usedTokens}):')
|
||||
print(f'{tokenized}\x1b[0m')
|
||||
print(f"{tokenized}\x1b[0m")
|
||||
|
||||
if discarded != "":
|
||||
print(f'\n>> [TOKENLOG] Tokens Discarded ({totalTokens - usedTokens}):')
|
||||
print(f'{discarded}\x1b[0m')
|
||||
print(f"\n>> [TOKENLOG] Tokens Discarded ({totalTokens - usedTokens}):")
|
||||
print(f"{discarded}\x1b[0m")
|
||||
|
||||
|
||||
def try_parse_legacy_blend(text: str, skip_normalize: bool=False) -> Optional[Blend]:
|
||||
def try_parse_legacy_blend(text: str, skip_normalize: bool = False) -> Optional[Blend]:
|
||||
weighted_subprompts = split_weighted_subprompts(text, skip_normalize=skip_normalize)
|
||||
if len(weighted_subprompts) <= 1:
|
||||
return None
|
||||
@ -214,10 +286,12 @@ def try_parse_legacy_blend(text: str, skip_normalize: bool=False) -> Optional[Bl
|
||||
parsed_conjunctions = [pp.parse_conjunction(x) for x in strings]
|
||||
flattened_prompts = [x.prompts[0] for x in parsed_conjunctions]
|
||||
|
||||
return Blend(prompts=flattened_prompts, weights=weights, normalize_weights=not skip_normalize)
|
||||
return Blend(
|
||||
prompts=flattened_prompts, weights=weights, normalize_weights=not skip_normalize
|
||||
)
|
||||
|
||||
|
||||
def split_weighted_subprompts(text, skip_normalize=False)->list:
|
||||
def split_weighted_subprompts(text, skip_normalize=False) -> list:
|
||||
"""
|
||||
Legacy blend parsing.
|
||||
|
||||
@ -226,7 +300,8 @@ def split_weighted_subprompts(text, skip_normalize=False)->list:
|
||||
if ':' has no value defined, defaults to 1.0
|
||||
repeats until no text remaining
|
||||
"""
|
||||
prompt_parser = re.compile("""
|
||||
prompt_parser = re.compile(
|
||||
"""
|
||||
(?P<prompt> # capture group for 'prompt'
|
||||
(?:\\\:|[^:])+ # match one or more non ':' characters or escaped colons '\:'
|
||||
) # end 'prompt'
|
||||
@ -239,16 +314,20 @@ def split_weighted_subprompts(text, skip_normalize=False)->list:
|
||||
| # OR
|
||||
$ # else, if no ':' then match end of line
|
||||
) # end non-capture group
|
||||
""", re.VERBOSE)
|
||||
parsed_prompts = [(match.group("prompt").replace("\\:", ":"), float(
|
||||
match.group("weight") or 1)) for match in re.finditer(prompt_parser, text)]
|
||||
""",
|
||||
re.VERBOSE,
|
||||
)
|
||||
parsed_prompts = [
|
||||
(match.group("prompt").replace("\\:", ":"), float(match.group("weight") or 1))
|
||||
for match in re.finditer(prompt_parser, text)
|
||||
]
|
||||
if skip_normalize:
|
||||
return parsed_prompts
|
||||
weight_sum = sum(map(lambda x: x[1], parsed_prompts))
|
||||
if weight_sum == 0:
|
||||
print(
|
||||
"* Warning: Subprompt weights add up to zero. Discarding and using even weights instead.")
|
||||
"* Warning: Subprompt weights add up to zero. Discarding and using even weights instead."
|
||||
)
|
||||
equal_weight = 1 / max(len(parsed_prompts), 1)
|
||||
return [(x[0], equal_weight) for x in parsed_prompts]
|
||||
return [(x[0], x[1] / weight_sum) for x in parsed_prompts]
|
||||
|
4
invokeai/backend/restoration/__init__.py
Normal file
4
invokeai/backend/restoration/__init__.py
Normal file
@ -0,0 +1,4 @@
|
||||
"""
|
||||
Initialization file for the invokeai.backend.restoration package
|
||||
"""
|
||||
from .base import Restoration
|
@ -1,38 +1,43 @@
|
||||
class Restoration():
|
||||
class Restoration:
|
||||
def __init__(self) -> None:
|
||||
pass
|
||||
|
||||
def load_face_restore_models(self, gfpgan_model_path='./models/gfpgan/GFPGANv1.4.pth'):
|
||||
def load_face_restore_models(
|
||||
self, gfpgan_model_path="./models/gfpgan/GFPGANv1.4.pth"
|
||||
):
|
||||
# Load GFPGAN
|
||||
gfpgan = self.load_gfpgan(gfpgan_model_path)
|
||||
if gfpgan.gfpgan_model_exists:
|
||||
print('>> GFPGAN Initialized')
|
||||
print(">> GFPGAN Initialized")
|
||||
else:
|
||||
print('>> GFPGAN Disabled')
|
||||
print(">> GFPGAN Disabled")
|
||||
gfpgan = None
|
||||
|
||||
# Load CodeFormer
|
||||
codeformer = self.load_codeformer()
|
||||
if codeformer.codeformer_model_exists:
|
||||
print('>> CodeFormer Initialized')
|
||||
print(">> CodeFormer Initialized")
|
||||
else:
|
||||
print('>> CodeFormer Disabled')
|
||||
print(">> CodeFormer Disabled")
|
||||
codeformer = None
|
||||
|
||||
return gfpgan, codeformer
|
||||
|
||||
# Face Restore Models
|
||||
def load_gfpgan(self, gfpgan_model_path):
|
||||
from ldm.invoke.restoration.gfpgan import GFPGAN
|
||||
from .gfpgan import GFPGAN
|
||||
|
||||
return GFPGAN(gfpgan_model_path)
|
||||
|
||||
def load_codeformer(self):
|
||||
from ldm.invoke.restoration.codeformer import CodeFormerRestoration
|
||||
from .codeformer import CodeFormerRestoration
|
||||
|
||||
return CodeFormerRestoration()
|
||||
|
||||
# Upscale Models
|
||||
def load_esrgan(self, esrgan_bg_tile=400):
|
||||
from ldm.invoke.restoration.realesrgan import ESRGAN
|
||||
from .realesrgan import ESRGAN
|
||||
|
||||
esrgan = ESRGAN(esrgan_bg_tile)
|
||||
print('>> ESRGAN Initialized')
|
||||
return esrgan;
|
||||
print(">> ESRGAN Initialized")
|
||||
return esrgan
|
@ -1,17 +1,21 @@
|
||||
import os
|
||||
import torch
|
||||
import numpy as np
|
||||
import warnings
|
||||
import sys
|
||||
from ldm.invoke.globals import Globals
|
||||
import warnings
|
||||
|
||||
pretrained_model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth'
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
class CodeFormerRestoration():
|
||||
def __init__(self,
|
||||
codeformer_dir='models/codeformer',
|
||||
codeformer_model_path='codeformer.pth') -> None:
|
||||
from ..globals import Globals
|
||||
|
||||
pretrained_model_url = (
|
||||
"https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth"
|
||||
)
|
||||
|
||||
|
||||
class CodeFormerRestoration:
|
||||
def __init__(
|
||||
self, codeformer_dir="models/codeformer", codeformer_model_path="codeformer.pth"
|
||||
) -> None:
|
||||
if not os.path.isabs(codeformer_dir):
|
||||
codeformer_dir = os.path.join(Globals.root, codeformer_dir)
|
||||
|
||||
@ -19,22 +23,23 @@ class CodeFormerRestoration():
|
||||
self.codeformer_model_exists = os.path.isfile(self.model_path)
|
||||
|
||||
if not self.codeformer_model_exists:
|
||||
print('## NOT FOUND: CodeFormer model not found at ' + self.model_path)
|
||||
print("## NOT FOUND: CodeFormer model not found at " + self.model_path)
|
||||
sys.path.append(os.path.abspath(codeformer_dir))
|
||||
|
||||
def process(self, image, strength, device, seed=None, fidelity=0.75):
|
||||
if seed is not None:
|
||||
print(f'>> CodeFormer - Restoring Faces for image seed:{seed}')
|
||||
print(f">> CodeFormer - Restoring Faces for image seed:{seed}")
|
||||
with warnings.catch_warnings():
|
||||
warnings.filterwarnings('ignore', category=DeprecationWarning)
|
||||
warnings.filterwarnings('ignore', category=UserWarning)
|
||||
warnings.filterwarnings("ignore", category=DeprecationWarning)
|
||||
warnings.filterwarnings("ignore", category=UserWarning)
|
||||
|
||||
from basicsr.utils.download_util import load_file_from_url
|
||||
from basicsr.utils import img2tensor, tensor2img
|
||||
from basicsr.utils.download_util import load_file_from_url
|
||||
from facexlib.utils.face_restoration_helper import FaceRestoreHelper
|
||||
from ldm.invoke.restoration.codeformer_arch import CodeFormer
|
||||
from torchvision.transforms.functional import normalize
|
||||
from PIL import Image
|
||||
from torchvision.transforms.functional import normalize
|
||||
|
||||
from .codeformer_arch import CodeFormer
|
||||
|
||||
cf_class = CodeFormer
|
||||
|
||||
@ -43,28 +48,31 @@ class CodeFormerRestoration():
|
||||
codebook_size=1024,
|
||||
n_head=8,
|
||||
n_layers=9,
|
||||
connect_list=['32', '64', '128', '256']
|
||||
connect_list=["32", "64", "128", "256"],
|
||||
).to(device)
|
||||
|
||||
# note that this file should already be downloaded and cached at
|
||||
# this point
|
||||
checkpoint_path = load_file_from_url(url=pretrained_model_url,
|
||||
model_dir=os.path.abspath(os.path.dirname(self.model_path)),
|
||||
progress=True
|
||||
checkpoint_path = load_file_from_url(
|
||||
url=pretrained_model_url,
|
||||
model_dir=os.path.abspath(os.path.dirname(self.model_path)),
|
||||
progress=True,
|
||||
)
|
||||
checkpoint = torch.load(checkpoint_path)['params_ema']
|
||||
checkpoint = torch.load(checkpoint_path)["params_ema"]
|
||||
cf.load_state_dict(checkpoint)
|
||||
cf.eval()
|
||||
|
||||
image = image.convert('RGB')
|
||||
image = image.convert("RGB")
|
||||
# Codeformer expects a BGR np array; make array and flip channels
|
||||
bgr_image_array = np.array(image, dtype=np.uint8)[...,::-1]
|
||||
bgr_image_array = np.array(image, dtype=np.uint8)[..., ::-1]
|
||||
|
||||
face_helper = FaceRestoreHelper(
|
||||
upscale_factor=1,
|
||||
use_parse=True,
|
||||
device=device,
|
||||
model_rootpath=os.path.join(Globals.root,'models','gfpgan','weights'),
|
||||
model_rootpath=os.path.join(
|
||||
Globals.root, "models", "gfpgan", "weights"
|
||||
),
|
||||
)
|
||||
face_helper.clean_all()
|
||||
face_helper.read_image(bgr_image_array)
|
||||
@ -72,30 +80,35 @@ class CodeFormerRestoration():
|
||||
face_helper.align_warp_face()
|
||||
|
||||
for idx, cropped_face in enumerate(face_helper.cropped_faces):
|
||||
cropped_face_t = img2tensor(cropped_face / 255., bgr2rgb=True, float32=True)
|
||||
normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
|
||||
cropped_face_t = img2tensor(
|
||||
cropped_face / 255.0, bgr2rgb=True, float32=True
|
||||
)
|
||||
normalize(
|
||||
cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True
|
||||
)
|
||||
cropped_face_t = cropped_face_t.unsqueeze(0).to(device)
|
||||
|
||||
try:
|
||||
with torch.no_grad():
|
||||
output = cf(cropped_face_t, w=fidelity, adain=True)[0]
|
||||
restored_face = tensor2img(output.squeeze(0), rgb2bgr=True, min_max=(-1, 1))
|
||||
restored_face = tensor2img(
|
||||
output.squeeze(0), rgb2bgr=True, min_max=(-1, 1)
|
||||
)
|
||||
del output
|
||||
torch.cuda.empty_cache()
|
||||
except RuntimeError as error:
|
||||
print(f'\tFailed inference for CodeFormer: {error}.')
|
||||
print(f"\tFailed inference for CodeFormer: {error}.")
|
||||
restored_face = cropped_face
|
||||
|
||||
restored_face = restored_face.astype('uint8')
|
||||
restored_face = restored_face.astype("uint8")
|
||||
face_helper.add_restored_face(restored_face)
|
||||
|
||||
|
||||
face_helper.get_inverse_affine(None)
|
||||
|
||||
restored_img = face_helper.paste_faces_to_input_image()
|
||||
|
||||
# Flip the channels back to RGB
|
||||
res = Image.fromarray(restored_img[...,::-1])
|
||||
res = Image.fromarray(restored_img[..., ::-1])
|
||||
|
||||
if strength < 1.0:
|
||||
# Resize the image to the new image if the sizes have changed
|
@ -1,13 +1,15 @@
|
||||
import math
|
||||
from typing import List, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import nn, Tensor
|
||||
import torch.nn.functional as F
|
||||
from typing import Optional, List
|
||||
|
||||
from ldm.invoke.restoration.vqgan_arch import *
|
||||
from basicsr.utils import get_root_logger
|
||||
from basicsr.utils.registry import ARCH_REGISTRY
|
||||
from torch import Tensor, nn
|
||||
|
||||
from .vqgan_arch import *
|
||||
|
||||
|
||||
def calc_mean_std(feat, eps=1e-5):
|
||||
"""Calculate mean and std for adaptive_instance_normalization.
|
||||
@ -18,7 +20,7 @@ def calc_mean_std(feat, eps=1e-5):
|
||||
divide-by-zero. Default: 1e-5.
|
||||
"""
|
||||
size = feat.size()
|
||||
assert len(size) == 4, 'The input feature should be 4D tensor.'
|
||||
assert len(size) == 4, "The input feature should be 4D tensor."
|
||||
b, c = size[:2]
|
||||
feat_var = feat.view(b, c, -1).var(dim=2) + eps
|
||||
feat_std = feat_var.sqrt().view(b, c, 1, 1)
|
||||
@ -39,7 +41,9 @@ def adaptive_instance_normalization(content_feat, style_feat):
|
||||
size = content_feat.size()
|
||||
style_mean, style_std = calc_mean_std(style_feat)
|
||||
content_mean, content_std = calc_mean_std(content_feat)
|
||||
normalized_feat = (content_feat - content_mean.expand(size)) / content_std.expand(size)
|
||||
normalized_feat = (content_feat - content_mean.expand(size)) / content_std.expand(
|
||||
size
|
||||
)
|
||||
return normalized_feat * style_std.expand(size) + style_mean.expand(size)
|
||||
|
||||
|
||||
@ -49,7 +53,9 @@ class PositionEmbeddingSine(nn.Module):
|
||||
used by the Attention is all you need paper, generalized to work on images.
|
||||
"""
|
||||
|
||||
def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):
|
||||
def __init__(
|
||||
self, num_pos_feats=64, temperature=10000, normalize=False, scale=None
|
||||
):
|
||||
super().__init__()
|
||||
self.num_pos_feats = num_pos_feats
|
||||
self.temperature = temperature
|
||||
@ -62,7 +68,9 @@ class PositionEmbeddingSine(nn.Module):
|
||||
|
||||
def forward(self, x, mask=None):
|
||||
if mask is None:
|
||||
mask = torch.zeros((x.size(0), x.size(2), x.size(3)), device=x.device, dtype=torch.bool)
|
||||
mask = torch.zeros(
|
||||
(x.size(0), x.size(2), x.size(3)), device=x.device, dtype=torch.bool
|
||||
)
|
||||
not_mask = ~mask
|
||||
y_embed = not_mask.cumsum(1, dtype=torch.float32)
|
||||
x_embed = not_mask.cumsum(2, dtype=torch.float32)
|
||||
@ -85,6 +93,7 @@ class PositionEmbeddingSine(nn.Module):
|
||||
pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
|
||||
return pos
|
||||
|
||||
|
||||
def _get_activation_fn(activation):
|
||||
"""Return an activation function given a string"""
|
||||
if activation == "relu":
|
||||
@ -93,11 +102,13 @@ def _get_activation_fn(activation):
|
||||
return F.gelu
|
||||
if activation == "glu":
|
||||
return F.glu
|
||||
raise RuntimeError(F"activation should be relu/gelu, not {activation}.")
|
||||
raise RuntimeError(f"activation should be relu/gelu, not {activation}.")
|
||||
|
||||
|
||||
class TransformerSALayer(nn.Module):
|
||||
def __init__(self, embed_dim, nhead=8, dim_mlp=2048, dropout=0.0, activation="gelu"):
|
||||
def __init__(
|
||||
self, embed_dim, nhead=8, dim_mlp=2048, dropout=0.0, activation="gelu"
|
||||
):
|
||||
super().__init__()
|
||||
self.self_attn = nn.MultiheadAttention(embed_dim, nhead, dropout=dropout)
|
||||
# Implementation of Feedforward model - MLP
|
||||
@ -115,16 +126,19 @@ class TransformerSALayer(nn.Module):
|
||||
def with_pos_embed(self, tensor, pos: Optional[Tensor]):
|
||||
return tensor if pos is None else tensor + pos
|
||||
|
||||
def forward(self, tgt,
|
||||
tgt_mask: Optional[Tensor] = None,
|
||||
tgt_key_padding_mask: Optional[Tensor] = None,
|
||||
query_pos: Optional[Tensor] = None):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
tgt,
|
||||
tgt_mask: Optional[Tensor] = None,
|
||||
tgt_key_padding_mask: Optional[Tensor] = None,
|
||||
query_pos: Optional[Tensor] = None,
|
||||
):
|
||||
# self attention
|
||||
tgt2 = self.norm1(tgt)
|
||||
q = k = self.with_pos_embed(tgt2, query_pos)
|
||||
tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask,
|
||||
key_padding_mask=tgt_key_padding_mask)[0]
|
||||
tgt2 = self.self_attn(
|
||||
q, k, value=tgt2, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask
|
||||
)[0]
|
||||
tgt = tgt + self.dropout1(tgt2)
|
||||
|
||||
# ffn
|
||||
@ -133,20 +147,23 @@ class TransformerSALayer(nn.Module):
|
||||
tgt = tgt + self.dropout2(tgt2)
|
||||
return tgt
|
||||
|
||||
|
||||
class Fuse_sft_block(nn.Module):
|
||||
def __init__(self, in_ch, out_ch):
|
||||
super().__init__()
|
||||
self.encode_enc = ResBlock(2*in_ch, out_ch)
|
||||
self.encode_enc = ResBlock(2 * in_ch, out_ch)
|
||||
|
||||
self.scale = nn.Sequential(
|
||||
nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
|
||||
nn.LeakyReLU(0.2, True),
|
||||
nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1))
|
||||
nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
|
||||
nn.LeakyReLU(0.2, True),
|
||||
nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1),
|
||||
)
|
||||
|
||||
self.shift = nn.Sequential(
|
||||
nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
|
||||
nn.LeakyReLU(0.2, True),
|
||||
nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1))
|
||||
nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
|
||||
nn.LeakyReLU(0.2, True),
|
||||
nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1),
|
||||
)
|
||||
|
||||
def forward(self, enc_feat, dec_feat, w=1):
|
||||
enc_feat = self.encode_enc(torch.cat([enc_feat, dec_feat], dim=1))
|
||||
@ -159,11 +176,19 @@ class Fuse_sft_block(nn.Module):
|
||||
|
||||
@ARCH_REGISTRY.register()
|
||||
class CodeFormer(VQAutoEncoder):
|
||||
def __init__(self, dim_embd=512, n_head=8, n_layers=9,
|
||||
codebook_size=1024, latent_size=256,
|
||||
connect_list=['32', '64', '128', '256'],
|
||||
fix_modules=['quantize','generator']):
|
||||
super(CodeFormer, self).__init__(512, 64, [1, 2, 2, 4, 4, 8], 'nearest',2, [16], codebook_size)
|
||||
def __init__(
|
||||
self,
|
||||
dim_embd=512,
|
||||
n_head=8,
|
||||
n_layers=9,
|
||||
codebook_size=1024,
|
||||
latent_size=256,
|
||||
connect_list=["32", "64", "128", "256"],
|
||||
fix_modules=["quantize", "generator"],
|
||||
):
|
||||
super(CodeFormer, self).__init__(
|
||||
512, 64, [1, 2, 2, 4, 4, 8], "nearest", 2, [16], codebook_size
|
||||
)
|
||||
|
||||
if fix_modules is not None:
|
||||
for module in fix_modules:
|
||||
@ -173,33 +198,53 @@ class CodeFormer(VQAutoEncoder):
|
||||
self.connect_list = connect_list
|
||||
self.n_layers = n_layers
|
||||
self.dim_embd = dim_embd
|
||||
self.dim_mlp = dim_embd*2
|
||||
self.dim_mlp = dim_embd * 2
|
||||
|
||||
self.position_emb = nn.Parameter(torch.zeros(latent_size, self.dim_embd))
|
||||
self.feat_emb = nn.Linear(256, self.dim_embd)
|
||||
|
||||
# transformer
|
||||
self.ft_layers = nn.Sequential(*[TransformerSALayer(embed_dim=dim_embd, nhead=n_head, dim_mlp=self.dim_mlp, dropout=0.0)
|
||||
for _ in range(self.n_layers)])
|
||||
self.ft_layers = nn.Sequential(
|
||||
*[
|
||||
TransformerSALayer(
|
||||
embed_dim=dim_embd, nhead=n_head, dim_mlp=self.dim_mlp, dropout=0.0
|
||||
)
|
||||
for _ in range(self.n_layers)
|
||||
]
|
||||
)
|
||||
|
||||
# logits_predict head
|
||||
self.idx_pred_layer = nn.Sequential(
|
||||
nn.LayerNorm(dim_embd),
|
||||
nn.Linear(dim_embd, codebook_size, bias=False))
|
||||
nn.LayerNorm(dim_embd), nn.Linear(dim_embd, codebook_size, bias=False)
|
||||
)
|
||||
|
||||
self.channels = {
|
||||
'16': 512,
|
||||
'32': 256,
|
||||
'64': 256,
|
||||
'128': 128,
|
||||
'256': 128,
|
||||
'512': 64,
|
||||
"16": 512,
|
||||
"32": 256,
|
||||
"64": 256,
|
||||
"128": 128,
|
||||
"256": 128,
|
||||
"512": 64,
|
||||
}
|
||||
|
||||
# after second residual block for > 16, before attn layer for ==16
|
||||
self.fuse_encoder_block = {'512':2, '256':5, '128':8, '64':11, '32':14, '16':18}
|
||||
self.fuse_encoder_block = {
|
||||
"512": 2,
|
||||
"256": 5,
|
||||
"128": 8,
|
||||
"64": 11,
|
||||
"32": 14,
|
||||
"16": 18,
|
||||
}
|
||||
# after first residual block for > 16, before attn layer for ==16
|
||||
self.fuse_generator_block = {'16':6, '32': 9, '64':12, '128':15, '256':18, '512':21}
|
||||
self.fuse_generator_block = {
|
||||
"16": 6,
|
||||
"32": 9,
|
||||
"64": 12,
|
||||
"128": 15,
|
||||
"256": 18,
|
||||
"512": 21,
|
||||
}
|
||||
|
||||
# fuse_convs_dict
|
||||
self.fuse_convs_dict = nn.ModuleDict()
|
||||
@ -228,20 +273,20 @@ class CodeFormer(VQAutoEncoder):
|
||||
lq_feat = x
|
||||
# ################# Transformer ###################
|
||||
# quant_feat, codebook_loss, quant_stats = self.quantize(lq_feat)
|
||||
pos_emb = self.position_emb.unsqueeze(1).repeat(1,x.shape[0],1)
|
||||
pos_emb = self.position_emb.unsqueeze(1).repeat(1, x.shape[0], 1)
|
||||
# BCHW -> BC(HW) -> (HW)BC
|
||||
feat_emb = self.feat_emb(lq_feat.flatten(2).permute(2,0,1))
|
||||
feat_emb = self.feat_emb(lq_feat.flatten(2).permute(2, 0, 1))
|
||||
query_emb = feat_emb
|
||||
# Transformer encoder
|
||||
for layer in self.ft_layers:
|
||||
query_emb = layer(query_emb, query_pos=pos_emb)
|
||||
|
||||
# output logits
|
||||
logits = self.idx_pred_layer(query_emb) # (hw)bn
|
||||
logits = logits.permute(1,0,2) # (hw)bn -> b(hw)n
|
||||
logits = self.idx_pred_layer(query_emb) # (hw)bn
|
||||
logits = logits.permute(1, 0, 2) # (hw)bn -> b(hw)n
|
||||
|
||||
if code_only: # for training stage II
|
||||
# logits doesn't need softmax before cross_entropy loss
|
||||
if code_only: # for training stage II
|
||||
# logits doesn't need softmax before cross_entropy loss
|
||||
return logits, lq_feat
|
||||
|
||||
# ################# Quantization ###################
|
||||
@ -252,12 +297,14 @@ class CodeFormer(VQAutoEncoder):
|
||||
# ------------
|
||||
soft_one_hot = F.softmax(logits, dim=2)
|
||||
_, top_idx = torch.topk(soft_one_hot, 1, dim=2)
|
||||
quant_feat = self.quantize.get_codebook_feat(top_idx, shape=[x.shape[0],16,16,256])
|
||||
quant_feat = self.quantize.get_codebook_feat(
|
||||
top_idx, shape=[x.shape[0], 16, 16, 256]
|
||||
)
|
||||
# preserve gradients
|
||||
# quant_feat = lq_feat + (quant_feat - lq_feat).detach()
|
||||
|
||||
if detach_16:
|
||||
quant_feat = quant_feat.detach() # for training stage III
|
||||
quant_feat = quant_feat.detach() # for training stage III
|
||||
if adain:
|
||||
quant_feat = adaptive_instance_normalization(quant_feat, lq_feat)
|
||||
|
||||
@ -267,10 +314,12 @@ class CodeFormer(VQAutoEncoder):
|
||||
|
||||
for i, block in enumerate(self.generator.blocks):
|
||||
x = block(x)
|
||||
if i in fuse_list: # fuse after i-th block
|
||||
if i in fuse_list: # fuse after i-th block
|
||||
f_size = str(x.shape[-1])
|
||||
if w>0:
|
||||
x = self.fuse_convs_dict[f_size](enc_feat_dict[f_size].detach(), x, w)
|
||||
if w > 0:
|
||||
x = self.fuse_convs_dict[f_size](
|
||||
enc_feat_dict[f_size].detach(), x, w
|
||||
)
|
||||
out = x
|
||||
# logits doesn't need softmax before cross_entropy loss
|
||||
return out, logits, lq_feat
|
@ -1,26 +1,25 @@
|
||||
import torch
|
||||
import warnings
|
||||
import os
|
||||
import sys
|
||||
import numpy as np
|
||||
from ldm.invoke.globals import Globals
|
||||
import warnings
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
from invokeai.backend.globals import Globals
|
||||
|
||||
class GFPGAN():
|
||||
def __init__(
|
||||
self,
|
||||
gfpgan_model_path='models/gfpgan/GFPGANv1.4.pth'
|
||||
) -> None:
|
||||
|
||||
class GFPGAN:
|
||||
def __init__(self, gfpgan_model_path="models/gfpgan/GFPGANv1.4.pth") -> None:
|
||||
if not os.path.isabs(gfpgan_model_path):
|
||||
gfpgan_model_path=os.path.abspath(os.path.join(Globals.root,gfpgan_model_path))
|
||||
gfpgan_model_path = os.path.abspath(
|
||||
os.path.join(Globals.root, gfpgan_model_path)
|
||||
)
|
||||
self.model_path = gfpgan_model_path
|
||||
self.gfpgan_model_exists = os.path.isfile(self.model_path)
|
||||
|
||||
if not self.gfpgan_model_exists:
|
||||
print('## NOT FOUND: GFPGAN model not found at ' + self.model_path)
|
||||
print("## NOT FOUND: GFPGAN model not found at " + self.model_path)
|
||||
return None
|
||||
|
||||
def model_exists(self):
|
||||
@ -28,40 +27,40 @@ class GFPGAN():
|
||||
|
||||
def process(self, image, strength: float, seed: str = None):
|
||||
if seed is not None:
|
||||
print(f'>> GFPGAN - Restoring Faces for image seed:{seed}')
|
||||
print(f">> GFPGAN - Restoring Faces for image seed:{seed}")
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.filterwarnings('ignore', category=DeprecationWarning)
|
||||
warnings.filterwarnings('ignore', category=UserWarning)
|
||||
warnings.filterwarnings("ignore", category=DeprecationWarning)
|
||||
warnings.filterwarnings("ignore", category=UserWarning)
|
||||
cwd = os.getcwd()
|
||||
os.chdir(os.path.join(Globals.root,'models'))
|
||||
os.chdir(os.path.join(Globals.root, "models"))
|
||||
try:
|
||||
from gfpgan import GFPGANer
|
||||
|
||||
self.gfpgan = GFPGANer(
|
||||
model_path=self.model_path,
|
||||
upscale=1,
|
||||
arch='clean',
|
||||
arch="clean",
|
||||
channel_multiplier=2,
|
||||
bg_upsampler=None,
|
||||
)
|
||||
except Exception:
|
||||
import traceback
|
||||
print('>> Error loading GFPGAN:', file=sys.stderr)
|
||||
|
||||
print(">> Error loading GFPGAN:", file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
os.chdir(cwd)
|
||||
|
||||
if self.gfpgan is None:
|
||||
print(f">> WARNING: GFPGAN not initialized.")
|
||||
print(
|
||||
f'>> WARNING: GFPGAN not initialized.'
|
||||
)
|
||||
print(
|
||||
f'>> Download https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth to {self.model_path}'
|
||||
f">> Download https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth to {self.model_path}"
|
||||
)
|
||||
|
||||
image = image.convert('RGB')
|
||||
image = image.convert("RGB")
|
||||
|
||||
# GFPGAN expects a BGR np array; make array and flip channels
|
||||
bgr_image_array = np.array(image, dtype=np.uint8)[...,::-1]
|
||||
bgr_image_array = np.array(image, dtype=np.uint8)[..., ::-1]
|
||||
|
||||
_, _, restored_img = self.gfpgan.enhance(
|
||||
bgr_image_array,
|
||||
@ -71,7 +70,7 @@ class GFPGAN():
|
||||
)
|
||||
|
||||
# Flip the channels back to RGB
|
||||
res = Image.fromarray(restored_img[...,::-1])
|
||||
res = Image.fromarray(restored_img[..., ::-1])
|
||||
|
||||
if strength < 1.0:
|
||||
# Resize the image to the new image if the sizes have changed
|
||||
@ -79,7 +78,6 @@ class GFPGAN():
|
||||
image = image.resize(res.size)
|
||||
res = Image.blend(image, res, strength)
|
||||
|
||||
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
self.gfpgan = None
|
119
invokeai/backend/restoration/outcrop.py
Normal file
119
invokeai/backend/restoration/outcrop.py
Normal file
@ -0,0 +1,119 @@
|
||||
import math
|
||||
import warnings
|
||||
|
||||
from PIL import Image, ImageFilter
|
||||
|
||||
|
||||
class Outcrop(object):
|
||||
def __init__(
|
||||
self,
|
||||
image,
|
||||
generate, # current generate object
|
||||
):
|
||||
self.image = image
|
||||
self.generate = generate
|
||||
|
||||
def process(
|
||||
self,
|
||||
extents: dict,
|
||||
opt, # current options
|
||||
orig_opt, # ones originally used to generate the image
|
||||
image_callback=None,
|
||||
prefix=None,
|
||||
):
|
||||
# grow and mask the image
|
||||
extended_image = self._extend_all(extents)
|
||||
|
||||
# switch samplers temporarily
|
||||
curr_sampler = self.generate.sampler
|
||||
self.generate.sampler_name = opt.sampler_name
|
||||
self.generate._set_sampler()
|
||||
|
||||
def wrapped_callback(img, seed, **kwargs):
|
||||
preferred_seed = (
|
||||
orig_opt.seed
|
||||
if orig_opt.seed is not None and orig_opt.seed >= 0
|
||||
else seed
|
||||
)
|
||||
image_callback(img, preferred_seed, use_prefix=prefix, **kwargs)
|
||||
|
||||
result = self.generate.prompt2image(
|
||||
opt.prompt,
|
||||
seed=opt.seed or orig_opt.seed,
|
||||
sampler=self.generate.sampler,
|
||||
steps=opt.steps,
|
||||
cfg_scale=opt.cfg_scale,
|
||||
ddim_eta=self.generate.ddim_eta,
|
||||
width=extended_image.width,
|
||||
height=extended_image.height,
|
||||
init_img=extended_image,
|
||||
strength=0.90,
|
||||
image_callback=wrapped_callback if image_callback else None,
|
||||
seam_size=opt.seam_size or 96,
|
||||
seam_blur=opt.seam_blur or 16,
|
||||
seam_strength=opt.seam_strength or 0.7,
|
||||
seam_steps=20,
|
||||
tile_size=32,
|
||||
color_match=True,
|
||||
force_outpaint=True, # this just stops the warning about erased regions
|
||||
)
|
||||
|
||||
# swap sampler back
|
||||
self.generate.sampler = curr_sampler
|
||||
return result
|
||||
|
||||
def _extend_all(
|
||||
self,
|
||||
extents: dict,
|
||||
) -> Image:
|
||||
"""
|
||||
Extend the image in direction ('top','bottom','left','right') by
|
||||
the indicated value. The image canvas is extended, and the empty
|
||||
rectangular section will be filled with a blurred copy of the
|
||||
adjacent image.
|
||||
"""
|
||||
image = self.image
|
||||
for direction in extents:
|
||||
assert direction in [
|
||||
"top",
|
||||
"left",
|
||||
"bottom",
|
||||
"right",
|
||||
], 'Direction must be one of "top", "left", "bottom", "right"'
|
||||
pixels = extents[direction]
|
||||
# round pixels up to the nearest 64
|
||||
pixels = math.ceil(pixels / 64) * 64
|
||||
print(f">> extending image {direction}ward by {pixels} pixels")
|
||||
image = self._rotate(image, direction)
|
||||
image = self._extend(image, pixels)
|
||||
image = self._rotate(image, direction, reverse=True)
|
||||
return image
|
||||
|
||||
def _rotate(self, image: Image, direction: str, reverse=False) -> Image:
|
||||
"""
|
||||
Rotates image so that the area to extend is always at the top top.
|
||||
Simplifies logic later. The reverse argument, if true, will undo the
|
||||
previous transpose.
|
||||
"""
|
||||
transposes = {
|
||||
"right": ["ROTATE_90", "ROTATE_270"],
|
||||
"bottom": ["ROTATE_180", "ROTATE_180"],
|
||||
"left": ["ROTATE_270", "ROTATE_90"],
|
||||
}
|
||||
if direction not in transposes:
|
||||
return image
|
||||
transpose = transposes[direction][1 if reverse else 0]
|
||||
return image.transpose(Image.Transpose.__dict__[transpose])
|
||||
|
||||
def _extend(self, image: Image, pixels: int) -> Image:
|
||||
extended_img = Image.new("RGBA", (image.width, image.height + pixels))
|
||||
|
||||
extended_img.paste((0, 0, 0), [0, 0, image.width, image.height + pixels])
|
||||
extended_img.paste(image, box=(0, pixels))
|
||||
|
||||
# now make the top part transparent to use as a mask
|
||||
alpha = extended_img.getchannel("A")
|
||||
alpha.paste(0, (0, 0, extended_img.width, pixels))
|
||||
extended_img.putalpha(alpha)
|
||||
|
||||
return extended_img
|
@ -1,39 +1,43 @@
|
||||
import warnings
|
||||
import math
|
||||
import warnings
|
||||
|
||||
from PIL import Image, ImageFilter
|
||||
|
||||
|
||||
class Outpaint(object):
|
||||
def __init__(self, image, generate):
|
||||
self.image = image
|
||||
self.generate = generate
|
||||
self.image = image
|
||||
self.generate = generate
|
||||
|
||||
def process(self, opt, old_opt, image_callback = None, prefix = None):
|
||||
def process(self, opt, old_opt, image_callback=None, prefix=None):
|
||||
image = self._create_outpaint_image(self.image, opt.out_direction)
|
||||
|
||||
seed = old_opt.seed
|
||||
seed = old_opt.seed
|
||||
prompt = old_opt.prompt
|
||||
|
||||
def wrapped_callback(img,seed,**kwargs):
|
||||
image_callback(img,seed,use_prefix=prefix,**kwargs)
|
||||
|
||||
def wrapped_callback(img, seed, **kwargs):
|
||||
image_callback(img, seed, use_prefix=prefix, **kwargs)
|
||||
|
||||
return self.generate.prompt2image(
|
||||
prompt,
|
||||
seed = seed,
|
||||
sampler = self.generate.sampler,
|
||||
steps = opt.steps,
|
||||
cfg_scale = opt.cfg_scale,
|
||||
ddim_eta = self.generate.ddim_eta,
|
||||
width = opt.width,
|
||||
height = opt.height,
|
||||
init_img = image,
|
||||
strength = 0.83,
|
||||
image_callback = wrapped_callback,
|
||||
prefix = prefix,
|
||||
seed=seed,
|
||||
sampler=self.generate.sampler,
|
||||
steps=opt.steps,
|
||||
cfg_scale=opt.cfg_scale,
|
||||
ddim_eta=self.generate.ddim_eta,
|
||||
width=opt.width,
|
||||
height=opt.height,
|
||||
init_img=image,
|
||||
strength=0.83,
|
||||
image_callback=wrapped_callback,
|
||||
prefix=prefix,
|
||||
)
|
||||
|
||||
def _create_outpaint_image(self, image, direction_args):
|
||||
assert len(direction_args) in [1, 2], 'Direction (-D) must have exactly one or two arguments.'
|
||||
assert len(direction_args) in [
|
||||
1,
|
||||
2,
|
||||
], "Direction (-D) must have exactly one or two arguments."
|
||||
|
||||
if len(direction_args) == 1:
|
||||
direction = direction_args[0]
|
||||
@ -42,19 +46,26 @@ class Outpaint(object):
|
||||
direction = direction_args[0]
|
||||
pixels = int(direction_args[1])
|
||||
|
||||
assert direction in ['top', 'left', 'bottom', 'right'], 'Direction (-D) must be one of "top", "left", "bottom", "right"'
|
||||
assert direction in [
|
||||
"top",
|
||||
"left",
|
||||
"bottom",
|
||||
"right",
|
||||
], 'Direction (-D) must be one of "top", "left", "bottom", "right"'
|
||||
|
||||
image = image.convert("RGBA")
|
||||
# we always extend top, but rotate to extend along the requested side
|
||||
if direction == 'left':
|
||||
if direction == "left":
|
||||
image = image.transpose(Image.Transpose.ROTATE_270)
|
||||
elif direction == 'bottom':
|
||||
elif direction == "bottom":
|
||||
image = image.transpose(Image.Transpose.ROTATE_180)
|
||||
elif direction == 'right':
|
||||
elif direction == "right":
|
||||
image = image.transpose(Image.Transpose.ROTATE_90)
|
||||
|
||||
pixels = image.height//2 if pixels is None else int(pixels)
|
||||
assert 0 < pixels < image.height, 'Direction (-D) pixels length must be in the range 0 - image.size'
|
||||
pixels = image.height // 2 if pixels is None else int(pixels)
|
||||
assert (
|
||||
0 < pixels < image.height
|
||||
), "Direction (-D) pixels length must be in the range 0 - image.size"
|
||||
|
||||
# the top part of the image is taken from the source image mirrored
|
||||
# coordinates (0,0) are the upper left corner of an image
|
||||
@ -74,19 +85,18 @@ class Outpaint(object):
|
||||
new_img.paste(bottom, (0, pixels))
|
||||
|
||||
# create a 10% dither in the middle
|
||||
dither = min(image.height//10, pixels)
|
||||
dither = min(image.height // 10, pixels)
|
||||
for x in range(0, image.width, 2):
|
||||
for y in range(pixels - dither, pixels + dither):
|
||||
(r, g, b, a) = new_img.getpixel((x, y))
|
||||
new_img.putpixel((x, y), (r, g, b, 0))
|
||||
|
||||
# let's rotate back again
|
||||
if direction == 'left':
|
||||
if direction == "left":
|
||||
new_img = new_img.transpose(Image.Transpose.ROTATE_90)
|
||||
elif direction == 'bottom':
|
||||
elif direction == "bottom":
|
||||
new_img = new_img.transpose(Image.Transpose.ROTATE_180)
|
||||
elif direction == 'right':
|
||||
elif direction == "right":
|
||||
new_img = new_img.transpose(Image.Transpose.ROTATE_270)
|
||||
|
||||
return new_img
|
||||
|
@ -1,13 +1,15 @@
|
||||
import torch
|
||||
import warnings
|
||||
import numpy as np
|
||||
import os
|
||||
import warnings
|
||||
|
||||
from ldm.invoke.globals import Globals
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
from PIL.Image import Image as ImageType
|
||||
|
||||
class ESRGAN():
|
||||
from invokeai.backend.globals import Globals
|
||||
|
||||
|
||||
class ESRGAN:
|
||||
def __init__(self, bg_tile_size=400) -> None:
|
||||
self.bg_tile_size = bg_tile_size
|
||||
|
||||
@ -22,12 +24,23 @@ class ESRGAN():
|
||||
else:
|
||||
use_half_precision = True
|
||||
|
||||
from realesrgan.archs.srvgg_arch import SRVGGNetCompact
|
||||
from realesrgan import RealESRGANer
|
||||
from realesrgan.archs.srvgg_arch import SRVGGNetCompact
|
||||
|
||||
model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu')
|
||||
model_path = os.path.join(Globals.root, 'models/realesrgan/realesr-general-x4v3.pth')
|
||||
wdn_model_path = os.path.join(Globals.root, 'models/realesrgan/realesr-general-wdn-x4v3.pth')
|
||||
model = SRVGGNetCompact(
|
||||
num_in_ch=3,
|
||||
num_out_ch=3,
|
||||
num_feat=64,
|
||||
num_conv=32,
|
||||
upscale=4,
|
||||
act_type="prelu",
|
||||
)
|
||||
model_path = os.path.join(
|
||||
Globals.root, "models/realesrgan/realesr-general-x4v3.pth"
|
||||
)
|
||||
wdn_model_path = os.path.join(
|
||||
Globals.root, "models/realesrgan/realesr-general-wdn-x4v3.pth"
|
||||
)
|
||||
scale = 4
|
||||
|
||||
bg_upsampler = RealESRGANer(
|
||||
@ -43,41 +56,49 @@ class ESRGAN():
|
||||
|
||||
return bg_upsampler
|
||||
|
||||
def process(self, image: ImageType, strength: float, seed: str = None, upsampler_scale: int = 2, denoise_str: float = 0.75):
|
||||
def process(
|
||||
self,
|
||||
image: ImageType,
|
||||
strength: float,
|
||||
seed: str = None,
|
||||
upsampler_scale: int = 2,
|
||||
denoise_str: float = 0.75,
|
||||
):
|
||||
with warnings.catch_warnings():
|
||||
warnings.filterwarnings('ignore', category=DeprecationWarning)
|
||||
warnings.filterwarnings('ignore', category=UserWarning)
|
||||
warnings.filterwarnings("ignore", category=DeprecationWarning)
|
||||
warnings.filterwarnings("ignore", category=UserWarning)
|
||||
|
||||
try:
|
||||
upsampler = self.load_esrgan_bg_upsampler(denoise_str)
|
||||
except Exception:
|
||||
import traceback
|
||||
import sys
|
||||
print('>> Error loading Real-ESRGAN:', file=sys.stderr)
|
||||
import traceback
|
||||
|
||||
print(">> Error loading Real-ESRGAN:", file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
|
||||
if upsampler_scale == 0:
|
||||
print('>> Real-ESRGAN: Invalid scaling option. Image not upscaled.')
|
||||
print(">> Real-ESRGAN: Invalid scaling option. Image not upscaled.")
|
||||
return image
|
||||
|
||||
if seed is not None:
|
||||
print(
|
||||
f'>> Real-ESRGAN Upscaling seed:{seed}, scale:{upsampler_scale}x, tile:{self.bg_tile_size}, denoise:{denoise_str}'
|
||||
f">> Real-ESRGAN Upscaling seed:{seed}, scale:{upsampler_scale}x, tile:{self.bg_tile_size}, denoise:{denoise_str}"
|
||||
)
|
||||
# ESRGAN outputs images with partial transparency if given RGBA images; convert to RGB
|
||||
image = image.convert("RGB")
|
||||
|
||||
# REALSRGAN expects a BGR np array; make array and flip channels
|
||||
bgr_image_array = np.array(image, dtype=np.uint8)[...,::-1]
|
||||
bgr_image_array = np.array(image, dtype=np.uint8)[..., ::-1]
|
||||
|
||||
output, _ = upsampler.enhance(
|
||||
bgr_image_array,
|
||||
outscale=upsampler_scale,
|
||||
alpha_upsampler='realesrgan',
|
||||
alpha_upsampler="realesrgan",
|
||||
)
|
||||
|
||||
# Flip the channels back to RGB
|
||||
res = Image.fromarray(output[...,::-1])
|
||||
res = Image.fromarray(output[..., ::-1])
|
||||
|
||||
if strength < 1.0:
|
||||
# Resize the image to the new image if the sizes have changed
|
@ -1,23 +1,27 @@
|
||||
'''
|
||||
"""
|
||||
VQGAN code, adapted from the original created by the Unleashing Transformers authors:
|
||||
https://github.com/samb-t/unleashing-transformers/blob/master/models/vqgan.py
|
||||
|
||||
'''
|
||||
"""
|
||||
import copy
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import copy
|
||||
from basicsr.utils import get_root_logger
|
||||
from basicsr.utils.registry import ARCH_REGISTRY
|
||||
|
||||
|
||||
def normalize(in_channels):
|
||||
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
||||
return torch.nn.GroupNorm(
|
||||
num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
|
||||
)
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def swish(x):
|
||||
return x*torch.sigmoid(x)
|
||||
return x * torch.sigmoid(x)
|
||||
|
||||
|
||||
# Define VQVAE classes
|
||||
@ -28,7 +32,9 @@ class VectorQuantizer(nn.Module):
|
||||
self.emb_dim = emb_dim # dimension of embedding
|
||||
self.beta = beta # commitment cost used in loss term, beta * ||z_e(x)-sg[e]||^2
|
||||
self.embedding = nn.Embedding(self.codebook_size, self.emb_dim)
|
||||
self.embedding.weight.data.uniform_(-1.0 / self.codebook_size, 1.0 / self.codebook_size)
|
||||
self.embedding.weight.data.uniform_(
|
||||
-1.0 / self.codebook_size, 1.0 / self.codebook_size
|
||||
)
|
||||
|
||||
def forward(self, z):
|
||||
# reshape z -> (batch, height, width, channel) and flatten
|
||||
@ -36,23 +42,32 @@ class VectorQuantizer(nn.Module):
|
||||
z_flattened = z.view(-1, self.emb_dim)
|
||||
|
||||
# distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
|
||||
d = (z_flattened ** 2).sum(dim=1, keepdim=True) + (self.embedding.weight**2).sum(1) - \
|
||||
2 * torch.matmul(z_flattened, self.embedding.weight.t())
|
||||
d = (
|
||||
(z_flattened**2).sum(dim=1, keepdim=True)
|
||||
+ (self.embedding.weight**2).sum(1)
|
||||
- 2 * torch.matmul(z_flattened, self.embedding.weight.t())
|
||||
)
|
||||
|
||||
mean_distance = torch.mean(d)
|
||||
# find closest encodings
|
||||
# min_encoding_indices = torch.argmin(d, dim=1).unsqueeze(1)
|
||||
min_encoding_scores, min_encoding_indices = torch.topk(d, 1, dim=1, largest=False)
|
||||
min_encoding_scores, min_encoding_indices = torch.topk(
|
||||
d, 1, dim=1, largest=False
|
||||
)
|
||||
# [0-1], higher score, higher confidence
|
||||
min_encoding_scores = torch.exp(-min_encoding_scores/10)
|
||||
min_encoding_scores = torch.exp(-min_encoding_scores / 10)
|
||||
|
||||
min_encodings = torch.zeros(min_encoding_indices.shape[0], self.codebook_size).to(z)
|
||||
min_encodings = torch.zeros(
|
||||
min_encoding_indices.shape[0], self.codebook_size
|
||||
).to(z)
|
||||
min_encodings.scatter_(1, min_encoding_indices, 1)
|
||||
|
||||
# get quantized latent vectors
|
||||
z_q = torch.matmul(min_encodings, self.embedding.weight).view(z.shape)
|
||||
# compute loss for embedding
|
||||
loss = torch.mean((z_q.detach()-z)**2) + self.beta * torch.mean((z_q - z.detach()) ** 2)
|
||||
loss = torch.mean((z_q.detach() - z) ** 2) + self.beta * torch.mean(
|
||||
(z_q - z.detach()) ** 2
|
||||
)
|
||||
# preserve gradients
|
||||
z_q = z + (z_q - z).detach()
|
||||
|
||||
@ -62,18 +77,22 @@ class VectorQuantizer(nn.Module):
|
||||
# reshape back to match original input shape
|
||||
z_q = z_q.permute(0, 3, 1, 2).contiguous()
|
||||
|
||||
return z_q, loss, {
|
||||
"perplexity": perplexity,
|
||||
"min_encodings": min_encodings,
|
||||
"min_encoding_indices": min_encoding_indices,
|
||||
"min_encoding_scores": min_encoding_scores,
|
||||
"mean_distance": mean_distance
|
||||
}
|
||||
return (
|
||||
z_q,
|
||||
loss,
|
||||
{
|
||||
"perplexity": perplexity,
|
||||
"min_encodings": min_encodings,
|
||||
"min_encoding_indices": min_encoding_indices,
|
||||
"min_encoding_scores": min_encoding_scores,
|
||||
"mean_distance": mean_distance,
|
||||
},
|
||||
)
|
||||
|
||||
def get_codebook_feat(self, indices, shape):
|
||||
# input indices: batch*token_num -> (batch*token_num)*1
|
||||
# shape: batch, height, width, channel
|
||||
indices = indices.view(-1,1)
|
||||
indices = indices.view(-1, 1)
|
||||
min_encodings = torch.zeros(indices.shape[0], self.codebook_size).to(indices)
|
||||
min_encodings.scatter_(1, indices, 1)
|
||||
# get quantized latent vectors
|
||||
@ -86,14 +105,24 @@ class VectorQuantizer(nn.Module):
|
||||
|
||||
|
||||
class GumbelQuantizer(nn.Module):
|
||||
def __init__(self, codebook_size, emb_dim, num_hiddens, straight_through=False, kl_weight=5e-4, temp_init=1.0):
|
||||
def __init__(
|
||||
self,
|
||||
codebook_size,
|
||||
emb_dim,
|
||||
num_hiddens,
|
||||
straight_through=False,
|
||||
kl_weight=5e-4,
|
||||
temp_init=1.0,
|
||||
):
|
||||
super().__init__()
|
||||
self.codebook_size = codebook_size # number of embeddings
|
||||
self.emb_dim = emb_dim # dimension of embedding
|
||||
self.straight_through = straight_through
|
||||
self.temperature = temp_init
|
||||
self.kl_weight = kl_weight
|
||||
self.proj = nn.Conv2d(num_hiddens, codebook_size, 1) # projects last encoder layer to quantized logits
|
||||
self.proj = nn.Conv2d(
|
||||
num_hiddens, codebook_size, 1
|
||||
) # projects last encoder layer to quantized logits
|
||||
self.embed = nn.Embedding(codebook_size, emb_dim)
|
||||
|
||||
def forward(self, z):
|
||||
@ -107,18 +136,21 @@ class GumbelQuantizer(nn.Module):
|
||||
|
||||
# + kl divergence to the prior loss
|
||||
qy = F.softmax(logits, dim=1)
|
||||
diff = self.kl_weight * torch.sum(qy * torch.log(qy * self.codebook_size + 1e-10), dim=1).mean()
|
||||
diff = (
|
||||
self.kl_weight
|
||||
* torch.sum(qy * torch.log(qy * self.codebook_size + 1e-10), dim=1).mean()
|
||||
)
|
||||
min_encoding_indices = soft_one_hot.argmax(dim=1)
|
||||
|
||||
return z_q, diff, {
|
||||
"min_encoding_indices": min_encoding_indices
|
||||
}
|
||||
return z_q, diff, {"min_encoding_indices": min_encoding_indices}
|
||||
|
||||
|
||||
class Downsample(nn.Module):
|
||||
def __init__(self, in_channels):
|
||||
super().__init__()
|
||||
self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
|
||||
self.conv = torch.nn.Conv2d(
|
||||
in_channels, in_channels, kernel_size=3, stride=2, padding=0
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
pad = (0, 1, 0, 1)
|
||||
@ -130,7 +162,9 @@ class Downsample(nn.Module):
|
||||
class Upsample(nn.Module):
|
||||
def __init__(self, in_channels):
|
||||
super().__init__()
|
||||
self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
|
||||
self.conv = nn.Conv2d(
|
||||
in_channels, in_channels, kernel_size=3, stride=1, padding=1
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x = F.interpolate(x, scale_factor=2.0, mode="nearest")
|
||||
@ -145,11 +179,17 @@ class ResBlock(nn.Module):
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = in_channels if out_channels is None else out_channels
|
||||
self.norm1 = normalize(in_channels)
|
||||
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
||||
self.conv1 = nn.Conv2d(
|
||||
in_channels, out_channels, kernel_size=3, stride=1, padding=1
|
||||
)
|
||||
self.norm2 = normalize(out_channels)
|
||||
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
||||
self.conv2 = nn.Conv2d(
|
||||
out_channels, out_channels, kernel_size=3, stride=1, padding=1
|
||||
)
|
||||
if self.in_channels != self.out_channels:
|
||||
self.conv_out = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
|
||||
self.conv_out = nn.Conv2d(
|
||||
in_channels, out_channels, kernel_size=1, stride=1, padding=0
|
||||
)
|
||||
|
||||
def forward(self, x_in):
|
||||
x = x_in
|
||||
@ -172,32 +212,16 @@ class AttnBlock(nn.Module):
|
||||
|
||||
self.norm = normalize(in_channels)
|
||||
self.q = torch.nn.Conv2d(
|
||||
in_channels,
|
||||
in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0
|
||||
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
||||
)
|
||||
self.k = torch.nn.Conv2d(
|
||||
in_channels,
|
||||
in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0
|
||||
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
||||
)
|
||||
self.v = torch.nn.Conv2d(
|
||||
in_channels,
|
||||
in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0
|
||||
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
||||
)
|
||||
self.proj_out = torch.nn.Conv2d(
|
||||
in_channels,
|
||||
in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0
|
||||
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
@ -209,26 +233,35 @@ class AttnBlock(nn.Module):
|
||||
|
||||
# compute attention
|
||||
b, c, h, w = q.shape
|
||||
q = q.reshape(b, c, h*w)
|
||||
q = q.reshape(b, c, h * w)
|
||||
q = q.permute(0, 2, 1)
|
||||
k = k.reshape(b, c, h*w)
|
||||
k = k.reshape(b, c, h * w)
|
||||
w_ = torch.bmm(q, k)
|
||||
w_ = w_ * (int(c)**(-0.5))
|
||||
w_ = w_ * (int(c) ** (-0.5))
|
||||
w_ = F.softmax(w_, dim=2)
|
||||
|
||||
# attend to values
|
||||
v = v.reshape(b, c, h*w)
|
||||
v = v.reshape(b, c, h * w)
|
||||
w_ = w_.permute(0, 2, 1)
|
||||
h_ = torch.bmm(v, w_)
|
||||
h_ = h_.reshape(b, c, h, w)
|
||||
|
||||
h_ = self.proj_out(h_)
|
||||
|
||||
return x+h_
|
||||
return x + h_
|
||||
|
||||
|
||||
class Encoder(nn.Module):
|
||||
def __init__(self, in_channels, nf, emb_dim, ch_mult, num_res_blocks, resolution, attn_resolutions):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
nf,
|
||||
emb_dim,
|
||||
ch_mult,
|
||||
num_res_blocks,
|
||||
resolution,
|
||||
attn_resolutions,
|
||||
):
|
||||
super().__init__()
|
||||
self.nf = nf
|
||||
self.num_resolutions = len(ch_mult)
|
||||
@ -237,7 +270,7 @@ class Encoder(nn.Module):
|
||||
self.attn_resolutions = attn_resolutions
|
||||
|
||||
curr_res = self.resolution
|
||||
in_ch_mult = (1,)+tuple(ch_mult)
|
||||
in_ch_mult = (1,) + tuple(ch_mult)
|
||||
|
||||
blocks = []
|
||||
# initial convultion
|
||||
@ -264,7 +297,9 @@ class Encoder(nn.Module):
|
||||
|
||||
# normalise and convert to latent size
|
||||
blocks.append(normalize(block_in_ch))
|
||||
blocks.append(nn.Conv2d(block_in_ch, emb_dim, kernel_size=3, stride=1, padding=1))
|
||||
blocks.append(
|
||||
nn.Conv2d(block_in_ch, emb_dim, kernel_size=3, stride=1, padding=1)
|
||||
)
|
||||
self.blocks = nn.ModuleList(blocks)
|
||||
|
||||
def forward(self, x):
|
||||
@ -286,11 +321,13 @@ class Generator(nn.Module):
|
||||
self.in_channels = emb_dim
|
||||
self.out_channels = 3
|
||||
block_in_ch = self.nf * self.ch_mult[-1]
|
||||
curr_res = self.resolution // 2 ** (self.num_resolutions-1)
|
||||
curr_res = self.resolution // 2 ** (self.num_resolutions - 1)
|
||||
|
||||
blocks = []
|
||||
# initial conv
|
||||
blocks.append(nn.Conv2d(self.in_channels, block_in_ch, kernel_size=3, stride=1, padding=1))
|
||||
blocks.append(
|
||||
nn.Conv2d(self.in_channels, block_in_ch, kernel_size=3, stride=1, padding=1)
|
||||
)
|
||||
|
||||
# non-local attention block
|
||||
blocks.append(ResBlock(block_in_ch, block_in_ch))
|
||||
@ -312,11 +349,14 @@ class Generator(nn.Module):
|
||||
curr_res = curr_res * 2
|
||||
|
||||
blocks.append(normalize(block_in_ch))
|
||||
blocks.append(nn.Conv2d(block_in_ch, self.out_channels, kernel_size=3, stride=1, padding=1))
|
||||
blocks.append(
|
||||
nn.Conv2d(
|
||||
block_in_ch, self.out_channels, kernel_size=3, stride=1, padding=1
|
||||
)
|
||||
)
|
||||
|
||||
self.blocks = nn.ModuleList(blocks)
|
||||
|
||||
|
||||
def forward(self, x):
|
||||
for block in self.blocks:
|
||||
x = block(x)
|
||||
@ -326,8 +366,21 @@ class Generator(nn.Module):
|
||||
|
||||
@ARCH_REGISTRY.register()
|
||||
class VQAutoEncoder(nn.Module):
|
||||
def __init__(self, img_size, nf, ch_mult, quantizer="nearest", res_blocks=2, attn_resolutions=[16], codebook_size=1024, emb_dim=256,
|
||||
beta=0.25, gumbel_straight_through=False, gumbel_kl_weight=1e-8, model_path=None):
|
||||
def __init__(
|
||||
self,
|
||||
img_size,
|
||||
nf,
|
||||
ch_mult,
|
||||
quantizer="nearest",
|
||||
res_blocks=2,
|
||||
attn_resolutions=[16],
|
||||
codebook_size=1024,
|
||||
emb_dim=256,
|
||||
beta=0.25,
|
||||
gumbel_straight_through=False,
|
||||
gumbel_kl_weight=1e-8,
|
||||
model_path=None,
|
||||
):
|
||||
super().__init__()
|
||||
logger = get_root_logger()
|
||||
self.in_channels = 3
|
||||
@ -346,11 +399,13 @@ class VQAutoEncoder(nn.Module):
|
||||
self.ch_mult,
|
||||
self.n_blocks,
|
||||
self.resolution,
|
||||
self.attn_resolutions
|
||||
self.attn_resolutions,
|
||||
)
|
||||
if self.quantizer_type == "nearest":
|
||||
self.beta = beta #0.25
|
||||
self.quantize = VectorQuantizer(self.codebook_size, self.embed_dim, self.beta)
|
||||
self.beta = beta # 0.25
|
||||
self.quantize = VectorQuantizer(
|
||||
self.codebook_size, self.embed_dim, self.beta
|
||||
)
|
||||
elif self.quantizer_type == "gumbel":
|
||||
self.gumbel_num_hiddens = emb_dim
|
||||
self.straight_through = gumbel_straight_through
|
||||
@ -360,7 +415,7 @@ class VQAutoEncoder(nn.Module):
|
||||
self.embed_dim,
|
||||
self.gumbel_num_hiddens,
|
||||
self.straight_through,
|
||||
self.kl_weight
|
||||
self.kl_weight,
|
||||
)
|
||||
self.generator = Generator(
|
||||
self.nf,
|
||||
@ -368,20 +423,23 @@ class VQAutoEncoder(nn.Module):
|
||||
self.ch_mult,
|
||||
self.n_blocks,
|
||||
self.resolution,
|
||||
self.attn_resolutions
|
||||
self.attn_resolutions,
|
||||
)
|
||||
|
||||
if model_path is not None:
|
||||
chkpt = torch.load(model_path, map_location='cpu')
|
||||
if 'params_ema' in chkpt:
|
||||
self.load_state_dict(torch.load(model_path, map_location='cpu')['params_ema'])
|
||||
logger.info(f'vqgan is loaded from: {model_path} [params_ema]')
|
||||
elif 'params' in chkpt:
|
||||
self.load_state_dict(torch.load(model_path, map_location='cpu')['params'])
|
||||
logger.info(f'vqgan is loaded from: {model_path} [params]')
|
||||
chkpt = torch.load(model_path, map_location="cpu")
|
||||
if "params_ema" in chkpt:
|
||||
self.load_state_dict(
|
||||
torch.load(model_path, map_location="cpu")["params_ema"]
|
||||
)
|
||||
logger.info(f"vqgan is loaded from: {model_path} [params_ema]")
|
||||
elif "params" in chkpt:
|
||||
self.load_state_dict(
|
||||
torch.load(model_path, map_location="cpu")["params"]
|
||||
)
|
||||
logger.info(f"vqgan is loaded from: {model_path} [params]")
|
||||
else:
|
||||
raise ValueError(f'Wrong params!')
|
||||
|
||||
raise ValueError(f"Wrong params!")
|
||||
|
||||
def forward(self, x):
|
||||
x = self.encoder(x)
|
||||
@ -390,46 +448,67 @@ class VQAutoEncoder(nn.Module):
|
||||
return x, codebook_loss, quant_stats
|
||||
|
||||
|
||||
|
||||
# patch based discriminator
|
||||
@ARCH_REGISTRY.register()
|
||||
class VQGANDiscriminator(nn.Module):
|
||||
def __init__(self, nc=3, ndf=64, n_layers=4, model_path=None):
|
||||
super().__init__()
|
||||
|
||||
layers = [nn.Conv2d(nc, ndf, kernel_size=4, stride=2, padding=1), nn.LeakyReLU(0.2, True)]
|
||||
layers = [
|
||||
nn.Conv2d(nc, ndf, kernel_size=4, stride=2, padding=1),
|
||||
nn.LeakyReLU(0.2, True),
|
||||
]
|
||||
ndf_mult = 1
|
||||
ndf_mult_prev = 1
|
||||
for n in range(1, n_layers): # gradually increase the number of filters
|
||||
ndf_mult_prev = ndf_mult
|
||||
ndf_mult = min(2 ** n, 8)
|
||||
ndf_mult = min(2**n, 8)
|
||||
layers += [
|
||||
nn.Conv2d(ndf * ndf_mult_prev, ndf * ndf_mult, kernel_size=4, stride=2, padding=1, bias=False),
|
||||
nn.Conv2d(
|
||||
ndf * ndf_mult_prev,
|
||||
ndf * ndf_mult,
|
||||
kernel_size=4,
|
||||
stride=2,
|
||||
padding=1,
|
||||
bias=False,
|
||||
),
|
||||
nn.BatchNorm2d(ndf * ndf_mult),
|
||||
nn.LeakyReLU(0.2, True)
|
||||
nn.LeakyReLU(0.2, True),
|
||||
]
|
||||
|
||||
ndf_mult_prev = ndf_mult
|
||||
ndf_mult = min(2 ** n_layers, 8)
|
||||
ndf_mult = min(2**n_layers, 8)
|
||||
|
||||
layers += [
|
||||
nn.Conv2d(ndf * ndf_mult_prev, ndf * ndf_mult, kernel_size=4, stride=1, padding=1, bias=False),
|
||||
nn.Conv2d(
|
||||
ndf * ndf_mult_prev,
|
||||
ndf * ndf_mult,
|
||||
kernel_size=4,
|
||||
stride=1,
|
||||
padding=1,
|
||||
bias=False,
|
||||
),
|
||||
nn.BatchNorm2d(ndf * ndf_mult),
|
||||
nn.LeakyReLU(0.2, True)
|
||||
nn.LeakyReLU(0.2, True),
|
||||
]
|
||||
|
||||
layers += [
|
||||
nn.Conv2d(ndf * ndf_mult, 1, kernel_size=4, stride=1, padding=1)] # output 1 channel prediction map
|
||||
nn.Conv2d(ndf * ndf_mult, 1, kernel_size=4, stride=1, padding=1)
|
||||
] # output 1 channel prediction map
|
||||
self.main = nn.Sequential(*layers)
|
||||
|
||||
if model_path is not None:
|
||||
chkpt = torch.load(model_path, map_location='cpu')
|
||||
if 'params_d' in chkpt:
|
||||
self.load_state_dict(torch.load(model_path, map_location='cpu')['params_d'])
|
||||
elif 'params' in chkpt:
|
||||
self.load_state_dict(torch.load(model_path, map_location='cpu')['params'])
|
||||
chkpt = torch.load(model_path, map_location="cpu")
|
||||
if "params_d" in chkpt:
|
||||
self.load_state_dict(
|
||||
torch.load(model_path, map_location="cpu")["params_d"]
|
||||
)
|
||||
elif "params" in chkpt:
|
||||
self.load_state_dict(
|
||||
torch.load(model_path, map_location="cpu")["params"]
|
||||
)
|
||||
else:
|
||||
raise ValueError(f'Wrong params!')
|
||||
raise ValueError(f"Wrong params!")
|
||||
|
||||
def forward(self, x):
|
||||
return self.main(x)
|
16
invokeai/backend/stable_diffusion/__init__.py
Normal file
16
invokeai/backend/stable_diffusion/__init__.py
Normal file
@ -0,0 +1,16 @@
|
||||
"""
|
||||
Initialization file for the invokeai.backend.stable_diffusion package
|
||||
"""
|
||||
from .concepts_lib import HuggingFaceConceptsLibrary
|
||||
from .diffusers_pipeline import (
|
||||
ConditioningData,
|
||||
PipelineIntermediateState,
|
||||
StableDiffusionGeneratorPipeline,
|
||||
)
|
||||
from .diffusion import InvokeAIDiffuserComponent
|
||||
from .diffusion.cross_attention_map_saving import AttentionMapSaver
|
||||
from .diffusion.ddim import DDIMSampler
|
||||
from .diffusion.ksampler import KSampler
|
||||
from .diffusion.plms import PLMSSampler
|
||||
from .diffusion.shared_invokeai_diffusion import PostprocessingSettings
|
||||
from .textual_inversion_manager import TextualInversionManager
|
@ -1,21 +1,22 @@
|
||||
from inspect import isfunction
|
||||
import math
|
||||
from inspect import isfunction
|
||||
from typing import Callable, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn, einsum
|
||||
from einops import rearrange, repeat
|
||||
from torch import einsum, nn
|
||||
|
||||
from .diffusion import InvokeAICrossAttentionMixin
|
||||
from .diffusionmodules.util import checkpoint
|
||||
|
||||
from ldm.models.diffusion.cross_attention_control import InvokeAICrossAttentionMixin
|
||||
from ldm.modules.diffusionmodules.util import checkpoint
|
||||
|
||||
def exists(val):
|
||||
return val is not None
|
||||
|
||||
|
||||
def uniq(arr):
|
||||
return{el: True for el in arr}.keys()
|
||||
return {el: True for el in arr}.keys()
|
||||
|
||||
|
||||
def default(val, d):
|
||||
@ -47,19 +48,18 @@ class GEGLU(nn.Module):
|
||||
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
|
||||
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0):
|
||||
super().__init__()
|
||||
inner_dim = int(dim * mult)
|
||||
dim_out = default(dim_out, dim)
|
||||
project_in = nn.Sequential(
|
||||
nn.Linear(dim, inner_dim),
|
||||
nn.GELU()
|
||||
) if not glu else GEGLU(dim, inner_dim)
|
||||
project_in = (
|
||||
nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU())
|
||||
if not glu
|
||||
else GEGLU(dim, inner_dim)
|
||||
)
|
||||
|
||||
self.net = nn.Sequential(
|
||||
project_in,
|
||||
nn.Dropout(dropout),
|
||||
nn.Linear(inner_dim, dim_out)
|
||||
project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
@ -76,7 +76,9 @@ def zero_module(module):
|
||||
|
||||
|
||||
def Normalize(in_channels):
|
||||
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
||||
return torch.nn.GroupNorm(
|
||||
num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
|
||||
)
|
||||
|
||||
|
||||
class LinearAttention(nn.Module):
|
||||
@ -84,17 +86,21 @@ class LinearAttention(nn.Module):
|
||||
super().__init__()
|
||||
self.heads = heads
|
||||
hidden_dim = dim_head * heads
|
||||
self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False)
|
||||
self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
|
||||
self.to_out = nn.Conv2d(hidden_dim, dim, 1)
|
||||
|
||||
def forward(self, x):
|
||||
b, c, h, w = x.shape
|
||||
qkv = self.to_qkv(x)
|
||||
q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads = self.heads, qkv=3)
|
||||
q, k, v = rearrange(
|
||||
qkv, "b (qkv heads c) h w -> qkv b heads c (h w)", heads=self.heads, qkv=3
|
||||
)
|
||||
k = k.softmax(dim=-1)
|
||||
context = torch.einsum('bhdn,bhen->bhde', k, v)
|
||||
out = torch.einsum('bhde,bhdn->bhen', context, q)
|
||||
out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w)
|
||||
context = torch.einsum("bhdn,bhen->bhde", k, v)
|
||||
out = torch.einsum("bhde,bhdn->bhen", context, q)
|
||||
out = rearrange(
|
||||
out, "b heads c (h w) -> b (heads c) h w", heads=self.heads, h=h, w=w
|
||||
)
|
||||
return self.to_out(out)
|
||||
|
||||
|
||||
@ -104,26 +110,18 @@ class SpatialSelfAttention(nn.Module):
|
||||
self.in_channels = in_channels
|
||||
|
||||
self.norm = Normalize(in_channels)
|
||||
self.q = torch.nn.Conv2d(in_channels,
|
||||
in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0)
|
||||
self.k = torch.nn.Conv2d(in_channels,
|
||||
in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0)
|
||||
self.v = torch.nn.Conv2d(in_channels,
|
||||
in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0)
|
||||
self.proj_out = torch.nn.Conv2d(in_channels,
|
||||
in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0)
|
||||
self.q = torch.nn.Conv2d(
|
||||
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
||||
)
|
||||
self.k = torch.nn.Conv2d(
|
||||
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
||||
)
|
||||
self.v = torch.nn.Conv2d(
|
||||
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
||||
)
|
||||
self.proj_out = torch.nn.Conv2d(
|
||||
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
h_ = x
|
||||
@ -133,43 +131,45 @@ class SpatialSelfAttention(nn.Module):
|
||||
v = self.v(h_)
|
||||
|
||||
# compute attention
|
||||
b,c,h,w = q.shape
|
||||
q = rearrange(q, 'b c h w -> b (h w) c')
|
||||
k = rearrange(k, 'b c h w -> b c (h w)')
|
||||
w_ = torch.einsum('bij,bjk->bik', q, k)
|
||||
b, c, h, w = q.shape
|
||||
q = rearrange(q, "b c h w -> b (h w) c")
|
||||
k = rearrange(k, "b c h w -> b c (h w)")
|
||||
w_ = torch.einsum("bij,bjk->bik", q, k)
|
||||
|
||||
w_ = w_ * (int(c)**(-0.5))
|
||||
w_ = w_ * (int(c) ** (-0.5))
|
||||
w_ = torch.nn.functional.softmax(w_, dim=2)
|
||||
|
||||
# attend to values
|
||||
v = rearrange(v, 'b c h w -> b c (h w)')
|
||||
w_ = rearrange(w_, 'b i j -> b j i')
|
||||
h_ = torch.einsum('bij,bjk->bik', v, w_)
|
||||
h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h)
|
||||
v = rearrange(v, "b c h w -> b c (h w)")
|
||||
w_ = rearrange(w_, "b i j -> b j i")
|
||||
h_ = torch.einsum("bij,bjk->bik", v, w_)
|
||||
h_ = rearrange(h_, "b c (h w) -> b c h w", h=h)
|
||||
h_ = self.proj_out(h_)
|
||||
|
||||
return x+h_
|
||||
return x + h_
|
||||
|
||||
|
||||
def get_mem_free_total(device):
|
||||
#only on cuda
|
||||
# only on cuda
|
||||
if not torch.cuda.is_available():
|
||||
return None
|
||||
stats = torch.cuda.memory_stats(device)
|
||||
mem_active = stats['active_bytes.all.current']
|
||||
mem_reserved = stats['reserved_bytes.all.current']
|
||||
mem_active = stats["active_bytes.all.current"]
|
||||
mem_reserved = stats["reserved_bytes.all.current"]
|
||||
mem_free_cuda, _ = torch.cuda.mem_get_info(device)
|
||||
mem_free_torch = mem_reserved - mem_active
|
||||
mem_free_total = mem_free_cuda + mem_free_torch
|
||||
return mem_free_total
|
||||
|
||||
|
||||
class CrossAttention(nn.Module, InvokeAICrossAttentionMixin):
|
||||
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.):
|
||||
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0):
|
||||
super().__init__()
|
||||
InvokeAICrossAttentionMixin.__init__(self)
|
||||
inner_dim = dim_head * heads
|
||||
context_dim = default(context_dim, query_dim)
|
||||
|
||||
self.scale = dim_head ** -0.5
|
||||
self.scale = dim_head**-0.5
|
||||
self.heads = heads
|
||||
|
||||
self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
|
||||
@ -177,8 +177,7 @@ class CrossAttention(nn.Module, InvokeAICrossAttentionMixin):
|
||||
self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
|
||||
|
||||
self.to_out = nn.Sequential(
|
||||
nn.Linear(inner_dim, query_dim),
|
||||
nn.Dropout(dropout)
|
||||
nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)
|
||||
)
|
||||
|
||||
def forward(self, x, context=None, mask=None):
|
||||
@ -190,7 +189,7 @@ class CrossAttention(nn.Module, InvokeAICrossAttentionMixin):
|
||||
v = self.to_v(context)
|
||||
del context, x
|
||||
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
|
||||
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v))
|
||||
|
||||
# don't apply scale twice
|
||||
cached_scale = self.scale
|
||||
@ -198,29 +197,45 @@ class CrossAttention(nn.Module, InvokeAICrossAttentionMixin):
|
||||
r = self.get_invokeai_attention_mem_efficient(q, k, v)
|
||||
self.scale = cached_scale
|
||||
|
||||
hidden_states = rearrange(r, '(b h) n d -> b n (h d)', h=h)
|
||||
hidden_states = rearrange(r, "(b h) n d -> b n (h d)", h=h)
|
||||
return self.to_out(hidden_states)
|
||||
|
||||
|
||||
|
||||
|
||||
class BasicTransformerBlock(nn.Module):
|
||||
def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
n_heads,
|
||||
d_head,
|
||||
dropout=0.0,
|
||||
context_dim=None,
|
||||
gated_ff=True,
|
||||
checkpoint=True,
|
||||
):
|
||||
super().__init__()
|
||||
self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout) # is a self-attention
|
||||
self.attn1 = CrossAttention(
|
||||
query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout
|
||||
) # is a self-attention
|
||||
self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
|
||||
self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim,
|
||||
heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none
|
||||
self.attn2 = CrossAttention(
|
||||
query_dim=dim,
|
||||
context_dim=context_dim,
|
||||
heads=n_heads,
|
||||
dim_head=d_head,
|
||||
dropout=dropout,
|
||||
) # is self-attn if context is none
|
||||
self.norm1 = nn.LayerNorm(dim)
|
||||
self.norm2 = nn.LayerNorm(dim)
|
||||
self.norm3 = nn.LayerNorm(dim)
|
||||
self.checkpoint = checkpoint
|
||||
|
||||
def forward(self, x, context=None):
|
||||
return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint)
|
||||
return checkpoint(
|
||||
self._forward, (x, context), self.parameters(), self.checkpoint
|
||||
)
|
||||
|
||||
def _forward(self, x, context=None):
|
||||
x = x.contiguous() if x.device.type == 'mps' else x
|
||||
x = x.contiguous() if x.device.type == "mps" else x
|
||||
x += self.attn1(self.norm1(x.clone()))
|
||||
x += self.attn2(self.norm2(x.clone()), context=context)
|
||||
x += self.ff(self.norm3(x.clone()))
|
||||
@ -235,29 +250,31 @@ class SpatialTransformer(nn.Module):
|
||||
Then apply standard transformer action.
|
||||
Finally, reshape to image
|
||||
"""
|
||||
def __init__(self, in_channels, n_heads, d_head,
|
||||
depth=1, dropout=0., context_dim=None):
|
||||
|
||||
def __init__(
|
||||
self, in_channels, n_heads, d_head, depth=1, dropout=0.0, context_dim=None
|
||||
):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
inner_dim = n_heads * d_head
|
||||
self.norm = Normalize(in_channels)
|
||||
|
||||
self.proj_in = nn.Conv2d(in_channels,
|
||||
inner_dim,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0)
|
||||
|
||||
self.transformer_blocks = nn.ModuleList(
|
||||
[BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim)
|
||||
for d in range(depth)]
|
||||
self.proj_in = nn.Conv2d(
|
||||
in_channels, inner_dim, kernel_size=1, stride=1, padding=0
|
||||
)
|
||||
|
||||
self.proj_out = zero_module(nn.Conv2d(inner_dim,
|
||||
in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0))
|
||||
self.transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
BasicTransformerBlock(
|
||||
inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim
|
||||
)
|
||||
for d in range(depth)
|
||||
]
|
||||
)
|
||||
|
||||
self.proj_out = zero_module(
|
||||
nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
)
|
||||
|
||||
def forward(self, x, context=None):
|
||||
# note: if no context is given, cross-attention defaults to self-attention
|
||||
@ -265,9 +282,9 @@ class SpatialTransformer(nn.Module):
|
||||
x_in = x
|
||||
x = self.norm(x)
|
||||
x = self.proj_in(x)
|
||||
x = rearrange(x, 'b c h w -> b (h w) c')
|
||||
x = rearrange(x, "b c h w -> b (h w) c")
|
||||
for block in self.transformer_blocks:
|
||||
x = block(x, context=context)
|
||||
x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w)
|
||||
x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w)
|
||||
x = self.proj_out(x)
|
||||
return x + x_in
|
@ -1,16 +1,13 @@
|
||||
import torch
|
||||
import pytorch_lightning as pl
|
||||
import torch.nn.functional as F
|
||||
from contextlib import contextmanager
|
||||
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer
|
||||
|
||||
from ldm.modules.diffusionmodules.model import Encoder, Decoder
|
||||
from ldm.modules.distributions.distributions import (
|
||||
DiagonalGaussianDistribution,
|
||||
)
|
||||
|
||||
from ldm.util import instantiate_from_config
|
||||
from ..util import instantiate_from_config
|
||||
from .diffusionmodules.model import Decoder, Encoder
|
||||
from .distributions.distributions import DiagonalGaussianDistribution
|
||||
|
||||
|
||||
class VQModel(pl.LightningModule):
|
||||
@ -22,7 +19,7 @@ class VQModel(pl.LightningModule):
|
||||
embed_dim,
|
||||
ckpt_path=None,
|
||||
ignore_keys=[],
|
||||
image_key='image',
|
||||
image_key="image",
|
||||
colorize_nlabels=None,
|
||||
monitor=None,
|
||||
batch_resize_range=None,
|
||||
@ -46,27 +43,23 @@ class VQModel(pl.LightningModule):
|
||||
remap=remap,
|
||||
sane_index_shape=sane_index_shape,
|
||||
)
|
||||
self.quant_conv = torch.nn.Conv2d(ddconfig['z_channels'], embed_dim, 1)
|
||||
self.post_quant_conv = torch.nn.Conv2d(
|
||||
embed_dim, ddconfig['z_channels'], 1
|
||||
)
|
||||
self.quant_conv = torch.nn.Conv2d(ddconfig["z_channels"], embed_dim, 1)
|
||||
self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
|
||||
if colorize_nlabels is not None:
|
||||
assert type(colorize_nlabels) == int
|
||||
self.register_buffer(
|
||||
'colorize', torch.randn(3, colorize_nlabels, 1, 1)
|
||||
)
|
||||
self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
|
||||
if monitor is not None:
|
||||
self.monitor = monitor
|
||||
self.batch_resize_range = batch_resize_range
|
||||
if self.batch_resize_range is not None:
|
||||
print(
|
||||
f'{self.__class__.__name__}: Using per-batch resizing in range {batch_resize_range}.'
|
||||
f"{self.__class__.__name__}: Using per-batch resizing in range {batch_resize_range}."
|
||||
)
|
||||
|
||||
self.use_ema = use_ema
|
||||
if self.use_ema:
|
||||
self.model_ema = LitEma(self)
|
||||
print(f'>> Keeping EMAs of {len(list(self.model_ema.buffers()))}.')
|
||||
print(f">> Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
|
||||
|
||||
if ckpt_path is not None:
|
||||
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
|
||||
@ -79,30 +72,30 @@ class VQModel(pl.LightningModule):
|
||||
self.model_ema.store(self.parameters())
|
||||
self.model_ema.copy_to(self)
|
||||
if context is not None:
|
||||
print(f'{context}: Switched to EMA weights')
|
||||
print(f"{context}: Switched to EMA weights")
|
||||
try:
|
||||
yield None
|
||||
finally:
|
||||
if self.use_ema:
|
||||
self.model_ema.restore(self.parameters())
|
||||
if context is not None:
|
||||
print(f'{context}: Restored training weights')
|
||||
print(f"{context}: Restored training weights")
|
||||
|
||||
def init_from_ckpt(self, path, ignore_keys=list()):
|
||||
sd = torch.load(path, map_location='cpu')['state_dict']
|
||||
sd = torch.load(path, map_location="cpu")["state_dict"]
|
||||
keys = list(sd.keys())
|
||||
for k in keys:
|
||||
for ik in ignore_keys:
|
||||
if k.startswith(ik):
|
||||
print('Deleting key {} from state_dict.'.format(k))
|
||||
print("Deleting key {} from state_dict.".format(k))
|
||||
del sd[k]
|
||||
missing, unexpected = self.load_state_dict(sd, strict=False)
|
||||
print(
|
||||
f'Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys'
|
||||
f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys"
|
||||
)
|
||||
if len(missing) > 0:
|
||||
print(f'Missing Keys: {missing}')
|
||||
print(f'Unexpected Keys: {unexpected}')
|
||||
print(f"Missing Keys: {missing}")
|
||||
print(f"Unexpected Keys: {unexpected}")
|
||||
|
||||
def on_train_batch_end(self, *args, **kwargs):
|
||||
if self.use_ema:
|
||||
@ -140,11 +133,7 @@ class VQModel(pl.LightningModule):
|
||||
x = batch[k]
|
||||
if len(x.shape) == 3:
|
||||
x = x[..., None]
|
||||
x = (
|
||||
x.permute(0, 3, 1, 2)
|
||||
.to(memory_format=torch.contiguous_format)
|
||||
.float()
|
||||
)
|
||||
x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
|
||||
if self.batch_resize_range is not None:
|
||||
lower_size = self.batch_resize_range[0]
|
||||
upper_size = self.batch_resize_range[1]
|
||||
@ -156,7 +145,7 @@ class VQModel(pl.LightningModule):
|
||||
np.arange(lower_size, upper_size + 16, 16)
|
||||
)
|
||||
if new_resize != x.shape[2]:
|
||||
x = F.interpolate(x, size=new_resize, mode='bicubic')
|
||||
x = F.interpolate(x, size=new_resize, mode="bicubic")
|
||||
x = x.detach()
|
||||
return x
|
||||
|
||||
@ -175,7 +164,7 @@ class VQModel(pl.LightningModule):
|
||||
optimizer_idx,
|
||||
self.global_step,
|
||||
last_layer=self.get_last_layer(),
|
||||
split='train',
|
||||
split="train",
|
||||
predicted_indices=ind,
|
||||
)
|
||||
|
||||
@ -197,7 +186,7 @@ class VQModel(pl.LightningModule):
|
||||
optimizer_idx,
|
||||
self.global_step,
|
||||
last_layer=self.get_last_layer(),
|
||||
split='train',
|
||||
split="train",
|
||||
)
|
||||
self.log_dict(
|
||||
log_dict_disc,
|
||||
@ -211,12 +200,10 @@ class VQModel(pl.LightningModule):
|
||||
def validation_step(self, batch, batch_idx):
|
||||
log_dict = self._validation_step(batch, batch_idx)
|
||||
with self.ema_scope():
|
||||
log_dict_ema = self._validation_step(
|
||||
batch, batch_idx, suffix='_ema'
|
||||
)
|
||||
log_dict_ema = self._validation_step(batch, batch_idx, suffix="_ema")
|
||||
return log_dict
|
||||
|
||||
def _validation_step(self, batch, batch_idx, suffix=''):
|
||||
def _validation_step(self, batch, batch_idx, suffix=""):
|
||||
x = self.get_input(batch, self.image_key)
|
||||
xrec, qloss, ind = self(x, return_pred_indices=True)
|
||||
aeloss, log_dict_ae = self.loss(
|
||||
@ -226,7 +213,7 @@ class VQModel(pl.LightningModule):
|
||||
0,
|
||||
self.global_step,
|
||||
last_layer=self.get_last_layer(),
|
||||
split='val' + suffix,
|
||||
split="val" + suffix,
|
||||
predicted_indices=ind,
|
||||
)
|
||||
|
||||
@ -237,12 +224,12 @@ class VQModel(pl.LightningModule):
|
||||
1,
|
||||
self.global_step,
|
||||
last_layer=self.get_last_layer(),
|
||||
split='val' + suffix,
|
||||
split="val" + suffix,
|
||||
predicted_indices=ind,
|
||||
)
|
||||
rec_loss = log_dict_ae[f'val{suffix}/rec_loss']
|
||||
rec_loss = log_dict_ae[f"val{suffix}/rec_loss"]
|
||||
self.log(
|
||||
f'val{suffix}/rec_loss',
|
||||
f"val{suffix}/rec_loss",
|
||||
rec_loss,
|
||||
prog_bar=True,
|
||||
logger=True,
|
||||
@ -251,7 +238,7 @@ class VQModel(pl.LightningModule):
|
||||
sync_dist=True,
|
||||
)
|
||||
self.log(
|
||||
f'val{suffix}/aeloss',
|
||||
f"val{suffix}/aeloss",
|
||||
aeloss,
|
||||
prog_bar=True,
|
||||
logger=True,
|
||||
@ -259,8 +246,8 @@ class VQModel(pl.LightningModule):
|
||||
on_epoch=True,
|
||||
sync_dist=True,
|
||||
)
|
||||
if version.parse(pl.__version__) >= version.parse('1.4.0'):
|
||||
del log_dict_ae[f'val{suffix}/rec_loss']
|
||||
if version.parse(pl.__version__) >= version.parse("1.4.0"):
|
||||
del log_dict_ae[f"val{suffix}/rec_loss"]
|
||||
self.log_dict(log_dict_ae)
|
||||
self.log_dict(log_dict_disc)
|
||||
return self.log_dict
|
||||
@ -268,8 +255,8 @@ class VQModel(pl.LightningModule):
|
||||
def configure_optimizers(self):
|
||||
lr_d = self.learning_rate
|
||||
lr_g = self.lr_g_factor * self.learning_rate
|
||||
print('lr_d', lr_d)
|
||||
print('lr_g', lr_g)
|
||||
print("lr_d", lr_d)
|
||||
print("lr_g", lr_g)
|
||||
opt_ae = torch.optim.Adam(
|
||||
list(self.encoder.parameters())
|
||||
+ list(self.decoder.parameters())
|
||||
@ -286,21 +273,17 @@ class VQModel(pl.LightningModule):
|
||||
if self.scheduler_config is not None:
|
||||
scheduler = instantiate_from_config(self.scheduler_config)
|
||||
|
||||
print('Setting up LambdaLR scheduler...')
|
||||
print("Setting up LambdaLR scheduler...")
|
||||
scheduler = [
|
||||
{
|
||||
'scheduler': LambdaLR(
|
||||
opt_ae, lr_lambda=scheduler.schedule
|
||||
),
|
||||
'interval': 'step',
|
||||
'frequency': 1,
|
||||
"scheduler": LambdaLR(opt_ae, lr_lambda=scheduler.schedule),
|
||||
"interval": "step",
|
||||
"frequency": 1,
|
||||
},
|
||||
{
|
||||
'scheduler': LambdaLR(
|
||||
opt_disc, lr_lambda=scheduler.schedule
|
||||
),
|
||||
'interval': 'step',
|
||||
'frequency': 1,
|
||||
"scheduler": LambdaLR(opt_disc, lr_lambda=scheduler.schedule),
|
||||
"interval": "step",
|
||||
"frequency": 1,
|
||||
},
|
||||
]
|
||||
return [opt_ae, opt_disc], scheduler
|
||||
@ -314,7 +297,7 @@ class VQModel(pl.LightningModule):
|
||||
x = self.get_input(batch, self.image_key)
|
||||
x = x.to(self.device)
|
||||
if only_inputs:
|
||||
log['inputs'] = x
|
||||
log["inputs"] = x
|
||||
return log
|
||||
xrec, _ = self(x)
|
||||
if x.shape[1] > 3:
|
||||
@ -322,22 +305,20 @@ class VQModel(pl.LightningModule):
|
||||
assert xrec.shape[1] > 3
|
||||
x = self.to_rgb(x)
|
||||
xrec = self.to_rgb(xrec)
|
||||
log['inputs'] = x
|
||||
log['reconstructions'] = xrec
|
||||
log["inputs"] = x
|
||||
log["reconstructions"] = xrec
|
||||
if plot_ema:
|
||||
with self.ema_scope():
|
||||
xrec_ema, _ = self(x)
|
||||
if x.shape[1] > 3:
|
||||
xrec_ema = self.to_rgb(xrec_ema)
|
||||
log['reconstructions_ema'] = xrec_ema
|
||||
log["reconstructions_ema"] = xrec_ema
|
||||
return log
|
||||
|
||||
def to_rgb(self, x):
|
||||
assert self.image_key == 'segmentation'
|
||||
if not hasattr(self, 'colorize'):
|
||||
self.register_buffer(
|
||||
'colorize', torch.randn(3, x.shape[1], 1, 1).to(x)
|
||||
)
|
||||
assert self.image_key == "segmentation"
|
||||
if not hasattr(self, "colorize"):
|
||||
self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
|
||||
x = F.conv2d(x, weight=self.colorize)
|
||||
x = 2.0 * (x - x.min()) / (x.max() - x.min()) - 1.0
|
||||
return x
|
||||
@ -372,7 +353,7 @@ class AutoencoderKL(pl.LightningModule):
|
||||
embed_dim,
|
||||
ckpt_path=None,
|
||||
ignore_keys=[],
|
||||
image_key='image',
|
||||
image_key="image",
|
||||
colorize_nlabels=None,
|
||||
monitor=None,
|
||||
):
|
||||
@ -381,34 +362,28 @@ class AutoencoderKL(pl.LightningModule):
|
||||
self.encoder = Encoder(**ddconfig)
|
||||
self.decoder = Decoder(**ddconfig)
|
||||
self.loss = instantiate_from_config(lossconfig)
|
||||
assert ddconfig['double_z']
|
||||
self.quant_conv = torch.nn.Conv2d(
|
||||
2 * ddconfig['z_channels'], 2 * embed_dim, 1
|
||||
)
|
||||
self.post_quant_conv = torch.nn.Conv2d(
|
||||
embed_dim, ddconfig['z_channels'], 1
|
||||
)
|
||||
assert ddconfig["double_z"]
|
||||
self.quant_conv = torch.nn.Conv2d(2 * ddconfig["z_channels"], 2 * embed_dim, 1)
|
||||
self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
|
||||
self.embed_dim = embed_dim
|
||||
if colorize_nlabels is not None:
|
||||
assert type(colorize_nlabels) == int
|
||||
self.register_buffer(
|
||||
'colorize', torch.randn(3, colorize_nlabels, 1, 1)
|
||||
)
|
||||
self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
|
||||
if monitor is not None:
|
||||
self.monitor = monitor
|
||||
if ckpt_path is not None:
|
||||
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
|
||||
|
||||
def init_from_ckpt(self, path, ignore_keys=list()):
|
||||
sd = torch.load(path, map_location='cpu')['state_dict']
|
||||
sd = torch.load(path, map_location="cpu")["state_dict"]
|
||||
keys = list(sd.keys())
|
||||
for k in keys:
|
||||
for ik in ignore_keys:
|
||||
if k.startswith(ik):
|
||||
print('Deleting key {} from state_dict.'.format(k))
|
||||
print("Deleting key {} from state_dict.".format(k))
|
||||
del sd[k]
|
||||
self.load_state_dict(sd, strict=False)
|
||||
print(f'Restored from {path}')
|
||||
print(f"Restored from {path}")
|
||||
|
||||
def encode(self, x):
|
||||
h = self.encoder(x)
|
||||
@ -434,11 +409,7 @@ class AutoencoderKL(pl.LightningModule):
|
||||
x = batch[k]
|
||||
if len(x.shape) == 3:
|
||||
x = x[..., None]
|
||||
x = (
|
||||
x.permute(0, 3, 1, 2)
|
||||
.to(memory_format=torch.contiguous_format)
|
||||
.float()
|
||||
)
|
||||
x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
|
||||
return x
|
||||
|
||||
def training_step(self, batch, batch_idx, optimizer_idx):
|
||||
@ -454,10 +425,10 @@ class AutoencoderKL(pl.LightningModule):
|
||||
optimizer_idx,
|
||||
self.global_step,
|
||||
last_layer=self.get_last_layer(),
|
||||
split='train',
|
||||
split="train",
|
||||
)
|
||||
self.log(
|
||||
'aeloss',
|
||||
"aeloss",
|
||||
aeloss,
|
||||
prog_bar=True,
|
||||
logger=True,
|
||||
@ -482,11 +453,11 @@ class AutoencoderKL(pl.LightningModule):
|
||||
optimizer_idx,
|
||||
self.global_step,
|
||||
last_layer=self.get_last_layer(),
|
||||
split='train',
|
||||
split="train",
|
||||
)
|
||||
|
||||
self.log(
|
||||
'discloss',
|
||||
"discloss",
|
||||
discloss,
|
||||
prog_bar=True,
|
||||
logger=True,
|
||||
@ -512,7 +483,7 @@ class AutoencoderKL(pl.LightningModule):
|
||||
0,
|
||||
self.global_step,
|
||||
last_layer=self.get_last_layer(),
|
||||
split='val',
|
||||
split="val",
|
||||
)
|
||||
|
||||
discloss, log_dict_disc = self.loss(
|
||||
@ -522,10 +493,10 @@ class AutoencoderKL(pl.LightningModule):
|
||||
1,
|
||||
self.global_step,
|
||||
last_layer=self.get_last_layer(),
|
||||
split='val',
|
||||
split="val",
|
||||
)
|
||||
|
||||
self.log('val/rec_loss', log_dict_ae['val/rec_loss'])
|
||||
self.log("val/rec_loss", log_dict_ae["val/rec_loss"])
|
||||
self.log_dict(log_dict_ae)
|
||||
self.log_dict(log_dict_disc)
|
||||
return self.log_dict
|
||||
@ -560,17 +531,15 @@ class AutoencoderKL(pl.LightningModule):
|
||||
assert xrec.shape[1] > 3
|
||||
x = self.to_rgb(x)
|
||||
xrec = self.to_rgb(xrec)
|
||||
log['samples'] = self.decode(torch.randn_like(posterior.sample()))
|
||||
log['reconstructions'] = xrec
|
||||
log['inputs'] = x
|
||||
log["samples"] = self.decode(torch.randn_like(posterior.sample()))
|
||||
log["reconstructions"] = xrec
|
||||
log["inputs"] = x
|
||||
return log
|
||||
|
||||
def to_rgb(self, x):
|
||||
assert self.image_key == 'segmentation'
|
||||
if not hasattr(self, 'colorize'):
|
||||
self.register_buffer(
|
||||
'colorize', torch.randn(3, x.shape[1], 1, 1).to(x)
|
||||
)
|
||||
assert self.image_key == "segmentation"
|
||||
if not hasattr(self, "colorize"):
|
||||
self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
|
||||
x = F.conv2d(x, weight=self.colorize)
|
||||
x = 2.0 * (x - x.min()) / (x.max() - x.min()) - 1.0
|
||||
return x
|
@ -8,32 +8,50 @@ import os
|
||||
import re
|
||||
import traceback
|
||||
from typing import Callable
|
||||
from urllib import request, error as ul_error
|
||||
from huggingface_hub import HfFolder, hf_hub_url, ModelSearchArguments, ModelFilter, HfApi
|
||||
from ldm.invoke.globals import Globals
|
||||
from urllib import error as ul_error
|
||||
from urllib import request
|
||||
|
||||
from huggingface_hub import (
|
||||
HfApi,
|
||||
HfFolder,
|
||||
ModelFilter,
|
||||
ModelSearchArguments,
|
||||
hf_hub_url,
|
||||
)
|
||||
|
||||
from invokeai.backend.globals import Globals
|
||||
|
||||
|
||||
class HuggingFaceConceptsLibrary(object):
|
||||
def __init__(self, root=None):
|
||||
'''
|
||||
"""
|
||||
Initialize the Concepts object. May optionally pass a root directory.
|
||||
'''
|
||||
"""
|
||||
self.root = root or Globals.root
|
||||
self.hf_api = HfApi()
|
||||
self.local_concepts = dict()
|
||||
self.concept_list = None
|
||||
self.concepts_loaded = dict()
|
||||
self.triggers = dict() # concept name to trigger phrase
|
||||
self.concept_names = dict() # trigger phrase to concept name
|
||||
self.match_trigger = re.compile('(<[\w\- >]+>)') # trigger is slightly less restrictive than HF concept name
|
||||
self.match_concept = re.compile('<([\w\-]+)>') # HF concept name can only contain A-Za-z0-9_-
|
||||
self.triggers = dict() # concept name to trigger phrase
|
||||
self.concept_names = dict() # trigger phrase to concept name
|
||||
self.match_trigger = re.compile(
|
||||
"(<[\w\- >]+>)"
|
||||
) # trigger is slightly less restrictive than HF concept name
|
||||
self.match_concept = re.compile(
|
||||
"<([\w\-]+)>"
|
||||
) # HF concept name can only contain A-Za-z0-9_-
|
||||
|
||||
def list_concepts(self)->list:
|
||||
'''
|
||||
def list_concepts(self) -> list:
|
||||
"""
|
||||
Return a list of all the concepts by name, without the 'sd-concepts-library' part.
|
||||
Also adds local concepts in invokeai/embeddings folder.
|
||||
'''
|
||||
local_concepts_now = self.get_local_concepts(os.path.join(self.root, 'embeddings'))
|
||||
local_concepts_to_add = set(local_concepts_now).difference(set(self.local_concepts))
|
||||
"""
|
||||
local_concepts_now = self.get_local_concepts(
|
||||
os.path.join(self.root, "embeddings")
|
||||
)
|
||||
local_concepts_to_add = set(local_concepts_now).difference(
|
||||
set(self.local_concepts)
|
||||
)
|
||||
self.local_concepts.update(local_concepts_now)
|
||||
|
||||
if self.concept_list is not None:
|
||||
@ -43,83 +61,96 @@ class HuggingFaceConceptsLibrary(object):
|
||||
return self.concept_list
|
||||
else:
|
||||
try:
|
||||
models = self.hf_api.list_models(filter=ModelFilter(model_name='sd-concepts-library/'))
|
||||
self.concept_list = [a.id.split('/')[1] for a in models]
|
||||
models = self.hf_api.list_models(
|
||||
filter=ModelFilter(model_name="sd-concepts-library/")
|
||||
)
|
||||
self.concept_list = [a.id.split("/")[1] for a in models]
|
||||
# when init, add all in dir. when not init, add only concepts added between init and now
|
||||
self.concept_list.extend(list(local_concepts_to_add))
|
||||
except Exception as e:
|
||||
print(f' ** WARNING: Hugging Face textual inversion concepts libraries could not be loaded. The error was {str(e)}.')
|
||||
print(' ** You may load .bin and .pt file(s) manually using the --embedding_directory argument.')
|
||||
print(
|
||||
f" ** WARNING: Hugging Face textual inversion concepts libraries could not be loaded. The error was {str(e)}."
|
||||
)
|
||||
print(
|
||||
" ** You may load .bin and .pt file(s) manually using the --embedding_directory argument."
|
||||
)
|
||||
return self.concept_list
|
||||
|
||||
def get_concept_model_path(self, concept_name:str)->str:
|
||||
'''
|
||||
def get_concept_model_path(self, concept_name: str) -> str:
|
||||
"""
|
||||
Returns the path to the 'learned_embeds.bin' file in
|
||||
the named concept. Returns None if invalid or cannot
|
||||
be downloaded.
|
||||
'''
|
||||
"""
|
||||
if not concept_name in self.list_concepts():
|
||||
print(f'This concept is not a local embedding trigger, nor is it a HuggingFace concept. Generation will continue without the concept.')
|
||||
print(
|
||||
f"This concept is not a local embedding trigger, nor is it a HuggingFace concept. Generation will continue without the concept."
|
||||
)
|
||||
return None
|
||||
return self.get_concept_file(concept_name.lower(),'learned_embeds.bin')
|
||||
return self.get_concept_file(concept_name.lower(), "learned_embeds.bin")
|
||||
|
||||
def concept_to_trigger(self, concept_name:str)->str:
|
||||
'''
|
||||
def concept_to_trigger(self, concept_name: str) -> str:
|
||||
"""
|
||||
Given a concept name returns its trigger by looking in the
|
||||
"token_identifier.txt" file.
|
||||
'''
|
||||
"""
|
||||
if concept_name in self.triggers:
|
||||
return self.triggers[concept_name]
|
||||
elif self.concept_is_local(concept_name):
|
||||
trigger = f'<{concept_name}>'
|
||||
trigger = f"<{concept_name}>"
|
||||
self.triggers[concept_name] = trigger
|
||||
self.concept_names[trigger] = concept_name
|
||||
return trigger
|
||||
|
||||
file = self.get_concept_file(concept_name, 'token_identifier.txt', local_only=True)
|
||||
file = self.get_concept_file(
|
||||
concept_name, "token_identifier.txt", local_only=True
|
||||
)
|
||||
if not file:
|
||||
return None
|
||||
with open(file,'r') as f:
|
||||
with open(file, "r") as f:
|
||||
trigger = f.readline()
|
||||
trigger = trigger.strip()
|
||||
self.triggers[concept_name] = trigger
|
||||
self.concept_names[trigger] = concept_name
|
||||
return trigger
|
||||
|
||||
def trigger_to_concept(self, trigger:str)->str:
|
||||
'''
|
||||
def trigger_to_concept(self, trigger: str) -> str:
|
||||
"""
|
||||
Given a trigger phrase, maps it to the concept library name.
|
||||
Only works if concept_to_trigger() has previously been called
|
||||
on this library. There needs to be a persistent database for
|
||||
this.
|
||||
'''
|
||||
concept = self.concept_names.get(trigger,None)
|
||||
return f'<{concept}>' if concept else f'{trigger}'
|
||||
"""
|
||||
concept = self.concept_names.get(trigger, None)
|
||||
return f"<{concept}>" if concept else f"{trigger}"
|
||||
|
||||
def replace_triggers_with_concepts(self, prompt:str)->str:
|
||||
'''
|
||||
def replace_triggers_with_concepts(self, prompt: str) -> str:
|
||||
"""
|
||||
Given a prompt string that contains <trigger> tags, replace these
|
||||
tags with the concept name. The reason for this is so that the
|
||||
concept names get stored in the prompt metadata. There is no
|
||||
controlling of colliding triggers in the SD library, so it is
|
||||
better to store the concept name (unique) than the concept trigger
|
||||
(not necessarily unique!)
|
||||
'''
|
||||
"""
|
||||
if not prompt:
|
||||
return prompt
|
||||
triggers = self.match_trigger.findall(prompt)
|
||||
if not triggers:
|
||||
return prompt
|
||||
|
||||
def do_replace(match)->str:
|
||||
return self.trigger_to_concept(match.group(1)) or f'<{match.group(1)}>'
|
||||
def do_replace(match) -> str:
|
||||
return self.trigger_to_concept(match.group(1)) or f"<{match.group(1)}>"
|
||||
|
||||
return self.match_trigger.sub(do_replace, prompt)
|
||||
|
||||
def replace_concepts_with_triggers(self,
|
||||
prompt:str,
|
||||
load_concepts_callback: Callable[[list], any],
|
||||
excluded_tokens:list[str])->str:
|
||||
'''
|
||||
def replace_concepts_with_triggers(
|
||||
self,
|
||||
prompt: str,
|
||||
load_concepts_callback: Callable[[list], any],
|
||||
excluded_tokens: list[str],
|
||||
) -> str:
|
||||
"""
|
||||
Given a prompt string that contains `<concept_name>` tags, replace
|
||||
these tags with the appropriate trigger.
|
||||
|
||||
@ -128,20 +159,30 @@ class HuggingFaceConceptsLibrary(object):
|
||||
|
||||
`excluded_tokens` are any tokens that should not be replaced, typically because they
|
||||
are trigger tokens from a locally-loaded embedding.
|
||||
'''
|
||||
"""
|
||||
concepts = self.match_concept.findall(prompt)
|
||||
if not concepts:
|
||||
return prompt
|
||||
load_concepts_callback(concepts)
|
||||
|
||||
def do_replace(match)->str:
|
||||
if excluded_tokens and f'<{match.group(1)}>' in excluded_tokens:
|
||||
return f'<{match.group(1)}>'
|
||||
return self.concept_to_trigger(match.group(1)) or f'<{match.group(1)}>'
|
||||
def do_replace(match) -> str:
|
||||
if excluded_tokens and f"<{match.group(1)}>" in excluded_tokens:
|
||||
return f"<{match.group(1)}>"
|
||||
return self.concept_to_trigger(match.group(1)) or f"<{match.group(1)}>"
|
||||
|
||||
return self.match_concept.sub(do_replace, prompt)
|
||||
|
||||
def get_concept_file(self, concept_name:str, file_name:str='learned_embeds.bin' , local_only:bool=False)->str:
|
||||
if not (self.concept_is_downloaded(concept_name) or self.concept_is_local(concept_name) or local_only):
|
||||
def get_concept_file(
|
||||
self,
|
||||
concept_name: str,
|
||||
file_name: str = "learned_embeds.bin",
|
||||
local_only: bool = False,
|
||||
) -> str:
|
||||
if not (
|
||||
self.concept_is_downloaded(concept_name)
|
||||
or self.concept_is_local(concept_name)
|
||||
or local_only
|
||||
):
|
||||
self.download_concept(concept_name)
|
||||
|
||||
# get local path in invokeai/embeddings if local concept
|
||||
@ -153,19 +194,19 @@ class HuggingFaceConceptsLibrary(object):
|
||||
path = os.path.join(concept_path, file_name)
|
||||
return path if os.path.exists(path) else None
|
||||
|
||||
def concept_is_local(self, concept_name)->bool:
|
||||
def concept_is_local(self, concept_name) -> bool:
|
||||
return concept_name in self.local_concepts
|
||||
|
||||
def concept_is_downloaded(self, concept_name)->bool:
|
||||
def concept_is_downloaded(self, concept_name) -> bool:
|
||||
concept_directory = self._concept_path(concept_name)
|
||||
return os.path.exists(concept_directory)
|
||||
|
||||
def download_concept(self,concept_name)->bool:
|
||||
def download_concept(self, concept_name) -> bool:
|
||||
repo_id = self._concept_id(concept_name)
|
||||
dest = self._concept_path(concept_name)
|
||||
|
||||
access_token = HfFolder.get_token()
|
||||
header = [("Authorization", f'Bearer {access_token}')] if access_token else []
|
||||
header = [("Authorization", f"Bearer {access_token}")] if access_token else []
|
||||
opener = request.build_opener()
|
||||
opener.addheaders = header
|
||||
request.install_opener(opener)
|
||||
@ -174,45 +215,59 @@ class HuggingFaceConceptsLibrary(object):
|
||||
succeeded = True
|
||||
|
||||
bytes = 0
|
||||
|
||||
def tally_download_size(chunk, size, total):
|
||||
nonlocal bytes
|
||||
if chunk==0:
|
||||
if chunk == 0:
|
||||
bytes += total
|
||||
|
||||
print(f'>> Downloading {repo_id}...',end='')
|
||||
print(f">> Downloading {repo_id}...", end="")
|
||||
try:
|
||||
for file in ('README.md','learned_embeds.bin','token_identifier.txt','type_of_concept.txt'):
|
||||
for file in (
|
||||
"README.md",
|
||||
"learned_embeds.bin",
|
||||
"token_identifier.txt",
|
||||
"type_of_concept.txt",
|
||||
):
|
||||
url = hf_hub_url(repo_id, file)
|
||||
request.urlretrieve(url, os.path.join(dest,file),reporthook=tally_download_size)
|
||||
request.urlretrieve(
|
||||
url, os.path.join(dest, file), reporthook=tally_download_size
|
||||
)
|
||||
except ul_error.HTTPError as e:
|
||||
if e.code==404:
|
||||
print(f'This concept is not known to the Hugging Face library. Generation will continue without the concept.')
|
||||
if e.code == 404:
|
||||
print(
|
||||
f"This concept is not known to the Hugging Face library. Generation will continue without the concept."
|
||||
)
|
||||
else:
|
||||
print(f'Failed to download {concept_name}/{file} ({str(e)}. Generation will continue without the concept.)')
|
||||
print(
|
||||
f"Failed to download {concept_name}/{file} ({str(e)}. Generation will continue without the concept.)"
|
||||
)
|
||||
os.rmdir(dest)
|
||||
return False
|
||||
except ul_error.URLError as e:
|
||||
print(f'ERROR: {str(e)}. This may reflect a network issue. Generation will continue without the concept.')
|
||||
print(
|
||||
f"ERROR: {str(e)}. This may reflect a network issue. Generation will continue without the concept."
|
||||
)
|
||||
os.rmdir(dest)
|
||||
return False
|
||||
print('...{:.2f}Kb'.format(bytes/1024))
|
||||
print("...{:.2f}Kb".format(bytes / 1024))
|
||||
return succeeded
|
||||
|
||||
def _concept_id(self, concept_name:str)->str:
|
||||
return f'sd-concepts-library/{concept_name}'
|
||||
def _concept_id(self, concept_name: str) -> str:
|
||||
return f"sd-concepts-library/{concept_name}"
|
||||
|
||||
def _concept_path(self, concept_name:str)->str:
|
||||
return os.path.join(self.root,'models','sd-concepts-library',concept_name)
|
||||
def _concept_path(self, concept_name: str) -> str:
|
||||
return os.path.join(self.root, "models", "sd-concepts-library", concept_name)
|
||||
|
||||
def _concept_local_path(self, concept_name:str)->str:
|
||||
def _concept_local_path(self, concept_name: str) -> str:
|
||||
filename = self.local_concepts[concept_name]
|
||||
return os.path.join(self.root,'embeddings',filename)
|
||||
return os.path.join(self.root, "embeddings", filename)
|
||||
|
||||
def get_local_concepts(self, loc_dir:str):
|
||||
def get_local_concepts(self, loc_dir: str):
|
||||
locs_dic = dict()
|
||||
if os.path.isdir(loc_dir):
|
||||
for file in os.listdir(loc_dir):
|
||||
f = os.path.splitext(file)
|
||||
if f[1] == '.bin' or f[1] == '.pt':
|
||||
if f[1] == ".bin" or f[1] == ".pt":
|
||||
locs_dic[f[0]] = file
|
||||
return locs_dic
|
@ -1,10 +1,6 @@
|
||||
from abc import abstractmethod
|
||||
from torch.utils.data import (
|
||||
Dataset,
|
||||
ConcatDataset,
|
||||
ChainDataset,
|
||||
IterableDataset,
|
||||
)
|
||||
|
||||
from torch.utils.data import ChainDataset, ConcatDataset, Dataset, IterableDataset
|
||||
|
||||
|
||||
class Txt2ImgIterableBaseDataset(IterableDataset):
|
||||
@ -19,9 +15,7 @@ class Txt2ImgIterableBaseDataset(IterableDataset):
|
||||
self.sample_ids = valid_ids
|
||||
self.size = size
|
||||
|
||||
print(
|
||||
f'{self.__class__.__name__} dataset contains {self.__len__()} examples.'
|
||||
)
|
||||
print(f"{self.__class__.__name__} dataset contains {self.__len__()} examples.")
|
||||
|
||||
def __len__(self):
|
||||
return self.num_records
|
@ -1,31 +1,32 @@
|
||||
import os, yaml, pickle, shutil, tarfile, glob
|
||||
import cv2
|
||||
import albumentations
|
||||
import PIL
|
||||
import numpy as np
|
||||
import torchvision.transforms.functional as TF
|
||||
from omegaconf import OmegaConf
|
||||
import glob
|
||||
import os
|
||||
import pickle
|
||||
import shutil
|
||||
import tarfile
|
||||
from functools import partial
|
||||
from PIL import Image
|
||||
from tqdm import tqdm
|
||||
from torch.utils.data import Dataset, Subset
|
||||
|
||||
import albumentations
|
||||
import cv2
|
||||
import numpy as np
|
||||
import PIL
|
||||
import taming.data.utils as tdu
|
||||
import torchvision.transforms.functional as TF
|
||||
import yaml
|
||||
from ldm.modules.image_degradation import degradation_fn_bsr, degradation_fn_bsr_light
|
||||
from omegaconf import OmegaConf
|
||||
from PIL import Image
|
||||
from taming.data.imagenet import (
|
||||
str_to_indices,
|
||||
give_synsets_from_indices,
|
||||
ImagePaths,
|
||||
download,
|
||||
give_synsets_from_indices,
|
||||
retrieve,
|
||||
str_to_indices,
|
||||
)
|
||||
from taming.data.imagenet import ImagePaths
|
||||
|
||||
from ldm.modules.image_degradation import (
|
||||
degradation_fn_bsr,
|
||||
degradation_fn_bsr_light,
|
||||
)
|
||||
from torch.utils.data import Dataset, Subset
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
def synset2idx(path_to_yaml='data/index_synset.yaml'):
|
||||
def synset2idx(path_to_yaml="data/index_synset.yaml"):
|
||||
with open(path_to_yaml) as f:
|
||||
di2s = yaml.load(f)
|
||||
return dict((v, k) for k, v in di2s.items())
|
||||
@ -36,9 +37,7 @@ class ImageNetBase(Dataset):
|
||||
self.config = config or OmegaConf.create()
|
||||
if not type(self.config) == dict:
|
||||
self.config = OmegaConf.to_container(self.config)
|
||||
self.keep_orig_class_label = self.config.get(
|
||||
'keep_orig_class_label', False
|
||||
)
|
||||
self.keep_orig_class_label = self.config.get("keep_orig_class_label", False)
|
||||
self.process_images = True # if False we skip loading & processing images and self.data contains filepaths
|
||||
self._prepare()
|
||||
self._prepare_synset_to_human()
|
||||
@ -58,21 +57,19 @@ class ImageNetBase(Dataset):
|
||||
def _filter_relpaths(self, relpaths):
|
||||
ignore = set(
|
||||
[
|
||||
'n06596364_9591.JPEG',
|
||||
"n06596364_9591.JPEG",
|
||||
]
|
||||
)
|
||||
relpaths = [
|
||||
rpath for rpath in relpaths if not rpath.split('/')[-1] in ignore
|
||||
]
|
||||
if 'sub_indices' in self.config:
|
||||
indices = str_to_indices(self.config['sub_indices'])
|
||||
relpaths = [rpath for rpath in relpaths if not rpath.split("/")[-1] in ignore]
|
||||
if "sub_indices" in self.config:
|
||||
indices = str_to_indices(self.config["sub_indices"])
|
||||
synsets = give_synsets_from_indices(
|
||||
indices, path_to_yaml=self.idx2syn
|
||||
) # returns a list of strings
|
||||
self.synset2idx = synset2idx(path_to_yaml=self.idx2syn)
|
||||
files = []
|
||||
for rpath in relpaths:
|
||||
syn = rpath.split('/')[0]
|
||||
syn = rpath.split("/")[0]
|
||||
if syn in synsets:
|
||||
files.append(rpath)
|
||||
return files
|
||||
@ -81,8 +78,8 @@ class ImageNetBase(Dataset):
|
||||
|
||||
def _prepare_synset_to_human(self):
|
||||
SIZE = 2655750
|
||||
URL = 'https://heibox.uni-heidelberg.de/f/9f28e956cd304264bb82/?dl=1'
|
||||
self.human_dict = os.path.join(self.root, 'synset_human.txt')
|
||||
URL = "https://heibox.uni-heidelberg.de/f/9f28e956cd304264bb82/?dl=1"
|
||||
self.human_dict = os.path.join(self.root, "synset_human.txt")
|
||||
if (
|
||||
not os.path.exists(self.human_dict)
|
||||
or not os.path.getsize(self.human_dict) == SIZE
|
||||
@ -90,64 +87,62 @@ class ImageNetBase(Dataset):
|
||||
download(URL, self.human_dict)
|
||||
|
||||
def _prepare_idx_to_synset(self):
|
||||
URL = 'https://heibox.uni-heidelberg.de/f/d835d5b6ceda4d3aa910/?dl=1'
|
||||
self.idx2syn = os.path.join(self.root, 'index_synset.yaml')
|
||||
URL = "https://heibox.uni-heidelberg.de/f/d835d5b6ceda4d3aa910/?dl=1"
|
||||
self.idx2syn = os.path.join(self.root, "index_synset.yaml")
|
||||
if not os.path.exists(self.idx2syn):
|
||||
download(URL, self.idx2syn)
|
||||
|
||||
def _prepare_human_to_integer_label(self):
|
||||
URL = 'https://heibox.uni-heidelberg.de/f/2362b797d5be43b883f6/?dl=1'
|
||||
URL = "https://heibox.uni-heidelberg.de/f/2362b797d5be43b883f6/?dl=1"
|
||||
self.human2integer = os.path.join(
|
||||
self.root, 'imagenet1000_clsidx_to_labels.txt'
|
||||
self.root, "imagenet1000_clsidx_to_labels.txt"
|
||||
)
|
||||
if not os.path.exists(self.human2integer):
|
||||
download(URL, self.human2integer)
|
||||
with open(self.human2integer, 'r') as f:
|
||||
with open(self.human2integer, "r") as f:
|
||||
lines = f.read().splitlines()
|
||||
assert len(lines) == 1000
|
||||
self.human2integer_dict = dict()
|
||||
for line in lines:
|
||||
value, key = line.split(':')
|
||||
value, key = line.split(":")
|
||||
self.human2integer_dict[key] = int(value)
|
||||
|
||||
def _load(self):
|
||||
with open(self.txt_filelist, 'r') as f:
|
||||
with open(self.txt_filelist, "r") as f:
|
||||
self.relpaths = f.read().splitlines()
|
||||
l1 = len(self.relpaths)
|
||||
self.relpaths = self._filter_relpaths(self.relpaths)
|
||||
print(
|
||||
'Removed {} files from filelist during filtering.'.format(
|
||||
"Removed {} files from filelist during filtering.".format(
|
||||
l1 - len(self.relpaths)
|
||||
)
|
||||
)
|
||||
|
||||
self.synsets = [p.split('/')[0] for p in self.relpaths]
|
||||
self.synsets = [p.split("/")[0] for p in self.relpaths]
|
||||
self.abspaths = [os.path.join(self.datadir, p) for p in self.relpaths]
|
||||
|
||||
unique_synsets = np.unique(self.synsets)
|
||||
class_dict = dict(
|
||||
(synset, i) for i, synset in enumerate(unique_synsets)
|
||||
)
|
||||
class_dict = dict((synset, i) for i, synset in enumerate(unique_synsets))
|
||||
if not self.keep_orig_class_label:
|
||||
self.class_labels = [class_dict[s] for s in self.synsets]
|
||||
else:
|
||||
self.class_labels = [self.synset2idx[s] for s in self.synsets]
|
||||
|
||||
with open(self.human_dict, 'r') as f:
|
||||
with open(self.human_dict, "r") as f:
|
||||
human_dict = f.read().splitlines()
|
||||
human_dict = dict(line.split(maxsplit=1) for line in human_dict)
|
||||
|
||||
self.human_labels = [human_dict[s] for s in self.synsets]
|
||||
|
||||
labels = {
|
||||
'relpath': np.array(self.relpaths),
|
||||
'synsets': np.array(self.synsets),
|
||||
'class_label': np.array(self.class_labels),
|
||||
'human_label': np.array(self.human_labels),
|
||||
"relpath": np.array(self.relpaths),
|
||||
"synsets": np.array(self.synsets),
|
||||
"class_label": np.array(self.class_labels),
|
||||
"human_label": np.array(self.human_labels),
|
||||
}
|
||||
|
||||
if self.process_images:
|
||||
self.size = retrieve(self.config, 'size', default=256)
|
||||
self.size = retrieve(self.config, "size", default=256)
|
||||
self.data = ImagePaths(
|
||||
self.abspaths,
|
||||
labels=labels,
|
||||
@ -159,11 +154,11 @@ class ImageNetBase(Dataset):
|
||||
|
||||
|
||||
class ImageNetTrain(ImageNetBase):
|
||||
NAME = 'ILSVRC2012_train'
|
||||
URL = 'http://www.image-net.org/challenges/LSVRC/2012/'
|
||||
AT_HASH = 'a306397ccf9c2ead27155983c254227c0fd938e2'
|
||||
NAME = "ILSVRC2012_train"
|
||||
URL = "http://www.image-net.org/challenges/LSVRC/2012/"
|
||||
AT_HASH = "a306397ccf9c2ead27155983c254227c0fd938e2"
|
||||
FILES = [
|
||||
'ILSVRC2012_img_train.tar',
|
||||
"ILSVRC2012_img_train.tar",
|
||||
]
|
||||
SIZES = [
|
||||
147897477120,
|
||||
@ -178,20 +173,18 @@ class ImageNetTrain(ImageNetBase):
|
||||
if self.data_root:
|
||||
self.root = os.path.join(self.data_root, self.NAME)
|
||||
else:
|
||||
cachedir = os.environ.get(
|
||||
'XDG_CACHE_HOME', os.path.expanduser('~/.cache')
|
||||
)
|
||||
self.root = os.path.join(cachedir, 'autoencoders/data', self.NAME)
|
||||
cachedir = os.environ.get("XDG_CACHE_HOME", os.path.expanduser("~/.cache"))
|
||||
self.root = os.path.join(cachedir, "autoencoders/data", self.NAME)
|
||||
|
||||
self.datadir = os.path.join(self.root, 'data')
|
||||
self.txt_filelist = os.path.join(self.root, 'filelist.txt')
|
||||
self.datadir = os.path.join(self.root, "data")
|
||||
self.txt_filelist = os.path.join(self.root, "filelist.txt")
|
||||
self.expected_length = 1281167
|
||||
self.random_crop = retrieve(
|
||||
self.config, 'ImageNetTrain/random_crop', default=True
|
||||
self.config, "ImageNetTrain/random_crop", default=True
|
||||
)
|
||||
if not tdu.is_prepared(self.root):
|
||||
# prep
|
||||
print('Preparing dataset {} in {}'.format(self.NAME, self.root))
|
||||
print("Preparing dataset {} in {}".format(self.NAME, self.root))
|
||||
|
||||
datadir = self.datadir
|
||||
if not os.path.exists(datadir):
|
||||
@ -205,37 +198,37 @@ class ImageNetTrain(ImageNetBase):
|
||||
atpath = at.get(self.AT_HASH, datastore=self.root)
|
||||
assert atpath == path
|
||||
|
||||
print('Extracting {} to {}'.format(path, datadir))
|
||||
print("Extracting {} to {}".format(path, datadir))
|
||||
os.makedirs(datadir, exist_ok=True)
|
||||
with tarfile.open(path, 'r:') as tar:
|
||||
with tarfile.open(path, "r:") as tar:
|
||||
tar.extractall(path=datadir)
|
||||
|
||||
print('Extracting sub-tars.')
|
||||
subpaths = sorted(glob.glob(os.path.join(datadir, '*.tar')))
|
||||
print("Extracting sub-tars.")
|
||||
subpaths = sorted(glob.glob(os.path.join(datadir, "*.tar")))
|
||||
for subpath in tqdm(subpaths):
|
||||
subdir = subpath[: -len('.tar')]
|
||||
subdir = subpath[: -len(".tar")]
|
||||
os.makedirs(subdir, exist_ok=True)
|
||||
with tarfile.open(subpath, 'r:') as tar:
|
||||
with tarfile.open(subpath, "r:") as tar:
|
||||
tar.extractall(path=subdir)
|
||||
|
||||
filelist = glob.glob(os.path.join(datadir, '**', '*.JPEG'))
|
||||
filelist = glob.glob(os.path.join(datadir, "**", "*.JPEG"))
|
||||
filelist = [os.path.relpath(p, start=datadir) for p in filelist]
|
||||
filelist = sorted(filelist)
|
||||
filelist = '\n'.join(filelist) + '\n'
|
||||
with open(self.txt_filelist, 'w') as f:
|
||||
filelist = "\n".join(filelist) + "\n"
|
||||
with open(self.txt_filelist, "w") as f:
|
||||
f.write(filelist)
|
||||
|
||||
tdu.mark_prepared(self.root)
|
||||
|
||||
|
||||
class ImageNetValidation(ImageNetBase):
|
||||
NAME = 'ILSVRC2012_validation'
|
||||
URL = 'http://www.image-net.org/challenges/LSVRC/2012/'
|
||||
AT_HASH = '5d6d0df7ed81efd49ca99ea4737e0ae5e3a5f2e5'
|
||||
VS_URL = 'https://heibox.uni-heidelberg.de/f/3e0f6e9c624e45f2bd73/?dl=1'
|
||||
NAME = "ILSVRC2012_validation"
|
||||
URL = "http://www.image-net.org/challenges/LSVRC/2012/"
|
||||
AT_HASH = "5d6d0df7ed81efd49ca99ea4737e0ae5e3a5f2e5"
|
||||
VS_URL = "https://heibox.uni-heidelberg.de/f/3e0f6e9c624e45f2bd73/?dl=1"
|
||||
FILES = [
|
||||
'ILSVRC2012_img_val.tar',
|
||||
'validation_synset.txt',
|
||||
"ILSVRC2012_img_val.tar",
|
||||
"validation_synset.txt",
|
||||
]
|
||||
SIZES = [
|
||||
6744924160,
|
||||
@ -251,19 +244,17 @@ class ImageNetValidation(ImageNetBase):
|
||||
if self.data_root:
|
||||
self.root = os.path.join(self.data_root, self.NAME)
|
||||
else:
|
||||
cachedir = os.environ.get(
|
||||
'XDG_CACHE_HOME', os.path.expanduser('~/.cache')
|
||||
)
|
||||
self.root = os.path.join(cachedir, 'autoencoders/data', self.NAME)
|
||||
self.datadir = os.path.join(self.root, 'data')
|
||||
self.txt_filelist = os.path.join(self.root, 'filelist.txt')
|
||||
cachedir = os.environ.get("XDG_CACHE_HOME", os.path.expanduser("~/.cache"))
|
||||
self.root = os.path.join(cachedir, "autoencoders/data", self.NAME)
|
||||
self.datadir = os.path.join(self.root, "data")
|
||||
self.txt_filelist = os.path.join(self.root, "filelist.txt")
|
||||
self.expected_length = 50000
|
||||
self.random_crop = retrieve(
|
||||
self.config, 'ImageNetValidation/random_crop', default=False
|
||||
self.config, "ImageNetValidation/random_crop", default=False
|
||||
)
|
||||
if not tdu.is_prepared(self.root):
|
||||
# prep
|
||||
print('Preparing dataset {} in {}'.format(self.NAME, self.root))
|
||||
print("Preparing dataset {} in {}".format(self.NAME, self.root))
|
||||
|
||||
datadir = self.datadir
|
||||
if not os.path.exists(datadir):
|
||||
@ -277,9 +268,9 @@ class ImageNetValidation(ImageNetBase):
|
||||
atpath = at.get(self.AT_HASH, datastore=self.root)
|
||||
assert atpath == path
|
||||
|
||||
print('Extracting {} to {}'.format(path, datadir))
|
||||
print("Extracting {} to {}".format(path, datadir))
|
||||
os.makedirs(datadir, exist_ok=True)
|
||||
with tarfile.open(path, 'r:') as tar:
|
||||
with tarfile.open(path, "r:") as tar:
|
||||
tar.extractall(path=datadir)
|
||||
|
||||
vspath = os.path.join(self.root, self.FILES[1])
|
||||
@ -289,11 +280,11 @@ class ImageNetValidation(ImageNetBase):
|
||||
):
|
||||
download(self.VS_URL, vspath)
|
||||
|
||||
with open(vspath, 'r') as f:
|
||||
with open(vspath, "r") as f:
|
||||
synset_dict = f.read().splitlines()
|
||||
synset_dict = dict(line.split() for line in synset_dict)
|
||||
|
||||
print('Reorganizing into synset folders')
|
||||
print("Reorganizing into synset folders")
|
||||
synsets = np.unique(list(synset_dict.values()))
|
||||
for s in synsets:
|
||||
os.makedirs(os.path.join(datadir, s), exist_ok=True)
|
||||
@ -302,11 +293,11 @@ class ImageNetValidation(ImageNetBase):
|
||||
dst = os.path.join(datadir, v)
|
||||
shutil.move(src, dst)
|
||||
|
||||
filelist = glob.glob(os.path.join(datadir, '**', '*.JPEG'))
|
||||
filelist = glob.glob(os.path.join(datadir, "**", "*.JPEG"))
|
||||
filelist = [os.path.relpath(p, start=datadir) for p in filelist]
|
||||
filelist = sorted(filelist)
|
||||
filelist = '\n'.join(filelist) + '\n'
|
||||
with open(self.txt_filelist, 'w') as f:
|
||||
filelist = "\n".join(filelist) + "\n"
|
||||
with open(self.txt_filelist, "w") as f:
|
||||
f.write(filelist)
|
||||
|
||||
tdu.mark_prepared(self.root)
|
||||
@ -356,32 +347,28 @@ class ImageNetSR(Dataset):
|
||||
False # gets reset later if incase interp_op is from pillow
|
||||
)
|
||||
|
||||
if degradation == 'bsrgan':
|
||||
self.degradation_process = partial(
|
||||
degradation_fn_bsr, sf=downscale_f
|
||||
)
|
||||
if degradation == "bsrgan":
|
||||
self.degradation_process = partial(degradation_fn_bsr, sf=downscale_f)
|
||||
|
||||
elif degradation == 'bsrgan_light':
|
||||
self.degradation_process = partial(
|
||||
degradation_fn_bsr_light, sf=downscale_f
|
||||
)
|
||||
elif degradation == "bsrgan_light":
|
||||
self.degradation_process = partial(degradation_fn_bsr_light, sf=downscale_f)
|
||||
|
||||
else:
|
||||
interpolation_fn = {
|
||||
'cv_nearest': cv2.INTER_NEAREST,
|
||||
'cv_bilinear': cv2.INTER_LINEAR,
|
||||
'cv_bicubic': cv2.INTER_CUBIC,
|
||||
'cv_area': cv2.INTER_AREA,
|
||||
'cv_lanczos': cv2.INTER_LANCZOS4,
|
||||
'pil_nearest': PIL.Image.NEAREST,
|
||||
'pil_bilinear': PIL.Image.BILINEAR,
|
||||
'pil_bicubic': PIL.Image.BICUBIC,
|
||||
'pil_box': PIL.Image.BOX,
|
||||
'pil_hamming': PIL.Image.HAMMING,
|
||||
'pil_lanczos': PIL.Image.LANCZOS,
|
||||
"cv_nearest": cv2.INTER_NEAREST,
|
||||
"cv_bilinear": cv2.INTER_LINEAR,
|
||||
"cv_bicubic": cv2.INTER_CUBIC,
|
||||
"cv_area": cv2.INTER_AREA,
|
||||
"cv_lanczos": cv2.INTER_LANCZOS4,
|
||||
"pil_nearest": PIL.Image.NEAREST,
|
||||
"pil_bilinear": PIL.Image.BILINEAR,
|
||||
"pil_bicubic": PIL.Image.BICUBIC,
|
||||
"pil_box": PIL.Image.BOX,
|
||||
"pil_hamming": PIL.Image.HAMMING,
|
||||
"pil_lanczos": PIL.Image.LANCZOS,
|
||||
}[degradation]
|
||||
|
||||
self.pil_interpolation = degradation.startswith('pil_')
|
||||
self.pil_interpolation = degradation.startswith("pil_")
|
||||
|
||||
if self.pil_interpolation:
|
||||
self.degradation_process = partial(
|
||||
@ -400,10 +387,10 @@ class ImageNetSR(Dataset):
|
||||
|
||||
def __getitem__(self, i):
|
||||
example = self.base[i]
|
||||
image = Image.open(example['file_path_'])
|
||||
image = Image.open(example["file_path_"])
|
||||
|
||||
if not image.mode == 'RGB':
|
||||
image = image.convert('RGB')
|
||||
if not image.mode == "RGB":
|
||||
image = image.convert("RGB")
|
||||
|
||||
image = np.array(image).astype(np.uint8)
|
||||
|
||||
@ -423,8 +410,8 @@ class ImageNetSR(Dataset):
|
||||
height=crop_side_len, width=crop_side_len
|
||||
)
|
||||
|
||||
image = self.cropper(image=image)['image']
|
||||
image = self.image_rescaler(image=image)['image']
|
||||
image = self.cropper(image=image)["image"]
|
||||
image = self.image_rescaler(image=image)["image"]
|
||||
|
||||
if self.pil_interpolation:
|
||||
image_pil = PIL.Image.fromarray(image)
|
||||
@ -432,10 +419,10 @@ class ImageNetSR(Dataset):
|
||||
LR_image = np.array(LR_image).astype(np.uint8)
|
||||
|
||||
else:
|
||||
LR_image = self.degradation_process(image=image)['image']
|
||||
LR_image = self.degradation_process(image=image)["image"]
|
||||
|
||||
example['image'] = (image / 127.5 - 1.0).astype(np.float32)
|
||||
example['LR_image'] = (LR_image / 127.5 - 1.0).astype(np.float32)
|
||||
example["image"] = (image / 127.5 - 1.0).astype(np.float32)
|
||||
example["LR_image"] = (LR_image / 127.5 - 1.0).astype(np.float32)
|
||||
|
||||
return example
|
||||
|
||||
@ -445,7 +432,7 @@ class ImageNetSRTrain(ImageNetSR):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def get_base(self):
|
||||
with open('data/imagenet_train_hr_indices.p', 'rb') as f:
|
||||
with open("data/imagenet_train_hr_indices.p", "rb") as f:
|
||||
indices = pickle.load(f)
|
||||
dset = ImageNetTrain(
|
||||
process_images=False,
|
||||
@ -458,7 +445,7 @@ class ImageNetSRValidation(ImageNetSR):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def get_base(self):
|
||||
with open('data/imagenet_val_hr_indices.p', 'rb') as f:
|
||||
with open("data/imagenet_val_hr_indices.p", "rb") as f:
|
||||
indices = pickle.load(f)
|
||||
dset = ImageNetValidation(
|
||||
process_images=False,
|
@ -1,4 +1,5 @@
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
import PIL
|
||||
from PIL import Image
|
||||
@ -12,27 +13,25 @@ class LSUNBase(Dataset):
|
||||
txt_file,
|
||||
data_root,
|
||||
size=None,
|
||||
interpolation='bicubic',
|
||||
interpolation="bicubic",
|
||||
flip_p=0.5,
|
||||
):
|
||||
self.data_paths = txt_file
|
||||
self.data_root = data_root
|
||||
with open(self.data_paths, 'r') as f:
|
||||
with open(self.data_paths, "r") as f:
|
||||
self.image_paths = f.read().splitlines()
|
||||
self._length = len(self.image_paths)
|
||||
self.labels = {
|
||||
'relative_file_path_': [l for l in self.image_paths],
|
||||
'file_path_': [
|
||||
os.path.join(self.data_root, l) for l in self.image_paths
|
||||
],
|
||||
"relative_file_path_": [l for l in self.image_paths],
|
||||
"file_path_": [os.path.join(self.data_root, l) for l in self.image_paths],
|
||||
}
|
||||
|
||||
self.size = size
|
||||
self.interpolation = {
|
||||
'linear': PIL.Image.LINEAR,
|
||||
'bilinear': PIL.Image.BILINEAR,
|
||||
'bicubic': PIL.Image.BICUBIC,
|
||||
'lanczos': PIL.Image.LANCZOS,
|
||||
"linear": PIL.Image.LINEAR,
|
||||
"bilinear": PIL.Image.BILINEAR,
|
||||
"bicubic": PIL.Image.BICUBIC,
|
||||
"lanczos": PIL.Image.LANCZOS,
|
||||
}[interpolation]
|
||||
self.flip = transforms.RandomHorizontalFlip(p=flip_p)
|
||||
|
||||
@ -41,14 +40,17 @@ class LSUNBase(Dataset):
|
||||
|
||||
def __getitem__(self, i):
|
||||
example = dict((k, self.labels[k][i]) for k in self.labels)
|
||||
image = Image.open(example['file_path_'])
|
||||
if not image.mode == 'RGB':
|
||||
image = image.convert('RGB')
|
||||
image = Image.open(example["file_path_"])
|
||||
if not image.mode == "RGB":
|
||||
image = image.convert("RGB")
|
||||
|
||||
# default to score-sde preprocessing
|
||||
img = np.array(image).astype(np.uint8)
|
||||
crop = min(img.shape[0], img.shape[1])
|
||||
h, w, = (
|
||||
(
|
||||
h,
|
||||
w,
|
||||
) = (
|
||||
img.shape[0],
|
||||
img.shape[1],
|
||||
)
|
||||
@ -59,68 +61,64 @@ class LSUNBase(Dataset):
|
||||
|
||||
image = Image.fromarray(img)
|
||||
if self.size is not None:
|
||||
image = image.resize(
|
||||
(self.size, self.size), resample=self.interpolation
|
||||
)
|
||||
image = image.resize((self.size, self.size), resample=self.interpolation)
|
||||
|
||||
image = self.flip(image)
|
||||
image = np.array(image).astype(np.uint8)
|
||||
example['image'] = (image / 127.5 - 1.0).astype(np.float32)
|
||||
example["image"] = (image / 127.5 - 1.0).astype(np.float32)
|
||||
return example
|
||||
|
||||
|
||||
class LSUNChurchesTrain(LSUNBase):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(
|
||||
txt_file='data/lsun/church_outdoor_train.txt',
|
||||
data_root='data/lsun/churches',
|
||||
**kwargs
|
||||
txt_file="data/lsun/church_outdoor_train.txt",
|
||||
data_root="data/lsun/churches",
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
class LSUNChurchesValidation(LSUNBase):
|
||||
def __init__(self, flip_p=0.0, **kwargs):
|
||||
super().__init__(
|
||||
txt_file='data/lsun/church_outdoor_val.txt',
|
||||
data_root='data/lsun/churches',
|
||||
txt_file="data/lsun/church_outdoor_val.txt",
|
||||
data_root="data/lsun/churches",
|
||||
flip_p=flip_p,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
class LSUNBedroomsTrain(LSUNBase):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(
|
||||
txt_file='data/lsun/bedrooms_train.txt',
|
||||
data_root='data/lsun/bedrooms',
|
||||
**kwargs
|
||||
txt_file="data/lsun/bedrooms_train.txt",
|
||||
data_root="data/lsun/bedrooms",
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
class LSUNBedroomsValidation(LSUNBase):
|
||||
def __init__(self, flip_p=0.0, **kwargs):
|
||||
super().__init__(
|
||||
txt_file='data/lsun/bedrooms_val.txt',
|
||||
data_root='data/lsun/bedrooms',
|
||||
txt_file="data/lsun/bedrooms_val.txt",
|
||||
data_root="data/lsun/bedrooms",
|
||||
flip_p=flip_p,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
class LSUNCatsTrain(LSUNBase):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(
|
||||
txt_file='data/lsun/cat_train.txt',
|
||||
data_root='data/lsun/cats',
|
||||
**kwargs
|
||||
txt_file="data/lsun/cat_train.txt", data_root="data/lsun/cats", **kwargs
|
||||
)
|
||||
|
||||
|
||||
class LSUNCatsValidation(LSUNBase):
|
||||
def __init__(self, flip_p=0.0, **kwargs):
|
||||
super().__init__(
|
||||
txt_file='data/lsun/cat_val.txt',
|
||||
data_root='data/lsun/cats',
|
||||
txt_file="data/lsun/cat_val.txt",
|
||||
data_root="data/lsun/cats",
|
||||
flip_p=flip_p,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
)
|
199
invokeai/backend/stable_diffusion/data/personalized.py
Normal file
199
invokeai/backend/stable_diffusion/data/personalized.py
Normal file
@ -0,0 +1,199 @@
|
||||
import os
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
import PIL
|
||||
from PIL import Image
|
||||
from torch.utils.data import Dataset
|
||||
from torchvision import transforms
|
||||
|
||||
imagenet_templates_smallest = [
|
||||
"a photo of a {}",
|
||||
]
|
||||
|
||||
imagenet_templates_small = [
|
||||
"a photo of a {}",
|
||||
"a rendering of a {}",
|
||||
"a cropped photo of the {}",
|
||||
"the photo of a {}",
|
||||
"a photo of a clean {}",
|
||||
"a photo of a dirty {}",
|
||||
"a dark photo of the {}",
|
||||
"a photo of my {}",
|
||||
"a photo of the cool {}",
|
||||
"a close-up photo of a {}",
|
||||
"a bright photo of the {}",
|
||||
"a cropped photo of a {}",
|
||||
"a photo of the {}",
|
||||
"a good photo of the {}",
|
||||
"a photo of one {}",
|
||||
"a close-up photo of the {}",
|
||||
"a rendition of the {}",
|
||||
"a photo of the clean {}",
|
||||
"a rendition of a {}",
|
||||
"a photo of a nice {}",
|
||||
"a good photo of a {}",
|
||||
"a photo of the nice {}",
|
||||
"a photo of the small {}",
|
||||
"a photo of the weird {}",
|
||||
"a photo of the large {}",
|
||||
"a photo of a cool {}",
|
||||
"a photo of a small {}",
|
||||
]
|
||||
|
||||
imagenet_dual_templates_small = [
|
||||
"a photo of a {} with {}",
|
||||
"a rendering of a {} with {}",
|
||||
"a cropped photo of the {} with {}",
|
||||
"the photo of a {} with {}",
|
||||
"a photo of a clean {} with {}",
|
||||
"a photo of a dirty {} with {}",
|
||||
"a dark photo of the {} with {}",
|
||||
"a photo of my {} with {}",
|
||||
"a photo of the cool {} with {}",
|
||||
"a close-up photo of a {} with {}",
|
||||
"a bright photo of the {} with {}",
|
||||
"a cropped photo of a {} with {}",
|
||||
"a photo of the {} with {}",
|
||||
"a good photo of the {} with {}",
|
||||
"a photo of one {} with {}",
|
||||
"a close-up photo of the {} with {}",
|
||||
"a rendition of the {} with {}",
|
||||
"a photo of the clean {} with {}",
|
||||
"a rendition of a {} with {}",
|
||||
"a photo of a nice {} with {}",
|
||||
"a good photo of a {} with {}",
|
||||
"a photo of the nice {} with {}",
|
||||
"a photo of the small {} with {}",
|
||||
"a photo of the weird {} with {}",
|
||||
"a photo of the large {} with {}",
|
||||
"a photo of a cool {} with {}",
|
||||
"a photo of a small {} with {}",
|
||||
]
|
||||
|
||||
per_img_token_list = [
|
||||
"א",
|
||||
"ב",
|
||||
"ג",
|
||||
"ד",
|
||||
"ה",
|
||||
"ו",
|
||||
"ז",
|
||||
"ח",
|
||||
"ט",
|
||||
"י",
|
||||
"כ",
|
||||
"ל",
|
||||
"מ",
|
||||
"נ",
|
||||
"ס",
|
||||
"ע",
|
||||
"פ",
|
||||
"צ",
|
||||
"ק",
|
||||
"ר",
|
||||
"ש",
|
||||
"ת",
|
||||
]
|
||||
|
||||
|
||||
class PersonalizedBase(Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
data_root,
|
||||
size=None,
|
||||
repeats=100,
|
||||
interpolation="bicubic",
|
||||
flip_p=0.5,
|
||||
set="train",
|
||||
placeholder_token="*",
|
||||
per_image_tokens=False,
|
||||
center_crop=False,
|
||||
mixing_prob=0.25,
|
||||
coarse_class_text=None,
|
||||
):
|
||||
self.data_root = data_root
|
||||
|
||||
self.image_paths = [
|
||||
os.path.join(self.data_root, file_path)
|
||||
for file_path in os.listdir(self.data_root)
|
||||
if file_path != ".DS_Store"
|
||||
]
|
||||
|
||||
# self._length = len(self.image_paths)
|
||||
self.num_images = len(self.image_paths)
|
||||
self._length = self.num_images
|
||||
|
||||
self.placeholder_token = placeholder_token
|
||||
|
||||
self.per_image_tokens = per_image_tokens
|
||||
self.center_crop = center_crop
|
||||
self.mixing_prob = mixing_prob
|
||||
|
||||
self.coarse_class_text = coarse_class_text
|
||||
|
||||
if per_image_tokens:
|
||||
assert self.num_images < len(
|
||||
per_img_token_list
|
||||
), f"Can't use per-image tokens when the training set contains more than {len(per_img_token_list)} tokens. To enable larger sets, add more tokens to 'per_img_token_list'."
|
||||
|
||||
if set == "train":
|
||||
self._length = self.num_images * repeats
|
||||
|
||||
self.size = size
|
||||
self.interpolation = {
|
||||
"linear": PIL.Image.LINEAR,
|
||||
"bilinear": PIL.Image.BILINEAR,
|
||||
"bicubic": PIL.Image.BICUBIC,
|
||||
"lanczos": PIL.Image.LANCZOS,
|
||||
}[interpolation]
|
||||
self.flip = transforms.RandomHorizontalFlip(p=flip_p)
|
||||
|
||||
def __len__(self):
|
||||
return self._length
|
||||
|
||||
def __getitem__(self, i):
|
||||
example = {}
|
||||
image = Image.open(self.image_paths[i % self.num_images])
|
||||
|
||||
if not image.mode == "RGB":
|
||||
image = image.convert("RGB")
|
||||
|
||||
placeholder_string = self.placeholder_token
|
||||
if self.coarse_class_text:
|
||||
placeholder_string = f"{self.coarse_class_text} {placeholder_string}"
|
||||
|
||||
if self.per_image_tokens and np.random.uniform() < self.mixing_prob:
|
||||
text = random.choice(imagenet_dual_templates_small).format(
|
||||
placeholder_string, per_img_token_list[i % self.num_images]
|
||||
)
|
||||
else:
|
||||
text = random.choice(imagenet_templates_small).format(placeholder_string)
|
||||
|
||||
example["caption"] = text
|
||||
|
||||
# default to score-sde preprocessing
|
||||
img = np.array(image).astype(np.uint8)
|
||||
|
||||
if self.center_crop:
|
||||
crop = min(img.shape[0], img.shape[1])
|
||||
(
|
||||
h,
|
||||
w,
|
||||
) = (
|
||||
img.shape[0],
|
||||
img.shape[1],
|
||||
)
|
||||
img = img[
|
||||
(h - crop) // 2 : (h + crop) // 2,
|
||||
(w - crop) // 2 : (w + crop) // 2,
|
||||
]
|
||||
|
||||
image = Image.fromarray(img)
|
||||
if self.size is not None:
|
||||
image = image.resize((self.size, self.size), resample=self.interpolation)
|
||||
|
||||
image = self.flip(image)
|
||||
image = np.array(image).astype(np.uint8)
|
||||
example["image"] = (image / 127.5 - 1.0).astype(np.float32)
|
||||
return example
|
170
invokeai/backend/stable_diffusion/data/personalized_style.py
Normal file
170
invokeai/backend/stable_diffusion/data/personalized_style.py
Normal file
@ -0,0 +1,170 @@
|
||||
import os
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
import PIL
|
||||
from PIL import Image
|
||||
from torch.utils.data import Dataset
|
||||
from torchvision import transforms
|
||||
|
||||
imagenet_templates_small = [
|
||||
"a painting in the style of {}",
|
||||
"a rendering in the style of {}",
|
||||
"a cropped painting in the style of {}",
|
||||
"the painting in the style of {}",
|
||||
"a clean painting in the style of {}",
|
||||
"a dirty painting in the style of {}",
|
||||
"a dark painting in the style of {}",
|
||||
"a picture in the style of {}",
|
||||
"a cool painting in the style of {}",
|
||||
"a close-up painting in the style of {}",
|
||||
"a bright painting in the style of {}",
|
||||
"a cropped painting in the style of {}",
|
||||
"a good painting in the style of {}",
|
||||
"a close-up painting in the style of {}",
|
||||
"a rendition in the style of {}",
|
||||
"a nice painting in the style of {}",
|
||||
"a small painting in the style of {}",
|
||||
"a weird painting in the style of {}",
|
||||
"a large painting in the style of {}",
|
||||
]
|
||||
|
||||
imagenet_dual_templates_small = [
|
||||
"a painting in the style of {} with {}",
|
||||
"a rendering in the style of {} with {}",
|
||||
"a cropped painting in the style of {} with {}",
|
||||
"the painting in the style of {} with {}",
|
||||
"a clean painting in the style of {} with {}",
|
||||
"a dirty painting in the style of {} with {}",
|
||||
"a dark painting in the style of {} with {}",
|
||||
"a cool painting in the style of {} with {}",
|
||||
"a close-up painting in the style of {} with {}",
|
||||
"a bright painting in the style of {} with {}",
|
||||
"a cropped painting in the style of {} with {}",
|
||||
"a good painting in the style of {} with {}",
|
||||
"a painting of one {} in the style of {}",
|
||||
"a nice painting in the style of {} with {}",
|
||||
"a small painting in the style of {} with {}",
|
||||
"a weird painting in the style of {} with {}",
|
||||
"a large painting in the style of {} with {}",
|
||||
]
|
||||
|
||||
per_img_token_list = [
|
||||
"א",
|
||||
"ב",
|
||||
"ג",
|
||||
"ד",
|
||||
"ה",
|
||||
"ו",
|
||||
"ז",
|
||||
"ח",
|
||||
"ט",
|
||||
"י",
|
||||
"כ",
|
||||
"ל",
|
||||
"מ",
|
||||
"נ",
|
||||
"ס",
|
||||
"ע",
|
||||
"פ",
|
||||
"צ",
|
||||
"ק",
|
||||
"ר",
|
||||
"ש",
|
||||
"ת",
|
||||
]
|
||||
|
||||
|
||||
class PersonalizedBase(Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
data_root,
|
||||
size=None,
|
||||
repeats=100,
|
||||
interpolation="bicubic",
|
||||
flip_p=0.5,
|
||||
set="train",
|
||||
placeholder_token="*",
|
||||
per_image_tokens=False,
|
||||
center_crop=False,
|
||||
):
|
||||
self.data_root = data_root
|
||||
|
||||
self.image_paths = [
|
||||
os.path.join(self.data_root, file_path)
|
||||
for file_path in os.listdir(self.data_root)
|
||||
if file_path != ".DS_Store"
|
||||
]
|
||||
|
||||
# self._length = len(self.image_paths)
|
||||
self.num_images = len(self.image_paths)
|
||||
self._length = self.num_images
|
||||
|
||||
self.placeholder_token = placeholder_token
|
||||
|
||||
self.per_image_tokens = per_image_tokens
|
||||
self.center_crop = center_crop
|
||||
|
||||
if per_image_tokens:
|
||||
assert self.num_images < len(
|
||||
per_img_token_list
|
||||
), f"Can't use per-image tokens when the training set contains more than {len(per_img_token_list)} tokens. To enable larger sets, add more tokens to 'per_img_token_list'."
|
||||
|
||||
if set == "train":
|
||||
self._length = self.num_images * repeats
|
||||
|
||||
self.size = size
|
||||
self.interpolation = {
|
||||
"linear": PIL.Image.LINEAR,
|
||||
"bilinear": PIL.Image.BILINEAR,
|
||||
"bicubic": PIL.Image.BICUBIC,
|
||||
"lanczos": PIL.Image.LANCZOS,
|
||||
}[interpolation]
|
||||
self.flip = transforms.RandomHorizontalFlip(p=flip_p)
|
||||
|
||||
def __len__(self):
|
||||
return self._length
|
||||
|
||||
def __getitem__(self, i):
|
||||
example = {}
|
||||
image = Image.open(self.image_paths[i % self.num_images])
|
||||
|
||||
if not image.mode == "RGB":
|
||||
image = image.convert("RGB")
|
||||
|
||||
if self.per_image_tokens and np.random.uniform() < 0.25:
|
||||
text = random.choice(imagenet_dual_templates_small).format(
|
||||
self.placeholder_token, per_img_token_list[i % self.num_images]
|
||||
)
|
||||
else:
|
||||
text = random.choice(imagenet_templates_small).format(
|
||||
self.placeholder_token
|
||||
)
|
||||
|
||||
example["caption"] = text
|
||||
|
||||
# default to score-sde preprocessing
|
||||
img = np.array(image).astype(np.uint8)
|
||||
|
||||
if self.center_crop:
|
||||
crop = min(img.shape[0], img.shape[1])
|
||||
(
|
||||
h,
|
||||
w,
|
||||
) = (
|
||||
img.shape[0],
|
||||
img.shape[1],
|
||||
)
|
||||
img = img[
|
||||
(h - crop) // 2 : (h + crop) // 2,
|
||||
(w - crop) // 2 : (w + crop) // 2,
|
||||
]
|
||||
|
||||
image = Image.fromarray(img)
|
||||
if self.size is not None:
|
||||
image = image.resize((self.size, self.size), resample=self.interpolation)
|
||||
|
||||
image = self.flip(image)
|
||||
image = np.array(image).astype(np.uint8)
|
||||
example["image"] = (image / 127.5 - 1.0).astype(np.float32)
|
||||
return example
|
@ -2,22 +2,28 @@ from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
import inspect
|
||||
import psutil
|
||||
import secrets
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass, field
|
||||
from typing import List, Optional, Union, Callable, Type, TypeVar, Generic, Any
|
||||
from typing import Any, Callable, Generic, List, Optional, Type, TypeVar, Union
|
||||
|
||||
import PIL.Image
|
||||
import einops
|
||||
import PIL.Image
|
||||
import psutil
|
||||
import torch
|
||||
import torchvision.transforms as T
|
||||
from compel import EmbeddingsProvider
|
||||
from diffusers.models import AutoencoderKL, UNet2DConditionModel
|
||||
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
||||
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipeline
|
||||
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img import StableDiffusionImg2ImgPipeline
|
||||
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
||||
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import (
|
||||
StableDiffusionPipeline,
|
||||
)
|
||||
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img import (
|
||||
StableDiffusionImg2ImgPipeline,
|
||||
)
|
||||
from diffusers.pipelines.stable_diffusion.safety_checker import (
|
||||
StableDiffusionSafetyChecker,
|
||||
)
|
||||
from diffusers.schedulers import KarrasDiffusionSchedulers
|
||||
from diffusers.schedulers.scheduling_utils import SchedulerMixin, SchedulerOutput
|
||||
from diffusers.utils.import_utils import is_xformers_available
|
||||
@ -26,13 +32,16 @@ from torchvision.transforms.functional import resize as tv_resize
|
||||
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
||||
from typing_extensions import ParamSpec
|
||||
|
||||
from ldm.invoke.globals import Globals
|
||||
from ldm.models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent, PostprocessingSettings
|
||||
from ldm.modules.textual_inversion_manager import TextualInversionManager
|
||||
from ..devices import normalize_device, CPU_DEVICE
|
||||
from ..offloading import LazilyLoadedModelGroup, FullyLoadedModelGroup, ModelGroup
|
||||
from ...models.diffusion.cross_attention_map_saving import AttentionMapSaver
|
||||
from compel import EmbeddingsProvider
|
||||
from invokeai.backend.globals import Globals
|
||||
|
||||
from ..util import CPU_DEVICE, normalize_device
|
||||
from .diffusion import (
|
||||
AttentionMapSaver,
|
||||
InvokeAIDiffuserComponent,
|
||||
PostprocessingSettings,
|
||||
)
|
||||
from .offloading import FullyLoadedModelGroup, LazilyLoadedModelGroup, ModelGroup
|
||||
from .textual_inversion_manager import TextualInversionManager
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -51,7 +60,7 @@ _default_personalization_config_params = dict(
|
||||
initializer_wods=["sculpture"],
|
||||
per_image_tokens=False,
|
||||
num_vectors_per_token=1,
|
||||
progressive_words=False
|
||||
progressive_words=False,
|
||||
)
|
||||
|
||||
|
||||
@ -64,29 +73,34 @@ class AddsMaskLatents:
|
||||
|
||||
This class assumes the same mask and base image should apply to all items in the batch.
|
||||
"""
|
||||
|
||||
forward: Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor]
|
||||
mask: torch.Tensor
|
||||
initial_image_latents: torch.Tensor
|
||||
|
||||
def __call__(self, latents: torch.Tensor, t: torch.Tensor, text_embeddings: torch.Tensor) -> torch.Tensor:
|
||||
def __call__(
|
||||
self, latents: torch.Tensor, t: torch.Tensor, text_embeddings: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
model_input = self.add_mask_channels(latents)
|
||||
return self.forward(model_input, t, text_embeddings)
|
||||
|
||||
def add_mask_channels(self, latents):
|
||||
batch_size = latents.size(0)
|
||||
# duplicate mask and latents for each batch
|
||||
mask = einops.repeat(self.mask, 'b c h w -> (repeat b) c h w', repeat=batch_size)
|
||||
image_latents = einops.repeat(self.initial_image_latents, 'b c h w -> (repeat b) c h w', repeat=batch_size)
|
||||
mask = einops.repeat(
|
||||
self.mask, "b c h w -> (repeat b) c h w", repeat=batch_size
|
||||
)
|
||||
image_latents = einops.repeat(
|
||||
self.initial_image_latents, "b c h w -> (repeat b) c h w", repeat=batch_size
|
||||
)
|
||||
# add mask and image as additional channels
|
||||
model_input, _ = einops.pack([latents, mask, image_latents], 'b * h w')
|
||||
model_input, _ = einops.pack([latents, mask, image_latents], "b * h w")
|
||||
return model_input
|
||||
|
||||
|
||||
def are_like_tensors(a: torch.Tensor, b: object) -> bool:
|
||||
return (
|
||||
isinstance(b, torch.Tensor)
|
||||
and (a.size() == b.size())
|
||||
)
|
||||
return isinstance(b, torch.Tensor) and (a.size() == b.size())
|
||||
|
||||
|
||||
@dataclass
|
||||
class AddsMaskGuidance:
|
||||
@ -96,7 +110,9 @@ class AddsMaskGuidance:
|
||||
noise: torch.Tensor
|
||||
_debug: Optional[Callable] = None
|
||||
|
||||
def __call__(self, step_output: BaseOutput | SchedulerOutput, t: torch.Tensor, conditioning) -> BaseOutput:
|
||||
def __call__(
|
||||
self, step_output: BaseOutput | SchedulerOutput, t: torch.Tensor, conditioning
|
||||
) -> BaseOutput:
|
||||
output_class = step_output.__class__ # We'll create a new one with masked data.
|
||||
|
||||
# The problem with taking SchedulerOutput instead of the model output is that we're less certain what's in it.
|
||||
@ -106,30 +122,41 @@ class AddsMaskGuidance:
|
||||
prev_sample = step_output[0]
|
||||
# Mask anything that has the same shape as prev_sample, return others as-is.
|
||||
return output_class(
|
||||
{k: (self.apply_mask(v, self._t_for_field(k, t))
|
||||
if are_like_tensors(prev_sample, v) else v)
|
||||
for k, v in step_output.items()}
|
||||
{
|
||||
k: (
|
||||
self.apply_mask(v, self._t_for_field(k, t))
|
||||
if are_like_tensors(prev_sample, v)
|
||||
else v
|
||||
)
|
||||
for k, v in step_output.items()
|
||||
}
|
||||
)
|
||||
|
||||
def _t_for_field(self, field_name:str, t):
|
||||
def _t_for_field(self, field_name: str, t):
|
||||
if field_name == "pred_original_sample":
|
||||
return torch.zeros_like(t, dtype=t.dtype) # it represents t=0
|
||||
return t
|
||||
|
||||
def apply_mask(self, latents: torch.Tensor, t) -> torch.Tensor:
|
||||
batch_size = latents.size(0)
|
||||
mask = einops.repeat(self.mask, 'b c h w -> (repeat b) c h w', repeat=batch_size)
|
||||
mask = einops.repeat(
|
||||
self.mask, "b c h w -> (repeat b) c h w", repeat=batch_size
|
||||
)
|
||||
if t.dim() == 0:
|
||||
# some schedulers expect t to be one-dimensional.
|
||||
# TODO: file diffusers bug about inconsistency?
|
||||
t = einops.repeat(t, '-> batch', batch=batch_size)
|
||||
t = einops.repeat(t, "-> batch", batch=batch_size)
|
||||
# Noise shouldn't be re-randomized between steps here. The multistep schedulers
|
||||
# get very confused about what is happening from step to step when we do that.
|
||||
mask_latents = self.scheduler.add_noise(self.mask_latents, self.noise, t)
|
||||
# TODO: Do we need to also apply scheduler.scale_model_input? Or is add_noise appropriately scaled already?
|
||||
# mask_latents = self.scheduler.scale_model_input(mask_latents, t)
|
||||
mask_latents = einops.repeat(mask_latents, 'b c h w -> (repeat b) c h w', repeat=batch_size)
|
||||
masked_input = torch.lerp(mask_latents.to(dtype=latents.dtype), latents, mask.to(dtype=latents.dtype))
|
||||
mask_latents = einops.repeat(
|
||||
mask_latents, "b c h w -> (repeat b) c h w", repeat=batch_size
|
||||
)
|
||||
masked_input = torch.lerp(
|
||||
mask_latents.to(dtype=latents.dtype), latents, mask.to(dtype=latents.dtype)
|
||||
)
|
||||
if self._debug:
|
||||
self._debug(masked_input, f"t={t} lerped")
|
||||
return masked_input
|
||||
@ -139,7 +166,9 @@ def trim_to_multiple_of(*args, multiple_of=8):
|
||||
return tuple((x - x % multiple_of) for x in args)
|
||||
|
||||
|
||||
def image_resized_to_grid_as_tensor(image: PIL.Image.Image, normalize: bool=True, multiple_of=8) -> torch.FloatTensor:
|
||||
def image_resized_to_grid_as_tensor(
|
||||
image: PIL.Image.Image, normalize: bool = True, multiple_of=8
|
||||
) -> torch.FloatTensor:
|
||||
"""
|
||||
|
||||
:param image: input image
|
||||
@ -147,10 +176,12 @@ def image_resized_to_grid_as_tensor(image: PIL.Image.Image, normalize: bool=True
|
||||
:param multiple_of: resize the input so both dimensions are a multiple of this
|
||||
"""
|
||||
w, h = trim_to_multiple_of(*image.size)
|
||||
transformation = T.Compose([
|
||||
T.Resize((h, w), T.InterpolationMode.LANCZOS),
|
||||
T.ToTensor(),
|
||||
])
|
||||
transformation = T.Compose(
|
||||
[
|
||||
T.Resize((h, w), T.InterpolationMode.LANCZOS),
|
||||
T.ToTensor(),
|
||||
]
|
||||
)
|
||||
tensor = transformation(image)
|
||||
if normalize:
|
||||
tensor = tensor * 2.0 - 1.0
|
||||
@ -160,9 +191,11 @@ def image_resized_to_grid_as_tensor(image: PIL.Image.Image, normalize: bool=True
|
||||
def is_inpainting_model(unet: UNet2DConditionModel):
|
||||
return unet.conv_in.in_channels == 9
|
||||
|
||||
CallbackType = TypeVar('CallbackType')
|
||||
ReturnType = TypeVar('ReturnType')
|
||||
ParamType = ParamSpec('ParamType')
|
||||
|
||||
CallbackType = TypeVar("CallbackType")
|
||||
ReturnType = TypeVar("ReturnType")
|
||||
ParamType = ParamSpec("ParamType")
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class GeneratorToCallbackinator(Generic[ParamType, ReturnType, CallbackType]):
|
||||
@ -171,9 +204,12 @@ class GeneratorToCallbackinator(Generic[ParamType, ReturnType, CallbackType]):
|
||||
generator_method: Callable[ParamType, ReturnType]
|
||||
callback_arg_type: Type[CallbackType]
|
||||
|
||||
def __call__(self, *args: ParamType.args,
|
||||
callback:Callable[[CallbackType], Any]=None,
|
||||
**kwargs: ParamType.kwargs) -> ReturnType:
|
||||
def __call__(
|
||||
self,
|
||||
*args: ParamType.args,
|
||||
callback: Callable[[CallbackType], Any] = None,
|
||||
**kwargs: ParamType.kwargs,
|
||||
) -> ReturnType:
|
||||
result = None
|
||||
for result in self.generator_method(*args, **kwargs):
|
||||
if callback is not None and isinstance(result, self.callback_arg_type):
|
||||
@ -218,6 +254,7 @@ class ConditioningData:
|
||||
scheduler_args[name] = value
|
||||
return dataclasses.replace(self, scheduler_args=scheduler_args)
|
||||
|
||||
|
||||
@dataclass
|
||||
class InvokeAIStableDiffusionPipelineOutput(StableDiffusionPipelineOutput):
|
||||
r"""
|
||||
@ -275,10 +312,18 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
safety_checker: Optional[StableDiffusionSafetyChecker],
|
||||
feature_extractor: Optional[CLIPFeatureExtractor],
|
||||
requires_safety_checker: bool = False,
|
||||
precision: str = 'float32',
|
||||
precision: str = "float32",
|
||||
):
|
||||
super().__init__(vae, text_encoder, tokenizer, unet, scheduler,
|
||||
safety_checker, feature_extractor, requires_safety_checker)
|
||||
super().__init__(
|
||||
vae,
|
||||
text_encoder,
|
||||
tokenizer,
|
||||
unet,
|
||||
scheduler,
|
||||
safety_checker,
|
||||
feature_extractor,
|
||||
requires_safety_checker,
|
||||
)
|
||||
|
||||
self.register_modules(
|
||||
vae=vae,
|
||||
@ -289,27 +334,34 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
safety_checker=safety_checker,
|
||||
feature_extractor=feature_extractor,
|
||||
)
|
||||
self.invokeai_diffuser = InvokeAIDiffuserComponent(self.unet, self._unet_forward, is_running_diffusers=True)
|
||||
use_full_precision = (precision == 'float32' or precision == 'autocast')
|
||||
self.textual_inversion_manager = TextualInversionManager(tokenizer=self.tokenizer,
|
||||
text_encoder=self.text_encoder,
|
||||
full_precision=use_full_precision)
|
||||
self.invokeai_diffuser = InvokeAIDiffuserComponent(
|
||||
self.unet, self._unet_forward, is_running_diffusers=True
|
||||
)
|
||||
use_full_precision = precision == "float32" or precision == "autocast"
|
||||
self.textual_inversion_manager = TextualInversionManager(
|
||||
tokenizer=self.tokenizer,
|
||||
text_encoder=self.text_encoder,
|
||||
full_precision=use_full_precision,
|
||||
)
|
||||
# InvokeAI's interface for text embeddings and whatnot
|
||||
self.embeddings_provider = EmbeddingsProvider(
|
||||
tokenizer=self.tokenizer,
|
||||
text_encoder=self.text_encoder,
|
||||
textual_inversion_manager=self.textual_inversion_manager
|
||||
textual_inversion_manager=self.textual_inversion_manager,
|
||||
)
|
||||
|
||||
self._model_group = FullyLoadedModelGroup(self.unet.device)
|
||||
self._model_group.install(*self._submodels)
|
||||
|
||||
|
||||
def _adjust_memory_efficient_attention(self, latents: torch.Tensor):
|
||||
"""
|
||||
if xformers is available, use it, otherwise use sliced attention.
|
||||
"""
|
||||
if torch.cuda.is_available() and is_xformers_available() and not Globals.disable_xformers:
|
||||
if (
|
||||
torch.cuda.is_available()
|
||||
and is_xformers_available()
|
||||
and not Globals.disable_xformers
|
||||
):
|
||||
self.enable_xformers_memory_efficient_attention()
|
||||
else:
|
||||
if torch.backends.mps.is_available():
|
||||
@ -318,25 +370,32 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
# fix is in https://github.com/kulinseth/pytorch/pull/222 but no idea when it will get merged to pytorch mainline.
|
||||
pass
|
||||
else:
|
||||
if self.device.type == 'cpu' or self.device.type == 'mps':
|
||||
if self.device.type == "cpu" or self.device.type == "mps":
|
||||
mem_free = psutil.virtual_memory().free
|
||||
elif self.device.type == 'cuda':
|
||||
elif self.device.type == "cuda":
|
||||
mem_free, _ = torch.cuda.mem_get_info(normalize_device(self.device))
|
||||
else:
|
||||
raise ValueError(f"unrecognized device {self.device}")
|
||||
# input tensor of [1, 4, h/8, w/8]
|
||||
# output tensor of [16, (h/8 * w/8), (h/8 * w/8)]
|
||||
bytes_per_element_needed_for_baddbmm_duplication = latents.element_size() + 4
|
||||
max_size_required_for_baddbmm = \
|
||||
16 * \
|
||||
latents.size(dim=2) * latents.size(dim=3) * latents.size(dim=2) * latents.size(dim=3) * \
|
||||
bytes_per_element_needed_for_baddbmm_duplication
|
||||
if max_size_required_for_baddbmm > (mem_free * 3.0 / 4.0): # 3.3 / 4.0 is from old Invoke code
|
||||
self.enable_attention_slicing(slice_size='max')
|
||||
bytes_per_element_needed_for_baddbmm_duplication = (
|
||||
latents.element_size() + 4
|
||||
)
|
||||
max_size_required_for_baddbmm = (
|
||||
16
|
||||
* latents.size(dim=2)
|
||||
* latents.size(dim=3)
|
||||
* latents.size(dim=2)
|
||||
* latents.size(dim=3)
|
||||
* bytes_per_element_needed_for_baddbmm_duplication
|
||||
)
|
||||
if max_size_required_for_baddbmm > (
|
||||
mem_free * 3.0 / 4.0
|
||||
): # 3.3 / 4.0 is from old Invoke code
|
||||
self.enable_attention_slicing(slice_size="max")
|
||||
else:
|
||||
self.disable_attention_slicing()
|
||||
|
||||
|
||||
def enable_offload_submodels(self, device: torch.device):
|
||||
"""
|
||||
Offload each submodel when it's not in use.
|
||||
@ -398,12 +457,16 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
values = [getattr(self, name) for name in module_names.keys()]
|
||||
return [m for m in values if isinstance(m, torch.nn.Module)]
|
||||
|
||||
def image_from_embeddings(self, latents: torch.Tensor, num_inference_steps: int,
|
||||
conditioning_data: ConditioningData,
|
||||
*,
|
||||
noise: torch.Tensor,
|
||||
callback: Callable[[PipelineIntermediateState], None]=None,
|
||||
run_id=None) -> InvokeAIStableDiffusionPipelineOutput:
|
||||
def image_from_embeddings(
|
||||
self,
|
||||
latents: torch.Tensor,
|
||||
num_inference_steps: int,
|
||||
conditioning_data: ConditioningData,
|
||||
*,
|
||||
noise: torch.Tensor,
|
||||
callback: Callable[[PipelineIntermediateState], None] = None,
|
||||
run_id=None,
|
||||
) -> InvokeAIStableDiffusionPipelineOutput:
|
||||
r"""
|
||||
Function invoked when calling the pipeline for generation.
|
||||
|
||||
@ -417,71 +480,104 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
:param run_id:
|
||||
"""
|
||||
result_latents, result_attention_map_saver = self.latents_from_embeddings(
|
||||
latents, num_inference_steps,
|
||||
latents,
|
||||
num_inference_steps,
|
||||
conditioning_data,
|
||||
noise=noise,
|
||||
run_id=run_id,
|
||||
callback=callback)
|
||||
callback=callback,
|
||||
)
|
||||
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
with torch.inference_mode():
|
||||
image = self.decode_latents(result_latents)
|
||||
output = InvokeAIStableDiffusionPipelineOutput(images=image, nsfw_content_detected=[], attention_map_saver=result_attention_map_saver)
|
||||
output = InvokeAIStableDiffusionPipelineOutput(
|
||||
images=image,
|
||||
nsfw_content_detected=[],
|
||||
attention_map_saver=result_attention_map_saver,
|
||||
)
|
||||
return self.check_for_safety(output, dtype=conditioning_data.dtype)
|
||||
|
||||
def latents_from_embeddings(self, latents: torch.Tensor, num_inference_steps: int,
|
||||
conditioning_data: ConditioningData,
|
||||
*,
|
||||
noise: torch.Tensor,
|
||||
timesteps=None,
|
||||
additional_guidance: List[Callable] = None, run_id=None,
|
||||
callback: Callable[[PipelineIntermediateState], None] = None
|
||||
) -> tuple[torch.Tensor, Optional[AttentionMapSaver]]:
|
||||
def latents_from_embeddings(
|
||||
self,
|
||||
latents: torch.Tensor,
|
||||
num_inference_steps: int,
|
||||
conditioning_data: ConditioningData,
|
||||
*,
|
||||
noise: torch.Tensor,
|
||||
timesteps=None,
|
||||
additional_guidance: List[Callable] = None,
|
||||
run_id=None,
|
||||
callback: Callable[[PipelineIntermediateState], None] = None,
|
||||
) -> tuple[torch.Tensor, Optional[AttentionMapSaver]]:
|
||||
if timesteps is None:
|
||||
self.scheduler.set_timesteps(num_inference_steps, device=self._model_group.device_for(self.unet))
|
||||
self.scheduler.set_timesteps(
|
||||
num_inference_steps, device=self._model_group.device_for(self.unet)
|
||||
)
|
||||
timesteps = self.scheduler.timesteps
|
||||
infer_latents_from_embeddings = GeneratorToCallbackinator(self.generate_latents_from_embeddings, PipelineIntermediateState)
|
||||
infer_latents_from_embeddings = GeneratorToCallbackinator(
|
||||
self.generate_latents_from_embeddings, PipelineIntermediateState
|
||||
)
|
||||
result: PipelineIntermediateState = infer_latents_from_embeddings(
|
||||
latents, timesteps, conditioning_data,
|
||||
latents,
|
||||
timesteps,
|
||||
conditioning_data,
|
||||
noise=noise,
|
||||
additional_guidance=additional_guidance,
|
||||
run_id=run_id,
|
||||
callback=callback)
|
||||
callback=callback,
|
||||
)
|
||||
return result.latents, result.attention_map_saver
|
||||
|
||||
def generate_latents_from_embeddings(self, latents: torch.Tensor, timesteps,
|
||||
conditioning_data: ConditioningData,
|
||||
*,
|
||||
noise: torch.Tensor,
|
||||
run_id: str = None,
|
||||
additional_guidance: List[Callable] = None):
|
||||
def generate_latents_from_embeddings(
|
||||
self,
|
||||
latents: torch.Tensor,
|
||||
timesteps,
|
||||
conditioning_data: ConditioningData,
|
||||
*,
|
||||
noise: torch.Tensor,
|
||||
run_id: str = None,
|
||||
additional_guidance: List[Callable] = None,
|
||||
):
|
||||
self._adjust_memory_efficient_attention(latents)
|
||||
if run_id is None:
|
||||
run_id = secrets.token_urlsafe(self.ID_LENGTH)
|
||||
if additional_guidance is None:
|
||||
additional_guidance = []
|
||||
extra_conditioning_info = conditioning_data.extra
|
||||
with self.invokeai_diffuser.custom_attention_context(extra_conditioning_info=extra_conditioning_info,
|
||||
step_count=len(self.scheduler.timesteps)
|
||||
):
|
||||
|
||||
yield PipelineIntermediateState(run_id=run_id, step=-1, timestep=self.scheduler.num_train_timesteps,
|
||||
latents=latents)
|
||||
with self.invokeai_diffuser.custom_attention_context(
|
||||
extra_conditioning_info=extra_conditioning_info,
|
||||
step_count=len(self.scheduler.timesteps),
|
||||
):
|
||||
yield PipelineIntermediateState(
|
||||
run_id=run_id,
|
||||
step=-1,
|
||||
timestep=self.scheduler.num_train_timesteps,
|
||||
latents=latents,
|
||||
)
|
||||
|
||||
batch_size = latents.shape[0]
|
||||
batched_t = torch.full((batch_size,), timesteps[0],
|
||||
dtype=timesteps.dtype, device=self._model_group.device_for(self.unet))
|
||||
batched_t = torch.full(
|
||||
(batch_size,),
|
||||
timesteps[0],
|
||||
dtype=timesteps.dtype,
|
||||
device=self._model_group.device_for(self.unet),
|
||||
)
|
||||
latents = self.scheduler.add_noise(latents, noise, batched_t)
|
||||
|
||||
attention_map_saver: Optional[AttentionMapSaver] = None
|
||||
|
||||
for i, t in enumerate(self.progress_bar(timesteps)):
|
||||
batched_t.fill_(t)
|
||||
step_output = self.step(batched_t, latents, conditioning_data,
|
||||
step_index=i,
|
||||
total_step_count=len(timesteps),
|
||||
additional_guidance=additional_guidance)
|
||||
step_output = self.step(
|
||||
batched_t,
|
||||
latents,
|
||||
conditioning_data,
|
||||
step_index=i,
|
||||
total_step_count=len(timesteps),
|
||||
additional_guidance=additional_guidance,
|
||||
)
|
||||
latents = step_output.prev_sample
|
||||
|
||||
latents = self.invokeai_diffuser.do_latent_postprocessing(
|
||||
@ -489,28 +585,39 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
latents=latents,
|
||||
sigma=batched_t,
|
||||
step_index=i,
|
||||
total_step_count=len(timesteps)
|
||||
total_step_count=len(timesteps),
|
||||
)
|
||||
|
||||
predicted_original = getattr(step_output, 'pred_original_sample', None)
|
||||
predicted_original = getattr(step_output, "pred_original_sample", None)
|
||||
|
||||
# TODO resuscitate attention map saving
|
||||
#if i == len(timesteps)-1 and extra_conditioning_info is not None:
|
||||
# if i == len(timesteps)-1 and extra_conditioning_info is not None:
|
||||
# eos_token_index = extra_conditioning_info.tokens_count_including_eos_bos - 1
|
||||
# attention_map_token_ids = range(1, eos_token_index)
|
||||
# attention_map_saver = AttentionMapSaver(token_ids=attention_map_token_ids, latents_shape=latents.shape[-2:])
|
||||
# self.invokeai_diffuser.setup_attention_map_saving(attention_map_saver)
|
||||
|
||||
yield PipelineIntermediateState(run_id=run_id, step=i, timestep=int(t), latents=latents,
|
||||
predicted_original=predicted_original, attention_map_saver=attention_map_saver)
|
||||
yield PipelineIntermediateState(
|
||||
run_id=run_id,
|
||||
step=i,
|
||||
timestep=int(t),
|
||||
latents=latents,
|
||||
predicted_original=predicted_original,
|
||||
attention_map_saver=attention_map_saver,
|
||||
)
|
||||
|
||||
return latents, attention_map_saver
|
||||
|
||||
@torch.inference_mode()
|
||||
def step(self, t: torch.Tensor, latents: torch.Tensor,
|
||||
conditioning_data: ConditioningData,
|
||||
step_index:int, total_step_count:int,
|
||||
additional_guidance: List[Callable] = None):
|
||||
def step(
|
||||
self,
|
||||
t: torch.Tensor,
|
||||
latents: torch.Tensor,
|
||||
conditioning_data: ConditioningData,
|
||||
step_index: int,
|
||||
total_step_count: int,
|
||||
additional_guidance: List[Callable] = None,
|
||||
):
|
||||
# invokeai_diffuser has batched timesteps, but diffusers schedulers expect a single value
|
||||
timestep = t[0]
|
||||
|
||||
@ -523,16 +630,19 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
|
||||
# predict the noise residual
|
||||
noise_pred = self.invokeai_diffuser.do_diffusion_step(
|
||||
latent_model_input, t,
|
||||
conditioning_data.unconditioned_embeddings, conditioning_data.text_embeddings,
|
||||
latent_model_input,
|
||||
t,
|
||||
conditioning_data.unconditioned_embeddings,
|
||||
conditioning_data.text_embeddings,
|
||||
conditioning_data.guidance_scale,
|
||||
step_index=step_index,
|
||||
total_step_count=total_step_count,
|
||||
)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
step_output = self.scheduler.step(noise_pred, timestep, latents,
|
||||
**conditioning_data.scheduler_args)
|
||||
step_output = self.scheduler.step(
|
||||
noise_pred, timestep, latents, **conditioning_data.scheduler_args
|
||||
)
|
||||
|
||||
# TODO: this additional_guidance extension point feels redundant with InvokeAIDiffusionComponent.
|
||||
# But the way things are now, scheduler runs _after_ that, so there was
|
||||
@ -542,7 +652,13 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
|
||||
return step_output
|
||||
|
||||
def _unet_forward(self, latents, t, text_embeddings, cross_attention_kwargs: Optional[dict[str,Any]] = None):
|
||||
def _unet_forward(
|
||||
self,
|
||||
latents,
|
||||
t,
|
||||
text_embeddings,
|
||||
cross_attention_kwargs: Optional[dict[str, Any]] = None,
|
||||
):
|
||||
"""predict the noise residual"""
|
||||
if is_inpainting_model(self.unet) and latents.size(1) == 4:
|
||||
# Pad out normal non-inpainting inputs for an inpainting model.
|
||||
@ -551,67 +667,100 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
# use of AddsMaskLatents.
|
||||
latents = AddsMaskLatents(
|
||||
self._unet_forward,
|
||||
mask=torch.ones_like(latents[:1, :1], device=latents.device, dtype=latents.dtype),
|
||||
initial_image_latents=torch.zeros_like(latents[:1], device=latents.device, dtype=latents.dtype)
|
||||
mask=torch.ones_like(
|
||||
latents[:1, :1], device=latents.device, dtype=latents.dtype
|
||||
),
|
||||
initial_image_latents=torch.zeros_like(
|
||||
latents[:1], device=latents.device, dtype=latents.dtype
|
||||
),
|
||||
).add_mask_channels(latents)
|
||||
|
||||
# First three args should be positional, not keywords, so torch hooks can see them.
|
||||
return self.unet(latents, t, text_embeddings,
|
||||
cross_attention_kwargs=cross_attention_kwargs).sample
|
||||
return self.unet(
|
||||
latents, t, text_embeddings, cross_attention_kwargs=cross_attention_kwargs
|
||||
).sample
|
||||
|
||||
def img2img_from_embeddings(self,
|
||||
init_image: Union[torch.FloatTensor, PIL.Image.Image],
|
||||
strength: float,
|
||||
num_inference_steps: int,
|
||||
conditioning_data: ConditioningData,
|
||||
*, callback: Callable[[PipelineIntermediateState], None] = None,
|
||||
run_id=None,
|
||||
noise_func=None
|
||||
) -> InvokeAIStableDiffusionPipelineOutput:
|
||||
def img2img_from_embeddings(
|
||||
self,
|
||||
init_image: Union[torch.FloatTensor, PIL.Image.Image],
|
||||
strength: float,
|
||||
num_inference_steps: int,
|
||||
conditioning_data: ConditioningData,
|
||||
*,
|
||||
callback: Callable[[PipelineIntermediateState], None] = None,
|
||||
run_id=None,
|
||||
noise_func=None,
|
||||
) -> InvokeAIStableDiffusionPipelineOutput:
|
||||
if isinstance(init_image, PIL.Image.Image):
|
||||
init_image = image_resized_to_grid_as_tensor(init_image.convert('RGB'))
|
||||
init_image = image_resized_to_grid_as_tensor(init_image.convert("RGB"))
|
||||
|
||||
if init_image.dim() == 3:
|
||||
init_image = einops.rearrange(init_image, 'c h w -> 1 c h w')
|
||||
init_image = einops.rearrange(init_image, "c h w -> 1 c h w")
|
||||
|
||||
# 6. Prepare latent variables
|
||||
initial_latents = self.non_noised_latents_from_image(
|
||||
init_image, device=self._model_group.device_for(self.unet),
|
||||
dtype=self.unet.dtype)
|
||||
init_image,
|
||||
device=self._model_group.device_for(self.unet),
|
||||
dtype=self.unet.dtype,
|
||||
)
|
||||
noise = noise_func(initial_latents)
|
||||
|
||||
return self.img2img_from_latents_and_embeddings(initial_latents, num_inference_steps,
|
||||
conditioning_data,
|
||||
strength,
|
||||
noise, run_id, callback)
|
||||
return self.img2img_from_latents_and_embeddings(
|
||||
initial_latents,
|
||||
num_inference_steps,
|
||||
conditioning_data,
|
||||
strength,
|
||||
noise,
|
||||
run_id,
|
||||
callback,
|
||||
)
|
||||
|
||||
def img2img_from_latents_and_embeddings(self, initial_latents, num_inference_steps,
|
||||
conditioning_data: ConditioningData,
|
||||
strength,
|
||||
noise: torch.Tensor, run_id=None, callback=None
|
||||
) -> InvokeAIStableDiffusionPipelineOutput:
|
||||
timesteps, _ = self.get_img2img_timesteps(num_inference_steps, strength,
|
||||
device=self._model_group.device_for(self.unet))
|
||||
def img2img_from_latents_and_embeddings(
|
||||
self,
|
||||
initial_latents,
|
||||
num_inference_steps,
|
||||
conditioning_data: ConditioningData,
|
||||
strength,
|
||||
noise: torch.Tensor,
|
||||
run_id=None,
|
||||
callback=None,
|
||||
) -> InvokeAIStableDiffusionPipelineOutput:
|
||||
timesteps, _ = self.get_img2img_timesteps(
|
||||
num_inference_steps,
|
||||
strength,
|
||||
device=self._model_group.device_for(self.unet),
|
||||
)
|
||||
result_latents, result_attention_maps = self.latents_from_embeddings(
|
||||
initial_latents, num_inference_steps, conditioning_data,
|
||||
initial_latents,
|
||||
num_inference_steps,
|
||||
conditioning_data,
|
||||
timesteps=timesteps,
|
||||
noise=noise,
|
||||
run_id=run_id,
|
||||
callback=callback)
|
||||
callback=callback,
|
||||
)
|
||||
|
||||
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
with torch.inference_mode():
|
||||
image = self.decode_latents(result_latents)
|
||||
output = InvokeAIStableDiffusionPipelineOutput(images=image, nsfw_content_detected=[], attention_map_saver=result_attention_maps)
|
||||
output = InvokeAIStableDiffusionPipelineOutput(
|
||||
images=image,
|
||||
nsfw_content_detected=[],
|
||||
attention_map_saver=result_attention_maps,
|
||||
)
|
||||
return self.check_for_safety(output, dtype=conditioning_data.dtype)
|
||||
|
||||
def get_img2img_timesteps(self, num_inference_steps: int, strength: float, device) -> (torch.Tensor, int):
|
||||
def get_img2img_timesteps(
|
||||
self, num_inference_steps: int, strength: float, device
|
||||
) -> (torch.Tensor, int):
|
||||
img2img_pipeline = StableDiffusionImg2ImgPipeline(**self.components)
|
||||
assert img2img_pipeline.scheduler is self.scheduler
|
||||
img2img_pipeline.scheduler.set_timesteps(num_inference_steps, device=device)
|
||||
timesteps, adjusted_steps = img2img_pipeline.get_timesteps(num_inference_steps, strength, device=device)
|
||||
timesteps, adjusted_steps = img2img_pipeline.get_timesteps(
|
||||
num_inference_steps, strength, device=device
|
||||
)
|
||||
# Workaround for low strength resulting in zero timesteps.
|
||||
# TODO: submit upstream fix for zero-step img2img
|
||||
if timesteps.numel() == 0:
|
||||
@ -620,21 +769,22 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
return timesteps, adjusted_steps
|
||||
|
||||
def inpaint_from_embeddings(
|
||||
self,
|
||||
init_image: torch.FloatTensor,
|
||||
mask: torch.FloatTensor,
|
||||
strength: float,
|
||||
num_inference_steps: int,
|
||||
conditioning_data: ConditioningData,
|
||||
*, callback: Callable[[PipelineIntermediateState], None] = None,
|
||||
run_id=None,
|
||||
noise_func=None,
|
||||
) -> InvokeAIStableDiffusionPipelineOutput:
|
||||
self,
|
||||
init_image: torch.FloatTensor,
|
||||
mask: torch.FloatTensor,
|
||||
strength: float,
|
||||
num_inference_steps: int,
|
||||
conditioning_data: ConditioningData,
|
||||
*,
|
||||
callback: Callable[[PipelineIntermediateState], None] = None,
|
||||
run_id=None,
|
||||
noise_func=None,
|
||||
) -> InvokeAIStableDiffusionPipelineOutput:
|
||||
device = self._model_group.device_for(self.unet)
|
||||
latents_dtype = self.unet.dtype
|
||||
|
||||
if isinstance(init_image, PIL.Image.Image):
|
||||
init_image = image_resized_to_grid_as_tensor(init_image.convert('RGB'))
|
||||
init_image = image_resized_to_grid_as_tensor(init_image.convert("RGB"))
|
||||
|
||||
init_image = init_image.to(device=device, dtype=latents_dtype)
|
||||
mask = mask.to(device=device, dtype=latents_dtype)
|
||||
@ -642,18 +792,23 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
if init_image.dim() == 3:
|
||||
init_image = init_image.unsqueeze(0)
|
||||
|
||||
timesteps, _ = self.get_img2img_timesteps(num_inference_steps, strength, device=device)
|
||||
timesteps, _ = self.get_img2img_timesteps(
|
||||
num_inference_steps, strength, device=device
|
||||
)
|
||||
|
||||
# 6. Prepare latent variables
|
||||
# can't quite use upstream StableDiffusionImg2ImgPipeline.prepare_latents
|
||||
# because we have our own noise function
|
||||
init_image_latents = self.non_noised_latents_from_image(init_image, device=device, dtype=latents_dtype)
|
||||
init_image_latents = self.non_noised_latents_from_image(
|
||||
init_image, device=device, dtype=latents_dtype
|
||||
)
|
||||
noise = noise_func(init_image_latents)
|
||||
|
||||
if mask.dim() == 3:
|
||||
mask = mask.unsqueeze(0)
|
||||
latent_mask = tv_resize(mask, init_image_latents.shape[-2:], T.InterpolationMode.BILINEAR) \
|
||||
.to(device=device, dtype=latents_dtype)
|
||||
latent_mask = tv_resize(
|
||||
mask, init_image_latents.shape[-2:], T.InterpolationMode.BILINEAR
|
||||
).to(device=device, dtype=latents_dtype)
|
||||
|
||||
guidance: List[Callable] = []
|
||||
|
||||
@ -661,20 +816,30 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
# You'd think the inpainting model wouldn't be paying attention to the area it is going to repaint
|
||||
# (that's why there's a mask!) but it seems to really want that blanked out.
|
||||
masked_init_image = init_image * torch.where(mask < 0.5, 1, 0)
|
||||
masked_latents = self.non_noised_latents_from_image(masked_init_image, device=device, dtype=latents_dtype)
|
||||
masked_latents = self.non_noised_latents_from_image(
|
||||
masked_init_image, device=device, dtype=latents_dtype
|
||||
)
|
||||
|
||||
# TODO: we should probably pass this in so we don't have to try/finally around setting it.
|
||||
self.invokeai_diffuser.model_forward_callback = \
|
||||
AddsMaskLatents(self._unet_forward, latent_mask, masked_latents)
|
||||
self.invokeai_diffuser.model_forward_callback = AddsMaskLatents(
|
||||
self._unet_forward, latent_mask, masked_latents
|
||||
)
|
||||
else:
|
||||
guidance.append(AddsMaskGuidance(latent_mask, init_image_latents, self.scheduler, noise))
|
||||
guidance.append(
|
||||
AddsMaskGuidance(latent_mask, init_image_latents, self.scheduler, noise)
|
||||
)
|
||||
|
||||
try:
|
||||
result_latents, result_attention_maps = self.latents_from_embeddings(
|
||||
init_image_latents, num_inference_steps,
|
||||
conditioning_data, noise=noise, timesteps=timesteps,
|
||||
init_image_latents,
|
||||
num_inference_steps,
|
||||
conditioning_data,
|
||||
noise=noise,
|
||||
timesteps=timesteps,
|
||||
additional_guidance=guidance,
|
||||
run_id=run_id, callback=callback)
|
||||
run_id=run_id,
|
||||
callback=callback,
|
||||
)
|
||||
finally:
|
||||
self.invokeai_diffuser.model_forward_callback = self._unet_forward
|
||||
|
||||
@ -683,13 +848,17 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
|
||||
with torch.inference_mode():
|
||||
image = self.decode_latents(result_latents)
|
||||
output = InvokeAIStableDiffusionPipelineOutput(images=image, nsfw_content_detected=[], attention_map_saver=result_attention_maps)
|
||||
output = InvokeAIStableDiffusionPipelineOutput(
|
||||
images=image,
|
||||
nsfw_content_detected=[],
|
||||
attention_map_saver=result_attention_maps,
|
||||
)
|
||||
return self.check_for_safety(output, dtype=conditioning_data.dtype)
|
||||
|
||||
def non_noised_latents_from_image(self, init_image, *, device: torch.device, dtype):
|
||||
init_image = init_image.to(device=device, dtype=dtype)
|
||||
with torch.inference_mode():
|
||||
if device.type == 'mps':
|
||||
if device.type == "mps":
|
||||
# workaround for torch MPS bug that has been fixed in https://github.com/kulinseth/pytorch/pull/222
|
||||
# TODO remove this workaround once kulinseth#222 is merged to pytorch mainline
|
||||
self.vae.to(CPU_DEVICE)
|
||||
@ -697,8 +866,10 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
else:
|
||||
self._model_group.load(self.vae)
|
||||
init_latent_dist = self.vae.encode(init_image).latent_dist
|
||||
init_latents = init_latent_dist.sample().to(dtype=dtype) # FIXME: uses torch.randn. make reproducible!
|
||||
if device.type == 'mps':
|
||||
init_latents = init_latent_dist.sample().to(
|
||||
dtype=dtype
|
||||
) # FIXME: uses torch.randn. make reproducible!
|
||||
if device.type == "mps":
|
||||
self.vae.to(device)
|
||||
init_latents = init_latents.to(device)
|
||||
|
||||
@ -707,14 +878,18 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
|
||||
def check_for_safety(self, output, dtype):
|
||||
with torch.inference_mode():
|
||||
screened_images, has_nsfw_concept = self.run_safety_checker(output.images, dtype=dtype)
|
||||
screened_images, has_nsfw_concept = self.run_safety_checker(
|
||||
output.images, dtype=dtype
|
||||
)
|
||||
screened_attention_map_saver = None
|
||||
if has_nsfw_concept is None or not has_nsfw_concept:
|
||||
screened_attention_map_saver = output.attention_map_saver
|
||||
return InvokeAIStableDiffusionPipelineOutput(screened_images,
|
||||
has_nsfw_concept,
|
||||
# block the attention maps if NSFW content is detected
|
||||
attention_map_saver=screened_attention_map_saver)
|
||||
return InvokeAIStableDiffusionPipelineOutput(
|
||||
screened_images,
|
||||
has_nsfw_concept,
|
||||
# block the attention maps if NSFW content is detected
|
||||
attention_map_saver=screened_attention_map_saver,
|
||||
)
|
||||
|
||||
def run_safety_checker(self, image, device=None, dtype=None):
|
||||
# overriding to use the model group for device info instead of requiring the caller to know.
|
||||
@ -723,15 +898,18 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
return super().run_safety_checker(image, device, dtype)
|
||||
|
||||
@torch.inference_mode()
|
||||
def get_learned_conditioning(self, c: List[List[str]], *, return_tokens=True, fragment_weights=None):
|
||||
def get_learned_conditioning(
|
||||
self, c: List[List[str]], *, return_tokens=True, fragment_weights=None
|
||||
):
|
||||
"""
|
||||
Compatibility function for ldm.models.diffusion.ddpm.LatentDiffusion.
|
||||
Compatibility function for invokeai.models.diffusion.ddpm.LatentDiffusion.
|
||||
"""
|
||||
return self.embeddings_provider.get_embeddings_for_weighted_prompt_fragments(
|
||||
text_batch=c,
|
||||
fragment_weights_batch=fragment_weights,
|
||||
should_return_tokens=return_tokens,
|
||||
device=self._model_group.device_for(self.unet))
|
||||
device=self._model_group.device_for(self.unet),
|
||||
)
|
||||
|
||||
@property
|
||||
def cond_stage_model(self):
|
||||
@ -760,6 +938,9 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
def debug_latents(self, latents, msg):
|
||||
with torch.inference_mode():
|
||||
from ldm.util import debug_image
|
||||
|
||||
decoded = self.numpy_to_pil(self.decode_latents(latents))
|
||||
for i, img in enumerate(decoded):
|
||||
debug_image(img, f"latents {msg} {i+1}/{len(decoded)}", debug_status=True)
|
||||
debug_image(
|
||||
img, f"latents {msg} {i+1}/{len(decoded)}", debug_status=True
|
||||
)
|
6
invokeai/backend/stable_diffusion/diffusion/__init__.py
Normal file
6
invokeai/backend/stable_diffusion/diffusion/__init__.py
Normal file
@ -0,0 +1,6 @@
|
||||
"""
|
||||
Initialization file for invokeai.models.diffusion
|
||||
"""
|
||||
from .cross_attention_control import InvokeAICrossAttentionMixin
|
||||
from .cross_attention_map_saving import AttentionMapSaver
|
||||
from .shared_invokeai_diffusion import InvokeAIDiffuserComponent, PostprocessingSettings
|
@ -1,22 +1,19 @@
|
||||
import os
|
||||
import torch
|
||||
from copy import deepcopy
|
||||
from glob import glob
|
||||
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
from einops import rearrange
|
||||
from ldm.modules.diffusionmodules.openaimodel import EncoderUNetModel, UNetModel
|
||||
from ldm.util import default, instantiate_from_config, ismap, log_txt_as_img
|
||||
from natsort import natsorted
|
||||
from omegaconf import OmegaConf
|
||||
from torch.nn import functional as F
|
||||
from torch.optim import AdamW
|
||||
from torch.optim.lr_scheduler import LambdaLR
|
||||
from copy import deepcopy
|
||||
from einops import rearrange
|
||||
from glob import glob
|
||||
from natsort import natsorted
|
||||
|
||||
from ldm.modules.diffusionmodules.openaimodel import (
|
||||
EncoderUNetModel,
|
||||
UNetModel,
|
||||
)
|
||||
from ldm.util import log_txt_as_img, default, ismap, instantiate_from_config
|
||||
|
||||
__models__ = {'class_label': EncoderUNetModel, 'segmentation': UNetModel}
|
||||
__models__ = {"class_label": EncoderUNetModel, "segmentation": UNetModel}
|
||||
|
||||
|
||||
def disabled_train(self, mode=True):
|
||||
@ -31,13 +28,13 @@ class NoisyLatentImageClassifier(pl.LightningModule):
|
||||
diffusion_path,
|
||||
num_classes,
|
||||
ckpt_path=None,
|
||||
pool='attention',
|
||||
pool="attention",
|
||||
label_key=None,
|
||||
diffusion_ckpt_path=None,
|
||||
scheduler_config=None,
|
||||
weight_decay=1.0e-2,
|
||||
log_steps=10,
|
||||
monitor='val/loss',
|
||||
monitor="val/loss",
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
@ -45,30 +42,26 @@ class NoisyLatentImageClassifier(pl.LightningModule):
|
||||
self.num_classes = num_classes
|
||||
# get latest config of diffusion model
|
||||
diffusion_config = natsorted(
|
||||
glob(os.path.join(diffusion_path, 'configs', '*-project.yaml'))
|
||||
glob(os.path.join(diffusion_path, "configs", "*-project.yaml"))
|
||||
)[-1]
|
||||
self.diffusion_config = OmegaConf.load(diffusion_config).model
|
||||
self.diffusion_config.params.ckpt_path = diffusion_ckpt_path
|
||||
self.load_diffusion()
|
||||
|
||||
self.monitor = monitor
|
||||
self.numd = (
|
||||
self.diffusion_model.first_stage_model.encoder.num_resolutions - 1
|
||||
)
|
||||
self.log_time_interval = (
|
||||
self.diffusion_model.num_timesteps // log_steps
|
||||
)
|
||||
self.numd = self.diffusion_model.first_stage_model.encoder.num_resolutions - 1
|
||||
self.log_time_interval = self.diffusion_model.num_timesteps // log_steps
|
||||
self.log_steps = log_steps
|
||||
|
||||
self.label_key = (
|
||||
label_key
|
||||
if not hasattr(self.diffusion_model, 'cond_stage_key')
|
||||
if not hasattr(self.diffusion_model, "cond_stage_key")
|
||||
else self.diffusion_model.cond_stage_key
|
||||
)
|
||||
|
||||
assert (
|
||||
self.label_key is not None
|
||||
), 'label_key neither in diffusion model nor in model.params'
|
||||
), "label_key neither in diffusion model nor in model.params"
|
||||
|
||||
if self.label_key not in __models__:
|
||||
raise NotImplementedError()
|
||||
@ -80,14 +73,14 @@ class NoisyLatentImageClassifier(pl.LightningModule):
|
||||
self.weight_decay = weight_decay
|
||||
|
||||
def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):
|
||||
sd = torch.load(path, map_location='cpu')
|
||||
if 'state_dict' in list(sd.keys()):
|
||||
sd = sd['state_dict']
|
||||
sd = torch.load(path, map_location="cpu")
|
||||
if "state_dict" in list(sd.keys()):
|
||||
sd = sd["state_dict"]
|
||||
keys = list(sd.keys())
|
||||
for k in keys:
|
||||
for ik in ignore_keys:
|
||||
if k.startswith(ik):
|
||||
print('Deleting key {} from state_dict.'.format(k))
|
||||
print("Deleting key {} from state_dict.".format(k))
|
||||
del sd[k]
|
||||
missing, unexpected = (
|
||||
self.load_state_dict(sd, strict=False)
|
||||
@ -95,12 +88,12 @@ class NoisyLatentImageClassifier(pl.LightningModule):
|
||||
else self.model.load_state_dict(sd, strict=False)
|
||||
)
|
||||
print(
|
||||
f'Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys'
|
||||
f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys"
|
||||
)
|
||||
if len(missing) > 0:
|
||||
print(f'Missing Keys: {missing}')
|
||||
print(f"Missing Keys: {missing}")
|
||||
if len(unexpected) > 0:
|
||||
print(f'Unexpected Keys: {unexpected}')
|
||||
print(f"Unexpected Keys: {unexpected}")
|
||||
|
||||
def load_diffusion(self):
|
||||
model = instantiate_from_config(self.diffusion_config)
|
||||
@ -110,24 +103,22 @@ class NoisyLatentImageClassifier(pl.LightningModule):
|
||||
param.requires_grad = False
|
||||
|
||||
def load_classifier(self, ckpt_path, pool):
|
||||
model_config = deepcopy(
|
||||
self.diffusion_config.params.unet_config.params
|
||||
)
|
||||
model_config = deepcopy(self.diffusion_config.params.unet_config.params)
|
||||
model_config.in_channels = (
|
||||
self.diffusion_config.params.unet_config.params.out_channels
|
||||
)
|
||||
model_config.out_channels = self.num_classes
|
||||
if self.label_key == 'class_label':
|
||||
if self.label_key == "class_label":
|
||||
model_config.pool = pool
|
||||
|
||||
self.model = __models__[self.label_key](**model_config)
|
||||
if ckpt_path is not None:
|
||||
print(
|
||||
'#####################################################################'
|
||||
"#####################################################################"
|
||||
)
|
||||
print(f'load from ckpt "{ckpt_path}"')
|
||||
print(
|
||||
'#####################################################################'
|
||||
"#####################################################################"
|
||||
)
|
||||
self.init_from_ckpt(ckpt_path)
|
||||
|
||||
@ -137,9 +128,7 @@ class NoisyLatentImageClassifier(pl.LightningModule):
|
||||
continuous_sqrt_alpha_cumprod = None
|
||||
if self.diffusion_model.use_continuous_noise:
|
||||
continuous_sqrt_alpha_cumprod = (
|
||||
self.diffusion_model.sample_continuous_noise_level(
|
||||
x.shape[0], t + 1
|
||||
)
|
||||
self.diffusion_model.sample_continuous_noise_level(x.shape[0], t + 1)
|
||||
)
|
||||
# todo: make sure t+1 is correct here
|
||||
|
||||
@ -158,7 +147,7 @@ class NoisyLatentImageClassifier(pl.LightningModule):
|
||||
x = batch[k]
|
||||
if len(x.shape) == 3:
|
||||
x = x[..., None]
|
||||
x = rearrange(x, 'b h w c -> b c h w')
|
||||
x = rearrange(x, "b h w c -> b c h w")
|
||||
x = x.to(memory_format=torch.contiguous_format).float()
|
||||
return x
|
||||
|
||||
@ -166,45 +155,41 @@ class NoisyLatentImageClassifier(pl.LightningModule):
|
||||
def get_conditioning(self, batch, k=None):
|
||||
if k is None:
|
||||
k = self.label_key
|
||||
assert k is not None, 'Needs to provide label key'
|
||||
assert k is not None, "Needs to provide label key"
|
||||
|
||||
targets = batch[k].to(self.device)
|
||||
|
||||
if self.label_key == 'segmentation':
|
||||
targets = rearrange(targets, 'b h w c -> b c h w')
|
||||
if self.label_key == "segmentation":
|
||||
targets = rearrange(targets, "b h w c -> b c h w")
|
||||
for down in range(self.numd):
|
||||
h, w = targets.shape[-2:]
|
||||
targets = F.interpolate(
|
||||
targets, size=(h // 2, w // 2), mode='nearest'
|
||||
)
|
||||
targets = F.interpolate(targets, size=(h // 2, w // 2), mode="nearest")
|
||||
|
||||
# targets = rearrange(targets,'b c h w -> b h w c')
|
||||
|
||||
return targets
|
||||
|
||||
def compute_top_k(self, logits, labels, k, reduction='mean'):
|
||||
def compute_top_k(self, logits, labels, k, reduction="mean"):
|
||||
_, top_ks = torch.topk(logits, k, dim=1)
|
||||
if reduction == 'mean':
|
||||
return (
|
||||
(top_ks == labels[:, None]).float().sum(dim=-1).mean().item()
|
||||
)
|
||||
elif reduction == 'none':
|
||||
if reduction == "mean":
|
||||
return (top_ks == labels[:, None]).float().sum(dim=-1).mean().item()
|
||||
elif reduction == "none":
|
||||
return (top_ks == labels[:, None]).float().sum(dim=-1)
|
||||
|
||||
def on_train_epoch_start(self):
|
||||
# save some memory
|
||||
self.diffusion_model.model.to('cpu')
|
||||
self.diffusion_model.model.to("cpu")
|
||||
|
||||
@torch.no_grad()
|
||||
def write_logs(self, loss, logits, targets):
|
||||
log_prefix = 'train' if self.training else 'val'
|
||||
log_prefix = "train" if self.training else "val"
|
||||
log = {}
|
||||
log[f'{log_prefix}/loss'] = loss.mean()
|
||||
log[f'{log_prefix}/acc@1'] = self.compute_top_k(
|
||||
logits, targets, k=1, reduction='mean'
|
||||
log[f"{log_prefix}/loss"] = loss.mean()
|
||||
log[f"{log_prefix}/acc@1"] = self.compute_top_k(
|
||||
logits, targets, k=1, reduction="mean"
|
||||
)
|
||||
log[f'{log_prefix}/acc@5'] = self.compute_top_k(
|
||||
logits, targets, k=5, reduction='mean'
|
||||
log[f"{log_prefix}/acc@5"] = self.compute_top_k(
|
||||
logits, targets, k=5, reduction="mean"
|
||||
)
|
||||
|
||||
self.log_dict(
|
||||
@ -214,19 +199,17 @@ class NoisyLatentImageClassifier(pl.LightningModule):
|
||||
on_step=self.training,
|
||||
on_epoch=True,
|
||||
)
|
||||
self.log("loss", log[f"{log_prefix}/loss"], prog_bar=True, logger=False)
|
||||
self.log(
|
||||
'loss', log[f'{log_prefix}/loss'], prog_bar=True, logger=False
|
||||
)
|
||||
self.log(
|
||||
'global_step',
|
||||
"global_step",
|
||||
self.global_step,
|
||||
logger=False,
|
||||
on_epoch=False,
|
||||
prog_bar=True,
|
||||
)
|
||||
lr = self.optimizers().param_groups[0]['lr']
|
||||
lr = self.optimizers().param_groups[0]["lr"]
|
||||
self.log(
|
||||
'lr_abs',
|
||||
"lr_abs",
|
||||
lr,
|
||||
on_step=True,
|
||||
logger=True,
|
||||
@ -249,13 +232,11 @@ class NoisyLatentImageClassifier(pl.LightningModule):
|
||||
device=self.device,
|
||||
).long()
|
||||
else:
|
||||
t = torch.full(
|
||||
size=(x.shape[0],), fill_value=t, device=self.device
|
||||
).long()
|
||||
t = torch.full(size=(x.shape[0],), fill_value=t, device=self.device).long()
|
||||
x_noisy = self.get_x_noisy(x, t)
|
||||
logits = self(x_noisy, t)
|
||||
|
||||
loss = F.cross_entropy(logits, targets, reduction='none')
|
||||
loss = F.cross_entropy(logits, targets, reduction="none")
|
||||
|
||||
self.write_logs(loss.detach(), logits.detach(), targets.detach())
|
||||
|
||||
@ -268,7 +249,7 @@ class NoisyLatentImageClassifier(pl.LightningModule):
|
||||
|
||||
def reset_noise_accs(self):
|
||||
self.noisy_acc = {
|
||||
t: {'acc@1': [], 'acc@5': []}
|
||||
t: {"acc@1": [], "acc@5": []}
|
||||
for t in range(
|
||||
0,
|
||||
self.diffusion_model.num_timesteps,
|
||||
@ -285,11 +266,11 @@ class NoisyLatentImageClassifier(pl.LightningModule):
|
||||
|
||||
for t in self.noisy_acc:
|
||||
_, logits, _, targets = self.shared_step(batch, t)
|
||||
self.noisy_acc[t]['acc@1'].append(
|
||||
self.compute_top_k(logits, targets, k=1, reduction='mean')
|
||||
self.noisy_acc[t]["acc@1"].append(
|
||||
self.compute_top_k(logits, targets, k=1, reduction="mean")
|
||||
)
|
||||
self.noisy_acc[t]['acc@5'].append(
|
||||
self.compute_top_k(logits, targets, k=5, reduction='mean')
|
||||
self.noisy_acc[t]["acc@5"].append(
|
||||
self.compute_top_k(logits, targets, k=5, reduction="mean")
|
||||
)
|
||||
|
||||
return loss
|
||||
@ -304,14 +285,12 @@ class NoisyLatentImageClassifier(pl.LightningModule):
|
||||
if self.use_scheduler:
|
||||
scheduler = instantiate_from_config(self.scheduler_config)
|
||||
|
||||
print('Setting up LambdaLR scheduler...')
|
||||
print("Setting up LambdaLR scheduler...")
|
||||
scheduler = [
|
||||
{
|
||||
'scheduler': LambdaLR(
|
||||
optimizer, lr_lambda=scheduler.schedule
|
||||
),
|
||||
'interval': 'step',
|
||||
'frequency': 1,
|
||||
"scheduler": LambdaLR(optimizer, lr_lambda=scheduler.schedule),
|
||||
"interval": "step",
|
||||
"frequency": 1,
|
||||
}
|
||||
]
|
||||
return [optimizer], scheduler
|
||||
@ -322,32 +301,28 @@ class NoisyLatentImageClassifier(pl.LightningModule):
|
||||
def log_images(self, batch, N=8, *args, **kwargs):
|
||||
log = dict()
|
||||
x = self.get_input(batch, self.diffusion_model.first_stage_key)
|
||||
log['inputs'] = x
|
||||
log["inputs"] = x
|
||||
|
||||
y = self.get_conditioning(batch)
|
||||
|
||||
if self.label_key == 'class_label':
|
||||
y = log_txt_as_img((x.shape[2], x.shape[3]), batch['human_label'])
|
||||
log['labels'] = y
|
||||
if self.label_key == "class_label":
|
||||
y = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"])
|
||||
log["labels"] = y
|
||||
|
||||
if ismap(y):
|
||||
log['labels'] = self.diffusion_model.to_rgb(y)
|
||||
log["labels"] = self.diffusion_model.to_rgb(y)
|
||||
|
||||
for step in range(self.log_steps):
|
||||
current_time = step * self.log_time_interval
|
||||
|
||||
_, logits, x_noisy, _ = self.shared_step(batch, t=current_time)
|
||||
|
||||
log[f'inputs@t{current_time}'] = x_noisy
|
||||
log[f"inputs@t{current_time}"] = x_noisy
|
||||
|
||||
pred = F.one_hot(
|
||||
logits.argmax(dim=1), num_classes=self.num_classes
|
||||
)
|
||||
pred = rearrange(pred, 'b h w c -> b c h w')
|
||||
pred = F.one_hot(logits.argmax(dim=1), num_classes=self.num_classes)
|
||||
pred = rearrange(pred, "b h w c -> b c h w")
|
||||
|
||||
log[f'pred@t{current_time}'] = self.diffusion_model.to_rgb(
|
||||
pred
|
||||
)
|
||||
log[f"pred@t{current_time}"] = self.diffusion_model.to_rgb(pred)
|
||||
|
||||
for key in log:
|
||||
log[key] = log[key][:N]
|
@ -1,21 +1,20 @@
|
||||
|
||||
# adapted from bloc97's CrossAttentionControl colab
|
||||
# https://github.com/bloc97/CrossAttentionControl
|
||||
|
||||
|
||||
import enum
|
||||
import math
|
||||
from typing import Optional, Callable
|
||||
from typing import Callable, Optional
|
||||
|
||||
import diffusers
|
||||
import psutil
|
||||
import torch
|
||||
import diffusers
|
||||
from compel.cross_attention_control import Arguments
|
||||
from diffusers.models.cross_attention import AttnProcessor
|
||||
from diffusers.models.unet_2d_condition import UNet2DConditionModel
|
||||
from torch import nn
|
||||
|
||||
from compel.cross_attention_control import Arguments
|
||||
from diffusers.models.unet_2d_condition import UNet2DConditionModel
|
||||
from diffusers.models.cross_attention import AttnProcessor
|
||||
from ldm.invoke.devices import torch_dtype
|
||||
from ...util import torch_dtype
|
||||
|
||||
|
||||
class CrossAttentionType(enum.Enum):
|
||||
@ -24,13 +23,12 @@ class CrossAttentionType(enum.Enum):
|
||||
|
||||
|
||||
class Context:
|
||||
|
||||
cross_attention_mask: Optional[torch.Tensor]
|
||||
cross_attention_index_map: Optional[torch.Tensor]
|
||||
|
||||
class Action(enum.Enum):
|
||||
NONE = 0
|
||||
SAVE = 1,
|
||||
SAVE = (1,)
|
||||
APPLY = 2
|
||||
|
||||
def __init__(self, arguments: Arguments, step_count: int):
|
||||
@ -53,11 +51,13 @@ class Context:
|
||||
self.clear_requests(cleanup=True)
|
||||
|
||||
def register_cross_attention_modules(self, model):
|
||||
for name,module in get_cross_attention_modules(model, CrossAttentionType.SELF):
|
||||
for name, module in get_cross_attention_modules(model, CrossAttentionType.SELF):
|
||||
if name in self.self_cross_attention_module_identifiers:
|
||||
assert False, f"name {name} cannot appear more than once"
|
||||
self.self_cross_attention_module_identifiers.append(name)
|
||||
for name,module in get_cross_attention_modules(model, CrossAttentionType.TOKENS):
|
||||
for name, module in get_cross_attention_modules(
|
||||
model, CrossAttentionType.TOKENS
|
||||
):
|
||||
if name in self.tokens_cross_attention_module_identifiers:
|
||||
assert False, f"name {name} cannot appear more than once"
|
||||
self.tokens_cross_attention_module_identifiers.append(name)
|
||||
@ -68,7 +68,9 @@ class Context:
|
||||
else:
|
||||
self.tokens_cross_attention_action = Context.Action.SAVE
|
||||
|
||||
def request_apply_saved_attention_maps(self, cross_attention_type: CrossAttentionType):
|
||||
def request_apply_saved_attention_maps(
|
||||
self, cross_attention_type: CrossAttentionType
|
||||
):
|
||||
if cross_attention_type == CrossAttentionType.SELF:
|
||||
self.self_cross_attention_action = Context.Action.APPLY
|
||||
else:
|
||||
@ -91,8 +93,9 @@ class Context:
|
||||
return self.tokens_cross_attention_action == Context.Action.APPLY
|
||||
return False
|
||||
|
||||
def get_active_cross_attention_control_types_for_step(self, percent_through:float=None)\
|
||||
-> list[CrossAttentionType]:
|
||||
def get_active_cross_attention_control_types_for_step(
|
||||
self, percent_through: float = None
|
||||
) -> list[CrossAttentionType]:
|
||||
"""
|
||||
Should cross-attention control be applied on the given step?
|
||||
:param percent_through: How far through the step sequence are we (0.0=pure noise, 1.0=completely denoised image). Expected range 0.0..<1.0.
|
||||
@ -103,50 +106,73 @@ class Context:
|
||||
|
||||
opts = self.arguments.edit_options
|
||||
to_control = []
|
||||
if opts['s_start'] <= percent_through < opts['s_end']:
|
||||
if opts["s_start"] <= percent_through < opts["s_end"]:
|
||||
to_control.append(CrossAttentionType.SELF)
|
||||
if opts['t_start'] <= percent_through < opts['t_end']:
|
||||
if opts["t_start"] <= percent_through < opts["t_end"]:
|
||||
to_control.append(CrossAttentionType.TOKENS)
|
||||
return to_control
|
||||
|
||||
def save_slice(self, identifier: str, slice: torch.Tensor, dim: Optional[int], offset: int,
|
||||
slice_size: Optional[int]):
|
||||
def save_slice(
|
||||
self,
|
||||
identifier: str,
|
||||
slice: torch.Tensor,
|
||||
dim: Optional[int],
|
||||
offset: int,
|
||||
slice_size: Optional[int],
|
||||
):
|
||||
if identifier not in self.saved_cross_attention_maps:
|
||||
self.saved_cross_attention_maps[identifier] = {
|
||||
'dim': dim,
|
||||
'slice_size': slice_size,
|
||||
'slices': {offset or 0: slice}
|
||||
"dim": dim,
|
||||
"slice_size": slice_size,
|
||||
"slices": {offset or 0: slice},
|
||||
}
|
||||
else:
|
||||
self.saved_cross_attention_maps[identifier]['slices'][offset or 0] = slice
|
||||
self.saved_cross_attention_maps[identifier]["slices"][offset or 0] = slice
|
||||
|
||||
def get_slice(self, identifier: str, requested_dim: Optional[int], requested_offset: int, slice_size: int):
|
||||
def get_slice(
|
||||
self,
|
||||
identifier: str,
|
||||
requested_dim: Optional[int],
|
||||
requested_offset: int,
|
||||
slice_size: int,
|
||||
):
|
||||
saved_attention_dict = self.saved_cross_attention_maps[identifier]
|
||||
if requested_dim is None:
|
||||
if saved_attention_dict['dim'] is not None:
|
||||
raise RuntimeError(f"dim mismatch: expected dim=None, have {saved_attention_dict['dim']}")
|
||||
return saved_attention_dict['slices'][0]
|
||||
|
||||
if saved_attention_dict['dim'] == requested_dim:
|
||||
if slice_size != saved_attention_dict['slice_size']:
|
||||
if saved_attention_dict["dim"] is not None:
|
||||
raise RuntimeError(
|
||||
f"slice_size mismatch: expected slice_size={slice_size}, have {saved_attention_dict['slice_size']}")
|
||||
return saved_attention_dict['slices'][requested_offset]
|
||||
f"dim mismatch: expected dim=None, have {saved_attention_dict['dim']}"
|
||||
)
|
||||
return saved_attention_dict["slices"][0]
|
||||
|
||||
if saved_attention_dict['dim'] is None:
|
||||
whole_saved_attention = saved_attention_dict['slices'][0]
|
||||
if saved_attention_dict["dim"] == requested_dim:
|
||||
if slice_size != saved_attention_dict["slice_size"]:
|
||||
raise RuntimeError(
|
||||
f"slice_size mismatch: expected slice_size={slice_size}, have {saved_attention_dict['slice_size']}"
|
||||
)
|
||||
return saved_attention_dict["slices"][requested_offset]
|
||||
|
||||
if saved_attention_dict["dim"] is None:
|
||||
whole_saved_attention = saved_attention_dict["slices"][0]
|
||||
if requested_dim == 0:
|
||||
return whole_saved_attention[requested_offset:requested_offset + slice_size]
|
||||
return whole_saved_attention[
|
||||
requested_offset : requested_offset + slice_size
|
||||
]
|
||||
elif requested_dim == 1:
|
||||
return whole_saved_attention[:, requested_offset:requested_offset + slice_size]
|
||||
return whole_saved_attention[
|
||||
:, requested_offset : requested_offset + slice_size
|
||||
]
|
||||
|
||||
raise RuntimeError(f"Cannot convert dim {saved_attention_dict['dim']} to requested dim {requested_dim}")
|
||||
raise RuntimeError(
|
||||
f"Cannot convert dim {saved_attention_dict['dim']} to requested dim {requested_dim}"
|
||||
)
|
||||
|
||||
def get_slicing_strategy(self, identifier: str) -> tuple[Optional[int], Optional[int]]:
|
||||
def get_slicing_strategy(
|
||||
self, identifier: str
|
||||
) -> tuple[Optional[int], Optional[int]]:
|
||||
saved_attention = self.saved_cross_attention_maps.get(identifier, None)
|
||||
if saved_attention is None:
|
||||
return None, None
|
||||
return saved_attention['dim'], saved_attention['slice_size']
|
||||
return saved_attention["dim"], saved_attention["slice_size"]
|
||||
|
||||
def clear_requests(self, cleanup=True):
|
||||
self.tokens_cross_attention_action = Context.Action.NONE
|
||||
@ -156,9 +182,8 @@ class Context:
|
||||
|
||||
def offload_saved_attention_slices_to_cpu(self):
|
||||
for key, map_dict in self.saved_cross_attention_maps.items():
|
||||
for offset, slice in map_dict['slices'].items():
|
||||
map_dict[offset] = slice.to('cpu')
|
||||
|
||||
for offset, slice in map_dict["slices"].items():
|
||||
map_dict[offset] = slice.to("cpu")
|
||||
|
||||
|
||||
class InvokeAICrossAttentionMixin:
|
||||
@ -167,14 +192,20 @@ class InvokeAICrossAttentionMixin:
|
||||
through both to an attention_slice_wrangler and a slicing_strategy_getter for custom attention map wrangling
|
||||
and dymamic slicing strategy selection.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.mem_total_gb = psutil.virtual_memory().total // (1 << 30)
|
||||
self.attention_slice_wrangler = None
|
||||
self.slicing_strategy_getter = None
|
||||
self.attention_slice_calculated_callback = None
|
||||
|
||||
def set_attention_slice_wrangler(self, wrangler: Optional[Callable[[nn.Module, torch.Tensor, int, int, int], torch.Tensor]]):
|
||||
'''
|
||||
def set_attention_slice_wrangler(
|
||||
self,
|
||||
wrangler: Optional[
|
||||
Callable[[nn.Module, torch.Tensor, int, int, int], torch.Tensor]
|
||||
],
|
||||
):
|
||||
"""
|
||||
Set custom attention calculator to be called when attention is calculated
|
||||
:param wrangler: Callback, with args (module, suggested_attention_slice, dim, offset, slice_size),
|
||||
which returns either the suggested_attention_slice or an adjusted equivalent.
|
||||
@ -185,20 +216,30 @@ class InvokeAICrossAttentionMixin:
|
||||
|
||||
Pass None to use the default attention calculation.
|
||||
:return:
|
||||
'''
|
||||
"""
|
||||
self.attention_slice_wrangler = wrangler
|
||||
|
||||
def set_slicing_strategy_getter(self, getter: Optional[Callable[[nn.Module], tuple[int,int]]]):
|
||||
def set_slicing_strategy_getter(
|
||||
self, getter: Optional[Callable[[nn.Module], tuple[int, int]]]
|
||||
):
|
||||
self.slicing_strategy_getter = getter
|
||||
|
||||
def set_attention_slice_calculated_callback(self, callback: Optional[Callable[[torch.Tensor], None]]):
|
||||
def set_attention_slice_calculated_callback(
|
||||
self, callback: Optional[Callable[[torch.Tensor], None]]
|
||||
):
|
||||
self.attention_slice_calculated_callback = callback
|
||||
|
||||
def einsum_lowest_level(self, query, key, value, dim, offset, slice_size):
|
||||
# calculate attention scores
|
||||
#attention_scores = torch.einsum('b i d, b j d -> b i j', q, k)
|
||||
# attention_scores = torch.einsum('b i d, b j d -> b i j', q, k)
|
||||
attention_scores = torch.baddbmm(
|
||||
torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device),
|
||||
torch.empty(
|
||||
query.shape[0],
|
||||
query.shape[1],
|
||||
key.shape[1],
|
||||
dtype=query.dtype,
|
||||
device=query.device,
|
||||
),
|
||||
query,
|
||||
key.transpose(-1, -2),
|
||||
beta=0,
|
||||
@ -206,35 +247,49 @@ class InvokeAICrossAttentionMixin:
|
||||
)
|
||||
|
||||
# calculate attention slice by taking the best scores for each latent pixel
|
||||
default_attention_slice = attention_scores.softmax(dim=-1, dtype=attention_scores.dtype)
|
||||
default_attention_slice = attention_scores.softmax(
|
||||
dim=-1, dtype=attention_scores.dtype
|
||||
)
|
||||
attention_slice_wrangler = self.attention_slice_wrangler
|
||||
if attention_slice_wrangler is not None:
|
||||
attention_slice = attention_slice_wrangler(self, default_attention_slice, dim, offset, slice_size)
|
||||
attention_slice = attention_slice_wrangler(
|
||||
self, default_attention_slice, dim, offset, slice_size
|
||||
)
|
||||
else:
|
||||
attention_slice = default_attention_slice
|
||||
|
||||
if self.attention_slice_calculated_callback is not None:
|
||||
self.attention_slice_calculated_callback(attention_slice, dim, offset, slice_size)
|
||||
self.attention_slice_calculated_callback(
|
||||
attention_slice, dim, offset, slice_size
|
||||
)
|
||||
|
||||
hidden_states = torch.bmm(attention_slice, value)
|
||||
return hidden_states
|
||||
|
||||
def einsum_op_slice_dim0(self, q, k, v, slice_size):
|
||||
r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
|
||||
r = torch.zeros(
|
||||
q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype
|
||||
)
|
||||
for i in range(0, q.shape[0], slice_size):
|
||||
end = i + slice_size
|
||||
r[i:end] = self.einsum_lowest_level(q[i:end], k[i:end], v[i:end], dim=0, offset=i, slice_size=slice_size)
|
||||
r[i:end] = self.einsum_lowest_level(
|
||||
q[i:end], k[i:end], v[i:end], dim=0, offset=i, slice_size=slice_size
|
||||
)
|
||||
return r
|
||||
|
||||
def einsum_op_slice_dim1(self, q, k, v, slice_size):
|
||||
r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
|
||||
r = torch.zeros(
|
||||
q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype
|
||||
)
|
||||
for i in range(0, q.shape[1], slice_size):
|
||||
end = i + slice_size
|
||||
r[:, i:end] = self.einsum_lowest_level(q[:, i:end], k, v, dim=1, offset=i, slice_size=slice_size)
|
||||
r[:, i:end] = self.einsum_lowest_level(
|
||||
q[:, i:end], k, v, dim=1, offset=i, slice_size=slice_size
|
||||
)
|
||||
return r
|
||||
|
||||
def einsum_op_mps_v1(self, q, k, v):
|
||||
if q.shape[1] <= 4096: # (512x512) max q.shape[1]: 4096
|
||||
if q.shape[1] <= 4096: # (512x512) max q.shape[1]: 4096
|
||||
return self.einsum_lowest_level(q, k, v, None, None, None)
|
||||
else:
|
||||
slice_size = math.floor(2**30 / (q.shape[0] * q.shape[1]))
|
||||
@ -272,13 +327,12 @@ class InvokeAICrossAttentionMixin:
|
||||
# Divide factor of safety as there's copying and fragmentation
|
||||
return self.einsum_op_tensor_mem(q, k, v, mem_free_total / 3.3 / (1 << 20))
|
||||
|
||||
|
||||
def get_invokeai_attention_mem_efficient(self, q, k, v):
|
||||
if q.device.type == 'cuda':
|
||||
#print("in get_attention_mem_efficient with q shape", q.shape, ", k shape", k.shape, ", free memory is", get_mem_free_total(q.device))
|
||||
if q.device.type == "cuda":
|
||||
# print("in get_attention_mem_efficient with q shape", q.shape, ", k shape", k.shape, ", free memory is", get_mem_free_total(q.device))
|
||||
return self.einsum_op_cuda(q, k, v)
|
||||
|
||||
if q.device.type == 'mps' or q.device.type == 'cpu':
|
||||
if q.device.type == "mps" or q.device.type == "cpu":
|
||||
if self.mem_total_gb >= 32:
|
||||
return self.einsum_op_mps_v1(q, k, v)
|
||||
return self.einsum_op_mps_v2(q, k, v)
|
||||
@ -288,8 +342,11 @@ class InvokeAICrossAttentionMixin:
|
||||
return self.einsum_op_tensor_mem(q, k, v, 32)
|
||||
|
||||
|
||||
|
||||
def restore_default_cross_attention(model, is_running_diffusers: bool, restore_attention_processor: Optional[AttnProcessor]=None):
|
||||
def restore_default_cross_attention(
|
||||
model,
|
||||
is_running_diffusers: bool,
|
||||
restore_attention_processor: Optional[AttnProcessor] = None,
|
||||
):
|
||||
if is_running_diffusers:
|
||||
unet = model
|
||||
unet.set_attn_processor(restore_attention_processor or CrossAttnProcessor())
|
||||
@ -297,7 +354,7 @@ def restore_default_cross_attention(model, is_running_diffusers: bool, restore_a
|
||||
remove_attention_function(model)
|
||||
|
||||
|
||||
def override_cross_attention(model, context: Context, is_running_diffusers = False):
|
||||
def override_cross_attention(model, context: Context, is_running_diffusers=False):
|
||||
"""
|
||||
Inject attention parameters and functions into the passed in model to enable cross attention editing.
|
||||
|
||||
@ -316,7 +373,7 @@ def override_cross_attention(model, context: Context, is_running_diffusers = Fal
|
||||
indices = torch.arange(max_length, dtype=torch.long)
|
||||
for name, a0, a1, b0, b1 in context.arguments.edit_opcodes:
|
||||
if b0 < max_length:
|
||||
if name == "equal":# or (name == "replace" and a1 - a0 == b1 - b0):
|
||||
if name == "equal": # or (name == "replace" and a1 - a0 == b1 - b0):
|
||||
# these tokens have not been edited
|
||||
indices[b0:b1] = indices_target[a0:a1]
|
||||
mask[b0:b1] = 1
|
||||
@ -332,7 +389,14 @@ def override_cross_attention(model, context: Context, is_running_diffusers = Fal
|
||||
else:
|
||||
# try to re-use an existing slice size
|
||||
default_slice_size = 4
|
||||
slice_size = next((p.slice_size for p in old_attn_processors.values() if type(p) is SlicedAttnProcessor), default_slice_size)
|
||||
slice_size = next(
|
||||
(
|
||||
p.slice_size
|
||||
for p in old_attn_processors.values()
|
||||
if type(p) is SlicedAttnProcessor
|
||||
),
|
||||
default_slice_size,
|
||||
)
|
||||
unet.set_attn_processor(SlicedSwapCrossAttnProcesser(slice_size=slice_size))
|
||||
return old_attn_processors
|
||||
else:
|
||||
@ -341,65 +405,96 @@ def override_cross_attention(model, context: Context, is_running_diffusers = Fal
|
||||
return None
|
||||
|
||||
|
||||
def get_cross_attention_modules(
|
||||
model, which: CrossAttentionType
|
||||
) -> list[tuple[str, InvokeAICrossAttentionMixin]]:
|
||||
from ldm.modules.attention import CrossAttention # avoid circular import
|
||||
|
||||
|
||||
def get_cross_attention_modules(model, which: CrossAttentionType) -> list[tuple[str, InvokeAICrossAttentionMixin]]:
|
||||
from ldm.modules.attention import CrossAttention # avoid circular import
|
||||
cross_attention_class: type = InvokeAIDiffusersCrossAttention if isinstance(model,UNet2DConditionModel) else CrossAttention
|
||||
cross_attention_class: type = (
|
||||
InvokeAIDiffusersCrossAttention
|
||||
if isinstance(model, UNet2DConditionModel)
|
||||
else CrossAttention
|
||||
)
|
||||
which_attn = "attn1" if which is CrossAttentionType.SELF else "attn2"
|
||||
attention_module_tuples = [(name,module) for name, module in model.named_modules() if
|
||||
isinstance(module, cross_attention_class) and which_attn in name]
|
||||
attention_module_tuples = [
|
||||
(name, module)
|
||||
for name, module in model.named_modules()
|
||||
if isinstance(module, cross_attention_class) and which_attn in name
|
||||
]
|
||||
cross_attention_modules_in_model_count = len(attention_module_tuples)
|
||||
expected_count = 16
|
||||
if cross_attention_modules_in_model_count != expected_count:
|
||||
# non-fatal error but .swap() won't work.
|
||||
print(f"Error! CrossAttentionControl found an unexpected number of {cross_attention_class} modules in the model " +
|
||||
f"(expected {expected_count}, found {cross_attention_modules_in_model_count}). Either monkey-patching failed " +
|
||||
f"or some assumption has changed about the structure of the model itself. Please fix the monkey-patching, " +
|
||||
f"and/or update the {expected_count} above to an appropriate number, and/or find and inform someone who knows " +
|
||||
f"what it means. This error is non-fatal, but it is likely that .swap() and attention map display will not " +
|
||||
f"work properly until it is fixed.")
|
||||
print(
|
||||
f"Error! CrossAttentionControl found an unexpected number of {cross_attention_class} modules in the model "
|
||||
+ f"(expected {expected_count}, found {cross_attention_modules_in_model_count}). Either monkey-patching failed "
|
||||
+ f"or some assumption has changed about the structure of the model itself. Please fix the monkey-patching, "
|
||||
+ f"and/or update the {expected_count} above to an appropriate number, and/or find and inform someone who knows "
|
||||
+ f"what it means. This error is non-fatal, but it is likely that .swap() and attention map display will not "
|
||||
+ f"work properly until it is fixed."
|
||||
)
|
||||
return attention_module_tuples
|
||||
|
||||
|
||||
def inject_attention_function(unet, context: Context):
|
||||
# ORIGINAL SOURCE CODE: https://github.com/huggingface/diffusers/blob/91ddd2a25b848df0fa1262d4f1cd98c7ccb87750/src/diffusers/models/attention.py#L276
|
||||
|
||||
def attention_slice_wrangler(module, suggested_attention_slice:torch.Tensor, dim, offset, slice_size):
|
||||
|
||||
#memory_usage = suggested_attention_slice.element_size() * suggested_attention_slice.nelement()
|
||||
def attention_slice_wrangler(
|
||||
module, suggested_attention_slice: torch.Tensor, dim, offset, slice_size
|
||||
):
|
||||
# memory_usage = suggested_attention_slice.element_size() * suggested_attention_slice.nelement()
|
||||
|
||||
attention_slice = suggested_attention_slice
|
||||
|
||||
if context.get_should_save_maps(module.identifier):
|
||||
#print(module.identifier, "saving suggested_attention_slice of shape",
|
||||
# print(module.identifier, "saving suggested_attention_slice of shape",
|
||||
# suggested_attention_slice.shape, "dim", dim, "offset", offset)
|
||||
slice_to_save = attention_slice.to('cpu') if dim is not None else attention_slice
|
||||
context.save_slice(module.identifier, slice_to_save, dim=dim, offset=offset, slice_size=slice_size)
|
||||
slice_to_save = (
|
||||
attention_slice.to("cpu") if dim is not None else attention_slice
|
||||
)
|
||||
context.save_slice(
|
||||
module.identifier,
|
||||
slice_to_save,
|
||||
dim=dim,
|
||||
offset=offset,
|
||||
slice_size=slice_size,
|
||||
)
|
||||
elif context.get_should_apply_saved_maps(module.identifier):
|
||||
#print(module.identifier, "applying saved attention slice for dim", dim, "offset", offset)
|
||||
saved_attention_slice = context.get_slice(module.identifier, dim, offset, slice_size)
|
||||
# print(module.identifier, "applying saved attention slice for dim", dim, "offset", offset)
|
||||
saved_attention_slice = context.get_slice(
|
||||
module.identifier, dim, offset, slice_size
|
||||
)
|
||||
|
||||
# slice may have been offloaded to CPU
|
||||
saved_attention_slice = saved_attention_slice.to(suggested_attention_slice.device)
|
||||
saved_attention_slice = saved_attention_slice.to(
|
||||
suggested_attention_slice.device
|
||||
)
|
||||
|
||||
if context.is_tokens_cross_attention(module.identifier):
|
||||
index_map = context.cross_attention_index_map
|
||||
remapped_saved_attention_slice = torch.index_select(saved_attention_slice, -1, index_map)
|
||||
remapped_saved_attention_slice = torch.index_select(
|
||||
saved_attention_slice, -1, index_map
|
||||
)
|
||||
this_attention_slice = suggested_attention_slice
|
||||
|
||||
mask = context.cross_attention_mask.to(torch_dtype(suggested_attention_slice.device))
|
||||
mask = context.cross_attention_mask.to(
|
||||
torch_dtype(suggested_attention_slice.device)
|
||||
)
|
||||
saved_mask = mask
|
||||
this_mask = 1 - mask
|
||||
attention_slice = remapped_saved_attention_slice * saved_mask + \
|
||||
this_attention_slice * this_mask
|
||||
attention_slice = (
|
||||
remapped_saved_attention_slice * saved_mask
|
||||
+ this_attention_slice * this_mask
|
||||
)
|
||||
else:
|
||||
# just use everything
|
||||
attention_slice = saved_attention_slice
|
||||
|
||||
return attention_slice
|
||||
|
||||
cross_attention_modules = get_cross_attention_modules(unet, CrossAttentionType.TOKENS) + get_cross_attention_modules(unet, CrossAttentionType.SELF)
|
||||
cross_attention_modules = get_cross_attention_modules(
|
||||
unet, CrossAttentionType.TOKENS
|
||||
) + get_cross_attention_modules(unet, CrossAttentionType.SELF)
|
||||
for identifier, module in cross_attention_modules:
|
||||
module.identifier = identifier
|
||||
try:
|
||||
@ -408,56 +503,61 @@ def inject_attention_function(unet, context: Context):
|
||||
lambda module: context.get_slicing_strategy(identifier)
|
||||
)
|
||||
except AttributeError as e:
|
||||
if is_attribute_error_about(e, 'set_attention_slice_wrangler'):
|
||||
print(f"TODO: implement set_attention_slice_wrangler for {type(module)}") # TODO
|
||||
if is_attribute_error_about(e, "set_attention_slice_wrangler"):
|
||||
print(
|
||||
f"TODO: implement set_attention_slice_wrangler for {type(module)}"
|
||||
) # TODO
|
||||
else:
|
||||
raise
|
||||
|
||||
|
||||
def remove_attention_function(unet):
|
||||
cross_attention_modules = get_cross_attention_modules(unet, CrossAttentionType.TOKENS) + get_cross_attention_modules(unet, CrossAttentionType.SELF)
|
||||
cross_attention_modules = get_cross_attention_modules(
|
||||
unet, CrossAttentionType.TOKENS
|
||||
) + get_cross_attention_modules(unet, CrossAttentionType.SELF)
|
||||
for identifier, module in cross_attention_modules:
|
||||
try:
|
||||
# clear wrangler callback
|
||||
module.set_attention_slice_wrangler(None)
|
||||
module.set_slicing_strategy_getter(None)
|
||||
except AttributeError as e:
|
||||
if is_attribute_error_about(e, 'set_attention_slice_wrangler'):
|
||||
print(f"TODO: implement set_attention_slice_wrangler for {type(module)}")
|
||||
if is_attribute_error_about(e, "set_attention_slice_wrangler"):
|
||||
print(
|
||||
f"TODO: implement set_attention_slice_wrangler for {type(module)}"
|
||||
)
|
||||
else:
|
||||
raise
|
||||
|
||||
|
||||
def is_attribute_error_about(error: AttributeError, attribute: str):
|
||||
if hasattr(error, 'name'): # Python 3.10
|
||||
if hasattr(error, "name"): # Python 3.10
|
||||
return error.name == attribute
|
||||
else: # Python 3.9
|
||||
return attribute in str(error)
|
||||
|
||||
|
||||
|
||||
def get_mem_free_total(device):
|
||||
#only on cuda
|
||||
# only on cuda
|
||||
if not torch.cuda.is_available():
|
||||
return None
|
||||
stats = torch.cuda.memory_stats(device)
|
||||
mem_active = stats['active_bytes.all.current']
|
||||
mem_reserved = stats['reserved_bytes.all.current']
|
||||
mem_active = stats["active_bytes.all.current"]
|
||||
mem_reserved = stats["reserved_bytes.all.current"]
|
||||
mem_free_cuda, _ = torch.cuda.mem_get_info(device)
|
||||
mem_free_torch = mem_reserved - mem_active
|
||||
mem_free_total = mem_free_cuda + mem_free_torch
|
||||
return mem_free_total
|
||||
|
||||
|
||||
|
||||
class InvokeAIDiffusersCrossAttention(diffusers.models.attention.CrossAttention, InvokeAICrossAttentionMixin):
|
||||
|
||||
class InvokeAIDiffusersCrossAttention(
|
||||
diffusers.models.attention.CrossAttention, InvokeAICrossAttentionMixin
|
||||
):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
InvokeAICrossAttentionMixin.__init__(self)
|
||||
|
||||
def _attention(self, query, key, value, attention_mask=None):
|
||||
#default_result = super()._attention(query, key, value)
|
||||
# default_result = super()._attention(query, key, value)
|
||||
if attention_mask is not None:
|
||||
print(f"{type(self).__name__} ignoring passed-in attention_mask")
|
||||
attention_result = self.get_invokeai_attention_mem_efficient(query, key, value)
|
||||
@ -466,9 +566,6 @@ class InvokeAIDiffusersCrossAttention(diffusers.models.attention.CrossAttention,
|
||||
return hidden_states
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
## 🧨diffusers implementation follows
|
||||
|
||||
|
||||
@ -501,25 +598,30 @@ class CrossAttnProcessor:
|
||||
return hidden_states
|
||||
|
||||
"""
|
||||
from dataclasses import field, dataclass
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
import torch
|
||||
|
||||
from diffusers.models.cross_attention import CrossAttention, CrossAttnProcessor, SlicedAttnProcessor
|
||||
from diffusers.models.cross_attention import (
|
||||
CrossAttention,
|
||||
CrossAttnProcessor,
|
||||
SlicedAttnProcessor,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SwapCrossAttnContext:
|
||||
modified_text_embeddings: torch.Tensor
|
||||
index_map: torch.Tensor # maps from original prompt token indices to the equivalent tokens in the modified prompt
|
||||
mask: torch.Tensor # in the target space of the index_map
|
||||
index_map: torch.Tensor # maps from original prompt token indices to the equivalent tokens in the modified prompt
|
||||
mask: torch.Tensor # in the target space of the index_map
|
||||
cross_attention_types_to_do: list[CrossAttentionType] = field(default_factory=list)
|
||||
|
||||
def __int__(self,
|
||||
cac_types_to_do: [CrossAttentionType],
|
||||
modified_text_embeddings: torch.Tensor,
|
||||
index_map: torch.Tensor,
|
||||
mask: torch.Tensor):
|
||||
def __int__(
|
||||
self,
|
||||
cac_types_to_do: [CrossAttentionType],
|
||||
modified_text_embeddings: torch.Tensor,
|
||||
index_map: torch.Tensor,
|
||||
mask: torch.Tensor,
|
||||
):
|
||||
self.cross_attention_types_to_do = cac_types_to_do
|
||||
self.modified_text_embeddings = modified_text_embeddings
|
||||
self.index_map = index_map
|
||||
@ -529,9 +631,9 @@ class SwapCrossAttnContext:
|
||||
return attn_type in self.cross_attention_types_to_do
|
||||
|
||||
@classmethod
|
||||
def make_mask_and_index_map(cls, edit_opcodes: list[tuple[str, int, int, int, int]], max_length: int) \
|
||||
-> tuple[torch.Tensor, torch.Tensor]:
|
||||
|
||||
def make_mask_and_index_map(
|
||||
cls, edit_opcodes: list[tuple[str, int, int, int, int]], max_length: int
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
# mask=1 means use original prompt attention, mask=0 means use modified prompt attention
|
||||
mask = torch.zeros(max_length)
|
||||
indices_target = torch.arange(max_length, dtype=torch.long)
|
||||
@ -547,28 +649,42 @@ class SwapCrossAttnContext:
|
||||
|
||||
|
||||
class SlicedSwapCrossAttnProcesser(SlicedAttnProcessor):
|
||||
|
||||
# TODO: dynamically pick slice size based on memory conditions
|
||||
|
||||
def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None,
|
||||
# kwargs
|
||||
swap_cross_attn_context: SwapCrossAttnContext=None):
|
||||
|
||||
attention_type = CrossAttentionType.SELF if encoder_hidden_states is None else CrossAttentionType.TOKENS
|
||||
def __call__(
|
||||
self,
|
||||
attn: CrossAttention,
|
||||
hidden_states,
|
||||
encoder_hidden_states=None,
|
||||
attention_mask=None,
|
||||
# kwargs
|
||||
swap_cross_attn_context: SwapCrossAttnContext = None,
|
||||
):
|
||||
attention_type = (
|
||||
CrossAttentionType.SELF
|
||||
if encoder_hidden_states is None
|
||||
else CrossAttentionType.TOKENS
|
||||
)
|
||||
|
||||
# if cross-attention control is not in play, just call through to the base implementation.
|
||||
if attention_type is CrossAttentionType.SELF or \
|
||||
swap_cross_attn_context is None or \
|
||||
not swap_cross_attn_context.wants_cross_attention_control(attention_type):
|
||||
#print(f"SwapCrossAttnContext for {attention_type} not active - passing request to superclass")
|
||||
return super().__call__(attn, hidden_states, encoder_hidden_states, attention_mask)
|
||||
#else:
|
||||
if (
|
||||
attention_type is CrossAttentionType.SELF
|
||||
or swap_cross_attn_context is None
|
||||
or not swap_cross_attn_context.wants_cross_attention_control(attention_type)
|
||||
):
|
||||
# print(f"SwapCrossAttnContext for {attention_type} not active - passing request to superclass")
|
||||
return super().__call__(
|
||||
attn, hidden_states, encoder_hidden_states, attention_mask
|
||||
)
|
||||
# else:
|
||||
# print(f"SwapCrossAttnContext for {attention_type} active")
|
||||
|
||||
batch_size, sequence_length, _ = hidden_states.shape
|
||||
attention_mask = attn.prepare_attention_mask(
|
||||
attention_mask=attention_mask, target_length=sequence_length,
|
||||
batch_size=batch_size)
|
||||
attention_mask=attention_mask,
|
||||
target_length=sequence_length,
|
||||
batch_size=batch_size,
|
||||
)
|
||||
|
||||
query = attn.to_q(hidden_states)
|
||||
dim = query.shape[-1]
|
||||
@ -589,41 +705,51 @@ class SlicedSwapCrossAttnProcesser(SlicedAttnProcessor):
|
||||
# compute slices and prepare output tensor
|
||||
batch_size_attention = query.shape[0]
|
||||
hidden_states = torch.zeros(
|
||||
(batch_size_attention, sequence_length, dim // attn.heads), device=query.device, dtype=query.dtype
|
||||
(batch_size_attention, sequence_length, dim // attn.heads),
|
||||
device=query.device,
|
||||
dtype=query.dtype,
|
||||
)
|
||||
|
||||
# do slices
|
||||
for i in range(max(1,hidden_states.shape[0] // self.slice_size)):
|
||||
for i in range(max(1, hidden_states.shape[0] // self.slice_size)):
|
||||
start_idx = i * self.slice_size
|
||||
end_idx = (i + 1) * self.slice_size
|
||||
|
||||
query_slice = query[start_idx:end_idx]
|
||||
original_key_slice = original_text_key[start_idx:end_idx]
|
||||
modified_key_slice = modified_text_key[start_idx:end_idx]
|
||||
attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None
|
||||
attn_mask_slice = (
|
||||
attention_mask[start_idx:end_idx]
|
||||
if attention_mask is not None
|
||||
else None
|
||||
)
|
||||
|
||||
original_attn_slice = attn.get_attention_scores(query_slice, original_key_slice, attn_mask_slice)
|
||||
modified_attn_slice = attn.get_attention_scores(query_slice, modified_key_slice, attn_mask_slice)
|
||||
original_attn_slice = attn.get_attention_scores(
|
||||
query_slice, original_key_slice, attn_mask_slice
|
||||
)
|
||||
modified_attn_slice = attn.get_attention_scores(
|
||||
query_slice, modified_key_slice, attn_mask_slice
|
||||
)
|
||||
|
||||
# because the prompt modifications may result in token sequences shifted forwards or backwards,
|
||||
# the original attention probabilities must be remapped to account for token index changes in the
|
||||
# modified prompt
|
||||
remapped_original_attn_slice = torch.index_select(original_attn_slice, -1,
|
||||
swap_cross_attn_context.index_map)
|
||||
remapped_original_attn_slice = torch.index_select(
|
||||
original_attn_slice, -1, swap_cross_attn_context.index_map
|
||||
)
|
||||
|
||||
# only some tokens taken from the original attention probabilities. this is controlled by the mask.
|
||||
mask = swap_cross_attn_context.mask
|
||||
inverse_mask = 1 - mask
|
||||
attn_slice = \
|
||||
remapped_original_attn_slice * mask + \
|
||||
modified_attn_slice * inverse_mask
|
||||
attn_slice = (
|
||||
remapped_original_attn_slice * mask + modified_attn_slice * inverse_mask
|
||||
)
|
||||
|
||||
del remapped_original_attn_slice, modified_attn_slice
|
||||
|
||||
attn_slice = torch.bmm(attn_slice, modified_value[start_idx:end_idx])
|
||||
hidden_states[start_idx:end_idx] = attn_slice
|
||||
|
||||
|
||||
# done
|
||||
hidden_states = attn.batch_to_head_dim(hidden_states)
|
||||
|
||||
@ -636,7 +762,7 @@ class SlicedSwapCrossAttnProcesser(SlicedAttnProcessor):
|
||||
|
||||
|
||||
class SwapCrossAttnProcessor(SlicedSwapCrossAttnProcesser):
|
||||
|
||||
def __init__(self):
|
||||
super(SwapCrossAttnProcessor, self).__init__(slice_size=int(1e9)) # massive slice size = don't slice
|
||||
|
||||
super(SwapCrossAttnProcessor, self).__init__(
|
||||
slice_size=int(1e9)
|
||||
) # massive slice size = don't slice
|
@ -2,17 +2,17 @@ import math
|
||||
|
||||
import PIL
|
||||
import torch
|
||||
from torchvision.transforms.functional import resize as tv_resize, InterpolationMode
|
||||
from torchvision.transforms.functional import InterpolationMode
|
||||
from torchvision.transforms.functional import resize as tv_resize
|
||||
|
||||
from ldm.models.diffusion.cross_attention_control import get_cross_attention_modules, CrossAttentionType
|
||||
from .cross_attention_control import CrossAttentionType, get_cross_attention_modules
|
||||
|
||||
|
||||
class AttentionMapSaver():
|
||||
|
||||
class AttentionMapSaver:
|
||||
def __init__(self, token_ids: range, latents_shape: torch.Size):
|
||||
self.token_ids = token_ids
|
||||
self.latents_shape = latents_shape
|
||||
#self.collated_maps = #torch.zeros([len(token_ids), latents_shape[0], latents_shape[1]])
|
||||
# self.collated_maps = #torch.zeros([len(token_ids), latents_shape[0], latents_shape[1]])
|
||||
self.collated_maps = {}
|
||||
|
||||
def clear_maps(self):
|
||||
@ -25,7 +25,7 @@ class AttentionMapSaver():
|
||||
:param key: Storage key. If a map already exists for this key it will be summed with the incoming data. In this case the maps sizes (H and W) should match.
|
||||
:return: None
|
||||
"""
|
||||
key_and_size = f'{key}_{maps.shape[1]}'
|
||||
key_and_size = f"{key}_{maps.shape[1]}"
|
||||
|
||||
# extract desired tokens
|
||||
maps = maps[:, :, self.token_ids]
|
||||
@ -35,12 +35,12 @@ class AttentionMapSaver():
|
||||
|
||||
# store
|
||||
if key_and_size not in self.collated_maps:
|
||||
self.collated_maps[key_and_size] = torch.zeros_like(maps, device='cpu')
|
||||
self.collated_maps[key_and_size] = torch.zeros_like(maps, device="cpu")
|
||||
self.collated_maps[key_and_size] += maps.cpu()
|
||||
|
||||
def write_maps_to_disk(self, path: str):
|
||||
pil_image = self.get_stacked_maps_image()
|
||||
pil_image.save(path, 'PNG')
|
||||
pil_image.save(path, "PNG")
|
||||
|
||||
def get_stacked_maps_image(self) -> PIL.Image:
|
||||
"""
|
||||
@ -57,39 +57,50 @@ class AttentionMapSaver():
|
||||
merged = None
|
||||
|
||||
for key, maps in self.collated_maps.items():
|
||||
|
||||
# maps has shape [(H*W), N] for N tokens
|
||||
# but we want [N, H, W]
|
||||
this_scale_factor = math.sqrt(maps.shape[0] / (latents_width * latents_height))
|
||||
this_scale_factor = math.sqrt(
|
||||
maps.shape[0] / (latents_width * latents_height)
|
||||
)
|
||||
this_maps_height = int(float(latents_height) * this_scale_factor)
|
||||
this_maps_width = int(float(latents_width) * this_scale_factor)
|
||||
# and we need to do some dimension juggling
|
||||
maps = torch.reshape(torch.swapdims(maps, 0, 1), [num_tokens, this_maps_height, this_maps_width])
|
||||
maps = torch.reshape(
|
||||
torch.swapdims(maps, 0, 1),
|
||||
[num_tokens, this_maps_height, this_maps_width],
|
||||
)
|
||||
|
||||
# scale to output size if necessary
|
||||
if this_scale_factor != 1:
|
||||
maps = tv_resize(maps, [latents_height, latents_width], InterpolationMode.BICUBIC)
|
||||
maps = tv_resize(
|
||||
maps, [latents_height, latents_width], InterpolationMode.BICUBIC
|
||||
)
|
||||
|
||||
# normalize
|
||||
maps_min = torch.min(maps)
|
||||
maps_range = torch.max(maps) - maps_min
|
||||
#print(f"map {key} size {[this_maps_width, this_maps_height]} range {[maps_min, maps_min + maps_range]}")
|
||||
# print(f"map {key} size {[this_maps_width, this_maps_height]} range {[maps_min, maps_min + maps_range]}")
|
||||
maps_normalized = (maps - maps_min) / maps_range
|
||||
# expand to (-0.1, 1.1) and clamp
|
||||
maps_normalized_expanded = maps_normalized * 1.1 - 0.05
|
||||
maps_normalized_expanded_clamped = torch.clamp(maps_normalized_expanded, 0, 1)
|
||||
maps_normalized_expanded_clamped = torch.clamp(
|
||||
maps_normalized_expanded, 0, 1
|
||||
)
|
||||
|
||||
# merge together, producing a vertical stack
|
||||
maps_stacked = torch.reshape(maps_normalized_expanded_clamped, [num_tokens * latents_height, latents_width])
|
||||
maps_stacked = torch.reshape(
|
||||
maps_normalized_expanded_clamped,
|
||||
[num_tokens * latents_height, latents_width],
|
||||
)
|
||||
|
||||
if merged is None:
|
||||
merged = maps_stacked
|
||||
else:
|
||||
# screen blend
|
||||
merged = 1 - (1 - maps_stacked)*(1 - merged)
|
||||
merged = 1 - (1 - maps_stacked) * (1 - merged)
|
||||
|
||||
if merged is None:
|
||||
return None
|
||||
|
||||
merged_bytes = merged.mul(0xff).byte()
|
||||
return PIL.Image.fromarray(merged_bytes.numpy(), mode='L')
|
||||
merged_bytes = merged.mul(0xFF).byte()
|
||||
return PIL.Image.fromarray(merged_bytes.numpy(), mode="L")
|
@ -1,77 +1,82 @@
|
||||
"""SAMPLING ONLY."""
|
||||
|
||||
import torch
|
||||
from ldm.models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent
|
||||
from ldm.models.diffusion.sampler import Sampler
|
||||
from ldm.modules.diffusionmodules.util import noise_like
|
||||
|
||||
from ..diffusionmodules.util import noise_like
|
||||
from .sampler import Sampler
|
||||
from .shared_invokeai_diffusion import InvokeAIDiffuserComponent
|
||||
|
||||
|
||||
class DDIMSampler(Sampler):
|
||||
def __init__(self, model, schedule='linear', device=None, **kwargs):
|
||||
super().__init__(model,schedule,model.num_timesteps,device)
|
||||
def __init__(self, model, schedule="linear", device=None, **kwargs):
|
||||
super().__init__(model, schedule, model.num_timesteps, device)
|
||||
|
||||
self.invokeai_diffuser = InvokeAIDiffuserComponent(self.model,
|
||||
model_forward_callback = lambda x, sigma, cond: self.model.apply_model(x, sigma, cond))
|
||||
self.invokeai_diffuser = InvokeAIDiffuserComponent(
|
||||
self.model,
|
||||
model_forward_callback=lambda x, sigma, cond: self.model.apply_model(
|
||||
x, sigma, cond
|
||||
),
|
||||
)
|
||||
|
||||
def prepare_to_sample(self, t_enc, **kwargs):
|
||||
super().prepare_to_sample(t_enc, **kwargs)
|
||||
|
||||
extra_conditioning_info = kwargs.get('extra_conditioning_info', None)
|
||||
all_timesteps_count = kwargs.get('all_timesteps_count', t_enc)
|
||||
extra_conditioning_info = kwargs.get("extra_conditioning_info", None)
|
||||
all_timesteps_count = kwargs.get("all_timesteps_count", t_enc)
|
||||
|
||||
if extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control:
|
||||
self.invokeai_diffuser.override_cross_attention(extra_conditioning_info, step_count = all_timesteps_count)
|
||||
if (
|
||||
extra_conditioning_info is not None
|
||||
and extra_conditioning_info.wants_cross_attention_control
|
||||
):
|
||||
self.invokeai_diffuser.override_cross_attention(
|
||||
extra_conditioning_info, step_count=all_timesteps_count
|
||||
)
|
||||
else:
|
||||
self.invokeai_diffuser.restore_default_cross_attention()
|
||||
|
||||
|
||||
# This is the central routine
|
||||
@torch.no_grad()
|
||||
def p_sample(
|
||||
self,
|
||||
x,
|
||||
c,
|
||||
t,
|
||||
index,
|
||||
repeat_noise=False,
|
||||
use_original_steps=False,
|
||||
quantize_denoised=False,
|
||||
temperature=1.0,
|
||||
noise_dropout=0.0,
|
||||
score_corrector=None,
|
||||
corrector_kwargs=None,
|
||||
unconditional_guidance_scale=1.0,
|
||||
unconditional_conditioning=None,
|
||||
step_count:int=1000, # total number of steps
|
||||
**kwargs,
|
||||
self,
|
||||
x,
|
||||
c,
|
||||
t,
|
||||
index,
|
||||
repeat_noise=False,
|
||||
use_original_steps=False,
|
||||
quantize_denoised=False,
|
||||
temperature=1.0,
|
||||
noise_dropout=0.0,
|
||||
score_corrector=None,
|
||||
corrector_kwargs=None,
|
||||
unconditional_guidance_scale=1.0,
|
||||
unconditional_conditioning=None,
|
||||
step_count: int = 1000, # total number of steps
|
||||
**kwargs,
|
||||
):
|
||||
b, *_, device = *x.shape, x.device
|
||||
|
||||
if (
|
||||
unconditional_conditioning is None
|
||||
or unconditional_guidance_scale == 1.0
|
||||
):
|
||||
if unconditional_conditioning is None or unconditional_guidance_scale == 1.0:
|
||||
# damian0815 would like to know when/if this code path is used
|
||||
e_t = self.model.apply_model(x, t, c)
|
||||
else:
|
||||
# step_index counts in the opposite direction to index
|
||||
step_index = step_count-(index+1)
|
||||
step_index = step_count - (index + 1)
|
||||
e_t = self.invokeai_diffuser.do_diffusion_step(
|
||||
x, t,
|
||||
unconditional_conditioning, c,
|
||||
x,
|
||||
t,
|
||||
unconditional_conditioning,
|
||||
c,
|
||||
unconditional_guidance_scale,
|
||||
step_index=step_index
|
||||
step_index=step_index,
|
||||
)
|
||||
if score_corrector is not None:
|
||||
assert self.model.parameterization == 'eps'
|
||||
assert self.model.parameterization == "eps"
|
||||
e_t = score_corrector.modify_score(
|
||||
self.model, e_t, x, t, c, **corrector_kwargs
|
||||
)
|
||||
|
||||
alphas = (
|
||||
self.model.alphas_cumprod
|
||||
if use_original_steps
|
||||
else self.ddim_alphas
|
||||
)
|
||||
alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
|
||||
alphas_prev = (
|
||||
self.model.alphas_cumprod_prev
|
||||
if use_original_steps
|
||||
@ -101,11 +106,8 @@ class DDIMSampler(Sampler):
|
||||
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
|
||||
# direction pointing to x_t
|
||||
dir_xt = (1.0 - a_prev - sigma_t**2).sqrt() * e_t
|
||||
noise = (
|
||||
sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
|
||||
)
|
||||
noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
|
||||
if noise_dropout > 0.0:
|
||||
noise = torch.nn.functional.dropout(noise, p=noise_dropout)
|
||||
x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
|
||||
return x_prev, pred_x0, None
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -8,12 +8,12 @@ from .cross_attention_map_saving import AttentionMapSaver
|
||||
from .sampler import Sampler
|
||||
from .shared_invokeai_diffusion import InvokeAIDiffuserComponent
|
||||
|
||||
|
||||
# at this threshold, the scheduler will stop using the Karras
|
||||
# noise schedule and start using the model's schedule
|
||||
STEP_THRESHOLD = 30
|
||||
|
||||
def cfg_apply_threshold(result, threshold = 0.0, scale = 0.7):
|
||||
|
||||
def cfg_apply_threshold(result, threshold=0.0, scale=0.7):
|
||||
if threshold <= 0.0:
|
||||
return result
|
||||
maxval = 0.0 + torch.max(result).cpu().numpy()
|
||||
@ -21,35 +21,43 @@ def cfg_apply_threshold(result, threshold = 0.0, scale = 0.7):
|
||||
if maxval < threshold and minval > -threshold:
|
||||
return result
|
||||
if maxval > threshold:
|
||||
maxval = min(max(1, scale*maxval), threshold)
|
||||
maxval = min(max(1, scale * maxval), threshold)
|
||||
if minval < -threshold:
|
||||
minval = max(min(-1, scale*minval), -threshold)
|
||||
minval = max(min(-1, scale * minval), -threshold)
|
||||
return torch.clamp(result, min=minval, max=maxval)
|
||||
|
||||
|
||||
class CFGDenoiser(nn.Module):
|
||||
def __init__(self, model, threshold = 0, warmup = 0):
|
||||
def __init__(self, model, threshold=0, warmup=0):
|
||||
super().__init__()
|
||||
self.inner_model = model
|
||||
self.threshold = threshold
|
||||
self.warmup_max = warmup
|
||||
self.warmup = max(warmup / 10, 1)
|
||||
self.invokeai_diffuser = InvokeAIDiffuserComponent(model,
|
||||
model_forward_callback=lambda x, sigma, cond: self.inner_model(x, sigma, cond=cond))
|
||||
|
||||
self.invokeai_diffuser = InvokeAIDiffuserComponent(
|
||||
model,
|
||||
model_forward_callback=lambda x, sigma, cond: self.inner_model(
|
||||
x, sigma, cond=cond
|
||||
),
|
||||
)
|
||||
|
||||
def prepare_to_sample(self, t_enc, **kwargs):
|
||||
extra_conditioning_info = kwargs.get("extra_conditioning_info", None)
|
||||
|
||||
extra_conditioning_info = kwargs.get('extra_conditioning_info', None)
|
||||
|
||||
if extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control:
|
||||
self.invokeai_diffuser.override_cross_attention(extra_conditioning_info, step_count = t_enc)
|
||||
if (
|
||||
extra_conditioning_info is not None
|
||||
and extra_conditioning_info.wants_cross_attention_control
|
||||
):
|
||||
self.invokeai_diffuser.override_cross_attention(
|
||||
extra_conditioning_info, step_count=t_enc
|
||||
)
|
||||
else:
|
||||
self.invokeai_diffuser.restore_default_cross_attention()
|
||||
|
||||
|
||||
def forward(self, x, sigma, uncond, cond, cond_scale):
|
||||
next_x = self.invokeai_diffuser.do_diffusion_step(x, sigma, uncond, cond, cond_scale)
|
||||
next_x = self.invokeai_diffuser.do_diffusion_step(
|
||||
x, sigma, uncond, cond, cond_scale
|
||||
)
|
||||
if self.warmup < self.warmup_max:
|
||||
thresh = max(1, 1 + (self.threshold - 1) * (self.warmup / self.warmup_max))
|
||||
self.warmup += 1
|
||||
@ -59,8 +67,9 @@ class CFGDenoiser(nn.Module):
|
||||
thresh = self.threshold
|
||||
return cfg_apply_threshold(next_x, thresh)
|
||||
|
||||
|
||||
class KSampler(Sampler):
|
||||
def __init__(self, model, schedule='lms', device=None, **kwargs):
|
||||
def __init__(self, model, schedule="lms", device=None, **kwargs):
|
||||
denoiser = K.external.CompVisDenoiser(model)
|
||||
super().__init__(
|
||||
denoiser,
|
||||
@ -68,45 +77,49 @@ class KSampler(Sampler):
|
||||
steps=model.num_timesteps,
|
||||
)
|
||||
self.sigmas = None
|
||||
self.ds = None
|
||||
self.s_in = None
|
||||
self.karras_max = kwargs.get('karras_max',STEP_THRESHOLD)
|
||||
self.ds = None
|
||||
self.s_in = None
|
||||
self.karras_max = kwargs.get("karras_max", STEP_THRESHOLD)
|
||||
if self.karras_max is None:
|
||||
self.karras_max = STEP_THRESHOLD
|
||||
|
||||
def make_schedule(
|
||||
self,
|
||||
ddim_num_steps,
|
||||
ddim_discretize='uniform',
|
||||
ddim_eta=0.0,
|
||||
verbose=False,
|
||||
self,
|
||||
ddim_num_steps,
|
||||
ddim_discretize="uniform",
|
||||
ddim_eta=0.0,
|
||||
verbose=False,
|
||||
):
|
||||
outer_model = self.model
|
||||
self.model = outer_model.inner_model
|
||||
self.model = outer_model.inner_model
|
||||
super().make_schedule(
|
||||
ddim_num_steps,
|
||||
ddim_discretize='uniform',
|
||||
ddim_discretize="uniform",
|
||||
ddim_eta=0.0,
|
||||
verbose=False,
|
||||
)
|
||||
self.model = outer_model
|
||||
self.model = outer_model
|
||||
self.ddim_num_steps = ddim_num_steps
|
||||
# we don't need both of these sigmas, but storing them here to make
|
||||
# comparison easier later on
|
||||
self.model_sigmas = self.model.get_sigmas(ddim_num_steps)
|
||||
self.model_sigmas = self.model.get_sigmas(ddim_num_steps)
|
||||
self.karras_sigmas = K.sampling.get_sigmas_karras(
|
||||
n=ddim_num_steps,
|
||||
sigma_min=self.model.sigmas[0].item(),
|
||||
sigma_max=self.model.sigmas[-1].item(),
|
||||
rho=7.,
|
||||
rho=7.0,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
if ddim_num_steps >= self.karras_max:
|
||||
print(f'>> Ksampler using model noise schedule (steps >= {self.karras_max})')
|
||||
print(
|
||||
f">> Ksampler using model noise schedule (steps >= {self.karras_max})"
|
||||
)
|
||||
self.sigmas = self.model_sigmas
|
||||
else:
|
||||
print(f'>> Ksampler using karras noise schedule (steps < {self.karras_max})')
|
||||
print(
|
||||
f">> Ksampler using karras noise schedule (steps < {self.karras_max})"
|
||||
)
|
||||
self.sigmas = self.karras_sigmas
|
||||
|
||||
# ALERT: We are completely overriding the sample() method in the base class, which
|
||||
@ -116,31 +129,31 @@ class KSampler(Sampler):
|
||||
|
||||
@torch.no_grad()
|
||||
def decode(
|
||||
self,
|
||||
z_enc,
|
||||
cond,
|
||||
t_enc,
|
||||
img_callback=None,
|
||||
unconditional_guidance_scale=1.0,
|
||||
unconditional_conditioning=None,
|
||||
use_original_steps=False,
|
||||
init_latent = None,
|
||||
mask = None,
|
||||
**kwargs
|
||||
self,
|
||||
z_enc,
|
||||
cond,
|
||||
t_enc,
|
||||
img_callback=None,
|
||||
unconditional_guidance_scale=1.0,
|
||||
unconditional_conditioning=None,
|
||||
use_original_steps=False,
|
||||
init_latent=None,
|
||||
mask=None,
|
||||
**kwargs,
|
||||
):
|
||||
samples,_ = self.sample(
|
||||
batch_size = 1,
|
||||
S = t_enc,
|
||||
x_T = z_enc,
|
||||
shape = z_enc.shape[1:],
|
||||
conditioning = cond,
|
||||
samples, _ = self.sample(
|
||||
batch_size=1,
|
||||
S=t_enc,
|
||||
x_T=z_enc,
|
||||
shape=z_enc.shape[1:],
|
||||
conditioning=cond,
|
||||
unconditional_guidance_scale=unconditional_guidance_scale,
|
||||
unconditional_conditioning = unconditional_conditioning,
|
||||
img_callback = img_callback,
|
||||
x0 = init_latent,
|
||||
mask = mask,
|
||||
**kwargs
|
||||
)
|
||||
unconditional_conditioning=unconditional_conditioning,
|
||||
img_callback=img_callback,
|
||||
x0=init_latent,
|
||||
mask=mask,
|
||||
**kwargs,
|
||||
)
|
||||
return samples
|
||||
|
||||
# this is a no-op, provided here for compatibility with ddim and plms samplers
|
||||
@ -174,26 +187,26 @@ class KSampler(Sampler):
|
||||
log_every_t=100,
|
||||
unconditional_guidance_scale=1.0,
|
||||
unconditional_conditioning=None,
|
||||
extra_conditioning_info: InvokeAIDiffuserComponent.ExtraConditioningInfo=None,
|
||||
threshold = 0,
|
||||
perlin = 0,
|
||||
extra_conditioning_info: InvokeAIDiffuserComponent.ExtraConditioningInfo = None,
|
||||
threshold=0,
|
||||
perlin=0,
|
||||
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
|
||||
**kwargs,
|
||||
):
|
||||
def route_callback(k_callback_values):
|
||||
if img_callback is not None:
|
||||
img_callback(k_callback_values['x'],k_callback_values['i'])
|
||||
img_callback(k_callback_values["x"], k_callback_values["i"])
|
||||
|
||||
# if make_schedule() hasn't been called, we do it now
|
||||
if self.sigmas is None:
|
||||
self.make_schedule(
|
||||
ddim_num_steps=S,
|
||||
ddim_eta = eta,
|
||||
verbose = False,
|
||||
ddim_eta=eta,
|
||||
verbose=False,
|
||||
)
|
||||
|
||||
# sigmas are set up in make_schedule - we take the last steps items
|
||||
sigmas = self.sigmas[-S-1:]
|
||||
sigmas = self.sigmas[-S - 1 :]
|
||||
|
||||
# x_T is variation noise. When an init image is provided (in x0) we need to add
|
||||
# more randomness to the starting image.
|
||||
@ -205,27 +218,40 @@ class KSampler(Sampler):
|
||||
else:
|
||||
x = torch.randn([batch_size, *shape], device=self.device) * sigmas[0]
|
||||
|
||||
model_wrap_cfg = CFGDenoiser(self.model, threshold=threshold, warmup=max(0.8*S,S-10))
|
||||
model_wrap_cfg.prepare_to_sample(S, extra_conditioning_info=extra_conditioning_info)
|
||||
model_wrap_cfg = CFGDenoiser(
|
||||
self.model, threshold=threshold, warmup=max(0.8 * S, S - 10)
|
||||
)
|
||||
model_wrap_cfg.prepare_to_sample(
|
||||
S, extra_conditioning_info=extra_conditioning_info
|
||||
)
|
||||
|
||||
# setup attention maps saving. checks for None are because there are multiple code paths to get here.
|
||||
attention_map_saver = None
|
||||
if attention_maps_callback is not None and extra_conditioning_info is not None:
|
||||
eos_token_index = extra_conditioning_info.tokens_count_including_eos_bos - 1
|
||||
attention_map_token_ids = range(1, eos_token_index)
|
||||
attention_map_saver = AttentionMapSaver(token_ids = attention_map_token_ids, latents_shape=x.shape[-2:])
|
||||
model_wrap_cfg.invokeai_diffuser.setup_attention_map_saving(attention_map_saver)
|
||||
attention_map_saver = AttentionMapSaver(
|
||||
token_ids=attention_map_token_ids, latents_shape=x.shape[-2:]
|
||||
)
|
||||
model_wrap_cfg.invokeai_diffuser.setup_attention_map_saving(
|
||||
attention_map_saver
|
||||
)
|
||||
|
||||
extra_args = {
|
||||
'cond': conditioning,
|
||||
'uncond': unconditional_conditioning,
|
||||
'cond_scale': unconditional_guidance_scale,
|
||||
"cond": conditioning,
|
||||
"uncond": unconditional_conditioning,
|
||||
"cond_scale": unconditional_guidance_scale,
|
||||
}
|
||||
print(f'>> Sampling with k_{self.schedule} starting at step {len(self.sigmas)-S-1} of {len(self.sigmas)-1} ({S} new sampling steps)')
|
||||
print(
|
||||
f">> Sampling with k_{self.schedule} starting at step {len(self.sigmas)-S-1} of {len(self.sigmas)-1} ({S} new sampling steps)"
|
||||
)
|
||||
sampling_result = (
|
||||
K.sampling.__dict__[f'sample_{self.schedule}'](
|
||||
model_wrap_cfg, x, sigmas, extra_args=extra_args,
|
||||
callback=route_callback
|
||||
K.sampling.__dict__[f"sample_{self.schedule}"](
|
||||
model_wrap_cfg,
|
||||
x,
|
||||
sigmas,
|
||||
extra_args=extra_args,
|
||||
callback=route_callback,
|
||||
),
|
||||
None,
|
||||
)
|
||||
@ -237,25 +263,25 @@ class KSampler(Sampler):
|
||||
# a workaround is found.
|
||||
@torch.no_grad()
|
||||
def p_sample(
|
||||
self,
|
||||
img,
|
||||
cond,
|
||||
ts,
|
||||
index,
|
||||
unconditional_guidance_scale=1.0,
|
||||
unconditional_conditioning=None,
|
||||
extra_conditioning_info=None,
|
||||
**kwargs,
|
||||
self,
|
||||
img,
|
||||
cond,
|
||||
ts,
|
||||
index,
|
||||
unconditional_guidance_scale=1.0,
|
||||
unconditional_conditioning=None,
|
||||
extra_conditioning_info=None,
|
||||
**kwargs,
|
||||
):
|
||||
if self.model_wrap is None:
|
||||
self.model_wrap = CFGDenoiser(self.model)
|
||||
extra_args = {
|
||||
'cond': cond,
|
||||
'uncond': unconditional_conditioning,
|
||||
'cond_scale': unconditional_guidance_scale,
|
||||
"cond": cond,
|
||||
"uncond": unconditional_conditioning,
|
||||
"cond_scale": unconditional_guidance_scale,
|
||||
}
|
||||
if self.s_in is None:
|
||||
self.s_in = img.new_ones([img.shape[0]])
|
||||
self.s_in = img.new_ones([img.shape[0]])
|
||||
if self.ds is None:
|
||||
self.ds = []
|
||||
|
||||
@ -270,14 +296,16 @@ class KSampler(Sampler):
|
||||
# so the actual formula for indexing into sigmas:
|
||||
# sigma_index = (steps-index)
|
||||
s_index = t_enc - index - 1
|
||||
self.model_wrap.prepare_to_sample(s_index, extra_conditioning_info=extra_conditioning_info)
|
||||
img = K.sampling.__dict__[f'_{self.schedule}'](
|
||||
self.model_wrap.prepare_to_sample(
|
||||
s_index, extra_conditioning_info=extra_conditioning_info
|
||||
)
|
||||
img = K.sampling.__dict__[f"_{self.schedule}"](
|
||||
self.model_wrap,
|
||||
img,
|
||||
self.sigmas,
|
||||
s_index,
|
||||
s_in = self.s_in,
|
||||
ds = self.ds,
|
||||
s_in=self.s_in,
|
||||
ds=self.ds,
|
||||
extra_args=extra_args,
|
||||
)
|
||||
|
||||
@ -287,26 +315,25 @@ class KSampler(Sampler):
|
||||
# we should not be multiplying by self.sigmas[0] if we
|
||||
# are at an intermediate step in img2img. See similar in
|
||||
# sample() which does work.
|
||||
def get_initial_image(self,x_T,shape,steps):
|
||||
print(f'WARNING: ksampler.get_initial_image(): get_initial_image needs testing')
|
||||
x = (torch.randn(shape, device=self.device) * self.sigmas[0])
|
||||
def get_initial_image(self, x_T, shape, steps):
|
||||
print(f"WARNING: ksampler.get_initial_image(): get_initial_image needs testing")
|
||||
x = torch.randn(shape, device=self.device) * self.sigmas[0]
|
||||
if x_T is not None:
|
||||
return x_T + x
|
||||
else:
|
||||
return x
|
||||
|
||||
def prepare_to_sample(self,t_enc,**kwargs):
|
||||
self.t_enc = t_enc
|
||||
def prepare_to_sample(self, t_enc, **kwargs):
|
||||
self.t_enc = t_enc
|
||||
self.model_wrap = None
|
||||
self.ds = None
|
||||
self.s_in = None
|
||||
self.ds = None
|
||||
self.s_in = None
|
||||
|
||||
def q_sample(self,x0,ts):
|
||||
'''
|
||||
def q_sample(self, x0, ts):
|
||||
"""
|
||||
Overrides parent method to return the q_sample of the inner model.
|
||||
'''
|
||||
return self.model.inner_model.q_sample(x0,ts)
|
||||
"""
|
||||
return self.model.inner_model.q_sample(x0, ts)
|
||||
|
||||
def conditioning_key(self)->str:
|
||||
def conditioning_key(self) -> str:
|
||||
return self.model.inner_model.model.conditioning_key
|
||||
|
@ -1,52 +1,58 @@
|
||||
"""SAMPLING ONLY."""
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
from functools import partial
|
||||
from ldm.invoke.devices import choose_torch_device
|
||||
from ldm.models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent
|
||||
from ldm.models.diffusion.sampler import Sampler
|
||||
from ldm.modules.diffusionmodules.util import noise_like
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
from ...util import choose_torch_device
|
||||
from ..diffusionmodules.util import noise_like
|
||||
from .sampler import Sampler
|
||||
from .shared_invokeai_diffusion import InvokeAIDiffuserComponent
|
||||
|
||||
|
||||
class PLMSSampler(Sampler):
|
||||
def __init__(self, model, schedule='linear', device=None, **kwargs):
|
||||
super().__init__(model,schedule,model.num_timesteps, device)
|
||||
def __init__(self, model, schedule="linear", device=None, **kwargs):
|
||||
super().__init__(model, schedule, model.num_timesteps, device)
|
||||
|
||||
def prepare_to_sample(self, t_enc, **kwargs):
|
||||
super().prepare_to_sample(t_enc, **kwargs)
|
||||
|
||||
extra_conditioning_info = kwargs.get('extra_conditioning_info', None)
|
||||
all_timesteps_count = kwargs.get('all_timesteps_count', t_enc)
|
||||
extra_conditioning_info = kwargs.get("extra_conditioning_info", None)
|
||||
all_timesteps_count = kwargs.get("all_timesteps_count", t_enc)
|
||||
|
||||
if extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control:
|
||||
self.invokeai_diffuser.override_cross_attention(extra_conditioning_info, step_count = all_timesteps_count)
|
||||
if (
|
||||
extra_conditioning_info is not None
|
||||
and extra_conditioning_info.wants_cross_attention_control
|
||||
):
|
||||
self.invokeai_diffuser.override_cross_attention(
|
||||
extra_conditioning_info, step_count=all_timesteps_count
|
||||
)
|
||||
else:
|
||||
self.invokeai_diffuser.restore_default_cross_attention()
|
||||
|
||||
|
||||
# this is the essential routine
|
||||
@torch.no_grad()
|
||||
def p_sample(
|
||||
self,
|
||||
x, # image, called 'img' elsewhere
|
||||
c, # conditioning, called 'cond' elsewhere
|
||||
t, # timesteps, called 'ts' elsewhere
|
||||
index,
|
||||
repeat_noise=False,
|
||||
use_original_steps=False,
|
||||
quantize_denoised=False,
|
||||
temperature=1.0,
|
||||
noise_dropout=0.0,
|
||||
score_corrector=None,
|
||||
corrector_kwargs=None,
|
||||
unconditional_guidance_scale=1.0,
|
||||
unconditional_conditioning=None,
|
||||
old_eps=[],
|
||||
t_next=None,
|
||||
step_count:int=1000, # total number of steps
|
||||
**kwargs,
|
||||
self,
|
||||
x, # image, called 'img' elsewhere
|
||||
c, # conditioning, called 'cond' elsewhere
|
||||
t, # timesteps, called 'ts' elsewhere
|
||||
index,
|
||||
repeat_noise=False,
|
||||
use_original_steps=False,
|
||||
quantize_denoised=False,
|
||||
temperature=1.0,
|
||||
noise_dropout=0.0,
|
||||
score_corrector=None,
|
||||
corrector_kwargs=None,
|
||||
unconditional_guidance_scale=1.0,
|
||||
unconditional_conditioning=None,
|
||||
old_eps=[],
|
||||
t_next=None,
|
||||
step_count: int = 1000, # total number of steps
|
||||
**kwargs,
|
||||
):
|
||||
b, *_, device = *x.shape, x.device
|
||||
|
||||
@ -59,24 +65,24 @@ class PLMSSampler(Sampler):
|
||||
e_t = self.model.apply_model(x, t, c)
|
||||
else:
|
||||
# step_index counts in the opposite direction to index
|
||||
step_index = step_count-(index+1)
|
||||
e_t = self.invokeai_diffuser.do_diffusion_step(x, t,
|
||||
unconditional_conditioning, c,
|
||||
unconditional_guidance_scale,
|
||||
step_index=step_index)
|
||||
step_index = step_count - (index + 1)
|
||||
e_t = self.invokeai_diffuser.do_diffusion_step(
|
||||
x,
|
||||
t,
|
||||
unconditional_conditioning,
|
||||
c,
|
||||
unconditional_guidance_scale,
|
||||
step_index=step_index,
|
||||
)
|
||||
if score_corrector is not None:
|
||||
assert self.model.parameterization == 'eps'
|
||||
assert self.model.parameterization == "eps"
|
||||
e_t = score_corrector.modify_score(
|
||||
self.model, e_t, x, t, c, **corrector_kwargs
|
||||
)
|
||||
|
||||
return e_t
|
||||
|
||||
alphas = (
|
||||
self.model.alphas_cumprod
|
||||
if use_original_steps
|
||||
else self.ddim_alphas
|
||||
)
|
||||
alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
|
||||
alphas_prev = (
|
||||
self.model.alphas_cumprod_prev
|
||||
if use_original_steps
|
||||
@ -96,9 +102,7 @@ class PLMSSampler(Sampler):
|
||||
def get_x_prev_and_pred_x0(e_t, index):
|
||||
# select parameters corresponding to the currently considered timestep
|
||||
a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
|
||||
a_prev = torch.full(
|
||||
(b, 1, 1, 1), alphas_prev[index], device=device
|
||||
)
|
||||
a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
|
||||
sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
|
||||
sqrt_one_minus_at = torch.full(
|
||||
(b, 1, 1, 1), sqrt_one_minus_alphas[index], device=device
|
||||
@ -110,11 +114,7 @@ class PLMSSampler(Sampler):
|
||||
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
|
||||
# direction pointing to x_t
|
||||
dir_xt = (1.0 - a_prev - sigma_t**2).sqrt() * e_t
|
||||
noise = (
|
||||
sigma_t
|
||||
* noise_like(x.shape, device, repeat_noise)
|
||||
* temperature
|
||||
)
|
||||
noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
|
||||
if noise_dropout > 0.0:
|
||||
noise = torch.nn.functional.dropout(noise, p=noise_dropout)
|
||||
x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
|
||||
@ -135,10 +135,7 @@ class PLMSSampler(Sampler):
|
||||
elif len(old_eps) >= 3:
|
||||
# 4nd order Pseudo Linear Multistep (Adams-Bashforth)
|
||||
e_t_prime = (
|
||||
55 * e_t
|
||||
- 59 * old_eps[-1]
|
||||
+ 37 * old_eps[-2]
|
||||
- 9 * old_eps[-3]
|
||||
55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]
|
||||
) / 24
|
||||
|
||||
x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index)
|
@ -1,31 +1,37 @@
|
||||
'''
|
||||
ldm.models.diffusion.sampler
|
||||
"""
|
||||
invokeai.models.diffusion.sampler
|
||||
|
||||
Base class for ldm.models.diffusion.ddim, ldm.models.diffusion.ksampler, etc
|
||||
'''
|
||||
import torch
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
Base class for invokeai.models.diffusion.ddim, invokeai.models.diffusion.ksampler, etc
|
||||
"""
|
||||
from functools import partial
|
||||
from ldm.invoke.devices import choose_torch_device
|
||||
from ldm.models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent
|
||||
|
||||
from ldm.modules.diffusionmodules.util import (
|
||||
import numpy as np
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
from ...util import choose_torch_device
|
||||
from ..diffusionmodules.util import (
|
||||
extract_into_tensor,
|
||||
make_ddim_sampling_parameters,
|
||||
make_ddim_timesteps,
|
||||
noise_like,
|
||||
extract_into_tensor,
|
||||
)
|
||||
from .shared_invokeai_diffusion import InvokeAIDiffuserComponent
|
||||
|
||||
|
||||
class Sampler(object):
|
||||
def __init__(self, model, schedule='linear', steps=None, device=None, **kwargs):
|
||||
def __init__(self, model, schedule="linear", steps=None, device=None, **kwargs):
|
||||
self.model = model
|
||||
self.ddim_timesteps = None
|
||||
self.ddpm_num_timesteps = steps
|
||||
self.schedule = schedule
|
||||
self.device = device or choose_torch_device()
|
||||
self.invokeai_diffuser = InvokeAIDiffuserComponent(self.model,
|
||||
model_forward_callback = lambda x, sigma, cond: self.model.apply_model(x, sigma, cond))
|
||||
self.device = device or choose_torch_device()
|
||||
self.invokeai_diffuser = InvokeAIDiffuserComponent(
|
||||
self.model,
|
||||
model_forward_callback=lambda x, sigma, cond: self.model.apply_model(
|
||||
x, sigma, cond
|
||||
),
|
||||
)
|
||||
|
||||
def register_buffer(self, name, attr):
|
||||
if type(attr) == torch.Tensor:
|
||||
@ -36,11 +42,11 @@ class Sampler(object):
|
||||
# This method was copied over from ddim.py and probably does stuff that is
|
||||
# ddim-specific. Disentangle at some point.
|
||||
def make_schedule(
|
||||
self,
|
||||
ddim_num_steps,
|
||||
ddim_discretize='uniform',
|
||||
ddim_eta=0.0,
|
||||
verbose=False,
|
||||
self,
|
||||
ddim_num_steps,
|
||||
ddim_discretize="uniform",
|
||||
ddim_eta=0.0,
|
||||
verbose=False,
|
||||
):
|
||||
self.total_steps = ddim_num_steps
|
||||
self.ddim_timesteps = make_ddim_timesteps(
|
||||
@ -52,38 +58,33 @@ class Sampler(object):
|
||||
alphas_cumprod = self.model.alphas_cumprod
|
||||
assert (
|
||||
alphas_cumprod.shape[0] == self.ddpm_num_timesteps
|
||||
), 'alphas have to be defined for each timestep'
|
||||
to_torch = (
|
||||
lambda x: x.clone()
|
||||
.detach()
|
||||
.to(torch.float32)
|
||||
.to(self.model.device)
|
||||
)
|
||||
), "alphas have to be defined for each timestep"
|
||||
to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
|
||||
|
||||
self.register_buffer('betas', to_torch(self.model.betas))
|
||||
self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
|
||||
self.register_buffer("betas", to_torch(self.model.betas))
|
||||
self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod))
|
||||
self.register_buffer(
|
||||
'alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev)
|
||||
"alphas_cumprod_prev", to_torch(self.model.alphas_cumprod_prev)
|
||||
)
|
||||
|
||||
# calculations for diffusion q(x_t | x_{t-1}) and others
|
||||
self.register_buffer(
|
||||
'sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu()))
|
||||
"sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod.cpu()))
|
||||
)
|
||||
self.register_buffer(
|
||||
'sqrt_one_minus_alphas_cumprod',
|
||||
"sqrt_one_minus_alphas_cumprod",
|
||||
to_torch(np.sqrt(1.0 - alphas_cumprod.cpu())),
|
||||
)
|
||||
self.register_buffer(
|
||||
'log_one_minus_alphas_cumprod',
|
||||
"log_one_minus_alphas_cumprod",
|
||||
to_torch(np.log(1.0 - alphas_cumprod.cpu())),
|
||||
)
|
||||
self.register_buffer(
|
||||
'sqrt_recip_alphas_cumprod',
|
||||
"sqrt_recip_alphas_cumprod",
|
||||
to_torch(np.sqrt(1.0 / alphas_cumprod.cpu())),
|
||||
)
|
||||
self.register_buffer(
|
||||
'sqrt_recipm1_alphas_cumprod',
|
||||
"sqrt_recipm1_alphas_cumprod",
|
||||
to_torch(np.sqrt(1.0 / alphas_cumprod.cpu() - 1)),
|
||||
)
|
||||
|
||||
@ -98,19 +99,17 @@ class Sampler(object):
|
||||
eta=ddim_eta,
|
||||
verbose=verbose,
|
||||
)
|
||||
self.register_buffer('ddim_sigmas', ddim_sigmas)
|
||||
self.register_buffer('ddim_alphas', ddim_alphas)
|
||||
self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
|
||||
self.register_buffer(
|
||||
'ddim_sqrt_one_minus_alphas', np.sqrt(1.0 - ddim_alphas)
|
||||
)
|
||||
self.register_buffer("ddim_sigmas", ddim_sigmas)
|
||||
self.register_buffer("ddim_alphas", ddim_alphas)
|
||||
self.register_buffer("ddim_alphas_prev", ddim_alphas_prev)
|
||||
self.register_buffer("ddim_sqrt_one_minus_alphas", np.sqrt(1.0 - ddim_alphas))
|
||||
sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
|
||||
(1 - self.alphas_cumprod_prev)
|
||||
/ (1 - self.alphas_cumprod)
|
||||
* (1 - self.alphas_cumprod / self.alphas_cumprod_prev)
|
||||
)
|
||||
self.register_buffer(
|
||||
'ddim_sigmas_for_original_num_steps',
|
||||
"ddim_sigmas_for_original_num_steps",
|
||||
sigmas_for_original_sampling_steps,
|
||||
)
|
||||
|
||||
@ -129,20 +128,19 @@ class Sampler(object):
|
||||
noise = torch.randn_like(x0)
|
||||
return (
|
||||
extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0
|
||||
+ extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape)
|
||||
* noise
|
||||
+ extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape) * noise
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def sample(
|
||||
self,
|
||||
S, # S is steps
|
||||
S, # S is steps
|
||||
batch_size,
|
||||
shape,
|
||||
conditioning=None,
|
||||
callback=None,
|
||||
normals_sequence=None,
|
||||
img_callback=None, # TODO: this is very confusing because it is called "step_callback" elsewhere. Change.
|
||||
img_callback=None, # TODO: this is very confusing because it is called "step_callback" elsewhere. Change.
|
||||
quantize_x0=False,
|
||||
eta=0.0,
|
||||
mask=None,
|
||||
@ -159,7 +157,6 @@ class Sampler(object):
|
||||
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
|
||||
**kwargs,
|
||||
):
|
||||
|
||||
if conditioning is not None:
|
||||
if isinstance(conditioning, dict):
|
||||
ctmp = conditioning[list(conditioning.keys())[0]]
|
||||
@ -167,17 +164,21 @@ class Sampler(object):
|
||||
ctmp = ctmp[0]
|
||||
cbs = ctmp.shape[0]
|
||||
if cbs != batch_size:
|
||||
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
|
||||
print(
|
||||
f"Warning: Got {cbs} conditionings but batch-size is {batch_size}"
|
||||
)
|
||||
else:
|
||||
if conditioning.shape[0] != batch_size:
|
||||
print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
|
||||
print(
|
||||
f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}"
|
||||
)
|
||||
|
||||
# check to see if make_schedule() has run, and if not, run it
|
||||
if self.ddim_timesteps is None:
|
||||
self.make_schedule(
|
||||
ddim_num_steps=S,
|
||||
ddim_eta = eta,
|
||||
verbose = False,
|
||||
ddim_eta=eta,
|
||||
verbose=False,
|
||||
)
|
||||
|
||||
ts = self.get_timesteps(S)
|
||||
@ -204,32 +205,32 @@ class Sampler(object):
|
||||
unconditional_guidance_scale=unconditional_guidance_scale,
|
||||
unconditional_conditioning=unconditional_conditioning,
|
||||
steps=S,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
)
|
||||
return samples, intermediates
|
||||
|
||||
@torch.no_grad()
|
||||
def do_sampling(
|
||||
self,
|
||||
cond,
|
||||
shape,
|
||||
timesteps=None,
|
||||
x_T=None,
|
||||
ddim_use_original_steps=False,
|
||||
callback=None,
|
||||
quantize_denoised=False,
|
||||
mask=None,
|
||||
x0=None,
|
||||
img_callback=None,
|
||||
log_every_t=100,
|
||||
temperature=1.0,
|
||||
noise_dropout=0.0,
|
||||
score_corrector=None,
|
||||
corrector_kwargs=None,
|
||||
unconditional_guidance_scale=1.0,
|
||||
unconditional_conditioning=None,
|
||||
steps=None,
|
||||
**kwargs
|
||||
self,
|
||||
cond,
|
||||
shape,
|
||||
timesteps=None,
|
||||
x_T=None,
|
||||
ddim_use_original_steps=False,
|
||||
callback=None,
|
||||
quantize_denoised=False,
|
||||
mask=None,
|
||||
x0=None,
|
||||
img_callback=None,
|
||||
log_every_t=100,
|
||||
temperature=1.0,
|
||||
noise_dropout=0.0,
|
||||
score_corrector=None,
|
||||
corrector_kwargs=None,
|
||||
unconditional_guidance_scale=1.0,
|
||||
unconditional_conditioning=None,
|
||||
steps=None,
|
||||
**kwargs,
|
||||
):
|
||||
b = shape[0]
|
||||
time_range = (
|
||||
@ -238,29 +239,24 @@ class Sampler(object):
|
||||
else np.flip(timesteps)
|
||||
)
|
||||
|
||||
total_steps=steps
|
||||
total_steps = steps
|
||||
|
||||
iterator = tqdm(
|
||||
time_range,
|
||||
desc=f'{self.__class__.__name__}',
|
||||
desc=f"{self.__class__.__name__}",
|
||||
total=total_steps,
|
||||
dynamic_ncols=True,
|
||||
)
|
||||
old_eps = []
|
||||
self.prepare_to_sample(t_enc=total_steps,all_timesteps_count=steps,**kwargs)
|
||||
img = self.get_initial_image(x_T,shape,total_steps)
|
||||
self.prepare_to_sample(t_enc=total_steps, all_timesteps_count=steps, **kwargs)
|
||||
img = self.get_initial_image(x_T, shape, total_steps)
|
||||
|
||||
# probably don't need this at all
|
||||
intermediates = {'x_inter': [img], 'pred_x0': [img]}
|
||||
intermediates = {"x_inter": [img], "pred_x0": [img]}
|
||||
|
||||
for i, step in enumerate(iterator):
|
||||
index = total_steps - i - 1
|
||||
ts = torch.full(
|
||||
(b,),
|
||||
step,
|
||||
device=self.device,
|
||||
dtype=torch.long
|
||||
)
|
||||
ts = torch.full((b,), step, device=self.device, dtype=torch.long)
|
||||
ts_next = torch.full(
|
||||
(b,),
|
||||
time_range[min(i + 1, len(time_range) - 1)],
|
||||
@ -290,7 +286,7 @@ class Sampler(object):
|
||||
unconditional_conditioning=unconditional_conditioning,
|
||||
old_eps=old_eps,
|
||||
t_next=ts_next,
|
||||
step_count=steps
|
||||
step_count=steps,
|
||||
)
|
||||
img, pred_x0, e_t = outs
|
||||
|
||||
@ -300,11 +296,11 @@ class Sampler(object):
|
||||
if callback:
|
||||
callback(i)
|
||||
if img_callback:
|
||||
img_callback(img,i)
|
||||
img_callback(img, i)
|
||||
|
||||
if index % log_every_t == 0 or index == total_steps - 1:
|
||||
intermediates['x_inter'].append(img)
|
||||
intermediates['pred_x0'].append(pred_x0)
|
||||
intermediates["x_inter"].append(img)
|
||||
intermediates["pred_x0"].append(pred_x0)
|
||||
|
||||
return img, intermediates
|
||||
|
||||
@ -312,18 +308,18 @@ class Sampler(object):
|
||||
# The variable names are changed in order to be confusing.
|
||||
@torch.no_grad()
|
||||
def decode(
|
||||
self,
|
||||
x_latent,
|
||||
cond,
|
||||
t_start,
|
||||
img_callback=None,
|
||||
unconditional_guidance_scale=1.0,
|
||||
unconditional_conditioning=None,
|
||||
use_original_steps=False,
|
||||
init_latent = None,
|
||||
mask = None,
|
||||
all_timesteps_count = None,
|
||||
**kwargs
|
||||
self,
|
||||
x_latent,
|
||||
cond,
|
||||
t_start,
|
||||
img_callback=None,
|
||||
unconditional_guidance_scale=1.0,
|
||||
unconditional_conditioning=None,
|
||||
use_original_steps=False,
|
||||
init_latent=None,
|
||||
mask=None,
|
||||
all_timesteps_count=None,
|
||||
**kwargs,
|
||||
):
|
||||
timesteps = (
|
||||
np.arange(self.ddpm_num_timesteps)
|
||||
@ -334,12 +330,16 @@ class Sampler(object):
|
||||
|
||||
time_range = np.flip(timesteps)
|
||||
total_steps = timesteps.shape[0]
|
||||
print(f'>> Running {self.__class__.__name__} sampling starting at step {self.total_steps - t_start} of {self.total_steps} ({total_steps} new sampling steps)')
|
||||
print(
|
||||
f">> Running {self.__class__.__name__} sampling starting at step {self.total_steps - t_start} of {self.total_steps} ({total_steps} new sampling steps)"
|
||||
)
|
||||
|
||||
iterator = tqdm(time_range, desc='Decoding image', total=total_steps)
|
||||
x_dec = x_latent
|
||||
x0 = init_latent
|
||||
self.prepare_to_sample(t_enc=total_steps, all_timesteps_count=all_timesteps_count, **kwargs)
|
||||
iterator = tqdm(time_range, desc="Decoding image", total=total_steps)
|
||||
x_dec = x_latent
|
||||
x0 = init_latent
|
||||
self.prepare_to_sample(
|
||||
t_enc=total_steps, all_timesteps_count=all_timesteps_count, **kwargs
|
||||
)
|
||||
|
||||
for i, step in enumerate(iterator):
|
||||
index = total_steps - i - 1
|
||||
@ -370,81 +370,85 @@ class Sampler(object):
|
||||
use_original_steps=use_original_steps,
|
||||
unconditional_guidance_scale=unconditional_guidance_scale,
|
||||
unconditional_conditioning=unconditional_conditioning,
|
||||
t_next = ts_next,
|
||||
step_count=len(self.ddim_timesteps)
|
||||
t_next=ts_next,
|
||||
step_count=len(self.ddim_timesteps),
|
||||
)
|
||||
|
||||
x_dec, pred_x0, e_t = outs
|
||||
if img_callback:
|
||||
img_callback(x_dec,i)
|
||||
img_callback(x_dec, i)
|
||||
|
||||
return x_dec
|
||||
|
||||
def get_initial_image(self,x_T,shape,timesteps=None):
|
||||
def get_initial_image(self, x_T, shape, timesteps=None):
|
||||
if x_T is None:
|
||||
return torch.randn(shape, device=self.device)
|
||||
else:
|
||||
return x_T
|
||||
|
||||
def p_sample(
|
||||
self,
|
||||
img,
|
||||
cond,
|
||||
ts,
|
||||
index,
|
||||
repeat_noise=False,
|
||||
use_original_steps=False,
|
||||
quantize_denoised=False,
|
||||
temperature=1.0,
|
||||
noise_dropout=0.0,
|
||||
score_corrector=None,
|
||||
corrector_kwargs=None,
|
||||
unconditional_guidance_scale=1.0,
|
||||
unconditional_conditioning=None,
|
||||
old_eps=None,
|
||||
t_next=None,
|
||||
steps=None,
|
||||
self,
|
||||
img,
|
||||
cond,
|
||||
ts,
|
||||
index,
|
||||
repeat_noise=False,
|
||||
use_original_steps=False,
|
||||
quantize_denoised=False,
|
||||
temperature=1.0,
|
||||
noise_dropout=0.0,
|
||||
score_corrector=None,
|
||||
corrector_kwargs=None,
|
||||
unconditional_guidance_scale=1.0,
|
||||
unconditional_conditioning=None,
|
||||
old_eps=None,
|
||||
t_next=None,
|
||||
steps=None,
|
||||
):
|
||||
raise NotImplementedError("p_sample() must be implemented in a descendent class")
|
||||
raise NotImplementedError(
|
||||
"p_sample() must be implemented in a descendent class"
|
||||
)
|
||||
|
||||
def prepare_to_sample(self,t_enc,**kwargs):
|
||||
'''
|
||||
def prepare_to_sample(self, t_enc, **kwargs):
|
||||
"""
|
||||
Hook that will be called right before the very first invocation of p_sample()
|
||||
to allow subclass to do additional initialization. t_enc corresponds to the actual
|
||||
number of steps that will be run, and may be less than total steps if img2img is
|
||||
active.
|
||||
'''
|
||||
"""
|
||||
pass
|
||||
|
||||
def get_timesteps(self,ddim_steps):
|
||||
'''
|
||||
def get_timesteps(self, ddim_steps):
|
||||
"""
|
||||
The ddim and plms samplers work on timesteps. This method is called after
|
||||
ddim_timesteps are created in make_schedule(), and selects the portion of
|
||||
timesteps that will be used for sampling, depending on the t_enc in img2img.
|
||||
'''
|
||||
"""
|
||||
return self.ddim_timesteps[:ddim_steps]
|
||||
|
||||
def q_sample(self,x0,ts):
|
||||
'''
|
||||
def q_sample(self, x0, ts):
|
||||
"""
|
||||
Returns self.model.q_sample(x0,ts). Is overridden in the k* samplers to
|
||||
return self.model.inner_model.q_sample(x0,ts)
|
||||
'''
|
||||
return self.model.q_sample(x0,ts)
|
||||
"""
|
||||
return self.model.q_sample(x0, ts)
|
||||
|
||||
def conditioning_key(self)->str:
|
||||
def conditioning_key(self) -> str:
|
||||
return self.model.model.conditioning_key
|
||||
|
||||
def uses_inpainting_model(self)->bool:
|
||||
return self.conditioning_key() in ('hybrid','concat')
|
||||
def uses_inpainting_model(self) -> bool:
|
||||
return self.conditioning_key() in ("hybrid", "concat")
|
||||
|
||||
def adjust_settings(self,**kwargs):
|
||||
'''
|
||||
def adjust_settings(self, **kwargs):
|
||||
"""
|
||||
This is a catch-all method for adjusting any instance variables
|
||||
after the sampler is instantiated. No type-checking performed
|
||||
here, so use with care!
|
||||
'''
|
||||
"""
|
||||
for k in kwargs.keys():
|
||||
try:
|
||||
setattr(self,k,kwargs[k])
|
||||
setattr(self, k, kwargs[k])
|
||||
except AttributeError:
|
||||
print(f'** Warning: attempt to set unknown attribute {k} in sampler of type {type(self)}')
|
||||
print(
|
||||
f"** Warning: attempt to set unknown attribute {k} in sampler of type {type(self)}"
|
||||
)
|
@ -1,25 +1,36 @@
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from math import ceil
|
||||
from typing import Callable, Optional, Union, Any, Dict
|
||||
from typing import Any, Callable, Dict, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from diffusers.models.cross_attention import AttnProcessor
|
||||
from typing_extensions import TypeAlias
|
||||
|
||||
from ldm.invoke.globals import Globals
|
||||
from ldm.models.diffusion.cross_attention_control import Arguments, \
|
||||
restore_default_cross_attention, override_cross_attention, Context, get_cross_attention_modules, \
|
||||
CrossAttentionType, SwapCrossAttnContext
|
||||
from ldm.models.diffusion.cross_attention_map_saving import AttentionMapSaver
|
||||
from invokeai.backend.globals import Globals
|
||||
|
||||
from .cross_attention_control import (
|
||||
Arguments,
|
||||
Context,
|
||||
CrossAttentionType,
|
||||
SwapCrossAttnContext,
|
||||
get_cross_attention_modules,
|
||||
override_cross_attention,
|
||||
restore_default_cross_attention,
|
||||
)
|
||||
from .cross_attention_map_saving import AttentionMapSaver
|
||||
|
||||
ModelForwardCallback: TypeAlias = Union[
|
||||
# x, t, conditioning, Optional[cross-attention kwargs]
|
||||
Callable[[torch.Tensor, torch.Tensor, torch.Tensor, Optional[dict[str, Any]]], torch.Tensor],
|
||||
Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor]
|
||||
Callable[
|
||||
[torch.Tensor, torch.Tensor, torch.Tensor, Optional[dict[str, Any]]],
|
||||
torch.Tensor,
|
||||
],
|
||||
Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor],
|
||||
]
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class PostprocessingSettings:
|
||||
threshold: float
|
||||
@ -29,20 +40,20 @@ class PostprocessingSettings:
|
||||
|
||||
|
||||
class InvokeAIDiffuserComponent:
|
||||
'''
|
||||
"""
|
||||
The aim of this component is to provide a single place for code that can be applied identically to
|
||||
all InvokeAI diffusion procedures.
|
||||
|
||||
At the moment it includes the following features:
|
||||
* Cross attention control ("prompt2prompt")
|
||||
* Hybrid conditioning (used for inpainting)
|
||||
'''
|
||||
"""
|
||||
|
||||
debug_thresholding = False
|
||||
sequential_guidance = False
|
||||
|
||||
@dataclass
|
||||
class ExtraConditioningInfo:
|
||||
|
||||
tokens_count_including_eos_bos: int
|
||||
cross_attention_control_args: Optional[Arguments] = None
|
||||
|
||||
@ -50,10 +61,12 @@ class InvokeAIDiffuserComponent:
|
||||
def wants_cross_attention_control(self):
|
||||
return self.cross_attention_control_args is not None
|
||||
|
||||
|
||||
def __init__(self, model, model_forward_callback: ModelForwardCallback,
|
||||
is_running_diffusers: bool=False,
|
||||
):
|
||||
def __init__(
|
||||
self,
|
||||
model,
|
||||
model_forward_callback: ModelForwardCallback,
|
||||
is_running_diffusers: bool = False,
|
||||
):
|
||||
"""
|
||||
:param model: the unet model to pass through to cross attention control
|
||||
:param model_forward_callback: a lambda with arguments (x, sigma, conditioning_to_apply). will be called repeatedly. most likely, this should simply call model.forward(x, sigma, conditioning)
|
||||
@ -66,23 +79,29 @@ class InvokeAIDiffuserComponent:
|
||||
self.sequential_guidance = Globals.sequential_guidance
|
||||
|
||||
@contextmanager
|
||||
def custom_attention_context(self,
|
||||
extra_conditioning_info: Optional[ExtraConditioningInfo],
|
||||
step_count: int):
|
||||
do_swap = extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control
|
||||
def custom_attention_context(
|
||||
self, extra_conditioning_info: Optional[ExtraConditioningInfo], step_count: int
|
||||
):
|
||||
do_swap = (
|
||||
extra_conditioning_info is not None
|
||||
and extra_conditioning_info.wants_cross_attention_control
|
||||
)
|
||||
old_attn_processor = None
|
||||
if do_swap:
|
||||
old_attn_processor = self.override_cross_attention(extra_conditioning_info,
|
||||
step_count=step_count)
|
||||
old_attn_processor = self.override_cross_attention(
|
||||
extra_conditioning_info, step_count=step_count
|
||||
)
|
||||
try:
|
||||
yield None
|
||||
finally:
|
||||
if old_attn_processor is not None:
|
||||
self.restore_default_cross_attention(old_attn_processor)
|
||||
# TODO resuscitate attention map saving
|
||||
#self.remove_attention_map_saving()
|
||||
# self.remove_attention_map_saving()
|
||||
|
||||
def override_cross_attention(self, conditioning: ExtraConditioningInfo, step_count: int) -> Dict[str, AttnProcessor]:
|
||||
def override_cross_attention(
|
||||
self, conditioning: ExtraConditioningInfo, step_count: int
|
||||
) -> Dict[str, AttnProcessor]:
|
||||
"""
|
||||
setup cross attention .swap control. for diffusers this replaces the attention processor, so
|
||||
the previous attention processor is returned so that the caller can restore it later.
|
||||
@ -90,18 +109,24 @@ class InvokeAIDiffuserComponent:
|
||||
self.conditioning = conditioning
|
||||
self.cross_attention_control_context = Context(
|
||||
arguments=self.conditioning.cross_attention_control_args,
|
||||
step_count=step_count
|
||||
step_count=step_count,
|
||||
)
|
||||
return override_cross_attention(
|
||||
self.model,
|
||||
self.cross_attention_control_context,
|
||||
is_running_diffusers=self.is_running_diffusers,
|
||||
)
|
||||
return override_cross_attention(self.model,
|
||||
self.cross_attention_control_context,
|
||||
is_running_diffusers=self.is_running_diffusers)
|
||||
|
||||
def restore_default_cross_attention(self, restore_attention_processor: Optional['AttnProcessor']=None):
|
||||
def restore_default_cross_attention(
|
||||
self, restore_attention_processor: Optional["AttnProcessor"] = None
|
||||
):
|
||||
self.conditioning = None
|
||||
self.cross_attention_control_context = None
|
||||
restore_default_cross_attention(self.model,
|
||||
is_running_diffusers=self.is_running_diffusers,
|
||||
restore_attention_processor=restore_attention_processor)
|
||||
restore_default_cross_attention(
|
||||
self.model,
|
||||
is_running_diffusers=self.is_running_diffusers,
|
||||
restore_attention_processor=restore_attention_processor,
|
||||
)
|
||||
|
||||
def setup_attention_map_saving(self, saver: AttentionMapSaver):
|
||||
def callback(slice, dim, offset, slice_size, key):
|
||||
@ -110,26 +135,40 @@ class InvokeAIDiffuserComponent:
|
||||
return
|
||||
saver.add_attention_maps(slice, key)
|
||||
|
||||
tokens_cross_attention_modules = get_cross_attention_modules(self.model, CrossAttentionType.TOKENS)
|
||||
tokens_cross_attention_modules = get_cross_attention_modules(
|
||||
self.model, CrossAttentionType.TOKENS
|
||||
)
|
||||
for identifier, module in tokens_cross_attention_modules:
|
||||
key = ('down' if identifier.startswith('down') else
|
||||
'up' if identifier.startswith('up') else
|
||||
'mid')
|
||||
key = (
|
||||
"down"
|
||||
if identifier.startswith("down")
|
||||
else "up"
|
||||
if identifier.startswith("up")
|
||||
else "mid"
|
||||
)
|
||||
module.set_attention_slice_calculated_callback(
|
||||
lambda slice, dim, offset, slice_size, key=key: callback(slice, dim, offset, slice_size, key))
|
||||
lambda slice, dim, offset, slice_size, key=key: callback(
|
||||
slice, dim, offset, slice_size, key
|
||||
)
|
||||
)
|
||||
|
||||
def remove_attention_map_saving(self):
|
||||
tokens_cross_attention_modules = get_cross_attention_modules(self.model, CrossAttentionType.TOKENS)
|
||||
tokens_cross_attention_modules = get_cross_attention_modules(
|
||||
self.model, CrossAttentionType.TOKENS
|
||||
)
|
||||
for _, module in tokens_cross_attention_modules:
|
||||
module.set_attention_slice_calculated_callback(None)
|
||||
|
||||
def do_diffusion_step(self, x: torch.Tensor, sigma: torch.Tensor,
|
||||
unconditioning: Union[torch.Tensor,dict],
|
||||
conditioning: Union[torch.Tensor,dict],
|
||||
unconditional_guidance_scale: float,
|
||||
step_index: Optional[int]=None,
|
||||
total_step_count: Optional[int]=None,
|
||||
):
|
||||
def do_diffusion_step(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
sigma: torch.Tensor,
|
||||
unconditioning: Union[torch.Tensor, dict],
|
||||
conditioning: Union[torch.Tensor, dict],
|
||||
unconditional_guidance_scale: float,
|
||||
step_index: Optional[int] = None,
|
||||
total_step_count: Optional[int] = None,
|
||||
):
|
||||
"""
|
||||
:param x: current latents
|
||||
:param sigma: aka t, passed to the internal model to control how much denoising will occur
|
||||
@ -140,33 +179,55 @@ class InvokeAIDiffuserComponent:
|
||||
:return: the new latents after applying the model to x using unscaled unconditioning and CFG-scaled conditioning.
|
||||
"""
|
||||
|
||||
|
||||
cross_attention_control_types_to_do = []
|
||||
context: Context = self.cross_attention_control_context
|
||||
if self.cross_attention_control_context is not None:
|
||||
percent_through = self.calculate_percent_through(sigma, step_index, total_step_count)
|
||||
cross_attention_control_types_to_do = context.get_active_cross_attention_control_types_for_step(percent_through)
|
||||
percent_through = self.calculate_percent_through(
|
||||
sigma, step_index, total_step_count
|
||||
)
|
||||
cross_attention_control_types_to_do = (
|
||||
context.get_active_cross_attention_control_types_for_step(
|
||||
percent_through
|
||||
)
|
||||
)
|
||||
|
||||
wants_cross_attention_control = (len(cross_attention_control_types_to_do) > 0)
|
||||
wants_cross_attention_control = len(cross_attention_control_types_to_do) > 0
|
||||
wants_hybrid_conditioning = isinstance(conditioning, dict)
|
||||
|
||||
if wants_hybrid_conditioning:
|
||||
unconditioned_next_x, conditioned_next_x = self._apply_hybrid_conditioning(x, sigma, unconditioning,
|
||||
conditioning)
|
||||
unconditioned_next_x, conditioned_next_x = self._apply_hybrid_conditioning(
|
||||
x, sigma, unconditioning, conditioning
|
||||
)
|
||||
elif wants_cross_attention_control:
|
||||
unconditioned_next_x, conditioned_next_x = self._apply_cross_attention_controlled_conditioning(x, sigma,
|
||||
unconditioning,
|
||||
conditioning,
|
||||
cross_attention_control_types_to_do)
|
||||
(
|
||||
unconditioned_next_x,
|
||||
conditioned_next_x,
|
||||
) = self._apply_cross_attention_controlled_conditioning(
|
||||
x,
|
||||
sigma,
|
||||
unconditioning,
|
||||
conditioning,
|
||||
cross_attention_control_types_to_do,
|
||||
)
|
||||
elif self.sequential_guidance:
|
||||
unconditioned_next_x, conditioned_next_x = self._apply_standard_conditioning_sequentially(
|
||||
x, sigma, unconditioning, conditioning)
|
||||
(
|
||||
unconditioned_next_x,
|
||||
conditioned_next_x,
|
||||
) = self._apply_standard_conditioning_sequentially(
|
||||
x, sigma, unconditioning, conditioning
|
||||
)
|
||||
|
||||
else:
|
||||
unconditioned_next_x, conditioned_next_x = self._apply_standard_conditioning(
|
||||
x, sigma, unconditioning, conditioning)
|
||||
(
|
||||
unconditioned_next_x,
|
||||
conditioned_next_x,
|
||||
) = self._apply_standard_conditioning(
|
||||
x, sigma, unconditioning, conditioning
|
||||
)
|
||||
|
||||
combined_next_x = self._combine(unconditioned_next_x, conditioned_next_x, unconditional_guidance_scale)
|
||||
combined_next_x = self._combine(
|
||||
unconditioned_next_x, conditioned_next_x, unconditional_guidance_scale
|
||||
)
|
||||
|
||||
return combined_next_x
|
||||
|
||||
@ -176,24 +237,33 @@ class InvokeAIDiffuserComponent:
|
||||
latents: torch.Tensor,
|
||||
sigma,
|
||||
step_index,
|
||||
total_step_count
|
||||
total_step_count,
|
||||
) -> torch.Tensor:
|
||||
if postprocessing_settings is not None:
|
||||
percent_through = self.calculate_percent_through(sigma, step_index, total_step_count)
|
||||
latents = self.apply_threshold(postprocessing_settings, latents, percent_through)
|
||||
latents = self.apply_symmetry(postprocessing_settings, latents, percent_through)
|
||||
percent_through = self.calculate_percent_through(
|
||||
sigma, step_index, total_step_count
|
||||
)
|
||||
latents = self.apply_threshold(
|
||||
postprocessing_settings, latents, percent_through
|
||||
)
|
||||
latents = self.apply_symmetry(
|
||||
postprocessing_settings, latents, percent_through
|
||||
)
|
||||
return latents
|
||||
|
||||
def calculate_percent_through(self, sigma, step_index, total_step_count):
|
||||
if step_index is not None and total_step_count is not None:
|
||||
# 🧨diffusers codepath
|
||||
percent_through = step_index / total_step_count # will never reach 1.0 - this is deliberate
|
||||
percent_through = (
|
||||
step_index / total_step_count
|
||||
) # will never reach 1.0 - this is deliberate
|
||||
else:
|
||||
# legacy compvis codepath
|
||||
# TODO remove when compvis codepath support is dropped
|
||||
if step_index is None and sigma is None:
|
||||
raise ValueError(
|
||||
f"Either step_index or sigma is required when doing cross attention control, but both are None.")
|
||||
f"Either step_index or sigma is required when doing cross attention control, but both are None."
|
||||
)
|
||||
percent_through = self.estimate_percent_through(step_index, sigma)
|
||||
return percent_through
|
||||
|
||||
@ -204,24 +274,30 @@ class InvokeAIDiffuserComponent:
|
||||
x_twice = torch.cat([x] * 2)
|
||||
sigma_twice = torch.cat([sigma] * 2)
|
||||
both_conditionings = torch.cat([unconditioning, conditioning])
|
||||
both_results = self.model_forward_callback(x_twice, sigma_twice, both_conditionings)
|
||||
both_results = self.model_forward_callback(
|
||||
x_twice, sigma_twice, both_conditionings
|
||||
)
|
||||
unconditioned_next_x, conditioned_next_x = both_results.chunk(2)
|
||||
if conditioned_next_x.device.type == 'mps':
|
||||
if conditioned_next_x.device.type == "mps":
|
||||
# prevent a result filled with zeros. seems to be a torch bug.
|
||||
conditioned_next_x = conditioned_next_x.clone()
|
||||
return unconditioned_next_x, conditioned_next_x
|
||||
|
||||
|
||||
def _apply_standard_conditioning_sequentially(self, x: torch.Tensor, sigma, unconditioning: torch.Tensor, conditioning: torch.Tensor):
|
||||
def _apply_standard_conditioning_sequentially(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
sigma,
|
||||
unconditioning: torch.Tensor,
|
||||
conditioning: torch.Tensor,
|
||||
):
|
||||
# low-memory sequential path
|
||||
unconditioned_next_x = self.model_forward_callback(x, sigma, unconditioning)
|
||||
conditioned_next_x = self.model_forward_callback(x, sigma, conditioning)
|
||||
if conditioned_next_x.device.type == 'mps':
|
||||
if conditioned_next_x.device.type == "mps":
|
||||
# prevent a result filled with zeros. seems to be a torch bug.
|
||||
conditioned_next_x = conditioned_next_x.clone()
|
||||
return unconditioned_next_x, conditioned_next_x
|
||||
|
||||
|
||||
def _apply_hybrid_conditioning(self, x, sigma, unconditioning, conditioning):
|
||||
assert isinstance(conditioning, dict)
|
||||
assert isinstance(unconditioning, dict)
|
||||
@ -236,48 +312,80 @@ class InvokeAIDiffuserComponent:
|
||||
]
|
||||
else:
|
||||
both_conditionings[k] = torch.cat([unconditioning[k], conditioning[k]])
|
||||
unconditioned_next_x, conditioned_next_x = self.model_forward_callback(x_twice, sigma_twice, both_conditionings).chunk(2)
|
||||
unconditioned_next_x, conditioned_next_x = self.model_forward_callback(
|
||||
x_twice, sigma_twice, both_conditionings
|
||||
).chunk(2)
|
||||
return unconditioned_next_x, conditioned_next_x
|
||||
|
||||
|
||||
def _apply_cross_attention_controlled_conditioning(self,
|
||||
x: torch.Tensor,
|
||||
sigma,
|
||||
unconditioning,
|
||||
conditioning,
|
||||
cross_attention_control_types_to_do):
|
||||
def _apply_cross_attention_controlled_conditioning(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
sigma,
|
||||
unconditioning,
|
||||
conditioning,
|
||||
cross_attention_control_types_to_do,
|
||||
):
|
||||
if self.is_running_diffusers:
|
||||
return self._apply_cross_attention_controlled_conditioning__diffusers(x, sigma, unconditioning,
|
||||
conditioning,
|
||||
cross_attention_control_types_to_do)
|
||||
return self._apply_cross_attention_controlled_conditioning__diffusers(
|
||||
x,
|
||||
sigma,
|
||||
unconditioning,
|
||||
conditioning,
|
||||
cross_attention_control_types_to_do,
|
||||
)
|
||||
else:
|
||||
return self._apply_cross_attention_controlled_conditioning__compvis(x, sigma, unconditioning, conditioning,
|
||||
cross_attention_control_types_to_do)
|
||||
return self._apply_cross_attention_controlled_conditioning__compvis(
|
||||
x,
|
||||
sigma,
|
||||
unconditioning,
|
||||
conditioning,
|
||||
cross_attention_control_types_to_do,
|
||||
)
|
||||
|
||||
def _apply_cross_attention_controlled_conditioning__diffusers(self,
|
||||
x: torch.Tensor,
|
||||
sigma,
|
||||
unconditioning,
|
||||
conditioning,
|
||||
cross_attention_control_types_to_do):
|
||||
def _apply_cross_attention_controlled_conditioning__diffusers(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
sigma,
|
||||
unconditioning,
|
||||
conditioning,
|
||||
cross_attention_control_types_to_do,
|
||||
):
|
||||
context: Context = self.cross_attention_control_context
|
||||
|
||||
cross_attn_processor_context = SwapCrossAttnContext(modified_text_embeddings=context.arguments.edited_conditioning,
|
||||
index_map=context.cross_attention_index_map,
|
||||
mask=context.cross_attention_mask,
|
||||
cross_attention_types_to_do=[])
|
||||
cross_attn_processor_context = SwapCrossAttnContext(
|
||||
modified_text_embeddings=context.arguments.edited_conditioning,
|
||||
index_map=context.cross_attention_index_map,
|
||||
mask=context.cross_attention_mask,
|
||||
cross_attention_types_to_do=[],
|
||||
)
|
||||
# no cross attention for unconditioning (negative prompt)
|
||||
unconditioned_next_x = self.model_forward_callback(x, sigma, unconditioning,
|
||||
{"swap_cross_attn_context": cross_attn_processor_context})
|
||||
unconditioned_next_x = self.model_forward_callback(
|
||||
x,
|
||||
sigma,
|
||||
unconditioning,
|
||||
{"swap_cross_attn_context": cross_attn_processor_context},
|
||||
)
|
||||
|
||||
# do requested cross attention types for conditioning (positive prompt)
|
||||
cross_attn_processor_context.cross_attention_types_to_do = cross_attention_control_types_to_do
|
||||
conditioned_next_x = self.model_forward_callback(x, sigma, conditioning,
|
||||
{"swap_cross_attn_context": cross_attn_processor_context})
|
||||
cross_attn_processor_context.cross_attention_types_to_do = (
|
||||
cross_attention_control_types_to_do
|
||||
)
|
||||
conditioned_next_x = self.model_forward_callback(
|
||||
x,
|
||||
sigma,
|
||||
conditioning,
|
||||
{"swap_cross_attn_context": cross_attn_processor_context},
|
||||
)
|
||||
return unconditioned_next_x, conditioned_next_x
|
||||
|
||||
|
||||
def _apply_cross_attention_controlled_conditioning__compvis(self, x:torch.Tensor, sigma, unconditioning, conditioning, cross_attention_control_types_to_do):
|
||||
def _apply_cross_attention_controlled_conditioning__compvis(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
sigma,
|
||||
unconditioning,
|
||||
conditioning,
|
||||
cross_attention_control_types_to_do,
|
||||
):
|
||||
# print('pct', percent_through, ': doing cross attention control on', cross_attention_control_types_to_do)
|
||||
# slower non-batched path (20% slower on mac MPS)
|
||||
# We are only interested in using attention maps for conditioned_next_x, but batching them with generation of
|
||||
@ -287,24 +395,28 @@ class InvokeAIDiffuserComponent:
|
||||
# representing batched uncond + cond, but then when it comes to applying the saved attention, the
|
||||
# wrangler gets an attention tensor which only has shape[0]=8, representing just self.edited_conditionings.)
|
||||
# todo: give CrossAttentionControl's `wrangler` function more info so it can work with a batched call as well.
|
||||
context:Context = self.cross_attention_control_context
|
||||
context: Context = self.cross_attention_control_context
|
||||
|
||||
try:
|
||||
unconditioned_next_x = self.model_forward_callback(x, sigma, unconditioning)
|
||||
|
||||
# process x using the original prompt, saving the attention maps
|
||||
#print("saving attention maps for", cross_attention_control_types_to_do)
|
||||
# print("saving attention maps for", cross_attention_control_types_to_do)
|
||||
for ca_type in cross_attention_control_types_to_do:
|
||||
context.request_save_attention_maps(ca_type)
|
||||
_ = self.model_forward_callback(x, sigma, conditioning)
|
||||
context.clear_requests(cleanup=False)
|
||||
|
||||
# process x again, using the saved attention maps to control where self.edited_conditioning will be applied
|
||||
#print("applying saved attention maps for", cross_attention_control_types_to_do)
|
||||
# print("applying saved attention maps for", cross_attention_control_types_to_do)
|
||||
for ca_type in cross_attention_control_types_to_do:
|
||||
context.request_apply_saved_attention_maps(ca_type)
|
||||
edited_conditioning = self.conditioning.cross_attention_control_args.edited_conditioning
|
||||
conditioned_next_x = self.model_forward_callback(x, sigma, edited_conditioning)
|
||||
edited_conditioning = (
|
||||
self.conditioning.cross_attention_control_args.edited_conditioning
|
||||
)
|
||||
conditioned_next_x = self.model_forward_callback(
|
||||
x, sigma, edited_conditioning
|
||||
)
|
||||
context.clear_requests(cleanup=True)
|
||||
|
||||
except:
|
||||
@ -323,17 +435,21 @@ class InvokeAIDiffuserComponent:
|
||||
self,
|
||||
postprocessing_settings: PostprocessingSettings,
|
||||
latents: torch.Tensor,
|
||||
percent_through: float
|
||||
percent_through: float,
|
||||
) -> torch.Tensor:
|
||||
|
||||
if postprocessing_settings.threshold is None or postprocessing_settings.threshold == 0.0:
|
||||
if (
|
||||
postprocessing_settings.threshold is None
|
||||
or postprocessing_settings.threshold == 0.0
|
||||
):
|
||||
return latents
|
||||
|
||||
threshold = postprocessing_settings.threshold
|
||||
warmup = postprocessing_settings.warmup
|
||||
|
||||
if percent_through < warmup:
|
||||
current_threshold = threshold + threshold * 5 * (1 - (percent_through / warmup))
|
||||
current_threshold = threshold + threshold * 5 * (
|
||||
1 - (percent_through / warmup)
|
||||
)
|
||||
else:
|
||||
current_threshold = threshold
|
||||
|
||||
@ -347,10 +463,14 @@ class InvokeAIDiffuserComponent:
|
||||
|
||||
if self.debug_thresholding:
|
||||
std, mean = [i.item() for i in torch.std_mean(latents)]
|
||||
outside = torch.count_nonzero((latents < -current_threshold) | (latents > current_threshold))
|
||||
print(f"\nThreshold: %={percent_through} threshold={current_threshold:.3f} (of {threshold:.3f})\n"
|
||||
f" | min, mean, max = {minval:.3f}, {mean:.3f}, {maxval:.3f}\tstd={std}\n"
|
||||
f" | {outside / latents.numel() * 100:.2f}% values outside threshold")
|
||||
outside = torch.count_nonzero(
|
||||
(latents < -current_threshold) | (latents > current_threshold)
|
||||
)
|
||||
print(
|
||||
f"\nThreshold: %={percent_through} threshold={current_threshold:.3f} (of {threshold:.3f})\n"
|
||||
f" | min, mean, max = {minval:.3f}, {mean:.3f}, {maxval:.3f}\tstd={std}\n"
|
||||
f" | {outside / latents.numel() * 100:.2f}% values outside threshold"
|
||||
)
|
||||
|
||||
if maxval < current_threshold and minval > -current_threshold:
|
||||
return latents
|
||||
@ -363,17 +483,23 @@ class InvokeAIDiffuserComponent:
|
||||
latents = torch.clone(latents)
|
||||
maxval = np.clip(maxval * scale, 1, current_threshold)
|
||||
num_altered += torch.count_nonzero(latents > maxval)
|
||||
latents[latents > maxval] = torch.rand_like(latents[latents > maxval]) * maxval
|
||||
latents[latents > maxval] = (
|
||||
torch.rand_like(latents[latents > maxval]) * maxval
|
||||
)
|
||||
|
||||
if minval < -current_threshold:
|
||||
latents = torch.clone(latents)
|
||||
minval = np.clip(minval * scale, -current_threshold, -1)
|
||||
num_altered += torch.count_nonzero(latents < minval)
|
||||
latents[latents < minval] = torch.rand_like(latents[latents < minval]) * minval
|
||||
latents[latents < minval] = (
|
||||
torch.rand_like(latents[latents < minval]) * minval
|
||||
)
|
||||
|
||||
if self.debug_thresholding:
|
||||
print(f" | min, , max = {minval:.3f}, , {maxval:.3f}\t(scaled by {scale})\n"
|
||||
f" | {num_altered / latents.numel() * 100:.2f}% values altered")
|
||||
print(
|
||||
f" | min, , max = {minval:.3f}, , {maxval:.3f}\t(scaled by {scale})\n"
|
||||
f" | {num_altered / latents.numel() * 100:.2f}% values altered"
|
||||
)
|
||||
|
||||
return latents
|
||||
|
||||
@ -381,9 +507,8 @@ class InvokeAIDiffuserComponent:
|
||||
self,
|
||||
postprocessing_settings: PostprocessingSettings,
|
||||
latents: torch.Tensor,
|
||||
percent_through: float
|
||||
percent_through: float,
|
||||
) -> torch.Tensor:
|
||||
|
||||
# Reset our last percent through if this is our first step.
|
||||
if percent_through == 0.0:
|
||||
self.last_percent_through = 0.0
|
||||
@ -393,36 +518,52 @@ class InvokeAIDiffuserComponent:
|
||||
|
||||
# Check for out of bounds
|
||||
h_symmetry_time_pct = postprocessing_settings.h_symmetry_time_pct
|
||||
if (h_symmetry_time_pct is not None and (h_symmetry_time_pct <= 0.0 or h_symmetry_time_pct > 1.0)):
|
||||
if h_symmetry_time_pct is not None and (
|
||||
h_symmetry_time_pct <= 0.0 or h_symmetry_time_pct > 1.0
|
||||
):
|
||||
h_symmetry_time_pct = None
|
||||
|
||||
v_symmetry_time_pct = postprocessing_settings.v_symmetry_time_pct
|
||||
if (v_symmetry_time_pct is not None and (v_symmetry_time_pct <= 0.0 or v_symmetry_time_pct > 1.0)):
|
||||
if v_symmetry_time_pct is not None and (
|
||||
v_symmetry_time_pct <= 0.0 or v_symmetry_time_pct > 1.0
|
||||
):
|
||||
v_symmetry_time_pct = None
|
||||
|
||||
dev = latents.device.type
|
||||
|
||||
latents.to(device='cpu')
|
||||
latents.to(device="cpu")
|
||||
|
||||
if (
|
||||
h_symmetry_time_pct != None and
|
||||
self.last_percent_through < h_symmetry_time_pct and
|
||||
percent_through >= h_symmetry_time_pct
|
||||
h_symmetry_time_pct != None
|
||||
and self.last_percent_through < h_symmetry_time_pct
|
||||
and percent_through >= h_symmetry_time_pct
|
||||
):
|
||||
# Horizontal symmetry occurs on the 3rd dimension of the latent
|
||||
width = latents.shape[3]
|
||||
x_flipped = torch.flip(latents, dims=[3])
|
||||
latents = torch.cat([latents[:, :, :, 0:int(width/2)], x_flipped[:, :, :, int(width/2):int(width)]], dim=3)
|
||||
latents = torch.cat(
|
||||
[
|
||||
latents[:, :, :, 0 : int(width / 2)],
|
||||
x_flipped[:, :, :, int(width / 2) : int(width)],
|
||||
],
|
||||
dim=3,
|
||||
)
|
||||
|
||||
if (
|
||||
v_symmetry_time_pct != None and
|
||||
self.last_percent_through < v_symmetry_time_pct and
|
||||
percent_through >= v_symmetry_time_pct
|
||||
v_symmetry_time_pct != None
|
||||
and self.last_percent_through < v_symmetry_time_pct
|
||||
and percent_through >= v_symmetry_time_pct
|
||||
):
|
||||
# Vertical symmetry occurs on the 2nd dimension of the latent
|
||||
height = latents.shape[2]
|
||||
y_flipped = torch.flip(latents, dims=[2])
|
||||
latents = torch.cat([latents[:, :, 0:int(height / 2)], y_flipped[:, :, int(height / 2):int(height)]], dim=2)
|
||||
latents = torch.cat(
|
||||
[
|
||||
latents[:, :, 0 : int(height / 2)],
|
||||
y_flipped[:, :, int(height / 2) : int(height)],
|
||||
],
|
||||
dim=2,
|
||||
)
|
||||
|
||||
self.last_percent_through = percent_through
|
||||
return latents.to(device=dev)
|
||||
@ -430,7 +571,9 @@ class InvokeAIDiffuserComponent:
|
||||
def estimate_percent_through(self, step_index, sigma):
|
||||
if step_index is not None and self.cross_attention_control_context is not None:
|
||||
# percent_through will never reach 1.0 (but this is intended)
|
||||
return float(step_index) / float(self.cross_attention_control_context.step_count)
|
||||
return float(step_index) / float(
|
||||
self.cross_attention_control_context.step_count
|
||||
)
|
||||
# find the best possible index of the current sigma in the sigma sequence
|
||||
smaller_sigmas = torch.nonzero(self.model.sigmas <= sigma)
|
||||
sigma_index = smaller_sigmas[-1].item() if smaller_sigmas.shape[0] > 0 else 0
|
||||
@ -439,33 +582,38 @@ class InvokeAIDiffuserComponent:
|
||||
return 1.0 - float(sigma_index + 1) / float(self.model.sigmas.shape[0])
|
||||
# print('estimated percent_through', percent_through, 'from sigma', sigma.item())
|
||||
|
||||
|
||||
# todo: make this work
|
||||
@classmethod
|
||||
def apply_conjunction(cls, x, t, forward_func, uc, c_or_weighted_c_list, global_guidance_scale):
|
||||
def apply_conjunction(
|
||||
cls, x, t, forward_func, uc, c_or_weighted_c_list, global_guidance_scale
|
||||
):
|
||||
x_in = torch.cat([x] * 2)
|
||||
t_in = torch.cat([t] * 2) # aka sigmas
|
||||
t_in = torch.cat([t] * 2) # aka sigmas
|
||||
|
||||
deltas = None
|
||||
uncond_latents = None
|
||||
weighted_cond_list = c_or_weighted_c_list if type(c_or_weighted_c_list) is list else [(c_or_weighted_c_list, 1)]
|
||||
weighted_cond_list = (
|
||||
c_or_weighted_c_list
|
||||
if type(c_or_weighted_c_list) is list
|
||||
else [(c_or_weighted_c_list, 1)]
|
||||
)
|
||||
|
||||
# below is fugly omg
|
||||
num_actual_conditionings = len(c_or_weighted_c_list)
|
||||
conditionings = [uc] + [c for c,weight in weighted_cond_list]
|
||||
weights = [1] + [weight for c,weight in weighted_cond_list]
|
||||
chunk_count = ceil(len(conditionings)/2)
|
||||
conditionings = [uc] + [c for c, weight in weighted_cond_list]
|
||||
weights = [1] + [weight for c, weight in weighted_cond_list]
|
||||
chunk_count = ceil(len(conditionings) / 2)
|
||||
deltas = None
|
||||
for chunk_index in range(chunk_count):
|
||||
offset = chunk_index*2
|
||||
chunk_size = min(2, len(conditionings)-offset)
|
||||
offset = chunk_index * 2
|
||||
chunk_size = min(2, len(conditionings) - offset)
|
||||
|
||||
if chunk_size == 1:
|
||||
c_in = conditionings[offset]
|
||||
latents_a = forward_func(x_in[:-1], t_in[:-1], c_in)
|
||||
latents_b = None
|
||||
else:
|
||||
c_in = torch.cat(conditionings[offset:offset+2])
|
||||
c_in = torch.cat(conditionings[offset : offset + 2])
|
||||
latents_a, latents_b = forward_func(x_in, t_in, c_in).chunk(2)
|
||||
|
||||
# first chunk is guaranteed to be 2 entries: uncond_latents + first conditioining
|
||||
@ -478,11 +626,15 @@ class InvokeAIDiffuserComponent:
|
||||
deltas = torch.cat((deltas, latents_b - uncond_latents))
|
||||
|
||||
# merge the weighted deltas together into a single merged delta
|
||||
per_delta_weights = torch.tensor(weights[1:], dtype=deltas.dtype, device=deltas.device)
|
||||
per_delta_weights = torch.tensor(
|
||||
weights[1:], dtype=deltas.dtype, device=deltas.device
|
||||
)
|
||||
normalize = False
|
||||
if normalize:
|
||||
per_delta_weights /= torch.sum(per_delta_weights)
|
||||
reshaped_weights = per_delta_weights.reshape(per_delta_weights.shape + (1, 1, 1))
|
||||
reshaped_weights = per_delta_weights.reshape(
|
||||
per_delta_weights.shape + (1, 1, 1)
|
||||
)
|
||||
deltas_merged = torch.sum(deltas * reshaped_weights, dim=0, keepdim=True)
|
||||
|
||||
# old_return_value = super().forward(x, sigma, uncond, cond, cond_scale)
|
1081
invokeai/backend/stable_diffusion/diffusionmodules/model.py
Normal file
1081
invokeai/backend/stable_diffusion/diffusionmodules/model.py
Normal file
File diff suppressed because it is too large
Load Diff
@ -1,23 +1,22 @@
|
||||
import math
|
||||
from abc import abstractmethod
|
||||
from functools import partial
|
||||
import math
|
||||
from typing import Iterable
|
||||
|
||||
import numpy as np
|
||||
import torch as th
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ldm.modules.attention import SpatialTransformer
|
||||
from ldm.modules.diffusionmodules.util import (
|
||||
avg_pool_nd,
|
||||
checkpoint,
|
||||
conv_nd,
|
||||
linear,
|
||||
avg_pool_nd,
|
||||
zero_module,
|
||||
normalization,
|
||||
timestep_embedding,
|
||||
zero_module,
|
||||
)
|
||||
from ldm.modules.attention import SpatialTransformer
|
||||
|
||||
|
||||
# dummy replace
|
||||
@ -100,9 +99,7 @@ class Upsample(nn.Module):
|
||||
upsampling occurs in the inner-two dimensions.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, channels, use_conv, dims=2, out_channels=None, padding=1
|
||||
):
|
||||
def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.out_channels = out_channels or channels
|
||||
@ -117,10 +114,10 @@ class Upsample(nn.Module):
|
||||
assert x.shape[1] == self.channels
|
||||
if self.dims == 3:
|
||||
x = F.interpolate(
|
||||
x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode='nearest'
|
||||
x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest"
|
||||
)
|
||||
else:
|
||||
x = F.interpolate(x, scale_factor=2, mode='nearest')
|
||||
x = F.interpolate(x, scale_factor=2, mode="nearest")
|
||||
if self.use_conv:
|
||||
x = self.conv(x)
|
||||
return x
|
||||
@ -151,9 +148,7 @@ class Downsample(nn.Module):
|
||||
downsampling occurs in the inner-two dimensions.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, channels, use_conv, dims=2, out_channels=None, padding=1
|
||||
):
|
||||
def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.out_channels = out_channels or channels
|
||||
@ -237,9 +232,7 @@ class ResBlock(TimestepBlock):
|
||||
nn.SiLU(),
|
||||
linear(
|
||||
emb_channels,
|
||||
2 * self.out_channels
|
||||
if use_scale_shift_norm
|
||||
else self.out_channels,
|
||||
2 * self.out_channels if use_scale_shift_norm else self.out_channels,
|
||||
),
|
||||
)
|
||||
self.out_layers = nn.Sequential(
|
||||
@ -247,9 +240,7 @@ class ResBlock(TimestepBlock):
|
||||
nn.SiLU(),
|
||||
nn.Dropout(p=dropout),
|
||||
zero_module(
|
||||
conv_nd(
|
||||
dims, self.out_channels, self.out_channels, 3, padding=1
|
||||
)
|
||||
conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)
|
||||
),
|
||||
)
|
||||
|
||||
@ -260,9 +251,7 @@ class ResBlock(TimestepBlock):
|
||||
dims, channels, self.out_channels, 3, padding=1
|
||||
)
|
||||
else:
|
||||
self.skip_connection = conv_nd(
|
||||
dims, channels, self.out_channels, 1
|
||||
)
|
||||
self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
|
||||
|
||||
def forward(self, x, emb):
|
||||
"""
|
||||
@ -320,7 +309,7 @@ class AttentionBlock(nn.Module):
|
||||
else:
|
||||
assert (
|
||||
channels % num_head_channels == 0
|
||||
), f'q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}'
|
||||
), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
|
||||
self.num_heads = channels // num_head_channels
|
||||
self.use_checkpoint = use_checkpoint
|
||||
self.norm = normalization(channels)
|
||||
@ -337,7 +326,7 @@ class AttentionBlock(nn.Module):
|
||||
def forward(self, x):
|
||||
return checkpoint(
|
||||
self._forward, (x,), self.parameters(), True
|
||||
) # TODO: check checkpoint usage, is True # TODO: fix the .half call!!!
|
||||
) # TODO: check checkpoint usage, is True # TODO: fix the .half call!!!
|
||||
# return pt_checkpoint(self._forward, x) # pytorch
|
||||
|
||||
def _forward(self, x):
|
||||
@ -387,15 +376,13 @@ class QKVAttentionLegacy(nn.Module):
|
||||
bs, width, length = qkv.shape
|
||||
assert width % (3 * self.n_heads) == 0
|
||||
ch = width // (3 * self.n_heads)
|
||||
q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(
|
||||
ch, dim=1
|
||||
)
|
||||
q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
|
||||
scale = 1 / math.sqrt(math.sqrt(ch))
|
||||
weight = th.einsum(
|
||||
'bct,bcs->bts', q * scale, k * scale
|
||||
"bct,bcs->bts", q * scale, k * scale
|
||||
) # More stable with f16 than dividing afterwards
|
||||
weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
|
||||
a = th.einsum('bts,bcs->bct', weight, v)
|
||||
a = th.einsum("bts,bcs->bct", weight, v)
|
||||
return a.reshape(bs, -1, length)
|
||||
|
||||
@staticmethod
|
||||
@ -424,14 +411,12 @@ class QKVAttention(nn.Module):
|
||||
q, k, v = qkv.chunk(3, dim=1)
|
||||
scale = 1 / math.sqrt(math.sqrt(ch))
|
||||
weight = th.einsum(
|
||||
'bct,bcs->bts',
|
||||
"bct,bcs->bts",
|
||||
(q * scale).view(bs * self.n_heads, ch, length),
|
||||
(k * scale).view(bs * self.n_heads, ch, length),
|
||||
) # More stable with f16 than dividing afterwards
|
||||
weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
|
||||
a = th.einsum(
|
||||
'bts,bcs->bct', weight, v.reshape(bs * self.n_heads, ch, length)
|
||||
)
|
||||
a = th.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length))
|
||||
return a.reshape(bs, -1, length)
|
||||
|
||||
@staticmethod
|
||||
@ -500,12 +485,12 @@ class UNetModel(nn.Module):
|
||||
if use_spatial_transformer:
|
||||
assert (
|
||||
context_dim is not None
|
||||
), 'Fool!! You forgot to include the dimension of your cross-attention conditioning...'
|
||||
), "Fool!! You forgot to include the dimension of your cross-attention conditioning..."
|
||||
|
||||
if context_dim is not None:
|
||||
assert (
|
||||
use_spatial_transformer
|
||||
), 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...'
|
||||
), "Fool!! You forgot to use the spatial transformer for your cross-attention conditioning..."
|
||||
from omegaconf.listconfig import ListConfig
|
||||
|
||||
if type(context_dim) == ListConfig:
|
||||
@ -517,12 +502,12 @@ class UNetModel(nn.Module):
|
||||
if num_heads == -1:
|
||||
assert (
|
||||
num_head_channels != -1
|
||||
), 'Either num_heads or num_head_channels has to be set'
|
||||
), "Either num_heads or num_head_channels has to be set"
|
||||
|
||||
if num_head_channels == -1:
|
||||
assert (
|
||||
num_heads != -1
|
||||
), 'Either num_heads or num_head_channels has to be set'
|
||||
), "Either num_heads or num_head_channels has to be set"
|
||||
|
||||
self.image_size = image_size
|
||||
self.in_channels = in_channels
|
||||
@ -641,11 +626,7 @@ class UNetModel(nn.Module):
|
||||
dim_head = num_head_channels
|
||||
if legacy:
|
||||
# num_heads = 1
|
||||
dim_head = (
|
||||
ch // num_heads
|
||||
if use_spatial_transformer
|
||||
else num_head_channels
|
||||
)
|
||||
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
|
||||
self.middle_block = TimestepEmbedSequential(
|
||||
ResBlock(
|
||||
ch,
|
||||
@ -741,9 +722,7 @@ class UNetModel(nn.Module):
|
||||
up=True,
|
||||
)
|
||||
if resblock_updown
|
||||
else Upsample(
|
||||
ch, conv_resample, dims=dims, out_channels=out_ch
|
||||
)
|
||||
else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)
|
||||
)
|
||||
ds //= 2
|
||||
self.output_blocks.append(TimestepEmbedSequential(*layers))
|
||||
@ -752,9 +731,7 @@ class UNetModel(nn.Module):
|
||||
self.out = nn.Sequential(
|
||||
normalization(ch),
|
||||
nn.SiLU(),
|
||||
zero_module(
|
||||
conv_nd(dims, model_channels, out_channels, 3, padding=1)
|
||||
),
|
||||
zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)),
|
||||
)
|
||||
if self.predict_codebook_ids:
|
||||
self.id_predictor = nn.Sequential(
|
||||
@ -790,11 +767,9 @@ class UNetModel(nn.Module):
|
||||
"""
|
||||
assert (y is not None) == (
|
||||
self.num_classes is not None
|
||||
), 'must specify y if and only if the model is class-conditional'
|
||||
), "must specify y if and only if the model is class-conditional"
|
||||
hs = []
|
||||
t_emb = timestep_embedding(
|
||||
timesteps, self.model_channels, repeat_only=False
|
||||
)
|
||||
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
|
||||
emb = self.time_embed(t_emb)
|
||||
|
||||
if self.num_classes is not None:
|
||||
@ -842,7 +817,7 @@ class EncoderUNetModel(nn.Module):
|
||||
use_scale_shift_norm=False,
|
||||
resblock_updown=False,
|
||||
use_new_attention_order=False,
|
||||
pool='adaptive',
|
||||
pool="adaptive",
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
@ -962,7 +937,7 @@ class EncoderUNetModel(nn.Module):
|
||||
)
|
||||
self._feature_size += ch
|
||||
self.pool = pool
|
||||
if pool == 'adaptive':
|
||||
if pool == "adaptive":
|
||||
self.out = nn.Sequential(
|
||||
normalization(ch),
|
||||
nn.SiLU(),
|
||||
@ -970,7 +945,7 @@ class EncoderUNetModel(nn.Module):
|
||||
zero_module(conv_nd(dims, ch, out_channels, 1)),
|
||||
nn.Flatten(),
|
||||
)
|
||||
elif pool == 'attention':
|
||||
elif pool == "attention":
|
||||
assert num_head_channels != -1
|
||||
self.out = nn.Sequential(
|
||||
normalization(ch),
|
||||
@ -979,13 +954,13 @@ class EncoderUNetModel(nn.Module):
|
||||
(image_size // ds), ch, num_head_channels, out_channels
|
||||
),
|
||||
)
|
||||
elif pool == 'spatial':
|
||||
elif pool == "spatial":
|
||||
self.out = nn.Sequential(
|
||||
nn.Linear(self._feature_size, 2048),
|
||||
nn.ReLU(),
|
||||
nn.Linear(2048, self.out_channels),
|
||||
)
|
||||
elif pool == 'spatial_v2':
|
||||
elif pool == "spatial_v2":
|
||||
self.out = nn.Sequential(
|
||||
nn.Linear(self._feature_size, 2048),
|
||||
normalization(2048),
|
||||
@ -993,7 +968,7 @@ class EncoderUNetModel(nn.Module):
|
||||
nn.Linear(2048, self.out_channels),
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(f'Unexpected {pool} pooling')
|
||||
raise NotImplementedError(f"Unexpected {pool} pooling")
|
||||
|
||||
def convert_to_fp16(self):
|
||||
"""
|
||||
@ -1016,18 +991,16 @@ class EncoderUNetModel(nn.Module):
|
||||
:param timesteps: a 1-D batch of timesteps.
|
||||
:return: an [N x K] Tensor of outputs.
|
||||
"""
|
||||
emb = self.time_embed(
|
||||
timestep_embedding(timesteps, self.model_channels)
|
||||
)
|
||||
emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
|
||||
|
||||
results = []
|
||||
h = x.type(self.dtype)
|
||||
for module in self.input_blocks:
|
||||
h = module(h, emb)
|
||||
if self.pool.startswith('spatial'):
|
||||
if self.pool.startswith("spatial"):
|
||||
results.append(h.type(x.dtype).mean(dim=(2, 3)))
|
||||
h = self.middle_block(h, emb)
|
||||
if self.pool.startswith('spatial'):
|
||||
if self.pool.startswith("spatial"):
|
||||
results.append(h.type(x.dtype).mean(dim=(2, 3)))
|
||||
h = th.cat(results, axis=-1)
|
||||
return self.out(h)
|
@ -8,20 +8,21 @@
|
||||
# thanks!
|
||||
|
||||
|
||||
import os
|
||||
import math
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import numpy as np
|
||||
from einops import repeat
|
||||
|
||||
from ldm.util import instantiate_from_config
|
||||
from ...util.util import instantiate_from_config
|
||||
|
||||
|
||||
def make_beta_schedule(
|
||||
schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3
|
||||
):
|
||||
if schedule == 'linear':
|
||||
if schedule == "linear":
|
||||
betas = (
|
||||
torch.linspace(
|
||||
linear_start**0.5,
|
||||
@ -32,10 +33,9 @@ def make_beta_schedule(
|
||||
** 2
|
||||
)
|
||||
|
||||
elif schedule == 'cosine':
|
||||
elif schedule == "cosine":
|
||||
timesteps = (
|
||||
torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep
|
||||
+ cosine_s
|
||||
torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s
|
||||
)
|
||||
alphas = timesteps / (1 + cosine_s) * np.pi / 2
|
||||
alphas = torch.cos(alphas).pow(2)
|
||||
@ -43,15 +43,13 @@ def make_beta_schedule(
|
||||
betas = 1 - alphas[1:] / alphas[:-1]
|
||||
betas = np.clip(betas, a_min=0, a_max=0.999)
|
||||
|
||||
elif schedule == 'sqrt_linear':
|
||||
elif schedule == "sqrt_linear":
|
||||
betas = torch.linspace(
|
||||
linear_start, linear_end, n_timestep, dtype=torch.float64
|
||||
)
|
||||
elif schedule == 'sqrt':
|
||||
elif schedule == "sqrt":
|
||||
betas = (
|
||||
torch.linspace(
|
||||
linear_start, linear_end, n_timestep, dtype=torch.float64
|
||||
)
|
||||
torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64)
|
||||
** 0.5
|
||||
)
|
||||
else:
|
||||
@ -62,19 +60,14 @@ def make_beta_schedule(
|
||||
def make_ddim_timesteps(
|
||||
ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True
|
||||
):
|
||||
if ddim_discr_method == 'uniform':
|
||||
if ddim_discr_method == "uniform":
|
||||
c = num_ddpm_timesteps // num_ddim_timesteps
|
||||
if c < 1:
|
||||
c = 1
|
||||
c = 1
|
||||
ddim_timesteps = (np.arange(0, num_ddim_timesteps) * c).astype(int)
|
||||
elif ddim_discr_method == 'quad':
|
||||
elif ddim_discr_method == "quad":
|
||||
ddim_timesteps = (
|
||||
(
|
||||
np.linspace(
|
||||
0, np.sqrt(num_ddpm_timesteps * 0.8), num_ddim_timesteps
|
||||
)
|
||||
)
|
||||
** 2
|
||||
(np.linspace(0, np.sqrt(num_ddpm_timesteps * 0.8), num_ddim_timesteps)) ** 2
|
||||
).astype(int)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
@ -87,18 +80,14 @@ def make_ddim_timesteps(
|
||||
# steps_out = ddim_timesteps
|
||||
|
||||
if verbose:
|
||||
print(f'Selected timesteps for ddim sampler: {steps_out}')
|
||||
print(f"Selected timesteps for ddim sampler: {steps_out}")
|
||||
return steps_out
|
||||
|
||||
|
||||
def make_ddim_sampling_parameters(
|
||||
alphacums, ddim_timesteps, eta, verbose=True
|
||||
):
|
||||
def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True):
|
||||
# select alphas for computing the variance schedule
|
||||
alphas = alphacums[ddim_timesteps]
|
||||
alphas_prev = np.asarray(
|
||||
[alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist()
|
||||
)
|
||||
alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist())
|
||||
|
||||
# according the the formula provided in https://arxiv.org/abs/2010.02502
|
||||
sigmas = eta * np.sqrt(
|
||||
@ -106,11 +95,11 @@ def make_ddim_sampling_parameters(
|
||||
)
|
||||
if verbose:
|
||||
print(
|
||||
f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}'
|
||||
f"Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}"
|
||||
)
|
||||
print(
|
||||
f'For the chosen value of eta, which is {eta}, '
|
||||
f'this results in the following sigma_t schedule for ddim sampler {sigmas}'
|
||||
f"For the chosen value of eta, which is {eta}, "
|
||||
f"this results in the following sigma_t schedule for ddim sampler {sigmas}"
|
||||
)
|
||||
return sigmas, alphas, alphas_prev
|
||||
|
||||
@ -150,9 +139,7 @@ def checkpoint(func, inputs, params, flag):
|
||||
explicitly take as arguments.
|
||||
:param flag: if False, disable gradient checkpointing.
|
||||
"""
|
||||
if (
|
||||
False
|
||||
): # disabled checkpointing to allow requires_grad = False for main model
|
||||
if False: # disabled checkpointing to allow requires_grad = False for main model
|
||||
args = tuple(inputs) + tuple(params)
|
||||
return CheckpointFunction.apply(func, len(inputs), *args)
|
||||
else:
|
||||
@ -172,9 +159,7 @@ class CheckpointFunction(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, *output_grads):
|
||||
ctx.input_tensors = [
|
||||
x.detach().requires_grad_(True) for x in ctx.input_tensors
|
||||
]
|
||||
ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
|
||||
with torch.enable_grad():
|
||||
# Fixes a bug where the first op in run_function modifies the
|
||||
# Tensor storage in place, which is not allowed for detach()'d
|
||||
@ -216,7 +201,7 @@ def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
|
||||
[embedding, torch.zeros_like(embedding[:, :1])], dim=-1
|
||||
)
|
||||
else:
|
||||
embedding = repeat(timesteps, 'b -> b d', d=dim)
|
||||
embedding = repeat(timesteps, "b -> b d", d=dim)
|
||||
return embedding
|
||||
|
||||
|
||||
@ -269,7 +254,7 @@ def conv_nd(dims, *args, **kwargs):
|
||||
return nn.Conv2d(*args, **kwargs)
|
||||
elif dims == 3:
|
||||
return nn.Conv3d(*args, **kwargs)
|
||||
raise ValueError(f'unsupported dimensions: {dims}')
|
||||
raise ValueError(f"unsupported dimensions: {dims}")
|
||||
|
||||
|
||||
def linear(*args, **kwargs):
|
||||
@ -289,21 +274,19 @@ def avg_pool_nd(dims, *args, **kwargs):
|
||||
return nn.AvgPool2d(*args, **kwargs)
|
||||
elif dims == 3:
|
||||
return nn.AvgPool3d(*args, **kwargs)
|
||||
raise ValueError(f'unsupported dimensions: {dims}')
|
||||
raise ValueError(f"unsupported dimensions: {dims}")
|
||||
|
||||
|
||||
class HybridConditioner(nn.Module):
|
||||
def __init__(self, c_concat_config, c_crossattn_config):
|
||||
super().__init__()
|
||||
self.concat_conditioner = instantiate_from_config(c_concat_config)
|
||||
self.crossattn_conditioner = instantiate_from_config(
|
||||
c_crossattn_config
|
||||
)
|
||||
self.crossattn_conditioner = instantiate_from_config(c_crossattn_config)
|
||||
|
||||
def forward(self, c_concat, c_crossattn):
|
||||
c_concat = self.concat_conditioner(c_concat)
|
||||
c_crossattn = self.crossattn_conditioner(c_crossattn)
|
||||
return {'c_concat': [c_concat], 'c_crossattn': [c_crossattn]}
|
||||
return {"c_concat": [c_concat], "c_crossattn": [c_crossattn]}
|
||||
|
||||
|
||||
def noise_like(shape, device, repeat=False):
|
@ -1,5 +1,5 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
||||
class AbstractDistribution:
|
||||
@ -64,9 +64,7 @@ class DiagonalGaussianDistribution(object):
|
||||
return torch.Tensor([0.0])
|
||||
logtwopi = np.log(2.0 * np.pi)
|
||||
return 0.5 * torch.sum(
|
||||
logtwopi
|
||||
+ self.logvar
|
||||
+ torch.pow(sample - self.mean, 2) / self.var,
|
||||
logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
|
||||
dim=dims,
|
||||
)
|
||||
|
||||
@ -86,7 +84,7 @@ def normal_kl(mean1, logvar1, mean2, logvar2):
|
||||
if isinstance(obj, torch.Tensor):
|
||||
tensor = obj
|
||||
break
|
||||
assert tensor is not None, 'at least one argument must be a Tensor'
|
||||
assert tensor is not None, "at least one argument must be a Tensor"
|
||||
|
||||
# Force variances to be Tensors. Broadcasting helps convert scalars to
|
||||
# Tensors, but it does not work for torch.exp().
|
@ -6,12 +6,12 @@ class LitEma(nn.Module):
|
||||
def __init__(self, model, decay=0.9999, use_num_upates=True):
|
||||
super().__init__()
|
||||
if decay < 0.0 or decay > 1.0:
|
||||
raise ValueError('Decay must be between 0 and 1')
|
||||
raise ValueError("Decay must be between 0 and 1")
|
||||
|
||||
self.m_name2s_name = {}
|
||||
self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32))
|
||||
self.register_buffer("decay", torch.tensor(decay, dtype=torch.float32))
|
||||
self.register_buffer(
|
||||
'num_updates',
|
||||
"num_updates",
|
||||
torch.tensor(0, dtype=torch.int)
|
||||
if use_num_upates
|
||||
else torch.tensor(-1, dtype=torch.int),
|
||||
@ -20,7 +20,7 @@ class LitEma(nn.Module):
|
||||
for name, p in model.named_parameters():
|
||||
if p.requires_grad:
|
||||
# remove as '.'-character is not allowed in buffers
|
||||
s_name = name.replace('.', '')
|
||||
s_name = name.replace(".", "")
|
||||
self.m_name2s_name.update({name: s_name})
|
||||
self.register_buffer(s_name, p.clone().detach().data)
|
||||
|
||||
@ -31,9 +31,7 @@ class LitEma(nn.Module):
|
||||
|
||||
if self.num_updates >= 0:
|
||||
self.num_updates += 1
|
||||
decay = min(
|
||||
self.decay, (1 + self.num_updates) / (10 + self.num_updates)
|
||||
)
|
||||
decay = min(self.decay, (1 + self.num_updates) / (10 + self.num_updates))
|
||||
|
||||
one_minus_decay = 1.0 - decay
|
||||
|
||||
@ -44,9 +42,7 @@ class LitEma(nn.Module):
|
||||
for key in m_param:
|
||||
if m_param[key].requires_grad:
|
||||
sname = self.m_name2s_name[key]
|
||||
shadow_params[sname] = shadow_params[sname].type_as(
|
||||
m_param[key]
|
||||
)
|
||||
shadow_params[sname] = shadow_params[sname].type_as(m_param[key])
|
||||
shadow_params[sname].sub_(
|
||||
one_minus_decay * (shadow_params[sname] - m_param[key])
|
||||
)
|
||||
@ -58,9 +54,7 @@ class LitEma(nn.Module):
|
||||
shadow_params = dict(self.named_buffers())
|
||||
for key in m_param:
|
||||
if m_param[key].requires_grad:
|
||||
m_param[key].data.copy_(
|
||||
shadow_params[self.m_name2s_name[key]].data
|
||||
)
|
||||
m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data)
|
||||
else:
|
||||
assert not key in self.m_name2s_name
|
||||
|
@ -7,14 +7,14 @@ import kornia
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from einops import repeat
|
||||
from transformers import CLIPTokenizer, CLIPTextModel
|
||||
from transformers import CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from ldm.invoke.devices import choose_torch_device
|
||||
from ldm.invoke.globals import global_cache_dir
|
||||
from ldm.modules.x_transformer import (
|
||||
from ...util import choose_torch_device
|
||||
from ..globals import global_cache_dir
|
||||
from ..x_transformer import ( # TODO: can we directly rely on lucidrains code and simply add this as a reuirement? --> test
|
||||
Encoder,
|
||||
TransformerWrapper,
|
||||
) # TODO: can we directly rely on lucidrains code and simply add this as a reuirement? --> test
|
||||
)
|
||||
|
||||
|
||||
def _expand_mask(mask, dtype, tgt_len=None):
|
||||
@ -24,9 +24,7 @@ def _expand_mask(mask, dtype, tgt_len=None):
|
||||
bsz, src_len = mask.size()
|
||||
tgt_len = tgt_len if tgt_len is not None else src_len
|
||||
|
||||
expanded_mask = (
|
||||
mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
|
||||
)
|
||||
expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
|
||||
|
||||
inverted_mask = 1.0 - expanded_mask
|
||||
|
||||
@ -54,7 +52,7 @@ class AbstractEncoder(nn.Module):
|
||||
|
||||
|
||||
class ClassEmbedder(nn.Module):
|
||||
def __init__(self, embed_dim, n_classes=1000, key='class'):
|
||||
def __init__(self, embed_dim, n_classes=1000, key="class"):
|
||||
super().__init__()
|
||||
self.key = key
|
||||
self.embedding = nn.Embedding(n_classes, embed_dim)
|
||||
@ -99,20 +97,14 @@ class TransformerEmbedder(AbstractEncoder):
|
||||
class BERTTokenizer(AbstractEncoder):
|
||||
"""Uses a pretrained BERT tokenizer by huggingface. Vocab size: 30522 (?)"""
|
||||
|
||||
def __init__(
|
||||
self, device=choose_torch_device(), vq_interface=True, max_length=77
|
||||
):
|
||||
def __init__(self, device=choose_torch_device(), vq_interface=True, max_length=77):
|
||||
super().__init__()
|
||||
from transformers import (
|
||||
BertTokenizerFast,
|
||||
)
|
||||
from transformers import BertTokenizerFast
|
||||
|
||||
cache = global_cache_dir('hub')
|
||||
cache = global_cache_dir("hub")
|
||||
try:
|
||||
self.tokenizer = BertTokenizerFast.from_pretrained(
|
||||
'bert-base-uncased',
|
||||
cache_dir=cache,
|
||||
local_files_only=True
|
||||
"bert-base-uncased", cache_dir=cache, local_files_only=True
|
||||
)
|
||||
except OSError:
|
||||
raise SystemExit(
|
||||
@ -129,10 +121,10 @@ class BERTTokenizer(AbstractEncoder):
|
||||
max_length=self.max_length,
|
||||
return_length=True,
|
||||
return_overflowing_tokens=False,
|
||||
padding='max_length',
|
||||
return_tensors='pt',
|
||||
padding="max_length",
|
||||
return_tensors="pt",
|
||||
)
|
||||
tokens = batch_encoding['input_ids'].to(self.device)
|
||||
tokens = batch_encoding["input_ids"].to(self.device)
|
||||
return tokens
|
||||
|
||||
@torch.no_grad()
|
||||
@ -150,21 +142,19 @@ class BERTEmbedder(AbstractEncoder):
|
||||
"""Uses the BERT tokenizr model and add some transformer encoder layers"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
n_embed,
|
||||
n_layer,
|
||||
vocab_size=30522,
|
||||
max_seq_len=77,
|
||||
device=choose_torch_device(),
|
||||
use_tokenizer=True,
|
||||
embedding_dropout=0.0,
|
||||
self,
|
||||
n_embed,
|
||||
n_layer,
|
||||
vocab_size=30522,
|
||||
max_seq_len=77,
|
||||
device=choose_torch_device(),
|
||||
use_tokenizer=True,
|
||||
embedding_dropout=0.0,
|
||||
):
|
||||
super().__init__()
|
||||
self.use_tknz_fn = use_tokenizer
|
||||
if self.use_tknz_fn:
|
||||
self.tknz_fn = BERTTokenizer(
|
||||
vq_interface=False, max_length=max_seq_len
|
||||
)
|
||||
self.tknz_fn = BERTTokenizer(vq_interface=False, max_length=max_seq_len)
|
||||
self.device = device
|
||||
self.transformer = TransformerWrapper(
|
||||
num_tokens=vocab_size,
|
||||
@ -192,7 +182,7 @@ class SpatialRescaler(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
n_stages=1,
|
||||
method='bilinear',
|
||||
method="bilinear",
|
||||
multiplier=0.5,
|
||||
in_channels=3,
|
||||
out_channels=None,
|
||||
@ -202,25 +192,21 @@ class SpatialRescaler(nn.Module):
|
||||
self.n_stages = n_stages
|
||||
assert self.n_stages >= 0
|
||||
assert method in [
|
||||
'nearest',
|
||||
'linear',
|
||||
'bilinear',
|
||||
'trilinear',
|
||||
'bicubic',
|
||||
'area',
|
||||
"nearest",
|
||||
"linear",
|
||||
"bilinear",
|
||||
"trilinear",
|
||||
"bicubic",
|
||||
"area",
|
||||
]
|
||||
self.multiplier = multiplier
|
||||
self.interpolator = partial(
|
||||
torch.nn.functional.interpolate, mode=method
|
||||
)
|
||||
self.interpolator = partial(torch.nn.functional.interpolate, mode=method)
|
||||
self.remap_output = out_channels is not None
|
||||
if self.remap_output:
|
||||
print(
|
||||
f'Spatial Rescaler mapping from {in_channels} to {out_channels} channels after resizing.'
|
||||
)
|
||||
self.channel_mapper = nn.Conv2d(
|
||||
in_channels, out_channels, 1, bias=bias
|
||||
f"Spatial Rescaler mapping from {in_channels} to {out_channels} channels after resizing."
|
||||
)
|
||||
self.channel_mapper = nn.Conv2d(in_channels, out_channels, 1, bias=bias)
|
||||
|
||||
def forward(self, x):
|
||||
for stage in range(self.n_stages):
|
||||
@ -236,27 +222,24 @@ class SpatialRescaler(nn.Module):
|
||||
|
||||
class FrozenCLIPEmbedder(AbstractEncoder):
|
||||
"""Uses the CLIP transformer encoder for text (from Hugging Face)"""
|
||||
|
||||
tokenizer: CLIPTokenizer
|
||||
transformer: CLIPTextModel
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
version:str='openai/clip-vit-large-patch14',
|
||||
max_length:int=77,
|
||||
tokenizer:Optional[CLIPTokenizer]=None,
|
||||
transformer:Optional[CLIPTextModel]=None,
|
||||
version: str = "openai/clip-vit-large-patch14",
|
||||
max_length: int = 77,
|
||||
tokenizer: Optional[CLIPTokenizer] = None,
|
||||
transformer: Optional[CLIPTextModel] = None,
|
||||
):
|
||||
super().__init__()
|
||||
cache = global_cache_dir('hub')
|
||||
cache = global_cache_dir("hub")
|
||||
self.tokenizer = tokenizer or CLIPTokenizer.from_pretrained(
|
||||
version,
|
||||
cache_dir=cache,
|
||||
local_files_only=True
|
||||
version, cache_dir=cache, local_files_only=True
|
||||
)
|
||||
self.transformer = transformer or CLIPTextModel.from_pretrained(
|
||||
version,
|
||||
cache_dir=cache,
|
||||
local_files_only=True
|
||||
version, cache_dir=cache, local_files_only=True
|
||||
)
|
||||
self.max_length = max_length
|
||||
self.freeze()
|
||||
@ -268,7 +251,6 @@ class FrozenCLIPEmbedder(AbstractEncoder):
|
||||
inputs_embeds=None,
|
||||
embedding_manager=None,
|
||||
) -> torch.Tensor:
|
||||
|
||||
seq_length = (
|
||||
input_ids.shape[-1]
|
||||
if input_ids is not None
|
||||
@ -289,8 +271,8 @@ class FrozenCLIPEmbedder(AbstractEncoder):
|
||||
|
||||
return embeddings
|
||||
|
||||
self.transformer.text_model.embeddings.forward = (
|
||||
embedding_forward.__get__(self.transformer.text_model.embeddings)
|
||||
self.transformer.text_model.embeddings.forward = embedding_forward.__get__(
|
||||
self.transformer.text_model.embeddings
|
||||
)
|
||||
|
||||
def encoder_forward(
|
||||
@ -313,9 +295,7 @@ class FrozenCLIPEmbedder(AbstractEncoder):
|
||||
else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = (
|
||||
return_dict
|
||||
if return_dict is not None
|
||||
else self.config.use_return_dict
|
||||
return_dict if return_dict is not None else self.config.use_return_dict
|
||||
)
|
||||
|
||||
encoder_states = () if output_hidden_states else None
|
||||
@ -368,13 +348,11 @@ class FrozenCLIPEmbedder(AbstractEncoder):
|
||||
else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = (
|
||||
return_dict
|
||||
if return_dict is not None
|
||||
else self.config.use_return_dict
|
||||
return_dict if return_dict is not None else self.config.use_return_dict
|
||||
)
|
||||
|
||||
if input_ids is None:
|
||||
raise ValueError('You have to specify either input_ids')
|
||||
raise ValueError("You have to specify either input_ids")
|
||||
|
||||
input_shape = input_ids.size()
|
||||
input_ids = input_ids.view(-1, input_shape[-1])
|
||||
@ -395,9 +373,7 @@ class FrozenCLIPEmbedder(AbstractEncoder):
|
||||
# expand attention_mask
|
||||
if attention_mask is not None:
|
||||
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
||||
attention_mask = _expand_mask(
|
||||
attention_mask, hidden_states.dtype
|
||||
)
|
||||
attention_mask = _expand_mask(attention_mask, hidden_states.dtype)
|
||||
|
||||
last_hidden_state = self.encoder(
|
||||
inputs_embeds=hidden_states,
|
||||
@ -436,9 +412,7 @@ class FrozenCLIPEmbedder(AbstractEncoder):
|
||||
embedding_manager=embedding_manager,
|
||||
)
|
||||
|
||||
self.transformer.forward = transformer_forward.__get__(
|
||||
self.transformer
|
||||
)
|
||||
self.transformer.forward = transformer_forward.__get__(self.transformer)
|
||||
|
||||
def freeze(self):
|
||||
self.transformer = self.transformer.eval()
|
||||
@ -452,10 +426,10 @@ class FrozenCLIPEmbedder(AbstractEncoder):
|
||||
max_length=self.max_length,
|
||||
return_length=True,
|
||||
return_overflowing_tokens=False,
|
||||
padding='max_length',
|
||||
return_tensors='pt',
|
||||
padding="max_length",
|
||||
return_tensors="pt",
|
||||
)
|
||||
tokens = batch_encoding['input_ids'].to(self.device)
|
||||
tokens = batch_encoding["input_ids"].to(self.device)
|
||||
z = self.transformer(input_ids=tokens, **kwargs)
|
||||
|
||||
return z
|
||||
@ -471,25 +445,25 @@ class FrozenCLIPEmbedder(AbstractEncoder):
|
||||
def device(self, device):
|
||||
self.transformer.to(device=device)
|
||||
|
||||
class WeightedFrozenCLIPEmbedder(FrozenCLIPEmbedder):
|
||||
|
||||
class WeightedFrozenCLIPEmbedder(FrozenCLIPEmbedder):
|
||||
fragment_weights_key = "fragment_weights"
|
||||
return_tokens_key = "return_tokens"
|
||||
|
||||
def set_textual_inversion_manager(self, manager): #TextualInversionManager):
|
||||
def set_textual_inversion_manager(self, manager): # TextualInversionManager):
|
||||
# TODO all of the weighting and expanding stuff needs be moved out of this class
|
||||
self.textual_inversion_manager = manager
|
||||
|
||||
def forward(self, text: list, **kwargs):
|
||||
# TODO all of the weighting and expanding stuff needs be moved out of this class
|
||||
'''
|
||||
"""
|
||||
|
||||
:param text: A batch of prompt strings, or, a batch of lists of fragments of prompt strings to which different
|
||||
weights shall be applied.
|
||||
:param kwargs: If the keyword arg "fragment_weights" is passed, it shall contain a batch of lists of weights
|
||||
for the prompt fragments. In this case text must contain batches of lists of prompt fragments.
|
||||
:return: A tensor of shape (B, 77, 768) containing weighted embeddings
|
||||
'''
|
||||
"""
|
||||
if self.fragment_weights_key not in kwargs:
|
||||
# fallback to base class implementation
|
||||
return super().forward(text, **kwargs)
|
||||
@ -507,7 +481,6 @@ class WeightedFrozenCLIPEmbedder(FrozenCLIPEmbedder):
|
||||
batch_z = None
|
||||
batch_tokens = None
|
||||
for fragments, weights in zip(text, fragment_weights):
|
||||
|
||||
# First, weight tokens in individual fragments by scaling the feature vectors as requested (effectively
|
||||
# applying a multiplier to the CFG scale on a per-token basis).
|
||||
# For tokens weighted<1, intuitively we want SD to become not merely *less* interested in the concept
|
||||
@ -520,7 +493,9 @@ class WeightedFrozenCLIPEmbedder(FrozenCLIPEmbedder):
|
||||
|
||||
# handle weights >=1
|
||||
tokens, per_token_weights = self.get_tokens_and_weights(fragments, weights)
|
||||
base_embedding = self.build_weighted_embedding_tensor(tokens, per_token_weights, **kwargs)
|
||||
base_embedding = self.build_weighted_embedding_tensor(
|
||||
tokens, per_token_weights, **kwargs
|
||||
)
|
||||
|
||||
# this is our starting point
|
||||
embeddings = base_embedding.unsqueeze(0)
|
||||
@ -536,12 +511,18 @@ class WeightedFrozenCLIPEmbedder(FrozenCLIPEmbedder):
|
||||
# such that the resulting lerped embedding is exactly half-way between "mountain man" and "mountain".
|
||||
for index, fragment_weight in enumerate(weights):
|
||||
if fragment_weight < 1:
|
||||
fragments_without_this = fragments[:index] + fragments[index+1:]
|
||||
weights_without_this = weights[:index] + weights[index+1:]
|
||||
tokens, per_token_weights = self.get_tokens_and_weights(fragments_without_this, weights_without_this)
|
||||
embedding_without_this = self.build_weighted_embedding_tensor(tokens, per_token_weights, **kwargs)
|
||||
fragments_without_this = fragments[:index] + fragments[index + 1 :]
|
||||
weights_without_this = weights[:index] + weights[index + 1 :]
|
||||
tokens, per_token_weights = self.get_tokens_and_weights(
|
||||
fragments_without_this, weights_without_this
|
||||
)
|
||||
embedding_without_this = self.build_weighted_embedding_tensor(
|
||||
tokens, per_token_weights, **kwargs
|
||||
)
|
||||
|
||||
embeddings = torch.cat((embeddings, embedding_without_this.unsqueeze(0)), dim=1)
|
||||
embeddings = torch.cat(
|
||||
(embeddings, embedding_without_this.unsqueeze(0)), dim=1
|
||||
)
|
||||
# weight of the embedding *without* this fragment gets *stronger* as its weight approaches 0
|
||||
# if fragment_weight = 0, basically we want embedding_without_this to completely overwhelm base_embedding
|
||||
# therefore:
|
||||
@ -554,29 +535,43 @@ class WeightedFrozenCLIPEmbedder(FrozenCLIPEmbedder):
|
||||
# inf at PI/2
|
||||
# -> tan((1-weight)*PI/2) should give us ideal lerp weights
|
||||
epsilon = 1e-9
|
||||
fragment_weight = max(epsilon, fragment_weight) # inf is bad
|
||||
embedding_lerp_weight = math.tan((1.0 - fragment_weight) * math.pi / 2)
|
||||
fragment_weight = max(epsilon, fragment_weight) # inf is bad
|
||||
embedding_lerp_weight = math.tan(
|
||||
(1.0 - fragment_weight) * math.pi / 2
|
||||
)
|
||||
# todo handle negative weight?
|
||||
|
||||
per_embedding_weights.append(embedding_lerp_weight)
|
||||
|
||||
lerped_embeddings = self.apply_embedding_weights(embeddings, per_embedding_weights, normalize=True).squeeze(0)
|
||||
lerped_embeddings = self.apply_embedding_weights(
|
||||
embeddings, per_embedding_weights, normalize=True
|
||||
).squeeze(0)
|
||||
|
||||
#print(f"assembled tokens for '{fragments}' into tensor of shape {lerped_embeddings.shape}")
|
||||
# print(f"assembled tokens for '{fragments}' into tensor of shape {lerped_embeddings.shape}")
|
||||
|
||||
# append to batch
|
||||
batch_z = lerped_embeddings.unsqueeze(0) if batch_z is None else torch.cat([batch_z, lerped_embeddings.unsqueeze(0)], dim=1)
|
||||
batch_tokens = tokens.unsqueeze(0) if batch_tokens is None else torch.cat([batch_tokens, tokens.unsqueeze(0)], dim=1)
|
||||
batch_z = (
|
||||
lerped_embeddings.unsqueeze(0)
|
||||
if batch_z is None
|
||||
else torch.cat([batch_z, lerped_embeddings.unsqueeze(0)], dim=1)
|
||||
)
|
||||
batch_tokens = (
|
||||
tokens.unsqueeze(0)
|
||||
if batch_tokens is None
|
||||
else torch.cat([batch_tokens, tokens.unsqueeze(0)], dim=1)
|
||||
)
|
||||
|
||||
# should have shape (B, 77, 768)
|
||||
#print(f"assembled all tokens into tensor of shape {batch_z.shape}")
|
||||
# print(f"assembled all tokens into tensor of shape {batch_z.shape}")
|
||||
|
||||
if should_return_tokens:
|
||||
return batch_z, batch_tokens
|
||||
else:
|
||||
return batch_z
|
||||
|
||||
def get_token_ids(self, fragments: list[str], include_start_and_end_markers: bool = True) -> list[list[int]]:
|
||||
def get_token_ids(
|
||||
self, fragments: list[str], include_start_and_end_markers: bool = True
|
||||
) -> list[list[int]]:
|
||||
"""
|
||||
Convert a list of strings like `["a cat", "sitting", "on a mat"]` into a list of lists of token ids like
|
||||
`[[bos, 0, 1, eos], [bos, 2, eos], [bos, 3, 0, 4, eos]]`. bos/eos markers are skipped if
|
||||
@ -594,58 +589,81 @@ class WeightedFrozenCLIPEmbedder(FrozenCLIPEmbedder):
|
||||
truncation=True,
|
||||
max_length=self.max_length,
|
||||
return_overflowing_tokens=False,
|
||||
padding='do_not_pad',
|
||||
padding="do_not_pad",
|
||||
return_tensors=None, # just give me lists of ints
|
||||
)['input_ids']
|
||||
)["input_ids"]
|
||||
|
||||
result = []
|
||||
for token_ids in token_ids_list:
|
||||
# trim eos/bos
|
||||
token_ids = token_ids[1:-1]
|
||||
# pad for textual inversions with vector length >1
|
||||
token_ids = self.textual_inversion_manager.expand_textual_inversion_token_ids_if_necessary(token_ids)
|
||||
token_ids = self.textual_inversion_manager.expand_textual_inversion_token_ids_if_necessary(
|
||||
token_ids
|
||||
)
|
||||
# restrict length to max_length-2 (leaving room for bos/eos)
|
||||
token_ids = token_ids[0:self.max_length - 2]
|
||||
token_ids = token_ids[0 : self.max_length - 2]
|
||||
# add back eos/bos if requested
|
||||
if include_start_and_end_markers:
|
||||
token_ids = [self.tokenizer.bos_token_id] + token_ids + [self.tokenizer.eos_token_id]
|
||||
token_ids = (
|
||||
[self.tokenizer.bos_token_id]
|
||||
+ token_ids
|
||||
+ [self.tokenizer.eos_token_id]
|
||||
)
|
||||
|
||||
result.append(token_ids)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@classmethod
|
||||
def apply_embedding_weights(self, embeddings: torch.Tensor, per_embedding_weights: list[float], normalize:bool) -> torch.Tensor:
|
||||
per_embedding_weights = torch.tensor(per_embedding_weights, dtype=embeddings.dtype, device=embeddings.device)
|
||||
def apply_embedding_weights(
|
||||
self,
|
||||
embeddings: torch.Tensor,
|
||||
per_embedding_weights: list[float],
|
||||
normalize: bool,
|
||||
) -> torch.Tensor:
|
||||
per_embedding_weights = torch.tensor(
|
||||
per_embedding_weights, dtype=embeddings.dtype, device=embeddings.device
|
||||
)
|
||||
if normalize:
|
||||
per_embedding_weights = per_embedding_weights / torch.sum(per_embedding_weights)
|
||||
reshaped_weights = per_embedding_weights.reshape(per_embedding_weights.shape + (1, 1,))
|
||||
#reshaped_weights = per_embedding_weights.reshape(per_embedding_weights.shape + (1,1,)).expand(embeddings.shape)
|
||||
per_embedding_weights = per_embedding_weights / torch.sum(
|
||||
per_embedding_weights
|
||||
)
|
||||
reshaped_weights = per_embedding_weights.reshape(
|
||||
per_embedding_weights.shape
|
||||
+ (
|
||||
1,
|
||||
1,
|
||||
)
|
||||
)
|
||||
# reshaped_weights = per_embedding_weights.reshape(per_embedding_weights.shape + (1,1,)).expand(embeddings.shape)
|
||||
return torch.sum(embeddings * reshaped_weights, dim=1)
|
||||
# lerped embeddings has shape (77, 768)
|
||||
|
||||
|
||||
def get_tokens_and_weights(self, fragments: list[str], weights: list[float]) -> (torch.Tensor, torch.Tensor):
|
||||
'''
|
||||
def get_tokens_and_weights(
|
||||
self, fragments: list[str], weights: list[float]
|
||||
) -> (torch.Tensor, torch.Tensor):
|
||||
"""
|
||||
|
||||
:param fragments:
|
||||
:param weights: Per-fragment weights (CFG scaling). No need for these to be normalized. They will not be normalized here and that's fine.
|
||||
:return:
|
||||
'''
|
||||
"""
|
||||
# empty is meaningful
|
||||
if len(fragments) == 0 and len(weights) == 0:
|
||||
fragments = ['']
|
||||
fragments = [""]
|
||||
weights = [1]
|
||||
per_fragment_token_ids = self.get_token_ids(fragments, include_start_and_end_markers=False)
|
||||
per_fragment_token_ids = self.get_token_ids(
|
||||
fragments, include_start_and_end_markers=False
|
||||
)
|
||||
all_token_ids = []
|
||||
per_token_weights = []
|
||||
#print("all fragments:", fragments, weights)
|
||||
# print("all fragments:", fragments, weights)
|
||||
for index, fragment in enumerate(per_fragment_token_ids):
|
||||
weight = float(weights[index])
|
||||
#print("processing fragment", fragment, weight)
|
||||
# print("processing fragment", fragment, weight)
|
||||
this_fragment_token_ids = per_fragment_token_ids[index]
|
||||
#print("fragment", fragment, "processed to", this_fragment_token_ids)
|
||||
# print("fragment", fragment, "processed to", this_fragment_token_ids)
|
||||
# append
|
||||
all_token_ids += this_fragment_token_ids
|
||||
# fill out weights tensor with one float per token
|
||||
@ -654,60 +672,85 @@ class WeightedFrozenCLIPEmbedder(FrozenCLIPEmbedder):
|
||||
# leave room for bos/eos
|
||||
max_token_count_without_bos_eos_markers = self.max_length - 2
|
||||
if len(all_token_ids) > max_token_count_without_bos_eos_markers:
|
||||
excess_token_count = len(all_token_ids) - max_token_count_without_bos_eos_markers
|
||||
excess_token_count = (
|
||||
len(all_token_ids) - max_token_count_without_bos_eos_markers
|
||||
)
|
||||
# TODO build nice description string of how the truncation was applied
|
||||
# this should be done by calling self.tokenizer.convert_ids_to_tokens() then passing the result to
|
||||
# self.tokenizer.convert_tokens_to_string() for the token_ids on each side of the truncation limit.
|
||||
print(f">> Prompt is {excess_token_count} token(s) too long and has been truncated")
|
||||
print(
|
||||
f">> Prompt is {excess_token_count} token(s) too long and has been truncated"
|
||||
)
|
||||
all_token_ids = all_token_ids[0:max_token_count_without_bos_eos_markers]
|
||||
per_token_weights = per_token_weights[0:max_token_count_without_bos_eos_markers]
|
||||
per_token_weights = per_token_weights[
|
||||
0:max_token_count_without_bos_eos_markers
|
||||
]
|
||||
|
||||
# pad out to a 77-entry array: [bos_token, <prompt tokens>, eos_token, pad_token…]
|
||||
# (77 = self.max_length)
|
||||
all_token_ids = [self.tokenizer.bos_token_id] + all_token_ids + [self.tokenizer.eos_token_id]
|
||||
all_token_ids = (
|
||||
[self.tokenizer.bos_token_id]
|
||||
+ all_token_ids
|
||||
+ [self.tokenizer.eos_token_id]
|
||||
)
|
||||
per_token_weights = [1.0] + per_token_weights + [1.0]
|
||||
pad_length = self.max_length - len(all_token_ids)
|
||||
all_token_ids += [self.tokenizer.pad_token_id] * pad_length
|
||||
per_token_weights += [1.0] * pad_length
|
||||
|
||||
all_token_ids_tensor = torch.tensor(all_token_ids, dtype=torch.long).to(self.device)
|
||||
per_token_weights_tensor = torch.tensor(per_token_weights, dtype=torch.float32).to(self.device)
|
||||
#print(f"assembled all_token_ids_tensor with shape {all_token_ids_tensor.shape}")
|
||||
all_token_ids_tensor = torch.tensor(all_token_ids, dtype=torch.long).to(
|
||||
self.device
|
||||
)
|
||||
per_token_weights_tensor = torch.tensor(
|
||||
per_token_weights, dtype=torch.float32
|
||||
).to(self.device)
|
||||
# print(f"assembled all_token_ids_tensor with shape {all_token_ids_tensor.shape}")
|
||||
return all_token_ids_tensor, per_token_weights_tensor
|
||||
|
||||
def build_weighted_embedding_tensor(self, token_ids: torch.Tensor, per_token_weights: torch.Tensor, weight_delta_from_empty=True, **kwargs) -> torch.Tensor:
|
||||
'''
|
||||
def build_weighted_embedding_tensor(
|
||||
self,
|
||||
token_ids: torch.Tensor,
|
||||
per_token_weights: torch.Tensor,
|
||||
weight_delta_from_empty=True,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Build a tensor representing the passed-in tokens, each of which has a weight.
|
||||
:param token_ids: A tensor of shape (77) containing token ids (integers)
|
||||
:param per_token_weights: A tensor of shape (77) containing weights (floats)
|
||||
:param method: Whether to multiply the whole feature vector for each token or just its distance from an "empty" feature vector
|
||||
:param kwargs: passed on to self.transformer()
|
||||
:return: A tensor of shape (1, 77, 768) representing the requested weighted embeddings.
|
||||
'''
|
||||
#print(f"building weighted embedding tensor for {tokens} with weights {per_token_weights}")
|
||||
"""
|
||||
# print(f"building weighted embedding tensor for {tokens} with weights {per_token_weights}")
|
||||
if token_ids.shape != torch.Size([self.max_length]):
|
||||
raise ValueError(f"token_ids has shape {token_ids.shape} - expected [{self.max_length}]")
|
||||
raise ValueError(
|
||||
f"token_ids has shape {token_ids.shape} - expected [{self.max_length}]"
|
||||
)
|
||||
|
||||
z = self.transformer(input_ids=token_ids.unsqueeze(0), **kwargs)
|
||||
|
||||
batch_weights_expanded = per_token_weights.reshape(per_token_weights.shape + (1,)).expand(z.shape)
|
||||
batch_weights_expanded = per_token_weights.reshape(
|
||||
per_token_weights.shape + (1,)
|
||||
).expand(z.shape)
|
||||
|
||||
if weight_delta_from_empty:
|
||||
empty_tokens = self.tokenizer([''] * z.shape[0],
|
||||
truncation=True,
|
||||
max_length=self.max_length,
|
||||
padding='max_length',
|
||||
return_tensors='pt'
|
||||
)['input_ids'].to(self.device)
|
||||
empty_tokens = self.tokenizer(
|
||||
[""] * z.shape[0],
|
||||
truncation=True,
|
||||
max_length=self.max_length,
|
||||
padding="max_length",
|
||||
return_tensors="pt",
|
||||
)["input_ids"].to(self.device)
|
||||
empty_z = self.transformer(input_ids=empty_tokens, **kwargs)
|
||||
z_delta_from_empty = z - empty_z
|
||||
weighted_z = empty_z + (z_delta_from_empty * batch_weights_expanded)
|
||||
|
||||
#weighted_z_delta_from_empty = (weighted_z-empty_z)
|
||||
#print("weighted z has delta from empty with sum", weighted_z_delta_from_empty.sum().item(), "mean", weighted_z_delta_from_empty.mean().item() )
|
||||
# weighted_z_delta_from_empty = (weighted_z-empty_z)
|
||||
# print("weighted z has delta from empty with sum", weighted_z_delta_from_empty.sum().item(), "mean", weighted_z_delta_from_empty.mean().item() )
|
||||
|
||||
#print("using empty-delta method, first 5 rows:")
|
||||
#print(weighted_z[:5])
|
||||
# print("using empty-delta method, first 5 rows:")
|
||||
# print(weighted_z[:5])
|
||||
|
||||
return weighted_z
|
||||
|
||||
@ -716,7 +759,7 @@ class WeightedFrozenCLIPEmbedder(FrozenCLIPEmbedder):
|
||||
z *= batch_weights_expanded
|
||||
after_weighting_mean = z.mean()
|
||||
# correct the mean. not sure if this is right but it's what the automatic1111 fork of SD does
|
||||
mean_correction_factor = original_mean/after_weighting_mean
|
||||
mean_correction_factor = original_mean / after_weighting_mean
|
||||
z *= mean_correction_factor
|
||||
return z
|
||||
|
||||
@ -728,7 +771,7 @@ class FrozenCLIPTextEmbedder(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
version='ViT-L/14',
|
||||
version="ViT-L/14",
|
||||
device=choose_torch_device(),
|
||||
max_length=77,
|
||||
n_repeat=1,
|
||||
@ -757,7 +800,7 @@ class FrozenCLIPTextEmbedder(nn.Module):
|
||||
z = self(text)
|
||||
if z.ndim == 2:
|
||||
z = z[:, None, :]
|
||||
z = repeat(z, 'b 1 d -> b k d', k=self.n_repeat)
|
||||
z = repeat(z, "b 1 d -> b k d", k=self.n_repeat)
|
||||
return z
|
||||
|
||||
|
||||
@ -779,12 +822,12 @@ class FrozenClipImageEmbedder(nn.Module):
|
||||
self.antialias = antialias
|
||||
|
||||
self.register_buffer(
|
||||
'mean',
|
||||
"mean",
|
||||
torch.Tensor([0.48145466, 0.4578275, 0.40821073]),
|
||||
persistent=False,
|
||||
)
|
||||
self.register_buffer(
|
||||
'std',
|
||||
"std",
|
||||
torch.Tensor([0.26862954, 0.26130258, 0.27577711]),
|
||||
persistent=False,
|
||||
)
|
||||
@ -794,7 +837,7 @@ class FrozenClipImageEmbedder(nn.Module):
|
||||
x = kornia.geometry.resize(
|
||||
x,
|
||||
(224, 224),
|
||||
interpolation='bicubic',
|
||||
interpolation="bicubic",
|
||||
align_corners=True,
|
||||
antialias=self.antialias,
|
||||
)
|
||||
@ -808,8 +851,8 @@ class FrozenClipImageEmbedder(nn.Module):
|
||||
return self.model.encode_image(self.preprocess(x))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
from ldm.util import count_params
|
||||
if __name__ == "__main__":
|
||||
from ...util.util import count_params
|
||||
|
||||
model = FrozenCLIPEmbedder()
|
||||
count_params(model, verbose=True)
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
x
Reference in New Issue
Block a user