feat(nodes): add datatypes module

This commit is contained in:
psychedelicious 2023-04-04 13:08:42 +10:00
parent 77bf3c780f
commit a065f7db56
13 changed files with 56 additions and 48 deletions

View File

@ -5,7 +5,8 @@ import argparse
from typing import Any, Callable, Iterable, Literal, get_args, get_origin, get_type_hints from typing import Any, Callable, Iterable, Literal, get_args, get_origin, get_type_hints
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from ..invocations.image import ImageField from invokeai.app.datatypes.image import ImageField
from ..services.graph import GraphExecutionState from ..services.graph import GraphExecutionState
from ..services.invoker import Invoker from ..services.invoker import Invoker

View File

View File

@ -0,0 +1,3 @@
class CanceledException(Exception):
"""Execution canceled by user."""
pass

View File

@ -0,0 +1,26 @@
from enum import Enum
from typing import Optional
from pydantic import BaseModel, Field
class ImageType(str, Enum):
RESULT = "results"
INTERMEDIATE = "intermediates"
UPLOAD = "uploads"
class ImageField(BaseModel):
"""An image field used for passing image objects between invocations"""
image_type: str = Field(
default=ImageType.RESULT, description="The type of the image"
)
image_name: Optional[str] = Field(default=None, description="The name of the image")
class Config:
schema_extra = {
"required": [
"image_type",
"image_name",
]
}

View File

@ -7,9 +7,9 @@ import numpy
from PIL import Image, ImageOps from PIL import Image, ImageOps
from pydantic import Field from pydantic import Field
from ..services.image_storage import ImageType from invokeai.app.datatypes.image import ImageField, ImageType
from .baseinvocation import BaseInvocation, InvocationContext from .baseinvocation import BaseInvocation, InvocationContext
from .image import ImageField, ImageOutput from .image import ImageOutput
class CvInpaintInvocation(BaseInvocation): class CvInpaintInvocation(BaseInvocation):

View File

@ -8,12 +8,13 @@ from torch import Tensor
from pydantic import Field from pydantic import Field
from ..services.image_storage import ImageType from invokeai.app.datatypes.image import ImageField, ImageType
from .baseinvocation import BaseInvocation, InvocationContext from .baseinvocation import BaseInvocation, InvocationContext
from .image import ImageField, ImageOutput from .image import ImageOutput
from ...backend.generator import Txt2Img, Img2Img, Inpaint, InvokeAIGenerator from ...backend.generator import Txt2Img, Img2Img, Inpaint, InvokeAIGenerator
from ...backend.stable_diffusion import PipelineIntermediateState from ...backend.stable_diffusion import PipelineIntermediateState
from ..util.util import diffusers_step_callback_adapter, CanceledException from ..datatypes.exceptions import CanceledException
from ..util.step_callback import diffusers_step_callback_adapter
SAMPLER_NAME_VALUES = Literal[ SAMPLER_NAME_VALUES = Literal[
tuple(InvokeAIGenerator.schedulers()) tuple(InvokeAIGenerator.schedulers())

View File

@ -7,27 +7,10 @@ import numpy
from PIL import Image, ImageFilter, ImageOps from PIL import Image, ImageFilter, ImageOps
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from ..services.image_storage import ImageType from invokeai.app.datatypes.image import ImageField, ImageType
from ..services.invocation_services import InvocationServices from ..services.invocation_services import InvocationServices
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext
class ImageField(BaseModel):
"""An image field used for passing image objects between invocations"""
image_type: str = Field(
default=ImageType.RESULT, description="The type of the image"
)
image_name: Optional[str] = Field(default=None, description="The name of the image")
class Config:
schema_extra = {
'required': [
'image_type',
'image_name',
]
}
class ImageOutput(BaseInvocationOutput): class ImageOutput(BaseInvocationOutput):
"""Base class for invocations that output an image""" """Base class for invocations that output an image"""
#fmt: off #fmt: off

View File

@ -3,10 +3,10 @@ from typing import Literal, Union
from pydantic import Field from pydantic import Field
from ..services.image_storage import ImageType from invokeai.app.datatypes.image import ImageField, ImageType
from ..services.invocation_services import InvocationServices from ..services.invocation_services import InvocationServices
from .baseinvocation import BaseInvocation, InvocationContext from .baseinvocation import BaseInvocation, InvocationContext
from .image import ImageField, ImageOutput from .image import ImageOutput
class RestoreFaceInvocation(BaseInvocation): class RestoreFaceInvocation(BaseInvocation):
"""Restores faces in an image.""" """Restores faces in an image."""

View File

@ -5,10 +5,10 @@ from typing import Literal, Union
from pydantic import Field from pydantic import Field
from ..services.image_storage import ImageType from invokeai.app.datatypes.image import ImageField, ImageType
from ..services.invocation_services import InvocationServices from ..services.invocation_services import InvocationServices
from .baseinvocation import BaseInvocation, InvocationContext from .baseinvocation import BaseInvocation, InvocationContext
from .image import ImageField, ImageOutput from .image import ImageOutput
class UpscaleInvocation(BaseInvocation): class UpscaleInvocation(BaseInvocation):

View File

@ -10,19 +10,12 @@ from queue import Queue
from typing import Callable, Dict from typing import Callable, Dict
from PIL.Image import Image from PIL.Image import Image
from invokeai.app.invocations.image import ImageField from invokeai.app.datatypes.image import ImageField, ImageType
from invokeai.app.services.item_storage import PaginatedResults from invokeai.app.services.item_storage import PaginatedResults
from invokeai.app.util.save_thumbnail import save_thumbnail from invokeai.app.util.save_thumbnail import save_thumbnail
from invokeai.backend.image_util import PngWriter from invokeai.backend.image_util import PngWriter
class ImageType(str, Enum):
RESULT = "results"
INTERMEDIATE = "intermediates"
UPLOAD = "uploads"
class ImageStorageBase(ABC): class ImageStorageBase(ABC):
"""Responsible for storing and retrieving images.""" """Responsible for storing and retrieving images."""

View File

@ -4,7 +4,7 @@ from threading import Event, Thread
from ..invocations.baseinvocation import InvocationContext from ..invocations.baseinvocation import InvocationContext
from .invocation_queue import InvocationQueueItem from .invocation_queue import InvocationQueueItem
from .invoker import InvocationProcessorABC, Invoker from .invoker import InvocationProcessorABC, Invoker
from ..util.util import CanceledException from ..datatypes.exceptions import CanceledException
class DefaultInvocationProcessor(InvocationProcessorABC): class DefaultInvocationProcessor(InvocationProcessorABC):
__invoker_thread: Thread __invoker_thread: Thread

View File

View File

@ -1,14 +1,16 @@
import torch import torch
from PIL import Image
from ..invocations.baseinvocation import InvocationContext from ..invocations.baseinvocation import InvocationContext
from ...backend.util.util import image_to_dataURL from ...backend.util.util import image_to_dataURL
from ...backend.generator.base import Generator from ...backend.generator.base import Generator
from ...backend.stable_diffusion import PipelineIntermediateState from ...backend.stable_diffusion import PipelineIntermediateState
class CanceledException(Exception): def fast_latents_step_callback(
pass sample: torch.Tensor,
step: int,
def fast_latents_step_callback(sample: torch.Tensor, step: int, steps: int, id: str, context: InvocationContext, ): steps: int,
id: str,
context: InvocationContext,
):
# TODO: only output a preview image when requested # TODO: only output a preview image when requested
image = Generator.sample_to_lowres_estimated_image(sample) image = Generator.sample_to_lowres_estimated_image(sample)
@ -21,15 +23,12 @@ def fast_latents_step_callback(sample: torch.Tensor, step: int, steps: int, id:
context.services.events.emit_generator_progress( context.services.events.emit_generator_progress(
context.graph_execution_state_id, context.graph_execution_state_id,
id, id,
{ {"width": width, "height": height, "dataURL": dataURL},
"width": width,
"height": height,
"dataURL": dataURL
},
step, step,
steps, steps,
) )
def diffusers_step_callback_adapter(*cb_args, **kwargs): def diffusers_step_callback_adapter(*cb_args, **kwargs):
""" """
txt2img gives us a Tensor in the step_callbak, while img2img gives us a PipelineIntermediateState. txt2img gives us a Tensor in the step_callbak, while img2img gives us a PipelineIntermediateState.
@ -37,6 +36,8 @@ def diffusers_step_callback_adapter(*cb_args, **kwargs):
""" """
if isinstance(cb_args[0], PipelineIntermediateState): if isinstance(cb_args[0], PipelineIntermediateState):
progress_state: PipelineIntermediateState = cb_args[0] progress_state: PipelineIntermediateState = cb_args[0]
return fast_latents_step_callback(progress_state.latents, progress_state.step, **kwargs) return fast_latents_step_callback(
progress_state.latents, progress_state.step, **kwargs
)
else: else:
return fast_latents_step_callback(*cb_args, **kwargs) return fast_latents_step_callback(*cb_args, **kwargs)