Merge branch 'main' into bugfix/prevent-cli-crash

This commit is contained in:
Lincoln Stein 2023-04-09 12:08:09 -04:00 committed by GitHub
commit cee159dfa3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
22 changed files with 377 additions and 174 deletions

View File

@ -1,10 +1,18 @@
# Invocations
Invocations represent a single operation, its inputs, and its outputs. These operations and their outputs can be chained together to generate and modify images.
Invocations represent a single operation, its inputs, and its outputs. These
operations and their outputs can be chained together to generate and modify
images.
## Creating a new invocation
To create a new invocation, either find the appropriate module file in `/ldm/invoke/app/invocations` to add your invocation to, or create a new one in that folder. All invocations in that folder will be discovered and made available to the CLI and API automatically. Invocations make use of [typing](https://docs.python.org/3/library/typing.html) and [pydantic](https://pydantic-docs.helpmanual.io/) for validation and integration into the CLI and API.
To create a new invocation, either find the appropriate module file in
`/ldm/invoke/app/invocations` to add your invocation to, or create a new one in
that folder. All invocations in that folder will be discovered and made
available to the CLI and API automatically. Invocations make use of
[typing](https://docs.python.org/3/library/typing.html) and
[pydantic](https://pydantic-docs.helpmanual.io/) for validation and integration
into the CLI and API.
An invocation looks like this:
@ -41,34 +49,54 @@ class UpscaleInvocation(BaseInvocation):
Each portion is important to implement correctly.
### Class definition and type
```py
class UpscaleInvocation(BaseInvocation):
"""Upscales an image."""
type: Literal['upscale'] = 'upscale'
```
All invocations must derive from `BaseInvocation`. They should have a docstring that declares what they do in a single, short line. They should also have a `type` with a type hint that's `Literal["command_name"]`, where `command_name` is what the user will type on the CLI or use in the API to create this invocation. The `command_name` must be unique. The `type` must be assigned to the value of the literal in the type hint.
All invocations must derive from `BaseInvocation`. They should have a docstring
that declares what they do in a single, short line. They should also have a
`type` with a type hint that's `Literal["command_name"]`, where `command_name`
is what the user will type on the CLI or use in the API to create this
invocation. The `command_name` must be unique. The `type` must be assigned to
the value of the literal in the type hint.
### Inputs
```py
# Inputs
image: Union[ImageField,None] = Field(description="The input image")
strength: float = Field(default=0.75, gt=0, le=1, description="The strength")
level: Literal[2,4] = Field(default=2, description="The upscale level")
```
Inputs consist of three parts: a name, a type hint, and a `Field` with default, description, and validation information. For example:
| Part | Value | Description |
| ---- | ----- | ----------- |
| Name | `strength` | This field is referred to as `strength` |
| Type Hint | `float` | This field must be of type `float` |
| Field | `Field(default=0.75, gt=0, le=1, description="The strength")` | The default value is `0.75`, the value must be in the range (0,1], and help text will show "The strength" for this field. |
Notice that `image` has type `Union[ImageField,None]`. The `Union` allows this field to be parsed with `None` as a value, which enables linking to previous invocations. All fields should either provide a default value or allow `None` as a value, so that they can be overwritten with a linked output from another invocation.
Inputs consist of three parts: a name, a type hint, and a `Field` with default,
description, and validation information. For example:
The special type `ImageField` is also used here. All images are passed as `ImageField`, which protects them from pydantic validation errors (since images only ever come from links).
| Part | Value | Description |
| --------- | ------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------- |
| Name | `strength` | This field is referred to as `strength` |
| Type Hint | `float` | This field must be of type `float` |
| Field | `Field(default=0.75, gt=0, le=1, description="The strength")` | The default value is `0.75`, the value must be in the range (0,1], and help text will show "The strength" for this field. |
Finally, note that for all linking, the `type` of the linked fields must match. If the `name` also matches, then the field can be **automatically linked** to a previous invocation by name and matching.
Notice that `image` has type `Union[ImageField,None]`. The `Union` allows this
field to be parsed with `None` as a value, which enables linking to previous
invocations. All fields should either provide a default value or allow `None` as
a value, so that they can be overwritten with a linked output from another
invocation.
The special type `ImageField` is also used here. All images are passed as
`ImageField`, which protects them from pydantic validation errors (since images
only ever come from links).
Finally, note that for all linking, the `type` of the linked fields must match.
If the `name` also matches, then the field can be **automatically linked** to a
previous invocation by name and matching.
### Invoke Function
```py
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get(self.image.image_type, self.image.image_name)
@ -88,13 +116,22 @@ Finally, note that for all linking, the `type` of the linked fields must match.
image = ImageField(image_type = image_type, image_name = image_name)
)
```
The `invoke` function is the last portion of an invocation. It is provided an `InvocationContext` which contains services to perform work as well as a `session_id` for use as needed. It should return a class with output values that derives from `BaseInvocationOutput`.
Before being called, the invocation will have all of its fields set from defaults, inputs, and finally links (overriding in that order).
The `invoke` function is the last portion of an invocation. It is provided an
`InvocationContext` which contains services to perform work as well as a
`session_id` for use as needed. It should return a class with output values that
derives from `BaseInvocationOutput`.
Assume that this invocation may be running simultaneously with other invocations, may be running on another machine, or in other interesting scenarios. If you need functionality, please provide it as a service in the `InvocationServices` class, and make sure it can be overridden.
Before being called, the invocation will have all of its fields set from
defaults, inputs, and finally links (overriding in that order).
Assume that this invocation may be running simultaneously with other
invocations, may be running on another machine, or in other interesting
scenarios. If you need functionality, please provide it as a service in the
`InvocationServices` class, and make sure it can be overridden.
### Outputs
```py
class ImageOutput(BaseInvocationOutput):
"""Base class for invocations that output an image"""
@ -102,4 +139,64 @@ class ImageOutput(BaseInvocationOutput):
image: ImageField = Field(default=None, description="The output image")
```
Output classes look like an invocation class without the invoke method. Prefer to use an existing output class if available, and prefer to name inputs the same as outputs when possible, to promote automatic invocation linking.
Output classes look like an invocation class without the invoke method. Prefer
to use an existing output class if available, and prefer to name inputs the same
as outputs when possible, to promote automatic invocation linking.
## Schema Generation
Invocation, output and related classes are used to generate an OpenAPI schema.
### Required Properties
The schema generation treat all properties with default values as optional. This
makes sense internally, but when when using these classes via the generated
schema, we end up with e.g. the `ImageOutput` class having its `image` property
marked as optional.
We know that this property will always be present, so the additional logic
needed to always check if the property exists adds a lot of extraneous cruft.
To fix this, we can leverage `pydantic`'s
[schema customisation](https://docs.pydantic.dev/usage/schema/#schema-customization)
to mark properties that we know will always be present as required.
Here's that `ImageOutput` class, without the needed schema customisation:
```python
class ImageOutput(BaseInvocationOutput):
"""Base class for invocations that output an image"""
type: Literal["image"] = "image"
image: ImageField = Field(default=None, description="The output image")
```
The generated OpenAPI schema, and all clients/types generated from it, will have
the `type` and `image` properties marked as optional, even though we know they
will always have a value by the time we can interact with them via the API.
Here's the same class, but with the schema customisation added:
```python
class ImageOutput(BaseInvocationOutput):
"""Base class for invocations that output an image"""
type: Literal["image"] = "image"
image: ImageField = Field(default=None, description="The output image")
class Config:
schema_extra = {
'required': [
'type',
'image',
]
}
```
The resultant schema (and any API client or types generated from it) will now
have see `type` as string literal `"image"` and `image` as an `ImageField`
object.
See this `pydantic` issue for discussion on this solution:
<https://github.com/pydantic/pydantic/discussions/4577>

View File

@ -50,7 +50,7 @@ subset that are currently installed are found in
|stable-diffusion-1.5|runwayml/stable-diffusion-v1-5|Stable Diffusion version 1.5 diffusers model (4.27 GB)|https://huggingface.co/runwayml/stable-diffusion-v1-5 |
|sd-inpainting-1.5|runwayml/stable-diffusion-inpainting|RunwayML SD 1.5 model optimized for inpainting, diffusers version (4.27 GB)|https://huggingface.co/runwayml/stable-diffusion-inpainting |
|stable-diffusion-2.1|stabilityai/stable-diffusion-2-1|Stable Diffusion version 2.1 diffusers model, trained on 768 pixel images (5.21 GB)|https://huggingface.co/stabilityai/stable-diffusion-2-1 |
|sd-inpainting-2.0|stabilityai/stable-diffusion-2-1|Stable Diffusion version 2.0 inpainting model (5.21 GB)|https://huggingface.co/stabilityai/stable-diffusion-2-1 |
|sd-inpainting-2.0|stabilityai/stable-diffusion-2-inpainting|Stable Diffusion version 2.0 inpainting model (5.21 GB)|https://huggingface.co/stabilityai/stable-diffusion-2-inpainting |
|analog-diffusion-1.0|wavymulder/Analog-Diffusion|An SD-1.5 model trained on diverse analog photographs (2.13 GB)|https://huggingface.co/wavymulder/Analog-Diffusion |
|deliberate-1.0|XpucT/Deliberate|Versatile model that produces detailed images up to 768px (4.27 GB)|https://huggingface.co/XpucT/Deliberate |
|d&d-diffusion-1.0|0xJustin/Dungeons-and-Diffusion|Dungeons & Dragons characters (2.13 GB)|https://huggingface.co/0xJustin/Dungeons-and-Diffusion |

View File

@ -0,0 +1,14 @@
from pydantic import BaseModel, Field
from invokeai.app.models.image import ImageType
from invokeai.app.models.metadata import ImageMetadata
class ImageResponse(BaseModel):
"""The response type for images"""
image_type: ImageType = Field(description="The type of the image")
image_name: str = Field(description="The name of the image")
image_url: str = Field(description="The url of the image")
thumbnail_url: str = Field(description="The url of the image's thumbnail")
metadata: ImageMetadata = Field(description="The image's metadata")

View File

@ -1,18 +1,20 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
from datetime import datetime, timezone
import uuid
from fastapi import Path, Request, UploadFile
from fastapi import Path, Query, Request, UploadFile
from fastapi.responses import FileResponse, Response
from fastapi.routing import APIRouter
from PIL import Image
from invokeai.app.api.models.images import ImageResponse
from invokeai.app.services.item_storage import PaginatedResults
from ...services.image_storage import ImageType
from ..dependencies import ApiDependencies
images_router = APIRouter(prefix="/v1/images", tags=["images"])
@images_router.get("/{image_type}/{image_name}", operation_id="get_image")
async def get_image(
image_type: ImageType = Path(description="The type of image to get"),
@ -53,14 +55,30 @@ async def upload_image(file: UploadFile, request: Request):
# Error opening the image
return Response(status_code=415)
filename = f"{str(int(datetime.now(timezone.utc).timestamp()))}.png"
filename = f"{uuid.uuid4()}_{str(int(datetime.now(timezone.utc).timestamp()))}.png"
ApiDependencies.invoker.services.images.save(ImageType.UPLOAD, filename, im)
return Response(
status_code=201,
headers={
"Location": request.url_for(
"get_image", image_type=ImageType.UPLOAD, image_name=filename
"get_image", image_type=ImageType.UPLOAD.value, image_name=filename
)
},
)
@images_router.get(
"/",
operation_id="list_images",
responses={200: {"model": PaginatedResults[ImageResponse]}},
)
async def list_images(
image_type: ImageType = Query(default=ImageType.RESULT, description="The type of images to get"),
page: int = Query(default=0, description="The page of images to get"),
per_page: int = Query(default=10, description="The number of images per page"),
) -> PaginatedResults[ImageResponse]:
"""Gets a list of images"""
result = ApiDependencies.invoker.services.images.list(
image_type, page, per_page
)
return result

View File

@ -1,11 +1,17 @@
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654) and 2023 Kent Keirsey (https://github.com/hipsterusername)
import shutil
import asyncio
from typing import Annotated, Any, List, Literal, Optional, Union
from fastapi.routing import APIRouter
from fastapi.routing import APIRouter, HTTPException
from pydantic import BaseModel, Field, parse_obj_as
from pathlib import Path
from ..dependencies import ApiDependencies
from invokeai.backend.globals import Globals, global_converted_ckpts_dir
from invokeai.backend.args import Args
models_router = APIRouter(prefix="/v1/models", tags=["models"])
@ -15,11 +21,9 @@ class VaeRepo(BaseModel):
path: Optional[str] = Field(description="The path to the VAE")
subfolder: Optional[str] = Field(description="The subfolder to use for this VAE")
class ModelInfo(BaseModel):
description: Optional[str] = Field(description="A description of the model")
class CkptModelInfo(ModelInfo):
format: Literal['ckpt'] = 'ckpt'
@ -29,7 +33,6 @@ class CkptModelInfo(ModelInfo):
width: Optional[int] = Field(description="The width of the model")
height: Optional[int] = Field(description="The height of the model")
class DiffusersModelInfo(ModelInfo):
format: Literal['diffusers'] = 'diffusers'
@ -37,12 +40,29 @@ class DiffusersModelInfo(ModelInfo):
repo_id: Optional[str] = Field(description="The repo ID to use for this model")
path: Optional[str] = Field(description="The path to the model")
class CreateModelRequest(BaseModel):
name: str = Field(description="The name of the model")
info: Union[CkptModelInfo, DiffusersModelInfo] = Field(discriminator="format", description="The model info")
class CreateModelResponse(BaseModel):
name: str = Field(description="The name of the new model")
info: Union[CkptModelInfo, DiffusersModelInfo] = Field(discriminator="format", description="The model info")
status: str = Field(description="The status of the API response")
class ConversionRequest(BaseModel):
name: str = Field(description="The name of the new model")
info: CkptModelInfo = Field(description="The converted model info")
save_location: str = Field(description="The path to save the converted model weights")
class ConvertedModelResponse(BaseModel):
name: str = Field(description="The name of the new model")
info: DiffusersModelInfo = Field(description="The converted model info")
class ModelsList(BaseModel):
models: dict[str, Annotated[Union[(CkptModelInfo,DiffusersModelInfo)], Field(discriminator="format")]]
@models_router.get(
"/",
operation_id="list_models",
@ -54,108 +74,61 @@ async def list_models() -> ModelsList:
models = parse_obj_as(ModelsList, { "models": models_raw })
return models
# @socketio.on("requestSystemConfig")
# def handle_request_capabilities():
# print(">> System config requested")
# config = self.get_system_config()
# config["model_list"] = self.generate.model_manager.list_models()
# config["infill_methods"] = infill_methods()
# socketio.emit("systemConfig", config)
# @socketio.on("searchForModels")
# def handle_search_models(search_folder: str):
# try:
# if not search_folder:
# socketio.emit(
# "foundModels",
# {"search_folder": None, "found_models": None},
# )
# else:
# (
# search_folder,
# found_models,
# ) = self.generate.model_manager.search_models(search_folder)
# socketio.emit(
# "foundModels",
# {"search_folder": search_folder, "found_models": found_models},
# )
# except Exception as e:
# self.handle_exceptions(e)
# print("\n")
@models_router.post(
"/",
operation_id="update_model",
responses={200: {"status": "success"}},
)
async def update_model(
model_request: CreateModelRequest
) -> CreateModelResponse:
""" Add Model """
model_request_info = model_request.info
info_dict = model_request_info.dict()
model_response = CreateModelResponse(name=model_request.name, info=model_request.info, status="success")
# @socketio.on("addNewModel")
# def handle_add_model(new_model_config: dict):
# try:
# model_name = new_model_config["name"]
# del new_model_config["name"]
# model_attributes = new_model_config
# if len(model_attributes["vae"]) == 0:
# del model_attributes["vae"]
# update = False
# current_model_list = self.generate.model_manager.list_models()
# if model_name in current_model_list:
# update = True
ApiDependencies.invoker.services.model_manager.add_model(
model_name=model_request.name,
model_attributes=info_dict,
clobber=True,
)
# print(f">> Adding New Model: {model_name}")
return model_response
# self.generate.model_manager.add_model(
# model_name=model_name,
# model_attributes=model_attributes,
# clobber=True,
# )
# self.generate.model_manager.commit(opt.conf)
# new_model_list = self.generate.model_manager.list_models()
# socketio.emit(
# "newModelAdded",
# {
# "new_model_name": model_name,
# "model_list": new_model_list,
# "update": update,
# },
# )
# print(f">> New Model Added: {model_name}")
# except Exception as e:
# self.handle_exceptions(e)
@models_router.delete(
"/{model_name}",
operation_id="del_model",
responses={
204: {
"description": "Model deleted successfully"
},
404: {
"description": "Model not found"
}
},
)
async def delete_model(model_name: str) -> None:
"""Delete Model"""
model_names = ApiDependencies.invoker.services.model_manager.model_names()
model_exists = model_name in model_names
# @socketio.on("deleteModel")
# def handle_delete_model(model_name: str):
# try:
# print(f">> Deleting Model: {model_name}")
# self.generate.model_manager.del_model(model_name)
# self.generate.model_manager.commit(opt.conf)
# updated_model_list = self.generate.model_manager.list_models()
# socketio.emit(
# "modelDeleted",
# {
# "deleted_model_name": model_name,
# "model_list": updated_model_list,
# },
# )
# print(f">> Model Deleted: {model_name}")
# except Exception as e:
# self.handle_exceptions(e)
# check if model exists
print(f">> Checking for model {model_name}...")
if model_exists:
print(f">> Deleting Model: {model_name}")
ApiDependencies.invoker.services.model_manager.del_model(model_name, delete_files=True)
print(f">> Model Deleted: {model_name}")
raise HTTPException(status_code=204, detail=f"Model '{model_name}' deleted successfully")
else:
print(f">> Model not found")
raise HTTPException(status_code=404, detail=f"Model '{model_name}' not found")
# @socketio.on("requestModelChange")
# def handle_set_model(model_name: str):
# try:
# print(f">> Model change requested: {model_name}")
# model = self.generate.set_model(model_name)
# model_list = self.generate.model_manager.list_models()
# if model is None:
# socketio.emit(
# "modelChangeFailed",
# {"model_name": model_name, "model_list": model_list},
# )
# else:
# socketio.emit(
# "modelChanged",
# {"model_name": model_name, "model_list": model_list},
# )
# except Exception as e:
# self.handle_exceptions(e)
# @socketio.on("convertToDiffusers")
# @socketio.on("convertToDiffusers")
# def convert_to_diffusers(model_to_convert: dict):
# try:
# if model_info := self.generate.model_manager.model_info(
@ -275,5 +248,4 @@ async def list_models() -> ModelsList:
# )
# print(f">> Models Merged: {models_to_merge}")
# print(f">> New Model Added: {model_merge_info['merged_model_name']}")
# except Exception as e:
# self.handle_exceptions(e)
# except Exception as e:

View File

@ -6,7 +6,8 @@ from typing import Any, Callable, Iterable, Literal, get_args, get_origin, get_t
from pydantic import BaseModel, Field
import networkx as nx
import matplotlib.pyplot as plt
from ..invocations.image import ImageField
from ..models.image import ImageField
from ..services.graph import GraphExecutionState
from ..services.invoker import Invoker

View File

@ -7,9 +7,9 @@ import numpy
from PIL import Image, ImageOps
from pydantic import Field
from ..services.image_storage import ImageType
from invokeai.app.models.image import ImageField, ImageType
from .baseinvocation import BaseInvocation, InvocationContext
from .image import ImageField, ImageOutput
from .image import ImageOutput
class CvInpaintInvocation(BaseInvocation):

View File

@ -8,12 +8,13 @@ from torch import Tensor
from pydantic import Field
from ..services.image_storage import ImageType
from invokeai.app.models.image import ImageField, ImageType
from .baseinvocation import BaseInvocation, InvocationContext
from .image import ImageField, ImageOutput
from .image import ImageOutput
from ...backend.generator import Txt2Img, Img2Img, Inpaint, InvokeAIGenerator
from ...backend.stable_diffusion import PipelineIntermediateState
from ..util.util import diffusers_step_callback_adapter, CanceledException
from ..models.exceptions import CanceledException
from ..util.step_callback import diffusers_step_callback_adapter
SAMPLER_NAME_VALUES = Literal[
tuple(InvokeAIGenerator.schedulers())

View File

@ -7,20 +7,10 @@ import numpy
from PIL import Image, ImageFilter, ImageOps
from pydantic import BaseModel, Field
from ..services.image_storage import ImageType
from ..models.image import ImageField, ImageType
from ..services.invocation_services import InvocationServices
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext
class ImageField(BaseModel):
"""An image field used for passing image objects between invocations"""
image_type: str = Field(
default=ImageType.RESULT, description="The type of the image"
)
image_name: Optional[str] = Field(default=None, description="The name of the image")
class ImageOutput(BaseInvocationOutput):
"""Base class for invocations that output an image"""
#fmt: off

View File

@ -6,7 +6,7 @@ from torch import Tensor
import torch
from ...backend.model_management.model_manager import ModelManager
from ...backend.util.devices import CUDA_DEVICE, torch_dtype
from ...backend.util.devices import choose_torch_device, torch_dtype
from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import PostprocessingSettings
from ...backend.image_util.seamless import configure_model_padding
from ...backend.prompting.conditioning import get_uc_and_c_and_ec
@ -110,7 +110,7 @@ class NoiseInvocation(BaseInvocation):
height: int = Field(default=512, multiple_of=64, gt=0, description="The height of the resulting noise", )
def invoke(self, context: InvocationContext) -> NoiseOutput:
device = torch.device(CUDA_DEVICE)
device = torch.device(choose_torch_device())
noise = get_noise(self.width, self.height, device, self.seed)
name = f'{context.graph_execution_state_id}__{self.id}'

View File

@ -3,10 +3,10 @@ from typing import Literal, Union
from pydantic import Field
from ..services.image_storage import ImageType
from invokeai.app.models.image import ImageField, ImageType
from ..services.invocation_services import InvocationServices
from .baseinvocation import BaseInvocation, InvocationContext
from .image import ImageField, ImageOutput
from .image import ImageOutput
class RestoreFaceInvocation(BaseInvocation):
"""Restores faces in an image."""

View File

@ -5,10 +5,10 @@ from typing import Literal, Union
from pydantic import Field
from ..services.image_storage import ImageType
from invokeai.app.models.image import ImageField, ImageType
from ..services.invocation_services import InvocationServices
from .baseinvocation import BaseInvocation, InvocationContext
from .image import ImageField, ImageOutput
from .image import ImageOutput
class UpscaleInvocation(BaseInvocation):

View File

View File

@ -0,0 +1,3 @@
class CanceledException(Exception):
"""Execution canceled by user."""
pass

View File

@ -0,0 +1,26 @@
from enum import Enum
from typing import Optional
from pydantic import BaseModel, Field
class ImageType(str, Enum):
RESULT = "results"
INTERMEDIATE = "intermediates"
UPLOAD = "uploads"
class ImageField(BaseModel):
"""An image field used for passing image objects between invocations"""
image_type: str = Field(
default=ImageType.RESULT, description="The type of the image"
)
image_name: Optional[str] = Field(default=None, description="The name of the image")
class Config:
schema_extra = {
"required": [
"image_type",
"image_name",
]
}

View File

@ -0,0 +1,11 @@
from typing import Optional
from pydantic import BaseModel, Field
class ImageMetadata(BaseModel):
"""An image's metadata"""
timestamp: float = Field(description="The creation timestamp of the image")
width: int = Field(description="The width of the image in pixels")
height: int = Field(description="The height of the image in pixels")
# TODO: figure out metadata
sd_metadata: Optional[dict] = Field(default={}, description="The image's SD-specific metadata")

View File

@ -794,9 +794,6 @@ class GraphExecutionState(BaseModel):
default_factory=dict,
)
# Declare all fields as required; necessary for OpenAPI schema generation build.
# Technically only fields without a `default_factory` need to be listed here.
# See: https://github.com/pydantic/pydantic/discussions/4577
class Config:
schema_extra = {
'required': [

View File

@ -2,24 +2,25 @@
import datetime
import os
from glob import glob
from abc import ABC, abstractmethod
from enum import Enum
from pathlib import Path
from queue import Queue
from typing import Dict
from typing import Callable, Dict, List
from PIL.Image import Image
import PIL.Image as PILImage
from pydantic import BaseModel
from invokeai.app.api.models.images import ImageResponse
from invokeai.app.models.image import ImageField, ImageType
from invokeai.app.models.metadata import ImageMetadata
from invokeai.app.services.item_storage import PaginatedResults
from invokeai.app.util.save_thumbnail import save_thumbnail
from invokeai.backend.image_util import PngWriter
class ImageType(str, Enum):
RESULT = "results"
INTERMEDIATE = "intermediates"
UPLOAD = "uploads"
class ImageStorageBase(ABC):
"""Responsible for storing and retrieving images."""
@ -27,9 +28,17 @@ class ImageStorageBase(ABC):
def get(self, image_type: ImageType, image_name: str) -> Image:
pass
@abstractmethod
def list(
self, image_type: ImageType, page: int = 0, per_page: int = 10
) -> PaginatedResults[ImageResponse]:
pass
# TODO: make this a bit more flexible for e.g. cloud storage
@abstractmethod
def get_path(self, image_type: ImageType, image_name: str) -> str:
def get_path(
self, image_type: ImageType, image_name: str, is_thumbnail: bool = False
) -> str:
pass
@abstractmethod
@ -71,19 +80,74 @@ class DiskImageStorage(ImageStorageBase):
parents=True, exist_ok=True
)
def list(
self, image_type: ImageType, page: int = 0, per_page: int = 10
) -> PaginatedResults[ImageResponse]:
dir_path = os.path.join(self.__output_folder, image_type)
image_paths = glob(f"{dir_path}/*.png")
count = len(image_paths)
sorted_image_paths = sorted(
glob(f"{dir_path}/*.png"), key=os.path.getctime, reverse=True
)
page_of_image_paths = sorted_image_paths[
page * per_page : (page + 1) * per_page
]
page_of_images: List[ImageResponse] = []
for path in page_of_image_paths:
filename = os.path.basename(path)
img = PILImage.open(path)
page_of_images.append(
ImageResponse(
image_type=image_type.value,
image_name=filename,
# TODO: DiskImageStorage should not be building URLs...?
image_url=f"api/v1/images/{image_type.value}/{filename}",
thumbnail_url=f"api/v1/images/{image_type.value}/thumbnails/{os.path.splitext(filename)[0]}.webp",
# TODO: Creation of this object should happen elsewhere, just making it fit here so it works
metadata=ImageMetadata(
timestamp=os.path.getctime(path),
width=img.width,
height=img.height,
),
)
)
page_count_trunc = int(count / per_page)
page_count_mod = count % per_page
page_count = page_count_trunc if page_count_mod == 0 else page_count_trunc + 1
return PaginatedResults[ImageResponse](
items=page_of_images,
page=page,
pages=page_count,
per_page=per_page,
total=count,
)
def get(self, image_type: ImageType, image_name: str) -> Image:
image_path = self.get_path(image_type, image_name)
cache_item = self.__get_cache(image_path)
if cache_item:
return cache_item
image = Image.open(image_path)
image = PILImage.open(image_path)
self.__set_cache(image_path, image)
return image
# TODO: make this a bit more flexible for e.g. cloud storage
def get_path(self, image_type: ImageType, image_name: str) -> str:
path = os.path.join(self.__output_folder, image_type, image_name)
def get_path(
self, image_type: ImageType, image_name: str, is_thumbnail: bool = False
) -> str:
if is_thumbnail:
path = os.path.join(
self.__output_folder, image_type, "thumbnails", image_name
)
else:
path = os.path.join(self.__output_folder, image_type, image_name)
return path
def save(self, image_type: ImageType, image_name: str, image: Image) -> None:
@ -101,12 +165,19 @@ class DiskImageStorage(ImageStorageBase):
def delete(self, image_type: ImageType, image_name: str) -> None:
image_path = self.get_path(image_type, image_name)
thumbnail_path = self.get_path(image_type, image_name, True)
if os.path.exists(image_path):
os.remove(image_path)
if image_path in self.__cache:
del self.__cache[image_path]
if os.path.exists(thumbnail_path):
os.remove(thumbnail_path)
if thumbnail_path in self.__cache:
del self.__cache[thumbnail_path]
def __get_cache(self, image_name: str) -> Image:
return None if image_name not in self.__cache else self.__cache[image_name]

View File

@ -4,7 +4,7 @@ from threading import Event, Thread
from ..invocations.baseinvocation import InvocationContext
from .invocation_queue import InvocationQueueItem
from .invoker import InvocationProcessorABC, Invoker
from ..util.util import CanceledException
from ..models.exceptions import CanceledException
class DefaultInvocationProcessor(InvocationProcessorABC):
__invoker_thread: Thread

View File

View File

@ -1,14 +1,16 @@
import torch
from PIL import Image
from ..invocations.baseinvocation import InvocationContext
from ...backend.util.util import image_to_dataURL
from ...backend.generator.base import Generator
from ...backend.stable_diffusion import PipelineIntermediateState
class CanceledException(Exception):
pass
def fast_latents_step_callback(sample: torch.Tensor, step: int, steps: int, id: str, context: InvocationContext, ):
def fast_latents_step_callback(
sample: torch.Tensor,
step: int,
steps: int,
id: str,
context: InvocationContext,
):
# TODO: only output a preview image when requested
image = Generator.sample_to_lowres_estimated_image(sample)
@ -21,15 +23,12 @@ def fast_latents_step_callback(sample: torch.Tensor, step: int, steps: int, id:
context.services.events.emit_generator_progress(
context.graph_execution_state_id,
id,
{
"width": width,
"height": height,
"dataURL": dataURL
},
{"width": width, "height": height, "dataURL": dataURL},
step,
steps,
)
def diffusers_step_callback_adapter(*cb_args, **kwargs):
"""
txt2img gives us a Tensor in the step_callbak, while img2img gives us a PipelineIntermediateState.
@ -37,6 +36,8 @@ def diffusers_step_callback_adapter(*cb_args, **kwargs):
"""
if isinstance(cb_args[0], PipelineIntermediateState):
progress_state: PipelineIntermediateState = cb_args[0]
return fast_latents_step_callback(progress_state.latents, progress_state.step, **kwargs)
return fast_latents_step_callback(
progress_state.latents, progress_state.step, **kwargs
)
else:
return fast_latents_step_callback(*cb_args, **kwargs)

View File

@ -63,6 +63,7 @@ dependencies = [
"prompt-toolkit",
"pypatchmatch",
"pyreadline3",
"python-multipart==0.0.6",
"pytorch-lightning==1.7.7",
"realesrgan",
"requests==2.28.2",