mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(nodes): add metadata handling
This commit is contained in:
parent
f071b03ceb
commit
5de3c41d19
@ -4,6 +4,7 @@ from logging import Logger
|
|||||||
import os
|
import os
|
||||||
from invokeai.app.services.image_record_storage import SqliteImageRecordStorage
|
from invokeai.app.services.image_record_storage import SqliteImageRecordStorage
|
||||||
from invokeai.app.services.images import ImageService
|
from invokeai.app.services.images import ImageService
|
||||||
|
from invokeai.app.services.metadata import CoreMetadataService
|
||||||
from invokeai.app.services.urls import LocalUrlService
|
from invokeai.app.services.urls import LocalUrlService
|
||||||
from invokeai.backend.util.logging import InvokeAILogger
|
from invokeai.backend.util.logging import InvokeAILogger
|
||||||
|
|
||||||
@ -18,7 +19,6 @@ from ..services.invocation_services import InvocationServices
|
|||||||
from ..services.invoker import Invoker
|
from ..services.invoker import Invoker
|
||||||
from ..services.processor import DefaultInvocationProcessor
|
from ..services.processor import DefaultInvocationProcessor
|
||||||
from ..services.sqlite import SqliteItemStorage
|
from ..services.sqlite import SqliteItemStorage
|
||||||
from ..services.metadata import PngMetadataService
|
|
||||||
from .events import FastAPIEventService
|
from .events import FastAPIEventService
|
||||||
|
|
||||||
|
|
||||||
@ -59,7 +59,7 @@ class ApiDependencies:
|
|||||||
DiskLatentsStorage(f"{output_folder}/latents")
|
DiskLatentsStorage(f"{output_folder}/latents")
|
||||||
)
|
)
|
||||||
|
|
||||||
metadata = PngMetadataService()
|
metadata = CoreMetadataService()
|
||||||
|
|
||||||
urls = LocalUrlService()
|
urls = LocalUrlService()
|
||||||
|
|
||||||
@ -80,6 +80,7 @@ class ApiDependencies:
|
|||||||
metadata=metadata,
|
metadata=metadata,
|
||||||
url=urls,
|
url=urls,
|
||||||
logger=logger,
|
logger=logger,
|
||||||
|
graph_execution_manager=graph_execution_manager,
|
||||||
)
|
)
|
||||||
|
|
||||||
services = InvocationServices(
|
services = InvocationServices(
|
||||||
|
@ -2,7 +2,6 @@ from typing import Optional
|
|||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from invokeai.app.models.image import ImageType
|
from invokeai.app.models.image import ImageType
|
||||||
from invokeai.app.services.metadata import InvokeAIMetadata
|
|
||||||
|
|
||||||
|
|
||||||
class ImageResponseMetadata(BaseModel):
|
class ImageResponseMetadata(BaseModel):
|
||||||
@ -11,9 +10,9 @@ class ImageResponseMetadata(BaseModel):
|
|||||||
created: int = Field(description="The creation timestamp of the image")
|
created: int = Field(description="The creation timestamp of the image")
|
||||||
width: int = Field(description="The width of the image in pixels")
|
width: int = Field(description="The width of the image in pixels")
|
||||||
height: int = Field(description="The height of the image in pixels")
|
height: int = Field(description="The height of the image in pixels")
|
||||||
invokeai: Optional[InvokeAIMetadata] = Field(
|
# invokeai: Optional[InvokeAIMetadata] = Field(
|
||||||
description="The image's InvokeAI-specific metadata"
|
# description="The image's InvokeAI-specific metadata"
|
||||||
)
|
# )
|
||||||
|
|
||||||
|
|
||||||
class ImageResponse(BaseModel):
|
class ImageResponse(BaseModel):
|
||||||
|
@ -7,6 +7,7 @@ from pydantic import BaseModel, Field
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from invokeai.app.invocations.util.choose_model import choose_model
|
from invokeai.app.invocations.util.choose_model import choose_model
|
||||||
|
from invokeai.app.models.image import ImageCategory
|
||||||
from invokeai.app.util.misc import SEED_MAX, get_random_seed
|
from invokeai.app.util.misc import SEED_MAX, get_random_seed
|
||||||
|
|
||||||
from invokeai.app.util.step_callback import stable_diffusion_step_callback
|
from invokeai.app.util.step_callback import stable_diffusion_step_callback
|
||||||
@ -356,20 +357,30 @@ class LatentsToImageInvocation(BaseInvocation):
|
|||||||
np_image = model.decode_latents(latents)
|
np_image = model.decode_latents(latents)
|
||||||
image = model.numpy_to_pil(np_image)[0]
|
image = model.numpy_to_pil(np_image)[0]
|
||||||
|
|
||||||
image_type = ImageType.RESULT
|
# image_type = ImageType.RESULT
|
||||||
image_name = context.services.images.create_name(
|
# image_name = context.services.images.create_name(
|
||||||
context.graph_execution_state_id, self.id
|
# context.graph_execution_state_id, self.id
|
||||||
|
# )
|
||||||
|
|
||||||
|
# metadata = context.services.metadata.build_metadata(
|
||||||
|
# session_id=context.graph_execution_state_id, node=self
|
||||||
|
# )
|
||||||
|
|
||||||
|
# torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
# context.services.images.save(image_type, image_name, image, metadata)
|
||||||
|
image_dto = context.services.images_new.create(
|
||||||
|
image=image,
|
||||||
|
image_type=ImageType.RESULT,
|
||||||
|
image_category=ImageCategory.IMAGE,
|
||||||
|
session_id=context.graph_execution_state_id,
|
||||||
|
node_id=self.id,
|
||||||
)
|
)
|
||||||
|
|
||||||
metadata = context.services.metadata.build_metadata(
|
|
||||||
session_id=context.graph_execution_state_id, node=self
|
|
||||||
)
|
|
||||||
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
|
|
||||||
context.services.images.save(image_type, image_name, image, metadata)
|
|
||||||
return build_image_output(
|
return build_image_output(
|
||||||
image_type=image_type, image_name=image_name, image=image
|
image_type=image_dto.image_type,
|
||||||
|
image_name=image_dto.image_name,
|
||||||
|
image=image,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
from pydantic import BaseModel, Field, StrictFloat, StrictInt, StrictStr
|
from pydantic import BaseModel, Extra, Field, StrictFloat, StrictInt, StrictStr
|
||||||
|
|
||||||
|
|
||||||
class ImageMetadata(BaseModel):
|
class ImageMetadata(BaseModel):
|
||||||
@ -8,11 +8,24 @@ class ImageMetadata(BaseModel):
|
|||||||
|
|
||||||
Also includes any metadata from the image's PNG tEXt chunks.
|
Also includes any metadata from the image's PNG tEXt chunks.
|
||||||
|
|
||||||
Generated by traversing the execution graph, collecting the parameters of the nearest ancestors of a given node.
|
Generated by traversing the execution graph, collecting the parameters of the nearest ancestors
|
||||||
|
of a given node.
|
||||||
|
|
||||||
Full metadata may be accessed by querying for the session in the `graph_executions` table.
|
Full metadata may be accessed by querying for the session in the `graph_executions` table.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
extra = Extra.allow
|
||||||
|
"""
|
||||||
|
This lets the ImageMetadata class accept arbitrary additional fields. The CoreMetadataService
|
||||||
|
won't add any fields that are not already defined, but other a different metadata service
|
||||||
|
implementation might.
|
||||||
|
"""
|
||||||
|
|
||||||
|
type: Optional[StrictStr] = Field(
|
||||||
|
default=None,
|
||||||
|
description="The type of the ancestor node of the image output node.",
|
||||||
|
)
|
||||||
positive_conditioning: Optional[StrictStr] = Field(
|
positive_conditioning: Optional[StrictStr] = Field(
|
||||||
default=None, description="The positive conditioning."
|
default=None, description="The positive conditioning."
|
||||||
)
|
)
|
||||||
@ -20,10 +33,10 @@ class ImageMetadata(BaseModel):
|
|||||||
default=None, description="The negative conditioning."
|
default=None, description="The negative conditioning."
|
||||||
)
|
)
|
||||||
width: Optional[StrictInt] = Field(
|
width: Optional[StrictInt] = Field(
|
||||||
default=None, description="Width of the image/tensor in pixels."
|
default=None, description="Width of the image/latents in pixels."
|
||||||
)
|
)
|
||||||
height: Optional[StrictInt] = Field(
|
height: Optional[StrictInt] = Field(
|
||||||
default=None, description="Height of the image/tensor in pixels."
|
default=None, description="Height of the image/latents in pixels."
|
||||||
)
|
)
|
||||||
seed: Optional[StrictInt] = Field(
|
seed: Optional[StrictInt] = Field(
|
||||||
default=None, description="The seed used for noise generation."
|
default=None, description="The seed used for noise generation."
|
||||||
@ -42,18 +55,21 @@ class ImageMetadata(BaseModel):
|
|||||||
)
|
)
|
||||||
strength: Optional[StrictFloat] = Field(
|
strength: Optional[StrictFloat] = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="The strength used for image-to-image/tensor-to-tensor.",
|
description="The strength used for image-to-image/latents-to-latents.",
|
||||||
)
|
)
|
||||||
image: Optional[StrictStr] = Field(
|
latents: Optional[StrictStr] = Field(
|
||||||
default=None, description="The ID of the initial image."
|
default=None, description="The ID of the initial latents."
|
||||||
)
|
)
|
||||||
tensor: Optional[StrictStr] = Field(
|
vae: Optional[StrictStr] = Field(
|
||||||
default=None, description="The ID of the initial tensor."
|
default=None, description="The VAE used for decoding."
|
||||||
|
)
|
||||||
|
unet: Optional[StrictStr] = Field(
|
||||||
|
default=None, description="The UNet used dor inference."
|
||||||
|
)
|
||||||
|
clip: Optional[StrictStr] = Field(
|
||||||
|
default=None, description="The CLIP Encoder used for conditioning."
|
||||||
)
|
)
|
||||||
# Pending model refactor:
|
|
||||||
# vae: Optional[str] = Field(default=None,description="The VAE used for decoding.")
|
|
||||||
# unet: Optional[str] = Field(default=None,description="The UNet used dor inference.")
|
|
||||||
# clip: Optional[str] = Field(default=None,description="The CLIP Encoder used for conditioning.")
|
|
||||||
extra: Optional[StrictStr] = Field(
|
extra: Optional[StrictStr] = Field(
|
||||||
default=None, description="Extra metadata, extracted from the PNG tEXt chunk."
|
default=None,
|
||||||
|
description="Uploaded image metadata, extracted from the PNG tEXt chunk.",
|
||||||
)
|
)
|
||||||
|
@ -713,6 +713,13 @@ class Graph(BaseModel):
|
|||||||
g.add_edges_from(set([(e.source.node_id, e.destination.node_id) for e in self.edges]))
|
g.add_edges_from(set([(e.source.node_id, e.destination.node_id) for e in self.edges]))
|
||||||
return g
|
return g
|
||||||
|
|
||||||
|
def nx_graph_with_data(self) -> nx.DiGraph:
|
||||||
|
"""Returns a NetworkX DiGraph representing the data and layout of this graph"""
|
||||||
|
g = nx.DiGraph()
|
||||||
|
g.add_nodes_from([n for n in self.nodes.items()])
|
||||||
|
g.add_edges_from(set([(e.source.node_id, e.destination.node_id) for e in self.edges]))
|
||||||
|
return g
|
||||||
|
|
||||||
def nx_graph_flat(
|
def nx_graph_flat(
|
||||||
self, nx_graph: Optional[nx.DiGraph] = None, prefix: Optional[str] = None
|
self, nx_graph: Optional[nx.DiGraph] = None, prefix: Optional[str] = None
|
||||||
) -> nx.DiGraph:
|
) -> nx.DiGraph:
|
||||||
|
@ -6,11 +6,11 @@ from queue import Queue
|
|||||||
from typing import Dict, Optional
|
from typing import Dict, Optional
|
||||||
|
|
||||||
from PIL.Image import Image as PILImageType
|
from PIL.Image import Image as PILImageType
|
||||||
from PIL import Image
|
from PIL import Image, PngImagePlugin
|
||||||
from PIL.PngImagePlugin import PngInfo
|
|
||||||
from send2trash import send2trash
|
from send2trash import send2trash
|
||||||
|
|
||||||
from invokeai.app.models.image import ImageType
|
from invokeai.app.models.image import ImageType
|
||||||
|
from invokeai.app.models.metadata import ImageMetadata
|
||||||
from invokeai.app.util.thumbnails import get_thumbnail_name, make_thumbnail
|
from invokeai.app.util.thumbnails import get_thumbnail_name, make_thumbnail
|
||||||
|
|
||||||
|
|
||||||
@ -54,7 +54,7 @@ class ImageFileStorageBase(ABC):
|
|||||||
image: PILImageType,
|
image: PILImageType,
|
||||||
image_type: ImageType,
|
image_type: ImageType,
|
||||||
image_name: str,
|
image_name: str,
|
||||||
pnginfo: Optional[PngInfo] = None,
|
metadata: Optional[ImageMetadata] = None,
|
||||||
thumbnail_size: int = 256,
|
thumbnail_size: int = 256,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Saves an image and a 256x256 WEBP thumbnail. Returns a tuple of the image name, thumbnail name, and created timestamp."""
|
"""Saves an image and a 256x256 WEBP thumbnail. Returns a tuple of the image name, thumbnail name, and created timestamp."""
|
||||||
@ -109,12 +109,18 @@ class DiskImageFileStorage(ImageFileStorageBase):
|
|||||||
image: PILImageType,
|
image: PILImageType,
|
||||||
image_type: ImageType,
|
image_type: ImageType,
|
||||||
image_name: str,
|
image_name: str,
|
||||||
pnginfo: Optional[PngInfo] = None,
|
metadata: Optional[ImageMetadata] = None,
|
||||||
thumbnail_size: int = 256,
|
thumbnail_size: int = 256,
|
||||||
) -> None:
|
) -> None:
|
||||||
try:
|
try:
|
||||||
image_path = self.get_path(image_type, image_name)
|
image_path = self.get_path(image_type, image_name)
|
||||||
image.save(image_path, "PNG", pnginfo=pnginfo)
|
|
||||||
|
if metadata is not None:
|
||||||
|
pnginfo = PngImagePlugin.PngInfo()
|
||||||
|
pnginfo.add_text("invokeai", metadata.json())
|
||||||
|
image.save(image_path, "PNG", pnginfo=pnginfo)
|
||||||
|
else:
|
||||||
|
image.save(image_path, "PNG")
|
||||||
|
|
||||||
thumbnail_name = get_thumbnail_name(image_name)
|
thumbnail_name = get_thumbnail_name(image_name)
|
||||||
thumbnail_path = self.get_path(image_type, thumbnail_name, thumbnail=True)
|
thumbnail_path = self.get_path(image_type, thumbnail_name, thumbnail=True)
|
||||||
|
@ -1,10 +1,9 @@
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
import json
|
import json
|
||||||
from logging import Logger
|
from logging import Logger
|
||||||
from typing import Optional, Union
|
from typing import Optional, TYPE_CHECKING, Union
|
||||||
import uuid
|
import uuid
|
||||||
from PIL.Image import Image as PILImageType
|
from PIL.Image import Image as PILImageType
|
||||||
from PIL import PngImagePlugin
|
|
||||||
|
|
||||||
from invokeai.app.models.image import ImageCategory, ImageType
|
from invokeai.app.models.image import ImageCategory, ImageType
|
||||||
from invokeai.app.models.metadata import ImageMetadata
|
from invokeai.app.models.metadata import ImageMetadata
|
||||||
@ -17,12 +16,16 @@ from invokeai.app.services.models.image_record import (
|
|||||||
image_record_to_dto,
|
image_record_to_dto,
|
||||||
)
|
)
|
||||||
from invokeai.app.services.image_file_storage import ImageFileStorageBase
|
from invokeai.app.services.image_file_storage import ImageFileStorageBase
|
||||||
from invokeai.app.services.item_storage import PaginatedResults
|
from invokeai.app.services.item_storage import ItemStorageABC, PaginatedResults
|
||||||
from invokeai.app.services.metadata import MetadataServiceBase
|
from invokeai.app.services.metadata import MetadataServiceBase
|
||||||
from invokeai.app.services.urls import UrlServiceBase
|
from invokeai.app.services.urls import UrlServiceBase
|
||||||
from invokeai.app.util.misc import get_iso_timestamp
|
from invokeai.app.util.misc import get_iso_timestamp
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from invokeai.app.services.graph import GraphExecutionState
|
||||||
|
|
||||||
|
|
||||||
class ImageServiceABC(ABC):
|
class ImageServiceABC(ABC):
|
||||||
"""
|
"""
|
||||||
High-level service for image management.
|
High-level service for image management.
|
||||||
@ -59,7 +62,9 @@ class ImageServiceABC(ABC):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_url(self, image_type: ImageType, image_name: str, thumbnail: bool = False) -> str:
|
def get_url(
|
||||||
|
self, image_type: ImageType, image_name: str, thumbnail: bool = False
|
||||||
|
) -> str:
|
||||||
"""Gets an image's or thumbnail's URL"""
|
"""Gets an image's or thumbnail's URL"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -113,6 +118,7 @@ class ImageServiceDependencies:
|
|||||||
metadata: MetadataServiceBase
|
metadata: MetadataServiceBase
|
||||||
urls: UrlServiceBase
|
urls: UrlServiceBase
|
||||||
logger: Logger
|
logger: Logger
|
||||||
|
graph_execution_manager: ItemStorageABC["GraphExecutionState"]
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -121,12 +127,14 @@ class ImageServiceDependencies:
|
|||||||
metadata: MetadataServiceBase,
|
metadata: MetadataServiceBase,
|
||||||
url: UrlServiceBase,
|
url: UrlServiceBase,
|
||||||
logger: Logger,
|
logger: Logger,
|
||||||
|
graph_execution_manager: ItemStorageABC["GraphExecutionState"],
|
||||||
):
|
):
|
||||||
self.records = image_record_storage
|
self.records = image_record_storage
|
||||||
self.files = image_file_storage
|
self.files = image_file_storage
|
||||||
self.metadata = metadata
|
self.metadata = metadata
|
||||||
self.urls = url
|
self.urls = url
|
||||||
self.logger = logger
|
self.logger = logger
|
||||||
|
self.graph_execution_manager = graph_execution_manager
|
||||||
|
|
||||||
|
|
||||||
class ImageService(ImageServiceABC):
|
class ImageService(ImageServiceABC):
|
||||||
@ -139,6 +147,7 @@ class ImageService(ImageServiceABC):
|
|||||||
metadata: MetadataServiceBase,
|
metadata: MetadataServiceBase,
|
||||||
url: UrlServiceBase,
|
url: UrlServiceBase,
|
||||||
logger: Logger,
|
logger: Logger,
|
||||||
|
graph_execution_manager: ItemStorageABC["GraphExecutionState"],
|
||||||
):
|
):
|
||||||
self._services = ImageServiceDependencies(
|
self._services = ImageServiceDependencies(
|
||||||
image_record_storage=image_record_storage,
|
image_record_storage=image_record_storage,
|
||||||
@ -146,6 +155,7 @@ class ImageService(ImageServiceABC):
|
|||||||
metadata=metadata,
|
metadata=metadata,
|
||||||
url=url,
|
url=url,
|
||||||
logger=logger,
|
logger=logger,
|
||||||
|
graph_execution_manager=graph_execution_manager,
|
||||||
)
|
)
|
||||||
|
|
||||||
def create(
|
def create(
|
||||||
@ -155,7 +165,6 @@ class ImageService(ImageServiceABC):
|
|||||||
image_category: ImageCategory,
|
image_category: ImageCategory,
|
||||||
node_id: Optional[str] = None,
|
node_id: Optional[str] = None,
|
||||||
session_id: Optional[str] = None,
|
session_id: Optional[str] = None,
|
||||||
metadata: Optional[ImageMetadata] = None,
|
|
||||||
) -> ImageDTO:
|
) -> ImageDTO:
|
||||||
image_name = self._create_image_name(
|
image_name = self._create_image_name(
|
||||||
image_type=image_type,
|
image_type=image_type,
|
||||||
@ -165,12 +174,7 @@ class ImageService(ImageServiceABC):
|
|||||||
)
|
)
|
||||||
|
|
||||||
timestamp = get_iso_timestamp()
|
timestamp = get_iso_timestamp()
|
||||||
|
metadata = self._get_metadata(session_id, node_id)
|
||||||
if metadata is not None:
|
|
||||||
pnginfo = PngImagePlugin.PngInfo()
|
|
||||||
pnginfo.add_text("invokeai", json.dumps(metadata))
|
|
||||||
else:
|
|
||||||
pnginfo = None
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# TODO: Consider using a transaction here to ensure consistency between storage and database
|
# TODO: Consider using a transaction here to ensure consistency between storage and database
|
||||||
@ -178,7 +182,7 @@ class ImageService(ImageServiceABC):
|
|||||||
image_type=image_type,
|
image_type=image_type,
|
||||||
image_name=image_name,
|
image_name=image_name,
|
||||||
image=image,
|
image=image,
|
||||||
pnginfo=pnginfo,
|
metadata=metadata,
|
||||||
)
|
)
|
||||||
|
|
||||||
self._services.records.save(
|
self._services.records.save(
|
||||||
@ -237,24 +241,6 @@ class ImageService(ImageServiceABC):
|
|||||||
self._services.logger.error("Problem getting image record")
|
self._services.logger.error("Problem getting image record")
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
def get_path(
|
|
||||||
self, image_type: ImageType, image_name: str, thumbnail: bool = False
|
|
||||||
) -> str:
|
|
||||||
try:
|
|
||||||
return self._services.files.get_path(image_type, image_name, thumbnail)
|
|
||||||
except Exception as e:
|
|
||||||
self._services.logger.error("Problem getting image path")
|
|
||||||
raise e
|
|
||||||
|
|
||||||
def get_url(
|
|
||||||
self, image_type: ImageType, image_name: str, thumbnail: bool = False
|
|
||||||
) -> str:
|
|
||||||
try:
|
|
||||||
return self._services.urls.get_image_url(image_type, image_name, thumbnail)
|
|
||||||
except Exception as e:
|
|
||||||
self._services.logger.error("Problem getting image path")
|
|
||||||
raise e
|
|
||||||
|
|
||||||
def get_dto(self, image_type: ImageType, image_name: str) -> ImageDTO:
|
def get_dto(self, image_type: ImageType, image_name: str) -> ImageDTO:
|
||||||
try:
|
try:
|
||||||
image_record = self._services.records.get(image_type, image_name)
|
image_record = self._services.records.get(image_type, image_name)
|
||||||
@ -273,6 +259,24 @@ class ImageService(ImageServiceABC):
|
|||||||
self._services.logger.error("Problem getting image DTO")
|
self._services.logger.error("Problem getting image DTO")
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
|
def get_path(
|
||||||
|
self, image_type: ImageType, image_name: str, thumbnail: bool = False
|
||||||
|
) -> str:
|
||||||
|
try:
|
||||||
|
return self._services.files.get_path(image_type, image_name, thumbnail)
|
||||||
|
except Exception as e:
|
||||||
|
self._services.logger.error("Problem getting image path")
|
||||||
|
raise e
|
||||||
|
|
||||||
|
def get_url(
|
||||||
|
self, image_type: ImageType, image_name: str, thumbnail: bool = False
|
||||||
|
) -> str:
|
||||||
|
try:
|
||||||
|
return self._services.urls.get_image_url(image_type, image_name, thumbnail)
|
||||||
|
except Exception as e:
|
||||||
|
self._services.logger.error("Problem getting image path")
|
||||||
|
raise e
|
||||||
|
|
||||||
def get_many(
|
def get_many(
|
||||||
self,
|
self,
|
||||||
image_type: ImageType,
|
image_type: ImageType,
|
||||||
@ -353,3 +357,15 @@ class ImageService(ImageServiceABC):
|
|||||||
return f"{image_type.value}_{image_category.value}_{session_id}_{node_id}_{uuid_str}.png"
|
return f"{image_type.value}_{image_category.value}_{session_id}_{node_id}_{uuid_str}.png"
|
||||||
|
|
||||||
return f"{image_type.value}_{image_category.value}_{uuid_str}.png"
|
return f"{image_type.value}_{image_category.value}_{uuid_str}.png"
|
||||||
|
|
||||||
|
def _get_metadata(
|
||||||
|
self, session_id: Optional[str] = None, node_id: Optional[str] = None
|
||||||
|
) -> Union[ImageMetadata, None]:
|
||||||
|
"""Get the metadata for a node."""
|
||||||
|
metadata = None
|
||||||
|
|
||||||
|
if node_id is not None and session_id is not None:
|
||||||
|
session = self._services.graph_execution_manager.get(session_id)
|
||||||
|
metadata = self._services.metadata.create_image_metadata(session, node_id)
|
||||||
|
|
||||||
|
return metadata
|
||||||
|
@ -1,295 +1,142 @@
|
|||||||
import json
|
import json
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Any, Dict, Optional, TypedDict
|
from typing import Any, Union
|
||||||
from PIL import Image, PngImagePlugin
|
import networkx as nx
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
from invokeai.app.models.image import ImageType, is_image_type
|
from invokeai.app.models.metadata import ImageMetadata
|
||||||
|
from invokeai.app.services.graph import Edge, Graph, GraphExecutionState
|
||||||
|
|
||||||
class MetadataImageField(TypedDict):
|
|
||||||
"""Pydantic-less ImageField, used for metadata parsing."""
|
|
||||||
|
|
||||||
image_type: ImageType
|
|
||||||
image_name: str
|
|
||||||
|
|
||||||
|
|
||||||
class MetadataLatentsField(TypedDict):
|
|
||||||
"""Pydantic-less LatentsField, used for metadata parsing."""
|
|
||||||
|
|
||||||
latents_name: str
|
|
||||||
|
|
||||||
|
|
||||||
class MetadataColorField(TypedDict):
|
|
||||||
"""Pydantic-less ColorField, used for metadata parsing"""
|
|
||||||
|
|
||||||
r: int
|
|
||||||
g: int
|
|
||||||
b: int
|
|
||||||
a: int
|
|
||||||
|
|
||||||
|
|
||||||
# TODO: This is a placeholder for `InvocationsUnion` pending resolution of circular imports
|
|
||||||
NodeMetadata = Dict[
|
|
||||||
str,
|
|
||||||
None
|
|
||||||
| str
|
|
||||||
| int
|
|
||||||
| float
|
|
||||||
| bool
|
|
||||||
| MetadataImageField
|
|
||||||
| MetadataLatentsField
|
|
||||||
| MetadataColorField,
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
class InvokeAIMetadata(TypedDict, total=False):
|
|
||||||
"""InvokeAI-specific metadata format."""
|
|
||||||
|
|
||||||
session_id: Optional[str]
|
|
||||||
node: Optional[NodeMetadata]
|
|
||||||
|
|
||||||
|
|
||||||
def build_invokeai_metadata_pnginfo(
|
|
||||||
metadata: InvokeAIMetadata | None,
|
|
||||||
) -> PngImagePlugin.PngInfo:
|
|
||||||
"""Builds a PngInfo object with key `"invokeai"` and value `metadata`"""
|
|
||||||
pnginfo = PngImagePlugin.PngInfo()
|
|
||||||
|
|
||||||
if metadata is not None:
|
|
||||||
pnginfo.add_text("invokeai", json.dumps(metadata))
|
|
||||||
|
|
||||||
return pnginfo
|
|
||||||
|
|
||||||
|
|
||||||
class MetadataServiceBase(ABC):
|
class MetadataServiceBase(ABC):
|
||||||
@abstractmethod
|
"""Handles building metadata for nodes, images, and outputs."""
|
||||||
def get_metadata(self, image: Image.Image) -> InvokeAIMetadata | None:
|
|
||||||
"""Gets the InvokeAI metadata from a PIL Image, skipping invalid values"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def build_metadata(
|
def create_image_metadata(
|
||||||
self, session_id: str, node: BaseModel
|
self, session: GraphExecutionState, node_id: str
|
||||||
) -> InvokeAIMetadata | None:
|
) -> ImageMetadata:
|
||||||
"""Builds an InvokeAIMetadata object"""
|
"""Builds an ImageMetadata object for a node."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
# @abstractmethod
|
|
||||||
# def create_metadata(self, session_id: str, node_id: str) -> dict:
|
|
||||||
# """Creates metadata for a result"""
|
|
||||||
# pass
|
|
||||||
|
|
||||||
|
|
||||||
class PngMetadataService(MetadataServiceBase):
|
class CoreMetadataService(MetadataServiceBase):
|
||||||
"""Handles loading and building metadata for images."""
|
_ANCESTOR_TYPES = ["t2l", "l2l"]
|
||||||
|
"""The ancestor types that contain the core metadata"""
|
||||||
|
|
||||||
# TODO: Use `InvocationsUnion` to **validate** metadata as representing a fully-functioning node
|
_ANCESTOR_PARAMS = ["type", "steps", "model", "cfg_scale", "scheduler", "strength"]
|
||||||
def _load_metadata(self, image: Image.Image) -> dict | None:
|
"""The core metadata parameters in the ancestor types"""
|
||||||
"""Loads a specific info entry from a PIL Image."""
|
|
||||||
|
|
||||||
try:
|
_NOISE_FIELDS = ["seed", "width", "height"]
|
||||||
info = image.info.get("invokeai")
|
"""The core metadata parameters in the noise node"""
|
||||||
|
|
||||||
if type(info) is not str:
|
def create_image_metadata(
|
||||||
return None
|
self, session: GraphExecutionState, node_id: str
|
||||||
|
) -> ImageMetadata:
|
||||||
loaded_metadata = json.loads(info)
|
metadata = self._build_metadata_from_graph(session, node_id)
|
||||||
|
|
||||||
if type(loaded_metadata) is not dict:
|
|
||||||
return None
|
|
||||||
|
|
||||||
if len(loaded_metadata.items()) == 0:
|
|
||||||
return None
|
|
||||||
|
|
||||||
return loaded_metadata
|
|
||||||
except:
|
|
||||||
return None
|
|
||||||
|
|
||||||
def get_metadata(self, image: Image.Image) -> dict | None:
|
|
||||||
"""Retrieves an image's metadata as a dict"""
|
|
||||||
loaded_metadata = self._load_metadata(image)
|
|
||||||
|
|
||||||
return loaded_metadata
|
|
||||||
|
|
||||||
def build_metadata(self, session_id: str, node: BaseModel) -> InvokeAIMetadata:
|
|
||||||
metadata = InvokeAIMetadata(session_id=session_id, node=node.dict())
|
|
||||||
|
|
||||||
return metadata
|
return metadata
|
||||||
|
|
||||||
|
def _find_nearest_ancestor(self, G: nx.DiGraph, node_id: str) -> Union[str, None]:
|
||||||
|
"""
|
||||||
|
Finds the id of the nearest ancestor (of a valid type) of a given node.
|
||||||
|
|
||||||
from enum import Enum
|
Parameters:
|
||||||
|
G (nx.DiGraph): The execution graph, converted in to a networkx DiGraph. Its nodes must
|
||||||
|
have the same data as the execution graph.
|
||||||
|
node_id (str): The ID of the node.
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
Returns:
|
||||||
import json
|
str | None: The ID of the nearest ancestor, or None if there are no valid ancestors.
|
||||||
import sqlite3
|
"""
|
||||||
from threading import Lock
|
|
||||||
from typing import Any, Union
|
|
||||||
|
|
||||||
import networkx as nx
|
# Retrieve the node from the graph
|
||||||
|
node = G.nodes[node_id]
|
||||||
|
|
||||||
from pydantic import BaseModel, Field, parse_obj_as, parse_raw_as
|
# If the node type is one of the core metadata node types, return its id
|
||||||
from invokeai.app.invocations.image import ImageOutput
|
if node.get("type") in self._ANCESTOR_TYPES:
|
||||||
from invokeai.app.services.graph import Edge, GraphExecutionState
|
return node.get("id")
|
||||||
from invokeai.app.invocations.latent import LatentsOutput
|
|
||||||
from invokeai.app.services.item_storage import PaginatedResults
|
|
||||||
from invokeai.app.util.misc import get_timestamp
|
|
||||||
|
|
||||||
|
# Else, look for the ancestor in the predecessor nodes
|
||||||
|
for predecessor in G.predecessors(node_id):
|
||||||
|
result = self._find_nearest_ancestor(G, predecessor)
|
||||||
|
if result:
|
||||||
|
return result
|
||||||
|
|
||||||
class ResultType(str, Enum):
|
# If there are no valid ancestors, return None
|
||||||
image_output = "image_output"
|
return None
|
||||||
latents_output = "latents_output"
|
|
||||||
|
|
||||||
|
def _get_additional_metadata(
|
||||||
|
self, graph: Graph, node_id: str
|
||||||
|
) -> Union[dict[str, Any], None]:
|
||||||
|
"""
|
||||||
|
Returns additional metadata for a given node.
|
||||||
|
|
||||||
class Result(BaseModel):
|
Parameters:
|
||||||
"""A session result"""
|
graph (Graph): The execution graph.
|
||||||
|
node_id (str): The ID of the node.
|
||||||
|
|
||||||
id: str = Field(description="Result ID")
|
Returns:
|
||||||
session_id: str = Field(description="Session ID")
|
dict[str, Any] | None: A dictionary of additional metadata.
|
||||||
node_id: str = Field(description="Node ID")
|
"""
|
||||||
data: Union[LatentsOutput, ImageOutput] = Field(description="The result data")
|
|
||||||
|
|
||||||
|
metadata = {}
|
||||||
|
|
||||||
class ResultWithSession(BaseModel):
|
# Iterate over all edges in the graph
|
||||||
"""A result with its session"""
|
for edge in graph.edges:
|
||||||
|
dest_node_id = edge.destination.node_id
|
||||||
|
dest_field = edge.destination.field
|
||||||
|
source_node_dict = graph.nodes[edge.source.node_id].dict()
|
||||||
|
|
||||||
result: Result = Field(description="The result")
|
# If the destination node ID matches the given node ID, gather necessary metadata
|
||||||
session: GraphExecutionState = Field(description="The session")
|
if dest_node_id == node_id:
|
||||||
|
# If the destination field is 'positive_conditioning', add the 'prompt' from the source node
|
||||||
|
if dest_field == "positive_conditioning":
|
||||||
|
metadata["positive_conditioning"] = source_node_dict.get("prompt")
|
||||||
|
# If the destination field is 'negative_conditioning', add the 'prompt' from the source node
|
||||||
|
if dest_field == "negative_conditioning":
|
||||||
|
metadata["negative_conditioning"] = source_node_dict.get("prompt")
|
||||||
|
# If the destination field is 'noise', add the core noise fields from the source node
|
||||||
|
if dest_field == "noise":
|
||||||
|
for field in self._NOISE_FIELDS:
|
||||||
|
metadata[field] = source_node_dict.get(field)
|
||||||
|
return metadata
|
||||||
|
|
||||||
|
def _build_metadata_from_graph(
|
||||||
|
self, session: GraphExecutionState, node_id: str
|
||||||
|
) -> ImageMetadata:
|
||||||
|
"""
|
||||||
|
Builds an ImageMetadata object for a node.
|
||||||
|
|
||||||
# # Create a directed graph
|
Parameters:
|
||||||
# from typing import Any, TypedDict, Union
|
session (GraphExecutionState): The session.
|
||||||
# from networkx import DiGraph
|
node_id (str): The ID of the node.
|
||||||
# import networkx as nx
|
|
||||||
# import json
|
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ImageMetadata: The metadata for the node.
|
||||||
|
"""
|
||||||
|
|
||||||
# # We need to use a loose class for nodes to allow for graceful parsing - we cannot use the stricter
|
graph = session.execution_graph
|
||||||
# # model used by the system, because we may be a graph in an old format. We can, however, use the
|
|
||||||
# # Edge model, because the edge format does not change.
|
|
||||||
# class LooseGraph(BaseModel):
|
|
||||||
# id: str
|
|
||||||
# nodes: dict[str, dict[str, Any]]
|
|
||||||
# edges: list[Edge]
|
|
||||||
|
|
||||||
|
# Find the nearest ancestor of the given node
|
||||||
|
ancestor_id = self._find_nearest_ancestor(graph.nx_graph_with_data(), node_id)
|
||||||
|
|
||||||
# # An intermediate type used during parsing
|
# If no ancestor was found, return an empty ImageMetadata object
|
||||||
# class NearestAncestor(TypedDict):
|
if ancestor_id is None:
|
||||||
# node_id: str
|
return ImageMetadata()
|
||||||
# metadata: dict[str, Any]
|
|
||||||
|
|
||||||
|
ancestor_node = graph.get_node(ancestor_id)
|
||||||
|
|
||||||
# # The ancestor types that contain the core metadata
|
ancestor_metadata = {
|
||||||
# ANCESTOR_TYPES = ['t2l', 'l2l']
|
param: val
|
||||||
|
for param, val in ancestor_node.dict().items()
|
||||||
|
if param in self._ANCESTOR_PARAMS
|
||||||
|
}
|
||||||
|
|
||||||
# # The core metadata parameters in the ancestor types
|
# Get additional metadata related to the ancestor
|
||||||
# ANCESTOR_PARAMS = ['steps', 'model', 'cfg_scale', 'scheduler', 'strength']
|
addl_metadata = self._get_additional_metadata(graph, ancestor_id)
|
||||||
|
|
||||||
# # The core metadata parameters in the noise node
|
# If additional metadata was found, add it to the main metadata
|
||||||
# NOISE_FIELDS = ['seed', 'width', 'height']
|
if addl_metadata is not None:
|
||||||
|
ancestor_metadata.update(addl_metadata)
|
||||||
|
|
||||||
# # Find nearest t2l or l2l ancestor from a given l2i node
|
return ImageMetadata(**ancestor_metadata)
|
||||||
# def find_nearest_ancestor(G: DiGraph, node_id: str) -> Union[NearestAncestor, None]:
|
|
||||||
# """Returns metadata for the nearest ancestor of a given node.
|
|
||||||
|
|
||||||
# Parameters:
|
|
||||||
# G (DiGraph): A directed graph.
|
|
||||||
# node_id (str): The ID of the starting node.
|
|
||||||
|
|
||||||
# Returns:
|
|
||||||
# NearestAncestor | None: An object with the ID and metadata of the nearest ancestor.
|
|
||||||
# """
|
|
||||||
|
|
||||||
# # Retrieve the node from the graph
|
|
||||||
# node = G.nodes[node_id]
|
|
||||||
|
|
||||||
# # If the node type is one of the core metadata node types, gather necessary metadata and return
|
|
||||||
# if node.get('type') in ANCESTOR_TYPES:
|
|
||||||
# parsed_metadata = {param: val for param, val in node.items() if param in ANCESTOR_PARAMS}
|
|
||||||
# return NearestAncestor(node_id=node_id, metadata=parsed_metadata)
|
|
||||||
|
|
||||||
|
|
||||||
# # Else, look for the ancestor in the predecessor nodes
|
|
||||||
# for predecessor in G.predecessors(node_id):
|
|
||||||
# result = find_nearest_ancestor(G, predecessor)
|
|
||||||
# if result:
|
|
||||||
# return result
|
|
||||||
|
|
||||||
# # If there are no valid ancestors, return None
|
|
||||||
# return None
|
|
||||||
|
|
||||||
|
|
||||||
# def get_additional_metadata(graph: LooseGraph, node_id: str) -> Union[dict[str, Any], None]:
|
|
||||||
# """Collects additional metadata from nodes connected to a given node.
|
|
||||||
|
|
||||||
# Parameters:
|
|
||||||
# graph (LooseGraph): The graph.
|
|
||||||
# node_id (str): The ID of the node.
|
|
||||||
|
|
||||||
# Returns:
|
|
||||||
# dict | None: A dictionary containing additional metadata.
|
|
||||||
# """
|
|
||||||
|
|
||||||
# metadata = {}
|
|
||||||
|
|
||||||
# # Iterate over all edges in the graph
|
|
||||||
# for edge in graph.edges:
|
|
||||||
# dest_node_id = edge.destination.node_id
|
|
||||||
# dest_field = edge.destination.field
|
|
||||||
# source_node = graph.nodes[edge.source.node_id]
|
|
||||||
|
|
||||||
# # If the destination node ID matches the given node ID, gather necessary metadata
|
|
||||||
# if dest_node_id == node_id:
|
|
||||||
# # If the destination field is 'positive_conditioning', add the 'prompt' from the source node
|
|
||||||
# if dest_field == 'positive_conditioning':
|
|
||||||
# metadata['positive_conditioning'] = source_node.get('prompt')
|
|
||||||
# # If the destination field is 'negative_conditioning', add the 'prompt' from the source node
|
|
||||||
# if dest_field == 'negative_conditioning':
|
|
||||||
# metadata['negative_conditioning'] = source_node.get('prompt')
|
|
||||||
# # If the destination field is 'noise', add the core noise fields from the source node
|
|
||||||
# if dest_field == 'noise':
|
|
||||||
# for field in NOISE_FIELDS:
|
|
||||||
# metadata[field] = source_node.get(field)
|
|
||||||
# return metadata
|
|
||||||
|
|
||||||
# def build_core_metadata(graph_raw: str, node_id: str) -> Union[dict, None]:
|
|
||||||
# """Builds the core metadata for a given node.
|
|
||||||
|
|
||||||
# Parameters:
|
|
||||||
# graph_raw (str): The graph structure as a raw string.
|
|
||||||
# node_id (str): The ID of the node.
|
|
||||||
|
|
||||||
# Returns:
|
|
||||||
# dict | None: A dictionary containing core metadata.
|
|
||||||
# """
|
|
||||||
|
|
||||||
# # Create a directed graph to facilitate traversal
|
|
||||||
# G = nx.DiGraph()
|
|
||||||
|
|
||||||
# # Convert the raw graph string into a JSON object
|
|
||||||
# graph = parse_obj_as(LooseGraph, graph_raw)
|
|
||||||
|
|
||||||
# # Add nodes and edges to the graph
|
|
||||||
# for node_id, node_data in graph.nodes.items():
|
|
||||||
# G.add_node(node_id, **node_data)
|
|
||||||
# for edge in graph.edges:
|
|
||||||
# G.add_edge(edge.source.node_id, edge.destination.node_id)
|
|
||||||
|
|
||||||
# # Find the nearest ancestor of the given node
|
|
||||||
# ancestor = find_nearest_ancestor(G, node_id)
|
|
||||||
|
|
||||||
# # If no ancestor was found, return None
|
|
||||||
# if ancestor is None:
|
|
||||||
# return None
|
|
||||||
|
|
||||||
# metadata = ancestor['metadata']
|
|
||||||
# ancestor_id = ancestor['node_id']
|
|
||||||
|
|
||||||
# # Get additional metadata related to the ancestor
|
|
||||||
# addl_metadata = get_additional_metadata(graph, ancestor_id)
|
|
||||||
|
|
||||||
# # If additional metadata was found, add it to the main metadata
|
|
||||||
# if addl_metadata is not None:
|
|
||||||
# metadata.update(addl_metadata)
|
|
||||||
|
|
||||||
# return metadata
|
|
||||||
|
@ -62,9 +62,12 @@ def deserialize_image_record(image_row: sqlite3.Row) -> ImageRecord:
|
|||||||
|
|
||||||
image_type = ImageType(image_dict.get("image_type", ImageType.RESULT.value))
|
image_type = ImageType(image_dict.get("image_type", ImageType.RESULT.value))
|
||||||
|
|
||||||
raw_metadata = image_dict.get("metadata", "{}")
|
raw_metadata = image_dict.get("metadata")
|
||||||
|
|
||||||
metadata = ImageMetadata.parse_raw(raw_metadata)
|
if raw_metadata is not None:
|
||||||
|
metadata = ImageMetadata.parse_raw(raw_metadata)
|
||||||
|
else:
|
||||||
|
metadata = None
|
||||||
|
|
||||||
return ImageRecord(
|
return ImageRecord(
|
||||||
image_name=image_dict.get("id", "unknown"),
|
image_name=image_dict.get("id", "unknown"),
|
||||||
|
Loading…
Reference in New Issue
Block a user