blackified

This commit is contained in:
Lincoln Stein 2023-07-29 17:30:54 -04:00
parent e82eb0b9fc
commit 348bee8981
3 changed files with 12 additions and 10 deletions

View File

@ -6,8 +6,7 @@ from pydantic import Field
from invokeai.app.invocations.prompt import PromptOutput from invokeai.app.invocations.prompt import PromptOutput
from .baseinvocation import (BaseInvocation, BaseInvocationOutput, from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationConfig, InvocationContext
InvocationConfig, InvocationContext)
from .math import FloatOutput, IntOutput from .math import FloatOutput, IntOutput
# Pass-through parameter nodes - used by subgraphs # Pass-through parameter nodes - used by subgraphs
@ -68,6 +67,7 @@ class ParamStringInvocation(BaseInvocation):
def invoke(self, context: InvocationContext) -> StringOutput: def invoke(self, context: InvocationContext) -> StringOutput:
return StringOutput(text=self.text) return StringOutput(text=self.text)
class ParamPromptInvocation(BaseInvocation): class ParamPromptInvocation(BaseInvocation):
"""A prompt input parameter""" """A prompt input parameter"""

View File

@ -1284,8 +1284,8 @@ def download_from_original_stable_diffusion_ckpt(
original_config_file = BytesIO(requests.get(config_url).content) original_config_file = BytesIO(requests.get(config_url).content)
original_config = OmegaConf.load(original_config_file) original_config = OmegaConf.load(original_config_file)
if original_config['model']['params'].get('use_ema') is not None: if original_config["model"]["params"].get("use_ema") is not None:
extract_ema = original_config['model']['params']['use_ema'] extract_ema = original_config["model"]["params"]["use_ema"]
if ( if (
model_version == BaseModelType.StableDiffusion2 model_version == BaseModelType.StableDiffusion2
@ -1689,9 +1689,9 @@ def download_controlnet_from_original_ckpt(
checkpoint = checkpoint["state_dict"] checkpoint = checkpoint["state_dict"]
# use original precision # use original precision
precision_probing_key = 'input_blocks.0.0.bias' precision_probing_key = "input_blocks.0.0.bias"
ckpt_precision = checkpoint[precision_probing_key].dtype ckpt_precision = checkpoint[precision_probing_key].dtype
logger.debug(f'original controlnet precision = {ckpt_precision}') logger.debug(f"original controlnet precision = {ckpt_precision}")
precision = precision or ckpt_precision precision = precision or ckpt_precision
original_config = OmegaConf.load(original_config_file) original_config = OmegaConf.load(original_config_file)

View File

@ -19,6 +19,7 @@ from .base import (
from invokeai.app.services.config import InvokeAIAppConfig from invokeai.app.services.config import InvokeAIAppConfig
import invokeai.backend.util.logging as logger import invokeai.backend.util.logging as logger
class ControlNetModelFormat(str, Enum): class ControlNetModelFormat(str, Enum):
Checkpoint = "checkpoint" Checkpoint = "checkpoint"
Diffusers = "diffusers" Diffusers = "diffusers"
@ -123,6 +124,7 @@ class ControlNetModel(ModelBase):
else: else:
return model_path return model_path
def _convert_controlnet_ckpt_and_cache( def _convert_controlnet_ckpt_and_cache(
model_path: str, model_path: str,
output_path: str, output_path: str,