blackified

This commit is contained in:
Lincoln Stein
2023-07-29 17:30:54 -04:00
parent e82eb0b9fc
commit 348bee8981
3 changed files with 12 additions and 10 deletions

View File

@ -6,8 +6,7 @@ from pydantic import Field
from invokeai.app.invocations.prompt import PromptOutput from invokeai.app.invocations.prompt import PromptOutput
from .baseinvocation import (BaseInvocation, BaseInvocationOutput, from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationConfig, InvocationContext
InvocationConfig, InvocationContext)
from .math import FloatOutput, IntOutput from .math import FloatOutput, IntOutput
# Pass-through parameter nodes - used by subgraphs # Pass-through parameter nodes - used by subgraphs
@ -68,6 +67,7 @@ class ParamStringInvocation(BaseInvocation):
def invoke(self, context: InvocationContext) -> StringOutput: def invoke(self, context: InvocationContext) -> StringOutput:
return StringOutput(text=self.text) return StringOutput(text=self.text)
class ParamPromptInvocation(BaseInvocation): class ParamPromptInvocation(BaseInvocation):
"""A prompt input parameter""" """A prompt input parameter"""
@ -80,4 +80,4 @@ class ParamPromptInvocation(BaseInvocation):
} }
def invoke(self, context: InvocationContext) -> PromptOutput: def invoke(self, context: InvocationContext) -> PromptOutput:
return PromptOutput(prompt=self.prompt) return PromptOutput(prompt=self.prompt)

View File

@ -1252,7 +1252,7 @@ def download_from_original_stable_diffusion_ckpt(
checkpoint = checkpoint["state_dict"] checkpoint = checkpoint["state_dict"]
logger.debug(f"model_type = {model_type}; original_config_file = {original_config_file}") logger.debug(f"model_type = {model_type}; original_config_file = {original_config_file}")
precision_probing_key = "model.diffusion_model.input_blocks.0.0.bias" precision_probing_key = "model.diffusion_model.input_blocks.0.0.bias"
logger.debug(f"original checkpoint precision == {checkpoint[precision_probing_key].dtype}") logger.debug(f"original checkpoint precision == {checkpoint[precision_probing_key].dtype}")
precision = precision or checkpoint[precision_probing_key].dtype precision = precision or checkpoint[precision_probing_key].dtype
@ -1284,9 +1284,9 @@ def download_from_original_stable_diffusion_ckpt(
original_config_file = BytesIO(requests.get(config_url).content) original_config_file = BytesIO(requests.get(config_url).content)
original_config = OmegaConf.load(original_config_file) original_config = OmegaConf.load(original_config_file)
if original_config['model']['params'].get('use_ema') is not None: if original_config["model"]["params"].get("use_ema") is not None:
extract_ema = original_config['model']['params']['use_ema'] extract_ema = original_config["model"]["params"]["use_ema"]
if ( if (
model_version == BaseModelType.StableDiffusion2 model_version == BaseModelType.StableDiffusion2
and original_config["model"]["params"].get("parameterization") == "v" and original_config["model"]["params"].get("parameterization") == "v"
@ -1689,9 +1689,9 @@ def download_controlnet_from_original_ckpt(
checkpoint = checkpoint["state_dict"] checkpoint = checkpoint["state_dict"]
# use original precision # use original precision
precision_probing_key = 'input_blocks.0.0.bias' precision_probing_key = "input_blocks.0.0.bias"
ckpt_precision = checkpoint[precision_probing_key].dtype ckpt_precision = checkpoint[precision_probing_key].dtype
logger.debug(f'original controlnet precision = {ckpt_precision}') logger.debug(f"original controlnet precision = {ckpt_precision}")
precision = precision or ckpt_precision precision = precision or ckpt_precision
original_config = OmegaConf.load(original_config_file) original_config = OmegaConf.load(original_config_file)

View File

@ -19,6 +19,7 @@ from .base import (
from invokeai.app.services.config import InvokeAIAppConfig from invokeai.app.services.config import InvokeAIAppConfig
import invokeai.backend.util.logging as logger import invokeai.backend.util.logging as logger
class ControlNetModelFormat(str, Enum): class ControlNetModelFormat(str, Enum):
Checkpoint = "checkpoint" Checkpoint = "checkpoint"
Diffusers = "diffusers" Diffusers = "diffusers"
@ -123,6 +124,7 @@ class ControlNetModel(ModelBase):
else: else:
return model_path return model_path
def _convert_controlnet_ckpt_and_cache( def _convert_controlnet_ckpt_and_cache(
model_path: str, model_path: str,
output_path: str, output_path: str,
@ -137,7 +139,7 @@ def _convert_controlnet_ckpt_and_cache(
app_config = InvokeAIAppConfig.get_config() app_config = InvokeAIAppConfig.get_config()
weights = app_config.root_path / model_path weights = app_config.root_path / model_path
output_path = Path(output_path) output_path = Path(output_path)
logger.info(f"Converting {weights} to diffusers format") logger.info(f"Converting {weights} to diffusers format")
# return cached version if it exists # return cached version if it exists
if output_path.exists(): if output_path.exists():