mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
spike: proof of concept using diffusers for txt2img
This commit is contained in:
parent
1d43512d64
commit
58ea3bf4c8
@ -12,6 +12,7 @@ dependencies:
|
|||||||
- pytorch=1.12.1
|
- pytorch=1.12.1
|
||||||
- cudatoolkit=11.6
|
- cudatoolkit=11.6
|
||||||
- pip:
|
- pip:
|
||||||
|
- accelerate~=0.13
|
||||||
- albumentations==0.4.3
|
- albumentations==0.4.3
|
||||||
- dependency_injector==4.40.0
|
- dependency_injector==4.40.0
|
||||||
- diffusers==0.6.0
|
- diffusers==0.6.0
|
||||||
@ -39,6 +40,15 @@ dependencies:
|
|||||||
- torch-fidelity==0.3.0
|
- torch-fidelity==0.3.0
|
||||||
- torchmetrics==0.7.0
|
- torchmetrics==0.7.0
|
||||||
- transformers==4.21.3
|
- transformers==4.21.3
|
||||||
|
- diffusers~=0.7
|
||||||
|
- torchmetrics==0.7.0
|
||||||
|
- flask==2.1.3
|
||||||
|
- flask_socketio==5.3.0
|
||||||
|
- flask_cors==3.0.10
|
||||||
|
- dependency_injector==4.40.0
|
||||||
|
- eventlet
|
||||||
|
- getpass_asterisk
|
||||||
|
- kornia==0.6.0
|
||||||
- git+https://github.com/openai/CLIP.git@main#egg=clip
|
- git+https://github.com/openai/CLIP.git@main#egg=clip
|
||||||
- git+https://github.com/Birch-san/k-diffusion.git@mps#egg=k-diffusion
|
- git+https://github.com/Birch-san/k-diffusion.git@mps#egg=k-diffusion
|
||||||
- git+https://github.com/invoke-ai/clipseg.git@relaxed-python-requirement#egg=clipseg
|
- git+https://github.com/invoke-ai/clipseg.git@relaxed-python-requirement#egg=clipseg
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
# pip will resolve the version which matches torch
|
# pip will resolve the version which matches torch
|
||||||
albumentations
|
albumentations
|
||||||
dependency_injector==4.40.0
|
dependency_injector==4.40.0
|
||||||
diffusers
|
diffusers[torch]~=0.7
|
||||||
einops
|
einops
|
||||||
eventlet
|
eventlet
|
||||||
facexlib
|
facexlib
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
import secrets
|
import secrets
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import List, Optional, Union
|
from typing import List, Optional, Union, Callable
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from diffusers.models import AutoencoderKL, UNet2DConditionModel
|
from diffusers.models import AutoencoderKL, UNet2DConditionModel
|
||||||
@ -131,6 +131,7 @@ class StableDiffusionGeneratorPipeline(DiffusionPipeline):
|
|||||||
guidance_scale: Optional[float] = 7.5,
|
guidance_scale: Optional[float] = 7.5,
|
||||||
generator: Optional[torch.Generator] = None,
|
generator: Optional[torch.Generator] = None,
|
||||||
latents: Optional[torch.FloatTensor] = None,
|
latents: Optional[torch.FloatTensor] = None,
|
||||||
|
callback: Optional[Callable[[PipelineIntermediateState], None]] = None,
|
||||||
**extra_step_kwargs,
|
**extra_step_kwargs,
|
||||||
):
|
):
|
||||||
r"""
|
r"""
|
||||||
@ -172,7 +173,22 @@ class StableDiffusionGeneratorPipeline(DiffusionPipeline):
|
|||||||
prompt, height=height, width=width, num_inference_steps=num_inference_steps,
|
prompt, height=height, width=width, num_inference_steps=num_inference_steps,
|
||||||
guidance_scale=guidance_scale, generator=generator, latents=latents,
|
guidance_scale=guidance_scale, generator=generator, latents=latents,
|
||||||
**extra_step_kwargs):
|
**extra_step_kwargs):
|
||||||
pass # discarding intermediates
|
if callback is not None:
|
||||||
|
callback(result)
|
||||||
|
if result is None:
|
||||||
|
raise AssertionError("why was that an empty generator?")
|
||||||
|
return result
|
||||||
|
|
||||||
|
def image_from_embeddings(self, latents: torch.Tensor, num_inference_steps: int,
|
||||||
|
text_embeddings: torch.Tensor, guidance_scale: float,
|
||||||
|
*, callback: Callable[[PipelineIntermediateState], None]=None, run_id=None,
|
||||||
|
**extra_step_kwargs) -> StableDiffusionPipelineOutput:
|
||||||
|
self.scheduler.set_timesteps(num_inference_steps)
|
||||||
|
result = None
|
||||||
|
for result in self.generate_from_embeddings(
|
||||||
|
latents, text_embeddings, guidance_scale, run_id, **extra_step_kwargs):
|
||||||
|
if callback is not None:
|
||||||
|
callback(result)
|
||||||
if result is None:
|
if result is None:
|
||||||
raise AssertionError("why was that an empty generator?")
|
raise AssertionError("why was that an empty generator?")
|
||||||
return result
|
return result
|
||||||
@ -199,9 +215,6 @@ class StableDiffusionGeneratorPipeline(DiffusionPipeline):
|
|||||||
if height % 8 != 0 or width % 8 != 0:
|
if height % 8 != 0 or width % 8 != 0:
|
||||||
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
||||||
|
|
||||||
if run_id is None:
|
|
||||||
run_id = secrets.token_urlsafe(self.ID_LENGTH)
|
|
||||||
|
|
||||||
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
||||||
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
||||||
# corresponds to doing no classifier free guidance.
|
# corresponds to doing no classifier free guidance.
|
||||||
@ -209,16 +222,23 @@ class StableDiffusionGeneratorPipeline(DiffusionPipeline):
|
|||||||
text_embeddings = self.get_text_embeddings(prompt, opposing_prompt, do_classifier_free_guidance, batch_size)\
|
text_embeddings = self.get_text_embeddings(prompt, opposing_prompt, do_classifier_free_guidance, batch_size)\
|
||||||
.to(self.unet.device)
|
.to(self.unet.device)
|
||||||
self.scheduler.set_timesteps(num_inference_steps)
|
self.scheduler.set_timesteps(num_inference_steps)
|
||||||
latents = self.prepare_latents(latents, batch_size, height, width,
|
latents = self.prepare_latents(latents, batch_size, height, width, generator, self.unet.dtype)
|
||||||
generator, self.unet.dtype)
|
|
||||||
|
|
||||||
|
yield from self.generate_from_embeddings(latents, text_embeddings, guidance_scale, run_id, **extra_step_kwargs)
|
||||||
|
|
||||||
|
def generate_from_embeddings(self, latents: torch.Tensor, text_embeddings: torch.Tensor, guidance_scale: float,
|
||||||
|
run_id: str = None, **extra_step_kwargs):
|
||||||
|
if run_id is None:
|
||||||
|
run_id = secrets.token_urlsafe(self.ID_LENGTH)
|
||||||
yield PipelineIntermediateState(run_id=run_id, step=-1, timestep=self.scheduler.num_train_timesteps,
|
yield PipelineIntermediateState(run_id=run_id, step=-1, timestep=self.scheduler.num_train_timesteps,
|
||||||
latents=latents)
|
latents=latents)
|
||||||
|
# NOTE: Depends on scheduler being already initialized!
|
||||||
for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)):
|
for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)):
|
||||||
step_output = self.step(t, latents, guidance_scale, text_embeddings, **extra_step_kwargs)
|
step_output = self.step(t, latents, guidance_scale, text_embeddings, **extra_step_kwargs)
|
||||||
latents = step_output.prev_sample
|
latents = step_output.prev_sample
|
||||||
|
predicted_original = getattr(step_output, 'pred_original_sample', None)
|
||||||
yield PipelineIntermediateState(run_id=run_id, step=i, timestep=int(t), latents=latents,
|
yield PipelineIntermediateState(run_id=run_id, step=i, timestep=int(t), latents=latents,
|
||||||
predicted_original=step_output.pred_original_sample)
|
predicted_original=predicted_original)
|
||||||
|
|
||||||
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
@ -1,10 +1,11 @@
|
|||||||
'''
|
'''
|
||||||
ldm.invoke.generator.txt2img inherits from ldm.invoke.generator
|
ldm.invoke.generator.txt2img inherits from ldm.invoke.generator
|
||||||
'''
|
'''
|
||||||
|
import PIL.Image
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from ldm.invoke.generator.base import Generator
|
from .base import Generator
|
||||||
|
from .diffusers_pipeline import StableDiffusionGeneratorPipeline
|
||||||
|
|
||||||
|
|
||||||
class Txt2Img(Generator):
|
class Txt2Img(Generator):
|
||||||
@ -13,7 +14,8 @@ class Txt2Img(Generator):
|
|||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def get_make_image(self,prompt,sampler,steps,cfg_scale,ddim_eta,
|
def get_make_image(self,prompt,sampler,steps,cfg_scale,ddim_eta,
|
||||||
conditioning,width,height,step_callback=None,threshold=0.0,perlin=0.0,**kwargs):
|
conditioning,width,height,step_callback=None,threshold=0.0,perlin=0.0,
|
||||||
|
**kwargs):
|
||||||
"""
|
"""
|
||||||
Returns a function returning an image derived from the prompt and the initial image
|
Returns a function returning an image derived from the prompt and the initial image
|
||||||
Return value depends on the seed at the time you call it
|
Return value depends on the seed at the time you call it
|
||||||
@ -22,38 +24,42 @@ class Txt2Img(Generator):
|
|||||||
self.perlin = perlin
|
self.perlin = perlin
|
||||||
uc, c, extra_conditioning_info = conditioning
|
uc, c, extra_conditioning_info = conditioning
|
||||||
|
|
||||||
@torch.no_grad()
|
# FIXME: this should probably be either passed in to __init__ instead of model & precision,
|
||||||
def make_image(x_T):
|
# or be constructed in __init__ from those inputs.
|
||||||
shape = [
|
pipeline = StableDiffusionGeneratorPipeline.from_pretrained(
|
||||||
self.latent_channels,
|
"runwayml/stable-diffusion-v1-5",
|
||||||
height // self.downsampling_factor,
|
revision="fp16", torch_dtype=torch.float16,
|
||||||
width // self.downsampling_factor,
|
safety_checker=None, # TODO
|
||||||
]
|
# scheduler=sampler + ddim_eta, # TODO
|
||||||
|
# TODO: local_files_only=True
|
||||||
|
)
|
||||||
|
pipeline.unet.to("cuda")
|
||||||
|
pipeline.vae.to("cuda")
|
||||||
|
|
||||||
if self.free_gpu_mem and self.model.model.device != self.model.device:
|
def make_image(x_T) -> PIL.Image.Image:
|
||||||
self.model.model.to(self.model.device)
|
# FIXME: restore free_gpu_mem functionality
|
||||||
|
# if self.free_gpu_mem and self.model.model.device != self.model.device:
|
||||||
|
# self.model.model.to(self.model.device)
|
||||||
|
|
||||||
sampler.make_schedule(ddim_num_steps=steps, ddim_eta=ddim_eta, verbose=False)
|
# FIXME: how the embeddings are combined should be internal to the pipeline
|
||||||
|
combined_text_embeddings = torch.cat([uc, c])
|
||||||
|
|
||||||
samples, _ = sampler.sample(
|
pipeline_output = pipeline.image_from_embeddings(
|
||||||
batch_size = 1,
|
latents=x_T,
|
||||||
S = steps,
|
num_inference_steps=steps,
|
||||||
x_T = x_T,
|
text_embeddings=combined_text_embeddings,
|
||||||
conditioning = c,
|
guidance_scale=cfg_scale,
|
||||||
shape = shape,
|
callback=step_callback,
|
||||||
verbose = False,
|
# TODO: extra_conditioning_info = extra_conditioning_info,
|
||||||
unconditional_guidance_scale = cfg_scale,
|
# TODO: eta = ddim_eta,
|
||||||
unconditional_conditioning = uc,
|
# TODO: threshold = threshold,
|
||||||
extra_conditioning_info = extra_conditioning_info,
|
|
||||||
eta = ddim_eta,
|
|
||||||
img_callback = step_callback,
|
|
||||||
threshold = threshold,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.free_gpu_mem:
|
# FIXME: restore free_gpu_mem functionality
|
||||||
self.model.model.to("cpu")
|
# if self.free_gpu_mem:
|
||||||
|
# self.model.model.to("cpu")
|
||||||
|
|
||||||
return self.sample_to_image(samples)
|
return pipeline.numpy_to_pil(pipeline_output.images)[0]
|
||||||
|
|
||||||
return make_image
|
return make_image
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user