Fix preview, inpaint

This commit is contained in:
Sergey Borisov 2023-08-07 21:27:32 +03:00
parent 2539e26c18
commit 1db2c93f75
4 changed files with 58 additions and 36 deletions

View File

@ -16,7 +16,7 @@ from ..util.step_callback import stable_diffusion_step_callback
from .baseinvocation import BaseInvocation, InvocationConfig, InvocationContext
from .image import ImageOutput
from ...backend.model_management.lora import ModelPatcher
from ...backend.model_management import ModelPatcher, BaseModelType
from ...backend.stable_diffusion.diffusers_pipeline import StableDiffusionGeneratorPipeline
from .model import UNetField, VaeField
from .compel import ConditioningField
@ -140,6 +140,7 @@ class InpaintInvocation(BaseInvocation):
self,
context: InvocationContext,
source_node_id: str,
base_model: BaseModelType,
intermediate_state: PipelineIntermediateState,
) -> None:
stable_diffusion_step_callback(
@ -147,15 +148,16 @@ class InpaintInvocation(BaseInvocation):
intermediate_state=intermediate_state,
node=self.dict(),
source_node_id=source_node_id,
base_model=base_model,
)
def get_conditioning(self, context, unet):
positive_cond_data = context.services.latents.get(self.positive_conditioning.conditioning_name)
c = positive_cond_data.conditionings[0].embeds.to(device=unet.device, dtype=unet.dtype)
extra_conditioning_info = positive_cond_data.conditionings[0].extra_conditioning
c = positive_cond_data.conditionings[0].to(device=unet.device, dtype=unet.dtype)
extra_conditioning_info = c.extra_conditioning
negative_cond_data = context.services.latents.get(self.negative_conditioning.conditioning_name)
uc = negative_cond_data.conditionings[0].embeds.to(device=unet.device, dtype=unet.dtype)
uc = negative_cond_data.conditionings[0].to(device=unet.device, dtype=unet.dtype)
return (uc, c, extra_conditioning_info)
@ -225,7 +227,7 @@ class InpaintInvocation(BaseInvocation):
scheduler=scheduler,
init_image=image,
mask_image=mask,
step_callback=partial(self.dispatch_progress, context, source_node_id),
step_callback=partial(self.dispatch_progress, context, source_node_id, self.unet.unet.base_model),
**self.dict(
exclude={"positive_conditioning", "negative_conditioning", "scheduler", "image", "mask"}
), # Shorthand for passing all of the parameters above manually

View File

@ -24,7 +24,7 @@ from ...backend.stable_diffusion.diffusers_pipeline import (
)
from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import PostprocessingSettings
from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP
from ...backend.model_management import ModelPatcher
from ...backend.model_management import ModelPatcher, BaseModelType
from ...backend.util.devices import choose_torch_device, torch_dtype, choose_precision
from ..models.image import ImageCategory, ImageField, ResourceOrigin
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationConfig, InvocationContext
@ -160,12 +160,14 @@ class TextToLatentsInvocation(BaseInvocation):
context: InvocationContext,
source_node_id: str,
intermediate_state: PipelineIntermediateState,
base_model: BaseModelType,
) -> None:
stable_diffusion_step_callback(
context=context,
intermediate_state=intermediate_state,
node=self.dict(),
source_node_id=source_node_id,
base_model=base_model,
)
def get_conditioning_data(
@ -340,7 +342,7 @@ class TextToLatentsInvocation(BaseInvocation):
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
def step_callback(state: PipelineIntermediateState):
self.dispatch_progress(context, source_node_id, state)
self.dispatch_progress(context, source_node_id, state, self.unet.unet.base_model)
def _lora_loader():
for lora in self.unet.loras:
@ -379,7 +381,7 @@ class TextToLatentsInvocation(BaseInvocation):
do_classifier_free_guidance=True,
exit_stack=exit_stack,
)
num_inference_steps, timesteps = self.init_scheduler(
scheduler,
device=unet.device,
@ -448,7 +450,7 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
def step_callback(state: PipelineIntermediateState):
self.dispatch_progress(context, source_node_id, state)
self.dispatch_progress(context, source_node_id, state, self.unet.unet.base_model)
def _lora_loader():
for lora in self.unet.loras:

View File

@ -7,6 +7,7 @@ from ...backend.util.util import image_to_dataURL
from ...backend.generator.base import Generator
from ...backend.stable_diffusion import PipelineIntermediateState
from invokeai.app.services.config import InvokeAIAppConfig
from ...backend.model_management.models import BaseModelType
def sample_to_lowres_estimated_image(samples, latent_rgb_factors, smooth_matrix=None):
@ -29,6 +30,7 @@ def stable_diffusion_step_callback(
intermediate_state: PipelineIntermediateState,
node: dict,
source_node_id: str,
base_model: BaseModelType,
):
if context.services.queue.is_canceled(context.graph_execution_state_id):
raise CanceledException
@ -56,23 +58,51 @@ def stable_diffusion_step_callback(
# TODO: only output a preview image when requested
# origingally adapted from code by @erucipe and @keturn here:
# https://discuss.huggingface.co/t/decoding-latents-to-rgb-without-upscaling/23204/7
if base_model in [BaseModelType.StableDiffusionXL, BaseModelType.StableDiffusionXLRefiner]:
sdxl_latent_rgb_factors = torch.tensor(
[
# R G B
[0.3816, 0.4930, 0.5320],
[-0.3753, 0.1631, 0.1739],
[0.1770, 0.3588, -0.2048],
[-0.4350, -0.2644, -0.4289],
],
dtype=sample.dtype,
device=sample.device,
)
# these updated numbers for v1.5 are from @torridgristle
v1_5_latent_rgb_factors = torch.tensor(
[
# R G B
[0.3444, 0.1385, 0.0670], # L1
[0.1247, 0.4027, 0.1494], # L2
[-0.3192, 0.2513, 0.2103], # L3
[-0.1307, -0.1874, -0.7445], # L4
],
dtype=sample.dtype,
device=sample.device,
)
sdxl_smooth_matrix = torch.tensor(
[
# [ 0.0478, 0.1285, 0.0478],
# [ 0.1285, 0.2948, 0.1285],
# [ 0.0478, 0.1285, 0.0478],
[0.0358, 0.0964, 0.0358],
[0.0964, 0.4711, 0.0964],
[0.0358, 0.0964, 0.0358],
],
dtype=sample.dtype,
device=sample.device,
)
image = sample_to_lowres_estimated_image(sample, v1_5_latent_rgb_factors)
image = sample_to_lowres_estimated_image(sample, sdxl_latent_rgb_factors, sdxl_smooth_matrix)
else:
# origingally adapted from code by @erucipe and @keturn here:
# https://discuss.huggingface.co/t/decoding-latents-to-rgb-without-upscaling/23204/7
# these updated numbers for v1.5 are from @torridgristle
v1_5_latent_rgb_factors = torch.tensor(
[
# R G B
[0.3444, 0.1385, 0.0670], # L1
[0.1247, 0.4027, 0.1494], # L2
[-0.3192, 0.2513, 0.2103], # L3
[-0.1307, -0.1874, -0.7445], # L4
],
dtype=sample.dtype,
device=sample.device,
)
image = sample_to_lowres_estimated_image(sample, v1_5_latent_rgb_factors)
(width, height) = image.size
width *= 8

View File

@ -50,7 +50,6 @@ from .offloading import FullyLoadedModelGroup, ModelGroup
@dataclass
class PipelineIntermediateState:
run_id: str
step: int
timestep: int
latents: torch.Tensor
@ -407,7 +406,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
noise: Optional[torch.Tensor],
timesteps=None,
additional_guidance: List[Callable] = None,
run_id=None,
callback: Callable[[PipelineIntermediateState], None] = None,
control_data: List[ControlNetData] = None,
) -> tuple[torch.Tensor, Optional[AttentionMapSaver]]:
@ -427,7 +425,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
timesteps,
conditioning_data,
noise=noise,
run_id=run_id,
additional_guidance=additional_guidance,
control_data=control_data,
callback=callback,
@ -441,13 +438,10 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
conditioning_data: ConditioningData,
*,
noise: Optional[torch.Tensor],
run_id: str = None,
additional_guidance: List[Callable] = None,
control_data: List[ControlNetData] = None,
):
self._adjust_memory_efficient_attention(latents)
if run_id is None:
run_id = secrets.token_urlsafe(self.ID_LENGTH)
if additional_guidance is None:
additional_guidance = []
extra_conditioning_info = conditioning_data.extra
@ -468,7 +462,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
latents = self.scheduler.add_noise(latents, noise, batched_t)
yield PipelineIntermediateState(
run_id=run_id,
step=-1,
timestep=self.scheduler.config.num_train_timesteps,
latents=latents,
@ -507,7 +500,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
# self.invokeai_diffuser.setup_attention_map_saving(attention_map_saver)
yield PipelineIntermediateState(
run_id=run_id,
step=i,
timestep=int(t),
latents=latents,
@ -619,7 +611,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
conditioning_data: ConditioningData,
*,
callback: Callable[[PipelineIntermediateState], None] = None,
run_id=None,
noise_func=None,
seed=None,
) -> InvokeAIStableDiffusionPipelineOutput:
@ -645,7 +636,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
conditioning_data,
strength,
noise,
run_id,
callback,
)
@ -678,7 +668,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
conditioning_data: ConditioningData,
*,
callback: Callable[[PipelineIntermediateState], None] = None,
run_id=None,
noise_func=None,
seed=None,
) -> InvokeAIStableDiffusionPipelineOutput:
@ -737,7 +726,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
noise=noise,
timesteps=timesteps,
additional_guidance=guidance,
run_id=run_id,
callback=callback,
)
finally: