feat(nodes): move fully* to new images service

* except i haven't rebuilt inpaint in latents
This commit is contained in:
psychedelicious
2023-05-24 15:50:55 +10:00
committed by Kent Keirsey
parent dd16f788ed
commit d2c223de8f
9 changed files with 273 additions and 258 deletions

View File

@ -3,7 +3,7 @@
import random
from typing import Literal, Optional, Union
import einops
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, validator
import torch
from invokeai.app.invocations.util.choose_model import choose_model
@ -23,7 +23,7 @@ from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationCont
import numpy as np
from ..services.image_file_storage import ImageType
from .baseinvocation import BaseInvocation, InvocationContext
from .image import ImageField, ImageOutput, build_image_output
from .image import ImageField, ImageOutput
from .compel import ConditioningField
from ...backend.stable_diffusion import PipelineIntermediateState
from diffusers.schedulers import SchedulerMixin as Scheduler
@ -362,19 +362,9 @@ class LatentsToImageInvocation(BaseInvocation):
np_image = model.decode_latents(latents)
image = model.numpy_to_pil(np_image)[0]
# image_type = ImageType.RESULT
# image_name = context.services.images.create_name(
# context.graph_execution_state_id, self.id
# )
torch.cuda.empty_cache()
# metadata = context.services.metadata.build_metadata(
# session_id=context.graph_execution_state_id, node=self
# )
# torch.cuda.empty_cache()
# context.services.images.save(image_type, image_name, image, metadata)
image_dto = context.services.images_new.create(
image_dto = context.services.images.create(
image=image,
image_type=ImageType.RESULT,
image_category=ImageCategory.GENERAL,
@ -382,10 +372,13 @@ class LatentsToImageInvocation(BaseInvocation):
node_id=self.id,
)
return build_image_output(
image_type=image_dto.image_type,
image_name=image_dto.image_name,
image=image,
return ImageOutput(
image=ImageField(
image_name=image_dto.image_name,
image_type=image_dto.image_type,
),
width=image_dto.width,
height=image_dto.height,
)
@ -474,7 +467,7 @@ class ImageToLatentsInvocation(BaseInvocation):
@torch.no_grad()
def invoke(self, context: InvocationContext) -> LatentsOutput:
image = context.services.images.get(
image = context.services.images.get_pil_image(
self.image.image_type, self.image.image_name
)
@ -496,3 +489,4 @@ class ImageToLatentsInvocation(BaseInvocation):
name = f"{context.graph_execution_state_id}__{self.id}"
context.services.latents.save(name, latents)
return build_latents_output(latents_name=name, latents=latents)