Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 47 additions & 35 deletions tests/models/autoencoders/test_models_autoencoder_kl_cogvideox.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,30 +13,38 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

import pytest
import torch

from diffusers import AutoencoderKLCogVideoX
from diffusers.utils.torch_utils import randn_tensor

from ...testing_utils import (
enable_full_determinism,
floats_tensor,
torch_device,
)
from ..test_modeling_common import ModelTesterMixin
from .testing_utils import AutoencoderTesterMixin
from ...testing_utils import enable_full_determinism, torch_device
from ..testing_utils import BaseModelTesterConfig, MemoryTesterMixin, ModelTesterMixin, TrainingTesterMixin
from .testing_utils import NewAutoencoderTesterMixin


enable_full_determinism()


class AutoencoderKLCogVideoXTests(ModelTesterMixin, AutoencoderTesterMixin, unittest.TestCase):
model_class = AutoencoderKLCogVideoX
main_input_name = "sample"
base_precision = 1e-2
class AutoencoderKLCogVideoXTesterConfig(BaseModelTesterConfig):
@property
def model_class(self):
return AutoencoderKLCogVideoX

def get_autoencoder_kl_cogvideox_config(self):
@property
def main_input_name(self) -> str:
return "sample"

@property
def output_shape(self) -> tuple:
return (3, 8, 16, 16)

@property
def generator(self):
return torch.Generator("cpu").manual_seed(0)

def get_init_dict(self) -> dict:
return {
"in_channels": 3,
"out_channels": 3,
Expand All @@ -59,29 +67,27 @@ def get_autoencoder_kl_cogvideox_config(self):
"temporal_compression_ratio": 4,
}

@property
def dummy_input(self):
def get_dummy_inputs(self) -> dict:
batch_size = 4
num_frames = 8
num_channels = 3
sizes = (16, 16)
image = randn_tensor(
(batch_size, num_channels, num_frames, *sizes), generator=self.generator, device=torch_device
)
return {"sample": image}

image = floats_tensor((batch_size, num_channels, num_frames) + sizes).to(torch_device)

return {"sample": image}
class TestAutoencoderKLCogVideoX(AutoencoderKLCogVideoXTesterConfig, ModelTesterMixin):
base_precision = 1e-2
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't think this attribute is needed.


@property
def input_shape(self):
return (3, 8, 16, 16)
@pytest.mark.skip("Unsupported test.")
def test_outputs_equivalence(self):
super().test_outputs_equivalence()

Comment on lines +84 to 87
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we have to skip this?

@property
def output_shape(self):
return (3, 8, 16, 16)

def prepare_init_args_and_inputs_for_common(self):
init_dict = self.get_autoencoder_kl_cogvideox_config()
inputs_dict = self.dummy_input
return init_dict, inputs_dict
class TestAutoencoderKLCogVideoXTraining(AutoencoderKLCogVideoXTesterConfig, TrainingTesterMixin):
"""Training tests for AutoencoderKLCogVideoX."""

def test_gradient_checkpointing_is_applied(self):
expected_set = {
Expand All @@ -93,8 +99,18 @@ def test_gradient_checkpointing_is_applied(self):
}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)


class TestAutoencoderKLCogVideoXMemory(AutoencoderKLCogVideoXTesterConfig, MemoryTesterMixin):
"""Memory optimization tests for AutoencoderKLCogVideoX."""


class TestAutoencoderKLCogVideoXSlicingTiling(AutoencoderKLCogVideoXTesterConfig, NewAutoencoderTesterMixin):
"""Slicing and tiling tests for AutoencoderKLCogVideoX."""

# Overwritten because the base test's block_out_channels doesn't account for the length of down_block_types.
def test_forward_with_norm_groups(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()

init_dict["norm_num_groups"] = 16
init_dict["block_out_channels"] = (16, 32, 32, 32)
Expand All @@ -109,10 +125,6 @@ def test_forward_with_norm_groups(self):
if isinstance(output, dict):
output = output.to_tuple()[0]

self.assertIsNotNone(output)
assert output is not None
expected_shape = inputs_dict["sample"].shape
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")

@unittest.skip("Unsupported test.")
def test_outputs_equivalence(self):
pass
assert output.shape == expected_shape, "Input and output shapes do not match"
Loading