spike: proof of concept using diffusers for txt2img

This commit is contained in:
Kevin Turner 2022-11-09 11:33:19 -08:00
parent 1d43512d64
commit 58ea3bf4c8
4 changed files with 75 additions and 39 deletions

View File

@ -12,6 +12,7 @@ dependencies:
- pytorch=1.12.1 - pytorch=1.12.1
- cudatoolkit=11.6 - cudatoolkit=11.6
- pip: - pip:
- accelerate~=0.13
- albumentations==0.4.3 - albumentations==0.4.3
- dependency_injector==4.40.0 - dependency_injector==4.40.0
- diffusers==0.6.0 - diffusers==0.6.0
@ -39,6 +40,15 @@ dependencies:
- torch-fidelity==0.3.0 - torch-fidelity==0.3.0
- torchmetrics==0.7.0 - torchmetrics==0.7.0
- transformers==4.21.3 - transformers==4.21.3
- diffusers~=0.7
- torchmetrics==0.7.0
- flask==2.1.3
- flask_socketio==5.3.0
- flask_cors==3.0.10
- dependency_injector==4.40.0
- eventlet
- getpass_asterisk
- kornia==0.6.0
- git+https://github.com/openai/CLIP.git@main#egg=clip - git+https://github.com/openai/CLIP.git@main#egg=clip
- git+https://github.com/Birch-san/k-diffusion.git@mps#egg=k-diffusion - git+https://github.com/Birch-san/k-diffusion.git@mps#egg=k-diffusion
- git+https://github.com/invoke-ai/clipseg.git@relaxed-python-requirement#egg=clipseg - git+https://github.com/invoke-ai/clipseg.git@relaxed-python-requirement#egg=clipseg

View File

@ -1,7 +1,7 @@
# pip will resolve the version which matches torch # pip will resolve the version which matches torch
albumentations albumentations
dependency_injector==4.40.0 dependency_injector==4.40.0
diffusers diffusers[torch]~=0.7
einops einops
eventlet eventlet
facexlib facexlib

View File

@ -1,6 +1,6 @@
import secrets import secrets
from dataclasses import dataclass from dataclasses import dataclass
from typing import List, Optional, Union from typing import List, Optional, Union, Callable
import torch import torch
from diffusers.models import AutoencoderKL, UNet2DConditionModel from diffusers.models import AutoencoderKL, UNet2DConditionModel
@ -131,6 +131,7 @@ class StableDiffusionGeneratorPipeline(DiffusionPipeline):
guidance_scale: Optional[float] = 7.5, guidance_scale: Optional[float] = 7.5,
generator: Optional[torch.Generator] = None, generator: Optional[torch.Generator] = None,
latents: Optional[torch.FloatTensor] = None, latents: Optional[torch.FloatTensor] = None,
callback: Optional[Callable[[PipelineIntermediateState], None]] = None,
**extra_step_kwargs, **extra_step_kwargs,
): ):
r""" r"""
@ -172,7 +173,22 @@ class StableDiffusionGeneratorPipeline(DiffusionPipeline):
prompt, height=height, width=width, num_inference_steps=num_inference_steps, prompt, height=height, width=width, num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale, generator=generator, latents=latents, guidance_scale=guidance_scale, generator=generator, latents=latents,
**extra_step_kwargs): **extra_step_kwargs):
pass # discarding intermediates if callback is not None:
callback(result)
if result is None:
raise AssertionError("why was that an empty generator?")
return result
def image_from_embeddings(self, latents: torch.Tensor, num_inference_steps: int,
text_embeddings: torch.Tensor, guidance_scale: float,
*, callback: Callable[[PipelineIntermediateState], None]=None, run_id=None,
**extra_step_kwargs) -> StableDiffusionPipelineOutput:
self.scheduler.set_timesteps(num_inference_steps)
result = None
for result in self.generate_from_embeddings(
latents, text_embeddings, guidance_scale, run_id, **extra_step_kwargs):
if callback is not None:
callback(result)
if result is None: if result is None:
raise AssertionError("why was that an empty generator?") raise AssertionError("why was that an empty generator?")
return result return result
@ -199,9 +215,6 @@ class StableDiffusionGeneratorPipeline(DiffusionPipeline):
if height % 8 != 0 or width % 8 != 0: if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
if run_id is None:
run_id = secrets.token_urlsafe(self.ID_LENGTH)
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance. # corresponds to doing no classifier free guidance.
@ -209,16 +222,23 @@ class StableDiffusionGeneratorPipeline(DiffusionPipeline):
text_embeddings = self.get_text_embeddings(prompt, opposing_prompt, do_classifier_free_guidance, batch_size)\ text_embeddings = self.get_text_embeddings(prompt, opposing_prompt, do_classifier_free_guidance, batch_size)\
.to(self.unet.device) .to(self.unet.device)
self.scheduler.set_timesteps(num_inference_steps) self.scheduler.set_timesteps(num_inference_steps)
latents = self.prepare_latents(latents, batch_size, height, width, latents = self.prepare_latents(latents, batch_size, height, width, generator, self.unet.dtype)
generator, self.unet.dtype)
yield from self.generate_from_embeddings(latents, text_embeddings, guidance_scale, run_id, **extra_step_kwargs)
def generate_from_embeddings(self, latents: torch.Tensor, text_embeddings: torch.Tensor, guidance_scale: float,
run_id: str = None, **extra_step_kwargs):
if run_id is None:
run_id = secrets.token_urlsafe(self.ID_LENGTH)
yield PipelineIntermediateState(run_id=run_id, step=-1, timestep=self.scheduler.num_train_timesteps, yield PipelineIntermediateState(run_id=run_id, step=-1, timestep=self.scheduler.num_train_timesteps,
latents=latents) latents=latents)
# NOTE: Depends on scheduler being already initialized!
for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)): for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)):
step_output = self.step(t, latents, guidance_scale, text_embeddings, **extra_step_kwargs) step_output = self.step(t, latents, guidance_scale, text_embeddings, **extra_step_kwargs)
latents = step_output.prev_sample latents = step_output.prev_sample
predicted_original = getattr(step_output, 'pred_original_sample', None)
yield PipelineIntermediateState(run_id=run_id, step=i, timestep=int(t), latents=latents, yield PipelineIntermediateState(run_id=run_id, step=i, timestep=int(t), latents=latents,
predicted_original=step_output.pred_original_sample) predicted_original=predicted_original)
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699 # https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
torch.cuda.empty_cache() torch.cuda.empty_cache()

View File

@ -1,10 +1,11 @@
''' '''
ldm.invoke.generator.txt2img inherits from ldm.invoke.generator ldm.invoke.generator.txt2img inherits from ldm.invoke.generator
''' '''
import PIL.Image
import torch import torch
from ldm.invoke.generator.base import Generator from .base import Generator
from .diffusers_pipeline import StableDiffusionGeneratorPipeline
class Txt2Img(Generator): class Txt2Img(Generator):
@ -13,7 +14,8 @@ class Txt2Img(Generator):
@torch.no_grad() @torch.no_grad()
def get_make_image(self,prompt,sampler,steps,cfg_scale,ddim_eta, def get_make_image(self,prompt,sampler,steps,cfg_scale,ddim_eta,
conditioning,width,height,step_callback=None,threshold=0.0,perlin=0.0,**kwargs): conditioning,width,height,step_callback=None,threshold=0.0,perlin=0.0,
**kwargs):
""" """
Returns a function returning an image derived from the prompt and the initial image Returns a function returning an image derived from the prompt and the initial image
Return value depends on the seed at the time you call it Return value depends on the seed at the time you call it
@ -22,38 +24,42 @@ class Txt2Img(Generator):
self.perlin = perlin self.perlin = perlin
uc, c, extra_conditioning_info = conditioning uc, c, extra_conditioning_info = conditioning
@torch.no_grad() # FIXME: this should probably be either passed in to __init__ instead of model & precision,
def make_image(x_T): # or be constructed in __init__ from those inputs.
shape = [ pipeline = StableDiffusionGeneratorPipeline.from_pretrained(
self.latent_channels, "runwayml/stable-diffusion-v1-5",
height // self.downsampling_factor, revision="fp16", torch_dtype=torch.float16,
width // self.downsampling_factor, safety_checker=None, # TODO
] # scheduler=sampler + ddim_eta, # TODO
# TODO: local_files_only=True
)
pipeline.unet.to("cuda")
pipeline.vae.to("cuda")
if self.free_gpu_mem and self.model.model.device != self.model.device: def make_image(x_T) -> PIL.Image.Image:
self.model.model.to(self.model.device) # FIXME: restore free_gpu_mem functionality
# if self.free_gpu_mem and self.model.model.device != self.model.device:
# self.model.model.to(self.model.device)
sampler.make_schedule(ddim_num_steps=steps, ddim_eta=ddim_eta, verbose=False) # FIXME: how the embeddings are combined should be internal to the pipeline
combined_text_embeddings = torch.cat([uc, c])
samples, _ = sampler.sample( pipeline_output = pipeline.image_from_embeddings(
batch_size = 1, latents=x_T,
S = steps, num_inference_steps=steps,
x_T = x_T, text_embeddings=combined_text_embeddings,
conditioning = c, guidance_scale=cfg_scale,
shape = shape, callback=step_callback,
verbose = False, # TODO: extra_conditioning_info = extra_conditioning_info,
unconditional_guidance_scale = cfg_scale, # TODO: eta = ddim_eta,
unconditional_conditioning = uc, # TODO: threshold = threshold,
extra_conditioning_info = extra_conditioning_info,
eta = ddim_eta,
img_callback = step_callback,
threshold = threshold,
) )
if self.free_gpu_mem: # FIXME: restore free_gpu_mem functionality
self.model.model.to("cpu") # if self.free_gpu_mem:
# self.model.model.to("cpu")
return self.sample_to_image(samples) return pipeline.numpy_to_pil(pipeline_output.images)[0]
return make_image return make_image