mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Compare commits
4 Commits
dev/ci/upd
...
bugfix/inv
Author | SHA1 | Date | |
---|---|---|---|
b2a0e5890b | |||
482fbd8884 | |||
0411de3651 | |||
9f0095ea42 |
15
.github/workflows/mkdocs-material.yml
vendored
15
.github/workflows/mkdocs-material.yml
vendored
@ -2,7 +2,8 @@ name: mkdocs-material
|
|||||||
on:
|
on:
|
||||||
push:
|
push:
|
||||||
branches:
|
branches:
|
||||||
- 'refs/heads/v2.3'
|
- 'main'
|
||||||
|
- 'development'
|
||||||
|
|
||||||
permissions:
|
permissions:
|
||||||
contents: write
|
contents: write
|
||||||
@ -11,10 +12,6 @@ jobs:
|
|||||||
mkdocs-material:
|
mkdocs-material:
|
||||||
if: github.event.pull_request.draft == false
|
if: github.event.pull_request.draft == false
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
env:
|
|
||||||
REPO_URL: '${{ github.server_url }}/${{ github.repository }}'
|
|
||||||
REPO_NAME: '${{ github.repository }}'
|
|
||||||
SITE_URL: 'https://${{ github.repository_owner }}.github.io/InvokeAI'
|
|
||||||
steps:
|
steps:
|
||||||
- name: checkout sources
|
- name: checkout sources
|
||||||
uses: actions/checkout@v3
|
uses: actions/checkout@v3
|
||||||
@ -25,15 +22,11 @@ jobs:
|
|||||||
uses: actions/setup-python@v4
|
uses: actions/setup-python@v4
|
||||||
with:
|
with:
|
||||||
python-version: '3.10'
|
python-version: '3.10'
|
||||||
cache: pip
|
|
||||||
cache-dependency-path: pyproject.toml
|
|
||||||
|
|
||||||
- name: install requirements
|
- name: install requirements
|
||||||
env:
|
|
||||||
PIP_USE_PEP517: 1
|
|
||||||
run: |
|
run: |
|
||||||
python -m \
|
python -m \
|
||||||
pip install ".[docs]"
|
pip install -r docs/requirements-mkdocs.txt
|
||||||
|
|
||||||
- name: confirm buildability
|
- name: confirm buildability
|
||||||
run: |
|
run: |
|
||||||
@ -43,7 +36,7 @@ jobs:
|
|||||||
--verbose
|
--verbose
|
||||||
|
|
||||||
- name: deploy to gh-pages
|
- name: deploy to gh-pages
|
||||||
if: ${{ github.ref == 'refs/heads/v2.3' }}
|
if: ${{ github.ref == 'refs/heads/main' }}
|
||||||
run: |
|
run: |
|
||||||
python -m \
|
python -m \
|
||||||
mkdocs gh-deploy \
|
mkdocs gh-deploy \
|
||||||
|
@ -33,8 +33,6 @@
|
|||||||
|
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
_**Note: The UI is not fully functional on `main`. If you need a stable UI based on `main`, use the `pre-nodes` tag while we [migrate to a new backend](https://github.com/invoke-ai/InvokeAI/discussions/3246).**_
|
|
||||||
|
|
||||||
InvokeAI is a leading creative engine built to empower professionals and enthusiasts alike. Generate and create stunning visual media using the latest AI-driven technologies. InvokeAI offers an industry leading Web Interface, interactive Command Line Interface, and also serves as the foundation for multiple commercial products.
|
InvokeAI is a leading creative engine built to empower professionals and enthusiasts alike. Generate and create stunning visual media using the latest AI-driven technologies. InvokeAI offers an industry leading Web Interface, interactive Command Line Interface, and also serves as the foundation for multiple commercial products.
|
||||||
|
|
||||||
**Quick links**: [[How to Install](https://invoke-ai.github.io/InvokeAI/#installation)] [<a href="https://discord.gg/ZmtBAhwWhy">Discord Server</a>] [<a href="https://invoke-ai.github.io/InvokeAI/">Documentation and Tutorials</a>] [<a href="https://github.com/invoke-ai/InvokeAI/">Code and Downloads</a>] [<a href="https://github.com/invoke-ai/InvokeAI/issues">Bug Reports</a>] [<a href="https://github.com/invoke-ai/InvokeAI/discussions">Discussion, Ideas & Q&A</a>]
|
**Quick links**: [[How to Install](https://invoke-ai.github.io/InvokeAI/#installation)] [<a href="https://discord.gg/ZmtBAhwWhy">Discord Server</a>] [<a href="https://invoke-ai.github.io/InvokeAI/">Documentation and Tutorials</a>] [<a href="https://github.com/invoke-ai/InvokeAI/">Code and Downloads</a>] [<a href="https://github.com/invoke-ai/InvokeAI/issues">Bug Reports</a>] [<a href="https://github.com/invoke-ai/InvokeAI/discussions">Discussion, Ideas & Q&A</a>]
|
||||||
|
@ -89,7 +89,7 @@ experimental versions later.
|
|||||||
sudo apt update
|
sudo apt update
|
||||||
sudo apt install -y software-properties-common
|
sudo apt install -y software-properties-common
|
||||||
sudo add-apt-repository -y ppa:deadsnakes/ppa
|
sudo add-apt-repository -y ppa:deadsnakes/ppa
|
||||||
sudo apt install -y python3.10 python3-pip python3.10-venv
|
sudo apt install python3.10 python3-pip python3.10-venv
|
||||||
sudo update-alternatives --install /usr/local/bin/python python /usr/bin/python3.10 3
|
sudo update-alternatives --install /usr/local/bin/python python /usr/bin/python3.10 3
|
||||||
```
|
```
|
||||||
|
|
||||||
|
@ -1,12 +1,14 @@
|
|||||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
from argparse import Namespace
|
||||||
|
|
||||||
import invokeai.backend.util.logging as logger
|
from invokeai.app.services.metadata import PngMetadataService, MetadataServiceBase
|
||||||
from typing import types
|
|
||||||
|
|
||||||
from ..services.default_graphs import create_system_graphs
|
from ..services.default_graphs import create_system_graphs
|
||||||
|
|
||||||
from ..services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage
|
from ..services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage
|
||||||
|
|
||||||
from ...backend import Globals
|
from ...backend import Globals
|
||||||
from ..services.model_manager_initializer import get_model_manager
|
from ..services.model_manager_initializer import get_model_manager
|
||||||
from ..services.restoration_services import RestorationServices
|
from ..services.restoration_services import RestorationServices
|
||||||
@ -17,7 +19,6 @@ from ..services.invocation_services import InvocationServices
|
|||||||
from ..services.invoker import Invoker
|
from ..services.invoker import Invoker
|
||||||
from ..services.processor import DefaultInvocationProcessor
|
from ..services.processor import DefaultInvocationProcessor
|
||||||
from ..services.sqlite import SqliteItemStorage
|
from ..services.sqlite import SqliteItemStorage
|
||||||
from ..services.metadata import PngMetadataService
|
|
||||||
from .events import FastAPIEventService
|
from .events import FastAPIEventService
|
||||||
|
|
||||||
|
|
||||||
@ -43,16 +44,15 @@ class ApiDependencies:
|
|||||||
invoker: Invoker = None
|
invoker: Invoker = None
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def initialize(config, event_handler_id: int, logger: types.ModuleType=logger):
|
def initialize(config, event_handler_id: int):
|
||||||
Globals.try_patchmatch = config.patchmatch
|
Globals.try_patchmatch = config.patchmatch
|
||||||
Globals.always_use_cpu = config.always_use_cpu
|
Globals.always_use_cpu = config.always_use_cpu
|
||||||
Globals.internet_available = config.internet_available and check_internet()
|
Globals.internet_available = config.internet_available and check_internet()
|
||||||
Globals.disable_xformers = not config.xformers
|
Globals.disable_xformers = not config.xformers
|
||||||
Globals.ckpt_convert = config.ckpt_convert
|
Globals.ckpt_convert = config.ckpt_convert
|
||||||
|
|
||||||
# TO DO: Use the config to select the logger rather than use the default
|
# TODO: Use a logger
|
||||||
# invokeai logging module
|
print(f">> Internet connectivity is {Globals.internet_available}")
|
||||||
logger.info(f"Internet connectivity is {Globals.internet_available}")
|
|
||||||
|
|
||||||
events = FastAPIEventService(event_handler_id)
|
events = FastAPIEventService(event_handler_id)
|
||||||
|
|
||||||
@ -70,9 +70,8 @@ class ApiDependencies:
|
|||||||
db_location = os.path.join(output_folder, "invokeai.db")
|
db_location = os.path.join(output_folder, "invokeai.db")
|
||||||
|
|
||||||
services = InvocationServices(
|
services = InvocationServices(
|
||||||
model_manager=get_model_manager(config,logger),
|
model_manager=get_model_manager(config),
|
||||||
events=events,
|
events=events,
|
||||||
logger=logger,
|
|
||||||
latents=latents,
|
latents=latents,
|
||||||
images=images,
|
images=images,
|
||||||
metadata=metadata,
|
metadata=metadata,
|
||||||
@ -84,7 +83,7 @@ class ApiDependencies:
|
|||||||
filename=db_location, table_name="graph_executions"
|
filename=db_location, table_name="graph_executions"
|
||||||
),
|
),
|
||||||
processor=DefaultInvocationProcessor(),
|
processor=DefaultInvocationProcessor(),
|
||||||
restoration=RestorationServices(config,logger),
|
restoration=RestorationServices(config),
|
||||||
)
|
)
|
||||||
|
|
||||||
create_system_graphs(services.graph_library)
|
create_system_graphs(services.graph_library)
|
||||||
|
@ -32,9 +32,3 @@ class ProgressImage(BaseModel):
|
|||||||
width: int = Field(description="The effective width of the image in pixels")
|
width: int = Field(description="The effective width of the image in pixels")
|
||||||
height: int = Field(description="The effective height of the image in pixels")
|
height: int = Field(description="The effective height of the image in pixels")
|
||||||
dataURL: str = Field(description="The image data as a b64 data URL")
|
dataURL: str = Field(description="The image data as a b64 data URL")
|
||||||
|
|
||||||
|
|
||||||
class SavedImage(BaseModel):
|
|
||||||
image_name: str = Field(description="The name of the saved image")
|
|
||||||
thumbnail_name: str = Field(description="The name of the saved thumbnail")
|
|
||||||
created: int = Field(description="The created timestamp of the saved image")
|
|
||||||
|
@ -6,14 +6,12 @@ import os
|
|||||||
from typing import Any
|
from typing import Any
|
||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
from fastapi import Body, HTTPException, Path, Query, Request, UploadFile
|
from fastapi import HTTPException, Path, Query, Request, UploadFile
|
||||||
from fastapi.responses import FileResponse, Response
|
from fastapi.responses import FileResponse, Response
|
||||||
from fastapi.routing import APIRouter
|
from fastapi.routing import APIRouter
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from invokeai.app.api.models.images import (
|
from invokeai.app.api.models.images import ImageResponse, ImageResponseMetadata
|
||||||
ImageResponse,
|
from invokeai.app.services.metadata import InvokeAIMetadata
|
||||||
ImageResponseMetadata,
|
|
||||||
)
|
|
||||||
from invokeai.app.services.item_storage import PaginatedResults
|
from invokeai.app.services.item_storage import PaginatedResults
|
||||||
|
|
||||||
from ...services.image_storage import ImageType
|
from ...services.image_storage import ImageType
|
||||||
@ -26,8 +24,8 @@ images_router = APIRouter(prefix="/v1/images", tags=["images"])
|
|||||||
async def get_image(
|
async def get_image(
|
||||||
image_type: ImageType = Path(description="The type of image to get"),
|
image_type: ImageType = Path(description="The type of image to get"),
|
||||||
image_name: str = Path(description="The name of the image to get"),
|
image_name: str = Path(description="The name of the image to get"),
|
||||||
) -> FileResponse:
|
) -> FileResponse | Response:
|
||||||
"""Gets an image"""
|
"""Gets a result"""
|
||||||
|
|
||||||
path = ApiDependencies.invoker.services.images.get_path(
|
path = ApiDependencies.invoker.services.images.get_path(
|
||||||
image_type=image_type, image_name=image_name
|
image_type=image_type, image_name=image_name
|
||||||
@ -39,29 +37,17 @@ async def get_image(
|
|||||||
raise HTTPException(status_code=404)
|
raise HTTPException(status_code=404)
|
||||||
|
|
||||||
|
|
||||||
@images_router.delete("/{image_type}/{image_name}", operation_id="delete_image")
|
|
||||||
async def delete_image(
|
|
||||||
image_type: ImageType = Path(description="The type of image to delete"),
|
|
||||||
image_name: str = Path(description="The name of the image to delete"),
|
|
||||||
) -> None:
|
|
||||||
"""Deletes an image and its thumbnail"""
|
|
||||||
|
|
||||||
ApiDependencies.invoker.services.images.delete(
|
|
||||||
image_type=image_type, image_name=image_name
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@images_router.get(
|
@images_router.get(
|
||||||
"/{thumbnail_type}/thumbnails/{thumbnail_name}", operation_id="get_thumbnail"
|
"/{image_type}/thumbnails/{image_name}", operation_id="get_thumbnail"
|
||||||
)
|
)
|
||||||
async def get_thumbnail(
|
async def get_thumbnail(
|
||||||
thumbnail_type: ImageType = Path(description="The type of thumbnail to get"),
|
image_type: ImageType = Path(description="The type of image to get"),
|
||||||
thumbnail_name: str = Path(description="The name of the thumbnail to get"),
|
image_name: str = Path(description="The name of the image to get"),
|
||||||
) -> FileResponse | Response:
|
) -> FileResponse | Response:
|
||||||
"""Gets a thumbnail"""
|
"""Gets a thumbnail"""
|
||||||
|
|
||||||
path = ApiDependencies.invoker.services.images.get_path(
|
path = ApiDependencies.invoker.services.images.get_path(
|
||||||
image_type=thumbnail_type, image_name=thumbnail_name, is_thumbnail=True
|
image_type=image_type, image_name=image_name, is_thumbnail=True
|
||||||
)
|
)
|
||||||
|
|
||||||
if ApiDependencies.invoker.services.images.validate_path(path):
|
if ApiDependencies.invoker.services.images.validate_path(path):
|
||||||
@ -98,27 +84,19 @@ async def upload_image(
|
|||||||
|
|
||||||
filename = f"{uuid.uuid4()}_{str(int(datetime.now(timezone.utc).timestamp()))}.png"
|
filename = f"{uuid.uuid4()}_{str(int(datetime.now(timezone.utc).timestamp()))}.png"
|
||||||
|
|
||||||
saved_image = ApiDependencies.invoker.services.images.save(
|
(image_path, thumbnail_path, ctime) = ApiDependencies.invoker.services.images.save(
|
||||||
ImageType.UPLOAD, filename, img
|
ImageType.UPLOAD, filename, img
|
||||||
)
|
)
|
||||||
|
|
||||||
invokeai_metadata = ApiDependencies.invoker.services.metadata.get_metadata(img)
|
invokeai_metadata = ApiDependencies.invoker.services.metadata.get_metadata(img)
|
||||||
|
|
||||||
image_url = ApiDependencies.invoker.services.images.get_uri(
|
|
||||||
ImageType.UPLOAD, saved_image.image_name
|
|
||||||
)
|
|
||||||
|
|
||||||
thumbnail_url = ApiDependencies.invoker.services.images.get_uri(
|
|
||||||
ImageType.UPLOAD, saved_image.image_name, True
|
|
||||||
)
|
|
||||||
|
|
||||||
res = ImageResponse(
|
res = ImageResponse(
|
||||||
image_type=ImageType.UPLOAD,
|
image_type=ImageType.UPLOAD,
|
||||||
image_name=saved_image.image_name,
|
image_name=filename,
|
||||||
image_url=image_url,
|
image_url=f"api/v1/images/{ImageType.UPLOAD.value}/{filename}",
|
||||||
thumbnail_url=thumbnail_url,
|
thumbnail_url=f"api/v1/images/{ImageType.UPLOAD.value}/thumbnails/{os.path.splitext(filename)[0]}.webp",
|
||||||
metadata=ImageResponseMetadata(
|
metadata=ImageResponseMetadata(
|
||||||
created=saved_image.created,
|
created=ctime,
|
||||||
width=img.width,
|
width=img.width,
|
||||||
height=img.height,
|
height=img.height,
|
||||||
invokeai=invokeai_metadata,
|
invokeai=invokeai_metadata,
|
||||||
@ -126,7 +104,9 @@ async def upload_image(
|
|||||||
)
|
)
|
||||||
|
|
||||||
response.status_code = 201
|
response.status_code = 201
|
||||||
response.headers["Location"] = image_url
|
response.headers["Location"] = request.url_for(
|
||||||
|
"get_image", image_type=ImageType.UPLOAD.value, image_name=filename
|
||||||
|
)
|
||||||
|
|
||||||
return res
|
return res
|
||||||
|
|
||||||
|
@ -8,6 +8,10 @@ from fastapi.routing import APIRouter, HTTPException
|
|||||||
from pydantic import BaseModel, Field, parse_obj_as
|
from pydantic import BaseModel, Field, parse_obj_as
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from ..dependencies import ApiDependencies
|
from ..dependencies import ApiDependencies
|
||||||
|
from invokeai.backend.globals import Globals, global_converted_ckpts_dir
|
||||||
|
from invokeai.backend.args import Args
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
models_router = APIRouter(prefix="/v1/models", tags=["models"])
|
models_router = APIRouter(prefix="/v1/models", tags=["models"])
|
||||||
|
|
||||||
@ -108,20 +112,19 @@ async def update_model(
|
|||||||
async def delete_model(model_name: str) -> None:
|
async def delete_model(model_name: str) -> None:
|
||||||
"""Delete Model"""
|
"""Delete Model"""
|
||||||
model_names = ApiDependencies.invoker.services.model_manager.model_names()
|
model_names = ApiDependencies.invoker.services.model_manager.model_names()
|
||||||
logger = ApiDependencies.invoker.services.logger
|
|
||||||
model_exists = model_name in model_names
|
model_exists = model_name in model_names
|
||||||
|
|
||||||
# check if model exists
|
# check if model exists
|
||||||
logger.info(f"Checking for model {model_name}...")
|
print(f">> Checking for model {model_name}...")
|
||||||
|
|
||||||
if model_exists:
|
if model_exists:
|
||||||
logger.info(f"Deleting Model: {model_name}")
|
print(f">> Deleting Model: {model_name}")
|
||||||
ApiDependencies.invoker.services.model_manager.del_model(model_name, delete_files=True)
|
ApiDependencies.invoker.services.model_manager.del_model(model_name, delete_files=True)
|
||||||
logger.info(f"Model Deleted: {model_name}")
|
print(f">> Model Deleted: {model_name}")
|
||||||
raise HTTPException(status_code=204, detail=f"Model '{model_name}' deleted successfully")
|
raise HTTPException(status_code=204, detail=f"Model '{model_name}' deleted successfully")
|
||||||
|
|
||||||
else:
|
else:
|
||||||
logger.error(f"Model not found")
|
print(f">> Model not found")
|
||||||
raise HTTPException(status_code=404, detail=f"Model '{model_name}' not found")
|
raise HTTPException(status_code=404, detail=f"Model '{model_name}' not found")
|
||||||
|
|
||||||
|
|
||||||
|
@ -2,7 +2,8 @@
|
|||||||
|
|
||||||
from typing import Annotated, List, Optional, Union
|
from typing import Annotated, List, Optional, Union
|
||||||
|
|
||||||
from fastapi import Body, HTTPException, Path, Query, Response
|
from fastapi import Body, Path, Query
|
||||||
|
from fastapi.responses import Response
|
||||||
from fastapi.routing import APIRouter
|
from fastapi.routing import APIRouter
|
||||||
from pydantic.fields import Field
|
from pydantic.fields import Field
|
||||||
|
|
||||||
@ -75,7 +76,7 @@ async def get_session(
|
|||||||
"""Gets a session"""
|
"""Gets a session"""
|
||||||
session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
|
session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
|
||||||
if session is None:
|
if session is None:
|
||||||
raise HTTPException(status_code=404)
|
return Response(status_code=404)
|
||||||
else:
|
else:
|
||||||
return session
|
return session
|
||||||
|
|
||||||
@ -98,7 +99,7 @@ async def add_node(
|
|||||||
"""Adds a node to the graph"""
|
"""Adds a node to the graph"""
|
||||||
session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
|
session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
|
||||||
if session is None:
|
if session is None:
|
||||||
raise HTTPException(status_code=404)
|
return Response(status_code=404)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
session.add_node(node)
|
session.add_node(node)
|
||||||
@ -107,9 +108,9 @@ async def add_node(
|
|||||||
) # TODO: can this be done automatically, or add node through an API?
|
) # TODO: can this be done automatically, or add node through an API?
|
||||||
return session.id
|
return session.id
|
||||||
except NodeAlreadyExecutedError:
|
except NodeAlreadyExecutedError:
|
||||||
raise HTTPException(status_code=400)
|
return Response(status_code=400)
|
||||||
except IndexError:
|
except IndexError:
|
||||||
raise HTTPException(status_code=400)
|
return Response(status_code=400)
|
||||||
|
|
||||||
|
|
||||||
@session_router.put(
|
@session_router.put(
|
||||||
@ -131,7 +132,7 @@ async def update_node(
|
|||||||
"""Updates a node in the graph and removes all linked edges"""
|
"""Updates a node in the graph and removes all linked edges"""
|
||||||
session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
|
session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
|
||||||
if session is None:
|
if session is None:
|
||||||
raise HTTPException(status_code=404)
|
return Response(status_code=404)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
session.update_node(node_path, node)
|
session.update_node(node_path, node)
|
||||||
@ -140,9 +141,9 @@ async def update_node(
|
|||||||
) # TODO: can this be done automatically, or add node through an API?
|
) # TODO: can this be done automatically, or add node through an API?
|
||||||
return session
|
return session
|
||||||
except NodeAlreadyExecutedError:
|
except NodeAlreadyExecutedError:
|
||||||
raise HTTPException(status_code=400)
|
return Response(status_code=400)
|
||||||
except IndexError:
|
except IndexError:
|
||||||
raise HTTPException(status_code=400)
|
return Response(status_code=400)
|
||||||
|
|
||||||
|
|
||||||
@session_router.delete(
|
@session_router.delete(
|
||||||
@ -161,7 +162,7 @@ async def delete_node(
|
|||||||
"""Deletes a node in the graph and removes all linked edges"""
|
"""Deletes a node in the graph and removes all linked edges"""
|
||||||
session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
|
session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
|
||||||
if session is None:
|
if session is None:
|
||||||
raise HTTPException(status_code=404)
|
return Response(status_code=404)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
session.delete_node(node_path)
|
session.delete_node(node_path)
|
||||||
@ -170,9 +171,9 @@ async def delete_node(
|
|||||||
) # TODO: can this be done automatically, or add node through an API?
|
) # TODO: can this be done automatically, or add node through an API?
|
||||||
return session
|
return session
|
||||||
except NodeAlreadyExecutedError:
|
except NodeAlreadyExecutedError:
|
||||||
raise HTTPException(status_code=400)
|
return Response(status_code=400)
|
||||||
except IndexError:
|
except IndexError:
|
||||||
raise HTTPException(status_code=400)
|
return Response(status_code=400)
|
||||||
|
|
||||||
|
|
||||||
@session_router.post(
|
@session_router.post(
|
||||||
@ -191,7 +192,7 @@ async def add_edge(
|
|||||||
"""Adds an edge to the graph"""
|
"""Adds an edge to the graph"""
|
||||||
session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
|
session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
|
||||||
if session is None:
|
if session is None:
|
||||||
raise HTTPException(status_code=404)
|
return Response(status_code=404)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
session.add_edge(edge)
|
session.add_edge(edge)
|
||||||
@ -200,9 +201,9 @@ async def add_edge(
|
|||||||
) # TODO: can this be done automatically, or add node through an API?
|
) # TODO: can this be done automatically, or add node through an API?
|
||||||
return session
|
return session
|
||||||
except NodeAlreadyExecutedError:
|
except NodeAlreadyExecutedError:
|
||||||
raise HTTPException(status_code=400)
|
return Response(status_code=400)
|
||||||
except IndexError:
|
except IndexError:
|
||||||
raise HTTPException(status_code=400)
|
return Response(status_code=400)
|
||||||
|
|
||||||
|
|
||||||
# TODO: the edge being in the path here is really ugly, find a better solution
|
# TODO: the edge being in the path here is really ugly, find a better solution
|
||||||
@ -225,7 +226,7 @@ async def delete_edge(
|
|||||||
"""Deletes an edge from the graph"""
|
"""Deletes an edge from the graph"""
|
||||||
session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
|
session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
|
||||||
if session is None:
|
if session is None:
|
||||||
raise HTTPException(status_code=404)
|
return Response(status_code=404)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
edge = Edge(
|
edge = Edge(
|
||||||
@ -238,9 +239,9 @@ async def delete_edge(
|
|||||||
) # TODO: can this be done automatically, or add node through an API?
|
) # TODO: can this be done automatically, or add node through an API?
|
||||||
return session
|
return session
|
||||||
except NodeAlreadyExecutedError:
|
except NodeAlreadyExecutedError:
|
||||||
raise HTTPException(status_code=400)
|
return Response(status_code=400)
|
||||||
except IndexError:
|
except IndexError:
|
||||||
raise HTTPException(status_code=400)
|
return Response(status_code=400)
|
||||||
|
|
||||||
|
|
||||||
@session_router.put(
|
@session_router.put(
|
||||||
@ -258,14 +259,14 @@ async def invoke_session(
|
|||||||
all: bool = Query(
|
all: bool = Query(
|
||||||
default=False, description="Whether or not to invoke all remaining invocations"
|
default=False, description="Whether or not to invoke all remaining invocations"
|
||||||
),
|
),
|
||||||
) -> Response:
|
) -> None:
|
||||||
"""Invokes a session"""
|
"""Invokes a session"""
|
||||||
session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
|
session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
|
||||||
if session is None:
|
if session is None:
|
||||||
raise HTTPException(status_code=404)
|
return Response(status_code=404)
|
||||||
|
|
||||||
if session.is_complete():
|
if session.is_complete():
|
||||||
raise HTTPException(status_code=400)
|
return Response(status_code=400)
|
||||||
|
|
||||||
ApiDependencies.invoker.invoke(session, invoke_all=all)
|
ApiDependencies.invoker.invoke(session, invoke_all=all)
|
||||||
return Response(status_code=202)
|
return Response(status_code=202)
|
||||||
@ -280,7 +281,7 @@ async def invoke_session(
|
|||||||
)
|
)
|
||||||
async def cancel_session_invoke(
|
async def cancel_session_invoke(
|
||||||
session_id: str = Path(description="The id of the session to cancel"),
|
session_id: str = Path(description="The id of the session to cancel"),
|
||||||
) -> Response:
|
) -> None:
|
||||||
"""Invokes a session"""
|
"""Invokes a session"""
|
||||||
ApiDependencies.invoker.cancel(session_id)
|
ApiDependencies.invoker.cancel(session_id)
|
||||||
return Response(status_code=202)
|
return Response(status_code=202)
|
||||||
|
@ -3,7 +3,6 @@ import asyncio
|
|||||||
from inspect import signature
|
from inspect import signature
|
||||||
|
|
||||||
import uvicorn
|
import uvicorn
|
||||||
import invokeai.backend.util.logging as logger
|
|
||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
from fastapi.openapi.docs import get_redoc_html, get_swagger_ui_html
|
from fastapi.openapi.docs import get_redoc_html, get_swagger_ui_html
|
||||||
@ -17,6 +16,7 @@ from ..backend import Args
|
|||||||
from .api.dependencies import ApiDependencies
|
from .api.dependencies import ApiDependencies
|
||||||
from .api.routers import images, sessions, models
|
from .api.routers import images, sessions, models
|
||||||
from .api.sockets import SocketIO
|
from .api.sockets import SocketIO
|
||||||
|
from .invocations import *
|
||||||
from .invocations.baseinvocation import BaseInvocation
|
from .invocations.baseinvocation import BaseInvocation
|
||||||
|
|
||||||
# Create the app
|
# Create the app
|
||||||
@ -56,7 +56,7 @@ async def startup_event():
|
|||||||
config.parse_args()
|
config.parse_args()
|
||||||
|
|
||||||
ApiDependencies.initialize(
|
ApiDependencies.initialize(
|
||||||
config=config, event_handler_id=event_handler_id, logger=logger
|
config=config, event_handler_id=event_handler_id
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -2,15 +2,14 @@
|
|||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
import argparse
|
import argparse
|
||||||
from typing import Any, Callable, Iterable, Literal, Union, get_args, get_origin, get_type_hints
|
from typing import Any, Callable, Iterable, Literal, get_args, get_origin, get_type_hints
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
import networkx as nx
|
import networkx as nx
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
import invokeai.backend.util.logging as logger
|
|
||||||
from ..invocations.baseinvocation import BaseInvocation
|
from ..invocations.baseinvocation import BaseInvocation
|
||||||
from ..invocations.image import ImageField
|
from ..invocations.image import ImageField
|
||||||
from ..services.graph import GraphExecutionState, LibraryGraph, Edge
|
from ..services.graph import GraphExecutionState, LibraryGraph, GraphInvocation, Edge
|
||||||
from ..services.invoker import Invoker
|
from ..services.invoker import Invoker
|
||||||
|
|
||||||
|
|
||||||
@ -230,7 +229,7 @@ class HistoryCommand(BaseCommand):
|
|||||||
for i in range(min(self.count, len(history))):
|
for i in range(min(self.count, len(history))):
|
||||||
entry_id = history[-1 - i]
|
entry_id = history[-1 - i]
|
||||||
entry = context.get_session().graph.get_node(entry_id)
|
entry = context.get_session().graph.get_node(entry_id)
|
||||||
logger.info(f"{entry_id}: {get_invocation_command(entry)}")
|
print(f"{entry_id}: {get_invocation_command(entry)}")
|
||||||
|
|
||||||
|
|
||||||
class SetDefaultCommand(BaseCommand):
|
class SetDefaultCommand(BaseCommand):
|
||||||
|
@ -10,7 +10,6 @@ import shlex
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List, Dict, Literal, get_args, get_type_hints, get_origin
|
from typing import List, Dict, Literal, get_args, get_type_hints, get_origin
|
||||||
|
|
||||||
import invokeai.backend.util.logging as logger
|
|
||||||
from ...backend import ModelManager, Globals
|
from ...backend import ModelManager, Globals
|
||||||
from ..invocations.baseinvocation import BaseInvocation
|
from ..invocations.baseinvocation import BaseInvocation
|
||||||
from .commands import BaseCommand
|
from .commands import BaseCommand
|
||||||
@ -161,8 +160,8 @@ def set_autocompleter(model_manager: ModelManager) -> Completer:
|
|||||||
pass
|
pass
|
||||||
except OSError: # file likely corrupted
|
except OSError: # file likely corrupted
|
||||||
newname = f"{histfile}.old"
|
newname = f"{histfile}.old"
|
||||||
logger.error(
|
print(
|
||||||
f"Your history file {histfile} couldn't be loaded and may be corrupted. Renaming it to {newname}"
|
f"## Your history file {histfile} couldn't be loaded and may be corrupted. Renaming it to {newname}"
|
||||||
)
|
)
|
||||||
histfile.replace(Path(newname))
|
histfile.replace(Path(newname))
|
||||||
atexit.register(readline.write_history_file, histfile)
|
atexit.register(readline.write_history_file, histfile)
|
||||||
|
@ -13,20 +13,21 @@ from typing import (
|
|||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from pydantic.fields import Field
|
from pydantic.fields import Field
|
||||||
|
|
||||||
|
|
||||||
import invokeai.backend.util.logging as logger
|
|
||||||
from invokeai.app.services.metadata import PngMetadataService
|
from invokeai.app.services.metadata import PngMetadataService
|
||||||
|
|
||||||
from .services.default_graphs import create_system_graphs
|
from .services.default_graphs import create_system_graphs
|
||||||
|
|
||||||
from .services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage
|
from .services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage
|
||||||
|
|
||||||
from ..backend import Args
|
from ..backend import Args
|
||||||
from .cli.commands import BaseCommand, CliContext, ExitCli, add_graph_parsers, add_parsers
|
from .cli.commands import BaseCommand, CliContext, ExitCli, add_graph_parsers, add_parsers, get_graph_execution_history
|
||||||
from .cli.completer import set_autocompleter
|
from .cli.completer import set_autocompleter
|
||||||
|
from .invocations import *
|
||||||
from .invocations.baseinvocation import BaseInvocation
|
from .invocations.baseinvocation import BaseInvocation
|
||||||
from .services.events import EventServiceBase
|
from .services.events import EventServiceBase
|
||||||
from .services.model_manager_initializer import get_model_manager
|
from .services.model_manager_initializer import get_model_manager
|
||||||
from .services.restoration_services import RestorationServices
|
from .services.restoration_services import RestorationServices
|
||||||
from .services.graph import Edge, EdgeConnection, GraphExecutionState, GraphInvocation, LibraryGraph, are_connection_types_compatible
|
from .services.graph import Edge, EdgeConnection, ExposedNodeInput, GraphExecutionState, GraphInvocation, LibraryGraph, are_connection_types_compatible
|
||||||
from .services.default_graphs import default_text_to_image_graph_id
|
from .services.default_graphs import default_text_to_image_graph_id
|
||||||
from .services.image_storage import DiskImageStorage
|
from .services.image_storage import DiskImageStorage
|
||||||
from .services.invocation_queue import MemoryInvocationQueue
|
from .services.invocation_queue import MemoryInvocationQueue
|
||||||
@ -181,7 +182,7 @@ def invoke_all(context: CliContext):
|
|||||||
# Print any errors
|
# Print any errors
|
||||||
if context.session.has_error():
|
if context.session.has_error():
|
||||||
for n in context.session.errors:
|
for n in context.session.errors:
|
||||||
context.invoker.services.logger.error(
|
print(
|
||||||
f"Error in node {n} (source node {context.session.prepared_source_mapping[n]}): {context.session.errors[n]}"
|
f"Error in node {n} (source node {context.session.prepared_source_mapping[n]}): {context.session.errors[n]}"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -191,13 +192,13 @@ def invoke_all(context: CliContext):
|
|||||||
def invoke_cli():
|
def invoke_cli():
|
||||||
config = Args()
|
config = Args()
|
||||||
config.parse_args()
|
config.parse_args()
|
||||||
model_manager = get_model_manager(config,logger=logger)
|
model_manager = get_model_manager(config)
|
||||||
|
|
||||||
# This initializes the autocompleter and returns it.
|
# This initializes the autocompleter and returns it.
|
||||||
# Currently nothing is done with the returned Completer
|
# Currently nothing is done with the returned Completer
|
||||||
# object, but the object can be used to change autocompletion
|
# object, but the object can be used to change autocompletion
|
||||||
# behavior on the fly, if desired.
|
# behavior on the fly, if desired.
|
||||||
set_autocompleter(model_manager)
|
completer = set_autocompleter(model_manager)
|
||||||
|
|
||||||
events = EventServiceBase()
|
events = EventServiceBase()
|
||||||
|
|
||||||
@ -224,8 +225,7 @@ def invoke_cli():
|
|||||||
filename=db_location, table_name="graph_executions"
|
filename=db_location, table_name="graph_executions"
|
||||||
),
|
),
|
||||||
processor=DefaultInvocationProcessor(),
|
processor=DefaultInvocationProcessor(),
|
||||||
restoration=RestorationServices(config,logger=logger),
|
restoration=RestorationServices(config),
|
||||||
logger=logger,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
system_graphs = create_system_graphs(services.graph_library)
|
system_graphs = create_system_graphs(services.graph_library)
|
||||||
@ -365,12 +365,12 @@ def invoke_cli():
|
|||||||
invoke_all(context)
|
invoke_all(context)
|
||||||
|
|
||||||
except InvalidArgs:
|
except InvalidArgs:
|
||||||
invoker.services.logger.warning('Invalid command, use "help" to list commands')
|
print('Invalid command, use "help" to list commands')
|
||||||
continue
|
continue
|
||||||
|
|
||||||
except SessionError:
|
except SessionError:
|
||||||
# Start a new session
|
# Start a new session
|
||||||
invoker.services.logger.warning("Session error: creating a new session")
|
print("Session error: creating a new session")
|
||||||
context.reset()
|
context.reset()
|
||||||
|
|
||||||
except ExitCli:
|
except ExitCli:
|
||||||
|
@ -1,245 +0,0 @@
|
|||||||
from typing import Literal, Optional, Union
|
|
||||||
from pydantic import BaseModel, Field
|
|
||||||
|
|
||||||
from invokeai.app.invocations.util.choose_model import choose_model
|
|
||||||
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext, InvocationConfig
|
|
||||||
|
|
||||||
from ...backend.util.devices import choose_torch_device, torch_dtype
|
|
||||||
from ...backend.stable_diffusion.diffusion import InvokeAIDiffuserComponent
|
|
||||||
from ...backend.stable_diffusion.textual_inversion_manager import TextualInversionManager
|
|
||||||
|
|
||||||
from compel import Compel
|
|
||||||
from compel.prompt_parser import (
|
|
||||||
Blend,
|
|
||||||
CrossAttentionControlSubstitute,
|
|
||||||
FlattenedPrompt,
|
|
||||||
Fragment,
|
|
||||||
)
|
|
||||||
|
|
||||||
from invokeai.backend.globals import Globals
|
|
||||||
|
|
||||||
|
|
||||||
class ConditioningField(BaseModel):
|
|
||||||
conditioning_name: Optional[str] = Field(default=None, description="The name of conditioning data")
|
|
||||||
class Config:
|
|
||||||
schema_extra = {"required": ["conditioning_name"]}
|
|
||||||
|
|
||||||
|
|
||||||
class CompelOutput(BaseInvocationOutput):
|
|
||||||
"""Compel parser output"""
|
|
||||||
|
|
||||||
#fmt: off
|
|
||||||
type: Literal["compel_output"] = "compel_output"
|
|
||||||
|
|
||||||
conditioning: ConditioningField = Field(default=None, description="Conditioning")
|
|
||||||
#fmt: on
|
|
||||||
|
|
||||||
|
|
||||||
class CompelInvocation(BaseInvocation):
|
|
||||||
"""Parse prompt using compel package to conditioning."""
|
|
||||||
|
|
||||||
type: Literal["compel"] = "compel"
|
|
||||||
|
|
||||||
prompt: str = Field(default="", description="Prompt")
|
|
||||||
model: str = Field(default="", description="Model to use")
|
|
||||||
|
|
||||||
# Schema customisation
|
|
||||||
class Config(InvocationConfig):
|
|
||||||
schema_extra = {
|
|
||||||
"ui": {
|
|
||||||
"title": "Prompt (Compel)",
|
|
||||||
"tags": ["prompt", "compel"],
|
|
||||||
"type_hints": {
|
|
||||||
"model": "model"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> CompelOutput:
|
|
||||||
|
|
||||||
# TODO: load without model
|
|
||||||
model = choose_model(context.services.model_manager, self.model)
|
|
||||||
pipeline = model["model"]
|
|
||||||
tokenizer = pipeline.tokenizer
|
|
||||||
text_encoder = pipeline.text_encoder
|
|
||||||
|
|
||||||
# TODO: global? input?
|
|
||||||
#use_full_precision = precision == "float32" or precision == "autocast"
|
|
||||||
#use_full_precision = False
|
|
||||||
|
|
||||||
# TODO: redo TI when separate model loding implemented
|
|
||||||
#textual_inversion_manager = TextualInversionManager(
|
|
||||||
# tokenizer=tokenizer,
|
|
||||||
# text_encoder=text_encoder,
|
|
||||||
# full_precision=use_full_precision,
|
|
||||||
#)
|
|
||||||
|
|
||||||
def load_huggingface_concepts(concepts: list[str]):
|
|
||||||
pipeline.textual_inversion_manager.load_huggingface_concepts(concepts)
|
|
||||||
|
|
||||||
# apply the concepts library to the prompt
|
|
||||||
prompt_str = pipeline.textual_inversion_manager.hf_concepts_library.replace_concepts_with_triggers(
|
|
||||||
self.prompt,
|
|
||||||
lambda concepts: load_huggingface_concepts(concepts),
|
|
||||||
pipeline.textual_inversion_manager.get_all_trigger_strings(),
|
|
||||||
)
|
|
||||||
|
|
||||||
# lazy-load any deferred textual inversions.
|
|
||||||
# this might take a couple of seconds the first time a textual inversion is used.
|
|
||||||
pipeline.textual_inversion_manager.create_deferred_token_ids_for_any_trigger_terms(
|
|
||||||
prompt_str
|
|
||||||
)
|
|
||||||
|
|
||||||
compel = Compel(
|
|
||||||
tokenizer=tokenizer,
|
|
||||||
text_encoder=text_encoder,
|
|
||||||
textual_inversion_manager=pipeline.textual_inversion_manager,
|
|
||||||
dtype_for_device_getter=torch_dtype,
|
|
||||||
truncate_long_prompts=True, # TODO:
|
|
||||||
)
|
|
||||||
|
|
||||||
# TODO: support legacy blend?
|
|
||||||
|
|
||||||
prompt: Union[FlattenedPrompt, Blend] = Compel.parse_prompt_string(prompt_str)
|
|
||||||
|
|
||||||
if getattr(Globals, "log_tokenization", False):
|
|
||||||
log_tokenization_for_prompt_object(prompt, tokenizer)
|
|
||||||
|
|
||||||
c, options = compel.build_conditioning_tensor_for_prompt_object(prompt)
|
|
||||||
|
|
||||||
# TODO: long prompt support
|
|
||||||
#if not self.truncate_long_prompts:
|
|
||||||
# [c, uc] = compel.pad_conditioning_tensors_to_same_length([c, uc])
|
|
||||||
|
|
||||||
ec = InvokeAIDiffuserComponent.ExtraConditioningInfo(
|
|
||||||
tokens_count_including_eos_bos=get_max_token_count(tokenizer, prompt),
|
|
||||||
cross_attention_control_args=options.get("cross_attention_control", None),
|
|
||||||
)
|
|
||||||
|
|
||||||
conditioning_name = f"{context.graph_execution_state_id}_{self.id}_conditioning"
|
|
||||||
|
|
||||||
# TODO: hacky but works ;D maybe rename latents somehow?
|
|
||||||
context.services.latents.set(conditioning_name, (c, ec))
|
|
||||||
|
|
||||||
return CompelOutput(
|
|
||||||
conditioning=ConditioningField(
|
|
||||||
conditioning_name=conditioning_name,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def get_max_token_count(
|
|
||||||
tokenizer, prompt: Union[FlattenedPrompt, Blend], truncate_if_too_long=False
|
|
||||||
) -> int:
|
|
||||||
if type(prompt) is Blend:
|
|
||||||
blend: Blend = prompt
|
|
||||||
return max(
|
|
||||||
[
|
|
||||||
get_max_token_count(tokenizer, c, truncate_if_too_long)
|
|
||||||
for c in blend.prompts
|
|
||||||
]
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
return len(
|
|
||||||
get_tokens_for_prompt_object(tokenizer, prompt, truncate_if_too_long)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def get_tokens_for_prompt_object(
|
|
||||||
tokenizer, parsed_prompt: FlattenedPrompt, truncate_if_too_long=True
|
|
||||||
) -> [str]:
|
|
||||||
if type(parsed_prompt) is Blend:
|
|
||||||
raise ValueError(
|
|
||||||
"Blend is not supported here - you need to get tokens for each of its .children"
|
|
||||||
)
|
|
||||||
|
|
||||||
text_fragments = [
|
|
||||||
x.text
|
|
||||||
if type(x) is Fragment
|
|
||||||
else (
|
|
||||||
" ".join([f.text for f in x.original])
|
|
||||||
if type(x) is CrossAttentionControlSubstitute
|
|
||||||
else str(x)
|
|
||||||
)
|
|
||||||
for x in parsed_prompt.children
|
|
||||||
]
|
|
||||||
text = " ".join(text_fragments)
|
|
||||||
tokens = tokenizer.tokenize(text)
|
|
||||||
if truncate_if_too_long:
|
|
||||||
max_tokens_length = tokenizer.model_max_length - 2 # typically 75
|
|
||||||
tokens = tokens[0:max_tokens_length]
|
|
||||||
return tokens
|
|
||||||
|
|
||||||
|
|
||||||
def log_tokenization_for_prompt_object(
|
|
||||||
p: Union[Blend, FlattenedPrompt], tokenizer, display_label_prefix=None
|
|
||||||
):
|
|
||||||
display_label_prefix = display_label_prefix or ""
|
|
||||||
if type(p) is Blend:
|
|
||||||
blend: Blend = p
|
|
||||||
for i, c in enumerate(blend.prompts):
|
|
||||||
log_tokenization_for_prompt_object(
|
|
||||||
c,
|
|
||||||
tokenizer,
|
|
||||||
display_label_prefix=f"{display_label_prefix}(blend part {i + 1}, weight={blend.weights[i]})",
|
|
||||||
)
|
|
||||||
elif type(p) is FlattenedPrompt:
|
|
||||||
flattened_prompt: FlattenedPrompt = p
|
|
||||||
if flattened_prompt.wants_cross_attention_control:
|
|
||||||
original_fragments = []
|
|
||||||
edited_fragments = []
|
|
||||||
for f in flattened_prompt.children:
|
|
||||||
if type(f) is CrossAttentionControlSubstitute:
|
|
||||||
original_fragments += f.original
|
|
||||||
edited_fragments += f.edited
|
|
||||||
else:
|
|
||||||
original_fragments.append(f)
|
|
||||||
edited_fragments.append(f)
|
|
||||||
|
|
||||||
original_text = " ".join([x.text for x in original_fragments])
|
|
||||||
log_tokenization_for_text(
|
|
||||||
original_text,
|
|
||||||
tokenizer,
|
|
||||||
display_label=f"{display_label_prefix}(.swap originals)",
|
|
||||||
)
|
|
||||||
edited_text = " ".join([x.text for x in edited_fragments])
|
|
||||||
log_tokenization_for_text(
|
|
||||||
edited_text,
|
|
||||||
tokenizer,
|
|
||||||
display_label=f"{display_label_prefix}(.swap replacements)",
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
text = " ".join([x.text for x in flattened_prompt.children])
|
|
||||||
log_tokenization_for_text(
|
|
||||||
text, tokenizer, display_label=display_label_prefix
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def log_tokenization_for_text(text, tokenizer, display_label=None, truncate_if_too_long=False):
|
|
||||||
"""shows how the prompt is tokenized
|
|
||||||
# usually tokens have '</w>' to indicate end-of-word,
|
|
||||||
# but for readability it has been replaced with ' '
|
|
||||||
"""
|
|
||||||
tokens = tokenizer.tokenize(text)
|
|
||||||
tokenized = ""
|
|
||||||
discarded = ""
|
|
||||||
usedTokens = 0
|
|
||||||
totalTokens = len(tokens)
|
|
||||||
|
|
||||||
for i in range(0, totalTokens):
|
|
||||||
token = tokens[i].replace("</w>", " ")
|
|
||||||
# alternate color
|
|
||||||
s = (usedTokens % 6) + 1
|
|
||||||
if truncate_if_too_long and i >= tokenizer.model_max_length:
|
|
||||||
discarded = discarded + f"\x1b[0;3{s};40m{token}"
|
|
||||||
else:
|
|
||||||
tokenized = tokenized + f"\x1b[0;3{s};40m{token}"
|
|
||||||
usedTokens += 1
|
|
||||||
|
|
||||||
if usedTokens > 0:
|
|
||||||
print(f'\n>> [TOKENLOG] Tokens {display_label or ""} ({usedTokens}):')
|
|
||||||
print(f"{tokenized}\x1b[0m")
|
|
||||||
|
|
||||||
if discarded != "":
|
|
||||||
print(f"\n>> [TOKENLOG] Tokens Discarded ({totalTokens - usedTokens}):")
|
|
||||||
print(f"{discarded}\x1b[0m")
|
|
@ -46,8 +46,8 @@ class TextToImageInvocation(BaseInvocation, SDImageInvocation):
|
|||||||
prompt: Optional[str] = Field(description="The prompt to generate an image from")
|
prompt: Optional[str] = Field(description="The prompt to generate an image from")
|
||||||
seed: int = Field(default=-1,ge=-1, le=np.iinfo(np.uint32).max, description="The seed to use (-1 for a random seed)", )
|
seed: int = Field(default=-1,ge=-1, le=np.iinfo(np.uint32).max, description="The seed to use (-1 for a random seed)", )
|
||||||
steps: int = Field(default=10, gt=0, description="The number of steps to use to generate the image")
|
steps: int = Field(default=10, gt=0, description="The number of steps to use to generate the image")
|
||||||
width: int = Field(default=512, multiple_of=8, gt=0, description="The width of the resulting image", )
|
width: int = Field(default=512, multiple_of=64, gt=0, description="The width of the resulting image", )
|
||||||
height: int = Field(default=512, multiple_of=8, gt=0, description="The height of the resulting image", )
|
height: int = Field(default=512, multiple_of=64, gt=0, description="The height of the resulting image", )
|
||||||
cfg_scale: float = Field(default=7.5, gt=0, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", )
|
cfg_scale: float = Field(default=7.5, gt=0, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", )
|
||||||
scheduler: SAMPLER_NAME_VALUES = Field(default="k_lms", description="The scheduler to use" )
|
scheduler: SAMPLER_NAME_VALUES = Field(default="k_lms", description="The scheduler to use" )
|
||||||
seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", )
|
seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", )
|
||||||
@ -150,9 +150,6 @@ class ImageToImageInvocation(TextToImageInvocation):
|
|||||||
)
|
)
|
||||||
mask = None
|
mask = None
|
||||||
|
|
||||||
if self.fit:
|
|
||||||
image = image.resize((self.width, self.height))
|
|
||||||
|
|
||||||
# Handle invalid model parameter
|
# Handle invalid model parameter
|
||||||
model = choose_model(context.services.model_manager, self.model)
|
model = choose_model(context.services.model_manager, self.model)
|
||||||
|
|
||||||
@ -250,8 +247,8 @@ class InpaintInvocation(ImageToImageInvocation):
|
|||||||
|
|
||||||
outputs = Inpaint(model).generate(
|
outputs = Inpaint(model).generate(
|
||||||
prompt=self.prompt,
|
prompt=self.prompt,
|
||||||
init_image=image,
|
init_img=image,
|
||||||
mask_image=mask,
|
init_mask=mask,
|
||||||
step_callback=partial(self.dispatch_progress, context, source_node_id),
|
step_callback=partial(self.dispatch_progress, context, source_node_id),
|
||||||
**self.dict(
|
**self.dict(
|
||||||
exclude={"prompt", "image", "mask"}
|
exclude={"prompt", "image", "mask"}
|
||||||
|
@ -13,13 +13,13 @@ from ...backend.model_management.model_manager import ModelManager
|
|||||||
from ...backend.util.devices import choose_torch_device, torch_dtype
|
from ...backend.util.devices import choose_torch_device, torch_dtype
|
||||||
from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import PostprocessingSettings
|
from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import PostprocessingSettings
|
||||||
from ...backend.image_util.seamless import configure_model_padding
|
from ...backend.image_util.seamless import configure_model_padding
|
||||||
|
from ...backend.prompting.conditioning import get_uc_and_c_and_ec
|
||||||
from ...backend.stable_diffusion.diffusers_pipeline import ConditioningData, StableDiffusionGeneratorPipeline
|
from ...backend.stable_diffusion.diffusers_pipeline import ConditioningData, StableDiffusionGeneratorPipeline
|
||||||
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext, InvocationConfig
|
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext, InvocationConfig
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from ..services.image_storage import ImageType
|
from ..services.image_storage import ImageType
|
||||||
from .baseinvocation import BaseInvocation, InvocationContext
|
from .baseinvocation import BaseInvocation, InvocationContext
|
||||||
from .image import ImageField, ImageOutput, build_image_output
|
from .image import ImageField, ImageOutput, build_image_output
|
||||||
from .compel import ConditioningField
|
|
||||||
from ...backend.stable_diffusion import PipelineIntermediateState
|
from ...backend.stable_diffusion import PipelineIntermediateState
|
||||||
from diffusers.schedulers import SchedulerMixin as Scheduler
|
from diffusers.schedulers import SchedulerMixin as Scheduler
|
||||||
import diffusers
|
import diffusers
|
||||||
@ -113,8 +113,8 @@ class NoiseInvocation(BaseInvocation):
|
|||||||
|
|
||||||
# Inputs
|
# Inputs
|
||||||
seed: int = Field(ge=0, le=np.iinfo(np.uint32).max, description="The seed to use", default_factory=random_seed)
|
seed: int = Field(ge=0, le=np.iinfo(np.uint32).max, description="The seed to use", default_factory=random_seed)
|
||||||
width: int = Field(default=512, multiple_of=8, gt=0, description="The width of the resulting noise", )
|
width: int = Field(default=512, multiple_of=64, gt=0, description="The width of the resulting noise", )
|
||||||
height: int = Field(default=512, multiple_of=8, gt=0, description="The height of the resulting noise", )
|
height: int = Field(default=512, multiple_of=64, gt=0, description="The height of the resulting noise", )
|
||||||
|
|
||||||
|
|
||||||
# Schema customisation
|
# Schema customisation
|
||||||
@ -138,16 +138,19 @@ class NoiseInvocation(BaseInvocation):
|
|||||||
|
|
||||||
# Text to image
|
# Text to image
|
||||||
class TextToLatentsInvocation(BaseInvocation):
|
class TextToLatentsInvocation(BaseInvocation):
|
||||||
"""Generates latents from conditionings."""
|
"""Generates latents from a prompt."""
|
||||||
|
|
||||||
type: Literal["t2l"] = "t2l"
|
type: Literal["t2l"] = "t2l"
|
||||||
|
|
||||||
# Inputs
|
# Inputs
|
||||||
|
# TODO: consider making prompt optional to enable providing prompt through a link
|
||||||
# fmt: off
|
# fmt: off
|
||||||
positive_conditioning: Optional[ConditioningField] = Field(description="Positive conditioning for generation")
|
prompt: Optional[str] = Field(description="The prompt to generate an image from")
|
||||||
negative_conditioning: Optional[ConditioningField] = Field(description="Negative conditioning for generation")
|
seed: int = Field(default=-1,ge=-1, le=np.iinfo(np.uint32).max, description="The seed to use (-1 for a random seed)", )
|
||||||
noise: Optional[LatentsField] = Field(description="The noise to use")
|
noise: Optional[LatentsField] = Field(description="The noise to use")
|
||||||
steps: int = Field(default=10, gt=0, description="The number of steps to use to generate the image")
|
steps: int = Field(default=10, gt=0, description="The number of steps to use to generate the image")
|
||||||
|
width: int = Field(default=512, multiple_of=64, gt=0, description="The width of the resulting image", )
|
||||||
|
height: int = Field(default=512, multiple_of=64, gt=0, description="The height of the resulting image", )
|
||||||
cfg_scale: float = Field(default=7.5, gt=0, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", )
|
cfg_scale: float = Field(default=7.5, gt=0, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", )
|
||||||
scheduler: SAMPLER_NAME_VALUES = Field(default="k_lms", description="The scheduler to use" )
|
scheduler: SAMPLER_NAME_VALUES = Field(default="k_lms", description="The scheduler to use" )
|
||||||
seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", )
|
seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", )
|
||||||
@ -203,10 +206,8 @@ class TextToLatentsInvocation(BaseInvocation):
|
|||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
def get_conditioning_data(self, context: InvocationContext, model: StableDiffusionGeneratorPipeline) -> ConditioningData:
|
def get_conditioning_data(self, model: StableDiffusionGeneratorPipeline) -> ConditioningData:
|
||||||
c, extra_conditioning_info = context.services.latents.get(self.positive_conditioning.conditioning_name)
|
uc, c, extra_conditioning_info = get_uc_and_c_and_ec(self.prompt, model=model)
|
||||||
uc, _ = context.services.latents.get(self.negative_conditioning.conditioning_name)
|
|
||||||
|
|
||||||
conditioning_data = ConditioningData(
|
conditioning_data = ConditioningData(
|
||||||
uc,
|
uc,
|
||||||
c,
|
c,
|
||||||
@ -233,7 +234,7 @@ class TextToLatentsInvocation(BaseInvocation):
|
|||||||
self.dispatch_progress(context, source_node_id, state)
|
self.dispatch_progress(context, source_node_id, state)
|
||||||
|
|
||||||
model = self.get_model(context.services.model_manager)
|
model = self.get_model(context.services.model_manager)
|
||||||
conditioning_data = self.get_conditioning_data(context, model)
|
conditioning_data = self.get_conditioning_data(model)
|
||||||
|
|
||||||
# TODO: Verify the noise is the right size
|
# TODO: Verify the noise is the right size
|
||||||
|
|
||||||
@ -362,74 +363,9 @@ class LatentsToImageInvocation(BaseInvocation):
|
|||||||
session_id=context.graph_execution_state_id, node=self
|
session_id=context.graph_execution_state_id, node=self
|
||||||
)
|
)
|
||||||
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
|
|
||||||
context.services.images.save(image_type, image_name, image, metadata)
|
context.services.images.save(image_type, image_name, image, metadata)
|
||||||
return build_image_output(
|
return build_image_output(
|
||||||
image_type=image_type, image_name=image_name, image=image
|
image_type=image_type,
|
||||||
|
image_name=image_name,
|
||||||
|
image=image
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
LATENTS_INTERPOLATION_MODE = Literal[
|
|
||||||
"nearest", "linear", "bilinear", "bicubic", "trilinear", "area", "nearest-exact"
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
class ResizeLatentsInvocation(BaseInvocation):
|
|
||||||
"""Resizes latents to explicit width/height (in pixels). Provided dimensions are floor-divided by 8."""
|
|
||||||
|
|
||||||
type: Literal["lresize"] = "lresize"
|
|
||||||
|
|
||||||
# Inputs
|
|
||||||
latents: Optional[LatentsField] = Field(description="The latents to resize")
|
|
||||||
width: int = Field(ge=64, multiple_of=8, description="The width to resize to (px)")
|
|
||||||
height: int = Field(ge=64, multiple_of=8, description="The height to resize to (px)")
|
|
||||||
mode: Optional[LATENTS_INTERPOLATION_MODE] = Field(default="bilinear", description="The interpolation mode")
|
|
||||||
antialias: Optional[bool] = Field(default=False, description="Whether or not to antialias (applied in bilinear and bicubic modes only)")
|
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
|
||||||
latents = context.services.latents.get(self.latents.latents_name)
|
|
||||||
|
|
||||||
resized_latents = torch.nn.functional.interpolate(
|
|
||||||
latents,
|
|
||||||
size=(self.height // 8, self.width // 8),
|
|
||||||
mode=self.mode,
|
|
||||||
antialias=self.antialias if self.mode in ["bilinear", "bicubic"] else False,
|
|
||||||
)
|
|
||||||
|
|
||||||
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
|
|
||||||
name = f"{context.graph_execution_state_id}__{self.id}"
|
|
||||||
context.services.latents.set(name, resized_latents)
|
|
||||||
return LatentsOutput(latents=LatentsField(latents_name=name))
|
|
||||||
|
|
||||||
|
|
||||||
class ScaleLatentsInvocation(BaseInvocation):
|
|
||||||
"""Scales latents by a given factor."""
|
|
||||||
|
|
||||||
type: Literal["lscale"] = "lscale"
|
|
||||||
|
|
||||||
# Inputs
|
|
||||||
latents: Optional[LatentsField] = Field(description="The latents to scale")
|
|
||||||
scale_factor: float = Field(gt=0, description="The factor by which to scale the latents")
|
|
||||||
mode: Optional[LATENTS_INTERPOLATION_MODE] = Field(default="bilinear", description="The interpolation mode")
|
|
||||||
antialias: Optional[bool] = Field(default=False, description="Whether or not to antialias (applied in bilinear and bicubic modes only)")
|
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
|
||||||
latents = context.services.latents.get(self.latents.latents_name)
|
|
||||||
|
|
||||||
# resizing
|
|
||||||
resized_latents = torch.nn.functional.interpolate(
|
|
||||||
latents,
|
|
||||||
scale_factor=self.scale_factor,
|
|
||||||
mode=self.mode,
|
|
||||||
antialias=self.antialias if self.mode in ["bilinear", "bicubic"] else False,
|
|
||||||
)
|
|
||||||
|
|
||||||
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
|
|
||||||
name = f"{context.graph_execution_state_id}__{self.id}"
|
|
||||||
context.services.latents.set(name, resized_latents)
|
|
||||||
return LatentsOutput(latents=LatentsField(latents_name=name))
|
|
||||||
|
@ -3,11 +3,12 @@ from invokeai.backend.model_management.model_manager import ModelManager
|
|||||||
|
|
||||||
def choose_model(model_manager: ModelManager, model_name: str):
|
def choose_model(model_manager: ModelManager, model_name: str):
|
||||||
"""Returns the default model if the `model_name` not a valid model, else returns the selected model."""
|
"""Returns the default model if the `model_name` not a valid model, else returns the selected model."""
|
||||||
logger = model_manager.logger
|
|
||||||
if model_manager.valid_model(model_name):
|
if model_manager.valid_model(model_name):
|
||||||
model = model_manager.get_model(model_name)
|
model = model_manager.get_model(model_name)
|
||||||
else:
|
else:
|
||||||
model = model_manager.get_model()
|
model = model_manager.get_model()
|
||||||
logger.warning(f"{model_name}' is not a valid model name. Using default model \'{model['model_name']}\' instead.")
|
print(
|
||||||
|
f"* Warning: '{model_name}' is not a valid model name. Using default model \'{model['model_name']}\' instead."
|
||||||
|
)
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
@ -1,5 +1,4 @@
|
|||||||
from ..invocations.latent import LatentsToImageInvocation, NoiseInvocation, TextToLatentsInvocation
|
from ..invocations.latent import LatentsToImageInvocation, NoiseInvocation, TextToLatentsInvocation
|
||||||
from ..invocations.compel import CompelInvocation
|
|
||||||
from ..invocations.params import ParamIntInvocation
|
from ..invocations.params import ParamIntInvocation
|
||||||
from .graph import Edge, EdgeConnection, ExposedNodeInput, ExposedNodeOutput, Graph, LibraryGraph
|
from .graph import Edge, EdgeConnection, ExposedNodeInput, ExposedNodeOutput, Graph, LibraryGraph
|
||||||
from .item_storage import ItemStorageABC
|
from .item_storage import ItemStorageABC
|
||||||
@ -17,32 +16,26 @@ def create_text_to_image() -> LibraryGraph:
|
|||||||
nodes={
|
nodes={
|
||||||
'width': ParamIntInvocation(id='width', a=512),
|
'width': ParamIntInvocation(id='width', a=512),
|
||||||
'height': ParamIntInvocation(id='height', a=512),
|
'height': ParamIntInvocation(id='height', a=512),
|
||||||
'seed': ParamIntInvocation(id='seed', a=-1),
|
|
||||||
'3': NoiseInvocation(id='3'),
|
'3': NoiseInvocation(id='3'),
|
||||||
'4': CompelInvocation(id='4'),
|
'4': TextToLatentsInvocation(id='4'),
|
||||||
'5': CompelInvocation(id='5'),
|
'5': LatentsToImageInvocation(id='5')
|
||||||
'6': TextToLatentsInvocation(id='6'),
|
|
||||||
'7': LatentsToImageInvocation(id='7'),
|
|
||||||
},
|
},
|
||||||
edges=[
|
edges=[
|
||||||
Edge(source=EdgeConnection(node_id='width', field='a'), destination=EdgeConnection(node_id='3', field='width')),
|
Edge(source=EdgeConnection(node_id='width', field='a'), destination=EdgeConnection(node_id='3', field='width')),
|
||||||
Edge(source=EdgeConnection(node_id='height', field='a'), destination=EdgeConnection(node_id='3', field='height')),
|
Edge(source=EdgeConnection(node_id='height', field='a'), destination=EdgeConnection(node_id='3', field='height')),
|
||||||
Edge(source=EdgeConnection(node_id='seed', field='a'), destination=EdgeConnection(node_id='3', field='seed')),
|
Edge(source=EdgeConnection(node_id='width', field='a'), destination=EdgeConnection(node_id='4', field='width')),
|
||||||
Edge(source=EdgeConnection(node_id='3', field='noise'), destination=EdgeConnection(node_id='6', field='noise')),
|
Edge(source=EdgeConnection(node_id='height', field='a'), destination=EdgeConnection(node_id='4', field='height')),
|
||||||
Edge(source=EdgeConnection(node_id='6', field='latents'), destination=EdgeConnection(node_id='7', field='latents')),
|
Edge(source=EdgeConnection(node_id='3', field='noise'), destination=EdgeConnection(node_id='4', field='noise')),
|
||||||
Edge(source=EdgeConnection(node_id='4', field='conditioning'), destination=EdgeConnection(node_id='6', field='positive_conditioning')),
|
Edge(source=EdgeConnection(node_id='4', field='latents'), destination=EdgeConnection(node_id='5', field='latents')),
|
||||||
Edge(source=EdgeConnection(node_id='5', field='conditioning'), destination=EdgeConnection(node_id='6', field='negative_conditioning')),
|
|
||||||
]
|
]
|
||||||
),
|
),
|
||||||
exposed_inputs=[
|
exposed_inputs=[
|
||||||
ExposedNodeInput(node_path='4', field='prompt', alias='positive_prompt'),
|
ExposedNodeInput(node_path='4', field='prompt', alias='prompt'),
|
||||||
ExposedNodeInput(node_path='5', field='prompt', alias='negative_prompt'),
|
|
||||||
ExposedNodeInput(node_path='width', field='a', alias='width'),
|
ExposedNodeInput(node_path='width', field='a', alias='width'),
|
||||||
ExposedNodeInput(node_path='height', field='a', alias='height'),
|
ExposedNodeInput(node_path='height', field='a', alias='height')
|
||||||
ExposedNodeInput(node_path='seed', field='a', alias='seed'),
|
|
||||||
],
|
],
|
||||||
exposed_outputs=[
|
exposed_outputs=[
|
||||||
ExposedNodeOutput(node_path='7', field='image', alias='image')
|
ExposedNodeOutput(node_path='5', field='image', alias='image')
|
||||||
])
|
])
|
||||||
|
|
||||||
|
|
||||||
|
@ -5,16 +5,11 @@ from glob import glob
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from queue import Queue
|
from queue import Queue
|
||||||
from typing import Dict, List
|
from typing import Dict, List, Tuple
|
||||||
|
|
||||||
from PIL.Image import Image
|
from PIL.Image import Image
|
||||||
import PIL.Image as PILImage
|
import PIL.Image as PILImage
|
||||||
from send2trash import send2trash
|
from invokeai.app.api.models.images import ImageResponse, ImageResponseMetadata
|
||||||
from invokeai.app.api.models.images import (
|
|
||||||
ImageResponse,
|
|
||||||
ImageResponseMetadata,
|
|
||||||
SavedImage,
|
|
||||||
)
|
|
||||||
from invokeai.app.models.image import ImageType
|
from invokeai.app.models.image import ImageType
|
||||||
from invokeai.app.services.metadata import (
|
from invokeai.app.services.metadata import (
|
||||||
InvokeAIMetadata,
|
InvokeAIMetadata,
|
||||||
@ -46,15 +41,7 @@ class ImageStorageBase(ABC):
|
|||||||
def get_path(
|
def get_path(
|
||||||
self, image_type: ImageType, image_name: str, is_thumbnail: bool = False
|
self, image_type: ImageType, image_name: str, is_thumbnail: bool = False
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Gets the internal path to an image or its thumbnail."""
|
"""Gets the path to an image or its thumbnail."""
|
||||||
pass
|
|
||||||
|
|
||||||
# TODO: make this a bit more flexible for e.g. cloud storage
|
|
||||||
@abstractmethod
|
|
||||||
def get_uri(
|
|
||||||
self, image_type: ImageType, image_name: str, is_thumbnail: bool = False
|
|
||||||
) -> str:
|
|
||||||
"""Gets the external URI to an image or its thumbnail."""
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
# TODO: make this a bit more flexible for e.g. cloud storage
|
# TODO: make this a bit more flexible for e.g. cloud storage
|
||||||
@ -70,8 +57,8 @@ class ImageStorageBase(ABC):
|
|||||||
image_name: str,
|
image_name: str,
|
||||||
image: Image,
|
image: Image,
|
||||||
metadata: InvokeAIMetadata | None = None,
|
metadata: InvokeAIMetadata | None = None,
|
||||||
) -> SavedImage:
|
) -> Tuple[str, str, int]:
|
||||||
"""Saves an image and a 256x256 WEBP thumbnail. Returns a tuple of the image name, thumbnail name, and created timestamp."""
|
"""Saves an image and a 256x256 WEBP thumbnail. Returns a tuple of the image path, thumbnail path, and created timestamp."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
@ -139,8 +126,8 @@ class DiskImageStorage(ImageStorageBase):
|
|||||||
image_type=image_type.value,
|
image_type=image_type.value,
|
||||||
image_name=filename,
|
image_name=filename,
|
||||||
# TODO: DiskImageStorage should not be building URLs...?
|
# TODO: DiskImageStorage should not be building URLs...?
|
||||||
image_url=self.get_uri(image_type, filename),
|
image_url=f"api/v1/images/{image_type.value}/{filename}",
|
||||||
thumbnail_url=self.get_uri(image_type, filename, True),
|
thumbnail_url=f"api/v1/images/{image_type.value}/thumbnails/{os.path.splitext(filename)[0]}.webp",
|
||||||
# TODO: Creation of this object should happen elsewhere (?), just making it fit here so it works
|
# TODO: Creation of this object should happen elsewhere (?), just making it fit here so it works
|
||||||
metadata=ImageResponseMetadata(
|
metadata=ImageResponseMetadata(
|
||||||
created=int(os.path.getctime(path)),
|
created=int(os.path.getctime(path)),
|
||||||
@ -187,23 +174,7 @@ class DiskImageStorage(ImageStorageBase):
|
|||||||
else:
|
else:
|
||||||
path = os.path.join(self.__output_folder, image_type, basename)
|
path = os.path.join(self.__output_folder, image_type, basename)
|
||||||
|
|
||||||
abspath = os.path.abspath(path)
|
return path
|
||||||
|
|
||||||
return abspath
|
|
||||||
|
|
||||||
def get_uri(
|
|
||||||
self, image_type: ImageType, image_name: str, is_thumbnail: bool = False
|
|
||||||
) -> str:
|
|
||||||
# strip out any relative path shenanigans
|
|
||||||
basename = os.path.basename(image_name)
|
|
||||||
|
|
||||||
if is_thumbnail:
|
|
||||||
thumbnail_basename = get_thumbnail_name(basename)
|
|
||||||
uri = f"api/v1/images/{image_type.value}/thumbnails/{thumbnail_basename}"
|
|
||||||
else:
|
|
||||||
uri = f"api/v1/images/{image_type.value}/{basename}"
|
|
||||||
|
|
||||||
return uri
|
|
||||||
|
|
||||||
def validate_path(self, path: str) -> bool:
|
def validate_path(self, path: str) -> bool:
|
||||||
try:
|
try:
|
||||||
@ -218,7 +189,7 @@ class DiskImageStorage(ImageStorageBase):
|
|||||||
image_name: str,
|
image_name: str,
|
||||||
image: Image,
|
image: Image,
|
||||||
metadata: InvokeAIMetadata | None = None,
|
metadata: InvokeAIMetadata | None = None,
|
||||||
) -> SavedImage:
|
) -> Tuple[str, str, int]:
|
||||||
image_path = self.get_path(image_type, image_name)
|
image_path = self.get_path(image_type, image_name)
|
||||||
|
|
||||||
# TODO: Reading the image and then saving it strips the metadata...
|
# TODO: Reading the image and then saving it strips the metadata...
|
||||||
@ -226,7 +197,7 @@ class DiskImageStorage(ImageStorageBase):
|
|||||||
pnginfo = build_invokeai_metadata_pnginfo(metadata=metadata)
|
pnginfo = build_invokeai_metadata_pnginfo(metadata=metadata)
|
||||||
image.save(image_path, "PNG", pnginfo=pnginfo)
|
image.save(image_path, "PNG", pnginfo=pnginfo)
|
||||||
else:
|
else:
|
||||||
image.save(image_path) # this saved image has an empty info
|
image.save(image_path) # this saved image has an empty info
|
||||||
|
|
||||||
thumbnail_name = get_thumbnail_name(image_name)
|
thumbnail_name = get_thumbnail_name(image_name)
|
||||||
thumbnail_path = self.get_path(image_type, thumbnail_name, is_thumbnail=True)
|
thumbnail_path = self.get_path(image_type, thumbnail_name, is_thumbnail=True)
|
||||||
@ -236,30 +207,24 @@ class DiskImageStorage(ImageStorageBase):
|
|||||||
self.__set_cache(image_path, image)
|
self.__set_cache(image_path, image)
|
||||||
self.__set_cache(thumbnail_path, thumbnail_image)
|
self.__set_cache(thumbnail_path, thumbnail_image)
|
||||||
|
|
||||||
return SavedImage(
|
return (image_path, thumbnail_path, int(os.path.getctime(image_path)))
|
||||||
image_name=image_name,
|
|
||||||
thumbnail_name=thumbnail_name,
|
|
||||||
created=int(os.path.getctime(image_path)),
|
|
||||||
)
|
|
||||||
|
|
||||||
def delete(self, image_type: ImageType, image_name: str) -> None:
|
def delete(self, image_type: ImageType, image_name: str) -> None:
|
||||||
basename = os.path.basename(image_name)
|
image_path = self.get_path(image_type, image_name)
|
||||||
image_path = self.get_path(image_type, basename)
|
thumbnail_path = self.get_path(image_type, image_name, True)
|
||||||
|
|
||||||
if os.path.exists(image_path):
|
if os.path.exists(image_path):
|
||||||
send2trash(image_path)
|
os.remove(image_path)
|
||||||
|
|
||||||
if image_path in self.__cache:
|
if image_path in self.__cache:
|
||||||
del self.__cache[image_path]
|
del self.__cache[image_path]
|
||||||
|
|
||||||
thumbnail_name = get_thumbnail_name(image_name)
|
|
||||||
thumbnail_path = self.get_path(image_type, thumbnail_name, True)
|
|
||||||
|
|
||||||
if os.path.exists(thumbnail_path):
|
if os.path.exists(thumbnail_path):
|
||||||
send2trash(thumbnail_path)
|
os.remove(thumbnail_path)
|
||||||
|
|
||||||
if thumbnail_path in self.__cache:
|
if thumbnail_path in self.__cache:
|
||||||
del self.__cache[thumbnail_path]
|
del self.__cache[thumbnail_path]
|
||||||
|
|
||||||
def __get_cache(self, image_name: str) -> Image | None:
|
def __get_cache(self, image_name: str) -> Image:
|
||||||
return None if image_name not in self.__cache else self.__cache[image_name]
|
return None if image_name not in self.__cache else self.__cache[image_name]
|
||||||
|
|
||||||
def __set_cache(self, image_name: str, image: Image):
|
def __set_cache(self, image_name: str, image: Image):
|
||||||
|
@ -1,6 +1,4 @@
|
|||||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) and the InvokeAI Team
|
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||||
|
|
||||||
from typing import types
|
|
||||||
from invokeai.app.services.metadata import MetadataServiceBase
|
from invokeai.app.services.metadata import MetadataServiceBase
|
||||||
from invokeai.backend import ModelManager
|
from invokeai.backend import ModelManager
|
||||||
|
|
||||||
@ -31,7 +29,6 @@ class InvocationServices:
|
|||||||
self,
|
self,
|
||||||
model_manager: ModelManager,
|
model_manager: ModelManager,
|
||||||
events: EventServiceBase,
|
events: EventServiceBase,
|
||||||
logger: types.ModuleType,
|
|
||||||
latents: LatentsStorageBase,
|
latents: LatentsStorageBase,
|
||||||
images: ImageStorageBase,
|
images: ImageStorageBase,
|
||||||
metadata: MetadataServiceBase,
|
metadata: MetadataServiceBase,
|
||||||
@ -43,7 +40,6 @@ class InvocationServices:
|
|||||||
):
|
):
|
||||||
self.model_manager = model_manager
|
self.model_manager = model_manager
|
||||||
self.events = events
|
self.events = events
|
||||||
self.logger = logger
|
|
||||||
self.latents = latents
|
self.latents = latents
|
||||||
self.images = images
|
self.images = images
|
||||||
self.metadata = metadata
|
self.metadata = metadata
|
||||||
|
@ -71,12 +71,18 @@ class Invoker:
|
|||||||
for service in vars(self.services):
|
for service in vars(self.services):
|
||||||
self.__start_service(getattr(self.services, service))
|
self.__start_service(getattr(self.services, service))
|
||||||
|
|
||||||
|
for service in vars(self.services):
|
||||||
|
self.__start_service(getattr(self.services, service))
|
||||||
|
|
||||||
def stop(self) -> None:
|
def stop(self) -> None:
|
||||||
"""Stops the invoker. A new invoker will have to be created to execute further."""
|
"""Stops the invoker. A new invoker will have to be created to execute further."""
|
||||||
# First stop all services
|
# First stop all services
|
||||||
for service in vars(self.services):
|
for service in vars(self.services):
|
||||||
self.__stop_service(getattr(self.services, service))
|
self.__stop_service(getattr(self.services, service))
|
||||||
|
|
||||||
|
for service in vars(self.services):
|
||||||
|
self.__stop_service(getattr(self.services, service))
|
||||||
|
|
||||||
self.services.queue.put(None)
|
self.services.queue.put(None)
|
||||||
|
|
||||||
|
|
||||||
|
@ -5,7 +5,6 @@ from argparse import Namespace
|
|||||||
from invokeai.backend import Args
|
from invokeai.backend import Args
|
||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import types
|
|
||||||
|
|
||||||
import invokeai.version
|
import invokeai.version
|
||||||
from ...backend import ModelManager
|
from ...backend import ModelManager
|
||||||
@ -13,16 +12,16 @@ from ...backend.util import choose_precision, choose_torch_device
|
|||||||
from ...backend import Globals
|
from ...backend import Globals
|
||||||
|
|
||||||
# TODO: Replace with an abstract class base ModelManagerBase
|
# TODO: Replace with an abstract class base ModelManagerBase
|
||||||
def get_model_manager(config: Args, logger: types.ModuleType) -> ModelManager:
|
def get_model_manager(config: Args) -> ModelManager:
|
||||||
if not config.conf:
|
if not config.conf:
|
||||||
config_file = os.path.join(Globals.root, "configs", "models.yaml")
|
config_file = os.path.join(Globals.root, "configs", "models.yaml")
|
||||||
if not os.path.exists(config_file):
|
if not os.path.exists(config_file):
|
||||||
report_model_error(
|
report_model_error(
|
||||||
config, FileNotFoundError(f"The file {config_file} could not be found."), logger
|
config, FileNotFoundError(f"The file {config_file} could not be found.")
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info(f"{invokeai.version.__app_name__}, version {invokeai.version.__version__}")
|
print(f">> {invokeai.version.__app_name__}, version {invokeai.version.__version__}")
|
||||||
logger.info(f'InvokeAI runtime directory is "{Globals.root}"')
|
print(f'>> InvokeAI runtime directory is "{Globals.root}"')
|
||||||
|
|
||||||
# these two lines prevent a horrible warning message from appearing
|
# these two lines prevent a horrible warning message from appearing
|
||||||
# when the frozen CLIP tokenizer is imported
|
# when the frozen CLIP tokenizer is imported
|
||||||
@ -63,12 +62,11 @@ def get_model_manager(config: Args, logger: types.ModuleType) -> ModelManager:
|
|||||||
device_type=device,
|
device_type=device,
|
||||||
max_loaded_models=config.max_loaded_models,
|
max_loaded_models=config.max_loaded_models,
|
||||||
embedding_path = Path(embedding_path),
|
embedding_path = Path(embedding_path),
|
||||||
logger = logger,
|
|
||||||
)
|
)
|
||||||
except (FileNotFoundError, TypeError, AssertionError) as e:
|
except (FileNotFoundError, TypeError, AssertionError) as e:
|
||||||
report_model_error(config, e, logger)
|
report_model_error(config, e)
|
||||||
except (IOError, KeyError) as e:
|
except (IOError, KeyError) as e:
|
||||||
logger.error(f"{e}. Aborting.")
|
print(f"{e}. Aborting.")
|
||||||
sys.exit(-1)
|
sys.exit(-1)
|
||||||
|
|
||||||
# try to autoconvert new models
|
# try to autoconvert new models
|
||||||
@ -78,18 +76,18 @@ def get_model_manager(config: Args, logger: types.ModuleType) -> ModelManager:
|
|||||||
conf_path=config.conf,
|
conf_path=config.conf,
|
||||||
weights_directory=path,
|
weights_directory=path,
|
||||||
)
|
)
|
||||||
logger.info('Model manager initialized')
|
|
||||||
return model_manager
|
return model_manager
|
||||||
|
|
||||||
def report_model_error(opt: Namespace, e: Exception, logger: types.ModuleType):
|
def report_model_error(opt: Namespace, e: Exception):
|
||||||
logger.error(f'An error occurred while attempting to initialize the model: "{str(e)}"')
|
print(f'** An error occurred while attempting to initialize the model: "{str(e)}"')
|
||||||
logger.error(
|
print(
|
||||||
"This can be caused by a missing or corrupted models file, and can sometimes be fixed by (re)installing the models."
|
"** This can be caused by a missing or corrupted models file, and can sometimes be fixed by (re)installing the models."
|
||||||
)
|
)
|
||||||
yes_to_all = os.environ.get("INVOKE_MODEL_RECONFIGURE")
|
yes_to_all = os.environ.get("INVOKE_MODEL_RECONFIGURE")
|
||||||
if yes_to_all:
|
if yes_to_all:
|
||||||
logger.warning(
|
print(
|
||||||
"Reconfiguration is being forced by environment variable INVOKE_MODEL_RECONFIGURE"
|
"** Reconfiguration is being forced by environment variable INVOKE_MODEL_RECONFIGURE"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
response = input(
|
response = input(
|
||||||
@ -98,12 +96,13 @@ def report_model_error(opt: Namespace, e: Exception, logger: types.ModuleType):
|
|||||||
if response.startswith(("n", "N")):
|
if response.startswith(("n", "N")):
|
||||||
return
|
return
|
||||||
|
|
||||||
logger.info("invokeai-configure is launching....\n")
|
print("invokeai-configure is launching....\n")
|
||||||
|
|
||||||
# Match arguments that were set on the CLI
|
# Match arguments that were set on the CLI
|
||||||
# only the arguments accepted by the configuration script are parsed
|
# only the arguments accepted by the configuration script are parsed
|
||||||
root_dir = ["--root", opt.root_dir] if opt.root_dir is not None else []
|
root_dir = ["--root", opt.root_dir] if opt.root_dir is not None else []
|
||||||
config = ["--config", opt.conf] if opt.conf is not None else []
|
config = ["--config", opt.conf] if opt.conf is not None else []
|
||||||
|
previous_config = sys.argv
|
||||||
sys.argv = ["invokeai-configure"]
|
sys.argv = ["invokeai-configure"]
|
||||||
sys.argv.extend(root_dir)
|
sys.argv.extend(root_dir)
|
||||||
sys.argv.extend(config.to_dict())
|
sys.argv.extend(config.to_dict())
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
import traceback
|
import traceback
|
||||||
from threading import Event, Thread, BoundedSemaphore
|
from threading import Event, Thread
|
||||||
|
|
||||||
from ..invocations.baseinvocation import InvocationContext
|
from ..invocations.baseinvocation import InvocationContext
|
||||||
from .invocation_queue import InvocationQueueItem
|
from .invocation_queue import InvocationQueueItem
|
||||||
@ -10,11 +10,8 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
|||||||
__invoker_thread: Thread
|
__invoker_thread: Thread
|
||||||
__stop_event: Event
|
__stop_event: Event
|
||||||
__invoker: Invoker
|
__invoker: Invoker
|
||||||
__threadLimit: BoundedSemaphore
|
|
||||||
|
|
||||||
def start(self, invoker) -> None:
|
def start(self, invoker) -> None:
|
||||||
# if we do want multithreading at some point, we could make this configurable
|
|
||||||
self.__threadLimit = BoundedSemaphore(1)
|
|
||||||
self.__invoker = invoker
|
self.__invoker = invoker
|
||||||
self.__stop_event = Event()
|
self.__stop_event = Event()
|
||||||
self.__invoker_thread = Thread(
|
self.__invoker_thread = Thread(
|
||||||
@ -23,7 +20,7 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
|||||||
kwargs=dict(stop_event=self.__stop_event),
|
kwargs=dict(stop_event=self.__stop_event),
|
||||||
)
|
)
|
||||||
self.__invoker_thread.daemon = (
|
self.__invoker_thread.daemon = (
|
||||||
True # TODO: make async and do not use threads
|
True # TODO: probably better to just not use threads?
|
||||||
)
|
)
|
||||||
self.__invoker_thread.start()
|
self.__invoker_thread.start()
|
||||||
|
|
||||||
@ -32,7 +29,6 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
|||||||
|
|
||||||
def __process(self, stop_event: Event):
|
def __process(self, stop_event: Event):
|
||||||
try:
|
try:
|
||||||
self.__threadLimit.acquire()
|
|
||||||
while not stop_event.is_set():
|
while not stop_event.is_set():
|
||||||
queue_item: InvocationQueueItem = self.__invoker.services.queue.get()
|
queue_item: InvocationQueueItem = self.__invoker.services.queue.get()
|
||||||
if not queue_item: # Probably stopping
|
if not queue_item: # Probably stopping
|
||||||
@ -131,6 +127,4 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
|||||||
)
|
)
|
||||||
|
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
pass # Log something? KeyboardInterrupt is probably not going to be seen by the processor
|
... # Log something?
|
||||||
finally:
|
|
||||||
self.__threadLimit.release()
|
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
import sys
|
import sys
|
||||||
import traceback
|
import traceback
|
||||||
import torch
|
import torch
|
||||||
from typing import types
|
|
||||||
from ...backend.restoration import Restoration
|
from ...backend.restoration import Restoration
|
||||||
from ...backend.util import choose_torch_device, CPU_DEVICE, MPS_DEVICE
|
from ...backend.util import choose_torch_device, CPU_DEVICE, MPS_DEVICE
|
||||||
|
|
||||||
@ -11,7 +10,7 @@ from ...backend.util import choose_torch_device, CPU_DEVICE, MPS_DEVICE
|
|||||||
class RestorationServices:
|
class RestorationServices:
|
||||||
'''Face restoration and upscaling'''
|
'''Face restoration and upscaling'''
|
||||||
|
|
||||||
def __init__(self,args,logger:types.ModuleType):
|
def __init__(self,args):
|
||||||
try:
|
try:
|
||||||
gfpgan, codeformer, esrgan = None, None, None
|
gfpgan, codeformer, esrgan = None, None, None
|
||||||
if args.restore or args.esrgan:
|
if args.restore or args.esrgan:
|
||||||
@ -21,22 +20,20 @@ class RestorationServices:
|
|||||||
args.gfpgan_model_path
|
args.gfpgan_model_path
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
logger.info("Face restoration disabled")
|
print(">> Face restoration disabled")
|
||||||
if args.esrgan:
|
if args.esrgan:
|
||||||
esrgan = restoration.load_esrgan(args.esrgan_bg_tile)
|
esrgan = restoration.load_esrgan(args.esrgan_bg_tile)
|
||||||
else:
|
else:
|
||||||
logger.info("Upscaling disabled")
|
print(">> Upscaling disabled")
|
||||||
else:
|
else:
|
||||||
logger.info("Face restoration and upscaling disabled")
|
print(">> Face restoration and upscaling disabled")
|
||||||
except (ModuleNotFoundError, ImportError):
|
except (ModuleNotFoundError, ImportError):
|
||||||
print(traceback.format_exc(), file=sys.stderr)
|
print(traceback.format_exc(), file=sys.stderr)
|
||||||
logger.info("You may need to install the ESRGAN and/or GFPGAN modules")
|
print(">> You may need to install the ESRGAN and/or GFPGAN modules")
|
||||||
self.device = torch.device(choose_torch_device())
|
self.device = torch.device(choose_torch_device())
|
||||||
self.gfpgan = gfpgan
|
self.gfpgan = gfpgan
|
||||||
self.codeformer = codeformer
|
self.codeformer = codeformer
|
||||||
self.esrgan = esrgan
|
self.esrgan = esrgan
|
||||||
self.logger = logger
|
|
||||||
self.logger.info('Face restoration initialized')
|
|
||||||
|
|
||||||
# note that this one method does gfpgan and codepath reconstruction, as well as
|
# note that this one method does gfpgan and codepath reconstruction, as well as
|
||||||
# esrgan upscaling
|
# esrgan upscaling
|
||||||
@ -61,15 +58,15 @@ class RestorationServices:
|
|||||||
if self.gfpgan is not None or self.codeformer is not None:
|
if self.gfpgan is not None or self.codeformer is not None:
|
||||||
if facetool == "gfpgan":
|
if facetool == "gfpgan":
|
||||||
if self.gfpgan is None:
|
if self.gfpgan is None:
|
||||||
self.logger.info(
|
print(
|
||||||
"GFPGAN not found. Face restoration is disabled."
|
">> GFPGAN not found. Face restoration is disabled."
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
image = self.gfpgan.process(image, strength, seed)
|
image = self.gfpgan.process(image, strength, seed)
|
||||||
if facetool == "codeformer":
|
if facetool == "codeformer":
|
||||||
if self.codeformer is None:
|
if self.codeformer is None:
|
||||||
self.logger.info(
|
print(
|
||||||
"CodeFormer not found. Face restoration is disabled."
|
">> CodeFormer not found. Face restoration is disabled."
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
cf_device = (
|
cf_device = (
|
||||||
@ -83,7 +80,7 @@ class RestorationServices:
|
|||||||
fidelity=codeformer_fidelity,
|
fidelity=codeformer_fidelity,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.logger.info("Face Restoration is disabled.")
|
print(">> Face Restoration is disabled.")
|
||||||
if upscale is not None:
|
if upscale is not None:
|
||||||
if self.esrgan is not None:
|
if self.esrgan is not None:
|
||||||
if len(upscale) < 2:
|
if len(upscale) < 2:
|
||||||
@ -96,10 +93,10 @@ class RestorationServices:
|
|||||||
denoise_str=upscale_denoise_str,
|
denoise_str=upscale_denoise_str,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.logger.info("ESRGAN is disabled. Image not upscaled.")
|
print(">> ESRGAN is disabled. Image not upscaled.")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.logger.info(
|
print(
|
||||||
f"Error running RealESRGAN or GFPGAN. Your image was not upscaled.\n{e}"
|
f">> Error running RealESRGAN or GFPGAN. Your image was not upscaled.\n{e}"
|
||||||
)
|
)
|
||||||
|
|
||||||
if image_callback is not None:
|
if image_callback is not None:
|
||||||
|
@ -96,7 +96,6 @@ from pathlib import Path
|
|||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
import invokeai.version
|
import invokeai.version
|
||||||
import invokeai.backend.util.logging as logger
|
|
||||||
from invokeai.backend.image_util import retrieve_metadata
|
from invokeai.backend.image_util import retrieve_metadata
|
||||||
|
|
||||||
from .globals import Globals
|
from .globals import Globals
|
||||||
@ -190,7 +189,7 @@ class Args(object):
|
|||||||
print(f"{APP_NAME} {APP_VERSION}")
|
print(f"{APP_NAME} {APP_VERSION}")
|
||||||
sys.exit(0)
|
sys.exit(0)
|
||||||
|
|
||||||
logger.info("Initializing, be patient...")
|
print("* Initializing, be patient...")
|
||||||
Globals.root = Path(os.path.abspath(switches.root_dir or Globals.root))
|
Globals.root = Path(os.path.abspath(switches.root_dir or Globals.root))
|
||||||
Globals.try_patchmatch = switches.patchmatch
|
Globals.try_patchmatch = switches.patchmatch
|
||||||
|
|
||||||
@ -198,13 +197,14 @@ class Args(object):
|
|||||||
initfile = os.path.expanduser(os.path.join(Globals.root, Globals.initfile))
|
initfile = os.path.expanduser(os.path.join(Globals.root, Globals.initfile))
|
||||||
legacyinit = os.path.expanduser("~/.invokeai")
|
legacyinit = os.path.expanduser("~/.invokeai")
|
||||||
if os.path.exists(initfile):
|
if os.path.exists(initfile):
|
||||||
logger.info(
|
print(
|
||||||
f"Initialization file {initfile} found. Loading...",
|
f">> Initialization file {initfile} found. Loading...",
|
||||||
|
file=sys.stderr,
|
||||||
)
|
)
|
||||||
sysargs.insert(0, f"@{initfile}")
|
sysargs.insert(0, f"@{initfile}")
|
||||||
elif os.path.exists(legacyinit):
|
elif os.path.exists(legacyinit):
|
||||||
logger.warning(
|
print(
|
||||||
f"Old initialization file found at {legacyinit}. This location is deprecated. Please move it to {Globals.root}/invokeai.init."
|
f">> WARNING: Old initialization file found at {legacyinit}. This location is deprecated. Please move it to {Globals.root}/invokeai.init."
|
||||||
)
|
)
|
||||||
sysargs.insert(0, f"@{legacyinit}")
|
sysargs.insert(0, f"@{legacyinit}")
|
||||||
Globals.log_tokenization = self._arg_parser.parse_args(
|
Globals.log_tokenization = self._arg_parser.parse_args(
|
||||||
@ -214,7 +214,7 @@ class Args(object):
|
|||||||
self._arg_switches = self._arg_parser.parse_args(sysargs)
|
self._arg_switches = self._arg_parser.parse_args(sysargs)
|
||||||
return self._arg_switches
|
return self._arg_switches
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"An exception has occurred: {e}")
|
print(f"An exception has occurred: {e}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def parse_cmd(self, cmd_string):
|
def parse_cmd(self, cmd_string):
|
||||||
@ -1154,7 +1154,7 @@ class Args(object):
|
|||||||
|
|
||||||
|
|
||||||
def format_metadata(**kwargs):
|
def format_metadata(**kwargs):
|
||||||
logger.warning("format_metadata() is deprecated. Please use metadata_dumps()")
|
print("format_metadata() is deprecated. Please use metadata_dumps()")
|
||||||
return metadata_dumps(kwargs)
|
return metadata_dumps(kwargs)
|
||||||
|
|
||||||
|
|
||||||
@ -1326,7 +1326,7 @@ def metadata_loads(metadata) -> list:
|
|||||||
import sys
|
import sys
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
logger.error("Could not read metadata")
|
print(">> could not read metadata", file=sys.stderr)
|
||||||
print(traceback.format_exc(), file=sys.stderr)
|
print(traceback.format_exc(), file=sys.stderr)
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
@ -27,7 +27,6 @@ from diffusers.utils.import_utils import is_xformers_available
|
|||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import invokeai.backend.util.logging as logger
|
|
||||||
from .args import metadata_from_png
|
from .args import metadata_from_png
|
||||||
from .generator import infill_methods
|
from .generator import infill_methods
|
||||||
from .globals import Globals, global_cache_dir
|
from .globals import Globals, global_cache_dir
|
||||||
@ -196,12 +195,12 @@ class Generate:
|
|||||||
# device to Generate(). However the device was then ignored, so
|
# device to Generate(). However the device was then ignored, so
|
||||||
# it wasn't actually doing anything. This logic could be reinstated.
|
# it wasn't actually doing anything. This logic could be reinstated.
|
||||||
self.device = torch.device(choose_torch_device())
|
self.device = torch.device(choose_torch_device())
|
||||||
logger.info(f"Using device_type {self.device.type}")
|
print(f">> Using device_type {self.device.type}")
|
||||||
if full_precision:
|
if full_precision:
|
||||||
if self.precision != "auto":
|
if self.precision != "auto":
|
||||||
raise ValueError("Remove --full_precision / -F if using --precision")
|
raise ValueError("Remove --full_precision / -F if using --precision")
|
||||||
logger.warning("Please remove deprecated --full_precision / -F")
|
print("Please remove deprecated --full_precision / -F")
|
||||||
logger.warning("If auto config does not work you can use --precision=float32")
|
print("If auto config does not work you can use --precision=float32")
|
||||||
self.precision = "float32"
|
self.precision = "float32"
|
||||||
if self.precision == "auto":
|
if self.precision == "auto":
|
||||||
self.precision = choose_precision(self.device)
|
self.precision = choose_precision(self.device)
|
||||||
@ -209,13 +208,13 @@ class Generate:
|
|||||||
|
|
||||||
if is_xformers_available():
|
if is_xformers_available():
|
||||||
if torch.cuda.is_available() and not Globals.disable_xformers:
|
if torch.cuda.is_available() and not Globals.disable_xformers:
|
||||||
logger.info("xformers memory-efficient attention is available and enabled")
|
print(">> xformers memory-efficient attention is available and enabled")
|
||||||
else:
|
else:
|
||||||
logger.info(
|
print(
|
||||||
"xformers memory-efficient attention is available but disabled"
|
">> xformers memory-efficient attention is available but disabled"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
logger.info("xformers not installed")
|
print(">> xformers not installed")
|
||||||
|
|
||||||
# model caching system for fast switching
|
# model caching system for fast switching
|
||||||
self.model_manager = ModelManager(
|
self.model_manager = ModelManager(
|
||||||
@ -230,8 +229,8 @@ class Generate:
|
|||||||
fallback = self.model_manager.default_model() or FALLBACK_MODEL_NAME
|
fallback = self.model_manager.default_model() or FALLBACK_MODEL_NAME
|
||||||
model = model or fallback
|
model = model or fallback
|
||||||
if not self.model_manager.valid_model(model):
|
if not self.model_manager.valid_model(model):
|
||||||
logger.warning(
|
print(
|
||||||
f'"{model}" is not a known model name; falling back to {fallback}.'
|
f'** "{model}" is not a known model name; falling back to {fallback}.'
|
||||||
)
|
)
|
||||||
model = None
|
model = None
|
||||||
self.model_name = model or fallback
|
self.model_name = model or fallback
|
||||||
@ -247,10 +246,10 @@ class Generate:
|
|||||||
|
|
||||||
# load safety checker if requested
|
# load safety checker if requested
|
||||||
if safety_checker:
|
if safety_checker:
|
||||||
logger.info("Initializing NSFW checker")
|
print(">> Initializing NSFW checker")
|
||||||
self.safety_checker = SafetyChecker(self.device)
|
self.safety_checker = SafetyChecker(self.device)
|
||||||
else:
|
else:
|
||||||
logger.info("NSFW checker is disabled")
|
print(">> NSFW checker is disabled")
|
||||||
|
|
||||||
def prompt2png(self, prompt, outdir, **kwargs):
|
def prompt2png(self, prompt, outdir, **kwargs):
|
||||||
"""
|
"""
|
||||||
@ -568,7 +567,7 @@ class Generate:
|
|||||||
self.clear_cuda_cache()
|
self.clear_cuda_cache()
|
||||||
|
|
||||||
if catch_interrupts:
|
if catch_interrupts:
|
||||||
logger.warning("Interrupted** Partial results will be returned.")
|
print("**Interrupted** Partial results will be returned.")
|
||||||
else:
|
else:
|
||||||
raise KeyboardInterrupt
|
raise KeyboardInterrupt
|
||||||
except RuntimeError:
|
except RuntimeError:
|
||||||
@ -576,11 +575,11 @@ class Generate:
|
|||||||
self.clear_cuda_cache()
|
self.clear_cuda_cache()
|
||||||
|
|
||||||
print(traceback.format_exc(), file=sys.stderr)
|
print(traceback.format_exc(), file=sys.stderr)
|
||||||
logger.info("Could not generate image.")
|
print(">> Could not generate image.")
|
||||||
|
|
||||||
toc = time.time()
|
toc = time.time()
|
||||||
logger.info("Usage stats:")
|
print("\n>> Usage stats:")
|
||||||
logger.info(f"{len(results)} image(s) generated in "+"%4.2fs" % (toc - tic))
|
print(f">> {len(results)} image(s) generated in", "%4.2fs" % (toc - tic))
|
||||||
self.print_cuda_stats()
|
self.print_cuda_stats()
|
||||||
return results
|
return results
|
||||||
|
|
||||||
@ -610,16 +609,16 @@ class Generate:
|
|||||||
def print_cuda_stats(self):
|
def print_cuda_stats(self):
|
||||||
if self._has_cuda():
|
if self._has_cuda():
|
||||||
self.gather_cuda_stats()
|
self.gather_cuda_stats()
|
||||||
logger.info(
|
print(
|
||||||
"Max VRAM used for this generation: "+
|
">> Max VRAM used for this generation:",
|
||||||
"%4.2fG. " % (self.max_memory_allocated / 1e9)+
|
"%4.2fG." % (self.max_memory_allocated / 1e9),
|
||||||
"Current VRAM utilization: "+
|
"Current VRAM utilization:",
|
||||||
"%4.2fG" % (self.memory_allocated / 1e9)
|
"%4.2fG" % (self.memory_allocated / 1e9),
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info(
|
print(
|
||||||
"Max VRAM used since script start: " +
|
">> Max VRAM used since script start: ",
|
||||||
"%4.2fG" % (self.session_peakmem / 1e9)
|
"%4.2fG" % (self.session_peakmem / 1e9),
|
||||||
)
|
)
|
||||||
|
|
||||||
# this needs to be generalized to all sorts of postprocessors, which should be wrapped
|
# this needs to be generalized to all sorts of postprocessors, which should be wrapped
|
||||||
@ -648,7 +647,7 @@ class Generate:
|
|||||||
seed = random.randrange(0, np.iinfo(np.uint32).max)
|
seed = random.randrange(0, np.iinfo(np.uint32).max)
|
||||||
|
|
||||||
prompt = opt.prompt or args.prompt or ""
|
prompt = opt.prompt or args.prompt or ""
|
||||||
logger.info(f'using seed {seed} and prompt "{prompt}" for {image_path}')
|
print(f'>> using seed {seed} and prompt "{prompt}" for {image_path}')
|
||||||
|
|
||||||
# try to reuse the same filename prefix as the original file.
|
# try to reuse the same filename prefix as the original file.
|
||||||
# we take everything up to the first period
|
# we take everything up to the first period
|
||||||
@ -697,8 +696,8 @@ class Generate:
|
|||||||
try:
|
try:
|
||||||
extend_instructions[direction] = int(pixels)
|
extend_instructions[direction] = int(pixels)
|
||||||
except ValueError:
|
except ValueError:
|
||||||
logger.warning(
|
print(
|
||||||
'invalid extension instruction. Use <directions> <pixels>..., as in "top 64 left 128 right 64 bottom 64"'
|
'** invalid extension instruction. Use <directions> <pixels>..., as in "top 64 left 128 right 64 bottom 64"'
|
||||||
)
|
)
|
||||||
|
|
||||||
opt.seed = seed
|
opt.seed = seed
|
||||||
@ -721,8 +720,8 @@ class Generate:
|
|||||||
# fetch the metadata from the image
|
# fetch the metadata from the image
|
||||||
generator = self.select_generator(embiggen=True)
|
generator = self.select_generator(embiggen=True)
|
||||||
opt.strength = opt.embiggen_strength or 0.40
|
opt.strength = opt.embiggen_strength or 0.40
|
||||||
logger.info(
|
print(
|
||||||
f"Setting img2img strength to {opt.strength} for happy embiggening"
|
f">> Setting img2img strength to {opt.strength} for happy embiggening"
|
||||||
)
|
)
|
||||||
generator.generate(
|
generator.generate(
|
||||||
prompt,
|
prompt,
|
||||||
@ -749,12 +748,12 @@ class Generate:
|
|||||||
return restorer.process(opt, args, image_callback=callback, prefix=prefix)
|
return restorer.process(opt, args, image_callback=callback, prefix=prefix)
|
||||||
|
|
||||||
elif tool is None:
|
elif tool is None:
|
||||||
logger.warning(
|
print(
|
||||||
"please provide at least one postprocessing option, such as -G or -U"
|
"* please provide at least one postprocessing option, such as -G or -U"
|
||||||
)
|
)
|
||||||
return None
|
return None
|
||||||
else:
|
else:
|
||||||
logger.warning(f"postprocessing tool {tool} is not yet supported")
|
print(f"* postprocessing tool {tool} is not yet supported")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def select_generator(
|
def select_generator(
|
||||||
@ -798,8 +797,8 @@ class Generate:
|
|||||||
image = self._load_img(img)
|
image = self._load_img(img)
|
||||||
|
|
||||||
if image.width < self.width and image.height < self.height:
|
if image.width < self.width and image.height < self.height:
|
||||||
logger.warning(
|
print(
|
||||||
f"img2img and inpainting may produce unexpected results with initial images smaller than {self.width}x{self.height} in both dimensions"
|
f">> WARNING: img2img and inpainting may produce unexpected results with initial images smaller than {self.width}x{self.height} in both dimensions"
|
||||||
)
|
)
|
||||||
|
|
||||||
# if image has a transparent area and no mask was provided, then try to generate mask
|
# if image has a transparent area and no mask was provided, then try to generate mask
|
||||||
@ -810,8 +809,8 @@ class Generate:
|
|||||||
if (image.width * image.height) > (
|
if (image.width * image.height) > (
|
||||||
self.width * self.height
|
self.width * self.height
|
||||||
) and self.size_matters:
|
) and self.size_matters:
|
||||||
logger.info(
|
print(
|
||||||
"This input is larger than your defaults. If you run out of memory, please use a smaller image."
|
">> This input is larger than your defaults. If you run out of memory, please use a smaller image."
|
||||||
)
|
)
|
||||||
self.size_matters = False
|
self.size_matters = False
|
||||||
|
|
||||||
@ -892,11 +891,11 @@ class Generate:
|
|||||||
try:
|
try:
|
||||||
model_data = cache.get_model(model_name)
|
model_data = cache.get_model(model_name)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"model {model_name} could not be loaded: {str(e)}")
|
print(f"** model {model_name} could not be loaded: {str(e)}")
|
||||||
print(traceback.format_exc(), file=sys.stderr)
|
print(traceback.format_exc(), file=sys.stderr)
|
||||||
if previous_model_name is None:
|
if previous_model_name is None:
|
||||||
raise e
|
raise e
|
||||||
logger.warning("trying to reload previous model")
|
print("** trying to reload previous model")
|
||||||
model_data = cache.get_model(previous_model_name) # load previous
|
model_data = cache.get_model(previous_model_name) # load previous
|
||||||
if model_data is None:
|
if model_data is None:
|
||||||
raise e
|
raise e
|
||||||
@ -963,15 +962,15 @@ class Generate:
|
|||||||
if self.gfpgan is not None or self.codeformer is not None:
|
if self.gfpgan is not None or self.codeformer is not None:
|
||||||
if facetool == "gfpgan":
|
if facetool == "gfpgan":
|
||||||
if self.gfpgan is None:
|
if self.gfpgan is None:
|
||||||
logger.info(
|
print(
|
||||||
"GFPGAN not found. Face restoration is disabled."
|
">> GFPGAN not found. Face restoration is disabled."
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
image = self.gfpgan.process(image, strength, seed)
|
image = self.gfpgan.process(image, strength, seed)
|
||||||
if facetool == "codeformer":
|
if facetool == "codeformer":
|
||||||
if self.codeformer is None:
|
if self.codeformer is None:
|
||||||
logger.info(
|
print(
|
||||||
"CodeFormer not found. Face restoration is disabled."
|
">> CodeFormer not found. Face restoration is disabled."
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
cf_device = (
|
cf_device = (
|
||||||
@ -985,7 +984,7 @@ class Generate:
|
|||||||
fidelity=codeformer_fidelity,
|
fidelity=codeformer_fidelity,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
logger.info("Face Restoration is disabled.")
|
print(">> Face Restoration is disabled.")
|
||||||
if upscale is not None:
|
if upscale is not None:
|
||||||
if self.esrgan is not None:
|
if self.esrgan is not None:
|
||||||
if len(upscale) < 2:
|
if len(upscale) < 2:
|
||||||
@ -998,10 +997,10 @@ class Generate:
|
|||||||
denoise_str=upscale_denoise_str,
|
denoise_str=upscale_denoise_str,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
logger.info("ESRGAN is disabled. Image not upscaled.")
|
print(">> ESRGAN is disabled. Image not upscaled.")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.info(
|
print(
|
||||||
f"Error running RealESRGAN or GFPGAN. Your image was not upscaled.\n{e}"
|
f">> Error running RealESRGAN or GFPGAN. Your image was not upscaled.\n{e}"
|
||||||
)
|
)
|
||||||
|
|
||||||
if image_callback is not None:
|
if image_callback is not None:
|
||||||
@ -1067,17 +1066,17 @@ class Generate:
|
|||||||
if self.sampler_name in scheduler_map:
|
if self.sampler_name in scheduler_map:
|
||||||
sampler_class = scheduler_map[self.sampler_name]
|
sampler_class = scheduler_map[self.sampler_name]
|
||||||
msg = (
|
msg = (
|
||||||
f"Setting Sampler to {self.sampler_name} ({sampler_class.__name__})"
|
f">> Setting Sampler to {self.sampler_name} ({sampler_class.__name__})"
|
||||||
)
|
)
|
||||||
self.sampler = sampler_class.from_config(self.model.scheduler.config)
|
self.sampler = sampler_class.from_config(self.model.scheduler.config)
|
||||||
else:
|
else:
|
||||||
msg = (
|
msg = (
|
||||||
f" Unsupported Sampler: {self.sampler_name} "+
|
f">> Unsupported Sampler: {self.sampler_name} "
|
||||||
f"Defaulting to {default}"
|
f"Defaulting to {default}"
|
||||||
)
|
)
|
||||||
self.sampler = default
|
self.sampler = default
|
||||||
|
|
||||||
logger.info(msg)
|
print(msg)
|
||||||
|
|
||||||
if not hasattr(self.sampler, "uses_inpainting_model"):
|
if not hasattr(self.sampler, "uses_inpainting_model"):
|
||||||
# FIXME: terrible kludge!
|
# FIXME: terrible kludge!
|
||||||
@ -1086,17 +1085,17 @@ class Generate:
|
|||||||
def _load_img(self, img) -> Image:
|
def _load_img(self, img) -> Image:
|
||||||
if isinstance(img, Image.Image):
|
if isinstance(img, Image.Image):
|
||||||
image = img
|
image = img
|
||||||
logger.info(f"using provided input image of size {image.width}x{image.height}")
|
print(f">> using provided input image of size {image.width}x{image.height}")
|
||||||
elif isinstance(img, str):
|
elif isinstance(img, str):
|
||||||
assert os.path.exists(img), f"{img}: File not found"
|
assert os.path.exists(img), f">> {img}: File not found"
|
||||||
|
|
||||||
image = Image.open(img)
|
image = Image.open(img)
|
||||||
logger.info(
|
print(
|
||||||
f"loaded input image of size {image.width}x{image.height} from {img}"
|
f">> loaded input image of size {image.width}x{image.height} from {img}"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
image = Image.open(img)
|
image = Image.open(img)
|
||||||
logger.info(f"loaded input image of size {image.width}x{image.height}")
|
print(f">> loaded input image of size {image.width}x{image.height}")
|
||||||
image = ImageOps.exif_transpose(image)
|
image = ImageOps.exif_transpose(image)
|
||||||
return image
|
return image
|
||||||
|
|
||||||
@ -1184,14 +1183,14 @@ class Generate:
|
|||||||
|
|
||||||
def _transparency_check_and_warning(self, image, mask, force_outpaint=False):
|
def _transparency_check_and_warning(self, image, mask, force_outpaint=False):
|
||||||
if not mask:
|
if not mask:
|
||||||
logger.info(
|
print(
|
||||||
"Initial image has transparent areas. Will inpaint in these regions."
|
">> Initial image has transparent areas. Will inpaint in these regions."
|
||||||
)
|
)
|
||||||
if (not force_outpaint) and self._check_for_erasure(image):
|
if (not force_outpaint) and self._check_for_erasure(image):
|
||||||
logger.info(
|
print(
|
||||||
"Colors underneath the transparent region seem to have been erased.\n" +
|
">> WARNING: Colors underneath the transparent region seem to have been erased.\n",
|
||||||
"Inpainting will be suboptimal. Please preserve the colors when making\n" +
|
">> Inpainting will be suboptimal. Please preserve the colors when making\n",
|
||||||
"a transparency mask, or provide mask explicitly using --init_mask (-M)."
|
">> a transparency mask, or provide mask explicitly using --init_mask (-M).",
|
||||||
)
|
)
|
||||||
|
|
||||||
def _squeeze_image(self, image):
|
def _squeeze_image(self, image):
|
||||||
@ -1202,11 +1201,11 @@ class Generate:
|
|||||||
|
|
||||||
def _fit_image(self, image, max_dimensions):
|
def _fit_image(self, image, max_dimensions):
|
||||||
w, h = max_dimensions
|
w, h = max_dimensions
|
||||||
logger.info(f"image will be resized to fit inside a box {w}x{h} in size.")
|
print(f">> image will be resized to fit inside a box {w}x{h} in size.")
|
||||||
# note that InitImageResizer does the multiple of 64 truncation internally
|
# note that InitImageResizer does the multiple of 64 truncation internally
|
||||||
image = InitImageResizer(image).resize(width=w, height=h)
|
image = InitImageResizer(image).resize(width=w, height=h)
|
||||||
logger.info(
|
print(
|
||||||
f"after adjusting image dimensions to be multiples of 64, init image is {image.width}x{image.height}"
|
f">> after adjusting image dimensions to be multiples of 64, init image is {image.width}x{image.height}"
|
||||||
)
|
)
|
||||||
return image
|
return image
|
||||||
|
|
||||||
@ -1217,8 +1216,8 @@ class Generate:
|
|||||||
) # resize to integer multiple of 64
|
) # resize to integer multiple of 64
|
||||||
if h != height or w != width:
|
if h != height or w != width:
|
||||||
if log:
|
if log:
|
||||||
logger.info(
|
print(
|
||||||
f"Provided width and height must be multiples of 64. Auto-resizing to {w}x{h}"
|
f">> Provided width and height must be multiples of 64. Auto-resizing to {w}x{h}"
|
||||||
)
|
)
|
||||||
height = h
|
height = h
|
||||||
width = w
|
width = w
|
||||||
|
@ -25,7 +25,6 @@ from typing import Callable, List, Iterator, Optional, Type
|
|||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from diffusers.schedulers import SchedulerMixin as Scheduler
|
from diffusers.schedulers import SchedulerMixin as Scheduler
|
||||||
|
|
||||||
import invokeai.backend.util.logging as logger
|
|
||||||
from ..image_util import configure_model_padding
|
from ..image_util import configure_model_padding
|
||||||
from ..util.util import rand_perlin_2d
|
from ..util.util import rand_perlin_2d
|
||||||
from ..safety_checker import SafetyChecker
|
from ..safety_checker import SafetyChecker
|
||||||
@ -373,7 +372,7 @@ class Generator:
|
|||||||
try:
|
try:
|
||||||
x_T = self.get_noise(width, height)
|
x_T = self.get_noise(width, height)
|
||||||
except:
|
except:
|
||||||
logger.error("An error occurred while getting initial noise")
|
print("** An error occurred while getting initial noise **")
|
||||||
print(traceback.format_exc())
|
print(traceback.format_exc())
|
||||||
|
|
||||||
# Pass on the seed in case a layer beneath us needs to generate noise on its own.
|
# Pass on the seed in case a layer beneath us needs to generate noise on its own.
|
||||||
@ -608,7 +607,7 @@ class Generator:
|
|||||||
image = self.sample_to_image(sample)
|
image = self.sample_to_image(sample)
|
||||||
dirname = os.path.dirname(filepath) or "."
|
dirname = os.path.dirname(filepath) or "."
|
||||||
if not os.path.exists(dirname):
|
if not os.path.exists(dirname):
|
||||||
logger.info(f"creating directory {dirname}")
|
print(f"** creating directory {dirname}")
|
||||||
os.makedirs(dirname, exist_ok=True)
|
os.makedirs(dirname, exist_ok=True)
|
||||||
image.save(filepath, "PNG")
|
image.save(filepath, "PNG")
|
||||||
|
|
||||||
|
@ -8,11 +8,10 @@ import torch
|
|||||||
from PIL import Image
|
from PIL import Image
|
||||||
from tqdm import trange
|
from tqdm import trange
|
||||||
|
|
||||||
import invokeai.backend.util.logging as logger
|
|
||||||
|
|
||||||
from .base import Generator
|
from .base import Generator
|
||||||
from .img2img import Img2Img
|
from .img2img import Img2Img
|
||||||
|
|
||||||
|
|
||||||
class Embiggen(Generator):
|
class Embiggen(Generator):
|
||||||
def __init__(self, model, precision):
|
def __init__(self, model, precision):
|
||||||
super().__init__(model, precision)
|
super().__init__(model, precision)
|
||||||
@ -73,22 +72,22 @@ class Embiggen(Generator):
|
|||||||
embiggen = [1.0] # If not specified, assume no scaling
|
embiggen = [1.0] # If not specified, assume no scaling
|
||||||
elif embiggen[0] < 0:
|
elif embiggen[0] < 0:
|
||||||
embiggen[0] = 1.0
|
embiggen[0] = 1.0
|
||||||
logger.warning(
|
print(
|
||||||
"Embiggen scaling factor cannot be negative, fell back to the default of 1.0 !"
|
">> Embiggen scaling factor cannot be negative, fell back to the default of 1.0 !"
|
||||||
)
|
)
|
||||||
if len(embiggen) < 2:
|
if len(embiggen) < 2:
|
||||||
embiggen.append(0.75)
|
embiggen.append(0.75)
|
||||||
elif embiggen[1] > 1.0 or embiggen[1] < 0:
|
elif embiggen[1] > 1.0 or embiggen[1] < 0:
|
||||||
embiggen[1] = 0.75
|
embiggen[1] = 0.75
|
||||||
logger.warning(
|
print(
|
||||||
"Embiggen upscaling strength for ESRGAN must be between 0 and 1, fell back to the default of 0.75 !"
|
">> Embiggen upscaling strength for ESRGAN must be between 0 and 1, fell back to the default of 0.75 !"
|
||||||
)
|
)
|
||||||
if len(embiggen) < 3:
|
if len(embiggen) < 3:
|
||||||
embiggen.append(0.25)
|
embiggen.append(0.25)
|
||||||
elif embiggen[2] < 0:
|
elif embiggen[2] < 0:
|
||||||
embiggen[2] = 0.25
|
embiggen[2] = 0.25
|
||||||
logger.warning(
|
print(
|
||||||
"Overlap size for Embiggen must be a positive ratio between 0 and 1 OR a number of pixels, fell back to the default of 0.25 !"
|
">> Overlap size for Embiggen must be a positive ratio between 0 and 1 OR a number of pixels, fell back to the default of 0.25 !"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Convert tiles from their user-freindly count-from-one to count-from-zero, because we need to do modulo math
|
# Convert tiles from their user-freindly count-from-one to count-from-zero, because we need to do modulo math
|
||||||
@ -98,8 +97,8 @@ class Embiggen(Generator):
|
|||||||
embiggen_tiles.sort()
|
embiggen_tiles.sort()
|
||||||
|
|
||||||
if strength >= 0.5:
|
if strength >= 0.5:
|
||||||
logger.warning(
|
print(
|
||||||
f"Embiggen may produce mirror motifs if the strength (-f) is too high (currently {strength}). Try values between 0.35-0.45."
|
f"* WARNING: Embiggen may produce mirror motifs if the strength (-f) is too high (currently {strength}). Try values between 0.35-0.45."
|
||||||
)
|
)
|
||||||
|
|
||||||
# Prep img2img generator, since we wrap over it
|
# Prep img2img generator, since we wrap over it
|
||||||
@ -122,8 +121,8 @@ class Embiggen(Generator):
|
|||||||
from ..restoration.realesrgan import ESRGAN
|
from ..restoration.realesrgan import ESRGAN
|
||||||
|
|
||||||
esrgan = ESRGAN()
|
esrgan = ESRGAN()
|
||||||
logger.info(
|
print(
|
||||||
f"ESRGAN upscaling init image prior to cutting with Embiggen with strength {embiggen[1]}"
|
f">> ESRGAN upscaling init image prior to cutting with Embiggen with strength {embiggen[1]}"
|
||||||
)
|
)
|
||||||
if embiggen[0] > 2:
|
if embiggen[0] > 2:
|
||||||
initsuperimage = esrgan.process(
|
initsuperimage = esrgan.process(
|
||||||
@ -313,10 +312,10 @@ class Embiggen(Generator):
|
|||||||
def make_image():
|
def make_image():
|
||||||
# Make main tiles -------------------------------------------------
|
# Make main tiles -------------------------------------------------
|
||||||
if embiggen_tiles:
|
if embiggen_tiles:
|
||||||
logger.info(f"Making {len(embiggen_tiles)} Embiggen tiles...")
|
print(f">> Making {len(embiggen_tiles)} Embiggen tiles...")
|
||||||
else:
|
else:
|
||||||
logger.info(
|
print(
|
||||||
f"Making {(emb_tiles_x * emb_tiles_y)} Embiggen tiles ({emb_tiles_x}x{emb_tiles_y})..."
|
f">> Making {(emb_tiles_x * emb_tiles_y)} Embiggen tiles ({emb_tiles_x}x{emb_tiles_y})..."
|
||||||
)
|
)
|
||||||
|
|
||||||
emb_tile_store = []
|
emb_tile_store = []
|
||||||
@ -362,11 +361,11 @@ class Embiggen(Generator):
|
|||||||
# newinitimage.save(newinitimagepath)
|
# newinitimage.save(newinitimagepath)
|
||||||
|
|
||||||
if embiggen_tiles:
|
if embiggen_tiles:
|
||||||
logger.debug(
|
print(
|
||||||
f"Making tile #{tile + 1} ({embiggen_tiles.index(tile) + 1} of {len(embiggen_tiles)} requested)"
|
f"Making tile #{tile + 1} ({embiggen_tiles.index(tile) + 1} of {len(embiggen_tiles)} requested)"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
logger.debug(f"Starting {tile + 1} of {(emb_tiles_x * emb_tiles_y)} tiles")
|
print(f"Starting {tile + 1} of {(emb_tiles_x * emb_tiles_y)} tiles")
|
||||||
|
|
||||||
# create a torch tensor from an Image
|
# create a torch tensor from an Image
|
||||||
newinitimage = np.array(newinitimage).astype(np.float32) / 255.0
|
newinitimage = np.array(newinitimage).astype(np.float32) / 255.0
|
||||||
@ -548,8 +547,8 @@ class Embiggen(Generator):
|
|||||||
# Layer tile onto final image
|
# Layer tile onto final image
|
||||||
outputsuperimage.alpha_composite(intileimage, (left, top))
|
outputsuperimage.alpha_composite(intileimage, (left, top))
|
||||||
else:
|
else:
|
||||||
logger.error(
|
print(
|
||||||
"Could not find all Embiggen output tiles in memory? Something must have gone wrong with img2img generation."
|
"Error: could not find all Embiggen output tiles in memory? Something must have gone wrong with img2img generation."
|
||||||
)
|
)
|
||||||
|
|
||||||
# after internal loops and patching up return Embiggen image
|
# after internal loops and patching up return Embiggen image
|
||||||
|
@ -14,8 +14,6 @@ from ..stable_diffusion.diffusers_pipeline import StableDiffusionGeneratorPipeli
|
|||||||
from ..stable_diffusion.diffusers_pipeline import ConditioningData
|
from ..stable_diffusion.diffusers_pipeline import ConditioningData
|
||||||
from ..stable_diffusion.diffusers_pipeline import trim_to_multiple_of
|
from ..stable_diffusion.diffusers_pipeline import trim_to_multiple_of
|
||||||
|
|
||||||
import invokeai.backend.util.logging as logger
|
|
||||||
|
|
||||||
class Txt2Img2Img(Generator):
|
class Txt2Img2Img(Generator):
|
||||||
def __init__(self, model, precision):
|
def __init__(self, model, precision):
|
||||||
super().__init__(model, precision)
|
super().__init__(model, precision)
|
||||||
@ -79,8 +77,8 @@ class Txt2Img2Img(Generator):
|
|||||||
# the message below is accurate.
|
# the message below is accurate.
|
||||||
init_width = first_pass_latent_output.size()[3] * self.downsampling_factor
|
init_width = first_pass_latent_output.size()[3] * self.downsampling_factor
|
||||||
init_height = first_pass_latent_output.size()[2] * self.downsampling_factor
|
init_height = first_pass_latent_output.size()[2] * self.downsampling_factor
|
||||||
logger.info(
|
print(
|
||||||
f"Interpolating from {init_width}x{init_height} to {width}x{height} using DDIM sampling"
|
f"\n>> Interpolating from {init_width}x{init_height} to {width}x{height} using DDIM sampling"
|
||||||
)
|
)
|
||||||
|
|
||||||
# resizing
|
# resizing
|
||||||
|
@ -5,9 +5,10 @@ wraps the actual patchmatch object. It respects the global
|
|||||||
be suppressed or deferred
|
be suppressed or deferred
|
||||||
"""
|
"""
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import invokeai.backend.util.logging as logger
|
|
||||||
from invokeai.backend.globals import Globals
|
from invokeai.backend.globals import Globals
|
||||||
|
|
||||||
|
|
||||||
class PatchMatch:
|
class PatchMatch:
|
||||||
"""
|
"""
|
||||||
Thin class wrapper around the patchmatch function.
|
Thin class wrapper around the patchmatch function.
|
||||||
@ -27,12 +28,12 @@ class PatchMatch:
|
|||||||
from patchmatch import patch_match as pm
|
from patchmatch import patch_match as pm
|
||||||
|
|
||||||
if pm.patchmatch_available:
|
if pm.patchmatch_available:
|
||||||
logger.info("Patchmatch initialized")
|
print(">> Patchmatch initialized")
|
||||||
else:
|
else:
|
||||||
logger.info("Patchmatch not loaded (nonfatal)")
|
print(">> Patchmatch not loaded (nonfatal)")
|
||||||
self.patch_match = pm
|
self.patch_match = pm
|
||||||
else:
|
else:
|
||||||
logger.info("Patchmatch loading disabled")
|
print(">> Patchmatch loading disabled")
|
||||||
self.tried_load = True
|
self.tried_load = True
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@ -30,9 +30,9 @@ work fine.
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from PIL import Image, ImageOps
|
from PIL import Image, ImageOps
|
||||||
|
from torchvision import transforms
|
||||||
from transformers import AutoProcessor, CLIPSegForImageSegmentation
|
from transformers import AutoProcessor, CLIPSegForImageSegmentation
|
||||||
|
|
||||||
import invokeai.backend.util.logging as logger
|
|
||||||
from invokeai.backend.globals import global_cache_dir
|
from invokeai.backend.globals import global_cache_dir
|
||||||
|
|
||||||
CLIPSEG_MODEL = "CIDAS/clipseg-rd64-refined"
|
CLIPSEG_MODEL = "CIDAS/clipseg-rd64-refined"
|
||||||
@ -83,7 +83,7 @@ class Txt2Mask(object):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, device="cpu", refined=False):
|
def __init__(self, device="cpu", refined=False):
|
||||||
logger.info("Initializing clipseg model for text to mask inference")
|
print(">> Initializing clipseg model for text to mask inference")
|
||||||
|
|
||||||
# BUG: we are not doing anything with the device option at this time
|
# BUG: we are not doing anything with the device option at this time
|
||||||
self.device = device
|
self.device = device
|
||||||
@ -101,6 +101,18 @@ class Txt2Mask(object):
|
|||||||
provided image and returns a SegmentedGrayscale object in which the brighter
|
provided image and returns a SegmentedGrayscale object in which the brighter
|
||||||
pixels indicate where the object is inferred to be.
|
pixels indicate where the object is inferred to be.
|
||||||
"""
|
"""
|
||||||
|
transform = transforms.Compose(
|
||||||
|
[
|
||||||
|
transforms.ToTensor(),
|
||||||
|
transforms.Normalize(
|
||||||
|
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
|
||||||
|
),
|
||||||
|
transforms.Resize(
|
||||||
|
(CLIPSEG_SIZE, CLIPSEG_SIZE)
|
||||||
|
), # must be multiple of 64...
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
if type(image) is str:
|
if type(image) is str:
|
||||||
image = Image.open(image).convert("RGB")
|
image = Image.open(image).convert("RGB")
|
||||||
|
|
||||||
|
@ -25,7 +25,6 @@ from typing import Union
|
|||||||
import torch
|
import torch
|
||||||
from safetensors.torch import load_file
|
from safetensors.torch import load_file
|
||||||
|
|
||||||
import invokeai.backend.util.logging as logger
|
|
||||||
from invokeai.backend.globals import global_cache_dir, global_config_dir
|
from invokeai.backend.globals import global_cache_dir, global_config_dir
|
||||||
|
|
||||||
from .model_manager import ModelManager, SDLegacyType
|
from .model_manager import ModelManager, SDLegacyType
|
||||||
@ -373,9 +372,9 @@ def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False
|
|||||||
unet_key = "model.diffusion_model."
|
unet_key = "model.diffusion_model."
|
||||||
# at least a 100 parameters have to start with `model_ema` in order for the checkpoint to be EMA
|
# at least a 100 parameters have to start with `model_ema` in order for the checkpoint to be EMA
|
||||||
if sum(k.startswith("model_ema") for k in keys) > 100:
|
if sum(k.startswith("model_ema") for k in keys) > 100:
|
||||||
logger.debug(f"Checkpoint {path} has both EMA and non-EMA weights.")
|
print(f" | Checkpoint {path} has both EMA and non-EMA weights.")
|
||||||
if extract_ema:
|
if extract_ema:
|
||||||
logger.debug("Extracting EMA weights (usually better for inference)")
|
print(" | Extracting EMA weights (usually better for inference)")
|
||||||
for key in keys:
|
for key in keys:
|
||||||
if key.startswith("model.diffusion_model"):
|
if key.startswith("model.diffusion_model"):
|
||||||
flat_ema_key = "model_ema." + "".join(key.split(".")[1:])
|
flat_ema_key = "model_ema." + "".join(key.split(".")[1:])
|
||||||
@ -393,8 +392,8 @@ def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False
|
|||||||
key
|
key
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
logger.debug(
|
print(
|
||||||
"Extracting only the non-EMA weights (usually better for fine-tuning)"
|
" | Extracting only the non-EMA weights (usually better for fine-tuning)"
|
||||||
)
|
)
|
||||||
|
|
||||||
for key in keys:
|
for key in keys:
|
||||||
@ -1116,7 +1115,7 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
|
|||||||
if "global_step" in checkpoint:
|
if "global_step" in checkpoint:
|
||||||
global_step = checkpoint["global_step"]
|
global_step = checkpoint["global_step"]
|
||||||
else:
|
else:
|
||||||
logger.debug("global_step key not found in model")
|
print(" | global_step key not found in model")
|
||||||
global_step = None
|
global_step = None
|
||||||
|
|
||||||
# sometimes there is a state_dict key and sometimes not
|
# sometimes there is a state_dict key and sometimes not
|
||||||
@ -1230,15 +1229,15 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
|
|||||||
# If a replacement VAE path was specified, we'll incorporate that into
|
# If a replacement VAE path was specified, we'll incorporate that into
|
||||||
# the checkpoint model and then convert it
|
# the checkpoint model and then convert it
|
||||||
if vae_path:
|
if vae_path:
|
||||||
logger.debug(f"Converting VAE {vae_path}")
|
print(f" | Converting VAE {vae_path}")
|
||||||
replace_checkpoint_vae(checkpoint,vae_path)
|
replace_checkpoint_vae(checkpoint,vae_path)
|
||||||
# otherwise we use the original VAE, provided that
|
# otherwise we use the original VAE, provided that
|
||||||
# an externally loaded diffusers VAE was not passed
|
# an externally loaded diffusers VAE was not passed
|
||||||
elif not vae:
|
elif not vae:
|
||||||
logger.debug("Using checkpoint model's original VAE")
|
print(" | Using checkpoint model's original VAE")
|
||||||
|
|
||||||
if vae:
|
if vae:
|
||||||
logger.debug("Using replacement diffusers VAE")
|
print(" | Using replacement diffusers VAE")
|
||||||
else: # convert the original or replacement VAE
|
else: # convert the original or replacement VAE
|
||||||
vae_config = create_vae_diffusers_config(
|
vae_config = create_vae_diffusers_config(
|
||||||
original_config, image_size=image_size
|
original_config, image_size=image_size
|
||||||
|
@ -18,13 +18,12 @@ import warnings
|
|||||||
from enum import Enum, auto
|
from enum import Enum, auto
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from shutil import move, rmtree
|
from shutil import move, rmtree
|
||||||
from typing import Any, Optional, Union, Callable, types
|
from typing import Any, Optional, Union, Callable
|
||||||
|
|
||||||
import safetensors
|
import safetensors
|
||||||
import safetensors.torch
|
import safetensors.torch
|
||||||
import torch
|
import torch
|
||||||
import transformers
|
import transformers
|
||||||
import invokeai.backend.util.logging as logger
|
|
||||||
from diffusers import (
|
from diffusers import (
|
||||||
AutoencoderKL,
|
AutoencoderKL,
|
||||||
UNet2DConditionModel,
|
UNet2DConditionModel,
|
||||||
@ -76,8 +75,6 @@ class ModelManager(object):
|
|||||||
Model manager handles loading, caching, importing, deleting, converting, and editing models.
|
Model manager handles loading, caching, importing, deleting, converting, and editing models.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
logger: types.ModuleType = logger
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: OmegaConf | Path,
|
config: OmegaConf | Path,
|
||||||
@ -86,7 +83,6 @@ class ModelManager(object):
|
|||||||
max_loaded_models=DEFAULT_MAX_MODELS,
|
max_loaded_models=DEFAULT_MAX_MODELS,
|
||||||
sequential_offload=False,
|
sequential_offload=False,
|
||||||
embedding_path: Path = None,
|
embedding_path: Path = None,
|
||||||
logger: types.ModuleType = logger,
|
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Initialize with the path to the models.yaml config file or
|
Initialize with the path to the models.yaml config file or
|
||||||
@ -108,7 +104,6 @@ class ModelManager(object):
|
|||||||
self.current_model = None
|
self.current_model = None
|
||||||
self.sequential_offload = sequential_offload
|
self.sequential_offload = sequential_offload
|
||||||
self.embedding_path = embedding_path
|
self.embedding_path = embedding_path
|
||||||
self.logger = logger
|
|
||||||
|
|
||||||
def valid_model(self, model_name: str) -> bool:
|
def valid_model(self, model_name: str) -> bool:
|
||||||
"""
|
"""
|
||||||
@ -137,8 +132,8 @@ class ModelManager(object):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if not self.valid_model(model_name):
|
if not self.valid_model(model_name):
|
||||||
self.logger.error(
|
print(
|
||||||
f'"{model_name}" is not a known model name. Please check your models.yaml file'
|
f'** "{model_name}" is not a known model name. Please check your models.yaml file'
|
||||||
)
|
)
|
||||||
return self.current_model
|
return self.current_model
|
||||||
|
|
||||||
@ -149,7 +144,7 @@ class ModelManager(object):
|
|||||||
|
|
||||||
if model_name in self.models:
|
if model_name in self.models:
|
||||||
requested_model = self.models[model_name]["model"]
|
requested_model = self.models[model_name]["model"]
|
||||||
self.logger.info(f"Retrieving model {model_name} from system RAM cache")
|
print(f">> Retrieving model {model_name} from system RAM cache")
|
||||||
requested_model.ready()
|
requested_model.ready()
|
||||||
width = self.models[model_name]["width"]
|
width = self.models[model_name]["width"]
|
||||||
height = self.models[model_name]["height"]
|
height = self.models[model_name]["height"]
|
||||||
@ -384,7 +379,7 @@ class ModelManager(object):
|
|||||||
"""
|
"""
|
||||||
omega = self.config
|
omega = self.config
|
||||||
if model_name not in omega:
|
if model_name not in omega:
|
||||||
self.logger.error(f"Unknown model {model_name}")
|
print(f"** Unknown model {model_name}")
|
||||||
return
|
return
|
||||||
# save these for use in deletion later
|
# save these for use in deletion later
|
||||||
conf = omega[model_name]
|
conf = omega[model_name]
|
||||||
@ -397,13 +392,13 @@ class ModelManager(object):
|
|||||||
self.stack.remove(model_name)
|
self.stack.remove(model_name)
|
||||||
if delete_files:
|
if delete_files:
|
||||||
if weights:
|
if weights:
|
||||||
self.logger.info(f"Deleting file {weights}")
|
print(f"** Deleting file {weights}")
|
||||||
Path(weights).unlink(missing_ok=True)
|
Path(weights).unlink(missing_ok=True)
|
||||||
elif path:
|
elif path:
|
||||||
self.logger.info(f"Deleting directory {path}")
|
print(f"** Deleting directory {path}")
|
||||||
rmtree(path, ignore_errors=True)
|
rmtree(path, ignore_errors=True)
|
||||||
elif repo_id:
|
elif repo_id:
|
||||||
self.logger.info(f"Deleting the cached model directory for {repo_id}")
|
print(f"** Deleting the cached model directory for {repo_id}")
|
||||||
self._delete_model_from_cache(repo_id)
|
self._delete_model_from_cache(repo_id)
|
||||||
|
|
||||||
def add_model(
|
def add_model(
|
||||||
@ -444,7 +439,7 @@ class ModelManager(object):
|
|||||||
def _load_model(self, model_name: str):
|
def _load_model(self, model_name: str):
|
||||||
"""Load and initialize the model from configuration variables passed at object creation time"""
|
"""Load and initialize the model from configuration variables passed at object creation time"""
|
||||||
if model_name not in self.config:
|
if model_name not in self.config:
|
||||||
self.logger.error(
|
print(
|
||||||
f'"{model_name}" is not a known model name. Please check your models.yaml file'
|
f'"{model_name}" is not a known model name. Please check your models.yaml file'
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
@ -462,7 +457,7 @@ class ModelManager(object):
|
|||||||
model_format = mconfig.get("format", "ckpt")
|
model_format = mconfig.get("format", "ckpt")
|
||||||
if model_format == "ckpt":
|
if model_format == "ckpt":
|
||||||
weights = mconfig.weights
|
weights = mconfig.weights
|
||||||
self.logger.info(f"Loading {model_name} from {weights}")
|
print(f">> Loading {model_name} from {weights}")
|
||||||
model, width, height, model_hash = self._load_ckpt_model(
|
model, width, height, model_hash = self._load_ckpt_model(
|
||||||
model_name, mconfig
|
model_name, mconfig
|
||||||
)
|
)
|
||||||
@ -478,15 +473,13 @@ class ModelManager(object):
|
|||||||
|
|
||||||
# usage statistics
|
# usage statistics
|
||||||
toc = time.time()
|
toc = time.time()
|
||||||
self.logger.info("Model loaded in " + "%4.2fs" % (toc - tic))
|
print(">> Model loaded in", "%4.2fs" % (toc - tic))
|
||||||
if self._has_cuda():
|
if self._has_cuda():
|
||||||
self.logger.info(
|
print(
|
||||||
"Max VRAM used to load the model: "+
|
">> Max VRAM used to load the model:",
|
||||||
"%4.2fG" % (torch.cuda.max_memory_allocated() / 1e9)
|
"%4.2fG" % (torch.cuda.max_memory_allocated() / 1e9),
|
||||||
)
|
"\n>> Current VRAM usage:"
|
||||||
self.logger.info(
|
"%4.2fG" % (torch.cuda.memory_allocated() / 1e9),
|
||||||
"Current VRAM usage: "+
|
|
||||||
"%4.2fG" % (torch.cuda.memory_allocated() / 1e9)
|
|
||||||
)
|
)
|
||||||
return model, width, height, model_hash
|
return model, width, height, model_hash
|
||||||
|
|
||||||
@ -494,11 +487,11 @@ class ModelManager(object):
|
|||||||
name_or_path = self.model_name_or_path(mconfig)
|
name_or_path = self.model_name_or_path(mconfig)
|
||||||
using_fp16 = self.precision == "float16"
|
using_fp16 = self.precision == "float16"
|
||||||
|
|
||||||
self.logger.info(f"Loading diffusers model from {name_or_path}")
|
print(f">> Loading diffusers model from {name_or_path}")
|
||||||
if using_fp16:
|
if using_fp16:
|
||||||
self.logger.debug("Using faster float16 precision")
|
print(" | Using faster float16 precision")
|
||||||
else:
|
else:
|
||||||
self.logger.debug("Using more accurate float32 precision")
|
print(" | Using more accurate float32 precision")
|
||||||
|
|
||||||
# TODO: scan weights maybe?
|
# TODO: scan weights maybe?
|
||||||
pipeline_args: dict[str, Any] = dict(
|
pipeline_args: dict[str, Any] = dict(
|
||||||
@ -530,8 +523,8 @@ class ModelManager(object):
|
|||||||
if str(e).startswith("fp16 is not a valid"):
|
if str(e).startswith("fp16 is not a valid"):
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
self.logger.error(
|
print(
|
||||||
f"An unexpected error occurred while downloading the model: {e})"
|
f"** An unexpected error occurred while downloading the model: {e})"
|
||||||
)
|
)
|
||||||
if pipeline:
|
if pipeline:
|
||||||
break
|
break
|
||||||
@ -549,7 +542,7 @@ class ModelManager(object):
|
|||||||
# square images???
|
# square images???
|
||||||
width = pipeline.unet.config.sample_size * pipeline.vae_scale_factor
|
width = pipeline.unet.config.sample_size * pipeline.vae_scale_factor
|
||||||
height = width
|
height = width
|
||||||
self.logger.debug(f"Default image dimensions = {width} x {height}")
|
print(f" | Default image dimensions = {width} x {height}")
|
||||||
|
|
||||||
return pipeline, width, height, model_hash
|
return pipeline, width, height, model_hash
|
||||||
|
|
||||||
@ -566,14 +559,14 @@ class ModelManager(object):
|
|||||||
weights = os.path.normpath(os.path.join(Globals.root, weights))
|
weights = os.path.normpath(os.path.join(Globals.root, weights))
|
||||||
|
|
||||||
# Convert to diffusers and return a diffusers pipeline
|
# Convert to diffusers and return a diffusers pipeline
|
||||||
self.logger.info(f"Converting legacy checkpoint {model_name} into a diffusers model...")
|
print(f">> Converting legacy checkpoint {model_name} into a diffusers model...")
|
||||||
|
|
||||||
from . import load_pipeline_from_original_stable_diffusion_ckpt
|
from . import load_pipeline_from_original_stable_diffusion_ckpt
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if self.list_models()[self.current_model]["status"] == "active":
|
if self.list_models()[self.current_model]["status"] == "active":
|
||||||
self.offload_model(self.current_model)
|
self.offload_model(self.current_model)
|
||||||
except Exception:
|
except Exception as e:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
vae_path = None
|
vae_path = None
|
||||||
@ -631,7 +624,7 @@ class ModelManager(object):
|
|||||||
if model_name not in self.models:
|
if model_name not in self.models:
|
||||||
return
|
return
|
||||||
|
|
||||||
self.logger.info(f"Offloading {model_name} to CPU")
|
print(f">> Offloading {model_name} to CPU")
|
||||||
model = self.models[model_name]["model"]
|
model = self.models[model_name]["model"]
|
||||||
model.offload_all()
|
model.offload_all()
|
||||||
self.current_model = None
|
self.current_model = None
|
||||||
@ -647,26 +640,30 @@ class ModelManager(object):
|
|||||||
and option to exit if an infected file is identified.
|
and option to exit if an infected file is identified.
|
||||||
"""
|
"""
|
||||||
# scan model
|
# scan model
|
||||||
self.logger.debug(f"Scanning Model: {model_name}")
|
print(f" | Scanning Model: {model_name}")
|
||||||
scan_result = scan_file_path(checkpoint)
|
scan_result = scan_file_path(checkpoint)
|
||||||
if scan_result.infected_files != 0:
|
if scan_result.infected_files != 0:
|
||||||
if scan_result.infected_files == 1:
|
if scan_result.infected_files == 1:
|
||||||
self.logger.critical(f"Issues Found In Model: {scan_result.issues_count}")
|
print(f"\n### Issues Found In Model: {scan_result.issues_count}")
|
||||||
self.logger.critical("The model you are trying to load seems to be infected.")
|
print(
|
||||||
self.logger.critical("For your safety, InvokeAI will not load this model.")
|
"### WARNING: The model you are trying to load seems to be infected."
|
||||||
self.logger.critical("Please use checkpoints from trusted sources.")
|
)
|
||||||
self.logger.critical("Exiting InvokeAI")
|
print("### For your safety, InvokeAI will not load this model.")
|
||||||
|
print("### Please use checkpoints from trusted sources.")
|
||||||
|
print("### Exiting InvokeAI")
|
||||||
sys.exit()
|
sys.exit()
|
||||||
else:
|
else:
|
||||||
self.logger.warning("InvokeAI was unable to scan the model you are using.")
|
print(
|
||||||
|
"\n### WARNING: InvokeAI was unable to scan the model you are using."
|
||||||
|
)
|
||||||
model_safe_check_fail = ask_user(
|
model_safe_check_fail = ask_user(
|
||||||
"Do you want to to continue loading the model?", ["y", "n"]
|
"Do you want to to continue loading the model?", ["y", "n"]
|
||||||
)
|
)
|
||||||
if model_safe_check_fail.lower() != "y":
|
if model_safe_check_fail.lower() != "y":
|
||||||
self.logger.critical("Exiting InvokeAI")
|
print("### Exiting InvokeAI")
|
||||||
sys.exit()
|
sys.exit()
|
||||||
else:
|
else:
|
||||||
self.logger.debug("Model scanned ok")
|
print(" | Model scanned ok")
|
||||||
|
|
||||||
def import_diffuser_model(
|
def import_diffuser_model(
|
||||||
self,
|
self,
|
||||||
@ -783,24 +780,26 @@ class ModelManager(object):
|
|||||||
model_path: Path = None
|
model_path: Path = None
|
||||||
thing = path_url_or_repo # to save typing
|
thing = path_url_or_repo # to save typing
|
||||||
|
|
||||||
self.logger.info(f"Probing {thing} for import")
|
print(f">> Probing {thing} for import")
|
||||||
|
|
||||||
if thing.startswith(("http:", "https:", "ftp:")):
|
if thing.startswith(("http:", "https:", "ftp:")):
|
||||||
self.logger.info(f"{thing} appears to be a URL")
|
print(f" | {thing} appears to be a URL")
|
||||||
model_path = self._resolve_path(
|
model_path = self._resolve_path(
|
||||||
thing, "models/ldm/stable-diffusion-v1"
|
thing, "models/ldm/stable-diffusion-v1"
|
||||||
) # _resolve_path does a download if needed
|
) # _resolve_path does a download if needed
|
||||||
|
|
||||||
elif Path(thing).is_file() and thing.endswith((".ckpt", ".safetensors")):
|
elif Path(thing).is_file() and thing.endswith((".ckpt", ".safetensors")):
|
||||||
if Path(thing).stem in ["model", "diffusion_pytorch_model"]:
|
if Path(thing).stem in ["model", "diffusion_pytorch_model"]:
|
||||||
self.logger.debug(f"{Path(thing).name} appears to be part of a diffusers model. Skipping import")
|
print(
|
||||||
|
f" | {Path(thing).name} appears to be part of a diffusers model. Skipping import"
|
||||||
|
)
|
||||||
return
|
return
|
||||||
else:
|
else:
|
||||||
self.logger.debug(f"{thing} appears to be a checkpoint file on disk")
|
print(f" | {thing} appears to be a checkpoint file on disk")
|
||||||
model_path = self._resolve_path(thing, "models/ldm/stable-diffusion-v1")
|
model_path = self._resolve_path(thing, "models/ldm/stable-diffusion-v1")
|
||||||
|
|
||||||
elif Path(thing).is_dir() and Path(thing, "model_index.json").exists():
|
elif Path(thing).is_dir() and Path(thing, "model_index.json").exists():
|
||||||
self.logger.debug(f"{thing} appears to be a diffusers file on disk")
|
print(f" | {thing} appears to be a diffusers file on disk")
|
||||||
model_name = self.import_diffuser_model(
|
model_name = self.import_diffuser_model(
|
||||||
thing,
|
thing,
|
||||||
vae=dict(repo_id="stabilityai/sd-vae-ft-mse"),
|
vae=dict(repo_id="stabilityai/sd-vae-ft-mse"),
|
||||||
@ -811,30 +810,34 @@ class ModelManager(object):
|
|||||||
|
|
||||||
elif Path(thing).is_dir():
|
elif Path(thing).is_dir():
|
||||||
if (Path(thing) / "model_index.json").exists():
|
if (Path(thing) / "model_index.json").exists():
|
||||||
self.logger.debug(f"{thing} appears to be a diffusers model.")
|
print(f" | {thing} appears to be a diffusers model.")
|
||||||
model_name = self.import_diffuser_model(
|
model_name = self.import_diffuser_model(
|
||||||
thing, commit_to_conf=commit_to_conf
|
thing, commit_to_conf=commit_to_conf
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.logger.debug(f"{thing} appears to be a directory. Will scan for models to import")
|
print(
|
||||||
|
f" |{thing} appears to be a directory. Will scan for models to import"
|
||||||
|
)
|
||||||
for m in list(Path(thing).rglob("*.ckpt")) + list(
|
for m in list(Path(thing).rglob("*.ckpt")) + list(
|
||||||
Path(thing).rglob("*.safetensors")
|
Path(thing).rglob("*.safetensors")
|
||||||
):
|
):
|
||||||
if model_name := self.heuristic_import(
|
if model_name := self.heuristic_import(
|
||||||
str(m), commit_to_conf=commit_to_conf
|
str(m), commit_to_conf=commit_to_conf
|
||||||
):
|
):
|
||||||
self.logger.info(f"{model_name} successfully imported")
|
print(f" >> {model_name} successfully imported")
|
||||||
return model_name
|
return model_name
|
||||||
|
|
||||||
elif re.match(r"^[\w.+-]+/[\w.+-]+$", thing):
|
elif re.match(r"^[\w.+-]+/[\w.+-]+$", thing):
|
||||||
self.logger.debug(f"{thing} appears to be a HuggingFace diffusers repo_id")
|
print(f" | {thing} appears to be a HuggingFace diffusers repo_id")
|
||||||
model_name = self.import_diffuser_model(
|
model_name = self.import_diffuser_model(
|
||||||
thing, commit_to_conf=commit_to_conf
|
thing, commit_to_conf=commit_to_conf
|
||||||
)
|
)
|
||||||
pipeline, _, _, _ = self._load_diffusers_model(self.config[model_name])
|
pipeline, _, _, _ = self._load_diffusers_model(self.config[model_name])
|
||||||
return model_name
|
return model_name
|
||||||
else:
|
else:
|
||||||
self.logger.warning(f"{thing}: Unknown thing. Please provide a URL, file path, directory or HuggingFace repo_id")
|
print(
|
||||||
|
f"** {thing}: Unknown thing. Please provide a URL, file path, directory or HuggingFace repo_id"
|
||||||
|
)
|
||||||
|
|
||||||
# Model_path is set in the event of a legacy checkpoint file.
|
# Model_path is set in the event of a legacy checkpoint file.
|
||||||
# If not set, we're all done
|
# If not set, we're all done
|
||||||
@ -842,7 +845,7 @@ class ModelManager(object):
|
|||||||
return
|
return
|
||||||
|
|
||||||
if model_path.stem in self.config: # already imported
|
if model_path.stem in self.config: # already imported
|
||||||
self.logger.debug("Already imported. Skipping")
|
print(" | Already imported. Skipping")
|
||||||
return model_path.stem
|
return model_path.stem
|
||||||
|
|
||||||
# another round of heuristics to guess the correct config file.
|
# another round of heuristics to guess the correct config file.
|
||||||
@ -858,39 +861,39 @@ class ModelManager(object):
|
|||||||
# look for a like-named .yaml file in same directory
|
# look for a like-named .yaml file in same directory
|
||||||
if model_path.with_suffix(".yaml").exists():
|
if model_path.with_suffix(".yaml").exists():
|
||||||
model_config_file = model_path.with_suffix(".yaml")
|
model_config_file = model_path.with_suffix(".yaml")
|
||||||
self.logger.debug(f"Using config file {model_config_file.name}")
|
print(f" | Using config file {model_config_file.name}")
|
||||||
|
|
||||||
else:
|
else:
|
||||||
model_type = self.probe_model_type(checkpoint)
|
model_type = self.probe_model_type(checkpoint)
|
||||||
if model_type == SDLegacyType.V1:
|
if model_type == SDLegacyType.V1:
|
||||||
self.logger.debug("SD-v1 model detected")
|
print(" | SD-v1 model detected")
|
||||||
model_config_file = Path(
|
model_config_file = Path(
|
||||||
Globals.root, "configs/stable-diffusion/v1-inference.yaml"
|
Globals.root, "configs/stable-diffusion/v1-inference.yaml"
|
||||||
)
|
)
|
||||||
elif model_type == SDLegacyType.V1_INPAINT:
|
elif model_type == SDLegacyType.V1_INPAINT:
|
||||||
self.logger.debug("SD-v1 inpainting model detected")
|
print(" | SD-v1 inpainting model detected")
|
||||||
model_config_file = Path(
|
model_config_file = Path(
|
||||||
Globals.root,
|
Globals.root,
|
||||||
"configs/stable-diffusion/v1-inpainting-inference.yaml",
|
"configs/stable-diffusion/v1-inpainting-inference.yaml",
|
||||||
)
|
)
|
||||||
elif model_type == SDLegacyType.V2_v:
|
elif model_type == SDLegacyType.V2_v:
|
||||||
self.logger.debug("SD-v2-v model detected")
|
print(" | SD-v2-v model detected")
|
||||||
model_config_file = Path(
|
model_config_file = Path(
|
||||||
Globals.root, "configs/stable-diffusion/v2-inference-v.yaml"
|
Globals.root, "configs/stable-diffusion/v2-inference-v.yaml"
|
||||||
)
|
)
|
||||||
elif model_type == SDLegacyType.V2_e:
|
elif model_type == SDLegacyType.V2_e:
|
||||||
self.logger.debug("SD-v2-e model detected")
|
print(" | SD-v2-e model detected")
|
||||||
model_config_file = Path(
|
model_config_file = Path(
|
||||||
Globals.root, "configs/stable-diffusion/v2-inference.yaml"
|
Globals.root, "configs/stable-diffusion/v2-inference.yaml"
|
||||||
)
|
)
|
||||||
elif model_type == SDLegacyType.V2:
|
elif model_type == SDLegacyType.V2:
|
||||||
self.logger.warning(
|
print(
|
||||||
f"{thing} is a V2 checkpoint file, but its parameterization cannot be determined. Please provide configuration file path."
|
f"** {thing} is a V2 checkpoint file, but its parameterization cannot be determined. Please provide configuration file path."
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
else:
|
else:
|
||||||
self.logger.warning(
|
print(
|
||||||
f"{thing} is a legacy checkpoint file but not a known Stable Diffusion model. Please provide configuration file path."
|
f"** {thing} is a legacy checkpoint file but not a known Stable Diffusion model. Please provide configuration file path."
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
@ -906,7 +909,7 @@ class ModelManager(object):
|
|||||||
for suffix in ["pt", "ckpt", "safetensors"]:
|
for suffix in ["pt", "ckpt", "safetensors"]:
|
||||||
if (model_path.with_suffix(f".vae.{suffix}")).exists():
|
if (model_path.with_suffix(f".vae.{suffix}")).exists():
|
||||||
vae_path = model_path.with_suffix(f".vae.{suffix}")
|
vae_path = model_path.with_suffix(f".vae.{suffix}")
|
||||||
self.logger.debug(f"Using VAE file {vae_path.name}")
|
print(f" | Using VAE file {vae_path.name}")
|
||||||
vae = None if vae_path else dict(repo_id="stabilityai/sd-vae-ft-mse")
|
vae = None if vae_path else dict(repo_id="stabilityai/sd-vae-ft-mse")
|
||||||
|
|
||||||
diffuser_path = Path(
|
diffuser_path = Path(
|
||||||
@ -952,14 +955,14 @@ class ModelManager(object):
|
|||||||
from . import convert_ckpt_to_diffusers
|
from . import convert_ckpt_to_diffusers
|
||||||
|
|
||||||
if diffusers_path.exists():
|
if diffusers_path.exists():
|
||||||
self.logger.error(
|
print(
|
||||||
f"The path {str(diffusers_path)} already exists. Please move or remove it and try again."
|
f"ERROR: The path {str(diffusers_path)} already exists. Please move or remove it and try again."
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
model_name = model_name or diffusers_path.name
|
model_name = model_name or diffusers_path.name
|
||||||
model_description = model_description or f"Converted version of {model_name}"
|
model_description = model_description or f"Converted version of {model_name}"
|
||||||
self.logger.debug(f"Converting {model_name} to diffusers (30-60s)")
|
print(f" | Converting {model_name} to diffusers (30-60s)")
|
||||||
try:
|
try:
|
||||||
# By passing the specified VAE to the conversion function, the autoencoder
|
# By passing the specified VAE to the conversion function, the autoencoder
|
||||||
# will be built into the model rather than tacked on afterward via the config file
|
# will be built into the model rather than tacked on afterward via the config file
|
||||||
@ -976,10 +979,10 @@ class ModelManager(object):
|
|||||||
vae_path=vae_path,
|
vae_path=vae_path,
|
||||||
scan_needed=scan_needed,
|
scan_needed=scan_needed,
|
||||||
)
|
)
|
||||||
self.logger.debug(
|
print(
|
||||||
f"Success. Converted model is now located at {str(diffusers_path)}"
|
f" | Success. Converted model is now located at {str(diffusers_path)}"
|
||||||
)
|
)
|
||||||
self.logger.debug(f"Writing new config file entry for {model_name}")
|
print(f" | Writing new config file entry for {model_name}")
|
||||||
new_config = dict(
|
new_config = dict(
|
||||||
path=str(diffusers_path),
|
path=str(diffusers_path),
|
||||||
description=model_description,
|
description=model_description,
|
||||||
@ -990,17 +993,17 @@ class ModelManager(object):
|
|||||||
self.add_model(model_name, new_config, True)
|
self.add_model(model_name, new_config, True)
|
||||||
if commit_to_conf:
|
if commit_to_conf:
|
||||||
self.commit(commit_to_conf)
|
self.commit(commit_to_conf)
|
||||||
self.logger.debug("Conversion succeeded")
|
print(" | Conversion succeeded")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.logger.warning(f"Conversion failed: {str(e)}")
|
print(f"** Conversion failed: {str(e)}")
|
||||||
self.logger.warning(
|
print(
|
||||||
"If you are trying to convert an inpainting or 2.X model, please indicate the correct config file (e.g. v1-inpainting-inference.yaml)"
|
"** If you are trying to convert an inpainting or 2.X model, please indicate the correct config file (e.g. v1-inpainting-inference.yaml)"
|
||||||
)
|
)
|
||||||
|
|
||||||
return model_name
|
return model_name
|
||||||
|
|
||||||
def search_models(self, search_folder):
|
def search_models(self, search_folder):
|
||||||
self.logger.info(f"Finding Models In: {search_folder}")
|
print(f">> Finding Models In: {search_folder}")
|
||||||
models_folder_ckpt = Path(search_folder).glob("**/*.ckpt")
|
models_folder_ckpt = Path(search_folder).glob("**/*.ckpt")
|
||||||
models_folder_safetensors = Path(search_folder).glob("**/*.safetensors")
|
models_folder_safetensors = Path(search_folder).glob("**/*.safetensors")
|
||||||
|
|
||||||
@ -1024,8 +1027,8 @@ class ModelManager(object):
|
|||||||
num_loaded_models = len(self.models)
|
num_loaded_models = len(self.models)
|
||||||
if num_loaded_models >= self.max_loaded_models:
|
if num_loaded_models >= self.max_loaded_models:
|
||||||
least_recent_model = self._pop_oldest_model()
|
least_recent_model = self._pop_oldest_model()
|
||||||
self.logger.info(
|
print(
|
||||||
f"Cache limit (max={self.max_loaded_models}) reached. Purging {least_recent_model}"
|
f">> Cache limit (max={self.max_loaded_models}) reached. Purging {least_recent_model}"
|
||||||
)
|
)
|
||||||
if least_recent_model is not None:
|
if least_recent_model is not None:
|
||||||
del self.models[least_recent_model]
|
del self.models[least_recent_model]
|
||||||
@ -1033,8 +1036,8 @@ class ModelManager(object):
|
|||||||
|
|
||||||
def print_vram_usage(self) -> None:
|
def print_vram_usage(self) -> None:
|
||||||
if self._has_cuda:
|
if self._has_cuda:
|
||||||
self.logger.info(
|
print(
|
||||||
"Current VRAM usage:"+
|
">> Current VRAM usage: ",
|
||||||
"%4.2fG" % (torch.cuda.memory_allocated() / 1e9),
|
"%4.2fG" % (torch.cuda.memory_allocated() / 1e9),
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -1123,10 +1126,10 @@ class ModelManager(object):
|
|||||||
dest = hub / model.stem
|
dest = hub / model.stem
|
||||||
if dest.exists() and not source.exists():
|
if dest.exists() and not source.exists():
|
||||||
continue
|
continue
|
||||||
cls.logger.info(f"{source} => {dest}")
|
print(f"** {source} => {dest}")
|
||||||
if source.exists():
|
if source.exists():
|
||||||
if dest.is_symlink():
|
if dest.is_symlink():
|
||||||
logger.warning(f"Found symlink at {dest.name}. Not migrating.")
|
print(f"** Found symlink at {dest.name}. Not migrating.")
|
||||||
elif dest.exists():
|
elif dest.exists():
|
||||||
if source.is_dir():
|
if source.is_dir():
|
||||||
rmtree(source)
|
rmtree(source)
|
||||||
@ -1143,7 +1146,7 @@ class ModelManager(object):
|
|||||||
]
|
]
|
||||||
for d in empty:
|
for d in empty:
|
||||||
os.rmdir(d)
|
os.rmdir(d)
|
||||||
cls.logger.info("Migration is done. Continuing...")
|
print("** Migration is done. Continuing...")
|
||||||
|
|
||||||
def _resolve_path(
|
def _resolve_path(
|
||||||
self, source: Union[str, Path], dest_directory: str
|
self, source: Union[str, Path], dest_directory: str
|
||||||
@ -1186,15 +1189,15 @@ class ModelManager(object):
|
|||||||
|
|
||||||
def _add_embeddings_to_model(self, model: StableDiffusionGeneratorPipeline):
|
def _add_embeddings_to_model(self, model: StableDiffusionGeneratorPipeline):
|
||||||
if self.embedding_path is not None:
|
if self.embedding_path is not None:
|
||||||
self.logger.info(f"Loading embeddings from {self.embedding_path}")
|
print(f">> Loading embeddings from {self.embedding_path}")
|
||||||
for root, _, files in os.walk(self.embedding_path):
|
for root, _, files in os.walk(self.embedding_path):
|
||||||
for name in files:
|
for name in files:
|
||||||
ti_path = os.path.join(root, name)
|
ti_path = os.path.join(root, name)
|
||||||
model.textual_inversion_manager.load_textual_inversion(
|
model.textual_inversion_manager.load_textual_inversion(
|
||||||
ti_path, defer_injecting_tokens=True
|
ti_path, defer_injecting_tokens=True
|
||||||
)
|
)
|
||||||
self.logger.info(
|
print(
|
||||||
f'Textual inversion triggers: {", ".join(sorted(model.textual_inversion_manager.get_all_trigger_strings()))}'
|
f'>> Textual inversion triggers: {", ".join(sorted(model.textual_inversion_manager.get_all_trigger_strings()))}'
|
||||||
)
|
)
|
||||||
|
|
||||||
def _has_cuda(self) -> bool:
|
def _has_cuda(self) -> bool:
|
||||||
@ -1216,7 +1219,7 @@ class ModelManager(object):
|
|||||||
with open(hashpath) as f:
|
with open(hashpath) as f:
|
||||||
hash = f.read()
|
hash = f.read()
|
||||||
return hash
|
return hash
|
||||||
self.logger.debug("Calculating sha256 hash of model files")
|
print(" | Calculating sha256 hash of model files")
|
||||||
tic = time.time()
|
tic = time.time()
|
||||||
sha = hashlib.sha256()
|
sha = hashlib.sha256()
|
||||||
count = 0
|
count = 0
|
||||||
@ -1228,7 +1231,7 @@ class ModelManager(object):
|
|||||||
sha.update(chunk)
|
sha.update(chunk)
|
||||||
hash = sha.hexdigest()
|
hash = sha.hexdigest()
|
||||||
toc = time.time()
|
toc = time.time()
|
||||||
self.logger.debug(f"sha256 = {hash} ({count} files hashed in", "%4.2fs)" % (toc - tic))
|
print(f" | sha256 = {hash} ({count} files hashed in", "%4.2fs)" % (toc - tic))
|
||||||
with open(hashpath, "w") as f:
|
with open(hashpath, "w") as f:
|
||||||
f.write(hash)
|
f.write(hash)
|
||||||
return hash
|
return hash
|
||||||
@ -1246,13 +1249,13 @@ class ModelManager(object):
|
|||||||
hash = f.read()
|
hash = f.read()
|
||||||
return hash
|
return hash
|
||||||
|
|
||||||
self.logger.debug("Calculating sha256 hash of weights file")
|
print(" | Calculating sha256 hash of weights file")
|
||||||
tic = time.time()
|
tic = time.time()
|
||||||
sha = hashlib.sha256()
|
sha = hashlib.sha256()
|
||||||
sha.update(data)
|
sha.update(data)
|
||||||
hash = sha.hexdigest()
|
hash = sha.hexdigest()
|
||||||
toc = time.time()
|
toc = time.time()
|
||||||
self.logger.debug(f"sha256 = {hash} "+"(%4.2fs)" % (toc - tic))
|
print(f">> sha256 = {hash}", "(%4.2fs)" % (toc - tic))
|
||||||
|
|
||||||
with open(hashpath, "w") as f:
|
with open(hashpath, "w") as f:
|
||||||
f.write(hash)
|
f.write(hash)
|
||||||
@ -1273,12 +1276,12 @@ class ModelManager(object):
|
|||||||
local_files_only=not Globals.internet_available,
|
local_files_only=not Globals.internet_available,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.logger.debug(f"Loading diffusers VAE from {name_or_path}")
|
print(f" | Loading diffusers VAE from {name_or_path}")
|
||||||
if using_fp16:
|
if using_fp16:
|
||||||
vae_args.update(torch_dtype=torch.float16)
|
vae_args.update(torch_dtype=torch.float16)
|
||||||
fp_args_list = [{"revision": "fp16"}, {}]
|
fp_args_list = [{"revision": "fp16"}, {}]
|
||||||
else:
|
else:
|
||||||
self.logger.debug("Using more accurate float32 precision")
|
print(" | Using more accurate float32 precision")
|
||||||
fp_args_list = [{}]
|
fp_args_list = [{}]
|
||||||
|
|
||||||
vae = None
|
vae = None
|
||||||
@ -1302,12 +1305,12 @@ class ModelManager(object):
|
|||||||
break
|
break
|
||||||
|
|
||||||
if not vae and deferred_error:
|
if not vae and deferred_error:
|
||||||
self.logger.warning(f"Could not load VAE {name_or_path}: {str(deferred_error)}")
|
print(f"** Could not load VAE {name_or_path}: {str(deferred_error)}")
|
||||||
|
|
||||||
return vae
|
return vae
|
||||||
|
|
||||||
@classmethod
|
@staticmethod
|
||||||
def _delete_model_from_cache(cls,repo_id):
|
def _delete_model_from_cache(repo_id):
|
||||||
cache_info = scan_cache_dir(global_cache_dir("hub"))
|
cache_info = scan_cache_dir(global_cache_dir("hub"))
|
||||||
|
|
||||||
# I'm sure there is a way to do this with comprehensions
|
# I'm sure there is a way to do this with comprehensions
|
||||||
@ -1318,8 +1321,8 @@ class ModelManager(object):
|
|||||||
for revision in repo.revisions:
|
for revision in repo.revisions:
|
||||||
hashes_to_delete.add(revision.commit_hash)
|
hashes_to_delete.add(revision.commit_hash)
|
||||||
strategy = cache_info.delete_revisions(*hashes_to_delete)
|
strategy = cache_info.delete_revisions(*hashes_to_delete)
|
||||||
cls.logger.warning(
|
print(
|
||||||
f"Deletion of this model is expected to free {strategy.expected_freed_size_str}"
|
f"** Deletion of this model is expected to free {strategy.expected_freed_size_str}"
|
||||||
)
|
)
|
||||||
strategy.execute()
|
strategy.execute()
|
||||||
|
|
||||||
|
@ -18,7 +18,6 @@ from compel.prompt_parser import (
|
|||||||
PromptParser,
|
PromptParser,
|
||||||
)
|
)
|
||||||
|
|
||||||
import invokeai.backend.util.logging as logger
|
|
||||||
from invokeai.backend.globals import Globals
|
from invokeai.backend.globals import Globals
|
||||||
|
|
||||||
from ..stable_diffusion import InvokeAIDiffuserComponent
|
from ..stable_diffusion import InvokeAIDiffuserComponent
|
||||||
@ -163,8 +162,8 @@ def log_tokenization(
|
|||||||
negative_prompt: Union[Blend, FlattenedPrompt],
|
negative_prompt: Union[Blend, FlattenedPrompt],
|
||||||
tokenizer,
|
tokenizer,
|
||||||
):
|
):
|
||||||
logger.info(f"[TOKENLOG] Parsed Prompt: {positive_prompt}")
|
print(f"\n>> [TOKENLOG] Parsed Prompt: {positive_prompt}")
|
||||||
logger.info(f"[TOKENLOG] Parsed Negative Prompt: {negative_prompt}")
|
print(f"\n>> [TOKENLOG] Parsed Negative Prompt: {negative_prompt}")
|
||||||
|
|
||||||
log_tokenization_for_prompt_object(positive_prompt, tokenizer)
|
log_tokenization_for_prompt_object(positive_prompt, tokenizer)
|
||||||
log_tokenization_for_prompt_object(
|
log_tokenization_for_prompt_object(
|
||||||
@ -238,12 +237,12 @@ def log_tokenization_for_text(text, tokenizer, display_label=None, truncate_if_t
|
|||||||
usedTokens += 1
|
usedTokens += 1
|
||||||
|
|
||||||
if usedTokens > 0:
|
if usedTokens > 0:
|
||||||
logger.info(f'[TOKENLOG] Tokens {display_label or ""} ({usedTokens}):')
|
print(f'\n>> [TOKENLOG] Tokens {display_label or ""} ({usedTokens}):')
|
||||||
logger.debug(f"{tokenized}\x1b[0m")
|
print(f"{tokenized}\x1b[0m")
|
||||||
|
|
||||||
if discarded != "":
|
if discarded != "":
|
||||||
logger.info(f"[TOKENLOG] Tokens Discarded ({totalTokens - usedTokens}):")
|
print(f"\n>> [TOKENLOG] Tokens Discarded ({totalTokens - usedTokens}):")
|
||||||
logger.debug(f"{discarded}\x1b[0m")
|
print(f"{discarded}\x1b[0m")
|
||||||
|
|
||||||
|
|
||||||
def try_parse_legacy_blend(text: str, skip_normalize: bool = False) -> Optional[Blend]:
|
def try_parse_legacy_blend(text: str, skip_normalize: bool = False) -> Optional[Blend]:
|
||||||
@ -296,8 +295,8 @@ def split_weighted_subprompts(text, skip_normalize=False) -> list:
|
|||||||
return parsed_prompts
|
return parsed_prompts
|
||||||
weight_sum = sum(map(lambda x: x[1], parsed_prompts))
|
weight_sum = sum(map(lambda x: x[1], parsed_prompts))
|
||||||
if weight_sum == 0:
|
if weight_sum == 0:
|
||||||
logger.warning(
|
print(
|
||||||
"Subprompt weights add up to zero. Discarding and using even weights instead."
|
"* Warning: Subprompt weights add up to zero. Discarding and using even weights instead."
|
||||||
)
|
)
|
||||||
equal_weight = 1 / max(len(parsed_prompts), 1)
|
equal_weight = 1 / max(len(parsed_prompts), 1)
|
||||||
return [(x[0], equal_weight) for x in parsed_prompts]
|
return [(x[0], equal_weight) for x in parsed_prompts]
|
||||||
|
@ -1,5 +1,3 @@
|
|||||||
import invokeai.backend.util.logging as logger
|
|
||||||
|
|
||||||
class Restoration:
|
class Restoration:
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
pass
|
pass
|
||||||
@ -10,17 +8,17 @@ class Restoration:
|
|||||||
# Load GFPGAN
|
# Load GFPGAN
|
||||||
gfpgan = self.load_gfpgan(gfpgan_model_path)
|
gfpgan = self.load_gfpgan(gfpgan_model_path)
|
||||||
if gfpgan.gfpgan_model_exists:
|
if gfpgan.gfpgan_model_exists:
|
||||||
logger.info("GFPGAN Initialized")
|
print(">> GFPGAN Initialized")
|
||||||
else:
|
else:
|
||||||
logger.info("GFPGAN Disabled")
|
print(">> GFPGAN Disabled")
|
||||||
gfpgan = None
|
gfpgan = None
|
||||||
|
|
||||||
# Load CodeFormer
|
# Load CodeFormer
|
||||||
codeformer = self.load_codeformer()
|
codeformer = self.load_codeformer()
|
||||||
if codeformer.codeformer_model_exists:
|
if codeformer.codeformer_model_exists:
|
||||||
logger.info("CodeFormer Initialized")
|
print(">> CodeFormer Initialized")
|
||||||
else:
|
else:
|
||||||
logger.info("CodeFormer Disabled")
|
print(">> CodeFormer Disabled")
|
||||||
codeformer = None
|
codeformer = None
|
||||||
|
|
||||||
return gfpgan, codeformer
|
return gfpgan, codeformer
|
||||||
@ -41,5 +39,5 @@ class Restoration:
|
|||||||
from .realesrgan import ESRGAN
|
from .realesrgan import ESRGAN
|
||||||
|
|
||||||
esrgan = ESRGAN(esrgan_bg_tile)
|
esrgan = ESRGAN(esrgan_bg_tile)
|
||||||
logger.info("ESRGAN Initialized")
|
print(">> ESRGAN Initialized")
|
||||||
return esrgan
|
return esrgan
|
||||||
|
@ -5,7 +5,6 @@ import warnings
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
import invokeai.backend.util.logging as logger
|
|
||||||
from ..globals import Globals
|
from ..globals import Globals
|
||||||
|
|
||||||
pretrained_model_url = (
|
pretrained_model_url = (
|
||||||
@ -24,12 +23,12 @@ class CodeFormerRestoration:
|
|||||||
self.codeformer_model_exists = os.path.isfile(self.model_path)
|
self.codeformer_model_exists = os.path.isfile(self.model_path)
|
||||||
|
|
||||||
if not self.codeformer_model_exists:
|
if not self.codeformer_model_exists:
|
||||||
logger.error("NOT FOUND: CodeFormer model not found at " + self.model_path)
|
print("## NOT FOUND: CodeFormer model not found at " + self.model_path)
|
||||||
sys.path.append(os.path.abspath(codeformer_dir))
|
sys.path.append(os.path.abspath(codeformer_dir))
|
||||||
|
|
||||||
def process(self, image, strength, device, seed=None, fidelity=0.75):
|
def process(self, image, strength, device, seed=None, fidelity=0.75):
|
||||||
if seed is not None:
|
if seed is not None:
|
||||||
logger.info(f"CodeFormer - Restoring Faces for image seed:{seed}")
|
print(f">> CodeFormer - Restoring Faces for image seed:{seed}")
|
||||||
with warnings.catch_warnings():
|
with warnings.catch_warnings():
|
||||||
warnings.filterwarnings("ignore", category=DeprecationWarning)
|
warnings.filterwarnings("ignore", category=DeprecationWarning)
|
||||||
warnings.filterwarnings("ignore", category=UserWarning)
|
warnings.filterwarnings("ignore", category=UserWarning)
|
||||||
@ -98,7 +97,7 @@ class CodeFormerRestoration:
|
|||||||
del output
|
del output
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
except RuntimeError as error:
|
except RuntimeError as error:
|
||||||
logger.error(f"Failed inference for CodeFormer: {error}.")
|
print(f"\tFailed inference for CodeFormer: {error}.")
|
||||||
restored_face = cropped_face
|
restored_face = cropped_face
|
||||||
|
|
||||||
restored_face = restored_face.astype("uint8")
|
restored_face = restored_face.astype("uint8")
|
||||||
|
@ -6,9 +6,9 @@ import numpy as np
|
|||||||
import torch
|
import torch
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
import invokeai.backend.util.logging as logger
|
|
||||||
from invokeai.backend.globals import Globals
|
from invokeai.backend.globals import Globals
|
||||||
|
|
||||||
|
|
||||||
class GFPGAN:
|
class GFPGAN:
|
||||||
def __init__(self, gfpgan_model_path="models/gfpgan/GFPGANv1.4.pth") -> None:
|
def __init__(self, gfpgan_model_path="models/gfpgan/GFPGANv1.4.pth") -> None:
|
||||||
if not os.path.isabs(gfpgan_model_path):
|
if not os.path.isabs(gfpgan_model_path):
|
||||||
@ -19,7 +19,7 @@ class GFPGAN:
|
|||||||
self.gfpgan_model_exists = os.path.isfile(self.model_path)
|
self.gfpgan_model_exists = os.path.isfile(self.model_path)
|
||||||
|
|
||||||
if not self.gfpgan_model_exists:
|
if not self.gfpgan_model_exists:
|
||||||
logger.error("NOT FOUND: GFPGAN model not found at " + self.model_path)
|
print("## NOT FOUND: GFPGAN model not found at " + self.model_path)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def model_exists(self):
|
def model_exists(self):
|
||||||
@ -27,7 +27,7 @@ class GFPGAN:
|
|||||||
|
|
||||||
def process(self, image, strength: float, seed: str = None):
|
def process(self, image, strength: float, seed: str = None):
|
||||||
if seed is not None:
|
if seed is not None:
|
||||||
logger.info(f"GFPGAN - Restoring Faces for image seed:{seed}")
|
print(f">> GFPGAN - Restoring Faces for image seed:{seed}")
|
||||||
|
|
||||||
with warnings.catch_warnings():
|
with warnings.catch_warnings():
|
||||||
warnings.filterwarnings("ignore", category=DeprecationWarning)
|
warnings.filterwarnings("ignore", category=DeprecationWarning)
|
||||||
@ -47,14 +47,14 @@ class GFPGAN:
|
|||||||
except Exception:
|
except Exception:
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
logger.error("Error loading GFPGAN:", file=sys.stderr)
|
print(">> Error loading GFPGAN:", file=sys.stderr)
|
||||||
print(traceback.format_exc(), file=sys.stderr)
|
print(traceback.format_exc(), file=sys.stderr)
|
||||||
os.chdir(cwd)
|
os.chdir(cwd)
|
||||||
|
|
||||||
if self.gfpgan is None:
|
if self.gfpgan is None:
|
||||||
logger.warning("WARNING: GFPGAN not initialized.")
|
print(f">> WARNING: GFPGAN not initialized.")
|
||||||
logger.warning(
|
print(
|
||||||
f"Download https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth to {self.model_path}"
|
f">> Download https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth to {self.model_path}"
|
||||||
)
|
)
|
||||||
|
|
||||||
image = image.convert("RGB")
|
image = image.convert("RGB")
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
import math
|
import math
|
||||||
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
import invokeai.backend.util.logging as logger
|
|
||||||
|
|
||||||
class Outcrop(object):
|
class Outcrop(object):
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -82,7 +82,7 @@ class Outcrop(object):
|
|||||||
pixels = extents[direction]
|
pixels = extents[direction]
|
||||||
# round pixels up to the nearest 64
|
# round pixels up to the nearest 64
|
||||||
pixels = math.ceil(pixels / 64) * 64
|
pixels = math.ceil(pixels / 64) * 64
|
||||||
logger.info(f"extending image {direction}ward by {pixels} pixels")
|
print(f">> extending image {direction}ward by {pixels} pixels")
|
||||||
image = self._rotate(image, direction)
|
image = self._rotate(image, direction)
|
||||||
image = self._extend(image, pixels)
|
image = self._extend(image, pixels)
|
||||||
image = self._rotate(image, direction, reverse=True)
|
image = self._rotate(image, direction, reverse=True)
|
||||||
|
@ -6,13 +6,18 @@ import torch
|
|||||||
from PIL import Image
|
from PIL import Image
|
||||||
from PIL.Image import Image as ImageType
|
from PIL.Image import Image as ImageType
|
||||||
|
|
||||||
import invokeai.backend.util.logging as logger
|
|
||||||
from invokeai.backend.globals import Globals
|
from invokeai.backend.globals import Globals
|
||||||
|
|
||||||
|
|
||||||
class ESRGAN:
|
class ESRGAN:
|
||||||
def __init__(self, bg_tile_size=400) -> None:
|
def __init__(self, bg_tile_size=400) -> None:
|
||||||
self.bg_tile_size = bg_tile_size
|
self.bg_tile_size = bg_tile_size
|
||||||
|
|
||||||
|
if not torch.cuda.is_available(): # CPU or MPS on M1
|
||||||
|
use_half_precision = False
|
||||||
|
else:
|
||||||
|
use_half_precision = True
|
||||||
|
|
||||||
def load_esrgan_bg_upsampler(self, denoise_str):
|
def load_esrgan_bg_upsampler(self, denoise_str):
|
||||||
if not torch.cuda.is_available(): # CPU or MPS on M1
|
if not torch.cuda.is_available(): # CPU or MPS on M1
|
||||||
use_half_precision = False
|
use_half_precision = False
|
||||||
@ -69,16 +74,16 @@ class ESRGAN:
|
|||||||
import sys
|
import sys
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
logger.error("Error loading Real-ESRGAN:")
|
print(">> Error loading Real-ESRGAN:", file=sys.stderr)
|
||||||
print(traceback.format_exc(), file=sys.stderr)
|
print(traceback.format_exc(), file=sys.stderr)
|
||||||
|
|
||||||
if upsampler_scale == 0:
|
if upsampler_scale == 0:
|
||||||
logger.warning("Real-ESRGAN: Invalid scaling option. Image not upscaled.")
|
print(">> Real-ESRGAN: Invalid scaling option. Image not upscaled.")
|
||||||
return image
|
return image
|
||||||
|
|
||||||
if seed is not None:
|
if seed is not None:
|
||||||
logger.info(
|
print(
|
||||||
f"Real-ESRGAN Upscaling seed:{seed}, scale:{upsampler_scale}x, tile:{self.bg_tile_size}, denoise:{denoise_str}"
|
f">> Real-ESRGAN Upscaling seed:{seed}, scale:{upsampler_scale}x, tile:{self.bg_tile_size}, denoise:{denoise_str}"
|
||||||
)
|
)
|
||||||
# ESRGAN outputs images with partial transparency if given RGBA images; convert to RGB
|
# ESRGAN outputs images with partial transparency if given RGBA images; convert to RGB
|
||||||
image = image.convert("RGB")
|
image = image.convert("RGB")
|
||||||
|
@ -14,7 +14,6 @@ from PIL import Image, ImageFilter
|
|||||||
from transformers import AutoFeatureExtractor
|
from transformers import AutoFeatureExtractor
|
||||||
|
|
||||||
import invokeai.assets.web as web_assets
|
import invokeai.assets.web as web_assets
|
||||||
import invokeai.backend.util.logging as logger
|
|
||||||
from .globals import global_cache_dir
|
from .globals import global_cache_dir
|
||||||
from .util import CPU_DEVICE
|
from .util import CPU_DEVICE
|
||||||
|
|
||||||
@ -41,8 +40,8 @@ class SafetyChecker(object):
|
|||||||
cache_dir=safety_model_path,
|
cache_dir=safety_model_path,
|
||||||
)
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.error(
|
print(
|
||||||
"An error was encountered while installing the safety checker:"
|
"** An error was encountered while installing the safety checker:"
|
||||||
)
|
)
|
||||||
print(traceback.format_exc())
|
print(traceback.format_exc())
|
||||||
|
|
||||||
@ -66,8 +65,8 @@ class SafetyChecker(object):
|
|||||||
)
|
)
|
||||||
self.safety_checker.to(CPU_DEVICE) # offload
|
self.safety_checker.to(CPU_DEVICE) # offload
|
||||||
if has_nsfw_concept[0]:
|
if has_nsfw_concept[0]:
|
||||||
logger.warning(
|
print(
|
||||||
"An image with potential non-safe content has been detected. A blurred image will be returned."
|
"** An image with potential non-safe content has been detected. A blurred image will be returned. **"
|
||||||
)
|
)
|
||||||
return self.blur(image)
|
return self.blur(image)
|
||||||
else:
|
else:
|
||||||
|
@ -17,7 +17,6 @@ from huggingface_hub import (
|
|||||||
hf_hub_url,
|
hf_hub_url,
|
||||||
)
|
)
|
||||||
|
|
||||||
import invokeai.backend.util.logging as logger
|
|
||||||
from invokeai.backend.globals import Globals
|
from invokeai.backend.globals import Globals
|
||||||
|
|
||||||
|
|
||||||
@ -67,11 +66,11 @@ class HuggingFaceConceptsLibrary(object):
|
|||||||
# when init, add all in dir. when not init, add only concepts added between init and now
|
# when init, add all in dir. when not init, add only concepts added between init and now
|
||||||
self.concept_list.extend(list(local_concepts_to_add))
|
self.concept_list.extend(list(local_concepts_to_add))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(
|
print(
|
||||||
f"Hugging Face textual inversion concepts libraries could not be loaded. The error was {str(e)}."
|
f" ** WARNING: Hugging Face textual inversion concepts libraries could not be loaded. The error was {str(e)}."
|
||||||
)
|
)
|
||||||
logger.warning(
|
print(
|
||||||
"You may load .bin and .pt file(s) manually using the --embedding_directory argument."
|
" ** You may load .bin and .pt file(s) manually using the --embedding_directory argument."
|
||||||
)
|
)
|
||||||
return self.concept_list
|
return self.concept_list
|
||||||
else:
|
else:
|
||||||
@ -84,7 +83,7 @@ class HuggingFaceConceptsLibrary(object):
|
|||||||
be downloaded.
|
be downloaded.
|
||||||
"""
|
"""
|
||||||
if not concept_name in self.list_concepts():
|
if not concept_name in self.list_concepts():
|
||||||
logger.warning(
|
print(
|
||||||
f"{concept_name} is not a local embedding trigger, nor is it a HuggingFace concept. Generation will continue without the concept."
|
f"{concept_name} is not a local embedding trigger, nor is it a HuggingFace concept. Generation will continue without the concept."
|
||||||
)
|
)
|
||||||
return None
|
return None
|
||||||
@ -222,7 +221,7 @@ class HuggingFaceConceptsLibrary(object):
|
|||||||
if chunk == 0:
|
if chunk == 0:
|
||||||
bytes += total
|
bytes += total
|
||||||
|
|
||||||
logger.info(f"Downloading {repo_id}...", end="")
|
print(f">> Downloading {repo_id}...", end="")
|
||||||
try:
|
try:
|
||||||
for file in (
|
for file in (
|
||||||
"README.md",
|
"README.md",
|
||||||
@ -236,22 +235,22 @@ class HuggingFaceConceptsLibrary(object):
|
|||||||
)
|
)
|
||||||
except ul_error.HTTPError as e:
|
except ul_error.HTTPError as e:
|
||||||
if e.code == 404:
|
if e.code == 404:
|
||||||
logger.warning(
|
print(
|
||||||
f"Concept {concept_name} is not known to the Hugging Face library. Generation will continue without the concept."
|
f"Concept {concept_name} is not known to the Hugging Face library. Generation will continue without the concept."
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
logger.warning(
|
print(
|
||||||
f"Failed to download {concept_name}/{file} ({str(e)}. Generation will continue without the concept.)"
|
f"Failed to download {concept_name}/{file} ({str(e)}. Generation will continue without the concept.)"
|
||||||
)
|
)
|
||||||
os.rmdir(dest)
|
os.rmdir(dest)
|
||||||
return False
|
return False
|
||||||
except ul_error.URLError as e:
|
except ul_error.URLError as e:
|
||||||
logger.error(
|
print(
|
||||||
f"an error occurred while downloading {concept_name}: {str(e)}. This may reflect a network issue. Generation will continue without the concept."
|
f"ERROR while downloading {concept_name}: {str(e)}. This may reflect a network issue. Generation will continue without the concept."
|
||||||
)
|
)
|
||||||
os.rmdir(dest)
|
os.rmdir(dest)
|
||||||
return False
|
return False
|
||||||
logger.info("...{:.2f}Kb".format(bytes / 1024))
|
print("...{:.2f}Kb".format(bytes / 1024))
|
||||||
return succeeded
|
return succeeded
|
||||||
|
|
||||||
def _concept_id(self, concept_name: str) -> str:
|
def _concept_id(self, concept_name: str) -> str:
|
||||||
|
@ -13,9 +13,9 @@ from compel.cross_attention_control import Arguments
|
|||||||
from diffusers.models.attention_processor import AttentionProcessor
|
from diffusers.models.attention_processor import AttentionProcessor
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
import invokeai.backend.util.logging as logger
|
|
||||||
from ...util import torch_dtype
|
from ...util import torch_dtype
|
||||||
|
|
||||||
|
|
||||||
class CrossAttentionType(enum.Enum):
|
class CrossAttentionType(enum.Enum):
|
||||||
SELF = 1
|
SELF = 1
|
||||||
TOKENS = 2
|
TOKENS = 2
|
||||||
@ -421,7 +421,7 @@ def get_cross_attention_modules(
|
|||||||
expected_count = 16
|
expected_count = 16
|
||||||
if cross_attention_modules_in_model_count != expected_count:
|
if cross_attention_modules_in_model_count != expected_count:
|
||||||
# non-fatal error but .swap() won't work.
|
# non-fatal error but .swap() won't work.
|
||||||
logger.error(
|
print(
|
||||||
f"Error! CrossAttentionControl found an unexpected number of {cross_attention_class} modules in the model "
|
f"Error! CrossAttentionControl found an unexpected number of {cross_attention_class} modules in the model "
|
||||||
+ f"(expected {expected_count}, found {cross_attention_modules_in_model_count}). Either monkey-patching failed "
|
+ f"(expected {expected_count}, found {cross_attention_modules_in_model_count}). Either monkey-patching failed "
|
||||||
+ "or some assumption has changed about the structure of the model itself. Please fix the monkey-patching, "
|
+ "or some assumption has changed about the structure of the model itself. Please fix the monkey-patching, "
|
||||||
|
@ -8,7 +8,6 @@ import torch
|
|||||||
from diffusers.models.attention_processor import AttentionProcessor
|
from diffusers.models.attention_processor import AttentionProcessor
|
||||||
from typing_extensions import TypeAlias
|
from typing_extensions import TypeAlias
|
||||||
|
|
||||||
import invokeai.backend.util.logging as logger
|
|
||||||
from invokeai.backend.globals import Globals
|
from invokeai.backend.globals import Globals
|
||||||
|
|
||||||
from .cross_attention_control import (
|
from .cross_attention_control import (
|
||||||
@ -467,14 +466,10 @@ class InvokeAIDiffuserComponent:
|
|||||||
outside = torch.count_nonzero(
|
outside = torch.count_nonzero(
|
||||||
(latents < -current_threshold) | (latents > current_threshold)
|
(latents < -current_threshold) | (latents > current_threshold)
|
||||||
)
|
)
|
||||||
logger.info(
|
print(
|
||||||
f"Threshold: %={percent_through} threshold={current_threshold:.3f} (of {threshold:.3f})"
|
f"\nThreshold: %={percent_through} threshold={current_threshold:.3f} (of {threshold:.3f})\n"
|
||||||
)
|
f" | min, mean, max = {minval:.3f}, {mean:.3f}, {maxval:.3f}\tstd={std}\n"
|
||||||
logger.debug(
|
f" | {outside / latents.numel() * 100:.2f}% values outside threshold"
|
||||||
f"min, mean, max = {minval:.3f}, {mean:.3f}, {maxval:.3f}\tstd={std}"
|
|
||||||
)
|
|
||||||
logger.debug(
|
|
||||||
f"{outside / latents.numel() * 100:.2f}% values outside threshold"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if maxval < current_threshold and minval > -current_threshold:
|
if maxval < current_threshold and minval > -current_threshold:
|
||||||
@ -501,11 +496,9 @@ class InvokeAIDiffuserComponent:
|
|||||||
)
|
)
|
||||||
|
|
||||||
if self.debug_thresholding:
|
if self.debug_thresholding:
|
||||||
logger.debug(
|
print(
|
||||||
f"min, , max = {minval:.3f}, , {maxval:.3f}\t(scaled by {scale})"
|
f" | min, , max = {minval:.3f}, , {maxval:.3f}\t(scaled by {scale})\n"
|
||||||
)
|
f" | {num_altered / latents.numel() * 100:.2f}% values altered"
|
||||||
logger.debug(
|
|
||||||
f"{num_altered / latents.numel() * 100:.2f}% values altered"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return latents
|
return latents
|
||||||
|
@ -10,7 +10,7 @@ from torchvision.utils import make_grid
|
|||||||
|
|
||||||
# import matplotlib.pyplot as plt # TODO: check with Dominik, also bsrgan.py vs bsrgan_light.py
|
# import matplotlib.pyplot as plt # TODO: check with Dominik, also bsrgan.py vs bsrgan_light.py
|
||||||
|
|
||||||
import invokeai.backend.util.logging as logger
|
|
||||||
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
|
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
|
||||||
|
|
||||||
|
|
||||||
@ -191,7 +191,7 @@ def mkdirs(paths):
|
|||||||
def mkdir_and_rename(path):
|
def mkdir_and_rename(path):
|
||||||
if os.path.exists(path):
|
if os.path.exists(path):
|
||||||
new_name = path + "_archived_" + get_timestamp()
|
new_name = path + "_archived_" + get_timestamp()
|
||||||
logger.error("Path already exists. Rename it to [{:s}]".format(new_name))
|
print("Path already exists. Rename it to [{:s}]".format(new_name))
|
||||||
os.replace(path, new_name)
|
os.replace(path, new_name)
|
||||||
os.makedirs(path)
|
os.makedirs(path)
|
||||||
|
|
||||||
|
@ -10,7 +10,6 @@ from compel.embeddings_provider import BaseTextualInversionManager
|
|||||||
from picklescan.scanner import scan_file_path
|
from picklescan.scanner import scan_file_path
|
||||||
from transformers import CLIPTextModel, CLIPTokenizer
|
from transformers import CLIPTextModel, CLIPTokenizer
|
||||||
|
|
||||||
import invokeai.backend.util.logging as logger
|
|
||||||
from .concepts_lib import HuggingFaceConceptsLibrary
|
from .concepts_lib import HuggingFaceConceptsLibrary
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -60,12 +59,12 @@ class TextualInversionManager(BaseTextualInversionManager):
|
|||||||
or self.has_textual_inversion_for_trigger_string(concept_name)
|
or self.has_textual_inversion_for_trigger_string(concept_name)
|
||||||
or self.has_textual_inversion_for_trigger_string(f"<{concept_name}>")
|
or self.has_textual_inversion_for_trigger_string(f"<{concept_name}>")
|
||||||
): # in case a token with literal angle brackets encountered
|
): # in case a token with literal angle brackets encountered
|
||||||
logger.info(f"Loaded local embedding for trigger {concept_name}")
|
print(f">> Loaded local embedding for trigger {concept_name}")
|
||||||
continue
|
continue
|
||||||
bin_file = self.hf_concepts_library.get_concept_model_path(concept_name)
|
bin_file = self.hf_concepts_library.get_concept_model_path(concept_name)
|
||||||
if not bin_file:
|
if not bin_file:
|
||||||
continue
|
continue
|
||||||
logger.info(f"Loaded remote embedding for trigger {concept_name}")
|
print(f">> Loaded remote embedding for trigger {concept_name}")
|
||||||
self.load_textual_inversion(bin_file)
|
self.load_textual_inversion(bin_file)
|
||||||
self.hf_concepts_library.concepts_loaded[concept_name] = True
|
self.hf_concepts_library.concepts_loaded[concept_name] = True
|
||||||
|
|
||||||
@ -86,8 +85,8 @@ class TextualInversionManager(BaseTextualInversionManager):
|
|||||||
embedding_list = self._parse_embedding(str(ckpt_path))
|
embedding_list = self._parse_embedding(str(ckpt_path))
|
||||||
for embedding_info in embedding_list:
|
for embedding_info in embedding_list:
|
||||||
if (self.text_encoder.get_input_embeddings().weight.data[0].shape[0] != embedding_info.token_dim):
|
if (self.text_encoder.get_input_embeddings().weight.data[0].shape[0] != embedding_info.token_dim):
|
||||||
logger.warning(
|
print(
|
||||||
f"Notice: {ckpt_path.parents[0].name}/{ckpt_path.name} was trained on a model with an incompatible token dimension: {self.text_encoder.get_input_embeddings().weight.data[0].shape[0]} vs {embedding_info.token_dim}."
|
f" ** Notice: {ckpt_path.parents[0].name}/{ckpt_path.name} was trained on a model with an incompatible token dimension: {self.text_encoder.get_input_embeddings().weight.data[0].shape[0]} vs {embedding_info.token_dim}."
|
||||||
)
|
)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
@ -106,8 +105,8 @@ class TextualInversionManager(BaseTextualInversionManager):
|
|||||||
if ckpt_path.name == "learned_embeds.bin"
|
if ckpt_path.name == "learned_embeds.bin"
|
||||||
else f"<{ckpt_path.stem}>"
|
else f"<{ckpt_path.stem}>"
|
||||||
)
|
)
|
||||||
logger.info(
|
print(
|
||||||
f"{sourcefile}: Trigger token '{trigger_str}' is already claimed by '{self.trigger_to_sourcefile[trigger_str]}'. Trigger this concept with {replacement_trigger_str}"
|
f">> {sourcefile}: Trigger token '{trigger_str}' is already claimed by '{self.trigger_to_sourcefile[trigger_str]}'. Trigger this concept with {replacement_trigger_str}"
|
||||||
)
|
)
|
||||||
trigger_str = replacement_trigger_str
|
trigger_str = replacement_trigger_str
|
||||||
|
|
||||||
@ -121,8 +120,8 @@ class TextualInversionManager(BaseTextualInversionManager):
|
|||||||
self.trigger_to_sourcefile[trigger_str] = sourcefile
|
self.trigger_to_sourcefile[trigger_str] = sourcefile
|
||||||
|
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
logger.debug(f'Ignoring incompatible embedding {embedding_info["name"]}')
|
print(f' | Ignoring incompatible embedding {embedding_info["name"]}')
|
||||||
logger.debug(f"The error was {str(e)}")
|
print(f" | The error was {str(e)}")
|
||||||
|
|
||||||
def _add_textual_inversion(
|
def _add_textual_inversion(
|
||||||
self, trigger_str, embedding, defer_injecting_tokens=False
|
self, trigger_str, embedding, defer_injecting_tokens=False
|
||||||
@ -134,8 +133,8 @@ class TextualInversionManager(BaseTextualInversionManager):
|
|||||||
:return: The token id for the added embedding, either existing or newly-added.
|
:return: The token id for the added embedding, either existing or newly-added.
|
||||||
"""
|
"""
|
||||||
if trigger_str in [ti.trigger_string for ti in self.textual_inversions]:
|
if trigger_str in [ti.trigger_string for ti in self.textual_inversions]:
|
||||||
logger.warning(
|
print(
|
||||||
f"TextualInversionManager refusing to overwrite already-loaded token '{trigger_str}'"
|
f"** TextualInversionManager refusing to overwrite already-loaded token '{trigger_str}'"
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
if not self.full_precision:
|
if not self.full_precision:
|
||||||
@ -156,11 +155,11 @@ class TextualInversionManager(BaseTextualInversionManager):
|
|||||||
|
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
if str(e).startswith("Warning"):
|
if str(e).startswith("Warning"):
|
||||||
logger.warning(f"{str(e)}")
|
print(f">> {str(e)}")
|
||||||
else:
|
else:
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
logger.error(
|
print(
|
||||||
f"TextualInversionManager was unable to add a textual inversion with trigger string {trigger_str}."
|
f"** TextualInversionManager was unable to add a textual inversion with trigger string {trigger_str}."
|
||||||
)
|
)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
@ -220,16 +219,16 @@ class TextualInversionManager(BaseTextualInversionManager):
|
|||||||
for ti in self.textual_inversions:
|
for ti in self.textual_inversions:
|
||||||
if ti.trigger_token_id is None and ti.trigger_string in prompt_string:
|
if ti.trigger_token_id is None and ti.trigger_string in prompt_string:
|
||||||
if ti.embedding_vector_length > 1:
|
if ti.embedding_vector_length > 1:
|
||||||
logger.info(
|
print(
|
||||||
f"Preparing tokens for textual inversion {ti.trigger_string}..."
|
f">> Preparing tokens for textual inversion {ti.trigger_string}..."
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
self._inject_tokens_and_assign_embeddings(ti)
|
self._inject_tokens_and_assign_embeddings(ti)
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
logger.debug(
|
print(
|
||||||
f"Ignoring incompatible embedding trigger {ti.trigger_string}"
|
f" | Ignoring incompatible embedding trigger {ti.trigger_string}"
|
||||||
)
|
)
|
||||||
logger.debug(f"The error was {str(e)}")
|
print(f" | The error was {str(e)}")
|
||||||
continue
|
continue
|
||||||
injected_token_ids.append(ti.trigger_token_id)
|
injected_token_ids.append(ti.trigger_token_id)
|
||||||
injected_token_ids.extend(ti.pad_token_ids)
|
injected_token_ids.extend(ti.pad_token_ids)
|
||||||
@ -307,16 +306,16 @@ class TextualInversionManager(BaseTextualInversionManager):
|
|||||||
if suffix in [".pt",".ckpt",".bin"]:
|
if suffix in [".pt",".ckpt",".bin"]:
|
||||||
scan_result = scan_file_path(embedding_file)
|
scan_result = scan_file_path(embedding_file)
|
||||||
if scan_result.infected_files > 0:
|
if scan_result.infected_files > 0:
|
||||||
logger.critical(
|
print(
|
||||||
f"Security Issues Found in Model: {scan_result.issues_count}"
|
f" ** Security Issues Found in Model: {scan_result.issues_count}"
|
||||||
)
|
)
|
||||||
logger.critical("For your safety, InvokeAI will not load this embed.")
|
print(" ** For your safety, InvokeAI will not load this embed.")
|
||||||
return list()
|
return list()
|
||||||
ckpt = torch.load(embedding_file,map_location="cpu")
|
ckpt = torch.load(embedding_file,map_location="cpu")
|
||||||
else:
|
else:
|
||||||
ckpt = safetensors.torch.load_file(embedding_file)
|
ckpt = safetensors.torch.load_file(embedding_file)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Notice: unrecognized embedding file format: {embedding_file}: {e}")
|
print(f" ** Notice: unrecognized embedding file format: {embedding_file}: {e}")
|
||||||
return list()
|
return list()
|
||||||
|
|
||||||
# try to figure out what kind of embedding file it is and parse accordingly
|
# try to figure out what kind of embedding file it is and parse accordingly
|
||||||
@ -335,7 +334,7 @@ class TextualInversionManager(BaseTextualInversionManager):
|
|||||||
|
|
||||||
def _parse_embedding_v1(self, embedding_ckpt: dict, file_path: str)->List[EmbeddingInfo]:
|
def _parse_embedding_v1(self, embedding_ckpt: dict, file_path: str)->List[EmbeddingInfo]:
|
||||||
basename = Path(file_path).stem
|
basename = Path(file_path).stem
|
||||||
logger.debug(f'Loading v1 embedding file: {basename}')
|
print(f' | Loading v1 embedding file: {basename}')
|
||||||
|
|
||||||
embeddings = list()
|
embeddings = list()
|
||||||
token_counter = -1
|
token_counter = -1
|
||||||
@ -343,7 +342,7 @@ class TextualInversionManager(BaseTextualInversionManager):
|
|||||||
if token_counter < 0:
|
if token_counter < 0:
|
||||||
trigger = embedding_ckpt["name"]
|
trigger = embedding_ckpt["name"]
|
||||||
elif token_counter == 0:
|
elif token_counter == 0:
|
||||||
trigger = '<basename>'
|
trigger = f'<basename>'
|
||||||
else:
|
else:
|
||||||
trigger = f'<{basename}-{int(token_counter:=token_counter)}>'
|
trigger = f'<{basename}-{int(token_counter:=token_counter)}>'
|
||||||
token_counter += 1
|
token_counter += 1
|
||||||
@ -366,7 +365,7 @@ class TextualInversionManager(BaseTextualInversionManager):
|
|||||||
This handles embedding .pt file variant #2.
|
This handles embedding .pt file variant #2.
|
||||||
"""
|
"""
|
||||||
basename = Path(file_path).stem
|
basename = Path(file_path).stem
|
||||||
logger.debug(f'Loading v2 embedding file: {basename}')
|
print(f' | Loading v2 embedding file: {basename}')
|
||||||
embeddings = list()
|
embeddings = list()
|
||||||
|
|
||||||
if isinstance(
|
if isinstance(
|
||||||
@ -385,7 +384,7 @@ class TextualInversionManager(BaseTextualInversionManager):
|
|||||||
)
|
)
|
||||||
embeddings.append(embedding_info)
|
embeddings.append(embedding_info)
|
||||||
else:
|
else:
|
||||||
logger.warning(f"{basename}: Unrecognized embedding format")
|
print(f" ** {basename}: Unrecognized embedding format")
|
||||||
|
|
||||||
return embeddings
|
return embeddings
|
||||||
|
|
||||||
@ -394,7 +393,7 @@ class TextualInversionManager(BaseTextualInversionManager):
|
|||||||
Parse 'version 3' of the .pt textual inversion embedding files.
|
Parse 'version 3' of the .pt textual inversion embedding files.
|
||||||
"""
|
"""
|
||||||
basename = Path(file_path).stem
|
basename = Path(file_path).stem
|
||||||
logger.debug(f'Loading v3 embedding file: {basename}')
|
print(f' | Loading v3 embedding file: {basename}')
|
||||||
embedding = embedding_ckpt['emb_params']
|
embedding = embedding_ckpt['emb_params']
|
||||||
embedding_info = EmbeddingInfo(
|
embedding_info = EmbeddingInfo(
|
||||||
name = f'<{basename}>',
|
name = f'<{basename}>',
|
||||||
@ -412,11 +411,11 @@ class TextualInversionManager(BaseTextualInversionManager):
|
|||||||
basename = Path(filepath).stem
|
basename = Path(filepath).stem
|
||||||
short_path = Path(filepath).parents[0].name+'/'+Path(filepath).name
|
short_path = Path(filepath).parents[0].name+'/'+Path(filepath).name
|
||||||
|
|
||||||
logger.debug(f'Loading v4 embedding file: {short_path}')
|
print(f' | Loading v4 embedding file: {short_path}')
|
||||||
|
|
||||||
embeddings = list()
|
embeddings = list()
|
||||||
if list(embedding_ckpt.keys()) == 0:
|
if list(embedding_ckpt.keys()) == 0:
|
||||||
logger.warning(f"Invalid embeddings file: {short_path}")
|
print(f" ** Invalid embeddings file: {short_path}")
|
||||||
else:
|
else:
|
||||||
for token,embedding in embedding_ckpt.items():
|
for token,embedding in embedding_ckpt.items():
|
||||||
embedding_info = EmbeddingInfo(
|
embedding_info = EmbeddingInfo(
|
||||||
|
@ -1,109 +0,0 @@
|
|||||||
# Copyright (c) 2023 Lincoln D. Stein and The InvokeAI Development Team
|
|
||||||
|
|
||||||
"""invokeai.util.logging
|
|
||||||
|
|
||||||
Logging class for InvokeAI that produces console messages that follow
|
|
||||||
the conventions established in InvokeAI 1.X through 2.X.
|
|
||||||
|
|
||||||
|
|
||||||
One way to use it:
|
|
||||||
|
|
||||||
from invokeai.backend.util.logging import InvokeAILogger
|
|
||||||
|
|
||||||
logger = InvokeAILogger.getLogger(__name__)
|
|
||||||
logger.critical('this is critical')
|
|
||||||
logger.error('this is an error')
|
|
||||||
logger.warning('this is a warning')
|
|
||||||
logger.info('this is info')
|
|
||||||
logger.debug('this is debugging')
|
|
||||||
|
|
||||||
Console messages:
|
|
||||||
### this is critical
|
|
||||||
*** this is an error ***
|
|
||||||
** this is a warning
|
|
||||||
>> this is info
|
|
||||||
| this is debugging
|
|
||||||
|
|
||||||
Another way:
|
|
||||||
import invokeai.backend.util.logging as ialog
|
|
||||||
ialogger.debug('this is a debugging message')
|
|
||||||
"""
|
|
||||||
import logging
|
|
||||||
|
|
||||||
# module level functions
|
|
||||||
def debug(msg, *args, **kwargs):
|
|
||||||
InvokeAILogger.getLogger().debug(msg, *args, **kwargs)
|
|
||||||
|
|
||||||
def info(msg, *args, **kwargs):
|
|
||||||
InvokeAILogger.getLogger().info(msg, *args, **kwargs)
|
|
||||||
|
|
||||||
def warning(msg, *args, **kwargs):
|
|
||||||
InvokeAILogger.getLogger().warning(msg, *args, **kwargs)
|
|
||||||
|
|
||||||
def error(msg, *args, **kwargs):
|
|
||||||
InvokeAILogger.getLogger().error(msg, *args, **kwargs)
|
|
||||||
|
|
||||||
def critical(msg, *args, **kwargs):
|
|
||||||
InvokeAILogger.getLogger().critical(msg, *args, **kwargs)
|
|
||||||
|
|
||||||
def log(level, msg, *args, **kwargs):
|
|
||||||
InvokeAILogger.getLogger().log(level, msg, *args, **kwargs)
|
|
||||||
|
|
||||||
def disable(level=logging.CRITICAL):
|
|
||||||
InvokeAILogger.getLogger().disable(level)
|
|
||||||
|
|
||||||
def basicConfig(**kwargs):
|
|
||||||
InvokeAILogger.getLogger().basicConfig(**kwargs)
|
|
||||||
|
|
||||||
def getLogger(name: str=None)->logging.Logger:
|
|
||||||
return InvokeAILogger.getLogger(name)
|
|
||||||
|
|
||||||
class InvokeAILogFormatter(logging.Formatter):
|
|
||||||
'''
|
|
||||||
Repurposed from:
|
|
||||||
https://stackoverflow.com/questions/14844970/modifying-logging-message-format-based-on-message-logging-level-in-python3
|
|
||||||
'''
|
|
||||||
crit_fmt = "### %(msg)s"
|
|
||||||
err_fmt = "*** %(msg)s"
|
|
||||||
warn_fmt = "** %(msg)s"
|
|
||||||
info_fmt = ">> %(msg)s"
|
|
||||||
dbg_fmt = " | %(msg)s"
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__(fmt="%(levelno)d: %(msg)s", datefmt=None, style='%')
|
|
||||||
|
|
||||||
def format(self, record):
|
|
||||||
# Remember the format used when the logging module
|
|
||||||
# was installed (in the event that this formatter is
|
|
||||||
# used with the vanilla logging module.
|
|
||||||
format_orig = self._style._fmt
|
|
||||||
if record.levelno == logging.DEBUG:
|
|
||||||
self._style._fmt = InvokeAILogFormatter.dbg_fmt
|
|
||||||
if record.levelno == logging.INFO:
|
|
||||||
self._style._fmt = InvokeAILogFormatter.info_fmt
|
|
||||||
if record.levelno == logging.WARNING:
|
|
||||||
self._style._fmt = InvokeAILogFormatter.warn_fmt
|
|
||||||
if record.levelno == logging.ERROR:
|
|
||||||
self._style._fmt = InvokeAILogFormatter.err_fmt
|
|
||||||
if record.levelno == logging.CRITICAL:
|
|
||||||
self._style._fmt = InvokeAILogFormatter.crit_fmt
|
|
||||||
|
|
||||||
# parent class does the work
|
|
||||||
result = super().format(record)
|
|
||||||
self._style._fmt = format_orig
|
|
||||||
return result
|
|
||||||
|
|
||||||
class InvokeAILogger(object):
|
|
||||||
loggers = dict()
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def getLogger(self, name:str='invokeai')->logging.Logger:
|
|
||||||
if name not in self.loggers:
|
|
||||||
logger = logging.getLogger(name)
|
|
||||||
logger.setLevel(logging.DEBUG)
|
|
||||||
ch = logging.StreamHandler()
|
|
||||||
fmt = InvokeAILogFormatter()
|
|
||||||
ch.setFormatter(fmt)
|
|
||||||
logger.addHandler(ch)
|
|
||||||
self.loggers[name] = logger
|
|
||||||
return self.loggers[name]
|
|
@ -18,7 +18,6 @@ import torch
|
|||||||
from PIL import Image, ImageDraw, ImageFont
|
from PIL import Image, ImageDraw, ImageFont
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
import invokeai.backend.util.logging as logger
|
|
||||||
from .devices import torch_dtype
|
from .devices import torch_dtype
|
||||||
|
|
||||||
|
|
||||||
@ -39,7 +38,7 @@ def log_txt_as_img(wh, xc, size=10):
|
|||||||
try:
|
try:
|
||||||
draw.text((0, 0), lines, fill="black", font=font)
|
draw.text((0, 0), lines, fill="black", font=font)
|
||||||
except UnicodeEncodeError:
|
except UnicodeEncodeError:
|
||||||
logger.warning("Cant encode string for logging. Skipping.")
|
print("Cant encode string for logging. Skipping.")
|
||||||
|
|
||||||
txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0
|
txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0
|
||||||
txts.append(txt)
|
txts.append(txt)
|
||||||
@ -81,8 +80,8 @@ def mean_flat(tensor):
|
|||||||
def count_params(model, verbose=False):
|
def count_params(model, verbose=False):
|
||||||
total_params = sum(p.numel() for p in model.parameters())
|
total_params = sum(p.numel() for p in model.parameters())
|
||||||
if verbose:
|
if verbose:
|
||||||
logger.debug(
|
print(
|
||||||
f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params."
|
f" | {model.__class__.__name__} has {total_params * 1.e-6:.2f} M params."
|
||||||
)
|
)
|
||||||
return total_params
|
return total_params
|
||||||
|
|
||||||
@ -133,8 +132,8 @@ def parallel_data_prefetch(
|
|||||||
raise ValueError("list expected but function got ndarray.")
|
raise ValueError("list expected but function got ndarray.")
|
||||||
elif isinstance(data, abc.Iterable):
|
elif isinstance(data, abc.Iterable):
|
||||||
if isinstance(data, dict):
|
if isinstance(data, dict):
|
||||||
logger.warning(
|
print(
|
||||||
'"data" argument passed to parallel_data_prefetch is a dict: Using only its values and disregarding keys.'
|
'WARNING:"data" argument passed to parallel_data_prefetch is a dict: Using only its values and disregarding keys.'
|
||||||
)
|
)
|
||||||
data = list(data.values())
|
data = list(data.values())
|
||||||
if target_data_type == "ndarray":
|
if target_data_type == "ndarray":
|
||||||
@ -176,7 +175,7 @@ def parallel_data_prefetch(
|
|||||||
processes += [p]
|
processes += [p]
|
||||||
|
|
||||||
# start processes
|
# start processes
|
||||||
logger.info("Start prefetching...")
|
print("Start prefetching...")
|
||||||
import time
|
import time
|
||||||
|
|
||||||
start = time.time()
|
start = time.time()
|
||||||
@ -195,7 +194,7 @@ def parallel_data_prefetch(
|
|||||||
gather_res[res[0]] = res[1]
|
gather_res[res[0]] = res[1]
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error("Exception: ", e)
|
print("Exception: ", e)
|
||||||
for p in processes:
|
for p in processes:
|
||||||
p.terminate()
|
p.terminate()
|
||||||
|
|
||||||
@ -203,7 +202,7 @@ def parallel_data_prefetch(
|
|||||||
finally:
|
finally:
|
||||||
for p in processes:
|
for p in processes:
|
||||||
p.join()
|
p.join()
|
||||||
logger.info(f"Prefetching complete. [{time.time() - start} sec.]")
|
print(f"Prefetching complete. [{time.time() - start} sec.]")
|
||||||
|
|
||||||
if target_data_type == "ndarray":
|
if target_data_type == "ndarray":
|
||||||
if not isinstance(gather_res[0], np.ndarray):
|
if not isinstance(gather_res[0], np.ndarray):
|
||||||
@ -319,23 +318,23 @@ def download_with_resume(url: str, dest: Path, access_token: str = None) -> Path
|
|||||||
resp = requests.get(url, headers=header, stream=True) # new request with range
|
resp = requests.get(url, headers=header, stream=True) # new request with range
|
||||||
|
|
||||||
if exist_size > content_length:
|
if exist_size > content_length:
|
||||||
logger.warning("corrupt existing file found. re-downloading")
|
print("* corrupt existing file found. re-downloading")
|
||||||
os.remove(dest)
|
os.remove(dest)
|
||||||
exist_size = 0
|
exist_size = 0
|
||||||
|
|
||||||
if resp.status_code == 416 or exist_size == content_length:
|
if resp.status_code == 416 or exist_size == content_length:
|
||||||
logger.warning(f"{dest}: complete file found. Skipping.")
|
print(f"* {dest}: complete file found. Skipping.")
|
||||||
return dest
|
return dest
|
||||||
elif resp.status_code == 206 or exist_size > 0:
|
elif resp.status_code == 206 or exist_size > 0:
|
||||||
logger.warning(f"{dest}: partial file found. Resuming...")
|
print(f"* {dest}: partial file found. Resuming...")
|
||||||
elif resp.status_code != 200:
|
elif resp.status_code != 200:
|
||||||
logger.error(f"An error occurred during downloading {dest}: {resp.reason}")
|
print(f"** An error occurred during downloading {dest}: {resp.reason}")
|
||||||
else:
|
else:
|
||||||
logger.error(f"{dest}: Downloading...")
|
print(f"* {dest}: Downloading...")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if content_length < 2000:
|
if content_length < 2000:
|
||||||
logger.error(f"ERROR DOWNLOADING {url}: {resp.text}")
|
print(f"*** ERROR DOWNLOADING {url}: {resp.text}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
with open(dest, open_mode) as file, tqdm(
|
with open(dest, open_mode) as file, tqdm(
|
||||||
@ -350,7 +349,7 @@ def download_with_resume(url: str, dest: Path, access_token: str = None) -> Path
|
|||||||
size = file.write(data)
|
size = file.write(data)
|
||||||
bar.update(size)
|
bar.update(size)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"An error occurred while downloading {dest}: {str(e)}")
|
print(f"An error occurred while downloading {dest}: {str(e)}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
return dest
|
return dest
|
||||||
|
@ -19,7 +19,6 @@ from PIL import Image
|
|||||||
from PIL.Image import Image as ImageType
|
from PIL.Image import Image as ImageType
|
||||||
from werkzeug.utils import secure_filename
|
from werkzeug.utils import secure_filename
|
||||||
|
|
||||||
import invokeai.backend.util.logging as logger
|
|
||||||
import invokeai.frontend.web.dist as frontend
|
import invokeai.frontend.web.dist as frontend
|
||||||
|
|
||||||
from .. import Generate
|
from .. import Generate
|
||||||
@ -78,6 +77,7 @@ class InvokeAIWebServer:
|
|||||||
mimetypes.add_type("application/javascript", ".js")
|
mimetypes.add_type("application/javascript", ".js")
|
||||||
mimetypes.add_type("text/css", ".css")
|
mimetypes.add_type("text/css", ".css")
|
||||||
# Socket IO
|
# Socket IO
|
||||||
|
logger = True if args.web_verbose else False
|
||||||
engineio_logger = True if args.web_verbose else False
|
engineio_logger = True if args.web_verbose else False
|
||||||
max_http_buffer_size = 10000000
|
max_http_buffer_size = 10000000
|
||||||
|
|
||||||
@ -213,7 +213,7 @@ class InvokeAIWebServer:
|
|||||||
self.load_socketio_listeners(self.socketio)
|
self.load_socketio_listeners(self.socketio)
|
||||||
|
|
||||||
if args.gui:
|
if args.gui:
|
||||||
logger.info("Launching Invoke AI GUI")
|
print(">> Launching Invoke AI GUI")
|
||||||
try:
|
try:
|
||||||
from flaskwebgui import FlaskUI
|
from flaskwebgui import FlaskUI
|
||||||
|
|
||||||
@ -231,17 +231,17 @@ class InvokeAIWebServer:
|
|||||||
sys.exit(0)
|
sys.exit(0)
|
||||||
else:
|
else:
|
||||||
useSSL = args.certfile or args.keyfile
|
useSSL = args.certfile or args.keyfile
|
||||||
logger.info("Started Invoke AI Web Server")
|
print(">> Started Invoke AI Web Server")
|
||||||
if self.host == "0.0.0.0":
|
if self.host == "0.0.0.0":
|
||||||
logger.info(
|
print(
|
||||||
f"Point your browser at http{'s' if useSSL else ''}://localhost:{self.port} or use the host's DNS name or IP address."
|
f"Point your browser at http{'s' if useSSL else ''}://localhost:{self.port} or use the host's DNS name or IP address."
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
logger.info(
|
print(
|
||||||
"Default host address now 127.0.0.1 (localhost). Use --host 0.0.0.0 to bind any address."
|
">> Default host address now 127.0.0.1 (localhost). Use --host 0.0.0.0 to bind any address."
|
||||||
)
|
)
|
||||||
logger.info(
|
print(
|
||||||
f"Point your browser at http{'s' if useSSL else ''}://{self.host}:{self.port}"
|
f">> Point your browser at http{'s' if useSSL else ''}://{self.host}:{self.port}"
|
||||||
)
|
)
|
||||||
if not useSSL:
|
if not useSSL:
|
||||||
self.socketio.run(app=self.app, host=self.host, port=self.port)
|
self.socketio.run(app=self.app, host=self.host, port=self.port)
|
||||||
@ -273,7 +273,7 @@ class InvokeAIWebServer:
|
|||||||
# path for thumbnail images
|
# path for thumbnail images
|
||||||
self.thumbnail_image_path = os.path.join(self.result_path, "thumbnails/")
|
self.thumbnail_image_path = os.path.join(self.result_path, "thumbnails/")
|
||||||
# txt log
|
# txt log
|
||||||
self.log_path = os.path.join(self.result_path, "invoke_logger.txt")
|
self.log_path = os.path.join(self.result_path, "invoke_log.txt")
|
||||||
# make all output paths
|
# make all output paths
|
||||||
[
|
[
|
||||||
os.makedirs(path, exist_ok=True)
|
os.makedirs(path, exist_ok=True)
|
||||||
@ -290,7 +290,7 @@ class InvokeAIWebServer:
|
|||||||
def load_socketio_listeners(self, socketio):
|
def load_socketio_listeners(self, socketio):
|
||||||
@socketio.on("requestSystemConfig")
|
@socketio.on("requestSystemConfig")
|
||||||
def handle_request_capabilities():
|
def handle_request_capabilities():
|
||||||
logger.info("System config requested")
|
print(">> System config requested")
|
||||||
config = self.get_system_config()
|
config = self.get_system_config()
|
||||||
config["model_list"] = self.generate.model_manager.list_models()
|
config["model_list"] = self.generate.model_manager.list_models()
|
||||||
config["infill_methods"] = infill_methods()
|
config["infill_methods"] = infill_methods()
|
||||||
@ -330,7 +330,7 @@ class InvokeAIWebServer:
|
|||||||
if model_name in current_model_list:
|
if model_name in current_model_list:
|
||||||
update = True
|
update = True
|
||||||
|
|
||||||
logger.info(f"Adding New Model: {model_name}")
|
print(f">> Adding New Model: {model_name}")
|
||||||
|
|
||||||
self.generate.model_manager.add_model(
|
self.generate.model_manager.add_model(
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
@ -348,14 +348,14 @@ class InvokeAIWebServer:
|
|||||||
"update": update,
|
"update": update,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
logger.info(f"New Model Added: {model_name}")
|
print(f">> New Model Added: {model_name}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.handle_exceptions(e)
|
self.handle_exceptions(e)
|
||||||
|
|
||||||
@socketio.on("deleteModel")
|
@socketio.on("deleteModel")
|
||||||
def handle_delete_model(model_name: str):
|
def handle_delete_model(model_name: str):
|
||||||
try:
|
try:
|
||||||
logger.info(f"Deleting Model: {model_name}")
|
print(f">> Deleting Model: {model_name}")
|
||||||
self.generate.model_manager.del_model(model_name)
|
self.generate.model_manager.del_model(model_name)
|
||||||
self.generate.model_manager.commit(opt.conf)
|
self.generate.model_manager.commit(opt.conf)
|
||||||
updated_model_list = self.generate.model_manager.list_models()
|
updated_model_list = self.generate.model_manager.list_models()
|
||||||
@ -366,14 +366,14 @@ class InvokeAIWebServer:
|
|||||||
"model_list": updated_model_list,
|
"model_list": updated_model_list,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
logger.info(f"Model Deleted: {model_name}")
|
print(f">> Model Deleted: {model_name}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.handle_exceptions(e)
|
self.handle_exceptions(e)
|
||||||
|
|
||||||
@socketio.on("requestModelChange")
|
@socketio.on("requestModelChange")
|
||||||
def handle_set_model(model_name: str):
|
def handle_set_model(model_name: str):
|
||||||
try:
|
try:
|
||||||
logger.info(f"Model change requested: {model_name}")
|
print(f">> Model change requested: {model_name}")
|
||||||
model = self.generate.set_model(model_name)
|
model = self.generate.set_model(model_name)
|
||||||
model_list = self.generate.model_manager.list_models()
|
model_list = self.generate.model_manager.list_models()
|
||||||
if model is None:
|
if model is None:
|
||||||
@ -454,7 +454,7 @@ class InvokeAIWebServer:
|
|||||||
"update": True,
|
"update": True,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
logger.info(f"Model Converted: {model_name}")
|
print(f">> Model Converted: {model_name}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.handle_exceptions(e)
|
self.handle_exceptions(e)
|
||||||
|
|
||||||
@ -490,7 +490,7 @@ class InvokeAIWebServer:
|
|||||||
if vae := self.generate.model_manager.config[models_to_merge[0]].get(
|
if vae := self.generate.model_manager.config[models_to_merge[0]].get(
|
||||||
"vae", None
|
"vae", None
|
||||||
):
|
):
|
||||||
logger.info(f"Using configured VAE assigned to {models_to_merge[0]}")
|
print(f">> Using configured VAE assigned to {models_to_merge[0]}")
|
||||||
merged_model_config.update(vae=vae)
|
merged_model_config.update(vae=vae)
|
||||||
|
|
||||||
self.generate.model_manager.import_diffuser_model(
|
self.generate.model_manager.import_diffuser_model(
|
||||||
@ -507,8 +507,8 @@ class InvokeAIWebServer:
|
|||||||
"update": True,
|
"update": True,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
logger.info(f"Models Merged: {models_to_merge}")
|
print(f">> Models Merged: {models_to_merge}")
|
||||||
logger.info(f"New Model Added: {model_merge_info['merged_model_name']}")
|
print(f">> New Model Added: {model_merge_info['merged_model_name']}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.handle_exceptions(e)
|
self.handle_exceptions(e)
|
||||||
|
|
||||||
@ -698,7 +698,7 @@ class InvokeAIWebServer:
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.info(f"Unable to load {path}")
|
print(f">> Unable to load {path}")
|
||||||
socketio.emit(
|
socketio.emit(
|
||||||
"error", {"message": f"Unable to load {path}: {str(e)}"}
|
"error", {"message": f"Unable to load {path}: {str(e)}"}
|
||||||
)
|
)
|
||||||
@ -735,9 +735,9 @@ class InvokeAIWebServer:
|
|||||||
printable_parameters["init_mask"][:64] + "..."
|
printable_parameters["init_mask"][:64] + "..."
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info(f"Image Generation Parameters:\n\n{printable_parameters}\n")
|
print(f"\n>> Image Generation Parameters:\n\n{printable_parameters}\n")
|
||||||
logger.info(f"ESRGAN Parameters: {esrgan_parameters}")
|
print(f">> ESRGAN Parameters: {esrgan_parameters}")
|
||||||
logger.info(f"Facetool Parameters: {facetool_parameters}")
|
print(f">> Facetool Parameters: {facetool_parameters}")
|
||||||
|
|
||||||
self.generate_images(
|
self.generate_images(
|
||||||
generation_parameters,
|
generation_parameters,
|
||||||
@ -750,8 +750,8 @@ class InvokeAIWebServer:
|
|||||||
@socketio.on("runPostprocessing")
|
@socketio.on("runPostprocessing")
|
||||||
def handle_run_postprocessing(original_image, postprocessing_parameters):
|
def handle_run_postprocessing(original_image, postprocessing_parameters):
|
||||||
try:
|
try:
|
||||||
logger.info(
|
print(
|
||||||
f'Postprocessing requested for "{original_image["url"]}": {postprocessing_parameters}'
|
f'>> Postprocessing requested for "{original_image["url"]}": {postprocessing_parameters}'
|
||||||
)
|
)
|
||||||
|
|
||||||
progress = Progress()
|
progress = Progress()
|
||||||
@ -861,14 +861,14 @@ class InvokeAIWebServer:
|
|||||||
|
|
||||||
@socketio.on("cancel")
|
@socketio.on("cancel")
|
||||||
def handle_cancel():
|
def handle_cancel():
|
||||||
logger.info("Cancel processing requested")
|
print(">> Cancel processing requested")
|
||||||
self.canceled.set()
|
self.canceled.set()
|
||||||
|
|
||||||
# TODO: I think this needs a safety mechanism.
|
# TODO: I think this needs a safety mechanism.
|
||||||
@socketio.on("deleteImage")
|
@socketio.on("deleteImage")
|
||||||
def handle_delete_image(url, thumbnail, uuid, category):
|
def handle_delete_image(url, thumbnail, uuid, category):
|
||||||
try:
|
try:
|
||||||
logger.info(f'Delete requested "{url}"')
|
print(f'>> Delete requested "{url}"')
|
||||||
from send2trash import send2trash
|
from send2trash import send2trash
|
||||||
|
|
||||||
path = self.get_image_path_from_url(url)
|
path = self.get_image_path_from_url(url)
|
||||||
@ -1263,7 +1263,7 @@ class InvokeAIWebServer:
|
|||||||
image, os.path.basename(path), self.thumbnail_image_path
|
image, os.path.basename(path), self.thumbnail_image_path
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info(f'Image generated: "{path}"\n')
|
print(f'\n\n>> Image generated: "{path}"\n')
|
||||||
self.write_log_message(f'[Generated] "{path}": {command}')
|
self.write_log_message(f'[Generated] "{path}": {command}')
|
||||||
|
|
||||||
if progress.total_iterations > progress.current_iteration:
|
if progress.total_iterations > progress.current_iteration:
|
||||||
@ -1329,7 +1329,7 @@ class InvokeAIWebServer:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
# Clear the CUDA cache on an exception
|
# Clear the CUDA cache on an exception
|
||||||
self.empty_cuda_cache()
|
self.empty_cuda_cache()
|
||||||
logger.error(e)
|
print(e)
|
||||||
self.handle_exceptions(e)
|
self.handle_exceptions(e)
|
||||||
|
|
||||||
def empty_cuda_cache(self):
|
def empty_cuda_cache(self):
|
||||||
|
@ -16,7 +16,6 @@ if sys.platform == "darwin":
|
|||||||
import pyparsing # type: ignore
|
import pyparsing # type: ignore
|
||||||
|
|
||||||
import invokeai.version as invokeai
|
import invokeai.version as invokeai
|
||||||
import invokeai.backend.util.logging as logger
|
|
||||||
|
|
||||||
from ...backend import Generate, ModelManager
|
from ...backend import Generate, ModelManager
|
||||||
from ...backend.args import Args, dream_cmd_from_png, metadata_dumps, metadata_from_png
|
from ...backend.args import Args, dream_cmd_from_png, metadata_dumps, metadata_from_png
|
||||||
@ -70,7 +69,7 @@ def main():
|
|||||||
# run any post-install patches needed
|
# run any post-install patches needed
|
||||||
run_patches()
|
run_patches()
|
||||||
|
|
||||||
logger.info(f"Internet connectivity is {Globals.internet_available}")
|
print(f">> Internet connectivity is {Globals.internet_available}")
|
||||||
|
|
||||||
if not args.conf:
|
if not args.conf:
|
||||||
config_file = os.path.join(Globals.root, "configs", "models.yaml")
|
config_file = os.path.join(Globals.root, "configs", "models.yaml")
|
||||||
@ -79,8 +78,8 @@ def main():
|
|||||||
opt, FileNotFoundError(f"The file {config_file} could not be found.")
|
opt, FileNotFoundError(f"The file {config_file} could not be found.")
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info(f"{invokeai.__app_name__}, version {invokeai.__version__}")
|
print(f">> {invokeai.__app_name__}, version {invokeai.__version__}")
|
||||||
logger.info(f'InvokeAI runtime directory is "{Globals.root}"')
|
print(f'>> InvokeAI runtime directory is "{Globals.root}"')
|
||||||
|
|
||||||
# loading here to avoid long delays on startup
|
# loading here to avoid long delays on startup
|
||||||
# these two lines prevent a horrible warning message from appearing
|
# these two lines prevent a horrible warning message from appearing
|
||||||
@ -122,7 +121,7 @@ def main():
|
|||||||
else:
|
else:
|
||||||
raise FileNotFoundError(f"{opt.infile} not found.")
|
raise FileNotFoundError(f"{opt.infile} not found.")
|
||||||
except (FileNotFoundError, IOError) as e:
|
except (FileNotFoundError, IOError) as e:
|
||||||
logger.critical('Aborted',exc_info=True)
|
print(f"{e}. Aborting.")
|
||||||
sys.exit(-1)
|
sys.exit(-1)
|
||||||
|
|
||||||
# creating a Generate object:
|
# creating a Generate object:
|
||||||
@ -143,12 +142,12 @@ def main():
|
|||||||
)
|
)
|
||||||
except (FileNotFoundError, TypeError, AssertionError) as e:
|
except (FileNotFoundError, TypeError, AssertionError) as e:
|
||||||
report_model_error(opt, e)
|
report_model_error(opt, e)
|
||||||
except (IOError, KeyError):
|
except (IOError, KeyError) as e:
|
||||||
logger.critical("Aborted",exc_info=True)
|
print(f"{e}. Aborting.")
|
||||||
sys.exit(-1)
|
sys.exit(-1)
|
||||||
|
|
||||||
if opt.seamless:
|
if opt.seamless:
|
||||||
logger.info("Changed to seamless tiling mode")
|
print(">> changed to seamless tiling mode")
|
||||||
|
|
||||||
# preload the model
|
# preload the model
|
||||||
try:
|
try:
|
||||||
@ -181,7 +180,9 @@ def main():
|
|||||||
f'\nGoodbye!\nYou can start InvokeAI again by running the "invoke.bat" (or "invoke.sh") script from {Globals.root}'
|
f'\nGoodbye!\nYou can start InvokeAI again by running the "invoke.bat" (or "invoke.sh") script from {Globals.root}'
|
||||||
)
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.error("An error occurred",exc_info=True)
|
print(">> An error occurred:")
|
||||||
|
traceback.print_exc()
|
||||||
|
|
||||||
|
|
||||||
# TODO: main_loop() has gotten busy. Needs to be refactored.
|
# TODO: main_loop() has gotten busy. Needs to be refactored.
|
||||||
def main_loop(gen, opt):
|
def main_loop(gen, opt):
|
||||||
@ -247,7 +248,7 @@ def main_loop(gen, opt):
|
|||||||
if not opt.prompt:
|
if not opt.prompt:
|
||||||
oldargs = metadata_from_png(opt.init_img)
|
oldargs = metadata_from_png(opt.init_img)
|
||||||
opt.prompt = oldargs.prompt
|
opt.prompt = oldargs.prompt
|
||||||
logger.info(f'Retrieved old prompt "{opt.prompt}" from {opt.init_img}')
|
print(f'>> Retrieved old prompt "{opt.prompt}" from {opt.init_img}')
|
||||||
except (OSError, AttributeError, KeyError):
|
except (OSError, AttributeError, KeyError):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -264,9 +265,9 @@ def main_loop(gen, opt):
|
|||||||
if opt.init_img is not None and re.match("^-\\d+$", opt.init_img):
|
if opt.init_img is not None and re.match("^-\\d+$", opt.init_img):
|
||||||
try:
|
try:
|
||||||
opt.init_img = last_results[int(opt.init_img)][0]
|
opt.init_img = last_results[int(opt.init_img)][0]
|
||||||
logger.info(f"Reusing previous image {opt.init_img}")
|
print(f">> Reusing previous image {opt.init_img}")
|
||||||
except IndexError:
|
except IndexError:
|
||||||
logger.info(f"No previous initial image at position {opt.init_img} found")
|
print(f">> No previous initial image at position {opt.init_img} found")
|
||||||
opt.init_img = None
|
opt.init_img = None
|
||||||
continue
|
continue
|
||||||
|
|
||||||
@ -287,9 +288,9 @@ def main_loop(gen, opt):
|
|||||||
if opt.seed is not None and opt.seed < 0 and operation != "postprocess":
|
if opt.seed is not None and opt.seed < 0 and operation != "postprocess":
|
||||||
try:
|
try:
|
||||||
opt.seed = last_results[opt.seed][1]
|
opt.seed = last_results[opt.seed][1]
|
||||||
logger.info(f"Reusing previous seed {opt.seed}")
|
print(f">> Reusing previous seed {opt.seed}")
|
||||||
except IndexError:
|
except IndexError:
|
||||||
logger.info(f"No previous seed at position {opt.seed} found")
|
print(f">> No previous seed at position {opt.seed} found")
|
||||||
opt.seed = None
|
opt.seed = None
|
||||||
continue
|
continue
|
||||||
|
|
||||||
@ -308,7 +309,7 @@ def main_loop(gen, opt):
|
|||||||
subdir = subdir[: (path_max - 39 - len(os.path.abspath(opt.outdir)))]
|
subdir = subdir[: (path_max - 39 - len(os.path.abspath(opt.outdir)))]
|
||||||
current_outdir = os.path.join(opt.outdir, subdir)
|
current_outdir = os.path.join(opt.outdir, subdir)
|
||||||
|
|
||||||
logger.info('Writing files to directory: "' + current_outdir + '"')
|
print('Writing files to directory: "' + current_outdir + '"')
|
||||||
|
|
||||||
# make sure the output directory exists
|
# make sure the output directory exists
|
||||||
if not os.path.exists(current_outdir):
|
if not os.path.exists(current_outdir):
|
||||||
@ -437,14 +438,15 @@ def main_loop(gen, opt):
|
|||||||
catch_interrupts=catch_ctrl_c,
|
catch_interrupts=catch_ctrl_c,
|
||||||
**vars(opt),
|
**vars(opt),
|
||||||
)
|
)
|
||||||
except (PromptParser.ParsingException, pyparsing.ParseException):
|
except (PromptParser.ParsingException, pyparsing.ParseException) as e:
|
||||||
logger.error("An error occurred while processing your prompt",exc_info=True)
|
print("** An error occurred while processing your prompt **")
|
||||||
|
print(f"** {str(e)} **")
|
||||||
elif operation == "postprocess":
|
elif operation == "postprocess":
|
||||||
logger.info(f"fixing {opt.prompt}")
|
print(f">> fixing {opt.prompt}")
|
||||||
opt.last_operation = do_postprocess(gen, opt, image_writer)
|
opt.last_operation = do_postprocess(gen, opt, image_writer)
|
||||||
|
|
||||||
elif operation == "mask":
|
elif operation == "mask":
|
||||||
logger.info(f"generating masks from {opt.prompt}")
|
print(f">> generating masks from {opt.prompt}")
|
||||||
do_textmask(gen, opt, image_writer)
|
do_textmask(gen, opt, image_writer)
|
||||||
|
|
||||||
if opt.grid and len(grid_images) > 0:
|
if opt.grid and len(grid_images) > 0:
|
||||||
@ -467,12 +469,12 @@ def main_loop(gen, opt):
|
|||||||
)
|
)
|
||||||
results = [[path, formatted_dream_prompt]]
|
results = [[path, formatted_dream_prompt]]
|
||||||
|
|
||||||
except AssertionError:
|
except AssertionError as e:
|
||||||
logger.error(e)
|
print(e)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
except OSError as e:
|
except OSError as e:
|
||||||
logger.error(e)
|
print(e)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
print("Outputs:")
|
print("Outputs:")
|
||||||
@ -511,7 +513,7 @@ def do_command(command: str, gen, opt: Args, completer) -> tuple:
|
|||||||
gen.set_model(model_name)
|
gen.set_model(model_name)
|
||||||
add_embedding_terms(gen, completer)
|
add_embedding_terms(gen, completer)
|
||||||
except KeyError as e:
|
except KeyError as e:
|
||||||
logger.error(e)
|
print(str(e))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
report_model_error(opt, e)
|
report_model_error(opt, e)
|
||||||
completer.add_history(command)
|
completer.add_history(command)
|
||||||
@ -525,8 +527,8 @@ def do_command(command: str, gen, opt: Args, completer) -> tuple:
|
|||||||
elif command.startswith("!import"):
|
elif command.startswith("!import"):
|
||||||
path = shlex.split(command)
|
path = shlex.split(command)
|
||||||
if len(path) < 2:
|
if len(path) < 2:
|
||||||
logger.warning(
|
print(
|
||||||
"please provide (1) a URL to a .ckpt file to import; (2) a local path to a .ckpt file; or (3) a diffusers repository id in the form stabilityai/stable-diffusion-2-1"
|
"** please provide (1) a URL to a .ckpt file to import; (2) a local path to a .ckpt file; or (3) a diffusers repository id in the form stabilityai/stable-diffusion-2-1"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
@ -539,7 +541,7 @@ def do_command(command: str, gen, opt: Args, completer) -> tuple:
|
|||||||
elif command.startswith(("!convert", "!optimize")):
|
elif command.startswith(("!convert", "!optimize")):
|
||||||
path = shlex.split(command)
|
path = shlex.split(command)
|
||||||
if len(path) < 2:
|
if len(path) < 2:
|
||||||
logger.warning("please provide the path to a .ckpt or .safetensors model")
|
print("** please provide the path to a .ckpt or .safetensors model")
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
convert_model(path[1], gen, opt, completer)
|
convert_model(path[1], gen, opt, completer)
|
||||||
@ -551,7 +553,7 @@ def do_command(command: str, gen, opt: Args, completer) -> tuple:
|
|||||||
elif command.startswith("!edit"):
|
elif command.startswith("!edit"):
|
||||||
path = shlex.split(command)
|
path = shlex.split(command)
|
||||||
if len(path) < 2:
|
if len(path) < 2:
|
||||||
logger.warning("please provide the name of a model")
|
print("** please provide the name of a model")
|
||||||
else:
|
else:
|
||||||
edit_model(path[1], gen, opt, completer)
|
edit_model(path[1], gen, opt, completer)
|
||||||
completer.add_history(command)
|
completer.add_history(command)
|
||||||
@ -560,7 +562,7 @@ def do_command(command: str, gen, opt: Args, completer) -> tuple:
|
|||||||
elif command.startswith("!del"):
|
elif command.startswith("!del"):
|
||||||
path = shlex.split(command)
|
path = shlex.split(command)
|
||||||
if len(path) < 2:
|
if len(path) < 2:
|
||||||
logger.warning("please provide the name of a model")
|
print("** please provide the name of a model")
|
||||||
else:
|
else:
|
||||||
del_config(path[1], gen, opt, completer)
|
del_config(path[1], gen, opt, completer)
|
||||||
completer.add_history(command)
|
completer.add_history(command)
|
||||||
@ -640,8 +642,8 @@ def import_model(model_path: str, gen, opt, completer):
|
|||||||
try:
|
try:
|
||||||
default_name = url_attachment_name(model_path)
|
default_name = url_attachment_name(model_path)
|
||||||
default_name = Path(default_name).stem
|
default_name = Path(default_name).stem
|
||||||
except Exception:
|
except Exception as e:
|
||||||
logger.warning(f"A problem occurred while assigning the name of the downloaded model",exc_info=True)
|
print(f"** URL: {str(e)}")
|
||||||
model_name, model_desc = _get_model_name_and_desc(
|
model_name, model_desc = _get_model_name_and_desc(
|
||||||
gen.model_manager,
|
gen.model_manager,
|
||||||
completer,
|
completer,
|
||||||
@ -662,11 +664,11 @@ def import_model(model_path: str, gen, opt, completer):
|
|||||||
model_config_file=config_file,
|
model_config_file=config_file,
|
||||||
)
|
)
|
||||||
if not imported_name:
|
if not imported_name:
|
||||||
logger.error("Aborting import.")
|
print("** Aborting import.")
|
||||||
return
|
return
|
||||||
|
|
||||||
if not _verify_load(imported_name, gen):
|
if not _verify_load(imported_name, gen):
|
||||||
logger.error("model failed to load. Discarding configuration entry")
|
print("** model failed to load. Discarding configuration entry")
|
||||||
gen.model_manager.del_model(imported_name)
|
gen.model_manager.del_model(imported_name)
|
||||||
return
|
return
|
||||||
if click.confirm("Make this the default model?", default=False):
|
if click.confirm("Make this the default model?", default=False):
|
||||||
@ -674,7 +676,7 @@ def import_model(model_path: str, gen, opt, completer):
|
|||||||
|
|
||||||
gen.model_manager.commit(opt.conf)
|
gen.model_manager.commit(opt.conf)
|
||||||
completer.update_models(gen.model_manager.list_models())
|
completer.update_models(gen.model_manager.list_models())
|
||||||
logger.info(f"{imported_name} successfully installed")
|
print(f">> {imported_name} successfully installed")
|
||||||
|
|
||||||
def _pick_configuration_file(completer)->Path:
|
def _pick_configuration_file(completer)->Path:
|
||||||
print(
|
print(
|
||||||
@ -718,21 +720,21 @@ Please select the type of this model:
|
|||||||
return choice
|
return choice
|
||||||
|
|
||||||
def _verify_load(model_name: str, gen) -> bool:
|
def _verify_load(model_name: str, gen) -> bool:
|
||||||
logger.info("Verifying that new model loads...")
|
print(">> Verifying that new model loads...")
|
||||||
current_model = gen.model_name
|
current_model = gen.model_name
|
||||||
try:
|
try:
|
||||||
if not gen.set_model(model_name):
|
if not gen.set_model(model_name):
|
||||||
return
|
return
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"model failed to load: {str(e)}")
|
print(f"** model failed to load: {str(e)}")
|
||||||
logger.warning(
|
print(
|
||||||
"** note that importing 2.X checkpoints is not supported. Please use !convert_model instead."
|
"** note that importing 2.X checkpoints is not supported. Please use !convert_model instead."
|
||||||
)
|
)
|
||||||
return False
|
return False
|
||||||
if click.confirm("Keep model loaded?", default=True):
|
if click.confirm("Keep model loaded?", default=True):
|
||||||
gen.set_model(model_name)
|
gen.set_model(model_name)
|
||||||
else:
|
else:
|
||||||
logger.info("Restoring previous model")
|
print(">> Restoring previous model")
|
||||||
gen.set_model(current_model)
|
gen.set_model(current_model)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
@ -755,7 +757,7 @@ def convert_model(model_name_or_path: Union[Path, str], gen, opt, completer):
|
|||||||
ckpt_path = None
|
ckpt_path = None
|
||||||
original_config_file = None
|
original_config_file = None
|
||||||
if model_name_or_path == gen.model_name:
|
if model_name_or_path == gen.model_name:
|
||||||
logger.warning("Can't convert the active model. !switch to another model first. **")
|
print("** Can't convert the active model. !switch to another model first. **")
|
||||||
return
|
return
|
||||||
elif model_info := manager.model_info(model_name_or_path):
|
elif model_info := manager.model_info(model_name_or_path):
|
||||||
if "weights" in model_info:
|
if "weights" in model_info:
|
||||||
@ -765,7 +767,7 @@ def convert_model(model_name_or_path: Union[Path, str], gen, opt, completer):
|
|||||||
model_description = model_info["description"]
|
model_description = model_info["description"]
|
||||||
vae_path = model_info.get("vae")
|
vae_path = model_info.get("vae")
|
||||||
else:
|
else:
|
||||||
logger.warning(f"{model_name_or_path} is not a legacy .ckpt weights file")
|
print(f"** {model_name_or_path} is not a legacy .ckpt weights file")
|
||||||
return
|
return
|
||||||
model_name = manager.convert_and_import(
|
model_name = manager.convert_and_import(
|
||||||
ckpt_path,
|
ckpt_path,
|
||||||
@ -786,16 +788,16 @@ def convert_model(model_name_or_path: Union[Path, str], gen, opt, completer):
|
|||||||
manager.commit(opt.conf)
|
manager.commit(opt.conf)
|
||||||
if click.confirm(f"Delete the original .ckpt file at {ckpt_path}?", default=False):
|
if click.confirm(f"Delete the original .ckpt file at {ckpt_path}?", default=False):
|
||||||
ckpt_path.unlink(missing_ok=True)
|
ckpt_path.unlink(missing_ok=True)
|
||||||
logger.warning(f"{ckpt_path} deleted")
|
print(f"{ckpt_path} deleted")
|
||||||
|
|
||||||
|
|
||||||
def del_config(model_name: str, gen, opt, completer):
|
def del_config(model_name: str, gen, opt, completer):
|
||||||
current_model = gen.model_name
|
current_model = gen.model_name
|
||||||
if model_name == current_model:
|
if model_name == current_model:
|
||||||
logger.warning("Can't delete active model. !switch to another model first. **")
|
print("** Can't delete active model. !switch to another model first. **")
|
||||||
return
|
return
|
||||||
if model_name not in gen.model_manager.config:
|
if model_name not in gen.model_manager.config:
|
||||||
logger.warning(f"Unknown model {model_name}")
|
print(f"** Unknown model {model_name}")
|
||||||
return
|
return
|
||||||
|
|
||||||
if not click.confirm(
|
if not click.confirm(
|
||||||
@ -808,17 +810,17 @@ def del_config(model_name: str, gen, opt, completer):
|
|||||||
)
|
)
|
||||||
gen.model_manager.del_model(model_name, delete_files=delete_completely)
|
gen.model_manager.del_model(model_name, delete_files=delete_completely)
|
||||||
gen.model_manager.commit(opt.conf)
|
gen.model_manager.commit(opt.conf)
|
||||||
logger.warning(f"{model_name} deleted")
|
print(f"** {model_name} deleted")
|
||||||
completer.update_models(gen.model_manager.list_models())
|
completer.update_models(gen.model_manager.list_models())
|
||||||
|
|
||||||
|
|
||||||
def edit_model(model_name: str, gen, opt, completer):
|
def edit_model(model_name: str, gen, opt, completer):
|
||||||
manager = gen.model_manager
|
manager = gen.model_manager
|
||||||
if not (info := manager.model_info(model_name)):
|
if not (info := manager.model_info(model_name)):
|
||||||
logger.warning(f"** Unknown model {model_name}")
|
print(f"** Unknown model {model_name}")
|
||||||
return
|
return
|
||||||
print()
|
|
||||||
logger.info(f"Editing model {model_name} from configuration file {opt.conf}")
|
print(f"\n>> Editing model {model_name} from configuration file {opt.conf}")
|
||||||
new_name = _get_model_name(manager.list_models(), completer, model_name)
|
new_name = _get_model_name(manager.list_models(), completer, model_name)
|
||||||
|
|
||||||
for attribute in info.keys():
|
for attribute in info.keys():
|
||||||
@ -856,7 +858,7 @@ def edit_model(model_name: str, gen, opt, completer):
|
|||||||
manager.set_default_model(new_name)
|
manager.set_default_model(new_name)
|
||||||
manager.commit(opt.conf)
|
manager.commit(opt.conf)
|
||||||
completer.update_models(manager.list_models())
|
completer.update_models(manager.list_models())
|
||||||
logger.info("Model successfully updated")
|
print(">> Model successfully updated")
|
||||||
|
|
||||||
|
|
||||||
def _get_model_name(existing_names, completer, default_name: str = "") -> str:
|
def _get_model_name(existing_names, completer, default_name: str = "") -> str:
|
||||||
@ -867,11 +869,11 @@ def _get_model_name(existing_names, completer, default_name: str = "") -> str:
|
|||||||
if len(model_name) == 0:
|
if len(model_name) == 0:
|
||||||
model_name = default_name
|
model_name = default_name
|
||||||
if not re.match("^[\w._+:/-]+$", model_name):
|
if not re.match("^[\w._+:/-]+$", model_name):
|
||||||
logger.warning(
|
print(
|
||||||
'model name must contain only words, digits and the characters "._+:/-" **'
|
'** model name must contain only words, digits and the characters "._+:/-" **'
|
||||||
)
|
)
|
||||||
elif model_name != default_name and model_name in existing_names:
|
elif model_name != default_name and model_name in existing_names:
|
||||||
logger.warning(f"the name {model_name} is already in use. Pick another.")
|
print(f"** the name {model_name} is already in use. Pick another.")
|
||||||
else:
|
else:
|
||||||
done = True
|
done = True
|
||||||
return model_name
|
return model_name
|
||||||
@ -938,10 +940,11 @@ def do_postprocess(gen, opt, callback):
|
|||||||
opt=opt,
|
opt=opt,
|
||||||
)
|
)
|
||||||
except OSError:
|
except OSError:
|
||||||
logger.error(f"{file_path}: file could not be read",exc_info=True)
|
print(traceback.format_exc(), file=sys.stderr)
|
||||||
|
print(f"** {file_path}: file could not be read")
|
||||||
return
|
return
|
||||||
except (KeyError, AttributeError):
|
except (KeyError, AttributeError):
|
||||||
logger.error(f"an error occurred while applying the {tool} postprocessor",exc_info=True)
|
print(traceback.format_exc(), file=sys.stderr)
|
||||||
return
|
return
|
||||||
return opt.last_operation
|
return opt.last_operation
|
||||||
|
|
||||||
@ -996,13 +999,13 @@ def prepare_image_metadata(
|
|||||||
try:
|
try:
|
||||||
filename = opt.fnformat.format(**wildcards)
|
filename = opt.fnformat.format(**wildcards)
|
||||||
except KeyError as e:
|
except KeyError as e:
|
||||||
logger.error(
|
print(
|
||||||
f"The filename format contains an unknown key '{e.args[0]}'. Will use {{prefix}}.{{seed}}.png' instead"
|
f"** The filename format contains an unknown key '{e.args[0]}'. Will use {{prefix}}.{{seed}}.png' instead"
|
||||||
)
|
)
|
||||||
filename = f"{prefix}.{seed}.png"
|
filename = f"{prefix}.{seed}.png"
|
||||||
except IndexError:
|
except IndexError:
|
||||||
logger.error(
|
print(
|
||||||
"The filename format is broken or complete. Will use '{prefix}.{seed}.png' instead"
|
"** The filename format is broken or complete. Will use '{prefix}.{seed}.png' instead"
|
||||||
)
|
)
|
||||||
filename = f"{prefix}.{seed}.png"
|
filename = f"{prefix}.{seed}.png"
|
||||||
|
|
||||||
@ -1091,14 +1094,14 @@ def split_variations(variations_string) -> list:
|
|||||||
for part in variations_string.split(","):
|
for part in variations_string.split(","):
|
||||||
seed_and_weight = part.split(":")
|
seed_and_weight = part.split(":")
|
||||||
if len(seed_and_weight) != 2:
|
if len(seed_and_weight) != 2:
|
||||||
logger.warning(f'Could not parse with_variation part "{part}"')
|
print(f'** Could not parse with_variation part "{part}"')
|
||||||
broken = True
|
broken = True
|
||||||
break
|
break
|
||||||
try:
|
try:
|
||||||
seed = int(seed_and_weight[0])
|
seed = int(seed_and_weight[0])
|
||||||
weight = float(seed_and_weight[1])
|
weight = float(seed_and_weight[1])
|
||||||
except ValueError:
|
except ValueError:
|
||||||
logger.warning(f'Could not parse with_variation part "{part}"')
|
print(f'** Could not parse with_variation part "{part}"')
|
||||||
broken = True
|
broken = True
|
||||||
break
|
break
|
||||||
parts.append([seed, weight])
|
parts.append([seed, weight])
|
||||||
@ -1122,23 +1125,23 @@ def load_face_restoration(opt):
|
|||||||
opt.gfpgan_model_path
|
opt.gfpgan_model_path
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
logger.info("Face restoration disabled")
|
print(">> Face restoration disabled")
|
||||||
if opt.esrgan:
|
if opt.esrgan:
|
||||||
esrgan = restoration.load_esrgan(opt.esrgan_bg_tile)
|
esrgan = restoration.load_esrgan(opt.esrgan_bg_tile)
|
||||||
else:
|
else:
|
||||||
logger.info("Upscaling disabled")
|
print(">> Upscaling disabled")
|
||||||
else:
|
else:
|
||||||
logger.info("Face restoration and upscaling disabled")
|
print(">> Face restoration and upscaling disabled")
|
||||||
except (ModuleNotFoundError, ImportError):
|
except (ModuleNotFoundError, ImportError):
|
||||||
print(traceback.format_exc(), file=sys.stderr)
|
print(traceback.format_exc(), file=sys.stderr)
|
||||||
logger.info("You may need to install the ESRGAN and/or GFPGAN modules")
|
print(">> You may need to install the ESRGAN and/or GFPGAN modules")
|
||||||
return gfpgan, codeformer, esrgan
|
return gfpgan, codeformer, esrgan
|
||||||
|
|
||||||
|
|
||||||
def make_step_callback(gen, opt, prefix):
|
def make_step_callback(gen, opt, prefix):
|
||||||
destination = os.path.join(opt.outdir, "intermediates", prefix)
|
destination = os.path.join(opt.outdir, "intermediates", prefix)
|
||||||
os.makedirs(destination, exist_ok=True)
|
os.makedirs(destination, exist_ok=True)
|
||||||
logger.info(f"Intermediate images will be written into {destination}")
|
print(f">> Intermediate images will be written into {destination}")
|
||||||
|
|
||||||
def callback(state: PipelineIntermediateState):
|
def callback(state: PipelineIntermediateState):
|
||||||
latents = state.latents
|
latents = state.latents
|
||||||
@ -1180,20 +1183,21 @@ def retrieve_dream_command(opt, command, completer):
|
|||||||
try:
|
try:
|
||||||
cmd = dream_cmd_from_png(path)
|
cmd = dream_cmd_from_png(path)
|
||||||
except OSError:
|
except OSError:
|
||||||
logger.error(f"{tokens[0]}: file could not be read")
|
print(f"## {tokens[0]}: file could not be read")
|
||||||
except (KeyError, AttributeError, IndexError):
|
except (KeyError, AttributeError, IndexError):
|
||||||
logger.error(f"{tokens[0]}: file has no metadata")
|
print(f"## {tokens[0]}: file has no metadata")
|
||||||
except:
|
except:
|
||||||
logger.error(f"{tokens[0]}: file could not be processed")
|
print(f"## {tokens[0]}: file could not be processed")
|
||||||
if len(cmd) > 0:
|
if len(cmd) > 0:
|
||||||
completer.set_line(cmd)
|
completer.set_line(cmd)
|
||||||
|
|
||||||
|
|
||||||
def write_commands(opt, file_path: str, outfilepath: str):
|
def write_commands(opt, file_path: str, outfilepath: str):
|
||||||
dir, basename = os.path.split(file_path)
|
dir, basename = os.path.split(file_path)
|
||||||
try:
|
try:
|
||||||
paths = sorted(list(Path(dir).glob(basename)))
|
paths = sorted(list(Path(dir).glob(basename)))
|
||||||
except ValueError:
|
except ValueError:
|
||||||
logger.error(f'"{basename}": unacceptable pattern')
|
print(f'## "{basename}": unacceptable pattern')
|
||||||
return
|
return
|
||||||
|
|
||||||
commands = []
|
commands = []
|
||||||
@ -1202,9 +1206,9 @@ def write_commands(opt, file_path: str, outfilepath: str):
|
|||||||
try:
|
try:
|
||||||
cmd = dream_cmd_from_png(path)
|
cmd = dream_cmd_from_png(path)
|
||||||
except (KeyError, AttributeError, IndexError):
|
except (KeyError, AttributeError, IndexError):
|
||||||
logger.error(f"{path}: file has no metadata")
|
print(f"## {path}: file has no metadata")
|
||||||
except:
|
except:
|
||||||
logger.error(f"{path}: file could not be processed")
|
print(f"## {path}: file could not be processed")
|
||||||
if cmd:
|
if cmd:
|
||||||
commands.append(f"# {path}")
|
commands.append(f"# {path}")
|
||||||
commands.append(cmd)
|
commands.append(cmd)
|
||||||
@ -1214,18 +1218,18 @@ def write_commands(opt, file_path: str, outfilepath: str):
|
|||||||
outfilepath = os.path.join(opt.outdir, basename)
|
outfilepath = os.path.join(opt.outdir, basename)
|
||||||
with open(outfilepath, "w", encoding="utf-8") as f:
|
with open(outfilepath, "w", encoding="utf-8") as f:
|
||||||
f.write("\n".join(commands))
|
f.write("\n".join(commands))
|
||||||
logger.info(f"File {outfilepath} with commands created")
|
print(f">> File {outfilepath} with commands created")
|
||||||
|
|
||||||
|
|
||||||
def report_model_error(opt: Namespace, e: Exception):
|
def report_model_error(opt: Namespace, e: Exception):
|
||||||
logger.warning(f'An error occurred while attempting to initialize the model: "{str(e)}"')
|
print(f'** An error occurred while attempting to initialize the model: "{str(e)}"')
|
||||||
logger.warning(
|
print(
|
||||||
"This can be caused by a missing or corrupted models file, and can sometimes be fixed by (re)installing the models."
|
"** This can be caused by a missing or corrupted models file, and can sometimes be fixed by (re)installing the models."
|
||||||
)
|
)
|
||||||
yes_to_all = os.environ.get("INVOKE_MODEL_RECONFIGURE")
|
yes_to_all = os.environ.get("INVOKE_MODEL_RECONFIGURE")
|
||||||
if yes_to_all:
|
if yes_to_all:
|
||||||
logger.warning(
|
print(
|
||||||
"Reconfiguration is being forced by environment variable INVOKE_MODEL_RECONFIGURE"
|
"** Reconfiguration is being forced by environment variable INVOKE_MODEL_RECONFIGURE"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
if not click.confirm(
|
if not click.confirm(
|
||||||
@ -1234,7 +1238,7 @@ def report_model_error(opt: Namespace, e: Exception):
|
|||||||
):
|
):
|
||||||
return
|
return
|
||||||
|
|
||||||
logger.info("invokeai-configure is launching....\n")
|
print("invokeai-configure is launching....\n")
|
||||||
|
|
||||||
# Match arguments that were set on the CLI
|
# Match arguments that were set on the CLI
|
||||||
# only the arguments accepted by the configuration script are parsed
|
# only the arguments accepted by the configuration script are parsed
|
||||||
@ -1251,7 +1255,7 @@ def report_model_error(opt: Namespace, e: Exception):
|
|||||||
from ..install import invokeai_configure
|
from ..install import invokeai_configure
|
||||||
|
|
||||||
invokeai_configure()
|
invokeai_configure()
|
||||||
logger.warning("InvokeAI will now restart")
|
print("** InvokeAI will now restart")
|
||||||
sys.argv = previous_args
|
sys.argv = previous_args
|
||||||
main() # would rather do a os.exec(), but doesn't exist?
|
main() # would rather do a os.exec(), but doesn't exist?
|
||||||
sys.exit(0)
|
sys.exit(0)
|
||||||
|
@ -22,7 +22,6 @@ import torch
|
|||||||
from npyscreen import widget
|
from npyscreen import widget
|
||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf
|
||||||
|
|
||||||
import invokeai.backend.util.logging as logger
|
|
||||||
from invokeai.backend.globals import Globals, global_config_dir
|
from invokeai.backend.globals import Globals, global_config_dir
|
||||||
|
|
||||||
from ...backend.config.model_install_backend import (
|
from ...backend.config.model_install_backend import (
|
||||||
@ -456,8 +455,8 @@ def main():
|
|||||||
Globals.root = os.path.expanduser(get_root(opt.root) or "")
|
Globals.root = os.path.expanduser(get_root(opt.root) or "")
|
||||||
|
|
||||||
if not global_config_dir().exists():
|
if not global_config_dir().exists():
|
||||||
logger.info(
|
print(
|
||||||
"Your InvokeAI root directory is not set up. Calling invokeai-configure."
|
">> Your InvokeAI root directory is not set up. Calling invokeai-configure."
|
||||||
)
|
)
|
||||||
from invokeai.frontend.install import invokeai_configure
|
from invokeai.frontend.install import invokeai_configure
|
||||||
|
|
||||||
@ -467,18 +466,18 @@ def main():
|
|||||||
try:
|
try:
|
||||||
select_and_download_models(opt)
|
select_and_download_models(opt)
|
||||||
except AssertionError as e:
|
except AssertionError as e:
|
||||||
logger.error(e)
|
print(str(e))
|
||||||
sys.exit(-1)
|
sys.exit(-1)
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
logger.info("Goodbye! Come back soon.")
|
print("\nGoodbye! Come back soon.")
|
||||||
except widget.NotEnoughSpaceForWidget as e:
|
except widget.NotEnoughSpaceForWidget as e:
|
||||||
if str(e).startswith("Height of 1 allocated"):
|
if str(e).startswith("Height of 1 allocated"):
|
||||||
logger.error(
|
print(
|
||||||
"Insufficient vertical space for the interface. Please make your window taller and try again"
|
"** Insufficient vertical space for the interface. Please make your window taller and try again"
|
||||||
)
|
)
|
||||||
elif str(e).startswith("addwstr"):
|
elif str(e).startswith("addwstr"):
|
||||||
logger.error(
|
print(
|
||||||
"Insufficient horizontal space for the interface. Please make your window wider and try again."
|
"** Insufficient horizontal space for the interface. Please make your window wider and try again."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -27,8 +27,6 @@ from ...backend.globals import (
|
|||||||
global_models_dir,
|
global_models_dir,
|
||||||
global_set_root,
|
global_set_root,
|
||||||
)
|
)
|
||||||
|
|
||||||
import invokeai.backend.util.logging as logger
|
|
||||||
from ...backend.model_management import ModelManager
|
from ...backend.model_management import ModelManager
|
||||||
from ...frontend.install.widgets import FloatTitleSlider
|
from ...frontend.install.widgets import FloatTitleSlider
|
||||||
|
|
||||||
@ -115,7 +113,7 @@ def merge_diffusion_models_and_commit(
|
|||||||
model_name=merged_model_name, description=f'Merge of models {", ".join(models)}'
|
model_name=merged_model_name, description=f'Merge of models {", ".join(models)}'
|
||||||
)
|
)
|
||||||
if vae := model_manager.config[models[0]].get("vae", None):
|
if vae := model_manager.config[models[0]].get("vae", None):
|
||||||
logger.info(f"Using configured VAE assigned to {models[0]}")
|
print(f">> Using configured VAE assigned to {models[0]}")
|
||||||
import_args.update(vae=vae)
|
import_args.update(vae=vae)
|
||||||
model_manager.import_diffuser_model(dump_path, **import_args)
|
model_manager.import_diffuser_model(dump_path, **import_args)
|
||||||
model_manager.commit(config_file)
|
model_manager.commit(config_file)
|
||||||
@ -393,8 +391,10 @@ class mergeModelsForm(npyscreen.FormMultiPageAction):
|
|||||||
for name in self.model_manager.model_names()
|
for name in self.model_manager.model_names()
|
||||||
if self.model_manager.model_info(name).get("format") == "diffusers"
|
if self.model_manager.model_info(name).get("format") == "diffusers"
|
||||||
]
|
]
|
||||||
|
print(model_names)
|
||||||
return sorted(model_names)
|
return sorted(model_names)
|
||||||
|
|
||||||
|
|
||||||
class Mergeapp(npyscreen.NPSAppManaged):
|
class Mergeapp(npyscreen.NPSAppManaged):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -414,7 +414,7 @@ def run_gui(args: Namespace):
|
|||||||
|
|
||||||
args = mergeapp.merge_arguments
|
args = mergeapp.merge_arguments
|
||||||
merge_diffusion_models_and_commit(**args)
|
merge_diffusion_models_and_commit(**args)
|
||||||
logger.info(f'Models merged into new model: "{args["merged_model_name"]}".')
|
print(f'>> Models merged into new model: "{args["merged_model_name"]}".')
|
||||||
|
|
||||||
|
|
||||||
def run_cli(args: Namespace):
|
def run_cli(args: Namespace):
|
||||||
@ -425,8 +425,8 @@ def run_cli(args: Namespace):
|
|||||||
|
|
||||||
if not args.merged_model_name:
|
if not args.merged_model_name:
|
||||||
args.merged_model_name = "+".join(args.models)
|
args.merged_model_name = "+".join(args.models)
|
||||||
logger.info(
|
print(
|
||||||
f'No --merged_model_name provided. Defaulting to "{args.merged_model_name}"'
|
f'>> No --merged_model_name provided. Defaulting to "{args.merged_model_name}"'
|
||||||
)
|
)
|
||||||
|
|
||||||
model_manager = ModelManager(OmegaConf.load(global_config_file()))
|
model_manager = ModelManager(OmegaConf.load(global_config_file()))
|
||||||
@ -435,7 +435,7 @@ def run_cli(args: Namespace):
|
|||||||
), f'A model named "{args.merged_model_name}" already exists. Use --clobber to overwrite.'
|
), f'A model named "{args.merged_model_name}" already exists. Use --clobber to overwrite.'
|
||||||
|
|
||||||
merge_diffusion_models_and_commit(**vars(args))
|
merge_diffusion_models_and_commit(**vars(args))
|
||||||
logger.info(f'Models merged into new model: "{args.merged_model_name}".')
|
print(f'>> Models merged into new model: "{args.merged_model_name}".')
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
@ -455,16 +455,17 @@ def main():
|
|||||||
run_cli(args)
|
run_cli(args)
|
||||||
except widget.NotEnoughSpaceForWidget as e:
|
except widget.NotEnoughSpaceForWidget as e:
|
||||||
if str(e).startswith("Height of 1 allocated"):
|
if str(e).startswith("Height of 1 allocated"):
|
||||||
logger.error(
|
print(
|
||||||
"You need to have at least two diffusers models defined in models.yaml in order to merge"
|
"** You need to have at least two diffusers models defined in models.yaml in order to merge"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
logger.error(
|
print(
|
||||||
"Not enough room for the user interface. Try making this window larger."
|
"** Not enough room for the user interface. Try making this window larger."
|
||||||
)
|
)
|
||||||
sys.exit(-1)
|
sys.exit(-1)
|
||||||
except Exception as e:
|
except Exception:
|
||||||
logger.error(e)
|
print(">> An error occurred:")
|
||||||
|
traceback.print_exc()
|
||||||
sys.exit(-1)
|
sys.exit(-1)
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
sys.exit(-1)
|
sys.exit(-1)
|
||||||
|
@ -20,7 +20,6 @@ import npyscreen
|
|||||||
from npyscreen import widget
|
from npyscreen import widget
|
||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf
|
||||||
|
|
||||||
import invokeai.backend.util.logging as logger
|
|
||||||
from invokeai.backend.globals import Globals, global_set_root
|
from invokeai.backend.globals import Globals, global_set_root
|
||||||
|
|
||||||
from ...backend.training import do_textual_inversion_training, parse_args
|
from ...backend.training import do_textual_inversion_training, parse_args
|
||||||
@ -369,14 +368,14 @@ def copy_to_embeddings_folder(args: dict):
|
|||||||
dest_dir_name = args["placeholder_token"].strip("<>")
|
dest_dir_name = args["placeholder_token"].strip("<>")
|
||||||
destination = Path(Globals.root, "embeddings", dest_dir_name)
|
destination = Path(Globals.root, "embeddings", dest_dir_name)
|
||||||
os.makedirs(destination, exist_ok=True)
|
os.makedirs(destination, exist_ok=True)
|
||||||
logger.info(f"Training completed. Copying learned_embeds.bin into {str(destination)}")
|
print(f">> Training completed. Copying learned_embeds.bin into {str(destination)}")
|
||||||
shutil.copy(source, destination)
|
shutil.copy(source, destination)
|
||||||
if (
|
if (
|
||||||
input("Delete training logs and intermediate checkpoints? [y] ") or "y"
|
input("Delete training logs and intermediate checkpoints? [y] ") or "y"
|
||||||
).startswith(("y", "Y")):
|
).startswith(("y", "Y")):
|
||||||
shutil.rmtree(Path(args["output_dir"]))
|
shutil.rmtree(Path(args["output_dir"]))
|
||||||
else:
|
else:
|
||||||
logger.info(f'Keeping {args["output_dir"]}')
|
print(f'>> Keeping {args["output_dir"]}')
|
||||||
|
|
||||||
|
|
||||||
def save_args(args: dict):
|
def save_args(args: dict):
|
||||||
@ -423,10 +422,10 @@ def do_front_end(args: Namespace):
|
|||||||
do_textual_inversion_training(**args)
|
do_textual_inversion_training(**args)
|
||||||
copy_to_embeddings_folder(args)
|
copy_to_embeddings_folder(args)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error("An exception occurred during training. The exception was:")
|
print("** An exception occurred during training. The exception was:")
|
||||||
logger.error(str(e))
|
print(str(e))
|
||||||
logger.error("DETAILS:")
|
print("** DETAILS:")
|
||||||
logger.error(traceback.format_exc())
|
print(traceback.format_exc())
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
@ -438,21 +437,21 @@ def main():
|
|||||||
else:
|
else:
|
||||||
do_textual_inversion_training(**vars(args))
|
do_textual_inversion_training(**vars(args))
|
||||||
except AssertionError as e:
|
except AssertionError as e:
|
||||||
logger.error(e)
|
print(str(e))
|
||||||
sys.exit(-1)
|
sys.exit(-1)
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
pass
|
pass
|
||||||
except (widget.NotEnoughSpaceForWidget, Exception) as e:
|
except (widget.NotEnoughSpaceForWidget, Exception) as e:
|
||||||
if str(e).startswith("Height of 1 allocated"):
|
if str(e).startswith("Height of 1 allocated"):
|
||||||
logger.error(
|
print(
|
||||||
"You need to have at least one diffusers models defined in models.yaml in order to train"
|
"** You need to have at least one diffusers models defined in models.yaml in order to train"
|
||||||
)
|
)
|
||||||
elif str(e).startswith("addwstr"):
|
elif str(e).startswith("addwstr"):
|
||||||
logger.error(
|
print(
|
||||||
"Not enough window space for the interface. Please make your window larger and try again."
|
"** Not enough window space for the interface. Please make your window larger and try again."
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
logger.error(e)
|
print(f"** An error has occurred: {str(e)}")
|
||||||
sys.exit(-1)
|
sys.exit(-1)
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,40 +0,0 @@
|
|||||||
import react from '@vitejs/plugin-react-swc';
|
|
||||||
import { visualizer } from 'rollup-plugin-visualizer';
|
|
||||||
import { PluginOption, UserConfig } from 'vite';
|
|
||||||
import eslint from 'vite-plugin-eslint';
|
|
||||||
import tsconfigPaths from 'vite-tsconfig-paths';
|
|
||||||
|
|
||||||
export const appConfig: UserConfig = {
|
|
||||||
base: './',
|
|
||||||
plugins: [
|
|
||||||
react(),
|
|
||||||
eslint(),
|
|
||||||
tsconfigPaths(),
|
|
||||||
visualizer() as unknown as PluginOption,
|
|
||||||
],
|
|
||||||
build: {
|
|
||||||
chunkSizeWarningLimit: 1500,
|
|
||||||
},
|
|
||||||
server: {
|
|
||||||
// Proxy HTTP requests to the flask server
|
|
||||||
proxy: {
|
|
||||||
// Proxy socket.io to the nodes socketio server
|
|
||||||
'/ws/socket.io': {
|
|
||||||
target: 'ws://127.0.0.1:9090',
|
|
||||||
ws: true,
|
|
||||||
},
|
|
||||||
// Proxy openapi schema definiton
|
|
||||||
'/openapi.json': {
|
|
||||||
target: 'http://127.0.0.1:9090/openapi.json',
|
|
||||||
rewrite: (path) => path.replace(/^\/openapi.json/, ''),
|
|
||||||
changeOrigin: true,
|
|
||||||
},
|
|
||||||
// proxy nodes api
|
|
||||||
'/api/v1': {
|
|
||||||
target: 'http://127.0.0.1:9090/api/v1',
|
|
||||||
rewrite: (path) => path.replace(/^\/api\/v1/, ''),
|
|
||||||
changeOrigin: true,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
};
|
|
@ -1,47 +0,0 @@
|
|||||||
import react from '@vitejs/plugin-react-swc';
|
|
||||||
import path from 'path';
|
|
||||||
import { visualizer } from 'rollup-plugin-visualizer';
|
|
||||||
import { PluginOption, UserConfig } from 'vite';
|
|
||||||
import dts from 'vite-plugin-dts';
|
|
||||||
import eslint from 'vite-plugin-eslint';
|
|
||||||
import tsconfigPaths from 'vite-tsconfig-paths';
|
|
||||||
|
|
||||||
export const packageConfig: UserConfig = {
|
|
||||||
base: './',
|
|
||||||
plugins: [
|
|
||||||
react(),
|
|
||||||
eslint(),
|
|
||||||
tsconfigPaths(),
|
|
||||||
visualizer() as unknown as PluginOption,
|
|
||||||
dts({
|
|
||||||
insertTypesEntry: true,
|
|
||||||
}),
|
|
||||||
],
|
|
||||||
build: {
|
|
||||||
chunkSizeWarningLimit: 1500,
|
|
||||||
lib: {
|
|
||||||
entry: path.resolve(__dirname, '../src/index.ts'),
|
|
||||||
name: 'InvokeAIUI',
|
|
||||||
fileName: (format) => `invoke-ai-ui.${format}.js`,
|
|
||||||
},
|
|
||||||
rollupOptions: {
|
|
||||||
external: ['react', 'react-dom', '@emotion/react'],
|
|
||||||
output: {
|
|
||||||
globals: {
|
|
||||||
react: 'React',
|
|
||||||
'react-dom': 'ReactDOM',
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
resolve: {
|
|
||||||
alias: {
|
|
||||||
app: path.resolve(__dirname, '../src/app'),
|
|
||||||
assets: path.resolve(__dirname, '../src/assets'),
|
|
||||||
common: path.resolve(__dirname, '../src/common'),
|
|
||||||
features: path.resolve(__dirname, '../src/features'),
|
|
||||||
services: path.resolve(__dirname, '../src/services'),
|
|
||||||
theme: path.resolve(__dirname, '../src/theme'),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
};
|
|
98
invokeai/frontend/web/index.d.ts
vendored
Normal file
98
invokeai/frontend/web/index.d.ts
vendored
Normal file
@ -0,0 +1,98 @@
|
|||||||
|
import React, { PropsWithChildren } from 'react';
|
||||||
|
import { IAIPopoverProps } from '../web/src/common/components/IAIPopover';
|
||||||
|
import { IAIIconButtonProps } from '../web/src/common/components/IAIIconButton';
|
||||||
|
import { InvokeTabName } from 'features/ui/store/tabMap';
|
||||||
|
|
||||||
|
export {};
|
||||||
|
|
||||||
|
declare module 'redux-socket.io-middleware';
|
||||||
|
|
||||||
|
declare global {
|
||||||
|
/* eslint-disable @typescript-eslint/no-explicit-any */
|
||||||
|
interface Array<T> {
|
||||||
|
/**
|
||||||
|
* Returns the value of the last element in the array where predicate is true, and undefined
|
||||||
|
* otherwise.
|
||||||
|
* @param predicate findLast calls predicate once for each element of the array, in descending
|
||||||
|
* order, until it finds one where predicate returns true. If such an element is found, findLast
|
||||||
|
* immediately returns that element value. Otherwise, findLast returns undefined.
|
||||||
|
* @param thisArg If provided, it will be used as the this value for each invocation of
|
||||||
|
* predicate. If it is not provided, undefined is used instead.
|
||||||
|
*/
|
||||||
|
findLast<S extends T>(
|
||||||
|
predicate: (value: T, index: number, array: T[]) => value is S,
|
||||||
|
thisArg?: any
|
||||||
|
): S | undefined;
|
||||||
|
findLast(
|
||||||
|
predicate: (value: T, index: number, array: T[]) => unknown,
|
||||||
|
thisArg?: any
|
||||||
|
): T | undefined;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns the index of the last element in the array where predicate is true, and -1
|
||||||
|
* otherwise.
|
||||||
|
* @param predicate findLastIndex calls predicate once for each element of the array, in descending
|
||||||
|
* order, until it finds one where predicate returns true. If such an element is found,
|
||||||
|
* findLastIndex immediately returns that element index. Otherwise, findLastIndex returns -1.
|
||||||
|
* @param thisArg If provided, it will be used as the this value for each invocation of
|
||||||
|
* predicate. If it is not provided, undefined is used instead.
|
||||||
|
*/
|
||||||
|
findLastIndex(
|
||||||
|
predicate: (value: T, index: number, array: T[]) => unknown,
|
||||||
|
thisArg?: any
|
||||||
|
): number;
|
||||||
|
}
|
||||||
|
/* eslint-enable @typescript-eslint/no-explicit-any */
|
||||||
|
}
|
||||||
|
|
||||||
|
declare module '@invoke-ai/invoke-ai-ui' {
|
||||||
|
declare class ThemeChanger extends React.Component<ThemeChangerProps> {
|
||||||
|
public constructor(props: ThemeChangerProps);
|
||||||
|
}
|
||||||
|
|
||||||
|
declare class InvokeAiLogoComponent extends React.Component<InvokeAILogoComponentProps> {
|
||||||
|
public constructor(props: InvokeAILogoComponentProps);
|
||||||
|
}
|
||||||
|
|
||||||
|
declare class IAIPopover extends React.Component<IAIPopoverProps> {
|
||||||
|
public constructor(props: IAIPopoverProps);
|
||||||
|
}
|
||||||
|
|
||||||
|
declare class IAIIconButton extends React.Component<IAIIconButtonProps> {
|
||||||
|
public constructor(props: IAIIconButtonProps);
|
||||||
|
}
|
||||||
|
|
||||||
|
declare class SettingsModal extends React.Component<SettingsModalProps> {
|
||||||
|
public constructor(props: SettingsModalProps);
|
||||||
|
}
|
||||||
|
|
||||||
|
declare class StatusIndicator extends React.Component<StatusIndicatorProps> {
|
||||||
|
public constructor(props: StatusIndicatorProps);
|
||||||
|
}
|
||||||
|
|
||||||
|
declare class ModelSelect extends React.Component<ModelSelectProps> {
|
||||||
|
public constructor(props: ModelSelectProps);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
interface InvokeProps extends PropsWithChildren {
|
||||||
|
apiUrl?: string;
|
||||||
|
disabledPanels?: string[];
|
||||||
|
disabledTabs?: InvokeTabName[];
|
||||||
|
token?: string;
|
||||||
|
shouldTransformUrls?: boolean;
|
||||||
|
shouldFetchImages?: boolean;
|
||||||
|
}
|
||||||
|
|
||||||
|
declare function Invoke(props: InvokeProps): JSX.Element;
|
||||||
|
|
||||||
|
export {
|
||||||
|
ThemeChanger,
|
||||||
|
InvokeAiLogoComponent,
|
||||||
|
IAIPopover,
|
||||||
|
IAIIconButton,
|
||||||
|
SettingsModal,
|
||||||
|
StatusIndicator,
|
||||||
|
ModelSelect,
|
||||||
|
};
|
||||||
|
export = Invoke;
|
@ -1,23 +1,7 @@
|
|||||||
{
|
{
|
||||||
"name": "@invoke-ai/invoke-ai-ui",
|
"name": "invoke-ai-ui",
|
||||||
"private": true,
|
"private": true,
|
||||||
"version": "0.0.1",
|
"version": "0.0.1",
|
||||||
"publishConfig": {
|
|
||||||
"access": "restricted",
|
|
||||||
"registry": "https://npm.pkg.github.com"
|
|
||||||
},
|
|
||||||
"main": "./dist/invoke-ai-ui.umd.js",
|
|
||||||
"module": "./dist/invoke-ai-ui.es.js",
|
|
||||||
"exports": {
|
|
||||||
".": {
|
|
||||||
"import": "./dist/invoke-ai-ui.es.js",
|
|
||||||
"require": "./dist/invoke-ai-ui.umd.js"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"types": "./dist/index.d.ts",
|
|
||||||
"files": [
|
|
||||||
"dist"
|
|
||||||
],
|
|
||||||
"scripts": {
|
"scripts": {
|
||||||
"prepare": "cd ../../../ && husky install invokeai/frontend/web/.husky",
|
"prepare": "cd ../../../ && husky install invokeai/frontend/web/.husky",
|
||||||
"dev": "concurrently \"vite dev\" \"yarn run theme:watch\"",
|
"dev": "concurrently \"vite dev\" \"yarn run theme:watch\"",
|
||||||
@ -56,96 +40,80 @@
|
|||||||
},
|
},
|
||||||
"dependencies": {
|
"dependencies": {
|
||||||
"@chakra-ui/anatomy": "^2.1.1",
|
"@chakra-ui/anatomy": "^2.1.1",
|
||||||
"@chakra-ui/icons": "^2.0.19",
|
"@chakra-ui/cli": "^2.3.0",
|
||||||
"@chakra-ui/react": "^2.6.0",
|
"@chakra-ui/icons": "^2.0.17",
|
||||||
"@chakra-ui/styled-system": "^2.9.0",
|
"@chakra-ui/react": "^2.5.1",
|
||||||
|
"@chakra-ui/styled-system": "^2.6.1",
|
||||||
"@chakra-ui/theme-tools": "^2.0.16",
|
"@chakra-ui/theme-tools": "^2.0.16",
|
||||||
"@dagrejs/graphlib": "^2.1.12",
|
"@dagrejs/graphlib": "^2.1.12",
|
||||||
"@emotion/react": "^11.10.6",
|
"@emotion/react": "^11.10.6",
|
||||||
"@emotion/styled": "^11.10.6",
|
"@emotion/styled": "^11.10.6",
|
||||||
"@fontsource/inter": "^4.5.15",
|
"@fontsource/inter": "^4.5.15",
|
||||||
"@reduxjs/toolkit": "^1.9.5",
|
"@reduxjs/toolkit": "^1.9.3",
|
||||||
"@roarr/browser-log-writer": "^1.1.5",
|
|
||||||
"chakra-ui-contextmenu": "^1.0.5",
|
"chakra-ui-contextmenu": "^1.0.5",
|
||||||
"dateformat": "^5.0.3",
|
"dateformat": "^5.0.3",
|
||||||
"formik": "^2.2.9",
|
"formik": "^2.2.9",
|
||||||
"framer-motion": "^10.12.4",
|
"framer-motion": "^9.0.4",
|
||||||
"fuse.js": "^6.6.2",
|
"fuse.js": "^6.6.2",
|
||||||
"i18next": "^22.4.15",
|
"i18next": "^22.4.10",
|
||||||
"i18next-browser-languagedetector": "^7.0.1",
|
"i18next-browser-languagedetector": "^7.0.1",
|
||||||
"i18next-http-backend": "^2.2.0",
|
"i18next-http-backend": "^2.1.1",
|
||||||
"konva": "^9.0.1",
|
"konva": "^8.4.2",
|
||||||
"lodash-es": "^4.17.21",
|
"lodash": "^4.17.21",
|
||||||
"overlayscrollbars": "^2.1.1",
|
"patch-package": "^6.5.1",
|
||||||
"overlayscrollbars-react": "^0.5.0",
|
|
||||||
"patch-package": "^7.0.0",
|
|
||||||
"re-resizable": "^6.9.9",
|
"re-resizable": "^6.9.9",
|
||||||
"react": "^18.2.0",
|
"react": "^18.2.0",
|
||||||
"react-colorful": "^5.6.1",
|
"react-colorful": "^5.6.1",
|
||||||
"react-dom": "^18.2.0",
|
"react-dom": "^18.2.0",
|
||||||
"react-dropzone": "^14.2.3",
|
"react-dropzone": "^14.2.3",
|
||||||
"react-hotkeys-hook": "4.4.0",
|
"react-hotkeys-hook": "4.3.5",
|
||||||
"react-i18next": "^12.2.2",
|
"react-i18next": "^12.1.5",
|
||||||
"react-icons": "^4.7.1",
|
"react-icons": "^4.7.1",
|
||||||
"react-konva": "^18.2.7",
|
"react-konva": "^18.2.4",
|
||||||
"react-konva-utils": "^1.0.4",
|
"react-konva-utils": "^0.3.2",
|
||||||
"react-redux": "^8.0.5",
|
"react-redux": "^8.0.5",
|
||||||
"react-rnd": "^10.4.1",
|
|
||||||
"react-transition-group": "^4.4.5",
|
"react-transition-group": "^4.4.5",
|
||||||
"react-use": "^17.4.0",
|
"react-zoom-pan-pinch": "^2.6.1",
|
||||||
"react-virtuoso": "^4.3.5",
|
|
||||||
"react-zoom-pan-pinch": "^3.0.7",
|
|
||||||
"reactflow": "^11.7.0",
|
"reactflow": "^11.7.0",
|
||||||
"redux-deep-persist": "^1.0.7",
|
"redux-deep-persist": "^1.0.7",
|
||||||
"redux-dynamic-middlewares": "^2.2.0",
|
"redux-dynamic-middlewares": "^2.2.0",
|
||||||
"redux-persist": "^6.0.0",
|
"redux-persist": "^6.0.0",
|
||||||
"roarr": "^7.15.0",
|
|
||||||
"serialize-error": "^11.0.0",
|
|
||||||
"socket.io-client": "^4.6.0",
|
"socket.io-client": "^4.6.0",
|
||||||
"use-image": "^1.1.0",
|
"use-image": "^1.1.0",
|
||||||
"uuid": "^9.0.0"
|
"uuid": "^9.0.0"
|
||||||
},
|
},
|
||||||
"peerDependencies": {
|
|
||||||
"@chakra-ui/cli": "^2.4.0",
|
|
||||||
"react": "^18.2.0",
|
|
||||||
"react-dom": "^18.2.0",
|
|
||||||
"ts-toolbelt": "^9.6.0"
|
|
||||||
},
|
|
||||||
"devDependencies": {
|
"devDependencies": {
|
||||||
"@chakra-ui/cli": "^2.4.0",
|
|
||||||
"@types/dateformat": "^5.0.0",
|
"@types/dateformat": "^5.0.0",
|
||||||
"@types/lodash-es": "^4.14.194",
|
"@types/lodash": "^4.14.194",
|
||||||
"@types/node": "^18.16.2",
|
"@types/react": "^18.0.28",
|
||||||
"@types/react": "^18.2.0",
|
"@types/react-dom": "^18.0.11",
|
||||||
"@types/react-dom": "^18.2.1",
|
|
||||||
"@types/react-transition-group": "^4.4.5",
|
"@types/react-transition-group": "^4.4.5",
|
||||||
"@types/uuid": "^9.0.0",
|
"@types/uuid": "^9.0.0",
|
||||||
"@typescript-eslint/eslint-plugin": "^5.59.1",
|
"@typescript-eslint/eslint-plugin": "^5.52.0",
|
||||||
"@typescript-eslint/parser": "^5.59.1",
|
"@typescript-eslint/parser": "^5.52.0",
|
||||||
"@vitejs/plugin-react-swc": "^3.3.0",
|
"@vitejs/plugin-react-swc": "^3.2.0",
|
||||||
"axios": "^1.4.0",
|
"axios": "^1.3.4",
|
||||||
"babel-plugin-transform-imports": "^2.0.0",
|
"babel-plugin-transform-imports": "^2.0.0",
|
||||||
"concurrently": "^8.0.1",
|
"concurrently": "^7.6.0",
|
||||||
"eslint": "^8.39.0",
|
"eslint": "^8.34.0",
|
||||||
"eslint-config-prettier": "^8.8.0",
|
"eslint-config-prettier": "^8.6.0",
|
||||||
"eslint-plugin-prettier": "^4.2.1",
|
"eslint-plugin-prettier": "^4.2.1",
|
||||||
"eslint-plugin-react": "^7.32.2",
|
"eslint-plugin-react": "^7.32.2",
|
||||||
"eslint-plugin-react-hooks": "^4.6.0",
|
"eslint-plugin-react-hooks": "^4.6.0",
|
||||||
"form-data": "^4.0.0",
|
"form-data": "^4.0.0",
|
||||||
"husky": "^8.0.3",
|
"husky": "^8.0.3",
|
||||||
"lint-staged": "^13.2.2",
|
"lint-staged": "^13.1.2",
|
||||||
"madge": "^6.0.0",
|
"madge": "^6.0.0",
|
||||||
"openapi-types": "^12.1.0",
|
"openapi-types": "^12.1.0",
|
||||||
"openapi-typescript-codegen": "^0.24.0",
|
"openapi-typescript-codegen": "^0.23.0",
|
||||||
"postinstall-postinstall": "^2.1.0",
|
"postinstall-postinstall": "^2.1.0",
|
||||||
"prettier": "^2.8.8",
|
"prettier": "^2.8.4",
|
||||||
"rollup-plugin-visualizer": "^5.9.0",
|
"rollup-plugin-visualizer": "^5.9.0",
|
||||||
"terser": "^5.17.1",
|
"terser": "^5.16.4",
|
||||||
"ts-toolbelt": "^9.6.0",
|
"typescript": "4.9.5",
|
||||||
"vite": "^4.3.3",
|
"vite": "^4.1.2",
|
||||||
"vite-plugin-dts": "^2.3.0",
|
|
||||||
"vite-plugin-eslint": "^1.8.1",
|
"vite-plugin-eslint": "^1.8.1",
|
||||||
"vite-tsconfig-paths": "^4.2.0",
|
"vite-tsconfig-paths": "^4.0.5",
|
||||||
"yarn": "^1.22.19"
|
"yarn": "^1.22.19"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -63,7 +63,7 @@
|
|||||||
"postProcessDesc3": "The Invoke AI Command Line Interface offers various other features including Embiggen.",
|
"postProcessDesc3": "The Invoke AI Command Line Interface offers various other features including Embiggen.",
|
||||||
"training": "Training",
|
"training": "Training",
|
||||||
"trainingDesc1": "A dedicated workflow for training your own embeddings and checkpoints using Textual Inversion and Dreambooth from the web interface.",
|
"trainingDesc1": "A dedicated workflow for training your own embeddings and checkpoints using Textual Inversion and Dreambooth from the web interface.",
|
||||||
"trainingDesc2": "InvokeAI already supports training custom embeddourings using Textual Inversion using the main script.",
|
"trainingDesc2": "InvokeAI already supports training custom embeddings using Textual Inversion using the main script.",
|
||||||
"upload": "Upload",
|
"upload": "Upload",
|
||||||
"close": "Close",
|
"close": "Close",
|
||||||
"cancel": "Cancel",
|
"cancel": "Cancel",
|
||||||
@ -97,12 +97,7 @@
|
|||||||
"statusMergedModels": "Models Merged",
|
"statusMergedModels": "Models Merged",
|
||||||
"pinOptionsPanel": "Pin Options Panel",
|
"pinOptionsPanel": "Pin Options Panel",
|
||||||
"loading": "Loading",
|
"loading": "Loading",
|
||||||
"loadingInvokeAI": "Loading Invoke AI",
|
"loadingInvokeAI": "Loading Invoke AI"
|
||||||
"random": "Random",
|
|
||||||
"generate": "Generate",
|
|
||||||
"openInNewTab": "Open in New Tab",
|
|
||||||
"dontAskMeAgain": "Don't ask me again",
|
|
||||||
"areYouSure": "Are you sure?"
|
|
||||||
},
|
},
|
||||||
"gallery": {
|
"gallery": {
|
||||||
"generations": "Generations",
|
"generations": "Generations",
|
||||||
@ -118,10 +113,7 @@
|
|||||||
"pinGallery": "Pin Gallery",
|
"pinGallery": "Pin Gallery",
|
||||||
"allImagesLoaded": "All Images Loaded",
|
"allImagesLoaded": "All Images Loaded",
|
||||||
"loadMore": "Load More",
|
"loadMore": "Load More",
|
||||||
"noImagesInGallery": "No Images In Gallery",
|
"noImagesInGallery": "No Images In Gallery"
|
||||||
"deleteImage": "Delete Image",
|
|
||||||
"deleteImageBin": "Deleted images will be sent to your operating system's Bin.",
|
|
||||||
"deleteImagePermanent": "Deleted images cannot be restored."
|
|
||||||
},
|
},
|
||||||
"hotkeys": {
|
"hotkeys": {
|
||||||
"keyboardShortcuts": "Keyboard Shortcuts",
|
"keyboardShortcuts": "Keyboard Shortcuts",
|
||||||
@ -513,6 +505,7 @@
|
|||||||
"useAll": "Use All",
|
"useAll": "Use All",
|
||||||
"useInitImg": "Use Initial Image",
|
"useInitImg": "Use Initial Image",
|
||||||
"info": "Info",
|
"info": "Info",
|
||||||
|
"deleteImage": "Delete Image",
|
||||||
"initialImage": "Initial Image",
|
"initialImage": "Initial Image",
|
||||||
"showOptionsPanel": "Show Options Panel",
|
"showOptionsPanel": "Show Options Panel",
|
||||||
"hidePreview": "Hide Preview",
|
"hidePreview": "Hide Preview",
|
||||||
@ -527,15 +520,10 @@
|
|||||||
"useCanvasBeta": "Use Canvas Beta Layout",
|
"useCanvasBeta": "Use Canvas Beta Layout",
|
||||||
"enableImageDebugging": "Enable Image Debugging",
|
"enableImageDebugging": "Enable Image Debugging",
|
||||||
"useSlidersForAll": "Use Sliders For All Options",
|
"useSlidersForAll": "Use Sliders For All Options",
|
||||||
"autoShowProgress": "Auto Show Progress Images",
|
|
||||||
"resetWebUI": "Reset Web UI",
|
"resetWebUI": "Reset Web UI",
|
||||||
"resetWebUIDesc1": "Resetting the web UI only resets the browser's local cache of your images and remembered settings. It does not delete any images from disk.",
|
"resetWebUIDesc1": "Resetting the web UI only resets the browser's local cache of your images and remembered settings. It does not delete any images from disk.",
|
||||||
"resetWebUIDesc2": "If images aren't showing up in the gallery or something else isn't working, please try resetting before submitting an issue on GitHub.",
|
"resetWebUIDesc2": "If images aren't showing up in the gallery or something else isn't working, please try resetting before submitting an issue on GitHub.",
|
||||||
"resetComplete": "Web UI has been reset. Refresh the page to reload.",
|
"resetComplete": "Web UI has been reset. Refresh the page to reload."
|
||||||
"consoleLogLevel": "Log Level",
|
|
||||||
"shouldLogToConsole": "Console Logging",
|
|
||||||
"developer": "Developer",
|
|
||||||
"general": "General"
|
|
||||||
},
|
},
|
||||||
"toast": {
|
"toast": {
|
||||||
"serverError": "Server Error",
|
"serverError": "Server Error",
|
||||||
@ -646,9 +634,5 @@
|
|||||||
"betaDarkenOutside": "Darken Outside",
|
"betaDarkenOutside": "Darken Outside",
|
||||||
"betaLimitToBox": "Limit To Box",
|
"betaLimitToBox": "Limit To Box",
|
||||||
"betaPreserveMasked": "Preserve Masked"
|
"betaPreserveMasked": "Preserve Masked"
|
||||||
},
|
|
||||||
"ui": {
|
|
||||||
"showProgressImages": "Show Progress Images",
|
|
||||||
"hideProgressImages": "Hide Progress Images"
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
39
invokeai/frontend/web/src/Loading.tsx
Normal file
39
invokeai/frontend/web/src/Loading.tsx
Normal file
@ -0,0 +1,39 @@
|
|||||||
|
import { Flex, Spinner, Text } from '@chakra-ui/react';
|
||||||
|
import { useTranslation } from 'react-i18next';
|
||||||
|
|
||||||
|
interface LoaderProps {
|
||||||
|
showText?: boolean;
|
||||||
|
text?: string;
|
||||||
|
}
|
||||||
|
|
||||||
|
// This component loads before the theme so we cannot use theme tokens here
|
||||||
|
|
||||||
|
const Loading = (props: LoaderProps) => {
|
||||||
|
const { t } = useTranslation();
|
||||||
|
const { showText = false, text = t('common.loadingInvokeAI') } = props;
|
||||||
|
|
||||||
|
return (
|
||||||
|
<Flex
|
||||||
|
width="100vw"
|
||||||
|
height="100vh"
|
||||||
|
alignItems="center"
|
||||||
|
justifyContent="center"
|
||||||
|
bg="#121212"
|
||||||
|
flexDirection="column"
|
||||||
|
rowGap={4}
|
||||||
|
>
|
||||||
|
<Spinner color="grey" w="5rem" h="5rem" />
|
||||||
|
{showText && (
|
||||||
|
<Text
|
||||||
|
color="grey"
|
||||||
|
fontWeight="semibold"
|
||||||
|
fontFamily="'Inter', sans-serif"
|
||||||
|
>
|
||||||
|
{text}
|
||||||
|
</Text>
|
||||||
|
)}
|
||||||
|
</Flex>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
export default Loading;
|
100
invokeai/frontend/web/src/app/App.tsx
Normal file
100
invokeai/frontend/web/src/app/App.tsx
Normal file
@ -0,0 +1,100 @@
|
|||||||
|
import ImageUploader from 'common/components/ImageUploader';
|
||||||
|
import Console from 'features/system/components/Console';
|
||||||
|
import ProgressBar from 'features/system/components/ProgressBar';
|
||||||
|
import SiteHeader from 'features/system/components/SiteHeader';
|
||||||
|
import InvokeTabs from 'features/ui/components/InvokeTabs';
|
||||||
|
import { keepGUIAlive } from './utils';
|
||||||
|
|
||||||
|
import useToastWatcher from 'features/system/hooks/useToastWatcher';
|
||||||
|
|
||||||
|
import FloatingGalleryButton from 'features/ui/components/FloatingGalleryButton';
|
||||||
|
import FloatingParametersPanelButtons from 'features/ui/components/FloatingParametersPanelButtons';
|
||||||
|
import { Box, Flex, Grid, Portal, useColorMode } from '@chakra-ui/react';
|
||||||
|
import { APP_HEIGHT, APP_WIDTH } from 'theme/util/constants';
|
||||||
|
import ImageGalleryPanel from 'features/gallery/components/ImageGalleryPanel';
|
||||||
|
import Lightbox from 'features/lightbox/components/Lightbox';
|
||||||
|
import { useAppDispatch, useAppSelector } from './storeHooks';
|
||||||
|
import { PropsWithChildren, useEffect } from 'react';
|
||||||
|
import { setDisabledPanels, setDisabledTabs } from 'features/ui/store/uiSlice';
|
||||||
|
import { InvokeTabName } from 'features/ui/store/tabMap';
|
||||||
|
import { shouldTransformUrlsChanged } from 'features/system/store/systemSlice';
|
||||||
|
import { setShouldFetchImages } from 'features/gallery/store/resultsSlice';
|
||||||
|
|
||||||
|
keepGUIAlive();
|
||||||
|
|
||||||
|
interface Props extends PropsWithChildren {
|
||||||
|
options: {
|
||||||
|
disabledPanels: string[];
|
||||||
|
disabledTabs: InvokeTabName[];
|
||||||
|
shouldTransformUrls?: boolean;
|
||||||
|
shouldFetchImages: boolean;
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
const App = (props: Props) => {
|
||||||
|
useToastWatcher();
|
||||||
|
|
||||||
|
const currentTheme = useAppSelector((state) => state.ui.currentTheme);
|
||||||
|
const { setColorMode } = useColorMode();
|
||||||
|
const dispatch = useAppDispatch();
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
dispatch(setDisabledPanels(props.options.disabledPanels));
|
||||||
|
}, [dispatch, props.options.disabledPanels]);
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
dispatch(setDisabledTabs(props.options.disabledTabs));
|
||||||
|
}, [dispatch, props.options.disabledTabs]);
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
dispatch(
|
||||||
|
shouldTransformUrlsChanged(Boolean(props.options.shouldTransformUrls))
|
||||||
|
);
|
||||||
|
}, [dispatch, props.options.shouldTransformUrls]);
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
dispatch(setShouldFetchImages(props.options.shouldFetchImages));
|
||||||
|
}, [dispatch, props.options.shouldFetchImages]);
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
setColorMode(['light'].includes(currentTheme) ? 'light' : 'dark');
|
||||||
|
}, [setColorMode, currentTheme]);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<Grid w="100vw" h="100vh">
|
||||||
|
<Lightbox />
|
||||||
|
<ImageUploader>
|
||||||
|
<ProgressBar />
|
||||||
|
<Grid
|
||||||
|
gap={4}
|
||||||
|
p={4}
|
||||||
|
gridAutoRows="min-content auto"
|
||||||
|
w={APP_WIDTH}
|
||||||
|
h={APP_HEIGHT}
|
||||||
|
>
|
||||||
|
{props.children || <SiteHeader />}
|
||||||
|
<Flex
|
||||||
|
gap={4}
|
||||||
|
w={{ base: '100vw', xl: 'full' }}
|
||||||
|
h="full"
|
||||||
|
flexDir={{ base: 'column', xl: 'row' }}
|
||||||
|
>
|
||||||
|
<InvokeTabs />
|
||||||
|
<ImageGalleryPanel />
|
||||||
|
</Flex>
|
||||||
|
</Grid>
|
||||||
|
<Box>
|
||||||
|
<Console />
|
||||||
|
</Box>
|
||||||
|
</ImageUploader>
|
||||||
|
<Portal>
|
||||||
|
<FloatingParametersPanelButtons />
|
||||||
|
</Portal>
|
||||||
|
<Portal>
|
||||||
|
<FloatingGalleryButton />
|
||||||
|
</Portal>
|
||||||
|
</Grid>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
export default App;
|
@ -2,8 +2,8 @@ import { ChakraProvider, extendTheme } from '@chakra-ui/react';
|
|||||||
import { ReactNode, useEffect } from 'react';
|
import { ReactNode, useEffect } from 'react';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
import { theme as invokeAITheme } from 'theme/theme';
|
import { theme as invokeAITheme } from 'theme/theme';
|
||||||
import { RootState } from 'app/store/store';
|
import { RootState } from './store';
|
||||||
import { useAppSelector } from 'app/store/storeHooks';
|
import { useAppSelector } from './storeHooks';
|
||||||
|
|
||||||
import { greenTeaThemeColors } from 'theme/colors/greenTea';
|
import { greenTeaThemeColors } from 'theme/colors/greenTea';
|
||||||
import { invokeAIThemeColors } from 'theme/colors/invokeAI';
|
import { invokeAIThemeColors } from 'theme/colors/invokeAI';
|
||||||
@ -18,8 +18,6 @@ import '@fontsource/inter/600.css';
|
|||||||
import '@fontsource/inter/700.css';
|
import '@fontsource/inter/700.css';
|
||||||
import '@fontsource/inter/800.css';
|
import '@fontsource/inter/800.css';
|
||||||
import '@fontsource/inter/900.css';
|
import '@fontsource/inter/900.css';
|
||||||
import 'overlayscrollbars/overlayscrollbars.css';
|
|
||||||
import 'theme/css/overlayscrollbars.css';
|
|
||||||
|
|
||||||
type ThemeLocaleProviderProps = {
|
type ThemeLocaleProviderProps = {
|
||||||
children: ReactNode;
|
children: ReactNode;
|
@ -1,129 +0,0 @@
|
|||||||
import ImageUploader from 'common/components/ImageUploader';
|
|
||||||
import ProgressBar from 'features/system/components/ProgressBar';
|
|
||||||
import SiteHeader from 'features/system/components/SiteHeader';
|
|
||||||
import InvokeTabs from 'features/ui/components/InvokeTabs';
|
|
||||||
|
|
||||||
import useToastWatcher from 'features/system/hooks/useToastWatcher';
|
|
||||||
|
|
||||||
import FloatingGalleryButton from 'features/ui/components/FloatingGalleryButton';
|
|
||||||
import FloatingParametersPanelButtons from 'features/ui/components/FloatingParametersPanelButtons';
|
|
||||||
import { Box, Flex, Grid, Portal, useColorMode } from '@chakra-ui/react';
|
|
||||||
import { APP_HEIGHT, APP_WIDTH } from 'theme/util/constants';
|
|
||||||
import ImageGalleryPanel from 'features/gallery/components/ImageGalleryPanel';
|
|
||||||
import Lightbox from 'features/lightbox/components/Lightbox';
|
|
||||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
|
||||||
import {
|
|
||||||
memo,
|
|
||||||
PropsWithChildren,
|
|
||||||
useCallback,
|
|
||||||
useEffect,
|
|
||||||
useState,
|
|
||||||
} from 'react';
|
|
||||||
import { motion, AnimatePresence } from 'framer-motion';
|
|
||||||
import Loading from 'common/components/Loading/Loading';
|
|
||||||
import { useIsApplicationReady } from 'features/system/hooks/useIsApplicationReady';
|
|
||||||
import { PartialAppConfig } from 'app/types/invokeai';
|
|
||||||
import { useGlobalHotkeys } from 'common/hooks/useGlobalHotkeys';
|
|
||||||
import { configChanged } from 'features/system/store/configSlice';
|
|
||||||
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
|
|
||||||
import { useLogger } from 'app/logging/useLogger';
|
|
||||||
import ProgressImagePreview from 'features/parameters/components/ProgressImagePreview';
|
|
||||||
|
|
||||||
const DEFAULT_CONFIG = {};
|
|
||||||
|
|
||||||
interface Props extends PropsWithChildren {
|
|
||||||
config?: PartialAppConfig;
|
|
||||||
}
|
|
||||||
|
|
||||||
const App = ({ config = DEFAULT_CONFIG, children }: Props) => {
|
|
||||||
useToastWatcher();
|
|
||||||
useGlobalHotkeys();
|
|
||||||
const log = useLogger();
|
|
||||||
|
|
||||||
const currentTheme = useAppSelector((state) => state.ui.currentTheme);
|
|
||||||
|
|
||||||
const isLightboxEnabled = useFeatureStatus('lightbox').isFeatureEnabled;
|
|
||||||
|
|
||||||
const isApplicationReady = useIsApplicationReady();
|
|
||||||
|
|
||||||
const [loadingOverridden, setLoadingOverridden] = useState(false);
|
|
||||||
|
|
||||||
const { setColorMode } = useColorMode();
|
|
||||||
const dispatch = useAppDispatch();
|
|
||||||
|
|
||||||
useEffect(() => {
|
|
||||||
log.info({ namespace: 'App', data: config }, 'Received config');
|
|
||||||
dispatch(configChanged(config));
|
|
||||||
}, [dispatch, config, log]);
|
|
||||||
|
|
||||||
useEffect(() => {
|
|
||||||
setColorMode(['light'].includes(currentTheme) ? 'light' : 'dark');
|
|
||||||
}, [setColorMode, currentTheme]);
|
|
||||||
|
|
||||||
const handleOverrideClicked = useCallback(() => {
|
|
||||||
setLoadingOverridden(true);
|
|
||||||
}, []);
|
|
||||||
|
|
||||||
return (
|
|
||||||
<Grid w="100vw" h="100vh" position="relative" overflow="hidden">
|
|
||||||
{isLightboxEnabled && <Lightbox />}
|
|
||||||
<ImageUploader>
|
|
||||||
<ProgressBar />
|
|
||||||
<Grid
|
|
||||||
gap={4}
|
|
||||||
p={4}
|
|
||||||
gridAutoRows="min-content auto"
|
|
||||||
w={APP_WIDTH}
|
|
||||||
h={APP_HEIGHT}
|
|
||||||
>
|
|
||||||
{children || <SiteHeader />}
|
|
||||||
<Flex
|
|
||||||
gap={4}
|
|
||||||
w={{ base: '100vw', xl: 'full' }}
|
|
||||||
h="full"
|
|
||||||
flexDir={{ base: 'column', xl: 'row' }}
|
|
||||||
>
|
|
||||||
<InvokeTabs />
|
|
||||||
<ImageGalleryPanel />
|
|
||||||
</Flex>
|
|
||||||
</Grid>
|
|
||||||
</ImageUploader>
|
|
||||||
|
|
||||||
<AnimatePresence>
|
|
||||||
{!isApplicationReady && !loadingOverridden && (
|
|
||||||
<motion.div
|
|
||||||
key="loading"
|
|
||||||
initial={{ opacity: 1 }}
|
|
||||||
animate={{ opacity: 1 }}
|
|
||||||
exit={{ opacity: 0 }}
|
|
||||||
transition={{ duration: 0.3 }}
|
|
||||||
style={{ zIndex: 3 }}
|
|
||||||
>
|
|
||||||
<Box position="absolute" top={0} left={0} w="100vw" h="100vh">
|
|
||||||
<Loading />
|
|
||||||
</Box>
|
|
||||||
<Box
|
|
||||||
onClick={handleOverrideClicked}
|
|
||||||
position="absolute"
|
|
||||||
top={0}
|
|
||||||
right={0}
|
|
||||||
cursor="pointer"
|
|
||||||
w="2rem"
|
|
||||||
h="2rem"
|
|
||||||
/>
|
|
||||||
</motion.div>
|
|
||||||
)}
|
|
||||||
</AnimatePresence>
|
|
||||||
|
|
||||||
<Portal>
|
|
||||||
<FloatingParametersPanelButtons />
|
|
||||||
</Portal>
|
|
||||||
<Portal>
|
|
||||||
<FloatingGalleryButton />
|
|
||||||
</Portal>
|
|
||||||
<ProgressImagePreview />
|
|
||||||
</Grid>
|
|
||||||
);
|
|
||||||
};
|
|
||||||
|
|
||||||
export default memo(App);
|
|
@ -1,6 +1,23 @@
|
|||||||
// TODO: use Enums?
|
// TODO: use Enums?
|
||||||
|
|
||||||
export const DIFFUSERS_SCHEDULERS: Array<string> = [
|
import { InProgressImageType } from 'features/system/store/systemSlice';
|
||||||
|
|
||||||
|
// Valid samplers
|
||||||
|
export const SAMPLERS: Array<string> = [
|
||||||
|
'ddim',
|
||||||
|
'plms',
|
||||||
|
'k_lms',
|
||||||
|
'k_dpm_2',
|
||||||
|
'k_dpm_2_a',
|
||||||
|
'k_dpmpp_2',
|
||||||
|
'k_dpmpp_2_a',
|
||||||
|
'k_euler',
|
||||||
|
'k_euler_a',
|
||||||
|
'k_heun',
|
||||||
|
];
|
||||||
|
|
||||||
|
// Valid Diffusers Samplers
|
||||||
|
export const DIFFUSERS_SAMPLERS: Array<string> = [
|
||||||
'ddim',
|
'ddim',
|
||||||
'plms',
|
'plms',
|
||||||
'k_lms',
|
'k_lms',
|
||||||
@ -31,8 +48,17 @@ export const UPSCALING_LEVELS: Array<{ key: string; value: number }> = [
|
|||||||
|
|
||||||
export const NUMPY_RAND_MIN = 0;
|
export const NUMPY_RAND_MIN = 0;
|
||||||
|
|
||||||
export const NUMPY_RAND_MAX = 2147483647;
|
export const NUMPY_RAND_MAX = 4294967295;
|
||||||
|
|
||||||
export const FACETOOL_TYPES = ['gfpgan', 'codeformer'] as const;
|
export const FACETOOL_TYPES = ['gfpgan', 'codeformer'] as const;
|
||||||
|
|
||||||
|
export const IN_PROGRESS_IMAGE_TYPES: Array<{
|
||||||
|
key: string;
|
||||||
|
value: InProgressImageType;
|
||||||
|
}> = [
|
||||||
|
{ key: 'None', value: 'none' },
|
||||||
|
{ key: 'Fast', value: 'latents' },
|
||||||
|
{ key: 'Accurate', value: 'full-res' },
|
||||||
|
];
|
||||||
|
|
||||||
export const NODE_MIN_WIDTH = 250;
|
export const NODE_MIN_WIDTH = 250;
|
||||||
|
336
invokeai/frontend/web/src/app/invokeai.d.ts
vendored
Normal file
336
invokeai/frontend/web/src/app/invokeai.d.ts
vendored
Normal file
@ -0,0 +1,336 @@
|
|||||||
|
/**
|
||||||
|
* Types for images, the things they are made of, and the things
|
||||||
|
* they make up.
|
||||||
|
*
|
||||||
|
* Generated images are txt2img and img2img images. They may have
|
||||||
|
* had additional postprocessing done on them when they were first
|
||||||
|
* generated.
|
||||||
|
*
|
||||||
|
* Postprocessed images are images which were not generated here
|
||||||
|
* but only postprocessed by the app. They only get postprocessing
|
||||||
|
* metadata and have a different image type, e.g. 'esrgan' or
|
||||||
|
* 'gfpgan'.
|
||||||
|
*/
|
||||||
|
|
||||||
|
import { InvokeTabName } from 'features/ui/store/tabMap';
|
||||||
|
import { IRect } from 'konva/lib/types';
|
||||||
|
import { ImageMetadata, ImageType } from 'services/api';
|
||||||
|
import { AnyInvocation } from 'services/events/types';
|
||||||
|
|
||||||
|
/**
|
||||||
|
* TODO:
|
||||||
|
* Once an image has been generated, if it is postprocessed again,
|
||||||
|
* additional postprocessing steps are added to its postprocessing
|
||||||
|
* array.
|
||||||
|
*
|
||||||
|
* TODO: Better documentation of types.
|
||||||
|
*/
|
||||||
|
|
||||||
|
export declare type PromptItem = {
|
||||||
|
prompt: string;
|
||||||
|
weight: number;
|
||||||
|
};
|
||||||
|
|
||||||
|
// TECHDEBT: We need to retain compatibility with plain prompt strings and the structure Prompt type
|
||||||
|
export declare type Prompt = Array<PromptItem> | string;
|
||||||
|
|
||||||
|
export declare type SeedWeightPair = {
|
||||||
|
seed: number;
|
||||||
|
weight: number;
|
||||||
|
};
|
||||||
|
|
||||||
|
export declare type SeedWeights = Array<SeedWeightPair>;
|
||||||
|
|
||||||
|
// All generated images contain these metadata.
|
||||||
|
export declare type CommonGeneratedImageMetadata = {
|
||||||
|
postprocessing: null | Array<ESRGANMetadata | GFPGANMetadata>;
|
||||||
|
sampler:
|
||||||
|
| 'ddim'
|
||||||
|
| 'k_dpm_2_a'
|
||||||
|
| 'k_dpm_2'
|
||||||
|
| 'k_dpmpp_2_a'
|
||||||
|
| 'k_dpmpp_2'
|
||||||
|
| 'k_euler_a'
|
||||||
|
| 'k_euler'
|
||||||
|
| 'k_heun'
|
||||||
|
| 'k_lms'
|
||||||
|
| 'plms';
|
||||||
|
prompt: Prompt;
|
||||||
|
seed: number;
|
||||||
|
variations: SeedWeights;
|
||||||
|
steps: number;
|
||||||
|
cfg_scale: number;
|
||||||
|
width: number;
|
||||||
|
height: number;
|
||||||
|
seamless: boolean;
|
||||||
|
hires_fix: boolean;
|
||||||
|
extra: null | Record<string, never>; // Pending development of RFC #266
|
||||||
|
};
|
||||||
|
|
||||||
|
// txt2img and img2img images have some unique attributes.
|
||||||
|
export declare type Txt2ImgMetadata = GeneratedImageMetadata & {
|
||||||
|
type: 'txt2img';
|
||||||
|
};
|
||||||
|
|
||||||
|
export declare type Img2ImgMetadata = GeneratedImageMetadata & {
|
||||||
|
type: 'img2img';
|
||||||
|
orig_hash: string;
|
||||||
|
strength: number;
|
||||||
|
fit: boolean;
|
||||||
|
init_image_path: string;
|
||||||
|
mask_image_path?: string;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Superset of generated image metadata types.
|
||||||
|
export declare type GeneratedImageMetadata = Txt2ImgMetadata | Img2ImgMetadata;
|
||||||
|
|
||||||
|
// All post processed images contain these metadata.
|
||||||
|
export declare type CommonPostProcessedImageMetadata = {
|
||||||
|
orig_path: string;
|
||||||
|
orig_hash: string;
|
||||||
|
};
|
||||||
|
|
||||||
|
// esrgan and gfpgan images have some unique attributes.
|
||||||
|
export declare type ESRGANMetadata = CommonPostProcessedImageMetadata & {
|
||||||
|
type: 'esrgan';
|
||||||
|
scale: 2 | 4;
|
||||||
|
strength: number;
|
||||||
|
denoise_str: number;
|
||||||
|
};
|
||||||
|
|
||||||
|
export declare type FacetoolMetadata = CommonPostProcessedImageMetadata & {
|
||||||
|
type: 'gfpgan' | 'codeformer';
|
||||||
|
strength: number;
|
||||||
|
fidelity?: number;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Superset of all postprocessed image metadata types..
|
||||||
|
export declare type PostProcessedImageMetadata =
|
||||||
|
| ESRGANMetadata
|
||||||
|
| FacetoolMetadata;
|
||||||
|
|
||||||
|
// Metadata includes the system config and image metadata.
|
||||||
|
export declare type Metadata = SystemGenerationMetadata & {
|
||||||
|
image: GeneratedImageMetadata | PostProcessedImageMetadata;
|
||||||
|
};
|
||||||
|
|
||||||
|
// An Image has a UUID, url, modified timestamp, width, height and maybe metadata
|
||||||
|
export declare type _Image = {
|
||||||
|
uuid: string;
|
||||||
|
url: string;
|
||||||
|
thumbnail: string;
|
||||||
|
mtime: number;
|
||||||
|
metadata?: Metadata;
|
||||||
|
width: number;
|
||||||
|
height: number;
|
||||||
|
category: GalleryCategory;
|
||||||
|
isBase64?: boolean;
|
||||||
|
dreamPrompt?: 'string';
|
||||||
|
name?: string;
|
||||||
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* ResultImage
|
||||||
|
*/
|
||||||
|
export declare type Image = {
|
||||||
|
name: string;
|
||||||
|
type: ImageType;
|
||||||
|
url: string;
|
||||||
|
thumbnail: string;
|
||||||
|
metadata: ImageMetadata;
|
||||||
|
};
|
||||||
|
|
||||||
|
// GalleryImages is an array of Image.
|
||||||
|
export declare type GalleryImages = {
|
||||||
|
images: Array<_Image>;
|
||||||
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Types related to the system status.
|
||||||
|
*/
|
||||||
|
|
||||||
|
// This represents the processing status of the backend.
|
||||||
|
export declare type SystemStatus = {
|
||||||
|
isProcessing: boolean;
|
||||||
|
currentStep: number;
|
||||||
|
totalSteps: number;
|
||||||
|
currentIteration: number;
|
||||||
|
totalIterations: number;
|
||||||
|
currentStatus: string;
|
||||||
|
currentStatusHasSteps: boolean;
|
||||||
|
hasError: boolean;
|
||||||
|
};
|
||||||
|
|
||||||
|
export declare type SystemGenerationMetadata = {
|
||||||
|
model: string;
|
||||||
|
model_weights?: string;
|
||||||
|
model_id?: string;
|
||||||
|
model_hash: string;
|
||||||
|
app_id: string;
|
||||||
|
app_version: string;
|
||||||
|
};
|
||||||
|
|
||||||
|
export declare type SystemConfig = SystemGenerationMetadata & {
|
||||||
|
model_list: ModelList;
|
||||||
|
infill_methods: string[];
|
||||||
|
};
|
||||||
|
|
||||||
|
export declare type ModelStatus = 'active' | 'cached' | 'not loaded';
|
||||||
|
|
||||||
|
export declare type Model = {
|
||||||
|
status: ModelStatus;
|
||||||
|
description: string;
|
||||||
|
weights: string;
|
||||||
|
config?: string;
|
||||||
|
vae?: string;
|
||||||
|
width?: number;
|
||||||
|
height?: number;
|
||||||
|
default?: boolean;
|
||||||
|
format?: string;
|
||||||
|
};
|
||||||
|
|
||||||
|
export declare type DiffusersModel = {
|
||||||
|
status: ModelStatus;
|
||||||
|
description: string;
|
||||||
|
repo_id?: string;
|
||||||
|
path?: string;
|
||||||
|
vae?: {
|
||||||
|
repo_id?: string;
|
||||||
|
path?: string;
|
||||||
|
};
|
||||||
|
format?: string;
|
||||||
|
default?: boolean;
|
||||||
|
};
|
||||||
|
|
||||||
|
export declare type ModelList = Record<string, Model & DiffusersModel>;
|
||||||
|
|
||||||
|
export declare type FoundModel = {
|
||||||
|
name: string;
|
||||||
|
location: string;
|
||||||
|
};
|
||||||
|
|
||||||
|
export declare type InvokeModelConfigProps = {
|
||||||
|
name: string | undefined;
|
||||||
|
description: string | undefined;
|
||||||
|
config: string | undefined;
|
||||||
|
weights: string | undefined;
|
||||||
|
vae: string | undefined;
|
||||||
|
width: number | undefined;
|
||||||
|
height: number | undefined;
|
||||||
|
default: boolean | undefined;
|
||||||
|
format: string | undefined;
|
||||||
|
};
|
||||||
|
|
||||||
|
export declare type InvokeDiffusersModelConfigProps = {
|
||||||
|
name: string | undefined;
|
||||||
|
description: string | undefined;
|
||||||
|
repo_id: string | undefined;
|
||||||
|
path: string | undefined;
|
||||||
|
default: boolean | undefined;
|
||||||
|
format: string | undefined;
|
||||||
|
vae: {
|
||||||
|
repo_id: string | undefined;
|
||||||
|
path: string | undefined;
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
|
export declare type InvokeModelConversionProps = {
|
||||||
|
model_name: string;
|
||||||
|
save_location: string;
|
||||||
|
custom_location: string | null;
|
||||||
|
};
|
||||||
|
|
||||||
|
export declare type InvokeModelMergingProps = {
|
||||||
|
models_to_merge: string[];
|
||||||
|
alpha: number;
|
||||||
|
interp: 'weighted_sum' | 'sigmoid' | 'inv_sigmoid' | 'add_difference';
|
||||||
|
force: boolean;
|
||||||
|
merged_model_name: string;
|
||||||
|
model_merge_save_path: string | null;
|
||||||
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* These types type data received from the server via socketio.
|
||||||
|
*/
|
||||||
|
|
||||||
|
export declare type ModelChangeResponse = {
|
||||||
|
model_name: string;
|
||||||
|
model_list: ModelList;
|
||||||
|
};
|
||||||
|
|
||||||
|
export declare type ModelConvertedResponse = {
|
||||||
|
converted_model_name: string;
|
||||||
|
model_list: ModelList;
|
||||||
|
};
|
||||||
|
|
||||||
|
export declare type ModelsMergedResponse = {
|
||||||
|
merged_models: string[];
|
||||||
|
merged_model_name: string;
|
||||||
|
model_list: ModelList;
|
||||||
|
};
|
||||||
|
|
||||||
|
export declare type ModelAddedResponse = {
|
||||||
|
new_model_name: string;
|
||||||
|
model_list: ModelList;
|
||||||
|
update: boolean;
|
||||||
|
};
|
||||||
|
|
||||||
|
export declare type ModelDeletedResponse = {
|
||||||
|
deleted_model_name: string;
|
||||||
|
model_list: ModelList;
|
||||||
|
};
|
||||||
|
|
||||||
|
export declare type FoundModelResponse = {
|
||||||
|
search_folder: string;
|
||||||
|
found_models: FoundModel[];
|
||||||
|
};
|
||||||
|
|
||||||
|
export declare type SystemStatusResponse = SystemStatus;
|
||||||
|
|
||||||
|
export declare type SystemConfigResponse = SystemConfig;
|
||||||
|
|
||||||
|
export declare type ImageResultResponse = Omit<_Image, 'uuid'> & {
|
||||||
|
boundingBox?: IRect;
|
||||||
|
generationMode: InvokeTabName;
|
||||||
|
};
|
||||||
|
|
||||||
|
export declare type ImageUploadResponse = {
|
||||||
|
// image: Omit<Image, 'uuid' | 'metadata' | 'category'>;
|
||||||
|
url: string;
|
||||||
|
mtime: number;
|
||||||
|
width: number;
|
||||||
|
height: number;
|
||||||
|
thumbnail: string;
|
||||||
|
// bbox: [number, number, number, number];
|
||||||
|
};
|
||||||
|
|
||||||
|
export declare type ErrorResponse = {
|
||||||
|
message: string;
|
||||||
|
additionalData?: string;
|
||||||
|
};
|
||||||
|
|
||||||
|
export declare type GalleryImagesResponse = {
|
||||||
|
images: Array<Omit<_Image, 'uuid'>>;
|
||||||
|
areMoreImagesAvailable: boolean;
|
||||||
|
category: GalleryCategory;
|
||||||
|
};
|
||||||
|
|
||||||
|
export declare type ImageDeletedResponse = {
|
||||||
|
uuid: string;
|
||||||
|
url: string;
|
||||||
|
category: GalleryCategory;
|
||||||
|
};
|
||||||
|
|
||||||
|
export declare type ImageUrlResponse = {
|
||||||
|
url: string;
|
||||||
|
};
|
||||||
|
|
||||||
|
export declare type UploadImagePayload = {
|
||||||
|
file: File;
|
||||||
|
destination?: ImageUploadDestination;
|
||||||
|
};
|
||||||
|
|
||||||
|
export declare type UploadOutpaintingMergeImagePayload = {
|
||||||
|
dataURL: string;
|
||||||
|
name: string;
|
||||||
|
};
|
@ -1,94 +0,0 @@
|
|||||||
import { createSelector } from '@reduxjs/toolkit';
|
|
||||||
import { useAppSelector } from 'app/store/storeHooks';
|
|
||||||
import { systemSelector } from 'features/system/store/systemSelectors';
|
|
||||||
import { isEqual } from 'lodash-es';
|
|
||||||
import { useEffect } from 'react';
|
|
||||||
import { LogLevelName, ROARR, Roarr } from 'roarr';
|
|
||||||
import { createLogWriter } from '@roarr/browser-log-writer';
|
|
||||||
|
|
||||||
// Base logging context includes only the package name
|
|
||||||
const baseContext = { package: '@invoke-ai/invoke-ai-ui' };
|
|
||||||
|
|
||||||
// Create browser log writer
|
|
||||||
ROARR.write = createLogWriter();
|
|
||||||
|
|
||||||
// Module-scoped logger - can be imported and used anywhere
|
|
||||||
export let log = Roarr.child(baseContext);
|
|
||||||
|
|
||||||
// Translate human-readable log levels to numbers, used for log filtering
|
|
||||||
export const LOG_LEVEL_MAP: Record<LogLevelName, number> = {
|
|
||||||
trace: 10,
|
|
||||||
debug: 20,
|
|
||||||
info: 30,
|
|
||||||
warn: 40,
|
|
||||||
error: 50,
|
|
||||||
fatal: 60,
|
|
||||||
};
|
|
||||||
|
|
||||||
export const VALID_LOG_LEVELS = [
|
|
||||||
'trace',
|
|
||||||
'debug',
|
|
||||||
'info',
|
|
||||||
'warn',
|
|
||||||
'error',
|
|
||||||
'fatal',
|
|
||||||
] as const;
|
|
||||||
|
|
||||||
export type InvokeLogLevel = (typeof VALID_LOG_LEVELS)[number];
|
|
||||||
|
|
||||||
const selector = createSelector(
|
|
||||||
systemSelector,
|
|
||||||
(system) => {
|
|
||||||
const { app_version, consoleLogLevel, shouldLogToConsole } = system;
|
|
||||||
|
|
||||||
return {
|
|
||||||
version: app_version,
|
|
||||||
consoleLogLevel,
|
|
||||||
shouldLogToConsole,
|
|
||||||
};
|
|
||||||
},
|
|
||||||
{
|
|
||||||
memoizeOptions: {
|
|
||||||
resultEqualityCheck: isEqual,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
);
|
|
||||||
|
|
||||||
export const useLogger = () => {
|
|
||||||
const { version, consoleLogLevel, shouldLogToConsole } =
|
|
||||||
useAppSelector(selector);
|
|
||||||
|
|
||||||
// The provided Roarr browser log writer uses localStorage to config logging to console
|
|
||||||
useEffect(() => {
|
|
||||||
if (shouldLogToConsole) {
|
|
||||||
// Enable console log output
|
|
||||||
localStorage.setItem('ROARR_LOG', 'true');
|
|
||||||
|
|
||||||
// Use a filter to show only logs of the given level
|
|
||||||
localStorage.setItem(
|
|
||||||
'ROARR_FILTER',
|
|
||||||
`context.logLevel:>=${LOG_LEVEL_MAP[consoleLogLevel]}`
|
|
||||||
);
|
|
||||||
} else {
|
|
||||||
// Disable console log output
|
|
||||||
localStorage.setItem('ROARR_LOG', 'false');
|
|
||||||
}
|
|
||||||
ROARR.write = createLogWriter();
|
|
||||||
}, [consoleLogLevel, shouldLogToConsole]);
|
|
||||||
|
|
||||||
// Update the module-scoped logger context as needed
|
|
||||||
useEffect(() => {
|
|
||||||
const newContext: Record<string, any> = {
|
|
||||||
...baseContext,
|
|
||||||
};
|
|
||||||
|
|
||||||
if (version) {
|
|
||||||
newContext.version = version;
|
|
||||||
}
|
|
||||||
|
|
||||||
log = Roarr.child(newContext);
|
|
||||||
}, [version]);
|
|
||||||
|
|
||||||
// Use the logger within components - no different than just importing it directly
|
|
||||||
return log;
|
|
||||||
};
|
|
@ -4,7 +4,7 @@ import { initialCanvasImageSelector } from 'features/canvas/store/canvasSelector
|
|||||||
import { generationSelector } from 'features/parameters/store/generationSelectors';
|
import { generationSelector } from 'features/parameters/store/generationSelectors';
|
||||||
import { systemSelector } from 'features/system/store/systemSelectors';
|
import { systemSelector } from 'features/system/store/systemSelectors';
|
||||||
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
|
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
|
||||||
import { isEqual } from 'lodash-es';
|
import { isEqual } from 'lodash';
|
||||||
|
|
||||||
export const readinessSelector = createSelector(
|
export const readinessSelector = createSelector(
|
||||||
[
|
[
|
||||||
|
@ -1,67 +1,65 @@
|
|||||||
// import { createAction } from '@reduxjs/toolkit';
|
import { createAction } from '@reduxjs/toolkit';
|
||||||
// import * as InvokeAI from 'app/types/invokeai';
|
import * as InvokeAI from 'app/invokeai';
|
||||||
// import { GalleryCategory } from 'features/gallery/store/gallerySlice';
|
import { GalleryCategory } from 'features/gallery/store/gallerySlice';
|
||||||
// import { InvokeTabName } from 'features/ui/store/tabMap';
|
import { InvokeTabName } from 'features/ui/store/tabMap';
|
||||||
|
|
||||||
// /**
|
/**
|
||||||
// * We can't use redux-toolkit's createSlice() to make these actions,
|
* We can't use redux-toolkit's createSlice() to make these actions,
|
||||||
// * because they have no associated reducer. They only exist to dispatch
|
* because they have no associated reducer. They only exist to dispatch
|
||||||
// * requests to the server via socketio. These actions will be handled
|
* requests to the server via socketio. These actions will be handled
|
||||||
// * by the middleware.
|
* by the middleware.
|
||||||
// */
|
*/
|
||||||
|
|
||||||
// export const generateImage = createAction<InvokeTabName>(
|
export const generateImage = createAction<InvokeTabName>(
|
||||||
// 'socketio/generateImage'
|
'socketio/generateImage'
|
||||||
// );
|
);
|
||||||
// export const runESRGAN = createAction<InvokeAI._Image>('socketio/runESRGAN');
|
export const runESRGAN = createAction<InvokeAI._Image>('socketio/runESRGAN');
|
||||||
// export const runFacetool = createAction<InvokeAI._Image>(
|
export const runFacetool = createAction<InvokeAI._Image>(
|
||||||
// 'socketio/runFacetool'
|
'socketio/runFacetool'
|
||||||
// );
|
);
|
||||||
// export const deleteImage = createAction<InvokeAI._Image>(
|
export const deleteImage = createAction<InvokeAI._Image>(
|
||||||
// 'socketio/deleteImage'
|
'socketio/deleteImage'
|
||||||
// );
|
);
|
||||||
// export const requestImages = createAction<GalleryCategory>(
|
export const requestImages = createAction<GalleryCategory>(
|
||||||
// 'socketio/requestImages'
|
'socketio/requestImages'
|
||||||
// );
|
);
|
||||||
// export const requestNewImages = createAction<GalleryCategory>(
|
export const requestNewImages = createAction<GalleryCategory>(
|
||||||
// 'socketio/requestNewImages'
|
'socketio/requestNewImages'
|
||||||
// );
|
);
|
||||||
// export const cancelProcessing = createAction<undefined>(
|
export const cancelProcessing = createAction<undefined>(
|
||||||
// 'socketio/cancelProcessing'
|
'socketio/cancelProcessing'
|
||||||
// );
|
);
|
||||||
|
|
||||||
// export const requestSystemConfig = createAction<undefined>(
|
export const requestSystemConfig = createAction<undefined>(
|
||||||
// 'socketio/requestSystemConfig'
|
'socketio/requestSystemConfig'
|
||||||
// );
|
);
|
||||||
|
|
||||||
// export const searchForModels = createAction<string>('socketio/searchForModels');
|
export const searchForModels = createAction<string>('socketio/searchForModels');
|
||||||
|
|
||||||
// export const addNewModel = createAction<
|
export const addNewModel = createAction<
|
||||||
// InvokeAI.InvokeModelConfigProps | InvokeAI.InvokeDiffusersModelConfigProps
|
InvokeAI.InvokeModelConfigProps | InvokeAI.InvokeDiffusersModelConfigProps
|
||||||
// >('socketio/addNewModel');
|
>('socketio/addNewModel');
|
||||||
|
|
||||||
// export const deleteModel = createAction<string>('socketio/deleteModel');
|
export const deleteModel = createAction<string>('socketio/deleteModel');
|
||||||
|
|
||||||
// export const convertToDiffusers =
|
export const convertToDiffusers =
|
||||||
// createAction<InvokeAI.InvokeModelConversionProps>(
|
createAction<InvokeAI.InvokeModelConversionProps>(
|
||||||
// 'socketio/convertToDiffusers'
|
'socketio/convertToDiffusers'
|
||||||
// );
|
);
|
||||||
|
|
||||||
// export const mergeDiffusersModels =
|
export const mergeDiffusersModels =
|
||||||
// createAction<InvokeAI.InvokeModelMergingProps>(
|
createAction<InvokeAI.InvokeModelMergingProps>(
|
||||||
// 'socketio/mergeDiffusersModels'
|
'socketio/mergeDiffusersModels'
|
||||||
// );
|
);
|
||||||
|
|
||||||
// export const requestModelChange = createAction<string>(
|
export const requestModelChange = createAction<string>(
|
||||||
// 'socketio/requestModelChange'
|
'socketio/requestModelChange'
|
||||||
// );
|
);
|
||||||
|
|
||||||
// export const saveStagingAreaImageToGallery = createAction<string>(
|
export const saveStagingAreaImageToGallery = createAction<string>(
|
||||||
// 'socketio/saveStagingAreaImageToGallery'
|
'socketio/saveStagingAreaImageToGallery'
|
||||||
// );
|
);
|
||||||
|
|
||||||
// export const emptyTempFolder = createAction<undefined>(
|
export const emptyTempFolder = createAction<undefined>(
|
||||||
// 'socketio/requestEmptyTempFolder'
|
'socketio/requestEmptyTempFolder'
|
||||||
// );
|
);
|
||||||
|
|
||||||
export default {};
|
|
||||||
|
@ -1,209 +1,208 @@
|
|||||||
// import { AnyAction, Dispatch, MiddlewareAPI } from '@reduxjs/toolkit';
|
import { AnyAction, Dispatch, MiddlewareAPI } from '@reduxjs/toolkit';
|
||||||
// import * as InvokeAI from 'app/types/invokeai';
|
import * as InvokeAI from 'app/invokeai';
|
||||||
// import type { RootState } from 'app/store/store';
|
import type { RootState } from 'app/store';
|
||||||
// import {
|
import {
|
||||||
// frontendToBackendParameters,
|
frontendToBackendParameters,
|
||||||
// FrontendToBackendParametersConfig,
|
FrontendToBackendParametersConfig,
|
||||||
// } from 'common/util/parameterTranslation';
|
} from 'common/util/parameterTranslation';
|
||||||
// import dateFormat from 'dateformat';
|
import dateFormat from 'dateformat';
|
||||||
// import {
|
import {
|
||||||
// GalleryCategory,
|
GalleryCategory,
|
||||||
// GalleryState,
|
GalleryState,
|
||||||
// removeImage,
|
removeImage,
|
||||||
// } from 'features/gallery/store/gallerySlice';
|
} from 'features/gallery/store/gallerySlice';
|
||||||
// import {
|
import {
|
||||||
// generationRequested,
|
addLogEntry,
|
||||||
// modelChangeRequested,
|
generationRequested,
|
||||||
// modelConvertRequested,
|
modelChangeRequested,
|
||||||
// modelMergingRequested,
|
modelConvertRequested,
|
||||||
// setIsProcessing,
|
modelMergingRequested,
|
||||||
// } from 'features/system/store/systemSlice';
|
setIsProcessing,
|
||||||
// import { InvokeTabName } from 'features/ui/store/tabMap';
|
} from 'features/system/store/systemSlice';
|
||||||
// import { Socket } from 'socket.io-client';
|
import { InvokeTabName } from 'features/ui/store/tabMap';
|
||||||
|
import { Socket } from 'socket.io-client';
|
||||||
|
|
||||||
// /**
|
/**
|
||||||
// * Returns an object containing all functions which use `socketio.emit()`.
|
* Returns an object containing all functions which use `socketio.emit()`.
|
||||||
// * i.e. those which make server requests.
|
* i.e. those which make server requests.
|
||||||
// */
|
*/
|
||||||
// const makeSocketIOEmitters = (
|
const makeSocketIOEmitters = (
|
||||||
// store: MiddlewareAPI<Dispatch<AnyAction>, RootState>,
|
store: MiddlewareAPI<Dispatch<AnyAction>, RootState>,
|
||||||
// socketio: Socket
|
socketio: Socket
|
||||||
// ) => {
|
) => {
|
||||||
// // We need to dispatch actions to redux and get pieces of state from the store.
|
// We need to dispatch actions to redux and get pieces of state from the store.
|
||||||
// const { dispatch, getState } = store;
|
const { dispatch, getState } = store;
|
||||||
|
|
||||||
// return {
|
return {
|
||||||
// emitGenerateImage: (generationMode: InvokeTabName) => {
|
emitGenerateImage: (generationMode: InvokeTabName) => {
|
||||||
// dispatch(setIsProcessing(true));
|
dispatch(setIsProcessing(true));
|
||||||
|
|
||||||
// const state: RootState = getState();
|
const state: RootState = getState();
|
||||||
|
|
||||||
// const {
|
const {
|
||||||
// generation: generationState,
|
generation: generationState,
|
||||||
// postprocessing: postprocessingState,
|
postprocessing: postprocessingState,
|
||||||
// system: systemState,
|
system: systemState,
|
||||||
// canvas: canvasState,
|
canvas: canvasState,
|
||||||
// } = state;
|
} = state;
|
||||||
|
|
||||||
// const frontendToBackendParametersConfig: FrontendToBackendParametersConfig =
|
const frontendToBackendParametersConfig: FrontendToBackendParametersConfig =
|
||||||
// {
|
{
|
||||||
// generationMode,
|
generationMode,
|
||||||
// generationState,
|
generationState,
|
||||||
// postprocessingState,
|
postprocessingState,
|
||||||
// canvasState,
|
canvasState,
|
||||||
// systemState,
|
systemState,
|
||||||
// };
|
};
|
||||||
|
|
||||||
// dispatch(generationRequested());
|
dispatch(generationRequested());
|
||||||
|
|
||||||
// const { generationParameters, esrganParameters, facetoolParameters } =
|
const { generationParameters, esrganParameters, facetoolParameters } =
|
||||||
// frontendToBackendParameters(frontendToBackendParametersConfig);
|
frontendToBackendParameters(frontendToBackendParametersConfig);
|
||||||
|
|
||||||
// socketio.emit(
|
socketio.emit(
|
||||||
// 'generateImage',
|
'generateImage',
|
||||||
// generationParameters,
|
generationParameters,
|
||||||
// esrganParameters,
|
esrganParameters,
|
||||||
// facetoolParameters
|
facetoolParameters
|
||||||
// );
|
);
|
||||||
|
|
||||||
// // we need to truncate the init_mask base64 else it takes up the whole log
|
// we need to truncate the init_mask base64 else it takes up the whole log
|
||||||
// // TODO: handle maintaining masks for reproducibility in future
|
// TODO: handle maintaining masks for reproducibility in future
|
||||||
// if (generationParameters.init_mask) {
|
if (generationParameters.init_mask) {
|
||||||
// generationParameters.init_mask = generationParameters.init_mask
|
generationParameters.init_mask = generationParameters.init_mask
|
||||||
// .substr(0, 64)
|
.substr(0, 64)
|
||||||
// .concat('...');
|
.concat('...');
|
||||||
// }
|
}
|
||||||
// if (generationParameters.init_img) {
|
if (generationParameters.init_img) {
|
||||||
// generationParameters.init_img = generationParameters.init_img
|
generationParameters.init_img = generationParameters.init_img
|
||||||
// .substr(0, 64)
|
.substr(0, 64)
|
||||||
// .concat('...');
|
.concat('...');
|
||||||
// }
|
}
|
||||||
|
|
||||||
// dispatch(
|
dispatch(
|
||||||
// addLogEntry({
|
addLogEntry({
|
||||||
// timestamp: dateFormat(new Date(), 'isoDateTime'),
|
timestamp: dateFormat(new Date(), 'isoDateTime'),
|
||||||
// message: `Image generation requested: ${JSON.stringify({
|
message: `Image generation requested: ${JSON.stringify({
|
||||||
// ...generationParameters,
|
...generationParameters,
|
||||||
// ...esrganParameters,
|
...esrganParameters,
|
||||||
// ...facetoolParameters,
|
...facetoolParameters,
|
||||||
// })}`,
|
})}`,
|
||||||
// })
|
})
|
||||||
// );
|
);
|
||||||
// },
|
},
|
||||||
// emitRunESRGAN: (imageToProcess: InvokeAI._Image) => {
|
emitRunESRGAN: (imageToProcess: InvokeAI._Image) => {
|
||||||
// dispatch(setIsProcessing(true));
|
dispatch(setIsProcessing(true));
|
||||||
|
|
||||||
// const {
|
const {
|
||||||
// postprocessing: {
|
postprocessing: {
|
||||||
// upscalingLevel,
|
upscalingLevel,
|
||||||
// upscalingDenoising,
|
upscalingDenoising,
|
||||||
// upscalingStrength,
|
upscalingStrength,
|
||||||
// },
|
},
|
||||||
// } = getState();
|
} = getState();
|
||||||
|
|
||||||
// const esrganParameters = {
|
const esrganParameters = {
|
||||||
// upscale: [upscalingLevel, upscalingDenoising, upscalingStrength],
|
upscale: [upscalingLevel, upscalingDenoising, upscalingStrength],
|
||||||
// };
|
};
|
||||||
// socketio.emit('runPostprocessing', imageToProcess, {
|
socketio.emit('runPostprocessing', imageToProcess, {
|
||||||
// type: 'esrgan',
|
type: 'esrgan',
|
||||||
// ...esrganParameters,
|
...esrganParameters,
|
||||||
// });
|
});
|
||||||
// dispatch(
|
dispatch(
|
||||||
// addLogEntry({
|
addLogEntry({
|
||||||
// timestamp: dateFormat(new Date(), 'isoDateTime'),
|
timestamp: dateFormat(new Date(), 'isoDateTime'),
|
||||||
// message: `ESRGAN upscale requested: ${JSON.stringify({
|
message: `ESRGAN upscale requested: ${JSON.stringify({
|
||||||
// file: imageToProcess.url,
|
file: imageToProcess.url,
|
||||||
// ...esrganParameters,
|
...esrganParameters,
|
||||||
// })}`,
|
})}`,
|
||||||
// })
|
})
|
||||||
// );
|
);
|
||||||
// },
|
},
|
||||||
// emitRunFacetool: (imageToProcess: InvokeAI._Image) => {
|
emitRunFacetool: (imageToProcess: InvokeAI._Image) => {
|
||||||
// dispatch(setIsProcessing(true));
|
dispatch(setIsProcessing(true));
|
||||||
|
|
||||||
// const {
|
const {
|
||||||
// postprocessing: { facetoolType, facetoolStrength, codeformerFidelity },
|
postprocessing: { facetoolType, facetoolStrength, codeformerFidelity },
|
||||||
// } = getState();
|
} = getState();
|
||||||
|
|
||||||
// const facetoolParameters: Record<string, unknown> = {
|
const facetoolParameters: Record<string, unknown> = {
|
||||||
// facetool_strength: facetoolStrength,
|
facetool_strength: facetoolStrength,
|
||||||
// };
|
};
|
||||||
|
|
||||||
// if (facetoolType === 'codeformer') {
|
if (facetoolType === 'codeformer') {
|
||||||
// facetoolParameters.codeformer_fidelity = codeformerFidelity;
|
facetoolParameters.codeformer_fidelity = codeformerFidelity;
|
||||||
// }
|
}
|
||||||
|
|
||||||
// socketio.emit('runPostprocessing', imageToProcess, {
|
socketio.emit('runPostprocessing', imageToProcess, {
|
||||||
// type: facetoolType,
|
type: facetoolType,
|
||||||
// ...facetoolParameters,
|
...facetoolParameters,
|
||||||
// });
|
});
|
||||||
// dispatch(
|
dispatch(
|
||||||
// addLogEntry({
|
addLogEntry({
|
||||||
// timestamp: dateFormat(new Date(), 'isoDateTime'),
|
timestamp: dateFormat(new Date(), 'isoDateTime'),
|
||||||
// message: `Face restoration (${facetoolType}) requested: ${JSON.stringify(
|
message: `Face restoration (${facetoolType}) requested: ${JSON.stringify(
|
||||||
// {
|
{
|
||||||
// file: imageToProcess.url,
|
file: imageToProcess.url,
|
||||||
// ...facetoolParameters,
|
...facetoolParameters,
|
||||||
// }
|
}
|
||||||
// )}`,
|
)}`,
|
||||||
// })
|
})
|
||||||
// );
|
);
|
||||||
// },
|
},
|
||||||
// emitDeleteImage: (imageToDelete: InvokeAI._Image) => {
|
emitDeleteImage: (imageToDelete: InvokeAI._Image) => {
|
||||||
// const { url, uuid, category, thumbnail } = imageToDelete;
|
const { url, uuid, category, thumbnail } = imageToDelete;
|
||||||
// dispatch(removeImage(imageToDelete));
|
dispatch(removeImage(imageToDelete));
|
||||||
// socketio.emit('deleteImage', url, thumbnail, uuid, category);
|
socketio.emit('deleteImage', url, thumbnail, uuid, category);
|
||||||
// },
|
},
|
||||||
// emitRequestImages: (category: GalleryCategory) => {
|
emitRequestImages: (category: GalleryCategory) => {
|
||||||
// const gallery: GalleryState = getState().gallery;
|
const gallery: GalleryState = getState().gallery;
|
||||||
// const { earliest_mtime } = gallery.categories[category];
|
const { earliest_mtime } = gallery.categories[category];
|
||||||
// socketio.emit('requestImages', category, earliest_mtime);
|
socketio.emit('requestImages', category, earliest_mtime);
|
||||||
// },
|
},
|
||||||
// emitRequestNewImages: (category: GalleryCategory) => {
|
emitRequestNewImages: (category: GalleryCategory) => {
|
||||||
// const gallery: GalleryState = getState().gallery;
|
const gallery: GalleryState = getState().gallery;
|
||||||
// const { latest_mtime } = gallery.categories[category];
|
const { latest_mtime } = gallery.categories[category];
|
||||||
// socketio.emit('requestLatestImages', category, latest_mtime);
|
socketio.emit('requestLatestImages', category, latest_mtime);
|
||||||
// },
|
},
|
||||||
// emitCancelProcessing: () => {
|
emitCancelProcessing: () => {
|
||||||
// socketio.emit('cancel');
|
socketio.emit('cancel');
|
||||||
// },
|
},
|
||||||
// emitRequestSystemConfig: () => {
|
emitRequestSystemConfig: () => {
|
||||||
// socketio.emit('requestSystemConfig');
|
socketio.emit('requestSystemConfig');
|
||||||
// },
|
},
|
||||||
// emitSearchForModels: (modelFolder: string) => {
|
emitSearchForModels: (modelFolder: string) => {
|
||||||
// socketio.emit('searchForModels', modelFolder);
|
socketio.emit('searchForModels', modelFolder);
|
||||||
// },
|
},
|
||||||
// emitAddNewModel: (modelConfig: InvokeAI.InvokeModelConfigProps) => {
|
emitAddNewModel: (modelConfig: InvokeAI.InvokeModelConfigProps) => {
|
||||||
// socketio.emit('addNewModel', modelConfig);
|
socketio.emit('addNewModel', modelConfig);
|
||||||
// },
|
},
|
||||||
// emitDeleteModel: (modelName: string) => {
|
emitDeleteModel: (modelName: string) => {
|
||||||
// socketio.emit('deleteModel', modelName);
|
socketio.emit('deleteModel', modelName);
|
||||||
// },
|
},
|
||||||
// emitConvertToDiffusers: (
|
emitConvertToDiffusers: (
|
||||||
// modelToConvert: InvokeAI.InvokeModelConversionProps
|
modelToConvert: InvokeAI.InvokeModelConversionProps
|
||||||
// ) => {
|
) => {
|
||||||
// dispatch(modelConvertRequested());
|
dispatch(modelConvertRequested());
|
||||||
// socketio.emit('convertToDiffusers', modelToConvert);
|
socketio.emit('convertToDiffusers', modelToConvert);
|
||||||
// },
|
},
|
||||||
// emitMergeDiffusersModels: (
|
emitMergeDiffusersModels: (
|
||||||
// modelMergeInfo: InvokeAI.InvokeModelMergingProps
|
modelMergeInfo: InvokeAI.InvokeModelMergingProps
|
||||||
// ) => {
|
) => {
|
||||||
// dispatch(modelMergingRequested());
|
dispatch(modelMergingRequested());
|
||||||
// socketio.emit('mergeDiffusersModels', modelMergeInfo);
|
socketio.emit('mergeDiffusersModels', modelMergeInfo);
|
||||||
// },
|
},
|
||||||
// emitRequestModelChange: (modelName: string) => {
|
emitRequestModelChange: (modelName: string) => {
|
||||||
// dispatch(modelChangeRequested());
|
dispatch(modelChangeRequested());
|
||||||
// socketio.emit('requestModelChange', modelName);
|
socketio.emit('requestModelChange', modelName);
|
||||||
// },
|
},
|
||||||
// emitSaveStagingAreaImageToGallery: (url: string) => {
|
emitSaveStagingAreaImageToGallery: (url: string) => {
|
||||||
// socketio.emit('requestSaveStagingAreaImageToGallery', url);
|
socketio.emit('requestSaveStagingAreaImageToGallery', url);
|
||||||
// },
|
},
|
||||||
// emitRequestEmptyTempFolder: () => {
|
emitRequestEmptyTempFolder: () => {
|
||||||
// socketio.emit('requestEmptyTempFolder');
|
socketio.emit('requestEmptyTempFolder');
|
||||||
// },
|
},
|
||||||
// };
|
};
|
||||||
// };
|
};
|
||||||
|
|
||||||
// export default makeSocketIOEmitters;
|
export default makeSocketIOEmitters;
|
||||||
|
|
||||||
export default {};
|
|
||||||
|
@ -1,502 +1,501 @@
|
|||||||
// import { AnyAction, Dispatch, MiddlewareAPI } from '@reduxjs/toolkit';
|
import { AnyAction, Dispatch, MiddlewareAPI } from '@reduxjs/toolkit';
|
||||||
// import dateFormat from 'dateformat';
|
import dateFormat from 'dateformat';
|
||||||
// import i18n from 'i18n';
|
import i18n from 'i18n';
|
||||||
// import { v4 as uuidv4 } from 'uuid';
|
import { v4 as uuidv4 } from 'uuid';
|
||||||
|
|
||||||
// import * as InvokeAI from 'app/types/invokeai';
|
import * as InvokeAI from 'app/invokeai';
|
||||||
|
|
||||||
// import {
|
import {
|
||||||
// addToast,
|
addLogEntry,
|
||||||
// errorOccurred,
|
addToast,
|
||||||
// processingCanceled,
|
errorOccurred,
|
||||||
// setCurrentStatus,
|
processingCanceled,
|
||||||
// setFoundModels,
|
setCurrentStatus,
|
||||||
// setIsCancelable,
|
setFoundModels,
|
||||||
// setIsConnected,
|
setIsCancelable,
|
||||||
// setIsProcessing,
|
setIsConnected,
|
||||||
// setModelList,
|
setIsProcessing,
|
||||||
// setSearchFolder,
|
setModelList,
|
||||||
// setSystemConfig,
|
setSearchFolder,
|
||||||
// setSystemStatus,
|
setSystemConfig,
|
||||||
// } from 'features/system/store/systemSlice';
|
setSystemStatus,
|
||||||
|
} from 'features/system/store/systemSlice';
|
||||||
|
|
||||||
// import {
|
import {
|
||||||
// addGalleryImages,
|
addGalleryImages,
|
||||||
// addImage,
|
addImage,
|
||||||
// clearIntermediateImage,
|
clearIntermediateImage,
|
||||||
// GalleryState,
|
GalleryState,
|
||||||
// removeImage,
|
removeImage,
|
||||||
// setIntermediateImage,
|
setIntermediateImage,
|
||||||
// } from 'features/gallery/store/gallerySlice';
|
} from 'features/gallery/store/gallerySlice';
|
||||||
|
|
||||||
// import type { RootState } from 'app/store/store';
|
import type { RootState } from 'app/store';
|
||||||
// import { addImageToStagingArea } from 'features/canvas/store/canvasSlice';
|
import { addImageToStagingArea } from 'features/canvas/store/canvasSlice';
|
||||||
// import {
|
import {
|
||||||
// clearInitialImage,
|
clearInitialImage,
|
||||||
// initialImageSelected,
|
initialImageSelected,
|
||||||
// setInfillMethod,
|
setInfillMethod,
|
||||||
// // setInitialImage,
|
// setInitialImage,
|
||||||
// setMaskPath,
|
setMaskPath,
|
||||||
// } from 'features/parameters/store/generationSlice';
|
} from 'features/parameters/store/generationSlice';
|
||||||
// import { tabMap } from 'features/ui/store/tabMap';
|
import { tabMap } from 'features/ui/store/tabMap';
|
||||||
// import {
|
import {
|
||||||
// requestImages,
|
requestImages,
|
||||||
// requestNewImages,
|
requestNewImages,
|
||||||
// requestSystemConfig,
|
requestSystemConfig,
|
||||||
// } from './actions';
|
} from './actions';
|
||||||
|
|
||||||
// /**
|
/**
|
||||||
// * Returns an object containing listener callbacks for socketio events.
|
* Returns an object containing listener callbacks for socketio events.
|
||||||
// * TODO: This file is large, but simple. Should it be split up further?
|
* TODO: This file is large, but simple. Should it be split up further?
|
||||||
// */
|
*/
|
||||||
// const makeSocketIOListeners = (
|
const makeSocketIOListeners = (
|
||||||
// store: MiddlewareAPI<Dispatch<AnyAction>, RootState>
|
store: MiddlewareAPI<Dispatch<AnyAction>, RootState>
|
||||||
// ) => {
|
) => {
|
||||||
// const { dispatch, getState } = store;
|
const { dispatch, getState } = store;
|
||||||
|
|
||||||
// return {
|
return {
|
||||||
// /**
|
/**
|
||||||
// * Callback to run when we receive a 'connect' event.
|
* Callback to run when we receive a 'connect' event.
|
||||||
// */
|
*/
|
||||||
// onConnect: () => {
|
onConnect: () => {
|
||||||
// try {
|
try {
|
||||||
// dispatch(setIsConnected(true));
|
dispatch(setIsConnected(true));
|
||||||
// dispatch(setCurrentStatus(i18n.t('common.statusConnected')));
|
dispatch(setCurrentStatus(i18n.t('common.statusConnected')));
|
||||||
// dispatch(requestSystemConfig());
|
dispatch(requestSystemConfig());
|
||||||
// const gallery: GalleryState = getState().gallery;
|
const gallery: GalleryState = getState().gallery;
|
||||||
|
|
||||||
// if (gallery.categories.result.latest_mtime) {
|
if (gallery.categories.result.latest_mtime) {
|
||||||
// dispatch(requestNewImages('result'));
|
dispatch(requestNewImages('result'));
|
||||||
// } else {
|
} else {
|
||||||
// dispatch(requestImages('result'));
|
dispatch(requestImages('result'));
|
||||||
// }
|
}
|
||||||
|
|
||||||
// if (gallery.categories.user.latest_mtime) {
|
if (gallery.categories.user.latest_mtime) {
|
||||||
// dispatch(requestNewImages('user'));
|
dispatch(requestNewImages('user'));
|
||||||
// } else {
|
} else {
|
||||||
// dispatch(requestImages('user'));
|
dispatch(requestImages('user'));
|
||||||
// }
|
}
|
||||||
// } catch (e) {
|
} catch (e) {
|
||||||
// console.error(e);
|
console.error(e);
|
||||||
// }
|
}
|
||||||
// },
|
},
|
||||||
// /**
|
/**
|
||||||
// * Callback to run when we receive a 'disconnect' event.
|
* Callback to run when we receive a 'disconnect' event.
|
||||||
// */
|
*/
|
||||||
// onDisconnect: () => {
|
onDisconnect: () => {
|
||||||
// try {
|
try {
|
||||||
// dispatch(setIsConnected(false));
|
dispatch(setIsConnected(false));
|
||||||
// dispatch(setCurrentStatus(i18n.t('common.statusDisconnected')));
|
dispatch(setCurrentStatus(i18n.t('common.statusDisconnected')));
|
||||||
|
|
||||||
// dispatch(
|
dispatch(
|
||||||
// addLogEntry({
|
addLogEntry({
|
||||||
// timestamp: dateFormat(new Date(), 'isoDateTime'),
|
timestamp: dateFormat(new Date(), 'isoDateTime'),
|
||||||
// message: `Disconnected from server`,
|
message: `Disconnected from server`,
|
||||||
// level: 'warning',
|
level: 'warning',
|
||||||
// })
|
})
|
||||||
// );
|
);
|
||||||
// } catch (e) {
|
} catch (e) {
|
||||||
// console.error(e);
|
console.error(e);
|
||||||
// }
|
}
|
||||||
// },
|
},
|
||||||
// /**
|
/**
|
||||||
// * Callback to run when we receive a 'generationResult' event.
|
* Callback to run when we receive a 'generationResult' event.
|
||||||
// */
|
*/
|
||||||
// onGenerationResult: (data: InvokeAI.ImageResultResponse) => {
|
onGenerationResult: (data: InvokeAI.ImageResultResponse) => {
|
||||||
// try {
|
try {
|
||||||
// const state = getState();
|
const state = getState();
|
||||||
// const { activeTab } = state.ui;
|
const { activeTab } = state.ui;
|
||||||
// const { shouldLoopback } = state.postprocessing;
|
const { shouldLoopback } = state.postprocessing;
|
||||||
// const { boundingBox: _, generationMode, ...rest } = data;
|
const { boundingBox: _, generationMode, ...rest } = data;
|
||||||
|
|
||||||
// const newImage = {
|
const newImage = {
|
||||||
// uuid: uuidv4(),
|
uuid: uuidv4(),
|
||||||
// ...rest,
|
...rest,
|
||||||
// };
|
};
|
||||||
|
|
||||||
// if (['txt2img', 'img2img'].includes(generationMode)) {
|
if (['txt2img', 'img2img'].includes(generationMode)) {
|
||||||
// dispatch(
|
dispatch(
|
||||||
// addImage({
|
addImage({
|
||||||
// category: 'result',
|
category: 'result',
|
||||||
// image: { ...newImage, category: 'result' },
|
image: { ...newImage, category: 'result' },
|
||||||
// })
|
})
|
||||||
// );
|
);
|
||||||
// }
|
}
|
||||||
|
|
||||||
// if (generationMode === 'unifiedCanvas' && data.boundingBox) {
|
if (generationMode === 'unifiedCanvas' && data.boundingBox) {
|
||||||
// const { boundingBox } = data;
|
const { boundingBox } = data;
|
||||||
// dispatch(
|
dispatch(
|
||||||
// addImageToStagingArea({
|
addImageToStagingArea({
|
||||||
// image: { ...newImage, category: 'temp' },
|
image: { ...newImage, category: 'temp' },
|
||||||
// boundingBox,
|
boundingBox,
|
||||||
// })
|
})
|
||||||
// );
|
);
|
||||||
|
|
||||||
// if (state.canvas.shouldAutoSave) {
|
if (state.canvas.shouldAutoSave) {
|
||||||
// dispatch(
|
dispatch(
|
||||||
// addImage({
|
addImage({
|
||||||
// image: { ...newImage, category: 'result' },
|
image: { ...newImage, category: 'result' },
|
||||||
// category: 'result',
|
category: 'result',
|
||||||
// })
|
})
|
||||||
// );
|
);
|
||||||
// }
|
}
|
||||||
// }
|
}
|
||||||
|
|
||||||
// // TODO: fix
|
// TODO: fix
|
||||||
// // if (shouldLoopback) {
|
// if (shouldLoopback) {
|
||||||
// // const activeTabName = tabMap[activeTab];
|
// const activeTabName = tabMap[activeTab];
|
||||||
// // switch (activeTabName) {
|
// switch (activeTabName) {
|
||||||
// // case 'img2img': {
|
// case 'img2img': {
|
||||||
// // dispatch(initialImageSelected(newImage.uuid));
|
// dispatch(initialImageSelected(newImage.uuid));
|
||||||
// // // dispatch(setInitialImage(newImage));
|
// // dispatch(setInitialImage(newImage));
|
||||||
// // break;
|
// break;
|
||||||
// // }
|
// }
|
||||||
// // }
|
// }
|
||||||
// // }
|
// }
|
||||||
|
|
||||||
// dispatch(clearIntermediateImage());
|
dispatch(clearIntermediateImage());
|
||||||
|
|
||||||
// dispatch(
|
dispatch(
|
||||||
// addLogEntry({
|
addLogEntry({
|
||||||
// timestamp: dateFormat(new Date(), 'isoDateTime'),
|
timestamp: dateFormat(new Date(), 'isoDateTime'),
|
||||||
// message: `Image generated: ${data.url}`,
|
message: `Image generated: ${data.url}`,
|
||||||
// })
|
})
|
||||||
// );
|
);
|
||||||
// } catch (e) {
|
} catch (e) {
|
||||||
// console.error(e);
|
console.error(e);
|
||||||
// }
|
}
|
||||||
// },
|
},
|
||||||
// /**
|
/**
|
||||||
// * Callback to run when we receive a 'intermediateResult' event.
|
* Callback to run when we receive a 'intermediateResult' event.
|
||||||
// */
|
*/
|
||||||
// onIntermediateResult: (data: InvokeAI.ImageResultResponse) => {
|
onIntermediateResult: (data: InvokeAI.ImageResultResponse) => {
|
||||||
// try {
|
try {
|
||||||
// dispatch(
|
dispatch(
|
||||||
// setIntermediateImage({
|
setIntermediateImage({
|
||||||
// uuid: uuidv4(),
|
uuid: uuidv4(),
|
||||||
// ...data,
|
...data,
|
||||||
// category: 'result',
|
category: 'result',
|
||||||
// })
|
})
|
||||||
// );
|
);
|
||||||
// if (!data.isBase64) {
|
if (!data.isBase64) {
|
||||||
// dispatch(
|
dispatch(
|
||||||
// addLogEntry({
|
addLogEntry({
|
||||||
// timestamp: dateFormat(new Date(), 'isoDateTime'),
|
timestamp: dateFormat(new Date(), 'isoDateTime'),
|
||||||
// message: `Intermediate image generated: ${data.url}`,
|
message: `Intermediate image generated: ${data.url}`,
|
||||||
// })
|
})
|
||||||
// );
|
);
|
||||||
// }
|
}
|
||||||
// } catch (e) {
|
} catch (e) {
|
||||||
// console.error(e);
|
console.error(e);
|
||||||
// }
|
}
|
||||||
// },
|
},
|
||||||
// /**
|
/**
|
||||||
// * Callback to run when we receive an 'esrganResult' event.
|
* Callback to run when we receive an 'esrganResult' event.
|
||||||
// */
|
*/
|
||||||
// onPostprocessingResult: (data: InvokeAI.ImageResultResponse) => {
|
onPostprocessingResult: (data: InvokeAI.ImageResultResponse) => {
|
||||||
// try {
|
try {
|
||||||
// dispatch(
|
dispatch(
|
||||||
// addImage({
|
addImage({
|
||||||
// category: 'result',
|
category: 'result',
|
||||||
// image: {
|
image: {
|
||||||
// uuid: uuidv4(),
|
uuid: uuidv4(),
|
||||||
// ...data,
|
...data,
|
||||||
// category: 'result',
|
category: 'result',
|
||||||
// },
|
},
|
||||||
// })
|
})
|
||||||
// );
|
);
|
||||||
|
|
||||||
// dispatch(
|
dispatch(
|
||||||
// addLogEntry({
|
addLogEntry({
|
||||||
// timestamp: dateFormat(new Date(), 'isoDateTime'),
|
timestamp: dateFormat(new Date(), 'isoDateTime'),
|
||||||
// message: `Postprocessed: ${data.url}`,
|
message: `Postprocessed: ${data.url}`,
|
||||||
// })
|
})
|
||||||
// );
|
);
|
||||||
// } catch (e) {
|
} catch (e) {
|
||||||
// console.error(e);
|
console.error(e);
|
||||||
// }
|
}
|
||||||
// },
|
},
|
||||||
// /**
|
/**
|
||||||
// * Callback to run when we receive a 'progressUpdate' event.
|
* Callback to run when we receive a 'progressUpdate' event.
|
||||||
// * TODO: Add additional progress phases
|
* TODO: Add additional progress phases
|
||||||
// */
|
*/
|
||||||
// onProgressUpdate: (data: InvokeAI.SystemStatus) => {
|
onProgressUpdate: (data: InvokeAI.SystemStatus) => {
|
||||||
// try {
|
try {
|
||||||
// dispatch(setIsProcessing(true));
|
dispatch(setIsProcessing(true));
|
||||||
// dispatch(setSystemStatus(data));
|
dispatch(setSystemStatus(data));
|
||||||
// } catch (e) {
|
} catch (e) {
|
||||||
// console.error(e);
|
console.error(e);
|
||||||
// }
|
}
|
||||||
// },
|
},
|
||||||
// /**
|
/**
|
||||||
// * Callback to run when we receive a 'progressUpdate' event.
|
* Callback to run when we receive a 'progressUpdate' event.
|
||||||
// */
|
*/
|
||||||
// onError: (data: InvokeAI.ErrorResponse) => {
|
onError: (data: InvokeAI.ErrorResponse) => {
|
||||||
// const { message, additionalData } = data;
|
const { message, additionalData } = data;
|
||||||
|
|
||||||
// if (additionalData) {
|
if (additionalData) {
|
||||||
// // TODO: handle more data than short message
|
// TODO: handle more data than short message
|
||||||
// }
|
}
|
||||||
|
|
||||||
// try {
|
try {
|
||||||
// dispatch(
|
dispatch(
|
||||||
// addLogEntry({
|
addLogEntry({
|
||||||
// timestamp: dateFormat(new Date(), 'isoDateTime'),
|
timestamp: dateFormat(new Date(), 'isoDateTime'),
|
||||||
// message: `Server error: ${message}`,
|
message: `Server error: ${message}`,
|
||||||
// level: 'error',
|
level: 'error',
|
||||||
// })
|
})
|
||||||
// );
|
);
|
||||||
// dispatch(errorOccurred());
|
dispatch(errorOccurred());
|
||||||
// dispatch(clearIntermediateImage());
|
dispatch(clearIntermediateImage());
|
||||||
// } catch (e) {
|
} catch (e) {
|
||||||
// console.error(e);
|
console.error(e);
|
||||||
// }
|
}
|
||||||
// },
|
},
|
||||||
// /**
|
/**
|
||||||
// * Callback to run when we receive a 'galleryImages' event.
|
* Callback to run when we receive a 'galleryImages' event.
|
||||||
// */
|
*/
|
||||||
// onGalleryImages: (data: InvokeAI.GalleryImagesResponse) => {
|
onGalleryImages: (data: InvokeAI.GalleryImagesResponse) => {
|
||||||
// const { images, areMoreImagesAvailable, category } = data;
|
const { images, areMoreImagesAvailable, category } = data;
|
||||||
|
|
||||||
// /**
|
/**
|
||||||
// * the logic here ideally would be in the reducer but we have a side effect:
|
* the logic here ideally would be in the reducer but we have a side effect:
|
||||||
// * generating a uuid. so the logic needs to be here, outside redux.
|
* generating a uuid. so the logic needs to be here, outside redux.
|
||||||
// */
|
*/
|
||||||
|
|
||||||
// // Generate a UUID for each image
|
// Generate a UUID for each image
|
||||||
// const preparedImages = images.map((image): InvokeAI._Image => {
|
const preparedImages = images.map((image): InvokeAI._Image => {
|
||||||
// return {
|
return {
|
||||||
// uuid: uuidv4(),
|
uuid: uuidv4(),
|
||||||
// ...image,
|
...image,
|
||||||
// };
|
};
|
||||||
// });
|
});
|
||||||
|
|
||||||
// dispatch(
|
dispatch(
|
||||||
// addGalleryImages({
|
addGalleryImages({
|
||||||
// images: preparedImages,
|
images: preparedImages,
|
||||||
// areMoreImagesAvailable,
|
areMoreImagesAvailable,
|
||||||
// category,
|
category,
|
||||||
// })
|
})
|
||||||
// );
|
);
|
||||||
|
|
||||||
// dispatch(
|
dispatch(
|
||||||
// addLogEntry({
|
addLogEntry({
|
||||||
// timestamp: dateFormat(new Date(), 'isoDateTime'),
|
timestamp: dateFormat(new Date(), 'isoDateTime'),
|
||||||
// message: `Loaded ${images.length} images`,
|
message: `Loaded ${images.length} images`,
|
||||||
// })
|
})
|
||||||
// );
|
);
|
||||||
// },
|
},
|
||||||
// /**
|
/**
|
||||||
// * Callback to run when we receive a 'processingCanceled' event.
|
* Callback to run when we receive a 'processingCanceled' event.
|
||||||
// */
|
*/
|
||||||
// onProcessingCanceled: () => {
|
onProcessingCanceled: () => {
|
||||||
// dispatch(processingCanceled());
|
dispatch(processingCanceled());
|
||||||
|
|
||||||
// const { intermediateImage } = getState().gallery;
|
const { intermediateImage } = getState().gallery;
|
||||||
|
|
||||||
// if (intermediateImage) {
|
if (intermediateImage) {
|
||||||
// if (!intermediateImage.isBase64) {
|
if (!intermediateImage.isBase64) {
|
||||||
// dispatch(
|
dispatch(
|
||||||
// addImage({
|
addImage({
|
||||||
// category: 'result',
|
category: 'result',
|
||||||
// image: intermediateImage,
|
image: intermediateImage,
|
||||||
// })
|
})
|
||||||
// );
|
);
|
||||||
// dispatch(
|
dispatch(
|
||||||
// addLogEntry({
|
addLogEntry({
|
||||||
// timestamp: dateFormat(new Date(), 'isoDateTime'),
|
timestamp: dateFormat(new Date(), 'isoDateTime'),
|
||||||
// message: `Intermediate image saved: ${intermediateImage.url}`,
|
message: `Intermediate image saved: ${intermediateImage.url}`,
|
||||||
// })
|
})
|
||||||
// );
|
);
|
||||||
// }
|
}
|
||||||
// dispatch(clearIntermediateImage());
|
dispatch(clearIntermediateImage());
|
||||||
// }
|
}
|
||||||
|
|
||||||
// dispatch(
|
dispatch(
|
||||||
// addLogEntry({
|
addLogEntry({
|
||||||
// timestamp: dateFormat(new Date(), 'isoDateTime'),
|
timestamp: dateFormat(new Date(), 'isoDateTime'),
|
||||||
// message: `Processing canceled`,
|
message: `Processing canceled`,
|
||||||
// level: 'warning',
|
level: 'warning',
|
||||||
// })
|
})
|
||||||
// );
|
);
|
||||||
// },
|
},
|
||||||
// /**
|
/**
|
||||||
// * Callback to run when we receive a 'imageDeleted' event.
|
* Callback to run when we receive a 'imageDeleted' event.
|
||||||
// */
|
*/
|
||||||
// onImageDeleted: (data: InvokeAI.ImageDeletedResponse) => {
|
onImageDeleted: (data: InvokeAI.ImageDeletedResponse) => {
|
||||||
// const { url } = data;
|
const { url } = data;
|
||||||
|
|
||||||
// // remove image from gallery
|
// remove image from gallery
|
||||||
// dispatch(removeImage(data));
|
dispatch(removeImage(data));
|
||||||
|
|
||||||
// // remove references to image in options
|
// remove references to image in options
|
||||||
// const {
|
const {
|
||||||
// generation: { initialImage, maskPath },
|
generation: { initialImage, maskPath },
|
||||||
// } = getState();
|
} = getState();
|
||||||
|
|
||||||
// if (
|
if (
|
||||||
// initialImage === url ||
|
initialImage === url ||
|
||||||
// (initialImage as InvokeAI._Image)?.url === url
|
(initialImage as InvokeAI._Image)?.url === url
|
||||||
// ) {
|
) {
|
||||||
// dispatch(clearInitialImage());
|
dispatch(clearInitialImage());
|
||||||
// }
|
}
|
||||||
|
|
||||||
// if (maskPath === url) {
|
if (maskPath === url) {
|
||||||
// dispatch(setMaskPath(''));
|
dispatch(setMaskPath(''));
|
||||||
// }
|
}
|
||||||
|
|
||||||
// dispatch(
|
dispatch(
|
||||||
// addLogEntry({
|
addLogEntry({
|
||||||
// timestamp: dateFormat(new Date(), 'isoDateTime'),
|
timestamp: dateFormat(new Date(), 'isoDateTime'),
|
||||||
// message: `Image deleted: ${url}`,
|
message: `Image deleted: ${url}`,
|
||||||
// })
|
})
|
||||||
// );
|
);
|
||||||
// },
|
},
|
||||||
// onSystemConfig: (data: InvokeAI.SystemConfig) => {
|
onSystemConfig: (data: InvokeAI.SystemConfig) => {
|
||||||
// dispatch(setSystemConfig(data));
|
dispatch(setSystemConfig(data));
|
||||||
// if (!data.infill_methods.includes('patchmatch')) {
|
if (!data.infill_methods.includes('patchmatch')) {
|
||||||
// dispatch(setInfillMethod(data.infill_methods[0]));
|
dispatch(setInfillMethod(data.infill_methods[0]));
|
||||||
// }
|
}
|
||||||
// },
|
},
|
||||||
// onFoundModels: (data: InvokeAI.FoundModelResponse) => {
|
onFoundModels: (data: InvokeAI.FoundModelResponse) => {
|
||||||
// const { search_folder, found_models } = data;
|
const { search_folder, found_models } = data;
|
||||||
// dispatch(setSearchFolder(search_folder));
|
dispatch(setSearchFolder(search_folder));
|
||||||
// dispatch(setFoundModels(found_models));
|
dispatch(setFoundModels(found_models));
|
||||||
// },
|
},
|
||||||
// onNewModelAdded: (data: InvokeAI.ModelAddedResponse) => {
|
onNewModelAdded: (data: InvokeAI.ModelAddedResponse) => {
|
||||||
// const { new_model_name, model_list, update } = data;
|
const { new_model_name, model_list, update } = data;
|
||||||
// dispatch(setModelList(model_list));
|
dispatch(setModelList(model_list));
|
||||||
// dispatch(setIsProcessing(false));
|
dispatch(setIsProcessing(false));
|
||||||
// dispatch(setCurrentStatus(i18n.t('modelManager.modelAdded')));
|
dispatch(setCurrentStatus(i18n.t('modelManager.modelAdded')));
|
||||||
// dispatch(
|
dispatch(
|
||||||
// addLogEntry({
|
addLogEntry({
|
||||||
// timestamp: dateFormat(new Date(), 'isoDateTime'),
|
timestamp: dateFormat(new Date(), 'isoDateTime'),
|
||||||
// message: `Model Added: ${new_model_name}`,
|
message: `Model Added: ${new_model_name}`,
|
||||||
// level: 'info',
|
level: 'info',
|
||||||
// })
|
})
|
||||||
// );
|
);
|
||||||
// dispatch(
|
dispatch(
|
||||||
// addToast({
|
addToast({
|
||||||
// title: !update
|
title: !update
|
||||||
// ? `${i18n.t('modelManager.modelAdded')}: ${new_model_name}`
|
? `${i18n.t('modelManager.modelAdded')}: ${new_model_name}`
|
||||||
// : `${i18n.t('modelManager.modelUpdated')}: ${new_model_name}`,
|
: `${i18n.t('modelManager.modelUpdated')}: ${new_model_name}`,
|
||||||
// status: 'success',
|
status: 'success',
|
||||||
// duration: 2500,
|
duration: 2500,
|
||||||
// isClosable: true,
|
isClosable: true,
|
||||||
// })
|
})
|
||||||
// );
|
);
|
||||||
// },
|
},
|
||||||
// onModelDeleted: (data: InvokeAI.ModelDeletedResponse) => {
|
onModelDeleted: (data: InvokeAI.ModelDeletedResponse) => {
|
||||||
// const { deleted_model_name, model_list } = data;
|
const { deleted_model_name, model_list } = data;
|
||||||
// dispatch(setModelList(model_list));
|
dispatch(setModelList(model_list));
|
||||||
// dispatch(setIsProcessing(false));
|
dispatch(setIsProcessing(false));
|
||||||
// dispatch(
|
dispatch(
|
||||||
// addLogEntry({
|
addLogEntry({
|
||||||
// timestamp: dateFormat(new Date(), 'isoDateTime'),
|
timestamp: dateFormat(new Date(), 'isoDateTime'),
|
||||||
// message: `${i18n.t(
|
message: `${i18n.t(
|
||||||
// 'modelManager.modelAdded'
|
'modelManager.modelAdded'
|
||||||
// )}: ${deleted_model_name}`,
|
)}: ${deleted_model_name}`,
|
||||||
// level: 'info',
|
level: 'info',
|
||||||
// })
|
})
|
||||||
// );
|
);
|
||||||
// dispatch(
|
dispatch(
|
||||||
// addToast({
|
addToast({
|
||||||
// title: `${i18n.t(
|
title: `${i18n.t(
|
||||||
// 'modelManager.modelEntryDeleted'
|
'modelManager.modelEntryDeleted'
|
||||||
// )}: ${deleted_model_name}`,
|
)}: ${deleted_model_name}`,
|
||||||
// status: 'success',
|
status: 'success',
|
||||||
// duration: 2500,
|
duration: 2500,
|
||||||
// isClosable: true,
|
isClosable: true,
|
||||||
// })
|
})
|
||||||
// );
|
);
|
||||||
// },
|
},
|
||||||
// onModelConverted: (data: InvokeAI.ModelConvertedResponse) => {
|
onModelConverted: (data: InvokeAI.ModelConvertedResponse) => {
|
||||||
// const { converted_model_name, model_list } = data;
|
const { converted_model_name, model_list } = data;
|
||||||
// dispatch(setModelList(model_list));
|
dispatch(setModelList(model_list));
|
||||||
// dispatch(setCurrentStatus(i18n.t('common.statusModelConverted')));
|
dispatch(setCurrentStatus(i18n.t('common.statusModelConverted')));
|
||||||
// dispatch(setIsProcessing(false));
|
dispatch(setIsProcessing(false));
|
||||||
// dispatch(setIsCancelable(true));
|
dispatch(setIsCancelable(true));
|
||||||
// dispatch(
|
dispatch(
|
||||||
// addLogEntry({
|
addLogEntry({
|
||||||
// timestamp: dateFormat(new Date(), 'isoDateTime'),
|
timestamp: dateFormat(new Date(), 'isoDateTime'),
|
||||||
// message: `Model converted: ${converted_model_name}`,
|
message: `Model converted: ${converted_model_name}`,
|
||||||
// level: 'info',
|
level: 'info',
|
||||||
// })
|
})
|
||||||
// );
|
);
|
||||||
// dispatch(
|
dispatch(
|
||||||
// addToast({
|
addToast({
|
||||||
// title: `${i18n.t(
|
title: `${i18n.t(
|
||||||
// 'modelManager.modelConverted'
|
'modelManager.modelConverted'
|
||||||
// )}: ${converted_model_name}`,
|
)}: ${converted_model_name}`,
|
||||||
// status: 'success',
|
status: 'success',
|
||||||
// duration: 2500,
|
duration: 2500,
|
||||||
// isClosable: true,
|
isClosable: true,
|
||||||
// })
|
})
|
||||||
// );
|
);
|
||||||
// },
|
},
|
||||||
// onModelsMerged: (data: InvokeAI.ModelsMergedResponse) => {
|
onModelsMerged: (data: InvokeAI.ModelsMergedResponse) => {
|
||||||
// const { merged_models, merged_model_name, model_list } = data;
|
const { merged_models, merged_model_name, model_list } = data;
|
||||||
// dispatch(setModelList(model_list));
|
dispatch(setModelList(model_list));
|
||||||
// dispatch(setCurrentStatus(i18n.t('common.statusMergedModels')));
|
dispatch(setCurrentStatus(i18n.t('common.statusMergedModels')));
|
||||||
// dispatch(setIsProcessing(false));
|
dispatch(setIsProcessing(false));
|
||||||
// dispatch(setIsCancelable(true));
|
dispatch(setIsCancelable(true));
|
||||||
// dispatch(
|
dispatch(
|
||||||
// addLogEntry({
|
addLogEntry({
|
||||||
// timestamp: dateFormat(new Date(), 'isoDateTime'),
|
timestamp: dateFormat(new Date(), 'isoDateTime'),
|
||||||
// message: `Models merged: ${merged_models}`,
|
message: `Models merged: ${merged_models}`,
|
||||||
// level: 'info',
|
level: 'info',
|
||||||
// })
|
})
|
||||||
// );
|
);
|
||||||
// dispatch(
|
dispatch(
|
||||||
// addToast({
|
addToast({
|
||||||
// title: `${i18n.t('modelManager.modelsMerged')}: ${merged_model_name}`,
|
title: `${i18n.t('modelManager.modelsMerged')}: ${merged_model_name}`,
|
||||||
// status: 'success',
|
status: 'success',
|
||||||
// duration: 2500,
|
duration: 2500,
|
||||||
// isClosable: true,
|
isClosable: true,
|
||||||
// })
|
})
|
||||||
// );
|
);
|
||||||
// },
|
},
|
||||||
// onModelChanged: (data: InvokeAI.ModelChangeResponse) => {
|
onModelChanged: (data: InvokeAI.ModelChangeResponse) => {
|
||||||
// const { model_name, model_list } = data;
|
const { model_name, model_list } = data;
|
||||||
// dispatch(setModelList(model_list));
|
dispatch(setModelList(model_list));
|
||||||
// dispatch(setCurrentStatus(i18n.t('common.statusModelChanged')));
|
dispatch(setCurrentStatus(i18n.t('common.statusModelChanged')));
|
||||||
// dispatch(setIsProcessing(false));
|
dispatch(setIsProcessing(false));
|
||||||
// dispatch(setIsCancelable(true));
|
dispatch(setIsCancelable(true));
|
||||||
// dispatch(
|
dispatch(
|
||||||
// addLogEntry({
|
addLogEntry({
|
||||||
// timestamp: dateFormat(new Date(), 'isoDateTime'),
|
timestamp: dateFormat(new Date(), 'isoDateTime'),
|
||||||
// message: `Model changed: ${model_name}`,
|
message: `Model changed: ${model_name}`,
|
||||||
// level: 'info',
|
level: 'info',
|
||||||
// })
|
})
|
||||||
// );
|
);
|
||||||
// },
|
},
|
||||||
// onModelChangeFailed: (data: InvokeAI.ModelChangeResponse) => {
|
onModelChangeFailed: (data: InvokeAI.ModelChangeResponse) => {
|
||||||
// const { model_name, model_list } = data;
|
const { model_name, model_list } = data;
|
||||||
// dispatch(setModelList(model_list));
|
dispatch(setModelList(model_list));
|
||||||
// dispatch(setIsProcessing(false));
|
dispatch(setIsProcessing(false));
|
||||||
// dispatch(setIsCancelable(true));
|
dispatch(setIsCancelable(true));
|
||||||
// dispatch(errorOccurred());
|
dispatch(errorOccurred());
|
||||||
// dispatch(
|
dispatch(
|
||||||
// addLogEntry({
|
addLogEntry({
|
||||||
// timestamp: dateFormat(new Date(), 'isoDateTime'),
|
timestamp: dateFormat(new Date(), 'isoDateTime'),
|
||||||
// message: `Model change failed: ${model_name}`,
|
message: `Model change failed: ${model_name}`,
|
||||||
// level: 'error',
|
level: 'error',
|
||||||
// })
|
})
|
||||||
// );
|
);
|
||||||
// },
|
},
|
||||||
// onTempFolderEmptied: () => {
|
onTempFolderEmptied: () => {
|
||||||
// dispatch(
|
dispatch(
|
||||||
// addToast({
|
addToast({
|
||||||
// title: i18n.t('toast.tempFoldersEmptied'),
|
title: i18n.t('toast.tempFoldersEmptied'),
|
||||||
// status: 'success',
|
status: 'success',
|
||||||
// duration: 2500,
|
duration: 2500,
|
||||||
// isClosable: true,
|
isClosable: true,
|
||||||
// })
|
})
|
||||||
// );
|
);
|
||||||
// },
|
},
|
||||||
// };
|
};
|
||||||
// };
|
};
|
||||||
|
|
||||||
// export default makeSocketIOListeners;
|
export default makeSocketIOListeners;
|
||||||
|
|
||||||
export default {};
|
|
||||||
|
@ -1,248 +1,246 @@
|
|||||||
// import { Middleware } from '@reduxjs/toolkit';
|
import { Middleware } from '@reduxjs/toolkit';
|
||||||
// import { io } from 'socket.io-client';
|
import { io } from 'socket.io-client';
|
||||||
|
|
||||||
// import makeSocketIOEmitters from './emitters';
|
import makeSocketIOEmitters from './emitters';
|
||||||
// import makeSocketIOListeners from './listeners';
|
import makeSocketIOListeners from './listeners';
|
||||||
|
|
||||||
// import * as InvokeAI from 'app/types/invokeai';
|
import * as InvokeAI from 'app/invokeai';
|
||||||
|
|
||||||
// /**
|
/**
|
||||||
// * Creates a socketio middleware to handle communication with server.
|
* Creates a socketio middleware to handle communication with server.
|
||||||
// *
|
*
|
||||||
// * Special `socketio/actionName` actions are created in actions.ts and
|
* Special `socketio/actionName` actions are created in actions.ts and
|
||||||
// * exported for use by the application, which treats them like any old
|
* exported for use by the application, which treats them like any old
|
||||||
// * action, using `dispatch` to dispatch them.
|
* action, using `dispatch` to dispatch them.
|
||||||
// *
|
*
|
||||||
// * These actions are intercepted here, where `socketio.emit()` calls are
|
* These actions are intercepted here, where `socketio.emit()` calls are
|
||||||
// * made on their behalf - see `emitters.ts`. The emitter functions
|
* made on their behalf - see `emitters.ts`. The emitter functions
|
||||||
// * are the outbound communication to the server.
|
* are the outbound communication to the server.
|
||||||
// *
|
*
|
||||||
// * Listeners are also established here - see `listeners.ts`. The listener
|
* Listeners are also established here - see `listeners.ts`. The listener
|
||||||
// * functions receive communication from the server and usually dispatch
|
* functions receive communication from the server and usually dispatch
|
||||||
// * some new action to handle whatever data was sent from the server.
|
* some new action to handle whatever data was sent from the server.
|
||||||
// */
|
*/
|
||||||
// export const socketioMiddleware = () => {
|
export const socketioMiddleware = () => {
|
||||||
// const { origin } = new URL(window.location.href);
|
const { origin } = new URL(window.location.href);
|
||||||
|
|
||||||
// const socketio = io(origin, {
|
const socketio = io(origin, {
|
||||||
// timeout: 60000,
|
timeout: 60000,
|
||||||
// path: `${window.location.pathname}socket.io`,
|
path: `${window.location.pathname}socket.io`,
|
||||||
// });
|
});
|
||||||
|
|
||||||
// socketio.disconnect();
|
socketio.disconnect();
|
||||||
|
|
||||||
// let areListenersSet = false;
|
let areListenersSet = false;
|
||||||
|
|
||||||
// const middleware: Middleware = (store) => (next) => (action) => {
|
const middleware: Middleware = (store) => (next) => (action) => {
|
||||||
// const {
|
const {
|
||||||
// onConnect,
|
onConnect,
|
||||||
// onDisconnect,
|
onDisconnect,
|
||||||
// onError,
|
onError,
|
||||||
// onPostprocessingResult,
|
onPostprocessingResult,
|
||||||
// onGenerationResult,
|
onGenerationResult,
|
||||||
// onIntermediateResult,
|
onIntermediateResult,
|
||||||
// onProgressUpdate,
|
onProgressUpdate,
|
||||||
// onGalleryImages,
|
onGalleryImages,
|
||||||
// onProcessingCanceled,
|
onProcessingCanceled,
|
||||||
// onImageDeleted,
|
onImageDeleted,
|
||||||
// onSystemConfig,
|
onSystemConfig,
|
||||||
// onModelChanged,
|
onModelChanged,
|
||||||
// onFoundModels,
|
onFoundModels,
|
||||||
// onNewModelAdded,
|
onNewModelAdded,
|
||||||
// onModelDeleted,
|
onModelDeleted,
|
||||||
// onModelConverted,
|
onModelConverted,
|
||||||
// onModelsMerged,
|
onModelsMerged,
|
||||||
// onModelChangeFailed,
|
onModelChangeFailed,
|
||||||
// onTempFolderEmptied,
|
onTempFolderEmptied,
|
||||||
// } = makeSocketIOListeners(store);
|
} = makeSocketIOListeners(store);
|
||||||
|
|
||||||
// const {
|
const {
|
||||||
// emitGenerateImage,
|
emitGenerateImage,
|
||||||
// emitRunESRGAN,
|
emitRunESRGAN,
|
||||||
// emitRunFacetool,
|
emitRunFacetool,
|
||||||
// emitDeleteImage,
|
emitDeleteImage,
|
||||||
// emitRequestImages,
|
emitRequestImages,
|
||||||
// emitRequestNewImages,
|
emitRequestNewImages,
|
||||||
// emitCancelProcessing,
|
emitCancelProcessing,
|
||||||
// emitRequestSystemConfig,
|
emitRequestSystemConfig,
|
||||||
// emitSearchForModels,
|
emitSearchForModels,
|
||||||
// emitAddNewModel,
|
emitAddNewModel,
|
||||||
// emitDeleteModel,
|
emitDeleteModel,
|
||||||
// emitConvertToDiffusers,
|
emitConvertToDiffusers,
|
||||||
// emitMergeDiffusersModels,
|
emitMergeDiffusersModels,
|
||||||
// emitRequestModelChange,
|
emitRequestModelChange,
|
||||||
// emitSaveStagingAreaImageToGallery,
|
emitSaveStagingAreaImageToGallery,
|
||||||
// emitRequestEmptyTempFolder,
|
emitRequestEmptyTempFolder,
|
||||||
// } = makeSocketIOEmitters(store, socketio);
|
} = makeSocketIOEmitters(store, socketio);
|
||||||
|
|
||||||
// /**
|
/**
|
||||||
// * If this is the first time the middleware has been called (e.g. during store setup),
|
* If this is the first time the middleware has been called (e.g. during store setup),
|
||||||
// * initialize all our socket.io listeners.
|
* initialize all our socket.io listeners.
|
||||||
// */
|
*/
|
||||||
// if (!areListenersSet) {
|
if (!areListenersSet) {
|
||||||
// socketio.on('connect', () => onConnect());
|
socketio.on('connect', () => onConnect());
|
||||||
|
|
||||||
// socketio.on('disconnect', () => onDisconnect());
|
socketio.on('disconnect', () => onDisconnect());
|
||||||
|
|
||||||
// socketio.on('error', (data: InvokeAI.ErrorResponse) => onError(data));
|
socketio.on('error', (data: InvokeAI.ErrorResponse) => onError(data));
|
||||||
|
|
||||||
// socketio.on('generationResult', (data: InvokeAI.ImageResultResponse) =>
|
socketio.on('generationResult', (data: InvokeAI.ImageResultResponse) =>
|
||||||
// onGenerationResult(data)
|
onGenerationResult(data)
|
||||||
// );
|
);
|
||||||
|
|
||||||
// socketio.on(
|
socketio.on(
|
||||||
// 'postprocessingResult',
|
'postprocessingResult',
|
||||||
// (data: InvokeAI.ImageResultResponse) => onPostprocessingResult(data)
|
(data: InvokeAI.ImageResultResponse) => onPostprocessingResult(data)
|
||||||
// );
|
);
|
||||||
|
|
||||||
// socketio.on('intermediateResult', (data: InvokeAI.ImageResultResponse) =>
|
socketio.on('intermediateResult', (data: InvokeAI.ImageResultResponse) =>
|
||||||
// onIntermediateResult(data)
|
onIntermediateResult(data)
|
||||||
// );
|
);
|
||||||
|
|
||||||
// socketio.on('progressUpdate', (data: InvokeAI.SystemStatus) =>
|
socketio.on('progressUpdate', (data: InvokeAI.SystemStatus) =>
|
||||||
// onProgressUpdate(data)
|
onProgressUpdate(data)
|
||||||
// );
|
);
|
||||||
|
|
||||||
// socketio.on('galleryImages', (data: InvokeAI.GalleryImagesResponse) =>
|
socketio.on('galleryImages', (data: InvokeAI.GalleryImagesResponse) =>
|
||||||
// onGalleryImages(data)
|
onGalleryImages(data)
|
||||||
// );
|
);
|
||||||
|
|
||||||
// socketio.on('processingCanceled', () => {
|
socketio.on('processingCanceled', () => {
|
||||||
// onProcessingCanceled();
|
onProcessingCanceled();
|
||||||
// });
|
});
|
||||||
|
|
||||||
// socketio.on('imageDeleted', (data: InvokeAI.ImageDeletedResponse) => {
|
socketio.on('imageDeleted', (data: InvokeAI.ImageDeletedResponse) => {
|
||||||
// onImageDeleted(data);
|
onImageDeleted(data);
|
||||||
// });
|
});
|
||||||
|
|
||||||
// socketio.on('systemConfig', (data: InvokeAI.SystemConfig) => {
|
socketio.on('systemConfig', (data: InvokeAI.SystemConfig) => {
|
||||||
// onSystemConfig(data);
|
onSystemConfig(data);
|
||||||
// });
|
});
|
||||||
|
|
||||||
// socketio.on('foundModels', (data: InvokeAI.FoundModelResponse) => {
|
socketio.on('foundModels', (data: InvokeAI.FoundModelResponse) => {
|
||||||
// onFoundModels(data);
|
onFoundModels(data);
|
||||||
// });
|
});
|
||||||
|
|
||||||
// socketio.on('newModelAdded', (data: InvokeAI.ModelAddedResponse) => {
|
socketio.on('newModelAdded', (data: InvokeAI.ModelAddedResponse) => {
|
||||||
// onNewModelAdded(data);
|
onNewModelAdded(data);
|
||||||
// });
|
});
|
||||||
|
|
||||||
// socketio.on('modelDeleted', (data: InvokeAI.ModelDeletedResponse) => {
|
socketio.on('modelDeleted', (data: InvokeAI.ModelDeletedResponse) => {
|
||||||
// onModelDeleted(data);
|
onModelDeleted(data);
|
||||||
// });
|
});
|
||||||
|
|
||||||
// socketio.on('modelConverted', (data: InvokeAI.ModelConvertedResponse) => {
|
socketio.on('modelConverted', (data: InvokeAI.ModelConvertedResponse) => {
|
||||||
// onModelConverted(data);
|
onModelConverted(data);
|
||||||
// });
|
});
|
||||||
|
|
||||||
// socketio.on('modelsMerged', (data: InvokeAI.ModelsMergedResponse) => {
|
socketio.on('modelsMerged', (data: InvokeAI.ModelsMergedResponse) => {
|
||||||
// onModelsMerged(data);
|
onModelsMerged(data);
|
||||||
// });
|
});
|
||||||
|
|
||||||
// socketio.on('modelChanged', (data: InvokeAI.ModelChangeResponse) => {
|
socketio.on('modelChanged', (data: InvokeAI.ModelChangeResponse) => {
|
||||||
// onModelChanged(data);
|
onModelChanged(data);
|
||||||
// });
|
});
|
||||||
|
|
||||||
// socketio.on('modelChangeFailed', (data: InvokeAI.ModelChangeResponse) => {
|
socketio.on('modelChangeFailed', (data: InvokeAI.ModelChangeResponse) => {
|
||||||
// onModelChangeFailed(data);
|
onModelChangeFailed(data);
|
||||||
// });
|
});
|
||||||
|
|
||||||
// socketio.on('tempFolderEmptied', () => {
|
socketio.on('tempFolderEmptied', () => {
|
||||||
// onTempFolderEmptied();
|
onTempFolderEmptied();
|
||||||
// });
|
});
|
||||||
|
|
||||||
// areListenersSet = true;
|
areListenersSet = true;
|
||||||
// }
|
}
|
||||||
|
|
||||||
// /**
|
/**
|
||||||
// * Handle redux actions caught by middleware.
|
* Handle redux actions caught by middleware.
|
||||||
// */
|
*/
|
||||||
// switch (action.type) {
|
switch (action.type) {
|
||||||
// case 'socketio/generateImage': {
|
case 'socketio/generateImage': {
|
||||||
// emitGenerateImage(action.payload);
|
emitGenerateImage(action.payload);
|
||||||
// break;
|
break;
|
||||||
// }
|
}
|
||||||
|
|
||||||
// case 'socketio/runESRGAN': {
|
case 'socketio/runESRGAN': {
|
||||||
// emitRunESRGAN(action.payload);
|
emitRunESRGAN(action.payload);
|
||||||
// break;
|
break;
|
||||||
// }
|
}
|
||||||
|
|
||||||
// case 'socketio/runFacetool': {
|
case 'socketio/runFacetool': {
|
||||||
// emitRunFacetool(action.payload);
|
emitRunFacetool(action.payload);
|
||||||
// break;
|
break;
|
||||||
// }
|
}
|
||||||
|
|
||||||
// case 'socketio/deleteImage': {
|
case 'socketio/deleteImage': {
|
||||||
// emitDeleteImage(action.payload);
|
emitDeleteImage(action.payload);
|
||||||
// break;
|
break;
|
||||||
// }
|
}
|
||||||
|
|
||||||
// case 'socketio/requestImages': {
|
case 'socketio/requestImages': {
|
||||||
// emitRequestImages(action.payload);
|
emitRequestImages(action.payload);
|
||||||
// break;
|
break;
|
||||||
// }
|
}
|
||||||
|
|
||||||
// case 'socketio/requestNewImages': {
|
case 'socketio/requestNewImages': {
|
||||||
// emitRequestNewImages(action.payload);
|
emitRequestNewImages(action.payload);
|
||||||
// break;
|
break;
|
||||||
// }
|
}
|
||||||
|
|
||||||
// case 'socketio/cancelProcessing': {
|
case 'socketio/cancelProcessing': {
|
||||||
// emitCancelProcessing();
|
emitCancelProcessing();
|
||||||
// break;
|
break;
|
||||||
// }
|
}
|
||||||
|
|
||||||
// case 'socketio/requestSystemConfig': {
|
case 'socketio/requestSystemConfig': {
|
||||||
// emitRequestSystemConfig();
|
emitRequestSystemConfig();
|
||||||
// break;
|
break;
|
||||||
// }
|
}
|
||||||
|
|
||||||
// case 'socketio/searchForModels': {
|
case 'socketio/searchForModels': {
|
||||||
// emitSearchForModels(action.payload);
|
emitSearchForModels(action.payload);
|
||||||
// break;
|
break;
|
||||||
// }
|
}
|
||||||
|
|
||||||
// case 'socketio/addNewModel': {
|
case 'socketio/addNewModel': {
|
||||||
// emitAddNewModel(action.payload);
|
emitAddNewModel(action.payload);
|
||||||
// break;
|
break;
|
||||||
// }
|
}
|
||||||
|
|
||||||
// case 'socketio/deleteModel': {
|
case 'socketio/deleteModel': {
|
||||||
// emitDeleteModel(action.payload);
|
emitDeleteModel(action.payload);
|
||||||
// break;
|
break;
|
||||||
// }
|
}
|
||||||
|
|
||||||
// case 'socketio/convertToDiffusers': {
|
case 'socketio/convertToDiffusers': {
|
||||||
// emitConvertToDiffusers(action.payload);
|
emitConvertToDiffusers(action.payload);
|
||||||
// break;
|
break;
|
||||||
// }
|
}
|
||||||
|
|
||||||
// case 'socketio/mergeDiffusersModels': {
|
case 'socketio/mergeDiffusersModels': {
|
||||||
// emitMergeDiffusersModels(action.payload);
|
emitMergeDiffusersModels(action.payload);
|
||||||
// break;
|
break;
|
||||||
// }
|
}
|
||||||
|
|
||||||
// case 'socketio/requestModelChange': {
|
case 'socketio/requestModelChange': {
|
||||||
// emitRequestModelChange(action.payload);
|
emitRequestModelChange(action.payload);
|
||||||
// break;
|
break;
|
||||||
// }
|
}
|
||||||
|
|
||||||
// case 'socketio/saveStagingAreaImageToGallery': {
|
case 'socketio/saveStagingAreaImageToGallery': {
|
||||||
// emitSaveStagingAreaImageToGallery(action.payload);
|
emitSaveStagingAreaImageToGallery(action.payload);
|
||||||
// break;
|
break;
|
||||||
// }
|
}
|
||||||
|
|
||||||
// case 'socketio/requestEmptyTempFolder': {
|
case 'socketio/requestEmptyTempFolder': {
|
||||||
// emitRequestEmptyTempFolder();
|
emitRequestEmptyTempFolder();
|
||||||
// break;
|
break;
|
||||||
// }
|
}
|
||||||
// }
|
}
|
||||||
|
|
||||||
// next(action);
|
next(action);
|
||||||
// };
|
};
|
||||||
|
|
||||||
// return middleware;
|
return middleware;
|
||||||
// };
|
};
|
||||||
|
|
||||||
export default {};
|
|
||||||
|
@ -13,23 +13,21 @@ import lightboxReducer from 'features/lightbox/store/lightboxSlice';
|
|||||||
import generationReducer from 'features/parameters/store/generationSlice';
|
import generationReducer from 'features/parameters/store/generationSlice';
|
||||||
import postprocessingReducer from 'features/parameters/store/postprocessingSlice';
|
import postprocessingReducer from 'features/parameters/store/postprocessingSlice';
|
||||||
import systemReducer from 'features/system/store/systemSlice';
|
import systemReducer from 'features/system/store/systemSlice';
|
||||||
import configReducer from 'features/system/store/configSlice';
|
|
||||||
import uiReducer from 'features/ui/store/uiSlice';
|
import uiReducer from 'features/ui/store/uiSlice';
|
||||||
import hotkeysReducer from 'features/ui/store/hotkeysSlice';
|
|
||||||
import modelsReducer from 'features/system/store/modelSlice';
|
import modelsReducer from 'features/system/store/modelSlice';
|
||||||
import nodesReducer from 'features/nodes/store/nodesSlice';
|
import nodesReducer from 'features/nodes/store/nodesSlice';
|
||||||
|
|
||||||
import { canvasDenylist } from 'features/canvas/store/canvasPersistDenylist';
|
import { socketioMiddleware } from './socketio/middleware';
|
||||||
import { galleryDenylist } from 'features/gallery/store/galleryPersistDenylist';
|
import { socketMiddleware } from 'services/events/middleware';
|
||||||
import { generationDenylist } from 'features/parameters/store/generationPersistDenylist';
|
import { canvasBlacklist } from 'features/canvas/store/canvasPersistBlacklist';
|
||||||
import { lightboxDenylist } from 'features/lightbox/store/lightboxPersistDenylist';
|
import { galleryBlacklist } from 'features/gallery/store/galleryPersistBlacklist';
|
||||||
import { modelsDenylist } from 'features/system/store/modelsPersistDenylist';
|
import { generationBlacklist } from 'features/parameters/store/generationPersistBlacklist';
|
||||||
import { nodesDenylist } from 'features/nodes/store/nodesPersistDenylist';
|
import { lightboxBlacklist } from 'features/lightbox/store/lightboxPersistBlacklist';
|
||||||
import { postprocessingDenylist } from 'features/parameters/store/postprocessingPersistDenylist';
|
import { modelsBlacklist } from 'features/system/store/modelsPersistBlacklist';
|
||||||
import { systemDenylist } from 'features/system/store/systemPersistDenylist';
|
import { nodesBlacklist } from 'features/nodes/store/nodesPersistBlacklist';
|
||||||
import { uiDenylist } from 'features/ui/store/uiPersistDenylist';
|
import { postprocessingBlacklist } from 'features/parameters/store/postprocessingPersistBlacklist';
|
||||||
import { resultsDenylist } from 'features/gallery/store/resultsPersistDenylist';
|
import { systemBlacklist } from 'features/system/store/systemPersistsBlacklist';
|
||||||
import { uploadsDenylist } from 'features/gallery/store/uploadsPersistDenylist';
|
import { uiBlacklist } from 'features/ui/store/uiPersistBlacklist';
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* redux-persist provides an easy and reliable way to persist state across reloads.
|
* redux-persist provides an easy and reliable way to persist state across reloads.
|
||||||
@ -40,9 +38,9 @@ import { uploadsDenylist } from 'features/gallery/store/uploadsPersistDenylist';
|
|||||||
* - Connection/processing status
|
* - Connection/processing status
|
||||||
* - Availability of external libraries like ESRGAN/GFPGAN
|
* - Availability of external libraries like ESRGAN/GFPGAN
|
||||||
*
|
*
|
||||||
* These can be denylisted in redux-persist.
|
* These can be blacklisted in redux-persist.
|
||||||
*
|
*
|
||||||
* The necesssary nested persistors with denylists are configured below.
|
* The necesssary nested persistors with blacklists are configured below.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
const rootReducer = combineReducers({
|
const rootReducer = combineReducers({
|
||||||
@ -55,10 +53,8 @@ const rootReducer = combineReducers({
|
|||||||
postprocessing: postprocessingReducer,
|
postprocessing: postprocessingReducer,
|
||||||
results: resultsReducer,
|
results: resultsReducer,
|
||||||
system: systemReducer,
|
system: systemReducer,
|
||||||
config: configReducer,
|
|
||||||
ui: uiReducer,
|
ui: uiReducer,
|
||||||
uploads: uploadsReducer,
|
uploads: uploadsReducer,
|
||||||
hotkeys: hotkeysReducer,
|
|
||||||
});
|
});
|
||||||
|
|
||||||
const rootPersistConfig = getPersistConfig({
|
const rootPersistConfig = getPersistConfig({
|
||||||
@ -66,34 +62,33 @@ const rootPersistConfig = getPersistConfig({
|
|||||||
storage,
|
storage,
|
||||||
rootReducer,
|
rootReducer,
|
||||||
blacklist: [
|
blacklist: [
|
||||||
...canvasDenylist,
|
...canvasBlacklist,
|
||||||
...galleryDenylist,
|
...galleryBlacklist,
|
||||||
...generationDenylist,
|
...generationBlacklist,
|
||||||
...lightboxDenylist,
|
...lightboxBlacklist,
|
||||||
...modelsDenylist,
|
...modelsBlacklist,
|
||||||
...nodesDenylist,
|
...nodesBlacklist,
|
||||||
...postprocessingDenylist,
|
...postprocessingBlacklist,
|
||||||
// ...resultsDenylist,
|
// ...resultsBlacklist,
|
||||||
'results',
|
'results',
|
||||||
...systemDenylist,
|
...systemBlacklist,
|
||||||
...uiDenylist,
|
...uiBlacklist,
|
||||||
// ...uploadsDenylist,
|
// ...uploadsBlacklist,
|
||||||
'uploads',
|
'uploads',
|
||||||
'hotkeys',
|
|
||||||
'config',
|
|
||||||
],
|
],
|
||||||
|
debounce: 300,
|
||||||
});
|
});
|
||||||
|
|
||||||
const persistedReducer = persistReducer(rootPersistConfig, rootReducer);
|
const persistedReducer = persistReducer(rootPersistConfig, rootReducer);
|
||||||
|
|
||||||
// TODO: rip the old middleware out when nodes is complete
|
// TODO: rip the old middleware out when nodes is complete
|
||||||
// export function buildMiddleware() {
|
export function buildMiddleware() {
|
||||||
// if (import.meta.env.MODE === 'nodes' || import.meta.env.MODE === 'package') {
|
if (import.meta.env.MODE === 'nodes' || import.meta.env.MODE === 'package') {
|
||||||
// return socketMiddleware();
|
return socketMiddleware();
|
||||||
// } else {
|
} else {
|
||||||
// return socketioMiddleware();
|
return socketioMiddleware();
|
||||||
// }
|
}
|
||||||
// }
|
}
|
||||||
|
|
||||||
export const store = configureStore({
|
export const store = configureStore({
|
||||||
reducer: persistedReducer,
|
reducer: persistedReducer,
|
||||||
@ -113,7 +108,6 @@ export const store = configureStore({
|
|||||||
'canvas/setBoundingBoxDimensions',
|
'canvas/setBoundingBoxDimensions',
|
||||||
'canvas/setIsDrawing',
|
'canvas/setIsDrawing',
|
||||||
'canvas/addPointToCurrentLine',
|
'canvas/addPointToCurrentLine',
|
||||||
'socket/generatorProgress',
|
|
||||||
],
|
],
|
||||||
},
|
},
|
||||||
});
|
});
|
@ -1,5 +1,5 @@
|
|||||||
import { TypedUseSelectorHook, useDispatch, useSelector } from 'react-redux';
|
import { TypedUseSelectorHook, useDispatch, useSelector } from 'react-redux';
|
||||||
import { AppDispatch, RootState } from 'app/store/store';
|
import { AppDispatch, RootState } from './store';
|
||||||
|
|
||||||
// Use throughout your app instead of plain `useDispatch` and `useSelector`
|
// Use throughout your app instead of plain `useDispatch` and `useSelector`
|
||||||
export const useAppDispatch: () => AppDispatch = useDispatch;
|
export const useAppDispatch: () => AppDispatch = useDispatch;
|
@ -1,5 +1,5 @@
|
|||||||
import { createAsyncThunk } from '@reduxjs/toolkit';
|
import { createAsyncThunk } from '@reduxjs/toolkit';
|
||||||
import { AppDispatch, RootState } from 'app/store/store';
|
import { AppDispatch, RootState } from './store';
|
||||||
|
|
||||||
// https://redux-toolkit.js.org/usage/usage-with-typescript#defining-a-pre-typed-createasyncthunk
|
// https://redux-toolkit.js.org/usage/usage-with-typescript#defining-a-pre-typed-createasyncthunk
|
||||||
export const createAppAsyncThunk = createAsyncThunk.withTypes<{
|
export const createAppAsyncThunk = createAsyncThunk.withTypes<{
|
@ -1,393 +0,0 @@
|
|||||||
/**
|
|
||||||
* Types for images, the things they are made of, and the things
|
|
||||||
* they make up.
|
|
||||||
*
|
|
||||||
* Generated images are txt2img and img2img images. They may have
|
|
||||||
* had additional postprocessing done on them when they were first
|
|
||||||
* generated.
|
|
||||||
*
|
|
||||||
* Postprocessed images are images which were not generated here
|
|
||||||
* but only postprocessed by the app. They only get postprocessing
|
|
||||||
* metadata and have a different image type, e.g. 'esrgan' or
|
|
||||||
* 'gfpgan'.
|
|
||||||
*/
|
|
||||||
|
|
||||||
import { GalleryCategory } from 'features/gallery/store/gallerySlice';
|
|
||||||
import { FacetoolType } from 'features/parameters/store/postprocessingSlice';
|
|
||||||
import { InvokeTabName } from 'features/ui/store/tabMap';
|
|
||||||
import { IRect } from 'konva/lib/types';
|
|
||||||
import { ImageResponseMetadata, ImageType } from 'services/api';
|
|
||||||
import { AnyInvocation } from 'services/events/types';
|
|
||||||
import { O } from 'ts-toolbelt';
|
|
||||||
|
|
||||||
/**
|
|
||||||
* TODO:
|
|
||||||
* Once an image has been generated, if it is postprocessed again,
|
|
||||||
* additional postprocessing steps are added to its postprocessing
|
|
||||||
* array.
|
|
||||||
*
|
|
||||||
* TODO: Better documentation of types.
|
|
||||||
*/
|
|
||||||
|
|
||||||
export type PromptItem = {
|
|
||||||
prompt: string;
|
|
||||||
weight: number;
|
|
||||||
};
|
|
||||||
|
|
||||||
// TECHDEBT: We need to retain compatibility with plain prompt strings and the structure Prompt type
|
|
||||||
export type Prompt = Array<PromptItem> | string;
|
|
||||||
|
|
||||||
export type SeedWeightPair = {
|
|
||||||
seed: number;
|
|
||||||
weight: number;
|
|
||||||
};
|
|
||||||
|
|
||||||
export type SeedWeights = Array<SeedWeightPair>;
|
|
||||||
|
|
||||||
// All generated images contain these metadata.
|
|
||||||
export type CommonGeneratedImageMetadata = {
|
|
||||||
postprocessing: null | Array<ESRGANMetadata | FacetoolMetadata>;
|
|
||||||
sampler:
|
|
||||||
| 'ddim'
|
|
||||||
| 'k_dpm_2_a'
|
|
||||||
| 'k_dpm_2'
|
|
||||||
| 'k_dpmpp_2_a'
|
|
||||||
| 'k_dpmpp_2'
|
|
||||||
| 'k_euler_a'
|
|
||||||
| 'k_euler'
|
|
||||||
| 'k_heun'
|
|
||||||
| 'k_lms'
|
|
||||||
| 'plms';
|
|
||||||
prompt: Prompt;
|
|
||||||
seed: number;
|
|
||||||
variations: SeedWeights;
|
|
||||||
steps: number;
|
|
||||||
cfg_scale: number;
|
|
||||||
width: number;
|
|
||||||
height: number;
|
|
||||||
seamless: boolean;
|
|
||||||
hires_fix: boolean;
|
|
||||||
extra: null | Record<string, never>; // Pending development of RFC #266
|
|
||||||
};
|
|
||||||
|
|
||||||
// txt2img and img2img images have some unique attributes.
|
|
||||||
export type Txt2ImgMetadata = CommonGeneratedImageMetadata & {
|
|
||||||
type: 'txt2img';
|
|
||||||
};
|
|
||||||
|
|
||||||
export type Img2ImgMetadata = CommonGeneratedImageMetadata & {
|
|
||||||
type: 'img2img';
|
|
||||||
orig_hash: string;
|
|
||||||
strength: number;
|
|
||||||
fit: boolean;
|
|
||||||
init_image_path: string;
|
|
||||||
mask_image_path?: string;
|
|
||||||
};
|
|
||||||
|
|
||||||
// Superset of generated image metadata types.
|
|
||||||
export type GeneratedImageMetadata = Txt2ImgMetadata | Img2ImgMetadata;
|
|
||||||
|
|
||||||
// All post processed images contain these metadata.
|
|
||||||
export type CommonPostProcessedImageMetadata = {
|
|
||||||
orig_path: string;
|
|
||||||
orig_hash: string;
|
|
||||||
};
|
|
||||||
|
|
||||||
// esrgan and gfpgan images have some unique attributes.
|
|
||||||
export type ESRGANMetadata = CommonPostProcessedImageMetadata & {
|
|
||||||
type: 'esrgan';
|
|
||||||
scale: 2 | 4;
|
|
||||||
strength: number;
|
|
||||||
denoise_str: number;
|
|
||||||
};
|
|
||||||
|
|
||||||
export type FacetoolMetadata = CommonPostProcessedImageMetadata & {
|
|
||||||
type: 'gfpgan' | 'codeformer';
|
|
||||||
strength: number;
|
|
||||||
fidelity?: number;
|
|
||||||
};
|
|
||||||
|
|
||||||
// Superset of all postprocessed image metadata types..
|
|
||||||
export type PostProcessedImageMetadata = ESRGANMetadata | FacetoolMetadata;
|
|
||||||
|
|
||||||
// Metadata includes the system config and image metadata.
|
|
||||||
// export type Metadata = SystemGenerationMetadata & {
|
|
||||||
// image: GeneratedImageMetadata | PostProcessedImageMetadata;
|
|
||||||
// };
|
|
||||||
|
|
||||||
/**
|
|
||||||
* ResultImage
|
|
||||||
*/
|
|
||||||
export type Image = {
|
|
||||||
name: string;
|
|
||||||
type: ImageType;
|
|
||||||
url: string;
|
|
||||||
thumbnail: string;
|
|
||||||
metadata: ImageResponseMetadata;
|
|
||||||
};
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Types related to the system status.
|
|
||||||
*/
|
|
||||||
|
|
||||||
// // This represents the processing status of the backend.
|
|
||||||
// export type SystemStatus = {
|
|
||||||
// isProcessing: boolean;
|
|
||||||
// currentStep: number;
|
|
||||||
// totalSteps: number;
|
|
||||||
// currentIteration: number;
|
|
||||||
// totalIterations: number;
|
|
||||||
// currentStatus: string;
|
|
||||||
// currentStatusHasSteps: boolean;
|
|
||||||
// hasError: boolean;
|
|
||||||
// };
|
|
||||||
|
|
||||||
// export type SystemGenerationMetadata = {
|
|
||||||
// model: string;
|
|
||||||
// model_weights?: string;
|
|
||||||
// model_id?: string;
|
|
||||||
// model_hash: string;
|
|
||||||
// app_id: string;
|
|
||||||
// app_version: string;
|
|
||||||
// };
|
|
||||||
|
|
||||||
// export type SystemConfig = SystemGenerationMetadata & {
|
|
||||||
// model_list: ModelList;
|
|
||||||
// infill_methods: string[];
|
|
||||||
// };
|
|
||||||
|
|
||||||
export type ModelStatus = 'active' | 'cached' | 'not loaded';
|
|
||||||
|
|
||||||
export type Model = {
|
|
||||||
status: ModelStatus;
|
|
||||||
description: string;
|
|
||||||
weights: string;
|
|
||||||
config?: string;
|
|
||||||
vae?: string;
|
|
||||||
width?: number;
|
|
||||||
height?: number;
|
|
||||||
default?: boolean;
|
|
||||||
format?: string;
|
|
||||||
};
|
|
||||||
|
|
||||||
export type DiffusersModel = {
|
|
||||||
status: ModelStatus;
|
|
||||||
description: string;
|
|
||||||
repo_id?: string;
|
|
||||||
path?: string;
|
|
||||||
vae?: {
|
|
||||||
repo_id?: string;
|
|
||||||
path?: string;
|
|
||||||
};
|
|
||||||
format?: string;
|
|
||||||
default?: boolean;
|
|
||||||
};
|
|
||||||
|
|
||||||
export type ModelList = Record<string, Model & DiffusersModel>;
|
|
||||||
|
|
||||||
export type FoundModel = {
|
|
||||||
name: string;
|
|
||||||
location: string;
|
|
||||||
};
|
|
||||||
|
|
||||||
export type InvokeModelConfigProps = {
|
|
||||||
name: string | undefined;
|
|
||||||
description: string | undefined;
|
|
||||||
config: string | undefined;
|
|
||||||
weights: string | undefined;
|
|
||||||
vae: string | undefined;
|
|
||||||
width: number | undefined;
|
|
||||||
height: number | undefined;
|
|
||||||
default: boolean | undefined;
|
|
||||||
format: string | undefined;
|
|
||||||
};
|
|
||||||
|
|
||||||
export type InvokeDiffusersModelConfigProps = {
|
|
||||||
name: string | undefined;
|
|
||||||
description: string | undefined;
|
|
||||||
repo_id: string | undefined;
|
|
||||||
path: string | undefined;
|
|
||||||
default: boolean | undefined;
|
|
||||||
format: string | undefined;
|
|
||||||
vae: {
|
|
||||||
repo_id: string | undefined;
|
|
||||||
path: string | undefined;
|
|
||||||
};
|
|
||||||
};
|
|
||||||
|
|
||||||
export type InvokeModelConversionProps = {
|
|
||||||
model_name: string;
|
|
||||||
save_location: string;
|
|
||||||
custom_location: string | null;
|
|
||||||
};
|
|
||||||
|
|
||||||
export type InvokeModelMergingProps = {
|
|
||||||
models_to_merge: string[];
|
|
||||||
alpha: number;
|
|
||||||
interp: 'weighted_sum' | 'sigmoid' | 'inv_sigmoid' | 'add_difference';
|
|
||||||
force: boolean;
|
|
||||||
merged_model_name: string;
|
|
||||||
model_merge_save_path: string | null;
|
|
||||||
};
|
|
||||||
|
|
||||||
/**
|
|
||||||
* These types type data received from the server via socketio.
|
|
||||||
*/
|
|
||||||
|
|
||||||
export type ModelChangeResponse = {
|
|
||||||
model_name: string;
|
|
||||||
model_list: ModelList;
|
|
||||||
};
|
|
||||||
|
|
||||||
export type ModelConvertedResponse = {
|
|
||||||
converted_model_name: string;
|
|
||||||
model_list: ModelList;
|
|
||||||
};
|
|
||||||
|
|
||||||
export type ModelsMergedResponse = {
|
|
||||||
merged_models: string[];
|
|
||||||
merged_model_name: string;
|
|
||||||
model_list: ModelList;
|
|
||||||
};
|
|
||||||
|
|
||||||
export type ModelAddedResponse = {
|
|
||||||
new_model_name: string;
|
|
||||||
model_list: ModelList;
|
|
||||||
update: boolean;
|
|
||||||
};
|
|
||||||
|
|
||||||
export type ModelDeletedResponse = {
|
|
||||||
deleted_model_name: string;
|
|
||||||
model_list: ModelList;
|
|
||||||
};
|
|
||||||
|
|
||||||
export type FoundModelResponse = {
|
|
||||||
search_folder: string;
|
|
||||||
found_models: FoundModel[];
|
|
||||||
};
|
|
||||||
|
|
||||||
// export type SystemStatusResponse = SystemStatus;
|
|
||||||
|
|
||||||
// export type SystemConfigResponse = SystemConfig;
|
|
||||||
|
|
||||||
export type ImageResultResponse = Omit<_Image, 'uuid'> & {
|
|
||||||
boundingBox?: IRect;
|
|
||||||
generationMode: InvokeTabName;
|
|
||||||
};
|
|
||||||
|
|
||||||
export type ImageUploadResponse = {
|
|
||||||
// image: Omit<Image, 'uuid' | 'metadata' | 'category'>;
|
|
||||||
url: string;
|
|
||||||
mtime: number;
|
|
||||||
width: number;
|
|
||||||
height: number;
|
|
||||||
thumbnail: string;
|
|
||||||
// bbox: [number, number, number, number];
|
|
||||||
};
|
|
||||||
|
|
||||||
export type ErrorResponse = {
|
|
||||||
message: string;
|
|
||||||
additionalData?: string;
|
|
||||||
};
|
|
||||||
|
|
||||||
export type ImageUrlResponse = {
|
|
||||||
url: string;
|
|
||||||
};
|
|
||||||
|
|
||||||
export type UploadOutpaintingMergeImagePayload = {
|
|
||||||
dataURL: string;
|
|
||||||
name: string;
|
|
||||||
};
|
|
||||||
|
|
||||||
/**
|
|
||||||
* A disable-able application feature
|
|
||||||
*/
|
|
||||||
export type AppFeature =
|
|
||||||
| 'faceRestore'
|
|
||||||
| 'upscaling'
|
|
||||||
| 'lightbox'
|
|
||||||
| 'modelManager'
|
|
||||||
| 'githubLink'
|
|
||||||
| 'discordLink'
|
|
||||||
| 'bugLink'
|
|
||||||
| 'localization';
|
|
||||||
|
|
||||||
/**
|
|
||||||
* A disable-able Stable Diffusion feature
|
|
||||||
*/
|
|
||||||
export type StableDiffusionFeature =
|
|
||||||
| 'noiseConfig'
|
|
||||||
| 'variations'
|
|
||||||
| 'symmetry'
|
|
||||||
| 'tiling'
|
|
||||||
| 'hires';
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Configuration options for the InvokeAI UI.
|
|
||||||
* Distinct from system settings which may be changed inside the app.
|
|
||||||
*/
|
|
||||||
export type AppConfig = {
|
|
||||||
/**
|
|
||||||
* Whether or not URLs should be transformed to use a different host
|
|
||||||
*/
|
|
||||||
shouldTransformUrls: boolean;
|
|
||||||
/**
|
|
||||||
* Whether or not we need to re-fetch images
|
|
||||||
*/
|
|
||||||
shouldFetchImages: boolean;
|
|
||||||
disabledTabs: InvokeTabName[];
|
|
||||||
disabledFeatures: AppFeature[];
|
|
||||||
canRestoreDeletedImagesFromBin: boolean;
|
|
||||||
sd: {
|
|
||||||
iterations: {
|
|
||||||
initial: number;
|
|
||||||
min: number;
|
|
||||||
sliderMax: number;
|
|
||||||
inputMax: number;
|
|
||||||
fineStep: number;
|
|
||||||
coarseStep: number;
|
|
||||||
};
|
|
||||||
width: {
|
|
||||||
initial: number;
|
|
||||||
min: number;
|
|
||||||
sliderMax: number;
|
|
||||||
inputMax: number;
|
|
||||||
fineStep: number;
|
|
||||||
coarseStep: number;
|
|
||||||
};
|
|
||||||
height: {
|
|
||||||
initial: number;
|
|
||||||
min: number;
|
|
||||||
sliderMax: number;
|
|
||||||
inputMax: number;
|
|
||||||
fineStep: number;
|
|
||||||
coarseStep: number;
|
|
||||||
};
|
|
||||||
steps: {
|
|
||||||
initial: number;
|
|
||||||
min: number;
|
|
||||||
sliderMax: number;
|
|
||||||
inputMax: number;
|
|
||||||
fineStep: number;
|
|
||||||
coarseStep: number;
|
|
||||||
};
|
|
||||||
guidance: {
|
|
||||||
initial: number;
|
|
||||||
min: number;
|
|
||||||
sliderMax: number;
|
|
||||||
inputMax: number;
|
|
||||||
fineStep: number;
|
|
||||||
coarseStep: number;
|
|
||||||
};
|
|
||||||
img2imgStrength: {
|
|
||||||
initial: number;
|
|
||||||
min: number;
|
|
||||||
sliderMax: number;
|
|
||||||
inputMax: number;
|
|
||||||
fineStep: number;
|
|
||||||
coarseStep: number;
|
|
||||||
};
|
|
||||||
};
|
|
||||||
};
|
|
||||||
|
|
||||||
export type PartialAppConfig = O.Partial<AppConfig, 'deep'>;
|
|
25
invokeai/frontend/web/src/app/utils.ts
Normal file
25
invokeai/frontend/web/src/app/utils.ts
Normal file
@ -0,0 +1,25 @@
|
|||||||
|
export function keepGUIAlive() {
|
||||||
|
async function getRequest(url = '') {
|
||||||
|
const response = await fetch(url, {
|
||||||
|
method: 'GET',
|
||||||
|
cache: 'no-cache',
|
||||||
|
});
|
||||||
|
return response;
|
||||||
|
}
|
||||||
|
|
||||||
|
const keepAliveServer = () => {
|
||||||
|
const url = document.location;
|
||||||
|
const route = '/flaskwebgui-keep-server-alive';
|
||||||
|
getRequest(url + route).then((data) => {
|
||||||
|
return data;
|
||||||
|
});
|
||||||
|
};
|
||||||
|
|
||||||
|
if (!import.meta.env.NODE_ENV || import.meta.env.NODE_ENV === 'production') {
|
||||||
|
document.addEventListener('DOMContentLoaded', () => {
|
||||||
|
const intervalRequest = 3 * 1000;
|
||||||
|
keepAliveServer();
|
||||||
|
setInterval(keepAliveServer, intervalRequest);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
@ -8,7 +8,7 @@ import {
|
|||||||
} from '@chakra-ui/react';
|
} from '@chakra-ui/react';
|
||||||
import { createSelector } from '@reduxjs/toolkit';
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
import { Feature, useFeatureHelpInfo } from 'app/features';
|
import { Feature, useFeatureHelpInfo } from 'app/features';
|
||||||
import { useAppSelector } from 'app/store/storeHooks';
|
import { useAppSelector } from 'app/storeHooks';
|
||||||
import { systemSelector } from 'features/system/store/systemSelectors';
|
import { systemSelector } from 'features/system/store/systemSelectors';
|
||||||
import { SystemState } from 'features/system/store/systemSlice';
|
import { SystemState } from 'features/system/store/systemSlice';
|
||||||
import { memo, ReactElement } from 'react';
|
import { memo, ReactElement } from 'react';
|
||||||
|
@ -14,7 +14,7 @@ import {
|
|||||||
Tooltip,
|
Tooltip,
|
||||||
TooltipProps,
|
TooltipProps,
|
||||||
} from '@chakra-ui/react';
|
} from '@chakra-ui/react';
|
||||||
import { clamp } from 'lodash-es';
|
import { clamp } from 'lodash';
|
||||||
|
|
||||||
import { FocusEvent, memo, useEffect, useState } from 'react';
|
import { FocusEvent, memo, useEffect, useState } from 'react';
|
||||||
|
|
||||||
|
@ -16,23 +16,13 @@ type IAISelectProps = SelectProps & {
|
|||||||
validValues:
|
validValues:
|
||||||
| Array<number | string>
|
| Array<number | string>
|
||||||
| Array<{ key: string; value: string | number }>;
|
| Array<{ key: string; value: string | number }>;
|
||||||
horizontal?: boolean;
|
|
||||||
spaceEvenly?: boolean;
|
|
||||||
};
|
};
|
||||||
/**
|
/**
|
||||||
* Customized Chakra FormControl + Select multi-part component.
|
* Customized Chakra FormControl + Select multi-part component.
|
||||||
*/
|
*/
|
||||||
const IAISelect = (props: IAISelectProps) => {
|
const IAISelect = (props: IAISelectProps) => {
|
||||||
const {
|
const { label, isDisabled, validValues, tooltip, tooltipProps, ...rest } =
|
||||||
label,
|
props;
|
||||||
isDisabled,
|
|
||||||
validValues,
|
|
||||||
tooltip,
|
|
||||||
tooltipProps,
|
|
||||||
horizontal,
|
|
||||||
spaceEvenly,
|
|
||||||
...rest
|
|
||||||
} = props;
|
|
||||||
return (
|
return (
|
||||||
<FormControl
|
<FormControl
|
||||||
isDisabled={isDisabled}
|
isDisabled={isDisabled}
|
||||||
@ -42,28 +32,10 @@ const IAISelect = (props: IAISelectProps) => {
|
|||||||
e.nativeEvent.stopPropagation();
|
e.nativeEvent.stopPropagation();
|
||||||
e.nativeEvent.cancelBubble = true;
|
e.nativeEvent.cancelBubble = true;
|
||||||
}}
|
}}
|
||||||
sx={
|
|
||||||
horizontal
|
|
||||||
? {
|
|
||||||
display: 'flex',
|
|
||||||
flexDirection: 'row',
|
|
||||||
alignItems: 'center',
|
|
||||||
justifyContent: 'space-between',
|
|
||||||
gap: 4,
|
|
||||||
}
|
|
||||||
: {}
|
|
||||||
}
|
|
||||||
>
|
>
|
||||||
{label && (
|
{label && <FormLabel>{label}</FormLabel>}
|
||||||
<FormLabel sx={spaceEvenly ? { flexBasis: 0, flexGrow: 1 } : {}}>
|
|
||||||
{label}
|
|
||||||
</FormLabel>
|
|
||||||
)}
|
|
||||||
<Tooltip label={tooltip} {...tooltipProps}>
|
<Tooltip label={tooltip} {...tooltipProps}>
|
||||||
<Select
|
<Select {...rest}>
|
||||||
{...rest}
|
|
||||||
rootProps={{ sx: spaceEvenly ? { flexBasis: 0, flexGrow: 1 } : {} }}
|
|
||||||
>
|
|
||||||
{validValues.map((opt) => {
|
{validValues.map((opt) => {
|
||||||
return typeof opt === 'string' || typeof opt === 'number' ? (
|
return typeof opt === 'string' || typeof opt === 'number' ? (
|
||||||
<IAIOption key={opt} value={opt}>
|
<IAIOption key={opt} value={opt}>
|
||||||
|
@ -23,21 +23,12 @@ import {
|
|||||||
Tooltip,
|
Tooltip,
|
||||||
TooltipProps,
|
TooltipProps,
|
||||||
} from '@chakra-ui/react';
|
} from '@chakra-ui/react';
|
||||||
import { clamp } from 'lodash-es';
|
import { clamp } from 'lodash';
|
||||||
|
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
import {
|
import { FocusEvent, memo, useEffect, useMemo, useState } from 'react';
|
||||||
FocusEvent,
|
|
||||||
memo,
|
|
||||||
MouseEvent,
|
|
||||||
useCallback,
|
|
||||||
useEffect,
|
|
||||||
useMemo,
|
|
||||||
useState,
|
|
||||||
} from 'react';
|
|
||||||
import { BiReset } from 'react-icons/bi';
|
import { BiReset } from 'react-icons/bi';
|
||||||
import IAIIconButton, { IAIIconButtonProps } from './IAIIconButton';
|
import IAIIconButton, { IAIIconButtonProps } from './IAIIconButton';
|
||||||
import { roundDownToMultiple } from 'common/util/roundDownToMultiple';
|
|
||||||
|
|
||||||
export type IAIFullSliderProps = {
|
export type IAIFullSliderProps = {
|
||||||
label: string;
|
label: string;
|
||||||
@ -117,52 +108,31 @@ const IAISlider = (props: IAIFullSliderProps) => {
|
|||||||
[max, sliderNumberInputProps?.max]
|
[max, sliderNumberInputProps?.max]
|
||||||
);
|
);
|
||||||
|
|
||||||
const handleSliderChange = useCallback(
|
const handleSliderChange = (v: number) => {
|
||||||
(v: number) => {
|
onChange(v);
|
||||||
onChange(v);
|
};
|
||||||
},
|
|
||||||
[onChange]
|
|
||||||
);
|
|
||||||
|
|
||||||
const handleInputBlur = useCallback(
|
const handleInputBlur = (e: FocusEvent<HTMLInputElement>) => {
|
||||||
(e: FocusEvent<HTMLInputElement>) => {
|
if (e.target.value === '') e.target.value = String(min);
|
||||||
if (e.target.value === '') {
|
const clamped = clamp(
|
||||||
e.target.value = String(min);
|
isInteger ? Math.floor(Number(e.target.value)) : Number(localInputValue),
|
||||||
}
|
min,
|
||||||
const clamped = clamp(
|
numberInputMax
|
||||||
isInteger
|
);
|
||||||
? Math.floor(Number(e.target.value))
|
onChange(clamped);
|
||||||
: Number(localInputValue),
|
};
|
||||||
min,
|
|
||||||
numberInputMax
|
|
||||||
);
|
|
||||||
const quantized = roundDownToMultiple(clamped, step);
|
|
||||||
onChange(quantized);
|
|
||||||
setLocalInputValue(quantized);
|
|
||||||
},
|
|
||||||
[isInteger, localInputValue, min, numberInputMax, onChange, step]
|
|
||||||
);
|
|
||||||
|
|
||||||
const handleInputChange = useCallback((v: number | string) => {
|
const handleInputChange = (v: number | string) => {
|
||||||
setLocalInputValue(v);
|
setLocalInputValue(v);
|
||||||
}, []);
|
};
|
||||||
|
|
||||||
const handleResetDisable = useCallback(() => {
|
const handleResetDisable = () => {
|
||||||
if (!handleReset) {
|
if (!handleReset) return;
|
||||||
return;
|
|
||||||
}
|
|
||||||
handleReset();
|
handleReset();
|
||||||
}, [handleReset]);
|
};
|
||||||
|
|
||||||
const forceInputBlur = useCallback((e: MouseEvent) => {
|
|
||||||
if (e.target instanceof HTMLDivElement) {
|
|
||||||
e.target.focus();
|
|
||||||
}
|
|
||||||
}, []);
|
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<FormControl
|
<FormControl
|
||||||
onClick={forceInputBlur}
|
|
||||||
sx={
|
sx={
|
||||||
isCompact
|
isCompact
|
||||||
? {
|
? {
|
||||||
@ -233,7 +203,7 @@ const IAISlider = (props: IAIFullSliderProps) => {
|
|||||||
hidden={hideTooltip}
|
hidden={hideTooltip}
|
||||||
{...sliderTooltipProps}
|
{...sliderTooltipProps}
|
||||||
>
|
>
|
||||||
<SliderThumb {...sliderThumbProps} zIndex={0} />
|
<SliderThumb {...sliderThumbProps} />
|
||||||
</Tooltip>
|
</Tooltip>
|
||||||
</Slider>
|
</Slider>
|
||||||
|
|
||||||
@ -245,7 +215,6 @@ const IAISlider = (props: IAIFullSliderProps) => {
|
|||||||
value={localInputValue}
|
value={localInputValue}
|
||||||
onChange={handleInputChange}
|
onChange={handleInputChange}
|
||||||
onBlur={handleInputBlur}
|
onBlur={handleInputBlur}
|
||||||
focusInputOnChange={false}
|
|
||||||
{...sliderNumberInputProps}
|
{...sliderNumberInputProps}
|
||||||
>
|
>
|
||||||
<NumberInputField
|
<NumberInputField
|
||||||
@ -268,7 +237,7 @@ const IAISlider = (props: IAIFullSliderProps) => {
|
|||||||
<IAIIconButton
|
<IAIIconButton
|
||||||
size="sm"
|
size="sm"
|
||||||
aria-label={t('accessibility.reset')}
|
aria-label={t('accessibility.reset')}
|
||||||
tooltip={t('accessibility.reset')}
|
tooltip="Reset"
|
||||||
icon={<BiReset />}
|
icon={<BiReset />}
|
||||||
isDisabled={isDisabled}
|
isDisabled={isDisabled}
|
||||||
onClick={handleResetDisable}
|
onClick={handleResetDisable}
|
||||||
|
@ -34,9 +34,10 @@ const IAISwitch = (props: Props) => {
|
|||||||
display="flex"
|
display="flex"
|
||||||
gap={4}
|
gap={4}
|
||||||
alignItems="center"
|
alignItems="center"
|
||||||
|
justifyContent="space-between"
|
||||||
{...formControlProps}
|
{...formControlProps}
|
||||||
>
|
>
|
||||||
<FormLabel my={1} flexGrow={1} {...formLabelProps}>
|
<FormLabel my={1} {...formLabelProps}>
|
||||||
{label}
|
{label}
|
||||||
</FormLabel>
|
</FormLabel>
|
||||||
<Switch {...rest} />
|
<Switch {...rest} />
|
||||||
|
@ -1,11 +1,32 @@
|
|||||||
import { Badge, Box, Flex } from '@chakra-ui/react';
|
import { Badge, Box, ButtonGroup, Flex } from '@chakra-ui/react';
|
||||||
import { Image } from 'app/types/invokeai';
|
import { RootState } from 'app/store';
|
||||||
|
import { useAppDispatch, useAppSelector } from 'app/storeHooks';
|
||||||
|
import { clearInitialImage } from 'features/parameters/store/generationSlice';
|
||||||
|
import { useCallback } from 'react';
|
||||||
|
import IAIIconButton from 'common/components/IAIIconButton';
|
||||||
|
import { FaUndo, FaUpload } from 'react-icons/fa';
|
||||||
|
import { useTranslation } from 'react-i18next';
|
||||||
|
import { Image } from 'app/invokeai';
|
||||||
|
|
||||||
type ImageToImageOverlayProps = {
|
type ImageToImageOverlayProps = {
|
||||||
|
setIsLoaded: (isLoaded: boolean) => void;
|
||||||
image: Image;
|
image: Image;
|
||||||
};
|
};
|
||||||
|
|
||||||
const ImageToImageOverlay = ({ image }: ImageToImageOverlayProps) => {
|
const ImageToImageOverlay = ({
|
||||||
|
setIsLoaded,
|
||||||
|
image,
|
||||||
|
}: ImageToImageOverlayProps) => {
|
||||||
|
const isImageToImageEnabled = useAppSelector(
|
||||||
|
(state: RootState) => state.generation.isImageToImageEnabled
|
||||||
|
);
|
||||||
|
const dispatch = useAppDispatch();
|
||||||
|
const { t } = useTranslation();
|
||||||
|
const handleResetInitialImage = useCallback(() => {
|
||||||
|
dispatch(clearInitialImage());
|
||||||
|
setIsLoaded(false);
|
||||||
|
}, [dispatch, setIsLoaded]);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<Box
|
<Box
|
||||||
sx={{
|
sx={{
|
||||||
@ -16,12 +37,34 @@ const ImageToImageOverlay = ({ image }: ImageToImageOverlayProps) => {
|
|||||||
position: 'absolute',
|
position: 'absolute',
|
||||||
}}
|
}}
|
||||||
>
|
>
|
||||||
<Flex
|
<ButtonGroup
|
||||||
sx={{
|
sx={{
|
||||||
position: 'absolute',
|
position: 'absolute',
|
||||||
top: 0,
|
top: 0,
|
||||||
right: 0,
|
right: 0,
|
||||||
p: 2,
|
p: 2,
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
<IAIIconButton
|
||||||
|
size="sm"
|
||||||
|
isDisabled={!isImageToImageEnabled}
|
||||||
|
icon={<FaUndo />}
|
||||||
|
aria-label={t('accessibility.reset')}
|
||||||
|
onClick={handleResetInitialImage}
|
||||||
|
/>
|
||||||
|
<IAIIconButton
|
||||||
|
size="sm"
|
||||||
|
isDisabled={!isImageToImageEnabled}
|
||||||
|
icon={<FaUpload />}
|
||||||
|
aria-label={t('common.upload')}
|
||||||
|
/>
|
||||||
|
</ButtonGroup>
|
||||||
|
<Flex
|
||||||
|
sx={{
|
||||||
|
position: 'absolute',
|
||||||
|
bottom: 0,
|
||||||
|
left: 0,
|
||||||
|
p: 2,
|
||||||
alignItems: 'flex-start',
|
alignItems: 'flex-start',
|
||||||
}}
|
}}
|
||||||
>
|
>
|
||||||
|
@ -1,41 +0,0 @@
|
|||||||
import { ButtonGroup, Flex, Spacer, Text } from '@chakra-ui/react';
|
|
||||||
import IAIIconButton from 'common/components/IAIIconButton';
|
|
||||||
|
|
||||||
import { useTranslation } from 'react-i18next';
|
|
||||||
import { FaUndo, FaUpload } from 'react-icons/fa';
|
|
||||||
import { useAppDispatch } from 'app/store/storeHooks';
|
|
||||||
import { useCallback } from 'react';
|
|
||||||
import { clearInitialImage } from 'features/parameters/store/generationSlice';
|
|
||||||
|
|
||||||
const ImageToImageSettingsHeader = () => {
|
|
||||||
const dispatch = useAppDispatch();
|
|
||||||
const { t } = useTranslation();
|
|
||||||
|
|
||||||
const handleResetInitialImage = useCallback(() => {
|
|
||||||
dispatch(clearInitialImage());
|
|
||||||
}, [dispatch]);
|
|
||||||
|
|
||||||
return (
|
|
||||||
<Flex w="full" alignItems="center">
|
|
||||||
<Text size="sm" fontWeight={500} color="base.300">
|
|
||||||
Image to Image
|
|
||||||
</Text>
|
|
||||||
<Spacer />
|
|
||||||
<ButtonGroup>
|
|
||||||
<IAIIconButton
|
|
||||||
size="sm"
|
|
||||||
icon={<FaUndo />}
|
|
||||||
aria-label={t('accessibility.reset')}
|
|
||||||
onClick={handleResetInitialImage}
|
|
||||||
/>
|
|
||||||
<IAIIconButton
|
|
||||||
size="sm"
|
|
||||||
icon={<FaUpload />}
|
|
||||||
aria-label={t('common.upload')}
|
|
||||||
/>
|
|
||||||
</ButtonGroup>
|
|
||||||
</Flex>
|
|
||||||
);
|
|
||||||
};
|
|
||||||
|
|
||||||
export default ImageToImageSettingsHeader;
|
|
@ -1,6 +1,6 @@
|
|||||||
import { Box, useToast } from '@chakra-ui/react';
|
import { Box, useToast } from '@chakra-ui/react';
|
||||||
import { ImageUploaderTriggerContext } from 'app/contexts/ImageUploaderTriggerContext';
|
import { ImageUploaderTriggerContext } from 'app/contexts/ImageUploaderTriggerContext';
|
||||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
import { useAppDispatch, useAppSelector } from 'app/storeHooks';
|
||||||
import useImageUploader from 'common/hooks/useImageUploader';
|
import useImageUploader from 'common/hooks/useImageUploader';
|
||||||
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
|
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
|
||||||
import { ResourceKey } from 'i18next';
|
import { ResourceKey } from 'i18next';
|
||||||
|
@ -1,33 +0,0 @@
|
|||||||
import { Flex, Image, Spinner } from '@chakra-ui/react';
|
|
||||||
import InvokeAILogoImage from 'assets/images/logo.png';
|
|
||||||
import { memo } from 'react';
|
|
||||||
|
|
||||||
// This component loads before the theme so we cannot use theme tokens here
|
|
||||||
|
|
||||||
const Loading = () => {
|
|
||||||
return (
|
|
||||||
<Flex
|
|
||||||
position="relative"
|
|
||||||
width="100vw"
|
|
||||||
height="100vh"
|
|
||||||
alignItems="center"
|
|
||||||
justifyContent="center"
|
|
||||||
bg="#151519"
|
|
||||||
>
|
|
||||||
<Image src={InvokeAILogoImage} w="8rem" h="8rem" />
|
|
||||||
<Spinner
|
|
||||||
label="Loading"
|
|
||||||
color="grey"
|
|
||||||
position="absolute"
|
|
||||||
size="sm"
|
|
||||||
width="24px !important"
|
|
||||||
height="24px !important"
|
|
||||||
right="1.5rem"
|
|
||||||
bottom="1.5rem"
|
|
||||||
speed="1.2s"
|
|
||||||
/>
|
|
||||||
</Flex>
|
|
||||||
);
|
|
||||||
};
|
|
||||||
|
|
||||||
export default memo(Loading);
|
|
@ -3,17 +3,7 @@ import { FaImage } from 'react-icons/fa';
|
|||||||
|
|
||||||
const SelectImagePlaceholder = () => {
|
const SelectImagePlaceholder = () => {
|
||||||
return (
|
return (
|
||||||
<Flex
|
<Flex sx={{ h: 36, alignItems: 'center', justifyContent: 'center' }}>
|
||||||
sx={{
|
|
||||||
w: 'full',
|
|
||||||
h: 'full',
|
|
||||||
bg: 'base.800',
|
|
||||||
borderRadius: 'base',
|
|
||||||
alignItems: 'center',
|
|
||||||
justifyContent: 'center',
|
|
||||||
aspectRatio: '1/1',
|
|
||||||
}}
|
|
||||||
>
|
|
||||||
<Icon color="base.400" boxSize={32} as={FaImage}></Icon>
|
<Icon color="base.400" boxSize={32} as={FaImage}></Icon>
|
||||||
</Flex>
|
</Flex>
|
||||||
);
|
);
|
||||||
|
@ -0,0 +1,160 @@
|
|||||||
|
// import WorkInProgress from './WorkInProgress';
|
||||||
|
// import ReactFlow, {
|
||||||
|
// applyEdgeChanges,
|
||||||
|
// applyNodeChanges,
|
||||||
|
// Background,
|
||||||
|
// Controls,
|
||||||
|
// Edge,
|
||||||
|
// Handle,
|
||||||
|
// Node,
|
||||||
|
// NodeTypes,
|
||||||
|
// OnEdgesChange,
|
||||||
|
// OnNodesChange,
|
||||||
|
// Position,
|
||||||
|
// } from 'reactflow';
|
||||||
|
|
||||||
|
// import 'reactflow/dist/style.css';
|
||||||
|
// import {
|
||||||
|
// Fragment,
|
||||||
|
// FunctionComponent,
|
||||||
|
// ReactNode,
|
||||||
|
// useCallback,
|
||||||
|
// useMemo,
|
||||||
|
// useState,
|
||||||
|
// } from 'react';
|
||||||
|
// import { OpenAPIV3 } from 'openapi-types';
|
||||||
|
// import { filter, map, reduce } from 'lodash';
|
||||||
|
// import {
|
||||||
|
// Box,
|
||||||
|
// Flex,
|
||||||
|
// FormControl,
|
||||||
|
// FormLabel,
|
||||||
|
// Input,
|
||||||
|
// Select,
|
||||||
|
// Switch,
|
||||||
|
// Text,
|
||||||
|
// NumberInput,
|
||||||
|
// NumberInputField,
|
||||||
|
// NumberInputStepper,
|
||||||
|
// NumberIncrementStepper,
|
||||||
|
// NumberDecrementStepper,
|
||||||
|
// Tooltip,
|
||||||
|
// chakra,
|
||||||
|
// Badge,
|
||||||
|
// Heading,
|
||||||
|
// VStack,
|
||||||
|
// HStack,
|
||||||
|
// Menu,
|
||||||
|
// MenuButton,
|
||||||
|
// MenuList,
|
||||||
|
// MenuItem,
|
||||||
|
// MenuItemOption,
|
||||||
|
// MenuGroup,
|
||||||
|
// MenuOptionGroup,
|
||||||
|
// MenuDivider,
|
||||||
|
// IconButton,
|
||||||
|
// } from '@chakra-ui/react';
|
||||||
|
// import { FaPlus } from 'react-icons/fa';
|
||||||
|
// import {
|
||||||
|
// FIELD_NAMES as FIELD_NAMES,
|
||||||
|
// FIELDS,
|
||||||
|
// INVOCATION_NAMES as INVOCATION_NAMES,
|
||||||
|
// INVOCATIONS,
|
||||||
|
// } from 'features/nodeEditor/constants';
|
||||||
|
|
||||||
|
// console.log('invocations', INVOCATIONS);
|
||||||
|
|
||||||
|
// const nodeTypes = reduce(
|
||||||
|
// INVOCATIONS,
|
||||||
|
// (acc, val, key) => {
|
||||||
|
// acc[key] = val.component;
|
||||||
|
// return acc;
|
||||||
|
// },
|
||||||
|
// {} as NodeTypes
|
||||||
|
// );
|
||||||
|
|
||||||
|
// console.log('nodeTypes', nodeTypes);
|
||||||
|
|
||||||
|
// // make initial nodes one of every node for now
|
||||||
|
// let n = 0;
|
||||||
|
// const initialNodes = map(INVOCATIONS, (i) => ({
|
||||||
|
// id: i.type,
|
||||||
|
// type: i.title,
|
||||||
|
// position: { x: (n += 20), y: (n += 20) },
|
||||||
|
// data: {},
|
||||||
|
// }));
|
||||||
|
|
||||||
|
// console.log('initialNodes', initialNodes);
|
||||||
|
|
||||||
|
// export default function NodesWIP() {
|
||||||
|
// const [nodes, setNodes] = useState<Node[]>([]);
|
||||||
|
// const [edges, setEdges] = useState<Edge[]>([]);
|
||||||
|
|
||||||
|
// const onNodesChange: OnNodesChange = useCallback(
|
||||||
|
// (changes) => setNodes((nds) => applyNodeChanges(changes, nds)),
|
||||||
|
// []
|
||||||
|
// );
|
||||||
|
|
||||||
|
// const onEdgesChange: OnEdgesChange = useCallback(
|
||||||
|
// (changes) => setEdges((eds: Edge[]) => applyEdgeChanges(changes, eds)),
|
||||||
|
// []
|
||||||
|
// );
|
||||||
|
|
||||||
|
// return (
|
||||||
|
// <Box
|
||||||
|
// sx={{
|
||||||
|
// position: 'relative',
|
||||||
|
// width: 'full',
|
||||||
|
// height: 'full',
|
||||||
|
// borderRadius: 'md',
|
||||||
|
// }}
|
||||||
|
// >
|
||||||
|
// <ReactFlow
|
||||||
|
// nodeTypes={nodeTypes}
|
||||||
|
// nodes={nodes}
|
||||||
|
// edges={edges}
|
||||||
|
// onNodesChange={onNodesChange}
|
||||||
|
// onEdgesChange={onEdgesChange}
|
||||||
|
// >
|
||||||
|
// <Background />
|
||||||
|
// <Controls />
|
||||||
|
// </ReactFlow>
|
||||||
|
// <HStack sx={{ position: 'absolute', top: 2, right: 2 }}>
|
||||||
|
// {FIELD_NAMES.map((field) => (
|
||||||
|
// <Badge
|
||||||
|
// key={field}
|
||||||
|
// colorScheme={FIELDS[field].color}
|
||||||
|
// sx={{ userSelect: 'none' }}
|
||||||
|
// >
|
||||||
|
// {field}
|
||||||
|
// </Badge>
|
||||||
|
// ))}
|
||||||
|
// </HStack>
|
||||||
|
// <Menu>
|
||||||
|
// <MenuButton
|
||||||
|
// as={IconButton}
|
||||||
|
// aria-label="Options"
|
||||||
|
// icon={<FaPlus />}
|
||||||
|
// sx={{ position: 'absolute', top: 2, left: 2 }}
|
||||||
|
// />
|
||||||
|
// <MenuList>
|
||||||
|
// {INVOCATION_NAMES.map((name) => {
|
||||||
|
// const invocation = INVOCATIONS[name];
|
||||||
|
// return (
|
||||||
|
// <Tooltip
|
||||||
|
// key={name}
|
||||||
|
// label={invocation.description}
|
||||||
|
// placement="end"
|
||||||
|
// hasArrow
|
||||||
|
// >
|
||||||
|
// <MenuItem>{invocation.title}</MenuItem>
|
||||||
|
// </Tooltip>
|
||||||
|
// );
|
||||||
|
// })}
|
||||||
|
// </MenuList>
|
||||||
|
// </Menu>
|
||||||
|
// </Box>
|
||||||
|
// );
|
||||||
|
// }
|
||||||
|
|
||||||
|
export default {};
|
@ -1,39 +0,0 @@
|
|||||||
import { createSelector } from '@reduxjs/toolkit';
|
|
||||||
import { RootState } from 'app/store/store';
|
|
||||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
|
||||||
import { shiftKeyPressed } from 'features/ui/store/hotkeysSlice';
|
|
||||||
import { isEqual } from 'lodash-es';
|
|
||||||
import { isHotkeyPressed, useHotkeys } from 'react-hotkeys-hook';
|
|
||||||
|
|
||||||
const globalHotkeysSelector = createSelector(
|
|
||||||
(state: RootState) => state.hotkeys,
|
|
||||||
(hotkeys) => {
|
|
||||||
const { shift } = hotkeys;
|
|
||||||
return { shift };
|
|
||||||
},
|
|
||||||
{
|
|
||||||
memoizeOptions: {
|
|
||||||
resultEqualityCheck: isEqual,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
);
|
|
||||||
|
|
||||||
// TODO: Does not catch keypresses while focused in an input. Maybe there is a way?
|
|
||||||
|
|
||||||
export const useGlobalHotkeys = () => {
|
|
||||||
const dispatch = useAppDispatch();
|
|
||||||
const { shift } = useAppSelector(globalHotkeysSelector);
|
|
||||||
|
|
||||||
useHotkeys(
|
|
||||||
'*',
|
|
||||||
() => {
|
|
||||||
if (isHotkeyPressed('shift')) {
|
|
||||||
!shift && dispatch(shiftKeyPressed(true));
|
|
||||||
} else {
|
|
||||||
shift && dispatch(shiftKeyPressed(false));
|
|
||||||
}
|
|
||||||
},
|
|
||||||
{ keyup: true, keydown: true },
|
|
||||||
[shift]
|
|
||||||
);
|
|
||||||
};
|
|
@ -1,4 +1,4 @@
|
|||||||
import * as InvokeAI from 'app/types/invokeai';
|
import * as InvokeAI from 'app/invokeai';
|
||||||
import promptToString from './promptToString';
|
import promptToString from './promptToString';
|
||||||
|
|
||||||
export function getPromptAndNegative(inputPrompt: InvokeAI.Prompt) {
|
export function getPromptAndNegative(inputPrompt: InvokeAI.Prompt) {
|
||||||
|
@ -1,6 +1,5 @@
|
|||||||
import { RootState } from 'app/store/store';
|
import { RootState } from 'app/store';
|
||||||
import { useAppSelector } from 'app/store/storeHooks';
|
import { useAppSelector } from 'app/storeHooks';
|
||||||
import { useCallback } from 'react';
|
|
||||||
import { OpenAPI } from 'services/api';
|
import { OpenAPI } from 'services/api';
|
||||||
|
|
||||||
export const getUrlAlt = (url: string, shouldTransformUrls: boolean) => {
|
export const getUrlAlt = (url: string, shouldTransformUrls: boolean) => {
|
||||||
@ -13,22 +12,17 @@ export const getUrlAlt = (url: string, shouldTransformUrls: boolean) => {
|
|||||||
|
|
||||||
export const useGetUrl = () => {
|
export const useGetUrl = () => {
|
||||||
const shouldTransformUrls = useAppSelector(
|
const shouldTransformUrls = useAppSelector(
|
||||||
(state: RootState) => state.config.shouldTransformUrls
|
(state: RootState) => state.system.shouldTransformUrls
|
||||||
);
|
);
|
||||||
|
|
||||||
const getUrl = useCallback(
|
return {
|
||||||
(url?: string) => {
|
shouldTransformUrls,
|
||||||
|
getUrl: (url?: string) => {
|
||||||
if (OpenAPI.BASE && shouldTransformUrls) {
|
if (OpenAPI.BASE && shouldTransformUrls) {
|
||||||
return [OpenAPI.BASE, url].join('/');
|
return [OpenAPI.BASE, url].join('/');
|
||||||
}
|
}
|
||||||
|
|
||||||
return url;
|
return url;
|
||||||
},
|
},
|
||||||
[shouldTransformUrls]
|
|
||||||
);
|
|
||||||
|
|
||||||
return {
|
|
||||||
shouldTransformUrls,
|
|
||||||
getUrl,
|
|
||||||
};
|
};
|
||||||
};
|
};
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
import { forEach, size } from 'lodash-es';
|
import { forEach, size } from 'lodash';
|
||||||
import { ImageField, LatentsField, ConditioningField } from 'services/api';
|
import { ImageField, LatentsField } from 'services/api';
|
||||||
|
|
||||||
const OBJECT_TYPESTRING = '[object Object]';
|
const OBJECT_TYPESTRING = '[object Object]';
|
||||||
const STRING_TYPESTRING = '[object String]';
|
const STRING_TYPESTRING = '[object String]';
|
||||||
@ -74,38 +74,8 @@ const parseLatentsField = (latentsField: unknown): LatentsField | undefined => {
|
|||||||
};
|
};
|
||||||
};
|
};
|
||||||
|
|
||||||
const parseConditioningField = (
|
|
||||||
conditioningField: unknown
|
|
||||||
): ConditioningField | undefined => {
|
|
||||||
// Must be an object
|
|
||||||
if (!isObject(conditioningField)) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
// A ConditioningField must have a `conditioning_name`
|
|
||||||
if (!('conditioning_name' in conditioningField)) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
// A ConditioningField's `conditioning_name` must be a string
|
|
||||||
if (typeof conditioningField.conditioning_name !== 'string') {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Build a valid ConditioningField
|
|
||||||
return {
|
|
||||||
conditioning_name: conditioningField.conditioning_name,
|
|
||||||
};
|
|
||||||
};
|
|
||||||
|
|
||||||
type NodeMetadata = {
|
type NodeMetadata = {
|
||||||
[key: string]:
|
[key: string]: string | number | boolean | ImageField | LatentsField;
|
||||||
| string
|
|
||||||
| number
|
|
||||||
| boolean
|
|
||||||
| ImageField
|
|
||||||
| LatentsField
|
|
||||||
| ConditioningField;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
type InvokeAIMetadata = {
|
type InvokeAIMetadata = {
|
||||||
@ -131,7 +101,7 @@ export const parseNodeMetadata = (
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
// the only valid object types are ImageField, LatentsField and ConditioningField
|
// the only valid object types are ImageField and LatentsField
|
||||||
if (isObject(nodeItem)) {
|
if (isObject(nodeItem)) {
|
||||||
if ('image_name' in nodeItem || 'image_type' in nodeItem) {
|
if ('image_name' in nodeItem || 'image_type' in nodeItem) {
|
||||||
const imageField = parseImageField(nodeItem);
|
const imageField = parseImageField(nodeItem);
|
||||||
@ -148,14 +118,6 @@ export const parseNodeMetadata = (
|
|||||||
}
|
}
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
if ('conditioning_name' in nodeItem) {
|
|
||||||
const conditioningField = parseConditioningField(nodeItem);
|
|
||||||
if (conditioningField) {
|
|
||||||
parsed[nodeKey] = conditioningField;
|
|
||||||
}
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// otherwise we accept any string, number or boolean
|
// otherwise we accept any string, number or boolean
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
import * as InvokeAI from 'app/types/invokeai';
|
import * as InvokeAI from 'app/invokeai';
|
||||||
|
|
||||||
const promptToString = (prompt: InvokeAI.Prompt): string => {
|
const promptToString = (prompt: InvokeAI.Prompt): string => {
|
||||||
if (typeof prompt === 'string') {
|
if (typeof prompt === 'string') {
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
import * as InvokeAI from 'app/types/invokeai';
|
import * as InvokeAI from 'app/invokeai';
|
||||||
|
|
||||||
export const stringToSeedWeights = (
|
export const stringToSeedWeights = (
|
||||||
string: string
|
string: string
|
||||||
|
@ -1,9 +1,10 @@
|
|||||||
import React, { lazy, memo, PropsWithChildren, useEffect } from 'react';
|
import React, { lazy, PropsWithChildren, useEffect, useState } from 'react';
|
||||||
import { Provider } from 'react-redux';
|
import { Provider } from 'react-redux';
|
||||||
import { PersistGate } from 'redux-persist/integration/react';
|
import { PersistGate } from 'redux-persist/integration/react';
|
||||||
import { store } from 'app/store/store';
|
import { buildMiddleware, store } from './app/store';
|
||||||
import { persistor } from '../store/persistor';
|
import { persistor } from './persistor';
|
||||||
import { OpenAPI } from 'services/api';
|
import { OpenAPI } from 'services/api';
|
||||||
|
import { InvokeTabName } from 'features/ui/store/tabMap';
|
||||||
import '@fontsource/inter/100.css';
|
import '@fontsource/inter/100.css';
|
||||||
import '@fontsource/inter/200.css';
|
import '@fontsource/inter/200.css';
|
||||||
import '@fontsource/inter/300.css';
|
import '@fontsource/inter/300.css';
|
||||||
@ -14,23 +15,33 @@ import '@fontsource/inter/700.css';
|
|||||||
import '@fontsource/inter/800.css';
|
import '@fontsource/inter/800.css';
|
||||||
import '@fontsource/inter/900.css';
|
import '@fontsource/inter/900.css';
|
||||||
|
|
||||||
import Loading from '../../common/components/Loading/Loading';
|
import Loading from './Loading';
|
||||||
|
|
||||||
|
// Localization
|
||||||
|
import './i18n';
|
||||||
import { addMiddleware, resetMiddlewares } from 'redux-dynamic-middlewares';
|
import { addMiddleware, resetMiddlewares } from 'redux-dynamic-middlewares';
|
||||||
import { PartialAppConfig } from 'app/types/invokeai';
|
|
||||||
|
|
||||||
import '../../i18n';
|
const App = lazy(() => import('./app/App'));
|
||||||
import { socketMiddleware } from 'services/events/middleware';
|
const ThemeLocaleProvider = lazy(() => import('./app/ThemeLocaleProvider'));
|
||||||
|
|
||||||
const App = lazy(() => import('./App'));
|
|
||||||
const ThemeLocaleProvider = lazy(() => import('./ThemeLocaleProvider'));
|
|
||||||
|
|
||||||
interface Props extends PropsWithChildren {
|
interface Props extends PropsWithChildren {
|
||||||
apiUrl?: string;
|
apiUrl?: string;
|
||||||
|
disabledPanels?: string[];
|
||||||
|
disabledTabs?: InvokeTabName[];
|
||||||
token?: string;
|
token?: string;
|
||||||
config?: PartialAppConfig;
|
shouldTransformUrls?: boolean;
|
||||||
|
shouldFetchImages?: boolean;
|
||||||
}
|
}
|
||||||
|
|
||||||
const InvokeAIUI = ({ apiUrl, token, config, children }: Props) => {
|
export default function Component({
|
||||||
|
apiUrl,
|
||||||
|
disabledPanels = [],
|
||||||
|
disabledTabs = [],
|
||||||
|
token,
|
||||||
|
children,
|
||||||
|
shouldTransformUrls,
|
||||||
|
shouldFetchImages = false,
|
||||||
|
}: Props) {
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
// configure API client token
|
// configure API client token
|
||||||
if (token) {
|
if (token) {
|
||||||
@ -51,22 +62,29 @@ const InvokeAIUI = ({ apiUrl, token, config, children }: Props) => {
|
|||||||
// the `apiUrl`/`token` dynamically.
|
// the `apiUrl`/`token` dynamically.
|
||||||
|
|
||||||
// rebuild socket middleware with token and apiUrl
|
// rebuild socket middleware with token and apiUrl
|
||||||
addMiddleware(socketMiddleware());
|
addMiddleware(buildMiddleware());
|
||||||
}, [apiUrl, token]);
|
}, [apiUrl, token]);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<React.StrictMode>
|
<React.StrictMode>
|
||||||
<Provider store={store}>
|
<Provider store={store}>
|
||||||
<PersistGate loading={<Loading />} persistor={persistor}>
|
<PersistGate loading={<Loading />} persistor={persistor}>
|
||||||
<React.Suspense fallback={<Loading />}>
|
<React.Suspense fallback={<Loading showText />}>
|
||||||
<ThemeLocaleProvider>
|
<ThemeLocaleProvider>
|
||||||
<App config={config}>{children}</App>
|
<App
|
||||||
|
options={{
|
||||||
|
disabledPanels,
|
||||||
|
disabledTabs,
|
||||||
|
shouldTransformUrls,
|
||||||
|
shouldFetchImages,
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
{children}
|
||||||
|
</App>
|
||||||
</ThemeLocaleProvider>
|
</ThemeLocaleProvider>
|
||||||
</React.Suspense>
|
</React.Suspense>
|
||||||
</PersistGate>
|
</PersistGate>
|
||||||
</Provider>
|
</Provider>
|
||||||
</React.StrictMode>
|
</React.StrictMode>
|
||||||
);
|
);
|
||||||
};
|
}
|
||||||
|
|
||||||
export default memo(InvokeAIUI);
|
|
20
invokeai/frontend/web/src/exports.tsx
Normal file
20
invokeai/frontend/web/src/exports.tsx
Normal file
@ -0,0 +1,20 @@
|
|||||||
|
import Component from './component';
|
||||||
|
|
||||||
|
import InvokeAiLogoComponent from './features/system/components/InvokeAILogoComponent';
|
||||||
|
import ThemeChanger from './features/system/components/ThemeChanger';
|
||||||
|
import IAIPopover from './common/components/IAIPopover';
|
||||||
|
import IAIIconButton from './common/components/IAIIconButton';
|
||||||
|
import SettingsModal from './features/system/components/SettingsModal/SettingsModal';
|
||||||
|
import StatusIndicator from './features/system/components/StatusIndicator';
|
||||||
|
import ModelSelect from 'features/system/components/ModelSelect';
|
||||||
|
|
||||||
|
export default Component;
|
||||||
|
export {
|
||||||
|
InvokeAiLogoComponent,
|
||||||
|
ThemeChanger,
|
||||||
|
IAIPopover,
|
||||||
|
IAIIconButton,
|
||||||
|
SettingsModal,
|
||||||
|
StatusIndicator,
|
||||||
|
ModelSelect,
|
||||||
|
};
|
@ -1,4 +1,4 @@
|
|||||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
import { useAppDispatch, useAppSelector } from 'app/storeHooks';
|
||||||
import IAIAlertDialog from 'common/components/IAIAlertDialog';
|
import IAIAlertDialog from 'common/components/IAIAlertDialog';
|
||||||
import IAIButton from 'common/components/IAIButton';
|
import IAIButton from 'common/components/IAIButton';
|
||||||
import { clearCanvasHistory } from 'features/canvas/store/canvasSlice';
|
import { clearCanvasHistory } from 'features/canvas/store/canvasSlice';
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
import { Box, chakra, Flex } from '@chakra-ui/react';
|
import { Box, chakra, Flex } from '@chakra-ui/react';
|
||||||
import { createSelector } from '@reduxjs/toolkit';
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
import { useAppSelector } from 'app/store/storeHooks';
|
import { useAppSelector } from 'app/storeHooks';
|
||||||
import {
|
import {
|
||||||
canvasSelector,
|
canvasSelector,
|
||||||
isStagingSelector,
|
isStagingSelector,
|
||||||
@ -8,7 +8,7 @@ import {
|
|||||||
import Konva from 'konva';
|
import Konva from 'konva';
|
||||||
import { KonvaEventObject } from 'konva/lib/Node';
|
import { KonvaEventObject } from 'konva/lib/Node';
|
||||||
import { Vector2d } from 'konva/lib/types';
|
import { Vector2d } from 'konva/lib/types';
|
||||||
import { isEqual } from 'lodash-es';
|
import { isEqual } from 'lodash';
|
||||||
|
|
||||||
import { useCallback, useRef } from 'react';
|
import { useCallback, useRef } from 'react';
|
||||||
import { Layer, Stage } from 'react-konva';
|
import { Layer, Stage } from 'react-konva';
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
import { createSelector } from '@reduxjs/toolkit';
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
import { useAppSelector } from 'app/store/storeHooks';
|
import { useAppSelector } from 'app/storeHooks';
|
||||||
import { isEqual } from 'lodash-es';
|
import { isEqual } from 'lodash';
|
||||||
|
|
||||||
import { Group, Rect } from 'react-konva';
|
import { Group, Rect } from 'react-konva';
|
||||||
import { canvasSelector } from '../store/canvasSelectors';
|
import { canvasSelector } from '../store/canvasSelectors';
|
||||||
|
@ -2,10 +2,10 @@
|
|||||||
|
|
||||||
import { useToken } from '@chakra-ui/react';
|
import { useToken } from '@chakra-ui/react';
|
||||||
import { createSelector } from '@reduxjs/toolkit';
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
import { RootState } from 'app/store/store';
|
import { RootState } from 'app/store';
|
||||||
import { useAppSelector } from 'app/store/storeHooks';
|
import { useAppSelector } from 'app/storeHooks';
|
||||||
import { canvasSelector } from 'features/canvas/store/canvasSelectors';
|
import { canvasSelector } from 'features/canvas/store/canvasSelectors';
|
||||||
import { isEqual, range } from 'lodash-es';
|
import { isEqual, range } from 'lodash';
|
||||||
|
|
||||||
import { ReactNode, useCallback, useLayoutEffect, useState } from 'react';
|
import { ReactNode, useCallback, useLayoutEffect, useState } from 'react';
|
||||||
import { Group, Line as KonvaLine } from 'react-konva';
|
import { Group, Line as KonvaLine } from 'react-konva';
|
||||||
|
@ -1,10 +1,10 @@
|
|||||||
import { createSelector } from '@reduxjs/toolkit';
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
import { RootState } from 'app/store/store';
|
import { RootState } from 'app/store';
|
||||||
import { useAppSelector } from 'app/store/storeHooks';
|
import { useAppSelector } from 'app/storeHooks';
|
||||||
import { useGetUrl } from 'common/util/getUrl';
|
import { useGetUrl } from 'common/util/getUrl';
|
||||||
import { GalleryState } from 'features/gallery/store/gallerySlice';
|
import { GalleryState } from 'features/gallery/store/gallerySlice';
|
||||||
import { ImageConfig } from 'konva/lib/shapes/Image';
|
import { ImageConfig } from 'konva/lib/shapes/Image';
|
||||||
import { isEqual } from 'lodash-es';
|
import { isEqual } from 'lodash';
|
||||||
|
|
||||||
import { useEffect, useState } from 'react';
|
import { useEffect, useState } from 'react';
|
||||||
import { Image as KonvaImage } from 'react-konva';
|
import { Image as KonvaImage } from 'react-konva';
|
||||||
|
@ -1,12 +1,12 @@
|
|||||||
import { createSelector } from '@reduxjs/toolkit';
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
import { useAppSelector } from 'app/store/storeHooks';
|
import { useAppSelector } from 'app/storeHooks';
|
||||||
import { canvasSelector } from 'features/canvas/store/canvasSelectors';
|
import { canvasSelector } from 'features/canvas/store/canvasSelectors';
|
||||||
import { RectConfig } from 'konva/lib/shapes/Rect';
|
import { RectConfig } from 'konva/lib/shapes/Rect';
|
||||||
import { Rect } from 'react-konva';
|
import { Rect } from 'react-konva';
|
||||||
|
|
||||||
import { rgbaColorToString } from 'features/canvas/util/colorToString';
|
import { rgbaColorToString } from 'features/canvas/util/colorToString';
|
||||||
import Konva from 'konva';
|
import Konva from 'konva';
|
||||||
import { isNumber } from 'lodash-es';
|
import { isNumber } from 'lodash';
|
||||||
import { useCallback, useEffect, useRef, useState } from 'react';
|
import { useCallback, useEffect, useRef, useState } from 'react';
|
||||||
|
|
||||||
export const canvasMaskCompositerSelector = createSelector(
|
export const canvasMaskCompositerSelector = createSelector(
|
||||||
|
@ -1,8 +1,8 @@
|
|||||||
import { createSelector } from '@reduxjs/toolkit';
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
import { useAppSelector } from 'app/store/storeHooks';
|
import { useAppSelector } from 'app/storeHooks';
|
||||||
import { canvasSelector } from 'features/canvas/store/canvasSelectors';
|
import { canvasSelector } from 'features/canvas/store/canvasSelectors';
|
||||||
import { GroupConfig } from 'konva/lib/Group';
|
import { GroupConfig } from 'konva/lib/Group';
|
||||||
import { isEqual } from 'lodash-es';
|
import { isEqual } from 'lodash';
|
||||||
|
|
||||||
import { Group, Line } from 'react-konva';
|
import { Group, Line } from 'react-konva';
|
||||||
import { isCanvasMaskLine } from '../store/canvasTypes';
|
import { isCanvasMaskLine } from '../store/canvasTypes';
|
||||||
|
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user