mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(db,nodes,api): refactor metadata
Metadata for the Linear UI is now sneakily provided via a `MetadataAccumulator` node, which the client populates / hooks up while building the graph. Additionally, we provide the unexpanded graph with the metadata API response. Both of these are embedded into the PNGs. - Remove `metadata` from `ImageDTO` - Split up the `images/` routes to accomodate this; metadata is only retrieved per-image - `images/{image_name}` now gets the DTO - `images/{image_name}/metadata` gets the new metadata - `images/{image_name}/full` gets the full-sized image file - Remove old metadata service - Add `MetadataAccumulator` node, `CoreMetadataField`, hook up to `LatentsToImage` node - Add `get_raw()` method to `ItemStorage`, retrieves the row from DB as a string, no pydantic parsing - Update `images`related services to handle storing and retrieving the new metadata - Add `get_metadata_graph_from_raw_session` which extracts the `graph` from `session` without needing to hydrate the session in pydantic, in preparation for providing it as metadata; also removes all references to the `MetadataAccumulator` node
This commit is contained in:
parent
eb0d55263b
commit
50bef87da7
@ -13,7 +13,6 @@ from invokeai.app.services.board_record_storage import SqliteBoardRecordStorage
|
|||||||
from invokeai.app.services.boards import BoardService, BoardServiceDependencies
|
from invokeai.app.services.boards import BoardService, BoardServiceDependencies
|
||||||
from invokeai.app.services.image_record_storage import SqliteImageRecordStorage
|
from invokeai.app.services.image_record_storage import SqliteImageRecordStorage
|
||||||
from invokeai.app.services.images import ImageService, ImageServiceDependencies
|
from invokeai.app.services.images import ImageService, ImageServiceDependencies
|
||||||
from invokeai.app.services.metadata import CoreMetadataService
|
|
||||||
from invokeai.app.services.resource_name import SimpleNameService
|
from invokeai.app.services.resource_name import SimpleNameService
|
||||||
from invokeai.app.services.urls import LocalUrlService
|
from invokeai.app.services.urls import LocalUrlService
|
||||||
from invokeai.backend.util.logging import InvokeAILogger
|
from invokeai.backend.util.logging import InvokeAILogger
|
||||||
@ -75,7 +74,6 @@ class ApiDependencies:
|
|||||||
)
|
)
|
||||||
|
|
||||||
urls = LocalUrlService()
|
urls = LocalUrlService()
|
||||||
metadata = CoreMetadataService()
|
|
||||||
image_record_storage = SqliteImageRecordStorage(db_location)
|
image_record_storage = SqliteImageRecordStorage(db_location)
|
||||||
image_file_storage = DiskImageFileStorage(f"{output_folder}/images")
|
image_file_storage = DiskImageFileStorage(f"{output_folder}/images")
|
||||||
names = SimpleNameService()
|
names = SimpleNameService()
|
||||||
@ -111,7 +109,6 @@ class ApiDependencies:
|
|||||||
board_image_record_storage=board_image_record_storage,
|
board_image_record_storage=board_image_record_storage,
|
||||||
image_record_storage=image_record_storage,
|
image_record_storage=image_record_storage,
|
||||||
image_file_storage=image_file_storage,
|
image_file_storage=image_file_storage,
|
||||||
metadata=metadata,
|
|
||||||
url=urls,
|
url=urls,
|
||||||
logger=logger,
|
logger=logger,
|
||||||
names=names,
|
names=names,
|
||||||
|
@ -1,20 +1,19 @@
|
|||||||
import io
|
import io
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
from fastapi import Body, HTTPException, Path, Query, Request, Response, UploadFile
|
|
||||||
from fastapi.routing import APIRouter
|
from fastapi import (Body, HTTPException, Path, Query, Request, Response,
|
||||||
|
UploadFile)
|
||||||
from fastapi.responses import FileResponse
|
from fastapi.responses import FileResponse
|
||||||
|
from fastapi.routing import APIRouter
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from invokeai.app.models.image import (
|
|
||||||
ImageCategory,
|
from invokeai.app.invocations.metadata import ImageMetadata
|
||||||
ResourceOrigin,
|
from invokeai.app.models.image import ImageCategory, ResourceOrigin
|
||||||
)
|
|
||||||
from invokeai.app.services.image_record_storage import OffsetPaginatedResults
|
from invokeai.app.services.image_record_storage import OffsetPaginatedResults
|
||||||
from invokeai.app.services.models.image_record import (
|
|
||||||
ImageDTO,
|
|
||||||
ImageRecordChanges,
|
|
||||||
ImageUrlsDTO,
|
|
||||||
)
|
|
||||||
from invokeai.app.services.item_storage import PaginatedResults
|
from invokeai.app.services.item_storage import PaginatedResults
|
||||||
|
from invokeai.app.services.models.image_record import (ImageDTO,
|
||||||
|
ImageRecordChanges,
|
||||||
|
ImageUrlsDTO)
|
||||||
|
|
||||||
from ..dependencies import ApiDependencies
|
from ..dependencies import ApiDependencies
|
||||||
|
|
||||||
@ -103,23 +102,38 @@ async def update_image(
|
|||||||
|
|
||||||
|
|
||||||
@images_router.get(
|
@images_router.get(
|
||||||
"/{image_name}/metadata",
|
"/{image_name}",
|
||||||
operation_id="get_image_metadata",
|
operation_id="get_image_dto",
|
||||||
response_model=ImageDTO,
|
response_model=ImageDTO,
|
||||||
)
|
)
|
||||||
async def get_image_metadata(
|
async def get_image_dto(
|
||||||
image_name: str = Path(description="The name of image to get"),
|
image_name: str = Path(description="The name of image to get"),
|
||||||
) -> ImageDTO:
|
) -> ImageDTO:
|
||||||
"""Gets an image's metadata"""
|
"""Gets an image's DTO"""
|
||||||
|
|
||||||
try:
|
try:
|
||||||
return ApiDependencies.invoker.services.images.get_dto(image_name)
|
return ApiDependencies.invoker.services.images.get_dto(image_name)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(status_code=404)
|
raise HTTPException(status_code=404)
|
||||||
|
|
||||||
|
@images_router.get(
|
||||||
|
"/{image_name}/metadata",
|
||||||
|
operation_id="get_image_metadata",
|
||||||
|
response_model=ImageMetadata,
|
||||||
|
)
|
||||||
|
async def get_image_metadata(
|
||||||
|
image_name: str = Path(description="The name of image to get"),
|
||||||
|
) -> ImageMetadata:
|
||||||
|
"""Gets an image's metadata"""
|
||||||
|
|
||||||
|
try:
|
||||||
|
return ApiDependencies.invoker.services.images.get_metadata(image_name)
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(status_code=404)
|
||||||
|
|
||||||
|
|
||||||
@images_router.get(
|
@images_router.get(
|
||||||
"/{image_name}",
|
"/{image_name}/full",
|
||||||
operation_id="get_image_full",
|
operation_id="get_image_full",
|
||||||
response_class=Response,
|
response_class=Response,
|
||||||
responses={
|
responses={
|
||||||
@ -208,10 +222,10 @@ async def get_image_urls(
|
|||||||
|
|
||||||
@images_router.get(
|
@images_router.get(
|
||||||
"/",
|
"/",
|
||||||
operation_id="list_images_with_metadata",
|
operation_id="list_image_dtos",
|
||||||
response_model=OffsetPaginatedResults[ImageDTO],
|
response_model=OffsetPaginatedResults[ImageDTO],
|
||||||
)
|
)
|
||||||
async def list_images_with_metadata(
|
async def list_image_dtos(
|
||||||
image_origin: Optional[ResourceOrigin] = Query(
|
image_origin: Optional[ResourceOrigin] = Query(
|
||||||
default=None, description="The origin of images to list"
|
default=None, description="The origin of images to list"
|
||||||
),
|
),
|
||||||
@ -227,7 +241,7 @@ async def list_images_with_metadata(
|
|||||||
offset: int = Query(default=0, description="The page offset"),
|
offset: int = Query(default=0, description="The page offset"),
|
||||||
limit: int = Query(default=10, description="The number of images per page"),
|
limit: int = Query(default=10, description="The number of images per page"),
|
||||||
) -> OffsetPaginatedResults[ImageDTO]:
|
) -> OffsetPaginatedResults[ImageDTO]:
|
||||||
"""Gets a list of images"""
|
"""Gets a list of image DTOs"""
|
||||||
|
|
||||||
image_dtos = ApiDependencies.invoker.services.images.get_many(
|
image_dtos = ApiDependencies.invoker.services.images.get_many(
|
||||||
offset,
|
offset,
|
||||||
|
@ -34,7 +34,6 @@ from invokeai.app.services.board_record_storage import SqliteBoardRecordStorage
|
|||||||
from invokeai.app.services.boards import BoardService, BoardServiceDependencies
|
from invokeai.app.services.boards import BoardService, BoardServiceDependencies
|
||||||
from invokeai.app.services.image_record_storage import SqliteImageRecordStorage
|
from invokeai.app.services.image_record_storage import SqliteImageRecordStorage
|
||||||
from invokeai.app.services.images import ImageService, ImageServiceDependencies
|
from invokeai.app.services.images import ImageService, ImageServiceDependencies
|
||||||
from invokeai.app.services.metadata import CoreMetadataService
|
|
||||||
from invokeai.app.services.resource_name import SimpleNameService
|
from invokeai.app.services.resource_name import SimpleNameService
|
||||||
from invokeai.app.services.urls import LocalUrlService
|
from invokeai.app.services.urls import LocalUrlService
|
||||||
from .services.default_graphs import (default_text_to_image_graph_id,
|
from .services.default_graphs import (default_text_to_image_graph_id,
|
||||||
@ -244,7 +243,6 @@ def invoke_cli():
|
|||||||
)
|
)
|
||||||
|
|
||||||
urls = LocalUrlService()
|
urls = LocalUrlService()
|
||||||
metadata = CoreMetadataService()
|
|
||||||
image_record_storage = SqliteImageRecordStorage(db_location)
|
image_record_storage = SqliteImageRecordStorage(db_location)
|
||||||
image_file_storage = DiskImageFileStorage(f"{output_folder}/images")
|
image_file_storage = DiskImageFileStorage(f"{output_folder}/images")
|
||||||
names = SimpleNameService()
|
names = SimpleNameService()
|
||||||
@ -277,7 +275,6 @@ def invoke_cli():
|
|||||||
board_image_record_storage=board_image_record_storage,
|
board_image_record_storage=board_image_record_storage,
|
||||||
image_record_storage=image_record_storage,
|
image_record_storage=image_record_storage,
|
||||||
image_file_storage=image_file_storage,
|
image_file_storage=image_file_storage,
|
||||||
metadata=metadata,
|
|
||||||
url=urls,
|
url=urls,
|
||||||
logger=logger,
|
logger=logger,
|
||||||
names=names,
|
names=names,
|
||||||
|
@ -9,9 +9,9 @@ from diffusers.image_processor import VaeImageProcessor
|
|||||||
from diffusers.schedulers import SchedulerMixin as Scheduler
|
from diffusers.schedulers import SchedulerMixin as Scheduler
|
||||||
from pydantic import BaseModel, Field, validator
|
from pydantic import BaseModel, Field, validator
|
||||||
|
|
||||||
|
from invokeai.app.invocations.metadata import CoreMetadata
|
||||||
from invokeai.app.util.step_callback import stable_diffusion_step_callback
|
from invokeai.app.util.step_callback import stable_diffusion_step_callback
|
||||||
|
|
||||||
from ..models.image import ImageCategory, ImageField, ResourceOrigin
|
|
||||||
from ...backend.model_management.lora import ModelPatcher
|
from ...backend.model_management.lora import ModelPatcher
|
||||||
from ...backend.stable_diffusion import PipelineIntermediateState
|
from ...backend.stable_diffusion import PipelineIntermediateState
|
||||||
from ...backend.stable_diffusion.diffusers_pipeline import (
|
from ...backend.stable_diffusion.diffusers_pipeline import (
|
||||||
@ -21,6 +21,7 @@ from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import \
|
|||||||
PostprocessingSettings
|
PostprocessingSettings
|
||||||
from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP
|
from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP
|
||||||
from ...backend.util.devices import torch_dtype
|
from ...backend.util.devices import torch_dtype
|
||||||
|
from ..models.image import ImageCategory, ImageField, ResourceOrigin
|
||||||
from .baseinvocation import (BaseInvocation, BaseInvocationOutput,
|
from .baseinvocation import (BaseInvocation, BaseInvocationOutput,
|
||||||
InvocationConfig, InvocationContext)
|
InvocationConfig, InvocationContext)
|
||||||
from .compel import ConditioningField
|
from .compel import ConditioningField
|
||||||
@ -449,6 +450,8 @@ class LatentsToImageInvocation(BaseInvocation):
|
|||||||
tiled: bool = Field(
|
tiled: bool = Field(
|
||||||
default=False,
|
default=False,
|
||||||
description="Decode latents by overlaping tiles(less memory consumption)")
|
description="Decode latents by overlaping tiles(less memory consumption)")
|
||||||
|
metadata: Optional[CoreMetadata] = Field(default=None, description="Optional core metadata to be written to the image")
|
||||||
|
|
||||||
|
|
||||||
# Schema customisation
|
# Schema customisation
|
||||||
class Config(InvocationConfig):
|
class Config(InvocationConfig):
|
||||||
@ -493,7 +496,8 @@ class LatentsToImageInvocation(BaseInvocation):
|
|||||||
image_category=ImageCategory.GENERAL,
|
image_category=ImageCategory.GENERAL,
|
||||||
node_id=self.id,
|
node_id=self.id,
|
||||||
session_id=context.graph_execution_state_id,
|
session_id=context.graph_execution_state_id,
|
||||||
is_intermediate=self.is_intermediate
|
is_intermediate=self.is_intermediate,
|
||||||
|
metadata=self.metadata.dict() if self.metadata else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
return ImageOutput(
|
return ImageOutput(
|
||||||
|
124
invokeai/app/invocations/metadata.py
Normal file
124
invokeai/app/invocations/metadata.py
Normal file
@ -0,0 +1,124 @@
|
|||||||
|
from typing import Literal, Optional, Union
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from invokeai.app.invocations.baseinvocation import (BaseInvocation,
|
||||||
|
BaseInvocationOutput,
|
||||||
|
InvocationContext)
|
||||||
|
from invokeai.app.invocations.controlnet_image_processors import ControlField
|
||||||
|
from invokeai.app.invocations.model import (LoRAModelField, MainModelField,
|
||||||
|
VAEModelField)
|
||||||
|
|
||||||
|
|
||||||
|
class LoRAMetadataField(BaseModel):
|
||||||
|
"""LoRA metadata for an image generated in InvokeAI."""
|
||||||
|
lora: LoRAModelField = Field(description="The LoRA model")
|
||||||
|
weight: float = Field(description="The weight of the LoRA model")
|
||||||
|
|
||||||
|
|
||||||
|
class CoreMetadata(BaseModel):
|
||||||
|
"""Core generation metadata for an image generated in InvokeAI."""
|
||||||
|
|
||||||
|
generation_mode: str = Field(description="The generation mode that output this image",)
|
||||||
|
positive_prompt: str = Field(description="The positive prompt parameter")
|
||||||
|
negative_prompt: str = Field(description="The negative prompt parameter")
|
||||||
|
width: int = Field(description="The width parameter")
|
||||||
|
height: int = Field(description="The height parameter")
|
||||||
|
seed: int = Field(description="The seed used for noise generation")
|
||||||
|
rand_device: str = Field(description="The device used for random number generation")
|
||||||
|
cfg_scale: float = Field(description="The classifier-free guidance scale parameter")
|
||||||
|
steps: int = Field(description="The number of steps used for inference")
|
||||||
|
scheduler: str = Field(description="The scheduler used for inference")
|
||||||
|
clip_skip: int = Field(description="The number of skipped CLIP layers",)
|
||||||
|
model: MainModelField = Field(description="The main model used for inference")
|
||||||
|
controlnets: list[ControlField]= Field(description="The ControlNets used for inference")
|
||||||
|
loras: list[LoRAMetadataField] = Field(description="The LoRAs used for inference")
|
||||||
|
strength: Union[float, None] = Field(
|
||||||
|
default=None,
|
||||||
|
description="The strength used for latents-to-latents",
|
||||||
|
)
|
||||||
|
init_image: Union[str, None] = Field(
|
||||||
|
default=None, description="The name of the initial image"
|
||||||
|
)
|
||||||
|
vae: Union[VAEModelField, None] = Field(
|
||||||
|
default=None,
|
||||||
|
description="The VAE used for decoding, if the main model's default was not used",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ImageMetadata(BaseModel):
|
||||||
|
"""An image's generation metadata"""
|
||||||
|
|
||||||
|
metadata: Optional[dict] = Field(
|
||||||
|
default=None,
|
||||||
|
description="The image's core metadata, if it was created in the Linear or Canvas UI",
|
||||||
|
)
|
||||||
|
graph: Optional[dict] = Field(
|
||||||
|
default=None, description="The graph that created the image"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class MetadataAccumulatorOutput(BaseInvocationOutput):
|
||||||
|
"""The output of the MetadataAccumulator node"""
|
||||||
|
|
||||||
|
type: Literal["metadata_accumulator_output"] = "metadata_accumulator_output"
|
||||||
|
|
||||||
|
metadata: CoreMetadata = Field(description="The core metadata for the image")
|
||||||
|
|
||||||
|
|
||||||
|
class MetadataAccumulatorInvocation(BaseInvocation):
|
||||||
|
"""Outputs a Core Metadata Object"""
|
||||||
|
|
||||||
|
type: Literal["metadata_accumulator"] = "metadata_accumulator"
|
||||||
|
|
||||||
|
generation_mode: str = Field(description="The generation mode that output this image",)
|
||||||
|
positive_prompt: str = Field(description="The positive prompt parameter")
|
||||||
|
negative_prompt: str = Field(description="The negative prompt parameter")
|
||||||
|
width: int = Field(description="The width parameter")
|
||||||
|
height: int = Field(description="The height parameter")
|
||||||
|
seed: int = Field(description="The seed used for noise generation")
|
||||||
|
rand_device: str = Field(description="The device used for random number generation")
|
||||||
|
cfg_scale: float = Field(description="The classifier-free guidance scale parameter")
|
||||||
|
steps: int = Field(description="The number of steps used for inference")
|
||||||
|
scheduler: str = Field(description="The scheduler used for inference")
|
||||||
|
clip_skip: int = Field(description="The number of skipped CLIP layers",)
|
||||||
|
model: MainModelField = Field(description="The main model used for inference")
|
||||||
|
controlnets: list[ControlField]= Field(description="The ControlNets used for inference")
|
||||||
|
loras: list[LoRAMetadataField] = Field(description="The LoRAs used for inference")
|
||||||
|
strength: Union[float, None] = Field(
|
||||||
|
default=None,
|
||||||
|
description="The strength used for latents-to-latents",
|
||||||
|
)
|
||||||
|
init_image: Union[str, None] = Field(
|
||||||
|
default=None, description="The name of the initial image"
|
||||||
|
)
|
||||||
|
vae: Union[VAEModelField, None] = Field(
|
||||||
|
default=None,
|
||||||
|
description="The VAE used for decoding, if the main model's default was not used",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def invoke(self, context: InvocationContext) -> MetadataAccumulatorOutput:
|
||||||
|
"""Collects and outputs a CoreMetadata object"""
|
||||||
|
|
||||||
|
return MetadataAccumulatorOutput(
|
||||||
|
metadata=CoreMetadata(
|
||||||
|
generation_mode=self.generation_mode,
|
||||||
|
positive_prompt=self.positive_prompt,
|
||||||
|
negative_prompt=self.negative_prompt,
|
||||||
|
width=self.width,
|
||||||
|
height=self.height,
|
||||||
|
seed=self.seed,
|
||||||
|
rand_device=self.rand_device,
|
||||||
|
cfg_scale=self.cfg_scale,
|
||||||
|
steps=self.steps,
|
||||||
|
scheduler=self.scheduler,
|
||||||
|
model=self.model,
|
||||||
|
strength=self.strength,
|
||||||
|
init_image=self.init_image,
|
||||||
|
vae=self.vae,
|
||||||
|
controlnets=self.controlnets,
|
||||||
|
loras=self.loras,
|
||||||
|
clip_skip=self.clip_skip,
|
||||||
|
)
|
||||||
|
)
|
@ -1,93 +0,0 @@
|
|||||||
from typing import Optional, Union, List
|
|
||||||
from pydantic import BaseModel, Extra, Field, StrictFloat, StrictInt, StrictStr
|
|
||||||
|
|
||||||
|
|
||||||
class ImageMetadata(BaseModel):
|
|
||||||
"""
|
|
||||||
Core generation metadata for an image/tensor generated in InvokeAI.
|
|
||||||
|
|
||||||
Also includes any metadata from the image's PNG tEXt chunks.
|
|
||||||
|
|
||||||
Generated by traversing the execution graph, collecting the parameters of the nearest ancestors
|
|
||||||
of a given node.
|
|
||||||
|
|
||||||
Full metadata may be accessed by querying for the session in the `graph_executions` table.
|
|
||||||
"""
|
|
||||||
|
|
||||||
class Config:
|
|
||||||
extra = Extra.allow
|
|
||||||
"""
|
|
||||||
This lets the ImageMetadata class accept arbitrary additional fields. The CoreMetadataService
|
|
||||||
won't add any fields that are not already defined, but other a different metadata service
|
|
||||||
implementation might.
|
|
||||||
"""
|
|
||||||
|
|
||||||
type: Optional[StrictStr] = Field(
|
|
||||||
default=None,
|
|
||||||
description="The type of the ancestor node of the image output node.",
|
|
||||||
)
|
|
||||||
"""The type of the ancestor node of the image output node."""
|
|
||||||
positive_conditioning: Optional[StrictStr] = Field(
|
|
||||||
default=None, description="The positive conditioning."
|
|
||||||
)
|
|
||||||
"""The positive conditioning"""
|
|
||||||
negative_conditioning: Optional[StrictStr] = Field(
|
|
||||||
default=None, description="The negative conditioning."
|
|
||||||
)
|
|
||||||
"""The negative conditioning"""
|
|
||||||
width: Optional[StrictInt] = Field(
|
|
||||||
default=None, description="Width of the image/latents in pixels."
|
|
||||||
)
|
|
||||||
"""Width of the image/latents in pixels"""
|
|
||||||
height: Optional[StrictInt] = Field(
|
|
||||||
default=None, description="Height of the image/latents in pixels."
|
|
||||||
)
|
|
||||||
"""Height of the image/latents in pixels"""
|
|
||||||
seed: Optional[StrictInt] = Field(
|
|
||||||
default=None, description="The seed used for noise generation."
|
|
||||||
)
|
|
||||||
"""The seed used for noise generation"""
|
|
||||||
# cfg_scale: Optional[StrictFloat] = Field(
|
|
||||||
# cfg_scale: Union[float, list[float]] = Field(
|
|
||||||
cfg_scale: Union[StrictFloat, List[StrictFloat]] = Field(
|
|
||||||
default=None, description="The classifier-free guidance scale."
|
|
||||||
)
|
|
||||||
"""The classifier-free guidance scale"""
|
|
||||||
steps: Optional[StrictInt] = Field(
|
|
||||||
default=None, description="The number of steps used for inference."
|
|
||||||
)
|
|
||||||
"""The number of steps used for inference"""
|
|
||||||
scheduler: Optional[StrictStr] = Field(
|
|
||||||
default=None, description="The scheduler used for inference."
|
|
||||||
)
|
|
||||||
"""The scheduler used for inference"""
|
|
||||||
model: Optional[StrictStr] = Field(
|
|
||||||
default=None, description="The model used for inference."
|
|
||||||
)
|
|
||||||
"""The model used for inference"""
|
|
||||||
strength: Optional[StrictFloat] = Field(
|
|
||||||
default=None,
|
|
||||||
description="The strength used for image-to-image/latents-to-latents.",
|
|
||||||
)
|
|
||||||
"""The strength used for image-to-image/latents-to-latents."""
|
|
||||||
latents: Optional[StrictStr] = Field(
|
|
||||||
default=None, description="The ID of the initial latents."
|
|
||||||
)
|
|
||||||
"""The ID of the initial latents"""
|
|
||||||
vae: Optional[StrictStr] = Field(
|
|
||||||
default=None, description="The VAE used for decoding."
|
|
||||||
)
|
|
||||||
"""The VAE used for decoding"""
|
|
||||||
unet: Optional[StrictStr] = Field(
|
|
||||||
default=None, description="The UNet used dor inference."
|
|
||||||
)
|
|
||||||
"""The UNet used dor inference"""
|
|
||||||
clip: Optional[StrictStr] = Field(
|
|
||||||
default=None, description="The CLIP Encoder used for conditioning."
|
|
||||||
)
|
|
||||||
"""The CLIP Encoder used for conditioning"""
|
|
||||||
extra: Optional[StrictStr] = Field(
|
|
||||||
default=None,
|
|
||||||
description="Uploaded image metadata, extracted from the PNG tEXt chunk.",
|
|
||||||
)
|
|
||||||
"""Uploaded image metadata, extracted from the PNG tEXt chunk."""
|
|
@ -1,14 +1,14 @@
|
|||||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) and the InvokeAI Team
|
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) and the InvokeAI Team
|
||||||
|
import json
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from queue import Queue
|
from queue import Queue
|
||||||
from typing import Dict, Optional, Union
|
from typing import Dict, Optional, Union
|
||||||
|
|
||||||
from PIL.Image import Image as PILImageType
|
|
||||||
from PIL import Image, PngImagePlugin
|
from PIL import Image, PngImagePlugin
|
||||||
|
from PIL.Image import Image as PILImageType
|
||||||
from send2trash import send2trash
|
from send2trash import send2trash
|
||||||
|
|
||||||
from invokeai.app.models.metadata import ImageMetadata
|
|
||||||
from invokeai.app.util.thumbnails import get_thumbnail_name, make_thumbnail
|
from invokeai.app.util.thumbnails import get_thumbnail_name, make_thumbnail
|
||||||
|
|
||||||
|
|
||||||
@ -59,7 +59,8 @@ class ImageFileStorageBase(ABC):
|
|||||||
self,
|
self,
|
||||||
image: PILImageType,
|
image: PILImageType,
|
||||||
image_name: str,
|
image_name: str,
|
||||||
metadata: Optional[ImageMetadata] = None,
|
metadata: Optional[dict] = None,
|
||||||
|
graph: Optional[dict] = None,
|
||||||
thumbnail_size: int = 256,
|
thumbnail_size: int = 256,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Saves an image and a 256x256 WEBP thumbnail. Returns a tuple of the image name, thumbnail name, and created timestamp."""
|
"""Saves an image and a 256x256 WEBP thumbnail. Returns a tuple of the image name, thumbnail name, and created timestamp."""
|
||||||
@ -110,20 +111,22 @@ class DiskImageFileStorage(ImageFileStorageBase):
|
|||||||
self,
|
self,
|
||||||
image: PILImageType,
|
image: PILImageType,
|
||||||
image_name: str,
|
image_name: str,
|
||||||
metadata: Optional[ImageMetadata] = None,
|
metadata: Optional[dict] = None,
|
||||||
|
graph: Optional[dict] = None,
|
||||||
thumbnail_size: int = 256,
|
thumbnail_size: int = 256,
|
||||||
) -> None:
|
) -> None:
|
||||||
try:
|
try:
|
||||||
self.__validate_storage_folders()
|
self.__validate_storage_folders()
|
||||||
image_path = self.get_path(image_name)
|
image_path = self.get_path(image_name)
|
||||||
|
|
||||||
if metadata is not None:
|
|
||||||
pnginfo = PngImagePlugin.PngInfo()
|
pnginfo = PngImagePlugin.PngInfo()
|
||||||
pnginfo.add_text("invokeai", metadata.json())
|
|
||||||
image.save(image_path, "PNG", pnginfo=pnginfo)
|
|
||||||
else:
|
|
||||||
image.save(image_path, "PNG")
|
|
||||||
|
|
||||||
|
if metadata is not None:
|
||||||
|
pnginfo.add_text("metadata", json.dumps(metadata))
|
||||||
|
if graph is not None:
|
||||||
|
pnginfo.add_text("graph", json.dumps(graph))
|
||||||
|
|
||||||
|
image.save(image_path, "PNG", pnginfo=pnginfo)
|
||||||
thumbnail_name = get_thumbnail_name(image_name)
|
thumbnail_name = get_thumbnail_name(image_name)
|
||||||
thumbnail_path = self.get_path(thumbnail_name, thumbnail=True)
|
thumbnail_path = self.get_path(thumbnail_name, thumbnail=True)
|
||||||
thumbnail_image = make_thumbnail(image, thumbnail_size)
|
thumbnail_image = make_thumbnail(image, thumbnail_size)
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
import json
|
||||||
import sqlite3
|
import sqlite3
|
||||||
import threading
|
import threading
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
@ -8,7 +9,6 @@ from pydantic import BaseModel, Field
|
|||||||
from pydantic.generics import GenericModel
|
from pydantic.generics import GenericModel
|
||||||
|
|
||||||
from invokeai.app.models.image import ImageCategory, ResourceOrigin
|
from invokeai.app.models.image import ImageCategory, ResourceOrigin
|
||||||
from invokeai.app.models.metadata import ImageMetadata
|
|
||||||
from invokeai.app.services.models.image_record import (
|
from invokeai.app.services.models.image_record import (
|
||||||
ImageRecord, ImageRecordChanges, deserialize_image_record)
|
ImageRecord, ImageRecordChanges, deserialize_image_record)
|
||||||
|
|
||||||
@ -48,6 +48,28 @@ class ImageRecordDeleteException(Exception):
|
|||||||
super().__init__(message)
|
super().__init__(message)
|
||||||
|
|
||||||
|
|
||||||
|
IMAGE_DTO_COLS = ", ".join(
|
||||||
|
list(
|
||||||
|
map(
|
||||||
|
lambda c: "images." + c,
|
||||||
|
[
|
||||||
|
"image_name",
|
||||||
|
"image_origin",
|
||||||
|
"image_category",
|
||||||
|
"width",
|
||||||
|
"height",
|
||||||
|
"session_id",
|
||||||
|
"node_id",
|
||||||
|
"is_intermediate",
|
||||||
|
"created_at",
|
||||||
|
"updated_at",
|
||||||
|
"deleted_at",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class ImageRecordStorageBase(ABC):
|
class ImageRecordStorageBase(ABC):
|
||||||
"""Low-level service responsible for interfacing with the image record store."""
|
"""Low-level service responsible for interfacing with the image record store."""
|
||||||
|
|
||||||
@ -58,6 +80,11 @@ class ImageRecordStorageBase(ABC):
|
|||||||
"""Gets an image record."""
|
"""Gets an image record."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_metadata(self, image_name: str) -> Optional[dict]:
|
||||||
|
"""Gets an image's metadata'."""
|
||||||
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def update(
|
def update(
|
||||||
self,
|
self,
|
||||||
@ -102,7 +129,7 @@ class ImageRecordStorageBase(ABC):
|
|||||||
height: int,
|
height: int,
|
||||||
session_id: Optional[str],
|
session_id: Optional[str],
|
||||||
node_id: Optional[str],
|
node_id: Optional[str],
|
||||||
metadata: Optional[ImageMetadata],
|
metadata: Optional[dict],
|
||||||
is_intermediate: bool = False,
|
is_intermediate: bool = False,
|
||||||
) -> datetime:
|
) -> datetime:
|
||||||
"""Saves an image record."""
|
"""Saves an image record."""
|
||||||
@ -206,7 +233,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
|||||||
|
|
||||||
self._cursor.execute(
|
self._cursor.execute(
|
||||||
f"""--sql
|
f"""--sql
|
||||||
SELECT * FROM images
|
SELECT {IMAGE_DTO_COLS} FROM images
|
||||||
WHERE image_name = ?;
|
WHERE image_name = ?;
|
||||||
""",
|
""",
|
||||||
(image_name,),
|
(image_name,),
|
||||||
@ -224,6 +251,28 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
|||||||
|
|
||||||
return deserialize_image_record(dict(result))
|
return deserialize_image_record(dict(result))
|
||||||
|
|
||||||
|
def get_metadata(self, image_name: str) -> Optional[dict]:
|
||||||
|
try:
|
||||||
|
self._lock.acquire()
|
||||||
|
|
||||||
|
self._cursor.execute(
|
||||||
|
f"""--sql
|
||||||
|
SELECT images.metadata FROM images
|
||||||
|
WHERE image_name = ?;
|
||||||
|
""",
|
||||||
|
(image_name,),
|
||||||
|
)
|
||||||
|
|
||||||
|
result = cast(Optional[sqlite3.Row], self._cursor.fetchone())
|
||||||
|
if not result or not result[0]:
|
||||||
|
return None
|
||||||
|
return json.loads(result[0])
|
||||||
|
except sqlite3.Error as e:
|
||||||
|
self._conn.rollback()
|
||||||
|
raise ImageRecordNotFoundException from e
|
||||||
|
finally:
|
||||||
|
self._lock.release()
|
||||||
|
|
||||||
def update(
|
def update(
|
||||||
self,
|
self,
|
||||||
image_name: str,
|
image_name: str,
|
||||||
@ -291,8 +340,8 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
|||||||
WHERE 1=1
|
WHERE 1=1
|
||||||
"""
|
"""
|
||||||
|
|
||||||
images_query = """--sql
|
images_query = f"""--sql
|
||||||
SELECT images.*
|
SELECT {IMAGE_DTO_COLS}
|
||||||
FROM images
|
FROM images
|
||||||
LEFT JOIN board_images ON board_images.image_name = images.image_name
|
LEFT JOIN board_images ON board_images.image_name = images.image_name
|
||||||
WHERE 1=1
|
WHERE 1=1
|
||||||
@ -410,12 +459,12 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
|||||||
width: int,
|
width: int,
|
||||||
height: int,
|
height: int,
|
||||||
node_id: Optional[str],
|
node_id: Optional[str],
|
||||||
metadata: Optional[ImageMetadata],
|
metadata: Optional[dict],
|
||||||
is_intermediate: bool = False,
|
is_intermediate: bool = False,
|
||||||
) -> datetime:
|
) -> datetime:
|
||||||
try:
|
try:
|
||||||
metadata_json = (
|
metadata_json = (
|
||||||
None if metadata is None else metadata.json(exclude_none=True)
|
None if metadata is None else json.dumps(metadata)
|
||||||
)
|
)
|
||||||
self._lock.acquire()
|
self._lock.acquire()
|
||||||
self._cursor.execute(
|
self._cursor.execute(
|
||||||
@ -465,9 +514,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
|||||||
finally:
|
finally:
|
||||||
self._lock.release()
|
self._lock.release()
|
||||||
|
|
||||||
def get_most_recent_image_for_board(
|
def get_most_recent_image_for_board(self, board_id: str) -> Optional[ImageRecord]:
|
||||||
self, board_id: str
|
|
||||||
) -> Optional[ImageRecord]:
|
|
||||||
try:
|
try:
|
||||||
self._lock.acquire()
|
self._lock.acquire()
|
||||||
self._cursor.execute(
|
self._cursor.execute(
|
||||||
|
@ -1,39 +1,30 @@
|
|||||||
|
import json
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from logging import Logger
|
from logging import Logger
|
||||||
from typing import Optional, TYPE_CHECKING, Union
|
from typing import TYPE_CHECKING, Optional
|
||||||
|
|
||||||
from PIL.Image import Image as PILImageType
|
from PIL.Image import Image as PILImageType
|
||||||
|
|
||||||
from invokeai.app.models.image import (
|
from invokeai.app.invocations.metadata import ImageMetadata
|
||||||
ImageCategory,
|
from invokeai.app.models.image import (ImageCategory,
|
||||||
ResourceOrigin,
|
|
||||||
InvalidImageCategoryException,
|
InvalidImageCategoryException,
|
||||||
InvalidOriginException,
|
InvalidOriginException, ResourceOrigin)
|
||||||
)
|
from invokeai.app.services.board_image_record_storage import \
|
||||||
from invokeai.app.models.metadata import ImageMetadata
|
BoardImageRecordStorageBase
|
||||||
from invokeai.app.services.board_image_record_storage import BoardImageRecordStorageBase
|
from invokeai.app.services.graph import Graph
|
||||||
from invokeai.app.services.image_record_storage import (
|
|
||||||
ImageRecordDeleteException,
|
|
||||||
ImageRecordNotFoundException,
|
|
||||||
ImageRecordSaveException,
|
|
||||||
ImageRecordStorageBase,
|
|
||||||
OffsetPaginatedResults,
|
|
||||||
)
|
|
||||||
from invokeai.app.services.models.image_record import (
|
|
||||||
ImageRecord,
|
|
||||||
ImageDTO,
|
|
||||||
ImageRecordChanges,
|
|
||||||
image_record_to_dto,
|
|
||||||
)
|
|
||||||
from invokeai.app.services.image_file_storage import (
|
from invokeai.app.services.image_file_storage import (
|
||||||
ImageFileDeleteException,
|
ImageFileDeleteException, ImageFileNotFoundException,
|
||||||
ImageFileNotFoundException,
|
ImageFileSaveException, ImageFileStorageBase)
|
||||||
ImageFileSaveException,
|
from invokeai.app.services.image_record_storage import (
|
||||||
ImageFileStorageBase,
|
ImageRecordDeleteException, ImageRecordNotFoundException,
|
||||||
)
|
ImageRecordSaveException, ImageRecordStorageBase, OffsetPaginatedResults)
|
||||||
from invokeai.app.services.item_storage import ItemStorageABC, PaginatedResults
|
from invokeai.app.services.item_storage import ItemStorageABC
|
||||||
from invokeai.app.services.metadata import MetadataServiceBase
|
from invokeai.app.services.models.image_record import (ImageDTO, ImageRecord,
|
||||||
|
ImageRecordChanges,
|
||||||
|
image_record_to_dto)
|
||||||
from invokeai.app.services.resource_name import NameServiceBase
|
from invokeai.app.services.resource_name import NameServiceBase
|
||||||
from invokeai.app.services.urls import UrlServiceBase
|
from invokeai.app.services.urls import UrlServiceBase
|
||||||
|
from invokeai.app.util.metadata import get_metadata_graph_from_raw_session
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from invokeai.app.services.graph import GraphExecutionState
|
from invokeai.app.services.graph import GraphExecutionState
|
||||||
@ -51,6 +42,7 @@ class ImageServiceABC(ABC):
|
|||||||
node_id: Optional[str] = None,
|
node_id: Optional[str] = None,
|
||||||
session_id: Optional[str] = None,
|
session_id: Optional[str] = None,
|
||||||
is_intermediate: bool = False,
|
is_intermediate: bool = False,
|
||||||
|
metadata: Optional[dict] = None,
|
||||||
) -> ImageDTO:
|
) -> ImageDTO:
|
||||||
"""Creates an image, storing the file and its metadata."""
|
"""Creates an image, storing the file and its metadata."""
|
||||||
pass
|
pass
|
||||||
@ -79,6 +71,11 @@ class ImageServiceABC(ABC):
|
|||||||
"""Gets an image DTO."""
|
"""Gets an image DTO."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_metadata(self, image_name: str) -> ImageMetadata:
|
||||||
|
"""Gets an image's metadata."""
|
||||||
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_path(self, image_name: str, thumbnail: bool = False) -> str:
|
def get_path(self, image_name: str, thumbnail: bool = False) -> str:
|
||||||
"""Gets an image's path."""
|
"""Gets an image's path."""
|
||||||
@ -124,7 +121,6 @@ class ImageServiceDependencies:
|
|||||||
image_records: ImageRecordStorageBase
|
image_records: ImageRecordStorageBase
|
||||||
image_files: ImageFileStorageBase
|
image_files: ImageFileStorageBase
|
||||||
board_image_records: BoardImageRecordStorageBase
|
board_image_records: BoardImageRecordStorageBase
|
||||||
metadata: MetadataServiceBase
|
|
||||||
urls: UrlServiceBase
|
urls: UrlServiceBase
|
||||||
logger: Logger
|
logger: Logger
|
||||||
names: NameServiceBase
|
names: NameServiceBase
|
||||||
@ -135,7 +131,6 @@ class ImageServiceDependencies:
|
|||||||
image_record_storage: ImageRecordStorageBase,
|
image_record_storage: ImageRecordStorageBase,
|
||||||
image_file_storage: ImageFileStorageBase,
|
image_file_storage: ImageFileStorageBase,
|
||||||
board_image_record_storage: BoardImageRecordStorageBase,
|
board_image_record_storage: BoardImageRecordStorageBase,
|
||||||
metadata: MetadataServiceBase,
|
|
||||||
url: UrlServiceBase,
|
url: UrlServiceBase,
|
||||||
logger: Logger,
|
logger: Logger,
|
||||||
names: NameServiceBase,
|
names: NameServiceBase,
|
||||||
@ -144,7 +139,6 @@ class ImageServiceDependencies:
|
|||||||
self.image_records = image_record_storage
|
self.image_records = image_record_storage
|
||||||
self.image_files = image_file_storage
|
self.image_files = image_file_storage
|
||||||
self.board_image_records = board_image_record_storage
|
self.board_image_records = board_image_record_storage
|
||||||
self.metadata = metadata
|
|
||||||
self.urls = url
|
self.urls = url
|
||||||
self.logger = logger
|
self.logger = logger
|
||||||
self.names = names
|
self.names = names
|
||||||
@ -165,6 +159,7 @@ class ImageService(ImageServiceABC):
|
|||||||
node_id: Optional[str] = None,
|
node_id: Optional[str] = None,
|
||||||
session_id: Optional[str] = None,
|
session_id: Optional[str] = None,
|
||||||
is_intermediate: bool = False,
|
is_intermediate: bool = False,
|
||||||
|
metadata: Optional[dict] = None,
|
||||||
) -> ImageDTO:
|
) -> ImageDTO:
|
||||||
if image_origin not in ResourceOrigin:
|
if image_origin not in ResourceOrigin:
|
||||||
raise InvalidOriginException
|
raise InvalidOriginException
|
||||||
@ -174,7 +169,16 @@ class ImageService(ImageServiceABC):
|
|||||||
|
|
||||||
image_name = self._services.names.create_image_name()
|
image_name = self._services.names.create_image_name()
|
||||||
|
|
||||||
metadata = self._get_metadata(session_id, node_id)
|
graph = None
|
||||||
|
|
||||||
|
if session_id is not None:
|
||||||
|
session_raw = self._services.graph_execution_manager.get_raw(session_id)
|
||||||
|
if session_raw is not None:
|
||||||
|
try:
|
||||||
|
graph = get_metadata_graph_from_raw_session(session_raw)
|
||||||
|
except Exception as e:
|
||||||
|
self._services.logger.warn(f"Failed to parse session graph: {e}")
|
||||||
|
graph = None
|
||||||
|
|
||||||
(width, height) = image.size
|
(width, height) = image.size
|
||||||
|
|
||||||
@ -191,14 +195,12 @@ class ImageService(ImageServiceABC):
|
|||||||
is_intermediate=is_intermediate,
|
is_intermediate=is_intermediate,
|
||||||
# Nullable fields
|
# Nullable fields
|
||||||
node_id=node_id,
|
node_id=node_id,
|
||||||
session_id=session_id,
|
|
||||||
metadata=metadata,
|
metadata=metadata,
|
||||||
|
session_id=session_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
self._services.image_files.save(
|
self._services.image_files.save(
|
||||||
image_name=image_name,
|
image_name=image_name, image=image, metadata=metadata, graph=graph
|
||||||
image=image,
|
|
||||||
metadata=metadata,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
image_dto = self.get_dto(image_name)
|
image_dto = self.get_dto(image_name)
|
||||||
@ -268,6 +270,34 @@ class ImageService(ImageServiceABC):
|
|||||||
self._services.logger.error("Problem getting image DTO")
|
self._services.logger.error("Problem getting image DTO")
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
|
def get_metadata(self, image_name: str) -> Optional[ImageMetadata]:
|
||||||
|
try:
|
||||||
|
image_record = self._services.image_records.get(image_name)
|
||||||
|
|
||||||
|
if not image_record.session_id:
|
||||||
|
return ImageMetadata()
|
||||||
|
|
||||||
|
session_raw = self._services.graph_execution_manager.get_raw(
|
||||||
|
image_record.session_id
|
||||||
|
)
|
||||||
|
graph = None
|
||||||
|
|
||||||
|
if session_raw:
|
||||||
|
try:
|
||||||
|
graph = get_metadata_graph_from_raw_session(session_raw)
|
||||||
|
except Exception as e:
|
||||||
|
self._services.logger.warn(f"Failed to parse session graph: {e}")
|
||||||
|
graph = None
|
||||||
|
|
||||||
|
metadata = self._services.image_records.get_metadata(image_name)
|
||||||
|
return ImageMetadata(graph=graph, metadata=metadata)
|
||||||
|
except ImageRecordNotFoundException:
|
||||||
|
self._services.logger.error("Image record not found")
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
self._services.logger.error("Problem getting image DTO")
|
||||||
|
raise e
|
||||||
|
|
||||||
def get_path(self, image_name: str, thumbnail: bool = False) -> str:
|
def get_path(self, image_name: str, thumbnail: bool = False) -> str:
|
||||||
try:
|
try:
|
||||||
return self._services.image_files.get_path(image_name, thumbnail)
|
return self._services.image_files.get_path(image_name, thumbnail)
|
||||||
@ -367,15 +397,3 @@ class ImageService(ImageServiceABC):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
self._services.logger.error("Problem deleting image records and files")
|
self._services.logger.error("Problem deleting image records and files")
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
def _get_metadata(
|
|
||||||
self, session_id: Optional[str] = None, node_id: Optional[str] = None
|
|
||||||
) -> Optional[ImageMetadata]:
|
|
||||||
"""Get the metadata for a node."""
|
|
||||||
metadata = None
|
|
||||||
|
|
||||||
if node_id is not None and session_id is not None:
|
|
||||||
session = self._services.graph_execution_manager.get(session_id)
|
|
||||||
metadata = self._services.metadata.create_image_metadata(session, node_id)
|
|
||||||
|
|
||||||
return metadata
|
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Callable, Generic, TypeVar
|
from typing import Callable, Generic, Optional, TypeVar
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from pydantic.generics import GenericModel
|
from pydantic.generics import GenericModel
|
||||||
@ -29,14 +29,22 @@ class ItemStorageABC(ABC, Generic[T]):
|
|||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get(self, item_id: str) -> T:
|
def get(self, item_id: str) -> T:
|
||||||
|
"""Gets the item, parsing it into a Pydantic model"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_raw(self, item_id: str) -> Optional[str]:
|
||||||
|
"""Gets the raw item as a string, skipping Pydantic parsing"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def set(self, item: T) -> None:
|
def set(self, item: T) -> None:
|
||||||
|
"""Sets the item"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def list(self, page: int = 0, per_page: int = 10) -> PaginatedResults[T]:
|
def list(self, page: int = 0, per_page: int = 10) -> PaginatedResults[T]:
|
||||||
|
"""Gets a paginated list of items"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
|
@ -1,142 +0,0 @@
|
|||||||
from abc import ABC, abstractmethod
|
|
||||||
from typing import Any, Optional
|
|
||||||
import networkx as nx
|
|
||||||
|
|
||||||
from invokeai.app.models.metadata import ImageMetadata
|
|
||||||
from invokeai.app.services.graph import Graph, GraphExecutionState
|
|
||||||
|
|
||||||
|
|
||||||
class MetadataServiceBase(ABC):
|
|
||||||
"""Handles building metadata for nodes, images, and outputs."""
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def create_image_metadata(
|
|
||||||
self, session: GraphExecutionState, node_id: str
|
|
||||||
) -> ImageMetadata:
|
|
||||||
"""Builds an ImageMetadata object for a node."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class CoreMetadataService(MetadataServiceBase):
|
|
||||||
_ANCESTOR_TYPES = ["t2l", "l2l"]
|
|
||||||
"""The ancestor types that contain the core metadata"""
|
|
||||||
|
|
||||||
_ANCESTOR_PARAMS = ["type", "steps", "model", "cfg_scale", "scheduler", "strength"]
|
|
||||||
"""The core metadata parameters in the ancestor types"""
|
|
||||||
|
|
||||||
_NOISE_FIELDS = ["seed", "width", "height"]
|
|
||||||
"""The core metadata parameters in the noise node"""
|
|
||||||
|
|
||||||
def create_image_metadata(
|
|
||||||
self, session: GraphExecutionState, node_id: str
|
|
||||||
) -> ImageMetadata:
|
|
||||||
metadata = self._build_metadata_from_graph(session, node_id)
|
|
||||||
|
|
||||||
return metadata
|
|
||||||
|
|
||||||
def _find_nearest_ancestor(self, G: nx.DiGraph, node_id: str) -> Optional[str]:
|
|
||||||
"""
|
|
||||||
Finds the id of the nearest ancestor (of a valid type) of a given node.
|
|
||||||
|
|
||||||
Parameters:
|
|
||||||
G (nx.DiGraph): The execution graph, converted in to a networkx DiGraph. Its nodes must
|
|
||||||
have the same data as the execution graph.
|
|
||||||
node_id (str): The ID of the node.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
str | None: The ID of the nearest ancestor, or None if there are no valid ancestors.
|
|
||||||
"""
|
|
||||||
|
|
||||||
# Retrieve the node from the graph
|
|
||||||
node = G.nodes[node_id]
|
|
||||||
|
|
||||||
# If the node type is one of the core metadata node types, return its id
|
|
||||||
if node.get("type") in self._ANCESTOR_TYPES:
|
|
||||||
return node.get("id")
|
|
||||||
|
|
||||||
# Else, look for the ancestor in the predecessor nodes
|
|
||||||
for predecessor in G.predecessors(node_id):
|
|
||||||
result = self._find_nearest_ancestor(G, predecessor)
|
|
||||||
if result:
|
|
||||||
return result
|
|
||||||
|
|
||||||
# If there are no valid ancestors, return None
|
|
||||||
return None
|
|
||||||
|
|
||||||
def _get_additional_metadata(
|
|
||||||
self, graph: Graph, node_id: str
|
|
||||||
) -> Optional[dict[str, Any]]:
|
|
||||||
"""
|
|
||||||
Returns additional metadata for a given node.
|
|
||||||
|
|
||||||
Parameters:
|
|
||||||
graph (Graph): The execution graph.
|
|
||||||
node_id (str): The ID of the node.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
dict[str, Any] | None: A dictionary of additional metadata.
|
|
||||||
"""
|
|
||||||
|
|
||||||
metadata = {}
|
|
||||||
|
|
||||||
# Iterate over all edges in the graph
|
|
||||||
for edge in graph.edges:
|
|
||||||
dest_node_id = edge.destination.node_id
|
|
||||||
dest_field = edge.destination.field
|
|
||||||
source_node_dict = graph.nodes[edge.source.node_id].dict()
|
|
||||||
|
|
||||||
# If the destination node ID matches the given node ID, gather necessary metadata
|
|
||||||
if dest_node_id == node_id:
|
|
||||||
# Prompt
|
|
||||||
if dest_field == "positive_conditioning":
|
|
||||||
metadata["positive_conditioning"] = source_node_dict.get("prompt")
|
|
||||||
# Negative prompt
|
|
||||||
if dest_field == "negative_conditioning":
|
|
||||||
metadata["negative_conditioning"] = source_node_dict.get("prompt")
|
|
||||||
# Seed, width and height
|
|
||||||
if dest_field == "noise":
|
|
||||||
for field in self._NOISE_FIELDS:
|
|
||||||
metadata[field] = source_node_dict.get(field)
|
|
||||||
return metadata
|
|
||||||
|
|
||||||
def _build_metadata_from_graph(
|
|
||||||
self, session: GraphExecutionState, node_id: str
|
|
||||||
) -> ImageMetadata:
|
|
||||||
"""
|
|
||||||
Builds an ImageMetadata object for a node.
|
|
||||||
|
|
||||||
Parameters:
|
|
||||||
session (GraphExecutionState): The session.
|
|
||||||
node_id (str): The ID of the node.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
ImageMetadata: The metadata for the node.
|
|
||||||
"""
|
|
||||||
|
|
||||||
# We need to do all the traversal on the execution graph
|
|
||||||
graph = session.execution_graph
|
|
||||||
|
|
||||||
# Find the nearest `t2l`/`l2l` ancestor of the given node
|
|
||||||
ancestor_id = self._find_nearest_ancestor(graph.nx_graph_with_data(), node_id)
|
|
||||||
|
|
||||||
# If no ancestor was found, return an empty ImageMetadata object
|
|
||||||
if ancestor_id is None:
|
|
||||||
return ImageMetadata()
|
|
||||||
|
|
||||||
ancestor_node = graph.get_node(ancestor_id)
|
|
||||||
|
|
||||||
# Grab all the core metadata from the ancestor node
|
|
||||||
ancestor_metadata = {
|
|
||||||
param: val
|
|
||||||
for param, val in ancestor_node.dict().items()
|
|
||||||
if param in self._ANCESTOR_PARAMS
|
|
||||||
}
|
|
||||||
|
|
||||||
# Get this image's prompts and noise parameters
|
|
||||||
addl_metadata = self._get_additional_metadata(graph, ancestor_id)
|
|
||||||
|
|
||||||
# If additional metadata was found, add it to the main metadata
|
|
||||||
if addl_metadata is not None:
|
|
||||||
ancestor_metadata.update(addl_metadata)
|
|
||||||
|
|
||||||
return ImageMetadata(**ancestor_metadata)
|
|
@ -1,13 +1,14 @@
|
|||||||
import datetime
|
import datetime
|
||||||
from typing import Optional, Union
|
from typing import Optional, Union
|
||||||
|
|
||||||
from pydantic import BaseModel, Extra, Field, StrictBool, StrictStr
|
from pydantic import BaseModel, Extra, Field, StrictBool, StrictStr
|
||||||
|
|
||||||
from invokeai.app.models.image import ImageCategory, ResourceOrigin
|
from invokeai.app.models.image import ImageCategory, ResourceOrigin
|
||||||
from invokeai.app.models.metadata import ImageMetadata
|
|
||||||
from invokeai.app.util.misc import get_iso_timestamp
|
from invokeai.app.util.misc import get_iso_timestamp
|
||||||
|
|
||||||
|
|
||||||
class ImageRecord(BaseModel):
|
class ImageRecord(BaseModel):
|
||||||
"""Deserialized image record."""
|
"""Deserialized image record without metadata."""
|
||||||
|
|
||||||
image_name: str = Field(description="The unique name of the image.")
|
image_name: str = Field(description="The unique name of the image.")
|
||||||
"""The unique name of the image."""
|
"""The unique name of the image."""
|
||||||
@ -43,11 +44,6 @@ class ImageRecord(BaseModel):
|
|||||||
description="The node ID that generated this image, if it is a generated image.",
|
description="The node ID that generated this image, if it is a generated image.",
|
||||||
)
|
)
|
||||||
"""The node ID that generated this image, if it is a generated image."""
|
"""The node ID that generated this image, if it is a generated image."""
|
||||||
metadata: Optional[ImageMetadata] = Field(
|
|
||||||
default=None,
|
|
||||||
description="A limited subset of the image's generation metadata. Retrieve the image's session for full metadata.",
|
|
||||||
)
|
|
||||||
"""A limited subset of the image's generation metadata. Retrieve the image's session for full metadata."""
|
|
||||||
|
|
||||||
|
|
||||||
class ImageRecordChanges(BaseModel, extra=Extra.forbid):
|
class ImageRecordChanges(BaseModel, extra=Extra.forbid):
|
||||||
@ -112,6 +108,7 @@ def deserialize_image_record(image_dict: dict) -> ImageRecord:
|
|||||||
|
|
||||||
# Retrieve all the values, setting "reasonable" defaults if they are not present.
|
# Retrieve all the values, setting "reasonable" defaults if they are not present.
|
||||||
|
|
||||||
|
# TODO: do we really need to handle default values here? ideally the data is the correct shape...
|
||||||
image_name = image_dict.get("image_name", "unknown")
|
image_name = image_dict.get("image_name", "unknown")
|
||||||
image_origin = ResourceOrigin(
|
image_origin = ResourceOrigin(
|
||||||
image_dict.get("image_origin", ResourceOrigin.INTERNAL.value)
|
image_dict.get("image_origin", ResourceOrigin.INTERNAL.value)
|
||||||
@ -128,13 +125,6 @@ def deserialize_image_record(image_dict: dict) -> ImageRecord:
|
|||||||
deleted_at = image_dict.get("deleted_at", get_iso_timestamp())
|
deleted_at = image_dict.get("deleted_at", get_iso_timestamp())
|
||||||
is_intermediate = image_dict.get("is_intermediate", False)
|
is_intermediate = image_dict.get("is_intermediate", False)
|
||||||
|
|
||||||
raw_metadata = image_dict.get("metadata")
|
|
||||||
|
|
||||||
if raw_metadata is not None:
|
|
||||||
metadata = ImageMetadata.parse_raw(raw_metadata)
|
|
||||||
else:
|
|
||||||
metadata = None
|
|
||||||
|
|
||||||
return ImageRecord(
|
return ImageRecord(
|
||||||
image_name=image_name,
|
image_name=image_name,
|
||||||
image_origin=image_origin,
|
image_origin=image_origin,
|
||||||
@ -143,7 +133,6 @@ def deserialize_image_record(image_dict: dict) -> ImageRecord:
|
|||||||
height=height,
|
height=height,
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
node_id=node_id,
|
node_id=node_id,
|
||||||
metadata=metadata,
|
|
||||||
created_at=created_at,
|
created_at=created_at,
|
||||||
updated_at=updated_at,
|
updated_at=updated_at,
|
||||||
deleted_at=deleted_at,
|
deleted_at=deleted_at,
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
import sqlite3
|
import sqlite3
|
||||||
from threading import Lock
|
from threading import Lock
|
||||||
from typing import Generic, TypeVar, Optional, Union, get_args
|
from typing import Generic, Optional, TypeVar, get_args
|
||||||
|
|
||||||
from pydantic import BaseModel, parse_raw_as
|
from pydantic import BaseModel, parse_raw_as
|
||||||
|
|
||||||
@ -78,6 +78,21 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
|
|||||||
|
|
||||||
return self._parse_item(result[0])
|
return self._parse_item(result[0])
|
||||||
|
|
||||||
|
def get_raw(self, id: str) -> Optional[str]:
|
||||||
|
try:
|
||||||
|
self._lock.acquire()
|
||||||
|
self._cursor.execute(
|
||||||
|
f"""SELECT item FROM {self._table_name} WHERE id = ?;""", (str(id),)
|
||||||
|
)
|
||||||
|
result = self._cursor.fetchone()
|
||||||
|
finally:
|
||||||
|
self._lock.release()
|
||||||
|
|
||||||
|
if not result:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return result[0]
|
||||||
|
|
||||||
def delete(self, id: str):
|
def delete(self, id: str):
|
||||||
try:
|
try:
|
||||||
self._lock.acquire()
|
self._lock.acquire()
|
||||||
|
@ -22,4 +22,4 @@ class LocalUrlService(UrlServiceBase):
|
|||||||
if thumbnail:
|
if thumbnail:
|
||||||
return f"{self._base_url}/images/{image_basename}/thumbnail"
|
return f"{self._base_url}/images/{image_basename}/thumbnail"
|
||||||
|
|
||||||
return f"{self._base_url}/images/{image_basename}"
|
return f"{self._base_url}/images/{image_basename}/full"
|
||||||
|
55
invokeai/app/util/metadata.py
Normal file
55
invokeai/app/util/metadata.py
Normal file
@ -0,0 +1,55 @@
|
|||||||
|
import json
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from pydantic import ValidationError
|
||||||
|
|
||||||
|
from invokeai.app.services.graph import Edge
|
||||||
|
|
||||||
|
|
||||||
|
def get_metadata_graph_from_raw_session(session_raw: str) -> Optional[dict]:
|
||||||
|
"""
|
||||||
|
Parses raw session string, returning a dict of the graph.
|
||||||
|
|
||||||
|
Only the general graph shape is validated; none of the fields are validated.
|
||||||
|
|
||||||
|
Any `metadata_accumulator` nodes and edges are removed.
|
||||||
|
|
||||||
|
Any validation failure will return None.
|
||||||
|
"""
|
||||||
|
|
||||||
|
graph = json.loads(session_raw).get("graph", None)
|
||||||
|
|
||||||
|
# sanity check make sure the graph is at least reasonably shaped
|
||||||
|
if (
|
||||||
|
type(graph) is not dict
|
||||||
|
or "nodes" not in graph
|
||||||
|
or type(graph["nodes"]) is not dict
|
||||||
|
or "edges" not in graph
|
||||||
|
or type(graph["edges"]) is not list
|
||||||
|
):
|
||||||
|
# something has gone terribly awry, return an empty dict
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
# delete the `metadata_accumulator` node
|
||||||
|
del graph["nodes"]["metadata_accumulator"]
|
||||||
|
except KeyError:
|
||||||
|
# no accumulator node, all good
|
||||||
|
pass
|
||||||
|
|
||||||
|
# delete any edges to or from it
|
||||||
|
for i, edge in enumerate(graph["edges"]):
|
||||||
|
try:
|
||||||
|
# try to parse the edge
|
||||||
|
Edge(**edge)
|
||||||
|
except ValidationError:
|
||||||
|
# something has gone terribly awry, return an empty dict
|
||||||
|
return None
|
||||||
|
|
||||||
|
if (
|
||||||
|
edge["source"]["node_id"] == "metadata_accumulator"
|
||||||
|
or edge["destination"]["node_id"] == "metadata_accumulator"
|
||||||
|
):
|
||||||
|
del graph["edges"][i]
|
||||||
|
|
||||||
|
return graph
|
Loading…
Reference in New Issue
Block a user