mostly ported to new manager API; needs testing

This commit is contained in:
Lincoln Stein
2023-05-06 00:44:12 -04:00
parent af8c7c7d29
commit e0214a32bc
12 changed files with 353 additions and 332 deletions

View File

@ -180,31 +180,33 @@ class TextToLatentsInvocation(BaseInvocation):
def get_model(self, model_manager: ModelManager) -> StableDiffusionGeneratorPipeline:
model_info = choose_model(model_manager, self.model)
model_name = model_info['model_name']
model_hash = model_info['hash']
model: StableDiffusionGeneratorPipeline = model_info['model']
model.scheduler = get_scheduler(
model=model,
scheduler_name=self.scheduler
)
model_name = model_info.name
model_hash = model_info.hash
model_ctx: StableDiffusionGeneratorPipeline = model_info.context
with model_ctx as model:
model.scheduler = get_scheduler(
model=model,
scheduler_name=self.scheduler
)
if isinstance(model, DiffusionPipeline):
for component in [model.unet, model.vae]:
configure_model_padding(component,
if isinstance(model, DiffusionPipeline):
for component in [model.unet, model.vae]:
configure_model_padding(component,
self.seamless,
self.seamless_axes
)
else:
configure_model_padding(model,
self.seamless,
self.seamless_axes
)
else:
configure_model_padding(model,
self.seamless,
self.seamless_axes
)
return model
return model_ctx
def get_conditioning_data(self, model: StableDiffusionGeneratorPipeline) -> ConditioningData:
uc, c, extra_conditioning_info = get_uc_and_c_and_ec(self.prompt, model=model)
print(f'DEBUG: uc.dtype={uc.dtype}, c.dtype={c.dtype}')
conditioning_data = ConditioningData(
uc,
c,
@ -230,18 +232,17 @@ class TextToLatentsInvocation(BaseInvocation):
def step_callback(state: PipelineIntermediateState):
self.dispatch_progress(context, source_node_id, state)
model = self.get_model(context.services.model_manager)
conditioning_data = self.get_conditioning_data(model)
with self.get_model(context.services.model_manager) as model:
conditioning_data = self.get_conditioning_data(model)
# TODO: Verify the noise is the right size
result_latents, result_attention_map_saver = model.latents_from_embeddings(
latents=torch.zeros_like(noise, dtype=torch_dtype(model.device)),
noise=noise,
num_inference_steps=self.steps,
conditioning_data=conditioning_data,
callback=step_callback
)
# TODO: Verify the noise is the right size
result_latents, result_attention_map_saver = model.latents_from_embeddings(
latents=torch.zeros_like(noise, dtype=torch_dtype(model.device)),
noise=noise,
num_inference_steps=self.steps,
conditioning_data=conditioning_data,
callback=step_callback
)
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
torch.cuda.empty_cache()
@ -284,29 +285,29 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
def step_callback(state: PipelineIntermediateState):
self.dispatch_progress(context, source_node_id, state)
model = self.get_model(context.services.model_manager)
conditioning_data = self.get_conditioning_data(model)
with self.get_model(context.services.model_manager) as model:
conditioning_data = self.get_conditioning_data(model)
# TODO: Verify the noise is the right size
# TODO: Verify the noise is the right size
initial_latents = latent if self.strength < 1.0 else torch.zeros_like(
latent, device=model.device, dtype=latent.dtype
)
initial_latents = latent if self.strength < 1.0 else torch.zeros_like(
latent, device=model.device, dtype=latent.dtype
)
timesteps, _ = model.get_img2img_timesteps(
self.steps,
self.strength,
device=model.device,
)
timesteps, _ = model.get_img2img_timesteps(
self.steps,
self.strength,
device=model.device,
)
result_latents, result_attention_map_saver = model.latents_from_embeddings(
latents=initial_latents,
timesteps=timesteps,
noise=noise,
num_inference_steps=self.steps,
conditioning_data=conditioning_data,
callback=step_callback
)
result_latents, result_attention_map_saver = model.latents_from_embeddings(
latents=initial_latents,
timesteps=timesteps,
noise=noise,
num_inference_steps=self.steps,
conditioning_data=conditioning_data,
callback=step_callback
)
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
torch.cuda.empty_cache()

View File

@ -7,7 +7,7 @@ def choose_model(model_manager: ModelManager, model_name: str):
if model_manager.valid_model(model_name):
model = model_manager.get_model(model_name)
else:
model = model_manager.get_model()
logger.warning(f"{model_name}' is not a valid model name. Using default model \'{model['model_name']}\' instead.")
model = model_manager.get_model(model_manager.default_model())
logger.warning(f"'{model_name}' is not a valid model name. Using default model \'{model.name}\' instead.")
return model

View File

@ -47,22 +47,21 @@ def get_model_manager(config: Args, logger: types.ModuleType) -> ModelManager:
else:
embedding_path = None
# migrate legacy models
ModelManager.migrate_models()
# creating the model manager
try:
device = torch.device(choose_torch_device())
precision = 'float16' if config.precision=='float16' \
else 'float32' if config.precision=='float32' \
else choose_precision(device)
if config.precision=="auto":
precision = choose_precision(device)
dtype = torch.float32 if precision=='float32' \
else torch.float16
model_manager = ModelManager(
OmegaConf.load(config.conf),
precision=precision,
config.conf,
precision=dtype,
device_type=device,
max_loaded_models=config.max_loaded_models,
embedding_path = Path(embedding_path),
# temporarily disabled until model manager stabilizes
# embedding_path = Path(embedding_path),
logger = logger,
)
except (FileNotFoundError, TypeError, AssertionError) as e: