diff --git a/src/diffusers/models/controlnets/controlnet_z_image.py b/src/diffusers/models/controlnets/controlnet_z_image.py index a4800b255ef0..d78f3e3cd0a0 100644 --- a/src/diffusers/models/controlnets/controlnet_z_image.py +++ b/src/diffusers/models/controlnets/controlnet_z_image.py @@ -35,6 +35,28 @@ SEQ_MULTI_OF = 32 +# Copied from diffusers.models.transformers.transformer_z_image.apply_rotary_emb +def apply_rotary_emb( + x: torch.Tensor, + freqs_cis: torch.Tensor | tuple[torch.Tensor], + use_real: bool = True, +) -> torch.Tensor: + if use_real: + cos, sin = freqs_cis + cos = cos.unsqueeze(2).to(x.device) + sin = sin.unsqueeze(2).to(x.device) + x_real, x_imag = x.float().reshape(*x.shape[:-1], -1, 2).unbind(-1) + x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3) + out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype) + return out + else: + with torch.amp.autocast("cuda", enabled=False): + x_complex = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2)) + freqs_cis = freqs_cis.unsqueeze(2) + x_out = torch.view_as_real(x_complex * freqs_cis).flatten(3) + return x_out.type_as(x) + + # Copied from diffusers.models.transformers.transformer_z_image.TimestepEmbedder class TimestepEmbedder(nn.Module): def __init__(self, out_size, mid_size=None, frequency_embedding_size=256): @@ -84,11 +106,12 @@ class ZSingleStreamAttnProcessor: _attention_backend = None _parallel_config = None - def __init__(self): + def __init__(self, use_real: bool = False): if not hasattr(F, "scaled_dot_product_attention"): raise ImportError( "ZSingleStreamAttnProcessor requires PyTorch 2.0. To use it, please upgrade PyTorch to version 2.0 or higher." ) + self.use_real = use_real def __call__( self, @@ -113,16 +136,9 @@ def __call__( key = attn.norm_k(key) # Apply RoPE - def apply_rotary_emb(x_in: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor: - with torch.amp.autocast("cuda", enabled=False): - x = torch.view_as_complex(x_in.float().reshape(*x_in.shape[:-1], -1, 2)) - freqs_cis = freqs_cis.unsqueeze(2) - x_out = torch.view_as_real(x * freqs_cis).flatten(3) - return x_out.type_as(x_in) # todo - if freqs_cis is not None: - query = apply_rotary_emb(query, freqs_cis) - key = apply_rotary_emb(key, freqs_cis) + query = apply_rotary_emb(query, freqs_cis, use_real=self.use_real) + key = apply_rotary_emb(key, freqs_cis, use_real=self.use_real) # Cast to correct dtype dtype = query.dtype @@ -197,6 +213,7 @@ def __init__( norm_eps: float, qk_norm: bool, modulation=True, + use_real: bool = False, ): super().__init__() self.dim = dim @@ -213,7 +230,7 @@ def __init__( eps=1e-5, bias=False, out_bias=False, - processor=ZSingleStreamAttnProcessor(), + processor=ZSingleStreamAttnProcessor(use_real=use_real), ) self.feed_forward = FeedForward(dim=dim, hidden_dim=int(dim / 3 * 8)) @@ -293,22 +310,29 @@ def __init__( theta: float = 256.0, axes_dims: list[int] = (16, 56, 56), axes_lens: list[int] = (64, 128, 128), + use_real: bool = False, ): self.theta = theta self.axes_dims = axes_dims self.axes_lens = axes_lens + self.use_real = use_real assert len(axes_dims) == len(axes_lens), "axes_dims and axes_lens must have the same length" self.freqs_cis = None @staticmethod - def precompute_freqs_cis(dim: list[int], end: list[int], theta: float = 256.0): + def precompute_freqs_cis(dim: list[int], end: list[int], theta: float = 256.0, use_real: bool = False): with torch.device("cpu"): freqs_cis = [] for i, (d, e) in enumerate(zip(dim, end)): freqs = 1.0 / (theta ** (torch.arange(0, d, 2, dtype=torch.float64, device="cpu") / d)) timestep = torch.arange(e, device=freqs.device, dtype=torch.float64) freqs = torch.outer(timestep, freqs).float() - freqs_cis_i = torch.polar(torch.ones_like(freqs), freqs).to(torch.complex64) # complex64 + if use_real: + cos = freqs.cos().repeat_interleave(2, dim=-1) + sin = freqs.sin().repeat_interleave(2, dim=-1) + freqs_cis_i = (cos, sin) + else: + freqs_cis_i = torch.polar(torch.ones_like(freqs), freqs).to(torch.complex64) freqs_cis.append(freqs_cis_i) return freqs_cis @@ -319,18 +343,34 @@ def __call__(self, ids: torch.Tensor): device = ids.device if self.freqs_cis is None: - self.freqs_cis = self.precompute_freqs_cis(self.axes_dims, self.axes_lens, theta=self.theta) - self.freqs_cis = [freqs_cis.to(device) for freqs_cis in self.freqs_cis] + self.freqs_cis = self.precompute_freqs_cis(self.axes_dims, self.axes_lens, theta=self.theta, use_real=self.use_real) + if self.use_real: + self.freqs_cis = [(cos.to(device), sin.to(device)) for cos, sin in self.freqs_cis] + else: + self.freqs_cis = [freqs_cis.to(device) for freqs_cis in self.freqs_cis] else: # Ensure freqs_cis are on the same device as ids - if self.freqs_cis[0].device != device: - self.freqs_cis = [freqs_cis.to(device) for freqs_cis in self.freqs_cis] - - result = [] - for i in range(len(self.axes_dims)): - index = ids[:, i] - result.append(self.freqs_cis[i][index]) - return torch.cat(result, dim=-1) + if self.use_real: + if self.freqs_cis[0][0].device != device: + self.freqs_cis = [(cos.to(device), sin.to(device)) for cos, sin in self.freqs_cis] + else: + if self.freqs_cis[0].device != device: + self.freqs_cis = [freqs_cis.to(device) for freqs_cis in self.freqs_cis] + + if self.use_real: + cos_result = [] + sin_result = [] + for i in range(len(self.axes_dims)): + index = ids[:, i] + cos_result.append(self.freqs_cis[i][0][index]) + sin_result.append(self.freqs_cis[i][1][index]) + return (torch.cat(cos_result, dim=-1), torch.cat(sin_result, dim=-1)) + else: + result = [] + for i in range(len(self.axes_dims)): + index = ids[:, i] + result.append(self.freqs_cis[i][index]) + return torch.cat(result, dim=-1) @maybe_allow_in_graph @@ -345,6 +385,7 @@ def __init__( qk_norm: bool, modulation=True, block_id=0, + use_real: bool = False, ): super().__init__() self.dim = dim @@ -361,7 +402,7 @@ def __init__( eps=1e-5, bias=False, out_bias=False, - processor=ZSingleStreamAttnProcessor(), + processor=ZSingleStreamAttnProcessor(use_real=use_real), ) self.feed_forward = FeedForward(dim=dim, hidden_dim=int(dim / 3 * 8)) @@ -447,6 +488,7 @@ def __init__( n_kv_heads=30, norm_eps=1e-5, qk_norm=True, + use_real: bool = False, ): super().__init__() self.control_layers_places = control_layers_places @@ -459,7 +501,7 @@ def __init__( # control blocks self.control_layers = nn.ModuleList( [ - ZImageControlTransformerBlock(i, dim, n_heads, n_kv_heads, norm_eps, qk_norm, block_id=i) + ZImageControlTransformerBlock(i, dim, n_heads, n_kv_heads, norm_eps, qk_norm, block_id=i, use_real=use_real) for i in self.control_layers_places ] ) @@ -485,6 +527,7 @@ def __init__( qk_norm, modulation=True, block_id=layer_id, + use_real=use_real, ) for layer_id in range(n_refiner_layers) ] @@ -500,6 +543,7 @@ def __init__( norm_eps, qk_norm, modulation=True, + use_real=use_real, ) for layer_id in range(n_refiner_layers) ] @@ -732,12 +776,19 @@ def forward( adaln_input = t.type_as(x) x[torch.cat(x_inner_pad_mask)] = self.x_pad_token x = list(x.split(x_item_seqlens, dim=0)) - x_freqs_cis = list(self.rope_embedder(torch.cat(x_pos_ids, dim=0)).split([len(_) for _ in x_pos_ids], dim=0)) - + x_rope_output = self.rope_embedder(torch.cat(x_pos_ids, dim=0)) + x_split_sizes = [len(_) for _ in x_pos_ids] x = pad_sequence(x, batch_first=True, padding_value=0.0) - x_freqs_cis = pad_sequence(x_freqs_cis, batch_first=True, padding_value=0.0) - # Clarify the length matches to satisfy Dynamo due to "Symbolic Shape Inference" to avoid compilation errors - x_freqs_cis = x_freqs_cis[:, : x.shape[1]] + if isinstance(x_rope_output, tuple): + x_cos_list = list(x_rope_output[0].split(x_split_sizes, dim=0)) + x_sin_list = list(x_rope_output[1].split(x_split_sizes, dim=0)) + x_freqs_cis = ( + pad_sequence(x_cos_list, batch_first=True, padding_value=0.0)[:, :x.shape[1]], + pad_sequence(x_sin_list, batch_first=True, padding_value=0.0)[:, :x.shape[1]], + ) + else: + x_freqs_cis_list = list(x_rope_output.split(x_split_sizes, dim=0)) + x_freqs_cis = pad_sequence(x_freqs_cis_list, batch_first=True, padding_value=0.0)[:, :x.shape[1]] x_attn_mask = torch.zeros((bsz, x_max_item_seqlen), dtype=torch.bool, device=device) for i, seq_len in enumerate(x_item_seqlens): @@ -788,14 +839,19 @@ def forward( cap_feats = self.cap_embedder(cap_feats) cap_feats[torch.cat(cap_inner_pad_mask)] = self.cap_pad_token cap_feats = list(cap_feats.split(cap_item_seqlens, dim=0)) - cap_freqs_cis = list( - self.rope_embedder(torch.cat(cap_pos_ids, dim=0)).split([len(_) for _ in cap_pos_ids], dim=0) - ) - + cap_rope_output = self.rope_embedder(torch.cat(cap_pos_ids, dim=0)) + cap_split_sizes = [len(_) for _ in cap_pos_ids] cap_feats = pad_sequence(cap_feats, batch_first=True, padding_value=0.0) - cap_freqs_cis = pad_sequence(cap_freqs_cis, batch_first=True, padding_value=0.0) - # Clarify the length matches to satisfy Dynamo due to "Symbolic Shape Inference" to avoid compilation errors - cap_freqs_cis = cap_freqs_cis[:, : cap_feats.shape[1]] + if isinstance(cap_rope_output, tuple): + cap_cos_list = list(cap_rope_output[0].split(cap_split_sizes, dim=0)) + cap_sin_list = list(cap_rope_output[1].split(cap_split_sizes, dim=0)) + cap_freqs_cis = ( + pad_sequence(cap_cos_list, batch_first=True, padding_value=0.0)[:, :cap_feats.shape[1]], + pad_sequence(cap_sin_list, batch_first=True, padding_value=0.0)[:, :cap_feats.shape[1]], + ) + else: + cap_freqs_cis_list = list(cap_rope_output.split(cap_split_sizes, dim=0)) + cap_freqs_cis = pad_sequence(cap_freqs_cis_list, batch_first=True, padding_value=0.0)[:, :cap_feats.shape[1]] cap_attn_mask = torch.zeros((bsz, cap_max_item_seqlen), dtype=torch.bool, device=device) for i, seq_len in enumerate(cap_item_seqlens): @@ -809,19 +865,32 @@ def forward( cap_feats = layer(cap_feats, cap_attn_mask, cap_freqs_cis) # unified + is_real = isinstance(x_freqs_cis, tuple) unified = [] - unified_freqs_cis = [] + unified_freqs_cos = [] if is_real else None + unified_freqs_sin = [] if is_real else None + unified_freqs_cis = [] if not is_real else None for i in range(bsz): x_len = x_item_seqlens[i] cap_len = cap_item_seqlens[i] unified.append(torch.cat([x[i][:x_len], cap_feats[i][:cap_len]])) - unified_freqs_cis.append(torch.cat([x_freqs_cis[i][:x_len], cap_freqs_cis[i][:cap_len]])) + if is_real: + unified_freqs_cos.append(torch.cat([x_freqs_cis[0][i][:x_len], cap_freqs_cis[0][i][:cap_len]])) + unified_freqs_sin.append(torch.cat([x_freqs_cis[1][i][:x_len], cap_freqs_cis[1][i][:cap_len]])) + else: + unified_freqs_cis.append(torch.cat([x_freqs_cis[i][:x_len], cap_freqs_cis[i][:cap_len]])) unified_item_seqlens = [a + b for a, b in zip(cap_item_seqlens, x_item_seqlens)] assert unified_item_seqlens == [len(_) for _ in unified] unified_max_item_seqlen = max(unified_item_seqlens) unified = pad_sequence(unified, batch_first=True, padding_value=0.0) - unified_freqs_cis = pad_sequence(unified_freqs_cis, batch_first=True, padding_value=0.0) + if is_real: + unified_freqs_cis = ( + pad_sequence(unified_freqs_cos, batch_first=True, padding_value=0.0), + pad_sequence(unified_freqs_sin, batch_first=True, padding_value=0.0), + ) + else: + unified_freqs_cis = pad_sequence(unified_freqs_cis, batch_first=True, padding_value=0.0) unified_attn_mask = torch.zeros((bsz, unified_max_item_seqlen), dtype=torch.bool, device=device) for i, seq_len in enumerate(unified_item_seqlens): unified_attn_mask[i, :seq_len] = 1 diff --git a/src/diffusers/models/transformers/transformer_z_image.py b/src/diffusers/models/transformers/transformer_z_image.py index 4cea745e5ed5..5c6f481a4fe8 100644 --- a/src/diffusers/models/transformers/transformer_z_image.py +++ b/src/diffusers/models/transformers/transformer_z_image.py @@ -34,6 +34,27 @@ X_PAD_DIM = 64 +def apply_rotary_emb( + x: torch.Tensor, + freqs_cis: torch.Tensor | tuple[torch.Tensor], + use_real: bool = True, +) -> torch.Tensor: + if use_real: + cos, sin = freqs_cis + cos = cos.unsqueeze(2).to(x.device) + sin = sin.unsqueeze(2).to(x.device) + x_real, x_imag = x.float().reshape(*x.shape[:-1], -1, 2).unbind(-1) + x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3) + out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype) + return out + else: + with torch.amp.autocast("cuda", enabled=False): + x_complex = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2)) + freqs_cis = freqs_cis.unsqueeze(2) + x_out = torch.view_as_real(x_complex * freqs_cis).flatten(3) + return x_out.type_as(x) + + class TimestepEmbedder(nn.Module): def __init__(self, out_size, mid_size=None, frequency_embedding_size=256): super().__init__() @@ -81,11 +102,12 @@ class ZSingleStreamAttnProcessor: _attention_backend = None _parallel_config = None - def __init__(self): + def __init__(self, use_real: bool = False): if not hasattr(F, "scaled_dot_product_attention"): raise ImportError( "ZSingleStreamAttnProcessor requires PyTorch 2.0. To use it, please upgrade PyTorch to version 2.0 or higher." ) + self.use_real = use_real def __call__( self, @@ -109,17 +131,9 @@ def __call__( if attn.norm_k is not None: key = attn.norm_k(key) - # Apply RoPE - def apply_rotary_emb(x_in: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor: - with torch.amp.autocast("cuda", enabled=False): - x = torch.view_as_complex(x_in.float().reshape(*x_in.shape[:-1], -1, 2)) - freqs_cis = freqs_cis.unsqueeze(2) - x_out = torch.view_as_real(x * freqs_cis).flatten(3) - return x_out.type_as(x_in) # todo - if freqs_cis is not None: - query = apply_rotary_emb(query, freqs_cis) - key = apply_rotary_emb(key, freqs_cis) + query = apply_rotary_emb(query, freqs_cis, use_real=self.use_real) + key = apply_rotary_emb(key, freqs_cis, use_real=self.use_real) # Cast to correct dtype dtype = query.dtype @@ -191,6 +205,7 @@ def __init__( norm_eps: float, qk_norm: bool, modulation=True, + use_real: bool = False, ): super().__init__() self.dim = dim @@ -207,7 +222,7 @@ def __init__( eps=1e-5, bias=False, out_bias=False, - processor=ZSingleStreamAttnProcessor(), + processor=ZSingleStreamAttnProcessor(use_real=use_real), ) self.feed_forward = FeedForward(dim=dim, hidden_dim=int(dim / 3 * 8)) @@ -316,22 +331,29 @@ def __init__( theta: float = 256.0, axes_dims: list[int] = (16, 56, 56), axes_lens: list[int] = (64, 128, 128), + use_real: bool = False, ): self.theta = theta self.axes_dims = axes_dims self.axes_lens = axes_lens + self.use_real = use_real assert len(axes_dims) == len(axes_lens), "axes_dims and axes_lens must have the same length" self.freqs_cis = None @staticmethod - def precompute_freqs_cis(dim: list[int], end: list[int], theta: float = 256.0): + def precompute_freqs_cis(dim: list[int], end: list[int], theta: float = 256.0, use_real: bool = False): with torch.device("cpu"): freqs_cis = [] for i, (d, e) in enumerate(zip(dim, end)): freqs = 1.0 / (theta ** (torch.arange(0, d, 2, dtype=torch.float64, device="cpu") / d)) timestep = torch.arange(e, device=freqs.device, dtype=torch.float64) freqs = torch.outer(timestep, freqs).float() - freqs_cis_i = torch.polar(torch.ones_like(freqs), freqs).to(torch.complex64) # complex64 + if use_real: + cos = freqs.cos().repeat_interleave(2, dim=-1) + sin = freqs.sin().repeat_interleave(2, dim=-1) + freqs_cis_i = (cos, sin) + else: + freqs_cis_i = torch.polar(torch.ones_like(freqs), freqs).to(torch.complex64) freqs_cis.append(freqs_cis_i) return freqs_cis @@ -342,18 +364,34 @@ def __call__(self, ids: torch.Tensor): device = ids.device if self.freqs_cis is None: - self.freqs_cis = self.precompute_freqs_cis(self.axes_dims, self.axes_lens, theta=self.theta) - self.freqs_cis = [freqs_cis.to(device) for freqs_cis in self.freqs_cis] + self.freqs_cis = self.precompute_freqs_cis(self.axes_dims, self.axes_lens, theta=self.theta, use_real=self.use_real) + if self.use_real: + self.freqs_cis = [(cos.to(device), sin.to(device)) for cos, sin in self.freqs_cis] + else: + self.freqs_cis = [freqs_cis.to(device) for freqs_cis in self.freqs_cis] else: # Ensure freqs_cis are on the same device as ids - if self.freqs_cis[0].device != device: - self.freqs_cis = [freqs_cis.to(device) for freqs_cis in self.freqs_cis] - - result = [] - for i in range(len(self.axes_dims)): - index = ids[:, i] - result.append(self.freqs_cis[i][index]) - return torch.cat(result, dim=-1) + if self.use_real: + if self.freqs_cis[0][0].device != device: + self.freqs_cis = [(cos.to(device), sin.to(device)) for cos, sin in self.freqs_cis] + else: + if self.freqs_cis[0].device != device: + self.freqs_cis = [freqs_cis.to(device) for freqs_cis in self.freqs_cis] + + if self.use_real: + cos_result = [] + sin_result = [] + for i in range(len(self.axes_dims)): + index = ids[:, i] + cos_result.append(self.freqs_cis[i][0][index]) + sin_result.append(self.freqs_cis[i][1][index]) + return (torch.cat(cos_result, dim=-1), torch.cat(sin_result, dim=-1)) + else: + result = [] + for i in range(len(self.axes_dims)): + index = ids[:, i] + result.append(self.freqs_cis[i][index]) + return torch.cat(result, dim=-1) class ZImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin): @@ -381,6 +419,7 @@ def __init__( t_scale=1000.0, axes_dims=[32, 48, 48], axes_lens=[1024, 512, 512], + use_real: bool = False, ) -> None: super().__init__() self.in_channels = in_channels @@ -417,6 +456,7 @@ def __init__( norm_eps, qk_norm, modulation=True, + use_real=use_real, ) for layer_id in range(n_refiner_layers) ] @@ -431,6 +471,7 @@ def __init__( norm_eps, qk_norm, modulation=False, + use_real=use_real, ) for layer_id in range(n_refiner_layers) ] @@ -453,6 +494,7 @@ def __init__( norm_eps, qk_norm, modulation=False, + use_real=use_real, ) for layer_id in range(n_refiner_layers) ] @@ -468,7 +510,7 @@ def __init__( self.layers = nn.ModuleList( [ - ZImageTransformerBlock(layer_id, dim, n_heads, n_kv_heads, norm_eps, qk_norm) + ZImageTransformerBlock(layer_id, dim, n_heads, n_kv_heads, norm_eps, qk_norm, use_real=use_real) for layer_id in range(n_layers) ] ) @@ -477,7 +519,7 @@ def __init__( self.axes_dims = axes_dims self.axes_lens = axes_lens - self.rope_embedder = RopeEmbedder(theta=rope_theta, axes_dims=axes_dims, axes_lens=axes_lens) + self.rope_embedder = RopeEmbedder(theta=rope_theta, axes_dims=axes_dims, axes_lens=axes_lens, use_real=use_real) def unpatchify( self, @@ -782,11 +824,20 @@ def _prepare_sequence( feats = list(feats_cat.split(item_seqlens, dim=0)) # RoPE - freqs_cis = list(self.rope_embedder(torch.cat(pos_ids, dim=0)).split([len(p) for p in pos_ids], dim=0)) + rope_output = self.rope_embedder(torch.cat(pos_ids, dim=0)) + split_sizes = [len(p) for p in pos_ids] + if isinstance(rope_output, tuple): + cos_list = list(rope_output[0].split(split_sizes, dim=0)) + sin_list = list(rope_output[1].split(split_sizes, dim=0)) + cos_padded = pad_sequence(cos_list, batch_first=True, padding_value=0.0)[:, :max_seqlen] + sin_padded = pad_sequence(sin_list, batch_first=True, padding_value=0.0)[:, :max_seqlen] + freqs_cis = (cos_padded, sin_padded) + else: + freqs_cis_list = list(rope_output.split(split_sizes, dim=0)) + freqs_cis = pad_sequence(freqs_cis_list, batch_first=True, padding_value=0.0)[:, :max_seqlen] # Pad to batch feats = pad_sequence(feats, batch_first=True, padding_value=0.0) - freqs_cis = pad_sequence(freqs_cis, batch_first=True, padding_value=0.0)[:, : feats.shape[1]] # Attention mask if all(seq == max_seqlen for seq in item_seqlens): @@ -810,15 +861,15 @@ def _prepare_sequence( def _build_unified_sequence( self, x: torch.Tensor, - x_freqs: torch.Tensor, + x_freqs: torch.Tensor | tuple[torch.Tensor, torch.Tensor], x_seqlens: list[int], x_noise_mask: list[list[int]] | None, cap: torch.Tensor, - cap_freqs: torch.Tensor, + cap_freqs: torch.Tensor | tuple[torch.Tensor, torch.Tensor], cap_seqlens: list[int], cap_noise_mask: list[list[int]] | None, siglip: torch.Tensor | None, - siglip_freqs: torch.Tensor | None, + siglip_freqs: torch.Tensor | tuple[torch.Tensor, torch.Tensor] | None, siglip_seqlens: list[int] | None, siglip_noise_mask: list[list[int]] | None, omni_mode: bool, @@ -828,8 +879,11 @@ def _build_unified_sequence( Basic mode order: [x, cap]; Omni mode order: [cap, x, siglip] """ bsz = len(x_seqlens) + is_real = isinstance(x_freqs, tuple) unified = [] - unified_freqs = [] + unified_freqs_cos = [] if is_real else None + unified_freqs_sin = [] if is_real else None + unified_freqs = [] if not is_real else None unified_noise_mask = [] for i in range(bsz): @@ -840,9 +894,17 @@ def _build_unified_sequence( if siglip is not None and siglip_seqlens is not None: sig_len = siglip_seqlens[i] unified.append(torch.cat([cap[i][:cap_len], x[i][:x_len], siglip[i][:sig_len]])) - unified_freqs.append( - torch.cat([cap_freqs[i][:cap_len], x_freqs[i][:x_len], siglip_freqs[i][:sig_len]]) - ) + if is_real: + unified_freqs_cos.append(torch.cat([ + cap_freqs[0][i][:cap_len], x_freqs[0][i][:x_len], siglip_freqs[0][i][:sig_len] + ])) + unified_freqs_sin.append(torch.cat([ + cap_freqs[1][i][:cap_len], x_freqs[1][i][:x_len], siglip_freqs[1][i][:sig_len] + ])) + else: + unified_freqs.append(torch.cat([ + cap_freqs[i][:cap_len], x_freqs[i][:x_len], siglip_freqs[i][:sig_len] + ])) unified_noise_mask.append( torch.tensor( cap_noise_mask[i] + x_noise_mask[i] + siglip_noise_mask[i], dtype=torch.long, device=device @@ -850,14 +912,22 @@ def _build_unified_sequence( ) else: unified.append(torch.cat([cap[i][:cap_len], x[i][:x_len]])) - unified_freqs.append(torch.cat([cap_freqs[i][:cap_len], x_freqs[i][:x_len]])) + if is_real: + unified_freqs_cos.append(torch.cat([cap_freqs[0][i][:cap_len], x_freqs[0][i][:x_len]])) + unified_freqs_sin.append(torch.cat([cap_freqs[1][i][:cap_len], x_freqs[1][i][:x_len]])) + else: + unified_freqs.append(torch.cat([cap_freqs[i][:cap_len], x_freqs[i][:x_len]])) unified_noise_mask.append( torch.tensor(cap_noise_mask[i] + x_noise_mask[i], dtype=torch.long, device=device) ) else: # Basic: [x, cap] unified.append(torch.cat([x[i][:x_len], cap[i][:cap_len]])) - unified_freqs.append(torch.cat([x_freqs[i][:x_len], cap_freqs[i][:cap_len]])) + if is_real: + unified_freqs_cos.append(torch.cat([x_freqs[0][i][:x_len], cap_freqs[0][i][:cap_len]])) + unified_freqs_sin.append(torch.cat([x_freqs[1][i][:x_len], cap_freqs[1][i][:cap_len]])) + else: + unified_freqs.append(torch.cat([x_freqs[i][:x_len], cap_freqs[i][:cap_len]])) # Compute unified seqlens if omni_mode: @@ -872,7 +942,13 @@ def _build_unified_sequence( # Pad to batch unified = pad_sequence(unified, batch_first=True, padding_value=0.0) - unified_freqs = pad_sequence(unified_freqs, batch_first=True, padding_value=0.0) + if is_real: + unified_freqs = ( + pad_sequence(unified_freqs_cos, batch_first=True, padding_value=0.0), + pad_sequence(unified_freqs_sin, batch_first=True, padding_value=0.0), + ) + else: + unified_freqs = pad_sequence(unified_freqs, batch_first=True, padding_value=0.0) # Attention mask if all(seq == max_seqlen for seq in unified_seqlens):