model_cache: add ability to load a diffusers model pipeline

and update associated things in Generate & Generator to not instantly fail when that happens
This commit is contained in:
Kevin Turner 2022-11-09 17:17:52 -08:00
parent 9b274bd57c
commit 4c3858e079
5 changed files with 126 additions and 19 deletions

View File

@ -18,6 +18,8 @@ import gc
import hashlib
import cv2
import skimage
from diffusers import DiffusionPipeline, DDIMScheduler, LMSDiscreteScheduler, EulerDiscreteScheduler, \
EulerAncestralDiscreteScheduler
from omegaconf import OmegaConf
from ldm.invoke.generator.base import downsampling
@ -402,7 +404,10 @@ class Generate:
width = width or self.width
height = height or self.height
configure_model_padding(model, seamless, seamless_axes)
if isinstance(model, DiffusionPipeline):
configure_model_padding(model.unet, seamless, seamless_axes)
else:
configure_model_padding(model, seamless, seamless_axes)
assert cfg_scale > 1.0, 'CFG_Scale (-C) must be >1.0'
assert threshold >= 0.0, '--threshold must be >=0.0'
@ -962,9 +967,15 @@ class Generate:
def sample_to_lowres_estimated_image(self, samples):
return self._make_base().sample_to_lowres_estimated_image(samples)
def _set_sampler(self):
if isinstance(self.model, DiffusionPipeline):
return self._set_scheduler()
else:
return self._set_sampler_legacy()
# very repetitive code - can this be simplified? The KSampler names are
# consistent, at least
def _set_sampler(self):
def _set_sampler_legacy(self):
msg = f'>> Setting Sampler to {self.sampler_name}'
if self.sampler_name == 'plms':
self.sampler = PLMSSampler(self.model, device=self.device)
@ -992,6 +1003,44 @@ class Generate:
print(msg)
def _set_scheduler(self):
msg = f'>> Setting Sampler to {self.sampler_name}'
default = self.model.scheduler
# TODO: Test me! Not all schedulers take the same args.
scheduler_args = dict(
num_train_timesteps=default.num_train_timesteps,
beta_start=default.beta_start,
beta_end=default.beta_end,
beta_schedule=default.beta_schedule,
)
trained_betas = getattr(self.model.scheduler, 'trained_betas')
if trained_betas is not None:
scheduler_args.update(trained_betas=trained_betas)
if self.sampler_name == 'plms':
raise NotImplementedError("What's the diffusers implementation of PLMS?")
elif self.sampler_name == 'ddim':
self.sampler = DDIMScheduler(**scheduler_args)
elif self.sampler_name == 'k_dpm_2_a':
raise NotImplementedError("no diffusers implementation of dpm_2 samplers")
elif self.sampler_name == 'k_dpm_2':
raise NotImplementedError("no diffusers implementation of dpm_2 samplers")
elif self.sampler_name == 'k_euler_a':
self.sampler = EulerAncestralDiscreteScheduler(**scheduler_args)
elif self.sampler_name == 'k_euler':
self.sampler = EulerDiscreteScheduler(**scheduler_args)
elif self.sampler_name == 'k_heun':
raise NotImplementedError("no diffusers implementation of Heun's sampler")
elif self.sampler_name == 'k_lms':
self.sampler = LMSDiscreteScheduler(**scheduler_args)
else:
msg = f'>> Unsupported Sampler: {self.sampler_name}, Defaulting to {default}'
print(msg)
if not hasattr(self.sampler, 'uses_inpainting_model'):
# FIXME: terrible kludge!
self.sampler.uses_inpainting_model = lambda: False
def _load_img(self, img)->Image:
if isinstance(img, Image.Image):
image = img

View File

@ -9,8 +9,8 @@ import traceback
import numpy as np
import torch
from PIL import Image, ImageFilter, ImageChops
import cv2 as cv
from PIL import Image, ImageFilter
from diffusers import DiffusionPipeline
from einops import rearrange
from pytorch_lightning import seed_everything
from tqdm import trange
@ -26,9 +26,9 @@ class Generator:
downsampling_factor: int
latent_channels: int
precision: str
model: DiffusionWrapper
model: DiffusionWrapper | DiffusionPipeline
def __init__(self, model: DiffusionWrapper, precision: str):
def __init__(self, model: DiffusionWrapper | DiffusionPipeline, precision: str):
self.model = model
self.precision = precision
self.seed = None

View File

@ -1,4 +1,5 @@
import secrets
import warnings
from dataclasses import dataclass
from typing import List, Optional, Union, Callable
@ -309,6 +310,28 @@ class StableDiffusionGeneratorPipeline(DiffusionPipeline):
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
return text_embeddings
def get_learned_conditioning(self, c: List[List[str]], return_tokens=True,
fragment_weights=None, **kwargs):
"""
Compatibility function for ldm.models.diffusion.ddpm.LatentDiffusion.
"""
assert return_tokens == True
if fragment_weights:
weights = fragment_weights[0]
if any(weight != 1.0 for weight in weights):
warnings.warn(f"fragment weights not implemented yet {fragment_weights}", stacklevel=2)
if kwargs:
warnings.warn(f"unsupported args {kwargs}", stacklevel=2)
text_fragments = c[0]
text_input = self._tokenize(text_fragments)
with torch.inference_mode():
token_ids = text_input.input_ids.to(self.text_encoder.device)
text_embeddings = self.text_encoder(token_ids)[0]
return text_embeddings, text_input.input_ids
@torch.inference_mode()
def _tokenize(self, prompt: Union[str, List[str]]):
return self.tokenizer(
@ -319,6 +342,11 @@ class StableDiffusionGeneratorPipeline(DiffusionPipeline):
return_tensors="pt",
)
@property
def channels(self) -> int:
"""Compatible with DiffusionWrapper"""
return self.unet.in_channels
def prepare_latents(self, latents, batch_size, height, width, generator, dtype):
# get the initial random noise unless the user supplied it
# Unlike in other pipelines, latents need to be generated in the target device

View File

@ -24,17 +24,8 @@ class Txt2Img(Generator):
self.perlin = perlin
uc, c, extra_conditioning_info = conditioning
# FIXME: this should probably be either passed in to __init__ instead of model & precision,
# or be constructed in __init__ from those inputs.
pipeline = StableDiffusionGeneratorPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5",
revision="fp16", torch_dtype=torch.float16,
safety_checker=None, # TODO
# scheduler=sampler + ddim_eta, # TODO
# TODO: local_files_only=True
)
pipeline.unet.to("cuda")
pipeline.vae.to("cuda")
pipeline = self.model
# TODO: customize a new pipeline for the given sampler (Scheduler)
def make_image(x_T) -> PIL.Image.Image:
# FIXME: restore free_gpu_mem functionality

View File

@ -4,6 +4,7 @@ They are moved between GPU and CPU as necessary. If CPU memory falls
below a preset minimum, the least recently used model will be
cleared and loaded from disk when next needed.
'''
from pathlib import Path
import torch
import os
@ -20,6 +21,8 @@ import contextlib
from typing import Union
from omegaconf import OmegaConf
from omegaconf.errors import ConfigAttributeError
from ldm.invoke.generator.diffusers_pipeline import StableDiffusionGeneratorPipeline
from ldm.util import instantiate_from_config, ask_user
from ldm.invoke.globals import Globals
from picklescan.scanner import scan_file_path
@ -91,7 +94,7 @@ class ModelCache(object):
assert self.current_model,'** FATAL: no current model to restore to'
print(f'** restoring {self.current_model}')
self.get_model(self.current_model)
return
return None
self.current_model = model_name
self._push_newest_model(model_name)
@ -277,7 +280,43 @@ class ModelCache(object):
return model, width, height, model_hash
def _load_diffusers_model(self, mconfig):
raise NotImplementedError() # return pipeline, width, height, model_hash
pipeline_args = {}
if 'repo_name' in mconfig:
name_or_path = mconfig['repo_name']
model_hash = "FIXME"
# model_hash = huggingface_hub.get_hf_file_metadata(url).commit_hash
elif 'path' in mconfig:
name_or_path = Path(mconfig['path'])
# FIXME: What should the model_hash be? A hash of the unet weights? Of all files of all
# the submodels hashed together? The commit ID from the repo?
model_hash = "FIXME TOO"
else:
raise ValueError("Model config must specify either repo_name or path.")
print(f'>> Loading diffusers model from {name_or_path}')
if self.precision == 'float16':
print(' | Using faster float16 precision')
pipeline_args.update(revision="fp16", torch_dtype=torch.float16)
else:
# TODO: more accurately, "using the model's default precision."
# How do we find out what that is?
print(' | Using more accurate float32 precision')
pipeline = StableDiffusionGeneratorPipeline.from_pretrained(
name_or_path,
safety_checker=None, # TODO
# TODO: alternate VAE
# TODO: local_files_only=True
**pipeline_args
)
pipeline.to(self.device)
width = pipeline.vae.sample_size
height = pipeline.vae.sample_size
return pipeline, width, height, model_hash
def offload_model(self, model_name:str):
'''