spike: proof of concept using diffusers for txt2img

This commit is contained in:
Kevin Turner 2022-11-09 11:33:19 -08:00
parent 1d43512d64
commit 58ea3bf4c8
4 changed files with 75 additions and 39 deletions

View File

@ -12,6 +12,7 @@ dependencies:
- pytorch=1.12.1
- cudatoolkit=11.6
- pip:
- accelerate~=0.13
- albumentations==0.4.3
- dependency_injector==4.40.0
- diffusers==0.6.0
@ -39,6 +40,15 @@ dependencies:
- torch-fidelity==0.3.0
- torchmetrics==0.7.0
- transformers==4.21.3
- diffusers~=0.7
- torchmetrics==0.7.0
- flask==2.1.3
- flask_socketio==5.3.0
- flask_cors==3.0.10
- dependency_injector==4.40.0
- eventlet
- getpass_asterisk
- kornia==0.6.0
- git+https://github.com/openai/CLIP.git@main#egg=clip
- git+https://github.com/Birch-san/k-diffusion.git@mps#egg=k-diffusion
- git+https://github.com/invoke-ai/clipseg.git@relaxed-python-requirement#egg=clipseg

View File

@ -1,7 +1,7 @@
# pip will resolve the version which matches torch
albumentations
dependency_injector==4.40.0
diffusers
diffusers[torch]~=0.7
einops
eventlet
facexlib

View File

@ -1,6 +1,6 @@
import secrets
from dataclasses import dataclass
from typing import List, Optional, Union
from typing import List, Optional, Union, Callable
import torch
from diffusers.models import AutoencoderKL, UNet2DConditionModel
@ -131,6 +131,7 @@ class StableDiffusionGeneratorPipeline(DiffusionPipeline):
guidance_scale: Optional[float] = 7.5,
generator: Optional[torch.Generator] = None,
latents: Optional[torch.FloatTensor] = None,
callback: Optional[Callable[[PipelineIntermediateState], None]] = None,
**extra_step_kwargs,
):
r"""
@ -172,7 +173,22 @@ class StableDiffusionGeneratorPipeline(DiffusionPipeline):
prompt, height=height, width=width, num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale, generator=generator, latents=latents,
**extra_step_kwargs):
pass # discarding intermediates
if callback is not None:
callback(result)
if result is None:
raise AssertionError("why was that an empty generator?")
return result
def image_from_embeddings(self, latents: torch.Tensor, num_inference_steps: int,
text_embeddings: torch.Tensor, guidance_scale: float,
*, callback: Callable[[PipelineIntermediateState], None]=None, run_id=None,
**extra_step_kwargs) -> StableDiffusionPipelineOutput:
self.scheduler.set_timesteps(num_inference_steps)
result = None
for result in self.generate_from_embeddings(
latents, text_embeddings, guidance_scale, run_id, **extra_step_kwargs):
if callback is not None:
callback(result)
if result is None:
raise AssertionError("why was that an empty generator?")
return result
@ -199,9 +215,6 @@ class StableDiffusionGeneratorPipeline(DiffusionPipeline):
if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
if run_id is None:
run_id = secrets.token_urlsafe(self.ID_LENGTH)
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
@ -209,16 +222,23 @@ class StableDiffusionGeneratorPipeline(DiffusionPipeline):
text_embeddings = self.get_text_embeddings(prompt, opposing_prompt, do_classifier_free_guidance, batch_size)\
.to(self.unet.device)
self.scheduler.set_timesteps(num_inference_steps)
latents = self.prepare_latents(latents, batch_size, height, width,
generator, self.unet.dtype)
latents = self.prepare_latents(latents, batch_size, height, width, generator, self.unet.dtype)
yield from self.generate_from_embeddings(latents, text_embeddings, guidance_scale, run_id, **extra_step_kwargs)
def generate_from_embeddings(self, latents: torch.Tensor, text_embeddings: torch.Tensor, guidance_scale: float,
run_id: str = None, **extra_step_kwargs):
if run_id is None:
run_id = secrets.token_urlsafe(self.ID_LENGTH)
yield PipelineIntermediateState(run_id=run_id, step=-1, timestep=self.scheduler.num_train_timesteps,
latents=latents)
# NOTE: Depends on scheduler being already initialized!
for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)):
step_output = self.step(t, latents, guidance_scale, text_embeddings, **extra_step_kwargs)
latents = step_output.prev_sample
predicted_original = getattr(step_output, 'pred_original_sample', None)
yield PipelineIntermediateState(run_id=run_id, step=i, timestep=int(t), latents=latents,
predicted_original=step_output.pred_original_sample)
predicted_original=predicted_original)
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
torch.cuda.empty_cache()

View File

@ -1,10 +1,11 @@
'''
ldm.invoke.generator.txt2img inherits from ldm.invoke.generator
'''
import PIL.Image
import torch
from ldm.invoke.generator.base import Generator
from .base import Generator
from .diffusers_pipeline import StableDiffusionGeneratorPipeline
class Txt2Img(Generator):
@ -13,7 +14,8 @@ class Txt2Img(Generator):
@torch.no_grad()
def get_make_image(self,prompt,sampler,steps,cfg_scale,ddim_eta,
conditioning,width,height,step_callback=None,threshold=0.0,perlin=0.0,**kwargs):
conditioning,width,height,step_callback=None,threshold=0.0,perlin=0.0,
**kwargs):
"""
Returns a function returning an image derived from the prompt and the initial image
Return value depends on the seed at the time you call it
@ -22,38 +24,42 @@ class Txt2Img(Generator):
self.perlin = perlin
uc, c, extra_conditioning_info = conditioning
@torch.no_grad()
def make_image(x_T):
shape = [
self.latent_channels,
height // self.downsampling_factor,
width // self.downsampling_factor,
]
# FIXME: this should probably be either passed in to __init__ instead of model & precision,
# or be constructed in __init__ from those inputs.
pipeline = StableDiffusionGeneratorPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5",
revision="fp16", torch_dtype=torch.float16,
safety_checker=None, # TODO
# scheduler=sampler + ddim_eta, # TODO
# TODO: local_files_only=True
)
pipeline.unet.to("cuda")
pipeline.vae.to("cuda")
if self.free_gpu_mem and self.model.model.device != self.model.device:
self.model.model.to(self.model.device)
sampler.make_schedule(ddim_num_steps=steps, ddim_eta=ddim_eta, verbose=False)
def make_image(x_T) -> PIL.Image.Image:
# FIXME: restore free_gpu_mem functionality
# if self.free_gpu_mem and self.model.model.device != self.model.device:
# self.model.model.to(self.model.device)
samples, _ = sampler.sample(
batch_size = 1,
S = steps,
x_T = x_T,
conditioning = c,
shape = shape,
verbose = False,
unconditional_guidance_scale = cfg_scale,
unconditional_conditioning = uc,
extra_conditioning_info = extra_conditioning_info,
eta = ddim_eta,
img_callback = step_callback,
threshold = threshold,
# FIXME: how the embeddings are combined should be internal to the pipeline
combined_text_embeddings = torch.cat([uc, c])
pipeline_output = pipeline.image_from_embeddings(
latents=x_T,
num_inference_steps=steps,
text_embeddings=combined_text_embeddings,
guidance_scale=cfg_scale,
callback=step_callback,
# TODO: extra_conditioning_info = extra_conditioning_info,
# TODO: eta = ddim_eta,
# TODO: threshold = threshold,
)
if self.free_gpu_mem:
self.model.model.to("cpu")
# FIXME: restore free_gpu_mem functionality
# if self.free_gpu_mem:
# self.model.model.to("cpu")
return self.sample_to_image(samples)
return pipeline.numpy_to_pil(pipeline_output.images)[0]
return make_image