diff --git a/src/virtual_stain_flow/engine/loss_group.py b/src/virtual_stain_flow/engine/loss_group.py index 16b3a74..72e79c3 100644 --- a/src/virtual_stain_flow/engine/loss_group.py +++ b/src/virtual_stain_flow/engine/loss_group.py @@ -22,7 +22,7 @@ """ from dataclasses import dataclass -from typing import Optional, Union, Tuple, Dict, Sequence, List +from typing import Optional, Union, Tuple, Dict, Sequence, List, Any import torch @@ -79,6 +79,7 @@ def __post_init__(self): def __call__( self, train: bool, + epoch: Optional[int] = None, context: Optional[Context] = None, **inputs: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: @@ -88,6 +89,7 @@ def __call__( skipped during validation. :param train: Whether the model is in training mode. + :param epoch: Optional epoch number to determine the weight from the schedule. :param context: Optional Context object containing tensors. :param inputs: Keyword arguments containing all necessary inputs for the loss computation. @@ -117,6 +119,21 @@ def __call__( return raw, raw * _scalar_from_ctx(self.weight, inputs) + def get_config(self) -> Dict[str, Any]: + """ + Get the configuration of the LossItem for logging or checkpointing. + """ + return { + 'module': self.module.__class__.__name__, + 'args': self.args, + 'key': self.key, + 'weight': self.weight, + 'enabled': self.enabled, + 'compute_at_val': self.compute_at_val, + 'device': str(self.device) + } + + @dataclass class LossGroup: """ @@ -137,6 +154,7 @@ def item_names(self) -> List[Optional[str]]: def __call__( self, train: bool, + epoch: Optional[int] = None, context: Optional[Context] = None, **inputs: torch.Tensor ) -> Tuple[torch.Tensor, Dict[str, Scalar]]: @@ -144,6 +162,7 @@ def __call__( Compute the total loss and individual loss values. :param train: Whether the model is in training mode. + :param epoch: Optional epoch number to determine the weight from the schedule. :param context: Optional Context object containing tensors. :input inputs: Keyword arguments containing all necessary inputs for the loss computations. @@ -156,8 +175,20 @@ def __call__( logs: Dict[str, float] = {} for item in self.items: - raw, weighted = item(train, context=context, **inputs) + raw, weighted = item( + train, + epoch=epoch, + context=context, + **inputs + ) logs[item.key] = raw.item() # type: ignore total += weighted return total, logs + + def get_config(self) -> List[Dict[str, Any]]: + """ + Get the configuration of the LossGroup for logging or checkpointing. + """ + + return [item.get_config() for item in self.items] diff --git a/src/virtual_stain_flow/trainers/logging_gan_trainer.py b/src/virtual_stain_flow/trainers/logging_gan_trainer.py index 8e5fe69..d50c437 100644 --- a/src/virtual_stain_flow/trainers/logging_gan_trainer.py +++ b/src/virtual_stain_flow/trainers/logging_gan_trainer.py @@ -187,6 +187,13 @@ def evaluate_step( metric.update(*ctx.as_metric_args(), validation=True) return gen_logs | disc_logs + + @property + def loss_groups(self) -> Dict[str, LossGroup]: + return { + 'generator': self._generator_loss_group, + 'discriminator': self._discriminator_loss_group + } def save_model( self, diff --git a/src/virtual_stain_flow/trainers/logging_trainer.py b/src/virtual_stain_flow/trainers/logging_trainer.py index 6442dea..27e9f09 100644 --- a/src/virtual_stain_flow/trainers/logging_trainer.py +++ b/src/virtual_stain_flow/trainers/logging_trainer.py @@ -105,7 +105,11 @@ def train_step( targets=targets ) - weighted_total, logs = self._loss_group(train=True, context=ctx) + weighted_total, logs = self._loss_group( + train=True, + epoch=self.epoch, + context=ctx + ) weighted_total.backward() self._forward_group.step() @@ -133,12 +137,20 @@ def evaluate_step( targets=targets ) - _, logs = self._loss_group(train=False, context=ctx) + _, logs = self._loss_group( + train=False, + epoch=self.epoch, + context=ctx + ) for _, metric in self.metrics.items(): metric.update(*ctx.as_metric_args(), validation=True) return logs + + @property + def loss_groups(self) -> Dict[str, LossGroup]: + return {'main': self._loss_group} def save_model( self, diff --git a/src/virtual_stain_flow/vsf_logging/MlflowLogger.py b/src/virtual_stain_flow/vsf_logging/MlflowLogger.py index a46c61d..9bfe22b 100644 --- a/src/virtual_stain_flow/vsf_logging/MlflowLogger.py +++ b/src/virtual_stain_flow/vsf_logging/MlflowLogger.py @@ -222,6 +222,8 @@ def on_train_start(self): except Exception as e: print(f"Fail to log model config as artifact: {e}") + self._log_loss_groups_config_and_tags() + for callback in self.callbacks: # TODO consider if we want hasattr checks @@ -502,6 +504,92 @@ def _save_model_weights( artifact_path=artifact_path ) + def _get_loss_groups(self) -> Dict[str, Any]: + """ + Discover loss groups attached to the bound trainer. + """ + + if self.trainer is None: + return {} + + loss_groups: Dict[str, Any] = {} + + explicit_groups = getattr(self.trainer, 'loss_groups', None) + if isinstance(explicit_groups, dict): + for group_name, group in explicit_groups.items(): + if hasattr(group, 'get_config'): + loss_groups[str(group_name)] = group + + fallback_attrs = { + 'main': '_loss_group', + 'generator': '_generator_loss_group', + 'discriminator': '_discriminator_loss_group' + } + for group_name, attr in fallback_attrs.items(): + if group_name in loss_groups: + continue + group = getattr(self.trainer, attr, None) + if group is not None and hasattr(group, 'get_config'): + loss_groups[group_name] = group + + return loss_groups + + def _log_loss_groups_config_and_tags(self) -> None: + """ + Log loss item names and weights as flat tags and full loss group + configuration as config artifacts. + """ + + loss_groups = self._get_loss_groups() + if not loss_groups: + return None + + for group_name, group in loss_groups.items(): + try: + group_config = group.get_config() + except Exception as e: + print( + f"Could not get loss group config for logging " + f"({group_name}): {e}" + ) + continue + + if not isinstance(group_config, list): + continue + + for idx, item_cfg in enumerate(group_config): + if not isinstance(item_cfg, dict): + continue + + if 'key' in item_cfg and item_cfg['key'] is not None: + mlflow.set_tag( + f"loss.{group_name}.{idx}.name", + str(item_cfg['key']) + ) + + if 'weight' in item_cfg and item_cfg['weight'] is not None: + mlflow.set_tag( + f"loss.{group_name}.{idx}.weight", + str(item_cfg['weight']) + ) + + try: + self.log_config( + tag=f"loss_group_{group_name}", + config={ + 'group_name': group_name, + 'items': group_config + }, + stage=None + ) + except Exception as e: + print( + f"Fail to log loss group config as artifact " + f"({group_name}): {e}" + ) + + return None + def log_config( self, tag: str, diff --git a/tests/conftest.py b/tests/conftest.py index 5859a44..b635c69 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -2,16 +2,21 @@ Testing fixtures meant to be shared across the whole package """ +import json +import importlib import pathlib +from types import SimpleNamespace import pytest import torch from torch.utils.data import DataLoader, Dataset +from virtual_stain_flow.trainers.AbstractTrainer import AbstractTrainer +from virtual_stain_flow.trainers.logging_trainer import SingleGeneratorTrainer from virtual_stain_flow.vsf_logging import MlflowLogger -# ----- Mock virtual_stain_flow components ----- # +# ----- Logger test doubles ----- # class DummyLogger(MlflowLogger): """ @@ -75,6 +80,8 @@ def dummy_logger(): return DummyLogger() +# ----- Model/optimizer fixtures ----- # + class MockModelWithSaveWeights(torch.nn.Module): """ Mock model that implements save_weights method for testing. @@ -115,7 +122,7 @@ def mock_optimizer(mock_model_with_save): return torch.optim.Adam(mock_model_with_save.parameters(), lr=0.001) -# ----- Fixtures for simulating minimal training ----- # +# ----- Dataset/dataloader fixtures ----- # class MinimalDataset(Dataset): """Minimal torch.utils.data.Dataset to test training.""" @@ -222,7 +229,6 @@ def empty_dataloader(): return DataLoader(dataset, batch_size=2, shuffle=False) - @pytest.fixture def image_train_loader(image_dataset): """Create a train dataloader with image data.""" @@ -243,6 +249,8 @@ def image_val_loader(image_dataset): return DataLoader(val_dataset, batch_size=4, shuffle=False) +# ----- Generic training fixtures ----- # + @pytest.fixture def simple_loss(): """Create a simple MSE loss function.""" @@ -278,3 +286,291 @@ def mock_metric(): def dataset_for_splitting(): """Create a larger dataset suitable for train/val/test splitting.""" return MinimalDataset(num_samples=100, input_size=4, target_size=2) + + +# ----- Trainer fixtures ----- # + +class MinimalTrainerRealization(AbstractTrainer): + """ + Minimal concrete realization of AbstractTrainer for testing. + Tracks method calls and provides controllable step behavior. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.train_step_calls = [] + self.evaluate_step_calls = [] + self.on_epoch_start_called = False + self.on_epoch_end_called = False + + class DummyProgressBar: + def set_postfix_str(self, *args, **kwargs): + pass + + self._epoch_pbar = DummyProgressBar() # type: ignore + + def train_step(self, inputs: torch.Tensor, targets: torch.Tensor) -> dict: + self.train_step_calls.append({ + 'inputs_shape': inputs.shape, + 'targets_shape': targets.shape, + }) + + return { + 'loss_a': torch.tensor(0.5), + 'loss_b': torch.tensor(0.3), + } + + def evaluate_step(self, inputs: torch.Tensor, targets: torch.Tensor) -> dict: + self.evaluate_step_calls.append({ + 'inputs_shape': inputs.shape, + 'targets_shape': targets.shape, + }) + + return { + 'loss_a': torch.tensor(0.4), + 'loss_b': torch.tensor(0.2), + } + + def save_model(self, save_path, file_name_prefix=None, file_name_suffix=None, + file_ext='.pth', best_model=True): + return None + + +@pytest.fixture +def minimal_trainer_cls(): + """Expose the minimal concrete trainer class for tests needing custom init.""" + return MinimalTrainerRealization + + +@pytest.fixture +def trainer_with_loaders(minimal_model, minimal_optimizer, train_dataloader, val_dataloader): + """ + Create a MinimalTrainerRealization with train and validation loaders. + """ + trainer = MinimalTrainerRealization( + model=minimal_model, + optimizer=minimal_optimizer, + train_loader=train_dataloader, + val_loader=val_dataloader, + batch_size=2, + device=torch.device('cpu') + ) + return trainer + + +@pytest.fixture +def trainer_with_empty_val_loader(minimal_model, minimal_optimizer, train_dataloader, empty_dataloader): + """ + Create a MinimalTrainerRealization with empty validation loader. + """ + trainer = MinimalTrainerRealization( + model=minimal_model, + optimizer=minimal_optimizer, + train_loader=train_dataloader, + val_loader=empty_dataloader, + batch_size=2, + device=torch.device('cpu') + ) + return trainer + + +@pytest.fixture +def single_generator_trainer(minimal_model, minimal_optimizer, simple_loss, train_dataloader, val_dataloader): + """ + Create a SingleGeneratorTrainer with a single loss function. + """ + trainer = SingleGeneratorTrainer( + model=minimal_model, + optimizer=minimal_optimizer, + losses=simple_loss, + device=torch.device('cpu'), + train_loader=train_dataloader, + val_loader=val_dataloader, + batch_size=2 + ) + return trainer + + +@pytest.fixture +def multi_loss_trainer(minimal_model, minimal_optimizer, multiple_losses, train_dataloader, val_dataloader): + """ + Create a SingleGeneratorTrainer with multiple loss functions. + """ + trainer = SingleGeneratorTrainer( + model=minimal_model, + optimizer=minimal_optimizer, + losses=multiple_losses, + device=torch.device('cpu'), + loss_weights=[0.5, 0.5], + train_loader=train_dataloader, + val_loader=val_dataloader, + batch_size=2 + ) + return trainer + + +@pytest.fixture +def conv_trainer(conv_model, conv_optimizer, simple_loss, image_train_loader, image_val_loader): + """ + Create a SingleGeneratorTrainer with conv model for full training tests. + """ + trainer = SingleGeneratorTrainer( + model=conv_model, + optimizer=conv_optimizer, + losses=simple_loss, + device=torch.device('cpu'), + train_loader=image_train_loader, + val_loader=image_val_loader, + batch_size=4, + early_termination_metric='MSELoss' + ) + return trainer + + +@pytest.fixture +def simple_discriminator(): + """ + Simple discriminator model for GAN testing. + Takes concatenated input/target stack (B, 2, H, W) -> outputs score (B, 1) + """ + import torch.nn as nn + + class SimpleDiscriminator(nn.Module): + def __init__(self): + super().__init__() + self.conv = nn.Conv2d(in_channels=2, out_channels=16, kernel_size=3, padding=1) + self.pool = nn.AdaptiveAvgPool2d(1) + self.fc = nn.Linear(16, 1) + + def forward(self, x): + x = torch.relu(self.conv(x)) + x = self.pool(x).flatten(1) + return self.fc(x) + + return SimpleDiscriminator() + + +@pytest.fixture +def discriminator_optimizer(simple_discriminator): + """Create an optimizer for the discriminator.""" + return torch.optim.Adam(simple_discriminator.parameters(), lr=0.0001) + + +@pytest.fixture +def wgan_trainer(conv_model, simple_discriminator, conv_optimizer, discriminator_optimizer, + simple_loss, image_train_loader, image_val_loader): + """ + Create a LoggingWGANTrainer for testing. + """ + from virtual_stain_flow.trainers.logging_gan_trainer import LoggingWGANTrainer + + trainer = LoggingWGANTrainer( + generator=conv_model, + discriminator=simple_discriminator, + generator_optimizer=conv_optimizer, + discriminator_optimizer=discriminator_optimizer, + generator_losses=simple_loss, + device=torch.device('cpu'), + train_loader=image_train_loader, + val_loader=image_val_loader, + batch_size=4, + n_discriminator_steps=3 + ) + return trainer + + +# ----- MLflow patch fixture ----- # + +@pytest.fixture +def patched_mlflow(monkeypatch): + """Patch MLflow module methods used by MlflowLogger and capture calls.""" + + captured = { + 'tags': {}, + 'artifacts': [], + 'active_run_id': None, + } + + mlflow_logger_module = importlib.import_module( + 'virtual_stain_flow.vsf_logging.MlflowLogger' + ) + + def fake_get_experiment_by_name(_name): + return None + + def fake_create_experiment(_name): + return 'exp-1' + + def fake_start_run(*args, **kwargs): + run_id = 'run-123' + captured['active_run_id'] = run_id + return SimpleNamespace(info=SimpleNamespace(run_id=run_id)) + + def fake_active_run(): + run_id = captured['active_run_id'] + if run_id is None: + return None + return SimpleNamespace(info=SimpleNamespace(run_id=run_id)) + + def fake_end_run(): + captured['active_run_id'] = None + + def fake_set_tag(key, value): + captured['tags'][key] = value + + def fake_log_artifact(file_path, artifact_path=None): + file_content = None + try: + with open(file_path, 'r', encoding='utf-8') as f: + file_content = json.load(f) + except Exception: + file_content = None + + captured['artifacts'].append({ + 'file_path': file_path, + 'artifact_path': artifact_path, + 'content': file_content, + }) + + monkeypatch.setattr( + mlflow_logger_module.mlflow, + 'get_experiment_by_name', + fake_get_experiment_by_name, + ) + monkeypatch.setattr( + mlflow_logger_module.mlflow, + 'create_experiment', + fake_create_experiment, + ) + monkeypatch.setattr( + mlflow_logger_module.mlflow, + 'start_run', + fake_start_run, + ) + monkeypatch.setattr( + mlflow_logger_module.mlflow, + 'active_run', + fake_active_run, + ) + monkeypatch.setattr( + mlflow_logger_module.mlflow, + 'end_run', + fake_end_run, + ) + monkeypatch.setattr( + mlflow_logger_module.mlflow, + 'set_tag', + fake_set_tag, + ) + monkeypatch.setattr( + mlflow_logger_module.mlflow, + 'log_artifact', + fake_log_artifact, + ) + monkeypatch.setattr( + mlflow_logger_module.mlflow, + 'log_params', + lambda *_args, **_kwargs: None, + ) + + return captured diff --git a/tests/trainers/conftest.py b/tests/trainers/conftest.py deleted file mode 100644 index cd5e443..0000000 --- a/tests/trainers/conftest.py +++ /dev/null @@ -1,208 +0,0 @@ -""" -Fixtures for trainer tests -""" - -import pytest -import torch - -from virtual_stain_flow.trainers.AbstractTrainer import AbstractTrainer -from virtual_stain_flow.trainers.logging_trainer import SingleGeneratorTrainer - - -class MinimalTrainerRealization(AbstractTrainer): - """ - Minimal concrete realization of AbstractTrainer for testing. - Tracks method calls and provides controllable step behavior. - """ - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - # Track method calls for testing - self.train_step_calls = [] - self.evaluate_step_calls = [] - self.on_epoch_start_called = False - self.on_epoch_end_called = False - - # Create a dummy progress bar property that does nothing beyond - # allowing set_postfix_str calls - class DummyProgressBar: - def set_postfix_str(self, *args, **kwargs): - pass - - self._epoch_pbar = DummyProgressBar() # type: ignore - - def train_step(self, inputs: torch.Tensor, targets: torch.Tensor) -> dict: - """ - Minimal train step that returns a dict of losses. - Stores call information for verification. - """ - - self.train_step_calls.append({ - 'inputs_shape': inputs.shape, - 'targets_shape': targets.shape, - }) - - # Return scalar tensor losses (simulating real losses) - return { - 'loss_a': torch.tensor(0.5), - 'loss_b': torch.tensor(0.3), - } - - def evaluate_step(self, inputs: torch.Tensor, targets: torch.Tensor) -> dict: - """ - Minimal evaluate step that returns a dict of losses. - Stores call information for verification. - """ - - self.evaluate_step_calls.append({ - 'inputs_shape': inputs.shape, - 'targets_shape': targets.shape, - }) - - # Return scalar tensor losses (simulating real losses) - return { - 'loss_a': torch.tensor(0.4), - 'loss_b': torch.tensor(0.2), - } - - def save_model(self, save_path, file_name_prefix=None, file_name_suffix=None, - file_ext='.pth', best_model=True): - """Minimal save_model implementation.""" - return None - - -@pytest.fixture -def trainer_with_loaders(minimal_model, minimal_optimizer, train_dataloader, val_dataloader): - """ - Create a MinimalTrainerRealization with train and validation loaders. - """ - trainer = MinimalTrainerRealization( - model=minimal_model, - optimizer=minimal_optimizer, - train_loader=train_dataloader, - val_loader=val_dataloader, - batch_size=2, - device=torch.device('cpu') - ) - return trainer - - -@pytest.fixture -def trainer_with_empty_val_loader(minimal_model, minimal_optimizer, train_dataloader, empty_dataloader): - """ - Create a MinimalTrainerRealization with empty validation loader. - """ - trainer = MinimalTrainerRealization( - model=minimal_model, - optimizer=minimal_optimizer, - train_loader=train_dataloader, - val_loader=empty_dataloader, - batch_size=2, - device=torch.device('cpu') - ) - return trainer - - -@pytest.fixture -def single_generator_trainer(minimal_model, minimal_optimizer, simple_loss, train_dataloader, val_dataloader): - """ - Create a SingleGeneratorTrainer with a single loss function. - """ - trainer = SingleGeneratorTrainer( - model=minimal_model, - optimizer=minimal_optimizer, - losses=simple_loss, - device=torch.device('cpu'), - train_loader=train_dataloader, - val_loader=val_dataloader, - batch_size=2 - ) - return trainer - - -@pytest.fixture -def multi_loss_trainer(minimal_model, minimal_optimizer, multiple_losses, train_dataloader, val_dataloader): - """ - Create a SingleGeneratorTrainer with multiple loss functions. - """ - trainer = SingleGeneratorTrainer( - model=minimal_model, - optimizer=minimal_optimizer, - losses=multiple_losses, - device=torch.device('cpu'), - loss_weights=[0.5, 0.5], - train_loader=train_dataloader, - val_loader=val_dataloader, - batch_size=2 - ) - return trainer - - -@pytest.fixture -def conv_trainer(conv_model, conv_optimizer, simple_loss, image_train_loader, image_val_loader): - """ - Create a SingleGeneratorTrainer with conv model for full training tests. - """ - trainer = SingleGeneratorTrainer( - model=conv_model, - optimizer=conv_optimizer, - losses=simple_loss, - device=torch.device('cpu'), - train_loader=image_train_loader, - val_loader=image_val_loader, - batch_size=4, - early_termination_metric='MSELoss' - ) - return trainer - - -@pytest.fixture -def simple_discriminator(): - """ - Simple discriminator model for GAN testing. - Takes concatenated input/target stack (B, 2, H, W) -> outputs score (B, 1) - """ - import torch.nn as nn - - class SimpleDiscriminator(nn.Module): - def __init__(self): - super().__init__() - self.conv = nn.Conv2d(in_channels=2, out_channels=16, kernel_size=3, padding=1) - self.pool = nn.AdaptiveAvgPool2d(1) - self.fc = nn.Linear(16, 1) - - def forward(self, x): - x = torch.relu(self.conv(x)) - x = self.pool(x).flatten(1) - return self.fc(x) - - return SimpleDiscriminator() - - -@pytest.fixture -def discriminator_optimizer(simple_discriminator): - """Create an optimizer for the discriminator.""" - return torch.optim.Adam(simple_discriminator.parameters(), lr=0.0001) - - -@pytest.fixture -def wgan_trainer(conv_model, simple_discriminator, conv_optimizer, discriminator_optimizer, - simple_loss, image_train_loader, image_val_loader): - """ - Create a LoggingWGANTrainer for testing. - """ - from virtual_stain_flow.trainers.logging_gan_trainer import LoggingWGANTrainer - - trainer = LoggingWGANTrainer( - generator=conv_model, - discriminator=simple_discriminator, - generator_optimizer=conv_optimizer, - discriminator_optimizer=discriminator_optimizer, - generator_losses=simple_loss, - device=torch.device('cpu'), - train_loader=image_train_loader, - val_loader=image_val_loader, - batch_size=4, - n_discriminator_steps=3 - ) - return trainer diff --git a/tests/trainers/test_abstract_trainer.py b/tests/trainers/test_abstract_trainer.py index 5d9555a..62d6e38 100644 --- a/tests/trainers/test_abstract_trainer.py +++ b/tests/trainers/test_abstract_trainer.py @@ -5,7 +5,12 @@ import pytest import torch -from conftest import MinimalTrainerRealization + +@pytest.fixture(autouse=True) +def _bind_minimal_trainer_cls(minimal_trainer_cls): + """Bind concrete trainer class from fixture to avoid direct conftest imports.""" + global MinimalTrainerRealization + MinimalTrainerRealization = minimal_trainer_cls class TestTrainEpochBatchIteration: diff --git a/tests/vsf_logging/test_mlflow_logger_loss_config.py b/tests/vsf_logging/test_mlflow_logger_loss_config.py new file mode 100644 index 0000000..296940f --- /dev/null +++ b/tests/vsf_logging/test_mlflow_logger_loss_config.py @@ -0,0 +1,80 @@ +""" +Tests for automatic loss-group config logging in MlflowLogger. +""" + + +class TestMlflowLoggerLossConfigLogging: + + def test_on_train_start_logs_single_trainer_loss_tags_and_config( + self, + patched_mlflow, + single_generator_trainer, + ): + from virtual_stain_flow.vsf_logging.MlflowLogger import MlflowLogger + + captured = patched_mlflow + + logger = MlflowLogger( + name='logger', + experiment_name='exp', + ) + logger.bind_trainer(single_generator_trainer) + + logger.on_train_start() + + assert captured['tags']['loss.main.0.name'] == 'MSELoss' + assert captured['tags']['loss.main.0.weight'] == '1.0' + + loss_group_artifacts = [ + artifact + for artifact in captured['artifacts'] + if artifact['content'] is not None + and artifact['content'].get('group_name') == 'main' + ] + + assert len(loss_group_artifacts) == 1 + artifact = loss_group_artifacts[0] + assert artifact['artifact_path'] == 'configs' + assert len(artifact['content']['items']) == 1 + assert artifact['content']['items'][0]['key'] == 'MSELoss' + assert artifact['content']['items'][0]['weight'] == 1.0 + + logger.end_run() + + def test_on_train_start_logs_wgan_loss_tags_and_configs( + self, + patched_mlflow, + wgan_trainer, + ): + from virtual_stain_flow.vsf_logging.MlflowLogger import MlflowLogger + + captured = patched_mlflow + + logger = MlflowLogger( + name='logger', + experiment_name='exp', + ) + logger.bind_trainer(wgan_trainer) + + logger.on_train_start() + + assert captured['tags']['loss.generator.0.name'] == 'MSELoss' + assert captured['tags']['loss.generator.0.weight'] == '1.0' + assert captured['tags']['loss.generator.1.name'] == 'AdversarialLoss' + assert captured['tags']['loss.generator.1.weight'] == '1.0' + + assert captured['tags']['loss.discriminator.0.name'] == 'WassersteinLoss' + assert captured['tags']['loss.discriminator.0.weight'] == '1.0' + assert captured['tags']['loss.discriminator.1.name'] == 'GradientPenaltyLoss' + assert captured['tags']['loss.discriminator.1.weight'] == '10.0' + + group_names = { + artifact['content']['group_name'] + for artifact in captured['artifacts'] + if artifact['content'] is not None + and 'group_name' in artifact['content'] + } + + assert group_names == {'generator', 'discriminator'} + + logger.end_run()