mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Compare commits
182 Commits
dev/pytorc
...
feat/ui/no
Author | SHA1 | Date | |
---|---|---|---|
fc52cab590 | |||
be0a033b90 | |||
688d3a9453 | |||
7f2dcbb66a | |||
5cadd74a81 | |||
db0bc47f67 | |||
262126aaea | |||
1e12c9b21f | |||
162bcda49e | |||
2b53ce50e0 | |||
1933184ca7 | |||
cc1e6374a6 | |||
a65ad1b42f | |||
16f39978e9 | |||
4e5ac85567 | |||
dffdca674e | |||
0ad0efcc44 | |||
daaf41daab | |||
1a22d50269 | |||
53e3cf162a | |||
61cf59d4f6 | |||
ecb5bdaf7e | |||
f6cdff2c5b | |||
f600104e80 | |||
fff55bd991 | |||
64f044a984 | |||
a15300ac8d | |||
fb0ec1c8d0 | |||
ae172b74a4 | |||
63d10027a4 | |||
ef0773b8a3 | |||
3daaddf15b | |||
570c3fe690 | |||
cbd1a7263a | |||
7fc5fbd4ce | |||
6f6de402ad | |||
1284bab4af | |||
20bf47e9cd | |||
978bde315b | |||
caa1bf9d17 | |||
50eb02f68b | |||
d73f3adc43 | |||
116107f464 | |||
da44bb1707 | |||
f43aed677e | |||
0d051aaae2 | |||
e4e48ff995 | |||
442a6bffa4 | |||
dfb934a2d4 | |||
f94d63ec94 | |||
50357e8b4e | |||
b1240de669 | |||
75f433b9bd | |||
53a1a3eb61 | |||
65f2a7ea31 | |||
49612d69d0 | |||
77ceb950b9 | |||
bffb860535 | |||
8807089c5b | |||
7cb3b2e56e | |||
387e7f949a | |||
cf562f140c | |||
ef890058b9 | |||
442848598d | |||
80c555ef76 | |||
d729d1c100 | |||
266ce200cc | |||
eb02acb22e | |||
f4e2928ac3 | |||
48677ac10b | |||
7a4d9e18d8 | |||
e1279e63d1 | |||
bab407fc65 | |||
afb9a9589a | |||
45bc2211c8 | |||
cb185f16bc | |||
3ab32aedc0 | |||
ea334aa92a | |||
1ef2bf2d2d | |||
a165959ab5 | |||
2386d5d786 | |||
18aa0c91da | |||
3f5a443c0c | |||
b771e9a190 | |||
0ffe2c67b0 | |||
9560a2b890 | |||
8fe49fdb55 | |||
106420fba9 | |||
85f101cdc8 | |||
7155360378 | |||
793a4ddbb2 | |||
b31b8d31ad | |||
914a7f160b | |||
5819c32fb8 | |||
f118933467 | |||
e4e5409d32 | |||
35021565ff | |||
ff9c78cee7 | |||
d5b03408da | |||
97f764c7c5 | |||
b565b6b2f5 | |||
87a917b22b | |||
b1dbf5428e | |||
927a6e425d | |||
aa89be32f7 | |||
5c29af4883 | |||
85949bc5c8 | |||
85111e8d76 | |||
98ebba7ba4 | |||
891b067470 | |||
cb849995e4 | |||
156de26995 | |||
7436a9b35d | |||
aa7eaaed45 | |||
1520a9e2fc | |||
1a21edf085 | |||
8b66a737a7 | |||
183a20cfd8 | |||
1260dfcacc | |||
9ee5cb4395 | |||
3554d3568f | |||
c2a92d1254 | |||
efa6e89dc2 | |||
ba500fc3cb | |||
8446c6cd1f | |||
8e2350ec4c | |||
dd66b3bf25 | |||
dfa69d815e | |||
bf8682fd4e | |||
64b02ead37 | |||
afc2518c66 | |||
28b7b785b0 | |||
9cb592539a | |||
41a87406b3 | |||
163c075b3d | |||
c84f689766 | |||
c38a712c0b | |||
13de5edd70 | |||
090f2a839e | |||
e5cb04f309 | |||
d6faf6d5a1 | |||
b4ade3db3a | |||
7647f8899d | |||
c82d92bc82 | |||
67b13c3b70 | |||
9b93d85746 | |||
818c254cd4 | |||
23d65e7162 | |||
024fd54d0b | |||
c44c19e911 | |||
d923d1d66b | |||
1f2c1e14db | |||
07e3a0ec15 | |||
427db7c7e2 | |||
dad3a7f263 | |||
5bd0bb637f | |||
f05095770c | |||
de189f2db6 | |||
4463124bdd | |||
34402cc46a | |||
54d9833db0 | |||
5fe8cb56fc | |||
7919d81fb1 | |||
9d80b28a4f | |||
1fcd91bcc5 | |||
e456e2e63a | |||
ee41b99049 | |||
111d674e71 | |||
8f048cfbd9 | |||
7103ac6a32 | |||
f6b131e706 | |||
d1b2b99226 | |||
e356f2511b | |||
e5f8b22a43 | |||
45b84fb4bb | |||
f022c89249 | |||
ab05144716 | |||
aeb4914e67 | |||
4c339dd4b0 | |||
7268131f57 | |||
d44151d6ff | |||
1f89cf3343 |
19
.github/stale.yaml
vendored
Normal file
19
.github/stale.yaml
vendored
Normal file
@ -0,0 +1,19 @@
|
||||
# Number of days of inactivity before an issue becomes stale
|
||||
daysUntilStale: 28
|
||||
# Number of days of inactivity before a stale issue is closed
|
||||
daysUntilClose: 14
|
||||
# Issues with these labels will never be considered stale
|
||||
exemptLabels:
|
||||
- pinned
|
||||
- security
|
||||
# Label to use when marking an issue as stale
|
||||
staleLabel: stale
|
||||
# Comment to post when marking an issue as stale. Set to `false` to disable
|
||||
markComment: >
|
||||
This issue has been automatically marked as stale because it has not had
|
||||
recent activity. It will be closed if no further activity occurs. Please
|
||||
update the ticket if this is still a problem on the latest release.
|
||||
# Comment to post when closing a stale issue. Set to `false` to disable
|
||||
closeComment: >
|
||||
Due to inactivity, this issue has been automatically closed. If this is
|
||||
still a problem on the latest release, please recreate the issue.
|
@ -84,7 +84,7 @@ installing lots of models.
|
||||
|
||||
6. Wait while the installer does its thing. After installing the software,
|
||||
the installer will launch a script that lets you configure InvokeAI and
|
||||
select a set of starting image generaiton models.
|
||||
select a set of starting image generation models.
|
||||
|
||||
7. Find the folder that InvokeAI was installed into (it is not the
|
||||
same as the unpacked zip file directory!) The default location of this
|
||||
|
@ -1,10 +1,18 @@
|
||||
# Invocations
|
||||
|
||||
Invocations represent a single operation, its inputs, and its outputs. These operations and their outputs can be chained together to generate and modify images.
|
||||
Invocations represent a single operation, its inputs, and its outputs. These
|
||||
operations and their outputs can be chained together to generate and modify
|
||||
images.
|
||||
|
||||
## Creating a new invocation
|
||||
|
||||
To create a new invocation, either find the appropriate module file in `/ldm/invoke/app/invocations` to add your invocation to, or create a new one in that folder. All invocations in that folder will be discovered and made available to the CLI and API automatically. Invocations make use of [typing](https://docs.python.org/3/library/typing.html) and [pydantic](https://pydantic-docs.helpmanual.io/) for validation and integration into the CLI and API.
|
||||
To create a new invocation, either find the appropriate module file in
|
||||
`/ldm/invoke/app/invocations` to add your invocation to, or create a new one in
|
||||
that folder. All invocations in that folder will be discovered and made
|
||||
available to the CLI and API automatically. Invocations make use of
|
||||
[typing](https://docs.python.org/3/library/typing.html) and
|
||||
[pydantic](https://pydantic-docs.helpmanual.io/) for validation and integration
|
||||
into the CLI and API.
|
||||
|
||||
An invocation looks like this:
|
||||
|
||||
@ -41,34 +49,54 @@ class UpscaleInvocation(BaseInvocation):
|
||||
Each portion is important to implement correctly.
|
||||
|
||||
### Class definition and type
|
||||
|
||||
```py
|
||||
class UpscaleInvocation(BaseInvocation):
|
||||
"""Upscales an image."""
|
||||
type: Literal['upscale'] = 'upscale'
|
||||
```
|
||||
All invocations must derive from `BaseInvocation`. They should have a docstring that declares what they do in a single, short line. They should also have a `type` with a type hint that's `Literal["command_name"]`, where `command_name` is what the user will type on the CLI or use in the API to create this invocation. The `command_name` must be unique. The `type` must be assigned to the value of the literal in the type hint.
|
||||
|
||||
All invocations must derive from `BaseInvocation`. They should have a docstring
|
||||
that declares what they do in a single, short line. They should also have a
|
||||
`type` with a type hint that's `Literal["command_name"]`, where `command_name`
|
||||
is what the user will type on the CLI or use in the API to create this
|
||||
invocation. The `command_name` must be unique. The `type` must be assigned to
|
||||
the value of the literal in the type hint.
|
||||
|
||||
### Inputs
|
||||
|
||||
```py
|
||||
# Inputs
|
||||
image: Union[ImageField,None] = Field(description="The input image")
|
||||
strength: float = Field(default=0.75, gt=0, le=1, description="The strength")
|
||||
level: Literal[2,4] = Field(default=2, description="The upscale level")
|
||||
```
|
||||
Inputs consist of three parts: a name, a type hint, and a `Field` with default, description, and validation information. For example:
|
||||
| Part | Value | Description |
|
||||
| ---- | ----- | ----------- |
|
||||
| Name | `strength` | This field is referred to as `strength` |
|
||||
| Type Hint | `float` | This field must be of type `float` |
|
||||
| Field | `Field(default=0.75, gt=0, le=1, description="The strength")` | The default value is `0.75`, the value must be in the range (0,1], and help text will show "The strength" for this field. |
|
||||
|
||||
Notice that `image` has type `Union[ImageField,None]`. The `Union` allows this field to be parsed with `None` as a value, which enables linking to previous invocations. All fields should either provide a default value or allow `None` as a value, so that they can be overwritten with a linked output from another invocation.
|
||||
Inputs consist of three parts: a name, a type hint, and a `Field` with default,
|
||||
description, and validation information. For example:
|
||||
|
||||
The special type `ImageField` is also used here. All images are passed as `ImageField`, which protects them from pydantic validation errors (since images only ever come from links).
|
||||
| Part | Value | Description |
|
||||
| --------- | ------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------- |
|
||||
| Name | `strength` | This field is referred to as `strength` |
|
||||
| Type Hint | `float` | This field must be of type `float` |
|
||||
| Field | `Field(default=0.75, gt=0, le=1, description="The strength")` | The default value is `0.75`, the value must be in the range (0,1], and help text will show "The strength" for this field. |
|
||||
|
||||
Finally, note that for all linking, the `type` of the linked fields must match. If the `name` also matches, then the field can be **automatically linked** to a previous invocation by name and matching.
|
||||
Notice that `image` has type `Union[ImageField,None]`. The `Union` allows this
|
||||
field to be parsed with `None` as a value, which enables linking to previous
|
||||
invocations. All fields should either provide a default value or allow `None` as
|
||||
a value, so that they can be overwritten with a linked output from another
|
||||
invocation.
|
||||
|
||||
The special type `ImageField` is also used here. All images are passed as
|
||||
`ImageField`, which protects them from pydantic validation errors (since images
|
||||
only ever come from links).
|
||||
|
||||
Finally, note that for all linking, the `type` of the linked fields must match.
|
||||
If the `name` also matches, then the field can be **automatically linked** to a
|
||||
previous invocation by name and matching.
|
||||
|
||||
### Invoke Function
|
||||
|
||||
```py
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get(self.image.image_type, self.image.image_name)
|
||||
@ -88,13 +116,22 @@ Finally, note that for all linking, the `type` of the linked fields must match.
|
||||
image = ImageField(image_type = image_type, image_name = image_name)
|
||||
)
|
||||
```
|
||||
The `invoke` function is the last portion of an invocation. It is provided an `InvocationContext` which contains services to perform work as well as a `session_id` for use as needed. It should return a class with output values that derives from `BaseInvocationOutput`.
|
||||
|
||||
Before being called, the invocation will have all of its fields set from defaults, inputs, and finally links (overriding in that order).
|
||||
The `invoke` function is the last portion of an invocation. It is provided an
|
||||
`InvocationContext` which contains services to perform work as well as a
|
||||
`session_id` for use as needed. It should return a class with output values that
|
||||
derives from `BaseInvocationOutput`.
|
||||
|
||||
Assume that this invocation may be running simultaneously with other invocations, may be running on another machine, or in other interesting scenarios. If you need functionality, please provide it as a service in the `InvocationServices` class, and make sure it can be overridden.
|
||||
Before being called, the invocation will have all of its fields set from
|
||||
defaults, inputs, and finally links (overriding in that order).
|
||||
|
||||
Assume that this invocation may be running simultaneously with other
|
||||
invocations, may be running on another machine, or in other interesting
|
||||
scenarios. If you need functionality, please provide it as a service in the
|
||||
`InvocationServices` class, and make sure it can be overridden.
|
||||
|
||||
### Outputs
|
||||
|
||||
```py
|
||||
class ImageOutput(BaseInvocationOutput):
|
||||
"""Base class for invocations that output an image"""
|
||||
@ -102,4 +139,64 @@ class ImageOutput(BaseInvocationOutput):
|
||||
|
||||
image: ImageField = Field(default=None, description="The output image")
|
||||
```
|
||||
Output classes look like an invocation class without the invoke method. Prefer to use an existing output class if available, and prefer to name inputs the same as outputs when possible, to promote automatic invocation linking.
|
||||
|
||||
Output classes look like an invocation class without the invoke method. Prefer
|
||||
to use an existing output class if available, and prefer to name inputs the same
|
||||
as outputs when possible, to promote automatic invocation linking.
|
||||
|
||||
## Schema Generation
|
||||
|
||||
Invocation, output and related classes are used to generate an OpenAPI schema.
|
||||
|
||||
### Required Properties
|
||||
|
||||
The schema generation treat all properties with default values as optional. This
|
||||
makes sense internally, but when when using these classes via the generated
|
||||
schema, we end up with e.g. the `ImageOutput` class having its `image` property
|
||||
marked as optional.
|
||||
|
||||
We know that this property will always be present, so the additional logic
|
||||
needed to always check if the property exists adds a lot of extraneous cruft.
|
||||
|
||||
To fix this, we can leverage `pydantic`'s
|
||||
[schema customisation](https://docs.pydantic.dev/usage/schema/#schema-customization)
|
||||
to mark properties that we know will always be present as required.
|
||||
|
||||
Here's that `ImageOutput` class, without the needed schema customisation:
|
||||
|
||||
```python
|
||||
class ImageOutput(BaseInvocationOutput):
|
||||
"""Base class for invocations that output an image"""
|
||||
|
||||
type: Literal["image"] = "image"
|
||||
image: ImageField = Field(default=None, description="The output image")
|
||||
```
|
||||
|
||||
The generated OpenAPI schema, and all clients/types generated from it, will have
|
||||
the `type` and `image` properties marked as optional, even though we know they
|
||||
will always have a value by the time we can interact with them via the API.
|
||||
|
||||
Here's the same class, but with the schema customisation added:
|
||||
|
||||
```python
|
||||
class ImageOutput(BaseInvocationOutput):
|
||||
"""Base class for invocations that output an image"""
|
||||
|
||||
type: Literal["image"] = "image"
|
||||
image: ImageField = Field(default=None, description="The output image")
|
||||
|
||||
class Config:
|
||||
schema_extra = {
|
||||
'required': [
|
||||
'type',
|
||||
'image',
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
The resultant schema (and any API client or types generated from it) will now
|
||||
have see `type` as string literal `"image"` and `image` as an `ImageField`
|
||||
object.
|
||||
|
||||
See this `pydantic` issue for discussion on this solution:
|
||||
<https://github.com/pydantic/pydantic/discussions/4577>
|
||||
|
@ -50,7 +50,7 @@ subset that are currently installed are found in
|
||||
|stable-diffusion-1.5|runwayml/stable-diffusion-v1-5|Stable Diffusion version 1.5 diffusers model (4.27 GB)|https://huggingface.co/runwayml/stable-diffusion-v1-5 |
|
||||
|sd-inpainting-1.5|runwayml/stable-diffusion-inpainting|RunwayML SD 1.5 model optimized for inpainting, diffusers version (4.27 GB)|https://huggingface.co/runwayml/stable-diffusion-inpainting |
|
||||
|stable-diffusion-2.1|stabilityai/stable-diffusion-2-1|Stable Diffusion version 2.1 diffusers model, trained on 768 pixel images (5.21 GB)|https://huggingface.co/stabilityai/stable-diffusion-2-1 |
|
||||
|sd-inpainting-2.0|stabilityai/stable-diffusion-2-1|Stable Diffusion version 2.0 inpainting model (5.21 GB)|https://huggingface.co/stabilityai/stable-diffusion-2-1 |
|
||||
|sd-inpainting-2.0|stabilityai/stable-diffusion-2-inpainting|Stable Diffusion version 2.0 inpainting model (5.21 GB)|https://huggingface.co/stabilityai/stable-diffusion-2-inpainting |
|
||||
|analog-diffusion-1.0|wavymulder/Analog-Diffusion|An SD-1.5 model trained on diverse analog photographs (2.13 GB)|https://huggingface.co/wavymulder/Analog-Diffusion |
|
||||
|deliberate-1.0|XpucT/Deliberate|Versatile model that produces detailed images up to 768px (4.27 GB)|https://huggingface.co/XpucT/Deliberate |
|
||||
|d&d-diffusion-1.0|0xJustin/Dungeons-and-Diffusion|Dungeons & Dragons characters (2.13 GB)|https://huggingface.co/0xJustin/Dungeons-and-Diffusion |
|
||||
|
@ -3,12 +3,16 @@
|
||||
import os
|
||||
from argparse import Namespace
|
||||
|
||||
from invokeai.app.services.metadata import PngMetadataService, MetadataServiceBase
|
||||
|
||||
from ..services.default_graphs import create_system_graphs
|
||||
|
||||
from ..services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage
|
||||
|
||||
from ...backend import Globals
|
||||
from ..services.model_manager_initializer import get_model_manager
|
||||
from ..services.restoration_services import RestorationServices
|
||||
from ..services.graph import GraphExecutionState
|
||||
from ..services.graph import GraphExecutionState, LibraryGraph
|
||||
from ..services.image_storage import DiskImageStorage
|
||||
from ..services.invocation_queue import MemoryInvocationQueue
|
||||
from ..services.invocation_services import InvocationServices
|
||||
@ -58,7 +62,9 @@ class ApiDependencies:
|
||||
|
||||
latents = ForwardCacheLatentsStorage(DiskLatentsStorage(f'{output_folder}/latents'))
|
||||
|
||||
images = DiskImageStorage(f'{output_folder}/images')
|
||||
metadata = PngMetadataService()
|
||||
|
||||
images = DiskImageStorage(f'{output_folder}/images', metadata_service=metadata)
|
||||
|
||||
# TODO: build a file/path manager?
|
||||
db_location = os.path.join(output_folder, "invokeai.db")
|
||||
@ -68,7 +74,11 @@ class ApiDependencies:
|
||||
events=events,
|
||||
latents=latents,
|
||||
images=images,
|
||||
metadata=metadata,
|
||||
queue=MemoryInvocationQueue(),
|
||||
graph_library=SqliteItemStorage[LibraryGraph](
|
||||
filename=db_location, table_name="graphs"
|
||||
),
|
||||
graph_execution_manager=SqliteItemStorage[GraphExecutionState](
|
||||
filename=db_location, table_name="graph_executions"
|
||||
),
|
||||
@ -76,6 +86,8 @@ class ApiDependencies:
|
||||
restoration=RestorationServices(config),
|
||||
)
|
||||
|
||||
create_system_graphs(services.graph_library)
|
||||
|
||||
ApiDependencies.invoker = Invoker(services)
|
||||
|
||||
@staticmethod
|
||||
|
34
invokeai/app/api/models/images.py
Normal file
34
invokeai/app/api/models/images.py
Normal file
@ -0,0 +1,34 @@
|
||||
from typing import Optional
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from invokeai.app.models.image import ImageType
|
||||
from invokeai.app.services.metadata import InvokeAIMetadata
|
||||
|
||||
|
||||
class ImageResponseMetadata(BaseModel):
|
||||
"""An image's metadata. Used only in HTTP responses."""
|
||||
|
||||
created: int = Field(description="The creation timestamp of the image")
|
||||
width: int = Field(description="The width of the image in pixels")
|
||||
height: int = Field(description="The height of the image in pixels")
|
||||
invokeai: Optional[InvokeAIMetadata] = Field(
|
||||
description="The image's InvokeAI-specific metadata"
|
||||
)
|
||||
|
||||
|
||||
class ImageResponse(BaseModel):
|
||||
"""The response type for images"""
|
||||
|
||||
image_type: ImageType = Field(description="The type of the image")
|
||||
image_name: str = Field(description="The name of the image")
|
||||
image_url: str = Field(description="The url of the image")
|
||||
thumbnail_url: str = Field(description="The url of the image's thumbnail")
|
||||
metadata: ImageResponseMetadata = Field(description="The image's metadata")
|
||||
|
||||
|
||||
class ProgressImage(BaseModel):
|
||||
"""The progress image sent intermittently during processing"""
|
||||
|
||||
width: int = Field(description="The effective width of the image in pixels")
|
||||
height: int = Field(description="The effective height of the image in pixels")
|
||||
dataURL: str = Field(description="The image data as a b64 data URL")
|
@ -1,11 +1,18 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
import io
|
||||
from datetime import datetime, timezone
|
||||
import json
|
||||
import os
|
||||
from typing import Any
|
||||
import uuid
|
||||
|
||||
from fastapi import Path, Request, UploadFile
|
||||
from fastapi import HTTPException, Path, Query, Request, UploadFile
|
||||
from fastapi.responses import FileResponse, Response
|
||||
from fastapi.routing import APIRouter
|
||||
from PIL import Image
|
||||
from invokeai.app.api.models.images import ImageResponse, ImageResponseMetadata
|
||||
from invokeai.app.services.metadata import InvokeAIMetadata
|
||||
from invokeai.app.services.item_storage import PaginatedResults
|
||||
|
||||
from ...services.image_storage import ImageType
|
||||
from ..dependencies import ApiDependencies
|
||||
@ -17,50 +24,105 @@ images_router = APIRouter(prefix="/v1/images", tags=["images"])
|
||||
async def get_image(
|
||||
image_type: ImageType = Path(description="The type of image to get"),
|
||||
image_name: str = Path(description="The name of the image to get"),
|
||||
):
|
||||
) -> FileResponse | Response:
|
||||
"""Gets a result"""
|
||||
# TODO: This is not really secure at all. At least make sure only output results are served
|
||||
filename = ApiDependencies.invoker.services.images.get_path(image_type, image_name)
|
||||
return FileResponse(filename)
|
||||
|
||||
@images_router.get("/{image_type}/thumbnails/{image_name}", operation_id="get_thumbnail")
|
||||
path = ApiDependencies.invoker.services.images.get_path(
|
||||
image_type=image_type, image_name=image_name
|
||||
)
|
||||
|
||||
if ApiDependencies.invoker.services.images.validate_path(path):
|
||||
return FileResponse(path)
|
||||
else:
|
||||
raise HTTPException(status_code=404)
|
||||
|
||||
|
||||
@images_router.get(
|
||||
"/{image_type}/thumbnails/{image_name}", operation_id="get_thumbnail"
|
||||
)
|
||||
async def get_thumbnail(
|
||||
image_type: ImageType = Path(description="The type of image to get"),
|
||||
image_name: str = Path(description="The name of the image to get"),
|
||||
):
|
||||
) -> FileResponse | Response:
|
||||
"""Gets a thumbnail"""
|
||||
# TODO: This is not really secure at all. At least make sure only output results are served
|
||||
filename = ApiDependencies.invoker.services.images.get_path(image_type, 'thumbnails/' + image_name)
|
||||
return FileResponse(filename)
|
||||
|
||||
path = ApiDependencies.invoker.services.images.get_path(
|
||||
image_type=image_type, image_name=image_name, is_thumbnail=True
|
||||
)
|
||||
|
||||
if ApiDependencies.invoker.services.images.validate_path(path):
|
||||
return FileResponse(path)
|
||||
else:
|
||||
raise HTTPException(status_code=404)
|
||||
|
||||
|
||||
@images_router.post(
|
||||
"/uploads/",
|
||||
operation_id="upload_image",
|
||||
responses={
|
||||
201: {"description": "The image was uploaded successfully"},
|
||||
404: {"description": "Session not found"},
|
||||
201: {
|
||||
"description": "The image was uploaded successfully",
|
||||
"model": ImageResponse,
|
||||
},
|
||||
415: {"description": "Image upload failed"},
|
||||
},
|
||||
status_code=201,
|
||||
)
|
||||
async def upload_image(file: UploadFile, request: Request):
|
||||
async def upload_image(
|
||||
file: UploadFile, request: Request, response: Response
|
||||
) -> ImageResponse:
|
||||
if not file.content_type.startswith("image"):
|
||||
return Response(status_code=415)
|
||||
raise HTTPException(status_code=415, detail="Not an image")
|
||||
|
||||
contents = await file.read()
|
||||
|
||||
try:
|
||||
im = Image.open(contents)
|
||||
img = Image.open(io.BytesIO(contents))
|
||||
except:
|
||||
# Error opening the image
|
||||
return Response(status_code=415)
|
||||
raise HTTPException(status_code=415, detail="Failed to read image")
|
||||
|
||||
filename = f"{str(int(datetime.now(timezone.utc).timestamp()))}.png"
|
||||
ApiDependencies.invoker.services.images.save(ImageType.UPLOAD, filename, im)
|
||||
filename = f"{uuid.uuid4()}_{str(int(datetime.now(timezone.utc).timestamp()))}.png"
|
||||
|
||||
return Response(
|
||||
status_code=201,
|
||||
headers={
|
||||
"Location": request.url_for(
|
||||
"get_image", image_type=ImageType.UPLOAD, image_name=filename
|
||||
)
|
||||
},
|
||||
(image_path, thumbnail_path, ctime) = ApiDependencies.invoker.services.images.save(
|
||||
ImageType.UPLOAD, filename, img
|
||||
)
|
||||
|
||||
invokeai_metadata = ApiDependencies.invoker.services.metadata.get_metadata(img)
|
||||
|
||||
res = ImageResponse(
|
||||
image_type=ImageType.UPLOAD,
|
||||
image_name=filename,
|
||||
image_url=f"api/v1/images/{ImageType.UPLOAD.value}/{filename}",
|
||||
thumbnail_url=f"api/v1/images/{ImageType.UPLOAD.value}/thumbnails/{os.path.splitext(filename)[0]}.webp",
|
||||
metadata=ImageResponseMetadata(
|
||||
created=ctime,
|
||||
width=img.width,
|
||||
height=img.height,
|
||||
invokeai=invokeai_metadata,
|
||||
),
|
||||
)
|
||||
|
||||
response.status_code = 201
|
||||
response.headers["Location"] = request.url_for(
|
||||
"get_image", image_type=ImageType.UPLOAD.value, image_name=filename
|
||||
)
|
||||
|
||||
return res
|
||||
|
||||
|
||||
@images_router.get(
|
||||
"/",
|
||||
operation_id="list_images",
|
||||
responses={200: {"model": PaginatedResults[ImageResponse]}},
|
||||
)
|
||||
async def list_images(
|
||||
image_type: ImageType = Query(
|
||||
default=ImageType.RESULT, description="The type of images to get"
|
||||
),
|
||||
page: int = Query(default=0, description="The page of images to get"),
|
||||
per_page: int = Query(default=10, description="The number of images per page"),
|
||||
) -> PaginatedResults[ImageResponse]:
|
||||
"""Gets a list of images"""
|
||||
result = ApiDependencies.invoker.services.images.list(image_type, page, per_page)
|
||||
return result
|
||||
|
@ -1,11 +1,17 @@
|
||||
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
|
||||
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654) and 2023 Kent Keirsey (https://github.com/hipsterusername)
|
||||
|
||||
import shutil
|
||||
import asyncio
|
||||
from typing import Annotated, Any, List, Literal, Optional, Union
|
||||
|
||||
from fastapi.routing import APIRouter
|
||||
from fastapi.routing import APIRouter, HTTPException
|
||||
from pydantic import BaseModel, Field, parse_obj_as
|
||||
|
||||
from pathlib import Path
|
||||
from ..dependencies import ApiDependencies
|
||||
from invokeai.backend.globals import Globals, global_converted_ckpts_dir
|
||||
from invokeai.backend.args import Args
|
||||
|
||||
|
||||
|
||||
models_router = APIRouter(prefix="/v1/models", tags=["models"])
|
||||
|
||||
@ -15,11 +21,9 @@ class VaeRepo(BaseModel):
|
||||
path: Optional[str] = Field(description="The path to the VAE")
|
||||
subfolder: Optional[str] = Field(description="The subfolder to use for this VAE")
|
||||
|
||||
|
||||
class ModelInfo(BaseModel):
|
||||
description: Optional[str] = Field(description="A description of the model")
|
||||
|
||||
|
||||
class CkptModelInfo(ModelInfo):
|
||||
format: Literal['ckpt'] = 'ckpt'
|
||||
|
||||
@ -29,7 +33,6 @@ class CkptModelInfo(ModelInfo):
|
||||
width: Optional[int] = Field(description="The width of the model")
|
||||
height: Optional[int] = Field(description="The height of the model")
|
||||
|
||||
|
||||
class DiffusersModelInfo(ModelInfo):
|
||||
format: Literal['diffusers'] = 'diffusers'
|
||||
|
||||
@ -37,12 +40,29 @@ class DiffusersModelInfo(ModelInfo):
|
||||
repo_id: Optional[str] = Field(description="The repo ID to use for this model")
|
||||
path: Optional[str] = Field(description="The path to the model")
|
||||
|
||||
class CreateModelRequest(BaseModel):
|
||||
name: str = Field(description="The name of the model")
|
||||
info: Union[CkptModelInfo, DiffusersModelInfo] = Field(discriminator="format", description="The model info")
|
||||
|
||||
class CreateModelResponse(BaseModel):
|
||||
name: str = Field(description="The name of the new model")
|
||||
info: Union[CkptModelInfo, DiffusersModelInfo] = Field(discriminator="format", description="The model info")
|
||||
status: str = Field(description="The status of the API response")
|
||||
|
||||
class ConversionRequest(BaseModel):
|
||||
name: str = Field(description="The name of the new model")
|
||||
info: CkptModelInfo = Field(description="The converted model info")
|
||||
save_location: str = Field(description="The path to save the converted model weights")
|
||||
|
||||
|
||||
class ConvertedModelResponse(BaseModel):
|
||||
name: str = Field(description="The name of the new model")
|
||||
info: DiffusersModelInfo = Field(description="The converted model info")
|
||||
|
||||
class ModelsList(BaseModel):
|
||||
models: dict[str, Annotated[Union[(CkptModelInfo,DiffusersModelInfo)], Field(discriminator="format")]]
|
||||
|
||||
|
||||
|
||||
@models_router.get(
|
||||
"/",
|
||||
operation_id="list_models",
|
||||
@ -54,108 +74,61 @@ async def list_models() -> ModelsList:
|
||||
models = parse_obj_as(ModelsList, { "models": models_raw })
|
||||
return models
|
||||
|
||||
# @socketio.on("requestSystemConfig")
|
||||
# def handle_request_capabilities():
|
||||
# print(">> System config requested")
|
||||
# config = self.get_system_config()
|
||||
# config["model_list"] = self.generate.model_manager.list_models()
|
||||
# config["infill_methods"] = infill_methods()
|
||||
# socketio.emit("systemConfig", config)
|
||||
|
||||
# @socketio.on("searchForModels")
|
||||
# def handle_search_models(search_folder: str):
|
||||
# try:
|
||||
# if not search_folder:
|
||||
# socketio.emit(
|
||||
# "foundModels",
|
||||
# {"search_folder": None, "found_models": None},
|
||||
# )
|
||||
# else:
|
||||
# (
|
||||
# search_folder,
|
||||
# found_models,
|
||||
# ) = self.generate.model_manager.search_models(search_folder)
|
||||
# socketio.emit(
|
||||
# "foundModels",
|
||||
# {"search_folder": search_folder, "found_models": found_models},
|
||||
# )
|
||||
# except Exception as e:
|
||||
# self.handle_exceptions(e)
|
||||
# print("\n")
|
||||
@models_router.post(
|
||||
"/",
|
||||
operation_id="update_model",
|
||||
responses={200: {"status": "success"}},
|
||||
)
|
||||
async def update_model(
|
||||
model_request: CreateModelRequest
|
||||
) -> CreateModelResponse:
|
||||
""" Add Model """
|
||||
model_request_info = model_request.info
|
||||
info_dict = model_request_info.dict()
|
||||
model_response = CreateModelResponse(name=model_request.name, info=model_request.info, status="success")
|
||||
|
||||
# @socketio.on("addNewModel")
|
||||
# def handle_add_model(new_model_config: dict):
|
||||
# try:
|
||||
# model_name = new_model_config["name"]
|
||||
# del new_model_config["name"]
|
||||
# model_attributes = new_model_config
|
||||
# if len(model_attributes["vae"]) == 0:
|
||||
# del model_attributes["vae"]
|
||||
# update = False
|
||||
# current_model_list = self.generate.model_manager.list_models()
|
||||
# if model_name in current_model_list:
|
||||
# update = True
|
||||
ApiDependencies.invoker.services.model_manager.add_model(
|
||||
model_name=model_request.name,
|
||||
model_attributes=info_dict,
|
||||
clobber=True,
|
||||
)
|
||||
|
||||
# print(f">> Adding New Model: {model_name}")
|
||||
return model_response
|
||||
|
||||
# self.generate.model_manager.add_model(
|
||||
# model_name=model_name,
|
||||
# model_attributes=model_attributes,
|
||||
# clobber=True,
|
||||
# )
|
||||
# self.generate.model_manager.commit(opt.conf)
|
||||
|
||||
# new_model_list = self.generate.model_manager.list_models()
|
||||
# socketio.emit(
|
||||
# "newModelAdded",
|
||||
# {
|
||||
# "new_model_name": model_name,
|
||||
# "model_list": new_model_list,
|
||||
# "update": update,
|
||||
# },
|
||||
# )
|
||||
# print(f">> New Model Added: {model_name}")
|
||||
# except Exception as e:
|
||||
# self.handle_exceptions(e)
|
||||
@models_router.delete(
|
||||
"/{model_name}",
|
||||
operation_id="del_model",
|
||||
responses={
|
||||
204: {
|
||||
"description": "Model deleted successfully"
|
||||
},
|
||||
404: {
|
||||
"description": "Model not found"
|
||||
}
|
||||
},
|
||||
)
|
||||
async def delete_model(model_name: str) -> None:
|
||||
"""Delete Model"""
|
||||
model_names = ApiDependencies.invoker.services.model_manager.model_names()
|
||||
model_exists = model_name in model_names
|
||||
|
||||
# @socketio.on("deleteModel")
|
||||
# def handle_delete_model(model_name: str):
|
||||
# try:
|
||||
# print(f">> Deleting Model: {model_name}")
|
||||
# self.generate.model_manager.del_model(model_name)
|
||||
# self.generate.model_manager.commit(opt.conf)
|
||||
# updated_model_list = self.generate.model_manager.list_models()
|
||||
# socketio.emit(
|
||||
# "modelDeleted",
|
||||
# {
|
||||
# "deleted_model_name": model_name,
|
||||
# "model_list": updated_model_list,
|
||||
# },
|
||||
# )
|
||||
# print(f">> Model Deleted: {model_name}")
|
||||
# except Exception as e:
|
||||
# self.handle_exceptions(e)
|
||||
# check if model exists
|
||||
print(f">> Checking for model {model_name}...")
|
||||
|
||||
if model_exists:
|
||||
print(f">> Deleting Model: {model_name}")
|
||||
ApiDependencies.invoker.services.model_manager.del_model(model_name, delete_files=True)
|
||||
print(f">> Model Deleted: {model_name}")
|
||||
raise HTTPException(status_code=204, detail=f"Model '{model_name}' deleted successfully")
|
||||
|
||||
else:
|
||||
print(f">> Model not found")
|
||||
raise HTTPException(status_code=404, detail=f"Model '{model_name}' not found")
|
||||
|
||||
|
||||
# @socketio.on("requestModelChange")
|
||||
# def handle_set_model(model_name: str):
|
||||
# try:
|
||||
# print(f">> Model change requested: {model_name}")
|
||||
# model = self.generate.set_model(model_name)
|
||||
# model_list = self.generate.model_manager.list_models()
|
||||
# if model is None:
|
||||
# socketio.emit(
|
||||
# "modelChangeFailed",
|
||||
# {"model_name": model_name, "model_list": model_list},
|
||||
# )
|
||||
# else:
|
||||
# socketio.emit(
|
||||
# "modelChanged",
|
||||
# {"model_name": model_name, "model_list": model_list},
|
||||
# )
|
||||
# except Exception as e:
|
||||
# self.handle_exceptions(e)
|
||||
|
||||
# @socketio.on("convertToDiffusers")
|
||||
# @socketio.on("convertToDiffusers")
|
||||
# def convert_to_diffusers(model_to_convert: dict):
|
||||
# try:
|
||||
# if model_info := self.generate.model_manager.model_info(
|
||||
@ -275,5 +248,4 @@ async def list_models() -> ModelsList:
|
||||
# )
|
||||
# print(f">> Models Merged: {models_to_merge}")
|
||||
# print(f">> New Model Added: {model_merge_info['merged_model_name']}")
|
||||
# except Exception as e:
|
||||
# self.handle_exceptions(e)
|
||||
# except Exception as e:
|
@ -6,11 +6,41 @@ from typing import Any, Callable, Iterable, Literal, get_args, get_origin, get_t
|
||||
from pydantic import BaseModel, Field
|
||||
import networkx as nx
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
from ..invocations.baseinvocation import BaseInvocation
|
||||
from ..invocations.image import ImageField
|
||||
from ..services.graph import GraphExecutionState
|
||||
from ..services.graph import GraphExecutionState, LibraryGraph, GraphInvocation, Edge
|
||||
from ..services.invoker import Invoker
|
||||
|
||||
|
||||
def add_field_argument(command_parser, name: str, field, default_override = None):
|
||||
default = default_override if default_override is not None else field.default if field.default_factory is None else field.default_factory()
|
||||
if get_origin(field.type_) == Literal:
|
||||
allowed_values = get_args(field.type_)
|
||||
allowed_types = set()
|
||||
for val in allowed_values:
|
||||
allowed_types.add(type(val))
|
||||
allowed_types_list = list(allowed_types)
|
||||
field_type = allowed_types_list[0] if len(allowed_types) == 1 else Union[allowed_types_list] # type: ignore
|
||||
|
||||
command_parser.add_argument(
|
||||
f"--{name}",
|
||||
dest=name,
|
||||
type=field_type,
|
||||
default=default,
|
||||
choices=allowed_values,
|
||||
help=field.field_info.description,
|
||||
)
|
||||
else:
|
||||
command_parser.add_argument(
|
||||
f"--{name}",
|
||||
dest=name,
|
||||
type=field.type_,
|
||||
default=default,
|
||||
help=field.field_info.description,
|
||||
)
|
||||
|
||||
|
||||
def add_parsers(
|
||||
subparsers,
|
||||
commands: list[type],
|
||||
@ -35,30 +65,26 @@ def add_parsers(
|
||||
if name in exclude_fields:
|
||||
continue
|
||||
|
||||
if get_origin(field.type_) == Literal:
|
||||
allowed_values = get_args(field.type_)
|
||||
allowed_types = set()
|
||||
for val in allowed_values:
|
||||
allowed_types.add(type(val))
|
||||
allowed_types_list = list(allowed_types)
|
||||
field_type = allowed_types_list[0] if len(allowed_types) == 1 else Union[allowed_types_list] # type: ignore
|
||||
add_field_argument(command_parser, name, field)
|
||||
|
||||
command_parser.add_argument(
|
||||
f"--{name}",
|
||||
dest=name,
|
||||
type=field_type,
|
||||
default=field.default if field.default_factory is None else field.default_factory(),
|
||||
choices=allowed_values,
|
||||
help=field.field_info.description,
|
||||
)
|
||||
else:
|
||||
command_parser.add_argument(
|
||||
f"--{name}",
|
||||
dest=name,
|
||||
type=field.type_,
|
||||
default=field.default if field.default_factory is None else field.default_factory(),
|
||||
help=field.field_info.description,
|
||||
)
|
||||
|
||||
def add_graph_parsers(
|
||||
subparsers,
|
||||
graphs: list[LibraryGraph],
|
||||
add_arguments: Callable[[argparse.ArgumentParser], None]|None = None
|
||||
):
|
||||
for graph in graphs:
|
||||
command_parser = subparsers.add_parser(graph.name, help=graph.description)
|
||||
|
||||
if add_arguments is not None:
|
||||
add_arguments(command_parser)
|
||||
|
||||
# Add arguments for inputs
|
||||
for exposed_input in graph.exposed_inputs:
|
||||
node = graph.graph.get_node(exposed_input.node_path)
|
||||
field = node.__fields__[exposed_input.field]
|
||||
default_override = getattr(node, exposed_input.field)
|
||||
add_field_argument(command_parser, exposed_input.alias, field, default_override)
|
||||
|
||||
|
||||
class CliContext:
|
||||
@ -66,17 +92,38 @@ class CliContext:
|
||||
session: GraphExecutionState
|
||||
parser: argparse.ArgumentParser
|
||||
defaults: dict[str, Any]
|
||||
graph_nodes: dict[str, str]
|
||||
nodes_added: list[str]
|
||||
|
||||
def __init__(self, invoker: Invoker, session: GraphExecutionState, parser: argparse.ArgumentParser):
|
||||
self.invoker = invoker
|
||||
self.session = session
|
||||
self.parser = parser
|
||||
self.defaults = dict()
|
||||
self.graph_nodes = dict()
|
||||
self.nodes_added = list()
|
||||
|
||||
def get_session(self):
|
||||
self.session = self.invoker.services.graph_execution_manager.get(self.session.id)
|
||||
return self.session
|
||||
|
||||
def reset(self):
|
||||
self.session = self.invoker.create_execution_state()
|
||||
self.graph_nodes = dict()
|
||||
self.nodes_added = list()
|
||||
# Leave defaults unchanged
|
||||
|
||||
def add_node(self, node: BaseInvocation):
|
||||
self.get_session()
|
||||
self.session.graph.add_node(node)
|
||||
self.nodes_added.append(node.id)
|
||||
self.invoker.services.graph_execution_manager.set(self.session)
|
||||
|
||||
def add_edge(self, edge: Edge):
|
||||
self.get_session()
|
||||
self.session.add_edge(edge)
|
||||
self.invoker.services.graph_execution_manager.set(self.session)
|
||||
|
||||
|
||||
class ExitCli(Exception):
|
||||
"""Exception to exit the CLI"""
|
||||
|
@ -13,17 +13,22 @@ from typing import (
|
||||
from pydantic import BaseModel
|
||||
from pydantic.fields import Field
|
||||
|
||||
from invokeai.app.services.metadata import PngMetadataService
|
||||
|
||||
from .services.default_graphs import create_system_graphs
|
||||
|
||||
from .services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage
|
||||
|
||||
from ..backend import Args
|
||||
from .cli.commands import BaseCommand, CliContext, ExitCli, add_parsers, get_graph_execution_history
|
||||
from .cli.commands import BaseCommand, CliContext, ExitCli, add_graph_parsers, add_parsers, get_graph_execution_history
|
||||
from .cli.completer import set_autocompleter
|
||||
from .invocations import *
|
||||
from .invocations.baseinvocation import BaseInvocation
|
||||
from .services.events import EventServiceBase
|
||||
from .services.model_manager_initializer import get_model_manager
|
||||
from .services.restoration_services import RestorationServices
|
||||
from .services.graph import Edge, EdgeConnection, GraphExecutionState, are_connection_types_compatible
|
||||
from .services.graph import Edge, EdgeConnection, ExposedNodeInput, GraphExecutionState, GraphInvocation, LibraryGraph, are_connection_types_compatible
|
||||
from .services.default_graphs import default_text_to_image_graph_id
|
||||
from .services.image_storage import DiskImageStorage
|
||||
from .services.invocation_queue import MemoryInvocationQueue
|
||||
from .services.invocation_services import InvocationServices
|
||||
@ -58,7 +63,7 @@ def add_invocation_args(command_parser):
|
||||
)
|
||||
|
||||
|
||||
def get_command_parser() -> argparse.ArgumentParser:
|
||||
def get_command_parser(services: InvocationServices) -> argparse.ArgumentParser:
|
||||
# Create invocation parser
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
@ -76,20 +81,72 @@ def get_command_parser() -> argparse.ArgumentParser:
|
||||
commands = BaseCommand.get_all_subclasses()
|
||||
add_parsers(subparsers, commands, exclude_fields=["type"])
|
||||
|
||||
# Create subparsers for exposed CLI graphs
|
||||
# TODO: add a way to identify these graphs
|
||||
text_to_image = services.graph_library.get(default_text_to_image_graph_id)
|
||||
add_graph_parsers(subparsers, [text_to_image], add_arguments=add_invocation_args)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
class NodeField():
|
||||
alias: str
|
||||
node_path: str
|
||||
field: str
|
||||
field_type: type
|
||||
|
||||
def __init__(self, alias: str, node_path: str, field: str, field_type: type):
|
||||
self.alias = alias
|
||||
self.node_path = node_path
|
||||
self.field = field
|
||||
self.field_type = field_type
|
||||
|
||||
|
||||
def fields_from_type_hints(hints: dict[str, type], node_path: str) -> dict[str,NodeField]:
|
||||
return {k:NodeField(alias=k, node_path=node_path, field=k, field_type=v) for k, v in hints.items()}
|
||||
|
||||
|
||||
def get_node_input_field(graph: LibraryGraph, field_alias: str, node_id: str) -> NodeField:
|
||||
"""Gets the node field for the specified field alias"""
|
||||
exposed_input = next(e for e in graph.exposed_inputs if e.alias == field_alias)
|
||||
node_type = type(graph.graph.get_node(exposed_input.node_path))
|
||||
return NodeField(alias=exposed_input.alias, node_path=f'{node_id}.{exposed_input.node_path}', field=exposed_input.field, field_type=get_type_hints(node_type)[exposed_input.field])
|
||||
|
||||
|
||||
def get_node_output_field(graph: LibraryGraph, field_alias: str, node_id: str) -> NodeField:
|
||||
"""Gets the node field for the specified field alias"""
|
||||
exposed_output = next(e for e in graph.exposed_outputs if e.alias == field_alias)
|
||||
node_type = type(graph.graph.get_node(exposed_output.node_path))
|
||||
node_output_type = node_type.get_output_type()
|
||||
return NodeField(alias=exposed_output.alias, node_path=f'{node_id}.{exposed_output.node_path}', field=exposed_output.field, field_type=get_type_hints(node_output_type)[exposed_output.field])
|
||||
|
||||
|
||||
def get_node_inputs(invocation: BaseInvocation, context: CliContext) -> dict[str, NodeField]:
|
||||
"""Gets the inputs for the specified invocation from the context"""
|
||||
node_type = type(invocation)
|
||||
if node_type is not GraphInvocation:
|
||||
return fields_from_type_hints(get_type_hints(node_type), invocation.id)
|
||||
else:
|
||||
graph: LibraryGraph = context.invoker.services.graph_library.get(context.graph_nodes[invocation.id])
|
||||
return {e.alias: get_node_input_field(graph, e.alias, invocation.id) for e in graph.exposed_inputs}
|
||||
|
||||
|
||||
def get_node_outputs(invocation: BaseInvocation, context: CliContext) -> dict[str, NodeField]:
|
||||
"""Gets the outputs for the specified invocation from the context"""
|
||||
node_type = type(invocation)
|
||||
if node_type is not GraphInvocation:
|
||||
return fields_from_type_hints(get_type_hints(node_type.get_output_type()), invocation.id)
|
||||
else:
|
||||
graph: LibraryGraph = context.invoker.services.graph_library.get(context.graph_nodes[invocation.id])
|
||||
return {e.alias: get_node_output_field(graph, e.alias, invocation.id) for e in graph.exposed_outputs}
|
||||
|
||||
|
||||
def generate_matching_edges(
|
||||
a: BaseInvocation, b: BaseInvocation
|
||||
a: BaseInvocation, b: BaseInvocation, context: CliContext
|
||||
) -> list[Edge]:
|
||||
"""Generates all possible edges between two invocations"""
|
||||
atype = type(a)
|
||||
btype = type(b)
|
||||
|
||||
aoutputtype = atype.get_output_type()
|
||||
|
||||
afields = get_type_hints(aoutputtype)
|
||||
bfields = get_type_hints(btype)
|
||||
afields = get_node_outputs(a, context)
|
||||
bfields = get_node_inputs(b, context)
|
||||
|
||||
matching_fields = set(afields.keys()).intersection(bfields.keys())
|
||||
|
||||
@ -98,14 +155,14 @@ def generate_matching_edges(
|
||||
matching_fields = matching_fields.difference(invalid_fields)
|
||||
|
||||
# Validate types
|
||||
matching_fields = [f for f in matching_fields if are_connection_types_compatible(afields[f], bfields[f])]
|
||||
matching_fields = [f for f in matching_fields if are_connection_types_compatible(afields[f].field_type, bfields[f].field_type)]
|
||||
|
||||
edges = [
|
||||
Edge(
|
||||
source=EdgeConnection(node_id=a.id, field=field),
|
||||
destination=EdgeConnection(node_id=b.id, field=field)
|
||||
source=EdgeConnection(node_id=afields[alias].node_path, field=afields[alias].field),
|
||||
destination=EdgeConnection(node_id=bfields[alias].node_path, field=bfields[alias].field)
|
||||
)
|
||||
for field in matching_fields
|
||||
for alias in matching_fields
|
||||
]
|
||||
return edges
|
||||
|
||||
@ -145,6 +202,8 @@ def invoke_cli():
|
||||
|
||||
events = EventServiceBase()
|
||||
|
||||
metadata = PngMetadataService()
|
||||
|
||||
output_folder = os.path.abspath(
|
||||
os.path.join(os.path.dirname(__file__), "../../../outputs")
|
||||
)
|
||||
@ -156,8 +215,12 @@ def invoke_cli():
|
||||
model_manager=model_manager,
|
||||
events=events,
|
||||
latents = ForwardCacheLatentsStorage(DiskLatentsStorage(f'{output_folder}/latents')),
|
||||
images=DiskImageStorage(f'{output_folder}/images'),
|
||||
images=DiskImageStorage(f'{output_folder}/images', metadata_service=metadata),
|
||||
metadata=metadata,
|
||||
queue=MemoryInvocationQueue(),
|
||||
graph_library=SqliteItemStorage[LibraryGraph](
|
||||
filename=db_location, table_name="graphs"
|
||||
),
|
||||
graph_execution_manager=SqliteItemStorage[GraphExecutionState](
|
||||
filename=db_location, table_name="graph_executions"
|
||||
),
|
||||
@ -165,9 +228,12 @@ def invoke_cli():
|
||||
restoration=RestorationServices(config),
|
||||
)
|
||||
|
||||
system_graphs = create_system_graphs(services.graph_library)
|
||||
system_graph_names = set([g.name for g in system_graphs])
|
||||
|
||||
invoker = Invoker(services)
|
||||
session: GraphExecutionState = invoker.create_execution_state()
|
||||
parser = get_command_parser()
|
||||
parser = get_command_parser(services)
|
||||
|
||||
re_negid = re.compile('^-[0-9]+$')
|
||||
|
||||
@ -185,11 +251,12 @@ def invoke_cli():
|
||||
|
||||
try:
|
||||
# Refresh the state of the session
|
||||
history = list(get_graph_execution_history(context.session))
|
||||
#history = list(get_graph_execution_history(context.session))
|
||||
history = list(reversed(context.nodes_added))
|
||||
|
||||
# Split the command for piping
|
||||
cmds = cmd_input.split("|")
|
||||
start_id = len(history)
|
||||
start_id = len(context.nodes_added)
|
||||
current_id = start_id
|
||||
new_invocations = list()
|
||||
for cmd in cmds:
|
||||
@ -205,8 +272,24 @@ def invoke_cli():
|
||||
args[field_name] = field_default
|
||||
|
||||
# Parse invocation
|
||||
args["id"] = current_id
|
||||
command = CliCommand(command=args)
|
||||
command: CliCommand = None # type:ignore
|
||||
system_graph: LibraryGraph|None = None
|
||||
if args['type'] in system_graph_names:
|
||||
system_graph = next(filter(lambda g: g.name == args['type'], system_graphs))
|
||||
invocation = GraphInvocation(graph=system_graph.graph, id=str(current_id))
|
||||
for exposed_input in system_graph.exposed_inputs:
|
||||
if exposed_input.alias in args:
|
||||
node = invocation.graph.get_node(exposed_input.node_path)
|
||||
field = exposed_input.field
|
||||
setattr(node, field, args[exposed_input.alias])
|
||||
command = CliCommand(command = invocation)
|
||||
context.graph_nodes[invocation.id] = system_graph.id
|
||||
else:
|
||||
args["id"] = current_id
|
||||
command = CliCommand(command=args)
|
||||
|
||||
if command is None:
|
||||
continue
|
||||
|
||||
# Run any CLI commands immediately
|
||||
if isinstance(command.command, BaseCommand):
|
||||
@ -217,6 +300,7 @@ def invoke_cli():
|
||||
command.command.run(context)
|
||||
continue
|
||||
|
||||
# TODO: handle linking with library graphs
|
||||
# Pipe previous command output (if there was a previous command)
|
||||
edges: list[Edge] = list()
|
||||
if len(history) > 0 or current_id != start_id:
|
||||
@ -229,7 +313,7 @@ def invoke_cli():
|
||||
else context.session.graph.get_node(from_id)
|
||||
)
|
||||
matching_edges = generate_matching_edges(
|
||||
from_node, command.command
|
||||
from_node, command.command, context
|
||||
)
|
||||
edges.extend(matching_edges)
|
||||
|
||||
@ -242,7 +326,7 @@ def invoke_cli():
|
||||
|
||||
link_node = context.session.graph.get_node(node_id)
|
||||
matching_edges = generate_matching_edges(
|
||||
link_node, command.command
|
||||
link_node, command.command, context
|
||||
)
|
||||
matching_destinations = [e.destination for e in matching_edges]
|
||||
edges = [e for e in edges if e.destination not in matching_destinations]
|
||||
@ -256,12 +340,14 @@ def invoke_cli():
|
||||
if re_negid.match(node_id):
|
||||
node_id = str(current_id + int(node_id))
|
||||
|
||||
# TODO: handle missing input/output
|
||||
node_output = get_node_outputs(context.session.graph.get_node(node_id), context)[link[1]]
|
||||
node_input = get_node_inputs(command.command, context)[link[2]]
|
||||
|
||||
edges.append(
|
||||
Edge(
|
||||
source=EdgeConnection(node_id=node_id, field=link[1]),
|
||||
destination=EdgeConnection(
|
||||
node_id=command.command.id, field=link[2]
|
||||
)
|
||||
source=EdgeConnection(node_id=node_output.node_path, field=node_output.field),
|
||||
destination=EdgeConnection(node_id=node_input.node_path, field=node_input.field)
|
||||
)
|
||||
)
|
||||
|
||||
@ -270,10 +356,10 @@ def invoke_cli():
|
||||
current_id = current_id + 1
|
||||
|
||||
# Add the node to the session
|
||||
context.session.add_node(command.command)
|
||||
context.add_node(command.command)
|
||||
for edge in edges:
|
||||
print(edge)
|
||||
context.session.add_edge(edge)
|
||||
context.add_edge(edge)
|
||||
|
||||
# Execute all remaining nodes
|
||||
invoke_all(context)
|
||||
@ -285,7 +371,7 @@ def invoke_cli():
|
||||
except SessionError:
|
||||
# Start a new session
|
||||
print("Session error: creating a new session")
|
||||
context.session = context.invoker.create_execution_state()
|
||||
context.reset()
|
||||
|
||||
except ExitCli:
|
||||
break
|
||||
|
@ -2,7 +2,7 @@
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from inspect import signature
|
||||
from typing import get_args, get_type_hints
|
||||
from typing import get_args, get_type_hints, Dict, List, Literal, TypedDict
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
@ -76,3 +76,56 @@ class BaseInvocation(ABC, BaseModel):
|
||||
#fmt: off
|
||||
id: str = Field(description="The id of this node. Must be unique among all nodes.")
|
||||
#fmt: on
|
||||
|
||||
|
||||
# TODO: figure out a better way to provide these hints
|
||||
# TODO: when we can upgrade to python 3.11, we can use the`NotRequired` type instead of `total=False`
|
||||
class UIConfig(TypedDict, total=False):
|
||||
type_hints: Dict[
|
||||
str,
|
||||
Literal[
|
||||
"integer",
|
||||
"float",
|
||||
"boolean",
|
||||
"string",
|
||||
"enum",
|
||||
"image",
|
||||
"latents",
|
||||
"model",
|
||||
],
|
||||
]
|
||||
tags: List[str]
|
||||
title: str
|
||||
|
||||
class CustomisedSchemaExtra(TypedDict):
|
||||
ui: UIConfig
|
||||
|
||||
|
||||
class InvocationConfig(BaseModel.Config):
|
||||
"""Customizes pydantic's BaseModel.Config class for use by Invocations.
|
||||
|
||||
Provide `schema_extra` a `ui` dict to add hints for generated UIs.
|
||||
|
||||
`tags`
|
||||
- A list of strings, used to categorise invocations.
|
||||
|
||||
`type_hints`
|
||||
- A dict of field types which override the types in the invocation definition.
|
||||
- Each key should be the name of one of the invocation's fields.
|
||||
- Each value should be one of the valid types:
|
||||
- `integer`, `float`, `boolean`, `string`, `enum`, `image`, `latents`, `model`
|
||||
|
||||
```python
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"tags": ["stable-diffusion", "image"],
|
||||
"type_hints": {
|
||||
"initial_image": "image",
|
||||
},
|
||||
},
|
||||
}
|
||||
```
|
||||
"""
|
||||
|
||||
schema_extra: CustomisedSchemaExtra
|
||||
|
@ -1,16 +1,17 @@
|
||||
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
from typing import Literal
|
||||
from typing import Literal, Optional
|
||||
|
||||
import cv2 as cv
|
||||
import numpy as np
|
||||
import numpy.random
|
||||
from PIL import Image, ImageOps
|
||||
from pydantic import Field
|
||||
|
||||
from ..services.image_storage import ImageType
|
||||
from .baseinvocation import BaseInvocation, InvocationContext, BaseInvocationOutput
|
||||
from .image import ImageField, ImageOutput
|
||||
from .baseinvocation import (
|
||||
BaseInvocation,
|
||||
InvocationConfig,
|
||||
InvocationContext,
|
||||
BaseInvocationOutput,
|
||||
)
|
||||
|
||||
|
||||
class IntCollectionOutput(BaseInvocationOutput):
|
||||
@ -33,7 +34,9 @@ class RangeInvocation(BaseInvocation):
|
||||
step: int = Field(default=1, description="The step of the range")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> IntCollectionOutput:
|
||||
return IntCollectionOutput(collection=list(range(self.start, self.stop, self.step)))
|
||||
return IntCollectionOutput(
|
||||
collection=list(range(self.start, self.stop, self.step))
|
||||
)
|
||||
|
||||
|
||||
class RandomRangeInvocation(BaseInvocation):
|
||||
@ -43,8 +46,19 @@ class RandomRangeInvocation(BaseInvocation):
|
||||
|
||||
# Inputs
|
||||
low: int = Field(default=0, description="The inclusive low value")
|
||||
high: int = Field(default=np.iinfo(np.int32).max, description="The exclusive high value")
|
||||
high: int = Field(
|
||||
default=np.iinfo(np.int32).max, description="The exclusive high value"
|
||||
)
|
||||
size: int = Field(default=1, description="The number of values to generate")
|
||||
seed: Optional[int] = Field(
|
||||
ge=0,
|
||||
le=np.iinfo(np.int32).max,
|
||||
description="The seed for the RNG",
|
||||
default_factory=lambda: numpy.random.randint(0, np.iinfo(np.int32).max),
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> IntCollectionOutput:
|
||||
return IntCollectionOutput(collection=list(numpy.random.randint(self.low, self.high, size=self.size)))
|
||||
rng = np.random.default_rng(self.seed)
|
||||
return IntCollectionOutput(
|
||||
collection=list(rng.integers(low=self.low, high=self.high, size=self.size))
|
||||
)
|
||||
|
@ -5,14 +5,26 @@ from typing import Literal
|
||||
import cv2 as cv
|
||||
import numpy
|
||||
from PIL import Image, ImageOps
|
||||
from pydantic import Field
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from ..services.image_storage import ImageType
|
||||
from .baseinvocation import BaseInvocation, InvocationContext
|
||||
from .image import ImageField, ImageOutput
|
||||
from invokeai.app.models.image import ImageField, ImageType
|
||||
from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig
|
||||
from .image import ImageOutput, build_image_output
|
||||
|
||||
|
||||
class CvInpaintInvocation(BaseInvocation):
|
||||
class CvInvocationConfig(BaseModel):
|
||||
"""Helper class to provide all OpenCV invocations with additional config"""
|
||||
|
||||
# Schema customisation
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"tags": ["cv", "image"],
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class CvInpaintInvocation(BaseInvocation, CvInvocationConfig):
|
||||
"""Simple inpaint using opencv."""
|
||||
#fmt: off
|
||||
type: Literal["cv_inpaint"] = "cv_inpaint"
|
||||
@ -44,7 +56,14 @@ class CvInpaintInvocation(BaseInvocation):
|
||||
image_name = context.services.images.create_name(
|
||||
context.graph_execution_state_id, self.id
|
||||
)
|
||||
context.services.images.save(image_type, image_name, image_inpainted)
|
||||
return ImageOutput(
|
||||
image=ImageField(image_type=image_type, image_name=image_name)
|
||||
|
||||
metadata = context.services.metadata.build_metadata(
|
||||
session_id=context.graph_execution_state_id, node=self
|
||||
)
|
||||
|
||||
context.services.images.save(image_type, image_name, image_inpainted, metadata)
|
||||
return build_image_output(
|
||||
image_type=image_type,
|
||||
image_name=image_name,
|
||||
image=image_inpainted,
|
||||
)
|
@ -6,21 +6,36 @@ from typing import Literal, Optional, Union
|
||||
import numpy as np
|
||||
from torch import Tensor
|
||||
|
||||
from pydantic import Field
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from ..services.image_storage import ImageType
|
||||
from .baseinvocation import BaseInvocation, InvocationContext
|
||||
from .image import ImageField, ImageOutput
|
||||
from invokeai.app.models.image import ImageField, ImageType
|
||||
from invokeai.app.invocations.util.choose_model import choose_model
|
||||
from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig
|
||||
from .image import ImageOutput, build_image_output
|
||||
from ...backend.generator import Txt2Img, Img2Img, Inpaint, InvokeAIGenerator
|
||||
from ...backend.stable_diffusion import PipelineIntermediateState
|
||||
from ..util.util import diffusers_step_callback_adapter, CanceledException
|
||||
from ..util.step_callback import stable_diffusion_step_callback
|
||||
|
||||
SAMPLER_NAME_VALUES = Literal[tuple(InvokeAIGenerator.schedulers())]
|
||||
|
||||
|
||||
class SDImageInvocation(BaseModel):
|
||||
"""Helper class to provide all Stable Diffusion raster image invocations with additional config"""
|
||||
|
||||
# Schema customisation
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"tags": ["stable-diffusion", "image"],
|
||||
"type_hints": {
|
||||
"model": "model",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
SAMPLER_NAME_VALUES = Literal[
|
||||
tuple(InvokeAIGenerator.schedulers())
|
||||
]
|
||||
|
||||
# Text to image
|
||||
class TextToImageInvocation(BaseInvocation):
|
||||
class TextToImageInvocation(BaseInvocation, SDImageInvocation):
|
||||
"""Generates an image using text2img."""
|
||||
|
||||
type: Literal["txt2img"] = "txt2img"
|
||||
@ -34,7 +49,7 @@ class TextToImageInvocation(BaseInvocation):
|
||||
width: int = Field(default=512, multiple_of=64, gt=0, description="The width of the resulting image", )
|
||||
height: int = Field(default=512, multiple_of=64, gt=0, description="The height of the resulting image", )
|
||||
cfg_scale: float = Field(default=7.5, gt=0, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", )
|
||||
sampler_name: SAMPLER_NAME_VALUES = Field(default="k_lms", description="The sampler to use" )
|
||||
scheduler: SAMPLER_NAME_VALUES = Field(default="k_lms", description="The scheduler to use" )
|
||||
seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", )
|
||||
model: str = Field(default="", description="The model to use (currently ignored)")
|
||||
progress_images: bool = Field(default=False, description="Whether or not to produce progress images during generation", )
|
||||
@ -42,35 +57,31 @@ class TextToImageInvocation(BaseInvocation):
|
||||
|
||||
# TODO: pass this an emitter method or something? or a session for dispatching?
|
||||
def dispatch_progress(
|
||||
self, context: InvocationContext, intermediate_state: PipelineIntermediateState
|
||||
self,
|
||||
context: InvocationContext,
|
||||
source_node_id: str,
|
||||
intermediate_state: PipelineIntermediateState,
|
||||
) -> None:
|
||||
if (context.services.queue.is_canceled(context.graph_execution_state_id)):
|
||||
raise CanceledException
|
||||
|
||||
step = intermediate_state.step
|
||||
if intermediate_state.predicted_original is not None:
|
||||
# Some schedulers report not only the noisy latents at the current timestep,
|
||||
# but also their estimate so far of what the de-noised latents will be.
|
||||
sample = intermediate_state.predicted_original
|
||||
else:
|
||||
sample = intermediate_state.latents
|
||||
|
||||
diffusers_step_callback_adapter(sample, step, steps=self.steps, id=self.id, context=context)
|
||||
stable_diffusion_step_callback(
|
||||
context=context,
|
||||
intermediate_state=intermediate_state,
|
||||
node=self.dict(),
|
||||
source_node_id=source_node_id,
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
# def step_callback(state: PipelineIntermediateState):
|
||||
# if (context.services.queue.is_canceled(context.graph_execution_state_id)):
|
||||
# raise CanceledException
|
||||
# self.dispatch_progress(context, state.latents, state.step)
|
||||
|
||||
# Handle invalid model parameter
|
||||
# TODO: figure out if this can be done via a validator that uses the model_cache
|
||||
# TODO: How to get the default model name now?
|
||||
# (right now uses whatever current model is set in model manager)
|
||||
model= context.services.model_manager.get_model()
|
||||
model = choose_model(context.services.model_manager, self.model)
|
||||
|
||||
# Get the source node id (we are invoking the prepared node)
|
||||
graph_execution_state = context.services.graph_execution_manager.get(
|
||||
context.graph_execution_state_id
|
||||
)
|
||||
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
|
||||
|
||||
outputs = Txt2Img(model).generate(
|
||||
prompt=self.prompt,
|
||||
step_callback=partial(self.dispatch_progress, context),
|
||||
step_callback=partial(self.dispatch_progress, context, source_node_id),
|
||||
**self.dict(
|
||||
exclude={"prompt"}
|
||||
), # Shorthand for passing all of the parameters above manually
|
||||
@ -86,9 +97,18 @@ class TextToImageInvocation(BaseInvocation):
|
||||
image_name = context.services.images.create_name(
|
||||
context.graph_execution_state_id, self.id
|
||||
)
|
||||
context.services.images.save(image_type, image_name, generate_output.image)
|
||||
return ImageOutput(
|
||||
image=ImageField(image_type=image_type, image_name=image_name)
|
||||
|
||||
metadata = context.services.metadata.build_metadata(
|
||||
session_id=context.graph_execution_state_id, node=self
|
||||
)
|
||||
|
||||
context.services.images.save(
|
||||
image_type, image_name, generate_output.image, metadata
|
||||
)
|
||||
return build_image_output(
|
||||
image_type=image_type,
|
||||
image_name=image_name,
|
||||
image=generate_output.image,
|
||||
)
|
||||
|
||||
|
||||
@ -108,20 +128,17 @@ class ImageToImageInvocation(TextToImageInvocation):
|
||||
)
|
||||
|
||||
def dispatch_progress(
|
||||
self, context: InvocationContext, intermediate_state: PipelineIntermediateState
|
||||
) -> None:
|
||||
if (context.services.queue.is_canceled(context.graph_execution_state_id)):
|
||||
raise CanceledException
|
||||
|
||||
step = intermediate_state.step
|
||||
if intermediate_state.predicted_original is not None:
|
||||
# Some schedulers report not only the noisy latents at the current timestep,
|
||||
# but also their estimate so far of what the de-noised latents will be.
|
||||
sample = intermediate_state.predicted_original
|
||||
else:
|
||||
sample = intermediate_state.latents
|
||||
|
||||
diffusers_step_callback_adapter(sample, step, steps=self.steps, id=self.id, context=context)
|
||||
self,
|
||||
context: InvocationContext,
|
||||
source_node_id: str,
|
||||
intermediate_state: PipelineIntermediateState,
|
||||
) -> None:
|
||||
stable_diffusion_step_callback(
|
||||
context=context,
|
||||
intermediate_state=intermediate_state,
|
||||
node=self.dict(),
|
||||
source_node_id=source_node_id,
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = (
|
||||
@ -134,18 +151,23 @@ class ImageToImageInvocation(TextToImageInvocation):
|
||||
mask = None
|
||||
|
||||
# Handle invalid model parameter
|
||||
# TODO: figure out if this can be done via a validator that uses the model_cache
|
||||
# TODO: How to get the default model name now?
|
||||
model = context.services.model_manager.get_model()
|
||||
model = choose_model(context.services.model_manager, self.model)
|
||||
|
||||
# Get the source node id (we are invoking the prepared node)
|
||||
graph_execution_state = context.services.graph_execution_manager.get(
|
||||
context.graph_execution_state_id
|
||||
)
|
||||
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
|
||||
|
||||
outputs = Img2Img(model).generate(
|
||||
prompt=self.prompt,
|
||||
init_image=image,
|
||||
init_mask=mask,
|
||||
step_callback=partial(self.dispatch_progress, context),
|
||||
**self.dict(
|
||||
exclude={"prompt", "image", "mask"}
|
||||
), # Shorthand for passing all of the parameters above manually
|
||||
)
|
||||
prompt=self.prompt,
|
||||
init_image=image,
|
||||
init_mask=mask,
|
||||
step_callback=partial(self.dispatch_progress, context, source_node_id),
|
||||
**self.dict(
|
||||
exclude={"prompt", "image", "mask"}
|
||||
), # Shorthand for passing all of the parameters above manually
|
||||
)
|
||||
|
||||
# Outputs is an infinite iterator that will return a new InvokeAIGeneratorOutput object
|
||||
# each time it is called. We only need the first one.
|
||||
@ -160,11 +182,19 @@ class ImageToImageInvocation(TextToImageInvocation):
|
||||
image_name = context.services.images.create_name(
|
||||
context.graph_execution_state_id, self.id
|
||||
)
|
||||
context.services.images.save(image_type, image_name, result_image)
|
||||
return ImageOutput(
|
||||
image=ImageField(image_type=image_type, image_name=image_name)
|
||||
|
||||
metadata = context.services.metadata.build_metadata(
|
||||
session_id=context.graph_execution_state_id, node=self
|
||||
)
|
||||
|
||||
context.services.images.save(image_type, image_name, result_image, metadata)
|
||||
return build_image_output(
|
||||
image_type=image_type,
|
||||
image_name=image_name,
|
||||
image=result_image,
|
||||
)
|
||||
|
||||
|
||||
class InpaintInvocation(ImageToImageInvocation):
|
||||
"""Generates an image using inpaint."""
|
||||
|
||||
@ -180,20 +210,17 @@ class InpaintInvocation(ImageToImageInvocation):
|
||||
)
|
||||
|
||||
def dispatch_progress(
|
||||
self, context: InvocationContext, intermediate_state: PipelineIntermediateState
|
||||
) -> None:
|
||||
if (context.services.queue.is_canceled(context.graph_execution_state_id)):
|
||||
raise CanceledException
|
||||
|
||||
step = intermediate_state.step
|
||||
if intermediate_state.predicted_original is not None:
|
||||
# Some schedulers report not only the noisy latents at the current timestep,
|
||||
# but also their estimate so far of what the de-noised latents will be.
|
||||
sample = intermediate_state.predicted_original
|
||||
else:
|
||||
sample = intermediate_state.latents
|
||||
|
||||
diffusers_step_callback_adapter(sample, step, steps=self.steps, id=self.id, context=context)
|
||||
self,
|
||||
context: InvocationContext,
|
||||
source_node_id: str,
|
||||
intermediate_state: PipelineIntermediateState,
|
||||
) -> None:
|
||||
stable_diffusion_step_callback(
|
||||
context=context,
|
||||
intermediate_state=intermediate_state,
|
||||
node=self.dict(),
|
||||
source_node_id=source_node_id,
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = (
|
||||
@ -210,18 +237,23 @@ class InpaintInvocation(ImageToImageInvocation):
|
||||
)
|
||||
|
||||
# Handle invalid model parameter
|
||||
# TODO: figure out if this can be done via a validator that uses the model_cache
|
||||
# TODO: How to get the default model name now?
|
||||
model = context.services.model_manager.get_model()
|
||||
model = choose_model(context.services.model_manager, self.model)
|
||||
|
||||
# Get the source node id (we are invoking the prepared node)
|
||||
graph_execution_state = context.services.graph_execution_manager.get(
|
||||
context.graph_execution_state_id
|
||||
)
|
||||
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
|
||||
|
||||
outputs = Inpaint(model).generate(
|
||||
prompt=self.prompt,
|
||||
init_img=image,
|
||||
init_mask=mask,
|
||||
step_callback=partial(self.dispatch_progress, context),
|
||||
**self.dict(
|
||||
exclude={"prompt", "image", "mask"}
|
||||
), # Shorthand for passing all of the parameters above manually
|
||||
)
|
||||
prompt=self.prompt,
|
||||
init_img=image,
|
||||
init_mask=mask,
|
||||
step_callback=partial(self.dispatch_progress, context, source_node_id),
|
||||
**self.dict(
|
||||
exclude={"prompt", "image", "mask"}
|
||||
), # Shorthand for passing all of the parameters above manually
|
||||
)
|
||||
|
||||
# Outputs is an infinite iterator that will return a new InvokeAIGeneratorOutput object
|
||||
# each time it is called. We only need the first one.
|
||||
@ -236,7 +268,14 @@ class InpaintInvocation(ImageToImageInvocation):
|
||||
image_name = context.services.images.create_name(
|
||||
context.graph_execution_state_id, self.id
|
||||
)
|
||||
context.services.images.save(image_type, image_name, result_image)
|
||||
return ImageOutput(
|
||||
image=ImageField(image_type=image_type, image_name=image_name)
|
||||
|
||||
metadata = context.services.metadata.build_metadata(
|
||||
session_id=context.graph_execution_state_id, node=self
|
||||
)
|
||||
|
||||
context.services.images.save(image_type, image_name, result_image, metadata)
|
||||
return build_image_output(
|
||||
image_type=image_type,
|
||||
image_name=image_name,
|
||||
image=result_image,
|
||||
)
|
||||
|
@ -1,70 +1,97 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from typing import Literal, Optional
|
||||
|
||||
import numpy
|
||||
from PIL import Image, ImageFilter, ImageOps
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from ..services.image_storage import ImageType
|
||||
from ..services.invocation_services import InvocationServices
|
||||
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext
|
||||
from ..models.image import ImageField, ImageType
|
||||
from .baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
InvocationContext,
|
||||
InvocationConfig,
|
||||
)
|
||||
|
||||
|
||||
class ImageField(BaseModel):
|
||||
"""An image field used for passing image objects between invocations"""
|
||||
class PILInvocationConfig(BaseModel):
|
||||
"""Helper class to provide all PIL invocations with additional config"""
|
||||
|
||||
image_type: str = Field(
|
||||
default=ImageType.RESULT, description="The type of the image"
|
||||
)
|
||||
image_name: Optional[str] = Field(default=None, description="The name of the image")
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"tags": ["PIL", "image"],
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class ImageOutput(BaseInvocationOutput):
|
||||
"""Base class for invocations that output an image"""
|
||||
#fmt: off
|
||||
|
||||
# fmt: off
|
||||
type: Literal["image"] = "image"
|
||||
image: ImageField = Field(default=None, description="The output image")
|
||||
#fmt: on
|
||||
width: Optional[int] = Field(default=None, description="The width of the image in pixels")
|
||||
height: Optional[int] = Field(default=None, description="The height of the image in pixels")
|
||||
# fmt: on
|
||||
|
||||
class Config:
|
||||
schema_extra = {
|
||||
'required': [
|
||||
'type',
|
||||
'image',
|
||||
]
|
||||
"required": ["type", "image", "width", "height", "mode"]
|
||||
}
|
||||
|
||||
|
||||
def build_image_output(
|
||||
image_type: ImageType, image_name: str, image: Image.Image
|
||||
) -> ImageOutput:
|
||||
"""Builds an ImageOutput and its ImageField"""
|
||||
image_field = ImageField(
|
||||
image_name=image_name,
|
||||
image_type=image_type,
|
||||
)
|
||||
return ImageOutput(
|
||||
image=image_field,
|
||||
width=image.width,
|
||||
height=image.height,
|
||||
mode=image.mode,
|
||||
)
|
||||
|
||||
|
||||
class MaskOutput(BaseInvocationOutput):
|
||||
"""Base class for invocations that output a mask"""
|
||||
#fmt: off
|
||||
|
||||
# fmt: off
|
||||
type: Literal["mask"] = "mask"
|
||||
mask: ImageField = Field(default=None, description="The output mask")
|
||||
#fmt: on
|
||||
# fmt: on
|
||||
|
||||
class Config:
|
||||
schema_extra = {
|
||||
'required': [
|
||||
'type',
|
||||
'mask',
|
||||
"required": [
|
||||
"type",
|
||||
"mask",
|
||||
]
|
||||
}
|
||||
|
||||
# TODO: this isn't really necessary anymore
|
||||
|
||||
class LoadImageInvocation(BaseInvocation):
|
||||
"""Load an image from a filename and provide it as output."""
|
||||
#fmt: off
|
||||
"""Load an image and provide it as output."""
|
||||
|
||||
# fmt: off
|
||||
type: Literal["load_image"] = "load_image"
|
||||
|
||||
# Inputs
|
||||
image_type: ImageType = Field(description="The type of the image")
|
||||
image_name: str = Field(description="The name of the image")
|
||||
#fmt: on
|
||||
|
||||
# fmt: on
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
return ImageOutput(
|
||||
image=ImageField(image_type=self.image_type, image_name=self.image_name)
|
||||
image = context.services.images.get(self.image_type, self.image_name)
|
||||
|
||||
return build_image_output(
|
||||
image_type=self.image_type,
|
||||
image_name=self.image_name,
|
||||
image=image,
|
||||
)
|
||||
|
||||
|
||||
@ -85,16 +112,17 @@ class ShowImageInvocation(BaseInvocation):
|
||||
|
||||
# TODO: how to handle failure?
|
||||
|
||||
return ImageOutput(
|
||||
image=ImageField(
|
||||
image_type=self.image.image_type, image_name=self.image.image_name
|
||||
)
|
||||
return build_image_output(
|
||||
image_type=self.image.image_type,
|
||||
image_name=self.image.image_name,
|
||||
image=image,
|
||||
)
|
||||
|
||||
|
||||
class CropImageInvocation(BaseInvocation):
|
||||
class CropImageInvocation(BaseInvocation, PILInvocationConfig):
|
||||
"""Crops an image to a specified box. The box can be outside of the image."""
|
||||
#fmt: off
|
||||
|
||||
# fmt: off
|
||||
type: Literal["crop"] = "crop"
|
||||
|
||||
# Inputs
|
||||
@ -103,7 +131,7 @@ class CropImageInvocation(BaseInvocation):
|
||||
y: int = Field(default=0, description="The top y coordinate of the crop rectangle")
|
||||
width: int = Field(default=512, gt=0, description="The width of the crop rectangle")
|
||||
height: int = Field(default=512, gt=0, description="The height of the crop rectangle")
|
||||
#fmt: on
|
||||
# fmt: on
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get(
|
||||
@ -119,15 +147,23 @@ class CropImageInvocation(BaseInvocation):
|
||||
image_name = context.services.images.create_name(
|
||||
context.graph_execution_state_id, self.id
|
||||
)
|
||||
context.services.images.save(image_type, image_name, image_crop)
|
||||
return ImageOutput(
|
||||
image=ImageField(image_type=image_type, image_name=image_name)
|
||||
|
||||
metadata = context.services.metadata.build_metadata(
|
||||
session_id=context.graph_execution_state_id, node=self
|
||||
)
|
||||
|
||||
context.services.images.save(image_type, image_name, image_crop, metadata)
|
||||
return build_image_output(
|
||||
image_type=image_type,
|
||||
image_name=image_name,
|
||||
image=image_crop,
|
||||
)
|
||||
|
||||
|
||||
class PasteImageInvocation(BaseInvocation):
|
||||
class PasteImageInvocation(BaseInvocation, PILInvocationConfig):
|
||||
"""Pastes an image into another image."""
|
||||
#fmt: off
|
||||
|
||||
# fmt: off
|
||||
type: Literal["paste"] = "paste"
|
||||
|
||||
# Inputs
|
||||
@ -136,7 +172,7 @@ class PasteImageInvocation(BaseInvocation):
|
||||
mask: Optional[ImageField] = Field(default=None, description="The mask to use when pasting")
|
||||
x: int = Field(default=0, description="The left x coordinate at which to paste the image")
|
||||
y: int = Field(default=0, description="The top y coordinate at which to paste the image")
|
||||
#fmt: on
|
||||
# fmt: on
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
base_image = context.services.images.get(
|
||||
@ -149,7 +185,7 @@ class PasteImageInvocation(BaseInvocation):
|
||||
None
|
||||
if self.mask is None
|
||||
else ImageOps.invert(
|
||||
services.images.get(self.mask.image_type, self.mask.image_name)
|
||||
context.services.images.get(self.mask.image_type, self.mask.image_name)
|
||||
)
|
||||
)
|
||||
# TODO: probably shouldn't invert mask here... should user be required to do it?
|
||||
@ -169,21 +205,29 @@ class PasteImageInvocation(BaseInvocation):
|
||||
image_name = context.services.images.create_name(
|
||||
context.graph_execution_state_id, self.id
|
||||
)
|
||||
context.services.images.save(image_type, image_name, new_image)
|
||||
return ImageOutput(
|
||||
image=ImageField(image_type=image_type, image_name=image_name)
|
||||
|
||||
metadata = context.services.metadata.build_metadata(
|
||||
session_id=context.graph_execution_state_id, node=self
|
||||
)
|
||||
|
||||
context.services.images.save(image_type, image_name, new_image, metadata)
|
||||
return build_image_output(
|
||||
image_type=image_type,
|
||||
image_name=image_name,
|
||||
image=new_image,
|
||||
)
|
||||
|
||||
|
||||
class MaskFromAlphaInvocation(BaseInvocation):
|
||||
class MaskFromAlphaInvocation(BaseInvocation, PILInvocationConfig):
|
||||
"""Extracts the alpha channel of an image as a mask."""
|
||||
#fmt: off
|
||||
|
||||
# fmt: off
|
||||
type: Literal["tomask"] = "tomask"
|
||||
|
||||
# Inputs
|
||||
image: ImageField = Field(default=None, description="The image to create the mask from")
|
||||
invert: bool = Field(default=False, description="Whether or not to invert the mask")
|
||||
#fmt: on
|
||||
# fmt: on
|
||||
|
||||
def invoke(self, context: InvocationContext) -> MaskOutput:
|
||||
image = context.services.images.get(
|
||||
@ -198,22 +242,27 @@ class MaskFromAlphaInvocation(BaseInvocation):
|
||||
image_name = context.services.images.create_name(
|
||||
context.graph_execution_state_id, self.id
|
||||
)
|
||||
context.services.images.save(image_type, image_name, image_mask)
|
||||
|
||||
metadata = context.services.metadata.build_metadata(
|
||||
session_id=context.graph_execution_state_id, node=self
|
||||
)
|
||||
|
||||
context.services.images.save(image_type, image_name, image_mask, metadata)
|
||||
return MaskOutput(mask=ImageField(image_type=image_type, image_name=image_name))
|
||||
|
||||
|
||||
class BlurInvocation(BaseInvocation):
|
||||
class BlurInvocation(BaseInvocation, PILInvocationConfig):
|
||||
"""Blurs an image"""
|
||||
|
||||
#fmt: off
|
||||
# fmt: off
|
||||
type: Literal["blur"] = "blur"
|
||||
|
||||
# Inputs
|
||||
image: ImageField = Field(default=None, description="The image to blur")
|
||||
radius: float = Field(default=8.0, ge=0, description="The blur radius")
|
||||
blur_type: Literal["gaussian", "box"] = Field(default="gaussian", description="The type of blur")
|
||||
#fmt: on
|
||||
|
||||
# fmt: on
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get(
|
||||
self.image.image_type, self.image.image_name
|
||||
@ -230,22 +279,28 @@ class BlurInvocation(BaseInvocation):
|
||||
image_name = context.services.images.create_name(
|
||||
context.graph_execution_state_id, self.id
|
||||
)
|
||||
context.services.images.save(image_type, image_name, blur_image)
|
||||
return ImageOutput(
|
||||
image=ImageField(image_type=image_type, image_name=image_name)
|
||||
|
||||
metadata = context.services.metadata.build_metadata(
|
||||
session_id=context.graph_execution_state_id, node=self
|
||||
)
|
||||
|
||||
context.services.images.save(image_type, image_name, blur_image, metadata)
|
||||
return build_image_output(
|
||||
image_type=image_type, image_name=image_name, image=blur_image
|
||||
)
|
||||
|
||||
|
||||
class LerpInvocation(BaseInvocation):
|
||||
class LerpInvocation(BaseInvocation, PILInvocationConfig):
|
||||
"""Linear interpolation of all pixels of an image"""
|
||||
#fmt: off
|
||||
|
||||
# fmt: off
|
||||
type: Literal["lerp"] = "lerp"
|
||||
|
||||
# Inputs
|
||||
image: ImageField = Field(default=None, description="The image to lerp")
|
||||
min: int = Field(default=0, ge=0, le=255, description="The minimum output value")
|
||||
max: int = Field(default=255, ge=0, le=255, description="The maximum output value")
|
||||
#fmt: on
|
||||
# fmt: on
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get(
|
||||
@ -261,23 +316,29 @@ class LerpInvocation(BaseInvocation):
|
||||
image_name = context.services.images.create_name(
|
||||
context.graph_execution_state_id, self.id
|
||||
)
|
||||
context.services.images.save(image_type, image_name, lerp_image)
|
||||
return ImageOutput(
|
||||
image=ImageField(image_type=image_type, image_name=image_name)
|
||||
|
||||
metadata = context.services.metadata.build_metadata(
|
||||
session_id=context.graph_execution_state_id, node=self
|
||||
)
|
||||
|
||||
context.services.images.save(image_type, image_name, lerp_image, metadata)
|
||||
return build_image_output(
|
||||
image_type=image_type, image_name=image_name, image=lerp_image
|
||||
)
|
||||
|
||||
|
||||
class InverseLerpInvocation(BaseInvocation):
|
||||
class InverseLerpInvocation(BaseInvocation, PILInvocationConfig):
|
||||
"""Inverse linear interpolation of all pixels of an image"""
|
||||
#fmt: off
|
||||
|
||||
# fmt: off
|
||||
type: Literal["ilerp"] = "ilerp"
|
||||
|
||||
# Inputs
|
||||
image: ImageField = Field(default=None, description="The image to lerp")
|
||||
min: int = Field(default=0, ge=0, le=255, description="The minimum input value")
|
||||
max: int = Field(default=255, ge=0, le=255, description="The maximum input value")
|
||||
#fmt: on
|
||||
|
||||
# fmt: on
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get(
|
||||
self.image.image_type, self.image.image_name
|
||||
@ -297,7 +358,12 @@ class InverseLerpInvocation(BaseInvocation):
|
||||
image_name = context.services.images.create_name(
|
||||
context.graph_execution_state_id, self.id
|
||||
)
|
||||
context.services.images.save(image_type, image_name, ilerp_image)
|
||||
return ImageOutput(
|
||||
image=ImageField(image_type=image_type, image_name=image_name)
|
||||
|
||||
metadata = context.services.metadata.build_metadata(
|
||||
session_id=context.graph_execution_state_id, node=self
|
||||
)
|
||||
|
||||
context.services.images.save(image_type, image_name, ilerp_image, metadata)
|
||||
return build_image_output(
|
||||
image_type=image_type, image_name=image_name, image=ilerp_image
|
||||
)
|
||||
|
@ -1,25 +1,26 @@
|
||||
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
import random
|
||||
from typing import Literal, Optional
|
||||
from pydantic import BaseModel, Field
|
||||
from torch import Tensor
|
||||
import torch
|
||||
|
||||
from invokeai.app.invocations.util.choose_model import choose_model
|
||||
|
||||
from invokeai.app.util.step_callback import stable_diffusion_step_callback
|
||||
|
||||
from ...backend.model_management.model_manager import ModelManager
|
||||
from ...backend.util.devices import CUDA_DEVICE, torch_dtype
|
||||
from ...backend.util.devices import choose_torch_device, torch_dtype
|
||||
from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import PostprocessingSettings
|
||||
from ...backend.image_util.seamless import configure_model_padding
|
||||
from ...backend.prompting.conditioning import get_uc_and_c_and_ec
|
||||
from ...backend.stable_diffusion.diffusers_pipeline import ConditioningData, StableDiffusionGeneratorPipeline
|
||||
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext
|
||||
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext, InvocationConfig
|
||||
import numpy as np
|
||||
from accelerate.utils import set_seed
|
||||
from ..services.image_storage import ImageType
|
||||
from .baseinvocation import BaseInvocation, InvocationContext
|
||||
from .image import ImageField, ImageOutput
|
||||
from ...backend.generator import Generator
|
||||
from .image import ImageField, ImageOutput, build_image_output
|
||||
from ...backend.stable_diffusion import PipelineIntermediateState
|
||||
from ...backend.util.util import image_to_dataURL
|
||||
from diffusers.schedulers import SchedulerMixin as Scheduler
|
||||
import diffusers
|
||||
from diffusers import DiffusionPipeline
|
||||
@ -30,6 +31,8 @@ class LatentsField(BaseModel):
|
||||
|
||||
latents_name: Optional[str] = Field(default=None, description="The name of the latents")
|
||||
|
||||
class Config:
|
||||
schema_extra = {"required": ["latents_name"]}
|
||||
|
||||
class LatentsOutput(BaseInvocationOutput):
|
||||
"""Base class for invocations that output latents"""
|
||||
@ -99,18 +102,31 @@ def get_noise(width:int, height:int, device:torch.device, seed:int = 0, latent_c
|
||||
return x
|
||||
|
||||
|
||||
def random_seed():
|
||||
return random.randint(0, np.iinfo(np.uint32).max)
|
||||
|
||||
|
||||
class NoiseInvocation(BaseInvocation):
|
||||
"""Generates latent noise."""
|
||||
|
||||
type: Literal["noise"] = "noise"
|
||||
|
||||
# Inputs
|
||||
seed: int = Field(default=0, ge=0, le=np.iinfo(np.uint32).max, description="The seed to use", )
|
||||
seed: int = Field(ge=0, le=np.iinfo(np.uint32).max, description="The seed to use", default_factory=random_seed)
|
||||
width: int = Field(default=512, multiple_of=64, gt=0, description="The width of the resulting noise", )
|
||||
height: int = Field(default=512, multiple_of=64, gt=0, description="The height of the resulting noise", )
|
||||
|
||||
|
||||
# Schema customisation
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"tags": ["latents", "noise"],
|
||||
},
|
||||
}
|
||||
|
||||
def invoke(self, context: InvocationContext) -> NoiseOutput:
|
||||
device = torch.device(CUDA_DEVICE)
|
||||
device = torch.device(choose_torch_device())
|
||||
noise = get_noise(self.width, self.height, device, self.seed)
|
||||
|
||||
name = f'{context.graph_execution_state_id}__{self.id}'
|
||||
@ -136,48 +152,45 @@ class TextToLatentsInvocation(BaseInvocation):
|
||||
width: int = Field(default=512, multiple_of=64, gt=0, description="The width of the resulting image", )
|
||||
height: int = Field(default=512, multiple_of=64, gt=0, description="The height of the resulting image", )
|
||||
cfg_scale: float = Field(default=7.5, gt=0, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", )
|
||||
sampler_name: SAMPLER_NAME_VALUES = Field(default="k_lms", description="The sampler to use" )
|
||||
scheduler: SAMPLER_NAME_VALUES = Field(default="k_lms", description="The scheduler to use" )
|
||||
seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", )
|
||||
seamless_axes: str = Field(default="", description="The axes to tile the image on, 'x' and/or 'y'")
|
||||
model: str = Field(default="", description="The model to use (currently ignored)")
|
||||
progress_images: bool = Field(default=False, description="Whether or not to produce progress images during generation", )
|
||||
# fmt: on
|
||||
|
||||
# Schema customisation
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"tags": ["latents", "image"],
|
||||
"type_hints": {
|
||||
"model": "model"
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
# TODO: pass this an emitter method or something? or a session for dispatching?
|
||||
def dispatch_progress(
|
||||
self, context: InvocationContext, sample: Tensor, step: int
|
||||
) -> None:
|
||||
# TODO: only output a preview image when requested
|
||||
image = Generator.sample_to_lowres_estimated_image(sample)
|
||||
|
||||
(width, height) = image.size
|
||||
width *= 8
|
||||
height *= 8
|
||||
|
||||
dataURL = image_to_dataURL(image, image_format="JPEG")
|
||||
|
||||
context.services.events.emit_generator_progress(
|
||||
context.graph_execution_state_id,
|
||||
self.id,
|
||||
{
|
||||
"width": width,
|
||||
"height": height,
|
||||
"dataURL": dataURL
|
||||
},
|
||||
step,
|
||||
self.steps,
|
||||
self, context: InvocationContext, source_node_id: str, intermediate_state: PipelineIntermediateState
|
||||
) -> None:
|
||||
stable_diffusion_step_callback(
|
||||
context=context,
|
||||
intermediate_state=intermediate_state,
|
||||
node=self.dict(),
|
||||
source_node_id=source_node_id,
|
||||
)
|
||||
|
||||
|
||||
def get_model(self, model_manager: ModelManager) -> StableDiffusionGeneratorPipeline:
|
||||
model_info = model_manager.get_model(self.model)
|
||||
model_info = choose_model(model_manager, self.model)
|
||||
model_name = model_info['model_name']
|
||||
model_hash = model_info['hash']
|
||||
model: StableDiffusionGeneratorPipeline = model_info['model']
|
||||
model.scheduler = get_scheduler(
|
||||
model=model,
|
||||
scheduler_name=self.sampler_name
|
||||
scheduler_name=self.scheduler
|
||||
)
|
||||
|
||||
|
||||
if isinstance(model, DiffusionPipeline):
|
||||
for component in [model.unet, model.vae]:
|
||||
configure_model_padding(component,
|
||||
@ -213,8 +226,12 @@ class TextToLatentsInvocation(BaseInvocation):
|
||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||
noise = context.services.latents.get(self.noise.latents_name)
|
||||
|
||||
# Get the source node id (we are invoking the prepared node)
|
||||
graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id)
|
||||
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
|
||||
|
||||
def step_callback(state: PipelineIntermediateState):
|
||||
self.dispatch_progress(context, state.latents, state.step)
|
||||
self.dispatch_progress(context, source_node_id, state)
|
||||
|
||||
model = self.get_model(context.services.model_manager)
|
||||
conditioning_data = self.get_conditioning_data(model)
|
||||
@ -244,6 +261,17 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
|
||||
|
||||
type: Literal["l2l"] = "l2l"
|
||||
|
||||
# Schema customisation
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"tags": ["latents"],
|
||||
"type_hints": {
|
||||
"model": "model"
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
# Inputs
|
||||
latents: Optional[LatentsField] = Field(description="The latents to use as a base image")
|
||||
strength: float = Field(default=0.5, description="The strength of the latents to use")
|
||||
@ -252,8 +280,12 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
|
||||
noise = context.services.latents.get(self.noise.latents_name)
|
||||
latent = context.services.latents.get(self.latents.latents_name)
|
||||
|
||||
# Get the source node id (we are invoking the prepared node)
|
||||
graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id)
|
||||
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
|
||||
|
||||
def step_callback(state: PipelineIntermediateState):
|
||||
self.dispatch_progress(context, state.latents, state.step)
|
||||
self.dispatch_progress(context, source_node_id, state)
|
||||
|
||||
model = self.get_model(context.services.model_manager)
|
||||
conditioning_data = self.get_conditioning_data(model)
|
||||
@ -263,7 +295,7 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
|
||||
initial_latents = latent if self.strength < 1.0 else torch.zeros_like(
|
||||
latent, device=model.device, dtype=latent.dtype
|
||||
)
|
||||
|
||||
|
||||
timesteps, _ = model.get_img2img_timesteps(
|
||||
self.steps,
|
||||
self.strength,
|
||||
@ -299,12 +331,23 @@ class LatentsToImageInvocation(BaseInvocation):
|
||||
latents: Optional[LatentsField] = Field(description="The latents to generate an image from")
|
||||
model: str = Field(default="", description="The model to use")
|
||||
|
||||
# Schema customisation
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"tags": ["latents", "image"],
|
||||
"type_hints": {
|
||||
"model": "model"
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
latents = context.services.latents.get(self.latents.latents_name)
|
||||
|
||||
# TODO: this only really needs the vae
|
||||
model_info = context.services.model_manager.get_model(self.model)
|
||||
model_info = choose_model(context.services.model_manager, self.model)
|
||||
model: StableDiffusionGeneratorPipeline = model_info['model']
|
||||
|
||||
with torch.inference_mode():
|
||||
@ -315,7 +358,14 @@ class LatentsToImageInvocation(BaseInvocation):
|
||||
image_name = context.services.images.create_name(
|
||||
context.graph_execution_state_id, self.id
|
||||
)
|
||||
context.services.images.save(image_type, image_name, image)
|
||||
return ImageOutput(
|
||||
image=ImageField(image_type=image_type, image_name=image_name)
|
||||
|
||||
metadata = context.services.metadata.build_metadata(
|
||||
session_id=context.graph_execution_state_id, node=self
|
||||
)
|
||||
|
||||
context.services.images.save(image_type, image_name, image, metadata)
|
||||
return build_image_output(
|
||||
image_type=image_type,
|
||||
image_name=image_name,
|
||||
image=image
|
||||
)
|
||||
|
@ -1,15 +1,22 @@
|
||||
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from typing import Literal, Optional
|
||||
from typing import Literal
|
||||
|
||||
import numpy
|
||||
from PIL import Image, ImageFilter, ImageOps
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from ..services.image_storage import ImageType
|
||||
from ..services.invocation_services import InvocationServices
|
||||
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext
|
||||
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext, InvocationConfig
|
||||
|
||||
|
||||
class MathInvocationConfig(BaseModel):
|
||||
"""Helper class to provide all math invocations with additional config"""
|
||||
|
||||
# Schema customisation
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"tags": ["math"],
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class IntOutput(BaseInvocationOutput):
|
||||
@ -20,7 +27,7 @@ class IntOutput(BaseInvocationOutput):
|
||||
#fmt: on
|
||||
|
||||
|
||||
class AddInvocation(BaseInvocation):
|
||||
class AddInvocation(BaseInvocation, MathInvocationConfig):
|
||||
"""Adds two numbers"""
|
||||
#fmt: off
|
||||
type: Literal["add"] = "add"
|
||||
@ -32,7 +39,7 @@ class AddInvocation(BaseInvocation):
|
||||
return IntOutput(a=self.a + self.b)
|
||||
|
||||
|
||||
class SubtractInvocation(BaseInvocation):
|
||||
class SubtractInvocation(BaseInvocation, MathInvocationConfig):
|
||||
"""Subtracts two numbers"""
|
||||
#fmt: off
|
||||
type: Literal["sub"] = "sub"
|
||||
@ -44,7 +51,7 @@ class SubtractInvocation(BaseInvocation):
|
||||
return IntOutput(a=self.a - self.b)
|
||||
|
||||
|
||||
class MultiplyInvocation(BaseInvocation):
|
||||
class MultiplyInvocation(BaseInvocation, MathInvocationConfig):
|
||||
"""Multiplies two numbers"""
|
||||
#fmt: off
|
||||
type: Literal["mul"] = "mul"
|
||||
@ -56,7 +63,7 @@ class MultiplyInvocation(BaseInvocation):
|
||||
return IntOutput(a=self.a * self.b)
|
||||
|
||||
|
||||
class DivideInvocation(BaseInvocation):
|
||||
class DivideInvocation(BaseInvocation, MathInvocationConfig):
|
||||
"""Divides two numbers"""
|
||||
#fmt: off
|
||||
type: Literal["div"] = "div"
|
||||
|
18
invokeai/app/invocations/params.py
Normal file
18
invokeai/app/invocations/params.py
Normal file
@ -0,0 +1,18 @@
|
||||
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
from typing import Literal
|
||||
from pydantic import Field
|
||||
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext
|
||||
from .math import IntOutput
|
||||
|
||||
# Pass-through parameter nodes - used by subgraphs
|
||||
|
||||
class ParamIntInvocation(BaseInvocation):
|
||||
"""An integer parameter"""
|
||||
#fmt: off
|
||||
type: Literal["param_int"] = "param_int"
|
||||
a: int = Field(default=0, description="The integer value")
|
||||
#fmt: on
|
||||
|
||||
def invoke(self, context: InvocationContext) -> IntOutput:
|
||||
return IntOutput(a=self.a)
|
@ -1,12 +1,11 @@
|
||||
from datetime import datetime, timezone
|
||||
from typing import Literal, Union
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from ..services.image_storage import ImageType
|
||||
from ..services.invocation_services import InvocationServices
|
||||
from .baseinvocation import BaseInvocation, InvocationContext
|
||||
from .image import ImageField, ImageOutput
|
||||
from invokeai.app.models.image import ImageField, ImageType
|
||||
|
||||
from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig
|
||||
from .image import ImageOutput, build_image_output
|
||||
|
||||
class RestoreFaceInvocation(BaseInvocation):
|
||||
"""Restores faces in an image."""
|
||||
@ -18,6 +17,14 @@ class RestoreFaceInvocation(BaseInvocation):
|
||||
strength: float = Field(default=0.75, gt=0, le=1, description="The strength of the restoration" )
|
||||
#fmt: on
|
||||
|
||||
# Schema customisation
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"tags": ["restoration", "image"],
|
||||
},
|
||||
}
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get(
|
||||
self.image.image_type, self.image.image_name
|
||||
@ -36,7 +43,14 @@ class RestoreFaceInvocation(BaseInvocation):
|
||||
image_name = context.services.images.create_name(
|
||||
context.graph_execution_state_id, self.id
|
||||
)
|
||||
context.services.images.save(image_type, image_name, results[0][0])
|
||||
return ImageOutput(
|
||||
image=ImageField(image_type=image_type, image_name=image_name)
|
||||
|
||||
metadata = context.services.metadata.build_metadata(
|
||||
session_id=context.graph_execution_state_id, node=self
|
||||
)
|
||||
|
||||
context.services.images.save(image_type, image_name, results[0][0], metadata)
|
||||
return build_image_output(
|
||||
image_type=image_type,
|
||||
image_name=image_name,
|
||||
image=results[0][0]
|
||||
)
|
@ -1,14 +1,12 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from typing import Literal, Union
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from ..services.image_storage import ImageType
|
||||
from ..services.invocation_services import InvocationServices
|
||||
from .baseinvocation import BaseInvocation, InvocationContext
|
||||
from .image import ImageField, ImageOutput
|
||||
from invokeai.app.models.image import ImageField, ImageType
|
||||
from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig
|
||||
from .image import ImageOutput, build_image_output
|
||||
|
||||
|
||||
class UpscaleInvocation(BaseInvocation):
|
||||
@ -22,6 +20,15 @@ class UpscaleInvocation(BaseInvocation):
|
||||
level: Literal[2, 4] = Field(default=2, description="The upscale level")
|
||||
#fmt: on
|
||||
|
||||
|
||||
# Schema customisation
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"tags": ["upscaling", "image"],
|
||||
},
|
||||
}
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get(
|
||||
self.image.image_type, self.image.image_name
|
||||
@ -40,7 +47,14 @@ class UpscaleInvocation(BaseInvocation):
|
||||
image_name = context.services.images.create_name(
|
||||
context.graph_execution_state_id, self.id
|
||||
)
|
||||
context.services.images.save(image_type, image_name, results[0][0])
|
||||
return ImageOutput(
|
||||
image=ImageField(image_type=image_type, image_name=image_name)
|
||||
|
||||
metadata = context.services.metadata.build_metadata(
|
||||
session_id=context.graph_execution_state_id, node=self
|
||||
)
|
||||
|
||||
context.services.images.save(image_type, image_name, results[0][0], metadata)
|
||||
return build_image_output(
|
||||
image_type=image_type,
|
||||
image_name=image_name,
|
||||
image=results[0][0]
|
||||
)
|
14
invokeai/app/invocations/util/choose_model.py
Normal file
14
invokeai/app/invocations/util/choose_model.py
Normal file
@ -0,0 +1,14 @@
|
||||
from invokeai.backend.model_management.model_manager import ModelManager
|
||||
|
||||
|
||||
def choose_model(model_manager: ModelManager, model_name: str):
|
||||
"""Returns the default model if the `model_name` not a valid model, else returns the selected model."""
|
||||
if model_manager.valid_model(model_name):
|
||||
model = model_manager.get_model(model_name)
|
||||
else:
|
||||
model = model_manager.get_model()
|
||||
print(
|
||||
f"* Warning: '{model_name}' is not a valid model name. Using default model \'{model['model_name']}\' instead."
|
||||
)
|
||||
|
||||
return model
|
0
invokeai/app/models/__init__.py
Normal file
0
invokeai/app/models/__init__.py
Normal file
3
invokeai/app/models/exceptions.py
Normal file
3
invokeai/app/models/exceptions.py
Normal file
@ -0,0 +1,3 @@
|
||||
class CanceledException(Exception):
|
||||
"""Execution canceled by user."""
|
||||
pass
|
29
invokeai/app/models/image.py
Normal file
29
invokeai/app/models/image.py
Normal file
@ -0,0 +1,29 @@
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class ImageType(str, Enum):
|
||||
RESULT = "results"
|
||||
INTERMEDIATE = "intermediates"
|
||||
UPLOAD = "uploads"
|
||||
|
||||
|
||||
def is_image_type(obj):
|
||||
try:
|
||||
ImageType(obj)
|
||||
except ValueError:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
class ImageField(BaseModel):
|
||||
"""An image field used for passing image objects between invocations"""
|
||||
|
||||
image_type: ImageType = Field(
|
||||
default=ImageType.RESULT, description="The type of the image"
|
||||
)
|
||||
image_name: Optional[str] = Field(default=None, description="The name of the image")
|
||||
|
||||
class Config:
|
||||
schema_extra = {"required": ["image_type", "image_name"]}
|
56
invokeai/app/services/default_graphs.py
Normal file
56
invokeai/app/services/default_graphs.py
Normal file
@ -0,0 +1,56 @@
|
||||
from ..invocations.latent import LatentsToImageInvocation, NoiseInvocation, TextToLatentsInvocation
|
||||
from ..invocations.params import ParamIntInvocation
|
||||
from .graph import Edge, EdgeConnection, ExposedNodeInput, ExposedNodeOutput, Graph, LibraryGraph
|
||||
from .item_storage import ItemStorageABC
|
||||
|
||||
|
||||
default_text_to_image_graph_id = '539b2af5-2b4d-4d8c-8071-e54a3255fc74'
|
||||
|
||||
|
||||
def create_text_to_image() -> LibraryGraph:
|
||||
return LibraryGraph(
|
||||
id=default_text_to_image_graph_id,
|
||||
name='t2i',
|
||||
description='Converts text to an image',
|
||||
graph=Graph(
|
||||
nodes={
|
||||
'width': ParamIntInvocation(id='width', a=512),
|
||||
'height': ParamIntInvocation(id='height', a=512),
|
||||
'3': NoiseInvocation(id='3'),
|
||||
'4': TextToLatentsInvocation(id='4'),
|
||||
'5': LatentsToImageInvocation(id='5')
|
||||
},
|
||||
edges=[
|
||||
Edge(source=EdgeConnection(node_id='width', field='a'), destination=EdgeConnection(node_id='3', field='width')),
|
||||
Edge(source=EdgeConnection(node_id='height', field='a'), destination=EdgeConnection(node_id='3', field='height')),
|
||||
Edge(source=EdgeConnection(node_id='width', field='a'), destination=EdgeConnection(node_id='4', field='width')),
|
||||
Edge(source=EdgeConnection(node_id='height', field='a'), destination=EdgeConnection(node_id='4', field='height')),
|
||||
Edge(source=EdgeConnection(node_id='3', field='noise'), destination=EdgeConnection(node_id='4', field='noise')),
|
||||
Edge(source=EdgeConnection(node_id='4', field='latents'), destination=EdgeConnection(node_id='5', field='latents')),
|
||||
]
|
||||
),
|
||||
exposed_inputs=[
|
||||
ExposedNodeInput(node_path='4', field='prompt', alias='prompt'),
|
||||
ExposedNodeInput(node_path='width', field='a', alias='width'),
|
||||
ExposedNodeInput(node_path='height', field='a', alias='height')
|
||||
],
|
||||
exposed_outputs=[
|
||||
ExposedNodeOutput(node_path='5', field='image', alias='image')
|
||||
])
|
||||
|
||||
|
||||
def create_system_graphs(graph_library: ItemStorageABC[LibraryGraph]) -> list[LibraryGraph]:
|
||||
"""Creates the default system graphs, or adds new versions if the old ones don't match"""
|
||||
|
||||
graphs: list[LibraryGraph] = list()
|
||||
|
||||
text_to_image = graph_library.get(default_text_to_image_graph_id)
|
||||
|
||||
# TODO: Check if the graph is the same as the default one, and if not, update it
|
||||
#if text_to_image is None:
|
||||
text_to_image = create_text_to_image()
|
||||
graph_library.set(text_to_image)
|
||||
|
||||
graphs.append(text_to_image)
|
||||
|
||||
return graphs
|
@ -1,10 +1,9 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
from typing import Any, Dict, TypedDict
|
||||
from typing import Any
|
||||
from invokeai.app.api.models.images import ProgressImage
|
||||
from invokeai.app.util.misc import get_timestamp
|
||||
|
||||
ProgressImage = TypedDict(
|
||||
"ProgressImage", {"dataURL": str, "width": int, "height": int}
|
||||
)
|
||||
|
||||
class EventServiceBase:
|
||||
session_event: str = "session_event"
|
||||
@ -14,7 +13,8 @@ class EventServiceBase:
|
||||
def dispatch(self, event_name: str, payload: Any) -> None:
|
||||
pass
|
||||
|
||||
def __emit_session_event(self, event_name: str, payload: Dict) -> None:
|
||||
def __emit_session_event(self, event_name: str, payload: dict) -> None:
|
||||
payload["timestamp"] = get_timestamp()
|
||||
self.dispatch(
|
||||
event_name=EventServiceBase.session_event,
|
||||
payload=dict(event=event_name, data=payload),
|
||||
@ -25,7 +25,8 @@ class EventServiceBase:
|
||||
def emit_generator_progress(
|
||||
self,
|
||||
graph_execution_state_id: str,
|
||||
invocation_id: str,
|
||||
node: dict,
|
||||
source_node_id: str,
|
||||
progress_image: ProgressImage | None,
|
||||
step: int,
|
||||
total_steps: int,
|
||||
@ -35,48 +36,60 @@ class EventServiceBase:
|
||||
event_name="generator_progress",
|
||||
payload=dict(
|
||||
graph_execution_state_id=graph_execution_state_id,
|
||||
invocation_id=invocation_id,
|
||||
progress_image=progress_image,
|
||||
node=node,
|
||||
source_node_id=source_node_id,
|
||||
progress_image=progress_image.dict() if progress_image is not None else None,
|
||||
step=step,
|
||||
total_steps=total_steps,
|
||||
),
|
||||
)
|
||||
|
||||
def emit_invocation_complete(
|
||||
self, graph_execution_state_id: str, invocation_id: str, result: Dict
|
||||
self,
|
||||
graph_execution_state_id: str,
|
||||
result: dict,
|
||||
node: dict,
|
||||
source_node_id: str,
|
||||
) -> None:
|
||||
"""Emitted when an invocation has completed"""
|
||||
self.__emit_session_event(
|
||||
event_name="invocation_complete",
|
||||
payload=dict(
|
||||
graph_execution_state_id=graph_execution_state_id,
|
||||
invocation_id=invocation_id,
|
||||
node=node,
|
||||
source_node_id=source_node_id,
|
||||
result=result,
|
||||
),
|
||||
)
|
||||
|
||||
def emit_invocation_error(
|
||||
self, graph_execution_state_id: str, invocation_id: str, error: str
|
||||
self,
|
||||
graph_execution_state_id: str,
|
||||
node: dict,
|
||||
source_node_id: str,
|
||||
error: str,
|
||||
) -> None:
|
||||
"""Emitted when an invocation has completed"""
|
||||
self.__emit_session_event(
|
||||
event_name="invocation_error",
|
||||
payload=dict(
|
||||
graph_execution_state_id=graph_execution_state_id,
|
||||
invocation_id=invocation_id,
|
||||
node=node,
|
||||
source_node_id=source_node_id,
|
||||
error=error,
|
||||
),
|
||||
)
|
||||
|
||||
def emit_invocation_started(
|
||||
self, graph_execution_state_id: str, invocation_id: str
|
||||
self, graph_execution_state_id: str, node: dict, source_node_id: str
|
||||
) -> None:
|
||||
"""Emitted when an invocation has started"""
|
||||
self.__emit_session_event(
|
||||
event_name="invocation_started",
|
||||
payload=dict(
|
||||
graph_execution_state_id=graph_execution_state_id,
|
||||
invocation_id=invocation_id,
|
||||
node=node,
|
||||
source_node_id=source_node_id,
|
||||
),
|
||||
)
|
||||
|
||||
@ -84,5 +97,7 @@ class EventServiceBase:
|
||||
"""Emitted when a session has completed all invocations"""
|
||||
self.__emit_session_event(
|
||||
event_name="graph_execution_state_complete",
|
||||
payload=dict(graph_execution_state_id=graph_execution_state_id),
|
||||
payload=dict(
|
||||
graph_execution_state_id=graph_execution_state_id,
|
||||
),
|
||||
)
|
||||
|
@ -2,7 +2,6 @@
|
||||
|
||||
import copy
|
||||
import itertools
|
||||
import traceback
|
||||
import uuid
|
||||
from types import NoneType
|
||||
from typing import (
|
||||
@ -17,7 +16,7 @@ from typing import (
|
||||
)
|
||||
|
||||
import networkx as nx
|
||||
from pydantic import BaseModel, validator
|
||||
from pydantic import BaseModel, root_validator, validator
|
||||
from pydantic.fields import Field
|
||||
|
||||
from ..invocations import *
|
||||
@ -26,7 +25,6 @@ from ..invocations.baseinvocation import (
|
||||
BaseInvocationOutput,
|
||||
InvocationContext,
|
||||
)
|
||||
from .invocation_services import InvocationServices
|
||||
|
||||
|
||||
class EdgeConnection(BaseModel):
|
||||
@ -215,7 +213,7 @@ InvocationOutputsUnion = Union[BaseInvocationOutput.get_all_subclasses_tuple()]
|
||||
|
||||
|
||||
class Graph(BaseModel):
|
||||
id: str = Field(description="The id of this graph", default_factory=uuid.uuid4)
|
||||
id: str = Field(description="The id of this graph", default_factory=lambda: uuid.uuid4().__str__())
|
||||
# TODO: use a list (and never use dict in a BaseModel) because pydantic/fastapi hates me
|
||||
nodes: dict[str, Annotated[InvocationsUnion, Field(discriminator="type")]] = Field(
|
||||
description="The nodes in this graph", default_factory=dict
|
||||
@ -283,7 +281,8 @@ class Graph(BaseModel):
|
||||
:raises InvalidEdgeError: the provided edge is invalid.
|
||||
"""
|
||||
|
||||
if self._is_edge_valid(edge) and edge not in self.edges:
|
||||
self._validate_edge(edge)
|
||||
if edge not in self.edges:
|
||||
self.edges.append(edge)
|
||||
else:
|
||||
raise InvalidEdgeError()
|
||||
@ -354,7 +353,7 @@ class Graph(BaseModel):
|
||||
|
||||
return True
|
||||
|
||||
def _is_edge_valid(self, edge: Edge) -> bool:
|
||||
def _validate_edge(self, edge: Edge):
|
||||
"""Validates that a new edge doesn't create a cycle in the graph"""
|
||||
|
||||
# Validate that the nodes exist (edges may contain node paths, so we can't just check for nodes directly)
|
||||
@ -362,54 +361,53 @@ class Graph(BaseModel):
|
||||
from_node = self.get_node(edge.source.node_id)
|
||||
to_node = self.get_node(edge.destination.node_id)
|
||||
except NodeNotFoundError:
|
||||
return False
|
||||
raise InvalidEdgeError("One or both nodes don't exist")
|
||||
|
||||
# Validate that an edge to this node+field doesn't already exist
|
||||
input_edges = self._get_input_edges(edge.destination.node_id, edge.destination.field)
|
||||
if len(input_edges) > 0 and not isinstance(to_node, CollectInvocation):
|
||||
return False
|
||||
raise InvalidEdgeError(f'Edge to node {edge.destination.node_id} field {edge.destination.field} already exists')
|
||||
|
||||
# Validate that no cycles would be created
|
||||
g = self.nx_graph_flat()
|
||||
g.add_edge(edge.source.node_id, edge.destination.node_id)
|
||||
if not nx.is_directed_acyclic_graph(g):
|
||||
return False
|
||||
raise InvalidEdgeError(f'Edge creates a cycle in the graph')
|
||||
|
||||
# Validate that the field types are compatible
|
||||
if not are_connections_compatible(
|
||||
from_node, edge.source.field, to_node, edge.destination.field
|
||||
):
|
||||
return False
|
||||
raise InvalidEdgeError(f'Fields are incompatible')
|
||||
|
||||
# Validate if iterator output type matches iterator input type (if this edge results in both being set)
|
||||
if isinstance(to_node, IterateInvocation) and edge.destination.field == "collection":
|
||||
if not self._is_iterator_connection_valid(
|
||||
edge.destination.node_id, new_input=edge.source
|
||||
):
|
||||
return False
|
||||
raise InvalidEdgeError(f'Iterator input type does not match iterator output type')
|
||||
|
||||
# Validate if iterator input type matches output type (if this edge results in both being set)
|
||||
if isinstance(from_node, IterateInvocation) and edge.source.field == "item":
|
||||
if not self._is_iterator_connection_valid(
|
||||
edge.source.node_id, new_output=edge.destination
|
||||
):
|
||||
return False
|
||||
raise InvalidEdgeError(f'Iterator output type does not match iterator input type')
|
||||
|
||||
# Validate if collector input type matches output type (if this edge results in both being set)
|
||||
if isinstance(to_node, CollectInvocation) and edge.destination.field == "item":
|
||||
if not self._is_collector_connection_valid(
|
||||
edge.destination.node_id, new_input=edge.source
|
||||
):
|
||||
return False
|
||||
raise InvalidEdgeError(f'Collector output type does not match collector input type')
|
||||
|
||||
# Validate if collector output type matches input type (if this edge results in both being set)
|
||||
if isinstance(from_node, CollectInvocation) and edge.source.field == "collection":
|
||||
if not self._is_collector_connection_valid(
|
||||
edge.source.node_id, new_output=edge.destination
|
||||
):
|
||||
return False
|
||||
raise InvalidEdgeError(f'Collector input type does not match collector output type')
|
||||
|
||||
return True
|
||||
|
||||
def has_node(self, node_path: str) -> bool:
|
||||
"""Determines whether or not a node exists in the graph."""
|
||||
@ -733,7 +731,7 @@ class Graph(BaseModel):
|
||||
for sgn in (
|
||||
gn for gn in self.nodes.values() if isinstance(gn, GraphInvocation)
|
||||
):
|
||||
sgn.graph.nx_graph_flat(g, self._get_node_path(sgn.id, prefix))
|
||||
g = sgn.graph.nx_graph_flat(g, self._get_node_path(sgn.id, prefix))
|
||||
|
||||
# TODO: figure out if iteration nodes need to be expanded
|
||||
|
||||
@ -750,9 +748,7 @@ class Graph(BaseModel):
|
||||
class GraphExecutionState(BaseModel):
|
||||
"""Tracks the state of a graph execution"""
|
||||
|
||||
id: str = Field(
|
||||
description="The id of the execution state", default_factory=uuid.uuid4
|
||||
)
|
||||
id: str = Field(description="The id of the execution state", default_factory=lambda: uuid.uuid4().__str__())
|
||||
|
||||
# TODO: Store a reference to the graph instead of the actual graph?
|
||||
graph: Graph = Field(description="The graph being executed")
|
||||
@ -794,9 +790,6 @@ class GraphExecutionState(BaseModel):
|
||||
default_factory=dict,
|
||||
)
|
||||
|
||||
# Declare all fields as required; necessary for OpenAPI schema generation build.
|
||||
# Technically only fields without a `default_factory` need to be listed here.
|
||||
# See: https://github.com/pydantic/pydantic/discussions/4577
|
||||
class Config:
|
||||
schema_extra = {
|
||||
'required': [
|
||||
@ -861,7 +854,8 @@ class GraphExecutionState(BaseModel):
|
||||
|
||||
def is_complete(self) -> bool:
|
||||
"""Returns true if the graph is complete"""
|
||||
return self.has_error() or all((k in self.executed for k in self.graph.nodes))
|
||||
node_ids = set(self.graph.nx_graph_flat().nodes)
|
||||
return self.has_error() or all((k in self.executed for k in node_ids))
|
||||
|
||||
def has_error(self) -> bool:
|
||||
"""Returns true if the graph has any errors"""
|
||||
@ -949,11 +943,11 @@ class GraphExecutionState(BaseModel):
|
||||
|
||||
def _iterator_graph(self) -> nx.DiGraph:
|
||||
"""Gets a DiGraph with edges to collectors removed so an ancestor search produces all active iterators for any node"""
|
||||
g = self.graph.nx_graph()
|
||||
g = self.graph.nx_graph_flat()
|
||||
collectors = (
|
||||
n
|
||||
for n in self.graph.nodes
|
||||
if isinstance(self.graph.nodes[n], CollectInvocation)
|
||||
if isinstance(self.graph.get_node(n), CollectInvocation)
|
||||
)
|
||||
for c in collectors:
|
||||
g.remove_edges_from(list(g.in_edges(c)))
|
||||
@ -965,7 +959,7 @@ class GraphExecutionState(BaseModel):
|
||||
iterators = [
|
||||
n
|
||||
for n in nx.ancestors(g, node_id)
|
||||
if isinstance(self.graph.nodes[n], IterateInvocation)
|
||||
if isinstance(self.graph.get_node(n), IterateInvocation)
|
||||
]
|
||||
return iterators
|
||||
|
||||
@ -1101,7 +1095,9 @@ class GraphExecutionState(BaseModel):
|
||||
|
||||
# TODO: Add API for modifying underlying graph that checks if the change will be valid given the current execution state
|
||||
def _is_edge_valid(self, edge: Edge) -> bool:
|
||||
if not self._is_edge_valid(edge):
|
||||
try:
|
||||
self.graph._validate_edge(edge)
|
||||
except InvalidEdgeError:
|
||||
return False
|
||||
|
||||
# Invalid if destination has already been prepared or executed
|
||||
@ -1147,4 +1143,52 @@ class GraphExecutionState(BaseModel):
|
||||
self.graph.delete_edge(edge)
|
||||
|
||||
|
||||
class ExposedNodeInput(BaseModel):
|
||||
node_path: str = Field(description="The node path to the node with the input")
|
||||
field: str = Field(description="The field name of the input")
|
||||
alias: str = Field(description="The alias of the input")
|
||||
|
||||
|
||||
class ExposedNodeOutput(BaseModel):
|
||||
node_path: str = Field(description="The node path to the node with the output")
|
||||
field: str = Field(description="The field name of the output")
|
||||
alias: str = Field(description="The alias of the output")
|
||||
|
||||
class LibraryGraph(BaseModel):
|
||||
id: str = Field(description="The unique identifier for this library graph", default_factory=uuid.uuid4)
|
||||
graph: Graph = Field(description="The graph")
|
||||
name: str = Field(description="The name of the graph")
|
||||
description: str = Field(description="The description of the graph")
|
||||
exposed_inputs: list[ExposedNodeInput] = Field(description="The inputs exposed by this graph", default_factory=list)
|
||||
exposed_outputs: list[ExposedNodeOutput] = Field(description="The outputs exposed by this graph", default_factory=list)
|
||||
|
||||
@validator('exposed_inputs', 'exposed_outputs')
|
||||
def validate_exposed_aliases(cls, v):
|
||||
if len(v) != len(set(i.alias for i in v)):
|
||||
raise ValueError("Duplicate exposed alias")
|
||||
return v
|
||||
|
||||
@root_validator
|
||||
def validate_exposed_nodes(cls, values):
|
||||
graph = values['graph']
|
||||
|
||||
# Validate exposed inputs
|
||||
for exposed_input in values['exposed_inputs']:
|
||||
if not graph.has_node(exposed_input.node_path):
|
||||
raise ValueError(f"Exposed input node {exposed_input.node_path} does not exist")
|
||||
node = graph.get_node(exposed_input.node_path)
|
||||
if get_input_field(node, exposed_input.field) is None:
|
||||
raise ValueError(f"Exposed input field {exposed_input.field} does not exist on node {exposed_input.node_path}")
|
||||
|
||||
# Validate exposed outputs
|
||||
for exposed_output in values['exposed_outputs']:
|
||||
if not graph.has_node(exposed_output.node_path):
|
||||
raise ValueError(f"Exposed output node {exposed_output.node_path} does not exist")
|
||||
node = graph.get_node(exposed_output.node_path)
|
||||
if get_output_field(node, exposed_output.field) is None:
|
||||
raise ValueError(f"Exposed output field {exposed_output.field} does not exist on node {exposed_output.node_path}")
|
||||
|
||||
return values
|
||||
|
||||
|
||||
GraphInvocation.update_forward_refs()
|
||||
|
@ -1,23 +1,24 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
import datetime
|
||||
import os
|
||||
from glob import glob
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from queue import Queue
|
||||
from typing import Dict
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
from PIL.Image import Image
|
||||
from invokeai.app.util.save_thumbnail import save_thumbnail
|
||||
|
||||
from invokeai.backend.image_util import PngWriter
|
||||
|
||||
|
||||
class ImageType(str, Enum):
|
||||
RESULT = "results"
|
||||
INTERMEDIATE = "intermediates"
|
||||
UPLOAD = "uploads"
|
||||
import PIL.Image as PILImage
|
||||
from invokeai.app.api.models.images import ImageResponse, ImageResponseMetadata
|
||||
from invokeai.app.models.image import ImageType
|
||||
from invokeai.app.services.metadata import (
|
||||
InvokeAIMetadata,
|
||||
MetadataServiceBase,
|
||||
build_invokeai_metadata_pnginfo,
|
||||
)
|
||||
from invokeai.app.services.item_storage import PaginatedResults
|
||||
from invokeai.app.util.misc import get_timestamp
|
||||
from invokeai.app.util.thumbnails import get_thumbnail_name, make_thumbnail
|
||||
|
||||
|
||||
class ImageStorageBase(ABC):
|
||||
@ -25,40 +26,66 @@ class ImageStorageBase(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def get(self, image_type: ImageType, image_name: str) -> Image:
|
||||
"""Retrieves an image as PIL Image."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def list(
|
||||
self, image_type: ImageType, page: int = 0, per_page: int = 10
|
||||
) -> PaginatedResults[ImageResponse]:
|
||||
"""Gets a paginated list of images."""
|
||||
pass
|
||||
|
||||
# TODO: make this a bit more flexible for e.g. cloud storage
|
||||
@abstractmethod
|
||||
def get_path(self, image_type: ImageType, image_name: str) -> str:
|
||||
def get_path(
|
||||
self, image_type: ImageType, image_name: str, is_thumbnail: bool = False
|
||||
) -> str:
|
||||
"""Gets the path to an image or its thumbnail."""
|
||||
pass
|
||||
|
||||
# TODO: make this a bit more flexible for e.g. cloud storage
|
||||
@abstractmethod
|
||||
def validate_path(self, path: str) -> bool:
|
||||
"""Validates an image path."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def save(self, image_type: ImageType, image_name: str, image: Image) -> None:
|
||||
def save(
|
||||
self,
|
||||
image_type: ImageType,
|
||||
image_name: str,
|
||||
image: Image,
|
||||
metadata: InvokeAIMetadata | None = None,
|
||||
) -> Tuple[str, str, int]:
|
||||
"""Saves an image and a 256x256 WEBP thumbnail. Returns a tuple of the image path, thumbnail path, and created timestamp."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def delete(self, image_type: ImageType, image_name: str) -> None:
|
||||
"""Deletes an image and its thumbnail (if one exists)."""
|
||||
pass
|
||||
|
||||
def create_name(self, context_id: str, node_id: str) -> str:
|
||||
return f"{context_id}_{node_id}_{str(int(datetime.datetime.now(datetime.timezone.utc).timestamp()))}.png"
|
||||
"""Creates a unique contextual image filename."""
|
||||
return f"{context_id}_{node_id}_{str(get_timestamp())}.png"
|
||||
|
||||
|
||||
class DiskImageStorage(ImageStorageBase):
|
||||
"""Stores images on disk"""
|
||||
|
||||
__output_folder: str
|
||||
__pngWriter: PngWriter
|
||||
__cache_ids: Queue # TODO: this is an incredibly naive cache
|
||||
__cache: Dict[str, Image]
|
||||
__max_cache_size: int
|
||||
__metadata_service: MetadataServiceBase
|
||||
|
||||
def __init__(self, output_folder: str):
|
||||
def __init__(self, output_folder: str, metadata_service: MetadataServiceBase):
|
||||
self.__output_folder = output_folder
|
||||
self.__pngWriter = PngWriter(output_folder)
|
||||
self.__cache = dict()
|
||||
self.__cache_ids = Queue()
|
||||
self.__max_cache_size = 10 # TODO: get this from config
|
||||
self.__metadata_service = metadata_service
|
||||
|
||||
Path(output_folder).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
@ -71,42 +98,132 @@ class DiskImageStorage(ImageStorageBase):
|
||||
parents=True, exist_ok=True
|
||||
)
|
||||
|
||||
def list(
|
||||
self, image_type: ImageType, page: int = 0, per_page: int = 10
|
||||
) -> PaginatedResults[ImageResponse]:
|
||||
dir_path = os.path.join(self.__output_folder, image_type)
|
||||
image_paths = glob(f"{dir_path}/*.png")
|
||||
count = len(image_paths)
|
||||
|
||||
sorted_image_paths = sorted(
|
||||
glob(f"{dir_path}/*.png"), key=os.path.getctime, reverse=True
|
||||
)
|
||||
|
||||
page_of_image_paths = sorted_image_paths[
|
||||
page * per_page : (page + 1) * per_page
|
||||
]
|
||||
|
||||
page_of_images: List[ImageResponse] = []
|
||||
|
||||
for path in page_of_image_paths:
|
||||
filename = os.path.basename(path)
|
||||
img = PILImage.open(path)
|
||||
|
||||
invokeai_metadata = self.__metadata_service.get_metadata(img)
|
||||
|
||||
page_of_images.append(
|
||||
ImageResponse(
|
||||
image_type=image_type.value,
|
||||
image_name=filename,
|
||||
# TODO: DiskImageStorage should not be building URLs...?
|
||||
image_url=f"api/v1/images/{image_type.value}/{filename}",
|
||||
thumbnail_url=f"api/v1/images/{image_type.value}/thumbnails/{os.path.splitext(filename)[0]}.webp",
|
||||
# TODO: Creation of this object should happen elsewhere (?), just making it fit here so it works
|
||||
metadata=ImageResponseMetadata(
|
||||
created=int(os.path.getctime(path)),
|
||||
width=img.width,
|
||||
height=img.height,
|
||||
invokeai=invokeai_metadata,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
page_count_trunc = int(count / per_page)
|
||||
page_count_mod = count % per_page
|
||||
page_count = page_count_trunc if page_count_mod == 0 else page_count_trunc + 1
|
||||
|
||||
return PaginatedResults[ImageResponse](
|
||||
items=page_of_images,
|
||||
page=page,
|
||||
pages=page_count,
|
||||
per_page=per_page,
|
||||
total=count,
|
||||
)
|
||||
|
||||
def get(self, image_type: ImageType, image_name: str) -> Image:
|
||||
image_path = self.get_path(image_type, image_name)
|
||||
cache_item = self.__get_cache(image_path)
|
||||
if cache_item:
|
||||
return cache_item
|
||||
|
||||
image = Image.open(image_path)
|
||||
image = PILImage.open(image_path)
|
||||
self.__set_cache(image_path, image)
|
||||
return image
|
||||
|
||||
# TODO: make this a bit more flexible for e.g. cloud storage
|
||||
def get_path(self, image_type: ImageType, image_name: str) -> str:
|
||||
path = os.path.join(self.__output_folder, image_type, image_name)
|
||||
def get_path(
|
||||
self, image_type: ImageType, image_name: str, is_thumbnail: bool = False
|
||||
) -> str:
|
||||
# strip out any relative path shenanigans
|
||||
basename = os.path.basename(image_name)
|
||||
|
||||
if is_thumbnail:
|
||||
path = os.path.join(
|
||||
self.__output_folder, image_type, "thumbnails", basename
|
||||
)
|
||||
else:
|
||||
path = os.path.join(self.__output_folder, image_type, basename)
|
||||
|
||||
return path
|
||||
|
||||
def save(self, image_type: ImageType, image_name: str, image: Image) -> None:
|
||||
image_subpath = os.path.join(image_type, image_name)
|
||||
self.__pngWriter.save_image_and_prompt_to_png(
|
||||
image, "", image_subpath, None
|
||||
) # TODO: just pass full path to png writer
|
||||
save_thumbnail(
|
||||
image=image,
|
||||
filename=image_name,
|
||||
path=os.path.join(self.__output_folder, image_type, "thumbnails"),
|
||||
)
|
||||
def validate_path(self, path: str) -> bool:
|
||||
try:
|
||||
os.stat(path)
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def save(
|
||||
self,
|
||||
image_type: ImageType,
|
||||
image_name: str,
|
||||
image: Image,
|
||||
metadata: InvokeAIMetadata | None = None,
|
||||
) -> Tuple[str, str, int]:
|
||||
image_path = self.get_path(image_type, image_name)
|
||||
|
||||
# TODO: Reading the image and then saving it strips the metadata...
|
||||
if metadata:
|
||||
pnginfo = build_invokeai_metadata_pnginfo(metadata=metadata)
|
||||
image.save(image_path, "PNG", pnginfo=pnginfo)
|
||||
else:
|
||||
image.save(image_path) # this saved image has an empty info
|
||||
|
||||
thumbnail_name = get_thumbnail_name(image_name)
|
||||
thumbnail_path = self.get_path(image_type, thumbnail_name, is_thumbnail=True)
|
||||
thumbnail_image = make_thumbnail(image)
|
||||
thumbnail_image.save(thumbnail_path)
|
||||
|
||||
self.__set_cache(image_path, image)
|
||||
self.__set_cache(thumbnail_path, thumbnail_image)
|
||||
|
||||
return (image_path, thumbnail_path, int(os.path.getctime(image_path)))
|
||||
|
||||
def delete(self, image_type: ImageType, image_name: str) -> None:
|
||||
image_path = self.get_path(image_type, image_name)
|
||||
thumbnail_path = self.get_path(image_type, image_name, True)
|
||||
if os.path.exists(image_path):
|
||||
os.remove(image_path)
|
||||
|
||||
if image_path in self.__cache:
|
||||
del self.__cache[image_path]
|
||||
|
||||
if os.path.exists(thumbnail_path):
|
||||
os.remove(thumbnail_path)
|
||||
|
||||
if thumbnail_path in self.__cache:
|
||||
del self.__cache[thumbnail_path]
|
||||
|
||||
def __get_cache(self, image_name: str) -> Image:
|
||||
return None if image_name not in self.__cache else self.__cache[image_name]
|
||||
|
||||
|
@ -1,30 +1,17 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
from queue import Queue
|
||||
import time
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
# TODO: make this serializable
|
||||
class InvocationQueueItem:
|
||||
# session_id: str
|
||||
graph_execution_state_id: str
|
||||
invocation_id: str
|
||||
invoke_all: bool
|
||||
timestamp: float
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
# session_id: str,
|
||||
graph_execution_state_id: str,
|
||||
invocation_id: str,
|
||||
invoke_all: bool = False,
|
||||
):
|
||||
# self.session_id = session_id
|
||||
self.graph_execution_state_id = graph_execution_state_id
|
||||
self.invocation_id = invocation_id
|
||||
self.invoke_all = invoke_all
|
||||
self.timestamp = time.time()
|
||||
class InvocationQueueItem(BaseModel):
|
||||
graph_execution_state_id: str = Field(description="The ID of the graph execution state")
|
||||
invocation_id: str = Field(description="The ID of the node being invoked")
|
||||
invoke_all: bool = Field(default=False)
|
||||
timestamp: float = Field(default_factory=time.time)
|
||||
|
||||
|
||||
class InvocationQueueABC(ABC):
|
||||
|
@ -1,4 +1,5 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
from invokeai.app.services.metadata import MetadataServiceBase
|
||||
from invokeai.backend import ModelManager
|
||||
|
||||
from .events import EventServiceBase
|
||||
@ -14,11 +15,13 @@ class InvocationServices:
|
||||
events: EventServiceBase
|
||||
latents: LatentsStorageBase
|
||||
images: ImageStorageBase
|
||||
metadata: MetadataServiceBase
|
||||
queue: InvocationQueueABC
|
||||
model_manager: ModelManager
|
||||
restoration: RestorationServices
|
||||
|
||||
# NOTE: we must forward-declare any types that include invocations, since invocations can use services
|
||||
graph_library: ItemStorageABC["LibraryGraph"]
|
||||
graph_execution_manager: ItemStorageABC["GraphExecutionState"]
|
||||
processor: "InvocationProcessorABC"
|
||||
|
||||
@ -28,7 +31,9 @@ class InvocationServices:
|
||||
events: EventServiceBase,
|
||||
latents: LatentsStorageBase,
|
||||
images: ImageStorageBase,
|
||||
metadata: MetadataServiceBase,
|
||||
queue: InvocationQueueABC,
|
||||
graph_library: ItemStorageABC["LibraryGraph"],
|
||||
graph_execution_manager: ItemStorageABC["GraphExecutionState"],
|
||||
processor: "InvocationProcessorABC",
|
||||
restoration: RestorationServices,
|
||||
@ -37,7 +42,9 @@ class InvocationServices:
|
||||
self.events = events
|
||||
self.latents = latents
|
||||
self.images = images
|
||||
self.metadata = metadata
|
||||
self.queue = queue
|
||||
self.graph_library = graph_library
|
||||
self.graph_execution_manager = graph_execution_manager
|
||||
self.processor = processor
|
||||
self.restoration = restoration
|
||||
|
96
invokeai/app/services/metadata.py
Normal file
96
invokeai/app/services/metadata.py
Normal file
@ -0,0 +1,96 @@
|
||||
import json
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, Optional, TypedDict
|
||||
from PIL import Image, PngImagePlugin
|
||||
from pydantic import BaseModel
|
||||
|
||||
from invokeai.app.models.image import ImageType, is_image_type
|
||||
|
||||
|
||||
class MetadataImageField(TypedDict):
|
||||
"""Pydantic-less ImageField, used for metadata parsing."""
|
||||
|
||||
image_type: ImageType
|
||||
image_name: str
|
||||
|
||||
|
||||
class MetadataLatentsField(TypedDict):
|
||||
"""Pydantic-less LatentsField, used for metadata parsing."""
|
||||
|
||||
latents_name: str
|
||||
|
||||
|
||||
# TODO: This is a placeholder for `InvocationsUnion` pending resolution of circular imports
|
||||
NodeMetadata = Dict[
|
||||
str, str | int | float | bool | MetadataImageField | MetadataLatentsField
|
||||
]
|
||||
|
||||
|
||||
class InvokeAIMetadata(TypedDict, total=False):
|
||||
"""InvokeAI-specific metadata format."""
|
||||
|
||||
session_id: Optional[str]
|
||||
node: Optional[NodeMetadata]
|
||||
|
||||
|
||||
def build_invokeai_metadata_pnginfo(
|
||||
metadata: InvokeAIMetadata | None,
|
||||
) -> PngImagePlugin.PngInfo:
|
||||
"""Builds a PngInfo object with key `"invokeai"` and value `metadata`"""
|
||||
pnginfo = PngImagePlugin.PngInfo()
|
||||
|
||||
if metadata is not None:
|
||||
pnginfo.add_text("invokeai", json.dumps(metadata))
|
||||
|
||||
return pnginfo
|
||||
|
||||
|
||||
class MetadataServiceBase(ABC):
|
||||
@abstractmethod
|
||||
def get_metadata(self, image: Image.Image) -> InvokeAIMetadata | None:
|
||||
"""Gets the InvokeAI metadata from a PIL Image, skipping invalid values"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def build_metadata(
|
||||
self, session_id: str, node: BaseModel
|
||||
) -> InvokeAIMetadata | None:
|
||||
"""Builds an InvokeAIMetadata object"""
|
||||
pass
|
||||
|
||||
|
||||
class PngMetadataService(MetadataServiceBase):
|
||||
"""Handles loading and building metadata for images."""
|
||||
|
||||
# TODO: Use `InvocationsUnion` to **validate** metadata as representing a fully-functioning node
|
||||
def _load_metadata(self, image: Image.Image) -> dict | None:
|
||||
"""Loads a specific info entry from a PIL Image."""
|
||||
|
||||
try:
|
||||
info = image.info.get("invokeai")
|
||||
|
||||
if type(info) is not str:
|
||||
return None
|
||||
|
||||
loaded_metadata = json.loads(info)
|
||||
|
||||
if type(loaded_metadata) is not dict:
|
||||
return None
|
||||
|
||||
if len(loaded_metadata.items()) == 0:
|
||||
return None
|
||||
|
||||
return loaded_metadata
|
||||
except:
|
||||
return None
|
||||
|
||||
def get_metadata(self, image: Image.Image) -> dict | None:
|
||||
"""Retrieves an image's metadata as a dict"""
|
||||
loaded_metadata = self._load_metadata(image)
|
||||
|
||||
return loaded_metadata
|
||||
|
||||
def build_metadata(self, session_id: str, node: BaseModel) -> InvokeAIMetadata:
|
||||
metadata = InvokeAIMetadata(session_id=session_id, node=node.dict())
|
||||
|
||||
return metadata
|
@ -4,7 +4,7 @@ from threading import Event, Thread
|
||||
from ..invocations.baseinvocation import InvocationContext
|
||||
from .invocation_queue import InvocationQueueItem
|
||||
from .invoker import InvocationProcessorABC, Invoker
|
||||
from ..util.util import CanceledException
|
||||
from ..models.exceptions import CanceledException
|
||||
|
||||
class DefaultInvocationProcessor(InvocationProcessorABC):
|
||||
__invoker_thread: Thread
|
||||
@ -43,10 +43,14 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
||||
queue_item.invocation_id
|
||||
)
|
||||
|
||||
# get the source node id to provide to clients (the prepared node id is not as useful)
|
||||
source_node_id = graph_execution_state.prepared_source_mapping[invocation.id]
|
||||
|
||||
# Send starting event
|
||||
self.__invoker.services.events.emit_invocation_started(
|
||||
graph_execution_state_id=graph_execution_state.id,
|
||||
invocation_id=invocation.id,
|
||||
node=invocation.dict(),
|
||||
source_node_id=source_node_id
|
||||
)
|
||||
|
||||
# Invoke
|
||||
@ -75,7 +79,8 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
||||
# Send complete event
|
||||
self.__invoker.services.events.emit_invocation_complete(
|
||||
graph_execution_state_id=graph_execution_state.id,
|
||||
invocation_id=invocation.id,
|
||||
node=invocation.dict(),
|
||||
source_node_id=source_node_id,
|
||||
result=outputs.dict(),
|
||||
)
|
||||
|
||||
@ -99,7 +104,8 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
||||
# Send error event
|
||||
self.__invoker.services.events.emit_invocation_error(
|
||||
graph_execution_state_id=graph_execution_state.id,
|
||||
invocation_id=invocation.id,
|
||||
node=invocation.dict(),
|
||||
source_node_id=source_node_id,
|
||||
error=error,
|
||||
)
|
||||
|
||||
|
0
invokeai/app/util/__init__.py
Normal file
0
invokeai/app/util/__init__.py
Normal file
5
invokeai/app/util/misc.py
Normal file
5
invokeai/app/util/misc.py
Normal file
@ -0,0 +1,5 @@
|
||||
import datetime
|
||||
|
||||
|
||||
def get_timestamp():
|
||||
return int(datetime.datetime.now(datetime.timezone.utc).timestamp())
|
@ -1,25 +0,0 @@
|
||||
import os
|
||||
from PIL import Image
|
||||
|
||||
|
||||
def save_thumbnail(
|
||||
image: Image.Image,
|
||||
filename: str,
|
||||
path: str,
|
||||
size: int = 256,
|
||||
) -> str:
|
||||
"""
|
||||
Saves a thumbnail of an image, returning its path.
|
||||
"""
|
||||
base_filename = os.path.splitext(filename)[0]
|
||||
thumbnail_path = os.path.join(path, base_filename + ".webp")
|
||||
|
||||
if os.path.exists(thumbnail_path):
|
||||
return thumbnail_path
|
||||
|
||||
image_copy = image.copy()
|
||||
image_copy.thumbnail(size=(size, size))
|
||||
|
||||
image_copy.save(thumbnail_path, "WEBP")
|
||||
|
||||
return thumbnail_path
|
55
invokeai/app/util/step_callback.py
Normal file
55
invokeai/app/util/step_callback.py
Normal file
@ -0,0 +1,55 @@
|
||||
from invokeai.app.api.models.images import ProgressImage
|
||||
from invokeai.app.models.exceptions import CanceledException
|
||||
from ..invocations.baseinvocation import InvocationContext
|
||||
from ...backend.util.util import image_to_dataURL
|
||||
from ...backend.generator.base import Generator
|
||||
from ...backend.stable_diffusion import PipelineIntermediateState
|
||||
|
||||
|
||||
def stable_diffusion_step_callback(
|
||||
context: InvocationContext,
|
||||
intermediate_state: PipelineIntermediateState,
|
||||
node: dict,
|
||||
source_node_id: str,
|
||||
):
|
||||
if context.services.queue.is_canceled(context.graph_execution_state_id):
|
||||
raise CanceledException
|
||||
|
||||
# Some schedulers report not only the noisy latents at the current timestep,
|
||||
# but also their estimate so far of what the de-noised latents will be. Use
|
||||
# that estimate if it is available.
|
||||
if intermediate_state.predicted_original is not None:
|
||||
sample = intermediate_state.predicted_original
|
||||
else:
|
||||
sample = intermediate_state.latents
|
||||
|
||||
# TODO: This does not seem to be needed any more?
|
||||
# # txt2img provides a Tensor in the step_callback
|
||||
# # img2img provides a PipelineIntermediateState
|
||||
# if isinstance(sample, PipelineIntermediateState):
|
||||
# # this was an img2img
|
||||
# print('img2img')
|
||||
# latents = sample.latents
|
||||
# step = sample.step
|
||||
# else:
|
||||
# print('txt2img')
|
||||
# latents = sample
|
||||
# step = intermediate_state.step
|
||||
|
||||
# TODO: only output a preview image when requested
|
||||
image = Generator.sample_to_lowres_estimated_image(sample)
|
||||
|
||||
(width, height) = image.size
|
||||
width *= 8
|
||||
height *= 8
|
||||
|
||||
dataURL = image_to_dataURL(image, image_format="JPEG")
|
||||
|
||||
context.services.events.emit_generator_progress(
|
||||
graph_execution_state_id=context.graph_execution_state_id,
|
||||
node=node,
|
||||
source_node_id=source_node_id,
|
||||
progress_image=ProgressImage(width=width, height=height, dataURL=dataURL),
|
||||
step=intermediate_state.step,
|
||||
total_steps=node["steps"],
|
||||
)
|
15
invokeai/app/util/thumbnails.py
Normal file
15
invokeai/app/util/thumbnails.py
Normal file
@ -0,0 +1,15 @@
|
||||
import os
|
||||
from PIL import Image
|
||||
|
||||
|
||||
def get_thumbnail_name(image_name: str) -> str:
|
||||
"""Formats given an image name, returns the appropriate thumbnail image name"""
|
||||
thumbnail_name = os.path.splitext(image_name)[0] + ".webp"
|
||||
return thumbnail_name
|
||||
|
||||
|
||||
def make_thumbnail(image: Image.Image, size: int = 256) -> Image.Image:
|
||||
"""Makes a thumbnail from a PIL Image"""
|
||||
thumbnail = image.copy()
|
||||
thumbnail.thumbnail(size=(size, size))
|
||||
return thumbnail
|
@ -1,42 +0,0 @@
|
||||
import torch
|
||||
from PIL import Image
|
||||
from ..invocations.baseinvocation import InvocationContext
|
||||
from ...backend.util.util import image_to_dataURL
|
||||
from ...backend.generator.base import Generator
|
||||
from ...backend.stable_diffusion import PipelineIntermediateState
|
||||
|
||||
class CanceledException(Exception):
|
||||
pass
|
||||
|
||||
def fast_latents_step_callback(sample: torch.Tensor, step: int, steps: int, id: str, context: InvocationContext, ):
|
||||
# TODO: only output a preview image when requested
|
||||
image = Generator.sample_to_lowres_estimated_image(sample)
|
||||
|
||||
(width, height) = image.size
|
||||
width *= 8
|
||||
height *= 8
|
||||
|
||||
dataURL = image_to_dataURL(image, image_format="JPEG")
|
||||
|
||||
context.services.events.emit_generator_progress(
|
||||
context.graph_execution_state_id,
|
||||
id,
|
||||
{
|
||||
"width": width,
|
||||
"height": height,
|
||||
"dataURL": dataURL
|
||||
},
|
||||
step,
|
||||
steps,
|
||||
)
|
||||
|
||||
def diffusers_step_callback_adapter(*cb_args, **kwargs):
|
||||
"""
|
||||
txt2img gives us a Tensor in the step_callbak, while img2img gives us a PipelineIntermediateState.
|
||||
This adapter grabs the needed data and passes it along to the callback function.
|
||||
"""
|
||||
if isinstance(cb_args[0], PipelineIntermediateState):
|
||||
progress_state: PipelineIntermediateState = cb_args[0]
|
||||
return fast_latents_step_callback(progress_state.latents, progress_state.step, **kwargs)
|
||||
else:
|
||||
return fast_latents_step_callback(*cb_args, **kwargs)
|
@ -561,7 +561,7 @@ class Args(object):
|
||||
"--autoimport",
|
||||
default=None,
|
||||
type=str,
|
||||
help="Check the indicated directory for .ckpt/.safetensors weights files at startup and import directly",
|
||||
help="(DEPRECATED - NONFUNCTIONAL). Check the indicated directory for .ckpt/.safetensors weights files at startup and import directly",
|
||||
)
|
||||
model_group.add_argument(
|
||||
"--autoconvert",
|
||||
|
@ -67,7 +67,6 @@ def install_requested_models(
|
||||
scan_directory: Path = None,
|
||||
external_models: List[str] = None,
|
||||
scan_at_startup: bool = False,
|
||||
convert_to_diffusers: bool = False,
|
||||
precision: str = "float16",
|
||||
purge_deleted: bool = False,
|
||||
config_file_path: Path = None,
|
||||
@ -113,7 +112,6 @@ def install_requested_models(
|
||||
try:
|
||||
model_manager.heuristic_import(
|
||||
path_url_or_repo,
|
||||
convert=convert_to_diffusers,
|
||||
commit_to_conf=config_file_path,
|
||||
)
|
||||
except KeyboardInterrupt:
|
||||
@ -122,7 +120,7 @@ def install_requested_models(
|
||||
pass
|
||||
|
||||
if scan_at_startup and scan_directory.is_dir():
|
||||
argument = "--autoconvert" if convert_to_diffusers else "--autoimport"
|
||||
argument = "--autoconvert"
|
||||
initfile = Path(Globals.root, Globals.initfile)
|
||||
replacement = Path(Globals.root, f"{Globals.initfile}.new")
|
||||
directory = str(scan_directory).replace("\\", "/")
|
||||
|
@ -7,3 +7,4 @@ from .convert_ckpt_to_diffusers import (
|
||||
)
|
||||
from .model_manager import ModelManager
|
||||
|
||||
|
||||
|
@ -1,4 +1,4 @@
|
||||
"""
|
||||
"""enum
|
||||
Manage a cache of Stable Diffusion model files for fast switching.
|
||||
They are moved between GPU and CPU as necessary. If CPU memory falls
|
||||
below a preset minimum, the least recently used model will be
|
||||
@ -15,7 +15,7 @@ import sys
|
||||
import textwrap
|
||||
import time
|
||||
import warnings
|
||||
from enum import Enum
|
||||
from enum import Enum, auto
|
||||
from pathlib import Path
|
||||
from shutil import move, rmtree
|
||||
from typing import Any, Optional, Union, Callable
|
||||
@ -24,8 +24,12 @@ import safetensors
|
||||
import safetensors.torch
|
||||
import torch
|
||||
import transformers
|
||||
from diffusers import AutoencoderKL
|
||||
from diffusers import logging as dlogging
|
||||
from diffusers import (
|
||||
AutoencoderKL,
|
||||
UNet2DConditionModel,
|
||||
SchedulerMixin,
|
||||
logging as dlogging,
|
||||
)
|
||||
from huggingface_hub import scan_cache_dir
|
||||
from omegaconf import OmegaConf
|
||||
from omegaconf.dictconfig import DictConfig
|
||||
@ -33,37 +37,58 @@ from picklescan.scanner import scan_file_path
|
||||
|
||||
from invokeai.backend.globals import Globals, global_cache_dir
|
||||
|
||||
from ..stable_diffusion import StableDiffusionGeneratorPipeline
|
||||
from transformers import (
|
||||
CLIPTextModel,
|
||||
CLIPTokenizer,
|
||||
CLIPFeatureExtractor,
|
||||
)
|
||||
from diffusers.pipelines.stable_diffusion.safety_checker import (
|
||||
StableDiffusionSafetyChecker,
|
||||
)
|
||||
from ..stable_diffusion import (
|
||||
StableDiffusionGeneratorPipeline,
|
||||
)
|
||||
from ..util import CUDA_DEVICE, ask_user, download_with_resume
|
||||
|
||||
class SDLegacyType(Enum):
|
||||
V1 = 1
|
||||
V1_INPAINT = 2
|
||||
V2 = 3
|
||||
V2_e = 4
|
||||
V2_v = 5
|
||||
UNKNOWN = 99
|
||||
|
||||
class SDLegacyType(Enum):
|
||||
V1 = auto()
|
||||
V1_INPAINT = auto()
|
||||
V2 = auto()
|
||||
V2_e = auto()
|
||||
V2_v = auto()
|
||||
UNKNOWN = auto()
|
||||
|
||||
class SDModelComponent(Enum):
|
||||
vae="vae"
|
||||
text_encoder="text_encoder"
|
||||
tokenizer="tokenizer"
|
||||
unet="unet"
|
||||
scheduler="scheduler"
|
||||
safety_checker="safety_checker"
|
||||
feature_extractor="feature_extractor"
|
||||
|
||||
DEFAULT_MAX_MODELS = 2
|
||||
|
||||
class ModelManager(object):
|
||||
'''
|
||||
"""
|
||||
Model manager handles loading, caching, importing, deleting, converting, and editing models.
|
||||
'''
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: OmegaConf|Path,
|
||||
device_type: torch.device = CUDA_DEVICE,
|
||||
precision: str = "float16",
|
||||
max_loaded_models=DEFAULT_MAX_MODELS,
|
||||
sequential_offload=False,
|
||||
embedding_path: Path=None,
|
||||
self,
|
||||
config: OmegaConf | Path,
|
||||
device_type: torch.device = CUDA_DEVICE,
|
||||
precision: str = "float16",
|
||||
max_loaded_models=DEFAULT_MAX_MODELS,
|
||||
sequential_offload=False,
|
||||
embedding_path: Path = None,
|
||||
):
|
||||
"""
|
||||
Initialize with the path to the models.yaml config file or
|
||||
an initialized OmegaConf dictionary. Optional parameters
|
||||
are the torch device type, precision, max_loaded_models,
|
||||
and sequential_offload boolean. Note that the default device
|
||||
and sequential_offload boolean. Note that the default device
|
||||
type and precision are set up for a CUDA system running at half precision.
|
||||
"""
|
||||
# prevent nasty-looking CLIP log message
|
||||
@ -87,15 +112,25 @@ class ModelManager(object):
|
||||
"""
|
||||
return model_name in self.config
|
||||
|
||||
def get_model(self, model_name: str=None)->dict:
|
||||
"""
|
||||
Given a model named identified in models.yaml, return
|
||||
the model object. If in RAM will load into GPU VRAM.
|
||||
If on disk, will load from there.
|
||||
def get_model(self, model_name: str = None) -> dict:
|
||||
"""Given a model named identified in models.yaml, return a dict
|
||||
containing the model object and some of its key features. If
|
||||
in RAM will load into GPU VRAM. If on disk, will load from
|
||||
there.
|
||||
The dict has the following keys:
|
||||
'model': The StableDiffusionGeneratorPipeline object
|
||||
'model_name': The name of the model in models.yaml
|
||||
'width': The width of images trained by this model
|
||||
'height': The height of images trained by this model
|
||||
'hash': A unique hash of this model's files on disk.
|
||||
"""
|
||||
if not model_name:
|
||||
return self.get_model(self.current_model) if self.current_model else self.get_model(self.default_model())
|
||||
|
||||
return (
|
||||
self.get_model(self.current_model)
|
||||
if self.current_model
|
||||
else self.get_model(self.default_model())
|
||||
)
|
||||
|
||||
if not self.valid_model(model_name):
|
||||
print(
|
||||
f'** "{model_name}" is not a known model name. Please check your models.yaml file'
|
||||
@ -135,6 +170,81 @@ class ModelManager(object):
|
||||
"hash": hash,
|
||||
}
|
||||
|
||||
def get_model_vae(self, model_name: str=None)->AutoencoderKL:
|
||||
"""Given a model name identified in models.yaml, load the model into
|
||||
GPU if necessary and return its assigned VAE as an
|
||||
AutoencoderKL object. If no model name is provided, return the
|
||||
vae from the model currently in the GPU.
|
||||
"""
|
||||
return self._get_sub_model(model_name, SDModelComponent.vae)
|
||||
|
||||
def get_model_tokenizer(self, model_name: str=None)->CLIPTokenizer:
|
||||
"""Given a model name identified in models.yaml, load the model into
|
||||
GPU if necessary and return its assigned CLIPTokenizer. If no
|
||||
model name is provided, return the tokenizer from the model
|
||||
currently in the GPU.
|
||||
"""
|
||||
return self._get_sub_model(model_name, SDModelComponent.tokenizer)
|
||||
|
||||
def get_model_unet(self, model_name: str=None)->UNet2DConditionModel:
|
||||
"""Given a model name identified in models.yaml, load the model into
|
||||
GPU if necessary and return its assigned UNet2DConditionModel. If no model
|
||||
name is provided, return the UNet from the model
|
||||
currently in the GPU.
|
||||
"""
|
||||
return self._get_sub_model(model_name, SDModelComponent.unet)
|
||||
|
||||
def get_model_text_encoder(self, model_name: str=None)->CLIPTextModel:
|
||||
"""Given a model name identified in models.yaml, load the model into
|
||||
GPU if necessary and return its assigned CLIPTextModel. If no
|
||||
model name is provided, return the text encoder from the model
|
||||
currently in the GPU.
|
||||
"""
|
||||
return self._get_sub_model(model_name, SDModelComponent.text_encoder)
|
||||
|
||||
def get_model_feature_extractor(self, model_name: str=None)->CLIPFeatureExtractor:
|
||||
"""Given a model name identified in models.yaml, load the model into
|
||||
GPU if necessary and return its assigned CLIPFeatureExtractor. If no
|
||||
model name is provided, return the text encoder from the model
|
||||
currently in the GPU.
|
||||
"""
|
||||
return self._get_sub_model(model_name, SDModelComponent.feature_extractor)
|
||||
|
||||
def get_model_scheduler(self, model_name: str=None)->SchedulerMixin:
|
||||
"""Given a model name identified in models.yaml, load the model into
|
||||
GPU if necessary and return its assigned scheduler. If no
|
||||
model name is provided, return the text encoder from the model
|
||||
currently in the GPU.
|
||||
"""
|
||||
return self._get_sub_model(model_name, SDModelComponent.scheduler)
|
||||
|
||||
def _get_sub_model(
|
||||
self,
|
||||
model_name: str=None,
|
||||
model_part: SDModelComponent=SDModelComponent.vae,
|
||||
) -> Union[
|
||||
AutoencoderKL,
|
||||
CLIPTokenizer,
|
||||
CLIPFeatureExtractor,
|
||||
UNet2DConditionModel,
|
||||
CLIPTextModel,
|
||||
StableDiffusionSafetyChecker,
|
||||
]:
|
||||
"""Given a model name identified in models.yaml, and the part of the
|
||||
model you wish to retrieve, return that part. Parts are in an Enum
|
||||
class named SDModelComponent, and consist of:
|
||||
SDModelComponent.vae
|
||||
SDModelComponent.text_encoder
|
||||
SDModelComponent.tokenizer
|
||||
SDModelComponent.unet
|
||||
SDModelComponent.scheduler
|
||||
SDModelComponent.safety_checker
|
||||
SDModelComponent.feature_extractor
|
||||
"""
|
||||
model_dict = self.get_model(model_name)
|
||||
model = model_dict["model"]
|
||||
return getattr(model, model_part.value)
|
||||
|
||||
def default_model(self) -> str | None:
|
||||
"""
|
||||
Returns the name of the default model, or None
|
||||
@ -360,7 +470,7 @@ class ModelManager(object):
|
||||
f"Unknown model format {model_name}: {model_format}"
|
||||
)
|
||||
self._add_embeddings_to_model(model)
|
||||
|
||||
|
||||
# usage statistics
|
||||
toc = time.time()
|
||||
print(">> Model loaded in", "%4.2fs" % (toc - tic))
|
||||
@ -433,7 +543,7 @@ class ModelManager(object):
|
||||
width = pipeline.unet.config.sample_size * pipeline.vae_scale_factor
|
||||
height = width
|
||||
print(f" | Default image dimensions = {width} x {height}")
|
||||
|
||||
|
||||
return pipeline, width, height, model_hash
|
||||
|
||||
def _load_ckpt_model(self, model_name, mconfig):
|
||||
@ -454,14 +564,18 @@ class ModelManager(object):
|
||||
from . import load_pipeline_from_original_stable_diffusion_ckpt
|
||||
|
||||
try:
|
||||
if self.list_models()[self.current_model]['status'] == 'active':
|
||||
if self.list_models()[self.current_model]["status"] == "active":
|
||||
self.offload_model(self.current_model)
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
|
||||
vae_path = None
|
||||
if vae:
|
||||
vae_path = vae if os.path.isabs(vae) else os.path.normpath(os.path.join(Globals.root, vae))
|
||||
vae_path = (
|
||||
vae
|
||||
if os.path.isabs(vae)
|
||||
else os.path.normpath(os.path.join(Globals.root, vae))
|
||||
)
|
||||
if self._has_cuda():
|
||||
torch.cuda.empty_cache()
|
||||
pipeline = load_pipeline_from_original_stable_diffusion_ckpt(
|
||||
@ -571,9 +685,7 @@ class ModelManager(object):
|
||||
models.yaml file.
|
||||
"""
|
||||
model_name = model_name or Path(repo_or_path).stem
|
||||
model_description = (
|
||||
description or f"Imported diffusers model {model_name}"
|
||||
)
|
||||
model_description = description or f"Imported diffusers model {model_name}"
|
||||
new_config = dict(
|
||||
description=model_description,
|
||||
vae=vae,
|
||||
@ -602,7 +714,7 @@ class ModelManager(object):
|
||||
SDLegacyType.V2_v (V2 using 'v_prediction' prediction type)
|
||||
SDLegacyType.UNKNOWN
|
||||
"""
|
||||
global_step = checkpoint.get('global_step')
|
||||
global_step = checkpoint.get("global_step")
|
||||
state_dict = checkpoint.get("state_dict") or checkpoint
|
||||
|
||||
try:
|
||||
@ -628,13 +740,13 @@ class ModelManager(object):
|
||||
return SDLegacyType.UNKNOWN
|
||||
|
||||
def heuristic_import(
|
||||
self,
|
||||
path_url_or_repo: str,
|
||||
model_name: str = None,
|
||||
description: str = None,
|
||||
model_config_file: Path = None,
|
||||
commit_to_conf: Path = None,
|
||||
config_file_callback: Callable[[Path], Path] = None,
|
||||
self,
|
||||
path_url_or_repo: str,
|
||||
model_name: str = None,
|
||||
description: str = None,
|
||||
model_config_file: Path = None,
|
||||
commit_to_conf: Path = None,
|
||||
config_file_callback: Callable[[Path], Path] = None,
|
||||
) -> str:
|
||||
"""Accept a string which could be:
|
||||
- a HF diffusers repo_id
|
||||
@ -738,8 +850,8 @@ class ModelManager(object):
|
||||
|
||||
# another round of heuristics to guess the correct config file.
|
||||
checkpoint = None
|
||||
if model_path.suffix in [".ckpt",".pt"]:
|
||||
self.scan_model(model_path,model_path)
|
||||
if model_path.suffix in [".ckpt", ".pt"]:
|
||||
self.scan_model(model_path, model_path)
|
||||
checkpoint = torch.load(model_path)
|
||||
else:
|
||||
checkpoint = safetensors.torch.load_file(model_path)
|
||||
@ -761,19 +873,16 @@ class ModelManager(object):
|
||||
elif model_type == SDLegacyType.V1_INPAINT:
|
||||
print(" | SD-v1 inpainting model detected")
|
||||
model_config_file = Path(
|
||||
Globals.root, "configs/stable-diffusion/v1-inpainting-inference.yaml"
|
||||
Globals.root,
|
||||
"configs/stable-diffusion/v1-inpainting-inference.yaml",
|
||||
)
|
||||
elif model_type == SDLegacyType.V2_v:
|
||||
print(
|
||||
" | SD-v2-v model detected"
|
||||
)
|
||||
print(" | SD-v2-v model detected")
|
||||
model_config_file = Path(
|
||||
Globals.root, "configs/stable-diffusion/v2-inference-v.yaml"
|
||||
)
|
||||
elif model_type == SDLegacyType.V2_e:
|
||||
print(
|
||||
" | SD-v2-e model detected"
|
||||
)
|
||||
print(" | SD-v2-e model detected")
|
||||
model_config_file = Path(
|
||||
Globals.root, "configs/stable-diffusion/v2-inference.yaml"
|
||||
)
|
||||
@ -820,16 +929,16 @@ class ModelManager(object):
|
||||
return model_name
|
||||
|
||||
def convert_and_import(
|
||||
self,
|
||||
ckpt_path: Path,
|
||||
diffusers_path: Path,
|
||||
model_name=None,
|
||||
model_description=None,
|
||||
vae:dict=None,
|
||||
vae_path:Path=None,
|
||||
original_config_file: Path = None,
|
||||
commit_to_conf: Path = None,
|
||||
scan_needed: bool=True,
|
||||
self,
|
||||
ckpt_path: Path,
|
||||
diffusers_path: Path,
|
||||
model_name=None,
|
||||
model_description=None,
|
||||
vae: dict = None,
|
||||
vae_path: Path = None,
|
||||
original_config_file: Path = None,
|
||||
commit_to_conf: Path = None,
|
||||
scan_needed: bool = True,
|
||||
) -> str:
|
||||
"""
|
||||
Convert a legacy ckpt weights file to diffuser model and import
|
||||
@ -857,10 +966,10 @@ class ModelManager(object):
|
||||
try:
|
||||
# By passing the specified VAE to the conversion function, the autoencoder
|
||||
# will be built into the model rather than tacked on afterward via the config file
|
||||
vae_model=None
|
||||
vae_model = None
|
||||
if vae:
|
||||
vae_model=self._load_vae(vae)
|
||||
vae_path=None
|
||||
vae_model = self._load_vae(vae)
|
||||
vae_path = None
|
||||
convert_ckpt_to_diffusers(
|
||||
ckpt_path,
|
||||
diffusers_path,
|
||||
@ -976,16 +1085,16 @@ class ModelManager(object):
|
||||
legacy_locations = [
|
||||
Path(
|
||||
models_dir,
|
||||
"CompVis/stable-diffusion-safety-checker/models--CompVis--stable-diffusion-safety-checker"
|
||||
"CompVis/stable-diffusion-safety-checker/models--CompVis--stable-diffusion-safety-checker",
|
||||
),
|
||||
Path(models_dir, "bert-base-uncased/models--bert-base-uncased"),
|
||||
Path(
|
||||
models_dir,
|
||||
"openai/clip-vit-large-patch14/models--openai--clip-vit-large-patch14"
|
||||
"openai/clip-vit-large-patch14/models--openai--clip-vit-large-patch14",
|
||||
),
|
||||
]
|
||||
legacy_locations.extend(list(global_cache_dir("diffusers").glob('*')))
|
||||
|
||||
legacy_locations.extend(list(global_cache_dir("diffusers").glob("*")))
|
||||
|
||||
legacy_layout = False
|
||||
for model in legacy_locations:
|
||||
legacy_layout = legacy_layout or model.exists()
|
||||
@ -1003,7 +1112,7 @@ class ModelManager(object):
|
||||
>> make adjustments, please press ctrl-C now to abort and relaunch InvokeAI when you are ready.
|
||||
>> Otherwise press <enter> to continue."""
|
||||
)
|
||||
input('continue> ')
|
||||
input("continue> ")
|
||||
|
||||
# transformer files get moved into the hub directory
|
||||
if cls._is_huggingface_hub_directory_present():
|
||||
@ -1090,12 +1199,12 @@ class ModelManager(object):
|
||||
print(
|
||||
f'>> Textual inversion triggers: {", ".join(sorted(model.textual_inversion_manager.get_all_trigger_strings()))}'
|
||||
)
|
||||
|
||||
|
||||
def _has_cuda(self) -> bool:
|
||||
return self.device.type == "cuda"
|
||||
|
||||
def _diffuser_sha256(
|
||||
self, name_or_path: Union[str, Path], chunksize=4096
|
||||
self, name_or_path: Union[str, Path], chunksize=16777216
|
||||
) -> Union[str, bytes]:
|
||||
path = None
|
||||
if isinstance(name_or_path, Path):
|
||||
|
@ -57,7 +57,7 @@ class HuggingFaceConceptsLibrary(object):
|
||||
self.concept_list.extend(list(local_concepts_to_add))
|
||||
return self.concept_list
|
||||
return self.concept_list
|
||||
else:
|
||||
elif Globals.internet_available is True:
|
||||
try:
|
||||
models = self.hf_api.list_models(
|
||||
filter=ModelFilter(model_name="sd-concepts-library/")
|
||||
@ -73,6 +73,8 @@ class HuggingFaceConceptsLibrary(object):
|
||||
" ** You may load .bin and .pt file(s) manually using the --embedding_directory argument."
|
||||
)
|
||||
return self.concept_list
|
||||
else:
|
||||
return self.concept_list
|
||||
|
||||
def get_concept_model_path(self, concept_name: str) -> str:
|
||||
"""
|
||||
|
@ -158,14 +158,9 @@ def main():
|
||||
report_model_error(opt, e)
|
||||
|
||||
# try to autoconvert new models
|
||||
if path := opt.autoimport:
|
||||
gen.model_manager.heuristic_import(
|
||||
str(path), convert=False, commit_to_conf=opt.conf
|
||||
)
|
||||
|
||||
if path := opt.autoconvert:
|
||||
gen.model_manager.heuristic_import(
|
||||
str(path), convert=True, commit_to_conf=opt.conf
|
||||
str(path), commit_to_conf=opt.conf
|
||||
)
|
||||
|
||||
# web server loops forever
|
||||
@ -581,6 +576,7 @@ def do_command(command: str, gen, opt: Args, completer) -> tuple:
|
||||
|
||||
elif command.startswith("!replay"):
|
||||
file_path = command.replace("!replay", "", 1).strip()
|
||||
file_path = os.path.join(opt.outdir, file_path)
|
||||
if infile is None and os.path.isfile(file_path):
|
||||
infile = open(file_path, "r", encoding="utf-8")
|
||||
completer.add_history(command)
|
||||
|
@ -199,17 +199,6 @@ class addModelsForm(npyscreen.FormMultiPage):
|
||||
relx=4,
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.nextrely += 1
|
||||
self.convert_models = self.add_widget_intelligent(
|
||||
npyscreen.TitleSelectOne,
|
||||
name="== CONVERT IMPORTED MODELS INTO DIFFUSERS==",
|
||||
values=["Keep original format", "Convert to diffusers"],
|
||||
value=0,
|
||||
begin_entry_at=4,
|
||||
max_height=4,
|
||||
hidden=True, # will appear when imported models box is edited
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.cancel = self.add_widget_intelligent(
|
||||
npyscreen.ButtonPress,
|
||||
name="CANCEL",
|
||||
@ -244,8 +233,6 @@ class addModelsForm(npyscreen.FormMultiPage):
|
||||
self.show_directory_fields.addVisibleWhenSelected(i)
|
||||
|
||||
self.show_directory_fields.when_value_edited = self._clear_scan_directory
|
||||
self.import_model_paths.when_value_edited = self._show_hide_convert
|
||||
self.autoload_directory.when_value_edited = self._show_hide_convert
|
||||
|
||||
def resize(self):
|
||||
super().resize()
|
||||
@ -256,13 +243,6 @@ class addModelsForm(npyscreen.FormMultiPage):
|
||||
if not self.show_directory_fields.value:
|
||||
self.autoload_directory.value = ""
|
||||
|
||||
def _show_hide_convert(self):
|
||||
model_paths = self.import_model_paths.value or ""
|
||||
autoload_directory = self.autoload_directory.value or ""
|
||||
self.convert_models.hidden = (
|
||||
len(model_paths) == 0 and len(autoload_directory) == 0
|
||||
)
|
||||
|
||||
def _get_starter_model_labels(self) -> List[str]:
|
||||
window_width, window_height = get_terminal_size()
|
||||
label_width = 25
|
||||
@ -322,7 +302,6 @@ class addModelsForm(npyscreen.FormMultiPage):
|
||||
.scan_directory: Path to a directory of models to scan and import
|
||||
.autoscan_on_startup: True if invokeai should scan and import at startup time
|
||||
.import_model_paths: list of URLs, repo_ids and file paths to import
|
||||
.convert_to_diffusers: if True, convert legacy checkpoints into diffusers
|
||||
"""
|
||||
# we're using a global here rather than storing the result in the parentapp
|
||||
# due to some bug in npyscreen that is causing attributes to be lost
|
||||
@ -359,7 +338,6 @@ class addModelsForm(npyscreen.FormMultiPage):
|
||||
|
||||
# URLs and the like
|
||||
selections.import_model_paths = self.import_model_paths.value.split()
|
||||
selections.convert_to_diffusers = self.convert_models.value[0] == 1
|
||||
|
||||
|
||||
class AddModelApplication(npyscreen.NPSAppManaged):
|
||||
@ -372,7 +350,6 @@ class AddModelApplication(npyscreen.NPSAppManaged):
|
||||
scan_directory=None,
|
||||
autoscan_on_startup=None,
|
||||
import_model_paths=None,
|
||||
convert_to_diffusers=None,
|
||||
)
|
||||
|
||||
def onStart(self):
|
||||
@ -393,7 +370,6 @@ def process_and_execute(opt: Namespace, selections: Namespace):
|
||||
directory_to_scan = selections.scan_directory
|
||||
scan_at_startup = selections.autoscan_on_startup
|
||||
potential_models_to_install = selections.import_model_paths
|
||||
convert_to_diffusers = selections.convert_to_diffusers
|
||||
|
||||
install_requested_models(
|
||||
install_initial_models=models_to_install,
|
||||
@ -401,7 +377,6 @@ def process_and_execute(opt: Namespace, selections: Namespace):
|
||||
scan_directory=Path(directory_to_scan) if directory_to_scan else None,
|
||||
external_models=potential_models_to_install,
|
||||
scan_at_startup=scan_at_startup,
|
||||
convert_to_diffusers=convert_to_diffusers,
|
||||
precision="float32"
|
||||
if opt.full_precision
|
||||
else choose_precision(torch.device(choose_torch_device())),
|
||||
|
@ -6,3 +6,5 @@ stats.html
|
||||
index.html
|
||||
.yarn/
|
||||
*.scss
|
||||
src/services/api/
|
||||
src/services/fixtures/*
|
||||
|
@ -3,4 +3,8 @@ dist/
|
||||
node_modules/
|
||||
patches/
|
||||
stats.html
|
||||
index.html
|
||||
.yarn/
|
||||
*.scss
|
||||
src/services/api/
|
||||
src/services/fixtures/*
|
||||
|
188
invokeai/frontend/web/dist/assets/App-843b023b.js
vendored
188
invokeai/frontend/web/dist/assets/App-843b023b.js
vendored
File diff suppressed because one or more lines are too long
188
invokeai/frontend/web/dist/assets/App-af7ef809.js
vendored
Normal file
188
invokeai/frontend/web/dist/assets/App-af7ef809.js
vendored
Normal file
File diff suppressed because one or more lines are too long
@ -1,4 +1,4 @@
|
||||
import{j as y,cN as Ie,r as _,cO as bt,q as Lr,cP as o,cQ as b,cR as v,cS as S,cT as Vr,cU as ut,cV as vt,cM as ft,cW as mt,n as gt,cX as ht,E as pt}from"./index-f7f41e1f.js";import{d as yt,i as St,T as xt,j as $t,h as kt}from"./storeHooks-eaf47ae3.js";var Or=`
|
||||
import{j as y,cO as Ie,r as _,cP as bt,q as Lr,cQ as o,cR as b,cS as v,cT as S,cU as Vr,cV as ut,cW as vt,cN as ft,cX as mt,n as gt,cY as ht,E as pt}from"./index-e53e8108.js";import{d as yt,i as St,T as xt,j as $t,h as kt}from"./storeHooks-5cde7d31.js";var Or=`
|
||||
:root {
|
||||
--chakra-vh: 100vh;
|
||||
}
|
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
2
invokeai/frontend/web/dist/index.html
vendored
2
invokeai/frontend/web/dist/index.html
vendored
@ -12,7 +12,7 @@
|
||||
margin: 0;
|
||||
}
|
||||
</style>
|
||||
<script type="module" crossorigin src="./assets/index-f7f41e1f.js"></script>
|
||||
<script type="module" crossorigin src="./assets/index-e53e8108.js"></script>
|
||||
<link rel="stylesheet" href="./assets/index-5483945c.css">
|
||||
</head>
|
||||
|
||||
|
1
invokeai/frontend/web/dist/locales/ar.json
vendored
1
invokeai/frontend/web/dist/locales/ar.json
vendored
@ -8,7 +8,6 @@
|
||||
"darkTheme": "داكن",
|
||||
"lightTheme": "فاتح",
|
||||
"greenTheme": "أخضر",
|
||||
"text2img": "نص إلى صورة",
|
||||
"img2img": "صورة إلى صورة",
|
||||
"unifiedCanvas": "لوحة موحدة",
|
||||
"nodes": "عقد",
|
||||
|
1
invokeai/frontend/web/dist/locales/de.json
vendored
1
invokeai/frontend/web/dist/locales/de.json
vendored
@ -7,7 +7,6 @@
|
||||
"darkTheme": "Dunkel",
|
||||
"lightTheme": "Hell",
|
||||
"greenTheme": "Grün",
|
||||
"text2img": "Text zu Bild",
|
||||
"img2img": "Bild zu Bild",
|
||||
"nodes": "Knoten",
|
||||
"langGerman": "Deutsch",
|
||||
|
4
invokeai/frontend/web/dist/locales/en.json
vendored
4
invokeai/frontend/web/dist/locales/en.json
vendored
@ -505,7 +505,9 @@
|
||||
"info": "Info",
|
||||
"deleteImage": "Delete Image",
|
||||
"initialImage": "Initial Image",
|
||||
"showOptionsPanel": "Show Options Panel"
|
||||
"showOptionsPanel": "Show Options Panel",
|
||||
"hidePreview": "Hide Preview",
|
||||
"showPreview": "Show Preview"
|
||||
},
|
||||
"settings": {
|
||||
"models": "Models",
|
||||
|
12
invokeai/frontend/web/dist/locales/es.json
vendored
12
invokeai/frontend/web/dist/locales/es.json
vendored
@ -8,7 +8,6 @@
|
||||
"darkTheme": "Oscuro",
|
||||
"lightTheme": "Claro",
|
||||
"greenTheme": "Verde",
|
||||
"text2img": "Texto a Imagen",
|
||||
"img2img": "Imagen a Imagen",
|
||||
"unifiedCanvas": "Lienzo Unificado",
|
||||
"nodes": "Nodos",
|
||||
@ -70,7 +69,11 @@
|
||||
"langHebrew": "Hebreo",
|
||||
"pinOptionsPanel": "Pin del panel de opciones",
|
||||
"loading": "Cargando",
|
||||
"loadingInvokeAI": "Cargando invocar a la IA"
|
||||
"loadingInvokeAI": "Cargando invocar a la IA",
|
||||
"postprocessing": "Tratamiento posterior",
|
||||
"txt2img": "De texto a imagen",
|
||||
"accept": "Aceptar",
|
||||
"cancel": "Cancelar"
|
||||
},
|
||||
"gallery": {
|
||||
"generations": "Generaciones",
|
||||
@ -404,7 +407,8 @@
|
||||
"none": "ninguno",
|
||||
"pickModelType": "Elige el tipo de modelo",
|
||||
"v2_768": "v2 (768px)",
|
||||
"addDifference": "Añadir una diferencia"
|
||||
"addDifference": "Añadir una diferencia",
|
||||
"scanForModels": "Buscar modelos"
|
||||
},
|
||||
"parameters": {
|
||||
"images": "Imágenes",
|
||||
@ -574,7 +578,7 @@
|
||||
"autoSaveToGallery": "Guardar automáticamente en galería",
|
||||
"saveBoxRegionOnly": "Guardar solo región dentro de la caja",
|
||||
"limitStrokesToBox": "Limitar trazos a la caja",
|
||||
"showCanvasDebugInfo": "Mostrar información de depuración de lienzo",
|
||||
"showCanvasDebugInfo": "Mostrar la información adicional del lienzo",
|
||||
"clearCanvasHistory": "Limpiar historial de lienzo",
|
||||
"clearHistory": "Limpiar historial",
|
||||
"clearCanvasHistoryMessage": "Limpiar el historial de lienzo también restablece completamente el lienzo unificado. Esto incluye todo el historial de deshacer/rehacer, las imágenes en el área de preparación y la capa base del lienzo.",
|
||||
|
25
invokeai/frontend/web/dist/locales/fr.json
vendored
25
invokeai/frontend/web/dist/locales/fr.json
vendored
@ -8,7 +8,6 @@
|
||||
"darkTheme": "Sombre",
|
||||
"lightTheme": "Clair",
|
||||
"greenTheme": "Vert",
|
||||
"text2img": "Texte en image",
|
||||
"img2img": "Image en image",
|
||||
"unifiedCanvas": "Canvas unifié",
|
||||
"nodes": "Nœuds",
|
||||
@ -47,7 +46,19 @@
|
||||
"statusLoadingModel": "Chargement du modèle",
|
||||
"statusModelChanged": "Modèle changé",
|
||||
"discordLabel": "Discord",
|
||||
"githubLabel": "Github"
|
||||
"githubLabel": "Github",
|
||||
"accept": "Accepter",
|
||||
"statusMergingModels": "Mélange des modèles",
|
||||
"loadingInvokeAI": "Chargement de Invoke AI",
|
||||
"cancel": "Annuler",
|
||||
"langEnglish": "Anglais",
|
||||
"statusConvertingModel": "Conversion du modèle",
|
||||
"statusModelConverted": "Modèle converti",
|
||||
"loading": "Chargement",
|
||||
"pinOptionsPanel": "Épingler la page d'options",
|
||||
"statusMergedModels": "Modèles mélangés",
|
||||
"txt2img": "Texte vers image",
|
||||
"postprocessing": "Post-Traitement"
|
||||
},
|
||||
"gallery": {
|
||||
"generations": "Générations",
|
||||
@ -518,5 +529,15 @@
|
||||
"betaDarkenOutside": "Assombrir à l'extérieur",
|
||||
"betaLimitToBox": "Limiter à la boîte",
|
||||
"betaPreserveMasked": "Conserver masqué"
|
||||
},
|
||||
"accessibility": {
|
||||
"uploadImage": "Charger une image",
|
||||
"reset": "Réinitialiser",
|
||||
"nextImage": "Image suivante",
|
||||
"previousImage": "Image précédente",
|
||||
"useThisParameter": "Utiliser ce paramètre",
|
||||
"zoomIn": "Zoom avant",
|
||||
"zoomOut": "Zoom arrière",
|
||||
"showOptionsPanel": "Montrer la page d'options"
|
||||
}
|
||||
}
|
||||
|
1
invokeai/frontend/web/dist/locales/he.json
vendored
1
invokeai/frontend/web/dist/locales/he.json
vendored
@ -125,7 +125,6 @@
|
||||
"langSimplifiedChinese": "סינית",
|
||||
"langUkranian": "אוקראינית",
|
||||
"langSpanish": "ספרדית",
|
||||
"text2img": "טקסט לתמונה",
|
||||
"img2img": "תמונה לתמונה",
|
||||
"unifiedCanvas": "קנבס מאוחד",
|
||||
"nodes": "צמתים",
|
||||
|
14
invokeai/frontend/web/dist/locales/it.json
vendored
14
invokeai/frontend/web/dist/locales/it.json
vendored
@ -8,7 +8,6 @@
|
||||
"darkTheme": "Scuro",
|
||||
"lightTheme": "Chiaro",
|
||||
"greenTheme": "Verde",
|
||||
"text2img": "Testo a Immagine",
|
||||
"img2img": "Immagine a Immagine",
|
||||
"unifiedCanvas": "Tela unificata",
|
||||
"nodes": "Nodi",
|
||||
@ -70,7 +69,11 @@
|
||||
"loading": "Caricamento in corso",
|
||||
"oceanTheme": "Oceano",
|
||||
"langHebrew": "Ebraico",
|
||||
"loadingInvokeAI": "Caricamento Invoke AI"
|
||||
"loadingInvokeAI": "Caricamento Invoke AI",
|
||||
"postprocessing": "Post Elaborazione",
|
||||
"txt2img": "Testo a Immagine",
|
||||
"accept": "Accetta",
|
||||
"cancel": "Annulla"
|
||||
},
|
||||
"gallery": {
|
||||
"generations": "Generazioni",
|
||||
@ -404,7 +407,8 @@
|
||||
"v2_768": "v2 (768px)",
|
||||
"none": "niente",
|
||||
"addDifference": "Aggiungi differenza",
|
||||
"pickModelType": "Scegli il tipo di modello"
|
||||
"pickModelType": "Scegli il tipo di modello",
|
||||
"scanForModels": "Cerca modelli"
|
||||
},
|
||||
"parameters": {
|
||||
"images": "Immagini",
|
||||
@ -574,7 +578,7 @@
|
||||
"autoSaveToGallery": "Salvataggio automatico nella Galleria",
|
||||
"saveBoxRegionOnly": "Salva solo l'area di selezione",
|
||||
"limitStrokesToBox": "Limita i tratti all'area di selezione",
|
||||
"showCanvasDebugInfo": "Mostra informazioni di debug della Tela",
|
||||
"showCanvasDebugInfo": "Mostra ulteriori informazioni sulla Tela",
|
||||
"clearCanvasHistory": "Cancella cronologia Tela",
|
||||
"clearHistory": "Cancella la cronologia",
|
||||
"clearCanvasHistoryMessage": "La cancellazione della cronologia della tela lascia intatta la tela corrente, ma cancella in modo irreversibile la cronologia degli annullamenti e dei ripristini.",
|
||||
@ -612,7 +616,7 @@
|
||||
"copyMetadataJson": "Copia i metadati JSON",
|
||||
"exitViewer": "Esci dal visualizzatore",
|
||||
"zoomIn": "Zoom avanti",
|
||||
"zoomOut": "Zoom Indietro",
|
||||
"zoomOut": "Zoom indietro",
|
||||
"rotateCounterClockwise": "Ruotare in senso antiorario",
|
||||
"rotateClockwise": "Ruotare in senso orario",
|
||||
"flipHorizontally": "Capovolgi orizzontalmente",
|
||||
|
1
invokeai/frontend/web/dist/locales/ko.json
vendored
1
invokeai/frontend/web/dist/locales/ko.json
vendored
@ -11,7 +11,6 @@
|
||||
"langArabic": "العربية",
|
||||
"langEnglish": "English",
|
||||
"langDutch": "Nederlands",
|
||||
"text2img": "텍스트->이미지",
|
||||
"unifiedCanvas": "통합 캔버스",
|
||||
"langFrench": "Français",
|
||||
"langGerman": "Deutsch",
|
||||
|
1
invokeai/frontend/web/dist/locales/nl.json
vendored
1
invokeai/frontend/web/dist/locales/nl.json
vendored
@ -8,7 +8,6 @@
|
||||
"darkTheme": "Donker",
|
||||
"lightTheme": "Licht",
|
||||
"greenTheme": "Groen",
|
||||
"text2img": "Tekst naar afbeelding",
|
||||
"img2img": "Afbeelding naar afbeelding",
|
||||
"unifiedCanvas": "Centraal canvas",
|
||||
"nodes": "Knooppunten",
|
||||
|
1
invokeai/frontend/web/dist/locales/pl.json
vendored
1
invokeai/frontend/web/dist/locales/pl.json
vendored
@ -8,7 +8,6 @@
|
||||
"darkTheme": "Ciemny",
|
||||
"lightTheme": "Jasny",
|
||||
"greenTheme": "Zielony",
|
||||
"text2img": "Tekst na obraz",
|
||||
"img2img": "Obraz na obraz",
|
||||
"unifiedCanvas": "Tryb uniwersalny",
|
||||
"nodes": "Węzły",
|
||||
|
1
invokeai/frontend/web/dist/locales/pt.json
vendored
1
invokeai/frontend/web/dist/locales/pt.json
vendored
@ -20,7 +20,6 @@
|
||||
"langSpanish": "Espanhol",
|
||||
"langRussian": "Русский",
|
||||
"langUkranian": "Украї́нська",
|
||||
"text2img": "Texto para Imagem",
|
||||
"img2img": "Imagem para Imagem",
|
||||
"unifiedCanvas": "Tela Unificada",
|
||||
"nodes": "Nós",
|
||||
|
@ -8,7 +8,6 @@
|
||||
"darkTheme": "Noite",
|
||||
"lightTheme": "Dia",
|
||||
"greenTheme": "Verde",
|
||||
"text2img": "Texto Para Imagem",
|
||||
"img2img": "Imagem Para Imagem",
|
||||
"unifiedCanvas": "Tela Unificada",
|
||||
"nodes": "Nódulos",
|
||||
|
1
invokeai/frontend/web/dist/locales/ru.json
vendored
1
invokeai/frontend/web/dist/locales/ru.json
vendored
@ -8,7 +8,6 @@
|
||||
"darkTheme": "Темная",
|
||||
"lightTheme": "Светлая",
|
||||
"greenTheme": "Зеленая",
|
||||
"text2img": "Изображение из текста (text2img)",
|
||||
"img2img": "Изображение в изображение (img2img)",
|
||||
"unifiedCanvas": "Универсальный холст",
|
||||
"nodes": "Ноды",
|
||||
|
1
invokeai/frontend/web/dist/locales/uk.json
vendored
1
invokeai/frontend/web/dist/locales/uk.json
vendored
@ -8,7 +8,6 @@
|
||||
"darkTheme": "Темна",
|
||||
"lightTheme": "Світла",
|
||||
"greenTheme": "Зелена",
|
||||
"text2img": "Зображення із тексту (text2img)",
|
||||
"img2img": "Зображення із зображення (img2img)",
|
||||
"unifiedCanvas": "Універсальне полотно",
|
||||
"nodes": "Вузли",
|
||||
|
@ -8,7 +8,6 @@
|
||||
"darkTheme": "暗色",
|
||||
"lightTheme": "亮色",
|
||||
"greenTheme": "绿色",
|
||||
"text2img": "文字到图像",
|
||||
"img2img": "图像到图像",
|
||||
"unifiedCanvas": "统一画布",
|
||||
"nodes": "节点",
|
||||
|
@ -33,7 +33,6 @@
|
||||
"langBrPortuguese": "巴西葡萄牙語",
|
||||
"langRussian": "俄語",
|
||||
"langSpanish": "西班牙語",
|
||||
"text2img": "文字到圖像",
|
||||
"unifiedCanvas": "統一畫布"
|
||||
}
|
||||
}
|
||||
|
87
invokeai/frontend/web/docs/API_CLIENT.md
Normal file
87
invokeai/frontend/web/docs/API_CLIENT.md
Normal file
@ -0,0 +1,87 @@
|
||||
# Generated axios API client
|
||||
|
||||
- [Generated axios API client](#generated-axios-api-client)
|
||||
- [Generation](#generation)
|
||||
- [Generate the API client from the nodes web server](#generate-the-api-client-from-the-nodes-web-server)
|
||||
- [Generate the API client from JSON](#generate-the-api-client-from-json)
|
||||
- [Getting the JSON from the nodes web server](#getting-the-json-from-the-nodes-web-server)
|
||||
- [Getting the JSON with a python script](#getting-the-json-with-a-python-script)
|
||||
- [Generate the API client](#generate-the-api-client)
|
||||
- [The generated client](#the-generated-client)
|
||||
- [API client customisation](#api-client-customisation)
|
||||
|
||||
This API client is generated by an [openapi code generator](https://github.com/ferdikoomen/openapi-typescript-codegen).
|
||||
|
||||
All files in `invokeai/frontend/web/src/services/api/` are made by the generator.
|
||||
|
||||
## Generation
|
||||
|
||||
The axios client may be generated by from the OpenAPI schema from the nodes web server, or from JSON.
|
||||
|
||||
### Generate the API client from the nodes web server
|
||||
|
||||
We need to start the nodes web server, which serves the OpenAPI schema to the generator.
|
||||
|
||||
1. Start the nodes web server.
|
||||
|
||||
```bash
|
||||
# from the repo root
|
||||
python scripts/invoke-new.py --web
|
||||
```
|
||||
|
||||
2. Generate the API client.
|
||||
|
||||
```bash
|
||||
# from invokeai/frontend/web/
|
||||
yarn api:web
|
||||
```
|
||||
|
||||
### Generate the API client from JSON
|
||||
|
||||
The JSON can be acquired from the nodes web server, or with a python script.
|
||||
|
||||
#### Getting the JSON from the nodes web server
|
||||
|
||||
Start the nodes web server as described above, then download the file.
|
||||
|
||||
```bash
|
||||
# from invokeai/frontend/web/
|
||||
curl http://localhost:9090/openapi.json -o openapi.json
|
||||
```
|
||||
|
||||
#### Getting the JSON with a python script
|
||||
|
||||
Run this python script from the repo root, so it can access the nodes server modules.
|
||||
|
||||
The script will output `openapi.json` in the repo root. Then we need to move it to `invokeai/frontend/web/`.
|
||||
|
||||
```bash
|
||||
# from the repo root
|
||||
python invokeai/app/util/generate_openapi_json.py
|
||||
mv invokeai/app/util/openapi.json invokeai/frontend/web/services/fixtures/
|
||||
```
|
||||
|
||||
#### Generate the API client
|
||||
|
||||
Now we can generate the API client from the JSON.
|
||||
|
||||
```bash
|
||||
# from invokeai/frontend/web/
|
||||
yarn api:file
|
||||
```
|
||||
|
||||
## The generated client
|
||||
|
||||
The client will be written to `invokeai/frontend/web/services/api/`:
|
||||
|
||||
- `axios` client
|
||||
- TS types
|
||||
- An easily parseable schema, which we can use to generate UI
|
||||
|
||||
## API client customisation
|
||||
|
||||
The generator has a default `request.ts` file that implements a base `axios` client. The generated client uses this base client.
|
||||
|
||||
One shortcoming of this is base client is it does not provide response headers unless the response body is empty. To fix this, we provide our own lightly-patched `request.ts`.
|
||||
|
||||
To access the headers, call `getHeaders(response)` on any response from the generated api client. This function is exported from `invokeai/frontend/web/src/services/util/getHeaders.ts`.
|
21
invokeai/frontend/web/docs/EVENTS.md
Normal file
21
invokeai/frontend/web/docs/EVENTS.md
Normal file
@ -0,0 +1,21 @@
|
||||
# Events
|
||||
|
||||
Events via `socket.io`
|
||||
|
||||
## `actions.ts`
|
||||
|
||||
Redux actions for all socket events. Payloads all include a timestamp, and optionally some other data.
|
||||
|
||||
Any reducer (or middleware) can respond to the actions.
|
||||
|
||||
## `middleware.ts`
|
||||
|
||||
Redux middleware for events.
|
||||
|
||||
Handles dispatching the event actions. Only put logic here if it can't really go anywhere else.
|
||||
|
||||
For example, on connect we want to load images to the gallery if it's not populated. This requires dispatching a thunk, so we need to directly dispatch this in the middleware.
|
||||
|
||||
## `types.ts`
|
||||
|
||||
Hand-written types for the socket events. Cannot generate these from the server, but fortunately they are few and simple.
|
17
invokeai/frontend/web/docs/NODE_EDITOR.md
Normal file
17
invokeai/frontend/web/docs/NODE_EDITOR.md
Normal file
@ -0,0 +1,17 @@
|
||||
# Node Editor Design
|
||||
|
||||
WIP
|
||||
|
||||
nodes
|
||||
|
||||
everything in `src/features/nodes/`
|
||||
|
||||
have a look at `state.nodes.invocation`
|
||||
|
||||
- on socket connect, if no schema saved, fetch `localhost:9090/openapi.json`, save JSON to `state.nodes.schema`
|
||||
- on fulfilled schema fetch, `parseSchema()` the schema. this outputs a `Record<string, Invocation>` which is saved to `state.nodes.invocations` - `Invocation` is like a template for the node
|
||||
- when you add a node, the the `Invocation` template is passed to `InvocationComponent.tsx` to build the UI component for that node
|
||||
- inputs/outputs have field types - and each field type gets an `FieldComponent` which includes a dispatcher to write state changes to redux `nodesSlice`
|
||||
- `reactflow` sends changes to nodes/edges to redux
|
||||
- to invoke, `buildNodesGraph()` state, then send this
|
||||
- changed onClick Invoke button actions to build the schema, then when schema builds it dispatches the actual network request to create the session - see `session.ts`
|
29
invokeai/frontend/web/docs/PACKAGE_SCRIPTS.md
Normal file
29
invokeai/frontend/web/docs/PACKAGE_SCRIPTS.md
Normal file
@ -0,0 +1,29 @@
|
||||
# Package Scripts
|
||||
|
||||
WIP walkthrough of `package.json` scripts.
|
||||
|
||||
## `theme` & `theme:watch`
|
||||
|
||||
These run the Chakra CLI to generate types for the theme, or watch for code change and re-generate the types.
|
||||
|
||||
The CLI essentially monkeypatches Chakra's files in `node_modules`.
|
||||
|
||||
## `postinstall`
|
||||
|
||||
The `postinstall` script patches a few packages and runs the Chakra CLI to generate types for the theme.
|
||||
|
||||
### Patch `@chakra-ui/cli`
|
||||
|
||||
See: <https://github.com/chakra-ui/chakra-ui/issues/7394>
|
||||
|
||||
### Patch `redux-persist`
|
||||
|
||||
We want to persist the canvas state to `localStorage` but many canvas operations change data very quickly, so we need to debounce the writes to `localStorage`.
|
||||
|
||||
`redux-persist` is unfortunately unmaintained. The repo's current code is nonfunctional, but the last release's code depends on a package that was removed from `npm` for being malware, so we cannot just fork it.
|
||||
|
||||
So, we have to patch it directly. Perhaps a better way would be to write a debounced storage adapter, but I couldn't figure out how to do that.
|
||||
|
||||
### Patch `redux-deep-persist`
|
||||
|
||||
This package makes blacklisting and whitelisting persist configs very simple, but we have to patch it to match `redux-persist` for the types to work.
|
@ -1,10 +1,16 @@
|
||||
# InvokeAI Web UI
|
||||
|
||||
- [InvokeAI Web UI](#invokeai-web-ui)
|
||||
- [Stack](#stack)
|
||||
- [Contributing](#contributing)
|
||||
- [Dev Environment](#dev-environment)
|
||||
- [Production builds](#production-builds)
|
||||
|
||||
The UI is a fairly straightforward Typescript React app. The only really fancy stuff is the Unified Canvas.
|
||||
|
||||
Code in `invokeai/frontend/web/` if you want to have a look.
|
||||
|
||||
## Details
|
||||
## Stack
|
||||
|
||||
State management is Redux via [Redux Toolkit](https://github.com/reduxjs/redux-toolkit). Communication with server is a mix of HTTP and [socket.io](https://github.com/socketio/socket.io-client) (with a custom redux middleware to help).
|
||||
|
||||
@ -32,7 +38,7 @@ Start everything in dev mode:
|
||||
|
||||
1. Start the dev server: `yarn dev`
|
||||
2. Start the InvokeAI UI per usual: `invokeai --web`
|
||||
3. Point your browser to the dev server address e.g. `http://localhost:5173/`
|
||||
3. Point your browser to the dev server address e.g. <http://localhost:5173/>
|
||||
|
||||
### Production builds
|
||||
|
21
invokeai/frontend/web/index.d.ts
vendored
21
invokeai/frontend/web/index.d.ts
vendored
@ -1,6 +1,7 @@
|
||||
import React, { PropsWithChildren } from 'react';
|
||||
import { IAIPopoverProps } from '../web/src/common/components/IAIPopover';
|
||||
import { IAIIconButtonProps } from '../web/src/common/components/IAIIconButton';
|
||||
import { InvokeTabName } from 'features/ui/store/tabMap';
|
||||
|
||||
export {};
|
||||
|
||||
@ -64,9 +65,25 @@ declare module '@invoke-ai/invoke-ai-ui' {
|
||||
declare class SettingsModal extends React.Component<SettingsModalProps> {
|
||||
public constructor(props: SettingsModalProps);
|
||||
}
|
||||
|
||||
declare class StatusIndicator extends React.Component<StatusIndicatorProps> {
|
||||
public constructor(props: StatusIndicatorProps);
|
||||
}
|
||||
|
||||
declare class ModelSelect extends React.Component<ModelSelectProps> {
|
||||
public constructor(props: ModelSelectProps);
|
||||
}
|
||||
}
|
||||
|
||||
declare function Invoke(props: PropsWithChildren): JSX.Element;
|
||||
interface InvokeProps extends PropsWithChildren {
|
||||
apiUrl?: string;
|
||||
disabledPanels?: string[];
|
||||
disabledTabs?: InvokeTabName[];
|
||||
token?: string;
|
||||
shouldTransformUrls?: boolean;
|
||||
}
|
||||
|
||||
declare function Invoke(props: InvokeProps): JSX.Element;
|
||||
|
||||
export {
|
||||
ThemeChanger,
|
||||
@ -74,5 +91,7 @@ export {
|
||||
IAIPopover,
|
||||
IAIIconButton,
|
||||
SettingsModal,
|
||||
StatusIndicator,
|
||||
ModelSelect,
|
||||
};
|
||||
export = Invoke;
|
||||
|
@ -5,7 +5,10 @@
|
||||
"scripts": {
|
||||
"prepare": "cd ../../../ && husky install invokeai/frontend/web/.husky",
|
||||
"dev": "concurrently \"vite dev\" \"yarn run theme:watch\"",
|
||||
"dev:nodes": "concurrently \"vite dev --mode nodes\" \"yarn run theme:watch\"",
|
||||
"build": "yarn run lint && vite build",
|
||||
"api:web": "openapi -i http://localhost:9090/openapi.json -o src/services/api --client axios --useOptions --useUnionTypes --exportSchemas true --indent 2 --request src/services/fixtures/request.ts",
|
||||
"api:file": "openapi -i src/services/fixtures/openapi.json -o src/services/api --client axios --useOptions --useUnionTypes --exportSchemas true --indent 2 --request src/services/fixtures/request.ts",
|
||||
"preview": "vite preview",
|
||||
"lint:madge": "madge --circular src/main.tsx",
|
||||
"lint:eslint": "eslint --max-warnings=0 .",
|
||||
@ -41,9 +44,11 @@
|
||||
"@chakra-ui/react": "^2.5.1",
|
||||
"@chakra-ui/styled-system": "^2.6.1",
|
||||
"@chakra-ui/theme-tools": "^2.0.16",
|
||||
"@dagrejs/graphlib": "^2.1.12",
|
||||
"@emotion/react": "^11.10.6",
|
||||
"@emotion/styled": "^11.10.6",
|
||||
"@reduxjs/toolkit": "^1.9.2",
|
||||
"@fontsource/inter": "^4.5.15",
|
||||
"@reduxjs/toolkit": "^1.9.3",
|
||||
"chakra-ui-contextmenu": "^1.0.5",
|
||||
"dateformat": "^5.0.3",
|
||||
"formik": "^2.2.9",
|
||||
@ -67,15 +72,17 @@
|
||||
"react-redux": "^8.0.5",
|
||||
"react-transition-group": "^4.4.5",
|
||||
"react-zoom-pan-pinch": "^2.6.1",
|
||||
"reactflow": "^11.7.0",
|
||||
"redux-deep-persist": "^1.0.7",
|
||||
"redux-dynamic-middlewares": "^2.2.0",
|
||||
"redux-persist": "^6.0.0",
|
||||
"socket.io-client": "^4.6.0",
|
||||
"use-image": "^1.1.0",
|
||||
"uuid": "^9.0.0"
|
||||
},
|
||||
"devDependencies": {
|
||||
"@fontsource/inter": "^4.5.15",
|
||||
"@types/dateformat": "^5.0.0",
|
||||
"@types/lodash": "^4.14.194",
|
||||
"@types/react": "^18.0.28",
|
||||
"@types/react-dom": "^18.0.11",
|
||||
"@types/react-transition-group": "^4.4.5",
|
||||
@ -83,6 +90,7 @@
|
||||
"@typescript-eslint/eslint-plugin": "^5.52.0",
|
||||
"@typescript-eslint/parser": "^5.52.0",
|
||||
"@vitejs/plugin-react-swc": "^3.2.0",
|
||||
"axios": "^1.3.4",
|
||||
"babel-plugin-transform-imports": "^2.0.0",
|
||||
"concurrently": "^7.6.0",
|
||||
"eslint": "^8.34.0",
|
||||
@ -90,13 +98,17 @@
|
||||
"eslint-plugin-prettier": "^4.2.1",
|
||||
"eslint-plugin-react": "^7.32.2",
|
||||
"eslint-plugin-react-hooks": "^4.6.0",
|
||||
"form-data": "^4.0.0",
|
||||
"husky": "^8.0.3",
|
||||
"lint-staged": "^13.1.2",
|
||||
"madge": "^6.0.0",
|
||||
"openapi-types": "^12.1.0",
|
||||
"openapi-typescript-codegen": "^0.23.0",
|
||||
"postinstall-postinstall": "^2.1.0",
|
||||
"prettier": "^2.8.4",
|
||||
"rollup-plugin-visualizer": "^5.9.0",
|
||||
"terser": "^5.16.4",
|
||||
"typescript": "4.9.5",
|
||||
"vite": "^4.1.2",
|
||||
"vite-plugin-eslint": "^1.8.1",
|
||||
"vite-tsconfig-paths": "^4.0.5",
|
||||
|
@ -52,6 +52,7 @@
|
||||
"txt2img": "Text To Image",
|
||||
"img2img": "Image To Image",
|
||||
"unifiedCanvas": "Unified Canvas",
|
||||
"linear": "Linear",
|
||||
"nodes": "Nodes",
|
||||
"postprocessing": "Post Processing",
|
||||
"nodesDesc": "A node based system for the generation of images is under development currently. Stay tuned for updates about this amazing feature.",
|
||||
@ -505,7 +506,9 @@
|
||||
"info": "Info",
|
||||
"deleteImage": "Delete Image",
|
||||
"initialImage": "Initial Image",
|
||||
"showOptionsPanel": "Show Options Panel"
|
||||
"showOptionsPanel": "Show Options Panel",
|
||||
"hidePreview": "Hide Preview",
|
||||
"showPreview": "Show Preview"
|
||||
},
|
||||
"settings": {
|
||||
"models": "Models",
|
||||
@ -522,6 +525,10 @@
|
||||
"resetComplete": "Web UI has been reset. Refresh the page to reload."
|
||||
},
|
||||
"toast": {
|
||||
"serverError": "Server Error",
|
||||
"disconnected": "Disconnected from Server",
|
||||
"connected": "Connected to Server",
|
||||
"canceled": "Processing Canceled",
|
||||
"tempFoldersEmptied": "Temp Folder Emptied",
|
||||
"uploadFailed": "Upload failed",
|
||||
"uploadFailedMultipleImagesDesc": "Multiple images pasted, may only upload one image at a time",
|
||||
|
@ -13,16 +13,42 @@ import { Box, Flex, Grid, Portal, useColorMode } from '@chakra-ui/react';
|
||||
import { APP_HEIGHT, APP_WIDTH } from 'theme/util/constants';
|
||||
import ImageGalleryPanel from 'features/gallery/components/ImageGalleryPanel';
|
||||
import Lightbox from 'features/lightbox/components/Lightbox';
|
||||
import { useAppSelector } from './storeHooks';
|
||||
import { useAppDispatch, useAppSelector } from './storeHooks';
|
||||
import { PropsWithChildren, useEffect } from 'react';
|
||||
import { setDisabledPanels, setDisabledTabs } from 'features/ui/store/uiSlice';
|
||||
import { InvokeTabName } from 'features/ui/store/tabMap';
|
||||
import { shouldTransformUrlsChanged } from 'features/system/store/systemSlice';
|
||||
|
||||
keepGUIAlive();
|
||||
|
||||
const App = (props: PropsWithChildren) => {
|
||||
interface Props extends PropsWithChildren {
|
||||
options: {
|
||||
disabledPanels: string[];
|
||||
disabledTabs: InvokeTabName[];
|
||||
shouldTransformUrls?: boolean;
|
||||
};
|
||||
}
|
||||
|
||||
const App = (props: Props) => {
|
||||
useToastWatcher();
|
||||
|
||||
const currentTheme = useAppSelector((state) => state.ui.currentTheme);
|
||||
const { setColorMode } = useColorMode();
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
useEffect(() => {
|
||||
dispatch(setDisabledPanels(props.options.disabledPanels));
|
||||
}, [dispatch, props.options.disabledPanels]);
|
||||
|
||||
useEffect(() => {
|
||||
dispatch(setDisabledTabs(props.options.disabledTabs));
|
||||
}, [dispatch, props.options.disabledTabs]);
|
||||
|
||||
useEffect(() => {
|
||||
dispatch(
|
||||
shouldTransformUrlsChanged(Boolean(props.options.shouldTransformUrls))
|
||||
);
|
||||
}, [dispatch, props.options.shouldTransformUrls]);
|
||||
|
||||
useEffect(() => {
|
||||
setColorMode(['light'].includes(currentTheme) ? 'light' : 'dark');
|
||||
|
22
invokeai/frontend/web/src/app/invokeai.d.ts
vendored
22
invokeai/frontend/web/src/app/invokeai.d.ts
vendored
@ -14,6 +14,8 @@
|
||||
|
||||
import { InvokeTabName } from 'features/ui/store/tabMap';
|
||||
import { IRect } from 'konva/lib/types';
|
||||
import { ImageMetadata, ImageType } from 'services/api';
|
||||
import { AnyInvocation } from 'services/events/types';
|
||||
|
||||
/**
|
||||
* TODO:
|
||||
@ -113,7 +115,7 @@ export declare type Metadata = SystemGenerationMetadata & {
|
||||
};
|
||||
|
||||
// An Image has a UUID, url, modified timestamp, width, height and maybe metadata
|
||||
export declare type Image = {
|
||||
export declare type _Image = {
|
||||
uuid: string;
|
||||
url: string;
|
||||
thumbnail: string;
|
||||
@ -124,11 +126,23 @@ export declare type Image = {
|
||||
category: GalleryCategory;
|
||||
isBase64?: boolean;
|
||||
dreamPrompt?: 'string';
|
||||
name?: string;
|
||||
};
|
||||
|
||||
/**
|
||||
* ResultImage
|
||||
*/
|
||||
export declare type Image = {
|
||||
name: string;
|
||||
type: ImageType;
|
||||
url: string;
|
||||
thumbnail: string;
|
||||
metadata: ImageMetadata;
|
||||
};
|
||||
|
||||
// GalleryImages is an array of Image.
|
||||
export declare type GalleryImages = {
|
||||
images: Array<Image>;
|
||||
images: Array<_Image>;
|
||||
};
|
||||
|
||||
/**
|
||||
@ -275,7 +289,7 @@ export declare type SystemStatusResponse = SystemStatus;
|
||||
|
||||
export declare type SystemConfigResponse = SystemConfig;
|
||||
|
||||
export declare type ImageResultResponse = Omit<Image, 'uuid'> & {
|
||||
export declare type ImageResultResponse = Omit<_Image, 'uuid'> & {
|
||||
boundingBox?: IRect;
|
||||
generationMode: InvokeTabName;
|
||||
};
|
||||
@ -296,7 +310,7 @@ export declare type ErrorResponse = {
|
||||
};
|
||||
|
||||
export declare type GalleryImagesResponse = {
|
||||
images: Array<Omit<Image, 'uuid'>>;
|
||||
images: Array<Omit<_Image, 'uuid'>>;
|
||||
areMoreImagesAvailable: boolean;
|
||||
category: GalleryCategory;
|
||||
};
|
||||
|
@ -20,6 +20,7 @@ export const readinessSelector = createSelector(
|
||||
seedWeights,
|
||||
initialImage,
|
||||
seed,
|
||||
isImageToImageEnabled,
|
||||
} = generation;
|
||||
|
||||
const { isProcessing, isConnected } = system;
|
||||
@ -33,7 +34,7 @@ export const readinessSelector = createSelector(
|
||||
reasonsWhyNotReady.push('Missing prompt');
|
||||
}
|
||||
|
||||
if (activeTabName === 'img2img' && !initialImage) {
|
||||
if (isImageToImageEnabled && !initialImage) {
|
||||
isReady = false;
|
||||
reasonsWhyNotReady.push('No initial image selected');
|
||||
}
|
||||
|
@ -13,9 +13,13 @@ import { InvokeTabName } from 'features/ui/store/tabMap';
|
||||
export const generateImage = createAction<InvokeTabName>(
|
||||
'socketio/generateImage'
|
||||
);
|
||||
export const runESRGAN = createAction<InvokeAI.Image>('socketio/runESRGAN');
|
||||
export const runFacetool = createAction<InvokeAI.Image>('socketio/runFacetool');
|
||||
export const deleteImage = createAction<InvokeAI.Image>('socketio/deleteImage');
|
||||
export const runESRGAN = createAction<InvokeAI._Image>('socketio/runESRGAN');
|
||||
export const runFacetool = createAction<InvokeAI._Image>(
|
||||
'socketio/runFacetool'
|
||||
);
|
||||
export const deleteImage = createAction<InvokeAI._Image>(
|
||||
'socketio/deleteImage'
|
||||
);
|
||||
export const requestImages = createAction<GalleryCategory>(
|
||||
'socketio/requestImages'
|
||||
);
|
||||
|
@ -91,7 +91,7 @@ const makeSocketIOEmitters = (
|
||||
})
|
||||
);
|
||||
},
|
||||
emitRunESRGAN: (imageToProcess: InvokeAI.Image) => {
|
||||
emitRunESRGAN: (imageToProcess: InvokeAI._Image) => {
|
||||
dispatch(setIsProcessing(true));
|
||||
|
||||
const {
|
||||
@ -119,7 +119,7 @@ const makeSocketIOEmitters = (
|
||||
})
|
||||
);
|
||||
},
|
||||
emitRunFacetool: (imageToProcess: InvokeAI.Image) => {
|
||||
emitRunFacetool: (imageToProcess: InvokeAI._Image) => {
|
||||
dispatch(setIsProcessing(true));
|
||||
|
||||
const {
|
||||
@ -150,7 +150,7 @@ const makeSocketIOEmitters = (
|
||||
})
|
||||
);
|
||||
},
|
||||
emitDeleteImage: (imageToDelete: InvokeAI.Image) => {
|
||||
emitDeleteImage: (imageToDelete: InvokeAI._Image) => {
|
||||
const { url, uuid, category, thumbnail } = imageToDelete;
|
||||
dispatch(removeImage(imageToDelete));
|
||||
socketio.emit('deleteImage', url, thumbnail, uuid, category);
|
||||
|
@ -34,8 +34,9 @@ import type { RootState } from 'app/store';
|
||||
import { addImageToStagingArea } from 'features/canvas/store/canvasSlice';
|
||||
import {
|
||||
clearInitialImage,
|
||||
initialImageSelected,
|
||||
setInfillMethod,
|
||||
setInitialImage,
|
||||
// setInitialImage,
|
||||
setMaskPath,
|
||||
} from 'features/parameters/store/generationSlice';
|
||||
import { tabMap } from 'features/ui/store/tabMap';
|
||||
@ -142,15 +143,17 @@ const makeSocketIOListeners = (
|
||||
}
|
||||
}
|
||||
|
||||
if (shouldLoopback) {
|
||||
const activeTabName = tabMap[activeTab];
|
||||
switch (activeTabName) {
|
||||
case 'img2img': {
|
||||
dispatch(setInitialImage(newImage));
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
// TODO: fix
|
||||
// if (shouldLoopback) {
|
||||
// const activeTabName = tabMap[activeTab];
|
||||
// switch (activeTabName) {
|
||||
// case 'img2img': {
|
||||
// dispatch(initialImageSelected(newImage.uuid));
|
||||
// // dispatch(setInitialImage(newImage));
|
||||
// break;
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
|
||||
dispatch(clearIntermediateImage());
|
||||
|
||||
@ -262,7 +265,7 @@ const makeSocketIOListeners = (
|
||||
*/
|
||||
|
||||
// Generate a UUID for each image
|
||||
const preparedImages = images.map((image): InvokeAI.Image => {
|
||||
const preparedImages = images.map((image): InvokeAI._Image => {
|
||||
return {
|
||||
uuid: uuidv4(),
|
||||
...image,
|
||||
@ -334,7 +337,7 @@ const makeSocketIOListeners = (
|
||||
|
||||
if (
|
||||
initialImage === url ||
|
||||
(initialImage as InvokeAI.Image)?.url === url
|
||||
(initialImage as InvokeAI._Image)?.url === url
|
||||
) {
|
||||
dispatch(clearInitialImage());
|
||||
}
|
||||
|
@ -29,6 +29,8 @@ export const socketioMiddleware = () => {
|
||||
path: `${window.location.pathname}socket.io`,
|
||||
});
|
||||
|
||||
socketio.disconnect();
|
||||
|
||||
let areListenersSet = false;
|
||||
|
||||
const middleware: Middleware = (store) => (next) => (action) => {
|
||||
|
@ -2,18 +2,32 @@ import { combineReducers, configureStore } from '@reduxjs/toolkit';
|
||||
|
||||
import { persistReducer } from 'redux-persist';
|
||||
import storage from 'redux-persist/lib/storage'; // defaults to localStorage for web
|
||||
|
||||
import dynamicMiddlewares from 'redux-dynamic-middlewares';
|
||||
import { getPersistConfig } from 'redux-deep-persist';
|
||||
|
||||
import canvasReducer from 'features/canvas/store/canvasSlice';
|
||||
import galleryReducer from 'features/gallery/store/gallerySlice';
|
||||
import resultsReducer from 'features/gallery/store/resultsSlice';
|
||||
import uploadsReducer from 'features/gallery/store/uploadsSlice';
|
||||
import lightboxReducer from 'features/lightbox/store/lightboxSlice';
|
||||
import generationReducer from 'features/parameters/store/generationSlice';
|
||||
import postprocessingReducer from 'features/parameters/store/postprocessingSlice';
|
||||
import systemReducer from 'features/system/store/systemSlice';
|
||||
import uiReducer from 'features/ui/store/uiSlice';
|
||||
import modelsReducer from 'features/system/store/modelSlice';
|
||||
import nodesReducer from 'features/nodes/store/nodesSlice';
|
||||
|
||||
import { socketioMiddleware } from './socketio/middleware';
|
||||
import { socketMiddleware } from 'services/events/middleware';
|
||||
import { canvasBlacklist } from 'features/canvas/store/canvasPersistBlacklist';
|
||||
import { galleryBlacklist } from 'features/gallery/store/galleryPersistBlacklist';
|
||||
import { generationBlacklist } from 'features/parameters/store/generationPersistBlacklist';
|
||||
import { lightboxBlacklist } from 'features/lightbox/store/lightboxPersistBlacklist';
|
||||
import { modelsBlacklist } from 'features/system/store/modelsPersistBlacklist';
|
||||
import { nodesBlacklist } from 'features/nodes/store/nodesPersistBlacklist';
|
||||
import { postprocessingBlacklist } from 'features/parameters/store/postprocessingPersistBlacklist';
|
||||
import { systemBlacklist } from 'features/system/store/systemPersistsBlacklist';
|
||||
import { uiBlacklist } from 'features/ui/store/uiPersistBlacklist';
|
||||
|
||||
/**
|
||||
* redux-persist provides an easy and reliable way to persist state across reloads.
|
||||
@ -29,49 +43,18 @@ import { socketioMiddleware } from './socketio/middleware';
|
||||
* The necesssary nested persistors with blacklists are configured below.
|
||||
*/
|
||||
|
||||
const canvasBlacklist = [
|
||||
'cursorPosition',
|
||||
'isCanvasInitialized',
|
||||
'doesCanvasNeedScaling',
|
||||
].map((blacklistItem) => `canvas.${blacklistItem}`);
|
||||
|
||||
const systemBlacklist = [
|
||||
'currentIteration',
|
||||
'currentStatus',
|
||||
'currentStep',
|
||||
'isCancelable',
|
||||
'isConnected',
|
||||
'isESRGANAvailable',
|
||||
'isGFPGANAvailable',
|
||||
'isProcessing',
|
||||
'socketId',
|
||||
'totalIterations',
|
||||
'totalSteps',
|
||||
'openModel',
|
||||
'cancelOptions.cancelAfter',
|
||||
].map((blacklistItem) => `system.${blacklistItem}`);
|
||||
|
||||
const galleryBlacklist = [
|
||||
'categories',
|
||||
'currentCategory',
|
||||
'currentImage',
|
||||
'currentImageUuid',
|
||||
'shouldAutoSwitchToNewImages',
|
||||
'intermediateImage',
|
||||
].map((blacklistItem) => `gallery.${blacklistItem}`);
|
||||
|
||||
const lightboxBlacklist = ['isLightboxOpen'].map(
|
||||
(blacklistItem) => `lightbox.${blacklistItem}`
|
||||
);
|
||||
|
||||
const rootReducer = combineReducers({
|
||||
generation: generationReducer,
|
||||
postprocessing: postprocessingReducer,
|
||||
gallery: galleryReducer,
|
||||
system: systemReducer,
|
||||
canvas: canvasReducer,
|
||||
ui: uiReducer,
|
||||
gallery: galleryReducer,
|
||||
generation: generationReducer,
|
||||
lightbox: lightboxReducer,
|
||||
models: modelsReducer,
|
||||
nodes: nodesReducer,
|
||||
postprocessing: postprocessingReducer,
|
||||
results: resultsReducer,
|
||||
system: systemReducer,
|
||||
ui: uiReducer,
|
||||
uploads: uploadsReducer,
|
||||
});
|
||||
|
||||
const rootPersistConfig = getPersistConfig({
|
||||
@ -80,23 +63,40 @@ const rootPersistConfig = getPersistConfig({
|
||||
rootReducer,
|
||||
blacklist: [
|
||||
...canvasBlacklist,
|
||||
...systemBlacklist,
|
||||
...galleryBlacklist,
|
||||
...generationBlacklist,
|
||||
...lightboxBlacklist,
|
||||
...modelsBlacklist,
|
||||
...nodesBlacklist,
|
||||
...postprocessingBlacklist,
|
||||
// ...resultsBlacklist,
|
||||
'results',
|
||||
...systemBlacklist,
|
||||
...uiBlacklist,
|
||||
// ...uploadsBlacklist,
|
||||
'uploads',
|
||||
],
|
||||
debounce: 300,
|
||||
});
|
||||
|
||||
const persistedReducer = persistReducer(rootPersistConfig, rootReducer);
|
||||
|
||||
// Continue with store setup
|
||||
// TODO: rip the old middleware out when nodes is complete
|
||||
export function buildMiddleware() {
|
||||
if (import.meta.env.MODE === 'nodes' || import.meta.env.MODE === 'package') {
|
||||
return socketMiddleware();
|
||||
} else {
|
||||
return socketioMiddleware();
|
||||
}
|
||||
}
|
||||
|
||||
export const store = configureStore({
|
||||
reducer: persistedReducer,
|
||||
middleware: (getDefaultMiddleware) =>
|
||||
getDefaultMiddleware({
|
||||
immutableCheck: false,
|
||||
serializableCheck: false,
|
||||
}).concat(socketioMiddleware()),
|
||||
}).concat(dynamicMiddlewares),
|
||||
devTools: {
|
||||
// Uncommenting these very rapidly called actions makes the redux dev tools output much more readable
|
||||
actionsDenylist: [
|
||||
|
8
invokeai/frontend/web/src/app/storeUtils.ts
Normal file
8
invokeai/frontend/web/src/app/storeUtils.ts
Normal file
@ -0,0 +1,8 @@
|
||||
import { createAsyncThunk } from '@reduxjs/toolkit';
|
||||
import { AppDispatch, RootState } from './store';
|
||||
|
||||
// https://redux-toolkit.js.org/usage/usage-with-typescript#defining-a-pre-typed-createasyncthunk
|
||||
export const createAppAsyncThunk = createAsyncThunk.withTypes<{
|
||||
state: RootState;
|
||||
dispatch: AppDispatch;
|
||||
}>();
|
@ -44,12 +44,10 @@ export type IAIFullSliderProps = {
|
||||
inputReadOnly?: boolean;
|
||||
withReset?: boolean;
|
||||
handleReset?: () => void;
|
||||
isResetDisabled?: boolean;
|
||||
isSliderDisabled?: boolean;
|
||||
isInputDisabled?: boolean;
|
||||
tooltipSuffix?: string;
|
||||
hideTooltip?: boolean;
|
||||
isCompact?: boolean;
|
||||
isDisabled?: boolean;
|
||||
sliderFormControlProps?: FormControlProps;
|
||||
sliderFormLabelProps?: FormLabelProps;
|
||||
sliderMarkProps?: Omit<SliderMarkProps, 'value'>;
|
||||
@ -80,10 +78,8 @@ const IAISlider = (props: IAIFullSliderProps) => {
|
||||
withReset = false,
|
||||
hideTooltip = false,
|
||||
isCompact = false,
|
||||
isDisabled = false,
|
||||
handleReset,
|
||||
isResetDisabled,
|
||||
isSliderDisabled,
|
||||
isInputDisabled,
|
||||
sliderFormControlProps,
|
||||
sliderFormLabelProps,
|
||||
sliderMarkProps,
|
||||
@ -149,6 +145,7 @@ const IAISlider = (props: IAIFullSliderProps) => {
|
||||
}
|
||||
: {}
|
||||
}
|
||||
isDisabled={isDisabled}
|
||||
{...sliderFormControlProps}
|
||||
>
|
||||
<FormLabel {...sliderFormLabelProps} mb={-1}>
|
||||
@ -166,15 +163,13 @@ const IAISlider = (props: IAIFullSliderProps) => {
|
||||
onMouseEnter={() => setShowTooltip(true)}
|
||||
onMouseLeave={() => setShowTooltip(false)}
|
||||
focusThumbOnChange={false}
|
||||
isDisabled={isSliderDisabled}
|
||||
// width={width}
|
||||
isDisabled={isDisabled}
|
||||
{...rest}
|
||||
>
|
||||
{withSliderMarks && (
|
||||
<>
|
||||
<SliderMark
|
||||
value={min}
|
||||
// insetInlineStart={0}
|
||||
sx={{
|
||||
insetInlineStart: '0 !important',
|
||||
insetInlineEnd: 'unset !important',
|
||||
@ -185,7 +180,6 @@ const IAISlider = (props: IAIFullSliderProps) => {
|
||||
</SliderMark>
|
||||
<SliderMark
|
||||
value={max}
|
||||
// insetInlineEnd={0}
|
||||
sx={{
|
||||
insetInlineStart: 'unset !important',
|
||||
insetInlineEnd: '0 !important',
|
||||
@ -221,7 +215,6 @@ const IAISlider = (props: IAIFullSliderProps) => {
|
||||
value={localInputValue}
|
||||
onChange={handleInputChange}
|
||||
onBlur={handleInputBlur}
|
||||
isDisabled={isInputDisabled}
|
||||
{...sliderNumberInputProps}
|
||||
>
|
||||
<NumberInputField
|
||||
@ -246,8 +239,8 @@ const IAISlider = (props: IAIFullSliderProps) => {
|
||||
aria-label={t('accessibility.reset')}
|
||||
tooltip="Reset"
|
||||
icon={<BiReset />}
|
||||
isDisabled={isDisabled}
|
||||
onClick={handleResetDisable}
|
||||
isDisabled={isResetDisabled}
|
||||
{...sliderIAIIconButtonProps}
|
||||
/>
|
||||
)}
|
||||
|
@ -0,0 +1,79 @@
|
||||
import { Badge, Box, ButtonGroup, Flex } from '@chakra-ui/react';
|
||||
import { RootState } from 'app/store';
|
||||
import { useAppDispatch, useAppSelector } from 'app/storeHooks';
|
||||
import { clearInitialImage } from 'features/parameters/store/generationSlice';
|
||||
import { useCallback } from 'react';
|
||||
import IAIIconButton from 'common/components/IAIIconButton';
|
||||
import { FaUndo, FaUpload } from 'react-icons/fa';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { Image } from 'app/invokeai';
|
||||
|
||||
type ImageToImageOverlayProps = {
|
||||
setIsLoaded: (isLoaded: boolean) => void;
|
||||
image: Image;
|
||||
};
|
||||
|
||||
const ImageToImageOverlay = ({
|
||||
setIsLoaded,
|
||||
image,
|
||||
}: ImageToImageOverlayProps) => {
|
||||
const isImageToImageEnabled = useAppSelector(
|
||||
(state: RootState) => state.generation.isImageToImageEnabled
|
||||
);
|
||||
const dispatch = useAppDispatch();
|
||||
const { t } = useTranslation();
|
||||
const handleResetInitialImage = useCallback(() => {
|
||||
dispatch(clearInitialImage());
|
||||
setIsLoaded(false);
|
||||
}, [dispatch, setIsLoaded]);
|
||||
|
||||
return (
|
||||
<Box
|
||||
sx={{
|
||||
top: 0,
|
||||
left: 0,
|
||||
w: 'full',
|
||||
h: 'full',
|
||||
position: 'absolute',
|
||||
}}
|
||||
>
|
||||
<ButtonGroup
|
||||
sx={{
|
||||
position: 'absolute',
|
||||
top: 0,
|
||||
right: 0,
|
||||
p: 2,
|
||||
}}
|
||||
>
|
||||
<IAIIconButton
|
||||
size="sm"
|
||||
isDisabled={!isImageToImageEnabled}
|
||||
icon={<FaUndo />}
|
||||
aria-label={t('accessibility.reset')}
|
||||
onClick={handleResetInitialImage}
|
||||
/>
|
||||
<IAIIconButton
|
||||
size="sm"
|
||||
isDisabled={!isImageToImageEnabled}
|
||||
icon={<FaUpload />}
|
||||
aria-label={t('common.upload')}
|
||||
/>
|
||||
</ButtonGroup>
|
||||
<Flex
|
||||
sx={{
|
||||
position: 'absolute',
|
||||
bottom: 0,
|
||||
left: 0,
|
||||
p: 2,
|
||||
alignItems: 'flex-start',
|
||||
}}
|
||||
>
|
||||
<Badge variant="solid" colorScheme="base">
|
||||
{image.metadata?.width} × {image.metadata?.height}
|
||||
</Badge>
|
||||
</Flex>
|
||||
</Box>
|
||||
);
|
||||
};
|
||||
|
||||
export default ImageToImageOverlay;
|
@ -2,7 +2,6 @@ import { Box, useToast } from '@chakra-ui/react';
|
||||
import { ImageUploaderTriggerContext } from 'app/contexts/ImageUploaderTriggerContext';
|
||||
import { useAppDispatch, useAppSelector } from 'app/storeHooks';
|
||||
import useImageUploader from 'common/hooks/useImageUploader';
|
||||
import { uploadImage } from 'features/gallery/store/thunks/uploadImage';
|
||||
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
|
||||
import { ResourceKey } from 'i18next';
|
||||
import {
|
||||
@ -15,6 +14,7 @@ import {
|
||||
} from 'react';
|
||||
import { FileRejection, useDropzone } from 'react-dropzone';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { imageUploaded } from 'services/thunks/image';
|
||||
import ImageUploadOverlay from './ImageUploadOverlay';
|
||||
|
||||
type ImageUploaderProps = {
|
||||
@ -49,7 +49,7 @@ const ImageUploader = (props: ImageUploaderProps) => {
|
||||
|
||||
const fileAcceptedCallback = useCallback(
|
||||
async (file: File) => {
|
||||
dispatch(uploadImage({ imageFile: file }));
|
||||
dispatch(imageUploaded({ formData: { file } }));
|
||||
},
|
||||
[dispatch]
|
||||
);
|
||||
@ -124,7 +124,7 @@ const ImageUploader = (props: ImageUploaderProps) => {
|
||||
return;
|
||||
}
|
||||
|
||||
dispatch(uploadImage({ imageFile: file }));
|
||||
dispatch(imageUploaded({ formData: { file } }));
|
||||
};
|
||||
document.addEventListener('paste', pasteImageListener);
|
||||
return () => {
|
||||
|
@ -0,0 +1,12 @@
|
||||
import { Flex, Icon } from '@chakra-ui/react';
|
||||
import { FaImage } from 'react-icons/fa';
|
||||
|
||||
const SelectImagePlaceholder = () => {
|
||||
return (
|
||||
<Flex sx={{ h: 36, alignItems: 'center', justifyContent: 'center' }}>
|
||||
<Icon color="base.400" boxSize={32} as={FaImage}></Icon>
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
|
||||
export default SelectImagePlaceholder;
|
@ -1,27 +1,160 @@
|
||||
import { Flex, Heading, Text, VStack } from '@chakra-ui/react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import WorkInProgress from './WorkInProgress';
|
||||
// import WorkInProgress from './WorkInProgress';
|
||||
// import ReactFlow, {
|
||||
// applyEdgeChanges,
|
||||
// applyNodeChanges,
|
||||
// Background,
|
||||
// Controls,
|
||||
// Edge,
|
||||
// Handle,
|
||||
// Node,
|
||||
// NodeTypes,
|
||||
// OnEdgesChange,
|
||||
// OnNodesChange,
|
||||
// Position,
|
||||
// } from 'reactflow';
|
||||
|
||||
export default function NodesWIP() {
|
||||
const { t } = useTranslation();
|
||||
return (
|
||||
<WorkInProgress>
|
||||
<Flex
|
||||
sx={{
|
||||
flexDirection: 'column',
|
||||
alignItems: 'center',
|
||||
justifyContent: 'center',
|
||||
w: '100%',
|
||||
h: '100%',
|
||||
gap: 4,
|
||||
textAlign: 'center',
|
||||
}}
|
||||
>
|
||||
<Heading>{t('common.nodes')}</Heading>
|
||||
<VStack maxW="50rem" gap={4}>
|
||||
<Text>{t('common.nodesDesc')}</Text>
|
||||
</VStack>
|
||||
</Flex>
|
||||
</WorkInProgress>
|
||||
);
|
||||
}
|
||||
// import 'reactflow/dist/style.css';
|
||||
// import {
|
||||
// Fragment,
|
||||
// FunctionComponent,
|
||||
// ReactNode,
|
||||
// useCallback,
|
||||
// useMemo,
|
||||
// useState,
|
||||
// } from 'react';
|
||||
// import { OpenAPIV3 } from 'openapi-types';
|
||||
// import { filter, map, reduce } from 'lodash';
|
||||
// import {
|
||||
// Box,
|
||||
// Flex,
|
||||
// FormControl,
|
||||
// FormLabel,
|
||||
// Input,
|
||||
// Select,
|
||||
// Switch,
|
||||
// Text,
|
||||
// NumberInput,
|
||||
// NumberInputField,
|
||||
// NumberInputStepper,
|
||||
// NumberIncrementStepper,
|
||||
// NumberDecrementStepper,
|
||||
// Tooltip,
|
||||
// chakra,
|
||||
// Badge,
|
||||
// Heading,
|
||||
// VStack,
|
||||
// HStack,
|
||||
// Menu,
|
||||
// MenuButton,
|
||||
// MenuList,
|
||||
// MenuItem,
|
||||
// MenuItemOption,
|
||||
// MenuGroup,
|
||||
// MenuOptionGroup,
|
||||
// MenuDivider,
|
||||
// IconButton,
|
||||
// } from '@chakra-ui/react';
|
||||
// import { FaPlus } from 'react-icons/fa';
|
||||
// import {
|
||||
// FIELD_NAMES as FIELD_NAMES,
|
||||
// FIELDS,
|
||||
// INVOCATION_NAMES as INVOCATION_NAMES,
|
||||
// INVOCATIONS,
|
||||
// } from 'features/nodeEditor/constants';
|
||||
|
||||
// console.log('invocations', INVOCATIONS);
|
||||
|
||||
// const nodeTypes = reduce(
|
||||
// INVOCATIONS,
|
||||
// (acc, val, key) => {
|
||||
// acc[key] = val.component;
|
||||
// return acc;
|
||||
// },
|
||||
// {} as NodeTypes
|
||||
// );
|
||||
|
||||
// console.log('nodeTypes', nodeTypes);
|
||||
|
||||
// // make initial nodes one of every node for now
|
||||
// let n = 0;
|
||||
// const initialNodes = map(INVOCATIONS, (i) => ({
|
||||
// id: i.type,
|
||||
// type: i.title,
|
||||
// position: { x: (n += 20), y: (n += 20) },
|
||||
// data: {},
|
||||
// }));
|
||||
|
||||
// console.log('initialNodes', initialNodes);
|
||||
|
||||
// export default function NodesWIP() {
|
||||
// const [nodes, setNodes] = useState<Node[]>([]);
|
||||
// const [edges, setEdges] = useState<Edge[]>([]);
|
||||
|
||||
// const onNodesChange: OnNodesChange = useCallback(
|
||||
// (changes) => setNodes((nds) => applyNodeChanges(changes, nds)),
|
||||
// []
|
||||
// );
|
||||
|
||||
// const onEdgesChange: OnEdgesChange = useCallback(
|
||||
// (changes) => setEdges((eds: Edge[]) => applyEdgeChanges(changes, eds)),
|
||||
// []
|
||||
// );
|
||||
|
||||
// return (
|
||||
// <Box
|
||||
// sx={{
|
||||
// position: 'relative',
|
||||
// width: 'full',
|
||||
// height: 'full',
|
||||
// borderRadius: 'md',
|
||||
// }}
|
||||
// >
|
||||
// <ReactFlow
|
||||
// nodeTypes={nodeTypes}
|
||||
// nodes={nodes}
|
||||
// edges={edges}
|
||||
// onNodesChange={onNodesChange}
|
||||
// onEdgesChange={onEdgesChange}
|
||||
// >
|
||||
// <Background />
|
||||
// <Controls />
|
||||
// </ReactFlow>
|
||||
// <HStack sx={{ position: 'absolute', top: 2, right: 2 }}>
|
||||
// {FIELD_NAMES.map((field) => (
|
||||
// <Badge
|
||||
// key={field}
|
||||
// colorScheme={FIELDS[field].color}
|
||||
// sx={{ userSelect: 'none' }}
|
||||
// >
|
||||
// {field}
|
||||
// </Badge>
|
||||
// ))}
|
||||
// </HStack>
|
||||
// <Menu>
|
||||
// <MenuButton
|
||||
// as={IconButton}
|
||||
// aria-label="Options"
|
||||
// icon={<FaPlus />}
|
||||
// sx={{ position: 'absolute', top: 2, left: 2 }}
|
||||
// />
|
||||
// <MenuList>
|
||||
// {INVOCATION_NAMES.map((name) => {
|
||||
// const invocation = INVOCATIONS[name];
|
||||
// return (
|
||||
// <Tooltip
|
||||
// key={name}
|
||||
// label={invocation.description}
|
||||
// placement="end"
|
||||
// hasArrow
|
||||
// >
|
||||
// <MenuItem>{invocation.title}</MenuItem>
|
||||
// </Tooltip>
|
||||
// );
|
||||
// })}
|
||||
// </MenuList>
|
||||
// </Menu>
|
||||
// </Box>
|
||||
// );
|
||||
// }
|
||||
|
||||
export default {};
|
||||
|
@ -14,6 +14,8 @@ const WorkInProgress = (props: WorkInProgressProps) => {
|
||||
width: '100%',
|
||||
height: '100%',
|
||||
bg: 'base.850',
|
||||
borderRadius: 'base',
|
||||
position: 'relative',
|
||||
}}
|
||||
>
|
||||
{children}
|
||||
|
119
invokeai/frontend/web/src/common/util/_parseMetadataZod.ts
Normal file
119
invokeai/frontend/web/src/common/util/_parseMetadataZod.ts
Normal file
@ -0,0 +1,119 @@
|
||||
/**
|
||||
* PARTIAL ZOD IMPLEMENTATION
|
||||
*
|
||||
* doesn't work well bc like most validators, zod is not built to skip invalid values.
|
||||
* it mostly works but just seems clearer and simpler to manually parse for now.
|
||||
*
|
||||
* in the future it would be really nice if we could use zod for some things:
|
||||
* - zodios (axios + zod): https://github.com/ecyrbe/zodios
|
||||
* - openapi to zodios: https://github.com/astahmer/openapi-zod-client
|
||||
*/
|
||||
|
||||
// import { z } from 'zod';
|
||||
|
||||
// const zMetadataStringField = z.string();
|
||||
// export type MetadataStringField = z.infer<typeof zMetadataStringField>;
|
||||
|
||||
// const zMetadataIntegerField = z.number().int();
|
||||
// export type MetadataIntegerField = z.infer<typeof zMetadataIntegerField>;
|
||||
|
||||
// const zMetadataFloatField = z.number();
|
||||
// export type MetadataFloatField = z.infer<typeof zMetadataFloatField>;
|
||||
|
||||
// const zMetadataBooleanField = z.boolean();
|
||||
// export type MetadataBooleanField = z.infer<typeof zMetadataBooleanField>;
|
||||
|
||||
// const zMetadataImageField = z.object({
|
||||
// image_type: z.union([
|
||||
// z.literal('results'),
|
||||
// z.literal('uploads'),
|
||||
// z.literal('intermediates'),
|
||||
// ]),
|
||||
// image_name: z.string().min(1),
|
||||
// });
|
||||
// export type MetadataImageField = z.infer<typeof zMetadataImageField>;
|
||||
|
||||
// const zMetadataLatentsField = z.object({
|
||||
// latents_name: z.string().min(1),
|
||||
// });
|
||||
// export type MetadataLatentsField = z.infer<typeof zMetadataLatentsField>;
|
||||
|
||||
// /**
|
||||
// * zod Schema for any node field. Use a `transform()` to manually parse, skipping invalid values.
|
||||
// */
|
||||
// const zAnyMetadataField = z.any().transform((val, ctx) => {
|
||||
// // Grab the field name from the path
|
||||
// const fieldName = String(ctx.path[ctx.path.length - 1]);
|
||||
|
||||
// // `id` and `type` must be strings if they exist
|
||||
// if (['id', 'type'].includes(fieldName)) {
|
||||
// const reservedStringPropertyResult = zMetadataStringField.safeParse(val);
|
||||
// if (reservedStringPropertyResult.success) {
|
||||
// return reservedStringPropertyResult.data;
|
||||
// }
|
||||
|
||||
// return;
|
||||
// }
|
||||
|
||||
// // Parse the rest of the fields, only returning the data if the parsing is successful
|
||||
|
||||
// const stringFieldResult = zMetadataStringField.safeParse(val);
|
||||
// if (stringFieldResult.success) {
|
||||
// return stringFieldResult.data;
|
||||
// }
|
||||
|
||||
// const integerFieldResult = zMetadataIntegerField.safeParse(val);
|
||||
// if (integerFieldResult.success) {
|
||||
// return integerFieldResult.data;
|
||||
// }
|
||||
|
||||
// const floatFieldResult = zMetadataFloatField.safeParse(val);
|
||||
// if (floatFieldResult.success) {
|
||||
// return floatFieldResult.data;
|
||||
// }
|
||||
|
||||
// const booleanFieldResult = zMetadataBooleanField.safeParse(val);
|
||||
// if (booleanFieldResult.success) {
|
||||
// return booleanFieldResult.data;
|
||||
// }
|
||||
|
||||
// const imageFieldResult = zMetadataImageField.safeParse(val);
|
||||
// if (imageFieldResult.success) {
|
||||
// return imageFieldResult.data;
|
||||
// }
|
||||
|
||||
// const latentsFieldResult = zMetadataImageField.safeParse(val);
|
||||
// if (latentsFieldResult.success) {
|
||||
// return latentsFieldResult.data;
|
||||
// }
|
||||
// });
|
||||
|
||||
// /**
|
||||
// * The node metadata schema.
|
||||
// */
|
||||
// const zNodeMetadata = z.object({
|
||||
// session_id: z.string().min(1).optional(),
|
||||
// node: z.record(z.string().min(1), zAnyMetadataField).optional(),
|
||||
// });
|
||||
|
||||
// export type NodeMetadata = z.infer<typeof zNodeMetadata>;
|
||||
|
||||
// const zMetadata = z.object({
|
||||
// invokeai: zNodeMetadata.optional(),
|
||||
// 'sd-metadata': z.record(z.string().min(1), z.any()).optional(),
|
||||
// });
|
||||
// export type Metadata = z.infer<typeof zMetadata>;
|
||||
|
||||
// export const parseMetadata = (
|
||||
// metadata: Record<string, any>
|
||||
// ): Metadata | undefined => {
|
||||
// const result = zMetadata.safeParse(metadata);
|
||||
// if (!result.success) {
|
||||
// console.log(result.error.issues);
|
||||
// return;
|
||||
// }
|
||||
|
||||
// return result.data;
|
||||
// };
|
||||
|
||||
export default {};
|
6
invokeai/frontend/web/src/common/util/getTimestamp.ts
Normal file
6
invokeai/frontend/web/src/common/util/getTimestamp.ts
Normal file
@ -0,0 +1,6 @@
|
||||
import dateFormat from 'dateformat';
|
||||
|
||||
/**
|
||||
* Get a `now` timestamp with 1s precision, formatted as ISO datetime.
|
||||
*/
|
||||
export const getTimestamp = () => dateFormat(new Date(), 'isoDateTime');
|
28
invokeai/frontend/web/src/common/util/getUrl.ts
Normal file
28
invokeai/frontend/web/src/common/util/getUrl.ts
Normal file
@ -0,0 +1,28 @@
|
||||
import { RootState } from 'app/store';
|
||||
import { useAppSelector } from 'app/storeHooks';
|
||||
import { OpenAPI } from 'services/api';
|
||||
|
||||
export const getUrlAlt = (url: string, shouldTransformUrls: boolean) => {
|
||||
if (OpenAPI.BASE && shouldTransformUrls) {
|
||||
return [OpenAPI.BASE, url].join('/');
|
||||
}
|
||||
|
||||
return url;
|
||||
};
|
||||
|
||||
export const useGetUrl = () => {
|
||||
const shouldTransformUrls = useAppSelector(
|
||||
(state: RootState) => state.system.shouldTransformUrls
|
||||
);
|
||||
|
||||
return {
|
||||
shouldTransformUrls,
|
||||
getUrl: (url?: string) => {
|
||||
if (OpenAPI.BASE && shouldTransformUrls) {
|
||||
return [OpenAPI.BASE, url].join('/');
|
||||
}
|
||||
|
||||
return url;
|
||||
},
|
||||
};
|
||||
};
|
169
invokeai/frontend/web/src/common/util/parseMetadata.ts
Normal file
169
invokeai/frontend/web/src/common/util/parseMetadata.ts
Normal file
@ -0,0 +1,169 @@
|
||||
import { forEach, size } from 'lodash';
|
||||
import { ImageField, LatentsField } from 'services/api';
|
||||
|
||||
const OBJECT_TYPESTRING = '[object Object]';
|
||||
const STRING_TYPESTRING = '[object String]';
|
||||
const NUMBER_TYPESTRING = '[object Number]';
|
||||
const BOOLEAN_TYPESTRING = '[object Boolean]';
|
||||
const ARRAY_TYPESTRING = '[object Array]';
|
||||
|
||||
const isObject = (obj: unknown): obj is Record<string | number, any> =>
|
||||
Object.prototype.toString.call(obj) === OBJECT_TYPESTRING;
|
||||
|
||||
const isString = (obj: unknown): obj is string =>
|
||||
Object.prototype.toString.call(obj) === STRING_TYPESTRING;
|
||||
|
||||
const isNumber = (obj: unknown): obj is number =>
|
||||
Object.prototype.toString.call(obj) === NUMBER_TYPESTRING;
|
||||
|
||||
const isBoolean = (obj: unknown): obj is boolean =>
|
||||
Object.prototype.toString.call(obj) === BOOLEAN_TYPESTRING;
|
||||
|
||||
const isArray = (obj: unknown): obj is Array<any> =>
|
||||
Object.prototype.toString.call(obj) === ARRAY_TYPESTRING;
|
||||
|
||||
const parseImageField = (imageField: unknown): ImageField | undefined => {
|
||||
// Must be an object
|
||||
if (!isObject(imageField)) {
|
||||
return;
|
||||
}
|
||||
|
||||
// An ImageField must have both `image_name` and `image_type`
|
||||
if (!('image_name' in imageField && 'image_type' in imageField)) {
|
||||
return;
|
||||
}
|
||||
|
||||
// An ImageField's `image_type` must be one of the allowed values
|
||||
if (
|
||||
!['results', 'uploads', 'intermediates'].includes(imageField.image_type)
|
||||
) {
|
||||
return;
|
||||
}
|
||||
|
||||
// An ImageField's `image_name` must be a string
|
||||
if (typeof imageField.image_name !== 'string') {
|
||||
return;
|
||||
}
|
||||
|
||||
// Build a valid ImageField
|
||||
return {
|
||||
image_type: imageField.image_type,
|
||||
image_name: imageField.image_name,
|
||||
};
|
||||
};
|
||||
|
||||
const parseLatentsField = (latentsField: unknown): LatentsField | undefined => {
|
||||
// Must be an object
|
||||
if (!isObject(latentsField)) {
|
||||
return;
|
||||
}
|
||||
|
||||
// A LatentsField must have a `latents_name`
|
||||
if (!('latents_name' in latentsField)) {
|
||||
return;
|
||||
}
|
||||
|
||||
// A LatentsField's `latents_name` must be a string
|
||||
if (typeof latentsField.latents_name !== 'string') {
|
||||
return;
|
||||
}
|
||||
|
||||
// Build a valid LatentsField
|
||||
return {
|
||||
latents_name: latentsField.latents_name,
|
||||
};
|
||||
};
|
||||
|
||||
type NodeMetadata = {
|
||||
[key: string]: string | number | boolean | ImageField | LatentsField;
|
||||
};
|
||||
|
||||
type InvokeAIMetadata = {
|
||||
session_id?: string;
|
||||
node?: NodeMetadata;
|
||||
};
|
||||
|
||||
export const parseNodeMetadata = (
|
||||
nodeMetadata: Record<string | number, any>
|
||||
): NodeMetadata | undefined => {
|
||||
if (!isObject(nodeMetadata)) {
|
||||
return;
|
||||
}
|
||||
|
||||
const parsed: NodeMetadata = {};
|
||||
|
||||
forEach(nodeMetadata, (nodeItem, nodeKey) => {
|
||||
// `id` and `type` must be strings if they are present
|
||||
if (['id', 'type'].includes(nodeKey)) {
|
||||
if (isString(nodeItem)) {
|
||||
parsed[nodeKey] = nodeItem;
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
// the only valid object types are ImageField and LatentsField
|
||||
if (isObject(nodeItem)) {
|
||||
if ('image_name' in nodeItem || 'image_type' in nodeItem) {
|
||||
const imageField = parseImageField(nodeItem);
|
||||
if (imageField) {
|
||||
parsed[nodeKey] = imageField;
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
if ('latents_name' in nodeItem) {
|
||||
const latentsField = parseLatentsField(nodeItem);
|
||||
if (latentsField) {
|
||||
parsed[nodeKey] = latentsField;
|
||||
}
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
// otherwise we accept any string, number or boolean
|
||||
if (isString(nodeItem) || isNumber(nodeItem) || isBoolean(nodeItem)) {
|
||||
parsed[nodeKey] = nodeItem;
|
||||
return;
|
||||
}
|
||||
});
|
||||
|
||||
if (size(parsed) === 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
return parsed;
|
||||
};
|
||||
|
||||
export const parseInvokeAIMetadata = (
|
||||
metadata: Record<string | number, any> | undefined
|
||||
): InvokeAIMetadata | undefined => {
|
||||
if (metadata === undefined) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (!isObject(metadata)) {
|
||||
return;
|
||||
}
|
||||
|
||||
const parsed: InvokeAIMetadata = {};
|
||||
|
||||
forEach(metadata, (item, key) => {
|
||||
if (key === 'session_id' && isString(item)) {
|
||||
parsed['session_id'] = item;
|
||||
}
|
||||
|
||||
if (key === 'node' && isObject(item)) {
|
||||
const nodeMetadata = parseNodeMetadata(item);
|
||||
|
||||
if (nodeMetadata) {
|
||||
parsed['node'] = nodeMetadata;
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
if (size(parsed) === 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
return parsed;
|
||||
};
|
@ -1,8 +1,10 @@
|
||||
import React, { lazy, PropsWithChildren } from 'react';
|
||||
import React, { lazy, PropsWithChildren, useEffect, useState } from 'react';
|
||||
import { Provider } from 'react-redux';
|
||||
import { PersistGate } from 'redux-persist/integration/react';
|
||||
import { store } from './app/store';
|
||||
import { buildMiddleware, store } from './app/store';
|
||||
import { persistor } from './persistor';
|
||||
import { OpenAPI } from 'services/api';
|
||||
import { InvokeTabName } from 'features/ui/store/tabMap';
|
||||
import '@fontsource/inter/100.css';
|
||||
import '@fontsource/inter/200.css';
|
||||
import '@fontsource/inter/300.css';
|
||||
@ -17,18 +19,61 @@ import Loading from './Loading';
|
||||
|
||||
// Localization
|
||||
import './i18n';
|
||||
import { addMiddleware, resetMiddlewares } from 'redux-dynamic-middlewares';
|
||||
|
||||
const App = lazy(() => import('./app/App'));
|
||||
const ThemeLocaleProvider = lazy(() => import('./app/ThemeLocaleProvider'));
|
||||
|
||||
export default function Component(props: PropsWithChildren) {
|
||||
interface Props extends PropsWithChildren {
|
||||
apiUrl?: string;
|
||||
disabledPanels?: string[];
|
||||
disabledTabs?: InvokeTabName[];
|
||||
token?: string;
|
||||
shouldTransformUrls?: boolean;
|
||||
}
|
||||
|
||||
export default function Component({
|
||||
apiUrl,
|
||||
disabledPanels = [],
|
||||
disabledTabs = [],
|
||||
token,
|
||||
children,
|
||||
shouldTransformUrls,
|
||||
}: Props) {
|
||||
useEffect(() => {
|
||||
// configure API client token
|
||||
if (token) {
|
||||
OpenAPI.TOKEN = token;
|
||||
}
|
||||
|
||||
// configure API client base url
|
||||
if (apiUrl) {
|
||||
OpenAPI.BASE = apiUrl;
|
||||
}
|
||||
|
||||
// reset dynamically added middlewares
|
||||
resetMiddlewares();
|
||||
|
||||
// TODO: at this point, after resetting the middleware, we really ought to clean up the socket
|
||||
// stuff by calling `dispatch(socketReset())`. but we cannot dispatch from here as we are
|
||||
// outside the provider. it's not needed until there is the possibility that we will change
|
||||
// the `apiUrl`/`token` dynamically.
|
||||
|
||||
// rebuild socket middleware with token and apiUrl
|
||||
addMiddleware(buildMiddleware());
|
||||
}, [apiUrl, token]);
|
||||
|
||||
return (
|
||||
<React.StrictMode>
|
||||
<Provider store={store}>
|
||||
<PersistGate loading={<Loading />} persistor={persistor}>
|
||||
<React.Suspense fallback={<Loading showText />}>
|
||||
<ThemeLocaleProvider>
|
||||
<App>{props.children}</App>
|
||||
<App
|
||||
options={{ disabledPanels, disabledTabs, shouldTransformUrls }}
|
||||
>
|
||||
{children}
|
||||
</App>
|
||||
</ThemeLocaleProvider>
|
||||
</React.Suspense>
|
||||
</PersistGate>
|
||||
|
@ -5,6 +5,8 @@ import ThemeChanger from './features/system/components/ThemeChanger';
|
||||
import IAIPopover from './common/components/IAIPopover';
|
||||
import IAIIconButton from './common/components/IAIIconButton';
|
||||
import SettingsModal from './features/system/components/SettingsModal/SettingsModal';
|
||||
import StatusIndicator from './features/system/components/StatusIndicator';
|
||||
import ModelSelect from 'features/system/components/ModelSelect';
|
||||
|
||||
export default Component;
|
||||
export {
|
||||
@ -13,4 +15,6 @@ export {
|
||||
IAIPopover,
|
||||
IAIIconButton,
|
||||
SettingsModal,
|
||||
StatusIndicator,
|
||||
ModelSelect,
|
||||
};
|
||||
|
@ -1,6 +1,7 @@
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { RootState } from 'app/store';
|
||||
import { useAppSelector } from 'app/storeHooks';
|
||||
import { useGetUrl } from 'common/util/getUrl';
|
||||
import { GalleryState } from 'features/gallery/store/gallerySlice';
|
||||
import { ImageConfig } from 'konva/lib/shapes/Image';
|
||||
import { isEqual } from 'lodash';
|
||||
@ -25,7 +26,7 @@ type Props = Omit<ImageConfig, 'image'>;
|
||||
const IAICanvasIntermediateImage = (props: Props) => {
|
||||
const { ...rest } = props;
|
||||
const intermediateImage = useAppSelector(selector);
|
||||
|
||||
const { getUrl } = useGetUrl();
|
||||
const [loadedImageElement, setLoadedImageElement] =
|
||||
useState<HTMLImageElement | null>(null);
|
||||
|
||||
@ -36,8 +37,8 @@ const IAICanvasIntermediateImage = (props: Props) => {
|
||||
tempImage.onload = () => {
|
||||
setLoadedImageElement(tempImage);
|
||||
};
|
||||
tempImage.src = intermediateImage.url;
|
||||
}, [intermediateImage]);
|
||||
tempImage.src = getUrl(intermediateImage.url);
|
||||
}, [intermediateImage, getUrl]);
|
||||
|
||||
if (!intermediateImage?.boundingBox) return null;
|
||||
|
||||
|
@ -1,5 +1,6 @@
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { useAppSelector } from 'app/storeHooks';
|
||||
import { useGetUrl } from 'common/util/getUrl';
|
||||
import { canvasSelector } from 'features/canvas/store/canvasSelectors';
|
||||
import { rgbaColorToString } from 'features/canvas/util/colorToString';
|
||||
import { isEqual } from 'lodash';
|
||||
@ -32,6 +33,7 @@ const selector = createSelector(
|
||||
|
||||
const IAICanvasObjectRenderer = () => {
|
||||
const { objects } = useAppSelector(selector);
|
||||
const { getUrl } = useGetUrl();
|
||||
|
||||
if (!objects) return null;
|
||||
|
||||
@ -40,7 +42,12 @@ const IAICanvasObjectRenderer = () => {
|
||||
{objects.map((obj, i) => {
|
||||
if (isCanvasBaseImage(obj)) {
|
||||
return (
|
||||
<IAICanvasImage key={i} x={obj.x} y={obj.y} url={obj.image.url} />
|
||||
<IAICanvasImage
|
||||
key={i}
|
||||
x={obj.x}
|
||||
y={obj.y}
|
||||
url={getUrl(obj.image.url)}
|
||||
/>
|
||||
);
|
||||
} else if (isCanvasBaseLine(obj)) {
|
||||
const line = (
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user