Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 33 additions & 2 deletions src/virtual_stain_flow/engine/loss_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
"""

from dataclasses import dataclass
from typing import Optional, Union, Tuple, Dict, Sequence, List
from typing import Optional, Union, Tuple, Dict, Sequence, List, Any

import torch

Expand Down Expand Up @@ -79,6 +79,7 @@ def __post_init__(self):
def __call__(
self,
train: bool,
epoch: Optional[int] = None,
context: Optional[Context] = None,
**inputs: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
Expand All @@ -88,6 +89,7 @@ def __call__(
skipped during validation.

:param train: Whether the model is in training mode.
:param epoch: Optional epoch number to determine the weight from the schedule.
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think some weight schedulers also use per batch, so you could consider adding this as well, or maybe it could be something like:
:param progress: Optional training progress index (step or epoch) used to determine the scheduled weight.

:param context: Optional Context object containing tensors.
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Haven't dug into the code, but consider renaming context object, if possible, to be less abstract (more specific)

:param inputs: Keyword arguments containing all necessary inputs for the
loss computation.
Expand Down Expand Up @@ -117,6 +119,21 @@ def __call__(

return raw, raw * _scalar_from_ctx(self.weight, inputs)

def get_config(self) -> Dict[str, Any]:
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea

"""
Get the configuration of the LossItem for logging or checkpointing.
"""
return {
'module': self.module.__class__.__name__,
'args': self.args,
'key': self.key,
'weight': self.weight,
'enabled': self.enabled,
'compute_at_val': self.compute_at_val,
'device': str(self.device)
}


@dataclass
class LossGroup:
"""
Expand All @@ -137,13 +154,15 @@ def item_names(self) -> List[Optional[str]]:
def __call__(
self,
train: bool,
epoch: Optional[int] = None,
context: Optional[Context] = None,
**inputs: torch.Tensor
) -> Tuple[torch.Tensor, Dict[str, Scalar]]:
"""
Compute the total loss and individual loss values.

:param train: Whether the model is in training mode.
:param epoch: Optional epoch number to determine the weight from the schedule.
:param context: Optional Context object containing tensors.
:input inputs: Keyword arguments containing all necessary inputs for the
loss computations.
Expand All @@ -156,8 +175,20 @@ def __call__(
logs: Dict[str, float] = {}

for item in self.items:
raw, weighted = item(train, context=context, **inputs)
raw, weighted = item(
train,
epoch=epoch,
context=context,
**inputs
)
logs[item.key] = raw.item() # type: ignore
total += weighted

return total, logs

def get_config(self) -> List[Dict[str, Any]]:
"""
Get the configuration of the LossGroup for logging or checkpointing.
"""

return [item.get_config() for item in self.items]
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider renaming self.items here to something more specific (such as the grouping - like training_loss_objects)

7 changes: 7 additions & 0 deletions src/virtual_stain_flow/trainers/logging_gan_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,13 @@ def evaluate_step(
metric.update(*ctx.as_metric_args(), validation=True)

return gen_logs | disc_logs

@property
def loss_groups(self) -> Dict[str, LossGroup]:
return {
'generator': self._generator_loss_group,
'discriminator': self._discriminator_loss_group
}

def save_model(
self,
Expand Down
16 changes: 14 additions & 2 deletions src/virtual_stain_flow/trainers/logging_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,11 @@ def train_step(
targets=targets
)

weighted_total, logs = self._loss_group(train=True, context=ctx)
weighted_total, logs = self._loss_group(
train=True,
epoch=self.epoch,
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe it would be worth adding this to logging_gan_trainer.py as well?

context=ctx
)
weighted_total.backward()
self._forward_group.step()

Expand Down Expand Up @@ -133,12 +137,20 @@ def evaluate_step(
targets=targets
)

_, logs = self._loss_group(train=False, context=ctx)
_, logs = self._loss_group(
train=False,
epoch=self.epoch,
context=ctx
)

for _, metric in self.metrics.items():
metric.update(*ctx.as_metric_args(), validation=True)

return logs

@property
def loss_groups(self) -> Dict[str, LossGroup]:
return {'main': self._loss_group}

def save_model(
self,
Expand Down
88 changes: 88 additions & 0 deletions src/virtual_stain_flow/vsf_logging/MlflowLogger.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,8 @@ def on_train_start(self):
except Exception as e:
print(f"Fail to log model config as artifact: {e}")

self._log_loss_groups_config_and_tags()


for callback in self.callbacks:
# TODO consider if we want hasattr checks
Expand Down Expand Up @@ -502,6 +504,92 @@ def _save_model_weights(
artifact_path=artifact_path
)

def _get_loss_groups(self) -> Dict[str, Any]:
"""
Discover loss groups attached to the bound trainer.
"""
Comment on lines +508 to +510
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good minimal comment


if self.trainer is None:
return {}

loss_groups: Dict[str, Any] = {}

explicit_groups = getattr(self.trainer, 'loss_groups', None)
if isinstance(explicit_groups, dict):
for group_name, group in explicit_groups.items():
if hasattr(group, 'get_config'):
loss_groups[str(group_name)] = group

fallback_attrs = {
'main': '_loss_group',
'generator': '_generator_loss_group',
'discriminator': '_discriminator_loss_group'
}
for group_name, attr in fallback_attrs.items():
if group_name in loss_groups:
continue
group = getattr(self.trainer, attr, None)
if group is not None and hasattr(group, 'get_config'):
loss_groups[group_name] = group

return loss_groups

def _log_loss_groups_config_and_tags(self) -> None:
"""
Log loss item names and weights as flat tags and full loss group
configuration as config artifacts.
"""
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe this is mentioned somewhere else, but after reading this comment, I'm not sure what flat tags and full loss groups mean. Consider expanding on this here or mentioning it somewhere else if you haven't already


loss_groups = self._get_loss_groups()
if not loss_groups:
return None

for group_name, group in loss_groups.items():
try:
group_config = group.get_config()
except Exception as e:
print(
f"Could not get loss group config for logging "
f"({group_name}): {e}"
)
continue

if not isinstance(group_config, list):
continue

for idx, item_cfg in enumerate(group_config):
if not isinstance(item_cfg, dict):
continue

if 'key' in item_cfg and item_cfg['key'] is not None:
mlflow.set_tag(
f"loss.{group_name}.{idx}.name",
str(item_cfg['key'])
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider naming this to something more specific if possible

)

if 'weight' in item_cfg and item_cfg['weight'] is not None:
mlflow.set_tag(
f"loss.{group_name}.{idx}.weight",
str(item_cfg['weight'])
)

try:
self.log_config(
tag=f"loss_group_{group_name}",
config={
'group_name': group_name,
'items': group_config
},
stage=None
)
except Exception as e:
print(
f"Fail to log loss group config as artifact "
f"({group_name}): {e}"
)

return None

def log_config(
self,
tag: str,
Expand Down
Loading
Loading