mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
blackified
This commit is contained in:
parent
e82eb0b9fc
commit
348bee8981
@ -6,8 +6,7 @@ from pydantic import Field
|
||||
|
||||
from invokeai.app.invocations.prompt import PromptOutput
|
||||
|
||||
from .baseinvocation import (BaseInvocation, BaseInvocationOutput,
|
||||
InvocationConfig, InvocationContext)
|
||||
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationConfig, InvocationContext
|
||||
from .math import FloatOutput, IntOutput
|
||||
|
||||
# Pass-through parameter nodes - used by subgraphs
|
||||
@ -68,6 +67,7 @@ class ParamStringInvocation(BaseInvocation):
|
||||
def invoke(self, context: InvocationContext) -> StringOutput:
|
||||
return StringOutput(text=self.text)
|
||||
|
||||
|
||||
class ParamPromptInvocation(BaseInvocation):
|
||||
"""A prompt input parameter"""
|
||||
|
||||
@ -80,4 +80,4 @@ class ParamPromptInvocation(BaseInvocation):
|
||||
}
|
||||
|
||||
def invoke(self, context: InvocationContext) -> PromptOutput:
|
||||
return PromptOutput(prompt=self.prompt)
|
||||
return PromptOutput(prompt=self.prompt)
|
||||
|
@ -1252,7 +1252,7 @@ def download_from_original_stable_diffusion_ckpt(
|
||||
checkpoint = checkpoint["state_dict"]
|
||||
|
||||
logger.debug(f"model_type = {model_type}; original_config_file = {original_config_file}")
|
||||
|
||||
|
||||
precision_probing_key = "model.diffusion_model.input_blocks.0.0.bias"
|
||||
logger.debug(f"original checkpoint precision == {checkpoint[precision_probing_key].dtype}")
|
||||
precision = precision or checkpoint[precision_probing_key].dtype
|
||||
@ -1284,9 +1284,9 @@ def download_from_original_stable_diffusion_ckpt(
|
||||
original_config_file = BytesIO(requests.get(config_url).content)
|
||||
|
||||
original_config = OmegaConf.load(original_config_file)
|
||||
if original_config['model']['params'].get('use_ema') is not None:
|
||||
extract_ema = original_config['model']['params']['use_ema']
|
||||
|
||||
if original_config["model"]["params"].get("use_ema") is not None:
|
||||
extract_ema = original_config["model"]["params"]["use_ema"]
|
||||
|
||||
if (
|
||||
model_version == BaseModelType.StableDiffusion2
|
||||
and original_config["model"]["params"].get("parameterization") == "v"
|
||||
@ -1689,9 +1689,9 @@ def download_controlnet_from_original_ckpt(
|
||||
checkpoint = checkpoint["state_dict"]
|
||||
|
||||
# use original precision
|
||||
precision_probing_key = 'input_blocks.0.0.bias'
|
||||
precision_probing_key = "input_blocks.0.0.bias"
|
||||
ckpt_precision = checkpoint[precision_probing_key].dtype
|
||||
logger.debug(f'original controlnet precision = {ckpt_precision}')
|
||||
logger.debug(f"original controlnet precision = {ckpt_precision}")
|
||||
precision = precision or ckpt_precision
|
||||
|
||||
original_config = OmegaConf.load(original_config_file)
|
||||
|
@ -19,6 +19,7 @@ from .base import (
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
import invokeai.backend.util.logging as logger
|
||||
|
||||
|
||||
class ControlNetModelFormat(str, Enum):
|
||||
Checkpoint = "checkpoint"
|
||||
Diffusers = "diffusers"
|
||||
@ -123,6 +124,7 @@ class ControlNetModel(ModelBase):
|
||||
else:
|
||||
return model_path
|
||||
|
||||
|
||||
def _convert_controlnet_ckpt_and_cache(
|
||||
model_path: str,
|
||||
output_path: str,
|
||||
@ -137,7 +139,7 @@ def _convert_controlnet_ckpt_and_cache(
|
||||
app_config = InvokeAIAppConfig.get_config()
|
||||
weights = app_config.root_path / model_path
|
||||
output_path = Path(output_path)
|
||||
|
||||
|
||||
logger.info(f"Converting {weights} to diffusers format")
|
||||
# return cached version if it exists
|
||||
if output_path.exists():
|
||||
|
Loading…
Reference in New Issue
Block a user