model_cache: add ability to load a diffusers model pipeline

and update associated things in Generate & Generator to not instantly fail when that happens
This commit is contained in:
Kevin Turner 2022-11-09 17:17:52 -08:00
parent 9b274bd57c
commit 4c3858e079
5 changed files with 126 additions and 19 deletions

View File

@ -18,6 +18,8 @@ import gc
import hashlib import hashlib
import cv2 import cv2
import skimage import skimage
from diffusers import DiffusionPipeline, DDIMScheduler, LMSDiscreteScheduler, EulerDiscreteScheduler, \
EulerAncestralDiscreteScheduler
from omegaconf import OmegaConf from omegaconf import OmegaConf
from ldm.invoke.generator.base import downsampling from ldm.invoke.generator.base import downsampling
@ -402,7 +404,10 @@ class Generate:
width = width or self.width width = width or self.width
height = height or self.height height = height or self.height
configure_model_padding(model, seamless, seamless_axes) if isinstance(model, DiffusionPipeline):
configure_model_padding(model.unet, seamless, seamless_axes)
else:
configure_model_padding(model, seamless, seamless_axes)
assert cfg_scale > 1.0, 'CFG_Scale (-C) must be >1.0' assert cfg_scale > 1.0, 'CFG_Scale (-C) must be >1.0'
assert threshold >= 0.0, '--threshold must be >=0.0' assert threshold >= 0.0, '--threshold must be >=0.0'
@ -962,9 +967,15 @@ class Generate:
def sample_to_lowres_estimated_image(self, samples): def sample_to_lowres_estimated_image(self, samples):
return self._make_base().sample_to_lowres_estimated_image(samples) return self._make_base().sample_to_lowres_estimated_image(samples)
def _set_sampler(self):
if isinstance(self.model, DiffusionPipeline):
return self._set_scheduler()
else:
return self._set_sampler_legacy()
# very repetitive code - can this be simplified? The KSampler names are # very repetitive code - can this be simplified? The KSampler names are
# consistent, at least # consistent, at least
def _set_sampler(self): def _set_sampler_legacy(self):
msg = f'>> Setting Sampler to {self.sampler_name}' msg = f'>> Setting Sampler to {self.sampler_name}'
if self.sampler_name == 'plms': if self.sampler_name == 'plms':
self.sampler = PLMSSampler(self.model, device=self.device) self.sampler = PLMSSampler(self.model, device=self.device)
@ -992,6 +1003,44 @@ class Generate:
print(msg) print(msg)
def _set_scheduler(self):
msg = f'>> Setting Sampler to {self.sampler_name}'
default = self.model.scheduler
# TODO: Test me! Not all schedulers take the same args.
scheduler_args = dict(
num_train_timesteps=default.num_train_timesteps,
beta_start=default.beta_start,
beta_end=default.beta_end,
beta_schedule=default.beta_schedule,
)
trained_betas = getattr(self.model.scheduler, 'trained_betas')
if trained_betas is not None:
scheduler_args.update(trained_betas=trained_betas)
if self.sampler_name == 'plms':
raise NotImplementedError("What's the diffusers implementation of PLMS?")
elif self.sampler_name == 'ddim':
self.sampler = DDIMScheduler(**scheduler_args)
elif self.sampler_name == 'k_dpm_2_a':
raise NotImplementedError("no diffusers implementation of dpm_2 samplers")
elif self.sampler_name == 'k_dpm_2':
raise NotImplementedError("no diffusers implementation of dpm_2 samplers")
elif self.sampler_name == 'k_euler_a':
self.sampler = EulerAncestralDiscreteScheduler(**scheduler_args)
elif self.sampler_name == 'k_euler':
self.sampler = EulerDiscreteScheduler(**scheduler_args)
elif self.sampler_name == 'k_heun':
raise NotImplementedError("no diffusers implementation of Heun's sampler")
elif self.sampler_name == 'k_lms':
self.sampler = LMSDiscreteScheduler(**scheduler_args)
else:
msg = f'>> Unsupported Sampler: {self.sampler_name}, Defaulting to {default}'
print(msg)
if not hasattr(self.sampler, 'uses_inpainting_model'):
# FIXME: terrible kludge!
self.sampler.uses_inpainting_model = lambda: False
def _load_img(self, img)->Image: def _load_img(self, img)->Image:
if isinstance(img, Image.Image): if isinstance(img, Image.Image):
image = img image = img

View File

@ -9,8 +9,8 @@ import traceback
import numpy as np import numpy as np
import torch import torch
from PIL import Image, ImageFilter, ImageChops from PIL import Image, ImageFilter
import cv2 as cv from diffusers import DiffusionPipeline
from einops import rearrange from einops import rearrange
from pytorch_lightning import seed_everything from pytorch_lightning import seed_everything
from tqdm import trange from tqdm import trange
@ -26,9 +26,9 @@ class Generator:
downsampling_factor: int downsampling_factor: int
latent_channels: int latent_channels: int
precision: str precision: str
model: DiffusionWrapper model: DiffusionWrapper | DiffusionPipeline
def __init__(self, model: DiffusionWrapper, precision: str): def __init__(self, model: DiffusionWrapper | DiffusionPipeline, precision: str):
self.model = model self.model = model
self.precision = precision self.precision = precision
self.seed = None self.seed = None

View File

@ -1,4 +1,5 @@
import secrets import secrets
import warnings
from dataclasses import dataclass from dataclasses import dataclass
from typing import List, Optional, Union, Callable from typing import List, Optional, Union, Callable
@ -309,6 +310,28 @@ class StableDiffusionGeneratorPipeline(DiffusionPipeline):
text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
return text_embeddings return text_embeddings
def get_learned_conditioning(self, c: List[List[str]], return_tokens=True,
fragment_weights=None, **kwargs):
"""
Compatibility function for ldm.models.diffusion.ddpm.LatentDiffusion.
"""
assert return_tokens == True
if fragment_weights:
weights = fragment_weights[0]
if any(weight != 1.0 for weight in weights):
warnings.warn(f"fragment weights not implemented yet {fragment_weights}", stacklevel=2)
if kwargs:
warnings.warn(f"unsupported args {kwargs}", stacklevel=2)
text_fragments = c[0]
text_input = self._tokenize(text_fragments)
with torch.inference_mode():
token_ids = text_input.input_ids.to(self.text_encoder.device)
text_embeddings = self.text_encoder(token_ids)[0]
return text_embeddings, text_input.input_ids
@torch.inference_mode() @torch.inference_mode()
def _tokenize(self, prompt: Union[str, List[str]]): def _tokenize(self, prompt: Union[str, List[str]]):
return self.tokenizer( return self.tokenizer(
@ -319,6 +342,11 @@ class StableDiffusionGeneratorPipeline(DiffusionPipeline):
return_tensors="pt", return_tensors="pt",
) )
@property
def channels(self) -> int:
"""Compatible with DiffusionWrapper"""
return self.unet.in_channels
def prepare_latents(self, latents, batch_size, height, width, generator, dtype): def prepare_latents(self, latents, batch_size, height, width, generator, dtype):
# get the initial random noise unless the user supplied it # get the initial random noise unless the user supplied it
# Unlike in other pipelines, latents need to be generated in the target device # Unlike in other pipelines, latents need to be generated in the target device

View File

@ -24,17 +24,8 @@ class Txt2Img(Generator):
self.perlin = perlin self.perlin = perlin
uc, c, extra_conditioning_info = conditioning uc, c, extra_conditioning_info = conditioning
# FIXME: this should probably be either passed in to __init__ instead of model & precision, pipeline = self.model
# or be constructed in __init__ from those inputs. # TODO: customize a new pipeline for the given sampler (Scheduler)
pipeline = StableDiffusionGeneratorPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5",
revision="fp16", torch_dtype=torch.float16,
safety_checker=None, # TODO
# scheduler=sampler + ddim_eta, # TODO
# TODO: local_files_only=True
)
pipeline.unet.to("cuda")
pipeline.vae.to("cuda")
def make_image(x_T) -> PIL.Image.Image: def make_image(x_T) -> PIL.Image.Image:
# FIXME: restore free_gpu_mem functionality # FIXME: restore free_gpu_mem functionality

View File

@ -4,6 +4,7 @@ They are moved between GPU and CPU as necessary. If CPU memory falls
below a preset minimum, the least recently used model will be below a preset minimum, the least recently used model will be
cleared and loaded from disk when next needed. cleared and loaded from disk when next needed.
''' '''
from pathlib import Path
import torch import torch
import os import os
@ -20,6 +21,8 @@ import contextlib
from typing import Union from typing import Union
from omegaconf import OmegaConf from omegaconf import OmegaConf
from omegaconf.errors import ConfigAttributeError from omegaconf.errors import ConfigAttributeError
from ldm.invoke.generator.diffusers_pipeline import StableDiffusionGeneratorPipeline
from ldm.util import instantiate_from_config, ask_user from ldm.util import instantiate_from_config, ask_user
from ldm.invoke.globals import Globals from ldm.invoke.globals import Globals
from picklescan.scanner import scan_file_path from picklescan.scanner import scan_file_path
@ -91,7 +94,7 @@ class ModelCache(object):
assert self.current_model,'** FATAL: no current model to restore to' assert self.current_model,'** FATAL: no current model to restore to'
print(f'** restoring {self.current_model}') print(f'** restoring {self.current_model}')
self.get_model(self.current_model) self.get_model(self.current_model)
return return None
self.current_model = model_name self.current_model = model_name
self._push_newest_model(model_name) self._push_newest_model(model_name)
@ -277,7 +280,43 @@ class ModelCache(object):
return model, width, height, model_hash return model, width, height, model_hash
def _load_diffusers_model(self, mconfig): def _load_diffusers_model(self, mconfig):
raise NotImplementedError() # return pipeline, width, height, model_hash pipeline_args = {}
if 'repo_name' in mconfig:
name_or_path = mconfig['repo_name']
model_hash = "FIXME"
# model_hash = huggingface_hub.get_hf_file_metadata(url).commit_hash
elif 'path' in mconfig:
name_or_path = Path(mconfig['path'])
# FIXME: What should the model_hash be? A hash of the unet weights? Of all files of all
# the submodels hashed together? The commit ID from the repo?
model_hash = "FIXME TOO"
else:
raise ValueError("Model config must specify either repo_name or path.")
print(f'>> Loading diffusers model from {name_or_path}')
if self.precision == 'float16':
print(' | Using faster float16 precision')
pipeline_args.update(revision="fp16", torch_dtype=torch.float16)
else:
# TODO: more accurately, "using the model's default precision."
# How do we find out what that is?
print(' | Using more accurate float32 precision')
pipeline = StableDiffusionGeneratorPipeline.from_pretrained(
name_or_path,
safety_checker=None, # TODO
# TODO: alternate VAE
# TODO: local_files_only=True
**pipeline_args
)
pipeline.to(self.device)
width = pipeline.vae.sample_size
height = pipeline.vae.sample_size
return pipeline, width, height, model_hash
def offload_model(self, model_name:str): def offload_model(self, model_name:str):
''' '''