mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Tidy _prepare_attention_processors(...) logic.
This commit is contained in:
parent
3f860c3523
commit
4df1cdb34d
@ -13,22 +13,10 @@ def _prepare_attention_processors(unet: UNet2DConditionModel, ip_adapters: list[
|
|||||||
|
|
||||||
Note that the `unet` param is only used to determine attention block dimensions and naming.
|
Note that the `unet` param is only used to determine attention block dimensions and naming.
|
||||||
"""
|
"""
|
||||||
# TODO(ryand): This logic can be simplified.
|
|
||||||
|
|
||||||
# Construct a dict of attention processors based on the UNet's architecture.
|
# Construct a dict of attention processors based on the UNet's architecture.
|
||||||
attn_procs = {}
|
attn_procs = {}
|
||||||
for idx, name in enumerate(unet.attn_processors.keys()):
|
for idx, name in enumerate(unet.attn_processors.keys()):
|
||||||
cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
|
if name.endswith("attn1.processor"):
|
||||||
if name.startswith("mid_block"):
|
|
||||||
hidden_size = unet.config.block_out_channels[-1]
|
|
||||||
elif name.startswith("up_blocks"):
|
|
||||||
block_id = int(name[len("up_blocks.")])
|
|
||||||
hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
|
|
||||||
elif name.startswith("down_blocks"):
|
|
||||||
block_id = int(name[len("down_blocks.")])
|
|
||||||
hidden_size = unet.config.block_out_channels[block_id]
|
|
||||||
|
|
||||||
if cross_attention_dim is None:
|
|
||||||
attn_procs[name] = AttnProcessor2_0()
|
attn_procs[name] = AttnProcessor2_0()
|
||||||
else:
|
else:
|
||||||
# Collect the weights from each IP Adapter for the idx'th attention processor.
|
# Collect the weights from each IP Adapter for the idx'th attention processor.
|
||||||
@ -43,8 +31,7 @@ def apply_ip_adapter_attention(unet: UNet2DConditionModel, ip_adapters: list[IPA
|
|||||||
"""A context manager that patches `unet` with IP-Adapter attention processors.
|
"""A context manager that patches `unet` with IP-Adapter attention processors.
|
||||||
|
|
||||||
Yields:
|
Yields:
|
||||||
Scales: The Scales object, which can be used to dynamically alter the scales of the
|
Scales: The Scales object, which can be used to dynamically alter the scales of the IP-Adapters.
|
||||||
IP-Adapters.
|
|
||||||
"""
|
"""
|
||||||
scales = Scales([1.0] * len(ip_adapters))
|
scales = Scales([1.0] * len(ip_adapters))
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user