Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
149 changes: 109 additions & 40 deletions src/diffusers/models/controlnets/controlnet_z_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,28 @@
SEQ_MULTI_OF = 32


# Copied from diffusers.models.transformers.transformer_z_image.apply_rotary_emb
def apply_rotary_emb(
x: torch.Tensor,
freqs_cis: torch.Tensor | tuple[torch.Tensor],
use_real: bool = True,
) -> torch.Tensor:
if use_real:
cos, sin = freqs_cis
cos = cos.unsqueeze(2).to(x.device)
sin = sin.unsqueeze(2).to(x.device)
x_real, x_imag = x.float().reshape(*x.shape[:-1], -1, 2).unbind(-1)
x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
return out
else:
with torch.amp.autocast("cuda", enabled=False):
x_complex = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
freqs_cis = freqs_cis.unsqueeze(2)
x_out = torch.view_as_real(x_complex * freqs_cis).flatten(3)
return x_out.type_as(x)


# Copied from diffusers.models.transformers.transformer_z_image.TimestepEmbedder
class TimestepEmbedder(nn.Module):
def __init__(self, out_size, mid_size=None, frequency_embedding_size=256):
Expand Down Expand Up @@ -84,11 +106,12 @@ class ZSingleStreamAttnProcessor:
_attention_backend = None
_parallel_config = None

def __init__(self):
def __init__(self, use_real: bool = False):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError(
"ZSingleStreamAttnProcessor requires PyTorch 2.0. To use it, please upgrade PyTorch to version 2.0 or higher."
)
self.use_real = use_real

def __call__(
self,
Expand All @@ -113,16 +136,9 @@ def __call__(
key = attn.norm_k(key)

# Apply RoPE
def apply_rotary_emb(x_in: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
with torch.amp.autocast("cuda", enabled=False):
x = torch.view_as_complex(x_in.float().reshape(*x_in.shape[:-1], -1, 2))
freqs_cis = freqs_cis.unsqueeze(2)
x_out = torch.view_as_real(x * freqs_cis).flatten(3)
return x_out.type_as(x_in) # todo

if freqs_cis is not None:
query = apply_rotary_emb(query, freqs_cis)
key = apply_rotary_emb(key, freqs_cis)
query = apply_rotary_emb(query, freqs_cis, use_real=self.use_real)
key = apply_rotary_emb(key, freqs_cis, use_real=self.use_real)

# Cast to correct dtype
dtype = query.dtype
Expand Down Expand Up @@ -197,6 +213,7 @@ def __init__(
norm_eps: float,
qk_norm: bool,
modulation=True,
use_real: bool = False,
):
super().__init__()
self.dim = dim
Expand All @@ -213,7 +230,7 @@ def __init__(
eps=1e-5,
bias=False,
out_bias=False,
processor=ZSingleStreamAttnProcessor(),
processor=ZSingleStreamAttnProcessor(use_real=use_real),
)

self.feed_forward = FeedForward(dim=dim, hidden_dim=int(dim / 3 * 8))
Expand Down Expand Up @@ -293,22 +310,29 @@ def __init__(
theta: float = 256.0,
axes_dims: list[int] = (16, 56, 56),
axes_lens: list[int] = (64, 128, 128),
use_real: bool = False,
):
self.theta = theta
self.axes_dims = axes_dims
self.axes_lens = axes_lens
self.use_real = use_real
assert len(axes_dims) == len(axes_lens), "axes_dims and axes_lens must have the same length"
self.freqs_cis = None

@staticmethod
def precompute_freqs_cis(dim: list[int], end: list[int], theta: float = 256.0):
def precompute_freqs_cis(dim: list[int], end: list[int], theta: float = 256.0, use_real: bool = False):
with torch.device("cpu"):
freqs_cis = []
for i, (d, e) in enumerate(zip(dim, end)):
freqs = 1.0 / (theta ** (torch.arange(0, d, 2, dtype=torch.float64, device="cpu") / d))
timestep = torch.arange(e, device=freqs.device, dtype=torch.float64)
freqs = torch.outer(timestep, freqs).float()
freqs_cis_i = torch.polar(torch.ones_like(freqs), freqs).to(torch.complex64) # complex64
if use_real:
cos = freqs.cos().repeat_interleave(2, dim=-1)
sin = freqs.sin().repeat_interleave(2, dim=-1)
freqs_cis_i = (cos, sin)
else:
freqs_cis_i = torch.polar(torch.ones_like(freqs), freqs).to(torch.complex64)
freqs_cis.append(freqs_cis_i)

return freqs_cis
Expand All @@ -319,18 +343,34 @@ def __call__(self, ids: torch.Tensor):
device = ids.device

if self.freqs_cis is None:
self.freqs_cis = self.precompute_freqs_cis(self.axes_dims, self.axes_lens, theta=self.theta)
self.freqs_cis = [freqs_cis.to(device) for freqs_cis in self.freqs_cis]
self.freqs_cis = self.precompute_freqs_cis(self.axes_dims, self.axes_lens, theta=self.theta, use_real=self.use_real)
if self.use_real:
self.freqs_cis = [(cos.to(device), sin.to(device)) for cos, sin in self.freqs_cis]
else:
self.freqs_cis = [freqs_cis.to(device) for freqs_cis in self.freqs_cis]
else:
# Ensure freqs_cis are on the same device as ids
if self.freqs_cis[0].device != device:
self.freqs_cis = [freqs_cis.to(device) for freqs_cis in self.freqs_cis]

result = []
for i in range(len(self.axes_dims)):
index = ids[:, i]
result.append(self.freqs_cis[i][index])
return torch.cat(result, dim=-1)
if self.use_real:
if self.freqs_cis[0][0].device != device:
self.freqs_cis = [(cos.to(device), sin.to(device)) for cos, sin in self.freqs_cis]
else:
if self.freqs_cis[0].device != device:
self.freqs_cis = [freqs_cis.to(device) for freqs_cis in self.freqs_cis]

if self.use_real:
cos_result = []
sin_result = []
for i in range(len(self.axes_dims)):
index = ids[:, i]
cos_result.append(self.freqs_cis[i][0][index])
sin_result.append(self.freqs_cis[i][1][index])
return (torch.cat(cos_result, dim=-1), torch.cat(sin_result, dim=-1))
else:
result = []
for i in range(len(self.axes_dims)):
index = ids[:, i]
result.append(self.freqs_cis[i][index])
return torch.cat(result, dim=-1)


@maybe_allow_in_graph
Expand All @@ -345,6 +385,7 @@ def __init__(
qk_norm: bool,
modulation=True,
block_id=0,
use_real: bool = False,
):
super().__init__()
self.dim = dim
Expand All @@ -361,7 +402,7 @@ def __init__(
eps=1e-5,
bias=False,
out_bias=False,
processor=ZSingleStreamAttnProcessor(),
processor=ZSingleStreamAttnProcessor(use_real=use_real),
)

self.feed_forward = FeedForward(dim=dim, hidden_dim=int(dim / 3 * 8))
Expand Down Expand Up @@ -447,6 +488,7 @@ def __init__(
n_kv_heads=30,
norm_eps=1e-5,
qk_norm=True,
use_real: bool = False,
):
super().__init__()
self.control_layers_places = control_layers_places
Expand All @@ -459,7 +501,7 @@ def __init__(
# control blocks
self.control_layers = nn.ModuleList(
[
ZImageControlTransformerBlock(i, dim, n_heads, n_kv_heads, norm_eps, qk_norm, block_id=i)
ZImageControlTransformerBlock(i, dim, n_heads, n_kv_heads, norm_eps, qk_norm, block_id=i, use_real=use_real)
for i in self.control_layers_places
]
)
Expand All @@ -485,6 +527,7 @@ def __init__(
qk_norm,
modulation=True,
block_id=layer_id,
use_real=use_real,
)
for layer_id in range(n_refiner_layers)
]
Expand All @@ -500,6 +543,7 @@ def __init__(
norm_eps,
qk_norm,
modulation=True,
use_real=use_real,
)
for layer_id in range(n_refiner_layers)
]
Expand Down Expand Up @@ -732,12 +776,19 @@ def forward(
adaln_input = t.type_as(x)
x[torch.cat(x_inner_pad_mask)] = self.x_pad_token
x = list(x.split(x_item_seqlens, dim=0))
x_freqs_cis = list(self.rope_embedder(torch.cat(x_pos_ids, dim=0)).split([len(_) for _ in x_pos_ids], dim=0))

x_rope_output = self.rope_embedder(torch.cat(x_pos_ids, dim=0))
x_split_sizes = [len(_) for _ in x_pos_ids]
x = pad_sequence(x, batch_first=True, padding_value=0.0)
x_freqs_cis = pad_sequence(x_freqs_cis, batch_first=True, padding_value=0.0)
# Clarify the length matches to satisfy Dynamo due to "Symbolic Shape Inference" to avoid compilation errors
x_freqs_cis = x_freqs_cis[:, : x.shape[1]]
if isinstance(x_rope_output, tuple):
x_cos_list = list(x_rope_output[0].split(x_split_sizes, dim=0))
x_sin_list = list(x_rope_output[1].split(x_split_sizes, dim=0))
x_freqs_cis = (
pad_sequence(x_cos_list, batch_first=True, padding_value=0.0)[:, :x.shape[1]],
pad_sequence(x_sin_list, batch_first=True, padding_value=0.0)[:, :x.shape[1]],
)
else:
x_freqs_cis_list = list(x_rope_output.split(x_split_sizes, dim=0))
x_freqs_cis = pad_sequence(x_freqs_cis_list, batch_first=True, padding_value=0.0)[:, :x.shape[1]]

x_attn_mask = torch.zeros((bsz, x_max_item_seqlen), dtype=torch.bool, device=device)
for i, seq_len in enumerate(x_item_seqlens):
Expand Down Expand Up @@ -788,14 +839,19 @@ def forward(
cap_feats = self.cap_embedder(cap_feats)
cap_feats[torch.cat(cap_inner_pad_mask)] = self.cap_pad_token
cap_feats = list(cap_feats.split(cap_item_seqlens, dim=0))
cap_freqs_cis = list(
self.rope_embedder(torch.cat(cap_pos_ids, dim=0)).split([len(_) for _ in cap_pos_ids], dim=0)
)

cap_rope_output = self.rope_embedder(torch.cat(cap_pos_ids, dim=0))
cap_split_sizes = [len(_) for _ in cap_pos_ids]
cap_feats = pad_sequence(cap_feats, batch_first=True, padding_value=0.0)
cap_freqs_cis = pad_sequence(cap_freqs_cis, batch_first=True, padding_value=0.0)
# Clarify the length matches to satisfy Dynamo due to "Symbolic Shape Inference" to avoid compilation errors
cap_freqs_cis = cap_freqs_cis[:, : cap_feats.shape[1]]
if isinstance(cap_rope_output, tuple):
cap_cos_list = list(cap_rope_output[0].split(cap_split_sizes, dim=0))
cap_sin_list = list(cap_rope_output[1].split(cap_split_sizes, dim=0))
cap_freqs_cis = (
pad_sequence(cap_cos_list, batch_first=True, padding_value=0.0)[:, :cap_feats.shape[1]],
pad_sequence(cap_sin_list, batch_first=True, padding_value=0.0)[:, :cap_feats.shape[1]],
)
else:
cap_freqs_cis_list = list(cap_rope_output.split(cap_split_sizes, dim=0))
cap_freqs_cis = pad_sequence(cap_freqs_cis_list, batch_first=True, padding_value=0.0)[:, :cap_feats.shape[1]]

cap_attn_mask = torch.zeros((bsz, cap_max_item_seqlen), dtype=torch.bool, device=device)
for i, seq_len in enumerate(cap_item_seqlens):
Expand All @@ -809,19 +865,32 @@ def forward(
cap_feats = layer(cap_feats, cap_attn_mask, cap_freqs_cis)

# unified
is_real = isinstance(x_freqs_cis, tuple)
unified = []
unified_freqs_cis = []
unified_freqs_cos = [] if is_real else None
unified_freqs_sin = [] if is_real else None
unified_freqs_cis = [] if not is_real else None
for i in range(bsz):
x_len = x_item_seqlens[i]
cap_len = cap_item_seqlens[i]
unified.append(torch.cat([x[i][:x_len], cap_feats[i][:cap_len]]))
unified_freqs_cis.append(torch.cat([x_freqs_cis[i][:x_len], cap_freqs_cis[i][:cap_len]]))
if is_real:
unified_freqs_cos.append(torch.cat([x_freqs_cis[0][i][:x_len], cap_freqs_cis[0][i][:cap_len]]))
unified_freqs_sin.append(torch.cat([x_freqs_cis[1][i][:x_len], cap_freqs_cis[1][i][:cap_len]]))
else:
unified_freqs_cis.append(torch.cat([x_freqs_cis[i][:x_len], cap_freqs_cis[i][:cap_len]]))
unified_item_seqlens = [a + b for a, b in zip(cap_item_seqlens, x_item_seqlens)]
assert unified_item_seqlens == [len(_) for _ in unified]
unified_max_item_seqlen = max(unified_item_seqlens)

unified = pad_sequence(unified, batch_first=True, padding_value=0.0)
unified_freqs_cis = pad_sequence(unified_freqs_cis, batch_first=True, padding_value=0.0)
if is_real:
unified_freqs_cis = (
pad_sequence(unified_freqs_cos, batch_first=True, padding_value=0.0),
pad_sequence(unified_freqs_sin, batch_first=True, padding_value=0.0),
)
else:
unified_freqs_cis = pad_sequence(unified_freqs_cis, batch_first=True, padding_value=0.0)
unified_attn_mask = torch.zeros((bsz, unified_max_item_seqlen), dtype=torch.bool, device=device)
for i, seq_len in enumerate(unified_item_seqlens):
unified_attn_mask[i, :seq_len] = 1
Expand Down
Loading
Loading