mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Move clip skip to separate node
This commit is contained in:
parent
04b57c408f
commit
a9e77675a8
@ -44,7 +44,6 @@ class CompelInvocation(BaseInvocation):
|
||||
|
||||
prompt: str = Field(default="", description="Prompt")
|
||||
clip: ClipField = Field(None, description="Clip to use")
|
||||
clip_skip: int = Field(0, description="Layers to skip in text_encoder")
|
||||
|
||||
# Schema customisation
|
||||
class Config(InvocationConfig):
|
||||
@ -96,7 +95,7 @@ class CompelInvocation(BaseInvocation):
|
||||
|
||||
with ModelPatcher.apply_lora_text_encoder(text_encoder_info.context.model, _lora_loader()),\
|
||||
ModelPatcher.apply_ti(tokenizer_info.context.model, text_encoder_info.context.model, ti_list) as (tokenizer, ti_manager),\
|
||||
ModelPatcher.apply_clip_skip(text_encoder_info.context.model, self.clip_skip),\
|
||||
ModelPatcher.apply_clip_skip(text_encoder_info.context.model, self.clip.skipped_layers),\
|
||||
text_encoder_info as text_encoder:
|
||||
|
||||
compel = Compel(
|
||||
@ -136,6 +135,24 @@ class CompelInvocation(BaseInvocation):
|
||||
),
|
||||
)
|
||||
|
||||
class ClipSkipInvocationOutput(BaseInvocationOutput):
|
||||
"""Clip skip node output"""
|
||||
type: Literal["clip_skip_output"] = "clip_skip_output"
|
||||
clip: ClipField = Field(None, description="Clip with skipped layers")
|
||||
|
||||
class ClipSkipInvocation(BaseInvocation):
|
||||
"""Skip layers in clip text_encoder model."""
|
||||
type: Literal["clip_skip"] = "clip_skip"
|
||||
|
||||
clip: ClipField = Field(None, description="Clip to use")
|
||||
skipped_layers: int = Field(0, description="Number of layers to skip in text_encoder")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ClipSkipInvocationOutput:
|
||||
self.clip.skipped_layers += self.skipped_layers
|
||||
return ClipSkipInvocationOutput(
|
||||
clip=self.clip,
|
||||
)
|
||||
|
||||
|
||||
def get_max_token_count(
|
||||
tokenizer, prompt: Union[FlattenedPrompt, Blend, Conjunction],
|
||||
|
@ -30,6 +30,7 @@ class UNetField(BaseModel):
|
||||
class ClipField(BaseModel):
|
||||
tokenizer: ModelInfo = Field(description="Info to load tokenizer submodel")
|
||||
text_encoder: ModelInfo = Field(description="Info to load text_encoder submodel")
|
||||
skipped_layers: int = Field(description="Number of skipped layers in text_encoder")
|
||||
loras: List[LoraInfo] = Field(description="Loras to apply on model loading")
|
||||
|
||||
|
||||
@ -154,6 +155,7 @@ class MainModelLoaderInvocation(BaseInvocation):
|
||||
submodel=SubModelType.TextEncoder,
|
||||
),
|
||||
loras=[],
|
||||
skipped_layers=0,
|
||||
),
|
||||
vae=VaeField(
|
||||
vae=ModelInfo(
|
||||
|
Loading…
Reference in New Issue
Block a user