mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
97 lines
2.7 KiB
Python
97 lines
2.7 KiB
Python
|
import json
|
||
|
from abc import ABC, abstractmethod
|
||
|
from typing import Any, Dict, Optional, TypedDict
|
||
|
from PIL import Image, PngImagePlugin
|
||
|
from pydantic import BaseModel
|
||
|
|
||
|
from invokeai.app.models.image import ImageType, is_image_type
|
||
|
|
||
|
|
||
|
class MetadataImageField(TypedDict):
|
||
|
"""Pydantic-less ImageField, used for metadata parsing."""
|
||
|
|
||
|
image_type: ImageType
|
||
|
image_name: str
|
||
|
|
||
|
|
||
|
class MetadataLatentsField(TypedDict):
|
||
|
"""Pydantic-less LatentsField, used for metadata parsing."""
|
||
|
|
||
|
latents_name: str
|
||
|
|
||
|
|
||
|
# TODO: This is a placeholder for `InvocationsUnion` pending resolution of circular imports
|
||
|
NodeMetadata = Dict[
|
||
|
str, str | int | float | bool | MetadataImageField | MetadataLatentsField
|
||
|
]
|
||
|
|
||
|
|
||
|
class InvokeAIMetadata(TypedDict, total=False):
|
||
|
"""InvokeAI-specific metadata format."""
|
||
|
|
||
|
session_id: Optional[str]
|
||
|
node: Optional[NodeMetadata]
|
||
|
|
||
|
|
||
|
def build_invokeai_metadata_pnginfo(
|
||
|
metadata: InvokeAIMetadata | None,
|
||
|
) -> PngImagePlugin.PngInfo:
|
||
|
"""Builds a PngInfo object with key `"invokeai"` and value `metadata`"""
|
||
|
pnginfo = PngImagePlugin.PngInfo()
|
||
|
|
||
|
if metadata is not None:
|
||
|
pnginfo.add_text("invokeai", json.dumps(metadata))
|
||
|
|
||
|
return pnginfo
|
||
|
|
||
|
|
||
|
class MetadataServiceBase(ABC):
|
||
|
@abstractmethod
|
||
|
def get_metadata(self, image: Image.Image) -> InvokeAIMetadata | None:
|
||
|
"""Gets the InvokeAI metadata from a PIL Image, skipping invalid values"""
|
||
|
pass
|
||
|
|
||
|
@abstractmethod
|
||
|
def build_metadata(
|
||
|
self, session_id: str, node: BaseModel
|
||
|
) -> InvokeAIMetadata | None:
|
||
|
"""Builds an InvokeAIMetadata object"""
|
||
|
pass
|
||
|
|
||
|
|
||
|
class PngMetadataService(MetadataServiceBase):
|
||
|
"""Handles loading and building metadata for images."""
|
||
|
|
||
|
# TODO: Use `InvocationsUnion` to **validate** metadata as representing a fully-functioning node
|
||
|
def _load_metadata(self, image: Image.Image) -> dict | None:
|
||
|
"""Loads a specific info entry from a PIL Image."""
|
||
|
|
||
|
try:
|
||
|
info = image.info.get("invokeai")
|
||
|
|
||
|
if type(info) is not str:
|
||
|
return None
|
||
|
|
||
|
loaded_metadata = json.loads(info)
|
||
|
|
||
|
if type(loaded_metadata) is not dict:
|
||
|
return None
|
||
|
|
||
|
if len(loaded_metadata.items()) == 0:
|
||
|
return None
|
||
|
|
||
|
return loaded_metadata
|
||
|
except:
|
||
|
return None
|
||
|
|
||
|
def get_metadata(self, image: Image.Image) -> dict | None:
|
||
|
"""Retrieves an image's metadata as a dict"""
|
||
|
loaded_metadata = self._load_metadata(image)
|
||
|
|
||
|
return loaded_metadata
|
||
|
|
||
|
def build_metadata(self, session_id: str, node: BaseModel) -> InvokeAIMetadata:
|
||
|
metadata = InvokeAIMetadata(session_id=session_id, node=node.dict())
|
||
|
|
||
|
return metadata
|