mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(nodes): add list_images endpoint
- add `list_images` endpoint at `GET api/v1/images` - extend `ImageStorageBase` with `list()` method, implemented it for `DiskImageStorage` - add `ImageReponse` class to for image responses, which includes urls, metadata - add `ImageMetadata` class (basically a stub at the moment) - uploaded images now named `"{uuid}_{timestamp}.png"` - add `models` modules. besides separating concerns more clearly, this helps to mitigate circular dependencies - improve thumbnail handling
This commit is contained in:
parent
54d9833db0
commit
34402cc46a
14
invokeai/app/api/models/images.py
Normal file
14
invokeai/app/api/models/images.py
Normal file
@ -0,0 +1,14 @@
|
|||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from invokeai.app.models.image import ImageType
|
||||||
|
from invokeai.app.models.metadata import ImageMetadata
|
||||||
|
|
||||||
|
|
||||||
|
class ImageResponse(BaseModel):
|
||||||
|
"""The response type for images"""
|
||||||
|
|
||||||
|
image_type: ImageType = Field(description="The type of the image")
|
||||||
|
image_name: str = Field(description="The name of the image")
|
||||||
|
image_url: str = Field(description="The url of the image")
|
||||||
|
thumbnail_url: str = Field(description="The url of the image's thumbnail")
|
||||||
|
metadata: ImageMetadata = Field(description="The image's metadata")
|
@ -1,18 +1,20 @@
|
|||||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||||
|
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
|
import uuid
|
||||||
|
|
||||||
from fastapi import Path, Request, UploadFile
|
from fastapi import Path, Query, Request, UploadFile
|
||||||
from fastapi.responses import FileResponse, Response
|
from fastapi.responses import FileResponse, Response
|
||||||
from fastapi.routing import APIRouter
|
from fastapi.routing import APIRouter
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
from invokeai.app.api.models.images import ImageResponse
|
||||||
|
from invokeai.app.services.item_storage import PaginatedResults
|
||||||
|
|
||||||
from ...services.image_storage import ImageType
|
from ...services.image_storage import ImageType
|
||||||
from ..dependencies import ApiDependencies
|
from ..dependencies import ApiDependencies
|
||||||
|
|
||||||
images_router = APIRouter(prefix="/v1/images", tags=["images"])
|
images_router = APIRouter(prefix="/v1/images", tags=["images"])
|
||||||
|
|
||||||
|
|
||||||
@images_router.get("/{image_type}/{image_name}", operation_id="get_image")
|
@images_router.get("/{image_type}/{image_name}", operation_id="get_image")
|
||||||
async def get_image(
|
async def get_image(
|
||||||
image_type: ImageType = Path(description="The type of image to get"),
|
image_type: ImageType = Path(description="The type of image to get"),
|
||||||
@ -53,14 +55,30 @@ async def upload_image(file: UploadFile, request: Request):
|
|||||||
# Error opening the image
|
# Error opening the image
|
||||||
return Response(status_code=415)
|
return Response(status_code=415)
|
||||||
|
|
||||||
filename = f"{str(int(datetime.now(timezone.utc).timestamp()))}.png"
|
filename = f"{uuid.uuid4()}_{str(int(datetime.now(timezone.utc).timestamp()))}.png"
|
||||||
ApiDependencies.invoker.services.images.save(ImageType.UPLOAD, filename, im)
|
ApiDependencies.invoker.services.images.save(ImageType.UPLOAD, filename, im)
|
||||||
|
|
||||||
return Response(
|
return Response(
|
||||||
status_code=201,
|
status_code=201,
|
||||||
headers={
|
headers={
|
||||||
"Location": request.url_for(
|
"Location": request.url_for(
|
||||||
"get_image", image_type=ImageType.UPLOAD, image_name=filename
|
"get_image", image_type=ImageType.UPLOAD.value, image_name=filename
|
||||||
)
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@images_router.get(
|
||||||
|
"/",
|
||||||
|
operation_id="list_images",
|
||||||
|
responses={200: {"model": PaginatedResults[ImageResponse]}},
|
||||||
|
)
|
||||||
|
async def list_images(
|
||||||
|
image_type: ImageType = Query(default=ImageType.RESULT, description="The type of images to get"),
|
||||||
|
page: int = Query(default=0, description="The page of images to get"),
|
||||||
|
per_page: int = Query(default=10, description="The number of images per page"),
|
||||||
|
) -> PaginatedResults[ImageResponse]:
|
||||||
|
"""Gets a list of images"""
|
||||||
|
result = ApiDependencies.invoker.services.images.list(
|
||||||
|
image_type, page, per_page
|
||||||
|
)
|
||||||
|
return result
|
||||||
|
@ -6,7 +6,8 @@ from typing import Any, Callable, Iterable, Literal, get_args, get_origin, get_t
|
|||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
import networkx as nx
|
import networkx as nx
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
from ..invocations.image import ImageField
|
|
||||||
|
from ..models.image import ImageField
|
||||||
from ..services.graph import GraphExecutionState
|
from ..services.graph import GraphExecutionState
|
||||||
from ..services.invoker import Invoker
|
from ..services.invoker import Invoker
|
||||||
|
|
||||||
|
@ -7,9 +7,9 @@ import numpy
|
|||||||
from PIL import Image, ImageOps
|
from PIL import Image, ImageOps
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
|
|
||||||
from ..services.image_storage import ImageType
|
from invokeai.app.models.image import ImageField, ImageType
|
||||||
from .baseinvocation import BaseInvocation, InvocationContext
|
from .baseinvocation import BaseInvocation, InvocationContext
|
||||||
from .image import ImageField, ImageOutput
|
from .image import ImageOutput
|
||||||
|
|
||||||
|
|
||||||
class CvInpaintInvocation(BaseInvocation):
|
class CvInpaintInvocation(BaseInvocation):
|
||||||
|
@ -8,12 +8,13 @@ from torch import Tensor
|
|||||||
|
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
|
|
||||||
from ..services.image_storage import ImageType
|
from invokeai.app.models.image import ImageField, ImageType
|
||||||
from .baseinvocation import BaseInvocation, InvocationContext
|
from .baseinvocation import BaseInvocation, InvocationContext
|
||||||
from .image import ImageField, ImageOutput
|
from .image import ImageOutput
|
||||||
from ...backend.generator import Txt2Img, Img2Img, Inpaint, InvokeAIGenerator
|
from ...backend.generator import Txt2Img, Img2Img, Inpaint, InvokeAIGenerator
|
||||||
from ...backend.stable_diffusion import PipelineIntermediateState
|
from ...backend.stable_diffusion import PipelineIntermediateState
|
||||||
from ..util.util import diffusers_step_callback_adapter, CanceledException
|
from ..models.exceptions import CanceledException
|
||||||
|
from ..util.step_callback import diffusers_step_callback_adapter
|
||||||
|
|
||||||
SAMPLER_NAME_VALUES = Literal[
|
SAMPLER_NAME_VALUES = Literal[
|
||||||
tuple(InvokeAIGenerator.schedulers())
|
tuple(InvokeAIGenerator.schedulers())
|
||||||
|
@ -7,20 +7,10 @@ import numpy
|
|||||||
from PIL import Image, ImageFilter, ImageOps
|
from PIL import Image, ImageFilter, ImageOps
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from ..services.image_storage import ImageType
|
from ..models.image import ImageField, ImageType
|
||||||
from ..services.invocation_services import InvocationServices
|
from ..services.invocation_services import InvocationServices
|
||||||
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext
|
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext
|
||||||
|
|
||||||
|
|
||||||
class ImageField(BaseModel):
|
|
||||||
"""An image field used for passing image objects between invocations"""
|
|
||||||
|
|
||||||
image_type: str = Field(
|
|
||||||
default=ImageType.RESULT, description="The type of the image"
|
|
||||||
)
|
|
||||||
image_name: Optional[str] = Field(default=None, description="The name of the image")
|
|
||||||
|
|
||||||
|
|
||||||
class ImageOutput(BaseInvocationOutput):
|
class ImageOutput(BaseInvocationOutput):
|
||||||
"""Base class for invocations that output an image"""
|
"""Base class for invocations that output an image"""
|
||||||
#fmt: off
|
#fmt: off
|
||||||
|
@ -3,10 +3,10 @@ from typing import Literal, Union
|
|||||||
|
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
|
|
||||||
from ..services.image_storage import ImageType
|
from invokeai.app.models.image import ImageField, ImageType
|
||||||
from ..services.invocation_services import InvocationServices
|
from ..services.invocation_services import InvocationServices
|
||||||
from .baseinvocation import BaseInvocation, InvocationContext
|
from .baseinvocation import BaseInvocation, InvocationContext
|
||||||
from .image import ImageField, ImageOutput
|
from .image import ImageOutput
|
||||||
|
|
||||||
class RestoreFaceInvocation(BaseInvocation):
|
class RestoreFaceInvocation(BaseInvocation):
|
||||||
"""Restores faces in an image."""
|
"""Restores faces in an image."""
|
||||||
|
@ -5,10 +5,10 @@ from typing import Literal, Union
|
|||||||
|
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
|
|
||||||
from ..services.image_storage import ImageType
|
from invokeai.app.models.image import ImageField, ImageType
|
||||||
from ..services.invocation_services import InvocationServices
|
from ..services.invocation_services import InvocationServices
|
||||||
from .baseinvocation import BaseInvocation, InvocationContext
|
from .baseinvocation import BaseInvocation, InvocationContext
|
||||||
from .image import ImageField, ImageOutput
|
from .image import ImageOutput
|
||||||
|
|
||||||
|
|
||||||
class UpscaleInvocation(BaseInvocation):
|
class UpscaleInvocation(BaseInvocation):
|
||||||
|
0
invokeai/app/models/__init__.py
Normal file
0
invokeai/app/models/__init__.py
Normal file
3
invokeai/app/models/exceptions.py
Normal file
3
invokeai/app/models/exceptions.py
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
class CanceledException(Exception):
|
||||||
|
"""Execution canceled by user."""
|
||||||
|
pass
|
26
invokeai/app/models/image.py
Normal file
26
invokeai/app/models/image.py
Normal file
@ -0,0 +1,26 @@
|
|||||||
|
from enum import Enum
|
||||||
|
from typing import Optional
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
|
class ImageType(str, Enum):
|
||||||
|
RESULT = "results"
|
||||||
|
INTERMEDIATE = "intermediates"
|
||||||
|
UPLOAD = "uploads"
|
||||||
|
|
||||||
|
|
||||||
|
class ImageField(BaseModel):
|
||||||
|
"""An image field used for passing image objects between invocations"""
|
||||||
|
|
||||||
|
image_type: str = Field(
|
||||||
|
default=ImageType.RESULT, description="The type of the image"
|
||||||
|
)
|
||||||
|
image_name: Optional[str] = Field(default=None, description="The name of the image")
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
schema_extra = {
|
||||||
|
"required": [
|
||||||
|
"image_type",
|
||||||
|
"image_name",
|
||||||
|
]
|
||||||
|
}
|
11
invokeai/app/models/metadata.py
Normal file
11
invokeai/app/models/metadata.py
Normal file
@ -0,0 +1,11 @@
|
|||||||
|
from typing import Optional
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
class ImageMetadata(BaseModel):
|
||||||
|
"""An image's metadata"""
|
||||||
|
|
||||||
|
timestamp: float = Field(description="The creation timestamp of the image")
|
||||||
|
width: int = Field(description="The width of the image in pixels")
|
||||||
|
height: int = Field(description="The height of the image in pixels")
|
||||||
|
# TODO: figure out metadata
|
||||||
|
sd_metadata: Optional[dict] = Field(default={}, description="The image's SD-specific metadata")
|
@ -2,24 +2,25 @@
|
|||||||
|
|
||||||
import datetime
|
import datetime
|
||||||
import os
|
import os
|
||||||
|
from glob import glob
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from queue import Queue
|
from queue import Queue
|
||||||
from typing import Dict
|
from typing import Callable, Dict, List
|
||||||
|
|
||||||
from PIL.Image import Image
|
from PIL.Image import Image
|
||||||
|
import PIL.Image as PILImage
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from invokeai.app.api.models.images import ImageResponse
|
||||||
|
from invokeai.app.models.image import ImageField, ImageType
|
||||||
|
from invokeai.app.models.metadata import ImageMetadata
|
||||||
|
from invokeai.app.services.item_storage import PaginatedResults
|
||||||
from invokeai.app.util.save_thumbnail import save_thumbnail
|
from invokeai.app.util.save_thumbnail import save_thumbnail
|
||||||
|
|
||||||
from invokeai.backend.image_util import PngWriter
|
from invokeai.backend.image_util import PngWriter
|
||||||
|
|
||||||
|
|
||||||
class ImageType(str, Enum):
|
|
||||||
RESULT = "results"
|
|
||||||
INTERMEDIATE = "intermediates"
|
|
||||||
UPLOAD = "uploads"
|
|
||||||
|
|
||||||
|
|
||||||
class ImageStorageBase(ABC):
|
class ImageStorageBase(ABC):
|
||||||
"""Responsible for storing and retrieving images."""
|
"""Responsible for storing and retrieving images."""
|
||||||
|
|
||||||
@ -27,9 +28,17 @@ class ImageStorageBase(ABC):
|
|||||||
def get(self, image_type: ImageType, image_name: str) -> Image:
|
def get(self, image_type: ImageType, image_name: str) -> Image:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def list(
|
||||||
|
self, image_type: ImageType, page: int = 0, per_page: int = 10
|
||||||
|
) -> PaginatedResults[ImageResponse]:
|
||||||
|
pass
|
||||||
|
|
||||||
# TODO: make this a bit more flexible for e.g. cloud storage
|
# TODO: make this a bit more flexible for e.g. cloud storage
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_path(self, image_type: ImageType, image_name: str) -> str:
|
def get_path(
|
||||||
|
self, image_type: ImageType, image_name: str, is_thumbnail: bool = False
|
||||||
|
) -> str:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
@ -71,19 +80,74 @@ class DiskImageStorage(ImageStorageBase):
|
|||||||
parents=True, exist_ok=True
|
parents=True, exist_ok=True
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def list(
|
||||||
|
self, image_type: ImageType, page: int = 0, per_page: int = 10
|
||||||
|
) -> PaginatedResults[ImageResponse]:
|
||||||
|
dir_path = os.path.join(self.__output_folder, image_type)
|
||||||
|
image_paths = glob(f"{dir_path}/*.png")
|
||||||
|
count = len(image_paths)
|
||||||
|
|
||||||
|
sorted_image_paths = sorted(
|
||||||
|
glob(f"{dir_path}/*.png"), key=os.path.getctime, reverse=True
|
||||||
|
)
|
||||||
|
|
||||||
|
page_of_image_paths = sorted_image_paths[
|
||||||
|
page * per_page : (page + 1) * per_page
|
||||||
|
]
|
||||||
|
|
||||||
|
page_of_images: List[ImageResponse] = []
|
||||||
|
|
||||||
|
for path in page_of_image_paths:
|
||||||
|
filename = os.path.basename(path)
|
||||||
|
img = PILImage.open(path)
|
||||||
|
page_of_images.append(
|
||||||
|
ImageResponse(
|
||||||
|
image_type=image_type.value,
|
||||||
|
image_name=filename,
|
||||||
|
# TODO: DiskImageStorage should not be building URLs...?
|
||||||
|
image_url=f"api/v1/images/{image_type.value}/{filename}",
|
||||||
|
thumbnail_url=f"api/v1/images/{image_type.value}/thumbnails/{os.path.splitext(filename)[0]}.webp",
|
||||||
|
# TODO: Creation of this object should happen elsewhere, just making it fit here so it works
|
||||||
|
metadata=ImageMetadata(
|
||||||
|
timestamp=os.path.getctime(path),
|
||||||
|
width=img.width,
|
||||||
|
height=img.height,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
page_count_trunc = int(count / per_page)
|
||||||
|
page_count_mod = count % per_page
|
||||||
|
page_count = page_count_trunc if page_count_mod == 0 else page_count_trunc + 1
|
||||||
|
|
||||||
|
return PaginatedResults[ImageResponse](
|
||||||
|
items=page_of_images,
|
||||||
|
page=page,
|
||||||
|
pages=page_count,
|
||||||
|
per_page=per_page,
|
||||||
|
total=count,
|
||||||
|
)
|
||||||
|
|
||||||
def get(self, image_type: ImageType, image_name: str) -> Image:
|
def get(self, image_type: ImageType, image_name: str) -> Image:
|
||||||
image_path = self.get_path(image_type, image_name)
|
image_path = self.get_path(image_type, image_name)
|
||||||
cache_item = self.__get_cache(image_path)
|
cache_item = self.__get_cache(image_path)
|
||||||
if cache_item:
|
if cache_item:
|
||||||
return cache_item
|
return cache_item
|
||||||
|
|
||||||
image = Image.open(image_path)
|
image = PILImage.open(image_path)
|
||||||
self.__set_cache(image_path, image)
|
self.__set_cache(image_path, image)
|
||||||
return image
|
return image
|
||||||
|
|
||||||
# TODO: make this a bit more flexible for e.g. cloud storage
|
# TODO: make this a bit more flexible for e.g. cloud storage
|
||||||
def get_path(self, image_type: ImageType, image_name: str) -> str:
|
def get_path(
|
||||||
path = os.path.join(self.__output_folder, image_type, image_name)
|
self, image_type: ImageType, image_name: str, is_thumbnail: bool = False
|
||||||
|
) -> str:
|
||||||
|
if is_thumbnail:
|
||||||
|
path = os.path.join(
|
||||||
|
self.__output_folder, image_type, "thumbnails", image_name
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
path = os.path.join(self.__output_folder, image_type, image_name)
|
||||||
return path
|
return path
|
||||||
|
|
||||||
def save(self, image_type: ImageType, image_name: str, image: Image) -> None:
|
def save(self, image_type: ImageType, image_name: str, image: Image) -> None:
|
||||||
@ -101,12 +165,19 @@ class DiskImageStorage(ImageStorageBase):
|
|||||||
|
|
||||||
def delete(self, image_type: ImageType, image_name: str) -> None:
|
def delete(self, image_type: ImageType, image_name: str) -> None:
|
||||||
image_path = self.get_path(image_type, image_name)
|
image_path = self.get_path(image_type, image_name)
|
||||||
|
thumbnail_path = self.get_path(image_type, image_name, True)
|
||||||
if os.path.exists(image_path):
|
if os.path.exists(image_path):
|
||||||
os.remove(image_path)
|
os.remove(image_path)
|
||||||
|
|
||||||
if image_path in self.__cache:
|
if image_path in self.__cache:
|
||||||
del self.__cache[image_path]
|
del self.__cache[image_path]
|
||||||
|
|
||||||
|
if os.path.exists(thumbnail_path):
|
||||||
|
os.remove(thumbnail_path)
|
||||||
|
|
||||||
|
if thumbnail_path in self.__cache:
|
||||||
|
del self.__cache[thumbnail_path]
|
||||||
|
|
||||||
def __get_cache(self, image_name: str) -> Image:
|
def __get_cache(self, image_name: str) -> Image:
|
||||||
return None if image_name not in self.__cache else self.__cache[image_name]
|
return None if image_name not in self.__cache else self.__cache[image_name]
|
||||||
|
|
||||||
|
@ -4,7 +4,7 @@ from threading import Event, Thread
|
|||||||
from ..invocations.baseinvocation import InvocationContext
|
from ..invocations.baseinvocation import InvocationContext
|
||||||
from .invocation_queue import InvocationQueueItem
|
from .invocation_queue import InvocationQueueItem
|
||||||
from .invoker import InvocationProcessorABC, Invoker
|
from .invoker import InvocationProcessorABC, Invoker
|
||||||
from ..util.util import CanceledException
|
from ..models.exceptions import CanceledException
|
||||||
|
|
||||||
class DefaultInvocationProcessor(InvocationProcessorABC):
|
class DefaultInvocationProcessor(InvocationProcessorABC):
|
||||||
__invoker_thread: Thread
|
__invoker_thread: Thread
|
||||||
|
0
invokeai/app/util/__init__.py
Normal file
0
invokeai/app/util/__init__.py
Normal file
@ -1,14 +1,16 @@
|
|||||||
import torch
|
import torch
|
||||||
from PIL import Image
|
|
||||||
from ..invocations.baseinvocation import InvocationContext
|
from ..invocations.baseinvocation import InvocationContext
|
||||||
from ...backend.util.util import image_to_dataURL
|
from ...backend.util.util import image_to_dataURL
|
||||||
from ...backend.generator.base import Generator
|
from ...backend.generator.base import Generator
|
||||||
from ...backend.stable_diffusion import PipelineIntermediateState
|
from ...backend.stable_diffusion import PipelineIntermediateState
|
||||||
|
|
||||||
class CanceledException(Exception):
|
def fast_latents_step_callback(
|
||||||
pass
|
sample: torch.Tensor,
|
||||||
|
step: int,
|
||||||
def fast_latents_step_callback(sample: torch.Tensor, step: int, steps: int, id: str, context: InvocationContext, ):
|
steps: int,
|
||||||
|
id: str,
|
||||||
|
context: InvocationContext,
|
||||||
|
):
|
||||||
# TODO: only output a preview image when requested
|
# TODO: only output a preview image when requested
|
||||||
image = Generator.sample_to_lowres_estimated_image(sample)
|
image = Generator.sample_to_lowres_estimated_image(sample)
|
||||||
|
|
||||||
@ -21,15 +23,12 @@ def fast_latents_step_callback(sample: torch.Tensor, step: int, steps: int, id:
|
|||||||
context.services.events.emit_generator_progress(
|
context.services.events.emit_generator_progress(
|
||||||
context.graph_execution_state_id,
|
context.graph_execution_state_id,
|
||||||
id,
|
id,
|
||||||
{
|
{"width": width, "height": height, "dataURL": dataURL},
|
||||||
"width": width,
|
|
||||||
"height": height,
|
|
||||||
"dataURL": dataURL
|
|
||||||
},
|
|
||||||
step,
|
step,
|
||||||
steps,
|
steps,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def diffusers_step_callback_adapter(*cb_args, **kwargs):
|
def diffusers_step_callback_adapter(*cb_args, **kwargs):
|
||||||
"""
|
"""
|
||||||
txt2img gives us a Tensor in the step_callbak, while img2img gives us a PipelineIntermediateState.
|
txt2img gives us a Tensor in the step_callbak, while img2img gives us a PipelineIntermediateState.
|
||||||
@ -37,6 +36,8 @@ def diffusers_step_callback_adapter(*cb_args, **kwargs):
|
|||||||
"""
|
"""
|
||||||
if isinstance(cb_args[0], PipelineIntermediateState):
|
if isinstance(cb_args[0], PipelineIntermediateState):
|
||||||
progress_state: PipelineIntermediateState = cb_args[0]
|
progress_state: PipelineIntermediateState = cb_args[0]
|
||||||
return fast_latents_step_callback(progress_state.latents, progress_state.step, **kwargs)
|
return fast_latents_step_callback(
|
||||||
|
progress_state.latents, progress_state.step, **kwargs
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
return fast_latents_step_callback(*cb_args, **kwargs)
|
return fast_latents_step_callback(*cb_args, **kwargs)
|
Loading…
x
Reference in New Issue
Block a user